{ "cells": [ { "cell_type": "markdown", "source": [ "# Dataset and configuration\n", "\n", "In this part, we will introduce how to prepare a new dataset and its configuration file, and the basic usage of `UserConfig` and `DataModule`. You will be able to run benchmarks on your own dataset after reading this part.\n", "\n", "## The dataset\n", "\n", "We provide a randomly generated sample dataset (`data/sample.csv`) and its configuration file (`configs/sample.py`) in the repository. First, let's check the content of `sample.csv`. It contains 256 data points, 10 continuous features (namely `cont_0` to `cont_9`), 10 categorical features (namely `cat_0` to `cat_9`), and one target column `target`.\n", "\n", "**Remark**: The dataset file should not contain an index column.\n", "\n", "**Remark**: Both `.csv` and `.xlsx` are supported. We recommend `.csv` files for their efficiency.\n", "\n", "**Remark**: Values of categorical features that contain non-numerical values (bool, string, or mixed types) will be transformed into strings. So, for example, the number `3` and the string `\"3\"` of a categorical feature will be the same (are both interpreted as the string `\"3\"`)." ], "metadata": { "collapsed": false, "pycharm": { "name": "#%% md\n" } } }, { "cell_type": "code", "execution_count": 1, "outputs": [ { "data": { "text/plain": " cont_0 cont_1 cont_2 cont_3 cont_4 cont_5 cont_6 \\\n0 -1.306527 NaN -0.118164 -0.159573 1.658131 -1.346718 -0.680178 \n1 2.011257 NaN 0.195070 0.527004 -0.044595 0.616887 -1.781563 \n2 -1.216077 NaN -0.743672 0.730184 0.140672 1.272954 -0.159012 \n3 0.559299 NaN -0.431096 -0.809627 -1.063696 -0.860153 0.572751 \n4 0.910179 NaN 0.786328 -0.042257 0.317218 0.379152 -0.466419 \n.. ... ... ... ... ... ... ... \n251 0.280442 -0.206904 0.841631 0.880179 -0.993124 -1.570623 -0.249459 \n252 -1.165150 -1.070753 0.465662 1.054452 0.900826 -0.179925 -1.536244 \n253 -0.069856 -0.186691 -1.021913 -1.143641 0.250114 1.040239 -1.150438 \n254 -1.031482 -0.860262 -0.061638 0.328301 -1.429991 -1.048170 -1.432735 \n255 -1.461733 0.960693 0.367545 1.329063 -0.683440 -1.184687 0.190312 \n\n cont_7 cont_8 cont_9 ... cat_3 cat_4 cat_5 cat_6 cat_7 \\\n0 -1.334258 0.666383 -0.460720 ... 0 2 category_4 3 4 \n1 0.354758 -0.729045 0.196557 ... 4 3 category_3 3 1 \n2 -0.475175 0.240057 0.100159 ... 0 4 category_3 4 1 \n3 -0.467441 0.677557 1.307184 ... 4 1 category_3 4 2 \n4 -0.017020 -0.944446 -0.410050 ... 1 0 category_2 0 2 \n.. ... ... ... ... ... ... ... ... ... \n251 0.643314 0.049495 0.493837 ... 1 2 category_2 2 3 \n252 1.178780 1.488252 1.895889 ... 4 2 category_4 4 2 \n253 0.258798 -0.836111 0.642211 ... 0 3 category_3 2 2 \n254 0.607112 0.087531 0.938747 ... 0 0 category_3 4 1 \n255 -0.521580 -0.851729 1.822724 ... 2 1 category_3 4 1 \n\n cat_8 cat_9 target target_binary target_multi_class \n0 4 3 -71.084217 0 1 \n1 3 2 13.415675 1 2 \n2 0 2 -47.492280 0 2 \n3 0 0 -94.482614 1 2 \n4 3 0 195.819531 1 3 \n.. ... ... ... ... ... \n251 0 2 -171.249549 0 0 \n252 1 1 23.708442 0 2 \n253 2 2 -33.414215 1 1 \n254 4 4 -359.199191 0 4 \n255 1 4 -135.199100 1 2 \n\n[256 rows x 23 columns]", "text/html": "
\n\n\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n
cont_0cont_1cont_2cont_3cont_4cont_5cont_6cont_7cont_8cont_9...cat_3cat_4cat_5cat_6cat_7cat_8cat_9targettarget_binarytarget_multi_class
0-1.306527NaN-0.118164-0.1595731.658131-1.346718-0.680178-1.3342580.666383-0.460720...02category_43443-71.08421701
12.011257NaN0.1950700.527004-0.0445950.616887-1.7815630.354758-0.7290450.196557...43category_3313213.41567512
2-1.216077NaN-0.7436720.7301840.1406721.272954-0.159012-0.4751750.2400570.100159...04category_34102-47.49228002
30.559299NaN-0.431096-0.809627-1.063696-0.8601530.572751-0.4674410.6775571.307184...41category_34200-94.48261412
40.910179NaN0.786328-0.0422570.3172180.379152-0.466419-0.017020-0.944446-0.410050...10category_20230195.81953113
..................................................................
2510.280442-0.2069040.8416310.880179-0.993124-1.570623-0.2494590.6433140.0494950.493837...12category_22302-171.24954900
252-1.165150-1.0707530.4656621.0544520.900826-0.179925-1.5362441.1787801.4882521.895889...42category_4421123.70844202
253-0.069856-0.186691-1.021913-1.1436410.2501141.040239-1.1504380.258798-0.8361110.642211...03category_32222-33.41421511
254-1.031482-0.860262-0.0616380.328301-1.429991-1.048170-1.4327350.6071120.0875310.938747...00category_34144-359.19919104
255-1.4617330.9606930.3675451.329063-0.683440-1.1846870.190312-0.521580-0.8517291.822724...21category_34114-135.19910012
\n

256 rows × 23 columns

\n
" }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import pandas as pd\n", "\n", "prefix = \"../../../../\"\n", "pd.read_csv(prefix + \"data/sample.csv\")" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "markdown", "source": [ "## The configuration file\n", "\n", "A configuration file contains a dictionary stating modified values compared to a given default configuration.\n", "\n", "**Remark**: The dataset file can be a `.py` file containing a `dict` object named `cfg`, or a `.json` file.\n", "\n", "### The default configuration\n", "\n", "To see the default values, use `tabensemb.config.UserConfig`, which inherits `dict`." ], "metadata": { "collapsed": false, "pycharm": { "name": "#%% md\n" } } }, { "cell_type": "code", "execution_count": 2, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{\n", "\t'database': 'sample',\n", "\t'task': None,\n", "\t'loss': None,\n", "\t'bayes_opt': False,\n", "\t'bayes_calls': 50,\n", "\t'bayes_epoch': 30,\n", "\t'patience': 100,\n", "\t'epoch': 300,\n", "\t'lr': 0.001,\n", "\t'weight_decay': 1e-09,\n", "\t'batch_size': 1024,\n", "\t'layers': [\n", "\t\t64,\n", "\t\t128,\n", "\t\t256,\n", "\t\t128,\n", "\t\t64\n", "\t],\n", "\t'SPACEs': {\n", "\t\t'lr': {\n", "\t\t\t'type': 'Real',\n", "\t\t\t'low': 0.0001,\n", "\t\t\t'high': 0.05,\n", "\t\t\t'prior': 'log-uniform'\n", "\t\t},\n", "\t\t'weight_decay': {\n", "\t\t\t'type': 'Real',\n", "\t\t\t'low': 1e-09,\n", "\t\t\t'high': 0.05,\n", "\t\t\t'prior': 'log-uniform'\n", "\t\t},\n", "\t\t'batch_size': {\n", "\t\t\t'type': 'Categorical',\n", "\t\t\t'categories': [\n", "\t\t\t\t64,\n", "\t\t\t\t128,\n", "\t\t\t\t256,\n", "\t\t\t\t512,\n", "\t\t\t\t1024,\n", "\t\t\t\t2048\n", "\t\t\t]\n", "\t\t}\n", "\t},\n", "\t'data_splitter': 'RandomSplitter',\n", "\t'split_ratio': [\n", "\t\t0.6,\n", "\t\t0.2,\n", "\t\t0.2\n", "\t],\n", "\t'data_imputer': 'MissForestImputer',\n", "\t'data_processors': [\n", "\t\t(\n", "\t\t\t'CategoricalOrdinalEncoder',\n", "\t\t\t{\n", "\t\t\t}\n", "\t\t),\n", "\t\t(\n", "\t\t\t'NaNFeatureRemover',\n", "\t\t\t{\n", "\t\t\t}\n", "\t\t),\n", "\t\t(\n", "\t\t\t'VarianceFeatureSelector',\n", "\t\t\t{\n", "\t\t\t\t'thres': 1\n", "\t\t\t}\n", "\t\t),\n", "\t\t(\n", "\t\t\t'StandardScaler',\n", "\t\t\t{\n", "\t\t\t}\n", "\t\t)\n", "\t],\n", "\t'data_derivers': [\n", "\t],\n", "\t'categorical_feature_names': [\n", "\t],\n", "\t'continuous_feature_names': [\n", "\t],\n", "\t'feature_types': {\n", "\t},\n", "\t'unique_feature_types': [\n", "\t],\n", "\t'label_name': [\n", "\t\t'target'\n", "\t]\n", "}\n" ] } ], "source": [ "from tabensemb.config import UserConfig\n", "from tabensemb.utils import pretty\n", "import tabensemb\n", "\n", "tabensemb.setting[\"default_config_path\"] = prefix + \"configs\"\n", "\n", "cfg = UserConfig(\"sample\")\n", "print(pretty(cfg.defaults()))" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "markdown", "source": [ "### The configuration of the given sample dataset\n", "\n", "`configs/sample.py` contains the following contents:\n", "```python\n", "cfg = {\n", " \"database\": \"sample\",\n", " \"continuous_feature_names\": [\"cont_0\", \"cont_1\", \"cont_2\", \"cont_3\", \"cont_4\"],\n", " \"categorical_feature_names\": [\"cat_0\", \"cat_1\", \"cat_2\"],\n", " \"label_name\": [\"target\"],\n", "}\n", "```\n", "Load `configs/sample.py` and see the changes." ], "metadata": { "collapsed": false, "pycharm": { "name": "#%% md\n" } } }, { "cell_type": "code", "execution_count": 3, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{\n", "\t'database': 'sample',\n", "\t'task': None,\n", "\t'loss': None,\n", "\t'bayes_opt': False,\n", "\t'bayes_calls': 50,\n", "\t'bayes_epoch': 30,\n", "\t'patience': 100,\n", "\t'epoch': 300,\n", "\t'lr': 0.001,\n", "\t'weight_decay': 1e-09,\n", "\t'batch_size': 1024,\n", "\t'layers': [\n", "\t\t64,\n", "\t\t128,\n", "\t\t256,\n", "\t\t128,\n", "\t\t64\n", "\t],\n", "\t'SPACEs': {\n", "\t\t'lr': {\n", "\t\t\t'type': 'Real',\n", "\t\t\t'low': 0.0001,\n", "\t\t\t'high': 0.05,\n", "\t\t\t'prior': 'log-uniform'\n", "\t\t},\n", "\t\t'weight_decay': {\n", "\t\t\t'type': 'Real',\n", "\t\t\t'low': 1e-09,\n", "\t\t\t'high': 0.05,\n", "\t\t\t'prior': 'log-uniform'\n", "\t\t},\n", "\t\t'batch_size': {\n", "\t\t\t'type': 'Categorical',\n", "\t\t\t'categories': [\n", "\t\t\t\t64,\n", "\t\t\t\t128,\n", "\t\t\t\t256,\n", "\t\t\t\t512,\n", "\t\t\t\t1024,\n", "\t\t\t\t2048\n", "\t\t\t]\n", "\t\t}\n", "\t},\n", "\t'data_splitter': 'RandomSplitter',\n", "\t'split_ratio': [\n", "\t\t0.6,\n", "\t\t0.2,\n", "\t\t0.2\n", "\t],\n", "\t'data_imputer': 'MissForestImputer',\n", "\t'data_processors': [\n", "\t\t(\n", "\t\t\t'CategoricalOrdinalEncoder',\n", "\t\t\t{\n", "\t\t\t}\n", "\t\t),\n", "\t\t(\n", "\t\t\t'NaNFeatureRemover',\n", "\t\t\t{\n", "\t\t\t}\n", "\t\t),\n", "\t\t(\n", "\t\t\t'VarianceFeatureSelector',\n", "\t\t\t{\n", "\t\t\t\t'thres': 1\n", "\t\t\t}\n", "\t\t),\n", "\t\t(\n", "\t\t\t'StandardScaler',\n", "\t\t\t{\n", "\t\t\t}\n", "\t\t)\n", "\t],\n", "\t'data_derivers': [\n", "\t],\n", "\t'categorical_feature_names': [\n", "\t\t'cat_0',\n", "\t\t'cat_1',\n", "\t\t'cat_2'\n", "\t],\n", "\t'continuous_feature_names': [\n", "\t\t'cont_0',\n", "\t\t'cont_1',\n", "\t\t'cont_2',\n", "\t\t'cont_3',\n", "\t\t'cont_4'\n", "\t],\n", "\t'feature_types': {\n", "\t\t'cont_0': 'Continuous',\n", "\t\t'cont_1': 'Continuous',\n", "\t\t'cont_2': 'Continuous',\n", "\t\t'cont_3': 'Continuous',\n", "\t\t'cont_4': 'Continuous',\n", "\t\t'cat_0': 'Categorical',\n", "\t\t'cat_1': 'Categorical',\n", "\t\t'cat_2': 'Categorical'\n", "\t},\n", "\t'unique_feature_types': [\n", "\t\t'Categorical',\n", "\t\t'Continuous'\n", "\t],\n", "\t'label_name': [\n", "\t\t'target'\n", "\t]\n", "}\n" ] } ], "source": [ "cfg = UserConfig(\"sample\")\n", "print(pretty(cfg))" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "markdown", "source": [ "### Descriptions of keys in a configuration file\n", "\n", "* `database`: The name of the database file. The file should be placed in the script directory or in `tabensemb.setting[\"default_data_path\"]`. If no postfix (`.csv` or `.xlsx`) is provided, the program automatically searches for a matched postfix. If both `.csv` and `.xlsx` exist, an exception will be raised.\n", "* `task`: \"regression\" for regression tasks, \"binary\" for binary classifications, and \"multiclass\" for multiclass classifications. If left None, the task will be guessed from the type of the target. If the target is of the type `object` or integers, \"binary\" or \"multiclass\" is guessed depending on the number of unique targets; otherwise, \"regression\" is guessed.\n", "* `loss`: \"mse\" (default) or \"mae\" for regression tasks, and \"cross_entropy\" for classification tasks. This loss will be used across all model bases. If left None, \"mse\" or \"cross_entropy\" will be used.\n", "* `bayes_opt`: Perform gaussian-process-based Bayesian hyperparameter optimization (HPO) using the `scikit-optimize` package when training each model.\n", "* `bayes_calls`: The number of calls of the Bayesian HPO. During each call, the model will be trained given a set of hyperparameters, and then the metric on the validation set will be returned to the Bayesian HPO process.\n", "* `bayes_epoch`: The number of epochs during each Bayesian HPO call.\n", "* `patience`: Early stopping patience. If the metric on the validation set does not improve after `patience` epochs, the training process terminates and the best model is loaded.\n", "* `epoch`: Total epochs to train each model.\n", "* `lr`: Initial learning rate.\n", "* `weight_decay`: Initial weight_decay (for a `torch.optim.Adam` optimizer)\n", "* `batch_size`: Initial batch_size.\n", "* `layers`: Default hidden layers for some models.\n", "* `SPACEs`: Default bayesian HPO spaces for `lr`, `weight_decay`, and `batch_size`. The key `type` determines the `skopt.space`, and the rest of the keys determines its arguments.\n", "* `data_splitter`: The dataset splitting method to split training/validation/testing sets. See `tabensemb.data.datasplitter.splitter_mapping` for available classes.\n", "* `split_ratio`: The ratio of training/validation/testing sets.\n", "* `data_imputer`: The imputation method for `NaN` values. See `tabensemb.data.dataimputer.imputer_mapping` for available classes.\n", "* `data_processors`: A list of data processing steps and their corresponding arguments. See `tabensemb.data.dataprocessor.processor_mapping` for available classes. See API docs for definitions of arguments.\n", "* `data_derivers`: A list of feature augmentation steps and their corresponding arguments. Some fix arguments are\n", " * `stacked`: `True` to append the derived feature to continuous features and the final `DataFrame` representing the processed dataset. `False` to leave it as an unstacked feature (mostly for multi-modal data)\n", " * `intermediate`: `True` to ignore the derived feature in continuous features even when `stacked=True`, but still append the feature to the `DataFrame`.\n", "\n", " See `tabensemb.data.dataderiver.deriver_mapping` for available classes. See API docs or `_required_cols` of each class for its additional arguments.\n", "* `continuous_feature_names`: Continuous features. Each of them should be all floats or integers.\n", "* `categorical_feature_names`: Categorical features. Each of them should be all integers or strings.\n", "* `feature_types`: A dictionary stating categories of each feature defined in `continuous_feature_names` and `categorical_feature_names`. If it is not given in the configuration, \"Continuous\" and \"Categorical\" will be automatically used to assign the values of continuous and categorical features, respectively.\n", "* `unique_feature_types`: Unique values in the dictionary `feature_types`.\n", "* `label_name`: The predicted target." ], "metadata": { "collapsed": false, "pycharm": { "name": "#%% md\n" } } }, { "cell_type": "markdown", "source": [ "## Use the configuration file to load the dataset\n", "\n", "The `DataModule` requires a `UserConfig` to load the dataset, then initialize and run all data processing steps on the dataset or an upcoming new dataset. The following lines load the dataset and present the loaded and processed `DataFrame` without imputation." ], "metadata": { "collapsed": false, "pycharm": { "name": "#%% md\n" } } }, { "cell_type": "code", "execution_count": 4, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Dataset size: 153 51 52\n" ] }, { "data": { "text/plain": " cont_0 cont_1 cont_2 cont_3 cont_4 cont_5 cont_6 \\\n0 -1.306527 NaN -0.118164 -0.159573 1.658131 -1.346718 -0.680178 \n1 2.011257 NaN 0.195070 0.527004 -0.044595 0.616887 -1.781563 \n2 -1.216077 NaN -0.743672 0.730184 0.140672 1.272954 -0.159012 \n3 0.559299 NaN -0.431096 -0.809627 -1.063696 -0.860153 0.572751 \n4 0.910179 NaN 0.786328 -0.042257 0.317218 0.379152 -0.466419 \n.. ... ... ... ... ... ... ... \n251 0.280442 -0.206904 0.841631 0.880179 -0.993124 -1.570623 -0.249459 \n252 -1.165150 -1.070753 0.465662 1.054452 0.900826 -0.179925 -1.536244 \n253 -0.069856 -0.186691 -1.021913 -1.143641 0.250114 1.040239 -1.150438 \n254 -1.031482 -0.860262 -0.061638 0.328301 -1.429991 -1.048170 -1.432735 \n255 -1.461733 0.960693 0.367545 1.329063 -0.683440 -1.184687 0.190312 \n\n cont_7 cont_8 cont_9 ... cat_3 cat_4 cat_5 cat_6 cat_7 \\\n0 -1.334258 0.666383 -0.460720 ... 0 2 category_4 3 4 \n1 0.354758 -0.729045 0.196557 ... 4 3 category_3 3 1 \n2 -0.475175 0.240057 0.100159 ... 0 4 category_3 4 1 \n3 -0.467441 0.677557 1.307184 ... 4 1 category_3 4 2 \n4 -0.017020 -0.944446 -0.410050 ... 1 0 category_2 0 2 \n.. ... ... ... ... ... ... ... ... ... \n251 0.643314 0.049495 0.493837 ... 1 2 category_2 2 3 \n252 1.178780 1.488252 1.895889 ... 4 2 category_4 4 2 \n253 0.258798 -0.836111 0.642211 ... 0 3 category_3 2 2 \n254 0.607112 0.087531 0.938747 ... 0 0 category_3 4 1 \n255 -0.521580 -0.851729 1.822724 ... 2 1 category_3 4 1 \n\n cat_8 cat_9 target target_binary target_multi_class \n0 4 3 -71.084217 0 1 \n1 3 2 13.415675 1 2 \n2 0 2 -47.492280 0 2 \n3 0 0 -94.482614 1 2 \n4 3 0 195.819531 1 3 \n.. ... ... ... ... ... \n251 0 2 -171.249549 0 0 \n252 1 1 23.708442 0 2 \n253 2 2 -33.414215 1 1 \n254 4 4 -359.199191 0 4 \n255 1 4 -135.199100 1 2 \n\n[256 rows x 23 columns]", "text/html": "
\n\n\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n
cont_0cont_1cont_2cont_3cont_4cont_5cont_6cont_7cont_8cont_9...cat_3cat_4cat_5cat_6cat_7cat_8cat_9targettarget_binarytarget_multi_class
0-1.306527NaN-0.118164-0.1595731.658131-1.346718-0.680178-1.3342580.666383-0.460720...02category_43443-71.08421701
12.011257NaN0.1950700.527004-0.0445950.616887-1.7815630.354758-0.7290450.196557...43category_3313213.41567512
2-1.216077NaN-0.7436720.7301840.1406721.272954-0.159012-0.4751750.2400570.100159...04category_34102-47.49228002
30.559299NaN-0.431096-0.809627-1.063696-0.8601530.572751-0.4674410.6775571.307184...41category_34200-94.48261412
40.910179NaN0.786328-0.0422570.3172180.379152-0.466419-0.017020-0.944446-0.410050...10category_20230195.81953113
..................................................................
2510.280442-0.2069040.8416310.880179-0.993124-1.570623-0.2494590.6433140.0494950.493837...12category_22302-171.24954900
252-1.165150-1.0707530.4656621.0544520.900826-0.179925-1.5362441.1787801.4882521.895889...42category_4421123.70844202
253-0.069856-0.186691-1.021913-1.1436410.2501141.040239-1.1504380.258798-0.8361110.642211...03category_32222-33.41421511
254-1.031482-0.860262-0.0616380.328301-1.429991-1.048170-1.4327350.6071120.0875310.938747...00category_34144-359.19919104
255-1.4617330.9606930.3675451.329063-0.683440-1.1846870.190312-0.521580-0.8517291.822724...21category_34114-135.19910012
\n

256 rows × 23 columns

\n
" }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from tabensemb.data import DataModule\n", "tabensemb.setting[\"default_data_path\"] = prefix + \"data\"\n", "\n", "datamodule = DataModule(cfg)\n", "datamodule.load_data()\n", "datamodule.get_not_imputed_df()" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "markdown", "source": [ "`DataModule.df` present the imputed `DataFrame`." ], "metadata": { "collapsed": false, "pycharm": { "name": "#%% md\n" } } }, { "cell_type": "code", "execution_count": 5, "outputs": [ { "data": { "text/plain": " cont_0 cont_1 cont_2 cont_3 cont_4 cont_5 cont_6 \\\n0 -1.306527 0.256888 -0.118164 -0.159573 1.658131 -1.346718 -0.680178 \n1 2.011257 0.256888 0.195070 0.527004 -0.044595 0.616887 -1.781563 \n2 -1.216077 0.256888 -0.743672 0.730184 0.140672 1.272954 -0.159012 \n3 0.559299 0.256888 -0.431096 -0.809627 -1.063696 -0.860153 0.572751 \n4 0.910179 -0.228308 0.786328 -0.042257 0.317218 0.379152 -0.466419 \n.. ... ... ... ... ... ... ... \n251 0.280442 -0.206904 0.841631 0.880179 -0.993124 -1.570623 -0.249459 \n252 -1.165150 -1.070753 0.465662 1.054452 0.900826 -0.179925 -1.536244 \n253 -0.069856 -0.186691 -1.021913 -1.143641 0.250114 1.040239 -1.150438 \n254 -1.031482 -0.860262 -0.061638 0.328301 -1.429991 -1.048170 -1.432735 \n255 -1.461733 0.960693 0.367545 1.329063 -0.683440 -1.184687 0.190312 \n\n cont_7 cont_8 cont_9 ... cat_3 cat_4 cat_5 cat_6 \\\n0 -1.334258 0.666383 -0.460720 ... 0 2 category_4 3 \n1 0.354758 -0.729045 0.196557 ... 4 3 category_3 3 \n2 -0.475175 0.240057 0.100159 ... 0 4 category_3 4 \n3 -0.467441 0.677557 1.307184 ... 4 1 category_3 4 \n4 -0.017020 -0.944446 -0.410050 ... 1 0 category_2 0 \n.. ... ... ... ... ... ... ... ... \n251 0.643314 0.049495 0.493837 ... 1 2 category_2 2 \n252 1.178780 1.488252 1.895889 ... 4 2 category_4 4 \n253 0.258798 -0.836111 0.642211 ... 0 3 category_3 2 \n254 0.607112 0.087531 0.938747 ... 0 0 category_3 4 \n255 -0.521580 -0.851729 1.822724 ... 2 1 category_3 4 \n\n cat_7 cat_8 cat_9 target target_binary target_multi_class \n0 4 4 3 -71.084217 0 1 \n1 1 3 2 13.415675 1 2 \n2 1 0 2 -47.492280 0 2 \n3 2 0 0 -94.482614 1 2 \n4 2 3 0 195.819531 1 3 \n.. ... ... ... ... ... ... \n251 3 0 2 -171.249549 0 0 \n252 2 1 1 23.708442 0 2 \n253 2 2 2 -33.414215 1 1 \n254 1 4 4 -359.199191 0 4 \n255 1 1 4 -135.199100 1 2 \n\n[256 rows x 23 columns]", "text/html": "
\n\n\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n
cont_0cont_1cont_2cont_3cont_4cont_5cont_6cont_7cont_8cont_9...cat_3cat_4cat_5cat_6cat_7cat_8cat_9targettarget_binarytarget_multi_class
0-1.3065270.256888-0.118164-0.1595731.658131-1.346718-0.680178-1.3342580.666383-0.460720...02category_43443-71.08421701
12.0112570.2568880.1950700.527004-0.0445950.616887-1.7815630.354758-0.7290450.196557...43category_3313213.41567512
2-1.2160770.256888-0.7436720.7301840.1406721.272954-0.159012-0.4751750.2400570.100159...04category_34102-47.49228002
30.5592990.256888-0.431096-0.809627-1.063696-0.8601530.572751-0.4674410.6775571.307184...41category_34200-94.48261412
40.910179-0.2283080.786328-0.0422570.3172180.379152-0.466419-0.017020-0.944446-0.410050...10category_20230195.81953113
..................................................................
2510.280442-0.2069040.8416310.880179-0.993124-1.570623-0.2494590.6433140.0494950.493837...12category_22302-171.24954900
252-1.165150-1.0707530.4656621.0544520.900826-0.179925-1.5362441.1787801.4882521.895889...42category_4421123.70844202
253-0.069856-0.186691-1.021913-1.1436410.2501141.040239-1.1504380.258798-0.8361110.642211...03category_32222-33.41421511
254-1.031482-0.860262-0.0616380.328301-1.429991-1.048170-1.4327350.6071120.0875310.938747...00category_34144-359.19919104
255-1.4617330.9606930.3675451.329063-0.683440-1.1846870.190312-0.521580-0.8517291.822724...21category_34114-135.19910012
\n

256 rows × 23 columns

\n
" }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "datamodule.df" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "markdown", "source": [ "`DataModule.train_indices`, `DataModule.val_indices`, and `DataModule.test_indices` represent indices of training/validation/testing sets, respectively." ], "metadata": { "collapsed": false, "pycharm": { "name": "#%% md\n" } } }, { "cell_type": "code", "execution_count": 6, "outputs": [ { "data": { "text/plain": "(array([158, 161, 186, 1, 208, 60, 202, 141, 13, 113, 216, 49, 75,\n 151, 159, 200, 190, 218, 32, 58, 28, 54, 19, 129, 24, 183,\n 61, 66, 253, 18, 5, 85, 43, 92, 0, 63, 185, 163, 244,\n 138, 20, 44, 96, 59, 236, 27, 94, 79, 100, 10, 213, 106,\n 176, 93, 62, 240, 98, 239, 247, 221, 204, 166, 99, 160, 181,\n 209, 150, 21, 14, 89, 201, 167, 145, 41, 12, 128, 101, 251,\n 124, 69, 205, 82, 241, 6, 191, 226, 117, 45, 22, 110, 35,\n 33, 74, 148, 105, 34, 77, 168, 90, 84, 179, 78, 2, 220,\n 155, 184, 47, 15, 140, 72, 195, 243, 232, 23, 39, 127, 71,\n 111, 144, 107, 211, 210, 173, 50, 254, 237, 194, 68, 162, 70,\n 197, 231, 123, 103, 228, 170, 136, 142, 80, 207, 116, 40, 171,\n 91, 135, 152, 248, 4, 125, 104, 83, 121, 177]),\n array([165, 122, 187, 235, 238, 17, 87, 29, 42, 174, 178, 48, 169,\n 65, 46, 242, 224, 245, 130, 97, 215, 56, 164, 143, 57, 7,\n 175, 137, 252, 249, 203, 88, 3, 223, 16, 109, 154, 86, 227,\n 64, 38, 115, 149, 222, 95, 188, 102, 131, 219, 250, 126]),\n array([ 9, 36, 153, 114, 120, 108, 246, 139, 118, 198, 76, 51, 212,\n 157, 196, 156, 172, 146, 189, 225, 233, 206, 37, 52, 199, 182,\n 229, 30, 192, 180, 234, 67, 11, 112, 8, 134, 147, 132, 25,\n 81, 119, 214, 53, 230, 31, 26, 73, 55, 217, 255, 133, 193]))" }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "datamodule.train_indices, datamodule.val_indices, datamodule.test_indices" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "markdown", "source": [ "For detailed functionalities of `DataModule`, please check the API documentation." ], "metadata": { "collapsed": false, "pycharm": { "name": "#%% md\n" } } }, { "cell_type": "markdown", "source": [ "## A `Trainer` does all things for you\n", "\n", "Indeed, a user does not need to manually generate a `UserConfig` or a `DataModule` because `Trainer` does all the above steps. After calling `Trainer.load_config` and `Trainer.load_data`, a `UserConfig` instance containing configurations, a `DataModule` instance containing processing steps, and loaded data are generated and can be accessed by `Trainer.args` and `Trainer.datamodule`, respectively.\n" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%% md\n" } } }, { "cell_type": "code", "execution_count": 7, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "The project will be saved to ../../../../output/sample/2023-09-23-20-37-10-0_sample\n", "Dataset size: 153 51 52\n", "Data saved to ../../../../output/sample/2023-09-23-20-37-10-0_sample (data.csv and tabular_data.csv).\n" ] }, { "data": { "text/plain": "(tabensemb.config.user_config.UserConfig, tabensemb.data.datamodule.DataModule)" }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from tabensemb.trainer import Trainer\n", "\n", "tabensemb.setting[\"default_output_path\"] = prefix + \"output\"\n", "trainer = Trainer(device=\"cpu\")\n", "trainer.load_config(\"sample\")\n", "trainer.load_data()\n", "type(trainer.args), type(trainer.datamodule)" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 2 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython2", "version": "2.7.6" } }, "nbformat": 4, "nbformat_minor": 0 }