Compare commits

...

136 Commits

Author SHA1 Message Date
lucidrains
5bb027b386 allow for image pretraining on video tokenizer 2025-12-04 10:34:15 -08:00
lucidrains
9efe269688 oops 2025-12-03 08:11:47 -08:00
lucidrains
fb8c3793b4 complete the addition of a state entropy bonus 2025-12-03 07:52:30 -08:00
lucidrains
fb6d69f43a complete the latent autoregressive prediction, to use the log variance as a state entropy bonus 2025-12-03 06:40:19 -08:00
lucidrains
125693ce1c add a separate state prediction head for the state entropy 2025-12-02 15:58:25 -08:00
lucidrains
2e7f406d49 allow for the combining of experiences from environment and dream 2025-11-13 16:37:35 -08:00
lucidrains
690ecf07dc fix the rnn time caching issue 2025-11-11 17:04:02 -08:00
lucidrains
ac1c12f743 disable until rnn hiddens are handled properly 2025-11-10 15:52:43 -08:00
lucidrains
3c84b404a8 rnn layer needs to be hyper connected too 2025-11-10 15:51:33 -08:00
lucidrains
d5b70e2b86 allow for adding an RNN before time attention, but need to handle caching still 2025-11-10 11:42:20 -08:00
lucidrains
c3532fa797 add learned value residual 2025-11-10 09:33:58 -08:00
lucidrains
73029635fe last commit for the day 2025-11-09 11:12:37 -08:00
lucidrains
e1c41f4371 decorrelation loss for spatial attention as well 2025-11-09 10:41:58 -08:00
Phil Wang
f55c61c6cf
cleanup 2025-11-09 10:22:47 -08:00
lucidrains
051d4d6ee2 oops 2025-11-09 10:12:51 -08:00
lucidrains
ef3a5552e7 eventually video tokenizer may need to be trained on single frames 2025-11-09 10:11:56 -08:00
lucidrains
0c4224da18 add a decorrelation loss for temporal attention in encoder of video tokenizer 2025-11-09 09:47:47 -08:00
Phil Wang
256a81f658
Merge pull request #5 from Cycl0/patch-1
Update Discord channel link in README to use permanent link
2025-11-09 08:17:41 -08:00
lucidrains
cfd34f1eba able to move the experience to cpu easily, and auto matically move it to the device of the dynamics world model when learning from it 2025-11-09 16:16:13 +00:00
Lucas Kenzo Cyra
4ffbe37873
Update Discord channel link in README to use permanent link
Updated Discord channel link for collaboration.
2025-11-09 10:12:45 -03:00
lucidrains
24ef72d528 0.1.4 2025-11-04 15:21:20 -08:00
Phil Wang
a4afcb22a6
Merge pull request #4 from dirkmcpherson/bugfix
fix a few typo bugs. Support info in return signature of environment …
2025-11-04 15:19:25 -08:00
j
b0f6b8583d fix a few typo bugs. Support info in return signature of environment step. Temporarily turn off flex attention when the kv_cache is used to avoid bug. 2025-11-04 17:29:12 -05:00
lucidrains
38cba80068 readme 2025-11-04 06:05:11 -08:00
lucidrains
c0a6cd56a1 link to new discord 2025-10-31 09:06:44 -07:00
lucidrains
d756d1bb8c addressing issues raised by an independent researcher with llm assistance 2025-10-31 08:37:39 -07:00
lucidrains
60681fce1d fix generation so that one more step is taken to decode agent embeds off the final cleaned set of latents, update readme 2025-10-31 06:48:49 -07:00
Phil Wang
6870294d95
no longer needed 2025-10-30 09:23:27 -07:00
lucidrains
3beae186da some more control over whether to normalize advantages 2025-10-30 08:46:03 -07:00
lucidrains
0904e224ab make the reverse kl optional 2025-10-30 08:22:50 -07:00
lucidrains
767789d0ca they decided on 0.3 for the behavioral prior loss weight 2025-10-29 13:24:58 -07:00
lucidrains
35b87c4fa1 oops 2025-10-29 13:04:02 -07:00
lucidrains
c4a3cb09d5 swap for discrete kl div, thanks to Dirk for pointing this out on the discord 2025-10-29 11:54:18 -07:00
lucidrains
cb54121ace sim trainer needs to take care of agent embedding and old actions 2025-10-29 11:15:11 -07:00
lucidrains
586379f2c8 sum the kl div loss across number of actions by default for action embedder .kl_div 2025-10-29 10:46:42 -07:00
lucidrains
a358a44a53 always store old agent embeds and old action parameters when possible 2025-10-29 10:39:15 -07:00
lucidrains
3547344312 take care of storing the old action logits and mean log var, and calculate kl div for pmpo based off that during learn from experience 2025-10-29 10:31:32 -07:00
lucidrains
691d9ca007 add kl div on action embedder, working way towards the kl div loss in pmpo 2025-10-29 10:02:25 -07:00
lucidrains
91d697f8ca fix pmpo 2025-10-28 18:55:22 -07:00
lucidrains
7acaa764f6 evolutionary policy optimization on dreams will be interesting 2025-10-28 10:17:01 -07:00
lucidrains
c0450359f3 allow for evolutionary policy optimization 2025-10-28 10:11:13 -07:00
lucidrains
46f86cd247 fix storing of agent embedding 2025-10-28 09:36:58 -07:00
lucidrains
903c43b770 use the agent embeds off the stored experience if available 2025-10-28 09:14:02 -07:00
lucidrains
d476fa7b14 able to store the agent embeddings during rollouts with imagination or environment, for efficient policy optimization (but will also allow for finetuning world model for the heads) 2025-10-28 09:02:26 -07:00
lucidrains
789f091c63 redo so that max timesteps is treated as truncation at the last timestep, then allow for accepting the truncation signal from the environment and reuse same logic 2025-10-28 08:04:48 -07:00
lucidrains
41ab83f691 fix mock 2025-10-27 10:47:24 -07:00
lucidrains
995b1f64e5 handle environments that return a terminate flag, also make sure episode lens are logged in vectorized env 2025-10-27 10:14:28 -07:00
lucidrains
fd1e87983b quantile filter 2025-10-27 09:08:26 -07:00
lucidrains
fe79bfa951 optionally keep track of returns statistics and normalize with them before advantage 2025-10-27 09:02:08 -07:00
lucidrains
f808b1c1d2 oops 2025-10-27 08:34:22 -07:00
lucidrains
349a03acd7 redo so lens is always the episode length, including the bootstrap value timestep, and use is_truncated to mask out the bootstrap node from being learned on 2025-10-27 08:06:21 -07:00
lucidrains
59c458aea3 introduce an is_truncated field on Experience, and mask out rewards and values before calculating gae appropriately 2025-10-27 07:55:00 -07:00
lucidrains
fbfd59e42f handle variable lengthed experiences when doing policy optimization 2025-10-27 06:09:09 -07:00
lucidrains
46432aee9b fix an issue with bc 2025-10-25 12:30:08 -07:00
lucidrains
f97d9adc97 oops, forgot to add the view embedding for robotics 2025-10-25 11:39:06 -07:00
lucidrains
32cf142b4d take another step for variable len experiences 2025-10-25 11:31:41 -07:00
lucidrains
1ed6a15cb0 fix tests 2025-10-25 11:13:22 -07:00
lucidrains
4d8f5613cc start storing the experience lens 2025-10-25 10:55:47 -07:00
lucidrains
3d5617d769 take a step towards variable lengthed experiences during training 2025-10-25 10:45:34 -07:00
lucidrains
77a40e8701 validate that we can generate multiple video streams for robotics use-case 2025-10-25 09:23:07 -07:00
lucidrains
4ce82f34df given the VAT paper, add multiple video streams (third person, wrist camera, etc), geared for robotics. need to manage an extra dimension for multiple viewpoints 2025-10-25 09:20:55 -07:00
lucidrains
a9b728c611 incorporate proprioception into the dynamics world model 2025-10-24 11:24:22 -07:00
lucidrains
35c1db4c7d sketch of training from sim env 2025-10-24 09:13:09 -07:00
lucidrains
27ac05efb0 function for combining experiences 2025-10-24 08:00:10 -07:00
lucidrains
d0ffc6bfed with or without signed advantage 2025-10-23 16:24:29 -07:00
lucidrains
fb3e026fe0 handle vectorized env 2025-10-22 11:19:44 -07:00
lucidrains
7ecc5d03e8 wire up the time kv cache when interacting with sim / env 2025-10-22 08:39:11 -07:00
lucidrains
d82debb7a6 first pass through gathering experience with a mock env for online rl 2025-10-22 08:32:46 -07:00
lucidrains
03b16a48f2 sketch out the dream trainer, seems like they only fine tune the heads 2025-10-22 06:41:10 -07:00
lucidrains
6f1a7a24ed try to fix ci 2025-10-21 11:47:39 -07:00
lucidrains
e316499047 naming 2025-10-21 10:57:55 -07:00
lucidrains
40da985c6b tweak bc trainer 2025-10-21 10:55:24 -07:00
lucidrains
2fc3b17149 take a gradient step with behavioral clone trainer, make sure it works with and without actions and rewards 2025-10-21 10:20:08 -07:00
lucidrains
283d59d75a oops 2025-10-21 09:50:07 -07:00
lucidrains
4a5465eeb6 fix ci 2025-10-21 09:17:53 -07:00
lucidrains
b34128d3d0 make sure time kv cache can be passed back in during generation 2025-10-21 09:15:32 -07:00
lucidrains
7ba3988fb9 prepare a mock for interacting with online env 2025-10-21 09:03:20 -07:00
lucidrains
ea13d4fcab take a gradient step with video tokenizer trainer 2025-10-21 08:52:22 -07:00
lucidrains
15876d34cf more muon prep 2025-10-21 08:23:59 -07:00
lucidrains
b4763caff9 fix rotary embeddings in presence of kv caching 2025-10-21 07:10:21 -07:00
lucidrains
7195bbb196 oops 2025-10-20 12:42:27 -07:00
lucidrains
ca244a290c first pass through the kv cache for the time block in the dynamics model 2025-10-20 12:25:50 -07:00
lucidrains
a7e0c395c3 allow for only rmsnorm for keys in attention 2025-10-20 11:20:49 -07:00
lucidrains
1345326656 another measure for the attending to nothing issue 2025-10-20 10:32:31 -07:00
lucidrains
55574c054e assert 2025-10-19 09:59:42 -07:00
lucidrains
27ed6d0ba5 fix time kv cache 2025-10-19 09:16:06 -07:00
lucidrains
4930002e99 bit of progress on time kv cache 2025-10-19 09:04:26 -07:00
lucidrains
ecbe13efe8 allow for setting different loss weights for each MTP head (perhaps more weight on the next vs some far out prediction) 2025-10-19 08:37:56 -07:00
lucidrains
f651d779e3 able to control the update of the loss ema from dynamics model forward 2025-10-19 08:25:50 -07:00
lucidrains
374667d8a9 take care of the loss normalization mentioned at the end of the first paragraph of section 3 2025-10-19 08:24:41 -07:00
lucidrains
79a1b1c46e oops 2025-10-18 10:31:48 -07:00
lucidrains
b6aa19f31e complete multi-token prediction for actions, tackle loss balancing another day 2025-10-18 10:23:14 -07:00
lucidrains
bc629d78b1 inverse norm for continuous actions when sampling 2025-10-18 08:55:04 -07:00
lucidrains
0ee475d2df oops 2025-10-18 08:50:53 -07:00
lucidrains
8c88a33d3b complete multi token prediction for the reward head 2025-10-18 08:33:06 -07:00
lucidrains
911a1a8434 oops 2025-10-18 08:07:06 -07:00
lucidrains
5fc0022bbf the function for generating the MTP targets, as well as the mask for the losses 2025-10-18 08:04:51 -07:00
lucidrains
83cfd2cd1b task conditioning when dreaming 2025-10-18 07:47:13 -07:00
lucidrains
22e13c45fc rename 2025-10-17 14:44:25 -07:00
lucidrains
c967404471 0.0.31 2025-10-17 08:55:42 -07:00
lucidrains
0c1b067f97 if optimizer is passed into the learn from dreams function, take the optimizer steps, otherwise let the researcher handle it externally. also ready muon 2025-10-17 08:55:20 -07:00
lucidrains
cb416c0d44 handle the entropies during policy optimization 2025-10-17 08:47:26 -07:00
lucidrains
61773c8219 eventually we will need to learn from the outside stream of experience 2025-10-17 08:06:24 -07:00
lucidrains
0dba734280 start the learning in dreams portion 2025-10-17 08:00:47 -07:00
lucidrains
a0161760a0 extract the log probs and predicted values (symexp two hot encoded) for the phase 3 RL training 2025-10-16 10:40:59 -07:00
lucidrains
2d20d0a6c1 able to roll out actions from one agent within the dreams of a world model 2025-10-16 10:15:43 -07:00
lucidrains
d74f09f0b3 a researcher in discord pointed out that the tokenizer also uses the axial space time transformer. redo without the 3d rotary and block causal, greatly simplifying the implementation 2025-10-16 09:40:14 -07:00
lucidrains
2ccb290e26 pass the attend kwargs for the block causal masking in tokenizer 2025-10-16 08:33:26 -07:00
lucidrains
517ef6b94b oops 2025-10-16 07:03:51 -07:00
lucidrains
ec18bc0fa4 cleanup 2025-10-16 06:44:28 -07:00
lucidrains
2a902eaaf7 allow reward tokens to be attended to as state optionally, DT-esque. figure out multi-agent scenario once i get around to it 2025-10-16 06:41:02 -07:00
lucidrains
d28251e9f9 another consideration before knocking out the RL logic 2025-10-14 11:10:26 -07:00
lucidrains
ff81dd761b separate action and agent embeds 2025-10-13 11:36:21 -07:00
lucidrains
6dbdc3d7d8 correct a misunderstanding where past actions is a separate action token, while agent token is used for the prediction of next action, rewards, values 2025-10-12 16:16:18 -07:00
lucidrains
9c78962736 sampling actions 2025-10-12 11:27:12 -07:00
lucidrains
c5e64ff4ce separate out the key from the value projections in attention for muon 2025-10-12 09:42:22 -07:00
lucidrains
ab5de6795f bring in muon 2025-10-12 09:35:06 -07:00
lucidrains
8a73a27fc7 add nested tensor way for getting log prob of multiple discrete actions 2025-10-11 10:53:24 -07:00
lucidrains
01bf70e18a 0.0.14 2025-10-11 09:24:58 -07:00
lucidrains
b2725d9b6e complete behavior cloning for one agent 2025-10-11 09:24:49 -07:00
lucidrains
02558d1f08 will organize the unembedding parameters under the actor optimizer 2025-10-11 06:55:57 -07:00
lucidrains
563b269f8a bring in hyper connections 2025-10-11 06:52:57 -07:00
lucidrains
5df3e69583 last commit for the day 2025-10-10 11:59:18 -07:00
lucidrains
9230267d34 handle subset of discrete action unembedding 2025-10-10 11:27:05 -07:00
lucidrains
c68942b026 cleanup 2025-10-10 10:42:54 -07:00
lucidrains
32aa355e37 prepare unembedding parameters in ActionEmbedder as well as the policy head, to allow for behavioral cloning before RL 2025-10-10 10:41:48 -07:00
lucidrains
9101a49cdd handle continuous value normalization if stats passed in 2025-10-09 08:59:54 -07:00
lucidrains
31f4363be7 must be able to do phase1 and phase2 training 2025-10-09 08:04:36 -07:00
lucidrains
e2d86a4543 add a complete action embedder that can accept any number of discrete actions with variable bins as well as any number of continuous actions, pooled and added to the agent token as described in the paper (seems like they fixed that horrendous hack in dreamer v3 with sticky action) 2025-10-09 07:53:42 -07:00
lucidrains
b62c08be65 fix task embed in presence of multiple agent tokens 2025-10-08 08:42:25 -07:00
lucidrains
4c2ed100a3 fix masking for multiple agent tokens 2025-10-08 08:26:44 -07:00
lucidrains
ed0918c974 prepare for evolution within dreams 2025-10-08 08:13:16 -07:00
lucidrains
892654d442 multiple agent tokens sharing the same state 2025-10-08 08:06:13 -07:00
lucidrains
c4e0f46528 for the value head, we will go for symexp encoding as well (following the "stop regressing" paper from Farebrother et al), also use layernormed mlp given recent papers 2025-10-08 07:37:34 -07:00
lucidrains
a50e360502 makes more sense for the noise to be fixed 2025-10-08 07:17:05 -07:00
Phil Wang
9c56ba0c9d
Merge pull request #3 from lucidrains/pytest-shard
add pytest shard
2025-10-08 07:03:11 -07:00
8 changed files with 3975 additions and 421 deletions

