tabensemb.model.AbstractNN.optimizer_step#

method

AbstractNN.optimizer_step(epoch: int, batch_idx: int, optimizer: Optimizer | LightningOptimizer, optimizer_idx: int = 0, optimizer_closure: Callable[[], Any] | None = None, on_tpu: bool = False, using_lbfgs: bool = False) None#

Override this method to adjust the default way the Trainer calls each optimizer.

By default, Lightning calls step() and zero_grad() as shown in the example once per optimizer. This method (and zero_grad()) won’t be called during the accumulation phase when Trainer(accumulate_grad_batches != 1). Overriding this hook has no benefit with manual optimization.

Parameters:
  • epoch – Current epoch

  • batch_idx – Index of current batch

  • optimizer – A PyTorch optimizer

  • optimizer_idx – If you used multiple optimizers, this indexes into that list.

  • optimizer_closure – The optimizer closure. This closure must be executed as it includes the calls to training_step(), optimizer.zero_grad(), and backward().

  • on_tpuTrue if TPU backward is required

  • using_lbfgs – True if the matching optimizer is torch.optim.LBFGS

Examples:

# DEFAULT
def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx,
                   optimizer_closure, on_tpu, using_lbfgs):
    optimizer.step(closure=optimizer_closure)

# Alternating schedule for optimizer steps (i.e.: GANs)
def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx,
                   optimizer_closure, on_tpu, using_lbfgs):
    # update generator opt every step
    if optimizer_idx == 0:
        optimizer.step(closure=optimizer_closure)

    # update discriminator opt every 2 steps
    if optimizer_idx == 1:
        if (batch_idx + 1) % 2 == 0 :
            optimizer.step(closure=optimizer_closure)
        else:
            # call the closure by itself to run `training_step` + `backward` without an optimizer step
            optimizer_closure()

    # ...
    # add as many optimizers as you want

Here’s another example showing how to use this for more advanced things such as learning rate warm-up:

# learning rate warm-up
def optimizer_step(
    self,
    epoch,
    batch_idx,
    optimizer,
    optimizer_idx,
    optimizer_closure,
    on_tpu,
    using_lbfgs,
):
    # update params
    optimizer.step(closure=optimizer_closure)

    # manually warm up lr without a scheduler
    if self.trainer.global_step < 500:
        lr_scale = min(1.0, float(self.trainer.global_step + 1) / 500.0)
        for pg in optimizer.param_groups:
            pg["lr"] = lr_scale * self.learning_rate