{ "cells": [ { "cell_type": "markdown", "source": [ "# Build your own model upon others\n", "\n", "Models can be built based on other trained models in the current model base or in other model bases. Both `AbstractModel` and `TorchModel` support this feature.\n", "\n", "## For `AbstractModel`" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%% md\n" } } }, { "cell_type": "code", "execution_count": 1, "metadata": { "collapsed": true, "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "import tabensemb\n", "import numpy as np\n", "import torch\n", "import os\n", "from tempfile import TemporaryDirectory\n", "from tabensemb.model import WideDeep, AbstractModel\n", "\n", "temp_path = TemporaryDirectory()\n", "tabensemb.setting[\"default_output_path\"] = os.path.join(temp_path.name, \"output\")\n", "tabensemb.setting[\"default_config_path\"] = os.path.join(temp_path.name, \"configs\")\n", "tabensemb.setting[\"default_data_path\"] = os.path.join(temp_path.name, \"data\")\n", "\n", "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"" ] }, { "cell_type": "markdown", "source": [ "Suppose that we want to call TabMlp of WideDeep in another model base `CallTabMlp`\n", "\n", "```python\n", "class CallTabMlp(AbstractModel):\n", " def _get_program_name(self):\n", " return \"CallTabMlp\"\n", "\n", " def _get_model_names(self):\n", " return [\"CalledTabMlp\"]\n", "\n", " def _space(self, model_name):\n", " return []\n", "\n", " def _initial_values(self, model_name):\n", " return {}\n", "```\n", "\n", "Extracting another model can be done by setting `required_models` in a specific format. In the following code, \"EXTERN\" means that the model is from another model base. \"WideDeep\" is the name of the model base which the wanted model is from. \"TabMlp\" is the wanted model in the model base. If the model is from the current model base, only the name of the wanted model is needed (`return [\"TabMlp\"]`). Multiple required models can be specified in the returned list.\n", "\n", "```python\n", " def required_models(self, model_name: str):\n", " return [\"EXTERN_WideDeep_TabMlp\"]\n", "```\n", "\n", "As normal, `_train_data_preprocess`, `_data_preprocess`, `_new_model`, `_train_single_model`, and `_pred_single_model` should be implemented. First, `_train_data_preprocess` is called, and `_get_required_models` is used to extract the external model. In this case, a `WideDeep` instance containing the trained TabMlp model is returned. If the model is from the current model base, calling `self._get_required_models(\"TabMlp\")` is equivalent to calling `self.model[\"TabMlp\"]`.\n", "\n", "Then the `_train_data_preprocess` method from `WideDeep` is directly used to process the dataset to get compatible processed data.\n", "\n", "```python\n", " def _train_data_preprocess(self, model_name):\n", " if not hasattr(self, \"net\"):\n", " self.net = self._get_required_models(\"TabMlp\")[\"EXTERN_WideDeep_TabMlp\"]\n", " self.net.trainer = self.trainer\n", " return self.net._train_data_preprocess(\"TabMlp\")\n", "```\n", "\n", "Also, `_data_preprocess` calls the same method from `WideDeep` instead to get compatible processed data.\n", "\n", "```python\n", " def _data_preprocess(self, df, derived_data, model_name):\n", " return self.net._data_preprocess(df, derived_data, \"TabMlp\")\n", "```\n", "\n", "In `_new_model`, the extracted model is directly returned.\n", "\n", "```python\n", " def _new_model(self, model_name, verbose, **kwargs):\n", " return self.net\n", "```\n", "\n", "`_pred_single_model` calls the same method from `WideDeep` to make predictions based on the extracted model.\n", "\n", "```python\n", " def _pred_single_model(self, model, X_test, verbose, **kwargs):\n", " return model._pred_single_model(model.model[\"TabMlp\"], X_test, verbose, **kwargs)\n", "```\n", "\n", "In this example, we won't do further training on the extracted model, but it is straightforward to do other operations on the predictions from the extracted model obtained by `model._pred_single_model` as shown above.\n", "\n", "```python\n", " def _train_single_model(self, *args, **kwargs):\n", " pass\n", "```" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%% md\n" } } }, { "cell_type": "code", "execution_count": 2, "outputs": [], "source": [ "class CallTabMlp(AbstractModel):\n", " def _get_program_name(self):\n", " return \"CallTabMlp\"\n", "\n", " def _get_model_names(self):\n", " return [\"TabMlp\"]\n", "\n", " def _space(self, model_name):\n", " return []\n", "\n", " def _initial_values(self, model_name):\n", " return {}\n", "\n", " def required_models(self, model_name: str):\n", " return [\"EXTERN_WideDeep_TabMlp\"]\n", "\n", " def _train_data_preprocess(self, model_name):\n", " if not hasattr(self, \"net\"):\n", " self.net = self._get_required_models(\"TabMlp\")[\"EXTERN_WideDeep_TabMlp\"]\n", " self.net.trainer = self.trainer\n", " return self.net._train_data_preprocess(\"TabMlp\")\n", "\n", " def _data_preprocess(self, df, derived_data, model_name):\n", " return self.net._data_preprocess(df, derived_data, \"TabMlp\")\n", "\n", " def _new_model(self, model_name, verbose, **kwargs):\n", " return self.net\n", "\n", " def _train_single_model(self, *args, **kwargs):\n", " pass\n", "\n", " def _pred_single_model(self, model, X_test, verbose, **kwargs):\n", " return model._pred_single_model(model.model[\"TabMlp\"], X_test, verbose, **kwargs)" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "markdown", "source": [ "## For `TorchModel`\n", "\n", "It is easier to build a model based on others in `TorchModel` because we have already implemented complex dataset-building operations internally.\n", "\n", "Similar to the implementation above, we specify methods except for `_train_data_preprocess` and `_data_preprocess`.\n", "\n", "```python\n", "class CallTabMlpTorch(TorchModel):\n", " def _get_program_name(self):\n", " return \"CallTabMlpTorch\"\n", "\n", " def _get_model_names(self):\n", " return [\"TabMlp\"]\n", "\n", " def required_models(self, model_name: str):\n", " return [\"EXTERN_WideDeep_TabMlp\"]\n", "\n", " def _space(self, model_name):\n", " return []\n", "\n", " def _initial_values(self, model_name):\n", " return {}\n", "```\n", "\n", "We build our model `CallTabMlpNN` on the top of TabMlp from WideDeep. In this tutorial, we will not train anything.\n", "\n", "```python\n", " def _new_model(self, model_name, verbose, **kwargs):\n", " return CallTabMlpNN(datamodule=self.trainer.datamodule, **kwargs)\n", "\n", " def _train_single_model(self, *args, **kwargs):\n", " pass\n", "```\n", "\n", "Now comes `CallTabMlpNN`. A positional argument `required_models` is passed to `__init__` containing all required and extracted models specified in `CallTabMlpTorch.required_models`.\n", "\n", "```python\n", "class CallTabMlpNN(AbstractNN):\n", " def __init__(self, datamodule, required_models, **kwargs):\n", " super(CallTabMlpNN, self).__init__(datamodule, **kwargs)\n", " self.net = required_models[\"EXTERN_WideDeep_TabMlp\"]\n", "```\n", "\n", "To get results from the extracted model, use `self.call_required_model`.\n", "\n", "```python\n", " def _forward(self, x: torch.Tensor, derived_tensors) -> torch.Tensor:\n", " return self.call_required_model(self.net, x, derived_tensors)\n", "```\n", "\n", "**Remark**: Indeed, the output of the model is already calculated when preparing the dataset and is stored in `derived_tensors[\"data_required_models\"][\"MODELNAME_pred\"]`. `self.call_required_model` first tries to find the pre-calculated output. If failed, the output is calculated using the dataset for the model base stored in `derived_tensors[\"data_required_models\"][\"MODELNAME\"]`. Therefore, if you want to actually calculate the output during `forward`, just remove the stored predictions in `derived_tensors`." ], "metadata": { "collapsed": false, "pycharm": { "name": "#%% md\n" } } }, { "cell_type": "code", "execution_count": 3, "outputs": [], "source": [ "from tabensemb.model import TorchModel, AbstractNN\n", "\n", "class CallTabMlpNN(AbstractNN):\n", " def __init__(self, datamodule, required_models, **kwargs):\n", " super(CallTabMlpNN, self).__init__(datamodule, **kwargs)\n", " self.net = required_models[\"EXTERN_WideDeep_TabMlp\"]\n", "\n", " def _forward(self, x: torch.Tensor, derived_tensors) -> torch.Tensor:\n", " return self.call_required_model(self.net, x, derived_tensors)\n", "\n", "class CallTabMlpTorch(TorchModel):\n", " def _new_model(self, model_name, verbose, **kwargs):\n", " return CallTabMlpNN(datamodule=self.trainer.datamodule, **kwargs)\n", "\n", " def _get_program_name(self):\n", " return \"CallTabMlpTorch\"\n", "\n", " def _get_model_names(self):\n", " return [\"TabMlp\"]\n", "\n", " def required_models(self, model_name: str):\n", " return [\"EXTERN_WideDeep_TabMlp\"]\n", "\n", " def _space(self, model_name):\n", " return []\n", "\n", " def _initial_values(self, model_name):\n", " return {}\n", "\n", " def _train_single_model(self, *args, **kwargs):\n", " pass" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "markdown", "source": [ "We can compare results from the original model and the extracted model. They get exactly the same results." ], "metadata": { "collapsed": false, "pycharm": { "name": "#%% md\n" } } }, { "cell_type": "code", "execution_count": 4, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Downloading https://archive.ics.uci.edu/static/public/9/auto+mpg.zip to /tmp/tmpvlx3s8em/data/Auto MPG.zip\n", "cylinders is Integer and will be treated as a continuous feature.\n", "model_year is Integer and will be treated as a continuous feature.\n", "origin is Integer and will be treated as a continuous feature.\n", "Unknown values are detected in ['horsepower']. They will be treated as np.nan.\n", "The project will be saved to /tmp/tmpvlx3s8em/output/auto-mpg/2023-09-23-20-41-06-0_UserInputConfig\n", "Dataset size: 238 80 80\n", "Data saved to /tmp/tmpvlx3s8em/output/auto-mpg/2023-09-23-20-41-06-0_UserInputConfig (data.csv and tabular_data.csv).\n", "\n", "-------------Run WideDeep-------------\n", "\n", "Training TabMlp\n", "Epoch: 1/300, Train loss: 635.5330, Val loss: 555.4755, Min val loss: 555.4755\n", "Epoch: 21/300, Train loss: 441.6902, Val loss: 375.7337, Min val loss: 375.7337\n", "Epoch: 41/300, Train loss: 145.8623, Val loss: 119.9598, Min val loss: 119.9598\n", "Epoch: 61/300, Train loss: 45.9133, Val loss: 34.0160, Min val loss: 34.0160\n", "Epoch: 81/300, Train loss: 27.6878, Val loss: 24.1525, Min val loss: 24.1525\n", "Epoch: 101/300, Train loss: 23.0877, Val loss: 18.2096, Min val loss: 18.2096\n", "Epoch: 121/300, Train loss: 21.4056, Val loss: 17.2203, Min val loss: 17.1303\n", "Epoch: 141/300, Train loss: 21.2559, Val loss: 16.0746, Min val loss: 16.0746\n", "Epoch: 161/300, Train loss: 19.2337, Val loss: 15.3027, Min val loss: 15.3027\n", "Epoch: 181/300, Train loss: 16.1232, Val loss: 14.5777, Min val loss: 14.5777\n", "Epoch: 201/300, Train loss: 16.7095, Val loss: 14.2274, Min val loss: 14.2274\n", "Epoch: 221/300, Train loss: 15.7366, Val loss: 13.5223, Min val loss: 13.5223\n", "Epoch: 241/300, Train loss: 16.9825, Val loss: 12.9892, Min val loss: 12.9892\n", "Epoch: 261/300, Train loss: 15.3358, Val loss: 12.4278, Min val loss: 12.4278\n", "Epoch: 281/300, Train loss: 13.3989, Val loss: 12.1155, Min val loss: 12.1155\n", "Restoring model weights from the end of the best epoch\n", "Training mse loss: 10.17037\n", "Validation mse loss: 11.66271\n", "Testing mse loss: 6.43856\n", "Trainer saved. To load the trainer, run trainer = load_trainer(path='/tmp/tmpvlx3s8em/output/auto-mpg/2023-09-23-20-41-06-0_UserInputConfig/trainer.pkl')\n", "\n", "-------------WideDeep End-------------\n", "\n", "\n", "-------------Run CallTabMlp-------------\n", "\n", "Training TabMlp\n", "Training mse loss: 10.17037\n", "Validation mse loss: 11.66271\n", "Testing mse loss: 6.43856\n", "Trainer saved. To load the trainer, run trainer = load_trainer(path='/tmp/tmpvlx3s8em/output/auto-mpg/2023-09-23-20-41-06-0_UserInputConfig/trainer.pkl')\n", "\n", "-------------CallTabMlp End-------------\n", "\n", "\n", "-------------Run CallTabMlpTorch-------------\n", "\n", "Training TabMlp\n", "Training mse loss: 10.17037\n", "Validation mse loss: 11.66271\n", "Testing mse loss: 6.43856\n", "Trainer saved. To load the trainer, run trainer = load_trainer(path='/tmp/tmpvlx3s8em/output/auto-mpg/2023-09-23-20-41-06-0_UserInputConfig/trainer.pkl')\n", "\n", "-------------CallTabMlpTorch End-------------\n", "\n", "WideDeep metrics\n", "TabMlp 1/1\n", "CallTabMlp metrics\n", "TabMlp 1/1\n", "CallTabMlpTorch metrics\n", "TabMlp 1/1\n", "Trainer saved. To load the trainer, run trainer = load_trainer(path='/tmp/tmpvlx3s8em/output/auto-mpg/2023-09-23-20-41-06-0_UserInputConfig/trainer.pkl')\n" ] }, { "data": { "text/plain": " Program Model Training RMSE Training MSE Training MAE \\\n0 WideDeep TabMlp 3.189102 10.170372 2.318564 \n1 CallTabMlp TabMlp 3.189102 10.170372 2.318564 \n2 CallTabMlpTorch TabMlp 3.189102 10.170372 2.318564 \n\n Training MAPE Training R2 Training MEDIAN_ABSOLUTE_ERROR \\\n0 0.096454 0.842218 1.669983 \n1 0.096454 0.842218 1.669983 \n2 0.096454 0.842218 1.669983 \n\n Training EXPLAINED_VARIANCE_SCORE Testing RMSE ... Testing R2 \\\n0 0.859805 2.537431 ... 0.88025 \n1 0.859805 2.537431 ... 0.88025 \n2 0.859805 2.537431 ... 0.88025 \n\n Testing MEDIAN_ABSOLUTE_ERROR Testing EXPLAINED_VARIANCE_SCORE \\\n0 1.767459 0.900587 \n1 1.767459 0.900587 \n2 1.767459 0.900587 \n\n Validation RMSE Validation MSE Validation MAE Validation MAPE \\\n0 3.415071 11.662707 2.539188 0.116035 \n1 3.415071 11.662707 2.539188 0.116035 \n2 3.415071 11.662707 2.539188 0.116035 \n\n Validation R2 Validation MEDIAN_ABSOLUTE_ERROR \\\n0 0.791657 1.90416 \n1 0.791657 1.90416 \n2 0.791657 1.90416 \n\n Validation EXPLAINED_VARIANCE_SCORE \n0 0.806152 \n1 0.806152 \n2 0.806152 \n\n[3 rows x 23 columns]", "text/html": "
| \n | Program | \nModel | \nTraining RMSE | \nTraining MSE | \nTraining MAE | \nTraining MAPE | \nTraining R2 | \nTraining MEDIAN_ABSOLUTE_ERROR | \nTraining EXPLAINED_VARIANCE_SCORE | \nTesting RMSE | \n... | \nTesting R2 | \nTesting MEDIAN_ABSOLUTE_ERROR | \nTesting EXPLAINED_VARIANCE_SCORE | \nValidation RMSE | \nValidation MSE | \nValidation MAE | \nValidation MAPE | \nValidation R2 | \nValidation MEDIAN_ABSOLUTE_ERROR | \nValidation EXPLAINED_VARIANCE_SCORE | \n
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | \nWideDeep | \nTabMlp | \n3.189102 | \n10.170372 | \n2.318564 | \n0.096454 | \n0.842218 | \n1.669983 | \n0.859805 | \n2.537431 | \n... | \n0.88025 | \n1.767459 | \n0.900587 | \n3.415071 | \n11.662707 | \n2.539188 | \n0.116035 | \n0.791657 | \n1.90416 | \n0.806152 | \n
| 1 | \nCallTabMlp | \nTabMlp | \n3.189102 | \n10.170372 | \n2.318564 | \n0.096454 | \n0.842218 | \n1.669983 | \n0.859805 | \n2.537431 | \n... | \n0.88025 | \n1.767459 | \n0.900587 | \n3.415071 | \n11.662707 | \n2.539188 | \n0.116035 | \n0.791657 | \n1.90416 | \n0.806152 | \n
| 2 | \nCallTabMlpTorch | \nTabMlp | \n3.189102 | \n10.170372 | \n2.318564 | \n0.096454 | \n0.842218 | \n1.669983 | \n0.859805 | \n2.537431 | \n... | \n0.88025 | \n1.767459 | \n0.900587 | \n3.415071 | \n11.662707 | \n2.539188 | \n0.116035 | \n0.791657 | \n1.90416 | \n0.806152 | \n
3 rows × 23 columns
\n| \n | Program | \nModel | \nTraining RMSE | \nTraining MSE | \nTraining MAE | \nTraining MAPE | \nTraining R2 | \nTraining MEDIAN_ABSOLUTE_ERROR | \nTraining EXPLAINED_VARIANCE_SCORE | \nTesting RMSE | \n... | \nTesting R2 | \nTesting MEDIAN_ABSOLUTE_ERROR | \nTesting EXPLAINED_VARIANCE_SCORE | \nValidation RMSE | \nValidation MSE | \nValidation MAE | \nValidation MAPE | \nValidation R2 | \nValidation MEDIAN_ABSOLUTE_ERROR | \nValidation EXPLAINED_VARIANCE_SCORE | \n
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | \nWideDeep | \nTabMlp | \n3.189102 | \n10.170372 | \n2.318564 | \n0.096454 | \n0.842218 | \n1.669983 | \n0.859805 | \n2.537431 | \n... | \n0.88025 | \n1.767459 | \n0.900587 | \n3.415071 | \n11.662707 | \n2.539188 | \n0.116035 | \n0.791657 | \n1.90416 | \n0.806152 | \n
| 1 | \nCallTabMlp | \nTabMlp | \n3.189102 | \n10.170372 | \n2.318564 | \n0.096454 | \n0.842218 | \n1.669983 | \n0.859805 | \n2.537431 | \n... | \n0.88025 | \n1.767459 | \n0.900587 | \n3.415071 | \n11.662707 | \n2.539188 | \n0.116035 | \n0.791657 | \n1.90416 | \n0.806152 | \n
| 2 | \nCallTabMlpTorch | \nTabMlp | \n3.189102 | \n10.170372 | \n2.318564 | \n0.096454 | \n0.842218 | \n1.669983 | \n0.859805 | \n2.537431 | \n... | \n0.88025 | \n1.767459 | \n0.900587 | \n3.415071 | \n11.662707 | \n2.539188 | \n0.116035 | \n0.791657 | \n1.90416 | \n0.806152 | \n
| 3 | \nCallTabMlpTorchWrapped | \nTabMlp | \n3.189102 | \n10.170372 | \n2.318564 | \n0.096454 | \n0.842218 | \n1.669983 | \n0.859805 | \n2.537431 | \n... | \n0.88025 | \n1.767459 | \n0.900587 | \n3.415071 | \n11.662707 | \n2.539188 | \n0.116035 | \n0.791657 | \n1.90416 | \n0.806152 | \n
4 rows × 23 columns
\n