sketch out the dream trainer, seems like they only fine tune the heads

This commit is contained in:
lucidrains 2025-10-22 06:41:10 -07:00
parent 6f1a7a24ed
commit 03b16a48f2
5 changed files with 193 additions and 23 deletions

View File

@ -3,3 +3,10 @@ from dreamer4.dreamer4 import (
DynamicsWorldModel, DynamicsWorldModel,
Dreamer Dreamer
) )
from dreamer4.trainers import (
VideoTokenizerTrainer,
BehaviorCloneTrainer,
DreamTrainer
)

View File

@ -3,6 +3,7 @@ from __future__ import annotations
import math import math
from math import ceil, log2 from math import ceil, log2
from random import random from random import random
from contextlib import nullcontext
from collections import namedtuple from collections import namedtuple
from functools import partial from functools import partial
from dataclasses import dataclass from dataclasses import dataclass
@ -485,7 +486,7 @@ class ActionEmbedder(Module):
return set([*self.discrete_action_embed.parameters(), *self.continuous_action_embed.parameters()]) return set([*self.discrete_action_embed.parameters(), *self.continuous_action_embed.parameters()])
def unembed_parameters(self): def unembed_parameters(self):
return set([*self.discrete_action_unembed.parameters(), *self.continuous_action_unembed.parameters()]) return set([self.discrete_action_unembed, self.continuous_action_unembed])
@property @property
def device(self): def device(self):
@ -1927,9 +1928,30 @@ class DynamicsWorldModel(Module):
def device(self): def device(self):
return self.zero.device return self.zero.device
# types of parameters
def muon_parameters(self): def muon_parameters(self):
return self.transformer.muon_parameters() return self.transformer.muon_parameters()
def policy_head_parameters(self):
return [
*self.policy_head.parameters(),
*self.action_embedder.unembed_parameters() # includes the unembed from the action-embedder
]
def value_head_parameters(self):
return self.value_head.parameters()
def parameter(self):
params = super().parameters()
if not exists(self.video_tokenizer):
return params
return list(set(params) - set(self.video_tokenizer.parameters()))
# helpers for shortcut flow matching
def get_times_from_signal_level( def get_times_from_signal_level(
self, self,
signal_levels, signal_levels,
@ -1942,19 +1964,14 @@ class DynamicsWorldModel(Module):
return align_dims_left(times, align_dims_left_to) return align_dims_left(times, align_dims_left_to)
def parameter(self): # ppo
params = super().parameters()
if not exists(self.video_tokenizer):
return params
return list(set(params) - set(self.video_tokenizer.parameters()))
def learn_from_experience( def learn_from_experience(
self, self,
experience: Experience, experience: Experience,
policy_optim: Optimizer | None = None, policy_optim: Optimizer | None = None,
value_optim: Optimizer | None = None value_optim: Optimizer | None = None,
only_learn_policy_value_heads = True # in the paper, they do not finetune the entire dynamics model, they just learn the heads
): ):
assert experience.is_batched assert experience.is_batched
@ -1971,6 +1988,10 @@ class DynamicsWorldModel(Module):
returns = calc_gae(rewards, old_values, gamma = self.gae_discount_factor, lam = self.gae_lambda, use_accelerated = self.gae_use_accelerated) returns = calc_gae(rewards, old_values, gamma = self.gae_discount_factor, lam = self.gae_lambda, use_accelerated = self.gae_use_accelerated)
# determine whether to finetune entire transformer or just learn the heads
world_model_forward_context = torch.no_grad if only_learn_policy_value_heads else nullcontext
# apparently they just use the sign of the advantage # apparently they just use the sign of the advantage
# https://arxiv.org/abs/2410.04166v1 # https://arxiv.org/abs/2410.04166v1
@ -1980,20 +2001,26 @@ class DynamicsWorldModel(Module):
discrete_actions, continuous_actions = actions discrete_actions, continuous_actions = actions
_, (agent_embed, _) = self.forward( with world_model_forward_context():
latents = latents, _, (agent_embed, _) = self.forward(
signal_levels = self.max_steps - 1, latents = latents,
step_sizes = step_size, signal_levels = self.max_steps - 1,
rewards = rewards, step_sizes = step_size,
discrete_actions = discrete_actions, rewards = rewards,
continuous_actions = continuous_actions, discrete_actions = discrete_actions,
latent_is_noised = True, continuous_actions = continuous_actions,
return_pred_only = True, latent_is_noised = True,
return_intermediates = True return_pred_only = True,
) return_intermediates = True
)
agent_embed = agent_embed[..., agent_index, :] agent_embed = agent_embed[..., agent_index, :]
# maybe detach agent embed
if only_learn_policy_value_heads:
agent_embed = agent_embed.detach()
# ppo # ppo
policy_embed = self.policy_head(agent_embed) policy_embed = self.policy_head(agent_embed)

View File

@ -3,6 +3,7 @@ from __future__ import annotations
import torch import torch
from torch import is_tensor from torch import is_tensor
from torch.nn import Module from torch.nn import Module
from torch.optim import AdamW
from torch.utils.data import Dataset, DataLoader from torch.utils.data import Dataset, DataLoader
from accelerate import Accelerator from accelerate import Accelerator
@ -213,3 +214,106 @@ class BehaviorCloneTrainer(Module):
self.optim.zero_grad() self.optim.zero_grad()
self.print('training complete') self.print('training complete')
# training from dreams
class DreamTrainer(Module):
def __init__(
self,
model: DynamicsWorldModel,
optim_klass = AdamW,
batch_size = 16,
generate_timesteps = 16,
learning_rate = 3e-4,
max_grad_norm = None,
num_train_steps = 10_000,
weight_decay = 0.,
accelerate_kwargs: dict = dict(),
optim_kwargs: dict = dict(),
cpu = False,
):
super().__init__()
self.accelerator = Accelerator(
cpu = cpu,
**accelerate_kwargs
)
self.model = model
optim_kwargs = dict(
lr = learning_rate,
weight_decay = weight_decay
)
self.policy_head_optim = AdamW(model.policy_head_parameters(), **optim_kwargs)
self.value_head_optim = AdamW(model.value_head_parameters(), **optim_kwargs)
self.max_grad_norm = max_grad_norm
self.num_train_steps = num_train_steps
self.batch_size = batch_size
self.generate_timesteps = generate_timesteps
self.unwrapped_model = self.model
(
self.model,
self.policy_head_optim,
self.value_head_optim,
) = self.accelerator.prepare(
self.model,
self.policy_head_optim,
self.value_head_optim
)
@property
def device(self):
return self.accelerator.device
@property
def unwrapped_model(self):
return self.accelerator.unwrap_model(self.model)
def print(self, *args, **kwargs):
return self.accelerator.print(*args, **kwargs)
def forward(
self
):
for _ in range(self.num_train_steps):
dreams = self.unwrapped_model.generate(
self.generate_timesteps,
batch_size = self.batch_size,
return_rewards_per_frame = True,
return_agent_actions = True,
return_log_probs_and_values = True
)
policy_head_loss, value_head_loss = self.model.learn_from_experience(dreams)
self.print(f'policy head loss: {policy_head_loss.item():.3f} | value head loss: {value_head_loss.item():.3f}')
# update policy head
self.accelerator.backward(policy_head_loss)
if exists(self.max_grad_norm):
self.accelerator.clip_grad_norm_(self.model.policy_head_parameters()(), self.max_grad_norm)
self.policy_head_optim.step()
self.policy_head_optim.zero_grad()
# update value head
self.accelerator.backward(value_head_loss)
if exists(self.max_grad_norm):
self.accelerator.clip_grad_norm_(self.model.value_head_parameters(), self.max_grad_norm)
self.value_head_optim.step()
self.value_head_optim.zero_grad()
self.print('training complete')

View File

@ -1,6 +1,6 @@
[project] [project]
name = "dreamer4" name = "dreamer4"
version = "0.0.59" version = "0.0.60"
description = "Dreamer 4" description = "Dreamer 4"
authors = [ authors = [
{ name = "Phil Wang", email = "lucidrains@gmail.com" } { name = "Phil Wang", email = "lucidrains@gmail.com" }

View File

@ -518,7 +518,7 @@ def test_bc_trainer(
num_latent_tokens = 1 num_latent_tokens = 1
) )
model = DynamicsWorldModel( world_model = DynamicsWorldModel(
video_tokenizer = tokenizer, video_tokenizer = tokenizer,
dim = 16, dim = 16,
dim_latent = 16, dim_latent = 16,
@ -536,7 +536,7 @@ def test_bc_trainer(
) )
trainer = BehaviorCloneTrainer( trainer = BehaviorCloneTrainer(
model, world_model,
dataset = dataset, dataset = dataset,
batch_size = 1, batch_size = 1,
num_train_steps = 1, num_train_steps = 1,
@ -545,6 +545,38 @@ def test_bc_trainer(
trainer() trainer()
def test_dream_trainer():
from dreamer4.dreamer4 import DynamicsWorldModel
world_model = DynamicsWorldModel(
dim = 16,
dim_latent = 16,
max_steps = 64,
num_tasks = 4,
num_latent_tokens = 1,
depth = 1,
time_block_every = 1,
num_spatial_tokens = 1,
pred_orig_latent = True,
num_discrete_actions = 4,
attn_dim_head = 16,
prob_no_shortcut_train = 0.1,
num_residual_streams = 1
)
# training from self-generations (dreams)
from dreamer4.trainers import DreamTrainer
dream_trainer = DreamTrainer(
world_model,
batch_size = 2,
num_train_steps = 1,
cpu = True,
)
dream_trainer()
def test_cache_generate(): def test_cache_generate():
from dreamer4.dreamer4 import DynamicsWorldModel from dreamer4.dreamer4 import DynamicsWorldModel