complete the addition of a state entropy bonus
This commit is contained in:
parent
fb6d69f43a
commit
eb7a13502e
@ -81,6 +81,8 @@ TransformerIntermediates = namedtuple('TransformerIntermediates', ('next_kv_cach
|
||||
|
||||
Predictions = namedtuple('Predictions', ('flow', 'proprioception', 'state'))
|
||||
|
||||
Embeds = namedtuple('Embeds', ['agent', 'state_pred'])
|
||||
|
||||
MaybeTensor = Tensor | None
|
||||
|
||||
@dataclass
|
||||
@ -2062,6 +2064,7 @@ class DynamicsWorldModel(Module):
|
||||
add_reward_embed_dropout = 0.1,
|
||||
add_state_pred_head = False,
|
||||
state_pred_loss_weight = 0.1,
|
||||
state_entropy_bonus_weight = 0.05,
|
||||
num_discrete_actions: int | tuple[int, ...] = 0,
|
||||
num_continuous_actions = 0,
|
||||
continuous_norm_stats = None,
|
||||
@ -2201,6 +2204,9 @@ class DynamicsWorldModel(Module):
|
||||
Rearrange('... (n d two) -> ... n d two', n = num_latent_tokens, two = 2)
|
||||
)
|
||||
|
||||
self.state_entropy_bonus_weight = state_entropy_bonus_weight
|
||||
self.add_state_entropy_bonus = self.should_pred_state and state_entropy_bonus_weight > 0.
|
||||
|
||||
# reinforcement related
|
||||
|
||||
# they sum all the actions into a single token
|
||||
@ -2491,7 +2497,7 @@ class DynamicsWorldModel(Module):
|
||||
|
||||
latents = self.video_tokenizer(video, return_latents = True)
|
||||
|
||||
_, (agent_embed, next_time_cache) = self.forward(
|
||||
_, (embeds, next_time_cache) = self.forward(
|
||||
latents = latents,
|
||||
signal_levels = self.max_steps - 1,
|
||||
step_sizes = step_size,
|
||||
@ -2511,6 +2517,8 @@ class DynamicsWorldModel(Module):
|
||||
|
||||
# get one agent
|
||||
|
||||
agent_embed = embeds.agent
|
||||
|
||||
one_agent_embed = agent_embed[..., -1:, agent_index, :]
|
||||
|
||||
# values
|
||||
@ -2565,6 +2573,17 @@ class DynamicsWorldModel(Module):
|
||||
elif len(env_step_out) == 5:
|
||||
next_frame, reward, terminated, truncated, info = env_step_out
|
||||
|
||||
# maybe add state entropy bonus
|
||||
|
||||
if self.add_state_entropy_bonus:
|
||||
state_pred_token = embeds.state_pred
|
||||
|
||||
state_pred = self.to_state_pred(state_pred_token)
|
||||
|
||||
state_pred_log_variance = state_pred[..., 1].sum()
|
||||
|
||||
reward = reward + state_pred_log_variance * self.state_entropy_bonus_weight
|
||||
|
||||
# update episode lens
|
||||
|
||||
episode_lens = torch.where(is_terminated, episode_lens, episode_lens + 1)
|
||||
@ -3035,7 +3054,7 @@ class DynamicsWorldModel(Module):
|
||||
|
||||
signal_levels_with_context = F.pad(signal_levels, (curr_time_steps, 0), value = self.max_steps - 1)
|
||||
|
||||
pred, (agent_embed, next_time_cache) = self.forward(
|
||||
pred, (embeds, next_time_cache) = self.forward(
|
||||
latents = noised_latent_with_context,
|
||||
signal_levels = signal_levels_with_context,
|
||||
step_sizes = step_size,
|
||||
@ -3102,6 +3121,8 @@ class DynamicsWorldModel(Module):
|
||||
# take care of the rewards by predicting on the agent token embedding on the last denoising step
|
||||
|
||||
if return_rewards_per_frame:
|
||||
agent_embed = embeds.agent
|
||||
|
||||
one_agent_embed = agent_embed[:, -1:, agent_index]
|
||||
|
||||
reward_logits = self.to_reward_pred.forward_one(one_agent_embed, id = 0)
|
||||
@ -3112,6 +3133,8 @@ class DynamicsWorldModel(Module):
|
||||
# maybe store agent embed
|
||||
|
||||
if store_agent_embed:
|
||||
agent_embed = embeds.agent
|
||||
|
||||
one_agent_embed = agent_embed[:, -1:, agent_index]
|
||||
acc_agent_embed = safe_cat((acc_agent_embed, one_agent_embed), dim = 1)
|
||||
|
||||
@ -3571,17 +3594,17 @@ class DynamicsWorldModel(Module):
|
||||
|
||||
# returning
|
||||
|
||||
# returning
|
||||
|
||||
predictions = Predictions(pred, pred_proprio, pred_state)
|
||||
|
||||
embeds = Embeds(agent_tokens, state_pred_token)
|
||||
|
||||
if not return_agent_tokens:
|
||||
return predictions
|
||||
|
||||
if not return_time_cache:
|
||||
return predictions, agent_tokens
|
||||
return predictions, embeds
|
||||
|
||||
return predictions, (agent_tokens, intermediates)
|
||||
return predictions, (embeds, intermediates)
|
||||
|
||||
# curry into get_prediction what does not change during first call as well as the shortcut ones
|
||||
|
||||
@ -3589,13 +3612,13 @@ class DynamicsWorldModel(Module):
|
||||
|
||||
# forward the network
|
||||
|
||||
pred, (encoded_agent_tokens, intermediates) = _get_prediction(noised_latents, noised_proprio, signal_levels, step_sizes_log2, return_agent_tokens = True, return_time_cache = True)
|
||||
pred, (embeds, intermediates) = _get_prediction(noised_latents, noised_proprio, signal_levels, step_sizes_log2, return_agent_tokens = True, return_time_cache = True)
|
||||
|
||||
if return_pred_only:
|
||||
if not return_intermediates:
|
||||
return pred
|
||||
|
||||
return pred, (encoded_agent_tokens, intermediates)
|
||||
return pred, (embeds, intermediates)
|
||||
|
||||
# pack the predictions to calculate flow for different modalities all at once
|
||||
|
||||
@ -3732,6 +3755,8 @@ class DynamicsWorldModel(Module):
|
||||
|
||||
if exists(rewards):
|
||||
|
||||
encoded_agent_tokens = embeds.agent
|
||||
|
||||
if rewards.ndim == 2: # (b t)
|
||||
encoded_agent_tokens = reduce(encoded_agent_tokens, 'b t g d -> b t d', 'mean')
|
||||
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "dreamer4"
|
||||
version = "0.1.21"
|
||||
version = "0.1.22"
|
||||
description = "Dreamer 4"
|
||||
authors = [
|
||||
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user