handle environments that return a terminate flag, also make sure episode lens are logged in vectorized env
This commit is contained in:
parent
fd1e87983b
commit
995b1f64e5
@ -2118,6 +2118,8 @@ class DynamicsWorldModel(Module):
|
|||||||
else:
|
else:
|
||||||
video = rearrange(init_frame, 'c vh vw -> 1 c 1 vh vw')
|
video = rearrange(init_frame, 'c vh vw -> 1 c 1 vh vw')
|
||||||
|
|
||||||
|
batch, device = video.shape[0], video.device
|
||||||
|
|
||||||
# accumulate
|
# accumulate
|
||||||
|
|
||||||
rewards = None
|
rewards = None
|
||||||
@ -2128,6 +2130,11 @@ class DynamicsWorldModel(Module):
|
|||||||
values = None
|
values = None
|
||||||
latents = None
|
latents = None
|
||||||
|
|
||||||
|
# keep track of termination, for setting the `is_truncated` field on Experience and for early stopping interaction with env
|
||||||
|
|
||||||
|
is_terminated = full((batch,), False, device = device)
|
||||||
|
episode_lens = full((batch,), 0, device = device)
|
||||||
|
|
||||||
# maybe time kv cache
|
# maybe time kv cache
|
||||||
|
|
||||||
time_kv_cache = None
|
time_kv_cache = None
|
||||||
@ -2190,7 +2197,22 @@ class DynamicsWorldModel(Module):
|
|||||||
|
|
||||||
# pass the sampled action to the environment and get back next state and reward
|
# pass the sampled action to the environment and get back next state and reward
|
||||||
|
|
||||||
next_frame, reward = env.step((sampled_discrete_actions, sampled_continuous_actions))
|
env_step_out = env.step((sampled_discrete_actions, sampled_continuous_actions))
|
||||||
|
|
||||||
|
if len(env_step_out) == 2:
|
||||||
|
next_frame, reward = env_step_out
|
||||||
|
terminate = full((batch,), False)
|
||||||
|
|
||||||
|
elif len(env_step_out) == 3:
|
||||||
|
next_frame, reward, terminate = env_step_out
|
||||||
|
|
||||||
|
# update episode lens
|
||||||
|
|
||||||
|
episode_lens = torch.where(is_terminated, episode_lens, episode_lens + 1)
|
||||||
|
|
||||||
|
# update `is_terminated`
|
||||||
|
|
||||||
|
is_terminated |= terminate
|
||||||
|
|
||||||
# batch and time dimension
|
# batch and time dimension
|
||||||
|
|
||||||
@ -2206,6 +2228,11 @@ class DynamicsWorldModel(Module):
|
|||||||
video = cat((video, next_frame), dim = 2)
|
video = cat((video, next_frame), dim = 2)
|
||||||
rewards = safe_cat((rewards, reward), dim = 1)
|
rewards = safe_cat((rewards, reward), dim = 1)
|
||||||
|
|
||||||
|
# early break out if all terminated
|
||||||
|
|
||||||
|
if is_terminated.all():
|
||||||
|
break
|
||||||
|
|
||||||
# package up one experience for learning
|
# package up one experience for learning
|
||||||
|
|
||||||
batch, device = latents.shape[0], latents.device
|
batch, device = latents.shape[0], latents.device
|
||||||
@ -2219,7 +2246,8 @@ class DynamicsWorldModel(Module):
|
|||||||
values = values,
|
values = values,
|
||||||
step_size = step_size,
|
step_size = step_size,
|
||||||
agent_index = agent_index,
|
agent_index = agent_index,
|
||||||
lens = full((batch,), max_timesteps + 1, device = device),
|
is_truncated = ~is_terminated,
|
||||||
|
lens = episode_lens,
|
||||||
is_from_world_model = False
|
is_from_world_model = False
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -7,6 +7,11 @@ from torch.nn import Module
|
|||||||
|
|
||||||
from einops import repeat
|
from einops import repeat
|
||||||
|
|
||||||
|
# helpers
|
||||||
|
|
||||||
|
def exists(v):
|
||||||
|
return v is not None
|
||||||
|
|
||||||
# mock env
|
# mock env
|
||||||
|
|
||||||
class MockEnv(Module):
|
class MockEnv(Module):
|
||||||
@ -15,7 +20,9 @@ class MockEnv(Module):
|
|||||||
image_shape,
|
image_shape,
|
||||||
reward_range = (-100, 100),
|
reward_range = (-100, 100),
|
||||||
num_envs = 1,
|
num_envs = 1,
|
||||||
vectorized = False
|
vectorized = False,
|
||||||
|
terminate_after_step = None,
|
||||||
|
rand_terminate_prob = 0.05
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.image_shape = image_shape
|
self.image_shape = image_shape
|
||||||
@ -25,6 +32,12 @@ class MockEnv(Module):
|
|||||||
self.vectorized = vectorized
|
self.vectorized = vectorized
|
||||||
assert not (vectorized and num_envs == 1)
|
assert not (vectorized and num_envs == 1)
|
||||||
|
|
||||||
|
# mocking termination
|
||||||
|
|
||||||
|
self.can_terminate = exists(terminate_after_step)
|
||||||
|
self.terminate_after_step = terminate_after_step
|
||||||
|
self.rand_terminate_prob = rand_terminate_prob
|
||||||
|
|
||||||
self.register_buffer('_step', tensor(0))
|
self.register_buffer('_step', tensor(0))
|
||||||
|
|
||||||
def get_random_state(self):
|
def get_random_state(self):
|
||||||
@ -50,13 +63,23 @@ class MockEnv(Module):
|
|||||||
|
|
||||||
reward = empty(()).uniform_(*self.reward_range)
|
reward = empty(()).uniform_(*self.reward_range)
|
||||||
|
|
||||||
if not self.vectorized:
|
if self.vectorized:
|
||||||
return state, reward
|
discrete, continuous = actions
|
||||||
|
assert discrete.shape[0] == self.num_envs, f'expected batch of actions for {self.num_envs} environments'
|
||||||
|
|
||||||
discrete, continuous = actions
|
state = repeat(state, '... -> b ...', b = self.num_envs)
|
||||||
assert discrete.shape[0] == self.num_envs, f'expected batch of actions for {self.num_envs} environments'
|
reward = repeat(reward, ' -> b', b = self.num_envs)
|
||||||
|
|
||||||
state = repeat(state, '... -> b ...', b = self.num_envs)
|
out = (state, reward)
|
||||||
reward = repeat(reward, ' -> b', b = self.num_envs)
|
|
||||||
|
|
||||||
return state, reward
|
if self.can_terminate:
|
||||||
|
terminate = (
|
||||||
|
(torch.rand((self.num_envs)) < self.rand_terminate_prob) &
|
||||||
|
(self._step > self.terminate_after_step)
|
||||||
|
)
|
||||||
|
|
||||||
|
out = (*out, terminate)
|
||||||
|
|
||||||
|
self._step.add_(1)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "dreamer4"
|
name = "dreamer4"
|
||||||
version = "0.0.81"
|
version = "0.0.82"
|
||||||
description = "Dreamer 4"
|
description = "Dreamer 4"
|
||||||
authors = [
|
authors = [
|
||||||
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
||||||
|
|||||||
@ -612,9 +612,11 @@ def test_cache_generate():
|
|||||||
|
|
||||||
@param('vectorized', (False, True))
|
@param('vectorized', (False, True))
|
||||||
@param('use_signed_advantage', (False, True))
|
@param('use_signed_advantage', (False, True))
|
||||||
|
@param('env_can_terminate', (False, True))
|
||||||
def test_online_rl(
|
def test_online_rl(
|
||||||
vectorized,
|
vectorized,
|
||||||
use_signed_advantage
|
use_signed_advantage,
|
||||||
|
env_can_terminate
|
||||||
):
|
):
|
||||||
from dreamer4.dreamer4 import DynamicsWorldModel, VideoTokenizer
|
from dreamer4.dreamer4 import DynamicsWorldModel, VideoTokenizer
|
||||||
|
|
||||||
@ -649,7 +651,13 @@ def test_online_rl(
|
|||||||
from dreamer4.mocks import MockEnv
|
from dreamer4.mocks import MockEnv
|
||||||
from dreamer4.dreamer4 import combine_experiences
|
from dreamer4.dreamer4 import combine_experiences
|
||||||
|
|
||||||
mock_env = MockEnv((256, 256), vectorized = vectorized, num_envs = 4)
|
mock_env = MockEnv(
|
||||||
|
(256, 256),
|
||||||
|
vectorized = vectorized,
|
||||||
|
num_envs = 4,
|
||||||
|
terminate_after_step = 2 if env_can_terminate else None,
|
||||||
|
rand_terminate_prob = 0.1
|
||||||
|
)
|
||||||
|
|
||||||
# manually
|
# manually
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user