fix generation so that one more step is taken to decode agent embeds off the final cleaned set of latents, update readme

This commit is contained in:
lucidrains 2025-10-31 06:48:49 -07:00
parent 6870294d95
commit 60681fce1d
3 changed files with 105 additions and 5 deletions

View File

@ -1,9 +1,75 @@
<img src="./dreamer4-fig2.png" width="400px"></img>
## Dreamer 4 (wip)
## Dreamer 4
Implementation of Danijar's [latest iteration](https://arxiv.org/abs/2509.24527v1) for his [Dreamer](https://danijar.com/project/dreamer4/) line of work
## Install
```bash
$ pip install dreamer4
```
## Usage
```python
import torch
from dreamer4 import VideoTokenizer, DynamicsWorldModel
# video tokenizer, learned through MAE + lpips
tokenizer = VideoTokenizer(
dim = 512,
dim_latent = 32,
patch_size = 32,
image_height = 256,
image_width = 256
)
# dynamics world model
dynamics = DynamicsWorldModel(
dim = 512,
dim_latent = 32,
video_tokenizer = tokenizer,
num_discrete_actions = 4,
num_residual_streams = 1
)
# state, action, rewards
video = torch.randn(2, 3, 10, 256, 256)
discrete_actions = torch.randint(0, 4, (2, 10, 1))
rewards = torch.randn(2, 10)
# learn dynamics / behavior cloned model
loss = dynamics(
video = video,
rewards = rewards,
discrete_actions = discrete_actions
)
loss.backward()
# do the above with much data
# then generate dreams
dreams = dynamics.generate(
10,
batch_size = 2,
return_decoded_video = True,
return_for_policy_optimization = True
)
# learn from the dreams
actor_loss, critic_loss = dynamics.learn_from_experience(dreams)
(actor_loss + critic_loss).backward()
```
## Citation
```bibtex
@ -17,3 +83,5 @@ Implementation of Danijar's [latest iteration](https://arxiv.org/abs/2509.24527v
url = {https://arxiv.org/abs/2509.24527},
}
```
*the conquest of nature is to be achieved through number and measure - angels to Descartes in a dream*

View File

@ -2429,6 +2429,7 @@ class DynamicsWorldModel(Module):
normalize_advantages = None,
eps = 1e-6
):
assert isinstance(experience, Experience)
latents = experience.latents
actions = experience.actions
@ -2441,7 +2442,7 @@ class DynamicsWorldModel(Module):
step_size = experience.step_size
agent_index = experience.agent_index
assert all([*map(exists, (old_log_probs, actions, old_values, rewards, step_size))]), 'the generations need to contain the log probs, values, and rewards for policy optimization'
assert all([*map(exists, (old_log_probs, actions, old_values, rewards, step_size))]), 'the generations need to contain the log probs, values, and rewards for policy optimization - world_model.generate(..., return_log_probs_and_values = True)'
batch, time = latents.shape[0], latents.shape[1]
@ -2694,12 +2695,22 @@ class DynamicsWorldModel(Module):
return_rewards_per_frame = False,
return_agent_actions = False,
return_log_probs_and_values = False,
return_for_policy_optimization = False,
return_time_kv_cache = False,
store_agent_embed = True,
store_old_action_unembeds = True
): # (b t n d) | (b c t h w)
# handy flag for returning generations for rl
if return_for_policy_optimization:
return_agent_actions |= True
return_log_probs_and_values |= True
return_rewards_per_frame |= True
# more variables
has_proprio = self.has_proprio
was_training = self.training
self.eval()
@ -2769,6 +2780,19 @@ class DynamicsWorldModel(Module):
curr_time_steps = latents.shape[1]
# determine whether to take an extra step if
# (1) using time kv cache
# (2) decoding anything off agent embedding (rewards, actions, etc)
take_extra_step = (
use_time_kv_cache or
return_rewards_per_frame or
store_agent_embed or
return_agent_actions
)
# prepare noised latent / proprio inputs
noised_latent = randn((batch_size, 1, self.num_video_views, *latent_shape), device = self.device)
noised_proprio = None
@ -2776,7 +2800,10 @@ class DynamicsWorldModel(Module):
if has_proprio:
noised_proprio = randn((batch_size, 1, self.dim_proprio), device = self.device)
for step in range(num_steps):
# denoising steps
for step in range(num_steps + int(take_extra_step)):
is_last_step = (step + 1) == num_steps
signal_levels = full((batch_size, 1), step * step_size, dtype = torch.long, device = self.device)
@ -2819,6 +2846,11 @@ class DynamicsWorldModel(Module):
if use_time_kv_cache and is_last_step:
time_kv_cache = next_time_kv_cache
# early break if taking an extra step for agent embedding off cleaned latents for decoding
if take_extra_step and is_last_step:
break
# maybe proprio
if has_proprio:
@ -3021,7 +3053,7 @@ class DynamicsWorldModel(Module):
latent_is_noised = False,
return_all_losses = False,
return_intermediates = False,
add_autoregressive_action_loss = False,
add_autoregressive_action_loss = True,
update_loss_ema = None,
latent_has_view_dim = False
):

View File

@ -1,6 +1,6 @@
[project]
name = "dreamer4"
version = "0.0.102"
version = "0.1.0"
description = "Dreamer 4"
authors = [
{ name = "Phil Wang", email = "lucidrains@gmail.com" }