Source code for tabensemb.model.pytorch_tabular

import warnings
import torch
from tabensemb.utils import *
from tabensemb.model import AbstractModel
from skopt.space import Integer, Real, Categorical
import shutil
import numpy as np
from pytorch_lightning import Callback
import pytorch_lightning as pl
from .base import PytorchLightningLossCallback
from .base import AbstractWrapper
from typing import Dict, Any
from packaging import version
from torch import nn
import re
import inspect


[docs] class PytorchTabular(AbstractModel):
[docs] def _get_program_name(self): return "PytorchTabular"
[docs] def _new_model(self, model_name, verbose, **kwargs): warnings.filterwarnings("ignore", message="Wandb") from ._pytorch_tabular.mute_track import mute_track mute_track() from functools import partialmethod import pytorch_tabular from pytorch_tabular.config import ExperimentRunManager erm_original_init = ExperimentRunManager.__init__ ExperimentRunManager.__init__ = partialmethod( ExperimentRunManager.__init__, exp_version_manager=os.path.join(self.root, "exp_version_manager.yml"), ) from pytorch_tabular import TabularModel from pytorch_tabular.models import ( CategoryEmbeddingModelConfig, NodeConfig, TabNetModelConfig, TabTransformerConfig, AutoIntConfig, FTTransformerConfig, GatedAdditiveTreeEnsembleConfig, ) from pytorch_tabular.config import DataConfig, OptimizerConfig, TrainerConfig task = self.trainer.datamodule.task self.task = task if task in ["binary", "multiclass"]: task = "classification" loss = self.trainer.datamodule.loss mapping = { "cross_entropy": "CrossEntropyLoss", "mse": "MSELoss", "mae": "L1Loss", } if loss in mapping.keys(): loss = mapping[loss] self.loss = loss data_config = DataConfig( target=self.trainer.label_name, continuous_cols=self.trainer.cont_feature_names, categorical_cols=self.trainer.cat_feature_names, num_workers=0, ) if not os.path.exists(os.path.join(self.root, "ckpts")): os.mkdir(os.path.join(self.root, "ckpts")) trainer_config = TrainerConfig( batch_size=int(kwargs["batch_size"]), progress_bar="none", early_stopping="valid_loss", early_stopping_patience=self.trainer.static_params["patience"], checkpoints="valid_loss", checkpoints_path=os.path.join(self.root, "ckpts"), checkpoints_save_top_k=1, checkpoints_name=model_name, load_best=True, accelerator="cpu" if self.device == "cpu" else "auto", ) ( opt_name, opt_params, lrs_name, lrs_params, ) = self._update_optimizer_lr_scheduler_params(model_name=model_name, **kwargs) if "lr" in opt_params.keys(): # pytorch_tabular deals with the learning rate individually. del opt_params["lr"] optimizer_config = OptimizerConfig( optimizer=opt_name, optimizer_params=opt_params, lr_scheduler=lrs_name, lr_scheduler_params=lrs_params, ) model_configs = { "Category Embedding": CategoryEmbeddingModelConfig, "NODE": NodeConfig, "TabNet": TabNetModelConfig, "TabTransformer": TabTransformerConfig, "AutoInt": AutoIntConfig, "FTTransformer": FTTransformerConfig, "GATE": GatedAdditiveTreeEnsembleConfig, } special_configs = { "NODE": ( {"embed_categorical": True} if version.parse(pytorch_tabular.__version__) < version.parse("1.1.0") else {} ), } legal_kwargs = { key: value for key, value in kwargs.items() if key not in ["lr", "batch_size", "original_batch_size"] and key not in opt_params.keys() and key not in lrs_params.keys() } if "lr" in kwargs.keys(): legal_kwargs["learning_rate"] = kwargs["lr"] for key in legal_kwargs.keys(): if type(legal_kwargs[key]) in [np.str_, np.int_]: try: legal_kwargs[key] = int(legal_kwargs[key]) except: pass with HiddenPrints(): model_config = ( model_configs[model_name](task=task, loss=loss, **legal_kwargs) if model_name not in special_configs.keys() else model_configs[model_name]( task=task, loss=loss, **special_configs[model_name], **legal_kwargs ) ) tabular_model = TabularModel( data_config=data_config, model_config=model_config, optimizer_config=optimizer_config, trainer_config=trainer_config, ) tabular_model.logger = False tabular_model.config["progress_bar_refresh_rate"] = 0 ExperimentRunManager.__init__ = erm_original_init return tabular_model
[docs] def _train_data_preprocess(self, model_name, warm_start=False): data = self.trainer.datamodule all_feature_names = self.trainer.all_feature_names X_train = data.categories_inverse_transform(data.X_train)[all_feature_names] X_val = data.categories_inverse_transform(data.X_val)[all_feature_names] X_test = data.categories_inverse_transform(data.X_test)[all_feature_names] return { "X_train": X_train, "y_train": data.y_train, "X_val": X_val, "y_val": data.y_val, "X_test": X_test, "y_test": data.y_test, }
[docs] def _data_preprocess(self, df, derived_data, model_name): all_feature_names = self.trainer.all_feature_names df = self.trainer.datamodule.categories_inverse_transform(df.copy())[ all_feature_names ] return df
[docs] def _train_single_model( self, model, model_name, epoch, X_train, y_train, X_val, y_val, verbose, warm_start, in_bayes_opt, **kwargs, ): tc = TqdmController() tc.disable_tqdm() label_name = self.trainer.label_name train_data = X_train.copy() train_data[label_name] = y_train val_data = X_val.copy() val_data[label_name] = y_val pl_loss_callback = PytorchLightningLossCallback( verbose=verbose, total_epoch=epoch ) with HiddenPrints( disable_std=not verbose, disable_logging=not verbose, ): with warnings.catch_warnings(): from pytorch_lightning.utilities.rank_zero import ( LightningDeprecationWarning, ) warnings.filterwarnings("ignore", category=LightningDeprecationWarning) warnings.simplefilter(action="ignore", category=UserWarning) model.fit( train=train_data, validation=val_data, max_epochs=epoch, callbacks=[ PytorchTabularVerboseLossCallback(), pl_loss_callback, ], ) self.train_losses[model_name] = pl_loss_callback.train_ls self.val_losses[model_name] = pl_loss_callback.val_ls from pytorch_lightning.callbacks import ModelCheckpoint ckpt_callback = None for callback in model.callbacks: if isinstance(callback, ModelCheckpoint): ckpt_callback = callback break if ckpt_callback is not None: self.restored_epochs[model_name] = int( re.findall(r"epoch=([0-9]*)-", ckpt_callback.kth_best_model_path)[0] ) if os.path.exists(os.path.join(self.root, "ckpts")): shutil.rmtree(os.path.join(self.root, "ckpts")) tc.enable_tqdm()
[docs] def _pred_single_model(self, model, X_test, verbose, **kwargs): from ._pytorch_tabular.mute_track import mute_track mute_track() targets = model.config.target with HiddenPrints(): # Two annoying warnings that cannot be suppressed: # 1. DeprecationWarning: Default for ``include_input_features`` will change from True to False in the next # release. Please set it explicitly. # 2. DeprecationWarning: "The ``out_ff_layers``, ``out_ff_activation``, ``out_ff_dropoout``, and # ``out_ff_initialization`` arguments are deprecated and will be removed next release. Please use head and # head_config as an alternative. original_batch_size = model.datamodule.batch_size model.datamodule.batch_size = len(X_test) warnings.filterwarnings( "ignore", category=DeprecationWarning, module="pytorch_tabular" ) all_res = model.predict(X_test, include_input_features=False) model.datamodule.batch_size = original_batch_size if self.task == "regression": preds = [ np.array(all_res[f"{target}_prediction"]).reshape(-1, 1) for target in targets ] res = np.concatenate(preds, axis=1) elif self.task == "binary": res = np.array(all_res[f"1_probability"]).reshape(-1, 1) else: n_classes = len(all_res.columns) - 1 res = np.array(all_res)[:, :n_classes] return res
[docs] @staticmethod def _get_model_names(): return [ "Category Embedding", "NODE", "TabNet", "TabTransformer", "AutoInt", "FTTransformer", # "GATE", Low efficiency ]
[docs] def _space(self, model_name): """ Spaces are selected around default parameters. """ space_dict = { "Category Embedding": [ Real(low=0, high=0.5, prior="uniform", name="dropout"), # 0.5 Real(low=0, high=0.5, prior="uniform", name="embedding_dropout"), # 0.5 ] + self.trainer.SPACE, "NODE": [ Integer(low=2, high=5, prior="uniform", name="depth", dtype=int), # 6 Real(low=0, high=0.3, prior="uniform", name="embedding_dropout"), # 0.0 Real(low=0, high=0.3, prior="uniform", name="input_dropout"), # 0.0 Integer(low=64, high=256, prior="uniform", name="num_trees", dtype=int), ] + self.trainer.SPACE, "TabNet": [ Integer(low=4, high=16, prior="uniform", name="n_d", dtype=int), # 8 Integer(low=4, high=16, prior="uniform", name="n_a", dtype=int), # 8 Integer(low=1, high=6, prior="uniform", name="n_steps", dtype=int), # 3 Real(low=1.0, high=1.5, prior="uniform", name="gamma"), # 1.3 Integer( low=1, high=4, prior="uniform", name="n_independent", dtype=int ), # 2 Integer( low=1, high=4, prior="uniform", name="n_shared", dtype=int ), # 2 ] + self.trainer.SPACE, "TabTransformer": [ Categorical(categories=[8, 16, 32], name="input_embed_dim"), Real(low=0, high=0.3, prior="uniform", name="embedding_dropout"), # 0.1 Real(low=0, high=0.3, prior="uniform", name="ff_dropout"), # 0.1 Categorical([2, 4, 8], name="num_heads"), # 8 Integer( low=4, high=8, prior="uniform", name="num_attn_blocks", dtype=int, ), # 6 Real(low=0, high=0.3, prior="uniform", name="attn_dropout"), # 0.1 Real(low=0, high=0.3, prior="uniform", name="add_norm_dropout"), # 0.1 Integer( low=2, high=6, prior="uniform", name="ff_hidden_multiplier", dtype=int, ), # 4 ] + self.trainer.SPACE, "AutoInt": [ Real(low=0, high=0.3, prior="uniform", name="attn_dropouts"), # 0.0 # Categorical([16, 32, 64, 128], name='attn_embed_dim'), # 32 Real(low=0, high=0.3, prior="uniform", name="dropout"), # 0.0 Categorical([4, 8, 16, 32], name="embedding_dim"), # 16 Real(low=0, high=0.3, prior="uniform", name="embedding_dropout"), # 0.0 Integer( low=1, high=4, prior="uniform", name="num_attn_blocks", dtype=int, ), # 3 Categorical([1, 2, 4], name="num_heads"), ] + self.trainer.SPACE, "FTTransformer": [ Categorical(categories=[8, 16, 32], name="input_embed_dim"), Real(low=0, high=0.3, prior="uniform", name="embedding_dropout"), # 0.1 Categorical([2, 4, 8], name="num_heads"), Integer( low=2, high=4, prior="uniform", name="num_attn_blocks", dtype=int, ), # 6 Real(low=0, high=0.3, prior="uniform", name="attn_dropout"), # 0.1 Real(low=0, high=0.3, prior="uniform", name="add_norm_dropout"), # 0.1 Real(low=0, high=0.3, prior="uniform", name="ff_dropout"), # 0.1 Integer( low=2, high=6, prior="uniform", name="ff_hidden_multiplier", dtype=int, ), # 4 ] + self.trainer.SPACE, "GATE": [ Integer(low=2, high=10, prior="uniform", name="gflu_stages", dtype=int), Real(low=0.0, high=0.3, prior="uniform", name="gflu_dropout"), Integer(low=2, high=4, prior="uniform", name="tree_depth", dtype=int), Integer(low=10, high=20, prior="uniform", name="num_trees", dtype=int), Real(low=0.0, high=0.3, prior="uniform", name="tree_dropout"), Real( low=0.0, high=0.3, prior="uniform", name="tree_wise_attention_dropout", ), Real(low=0.0, high=0.3, prior="uniform", name="embedding_dropout"), ] + self.trainer.SPACE, } return space_dict[model_name]
[docs] def _initial_values(self, model_name): params_dict = { "Category Embedding": { "dropout": 0.0, "embedding_dropout": 0.1, }, "NODE": { "depth": 4, "embedding_dropout": 0.0, "input_dropout": 0.0, "num_trees": 256, }, "TabNet": { "n_d": 8, "n_a": 8, "n_steps": 3, "gamma": 1.3, "n_independent": 2, "n_shared": 2, }, "TabTransformer": { "input_embed_dim": 32, "embedding_dropout": 0.1, "ff_dropout": 0.1, "num_heads": 8, "num_attn_blocks": 6, "attn_dropout": 0.1, "add_norm_dropout": 0.1, "ff_hidden_multiplier": 4, }, "AutoInt": { "attn_dropouts": 0.0, "dropout": 0.0, "embedding_dim": 16, "embedding_dropout": 0.0, "num_attn_blocks": 3, "num_heads": 2, }, "FTTransformer": { "input_embed_dim": 32, "embedding_dropout": 0.1, "num_heads": 8, "num_attn_blocks": 4, "attn_dropout": 0.1, "add_norm_dropout": 0.1, "ff_dropout": 0.1, "ff_hidden_multiplier": 4, }, "GATE": { "gflu_stages": 6, "gflu_dropout": 0.0, # ``tree_depth`` influences the memory usage a lot. ``tree_depth``==10 with other default settings consumes # about 4 GiBs of ram. # When "tree_depth" larger than 4, and num_trees larger than 20 (approximately), performance on GPU # decreases dramatically. "tree_depth": 4, "num_trees": 20, "tree_dropout": 0.0, "tree_wise_attention_dropout": 0.0, "embedding_dropout": 0.1, }, } for key in params_dict.keys(): params_dict[key].update(self.trainer.chosen_params) return params_dict[model_name]
def pytorch_tabular_forward(self, backbone_features: torch.Tensor) -> Dict[str, Any]: setattr(self, "_hidden_representation", backbone_features) y_hat = self.head(backbone_features) y_hat = self.apply_output_sigmoid_scaling(y_hat) return self.pack_output(y_hat, backbone_features)
[docs] class PytorchTabularWrapper(AbstractWrapper):
[docs] def __init__(self, model: PytorchTabular): super(PytorchTabularWrapper, self).__init__(model=model) if self.model_name == "TabNet": raise Exception(f"Wrapping TabNet is not supported.")
[docs] def wrap_forward(self): from pytorch_tabular.models.base_model import BaseModel component = self.wrapped_model.model[self.model_name].model self.original_forward = component.compute_head component.compute_head = pytorch_tabular_forward.__get__(component, BaseModel)
[docs] def reset_forward(self): if self.original_forward is not None: component = self.wrapped_model.model[self.model_name].model component.compute_head = self.original_forward
@property def hidden_rep_dim(self): from pytorch_tabular.models.common.heads import LinearHead, MixtureDensityHead head = self.wrapped_model.model[self.model_name].model.head if type(head) == LinearHead: return head.layers[0].in_features elif type(head) == MixtureDensityHead: return head.pi.in_features else: raise Exception( f"Only LinearHead and MixtureDensityHead is supported to extract a hidden_rep_dim, but " f"got {type(head)} instead. It might be a customized one." ) @property def hidden_representation(self): return getattr( self.wrapped_model.model[self.model_name].model, "_hidden_representation" )
class PytorchTabularVerboseLossCallback(Callback): def on_train_batch_end( self, trainer: "pl.Trainer", pl_module: pl.LightningModule, outputs, batch: Any, batch_idx: int, ) -> None: pl_module.log( "train_loss_verbose", outputs["loss"], on_step=False, on_epoch=True, batch_size=batch["target"].shape[0], ) def on_validation_end( self, trainer: "pl.Trainer", pl_module: "pl.LightningModule" ) -> None: trainer.callback_metrics["valid_loss_verbose"] = trainer.callback_metrics[ "valid_loss" ]