tabensemb.model.AbstractNN.on_before_batch_transfer#

method

AbstractNN.on_before_batch_transfer(batch: Any, dataloader_idx: int) Any#

Override to alter or apply batch augmentations to your batch before it is transferred to the device.

Note

To check the current state of execution of this hook you can use self.trainer.training/testing/validating/predicting so that you can add different logic as per your requirement.

Note

This hook only runs on single GPU training and DDP (no data-parallel). Data-Parallel support will come in near future.

Parameters:
  • batch – A batch of data that needs to be altered or augmented.

  • dataloader_idx – The index of the dataloader to which the batch belongs.

Returns:

A batch of data

Example:

def on_before_batch_transfer(self, batch, dataloader_idx):
    batch['x'] = transforms(batch['x'])
    return batch