bug fix when using normal_1
This commit is contained in:
parent
6924abdd3e
commit
d94a719421
@ -804,7 +804,7 @@ class ActionHead(nn.Module):
|
||||
dist = torchd.normal.Normal(torch.tanh(mean), std)
|
||||
dist = tools.ContDist(torchd.independent.Independent(dist, 1))
|
||||
elif self._dist == "normal_1":
|
||||
x = self._dist_layer(x)
|
||||
mean = self._dist_layer(x)
|
||||
dist = torchd.normal.Normal(mean, 1)
|
||||
dist = tools.ContDist(torchd.independent.Independent(dist, 1))
|
||||
elif self._dist == "trunc_normal":
|
||||
|
Loading…
x
Reference in New Issue
Block a user