if optimizer is passed into the learn from dreams function, take the optimizer steps, otherwise let the researcher handle it externally. also ready muon
This commit is contained in:
parent
cb416c0d44
commit
0c1b067f97
@ -17,6 +17,9 @@ from torch import nn, cat, stack, arange, tensor, Tensor, is_tensor, zeros, ones
|
|||||||
import torchvision
|
import torchvision
|
||||||
from torchvision.models import VGG16_Weights
|
from torchvision.models import VGG16_Weights
|
||||||
|
|
||||||
|
from torch.optim import Optimizer
|
||||||
|
from adam_atan2_pytorch import MuonAdamAtan2
|
||||||
|
|
||||||
from x_mlps_pytorch.normed_mlp import create_mlp
|
from x_mlps_pytorch.normed_mlp import create_mlp
|
||||||
from x_mlps_pytorch.ensemble import Ensemble
|
from x_mlps_pytorch.ensemble import Ensemble
|
||||||
|
|
||||||
@ -1699,7 +1702,9 @@ class DynamicsWorldModel(Module):
|
|||||||
|
|
||||||
def learn_policy_from_generations(
|
def learn_policy_from_generations(
|
||||||
self,
|
self,
|
||||||
generation: Experience
|
generation: Experience,
|
||||||
|
policy_optim: Optimizer | None = None,
|
||||||
|
value_optim: Optimizer | None = None
|
||||||
):
|
):
|
||||||
latents = generation.latents
|
latents = generation.latents
|
||||||
actions = generation.actions
|
actions = generation.actions
|
||||||
@ -1771,6 +1776,14 @@ class DynamicsWorldModel(Module):
|
|||||||
entropy_loss * self.policy_entropy_weight
|
entropy_loss * self.policy_entropy_weight
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# maye take policy optimizer step
|
||||||
|
|
||||||
|
if exists(policy_optim):
|
||||||
|
total_policy_loss.backward()
|
||||||
|
|
||||||
|
policy_optim.step()
|
||||||
|
policy_optim.zero_grad()
|
||||||
|
|
||||||
# value loss
|
# value loss
|
||||||
|
|
||||||
value_bins = self.value_head(agent_embed)
|
value_bins = self.value_head(agent_embed)
|
||||||
@ -1786,6 +1799,14 @@ class DynamicsWorldModel(Module):
|
|||||||
|
|
||||||
value_loss = torch.maximum(value_loss_1, value_loss_2).mean()
|
value_loss = torch.maximum(value_loss_1, value_loss_2).mean()
|
||||||
|
|
||||||
|
# maybe take value optimizer step
|
||||||
|
|
||||||
|
if exists(policy_optim):
|
||||||
|
value_loss.backward()
|
||||||
|
|
||||||
|
value_optim.step()
|
||||||
|
value_optim.zero_grad()
|
||||||
|
|
||||||
return total_policy_loss, value_loss
|
return total_policy_loss, value_loss
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user