will organize the unembedding parameters under the actor optimizer

This commit is contained in:
lucidrains 2025-10-11 06:55:57 -07:00
parent 563b269f8a
commit 02558d1f08

View File

@ -344,6 +344,12 @@ class ActionEmbedder(Module):
self.continuous_action_unembed = Parameter(torch.randn(num_continuous_actions, unembed_dim, 2) * 1e-2)
def embed_parameters(self):
return set([*self.discrete_action_embed.parameters(), *self.continuous_action_embed.parameters()])
def unembed_parameters(self):
return set([*self.discrete_action_unembed.parameters(), *self.continuous_action_unembed.parameters()])
@property
def device(self):
return self.discrete_action_offsets.device