a bit of dropout to rewards as state
This commit is contained in:
parent
c8f75caa40
commit
612f5f5dd1
@ -69,6 +69,9 @@ def first(arr):
|
|||||||
def divisible_by(num, den):
|
def divisible_by(num, den):
|
||||||
return (num % den) == 0
|
return (num % den) == 0
|
||||||
|
|
||||||
|
def sample_prob(prob):
|
||||||
|
return random() < prob
|
||||||
|
|
||||||
def is_power_two(num):
|
def is_power_two(num):
|
||||||
return log2(num).is_integer()
|
return log2(num).is_integer()
|
||||||
|
|
||||||
@ -1083,7 +1086,8 @@ class DynamicsModel(Module):
|
|||||||
num_future_predictions = 8, # they do multi-token prediction of 8 steps forward
|
num_future_predictions = 8, # they do multi-token prediction of 8 steps forward
|
||||||
prob_no_shortcut_train = None, # probability of no shortcut training, defaults to 1 / num_step_sizes
|
prob_no_shortcut_train = None, # probability of no shortcut training, defaults to 1 / num_step_sizes
|
||||||
add_reward_embed_to_agent_token = False,
|
add_reward_embed_to_agent_token = False,
|
||||||
reward_loss_weight = 0.1
|
add_reward_embed_dropout = 0.1,
|
||||||
|
reward_loss_weight = 0.1,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@ -1259,6 +1263,9 @@ class DynamicsModel(Module):
|
|||||||
|
|
||||||
): # (b t n d) | (b c t h w)
|
): # (b t n d) | (b c t h w)
|
||||||
|
|
||||||
|
was_training = self.training
|
||||||
|
self.eval()
|
||||||
|
|
||||||
assert log2(num_steps).is_integer(), f'number of steps {num_steps} must be a power of 2'
|
assert log2(num_steps).is_integer(), f'number of steps {num_steps} must be a power of 2'
|
||||||
assert 0 < num_steps <= self.max_steps, f'number of steps {num_steps} must be between 0 and {self.max_steps}'
|
assert 0 < num_steps <= self.max_steps, f'number of steps {num_steps} must be between 0 and {self.max_steps}'
|
||||||
|
|
||||||
@ -1330,6 +1337,10 @@ class DynamicsModel(Module):
|
|||||||
|
|
||||||
latents = cat((latents, noised_latent), dim = 1)
|
latents = cat((latents, noised_latent), dim = 1)
|
||||||
|
|
||||||
|
# restore state
|
||||||
|
|
||||||
|
self.train(was_training)
|
||||||
|
|
||||||
# returning video
|
# returning video
|
||||||
|
|
||||||
has_tokenizer = exists(self.video_tokenizer)
|
has_tokenizer = exists(self.video_tokenizer)
|
||||||
@ -1435,7 +1446,7 @@ class DynamicsModel(Module):
|
|||||||
|
|
||||||
if not is_inference:
|
if not is_inference:
|
||||||
|
|
||||||
no_shortcut_train = random() < self.prob_no_shortcut_train
|
no_shortcut_train = sample_prob(self.prob_no_shortcut_train)
|
||||||
|
|
||||||
if no_shortcut_train:
|
if no_shortcut_train:
|
||||||
# if no shortcut training, step sizes are just 1 and noising is all steps, where each step is 1 / d_min
|
# if no shortcut training, step sizes are just 1 and noising is all steps, where each step is 1 / d_min
|
||||||
@ -1487,7 +1498,10 @@ class DynamicsModel(Module):
|
|||||||
if exists(rewards):
|
if exists(rewards):
|
||||||
two_hot_encoding = self.reward_encoder(rewards)
|
two_hot_encoding = self.reward_encoder(rewards)
|
||||||
|
|
||||||
if self.add_reward_embed_to_agent_token:
|
if (
|
||||||
|
self.add_reward_embed_to_agent_token and
|
||||||
|
(not self.training or not sample_prob(self.add_reward_embed_dropout)) # a bit of noise goes a long way
|
||||||
|
):
|
||||||
reward_embeds = self.reward_encoder.embed(two_hot_encoding)
|
reward_embeds = self.reward_encoder.embed(two_hot_encoding)
|
||||||
|
|
||||||
pop_last_reward = int(reward_embeds.shape[1] == agent_tokens.shape[1]) # the last reward is popped off during training, during inference, it is not known yet, so need to handle this edge case
|
pop_last_reward = int(reward_embeds.shape[1] == agent_tokens.shape[1]) # the last reward is popped off during training, during inference, it is not known yet, so need to handle this edge case
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user