optionally keep track of returns statistics and normalize with them before advantage
This commit is contained in:
parent
f808b1c1d2
commit
fe79bfa951
@ -1817,6 +1817,8 @@ class DynamicsWorldModel(Module):
|
||||
continuous_action_loss_weight: float | list[float] = 1.,
|
||||
num_latent_genes = 0, # for carrying out evolution within the dreams https://web3.arxiv.org/abs/2503.19037
|
||||
num_residual_streams = 1,
|
||||
keep_reward_ema_stats = False,
|
||||
reward_ema_decay = 0.998,
|
||||
gae_discount_factor = 0.997,
|
||||
gae_lambda = 0.95,
|
||||
ppo_eps_clip = 0.2,
|
||||
@ -2022,6 +2024,14 @@ class DynamicsWorldModel(Module):
|
||||
self.value_clip = value_clip
|
||||
self.policy_entropy_weight = value_clip
|
||||
|
||||
# rewards related
|
||||
|
||||
self.keep_reward_ema_stats = keep_reward_ema_stats
|
||||
self.reward_ema_decay = reward_ema_decay
|
||||
|
||||
self.register_buffer('ema_returns_mean', tensor(0.))
|
||||
self.register_buffer('ema_returns_var', tensor(1.))
|
||||
|
||||
# loss related
|
||||
|
||||
self.flow_loss_normalizer = LossNormalizer(1)
|
||||
@ -2267,11 +2277,32 @@ class DynamicsWorldModel(Module):
|
||||
|
||||
world_model_forward_context = torch.no_grad if only_learn_policy_value_heads else nullcontext
|
||||
|
||||
# maybe keep track returns statistics and normalize returns and values before calculating advantage, as done in dreamer v3
|
||||
|
||||
if self.keep_reward_ema_stats:
|
||||
ema_returns_mean, ema_returns_var = self.ema_returns_mean, self.ema_returns_var
|
||||
|
||||
decay = 1. - self.reward_ema_decay
|
||||
|
||||
# todo - handle distributed
|
||||
|
||||
returns_mean, returns_var = returns.mean(), returns.var()
|
||||
|
||||
ema_returns_mean.lerp_(returns_mean, decay)
|
||||
ema_returns_var.lerp_(returns_var, decay)
|
||||
|
||||
ema_returns_std = ema_returns_var.clamp(min = 1e-5).sqrt()
|
||||
|
||||
normed_returns = (returns - ema_returns_mean) / ema_returns_std
|
||||
normed_old_values = (old_values - ema_returns_mean) / ema_returns_std
|
||||
|
||||
advantage = normed_returns - normed_old_values
|
||||
else:
|
||||
advantage = returns - old_values
|
||||
|
||||
# apparently they just use the sign of the advantage
|
||||
# https://arxiv.org/abs/2410.04166v1
|
||||
|
||||
advantage = returns - old_values
|
||||
|
||||
if use_signed_advantage:
|
||||
advantage = advantage.sign()
|
||||
else:
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "dreamer4"
|
||||
version = "0.0.80"
|
||||
version = "0.0.81"
|
||||
description = "Dreamer 4"
|
||||
authors = [
|
||||
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user