bug fix when using normal_1
This commit is contained in:
parent
6924abdd3e
commit
d94a719421
@ -804,7 +804,7 @@ class ActionHead(nn.Module):
|
|||||||
dist = torchd.normal.Normal(torch.tanh(mean), std)
|
dist = torchd.normal.Normal(torch.tanh(mean), std)
|
||||||
dist = tools.ContDist(torchd.independent.Independent(dist, 1))
|
dist = tools.ContDist(torchd.independent.Independent(dist, 1))
|
||||||
elif self._dist == "normal_1":
|
elif self._dist == "normal_1":
|
||||||
x = self._dist_layer(x)
|
mean = self._dist_layer(x)
|
||||||
dist = torchd.normal.Normal(mean, 1)
|
dist = torchd.normal.Normal(mean, 1)
|
||||||
dist = tools.ContDist(torchd.independent.Independent(dist, 1))
|
dist = tools.ContDist(torchd.independent.Independent(dist, 1))
|
||||||
elif self._dist == "trunc_normal":
|
elif self._dist == "trunc_normal":
|
||||||
|
Loading…
x
Reference in New Issue
Block a user