From 0c1b067f9740095c29dca4e70bc73a42904a8322 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Fri, 17 Oct 2025 08:55:20 -0700 Subject: [PATCH] if optimizer is passed into the learn from dreams function, take the optimizer steps, otherwise let the researcher handle it externally. also ready muon --- dreamer4/dreamer4.py | 23 ++++++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index f2eaec2..ce5bd64 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -17,6 +17,9 @@ from torch import nn, cat, stack, arange, tensor, Tensor, is_tensor, zeros, ones import torchvision from torchvision.models import VGG16_Weights +from torch.optim import Optimizer +from adam_atan2_pytorch import MuonAdamAtan2 + from x_mlps_pytorch.normed_mlp import create_mlp from x_mlps_pytorch.ensemble import Ensemble @@ -1699,7 +1702,9 @@ class DynamicsWorldModel(Module): def learn_policy_from_generations( self, - generation: Experience + generation: Experience, + policy_optim: Optimizer | None = None, + value_optim: Optimizer | None = None ): latents = generation.latents actions = generation.actions @@ -1771,6 +1776,14 @@ class DynamicsWorldModel(Module): entropy_loss * self.policy_entropy_weight ) + # maye take policy optimizer step + + if exists(policy_optim): + total_policy_loss.backward() + + policy_optim.step() + policy_optim.zero_grad() + # value loss value_bins = self.value_head(agent_embed) @@ -1786,6 +1799,14 @@ class DynamicsWorldModel(Module): value_loss = torch.maximum(value_loss_1, value_loss_2).mean() + # maybe take value optimizer step + + if exists(policy_optim): + value_loss.backward() + + value_optim.step() + value_optim.zero_grad() + return total_policy_loss, value_loss @torch.no_grad()