addressing issues raised by an independent researcher with llm assistance
This commit is contained in:
parent
60681fce1d
commit
d756d1bb8c
@ -1331,6 +1331,12 @@ class Attention(Module):
|
||||
q = self.q_heads_rmsnorm(q)
|
||||
k = self.k_heads_rmsnorm(k)
|
||||
|
||||
# rotary
|
||||
|
||||
if exists(rotary_pos_emb):
|
||||
q = apply_rotations(rotary_pos_emb, q)
|
||||
k = apply_rotations(rotary_pos_emb, k)
|
||||
|
||||
# caching
|
||||
|
||||
if exists(kv_cache):
|
||||
@ -1338,12 +1344,6 @@ class Attention(Module):
|
||||
k = cat((ck, k), dim = -2)
|
||||
v = cat((cv, v), dim = -2)
|
||||
|
||||
# rotary
|
||||
|
||||
if exists(rotary_pos_emb):
|
||||
q = apply_rotations(rotary_pos_emb, q)
|
||||
k = apply_rotations(rotary_pos_emb, k)
|
||||
|
||||
# attention
|
||||
|
||||
attend_fn = default(attend_fn, naive_attend)
|
||||
@ -1507,12 +1507,11 @@ class AxialSpaceTimeTransformer(Module):
|
||||
|
||||
has_kv_cache = exists(kv_cache)
|
||||
|
||||
|
||||
if has_kv_cache:
|
||||
past_tokens, tokens = tokens[:, :-1], tokens[:, -1:]
|
||||
|
||||
rotary_seq_len = 1
|
||||
rotary_pos_offset = past_tokens.shape[-2]
|
||||
rotary_pos_offset = past_tokens.shape[1]
|
||||
else:
|
||||
rotary_seq_len = time
|
||||
rotary_pos_offset = 0
|
||||
@ -1687,6 +1686,7 @@ class VideoTokenizer(Module):
|
||||
time_block_every = time_block_every,
|
||||
num_special_spatial_tokens = num_latent_tokens,
|
||||
num_residual_streams = num_residual_streams,
|
||||
special_attend_only_itself = True,
|
||||
final_norm = True
|
||||
)
|
||||
|
||||
@ -2456,8 +2456,8 @@ class DynamicsWorldModel(Module):
|
||||
if exists(experience.lens):
|
||||
mask_for_gae = lens_to_mask(experience.lens, time)
|
||||
|
||||
rewards = rewards.masked_fill(mask_for_gae, 0.)
|
||||
old_values = old_values.masked_fill(mask_for_gae, 0.)
|
||||
rewards = rewards.masked_fill(~mask_for_gae, 0.)
|
||||
old_values = old_values.masked_fill(~mask_for_gae, 0.)
|
||||
|
||||
# calculate returns
|
||||
|
||||
@ -2492,7 +2492,7 @@ class DynamicsWorldModel(Module):
|
||||
|
||||
# mean, var - todo - handle distributed
|
||||
|
||||
returns_mean, returns_var = returns.mean(), returns.var()
|
||||
returns_mean, returns_var = returns_for_stats.mean(), returns_for_stats.var()
|
||||
|
||||
# ema
|
||||
|
||||
@ -3510,7 +3510,7 @@ class DynamicsWorldModel(Module):
|
||||
|
||||
reward_losses = F.cross_entropy(reward_pred, reward_targets, reduction = 'none')
|
||||
|
||||
reward_losses = reward_losses.masked_fill(reward_loss_mask, 0.)
|
||||
reward_losses = reward_losses.masked_fill(~reward_loss_mask, 0.)
|
||||
|
||||
if is_var_len:
|
||||
reward_loss = reward_losses[loss_mask_without_last].mean(dim = 0)
|
||||
@ -3554,7 +3554,7 @@ class DynamicsWorldModel(Module):
|
||||
discrete_mask = rearrange(discrete_mask, 'b t mtp -> mtp b t')
|
||||
|
||||
if exists(continuous_actions):
|
||||
continuous_action_targets, continuous_mask = create_multi_token_prediction_targets(discrete_actions, self.multi_token_pred_len)
|
||||
continuous_action_targets, continuous_mask = create_multi_token_prediction_targets(continuous_actions, self.multi_token_pred_len)
|
||||
continuous_action_targets = rearrange(continuous_action_targets, 'b t mtp ... -> mtp b t ...')
|
||||
continuous_mask = rearrange(continuous_mask, 'b t mtp -> mtp b t')
|
||||
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "dreamer4"
|
||||
version = "0.1.0"
|
||||
version = "0.1.2"
|
||||
description = "Dreamer 4"
|
||||
authors = [
|
||||
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user