clean code
This commit is contained in:
parent
4e50f302cd
commit
59939222d1
12
networks.py
12
networks.py
@ -681,7 +681,7 @@ class MLP(nn.Module):
|
|||||||
return self.dist(self._dist, mean, std, self._shape)
|
return self.dist(self._dist, mean, std, self._shape)
|
||||||
|
|
||||||
def dist(self, dist, mean, std, shape):
|
def dist(self, dist, mean, std, shape):
|
||||||
if self._dist == "tanh_normal":
|
if dist == "tanh_normal":
|
||||||
mean = torch.tanh(mean)
|
mean = torch.tanh(mean)
|
||||||
std = F.softplus(std) + self._min_std
|
std = F.softplus(std) + self._min_std
|
||||||
dist = torchd.normal.Normal(mean, std)
|
dist = torchd.normal.Normal(mean, std)
|
||||||
@ -690,7 +690,7 @@ class MLP(nn.Module):
|
|||||||
)
|
)
|
||||||
dist = torchd.independent.Independent(dist, 1)
|
dist = torchd.independent.Independent(dist, 1)
|
||||||
dist = tools.SampleDist(dist)
|
dist = tools.SampleDist(dist)
|
||||||
elif self._dist == "normal":
|
elif dist == "normal":
|
||||||
std = (self._max_std - self._min_std) * torch.sigmoid(
|
std = (self._max_std - self._min_std) * torch.sigmoid(
|
||||||
std + 2.0
|
std + 2.0
|
||||||
) + self._min_std
|
) + self._min_std
|
||||||
@ -698,21 +698,21 @@ class MLP(nn.Module):
|
|||||||
dist = tools.ContDist(
|
dist = tools.ContDist(
|
||||||
torchd.independent.Independent(dist, 1), absmax=self._absmax
|
torchd.independent.Independent(dist, 1), absmax=self._absmax
|
||||||
)
|
)
|
||||||
elif self._dist == "normal_std_fixed":
|
elif dist == "normal_std_fixed":
|
||||||
dist = torchd.normal.Normal(mean, self._std)
|
dist = torchd.normal.Normal(mean, self._std)
|
||||||
dist = tools.ContDist(
|
dist = tools.ContDist(
|
||||||
torchd.independent.Independent(dist, 1), absmax=self._absmax
|
torchd.independent.Independent(dist, 1), absmax=self._absmax
|
||||||
)
|
)
|
||||||
elif self._dist == "trunc_normal":
|
elif dist == "trunc_normal":
|
||||||
mean = torch.tanh(mean)
|
mean = torch.tanh(mean)
|
||||||
std = 2 * torch.sigmoid(std / 2) + self._min_std
|
std = 2 * torch.sigmoid(std / 2) + self._min_std
|
||||||
dist = tools.SafeTruncatedNormal(mean, std, -1, 1)
|
dist = tools.SafeTruncatedNormal(mean, std, -1, 1)
|
||||||
dist = tools.ContDist(
|
dist = tools.ContDist(
|
||||||
torchd.independent.Independent(dist, 1), absmax=self._absmax
|
torchd.independent.Independent(dist, 1), absmax=self._absmax
|
||||||
)
|
)
|
||||||
elif self._dist == "onehot":
|
elif dist == "onehot":
|
||||||
dist = tools.OneHotDist(mean, unimix_ratio=self._unimix_ratio)
|
dist = tools.OneHotDist(mean, unimix_ratio=self._unimix_ratio)
|
||||||
elif self._dist == "onehot_gumble":
|
elif dist == "onehot_gumble":
|
||||||
dist = tools.ContDist(
|
dist = tools.ContDist(
|
||||||
torchd.gumbel.Gumbel(mean, 1 / self._temp), absmax=self._absmax
|
torchd.gumbel.Gumbel(mean, 1 / self._temp), absmax=self._absmax
|
||||||
)
|
)
|
||||||
|
2
tools.py
2
tools.py
@ -441,7 +441,7 @@ class OneHotDist(torchd.one_hot_categorical.OneHotCategorical):
|
|||||||
def sample(self, sample_shape=(), seed=None):
|
def sample(self, sample_shape=(), seed=None):
|
||||||
if seed is not None:
|
if seed is not None:
|
||||||
raise ValueError("need to check")
|
raise ValueError("need to check")
|
||||||
sample = super().sample(sample_shape)
|
sample = super().sample(sample_shape).detach()
|
||||||
probs = super().probs
|
probs = super().probs
|
||||||
while len(probs.shape) < len(sample.shape):
|
while len(probs.shape) < len(sample.shape):
|
||||||
probs = probs[None]
|
probs = probs[None]
|
||||||
|
Loading…
x
Reference in New Issue
Block a user