add a separate state prediction head for the state entropy
This commit is contained in:
parent
2e7f406d49
commit
125693ce1c
@ -2058,6 +2058,8 @@ class DynamicsWorldModel(Module):
|
||||
prob_no_shortcut_train = None, # probability of no shortcut training, defaults to 1 / num_step_sizes
|
||||
add_reward_embed_to_agent_token = False,
|
||||
add_reward_embed_dropout = 0.1,
|
||||
add_state_pred_head = False,
|
||||
state_pred_loss_weight = 0.1,
|
||||
num_discrete_actions: int | tuple[int, ...] = 0,
|
||||
num_continuous_actions = 0,
|
||||
continuous_norm_stats = None,
|
||||
@ -2181,6 +2183,22 @@ class DynamicsWorldModel(Module):
|
||||
self.pred_orig_latent = pred_orig_latent # x-space or v-space
|
||||
self.loss_weight_fn = loss_weight_fn
|
||||
|
||||
# state prediction, for state entropy bonus
|
||||
|
||||
self.add_state_pred_head = add_state_pred_head
|
||||
self.state_pred_loss_weight = state_pred_loss_weight
|
||||
|
||||
self.should_pred_state = add_state_pred_head and state_pred_loss_weight > 0.
|
||||
|
||||
if self.should_pred_state:
|
||||
self.state_pred_token = nn.Parameter(torch.randn(dim) * 1e-2)
|
||||
|
||||
self.to_state_pred = Sequential(
|
||||
RMSNorm(dim),
|
||||
nn.Linear(dim, num_latent_tokens * dim_latent * 2),
|
||||
Rearrange('... (n d two) -> ... n d two', n = num_latent_tokens, two = 2)
|
||||
)
|
||||
|
||||
# reinforcement related
|
||||
|
||||
# they sum all the actions into a single token
|
||||
@ -3387,9 +3405,13 @@ class DynamicsWorldModel(Module):
|
||||
|
||||
agent_tokens = repeat(agent_tokens, 'b ... d -> b t ... d', t = time)
|
||||
|
||||
# empty token
|
||||
|
||||
empty_token = agent_tokens[:, :, 0:0]
|
||||
|
||||
# maybe reward tokens
|
||||
|
||||
reward_tokens = agent_tokens[:, :, 0:0]
|
||||
reward_tokens = empty_token
|
||||
|
||||
if exists(rewards):
|
||||
two_hot_encoding = self.reward_encoder(rewards)
|
||||
@ -3429,6 +3451,13 @@ class DynamicsWorldModel(Module):
|
||||
else:
|
||||
noised_proprio = proprio
|
||||
|
||||
# maybe state prediction token
|
||||
|
||||
if self.should_pred_state:
|
||||
state_pred_token = repeat(self.state_pred_token, 'd -> b t 1 d', b = batch, t = time)
|
||||
else:
|
||||
state_pred_token = empty_token
|
||||
|
||||
# maybe create the action tokens
|
||||
|
||||
if exists(discrete_actions) or exists(continuous_actions):
|
||||
@ -3453,11 +3482,11 @@ class DynamicsWorldModel(Module):
|
||||
action_tokens = torch.zeros_like(agent_tokens[:, :, 0:1])
|
||||
|
||||
else:
|
||||
action_tokens = agent_tokens[:, :, 0:0] # else empty off agent tokens
|
||||
action_tokens = empty_token # else empty off agent tokens
|
||||
|
||||
# main function, needs to be defined as such for shortcut training - additional calls for consistency loss
|
||||
|
||||
def get_prediction(noised_latents, noised_proprio, signal_levels, step_sizes_log2, action_tokens, reward_tokens, agent_tokens, return_agent_tokens = False, return_time_cache = False):
|
||||
def get_prediction(noised_latents, noised_proprio, signal_levels, step_sizes_log2, state_pred_token, action_tokens, reward_tokens, agent_tokens, return_agent_tokens = False, return_time_cache = False):
|
||||
|
||||
# latents to spatial tokens
|
||||
|
||||
@ -3506,7 +3535,7 @@ class DynamicsWorldModel(Module):
|
||||
|
||||
# pack to tokens for attending
|
||||
|
||||
tokens, packed_tokens_shape = pack([flow_token, space_tokens, proprio_token, registers, action_tokens, reward_tokens, agent_tokens], 'b t * d')
|
||||
tokens, packed_tokens_shape = pack([flow_token, space_tokens, proprio_token, state_pred_token, registers, action_tokens, reward_tokens, agent_tokens], 'b t * d')
|
||||
|
||||
# attention
|
||||
|
||||
@ -3514,7 +3543,7 @@ class DynamicsWorldModel(Module):
|
||||
|
||||
# unpack
|
||||
|
||||
flow_token, space_tokens, proprio_token, register_tokens, action_tokens, reward_tokens, agent_tokens = unpack(tokens, packed_tokens_shape, 'b t * d')
|
||||
flow_token, space_tokens, proprio_token, state_pred_token, register_tokens, action_tokens, reward_tokens, agent_tokens = unpack(tokens, packed_tokens_shape, 'b t * d')
|
||||
|
||||
# pooling
|
||||
|
||||
@ -3529,6 +3558,11 @@ class DynamicsWorldModel(Module):
|
||||
|
||||
pred = (pred, pred_proprio)
|
||||
|
||||
# maybe state pred
|
||||
|
||||
if self.should_pred_state:
|
||||
pred_state = self.to_state_pred(state_pred_token)
|
||||
|
||||
# returning
|
||||
|
||||
if not return_agent_tokens:
|
||||
@ -3541,7 +3575,7 @@ class DynamicsWorldModel(Module):
|
||||
|
||||
# curry into get_prediction what does not change during first call as well as the shortcut ones
|
||||
|
||||
_get_prediction = partial(get_prediction, action_tokens = action_tokens, reward_tokens = reward_tokens, agent_tokens = agent_tokens)
|
||||
_get_prediction = partial(get_prediction, state_pred_token = state_pred_token, action_tokens = action_tokens, reward_tokens = reward_tokens, agent_tokens = agent_tokens)
|
||||
|
||||
# forward the network
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user