{ "cells": [ { "cell_type": "markdown", "source": [ "# Advanced customized model base\n", "\n", "Some low-level methods are provided within `AbstractModel` and `TorchModel`, and offer more flexibility for training and testing customization.\n", "\n", "## Advanced customizations of `AbstractModel`\n", "\n", "Assume that a model base `TabNetFromAbstractInherited` is built upon `TabNetFromAbstract` introduced in \"Customized model base\".\n", "\n", "```python\n", "class TabNetFromAbstractInherited(TabNetFromAbstract):\n", "```\n", "\n", "### Training parameters\n", "\n", "`_custom_training_params` returns a dictionary containing items that override settings in the configuration file for the model base. For example:\n", "\n", "```python\n", " def _custom_training_params(self, model_name) -> Dict:\n", " return {\"epoch\": 100}\n", "```\n", "\n", "### Bayesian optimization criterion\n", "\n", "During Bayesian hyperparameter optimization, the objective might be the validation loss, the training loss, or something else. By default, the larger one of the validation loss and the training loss will be returned (The former is usually higher, but randomization may sometimes make the latter higher). For example, the following code returns the validation loss\n", "\n", "```python\n", " def _bayes_eval(self, model, X_train, y_train, X_val, y_val):\n", " y_val_pred = self._pred_single_model(model, X_val, verbose=False)\n", " _, val_loss = self._default_metric_sklearn(y_val, y_val_pred)\n", " return val_loss\n", "```\n", "\n", "where `_default_metric_sklearn` returns MSE loss for regression tasks and log loss for classification tasks.\n", "\n", "### Validity of a model\n", "\n", "`_conditional_validity` is used to check the validity of a model under certain circumstances. For example, some models might be invalid if a certain feature `A_FEATURE` is not provided:\n", "\n", "```python\n", " def _conditional_validity(self, model_name: str) -> bool:\n", " if model_name == \"SOME_MODEL\" and \"A_FEATURE\" not in self.trainer.cont_feature_names:\n", " return False\n", "```\n", "\n", "**Remark**: We do not recommend modifying other methods in `AbstractModel` except for those introduced in this part and in \"Customized model base\" unless you know what you are doing.\n", "\n", "## Advanced customizations of `TorchModel`\n", "\n", "The above customizations of `AbstractModel` can also be applied to `TorchModel`. `TorchModel` is restricted by a narrower framework, but provides more APIs for flexibility considerations. Some customizations are provided in `AbstractNN` at a lower and more specific level.\n", "\n", "```python\n", "class TabNetFromTorchInherited(TabNetFromTorch):\n", "```\n", "\n", "### Customized data processing\n", "\n", "In `TorchModel._train_data_preprocess`, a model base processes tabular or multimodal datasets for itself. The method `_prepare_custom_datamodule` is called at the beginning and should return a `DataModule` instance (`self.trainer.datamodule` by default), which is used to generate final datasets (`torch.utils.data.Dataset` instances) and provides other information. For example, the following code builds a `Datamodule` that additionally records unscaled data as an item of derived data (multimodal data) by using `UnscaledDataDeriver`. Note that `warm_start` should be considered here, otherwise when new data is fed to `fit` with `warm_start=True`, the `Datamodule` will be reset.\n", "\n", "```python\n", " def _prepare_custom_datamodule(self, model_name, warm_start=False):\n", " from tabensemb.data import DataModule\n", "\n", " base = self.trainer.datamodule\n", " if not warm_start or not hasattr(self, \"datamodule\"):\n", " datamodule = DataModule(\n", " config=self.trainer.datamodule.args, initialize=False\n", " )\n", " datamodule.set_data_imputer(\"MeanImputer\")\n", " datamodule.set_data_derivers(\n", " [(\"UnscaledDataDeriver\", {\"derived_name\": \"Unscaled\"})]\n", " )\n", " datamodule.set_data_processors(\n", " [(\"CategoricalOrdinalEncoder\", {}), (\"StandardScaler\", {})]\n", " )\n", " warm_start = False\n", " else:\n", " datamodule = self.datamodule\n", " datamodule.set_data(\n", " base.categories_inverse_transform(base.df),\n", " cont_feature_names=base.cont_feature_names,\n", " cat_feature_names=base.cat_feature_names,\n", " label_name=base.label_name,\n", " train_indices=base.train_indices,\n", " val_indices=base.val_indices,\n", " test_indices=base.test_indices,\n", " verbose=False,\n", " warm_start=warm_start,\n", " )\n", " tmp_derived_data = base.derived_data.copy()\n", " tmp_derived_data.update(datamodule.derived_data)\n", " datamodule.derived_data = tmp_derived_data\n", " self.datamodule = datamodule\n", " return datamodule\n", "```\n", "\n", "In `TorchModel._data_preprocess`, `_run_custom_data_module` is called first to transform the incoming data into a consistent form. A common implementation is as follows:\n", "\n", "```python\n", " def _run_custom_data_module(self, df, derived_data, model_name):\n", " df, my_derived_data = self.datamodule.prepare_new_data(df, ignore_absence=True)\n", " derived_data = derived_data.copy()\n", " derived_data.update(my_derived_data)\n", " derived_data = self.datamodule.sort_derived_data(derived_data)\n", " return df, derived_data, self.datamodule\n", "```\n", "\n", "### Output normalization\n", "\n", "The functionality is provided in `AbstractNN`. Different normalizations are used for different tasks: `torch.nn.Identity()` for regression so that nothing is done on the output, and `torch.nn.Softmax(dim=-1)` for multi-class classification and `torch.nn.Sigmoid()` for binary classification to calculate probabilities from logits. For example, a model will always return positive predictions using the following code:\n", "\n", "```python\n", "class TabNetNNInherited(TabNetNN):\n", " def output_norm(self, y_pred):\n", " return torch.abs(y_pred)\n", "```\n", "\n", "**Remark**: Normalization is not related to the calculation of the loss function.\n", "\n", "### Loss function\n", "\n", "The functionality is provided in `AbstractNN`. By default, `torch.nn.BCEWithLogitsLoss()` is used for binary classification; `torch.nn.CrossEntropyLoss()` is used for multi-class classification; `torch.nn.MSELoss()` (`loss==\"mse\"`) or `torch.nn.L1Loss()` (`loss==\"mae\"`) is used for regression. For example, a model with the following code uses `torch.nn.SmoothL1Loss`:\n", "\n", "```python\n", "class TabNetNNInherited(TabNetNN):\n", " def loss_fn(self, y_pred, y_true, *data, **kwargs):\n", " return torch.nn.SmoothL1Loss()(y_pred, y_true)\n", "```\n", "\n", "`before_loss_fn` is called before calling `loss_fn` to transform the output (from `forward`) and the target to the desired format. Correspondingly to the default `loss_fn` (`self.default_loss_fn` returned by `AbstractNN.get_loss_fn`), a common implementation of `before_loss_fn` is as follows to meet the need of `torch.nn.BCEWithLogitsLoss()` and `torch.nn.CrossEntropyLoss()`:\n", "\n", "```python\n", "class TabNetNNInherited(TabNetNN):\n", " def before_loss_fn(self, y, yhat):\n", " if self.task == \"binary\":\n", " y = torch.flatten(y)\n", " yhat = torch.flatten(yhat)\n", " elif self.task == \"multiclass\":\n", " yhat = torch.flatten(yhat).long()\n", " return y, yhat\n", "```\n", "\n", "### `pytorch_lighting` functionalities\n", "\n", "`AbstractNN` is based on `pytorch_lightning.LightningModule`, so most methods of `LightningModule` can be directly used for `AbstractNN`. Note that some of those methods are already implemented, like `training_step`, `validation_step`, and `configure_optimizers`. Others like `on_train_start`, `on_train_epoch_end`, etc. will be automatically called by `pytorch_lightning.Trainer`. See the original [instructions](https://lightning.ai/docs/pytorch/stable/) for advanced usage.\n", "\n", "### Backward propagation\n", "\n", "The functionality is provided in `AbstractNN`. With the loss value returned by `loss_fn` (or registered attributes during calling `loss_fn`) and optimizers returned by `configure_optimizers`, backward propagation and optimization are performed. The default implementation is as follows where only one optimizer and one loss item are used:\n", "\n", "```python\n", "class TabNetNNInherited(TabNetNN):\n", " def cal_backward_step(self, loss):\n", " self.manual_backward(loss)\n", " opt = self.optimizers()\n", " opt.step()\n", "```\n", "\n", "`self.manual_backward` should be used instead of `loss.backward` due to the requirement of `LightningModule`.\n", "\n", "### Early stopping criterion\n", "\n", "The functionality is provided in `AbstractNN`. Early stopping is used to reduce over-fitting risks. `_early_stopping_eval` returns the monitored value of early stopping. By default, the validation loss is returned:\n", "\n", "```python\n", "class TabNetNNInherited(TabNetNN):\n", " def _early_stopping_eval(self, train_loss: float, val_loss: float) -> float:\n", " return val_loss + 0.0 * train_loss\n", "```\n", "\n", "The second term is used to identify NaN in the training loss." ], "metadata": { "collapsed": false, "pycharm": { "name": "#%% md\n" } } } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 2 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython2", "version": "2.7.6" } }, "nbformat": 4, "nbformat_minor": 0 }