diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index f44793e..b5c8cd0 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -69,6 +69,9 @@ def first(arr): def divisible_by(num, den): return (num % den) == 0 +def sample_prob(prob): + return random() < prob + def is_power_two(num): return log2(num).is_integer() @@ -1083,7 +1086,8 @@ class DynamicsModel(Module): num_future_predictions = 8, # they do multi-token prediction of 8 steps forward prob_no_shortcut_train = None, # probability of no shortcut training, defaults to 1 / num_step_sizes add_reward_embed_to_agent_token = False, - reward_loss_weight = 0.1 + add_reward_embed_dropout = 0.1, + reward_loss_weight = 0.1, ): super().__init__() @@ -1259,6 +1263,9 @@ class DynamicsModel(Module): ): # (b t n d) | (b c t h w) + was_training = self.training + self.eval() + assert log2(num_steps).is_integer(), f'number of steps {num_steps} must be a power of 2' assert 0 < num_steps <= self.max_steps, f'number of steps {num_steps} must be between 0 and {self.max_steps}' @@ -1330,6 +1337,10 @@ class DynamicsModel(Module): latents = cat((latents, noised_latent), dim = 1) + # restore state + + self.train(was_training) + # returning video has_tokenizer = exists(self.video_tokenizer) @@ -1435,7 +1446,7 @@ class DynamicsModel(Module): if not is_inference: - no_shortcut_train = random() < self.prob_no_shortcut_train + no_shortcut_train = sample_prob(self.prob_no_shortcut_train) if no_shortcut_train: # if no shortcut training, step sizes are just 1 and noising is all steps, where each step is 1 / d_min @@ -1487,7 +1498,10 @@ class DynamicsModel(Module): if exists(rewards): two_hot_encoding = self.reward_encoder(rewards) - if self.add_reward_embed_to_agent_token: + if ( + self.add_reward_embed_to_agent_token and + (not self.training or not sample_prob(self.add_reward_embed_dropout)) # a bit of noise goes a long way + ): reward_embeds = self.reward_encoder.embed(two_hot_encoding) pop_last_reward = int(reward_embeds.shape[1] == agent_tokens.shape[1]) # the last reward is popped off during training, during inference, it is not known yet, so need to handle this edge case