handle continuous value normalization if stats passed in

This commit is contained in:
lucidrains 2025-10-09 08:59:54 -07:00
parent 31f4363be7
commit 9101a49cdd
2 changed files with 17 additions and 2 deletions

View File

@ -282,6 +282,7 @@ class ActionEmbedder(Module):
*,
num_discrete_actions: int | tuple[int, ...] = 0,
num_continuous_actions = 0,
continuous_norm_stats: tuple[tuple[float, float], ...] | None = None
):
super().__init__()
@ -298,6 +299,11 @@ class ActionEmbedder(Module):
self.num_continuous_action_types = num_continuous_actions
self.continuous_action_embed = Embedding(num_continuous_actions, dim)
self.continuous_need_norm = exists(continuous_norm_stats)
if self.continuous_need_norm:
self.register_buffer('continuous_norm_stats', tensor(continuous_norm_stats))
# defaults
self.register_buffer('default_discrete_action_types', arange(self.num_discrete_action_types), persistent = False)
@ -360,6 +366,12 @@ class ActionEmbedder(Module):
continuous_action_embed = self.continuous_action_embed(continuous_action_types)
# maybe normalization
if self.continuous_need_norm:
norm_mean, norm_std = self.continuous_norm_stats.unbind(dim = -1)
continuous_actions = (continuous_actions - norm_mean) / norm_std.clamp(min = 1e-6)
# continuous embed is just the selected continuous action type with the scale
continuous_embeds = einx.multiply('na d, ... na -> ... na d', continuous_action_embed, continuous_actions)
@ -1205,6 +1217,7 @@ class DynamicsWorldModel(Module):
add_reward_embed_dropout = 0.1,
num_discrete_actions: int | tuple[int, ...] = 0,
num_continuous_actions = 0,
continuous_norm_stats = None,
reward_loss_weight = 0.1,
value_head_mlp_depth = 3,
num_latent_genes = 0 # for carrying out evolution within the dreams https://web3.arxiv.org/abs/2503.19037
@ -1303,7 +1316,8 @@ class DynamicsWorldModel(Module):
self.action_embedder = ActionEmbedder(
dim = dim,
num_discrete_actions = num_discrete_actions,
num_continuous_actions = num_continuous_actions
num_continuous_actions = num_continuous_actions,
continuous_norm_stats = continuous_norm_stats
)
# each agent token will have the reward embedding of the previous time step - but could eventually just give reward its own token

View File

@ -220,7 +220,8 @@ def test_action_embedder():
embedder = ActionEmbedder(
512,
num_continuous_actions = 2
num_continuous_actions = 2,
continuous_norm_stats = ((0., 2.), (1., 1.)) # (mean, std) for normalizing each action
)
actions = torch.randn((2, 3, 2))