{ "cells": [ { "cell_type": "markdown", "source": [ "# New data imputers\n", "\n", "Imputation is necessary if invalid values are encountered in the tabular dataset. We have provided some imputers in the package. For an arbitrary imputation class, `AbstractImputer` should be inherited. If the imputation class follows the structure of `sklearn.impute._base._BaseImputer` (or has `fit_transform` and `transform` methods), `AbstractSklearnImputer` is much easier to be inherited and implemented.\n" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%% md\n" } } }, { "cell_type": "code", "execution_count": 1, "outputs": [], "source": [ "from tabensemb.data import AbstractImputer, AbstractSklearnImputer, DataModule\n", "import numpy as np\n", "import pandas as pd\n", "import sklearn.exceptions\n", "from sklearn.experimental import enable_iterative_imputer\n", "from sklearn.impute import IterativeImputer\n", "from sklearn.ensemble import RandomForestRegressor\n", "import warnings" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "markdown", "source": [ "## Inherit `AbstractImputer`\n", "\n", "Take `tabensemb.data.dataimputer.MiceLightgbmImputer` as an example, `_defaults` provides a set of default parameters for the imputation. These parameters can be changed by specifying them in the configuration, such as `\"data_imputer\": (\"MissForestImputer\", {\"iterations\": 5})`. Parameters in the configuration do not necessarily need to be in `_defaults`.\n", "\n", "```python\n", "class MiceLightgbmImputer(AbstractImputer):\n", " def _defaults(self):\n", " return dict(iterations=2, n_estimators=1)\n", "```\n", "\n", "`_fit_transform` is used to fit the imputer and transform the training set and the validation set. `_transform` will be called to impute the testing set or an upcoming dataset.\n", "\n", "`MiceLightgbmImputer` uses the `miceforest` package. The method `_get_impute_features` returns features that are not completely missing. The trained imputer should be recorded as the attribute `self.transformer`. The imputed `input_data` should be returned. Parameters defined in `_defaults` and modified in the configuration are recorded in `self.kwargs`.\n", "\n", "```python\n", " def _fit_transform(\n", " self, input_data: pd.DataFrame, datamodule: DataModule, **kwargs\n", " ):\n", " import miceforest as mf\n", "\n", " impute_features = self._get_impute_features(\n", " datamodule.cont_feature_names, input_data\n", " )\n", " no_nan = not np.any(np.isnan(input_data[impute_features].values))\n", " imputer = mf.ImputationKernel(\n", " input_data[impute_features], random_state=0, train_nonmissing=no_nan\n", " )\n", " imputer.mice(**self.kwargs)\n", " input_data[impute_features] = imputer.complete_data().values.astype(np.float64)\n", " imputer.compile_candidate_preds()\n", " self.transformer = imputer\n", " return input_data\n", "```\n", "\n", "In `_transform`, the trained imputer should be used to impute a new dataset. `self.record_imputed_features` is a copy of `self._get_impute_features` called in `_fit_transform`.\n", "\n", "```python\n", " def _transform(self, input_data: pd.DataFrame, datamodule: DataModule, **kwargs):\n", " input_data[self.record_imputed_features] = (\n", " self.transformer.impute_new_data(\n", " new_data=input_data[self.record_imputed_features]\n", " )\n", " .complete_data()\n", " .values.astype(np.float64)\n", " )\n", " return input_data\n", "```\n", "\n", "You can also implement `_required_kwargs` as we did in \"New data derivers\"." ], "metadata": { "collapsed": false, "pycharm": { "name": "#%% md\n" } } }, { "cell_type": "code", "execution_count": 2, "outputs": [], "source": [ "class MiceLightgbmImputer(AbstractImputer):\n", " def _defaults(self):\n", " return dict(iterations=2, n_estimators=1)\n", "\n", " def _fit_transform(\n", " self, input_data: pd.DataFrame, datamodule: DataModule, **kwargs\n", " ):\n", " import miceforest as mf\n", "\n", " impute_features = self._get_impute_features(\n", " datamodule.cont_feature_names, input_data\n", " )\n", " no_nan = not np.any(np.isnan(input_data[impute_features].values))\n", " imputer = mf.ImputationKernel(\n", " input_data[impute_features], random_state=0, train_nonmissing=no_nan\n", " )\n", " imputer.mice(**self.kwargs)\n", " input_data[impute_features] = imputer.complete_data().values.astype(np.float64)\n", " imputer.compile_candidate_preds()\n", " self.transformer = imputer\n", " return input_data\n", "\n", " def _transform(self, input_data: pd.DataFrame, datamodule: DataModule, **kwargs):\n", " input_data[self.record_imputed_features] = (\n", " self.transformer.impute_new_data(\n", " new_data=input_data[self.record_imputed_features]\n", " )\n", " .complete_data()\n", " .values.astype(np.float64)\n", " )\n", " return input_data" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "markdown", "source": [ "## Inherit `AbstractSklearnImputer`\n", "\n", "Take `tabensemb.data.dataimputer.MissForestImputer` as an example, which uses the `IterativeImputer` from `sklearn`. The implementation is much easier. `_defaults` is similar to that above. `_new_imputer` returns an imputer instance that has `fit_transform` and `transform` methods which could return an `np.ndarray` respectively." ], "metadata": { "collapsed": false, "pycharm": { "name": "#%% md\n" } } }, { "cell_type": "code", "execution_count": 3, "outputs": [], "source": [ "class MissForestImputer(AbstractSklearnImputer):\n", " def _defaults(self):\n", " return dict(\n", " n_estimators=1,\n", " max_depth=3,\n", " random_state=0,\n", " bootstrap=True,\n", " n_jobs=-1,\n", " )\n", "\n", " def _new_imputer(self):\n", " warnings.simplefilter(\n", " action=\"ignore\", category=sklearn.exceptions.ConvergenceWarning\n", " )\n", " estimator_rf = RandomForestRegressor(**self.kwargs)\n", " return IterativeImputer(estimator=estimator_rf, random_state=0, max_iter=10)" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "markdown", "source": [ "The implemented imputer should be registered as follows to be recognized by `DataModule.set_data_imputer` automatically." ], "metadata": { "collapsed": false, "pycharm": { "name": "#%% md\n" } } }, { "cell_type": "code", "execution_count": 4, "outputs": [], "source": [ "from tabensemb.data.dataimputer import imputer_mapping\n", "imputer_mapping[\"MiceLightgbmImputer\"] = MiceLightgbmImputer\n", "imputer_mapping[\"MissForestImputer\"] = MissForestImputer" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "code", "execution_count": 5, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "The project will be saved to ../../../../output/sample/2023-09-18-18-15-03-0_sample\n", "Dataset size: 153 51 52\n", "Data saved to ../../../../output/sample/2023-09-18-18-15-03-0_sample (data.csv and tabular_data.csv).\n" ] } ], "source": [ "from tabensemb.trainer import Trainer\n", "import tabensemb\n", "\n", "prefix = \"../../../../\"\n", "tabensemb.setting[\"default_output_path\"] = prefix + \"output\"\n", "tabensemb.setting[\"default_config_path\"] = prefix + \"configs\"\n", "tabensemb.setting[\"default_data_path\"] = prefix + \"data\"\n", "\n", "trainer = Trainer(device=\"cpu\")\n", "\n", "trainer.load_config(\"sample\")\n", "trainer.datamodule.set_data_imputer((\"MiceLightgbmImputer\", {\"iterations\": 3}))\n", "trainer.load_data()" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "markdown", "source": [ "The original `sample.csv` dataset has missing values:" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%% md\n" } } }, { "cell_type": "code", "execution_count": 6, "outputs": [ { "data": { "text/plain": " cont_0 cont_1 cont_2 cont_3 cont_4 cont_5 cont_6 \\\n0 -1.306527 NaN -0.118164 -0.159573 1.658131 -1.346718 -0.680178 \n1 2.011257 NaN 0.195070 0.527004 -0.044595 0.616887 -1.781563 \n2 -1.216077 NaN -0.743672 0.730184 0.140672 1.272954 -0.159012 \n3 0.559299 NaN -0.431096 -0.809627 -1.063696 -0.860153 0.572751 \n4 0.910179 NaN 0.786328 -0.042257 0.317218 0.379152 -0.466419 \n.. ... ... ... ... ... ... ... \n251 0.280442 -0.206904 0.841631 0.880179 -0.993124 -1.570623 -0.249459 \n252 -1.165150 -1.070753 0.465662 1.054452 0.900826 -0.179925 -1.536244 \n253 -0.069856 -0.186691 -1.021913 -1.143641 0.250114 1.040239 -1.150438 \n254 -1.031482 -0.860262 -0.061638 0.328301 -1.429991 -1.048170 -1.432735 \n255 -1.461733 0.960693 0.367545 1.329063 -0.683440 -1.184687 0.190312 \n\n cont_7 cont_8 cont_9 ... cat_3 cat_4 cat_5 cat_6 cat_7 \\\n0 -1.334258 0.666383 -0.460720 ... 0 2 category_4 3 4 \n1 0.354758 -0.729045 0.196557 ... 4 3 category_3 3 1 \n2 -0.475175 0.240057 0.100159 ... 0 4 category_3 4 1 \n3 -0.467441 0.677557 1.307184 ... 4 1 category_3 4 2 \n4 -0.017020 -0.944446 -0.410050 ... 1 0 category_2 0 2 \n.. ... ... ... ... ... ... ... ... ... \n251 0.643314 0.049495 0.493837 ... 1 2 category_2 2 3 \n252 1.178780 1.488252 1.895889 ... 4 2 category_4 4 2 \n253 0.258798 -0.836111 0.642211 ... 0 3 category_3 2 2 \n254 0.607112 0.087531 0.938747 ... 0 0 category_3 4 1 \n255 -0.521580 -0.851729 1.822724 ... 2 1 category_3 4 1 \n\n cat_8 cat_9 target target_binary target_multi_class \n0 4 3 -71.084217 0 1 \n1 3 2 13.415675 1 2 \n2 0 2 -47.492280 0 2 \n3 0 0 -94.482614 1 2 \n4 3 0 195.819531 1 3 \n.. ... ... ... ... ... \n251 0 2 -171.249549 0 0 \n252 1 1 23.708442 0 2 \n253 2 2 -33.414215 1 1 \n254 4 4 -359.199191 0 4 \n255 1 4 -135.199100 1 2 \n\n[256 rows x 23 columns]", "text/html": "
| \n | cont_0 | \ncont_1 | \ncont_2 | \ncont_3 | \ncont_4 | \ncont_5 | \ncont_6 | \ncont_7 | \ncont_8 | \ncont_9 | \n... | \ncat_3 | \ncat_4 | \ncat_5 | \ncat_6 | \ncat_7 | \ncat_8 | \ncat_9 | \ntarget | \ntarget_binary | \ntarget_multi_class | \n
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | \n-1.306527 | \nNaN | \n-0.118164 | \n-0.159573 | \n1.658131 | \n-1.346718 | \n-0.680178 | \n-1.334258 | \n0.666383 | \n-0.460720 | \n... | \n0 | \n2 | \ncategory_4 | \n3 | \n4 | \n4 | \n3 | \n-71.084217 | \n0 | \n1 | \n
| 1 | \n2.011257 | \nNaN | \n0.195070 | \n0.527004 | \n-0.044595 | \n0.616887 | \n-1.781563 | \n0.354758 | \n-0.729045 | \n0.196557 | \n... | \n4 | \n3 | \ncategory_3 | \n3 | \n1 | \n3 | \n2 | \n13.415675 | \n1 | \n2 | \n
| 2 | \n-1.216077 | \nNaN | \n-0.743672 | \n0.730184 | \n0.140672 | \n1.272954 | \n-0.159012 | \n-0.475175 | \n0.240057 | \n0.100159 | \n... | \n0 | \n4 | \ncategory_3 | \n4 | \n1 | \n0 | \n2 | \n-47.492280 | \n0 | \n2 | \n
| 3 | \n0.559299 | \nNaN | \n-0.431096 | \n-0.809627 | \n-1.063696 | \n-0.860153 | \n0.572751 | \n-0.467441 | \n0.677557 | \n1.307184 | \n... | \n4 | \n1 | \ncategory_3 | \n4 | \n2 | \n0 | \n0 | \n-94.482614 | \n1 | \n2 | \n
| 4 | \n0.910179 | \nNaN | \n0.786328 | \n-0.042257 | \n0.317218 | \n0.379152 | \n-0.466419 | \n-0.017020 | \n-0.944446 | \n-0.410050 | \n... | \n1 | \n0 | \ncategory_2 | \n0 | \n2 | \n3 | \n0 | \n195.819531 | \n1 | \n3 | \n
| ... | \n... | \n... | \n... | \n... | \n... | \n... | \n... | \n... | \n... | \n... | \n... | \n... | \n... | \n... | \n... | \n... | \n... | \n... | \n... | \n... | \n... | \n
| 251 | \n0.280442 | \n-0.206904 | \n0.841631 | \n0.880179 | \n-0.993124 | \n-1.570623 | \n-0.249459 | \n0.643314 | \n0.049495 | \n0.493837 | \n... | \n1 | \n2 | \ncategory_2 | \n2 | \n3 | \n0 | \n2 | \n-171.249549 | \n0 | \n0 | \n
| 252 | \n-1.165150 | \n-1.070753 | \n0.465662 | \n1.054452 | \n0.900826 | \n-0.179925 | \n-1.536244 | \n1.178780 | \n1.488252 | \n1.895889 | \n... | \n4 | \n2 | \ncategory_4 | \n4 | \n2 | \n1 | \n1 | \n23.708442 | \n0 | \n2 | \n
| 253 | \n-0.069856 | \n-0.186691 | \n-1.021913 | \n-1.143641 | \n0.250114 | \n1.040239 | \n-1.150438 | \n0.258798 | \n-0.836111 | \n0.642211 | \n... | \n0 | \n3 | \ncategory_3 | \n2 | \n2 | \n2 | \n2 | \n-33.414215 | \n1 | \n1 | \n
| 254 | \n-1.031482 | \n-0.860262 | \n-0.061638 | \n0.328301 | \n-1.429991 | \n-1.048170 | \n-1.432735 | \n0.607112 | \n0.087531 | \n0.938747 | \n... | \n0 | \n0 | \ncategory_3 | \n4 | \n1 | \n4 | \n4 | \n-359.199191 | \n0 | \n4 | \n
| 255 | \n-1.461733 | \n0.960693 | \n0.367545 | \n1.329063 | \n-0.683440 | \n-1.184687 | \n0.190312 | \n-0.521580 | \n-0.851729 | \n1.822724 | \n... | \n2 | \n1 | \ncategory_3 | \n4 | \n1 | \n1 | \n4 | \n-135.199100 | \n1 | \n2 | \n
256 rows × 23 columns
\n| \n | cont_0 | \ncont_1 | \ncont_2 | \ncont_3 | \ncont_4 | \ncont_5 | \ncont_6 | \ncont_7 | \ncont_8 | \ncont_9 | \n... | \ncat_3 | \ncat_4 | \ncat_5 | \ncat_6 | \ncat_7 | \ncat_8 | \ncat_9 | \ntarget | \ntarget_binary | \ntarget_multi_class | \n
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | \n-1.306527 | \n-1.830029 | \n-0.118164 | \n-0.159573 | \n1.658131 | \n-1.346718 | \n-0.680178 | \n-1.334258 | \n0.666383 | \n-0.460720 | \n... | \n0 | \n2 | \ncategory_4 | \n3 | \n4 | \n4 | \n3 | \n-71.084217 | \n0 | \n1 | \n
| 1 | \n2.011257 | \n0.936795 | \n0.195070 | \n0.527004 | \n-0.044595 | \n0.616887 | \n-1.781563 | \n0.354758 | \n-0.729045 | \n0.196557 | \n... | \n4 | \n3 | \ncategory_3 | \n3 | \n1 | \n3 | \n2 | \n13.415675 | \n1 | \n2 | \n
| 2 | \n-1.216077 | \n-0.049324 | \n-0.743672 | \n0.730184 | \n0.140672 | \n1.272954 | \n-0.159012 | \n-0.475175 | \n0.240057 | \n0.100159 | \n... | \n0 | \n4 | \ncategory_3 | \n4 | \n1 | \n0 | \n2 | \n-47.492280 | \n0 | \n2 | \n
| 3 | \n0.559299 | \n-0.202897 | \n-0.431096 | \n-0.809627 | \n-1.063696 | \n-0.860153 | \n0.572751 | \n-0.467441 | \n0.677557 | \n1.307184 | \n... | \n4 | \n1 | \ncategory_3 | \n4 | \n2 | \n0 | \n0 | \n-94.482614 | \n1 | \n2 | \n
| 4 | \n0.910179 | \n-0.483250 | \n0.786328 | \n-0.042257 | \n0.317218 | \n0.379152 | \n-0.466419 | \n-0.017020 | \n-0.944446 | \n-0.410050 | \n... | \n1 | \n0 | \ncategory_2 | \n0 | \n2 | \n3 | \n0 | \n195.819531 | \n1 | \n3 | \n
| ... | \n... | \n... | \n... | \n... | \n... | \n... | \n... | \n... | \n... | \n... | \n... | \n... | \n... | \n... | \n... | \n... | \n... | \n... | \n... | \n... | \n... | \n
| 251 | \n0.280442 | \n-0.206904 | \n0.841631 | \n0.880179 | \n-0.993124 | \n-1.570623 | \n-0.249459 | \n0.643314 | \n0.049495 | \n0.493837 | \n... | \n1 | \n2 | \ncategory_2 | \n2 | \n3 | \n0 | \n2 | \n-171.249549 | \n0 | \n0 | \n
| 252 | \n-1.165150 | \n-1.070753 | \n0.465662 | \n1.054452 | \n0.900826 | \n-0.179925 | \n-1.536244 | \n1.178780 | \n1.488252 | \n1.895889 | \n... | \n4 | \n2 | \ncategory_4 | \n4 | \n2 | \n1 | \n1 | \n23.708442 | \n0 | \n2 | \n
| 253 | \n-0.069856 | \n-0.186691 | \n-1.021913 | \n-1.143641 | \n0.250114 | \n1.040239 | \n-1.150438 | \n0.258798 | \n-0.836111 | \n0.642211 | \n... | \n0 | \n3 | \ncategory_3 | \n2 | \n2 | \n2 | \n2 | \n-33.414215 | \n1 | \n1 | \n
| 254 | \n-1.031482 | \n-0.860262 | \n-0.061638 | \n0.328301 | \n-1.429991 | \n-1.048170 | \n-1.432735 | \n0.607112 | \n0.087531 | \n0.938747 | \n... | \n0 | \n0 | \ncategory_3 | \n4 | \n1 | \n4 | \n4 | \n-359.199191 | \n0 | \n4 | \n
| 255 | \n-1.461733 | \n0.960693 | \n0.367545 | \n1.329063 | \n-0.683440 | \n-1.184687 | \n0.190312 | \n-0.521580 | \n-0.851729 | \n1.822724 | \n... | \n2 | \n1 | \ncategory_3 | \n4 | \n1 | \n1 | \n4 | \n-135.199100 | \n1 | \n2 | \n
256 rows × 23 columns
\n| \n | cont_0 | \ncont_1 | \ncont_2 | \ncont_3 | \ncont_4 | \ncont_5 | \ncont_6 | \ncont_7 | \ncont_8 | \ncont_9 | \n... | \ncat_3 | \ncat_4 | \ncat_5 | \ncat_6 | \ncat_7 | \ncat_8 | \ncat_9 | \ntarget | \ntarget_binary | \ntarget_multi_class | \n
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | \n-1.306527 | \nNaN | \n-0.118164 | \n-0.159573 | \n1.658131 | \n-1.346718 | \n-0.680178 | \n-1.334258 | \n0.666383 | \n-0.460720 | \n... | \n0 | \n2 | \ncategory_4 | \n3 | \n4 | \n4 | \n3 | \n-71.084217 | \n0 | \n1 | \n
| 1 | \n2.011257 | \nNaN | \n0.195070 | \n0.527004 | \n-0.044595 | \n0.616887 | \n-1.781563 | \n0.354758 | \n-0.729045 | \n0.196557 | \n... | \n4 | \n3 | \ncategory_3 | \n3 | \n1 | \n3 | \n2 | \n13.415675 | \n1 | \n2 | \n
| 2 | \n-1.216077 | \nNaN | \n-0.743672 | \n0.730184 | \n0.140672 | \n1.272954 | \n-0.159012 | \n-0.475175 | \n0.240057 | \n0.100159 | \n... | \n0 | \n4 | \ncategory_3 | \n4 | \n1 | \n0 | \n2 | \n-47.492280 | \n0 | \n2 | \n
| 3 | \n0.559299 | \nNaN | \n-0.431096 | \n-0.809627 | \n-1.063696 | \n-0.860153 | \n0.572751 | \n-0.467441 | \n0.677557 | \n1.307184 | \n... | \n4 | \n1 | \ncategory_3 | \n4 | \n2 | \n0 | \n0 | \n-94.482614 | \n1 | \n2 | \n
| 4 | \n0.910179 | \nNaN | \n0.786328 | \n-0.042257 | \n0.317218 | \n0.379152 | \n-0.466419 | \n-0.017020 | \n-0.944446 | \n-0.410050 | \n... | \n1 | \n0 | \ncategory_2 | \n0 | \n2 | \n3 | \n0 | \n195.819531 | \n1 | \n3 | \n
| ... | \n... | \n... | \n... | \n... | \n... | \n... | \n... | \n... | \n... | \n... | \n... | \n... | \n... | \n... | \n... | \n... | \n... | \n... | \n... | \n... | \n... | \n
| 251 | \n0.280442 | \n-0.206904 | \n0.841631 | \n0.880179 | \n-0.993124 | \n-1.570623 | \n-0.249459 | \n0.643314 | \n0.049495 | \n0.493837 | \n... | \n1 | \n2 | \ncategory_2 | \n2 | \n3 | \n0 | \n2 | \n-171.249549 | \n0 | \n0 | \n
| 252 | \n-1.165150 | \n-1.070753 | \n0.465662 | \n1.054452 | \n0.900826 | \n-0.179925 | \n-1.536244 | \n1.178780 | \n1.488252 | \n1.895889 | \n... | \n4 | \n2 | \ncategory_4 | \n4 | \n2 | \n1 | \n1 | \n23.708442 | \n0 | \n2 | \n
| 253 | \n-0.069856 | \n-0.186691 | \n-1.021913 | \n-1.143641 | \n0.250114 | \n1.040239 | \n-1.150438 | \n0.258798 | \n-0.836111 | \n0.642211 | \n... | \n0 | \n3 | \ncategory_3 | \n2 | \n2 | \n2 | \n2 | \n-33.414215 | \n1 | \n1 | \n
| 254 | \n-1.031482 | \n-0.860262 | \n-0.061638 | \n0.328301 | \n-1.429991 | \n-1.048170 | \n-1.432735 | \n0.607112 | \n0.087531 | \n0.938747 | \n... | \n0 | \n0 | \ncategory_3 | \n4 | \n1 | \n4 | \n4 | \n-359.199191 | \n0 | \n4 | \n
| 255 | \n-1.461733 | \n0.960693 | \n0.367545 | \n1.329063 | \n-0.683440 | \n-1.184687 | \n0.190312 | \n-0.521580 | \n-0.851729 | \n1.822724 | \n... | \n2 | \n1 | \ncategory_3 | \n4 | \n1 | \n1 | \n4 | \n-135.199100 | \n1 | \n2 | \n
256 rows × 23 columns
\n