sketch out the dream trainer, seems like they only fine tune the heads
This commit is contained in:
parent
6f1a7a24ed
commit
03b16a48f2
@ -3,3 +3,10 @@ from dreamer4.dreamer4 import (
|
||||
DynamicsWorldModel,
|
||||
Dreamer
|
||||
)
|
||||
|
||||
|
||||
from dreamer4.trainers import (
|
||||
VideoTokenizerTrainer,
|
||||
BehaviorCloneTrainer,
|
||||
DreamTrainer
|
||||
)
|
||||
|
||||
@ -3,6 +3,7 @@ from __future__ import annotations
|
||||
import math
|
||||
from math import ceil, log2
|
||||
from random import random
|
||||
from contextlib import nullcontext
|
||||
from collections import namedtuple
|
||||
from functools import partial
|
||||
from dataclasses import dataclass
|
||||
@ -485,7 +486,7 @@ class ActionEmbedder(Module):
|
||||
return set([*self.discrete_action_embed.parameters(), *self.continuous_action_embed.parameters()])
|
||||
|
||||
def unembed_parameters(self):
|
||||
return set([*self.discrete_action_unembed.parameters(), *self.continuous_action_unembed.parameters()])
|
||||
return set([self.discrete_action_unembed, self.continuous_action_unembed])
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
@ -1927,9 +1928,30 @@ class DynamicsWorldModel(Module):
|
||||
def device(self):
|
||||
return self.zero.device
|
||||
|
||||
# types of parameters
|
||||
|
||||
def muon_parameters(self):
|
||||
return self.transformer.muon_parameters()
|
||||
|
||||
def policy_head_parameters(self):
|
||||
return [
|
||||
*self.policy_head.parameters(),
|
||||
*self.action_embedder.unembed_parameters() # includes the unembed from the action-embedder
|
||||
]
|
||||
|
||||
def value_head_parameters(self):
|
||||
return self.value_head.parameters()
|
||||
|
||||
def parameter(self):
|
||||
params = super().parameters()
|
||||
|
||||
if not exists(self.video_tokenizer):
|
||||
return params
|
||||
|
||||
return list(set(params) - set(self.video_tokenizer.parameters()))
|
||||
|
||||
# helpers for shortcut flow matching
|
||||
|
||||
def get_times_from_signal_level(
|
||||
self,
|
||||
signal_levels,
|
||||
@ -1942,19 +1964,14 @@ class DynamicsWorldModel(Module):
|
||||
|
||||
return align_dims_left(times, align_dims_left_to)
|
||||
|
||||
def parameter(self):
|
||||
params = super().parameters()
|
||||
|
||||
if not exists(self.video_tokenizer):
|
||||
return params
|
||||
|
||||
return list(set(params) - set(self.video_tokenizer.parameters()))
|
||||
# ppo
|
||||
|
||||
def learn_from_experience(
|
||||
self,
|
||||
experience: Experience,
|
||||
policy_optim: Optimizer | None = None,
|
||||
value_optim: Optimizer | None = None
|
||||
value_optim: Optimizer | None = None,
|
||||
only_learn_policy_value_heads = True # in the paper, they do not finetune the entire dynamics model, they just learn the heads
|
||||
):
|
||||
assert experience.is_batched
|
||||
|
||||
@ -1971,6 +1988,10 @@ class DynamicsWorldModel(Module):
|
||||
|
||||
returns = calc_gae(rewards, old_values, gamma = self.gae_discount_factor, lam = self.gae_lambda, use_accelerated = self.gae_use_accelerated)
|
||||
|
||||
# determine whether to finetune entire transformer or just learn the heads
|
||||
|
||||
world_model_forward_context = torch.no_grad if only_learn_policy_value_heads else nullcontext
|
||||
|
||||
# apparently they just use the sign of the advantage
|
||||
# https://arxiv.org/abs/2410.04166v1
|
||||
|
||||
@ -1980,20 +2001,26 @@ class DynamicsWorldModel(Module):
|
||||
|
||||
discrete_actions, continuous_actions = actions
|
||||
|
||||
_, (agent_embed, _) = self.forward(
|
||||
latents = latents,
|
||||
signal_levels = self.max_steps - 1,
|
||||
step_sizes = step_size,
|
||||
rewards = rewards,
|
||||
discrete_actions = discrete_actions,
|
||||
continuous_actions = continuous_actions,
|
||||
latent_is_noised = True,
|
||||
return_pred_only = True,
|
||||
return_intermediates = True
|
||||
)
|
||||
with world_model_forward_context():
|
||||
_, (agent_embed, _) = self.forward(
|
||||
latents = latents,
|
||||
signal_levels = self.max_steps - 1,
|
||||
step_sizes = step_size,
|
||||
rewards = rewards,
|
||||
discrete_actions = discrete_actions,
|
||||
continuous_actions = continuous_actions,
|
||||
latent_is_noised = True,
|
||||
return_pred_only = True,
|
||||
return_intermediates = True
|
||||
)
|
||||
|
||||
agent_embed = agent_embed[..., agent_index, :]
|
||||
|
||||
# maybe detach agent embed
|
||||
|
||||
if only_learn_policy_value_heads:
|
||||
agent_embed = agent_embed.detach()
|
||||
|
||||
# ppo
|
||||
|
||||
policy_embed = self.policy_head(agent_embed)
|
||||
|
||||
@ -3,6 +3,7 @@ from __future__ import annotations
|
||||
import torch
|
||||
from torch import is_tensor
|
||||
from torch.nn import Module
|
||||
from torch.optim import AdamW
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
|
||||
from accelerate import Accelerator
|
||||
@ -213,3 +214,106 @@ class BehaviorCloneTrainer(Module):
|
||||
self.optim.zero_grad()
|
||||
|
||||
self.print('training complete')
|
||||
|
||||
# training from dreams
|
||||
|
||||
class DreamTrainer(Module):
|
||||
def __init__(
|
||||
self,
|
||||
model: DynamicsWorldModel,
|
||||
optim_klass = AdamW,
|
||||
batch_size = 16,
|
||||
generate_timesteps = 16,
|
||||
learning_rate = 3e-4,
|
||||
max_grad_norm = None,
|
||||
num_train_steps = 10_000,
|
||||
weight_decay = 0.,
|
||||
accelerate_kwargs: dict = dict(),
|
||||
optim_kwargs: dict = dict(),
|
||||
cpu = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.accelerator = Accelerator(
|
||||
cpu = cpu,
|
||||
**accelerate_kwargs
|
||||
)
|
||||
|
||||
self.model = model
|
||||
|
||||
optim_kwargs = dict(
|
||||
lr = learning_rate,
|
||||
weight_decay = weight_decay
|
||||
)
|
||||
|
||||
self.policy_head_optim = AdamW(model.policy_head_parameters(), **optim_kwargs)
|
||||
self.value_head_optim = AdamW(model.value_head_parameters(), **optim_kwargs)
|
||||
|
||||
self.max_grad_norm = max_grad_norm
|
||||
|
||||
self.num_train_steps = num_train_steps
|
||||
self.batch_size = batch_size
|
||||
self.generate_timesteps = generate_timesteps
|
||||
|
||||
self.unwrapped_model = self.model
|
||||
|
||||
(
|
||||
self.model,
|
||||
self.policy_head_optim,
|
||||
self.value_head_optim,
|
||||
) = self.accelerator.prepare(
|
||||
self.model,
|
||||
self.policy_head_optim,
|
||||
self.value_head_optim
|
||||
)
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return self.accelerator.device
|
||||
|
||||
@property
|
||||
def unwrapped_model(self):
|
||||
return self.accelerator.unwrap_model(self.model)
|
||||
|
||||
def print(self, *args, **kwargs):
|
||||
return self.accelerator.print(*args, **kwargs)
|
||||
|
||||
def forward(
|
||||
self
|
||||
):
|
||||
|
||||
for _ in range(self.num_train_steps):
|
||||
|
||||
dreams = self.unwrapped_model.generate(
|
||||
self.generate_timesteps,
|
||||
batch_size = self.batch_size,
|
||||
return_rewards_per_frame = True,
|
||||
return_agent_actions = True,
|
||||
return_log_probs_and_values = True
|
||||
)
|
||||
|
||||
policy_head_loss, value_head_loss = self.model.learn_from_experience(dreams)
|
||||
|
||||
self.print(f'policy head loss: {policy_head_loss.item():.3f} | value head loss: {value_head_loss.item():.3f}')
|
||||
|
||||
# update policy head
|
||||
|
||||
self.accelerator.backward(policy_head_loss)
|
||||
|
||||
if exists(self.max_grad_norm):
|
||||
self.accelerator.clip_grad_norm_(self.model.policy_head_parameters()(), self.max_grad_norm)
|
||||
|
||||
self.policy_head_optim.step()
|
||||
self.policy_head_optim.zero_grad()
|
||||
|
||||
# update value head
|
||||
|
||||
self.accelerator.backward(value_head_loss)
|
||||
|
||||
if exists(self.max_grad_norm):
|
||||
self.accelerator.clip_grad_norm_(self.model.value_head_parameters(), self.max_grad_norm)
|
||||
|
||||
self.value_head_optim.step()
|
||||
self.value_head_optim.zero_grad()
|
||||
|
||||
self.print('training complete')
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "dreamer4"
|
||||
version = "0.0.59"
|
||||
version = "0.0.60"
|
||||
description = "Dreamer 4"
|
||||
authors = [
|
||||
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
||||
|
||||
@ -518,7 +518,7 @@ def test_bc_trainer(
|
||||
num_latent_tokens = 1
|
||||
)
|
||||
|
||||
model = DynamicsWorldModel(
|
||||
world_model = DynamicsWorldModel(
|
||||
video_tokenizer = tokenizer,
|
||||
dim = 16,
|
||||
dim_latent = 16,
|
||||
@ -536,7 +536,7 @@ def test_bc_trainer(
|
||||
)
|
||||
|
||||
trainer = BehaviorCloneTrainer(
|
||||
model,
|
||||
world_model,
|
||||
dataset = dataset,
|
||||
batch_size = 1,
|
||||
num_train_steps = 1,
|
||||
@ -545,6 +545,38 @@ def test_bc_trainer(
|
||||
|
||||
trainer()
|
||||
|
||||
def test_dream_trainer():
|
||||
from dreamer4.dreamer4 import DynamicsWorldModel
|
||||
|
||||
world_model = DynamicsWorldModel(
|
||||
dim = 16,
|
||||
dim_latent = 16,
|
||||
max_steps = 64,
|
||||
num_tasks = 4,
|
||||
num_latent_tokens = 1,
|
||||
depth = 1,
|
||||
time_block_every = 1,
|
||||
num_spatial_tokens = 1,
|
||||
pred_orig_latent = True,
|
||||
num_discrete_actions = 4,
|
||||
attn_dim_head = 16,
|
||||
prob_no_shortcut_train = 0.1,
|
||||
num_residual_streams = 1
|
||||
)
|
||||
|
||||
# training from self-generations (dreams)
|
||||
|
||||
from dreamer4.trainers import DreamTrainer
|
||||
|
||||
dream_trainer = DreamTrainer(
|
||||
world_model,
|
||||
batch_size = 2,
|
||||
num_train_steps = 1,
|
||||
cpu = True,
|
||||
)
|
||||
|
||||
dream_trainer()
|
||||
|
||||
def test_cache_generate():
|
||||
from dreamer4.dreamer4 import DynamicsWorldModel
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user