dreamer4/dreamer4/trainers.py

516 lines
14 KiB
Python

from __future__ import annotations
import torch
from torch import is_tensor
from torch.nn import Module
from torch.optim import AdamW
from torch.utils.data import Dataset, TensorDataset, DataLoader
from accelerate import Accelerator
from adam_atan2_pytorch import MuonAdamAtan2
from dreamer4.dreamer4 import (
VideoTokenizer,
DynamicsWorldModel,
Experience,
combine_experiences
)
from ema_pytorch import EMA
# helpers
def exists(v):
return v is not None
def default(v, d):
return v if exists(v) else d
def cycle(dl):
while True:
for batch in dl:
yield batch
# trainers
class VideoTokenizerTrainer(Module):
def __init__(
self,
model: VideoTokenizer,
dataset: Dataset,
optim_klass = MuonAdamAtan2,
batch_size = 16,
learning_rate = 3e-4,
max_grad_norm = None,
num_train_steps = 10_000,
weight_decay = 0.,
accelerate_kwargs: dict = dict(),
optim_kwargs: dict = dict(),
cpu = False,
):
super().__init__()
batch_size = min(batch_size, len(dataset))
self.accelerator = Accelerator(
cpu = cpu,
**accelerate_kwargs
)
self.model = model
self.dataset = dataset
self.train_dataloader = DataLoader(dataset, batch_size = batch_size, drop_last = True, shuffle = True)
optim_kwargs = dict(
lr = learning_rate,
weight_decay = weight_decay
)
if optim_klass is MuonAdamAtan2:
optim = MuonAdamAtan2(
model.muon_parameters(),
model.parameters(),
**optim_kwargs
)
else:
optim = optim_klass(
model.parameters(),
**optim_kwargs
)
self.optim = optim
self.max_grad_norm = max_grad_norm
self.num_train_steps = num_train_steps
self.batch_size = batch_size
(
self.model,
self.train_dataloader,
self.optim
) = self.accelerator.prepare(
self.model,
self.train_dataloader,
self.optim
)
@property
def device(self):
return self.accelerator.device
def print(self, *args, **kwargs):
return self.accelerator.print(*args, **kwargs)
def forward(
self
):
iter_train_dl = cycle(self.train_dataloader)
for _ in range(self.num_train_steps):
video = next(iter_train_dl)
loss = self.model(video)
self.accelerator.backward(loss)
if exists(self.max_grad_norm):
self.accelerator.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
self.optim.step()
self.optim.zero_grad()
self.print('training complete')
# dynamics world model
class BehaviorCloneTrainer(Module):
def __init__(
self,
model: DynamicsWorldModel,
dataset: Dataset,
optim_klass = MuonAdamAtan2,
batch_size = 16,
learning_rate = 3e-4,
max_grad_norm = None,
num_train_steps = 10_000,
weight_decay = 0.,
accelerate_kwargs: dict = dict(),
optim_kwargs: dict = dict(),
cpu = False,
):
super().__init__()
batch_size = min(batch_size, len(dataset))
self.accelerator = Accelerator(
cpu = cpu,
**accelerate_kwargs
)
self.model = model
self.dataset = dataset
self.train_dataloader = DataLoader(dataset, batch_size = batch_size, drop_last = True, shuffle = True)
optim_kwargs = dict(
lr = learning_rate,
weight_decay = weight_decay
)
if optim_klass is MuonAdamAtan2:
optim = MuonAdamAtan2(
model.muon_parameters(),
model.parameters(),
**optim_kwargs
)
else:
optim = optim_klass(
model.parameters(),
**optim_kwargs
)
self.optim = optim
self.max_grad_norm = max_grad_norm
self.num_train_steps = num_train_steps
self.batch_size = batch_size
(
self.model,
self.train_dataloader,
self.optim
) = self.accelerator.prepare(
self.model,
self.train_dataloader,
self.optim
)
@property
def device(self):
return self.accelerator.device
def print(self, *args, **kwargs):
return self.accelerator.print(*args, **kwargs)
def forward(
self
):
iter_train_dl = cycle(self.train_dataloader)
for _ in range(self.num_train_steps):
batch_data = next(iter_train_dl)
# just assume raw video dynamics training if batch_data is a tensor
# else kwargs for video, actions, rewards
if is_tensor(batch_data):
loss = self.model(batch_data)
else:
loss = self.model(**batch_data)
self.accelerator.backward(loss)
if exists(self.max_grad_norm):
self.accelerator.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
self.optim.step()
self.optim.zero_grad()
self.print('training complete')
# training from dreams
class DreamTrainer(Module):
def __init__(
self,
model: DynamicsWorldModel,
optim_klass = AdamW,
batch_size = 16,
generate_timesteps = 16,
learning_rate = 3e-4,
max_grad_norm = None,
num_train_steps = 10_000,
weight_decay = 0.,
accelerate_kwargs: dict = dict(),
optim_kwargs: dict = dict(),
cpu = False,
):
super().__init__()
self.accelerator = Accelerator(
cpu = cpu,
**accelerate_kwargs
)
self.model = model
optim_kwargs = dict(
lr = learning_rate,
weight_decay = weight_decay
)
self.policy_head_optim = AdamW(model.policy_head_parameters(), **optim_kwargs)
self.value_head_optim = AdamW(model.value_head_parameters(), **optim_kwargs)
self.max_grad_norm = max_grad_norm
self.num_train_steps = num_train_steps
self.batch_size = batch_size
self.generate_timesteps = generate_timesteps
self.unwrapped_model = self.model
(
self.model,
self.policy_head_optim,
self.value_head_optim,
) = self.accelerator.prepare(
self.model,
self.policy_head_optim,
self.value_head_optim
)
@property
def device(self):
return self.accelerator.device
@property
def unwrapped_model(self):
return self.accelerator.unwrap_model(self.model)
def print(self, *args, **kwargs):
return self.accelerator.print(*args, **kwargs)
def forward(
self
):
for _ in range(self.num_train_steps):
dreams = self.unwrapped_model.generate(
self.generate_timesteps + 1, # plus one for bootstrap value
batch_size = self.batch_size,
return_rewards_per_frame = True,
return_agent_actions = True,
return_log_probs_and_values = True
)
policy_head_loss, value_head_loss = self.model.learn_from_experience(dreams)
self.print(f'policy head loss: {policy_head_loss.item():.3f} | value head loss: {value_head_loss.item():.3f}')
# update policy head
self.accelerator.backward(policy_head_loss)
if exists(self.max_grad_norm):
self.accelerator.clip_grad_norm_(self.model.policy_head_parameters()(), self.max_grad_norm)
self.policy_head_optim.step()
self.policy_head_optim.zero_grad()
# update value head
self.accelerator.backward(value_head_loss)
if exists(self.max_grad_norm):
self.accelerator.clip_grad_norm_(self.model.value_head_parameters(), self.max_grad_norm)
self.value_head_optim.step()
self.value_head_optim.zero_grad()
self.print('training complete')
# training from sim
class SimTrainer(Module):
def __init__(
self,
model: DynamicsWorldModel,
optim_klass = AdamW,
batch_size = 16,
generate_timesteps = 16,
learning_rate = 3e-4,
max_grad_norm = None,
epochs = 2,
weight_decay = 0.,
accelerate_kwargs: dict = dict(),
optim_kwargs: dict = dict(),
cpu = False,
):
super().__init__()
self.accelerator = Accelerator(
cpu = cpu,
**accelerate_kwargs
)
self.model = model
optim_kwargs = dict(
lr = learning_rate,
weight_decay = weight_decay
)
self.policy_head_optim = AdamW(model.policy_head_parameters(), **optim_kwargs)
self.value_head_optim = AdamW(model.value_head_parameters(), **optim_kwargs)
self.max_grad_norm = max_grad_norm
self.epochs = epochs
self.batch_size = batch_size
self.generate_timesteps = generate_timesteps
self.unwrapped_model = self.model
(
self.model,
self.policy_head_optim,
self.value_head_optim,
) = self.accelerator.prepare(
self.model,
self.policy_head_optim,
self.value_head_optim
)
@property
def device(self):
return self.accelerator.device
@property
def unwrapped_model(self):
return self.accelerator.unwrap_model(self.model)
def print(self, *args, **kwargs):
return self.accelerator.print(*args, **kwargs)
def learn(
self,
experience: Experience
):
step_size = experience.step_size
agent_index = experience.agent_index
latents = experience.latents
old_values = experience.values
rewards = experience.rewards
discrete_actions, continuous_actions = experience.actions
discrete_log_probs, continuous_log_probs = experience.log_probs
# handle empties
empty_tensor = torch.empty_like(rewards)
has_discrete = exists(discrete_actions)
has_continuous = exists(continuous_actions)
discrete_actions = default(discrete_actions, empty_tensor)
continuous_actions = default(continuous_actions, empty_tensor)
discrete_log_probs = default(discrete_log_probs, empty_tensor)
continuous_log_probs = default(continuous_log_probs, empty_tensor)
# create the dataset and dataloader
dataset = TensorDataset(
latents,
discrete_actions,
continuous_actions,
discrete_log_probs,
continuous_log_probs,
old_values,
rewards
)
dataloader = DataLoader(dataset, batch_size = self.batch_size, shuffle = True)
for epoch in range(self.epochs):
for (
latents,
discrete_actions,
continuous_actions,
discrete_log_probs,
continuous_log_probs,
old_values,
rewards
) in dataloader:
actions = (
discrete_actions if has_discrete else None,
continuous_actions if has_continuous else None
)
log_probs = (
discrete_log_probs if has_discrete else None,
continuous_log_probs if has_continuous else None
)
batch_experience = Experience(
latents = latents,
actions = actions,
log_probs = log_probs,
values = old_values,
rewards = rewards,
step_size = step_size,
agent_index = agent_index
)
policy_head_loss, value_head_loss = self.model.learn_from_experience(batch_experience)
self.print(f'policy head loss: {policy_head_loss.item():.3f} | value head loss: {value_head_loss.item():.3f}')
# update policy head
self.accelerator.backward(policy_head_loss)
if exists(self.max_grad_norm):
self.accelerator.clip_grad_norm_(self.model.policy_head_parameters()(), self.max_grad_norm)
self.policy_head_optim.step()
self.policy_head_optim.zero_grad()
# update value head
self.accelerator.backward(value_head_loss)
if exists(self.max_grad_norm):
self.accelerator.clip_grad_norm_(self.model.value_head_parameters(), self.max_grad_norm)
self.value_head_optim.step()
self.value_head_optim.zero_grad()
self.print('training complete')
def forward(
self,
env,
num_episodes = 50000,
max_experiences_before_learn = 8,
env_is_vectorized = False
):
for _ in range(num_episodes):
total_experience = 0
experiences = []
while total_experience < max_experiences_before_learn:
experience = self.unwrapped_model.interact_with_env(env, env_is_vectorized = env_is_vectorized)
num_experience = experience.video.shape[0]
total_experience += num_experience
experiences.append(experience)
combined_experiences = combine_experiences(experiences)
self.learn(combined_experiences)
experiences.clear()
self.print('training complete')