first pass through gathering experience with a mock env for online rl
This commit is contained in:
parent
03b16a48f2
commit
d82debb7a6
@ -1,7 +1,6 @@
|
|||||||
from dreamer4.dreamer4 import (
|
from dreamer4.dreamer4 import (
|
||||||
VideoTokenizer,
|
VideoTokenizer,
|
||||||
DynamicsWorldModel,
|
DynamicsWorldModel
|
||||||
Dreamer
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -81,7 +81,6 @@ class Experience:
|
|||||||
step_size: int | None = None
|
step_size: int | None = None
|
||||||
agent_index: int = 0
|
agent_index: int = 0
|
||||||
is_from_world_model: bool = True
|
is_from_world_model: bool = True
|
||||||
is_batched: bool = True
|
|
||||||
|
|
||||||
# helpers
|
# helpers
|
||||||
|
|
||||||
@ -1964,6 +1963,107 @@ class DynamicsWorldModel(Module):
|
|||||||
|
|
||||||
return align_dims_left(times, align_dims_left_to)
|
return align_dims_left(times, align_dims_left_to)
|
||||||
|
|
||||||
|
# interacting with env for experience
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def interact_with_env(
|
||||||
|
self,
|
||||||
|
env,
|
||||||
|
seed = None,
|
||||||
|
agent_index = 0,
|
||||||
|
step_size = 4,
|
||||||
|
max_timesteps = 16
|
||||||
|
):
|
||||||
|
assert exists(self.video_tokenizer)
|
||||||
|
|
||||||
|
init_frame = env.reset()
|
||||||
|
|
||||||
|
video = rearrange(init_frame, 'c vh vw -> 1 c 1 vh vw')
|
||||||
|
|
||||||
|
# accumulate
|
||||||
|
|
||||||
|
rewards = None
|
||||||
|
discrete_actions = None
|
||||||
|
continuous_actions = None
|
||||||
|
discrete_log_probs = None
|
||||||
|
continuous_log_probs = None
|
||||||
|
values = None
|
||||||
|
latents = None
|
||||||
|
|
||||||
|
for _ in range(max_timesteps):
|
||||||
|
|
||||||
|
latents = self.video_tokenizer(video, return_latents = True)
|
||||||
|
|
||||||
|
_, (agent_embed, _) = self.forward(
|
||||||
|
latents = latents,
|
||||||
|
signal_levels = self.max_steps - 1,
|
||||||
|
step_sizes = step_size,
|
||||||
|
rewards = rewards,
|
||||||
|
discrete_actions = discrete_actions,
|
||||||
|
continuous_actions = continuous_actions,
|
||||||
|
latent_is_noised = True,
|
||||||
|
return_pred_only = True,
|
||||||
|
return_intermediates = True
|
||||||
|
)
|
||||||
|
|
||||||
|
one_agent_embed = agent_embed[..., -1:, agent_index, :]
|
||||||
|
|
||||||
|
policy_embed = self.policy_head(one_agent_embed)
|
||||||
|
|
||||||
|
# sample actions
|
||||||
|
|
||||||
|
sampled_discrete_actions, sampled_continuous_actions = self.action_embedder.sample(policy_embed, pred_head_index = 0, squeeze = True)
|
||||||
|
|
||||||
|
discrete_actions = safe_cat((discrete_actions, sampled_discrete_actions), dim = 1)
|
||||||
|
continuous_actions = safe_cat((continuous_actions, sampled_continuous_actions), dim = 1)
|
||||||
|
|
||||||
|
# get the log prob and values for policy optimization
|
||||||
|
|
||||||
|
one_discrete_log_probs, one_continuous_log_probs = self.action_embedder.log_probs(
|
||||||
|
policy_embed,
|
||||||
|
pred_head_index = 0,
|
||||||
|
discrete_targets = sampled_discrete_actions,
|
||||||
|
continuous_targets = sampled_continuous_actions,
|
||||||
|
)
|
||||||
|
|
||||||
|
discrete_log_probs = safe_cat((discrete_log_probs, one_discrete_log_probs), dim = 1)
|
||||||
|
continuous_log_probs = safe_cat((continuous_log_probs, one_continuous_log_probs), dim = 1)
|
||||||
|
|
||||||
|
value_bins = self.value_head(one_agent_embed)
|
||||||
|
value = self.reward_encoder.bins_to_scalar_value(value_bins)
|
||||||
|
|
||||||
|
values = safe_cat((values, value), dim = 1)
|
||||||
|
|
||||||
|
# pass the sampled action to the environment and get back next state and reward
|
||||||
|
|
||||||
|
next_frame, reward = env.step((sampled_discrete_actions, sampled_continuous_actions))
|
||||||
|
|
||||||
|
# batch and time dimension
|
||||||
|
|
||||||
|
next_frame = rearrange(next_frame, 'c vh vw -> 1 c 1 vh vw')
|
||||||
|
reward = rearrange(reward, ' -> 1 1')
|
||||||
|
|
||||||
|
# concat
|
||||||
|
|
||||||
|
video = cat((video, next_frame), dim = 2)
|
||||||
|
rewards = safe_cat((rewards, reward), dim = 1)
|
||||||
|
|
||||||
|
# package up one experience for learning
|
||||||
|
|
||||||
|
one_experience = Experience(
|
||||||
|
latents = latents,
|
||||||
|
video = video[:, :, :-1],
|
||||||
|
rewards = rewards,
|
||||||
|
actions = (discrete_actions, continuous_actions),
|
||||||
|
log_probs = (discrete_log_probs, continuous_log_probs),
|
||||||
|
values = values,
|
||||||
|
step_size = step_size,
|
||||||
|
agent_index = agent_index,
|
||||||
|
is_from_world_model = False
|
||||||
|
)
|
||||||
|
|
||||||
|
return one_experience
|
||||||
|
|
||||||
# ppo
|
# ppo
|
||||||
|
|
||||||
def learn_from_experience(
|
def learn_from_experience(
|
||||||
@ -1973,7 +2073,6 @@ class DynamicsWorldModel(Module):
|
|||||||
value_optim: Optimizer | None = None,
|
value_optim: Optimizer | None = None,
|
||||||
only_learn_policy_value_heads = True # in the paper, they do not finetune the entire dynamics model, they just learn the heads
|
only_learn_policy_value_heads = True # in the paper, they do not finetune the entire dynamics model, they just learn the heads
|
||||||
):
|
):
|
||||||
assert experience.is_batched
|
|
||||||
|
|
||||||
latents = experience.latents
|
latents = experience.latents
|
||||||
actions = experience.actions
|
actions = experience.actions
|
||||||
@ -2758,18 +2857,3 @@ class DynamicsWorldModel(Module):
|
|||||||
return total_loss
|
return total_loss
|
||||||
|
|
||||||
return total_loss, losses
|
return total_loss, losses
|
||||||
|
|
||||||
# dreamer
|
|
||||||
|
|
||||||
class Dreamer(Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
state_tokenizer: VideoTokenizer,
|
|
||||||
world_model: DynamicsWorldModel,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.state_toke = state_tokenizer
|
|
||||||
self.world_model = world_model
|
|
||||||
|
|
||||||
def interact_with_sim(self, env) -> Experience:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|||||||
@ -2,7 +2,7 @@ from __future__ import annotations
|
|||||||
from random import choice
|
from random import choice
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import tensor, randn, randint
|
from torch import tensor, empty, randn, randint
|
||||||
from torch.nn import Module
|
from torch.nn import Module
|
||||||
|
|
||||||
from einops import repeat
|
from einops import repeat
|
||||||
@ -13,7 +13,7 @@ class MockEnv(Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
image_shape,
|
image_shape,
|
||||||
reward_range = (-100., 100.),
|
reward_range = (-100, 100),
|
||||||
num_envs = 1,
|
num_envs = 1,
|
||||||
vectorized = False
|
vectorized = False
|
||||||
):
|
):
|
||||||
@ -40,7 +40,8 @@ class MockEnv(Module):
|
|||||||
actions,
|
actions,
|
||||||
):
|
):
|
||||||
state = self.get_random_state()
|
state = self.get_random_state()
|
||||||
reward = randint(*self.reward_range, ()).float()
|
|
||||||
|
reward = empty(()).uniform_(*self.reward_range)
|
||||||
|
|
||||||
if not self.vectorized:
|
if not self.vectorized:
|
||||||
return state, reward
|
return state, reward
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "dreamer4"
|
name = "dreamer4"
|
||||||
version = "0.0.60"
|
version = "0.0.61"
|
||||||
description = "Dreamer 4"
|
description = "Dreamer 4"
|
||||||
authors = [
|
authors = [
|
||||||
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
||||||
|
|||||||
@ -600,10 +600,43 @@ def test_cache_generate():
|
|||||||
generated, time_kv_cache = dynamics.generate(1, time_kv_cache = time_kv_cache, return_time_kv_cache = True)
|
generated, time_kv_cache = dynamics.generate(1, time_kv_cache = time_kv_cache, return_time_kv_cache = True)
|
||||||
generated, time_kv_cache = dynamics.generate(1, time_kv_cache = time_kv_cache, return_time_kv_cache = True)
|
generated, time_kv_cache = dynamics.generate(1, time_kv_cache = time_kv_cache, return_time_kv_cache = True)
|
||||||
|
|
||||||
@param('vectorized', (False, True))
|
def test_online_rl():
|
||||||
def test_online_rl(
|
from dreamer4.dreamer4 import DynamicsWorldModel, VideoTokenizer
|
||||||
vectorized
|
|
||||||
):
|
|
||||||
from dreamer4.mocks import MockEnv
|
|
||||||
|
|
||||||
mock_env = MockEnv((256, 256), vectorized = vectorized, num_envs = 4)
|
tokenizer = VideoTokenizer(
|
||||||
|
16,
|
||||||
|
encoder_depth = 1,
|
||||||
|
decoder_depth = 1,
|
||||||
|
time_block_every = 1,
|
||||||
|
dim_latent = 16,
|
||||||
|
patch_size = 32,
|
||||||
|
attn_dim_head = 16,
|
||||||
|
num_latent_tokens = 1
|
||||||
|
)
|
||||||
|
|
||||||
|
world_model_and_policy = DynamicsWorldModel(
|
||||||
|
video_tokenizer = tokenizer,
|
||||||
|
dim = 16,
|
||||||
|
dim_latent = 16,
|
||||||
|
max_steps = 64,
|
||||||
|
num_tasks = 4,
|
||||||
|
num_latent_tokens = 1,
|
||||||
|
depth = 1,
|
||||||
|
time_block_every = 1,
|
||||||
|
num_spatial_tokens = 1,
|
||||||
|
pred_orig_latent = True,
|
||||||
|
num_discrete_actions = 4,
|
||||||
|
attn_dim_head = 16,
|
||||||
|
prob_no_shortcut_train = 0.1,
|
||||||
|
num_residual_streams = 1
|
||||||
|
)
|
||||||
|
|
||||||
|
from dreamer4.mocks import MockEnv
|
||||||
|
mock_env = MockEnv((256, 256), vectorized = False, num_envs = 4)
|
||||||
|
|
||||||
|
one_experience = world_model_and_policy.interact_with_env(mock_env, max_timesteps = 16)
|
||||||
|
|
||||||
|
actor_loss, critic_loss = world_model_and_policy.learn_from_experience(one_experience)
|
||||||
|
|
||||||
|
actor_loss.backward()
|
||||||
|
critic_loss.backward()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user