sketch of training from sim env
This commit is contained in:
parent
27ac05efb0
commit
35c1db4c7d
@ -4,7 +4,7 @@ import torch
|
||||
from torch import is_tensor
|
||||
from torch.nn import Module
|
||||
from torch.optim import AdamW
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
from torch.utils.data import Dataset, TensorDataset, DataLoader
|
||||
|
||||
from accelerate import Accelerator
|
||||
|
||||
@ -12,7 +12,9 @@ from adam_atan2_pytorch import MuonAdamAtan2
|
||||
|
||||
from dreamer4.dreamer4 import (
|
||||
VideoTokenizer,
|
||||
DynamicsWorldModel
|
||||
DynamicsWorldModel,
|
||||
Experience,
|
||||
combine_experiences
|
||||
)
|
||||
|
||||
from ema_pytorch import EMA
|
||||
@ -317,3 +319,198 @@ class DreamTrainer(Module):
|
||||
self.value_head_optim.zero_grad()
|
||||
|
||||
self.print('training complete')
|
||||
|
||||
# training from sim
|
||||
|
||||
class SimTrainer(Module):
|
||||
def __init__(
|
||||
self,
|
||||
model: DynamicsWorldModel,
|
||||
optim_klass = AdamW,
|
||||
batch_size = 16,
|
||||
generate_timesteps = 16,
|
||||
learning_rate = 3e-4,
|
||||
max_grad_norm = None,
|
||||
epochs = 2,
|
||||
weight_decay = 0.,
|
||||
accelerate_kwargs: dict = dict(),
|
||||
optim_kwargs: dict = dict(),
|
||||
cpu = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.accelerator = Accelerator(
|
||||
cpu = cpu,
|
||||
**accelerate_kwargs
|
||||
)
|
||||
|
||||
self.model = model
|
||||
|
||||
optim_kwargs = dict(
|
||||
lr = learning_rate,
|
||||
weight_decay = weight_decay
|
||||
)
|
||||
|
||||
self.policy_head_optim = AdamW(model.policy_head_parameters(), **optim_kwargs)
|
||||
self.value_head_optim = AdamW(model.value_head_parameters(), **optim_kwargs)
|
||||
|
||||
self.max_grad_norm = max_grad_norm
|
||||
|
||||
self.epochs = epochs
|
||||
self.batch_size = batch_size
|
||||
|
||||
self.generate_timesteps = generate_timesteps
|
||||
|
||||
self.unwrapped_model = self.model
|
||||
|
||||
(
|
||||
self.model,
|
||||
self.policy_head_optim,
|
||||
self.value_head_optim,
|
||||
) = self.accelerator.prepare(
|
||||
self.model,
|
||||
self.policy_head_optim,
|
||||
self.value_head_optim
|
||||
)
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return self.accelerator.device
|
||||
|
||||
@property
|
||||
def unwrapped_model(self):
|
||||
return self.accelerator.unwrap_model(self.model)
|
||||
|
||||
def print(self, *args, **kwargs):
|
||||
return self.accelerator.print(*args, **kwargs)
|
||||
|
||||
def learn(
|
||||
self,
|
||||
experience: Experience
|
||||
):
|
||||
|
||||
step_size = experience.step_size
|
||||
agent_index = experience.agent_index
|
||||
|
||||
latents = experience.latents
|
||||
old_values = experience.values
|
||||
rewards = experience.rewards
|
||||
|
||||
discrete_actions, continuous_actions = experience.actions
|
||||
discrete_log_probs, continuous_log_probs = experience.log_probs
|
||||
|
||||
# handle empties
|
||||
|
||||
empty_tensor = torch.empty_like(rewards)
|
||||
|
||||
has_discrete = exists(discrete_actions)
|
||||
has_continuous = exists(continuous_actions)
|
||||
|
||||
discrete_actions = default(discrete_actions, empty_tensor)
|
||||
continuous_actions = default(continuous_actions, empty_tensor)
|
||||
|
||||
discrete_log_probs = default(discrete_log_probs, empty_tensor)
|
||||
continuous_log_probs = default(continuous_log_probs, empty_tensor)
|
||||
|
||||
# create the dataset and dataloader
|
||||
|
||||
dataset = TensorDataset(
|
||||
latents,
|
||||
discrete_actions,
|
||||
continuous_actions,
|
||||
discrete_log_probs,
|
||||
continuous_log_probs,
|
||||
old_values,
|
||||
rewards
|
||||
)
|
||||
|
||||
dataloader = DataLoader(dataset, batch_size = self.batch_size, shuffle = True)
|
||||
|
||||
for epoch in range(self.epochs):
|
||||
|
||||
for (
|
||||
latents,
|
||||
discrete_actions,
|
||||
continuous_actions,
|
||||
discrete_log_probs,
|
||||
continuous_log_probs,
|
||||
old_values,
|
||||
rewards
|
||||
) in dataloader:
|
||||
|
||||
actions = (
|
||||
discrete_actions if has_discrete else None,
|
||||
continuous_actions if has_continuous else None
|
||||
)
|
||||
|
||||
log_probs = (
|
||||
discrete_log_probs if has_discrete else None,
|
||||
continuous_log_probs if has_continuous else None
|
||||
)
|
||||
|
||||
batch_experience = Experience(
|
||||
latents = latents,
|
||||
actions = actions,
|
||||
log_probs = log_probs,
|
||||
values = old_values,
|
||||
rewards = rewards,
|
||||
step_size = step_size,
|
||||
agent_index = agent_index
|
||||
)
|
||||
|
||||
policy_head_loss, value_head_loss = self.model.learn_from_experience(batch_experience)
|
||||
|
||||
self.print(f'policy head loss: {policy_head_loss.item():.3f} | value head loss: {value_head_loss.item():.3f}')
|
||||
|
||||
# update policy head
|
||||
|
||||
self.accelerator.backward(policy_head_loss)
|
||||
|
||||
if exists(self.max_grad_norm):
|
||||
self.accelerator.clip_grad_norm_(self.model.policy_head_parameters()(), self.max_grad_norm)
|
||||
|
||||
self.policy_head_optim.step()
|
||||
self.policy_head_optim.zero_grad()
|
||||
|
||||
# update value head
|
||||
|
||||
self.accelerator.backward(value_head_loss)
|
||||
|
||||
if exists(self.max_grad_norm):
|
||||
self.accelerator.clip_grad_norm_(self.model.value_head_parameters(), self.max_grad_norm)
|
||||
|
||||
self.value_head_optim.step()
|
||||
self.value_head_optim.zero_grad()
|
||||
|
||||
self.print('training complete')
|
||||
|
||||
def forward(
|
||||
self,
|
||||
env,
|
||||
num_episodes = 50000,
|
||||
max_experiences_before_learn = 8,
|
||||
env_is_vectorized = False
|
||||
):
|
||||
|
||||
for _ in range(num_episodes):
|
||||
|
||||
total_experience = 0
|
||||
experiences = []
|
||||
|
||||
while total_experience < max_experiences_before_learn:
|
||||
|
||||
experience = self.unwrapped_model.interact_with_env(env, env_is_vectorized = env_is_vectorized)
|
||||
|
||||
num_experience = experience.video.shape[0]
|
||||
|
||||
total_experience += num_experience
|
||||
|
||||
experiences.append(experience)
|
||||
|
||||
combined_experiences = combine_experiences(experiences)
|
||||
|
||||
self.learn(combined_experiences)
|
||||
|
||||
experiences.clear()
|
||||
|
||||
self.print('training complete')
|
||||
@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "dreamer4"
|
||||
version = "0.0.67"
|
||||
version = "0.0.69"
|
||||
description = "Dreamer 4"
|
||||
authors = [
|
||||
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
||||
|
||||
@ -641,6 +641,8 @@ def test_online_rl(
|
||||
|
||||
mock_env = MockEnv((256, 256), vectorized = vectorized, num_envs = 4)
|
||||
|
||||
# manually
|
||||
|
||||
one_experience = world_model_and_policy.interact_with_env(mock_env, max_timesteps = 16, env_is_vectorized = vectorized)
|
||||
another_experience = world_model_and_policy.interact_with_env(mock_env, max_timesteps = 16, env_is_vectorized = vectorized)
|
||||
|
||||
@ -650,3 +652,15 @@ def test_online_rl(
|
||||
|
||||
actor_loss.backward()
|
||||
critic_loss.backward()
|
||||
|
||||
# with trainer
|
||||
|
||||
from dreamer4.trainers import SimTrainer
|
||||
|
||||
trainer = SimTrainer(
|
||||
world_model_and_policy,
|
||||
batch_size = 4,
|
||||
cpu = True
|
||||
)
|
||||
|
||||
trainer(mock_env, num_episodes = 2, env_is_vectorized = vectorized)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user