From 59939222d181deccbee88522eeaa582171fc3a1c Mon Sep 17 00:00:00 2001 From: NM512 Date: Tue, 24 Sep 2024 00:16:12 +0900 Subject: [PATCH] clean code --- networks.py | 12 ++++++------ tools.py | 2 +- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/networks.py b/networks.py index c5d7f6a..2517b3b 100644 --- a/networks.py +++ b/networks.py @@ -681,7 +681,7 @@ class MLP(nn.Module): return self.dist(self._dist, mean, std, self._shape) def dist(self, dist, mean, std, shape): - if self._dist == "tanh_normal": + if dist == "tanh_normal": mean = torch.tanh(mean) std = F.softplus(std) + self._min_std dist = torchd.normal.Normal(mean, std) @@ -690,7 +690,7 @@ class MLP(nn.Module): ) dist = torchd.independent.Independent(dist, 1) dist = tools.SampleDist(dist) - elif self._dist == "normal": + elif dist == "normal": std = (self._max_std - self._min_std) * torch.sigmoid( std + 2.0 ) + self._min_std @@ -698,21 +698,21 @@ class MLP(nn.Module): dist = tools.ContDist( torchd.independent.Independent(dist, 1), absmax=self._absmax ) - elif self._dist == "normal_std_fixed": + elif dist == "normal_std_fixed": dist = torchd.normal.Normal(mean, self._std) dist = tools.ContDist( torchd.independent.Independent(dist, 1), absmax=self._absmax ) - elif self._dist == "trunc_normal": + elif dist == "trunc_normal": mean = torch.tanh(mean) std = 2 * torch.sigmoid(std / 2) + self._min_std dist = tools.SafeTruncatedNormal(mean, std, -1, 1) dist = tools.ContDist( torchd.independent.Independent(dist, 1), absmax=self._absmax ) - elif self._dist == "onehot": + elif dist == "onehot": dist = tools.OneHotDist(mean, unimix_ratio=self._unimix_ratio) - elif self._dist == "onehot_gumble": + elif dist == "onehot_gumble": dist = tools.ContDist( torchd.gumbel.Gumbel(mean, 1 / self._temp), absmax=self._absmax ) diff --git a/tools.py b/tools.py index 87f2633..c968e68 100644 --- a/tools.py +++ b/tools.py @@ -441,7 +441,7 @@ class OneHotDist(torchd.one_hot_categorical.OneHotCategorical): def sample(self, sample_shape=(), seed=None): if seed is not None: raise ValueError("need to check") - sample = super().sample(sample_shape) + sample = super().sample(sample_shape).detach() probs = super().probs while len(probs.shape) < len(sample.shape): probs = probs[None]