Compare commits

..

1 Commits
main ... 0.0.64

Author SHA1 Message Date
lucidrains
e4ee4d905a handle vectorized env 2025-10-22 08:52:08 -07:00
8 changed files with 231 additions and 1788 deletions

View File

@ -5,11 +5,11 @@ jobs:
build: build:
runs-on: ubuntu-latest runs-on: ubuntu-latest
timeout-minutes: 60 timeout-minutes: 20
strategy: strategy:
fail-fast: false fail-fast: false
matrix: matrix:
group: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19] group: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
@ -24,4 +24,4 @@ jobs:
python -m uv pip install -e .[test] python -m uv pip install -e .[test]
- name: Test with pytest - name: Test with pytest
run: | run: |
python -m pytest --num-shards 20 --shard-id ${{ matrix.group }} tests/ python -m pytest --num-shards 10 --shard-id ${{ matrix.group }} tests/

View File

@ -1,99 +1,10 @@
<img src="./dreamer4-fig2.png" width="400px"></img> <img src="./dreamer4-fig2.png" width="400px"></img>
## Dreamer 4 ## Dreamer 4 (wip)
Implementation of Danijar's [latest iteration](https://arxiv.org/abs/2509.24527v1) for his [Dreamer](https://danijar.com/project/dreamer4/) line of work Implementation of Danijar's [latest iteration](https://arxiv.org/abs/2509.24527v1) for his [Dreamer](https://danijar.com/project/dreamer4/) line of work
[Discord channel](https://discord.gg/PmGR7KRwxq) for collaborating with other researchers interested in this work [Temporary Discord](https://discord.gg/MkACrrkrYR)
## Appreciation
- [@dirkmcpherson](https://github.com/dirkmcpherson) for fixes to typo errors and unpassed arguments!
## Install
```bash
$ pip install dreamer4
```
## Usage
```python
import torch
from dreamer4 import VideoTokenizer, DynamicsWorldModel
# video tokenizer, learned through MAE + lpips
tokenizer = VideoTokenizer(
dim = 512,
dim_latent = 32,
patch_size = 32,
image_height = 256,
image_width = 256
)
video = torch.randn(2, 3, 10, 256, 256)
# learn the tokenizer
loss = tokenizer(video)
loss.backward()
# dynamics world model
world_model = DynamicsWorldModel(
dim = 512,
dim_latent = 32,
video_tokenizer = tokenizer,
num_discrete_actions = 4,
num_residual_streams = 1
)
# state, action, rewards
video = torch.randn(2, 3, 10, 256, 256)
discrete_actions = torch.randint(0, 4, (2, 10, 1))
rewards = torch.randn(2, 10)
# learn dynamics / behavior cloned model
loss = world_model(
video = video,
rewards = rewards,
discrete_actions = discrete_actions
)
loss.backward()
# do the above with much data
# then generate dreams
dreams = world_model.generate(
10,
batch_size = 2,
return_decoded_video = True,
return_for_policy_optimization = True
)
# learn from the dreams
actor_loss, critic_loss = world_model.learn_from_experience(dreams)
(actor_loss + critic_loss).backward()
# learn from environment
from dreamer4.mocks import MockEnv
mock_env = MockEnv((256, 256), vectorized = True, num_envs = 4)
experience = world_model.interact_with_env(mock_env, max_timesteps = 8, env_is_vectorized = True)
actor_loss, critic_loss = world_model.learn_from_experience(experience)
(actor_loss + critic_loss).backward()
```
## Citation ## Citation
@ -108,5 +19,3 @@ actor_loss, critic_loss = world_model.learn_from_experience(experience)
url = {https://arxiv.org/abs/2509.24527}, url = {https://arxiv.org/abs/2509.24527},
} }
``` ```
*the conquest of nature is to be achieved through number and measure - angels to Descartes in a dream*

View File

@ -1,7 +1,6 @@
from dreamer4.dreamer4 import ( from dreamer4.dreamer4 import (
VideoTokenizer, VideoTokenizer,
DynamicsWorldModel, DynamicsWorldModel
AxialSpaceTimeTransformer
) )

File diff suppressed because it is too large Load Diff

View File

@ -7,11 +7,6 @@ from torch.nn import Module
from einops import repeat from einops import repeat
# helpers
def exists(v):
return v is not None
# mock env # mock env
class MockEnv(Module): class MockEnv(Module):
@ -20,11 +15,7 @@ class MockEnv(Module):
image_shape, image_shape,
reward_range = (-100, 100), reward_range = (-100, 100),
num_envs = 1, num_envs = 1,
vectorized = False, vectorized = False
terminate_after_step = None,
rand_terminate_prob = 0.05,
can_truncate = False,
rand_truncate_prob = 0.05,
): ):
super().__init__() super().__init__()
self.image_shape = image_shape self.image_shape = image_shape
@ -34,15 +25,6 @@ class MockEnv(Module):
self.vectorized = vectorized self.vectorized = vectorized
assert not (vectorized and num_envs == 1) assert not (vectorized and num_envs == 1)
# mocking termination and truncation
self.can_terminate = exists(terminate_after_step)
self.terminate_after_step = terminate_after_step
self.rand_terminate_prob = rand_terminate_prob
self.can_truncate = can_truncate
self.rand_truncate_prob = rand_truncate_prob
self.register_buffer('_step', tensor(0)) self.register_buffer('_step', tensor(0))
def get_random_state(self): def get_random_state(self):
@ -68,30 +50,13 @@ class MockEnv(Module):
reward = empty(()).uniform_(*self.reward_range) reward = empty(()).uniform_(*self.reward_range)
if self.vectorized: if not self.vectorized:
discrete, continuous = actions return state, reward
assert discrete.shape[0] == self.num_envs, f'expected batch of actions for {self.num_envs} environments'
state = repeat(state, '... -> b ...', b = self.num_envs) discrete, continuous = actions
reward = repeat(reward, ' -> b', b = self.num_envs) assert discrete.shape[0] == self.num_envs, f'expected batch of actions for {self.num_envs} environments'
out = (state, reward) state = repeat(state, '... -> b ...', b = self.num_envs)
reward = repeat(reward, ' -> b', b = self.num_envs)
return state, reward
if self.can_terminate:
shape = (self.num_envs,) if self.vectorized else (1,)
valid_step = self._step > self.terminate_after_step
terminate = (torch.rand(shape) < self.rand_terminate_prob) & valid_step
out = (*out, terminate)
# maybe truncation
if self.can_truncate:
truncate = (torch.rand(shape) < self.rand_truncate_prob) & valid_step & ~terminate
out = (*out, truncate)
self._step.add_(1)
return out

