From 35c1db4c7d92d04e439ef4118fdc371a4bacfb79 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Fri, 24 Oct 2025 09:13:09 -0700 Subject: [PATCH] sketch of training from sim env --- dreamer4/trainers.py | 201 +++++++++++++++++++++++++++++++++++++++++- pyproject.toml | 2 +- tests/test_dreamer.py | 14 +++ 3 files changed, 214 insertions(+), 3 deletions(-) diff --git a/dreamer4/trainers.py b/dreamer4/trainers.py index 0c730da..3ebbd95 100644 --- a/dreamer4/trainers.py +++ b/dreamer4/trainers.py @@ -4,7 +4,7 @@ import torch from torch import is_tensor from torch.nn import Module from torch.optim import AdamW -from torch.utils.data import Dataset, DataLoader +from torch.utils.data import Dataset, TensorDataset, DataLoader from accelerate import Accelerator @@ -12,7 +12,9 @@ from adam_atan2_pytorch import MuonAdamAtan2 from dreamer4.dreamer4 import ( VideoTokenizer, - DynamicsWorldModel + DynamicsWorldModel, + Experience, + combine_experiences ) from ema_pytorch import EMA @@ -317,3 +319,198 @@ class DreamTrainer(Module): self.value_head_optim.zero_grad() self.print('training complete') + +# training from sim + +class SimTrainer(Module): + def __init__( + self, + model: DynamicsWorldModel, + optim_klass = AdamW, + batch_size = 16, + generate_timesteps = 16, + learning_rate = 3e-4, + max_grad_norm = None, + epochs = 2, + weight_decay = 0., + accelerate_kwargs: dict = dict(), + optim_kwargs: dict = dict(), + cpu = False, + ): + super().__init__() + + self.accelerator = Accelerator( + cpu = cpu, + **accelerate_kwargs + ) + + self.model = model + + optim_kwargs = dict( + lr = learning_rate, + weight_decay = weight_decay + ) + + self.policy_head_optim = AdamW(model.policy_head_parameters(), **optim_kwargs) + self.value_head_optim = AdamW(model.value_head_parameters(), **optim_kwargs) + + self.max_grad_norm = max_grad_norm + + self.epochs = epochs + self.batch_size = batch_size + + self.generate_timesteps = generate_timesteps + + self.unwrapped_model = self.model + + ( + self.model, + self.policy_head_optim, + self.value_head_optim, + ) = self.accelerator.prepare( + self.model, + self.policy_head_optim, + self.value_head_optim + ) + + @property + def device(self): + return self.accelerator.device + + @property + def unwrapped_model(self): + return self.accelerator.unwrap_model(self.model) + + def print(self, *args, **kwargs): + return self.accelerator.print(*args, **kwargs) + + def learn( + self, + experience: Experience + ): + + step_size = experience.step_size + agent_index = experience.agent_index + + latents = experience.latents + old_values = experience.values + rewards = experience.rewards + + discrete_actions, continuous_actions = experience.actions + discrete_log_probs, continuous_log_probs = experience.log_probs + + # handle empties + + empty_tensor = torch.empty_like(rewards) + + has_discrete = exists(discrete_actions) + has_continuous = exists(continuous_actions) + + discrete_actions = default(discrete_actions, empty_tensor) + continuous_actions = default(continuous_actions, empty_tensor) + + discrete_log_probs = default(discrete_log_probs, empty_tensor) + continuous_log_probs = default(continuous_log_probs, empty_tensor) + + # create the dataset and dataloader + + dataset = TensorDataset( + latents, + discrete_actions, + continuous_actions, + discrete_log_probs, + continuous_log_probs, + old_values, + rewards + ) + + dataloader = DataLoader(dataset, batch_size = self.batch_size, shuffle = True) + + for epoch in range(self.epochs): + + for ( + latents, + discrete_actions, + continuous_actions, + discrete_log_probs, + continuous_log_probs, + old_values, + rewards + ) in dataloader: + + actions = ( + discrete_actions if has_discrete else None, + continuous_actions if has_continuous else None + ) + + log_probs = ( + discrete_log_probs if has_discrete else None, + continuous_log_probs if has_continuous else None + ) + + batch_experience = Experience( + latents = latents, + actions = actions, + log_probs = log_probs, + values = old_values, + rewards = rewards, + step_size = step_size, + agent_index = agent_index + ) + + policy_head_loss, value_head_loss = self.model.learn_from_experience(batch_experience) + + self.print(f'policy head loss: {policy_head_loss.item():.3f} | value head loss: {value_head_loss.item():.3f}') + + # update policy head + + self.accelerator.backward(policy_head_loss) + + if exists(self.max_grad_norm): + self.accelerator.clip_grad_norm_(self.model.policy_head_parameters()(), self.max_grad_norm) + + self.policy_head_optim.step() + self.policy_head_optim.zero_grad() + + # update value head + + self.accelerator.backward(value_head_loss) + + if exists(self.max_grad_norm): + self.accelerator.clip_grad_norm_(self.model.value_head_parameters(), self.max_grad_norm) + + self.value_head_optim.step() + self.value_head_optim.zero_grad() + + self.print('training complete') + + def forward( + self, + env, + num_episodes = 50000, + max_experiences_before_learn = 8, + env_is_vectorized = False + ): + + for _ in range(num_episodes): + + total_experience = 0 + experiences = [] + + while total_experience < max_experiences_before_learn: + + experience = self.unwrapped_model.interact_with_env(env, env_is_vectorized = env_is_vectorized) + + num_experience = experience.video.shape[0] + + total_experience += num_experience + + experiences.append(experience) + + combined_experiences = combine_experiences(experiences) + + self.learn(combined_experiences) + + experiences.clear() + + self.print('training complete') \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 04b0cd1..0d7b0af 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "dreamer4" -version = "0.0.67" +version = "0.0.69" description = "Dreamer 4" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" } diff --git a/tests/test_dreamer.py b/tests/test_dreamer.py index 39dbb9c..5afe5eb 100644 --- a/tests/test_dreamer.py +++ b/tests/test_dreamer.py @@ -641,6 +641,8 @@ def test_online_rl( mock_env = MockEnv((256, 256), vectorized = vectorized, num_envs = 4) + # manually + one_experience = world_model_and_policy.interact_with_env(mock_env, max_timesteps = 16, env_is_vectorized = vectorized) another_experience = world_model_and_policy.interact_with_env(mock_env, max_timesteps = 16, env_is_vectorized = vectorized) @@ -650,3 +652,15 @@ def test_online_rl( actor_loss.backward() critic_loss.backward() + + # with trainer + + from dreamer4.trainers import SimTrainer + + trainer = SimTrainer( + world_model_and_policy, + batch_size = 4, + cpu = True + ) + + trainer(mock_env, num_episodes = 2, env_is_vectorized = vectorized)