clean code

This commit is contained in:
NM512 2024-09-24 00:16:12 +09:00
parent 4e50f302cd
commit 59939222d1
2 changed files with 7 additions and 7 deletions

View File

@ -681,7 +681,7 @@ class MLP(nn.Module):
return self.dist(self._dist, mean, std, self._shape)
def dist(self, dist, mean, std, shape):
if self._dist == "tanh_normal":
if dist == "tanh_normal":
mean = torch.tanh(mean)
std = F.softplus(std) + self._min_std
dist = torchd.normal.Normal(mean, std)
@ -690,7 +690,7 @@ class MLP(nn.Module):
)
dist = torchd.independent.Independent(dist, 1)
dist = tools.SampleDist(dist)
elif self._dist == "normal":
elif dist == "normal":
std = (self._max_std - self._min_std) * torch.sigmoid(
std + 2.0
) + self._min_std
@ -698,21 +698,21 @@ class MLP(nn.Module):
dist = tools.ContDist(
torchd.independent.Independent(dist, 1), absmax=self._absmax
)
elif self._dist == "normal_std_fixed":
elif dist == "normal_std_fixed":
dist = torchd.normal.Normal(mean, self._std)
dist = tools.ContDist(
torchd.independent.Independent(dist, 1), absmax=self._absmax
)
elif self._dist == "trunc_normal":
elif dist == "trunc_normal":
mean = torch.tanh(mean)
std = 2 * torch.sigmoid(std / 2) + self._min_std
dist = tools.SafeTruncatedNormal(mean, std, -1, 1)
dist = tools.ContDist(
torchd.independent.Independent(dist, 1), absmax=self._absmax
)
elif self._dist == "onehot":
elif dist == "onehot":
dist = tools.OneHotDist(mean, unimix_ratio=self._unimix_ratio)
elif self._dist == "onehot_gumble":
elif dist == "onehot_gumble":
dist = tools.ContDist(
torchd.gumbel.Gumbel(mean, 1 / self._temp), absmax=self._absmax
)

View File

@ -441,7 +441,7 @@ class OneHotDist(torchd.one_hot_categorical.OneHotCategorical):
def sample(self, sample_shape=(), seed=None):
if seed is not None:
raise ValueError("need to check")
sample = super().sample(sample_shape)
sample = super().sample(sample_shape).detach()
probs = super().probs
while len(probs.shape) < len(sample.shape):
probs = probs[None]