tabensemb.model.AbstractNN#
- class tabensemb.model.AbstractNN(datamodule: DataModule, **kwargs)[source]#
Bases:
LightningModuleA subclass of
pytorch_lightning.LightningModulethat is compatible withTorchModeland has implemented training and inferencing steps.- Attributes:
- default_loss_fn
The name of the default loss function returned by
get_loss_fn()- default_output_norm
The name of the default output normalization returned by
get_output_norm()- cont_feature_names
The names of continuous features
- cat_feature_names
The names of categorical features
- n_cont
The number of continuous features
- n_cat
The number of categorical features
- default_optimizer
An optimizer name from
torch.optim.- default_optimizer_params
Parameters of
default_optimizer- default_lr_scheduler
A lr scheduler name from
torch.optim.lr_scheduler- default_lr_scheduler_params
Parameters of
default_lr_scheduler- derived_feature_names
The keys of derived unstacked features.
- derived_feature_dims
The dimensions of derived unstacked features
- task
“regression”, “binary”, or “multiclass”
- n_outputs
The number of outputs. Note that for classification tasks, logits are returned instead of probabilities. For binary classification, the logit for the positive class is returned.
- cat_num_unique
The number of unique values for each categorical feature.
- hidden_representation
The extracted information of a deep learning model when forward-passing a batch. It is usually the input of the last output layer (usually a linear layer or an MLP). It should be manually recorded in
_forward().- hidden_rep_dim
The dimension of
hidden_representation. It should be manually set in__init__().deviceThe device where the model is.
- training
Methods
- __init__(datamodule: DataModule, **kwargs)[source]#
Record useful information for initializing and training models.
- Parameters:
- datamodule:
A
tabensemb.data.datamodule.DataModuleinstance.
before_loss_fn(y, yhat)Treatments on the prediction and the ground truth before passing them to
loss_fn().cal_backward_step(loss)Perform the backward propagation and optimization steps.
Call zero_grad of optimizers initialized in
configure_optimizers().call_required_model(required_model, x, ...)Call a required model and return its result.
Choose what optimizers and learning-rate schedulers to use in your optimization.
forward(*tensors[, data_required_models])A wrapper of the original forward of
nn.Modulefor compatibility concerns.get_hidden_state(required_model, x, ...[, ...])The input of the last layer of a deep learning model, i.e. the hidden representation, whose dimension is (batch_size, required_model.hidden_rep_dim).
get_loss_fn(loss, task)The loss function for the output of
forwardand the target.get_output_norm(task)The operation on the output of
forwardin training/validation/testing steps.loss_fn(y_pred, y_true, *data, **kwargs)User defined loss function.
output_norm(y_pred)User defined operation before output.
set_requires_grad(model[, requires_grad, state])Set or reset requires_grad states of a
nn.Module.test_epoch(test_loader, **kwargs)Evaluate a torch.nn.Module model in a single epoch.
training_step(batch, batch_idx)Here you compute and return the training loss and some additional metrics for e.g. the progress bar or logger.
validation_step(batch, batch_idx)Operates on a single batch of data from the validation set.
_early_stopping_eval(train_loss, val_loss)Calculate the loss value (criteria) for early stopping.
_forward(x, derived_tensors)The real forward method.
_test_required_model(n_inputs, required_model)Test whether a required model has the attribute
hidden_rep_dimand find its value.