Compare commits
38 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5bb027b386 | ||
|
|
9efe269688 | ||
|
|
fb8c3793b4 | ||
|
|
fb6d69f43a | ||
|
|
125693ce1c | ||
|
|
2e7f406d49 | ||
|
|
690ecf07dc | ||
|
|
ac1c12f743 | ||
|
|
3c84b404a8 | ||
|
|
d5b70e2b86 | ||
|
|
c3532fa797 | ||
|
|
73029635fe | ||
|
|
e1c41f4371 | ||
|
|
f55c61c6cf | ||
|
|
051d4d6ee2 | ||
|
|
ef3a5552e7 | ||
|
|
0c4224da18 | ||
|
|
256a81f658 | ||
|
|
cfd34f1eba | ||
|
|
4ffbe37873 | ||
|
|
24ef72d528 | ||
|
|
a4afcb22a6 | ||
|
|
b0f6b8583d | ||
|
|
38cba80068 | ||
|
|
c0a6cd56a1 | ||
|
|
d756d1bb8c | ||
|
|
60681fce1d | ||
|
|
6870294d95 | ||
|
|
3beae186da | ||
|
|
0904e224ab | ||
|
|
767789d0ca | ||
|
|
35b87c4fa1 | ||
|
|
c4a3cb09d5 | ||
|
|
cb54121ace | ||
|
|
586379f2c8 | ||
|
|
a358a44a53 | ||
|
|
3547344312 | ||
|
|
691d9ca007 |
95
README.md
95
README.md
@ -1,10 +1,99 @@
|
|||||||
<img src="./dreamer4-fig2.png" width="400px"></img>
|
<img src="./dreamer4-fig2.png" width="400px"></img>
|
||||||
|
|
||||||
## Dreamer 4 (wip)
|
## Dreamer 4
|
||||||
|
|
||||||
Implementation of Danijar's [latest iteration](https://arxiv.org/abs/2509.24527v1) for his [Dreamer](https://danijar.com/project/dreamer4/) line of work
|
Implementation of Danijar's [latest iteration](https://arxiv.org/abs/2509.24527v1) for his [Dreamer](https://danijar.com/project/dreamer4/) line of work
|
||||||
|
|
||||||
[Temporary Discord](https://discord.gg/MkACrrkrYR)
|
[Discord channel](https://discord.gg/PmGR7KRwxq) for collaborating with other researchers interested in this work
|
||||||
|
|
||||||
|
## Appreciation
|
||||||
|
|
||||||
|
- [@dirkmcpherson](https://github.com/dirkmcpherson) for fixes to typo errors and unpassed arguments!
|
||||||
|
|
||||||
|
## Install
|
||||||
|
|
||||||
|
```bash
|
||||||
|
$ pip install dreamer4
|
||||||
|
```
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
```python
|
||||||
|
import torch
|
||||||
|
from dreamer4 import VideoTokenizer, DynamicsWorldModel
|
||||||
|
|
||||||
|
# video tokenizer, learned through MAE + lpips
|
||||||
|
|
||||||
|
tokenizer = VideoTokenizer(
|
||||||
|
dim = 512,
|
||||||
|
dim_latent = 32,
|
||||||
|
patch_size = 32,
|
||||||
|
image_height = 256,
|
||||||
|
image_width = 256
|
||||||
|
)
|
||||||
|
|
||||||
|
video = torch.randn(2, 3, 10, 256, 256)
|
||||||
|
|
||||||
|
# learn the tokenizer
|
||||||
|
|
||||||
|
loss = tokenizer(video)
|
||||||
|
loss.backward()
|
||||||
|
|
||||||
|
# dynamics world model
|
||||||
|
|
||||||
|
world_model = DynamicsWorldModel(
|
||||||
|
dim = 512,
|
||||||
|
dim_latent = 32,
|
||||||
|
video_tokenizer = tokenizer,
|
||||||
|
num_discrete_actions = 4,
|
||||||
|
num_residual_streams = 1
|
||||||
|
)
|
||||||
|
|
||||||
|
# state, action, rewards
|
||||||
|
|
||||||
|
video = torch.randn(2, 3, 10, 256, 256)
|
||||||
|
discrete_actions = torch.randint(0, 4, (2, 10, 1))
|
||||||
|
rewards = torch.randn(2, 10)
|
||||||
|
|
||||||
|
# learn dynamics / behavior cloned model
|
||||||
|
|
||||||
|
loss = world_model(
|
||||||
|
video = video,
|
||||||
|
rewards = rewards,
|
||||||
|
discrete_actions = discrete_actions
|
||||||
|
)
|
||||||
|
|
||||||
|
loss.backward()
|
||||||
|
|
||||||
|
# do the above with much data
|
||||||
|
|
||||||
|
# then generate dreams
|
||||||
|
|
||||||
|
dreams = world_model.generate(
|
||||||
|
10,
|
||||||
|
batch_size = 2,
|
||||||
|
return_decoded_video = True,
|
||||||
|
return_for_policy_optimization = True
|
||||||
|
)
|
||||||
|
|
||||||
|
# learn from the dreams
|
||||||
|
|
||||||
|
actor_loss, critic_loss = world_model.learn_from_experience(dreams)
|
||||||
|
|
||||||
|
(actor_loss + critic_loss).backward()
|
||||||
|
|
||||||
|
# learn from environment
|
||||||
|
|
||||||
|
from dreamer4.mocks import MockEnv
|
||||||
|
|
||||||
|
mock_env = MockEnv((256, 256), vectorized = True, num_envs = 4)
|
||||||
|
|
||||||
|
experience = world_model.interact_with_env(mock_env, max_timesteps = 8, env_is_vectorized = True)
|
||||||
|
|
||||||
|
actor_loss, critic_loss = world_model.learn_from_experience(experience)
|
||||||
|
|
||||||
|
(actor_loss + critic_loss).backward()
|
||||||
|
```
|
||||||
|
|
||||||
## Citation
|
## Citation
|
||||||
|
|
||||||
@ -19,3 +108,5 @@ Implementation of Danijar's [latest iteration](https://arxiv.org/abs/2509.24527v
|
|||||||
url = {https://arxiv.org/abs/2509.24527},
|
url = {https://arxiv.org/abs/2509.24527},
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
*the conquest of nature is to be achieved through number and measure - angels to Descartes in a dream*
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
from dreamer4.dreamer4 import (
|
from dreamer4.dreamer4 import (
|
||||||
VideoTokenizer,
|
VideoTokenizer,
|
||||||
DynamicsWorldModel
|
DynamicsWorldModel,
|
||||||
|
AxialSpaceTimeTransformer
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@ -396,13 +396,20 @@ class SimTrainer(Module):
|
|||||||
old_values = experience.values
|
old_values = experience.values
|
||||||
rewards = experience.rewards
|
rewards = experience.rewards
|
||||||
|
|
||||||
|
has_agent_embed = exists(experience.agent_embed)
|
||||||
|
agent_embed = experience.agent_embed
|
||||||
|
|
||||||
discrete_actions, continuous_actions = experience.actions
|
discrete_actions, continuous_actions = experience.actions
|
||||||
discrete_log_probs, continuous_log_probs = experience.log_probs
|
discrete_log_probs, continuous_log_probs = experience.log_probs
|
||||||
|
|
||||||
|
discrete_old_action_unembeds, continuous_old_action_unembeds = default(experience.old_action_unembeds, (None, None))
|
||||||
|
|
||||||
# handle empties
|
# handle empties
|
||||||
|
|
||||||
empty_tensor = torch.empty_like(rewards)
|
empty_tensor = torch.empty_like(rewards)
|
||||||
|
|
||||||
|
agent_embed = default(agent_embed, empty_tensor)
|
||||||
|
|
||||||
has_discrete = exists(discrete_actions)
|
has_discrete = exists(discrete_actions)
|
||||||
has_continuous = exists(continuous_actions)
|
has_continuous = exists(continuous_actions)
|
||||||
|
|
||||||
@ -412,6 +419,9 @@ class SimTrainer(Module):
|
|||||||
discrete_log_probs = default(discrete_log_probs, empty_tensor)
|
discrete_log_probs = default(discrete_log_probs, empty_tensor)
|
||||||
continuous_log_probs = default(continuous_log_probs, empty_tensor)
|
continuous_log_probs = default(continuous_log_probs, empty_tensor)
|
||||||
|
|
||||||
|
discrete_old_action_unembeds = default(discrete_old_action_unembeds, empty_tensor)
|
||||||
|
continuous_old_action_unembeds = default(discrete_old_action_unembeds, empty_tensor)
|
||||||
|
|
||||||
# create the dataset and dataloader
|
# create the dataset and dataloader
|
||||||
|
|
||||||
dataset = TensorDataset(
|
dataset = TensorDataset(
|
||||||
@ -420,6 +430,9 @@ class SimTrainer(Module):
|
|||||||
continuous_actions,
|
continuous_actions,
|
||||||
discrete_log_probs,
|
discrete_log_probs,
|
||||||
continuous_log_probs,
|
continuous_log_probs,
|
||||||
|
agent_embed,
|
||||||
|
discrete_old_action_unembeds,
|
||||||
|
continuous_old_action_unembeds,
|
||||||
old_values,
|
old_values,
|
||||||
rewards
|
rewards
|
||||||
)
|
)
|
||||||
@ -434,6 +447,9 @@ class SimTrainer(Module):
|
|||||||
continuous_actions,
|
continuous_actions,
|
||||||
discrete_log_probs,
|
discrete_log_probs,
|
||||||
continuous_log_probs,
|
continuous_log_probs,
|
||||||
|
agent_embed,
|
||||||
|
discrete_old_action_unembeds,
|
||||||
|
continuous_old_action_unembeds,
|
||||||
old_values,
|
old_values,
|
||||||
rewards
|
rewards
|
||||||
) in dataloader:
|
) in dataloader:
|
||||||
@ -448,10 +464,17 @@ class SimTrainer(Module):
|
|||||||
continuous_log_probs if has_continuous else None
|
continuous_log_probs if has_continuous else None
|
||||||
)
|
)
|
||||||
|
|
||||||
|
old_action_unembeds = (
|
||||||
|
discrete_old_action_unembeds if has_discrete else None,
|
||||||
|
continuous_old_action_unembeds if has_continuous else None
|
||||||
|
)
|
||||||
|
|
||||||
batch_experience = Experience(
|
batch_experience = Experience(
|
||||||
latents = latents,
|
latents = latents,
|
||||||
actions = actions,
|
actions = actions,
|
||||||
log_probs = log_probs,
|
log_probs = log_probs,
|
||||||
|
agent_embed = agent_embed if has_agent_embed else None,
|
||||||
|
old_action_unembeds = old_action_unembeds,
|
||||||
values = old_values,
|
values = old_values,
|
||||||
rewards = rewards,
|
rewards = rewards,
|
||||||
step_size = step_size,
|
step_size = step_size,
|
||||||
@ -505,7 +528,7 @@ class SimTrainer(Module):
|
|||||||
|
|
||||||
total_experience += num_experience
|
total_experience += num_experience
|
||||||
|
|
||||||
experiences.append(experience)
|
experiences.append(experience.cpu())
|
||||||
|
|
||||||
combined_experiences = combine_experiences(experiences)
|
combined_experiences = combine_experiences(experiences)
|
||||||
|
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "dreamer4"
|
name = "dreamer4"
|
||||||
version = "0.0.91"
|
version = "0.1.24"
|
||||||
description = "Dreamer 4"
|
description = "Dreamer 4"
|
||||||
authors = [
|
authors = [
|
||||||
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
||||||
@ -36,7 +36,8 @@ dependencies = [
|
|||||||
"hyper-connections>=0.2.1",
|
"hyper-connections>=0.2.1",
|
||||||
"torch>=2.4",
|
"torch>=2.4",
|
||||||
"torchvision",
|
"torchvision",
|
||||||
"x-mlps-pytorch>=0.0.29"
|
"x-mlps-pytorch>=0.0.29",
|
||||||
|
"vit-pytorch>=1.15.3"
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.urls]
|
[project.urls]
|
||||||
|
|||||||
@ -15,7 +15,8 @@ def exists(v):
|
|||||||
@param('condition_on_actions', (False, True))
|
@param('condition_on_actions', (False, True))
|
||||||
@param('num_residual_streams', (1, 4))
|
@param('num_residual_streams', (1, 4))
|
||||||
@param('add_reward_embed_to_agent_token', (False, True))
|
@param('add_reward_embed_to_agent_token', (False, True))
|
||||||
@param('use_time_kv_cache', (False, True))
|
@param('add_state_pred_head', (False, True))
|
||||||
|
@param('use_time_cache', (False, True))
|
||||||
@param('var_len', (False, True))
|
@param('var_len', (False, True))
|
||||||
def test_e2e(
|
def test_e2e(
|
||||||
pred_orig_latent,
|
pred_orig_latent,
|
||||||
@ -28,7 +29,8 @@ def test_e2e(
|
|||||||
condition_on_actions,
|
condition_on_actions,
|
||||||
num_residual_streams,
|
num_residual_streams,
|
||||||
add_reward_embed_to_agent_token,
|
add_reward_embed_to_agent_token,
|
||||||
use_time_kv_cache,
|
add_state_pred_head,
|
||||||
|
use_time_cache,
|
||||||
var_len
|
var_len
|
||||||
):
|
):
|
||||||
from dreamer4.dreamer4 import VideoTokenizer, DynamicsWorldModel
|
from dreamer4.dreamer4 import VideoTokenizer, DynamicsWorldModel
|
||||||
@ -41,7 +43,9 @@ def test_e2e(
|
|||||||
patch_size = 32,
|
patch_size = 32,
|
||||||
attn_dim_head = 16,
|
attn_dim_head = 16,
|
||||||
num_latent_tokens = 4,
|
num_latent_tokens = 4,
|
||||||
num_residual_streams = num_residual_streams
|
num_residual_streams = num_residual_streams,
|
||||||
|
encoder_add_decor_aux_loss = True,
|
||||||
|
decorr_sample_frac = 1.
|
||||||
)
|
)
|
||||||
|
|
||||||
video = torch.randn(2, 3, 4, 256, 256)
|
video = torch.randn(2, 3, 4, 256, 256)
|
||||||
@ -69,12 +73,13 @@ def test_e2e(
|
|||||||
pred_orig_latent = pred_orig_latent,
|
pred_orig_latent = pred_orig_latent,
|
||||||
num_discrete_actions = 4,
|
num_discrete_actions = 4,
|
||||||
attn_dim_head = 16,
|
attn_dim_head = 16,
|
||||||
|
attn_heads = heads,
|
||||||
attn_kwargs = dict(
|
attn_kwargs = dict(
|
||||||
heads = heads,
|
|
||||||
query_heads = query_heads,
|
query_heads = query_heads,
|
||||||
),
|
),
|
||||||
prob_no_shortcut_train = prob_no_shortcut_train,
|
prob_no_shortcut_train = prob_no_shortcut_train,
|
||||||
add_reward_embed_to_agent_token = add_reward_embed_to_agent_token,
|
add_reward_embed_to_agent_token = add_reward_embed_to_agent_token,
|
||||||
|
add_state_pred_head = add_state_pred_head,
|
||||||
num_residual_streams = num_residual_streams
|
num_residual_streams = num_residual_streams
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -121,7 +126,7 @@ def test_e2e(
|
|||||||
image_width = 128,
|
image_width = 128,
|
||||||
batch_size = 2,
|
batch_size = 2,
|
||||||
return_rewards_per_frame = True,
|
return_rewards_per_frame = True,
|
||||||
use_time_kv_cache = use_time_kv_cache
|
use_time_cache = use_time_cache
|
||||||
)
|
)
|
||||||
|
|
||||||
assert generations.video.shape == (2, 3, 10, 128, 128)
|
assert generations.video.shape == (2, 3, 10, 128, 128)
|
||||||
@ -346,6 +351,15 @@ def test_action_embedder():
|
|||||||
assert discrete_logits.shape == (2, 3, 8)
|
assert discrete_logits.shape == (2, 3, 8)
|
||||||
assert continuous_mean_log_var.shape == (2, 3, 2, 2)
|
assert continuous_mean_log_var.shape == (2, 3, 2, 2)
|
||||||
|
|
||||||
|
# test kl div
|
||||||
|
|
||||||
|
discrete_logits_tgt, continuous_mean_log_var_tgt = embedder.unembed(action_embed)
|
||||||
|
|
||||||
|
discrete_kl_div, continuous_kl_div = embedder.kl_div((discrete_logits, continuous_mean_log_var), (discrete_logits_tgt, continuous_mean_log_var_tgt))
|
||||||
|
|
||||||
|
assert discrete_kl_div.shape == (2, 3)
|
||||||
|
assert continuous_kl_div.shape == (2, 3)
|
||||||
|
|
||||||
# return discrete split by number of actions
|
# return discrete split by number of actions
|
||||||
|
|
||||||
discrete_logits, continuous_mean_log_var = embedder.unembed(action_embed, return_split_discrete = True)
|
discrete_logits, continuous_mean_log_var = embedder.unembed(action_embed, return_split_discrete = True)
|
||||||
@ -606,9 +620,9 @@ def test_cache_generate():
|
|||||||
num_residual_streams = 1
|
num_residual_streams = 1
|
||||||
)
|
)
|
||||||
|
|
||||||
generated, time_kv_cache = dynamics.generate(1, return_time_kv_cache = True)
|
generated, time_cache = dynamics.generate(1, return_time_cache = True)
|
||||||
generated, time_kv_cache = dynamics.generate(1, time_kv_cache = time_kv_cache, return_time_kv_cache = True)
|
generated, time_cache = dynamics.generate(1, time_cache = time_cache, return_time_cache = True)
|
||||||
generated, time_kv_cache = dynamics.generate(1, time_kv_cache = time_kv_cache, return_time_kv_cache = True)
|
generated, time_cache = dynamics.generate(1, time_cache = time_cache, return_time_cache = True)
|
||||||
|
|
||||||
@param('vectorized', (False, True))
|
@param('vectorized', (False, True))
|
||||||
@param('use_pmpo', (False, True))
|
@param('use_pmpo', (False, True))
|
||||||
@ -632,7 +646,9 @@ def test_online_rl(
|
|||||||
dim_latent = 16,
|
dim_latent = 16,
|
||||||
patch_size = 32,
|
patch_size = 32,
|
||||||
attn_dim_head = 16,
|
attn_dim_head = 16,
|
||||||
num_latent_tokens = 1
|
num_latent_tokens = 1,
|
||||||
|
image_height = 256,
|
||||||
|
image_width = 256,
|
||||||
)
|
)
|
||||||
|
|
||||||
world_model_and_policy = DynamicsWorldModel(
|
world_model_and_policy = DynamicsWorldModel(
|
||||||
@ -666,10 +682,18 @@ def test_online_rl(
|
|||||||
|
|
||||||
# manually
|
# manually
|
||||||
|
|
||||||
|
dream_experience = world_model_and_policy.generate(10, batch_size = 1, store_agent_embed = store_agent_embed, return_for_policy_optimization = True)
|
||||||
|
|
||||||
one_experience = world_model_and_policy.interact_with_env(mock_env, max_timesteps = 8, env_is_vectorized = vectorized, store_agent_embed = store_agent_embed)
|
one_experience = world_model_and_policy.interact_with_env(mock_env, max_timesteps = 8, env_is_vectorized = vectorized, store_agent_embed = store_agent_embed)
|
||||||
another_experience = world_model_and_policy.interact_with_env(mock_env, max_timesteps = 16, env_is_vectorized = vectorized, store_agent_embed = store_agent_embed)
|
another_experience = world_model_and_policy.interact_with_env(mock_env, max_timesteps = 16, env_is_vectorized = vectorized, store_agent_embed = store_agent_embed)
|
||||||
|
|
||||||
combined_experience = combine_experiences([one_experience, another_experience])
|
combined_experience = combine_experiences([dream_experience, one_experience, another_experience])
|
||||||
|
|
||||||
|
# quick test moving the experience to different devices
|
||||||
|
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
combined_experience = combined_experience.to(torch.device('cuda'))
|
||||||
|
combined_experience = combined_experience.to(world_model_and_policy.device)
|
||||||
|
|
||||||
if store_agent_embed:
|
if store_agent_embed:
|
||||||
assert exists(combined_experience.agent_embed)
|
assert exists(combined_experience.agent_embed)
|
||||||
@ -786,3 +810,22 @@ def test_epo():
|
|||||||
|
|
||||||
fitness = torch.randn(16,)
|
fitness = torch.randn(16,)
|
||||||
dynamics.evolve_(fitness)
|
dynamics.evolve_(fitness)
|
||||||
|
|
||||||
|
def test_images_to_video_tokenizer():
|
||||||
|
import torch
|
||||||
|
from dreamer4 import VideoTokenizer, DynamicsWorldModel, AxialSpaceTimeTransformer
|
||||||
|
|
||||||
|
tokenizer = VideoTokenizer(
|
||||||
|
dim = 512,
|
||||||
|
dim_latent = 32,
|
||||||
|
patch_size = 32,
|
||||||
|
image_height = 256,
|
||||||
|
image_width = 256,
|
||||||
|
encoder_add_decor_aux_loss = True
|
||||||
|
)
|
||||||
|
|
||||||
|
images = torch.randn(2, 3, 256, 256)
|
||||||
|
loss, (losses, recon_images) = tokenizer(images, return_intermediates = True)
|
||||||
|
loss.backward()
|
||||||
|
|
||||||
|
assert images.shape == recon_images.shape
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user