first pass through gathering experience with a mock env for online rl

This commit is contained in:
lucidrains 2025-10-22 08:32:46 -07:00
parent 03b16a48f2
commit d82debb7a6
5 changed files with 146 additions and 29 deletions

View File

@ -1,7 +1,6 @@
from dreamer4.dreamer4 import ( from dreamer4.dreamer4 import (
VideoTokenizer, VideoTokenizer,
DynamicsWorldModel, DynamicsWorldModel
Dreamer
) )

View File

@ -81,7 +81,6 @@ class Experience:
step_size: int | None = None step_size: int | None = None
agent_index: int = 0 agent_index: int = 0
is_from_world_model: bool = True is_from_world_model: bool = True
is_batched: bool = True
# helpers # helpers
@ -1964,6 +1963,107 @@ class DynamicsWorldModel(Module):
return align_dims_left(times, align_dims_left_to) return align_dims_left(times, align_dims_left_to)
# interacting with env for experience
@torch.no_grad()
def interact_with_env(
self,
env,
seed = None,
agent_index = 0,
step_size = 4,
max_timesteps = 16
):
assert exists(self.video_tokenizer)
init_frame = env.reset()
video = rearrange(init_frame, 'c vh vw -> 1 c 1 vh vw')
# accumulate
rewards = None
discrete_actions = None
continuous_actions = None
discrete_log_probs = None
continuous_log_probs = None
values = None
latents = None
for _ in range(max_timesteps):
latents = self.video_tokenizer(video, return_latents = True)
_, (agent_embed, _) = self.forward(
latents = latents,
signal_levels = self.max_steps - 1,
step_sizes = step_size,
rewards = rewards,
discrete_actions = discrete_actions,
continuous_actions = continuous_actions,
latent_is_noised = True,
return_pred_only = True,
return_intermediates = True
)
one_agent_embed = agent_embed[..., -1:, agent_index, :]
policy_embed = self.policy_head(one_agent_embed)
# sample actions
sampled_discrete_actions, sampled_continuous_actions = self.action_embedder.sample(policy_embed, pred_head_index = 0, squeeze = True)
discrete_actions = safe_cat((discrete_actions, sampled_discrete_actions), dim = 1)
continuous_actions = safe_cat((continuous_actions, sampled_continuous_actions), dim = 1)
# get the log prob and values for policy optimization
one_discrete_log_probs, one_continuous_log_probs = self.action_embedder.log_probs(
policy_embed,
pred_head_index = 0,
discrete_targets = sampled_discrete_actions,
continuous_targets = sampled_continuous_actions,
)
discrete_log_probs = safe_cat((discrete_log_probs, one_discrete_log_probs), dim = 1)
continuous_log_probs = safe_cat((continuous_log_probs, one_continuous_log_probs), dim = 1)
value_bins = self.value_head(one_agent_embed)
value = self.reward_encoder.bins_to_scalar_value(value_bins)
values = safe_cat((values, value), dim = 1)
# pass the sampled action to the environment and get back next state and reward
next_frame, reward = env.step((sampled_discrete_actions, sampled_continuous_actions))
# batch and time dimension
next_frame = rearrange(next_frame, 'c vh vw -> 1 c 1 vh vw')
reward = rearrange(reward, ' -> 1 1')
# concat
video = cat((video, next_frame), dim = 2)
rewards = safe_cat((rewards, reward), dim = 1)
# package up one experience for learning
one_experience = Experience(
latents = latents,
video = video[:, :, :-1],
rewards = rewards,
actions = (discrete_actions, continuous_actions),
log_probs = (discrete_log_probs, continuous_log_probs),
values = values,
step_size = step_size,
agent_index = agent_index,
is_from_world_model = False
)
return one_experience
# ppo # ppo
def learn_from_experience( def learn_from_experience(
@ -1973,7 +2073,6 @@ class DynamicsWorldModel(Module):
value_optim: Optimizer | None = None, value_optim: Optimizer | None = None,
only_learn_policy_value_heads = True # in the paper, they do not finetune the entire dynamics model, they just learn the heads only_learn_policy_value_heads = True # in the paper, they do not finetune the entire dynamics model, they just learn the heads
): ):
assert experience.is_batched
latents = experience.latents latents = experience.latents
actions = experience.actions actions = experience.actions
@ -2758,18 +2857,3 @@ class DynamicsWorldModel(Module):
return total_loss return total_loss
return total_loss, losses return total_loss, losses
# dreamer
class Dreamer(Module):
def __init__(
self,
state_tokenizer: VideoTokenizer,
world_model: DynamicsWorldModel,
):
super().__init__()
self.state_toke = state_tokenizer
self.world_model = world_model
def interact_with_sim(self, env) -> Experience:
raise NotImplementedError

