From 46432aee9b581822f7b89e2a215f1af02b5b7882 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Sat, 25 Oct 2025 12:30:08 -0700 Subject: [PATCH] fix an issue with bc --- dreamer4/dreamer4.py | 13 +++++++++++-- pyproject.toml | 2 +- tests/test_dreamer.py | 8 ++++---- 3 files changed, 16 insertions(+), 7 deletions(-) diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index a3cc903..8b507b7 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -284,7 +284,7 @@ def create_multi_token_prediction_targets( batch, seq_len, device = *t.shape[:2], t.device batch_arange = arange(batch, device = device) - seq_arange = arange(seq_len, device = device)[1:] + seq_arange = arange(seq_len, device = device) steps_arange = arange(steps_future, device = device) indices = add('t, steps -> t steps', seq_arange, steps_arange) @@ -3100,7 +3100,7 @@ class DynamicsWorldModel(Module): reward_pred = rearrange(reward_pred, 'mtp b t l -> b l t mtp') - reward_targets, reward_loss_mask = create_multi_token_prediction_targets(two_hot_encoding, self.multi_token_pred_len) + reward_targets, reward_loss_mask = create_multi_token_prediction_targets(two_hot_encoding[:, :-1], self.multi_token_pred_len) reward_targets = rearrange(reward_targets, 'b t mtp l -> b l t mtp') @@ -3126,6 +3126,15 @@ class DynamicsWorldModel(Module): ): assert self.action_embedder.has_actions + # handle actions having time vs time - 1 length + # remove the first action if it is equal to time (as it would come from some agent token in the past) + + if exists(discrete_actions) and discrete_actions.shape[1] == time: + discrete_actions = discrete_actions[:, 1:] + + if exists(continuous_actions) and continuous_actions.shape[1] == time: + continuous_actions = continuous_actions[:, 1:] + # only for 1 agent agent_tokens = rearrange(agent_tokens, 'b t 1 d -> b t d') diff --git a/pyproject.toml b/pyproject.toml index 8981fb6..6622e4f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "dreamer4" -version = "0.0.75" +version = "0.0.76" description = "Dreamer 4" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" } diff --git a/tests/test_dreamer.py b/tests/test_dreamer.py index ae1f679..3a198f4 100644 --- a/tests/test_dreamer.py +++ b/tests/test_dreamer.py @@ -407,14 +407,14 @@ def test_mtp(): reward_targets, mask = create_multi_token_prediction_targets(rewards, 3) # say three token lookahead - assert reward_targets.shape == (3, 15, 3) - assert mask.shape == (3, 15, 3) + assert reward_targets.shape == (3, 16, 3) + assert mask.shape == (3, 16, 3) actions = torch.randint(0, 10, (3, 16, 2)) action_targets, mask = create_multi_token_prediction_targets(actions, 3) - assert action_targets.shape == (3, 15, 3, 2) - assert mask.shape == (3, 15, 3) + assert action_targets.shape == (3, 16, 3, 2) + assert mask.shape == (3, 16, 3) from dreamer4.dreamer4 import ActionEmbedder