if optimizer is passed into the learn from dreams function, take the optimizer steps, otherwise let the researcher handle it externally. also ready muon

This commit is contained in:
lucidrains 2025-10-17 08:55:20 -07:00
parent cb416c0d44
commit 0c1b067f97

View File

@ -17,6 +17,9 @@ from torch import nn, cat, stack, arange, tensor, Tensor, is_tensor, zeros, ones
import torchvision
from torchvision.models import VGG16_Weights
from torch.optim import Optimizer
from adam_atan2_pytorch import MuonAdamAtan2
from x_mlps_pytorch.normed_mlp import create_mlp
from x_mlps_pytorch.ensemble import Ensemble
@ -1699,7 +1702,9 @@ class DynamicsWorldModel(Module):
def learn_policy_from_generations(
self,
generation: Experience
generation: Experience,
policy_optim: Optimizer | None = None,
value_optim: Optimizer | None = None
):
latents = generation.latents
actions = generation.actions
@ -1771,6 +1776,14 @@ class DynamicsWorldModel(Module):
entropy_loss * self.policy_entropy_weight
)
# maye take policy optimizer step
if exists(policy_optim):
total_policy_loss.backward()
policy_optim.step()
policy_optim.zero_grad()
# value loss
value_bins = self.value_head(agent_embed)
@ -1786,6 +1799,14 @@ class DynamicsWorldModel(Module):
value_loss = torch.maximum(value_loss_1, value_loss_2).mean()
# maybe take value optimizer step
if exists(policy_optim):
value_loss.backward()
value_optim.step()
value_optim.zero_grad()
return total_policy_loss, value_loss
@torch.no_grad()