take a step towards variable lengthed experiences during training
This commit is contained in:
parent
77a40e8701
commit
3d5617d769
@ -155,6 +155,15 @@ def is_power_two(num):
|
||||
def is_empty(t):
|
||||
return t.numel() == 0
|
||||
|
||||
def lens_to_mask(t, max_len = None):
|
||||
if not exists(max_len):
|
||||
max_len = t.amax().item()
|
||||
|
||||
device = t.device
|
||||
seq = torch.arange(max_len, device = device)
|
||||
|
||||
return einx.less('j, i -> i j', seq, t)
|
||||
|
||||
def log(t, eps = 1e-20):
|
||||
return t.clamp(min = eps).log()
|
||||
|
||||
@ -2581,6 +2590,7 @@ class DynamicsWorldModel(Module):
|
||||
*,
|
||||
video = None, # (b v? c t vh vw)
|
||||
latents = None, # (b t v? n d) | (b t v? d)
|
||||
lens = None, # (b)
|
||||
signal_levels = None, # () | (b) | (b t)
|
||||
step_sizes = None, # () | (b)
|
||||
step_sizes_log2 = None, # () | (b)
|
||||
@ -3014,7 +3024,19 @@ class DynamicsWorldModel(Module):
|
||||
|
||||
flow_losses = flow_losses * loss_weight
|
||||
|
||||
flow_loss = flow_losses.mean()
|
||||
# handle variable lengths if needed
|
||||
|
||||
is_var_len = exists(lens)
|
||||
|
||||
if is_var_len:
|
||||
|
||||
loss_mask = lens_to_mask(lens, time)
|
||||
loss_mask_without_last = loss_mask[:, :-1]
|
||||
|
||||
flow_loss = flow_losses[loss_mask].mean()
|
||||
|
||||
else:
|
||||
flow_loss = flow_losses.mean()
|
||||
|
||||
# now take care of the agent token losses
|
||||
|
||||
@ -3037,7 +3059,10 @@ class DynamicsWorldModel(Module):
|
||||
|
||||
reward_losses = reward_losses.masked_fill(reward_loss_mask, 0.)
|
||||
|
||||
reward_loss = reduce(reward_losses, '... mtp -> mtp', 'mean') # they sum across the prediction steps (mtp dimension) - eq(9)
|
||||
if is_var_len:
|
||||
reward_loss = reward_losses[loss_mask_without_last].mean(dim = 0)
|
||||
else:
|
||||
reward_loss = reduce(reward_losses, '... mtp -> mtp', 'mean') # they sum across the prediction steps (mtp dimension) - eq(9)
|
||||
|
||||
# maybe autoregressive action loss
|
||||
|
||||
@ -3080,12 +3105,20 @@ class DynamicsWorldModel(Module):
|
||||
if exists(discrete_log_probs):
|
||||
discrete_log_probs = discrete_log_probs.masked_fill(~discrete_mask[..., None], 0.)
|
||||
|
||||
discrete_action_loss = reduce(-discrete_log_probs, 'mtp b t na -> mtp', 'mean')
|
||||
if is_var_len:
|
||||
discrete_action_losses = rearrange(-discrete_log_probs, 'mtp b t na -> b t na mtp')
|
||||
discrete_action_loss = reduce(discrete_action_losses[loss_mask_without_last], '... mtp -> mtp', 'mean')
|
||||
else:
|
||||
discrete_action_loss = reduce(-discrete_log_probs, 'mtp b t na -> mtp', 'mean')
|
||||
|
||||
if exists(continuous_log_probs):
|
||||
continuous_log_probs = continuous_log_probs.masked_fill(~continuous_mask[..., None], 0.)
|
||||
|
||||
continuous_action_loss = reduce(-continuous_log_probs, 'mtp b t na -> mtp', 'mean')
|
||||
if is_var_len:
|
||||
continuous_action_losses = rearrange(-continuous_log_probs, 'mtp b t na -> b t na mtp')
|
||||
continuous_action_loss = reduce(continuous_action_losses[loss_mask_without_last], '... mtp -> mtp', 'mean')
|
||||
else:
|
||||
continuous_action_loss = reduce(-continuous_log_probs, 'mtp b t na -> mtp', 'mean')
|
||||
|
||||
# handle loss normalization
|
||||
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "dreamer4"
|
||||
version = "0.0.71"
|
||||
version = "0.0.72"
|
||||
description = "Dreamer 4"
|
||||
authors = [
|
||||
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
||||
|
||||
@ -16,6 +16,7 @@ def exists(v):
|
||||
@param('num_residual_streams', (1, 4))
|
||||
@param('add_reward_embed_to_agent_token', (False, True))
|
||||
@param('use_time_kv_cache', (False, True))
|
||||
@param('var_len', (False, True))
|
||||
def test_e2e(
|
||||
pred_orig_latent,
|
||||
grouped_query_attn,
|
||||
@ -27,7 +28,8 @@ def test_e2e(
|
||||
condition_on_actions,
|
||||
num_residual_streams,
|
||||
add_reward_embed_to_agent_token,
|
||||
use_time_kv_cache
|
||||
use_time_kv_cache,
|
||||
var_len
|
||||
):
|
||||
from dreamer4.dreamer4 import VideoTokenizer, DynamicsWorldModel
|
||||
|
||||
@ -95,8 +97,13 @@ def test_e2e(
|
||||
if condition_on_actions:
|
||||
actions = torch.randint(0, 4, (2, 3, 1))
|
||||
|
||||
lens = None
|
||||
if var_len:
|
||||
lens = torch.randint(1, 4, (2,))
|
||||
|
||||
flow_loss = dynamics(
|
||||
**dynamics_input,
|
||||
lens = lens,
|
||||
tasks = tasks,
|
||||
signal_levels = signal_levels,
|
||||
step_sizes_log2 = step_sizes_log2,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user