Compare commits

...

66 Commits
0.0.64 ... main

Author SHA1 Message Date
lucidrains
5bb027b386 allow for image pretraining on video tokenizer 2025-12-04 10:34:15 -08:00
lucidrains
9efe269688 oops 2025-12-03 08:11:47 -08:00
lucidrains
fb8c3793b4 complete the addition of a state entropy bonus 2025-12-03 07:52:30 -08:00
lucidrains
fb6d69f43a complete the latent autoregressive prediction, to use the log variance as a state entropy bonus 2025-12-03 06:40:19 -08:00
lucidrains
125693ce1c add a separate state prediction head for the state entropy 2025-12-02 15:58:25 -08:00
lucidrains
2e7f406d49 allow for the combining of experiences from environment and dream 2025-11-13 16:37:35 -08:00
lucidrains
690ecf07dc fix the rnn time caching issue 2025-11-11 17:04:02 -08:00
lucidrains
ac1c12f743 disable until rnn hiddens are handled properly 2025-11-10 15:52:43 -08:00
lucidrains
3c84b404a8 rnn layer needs to be hyper connected too 2025-11-10 15:51:33 -08:00
lucidrains
d5b70e2b86 allow for adding an RNN before time attention, but need to handle caching still 2025-11-10 11:42:20 -08:00
lucidrains
c3532fa797 add learned value residual 2025-11-10 09:33:58 -08:00
lucidrains
73029635fe last commit for the day 2025-11-09 11:12:37 -08:00
lucidrains
e1c41f4371 decorrelation loss for spatial attention as well 2025-11-09 10:41:58 -08:00
Phil Wang
f55c61c6cf
cleanup 2025-11-09 10:22:47 -08:00
lucidrains
051d4d6ee2 oops 2025-11-09 10:12:51 -08:00
lucidrains
ef3a5552e7 eventually video tokenizer may need to be trained on single frames 2025-11-09 10:11:56 -08:00
lucidrains
0c4224da18 add a decorrelation loss for temporal attention in encoder of video tokenizer 2025-11-09 09:47:47 -08:00
Phil Wang
256a81f658
Merge pull request #5 from Cycl0/patch-1
Update Discord channel link in README to use permanent link
2025-11-09 08:17:41 -08:00
lucidrains
cfd34f1eba able to move the experience to cpu easily, and auto matically move it to the device of the dynamics world model when learning from it 2025-11-09 16:16:13 +00:00
Lucas Kenzo Cyra
4ffbe37873
Update Discord channel link in README to use permanent link
Updated Discord channel link for collaboration.
2025-11-09 10:12:45 -03:00
lucidrains
24ef72d528 0.1.4 2025-11-04 15:21:20 -08:00
Phil Wang
a4afcb22a6
Merge pull request #4 from dirkmcpherson/bugfix
fix a few typo bugs. Support info in return signature of environment …
2025-11-04 15:19:25 -08:00
j
b0f6b8583d fix a few typo bugs. Support info in return signature of environment step. Temporarily turn off flex attention when the kv_cache is used to avoid bug. 2025-11-04 17:29:12 -05:00
lucidrains
38cba80068 readme 2025-11-04 06:05:11 -08:00
lucidrains
c0a6cd56a1 link to new discord 2025-10-31 09:06:44 -07:00
lucidrains
d756d1bb8c addressing issues raised by an independent researcher with llm assistance 2025-10-31 08:37:39 -07:00
lucidrains
60681fce1d fix generation so that one more step is taken to decode agent embeds off the final cleaned set of latents, update readme 2025-10-31 06:48:49 -07:00
Phil Wang
6870294d95
no longer needed 2025-10-30 09:23:27 -07:00
lucidrains
3beae186da some more control over whether to normalize advantages 2025-10-30 08:46:03 -07:00
lucidrains
0904e224ab make the reverse kl optional 2025-10-30 08:22:50 -07:00
lucidrains
767789d0ca they decided on 0.3 for the behavioral prior loss weight 2025-10-29 13:24:58 -07:00
lucidrains
35b87c4fa1 oops 2025-10-29 13:04:02 -07:00
lucidrains
c4a3cb09d5 swap for discrete kl div, thanks to Dirk for pointing this out on the discord 2025-10-29 11:54:18 -07:00
lucidrains
cb54121ace sim trainer needs to take care of agent embedding and old actions 2025-10-29 11:15:11 -07:00
lucidrains
586379f2c8 sum the kl div loss across number of actions by default for action embedder .kl_div 2025-10-29 10:46:42 -07:00
lucidrains
a358a44a53 always store old agent embeds and old action parameters when possible 2025-10-29 10:39:15 -07:00
lucidrains
3547344312 take care of storing the old action logits and mean log var, and calculate kl div for pmpo based off that during learn from experience 2025-10-29 10:31:32 -07:00
lucidrains
691d9ca007 add kl div on action embedder, working way towards the kl div loss in pmpo 2025-10-29 10:02:25 -07:00
lucidrains
91d697f8ca fix pmpo 2025-10-28 18:55:22 -07:00
lucidrains
7acaa764f6 evolutionary policy optimization on dreams will be interesting 2025-10-28 10:17:01 -07:00
lucidrains
c0450359f3 allow for evolutionary policy optimization 2025-10-28 10:11:13 -07:00
lucidrains
46f86cd247 fix storing of agent embedding 2025-10-28 09:36:58 -07:00
lucidrains
903c43b770 use the agent embeds off the stored experience if available 2025-10-28 09:14:02 -07:00
lucidrains
d476fa7b14 able to store the agent embeddings during rollouts with imagination or environment, for efficient policy optimization (but will also allow for finetuning world model for the heads) 2025-10-28 09:02:26 -07:00
lucidrains
789f091c63 redo so that max timesteps is treated as truncation at the last timestep, then allow for accepting the truncation signal from the environment and reuse same logic 2025-10-28 08:04:48 -07:00
lucidrains
41ab83f691 fix mock 2025-10-27 10:47:24 -07:00
lucidrains
995b1f64e5 handle environments that return a terminate flag, also make sure episode lens are logged in vectorized env 2025-10-27 10:14:28 -07:00
lucidrains
fd1e87983b quantile filter 2025-10-27 09:08:26 -07:00
lucidrains
fe79bfa951 optionally keep track of returns statistics and normalize with them before advantage 2025-10-27 09:02:08 -07:00
lucidrains
f808b1c1d2 oops 2025-10-27 08:34:22 -07:00
lucidrains
349a03acd7 redo so lens is always the episode length, including the bootstrap value timestep, and use is_truncated to mask out the bootstrap node from being learned on 2025-10-27 08:06:21 -07:00
lucidrains
59c458aea3 introduce an is_truncated field on Experience, and mask out rewards and values before calculating gae appropriately 2025-10-27 07:55:00 -07:00
lucidrains
fbfd59e42f handle variable lengthed experiences when doing policy optimization 2025-10-27 06:09:09 -07:00
lucidrains
46432aee9b fix an issue with bc 2025-10-25 12:30:08 -07:00
lucidrains
f97d9adc97 oops, forgot to add the view embedding for robotics 2025-10-25 11:39:06 -07:00
lucidrains
32cf142b4d take another step for variable len experiences 2025-10-25 11:31:41 -07:00
lucidrains
1ed6a15cb0 fix tests 2025-10-25 11:13:22 -07:00
lucidrains
4d8f5613cc start storing the experience lens 2025-10-25 10:55:47 -07:00
lucidrains
3d5617d769 take a step towards variable lengthed experiences during training 2025-10-25 10:45:34 -07:00
lucidrains
77a40e8701 validate that we can generate multiple video streams for robotics use-case 2025-10-25 09:23:07 -07:00
lucidrains
4ce82f34df given the VAT paper, add multiple video streams (third person, wrist camera, etc), geared for robotics. need to manage an extra dimension for multiple viewpoints 2025-10-25 09:20:55 -07:00
lucidrains
a9b728c611 incorporate proprioception into the dynamics world model 2025-10-24 11:24:22 -07:00
lucidrains
35c1db4c7d sketch of training from sim env 2025-10-24 09:13:09 -07:00
lucidrains
27ac05efb0 function for combining experiences 2025-10-24 08:00:10 -07:00
lucidrains
d0ffc6bfed with or without signed advantage 2025-10-23 16:24:29 -07:00
lucidrains
fb3e026fe0 handle vectorized env 2025-10-22 11:19:44 -07:00
8 changed files with 1810 additions and 228 deletions

