This commit is contained in:
lucidrains 2025-10-10 10:42:54 -07:00
parent 32aa355e37
commit c68942b026

View File

@ -1359,13 +1359,24 @@ class DynamicsWorldModel(Module):
self.agent_has_genes = num_latent_genes > 0
self.latent_genes = Parameter(randn(num_latent_genes, dim) * 1e-2)
# policy head
self.policy_head = create_mlp(
dim_in = dim,
dim = dim * 4,
dim_out = dim * 4,
depth = policy_head_mlp_depth
)
# action embedder
self.action_embedder = ActionEmbedder(
dim = dim,
num_discrete_actions = num_discrete_actions,
num_continuous_actions = num_continuous_actions,
continuous_norm_stats = continuous_norm_stats
continuous_norm_stats = continuous_norm_stats,
can_unembed = True,
unembed_dim = dim * 4
)
# each agent token will have the reward embedding of the previous time step - but could eventually just give reward its own token
@ -1386,15 +1397,6 @@ class DynamicsWorldModel(Module):
self.reward_loss_weight = reward_loss_weight
# policy head
self.policy_head = create_mlp(
dim_in = dim,
dim = dim * 4,
dim_out = dim,
depth = policy_head_mlp_depth
)
# value head
self.value_head = create_mlp(