View File

@ -2,7 +2,7 @@ from __future__ import annotations
from random import choice from random import choice
import torch import torch
from torch import tensor, randn, randint from torch import tensor, empty, randn, randint
from torch.nn import Module from torch.nn import Module
from einops import repeat from einops import repeat
@ -13,7 +13,7 @@ class MockEnv(Module):
def __init__( def __init__(
self, self,
image_shape, image_shape,
reward_range = (-100., 100.), reward_range = (-100, 100),
num_envs = 1, num_envs = 1,
vectorized = False vectorized = False
): ):
@ -40,7 +40,8 @@ class MockEnv(Module):
actions, actions,
): ):
state = self.get_random_state() state = self.get_random_state()
reward = randint(*self.reward_range, ()).float()
reward = empty(()).uniform_(*self.reward_range)
if not self.vectorized: if not self.vectorized:
return state, reward return state, reward

View File

@ -1,6 +1,6 @@
[project] [project]
name = "dreamer4" name = "dreamer4"
version = "0.0.60" version = "0.0.61"
description = "Dreamer 4" description = "Dreamer 4"
authors = [ authors = [
{ name = "Phil Wang", email = "lucidrains@gmail.com" } { name = "Phil Wang", email = "lucidrains@gmail.com" }

View File

@ -600,10 +600,43 @@ def test_cache_generate():
generated, time_kv_cache = dynamics.generate(1, time_kv_cache = time_kv_cache, return_time_kv_cache = True) generated, time_kv_cache = dynamics.generate(1, time_kv_cache = time_kv_cache, return_time_kv_cache = True)
generated, time_kv_cache = dynamics.generate(1, time_kv_cache = time_kv_cache, return_time_kv_cache = True) generated, time_kv_cache = dynamics.generate(1, time_kv_cache = time_kv_cache, return_time_kv_cache = True)
@param('vectorized', (False, True)) def test_online_rl():
def test_online_rl( from dreamer4.dreamer4 import DynamicsWorldModel, VideoTokenizer
vectorized
):
from dreamer4.mocks import MockEnv
mock_env = MockEnv((256, 256), vectorized = vectorized, num_envs = 4) tokenizer = VideoTokenizer(
16,
encoder_depth = 1,
decoder_depth = 1,
time_block_every = 1,
dim_latent = 16,
patch_size = 32,
attn_dim_head = 16,
num_latent_tokens = 1
)
world_model_and_policy = DynamicsWorldModel(
video_tokenizer = tokenizer,
dim = 16,
dim_latent = 16,
max_steps = 64,
num_tasks = 4,
num_latent_tokens = 1,
depth = 1,
time_block_every = 1,
num_spatial_tokens = 1,
pred_orig_latent = True,
num_discrete_actions = 4,
attn_dim_head = 16,
prob_no_shortcut_train = 0.1,
num_residual_streams = 1
)
from dreamer4.mocks import MockEnv
mock_env = MockEnv((256, 256), vectorized = False, num_envs = 4)
one_experience = world_model_and_policy.interact_with_env(mock_env, max_timesteps = 16)
actor_loss, critic_loss = world_model_and_policy.learn_from_experience(one_experience)
actor_loss.backward()
critic_loss.backward()