bug fix when using normal_1

This commit is contained in:
NM512 2023-07-27 10:01:40 +09:00
parent 6924abdd3e
commit d94a719421

View File

@ -804,7 +804,7 @@ class ActionHead(nn.Module):
dist = torchd.normal.Normal(torch.tanh(mean), std)
dist = tools.ContDist(torchd.independent.Independent(dist, 1))
elif self._dist == "normal_1":
x = self._dist_layer(x)
mean = self._dist_layer(x)
dist = torchd.normal.Normal(mean, 1)
dist = tools.ContDist(torchd.independent.Independent(dist, 1))
elif self._dist == "trunc_normal":