View File

@ -5,11 +5,11 @@ jobs:
build:
runs-on: ubuntu-latest
timeout-minutes: 20
timeout-minutes: 60
strategy:
fail-fast: false
matrix:
group: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
group: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19]
steps:
- uses: actions/checkout@v4
@ -24,4 +24,4 @@ jobs:
python -m uv pip install -e .[test]
- name: Test with pytest
run: |
python -m pytest --num-shards 10 --shard-id ${{ matrix.group }} tests/
python -m pytest --num-shards 20 --shard-id ${{ matrix.group }} tests/

102
README.md
View File

@ -1,10 +1,99 @@
<img src="./dreamer4-fig2.png" width="400px"></img>
## Dreamer 4 (wip)
## Dreamer 4
Implementation of Danijar's [latest iteration](https://arxiv.org/abs/2509.24527v1) for his [Dreamer](https://danijar.com/project/dreamer4/) line of work
[Temporary Discord](https://discord.gg/MkACrrkrYR)
[Discord channel](https://discord.gg/PmGR7KRwxq) for collaborating with other researchers interested in this work
## Appreciation
- [@dirkmcpherson](https://github.com/dirkmcpherson) for fixes to typo errors and unpassed arguments!
## Install
```bash
$ pip install dreamer4
```
## Usage
```python
import torch
from dreamer4 import VideoTokenizer, DynamicsWorldModel
# video tokenizer, learned through MAE + lpips
tokenizer = VideoTokenizer(
dim = 512,
dim_latent = 32,
patch_size = 32,
image_height = 256,
image_width = 256
)
video = torch.randn(2, 3, 10, 256, 256)
# learn the tokenizer
loss = tokenizer(video)
loss.backward()
# dynamics world model
world_model = DynamicsWorldModel(
dim = 512,
dim_latent = 32,
video_tokenizer = tokenizer,
num_discrete_actions = 4,
num_residual_streams = 1
)
# state, action, rewards
video = torch.randn(2, 3, 10, 256, 256)
discrete_actions = torch.randint(0, 4, (2, 10, 1))
rewards = torch.randn(2, 10)
# learn dynamics / behavior cloned model
loss = world_model(
video = video,
rewards = rewards,
discrete_actions = discrete_actions
)
loss.backward()
# do the above with much data
# then generate dreams
dreams = world_model.generate(
10,
batch_size = 2,
return_decoded_video = True,
return_for_policy_optimization = True
)
# learn from the dreams
actor_loss, critic_loss = world_model.learn_from_experience(dreams)
(actor_loss + critic_loss).backward()
# learn from environment
from dreamer4.mocks import MockEnv
mock_env = MockEnv((256, 256), vectorized = True, num_envs = 4)
experience = world_model.interact_with_env(mock_env, max_timesteps = 8, env_is_vectorized = True)
actor_loss, critic_loss = world_model.learn_from_experience(experience)
(actor_loss + critic_loss).backward()
```
## Citation
@ -20,11 +109,4 @@ Implementation of Danijar's [latest iteration](https://arxiv.org/abs/2509.24527v
}
```
```bibtex
@misc{xiong2025ndrope,
author = {Jerry Xiong},
title = {On n-dimensional rotary positional embeddings},
year = {2025},
url = {https://jerryxio.ng/posts/nd-rope/}
}
```
*the conquest of nature is to be achieved through number and measure - angels to Descartes in a dream*

View File

@ -1,5 +1,12 @@
from dreamer4.dreamer4 import (
VideoTokenizer,
DynamicsModel,
Dreamer
DynamicsWorldModel,
AxialSpaceTimeTransformer
)
from dreamer4.trainers import (
VideoTokenizerTrainer,
BehaviorCloneTrainer,
DreamTrainer
)

File diff suppressed because it is too large Load Diff

97
dreamer4/mocks.py Normal file
View File

@ -0,0 +1,97 @@
from __future__ import annotations
from random import choice
import torch
from torch import tensor, empty, randn, randint
from torch.nn import Module
from einops import repeat
# helpers
def exists(v):
return v is not None
# mock env
class MockEnv(Module):
def __init__(
self,
image_shape,
reward_range = (-100, 100),
num_envs = 1,
vectorized = False,
terminate_after_step = None,
rand_terminate_prob = 0.05,
can_truncate = False,
rand_truncate_prob = 0.05,
):
super().__init__()
self.image_shape = image_shape
self.reward_range = reward_range
self.num_envs = num_envs
self.vectorized = vectorized
assert not (vectorized and num_envs == 1)
# mocking termination and truncation
self.can_terminate = exists(terminate_after_step)
self.terminate_after_step = terminate_after_step
self.rand_terminate_prob = rand_terminate_prob
self.can_truncate = can_truncate
self.rand_truncate_prob = rand_truncate_prob
self.register_buffer('_step', tensor(0))
def get_random_state(self):
return randn(3, *self.image_shape)
def reset(
self,
seed = None
):
self._step.zero_()
state = self.get_random_state()
if self.vectorized:
state = repeat(state, '... -> b ...', b = self.num_envs)
return state
def step(
self,
actions,
):
state = self.get_random_state()
reward = empty(()).uniform_(*self.reward_range)
if self.vectorized:
discrete, continuous = actions
assert discrete.shape[0] == self.num_envs, f'expected batch of actions for {self.num_envs} environments'
state = repeat(state, '... -> b ...', b = self.num_envs)
reward = repeat(reward, ' -> b', b = self.num_envs)
out = (state, reward)
if self.can_terminate:
shape = (self.num_envs,) if self.vectorized else (1,)
valid_step = self._step > self.terminate_after_step
terminate = (torch.rand(shape) < self.rand_terminate_prob) & valid_step
out = (*out, terminate)
# maybe truncation
if self.can_truncate:
truncate = (torch.rand(shape) < self.rand_truncate_prob) & valid_step & ~terminate
out = (*out, truncate)
self._step.add_(1)
return out

View File

@ -1,17 +1,539 @@
from __future__ import annotations
import torch
from torch import is_tensor
from torch.nn import Module
from torch.optim import AdamW
from torch.utils.data import Dataset, TensorDataset, DataLoader
from accelerate import Accelerator
from adam_atan2_pytorch import MuonAdamAtan2
from dreamer4.dreamer4 import (
VideoTokenizer,
DynamicsModel
DynamicsWorldModel,
Experience,
combine_experiences
)
from ema_pytorch import EMA
# helpers
def exists(v):
return v is not None
def default(v, d):
return v if exists(v) else d
def cycle(dl):
while True:
for batch in dl:
yield batch
# trainers
class VideoTokenizerTrainer(Module):
def __init__(
self,
model: VideoTokenizer
model: VideoTokenizer,
dataset: Dataset,
optim_klass = MuonAdamAtan2,
batch_size = 16,
learning_rate = 3e-4,
max_grad_norm = None,
num_train_steps = 10_000,
weight_decay = 0.,
accelerate_kwargs: dict = dict(),
optim_kwargs: dict = dict(),
cpu = False,
):
super().__init__()
raise NotImplementedError
batch_size = min(batch_size, len(dataset))
self.accelerator = Accelerator(
cpu = cpu,
**accelerate_kwargs
)
self.model = model
self.dataset = dataset
self.train_dataloader = DataLoader(dataset, batch_size = batch_size, drop_last = True, shuffle = True)
optim_kwargs = dict(
lr = learning_rate,
weight_decay = weight_decay
)
if optim_klass is MuonAdamAtan2:
optim = MuonAdamAtan2(
model.muon_parameters(),
model.parameters(),
**optim_kwargs
)
else:
optim = optim_klass(
model.parameters(),
**optim_kwargs
)
self.optim = optim
self.max_grad_norm = max_grad_norm
self.num_train_steps = num_train_steps
self.batch_size = batch_size
(
self.model,
self.train_dataloader,
self.optim
) = self.accelerator.prepare(
self.model,
self.train_dataloader,
self.optim
)
@property
def device(self):
return self.accelerator.device
def print(self, *args, **kwargs):
return self.accelerator.print(*args, **kwargs)
def forward(
self
):
iter_train_dl = cycle(self.train_dataloader)
for _ in range(self.num_train_steps):
video = next(iter_train_dl)
loss = self.model(video)
self.accelerator.backward(loss)
if exists(self.max_grad_norm):
self.accelerator.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
self.optim.step()
self.optim.zero_grad()
self.print('training complete')
# dynamics world model
class BehaviorCloneTrainer(Module):
def __init__(
self,
model: DynamicsWorldModel,
dataset: Dataset,
optim_klass = MuonAdamAtan2,
batch_size = 16,
learning_rate = 3e-4,
max_grad_norm = None,
num_train_steps = 10_000,
weight_decay = 0.,
accelerate_kwargs: dict = dict(),
optim_kwargs: dict = dict(),
cpu = False,
):
super().__init__()
batch_size = min(batch_size, len(dataset))
self.accelerator = Accelerator(
cpu = cpu,
**accelerate_kwargs
)
self.model = model
self.dataset = dataset
self.train_dataloader = DataLoader(dataset, batch_size = batch_size, drop_last = True, shuffle = True)
optim_kwargs = dict(
lr = learning_rate,
weight_decay = weight_decay
)
if optim_klass is MuonAdamAtan2:
optim = MuonAdamAtan2(
model.muon_parameters(),
model.parameters(),
**optim_kwargs
)
else:
optim = optim_klass(
model.parameters(),
**optim_kwargs
)
self.optim = optim
self.max_grad_norm = max_grad_norm
self.num_train_steps = num_train_steps
self.batch_size = batch_size
(
self.model,
self.train_dataloader,
self.optim
) = self.accelerator.prepare(
self.model,
self.train_dataloader,
self.optim
)
@property
def device(self):
return self.accelerator.device
def print(self, *args, **kwargs):
return self.accelerator.print(*args, **kwargs)
def forward(
self
):
iter_train_dl = cycle(self.train_dataloader)
for _ in range(self.num_train_steps):
batch_data = next(iter_train_dl)
# just assume raw video dynamics training if batch_data is a tensor
# else kwargs for video, actions, rewards
if is_tensor(batch_data):
loss = self.model(batch_data)
else:
loss = self.model(**batch_data)
self.accelerator.backward(loss)
if exists(self.max_grad_norm):
self.accelerator.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
self.optim.step()
self.optim.zero_grad()
self.print('training complete')
# training from dreams
class DreamTrainer(Module):
def __init__(
self,
model: DynamicsWorldModel,
optim_klass = AdamW,
batch_size = 16,
generate_timesteps = 16,
learning_rate = 3e-4,
max_grad_norm = None,
num_train_steps = 10_000,
weight_decay = 0.,
accelerate_kwargs: dict = dict(),
optim_kwargs: dict = dict(),
cpu = False,
):
super().__init__()
self.accelerator = Accelerator(
cpu = cpu,
**accelerate_kwargs
)
self.model = model
optim_kwargs = dict(
lr = learning_rate,
weight_decay = weight_decay
)
self.policy_head_optim = AdamW(model.policy_head_parameters(), **optim_kwargs)
self.value_head_optim = AdamW(model.value_head_parameters(), **optim_kwargs)
self.max_grad_norm = max_grad_norm
self.num_train_steps = num_train_steps
self.batch_size = batch_size
self.generate_timesteps = generate_timesteps
self.unwrapped_model = self.model
(
self.model,
self.policy_head_optim,
self.value_head_optim,
) = self.accelerator.prepare(
self.model,
self.policy_head_optim,
self.value_head_optim
)
@property
def device(self):
return self.accelerator.device
@property
def unwrapped_model(self):
return self.accelerator.unwrap_model(self.model)
def print(self, *args, **kwargs):
return self.accelerator.print(*args, **kwargs)
def forward(
self
):
for _ in range(self.num_train_steps):
dreams = self.unwrapped_model.generate(
self.generate_timesteps + 1, # plus one for bootstrap value
batch_size = self.batch_size,
return_rewards_per_frame = True,
return_agent_actions = True,
return_log_probs_and_values = True
)
policy_head_loss, value_head_loss = self.model.learn_from_experience(dreams)
self.print(f'policy head loss: {policy_head_loss.item():.3f} | value head loss: {value_head_loss.item():.3f}')
# update policy head
self.accelerator.backward(policy_head_loss)
if exists(self.max_grad_norm):
self.accelerator.clip_grad_norm_(self.model.policy_head_parameters()(), self.max_grad_norm)
self.policy_head_optim.step()
self.policy_head_optim.zero_grad()
# update value head
self.accelerator.backward(value_head_loss)
if exists(self.max_grad_norm):
self.accelerator.clip_grad_norm_(self.model.value_head_parameters(), self.max_grad_norm)
self.value_head_optim.step()
self.value_head_optim.zero_grad()
self.print('training complete')
# training from sim
class SimTrainer(Module):
def __init__(
self,
model: DynamicsWorldModel,
optim_klass = AdamW,
batch_size = 16,
generate_timesteps = 16,
learning_rate = 3e-4,
max_grad_norm = None,
epochs = 2,
weight_decay = 0.,
accelerate_kwargs: dict = dict(),
optim_kwargs: dict = dict(),
cpu = False,
):
super().__init__()
self.accelerator = Accelerator(
cpu = cpu,
**accelerate_kwargs
)
self.model = model
optim_kwargs = dict(
lr = learning_rate,
weight_decay = weight_decay
)
self.policy_head_optim = AdamW(model.policy_head_parameters(), **optim_kwargs)
self.value_head_optim = AdamW(model.value_head_parameters(), **optim_kwargs)
self.max_grad_norm = max_grad_norm
self.epochs = epochs
self.batch_size = batch_size
self.generate_timesteps = generate_timesteps
self.unwrapped_model = self.model
(
self.model,
self.policy_head_optim,
self.value_head_optim,
) = self.accelerator.prepare(
self.model,
self.policy_head_optim,
self.value_head_optim
)
@property
def device(self):
return self.accelerator.device
@property
def unwrapped_model(self):
return self.accelerator.unwrap_model(self.model)
def print(self, *args, **kwargs):
return self.accelerator.print(*args, **kwargs)
def learn(
self,
experience: Experience
):
step_size = experience.step_size
agent_index = experience.agent_index
latents = experience.latents
old_values = experience.values
rewards = experience.rewards
has_agent_embed = exists(experience.agent_embed)
agent_embed = experience.agent_embed
discrete_actions, continuous_actions = experience.actions
discrete_log_probs, continuous_log_probs = experience.log_probs
discrete_old_action_unembeds, continuous_old_action_unembeds = default(experience.old_action_unembeds, (None, None))
# handle empties
empty_tensor = torch.empty_like(rewards)
agent_embed = default(agent_embed, empty_tensor)
has_discrete = exists(discrete_actions)
has_continuous = exists(continuous_actions)
discrete_actions = default(discrete_actions, empty_tensor)
continuous_actions = default(continuous_actions, empty_tensor)
discrete_log_probs = default(discrete_log_probs, empty_tensor)
continuous_log_probs = default(continuous_log_probs, empty_tensor)
discrete_old_action_unembeds = default(discrete_old_action_unembeds, empty_tensor)
continuous_old_action_unembeds = default(discrete_old_action_unembeds, empty_tensor)
# create the dataset and dataloader
dataset = TensorDataset(
latents,
discrete_actions,
continuous_actions,
discrete_log_probs,
continuous_log_probs,
agent_embed,
discrete_old_action_unembeds,
continuous_old_action_unembeds,
old_values,
rewards
)
dataloader = DataLoader(dataset, batch_size = self.batch_size, shuffle = True)
for epoch in range(self.epochs):
for (
latents,
discrete_actions,
continuous_actions,
discrete_log_probs,
continuous_log_probs,
agent_embed,
discrete_old_action_unembeds,
continuous_old_action_unembeds,
old_values,
rewards
) in dataloader:
actions = (
discrete_actions if has_discrete else None,
continuous_actions if has_continuous else None
)
log_probs = (
discrete_log_probs if has_discrete else None,
continuous_log_probs if has_continuous else None
)
old_action_unembeds = (
discrete_old_action_unembeds if has_discrete else None,
continuous_old_action_unembeds if has_continuous else None
)
batch_experience = Experience(
latents = latents,
actions = actions,
log_probs = log_probs,
agent_embed = agent_embed if has_agent_embed else None,
old_action_unembeds = old_action_unembeds,
values = old_values,
rewards = rewards,
step_size = step_size,
agent_index = agent_index
)
policy_head_loss, value_head_loss = self.model.learn_from_experience(batch_experience)
self.print(f'policy head loss: {policy_head_loss.item():.3f} | value head loss: {value_head_loss.item():.3f}')
# update policy head
self.accelerator.backward(policy_head_loss)
if exists(self.max_grad_norm):
self.accelerator.clip_grad_norm_(self.model.policy_head_parameters()(), self.max_grad_norm)
self.policy_head_optim.step()
self.policy_head_optim.zero_grad()
# update value head
self.accelerator.backward(value_head_loss)
if exists(self.max_grad_norm):
self.accelerator.clip_grad_norm_(self.model.value_head_parameters(), self.max_grad_norm)
self.value_head_optim.step()
self.value_head_optim.zero_grad()
self.print('training complete')
def forward(
self,
env,
num_episodes = 50000,
max_experiences_before_learn = 8,
env_is_vectorized = False
):
for _ in range(num_episodes):
total_experience = 0
experiences = []
while total_experience < max_experiences_before_learn:
experience = self.unwrapped_model.interact_with_env(env, env_is_vectorized = env_is_vectorized)
num_experience = experience.video.shape[0]
total_experience += num_experience
experiences.append(experience.cpu())
combined_experiences = combine_experiences(experiences)
self.learn(combined_experiences)
experiences.clear()
self.print('training complete')

View File

@ -1,6 +1,6 @@
[project]
name = "dreamer4"
version = "0.0.6"
version = "0.1.24"
description = "Dreamer 4"
authors = [
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
@ -27,13 +27,17 @@ classifiers=[
dependencies = [
"accelerate",
"adam-atan2-pytorch>=0.2.2",
"assoc-scan",
"einx>=0.3.0",
"einops>=0.8.1",
"ema-pytorch",
"hl-gauss-pytorch",
"hyper-connections>=0.2.1",
"torch>=2.4",
"torchvision",
"x-mlps-pytorch>=0.0.29"
"x-mlps-pytorch>=0.0.29",
"vit-pytorch>=1.15.3"
]
[project.urls]

View File

@ -2,6 +2,9 @@ import pytest
param = pytest.mark.parametrize
import torch
def exists(v):
return v is not None
@param('pred_orig_latent', (False, True))
@param('grouped_query_attn', (False, True))
@param('dynamics_with_video_input', (False, True))
@ -9,7 +12,12 @@ import torch
@param('add_task_embeds', (False, True))
@param('num_spatial_tokens', (2, 8))
@param('signal_and_step_passed_in', (False, True))
@param('condition_on_actions', (False, True))
@param('num_residual_streams', (1, 4))
@param('add_reward_embed_to_agent_token', (False, True))
@param('add_state_pred_head', (False, True))
@param('use_time_cache', (False, True))
@param('var_len', (False, True))
def test_e2e(
pred_orig_latent,
grouped_query_attn,
@ -18,18 +26,26 @@ def test_e2e(
add_task_embeds,
num_spatial_tokens,
signal_and_step_passed_in,
add_reward_embed_to_agent_token
condition_on_actions,
num_residual_streams,
add_reward_embed_to_agent_token,
add_state_pred_head,
use_time_cache,
var_len
):
from dreamer4.dreamer4 import VideoTokenizer, DynamicsModel
from dreamer4.dreamer4 import VideoTokenizer, DynamicsWorldModel
tokenizer = VideoTokenizer(
16,
encoder_depth = 1,
decoder_depth = 1,
encoder_depth = 4,
decoder_depth = 4,
dim_latent = 16,
patch_size = 32,
attn_dim_head = 16,
num_latent_tokens = 4
num_latent_tokens = 4,
num_residual_streams = num_residual_streams,
encoder_add_decor_aux_loss = True,
decorr_sample_frac = 1.
)
video = torch.randn(2, 3, 4, 256, 256)
@ -45,7 +61,7 @@ def test_e2e(
query_heads, heads = (16, 4) if grouped_query_attn else (8, 8)
dynamics = DynamicsModel(
dynamics = DynamicsWorldModel(
dim = 16,
video_tokenizer = tokenizer,
dim_latent = 16,
@ -55,13 +71,16 @@ def test_e2e(
depth = 4,
num_spatial_tokens = num_spatial_tokens,
pred_orig_latent = pred_orig_latent,
num_discrete_actions = 4,
attn_dim_head = 16,
attn_heads = heads,
attn_kwargs = dict(
heads = heads,
query_heads = query_heads,
),
prob_no_shortcut_train = prob_no_shortcut_train,
add_reward_embed_to_agent_token = add_reward_embed_to_agent_token
add_reward_embed_to_agent_token = add_reward_embed_to_agent_token,
add_state_pred_head = add_state_pred_head,
num_residual_streams = num_residual_streams
)
signal_levels = step_sizes_log2 = None
@ -79,27 +98,39 @@ def test_e2e(
if add_task_embeds:
tasks = torch.randint(0, 4, (2,))
actions = None
if condition_on_actions:
actions = torch.randint(0, 4, (2, 3, 1))
lens = None
if var_len:
lens = torch.randint(1, 4, (2,))
flow_loss = dynamics(
**dynamics_input,
lens = lens,
tasks = tasks,
signal_levels = signal_levels,
step_sizes_log2 = step_sizes_log2
step_sizes_log2 = step_sizes_log2,
discrete_actions = actions,
add_autoregressive_action_loss = True
)
assert flow_loss.numel() == 1
# generating
generated_video, generated_rewards = dynamics.generate(
generations = dynamics.generate(
time_steps = 10,
image_height = 128,
image_width = 128,
batch_size = 2,
return_rewards_per_frame = True
return_rewards_per_frame = True,
use_time_cache = use_time_cache
)
assert generated_video.shape == (2, 3, 10, 128, 128)
assert generated_rewards.shape == (2, 10)
assert generations.video.shape == (2, 3, 10, 128, 128)
assert generations.rewards.shape == (2, 10)
# rl
@ -173,3 +204,628 @@ def test_attend_factory(
out = attend(q, k, v)
assert torch.allclose(flex_out, out, atol = 1e-6)
def test_action_with_world_model():
from dreamer4.dreamer4 import VideoTokenizer, DynamicsWorldModel
tokenizer = VideoTokenizer(
512,
dim_latent = 32,
patch_size = 32,
encoder_depth = 4,
decoder_depth = 4,
attn_heads = 8,
image_height = 256,
image_width = 256,
attn_kwargs = dict(
query_heads = 16
)
)
dynamics = DynamicsWorldModel(
512,
num_agents = 1,
video_tokenizer = tokenizer,
dim_latent = 32,
depth = 4,
num_discrete_actions = 4
)
rewards = torch.randn(1, 4)
discrete_actions = torch.randint(0, 4, (1, 4, 1))
gen = dynamics.generate(
16,
batch_size = 4,
return_rewards_per_frame = True,
return_agent_actions = True,
return_log_probs_and_values = True
)
assert gen.video.shape == (4, 3, 16, 256, 256)
assert gen.rewards.shape == (4, 16)
discrete_actions, continuous_actions = gen.actions
assert discrete_actions.shape == (4, 16, 1)
assert continuous_actions is None
discrete_log_probs, _ = gen.log_probs
assert discrete_log_probs.shape == (4, 16, 1)
assert gen.values.shape == (4, 16)
# take a reinforcement learning step
actor_loss, critic_loss = dynamics.learn_from_experience(gen)
actor_loss.backward(retain_graph = True)
critic_loss.backward()
def test_action_embedder():
from dreamer4.dreamer4 import ActionEmbedder
# 1 discrete action with 4 choices
embedder = ActionEmbedder(
512,
num_discrete_actions = 4
)
actions = torch.randint(0, 4, (2, 3, 1))
action_embed = embedder(discrete_actions = actions)
assert action_embed.shape == (2, 3, 512)
# 2 discrete actions with 4 choices each
embedder = ActionEmbedder(
512,
num_discrete_actions = (4, 4)
)
actions = torch.randint(0, 4, (2, 3, 2))
action_embed = embedder(discrete_actions = actions)
assert action_embed.shape == (2, 3, 512)
# picking out only the second discrete action
actions = torch.randint(0, 4, (2, 3, 1))
action_embed = embedder(discrete_actions = actions, discrete_action_types = 1)
assert action_embed.shape == (2, 3, 512)
# 2 continuous actions
embedder = ActionEmbedder(
512,
num_continuous_actions = 2,
continuous_norm_stats = ((0., 2.), (1., 1.)) # (mean, std) for normalizing each action
)
actions = torch.randn((2, 3, 2))
action_embed = embedder(continuous_actions = actions)
assert action_embed.shape == (2, 3, 512)
# 2 discrete actions with 4 choices each and 2 continuous actions
embedder = ActionEmbedder(
512,
num_discrete_actions = (4, 4),
num_continuous_actions = 2
)
discrete_actions = torch.randint(0, 4, (2, 3, 2))
continuous_actions = torch.randn(2, 3, 2)
action_embed = embedder(discrete_actions = discrete_actions, continuous_actions = continuous_actions)
assert action_embed.shape == (2, 3, 512)
# picking out one discrete and one continuous
discrete_actions = torch.randint(0, 4, (2, 3, 1))
continuous_actions = torch.randn(2, 3, 1)
action_embed = embedder(discrete_actions = discrete_actions, continuous_actions = continuous_actions, discrete_action_types = 1, continuous_action_types = 0)
assert action_embed.shape == (2, 3, 512)
# unembed
embedder = ActionEmbedder(
512,
num_discrete_actions = (4, 4),
num_continuous_actions = 2,
can_unembed = True
)
discrete_actions = torch.randint(0, 4, (2, 3, 2))
continuous_actions = torch.randn(2, 3, 2)
action_embed = embedder(discrete_actions = discrete_actions, continuous_actions = continuous_actions)
discrete_logits, continuous_mean_log_var = embedder.unembed(action_embed)
assert discrete_logits.shape == (2, 3, 8)
assert continuous_mean_log_var.shape == (2, 3, 2, 2)
# test kl div
discrete_logits_tgt, continuous_mean_log_var_tgt = embedder.unembed(action_embed)
discrete_kl_div, continuous_kl_div = embedder.kl_div((discrete_logits, continuous_mean_log_var), (discrete_logits_tgt, continuous_mean_log_var_tgt))
assert discrete_kl_div.shape == (2, 3)
assert continuous_kl_div.shape == (2, 3)
# return discrete split by number of actions
discrete_logits, continuous_mean_log_var = embedder.unembed(action_embed, return_split_discrete = True)
assert discrete_logits[0].shape == discrete_logits[1].shape == (2, 3, 4)
# unembed subset of actions
discrete_logits, continuous_mean_log_var = embedder.unembed(action_embed, discrete_action_types = 1, continuous_action_types = 0)
assert discrete_logits.shape == (2, 3, 4)
assert continuous_mean_log_var.shape == (2, 3, 1, 2)
# sample actions
sampled_discrete_actions, sampled_continuous_actions = embedder.sample(action_embed, discrete_action_types = 1, continuous_action_types = 0)
assert sampled_discrete_actions.shape == (2, 3, 1)
assert sampled_continuous_actions.shape == (2, 3, 1)
# log probs
assert discrete_logits.shape == (2, 3, 4)
assert continuous_mean_log_var.shape == (2, 3, 1, 2)
discrete_log_probs, continuous_log_probs = embedder.log_probs(
action_embed,
discrete_targets = discrete_actions,
continuous_targets = continuous_actions,
parallel_discrete_calc = False
)
assert discrete_log_probs.shape == (2, 3, 2)
assert continuous_log_probs.shape == (2, 3, 2)
_, (discrete_entropies, continuous_entropies) = embedder.log_probs(
action_embed,
discrete_targets = discrete_actions,
continuous_targets = continuous_actions,
parallel_discrete_calc = True,
return_entropies = True
)
assert discrete_entropies.shape == (2, 3, 2)
assert continuous_entropies.shape == (2, 3, 2)
parallel_discrete_log_probs, _ = embedder.log_probs(
action_embed,
discrete_targets = discrete_actions,
continuous_targets = continuous_actions,
parallel_discrete_calc = True
)
assert torch.allclose(discrete_log_probs, parallel_discrete_log_probs, atol = 1e-5)
def test_mtp():
from dreamer4.dreamer4 import create_multi_token_prediction_targets
rewards = torch.randn(3, 16) # (b t)
reward_targets, mask = create_multi_token_prediction_targets(rewards, 3) # say three token lookahead
assert reward_targets.shape == (3, 16, 3)
assert mask.shape == (3, 16, 3)
actions = torch.randint(0, 10, (3, 16, 2))
action_targets, mask = create_multi_token_prediction_targets(actions, 3)
assert action_targets.shape == (3, 16, 3, 2)
assert mask.shape == (3, 16, 3)
from dreamer4.dreamer4 import ActionEmbedder
embedder = ActionEmbedder(
512,
num_discrete_actions = (4, 4),
num_continuous_actions = 2,
can_unembed = True,
num_unembed_preds = 8
)
discrete_actions = torch.randint(0, 4, (2, 3, 2))
continuous_actions = torch.randn(2, 3, 2)
action_embed = torch.randn(2, 16, 512)
discrete_logits, continuous_logits = embedder.unembed(action_embed)
assert discrete_logits.shape == (8, 2, 16, 8)
discrete_logits, continuous_logits = embedder.unembed(action_embed, pred_head_index = 0)
assert discrete_logits.shape == (2, 16, 8)
def test_loss_normalizer():
from torch import cat
from dreamer4.dreamer4 import LossNormalizer
loss_normalizer = LossNormalizer(4, beta = 0.)
losses = torch.rand(4)
_ = loss_normalizer(losses)
normed_losses = loss_normalizer(losses)
assert (normed_losses == 1.).all()
def test_tokenizer_trainer():
from dreamer4.trainers import VideoTokenizerTrainer
from dreamer4.dreamer4 import VideoTokenizer
from torch.utils.data import Dataset
class MockDataset(Dataset):
def __len__(self):
return 2
def __getitem__(self, idx):
return torch.randn(3, 2, 64, 64)
dataset = MockDataset()
tokenizer = VideoTokenizer(
16,
encoder_depth = 1,
decoder_depth = 1,
time_block_every = 1,
dim_latent = 16,
patch_size = 32,
attn_dim_head = 16,
num_latent_tokens = 4
)
trainer = VideoTokenizerTrainer(
tokenizer,
dataset = dataset,
num_train_steps = 1,
batch_size = 1,
cpu = True,
max_grad_norm = 0.5
)
trainer()
@param('with_actions', (True, False))
@param('with_rewards', (True, False))
def test_bc_trainer(
with_actions,
with_rewards
):
from dreamer4.trainers import BehaviorCloneTrainer
from dreamer4.dreamer4 import DynamicsWorldModel, VideoTokenizer
from torch.utils.data import Dataset
class MockDataset(Dataset):
def __len__(self):
return 2
def __getitem__(self, idx):
state = torch.randn(3, 2, 64, 64)
pkg = dict(video = state)
if with_actions:
pkg.update(discrete_actions = torch.randint(0, 4, (2, 1)))
if with_rewards:
pkg.update(rewards = torch.randn(2,))
return pkg
dataset = MockDataset()
tokenizer = VideoTokenizer(
16,
encoder_depth = 1,
decoder_depth = 1,
time_block_every = 1,
dim_latent = 16,
patch_size = 32,
attn_dim_head = 16,
num_latent_tokens = 1
)
world_model = DynamicsWorldModel(
video_tokenizer = tokenizer,
dim = 16,
dim_latent = 16,
max_steps = 64,
num_tasks = 4,
num_latent_tokens = 1,
depth = 1,
time_block_every = 1,
num_spatial_tokens = 1,
pred_orig_latent = True,
num_discrete_actions = 4,
attn_dim_head = 16,
prob_no_shortcut_train = 0.1,
num_residual_streams = 1
)
trainer = BehaviorCloneTrainer(
world_model,
dataset = dataset,
batch_size = 1,
num_train_steps = 1,
cpu = True
)
trainer()
def test_dream_trainer():
from dreamer4.dreamer4 import DynamicsWorldModel
world_model = DynamicsWorldModel(
dim = 16,
dim_latent = 16,
max_steps = 64,
num_tasks = 4,
num_latent_tokens = 1,
depth = 1,
time_block_every = 1,
num_spatial_tokens = 1,
pred_orig_latent = True,
num_discrete_actions = 4,
attn_dim_head = 16,
prob_no_shortcut_train = 0.1,
num_residual_streams = 1
)
# training from self-generations (dreams)
from dreamer4.trainers import DreamTrainer
dream_trainer = DreamTrainer(
world_model,
batch_size = 2,
num_train_steps = 1,
cpu = True,
)
dream_trainer()
def test_cache_generate():
from dreamer4.dreamer4 import DynamicsWorldModel
dynamics = DynamicsWorldModel(
dim = 16,
dim_latent = 16,
max_steps = 64,
num_tasks = 4,
num_latent_tokens = 4,
depth = 1,
time_block_every = 1,
num_spatial_tokens = 1,
pred_orig_latent = True,
num_discrete_actions = 4,
attn_dim_head = 16,
prob_no_shortcut_train = 0.1,
num_residual_streams = 1
)
generated, time_cache = dynamics.generate(1, return_time_cache = True)
generated, time_cache = dynamics.generate(1, time_cache = time_cache, return_time_cache = True)
generated, time_cache = dynamics.generate(1, time_cache = time_cache, return_time_cache = True)
@param('vectorized', (False, True))
@param('use_pmpo', (False, True))
@param('env_can_terminate', (False, True))
@param('env_can_truncate', (False, True))
@param('store_agent_embed', (False, True))
def test_online_rl(
vectorized,
use_pmpo,
env_can_terminate,
env_can_truncate,
store_agent_embed
):
from dreamer4.dreamer4 import DynamicsWorldModel, VideoTokenizer
tokenizer = VideoTokenizer(
16,
encoder_depth = 1,
decoder_depth = 1,
time_block_every = 1,
dim_latent = 16,
patch_size = 32,
attn_dim_head = 16,
num_latent_tokens = 1,
image_height = 256,
image_width = 256,
)
world_model_and_policy = DynamicsWorldModel(
video_tokenizer = tokenizer,
dim = 16,
dim_latent = 16,
max_steps = 64,
num_tasks = 4,
num_latent_tokens = 1,
depth = 1,
time_block_every = 1,
num_spatial_tokens = 1,
pred_orig_latent = True,
num_discrete_actions = 4,
attn_dim_head = 16,
prob_no_shortcut_train = 0.1,
num_residual_streams = 1
)
from dreamer4.mocks import MockEnv
from dreamer4.dreamer4 import combine_experiences
mock_env = MockEnv(
(256, 256),
vectorized = vectorized,
num_envs = 4,
terminate_after_step = 2 if env_can_terminate else None,
can_truncate = env_can_truncate,
rand_terminate_prob = 0.1
)
# manually
dream_experience = world_model_and_policy.generate(10, batch_size = 1, store_agent_embed = store_agent_embed, return_for_policy_optimization = True)
one_experience = world_model_and_policy.interact_with_env(mock_env, max_timesteps = 8, env_is_vectorized = vectorized, store_agent_embed = store_agent_embed)
another_experience = world_model_and_policy.interact_with_env(mock_env, max_timesteps = 16, env_is_vectorized = vectorized, store_agent_embed = store_agent_embed)
combined_experience = combine_experiences([dream_experience, one_experience, another_experience])
# quick test moving the experience to different devices
if torch.cuda.is_available():
combined_experience = combined_experience.to(torch.device('cuda'))
combined_experience = combined_experience.to(world_model_and_policy.device)
if store_agent_embed:
assert exists(combined_experience.agent_embed)
actor_loss, critic_loss = world_model_and_policy.learn_from_experience(combined_experience, use_pmpo = use_pmpo)
actor_loss.backward()
critic_loss.backward()
# with trainer
from dreamer4.trainers import SimTrainer
trainer = SimTrainer(
world_model_and_policy,
batch_size = 4,
cpu = True
)
trainer(mock_env, num_episodes = 2, env_is_vectorized = vectorized)
@param('num_video_views', (1, 2))
def test_proprioception(
num_video_views
):
from dreamer4.dreamer4 import VideoTokenizer, DynamicsWorldModel
tokenizer = VideoTokenizer(
512,
dim_latent = 32,
patch_size = 32,
encoder_depth = 2,
decoder_depth = 2,
time_block_every = 2,
attn_heads = 8,
image_height = 256,
image_width = 256,
attn_kwargs = dict(
query_heads = 16
)
)
dynamics = DynamicsWorldModel(
512,
num_agents = 1,
video_tokenizer = tokenizer,
dim_latent = 32,
dim_proprio = 21,
num_tasks = 4,
num_video_views = num_video_views,
num_discrete_actions = 4,
num_residual_streams = 1
)
if num_video_views > 1:
video_shape = (2, num_video_views, 3, 10, 256, 256)
else:
video_shape = (2, 3, 10, 256, 256)
video = torch.randn(*video_shape)
rewards = torch.randn(2, 10)
proprio = torch.randn(2, 10, 21)
discrete_actions = torch.randint(0, 4, (2, 10, 1))
tasks = torch.randint(0, 4, (2,))
loss = dynamics(
video = video,
rewards = rewards,
tasks = tasks,
proprio = proprio,
discrete_actions = discrete_actions
)
loss.backward()
generations = dynamics.generate(
10,
batch_size = 2,
return_decoded_video = True
)
assert exists(generations.proprio)
assert generations.video.shape == video_shape
def test_epo():
from dreamer4.dreamer4 import VideoTokenizer, DynamicsWorldModel
tokenizer = VideoTokenizer(
512,
dim_latent = 32,
patch_size = 32,
encoder_depth = 2,
decoder_depth = 2,
time_block_every = 2,
attn_heads = 8,
image_height = 256,
image_width = 256,
attn_kwargs = dict(
query_heads = 16
)
)
dynamics = DynamicsWorldModel(
512,
num_agents = 1,
video_tokenizer = tokenizer,
dim_latent = 32,
dim_proprio = 21,
num_tasks = 4,
num_latent_genes = 16,
num_discrete_actions = 4,
num_residual_streams = 1
)
fitness = torch.randn(16,)
dynamics.evolve_(fitness)
def test_images_to_video_tokenizer():
import torch
from dreamer4 import VideoTokenizer, DynamicsWorldModel, AxialSpaceTimeTransformer
tokenizer = VideoTokenizer(
dim = 512,
dim_latent = 32,
patch_size = 32,
image_height = 256,
image_width = 256,
encoder_add_decor_aux_loss = True
)
images = torch.randn(2, 3, 256, 256)
loss, (losses, recon_images) = tokenizer(images, return_intermediates = True)
loss.backward()
assert images.shape == recon_images.shape