Dataset and configuration#
In this part, we will introduce how to prepare a new dataset and its configuration file, and the basic usage of UserConfig and DataModule. You will be able to run benchmarks on your own dataset after reading this part.
The dataset#
We provide a randomly generated sample dataset (data/sample.csv) and its configuration file (configs/sample.py) in the repository. First, let’s check the content of sample.csv. It contains 256 data points, 10 continuous features (namely cont_0 to cont_9), 10 categorical features (namely cat_0 to cat_9), and one target column target.
Remark: The dataset file should not contain an index column.
Remark: Both .csv and .xlsx are supported. We recommend .csv files for their efficiency.
Remark: Values of categorical features that contain non-numerical values (bool, string, or mixed types) will be transformed into strings. So, for example, the number 3 and the string "3" of a categorical feature will be the same (are both interpreted as the string "3").
[1]:
import pandas as pd
prefix = "../../../../"
pd.read_csv(prefix + "data/sample.csv")
[1]:
| cont_0 | cont_1 | cont_2 | cont_3 | cont_4 | cont_5 | cont_6 | cont_7 | cont_8 | cont_9 | ... | cat_3 | cat_4 | cat_5 | cat_6 | cat_7 | cat_8 | cat_9 | target | target_binary | target_multi_class | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | -1.306527 | NaN | -0.118164 | -0.159573 | 1.658131 | -1.346718 | -0.680178 | -1.334258 | 0.666383 | -0.460720 | ... | 0 | 2 | category_4 | 3 | 4 | 4 | 3 | -71.084217 | 0 | 1 |
| 1 | 2.011257 | NaN | 0.195070 | 0.527004 | -0.044595 | 0.616887 | -1.781563 | 0.354758 | -0.729045 | 0.196557 | ... | 4 | 3 | category_3 | 3 | 1 | 3 | 2 | 13.415675 | 1 | 2 |
| 2 | -1.216077 | NaN | -0.743672 | 0.730184 | 0.140672 | 1.272954 | -0.159012 | -0.475175 | 0.240057 | 0.100159 | ... | 0 | 4 | category_3 | 4 | 1 | 0 | 2 | -47.492280 | 0 | 2 |
| 3 | 0.559299 | NaN | -0.431096 | -0.809627 | -1.063696 | -0.860153 | 0.572751 | -0.467441 | 0.677557 | 1.307184 | ... | 4 | 1 | category_3 | 4 | 2 | 0 | 0 | -94.482614 | 1 | 2 |
| 4 | 0.910179 | NaN | 0.786328 | -0.042257 | 0.317218 | 0.379152 | -0.466419 | -0.017020 | -0.944446 | -0.410050 | ... | 1 | 0 | category_2 | 0 | 2 | 3 | 0 | 195.819531 | 1 | 3 |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| 251 | 0.280442 | -0.206904 | 0.841631 | 0.880179 | -0.993124 | -1.570623 | -0.249459 | 0.643314 | 0.049495 | 0.493837 | ... | 1 | 2 | category_2 | 2 | 3 | 0 | 2 | -171.249549 | 0 | 0 |
| 252 | -1.165150 | -1.070753 | 0.465662 | 1.054452 | 0.900826 | -0.179925 | -1.536244 | 1.178780 | 1.488252 | 1.895889 | ... | 4 | 2 | category_4 | 4 | 2 | 1 | 1 | 23.708442 | 0 | 2 |
| 253 | -0.069856 | -0.186691 | -1.021913 | -1.143641 | 0.250114 | 1.040239 | -1.150438 | 0.258798 | -0.836111 | 0.642211 | ... | 0 | 3 | category_3 | 2 | 2 | 2 | 2 | -33.414215 | 1 | 1 |
| 254 | -1.031482 | -0.860262 | -0.061638 | 0.328301 | -1.429991 | -1.048170 | -1.432735 | 0.607112 | 0.087531 | 0.938747 | ... | 0 | 0 | category_3 | 4 | 1 | 4 | 4 | -359.199191 | 0 | 4 |
| 255 | -1.461733 | 0.960693 | 0.367545 | 1.329063 | -0.683440 | -1.184687 | 0.190312 | -0.521580 | -0.851729 | 1.822724 | ... | 2 | 1 | category_3 | 4 | 1 | 1 | 4 | -135.199100 | 1 | 2 |
256 rows × 23 columns
The configuration file#
A configuration file contains a dictionary stating modified values compared to a given default configuration.
Remark: The dataset file can be a .py file containing a dict object named cfg, or a .json file.
The default configuration#
To see the default values, use tabensemb.config.UserConfig, which inherits dict.
[2]:
from tabensemb.config import UserConfig
from tabensemb.utils import pretty
import tabensemb
tabensemb.setting["default_config_path"] = prefix + "configs"
cfg = UserConfig("sample")
print(pretty(cfg.defaults()))
{
'database': 'sample',
'task': None,
'loss': None,
'bayes_opt': False,
'bayes_calls': 50,
'bayes_epoch': 30,
'patience': 100,
'epoch': 300,
'lr': 0.001,
'weight_decay': 1e-09,
'batch_size': 1024,
'layers': [
64,
128,
256,
128,
64
],
'SPACEs': {
'lr': {
'type': 'Real',
'low': 0.0001,
'high': 0.05,
'prior': 'log-uniform'
},
'weight_decay': {
'type': 'Real',
'low': 1e-09,
'high': 0.05,
'prior': 'log-uniform'
},
'batch_size': {
'type': 'Categorical',
'categories': [
64,
128,
256,
512,
1024,
2048
]
}
},
'data_splitter': 'RandomSplitter',
'split_ratio': [
0.6,
0.2,
0.2
],
'data_imputer': 'MissForestImputer',
'data_processors': [
(
'CategoricalOrdinalEncoder',
{
}
),
(
'NaNFeatureRemover',
{
}
),
(
'VarianceFeatureSelector',
{
'thres': 1
}
),
(
'StandardScaler',
{
}
)
],
'data_derivers': [
],
'categorical_feature_names': [
],
'continuous_feature_names': [
],
'feature_types': {
},
'unique_feature_types': [
],
'label_name': [
'target'
]
}
The configuration of the given sample dataset#
configs/sample.py contains the following contents:
cfg = {
"database": "sample",
"continuous_feature_names": ["cont_0", "cont_1", "cont_2", "cont_3", "cont_4"],
"categorical_feature_names": ["cat_0", "cat_1", "cat_2"],
"label_name": ["target"],
}
Load configs/sample.py and see the changes.
[3]:
cfg = UserConfig("sample")
print(pretty(cfg))
{
'database': 'sample',
'task': None,
'loss': None,
'bayes_opt': False,
'bayes_calls': 50,
'bayes_epoch': 30,
'patience': 100,
'epoch': 300,
'lr': 0.001,
'weight_decay': 1e-09,
'batch_size': 1024,
'layers': [
64,
128,
256,
128,
64
],
'SPACEs': {
'lr': {
'type': 'Real',
'low': 0.0001,
'high': 0.05,
'prior': 'log-uniform'
},
'weight_decay': {
'type': 'Real',
'low': 1e-09,
'high': 0.05,
'prior': 'log-uniform'
},
'batch_size': {
'type': 'Categorical',
'categories': [
64,
128,
256,
512,
1024,
2048
]
}
},
'data_splitter': 'RandomSplitter',
'split_ratio': [
0.6,
0.2,
0.2
],
'data_imputer': 'MissForestImputer',
'data_processors': [
(
'CategoricalOrdinalEncoder',
{
}
),
(
'NaNFeatureRemover',
{
}
),
(
'VarianceFeatureSelector',
{
'thres': 1
}
),
(
'StandardScaler',
{
}
)
],
'data_derivers': [
],
'categorical_feature_names': [
'cat_0',
'cat_1',
'cat_2'
],
'continuous_feature_names': [
'cont_0',
'cont_1',
'cont_2',
'cont_3',
'cont_4'
],
'feature_types': {
'cont_0': 'Continuous',
'cont_1': 'Continuous',
'cont_2': 'Continuous',
'cont_3': 'Continuous',
'cont_4': 'Continuous',
'cat_0': 'Categorical',
'cat_1': 'Categorical',
'cat_2': 'Categorical'
},
'unique_feature_types': [
'Categorical',
'Continuous'
],
'label_name': [
'target'
]
}
Descriptions of keys in a configuration file#
database: The name of the database file. The file should be placed in the script directory or intabensemb.setting["default_data_path"]. If no postfix (.csvor.xlsx) is provided, the program automatically searches for a matched postfix. If both.csvand.xlsxexist, an exception will be raised.task: “regression” for regression tasks, “binary” for binary classifications, and “multiclass” for multiclass classifications. If left None, the task will be guessed from the type of the target. If the target is of the typeobjector integers, “binary” or “multiclass” is guessed depending on the number of unique targets; otherwise, “regression” is guessed.loss: “mse” (default) or “mae” for regression tasks, and “cross_entropy” for classification tasks. This loss will be used across all model bases. If left None, “mse” or “cross_entropy” will be used.bayes_opt: Perform gaussian-process-based Bayesian hyperparameter optimization (HPO) using thescikit-optimizepackage when training each model.bayes_calls: The number of calls of the Bayesian HPO. During each call, the model will be trained given a set of hyperparameters, and then the metric on the validation set will be returned to the Bayesian HPO process.bayes_epoch: The number of epochs during each Bayesian HPO call.patience: Early stopping patience. If the metric on the validation set does not improve afterpatienceepochs, the training process terminates and the best model is loaded.epoch: Total epochs to train each model.lr: Initial learning rate.weight_decay: Initial weight_decay (for atorch.optim.Adamoptimizer)batch_size: Initial batch_size.layers: Default hidden layers for some models.SPACEs: Default bayesian HPO spaces forlr,weight_decay, andbatch_size. The keytypedetermines theskopt.space, and the rest of the keys determines its arguments.data_splitter: The dataset splitting method to split training/validation/testing sets. Seetabensemb.data.datasplitter.splitter_mappingfor available classes.split_ratio: The ratio of training/validation/testing sets.data_imputer: The imputation method forNaNvalues. Seetabensemb.data.dataimputer.imputer_mappingfor available classes.data_processors: A list of data processing steps and their corresponding arguments. Seetabensemb.data.dataprocessor.processor_mappingfor available classes. See API docs for definitions of arguments.data_derivers: A list of feature augmentation steps and their corresponding arguments. Some fix arguments arestacked:Trueto append the derived feature to continuous features and the finalDataFramerepresenting the processed dataset.Falseto leave it as an unstacked feature (mostly for multi-modal data)intermediate:Trueto ignore the derived feature in continuous features even whenstacked=True, but still append the feature to theDataFrame.
See
tabensemb.data.dataderiver.deriver_mappingfor available classes. See API docs or_required_colsof each class for its additional arguments.continuous_feature_names: Continuous features. Each of them should be all floats or integers.categorical_feature_names: Categorical features. Each of them should be all integers or strings.feature_types: A dictionary stating categories of each feature defined incontinuous_feature_namesandcategorical_feature_names. If it is not given in the configuration, “Continuous” and “Categorical” will be automatically used to assign the values of continuous and categorical features, respectively.unique_feature_types: Unique values in the dictionaryfeature_types.label_name: The predicted target.
Use the configuration file to load the dataset#
The DataModule requires a UserConfig to load the dataset, then initialize and run all data processing steps on the dataset or an upcoming new dataset. The following lines load the dataset and present the loaded and processed DataFrame without imputation.
[4]:
from tabensemb.data import DataModule
tabensemb.setting["default_data_path"] = prefix + "data"
datamodule = DataModule(cfg)
datamodule.load_data()
datamodule.get_not_imputed_df()
Dataset size: 153 51 52
[4]:
| cont_0 | cont_1 | cont_2 | cont_3 | cont_4 | cont_5 | cont_6 | cont_7 | cont_8 | cont_9 | ... | cat_3 | cat_4 | cat_5 | cat_6 | cat_7 | cat_8 | cat_9 | target | target_binary | target_multi_class | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | -1.306527 | NaN | -0.118164 | -0.159573 | 1.658131 | -1.346718 | -0.680178 | -1.334258 | 0.666383 | -0.460720 | ... | 0 | 2 | category_4 | 3 | 4 | 4 | 3 | -71.084217 | 0 | 1 |
| 1 | 2.011257 | NaN | 0.195070 | 0.527004 | -0.044595 | 0.616887 | -1.781563 | 0.354758 | -0.729045 | 0.196557 | ... | 4 | 3 | category_3 | 3 | 1 | 3 | 2 | 13.415675 | 1 | 2 |
| 2 | -1.216077 | NaN | -0.743672 | 0.730184 | 0.140672 | 1.272954 | -0.159012 | -0.475175 | 0.240057 | 0.100159 | ... | 0 | 4 | category_3 | 4 | 1 | 0 | 2 | -47.492280 | 0 | 2 |
| 3 | 0.559299 | NaN | -0.431096 | -0.809627 | -1.063696 | -0.860153 | 0.572751 | -0.467441 | 0.677557 | 1.307184 | ... | 4 | 1 | category_3 | 4 | 2 | 0 | 0 | -94.482614 | 1 | 2 |
| 4 | 0.910179 | NaN | 0.786328 | -0.042257 | 0.317218 | 0.379152 | -0.466419 | -0.017020 | -0.944446 | -0.410050 | ... | 1 | 0 | category_2 | 0 | 2 | 3 | 0 | 195.819531 | 1 | 3 |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| 251 | 0.280442 | -0.206904 | 0.841631 | 0.880179 | -0.993124 | -1.570623 | -0.249459 | 0.643314 | 0.049495 | 0.493837 | ... | 1 | 2 | category_2 | 2 | 3 | 0 | 2 | -171.249549 | 0 | 0 |
| 252 | -1.165150 | -1.070753 | 0.465662 | 1.054452 | 0.900826 | -0.179925 | -1.536244 | 1.178780 | 1.488252 | 1.895889 | ... | 4 | 2 | category_4 | 4 | 2 | 1 | 1 | 23.708442 | 0 | 2 |
| 253 | -0.069856 | -0.186691 | -1.021913 | -1.143641 | 0.250114 | 1.040239 | -1.150438 | 0.258798 | -0.836111 | 0.642211 | ... | 0 | 3 | category_3 | 2 | 2 | 2 | 2 | -33.414215 | 1 | 1 |
| 254 | -1.031482 | -0.860262 | -0.061638 | 0.328301 | -1.429991 | -1.048170 | -1.432735 | 0.607112 | 0.087531 | 0.938747 | ... | 0 | 0 | category_3 | 4 | 1 | 4 | 4 | -359.199191 | 0 | 4 |
| 255 | -1.461733 | 0.960693 | 0.367545 | 1.329063 | -0.683440 | -1.184687 | 0.190312 | -0.521580 | -0.851729 | 1.822724 | ... | 2 | 1 | category_3 | 4 | 1 | 1 | 4 | -135.199100 | 1 | 2 |
256 rows × 23 columns
DataModule.df present the imputed DataFrame.
[5]:
datamodule.df
[5]:
| cont_0 | cont_1 | cont_2 | cont_3 | cont_4 | cont_5 | cont_6 | cont_7 | cont_8 | cont_9 | ... | cat_3 | cat_4 | cat_5 | cat_6 | cat_7 | cat_8 | cat_9 | target | target_binary | target_multi_class | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | -1.306527 | 0.256888 | -0.118164 | -0.159573 | 1.658131 | -1.346718 | -0.680178 | -1.334258 | 0.666383 | -0.460720 | ... | 0 | 2 | category_4 | 3 | 4 | 4 | 3 | -71.084217 | 0 | 1 |
| 1 | 2.011257 | 0.256888 | 0.195070 | 0.527004 | -0.044595 | 0.616887 | -1.781563 | 0.354758 | -0.729045 | 0.196557 | ... | 4 | 3 | category_3 | 3 | 1 | 3 | 2 | 13.415675 | 1 | 2 |
| 2 | -1.216077 | 0.256888 | -0.743672 | 0.730184 | 0.140672 | 1.272954 | -0.159012 | -0.475175 | 0.240057 | 0.100159 | ... | 0 | 4 | category_3 | 4 | 1 | 0 | 2 | -47.492280 | 0 | 2 |
| 3 | 0.559299 | 0.256888 | -0.431096 | -0.809627 | -1.063696 | -0.860153 | 0.572751 | -0.467441 | 0.677557 | 1.307184 | ... | 4 | 1 | category_3 | 4 | 2 | 0 | 0 | -94.482614 | 1 | 2 |
| 4 | 0.910179 | -0.228308 | 0.786328 | -0.042257 | 0.317218 | 0.379152 | -0.466419 | -0.017020 | -0.944446 | -0.410050 | ... | 1 | 0 | category_2 | 0 | 2 | 3 | 0 | 195.819531 | 1 | 3 |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| 251 | 0.280442 | -0.206904 | 0.841631 | 0.880179 | -0.993124 | -1.570623 | -0.249459 | 0.643314 | 0.049495 | 0.493837 | ... | 1 | 2 | category_2 | 2 | 3 | 0 | 2 | -171.249549 | 0 | 0 |
| 252 | -1.165150 | -1.070753 | 0.465662 | 1.054452 | 0.900826 | -0.179925 | -1.536244 | 1.178780 | 1.488252 | 1.895889 | ... | 4 | 2 | category_4 | 4 | 2 | 1 | 1 | 23.708442 | 0 | 2 |
| 253 | -0.069856 | -0.186691 | -1.021913 | -1.143641 | 0.250114 | 1.040239 | -1.150438 | 0.258798 | -0.836111 | 0.642211 | ... | 0 | 3 | category_3 | 2 | 2 | 2 | 2 | -33.414215 | 1 | 1 |
| 254 | -1.031482 | -0.860262 | -0.061638 | 0.328301 | -1.429991 | -1.048170 | -1.432735 | 0.607112 | 0.087531 | 0.938747 | ... | 0 | 0 | category_3 | 4 | 1 | 4 | 4 | -359.199191 | 0 | 4 |
| 255 | -1.461733 | 0.960693 | 0.367545 | 1.329063 | -0.683440 | -1.184687 | 0.190312 | -0.521580 | -0.851729 | 1.822724 | ... | 2 | 1 | category_3 | 4 | 1 | 1 | 4 | -135.199100 | 1 | 2 |
256 rows × 23 columns
DataModule.train_indices, DataModule.val_indices, and DataModule.test_indices represent indices of training/validation/testing sets, respectively.
[6]:
datamodule.train_indices, datamodule.val_indices, datamodule.test_indices
[6]:
(array([158, 161, 186, 1, 208, 60, 202, 141, 13, 113, 216, 49, 75,
151, 159, 200, 190, 218, 32, 58, 28, 54, 19, 129, 24, 183,
61, 66, 253, 18, 5, 85, 43, 92, 0, 63, 185, 163, 244,
138, 20, 44, 96, 59, 236, 27, 94, 79, 100, 10, 213, 106,
176, 93, 62, 240, 98, 239, 247, 221, 204, 166, 99, 160, 181,
209, 150, 21, 14, 89, 201, 167, 145, 41, 12, 128, 101, 251,
124, 69, 205, 82, 241, 6, 191, 226, 117, 45, 22, 110, 35,
33, 74, 148, 105, 34, 77, 168, 90, 84, 179, 78, 2, 220,
155, 184, 47, 15, 140, 72, 195, 243, 232, 23, 39, 127, 71,
111, 144, 107, 211, 210, 173, 50, 254, 237, 194, 68, 162, 70,
197, 231, 123, 103, 228, 170, 136, 142, 80, 207, 116, 40, 171,
91, 135, 152, 248, 4, 125, 104, 83, 121, 177]),
array([165, 122, 187, 235, 238, 17, 87, 29, 42, 174, 178, 48, 169,
65, 46, 242, 224, 245, 130, 97, 215, 56, 164, 143, 57, 7,
175, 137, 252, 249, 203, 88, 3, 223, 16, 109, 154, 86, 227,
64, 38, 115, 149, 222, 95, 188, 102, 131, 219, 250, 126]),
array([ 9, 36, 153, 114, 120, 108, 246, 139, 118, 198, 76, 51, 212,
157, 196, 156, 172, 146, 189, 225, 233, 206, 37, 52, 199, 182,
229, 30, 192, 180, 234, 67, 11, 112, 8, 134, 147, 132, 25,
81, 119, 214, 53, 230, 31, 26, 73, 55, 217, 255, 133, 193]))
For detailed functionalities of DataModule, please check the API documentation.
A Trainer does all things for you#
Indeed, a user does not need to manually generate a UserConfig or a DataModule because Trainer does all the above steps. After calling Trainer.load_config and Trainer.load_data, a UserConfig instance containing configurations, a DataModule instance containing processing steps, and loaded data are generated and can be accessed by Trainer.args and Trainer.datamodule, respectively.
[7]:
from tabensemb.trainer import Trainer
tabensemb.setting["default_output_path"] = prefix + "output"
trainer = Trainer(device="cpu")
trainer.load_config("sample")
trainer.load_data()
type(trainer.args), type(trainer.datamodule)
The project will be saved to ../../../../output/sample/2023-09-23-20-37-10-0_sample
Dataset size: 153 51 52
Data saved to ../../../../output/sample/2023-09-23-20-37-10-0_sample (data.csv and tabular_data.csv).
[7]:
(tabensemb.config.user_config.UserConfig, tabensemb.data.datamodule.DataModule)