first pass through gathering experience with a mock env for online rl

This commit is contained in:
lucidrains 2025-10-22 08:32:46 -07:00
parent 03b16a48f2
commit d82debb7a6
5 changed files with 146 additions and 29 deletions

View File

@ -1,7 +1,6 @@
from dreamer4.dreamer4 import (
VideoTokenizer,
DynamicsWorldModel,
Dreamer
DynamicsWorldModel
)

View File

@ -81,7 +81,6 @@ class Experience:
step_size: int | None = None
agent_index: int = 0
is_from_world_model: bool = True
is_batched: bool = True
# helpers
@ -1964,6 +1963,107 @@ class DynamicsWorldModel(Module):
return align_dims_left(times, align_dims_left_to)
# interacting with env for experience
@torch.no_grad()
def interact_with_env(
self,
env,
seed = None,
agent_index = 0,
step_size = 4,
max_timesteps = 16
):
assert exists(self.video_tokenizer)
init_frame = env.reset()
video = rearrange(init_frame, 'c vh vw -> 1 c 1 vh vw')
# accumulate
rewards = None
discrete_actions = None
continuous_actions = None
discrete_log_probs = None
continuous_log_probs = None
values = None
latents = None
for _ in range(max_timesteps):
latents = self.video_tokenizer(video, return_latents = True)
_, (agent_embed, _) = self.forward(
latents = latents,
signal_levels = self.max_steps - 1,
step_sizes = step_size,
rewards = rewards,
discrete_actions = discrete_actions,
continuous_actions = continuous_actions,
latent_is_noised = True,
return_pred_only = True,
return_intermediates = True
)
one_agent_embed = agent_embed[..., -1:, agent_index, :]
policy_embed = self.policy_head(one_agent_embed)
# sample actions
sampled_discrete_actions, sampled_continuous_actions = self.action_embedder.sample(policy_embed, pred_head_index = 0, squeeze = True)
discrete_actions = safe_cat((discrete_actions, sampled_discrete_actions), dim = 1)
continuous_actions = safe_cat((continuous_actions, sampled_continuous_actions), dim = 1)
# get the log prob and values for policy optimization
one_discrete_log_probs, one_continuous_log_probs = self.action_embedder.log_probs(
policy_embed,
pred_head_index = 0,
discrete_targets = sampled_discrete_actions,
continuous_targets = sampled_continuous_actions,
)
discrete_log_probs = safe_cat((discrete_log_probs, one_discrete_log_probs), dim = 1)
continuous_log_probs = safe_cat((continuous_log_probs, one_continuous_log_probs), dim = 1)
value_bins = self.value_head(one_agent_embed)
value = self.reward_encoder.bins_to_scalar_value(value_bins)
values = safe_cat((values, value), dim = 1)
# pass the sampled action to the environment and get back next state and reward
next_frame, reward = env.step((sampled_discrete_actions, sampled_continuous_actions))
# batch and time dimension
next_frame = rearrange(next_frame, 'c vh vw -> 1 c 1 vh vw')
reward = rearrange(reward, ' -> 1 1')
# concat
video = cat((video, next_frame), dim = 2)
rewards = safe_cat((rewards, reward), dim = 1)
# package up one experience for learning
one_experience = Experience(
latents = latents,
video = video[:, :, :-1],
rewards = rewards,
actions = (discrete_actions, continuous_actions),
log_probs = (discrete_log_probs, continuous_log_probs),
values = values,
step_size = step_size,
agent_index = agent_index,
is_from_world_model = False
)
return one_experience
# ppo
def learn_from_experience(
@ -1973,7 +2073,6 @@ class DynamicsWorldModel(Module):
value_optim: Optimizer | None = None,
only_learn_policy_value_heads = True # in the paper, they do not finetune the entire dynamics model, they just learn the heads
):
assert experience.is_batched
latents = experience.latents
actions = experience.actions
@ -2758,18 +2857,3 @@ class DynamicsWorldModel(Module):
return total_loss
return total_loss, losses
# dreamer
class Dreamer(Module):
def __init__(
self,
state_tokenizer: VideoTokenizer,
world_model: DynamicsWorldModel,
):
super().__init__()
self.state_toke = state_tokenizer
self.world_model = world_model
def interact_with_sim(self, env) -> Experience:
raise NotImplementedError

View File

@ -2,7 +2,7 @@ from __future__ import annotations
from random import choice
import torch
from torch import tensor, randn, randint
from torch import tensor, empty, randn, randint
from torch.nn import Module
from einops import repeat
@ -13,7 +13,7 @@ class MockEnv(Module):
def __init__(
self,
image_shape,
reward_range = (-100., 100.),
reward_range = (-100, 100),
num_envs = 1,
vectorized = False
):
@ -40,7 +40,8 @@ class MockEnv(Module):
actions,
):
state = self.get_random_state()
reward = randint(*self.reward_range, ()).float()
reward = empty(()).uniform_(*self.reward_range)
if not self.vectorized:
return state, reward

View File

@ -1,6 +1,6 @@
[project]
name = "dreamer4"
version = "0.0.60"
version = "0.0.61"
description = "Dreamer 4"
authors = [
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

View File

@ -600,10 +600,43 @@ def test_cache_generate():
generated, time_kv_cache = dynamics.generate(1, time_kv_cache = time_kv_cache, return_time_kv_cache = True)
generated, time_kv_cache = dynamics.generate(1, time_kv_cache = time_kv_cache, return_time_kv_cache = True)
@param('vectorized', (False, True))
def test_online_rl(
vectorized
):
from dreamer4.mocks import MockEnv
def test_online_rl():
from dreamer4.dreamer4 import DynamicsWorldModel, VideoTokenizer
mock_env = MockEnv((256, 256), vectorized = vectorized, num_envs = 4)
tokenizer = VideoTokenizer(
16,
encoder_depth = 1,
decoder_depth = 1,
time_block_every = 1,
dim_latent = 16,
patch_size = 32,
attn_dim_head = 16,
num_latent_tokens = 1
)
world_model_and_policy = DynamicsWorldModel(
video_tokenizer = tokenizer,
dim = 16,
dim_latent = 16,
max_steps = 64,
num_tasks = 4,
num_latent_tokens = 1,
depth = 1,
time_block_every = 1,
num_spatial_tokens = 1,
pred_orig_latent = True,
num_discrete_actions = 4,
attn_dim_head = 16,
prob_no_shortcut_train = 0.1,
num_residual_streams = 1
)
from dreamer4.mocks import MockEnv
mock_env = MockEnv((256, 256), vectorized = False, num_envs = 4)
one_experience = world_model_and_policy.interact_with_env(mock_env, max_timesteps = 16)
actor_loss, critic_loss = world_model_and_policy.learn_from_experience(one_experience)
actor_loss.backward()
critic_loss.backward()