fix generation so that one more step is taken to decode agent embeds off the final cleaned set of latents, update readme
This commit is contained in:
parent
6870294d95
commit
a0bda62989
70
README.md
70
README.md
@ -1,9 +1,75 @@
|
||||
<img src="./dreamer4-fig2.png" width="400px"></img>
|
||||
|
||||
## Dreamer 4 (wip)
|
||||
## Dreamer 4
|
||||
|
||||
Implementation of Danijar's [latest iteration](https://arxiv.org/abs/2509.24527v1) for his [Dreamer](https://danijar.com/project/dreamer4/) line of work
|
||||
|
||||
## Install
|
||||
|
||||
```bash
|
||||
$ pip install dreamer4-pytorch
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
```python
|
||||
import torch
|
||||
from dreamer4 import VideoTokenizer, DynamicsWorldModel
|
||||
|
||||
# video tokenizer, learned through MAE + lpips
|
||||
|
||||
tokenizer = VideoTokenizer(
|
||||
dim = 512,
|
||||
dim_latent = 32,
|
||||
patch_size = 32,
|
||||
image_height = 256,
|
||||
image_width = 256
|
||||
)
|
||||
|
||||
# dynamics world model
|
||||
|
||||
dynamics = DynamicsWorldModel(
|
||||
dim = 512,
|
||||
dim_latent = 32,
|
||||
video_tokenizer = tokenizer,
|
||||
num_discrete_actions = 4,
|
||||
num_residual_streams = 1
|
||||
)
|
||||
|
||||
# state, action, rewards
|
||||
|
||||
video = torch.randn(2, 3, 10, 256, 256)
|
||||
discrete_actions = torch.randint(0, 4, (2, 10, 1))
|
||||
rewards = torch.randn(2, 10)
|
||||
|
||||
# learn dynamics / behavior cloned model
|
||||
|
||||
loss = dynamics(
|
||||
video = video,
|
||||
rewards = rewards,
|
||||
discrete_actions = discrete_actions
|
||||
)
|
||||
|
||||
loss.backward()
|
||||
|
||||
# do the above with much data
|
||||
|
||||
# then generate dreams
|
||||
|
||||
dreams = dynamics.generate(
|
||||
10,
|
||||
batch_size = 2,
|
||||
return_decoded_video = True,
|
||||
return_for_policy_optimization = True
|
||||
)
|
||||
|
||||
# learn from the dreams
|
||||
|
||||
actor_loss, critic_loss = dynamics.learn_from_experience(dreams)
|
||||
|
||||
(actor_loss + critic_loss).backward()
|
||||
```
|
||||
|
||||
## Citation
|
||||
|
||||
```bibtex
|
||||
@ -17,3 +83,5 @@ Implementation of Danijar's [latest iteration](https://arxiv.org/abs/2509.24527v
|
||||
url = {https://arxiv.org/abs/2509.24527},
|
||||
}
|
||||
```
|
||||
|
||||
*the conquest of nature is to be achieved through number and measure* - angels to Descartes, in a dream, the story goes.
|
||||
|
||||
@ -2429,6 +2429,7 @@ class DynamicsWorldModel(Module):
|
||||
normalize_advantages = None,
|
||||
eps = 1e-6
|
||||
):
|
||||
assert isinstance(experience, Experience)
|
||||
|
||||
latents = experience.latents
|
||||
actions = experience.actions
|
||||
@ -2441,7 +2442,7 @@ class DynamicsWorldModel(Module):
|
||||
step_size = experience.step_size
|
||||
agent_index = experience.agent_index
|
||||
|
||||
assert all([*map(exists, (old_log_probs, actions, old_values, rewards, step_size))]), 'the generations need to contain the log probs, values, and rewards for policy optimization'
|
||||
assert all([*map(exists, (old_log_probs, actions, old_values, rewards, step_size))]), 'the generations need to contain the log probs, values, and rewards for policy optimization - world_model.generate(..., return_log_probs_and_values = True)'
|
||||
|
||||
batch, time = latents.shape[0], latents.shape[1]
|
||||
|
||||
@ -2694,12 +2695,22 @@ class DynamicsWorldModel(Module):
|
||||
return_rewards_per_frame = False,
|
||||
return_agent_actions = False,
|
||||
return_log_probs_and_values = False,
|
||||
return_for_policy_optimization = False,
|
||||
return_time_kv_cache = False,
|
||||
store_agent_embed = True,
|
||||
store_old_action_unembeds = True
|
||||
|
||||
): # (b t n d) | (b c t h w)
|
||||
|
||||
# handy flag for returning generations for rl
|
||||
|
||||
if return_for_policy_optimization:
|
||||
return_agent_actions |= True
|
||||
return_log_probs_and_values |= True
|
||||
return_rewards_per_frame |= True
|
||||
|
||||
# more variables
|
||||
|
||||
has_proprio = self.has_proprio
|
||||
was_training = self.training
|
||||
self.eval()
|
||||
@ -2769,6 +2780,19 @@ class DynamicsWorldModel(Module):
|
||||
|
||||
curr_time_steps = latents.shape[1]
|
||||
|
||||
# determine whether to take an extra step if
|
||||
# (1) using time kv cache
|
||||
# (2) decoding anything off agent embedding (rewards, actions, etc)
|
||||
|
||||
take_extra_step = (
|
||||
use_time_kv_cache or
|
||||
return_rewards_per_frame or
|
||||
store_agent_embed or
|
||||
return_agent_actions
|
||||
)
|
||||
|
||||
# prepare noised latent / proprio inputs
|
||||
|
||||
noised_latent = randn((batch_size, 1, self.num_video_views, *latent_shape), device = self.device)
|
||||
|
||||
noised_proprio = None
|
||||
@ -2776,7 +2800,10 @@ class DynamicsWorldModel(Module):
|
||||
if has_proprio:
|
||||
noised_proprio = randn((batch_size, 1, self.dim_proprio), device = self.device)
|
||||
|
||||
for step in range(num_steps):
|
||||
# denoising steps
|
||||
|
||||
for step in range(num_steps + int(take_extra_step)):
|
||||
|
||||
is_last_step = (step + 1) == num_steps
|
||||
|
||||
signal_levels = full((batch_size, 1), step * step_size, dtype = torch.long, device = self.device)
|
||||
@ -2819,6 +2846,11 @@ class DynamicsWorldModel(Module):
|
||||
if use_time_kv_cache and is_last_step:
|
||||
time_kv_cache = next_time_kv_cache
|
||||
|
||||
# early break if taking an extra step for agent embedding off cleaned latents for decoding
|
||||
|
||||
if take_extra_step and is_last_step:
|
||||
break
|
||||
|
||||
# maybe proprio
|
||||
|
||||
if has_proprio:
|
||||
@ -3021,7 +3053,7 @@ class DynamicsWorldModel(Module):
|
||||
latent_is_noised = False,
|
||||
return_all_losses = False,
|
||||
return_intermediates = False,
|
||||
add_autoregressive_action_loss = False,
|
||||
add_autoregressive_action_loss = True,
|
||||
update_loss_ema = None,
|
||||
latent_has_view_dim = False
|
||||
):
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "dreamer4"
|
||||
version = "0.0.102"
|
||||
version = "0.1.0"
|
||||
description = "Dreamer 4"
|
||||
authors = [
|
||||
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user