redo so that max timesteps is treated as truncation at the last timestep, then allow for accepting the truncation signal from the environment and reuse same logic
This commit is contained in:
parent
41ab83f691
commit
789f091c63
@ -2133,13 +2133,18 @@ class DynamicsWorldModel(Module):
|
||||
# keep track of termination, for setting the `is_truncated` field on Experience and for early stopping interaction with env
|
||||
|
||||
is_terminated = full((batch,), False, device = device)
|
||||
is_truncated = full((batch,), False, device = device)
|
||||
|
||||
episode_lens = full((batch,), 0, device = device)
|
||||
|
||||
# maybe time kv cache
|
||||
|
||||
time_kv_cache = None
|
||||
|
||||
for i in range(max_timesteps + 1):
|
||||
step_index = 0
|
||||
|
||||
while not is_terminated.all():
|
||||
step_index += 1
|
||||
|
||||
latents = self.video_tokenizer(video, return_latents = True)
|
||||
|
||||
@ -2201,10 +2206,15 @@ class DynamicsWorldModel(Module):
|
||||
|
||||
if len(env_step_out) == 2:
|
||||
next_frame, reward = env_step_out
|
||||
terminate = full((batch,), False)
|
||||
terminated = full((batch,), False)
|
||||
truncated = full((batch,), False)
|
||||
|
||||
elif len(env_step_out) == 3:
|
||||
next_frame, reward, terminate = env_step_out
|
||||
next_frame, reward, terminated = env_step_out
|
||||
truncated = full((batch,), False)
|
||||
|
||||
elif len(env_step_out) == 4:
|
||||
next_frame, reward, terminated, truncated = env_step_out
|
||||
|
||||
# update episode lens
|
||||
|
||||
@ -2212,7 +2222,20 @@ class DynamicsWorldModel(Module):
|
||||
|
||||
# update `is_terminated`
|
||||
|
||||
is_terminated |= terminate
|
||||
# (1) - environment says it is terminated
|
||||
# (2) - previous step is truncated (this step is for bootstrap value)
|
||||
|
||||
is_terminated |= (terminated | is_truncated)
|
||||
|
||||
# update `is_truncated`
|
||||
|
||||
if step_index <= max_timesteps:
|
||||
is_truncated |= truncated
|
||||
|
||||
if step_index == max_timesteps:
|
||||
# if the step index is at the max time step allowed, set the truncated flag, if not already terminated
|
||||
|
||||
is_truncated |= ~is_terminated
|
||||
|
||||
# batch and time dimension
|
||||
|
||||
@ -2228,11 +2251,6 @@ class DynamicsWorldModel(Module):
|
||||
video = cat((video, next_frame), dim = 2)
|
||||
rewards = safe_cat((rewards, reward), dim = 1)
|
||||
|
||||
# early break out if all terminated
|
||||
|
||||
if is_terminated.all():
|
||||
break
|
||||
|
||||
# package up one experience for learning
|
||||
|
||||
batch, device = latents.shape[0], latents.device
|
||||
@ -2246,7 +2264,7 @@ class DynamicsWorldModel(Module):
|
||||
values = values,
|
||||
step_size = step_size,
|
||||
agent_index = agent_index,
|
||||
is_truncated = ~is_terminated,
|
||||
is_truncated = is_truncated,
|
||||
lens = episode_lens,
|
||||
is_from_world_model = False
|
||||
)
|
||||
|
||||
@ -22,7 +22,9 @@ class MockEnv(Module):
|
||||
num_envs = 1,
|
||||
vectorized = False,
|
||||
terminate_after_step = None,
|
||||
rand_terminate_prob = 0.05
|
||||
rand_terminate_prob = 0.05,
|
||||
can_truncate = False,
|
||||
rand_truncate_prob = 0.05,
|
||||
):
|
||||
super().__init__()
|
||||
self.image_shape = image_shape
|
||||
@ -32,12 +34,15 @@ class MockEnv(Module):
|
||||
self.vectorized = vectorized
|
||||
assert not (vectorized and num_envs == 1)
|
||||
|
||||
# mocking termination
|
||||
# mocking termination and truncation
|
||||
|
||||
self.can_terminate = exists(terminate_after_step)
|
||||
self.terminate_after_step = terminate_after_step
|
||||
self.rand_terminate_prob = rand_terminate_prob
|
||||
|
||||
self.can_truncate = can_truncate
|
||||
self.rand_truncate_prob = rand_truncate_prob
|
||||
|
||||
self.register_buffer('_step', tensor(0))
|
||||
|
||||
def get_random_state(self):
|
||||
@ -72,16 +77,21 @@ class MockEnv(Module):
|
||||
|
||||
out = (state, reward)
|
||||
|
||||
|
||||
if self.can_terminate:
|
||||
shape = (self.num_envs,) if self.vectorized else (1,)
|
||||
valid_step = self._step > self.terminate_after_step
|
||||
|
||||
terminate = (
|
||||
(torch.rand(shape) < self.rand_terminate_prob) &
|
||||
(self._step > self.terminate_after_step)
|
||||
)
|
||||
terminate = (torch.rand(shape) < self.rand_terminate_prob) & valid_step
|
||||
|
||||
out = (*out, terminate)
|
||||
|
||||
# maybe truncation
|
||||
|
||||
if self.can_truncate:
|
||||
truncate = (torch.rand(shape) < self.rand_truncate_prob) & valid_step & ~terminate
|
||||
out = (*out, truncate)
|
||||
|
||||
self._step.add_(1)
|
||||
|
||||
return out
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "dreamer4"
|
||||
version = "0.0.82"
|
||||
version = "0.0.83"
|
||||
description = "Dreamer 4"
|
||||
authors = [
|
||||
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
||||
|
||||
@ -613,10 +613,12 @@ def test_cache_generate():
|
||||
@param('vectorized', (False, True))
|
||||
@param('use_signed_advantage', (False, True))
|
||||
@param('env_can_terminate', (False, True))
|
||||
@param('env_can_truncate', (False, True))
|
||||
def test_online_rl(
|
||||
vectorized,
|
||||
use_signed_advantage,
|
||||
env_can_terminate
|
||||
env_can_terminate,
|
||||
env_can_truncate
|
||||
):
|
||||
from dreamer4.dreamer4 import DynamicsWorldModel, VideoTokenizer
|
||||
|
||||
@ -656,6 +658,7 @@ def test_online_rl(
|
||||
vectorized = vectorized,
|
||||
num_envs = 4,
|
||||
terminate_after_step = 2 if env_can_terminate else None,
|
||||
can_truncate = env_can_truncate,
|
||||
rand_terminate_prob = 0.1
|
||||
)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user