limit action values in sampling stage

This commit is contained in:
NM512 2024-01-05 11:42:45 +09:00
parent a9e85e8b7c
commit a27711ab96
3 changed files with 33 additions and 15 deletions

View File

@ -239,9 +239,10 @@ class ImagBehavior(nn.Module):
"learned",
config.actor["min_std"],
config.actor["max_std"],
config.actor["temp"],
absmax=1.0,
temp=config.actor["temp"],
unimix_ratio=config.actor["unimix_ratio"],
outscale=1.0,
outscale=config.actor["outscale"],
name="Actor",
)
self.value = networks.MLP(

View File

@ -200,9 +200,8 @@ class RSSM(nn.Module):
return dist
def obs_step(self, prev_state, prev_action, embed, is_first, sample=True):
# if shared is True, prior and post both use same networks(inp_layers, _img_out_layers, _ims_stat_layer)
# if shared is True, prior and post both use same networks(inp_layers, _img_out_layers, _imgs_stat_layer)
# otherwise, post use different network(_obs_out_layers) with prior[deter] and embed as inputs
prev_action *= (1.0 / torch.clip(torch.abs(prev_action), min=1.0)).detach()
# initialize all prev_state
if prev_state == None or torch.sum(is_first) == len(is_first):
@ -246,7 +245,6 @@ class RSSM(nn.Module):
# this is used for making future image
def img_step(self, prev_state, prev_action, embed=None, sample=True):
# (batch, stoch, discrete_num)
prev_action *= (1.0 / torch.clip(torch.abs(prev_action), min=1.0)).detach()
prev_stoch = prev_state["stoch"]
if self._discrete:
shape = list(prev_stoch.shape[:-2]) + [self._stoch * self._discrete]
@ -644,6 +642,7 @@ class MLP(nn.Module):
std=1.0,
min_std=0.1,
max_std=1.0,
absmax=None,
temp=0.1,
unimix_ratio=0.01,
outscale=1.0,
@ -660,12 +659,13 @@ class MLP(nn.Module):
norm = getattr(torch.nn, norm)
self._dist = dist
self._std = std
self._symlog_inputs = symlog_inputs
self._device = device
self._min_std = min_std
self._max_std = max_std
self._absmax = absmax
self._temp = temp
self._unimix_ratio = unimix_ratio
self._symlog_inputs = symlog_inputs
self._device = device
self.layers = nn.Sequential()
for index in range(self._layers):
@ -738,23 +738,33 @@ class MLP(nn.Module):
std + 2.0
) + self._min_std
dist = torchd.normal.Normal(torch.tanh(mean), std)
dist = tools.ContDist(torchd.independent.Independent(dist, 1), absmax=1.0)
dist = tools.ContDist(
torchd.independent.Independent(dist, 1), absmax=self._absmax
)
elif self._dist == "normal_std_fixed":
dist = torchd.normal.Normal(mean, self._std)
dist = tools.ContDist(torchd.independent.Independent(dist, 1))
dist = tools.ContDist(
torchd.independent.Independent(dist, 1), absmax=self._absmax
)
elif self._dist == "trunc_normal":
mean = torch.tanh(mean)
std = 2 * torch.sigmoid(std / 2) + self._min_std
dist = tools.SafeTruncatedNormal(mean, std, -1, 1)
dist = tools.ContDist(torchd.independent.Independent(dist, 1))
dist = tools.ContDist(
torchd.independent.Independent(dist, 1), absmax=self._absmax
)
elif self._dist == "onehot":
dist = tools.OneHotDist(mean, unimix_ratio=self._unimix_ratio)
elif self._dist == "onehot_gumble":
dist = tools.ContDist(torchd.gumbel.Gumbel(mean, 1 / self._temp))
dist = tools.ContDist(
torchd.gumbel.Gumbel(mean, 1 / self._temp), absmax=self._absmax
)
elif dist == "huber":
dist = tools.ContDist(
torchd.independent.Independent(
tools.UnnormalizedHuber(mean, std, 1.0), len(shape)
tools.UnnormalizedHuber(mean, std, 1.0),
len(shape),
absmax=self._absmax,
)
)
elif dist == "binary":

View File

@ -562,10 +562,11 @@ class SymlogDist:
class ContDist:
def __init__(self, dist=None):
def __init__(self, dist=None, absmax=None):
super().__init__()
self._dist = dist
self.mean = dist.mean
self.absmax = absmax
def __getattr__(self, name):
return getattr(self._dist, name)
@ -574,10 +575,16 @@ class ContDist:
return self._dist.entropy()
def mode(self):
return self._dist.mean
out = self._dist.mean
if self.absmax is not None:
out *= (self.absmax / torch.clip(torch.abs(out), min=self.absmax)).detach()
return out
def sample(self, sample_shape=()):
return self._dist.rsample(sample_shape)
out = self._dist.rsample(sample_shape)
if self.absmax is not None:
out *= (self.absmax / torch.clip(torch.abs(out), min=self.absmax)).detach()
return out
def log_prob(self, x):
return self._dist.log_prob(x)