diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index 467c091..dd3d5ac 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -1817,6 +1817,8 @@ class DynamicsWorldModel(Module): continuous_action_loss_weight: float | list[float] = 1., num_latent_genes = 0, # for carrying out evolution within the dreams https://web3.arxiv.org/abs/2503.19037 num_residual_streams = 1, + keep_reward_ema_stats = False, + reward_ema_decay = 0.998, gae_discount_factor = 0.997, gae_lambda = 0.95, ppo_eps_clip = 0.2, @@ -2022,6 +2024,14 @@ class DynamicsWorldModel(Module): self.value_clip = value_clip self.policy_entropy_weight = value_clip + # rewards related + + self.keep_reward_ema_stats = keep_reward_ema_stats + self.reward_ema_decay = reward_ema_decay + + self.register_buffer('ema_returns_mean', tensor(0.)) + self.register_buffer('ema_returns_var', tensor(1.)) + # loss related self.flow_loss_normalizer = LossNormalizer(1) @@ -2267,11 +2277,32 @@ class DynamicsWorldModel(Module): world_model_forward_context = torch.no_grad if only_learn_policy_value_heads else nullcontext + # maybe keep track returns statistics and normalize returns and values before calculating advantage, as done in dreamer v3 + + if self.keep_reward_ema_stats: + ema_returns_mean, ema_returns_var = self.ema_returns_mean, self.ema_returns_var + + decay = 1. - self.reward_ema_decay + + # todo - handle distributed + + returns_mean, returns_var = returns.mean(), returns.var() + + ema_returns_mean.lerp_(returns_mean, decay) + ema_returns_var.lerp_(returns_var, decay) + + ema_returns_std = ema_returns_var.clamp(min = 1e-5).sqrt() + + normed_returns = (returns - ema_returns_mean) / ema_returns_std + normed_old_values = (old_values - ema_returns_mean) / ema_returns_std + + advantage = normed_returns - normed_old_values + else: + advantage = returns - old_values + # apparently they just use the sign of the advantage # https://arxiv.org/abs/2410.04166v1 - advantage = returns - old_values - if use_signed_advantage: advantage = advantage.sign() else: diff --git a/pyproject.toml b/pyproject.toml index 2691c8b..34ca4e2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "dreamer4" -version = "0.0.80" +version = "0.0.81" description = "Dreamer 4" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" }