first pass through gathering experience with a mock env for online rl
This commit is contained in:
parent
03b16a48f2
commit
d82debb7a6
@ -1,7 +1,6 @@
|
||||
from dreamer4.dreamer4 import (
|
||||
VideoTokenizer,
|
||||
DynamicsWorldModel,
|
||||
Dreamer
|
||||
DynamicsWorldModel
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -81,7 +81,6 @@ class Experience:
|
||||
step_size: int | None = None
|
||||
agent_index: int = 0
|
||||
is_from_world_model: bool = True
|
||||
is_batched: bool = True
|
||||
|
||||
# helpers
|
||||
|
||||
@ -1964,6 +1963,107 @@ class DynamicsWorldModel(Module):
|
||||
|
||||
return align_dims_left(times, align_dims_left_to)
|
||||
|
||||
# interacting with env for experience
|
||||
|
||||
@torch.no_grad()
|
||||
def interact_with_env(
|
||||
self,
|
||||
env,
|
||||
seed = None,
|
||||
agent_index = 0,
|
||||
step_size = 4,
|
||||
max_timesteps = 16
|
||||
):
|
||||
assert exists(self.video_tokenizer)
|
||||
|
||||
init_frame = env.reset()
|
||||
|
||||
video = rearrange(init_frame, 'c vh vw -> 1 c 1 vh vw')
|
||||
|
||||
# accumulate
|
||||
|
||||
rewards = None
|
||||
discrete_actions = None
|
||||
continuous_actions = None
|
||||
discrete_log_probs = None
|
||||
continuous_log_probs = None
|
||||
values = None
|
||||
latents = None
|
||||
|
||||
for _ in range(max_timesteps):
|
||||
|
||||
latents = self.video_tokenizer(video, return_latents = True)
|
||||
|
||||
_, (agent_embed, _) = self.forward(
|
||||
latents = latents,
|
||||
signal_levels = self.max_steps - 1,
|
||||
step_sizes = step_size,
|
||||
rewards = rewards,
|
||||
discrete_actions = discrete_actions,
|
||||
continuous_actions = continuous_actions,
|
||||
latent_is_noised = True,
|
||||
return_pred_only = True,
|
||||
return_intermediates = True
|
||||
)
|
||||
|
||||
one_agent_embed = agent_embed[..., -1:, agent_index, :]
|
||||
|
||||
policy_embed = self.policy_head(one_agent_embed)
|
||||
|
||||
# sample actions
|
||||
|
||||
sampled_discrete_actions, sampled_continuous_actions = self.action_embedder.sample(policy_embed, pred_head_index = 0, squeeze = True)
|
||||
|
||||
discrete_actions = safe_cat((discrete_actions, sampled_discrete_actions), dim = 1)
|
||||
continuous_actions = safe_cat((continuous_actions, sampled_continuous_actions), dim = 1)
|
||||
|
||||
# get the log prob and values for policy optimization
|
||||
|
||||
one_discrete_log_probs, one_continuous_log_probs = self.action_embedder.log_probs(
|
||||
policy_embed,
|
||||
pred_head_index = 0,
|
||||
discrete_targets = sampled_discrete_actions,
|
||||
continuous_targets = sampled_continuous_actions,
|
||||
)
|
||||
|
||||
discrete_log_probs = safe_cat((discrete_log_probs, one_discrete_log_probs), dim = 1)
|
||||
continuous_log_probs = safe_cat((continuous_log_probs, one_continuous_log_probs), dim = 1)
|
||||
|
||||
value_bins = self.value_head(one_agent_embed)
|
||||
value = self.reward_encoder.bins_to_scalar_value(value_bins)
|
||||
|
||||
values = safe_cat((values, value), dim = 1)
|
||||
|
||||
# pass the sampled action to the environment and get back next state and reward
|
||||
|
||||
next_frame, reward = env.step((sampled_discrete_actions, sampled_continuous_actions))
|
||||
|
||||
# batch and time dimension
|
||||
|
||||
next_frame = rearrange(next_frame, 'c vh vw -> 1 c 1 vh vw')
|
||||
reward = rearrange(reward, ' -> 1 1')
|
||||
|
||||
# concat
|
||||
|
||||
video = cat((video, next_frame), dim = 2)
|
||||
rewards = safe_cat((rewards, reward), dim = 1)
|
||||
|
||||
# package up one experience for learning
|
||||
|
||||
one_experience = Experience(
|
||||
latents = latents,
|
||||
video = video[:, :, :-1],
|
||||
rewards = rewards,
|
||||
actions = (discrete_actions, continuous_actions),
|
||||
log_probs = (discrete_log_probs, continuous_log_probs),
|
||||
values = values,
|
||||
step_size = step_size,
|
||||
agent_index = agent_index,
|
||||
is_from_world_model = False
|
||||
)
|
||||
|
||||
return one_experience
|
||||
|
||||
# ppo
|
||||
|
||||
def learn_from_experience(
|
||||
@ -1973,7 +2073,6 @@ class DynamicsWorldModel(Module):
|
||||
value_optim: Optimizer | None = None,
|
||||
only_learn_policy_value_heads = True # in the paper, they do not finetune the entire dynamics model, they just learn the heads
|
||||
):
|
||||
assert experience.is_batched
|
||||
|
||||
latents = experience.latents
|
||||
actions = experience.actions
|
||||
@ -2758,18 +2857,3 @@ class DynamicsWorldModel(Module):
|
||||
return total_loss
|
||||
|
||||
return total_loss, losses
|
||||
|
||||
# dreamer
|
||||
|
||||
class Dreamer(Module):
|
||||
def __init__(
|
||||
self,
|
||||
state_tokenizer: VideoTokenizer,
|
||||
world_model: DynamicsWorldModel,
|
||||
):
|
||||
super().__init__()
|
||||
self.state_toke = state_tokenizer
|
||||
self.world_model = world_model
|
||||
|
||||
def interact_with_sim(self, env) -> Experience:
|
||||
raise NotImplementedError
|
||||
|
||||
@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
from random import choice
|
||||
|
||||
import torch
|
||||
from torch import tensor, randn, randint
|
||||
from torch import tensor, empty, randn, randint
|
||||
from torch.nn import Module
|
||||
|
||||
from einops import repeat
|
||||
@ -13,7 +13,7 @@ class MockEnv(Module):
|
||||
def __init__(
|
||||
self,
|
||||
image_shape,
|
||||
reward_range = (-100., 100.),
|
||||
reward_range = (-100, 100),
|
||||
num_envs = 1,
|
||||
vectorized = False
|
||||
):
|
||||
@ -40,7 +40,8 @@ class MockEnv(Module):
|
||||
actions,
|
||||
):
|
||||
state = self.get_random_state()
|
||||
reward = randint(*self.reward_range, ()).float()
|
||||
|
||||
reward = empty(()).uniform_(*self.reward_range)
|
||||
|
||||
if not self.vectorized:
|
||||
return state, reward
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "dreamer4"
|
||||
version = "0.0.60"
|
||||
version = "0.0.61"
|
||||
description = "Dreamer 4"
|
||||
authors = [
|
||||
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
||||
|
||||
@ -600,10 +600,43 @@ def test_cache_generate():
|
||||
generated, time_kv_cache = dynamics.generate(1, time_kv_cache = time_kv_cache, return_time_kv_cache = True)
|
||||
generated, time_kv_cache = dynamics.generate(1, time_kv_cache = time_kv_cache, return_time_kv_cache = True)
|
||||
|
||||
@param('vectorized', (False, True))
|
||||
def test_online_rl(
|
||||
vectorized
|
||||
):
|
||||
from dreamer4.mocks import MockEnv
|
||||
def test_online_rl():
|
||||
from dreamer4.dreamer4 import DynamicsWorldModel, VideoTokenizer
|
||||
|
||||
mock_env = MockEnv((256, 256), vectorized = vectorized, num_envs = 4)
|
||||
tokenizer = VideoTokenizer(
|
||||
16,
|
||||
encoder_depth = 1,
|
||||
decoder_depth = 1,
|
||||
time_block_every = 1,
|
||||
dim_latent = 16,
|
||||
patch_size = 32,
|
||||
attn_dim_head = 16,
|
||||
num_latent_tokens = 1
|
||||
)
|
||||
|
||||
world_model_and_policy = DynamicsWorldModel(
|
||||
video_tokenizer = tokenizer,
|
||||
dim = 16,
|
||||
dim_latent = 16,
|
||||
max_steps = 64,
|
||||
num_tasks = 4,
|
||||
num_latent_tokens = 1,
|
||||
depth = 1,
|
||||
time_block_every = 1,
|
||||
num_spatial_tokens = 1,
|
||||
pred_orig_latent = True,
|
||||
num_discrete_actions = 4,
|
||||
attn_dim_head = 16,
|
||||
prob_no_shortcut_train = 0.1,
|
||||
num_residual_streams = 1
|
||||
)
|
||||
|
||||
from dreamer4.mocks import MockEnv
|
||||
mock_env = MockEnv((256, 256), vectorized = False, num_envs = 4)
|
||||
|
||||
one_experience = world_model_and_policy.interact_with_env(mock_env, max_timesteps = 16)
|
||||
|
||||
actor_loss, critic_loss = world_model_and_policy.learn_from_experience(one_experience)
|
||||
|
||||
actor_loss.backward()
|
||||
critic_loss.backward()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user