tabensemb.model.AbstractNN.all_gather#

method

AbstractNN.all_gather(data: Tensor | Dict | List | Tuple, group: Any | None = None, sync_grads: bool = False) Tensor | Dict | List | Tuple#

Allows users to call self.all_gather() from the LightningModule, thus making the all_gather operation accelerator agnostic. all_gather is a function provided by accelerators to gather a tensor from several distributed processes.

Parameters:
  • data – int, float, tensor of shape (batch, …), or a (possibly nested) collection thereof.

  • group – the process group to gather results from. Defaults to all processes (world)

  • sync_grads – flag that allows users to synchronize gradients for the all_gather operation

Returns:

A tensor of shape (world_size, batch, …), or if the input was a collection the output will also be a collection with tensors of this shape.