Compare commits

...

27 Commits
0.1.0 ... main

Author SHA1 Message Date
lucidrains
5bb027b386 allow for image pretraining on video tokenizer 2025-12-04 10:34:15 -08:00
lucidrains
9efe269688 oops 2025-12-03 08:11:47 -08:00
lucidrains
fb8c3793b4 complete the addition of a state entropy bonus 2025-12-03 07:52:30 -08:00
lucidrains
fb6d69f43a complete the latent autoregressive prediction, to use the log variance as a state entropy bonus 2025-12-03 06:40:19 -08:00
lucidrains
125693ce1c add a separate state prediction head for the state entropy 2025-12-02 15:58:25 -08:00
lucidrains
2e7f406d49 allow for the combining of experiences from environment and dream 2025-11-13 16:37:35 -08:00
lucidrains
690ecf07dc fix the rnn time caching issue 2025-11-11 17:04:02 -08:00
lucidrains
ac1c12f743 disable until rnn hiddens are handled properly 2025-11-10 15:52:43 -08:00
lucidrains
3c84b404a8 rnn layer needs to be hyper connected too 2025-11-10 15:51:33 -08:00
lucidrains
d5b70e2b86 allow for adding an RNN before time attention, but need to handle caching still 2025-11-10 11:42:20 -08:00
lucidrains
c3532fa797 add learned value residual 2025-11-10 09:33:58 -08:00
lucidrains
73029635fe last commit for the day 2025-11-09 11:12:37 -08:00
lucidrains
e1c41f4371 decorrelation loss for spatial attention as well 2025-11-09 10:41:58 -08:00
Phil Wang
f55c61c6cf
cleanup 2025-11-09 10:22:47 -08:00
lucidrains
051d4d6ee2 oops 2025-11-09 10:12:51 -08:00
lucidrains
ef3a5552e7 eventually video tokenizer may need to be trained on single frames 2025-11-09 10:11:56 -08:00
lucidrains
0c4224da18 add a decorrelation loss for temporal attention in encoder of video tokenizer 2025-11-09 09:47:47 -08:00
Phil Wang
256a81f658
Merge pull request #5 from Cycl0/patch-1
Update Discord channel link in README to use permanent link
2025-11-09 08:17:41 -08:00
lucidrains
cfd34f1eba able to move the experience to cpu easily, and auto matically move it to the device of the dynamics world model when learning from it 2025-11-09 16:16:13 +00:00
Lucas Kenzo Cyra
4ffbe37873
Update Discord channel link in README to use permanent link
Updated Discord channel link for collaboration.
2025-11-09 10:12:45 -03:00
lucidrains
24ef72d528 0.1.4 2025-11-04 15:21:20 -08:00
Phil Wang
a4afcb22a6
Merge pull request #4 from dirkmcpherson/bugfix
fix a few typo bugs. Support info in return signature of environment …
2025-11-04 15:19:25 -08:00
j
b0f6b8583d fix a few typo bugs. Support info in return signature of environment step. Temporarily turn off flex attention when the kv_cache is used to avoid bug. 2025-11-04 17:29:12 -05:00
lucidrains
38cba80068 readme 2025-11-04 06:05:11 -08:00
lucidrains
c0a6cd56a1 link to new discord 2025-10-31 09:06:44 -07:00
lucidrains
d756d1bb8c addressing issues raised by an independent researcher with llm assistance 2025-10-31 08:37:39 -07:00
lucidrains
60681fce1d fix generation so that one more step is taken to decode agent embeds off the final cleaned set of latents, update readme 2025-10-31 06:48:49 -07:00
6 changed files with 582 additions and 125 deletions

View File

