handle vectorized env
This commit is contained in:
parent
7ecc5d03e8
commit
fb3e026fe0
@ -1973,13 +1973,19 @@ class DynamicsWorldModel(Module):
|
||||
agent_index = 0,
|
||||
step_size = 4,
|
||||
max_timesteps = 16,
|
||||
env_is_vectorized = False,
|
||||
use_time_kv_cache = True
|
||||
):
|
||||
assert exists(self.video_tokenizer)
|
||||
|
||||
init_frame = env.reset()
|
||||
|
||||
video = rearrange(init_frame, 'c vh vw -> 1 c 1 vh vw')
|
||||
# frame to video
|
||||
|
||||
if env_is_vectorized:
|
||||
video = rearrange(init_frame, 'b c vh vw -> b c 1 vh vw')
|
||||
else:
|
||||
video = rearrange(init_frame, 'c vh vw -> 1 c 1 vh vw')
|
||||
|
||||
# accumulate
|
||||
|
||||
@ -2053,8 +2059,12 @@ class DynamicsWorldModel(Module):
|
||||
|
||||
# batch and time dimension
|
||||
|
||||
next_frame = rearrange(next_frame, 'c vh vw -> 1 c 1 vh vw')
|
||||
reward = rearrange(reward, ' -> 1 1')
|
||||
if env_is_vectorized:
|
||||
next_frame = rearrange(next_frame, 'b c vh vw -> b c 1 vh vw')
|
||||
reward = rearrange(reward, 'b -> b 1')
|
||||
else:
|
||||
next_frame = rearrange(next_frame, 'c vh vw -> 1 c 1 vh vw')
|
||||
reward = rearrange(reward, ' -> 1 1')
|
||||
|
||||
# concat
|
||||
|
||||
|
||||
@ -23,6 +23,8 @@ class MockEnv(Module):
|
||||
|
||||
self.num_envs = num_envs
|
||||
self.vectorized = vectorized
|
||||
assert not (vectorized and num_envs == 1)
|
||||
|
||||
self.register_buffer('_step', tensor(0))
|
||||
|
||||
def get_random_state(self):
|
||||
@ -33,7 +35,12 @@ class MockEnv(Module):
|
||||
seed = None
|
||||
):
|
||||
self._step.zero_()
|
||||
return self.get_random_state()
|
||||
state = self.get_random_state()
|
||||
|
||||
if self.vectorized:
|
||||
state = repeat(state, '... -> b ...', b = self.num_envs)
|
||||
|
||||
return state
|
||||
|
||||
def step(
|
||||
self,
|
||||
@ -46,9 +53,10 @@ class MockEnv(Module):
|
||||
if not self.vectorized:
|
||||
return state, reward
|
||||
|
||||
assert actions.shape[0] == self.num_envs
|
||||
discrete, continuous = actions
|
||||
assert discrete.shape[0] == self.num_envs, f'expected batch of actions for {self.num_envs} environments'
|
||||
|
||||
state = repeat(state, '... -> b ...', b = self.num_envs)
|
||||
reward = repeat(reward, ' -> b', b = self.num_envs)
|
||||
|
||||
return state, rewards
|
||||
return state, reward
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "dreamer4"
|
||||
version = "0.0.62"
|
||||
version = "0.0.65"
|
||||
description = "Dreamer 4"
|
||||
authors = [
|
||||
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
||||
|
||||
@ -600,7 +600,8 @@ def test_cache_generate():
|
||||
generated, time_kv_cache = dynamics.generate(1, time_kv_cache = time_kv_cache, return_time_kv_cache = True)
|
||||
generated, time_kv_cache = dynamics.generate(1, time_kv_cache = time_kv_cache, return_time_kv_cache = True)
|
||||
|
||||
def test_online_rl():
|
||||
@param('vectorized', (False, True))
|
||||
def test_online_rl(vectorized):
|
||||
from dreamer4.dreamer4 import DynamicsWorldModel, VideoTokenizer
|
||||
|
||||
tokenizer = VideoTokenizer(
|
||||
@ -632,9 +633,9 @@ def test_online_rl():
|
||||
)
|
||||
|
||||
from dreamer4.mocks import MockEnv
|
||||
mock_env = MockEnv((256, 256), vectorized = False, num_envs = 4)
|
||||
mock_env = MockEnv((256, 256), vectorized = vectorized, num_envs = 4)
|
||||
|
||||
one_experience = world_model_and_policy.interact_with_env(mock_env, max_timesteps = 16)
|
||||
one_experience = world_model_and_policy.interact_with_env(mock_env, max_timesteps = 16, env_is_vectorized = vectorized)
|
||||
|
||||
actor_loss, critic_loss = world_model_and_policy.learn_from_experience(one_experience)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user