sketch out the dream trainer, seems like they only fine tune the heads
This commit is contained in:
parent
6f1a7a24ed
commit
03b16a48f2
@ -3,3 +3,10 @@ from dreamer4.dreamer4 import (
|
|||||||
DynamicsWorldModel,
|
DynamicsWorldModel,
|
||||||
Dreamer
|
Dreamer
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
from dreamer4.trainers import (
|
||||||
|
VideoTokenizerTrainer,
|
||||||
|
BehaviorCloneTrainer,
|
||||||
|
DreamTrainer
|
||||||
|
)
|
||||||
|
|||||||
@ -3,6 +3,7 @@ from __future__ import annotations
|
|||||||
import math
|
import math
|
||||||
from math import ceil, log2
|
from math import ceil, log2
|
||||||
from random import random
|
from random import random
|
||||||
|
from contextlib import nullcontext
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
@ -485,7 +486,7 @@ class ActionEmbedder(Module):
|
|||||||
return set([*self.discrete_action_embed.parameters(), *self.continuous_action_embed.parameters()])
|
return set([*self.discrete_action_embed.parameters(), *self.continuous_action_embed.parameters()])
|
||||||
|
|
||||||
def unembed_parameters(self):
|
def unembed_parameters(self):
|
||||||
return set([*self.discrete_action_unembed.parameters(), *self.continuous_action_unembed.parameters()])
|
return set([self.discrete_action_unembed, self.continuous_action_unembed])
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def device(self):
|
def device(self):
|
||||||
@ -1927,9 +1928,30 @@ class DynamicsWorldModel(Module):
|
|||||||
def device(self):
|
def device(self):
|
||||||
return self.zero.device
|
return self.zero.device
|
||||||
|
|
||||||
|
# types of parameters
|
||||||
|
|
||||||
def muon_parameters(self):
|
def muon_parameters(self):
|
||||||
return self.transformer.muon_parameters()
|
return self.transformer.muon_parameters()
|
||||||
|
|
||||||
|
def policy_head_parameters(self):
|
||||||
|
return [
|
||||||
|
*self.policy_head.parameters(),
|
||||||
|
*self.action_embedder.unembed_parameters() # includes the unembed from the action-embedder
|
||||||
|
]
|
||||||
|
|
||||||
|
def value_head_parameters(self):
|
||||||
|
return self.value_head.parameters()
|
||||||
|
|
||||||
|
def parameter(self):
|
||||||
|
params = super().parameters()
|
||||||
|
|
||||||
|
if not exists(self.video_tokenizer):
|
||||||
|
return params
|
||||||
|
|
||||||
|
return list(set(params) - set(self.video_tokenizer.parameters()))
|
||||||
|
|
||||||
|
# helpers for shortcut flow matching
|
||||||
|
|
||||||
def get_times_from_signal_level(
|
def get_times_from_signal_level(
|
||||||
self,
|
self,
|
||||||
signal_levels,
|
signal_levels,
|
||||||
@ -1942,19 +1964,14 @@ class DynamicsWorldModel(Module):
|
|||||||
|
|
||||||
return align_dims_left(times, align_dims_left_to)
|
return align_dims_left(times, align_dims_left_to)
|
||||||
|
|
||||||
def parameter(self):
|
# ppo
|
||||||
params = super().parameters()
|
|
||||||
|
|
||||||
if not exists(self.video_tokenizer):
|
|
||||||
return params
|
|
||||||
|
|
||||||
return list(set(params) - set(self.video_tokenizer.parameters()))
|
|
||||||
|
|
||||||
def learn_from_experience(
|
def learn_from_experience(
|
||||||
self,
|
self,
|
||||||
experience: Experience,
|
experience: Experience,
|
||||||
policy_optim: Optimizer | None = None,
|
policy_optim: Optimizer | None = None,
|
||||||
value_optim: Optimizer | None = None
|
value_optim: Optimizer | None = None,
|
||||||
|
only_learn_policy_value_heads = True # in the paper, they do not finetune the entire dynamics model, they just learn the heads
|
||||||
):
|
):
|
||||||
assert experience.is_batched
|
assert experience.is_batched
|
||||||
|
|
||||||
@ -1971,6 +1988,10 @@ class DynamicsWorldModel(Module):
|
|||||||
|
|
||||||
returns = calc_gae(rewards, old_values, gamma = self.gae_discount_factor, lam = self.gae_lambda, use_accelerated = self.gae_use_accelerated)
|
returns = calc_gae(rewards, old_values, gamma = self.gae_discount_factor, lam = self.gae_lambda, use_accelerated = self.gae_use_accelerated)
|
||||||
|
|
||||||
|
# determine whether to finetune entire transformer or just learn the heads
|
||||||
|
|
||||||
|
world_model_forward_context = torch.no_grad if only_learn_policy_value_heads else nullcontext
|
||||||
|
|
||||||
# apparently they just use the sign of the advantage
|
# apparently they just use the sign of the advantage
|
||||||
# https://arxiv.org/abs/2410.04166v1
|
# https://arxiv.org/abs/2410.04166v1
|
||||||
|
|
||||||
@ -1980,20 +2001,26 @@ class DynamicsWorldModel(Module):
|
|||||||
|
|
||||||
discrete_actions, continuous_actions = actions
|
discrete_actions, continuous_actions = actions
|
||||||
|
|
||||||
_, (agent_embed, _) = self.forward(
|
with world_model_forward_context():
|
||||||
latents = latents,
|
_, (agent_embed, _) = self.forward(
|
||||||
signal_levels = self.max_steps - 1,
|
latents = latents,
|
||||||
step_sizes = step_size,
|
signal_levels = self.max_steps - 1,
|
||||||
rewards = rewards,
|
step_sizes = step_size,
|
||||||
discrete_actions = discrete_actions,
|
rewards = rewards,
|
||||||
continuous_actions = continuous_actions,
|
discrete_actions = discrete_actions,
|
||||||
latent_is_noised = True,
|
continuous_actions = continuous_actions,
|
||||||
return_pred_only = True,
|
latent_is_noised = True,
|
||||||
return_intermediates = True
|
return_pred_only = True,
|
||||||
)
|
return_intermediates = True
|
||||||
|
)
|
||||||
|
|
||||||
agent_embed = agent_embed[..., agent_index, :]
|
agent_embed = agent_embed[..., agent_index, :]
|
||||||
|
|
||||||
|
# maybe detach agent embed
|
||||||
|
|
||||||
|
if only_learn_policy_value_heads:
|
||||||
|
agent_embed = agent_embed.detach()
|
||||||
|
|
||||||
# ppo
|
# ppo
|
||||||
|
|
||||||
policy_embed = self.policy_head(agent_embed)
|
policy_embed = self.policy_head(agent_embed)
|
||||||
|
|||||||
@ -3,6 +3,7 @@ from __future__ import annotations
|
|||||||
import torch
|
import torch
|
||||||
from torch import is_tensor
|
from torch import is_tensor
|
||||||
from torch.nn import Module
|
from torch.nn import Module
|
||||||
|
from torch.optim import AdamW
|
||||||
from torch.utils.data import Dataset, DataLoader
|
from torch.utils.data import Dataset, DataLoader
|
||||||
|
|
||||||
from accelerate import Accelerator
|
from accelerate import Accelerator
|
||||||
@ -213,3 +214,106 @@ class BehaviorCloneTrainer(Module):
|
|||||||
self.optim.zero_grad()
|
self.optim.zero_grad()
|
||||||
|
|
||||||
self.print('training complete')
|
self.print('training complete')
|
||||||
|
|
||||||
|
# training from dreams
|
||||||
|
|
||||||
|
class DreamTrainer(Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: DynamicsWorldModel,
|
||||||
|
optim_klass = AdamW,
|
||||||
|
batch_size = 16,
|
||||||
|
generate_timesteps = 16,
|
||||||
|
learning_rate = 3e-4,
|
||||||
|
max_grad_norm = None,
|
||||||
|
num_train_steps = 10_000,
|
||||||
|
weight_decay = 0.,
|
||||||
|
accelerate_kwargs: dict = dict(),
|
||||||
|
optim_kwargs: dict = dict(),
|
||||||
|
cpu = False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.accelerator = Accelerator(
|
||||||
|
cpu = cpu,
|
||||||
|
**accelerate_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
self.model = model
|
||||||
|
|
||||||
|
optim_kwargs = dict(
|
||||||
|
lr = learning_rate,
|
||||||
|
weight_decay = weight_decay
|
||||||
|
)
|
||||||
|
|
||||||
|
self.policy_head_optim = AdamW(model.policy_head_parameters(), **optim_kwargs)
|
||||||
|
self.value_head_optim = AdamW(model.value_head_parameters(), **optim_kwargs)
|
||||||
|
|
||||||
|
self.max_grad_norm = max_grad_norm
|
||||||
|
|
||||||
|
self.num_train_steps = num_train_steps
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.generate_timesteps = generate_timesteps
|
||||||
|
|
||||||
|
self.unwrapped_model = self.model
|
||||||
|
|
||||||
|
(
|
||||||
|
self.model,
|
||||||
|
self.policy_head_optim,
|
||||||
|
self.value_head_optim,
|
||||||
|
) = self.accelerator.prepare(
|
||||||
|
self.model,
|
||||||
|
self.policy_head_optim,
|
||||||
|
self.value_head_optim
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def device(self):
|
||||||
|
return self.accelerator.device
|
||||||
|
|
||||||
|
@property
|
||||||
|
def unwrapped_model(self):
|
||||||
|
return self.accelerator.unwrap_model(self.model)
|
||||||
|
|
||||||
|
def print(self, *args, **kwargs):
|
||||||
|
return self.accelerator.print(*args, **kwargs)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self
|
||||||
|
):
|
||||||
|
|
||||||
|
for _ in range(self.num_train_steps):
|
||||||
|
|
||||||
|
dreams = self.unwrapped_model.generate(
|
||||||
|
self.generate_timesteps,
|
||||||
|
batch_size = self.batch_size,
|
||||||
|
return_rewards_per_frame = True,
|
||||||
|
return_agent_actions = True,
|
||||||
|
return_log_probs_and_values = True
|
||||||
|
)
|
||||||
|
|
||||||
|
policy_head_loss, value_head_loss = self.model.learn_from_experience(dreams)
|
||||||
|
|
||||||
|
self.print(f'policy head loss: {policy_head_loss.item():.3f} | value head loss: {value_head_loss.item():.3f}')
|
||||||
|
|
||||||
|
# update policy head
|
||||||
|
|
||||||
|
self.accelerator.backward(policy_head_loss)
|
||||||
|
|
||||||
|
if exists(self.max_grad_norm):
|
||||||
|
self.accelerator.clip_grad_norm_(self.model.policy_head_parameters()(), self.max_grad_norm)
|
||||||
|
|
||||||
|
self.policy_head_optim.step()
|
||||||
|
self.policy_head_optim.zero_grad()
|
||||||
|
|
||||||
|
# update value head
|
||||||
|
|
||||||
|
self.accelerator.backward(value_head_loss)
|
||||||
|
|
||||||
|
if exists(self.max_grad_norm):
|
||||||
|
self.accelerator.clip_grad_norm_(self.model.value_head_parameters(), self.max_grad_norm)
|
||||||
|
|
||||||
|
self.value_head_optim.step()
|
||||||
|
self.value_head_optim.zero_grad()
|
||||||
|
|
||||||
|
self.print('training complete')
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "dreamer4"
|
name = "dreamer4"
|
||||||
version = "0.0.59"
|
version = "0.0.60"
|
||||||
description = "Dreamer 4"
|
description = "Dreamer 4"
|
||||||
authors = [
|
authors = [
|
||||||
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
||||||
|
|||||||
@ -518,7 +518,7 @@ def test_bc_trainer(
|
|||||||
num_latent_tokens = 1
|
num_latent_tokens = 1
|
||||||
)
|
)
|
||||||
|
|
||||||
model = DynamicsWorldModel(
|
world_model = DynamicsWorldModel(
|
||||||
video_tokenizer = tokenizer,
|
video_tokenizer = tokenizer,
|
||||||
dim = 16,
|
dim = 16,
|
||||||
dim_latent = 16,
|
dim_latent = 16,
|
||||||
@ -536,7 +536,7 @@ def test_bc_trainer(
|
|||||||
)
|
)
|
||||||
|
|
||||||
trainer = BehaviorCloneTrainer(
|
trainer = BehaviorCloneTrainer(
|
||||||
model,
|
world_model,
|
||||||
dataset = dataset,
|
dataset = dataset,
|
||||||
batch_size = 1,
|
batch_size = 1,
|
||||||
num_train_steps = 1,
|
num_train_steps = 1,
|
||||||
@ -545,6 +545,38 @@ def test_bc_trainer(
|
|||||||
|
|
||||||
trainer()
|
trainer()
|
||||||
|
|
||||||
|
def test_dream_trainer():
|
||||||
|
from dreamer4.dreamer4 import DynamicsWorldModel
|
||||||
|
|
||||||
|
world_model = DynamicsWorldModel(
|
||||||
|
dim = 16,
|
||||||
|
dim_latent = 16,
|
||||||
|
max_steps = 64,
|
||||||
|
num_tasks = 4,
|
||||||
|
num_latent_tokens = 1,
|
||||||
|
depth = 1,
|
||||||
|
time_block_every = 1,
|
||||||
|
num_spatial_tokens = 1,
|
||||||
|
pred_orig_latent = True,
|
||||||
|
num_discrete_actions = 4,
|
||||||
|
attn_dim_head = 16,
|
||||||
|
prob_no_shortcut_train = 0.1,
|
||||||
|
num_residual_streams = 1
|
||||||
|
)
|
||||||
|
|
||||||
|
# training from self-generations (dreams)
|
||||||
|
|
||||||
|
from dreamer4.trainers import DreamTrainer
|
||||||
|
|
||||||
|
dream_trainer = DreamTrainer(
|
||||||
|
world_model,
|
||||||
|
batch_size = 2,
|
||||||
|
num_train_steps = 1,
|
||||||
|
cpu = True,
|
||||||
|
)
|
||||||
|
|
||||||
|
dream_trainer()
|
||||||
|
|
||||||
def test_cache_generate():
|
def test_cache_generate():
|
||||||
from dreamer4.dreamer4 import DynamicsWorldModel
|
from dreamer4.dreamer4 import DynamicsWorldModel
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user