Compare commits

..

No commits in common. "main" and "0.0.90" have entirely different histories.
main ... 0.0.90

6 changed files with 165 additions and 806 deletions

View File

@ -1,99 +1,10 @@
<img src="./dreamer4-fig2.png" width="400px"></img> <img src="./dreamer4-fig2.png" width="400px"></img>
## Dreamer 4 ## Dreamer 4 (wip)
Implementation of Danijar's [latest iteration](https://arxiv.org/abs/2509.24527v1) for his [Dreamer](https://danijar.com/project/dreamer4/) line of work Implementation of Danijar's [latest iteration](https://arxiv.org/abs/2509.24527v1) for his [Dreamer](https://danijar.com/project/dreamer4/) line of work
[Discord channel](https://discord.gg/PmGR7KRwxq) for collaborating with other researchers interested in this work [Temporary Discord](https://discord.gg/MkACrrkrYR)
## Appreciation
- [@dirkmcpherson](https://github.com/dirkmcpherson) for fixes to typo errors and unpassed arguments!
## Install
```bash
$ pip install dreamer4
```
## Usage
```python
import torch
from dreamer4 import VideoTokenizer, DynamicsWorldModel
# video tokenizer, learned through MAE + lpips
tokenizer = VideoTokenizer(
dim = 512,
dim_latent = 32,
patch_size = 32,
image_height = 256,
image_width = 256
)
video = torch.randn(2, 3, 10, 256, 256)
# learn the tokenizer
loss = tokenizer(video)
loss.backward()
# dynamics world model
world_model = DynamicsWorldModel(
dim = 512,
dim_latent = 32,
video_tokenizer = tokenizer,
num_discrete_actions = 4,
num_residual_streams = 1
)
# state, action, rewards
video = torch.randn(2, 3, 10, 256, 256)
discrete_actions = torch.randint(0, 4, (2, 10, 1))
rewards = torch.randn(2, 10)
# learn dynamics / behavior cloned model
loss = world_model(
video = video,
rewards = rewards,
discrete_actions = discrete_actions
)
loss.backward()
# do the above with much data
# then generate dreams
dreams = world_model.generate(
10,
batch_size = 2,
return_decoded_video = True,
return_for_policy_optimization = True
)
# learn from the dreams
actor_loss, critic_loss = world_model.learn_from_experience(dreams)
(actor_loss + critic_loss).backward()
# learn from environment
from dreamer4.mocks import MockEnv
mock_env = MockEnv((256, 256), vectorized = True, num_envs = 4)
experience = world_model.interact_with_env(mock_env, max_timesteps = 8, env_is_vectorized = True)
actor_loss, critic_loss = world_model.learn_from_experience(experience)
(actor_loss + critic_loss).backward()
```
## Citation ## Citation
@ -108,5 +19,3 @@ actor_loss, critic_loss = world_model.learn_from_experience(experience)
url = {https://arxiv.org/abs/2509.24527}, url = {https://arxiv.org/abs/2509.24527},
} }
``` ```
*the conquest of nature is to be achieved through number and measure - angels to Descartes in a dream*

View File

@ -1,7 +1,6 @@
from dreamer4.dreamer4 import ( from dreamer4.dreamer4 import (
VideoTokenizer, VideoTokenizer,
DynamicsWorldModel, DynamicsWorldModel
AxialSpaceTimeTransformer
) )

File diff suppressed because it is too large Load Diff

View File

@ -396,20 +396,13 @@ class SimTrainer(Module):
old_values = experience.values old_values = experience.values
rewards = experience.rewards rewards = experience.rewards
has_agent_embed = exists(experience.agent_embed)
agent_embed = experience.agent_embed
discrete_actions, continuous_actions = experience.actions discrete_actions, continuous_actions = experience.actions
discrete_log_probs, continuous_log_probs = experience.log_probs discrete_log_probs, continuous_log_probs = experience.log_probs
discrete_old_action_unembeds, continuous_old_action_unembeds = default(experience.old_action_unembeds, (None, None))
# handle empties # handle empties
empty_tensor = torch.empty_like(rewards) empty_tensor = torch.empty_like(rewards)
agent_embed = default(agent_embed, empty_tensor)
has_discrete = exists(discrete_actions) has_discrete = exists(discrete_actions)
has_continuous = exists(continuous_actions) has_continuous = exists(continuous_actions)
@ -419,9 +412,6 @@ class SimTrainer(Module):
discrete_log_probs = default(discrete_log_probs, empty_tensor) discrete_log_probs = default(discrete_log_probs, empty_tensor)
continuous_log_probs = default(continuous_log_probs, empty_tensor) continuous_log_probs = default(continuous_log_probs, empty_tensor)
discrete_old_action_unembeds = default(discrete_old_action_unembeds, empty_tensor)
continuous_old_action_unembeds = default(discrete_old_action_unembeds, empty_tensor)
# create the dataset and dataloader # create the dataset and dataloader
dataset = TensorDataset( dataset = TensorDataset(
@ -430,9 +420,6 @@ class SimTrainer(Module):
continuous_actions, continuous_actions,
discrete_log_probs, discrete_log_probs,
continuous_log_probs, continuous_log_probs,
agent_embed,
discrete_old_action_unembeds,
continuous_old_action_unembeds,
old_values, old_values,
rewards rewards
) )
@ -447,9 +434,6 @@ class SimTrainer(Module):
continuous_actions, continuous_actions,
discrete_log_probs, discrete_log_probs,
continuous_log_probs, continuous_log_probs,
agent_embed,
discrete_old_action_unembeds,
continuous_old_action_unembeds,
old_values, old_values,
rewards rewards
) in dataloader: ) in dataloader:
@ -464,17 +448,10 @@ class SimTrainer(Module):
continuous_log_probs if has_continuous else None continuous_log_probs if has_continuous else None
) )
old_action_unembeds = (
discrete_old_action_unembeds if has_discrete else None,
continuous_old_action_unembeds if has_continuous else None
)
batch_experience = Experience( batch_experience = Experience(
latents = latents, latents = latents,
actions = actions, actions = actions,
log_probs = log_probs, log_probs = log_probs,
agent_embed = agent_embed if has_agent_embed else None,
old_action_unembeds = old_action_unembeds,
values = old_values, values = old_values,
rewards = rewards, rewards = rewards,
step_size = step_size, step_size = step_size,
@ -528,7 +505,7 @@ class SimTrainer(Module):
total_experience += num_experience total_experience += num_experience
experiences.append(experience.cpu()) experiences.append(experience)
combined_experiences = combine_experiences(experiences) combined_experiences = combine_experiences(experiences)

View File

@ -1,6 +1,6 @@
[project] [project]
name = "dreamer4" name = "dreamer4"
version = "0.1.24" version = "0.0.90"
description = "Dreamer 4" description = "Dreamer 4"
authors = [ authors = [
{ name = "Phil Wang", email = "lucidrains@gmail.com" } { name = "Phil Wang", email = "lucidrains@gmail.com" }
@ -36,8 +36,7 @@ dependencies = [
"hyper-connections>=0.2.1", "hyper-connections>=0.2.1",
"torch>=2.4", "torch>=2.4",
"torchvision", "torchvision",
"x-mlps-pytorch>=0.0.29", "x-mlps-pytorch>=0.0.29"
"vit-pytorch>=1.15.3"
] ]
[project.urls] [project.urls]

View File

@ -15,8 +15,7 @@ def exists(v):
@param('condition_on_actions', (False, True)) @param('condition_on_actions', (False, True))
@param('num_residual_streams', (1, 4)) @param('num_residual_streams', (1, 4))
@param('add_reward_embed_to_agent_token', (False, True)) @param('add_reward_embed_to_agent_token', (False, True))
@param('add_state_pred_head', (False, True)) @param('use_time_kv_cache', (False, True))
@param('use_time_cache', (False, True))
@param('var_len', (False, True)) @param('var_len', (False, True))
def test_e2e( def test_e2e(
pred_orig_latent, pred_orig_latent,
@ -29,8 +28,7 @@ def test_e2e(
condition_on_actions, condition_on_actions,
num_residual_streams, num_residual_streams,
add_reward_embed_to_agent_token, add_reward_embed_to_agent_token,
add_state_pred_head, use_time_kv_cache,
use_time_cache,
var_len var_len
): ):
from dreamer4.dreamer4 import VideoTokenizer, DynamicsWorldModel from dreamer4.dreamer4 import VideoTokenizer, DynamicsWorldModel
@ -43,9 +41,7 @@ def test_e2e(
patch_size = 32, patch_size = 32,
attn_dim_head = 16, attn_dim_head = 16,
num_latent_tokens = 4, num_latent_tokens = 4,
num_residual_streams = num_residual_streams, num_residual_streams = num_residual_streams
encoder_add_decor_aux_loss = True,
decorr_sample_frac = 1.
) )
video = torch.randn(2, 3, 4, 256, 256) video = torch.randn(2, 3, 4, 256, 256)
@ -73,13 +69,12 @@ def test_e2e(
pred_orig_latent = pred_orig_latent, pred_orig_latent = pred_orig_latent,
num_discrete_actions = 4, num_discrete_actions = 4,
attn_dim_head = 16, attn_dim_head = 16,
attn_heads = heads,
attn_kwargs = dict( attn_kwargs = dict(
heads = heads,
query_heads = query_heads, query_heads = query_heads,
), ),
prob_no_shortcut_train = prob_no_shortcut_train, prob_no_shortcut_train = prob_no_shortcut_train,
add_reward_embed_to_agent_token = add_reward_embed_to_agent_token, add_reward_embed_to_agent_token = add_reward_embed_to_agent_token,
add_state_pred_head = add_state_pred_head,
num_residual_streams = num_residual_streams num_residual_streams = num_residual_streams
) )
@ -126,7 +121,7 @@ def test_e2e(
image_width = 128, image_width = 128,
batch_size = 2, batch_size = 2,
return_rewards_per_frame = True, return_rewards_per_frame = True,
use_time_cache = use_time_cache use_time_kv_cache = use_time_kv_cache
) )
assert generations.video.shape == (2, 3, 10, 128, 128) assert generations.video.shape == (2, 3, 10, 128, 128)
@ -351,15 +346,6 @@ def test_action_embedder():
assert discrete_logits.shape == (2, 3, 8) assert discrete_logits.shape == (2, 3, 8)
assert continuous_mean_log_var.shape == (2, 3, 2, 2) assert continuous_mean_log_var.shape == (2, 3, 2, 2)
# test kl div
discrete_logits_tgt, continuous_mean_log_var_tgt = embedder.unembed(action_embed)
discrete_kl_div, continuous_kl_div = embedder.kl_div((discrete_logits, continuous_mean_log_var), (discrete_logits_tgt, continuous_mean_log_var_tgt))
assert discrete_kl_div.shape == (2, 3)
assert continuous_kl_div.shape == (2, 3)
# return discrete split by number of actions # return discrete split by number of actions
discrete_logits, continuous_mean_log_var = embedder.unembed(action_embed, return_split_discrete = True) discrete_logits, continuous_mean_log_var = embedder.unembed(action_embed, return_split_discrete = True)
@ -620,18 +606,18 @@ def test_cache_generate():
num_residual_streams = 1 num_residual_streams = 1
) )
generated, time_cache = dynamics.generate(1, return_time_cache = True) generated, time_kv_cache = dynamics.generate(1, return_time_kv_cache = True)
generated, time_cache = dynamics.generate(1, time_cache = time_cache, return_time_cache = True) generated, time_kv_cache = dynamics.generate(1, time_kv_cache = time_kv_cache, return_time_kv_cache = True)
generated, time_cache = dynamics.generate(1, time_cache = time_cache, return_time_cache = True) generated, time_kv_cache = dynamics.generate(1, time_kv_cache = time_kv_cache, return_time_kv_cache = True)
@param('vectorized', (False, True)) @param('vectorized', (False, True))
@param('use_pmpo', (False, True)) @param('use_signed_advantage', (False, True))
@param('env_can_terminate', (False, True)) @param('env_can_terminate', (False, True))
@param('env_can_truncate', (False, True)) @param('env_can_truncate', (False, True))
@param('store_agent_embed', (False, True)) @param('store_agent_embed', (False, True))
def test_online_rl( def test_online_rl(
vectorized, vectorized,
use_pmpo, use_signed_advantage,
env_can_terminate, env_can_terminate,
env_can_truncate, env_can_truncate,
store_agent_embed store_agent_embed
@ -646,9 +632,7 @@ def test_online_rl(
dim_latent = 16, dim_latent = 16,
patch_size = 32, patch_size = 32,
attn_dim_head = 16, attn_dim_head = 16,
num_latent_tokens = 1, num_latent_tokens = 1
image_height = 256,
image_width = 256,
) )
world_model_and_policy = DynamicsWorldModel( world_model_and_policy = DynamicsWorldModel(
@ -682,23 +666,15 @@ def test_online_rl(
# manually # manually
dream_experience = world_model_and_policy.generate(10, batch_size = 1, store_agent_embed = store_agent_embed, return_for_policy_optimization = True)
one_experience = world_model_and_policy.interact_with_env(mock_env, max_timesteps = 8, env_is_vectorized = vectorized, store_agent_embed = store_agent_embed) one_experience = world_model_and_policy.interact_with_env(mock_env, max_timesteps = 8, env_is_vectorized = vectorized, store_agent_embed = store_agent_embed)
another_experience = world_model_and_policy.interact_with_env(mock_env, max_timesteps = 16, env_is_vectorized = vectorized, store_agent_embed = store_agent_embed) another_experience = world_model_and_policy.interact_with_env(mock_env, max_timesteps = 16, env_is_vectorized = vectorized, store_agent_embed = store_agent_embed)
combined_experience = combine_experiences([dream_experience, one_experience, another_experience]) combined_experience = combine_experiences([one_experience, another_experience])
# quick test moving the experience to different devices
if torch.cuda.is_available():
combined_experience = combined_experience.to(torch.device('cuda'))
combined_experience = combined_experience.to(world_model_and_policy.device)
if store_agent_embed: if store_agent_embed:
assert exists(combined_experience.agent_embed) assert exists(combined_experience.agent_embed)
actor_loss, critic_loss = world_model_and_policy.learn_from_experience(combined_experience, use_pmpo = use_pmpo) actor_loss, critic_loss = world_model_and_policy.learn_from_experience(combined_experience, use_signed_advantage = use_signed_advantage)
actor_loss.backward() actor_loss.backward()
critic_loss.backward() critic_loss.backward()
@ -810,22 +786,3 @@ def test_epo():
fitness = torch.randn(16,) fitness = torch.randn(16,)
dynamics.evolve_(fitness) dynamics.evolve_(fitness)
def test_images_to_video_tokenizer():
import torch
from dreamer4 import VideoTokenizer, DynamicsWorldModel, AxialSpaceTimeTransformer
tokenizer = VideoTokenizer(
dim = 512,
dim_latent = 32,
patch_size = 32,
image_height = 256,
image_width = 256,
encoder_add_decor_aux_loss = True
)
images = torch.randn(2, 3, 256, 256)
loss, (losses, recon_images) = tokenizer(images, return_intermediates = True)
loss.backward()
assert images.shape == recon_images.shape