tabensemb.model.AbstractNN.transfer_batch_to_device#
method
- AbstractNN.transfer_batch_to_device(batch: Any, device: device, dataloader_idx: int) Any#
Override this hook if your
DataLoaderreturns tensors wrapped in a custom data structure.The data types listed below (and any arbitrary nesting of them) are supported out of the box:
torch.Tensoror anything that implements .to(…)listdicttuple
For anything else, you need to define how the data is moved to the target device (CPU, GPU, TPU, …).
Note
This hook should only transfer the data and not modify it, nor should it move the data to any other device than the one passed in as argument (unless you know what you are doing). To check the current state of execution of this hook you can use
self.trainer.training/testing/validating/predictingso that you can add different logic as per your requirement.Note
This hook only runs on single GPU training and DDP (no data-parallel). Data-Parallel support will come in near future.
- Parameters:
- Returns:
A reference to the data on the new device.
Example:
def transfer_batch_to_device(self, batch, device, dataloader_idx): if isinstance(batch, CustomBatch): # move all tensors in your custom data structure to the device batch.samples = batch.samples.to(device) batch.targets = batch.targets.to(device) elif dataloader_idx == 0: # skip device transfer for the first dataloader or anything you wish pass else: batch = super().transfer_batch_to_device(batch, device, dataloader_idx) return batch
- Raises:
MisconfigurationException – If using data-parallel,
Trainer(strategy='dp').MisconfigurationException – If using IPUs,
Trainer(accelerator='ipu').
See also
move_data_to_device()apply_to_collection()