New data imputers#
Imputation is necessary if invalid values are encountered in the tabular dataset. We have provided some imputers in the package. For an arbitrary imputation class, AbstractImputer should be inherited. If the imputation class follows the structure of sklearn.impute._base._BaseImputer (or has fit_transform and transform methods), AbstractSklearnImputer is much easier to be inherited and implemented.
[1]:
from tabensemb.data import AbstractImputer, AbstractSklearnImputer, DataModule
import numpy as np
import pandas as pd
import sklearn.exceptions
from sklearn.experimental import enable_iterative_imputer
from sklearn.impute import IterativeImputer
from sklearn.ensemble import RandomForestRegressor
import warnings
Inherit AbstractImputer#
Take tabensemb.data.dataimputer.MiceLightgbmImputer as an example, _defaults provides a set of default parameters for the imputation. These parameters can be changed by specifying them in the configuration, such as "data_imputer": ("MissForestImputer", {"iterations": 5}). Parameters in the configuration do not necessarily need to be in _defaults.
class MiceLightgbmImputer(AbstractImputer):
def _defaults(self):
return dict(iterations=2, n_estimators=1)
_fit_transform is used to fit the imputer and transform the training set and the validation set. _transform will be called to impute the testing set or an upcoming dataset.
MiceLightgbmImputer uses the miceforest package. The method _get_impute_features returns features that are not completely missing. The trained imputer should be recorded as the attribute self.transformer. The imputed input_data should be returned. Parameters defined in _defaults and modified in the configuration are recorded in self.kwargs.
def _fit_transform(
self, input_data: pd.DataFrame, datamodule: DataModule, **kwargs
):
import miceforest as mf
impute_features = self._get_impute_features(
datamodule.cont_feature_names, input_data
)
no_nan = not np.any(np.isnan(input_data[impute_features].values))
imputer = mf.ImputationKernel(
input_data[impute_features], random_state=0, train_nonmissing=no_nan
)
imputer.mice(**self.kwargs)
input_data[impute_features] = imputer.complete_data().values.astype(np.float64)
imputer.compile_candidate_preds()
self.transformer = imputer
return input_data
In _transform, the trained imputer should be used to impute a new dataset. self.record_imputed_features is a copy of self._get_impute_features called in _fit_transform.
def _transform(self, input_data: pd.DataFrame, datamodule: DataModule, **kwargs):
input_data[self.record_imputed_features] = (
self.transformer.impute_new_data(
new_data=input_data[self.record_imputed_features]
)
.complete_data()
.values.astype(np.float64)
)
return input_data
You can also implement _required_kwargs as we did in “New data derivers”.
[2]:
class MiceLightgbmImputer(AbstractImputer):
def _defaults(self):
return dict(iterations=2, n_estimators=1)
def _fit_transform(
self, input_data: pd.DataFrame, datamodule: DataModule, **kwargs
):
import miceforest as mf
impute_features = self._get_impute_features(
datamodule.cont_feature_names, input_data
)
no_nan = not np.any(np.isnan(input_data[impute_features].values))
imputer = mf.ImputationKernel(
input_data[impute_features], random_state=0, train_nonmissing=no_nan
)
imputer.mice(**self.kwargs)
input_data[impute_features] = imputer.complete_data().values.astype(np.float64)
imputer.compile_candidate_preds()
self.transformer = imputer
return input_data
def _transform(self, input_data: pd.DataFrame, datamodule: DataModule, **kwargs):
input_data[self.record_imputed_features] = (
self.transformer.impute_new_data(
new_data=input_data[self.record_imputed_features]
)
.complete_data()
.values.astype(np.float64)
)
return input_data
Inherit AbstractSklearnImputer#
Take tabensemb.data.dataimputer.MissForestImputer as an example, which uses the IterativeImputer from sklearn. The implementation is much easier. _defaults is similar to that above. _new_imputer returns an imputer instance that has fit_transform and transform methods which could return an np.ndarray respectively.
[3]:
class MissForestImputer(AbstractSklearnImputer):
def _defaults(self):
return dict(
n_estimators=1,
max_depth=3,
random_state=0,
bootstrap=True,
n_jobs=-1,
)
def _new_imputer(self):
warnings.simplefilter(
action="ignore", category=sklearn.exceptions.ConvergenceWarning
)
estimator_rf = RandomForestRegressor(**self.kwargs)
return IterativeImputer(estimator=estimator_rf, random_state=0, max_iter=10)
The implemented imputer should be registered as follows to be recognized by DataModule.set_data_imputer automatically.
[4]:
from tabensemb.data.dataimputer import imputer_mapping
imputer_mapping["MiceLightgbmImputer"] = MiceLightgbmImputer
imputer_mapping["MissForestImputer"] = MissForestImputer
[5]:
from tabensemb.trainer import Trainer
import tabensemb
prefix = "../../../../"
tabensemb.setting["default_output_path"] = prefix + "output"
tabensemb.setting["default_config_path"] = prefix + "configs"
tabensemb.setting["default_data_path"] = prefix + "data"
trainer = Trainer(device="cpu")
trainer.load_config("sample")
trainer.datamodule.set_data_imputer(("MiceLightgbmImputer", {"iterations": 3}))
trainer.load_data()
The project will be saved to ../../../../output/sample/2023-09-18-18-15-03-0_sample
Dataset size: 153 51 52
Data saved to ../../../../output/sample/2023-09-18-18-15-03-0_sample (data.csv and tabular_data.csv).
The original sample.csv dataset has missing values:
[6]:
import os
pd.read_csv(os.path.join(tabensemb.setting["default_data_path"], "sample.csv"))
[6]:
| cont_0 | cont_1 | cont_2 | cont_3 | cont_4 | cont_5 | cont_6 | cont_7 | cont_8 | cont_9 | ... | cat_3 | cat_4 | cat_5 | cat_6 | cat_7 | cat_8 | cat_9 | target | target_binary | target_multi_class | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | -1.306527 | NaN | -0.118164 | -0.159573 | 1.658131 | -1.346718 | -0.680178 | -1.334258 | 0.666383 | -0.460720 | ... | 0 | 2 | category_4 | 3 | 4 | 4 | 3 | -71.084217 | 0 | 1 |
| 1 | 2.011257 | NaN | 0.195070 | 0.527004 | -0.044595 | 0.616887 | -1.781563 | 0.354758 | -0.729045 | 0.196557 | ... | 4 | 3 | category_3 | 3 | 1 | 3 | 2 | 13.415675 | 1 | 2 |
| 2 | -1.216077 | NaN | -0.743672 | 0.730184 | 0.140672 | 1.272954 | -0.159012 | -0.475175 | 0.240057 | 0.100159 | ... | 0 | 4 | category_3 | 4 | 1 | 0 | 2 | -47.492280 | 0 | 2 |
| 3 | 0.559299 | NaN | -0.431096 | -0.809627 | -1.063696 | -0.860153 | 0.572751 | -0.467441 | 0.677557 | 1.307184 | ... | 4 | 1 | category_3 | 4 | 2 | 0 | 0 | -94.482614 | 1 | 2 |
| 4 | 0.910179 | NaN | 0.786328 | -0.042257 | 0.317218 | 0.379152 | -0.466419 | -0.017020 | -0.944446 | -0.410050 | ... | 1 | 0 | category_2 | 0 | 2 | 3 | 0 | 195.819531 | 1 | 3 |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| 251 | 0.280442 | -0.206904 | 0.841631 | 0.880179 | -0.993124 | -1.570623 | -0.249459 | 0.643314 | 0.049495 | 0.493837 | ... | 1 | 2 | category_2 | 2 | 3 | 0 | 2 | -171.249549 | 0 | 0 |
| 252 | -1.165150 | -1.070753 | 0.465662 | 1.054452 | 0.900826 | -0.179925 | -1.536244 | 1.178780 | 1.488252 | 1.895889 | ... | 4 | 2 | category_4 | 4 | 2 | 1 | 1 | 23.708442 | 0 | 2 |
| 253 | -0.069856 | -0.186691 | -1.021913 | -1.143641 | 0.250114 | 1.040239 | -1.150438 | 0.258798 | -0.836111 | 0.642211 | ... | 0 | 3 | category_3 | 2 | 2 | 2 | 2 | -33.414215 | 1 | 1 |
| 254 | -1.031482 | -0.860262 | -0.061638 | 0.328301 | -1.429991 | -1.048170 | -1.432735 | 0.607112 | 0.087531 | 0.938747 | ... | 0 | 0 | category_3 | 4 | 1 | 4 | 4 | -359.199191 | 0 | 4 |
| 255 | -1.461733 | 0.960693 | 0.367545 | 1.329063 | -0.683440 | -1.184687 | 0.190312 | -0.521580 | -0.851729 | 1.822724 | ... | 2 | 1 | category_3 | 4 | 1 | 1 | 4 | -135.199100 | 1 | 2 |
256 rows × 23 columns
After imputation, these missing values are filled using correlations learned by the imputer.
[7]:
trainer.df
[7]:
| cont_0 | cont_1 | cont_2 | cont_3 | cont_4 | cont_5 | cont_6 | cont_7 | cont_8 | cont_9 | ... | cat_3 | cat_4 | cat_5 | cat_6 | cat_7 | cat_8 | cat_9 | target | target_binary | target_multi_class | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | -1.306527 | -1.830029 | -0.118164 | -0.159573 | 1.658131 | -1.346718 | -0.680178 | -1.334258 | 0.666383 | -0.460720 | ... | 0 | 2 | category_4 | 3 | 4 | 4 | 3 | -71.084217 | 0 | 1 |
| 1 | 2.011257 | 0.936795 | 0.195070 | 0.527004 | -0.044595 | 0.616887 | -1.781563 | 0.354758 | -0.729045 | 0.196557 | ... | 4 | 3 | category_3 | 3 | 1 | 3 | 2 | 13.415675 | 1 | 2 |
| 2 | -1.216077 | -0.049324 | -0.743672 | 0.730184 | 0.140672 | 1.272954 | -0.159012 | -0.475175 | 0.240057 | 0.100159 | ... | 0 | 4 | category_3 | 4 | 1 | 0 | 2 | -47.492280 | 0 | 2 |
| 3 | 0.559299 | -0.202897 | -0.431096 | -0.809627 | -1.063696 | -0.860153 | 0.572751 | -0.467441 | 0.677557 | 1.307184 | ... | 4 | 1 | category_3 | 4 | 2 | 0 | 0 | -94.482614 | 1 | 2 |
| 4 | 0.910179 | -0.483250 | 0.786328 | -0.042257 | 0.317218 | 0.379152 | -0.466419 | -0.017020 | -0.944446 | -0.410050 | ... | 1 | 0 | category_2 | 0 | 2 | 3 | 0 | 195.819531 | 1 | 3 |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| 251 | 0.280442 | -0.206904 | 0.841631 | 0.880179 | -0.993124 | -1.570623 | -0.249459 | 0.643314 | 0.049495 | 0.493837 | ... | 1 | 2 | category_2 | 2 | 3 | 0 | 2 | -171.249549 | 0 | 0 |
| 252 | -1.165150 | -1.070753 | 0.465662 | 1.054452 | 0.900826 | -0.179925 | -1.536244 | 1.178780 | 1.488252 | 1.895889 | ... | 4 | 2 | category_4 | 4 | 2 | 1 | 1 | 23.708442 | 0 | 2 |
| 253 | -0.069856 | -0.186691 | -1.021913 | -1.143641 | 0.250114 | 1.040239 | -1.150438 | 0.258798 | -0.836111 | 0.642211 | ... | 0 | 3 | category_3 | 2 | 2 | 2 | 2 | -33.414215 | 1 | 1 |
| 254 | -1.031482 | -0.860262 | -0.061638 | 0.328301 | -1.429991 | -1.048170 | -1.432735 | 0.607112 | 0.087531 | 0.938747 | ... | 0 | 0 | category_3 | 4 | 1 | 4 | 4 | -359.199191 | 0 | 4 |
| 255 | -1.461733 | 0.960693 | 0.367545 | 1.329063 | -0.683440 | -1.184687 | 0.190312 | -0.521580 | -0.851729 | 1.822724 | ... | 2 | 1 | category_3 | 4 | 1 | 1 | 4 | -135.199100 | 1 | 2 |
256 rows × 23 columns
The following code accesses the dataset without imputation. Derived stacked features are also supported but the case is not shown here.
[8]:
trainer.datamodule.get_not_imputed_df()
[8]:
| cont_0 | cont_1 | cont_2 | cont_3 | cont_4 | cont_5 | cont_6 | cont_7 | cont_8 | cont_9 | ... | cat_3 | cat_4 | cat_5 | cat_6 | cat_7 | cat_8 | cat_9 | target | target_binary | target_multi_class | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | -1.306527 | NaN | -0.118164 | -0.159573 | 1.658131 | -1.346718 | -0.680178 | -1.334258 | 0.666383 | -0.460720 | ... | 0 | 2 | category_4 | 3 | 4 | 4 | 3 | -71.084217 | 0 | 1 |
| 1 | 2.011257 | NaN | 0.195070 | 0.527004 | -0.044595 | 0.616887 | -1.781563 | 0.354758 | -0.729045 | 0.196557 | ... | 4 | 3 | category_3 | 3 | 1 | 3 | 2 | 13.415675 | 1 | 2 |
| 2 | -1.216077 | NaN | -0.743672 | 0.730184 | 0.140672 | 1.272954 | -0.159012 | -0.475175 | 0.240057 | 0.100159 | ... | 0 | 4 | category_3 | 4 | 1 | 0 | 2 | -47.492280 | 0 | 2 |
| 3 | 0.559299 | NaN | -0.431096 | -0.809627 | -1.063696 | -0.860153 | 0.572751 | -0.467441 | 0.677557 | 1.307184 | ... | 4 | 1 | category_3 | 4 | 2 | 0 | 0 | -94.482614 | 1 | 2 |
| 4 | 0.910179 | NaN | 0.786328 | -0.042257 | 0.317218 | 0.379152 | -0.466419 | -0.017020 | -0.944446 | -0.410050 | ... | 1 | 0 | category_2 | 0 | 2 | 3 | 0 | 195.819531 | 1 | 3 |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| 251 | 0.280442 | -0.206904 | 0.841631 | 0.880179 | -0.993124 | -1.570623 | -0.249459 | 0.643314 | 0.049495 | 0.493837 | ... | 1 | 2 | category_2 | 2 | 3 | 0 | 2 | -171.249549 | 0 | 0 |
| 252 | -1.165150 | -1.070753 | 0.465662 | 1.054452 | 0.900826 | -0.179925 | -1.536244 | 1.178780 | 1.488252 | 1.895889 | ... | 4 | 2 | category_4 | 4 | 2 | 1 | 1 | 23.708442 | 0 | 2 |
| 253 | -0.069856 | -0.186691 | -1.021913 | -1.143641 | 0.250114 | 1.040239 | -1.150438 | 0.258798 | -0.836111 | 0.642211 | ... | 0 | 3 | category_3 | 2 | 2 | 2 | 2 | -33.414215 | 1 | 1 |
| 254 | -1.031482 | -0.860262 | -0.061638 | 0.328301 | -1.429991 | -1.048170 | -1.432735 | 0.607112 | 0.087531 | 0.938747 | ... | 0 | 0 | category_3 | 4 | 1 | 4 | 4 | -359.199191 | 0 | 4 |
| 255 | -1.461733 | 0.960693 | 0.367545 | 1.329063 | -0.683440 | -1.184687 | 0.190312 | -0.521580 | -0.851729 | 1.822724 | ... | 2 | 1 | category_3 | 4 | 1 | 1 | 4 | -135.199100 | 1 | 2 |
256 rows × 23 columns