handle environments that return a terminate flag, also make sure episode lens are logged in vectorized env

This commit is contained in:
lucidrains 2025-10-27 10:14:28 -07:00
parent fd1e87983b
commit 995b1f64e5
4 changed files with 72 additions and 13 deletions

View File

@ -2118,6 +2118,8 @@ class DynamicsWorldModel(Module):
else:
video = rearrange(init_frame, 'c vh vw -> 1 c 1 vh vw')
batch, device = video.shape[0], video.device
# accumulate
rewards = None
@ -2128,6 +2130,11 @@ class DynamicsWorldModel(Module):
values = None
latents = None
# keep track of termination, for setting the `is_truncated` field on Experience and for early stopping interaction with env
is_terminated = full((batch,), False, device = device)
episode_lens = full((batch,), 0, device = device)
# maybe time kv cache
time_kv_cache = None
@ -2190,7 +2197,22 @@ class DynamicsWorldModel(Module):
# pass the sampled action to the environment and get back next state and reward
next_frame, reward = env.step((sampled_discrete_actions, sampled_continuous_actions))
env_step_out = env.step((sampled_discrete_actions, sampled_continuous_actions))
if len(env_step_out) == 2:
next_frame, reward = env_step_out
terminate = full((batch,), False)
elif len(env_step_out) == 3:
next_frame, reward, terminate = env_step_out
# update episode lens
episode_lens = torch.where(is_terminated, episode_lens, episode_lens + 1)
# update `is_terminated`
is_terminated |= terminate
# batch and time dimension
@ -2206,6 +2228,11 @@ class DynamicsWorldModel(Module):
video = cat((video, next_frame), dim = 2)
rewards = safe_cat((rewards, reward), dim = 1)
# early break out if all terminated
if is_terminated.all():
break
# package up one experience for learning
batch, device = latents.shape[0], latents.device
@ -2219,7 +2246,8 @@ class DynamicsWorldModel(Module):
values = values,
step_size = step_size,
agent_index = agent_index,
lens = full((batch,), max_timesteps + 1, device = device),
is_truncated = ~is_terminated,
lens = episode_lens,
is_from_world_model = False
)

View File

@ -7,6 +7,11 @@ from torch.nn import Module
from einops import repeat
# helpers
def exists(v):
return v is not None
# mock env
class MockEnv(Module):
@ -15,7 +20,9 @@ class MockEnv(Module):
image_shape,
reward_range = (-100, 100),
num_envs = 1,
vectorized = False
vectorized = False,
terminate_after_step = None,
rand_terminate_prob = 0.05
):
super().__init__()
self.image_shape = image_shape
@ -25,6 +32,12 @@ class MockEnv(Module):
self.vectorized = vectorized
assert not (vectorized and num_envs == 1)
# mocking termination
self.can_terminate = exists(terminate_after_step)
self.terminate_after_step = terminate_after_step
self.rand_terminate_prob = rand_terminate_prob
self.register_buffer('_step', tensor(0))
def get_random_state(self):
@ -50,13 +63,23 @@ class MockEnv(Module):
reward = empty(()).uniform_(*self.reward_range)
if not self.vectorized:
return state, reward
if self.vectorized:
discrete, continuous = actions
assert discrete.shape[0] == self.num_envs, f'expected batch of actions for {self.num_envs} environments'
discrete, continuous = actions
assert discrete.shape[0] == self.num_envs, f'expected batch of actions for {self.num_envs} environments'
state = repeat(state, '... -> b ...', b = self.num_envs)
reward = repeat(reward, ' -> b', b = self.num_envs)
state = repeat(state, '... -> b ...', b = self.num_envs)
reward = repeat(reward, ' -> b', b = self.num_envs)
out = (state, reward)
return state, reward
if self.can_terminate:
terminate = (
(torch.rand((self.num_envs)) < self.rand_terminate_prob) &
(self._step > self.terminate_after_step)
)
out = (*out, terminate)
self._step.add_(1)
return out

View File

@ -1,6 +1,6 @@
[project]
name = "dreamer4"
version = "0.0.81"
version = "0.0.82"
description = "Dreamer 4"
authors = [
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

View File

@ -612,9 +612,11 @@ def test_cache_generate():
@param('vectorized', (False, True))
@param('use_signed_advantage', (False, True))
@param('env_can_terminate', (False, True))
def test_online_rl(
vectorized,
use_signed_advantage
use_signed_advantage,
env_can_terminate
):
from dreamer4.dreamer4 import DynamicsWorldModel, VideoTokenizer
@ -649,7 +651,13 @@ def test_online_rl(
from dreamer4.mocks import MockEnv
from dreamer4.dreamer4 import combine_experiences
mock_env = MockEnv((256, 256), vectorized = vectorized, num_envs = 4)
mock_env = MockEnv(
(256, 256),
vectorized = vectorized,
num_envs = 4,
terminate_after_step = 2 if env_can_terminate else None,
rand_terminate_prob = 0.1
)
# manually