handle vectorized env

This commit is contained in:
lucidrains 2025-10-22 11:19:44 -07:00
parent 7ecc5d03e8
commit fb3e026fe0
4 changed files with 29 additions and 10 deletions

View File

@ -1973,13 +1973,19 @@ class DynamicsWorldModel(Module):
agent_index = 0,
step_size = 4,
max_timesteps = 16,
env_is_vectorized = False,
use_time_kv_cache = True
):
assert exists(self.video_tokenizer)
init_frame = env.reset()
video = rearrange(init_frame, 'c vh vw -> 1 c 1 vh vw')
# frame to video
if env_is_vectorized:
video = rearrange(init_frame, 'b c vh vw -> b c 1 vh vw')
else:
video = rearrange(init_frame, 'c vh vw -> 1 c 1 vh vw')
# accumulate
@ -2053,8 +2059,12 @@ class DynamicsWorldModel(Module):
# batch and time dimension
next_frame = rearrange(next_frame, 'c vh vw -> 1 c 1 vh vw')
reward = rearrange(reward, ' -> 1 1')
if env_is_vectorized:
next_frame = rearrange(next_frame, 'b c vh vw -> b c 1 vh vw')
reward = rearrange(reward, 'b -> b 1')
else:
next_frame = rearrange(next_frame, 'c vh vw -> 1 c 1 vh vw')
reward = rearrange(reward, ' -> 1 1')
# concat

View File

@ -23,6 +23,8 @@ class MockEnv(Module):
self.num_envs = num_envs
self.vectorized = vectorized
assert not (vectorized and num_envs == 1)
self.register_buffer('_step', tensor(0))
def get_random_state(self):
@ -33,7 +35,12 @@ class MockEnv(Module):
seed = None
):
self._step.zero_()
return self.get_random_state()
state = self.get_random_state()
if self.vectorized:
state = repeat(state, '... -> b ...', b = self.num_envs)
return state
def step(
self,
@ -46,9 +53,10 @@ class MockEnv(Module):
if not self.vectorized:
return state, reward
assert actions.shape[0] == self.num_envs
discrete, continuous = actions
assert discrete.shape[0] == self.num_envs, f'expected batch of actions for {self.num_envs} environments'
state = repeat(state, '... -> b ...', b = self.num_envs)
reward = repeat(reward, ' -> b', b = self.num_envs)
return state, rewards
return state, reward

View File

@ -1,6 +1,6 @@
[project]
name = "dreamer4"
version = "0.0.62"
version = "0.0.65"
description = "Dreamer 4"
authors = [
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

View File

@ -600,7 +600,8 @@ def test_cache_generate():
generated, time_kv_cache = dynamics.generate(1, time_kv_cache = time_kv_cache, return_time_kv_cache = True)
generated, time_kv_cache = dynamics.generate(1, time_kv_cache = time_kv_cache, return_time_kv_cache = True)
def test_online_rl():
@param('vectorized', (False, True))
def test_online_rl(vectorized):
from dreamer4.dreamer4 import DynamicsWorldModel, VideoTokenizer
tokenizer = VideoTokenizer(
@ -632,9 +633,9 @@ def test_online_rl():
)
from dreamer4.mocks import MockEnv
mock_env = MockEnv((256, 256), vectorized = False, num_envs = 4)
mock_env = MockEnv((256, 256), vectorized = vectorized, num_envs = 4)
one_experience = world_model_and_policy.interact_with_env(mock_env, max_timesteps = 16)
one_experience = world_model_and_policy.interact_with_env(mock_env, max_timesteps = 16, env_is_vectorized = vectorized)
actor_loss, critic_loss = world_model_and_policy.learn_from_experience(one_experience)