tabensemb.model.AbstractNN.training_step#
method
- AbstractNN.training_step(batch: Any, batch_idx: Any)[source]#
Here you compute and return the training loss and some additional metrics for e.g. the progress bar or logger.
- Parameters:
batch¶ (
Tensor| (Tensor, …) | [Tensor, …]) – The output of yourDataLoader. A tensor, tuple or list.batch_idx¶ (
int) – Integer displaying index of this batchoptimizer_idx¶ (
int) – When using multiple optimizers, this argument will also be present.hiddens¶ (
Any) – Passed in iftruncated_bptt_steps> 0.
- Returns:
Any of.
Tensor- The loss tensordict- A dictionary. Can include any keys, but must include the key'loss'None- Training will skip to the next batch. This is only for automatic optimization.This is not supported for multi-GPU, TPU, IPU, or DeepSpeed.
In this step you’d normally do the forward pass and calculate the loss for a batch. You can also do fancier things like multiple forward passes or something model specific.
Example:
def training_step(self, batch, batch_idx): x, y, z = batch out = self.encoder(x) loss = self.loss(out, x) return loss
If you define multiple optimizers, this step will be called with an additional
optimizer_idxparameter.# Multiple optimizers (e.g.: GANs) def training_step(self, batch, batch_idx, optimizer_idx): if optimizer_idx == 0: # do training_step with encoder ... if optimizer_idx == 1: # do training_step with decoder ...
If you add truncated back propagation through time you will also get an additional argument with the hidden states of the previous step.
# Truncated back-propagation through time def training_step(self, batch, batch_idx, hiddens): # hiddens are the hidden states from the previous truncated backprop step out, hiddens = self.lstm(data, hiddens) loss = ... return {"loss": loss, "hiddens": hiddens}
Note
The loss value shown in the progress bar is smoothed (averaged) over the last values, so it differs from the actual loss returned in train/validation step.
Note
When
accumulate_grad_batches> 1, the loss returned here will be automatically normalized byaccumulate_grad_batchesinternally.