From 9101a49cddaf5dc06fb160647c93698d8dda42f1 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Thu, 9 Oct 2025 08:59:54 -0700 Subject: [PATCH] handle continuous value normalization if stats passed in --- dreamer4/dreamer4.py | 16 +++++++++++++++- tests/test_dreamer.py | 3 ++- 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/dreamer4/dreamer4.py b/dreamer4/dreamer4.py index 9b2fac5..2998f28 100644 --- a/dreamer4/dreamer4.py +++ b/dreamer4/dreamer4.py @@ -282,6 +282,7 @@ class ActionEmbedder(Module): *, num_discrete_actions: int | tuple[int, ...] = 0, num_continuous_actions = 0, + continuous_norm_stats: tuple[tuple[float, float], ...] | None = None ): super().__init__() @@ -298,6 +299,11 @@ class ActionEmbedder(Module): self.num_continuous_action_types = num_continuous_actions self.continuous_action_embed = Embedding(num_continuous_actions, dim) + self.continuous_need_norm = exists(continuous_norm_stats) + + if self.continuous_need_norm: + self.register_buffer('continuous_norm_stats', tensor(continuous_norm_stats)) + # defaults self.register_buffer('default_discrete_action_types', arange(self.num_discrete_action_types), persistent = False) @@ -360,6 +366,12 @@ class ActionEmbedder(Module): continuous_action_embed = self.continuous_action_embed(continuous_action_types) + # maybe normalization + + if self.continuous_need_norm: + norm_mean, norm_std = self.continuous_norm_stats.unbind(dim = -1) + continuous_actions = (continuous_actions - norm_mean) / norm_std.clamp(min = 1e-6) + # continuous embed is just the selected continuous action type with the scale continuous_embeds = einx.multiply('na d, ... na -> ... na d', continuous_action_embed, continuous_actions) @@ -1205,6 +1217,7 @@ class DynamicsWorldModel(Module): add_reward_embed_dropout = 0.1, num_discrete_actions: int | tuple[int, ...] = 0, num_continuous_actions = 0, + continuous_norm_stats = None, reward_loss_weight = 0.1, value_head_mlp_depth = 3, num_latent_genes = 0 # for carrying out evolution within the dreams https://web3.arxiv.org/abs/2503.19037 @@ -1303,7 +1316,8 @@ class DynamicsWorldModel(Module): self.action_embedder = ActionEmbedder( dim = dim, num_discrete_actions = num_discrete_actions, - num_continuous_actions = num_continuous_actions + num_continuous_actions = num_continuous_actions, + continuous_norm_stats = continuous_norm_stats ) # each agent token will have the reward embedding of the previous time step - but could eventually just give reward its own token diff --git a/tests/test_dreamer.py b/tests/test_dreamer.py index bd0f028..6b89f8d 100644 --- a/tests/test_dreamer.py +++ b/tests/test_dreamer.py @@ -220,7 +220,8 @@ def test_action_embedder(): embedder = ActionEmbedder( 512, - num_continuous_actions = 2 + num_continuous_actions = 2, + continuous_norm_stats = ((0., 2.), (1., 1.)) # (mean, std) for normalizing each action ) actions = torch.randn((2, 3, 2))