diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index 91bc687..3f86e9e 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -81,6 +81,8 @@ TransformerIntermediates = namedtuple('TransformerIntermediates', ('next_kv_cach Predictions = namedtuple('Predictions', ('flow', 'proprioception', 'state')) +Embeds = namedtuple('Embeds', ['agent', 'state_pred']) + MaybeTensor = Tensor | None @dataclass @@ -2062,6 +2064,7 @@ class DynamicsWorldModel(Module): add_reward_embed_dropout = 0.1, add_state_pred_head = False, state_pred_loss_weight = 0.1, + state_entropy_bonus_weight = 0.05, num_discrete_actions: int | tuple[int, ...] = 0, num_continuous_actions = 0, continuous_norm_stats = None, @@ -2201,6 +2204,9 @@ class DynamicsWorldModel(Module): Rearrange('... (n d two) -> ... n d two', n = num_latent_tokens, two = 2) ) + self.state_entropy_bonus_weight = state_entropy_bonus_weight + self.add_state_entropy_bonus = self.should_pred_state and state_entropy_bonus_weight > 0. + # reinforcement related # they sum all the actions into a single token @@ -2491,7 +2497,7 @@ class DynamicsWorldModel(Module): latents = self.video_tokenizer(video, return_latents = True) - _, (agent_embed, next_time_cache) = self.forward( + _, (embeds, next_time_cache) = self.forward( latents = latents, signal_levels = self.max_steps - 1, step_sizes = step_size, @@ -2511,6 +2517,8 @@ class DynamicsWorldModel(Module): # get one agent + agent_embed = embeds.agent + one_agent_embed = agent_embed[..., -1:, agent_index, :] # values @@ -2565,6 +2573,17 @@ class DynamicsWorldModel(Module): elif len(env_step_out) == 5: next_frame, reward, terminated, truncated, info = env_step_out + # maybe add state entropy bonus + + if self.add_state_entropy_bonus: + state_pred_token = embeds.state_pred + + state_pred = self.to_state_pred(state_pred_token) + + state_pred_log_variance = state_pred[..., 1].sum() + + reward = reward + state_pred_log_variance * self.state_entropy_bonus_weight + # update episode lens episode_lens = torch.where(is_terminated, episode_lens, episode_lens + 1) @@ -3035,7 +3054,7 @@ class DynamicsWorldModel(Module): signal_levels_with_context = F.pad(signal_levels, (curr_time_steps, 0), value = self.max_steps - 1) - pred, (agent_embed, next_time_cache) = self.forward( + pred, (embeds, next_time_cache) = self.forward( latents = noised_latent_with_context, signal_levels = signal_levels_with_context, step_sizes = step_size, @@ -3102,6 +3121,8 @@ class DynamicsWorldModel(Module): # take care of the rewards by predicting on the agent token embedding on the last denoising step if return_rewards_per_frame: + agent_embed = embeds.agent + one_agent_embed = agent_embed[:, -1:, agent_index] reward_logits = self.to_reward_pred.forward_one(one_agent_embed, id = 0) @@ -3112,6 +3133,8 @@ class DynamicsWorldModel(Module): # maybe store agent embed if store_agent_embed: + agent_embed = embeds.agent + one_agent_embed = agent_embed[:, -1:, agent_index] acc_agent_embed = safe_cat((acc_agent_embed, one_agent_embed), dim = 1) @@ -3571,17 +3594,17 @@ class DynamicsWorldModel(Module): # returning - # returning - predictions = Predictions(pred, pred_proprio, pred_state) + embeds = Embeds(agent_tokens, state_pred_token) + if not return_agent_tokens: return predictions if not return_time_cache: - return predictions, agent_tokens + return predictions, embeds - return predictions, (agent_tokens, intermediates) + return predictions, (embeds, intermediates) # curry into get_prediction what does not change during first call as well as the shortcut ones @@ -3589,13 +3612,13 @@ class DynamicsWorldModel(Module): # forward the network - pred, (encoded_agent_tokens, intermediates) = _get_prediction(noised_latents, noised_proprio, signal_levels, step_sizes_log2, return_agent_tokens = True, return_time_cache = True) + pred, (embeds, intermediates) = _get_prediction(noised_latents, noised_proprio, signal_levels, step_sizes_log2, return_agent_tokens = True, return_time_cache = True) if return_pred_only: if not return_intermediates: return pred - return pred, (encoded_agent_tokens, intermediates) + return pred, (embeds, intermediates) # pack the predictions to calculate flow for different modalities all at once @@ -3732,6 +3755,8 @@ class DynamicsWorldModel(Module): if exists(rewards): + encoded_agent_tokens = embeds.agent + if rewards.ndim == 2: # (b t) encoded_agent_tokens = reduce(encoded_agent_tokens, 'b t g d -> b t d', 'mean') diff --git a/pyproject.toml b/pyproject.toml index 95cd36a..8499691 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "dreamer4" -version = "0.1.21" +version = "0.1.22" description = "Dreamer 4" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" }