diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index 2d11d75..3727f64 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -2032,7 +2032,7 @@ class DynamicsWorldModel(Module): # pack to tokens for attending - tokens, packed_tokens_shape = pack([flow_token, space_tokens, registers, action_tokens, agent_tokens], 'b t * d') + tokens, packed_tokens_shape = pack([flow_token, space_tokens, registers, action_tokens, reward_tokens, agent_tokens], 'b t * d') # attend functions for space and time @@ -2087,7 +2087,7 @@ class DynamicsWorldModel(Module): # unpack - flow_token, space_tokens, register_tokens, action_tokens, agent_tokens = unpack(tokens, packed_tokens_shape, 'b t * d') + flow_token, space_tokens, register_tokens, action_tokens, reward_tokens, agent_tokens = unpack(tokens, packed_tokens_shape, 'b t * d') # pooling diff --git a/pyproject.toml b/pyproject.toml index 11fc5ad..299d863 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "dreamer4" -version = "0.0.21" +version = "0.0.22" description = "Dreamer 4" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" }