limit action values in sampling stage
This commit is contained in:
parent
a9e85e8b7c
commit
a27711ab96
@ -239,9 +239,10 @@ class ImagBehavior(nn.Module):
|
||||
"learned",
|
||||
config.actor["min_std"],
|
||||
config.actor["max_std"],
|
||||
config.actor["temp"],
|
||||
absmax=1.0,
|
||||
temp=config.actor["temp"],
|
||||
unimix_ratio=config.actor["unimix_ratio"],
|
||||
outscale=1.0,
|
||||
outscale=config.actor["outscale"],
|
||||
name="Actor",
|
||||
)
|
||||
self.value = networks.MLP(
|
||||
|
30
networks.py
30
networks.py
@ -200,9 +200,8 @@ class RSSM(nn.Module):
|
||||
return dist
|
||||
|
||||
def obs_step(self, prev_state, prev_action, embed, is_first, sample=True):
|
||||
# if shared is True, prior and post both use same networks(inp_layers, _img_out_layers, _ims_stat_layer)
|
||||
# if shared is True, prior and post both use same networks(inp_layers, _img_out_layers, _imgs_stat_layer)
|
||||
# otherwise, post use different network(_obs_out_layers) with prior[deter] and embed as inputs
|
||||
prev_action *= (1.0 / torch.clip(torch.abs(prev_action), min=1.0)).detach()
|
||||
|
||||
# initialize all prev_state
|
||||
if prev_state == None or torch.sum(is_first) == len(is_first):
|
||||
@ -246,7 +245,6 @@ class RSSM(nn.Module):
|
||||
# this is used for making future image
|
||||
def img_step(self, prev_state, prev_action, embed=None, sample=True):
|
||||
# (batch, stoch, discrete_num)
|
||||
prev_action *= (1.0 / torch.clip(torch.abs(prev_action), min=1.0)).detach()
|
||||
prev_stoch = prev_state["stoch"]
|
||||
if self._discrete:
|
||||
shape = list(prev_stoch.shape[:-2]) + [self._stoch * self._discrete]
|
||||
@ -644,6 +642,7 @@ class MLP(nn.Module):
|
||||
std=1.0,
|
||||
min_std=0.1,
|
||||
max_std=1.0,
|
||||
absmax=None,
|
||||
temp=0.1,
|
||||
unimix_ratio=0.01,
|
||||
outscale=1.0,
|
||||
@ -660,12 +659,13 @@ class MLP(nn.Module):
|
||||
norm = getattr(torch.nn, norm)
|
||||
self._dist = dist
|
||||
self._std = std
|
||||
self._symlog_inputs = symlog_inputs
|
||||
self._device = device
|
||||
self._min_std = min_std
|
||||
self._max_std = max_std
|
||||
self._absmax = absmax
|
||||
self._temp = temp
|
||||
self._unimix_ratio = unimix_ratio
|
||||
self._symlog_inputs = symlog_inputs
|
||||
self._device = device
|
||||
|
||||
self.layers = nn.Sequential()
|
||||
for index in range(self._layers):
|
||||
@ -738,23 +738,33 @@ class MLP(nn.Module):
|
||||
std + 2.0
|
||||
) + self._min_std
|
||||
dist = torchd.normal.Normal(torch.tanh(mean), std)
|
||||
dist = tools.ContDist(torchd.independent.Independent(dist, 1), absmax=1.0)
|
||||
dist = tools.ContDist(
|
||||
torchd.independent.Independent(dist, 1), absmax=self._absmax
|
||||
)
|
||||
elif self._dist == "normal_std_fixed":
|
||||
dist = torchd.normal.Normal(mean, self._std)
|
||||
dist = tools.ContDist(torchd.independent.Independent(dist, 1))
|
||||
dist = tools.ContDist(
|
||||
torchd.independent.Independent(dist, 1), absmax=self._absmax
|
||||
)
|
||||
elif self._dist == "trunc_normal":
|
||||
mean = torch.tanh(mean)
|
||||
std = 2 * torch.sigmoid(std / 2) + self._min_std
|
||||
dist = tools.SafeTruncatedNormal(mean, std, -1, 1)
|
||||
dist = tools.ContDist(torchd.independent.Independent(dist, 1))
|
||||
dist = tools.ContDist(
|
||||
torchd.independent.Independent(dist, 1), absmax=self._absmax
|
||||
)
|
||||
elif self._dist == "onehot":
|
||||
dist = tools.OneHotDist(mean, unimix_ratio=self._unimix_ratio)
|
||||
elif self._dist == "onehot_gumble":
|
||||
dist = tools.ContDist(torchd.gumbel.Gumbel(mean, 1 / self._temp))
|
||||
dist = tools.ContDist(
|
||||
torchd.gumbel.Gumbel(mean, 1 / self._temp), absmax=self._absmax
|
||||
)
|
||||
elif dist == "huber":
|
||||
dist = tools.ContDist(
|
||||
torchd.independent.Independent(
|
||||
tools.UnnormalizedHuber(mean, std, 1.0), len(shape)
|
||||
tools.UnnormalizedHuber(mean, std, 1.0),
|
||||
len(shape),
|
||||
absmax=self._absmax,
|
||||
)
|
||||
)
|
||||
elif dist == "binary":
|
||||
|
13
tools.py
13
tools.py
@ -562,10 +562,11 @@ class SymlogDist:
|
||||
|
||||
|
||||
class ContDist:
|
||||
def __init__(self, dist=None):
|
||||
def __init__(self, dist=None, absmax=None):
|
||||
super().__init__()
|
||||
self._dist = dist
|
||||
self.mean = dist.mean
|
||||
self.absmax = absmax
|
||||
|
||||
def __getattr__(self, name):
|
||||
return getattr(self._dist, name)
|
||||
@ -574,10 +575,16 @@ class ContDist:
|
||||
return self._dist.entropy()
|
||||
|
||||
def mode(self):
|
||||
return self._dist.mean
|
||||
out = self._dist.mean
|
||||
if self.absmax is not None:
|
||||
out *= (self.absmax / torch.clip(torch.abs(out), min=self.absmax)).detach()
|
||||
return out
|
||||
|
||||
def sample(self, sample_shape=()):
|
||||
return self._dist.rsample(sample_shape)
|
||||
out = self._dist.rsample(sample_shape)
|
||||
if self.absmax is not None:
|
||||
out *= (self.absmax / torch.clip(torch.abs(out), min=self.absmax)).detach()
|
||||
return out
|
||||
|
||||
def log_prob(self, x):
|
||||
return self._dist.log_prob(x)
|
||||
|
Loading…
x
Reference in New Issue
Block a user