add a separate state prediction head for the state entropy
This commit is contained in:
parent
2e7f406d49
commit
125693ce1c
@ -2058,6 +2058,8 @@ class DynamicsWorldModel(Module):
|
|||||||
prob_no_shortcut_train = None, # probability of no shortcut training, defaults to 1 / num_step_sizes
|
prob_no_shortcut_train = None, # probability of no shortcut training, defaults to 1 / num_step_sizes
|
||||||
add_reward_embed_to_agent_token = False,
|
add_reward_embed_to_agent_token = False,
|
||||||
add_reward_embed_dropout = 0.1,
|
add_reward_embed_dropout = 0.1,
|
||||||
|
add_state_pred_head = False,
|
||||||
|
state_pred_loss_weight = 0.1,
|
||||||
num_discrete_actions: int | tuple[int, ...] = 0,
|
num_discrete_actions: int | tuple[int, ...] = 0,
|
||||||
num_continuous_actions = 0,
|
num_continuous_actions = 0,
|
||||||
continuous_norm_stats = None,
|
continuous_norm_stats = None,
|
||||||
@ -2181,6 +2183,22 @@ class DynamicsWorldModel(Module):
|
|||||||
self.pred_orig_latent = pred_orig_latent # x-space or v-space
|
self.pred_orig_latent = pred_orig_latent # x-space or v-space
|
||||||
self.loss_weight_fn = loss_weight_fn
|
self.loss_weight_fn = loss_weight_fn
|
||||||
|
|
||||||
|
# state prediction, for state entropy bonus
|
||||||
|
|
||||||
|
self.add_state_pred_head = add_state_pred_head
|
||||||
|
self.state_pred_loss_weight = state_pred_loss_weight
|
||||||
|
|
||||||
|
self.should_pred_state = add_state_pred_head and state_pred_loss_weight > 0.
|
||||||
|
|
||||||
|
if self.should_pred_state:
|
||||||
|
self.state_pred_token = nn.Parameter(torch.randn(dim) * 1e-2)
|
||||||
|
|
||||||
|
self.to_state_pred = Sequential(
|
||||||
|
RMSNorm(dim),
|
||||||
|
nn.Linear(dim, num_latent_tokens * dim_latent * 2),
|
||||||
|
Rearrange('... (n d two) -> ... n d two', n = num_latent_tokens, two = 2)
|
||||||
|
)
|
||||||
|
|
||||||
# reinforcement related
|
# reinforcement related
|
||||||
|
|
||||||
# they sum all the actions into a single token
|
# they sum all the actions into a single token
|
||||||
@ -3387,9 +3405,13 @@ class DynamicsWorldModel(Module):
|
|||||||
|
|
||||||
agent_tokens = repeat(agent_tokens, 'b ... d -> b t ... d', t = time)
|
agent_tokens = repeat(agent_tokens, 'b ... d -> b t ... d', t = time)
|
||||||
|
|
||||||
|
# empty token
|
||||||
|
|
||||||
|
empty_token = agent_tokens[:, :, 0:0]
|
||||||
|
|
||||||
# maybe reward tokens
|
# maybe reward tokens
|
||||||
|
|
||||||
reward_tokens = agent_tokens[:, :, 0:0]
|
reward_tokens = empty_token
|
||||||
|
|
||||||
if exists(rewards):
|
if exists(rewards):
|
||||||
two_hot_encoding = self.reward_encoder(rewards)
|
two_hot_encoding = self.reward_encoder(rewards)
|
||||||
@ -3429,6 +3451,13 @@ class DynamicsWorldModel(Module):
|
|||||||
else:
|
else:
|
||||||
noised_proprio = proprio
|
noised_proprio = proprio
|
||||||
|
|
||||||
|
# maybe state prediction token
|
||||||
|
|
||||||
|
if self.should_pred_state:
|
||||||
|
state_pred_token = repeat(self.state_pred_token, 'd -> b t 1 d', b = batch, t = time)
|
||||||
|
else:
|
||||||
|
state_pred_token = empty_token
|
||||||
|
|
||||||
# maybe create the action tokens
|
# maybe create the action tokens
|
||||||
|
|
||||||
if exists(discrete_actions) or exists(continuous_actions):
|
if exists(discrete_actions) or exists(continuous_actions):
|
||||||
@ -3453,11 +3482,11 @@ class DynamicsWorldModel(Module):
|
|||||||
action_tokens = torch.zeros_like(agent_tokens[:, :, 0:1])
|
action_tokens = torch.zeros_like(agent_tokens[:, :, 0:1])
|
||||||
|
|
||||||
else:
|
else:
|
||||||
action_tokens = agent_tokens[:, :, 0:0] # else empty off agent tokens
|
action_tokens = empty_token # else empty off agent tokens
|
||||||
|
|
||||||
# main function, needs to be defined as such for shortcut training - additional calls for consistency loss
|
# main function, needs to be defined as such for shortcut training - additional calls for consistency loss
|
||||||
|
|
||||||
def get_prediction(noised_latents, noised_proprio, signal_levels, step_sizes_log2, action_tokens, reward_tokens, agent_tokens, return_agent_tokens = False, return_time_cache = False):
|
def get_prediction(noised_latents, noised_proprio, signal_levels, step_sizes_log2, state_pred_token, action_tokens, reward_tokens, agent_tokens, return_agent_tokens = False, return_time_cache = False):
|
||||||
|
|
||||||
# latents to spatial tokens
|
# latents to spatial tokens
|
||||||
|
|
||||||
@ -3506,7 +3535,7 @@ class DynamicsWorldModel(Module):
|
|||||||
|
|
||||||
# pack to tokens for attending
|
# pack to tokens for attending
|
||||||
|
|
||||||
tokens, packed_tokens_shape = pack([flow_token, space_tokens, proprio_token, registers, action_tokens, reward_tokens, agent_tokens], 'b t * d')
|
tokens, packed_tokens_shape = pack([flow_token, space_tokens, proprio_token, state_pred_token, registers, action_tokens, reward_tokens, agent_tokens], 'b t * d')
|
||||||
|
|
||||||
# attention
|
# attention
|
||||||
|
|
||||||
@ -3514,7 +3543,7 @@ class DynamicsWorldModel(Module):
|
|||||||
|
|
||||||
# unpack
|
# unpack
|
||||||
|
|
||||||
flow_token, space_tokens, proprio_token, register_tokens, action_tokens, reward_tokens, agent_tokens = unpack(tokens, packed_tokens_shape, 'b t * d')
|
flow_token, space_tokens, proprio_token, state_pred_token, register_tokens, action_tokens, reward_tokens, agent_tokens = unpack(tokens, packed_tokens_shape, 'b t * d')
|
||||||
|
|
||||||
# pooling
|
# pooling
|
||||||
|
|
||||||
@ -3529,6 +3558,11 @@ class DynamicsWorldModel(Module):
|
|||||||
|
|
||||||
pred = (pred, pred_proprio)
|
pred = (pred, pred_proprio)
|
||||||
|
|
||||||
|
# maybe state pred
|
||||||
|
|
||||||
|
if self.should_pred_state:
|
||||||
|
pred_state = self.to_state_pred(state_pred_token)
|
||||||
|
|
||||||
# returning
|
# returning
|
||||||
|
|
||||||
if not return_agent_tokens:
|
if not return_agent_tokens:
|
||||||
@ -3541,7 +3575,7 @@ class DynamicsWorldModel(Module):
|
|||||||
|
|
||||||
# curry into get_prediction what does not change during first call as well as the shortcut ones
|
# curry into get_prediction what does not change during first call as well as the shortcut ones
|
||||||
|
|
||||||
_get_prediction = partial(get_prediction, action_tokens = action_tokens, reward_tokens = reward_tokens, agent_tokens = agent_tokens)
|
_get_prediction = partial(get_prediction, state_pred_token = state_pred_token, action_tokens = action_tokens, reward_tokens = reward_tokens, agent_tokens = agent_tokens)
|
||||||
|
|
||||||
# forward the network
|
# forward the network
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user