a bit of dropout to rewards as state

This commit is contained in:
lucidrains 2025-10-08 06:45:25 -07:00
parent c8f75caa40
commit 612f5f5dd1

View File

@ -69,6 +69,9 @@ def first(arr):
def divisible_by(num, den):
return (num % den) == 0
def sample_prob(prob):
return random() < prob
def is_power_two(num):
return log2(num).is_integer()
@ -1083,7 +1086,8 @@ class DynamicsModel(Module):
num_future_predictions = 8, # they do multi-token prediction of 8 steps forward
prob_no_shortcut_train = None, # probability of no shortcut training, defaults to 1 / num_step_sizes
add_reward_embed_to_agent_token = False,
reward_loss_weight = 0.1
add_reward_embed_dropout = 0.1,
reward_loss_weight = 0.1,
):
super().__init__()
@ -1259,6 +1263,9 @@ class DynamicsModel(Module):
): # (b t n d) | (b c t h w)
was_training = self.training
self.eval()
assert log2(num_steps).is_integer(), f'number of steps {num_steps} must be a power of 2'
assert 0 < num_steps <= self.max_steps, f'number of steps {num_steps} must be between 0 and {self.max_steps}'
@ -1330,6 +1337,10 @@ class DynamicsModel(Module):
latents = cat((latents, noised_latent), dim = 1)
# restore state
self.train(was_training)
# returning video
has_tokenizer = exists(self.video_tokenizer)
@ -1435,7 +1446,7 @@ class DynamicsModel(Module):
if not is_inference:
no_shortcut_train = random() < self.prob_no_shortcut_train
no_shortcut_train = sample_prob(self.prob_no_shortcut_train)
if no_shortcut_train:
# if no shortcut training, step sizes are just 1 and noising is all steps, where each step is 1 / d_min
@ -1487,7 +1498,10 @@ class DynamicsModel(Module):
if exists(rewards):
two_hot_encoding = self.reward_encoder(rewards)
if self.add_reward_embed_to_agent_token:
if (
self.add_reward_embed_to_agent_token and
(not self.training or not sample_prob(self.add_reward_embed_dropout)) # a bit of noise goes a long way
):
reward_embeds = self.reward_encoder.embed(two_hot_encoding)
pop_last_reward = int(reward_embeds.shape[1] == agent_tokens.shape[1]) # the last reward is popped off during training, during inference, it is not known yet, so need to handle this edge case