tabensemb.model.RFE._generate_dataset_for_required_models#

method

RFE._generate_dataset_for_required_models(df, derived_data, tensors, required_models)#

Call AbstractModel._data_preprocess() to generate the dataset, output, and hidden representations for the required model

Parameters:
df

The new tabular dataset that has the same structure as self.trainer.datamodule.X_test

derived_data

Unstacked data derived from tabensemb.data.datamodule.DataModule.derive_unstacked().

tensors

Tensors stored in a tabensemb.data.datamodule.DataModule and obtained by tabensemb.data.datamodule.DataModule.update_dataset()

required_models

Required models specified in AbstractModel.required_models() and extracted by AbstractModel._get_required_models().

Returns:
torch.utils.data.Dataset