tabensemb.model.AbstractNN.tbptt_split_batch#
method
- AbstractNN.tbptt_split_batch(batch: Any, split_size: int) List[Any]#
When using truncated backpropagation through time, each batch must be split along the time dimension. Lightning handles this by default, but for custom behavior override this function.
- Parameters:
- Returns:
List of batch splits. Each split will be passed to
training_step()to enable truncated back propagation through time. The default implementation splits root level Tensors and Sequences at dim=1 (i.e. time dim). It assumes that each time dim is the same length.
Examples:
def tbptt_split_batch(self, batch, split_size): splits = [] for t in range(0, time_dims[0], split_size): batch_split = [] for i, x in enumerate(batch): if isinstance(x, torch.Tensor): split_x = x[:, t:t + split_size] elif isinstance(x, collections.abc.Sequence): split_x = [None] * len(x) for batch_idx in range(len(x)): split_x[batch_idx] = x[batch_idx][t:t + split_size] batch_split.append(split_x) splits.append(batch_split) return splits
Note
Called in the training loop after
on_train_batch_start()iftruncated_bptt_steps> 0. Each returned batch split is passed separately totraining_step().