tabensemb.model.AbstractNN#

class tabensemb.model.AbstractNN(datamodule: DataModule, **kwargs)[source]#

Bases: LightningModule

A subclass of pytorch_lightning.LightningModule that is compatible with TorchModel and has implemented training and inferencing steps.

Attributes:
default_loss_fn

The name of the default loss function returned by get_loss_fn()

default_output_norm

The name of the default output normalization returned by get_output_norm()

cont_feature_names

The names of continuous features

cat_feature_names

The names of categorical features

n_cont

The number of continuous features

n_cat

The number of categorical features

default_optimizer

An optimizer name from torch.optim.

default_optimizer_params

Parameters of default_optimizer

default_lr_scheduler

A lr scheduler name from torch.optim.lr_scheduler

default_lr_scheduler_params

Parameters of default_lr_scheduler

derived_feature_names

The keys of derived unstacked features.

derived_feature_dims

The dimensions of derived unstacked features

task

“regression”, “binary”, or “multiclass”

n_outputs

The number of outputs. Note that for classification tasks, logits are returned instead of probabilities. For binary classification, the logit for the positive class is returned.

cat_num_unique

The number of unique values for each categorical feature.

hidden_representation

The extracted information of a deep learning model when forward-passing a batch. It is usually the input of the last output layer (usually a linear layer or an MLP). It should be manually recorded in _forward().

hidden_rep_dim

The dimension of hidden_representation. It should be manually set in __init__().

device

The device where the model is.

training

Methods

__init__(datamodule: DataModule, **kwargs)[source]#

Record useful information for initializing and training models.

Parameters:
datamodule:

A tabensemb.data.datamodule.DataModule instance.

before_loss_fn(y, yhat)

Treatments on the prediction and the ground truth before passing them to loss_fn().

cal_backward_step(loss)

Perform the backward propagation and optimization steps.

cal_zero_grad()

Call zero_grad of optimizers initialized in configure_optimizers().

call_required_model(required_model, x, ...)

Call a required model and return its result.

configure_optimizers()

Choose what optimizers and learning-rate schedulers to use in your optimization.

forward(*tensors[, data_required_models])

A wrapper of the original forward of nn.Module for compatibility concerns.

get_hidden_state(required_model, x, ...[, ...])

The input of the last layer of a deep learning model, i.e. the hidden representation, whose dimension is (batch_size, required_model.hidden_rep_dim).

get_loss_fn(loss, task)

The loss function for the output of forward and the target.

get_output_norm(task)

The operation on the output of forward in training/validation/testing steps.

loss_fn(y_pred, y_true, *data, **kwargs)

User defined loss function.

output_norm(y_pred)

User defined operation before output.

set_requires_grad(model[, requires_grad, state])

Set or reset requires_grad states of a nn.Module.

test_epoch(test_loader, **kwargs)

Evaluate a torch.nn.Module model in a single epoch.

training_step(batch, batch_idx)

Here you compute and return the training loss and some additional metrics for e.g. the progress bar or logger.

validation_step(batch, batch_idx)

Operates on a single batch of data from the validation set.

_early_stopping_eval(train_loss, val_loss)

Calculate the loss value (criteria) for early stopping.

_forward(x, derived_tensors)

The real forward method.

_test_required_model(n_inputs, required_model)

Test whether a required model has the attribute hidden_rep_dim and find its value.