diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index 13eb4b2..5c5cc2e 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -1359,13 +1359,24 @@ class DynamicsWorldModel(Module): self.agent_has_genes = num_latent_genes > 0 self.latent_genes = Parameter(randn(num_latent_genes, dim) * 1e-2) + # policy head + + self.policy_head = create_mlp( + dim_in = dim, + dim = dim * 4, + dim_out = dim * 4, + depth = policy_head_mlp_depth + ) + # action embedder self.action_embedder = ActionEmbedder( dim = dim, num_discrete_actions = num_discrete_actions, num_continuous_actions = num_continuous_actions, - continuous_norm_stats = continuous_norm_stats + continuous_norm_stats = continuous_norm_stats, + can_unembed = True, + unembed_dim = dim * 4 ) # each agent token will have the reward embedding of the previous time step - but could eventually just give reward its own token @@ -1386,15 +1397,6 @@ class DynamicsWorldModel(Module): self.reward_loss_weight = reward_loss_weight - # policy head - - self.policy_head = create_mlp( - dim_in = dim, - dim = dim * 4, - dim_out = dim, - depth = policy_head_mlp_depth - ) - # value head self.value_head = create_mlp(