From d94a719421401baab3a38696ab23f15b9c98b4ae Mon Sep 17 00:00:00 2001 From: NM512 Date: Thu, 27 Jul 2023 10:01:40 +0900 Subject: [PATCH] bug fix when using normal_1 --- networks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/networks.py b/networks.py index b628b73..5d05a9b 100644 --- a/networks.py +++ b/networks.py @@ -804,7 +804,7 @@ class ActionHead(nn.Module): dist = torchd.normal.Normal(torch.tanh(mean), std) dist = tools.ContDist(torchd.independent.Independent(dist, 1)) elif self._dist == "normal_1": - x = self._dist_layer(x) + mean = self._dist_layer(x) dist = torchd.normal.Normal(mean, 1) dist = tools.ContDist(torchd.independent.Independent(dist, 1)) elif self._dist == "trunc_normal":