Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e4ee4d905a |
6
.github/workflows/test.yml
vendored
6
.github/workflows/test.yml
vendored
@ -5,11 +5,11 @@ jobs:
|
|||||||
build:
|
build:
|
||||||
|
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
timeout-minutes: 60
|
timeout-minutes: 20
|
||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
group: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19]
|
group: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
@ -24,4 +24,4 @@ jobs:
|
|||||||
python -m uv pip install -e .[test]
|
python -m uv pip install -e .[test]
|
||||||
- name: Test with pytest
|
- name: Test with pytest
|
||||||
run: |
|
run: |
|
||||||
python -m pytest --num-shards 20 --shard-id ${{ matrix.group }} tests/
|
python -m pytest --num-shards 10 --shard-id ${{ matrix.group }} tests/
|
||||||
|
|||||||
95
README.md
95
README.md
@ -1,99 +1,10 @@
|
|||||||
<img src="./dreamer4-fig2.png" width="400px"></img>
|
<img src="./dreamer4-fig2.png" width="400px"></img>
|
||||||
|
|
||||||
## Dreamer 4
|
## Dreamer 4 (wip)
|
||||||
|
|
||||||
Implementation of Danijar's [latest iteration](https://arxiv.org/abs/2509.24527v1) for his [Dreamer](https://danijar.com/project/dreamer4/) line of work
|
Implementation of Danijar's [latest iteration](https://arxiv.org/abs/2509.24527v1) for his [Dreamer](https://danijar.com/project/dreamer4/) line of work
|
||||||
|
|
||||||
[Discord channel](https://discord.gg/PmGR7KRwxq) for collaborating with other researchers interested in this work
|
[Temporary Discord](https://discord.gg/MkACrrkrYR)
|
||||||
|
|
||||||
## Appreciation
|
|
||||||
|
|
||||||
- [@dirkmcpherson](https://github.com/dirkmcpherson) for fixes to typo errors and unpassed arguments!
|
|
||||||
|
|
||||||
## Install
|
|
||||||
|
|
||||||
```bash
|
|
||||||
$ pip install dreamer4
|
|
||||||
```
|
|
||||||
|
|
||||||
## Usage
|
|
||||||
|
|
||||||
```python
|
|
||||||
import torch
|
|
||||||
from dreamer4 import VideoTokenizer, DynamicsWorldModel
|
|
||||||
|
|
||||||
# video tokenizer, learned through MAE + lpips
|
|
||||||
|
|
||||||
tokenizer = VideoTokenizer(
|
|
||||||
dim = 512,
|
|
||||||
dim_latent = 32,
|
|
||||||
patch_size = 32,
|
|
||||||
image_height = 256,
|
|
||||||
image_width = 256
|
|
||||||
)
|
|
||||||
|
|
||||||
video = torch.randn(2, 3, 10, 256, 256)
|
|
||||||
|
|
||||||
# learn the tokenizer
|
|
||||||
|
|
||||||
loss = tokenizer(video)
|
|
||||||
loss.backward()
|
|
||||||
|
|
||||||
# dynamics world model
|
|
||||||
|
|
||||||
world_model = DynamicsWorldModel(
|
|
||||||
dim = 512,
|
|
||||||
dim_latent = 32,
|
|
||||||
video_tokenizer = tokenizer,
|
|
||||||
num_discrete_actions = 4,
|
|
||||||
num_residual_streams = 1
|
|
||||||
)
|
|
||||||
|
|
||||||
# state, action, rewards
|
|
||||||
|
|
||||||
video = torch.randn(2, 3, 10, 256, 256)
|
|
||||||
discrete_actions = torch.randint(0, 4, (2, 10, 1))
|
|
||||||
rewards = torch.randn(2, 10)
|
|
||||||
|
|
||||||
# learn dynamics / behavior cloned model
|
|
||||||
|
|
||||||
loss = world_model(
|
|
||||||
video = video,
|
|
||||||
rewards = rewards,
|
|
||||||
discrete_actions = discrete_actions
|
|
||||||
)
|
|
||||||
|
|
||||||
loss.backward()
|
|
||||||
|
|
||||||
# do the above with much data
|
|
||||||
|
|
||||||
# then generate dreams
|
|
||||||
|
|
||||||
dreams = world_model.generate(
|
|
||||||
10,
|
|
||||||
batch_size = 2,
|
|
||||||
return_decoded_video = True,
|
|
||||||
return_for_policy_optimization = True
|
|
||||||
)
|
|
||||||
|
|
||||||
# learn from the dreams
|
|
||||||
|
|
||||||
actor_loss, critic_loss = world_model.learn_from_experience(dreams)
|
|
||||||
|
|
||||||
(actor_loss + critic_loss).backward()
|
|
||||||
|
|
||||||
# learn from environment
|
|
||||||
|
|
||||||
from dreamer4.mocks import MockEnv
|
|
||||||
|
|
||||||
mock_env = MockEnv((256, 256), vectorized = True, num_envs = 4)
|
|
||||||
|
|
||||||
experience = world_model.interact_with_env(mock_env, max_timesteps = 8, env_is_vectorized = True)
|
|
||||||
|
|
||||||
actor_loss, critic_loss = world_model.learn_from_experience(experience)
|
|
||||||
|
|
||||||
(actor_loss + critic_loss).backward()
|
|
||||||
```
|
|
||||||
|
|
||||||
## Citation
|
## Citation
|
||||||
|
|
||||||
@ -108,5 +19,3 @@ actor_loss, critic_loss = world_model.learn_from_experience(experience)
|
|||||||
url = {https://arxiv.org/abs/2509.24527},
|
url = {https://arxiv.org/abs/2509.24527},
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
*the conquest of nature is to be achieved through number and measure - angels to Descartes in a dream*
|
|
||||||
|
|||||||
@ -1,7 +1,6 @@
|
|||||||
from dreamer4.dreamer4 import (
|
from dreamer4.dreamer4 import (
|
||||||
VideoTokenizer,
|
VideoTokenizer,
|
||||||
DynamicsWorldModel,
|
DynamicsWorldModel
|
||||||
AxialSpaceTimeTransformer
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
1411
dreamer4/dreamer4.py
1411
dreamer4/dreamer4.py
File diff suppressed because it is too large
Load Diff
@ -7,11 +7,6 @@ from torch.nn import Module
|
|||||||
|
|
||||||
from einops import repeat
|
from einops import repeat
|
||||||
|
|
||||||
# helpers
|
|
||||||
|
|
||||||
def exists(v):
|
|
||||||
return v is not None
|
|
||||||
|
|
||||||
# mock env
|
# mock env
|
||||||
|
|
||||||
class MockEnv(Module):
|
class MockEnv(Module):
|
||||||
@ -20,11 +15,7 @@ class MockEnv(Module):
|
|||||||
image_shape,
|
image_shape,
|
||||||
reward_range = (-100, 100),
|
reward_range = (-100, 100),
|
||||||
num_envs = 1,
|
num_envs = 1,
|
||||||
vectorized = False,
|
vectorized = False
|
||||||
terminate_after_step = None,
|
|
||||||
rand_terminate_prob = 0.05,
|
|
||||||
can_truncate = False,
|
|
||||||
rand_truncate_prob = 0.05,
|
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.image_shape = image_shape
|
self.image_shape = image_shape
|
||||||
@ -34,15 +25,6 @@ class MockEnv(Module):
|
|||||||
self.vectorized = vectorized
|
self.vectorized = vectorized
|
||||||
assert not (vectorized and num_envs == 1)
|
assert not (vectorized and num_envs == 1)
|
||||||
|
|
||||||
# mocking termination and truncation
|
|
||||||
|
|
||||||
self.can_terminate = exists(terminate_after_step)
|
|
||||||
self.terminate_after_step = terminate_after_step
|
|
||||||
self.rand_terminate_prob = rand_terminate_prob
|
|
||||||
|
|
||||||
self.can_truncate = can_truncate
|
|
||||||
self.rand_truncate_prob = rand_truncate_prob
|
|
||||||
|
|
||||||
self.register_buffer('_step', tensor(0))
|
self.register_buffer('_step', tensor(0))
|
||||||
|
|
||||||
def get_random_state(self):
|
def get_random_state(self):
|
||||||
@ -68,30 +50,13 @@ class MockEnv(Module):
|
|||||||
|
|
||||||
reward = empty(()).uniform_(*self.reward_range)
|
reward = empty(()).uniform_(*self.reward_range)
|
||||||
|
|
||||||
if self.vectorized:
|
if not self.vectorized:
|
||||||
discrete, continuous = actions
|
return state, reward
|
||||||
assert discrete.shape[0] == self.num_envs, f'expected batch of actions for {self.num_envs} environments'
|
|
||||||
|
|
||||||
state = repeat(state, '... -> b ...', b = self.num_envs)
|
discrete, continuous = actions
|
||||||
reward = repeat(reward, ' -> b', b = self.num_envs)
|
assert discrete.shape[0] == self.num_envs, f'expected batch of actions for {self.num_envs} environments'
|
||||||
|
|
||||||
out = (state, reward)
|
state = repeat(state, '... -> b ...', b = self.num_envs)
|
||||||
|
reward = repeat(reward, ' -> b', b = self.num_envs)
|
||||||
|
|
||||||
|
return state, reward
|
||||||
if self.can_terminate:
|
|
||||||
shape = (self.num_envs,) if self.vectorized else (1,)
|
|
||||||
valid_step = self._step > self.terminate_after_step
|
|
||||||
|
|
||||||
terminate = (torch.rand(shape) < self.rand_terminate_prob) & valid_step
|
|
||||||
|
|
||||||
out = (*out, terminate)
|
|
||||||
|
|
||||||
# maybe truncation
|
|
||||||
|
|
||||||
if self.can_truncate:
|
|
||||||
truncate = (torch.rand(shape) < self.rand_truncate_prob) & valid_step & ~terminate
|
|
||||||
out = (*out, truncate)
|
|
||||||
|
|
||||||
self._step.add_(1)
|
|
||||||
|
|
||||||
return out
|
|
||||||
|
|||||||
@ -4,7 +4,7 @@ import torch
|
|||||||
from torch import is_tensor
|
from torch import is_tensor
|
||||||
from torch.nn import Module
|
from torch.nn import Module
|
||||||
from torch.optim import AdamW
|
from torch.optim import AdamW
|
||||||
from torch.utils.data import Dataset, TensorDataset, DataLoader
|
from torch.utils.data import Dataset, DataLoader
|
||||||
|
|
||||||
from accelerate import Accelerator
|
from accelerate import Accelerator
|
||||||
|
|
||||||
@ -12,9 +12,7 @@ from adam_atan2_pytorch import MuonAdamAtan2
|
|||||||
|
|
||||||
from dreamer4.dreamer4 import (
|
from dreamer4.dreamer4 import (
|
||||||
VideoTokenizer,
|
VideoTokenizer,
|
||||||
DynamicsWorldModel,
|
DynamicsWorldModel
|
||||||
Experience,
|
|
||||||
combine_experiences
|
|
||||||
)
|
)
|
||||||
|
|
||||||
from ema_pytorch import EMA
|
from ema_pytorch import EMA
|
||||||
@ -287,7 +285,7 @@ class DreamTrainer(Module):
|
|||||||
for _ in range(self.num_train_steps):
|
for _ in range(self.num_train_steps):
|
||||||
|
|
||||||
dreams = self.unwrapped_model.generate(
|
dreams = self.unwrapped_model.generate(
|
||||||
self.generate_timesteps + 1, # plus one for bootstrap value
|
self.generate_timesteps,
|
||||||
batch_size = self.batch_size,
|
batch_size = self.batch_size,
|
||||||
return_rewards_per_frame = True,
|
return_rewards_per_frame = True,
|
||||||
return_agent_actions = True,
|
return_agent_actions = True,
|
||||||
@ -319,221 +317,3 @@ class DreamTrainer(Module):
|
|||||||
self.value_head_optim.zero_grad()
|
self.value_head_optim.zero_grad()
|
||||||
|
|
||||||
self.print('training complete')
|
self.print('training complete')
|
||||||
|
|
||||||
# training from sim
|
|
||||||
|
|
||||||
class SimTrainer(Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model: DynamicsWorldModel,
|
|
||||||
optim_klass = AdamW,
|
|
||||||
batch_size = 16,
|
|
||||||
generate_timesteps = 16,
|
|
||||||
learning_rate = 3e-4,
|
|
||||||
max_grad_norm = None,
|
|
||||||
epochs = 2,
|
|
||||||
weight_decay = 0.,
|
|
||||||
accelerate_kwargs: dict = dict(),
|
|
||||||
optim_kwargs: dict = dict(),
|
|
||||||
cpu = False,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.accelerator = Accelerator(
|
|
||||||
cpu = cpu,
|
|
||||||
**accelerate_kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
self.model = model
|
|
||||||
|
|
||||||
optim_kwargs = dict(
|
|
||||||
lr = learning_rate,
|
|
||||||
weight_decay = weight_decay
|
|
||||||
)
|
|
||||||
|
|
||||||
self.policy_head_optim = AdamW(model.policy_head_parameters(), **optim_kwargs)
|
|
||||||
self.value_head_optim = AdamW(model.value_head_parameters(), **optim_kwargs)
|
|
||||||
|
|
||||||
self.max_grad_norm = max_grad_norm
|
|
||||||
|
|
||||||
self.epochs = epochs
|
|
||||||
self.batch_size = batch_size
|
|
||||||
|
|
||||||
self.generate_timesteps = generate_timesteps
|
|
||||||
|
|
||||||
self.unwrapped_model = self.model
|
|
||||||
|
|
||||||
(
|
|
||||||
self.model,
|
|
||||||
self.policy_head_optim,
|
|
||||||
self.value_head_optim,
|
|
||||||
) = self.accelerator.prepare(
|
|
||||||
self.model,
|
|
||||||
self.policy_head_optim,
|
|
||||||
self.value_head_optim
|
|
||||||
)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def device(self):
|
|
||||||
return self.accelerator.device
|
|
||||||
|
|
||||||
@property
|
|
||||||
def unwrapped_model(self):
|
|
||||||
return self.accelerator.unwrap_model(self.model)
|
|
||||||
|
|
||||||
def print(self, *args, **kwargs):
|
|
||||||
return self.accelerator.print(*args, **kwargs)
|
|
||||||
|
|
||||||
def learn(
|
|
||||||
self,
|
|
||||||
experience: Experience
|
|
||||||
):
|
|
||||||
|
|
||||||
step_size = experience.step_size
|
|
||||||
agent_index = experience.agent_index
|
|
||||||
|
|
||||||
latents = experience.latents
|
|
||||||
old_values = experience.values
|
|
||||||
rewards = experience.rewards
|
|
||||||
|
|
||||||
has_agent_embed = exists(experience.agent_embed)
|
|
||||||
agent_embed = experience.agent_embed
|
|
||||||
|
|
||||||
discrete_actions, continuous_actions = experience.actions
|
|
||||||
discrete_log_probs, continuous_log_probs = experience.log_probs
|
|
||||||
|
|
||||||
discrete_old_action_unembeds, continuous_old_action_unembeds = default(experience.old_action_unembeds, (None, None))
|
|
||||||
|
|
||||||
# handle empties
|
|
||||||
|
|
||||||
empty_tensor = torch.empty_like(rewards)
|
|
||||||
|
|
||||||
agent_embed = default(agent_embed, empty_tensor)
|
|
||||||
|
|
||||||
has_discrete = exists(discrete_actions)
|
|
||||||
has_continuous = exists(continuous_actions)
|
|
||||||
|
|
||||||
discrete_actions = default(discrete_actions, empty_tensor)
|
|
||||||
continuous_actions = default(continuous_actions, empty_tensor)
|
|
||||||
|
|
||||||
discrete_log_probs = default(discrete_log_probs, empty_tensor)
|
|
||||||
continuous_log_probs = default(continuous_log_probs, empty_tensor)
|
|
||||||
|
|
||||||
discrete_old_action_unembeds = default(discrete_old_action_unembeds, empty_tensor)
|
|
||||||
continuous_old_action_unembeds = default(discrete_old_action_unembeds, empty_tensor)
|
|
||||||
|
|
||||||
# create the dataset and dataloader
|
|
||||||
|
|
||||||
dataset = TensorDataset(
|
|
||||||
latents,
|
|
||||||
discrete_actions,
|
|
||||||
continuous_actions,
|
|
||||||
discrete_log_probs,
|
|
||||||
continuous_log_probs,
|
|
||||||
agent_embed,
|
|
||||||
discrete_old_action_unembeds,
|
|
||||||
continuous_old_action_unembeds,
|
|
||||||
old_values,
|
|
||||||
rewards
|
|
||||||
)
|
|
||||||
|
|
||||||
dataloader = DataLoader(dataset, batch_size = self.batch_size, shuffle = True)
|
|
||||||
|
|
||||||
for epoch in range(self.epochs):
|
|
||||||
|
|
||||||
for (
|
|
||||||
latents,
|
|
||||||
discrete_actions,
|
|
||||||
continuous_actions,
|
|
||||||
discrete_log_probs,
|
|
||||||
continuous_log_probs,
|
|
||||||
agent_embed,
|
|
||||||
discrete_old_action_unembeds,
|
|
||||||
continuous_old_action_unembeds,
|
|
||||||
old_values,
|
|
||||||
rewards
|
|
||||||
) in dataloader:
|
|
||||||
|
|
||||||
actions = (
|
|
||||||
discrete_actions if has_discrete else None,
|
|
||||||
continuous_actions if has_continuous else None
|
|
||||||
)
|
|
||||||
|
|
||||||
log_probs = (
|
|
||||||
discrete_log_probs if has_discrete else None,
|
|
||||||
continuous_log_probs if has_continuous else None
|
|
||||||
)
|
|
||||||
|
|
||||||
old_action_unembeds = (
|
|
||||||
discrete_old_action_unembeds if has_discrete else None,
|
|
||||||
continuous_old_action_unembeds if has_continuous else None
|
|
||||||
)
|
|
||||||
|
|
||||||
batch_experience = Experience(
|
|
||||||
latents = latents,
|
|
||||||
actions = actions,
|
|
||||||
log_probs = log_probs,
|
|
||||||
agent_embed = agent_embed if has_agent_embed else None,
|
|
||||||
old_action_unembeds = old_action_unembeds,
|
|
||||||
values = old_values,
|
|
||||||
rewards = rewards,
|
|
||||||
step_size = step_size,
|
|
||||||
agent_index = agent_index
|
|
||||||
)
|
|
||||||
|
|
||||||
policy_head_loss, value_head_loss = self.model.learn_from_experience(batch_experience)
|
|
||||||
|
|
||||||
self.print(f'policy head loss: {policy_head_loss.item():.3f} | value head loss: {value_head_loss.item():.3f}')
|
|
||||||
|
|
||||||
# update policy head
|
|
||||||
|
|
||||||
self.accelerator.backward(policy_head_loss)
|
|
||||||
|
|
||||||
if exists(self.max_grad_norm):
|
|
||||||
self.accelerator.clip_grad_norm_(self.model.policy_head_parameters()(), self.max_grad_norm)
|
|
||||||
|
|
||||||
self.policy_head_optim.step()
|
|
||||||
self.policy_head_optim.zero_grad()
|
|
||||||
|
|
||||||
# update value head
|
|
||||||
|
|
||||||
self.accelerator.backward(value_head_loss)
|
|
||||||
|
|
||||||
if exists(self.max_grad_norm):
|
|
||||||
self.accelerator.clip_grad_norm_(self.model.value_head_parameters(), self.max_grad_norm)
|
|
||||||
|
|
||||||
self.value_head_optim.step()
|
|
||||||
self.value_head_optim.zero_grad()
|
|
||||||
|
|
||||||
self.print('training complete')
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
env,
|
|
||||||
num_episodes = 50000,
|
|
||||||
max_experiences_before_learn = 8,
|
|
||||||
env_is_vectorized = False
|
|
||||||
):
|
|
||||||
|
|
||||||
for _ in range(num_episodes):
|
|
||||||
|
|
||||||
total_experience = 0
|
|
||||||
experiences = []
|
|
||||||
|
|
||||||
while total_experience < max_experiences_before_learn:
|
|
||||||
|
|
||||||
experience = self.unwrapped_model.interact_with_env(env, env_is_vectorized = env_is_vectorized)
|
|
||||||
|
|
||||||
num_experience = experience.video.shape[0]
|
|
||||||
|
|
||||||
total_experience += num_experience
|
|
||||||
|
|
||||||
experiences.append(experience.cpu())
|
|
||||||
|
|
||||||
combined_experiences = combine_experiences(experiences)
|
|
||||||
|
|
||||||
self.learn(combined_experiences)
|
|
||||||
|
|
||||||
experiences.clear()
|
|
||||||
|
|
||||||
self.print('training complete')
|
|
||||||
@ -1,6 +1,6 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "dreamer4"
|
name = "dreamer4"
|
||||||
version = "0.1.24"
|
version = "0.0.64"
|
||||||
description = "Dreamer 4"
|
description = "Dreamer 4"
|
||||||
authors = [
|
authors = [
|
||||||
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
||||||
@ -36,8 +36,7 @@ dependencies = [
|
|||||||
"hyper-connections>=0.2.1",
|
"hyper-connections>=0.2.1",
|
||||||
"torch>=2.4",
|
"torch>=2.4",
|
||||||
"torchvision",
|
"torchvision",
|
||||||
"x-mlps-pytorch>=0.0.29",
|
"x-mlps-pytorch>=0.0.29"
|
||||||
"vit-pytorch>=1.15.3"
|
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.urls]
|
[project.urls]
|
||||||
|
|||||||
@ -2,9 +2,6 @@ import pytest
|
|||||||
param = pytest.mark.parametrize
|
param = pytest.mark.parametrize
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
def exists(v):
|
|
||||||
return v is not None
|
|
||||||
|
|
||||||
@param('pred_orig_latent', (False, True))
|
@param('pred_orig_latent', (False, True))
|
||||||
@param('grouped_query_attn', (False, True))
|
@param('grouped_query_attn', (False, True))
|
||||||
@param('dynamics_with_video_input', (False, True))
|
@param('dynamics_with_video_input', (False, True))
|
||||||
@ -15,9 +12,7 @@ def exists(v):
|
|||||||
@param('condition_on_actions', (False, True))
|
@param('condition_on_actions', (False, True))
|
||||||
@param('num_residual_streams', (1, 4))
|
@param('num_residual_streams', (1, 4))
|
||||||
@param('add_reward_embed_to_agent_token', (False, True))
|
@param('add_reward_embed_to_agent_token', (False, True))
|
||||||
@param('add_state_pred_head', (False, True))
|
@param('use_time_kv_cache', (False, True))
|
||||||
@param('use_time_cache', (False, True))
|
|
||||||
@param('var_len', (False, True))
|
|
||||||
def test_e2e(
|
def test_e2e(
|
||||||
pred_orig_latent,
|
pred_orig_latent,
|
||||||
grouped_query_attn,
|
grouped_query_attn,
|
||||||
@ -29,9 +24,7 @@ def test_e2e(
|
|||||||
condition_on_actions,
|
condition_on_actions,
|
||||||
num_residual_streams,
|
num_residual_streams,
|
||||||
add_reward_embed_to_agent_token,
|
add_reward_embed_to_agent_token,
|
||||||
add_state_pred_head,
|
use_time_kv_cache
|
||||||
use_time_cache,
|
|
||||||
var_len
|
|
||||||
):
|
):
|
||||||
from dreamer4.dreamer4 import VideoTokenizer, DynamicsWorldModel
|
from dreamer4.dreamer4 import VideoTokenizer, DynamicsWorldModel
|
||||||
|
|
||||||
@ -43,9 +36,7 @@ def test_e2e(
|
|||||||
patch_size = 32,
|
patch_size = 32,
|
||||||
attn_dim_head = 16,
|
attn_dim_head = 16,
|
||||||
num_latent_tokens = 4,
|
num_latent_tokens = 4,
|
||||||
num_residual_streams = num_residual_streams,
|
num_residual_streams = num_residual_streams
|
||||||
encoder_add_decor_aux_loss = True,
|
|
||||||
decorr_sample_frac = 1.
|
|
||||||
)
|
)
|
||||||
|
|
||||||
video = torch.randn(2, 3, 4, 256, 256)
|
video = torch.randn(2, 3, 4, 256, 256)
|
||||||
@ -73,13 +64,12 @@ def test_e2e(
|
|||||||
pred_orig_latent = pred_orig_latent,
|
pred_orig_latent = pred_orig_latent,
|
||||||
num_discrete_actions = 4,
|
num_discrete_actions = 4,
|
||||||
attn_dim_head = 16,
|
attn_dim_head = 16,
|
||||||
attn_heads = heads,
|
|
||||||
attn_kwargs = dict(
|
attn_kwargs = dict(
|
||||||
|
heads = heads,
|
||||||
query_heads = query_heads,
|
query_heads = query_heads,
|
||||||
),
|
),
|
||||||
prob_no_shortcut_train = prob_no_shortcut_train,
|
prob_no_shortcut_train = prob_no_shortcut_train,
|
||||||
add_reward_embed_to_agent_token = add_reward_embed_to_agent_token,
|
add_reward_embed_to_agent_token = add_reward_embed_to_agent_token,
|
||||||
add_state_pred_head = add_state_pred_head,
|
|
||||||
num_residual_streams = num_residual_streams
|
num_residual_streams = num_residual_streams
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -102,13 +92,8 @@ def test_e2e(
|
|||||||
if condition_on_actions:
|
if condition_on_actions:
|
||||||
actions = torch.randint(0, 4, (2, 3, 1))
|
actions = torch.randint(0, 4, (2, 3, 1))
|
||||||
|
|
||||||
lens = None
|
|
||||||
if var_len:
|
|
||||||
lens = torch.randint(1, 4, (2,))
|
|
||||||
|
|
||||||
flow_loss = dynamics(
|
flow_loss = dynamics(
|
||||||
**dynamics_input,
|
**dynamics_input,
|
||||||
lens = lens,
|
|
||||||
tasks = tasks,
|
tasks = tasks,
|
||||||
signal_levels = signal_levels,
|
signal_levels = signal_levels,
|
||||||
step_sizes_log2 = step_sizes_log2,
|
step_sizes_log2 = step_sizes_log2,
|
||||||
@ -126,7 +111,7 @@ def test_e2e(
|
|||||||
image_width = 128,
|
image_width = 128,
|
||||||
batch_size = 2,
|
batch_size = 2,
|
||||||
return_rewards_per_frame = True,
|
return_rewards_per_frame = True,
|
||||||
use_time_cache = use_time_cache
|
use_time_kv_cache = use_time_kv_cache
|
||||||
)
|
)
|
||||||
|
|
||||||
assert generations.video.shape == (2, 3, 10, 128, 128)
|
assert generations.video.shape == (2, 3, 10, 128, 128)
|
||||||
@ -351,15 +336,6 @@ def test_action_embedder():
|
|||||||
assert discrete_logits.shape == (2, 3, 8)
|
assert discrete_logits.shape == (2, 3, 8)
|
||||||
assert continuous_mean_log_var.shape == (2, 3, 2, 2)
|
assert continuous_mean_log_var.shape == (2, 3, 2, 2)
|
||||||
|
|
||||||
# test kl div
|
|
||||||
|
|
||||||
discrete_logits_tgt, continuous_mean_log_var_tgt = embedder.unembed(action_embed)
|
|
||||||
|
|
||||||
discrete_kl_div, continuous_kl_div = embedder.kl_div((discrete_logits, continuous_mean_log_var), (discrete_logits_tgt, continuous_mean_log_var_tgt))
|
|
||||||
|
|
||||||
assert discrete_kl_div.shape == (2, 3)
|
|
||||||
assert continuous_kl_div.shape == (2, 3)
|
|
||||||
|
|
||||||
# return discrete split by number of actions
|
# return discrete split by number of actions
|
||||||
|
|
||||||
discrete_logits, continuous_mean_log_var = embedder.unembed(action_embed, return_split_discrete = True)
|
discrete_logits, continuous_mean_log_var = embedder.unembed(action_embed, return_split_discrete = True)
|
||||||
@ -421,14 +397,14 @@ def test_mtp():
|
|||||||
|
|
||||||
reward_targets, mask = create_multi_token_prediction_targets(rewards, 3) # say three token lookahead
|
reward_targets, mask = create_multi_token_prediction_targets(rewards, 3) # say three token lookahead
|
||||||
|
|
||||||
assert reward_targets.shape == (3, 16, 3)
|
assert reward_targets.shape == (3, 15, 3)
|
||||||
assert mask.shape == (3, 16, 3)
|
assert mask.shape == (3, 15, 3)
|
||||||
|
|
||||||
actions = torch.randint(0, 10, (3, 16, 2))
|
actions = torch.randint(0, 10, (3, 16, 2))
|
||||||
action_targets, mask = create_multi_token_prediction_targets(actions, 3)
|
action_targets, mask = create_multi_token_prediction_targets(actions, 3)
|
||||||
|
|
||||||
assert action_targets.shape == (3, 16, 3, 2)
|
assert action_targets.shape == (3, 15, 3, 2)
|
||||||
assert mask.shape == (3, 16, 3)
|
assert mask.shape == (3, 15, 3)
|
||||||
|
|
||||||
from dreamer4.dreamer4 import ActionEmbedder
|
from dreamer4.dreamer4 import ActionEmbedder
|
||||||
|
|
||||||
@ -620,22 +596,12 @@ def test_cache_generate():
|
|||||||
num_residual_streams = 1
|
num_residual_streams = 1
|
||||||
)
|
)
|
||||||
|
|
||||||
generated, time_cache = dynamics.generate(1, return_time_cache = True)
|
generated, time_kv_cache = dynamics.generate(1, return_time_kv_cache = True)
|
||||||
generated, time_cache = dynamics.generate(1, time_cache = time_cache, return_time_cache = True)
|
generated, time_kv_cache = dynamics.generate(1, time_kv_cache = time_kv_cache, return_time_kv_cache = True)
|
||||||
generated, time_cache = dynamics.generate(1, time_cache = time_cache, return_time_cache = True)
|
generated, time_kv_cache = dynamics.generate(1, time_kv_cache = time_kv_cache, return_time_kv_cache = True)
|
||||||
|
|
||||||
@param('vectorized', (False, True))
|
@param('vectorized', (False, True))
|
||||||
@param('use_pmpo', (False, True))
|
def test_online_rl(vectorized):
|
||||||
@param('env_can_terminate', (False, True))
|
|
||||||
@param('env_can_truncate', (False, True))
|
|
||||||
@param('store_agent_embed', (False, True))
|
|
||||||
def test_online_rl(
|
|
||||||
vectorized,
|
|
||||||
use_pmpo,
|
|
||||||
env_can_terminate,
|
|
||||||
env_can_truncate,
|
|
||||||
store_agent_embed
|
|
||||||
):
|
|
||||||
from dreamer4.dreamer4 import DynamicsWorldModel, VideoTokenizer
|
from dreamer4.dreamer4 import DynamicsWorldModel, VideoTokenizer
|
||||||
|
|
||||||
tokenizer = VideoTokenizer(
|
tokenizer = VideoTokenizer(
|
||||||
@ -646,9 +612,7 @@ def test_online_rl(
|
|||||||
dim_latent = 16,
|
dim_latent = 16,
|
||||||
patch_size = 32,
|
patch_size = 32,
|
||||||
attn_dim_head = 16,
|
attn_dim_head = 16,
|
||||||
num_latent_tokens = 1,
|
num_latent_tokens = 1
|
||||||
image_height = 256,
|
|
||||||
image_width = 256,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
world_model_and_policy = DynamicsWorldModel(
|
world_model_and_policy = DynamicsWorldModel(
|
||||||
@ -669,163 +633,11 @@ def test_online_rl(
|
|||||||
)
|
)
|
||||||
|
|
||||||
from dreamer4.mocks import MockEnv
|
from dreamer4.mocks import MockEnv
|
||||||
from dreamer4.dreamer4 import combine_experiences
|
mock_env = MockEnv((256, 256), vectorized = vectorized, num_envs = 4)
|
||||||
|
|
||||||
mock_env = MockEnv(
|
one_experience = world_model_and_policy.interact_with_env(mock_env, max_timesteps = 16, env_is_vectorized = vectorized)
|
||||||
(256, 256),
|
|
||||||
vectorized = vectorized,
|
|
||||||
num_envs = 4,
|
|
||||||
terminate_after_step = 2 if env_can_terminate else None,
|
|
||||||
can_truncate = env_can_truncate,
|
|
||||||
rand_terminate_prob = 0.1
|
|
||||||
)
|
|
||||||
|
|
||||||
# manually
|
actor_loss, critic_loss = world_model_and_policy.learn_from_experience(one_experience)
|
||||||
|
|
||||||
dream_experience = world_model_and_policy.generate(10, batch_size = 1, store_agent_embed = store_agent_embed, return_for_policy_optimization = True)
|
|
||||||
|
|
||||||
one_experience = world_model_and_policy.interact_with_env(mock_env, max_timesteps = 8, env_is_vectorized = vectorized, store_agent_embed = store_agent_embed)
|
|
||||||
another_experience = world_model_and_policy.interact_with_env(mock_env, max_timesteps = 16, env_is_vectorized = vectorized, store_agent_embed = store_agent_embed)
|
|
||||||
|
|
||||||
combined_experience = combine_experiences([dream_experience, one_experience, another_experience])
|
|
||||||
|
|
||||||
# quick test moving the experience to different devices
|
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
combined_experience = combined_experience.to(torch.device('cuda'))
|
|
||||||
combined_experience = combined_experience.to(world_model_and_policy.device)
|
|
||||||
|
|
||||||
if store_agent_embed:
|
|
||||||
assert exists(combined_experience.agent_embed)
|
|
||||||
|
|
||||||
actor_loss, critic_loss = world_model_and_policy.learn_from_experience(combined_experience, use_pmpo = use_pmpo)
|
|
||||||
|
|
||||||
actor_loss.backward()
|
actor_loss.backward()
|
||||||
critic_loss.backward()
|
critic_loss.backward()
|
||||||
|
|
||||||
# with trainer
|
|
||||||
|
|
||||||
from dreamer4.trainers import SimTrainer
|
|
||||||
|
|
||||||
trainer = SimTrainer(
|
|
||||||
world_model_and_policy,
|
|
||||||
batch_size = 4,
|
|
||||||
cpu = True
|
|
||||||
)
|
|
||||||
|
|
||||||
trainer(mock_env, num_episodes = 2, env_is_vectorized = vectorized)
|
|
||||||
|
|
||||||
@param('num_video_views', (1, 2))
|
|
||||||
def test_proprioception(
|
|
||||||
num_video_views
|
|
||||||
):
|
|
||||||
from dreamer4.dreamer4 import VideoTokenizer, DynamicsWorldModel
|
|
||||||
|
|
||||||
tokenizer = VideoTokenizer(
|
|
||||||
512,
|
|
||||||
dim_latent = 32,
|
|
||||||
patch_size = 32,
|
|
||||||
encoder_depth = 2,
|
|
||||||
decoder_depth = 2,
|
|
||||||
time_block_every = 2,
|
|
||||||
attn_heads = 8,
|
|
||||||
image_height = 256,
|
|
||||||
image_width = 256,
|
|
||||||
attn_kwargs = dict(
|
|
||||||
query_heads = 16
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
dynamics = DynamicsWorldModel(
|
|
||||||
512,
|
|
||||||
num_agents = 1,
|
|
||||||
video_tokenizer = tokenizer,
|
|
||||||
dim_latent = 32,
|
|
||||||
dim_proprio = 21,
|
|
||||||
num_tasks = 4,
|
|
||||||
num_video_views = num_video_views,
|
|
||||||
num_discrete_actions = 4,
|
|
||||||
num_residual_streams = 1
|
|
||||||
)
|
|
||||||
|
|
||||||
if num_video_views > 1:
|
|
||||||
video_shape = (2, num_video_views, 3, 10, 256, 256)
|
|
||||||
else:
|
|
||||||
video_shape = (2, 3, 10, 256, 256)
|
|
||||||
|
|
||||||
video = torch.randn(*video_shape)
|
|
||||||
rewards = torch.randn(2, 10)
|
|
||||||
proprio = torch.randn(2, 10, 21)
|
|
||||||
discrete_actions = torch.randint(0, 4, (2, 10, 1))
|
|
||||||
tasks = torch.randint(0, 4, (2,))
|
|
||||||
|
|
||||||
loss = dynamics(
|
|
||||||
video = video,
|
|
||||||
rewards = rewards,
|
|
||||||
tasks = tasks,
|
|
||||||
proprio = proprio,
|
|
||||||
discrete_actions = discrete_actions
|
|
||||||
)
|
|
||||||
|
|
||||||
loss.backward()
|
|
||||||
|
|
||||||
generations = dynamics.generate(
|
|
||||||
10,
|
|
||||||
batch_size = 2,
|
|
||||||
return_decoded_video = True
|
|
||||||
)
|
|
||||||
|
|
||||||
assert exists(generations.proprio)
|
|
||||||
assert generations.video.shape == video_shape
|
|
||||||
|
|
||||||
def test_epo():
|
|
||||||
from dreamer4.dreamer4 import VideoTokenizer, DynamicsWorldModel
|
|
||||||
|
|
||||||
tokenizer = VideoTokenizer(
|
|
||||||
512,
|
|
||||||
dim_latent = 32,
|
|
||||||
patch_size = 32,
|
|
||||||
encoder_depth = 2,
|
|
||||||
decoder_depth = 2,
|
|
||||||
time_block_every = 2,
|
|
||||||
attn_heads = 8,
|
|
||||||
image_height = 256,
|
|
||||||
image_width = 256,
|
|
||||||
attn_kwargs = dict(
|
|
||||||
query_heads = 16
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
dynamics = DynamicsWorldModel(
|
|
||||||
512,
|
|
||||||
num_agents = 1,
|
|
||||||
video_tokenizer = tokenizer,
|
|
||||||
dim_latent = 32,
|
|
||||||
dim_proprio = 21,
|
|
||||||
num_tasks = 4,
|
|
||||||
num_latent_genes = 16,
|
|
||||||
num_discrete_actions = 4,
|
|
||||||
num_residual_streams = 1
|
|
||||||
)
|
|
||||||
|
|
||||||
fitness = torch.randn(16,)
|
|
||||||
dynamics.evolve_(fitness)
|
|
||||||
|
|
||||||
def test_images_to_video_tokenizer():
|
|
||||||
import torch
|
|
||||||
from dreamer4 import VideoTokenizer, DynamicsWorldModel, AxialSpaceTimeTransformer
|
|
||||||
|
|
||||||
tokenizer = VideoTokenizer(
|
|
||||||
dim = 512,
|
|
||||||
dim_latent = 32,
|
|
||||||
patch_size = 32,
|
|
||||||
image_height = 256,
|
|
||||||
image_width = 256,
|
|
||||||
encoder_add_decor_aux_loss = True
|
|
||||||
)
|
|
||||||
|
|
||||||
images = torch.randn(2, 3, 256, 256)
|
|
||||||
loss, (losses, recon_images) = tokenizer(images, return_intermediates = True)
|
|
||||||
loss.backward()
|
|
||||||
|
|
||||||
assert images.shape == recon_images.shape
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user