tabensemb.model.AbstractNN.tbptt_split_batch#

method

AbstractNN.tbptt_split_batch(batch: Any, split_size: int) List[Any]#

When using truncated backpropagation through time, each batch must be split along the time dimension. Lightning handles this by default, but for custom behavior override this function.

Parameters:
  • batch – Current batch

  • split_size – The size of the split

Returns:

List of batch splits. Each split will be passed to training_step() to enable truncated back propagation through time. The default implementation splits root level Tensors and Sequences at dim=1 (i.e. time dim). It assumes that each time dim is the same length.

Examples:

def tbptt_split_batch(self, batch, split_size):
    splits = []
    for t in range(0, time_dims[0], split_size):
        batch_split = []
        for i, x in enumerate(batch):
            if isinstance(x, torch.Tensor):
                split_x = x[:, t:t + split_size]
            elif isinstance(x, collections.abc.Sequence):
                split_x = [None] * len(x)
                for batch_idx in range(len(x)):
                  split_x[batch_idx] = x[batch_idx][t:t + split_size]
            batch_split.append(split_x)
        splits.append(batch_split)
    return splits

Note

Called in the training loop after on_train_batch_start() if truncated_bptt_steps > 0. Each returned batch split is passed separately to training_step().