diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index 83e8658..fdc7c99 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -2171,7 +2171,8 @@ class DynamicsWorldModel(Module): latent_is_noised = False, return_all_losses = False, return_agent_tokens = False, - add_autoregressive_action_loss = False + add_autoregressive_action_loss = False, + update_loss_ema = None ): # handle video or latents @@ -2567,16 +2568,16 @@ class DynamicsWorldModel(Module): losses = WorldModelLosses(flow_loss, reward_loss, discrete_action_loss, continuous_action_loss) if exists(self.flow_loss_normalizer): - flow_loss = self.flow_loss_normalizer(flow_loss) + flow_loss = self.flow_loss_normalizer(flow_loss, update_ema = update_loss_ema) if exists(rewards) and exists(self.reward_loss_normalizer): - reward_loss = self.reward_loss_normalizer(reward_loss) + reward_loss = self.reward_loss_normalizer(reward_loss, update_ema = update_loss_ema) if exists(discrete_actions) and exists(self.discrete_actions_loss_normalizer): - discrete_action_loss = self.discrete_actions_loss_normalizer(discrete_action_loss) + discrete_action_loss = self.discrete_actions_loss_normalizer(discrete_action_loss, update_ema = update_loss_ema) if exists(continuous_actions) and exists(self.continuous_actions_loss_normalizer): - continuous_action_loss = self.continuous_actions_loss_normalizer(continuous_action_loss) + continuous_action_loss = self.continuous_actions_loss_normalizer(continuous_action_loss, update_ema = update_loss_ema) # gather losses - they sum across the multi token prediction steps for rewards and actions - eq (9) diff --git a/pyproject.toml b/pyproject.toml index 0d9f1d7..27e68f6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "dreamer4" -version = "0.0.41" +version = "0.0.42" description = "Dreamer 4" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" }