From d82debb7a6329046cbd942c9b0783ee176b5e705 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Wed, 22 Oct 2025 08:32:46 -0700 Subject: [PATCH] first pass through gathering experience with a mock env for online rl --- dreamer4/__init__.py | 3 +- dreamer4/dreamer4.py | 118 ++++++++++++++++++++++++++++++++++++------ dreamer4/mocks.py | 7 +-- pyproject.toml | 2 +- tests/test_dreamer.py | 45 +++++++++++++--- 5 files changed, 146 insertions(+), 29 deletions(-) diff --git a/dreamer4/__init__.py b/dreamer4/__init__.py index 82442fb..d7ce984 100644 --- a/dreamer4/__init__.py +++ b/dreamer4/__init__.py @@ -1,7 +1,6 @@ from dreamer4.dreamer4 import ( VideoTokenizer, - DynamicsWorldModel, - Dreamer + DynamicsWorldModel ) diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index 64922be..41809cd 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -81,7 +81,6 @@ class Experience: step_size: int | None = None agent_index: int = 0 is_from_world_model: bool = True - is_batched: bool = True # helpers @@ -1964,6 +1963,107 @@ class DynamicsWorldModel(Module): return align_dims_left(times, align_dims_left_to) + # interacting with env for experience + + @torch.no_grad() + def interact_with_env( + self, + env, + seed = None, + agent_index = 0, + step_size = 4, + max_timesteps = 16 + ): + assert exists(self.video_tokenizer) + + init_frame = env.reset() + + video = rearrange(init_frame, 'c vh vw -> 1 c 1 vh vw') + + # accumulate + + rewards = None + discrete_actions = None + continuous_actions = None + discrete_log_probs = None + continuous_log_probs = None + values = None + latents = None + + for _ in range(max_timesteps): + + latents = self.video_tokenizer(video, return_latents = True) + + _, (agent_embed, _) = self.forward( + latents = latents, + signal_levels = self.max_steps - 1, + step_sizes = step_size, + rewards = rewards, + discrete_actions = discrete_actions, + continuous_actions = continuous_actions, + latent_is_noised = True, + return_pred_only = True, + return_intermediates = True + ) + + one_agent_embed = agent_embed[..., -1:, agent_index, :] + + policy_embed = self.policy_head(one_agent_embed) + + # sample actions + + sampled_discrete_actions, sampled_continuous_actions = self.action_embedder.sample(policy_embed, pred_head_index = 0, squeeze = True) + + discrete_actions = safe_cat((discrete_actions, sampled_discrete_actions), dim = 1) + continuous_actions = safe_cat((continuous_actions, sampled_continuous_actions), dim = 1) + + # get the log prob and values for policy optimization + + one_discrete_log_probs, one_continuous_log_probs = self.action_embedder.log_probs( + policy_embed, + pred_head_index = 0, + discrete_targets = sampled_discrete_actions, + continuous_targets = sampled_continuous_actions, + ) + + discrete_log_probs = safe_cat((discrete_log_probs, one_discrete_log_probs), dim = 1) + continuous_log_probs = safe_cat((continuous_log_probs, one_continuous_log_probs), dim = 1) + + value_bins = self.value_head(one_agent_embed) + value = self.reward_encoder.bins_to_scalar_value(value_bins) + + values = safe_cat((values, value), dim = 1) + + # pass the sampled action to the environment and get back next state and reward + + next_frame, reward = env.step((sampled_discrete_actions, sampled_continuous_actions)) + + # batch and time dimension + + next_frame = rearrange(next_frame, 'c vh vw -> 1 c 1 vh vw') + reward = rearrange(reward, ' -> 1 1') + + # concat + + video = cat((video, next_frame), dim = 2) + rewards = safe_cat((rewards, reward), dim = 1) + + # package up one experience for learning + + one_experience = Experience( + latents = latents, + video = video[:, :, :-1], + rewards = rewards, + actions = (discrete_actions, continuous_actions), + log_probs = (discrete_log_probs, continuous_log_probs), + values = values, + step_size = step_size, + agent_index = agent_index, + is_from_world_model = False + ) + + return one_experience + # ppo def learn_from_experience( @@ -1973,7 +2073,6 @@ class DynamicsWorldModel(Module): value_optim: Optimizer | None = None, only_learn_policy_value_heads = True # in the paper, they do not finetune the entire dynamics model, they just learn the heads ): - assert experience.is_batched latents = experience.latents actions = experience.actions @@ -2758,18 +2857,3 @@ class DynamicsWorldModel(Module): return total_loss return total_loss, losses - -# dreamer - -class Dreamer(Module): - def __init__( - self, - state_tokenizer: VideoTokenizer, - world_model: DynamicsWorldModel, - ): - super().__init__() - self.state_toke = state_tokenizer - self.world_model = world_model - - def interact_with_sim(self, env) -> Experience: - raise NotImplementedError diff --git a/dreamer4/mocks.py b/dreamer4/mocks.py index 55de215..a207a4f 100644 --- a/dreamer4/mocks.py +++ b/dreamer4/mocks.py @@ -2,7 +2,7 @@ from __future__ import annotations from random import choice import torch -from torch import tensor, randn, randint +from torch import tensor, empty, randn, randint from torch.nn import Module from einops import repeat @@ -13,7 +13,7 @@ class MockEnv(Module): def __init__( self, image_shape, - reward_range = (-100., 100.), + reward_range = (-100, 100), num_envs = 1, vectorized = False ): @@ -40,7 +40,8 @@ class MockEnv(Module): actions, ): state = self.get_random_state() - reward = randint(*self.reward_range, ()).float() + + reward = empty(()).uniform_(*self.reward_range) if not self.vectorized: return state, reward diff --git a/pyproject.toml b/pyproject.toml index ca30342..8483df6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "dreamer4" -version = "0.0.60" +version = "0.0.61" description = "Dreamer 4" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" } diff --git a/tests/test_dreamer.py b/tests/test_dreamer.py index 683e981..f0cc357 100644 --- a/tests/test_dreamer.py +++ b/tests/test_dreamer.py @@ -600,10 +600,43 @@ def test_cache_generate(): generated, time_kv_cache = dynamics.generate(1, time_kv_cache = time_kv_cache, return_time_kv_cache = True) generated, time_kv_cache = dynamics.generate(1, time_kv_cache = time_kv_cache, return_time_kv_cache = True) -@param('vectorized', (False, True)) -def test_online_rl( - vectorized -): - from dreamer4.mocks import MockEnv +def test_online_rl(): + from dreamer4.dreamer4 import DynamicsWorldModel, VideoTokenizer - mock_env = MockEnv((256, 256), vectorized = vectorized, num_envs = 4) \ No newline at end of file + tokenizer = VideoTokenizer( + 16, + encoder_depth = 1, + decoder_depth = 1, + time_block_every = 1, + dim_latent = 16, + patch_size = 32, + attn_dim_head = 16, + num_latent_tokens = 1 + ) + + world_model_and_policy = DynamicsWorldModel( + video_tokenizer = tokenizer, + dim = 16, + dim_latent = 16, + max_steps = 64, + num_tasks = 4, + num_latent_tokens = 1, + depth = 1, + time_block_every = 1, + num_spatial_tokens = 1, + pred_orig_latent = True, + num_discrete_actions = 4, + attn_dim_head = 16, + prob_no_shortcut_train = 0.1, + num_residual_streams = 1 + ) + + from dreamer4.mocks import MockEnv + mock_env = MockEnv((256, 256), vectorized = False, num_envs = 4) + + one_experience = world_model_and_policy.interact_with_env(mock_env, max_timesteps = 16) + + actor_loss, critic_loss = world_model_and_policy.learn_from_experience(one_experience) + + actor_loss.backward() + critic_loss.backward()