limit action values in sampling stage
This commit is contained in:
parent
a9e85e8b7c
commit
a27711ab96
@ -239,9 +239,10 @@ class ImagBehavior(nn.Module):
|
|||||||
"learned",
|
"learned",
|
||||||
config.actor["min_std"],
|
config.actor["min_std"],
|
||||||
config.actor["max_std"],
|
config.actor["max_std"],
|
||||||
config.actor["temp"],
|
absmax=1.0,
|
||||||
|
temp=config.actor["temp"],
|
||||||
unimix_ratio=config.actor["unimix_ratio"],
|
unimix_ratio=config.actor["unimix_ratio"],
|
||||||
outscale=1.0,
|
outscale=config.actor["outscale"],
|
||||||
name="Actor",
|
name="Actor",
|
||||||
)
|
)
|
||||||
self.value = networks.MLP(
|
self.value = networks.MLP(
|
||||||
|
30
networks.py
30
networks.py
@ -200,9 +200,8 @@ class RSSM(nn.Module):
|
|||||||
return dist
|
return dist
|
||||||
|
|
||||||
def obs_step(self, prev_state, prev_action, embed, is_first, sample=True):
|
def obs_step(self, prev_state, prev_action, embed, is_first, sample=True):
|
||||||
# if shared is True, prior and post both use same networks(inp_layers, _img_out_layers, _ims_stat_layer)
|
# if shared is True, prior and post both use same networks(inp_layers, _img_out_layers, _imgs_stat_layer)
|
||||||
# otherwise, post use different network(_obs_out_layers) with prior[deter] and embed as inputs
|
# otherwise, post use different network(_obs_out_layers) with prior[deter] and embed as inputs
|
||||||
prev_action *= (1.0 / torch.clip(torch.abs(prev_action), min=1.0)).detach()
|
|
||||||
|
|
||||||
# initialize all prev_state
|
# initialize all prev_state
|
||||||
if prev_state == None or torch.sum(is_first) == len(is_first):
|
if prev_state == None or torch.sum(is_first) == len(is_first):
|
||||||
@ -246,7 +245,6 @@ class RSSM(nn.Module):
|
|||||||
# this is used for making future image
|
# this is used for making future image
|
||||||
def img_step(self, prev_state, prev_action, embed=None, sample=True):
|
def img_step(self, prev_state, prev_action, embed=None, sample=True):
|
||||||
# (batch, stoch, discrete_num)
|
# (batch, stoch, discrete_num)
|
||||||
prev_action *= (1.0 / torch.clip(torch.abs(prev_action), min=1.0)).detach()
|
|
||||||
prev_stoch = prev_state["stoch"]
|
prev_stoch = prev_state["stoch"]
|
||||||
if self._discrete:
|
if self._discrete:
|
||||||
shape = list(prev_stoch.shape[:-2]) + [self._stoch * self._discrete]
|
shape = list(prev_stoch.shape[:-2]) + [self._stoch * self._discrete]
|
||||||
@ -644,6 +642,7 @@ class MLP(nn.Module):
|
|||||||
std=1.0,
|
std=1.0,
|
||||||
min_std=0.1,
|
min_std=0.1,
|
||||||
max_std=1.0,
|
max_std=1.0,
|
||||||
|
absmax=None,
|
||||||
temp=0.1,
|
temp=0.1,
|
||||||
unimix_ratio=0.01,
|
unimix_ratio=0.01,
|
||||||
outscale=1.0,
|
outscale=1.0,
|
||||||
@ -660,12 +659,13 @@ class MLP(nn.Module):
|
|||||||
norm = getattr(torch.nn, norm)
|
norm = getattr(torch.nn, norm)
|
||||||
self._dist = dist
|
self._dist = dist
|
||||||
self._std = std
|
self._std = std
|
||||||
self._symlog_inputs = symlog_inputs
|
|
||||||
self._device = device
|
|
||||||
self._min_std = min_std
|
self._min_std = min_std
|
||||||
self._max_std = max_std
|
self._max_std = max_std
|
||||||
|
self._absmax = absmax
|
||||||
self._temp = temp
|
self._temp = temp
|
||||||
self._unimix_ratio = unimix_ratio
|
self._unimix_ratio = unimix_ratio
|
||||||
|
self._symlog_inputs = symlog_inputs
|
||||||
|
self._device = device
|
||||||
|
|
||||||
self.layers = nn.Sequential()
|
self.layers = nn.Sequential()
|
||||||
for index in range(self._layers):
|
for index in range(self._layers):
|
||||||
@ -738,23 +738,33 @@ class MLP(nn.Module):
|
|||||||
std + 2.0
|
std + 2.0
|
||||||
) + self._min_std
|
) + self._min_std
|
||||||
dist = torchd.normal.Normal(torch.tanh(mean), std)
|
dist = torchd.normal.Normal(torch.tanh(mean), std)
|
||||||
dist = tools.ContDist(torchd.independent.Independent(dist, 1), absmax=1.0)
|
dist = tools.ContDist(
|
||||||
|
torchd.independent.Independent(dist, 1), absmax=self._absmax
|
||||||
|
)
|
||||||
elif self._dist == "normal_std_fixed":
|
elif self._dist == "normal_std_fixed":
|
||||||
dist = torchd.normal.Normal(mean, self._std)
|
dist = torchd.normal.Normal(mean, self._std)
|
||||||
dist = tools.ContDist(torchd.independent.Independent(dist, 1))
|
dist = tools.ContDist(
|
||||||
|
torchd.independent.Independent(dist, 1), absmax=self._absmax
|
||||||
|
)
|
||||||
elif self._dist == "trunc_normal":
|
elif self._dist == "trunc_normal":
|
||||||
mean = torch.tanh(mean)
|
mean = torch.tanh(mean)
|
||||||
std = 2 * torch.sigmoid(std / 2) + self._min_std
|
std = 2 * torch.sigmoid(std / 2) + self._min_std
|
||||||
dist = tools.SafeTruncatedNormal(mean, std, -1, 1)
|
dist = tools.SafeTruncatedNormal(mean, std, -1, 1)
|
||||||
dist = tools.ContDist(torchd.independent.Independent(dist, 1))
|
dist = tools.ContDist(
|
||||||
|
torchd.independent.Independent(dist, 1), absmax=self._absmax
|
||||||
|
)
|
||||||
elif self._dist == "onehot":
|
elif self._dist == "onehot":
|
||||||
dist = tools.OneHotDist(mean, unimix_ratio=self._unimix_ratio)
|
dist = tools.OneHotDist(mean, unimix_ratio=self._unimix_ratio)
|
||||||
elif self._dist == "onehot_gumble":
|
elif self._dist == "onehot_gumble":
|
||||||
dist = tools.ContDist(torchd.gumbel.Gumbel(mean, 1 / self._temp))
|
dist = tools.ContDist(
|
||||||
|
torchd.gumbel.Gumbel(mean, 1 / self._temp), absmax=self._absmax
|
||||||
|
)
|
||||||
elif dist == "huber":
|
elif dist == "huber":
|
||||||
dist = tools.ContDist(
|
dist = tools.ContDist(
|
||||||
torchd.independent.Independent(
|
torchd.independent.Independent(
|
||||||
tools.UnnormalizedHuber(mean, std, 1.0), len(shape)
|
tools.UnnormalizedHuber(mean, std, 1.0),
|
||||||
|
len(shape),
|
||||||
|
absmax=self._absmax,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
elif dist == "binary":
|
elif dist == "binary":
|
||||||
|
13
tools.py
13
tools.py
@ -562,10 +562,11 @@ class SymlogDist:
|
|||||||
|
|
||||||
|
|
||||||
class ContDist:
|
class ContDist:
|
||||||
def __init__(self, dist=None):
|
def __init__(self, dist=None, absmax=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._dist = dist
|
self._dist = dist
|
||||||
self.mean = dist.mean
|
self.mean = dist.mean
|
||||||
|
self.absmax = absmax
|
||||||
|
|
||||||
def __getattr__(self, name):
|
def __getattr__(self, name):
|
||||||
return getattr(self._dist, name)
|
return getattr(self._dist, name)
|
||||||
@ -574,10 +575,16 @@ class ContDist:
|
|||||||
return self._dist.entropy()
|
return self._dist.entropy()
|
||||||
|
|
||||||
def mode(self):
|
def mode(self):
|
||||||
return self._dist.mean
|
out = self._dist.mean
|
||||||
|
if self.absmax is not None:
|
||||||
|
out *= (self.absmax / torch.clip(torch.abs(out), min=self.absmax)).detach()
|
||||||
|
return out
|
||||||
|
|
||||||
def sample(self, sample_shape=()):
|
def sample(self, sample_shape=()):
|
||||||
return self._dist.rsample(sample_shape)
|
out = self._dist.rsample(sample_shape)
|
||||||
|
if self.absmax is not None:
|
||||||
|
out *= (self.absmax / torch.clip(torch.abs(out), min=self.absmax)).detach()
|
||||||
|
return out
|
||||||
|
|
||||||
def log_prob(self, x):
|
def log_prob(self, x):
|
||||||
return self._dist.log_prob(x)
|
return self._dist.log_prob(x)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user