add a separate state prediction head for the state entropy

This commit is contained in:
lucidrains 2025-12-02 15:58:25 -08:00
parent 2e7f406d49
commit 125693ce1c

View File

@ -2058,6 +2058,8 @@ class DynamicsWorldModel(Module):
prob_no_shortcut_train = None, # probability of no shortcut training, defaults to 1 / num_step_sizes
add_reward_embed_to_agent_token = False,
add_reward_embed_dropout = 0.1,
add_state_pred_head = False,
state_pred_loss_weight = 0.1,
num_discrete_actions: int | tuple[int, ...] = 0,
num_continuous_actions = 0,
continuous_norm_stats = None,
@ -2181,6 +2183,22 @@ class DynamicsWorldModel(Module):
self.pred_orig_latent = pred_orig_latent # x-space or v-space
self.loss_weight_fn = loss_weight_fn
# state prediction, for state entropy bonus
self.add_state_pred_head = add_state_pred_head
self.state_pred_loss_weight = state_pred_loss_weight
self.should_pred_state = add_state_pred_head and state_pred_loss_weight > 0.
if self.should_pred_state:
self.state_pred_token = nn.Parameter(torch.randn(dim) * 1e-2)
self.to_state_pred = Sequential(
RMSNorm(dim),
nn.Linear(dim, num_latent_tokens * dim_latent * 2),
Rearrange('... (n d two) -> ... n d two', n = num_latent_tokens, two = 2)
)
# reinforcement related
# they sum all the actions into a single token
@ -3387,9 +3405,13 @@ class DynamicsWorldModel(Module):
agent_tokens = repeat(agent_tokens, 'b ... d -> b t ... d', t = time)
# empty token
empty_token = agent_tokens[:, :, 0:0]
# maybe reward tokens
reward_tokens = agent_tokens[:, :, 0:0]
reward_tokens = empty_token
if exists(rewards):
two_hot_encoding = self.reward_encoder(rewards)
@ -3429,6 +3451,13 @@ class DynamicsWorldModel(Module):
else:
noised_proprio = proprio
# maybe state prediction token
if self.should_pred_state:
state_pred_token = repeat(self.state_pred_token, 'd -> b t 1 d', b = batch, t = time)
else:
state_pred_token = empty_token
# maybe create the action tokens
if exists(discrete_actions) or exists(continuous_actions):
@ -3453,11 +3482,11 @@ class DynamicsWorldModel(Module):
action_tokens = torch.zeros_like(agent_tokens[:, :, 0:1])
else:
action_tokens = agent_tokens[:, :, 0:0] # else empty off agent tokens
action_tokens = empty_token # else empty off agent tokens
# main function, needs to be defined as such for shortcut training - additional calls for consistency loss
def get_prediction(noised_latents, noised_proprio, signal_levels, step_sizes_log2, action_tokens, reward_tokens, agent_tokens, return_agent_tokens = False, return_time_cache = False):
def get_prediction(noised_latents, noised_proprio, signal_levels, step_sizes_log2, state_pred_token, action_tokens, reward_tokens, agent_tokens, return_agent_tokens = False, return_time_cache = False):
# latents to spatial tokens
@ -3506,7 +3535,7 @@ class DynamicsWorldModel(Module):
# pack to tokens for attending
tokens, packed_tokens_shape = pack([flow_token, space_tokens, proprio_token, registers, action_tokens, reward_tokens, agent_tokens], 'b t * d')
tokens, packed_tokens_shape = pack([flow_token, space_tokens, proprio_token, state_pred_token, registers, action_tokens, reward_tokens, agent_tokens], 'b t * d')
# attention
@ -3514,7 +3543,7 @@ class DynamicsWorldModel(Module):
# unpack
flow_token, space_tokens, proprio_token, register_tokens, action_tokens, reward_tokens, agent_tokens = unpack(tokens, packed_tokens_shape, 'b t * d')
flow_token, space_tokens, proprio_token, state_pred_token, register_tokens, action_tokens, reward_tokens, agent_tokens = unpack(tokens, packed_tokens_shape, 'b t * d')
# pooling
@ -3529,6 +3558,11 @@ class DynamicsWorldModel(Module):
pred = (pred, pred_proprio)
# maybe state pred
if self.should_pred_state:
pred_state = self.to_state_pred(state_pred_token)
# returning
if not return_agent_tokens:
@ -3541,7 +3575,7 @@ class DynamicsWorldModel(Module):
# curry into get_prediction what does not change during first call as well as the shortcut ones
_get_prediction = partial(get_prediction, action_tokens = action_tokens, reward_tokens = reward_tokens, agent_tokens = agent_tokens)
_get_prediction = partial(get_prediction, state_pred_token = state_pred_token, action_tokens = action_tokens, reward_tokens = reward_tokens, agent_tokens = agent_tokens)
# forward the network