incorporate proprioception into the dynamics world model

This commit is contained in:
lucidrains 2025-10-24 11:24:22 -07:00
parent 35c1db4c7d
commit a9b728c611
3 changed files with 235 additions and 33 deletions

View File

@ -5,7 +5,7 @@ from math import ceil, log2
from random import random
from contextlib import nullcontext
from collections import namedtuple
from functools import partial
from functools import partial, wraps
from dataclasses import dataclass, asdict
import torch
@ -75,6 +75,7 @@ WorldModelLosses = namedtuple('WorldModelLosses', ('flow', 'rewards', 'discrete_
class Experience:
latents: Tensor
video: Tensor | None = None
proprio: Tensor | None = None
rewards: Tensor | None = None
actions: tuple[Tensor, Tensor] | None = None
log_probs: tuple[Tensor, Tensor] | None = None
@ -130,6 +131,9 @@ def default(v, d):
def first(arr):
return arr[0]
def xnor(x, y):
return not (x ^ y)
def has_at_least_one(*bools):
return sum([*map(int, bools)]) > 0
@ -1731,6 +1735,7 @@ class DynamicsWorldModel(Module):
num_latent_tokens = None,
num_agents = 1,
num_tasks = 0,
dim_proprio = None,
reward_encoder_kwargs: dict = dict(),
depth = 4,
pred_orig_latent = True, # directly predicting the original x0 data yield better results, rather than velocity (x-space vs v-space)
@ -1816,6 +1821,19 @@ class DynamicsWorldModel(Module):
Rearrange('b t s (n d) -> b t (s n) d', n = latent_tokens_to_space)
)
# proprioception
self.has_proprio = exists(dim_proprio)
self.dim_proprio = dim_proprio
if self.has_proprio:
self.to_proprio_token = nn.Linear(dim_proprio, dim)
self.to_proprio_pred = Sequential(
RMSNorm(dim),
nn.Linear(dim, dim_proprio)
)
# register tokens
self.num_register_tokens = num_register_tokens
@ -2221,7 +2239,7 @@ class DynamicsWorldModel(Module):
entropy_loss * self.policy_entropy_weight
)
# maye take policy optimizer step
# maybe take policy optimizer step
if exists(policy_optim):
total_policy_loss.backward()
@ -2271,10 +2289,11 @@ class DynamicsWorldModel(Module):
return_rewards_per_frame = False,
return_agent_actions = False,
return_log_probs_and_values = False,
return_time_kv_cache = False
return_time_kv_cache = False,
): # (b t n d) | (b c t h w)
has_proprio = self.has_proprio
was_training = self.training
self.eval()
@ -2301,7 +2320,14 @@ class DynamicsWorldModel(Module):
latents = empty((batch_size, 0, *latent_shape), device = self.device)
past_context_noise = latents.clone()
past_latents_context_noise = latents.clone()
# maybe internal state
if has_proprio:
proprio = empty((batch_size, 0, self.dim_proprio), device = self.device)
past_proprio_context_noise = proprio.clone()
# maybe return actions
@ -2327,17 +2353,35 @@ class DynamicsWorldModel(Module):
while latents.shape[1] < time_steps:
curr_time_steps = latents.shape[1]
noised_latent = randn((batch_size, 1, *latent_shape), device = self.device)
noised_proprio = None
if has_proprio:
noised_proprio = randn((batch_size, 1, self.dim_proprio), device = self.device)
for step in range(num_steps):
is_last_step = (step + 1) == num_steps
signal_levels = full((batch_size, 1), step * step_size, dtype = torch.long, device = self.device)
noised_context = latents.lerp(past_context_noise, context_signal_noise) # the paragraph after eq (8)
# noising past latent context
noised_context = latents.lerp(past_latents_context_noise, context_signal_noise) # the paragraph after eq (8)
noised_latent_with_context, pack_context_shape = pack((noised_context, noised_latent), 'b * n d')
# handle proprio
noised_proprio_with_context = None
if has_proprio:
noised_proprio_context = proprio.lerp(past_proprio_context_noise, context_signal_noise)
noised_proprio_with_context, _ = pack((noised_proprio_context, noised_proprio), 'b * d')
# proper signal levels
signal_levels_with_context = F.pad(signal_levels, (curr_time_steps, 0), value = self.max_steps - 1)
pred, (agent_embed, next_time_kv_cache) = self.forward(
@ -2348,6 +2392,7 @@ class DynamicsWorldModel(Module):
tasks = tasks,
discrete_actions = decoded_discrete_actions,
continuous_actions = decoded_continuous_actions,
proprio = noised_proprio_with_context,
time_kv_cache = time_kv_cache,
latent_is_noised = True,
return_pred_only = True,
@ -2357,24 +2402,44 @@ class DynamicsWorldModel(Module):
if use_time_kv_cache and is_last_step:
time_kv_cache = next_time_kv_cache
# maybe proprio
if has_proprio:
pred, pred_proprio = pred
# unpack pred
_, pred = unpack(pred, pack_context_shape, 'b * n d')
if has_proprio:
_, pred_proprio = unpack(pred_proprio, pack_context_shape, 'b * d')
# derive flow, based on whether in x-space or not
if self.pred_orig_latent:
times = self.get_times_from_signal_level(signal_levels, noised_latent)
flow = (pred - noised_latent) / (1. - times)
else:
flow = pred
def denoise_step(pred, noised, signal_levels):
if self.pred_orig_latent:
times = self.get_times_from_signal_level(signal_levels)
aligned_times = align_dims_left(times, noised)
flow = (pred - noised) / (1. - aligned_times)
else:
flow = pred
return flow * (step_size / self.max_steps)
# denoise
noised_latent += flow * (step_size / self.max_steps)
noised_latent += denoise_step(pred, noised_latent, signal_levels)
if has_proprio:
noised_proprio += denoise_step(pred_proprio, noised_proprio, signal_levels)
denoised_latent = noised_latent # it is now denoised
if has_proprio:
denoised_proprio = noised_proprio
# take care of the rewards by predicting on the agent token embedding on the last denoising step
if return_rewards_per_frame:
@ -2421,7 +2486,14 @@ class DynamicsWorldModel(Module):
# add new fixed context noise for the temporal consistency
past_context_noise = cat((past_context_noise, randn_like(denoised_latent)), dim = 1)
past_latents_context_noise = cat((past_latents_context_noise, randn_like(denoised_latent)), dim = 1)
# handle proprio
if has_proprio:
proprio = cat((proprio, denoised_proprio), dim = 1)
past_proprio_context_noise = cat((past_proprio_context_noise, randn_like(denoised_proprio)), dim = 1)
# restore state
@ -2443,7 +2515,7 @@ class DynamicsWorldModel(Module):
# only return video or latent if not requesting anything else, for first stage training
if not has_at_least_one(return_rewards_per_frame, return_agent_actions):
if not has_at_least_one(return_rewards_per_frame, return_agent_actions, has_proprio):
out = video if return_decoded_video else latents
if not return_time_kv_cache:
@ -2456,6 +2528,7 @@ class DynamicsWorldModel(Module):
gen = Experience(
latents = latents,
video = video,
proprio = proprio if has_proprio else None,
step_size = step_size,
agent_index = agent_index,
is_from_world_model = True
@ -2492,6 +2565,7 @@ class DynamicsWorldModel(Module):
continuous_actions = None, # (b t na) | (b t-1 na)
discrete_action_types = None, # (na)
continuous_action_types = None, # (na)
proprio = None, # (b t dp)
time_kv_cache = None,
return_pred_only = False,
latent_is_noised = False,
@ -2587,16 +2661,17 @@ class DynamicsWorldModel(Module):
# times is from 0 to 1
times = self.get_times_from_signal_level(signal_levels, latents)
times = self.get_times_from_signal_level(signal_levels)
if not latent_is_noised:
# get the noise
noise = randn_like(latents)
aligned_times = align_dims_left(times, latents)
# noise from 0 as noise to 1 as data
noised_latents = noise.lerp(latents, times)
noised_latents = noise.lerp(latents, aligned_times)
else:
noised_latents = latents
@ -2644,6 +2719,27 @@ class DynamicsWorldModel(Module):
reward_tokens = add('1 d, b t d', self.reward_learned_embed, reward_tokens)
# maybe proprioception
assert xnor(self.has_proprio, exists(proprio)), 'proprio must be passed in if `dim_proprio` is set and vice versa'
noised_proprio = None
if self.has_proprio:
if not latent_is_noised:
# get the noise
proprio_noise = randn_like(proprio)
aligned_times = align_dims_left(times, proprio)
# noise from 0 as noise to 1 as data
noised_proprio = proprio_noise.lerp(proprio, aligned_times)
else:
noised_proprio = proprio
# maybe create the action tokens
if exists(discrete_actions) or exists(continuous_actions):
@ -2672,7 +2768,7 @@ class DynamicsWorldModel(Module):
# main function, needs to be defined as such for shortcut training - additional calls for consistency loss
def get_prediction(noised_latents, signal_levels, step_sizes_log2, action_tokens, reward_tokens, agent_tokens, return_agent_tokens = False, return_time_kv_cache = False):
def get_prediction(noised_latents, noised_proprio, signal_levels, step_sizes_log2, action_tokens, reward_tokens, agent_tokens, return_agent_tokens = False, return_time_kv_cache = False):
# latents to spatial tokens
space_tokens = self.latents_to_spatial_tokens(noised_latents)
@ -2694,6 +2790,13 @@ class DynamicsWorldModel(Module):
registers = repeat(self.register_tokens, 's d -> b t s d', b = batch, t = time)
# maybe proprio
if exists(noised_proprio):
proprio_token = self.to_proprio_token(noised_proprio)
else:
proprio_token = registers[:, :, 0:0]
# determine signal + step size embed for their diffusion forcing + shortcut
signal_embed = self.signal_levels_embed(signal_levels)
@ -2706,7 +2809,7 @@ class DynamicsWorldModel(Module):
# pack to tokens for attending
tokens, packed_tokens_shape = pack([flow_token, space_tokens, registers, action_tokens, reward_tokens, agent_tokens], 'b t * d')
tokens, packed_tokens_shape = pack([flow_token, space_tokens, proprio_token, registers, action_tokens, reward_tokens, agent_tokens], 'b t * d')
# attention
@ -2714,7 +2817,7 @@ class DynamicsWorldModel(Module):
# unpack
flow_token, space_tokens, register_tokens, action_tokens, reward_tokens, agent_tokens = unpack(tokens, packed_tokens_shape, 'b t * d')
flow_token, space_tokens, proprio_token, register_tokens, action_tokens, reward_tokens, agent_tokens = unpack(tokens, packed_tokens_shape, 'b t * d')
# pooling
@ -2722,6 +2825,15 @@ class DynamicsWorldModel(Module):
pred = self.to_latent_pred(space_tokens)
# maybe proprio
if self.has_proprio:
pred_proprio = self.to_proprio_pred(proprio_token)
pred = (pred, pred_proprio)
# returning
if not return_agent_tokens:
return pred
@ -2736,7 +2848,7 @@ class DynamicsWorldModel(Module):
# forward the network
pred, (encoded_agent_tokens, next_time_kv_cache) = _get_prediction(noised_latents, signal_levels, step_sizes_log2, return_agent_tokens = True, return_time_kv_cache = True)
pred, (encoded_agent_tokens, next_time_kv_cache) = _get_prediction(noised_latents, noised_proprio, signal_levels, step_sizes_log2, return_agent_tokens = True, return_time_kv_cache = True)
if return_pred_only:
if not return_intermediates:
@ -2744,6 +2856,40 @@ class DynamicsWorldModel(Module):
return pred, (encoded_agent_tokens, next_time_kv_cache)
# pack the predictions to calculate flow for different modalities all at once
if self.has_proprio:
pred, for_flow_loss_packed_shape = pack(pred, 'b t *')
noised, _ = pack((noised_latents, noised_proprio), 'b t *')
data, _ = pack((latents, proprio), 'b t *')
noise, _ = pack((noise, proprio_noise), 'b t *')
else:
noised = noised_latents
data = latents
# wrapper function for maybe unpacking and packing modalities for doing flow math in unison
def maybe_pack_unpack(fn):
@wraps(fn)
@torch.no_grad()
def inner(noised, *args, **kwargs):
noised_proprio = None
if self.has_proprio:
noised, noised_proprio = unpack(noised, for_flow_loss_packed_shape, 'b t *')
pred = fn(noised, noised_proprio, *args, **kwargs)
if self.has_proprio:
pred, _ = pack(pred, 'b t *')
return pred
return inner
wrapped_get_prediction = maybe_pack_unpack(_get_prediction)
# determine the target for the loss
pred_target = None
@ -2759,46 +2905,45 @@ class DynamicsWorldModel(Module):
# x-space as in paper is in else clause
if is_v_space_pred:
pred_target = flow = latents - noise
pred_target = flow = data - noise
else:
pred_target = latents
pred_target = data
else:
# shortcut training - Frans et al. https://arxiv.org/abs/2410.12557
# basically a consistency loss where you ensure quantity of two half steps equals one step
# dreamer then makes it works for x-space with some math
get_prediction_no_grad = torch.no_grad()(_get_prediction)
step_sizes_log2_minus_one = step_sizes_log2 - 1 # which equals d / 2
half_step_size = 2 ** step_sizes_log2_minus_one
first_step_pred = get_prediction_no_grad(noised_latents, signal_levels, step_sizes_log2_minus_one)
first_step_pred = wrapped_get_prediction(noised, signal_levels, step_sizes_log2_minus_one)
# first derive b'
if is_v_space_pred:
first_step_pred_flow = first_step_pred
else:
first_times = self.get_times_from_signal_level(signal_levels, noised_latents)
first_step_pred_flow = (first_step_pred - noised_latents) / (1. - first_times)
first_times = self.get_times_from_signal_level(signal_levels, noised)
first_step_pred_flow = (first_step_pred - noised) / (1. - first_times)
# take a half step
half_step_size_align_left = align_dims_left(half_step_size, noised_latents)
half_step_size_align_left = align_dims_left(half_step_size, noised)
denoised_latent = noised_latents + first_step_pred_flow * (half_step_size_align_left / self.max_steps)
denoised = noised + first_step_pred_flow * (half_step_size_align_left / self.max_steps)
# get second prediction for b''
signal_levels_plus_half_step = signal_levels + half_step_size[:, None]
second_step_pred = get_prediction_no_grad(denoised_latent, signal_levels_plus_half_step, step_sizes_log2_minus_one)
second_step_pred = wrapped_get_prediction(denoised, signal_levels_plus_half_step, step_sizes_log2_minus_one)
if is_v_space_pred:
second_step_pred_flow = second_step_pred
else:
second_times = self.get_times_from_signal_level(signal_levels_plus_half_step, denoised_latent)
second_step_pred_flow = (second_step_pred - denoised_latent) / (1. - second_times)
second_times = self.get_times_from_signal_level(signal_levels_plus_half_step, denoised)
second_step_pred_flow = (second_step_pred - denoised) / (1. - second_times)
# pred target is sg(b' + b'') / 2
@ -2807,7 +2952,7 @@ class DynamicsWorldModel(Module):
# need to convert x-space to v-space
if is_x_space:
pred = (pred - noised_latents) / (1. - first_times)
pred = (pred - noised) / (1. - first_times)
maybe_shortcut_loss_weight = (1. - first_times) ** 2
# mse loss
@ -2820,6 +2965,8 @@ class DynamicsWorldModel(Module):
if exists(self.loss_weight_fn):
loss_weight = self.loss_weight_fn(times)
loss_weight = align_dims_left(loss_weight, flow_losses)
flow_losses = flow_losses * loss_weight
flow_loss = flow_losses.mean()

View File

@ -1,6 +1,6 @@
[project]
name = "dreamer4"
version = "0.0.69"
version = "0.0.70"
description = "Dreamer 4"
authors = [
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

View File

@ -2,6 +2,9 @@ import pytest
param = pytest.mark.parametrize
import torch
def exists(v):
return v is not None
@param('pred_orig_latent', (False, True))
@param('grouped_query_attn', (False, True))
@param('dynamics_with_video_input', (False, True))
@ -664,3 +667,55 @@ def test_online_rl(
)
trainer(mock_env, num_episodes = 2, env_is_vectorized = vectorized)
def test_proprioception():
from dreamer4.dreamer4 import VideoTokenizer, DynamicsWorldModel
tokenizer = VideoTokenizer(
512,
dim_latent = 32,
patch_size = 32,
encoder_depth = 2,
decoder_depth = 2,
time_block_every = 2,
attn_heads = 8,
image_height = 256,
image_width = 256,
attn_kwargs = dict(
query_heads = 16
)
)
dynamics = DynamicsWorldModel(
512,
num_agents = 1,
video_tokenizer = tokenizer,
dim_latent = 32,
dim_proprio = 21,
num_tasks = 4,
num_discrete_actions = 4,
num_residual_streams = 1
)
video = torch.randn(2, 3, 10, 256, 256)
rewards = torch.randn(2, 10)
proprio = torch.randn(2, 10, 21)
discrete_actions = torch.randint(0, 4, (2, 10, 1))
tasks = torch.randint(0, 4, (2,))
loss = dynamics(
video = video,
rewards = rewards,
tasks = tasks,
proprio = proprio,
discrete_actions = discrete_actions
)
loss.backward()
generations = dynamics.generate(
4,
batch_size = 2,
)
assert exists(generations.proprio)