optionally keep track of returns statistics and normalize with them before advantage

This commit is contained in:
lucidrains 2025-10-27 09:02:08 -07:00
parent f808b1c1d2
commit fe79bfa951
2 changed files with 34 additions and 3 deletions

View File

@ -1817,6 +1817,8 @@ class DynamicsWorldModel(Module):
continuous_action_loss_weight: float | list[float] = 1.,
num_latent_genes = 0, # for carrying out evolution within the dreams https://web3.arxiv.org/abs/2503.19037
num_residual_streams = 1,
keep_reward_ema_stats = False,
reward_ema_decay = 0.998,
gae_discount_factor = 0.997,
gae_lambda = 0.95,
ppo_eps_clip = 0.2,
@ -2022,6 +2024,14 @@ class DynamicsWorldModel(Module):
self.value_clip = value_clip
self.policy_entropy_weight = value_clip
# rewards related
self.keep_reward_ema_stats = keep_reward_ema_stats
self.reward_ema_decay = reward_ema_decay
self.register_buffer('ema_returns_mean', tensor(0.))
self.register_buffer('ema_returns_var', tensor(1.))
# loss related
self.flow_loss_normalizer = LossNormalizer(1)
@ -2267,11 +2277,32 @@ class DynamicsWorldModel(Module):
world_model_forward_context = torch.no_grad if only_learn_policy_value_heads else nullcontext
# maybe keep track returns statistics and normalize returns and values before calculating advantage, as done in dreamer v3
if self.keep_reward_ema_stats:
ema_returns_mean, ema_returns_var = self.ema_returns_mean, self.ema_returns_var
decay = 1. - self.reward_ema_decay
# todo - handle distributed
returns_mean, returns_var = returns.mean(), returns.var()
ema_returns_mean.lerp_(returns_mean, decay)
ema_returns_var.lerp_(returns_var, decay)
ema_returns_std = ema_returns_var.clamp(min = 1e-5).sqrt()
normed_returns = (returns - ema_returns_mean) / ema_returns_std
normed_old_values = (old_values - ema_returns_mean) / ema_returns_std
advantage = normed_returns - normed_old_values
else:
advantage = returns - old_values
# apparently they just use the sign of the advantage
# https://arxiv.org/abs/2410.04166v1
advantage = returns - old_values
if use_signed_advantage:
advantage = advantage.sign()
else:

View File

@ -1,6 +1,6 @@
[project]
name = "dreamer4"
version = "0.0.80"
version = "0.0.81"
description = "Dreamer 4"
authors = [
{ name = "Phil Wang", email = "lucidrains@gmail.com" }