diff --git a/dreamer4/__init__.py b/dreamer4/__init__.py index eeb1f2c..82442fb 100644 --- a/dreamer4/__init__.py +++ b/dreamer4/__init__.py @@ -3,3 +3,10 @@ from dreamer4.dreamer4 import ( DynamicsWorldModel, Dreamer ) + + +from dreamer4.trainers import ( + VideoTokenizerTrainer, + BehaviorCloneTrainer, + DreamTrainer +) diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index 9902535..64922be 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -3,6 +3,7 @@ from __future__ import annotations import math from math import ceil, log2 from random import random +from contextlib import nullcontext from collections import namedtuple from functools import partial from dataclasses import dataclass @@ -485,7 +486,7 @@ class ActionEmbedder(Module): return set([*self.discrete_action_embed.parameters(), *self.continuous_action_embed.parameters()]) def unembed_parameters(self): - return set([*self.discrete_action_unembed.parameters(), *self.continuous_action_unembed.parameters()]) + return set([self.discrete_action_unembed, self.continuous_action_unembed]) @property def device(self): @@ -1927,9 +1928,30 @@ class DynamicsWorldModel(Module): def device(self): return self.zero.device + # types of parameters + def muon_parameters(self): return self.transformer.muon_parameters() + def policy_head_parameters(self): + return [ + *self.policy_head.parameters(), + *self.action_embedder.unembed_parameters() # includes the unembed from the action-embedder + ] + + def value_head_parameters(self): + return self.value_head.parameters() + + def parameter(self): + params = super().parameters() + + if not exists(self.video_tokenizer): + return params + + return list(set(params) - set(self.video_tokenizer.parameters())) + + # helpers for shortcut flow matching + def get_times_from_signal_level( self, signal_levels, @@ -1942,19 +1964,14 @@ class DynamicsWorldModel(Module): return align_dims_left(times, align_dims_left_to) - def parameter(self): - params = super().parameters() - - if not exists(self.video_tokenizer): - return params - - return list(set(params) - set(self.video_tokenizer.parameters())) + # ppo def learn_from_experience( self, experience: Experience, policy_optim: Optimizer | None = None, - value_optim: Optimizer | None = None + value_optim: Optimizer | None = None, + only_learn_policy_value_heads = True # in the paper, they do not finetune the entire dynamics model, they just learn the heads ): assert experience.is_batched @@ -1971,6 +1988,10 @@ class DynamicsWorldModel(Module): returns = calc_gae(rewards, old_values, gamma = self.gae_discount_factor, lam = self.gae_lambda, use_accelerated = self.gae_use_accelerated) + # determine whether to finetune entire transformer or just learn the heads + + world_model_forward_context = torch.no_grad if only_learn_policy_value_heads else nullcontext + # apparently they just use the sign of the advantage # https://arxiv.org/abs/2410.04166v1 @@ -1980,20 +2001,26 @@ class DynamicsWorldModel(Module): discrete_actions, continuous_actions = actions - _, (agent_embed, _) = self.forward( - latents = latents, - signal_levels = self.max_steps - 1, - step_sizes = step_size, - rewards = rewards, - discrete_actions = discrete_actions, - continuous_actions = continuous_actions, - latent_is_noised = True, - return_pred_only = True, - return_intermediates = True - ) + with world_model_forward_context(): + _, (agent_embed, _) = self.forward( + latents = latents, + signal_levels = self.max_steps - 1, + step_sizes = step_size, + rewards = rewards, + discrete_actions = discrete_actions, + continuous_actions = continuous_actions, + latent_is_noised = True, + return_pred_only = True, + return_intermediates = True + ) agent_embed = agent_embed[..., agent_index, :] + # maybe detach agent embed + + if only_learn_policy_value_heads: + agent_embed = agent_embed.detach() + # ppo policy_embed = self.policy_head(agent_embed) diff --git a/dreamer4/trainers.py b/dreamer4/trainers.py index 8b59f98..0c730da 100644 --- a/dreamer4/trainers.py +++ b/dreamer4/trainers.py @@ -3,6 +3,7 @@ from __future__ import annotations import torch from torch import is_tensor from torch.nn import Module +from torch.optim import AdamW from torch.utils.data import Dataset, DataLoader from accelerate import Accelerator @@ -213,3 +214,106 @@ class BehaviorCloneTrainer(Module): self.optim.zero_grad() self.print('training complete') + +# training from dreams + +class DreamTrainer(Module): + def __init__( + self, + model: DynamicsWorldModel, + optim_klass = AdamW, + batch_size = 16, + generate_timesteps = 16, + learning_rate = 3e-4, + max_grad_norm = None, + num_train_steps = 10_000, + weight_decay = 0., + accelerate_kwargs: dict = dict(), + optim_kwargs: dict = dict(), + cpu = False, + ): + super().__init__() + + self.accelerator = Accelerator( + cpu = cpu, + **accelerate_kwargs + ) + + self.model = model + + optim_kwargs = dict( + lr = learning_rate, + weight_decay = weight_decay + ) + + self.policy_head_optim = AdamW(model.policy_head_parameters(), **optim_kwargs) + self.value_head_optim = AdamW(model.value_head_parameters(), **optim_kwargs) + + self.max_grad_norm = max_grad_norm + + self.num_train_steps = num_train_steps + self.batch_size = batch_size + self.generate_timesteps = generate_timesteps + + self.unwrapped_model = self.model + + ( + self.model, + self.policy_head_optim, + self.value_head_optim, + ) = self.accelerator.prepare( + self.model, + self.policy_head_optim, + self.value_head_optim + ) + + @property + def device(self): + return self.accelerator.device + + @property + def unwrapped_model(self): + return self.accelerator.unwrap_model(self.model) + + def print(self, *args, **kwargs): + return self.accelerator.print(*args, **kwargs) + + def forward( + self + ): + + for _ in range(self.num_train_steps): + + dreams = self.unwrapped_model.generate( + self.generate_timesteps, + batch_size = self.batch_size, + return_rewards_per_frame = True, + return_agent_actions = True, + return_log_probs_and_values = True + ) + + policy_head_loss, value_head_loss = self.model.learn_from_experience(dreams) + + self.print(f'policy head loss: {policy_head_loss.item():.3f} | value head loss: {value_head_loss.item():.3f}') + + # update policy head + + self.accelerator.backward(policy_head_loss) + + if exists(self.max_grad_norm): + self.accelerator.clip_grad_norm_(self.model.policy_head_parameters()(), self.max_grad_norm) + + self.policy_head_optim.step() + self.policy_head_optim.zero_grad() + + # update value head + + self.accelerator.backward(value_head_loss) + + if exists(self.max_grad_norm): + self.accelerator.clip_grad_norm_(self.model.value_head_parameters(), self.max_grad_norm) + + self.value_head_optim.step() + self.value_head_optim.zero_grad() + + self.print('training complete') diff --git a/pyproject.toml b/pyproject.toml index cb11e5b..ca30342 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "dreamer4" -version = "0.0.59" +version = "0.0.60" description = "Dreamer 4" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" } diff --git a/tests/test_dreamer.py b/tests/test_dreamer.py index 7a715f5..683e981 100644 --- a/tests/test_dreamer.py +++ b/tests/test_dreamer.py @@ -518,7 +518,7 @@ def test_bc_trainer( num_latent_tokens = 1 ) - model = DynamicsWorldModel( + world_model = DynamicsWorldModel( video_tokenizer = tokenizer, dim = 16, dim_latent = 16, @@ -536,7 +536,7 @@ def test_bc_trainer( ) trainer = BehaviorCloneTrainer( - model, + world_model, dataset = dataset, batch_size = 1, num_train_steps = 1, @@ -545,6 +545,38 @@ def test_bc_trainer( trainer() +def test_dream_trainer(): + from dreamer4.dreamer4 import DynamicsWorldModel + + world_model = DynamicsWorldModel( + dim = 16, + dim_latent = 16, + max_steps = 64, + num_tasks = 4, + num_latent_tokens = 1, + depth = 1, + time_block_every = 1, + num_spatial_tokens = 1, + pred_orig_latent = True, + num_discrete_actions = 4, + attn_dim_head = 16, + prob_no_shortcut_train = 0.1, + num_residual_streams = 1 + ) + + # training from self-generations (dreams) + + from dreamer4.trainers import DreamTrainer + + dream_trainer = DreamTrainer( + world_model, + batch_size = 2, + num_train_steps = 1, + cpu = True, + ) + + dream_trainer() + def test_cache_generate(): from dreamer4.dreamer4 import DynamicsWorldModel