redo so that max timesteps is treated as truncation at the last timestep, then allow for accepting the truncation signal from the environment and reuse same logic

This commit is contained in:
lucidrains 2025-10-28 08:04:48 -07:00
parent 41ab83f691
commit 789f091c63
4 changed files with 49 additions and 18 deletions

View File

@ -2133,13 +2133,18 @@ class DynamicsWorldModel(Module):
# keep track of termination, for setting the `is_truncated` field on Experience and for early stopping interaction with env
is_terminated = full((batch,), False, device = device)
is_truncated = full((batch,), False, device = device)
episode_lens = full((batch,), 0, device = device)
# maybe time kv cache
time_kv_cache = None
for i in range(max_timesteps + 1):
step_index = 0
while not is_terminated.all():
step_index += 1
latents = self.video_tokenizer(video, return_latents = True)
@ -2201,10 +2206,15 @@ class DynamicsWorldModel(Module):
if len(env_step_out) == 2:
next_frame, reward = env_step_out
terminate = full((batch,), False)
terminated = full((batch,), False)
truncated = full((batch,), False)
elif len(env_step_out) == 3:
next_frame, reward, terminate = env_step_out
next_frame, reward, terminated = env_step_out
truncated = full((batch,), False)
elif len(env_step_out) == 4:
next_frame, reward, terminated, truncated = env_step_out
# update episode lens
@ -2212,7 +2222,20 @@ class DynamicsWorldModel(Module):
# update `is_terminated`
is_terminated |= terminate
# (1) - environment says it is terminated
# (2) - previous step is truncated (this step is for bootstrap value)
is_terminated |= (terminated | is_truncated)
# update `is_truncated`
if step_index <= max_timesteps:
is_truncated |= truncated
if step_index == max_timesteps:
# if the step index is at the max time step allowed, set the truncated flag, if not already terminated
is_truncated |= ~is_terminated
# batch and time dimension
@ -2228,11 +2251,6 @@ class DynamicsWorldModel(Module):
video = cat((video, next_frame), dim = 2)
rewards = safe_cat((rewards, reward), dim = 1)
# early break out if all terminated
if is_terminated.all():
break
# package up one experience for learning
batch, device = latents.shape[0], latents.device
@ -2246,7 +2264,7 @@ class DynamicsWorldModel(Module):
values = values,
step_size = step_size,
agent_index = agent_index,
is_truncated = ~is_terminated,
is_truncated = is_truncated,
lens = episode_lens,
is_from_world_model = False
)

View File

@ -22,7 +22,9 @@ class MockEnv(Module):
num_envs = 1,
vectorized = False,
terminate_after_step = None,
rand_terminate_prob = 0.05
rand_terminate_prob = 0.05,
can_truncate = False,
rand_truncate_prob = 0.05,
):
super().__init__()
self.image_shape = image_shape
@ -32,12 +34,15 @@ class MockEnv(Module):
self.vectorized = vectorized
assert not (vectorized and num_envs == 1)
# mocking termination
# mocking termination and truncation
self.can_terminate = exists(terminate_after_step)
self.terminate_after_step = terminate_after_step
self.rand_terminate_prob = rand_terminate_prob
self.can_truncate = can_truncate
self.rand_truncate_prob = rand_truncate_prob
self.register_buffer('_step', tensor(0))
def get_random_state(self):
@ -72,16 +77,21 @@ class MockEnv(Module):
out = (state, reward)
if self.can_terminate:
shape = (self.num_envs,) if self.vectorized else (1,)
valid_step = self._step > self.terminate_after_step
terminate = (
(torch.rand(shape) < self.rand_terminate_prob) &
(self._step > self.terminate_after_step)
)
terminate = (torch.rand(shape) < self.rand_terminate_prob) & valid_step
out = (*out, terminate)
# maybe truncation
if self.can_truncate:
truncate = (torch.rand(shape) < self.rand_truncate_prob) & valid_step & ~terminate
out = (*out, truncate)
self._step.add_(1)
return out

View File

@ -1,6 +1,6 @@
[project]
name = "dreamer4"
version = "0.0.82"
version = "0.0.83"
description = "Dreamer 4"
authors = [
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

View File

@ -613,10 +613,12 @@ def test_cache_generate():
@param('vectorized', (False, True))
@param('use_signed_advantage', (False, True))
@param('env_can_terminate', (False, True))
@param('env_can_truncate', (False, True))
def test_online_rl(
vectorized,
use_signed_advantage,
env_can_terminate
env_can_terminate,
env_can_truncate
):
from dreamer4.dreamer4 import DynamicsWorldModel, VideoTokenizer
@ -656,6 +658,7 @@ def test_online_rl(
vectorized = vectorized,
num_envs = 4,
terminate_after_step = 2 if env_can_terminate else None,
can_truncate = env_can_truncate,
rand_terminate_prob = 0.1
)