take a step towards variable lengthed experiences during training
This commit is contained in:
parent
77a40e8701
commit
3d5617d769
@ -155,6 +155,15 @@ def is_power_two(num):
|
|||||||
def is_empty(t):
|
def is_empty(t):
|
||||||
return t.numel() == 0
|
return t.numel() == 0
|
||||||
|
|
||||||
|
def lens_to_mask(t, max_len = None):
|
||||||
|
if not exists(max_len):
|
||||||
|
max_len = t.amax().item()
|
||||||
|
|
||||||
|
device = t.device
|
||||||
|
seq = torch.arange(max_len, device = device)
|
||||||
|
|
||||||
|
return einx.less('j, i -> i j', seq, t)
|
||||||
|
|
||||||
def log(t, eps = 1e-20):
|
def log(t, eps = 1e-20):
|
||||||
return t.clamp(min = eps).log()
|
return t.clamp(min = eps).log()
|
||||||
|
|
||||||
@ -2581,6 +2590,7 @@ class DynamicsWorldModel(Module):
|
|||||||
*,
|
*,
|
||||||
video = None, # (b v? c t vh vw)
|
video = None, # (b v? c t vh vw)
|
||||||
latents = None, # (b t v? n d) | (b t v? d)
|
latents = None, # (b t v? n d) | (b t v? d)
|
||||||
|
lens = None, # (b)
|
||||||
signal_levels = None, # () | (b) | (b t)
|
signal_levels = None, # () | (b) | (b t)
|
||||||
step_sizes = None, # () | (b)
|
step_sizes = None, # () | (b)
|
||||||
step_sizes_log2 = None, # () | (b)
|
step_sizes_log2 = None, # () | (b)
|
||||||
@ -3014,6 +3024,18 @@ class DynamicsWorldModel(Module):
|
|||||||
|
|
||||||
flow_losses = flow_losses * loss_weight
|
flow_losses = flow_losses * loss_weight
|
||||||
|
|
||||||
|
# handle variable lengths if needed
|
||||||
|
|
||||||
|
is_var_len = exists(lens)
|
||||||
|
|
||||||
|
if is_var_len:
|
||||||
|
|
||||||
|
loss_mask = lens_to_mask(lens, time)
|
||||||
|
loss_mask_without_last = loss_mask[:, :-1]
|
||||||
|
|
||||||
|
flow_loss = flow_losses[loss_mask].mean()
|
||||||
|
|
||||||
|
else:
|
||||||
flow_loss = flow_losses.mean()
|
flow_loss = flow_losses.mean()
|
||||||
|
|
||||||
# now take care of the agent token losses
|
# now take care of the agent token losses
|
||||||
@ -3037,6 +3059,9 @@ class DynamicsWorldModel(Module):
|
|||||||
|
|
||||||
reward_losses = reward_losses.masked_fill(reward_loss_mask, 0.)
|
reward_losses = reward_losses.masked_fill(reward_loss_mask, 0.)
|
||||||
|
|
||||||
|
if is_var_len:
|
||||||
|
reward_loss = reward_losses[loss_mask_without_last].mean(dim = 0)
|
||||||
|
else:
|
||||||
reward_loss = reduce(reward_losses, '... mtp -> mtp', 'mean') # they sum across the prediction steps (mtp dimension) - eq(9)
|
reward_loss = reduce(reward_losses, '... mtp -> mtp', 'mean') # they sum across the prediction steps (mtp dimension) - eq(9)
|
||||||
|
|
||||||
# maybe autoregressive action loss
|
# maybe autoregressive action loss
|
||||||
@ -3080,11 +3105,19 @@ class DynamicsWorldModel(Module):
|
|||||||
if exists(discrete_log_probs):
|
if exists(discrete_log_probs):
|
||||||
discrete_log_probs = discrete_log_probs.masked_fill(~discrete_mask[..., None], 0.)
|
discrete_log_probs = discrete_log_probs.masked_fill(~discrete_mask[..., None], 0.)
|
||||||
|
|
||||||
|
if is_var_len:
|
||||||
|
discrete_action_losses = rearrange(-discrete_log_probs, 'mtp b t na -> b t na mtp')
|
||||||
|
discrete_action_loss = reduce(discrete_action_losses[loss_mask_without_last], '... mtp -> mtp', 'mean')
|
||||||
|
else:
|
||||||
discrete_action_loss = reduce(-discrete_log_probs, 'mtp b t na -> mtp', 'mean')
|
discrete_action_loss = reduce(-discrete_log_probs, 'mtp b t na -> mtp', 'mean')
|
||||||
|
|
||||||
if exists(continuous_log_probs):
|
if exists(continuous_log_probs):
|
||||||
continuous_log_probs = continuous_log_probs.masked_fill(~continuous_mask[..., None], 0.)
|
continuous_log_probs = continuous_log_probs.masked_fill(~continuous_mask[..., None], 0.)
|
||||||
|
|
||||||
|
if is_var_len:
|
||||||
|
continuous_action_losses = rearrange(-continuous_log_probs, 'mtp b t na -> b t na mtp')
|
||||||
|
continuous_action_loss = reduce(continuous_action_losses[loss_mask_without_last], '... mtp -> mtp', 'mean')
|
||||||
|
else:
|
||||||
continuous_action_loss = reduce(-continuous_log_probs, 'mtp b t na -> mtp', 'mean')
|
continuous_action_loss = reduce(-continuous_log_probs, 'mtp b t na -> mtp', 'mean')
|
||||||
|
|
||||||
# handle loss normalization
|
# handle loss normalization
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "dreamer4"
|
name = "dreamer4"
|
||||||
version = "0.0.71"
|
version = "0.0.72"
|
||||||
description = "Dreamer 4"
|
description = "Dreamer 4"
|
||||||
authors = [
|
authors = [
|
||||||
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
||||||
|
|||||||
@ -16,6 +16,7 @@ def exists(v):
|
|||||||
@param('num_residual_streams', (1, 4))
|
@param('num_residual_streams', (1, 4))
|
||||||
@param('add_reward_embed_to_agent_token', (False, True))
|
@param('add_reward_embed_to_agent_token', (False, True))
|
||||||
@param('use_time_kv_cache', (False, True))
|
@param('use_time_kv_cache', (False, True))
|
||||||
|
@param('var_len', (False, True))
|
||||||
def test_e2e(
|
def test_e2e(
|
||||||
pred_orig_latent,
|
pred_orig_latent,
|
||||||
grouped_query_attn,
|
grouped_query_attn,
|
||||||
@ -27,7 +28,8 @@ def test_e2e(
|
|||||||
condition_on_actions,
|
condition_on_actions,
|
||||||
num_residual_streams,
|
num_residual_streams,
|
||||||
add_reward_embed_to_agent_token,
|
add_reward_embed_to_agent_token,
|
||||||
use_time_kv_cache
|
use_time_kv_cache,
|
||||||
|
var_len
|
||||||
):
|
):
|
||||||
from dreamer4.dreamer4 import VideoTokenizer, DynamicsWorldModel
|
from dreamer4.dreamer4 import VideoTokenizer, DynamicsWorldModel
|
||||||
|
|
||||||
@ -95,8 +97,13 @@ def test_e2e(
|
|||||||
if condition_on_actions:
|
if condition_on_actions:
|
||||||
actions = torch.randint(0, 4, (2, 3, 1))
|
actions = torch.randint(0, 4, (2, 3, 1))
|
||||||
|
|
||||||
|
lens = None
|
||||||
|
if var_len:
|
||||||
|
lens = torch.randint(1, 4, (2,))
|
||||||
|
|
||||||
flow_loss = dynamics(
|
flow_loss = dynamics(
|
||||||
**dynamics_input,
|
**dynamics_input,
|
||||||
|
lens = lens,
|
||||||
tasks = tasks,
|
tasks = tasks,
|
||||||
signal_levels = signal_levels,
|
signal_levels = signal_levels,
|
||||||
step_sizes_log2 = step_sizes_log2,
|
step_sizes_log2 = step_sizes_log2,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user