quantile filter
This commit is contained in:
parent
fe79bfa951
commit
fd1e87983b
@ -1819,6 +1819,7 @@ class DynamicsWorldModel(Module):
|
||||
num_residual_streams = 1,
|
||||
keep_reward_ema_stats = False,
|
||||
reward_ema_decay = 0.998,
|
||||
reward_quantile_filter = (0.05, 0.95),
|
||||
gae_discount_factor = 0.997,
|
||||
gae_lambda = 0.95,
|
||||
ppo_eps_clip = 0.2,
|
||||
@ -2029,6 +2030,8 @@ class DynamicsWorldModel(Module):
|
||||
self.keep_reward_ema_stats = keep_reward_ema_stats
|
||||
self.reward_ema_decay = reward_ema_decay
|
||||
|
||||
self.register_buffer('reward_quantile_filter', tensor(reward_quantile_filter), persistent = False)
|
||||
|
||||
self.register_buffer('ema_returns_mean', tensor(0.))
|
||||
self.register_buffer('ema_returns_var', tensor(1.))
|
||||
|
||||
@ -2284,13 +2287,22 @@ class DynamicsWorldModel(Module):
|
||||
|
||||
decay = 1. - self.reward_ema_decay
|
||||
|
||||
# todo - handle distributed
|
||||
# quantile filter
|
||||
|
||||
lo, hi = torch.quantile(returns, self.reward_quantile_filter).tolist()
|
||||
returns_for_stats = returns.clamp(lo, hi)
|
||||
|
||||
# mean, var - todo - handle distributed
|
||||
|
||||
returns_mean, returns_var = returns.mean(), returns.var()
|
||||
|
||||
# ema
|
||||
|
||||
ema_returns_mean.lerp_(returns_mean, decay)
|
||||
ema_returns_var.lerp_(returns_var, decay)
|
||||
|
||||
# normalize
|
||||
|
||||
ema_returns_std = ema_returns_var.clamp(min = 1e-5).sqrt()
|
||||
|
||||
normed_returns = (returns - ema_returns_mean) / ema_returns_std
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user