correct a misunderstanding where past actions is a separate action token, while agent token is used for the prediction of next action, rewards, values
This commit is contained in:
parent
9c78962736
commit
6dbdc3d7d8
@ -87,6 +87,9 @@ def is_power_two(num):
|
|||||||
|
|
||||||
# tensor helpers
|
# tensor helpers
|
||||||
|
|
||||||
|
def is_empty(t):
|
||||||
|
return t.numel() == 0
|
||||||
|
|
||||||
def log(t, eps = 1e-20):
|
def log(t, eps = 1e-20):
|
||||||
return t.clamp(min = eps).log()
|
return t.clamp(min = eps).log()
|
||||||
|
|
||||||
@ -1944,20 +1947,6 @@ class DynamicsWorldModel(Module):
|
|||||||
|
|
||||||
agent_tokens = repeat(agent_tokens, 'b ... d -> b t ... d', t = time)
|
agent_tokens = repeat(agent_tokens, 'b ... d -> b t ... d', t = time)
|
||||||
|
|
||||||
# maybe add the action embed to the agent tokens per time step
|
|
||||||
|
|
||||||
if exists(discrete_actions) or exists(continuous_actions):
|
|
||||||
assert self.action_embedder.has_actions
|
|
||||||
|
|
||||||
action_embed = self.action_embedder(
|
|
||||||
discrete_actions = discrete_actions,
|
|
||||||
discrete_action_types = discrete_action_types,
|
|
||||||
continuous_actions = continuous_actions,
|
|
||||||
continuous_action_types = continuous_action_types
|
|
||||||
)
|
|
||||||
|
|
||||||
agent_tokens = einx.add('b t ... d, b t d', agent_tokens, action_embed)
|
|
||||||
|
|
||||||
# maybe add a reward embedding to agent tokens
|
# maybe add a reward embedding to agent tokens
|
||||||
|
|
||||||
if exists(rewards):
|
if exists(rewards):
|
||||||
@ -1975,9 +1964,23 @@ class DynamicsWorldModel(Module):
|
|||||||
|
|
||||||
agent_tokens = einx.add('b t ... d, b t d', agent_tokens, reward_embeds)
|
agent_tokens = einx.add('b t ... d, b t d', agent_tokens, reward_embeds)
|
||||||
|
|
||||||
|
# maybe create the action tokens
|
||||||
|
|
||||||
|
if exists(discrete_actions) or exists(continuous_actions):
|
||||||
|
assert self.action_embedder.has_actions
|
||||||
|
|
||||||
|
action_tokens = self.action_embedder(
|
||||||
|
discrete_actions = discrete_actions,
|
||||||
|
discrete_action_types = discrete_action_types,
|
||||||
|
continuous_actions = continuous_actions,
|
||||||
|
continuous_action_types = continuous_action_types
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
action_tokens = agent_tokens[:, :, 0:0] # else empty off agent tokens
|
||||||
|
|
||||||
# main function, needs to be defined as such for shortcut training - additional calls for consistency loss
|
# main function, needs to be defined as such for shortcut training - additional calls for consistency loss
|
||||||
|
|
||||||
def get_prediction(noised_latents, signal_levels, step_sizes_log2, agent_tokens, return_agent_tokens = False):
|
def get_prediction(noised_latents, signal_levels, step_sizes_log2, action_tokens, agent_tokens, return_agent_tokens = False):
|
||||||
# latents to spatial tokens
|
# latents to spatial tokens
|
||||||
|
|
||||||
space_tokens = self.latents_to_spatial_tokens(noised_latents)
|
space_tokens = self.latents_to_spatial_tokens(noised_latents)
|
||||||
@ -1986,6 +1989,10 @@ class DynamicsWorldModel(Module):
|
|||||||
|
|
||||||
num_spatial_tokens = space_tokens.shape[-2]
|
num_spatial_tokens = space_tokens.shape[-2]
|
||||||
|
|
||||||
|
# action tokens
|
||||||
|
|
||||||
|
num_action_tokens = 1 if not is_empty(action_tokens) else 0
|
||||||
|
|
||||||
# pack to tokens
|
# pack to tokens
|
||||||
# [signal + step size embed] [latent space tokens] [register] [actions / agent]
|
# [signal + step size embed] [latent space tokens] [register] [actions / agent]
|
||||||
|
|
||||||
@ -2003,7 +2010,7 @@ class DynamicsWorldModel(Module):
|
|||||||
|
|
||||||
# pack to tokens for attending
|
# pack to tokens for attending
|
||||||
|
|
||||||
tokens, packed_tokens_shape = pack([flow_token, space_tokens, registers, agent_tokens], 'b t * d')
|
tokens, packed_tokens_shape = pack([flow_token, space_tokens, registers, action_tokens, agent_tokens], 'b t * d')
|
||||||
|
|
||||||
# attend functions for space and time
|
# attend functions for space and time
|
||||||
|
|
||||||
@ -2015,6 +2022,7 @@ class DynamicsWorldModel(Module):
|
|||||||
|
|
||||||
space_seq_len = (
|
space_seq_len = (
|
||||||
+ 1 # signal + step
|
+ 1 # signal + step
|
||||||
|
+ num_action_tokens # past action tokens - todo: account for multi-agent
|
||||||
+ self.num_agents # action / agent tokens
|
+ self.num_agents # action / agent tokens
|
||||||
+ self.num_register_tokens
|
+ self.num_register_tokens
|
||||||
+ num_spatial_tokens
|
+ num_spatial_tokens
|
||||||
@ -2056,7 +2064,7 @@ class DynamicsWorldModel(Module):
|
|||||||
|
|
||||||
# unpack
|
# unpack
|
||||||
|
|
||||||
flow_token, space_tokens, register_tokens, agent_tokens = unpack(tokens, packed_tokens_shape, 'b t * d')
|
flow_token, space_tokens, register_tokens, action_tokens, agent_tokens = unpack(tokens, packed_tokens_shape, 'b t * d')
|
||||||
|
|
||||||
# pooling
|
# pooling
|
||||||
|
|
||||||
@ -2071,7 +2079,7 @@ class DynamicsWorldModel(Module):
|
|||||||
|
|
||||||
# forward the network
|
# forward the network
|
||||||
|
|
||||||
pred, encoded_agent_tokens = get_prediction(noised_latents, signal_levels, step_sizes_log2, agent_tokens, return_agent_tokens = True)
|
pred, encoded_agent_tokens = get_prediction(noised_latents, signal_levels, step_sizes_log2, action_tokens, agent_tokens, return_agent_tokens = True)
|
||||||
|
|
||||||
if return_pred_only:
|
if return_pred_only:
|
||||||
if not return_agent_tokens:
|
if not return_agent_tokens:
|
||||||
@ -2108,7 +2116,7 @@ class DynamicsWorldModel(Module):
|
|||||||
step_sizes_log2_minus_one = step_sizes_log2 - 1 # which equals d / 2
|
step_sizes_log2_minus_one = step_sizes_log2 - 1 # which equals d / 2
|
||||||
half_step_size = 2 ** step_sizes_log2_minus_one
|
half_step_size = 2 ** step_sizes_log2_minus_one
|
||||||
|
|
||||||
first_step_pred = get_prediction_no_grad(noised_latents, signal_levels, step_sizes_log2_minus_one, agent_tokens)
|
first_step_pred = get_prediction_no_grad(noised_latents, signal_levels, step_sizes_log2_minus_one, action_tokens, agent_tokens)
|
||||||
|
|
||||||
# first derive b'
|
# first derive b'
|
||||||
|
|
||||||
@ -2127,7 +2135,7 @@ class DynamicsWorldModel(Module):
|
|||||||
# get second prediction for b''
|
# get second prediction for b''
|
||||||
|
|
||||||
signal_levels_plus_half_step = signal_levels + half_step_size[:, None]
|
signal_levels_plus_half_step = signal_levels + half_step_size[:, None]
|
||||||
second_step_pred = get_prediction_no_grad(denoised_latent, signal_levels_plus_half_step, step_sizes_log2_minus_one, agent_tokens)
|
second_step_pred = get_prediction_no_grad(denoised_latent, signal_levels_plus_half_step, step_sizes_log2_minus_one, action_tokens, agent_tokens)
|
||||||
|
|
||||||
if is_v_space_pred:
|
if is_v_space_pred:
|
||||||
second_step_pred_flow = second_step_pred
|
second_step_pred_flow = second_step_pred
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "dreamer4"
|
name = "dreamer4"
|
||||||
version = "0.0.17"
|
version = "0.0.18"
|
||||||
description = "Dreamer 4"
|
description = "Dreamer 4"
|
||||||
authors = [
|
authors = [
|
||||||
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user