Compare commits

..

1 Commits
main ... 0.0.64

Author SHA1 Message Date
lucidrains
e4ee4d905a handle vectorized env 2025-10-22 08:52:08 -07:00
8 changed files with 231 additions and 1788 deletions

View File

@ -5,11 +5,11 @@ jobs:
build:
runs-on: ubuntu-latest
timeout-minutes: 60
timeout-minutes: 20
strategy:
fail-fast: false
matrix:
group: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19]
group: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
steps:
- uses: actions/checkout@v4
@ -24,4 +24,4 @@ jobs:
python -m uv pip install -e .[test]
- name: Test with pytest
run: |
python -m pytest --num-shards 20 --shard-id ${{ matrix.group }} tests/
python -m pytest --num-shards 10 --shard-id ${{ matrix.group }} tests/

View File

@ -1,99 +1,10 @@
<img src="./dreamer4-fig2.png" width="400px"></img>
## Dreamer 4
## Dreamer 4 (wip)
Implementation of Danijar's [latest iteration](https://arxiv.org/abs/2509.24527v1) for his [Dreamer](https://danijar.com/project/dreamer4/) line of work
[Discord channel](https://discord.gg/PmGR7KRwxq) for collaborating with other researchers interested in this work
## Appreciation
- [@dirkmcpherson](https://github.com/dirkmcpherson) for fixes to typo errors and unpassed arguments!
## Install
```bash
$ pip install dreamer4
```
## Usage
```python
import torch
from dreamer4 import VideoTokenizer, DynamicsWorldModel
# video tokenizer, learned through MAE + lpips
tokenizer = VideoTokenizer(
dim = 512,
dim_latent = 32,
patch_size = 32,
image_height = 256,
image_width = 256
)
video = torch.randn(2, 3, 10, 256, 256)
# learn the tokenizer
loss = tokenizer(video)
loss.backward()
# dynamics world model
world_model = DynamicsWorldModel(
dim = 512,
dim_latent = 32,
video_tokenizer = tokenizer,
num_discrete_actions = 4,
num_residual_streams = 1
)
# state, action, rewards
video = torch.randn(2, 3, 10, 256, 256)
discrete_actions = torch.randint(0, 4, (2, 10, 1))
rewards = torch.randn(2, 10)
# learn dynamics / behavior cloned model
loss = world_model(
video = video,
rewards = rewards,
discrete_actions = discrete_actions
)
loss.backward()
# do the above with much data
# then generate dreams
dreams = world_model.generate(
10,
batch_size = 2,
return_decoded_video = True,
return_for_policy_optimization = True
)
# learn from the dreams
actor_loss, critic_loss = world_model.learn_from_experience(dreams)
(actor_loss + critic_loss).backward()
# learn from environment
from dreamer4.mocks import MockEnv
mock_env = MockEnv((256, 256), vectorized = True, num_envs = 4)
experience = world_model.interact_with_env(mock_env, max_timesteps = 8, env_is_vectorized = True)
actor_loss, critic_loss = world_model.learn_from_experience(experience)
(actor_loss + critic_loss).backward()
```
[Temporary Discord](https://discord.gg/MkACrrkrYR)
## Citation
@ -108,5 +19,3 @@ actor_loss, critic_loss = world_model.learn_from_experience(experience)
url = {https://arxiv.org/abs/2509.24527},
}
```
*the conquest of nature is to be achieved through number and measure - angels to Descartes in a dream*

View File

@ -1,7 +1,6 @@
from dreamer4.dreamer4 import (
VideoTokenizer,
DynamicsWorldModel,
AxialSpaceTimeTransformer
DynamicsWorldModel
)

File diff suppressed because it is too large Load Diff

View File

@ -7,11 +7,6 @@ from torch.nn import Module
from einops import repeat
# helpers
def exists(v):
return v is not None
# mock env
class MockEnv(Module):
@ -20,11 +15,7 @@ class MockEnv(Module):
image_shape,
reward_range = (-100, 100),
num_envs = 1,
vectorized = False,
terminate_after_step = None,
rand_terminate_prob = 0.05,
can_truncate = False,
rand_truncate_prob = 0.05,
vectorized = False
):
super().__init__()
self.image_shape = image_shape
@ -34,15 +25,6 @@ class MockEnv(Module):
self.vectorized = vectorized
assert not (vectorized and num_envs == 1)
# mocking termination and truncation
self.can_terminate = exists(terminate_after_step)
self.terminate_after_step = terminate_after_step
self.rand_terminate_prob = rand_terminate_prob
self.can_truncate = can_truncate
self.rand_truncate_prob = rand_truncate_prob
self.register_buffer('_step', tensor(0))
def get_random_state(self):
@ -68,30 +50,13 @@ class MockEnv(Module):
reward = empty(()).uniform_(*self.reward_range)
if self.vectorized:
discrete, continuous = actions
assert discrete.shape[0] == self.num_envs, f'expected batch of actions for {self.num_envs} environments'
if not self.vectorized:
return state, reward
state = repeat(state, '... -> b ...', b = self.num_envs)
reward = repeat(reward, ' -> b', b = self.num_envs)
discrete, continuous = actions
assert discrete.shape[0] == self.num_envs, f'expected batch of actions for {self.num_envs} environments'
out = (state, reward)
state = repeat(state, '... -> b ...', b = self.num_envs)
reward = repeat(reward, ' -> b', b = self.num_envs)
if self.can_terminate:
shape = (self.num_envs,) if self.vectorized else (1,)
valid_step = self._step > self.terminate_after_step
terminate = (torch.rand(shape) < self.rand_terminate_prob) & valid_step
out = (*out, terminate)
# maybe truncation
if self.can_truncate:
truncate = (torch.rand(shape) < self.rand_truncate_prob) & valid_step & ~terminate
out = (*out, truncate)
self._step.add_(1)
return out
return state, reward

View File

@ -4,7 +4,7 @@ import torch
from torch import is_tensor
from torch.nn import Module
from torch.optim import AdamW
from torch.utils.data import Dataset, TensorDataset, DataLoader
from torch.utils.data import Dataset, DataLoader
from accelerate import Accelerator
@ -12,9 +12,7 @@ from adam_atan2_pytorch import MuonAdamAtan2
from dreamer4.dreamer4 import (
VideoTokenizer,
DynamicsWorldModel,
Experience,
combine_experiences
DynamicsWorldModel
)
from ema_pytorch import EMA
@ -287,7 +285,7 @@ class DreamTrainer(Module):
for _ in range(self.num_train_steps):
dreams = self.unwrapped_model.generate(
self.generate_timesteps + 1, # plus one for bootstrap value
self.generate_timesteps,
batch_size = self.batch_size,
return_rewards_per_frame = True,
return_agent_actions = True,
@ -319,221 +317,3 @@ class DreamTrainer(Module):
self.value_head_optim.zero_grad()
self.print('training complete')
# training from sim
class SimTrainer(Module):
def __init__(
self,
model: DynamicsWorldModel,
optim_klass = AdamW,
batch_size = 16,
generate_timesteps = 16,
learning_rate = 3e-4,
max_grad_norm = None,
epochs = 2,
weight_decay = 0.,
accelerate_kwargs: dict = dict(),
optim_kwargs: dict = dict(),
cpu = False,
):
super().__init__()
self.accelerator = Accelerator(
cpu = cpu,
**accelerate_kwargs
)
self.model = model
optim_kwargs = dict(
lr = learning_rate,
weight_decay = weight_decay
)
self.policy_head_optim = AdamW(model.policy_head_parameters(), **optim_kwargs)
self.value_head_optim = AdamW(model.value_head_parameters(), **optim_kwargs)
self.max_grad_norm = max_grad_norm
self.epochs = epochs
self.batch_size = batch_size
self.generate_timesteps = generate_timesteps
self.unwrapped_model = self.model
(
self.model,
self.policy_head_optim,
self.value_head_optim,
) = self.accelerator.prepare(
self.model,
self.policy_head_optim,
self.value_head_optim
)
@property
def device(self):
return self.accelerator.device
@property
def unwrapped_model(self):
return self.accelerator.unwrap_model(self.model)
def print(self, *args, **kwargs):
return self.accelerator.print(*args, **kwargs)
def learn(
self,
experience: Experience
):
step_size = experience.step_size
agent_index = experience.agent_index
latents = experience.latents
old_values = experience.values
rewards = experience.rewards
has_agent_embed = exists(experience.agent_embed)
agent_embed = experience.agent_embed
discrete_actions, continuous_actions = experience.actions
discrete_log_probs, continuous_log_probs = experience.log_probs
discrete_old_action_unembeds, continuous_old_action_unembeds = default(experience.old_action_unembeds, (None, None))
# handle empties
empty_tensor = torch.empty_like(rewards)
agent_embed = default(agent_embed, empty_tensor)
has_discrete = exists(discrete_actions)
has_continuous = exists(continuous_actions)
discrete_actions = default(discrete_actions, empty_tensor)
continuous_actions = default(continuous_actions, empty_tensor)
discrete_log_probs = default(discrete_log_probs, empty_tensor)
continuous_log_probs = default(continuous_log_probs, empty_tensor)
discrete_old_action_unembeds = default(discrete_old_action_unembeds, empty_tensor)
continuous_old_action_unembeds = default(discrete_old_action_unembeds, empty_tensor)
# create the dataset and dataloader
dataset = TensorDataset(
latents,
discrete_actions,
continuous_actions,
discrete_log_probs,
continuous_log_probs,
agent_embed,
discrete_old_action_unembeds,
continuous_old_action_unembeds,
old_values,
rewards
)
dataloader = DataLoader(dataset, batch_size = self.batch_size, shuffle = True)
for epoch in range(self.epochs):
for (
latents,
discrete_actions,
continuous_actions,
discrete_log_probs,
continuous_log_probs,
agent_embed,
discrete_old_action_unembeds,
continuous_old_action_unembeds,
old_values,
rewards
) in dataloader:
actions = (
discrete_actions if has_discrete else None,
continuous_actions if has_continuous else None
)
log_probs = (
discrete_log_probs if has_discrete else None,
continuous_log_probs if has_continuous else None
)
old_action_unembeds = (
discrete_old_action_unembeds if has_discrete else None,
continuous_old_action_unembeds if has_continuous else None
)
batch_experience = Experience(
latents = latents,
actions = actions,
log_probs = log_probs,
agent_embed = agent_embed if has_agent_embed else None,
old_action_unembeds = old_action_unembeds,
values = old_values,
rewards = rewards,
step_size = step_size,
agent_index = agent_index
)
policy_head_loss, value_head_loss = self.model.learn_from_experience(batch_experience)
self.print(f'policy head loss: {policy_head_loss.item():.3f} | value head loss: {value_head_loss.item():.3f}')
# update policy head
self.accelerator.backward(policy_head_loss)
if exists(self.max_grad_norm):
self.accelerator.clip_grad_norm_(self.model.policy_head_parameters()(), self.max_grad_norm)
self.policy_head_optim.step()
self.policy_head_optim.zero_grad()
# update value head
self.accelerator.backward(value_head_loss)
if exists(self.max_grad_norm):
self.accelerator.clip_grad_norm_(self.model.value_head_parameters(), self.max_grad_norm)
self.value_head_optim.step()
self.value_head_optim.zero_grad()
self.print('training complete')
def forward(
self,
env,
num_episodes = 50000,
max_experiences_before_learn = 8,
env_is_vectorized = False
):
for _ in range(num_episodes):
total_experience = 0
experiences = []
while total_experience < max_experiences_before_learn:
experience = self.unwrapped_model.interact_with_env(env, env_is_vectorized = env_is_vectorized)
num_experience = experience.video.shape[0]
total_experience += num_experience
experiences.append(experience.cpu())
combined_experiences = combine_experiences(experiences)
self.learn(combined_experiences)
experiences.clear()
self.print('training complete')

View File

@ -1,6 +1,6 @@
[project]
name = "dreamer4"
version = "0.1.24"
version = "0.0.64"
description = "Dreamer 4"
authors = [
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
@ -36,8 +36,7 @@ dependencies = [
"hyper-connections>=0.2.1",
"torch>=2.4",
"torchvision",
"x-mlps-pytorch>=0.0.29",
"vit-pytorch>=1.15.3"
"x-mlps-pytorch>=0.0.29"
]
[project.urls]

View File

@ -2,9 +2,6 @@ import pytest
param = pytest.mark.parametrize
import torch
def exists(v):
return v is not None
@param('pred_orig_latent', (False, True))
@param('grouped_query_attn', (False, True))
@param('dynamics_with_video_input', (False, True))
@ -15,9 +12,7 @@ def exists(v):
@param('condition_on_actions', (False, True))
@param('num_residual_streams', (1, 4))
@param('add_reward_embed_to_agent_token', (False, True))
@param('add_state_pred_head', (False, True))
@param('use_time_cache', (False, True))
@param('var_len', (False, True))
@param('use_time_kv_cache', (False, True))
def test_e2e(
pred_orig_latent,
grouped_query_attn,
@ -29,9 +24,7 @@ def test_e2e(
condition_on_actions,
num_residual_streams,
add_reward_embed_to_agent_token,
add_state_pred_head,
use_time_cache,
var_len
use_time_kv_cache
):
from dreamer4.dreamer4 import VideoTokenizer, DynamicsWorldModel
@ -43,9 +36,7 @@ def test_e2e(
patch_size = 32,
attn_dim_head = 16,
num_latent_tokens = 4,
num_residual_streams = num_residual_streams,
encoder_add_decor_aux_loss = True,
decorr_sample_frac = 1.
num_residual_streams = num_residual_streams
)
video = torch.randn(2, 3, 4, 256, 256)
@ -73,13 +64,12 @@ def test_e2e(
pred_orig_latent = pred_orig_latent,
num_discrete_actions = 4,
attn_dim_head = 16,
attn_heads = heads,
attn_kwargs = dict(
heads = heads,
query_heads = query_heads,
),
prob_no_shortcut_train = prob_no_shortcut_train,
add_reward_embed_to_agent_token = add_reward_embed_to_agent_token,
add_state_pred_head = add_state_pred_head,
num_residual_streams = num_residual_streams
)
@ -102,13 +92,8 @@ def test_e2e(
if condition_on_actions:
actions = torch.randint(0, 4, (2, 3, 1))
lens = None
if var_len:
lens = torch.randint(1, 4, (2,))
flow_loss = dynamics(
**dynamics_input,
lens = lens,
tasks = tasks,
signal_levels = signal_levels,
step_sizes_log2 = step_sizes_log2,
@ -126,7 +111,7 @@ def test_e2e(
image_width = 128,
batch_size = 2,
return_rewards_per_frame = True,
use_time_cache = use_time_cache
use_time_kv_cache = use_time_kv_cache
)
assert generations.video.shape == (2, 3, 10, 128, 128)
@ -351,15 +336,6 @@ def test_action_embedder():
assert discrete_logits.shape == (2, 3, 8)
assert continuous_mean_log_var.shape == (2, 3, 2, 2)
# test kl div
discrete_logits_tgt, continuous_mean_log_var_tgt = embedder.unembed(action_embed)
discrete_kl_div, continuous_kl_div = embedder.kl_div((discrete_logits, continuous_mean_log_var), (discrete_logits_tgt, continuous_mean_log_var_tgt))
assert discrete_kl_div.shape == (2, 3)
assert continuous_kl_div.shape == (2, 3)
# return discrete split by number of actions
discrete_logits, continuous_mean_log_var = embedder.unembed(action_embed, return_split_discrete = True)
@ -421,14 +397,14 @@ def test_mtp():
reward_targets, mask = create_multi_token_prediction_targets(rewards, 3) # say three token lookahead
assert reward_targets.shape == (3, 16, 3)
assert mask.shape == (3, 16, 3)
assert reward_targets.shape == (3, 15, 3)
assert mask.shape == (3, 15, 3)
actions = torch.randint(0, 10, (3, 16, 2))
action_targets, mask = create_multi_token_prediction_targets(actions, 3)
assert action_targets.shape == (3, 16, 3, 2)
assert mask.shape == (3, 16, 3)
assert action_targets.shape == (3, 15, 3, 2)
assert mask.shape == (3, 15, 3)
from dreamer4.dreamer4 import ActionEmbedder
@ -620,22 +596,12 @@ def test_cache_generate():
num_residual_streams = 1
)
generated, time_cache = dynamics.generate(1, return_time_cache = True)
generated, time_cache = dynamics.generate(1, time_cache = time_cache, return_time_cache = True)
generated, time_cache = dynamics.generate(1, time_cache = time_cache, return_time_cache = True)
generated, time_kv_cache = dynamics.generate(1, return_time_kv_cache = True)
generated, time_kv_cache = dynamics.generate(1, time_kv_cache = time_kv_cache, return_time_kv_cache = True)
generated, time_kv_cache = dynamics.generate(1, time_kv_cache = time_kv_cache, return_time_kv_cache = True)
@param('vectorized', (False, True))
@param('use_pmpo', (False, True))
@param('env_can_terminate', (False, True))
@param('env_can_truncate', (False, True))
@param('store_agent_embed', (False, True))
def test_online_rl(
vectorized,
use_pmpo,
env_can_terminate,
env_can_truncate,
store_agent_embed
):
def test_online_rl(vectorized):
from dreamer4.dreamer4 import DynamicsWorldModel, VideoTokenizer
tokenizer = VideoTokenizer(
@ -646,9 +612,7 @@ def test_online_rl(
dim_latent = 16,
patch_size = 32,
attn_dim_head = 16,
num_latent_tokens = 1,
image_height = 256,
image_width = 256,
num_latent_tokens = 1
)
world_model_and_policy = DynamicsWorldModel(
@ -669,163 +633,11 @@ def test_online_rl(
)
from dreamer4.mocks import MockEnv
from dreamer4.dreamer4 import combine_experiences
mock_env = MockEnv((256, 256), vectorized = vectorized, num_envs = 4)
mock_env = MockEnv(
(256, 256),
vectorized = vectorized,
num_envs = 4,
terminate_after_step = 2 if env_can_terminate else None,
can_truncate = env_can_truncate,
rand_terminate_prob = 0.1
)
one_experience = world_model_and_policy.interact_with_env(mock_env, max_timesteps = 16, env_is_vectorized = vectorized)
# manually
dream_experience = world_model_and_policy.generate(10, batch_size = 1, store_agent_embed = store_agent_embed, return_for_policy_optimization = True)
one_experience = world_model_and_policy.interact_with_env(mock_env, max_timesteps = 8, env_is_vectorized = vectorized, store_agent_embed = store_agent_embed)
another_experience = world_model_and_policy.interact_with_env(mock_env, max_timesteps = 16, env_is_vectorized = vectorized, store_agent_embed = store_agent_embed)
combined_experience = combine_experiences([dream_experience, one_experience, another_experience])
# quick test moving the experience to different devices
if torch.cuda.is_available():
combined_experience = combined_experience.to(torch.device('cuda'))
combined_experience = combined_experience.to(world_model_and_policy.device)
if store_agent_embed:
assert exists(combined_experience.agent_embed)
actor_loss, critic_loss = world_model_and_policy.learn_from_experience(combined_experience, use_pmpo = use_pmpo)
actor_loss, critic_loss = world_model_and_policy.learn_from_experience(one_experience)
actor_loss.backward()
critic_loss.backward()
# with trainer
from dreamer4.trainers import SimTrainer
trainer = SimTrainer(
world_model_and_policy,
batch_size = 4,
cpu = True
)
trainer(mock_env, num_episodes = 2, env_is_vectorized = vectorized)
@param('num_video_views', (1, 2))
def test_proprioception(
num_video_views
):
from dreamer4.dreamer4 import VideoTokenizer, DynamicsWorldModel
tokenizer = VideoTokenizer(
512,
dim_latent = 32,
patch_size = 32,
encoder_depth = 2,
decoder_depth = 2,
time_block_every = 2,
attn_heads = 8,
image_height = 256,
image_width = 256,
attn_kwargs = dict(
query_heads = 16
)
)
dynamics = DynamicsWorldModel(
512,
num_agents = 1,
video_tokenizer = tokenizer,
dim_latent = 32,
dim_proprio = 21,
num_tasks = 4,
num_video_views = num_video_views,
num_discrete_actions = 4,
num_residual_streams = 1
)
if num_video_views > 1:
video_shape = (2, num_video_views, 3, 10, 256, 256)
else:
video_shape = (2, 3, 10, 256, 256)
video = torch.randn(*video_shape)
rewards = torch.randn(2, 10)
proprio = torch.randn(2, 10, 21)
discrete_actions = torch.randint(0, 4, (2, 10, 1))
tasks = torch.randint(0, 4, (2,))
loss = dynamics(
video = video,
rewards = rewards,
tasks = tasks,
proprio = proprio,
discrete_actions = discrete_actions
)
loss.backward()
generations = dynamics.generate(
10,
batch_size = 2,
return_decoded_video = True
)
assert exists(generations.proprio)
assert generations.video.shape == video_shape
def test_epo():
from dreamer4.dreamer4 import VideoTokenizer, DynamicsWorldModel
tokenizer = VideoTokenizer(
512,
dim_latent = 32,
patch_size = 32,
encoder_depth = 2,
decoder_depth = 2,
time_block_every = 2,
attn_heads = 8,
image_height = 256,
image_width = 256,
attn_kwargs = dict(
query_heads = 16
)
)
dynamics = DynamicsWorldModel(
512,
num_agents = 1,
video_tokenizer = tokenizer,
dim_latent = 32,
dim_proprio = 21,
num_tasks = 4,
num_latent_genes = 16,
num_discrete_actions = 4,
num_residual_streams = 1
)
fitness = torch.randn(16,)
dynamics.evolve_(fitness)
def test_images_to_video_tokenizer():
import torch
from dreamer4 import VideoTokenizer, DynamicsWorldModel, AxialSpaceTimeTransformer
tokenizer = VideoTokenizer(
dim = 512,
dim_latent = 32,
patch_size = 32,
image_height = 256,
image_width = 256,
encoder_add_decor_aux_loss = True
)
images = torch.randn(2, 3, 256, 256)
loss, (losses, recon_images) = tokenizer(images, return_intermediates = True)
loss.backward()
assert images.shape == recon_images.shape