handle environments that return a terminate flag, also make sure episode lens are logged in vectorized env
This commit is contained in:
parent
fd1e87983b
commit
995b1f64e5
@ -2118,6 +2118,8 @@ class DynamicsWorldModel(Module):
|
||||
else:
|
||||
video = rearrange(init_frame, 'c vh vw -> 1 c 1 vh vw')
|
||||
|
||||
batch, device = video.shape[0], video.device
|
||||
|
||||
# accumulate
|
||||
|
||||
rewards = None
|
||||
@ -2128,6 +2130,11 @@ class DynamicsWorldModel(Module):
|
||||
values = None
|
||||
latents = None
|
||||
|
||||
# keep track of termination, for setting the `is_truncated` field on Experience and for early stopping interaction with env
|
||||
|
||||
is_terminated = full((batch,), False, device = device)
|
||||
episode_lens = full((batch,), 0, device = device)
|
||||
|
||||
# maybe time kv cache
|
||||
|
||||
time_kv_cache = None
|
||||
@ -2190,7 +2197,22 @@ class DynamicsWorldModel(Module):
|
||||
|
||||
# pass the sampled action to the environment and get back next state and reward
|
||||
|
||||
next_frame, reward = env.step((sampled_discrete_actions, sampled_continuous_actions))
|
||||
env_step_out = env.step((sampled_discrete_actions, sampled_continuous_actions))
|
||||
|
||||
if len(env_step_out) == 2:
|
||||
next_frame, reward = env_step_out
|
||||
terminate = full((batch,), False)
|
||||
|
||||
elif len(env_step_out) == 3:
|
||||
next_frame, reward, terminate = env_step_out
|
||||
|
||||
# update episode lens
|
||||
|
||||
episode_lens = torch.where(is_terminated, episode_lens, episode_lens + 1)
|
||||
|
||||
# update `is_terminated`
|
||||
|
||||
is_terminated |= terminate
|
||||
|
||||
# batch and time dimension
|
||||
|
||||
@ -2206,6 +2228,11 @@ class DynamicsWorldModel(Module):
|
||||
video = cat((video, next_frame), dim = 2)
|
||||
rewards = safe_cat((rewards, reward), dim = 1)
|
||||
|
||||
# early break out if all terminated
|
||||
|
||||
if is_terminated.all():
|
||||
break
|
||||
|
||||
# package up one experience for learning
|
||||
|
||||
batch, device = latents.shape[0], latents.device
|
||||
@ -2219,7 +2246,8 @@ class DynamicsWorldModel(Module):
|
||||
values = values,
|
||||
step_size = step_size,
|
||||
agent_index = agent_index,
|
||||
lens = full((batch,), max_timesteps + 1, device = device),
|
||||
is_truncated = ~is_terminated,
|
||||
lens = episode_lens,
|
||||
is_from_world_model = False
|
||||
)
|
||||
|
||||
|
||||
@ -7,6 +7,11 @@ from torch.nn import Module
|
||||
|
||||
from einops import repeat
|
||||
|
||||
# helpers
|
||||
|
||||
def exists(v):
|
||||
return v is not None
|
||||
|
||||
# mock env
|
||||
|
||||
class MockEnv(Module):
|
||||
@ -15,7 +20,9 @@ class MockEnv(Module):
|
||||
image_shape,
|
||||
reward_range = (-100, 100),
|
||||
num_envs = 1,
|
||||
vectorized = False
|
||||
vectorized = False,
|
||||
terminate_after_step = None,
|
||||
rand_terminate_prob = 0.05
|
||||
):
|
||||
super().__init__()
|
||||
self.image_shape = image_shape
|
||||
@ -25,6 +32,12 @@ class MockEnv(Module):
|
||||
self.vectorized = vectorized
|
||||
assert not (vectorized and num_envs == 1)
|
||||
|
||||
# mocking termination
|
||||
|
||||
self.can_terminate = exists(terminate_after_step)
|
||||
self.terminate_after_step = terminate_after_step
|
||||
self.rand_terminate_prob = rand_terminate_prob
|
||||
|
||||
self.register_buffer('_step', tensor(0))
|
||||
|
||||
def get_random_state(self):
|
||||
@ -50,13 +63,23 @@ class MockEnv(Module):
|
||||
|
||||
reward = empty(()).uniform_(*self.reward_range)
|
||||
|
||||
if not self.vectorized:
|
||||
return state, reward
|
||||
if self.vectorized:
|
||||
discrete, continuous = actions
|
||||
assert discrete.shape[0] == self.num_envs, f'expected batch of actions for {self.num_envs} environments'
|
||||
|
||||
discrete, continuous = actions
|
||||
assert discrete.shape[0] == self.num_envs, f'expected batch of actions for {self.num_envs} environments'
|
||||
state = repeat(state, '... -> b ...', b = self.num_envs)
|
||||
reward = repeat(reward, ' -> b', b = self.num_envs)
|
||||
|
||||
state = repeat(state, '... -> b ...', b = self.num_envs)
|
||||
reward = repeat(reward, ' -> b', b = self.num_envs)
|
||||
out = (state, reward)
|
||||
|
||||
return state, reward
|
||||
if self.can_terminate:
|
||||
terminate = (
|
||||
(torch.rand((self.num_envs)) < self.rand_terminate_prob) &
|
||||
(self._step > self.terminate_after_step)
|
||||
)
|
||||
|
||||
out = (*out, terminate)
|
||||
|
||||
self._step.add_(1)
|
||||
|
||||
return out
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "dreamer4"
|
||||
version = "0.0.81"
|
||||
version = "0.0.82"
|
||||
description = "Dreamer 4"
|
||||
authors = [
|
||||
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
||||
|
||||
@ -612,9 +612,11 @@ def test_cache_generate():
|
||||
|
||||
@param('vectorized', (False, True))
|
||||
@param('use_signed_advantage', (False, True))
|
||||
@param('env_can_terminate', (False, True))
|
||||
def test_online_rl(
|
||||
vectorized,
|
||||
use_signed_advantage
|
||||
use_signed_advantage,
|
||||
env_can_terminate
|
||||
):
|
||||
from dreamer4.dreamer4 import DynamicsWorldModel, VideoTokenizer
|
||||
|
||||
@ -649,7 +651,13 @@ def test_online_rl(
|
||||
from dreamer4.mocks import MockEnv
|
||||
from dreamer4.dreamer4 import combine_experiences
|
||||
|
||||
mock_env = MockEnv((256, 256), vectorized = vectorized, num_envs = 4)
|
||||
mock_env = MockEnv(
|
||||
(256, 256),
|
||||
vectorized = vectorized,
|
||||
num_envs = 4,
|
||||
terminate_after_step = 2 if env_can_terminate else None,
|
||||
rand_terminate_prob = 0.1
|
||||
)
|
||||
|
||||
# manually
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user