will organize the unembedding parameters under the actor optimizer
This commit is contained in:
parent
563b269f8a
commit
02558d1f08
@ -344,6 +344,12 @@ class ActionEmbedder(Module):
|
||||
|
||||
self.continuous_action_unembed = Parameter(torch.randn(num_continuous_actions, unembed_dim, 2) * 1e-2)
|
||||
|
||||
def embed_parameters(self):
|
||||
return set([*self.discrete_action_embed.parameters(), *self.continuous_action_embed.parameters()])
|
||||
|
||||
def unembed_parameters(self):
|
||||
return set([*self.discrete_action_unembed.parameters(), *self.continuous_action_unembed.parameters()])
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return self.discrete_action_offsets.device
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user