incorporate proprioception into the dynamics world model
This commit is contained in:
parent
35c1db4c7d
commit
a9b728c611
@ -5,7 +5,7 @@ from math import ceil, log2
|
||||
from random import random
|
||||
from contextlib import nullcontext
|
||||
from collections import namedtuple
|
||||
from functools import partial
|
||||
from functools import partial, wraps
|
||||
from dataclasses import dataclass, asdict
|
||||
|
||||
import torch
|
||||
@ -75,6 +75,7 @@ WorldModelLosses = namedtuple('WorldModelLosses', ('flow', 'rewards', 'discrete_
|
||||
class Experience:
|
||||
latents: Tensor
|
||||
video: Tensor | None = None
|
||||
proprio: Tensor | None = None
|
||||
rewards: Tensor | None = None
|
||||
actions: tuple[Tensor, Tensor] | None = None
|
||||
log_probs: tuple[Tensor, Tensor] | None = None
|
||||
@ -130,6 +131,9 @@ def default(v, d):
|
||||
def first(arr):
|
||||
return arr[0]
|
||||
|
||||
def xnor(x, y):
|
||||
return not (x ^ y)
|
||||
|
||||
def has_at_least_one(*bools):
|
||||
return sum([*map(int, bools)]) > 0
|
||||
|
||||
@ -1731,6 +1735,7 @@ class DynamicsWorldModel(Module):
|
||||
num_latent_tokens = None,
|
||||
num_agents = 1,
|
||||
num_tasks = 0,
|
||||
dim_proprio = None,
|
||||
reward_encoder_kwargs: dict = dict(),
|
||||
depth = 4,
|
||||
pred_orig_latent = True, # directly predicting the original x0 data yield better results, rather than velocity (x-space vs v-space)
|
||||
@ -1816,6 +1821,19 @@ class DynamicsWorldModel(Module):
|
||||
Rearrange('b t s (n d) -> b t (s n) d', n = latent_tokens_to_space)
|
||||
)
|
||||
|
||||
# proprioception
|
||||
|
||||
self.has_proprio = exists(dim_proprio)
|
||||
self.dim_proprio = dim_proprio
|
||||
|
||||
if self.has_proprio:
|
||||
self.to_proprio_token = nn.Linear(dim_proprio, dim)
|
||||
|
||||
self.to_proprio_pred = Sequential(
|
||||
RMSNorm(dim),
|
||||
nn.Linear(dim, dim_proprio)
|
||||
)
|
||||
|
||||
# register tokens
|
||||
|
||||
self.num_register_tokens = num_register_tokens
|
||||
@ -2221,7 +2239,7 @@ class DynamicsWorldModel(Module):
|
||||
entropy_loss * self.policy_entropy_weight
|
||||
)
|
||||
|
||||
# maye take policy optimizer step
|
||||
# maybe take policy optimizer step
|
||||
|
||||
if exists(policy_optim):
|
||||
total_policy_loss.backward()
|
||||
@ -2271,10 +2289,11 @@ class DynamicsWorldModel(Module):
|
||||
return_rewards_per_frame = False,
|
||||
return_agent_actions = False,
|
||||
return_log_probs_and_values = False,
|
||||
return_time_kv_cache = False
|
||||
return_time_kv_cache = False,
|
||||
|
||||
): # (b t n d) | (b c t h w)
|
||||
|
||||
has_proprio = self.has_proprio
|
||||
was_training = self.training
|
||||
self.eval()
|
||||
|
||||
@ -2301,7 +2320,14 @@ class DynamicsWorldModel(Module):
|
||||
|
||||
latents = empty((batch_size, 0, *latent_shape), device = self.device)
|
||||
|
||||
past_context_noise = latents.clone()
|
||||
past_latents_context_noise = latents.clone()
|
||||
|
||||
# maybe internal state
|
||||
|
||||
if has_proprio:
|
||||
proprio = empty((batch_size, 0, self.dim_proprio), device = self.device)
|
||||
|
||||
past_proprio_context_noise = proprio.clone()
|
||||
|
||||
# maybe return actions
|
||||
|
||||
@ -2327,17 +2353,35 @@ class DynamicsWorldModel(Module):
|
||||
while latents.shape[1] < time_steps:
|
||||
|
||||
curr_time_steps = latents.shape[1]
|
||||
|
||||
noised_latent = randn((batch_size, 1, *latent_shape), device = self.device)
|
||||
|
||||
noised_proprio = None
|
||||
|
||||
if has_proprio:
|
||||
noised_proprio = randn((batch_size, 1, self.dim_proprio), device = self.device)
|
||||
|
||||
for step in range(num_steps):
|
||||
is_last_step = (step + 1) == num_steps
|
||||
|
||||
signal_levels = full((batch_size, 1), step * step_size, dtype = torch.long, device = self.device)
|
||||
|
||||
noised_context = latents.lerp(past_context_noise, context_signal_noise) # the paragraph after eq (8)
|
||||
# noising past latent context
|
||||
|
||||
noised_context = latents.lerp(past_latents_context_noise, context_signal_noise) # the paragraph after eq (8)
|
||||
|
||||
noised_latent_with_context, pack_context_shape = pack((noised_context, noised_latent), 'b * n d')
|
||||
|
||||
# handle proprio
|
||||
|
||||
noised_proprio_with_context = None
|
||||
|
||||
if has_proprio:
|
||||
noised_proprio_context = proprio.lerp(past_proprio_context_noise, context_signal_noise)
|
||||
noised_proprio_with_context, _ = pack((noised_proprio_context, noised_proprio), 'b * d')
|
||||
|
||||
# proper signal levels
|
||||
|
||||
signal_levels_with_context = F.pad(signal_levels, (curr_time_steps, 0), value = self.max_steps - 1)
|
||||
|
||||
pred, (agent_embed, next_time_kv_cache) = self.forward(
|
||||
@ -2348,6 +2392,7 @@ class DynamicsWorldModel(Module):
|
||||
tasks = tasks,
|
||||
discrete_actions = decoded_discrete_actions,
|
||||
continuous_actions = decoded_continuous_actions,
|
||||
proprio = noised_proprio_with_context,
|
||||
time_kv_cache = time_kv_cache,
|
||||
latent_is_noised = True,
|
||||
return_pred_only = True,
|
||||
@ -2357,24 +2402,44 @@ class DynamicsWorldModel(Module):
|
||||
if use_time_kv_cache and is_last_step:
|
||||
time_kv_cache = next_time_kv_cache
|
||||
|
||||
# maybe proprio
|
||||
|
||||
if has_proprio:
|
||||
pred, pred_proprio = pred
|
||||
|
||||
# unpack pred
|
||||
|
||||
_, pred = unpack(pred, pack_context_shape, 'b * n d')
|
||||
|
||||
if has_proprio:
|
||||
_, pred_proprio = unpack(pred_proprio, pack_context_shape, 'b * d')
|
||||
|
||||
|
||||
# derive flow, based on whether in x-space or not
|
||||
|
||||
if self.pred_orig_latent:
|
||||
times = self.get_times_from_signal_level(signal_levels, noised_latent)
|
||||
flow = (pred - noised_latent) / (1. - times)
|
||||
else:
|
||||
flow = pred
|
||||
def denoise_step(pred, noised, signal_levels):
|
||||
if self.pred_orig_latent:
|
||||
times = self.get_times_from_signal_level(signal_levels)
|
||||
aligned_times = align_dims_left(times, noised)
|
||||
|
||||
flow = (pred - noised) / (1. - aligned_times)
|
||||
else:
|
||||
flow = pred
|
||||
|
||||
return flow * (step_size / self.max_steps)
|
||||
|
||||
# denoise
|
||||
|
||||
noised_latent += flow * (step_size / self.max_steps)
|
||||
noised_latent += denoise_step(pred, noised_latent, signal_levels)
|
||||
|
||||
if has_proprio:
|
||||
noised_proprio += denoise_step(pred_proprio, noised_proprio, signal_levels)
|
||||
|
||||
denoised_latent = noised_latent # it is now denoised
|
||||
|
||||
if has_proprio:
|
||||
denoised_proprio = noised_proprio
|
||||
|
||||
# take care of the rewards by predicting on the agent token embedding on the last denoising step
|
||||
|
||||
if return_rewards_per_frame:
|
||||
@ -2421,7 +2486,14 @@ class DynamicsWorldModel(Module):
|
||||
|
||||
# add new fixed context noise for the temporal consistency
|
||||
|
||||
past_context_noise = cat((past_context_noise, randn_like(denoised_latent)), dim = 1)
|
||||
past_latents_context_noise = cat((past_latents_context_noise, randn_like(denoised_latent)), dim = 1)
|
||||
|
||||
# handle proprio
|
||||
|
||||
if has_proprio:
|
||||
proprio = cat((proprio, denoised_proprio), dim = 1)
|
||||
|
||||
past_proprio_context_noise = cat((past_proprio_context_noise, randn_like(denoised_proprio)), dim = 1)
|
||||
|
||||
# restore state
|
||||
|
||||
@ -2443,7 +2515,7 @@ class DynamicsWorldModel(Module):
|
||||
|
||||
# only return video or latent if not requesting anything else, for first stage training
|
||||
|
||||
if not has_at_least_one(return_rewards_per_frame, return_agent_actions):
|
||||
if not has_at_least_one(return_rewards_per_frame, return_agent_actions, has_proprio):
|
||||
out = video if return_decoded_video else latents
|
||||
|
||||
if not return_time_kv_cache:
|
||||
@ -2456,6 +2528,7 @@ class DynamicsWorldModel(Module):
|
||||
gen = Experience(
|
||||
latents = latents,
|
||||
video = video,
|
||||
proprio = proprio if has_proprio else None,
|
||||
step_size = step_size,
|
||||
agent_index = agent_index,
|
||||
is_from_world_model = True
|
||||
@ -2492,6 +2565,7 @@ class DynamicsWorldModel(Module):
|
||||
continuous_actions = None, # (b t na) | (b t-1 na)
|
||||
discrete_action_types = None, # (na)
|
||||
continuous_action_types = None, # (na)
|
||||
proprio = None, # (b t dp)
|
||||
time_kv_cache = None,
|
||||
return_pred_only = False,
|
||||
latent_is_noised = False,
|
||||
@ -2587,16 +2661,17 @@ class DynamicsWorldModel(Module):
|
||||
|
||||
# times is from 0 to 1
|
||||
|
||||
times = self.get_times_from_signal_level(signal_levels, latents)
|
||||
times = self.get_times_from_signal_level(signal_levels)
|
||||
|
||||
if not latent_is_noised:
|
||||
# get the noise
|
||||
|
||||
noise = randn_like(latents)
|
||||
aligned_times = align_dims_left(times, latents)
|
||||
|
||||
# noise from 0 as noise to 1 as data
|
||||
|
||||
noised_latents = noise.lerp(latents, times)
|
||||
noised_latents = noise.lerp(latents, aligned_times)
|
||||
|
||||
else:
|
||||
noised_latents = latents
|
||||
@ -2644,6 +2719,27 @@ class DynamicsWorldModel(Module):
|
||||
|
||||
reward_tokens = add('1 d, b t d', self.reward_learned_embed, reward_tokens)
|
||||
|
||||
# maybe proprioception
|
||||
|
||||
assert xnor(self.has_proprio, exists(proprio)), 'proprio must be passed in if `dim_proprio` is set and vice versa'
|
||||
|
||||
noised_proprio = None
|
||||
|
||||
if self.has_proprio:
|
||||
|
||||
if not latent_is_noised:
|
||||
# get the noise
|
||||
|
||||
proprio_noise = randn_like(proprio)
|
||||
aligned_times = align_dims_left(times, proprio)
|
||||
|
||||
# noise from 0 as noise to 1 as data
|
||||
|
||||
noised_proprio = proprio_noise.lerp(proprio, aligned_times)
|
||||
|
||||
else:
|
||||
noised_proprio = proprio
|
||||
|
||||
# maybe create the action tokens
|
||||
|
||||
if exists(discrete_actions) or exists(continuous_actions):
|
||||
@ -2672,7 +2768,7 @@ class DynamicsWorldModel(Module):
|
||||
|
||||
# main function, needs to be defined as such for shortcut training - additional calls for consistency loss
|
||||
|
||||
def get_prediction(noised_latents, signal_levels, step_sizes_log2, action_tokens, reward_tokens, agent_tokens, return_agent_tokens = False, return_time_kv_cache = False):
|
||||
def get_prediction(noised_latents, noised_proprio, signal_levels, step_sizes_log2, action_tokens, reward_tokens, agent_tokens, return_agent_tokens = False, return_time_kv_cache = False):
|
||||
# latents to spatial tokens
|
||||
|
||||
space_tokens = self.latents_to_spatial_tokens(noised_latents)
|
||||
@ -2694,6 +2790,13 @@ class DynamicsWorldModel(Module):
|
||||
|
||||
registers = repeat(self.register_tokens, 's d -> b t s d', b = batch, t = time)
|
||||
|
||||
# maybe proprio
|
||||
|
||||
if exists(noised_proprio):
|
||||
proprio_token = self.to_proprio_token(noised_proprio)
|
||||
else:
|
||||
proprio_token = registers[:, :, 0:0]
|
||||
|
||||
# determine signal + step size embed for their diffusion forcing + shortcut
|
||||
|
||||
signal_embed = self.signal_levels_embed(signal_levels)
|
||||
@ -2706,7 +2809,7 @@ class DynamicsWorldModel(Module):
|
||||
|
||||
# pack to tokens for attending
|
||||
|
||||
tokens, packed_tokens_shape = pack([flow_token, space_tokens, registers, action_tokens, reward_tokens, agent_tokens], 'b t * d')
|
||||
tokens, packed_tokens_shape = pack([flow_token, space_tokens, proprio_token, registers, action_tokens, reward_tokens, agent_tokens], 'b t * d')
|
||||
|
||||
# attention
|
||||
|
||||
@ -2714,7 +2817,7 @@ class DynamicsWorldModel(Module):
|
||||
|
||||
# unpack
|
||||
|
||||
flow_token, space_tokens, register_tokens, action_tokens, reward_tokens, agent_tokens = unpack(tokens, packed_tokens_shape, 'b t * d')
|
||||
flow_token, space_tokens, proprio_token, register_tokens, action_tokens, reward_tokens, agent_tokens = unpack(tokens, packed_tokens_shape, 'b t * d')
|
||||
|
||||
# pooling
|
||||
|
||||
@ -2722,6 +2825,15 @@ class DynamicsWorldModel(Module):
|
||||
|
||||
pred = self.to_latent_pred(space_tokens)
|
||||
|
||||
# maybe proprio
|
||||
|
||||
if self.has_proprio:
|
||||
pred_proprio = self.to_proprio_pred(proprio_token)
|
||||
|
||||
pred = (pred, pred_proprio)
|
||||
|
||||
# returning
|
||||
|
||||
if not return_agent_tokens:
|
||||
return pred
|
||||
|
||||
@ -2736,7 +2848,7 @@ class DynamicsWorldModel(Module):
|
||||
|
||||
# forward the network
|
||||
|
||||
pred, (encoded_agent_tokens, next_time_kv_cache) = _get_prediction(noised_latents, signal_levels, step_sizes_log2, return_agent_tokens = True, return_time_kv_cache = True)
|
||||
pred, (encoded_agent_tokens, next_time_kv_cache) = _get_prediction(noised_latents, noised_proprio, signal_levels, step_sizes_log2, return_agent_tokens = True, return_time_kv_cache = True)
|
||||
|
||||
if return_pred_only:
|
||||
if not return_intermediates:
|
||||
@ -2744,6 +2856,40 @@ class DynamicsWorldModel(Module):
|
||||
|
||||
return pred, (encoded_agent_tokens, next_time_kv_cache)
|
||||
|
||||
# pack the predictions to calculate flow for different modalities all at once
|
||||
|
||||
if self.has_proprio:
|
||||
pred, for_flow_loss_packed_shape = pack(pred, 'b t *')
|
||||
|
||||
noised, _ = pack((noised_latents, noised_proprio), 'b t *')
|
||||
data, _ = pack((latents, proprio), 'b t *')
|
||||
noise, _ = pack((noise, proprio_noise), 'b t *')
|
||||
else:
|
||||
noised = noised_latents
|
||||
data = latents
|
||||
|
||||
# wrapper function for maybe unpacking and packing modalities for doing flow math in unison
|
||||
|
||||
def maybe_pack_unpack(fn):
|
||||
@wraps(fn)
|
||||
@torch.no_grad()
|
||||
def inner(noised, *args, **kwargs):
|
||||
|
||||
noised_proprio = None
|
||||
|
||||
if self.has_proprio:
|
||||
noised, noised_proprio = unpack(noised, for_flow_loss_packed_shape, 'b t *')
|
||||
|
||||
pred = fn(noised, noised_proprio, *args, **kwargs)
|
||||
|
||||
if self.has_proprio:
|
||||
pred, _ = pack(pred, 'b t *')
|
||||
|
||||
return pred
|
||||
return inner
|
||||
|
||||
wrapped_get_prediction = maybe_pack_unpack(_get_prediction)
|
||||
|
||||
# determine the target for the loss
|
||||
|
||||
pred_target = None
|
||||
@ -2759,46 +2905,45 @@ class DynamicsWorldModel(Module):
|
||||
# x-space as in paper is in else clause
|
||||
|
||||
if is_v_space_pred:
|
||||
pred_target = flow = latents - noise
|
||||
pred_target = flow = data - noise
|
||||
else:
|
||||
pred_target = latents
|
||||
pred_target = data
|
||||
else:
|
||||
# shortcut training - Frans et al. https://arxiv.org/abs/2410.12557
|
||||
|
||||
# basically a consistency loss where you ensure quantity of two half steps equals one step
|
||||
# dreamer then makes it works for x-space with some math
|
||||
|
||||
get_prediction_no_grad = torch.no_grad()(_get_prediction)
|
||||
|
||||
step_sizes_log2_minus_one = step_sizes_log2 - 1 # which equals d / 2
|
||||
half_step_size = 2 ** step_sizes_log2_minus_one
|
||||
|
||||
first_step_pred = get_prediction_no_grad(noised_latents, signal_levels, step_sizes_log2_minus_one)
|
||||
first_step_pred = wrapped_get_prediction(noised, signal_levels, step_sizes_log2_minus_one)
|
||||
|
||||
# first derive b'
|
||||
|
||||
if is_v_space_pred:
|
||||
first_step_pred_flow = first_step_pred
|
||||
else:
|
||||
first_times = self.get_times_from_signal_level(signal_levels, noised_latents)
|
||||
first_step_pred_flow = (first_step_pred - noised_latents) / (1. - first_times)
|
||||
first_times = self.get_times_from_signal_level(signal_levels, noised)
|
||||
|
||||
first_step_pred_flow = (first_step_pred - noised) / (1. - first_times)
|
||||
|
||||
# take a half step
|
||||
|
||||
half_step_size_align_left = align_dims_left(half_step_size, noised_latents)
|
||||
half_step_size_align_left = align_dims_left(half_step_size, noised)
|
||||
|
||||
denoised_latent = noised_latents + first_step_pred_flow * (half_step_size_align_left / self.max_steps)
|
||||
denoised = noised + first_step_pred_flow * (half_step_size_align_left / self.max_steps)
|
||||
|
||||
# get second prediction for b''
|
||||
|
||||
signal_levels_plus_half_step = signal_levels + half_step_size[:, None]
|
||||
second_step_pred = get_prediction_no_grad(denoised_latent, signal_levels_plus_half_step, step_sizes_log2_minus_one)
|
||||
second_step_pred = wrapped_get_prediction(denoised, signal_levels_plus_half_step, step_sizes_log2_minus_one)
|
||||
|
||||
if is_v_space_pred:
|
||||
second_step_pred_flow = second_step_pred
|
||||
else:
|
||||
second_times = self.get_times_from_signal_level(signal_levels_plus_half_step, denoised_latent)
|
||||
second_step_pred_flow = (second_step_pred - denoised_latent) / (1. - second_times)
|
||||
second_times = self.get_times_from_signal_level(signal_levels_plus_half_step, denoised)
|
||||
second_step_pred_flow = (second_step_pred - denoised) / (1. - second_times)
|
||||
|
||||
# pred target is sg(b' + b'') / 2
|
||||
|
||||
@ -2807,7 +2952,7 @@ class DynamicsWorldModel(Module):
|
||||
# need to convert x-space to v-space
|
||||
|
||||
if is_x_space:
|
||||
pred = (pred - noised_latents) / (1. - first_times)
|
||||
pred = (pred - noised) / (1. - first_times)
|
||||
maybe_shortcut_loss_weight = (1. - first_times) ** 2
|
||||
|
||||
# mse loss
|
||||
@ -2820,6 +2965,8 @@ class DynamicsWorldModel(Module):
|
||||
|
||||
if exists(self.loss_weight_fn):
|
||||
loss_weight = self.loss_weight_fn(times)
|
||||
loss_weight = align_dims_left(loss_weight, flow_losses)
|
||||
|
||||
flow_losses = flow_losses * loss_weight
|
||||
|
||||
flow_loss = flow_losses.mean()
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "dreamer4"
|
||||
version = "0.0.69"
|
||||
version = "0.0.70"
|
||||
description = "Dreamer 4"
|
||||
authors = [
|
||||
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
||||
|
||||
@ -2,6 +2,9 @@ import pytest
|
||||
param = pytest.mark.parametrize
|
||||
import torch
|
||||
|
||||
def exists(v):
|
||||
return v is not None
|
||||
|
||||
@param('pred_orig_latent', (False, True))
|
||||
@param('grouped_query_attn', (False, True))
|
||||
@param('dynamics_with_video_input', (False, True))
|
||||
@ -664,3 +667,55 @@ def test_online_rl(
|
||||
)
|
||||
|
||||
trainer(mock_env, num_episodes = 2, env_is_vectorized = vectorized)
|
||||
|
||||
def test_proprioception():
|
||||
from dreamer4.dreamer4 import VideoTokenizer, DynamicsWorldModel
|
||||
|
||||
tokenizer = VideoTokenizer(
|
||||
512,
|
||||
dim_latent = 32,
|
||||
patch_size = 32,
|
||||
encoder_depth = 2,
|
||||
decoder_depth = 2,
|
||||
time_block_every = 2,
|
||||
attn_heads = 8,
|
||||
image_height = 256,
|
||||
image_width = 256,
|
||||
attn_kwargs = dict(
|
||||
query_heads = 16
|
||||
)
|
||||
)
|
||||
|
||||
dynamics = DynamicsWorldModel(
|
||||
512,
|
||||
num_agents = 1,
|
||||
video_tokenizer = tokenizer,
|
||||
dim_latent = 32,
|
||||
dim_proprio = 21,
|
||||
num_tasks = 4,
|
||||
num_discrete_actions = 4,
|
||||
num_residual_streams = 1
|
||||
)
|
||||
|
||||
video = torch.randn(2, 3, 10, 256, 256)
|
||||
rewards = torch.randn(2, 10)
|
||||
proprio = torch.randn(2, 10, 21)
|
||||
discrete_actions = torch.randint(0, 4, (2, 10, 1))
|
||||
tasks = torch.randint(0, 4, (2,))
|
||||
|
||||
loss = dynamics(
|
||||
video = video,
|
||||
rewards = rewards,
|
||||
tasks = tasks,
|
||||
proprio = proprio,
|
||||
discrete_actions = discrete_actions
|
||||
)
|
||||
|
||||
loss.backward()
|
||||
|
||||
generations = dynamics.generate(
|
||||
4,
|
||||
batch_size = 2,
|
||||
)
|
||||
|
||||
assert exists(generations.proprio)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user