View File

@ -4,7 +4,7 @@ import torch
from torch import is_tensor from torch import is_tensor
from torch.nn import Module from torch.nn import Module
from torch.optim import AdamW from torch.optim import AdamW
from torch.utils.data import Dataset, TensorDataset, DataLoader from torch.utils.data import Dataset, DataLoader
from accelerate import Accelerator from accelerate import Accelerator
@ -12,9 +12,7 @@ from adam_atan2_pytorch import MuonAdamAtan2
from dreamer4.dreamer4 import ( from dreamer4.dreamer4 import (
VideoTokenizer, VideoTokenizer,
DynamicsWorldModel, DynamicsWorldModel
Experience,
combine_experiences
) )
from ema_pytorch import EMA from ema_pytorch import EMA
@ -287,7 +285,7 @@ class DreamTrainer(Module):
for _ in range(self.num_train_steps): for _ in range(self.num_train_steps):
dreams = self.unwrapped_model.generate( dreams = self.unwrapped_model.generate(
self.generate_timesteps + 1, # plus one for bootstrap value self.generate_timesteps,
batch_size = self.batch_size, batch_size = self.batch_size,
return_rewards_per_frame = True, return_rewards_per_frame = True,
return_agent_actions = True, return_agent_actions = True,
@ -319,221 +317,3 @@ class DreamTrainer(Module):
self.value_head_optim.zero_grad() self.value_head_optim.zero_grad()
self.print('training complete') self.print('training complete')
# training from sim
class SimTrainer(Module):
def __init__(
self,
model: DynamicsWorldModel,
optim_klass = AdamW,
batch_size = 16,
generate_timesteps = 16,
learning_rate = 3e-4,
max_grad_norm = None,
epochs = 2,
weight_decay = 0.,
accelerate_kwargs: dict = dict(),
optim_kwargs: dict = dict(),
cpu = False,
):
super().__init__()
self.accelerator = Accelerator(
cpu = cpu,
**accelerate_kwargs
)
self.model = model
optim_kwargs = dict(
lr = learning_rate,
weight_decay = weight_decay
)
self.policy_head_optim = AdamW(model.policy_head_parameters(), **optim_kwargs)
self.value_head_optim = AdamW(model.value_head_parameters(), **optim_kwargs)
self.max_grad_norm = max_grad_norm
self.epochs = epochs
self.batch_size = batch_size
self.generate_timesteps = generate_timesteps
self.unwrapped_model = self.model
(
self.model,
self.policy_head_optim,
self.value_head_optim,
) = self.accelerator.prepare(
self.model,
self.policy_head_optim,
self.value_head_optim
)
@property
def device(self):
return self.accelerator.device
@property
def unwrapped_model(self):
return self.accelerator.unwrap_model(self.model)
def print(self, *args, **kwargs):
return self.accelerator.print(*args, **kwargs)
def learn(
self,
experience: Experience
):
step_size = experience.step_size
agent_index = experience.agent_index
latents = experience.latents
old_values = experience.values
rewards = experience.rewards
has_agent_embed = exists(experience.agent_embed)
agent_embed = experience.agent_embed
discrete_actions, continuous_actions = experience.actions
discrete_log_probs, continuous_log_probs = experience.log_probs
discrete_old_action_unembeds, continuous_old_action_unembeds = default(experience.old_action_unembeds, (None, None))
# handle empties
empty_tensor = torch.empty_like(rewards)
agent_embed = default(agent_embed, empty_tensor)
has_discrete = exists(discrete_actions)
has_continuous = exists(continuous_actions)
discrete_actions = default(discrete_actions, empty_tensor)
continuous_actions = default(continuous_actions, empty_tensor)
discrete_log_probs = default(discrete_log_probs, empty_tensor)
continuous_log_probs = default(continuous_log_probs, empty_tensor)
discrete_old_action_unembeds = default(discrete_old_action_unembeds, empty_tensor)
continuous_old_action_unembeds = default(discrete_old_action_unembeds, empty_tensor)
# create the dataset and dataloader
dataset = TensorDataset(
latents,
discrete_actions,
continuous_actions,
discrete_log_probs,
continuous_log_probs,
agent_embed,
discrete_old_action_unembeds,
continuous_old_action_unembeds,
old_values,
rewards
)
dataloader = DataLoader(dataset, batch_size = self.batch_size, shuffle = True)
for epoch in range(self.epochs):
for (
latents,
discrete_actions,
continuous_actions,
discrete_log_probs,
continuous_log_probs,
agent_embed,
discrete_old_action_unembeds,
continuous_old_action_unembeds,
old_values,
rewards
) in dataloader:
actions = (
discrete_actions if has_discrete else None,
continuous_actions if has_continuous else None
)
log_probs = (
discrete_log_probs if has_discrete else None,
continuous_log_probs if has_continuous else None
)
old_action_unembeds = (
discrete_old_action_unembeds if has_discrete else None,
continuous_old_action_unembeds if has_continuous else None
)
batch_experience = Experience(
latents = latents,
actions = actions,
log_probs = log_probs,
agent_embed = agent_embed if has_agent_embed else None,
old_action_unembeds = old_action_unembeds,
values = old_values,
rewards = rewards,
step_size = step_size,
agent_index = agent_index
)
policy_head_loss, value_head_loss = self.model.learn_from_experience(batch_experience)
self.print(f'policy head loss: {policy_head_loss.item():.3f} | value head loss: {value_head_loss.item():.3f}')
# update policy head
self.accelerator.backward(policy_head_loss)
if exists(self.max_grad_norm):
self.accelerator.clip_grad_norm_(self.model.policy_head_parameters()(), self.max_grad_norm)
self.policy_head_optim.step()
self.policy_head_optim.zero_grad()
# update value head
self.accelerator.backward(value_head_loss)
if exists(self.max_grad_norm):
self.accelerator.clip_grad_norm_(self.model.value_head_parameters(), self.max_grad_norm)
self.value_head_optim.step()
self.value_head_optim.zero_grad()
self.print('training complete')
def forward(
self,
env,
num_episodes = 50000,
max_experiences_before_learn = 8,
env_is_vectorized = False
):
for _ in range(num_episodes):
total_experience = 0
experiences = []
while total_experience < max_experiences_before_learn:
experience = self.unwrapped_model.interact_with_env(env, env_is_vectorized = env_is_vectorized)
num_experience = experience.video.shape[0]
total_experience += num_experience
experiences.append(experience.cpu())
combined_experiences = combine_experiences(experiences)
self.learn(combined_experiences)
experiences.clear()
self.print('training complete')

View File

@ -1,6 +1,6 @@
[project] [project]
name = "dreamer4" name = "dreamer4"
version = "0.1.24" version = "0.0.64"
description = "Dreamer 4" description = "Dreamer 4"
authors = [ authors = [
{ name = "Phil Wang", email = "lucidrains@gmail.com" } { name = "Phil Wang", email = "lucidrains@gmail.com" }
@ -36,8 +36,7 @@ dependencies = [
"hyper-connections>=0.2.1", "hyper-connections>=0.2.1",
"torch>=2.4", "torch>=2.4",
"torchvision", "torchvision",
"x-mlps-pytorch>=0.0.29", "x-mlps-pytorch>=0.0.29"
"vit-pytorch>=1.15.3"
] ]
[project.urls] [project.urls]

View File

@ -2,9 +2,6 @@ import pytest
param = pytest.mark.parametrize param = pytest.mark.parametrize
import torch import torch
def exists(v):
return v is not None
@param('pred_orig_latent', (False, True)) @param('pred_orig_latent', (False, True))
@param('grouped_query_attn', (False, True)) @param('grouped_query_attn', (False, True))
@param('dynamics_with_video_input', (False, True)) @param('dynamics_with_video_input', (False, True))
@ -15,9 +12,7 @@ def exists(v):
@param('condition_on_actions', (False, True)) @param('condition_on_actions', (False, True))
@param('num_residual_streams', (1, 4)) @param('num_residual_streams', (1, 4))
@param('add_reward_embed_to_agent_token', (False, True)) @param('add_reward_embed_to_agent_token', (False, True))
@param('add_state_pred_head', (False, True)) @param('use_time_kv_cache', (False, True))
@param('use_time_cache', (False, True))
@param('var_len', (False, True))
def test_e2e( def test_e2e(
pred_orig_latent, pred_orig_latent,
grouped_query_attn, grouped_query_attn,
@ -29,9 +24,7 @@ def test_e2e(
condition_on_actions, condition_on_actions,
num_residual_streams, num_residual_streams,
add_reward_embed_to_agent_token, add_reward_embed_to_agent_token,
add_state_pred_head, use_time_kv_cache
use_time_cache,
var_len
): ):
from dreamer4.dreamer4 import VideoTokenizer, DynamicsWorldModel from dreamer4.dreamer4 import VideoTokenizer, DynamicsWorldModel
@ -43,9 +36,7 @@ def test_e2e(
patch_size = 32, patch_size = 32,
attn_dim_head = 16, attn_dim_head = 16,
num_latent_tokens = 4, num_latent_tokens = 4,
num_residual_streams = num_residual_streams, num_residual_streams = num_residual_streams
encoder_add_decor_aux_loss = True,
decorr_sample_frac = 1.
) )
video = torch.randn(2, 3, 4, 256, 256) video = torch.randn(2, 3, 4, 256, 256)
@ -73,13 +64,12 @@ def test_e2e(
pred_orig_latent = pred_orig_latent, pred_orig_latent = pred_orig_latent,
num_discrete_actions = 4, num_discrete_actions = 4,
attn_dim_head = 16, attn_dim_head = 16,
attn_heads = heads,
attn_kwargs = dict( attn_kwargs = dict(
heads = heads,
query_heads = query_heads, query_heads = query_heads,
), ),
prob_no_shortcut_train = prob_no_shortcut_train, prob_no_shortcut_train = prob_no_shortcut_train,
add_reward_embed_to_agent_token = add_reward_embed_to_agent_token, add_reward_embed_to_agent_token = add_reward_embed_to_agent_token,
add_state_pred_head = add_state_pred_head,
num_residual_streams = num_residual_streams num_residual_streams = num_residual_streams
) )
@ -102,13 +92,8 @@ def test_e2e(
if condition_on_actions: if condition_on_actions:
actions = torch.randint(0, 4, (2, 3, 1)) actions = torch.randint(0, 4, (2, 3, 1))
lens = None
if var_len:
lens = torch.randint(1, 4, (2,))
flow_loss = dynamics( flow_loss = dynamics(
**dynamics_input, **dynamics_input,
lens = lens,
tasks = tasks, tasks = tasks,
signal_levels = signal_levels, signal_levels = signal_levels,
step_sizes_log2 = step_sizes_log2, step_sizes_log2 = step_sizes_log2,
@ -126,7 +111,7 @@ def test_e2e(
image_width = 128, image_width = 128,
batch_size = 2, batch_size = 2,
return_rewards_per_frame = True, return_rewards_per_frame = True,
use_time_cache = use_time_cache use_time_kv_cache = use_time_kv_cache
) )
assert generations.video.shape == (2, 3, 10, 128, 128) assert generations.video.shape == (2, 3, 10, 128, 128)
@ -351,15 +336,6 @@ def test_action_embedder():
assert discrete_logits.shape == (2, 3, 8) assert discrete_logits.shape == (2, 3, 8)
assert continuous_mean_log_var.shape == (2, 3, 2, 2) assert continuous_mean_log_var.shape == (2, 3, 2, 2)
# test kl div
discrete_logits_tgt, continuous_mean_log_var_tgt = embedder.unembed(action_embed)
discrete_kl_div, continuous_kl_div = embedder.kl_div((discrete_logits, continuous_mean_log_var), (discrete_logits_tgt, continuous_mean_log_var_tgt))
assert discrete_kl_div.shape == (2, 3)
assert continuous_kl_div.shape == (2, 3)
# return discrete split by number of actions # return discrete split by number of actions
discrete_logits, continuous_mean_log_var = embedder.unembed(action_embed, return_split_discrete = True) discrete_logits, continuous_mean_log_var = embedder.unembed(action_embed, return_split_discrete = True)
@ -421,14 +397,14 @@ def test_mtp():
reward_targets, mask = create_multi_token_prediction_targets(rewards, 3) # say three token lookahead reward_targets, mask = create_multi_token_prediction_targets(rewards, 3) # say three token lookahead
assert reward_targets.shape == (3, 16, 3) assert reward_targets.shape == (3, 15, 3)
assert mask.shape == (3, 16, 3) assert mask.shape == (3, 15, 3)
actions = torch.randint(0, 10, (3, 16, 2)) actions = torch.randint(0, 10, (3, 16, 2))
action_targets, mask = create_multi_token_prediction_targets(actions, 3) action_targets, mask = create_multi_token_prediction_targets(actions, 3)
assert action_targets.shape == (3, 16, 3, 2) assert action_targets.shape == (3, 15, 3, 2)
assert mask.shape == (3, 16, 3) assert mask.shape == (3, 15, 3)
from dreamer4.dreamer4 import ActionEmbedder from dreamer4.dreamer4 import ActionEmbedder
@ -620,22 +596,12 @@ def test_cache_generate():
num_residual_streams = 1 num_residual_streams = 1
) )
generated, time_cache = dynamics.generate(1, return_time_cache = True) generated, time_kv_cache = dynamics.generate(1, return_time_kv_cache = True)
generated, time_cache = dynamics.generate(1, time_cache = time_cache, return_time_cache = True) generated, time_kv_cache = dynamics.generate(1, time_kv_cache = time_kv_cache, return_time_kv_cache = True)
generated, time_cache = dynamics.generate(1, time_cache = time_cache, return_time_cache = True) generated, time_kv_cache = dynamics.generate(1, time_kv_cache = time_kv_cache, return_time_kv_cache = True)
@param('vectorized', (False, True)) @param('vectorized', (False, True))
@param('use_pmpo', (False, True)) def test_online_rl(vectorized):
@param('env_can_terminate', (False, True))
@param('env_can_truncate', (False, True))
@param('store_agent_embed', (False, True))
def test_online_rl(
vectorized,
use_pmpo,
env_can_terminate,
env_can_truncate,
store_agent_embed
):
from dreamer4.dreamer4 import DynamicsWorldModel, VideoTokenizer from dreamer4.dreamer4 import DynamicsWorldModel, VideoTokenizer
tokenizer = VideoTokenizer( tokenizer = VideoTokenizer(
@ -646,9 +612,7 @@ def test_online_rl(
dim_latent = 16, dim_latent = 16,
patch_size = 32, patch_size = 32,
attn_dim_head = 16, attn_dim_head = 16,
num_latent_tokens = 1, num_latent_tokens = 1
image_height = 256,
image_width = 256,
) )
world_model_and_policy = DynamicsWorldModel( world_model_and_policy = DynamicsWorldModel(
@ -669,163 +633,11 @@ def test_online_rl(
) )
from dreamer4.mocks import MockEnv from dreamer4.mocks import MockEnv
from dreamer4.dreamer4 import combine_experiences mock_env = MockEnv((256, 256), vectorized = vectorized, num_envs = 4)
mock_env = MockEnv( one_experience = world_model_and_policy.interact_with_env(mock_env, max_timesteps = 16, env_is_vectorized = vectorized)
(256, 256),
vectorized = vectorized,
num_envs = 4,
terminate_after_step = 2 if env_can_terminate else None,
can_truncate = env_can_truncate,
rand_terminate_prob = 0.1
)
# manually actor_loss, critic_loss = world_model_and_policy.learn_from_experience(one_experience)
dream_experience = world_model_and_policy.generate(10, batch_size = 1, store_agent_embed = store_agent_embed, return_for_policy_optimization = True)
one_experience = world_model_and_policy.interact_with_env(mock_env, max_timesteps = 8, env_is_vectorized = vectorized, store_agent_embed = store_agent_embed)
another_experience = world_model_and_policy.interact_with_env(mock_env, max_timesteps = 16, env_is_vectorized = vectorized, store_agent_embed = store_agent_embed)
combined_experience = combine_experiences([dream_experience, one_experience, another_experience])
# quick test moving the experience to different devices
if torch.cuda.is_available():
combined_experience = combined_experience.to(torch.device('cuda'))
combined_experience = combined_experience.to(world_model_and_policy.device)
if store_agent_embed:
assert exists(combined_experience.agent_embed)
actor_loss, critic_loss = world_model_and_policy.learn_from_experience(combined_experience, use_pmpo = use_pmpo)
actor_loss.backward() actor_loss.backward()
critic_loss.backward() critic_loss.backward()
# with trainer
from dreamer4.trainers import SimTrainer
trainer = SimTrainer(
world_model_and_policy,
batch_size = 4,
cpu = True
)
trainer(mock_env, num_episodes = 2, env_is_vectorized = vectorized)
@param('num_video_views', (1, 2))
def test_proprioception(
num_video_views
):
from dreamer4.dreamer4 import VideoTokenizer, DynamicsWorldModel
tokenizer = VideoTokenizer(
512,
dim_latent = 32,
patch_size = 32,
encoder_depth = 2,
decoder_depth = 2,
time_block_every = 2,
attn_heads = 8,
image_height = 256,
image_width = 256,
attn_kwargs = dict(
query_heads = 16
)
)
dynamics = DynamicsWorldModel(
512,
num_agents = 1,
video_tokenizer = tokenizer,
dim_latent = 32,
dim_proprio = 21,
num_tasks = 4,
num_video_views = num_video_views,
num_discrete_actions = 4,
num_residual_streams = 1
)
if num_video_views > 1:
video_shape = (2, num_video_views, 3, 10, 256, 256)
else:
video_shape = (2, 3, 10, 256, 256)
video = torch.randn(*video_shape)
rewards = torch.randn(2, 10)
proprio = torch.randn(2, 10, 21)
discrete_actions = torch.randint(0, 4, (2, 10, 1))
tasks = torch.randint(0, 4, (2,))
loss = dynamics(
video = video,
rewards = rewards,
tasks = tasks,
proprio = proprio,
discrete_actions = discrete_actions
)
loss.backward()
generations = dynamics.generate(
10,
batch_size = 2,
return_decoded_video = True
)
assert exists(generations.proprio)
assert generations.video.shape == video_shape
def test_epo():
from dreamer4.dreamer4 import VideoTokenizer, DynamicsWorldModel
tokenizer = VideoTokenizer(
512,
dim_latent = 32,
patch_size = 32,
encoder_depth = 2,
decoder_depth = 2,
time_block_every = 2,
attn_heads = 8,
image_height = 256,
image_width = 256,
attn_kwargs = dict(
query_heads = 16
)
)
dynamics = DynamicsWorldModel(
512,
num_agents = 1,
video_tokenizer = tokenizer,
dim_latent = 32,
dim_proprio = 21,
num_tasks = 4,
num_latent_genes = 16,
num_discrete_actions = 4,
num_residual_streams = 1
)
fitness = torch.randn(16,)
dynamics.evolve_(fitness)
def test_images_to_video_tokenizer():
import torch
from dreamer4 import VideoTokenizer, DynamicsWorldModel, AxialSpaceTimeTransformer
tokenizer = VideoTokenizer(
dim = 512,
dim_latent = 32,
patch_size = 32,
image_height = 256,
image_width = 256,
encoder_add_decor_aux_loss = True
)
images = torch.randn(2, 3, 256, 256)
loss, (losses, recon_images) = tokenizer(images, return_intermediates = True)
loss.backward()
assert images.shape == recon_images.shape