View File

@ -5,11 +5,11 @@ jobs:
build: build:
runs-on: ubuntu-latest runs-on: ubuntu-latest
timeout-minutes: 20 timeout-minutes: 60
strategy: strategy:
fail-fast: false fail-fast: false
matrix: matrix:
group: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] group: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19]
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
@ -24,4 +24,4 @@ jobs:
python -m uv pip install -e .[test] python -m uv pip install -e .[test]
- name: Test with pytest - name: Test with pytest
run: | run: |
python -m pytest --num-shards 10 --shard-id ${{ matrix.group }} tests/ python -m pytest --num-shards 20 --shard-id ${{ matrix.group }} tests/

View File

@ -1,10 +1,99 @@
<img src="./dreamer4-fig2.png" width="400px"></img> <img src="./dreamer4-fig2.png" width="400px"></img>
## Dreamer 4 (wip) ## Dreamer 4
Implementation of Danijar's [latest iteration](https://arxiv.org/abs/2509.24527v1) for his [Dreamer](https://danijar.com/project/dreamer4/) line of work Implementation of Danijar's [latest iteration](https://arxiv.org/abs/2509.24527v1) for his [Dreamer](https://danijar.com/project/dreamer4/) line of work
[Temporary Discord](https://discord.gg/MkACrrkrYR) [Discord channel](https://discord.gg/PmGR7KRwxq) for collaborating with other researchers interested in this work
## Appreciation
- [@dirkmcpherson](https://github.com/dirkmcpherson) for fixes to typo errors and unpassed arguments!
## Install
```bash
$ pip install dreamer4
```
## Usage
```python
import torch
from dreamer4 import VideoTokenizer, DynamicsWorldModel
# video tokenizer, learned through MAE + lpips
tokenizer = VideoTokenizer(
dim = 512,
dim_latent = 32,
patch_size = 32,
image_height = 256,
image_width = 256
)
video = torch.randn(2, 3, 10, 256, 256)
# learn the tokenizer
loss = tokenizer(video)
loss.backward()
# dynamics world model
world_model = DynamicsWorldModel(
dim = 512,
dim_latent = 32,
video_tokenizer = tokenizer,
num_discrete_actions = 4,
num_residual_streams = 1
)
# state, action, rewards
video = torch.randn(2, 3, 10, 256, 256)
discrete_actions = torch.randint(0, 4, (2, 10, 1))
rewards = torch.randn(2, 10)
# learn dynamics / behavior cloned model
loss = world_model(
video = video,
rewards = rewards,
discrete_actions = discrete_actions
)
loss.backward()
# do the above with much data
# then generate dreams
dreams = world_model.generate(
10,
batch_size = 2,
return_decoded_video = True,
return_for_policy_optimization = True
)
# learn from the dreams
actor_loss, critic_loss = world_model.learn_from_experience(dreams)
(actor_loss + critic_loss).backward()
# learn from environment
from dreamer4.mocks import MockEnv
mock_env = MockEnv((256, 256), vectorized = True, num_envs = 4)
experience = world_model.interact_with_env(mock_env, max_timesteps = 8, env_is_vectorized = True)
actor_loss, critic_loss = world_model.learn_from_experience(experience)
(actor_loss + critic_loss).backward()
```
## Citation ## Citation
@ -19,3 +108,5 @@ Implementation of Danijar's [latest iteration](https://arxiv.org/abs/2509.24527v
url = {https://arxiv.org/abs/2509.24527}, url = {https://arxiv.org/abs/2509.24527},
} }
``` ```
*the conquest of nature is to be achieved through number and measure - angels to Descartes in a dream*

View File

@ -1,6 +1,7 @@
from dreamer4.dreamer4 import ( from dreamer4.dreamer4 import (
VideoTokenizer, VideoTokenizer,
DynamicsWorldModel DynamicsWorldModel,
AxialSpaceTimeTransformer
) )

File diff suppressed because it is too large Load Diff

View File

@ -7,6 +7,11 @@ from torch.nn import Module
from einops import repeat from einops import repeat
# helpers
def exists(v):
return v is not None
# mock env # mock env
class MockEnv(Module): class MockEnv(Module):
@ -15,7 +20,11 @@ class MockEnv(Module):
image_shape, image_shape,
reward_range = (-100, 100), reward_range = (-100, 100),
num_envs = 1, num_envs = 1,
vectorized = False vectorized = False,
terminate_after_step = None,
rand_terminate_prob = 0.05,
can_truncate = False,
rand_truncate_prob = 0.05,
): ):
super().__init__() super().__init__()
self.image_shape = image_shape self.image_shape = image_shape
@ -23,6 +32,17 @@ class MockEnv(Module):
self.num_envs = num_envs self.num_envs = num_envs
self.vectorized = vectorized self.vectorized = vectorized
assert not (vectorized and num_envs == 1)
# mocking termination and truncation
self.can_terminate = exists(terminate_after_step)
self.terminate_after_step = terminate_after_step
self.rand_terminate_prob = rand_terminate_prob
self.can_truncate = can_truncate
self.rand_truncate_prob = rand_truncate_prob
self.register_buffer('_step', tensor(0)) self.register_buffer('_step', tensor(0))
def get_random_state(self): def get_random_state(self):
@ -33,7 +53,12 @@ class MockEnv(Module):
seed = None seed = None
): ):
self._step.zero_() self._step.zero_()
return self.get_random_state() state = self.get_random_state()
if self.vectorized:
state = repeat(state, '... -> b ...', b = self.num_envs)
return state
def step( def step(
self, self,
@ -43,12 +68,30 @@ class MockEnv(Module):
reward = empty(()).uniform_(*self.reward_range) reward = empty(()).uniform_(*self.reward_range)
if not self.vectorized: if self.vectorized:
return state, reward discrete, continuous = actions
assert discrete.shape[0] == self.num_envs, f'expected batch of actions for {self.num_envs} environments'
assert actions.shape[0] == self.num_envs state = repeat(state, '... -> b ...', b = self.num_envs)
reward = repeat(reward, ' -> b', b = self.num_envs)
state = repeat(state, '... -> b ...', b = self.num_envs) out = (state, reward)
reward = repeat(reward, ' -> b', b = self.num_envs)
return state, rewards
if self.can_terminate:
shape = (self.num_envs,) if self.vectorized else (1,)
valid_step = self._step > self.terminate_after_step
terminate = (torch.rand(shape) < self.rand_terminate_prob) & valid_step
out = (*out, terminate)
# maybe truncation
if self.can_truncate:
truncate = (torch.rand(shape) < self.rand_truncate_prob) & valid_step & ~terminate
out = (*out, truncate)
self._step.add_(1)
return out

View File

@ -4,7 +4,7 @@ import torch
from torch import is_tensor from torch import is_tensor
from torch.nn import Module from torch.nn import Module
from torch.optim import AdamW from torch.optim import AdamW
from torch.utils.data import Dataset, DataLoader from torch.utils.data import Dataset, TensorDataset, DataLoader
from accelerate import Accelerator from accelerate import Accelerator
@ -12,7 +12,9 @@ from adam_atan2_pytorch import MuonAdamAtan2
from dreamer4.dreamer4 import ( from dreamer4.dreamer4 import (
VideoTokenizer, VideoTokenizer,
DynamicsWorldModel DynamicsWorldModel,
Experience,
combine_experiences
) )
from ema_pytorch import EMA from ema_pytorch import EMA
@ -285,7 +287,7 @@ class DreamTrainer(Module):
for _ in range(self.num_train_steps): for _ in range(self.num_train_steps):
dreams = self.unwrapped_model.generate( dreams = self.unwrapped_model.generate(
self.generate_timesteps, self.generate_timesteps + 1, # plus one for bootstrap value
batch_size = self.batch_size, batch_size = self.batch_size,
return_rewards_per_frame = True, return_rewards_per_frame = True,
return_agent_actions = True, return_agent_actions = True,
@ -317,3 +319,221 @@ class DreamTrainer(Module):
self.value_head_optim.zero_grad() self.value_head_optim.zero_grad()
self.print('training complete') self.print('training complete')
# training from sim
class SimTrainer(Module):
def __init__(
self,
model: DynamicsWorldModel,
optim_klass = AdamW,
batch_size = 16,
generate_timesteps = 16,
learning_rate = 3e-4,
max_grad_norm = None,
epochs = 2,
weight_decay = 0.,
accelerate_kwargs: dict = dict(),
optim_kwargs: dict = dict(),
cpu = False,
):
super().__init__()
self.accelerator = Accelerator(
cpu = cpu,
**accelerate_kwargs
)
self.model = model
optim_kwargs = dict(
lr = learning_rate,
weight_decay = weight_decay
)
self.policy_head_optim = AdamW(model.policy_head_parameters(), **optim_kwargs)
self.value_head_optim = AdamW(model.value_head_parameters(), **optim_kwargs)
self.max_grad_norm = max_grad_norm
self.epochs = epochs
self.batch_size = batch_size
self.generate_timesteps = generate_timesteps
self.unwrapped_model = self.model
(
self.model,
self.policy_head_optim,
self.value_head_optim,
) = self.accelerator.prepare(
self.model,
self.policy_head_optim,
self.value_head_optim
)
@property
def device(self):
return self.accelerator.device
@property
def unwrapped_model(self):
return self.accelerator.unwrap_model(self.model)
def print(self, *args, **kwargs):
return self.accelerator.print(*args, **kwargs)
def learn(
self,
experience: Experience
):
step_size = experience.step_size
agent_index = experience.agent_index
latents = experience.latents
old_values = experience.values
rewards = experience.rewards
has_agent_embed = exists(experience.agent_embed)
agent_embed = experience.agent_embed
discrete_actions, continuous_actions = experience.actions
discrete_log_probs, continuous_log_probs = experience.log_probs
discrete_old_action_unembeds, continuous_old_action_unembeds = default(experience.old_action_unembeds, (None, None))
# handle empties
empty_tensor = torch.empty_like(rewards)
agent_embed = default(agent_embed, empty_tensor)
has_discrete = exists(discrete_actions)
has_continuous = exists(continuous_actions)
discrete_actions = default(discrete_actions, empty_tensor)
continuous_actions = default(continuous_actions, empty_tensor)
discrete_log_probs = default(discrete_log_probs, empty_tensor)
continuous_log_probs = default(continuous_log_probs, empty_tensor)
discrete_old_action_unembeds = default(discrete_old_action_unembeds, empty_tensor)
continuous_old_action_unembeds = default(discrete_old_action_unembeds, empty_tensor)
# create the dataset and dataloader
dataset = TensorDataset(
latents,
discrete_actions,
continuous_actions,
discrete_log_probs,
continuous_log_probs,
agent_embed,
discrete_old_action_unembeds,
continuous_old_action_unembeds,
old_values,
rewards
)
dataloader = DataLoader(dataset, batch_size = self.batch_size, shuffle = True)
for epoch in range(self.epochs):
for (
latents,
discrete_actions,
continuous_actions,
discrete_log_probs,
continuous_log_probs,
agent_embed,
discrete_old_action_unembeds,
continuous_old_action_unembeds,
old_values,
rewards
) in dataloader:
actions = (
discrete_actions if has_discrete else None,
continuous_actions if has_continuous else None
)
log_probs = (
discrete_log_probs if has_discrete else None,
continuous_log_probs if has_continuous else None
)
old_action_unembeds = (
discrete_old_action_unembeds if has_discrete else None,
continuous_old_action_unembeds if has_continuous else None
)
batch_experience = Experience(
latents = latents,
actions = actions,
log_probs = log_probs,
agent_embed = agent_embed if has_agent_embed else None,
old_action_unembeds = old_action_unembeds,
values = old_values,
rewards = rewards,
step_size = step_size,
agent_index = agent_index
)
policy_head_loss, value_head_loss = self.model.learn_from_experience(batch_experience)
self.print(f'policy head loss: {policy_head_loss.item():.3f} | value head loss: {value_head_loss.item():.3f}')
# update policy head
self.accelerator.backward(policy_head_loss)
if exists(self.max_grad_norm):
self.accelerator.clip_grad_norm_(self.model.policy_head_parameters()(), self.max_grad_norm)
self.policy_head_optim.step()
self.policy_head_optim.zero_grad()
# update value head
self.accelerator.backward(value_head_loss)
if exists(self.max_grad_norm):
self.accelerator.clip_grad_norm_(self.model.value_head_parameters(), self.max_grad_norm)
self.value_head_optim.step()
self.value_head_optim.zero_grad()
self.print('training complete')
def forward(
self,
env,
num_episodes = 50000,
max_experiences_before_learn = 8,
env_is_vectorized = False
):
for _ in range(num_episodes):
total_experience = 0
experiences = []
while total_experience < max_experiences_before_learn:
experience = self.unwrapped_model.interact_with_env(env, env_is_vectorized = env_is_vectorized)
num_experience = experience.video.shape[0]
total_experience += num_experience
experiences.append(experience.cpu())
combined_experiences = combine_experiences(experiences)
self.learn(combined_experiences)
experiences.clear()
self.print('training complete')

View File

@ -1,6 +1,6 @@
[project] [project]
name = "dreamer4" name = "dreamer4"
version = "0.0.62" version = "0.1.24"
description = "Dreamer 4" description = "Dreamer 4"
authors = [ authors = [
{ name = "Phil Wang", email = "lucidrains@gmail.com" } { name = "Phil Wang", email = "lucidrains@gmail.com" }
@ -36,7 +36,8 @@ dependencies = [
"hyper-connections>=0.2.1", "hyper-connections>=0.2.1",
"torch>=2.4", "torch>=2.4",
"torchvision", "torchvision",
"x-mlps-pytorch>=0.0.29" "x-mlps-pytorch>=0.0.29",
"vit-pytorch>=1.15.3"
] ]
[project.urls] [project.urls]

View File

@ -2,6 +2,9 @@ import pytest
param = pytest.mark.parametrize param = pytest.mark.parametrize
import torch import torch
def exists(v):
return v is not None
@param('pred_orig_latent', (False, True)) @param('pred_orig_latent', (False, True))
@param('grouped_query_attn', (False, True)) @param('grouped_query_attn', (False, True))
@param('dynamics_with_video_input', (False, True)) @param('dynamics_with_video_input', (False, True))
@ -12,7 +15,9 @@ import torch
@param('condition_on_actions', (False, True)) @param('condition_on_actions', (False, True))
@param('num_residual_streams', (1, 4)) @param('num_residual_streams', (1, 4))
@param('add_reward_embed_to_agent_token', (False, True)) @param('add_reward_embed_to_agent_token', (False, True))
@param('use_time_kv_cache', (False, True)) @param('add_state_pred_head', (False, True))
@param('use_time_cache', (False, True))
@param('var_len', (False, True))
def test_e2e( def test_e2e(
pred_orig_latent, pred_orig_latent,
grouped_query_attn, grouped_query_attn,
@ -24,7 +29,9 @@ def test_e2e(
condition_on_actions, condition_on_actions,
num_residual_streams, num_residual_streams,
add_reward_embed_to_agent_token, add_reward_embed_to_agent_token,
use_time_kv_cache add_state_pred_head,
use_time_cache,
var_len
): ):
from dreamer4.dreamer4 import VideoTokenizer, DynamicsWorldModel from dreamer4.dreamer4 import VideoTokenizer, DynamicsWorldModel
@ -36,7 +43,9 @@ def test_e2e(
patch_size = 32, patch_size = 32,
attn_dim_head = 16, attn_dim_head = 16,
num_latent_tokens = 4, num_latent_tokens = 4,
num_residual_streams = num_residual_streams num_residual_streams = num_residual_streams,
encoder_add_decor_aux_loss = True,
decorr_sample_frac = 1.
) )
video = torch.randn(2, 3, 4, 256, 256) video = torch.randn(2, 3, 4, 256, 256)
@ -64,12 +73,13 @@ def test_e2e(
pred_orig_latent = pred_orig_latent, pred_orig_latent = pred_orig_latent,
num_discrete_actions = 4, num_discrete_actions = 4,
attn_dim_head = 16, attn_dim_head = 16,
attn_heads = heads,
attn_kwargs = dict( attn_kwargs = dict(
heads = heads,
query_heads = query_heads, query_heads = query_heads,
), ),
prob_no_shortcut_train = prob_no_shortcut_train, prob_no_shortcut_train = prob_no_shortcut_train,
add_reward_embed_to_agent_token = add_reward_embed_to_agent_token, add_reward_embed_to_agent_token = add_reward_embed_to_agent_token,
add_state_pred_head = add_state_pred_head,
num_residual_streams = num_residual_streams num_residual_streams = num_residual_streams
) )
@ -92,8 +102,13 @@ def test_e2e(
if condition_on_actions: if condition_on_actions:
actions = torch.randint(0, 4, (2, 3, 1)) actions = torch.randint(0, 4, (2, 3, 1))
lens = None
if var_len:
lens = torch.randint(1, 4, (2,))
flow_loss = dynamics( flow_loss = dynamics(
**dynamics_input, **dynamics_input,
lens = lens,
tasks = tasks, tasks = tasks,
signal_levels = signal_levels, signal_levels = signal_levels,
step_sizes_log2 = step_sizes_log2, step_sizes_log2 = step_sizes_log2,
@ -111,7 +126,7 @@ def test_e2e(
image_width = 128, image_width = 128,
batch_size = 2, batch_size = 2,
return_rewards_per_frame = True, return_rewards_per_frame = True,
use_time_kv_cache = use_time_kv_cache use_time_cache = use_time_cache
) )
assert generations.video.shape == (2, 3, 10, 128, 128) assert generations.video.shape == (2, 3, 10, 128, 128)
@ -336,6 +351,15 @@ def test_action_embedder():
assert discrete_logits.shape == (2, 3, 8) assert discrete_logits.shape == (2, 3, 8)
assert continuous_mean_log_var.shape == (2, 3, 2, 2) assert continuous_mean_log_var.shape == (2, 3, 2, 2)
# test kl div
discrete_logits_tgt, continuous_mean_log_var_tgt = embedder.unembed(action_embed)
discrete_kl_div, continuous_kl_div = embedder.kl_div((discrete_logits, continuous_mean_log_var), (discrete_logits_tgt, continuous_mean_log_var_tgt))
assert discrete_kl_div.shape == (2, 3)
assert continuous_kl_div.shape == (2, 3)
# return discrete split by number of actions # return discrete split by number of actions
discrete_logits, continuous_mean_log_var = embedder.unembed(action_embed, return_split_discrete = True) discrete_logits, continuous_mean_log_var = embedder.unembed(action_embed, return_split_discrete = True)
@ -397,14 +421,14 @@ def test_mtp():
reward_targets, mask = create_multi_token_prediction_targets(rewards, 3) # say three token lookahead reward_targets, mask = create_multi_token_prediction_targets(rewards, 3) # say three token lookahead
assert reward_targets.shape == (3, 15, 3) assert reward_targets.shape == (3, 16, 3)
assert mask.shape == (3, 15, 3) assert mask.shape == (3, 16, 3)
actions = torch.randint(0, 10, (3, 16, 2)) actions = torch.randint(0, 10, (3, 16, 2))
action_targets, mask = create_multi_token_prediction_targets(actions, 3) action_targets, mask = create_multi_token_prediction_targets(actions, 3)
assert action_targets.shape == (3, 15, 3, 2) assert action_targets.shape == (3, 16, 3, 2)
assert mask.shape == (3, 15, 3) assert mask.shape == (3, 16, 3)
from dreamer4.dreamer4 import ActionEmbedder from dreamer4.dreamer4 import ActionEmbedder
@ -596,11 +620,22 @@ def test_cache_generate():
num_residual_streams = 1 num_residual_streams = 1
) )
generated, time_kv_cache = dynamics.generate(1, return_time_kv_cache = True) generated, time_cache = dynamics.generate(1, return_time_cache = True)
generated, time_kv_cache = dynamics.generate(1, time_kv_cache = time_kv_cache, return_time_kv_cache = True) generated, time_cache = dynamics.generate(1, time_cache = time_cache, return_time_cache = True)
generated, time_kv_cache = dynamics.generate(1, time_kv_cache = time_kv_cache, return_time_kv_cache = True) generated, time_cache = dynamics.generate(1, time_cache = time_cache, return_time_cache = True)
def test_online_rl(): @param('vectorized', (False, True))
@param('use_pmpo', (False, True))
@param('env_can_terminate', (False, True))
@param('env_can_truncate', (False, True))
@param('store_agent_embed', (False, True))
def test_online_rl(
vectorized,
use_pmpo,
env_can_terminate,
env_can_truncate,
store_agent_embed
):
from dreamer4.dreamer4 import DynamicsWorldModel, VideoTokenizer from dreamer4.dreamer4 import DynamicsWorldModel, VideoTokenizer
tokenizer = VideoTokenizer( tokenizer = VideoTokenizer(
@ -611,7 +646,9 @@ def test_online_rl():
dim_latent = 16, dim_latent = 16,
patch_size = 32, patch_size = 32,
attn_dim_head = 16, attn_dim_head = 16,
num_latent_tokens = 1 num_latent_tokens = 1,
image_height = 256,
image_width = 256,
) )
world_model_and_policy = DynamicsWorldModel( world_model_and_policy = DynamicsWorldModel(
@ -632,11 +669,163 @@ def test_online_rl():
) )
from dreamer4.mocks import MockEnv from dreamer4.mocks import MockEnv
mock_env = MockEnv((256, 256), vectorized = False, num_envs = 4) from dreamer4.dreamer4 import combine_experiences
one_experience = world_model_and_policy.interact_with_env(mock_env, max_timesteps = 16) mock_env = MockEnv(
(256, 256),
vectorized = vectorized,
num_envs = 4,
terminate_after_step = 2 if env_can_terminate else None,
can_truncate = env_can_truncate,
rand_terminate_prob = 0.1
)
actor_loss, critic_loss = world_model_and_policy.learn_from_experience(one_experience) # manually
dream_experience = world_model_and_policy.generate(10, batch_size = 1, store_agent_embed = store_agent_embed, return_for_policy_optimization = True)
one_experience = world_model_and_policy.interact_with_env(mock_env, max_timesteps = 8, env_is_vectorized = vectorized, store_agent_embed = store_agent_embed)
another_experience = world_model_and_policy.interact_with_env(mock_env, max_timesteps = 16, env_is_vectorized = vectorized, store_agent_embed = store_agent_embed)
combined_experience = combine_experiences([dream_experience, one_experience, another_experience])
# quick test moving the experience to different devices
if torch.cuda.is_available():
combined_experience = combined_experience.to(torch.device('cuda'))
combined_experience = combined_experience.to(world_model_and_policy.device)
if store_agent_embed:
assert exists(combined_experience.agent_embed)
actor_loss, critic_loss = world_model_and_policy.learn_from_experience(combined_experience, use_pmpo = use_pmpo)
actor_loss.backward() actor_loss.backward()
critic_loss.backward() critic_loss.backward()
# with trainer
from dreamer4.trainers import SimTrainer
trainer = SimTrainer(
world_model_and_policy,
batch_size = 4,
cpu = True
)
trainer(mock_env, num_episodes = 2, env_is_vectorized = vectorized)
@param('num_video_views', (1, 2))
def test_proprioception(
num_video_views
):
from dreamer4.dreamer4 import VideoTokenizer, DynamicsWorldModel
tokenizer = VideoTokenizer(
512,
dim_latent = 32,
patch_size = 32,
encoder_depth = 2,
decoder_depth = 2,
time_block_every = 2,
attn_heads = 8,
image_height = 256,
image_width = 256,
attn_kwargs = dict(
query_heads = 16
)
)
dynamics = DynamicsWorldModel(
512,
num_agents = 1,
video_tokenizer = tokenizer,
dim_latent = 32,
dim_proprio = 21,
num_tasks = 4,
num_video_views = num_video_views,
num_discrete_actions = 4,
num_residual_streams = 1
)
if num_video_views > 1:
video_shape = (2, num_video_views, 3, 10, 256, 256)
else:
video_shape = (2, 3, 10, 256, 256)
video = torch.randn(*video_shape)
rewards = torch.randn(2, 10)
proprio = torch.randn(2, 10, 21)
discrete_actions = torch.randint(0, 4, (2, 10, 1))
tasks = torch.randint(0, 4, (2,))
loss = dynamics(
video = video,
rewards = rewards,
tasks = tasks,
proprio = proprio,
discrete_actions = discrete_actions
)
loss.backward()
generations = dynamics.generate(
10,
batch_size = 2,
return_decoded_video = True
)
assert exists(generations.proprio)
assert generations.video.shape == video_shape
def test_epo():
from dreamer4.dreamer4 import VideoTokenizer, DynamicsWorldModel
tokenizer = VideoTokenizer(
512,
dim_latent = 32,
patch_size = 32,
encoder_depth = 2,
decoder_depth = 2,
time_block_every = 2,
attn_heads = 8,
image_height = 256,
image_width = 256,
attn_kwargs = dict(
query_heads = 16
)
)
dynamics = DynamicsWorldModel(
512,
num_agents = 1,
video_tokenizer = tokenizer,
dim_latent = 32,
dim_proprio = 21,
num_tasks = 4,
num_latent_genes = 16,
num_discrete_actions = 4,
num_residual_streams = 1
)
fitness = torch.randn(16,)
dynamics.evolve_(fitness)
def test_images_to_video_tokenizer():
import torch
from dreamer4 import VideoTokenizer, DynamicsWorldModel, AxialSpaceTimeTransformer
tokenizer = VideoTokenizer(
dim = 512,
dim_latent = 32,
patch_size = 32,
image_height = 256,
image_width = 256,
encoder_add_decor_aux_loss = True
)
images = torch.randn(2, 3, 256, 256)
loss, (losses, recon_images) = tokenizer(images, return_intermediates = True)
loss.backward()
assert images.shape == recon_images.shape