{ "cells": [ { "cell_type": "markdown", "source": [ "# New data derivers\n", "\n", "In this package, a very limited number of derivers are currently provided. A deriver can be used to calculate new features (continuous or categorical) based on existing features, or load images, text, etc. as multimodal data. The source code of the integrated `tabensemb.data.dataderiver.RelativeDeriver` is extended here to demonstrate the implementation procedure.\n" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%% md\n" } } }, { "cell_type": "code", "execution_count": 1, "outputs": [], "source": [ "from tabensemb.data.dataderiver import AbstractDeriver" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "markdown", "source": [ "Data derivers inherit `tabensemb.data.AbstractDervier` and four methods should be implemented:\n", "\n", "* `_required_cols`: Arguments for columns that must exist in the tabular dataset. The following code means that the arguments `absolute_col` and `relative2_col` should be given in the configuration, such as `\"data_derivers\": [(\"MyRelativeDeriver\", {\"absolute_col\": \"cont_0\", \"relative2_col\": \"cont_1\"})]`\n", "\n", "```python\n", "class MyRelativeDeriver(AbstractDeriver):\n", " def _required_cols(self):\n", " return [\"absolute_col\", \"relative2_col\"]\n", "```\n", "\n", "* `_required_kwargs`: Parameters that must be specified in the configuration. The following code means that the parameter `some_param` should be given in the configuration, such as `\"data_derivers\": [(\"MyRelativeDeriver\", {\"some_param\": 1.5})]`\n", "\n", "```python\n", " def _required_kwargs(self):\n", " return [\"some_param\"]\n", "```\n", "\n", "**Remark**: \"stacked\", \"intermediate\", \"derived_name\", and \"is_continuous\" are shared necessary kwargs and do not need to be added to `_required_kwargs`.\n", "\n", "* `_defaults`: Default values of those in `_required_cols`, `_required_kwargs`, and `[\"stacked\", \"intermediate\", \"derived_name\", \"is_continuous\"]`. If default values are given, no error will be raised if the argument is not set in the configuration.\n", "\n", "```python\n", " def _defaults(self):\n", " return dict(stacked=True, intermediate=False, is_continuous=True)\n", "```\n", "\n", "* `_derive`: The main derivation step. It receives the tabular data (a `DataFrame`) and a `DataModule` and should return an `np.ndarray`. The returned array can not be 1d. Arguments are checked and recorded in `self.kwargs` when initializing.\n", "\n", "```python\n", " def _derive(self, df, datamodule):\n", " absolute_col = self.kwargs[\"absolute_col\"]\n", " relative2_col = self.kwargs[\"relative2_col\"]\n", " some_param = self.kwargs[\"some_param\"]\n", " stacked = self.kwargs[\"stacked\"]\n", "\n", " relative = df[absolute_col] / df[relative2_col]\n", " relative = relative.values.reshape(-1, 1)\n", " return relative\n", "```" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%% md\n" } } }, { "cell_type": "code", "execution_count": 2, "metadata": { "collapsed": true, "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "class MyRelativeDeriver(AbstractDeriver):\n", " def _required_cols(self):\n", " return [\"absolute_col\", \"relative2_col\"]\n", "\n", " def _required_kwargs(self):\n", " return [\"some_param\"]\n", "\n", " def _defaults(self):\n", " return dict(stacked=True, intermediate=False, is_continuous=True)\n", "\n", " def _derive(self, df, datamodule):\n", " absolute_col = self.kwargs[\"absolute_col\"]\n", " relative2_col = self.kwargs[\"relative2_col\"]\n", " some_param = self.kwargs[\"some_param\"]\n", " stacked = self.kwargs[\"stacked\"]\n", "\n", " relative = df[absolute_col] / df[relative2_col]\n", " relative = relative.values.reshape(-1, 1)\n", " return relative" ] }, { "cell_type": "markdown", "source": [ "The implemented splitter should be registered as follows to be recognized by `DataModule.set_data_derivers` automatically." ], "metadata": { "collapsed": false, "pycharm": { "name": "#%% md\n" } } }, { "cell_type": "code", "execution_count": 3, "outputs": [], "source": [ "from tabensemb.data.dataderiver import deriver_mapping\n", "deriver_mapping[\"MyRelativeDeriver\"] = MyRelativeDeriver" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "code", "execution_count": 4, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "The project will be saved to ../../../../output/sample/2023-09-18-18-15-00-0_sample\n" ] } ], "source": [ "from tabensemb.trainer import Trainer\n", "import tabensemb\n", "\n", "prefix = \"../../../../\"\n", "tabensemb.setting[\"default_output_path\"] = prefix + \"output\"\n", "tabensemb.setting[\"default_config_path\"] = prefix + \"configs\"\n", "tabensemb.setting[\"default_data_path\"] = prefix + \"data\"\n", "\n", "trainer = Trainer(device=\"cpu\")\n", "\n", "trainer.load_config(\"sample\")" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "markdown", "source": [ "If `stacked` is `True`:" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%% md\n" } } }, { "cell_type": "code", "execution_count": 5, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Dataset size: 153 51 52\n", "Data saved to ../../../../output/sample/2023-09-18-18-15-00-0_sample (data.csv and tabular_data.csv).\n", "cont_0_relative2_cont_1 in continuous features?: True\n" ] }, { "data": { "text/plain": " cont_0 cont_1 cont_2 cont_3 cont_4 cont_5 cont_6 \\\n0 -1.306527 0.065895 -0.118164 -0.159573 1.658131 -1.346718 -0.680178 \n1 2.011257 0.117717 0.195070 0.527004 -0.044595 0.616887 -1.781563 \n2 -1.216077 0.065895 -0.743672 0.730184 0.140672 1.272954 -0.159012 \n3 0.559299 0.117717 -0.431096 -0.809627 -1.063696 -0.860153 0.572751 \n4 0.910179 -0.213096 0.786328 -0.042257 0.317218 0.379152 -0.466419 \n.. ... ... ... ... ... ... ... \n251 0.280442 -0.206904 0.841631 0.880179 -0.993124 -1.570623 -0.249459 \n252 -1.165150 -1.070753 0.465662 1.054452 0.900826 -0.179925 -1.536244 \n253 -0.069856 -0.186691 -1.021913 -1.143641 0.250114 1.040239 -1.150438 \n254 -1.031482 -0.860262 -0.061638 0.328301 -1.429991 -1.048170 -1.432735 \n255 -1.461733 0.960693 0.367545 1.329063 -0.683440 -1.184687 0.190312 \n\n cont_7 cont_8 cont_9 ... cat_4 cat_5 cat_6 cat_7 \\\n0 -1.334258 0.666383 -0.460720 ... 2 category_4 3 4 \n1 0.354758 -0.729045 0.196557 ... 3 category_3 3 1 \n2 -0.475175 0.240057 0.100159 ... 4 category_3 4 1 \n3 -0.467441 0.677557 1.307184 ... 1 category_3 4 2 \n4 -0.017020 -0.944446 -0.410050 ... 0 category_2 0 2 \n.. ... ... ... ... ... ... ... ... \n251 0.643314 0.049495 0.493837 ... 2 category_2 2 3 \n252 1.178780 1.488252 1.895889 ... 2 category_4 4 2 \n253 0.258798 -0.836111 0.642211 ... 3 category_3 2 2 \n254 0.607112 0.087531 0.938747 ... 0 category_3 4 1 \n255 -0.521580 -0.851729 1.822724 ... 1 category_3 4 1 \n\n cat_8 cat_9 target target_binary target_multi_class \\\n0 4 3 -71.084217 0 1 \n1 3 2 13.415675 1 2 \n2 0 2 -47.492280 0 2 \n3 0 0 -94.482614 1 2 \n4 3 0 195.819531 1 3 \n.. ... ... ... ... ... \n251 0 2 -171.249549 0 0 \n252 1 1 23.708442 0 2 \n253 2 2 -33.414215 1 1 \n254 4 4 -359.199191 0 4 \n255 1 4 -135.199100 1 2 \n\n cont_0_relative2_cont_1 \n0 -19.827301 \n1 17.085552 \n2 -18.454666 \n3 4.751225 \n4 -4.271217 \n.. ... \n251 -1.355422 \n252 1.088160 \n253 0.374183 \n254 1.199032 \n255 -1.521539 \n\n[256 rows x 24 columns]", "text/html": "
| \n | cont_0 | \ncont_1 | \ncont_2 | \ncont_3 | \ncont_4 | \ncont_5 | \ncont_6 | \ncont_7 | \ncont_8 | \ncont_9 | \n... | \ncat_4 | \ncat_5 | \ncat_6 | \ncat_7 | \ncat_8 | \ncat_9 | \ntarget | \ntarget_binary | \ntarget_multi_class | \ncont_0_relative2_cont_1 | \n
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | \n-1.306527 | \n0.065895 | \n-0.118164 | \n-0.159573 | \n1.658131 | \n-1.346718 | \n-0.680178 | \n-1.334258 | \n0.666383 | \n-0.460720 | \n... | \n2 | \ncategory_4 | \n3 | \n4 | \n4 | \n3 | \n-71.084217 | \n0 | \n1 | \n-19.827301 | \n
| 1 | \n2.011257 | \n0.117717 | \n0.195070 | \n0.527004 | \n-0.044595 | \n0.616887 | \n-1.781563 | \n0.354758 | \n-0.729045 | \n0.196557 | \n... | \n3 | \ncategory_3 | \n3 | \n1 | \n3 | \n2 | \n13.415675 | \n1 | \n2 | \n17.085552 | \n
| 2 | \n-1.216077 | \n0.065895 | \n-0.743672 | \n0.730184 | \n0.140672 | \n1.272954 | \n-0.159012 | \n-0.475175 | \n0.240057 | \n0.100159 | \n... | \n4 | \ncategory_3 | \n4 | \n1 | \n0 | \n2 | \n-47.492280 | \n0 | \n2 | \n-18.454666 | \n
| 3 | \n0.559299 | \n0.117717 | \n-0.431096 | \n-0.809627 | \n-1.063696 | \n-0.860153 | \n0.572751 | \n-0.467441 | \n0.677557 | \n1.307184 | \n... | \n1 | \ncategory_3 | \n4 | \n2 | \n0 | \n0 | \n-94.482614 | \n1 | \n2 | \n4.751225 | \n
| 4 | \n0.910179 | \n-0.213096 | \n0.786328 | \n-0.042257 | \n0.317218 | \n0.379152 | \n-0.466419 | \n-0.017020 | \n-0.944446 | \n-0.410050 | \n... | \n0 | \ncategory_2 | \n0 | \n2 | \n3 | \n0 | \n195.819531 | \n1 | \n3 | \n-4.271217 | \n
| ... | \n... | \n... | \n... | \n... | \n... | \n... | \n... | \n... | \n... | \n... | \n... | \n... | \n... | \n... | \n... | \n... | \n... | \n... | \n... | \n... | \n... | \n
| 251 | \n0.280442 | \n-0.206904 | \n0.841631 | \n0.880179 | \n-0.993124 | \n-1.570623 | \n-0.249459 | \n0.643314 | \n0.049495 | \n0.493837 | \n... | \n2 | \ncategory_2 | \n2 | \n3 | \n0 | \n2 | \n-171.249549 | \n0 | \n0 | \n-1.355422 | \n
| 252 | \n-1.165150 | \n-1.070753 | \n0.465662 | \n1.054452 | \n0.900826 | \n-0.179925 | \n-1.536244 | \n1.178780 | \n1.488252 | \n1.895889 | \n... | \n2 | \ncategory_4 | \n4 | \n2 | \n1 | \n1 | \n23.708442 | \n0 | \n2 | \n1.088160 | \n
| 253 | \n-0.069856 | \n-0.186691 | \n-1.021913 | \n-1.143641 | \n0.250114 | \n1.040239 | \n-1.150438 | \n0.258798 | \n-0.836111 | \n0.642211 | \n... | \n3 | \ncategory_3 | \n2 | \n2 | \n2 | \n2 | \n-33.414215 | \n1 | \n1 | \n0.374183 | \n
| 254 | \n-1.031482 | \n-0.860262 | \n-0.061638 | \n0.328301 | \n-1.429991 | \n-1.048170 | \n-1.432735 | \n0.607112 | \n0.087531 | \n0.938747 | \n... | \n0 | \ncategory_3 | \n4 | \n1 | \n4 | \n4 | \n-359.199191 | \n0 | \n4 | \n1.199032 | \n
| 255 | \n-1.461733 | \n0.960693 | \n0.367545 | \n1.329063 | \n-0.683440 | \n-1.184687 | \n0.190312 | \n-0.521580 | \n-0.851729 | \n1.822724 | \n... | \n1 | \ncategory_3 | \n4 | \n1 | \n1 | \n4 | \n-135.199100 | \n1 | \n2 | \n-1.521539 | \n
256 rows × 24 columns
\n| \n | cont_0 | \ncont_1 | \ncont_2 | \ncont_3 | \ncont_4 | \ncont_5 | \ncont_6 | \ncont_7 | \ncont_8 | \ncont_9 | \n... | \ncat_4 | \ncat_5 | \ncat_6 | \ncat_7 | \ncat_8 | \ncat_9 | \ntarget | \ntarget_binary | \ntarget_multi_class | \ncont_0_relative2_cont_1 | \n
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | \n-1.306527 | \n-0.409756 | \n-0.118164 | \n-0.159573 | \n1.658131 | \n-1.346718 | \n-0.680178 | \n-1.334258 | \n0.666383 | \n-0.460720 | \n... | \n2 | \ncategory_4 | \n3 | \n4 | \n4 | \n3 | \n-71.084217 | \n0 | \n1 | \n3.188552 | \n
| 1 | \n2.011257 | \n-0.409756 | \n0.195070 | \n0.527004 | \n-0.044595 | \n0.616887 | \n-1.781563 | \n0.354758 | \n-0.729045 | \n0.196557 | \n... | \n3 | \ncategory_3 | \n3 | \n1 | \n3 | \n2 | \n13.415675 | \n1 | \n2 | \n-4.908431 | \n
| 2 | \n-1.216077 | \n0.104704 | \n-0.743672 | \n0.730184 | \n0.140672 | \n1.272954 | \n-0.159012 | \n-0.475175 | \n0.240057 | \n0.100159 | \n... | \n4 | \ncategory_3 | \n4 | \n1 | \n0 | \n2 | \n-47.492280 | \n0 | \n2 | \n-11.614467 | \n
| 3 | \n0.559299 | \n0.104704 | \n-0.431096 | \n-0.809627 | \n-1.063696 | \n-0.860153 | \n0.572751 | \n-0.467441 | \n0.677557 | \n1.307184 | \n... | \n1 | \ncategory_3 | \n4 | \n2 | \n0 | \n0 | \n-94.482614 | \n1 | \n2 | \n5.341736 | \n
| 4 | \n0.910179 | \n-0.409756 | \n0.786328 | \n-0.042257 | \n0.317218 | \n0.379152 | \n-0.466419 | \n-0.017020 | \n-0.944446 | \n-0.410050 | \n... | \n0 | \ncategory_2 | \n0 | \n2 | \n3 | \n0 | \n195.819531 | \n1 | \n3 | \n-2.221273 | \n
| ... | \n... | \n... | \n... | \n... | \n... | \n... | \n... | \n... | \n... | \n... | \n... | \n... | \n... | \n... | \n... | \n... | \n... | \n... | \n... | \n... | \n... | \n
| 251 | \n0.280442 | \n-0.206904 | \n0.841631 | \n0.880179 | \n-0.993124 | \n-1.570623 | \n-0.249459 | \n0.643314 | \n0.049495 | \n0.493837 | \n... | \n2 | \ncategory_2 | \n2 | \n3 | \n0 | \n2 | \n-171.249549 | \n0 | \n0 | \n-1.355422 | \n
| 252 | \n-1.165150 | \n-1.070753 | \n0.465662 | \n1.054452 | \n0.900826 | \n-0.179925 | \n-1.536244 | \n1.178780 | \n1.488252 | \n1.895889 | \n... | \n2 | \ncategory_4 | \n4 | \n2 | \n1 | \n1 | \n23.708442 | \n0 | \n2 | \n1.088160 | \n
| 253 | \n-0.069856 | \n-0.186691 | \n-1.021913 | \n-1.143641 | \n0.250114 | \n1.040239 | \n-1.150438 | \n0.258798 | \n-0.836111 | \n0.642211 | \n... | \n3 | \ncategory_3 | \n2 | \n2 | \n2 | \n2 | \n-33.414215 | \n1 | \n1 | \n0.374183 | \n
| 254 | \n-1.031482 | \n-0.860262 | \n-0.061638 | \n0.328301 | \n-1.429991 | \n-1.048170 | \n-1.432735 | \n0.607112 | \n0.087531 | \n0.938747 | \n... | \n0 | \ncategory_3 | \n4 | \n1 | \n4 | \n4 | \n-359.199191 | \n0 | \n4 | \n1.199032 | \n
| 255 | \n-1.461733 | \n0.960693 | \n0.367545 | \n1.329063 | \n-0.683440 | \n-1.184687 | \n0.190312 | \n-0.521580 | \n-0.851729 | \n1.822724 | \n... | \n1 | \ncategory_3 | \n4 | \n1 | \n1 | \n4 | \n-135.199100 | \n1 | \n2 | \n-1.521539 | \n
256 rows × 24 columns
\n| \n | cont_0 | \ncont_1 | \ncont_2 | \ncont_3 | \ncont_4 | \ncont_5 | \ncont_6 | \ncont_7 | \ncont_8 | \ncont_9 | \n... | \ncat_3 | \ncat_4 | \ncat_5 | \ncat_6 | \ncat_7 | \ncat_8 | \ncat_9 | \ntarget | \ntarget_binary | \ntarget_multi_class | \n
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | \n-1.306527 | \n0.138315 | \n-0.118164 | \n-0.159573 | \n1.658131 | \n-1.346718 | \n-0.680178 | \n-1.334258 | \n0.666383 | \n-0.460720 | \n... | \n0 | \n2 | \ncategory_4 | \n3 | \n4 | \n4 | \n3 | \n-71.084217 | \n0 | \n1 | \n
| 1 | \n2.011257 | \n-0.006111 | \n0.195070 | \n0.527004 | \n-0.044595 | \n0.616887 | \n-1.781563 | \n0.354758 | \n-0.729045 | \n0.196557 | \n... | \n4 | \n3 | \ncategory_3 | \n3 | \n1 | \n3 | \n2 | \n13.415675 | \n1 | \n2 | \n
| 2 | \n-1.216077 | \n0.138315 | \n-0.743672 | \n0.730184 | \n0.140672 | \n1.272954 | \n-0.159012 | \n-0.475175 | \n0.240057 | \n0.100159 | \n... | \n0 | \n4 | \ncategory_3 | \n4 | \n1 | \n0 | \n2 | \n-47.492280 | \n0 | \n2 | \n
| 3 | \n0.559299 | \n-0.006111 | \n-0.431096 | \n-0.809627 | \n-1.063696 | \n-0.860153 | \n0.572751 | \n-0.467441 | \n0.677557 | \n1.307184 | \n... | \n4 | \n1 | \ncategory_3 | \n4 | \n2 | \n0 | \n0 | \n-94.482614 | \n1 | \n2 | \n
| 4 | \n0.910179 | \n-0.006111 | \n0.786328 | \n-0.042257 | \n0.317218 | \n0.379152 | \n-0.466419 | \n-0.017020 | \n-0.944446 | \n-0.410050 | \n... | \n1 | \n0 | \ncategory_2 | \n0 | \n2 | \n3 | \n0 | \n195.819531 | \n1 | \n3 | \n
| ... | \n... | \n... | \n... | \n... | \n... | \n... | \n... | \n... | \n... | \n... | \n... | \n... | \n... | \n... | \n... | \n... | \n... | \n... | \n... | \n... | \n... | \n
| 251 | \n0.280442 | \n-0.206904 | \n0.841631 | \n0.880179 | \n-0.993124 | \n-1.570623 | \n-0.249459 | \n0.643314 | \n0.049495 | \n0.493837 | \n... | \n1 | \n2 | \ncategory_2 | \n2 | \n3 | \n0 | \n2 | \n-171.249549 | \n0 | \n0 | \n
| 252 | \n-1.165150 | \n-1.070753 | \n0.465662 | \n1.054452 | \n0.900826 | \n-0.179925 | \n-1.536244 | \n1.178780 | \n1.488252 | \n1.895889 | \n... | \n4 | \n2 | \ncategory_4 | \n4 | \n2 | \n1 | \n1 | \n23.708442 | \n0 | \n2 | \n
| 253 | \n-0.069856 | \n-0.186691 | \n-1.021913 | \n-1.143641 | \n0.250114 | \n1.040239 | \n-1.150438 | \n0.258798 | \n-0.836111 | \n0.642211 | \n... | \n0 | \n3 | \ncategory_3 | \n2 | \n2 | \n2 | \n2 | \n-33.414215 | \n1 | \n1 | \n
| 254 | \n-1.031482 | \n-0.860262 | \n-0.061638 | \n0.328301 | \n-1.429991 | \n-1.048170 | \n-1.432735 | \n0.607112 | \n0.087531 | \n0.938747 | \n... | \n0 | \n0 | \ncategory_3 | \n4 | \n1 | \n4 | \n4 | \n-359.199191 | \n0 | \n4 | \n
| 255 | \n-1.461733 | \n0.960693 | \n0.367545 | \n1.329063 | \n-0.683440 | \n-1.184687 | \n0.190312 | \n-0.521580 | \n-0.851729 | \n1.822724 | \n... | \n2 | \n1 | \ncategory_3 | \n4 | \n1 | \n1 | \n4 | \n-135.199100 | \n1 | \n2 | \n
256 rows × 23 columns
\n