incorporate proprioception into the dynamics world model
This commit is contained in:
parent
35c1db4c7d
commit
a9b728c611
@ -5,7 +5,7 @@ from math import ceil, log2
|
|||||||
from random import random
|
from random import random
|
||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
from functools import partial
|
from functools import partial, wraps
|
||||||
from dataclasses import dataclass, asdict
|
from dataclasses import dataclass, asdict
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -75,6 +75,7 @@ WorldModelLosses = namedtuple('WorldModelLosses', ('flow', 'rewards', 'discrete_
|
|||||||
class Experience:
|
class Experience:
|
||||||
latents: Tensor
|
latents: Tensor
|
||||||
video: Tensor | None = None
|
video: Tensor | None = None
|
||||||
|
proprio: Tensor | None = None
|
||||||
rewards: Tensor | None = None
|
rewards: Tensor | None = None
|
||||||
actions: tuple[Tensor, Tensor] | None = None
|
actions: tuple[Tensor, Tensor] | None = None
|
||||||
log_probs: tuple[Tensor, Tensor] | None = None
|
log_probs: tuple[Tensor, Tensor] | None = None
|
||||||
@ -130,6 +131,9 @@ def default(v, d):
|
|||||||
def first(arr):
|
def first(arr):
|
||||||
return arr[0]
|
return arr[0]
|
||||||
|
|
||||||
|
def xnor(x, y):
|
||||||
|
return not (x ^ y)
|
||||||
|
|
||||||
def has_at_least_one(*bools):
|
def has_at_least_one(*bools):
|
||||||
return sum([*map(int, bools)]) > 0
|
return sum([*map(int, bools)]) > 0
|
||||||
|
|
||||||
@ -1731,6 +1735,7 @@ class DynamicsWorldModel(Module):
|
|||||||
num_latent_tokens = None,
|
num_latent_tokens = None,
|
||||||
num_agents = 1,
|
num_agents = 1,
|
||||||
num_tasks = 0,
|
num_tasks = 0,
|
||||||
|
dim_proprio = None,
|
||||||
reward_encoder_kwargs: dict = dict(),
|
reward_encoder_kwargs: dict = dict(),
|
||||||
depth = 4,
|
depth = 4,
|
||||||
pred_orig_latent = True, # directly predicting the original x0 data yield better results, rather than velocity (x-space vs v-space)
|
pred_orig_latent = True, # directly predicting the original x0 data yield better results, rather than velocity (x-space vs v-space)
|
||||||
@ -1816,6 +1821,19 @@ class DynamicsWorldModel(Module):
|
|||||||
Rearrange('b t s (n d) -> b t (s n) d', n = latent_tokens_to_space)
|
Rearrange('b t s (n d) -> b t (s n) d', n = latent_tokens_to_space)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# proprioception
|
||||||
|
|
||||||
|
self.has_proprio = exists(dim_proprio)
|
||||||
|
self.dim_proprio = dim_proprio
|
||||||
|
|
||||||
|
if self.has_proprio:
|
||||||
|
self.to_proprio_token = nn.Linear(dim_proprio, dim)
|
||||||
|
|
||||||
|
self.to_proprio_pred = Sequential(
|
||||||
|
RMSNorm(dim),
|
||||||
|
nn.Linear(dim, dim_proprio)
|
||||||
|
)
|
||||||
|
|
||||||
# register tokens
|
# register tokens
|
||||||
|
|
||||||
self.num_register_tokens = num_register_tokens
|
self.num_register_tokens = num_register_tokens
|
||||||
@ -2221,7 +2239,7 @@ class DynamicsWorldModel(Module):
|
|||||||
entropy_loss * self.policy_entropy_weight
|
entropy_loss * self.policy_entropy_weight
|
||||||
)
|
)
|
||||||
|
|
||||||
# maye take policy optimizer step
|
# maybe take policy optimizer step
|
||||||
|
|
||||||
if exists(policy_optim):
|
if exists(policy_optim):
|
||||||
total_policy_loss.backward()
|
total_policy_loss.backward()
|
||||||
@ -2271,10 +2289,11 @@ class DynamicsWorldModel(Module):
|
|||||||
return_rewards_per_frame = False,
|
return_rewards_per_frame = False,
|
||||||
return_agent_actions = False,
|
return_agent_actions = False,
|
||||||
return_log_probs_and_values = False,
|
return_log_probs_and_values = False,
|
||||||
return_time_kv_cache = False
|
return_time_kv_cache = False,
|
||||||
|
|
||||||
): # (b t n d) | (b c t h w)
|
): # (b t n d) | (b c t h w)
|
||||||
|
|
||||||
|
has_proprio = self.has_proprio
|
||||||
was_training = self.training
|
was_training = self.training
|
||||||
self.eval()
|
self.eval()
|
||||||
|
|
||||||
@ -2301,7 +2320,14 @@ class DynamicsWorldModel(Module):
|
|||||||
|
|
||||||
latents = empty((batch_size, 0, *latent_shape), device = self.device)
|
latents = empty((batch_size, 0, *latent_shape), device = self.device)
|
||||||
|
|
||||||
past_context_noise = latents.clone()
|
past_latents_context_noise = latents.clone()
|
||||||
|
|
||||||
|
# maybe internal state
|
||||||
|
|
||||||
|
if has_proprio:
|
||||||
|
proprio = empty((batch_size, 0, self.dim_proprio), device = self.device)
|
||||||
|
|
||||||
|
past_proprio_context_noise = proprio.clone()
|
||||||
|
|
||||||
# maybe return actions
|
# maybe return actions
|
||||||
|
|
||||||
@ -2327,17 +2353,35 @@ class DynamicsWorldModel(Module):
|
|||||||
while latents.shape[1] < time_steps:
|
while latents.shape[1] < time_steps:
|
||||||
|
|
||||||
curr_time_steps = latents.shape[1]
|
curr_time_steps = latents.shape[1]
|
||||||
|
|
||||||
noised_latent = randn((batch_size, 1, *latent_shape), device = self.device)
|
noised_latent = randn((batch_size, 1, *latent_shape), device = self.device)
|
||||||
|
|
||||||
|
noised_proprio = None
|
||||||
|
|
||||||
|
if has_proprio:
|
||||||
|
noised_proprio = randn((batch_size, 1, self.dim_proprio), device = self.device)
|
||||||
|
|
||||||
for step in range(num_steps):
|
for step in range(num_steps):
|
||||||
is_last_step = (step + 1) == num_steps
|
is_last_step = (step + 1) == num_steps
|
||||||
|
|
||||||
signal_levels = full((batch_size, 1), step * step_size, dtype = torch.long, device = self.device)
|
signal_levels = full((batch_size, 1), step * step_size, dtype = torch.long, device = self.device)
|
||||||
|
|
||||||
noised_context = latents.lerp(past_context_noise, context_signal_noise) # the paragraph after eq (8)
|
# noising past latent context
|
||||||
|
|
||||||
|
noised_context = latents.lerp(past_latents_context_noise, context_signal_noise) # the paragraph after eq (8)
|
||||||
|
|
||||||
noised_latent_with_context, pack_context_shape = pack((noised_context, noised_latent), 'b * n d')
|
noised_latent_with_context, pack_context_shape = pack((noised_context, noised_latent), 'b * n d')
|
||||||
|
|
||||||
|
# handle proprio
|
||||||
|
|
||||||
|
noised_proprio_with_context = None
|
||||||
|
|
||||||
|
if has_proprio:
|
||||||
|
noised_proprio_context = proprio.lerp(past_proprio_context_noise, context_signal_noise)
|
||||||
|
noised_proprio_with_context, _ = pack((noised_proprio_context, noised_proprio), 'b * d')
|
||||||
|
|
||||||
|
# proper signal levels
|
||||||
|
|
||||||
signal_levels_with_context = F.pad(signal_levels, (curr_time_steps, 0), value = self.max_steps - 1)
|
signal_levels_with_context = F.pad(signal_levels, (curr_time_steps, 0), value = self.max_steps - 1)
|
||||||
|
|
||||||
pred, (agent_embed, next_time_kv_cache) = self.forward(
|
pred, (agent_embed, next_time_kv_cache) = self.forward(
|
||||||
@ -2348,6 +2392,7 @@ class DynamicsWorldModel(Module):
|
|||||||
tasks = tasks,
|
tasks = tasks,
|
||||||
discrete_actions = decoded_discrete_actions,
|
discrete_actions = decoded_discrete_actions,
|
||||||
continuous_actions = decoded_continuous_actions,
|
continuous_actions = decoded_continuous_actions,
|
||||||
|
proprio = noised_proprio_with_context,
|
||||||
time_kv_cache = time_kv_cache,
|
time_kv_cache = time_kv_cache,
|
||||||
latent_is_noised = True,
|
latent_is_noised = True,
|
||||||
return_pred_only = True,
|
return_pred_only = True,
|
||||||
@ -2357,24 +2402,44 @@ class DynamicsWorldModel(Module):
|
|||||||
if use_time_kv_cache and is_last_step:
|
if use_time_kv_cache and is_last_step:
|
||||||
time_kv_cache = next_time_kv_cache
|
time_kv_cache = next_time_kv_cache
|
||||||
|
|
||||||
|
# maybe proprio
|
||||||
|
|
||||||
|
if has_proprio:
|
||||||
|
pred, pred_proprio = pred
|
||||||
|
|
||||||
# unpack pred
|
# unpack pred
|
||||||
|
|
||||||
_, pred = unpack(pred, pack_context_shape, 'b * n d')
|
_, pred = unpack(pred, pack_context_shape, 'b * n d')
|
||||||
|
|
||||||
|
if has_proprio:
|
||||||
|
_, pred_proprio = unpack(pred_proprio, pack_context_shape, 'b * d')
|
||||||
|
|
||||||
|
|
||||||
# derive flow, based on whether in x-space or not
|
# derive flow, based on whether in x-space or not
|
||||||
|
|
||||||
|
def denoise_step(pred, noised, signal_levels):
|
||||||
if self.pred_orig_latent:
|
if self.pred_orig_latent:
|
||||||
times = self.get_times_from_signal_level(signal_levels, noised_latent)
|
times = self.get_times_from_signal_level(signal_levels)
|
||||||
flow = (pred - noised_latent) / (1. - times)
|
aligned_times = align_dims_left(times, noised)
|
||||||
|
|
||||||
|
flow = (pred - noised) / (1. - aligned_times)
|
||||||
else:
|
else:
|
||||||
flow = pred
|
flow = pred
|
||||||
|
|
||||||
|
return flow * (step_size / self.max_steps)
|
||||||
|
|
||||||
# denoise
|
# denoise
|
||||||
|
|
||||||
noised_latent += flow * (step_size / self.max_steps)
|
noised_latent += denoise_step(pred, noised_latent, signal_levels)
|
||||||
|
|
||||||
|
if has_proprio:
|
||||||
|
noised_proprio += denoise_step(pred_proprio, noised_proprio, signal_levels)
|
||||||
|
|
||||||
denoised_latent = noised_latent # it is now denoised
|
denoised_latent = noised_latent # it is now denoised
|
||||||
|
|
||||||
|
if has_proprio:
|
||||||
|
denoised_proprio = noised_proprio
|
||||||
|
|
||||||
# take care of the rewards by predicting on the agent token embedding on the last denoising step
|
# take care of the rewards by predicting on the agent token embedding on the last denoising step
|
||||||
|
|
||||||
if return_rewards_per_frame:
|
if return_rewards_per_frame:
|
||||||
@ -2421,7 +2486,14 @@ class DynamicsWorldModel(Module):
|
|||||||
|
|
||||||
# add new fixed context noise for the temporal consistency
|
# add new fixed context noise for the temporal consistency
|
||||||
|
|
||||||
past_context_noise = cat((past_context_noise, randn_like(denoised_latent)), dim = 1)
|
past_latents_context_noise = cat((past_latents_context_noise, randn_like(denoised_latent)), dim = 1)
|
||||||
|
|
||||||
|
# handle proprio
|
||||||
|
|
||||||
|
if has_proprio:
|
||||||
|
proprio = cat((proprio, denoised_proprio), dim = 1)
|
||||||
|
|
||||||
|
past_proprio_context_noise = cat((past_proprio_context_noise, randn_like(denoised_proprio)), dim = 1)
|
||||||
|
|
||||||
# restore state
|
# restore state
|
||||||
|
|
||||||
@ -2443,7 +2515,7 @@ class DynamicsWorldModel(Module):
|
|||||||
|
|
||||||
# only return video or latent if not requesting anything else, for first stage training
|
# only return video or latent if not requesting anything else, for first stage training
|
||||||
|
|
||||||
if not has_at_least_one(return_rewards_per_frame, return_agent_actions):
|
if not has_at_least_one(return_rewards_per_frame, return_agent_actions, has_proprio):
|
||||||
out = video if return_decoded_video else latents
|
out = video if return_decoded_video else latents
|
||||||
|
|
||||||
if not return_time_kv_cache:
|
if not return_time_kv_cache:
|
||||||
@ -2456,6 +2528,7 @@ class DynamicsWorldModel(Module):
|
|||||||
gen = Experience(
|
gen = Experience(
|
||||||
latents = latents,
|
latents = latents,
|
||||||
video = video,
|
video = video,
|
||||||
|
proprio = proprio if has_proprio else None,
|
||||||
step_size = step_size,
|
step_size = step_size,
|
||||||
agent_index = agent_index,
|
agent_index = agent_index,
|
||||||
is_from_world_model = True
|
is_from_world_model = True
|
||||||
@ -2492,6 +2565,7 @@ class DynamicsWorldModel(Module):
|
|||||||
continuous_actions = None, # (b t na) | (b t-1 na)
|
continuous_actions = None, # (b t na) | (b t-1 na)
|
||||||
discrete_action_types = None, # (na)
|
discrete_action_types = None, # (na)
|
||||||
continuous_action_types = None, # (na)
|
continuous_action_types = None, # (na)
|
||||||
|
proprio = None, # (b t dp)
|
||||||
time_kv_cache = None,
|
time_kv_cache = None,
|
||||||
return_pred_only = False,
|
return_pred_only = False,
|
||||||
latent_is_noised = False,
|
latent_is_noised = False,
|
||||||
@ -2587,16 +2661,17 @@ class DynamicsWorldModel(Module):
|
|||||||
|
|
||||||
# times is from 0 to 1
|
# times is from 0 to 1
|
||||||
|
|
||||||
times = self.get_times_from_signal_level(signal_levels, latents)
|
times = self.get_times_from_signal_level(signal_levels)
|
||||||
|
|
||||||
if not latent_is_noised:
|
if not latent_is_noised:
|
||||||
# get the noise
|
# get the noise
|
||||||
|
|
||||||
noise = randn_like(latents)
|
noise = randn_like(latents)
|
||||||
|
aligned_times = align_dims_left(times, latents)
|
||||||
|
|
||||||
# noise from 0 as noise to 1 as data
|
# noise from 0 as noise to 1 as data
|
||||||
|
|
||||||
noised_latents = noise.lerp(latents, times)
|
noised_latents = noise.lerp(latents, aligned_times)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
noised_latents = latents
|
noised_latents = latents
|
||||||
@ -2644,6 +2719,27 @@ class DynamicsWorldModel(Module):
|
|||||||
|
|
||||||
reward_tokens = add('1 d, b t d', self.reward_learned_embed, reward_tokens)
|
reward_tokens = add('1 d, b t d', self.reward_learned_embed, reward_tokens)
|
||||||
|
|
||||||
|
# maybe proprioception
|
||||||
|
|
||||||
|
assert xnor(self.has_proprio, exists(proprio)), 'proprio must be passed in if `dim_proprio` is set and vice versa'
|
||||||
|
|
||||||
|
noised_proprio = None
|
||||||
|
|
||||||
|
if self.has_proprio:
|
||||||
|
|
||||||
|
if not latent_is_noised:
|
||||||
|
# get the noise
|
||||||
|
|
||||||
|
proprio_noise = randn_like(proprio)
|
||||||
|
aligned_times = align_dims_left(times, proprio)
|
||||||
|
|
||||||
|
# noise from 0 as noise to 1 as data
|
||||||
|
|
||||||
|
noised_proprio = proprio_noise.lerp(proprio, aligned_times)
|
||||||
|
|
||||||
|
else:
|
||||||
|
noised_proprio = proprio
|
||||||
|
|
||||||
# maybe create the action tokens
|
# maybe create the action tokens
|
||||||
|
|
||||||
if exists(discrete_actions) or exists(continuous_actions):
|
if exists(discrete_actions) or exists(continuous_actions):
|
||||||
@ -2672,7 +2768,7 @@ class DynamicsWorldModel(Module):
|
|||||||
|
|
||||||
# main function, needs to be defined as such for shortcut training - additional calls for consistency loss
|
# main function, needs to be defined as such for shortcut training - additional calls for consistency loss
|
||||||
|
|
||||||
def get_prediction(noised_latents, signal_levels, step_sizes_log2, action_tokens, reward_tokens, agent_tokens, return_agent_tokens = False, return_time_kv_cache = False):
|
def get_prediction(noised_latents, noised_proprio, signal_levels, step_sizes_log2, action_tokens, reward_tokens, agent_tokens, return_agent_tokens = False, return_time_kv_cache = False):
|
||||||
# latents to spatial tokens
|
# latents to spatial tokens
|
||||||
|
|
||||||
space_tokens = self.latents_to_spatial_tokens(noised_latents)
|
space_tokens = self.latents_to_spatial_tokens(noised_latents)
|
||||||
@ -2694,6 +2790,13 @@ class DynamicsWorldModel(Module):
|
|||||||
|
|
||||||
registers = repeat(self.register_tokens, 's d -> b t s d', b = batch, t = time)
|
registers = repeat(self.register_tokens, 's d -> b t s d', b = batch, t = time)
|
||||||
|
|
||||||
|
# maybe proprio
|
||||||
|
|
||||||
|
if exists(noised_proprio):
|
||||||
|
proprio_token = self.to_proprio_token(noised_proprio)
|
||||||
|
else:
|
||||||
|
proprio_token = registers[:, :, 0:0]
|
||||||
|
|
||||||
# determine signal + step size embed for their diffusion forcing + shortcut
|
# determine signal + step size embed for their diffusion forcing + shortcut
|
||||||
|
|
||||||
signal_embed = self.signal_levels_embed(signal_levels)
|
signal_embed = self.signal_levels_embed(signal_levels)
|
||||||
@ -2706,7 +2809,7 @@ class DynamicsWorldModel(Module):
|
|||||||
|
|
||||||
# pack to tokens for attending
|
# pack to tokens for attending
|
||||||
|
|
||||||
tokens, packed_tokens_shape = pack([flow_token, space_tokens, registers, action_tokens, reward_tokens, agent_tokens], 'b t * d')
|
tokens, packed_tokens_shape = pack([flow_token, space_tokens, proprio_token, registers, action_tokens, reward_tokens, agent_tokens], 'b t * d')
|
||||||
|
|
||||||
# attention
|
# attention
|
||||||
|
|
||||||
@ -2714,7 +2817,7 @@ class DynamicsWorldModel(Module):
|
|||||||
|
|
||||||
# unpack
|
# unpack
|
||||||
|
|
||||||
flow_token, space_tokens, register_tokens, action_tokens, reward_tokens, agent_tokens = unpack(tokens, packed_tokens_shape, 'b t * d')
|
flow_token, space_tokens, proprio_token, register_tokens, action_tokens, reward_tokens, agent_tokens = unpack(tokens, packed_tokens_shape, 'b t * d')
|
||||||
|
|
||||||
# pooling
|
# pooling
|
||||||
|
|
||||||
@ -2722,6 +2825,15 @@ class DynamicsWorldModel(Module):
|
|||||||
|
|
||||||
pred = self.to_latent_pred(space_tokens)
|
pred = self.to_latent_pred(space_tokens)
|
||||||
|
|
||||||
|
# maybe proprio
|
||||||
|
|
||||||
|
if self.has_proprio:
|
||||||
|
pred_proprio = self.to_proprio_pred(proprio_token)
|
||||||
|
|
||||||
|
pred = (pred, pred_proprio)
|
||||||
|
|
||||||
|
# returning
|
||||||
|
|
||||||
if not return_agent_tokens:
|
if not return_agent_tokens:
|
||||||
return pred
|
return pred
|
||||||
|
|
||||||
@ -2736,7 +2848,7 @@ class DynamicsWorldModel(Module):
|
|||||||
|
|
||||||
# forward the network
|
# forward the network
|
||||||
|
|
||||||
pred, (encoded_agent_tokens, next_time_kv_cache) = _get_prediction(noised_latents, signal_levels, step_sizes_log2, return_agent_tokens = True, return_time_kv_cache = True)
|
pred, (encoded_agent_tokens, next_time_kv_cache) = _get_prediction(noised_latents, noised_proprio, signal_levels, step_sizes_log2, return_agent_tokens = True, return_time_kv_cache = True)
|
||||||
|
|
||||||
if return_pred_only:
|
if return_pred_only:
|
||||||
if not return_intermediates:
|
if not return_intermediates:
|
||||||
@ -2744,6 +2856,40 @@ class DynamicsWorldModel(Module):
|
|||||||
|
|
||||||
return pred, (encoded_agent_tokens, next_time_kv_cache)
|
return pred, (encoded_agent_tokens, next_time_kv_cache)
|
||||||
|
|
||||||
|
# pack the predictions to calculate flow for different modalities all at once
|
||||||
|
|
||||||
|
if self.has_proprio:
|
||||||
|
pred, for_flow_loss_packed_shape = pack(pred, 'b t *')
|
||||||
|
|
||||||
|
noised, _ = pack((noised_latents, noised_proprio), 'b t *')
|
||||||
|
data, _ = pack((latents, proprio), 'b t *')
|
||||||
|
noise, _ = pack((noise, proprio_noise), 'b t *')
|
||||||
|
else:
|
||||||
|
noised = noised_latents
|
||||||
|
data = latents
|
||||||
|
|
||||||
|
# wrapper function for maybe unpacking and packing modalities for doing flow math in unison
|
||||||
|
|
||||||
|
def maybe_pack_unpack(fn):
|
||||||
|
@wraps(fn)
|
||||||
|
@torch.no_grad()
|
||||||
|
def inner(noised, *args, **kwargs):
|
||||||
|
|
||||||
|
noised_proprio = None
|
||||||
|
|
||||||
|
if self.has_proprio:
|
||||||
|
noised, noised_proprio = unpack(noised, for_flow_loss_packed_shape, 'b t *')
|
||||||
|
|
||||||
|
pred = fn(noised, noised_proprio, *args, **kwargs)
|
||||||
|
|
||||||
|
if self.has_proprio:
|
||||||
|
pred, _ = pack(pred, 'b t *')
|
||||||
|
|
||||||
|
return pred
|
||||||
|
return inner
|
||||||
|
|
||||||
|
wrapped_get_prediction = maybe_pack_unpack(_get_prediction)
|
||||||
|
|
||||||
# determine the target for the loss
|
# determine the target for the loss
|
||||||
|
|
||||||
pred_target = None
|
pred_target = None
|
||||||
@ -2759,46 +2905,45 @@ class DynamicsWorldModel(Module):
|
|||||||
# x-space as in paper is in else clause
|
# x-space as in paper is in else clause
|
||||||
|
|
||||||
if is_v_space_pred:
|
if is_v_space_pred:
|
||||||
pred_target = flow = latents - noise
|
pred_target = flow = data - noise
|
||||||
else:
|
else:
|
||||||
pred_target = latents
|
pred_target = data
|
||||||
else:
|
else:
|
||||||
# shortcut training - Frans et al. https://arxiv.org/abs/2410.12557
|
# shortcut training - Frans et al. https://arxiv.org/abs/2410.12557
|
||||||
|
|
||||||
# basically a consistency loss where you ensure quantity of two half steps equals one step
|
# basically a consistency loss where you ensure quantity of two half steps equals one step
|
||||||
# dreamer then makes it works for x-space with some math
|
# dreamer then makes it works for x-space with some math
|
||||||
|
|
||||||
get_prediction_no_grad = torch.no_grad()(_get_prediction)
|
|
||||||
|
|
||||||
step_sizes_log2_minus_one = step_sizes_log2 - 1 # which equals d / 2
|
step_sizes_log2_minus_one = step_sizes_log2 - 1 # which equals d / 2
|
||||||
half_step_size = 2 ** step_sizes_log2_minus_one
|
half_step_size = 2 ** step_sizes_log2_minus_one
|
||||||
|
|
||||||
first_step_pred = get_prediction_no_grad(noised_latents, signal_levels, step_sizes_log2_minus_one)
|
first_step_pred = wrapped_get_prediction(noised, signal_levels, step_sizes_log2_minus_one)
|
||||||
|
|
||||||
# first derive b'
|
# first derive b'
|
||||||
|
|
||||||
if is_v_space_pred:
|
if is_v_space_pred:
|
||||||
first_step_pred_flow = first_step_pred
|
first_step_pred_flow = first_step_pred
|
||||||
else:
|
else:
|
||||||
first_times = self.get_times_from_signal_level(signal_levels, noised_latents)
|
first_times = self.get_times_from_signal_level(signal_levels, noised)
|
||||||
first_step_pred_flow = (first_step_pred - noised_latents) / (1. - first_times)
|
|
||||||
|
first_step_pred_flow = (first_step_pred - noised) / (1. - first_times)
|
||||||
|
|
||||||
# take a half step
|
# take a half step
|
||||||
|
|
||||||
half_step_size_align_left = align_dims_left(half_step_size, noised_latents)
|
half_step_size_align_left = align_dims_left(half_step_size, noised)
|
||||||
|
|
||||||
denoised_latent = noised_latents + first_step_pred_flow * (half_step_size_align_left / self.max_steps)
|
denoised = noised + first_step_pred_flow * (half_step_size_align_left / self.max_steps)
|
||||||
|
|
||||||
# get second prediction for b''
|
# get second prediction for b''
|
||||||
|
|
||||||
signal_levels_plus_half_step = signal_levels + half_step_size[:, None]
|
signal_levels_plus_half_step = signal_levels + half_step_size[:, None]
|
||||||
second_step_pred = get_prediction_no_grad(denoised_latent, signal_levels_plus_half_step, step_sizes_log2_minus_one)
|
second_step_pred = wrapped_get_prediction(denoised, signal_levels_plus_half_step, step_sizes_log2_minus_one)
|
||||||
|
|
||||||
if is_v_space_pred:
|
if is_v_space_pred:
|
||||||
second_step_pred_flow = second_step_pred
|
second_step_pred_flow = second_step_pred
|
||||||
else:
|
else:
|
||||||
second_times = self.get_times_from_signal_level(signal_levels_plus_half_step, denoised_latent)
|
second_times = self.get_times_from_signal_level(signal_levels_plus_half_step, denoised)
|
||||||
second_step_pred_flow = (second_step_pred - denoised_latent) / (1. - second_times)
|
second_step_pred_flow = (second_step_pred - denoised) / (1. - second_times)
|
||||||
|
|
||||||
# pred target is sg(b' + b'') / 2
|
# pred target is sg(b' + b'') / 2
|
||||||
|
|
||||||
@ -2807,7 +2952,7 @@ class DynamicsWorldModel(Module):
|
|||||||
# need to convert x-space to v-space
|
# need to convert x-space to v-space
|
||||||
|
|
||||||
if is_x_space:
|
if is_x_space:
|
||||||
pred = (pred - noised_latents) / (1. - first_times)
|
pred = (pred - noised) / (1. - first_times)
|
||||||
maybe_shortcut_loss_weight = (1. - first_times) ** 2
|
maybe_shortcut_loss_weight = (1. - first_times) ** 2
|
||||||
|
|
||||||
# mse loss
|
# mse loss
|
||||||
@ -2820,6 +2965,8 @@ class DynamicsWorldModel(Module):
|
|||||||
|
|
||||||
if exists(self.loss_weight_fn):
|
if exists(self.loss_weight_fn):
|
||||||
loss_weight = self.loss_weight_fn(times)
|
loss_weight = self.loss_weight_fn(times)
|
||||||
|
loss_weight = align_dims_left(loss_weight, flow_losses)
|
||||||
|
|
||||||
flow_losses = flow_losses * loss_weight
|
flow_losses = flow_losses * loss_weight
|
||||||
|
|
||||||
flow_loss = flow_losses.mean()
|
flow_loss = flow_losses.mean()
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "dreamer4"
|
name = "dreamer4"
|
||||||
version = "0.0.69"
|
version = "0.0.70"
|
||||||
description = "Dreamer 4"
|
description = "Dreamer 4"
|
||||||
authors = [
|
authors = [
|
||||||
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
||||||
|
|||||||
@ -2,6 +2,9 @@ import pytest
|
|||||||
param = pytest.mark.parametrize
|
param = pytest.mark.parametrize
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
def exists(v):
|
||||||
|
return v is not None
|
||||||
|
|
||||||
@param('pred_orig_latent', (False, True))
|
@param('pred_orig_latent', (False, True))
|
||||||
@param('grouped_query_attn', (False, True))
|
@param('grouped_query_attn', (False, True))
|
||||||
@param('dynamics_with_video_input', (False, True))
|
@param('dynamics_with_video_input', (False, True))
|
||||||
@ -664,3 +667,55 @@ def test_online_rl(
|
|||||||
)
|
)
|
||||||
|
|
||||||
trainer(mock_env, num_episodes = 2, env_is_vectorized = vectorized)
|
trainer(mock_env, num_episodes = 2, env_is_vectorized = vectorized)
|
||||||
|
|
||||||
|
def test_proprioception():
|
||||||
|
from dreamer4.dreamer4 import VideoTokenizer, DynamicsWorldModel
|
||||||
|
|
||||||
|
tokenizer = VideoTokenizer(
|
||||||
|
512,
|
||||||
|
dim_latent = 32,
|
||||||
|
patch_size = 32,
|
||||||
|
encoder_depth = 2,
|
||||||
|
decoder_depth = 2,
|
||||||
|
time_block_every = 2,
|
||||||
|
attn_heads = 8,
|
||||||
|
image_height = 256,
|
||||||
|
image_width = 256,
|
||||||
|
attn_kwargs = dict(
|
||||||
|
query_heads = 16
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
dynamics = DynamicsWorldModel(
|
||||||
|
512,
|
||||||
|
num_agents = 1,
|
||||||
|
video_tokenizer = tokenizer,
|
||||||
|
dim_latent = 32,
|
||||||
|
dim_proprio = 21,
|
||||||
|
num_tasks = 4,
|
||||||
|
num_discrete_actions = 4,
|
||||||
|
num_residual_streams = 1
|
||||||
|
)
|
||||||
|
|
||||||
|
video = torch.randn(2, 3, 10, 256, 256)
|
||||||
|
rewards = torch.randn(2, 10)
|
||||||
|
proprio = torch.randn(2, 10, 21)
|
||||||
|
discrete_actions = torch.randint(0, 4, (2, 10, 1))
|
||||||
|
tasks = torch.randint(0, 4, (2,))
|
||||||
|
|
||||||
|
loss = dynamics(
|
||||||
|
video = video,
|
||||||
|
rewards = rewards,
|
||||||
|
tasks = tasks,
|
||||||
|
proprio = proprio,
|
||||||
|
discrete_actions = discrete_actions
|
||||||
|
)
|
||||||
|
|
||||||
|
loss.backward()
|
||||||
|
|
||||||
|
generations = dynamics.generate(
|
||||||
|
4,
|
||||||
|
batch_size = 2,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert exists(generations.proprio)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user