Compare commits

...

66 Commits
0.0.64 ... main

Author SHA1 Message Date
lucidrains
5bb027b386 allow for image pretraining on video tokenizer 2025-12-04 10:34:15 -08:00
lucidrains
9efe269688 oops 2025-12-03 08:11:47 -08:00
lucidrains
fb8c3793b4 complete the addition of a state entropy bonus 2025-12-03 07:52:30 -08:00
lucidrains
fb6d69f43a complete the latent autoregressive prediction, to use the log variance as a state entropy bonus 2025-12-03 06:40:19 -08:00
lucidrains
125693ce1c add a separate state prediction head for the state entropy 2025-12-02 15:58:25 -08:00
lucidrains
2e7f406d49 allow for the combining of experiences from environment and dream 2025-11-13 16:37:35 -08:00
lucidrains
690ecf07dc fix the rnn time caching issue 2025-11-11 17:04:02 -08:00
lucidrains
ac1c12f743 disable until rnn hiddens are handled properly 2025-11-10 15:52:43 -08:00
lucidrains
3c84b404a8 rnn layer needs to be hyper connected too 2025-11-10 15:51:33 -08:00
lucidrains
d5b70e2b86 allow for adding an RNN before time attention, but need to handle caching still 2025-11-10 11:42:20 -08:00
lucidrains
c3532fa797 add learned value residual 2025-11-10 09:33:58 -08:00
lucidrains
73029635fe last commit for the day 2025-11-09 11:12:37 -08:00
lucidrains
e1c41f4371 decorrelation loss for spatial attention as well 2025-11-09 10:41:58 -08:00
Phil Wang
f55c61c6cf
cleanup 2025-11-09 10:22:47 -08:00
lucidrains
051d4d6ee2 oops 2025-11-09 10:12:51 -08:00
lucidrains
ef3a5552e7 eventually video tokenizer may need to be trained on single frames 2025-11-09 10:11:56 -08:00
lucidrains
0c4224da18 add a decorrelation loss for temporal attention in encoder of video tokenizer 2025-11-09 09:47:47 -08:00
Phil Wang
256a81f658
Merge pull request #5 from Cycl0/patch-1
Update Discord channel link in README to use permanent link
2025-11-09 08:17:41 -08:00
lucidrains
cfd34f1eba able to move the experience to cpu easily, and auto matically move it to the device of the dynamics world model when learning from it 2025-11-09 16:16:13 +00:00
Lucas Kenzo Cyra
4ffbe37873
Update Discord channel link in README to use permanent link
Updated Discord channel link for collaboration.
2025-11-09 10:12:45 -03:00
lucidrains
24ef72d528 0.1.4 2025-11-04 15:21:20 -08:00
Phil Wang
a4afcb22a6
Merge pull request #4 from dirkmcpherson/bugfix
fix a few typo bugs. Support info in return signature of environment …
2025-11-04 15:19:25 -08:00
j
b0f6b8583d fix a few typo bugs. Support info in return signature of environment step. Temporarily turn off flex attention when the kv_cache is used to avoid bug. 2025-11-04 17:29:12 -05:00
lucidrains
38cba80068 readme 2025-11-04 06:05:11 -08:00
lucidrains
c0a6cd56a1 link to new discord 2025-10-31 09:06:44 -07:00
lucidrains
d756d1bb8c addressing issues raised by an independent researcher with llm assistance 2025-10-31 08:37:39 -07:00
lucidrains
60681fce1d fix generation so that one more step is taken to decode agent embeds off the final cleaned set of latents, update readme 2025-10-31 06:48:49 -07:00
Phil Wang
6870294d95
no longer needed 2025-10-30 09:23:27 -07:00
lucidrains
3beae186da some more control over whether to normalize advantages 2025-10-30 08:46:03 -07:00
lucidrains
0904e224ab make the reverse kl optional 2025-10-30 08:22:50 -07:00
lucidrains
767789d0ca they decided on 0.3 for the behavioral prior loss weight 2025-10-29 13:24:58 -07:00
lucidrains
35b87c4fa1 oops 2025-10-29 13:04:02 -07:00
lucidrains
c4a3cb09d5 swap for discrete kl div, thanks to Dirk for pointing this out on the discord 2025-10-29 11:54:18 -07:00
lucidrains
cb54121ace sim trainer needs to take care of agent embedding and old actions 2025-10-29 11:15:11 -07:00
lucidrains
586379f2c8 sum the kl div loss across number of actions by default for action embedder .kl_div 2025-10-29 10:46:42 -07:00
lucidrains
a358a44a53 always store old agent embeds and old action parameters when possible 2025-10-29 10:39:15 -07:00
lucidrains
3547344312 take care of storing the old action logits and mean log var, and calculate kl div for pmpo based off that during learn from experience 2025-10-29 10:31:32 -07:00
lucidrains
691d9ca007 add kl div on action embedder, working way towards the kl div loss in pmpo 2025-10-29 10:02:25 -07:00
lucidrains
91d697f8ca fix pmpo 2025-10-28 18:55:22 -07:00
lucidrains
7acaa764f6 evolutionary policy optimization on dreams will be interesting 2025-10-28 10:17:01 -07:00
lucidrains
c0450359f3 allow for evolutionary policy optimization 2025-10-28 10:11:13 -07:00
lucidrains
46f86cd247 fix storing of agent embedding 2025-10-28 09:36:58 -07:00
lucidrains
903c43b770 use the agent embeds off the stored experience if available 2025-10-28 09:14:02 -07:00
lucidrains
d476fa7b14 able to store the agent embeddings during rollouts with imagination or environment, for efficient policy optimization (but will also allow for finetuning world model for the heads) 2025-10-28 09:02:26 -07:00
lucidrains
789f091c63 redo so that max timesteps is treated as truncation at the last timestep, then allow for accepting the truncation signal from the environment and reuse same logic 2025-10-28 08:04:48 -07:00
lucidrains
41ab83f691 fix mock 2025-10-27 10:47:24 -07:00
lucidrains
995b1f64e5 handle environments that return a terminate flag, also make sure episode lens are logged in vectorized env 2025-10-27 10:14:28 -07:00
lucidrains
fd1e87983b quantile filter 2025-10-27 09:08:26 -07:00
lucidrains
fe79bfa951 optionally keep track of returns statistics and normalize with them before advantage 2025-10-27 09:02:08 -07:00
lucidrains
f808b1c1d2 oops 2025-10-27 08:34:22 -07:00
lucidrains
349a03acd7 redo so lens is always the episode length, including the bootstrap value timestep, and use is_truncated to mask out the bootstrap node from being learned on 2025-10-27 08:06:21 -07:00
lucidrains
59c458aea3 introduce an is_truncated field on Experience, and mask out rewards and values before calculating gae appropriately 2025-10-27 07:55:00 -07:00
lucidrains
fbfd59e42f handle variable lengthed experiences when doing policy optimization 2025-10-27 06:09:09 -07:00
lucidrains
46432aee9b fix an issue with bc 2025-10-25 12:30:08 -07:00
lucidrains
f97d9adc97 oops, forgot to add the view embedding for robotics 2025-10-25 11:39:06 -07:00
lucidrains
32cf142b4d take another step for variable len experiences 2025-10-25 11:31:41 -07:00
lucidrains
1ed6a15cb0 fix tests 2025-10-25 11:13:22 -07:00
lucidrains
4d8f5613cc start storing the experience lens 2025-10-25 10:55:47 -07:00
lucidrains
3d5617d769 take a step towards variable lengthed experiences during training 2025-10-25 10:45:34 -07:00
lucidrains
77a40e8701 validate that we can generate multiple video streams for robotics use-case 2025-10-25 09:23:07 -07:00
lucidrains
4ce82f34df given the VAT paper, add multiple video streams (third person, wrist camera, etc), geared for robotics. need to manage an extra dimension for multiple viewpoints 2025-10-25 09:20:55 -07:00
lucidrains
a9b728c611 incorporate proprioception into the dynamics world model 2025-10-24 11:24:22 -07:00
lucidrains
35c1db4c7d sketch of training from sim env 2025-10-24 09:13:09 -07:00
lucidrains
27ac05efb0 function for combining experiences 2025-10-24 08:00:10 -07:00
lucidrains
d0ffc6bfed with or without signed advantage 2025-10-23 16:24:29 -07:00
lucidrains
fb3e026fe0 handle vectorized env 2025-10-22 11:19:44 -07:00
8 changed files with 1810 additions and 228 deletions

