Advanced customized model base#
Some low-level methods are provided within AbstractModel and TorchModel, and offer more flexibility for training and testing customization.
Advanced customizations of AbstractModel#
Assume that a model base TabNetFromAbstractInherited is built upon TabNetFromAbstract introduced in “Customized model base”.
class TabNetFromAbstractInherited(TabNetFromAbstract):
Training parameters#
_custom_training_params returns a dictionary containing items that override settings in the configuration file for the model base. For example:
def _custom_training_params(self, model_name) -> Dict:
return {"epoch": 100}
Bayesian optimization criterion#
During Bayesian hyperparameter optimization, the objective might be the validation loss, the training loss, or something else. By default, the larger one of the validation loss and the training loss will be returned (The former is usually higher, but randomization may sometimes make the latter higher). For example, the following code returns the validation loss
def _bayes_eval(self, model, X_train, y_train, X_val, y_val):
y_val_pred = self._pred_single_model(model, X_val, verbose=False)
_, val_loss = self._default_metric_sklearn(y_val, y_val_pred)
return val_loss
where _default_metric_sklearn returns MSE loss for regression tasks and log loss for classification tasks.
Validity of a model#
_conditional_validity is used to check the validity of a model under certain circumstances. For example, some models might be invalid if a certain feature A_FEATURE is not provided:
def _conditional_validity(self, model_name: str) -> bool:
if model_name == "SOME_MODEL" and "A_FEATURE" not in self.trainer.cont_feature_names:
return False
Remark: We do not recommend modifying other methods in AbstractModel except for those introduced in this part and in “Customized model base” unless you know what you are doing.
Advanced customizations of TorchModel#
The above customizations of AbstractModel can also be applied to TorchModel. TorchModel is restricted by a narrower framework, but provides more APIs for flexibility considerations. Some customizations are provided in AbstractNN at a lower and more specific level.
class TabNetFromTorchInherited(TabNetFromTorch):
Customized data processing#
In TorchModel._train_data_preprocess, a model base processes tabular or multimodal datasets for itself. The method _prepare_custom_datamodule is called at the beginning and should return a DataModule instance (self.trainer.datamodule by default), which is used to generate final datasets (torch.utils.data.Dataset instances) and provides other information. For example, the following code builds a Datamodule that additionally records unscaled data as an item of derived data
(multimodal data) by using UnscaledDataDeriver. Note that warm_start should be considered here, otherwise when new data is fed to fit with warm_start=True, the Datamodule will be reset.
def _prepare_custom_datamodule(self, model_name, warm_start=False):
from tabensemb.data import DataModule
base = self.trainer.datamodule
if not warm_start or not hasattr(self, "datamodule"):
datamodule = DataModule(
config=self.trainer.datamodule.args, initialize=False
)
datamodule.set_data_imputer("MeanImputer")
datamodule.set_data_derivers(
[("UnscaledDataDeriver", {"derived_name": "Unscaled"})]
)
datamodule.set_data_processors(
[("CategoricalOrdinalEncoder", {}), ("StandardScaler", {})]
)
warm_start = False
else:
datamodule = self.datamodule
datamodule.set_data(
base.categories_inverse_transform(base.df),
cont_feature_names=base.cont_feature_names,
cat_feature_names=base.cat_feature_names,
label_name=base.label_name,
train_indices=base.train_indices,
val_indices=base.val_indices,
test_indices=base.test_indices,
verbose=False,
warm_start=warm_start,
)
tmp_derived_data = base.derived_data.copy()
tmp_derived_data.update(datamodule.derived_data)
datamodule.derived_data = tmp_derived_data
self.datamodule = datamodule
return datamodule
In TorchModel._data_preprocess, _run_custom_data_module is called first to transform the incoming data into a consistent form. A common implementation is as follows:
def _run_custom_data_module(self, df, derived_data, model_name):
df, my_derived_data = self.datamodule.prepare_new_data(df, ignore_absence=True)
derived_data = derived_data.copy()
derived_data.update(my_derived_data)
derived_data = self.datamodule.sort_derived_data(derived_data)
return df, derived_data, self.datamodule
Output normalization#
The functionality is provided in AbstractNN. Different normalizations are used for different tasks: torch.nn.Identity() for regression so that nothing is done on the output, and torch.nn.Softmax(dim=-1) for multi-class classification and torch.nn.Sigmoid() for binary classification to calculate probabilities from logits. For example, a model will always return positive predictions using the following code:
class TabNetNNInherited(TabNetNN):
def output_norm(self, y_pred):
return torch.abs(y_pred)
Remark: Normalization is not related to the calculation of the loss function.
Loss function#
The functionality is provided in AbstractNN. By default, torch.nn.BCEWithLogitsLoss() is used for binary classification; torch.nn.CrossEntropyLoss() is used for multi-class classification; torch.nn.MSELoss() (loss=="mse") or torch.nn.L1Loss() (loss=="mae") is used for regression. For example, a model with the following code uses torch.nn.SmoothL1Loss:
class TabNetNNInherited(TabNetNN):
def loss_fn(self, y_pred, y_true, *data, **kwargs):
return torch.nn.SmoothL1Loss()(y_pred, y_true)
before_loss_fn is called before calling loss_fn to transform the output (from forward) and the target to the desired format. Correspondingly to the default loss_fn (self.default_loss_fn returned by AbstractNN.get_loss_fn), a common implementation of before_loss_fn is as follows to meet the need of torch.nn.BCEWithLogitsLoss() and torch.nn.CrossEntropyLoss():
class TabNetNNInherited(TabNetNN):
def before_loss_fn(self, y, yhat):
if self.task == "binary":
y = torch.flatten(y)
yhat = torch.flatten(yhat)
elif self.task == "multiclass":
yhat = torch.flatten(yhat).long()
return y, yhat
pytorch_lighting functionalities#
AbstractNN is based on pytorch_lightning.LightningModule, so most methods of LightningModule can be directly used for AbstractNN. Note that some of those methods are already implemented, like training_step, validation_step, and configure_optimizers. Others like on_train_start, on_train_epoch_end, etc. will be automatically called by pytorch_lightning.Trainer. See the original instructions for advanced usage.
Backward propagation#
The functionality is provided in AbstractNN. With the loss value returned by loss_fn (or registered attributes during calling loss_fn) and optimizers returned by configure_optimizers, backward propagation and optimization are performed. The default implementation is as follows where only one optimizer and one loss item are used:
class TabNetNNInherited(TabNetNN):
def cal_backward_step(self, loss):
self.manual_backward(loss)
opt = self.optimizers()
opt.step()
self.manual_backward should be used instead of loss.backward due to the requirement of LightningModule.
Early stopping criterion#
The functionality is provided in AbstractNN. Early stopping is used to reduce over-fitting risks. _early_stopping_eval returns the monitored value of early stopping. By default, the validation loss is returned:
class TabNetNNInherited(TabNetNN):
def _early_stopping_eval(self, train_loss: float, val_loss: float) -> float:
return val_loss + 0.0 * train_loss
The second term is used to identify NaN in the training loss.