clean code
This commit is contained in:
parent
4e50f302cd
commit
59939222d1
12
networks.py
12
networks.py
@ -681,7 +681,7 @@ class MLP(nn.Module):
|
||||
return self.dist(self._dist, mean, std, self._shape)
|
||||
|
||||
def dist(self, dist, mean, std, shape):
|
||||
if self._dist == "tanh_normal":
|
||||
if dist == "tanh_normal":
|
||||
mean = torch.tanh(mean)
|
||||
std = F.softplus(std) + self._min_std
|
||||
dist = torchd.normal.Normal(mean, std)
|
||||
@ -690,7 +690,7 @@ class MLP(nn.Module):
|
||||
)
|
||||
dist = torchd.independent.Independent(dist, 1)
|
||||
dist = tools.SampleDist(dist)
|
||||
elif self._dist == "normal":
|
||||
elif dist == "normal":
|
||||
std = (self._max_std - self._min_std) * torch.sigmoid(
|
||||
std + 2.0
|
||||
) + self._min_std
|
||||
@ -698,21 +698,21 @@ class MLP(nn.Module):
|
||||
dist = tools.ContDist(
|
||||
torchd.independent.Independent(dist, 1), absmax=self._absmax
|
||||
)
|
||||
elif self._dist == "normal_std_fixed":
|
||||
elif dist == "normal_std_fixed":
|
||||
dist = torchd.normal.Normal(mean, self._std)
|
||||
dist = tools.ContDist(
|
||||
torchd.independent.Independent(dist, 1), absmax=self._absmax
|
||||
)
|
||||
elif self._dist == "trunc_normal":
|
||||
elif dist == "trunc_normal":
|
||||
mean = torch.tanh(mean)
|
||||
std = 2 * torch.sigmoid(std / 2) + self._min_std
|
||||
dist = tools.SafeTruncatedNormal(mean, std, -1, 1)
|
||||
dist = tools.ContDist(
|
||||
torchd.independent.Independent(dist, 1), absmax=self._absmax
|
||||
)
|
||||
elif self._dist == "onehot":
|
||||
elif dist == "onehot":
|
||||
dist = tools.OneHotDist(mean, unimix_ratio=self._unimix_ratio)
|
||||
elif self._dist == "onehot_gumble":
|
||||
elif dist == "onehot_gumble":
|
||||
dist = tools.ContDist(
|
||||
torchd.gumbel.Gumbel(mean, 1 / self._temp), absmax=self._absmax
|
||||
)
|
||||
|
2
tools.py
2
tools.py
@ -441,7 +441,7 @@ class OneHotDist(torchd.one_hot_categorical.OneHotCategorical):
|
||||
def sample(self, sample_shape=(), seed=None):
|
||||
if seed is not None:
|
||||
raise ValueError("need to check")
|
||||
sample = super().sample(sample_shape)
|
||||
sample = super().sample(sample_shape).detach()
|
||||
probs = super().probs
|
||||
while len(probs.shape) < len(sample.shape):
|
||||
probs = probs[None]
|
||||
|
Loading…
x
Reference in New Issue
Block a user