tabensemb.model.CatEmbed#
- class tabensemb.model.CatEmbed(*args, lightning_trainer_kwargs: Dict | None = None, **kwargs)[source]#
Bases:
TorchModelMethods
- __init__(*args, lightning_trainer_kwargs: Dict | None = None, **kwargs)#
- Parameters:
- trainer:
A
Trainerinstance that contains all information and datasets and will be linked to the model base. The trainer has loaded configs and data.- program:
The name of the model base. If None, the name from
_get_program_name()is used.- model_subset:
The names of models selected to be trained in the model base.
- exclude_models:
The names of models that should not be trained. Only one of
model_subsetandexclude_modelscan be specified.- store_in_harddisk:
Whether to save models in the hard disk. If the global setting
tabensemb.setting["low_memory"]is True, True is used.- optimizers
A dictionary of optimizer names (choose from those in
torch.optim) and their hyperparameters for each model. Remember to change_initial_values()and_space()to optimize its hyperparameters.- lr_schedulers
A dictionary of lr scheduler names (choose from those in
torch.optim.lr_scheduler) and their hyperparameters for each model. Remember to change_initial_values()and_space()to optimize its hyperparameters.- **kwargs:
Ignored.
required_models(model_name)The names of models required by the requested model.
_conditional_validity(model_name)Check the validity of a model.
Get names of all available models implemented in the model base.
Get the default name of the model base.
_initial_values(model_name)Initial values of hyperparameters to be optimized.
_new_model(model_name, verbose, **kwargs)Generate a new selected model based on kwargs.
_prepare_custom_datamodule(model_name[, ...])Change this method if a customized preprocessing stage is needed.
_run_custom_data_module(df, derived_data, ...)Change this method if a customized preprocessing stage is implemented in
_prepare_custom_datamodule()._space(model_name)A list of
scikit-optimizesearch spaces for the selected model.