New data derivers#
In this package, a very limited number of derivers are currently provided. A deriver can be used to calculate new features (continuous or categorical) based on existing features, or load images, text, etc. as multimodal data. The source code of the integrated tabensemb.data.dataderiver.RelativeDeriver is extended here to demonstrate the implementation procedure.
[1]:
from tabensemb.data.dataderiver import AbstractDeriver
Data derivers inherit tabensemb.data.AbstractDervier and four methods should be implemented:
_required_cols: Arguments for columns that must exist in the tabular dataset. The following code means that the argumentsabsolute_colandrelative2_colshould be given in the configuration, such as"data_derivers": [("MyRelativeDeriver", {"absolute_col": "cont_0", "relative2_col": "cont_1"})]
class MyRelativeDeriver(AbstractDeriver):
def _required_cols(self):
return ["absolute_col", "relative2_col"]
_required_kwargs: Parameters that must be specified in the configuration. The following code means that the parametersome_paramshould be given in the configuration, such as"data_derivers": [("MyRelativeDeriver", {"some_param": 1.5})]
def _required_kwargs(self):
return ["some_param"]
Remark: “stacked”, “intermediate”, “derived_name”, and “is_continuous” are shared necessary kwargs and do not need to be added to _required_kwargs.
_defaults: Default values of those in_required_cols,_required_kwargs, and["stacked", "intermediate", "derived_name", "is_continuous"]. If default values are given, no error will be raised if the argument is not set in the configuration.
def _defaults(self):
return dict(stacked=True, intermediate=False, is_continuous=True)
_derive: The main derivation step. It receives the tabular data (aDataFrame) and aDataModuleand should return annp.ndarray. The returned array can not be 1d. Arguments are checked and recorded inself.kwargswhen initializing.
def _derive(self, df, datamodule):
absolute_col = self.kwargs["absolute_col"]
relative2_col = self.kwargs["relative2_col"]
some_param = self.kwargs["some_param"]
stacked = self.kwargs["stacked"]
relative = df[absolute_col] / df[relative2_col]
relative = relative.values.reshape(-1, 1)
return relative
[2]:
class MyRelativeDeriver(AbstractDeriver):
def _required_cols(self):
return ["absolute_col", "relative2_col"]
def _required_kwargs(self):
return ["some_param"]
def _defaults(self):
return dict(stacked=True, intermediate=False, is_continuous=True)
def _derive(self, df, datamodule):
absolute_col = self.kwargs["absolute_col"]
relative2_col = self.kwargs["relative2_col"]
some_param = self.kwargs["some_param"]
stacked = self.kwargs["stacked"]
relative = df[absolute_col] / df[relative2_col]
relative = relative.values.reshape(-1, 1)
return relative
The implemented splitter should be registered as follows to be recognized by DataModule.set_data_derivers automatically.
[3]:
from tabensemb.data.dataderiver import deriver_mapping
deriver_mapping["MyRelativeDeriver"] = MyRelativeDeriver
[4]:
from tabensemb.trainer import Trainer
import tabensemb
prefix = "../../../../"
tabensemb.setting["default_output_path"] = prefix + "output"
tabensemb.setting["default_config_path"] = prefix + "configs"
tabensemb.setting["default_data_path"] = prefix + "data"
trainer = Trainer(device="cpu")
trainer.load_config("sample")
The project will be saved to ../../../../output/sample/2023-09-18-18-15-00-0_sample
If stacked is True:
[5]:
trainer.datamodule.set_data_derivers([("MyRelativeDeriver", {"absolute_col": "cont_0", "relative2_col": "cont_1", "derived_name": "cont_0_relative2_cont_1", "some_param": 1.0, "stacked": True})])
trainer.load_data()
print(f"cont_0_relative2_cont_1 in continuous features?: {'cont_0_relative2_cont_1' in trainer.cont_feature_names}")
trainer.df
Dataset size: 153 51 52
Data saved to ../../../../output/sample/2023-09-18-18-15-00-0_sample (data.csv and tabular_data.csv).
cont_0_relative2_cont_1 in continuous features?: True
[5]:
| cont_0 | cont_1 | cont_2 | cont_3 | cont_4 | cont_5 | cont_6 | cont_7 | cont_8 | cont_9 | ... | cat_4 | cat_5 | cat_6 | cat_7 | cat_8 | cat_9 | target | target_binary | target_multi_class | cont_0_relative2_cont_1 | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | -1.306527 | 0.065895 | -0.118164 | -0.159573 | 1.658131 | -1.346718 | -0.680178 | -1.334258 | 0.666383 | -0.460720 | ... | 2 | category_4 | 3 | 4 | 4 | 3 | -71.084217 | 0 | 1 | -19.827301 |
| 1 | 2.011257 | 0.117717 | 0.195070 | 0.527004 | -0.044595 | 0.616887 | -1.781563 | 0.354758 | -0.729045 | 0.196557 | ... | 3 | category_3 | 3 | 1 | 3 | 2 | 13.415675 | 1 | 2 | 17.085552 |
| 2 | -1.216077 | 0.065895 | -0.743672 | 0.730184 | 0.140672 | 1.272954 | -0.159012 | -0.475175 | 0.240057 | 0.100159 | ... | 4 | category_3 | 4 | 1 | 0 | 2 | -47.492280 | 0 | 2 | -18.454666 |
| 3 | 0.559299 | 0.117717 | -0.431096 | -0.809627 | -1.063696 | -0.860153 | 0.572751 | -0.467441 | 0.677557 | 1.307184 | ... | 1 | category_3 | 4 | 2 | 0 | 0 | -94.482614 | 1 | 2 | 4.751225 |
| 4 | 0.910179 | -0.213096 | 0.786328 | -0.042257 | 0.317218 | 0.379152 | -0.466419 | -0.017020 | -0.944446 | -0.410050 | ... | 0 | category_2 | 0 | 2 | 3 | 0 | 195.819531 | 1 | 3 | -4.271217 |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| 251 | 0.280442 | -0.206904 | 0.841631 | 0.880179 | -0.993124 | -1.570623 | -0.249459 | 0.643314 | 0.049495 | 0.493837 | ... | 2 | category_2 | 2 | 3 | 0 | 2 | -171.249549 | 0 | 0 | -1.355422 |
| 252 | -1.165150 | -1.070753 | 0.465662 | 1.054452 | 0.900826 | -0.179925 | -1.536244 | 1.178780 | 1.488252 | 1.895889 | ... | 2 | category_4 | 4 | 2 | 1 | 1 | 23.708442 | 0 | 2 | 1.088160 |
| 253 | -0.069856 | -0.186691 | -1.021913 | -1.143641 | 0.250114 | 1.040239 | -1.150438 | 0.258798 | -0.836111 | 0.642211 | ... | 3 | category_3 | 2 | 2 | 2 | 2 | -33.414215 | 1 | 1 | 0.374183 |
| 254 | -1.031482 | -0.860262 | -0.061638 | 0.328301 | -1.429991 | -1.048170 | -1.432735 | 0.607112 | 0.087531 | 0.938747 | ... | 0 | category_3 | 4 | 1 | 4 | 4 | -359.199191 | 0 | 4 | 1.199032 |
| 255 | -1.461733 | 0.960693 | 0.367545 | 1.329063 | -0.683440 | -1.184687 | 0.190312 | -0.521580 | -0.851729 | 1.822724 | ... | 1 | category_3 | 4 | 1 | 1 | 4 | -135.199100 | 1 | 2 | -1.521539 |
256 rows × 24 columns
If stacked is True but intermediate is True:
[6]:
trainer.datamodule.set_data_derivers([("MyRelativeDeriver", {"absolute_col": "cont_0", "relative2_col": "cont_1", "derived_name": "cont_0_relative2_cont_1", "some_param": 1.0, "stacked": True, "intermediate": True})])
trainer.load_data()
print(f"cont_0_relative2_cont_1 in continuous features?: {'cont_0_relative2_cont_1' in trainer.cont_feature_names}")
trainer.df
Using previously used data path ../../../../data/sample.csv
Dataset size: 153 51 52
Data saved to ../../../../output/sample/2023-09-18-18-15-00-0_sample (data.csv and tabular_data.csv).
cont_0_relative2_cont_1 in continuous features?: False
[6]:
| cont_0 | cont_1 | cont_2 | cont_3 | cont_4 | cont_5 | cont_6 | cont_7 | cont_8 | cont_9 | ... | cat_4 | cat_5 | cat_6 | cat_7 | cat_8 | cat_9 | target | target_binary | target_multi_class | cont_0_relative2_cont_1 | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | -1.306527 | -0.409756 | -0.118164 | -0.159573 | 1.658131 | -1.346718 | -0.680178 | -1.334258 | 0.666383 | -0.460720 | ... | 2 | category_4 | 3 | 4 | 4 | 3 | -71.084217 | 0 | 1 | 3.188552 |
| 1 | 2.011257 | -0.409756 | 0.195070 | 0.527004 | -0.044595 | 0.616887 | -1.781563 | 0.354758 | -0.729045 | 0.196557 | ... | 3 | category_3 | 3 | 1 | 3 | 2 | 13.415675 | 1 | 2 | -4.908431 |
| 2 | -1.216077 | 0.104704 | -0.743672 | 0.730184 | 0.140672 | 1.272954 | -0.159012 | -0.475175 | 0.240057 | 0.100159 | ... | 4 | category_3 | 4 | 1 | 0 | 2 | -47.492280 | 0 | 2 | -11.614467 |
| 3 | 0.559299 | 0.104704 | -0.431096 | -0.809627 | -1.063696 | -0.860153 | 0.572751 | -0.467441 | 0.677557 | 1.307184 | ... | 1 | category_3 | 4 | 2 | 0 | 0 | -94.482614 | 1 | 2 | 5.341736 |
| 4 | 0.910179 | -0.409756 | 0.786328 | -0.042257 | 0.317218 | 0.379152 | -0.466419 | -0.017020 | -0.944446 | -0.410050 | ... | 0 | category_2 | 0 | 2 | 3 | 0 | 195.819531 | 1 | 3 | -2.221273 |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| 251 | 0.280442 | -0.206904 | 0.841631 | 0.880179 | -0.993124 | -1.570623 | -0.249459 | 0.643314 | 0.049495 | 0.493837 | ... | 2 | category_2 | 2 | 3 | 0 | 2 | -171.249549 | 0 | 0 | -1.355422 |
| 252 | -1.165150 | -1.070753 | 0.465662 | 1.054452 | 0.900826 | -0.179925 | -1.536244 | 1.178780 | 1.488252 | 1.895889 | ... | 2 | category_4 | 4 | 2 | 1 | 1 | 23.708442 | 0 | 2 | 1.088160 |
| 253 | -0.069856 | -0.186691 | -1.021913 | -1.143641 | 0.250114 | 1.040239 | -1.150438 | 0.258798 | -0.836111 | 0.642211 | ... | 3 | category_3 | 2 | 2 | 2 | 2 | -33.414215 | 1 | 1 | 0.374183 |
| 254 | -1.031482 | -0.860262 | -0.061638 | 0.328301 | -1.429991 | -1.048170 | -1.432735 | 0.607112 | 0.087531 | 0.938747 | ... | 0 | category_3 | 4 | 1 | 4 | 4 | -359.199191 | 0 | 4 | 1.199032 |
| 255 | -1.461733 | 0.960693 | 0.367545 | 1.329063 | -0.683440 | -1.184687 | 0.190312 | -0.521580 | -0.851729 | 1.822724 | ... | 1 | category_3 | 4 | 1 | 1 | 4 | -135.199100 | 1 | 2 | -1.521539 |
256 rows × 24 columns
If stacked is False:
[7]:
trainer.datamodule.set_data_derivers([("MyRelativeDeriver", {"absolute_col": "cont_0", "relative2_col": "cont_1", "derived_name": "cont_0_relative2_cont_1", "some_param": 1.0, "stacked": False})])
trainer.load_data()
print(f"cont_0_relative2_cont_1 in continuous features?: {'cont_0_relative2_cont_1' in trainer.cont_feature_names}")
trainer.df
Using previously used data path ../../../../data/sample.csv
Dataset size: 153 51 52
Data saved to ../../../../output/sample/2023-09-18-18-15-00-0_sample (data.csv and tabular_data.csv).
cont_0_relative2_cont_1 in continuous features?: False
[7]:
| cont_0 | cont_1 | cont_2 | cont_3 | cont_4 | cont_5 | cont_6 | cont_7 | cont_8 | cont_9 | ... | cat_3 | cat_4 | cat_5 | cat_6 | cat_7 | cat_8 | cat_9 | target | target_binary | target_multi_class | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | -1.306527 | 0.138315 | -0.118164 | -0.159573 | 1.658131 | -1.346718 | -0.680178 | -1.334258 | 0.666383 | -0.460720 | ... | 0 | 2 | category_4 | 3 | 4 | 4 | 3 | -71.084217 | 0 | 1 |
| 1 | 2.011257 | -0.006111 | 0.195070 | 0.527004 | -0.044595 | 0.616887 | -1.781563 | 0.354758 | -0.729045 | 0.196557 | ... | 4 | 3 | category_3 | 3 | 1 | 3 | 2 | 13.415675 | 1 | 2 |
| 2 | -1.216077 | 0.138315 | -0.743672 | 0.730184 | 0.140672 | 1.272954 | -0.159012 | -0.475175 | 0.240057 | 0.100159 | ... | 0 | 4 | category_3 | 4 | 1 | 0 | 2 | -47.492280 | 0 | 2 |
| 3 | 0.559299 | -0.006111 | -0.431096 | -0.809627 | -1.063696 | -0.860153 | 0.572751 | -0.467441 | 0.677557 | 1.307184 | ... | 4 | 1 | category_3 | 4 | 2 | 0 | 0 | -94.482614 | 1 | 2 |
| 4 | 0.910179 | -0.006111 | 0.786328 | -0.042257 | 0.317218 | 0.379152 | -0.466419 | -0.017020 | -0.944446 | -0.410050 | ... | 1 | 0 | category_2 | 0 | 2 | 3 | 0 | 195.819531 | 1 | 3 |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| 251 | 0.280442 | -0.206904 | 0.841631 | 0.880179 | -0.993124 | -1.570623 | -0.249459 | 0.643314 | 0.049495 | 0.493837 | ... | 1 | 2 | category_2 | 2 | 3 | 0 | 2 | -171.249549 | 0 | 0 |
| 252 | -1.165150 | -1.070753 | 0.465662 | 1.054452 | 0.900826 | -0.179925 | -1.536244 | 1.178780 | 1.488252 | 1.895889 | ... | 4 | 2 | category_4 | 4 | 2 | 1 | 1 | 23.708442 | 0 | 2 |
| 253 | -0.069856 | -0.186691 | -1.021913 | -1.143641 | 0.250114 | 1.040239 | -1.150438 | 0.258798 | -0.836111 | 0.642211 | ... | 0 | 3 | category_3 | 2 | 2 | 2 | 2 | -33.414215 | 1 | 1 |
| 254 | -1.031482 | -0.860262 | -0.061638 | 0.328301 | -1.429991 | -1.048170 | -1.432735 | 0.607112 | 0.087531 | 0.938747 | ... | 0 | 0 | category_3 | 4 | 1 | 4 | 4 | -359.199191 | 0 | 4 |
| 255 | -1.461733 | 0.960693 | 0.367545 | 1.329063 | -0.683440 | -1.184687 | 0.190312 | -0.521580 | -0.851729 | 1.822724 | ... | 2 | 1 | category_3 | 4 | 1 | 1 | 4 | -135.199100 | 1 | 2 |
256 rows × 23 columns
[8]:
trainer.derived_data.keys()
[8]:
dict_keys(['cont_0_relative2_cont_1', 'categorical'])