From fd1e87983b5d28a3ae4b264522433728a9b8169e Mon Sep 17 00:00:00 2001 From: lucidrains Date: Mon, 27 Oct 2025 09:08:26 -0700 Subject: [PATCH] quantile filter --- dreamer4/dreamer4.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index dd3d5ac..7d0b920 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -1819,6 +1819,7 @@ class DynamicsWorldModel(Module): num_residual_streams = 1, keep_reward_ema_stats = False, reward_ema_decay = 0.998, + reward_quantile_filter = (0.05, 0.95), gae_discount_factor = 0.997, gae_lambda = 0.95, ppo_eps_clip = 0.2, @@ -2029,6 +2030,8 @@ class DynamicsWorldModel(Module): self.keep_reward_ema_stats = keep_reward_ema_stats self.reward_ema_decay = reward_ema_decay + self.register_buffer('reward_quantile_filter', tensor(reward_quantile_filter), persistent = False) + self.register_buffer('ema_returns_mean', tensor(0.)) self.register_buffer('ema_returns_var', tensor(1.)) @@ -2284,13 +2287,22 @@ class DynamicsWorldModel(Module): decay = 1. - self.reward_ema_decay - # todo - handle distributed + # quantile filter + + lo, hi = torch.quantile(returns, self.reward_quantile_filter).tolist() + returns_for_stats = returns.clamp(lo, hi) + + # mean, var - todo - handle distributed returns_mean, returns_var = returns.mean(), returns.var() + # ema + ema_returns_mean.lerp_(returns_mean, decay) ema_returns_var.lerp_(returns_var, decay) + # normalize + ema_returns_std = ema_returns_var.clamp(min = 1e-5).sqrt() normed_returns = (returns - ema_returns_mean) / ema_returns_std