sketch of training from sim env

This commit is contained in:
lucidrains 2025-10-24 08:56:51 -07:00
parent 27ac05efb0
commit 8526347316
3 changed files with 212 additions and 3 deletions

View File

@ -4,7 +4,7 @@ import torch
from torch import is_tensor from torch import is_tensor
from torch.nn import Module from torch.nn import Module
from torch.optim import AdamW from torch.optim import AdamW
from torch.utils.data import Dataset, DataLoader from torch.utils.data import Dataset, TensorDataset, DataLoader
from accelerate import Accelerator from accelerate import Accelerator
@ -12,7 +12,9 @@ from adam_atan2_pytorch import MuonAdamAtan2
from dreamer4.dreamer4 import ( from dreamer4.dreamer4 import (
VideoTokenizer, VideoTokenizer,
DynamicsWorldModel DynamicsWorldModel,
Experience,
combine_experiences
) )
from ema_pytorch import EMA from ema_pytorch import EMA
@ -317,3 +319,196 @@ class DreamTrainer(Module):
self.value_head_optim.zero_grad() self.value_head_optim.zero_grad()
self.print('training complete') self.print('training complete')
# training from sim
class SimTrainer(Module):
def __init__(
self,
model: DynamicsWorldModel,
optim_klass = AdamW,
batch_size = 16,
generate_timesteps = 16,
learning_rate = 3e-4,
max_grad_norm = None,
epochs = 2,
weight_decay = 0.,
accelerate_kwargs: dict = dict(),
optim_kwargs: dict = dict(),
cpu = False,
):
super().__init__()
self.accelerator = Accelerator(
cpu = cpu,
**accelerate_kwargs
)
self.model = model
optim_kwargs = dict(
lr = learning_rate,
weight_decay = weight_decay
)
self.policy_head_optim = AdamW(model.policy_head_parameters(), **optim_kwargs)
self.value_head_optim = AdamW(model.value_head_parameters(), **optim_kwargs)
self.max_grad_norm = max_grad_norm
self.epochs = epochs
self.batch_size = batch_size
self.generate_timesteps = generate_timesteps
self.unwrapped_model = self.model
(
self.model,
self.policy_head_optim,
self.value_head_optim,
) = self.accelerator.prepare(
self.model,
self.policy_head_optim,
self.value_head_optim
)
@property
def device(self):
return self.accelerator.device
@property
def unwrapped_model(self):
return self.accelerator.unwrap_model(self.model)
def print(self, *args, **kwargs):
return self.accelerator.print(*args, **kwargs)
def learn(
self,
experience: Experience
):
step_size = experience.step_size
agent_index = experience.agent_index
latents = experience.latents
old_values = experience.values
rewards = experience.rewards
discrete_actions, continuous_actions = experience.actions
discrete_log_probs, continuous_log_probs = experience.log_probs
# handle empties
empty_tensor = torch.empty_like(rewards)
has_discrete = exists(discrete_actions)
has_continuous = exists(continuous_actions)
discrete_actions = default(discrete_actions, empty_tensor)
continuous_actions = default(continuous_actions, empty_tensor)
discrete_log_probs = default(discrete_log_probs, empty_tensor)
continuous_log_probs = default(continuous_log_probs, empty_tensor)
# create the dataset and dataloader
dataset = TensorDataset(
latents,
discrete_actions,
continuous_actions,
discrete_log_probs,
continuous_log_probs,
old_values,
rewards
)
dataloader = DataLoader(dataset, batch_size = self.batch_size, shuffle = True)
for epoch in range(self.epochs):
for (
latents,
discrete_actions,
continuous_actions,
discrete_log_probs,
continuous_log_probs,
old_values,
rewards
) in dataloader:
actions = (
discrete_actions if has_discrete else None,
continuous_actions if has_continuous else None
)
log_probs = (
discrete_log_probs if has_discrete else None,
continuous_log_probs if has_continuous else None
)
batch_experience = Experience(
latents = latents,
actions = actions,
log_probs = log_probs,
values = old_values,
rewards = rewards,
step_size = step_size,
agent_index = agent_index
)
policy_head_loss, value_head_loss = self.model.learn_from_experience(batch_experience)
self.print(f'policy head loss: {policy_head_loss.item():.3f} | value head loss: {value_head_loss.item():.3f}')
# update policy head
self.accelerator.backward(policy_head_loss)
if exists(self.max_grad_norm):
self.accelerator.clip_grad_norm_(self.model.policy_head_parameters()(), self.max_grad_norm)
self.policy_head_optim.step()
self.policy_head_optim.zero_grad()
# update value head
self.accelerator.backward(value_head_loss)
if exists(self.max_grad_norm):
self.accelerator.clip_grad_norm_(self.model.value_head_parameters(), self.max_grad_norm)
self.value_head_optim.step()
self.value_head_optim.zero_grad()
self.print('training complete')
def forward(
self,
env,
num_episodes = 50000,
max_experiences_before_learn = 8,
env_is_vectorized = False
):
for _ in range(num_episodes):
total_experience = 0
experiences = []
while total_experience < max_experiences_before_learn:
experience = self.unwrapped_model.interact_with_env(env, env_is_vectorized = env_is_vectorized)
num_experience = experience.video.shape[0]
total_experience += num_experience
experiences.append(experience)
experiences = combine_experiences(experiences)
self.learn(experiences)
self.print('training complete')

View File

@ -1,6 +1,6 @@
[project] [project]
name = "dreamer4" name = "dreamer4"
version = "0.0.67" version = "0.0.68"
description = "Dreamer 4" description = "Dreamer 4"
authors = [ authors = [
{ name = "Phil Wang", email = "lucidrains@gmail.com" } { name = "Phil Wang", email = "lucidrains@gmail.com" }

View File

@ -641,6 +641,8 @@ def test_online_rl(
mock_env = MockEnv((256, 256), vectorized = vectorized, num_envs = 4) mock_env = MockEnv((256, 256), vectorized = vectorized, num_envs = 4)
# manually
one_experience = world_model_and_policy.interact_with_env(mock_env, max_timesteps = 16, env_is_vectorized = vectorized) one_experience = world_model_and_policy.interact_with_env(mock_env, max_timesteps = 16, env_is_vectorized = vectorized)
another_experience = world_model_and_policy.interact_with_env(mock_env, max_timesteps = 16, env_is_vectorized = vectorized) another_experience = world_model_and_policy.interact_with_env(mock_env, max_timesteps = 16, env_is_vectorized = vectorized)
@ -650,3 +652,15 @@ def test_online_rl(
actor_loss.backward() actor_loss.backward()
critic_loss.backward() critic_loss.backward()
# with trainer
from dreamer4.trainers import SimTrainer
trainer = SimTrainer(
world_model_and_policy,
batch_size = 4,
cpu = True
)
trainer(mock_env, num_episodes = 2, env_is_vectorized = vectorized)