Compare commits
136 Commits
pytest-sha
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5bb027b386 | ||
|
|
9efe269688 | ||
|
|
fb8c3793b4 | ||
|
|
fb6d69f43a | ||
|
|
125693ce1c | ||
|
|
2e7f406d49 | ||
|
|
690ecf07dc | ||
|
|
ac1c12f743 | ||
|
|
3c84b404a8 | ||
|
|
d5b70e2b86 | ||
|
|
c3532fa797 | ||
|
|
73029635fe | ||
|
|
e1c41f4371 | ||
|
|
f55c61c6cf | ||
|
|
051d4d6ee2 | ||
|
|
ef3a5552e7 | ||
|
|
0c4224da18 | ||
|
|
256a81f658 | ||
|
|
cfd34f1eba | ||
|
|
4ffbe37873 | ||
|
|
24ef72d528 | ||
|
|
a4afcb22a6 | ||
|
|
b0f6b8583d | ||
|
|
38cba80068 | ||
|
|
c0a6cd56a1 | ||
|
|
d756d1bb8c | ||
|
|
60681fce1d | ||
|
|
6870294d95 | ||
|
|
3beae186da | ||
|
|
0904e224ab | ||
|
|
767789d0ca | ||
|
|
35b87c4fa1 | ||
|
|
c4a3cb09d5 | ||
|
|
cb54121ace | ||
|
|
586379f2c8 | ||
|
|
a358a44a53 | ||
|
|
3547344312 | ||
|
|
691d9ca007 | ||
|
|
91d697f8ca | ||
|
|
7acaa764f6 | ||
|
|
c0450359f3 | ||
|
|
46f86cd247 | ||
|
|
903c43b770 | ||
|
|
d476fa7b14 | ||
|
|
789f091c63 | ||
|
|
41ab83f691 | ||
|
|
995b1f64e5 | ||
|
|
fd1e87983b | ||
|
|
fe79bfa951 | ||
|
|
f808b1c1d2 | ||
|
|
349a03acd7 | ||
|
|
59c458aea3 | ||
|
|
fbfd59e42f | ||
|
|
46432aee9b | ||
|
|
f97d9adc97 | ||
|
|
32cf142b4d | ||
|
|
1ed6a15cb0 | ||
|
|
4d8f5613cc | ||
|
|
3d5617d769 | ||
|
|
77a40e8701 | ||
|
|
4ce82f34df | ||
|
|
a9b728c611 | ||
|
|
35c1db4c7d | ||
|
|
27ac05efb0 | ||
|
|
d0ffc6bfed | ||
|
|
fb3e026fe0 | ||
|
|
7ecc5d03e8 | ||
|
|
d82debb7a6 | ||
|
|
03b16a48f2 | ||
|
|
6f1a7a24ed | ||
|
|
e316499047 | ||
|
|
40da985c6b | ||
|
|
2fc3b17149 | ||
|
|
283d59d75a | ||
|
|
4a5465eeb6 | ||
|
|
b34128d3d0 | ||
|
|
7ba3988fb9 | ||
|
|
ea13d4fcab | ||
|
|
15876d34cf | ||
|
|
b4763caff9 | ||
|
|
7195bbb196 | ||
|
|
ca244a290c | ||
|
|
a7e0c395c3 | ||
|
|
1345326656 | ||
|
|
55574c054e | ||
|
|
27ed6d0ba5 | ||
|
|
4930002e99 | ||
|
|
ecbe13efe8 | ||
|
|
f651d779e3 | ||
|
|
374667d8a9 | ||
|
|
79a1b1c46e | ||
|
|
b6aa19f31e | ||
|
|
bc629d78b1 | ||
|
|
0ee475d2df | ||
|
|
8c88a33d3b | ||
|
|
911a1a8434 | ||
|
|
5fc0022bbf | ||
|
|
83cfd2cd1b | ||
|
|
22e13c45fc | ||
|
|
c967404471 | ||
|
|
0c1b067f97 | ||
|
|
cb416c0d44 | ||
|
|
61773c8219 | ||
|
|
0dba734280 | ||
|
|
a0161760a0 | ||
|
|
2d20d0a6c1 | ||
|
|
d74f09f0b3 | ||
|
|
2ccb290e26 | ||
|
|
517ef6b94b | ||
|
|
ec18bc0fa4 | ||
|
|
2a902eaaf7 | ||
|
|
d28251e9f9 | ||
|
|
ff81dd761b | ||
|
|
6dbdc3d7d8 | ||
|
|
9c78962736 | ||
|
|
c5e64ff4ce | ||
|
|
ab5de6795f | ||
|
|
8a73a27fc7 | ||
|
|
01bf70e18a | ||
|
|
b2725d9b6e | ||
|
|
02558d1f08 | ||
|
|
563b269f8a | ||
|
|
5df3e69583 | ||
|
|
9230267d34 | ||
|
|
c68942b026 | ||
|
|
32aa355e37 | ||
|
|
9101a49cdd | ||
|
|
31f4363be7 | ||
|
|
e2d86a4543 | ||
|
|
b62c08be65 | ||
|
|
4c2ed100a3 | ||
|
|
ed0918c974 | ||
|
|
892654d442 | ||
|
|
c4e0f46528 | ||
|
|
a50e360502 | ||
|
|
9c56ba0c9d |
6
.github/workflows/test.yml
vendored
6
.github/workflows/test.yml
vendored
@ -5,11 +5,11 @@ jobs:
|
||||
build:
|
||||
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 20
|
||||
timeout-minutes: 60
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
group: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
|
||||
group: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19]
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
@ -24,4 +24,4 @@ jobs:
|
||||
python -m uv pip install -e .[test]
|
||||
- name: Test with pytest
|
||||
run: |
|
||||
python -m pytest --num-shards 10 --shard-id ${{ matrix.group }} tests/
|
||||
python -m pytest --num-shards 20 --shard-id ${{ matrix.group }} tests/
|
||||
|
||||
102
README.md
102
README.md
@ -1,10 +1,99 @@
|
||||
<img src="./dreamer4-fig2.png" width="400px"></img>
|
||||
|
||||
## Dreamer 4 (wip)
|
||||
## Dreamer 4
|
||||
|
||||
Implementation of Danijar's [latest iteration](https://arxiv.org/abs/2509.24527v1) for his [Dreamer](https://danijar.com/project/dreamer4/) line of work
|
||||
|
||||
[Temporary Discord](https://discord.gg/MkACrrkrYR)
|
||||
[Discord channel](https://discord.gg/PmGR7KRwxq) for collaborating with other researchers interested in this work
|
||||
|
||||
## Appreciation
|
||||
|
||||
- [@dirkmcpherson](https://github.com/dirkmcpherson) for fixes to typo errors and unpassed arguments!
|
||||
|
||||
## Install
|
||||
|
||||
```bash
|
||||
$ pip install dreamer4
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
```python
|
||||
import torch
|
||||
from dreamer4 import VideoTokenizer, DynamicsWorldModel
|
||||
|
||||
# video tokenizer, learned through MAE + lpips
|
||||
|
||||
tokenizer = VideoTokenizer(
|
||||
dim = 512,
|
||||
dim_latent = 32,
|
||||
patch_size = 32,
|
||||
image_height = 256,
|
||||
image_width = 256
|
||||
)
|
||||
|
||||
video = torch.randn(2, 3, 10, 256, 256)
|
||||
|
||||
# learn the tokenizer
|
||||
|
||||
loss = tokenizer(video)
|
||||
loss.backward()
|
||||
|
||||
# dynamics world model
|
||||
|
||||
world_model = DynamicsWorldModel(
|
||||
dim = 512,
|
||||
dim_latent = 32,
|
||||
video_tokenizer = tokenizer,
|
||||
num_discrete_actions = 4,
|
||||
num_residual_streams = 1
|
||||
)
|
||||
|
||||
# state, action, rewards
|
||||
|
||||
video = torch.randn(2, 3, 10, 256, 256)
|
||||
discrete_actions = torch.randint(0, 4, (2, 10, 1))
|
||||
rewards = torch.randn(2, 10)
|
||||
|
||||
# learn dynamics / behavior cloned model
|
||||
|
||||
loss = world_model(
|
||||
video = video,
|
||||
rewards = rewards,
|
||||
discrete_actions = discrete_actions
|
||||
)
|
||||
|
||||
loss.backward()
|
||||
|
||||
# do the above with much data
|
||||
|
||||
# then generate dreams
|
||||
|
||||
dreams = world_model.generate(
|
||||
10,
|
||||
batch_size = 2,
|
||||
return_decoded_video = True,
|
||||
return_for_policy_optimization = True
|
||||
)
|
||||
|
||||
# learn from the dreams
|
||||
|
||||
actor_loss, critic_loss = world_model.learn_from_experience(dreams)
|
||||
|
||||
(actor_loss + critic_loss).backward()
|
||||
|
||||
# learn from environment
|
||||
|
||||
from dreamer4.mocks import MockEnv
|
||||
|
||||
mock_env = MockEnv((256, 256), vectorized = True, num_envs = 4)
|
||||
|
||||
experience = world_model.interact_with_env(mock_env, max_timesteps = 8, env_is_vectorized = True)
|
||||
|
||||
actor_loss, critic_loss = world_model.learn_from_experience(experience)
|
||||
|
||||
(actor_loss + critic_loss).backward()
|
||||
```
|
||||
|
||||
## Citation
|
||||
|
||||
@ -20,11 +109,4 @@ Implementation of Danijar's [latest iteration](https://arxiv.org/abs/2509.24527v
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@misc{xiong2025ndrope,
|
||||
author = {Jerry Xiong},
|
||||
title = {On n-dimensional rotary positional embeddings},
|
||||
year = {2025},
|
||||
url = {https://jerryxio.ng/posts/nd-rope/}
|
||||
}
|
||||
```
|
||||
*the conquest of nature is to be achieved through number and measure - angels to Descartes in a dream*
|
||||
|
||||
@ -1,5 +1,12 @@
|
||||
from dreamer4.dreamer4 import (
|
||||
VideoTokenizer,
|
||||
DynamicsModel,
|
||||
Dreamer
|
||||
DynamicsWorldModel,
|
||||
AxialSpaceTimeTransformer
|
||||
)
|
||||
|
||||
|
||||
from dreamer4.trainers import (
|
||||
VideoTokenizerTrainer,
|
||||
BehaviorCloneTrainer,
|
||||
DreamTrainer
|
||||
)
|
||||
|
||||
2962
dreamer4/dreamer4.py
2962
dreamer4/dreamer4.py
File diff suppressed because it is too large
Load Diff
97
dreamer4/mocks.py
Normal file
97
dreamer4/mocks.py
Normal file
@ -0,0 +1,97 @@
|
||||
from __future__ import annotations
|
||||
from random import choice
|
||||
|
||||
import torch
|
||||
from torch import tensor, empty, randn, randint
|
||||
from torch.nn import Module
|
||||
|
||||
from einops import repeat
|
||||
|
||||
# helpers
|
||||
|
||||
def exists(v):
|
||||
return v is not None
|
||||
|
||||
# mock env
|
||||
|
||||
class MockEnv(Module):
|
||||
def __init__(
|
||||
self,
|
||||
image_shape,
|
||||
reward_range = (-100, 100),
|
||||
num_envs = 1,
|
||||
vectorized = False,
|
||||
terminate_after_step = None,
|
||||
rand_terminate_prob = 0.05,
|
||||
can_truncate = False,
|
||||
rand_truncate_prob = 0.05,
|
||||
):
|
||||
super().__init__()
|
||||
self.image_shape = image_shape
|
||||
self.reward_range = reward_range
|
||||
|
||||
self.num_envs = num_envs
|
||||
self.vectorized = vectorized
|
||||
assert not (vectorized and num_envs == 1)
|
||||
|
||||
# mocking termination and truncation
|
||||
|
||||
self.can_terminate = exists(terminate_after_step)
|
||||
self.terminate_after_step = terminate_after_step
|
||||
self.rand_terminate_prob = rand_terminate_prob
|
||||
|
||||
self.can_truncate = can_truncate
|
||||
self.rand_truncate_prob = rand_truncate_prob
|
||||
|
||||
self.register_buffer('_step', tensor(0))
|
||||
|
||||
def get_random_state(self):
|
||||
return randn(3, *self.image_shape)
|
||||
|
||||
def reset(
|
||||
self,
|
||||
seed = None
|
||||
):
|
||||
self._step.zero_()
|
||||
state = self.get_random_state()
|
||||
|
||||
if self.vectorized:
|
||||
state = repeat(state, '... -> b ...', b = self.num_envs)
|
||||
|
||||
return state
|
||||
|
||||
def step(
|
||||
self,
|
||||
actions,
|
||||
):
|
||||
state = self.get_random_state()
|
||||
|
||||
reward = empty(()).uniform_(*self.reward_range)
|
||||
|
||||
if self.vectorized:
|
||||
discrete, continuous = actions
|
||||
assert discrete.shape[0] == self.num_envs, f'expected batch of actions for {self.num_envs} environments'
|
||||
|
||||
state = repeat(state, '... -> b ...', b = self.num_envs)
|
||||
reward = repeat(reward, ' -> b', b = self.num_envs)
|
||||
|
||||
out = (state, reward)
|
||||
|
||||
|
||||
if self.can_terminate:
|
||||
shape = (self.num_envs,) if self.vectorized else (1,)
|
||||
valid_step = self._step > self.terminate_after_step
|
||||
|
||||
terminate = (torch.rand(shape) < self.rand_terminate_prob) & valid_step
|
||||
|
||||
out = (*out, terminate)
|
||||
|
||||
# maybe truncation
|
||||
|
||||
if self.can_truncate:
|
||||
truncate = (torch.rand(shape) < self.rand_truncate_prob) & valid_step & ~terminate
|
||||
out = (*out, truncate)
|
||||
|
||||
self._step.add_(1)
|
||||
|
||||
return out
|
||||
@ -1,17 +1,539 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
from torch import is_tensor
|
||||
from torch.nn import Module
|
||||
from torch.optim import AdamW
|
||||
from torch.utils.data import Dataset, TensorDataset, DataLoader
|
||||
|
||||
from accelerate import Accelerator
|
||||
|
||||
from adam_atan2_pytorch import MuonAdamAtan2
|
||||
|
||||
from dreamer4.dreamer4 import (
|
||||
VideoTokenizer,
|
||||
DynamicsModel
|
||||
DynamicsWorldModel,
|
||||
Experience,
|
||||
combine_experiences
|
||||
)
|
||||
|
||||
from ema_pytorch import EMA
|
||||
|
||||
# helpers
|
||||
|
||||
def exists(v):
|
||||
return v is not None
|
||||
|
||||
def default(v, d):
|
||||
return v if exists(v) else d
|
||||
|
||||
def cycle(dl):
|
||||
while True:
|
||||
for batch in dl:
|
||||
yield batch
|
||||
|
||||
# trainers
|
||||
|
||||
class VideoTokenizerTrainer(Module):
|
||||
def __init__(
|
||||
self,
|
||||
model: VideoTokenizer
|
||||
model: VideoTokenizer,
|
||||
dataset: Dataset,
|
||||
optim_klass = MuonAdamAtan2,
|
||||
batch_size = 16,
|
||||
learning_rate = 3e-4,
|
||||
max_grad_norm = None,
|
||||
num_train_steps = 10_000,
|
||||
weight_decay = 0.,
|
||||
accelerate_kwargs: dict = dict(),
|
||||
optim_kwargs: dict = dict(),
|
||||
cpu = False,
|
||||
):
|
||||
super().__init__()
|
||||
raise NotImplementedError
|
||||
batch_size = min(batch_size, len(dataset))
|
||||
|
||||
self.accelerator = Accelerator(
|
||||
cpu = cpu,
|
||||
**accelerate_kwargs
|
||||
)
|
||||
|
||||
self.model = model
|
||||
self.dataset = dataset
|
||||
self.train_dataloader = DataLoader(dataset, batch_size = batch_size, drop_last = True, shuffle = True)
|
||||
|
||||
optim_kwargs = dict(
|
||||
lr = learning_rate,
|
||||
weight_decay = weight_decay
|
||||
)
|
||||
|
||||
if optim_klass is MuonAdamAtan2:
|
||||
optim = MuonAdamAtan2(
|
||||
model.muon_parameters(),
|
||||
model.parameters(),
|
||||
**optim_kwargs
|
||||
)
|
||||
else:
|
||||
optim = optim_klass(
|
||||
model.parameters(),
|
||||
**optim_kwargs
|
||||
)
|
||||
|
||||
self.optim = optim
|
||||
|
||||
self.max_grad_norm = max_grad_norm
|
||||
|
||||
self.num_train_steps = num_train_steps
|
||||
self.batch_size = batch_size
|
||||
|
||||
(
|
||||
self.model,
|
||||
self.train_dataloader,
|
||||
self.optim
|
||||
) = self.accelerator.prepare(
|
||||
self.model,
|
||||
self.train_dataloader,
|
||||
self.optim
|
||||
)
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return self.accelerator.device
|
||||
|
||||
def print(self, *args, **kwargs):
|
||||
return self.accelerator.print(*args, **kwargs)
|
||||
|
||||
def forward(
|
||||
self
|
||||
):
|
||||
iter_train_dl = cycle(self.train_dataloader)
|
||||
|
||||
for _ in range(self.num_train_steps):
|
||||
video = next(iter_train_dl)
|
||||
|
||||
loss = self.model(video)
|
||||
self.accelerator.backward(loss)
|
||||
|
||||
if exists(self.max_grad_norm):
|
||||
self.accelerator.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
|
||||
|
||||
self.optim.step()
|
||||
self.optim.zero_grad()
|
||||
|
||||
self.print('training complete')
|
||||
|
||||
# dynamics world model
|
||||
|
||||
class BehaviorCloneTrainer(Module):
|
||||
def __init__(
|
||||
self,
|
||||
model: DynamicsWorldModel,
|
||||
dataset: Dataset,
|
||||
optim_klass = MuonAdamAtan2,
|
||||
batch_size = 16,
|
||||
learning_rate = 3e-4,
|
||||
max_grad_norm = None,
|
||||
num_train_steps = 10_000,
|
||||
weight_decay = 0.,
|
||||
accelerate_kwargs: dict = dict(),
|
||||
optim_kwargs: dict = dict(),
|
||||
cpu = False,
|
||||
):
|
||||
super().__init__()
|
||||
batch_size = min(batch_size, len(dataset))
|
||||
|
||||
self.accelerator = Accelerator(
|
||||
cpu = cpu,
|
||||
**accelerate_kwargs
|
||||
)
|
||||
|
||||
self.model = model
|
||||
self.dataset = dataset
|
||||
self.train_dataloader = DataLoader(dataset, batch_size = batch_size, drop_last = True, shuffle = True)
|
||||
|
||||
optim_kwargs = dict(
|
||||
lr = learning_rate,
|
||||
weight_decay = weight_decay
|
||||
)
|
||||
|
||||
if optim_klass is MuonAdamAtan2:
|
||||
optim = MuonAdamAtan2(
|
||||
model.muon_parameters(),
|
||||
model.parameters(),
|
||||
**optim_kwargs
|
||||
)
|
||||
else:
|
||||
optim = optim_klass(
|
||||
model.parameters(),
|
||||
**optim_kwargs
|
||||
)
|
||||
|
||||
self.optim = optim
|
||||
|
||||
self.max_grad_norm = max_grad_norm
|
||||
|
||||
self.num_train_steps = num_train_steps
|
||||
self.batch_size = batch_size
|
||||
|
||||
(
|
||||
self.model,
|
||||
self.train_dataloader,
|
||||
self.optim
|
||||
) = self.accelerator.prepare(
|
||||
self.model,
|
||||
self.train_dataloader,
|
||||
self.optim
|
||||
)
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return self.accelerator.device
|
||||
|
||||
def print(self, *args, **kwargs):
|
||||
return self.accelerator.print(*args, **kwargs)
|
||||
|
||||
def forward(
|
||||
self
|
||||
):
|
||||
iter_train_dl = cycle(self.train_dataloader)
|
||||
|
||||
for _ in range(self.num_train_steps):
|
||||
batch_data = next(iter_train_dl)
|
||||
|
||||
# just assume raw video dynamics training if batch_data is a tensor
|
||||
# else kwargs for video, actions, rewards
|
||||
|
||||
if is_tensor(batch_data):
|
||||
loss = self.model(batch_data)
|
||||
else:
|
||||
loss = self.model(**batch_data)
|
||||
|
||||
self.accelerator.backward(loss)
|
||||
|
||||
if exists(self.max_grad_norm):
|
||||
self.accelerator.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
|
||||
|
||||
self.optim.step()
|
||||
self.optim.zero_grad()
|
||||
|
||||
self.print('training complete')
|
||||
|
||||
# training from dreams
|
||||
|
||||
class DreamTrainer(Module):
|
||||
def __init__(
|
||||
self,
|
||||
model: DynamicsWorldModel,
|
||||
optim_klass = AdamW,
|
||||
batch_size = 16,
|
||||
generate_timesteps = 16,
|
||||
learning_rate = 3e-4,
|
||||
max_grad_norm = None,
|
||||
num_train_steps = 10_000,
|
||||
weight_decay = 0.,
|
||||
accelerate_kwargs: dict = dict(),
|
||||
optim_kwargs: dict = dict(),
|
||||
cpu = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.accelerator = Accelerator(
|
||||
cpu = cpu,
|
||||
**accelerate_kwargs
|
||||
)
|
||||
|
||||
self.model = model
|
||||
|
||||
optim_kwargs = dict(
|
||||
lr = learning_rate,
|
||||
weight_decay = weight_decay
|
||||
)
|
||||
|
||||
self.policy_head_optim = AdamW(model.policy_head_parameters(), **optim_kwargs)
|
||||
self.value_head_optim = AdamW(model.value_head_parameters(), **optim_kwargs)
|
||||
|
||||
self.max_grad_norm = max_grad_norm
|
||||
|
||||
self.num_train_steps = num_train_steps
|
||||
self.batch_size = batch_size
|
||||
self.generate_timesteps = generate_timesteps
|
||||
|
||||
self.unwrapped_model = self.model
|
||||
|
||||
(
|
||||
self.model,
|
||||
self.policy_head_optim,
|
||||
self.value_head_optim,
|
||||
) = self.accelerator.prepare(
|
||||
self.model,
|
||||
self.policy_head_optim,
|
||||
self.value_head_optim
|
||||
)
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return self.accelerator.device
|
||||
|
||||
@property
|
||||
def unwrapped_model(self):
|
||||
return self.accelerator.unwrap_model(self.model)
|
||||
|
||||
def print(self, *args, **kwargs):
|
||||
return self.accelerator.print(*args, **kwargs)
|
||||
|
||||
def forward(
|
||||
self
|
||||
):
|
||||
|
||||
for _ in range(self.num_train_steps):
|
||||
|
||||
dreams = self.unwrapped_model.generate(
|
||||
self.generate_timesteps + 1, # plus one for bootstrap value
|
||||
batch_size = self.batch_size,
|
||||
return_rewards_per_frame = True,
|
||||
return_agent_actions = True,
|
||||
return_log_probs_and_values = True
|
||||
)
|
||||
|
||||
policy_head_loss, value_head_loss = self.model.learn_from_experience(dreams)
|
||||
|
||||
self.print(f'policy head loss: {policy_head_loss.item():.3f} | value head loss: {value_head_loss.item():.3f}')
|
||||
|
||||
# update policy head
|
||||
|
||||
self.accelerator.backward(policy_head_loss)
|
||||
|
||||
if exists(self.max_grad_norm):
|
||||
self.accelerator.clip_grad_norm_(self.model.policy_head_parameters()(), self.max_grad_norm)
|
||||
|
||||
self.policy_head_optim.step()
|
||||
self.policy_head_optim.zero_grad()
|
||||
|
||||
# update value head
|
||||
|
||||
self.accelerator.backward(value_head_loss)
|
||||
|
||||
if exists(self.max_grad_norm):
|
||||
self.accelerator.clip_grad_norm_(self.model.value_head_parameters(), self.max_grad_norm)
|
||||
|
||||
self.value_head_optim.step()
|
||||
self.value_head_optim.zero_grad()
|
||||
|
||||
self.print('training complete')
|
||||
|
||||
# training from sim
|
||||
|
||||
class SimTrainer(Module):
|
||||
def __init__(
|
||||
self,
|
||||
model: DynamicsWorldModel,
|
||||
optim_klass = AdamW,
|
||||
batch_size = 16,
|
||||
generate_timesteps = 16,
|
||||
learning_rate = 3e-4,
|
||||
max_grad_norm = None,
|
||||
epochs = 2,
|
||||
weight_decay = 0.,
|
||||
accelerate_kwargs: dict = dict(),
|
||||
optim_kwargs: dict = dict(),
|
||||
cpu = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.accelerator = Accelerator(
|
||||
cpu = cpu,
|
||||
**accelerate_kwargs
|
||||
)
|
||||
|
||||
self.model = model
|
||||
|
||||
optim_kwargs = dict(
|
||||
lr = learning_rate,
|
||||
weight_decay = weight_decay
|
||||
)
|
||||
|
||||
self.policy_head_optim = AdamW(model.policy_head_parameters(), **optim_kwargs)
|
||||
self.value_head_optim = AdamW(model.value_head_parameters(), **optim_kwargs)
|
||||
|
||||
self.max_grad_norm = max_grad_norm
|
||||
|
||||
self.epochs = epochs
|
||||
self.batch_size = batch_size
|
||||
|
||||
self.generate_timesteps = generate_timesteps
|
||||
|
||||
self.unwrapped_model = self.model
|
||||
|
||||
(
|
||||
self.model,
|
||||
self.policy_head_optim,
|
||||
self.value_head_optim,
|
||||
) = self.accelerator.prepare(
|
||||
self.model,
|
||||
self.policy_head_optim,
|
||||
self.value_head_optim
|
||||
)
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return self.accelerator.device
|
||||
|
||||
@property
|
||||
def unwrapped_model(self):
|
||||
return self.accelerator.unwrap_model(self.model)
|
||||
|
||||
def print(self, *args, **kwargs):
|
||||
return self.accelerator.print(*args, **kwargs)
|
||||
|
||||
def learn(
|
||||
self,
|
||||
experience: Experience
|
||||
):
|
||||
|
||||
step_size = experience.step_size
|
||||
agent_index = experience.agent_index
|
||||
|
||||
latents = experience.latents
|
||||
old_values = experience.values
|
||||
rewards = experience.rewards
|
||||
|
||||
has_agent_embed = exists(experience.agent_embed)
|
||||
agent_embed = experience.agent_embed
|
||||
|
||||
discrete_actions, continuous_actions = experience.actions
|
||||
discrete_log_probs, continuous_log_probs = experience.log_probs
|
||||
|
||||
discrete_old_action_unembeds, continuous_old_action_unembeds = default(experience.old_action_unembeds, (None, None))
|
||||
|
||||
# handle empties
|
||||
|
||||
empty_tensor = torch.empty_like(rewards)
|
||||
|
||||
agent_embed = default(agent_embed, empty_tensor)
|
||||
|
||||
has_discrete = exists(discrete_actions)
|
||||
has_continuous = exists(continuous_actions)
|
||||
|
||||
discrete_actions = default(discrete_actions, empty_tensor)
|
||||
continuous_actions = default(continuous_actions, empty_tensor)
|
||||
|
||||
discrete_log_probs = default(discrete_log_probs, empty_tensor)
|
||||
continuous_log_probs = default(continuous_log_probs, empty_tensor)
|
||||
|
||||
discrete_old_action_unembeds = default(discrete_old_action_unembeds, empty_tensor)
|
||||
continuous_old_action_unembeds = default(discrete_old_action_unembeds, empty_tensor)
|
||||
|
||||
# create the dataset and dataloader
|
||||
|
||||
dataset = TensorDataset(
|
||||
latents,
|
||||
discrete_actions,
|
||||
continuous_actions,
|
||||
discrete_log_probs,
|
||||
continuous_log_probs,
|
||||
agent_embed,
|
||||
discrete_old_action_unembeds,
|
||||
continuous_old_action_unembeds,
|
||||
old_values,
|
||||
rewards
|
||||
)
|
||||
|
||||
dataloader = DataLoader(dataset, batch_size = self.batch_size, shuffle = True)
|
||||
|
||||
for epoch in range(self.epochs):
|
||||
|
||||
for (
|
||||
latents,
|
||||
discrete_actions,
|
||||
continuous_actions,
|
||||
discrete_log_probs,
|
||||
continuous_log_probs,
|
||||
agent_embed,
|
||||
discrete_old_action_unembeds,
|
||||
continuous_old_action_unembeds,
|
||||
old_values,
|
||||
rewards
|
||||
) in dataloader:
|
||||
|
||||
actions = (
|
||||
discrete_actions if has_discrete else None,
|
||||
continuous_actions if has_continuous else None
|
||||
)
|
||||
|
||||
log_probs = (
|
||||
discrete_log_probs if has_discrete else None,
|
||||
continuous_log_probs if has_continuous else None
|
||||
)
|
||||
|
||||
old_action_unembeds = (
|
||||
discrete_old_action_unembeds if has_discrete else None,
|
||||
continuous_old_action_unembeds if has_continuous else None
|
||||
)
|
||||
|
||||
batch_experience = Experience(
|
||||
latents = latents,
|
||||
actions = actions,
|
||||
log_probs = log_probs,
|
||||
agent_embed = agent_embed if has_agent_embed else None,
|
||||
old_action_unembeds = old_action_unembeds,
|
||||
values = old_values,
|
||||
rewards = rewards,
|
||||
step_size = step_size,
|
||||
agent_index = agent_index
|
||||
)
|
||||
|
||||
policy_head_loss, value_head_loss = self.model.learn_from_experience(batch_experience)
|
||||
|
||||
self.print(f'policy head loss: {policy_head_loss.item():.3f} | value head loss: {value_head_loss.item():.3f}')
|
||||
|
||||
# update policy head
|
||||
|
||||
self.accelerator.backward(policy_head_loss)
|
||||
|
||||
if exists(self.max_grad_norm):
|
||||
self.accelerator.clip_grad_norm_(self.model.policy_head_parameters()(), self.max_grad_norm)
|
||||
|
||||
self.policy_head_optim.step()
|
||||
self.policy_head_optim.zero_grad()
|
||||
|
||||
# update value head
|
||||
|
||||
self.accelerator.backward(value_head_loss)
|
||||
|
||||
if exists(self.max_grad_norm):
|
||||
self.accelerator.clip_grad_norm_(self.model.value_head_parameters(), self.max_grad_norm)
|
||||
|
||||
self.value_head_optim.step()
|
||||
self.value_head_optim.zero_grad()
|
||||
|
||||
self.print('training complete')
|
||||
|
||||
def forward(
|
||||
self,
|
||||
env,
|
||||
num_episodes = 50000,
|
||||
max_experiences_before_learn = 8,
|
||||
env_is_vectorized = False
|
||||
):
|
||||
|
||||
for _ in range(num_episodes):
|
||||
|
||||
total_experience = 0
|
||||
experiences = []
|
||||
|
||||
while total_experience < max_experiences_before_learn:
|
||||
|
||||
experience = self.unwrapped_model.interact_with_env(env, env_is_vectorized = env_is_vectorized)
|
||||
|
||||
num_experience = experience.video.shape[0]
|
||||
|
||||
total_experience += num_experience
|
||||
|
||||
experiences.append(experience.cpu())
|
||||
|
||||
combined_experiences = combine_experiences(experiences)
|
||||
|
||||
self.learn(combined_experiences)
|
||||
|
||||
experiences.clear()
|
||||
|
||||
self.print('training complete')
|
||||
@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "dreamer4"
|
||||
version = "0.0.6"
|
||||
version = "0.1.24"
|
||||
description = "Dreamer 4"
|
||||
authors = [
|
||||
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
||||
@ -27,13 +27,17 @@ classifiers=[
|
||||
|
||||
dependencies = [
|
||||
"accelerate",
|
||||
"adam-atan2-pytorch>=0.2.2",
|
||||
"assoc-scan",
|
||||
"einx>=0.3.0",
|
||||
"einops>=0.8.1",
|
||||
"ema-pytorch",
|
||||
"hl-gauss-pytorch",
|
||||
"hyper-connections>=0.2.1",
|
||||
"torch>=2.4",
|
||||
"torchvision",
|
||||
"x-mlps-pytorch>=0.0.29"
|
||||
"x-mlps-pytorch>=0.0.29",
|
||||
"vit-pytorch>=1.15.3"
|
||||
]
|
||||
|
||||
[project.urls]
|
||||
|
||||
@ -2,6 +2,9 @@ import pytest
|
||||
param = pytest.mark.parametrize
|
||||
import torch
|
||||
|
||||
def exists(v):
|
||||
return v is not None
|
||||
|
||||
@param('pred_orig_latent', (False, True))
|
||||
@param('grouped_query_attn', (False, True))
|
||||
@param('dynamics_with_video_input', (False, True))
|
||||
@ -9,7 +12,12 @@ import torch
|
||||
@param('add_task_embeds', (False, True))
|
||||
@param('num_spatial_tokens', (2, 8))
|
||||
@param('signal_and_step_passed_in', (False, True))
|
||||
@param('condition_on_actions', (False, True))
|
||||
@param('num_residual_streams', (1, 4))
|
||||
@param('add_reward_embed_to_agent_token', (False, True))
|
||||
@param('add_state_pred_head', (False, True))
|
||||
@param('use_time_cache', (False, True))
|
||||
@param('var_len', (False, True))
|
||||
def test_e2e(
|
||||
pred_orig_latent,
|
||||
grouped_query_attn,
|
||||
@ -18,18 +26,26 @@ def test_e2e(
|
||||
add_task_embeds,
|
||||
num_spatial_tokens,
|
||||
signal_and_step_passed_in,
|
||||
add_reward_embed_to_agent_token
|
||||
condition_on_actions,
|
||||
num_residual_streams,
|
||||
add_reward_embed_to_agent_token,
|
||||
add_state_pred_head,
|
||||
use_time_cache,
|
||||
var_len
|
||||
):
|
||||
from dreamer4.dreamer4 import VideoTokenizer, DynamicsModel
|
||||
from dreamer4.dreamer4 import VideoTokenizer, DynamicsWorldModel
|
||||
|
||||
tokenizer = VideoTokenizer(
|
||||
16,
|
||||
encoder_depth = 1,
|
||||
decoder_depth = 1,
|
||||
encoder_depth = 4,
|
||||
decoder_depth = 4,
|
||||
dim_latent = 16,
|
||||
patch_size = 32,
|
||||
attn_dim_head = 16,
|
||||
num_latent_tokens = 4
|
||||
num_latent_tokens = 4,
|
||||
num_residual_streams = num_residual_streams,
|
||||
encoder_add_decor_aux_loss = True,
|
||||
decorr_sample_frac = 1.
|
||||
)
|
||||
|
||||
video = torch.randn(2, 3, 4, 256, 256)
|
||||
@ -45,7 +61,7 @@ def test_e2e(
|
||||
|
||||
query_heads, heads = (16, 4) if grouped_query_attn else (8, 8)
|
||||
|
||||
dynamics = DynamicsModel(
|
||||
dynamics = DynamicsWorldModel(
|
||||
dim = 16,
|
||||
video_tokenizer = tokenizer,
|
||||
dim_latent = 16,
|
||||
@ -55,13 +71,16 @@ def test_e2e(
|
||||
depth = 4,
|
||||
num_spatial_tokens = num_spatial_tokens,
|
||||
pred_orig_latent = pred_orig_latent,
|
||||
num_discrete_actions = 4,
|
||||
attn_dim_head = 16,
|
||||
attn_heads = heads,
|
||||
attn_kwargs = dict(
|
||||
heads = heads,
|
||||
query_heads = query_heads,
|
||||
),
|
||||
prob_no_shortcut_train = prob_no_shortcut_train,
|
||||
add_reward_embed_to_agent_token = add_reward_embed_to_agent_token
|
||||
add_reward_embed_to_agent_token = add_reward_embed_to_agent_token,
|
||||
add_state_pred_head = add_state_pred_head,
|
||||
num_residual_streams = num_residual_streams
|
||||
)
|
||||
|
||||
signal_levels = step_sizes_log2 = None
|
||||
@ -79,27 +98,39 @@ def test_e2e(
|
||||
if add_task_embeds:
|
||||
tasks = torch.randint(0, 4, (2,))
|
||||
|
||||
actions = None
|
||||
if condition_on_actions:
|
||||
actions = torch.randint(0, 4, (2, 3, 1))
|
||||
|
||||
lens = None
|
||||
if var_len:
|
||||
lens = torch.randint(1, 4, (2,))
|
||||
|
||||
flow_loss = dynamics(
|
||||
**dynamics_input,
|
||||
lens = lens,
|
||||
tasks = tasks,
|
||||
signal_levels = signal_levels,
|
||||
step_sizes_log2 = step_sizes_log2
|
||||
step_sizes_log2 = step_sizes_log2,
|
||||
discrete_actions = actions,
|
||||
add_autoregressive_action_loss = True
|
||||
)
|
||||
|
||||
assert flow_loss.numel() == 1
|
||||
|
||||
# generating
|
||||
|
||||
generated_video, generated_rewards = dynamics.generate(
|
||||
generations = dynamics.generate(
|
||||
time_steps = 10,
|
||||
image_height = 128,
|
||||
image_width = 128,
|
||||
batch_size = 2,
|
||||
return_rewards_per_frame = True
|
||||
return_rewards_per_frame = True,
|
||||
use_time_cache = use_time_cache
|
||||
)
|
||||
|
||||
assert generated_video.shape == (2, 3, 10, 128, 128)
|
||||
assert generated_rewards.shape == (2, 10)
|
||||
assert generations.video.shape == (2, 3, 10, 128, 128)
|
||||
assert generations.rewards.shape == (2, 10)
|
||||
|
||||
# rl
|
||||
|
||||
@ -173,3 +204,628 @@ def test_attend_factory(
|
||||
out = attend(q, k, v)
|
||||
|
||||
assert torch.allclose(flex_out, out, atol = 1e-6)
|
||||
|
||||
def test_action_with_world_model():
|
||||
from dreamer4.dreamer4 import VideoTokenizer, DynamicsWorldModel
|
||||
|
||||
tokenizer = VideoTokenizer(
|
||||
512,
|
||||
dim_latent = 32,
|
||||
patch_size = 32,
|
||||
encoder_depth = 4,
|
||||
decoder_depth = 4,
|
||||
attn_heads = 8,
|
||||
image_height = 256,
|
||||
image_width = 256,
|
||||
attn_kwargs = dict(
|
||||
query_heads = 16
|
||||
)
|
||||
)
|
||||
|
||||
dynamics = DynamicsWorldModel(
|
||||
512,
|
||||
num_agents = 1,
|
||||
video_tokenizer = tokenizer,
|
||||
dim_latent = 32,
|
||||
depth = 4,
|
||||
num_discrete_actions = 4
|
||||
)
|
||||
|
||||
rewards = torch.randn(1, 4)
|
||||
discrete_actions = torch.randint(0, 4, (1, 4, 1))
|
||||
|
||||
gen = dynamics.generate(
|
||||
16,
|
||||
batch_size = 4,
|
||||
return_rewards_per_frame = True,
|
||||
return_agent_actions = True,
|
||||
return_log_probs_and_values = True
|
||||
)
|
||||
|
||||
assert gen.video.shape == (4, 3, 16, 256, 256)
|
||||
assert gen.rewards.shape == (4, 16)
|
||||
|
||||
discrete_actions, continuous_actions = gen.actions
|
||||
|
||||
assert discrete_actions.shape == (4, 16, 1)
|
||||
assert continuous_actions is None
|
||||
|
||||
discrete_log_probs, _ = gen.log_probs
|
||||
|
||||
assert discrete_log_probs.shape == (4, 16, 1)
|
||||
assert gen.values.shape == (4, 16)
|
||||
|
||||
# take a reinforcement learning step
|
||||
|
||||
actor_loss, critic_loss = dynamics.learn_from_experience(gen)
|
||||
|
||||
actor_loss.backward(retain_graph = True)
|
||||
critic_loss.backward()
|
||||
|
||||
def test_action_embedder():
|
||||
from dreamer4.dreamer4 import ActionEmbedder
|
||||
|
||||
# 1 discrete action with 4 choices
|
||||
|
||||
embedder = ActionEmbedder(
|
||||
512,
|
||||
num_discrete_actions = 4
|
||||
)
|
||||
|
||||
actions = torch.randint(0, 4, (2, 3, 1))
|
||||
action_embed = embedder(discrete_actions = actions)
|
||||
|
||||
assert action_embed.shape == (2, 3, 512)
|
||||
|
||||
# 2 discrete actions with 4 choices each
|
||||
|
||||
embedder = ActionEmbedder(
|
||||
512,
|
||||
num_discrete_actions = (4, 4)
|
||||
)
|
||||
|
||||
actions = torch.randint(0, 4, (2, 3, 2))
|
||||
action_embed = embedder(discrete_actions = actions)
|
||||
|
||||
assert action_embed.shape == (2, 3, 512)
|
||||
|
||||
# picking out only the second discrete action
|
||||
|
||||
actions = torch.randint(0, 4, (2, 3, 1))
|
||||
action_embed = embedder(discrete_actions = actions, discrete_action_types = 1)
|
||||
|
||||
assert action_embed.shape == (2, 3, 512)
|
||||
|
||||
# 2 continuous actions
|
||||
|
||||
embedder = ActionEmbedder(
|
||||
512,
|
||||
num_continuous_actions = 2,
|
||||
continuous_norm_stats = ((0., 2.), (1., 1.)) # (mean, std) for normalizing each action
|
||||
)
|
||||
|
||||
actions = torch.randn((2, 3, 2))
|
||||
action_embed = embedder(continuous_actions = actions)
|
||||
|
||||
assert action_embed.shape == (2, 3, 512)
|
||||
|
||||
# 2 discrete actions with 4 choices each and 2 continuous actions
|
||||
|
||||
embedder = ActionEmbedder(
|
||||
512,
|
||||
num_discrete_actions = (4, 4),
|
||||
num_continuous_actions = 2
|
||||
)
|
||||
|
||||
discrete_actions = torch.randint(0, 4, (2, 3, 2))
|
||||
continuous_actions = torch.randn(2, 3, 2)
|
||||
|
||||
action_embed = embedder(discrete_actions = discrete_actions, continuous_actions = continuous_actions)
|
||||
assert action_embed.shape == (2, 3, 512)
|
||||
|
||||
# picking out one discrete and one continuous
|
||||
|
||||
discrete_actions = torch.randint(0, 4, (2, 3, 1))
|
||||
continuous_actions = torch.randn(2, 3, 1)
|
||||
|
||||
action_embed = embedder(discrete_actions = discrete_actions, continuous_actions = continuous_actions, discrete_action_types = 1, continuous_action_types = 0)
|
||||
|
||||
assert action_embed.shape == (2, 3, 512)
|
||||
|
||||
# unembed
|
||||
|
||||
embedder = ActionEmbedder(
|
||||
512,
|
||||
num_discrete_actions = (4, 4),
|
||||
num_continuous_actions = 2,
|
||||
can_unembed = True
|
||||
)
|
||||
|
||||
discrete_actions = torch.randint(0, 4, (2, 3, 2))
|
||||
continuous_actions = torch.randn(2, 3, 2)
|
||||
|
||||
action_embed = embedder(discrete_actions = discrete_actions, continuous_actions = continuous_actions)
|
||||
|
||||
discrete_logits, continuous_mean_log_var = embedder.unembed(action_embed)
|
||||
|
||||
assert discrete_logits.shape == (2, 3, 8)
|
||||
assert continuous_mean_log_var.shape == (2, 3, 2, 2)
|
||||
|
||||
# test kl div
|
||||
|
||||
discrete_logits_tgt, continuous_mean_log_var_tgt = embedder.unembed(action_embed)
|
||||
|
||||
discrete_kl_div, continuous_kl_div = embedder.kl_div((discrete_logits, continuous_mean_log_var), (discrete_logits_tgt, continuous_mean_log_var_tgt))
|
||||
|
||||
assert discrete_kl_div.shape == (2, 3)
|
||||
assert continuous_kl_div.shape == (2, 3)
|
||||
|
||||
# return discrete split by number of actions
|
||||
|
||||
discrete_logits, continuous_mean_log_var = embedder.unembed(action_embed, return_split_discrete = True)
|
||||
assert discrete_logits[0].shape == discrete_logits[1].shape == (2, 3, 4)
|
||||
|
||||
# unembed subset of actions
|
||||
|
||||
discrete_logits, continuous_mean_log_var = embedder.unembed(action_embed, discrete_action_types = 1, continuous_action_types = 0)
|
||||
|
||||
assert discrete_logits.shape == (2, 3, 4)
|
||||
assert continuous_mean_log_var.shape == (2, 3, 1, 2)
|
||||
|
||||
# sample actions
|
||||
|
||||
sampled_discrete_actions, sampled_continuous_actions = embedder.sample(action_embed, discrete_action_types = 1, continuous_action_types = 0)
|
||||
|
||||
assert sampled_discrete_actions.shape == (2, 3, 1)
|
||||
assert sampled_continuous_actions.shape == (2, 3, 1)
|
||||
|
||||
# log probs
|
||||
|
||||
assert discrete_logits.shape == (2, 3, 4)
|
||||
assert continuous_mean_log_var.shape == (2, 3, 1, 2)
|
||||
|
||||
discrete_log_probs, continuous_log_probs = embedder.log_probs(
|
||||
action_embed,
|
||||
discrete_targets = discrete_actions,
|
||||
continuous_targets = continuous_actions,
|
||||
parallel_discrete_calc = False
|
||||
)
|
||||
|
||||
assert discrete_log_probs.shape == (2, 3, 2)
|
||||
assert continuous_log_probs.shape == (2, 3, 2)
|
||||
|
||||
_, (discrete_entropies, continuous_entropies) = embedder.log_probs(
|
||||
action_embed,
|
||||
discrete_targets = discrete_actions,
|
||||
continuous_targets = continuous_actions,
|
||||
parallel_discrete_calc = True,
|
||||
return_entropies = True
|
||||
)
|
||||
|
||||
assert discrete_entropies.shape == (2, 3, 2)
|
||||
assert continuous_entropies.shape == (2, 3, 2)
|
||||
|
||||
parallel_discrete_log_probs, _ = embedder.log_probs(
|
||||
action_embed,
|
||||
discrete_targets = discrete_actions,
|
||||
continuous_targets = continuous_actions,
|
||||
parallel_discrete_calc = True
|
||||
)
|
||||
|
||||
assert torch.allclose(discrete_log_probs, parallel_discrete_log_probs, atol = 1e-5)
|
||||
|
||||
def test_mtp():
|
||||
from dreamer4.dreamer4 import create_multi_token_prediction_targets
|
||||
|
||||
rewards = torch.randn(3, 16) # (b t)
|
||||
|
||||
reward_targets, mask = create_multi_token_prediction_targets(rewards, 3) # say three token lookahead
|
||||
|
||||
assert reward_targets.shape == (3, 16, 3)
|
||||
assert mask.shape == (3, 16, 3)
|
||||
|
||||
actions = torch.randint(0, 10, (3, 16, 2))
|
||||
action_targets, mask = create_multi_token_prediction_targets(actions, 3)
|
||||
|
||||
assert action_targets.shape == (3, 16, 3, 2)
|
||||
assert mask.shape == (3, 16, 3)
|
||||
|
||||
from dreamer4.dreamer4 import ActionEmbedder
|
||||
|
||||
embedder = ActionEmbedder(
|
||||
512,
|
||||
num_discrete_actions = (4, 4),
|
||||
num_continuous_actions = 2,
|
||||
can_unembed = True,
|
||||
num_unembed_preds = 8
|
||||
)
|
||||
|
||||
discrete_actions = torch.randint(0, 4, (2, 3, 2))
|
||||
continuous_actions = torch.randn(2, 3, 2)
|
||||
|
||||
action_embed = torch.randn(2, 16, 512)
|
||||
discrete_logits, continuous_logits = embedder.unembed(action_embed)
|
||||
|
||||
assert discrete_logits.shape == (8, 2, 16, 8)
|
||||
|
||||
discrete_logits, continuous_logits = embedder.unembed(action_embed, pred_head_index = 0)
|
||||
|
||||
assert discrete_logits.shape == (2, 16, 8)
|
||||
|
||||
def test_loss_normalizer():
|
||||
from torch import cat
|
||||
from dreamer4.dreamer4 import LossNormalizer
|
||||
|
||||
loss_normalizer = LossNormalizer(4, beta = 0.)
|
||||
|
||||
losses = torch.rand(4)
|
||||
|
||||
_ = loss_normalizer(losses)
|
||||
normed_losses = loss_normalizer(losses)
|
||||
|
||||
assert (normed_losses == 1.).all()
|
||||
|
||||
def test_tokenizer_trainer():
|
||||
from dreamer4.trainers import VideoTokenizerTrainer
|
||||
from dreamer4.dreamer4 import VideoTokenizer
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
class MockDataset(Dataset):
|
||||
def __len__(self):
|
||||
return 2
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return torch.randn(3, 2, 64, 64)
|
||||
|
||||
dataset = MockDataset()
|
||||
|
||||
tokenizer = VideoTokenizer(
|
||||
16,
|
||||
encoder_depth = 1,
|
||||
decoder_depth = 1,
|
||||
time_block_every = 1,
|
||||
dim_latent = 16,
|
||||
patch_size = 32,
|
||||
attn_dim_head = 16,
|
||||
num_latent_tokens = 4
|
||||
)
|
||||
|
||||
trainer = VideoTokenizerTrainer(
|
||||
tokenizer,
|
||||
dataset = dataset,
|
||||
num_train_steps = 1,
|
||||
batch_size = 1,
|
||||
cpu = True,
|
||||
max_grad_norm = 0.5
|
||||
)
|
||||
|
||||
trainer()
|
||||
|
||||
@param('with_actions', (True, False))
|
||||
@param('with_rewards', (True, False))
|
||||
def test_bc_trainer(
|
||||
with_actions,
|
||||
with_rewards
|
||||
):
|
||||
from dreamer4.trainers import BehaviorCloneTrainer
|
||||
from dreamer4.dreamer4 import DynamicsWorldModel, VideoTokenizer
|
||||
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
class MockDataset(Dataset):
|
||||
def __len__(self):
|
||||
return 2
|
||||
|
||||
def __getitem__(self, idx):
|
||||
state = torch.randn(3, 2, 64, 64)
|
||||
|
||||
pkg = dict(video = state)
|
||||
|
||||
if with_actions:
|
||||
pkg.update(discrete_actions = torch.randint(0, 4, (2, 1)))
|
||||
|
||||
if with_rewards:
|
||||
pkg.update(rewards = torch.randn(2,))
|
||||
|
||||
return pkg
|
||||
|
||||
dataset = MockDataset()
|
||||
|
||||
tokenizer = VideoTokenizer(
|
||||
16,
|
||||
encoder_depth = 1,
|
||||
decoder_depth = 1,
|
||||
time_block_every = 1,
|
||||
dim_latent = 16,
|
||||
patch_size = 32,
|
||||
attn_dim_head = 16,
|
||||
num_latent_tokens = 1
|
||||
)
|
||||
|
||||
world_model = DynamicsWorldModel(
|
||||
video_tokenizer = tokenizer,
|
||||
dim = 16,
|
||||
dim_latent = 16,
|
||||
max_steps = 64,
|
||||
num_tasks = 4,
|
||||
num_latent_tokens = 1,
|
||||
depth = 1,
|
||||
time_block_every = 1,
|
||||
num_spatial_tokens = 1,
|
||||
pred_orig_latent = True,
|
||||
num_discrete_actions = 4,
|
||||
attn_dim_head = 16,
|
||||
prob_no_shortcut_train = 0.1,
|
||||
num_residual_streams = 1
|
||||
)
|
||||
|
||||
trainer = BehaviorCloneTrainer(
|
||||
world_model,
|
||||
dataset = dataset,
|
||||
batch_size = 1,
|
||||
num_train_steps = 1,
|
||||
cpu = True
|
||||
)
|
||||
|
||||
trainer()
|
||||
|
||||
def test_dream_trainer():
|
||||
from dreamer4.dreamer4 import DynamicsWorldModel
|
||||
|
||||
world_model = DynamicsWorldModel(
|
||||
dim = 16,
|
||||
dim_latent = 16,
|
||||
max_steps = 64,
|
||||
num_tasks = 4,
|
||||
num_latent_tokens = 1,
|
||||
depth = 1,
|
||||
time_block_every = 1,
|
||||
num_spatial_tokens = 1,
|
||||
pred_orig_latent = True,
|
||||
num_discrete_actions = 4,
|
||||
attn_dim_head = 16,
|
||||
prob_no_shortcut_train = 0.1,
|
||||
num_residual_streams = 1
|
||||
)
|
||||
|
||||
# training from self-generations (dreams)
|
||||
|
||||
from dreamer4.trainers import DreamTrainer
|
||||
|
||||
dream_trainer = DreamTrainer(
|
||||
world_model,
|
||||
batch_size = 2,
|
||||
num_train_steps = 1,
|
||||
cpu = True,
|
||||
)
|
||||
|
||||
dream_trainer()
|
||||
|
||||
def test_cache_generate():
|
||||
from dreamer4.dreamer4 import DynamicsWorldModel
|
||||
|
||||
dynamics = DynamicsWorldModel(
|
||||
dim = 16,
|
||||
dim_latent = 16,
|
||||
max_steps = 64,
|
||||
num_tasks = 4,
|
||||
num_latent_tokens = 4,
|
||||
depth = 1,
|
||||
time_block_every = 1,
|
||||
num_spatial_tokens = 1,
|
||||
pred_orig_latent = True,
|
||||
num_discrete_actions = 4,
|
||||
attn_dim_head = 16,
|
||||
prob_no_shortcut_train = 0.1,
|
||||
num_residual_streams = 1
|
||||
)
|
||||
|
||||
generated, time_cache = dynamics.generate(1, return_time_cache = True)
|
||||
generated, time_cache = dynamics.generate(1, time_cache = time_cache, return_time_cache = True)
|
||||
generated, time_cache = dynamics.generate(1, time_cache = time_cache, return_time_cache = True)
|
||||
|
||||
@param('vectorized', (False, True))
|
||||
@param('use_pmpo', (False, True))
|
||||
@param('env_can_terminate', (False, True))
|
||||
@param('env_can_truncate', (False, True))
|
||||
@param('store_agent_embed', (False, True))
|
||||
def test_online_rl(
|
||||
vectorized,
|
||||
use_pmpo,
|
||||
env_can_terminate,
|
||||
env_can_truncate,
|
||||
store_agent_embed
|
||||
):
|
||||
from dreamer4.dreamer4 import DynamicsWorldModel, VideoTokenizer
|
||||
|
||||
tokenizer = VideoTokenizer(
|
||||
16,
|
||||
encoder_depth = 1,
|
||||
decoder_depth = 1,
|
||||
time_block_every = 1,
|
||||
dim_latent = 16,
|
||||
patch_size = 32,
|
||||
attn_dim_head = 16,
|
||||
num_latent_tokens = 1,
|
||||
image_height = 256,
|
||||
image_width = 256,
|
||||
)
|
||||
|
||||
world_model_and_policy = DynamicsWorldModel(
|
||||
video_tokenizer = tokenizer,
|
||||
dim = 16,
|
||||
dim_latent = 16,
|
||||
max_steps = 64,
|
||||
num_tasks = 4,
|
||||
num_latent_tokens = 1,
|
||||
depth = 1,
|
||||
time_block_every = 1,
|
||||
num_spatial_tokens = 1,
|
||||
pred_orig_latent = True,
|
||||
num_discrete_actions = 4,
|
||||
attn_dim_head = 16,
|
||||
prob_no_shortcut_train = 0.1,
|
||||
num_residual_streams = 1
|
||||
)
|
||||
|
||||
from dreamer4.mocks import MockEnv
|
||||
from dreamer4.dreamer4 import combine_experiences
|
||||
|
||||
mock_env = MockEnv(
|
||||
(256, 256),
|
||||
vectorized = vectorized,
|
||||
num_envs = 4,
|
||||
terminate_after_step = 2 if env_can_terminate else None,
|
||||
can_truncate = env_can_truncate,
|
||||
rand_terminate_prob = 0.1
|
||||
)
|
||||
|
||||
# manually
|
||||
|
||||
dream_experience = world_model_and_policy.generate(10, batch_size = 1, store_agent_embed = store_agent_embed, return_for_policy_optimization = True)
|
||||
|
||||
one_experience = world_model_and_policy.interact_with_env(mock_env, max_timesteps = 8, env_is_vectorized = vectorized, store_agent_embed = store_agent_embed)
|
||||
another_experience = world_model_and_policy.interact_with_env(mock_env, max_timesteps = 16, env_is_vectorized = vectorized, store_agent_embed = store_agent_embed)
|
||||
|
||||
combined_experience = combine_experiences([dream_experience, one_experience, another_experience])
|
||||
|
||||
# quick test moving the experience to different devices
|
||||
|
||||
if torch.cuda.is_available():
|
||||
combined_experience = combined_experience.to(torch.device('cuda'))
|
||||
combined_experience = combined_experience.to(world_model_and_policy.device)
|
||||
|
||||
if store_agent_embed:
|
||||
assert exists(combined_experience.agent_embed)
|
||||
|
||||
actor_loss, critic_loss = world_model_and_policy.learn_from_experience(combined_experience, use_pmpo = use_pmpo)
|
||||
|
||||
actor_loss.backward()
|
||||
critic_loss.backward()
|
||||
|
||||
# with trainer
|
||||
|
||||
from dreamer4.trainers import SimTrainer
|
||||
|
||||
trainer = SimTrainer(
|
||||
world_model_and_policy,
|
||||
batch_size = 4,
|
||||
cpu = True
|
||||
)
|
||||
|
||||
trainer(mock_env, num_episodes = 2, env_is_vectorized = vectorized)
|
||||
|
||||
@param('num_video_views', (1, 2))
|
||||
def test_proprioception(
|
||||
num_video_views
|
||||
):
|
||||
from dreamer4.dreamer4 import VideoTokenizer, DynamicsWorldModel
|
||||
|
||||
tokenizer = VideoTokenizer(
|
||||
512,
|
||||
dim_latent = 32,
|
||||
patch_size = 32,
|
||||
encoder_depth = 2,
|
||||
decoder_depth = 2,
|
||||
time_block_every = 2,
|
||||
attn_heads = 8,
|
||||
image_height = 256,
|
||||
image_width = 256,
|
||||
attn_kwargs = dict(
|
||||
query_heads = 16
|
||||
)
|
||||
)
|
||||
|
||||
dynamics = DynamicsWorldModel(
|
||||
512,
|
||||
num_agents = 1,
|
||||
video_tokenizer = tokenizer,
|
||||
dim_latent = 32,
|
||||
dim_proprio = 21,
|
||||
num_tasks = 4,
|
||||
num_video_views = num_video_views,
|
||||
num_discrete_actions = 4,
|
||||
num_residual_streams = 1
|
||||
)
|
||||
|
||||
if num_video_views > 1:
|
||||
video_shape = (2, num_video_views, 3, 10, 256, 256)
|
||||
else:
|
||||
video_shape = (2, 3, 10, 256, 256)
|
||||
|
||||
video = torch.randn(*video_shape)
|
||||
rewards = torch.randn(2, 10)
|
||||
proprio = torch.randn(2, 10, 21)
|
||||
discrete_actions = torch.randint(0, 4, (2, 10, 1))
|
||||
tasks = torch.randint(0, 4, (2,))
|
||||
|
||||
loss = dynamics(
|
||||
video = video,
|
||||
rewards = rewards,
|
||||
tasks = tasks,
|
||||
proprio = proprio,
|
||||
discrete_actions = discrete_actions
|
||||
)
|
||||
|
||||
loss.backward()
|
||||
|
||||
generations = dynamics.generate(
|
||||
10,
|
||||
batch_size = 2,
|
||||
return_decoded_video = True
|
||||
)
|
||||
|
||||
assert exists(generations.proprio)
|
||||
assert generations.video.shape == video_shape
|
||||
|
||||
def test_epo():
|
||||
from dreamer4.dreamer4 import VideoTokenizer, DynamicsWorldModel
|
||||
|
||||
tokenizer = VideoTokenizer(
|
||||
512,
|
||||
dim_latent = 32,
|
||||
patch_size = 32,
|
||||
encoder_depth = 2,
|
||||
decoder_depth = 2,
|
||||
time_block_every = 2,
|
||||
attn_heads = 8,
|
||||
image_height = 256,
|
||||
image_width = 256,
|
||||
attn_kwargs = dict(
|
||||
query_heads = 16
|
||||
)
|
||||
)
|
||||
|
||||
dynamics = DynamicsWorldModel(
|
||||
512,
|
||||
num_agents = 1,
|
||||
video_tokenizer = tokenizer,
|
||||
dim_latent = 32,
|
||||
dim_proprio = 21,
|
||||
num_tasks = 4,
|
||||
num_latent_genes = 16,
|
||||
num_discrete_actions = 4,
|
||||
num_residual_streams = 1
|
||||
)
|
||||
|
||||
fitness = torch.randn(16,)
|
||||
dynamics.evolve_(fitness)
|
||||
|
||||
def test_images_to_video_tokenizer():
|
||||
import torch
|
||||
from dreamer4 import VideoTokenizer, DynamicsWorldModel, AxialSpaceTimeTransformer
|
||||
|
||||
tokenizer = VideoTokenizer(
|
||||
dim = 512,
|
||||
dim_latent = 32,
|
||||
patch_size = 32,
|
||||
image_height = 256,
|
||||
image_width = 256,
|
||||
encoder_add_decor_aux_loss = True
|
||||
)
|
||||
|
||||
images = torch.randn(2, 3, 256, 256)
|
||||
loss, (losses, recon_images) = tokenizer(images, return_intermediates = True)
|
||||
loss.backward()
|
||||
|
||||
assert images.shape == recon_images.shape
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user