@ -1,9 +1,100 @@
<img src="./dreamer4-fig2.png" width="400px"></img> <img src="./dreamer4-fig2.png" width="400px"></img>
## Dreamer 4 (wip) ## Dreamer 4
Implementation of Danijar's [latest iteration](https://arxiv.org/abs/2509.24527v1) for his [Dreamer](https://danijar.com/project/dreamer4/) line of work Implementation of Danijar's [latest iteration](https://arxiv.org/abs/2509.24527v1) for his [Dreamer](https://danijar.com/project/dreamer4/) line of work
[Discord channel](https://discord.gg/PmGR7KRwxq) for collaborating with other researchers interested in this work
## Appreciation
- [@dirkmcpherson](https://github.com/dirkmcpherson) for fixes to typo errors and unpassed arguments!
## Install
```bash
$ pip install dreamer4
```
## Usage
```python
import torch
from dreamer4 import VideoTokenizer, DynamicsWorldModel
# video tokenizer, learned through MAE + lpips
tokenizer = VideoTokenizer(
dim = 512,
dim_latent = 32,
patch_size = 32,
image_height = 256,
image_width = 256
)
video = torch.randn(2, 3, 10, 256, 256)
# learn the tokenizer
loss = tokenizer(video)
loss.backward()
# dynamics world model
world_model = DynamicsWorldModel(
dim = 512,
dim_latent = 32,
video_tokenizer = tokenizer,
num_discrete_actions = 4,
num_residual_streams = 1
)
# state, action, rewards
video = torch.randn(2, 3, 10, 256, 256)
discrete_actions = torch.randint(0, 4, (2, 10, 1))
rewards = torch.randn(2, 10)
# learn dynamics / behavior cloned model
loss = world_model(
video = video,
rewards = rewards,
discrete_actions = discrete_actions
)
loss.backward()
# do the above with much data
# then generate dreams
dreams = world_model.generate(
10,
batch_size = 2,
return_decoded_video = True,
return_for_policy_optimization = True
)
# learn from the dreams
actor_loss, critic_loss = world_model.learn_from_experience(dreams)
(actor_loss + critic_loss).backward()
# learn from environment
from dreamer4.mocks import MockEnv
mock_env = MockEnv((256, 256), vectorized = True, num_envs = 4)
experience = world_model.interact_with_env(mock_env, max_timesteps = 8, env_is_vectorized = True)
actor_loss, critic_loss = world_model.learn_from_experience(experience)
(actor_loss + critic_loss).backward()
```
## Citation ## Citation
```bibtex ```bibtex
@ -17,3 +108,5 @@ Implementation of Danijar's [latest iteration](https://arxiv.org/abs/2509.24527v
url = {https://arxiv.org/abs/2509.24527}, url = {https://arxiv.org/abs/2509.24527},
} }
``` ```
*the conquest of nature is to be achieved through number and measure - angels to Descartes in a dream*

View File

@ -1,6 +1,7 @@
from dreamer4.dreamer4 import ( from dreamer4.dreamer4 import (
VideoTokenizer, VideoTokenizer,
DynamicsWorldModel DynamicsWorldModel,
AxialSpaceTimeTransformer
) )

File diff suppressed because it is too large Load Diff

View File

@ -528,7 +528,7 @@ class SimTrainer(Module):
total_experience += num_experience total_experience += num_experience
experiences.append(experience) experiences.append(experience.cpu())
combined_experiences = combine_experiences(experiences) combined_experiences = combine_experiences(experiences)

View File

@ -1,6 +1,6 @@
[project] [project]
name = "dreamer4" name = "dreamer4"
version = "0.0.102" version = "0.1.24"
description = "Dreamer 4" description = "Dreamer 4"
authors = [ authors = [
{ name = "Phil Wang", email = "lucidrains@gmail.com" } { name = "Phil Wang", email = "lucidrains@gmail.com" }
@ -36,7 +36,8 @@ dependencies = [
"hyper-connections>=0.2.1", "hyper-connections>=0.2.1",
"torch>=2.4", "torch>=2.4",
"torchvision", "torchvision",
"x-mlps-pytorch>=0.0.29" "x-mlps-pytorch>=0.0.29",
"vit-pytorch>=1.15.3"
] ]
[project.urls] [project.urls]

View File

@ -15,7 +15,8 @@ def exists(v):
@param('condition_on_actions', (False, True)) @param('condition_on_actions', (False, True))
@param('num_residual_streams', (1, 4)) @param('num_residual_streams', (1, 4))
@param('add_reward_embed_to_agent_token', (False, True)) @param('add_reward_embed_to_agent_token', (False, True))
@param('use_time_kv_cache', (False, True)) @param('add_state_pred_head', (False, True))
@param('use_time_cache', (False, True))
@param('var_len', (False, True)) @param('var_len', (False, True))
def test_e2e( def test_e2e(
pred_orig_latent, pred_orig_latent,
@ -28,7 +29,8 @@ def test_e2e(
condition_on_actions, condition_on_actions,
num_residual_streams, num_residual_streams,
add_reward_embed_to_agent_token, add_reward_embed_to_agent_token,
use_time_kv_cache, add_state_pred_head,
use_time_cache,
var_len var_len
): ):
from dreamer4.dreamer4 import VideoTokenizer, DynamicsWorldModel from dreamer4.dreamer4 import VideoTokenizer, DynamicsWorldModel
@ -41,7 +43,9 @@ def test_e2e(
patch_size = 32, patch_size = 32,
attn_dim_head = 16, attn_dim_head = 16,
num_latent_tokens = 4, num_latent_tokens = 4,
num_residual_streams = num_residual_streams num_residual_streams = num_residual_streams,
encoder_add_decor_aux_loss = True,
decorr_sample_frac = 1.
) )
video = torch.randn(2, 3, 4, 256, 256) video = torch.randn(2, 3, 4, 256, 256)
@ -69,12 +73,13 @@ def test_e2e(
pred_orig_latent = pred_orig_latent, pred_orig_latent = pred_orig_latent,
num_discrete_actions = 4, num_discrete_actions = 4,
attn_dim_head = 16, attn_dim_head = 16,
attn_heads = heads,
attn_kwargs = dict( attn_kwargs = dict(
heads = heads,
query_heads = query_heads, query_heads = query_heads,
), ),
prob_no_shortcut_train = prob_no_shortcut_train, prob_no_shortcut_train = prob_no_shortcut_train,
add_reward_embed_to_agent_token = add_reward_embed_to_agent_token, add_reward_embed_to_agent_token = add_reward_embed_to_agent_token,
add_state_pred_head = add_state_pred_head,
num_residual_streams = num_residual_streams num_residual_streams = num_residual_streams
) )
@ -121,7 +126,7 @@ def test_e2e(
image_width = 128, image_width = 128,
batch_size = 2, batch_size = 2,
return_rewards_per_frame = True, return_rewards_per_frame = True,
use_time_kv_cache = use_time_kv_cache use_time_cache = use_time_cache
) )
assert generations.video.shape == (2, 3, 10, 128, 128) assert generations.video.shape == (2, 3, 10, 128, 128)
@ -615,9 +620,9 @@ def test_cache_generate():
num_residual_streams = 1 num_residual_streams = 1
) )
generated, time_kv_cache = dynamics.generate(1, return_time_kv_cache = True) generated, time_cache = dynamics.generate(1, return_time_cache = True)
generated, time_kv_cache = dynamics.generate(1, time_kv_cache = time_kv_cache, return_time_kv_cache = True) generated, time_cache = dynamics.generate(1, time_cache = time_cache, return_time_cache = True)
generated, time_kv_cache = dynamics.generate(1, time_kv_cache = time_kv_cache, return_time_kv_cache = True) generated, time_cache = dynamics.generate(1, time_cache = time_cache, return_time_cache = True)
@param('vectorized', (False, True)) @param('vectorized', (False, True))
@param('use_pmpo', (False, True)) @param('use_pmpo', (False, True))
@ -641,7 +646,9 @@ def test_online_rl(
dim_latent = 16, dim_latent = 16,
patch_size = 32, patch_size = 32,
attn_dim_head = 16, attn_dim_head = 16,
num_latent_tokens = 1 num_latent_tokens = 1,
image_height = 256,
image_width = 256,
) )
world_model_and_policy = DynamicsWorldModel( world_model_and_policy = DynamicsWorldModel(
@ -675,10 +682,18 @@ def test_online_rl(
# manually # manually
dream_experience = world_model_and_policy.generate(10, batch_size = 1, store_agent_embed = store_agent_embed, return_for_policy_optimization = True)
one_experience = world_model_and_policy.interact_with_env(mock_env, max_timesteps = 8, env_is_vectorized = vectorized, store_agent_embed = store_agent_embed) one_experience = world_model_and_policy.interact_with_env(mock_env, max_timesteps = 8, env_is_vectorized = vectorized, store_agent_embed = store_agent_embed)
another_experience = world_model_and_policy.interact_with_env(mock_env, max_timesteps = 16, env_is_vectorized = vectorized, store_agent_embed = store_agent_embed) another_experience = world_model_and_policy.interact_with_env(mock_env, max_timesteps = 16, env_is_vectorized = vectorized, store_agent_embed = store_agent_embed)
combined_experience = combine_experiences([one_experience, another_experience]) combined_experience = combine_experiences([dream_experience, one_experience, another_experience])
# quick test moving the experience to different devices
if torch.cuda.is_available():
combined_experience = combined_experience.to(torch.device('cuda'))
combined_experience = combined_experience.to(world_model_and_policy.device)
if store_agent_embed: if store_agent_embed:
assert exists(combined_experience.agent_embed) assert exists(combined_experience.agent_embed)
@ -795,3 +810,22 @@ def test_epo():
fitness = torch.randn(16,) fitness = torch.randn(16,)
dynamics.evolve_(fitness) dynamics.evolve_(fitness)
def test_images_to_video_tokenizer():
import torch
from dreamer4 import VideoTokenizer, DynamicsWorldModel, AxialSpaceTimeTransformer
tokenizer = VideoTokenizer(
dim = 512,
dim_latent = 32,
patch_size = 32,
image_height = 256,
image_width = 256,
encoder_add_decor_aux_loss = True
)
images = torch.randn(2, 3, 256, 256)
loss, (losses, recon_images) = tokenizer(images, return_intermediates = True)
loss.backward()
assert images.shape == recon_images.shape