sketch of training from sim env
This commit is contained in:
parent
27ac05efb0
commit
8526347316
@ -4,7 +4,7 @@ import torch
|
|||||||
from torch import is_tensor
|
from torch import is_tensor
|
||||||
from torch.nn import Module
|
from torch.nn import Module
|
||||||
from torch.optim import AdamW
|
from torch.optim import AdamW
|
||||||
from torch.utils.data import Dataset, DataLoader
|
from torch.utils.data import Dataset, TensorDataset, DataLoader
|
||||||
|
|
||||||
from accelerate import Accelerator
|
from accelerate import Accelerator
|
||||||
|
|
||||||
@ -12,7 +12,9 @@ from adam_atan2_pytorch import MuonAdamAtan2
|
|||||||
|
|
||||||
from dreamer4.dreamer4 import (
|
from dreamer4.dreamer4 import (
|
||||||
VideoTokenizer,
|
VideoTokenizer,
|
||||||
DynamicsWorldModel
|
DynamicsWorldModel,
|
||||||
|
Experience,
|
||||||
|
combine_experiences
|
||||||
)
|
)
|
||||||
|
|
||||||
from ema_pytorch import EMA
|
from ema_pytorch import EMA
|
||||||
@ -317,3 +319,196 @@ class DreamTrainer(Module):
|
|||||||
self.value_head_optim.zero_grad()
|
self.value_head_optim.zero_grad()
|
||||||
|
|
||||||
self.print('training complete')
|
self.print('training complete')
|
||||||
|
|
||||||
|
# training from sim
|
||||||
|
|
||||||
|
class SimTrainer(Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: DynamicsWorldModel,
|
||||||
|
optim_klass = AdamW,
|
||||||
|
batch_size = 16,
|
||||||
|
generate_timesteps = 16,
|
||||||
|
learning_rate = 3e-4,
|
||||||
|
max_grad_norm = None,
|
||||||
|
epochs = 2,
|
||||||
|
weight_decay = 0.,
|
||||||
|
accelerate_kwargs: dict = dict(),
|
||||||
|
optim_kwargs: dict = dict(),
|
||||||
|
cpu = False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.accelerator = Accelerator(
|
||||||
|
cpu = cpu,
|
||||||
|
**accelerate_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
self.model = model
|
||||||
|
|
||||||
|
optim_kwargs = dict(
|
||||||
|
lr = learning_rate,
|
||||||
|
weight_decay = weight_decay
|
||||||
|
)
|
||||||
|
|
||||||
|
self.policy_head_optim = AdamW(model.policy_head_parameters(), **optim_kwargs)
|
||||||
|
self.value_head_optim = AdamW(model.value_head_parameters(), **optim_kwargs)
|
||||||
|
|
||||||
|
self.max_grad_norm = max_grad_norm
|
||||||
|
|
||||||
|
self.epochs = epochs
|
||||||
|
self.batch_size = batch_size
|
||||||
|
|
||||||
|
self.generate_timesteps = generate_timesteps
|
||||||
|
|
||||||
|
self.unwrapped_model = self.model
|
||||||
|
|
||||||
|
(
|
||||||
|
self.model,
|
||||||
|
self.policy_head_optim,
|
||||||
|
self.value_head_optim,
|
||||||
|
) = self.accelerator.prepare(
|
||||||
|
self.model,
|
||||||
|
self.policy_head_optim,
|
||||||
|
self.value_head_optim
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def device(self):
|
||||||
|
return self.accelerator.device
|
||||||
|
|
||||||
|
@property
|
||||||
|
def unwrapped_model(self):
|
||||||
|
return self.accelerator.unwrap_model(self.model)
|
||||||
|
|
||||||
|
def print(self, *args, **kwargs):
|
||||||
|
return self.accelerator.print(*args, **kwargs)
|
||||||
|
|
||||||
|
def learn(
|
||||||
|
self,
|
||||||
|
experience: Experience
|
||||||
|
):
|
||||||
|
|
||||||
|
step_size = experience.step_size
|
||||||
|
agent_index = experience.agent_index
|
||||||
|
|
||||||
|
latents = experience.latents
|
||||||
|
old_values = experience.values
|
||||||
|
rewards = experience.rewards
|
||||||
|
|
||||||
|
discrete_actions, continuous_actions = experience.actions
|
||||||
|
discrete_log_probs, continuous_log_probs = experience.log_probs
|
||||||
|
|
||||||
|
# handle empties
|
||||||
|
|
||||||
|
empty_tensor = torch.empty_like(rewards)
|
||||||
|
|
||||||
|
has_discrete = exists(discrete_actions)
|
||||||
|
has_continuous = exists(continuous_actions)
|
||||||
|
|
||||||
|
discrete_actions = default(discrete_actions, empty_tensor)
|
||||||
|
continuous_actions = default(continuous_actions, empty_tensor)
|
||||||
|
|
||||||
|
discrete_log_probs = default(discrete_log_probs, empty_tensor)
|
||||||
|
continuous_log_probs = default(continuous_log_probs, empty_tensor)
|
||||||
|
|
||||||
|
# create the dataset and dataloader
|
||||||
|
|
||||||
|
dataset = TensorDataset(
|
||||||
|
latents,
|
||||||
|
discrete_actions,
|
||||||
|
continuous_actions,
|
||||||
|
discrete_log_probs,
|
||||||
|
continuous_log_probs,
|
||||||
|
old_values,
|
||||||
|
rewards
|
||||||
|
)
|
||||||
|
|
||||||
|
dataloader = DataLoader(dataset, batch_size = self.batch_size, shuffle = True)
|
||||||
|
|
||||||
|
for epoch in range(self.epochs):
|
||||||
|
|
||||||
|
for (
|
||||||
|
latents,
|
||||||
|
discrete_actions,
|
||||||
|
continuous_actions,
|
||||||
|
discrete_log_probs,
|
||||||
|
continuous_log_probs,
|
||||||
|
old_values,
|
||||||
|
rewards
|
||||||
|
) in dataloader:
|
||||||
|
|
||||||
|
actions = (
|
||||||
|
discrete_actions if has_discrete else None,
|
||||||
|
continuous_actions if has_continuous else None
|
||||||
|
)
|
||||||
|
|
||||||
|
log_probs = (
|
||||||
|
discrete_log_probs if has_discrete else None,
|
||||||
|
continuous_log_probs if has_continuous else None
|
||||||
|
)
|
||||||
|
|
||||||
|
batch_experience = Experience(
|
||||||
|
latents = latents,
|
||||||
|
actions = actions,
|
||||||
|
log_probs = log_probs,
|
||||||
|
values = old_values,
|
||||||
|
rewards = rewards,
|
||||||
|
step_size = step_size,
|
||||||
|
agent_index = agent_index
|
||||||
|
)
|
||||||
|
|
||||||
|
policy_head_loss, value_head_loss = self.model.learn_from_experience(batch_experience)
|
||||||
|
|
||||||
|
self.print(f'policy head loss: {policy_head_loss.item():.3f} | value head loss: {value_head_loss.item():.3f}')
|
||||||
|
|
||||||
|
# update policy head
|
||||||
|
|
||||||
|
self.accelerator.backward(policy_head_loss)
|
||||||
|
|
||||||
|
if exists(self.max_grad_norm):
|
||||||
|
self.accelerator.clip_grad_norm_(self.model.policy_head_parameters()(), self.max_grad_norm)
|
||||||
|
|
||||||
|
self.policy_head_optim.step()
|
||||||
|
self.policy_head_optim.zero_grad()
|
||||||
|
|
||||||
|
# update value head
|
||||||
|
|
||||||
|
self.accelerator.backward(value_head_loss)
|
||||||
|
|
||||||
|
if exists(self.max_grad_norm):
|
||||||
|
self.accelerator.clip_grad_norm_(self.model.value_head_parameters(), self.max_grad_norm)
|
||||||
|
|
||||||
|
self.value_head_optim.step()
|
||||||
|
self.value_head_optim.zero_grad()
|
||||||
|
|
||||||
|
self.print('training complete')
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
env,
|
||||||
|
num_episodes = 50000,
|
||||||
|
max_experiences_before_learn = 8,
|
||||||
|
env_is_vectorized = False
|
||||||
|
):
|
||||||
|
|
||||||
|
for _ in range(num_episodes):
|
||||||
|
|
||||||
|
total_experience = 0
|
||||||
|
experiences = []
|
||||||
|
|
||||||
|
while total_experience < max_experiences_before_learn:
|
||||||
|
|
||||||
|
experience = self.unwrapped_model.interact_with_env(env, env_is_vectorized = env_is_vectorized)
|
||||||
|
|
||||||
|
num_experience = experience.video.shape[0]
|
||||||
|
|
||||||
|
total_experience += num_experience
|
||||||
|
|
||||||
|
experiences.append(experience)
|
||||||
|
|
||||||
|
experiences = combine_experiences(experiences)
|
||||||
|
|
||||||
|
self.learn(experiences)
|
||||||
|
|
||||||
|
self.print('training complete')
|
||||||
@ -1,6 +1,6 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "dreamer4"
|
name = "dreamer4"
|
||||||
version = "0.0.67"
|
version = "0.0.68"
|
||||||
description = "Dreamer 4"
|
description = "Dreamer 4"
|
||||||
authors = [
|
authors = [
|
||||||
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
||||||
|
|||||||
@ -641,6 +641,8 @@ def test_online_rl(
|
|||||||
|
|
||||||
mock_env = MockEnv((256, 256), vectorized = vectorized, num_envs = 4)
|
mock_env = MockEnv((256, 256), vectorized = vectorized, num_envs = 4)
|
||||||
|
|
||||||
|
# manually
|
||||||
|
|
||||||
one_experience = world_model_and_policy.interact_with_env(mock_env, max_timesteps = 16, env_is_vectorized = vectorized)
|
one_experience = world_model_and_policy.interact_with_env(mock_env, max_timesteps = 16, env_is_vectorized = vectorized)
|
||||||
another_experience = world_model_and_policy.interact_with_env(mock_env, max_timesteps = 16, env_is_vectorized = vectorized)
|
another_experience = world_model_and_policy.interact_with_env(mock_env, max_timesteps = 16, env_is_vectorized = vectorized)
|
||||||
|
|
||||||
@ -650,3 +652,15 @@ def test_online_rl(
|
|||||||
|
|
||||||
actor_loss.backward()
|
actor_loss.backward()
|
||||||
critic_loss.backward()
|
critic_loss.backward()
|
||||||
|
|
||||||
|
# with trainer
|
||||||
|
|
||||||
|
from dreamer4.trainers import SimTrainer
|
||||||
|
|
||||||
|
trainer = SimTrainer(
|
||||||
|
world_model_and_policy,
|
||||||
|
batch_size = 4,
|
||||||
|
cpu = True
|
||||||
|
)
|
||||||
|
|
||||||
|
trainer(mock_env, num_episodes = 2, env_is_vectorized = vectorized)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user