diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index f2eaec2..ce5bd64 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -17,6 +17,9 @@ from torch import nn, cat, stack, arange, tensor, Tensor, is_tensor, zeros, ones import torchvision from torchvision.models import VGG16_Weights +from torch.optim import Optimizer +from adam_atan2_pytorch import MuonAdamAtan2 + from x_mlps_pytorch.normed_mlp import create_mlp from x_mlps_pytorch.ensemble import Ensemble @@ -1699,7 +1702,9 @@ class DynamicsWorldModel(Module): def learn_policy_from_generations( self, - generation: Experience + generation: Experience, + policy_optim: Optimizer | None = None, + value_optim: Optimizer | None = None ): latents = generation.latents actions = generation.actions @@ -1771,6 +1776,14 @@ class DynamicsWorldModel(Module): entropy_loss * self.policy_entropy_weight ) + # maye take policy optimizer step + + if exists(policy_optim): + total_policy_loss.backward() + + policy_optim.step() + policy_optim.zero_grad() + # value loss value_bins = self.value_head(agent_embed) @@ -1786,6 +1799,14 @@ class DynamicsWorldModel(Module): value_loss = torch.maximum(value_loss_1, value_loss_2).mean() + # maybe take value optimizer step + + if exists(policy_optim): + value_loss.backward() + + value_optim.step() + value_optim.zero_grad() + return total_policy_loss, value_loss @torch.no_grad()