if optimizer is passed into the learn from dreams function, take the optimizer steps, otherwise let the researcher handle it externally. also ready muon
This commit is contained in:
parent
cb416c0d44
commit
0c1b067f97
@ -17,6 +17,9 @@ from torch import nn, cat, stack, arange, tensor, Tensor, is_tensor, zeros, ones
|
||||
import torchvision
|
||||
from torchvision.models import VGG16_Weights
|
||||
|
||||
from torch.optim import Optimizer
|
||||
from adam_atan2_pytorch import MuonAdamAtan2
|
||||
|
||||
from x_mlps_pytorch.normed_mlp import create_mlp
|
||||
from x_mlps_pytorch.ensemble import Ensemble
|
||||
|
||||
@ -1699,7 +1702,9 @@ class DynamicsWorldModel(Module):
|
||||
|
||||
def learn_policy_from_generations(
|
||||
self,
|
||||
generation: Experience
|
||||
generation: Experience,
|
||||
policy_optim: Optimizer | None = None,
|
||||
value_optim: Optimizer | None = None
|
||||
):
|
||||
latents = generation.latents
|
||||
actions = generation.actions
|
||||
@ -1771,6 +1776,14 @@ class DynamicsWorldModel(Module):
|
||||
entropy_loss * self.policy_entropy_weight
|
||||
)
|
||||
|
||||
# maye take policy optimizer step
|
||||
|
||||
if exists(policy_optim):
|
||||
total_policy_loss.backward()
|
||||
|
||||
policy_optim.step()
|
||||
policy_optim.zero_grad()
|
||||
|
||||
# value loss
|
||||
|
||||
value_bins = self.value_head(agent_embed)
|
||||
@ -1786,6 +1799,14 @@ class DynamicsWorldModel(Module):
|
||||
|
||||
value_loss = torch.maximum(value_loss_1, value_loss_2).mean()
|
||||
|
||||
# maybe take value optimizer step
|
||||
|
||||
if exists(policy_optim):
|
||||
value_loss.backward()
|
||||
|
||||
value_optim.step()
|
||||
value_optim.zero_grad()
|
||||
|
||||
return total_policy_loss, value_loss
|
||||
|
||||
@torch.no_grad()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user