correct a misunderstanding where past actions is a separate action token, while agent token is used for the prediction of next action, rewards, values
This commit is contained in:
parent
9c78962736
commit
6dbdc3d7d8
@ -87,6 +87,9 @@ def is_power_two(num):
|
||||
|
||||
# tensor helpers
|
||||
|
||||
def is_empty(t):
|
||||
return t.numel() == 0
|
||||
|
||||
def log(t, eps = 1e-20):
|
||||
return t.clamp(min = eps).log()
|
||||
|
||||
@ -1944,20 +1947,6 @@ class DynamicsWorldModel(Module):
|
||||
|
||||
agent_tokens = repeat(agent_tokens, 'b ... d -> b t ... d', t = time)
|
||||
|
||||
# maybe add the action embed to the agent tokens per time step
|
||||
|
||||
if exists(discrete_actions) or exists(continuous_actions):
|
||||
assert self.action_embedder.has_actions
|
||||
|
||||
action_embed = self.action_embedder(
|
||||
discrete_actions = discrete_actions,
|
||||
discrete_action_types = discrete_action_types,
|
||||
continuous_actions = continuous_actions,
|
||||
continuous_action_types = continuous_action_types
|
||||
)
|
||||
|
||||
agent_tokens = einx.add('b t ... d, b t d', agent_tokens, action_embed)
|
||||
|
||||
# maybe add a reward embedding to agent tokens
|
||||
|
||||
if exists(rewards):
|
||||
@ -1975,9 +1964,23 @@ class DynamicsWorldModel(Module):
|
||||
|
||||
agent_tokens = einx.add('b t ... d, b t d', agent_tokens, reward_embeds)
|
||||
|
||||
# maybe create the action tokens
|
||||
|
||||
if exists(discrete_actions) or exists(continuous_actions):
|
||||
assert self.action_embedder.has_actions
|
||||
|
||||
action_tokens = self.action_embedder(
|
||||
discrete_actions = discrete_actions,
|
||||
discrete_action_types = discrete_action_types,
|
||||
continuous_actions = continuous_actions,
|
||||
continuous_action_types = continuous_action_types
|
||||
)
|
||||
else:
|
||||
action_tokens = agent_tokens[:, :, 0:0] # else empty off agent tokens
|
||||
|
||||
# main function, needs to be defined as such for shortcut training - additional calls for consistency loss
|
||||
|
||||
def get_prediction(noised_latents, signal_levels, step_sizes_log2, agent_tokens, return_agent_tokens = False):
|
||||
def get_prediction(noised_latents, signal_levels, step_sizes_log2, action_tokens, agent_tokens, return_agent_tokens = False):
|
||||
# latents to spatial tokens
|
||||
|
||||
space_tokens = self.latents_to_spatial_tokens(noised_latents)
|
||||
@ -1986,6 +1989,10 @@ class DynamicsWorldModel(Module):
|
||||
|
||||
num_spatial_tokens = space_tokens.shape[-2]
|
||||
|
||||
# action tokens
|
||||
|
||||
num_action_tokens = 1 if not is_empty(action_tokens) else 0
|
||||
|
||||
# pack to tokens
|
||||
# [signal + step size embed] [latent space tokens] [register] [actions / agent]
|
||||
|
||||
@ -2003,7 +2010,7 @@ class DynamicsWorldModel(Module):
|
||||
|
||||
# pack to tokens for attending
|
||||
|
||||
tokens, packed_tokens_shape = pack([flow_token, space_tokens, registers, agent_tokens], 'b t * d')
|
||||
tokens, packed_tokens_shape = pack([flow_token, space_tokens, registers, action_tokens, agent_tokens], 'b t * d')
|
||||
|
||||
# attend functions for space and time
|
||||
|
||||
@ -2015,6 +2022,7 @@ class DynamicsWorldModel(Module):
|
||||
|
||||
space_seq_len = (
|
||||
+ 1 # signal + step
|
||||
+ num_action_tokens # past action tokens - todo: account for multi-agent
|
||||
+ self.num_agents # action / agent tokens
|
||||
+ self.num_register_tokens
|
||||
+ num_spatial_tokens
|
||||
@ -2056,7 +2064,7 @@ class DynamicsWorldModel(Module):
|
||||
|
||||
# unpack
|
||||
|
||||
flow_token, space_tokens, register_tokens, agent_tokens = unpack(tokens, packed_tokens_shape, 'b t * d')
|
||||
flow_token, space_tokens, register_tokens, action_tokens, agent_tokens = unpack(tokens, packed_tokens_shape, 'b t * d')
|
||||
|
||||
# pooling
|
||||
|
||||
@ -2071,7 +2079,7 @@ class DynamicsWorldModel(Module):
|
||||
|
||||
# forward the network
|
||||
|
||||
pred, encoded_agent_tokens = get_prediction(noised_latents, signal_levels, step_sizes_log2, agent_tokens, return_agent_tokens = True)
|
||||
pred, encoded_agent_tokens = get_prediction(noised_latents, signal_levels, step_sizes_log2, action_tokens, agent_tokens, return_agent_tokens = True)
|
||||
|
||||
if return_pred_only:
|
||||
if not return_agent_tokens:
|
||||
@ -2108,7 +2116,7 @@ class DynamicsWorldModel(Module):
|
||||
step_sizes_log2_minus_one = step_sizes_log2 - 1 # which equals d / 2
|
||||
half_step_size = 2 ** step_sizes_log2_minus_one
|
||||
|
||||
first_step_pred = get_prediction_no_grad(noised_latents, signal_levels, step_sizes_log2_minus_one, agent_tokens)
|
||||
first_step_pred = get_prediction_no_grad(noised_latents, signal_levels, step_sizes_log2_minus_one, action_tokens, agent_tokens)
|
||||
|
||||
# first derive b'
|
||||
|
||||
@ -2127,7 +2135,7 @@ class DynamicsWorldModel(Module):
|
||||
# get second prediction for b''
|
||||
|
||||
signal_levels_plus_half_step = signal_levels + half_step_size[:, None]
|
||||
second_step_pred = get_prediction_no_grad(denoised_latent, signal_levels_plus_half_step, step_sizes_log2_minus_one, agent_tokens)
|
||||
second_step_pred = get_prediction_no_grad(denoised_latent, signal_levels_plus_half_step, step_sizes_log2_minus_one, action_tokens, agent_tokens)
|
||||
|
||||
if is_v_space_pred:
|
||||
second_step_pred_flow = second_step_pred
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "dreamer4"
|
||||
version = "0.0.17"
|
||||
version = "0.0.18"
|
||||
description = "Dreamer 4"
|
||||
authors = [
|
||||
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user