complete the addition of a state entropy bonus

This commit is contained in:
lucidrains 2025-12-03 07:51:04 -08:00
parent fb6d69f43a
commit eb7a13502e
2 changed files with 34 additions and 9 deletions

View File

@ -81,6 +81,8 @@ TransformerIntermediates = namedtuple('TransformerIntermediates', ('next_kv_cach
Predictions = namedtuple('Predictions', ('flow', 'proprioception', 'state'))
Embeds = namedtuple('Embeds', ['agent', 'state_pred'])
MaybeTensor = Tensor | None
@dataclass
@ -2062,6 +2064,7 @@ class DynamicsWorldModel(Module):
add_reward_embed_dropout = 0.1,
add_state_pred_head = False,
state_pred_loss_weight = 0.1,
state_entropy_bonus_weight = 0.05,
num_discrete_actions: int | tuple[int, ...] = 0,
num_continuous_actions = 0,
continuous_norm_stats = None,
@ -2201,6 +2204,9 @@ class DynamicsWorldModel(Module):
Rearrange('... (n d two) -> ... n d two', n = num_latent_tokens, two = 2)
)
self.state_entropy_bonus_weight = state_entropy_bonus_weight
self.add_state_entropy_bonus = self.should_pred_state and state_entropy_bonus_weight > 0.
# reinforcement related
# they sum all the actions into a single token
@ -2491,7 +2497,7 @@ class DynamicsWorldModel(Module):
latents = self.video_tokenizer(video, return_latents = True)
_, (agent_embed, next_time_cache) = self.forward(
_, (embeds, next_time_cache) = self.forward(
latents = latents,
signal_levels = self.max_steps - 1,
step_sizes = step_size,
@ -2511,6 +2517,8 @@ class DynamicsWorldModel(Module):
# get one agent
agent_embed = embeds.agent
one_agent_embed = agent_embed[..., -1:, agent_index, :]
# values
@ -2565,6 +2573,17 @@ class DynamicsWorldModel(Module):
elif len(env_step_out) == 5:
next_frame, reward, terminated, truncated, info = env_step_out
# maybe add state entropy bonus
if self.add_state_entropy_bonus:
state_pred_token = embeds.state_pred
state_pred = self.to_state_pred(state_pred_token)
state_pred_log_variance = state_pred[..., 1].sum()
reward = reward + state_pred_log_variance * self.state_entropy_bonus_weight
# update episode lens
episode_lens = torch.where(is_terminated, episode_lens, episode_lens + 1)
@ -3035,7 +3054,7 @@ class DynamicsWorldModel(Module):
signal_levels_with_context = F.pad(signal_levels, (curr_time_steps, 0), value = self.max_steps - 1)
pred, (agent_embed, next_time_cache) = self.forward(
pred, (embeds, next_time_cache) = self.forward(
latents = noised_latent_with_context,
signal_levels = signal_levels_with_context,
step_sizes = step_size,
@ -3102,6 +3121,8 @@ class DynamicsWorldModel(Module):
# take care of the rewards by predicting on the agent token embedding on the last denoising step
if return_rewards_per_frame:
agent_embed = embeds.agent
one_agent_embed = agent_embed[:, -1:, agent_index]
reward_logits = self.to_reward_pred.forward_one(one_agent_embed, id = 0)
@ -3112,6 +3133,8 @@ class DynamicsWorldModel(Module):
# maybe store agent embed
if store_agent_embed:
agent_embed = embeds.agent
one_agent_embed = agent_embed[:, -1:, agent_index]
acc_agent_embed = safe_cat((acc_agent_embed, one_agent_embed), dim = 1)
@ -3571,17 +3594,17 @@ class DynamicsWorldModel(Module):
# returning
# returning
predictions = Predictions(pred, pred_proprio, pred_state)
embeds = Embeds(agent_tokens, state_pred_token)
if not return_agent_tokens:
return predictions
if not return_time_cache:
return predictions, agent_tokens
return predictions, embeds
return predictions, (agent_tokens, intermediates)
return predictions, (embeds, intermediates)
# curry into get_prediction what does not change during first call as well as the shortcut ones
@ -3589,13 +3612,13 @@ class DynamicsWorldModel(Module):
# forward the network
pred, (encoded_agent_tokens, intermediates) = _get_prediction(noised_latents, noised_proprio, signal_levels, step_sizes_log2, return_agent_tokens = True, return_time_cache = True)
pred, (embeds, intermediates) = _get_prediction(noised_latents, noised_proprio, signal_levels, step_sizes_log2, return_agent_tokens = True, return_time_cache = True)
if return_pred_only:
if not return_intermediates:
return pred
return pred, (encoded_agent_tokens, intermediates)
return pred, (embeds, intermediates)
# pack the predictions to calculate flow for different modalities all at once
@ -3732,6 +3755,8 @@ class DynamicsWorldModel(Module):
if exists(rewards):
encoded_agent_tokens = embeds.agent
if rewards.ndim == 2: # (b t)
encoded_agent_tokens = reduce(encoded_agent_tokens, 'b t g d -> b t d', 'mean')

View File

@ -1,6 +1,6 @@
[project]
name = "dreamer4"
version = "0.1.21"
version = "0.1.22"
description = "Dreamer 4"
authors = [
{ name = "Phil Wang", email = "lucidrains@gmail.com" }