diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index b156843..18e22f8 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -1548,7 +1548,7 @@ class DynamicsWorldModel(Module): reward_embeds = pad_at_dim(reward_embeds, (1, -pop_last_reward), dim = -2, value = 0.) # shift as each agent token predicts the next reward - agent_tokens = einx.add('b t ... d, b t', agent_tokens, reward_embeds) + agent_tokens = einx.add('b t ... d, b t d', agent_tokens, reward_embeds) # main function, needs to be defined as such for shortcut training - additional calls for consistency loss @@ -1589,8 +1589,8 @@ class DynamicsWorldModel(Module): attend_kwargs = dict(use_flex = use_flex, softclamp_value = self.attn_softclamp_value, device = device) space_seq_len = ( - 1 # action / agent token - + 1 # signal + step + + 1 # signal + step + + self.num_agents # action / agent tokens + self.num_register_tokens + num_spatial_tokens ) diff --git a/pyproject.toml b/pyproject.toml index 65e3543..82c9f58 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "dreamer4" -version = "0.0.6" +version = "0.0.7" description = "Dreamer 4" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" }