From 789f091c63e2adf649d5730225575f387003ed90 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Tue, 28 Oct 2025 08:04:48 -0700 Subject: [PATCH] redo so that max timesteps is treated as truncation at the last timestep, then allow for accepting the truncation signal from the environment and reuse same logic --- dreamer4/dreamer4.py | 38 ++++++++++++++++++++++++++++---------- dreamer4/mocks.py | 22 ++++++++++++++++------ pyproject.toml | 2 +- tests/test_dreamer.py | 5 ++++- 4 files changed, 49 insertions(+), 18 deletions(-) diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index 0b0beaf..91c4299 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -2133,13 +2133,18 @@ class DynamicsWorldModel(Module): # keep track of termination, for setting the `is_truncated` field on Experience and for early stopping interaction with env is_terminated = full((batch,), False, device = device) + is_truncated = full((batch,), False, device = device) + episode_lens = full((batch,), 0, device = device) # maybe time kv cache time_kv_cache = None - for i in range(max_timesteps + 1): + step_index = 0 + + while not is_terminated.all(): + step_index += 1 latents = self.video_tokenizer(video, return_latents = True) @@ -2201,10 +2206,15 @@ class DynamicsWorldModel(Module): if len(env_step_out) == 2: next_frame, reward = env_step_out - terminate = full((batch,), False) + terminated = full((batch,), False) + truncated = full((batch,), False) elif len(env_step_out) == 3: - next_frame, reward, terminate = env_step_out + next_frame, reward, terminated = env_step_out + truncated = full((batch,), False) + + elif len(env_step_out) == 4: + next_frame, reward, terminated, truncated = env_step_out # update episode lens @@ -2212,7 +2222,20 @@ class DynamicsWorldModel(Module): # update `is_terminated` - is_terminated |= terminate + # (1) - environment says it is terminated + # (2) - previous step is truncated (this step is for bootstrap value) + + is_terminated |= (terminated | is_truncated) + + # update `is_truncated` + + if step_index <= max_timesteps: + is_truncated |= truncated + + if step_index == max_timesteps: + # if the step index is at the max time step allowed, set the truncated flag, if not already terminated + + is_truncated |= ~is_terminated # batch and time dimension @@ -2228,11 +2251,6 @@ class DynamicsWorldModel(Module): video = cat((video, next_frame), dim = 2) rewards = safe_cat((rewards, reward), dim = 1) - # early break out if all terminated - - if is_terminated.all(): - break - # package up one experience for learning batch, device = latents.shape[0], latents.device @@ -2246,7 +2264,7 @@ class DynamicsWorldModel(Module): values = values, step_size = step_size, agent_index = agent_index, - is_truncated = ~is_terminated, + is_truncated = is_truncated, lens = episode_lens, is_from_world_model = False ) diff --git a/dreamer4/mocks.py b/dreamer4/mocks.py index 09be2d4..16ebfe2 100644 --- a/dreamer4/mocks.py +++ b/dreamer4/mocks.py @@ -22,7 +22,9 @@ class MockEnv(Module): num_envs = 1, vectorized = False, terminate_after_step = None, - rand_terminate_prob = 0.05 + rand_terminate_prob = 0.05, + can_truncate = False, + rand_truncate_prob = 0.05, ): super().__init__() self.image_shape = image_shape @@ -32,12 +34,15 @@ class MockEnv(Module): self.vectorized = vectorized assert not (vectorized and num_envs == 1) - # mocking termination + # mocking termination and truncation self.can_terminate = exists(terminate_after_step) self.terminate_after_step = terminate_after_step self.rand_terminate_prob = rand_terminate_prob + self.can_truncate = can_truncate + self.rand_truncate_prob = rand_truncate_prob + self.register_buffer('_step', tensor(0)) def get_random_state(self): @@ -72,16 +77,21 @@ class MockEnv(Module): out = (state, reward) + if self.can_terminate: shape = (self.num_envs,) if self.vectorized else (1,) + valid_step = self._step > self.terminate_after_step - terminate = ( - (torch.rand(shape) < self.rand_terminate_prob) & - (self._step > self.terminate_after_step) - ) + terminate = (torch.rand(shape) < self.rand_terminate_prob) & valid_step out = (*out, terminate) + # maybe truncation + + if self.can_truncate: + truncate = (torch.rand(shape) < self.rand_truncate_prob) & valid_step & ~terminate + out = (*out, truncate) + self._step.add_(1) return out diff --git a/pyproject.toml b/pyproject.toml index 9378e94..4ce3fa3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "dreamer4" -version = "0.0.82" +version = "0.0.83" description = "Dreamer 4" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" } diff --git a/tests/test_dreamer.py b/tests/test_dreamer.py index b2a850c..0be14df 100644 --- a/tests/test_dreamer.py +++ b/tests/test_dreamer.py @@ -613,10 +613,12 @@ def test_cache_generate(): @param('vectorized', (False, True)) @param('use_signed_advantage', (False, True)) @param('env_can_terminate', (False, True)) +@param('env_can_truncate', (False, True)) def test_online_rl( vectorized, use_signed_advantage, - env_can_terminate + env_can_terminate, + env_can_truncate ): from dreamer4.dreamer4 import DynamicsWorldModel, VideoTokenizer @@ -656,6 +658,7 @@ def test_online_rl( vectorized = vectorized, num_envs = 4, terminate_after_step = 2 if env_can_terminate else None, + can_truncate = env_can_truncate, rand_terminate_prob = 0.1 )