diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index d2265e3..c1dd4aa 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -1973,13 +1973,19 @@ class DynamicsWorldModel(Module): agent_index = 0, step_size = 4, max_timesteps = 16, + env_is_vectorized = False, use_time_kv_cache = True ): assert exists(self.video_tokenizer) init_frame = env.reset() - video = rearrange(init_frame, 'c vh vw -> 1 c 1 vh vw') + # frame to video + + if env_is_vectorized: + video = rearrange(init_frame, 'b c vh vw -> b c 1 vh vw') + else: + video = rearrange(init_frame, 'c vh vw -> 1 c 1 vh vw') # accumulate @@ -2053,8 +2059,12 @@ class DynamicsWorldModel(Module): # batch and time dimension - next_frame = rearrange(next_frame, 'c vh vw -> 1 c 1 vh vw') - reward = rearrange(reward, ' -> 1 1') + if env_is_vectorized: + next_frame = rearrange(next_frame, 'b c vh vw -> b c 1 vh vw') + reward = rearrange(reward, 'b -> b 1') + else: + next_frame = rearrange(next_frame, 'c vh vw -> 1 c 1 vh vw') + reward = rearrange(reward, ' -> 1 1') # concat diff --git a/dreamer4/mocks.py b/dreamer4/mocks.py index a207a4f..c26aefe 100644 --- a/dreamer4/mocks.py +++ b/dreamer4/mocks.py @@ -23,6 +23,8 @@ class MockEnv(Module): self.num_envs = num_envs self.vectorized = vectorized + assert not (vectorized and num_envs == 1) + self.register_buffer('_step', tensor(0)) def get_random_state(self): @@ -33,7 +35,12 @@ class MockEnv(Module): seed = None ): self._step.zero_() - return self.get_random_state() + state = self.get_random_state() + + if self.vectorized: + state = repeat(state, '... -> b ...', b = self.num_envs) + + return state def step( self, @@ -46,9 +53,10 @@ class MockEnv(Module): if not self.vectorized: return state, reward - assert actions.shape[0] == self.num_envs + discrete, continuous = actions + assert discrete.shape[0] == self.num_envs, f'expected batch of actions for {self.num_envs} environments' state = repeat(state, '... -> b ...', b = self.num_envs) reward = repeat(reward, ' -> b', b = self.num_envs) - return state, rewards + return state, reward diff --git a/pyproject.toml b/pyproject.toml index a72e080..c01c408 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "dreamer4" -version = "0.0.62" +version = "0.0.65" description = "Dreamer 4" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" } diff --git a/tests/test_dreamer.py b/tests/test_dreamer.py index f0cc357..cda0419 100644 --- a/tests/test_dreamer.py +++ b/tests/test_dreamer.py @@ -600,7 +600,8 @@ def test_cache_generate(): generated, time_kv_cache = dynamics.generate(1, time_kv_cache = time_kv_cache, return_time_kv_cache = True) generated, time_kv_cache = dynamics.generate(1, time_kv_cache = time_kv_cache, return_time_kv_cache = True) -def test_online_rl(): +@param('vectorized', (False, True)) +def test_online_rl(vectorized): from dreamer4.dreamer4 import DynamicsWorldModel, VideoTokenizer tokenizer = VideoTokenizer( @@ -632,9 +633,9 @@ def test_online_rl(): ) from dreamer4.mocks import MockEnv - mock_env = MockEnv((256, 256), vectorized = False, num_envs = 4) + mock_env = MockEnv((256, 256), vectorized = vectorized, num_envs = 4) - one_experience = world_model_and_policy.interact_with_env(mock_env, max_timesteps = 16) + one_experience = world_model_and_policy.interact_with_env(mock_env, max_timesteps = 16, env_is_vectorized = vectorized) actor_loss, critic_loss = world_model_and_policy.learn_from_experience(one_experience)