handle continuous value normalization if stats passed in
This commit is contained in:
parent
31f4363be7
commit
9101a49cdd
@ -282,6 +282,7 @@ class ActionEmbedder(Module):
|
|||||||
*,
|
*,
|
||||||
num_discrete_actions: int | tuple[int, ...] = 0,
|
num_discrete_actions: int | tuple[int, ...] = 0,
|
||||||
num_continuous_actions = 0,
|
num_continuous_actions = 0,
|
||||||
|
continuous_norm_stats: tuple[tuple[float, float], ...] | None = None
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@ -298,6 +299,11 @@ class ActionEmbedder(Module):
|
|||||||
self.num_continuous_action_types = num_continuous_actions
|
self.num_continuous_action_types = num_continuous_actions
|
||||||
self.continuous_action_embed = Embedding(num_continuous_actions, dim)
|
self.continuous_action_embed = Embedding(num_continuous_actions, dim)
|
||||||
|
|
||||||
|
self.continuous_need_norm = exists(continuous_norm_stats)
|
||||||
|
|
||||||
|
if self.continuous_need_norm:
|
||||||
|
self.register_buffer('continuous_norm_stats', tensor(continuous_norm_stats))
|
||||||
|
|
||||||
# defaults
|
# defaults
|
||||||
|
|
||||||
self.register_buffer('default_discrete_action_types', arange(self.num_discrete_action_types), persistent = False)
|
self.register_buffer('default_discrete_action_types', arange(self.num_discrete_action_types), persistent = False)
|
||||||
@ -360,6 +366,12 @@ class ActionEmbedder(Module):
|
|||||||
|
|
||||||
continuous_action_embed = self.continuous_action_embed(continuous_action_types)
|
continuous_action_embed = self.continuous_action_embed(continuous_action_types)
|
||||||
|
|
||||||
|
# maybe normalization
|
||||||
|
|
||||||
|
if self.continuous_need_norm:
|
||||||
|
norm_mean, norm_std = self.continuous_norm_stats.unbind(dim = -1)
|
||||||
|
continuous_actions = (continuous_actions - norm_mean) / norm_std.clamp(min = 1e-6)
|
||||||
|
|
||||||
# continuous embed is just the selected continuous action type with the scale
|
# continuous embed is just the selected continuous action type with the scale
|
||||||
|
|
||||||
continuous_embeds = einx.multiply('na d, ... na -> ... na d', continuous_action_embed, continuous_actions)
|
continuous_embeds = einx.multiply('na d, ... na -> ... na d', continuous_action_embed, continuous_actions)
|
||||||
@ -1205,6 +1217,7 @@ class DynamicsWorldModel(Module):
|
|||||||
add_reward_embed_dropout = 0.1,
|
add_reward_embed_dropout = 0.1,
|
||||||
num_discrete_actions: int | tuple[int, ...] = 0,
|
num_discrete_actions: int | tuple[int, ...] = 0,
|
||||||
num_continuous_actions = 0,
|
num_continuous_actions = 0,
|
||||||
|
continuous_norm_stats = None,
|
||||||
reward_loss_weight = 0.1,
|
reward_loss_weight = 0.1,
|
||||||
value_head_mlp_depth = 3,
|
value_head_mlp_depth = 3,
|
||||||
num_latent_genes = 0 # for carrying out evolution within the dreams https://web3.arxiv.org/abs/2503.19037
|
num_latent_genes = 0 # for carrying out evolution within the dreams https://web3.arxiv.org/abs/2503.19037
|
||||||
@ -1303,7 +1316,8 @@ class DynamicsWorldModel(Module):
|
|||||||
self.action_embedder = ActionEmbedder(
|
self.action_embedder = ActionEmbedder(
|
||||||
dim = dim,
|
dim = dim,
|
||||||
num_discrete_actions = num_discrete_actions,
|
num_discrete_actions = num_discrete_actions,
|
||||||
num_continuous_actions = num_continuous_actions
|
num_continuous_actions = num_continuous_actions,
|
||||||
|
continuous_norm_stats = continuous_norm_stats
|
||||||
)
|
)
|
||||||
|
|
||||||
# each agent token will have the reward embedding of the previous time step - but could eventually just give reward its own token
|
# each agent token will have the reward embedding of the previous time step - but could eventually just give reward its own token
|
||||||
|
|||||||
@ -220,7 +220,8 @@ def test_action_embedder():
|
|||||||
|
|
||||||
embedder = ActionEmbedder(
|
embedder = ActionEmbedder(
|
||||||
512,
|
512,
|
||||||
num_continuous_actions = 2
|
num_continuous_actions = 2,
|
||||||
|
continuous_norm_stats = ((0., 2.), (1., 1.)) # (mean, std) for normalizing each action
|
||||||
)
|
)
|
||||||
|
|
||||||
actions = torch.randn((2, 3, 2))
|
actions = torch.randn((2, 3, 2))
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user