From d28251e9f9b3a3b3cf16aa4c10ce44a0d649d4c1 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Tue, 14 Oct 2025 11:10:26 -0700 Subject: [PATCH] another consideration before knocking out the RL logic --- dreamer4/dreamer4.py | 9 +++++++-- pyproject.toml | 2 +- tests/test_dreamer.py | 2 +- 3 files changed, 9 insertions(+), 4 deletions(-) diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index 3db4e7d..1c4a988 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -1817,8 +1817,8 @@ class DynamicsWorldModel(Module): latent_gene_ids = None, # (b) tasks = None, # (b) rewards = None, # (b t) - discrete_actions = None, # (b t na) - continuous_actions = None, # (b t na) + discrete_actions = None, # (b t na) | (b t-1 na) + continuous_actions = None, # (b t na) | (b t-1 na) discrete_action_types = None, # (na) continuous_action_types = None, # (na) return_pred_only = False, @@ -1980,6 +1980,11 @@ class DynamicsWorldModel(Module): continuous_action_types = continuous_action_types ) + # handle first timestep not having an associated past action + + if action_tokens.shape[1] == (time - 1): + action_tokens = pad_at_dim(action_tokens, (1, 0), value = 0. , dim = 1) + action_tokens = add('1 d, b t d', self.action_learned_embed, action_tokens) else: diff --git a/pyproject.toml b/pyproject.toml index 9f40547..343fb20 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "dreamer4" -version = "0.0.19" +version = "0.0.20" description = "Dreamer 4" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" } diff --git a/tests/test_dreamer.py b/tests/test_dreamer.py index aa2b39f..484470d 100644 --- a/tests/test_dreamer.py +++ b/tests/test_dreamer.py @@ -88,7 +88,7 @@ def test_e2e( actions = None if condition_on_actions: - actions = torch.randint(0, 4, (2, 4, 1)) + actions = torch.randint(0, 4, (2, 3, 1)) flow_loss = dynamics( **dynamics_input,