View File

@ -5,11 +5,11 @@ jobs:
build:
runs-on: ubuntu-latest
timeout-minutes: 20
timeout-minutes: 60
strategy:
fail-fast: false
matrix:
group: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
group: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19]
steps:
- uses: actions/checkout@v4
@ -24,4 +24,4 @@ jobs:
python -m uv pip install -e .[test]
- name: Test with pytest
run: |
python -m pytest --num-shards 10 --shard-id ${{ matrix.group }} tests/
python -m pytest --num-shards 20 --shard-id ${{ matrix.group }} tests/

View File

@ -1,10 +1,99 @@
<img src="./dreamer4-fig2.png" width="400px"></img>
## Dreamer 4 (wip)
## Dreamer 4
Implementation of Danijar's [latest iteration](https://arxiv.org/abs/2509.24527v1) for his [Dreamer](https://danijar.com/project/dreamer4/) line of work
[Temporary Discord](https://discord.gg/MkACrrkrYR)
[Discord channel](https://discord.gg/PmGR7KRwxq) for collaborating with other researchers interested in this work
## Appreciation
- [@dirkmcpherson](https://github.com/dirkmcpherson) for fixes to typo errors and unpassed arguments!
## Install
```bash
$ pip install dreamer4
```
## Usage
```python
import torch
from dreamer4 import VideoTokenizer, DynamicsWorldModel
# video tokenizer, learned through MAE + lpips
tokenizer = VideoTokenizer(
dim = 512,
dim_latent = 32,
patch_size = 32,
image_height = 256,
image_width = 256
)
video = torch.randn(2, 3, 10, 256, 256)
# learn the tokenizer
loss = tokenizer(video)
loss.backward()
# dynamics world model
world_model = DynamicsWorldModel(
dim = 512,
dim_latent = 32,
video_tokenizer = tokenizer,
num_discrete_actions = 4,
num_residual_streams = 1
)
# state, action, rewards
video = torch.randn(2, 3, 10, 256, 256)
discrete_actions = torch.randint(0, 4, (2, 10, 1))
rewards = torch.randn(2, 10)
# learn dynamics / behavior cloned model
loss = world_model(
video = video,
rewards = rewards,
discrete_actions = discrete_actions
)
loss.backward()
# do the above with much data
# then generate dreams
dreams = world_model.generate(
10,
batch_size = 2,
return_decoded_video = True,
return_for_policy_optimization = True
)
# learn from the dreams
actor_loss, critic_loss = world_model.learn_from_experience(dreams)
(actor_loss + critic_loss).backward()
# learn from environment
from dreamer4.mocks import MockEnv
mock_env = MockEnv((256, 256), vectorized = True, num_envs = 4)
experience = world_model.interact_with_env(mock_env, max_timesteps = 8, env_is_vectorized = True)
actor_loss, critic_loss = world_model.learn_from_experience(experience)
(actor_loss + critic_loss).backward()
```
## Citation
@ -19,3 +108,5 @@ Implementation of Danijar's [latest iteration](https://arxiv.org/abs/2509.24527v
url = {https://arxiv.org/abs/2509.24527},
}
```
*the conquest of nature is to be achieved through number and measure - angels to Descartes in a dream*

View File

@ -1,6 +1,7 @@
from dreamer4.dreamer4 import (
VideoTokenizer,
DynamicsWorldModel
DynamicsWorldModel,
AxialSpaceTimeTransformer
)

File diff suppressed because it is too large Load Diff

View File

@ -7,6 +7,11 @@ from torch.nn import Module
from einops import repeat
# helpers
def exists(v):
return v is not None
# mock env
class MockEnv(Module):
@ -15,7 +20,11 @@ class MockEnv(Module):
image_shape,
reward_range = (-100, 100),
num_envs = 1,
vectorized = False
vectorized = False,
terminate_after_step = None,
rand_terminate_prob = 0.05,
can_truncate = False,
rand_truncate_prob = 0.05,
):
super().__init__()
self.image_shape = image_shape
@ -23,6 +32,17 @@ class MockEnv(Module):
self.num_envs = num_envs
self.vectorized = vectorized
assert not (vectorized and num_envs == 1)
# mocking termination and truncation
self.can_terminate = exists(terminate_after_step)
self.terminate_after_step = terminate_after_step
self.rand_terminate_prob = rand_terminate_prob
self.can_truncate = can_truncate
self.rand_truncate_prob = rand_truncate_prob
self.register_buffer('_step', tensor(0))
def get_random_state(self):
@ -33,7 +53,12 @@ class MockEnv(Module):
seed = None
):
self._step.zero_()
return self.get_random_state()
state = self.get_random_state()
if self.vectorized:
state = repeat(state, '... -> b ...', b = self.num_envs)
return state
def step(
self,
@ -43,12 +68,30 @@ class MockEnv(Module):
reward = empty(()).uniform_(*self.reward_range)
if not self.vectorized:
return state, reward
if self.vectorized:
discrete, continuous = actions
assert discrete.shape[0] == self.num_envs, f'expected batch of actions for {self.num_envs} environments'
assert actions.shape[0] == self.num_envs
state = repeat(state, '... -> b ...', b = self.num_envs)
reward = repeat(reward, ' -> b', b = self.num_envs)
state = repeat(state, '... -> b ...', b = self.num_envs)
reward = repeat(reward, ' -> b', b = self.num_envs)
out = (state, reward)
return state, rewards
if self.can_terminate:
shape = (self.num_envs,) if self.vectorized else (1,)
valid_step = self._step > self.terminate_after_step
terminate = (torch.rand(shape) < self.rand_terminate_prob) & valid_step
out = (*out, terminate)
# maybe truncation
if self.can_truncate:
truncate = (torch.rand(shape) < self.rand_truncate_prob) & valid_step & ~terminate
out = (*out, truncate)
self._step.add_(1)
return out

View File

@ -4,7 +4,7 @@ import torch
from torch import is_tensor
from torch.nn import Module
from torch.optim import AdamW
from torch.utils.data import Dataset, DataLoader
from torch.utils.data import Dataset, TensorDataset, DataLoader
from accelerate import Accelerator
@ -12,7 +12,9 @@ from adam_atan2_pytorch import MuonAdamAtan2
from dreamer4.dreamer4 import (
VideoTokenizer,
DynamicsWorldModel
DynamicsWorldModel,
Experience,
combine_experiences
)
from ema_pytorch import EMA
@ -285,7 +287,7 @@ class DreamTrainer(Module):
for _ in range(self.num_train_steps):
dreams = self.unwrapped_model.generate(
self.generate_timesteps,
self.generate_timesteps + 1, # plus one for bootstrap value
batch_size = self.batch_size,
return_rewards_per_frame = True,
return_agent_actions = True,
@ -317,3 +319,221 @@ class DreamTrainer(Module):
self.value_head_optim.zero_grad()
self.print('training complete')
# training from sim
class SimTrainer(Module):
def __init__(
self,
model: DynamicsWorldModel,
optim_klass = AdamW,
batch_size = 16,
generate_timesteps = 16,
learning_rate = 3e-4,
max_grad_norm = None,
epochs = 2,
weight_decay = 0.,
accelerate_kwargs: dict = dict(),
optim_kwargs: dict = dict(),
cpu = False,
):
super().__init__()
self.accelerator = Accelerator(
cpu = cpu,
**accelerate_kwargs
)
self.model = model
optim_kwargs = dict(
lr = learning_rate,
weight_decay = weight_decay
)
self.policy_head_optim = AdamW(model.policy_head_parameters(), **optim_kwargs)
self.value_head_optim = AdamW(model.value_head_parameters(), **optim_kwargs)
self.max_grad_norm = max_grad_norm
self.epochs = epochs
self.batch_size = batch_size
self.generate_timesteps = generate_timesteps
self.unwrapped_model = self.model
(
self.model,
self.policy_head_optim,
self.value_head_optim,
) = self.accelerator.prepare(
self.model,
self.policy_head_optim,
self.value_head_optim
)
@property
def device(self):
return self.accelerator.device
@property
def unwrapped_model(self):
return self.accelerator.unwrap_model(self.model)
def print(self, *args, **kwargs):
return self.accelerator.print(*args, **kwargs)
def learn(
self,
experience: Experience
):
step_size = experience.step_size
agent_index = experience.agent_index
latents = experience.latents
old_values = experience.values
rewards = experience.rewards
has_agent_embed = exists(experience.agent_embed)
agent_embed = experience.agent_embed
discrete_actions, continuous_actions = experience.actions
discrete_log_probs, continuous_log_probs = experience.log_probs
discrete_old_action_unembeds, continuous_old_action_unembeds = default(experience.old_action_unembeds, (None, None))
# handle empties
empty_tensor = torch.empty_like(rewards)
agent_embed = default(agent_embed, empty_tensor)
has_discrete = exists(discrete_actions)
has_continuous = exists(continuous_actions)
discrete_actions = default(discrete_actions, empty_tensor)
continuous_actions = default(continuous_actions, empty_tensor)
discrete_log_probs = default(discrete_log_probs, empty_tensor)
continuous_log_probs = default(continuous_log_probs, empty_tensor)
discrete_old_action_unembeds = default(discrete_old_action_unembeds, empty_tensor)
continuous_old_action_unembeds = default(discrete_old_action_unembeds, empty_tensor)
# create the dataset and dataloader
dataset = TensorDataset(
latents,
discrete_actions,
continuous_actions,
discrete_log_probs,
continuous_log_probs,
agent_embed,
discrete_old_action_unembeds,
continuous_old_action_unembeds,
old_values,
rewards
)
dataloader = DataLoader(dataset, batch_size = self.batch_size, shuffle = True)
for epoch in range(self.epochs):
for (
latents,
discrete_actions,
continuous_actions,
discrete_log_probs,
continuous_log_probs,
agent_embed,
discrete_old_action_unembeds,
continuous_old_action_unembeds,
old_values,
rewards
) in dataloader:
actions = (
discrete_actions if has_discrete else None,
continuous_actions if has_continuous else None
)
log_probs = (
discrete_log_probs if has_discrete else None,
continuous_log_probs if has_continuous else None
)
old_action_unembeds = (
discrete_old_action_unembeds if has_discrete else None,
continuous_old_action_unembeds if has_continuous else None
)
batch_experience = Experience(
latents = latents,
actions = actions,
log_probs = log_probs,
agent_embed = agent_embed if has_agent_embed else None,
old_action_unembeds = old_action_unembeds,
values = old_values,
rewards = rewards,
step_size = step_size,
agent_index = agent_index
)
policy_head_loss, value_head_loss = self.model.learn_from_experience(batch_experience)
self.print(f'policy head loss: {policy_head_loss.item():.3f} | value head loss: {value_head_loss.item():.3f}')
# update policy head
self.accelerator.backward(policy_head_loss)
if exists(self.max_grad_norm):
self.accelerator.clip_grad_norm_(self.model.policy_head_parameters()(), self.max_grad_norm)
self.policy_head_optim.step()
self.policy_head_optim.zero_grad()
# update value head
self.accelerator.backward(value_head_loss)
if exists(self.max_grad_norm):
self.accelerator.clip_grad_norm_(self.model.value_head_parameters(), self.max_grad_norm)
self.value_head_optim.step()
self.value_head_optim.zero_grad()
self.print('training complete')
def forward(
self,
env,
num_episodes = 50000,
max_experiences_before_learn = 8,
env_is_vectorized = False
):
for _ in range(num_episodes):
total_experience = 0
experiences = []
while total_experience < max_experiences_before_learn:
experience = self.unwrapped_model.interact_with_env(env, env_is_vectorized = env_is_vectorized)
num_experience = experience.video.shape[0]
total_experience += num_experience
experiences.append(experience.cpu())
combined_experiences = combine_experiences(experiences)
self.learn(combined_experiences)
experiences.clear()
self.print('training complete')

View File

@ -1,6 +1,6 @@
[project]
name = "dreamer4"
version = "0.0.62"
version = "0.1.24"
description = "Dreamer 4"
authors = [
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
@ -36,7 +36,8 @@ dependencies = [
"hyper-connections>=0.2.1",
"torch>=2.4",
"torchvision",
"x-mlps-pytorch>=0.0.29"
"x-mlps-pytorch>=0.0.29",
"vit-pytorch>=1.15.3"
]
[project.urls]

View File

@ -2,6 +2,9 @@ import pytest
param = pytest.mark.parametrize
import torch
def exists(v):
return v is not None
@param('pred_orig_latent', (False, True))
@param('grouped_query_attn', (False, True))
@param('dynamics_with_video_input', (False, True))
@ -12,7 +15,9 @@ import torch
@param('condition_on_actions', (False, True))
@param('num_residual_streams', (1, 4))
@param('add_reward_embed_to_agent_token', (False, True))
@param('use_time_kv_cache', (False, True))
@param('add_state_pred_head', (False, True))
@param('use_time_cache', (False, True))
@param('var_len', (False, True))
def test_e2e(
pred_orig_latent,
grouped_query_attn,
@ -24,7 +29,9 @@ def test_e2e(
condition_on_actions,
num_residual_streams,
add_reward_embed_to_agent_token,
use_time_kv_cache
add_state_pred_head,
use_time_cache,
var_len
):
from dreamer4.dreamer4 import VideoTokenizer, DynamicsWorldModel
@ -36,7 +43,9 @@ def test_e2e(
patch_size = 32,
attn_dim_head = 16,
num_latent_tokens = 4,
num_residual_streams = num_residual_streams
num_residual_streams = num_residual_streams,
encoder_add_decor_aux_loss = True,
decorr_sample_frac = 1.
)
video = torch.randn(2, 3, 4, 256, 256)
@ -64,12 +73,13 @@ def test_e2e(
pred_orig_latent = pred_orig_latent,
num_discrete_actions = 4,
attn_dim_head = 16,
attn_heads = heads,
attn_kwargs = dict(
heads = heads,
query_heads = query_heads,
),
prob_no_shortcut_train = prob_no_shortcut_train,
add_reward_embed_to_agent_token = add_reward_embed_to_agent_token,
add_state_pred_head = add_state_pred_head,
num_residual_streams = num_residual_streams
)
@ -92,8 +102,13 @@ def test_e2e(
if condition_on_actions:
actions = torch.randint(0, 4, (2, 3, 1))
lens = None
if var_len:
lens = torch.randint(1, 4, (2,))
flow_loss = dynamics(
**dynamics_input,
lens = lens,
tasks = tasks,
signal_levels = signal_levels,
step_sizes_log2 = step_sizes_log2,
@ -111,7 +126,7 @@ def test_e2e(
image_width = 128,
batch_size = 2,
return_rewards_per_frame = True,
use_time_kv_cache = use_time_kv_cache
use_time_cache = use_time_cache
)
assert generations.video.shape == (2, 3, 10, 128, 128)
@ -336,6 +351,15 @@ def test_action_embedder():
assert discrete_logits.shape == (2, 3, 8)
assert continuous_mean_log_var.shape == (2, 3, 2, 2)
# test kl div
discrete_logits_tgt, continuous_mean_log_var_tgt = embedder.unembed(action_embed)
discrete_kl_div, continuous_kl_div = embedder.kl_div((discrete_logits, continuous_mean_log_var), (discrete_logits_tgt, continuous_mean_log_var_tgt))
assert discrete_kl_div.shape == (2, 3)
assert continuous_kl_div.shape == (2, 3)
# return discrete split by number of actions
discrete_logits, continuous_mean_log_var = embedder.unembed(action_embed, return_split_discrete = True)
@ -397,14 +421,14 @@ def test_mtp():
reward_targets, mask = create_multi_token_prediction_targets(rewards, 3) # say three token lookahead
assert reward_targets.shape == (3, 15, 3)
assert mask.shape == (3, 15, 3)
assert reward_targets.shape == (3, 16, 3)
assert mask.shape == (3, 16, 3)
actions = torch.randint(0, 10, (3, 16, 2))
action_targets, mask = create_multi_token_prediction_targets(actions, 3)
assert action_targets.shape == (3, 15, 3, 2)
assert mask.shape == (3, 15, 3)
assert action_targets.shape == (3, 16, 3, 2)
assert mask.shape == (3, 16, 3)
from dreamer4.dreamer4 import ActionEmbedder
@ -596,11 +620,22 @@ def test_cache_generate():
num_residual_streams = 1
)
generated, time_kv_cache = dynamics.generate(1, return_time_kv_cache = True)
generated, time_kv_cache = dynamics.generate(1, time_kv_cache = time_kv_cache, return_time_kv_cache = True)
generated, time_kv_cache = dynamics.generate(1, time_kv_cache = time_kv_cache, return_time_kv_cache = True)
generated, time_cache = dynamics.generate(1, return_time_cache = True)
generated, time_cache = dynamics.generate(1, time_cache = time_cache, return_time_cache = True)
generated, time_cache = dynamics.generate(1, time_cache = time_cache, return_time_cache = True)
def test_online_rl():
@param('vectorized', (False, True))
@param('use_pmpo', (False, True))
@param('env_can_terminate', (False, True))
@param('env_can_truncate', (False, True))
@param('store_agent_embed', (False, True))
def test_online_rl(
vectorized,
use_pmpo,
env_can_terminate,
env_can_truncate,
store_agent_embed
):
from dreamer4.dreamer4 import DynamicsWorldModel, VideoTokenizer
tokenizer = VideoTokenizer(
@ -611,7 +646,9 @@ def test_online_rl():
dim_latent = 16,
patch_size = 32,
attn_dim_head = 16,
num_latent_tokens = 1
num_latent_tokens = 1,
image_height = 256,
image_width = 256,
)
world_model_and_policy = DynamicsWorldModel(
@ -632,11 +669,163 @@ def test_online_rl():
)
from dreamer4.mocks import MockEnv
mock_env = MockEnv((256, 256), vectorized = False, num_envs = 4)
from dreamer4.dreamer4 import combine_experiences
one_experience = world_model_and_policy.interact_with_env(mock_env, max_timesteps = 16)
mock_env = MockEnv(
(256, 256),
vectorized = vectorized,
num_envs = 4,
terminate_after_step = 2 if env_can_terminate else None,
can_truncate = env_can_truncate,
rand_terminate_prob = 0.1
)
actor_loss, critic_loss = world_model_and_policy.learn_from_experience(one_experience)
# manually
dream_experience = world_model_and_policy.generate(10, batch_size = 1, store_agent_embed = store_agent_embed, return_for_policy_optimization = True)
one_experience = world_model_and_policy.interact_with_env(mock_env, max_timesteps = 8, env_is_vectorized = vectorized, store_agent_embed = store_agent_embed)
another_experience = world_model_and_policy.interact_with_env(mock_env, max_timesteps = 16, env_is_vectorized = vectorized, store_agent_embed = store_agent_embed)
combined_experience = combine_experiences([dream_experience, one_experience, another_experience])
# quick test moving the experience to different devices
if torch.cuda.is_available():
combined_experience = combined_experience.to(torch.device('cuda'))
combined_experience = combined_experience.to(world_model_and_policy.device)
if store_agent_embed:
assert exists(combined_experience.agent_embed)
actor_loss, critic_loss = world_model_and_policy.learn_from_experience(combined_experience, use_pmpo = use_pmpo)
actor_loss.backward()
critic_loss.backward()
# with trainer
from dreamer4.trainers import SimTrainer
trainer = SimTrainer(
world_model_and_policy,
batch_size = 4,
cpu = True
)
trainer(mock_env, num_episodes = 2, env_is_vectorized = vectorized)
@param('num_video_views', (1, 2))
def test_proprioception(
num_video_views
):
from dreamer4.dreamer4 import VideoTokenizer, DynamicsWorldModel
tokenizer = VideoTokenizer(
512,
dim_latent = 32,
patch_size = 32,
encoder_depth = 2,
decoder_depth = 2,
time_block_every = 2,
attn_heads = 8,
image_height = 256,
image_width = 256,
attn_kwargs = dict(
query_heads = 16
)
)
dynamics = DynamicsWorldModel(
512,
num_agents = 1,
video_tokenizer = tokenizer,
dim_latent = 32,
dim_proprio = 21,
num_tasks = 4,
num_video_views = num_video_views,
num_discrete_actions = 4,
num_residual_streams = 1
)
if num_video_views > 1:
video_shape = (2, num_video_views, 3, 10, 256, 256)
else:
video_shape = (2, 3, 10, 256, 256)
video = torch.randn(*video_shape)
rewards = torch.randn(2, 10)
proprio = torch.randn(2, 10, 21)
discrete_actions = torch.randint(0, 4, (2, 10, 1))
tasks = torch.randint(0, 4, (2,))
loss = dynamics(
video = video,
rewards = rewards,
tasks = tasks,
proprio = proprio,
discrete_actions = discrete_actions
)
loss.backward()
generations = dynamics.generate(
10,
batch_size = 2,
return_decoded_video = True
)
assert exists(generations.proprio)
assert generations.video.shape == video_shape
def test_epo():
from dreamer4.dreamer4 import VideoTokenizer, DynamicsWorldModel
tokenizer = VideoTokenizer(
512,
dim_latent = 32,
patch_size = 32,
encoder_depth = 2,
decoder_depth = 2,
time_block_every = 2,
attn_heads = 8,
image_height = 256,
image_width = 256,
attn_kwargs = dict(
query_heads = 16
)
)
dynamics = DynamicsWorldModel(
512,
num_agents = 1,
video_tokenizer = tokenizer,
dim_latent = 32,
dim_proprio = 21,
num_tasks = 4,
num_latent_genes = 16,
num_discrete_actions = 4,
num_residual_streams = 1
)
fitness = torch.randn(16,)
dynamics.evolve_(fitness)
def test_images_to_video_tokenizer():
import torch
from dreamer4 import VideoTokenizer, DynamicsWorldModel, AxialSpaceTimeTransformer
tokenizer = VideoTokenizer(
dim = 512,
dim_latent = 32,
patch_size = 32,
image_height = 256,
image_width = 256,
encoder_add_decor_aux_loss = True
)
images = torch.randn(2, 3, 256, 256)
loss, (losses, recon_images) = tokenizer(images, return_intermediates = True)
loss.backward()
assert images.shape == recon_images.shape