diff --git a/networks.py b/networks.py index b628b73..5d05a9b 100644 --- a/networks.py +++ b/networks.py @@ -804,7 +804,7 @@ class ActionHead(nn.Module): dist = torchd.normal.Normal(torch.tanh(mean), std) dist = tools.ContDist(torchd.independent.Independent(dist, 1)) elif self._dist == "normal_1": - x = self._dist_layer(x) + mean = self._dist_layer(x) dist = torchd.normal.Normal(mean, 1) dist = tools.ContDist(torchd.independent.Independent(dist, 1)) elif self._dist == "trunc_normal":