sketch out the dream trainer, seems like they only fine tune the heads

This commit is contained in:
lucidrains 2025-10-22 06:41:10 -07:00
parent 6f1a7a24ed
commit 03b16a48f2
5 changed files with 193 additions and 23 deletions

View File

@ -3,3 +3,10 @@ from dreamer4.dreamer4 import (
DynamicsWorldModel,
Dreamer
)
from dreamer4.trainers import (
VideoTokenizerTrainer,
BehaviorCloneTrainer,
DreamTrainer
)

View File

@ -3,6 +3,7 @@ from __future__ import annotations
import math
from math import ceil, log2
from random import random
from contextlib import nullcontext
from collections import namedtuple
from functools import partial
from dataclasses import dataclass
@ -485,7 +486,7 @@ class ActionEmbedder(Module):
return set([*self.discrete_action_embed.parameters(), *self.continuous_action_embed.parameters()])
def unembed_parameters(self):
return set([*self.discrete_action_unembed.parameters(), *self.continuous_action_unembed.parameters()])
return set([self.discrete_action_unembed, self.continuous_action_unembed])
@property
def device(self):
@ -1927,9 +1928,30 @@ class DynamicsWorldModel(Module):
def device(self):
return self.zero.device
# types of parameters
def muon_parameters(self):
return self.transformer.muon_parameters()
def policy_head_parameters(self):
return [
*self.policy_head.parameters(),
*self.action_embedder.unembed_parameters() # includes the unembed from the action-embedder
]
def value_head_parameters(self):
return self.value_head.parameters()
def parameter(self):
params = super().parameters()
if not exists(self.video_tokenizer):
return params
return list(set(params) - set(self.video_tokenizer.parameters()))
# helpers for shortcut flow matching
def get_times_from_signal_level(
self,
signal_levels,
@ -1942,19 +1964,14 @@ class DynamicsWorldModel(Module):
return align_dims_left(times, align_dims_left_to)
def parameter(self):
params = super().parameters()
if not exists(self.video_tokenizer):
return params
return list(set(params) - set(self.video_tokenizer.parameters()))
# ppo
def learn_from_experience(
self,
experience: Experience,
policy_optim: Optimizer | None = None,
value_optim: Optimizer | None = None
value_optim: Optimizer | None = None,
only_learn_policy_value_heads = True # in the paper, they do not finetune the entire dynamics model, they just learn the heads
):
assert experience.is_batched
@ -1971,6 +1988,10 @@ class DynamicsWorldModel(Module):
returns = calc_gae(rewards, old_values, gamma = self.gae_discount_factor, lam = self.gae_lambda, use_accelerated = self.gae_use_accelerated)
# determine whether to finetune entire transformer or just learn the heads
world_model_forward_context = torch.no_grad if only_learn_policy_value_heads else nullcontext
# apparently they just use the sign of the advantage
# https://arxiv.org/abs/2410.04166v1
@ -1980,20 +2001,26 @@ class DynamicsWorldModel(Module):
discrete_actions, continuous_actions = actions
_, (agent_embed, _) = self.forward(
latents = latents,
signal_levels = self.max_steps - 1,
step_sizes = step_size,
rewards = rewards,
discrete_actions = discrete_actions,
continuous_actions = continuous_actions,
latent_is_noised = True,
return_pred_only = True,
return_intermediates = True
)
with world_model_forward_context():
_, (agent_embed, _) = self.forward(
latents = latents,
signal_levels = self.max_steps - 1,
step_sizes = step_size,
rewards = rewards,
discrete_actions = discrete_actions,
continuous_actions = continuous_actions,
latent_is_noised = True,
return_pred_only = True,
return_intermediates = True
)
agent_embed = agent_embed[..., agent_index, :]
# maybe detach agent embed
if only_learn_policy_value_heads:
agent_embed = agent_embed.detach()
# ppo
policy_embed = self.policy_head(agent_embed)

View File

@ -3,6 +3,7 @@ from __future__ import annotations
import torch
from torch import is_tensor
from torch.nn import Module
from torch.optim import AdamW
from torch.utils.data import Dataset, DataLoader
from accelerate import Accelerator
@ -213,3 +214,106 @@ class BehaviorCloneTrainer(Module):
self.optim.zero_grad()
self.print('training complete')
# training from dreams
class DreamTrainer(Module):
def __init__(
self,
model: DynamicsWorldModel,
optim_klass = AdamW,
batch_size = 16,
generate_timesteps = 16,
learning_rate = 3e-4,
max_grad_norm = None,
num_train_steps = 10_000,
weight_decay = 0.,
accelerate_kwargs: dict = dict(),
optim_kwargs: dict = dict(),
cpu = False,
):
super().__init__()
self.accelerator = Accelerator(
cpu = cpu,
**accelerate_kwargs
)
self.model = model
optim_kwargs = dict(
lr = learning_rate,
weight_decay = weight_decay
)
self.policy_head_optim = AdamW(model.policy_head_parameters(), **optim_kwargs)
self.value_head_optim = AdamW(model.value_head_parameters(), **optim_kwargs)
self.max_grad_norm = max_grad_norm
self.num_train_steps = num_train_steps
self.batch_size = batch_size
self.generate_timesteps = generate_timesteps
self.unwrapped_model = self.model
(
self.model,
self.policy_head_optim,
self.value_head_optim,
) = self.accelerator.prepare(
self.model,
self.policy_head_optim,
self.value_head_optim
)
@property
def device(self):
return self.accelerator.device
@property
def unwrapped_model(self):
return self.accelerator.unwrap_model(self.model)
def print(self, *args, **kwargs):
return self.accelerator.print(*args, **kwargs)
def forward(
self
):
for _ in range(self.num_train_steps):
dreams = self.unwrapped_model.generate(
self.generate_timesteps,
batch_size = self.batch_size,
return_rewards_per_frame = True,
return_agent_actions = True,
return_log_probs_and_values = True
)
policy_head_loss, value_head_loss = self.model.learn_from_experience(dreams)
self.print(f'policy head loss: {policy_head_loss.item():.3f} | value head loss: {value_head_loss.item():.3f}')
# update policy head
self.accelerator.backward(policy_head_loss)
if exists(self.max_grad_norm):
self.accelerator.clip_grad_norm_(self.model.policy_head_parameters()(), self.max_grad_norm)
self.policy_head_optim.step()
self.policy_head_optim.zero_grad()
# update value head
self.accelerator.backward(value_head_loss)
if exists(self.max_grad_norm):
self.accelerator.clip_grad_norm_(self.model.value_head_parameters(), self.max_grad_norm)
self.value_head_optim.step()
self.value_head_optim.zero_grad()
self.print('training complete')

View File

@ -1,6 +1,6 @@
[project]
name = "dreamer4"
version = "0.0.59"
version = "0.0.60"
description = "Dreamer 4"
authors = [
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

View File

@ -518,7 +518,7 @@ def test_bc_trainer(
num_latent_tokens = 1
)
model = DynamicsWorldModel(
world_model = DynamicsWorldModel(
video_tokenizer = tokenizer,
dim = 16,
dim_latent = 16,
@ -536,7 +536,7 @@ def test_bc_trainer(
)
trainer = BehaviorCloneTrainer(
model,
world_model,
dataset = dataset,
batch_size = 1,
num_train_steps = 1,
@ -545,6 +545,38 @@ def test_bc_trainer(
trainer()
def test_dream_trainer():
from dreamer4.dreamer4 import DynamicsWorldModel
world_model = DynamicsWorldModel(
dim = 16,
dim_latent = 16,
max_steps = 64,
num_tasks = 4,
num_latent_tokens = 1,
depth = 1,
time_block_every = 1,
num_spatial_tokens = 1,
pred_orig_latent = True,
num_discrete_actions = 4,
attn_dim_head = 16,
prob_no_shortcut_train = 0.1,
num_residual_streams = 1
)
# training from self-generations (dreams)
from dreamer4.trainers import DreamTrainer
dream_trainer = DreamTrainer(
world_model,
batch_size = 2,
num_train_steps = 1,
cpu = True,
)
dream_trainer()
def test_cache_generate():
from dreamer4.dreamer4 import DynamicsWorldModel