handle continuous value normalization if stats passed in
This commit is contained in:
parent
31f4363be7
commit
9101a49cdd
@ -282,6 +282,7 @@ class ActionEmbedder(Module):
|
||||
*,
|
||||
num_discrete_actions: int | tuple[int, ...] = 0,
|
||||
num_continuous_actions = 0,
|
||||
continuous_norm_stats: tuple[tuple[float, float], ...] | None = None
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@ -298,6 +299,11 @@ class ActionEmbedder(Module):
|
||||
self.num_continuous_action_types = num_continuous_actions
|
||||
self.continuous_action_embed = Embedding(num_continuous_actions, dim)
|
||||
|
||||
self.continuous_need_norm = exists(continuous_norm_stats)
|
||||
|
||||
if self.continuous_need_norm:
|
||||
self.register_buffer('continuous_norm_stats', tensor(continuous_norm_stats))
|
||||
|
||||
# defaults
|
||||
|
||||
self.register_buffer('default_discrete_action_types', arange(self.num_discrete_action_types), persistent = False)
|
||||
@ -360,6 +366,12 @@ class ActionEmbedder(Module):
|
||||
|
||||
continuous_action_embed = self.continuous_action_embed(continuous_action_types)
|
||||
|
||||
# maybe normalization
|
||||
|
||||
if self.continuous_need_norm:
|
||||
norm_mean, norm_std = self.continuous_norm_stats.unbind(dim = -1)
|
||||
continuous_actions = (continuous_actions - norm_mean) / norm_std.clamp(min = 1e-6)
|
||||
|
||||
# continuous embed is just the selected continuous action type with the scale
|
||||
|
||||
continuous_embeds = einx.multiply('na d, ... na -> ... na d', continuous_action_embed, continuous_actions)
|
||||
@ -1205,6 +1217,7 @@ class DynamicsWorldModel(Module):
|
||||
add_reward_embed_dropout = 0.1,
|
||||
num_discrete_actions: int | tuple[int, ...] = 0,
|
||||
num_continuous_actions = 0,
|
||||
continuous_norm_stats = None,
|
||||
reward_loss_weight = 0.1,
|
||||
value_head_mlp_depth = 3,
|
||||
num_latent_genes = 0 # for carrying out evolution within the dreams https://web3.arxiv.org/abs/2503.19037
|
||||
@ -1303,7 +1316,8 @@ class DynamicsWorldModel(Module):
|
||||
self.action_embedder = ActionEmbedder(
|
||||
dim = dim,
|
||||
num_discrete_actions = num_discrete_actions,
|
||||
num_continuous_actions = num_continuous_actions
|
||||
num_continuous_actions = num_continuous_actions,
|
||||
continuous_norm_stats = continuous_norm_stats
|
||||
)
|
||||
|
||||
# each agent token will have the reward embedding of the previous time step - but could eventually just give reward its own token
|
||||
|
||||
@ -220,7 +220,8 @@ def test_action_embedder():
|
||||
|
||||
embedder = ActionEmbedder(
|
||||
512,
|
||||
num_continuous_actions = 2
|
||||
num_continuous_actions = 2,
|
||||
continuous_norm_stats = ((0., 2.), (1., 1.)) # (mean, std) for normalizing each action
|
||||
)
|
||||
|
||||
actions = torch.randn((2, 3, 2))
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user