tabensemb.model.CatEmbed#

class tabensemb.model.CatEmbed(*args, lightning_trainer_kwargs: Dict | None = None, **kwargs)[source]#

Bases: TorchModel

Methods

__init__(*args, lightning_trainer_kwargs: Dict | None = None, **kwargs)#
Parameters:
trainer:

A Trainer instance that contains all information and datasets and will be linked to the model base. The trainer has loaded configs and data.

program:

The name of the model base. If None, the name from _get_program_name() is used.

model_subset:

The names of models selected to be trained in the model base.

exclude_models:

The names of models that should not be trained. Only one of model_subset and exclude_models can be specified.

store_in_harddisk:

Whether to save models in the hard disk. If the global setting tabensemb.setting["low_memory"] is True, True is used.

optimizers

A dictionary of optimizer names (choose from those in torch.optim) and their hyperparameters for each model. Remember to change _initial_values() and _space() to optimize its hyperparameters.

lr_schedulers

A dictionary of lr scheduler names (choose from those in torch.optim.lr_scheduler) and their hyperparameters for each model. Remember to change _initial_values() and _space() to optimize its hyperparameters.

**kwargs:

Ignored.

required_models(model_name)

The names of models required by the requested model.

_conditional_validity(model_name)

Check the validity of a model.

_get_model_names()

Get names of all available models implemented in the model base.

_get_program_name()

Get the default name of the model base.

_initial_values(model_name)

Initial values of hyperparameters to be optimized.

_new_model(model_name, verbose, **kwargs)

Generate a new selected model based on kwargs.

_prepare_custom_datamodule(model_name[, ...])

Change this method if a customized preprocessing stage is needed.

_run_custom_data_module(df, derived_data, ...)

Change this method if a customized preprocessing stage is implemented in _prepare_custom_datamodule().

_space(model_name)

A list of scikit-optimize search spaces for the selected model.