import torch from torch import nn def optim_step( loss: torch.Tensor, optim: torch.optim.Optimizer, module: nn.Module, max_grad_norm: float | None = None, ) -> None: """Perform a single optimization step. :param loss: :param optim: :param module: :param max_grad_norm: if passed, will clip gradients using this """ optim.zero_grad() loss.backward() if max_grad_norm: nn.utils.clip_grad_norm_(module.parameters(), max_norm=max_grad_norm) optim.step()