From 995b1f64e52929ae346311096b477e504b2aa442 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Mon, 27 Oct 2025 10:14:28 -0700 Subject: [PATCH] handle environments that return a terminate flag, also make sure episode lens are logged in vectorized env --- dreamer4/dreamer4.py | 32 ++++++++++++++++++++++++++++++-- dreamer4/mocks.py | 39 +++++++++++++++++++++++++++++++-------- pyproject.toml | 2 +- tests/test_dreamer.py | 12 ++++++++++-- 4 files changed, 72 insertions(+), 13 deletions(-) diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index 7d0b920..0b0beaf 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -2118,6 +2118,8 @@ class DynamicsWorldModel(Module): else: video = rearrange(init_frame, 'c vh vw -> 1 c 1 vh vw') + batch, device = video.shape[0], video.device + # accumulate rewards = None @@ -2128,6 +2130,11 @@ class DynamicsWorldModel(Module): values = None latents = None + # keep track of termination, for setting the `is_truncated` field on Experience and for early stopping interaction with env + + is_terminated = full((batch,), False, device = device) + episode_lens = full((batch,), 0, device = device) + # maybe time kv cache time_kv_cache = None @@ -2190,7 +2197,22 @@ class DynamicsWorldModel(Module): # pass the sampled action to the environment and get back next state and reward - next_frame, reward = env.step((sampled_discrete_actions, sampled_continuous_actions)) + env_step_out = env.step((sampled_discrete_actions, sampled_continuous_actions)) + + if len(env_step_out) == 2: + next_frame, reward = env_step_out + terminate = full((batch,), False) + + elif len(env_step_out) == 3: + next_frame, reward, terminate = env_step_out + + # update episode lens + + episode_lens = torch.where(is_terminated, episode_lens, episode_lens + 1) + + # update `is_terminated` + + is_terminated |= terminate # batch and time dimension @@ -2206,6 +2228,11 @@ class DynamicsWorldModel(Module): video = cat((video, next_frame), dim = 2) rewards = safe_cat((rewards, reward), dim = 1) + # early break out if all terminated + + if is_terminated.all(): + break + # package up one experience for learning batch, device = latents.shape[0], latents.device @@ -2219,7 +2246,8 @@ class DynamicsWorldModel(Module): values = values, step_size = step_size, agent_index = agent_index, - lens = full((batch,), max_timesteps + 1, device = device), + is_truncated = ~is_terminated, + lens = episode_lens, is_from_world_model = False ) diff --git a/dreamer4/mocks.py b/dreamer4/mocks.py index c26aefe..2fb1c50 100644 --- a/dreamer4/mocks.py +++ b/dreamer4/mocks.py @@ -7,6 +7,11 @@ from torch.nn import Module from einops import repeat +# helpers + +def exists(v): + return v is not None + # mock env class MockEnv(Module): @@ -15,7 +20,9 @@ class MockEnv(Module): image_shape, reward_range = (-100, 100), num_envs = 1, - vectorized = False + vectorized = False, + terminate_after_step = None, + rand_terminate_prob = 0.05 ): super().__init__() self.image_shape = image_shape @@ -25,6 +32,12 @@ class MockEnv(Module): self.vectorized = vectorized assert not (vectorized and num_envs == 1) + # mocking termination + + self.can_terminate = exists(terminate_after_step) + self.terminate_after_step = terminate_after_step + self.rand_terminate_prob = rand_terminate_prob + self.register_buffer('_step', tensor(0)) def get_random_state(self): @@ -50,13 +63,23 @@ class MockEnv(Module): reward = empty(()).uniform_(*self.reward_range) - if not self.vectorized: - return state, reward + if self.vectorized: + discrete, continuous = actions + assert discrete.shape[0] == self.num_envs, f'expected batch of actions for {self.num_envs} environments' - discrete, continuous = actions - assert discrete.shape[0] == self.num_envs, f'expected batch of actions for {self.num_envs} environments' + state = repeat(state, '... -> b ...', b = self.num_envs) + reward = repeat(reward, ' -> b', b = self.num_envs) - state = repeat(state, '... -> b ...', b = self.num_envs) - reward = repeat(reward, ' -> b', b = self.num_envs) + out = (state, reward) - return state, reward + if self.can_terminate: + terminate = ( + (torch.rand((self.num_envs)) < self.rand_terminate_prob) & + (self._step > self.terminate_after_step) + ) + + out = (*out, terminate) + + self._step.add_(1) + + return out diff --git a/pyproject.toml b/pyproject.toml index 34ca4e2..9378e94 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "dreamer4" -version = "0.0.81" +version = "0.0.82" description = "Dreamer 4" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" } diff --git a/tests/test_dreamer.py b/tests/test_dreamer.py index 3a198f4..b2a850c 100644 --- a/tests/test_dreamer.py +++ b/tests/test_dreamer.py @@ -612,9 +612,11 @@ def test_cache_generate(): @param('vectorized', (False, True)) @param('use_signed_advantage', (False, True)) +@param('env_can_terminate', (False, True)) def test_online_rl( vectorized, - use_signed_advantage + use_signed_advantage, + env_can_terminate ): from dreamer4.dreamer4 import DynamicsWorldModel, VideoTokenizer @@ -649,7 +651,13 @@ def test_online_rl( from dreamer4.mocks import MockEnv from dreamer4.dreamer4 import combine_experiences - mock_env = MockEnv((256, 256), vectorized = vectorized, num_envs = 4) + mock_env = MockEnv( + (256, 256), + vectorized = vectorized, + num_envs = 4, + terminate_after_step = 2 if env_can_terminate else None, + rand_terminate_prob = 0.1 + ) # manually