a bit of dropout to rewards as state
This commit is contained in:
parent
c8f75caa40
commit
612f5f5dd1
@ -69,6 +69,9 @@ def first(arr):
|
||||
def divisible_by(num, den):
|
||||
return (num % den) == 0
|
||||
|
||||
def sample_prob(prob):
|
||||
return random() < prob
|
||||
|
||||
def is_power_two(num):
|
||||
return log2(num).is_integer()
|
||||
|
||||
@ -1083,7 +1086,8 @@ class DynamicsModel(Module):
|
||||
num_future_predictions = 8, # they do multi-token prediction of 8 steps forward
|
||||
prob_no_shortcut_train = None, # probability of no shortcut training, defaults to 1 / num_step_sizes
|
||||
add_reward_embed_to_agent_token = False,
|
||||
reward_loss_weight = 0.1
|
||||
add_reward_embed_dropout = 0.1,
|
||||
reward_loss_weight = 0.1,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@ -1259,6 +1263,9 @@ class DynamicsModel(Module):
|
||||
|
||||
): # (b t n d) | (b c t h w)
|
||||
|
||||
was_training = self.training
|
||||
self.eval()
|
||||
|
||||
assert log2(num_steps).is_integer(), f'number of steps {num_steps} must be a power of 2'
|
||||
assert 0 < num_steps <= self.max_steps, f'number of steps {num_steps} must be between 0 and {self.max_steps}'
|
||||
|
||||
@ -1330,6 +1337,10 @@ class DynamicsModel(Module):
|
||||
|
||||
latents = cat((latents, noised_latent), dim = 1)
|
||||
|
||||
# restore state
|
||||
|
||||
self.train(was_training)
|
||||
|
||||
# returning video
|
||||
|
||||
has_tokenizer = exists(self.video_tokenizer)
|
||||
@ -1435,7 +1446,7 @@ class DynamicsModel(Module):
|
||||
|
||||
if not is_inference:
|
||||
|
||||
no_shortcut_train = random() < self.prob_no_shortcut_train
|
||||
no_shortcut_train = sample_prob(self.prob_no_shortcut_train)
|
||||
|
||||
if no_shortcut_train:
|
||||
# if no shortcut training, step sizes are just 1 and noising is all steps, where each step is 1 / d_min
|
||||
@ -1487,7 +1498,10 @@ class DynamicsModel(Module):
|
||||
if exists(rewards):
|
||||
two_hot_encoding = self.reward_encoder(rewards)
|
||||
|
||||
if self.add_reward_embed_to_agent_token:
|
||||
if (
|
||||
self.add_reward_embed_to_agent_token and
|
||||
(not self.training or not sample_prob(self.add_reward_embed_dropout)) # a bit of noise goes a long way
|
||||
):
|
||||
reward_embeds = self.reward_encoder.embed(two_hot_encoding)
|
||||
|
||||
pop_last_reward = int(reward_embeds.shape[1] == agent_tokens.shape[1]) # the last reward is popped off during training, during inference, it is not known yet, so need to handle this edge case
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user