quantile filter

This commit is contained in:
lucidrains 2025-10-27 09:08:26 -07:00
parent fe79bfa951
commit fd1e87983b

View File

@ -1819,6 +1819,7 @@ class DynamicsWorldModel(Module):
num_residual_streams = 1,
keep_reward_ema_stats = False,
reward_ema_decay = 0.998,
reward_quantile_filter = (0.05, 0.95),
gae_discount_factor = 0.997,
gae_lambda = 0.95,
ppo_eps_clip = 0.2,
@ -2029,6 +2030,8 @@ class DynamicsWorldModel(Module):
self.keep_reward_ema_stats = keep_reward_ema_stats
self.reward_ema_decay = reward_ema_decay
self.register_buffer('reward_quantile_filter', tensor(reward_quantile_filter), persistent = False)
self.register_buffer('ema_returns_mean', tensor(0.))
self.register_buffer('ema_returns_var', tensor(1.))
@ -2284,13 +2287,22 @@ class DynamicsWorldModel(Module):
decay = 1. - self.reward_ema_decay
# todo - handle distributed
# quantile filter
lo, hi = torch.quantile(returns, self.reward_quantile_filter).tolist()
returns_for_stats = returns.clamp(lo, hi)
# mean, var - todo - handle distributed
returns_mean, returns_var = returns.mean(), returns.var()
# ema
ema_returns_mean.lerp_(returns_mean, decay)
ema_returns_var.lerp_(returns_var, decay)
# normalize
ema_returns_std = ema_returns_var.clamp(min = 1e-5).sqrt()
normed_returns = (returns - ema_returns_mean) / ema_returns_std