diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index 3881307..5504951 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -2058,6 +2058,8 @@ class DynamicsWorldModel(Module): prob_no_shortcut_train = None, # probability of no shortcut training, defaults to 1 / num_step_sizes add_reward_embed_to_agent_token = False, add_reward_embed_dropout = 0.1, + add_state_pred_head = False, + state_pred_loss_weight = 0.1, num_discrete_actions: int | tuple[int, ...] = 0, num_continuous_actions = 0, continuous_norm_stats = None, @@ -2181,6 +2183,22 @@ class DynamicsWorldModel(Module): self.pred_orig_latent = pred_orig_latent # x-space or v-space self.loss_weight_fn = loss_weight_fn + # state prediction, for state entropy bonus + + self.add_state_pred_head = add_state_pred_head + self.state_pred_loss_weight = state_pred_loss_weight + + self.should_pred_state = add_state_pred_head and state_pred_loss_weight > 0. + + if self.should_pred_state: + self.state_pred_token = nn.Parameter(torch.randn(dim) * 1e-2) + + self.to_state_pred = Sequential( + RMSNorm(dim), + nn.Linear(dim, num_latent_tokens * dim_latent * 2), + Rearrange('... (n d two) -> ... n d two', n = num_latent_tokens, two = 2) + ) + # reinforcement related # they sum all the actions into a single token @@ -3387,9 +3405,13 @@ class DynamicsWorldModel(Module): agent_tokens = repeat(agent_tokens, 'b ... d -> b t ... d', t = time) + # empty token + + empty_token = agent_tokens[:, :, 0:0] + # maybe reward tokens - reward_tokens = agent_tokens[:, :, 0:0] + reward_tokens = empty_token if exists(rewards): two_hot_encoding = self.reward_encoder(rewards) @@ -3429,6 +3451,13 @@ class DynamicsWorldModel(Module): else: noised_proprio = proprio + # maybe state prediction token + + if self.should_pred_state: + state_pred_token = repeat(self.state_pred_token, 'd -> b t 1 d', b = batch, t = time) + else: + state_pred_token = empty_token + # maybe create the action tokens if exists(discrete_actions) or exists(continuous_actions): @@ -3453,11 +3482,11 @@ class DynamicsWorldModel(Module): action_tokens = torch.zeros_like(agent_tokens[:, :, 0:1]) else: - action_tokens = agent_tokens[:, :, 0:0] # else empty off agent tokens + action_tokens = empty_token # else empty off agent tokens # main function, needs to be defined as such for shortcut training - additional calls for consistency loss - def get_prediction(noised_latents, noised_proprio, signal_levels, step_sizes_log2, action_tokens, reward_tokens, agent_tokens, return_agent_tokens = False, return_time_cache = False): + def get_prediction(noised_latents, noised_proprio, signal_levels, step_sizes_log2, state_pred_token, action_tokens, reward_tokens, agent_tokens, return_agent_tokens = False, return_time_cache = False): # latents to spatial tokens @@ -3506,7 +3535,7 @@ class DynamicsWorldModel(Module): # pack to tokens for attending - tokens, packed_tokens_shape = pack([flow_token, space_tokens, proprio_token, registers, action_tokens, reward_tokens, agent_tokens], 'b t * d') + tokens, packed_tokens_shape = pack([flow_token, space_tokens, proprio_token, state_pred_token, registers, action_tokens, reward_tokens, agent_tokens], 'b t * d') # attention @@ -3514,7 +3543,7 @@ class DynamicsWorldModel(Module): # unpack - flow_token, space_tokens, proprio_token, register_tokens, action_tokens, reward_tokens, agent_tokens = unpack(tokens, packed_tokens_shape, 'b t * d') + flow_token, space_tokens, proprio_token, state_pred_token, register_tokens, action_tokens, reward_tokens, agent_tokens = unpack(tokens, packed_tokens_shape, 'b t * d') # pooling @@ -3529,6 +3558,11 @@ class DynamicsWorldModel(Module): pred = (pred, pred_proprio) + # maybe state pred + + if self.should_pred_state: + pred_state = self.to_state_pred(state_pred_token) + # returning if not return_agent_tokens: @@ -3541,7 +3575,7 @@ class DynamicsWorldModel(Module): # curry into get_prediction what does not change during first call as well as the shortcut ones - _get_prediction = partial(get_prediction, action_tokens = action_tokens, reward_tokens = reward_tokens, agent_tokens = agent_tokens) + _get_prediction = partial(get_prediction, state_pred_token = state_pred_token, action_tokens = action_tokens, reward_tokens = reward_tokens, agent_tokens = agent_tokens) # forward the network