optionally keep track of returns statistics and normalize with them before advantage
This commit is contained in:
parent
f808b1c1d2
commit
fe79bfa951
@ -1817,6 +1817,8 @@ class DynamicsWorldModel(Module):
|
|||||||
continuous_action_loss_weight: float | list[float] = 1.,
|
continuous_action_loss_weight: float | list[float] = 1.,
|
||||||
num_latent_genes = 0, # for carrying out evolution within the dreams https://web3.arxiv.org/abs/2503.19037
|
num_latent_genes = 0, # for carrying out evolution within the dreams https://web3.arxiv.org/abs/2503.19037
|
||||||
num_residual_streams = 1,
|
num_residual_streams = 1,
|
||||||
|
keep_reward_ema_stats = False,
|
||||||
|
reward_ema_decay = 0.998,
|
||||||
gae_discount_factor = 0.997,
|
gae_discount_factor = 0.997,
|
||||||
gae_lambda = 0.95,
|
gae_lambda = 0.95,
|
||||||
ppo_eps_clip = 0.2,
|
ppo_eps_clip = 0.2,
|
||||||
@ -2022,6 +2024,14 @@ class DynamicsWorldModel(Module):
|
|||||||
self.value_clip = value_clip
|
self.value_clip = value_clip
|
||||||
self.policy_entropy_weight = value_clip
|
self.policy_entropy_weight = value_clip
|
||||||
|
|
||||||
|
# rewards related
|
||||||
|
|
||||||
|
self.keep_reward_ema_stats = keep_reward_ema_stats
|
||||||
|
self.reward_ema_decay = reward_ema_decay
|
||||||
|
|
||||||
|
self.register_buffer('ema_returns_mean', tensor(0.))
|
||||||
|
self.register_buffer('ema_returns_var', tensor(1.))
|
||||||
|
|
||||||
# loss related
|
# loss related
|
||||||
|
|
||||||
self.flow_loss_normalizer = LossNormalizer(1)
|
self.flow_loss_normalizer = LossNormalizer(1)
|
||||||
@ -2267,11 +2277,32 @@ class DynamicsWorldModel(Module):
|
|||||||
|
|
||||||
world_model_forward_context = torch.no_grad if only_learn_policy_value_heads else nullcontext
|
world_model_forward_context = torch.no_grad if only_learn_policy_value_heads else nullcontext
|
||||||
|
|
||||||
|
# maybe keep track returns statistics and normalize returns and values before calculating advantage, as done in dreamer v3
|
||||||
|
|
||||||
|
if self.keep_reward_ema_stats:
|
||||||
|
ema_returns_mean, ema_returns_var = self.ema_returns_mean, self.ema_returns_var
|
||||||
|
|
||||||
|
decay = 1. - self.reward_ema_decay
|
||||||
|
|
||||||
|
# todo - handle distributed
|
||||||
|
|
||||||
|
returns_mean, returns_var = returns.mean(), returns.var()
|
||||||
|
|
||||||
|
ema_returns_mean.lerp_(returns_mean, decay)
|
||||||
|
ema_returns_var.lerp_(returns_var, decay)
|
||||||
|
|
||||||
|
ema_returns_std = ema_returns_var.clamp(min = 1e-5).sqrt()
|
||||||
|
|
||||||
|
normed_returns = (returns - ema_returns_mean) / ema_returns_std
|
||||||
|
normed_old_values = (old_values - ema_returns_mean) / ema_returns_std
|
||||||
|
|
||||||
|
advantage = normed_returns - normed_old_values
|
||||||
|
else:
|
||||||
|
advantage = returns - old_values
|
||||||
|
|
||||||
# apparently they just use the sign of the advantage
|
# apparently they just use the sign of the advantage
|
||||||
# https://arxiv.org/abs/2410.04166v1
|
# https://arxiv.org/abs/2410.04166v1
|
||||||
|
|
||||||
advantage = returns - old_values
|
|
||||||
|
|
||||||
if use_signed_advantage:
|
if use_signed_advantage:
|
||||||
advantage = advantage.sign()
|
advantage = advantage.sign()
|
||||||
else:
|
else:
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "dreamer4"
|
name = "dreamer4"
|
||||||
version = "0.0.80"
|
version = "0.0.81"
|
||||||
description = "Dreamer 4"
|
description = "Dreamer 4"
|
||||||
authors = [
|
authors = [
|
||||||
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user