sketch of training from sim env

This commit is contained in:
lucidrains 2025-10-24 09:13:09 -07:00
parent 27ac05efb0
commit 35c1db4c7d
3 changed files with 214 additions and 3 deletions

View File

@ -4,7 +4,7 @@ import torch
from torch import is_tensor
from torch.nn import Module
from torch.optim import AdamW
from torch.utils.data import Dataset, DataLoader
from torch.utils.data import Dataset, TensorDataset, DataLoader
from accelerate import Accelerator
@ -12,7 +12,9 @@ from adam_atan2_pytorch import MuonAdamAtan2
from dreamer4.dreamer4 import (
VideoTokenizer,
DynamicsWorldModel
DynamicsWorldModel,
Experience,
combine_experiences
)
from ema_pytorch import EMA
@ -317,3 +319,198 @@ class DreamTrainer(Module):
self.value_head_optim.zero_grad()
self.print('training complete')
# training from sim
class SimTrainer(Module):
def __init__(
self,
model: DynamicsWorldModel,
optim_klass = AdamW,
batch_size = 16,
generate_timesteps = 16,
learning_rate = 3e-4,
max_grad_norm = None,
epochs = 2,
weight_decay = 0.,
accelerate_kwargs: dict = dict(),
optim_kwargs: dict = dict(),
cpu = False,
):
super().__init__()
self.accelerator = Accelerator(
cpu = cpu,
**accelerate_kwargs
)
self.model = model
optim_kwargs = dict(
lr = learning_rate,
weight_decay = weight_decay
)
self.policy_head_optim = AdamW(model.policy_head_parameters(), **optim_kwargs)
self.value_head_optim = AdamW(model.value_head_parameters(), **optim_kwargs)
self.max_grad_norm = max_grad_norm
self.epochs = epochs
self.batch_size = batch_size
self.generate_timesteps = generate_timesteps
self.unwrapped_model = self.model
(
self.model,
self.policy_head_optim,
self.value_head_optim,
) = self.accelerator.prepare(
self.model,
self.policy_head_optim,
self.value_head_optim
)
@property
def device(self):
return self.accelerator.device
@property
def unwrapped_model(self):
return self.accelerator.unwrap_model(self.model)
def print(self, *args, **kwargs):
return self.accelerator.print(*args, **kwargs)
def learn(
self,
experience: Experience
):
step_size = experience.step_size
agent_index = experience.agent_index
latents = experience.latents
old_values = experience.values
rewards = experience.rewards
discrete_actions, continuous_actions = experience.actions
discrete_log_probs, continuous_log_probs = experience.log_probs
# handle empties
empty_tensor = torch.empty_like(rewards)
has_discrete = exists(discrete_actions)
has_continuous = exists(continuous_actions)
discrete_actions = default(discrete_actions, empty_tensor)
continuous_actions = default(continuous_actions, empty_tensor)
discrete_log_probs = default(discrete_log_probs, empty_tensor)
continuous_log_probs = default(continuous_log_probs, empty_tensor)
# create the dataset and dataloader
dataset = TensorDataset(
latents,
discrete_actions,
continuous_actions,
discrete_log_probs,
continuous_log_probs,
old_values,
rewards
)
dataloader = DataLoader(dataset, batch_size = self.batch_size, shuffle = True)
for epoch in range(self.epochs):
for (
latents,
discrete_actions,
continuous_actions,
discrete_log_probs,
continuous_log_probs,
old_values,
rewards
) in dataloader:
actions = (
discrete_actions if has_discrete else None,
continuous_actions if has_continuous else None
)
log_probs = (
discrete_log_probs if has_discrete else None,
continuous_log_probs if has_continuous else None
)
batch_experience = Experience(
latents = latents,
actions = actions,
log_probs = log_probs,
values = old_values,
rewards = rewards,
step_size = step_size,
agent_index = agent_index
)
policy_head_loss, value_head_loss = self.model.learn_from_experience(batch_experience)
self.print(f'policy head loss: {policy_head_loss.item():.3f} | value head loss: {value_head_loss.item():.3f}')
# update policy head
self.accelerator.backward(policy_head_loss)
if exists(self.max_grad_norm):
self.accelerator.clip_grad_norm_(self.model.policy_head_parameters()(), self.max_grad_norm)
self.policy_head_optim.step()
self.policy_head_optim.zero_grad()
# update value head
self.accelerator.backward(value_head_loss)
if exists(self.max_grad_norm):
self.accelerator.clip_grad_norm_(self.model.value_head_parameters(), self.max_grad_norm)
self.value_head_optim.step()
self.value_head_optim.zero_grad()
self.print('training complete')
def forward(
self,
env,
num_episodes = 50000,
max_experiences_before_learn = 8,
env_is_vectorized = False
):
for _ in range(num_episodes):
total_experience = 0
experiences = []
while total_experience < max_experiences_before_learn:
experience = self.unwrapped_model.interact_with_env(env, env_is_vectorized = env_is_vectorized)
num_experience = experience.video.shape[0]
total_experience += num_experience
experiences.append(experience)
combined_experiences = combine_experiences(experiences)
self.learn(combined_experiences)
experiences.clear()
self.print('training complete')

View File

@ -1,6 +1,6 @@
[project]
name = "dreamer4"
version = "0.0.67"
version = "0.0.69"
description = "Dreamer 4"
authors = [
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

View File

@ -641,6 +641,8 @@ def test_online_rl(
mock_env = MockEnv((256, 256), vectorized = vectorized, num_envs = 4)
# manually
one_experience = world_model_and_policy.interact_with_env(mock_env, max_timesteps = 16, env_is_vectorized = vectorized)
another_experience = world_model_and_policy.interact_with_env(mock_env, max_timesteps = 16, env_is_vectorized = vectorized)
@ -650,3 +652,15 @@ def test_online_rl(
actor_loss.backward()
critic_loss.backward()
# with trainer
from dreamer4.trainers import SimTrainer
trainer = SimTrainer(
world_model_and_policy,
batch_size = 4,
cpu = True
)
trainer(mock_env, num_episodes = 2, env_is_vectorized = vectorized)