tabensemb.model.CatEmbed._generate_dataset_for_required_models#
method
- CatEmbed._generate_dataset_for_required_models(df, derived_data, tensors, required_models)#
Call
AbstractModel._data_preprocess()to generate the dataset, output, and hidden representations for the required model- Parameters:
- df
The new tabular dataset that has the same structure as
self.trainer.datamodule.X_test- derived_data
Unstacked data derived from
tabensemb.data.datamodule.DataModule.derive_unstacked().- tensors
Tensors stored in a
tabensemb.data.datamodule.DataModuleand obtained bytabensemb.data.datamodule.DataModule.update_dataset()- required_models
Required models specified in
AbstractModel.required_models()and extracted byAbstractModel._get_required_models().
- Returns:
- torch.utils.data.Dataset