will organize the unembedding parameters under the actor optimizer
This commit is contained in:
parent
563b269f8a
commit
02558d1f08
@ -344,6 +344,12 @@ class ActionEmbedder(Module):
|
|||||||
|
|
||||||
self.continuous_action_unembed = Parameter(torch.randn(num_continuous_actions, unembed_dim, 2) * 1e-2)
|
self.continuous_action_unembed = Parameter(torch.randn(num_continuous_actions, unembed_dim, 2) * 1e-2)
|
||||||
|
|
||||||
|
def embed_parameters(self):
|
||||||
|
return set([*self.discrete_action_embed.parameters(), *self.continuous_action_embed.parameters()])
|
||||||
|
|
||||||
|
def unembed_parameters(self):
|
||||||
|
return set([*self.discrete_action_unembed.parameters(), *self.continuous_action_unembed.parameters()])
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def device(self):
|
def device(self):
|
||||||
return self.discrete_action_offsets.device
|
return self.discrete_action_offsets.device
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user