clean code

This commit is contained in:
NM512 2024-09-24 00:16:12 +09:00
parent 4e50f302cd
commit 59939222d1
2 changed files with 7 additions and 7 deletions

View File

@ -681,7 +681,7 @@ class MLP(nn.Module):
return self.dist(self._dist, mean, std, self._shape) return self.dist(self._dist, mean, std, self._shape)
def dist(self, dist, mean, std, shape): def dist(self, dist, mean, std, shape):
if self._dist == "tanh_normal": if dist == "tanh_normal":
mean = torch.tanh(mean) mean = torch.tanh(mean)
std = F.softplus(std) + self._min_std std = F.softplus(std) + self._min_std
dist = torchd.normal.Normal(mean, std) dist = torchd.normal.Normal(mean, std)
@ -690,7 +690,7 @@ class MLP(nn.Module):
) )
dist = torchd.independent.Independent(dist, 1) dist = torchd.independent.Independent(dist, 1)
dist = tools.SampleDist(dist) dist = tools.SampleDist(dist)
elif self._dist == "normal": elif dist == "normal":
std = (self._max_std - self._min_std) * torch.sigmoid( std = (self._max_std - self._min_std) * torch.sigmoid(
std + 2.0 std + 2.0
) + self._min_std ) + self._min_std
@ -698,21 +698,21 @@ class MLP(nn.Module):
dist = tools.ContDist( dist = tools.ContDist(
torchd.independent.Independent(dist, 1), absmax=self._absmax torchd.independent.Independent(dist, 1), absmax=self._absmax
) )
elif self._dist == "normal_std_fixed": elif dist == "normal_std_fixed":
dist = torchd.normal.Normal(mean, self._std) dist = torchd.normal.Normal(mean, self._std)
dist = tools.ContDist( dist = tools.ContDist(
torchd.independent.Independent(dist, 1), absmax=self._absmax torchd.independent.Independent(dist, 1), absmax=self._absmax
) )
elif self._dist == "trunc_normal": elif dist == "trunc_normal":
mean = torch.tanh(mean) mean = torch.tanh(mean)
std = 2 * torch.sigmoid(std / 2) + self._min_std std = 2 * torch.sigmoid(std / 2) + self._min_std
dist = tools.SafeTruncatedNormal(mean, std, -1, 1) dist = tools.SafeTruncatedNormal(mean, std, -1, 1)
dist = tools.ContDist( dist = tools.ContDist(
torchd.independent.Independent(dist, 1), absmax=self._absmax torchd.independent.Independent(dist, 1), absmax=self._absmax
) )
elif self._dist == "onehot": elif dist == "onehot":
dist = tools.OneHotDist(mean, unimix_ratio=self._unimix_ratio) dist = tools.OneHotDist(mean, unimix_ratio=self._unimix_ratio)
elif self._dist == "onehot_gumble": elif dist == "onehot_gumble":
dist = tools.ContDist( dist = tools.ContDist(
torchd.gumbel.Gumbel(mean, 1 / self._temp), absmax=self._absmax torchd.gumbel.Gumbel(mean, 1 / self._temp), absmax=self._absmax
) )

View File

@ -441,7 +441,7 @@ class OneHotDist(torchd.one_hot_categorical.OneHotCategorical):
def sample(self, sample_shape=(), seed=None): def sample(self, sample_shape=(), seed=None):
if seed is not None: if seed is not None:
raise ValueError("need to check") raise ValueError("need to check")
sample = super().sample(sample_shape) sample = super().sample(sample_shape).detach()
probs = super().probs probs = super().probs
while len(probs.shape) < len(sample.shape): while len(probs.shape) < len(sample.shape):
probs = probs[None] probs = probs[None]