From a27711ab9645762ca1e664b16eade07dde1755b9 Mon Sep 17 00:00:00 2001 From: NM512 Date: Fri, 5 Jan 2024 11:42:45 +0900 Subject: [PATCH] limit action values in sampling stage --- models.py | 5 +++-- networks.py | 30 ++++++++++++++++++++---------- tools.py | 13 ++++++++++--- 3 files changed, 33 insertions(+), 15 deletions(-) diff --git a/models.py b/models.py index 69f6b53..ace95b4 100644 --- a/models.py +++ b/models.py @@ -239,9 +239,10 @@ class ImagBehavior(nn.Module): "learned", config.actor["min_std"], config.actor["max_std"], - config.actor["temp"], + absmax=1.0, + temp=config.actor["temp"], unimix_ratio=config.actor["unimix_ratio"], - outscale=1.0, + outscale=config.actor["outscale"], name="Actor", ) self.value = networks.MLP( diff --git a/networks.py b/networks.py index 3616e2e..1ad5772 100644 --- a/networks.py +++ b/networks.py @@ -200,9 +200,8 @@ class RSSM(nn.Module): return dist def obs_step(self, prev_state, prev_action, embed, is_first, sample=True): - # if shared is True, prior and post both use same networks(inp_layers, _img_out_layers, _ims_stat_layer) + # if shared is True, prior and post both use same networks(inp_layers, _img_out_layers, _imgs_stat_layer) # otherwise, post use different network(_obs_out_layers) with prior[deter] and embed as inputs - prev_action *= (1.0 / torch.clip(torch.abs(prev_action), min=1.0)).detach() # initialize all prev_state if prev_state == None or torch.sum(is_first) == len(is_first): @@ -246,7 +245,6 @@ class RSSM(nn.Module): # this is used for making future image def img_step(self, prev_state, prev_action, embed=None, sample=True): # (batch, stoch, discrete_num) - prev_action *= (1.0 / torch.clip(torch.abs(prev_action), min=1.0)).detach() prev_stoch = prev_state["stoch"] if self._discrete: shape = list(prev_stoch.shape[:-2]) + [self._stoch * self._discrete] @@ -644,6 +642,7 @@ class MLP(nn.Module): std=1.0, min_std=0.1, max_std=1.0, + absmax=None, temp=0.1, unimix_ratio=0.01, outscale=1.0, @@ -660,12 +659,13 @@ class MLP(nn.Module): norm = getattr(torch.nn, norm) self._dist = dist self._std = std - self._symlog_inputs = symlog_inputs - self._device = device self._min_std = min_std self._max_std = max_std + self._absmax = absmax self._temp = temp self._unimix_ratio = unimix_ratio + self._symlog_inputs = symlog_inputs + self._device = device self.layers = nn.Sequential() for index in range(self._layers): @@ -738,23 +738,33 @@ class MLP(nn.Module): std + 2.0 ) + self._min_std dist = torchd.normal.Normal(torch.tanh(mean), std) - dist = tools.ContDist(torchd.independent.Independent(dist, 1), absmax=1.0) + dist = tools.ContDist( + torchd.independent.Independent(dist, 1), absmax=self._absmax + ) elif self._dist == "normal_std_fixed": dist = torchd.normal.Normal(mean, self._std) - dist = tools.ContDist(torchd.independent.Independent(dist, 1)) + dist = tools.ContDist( + torchd.independent.Independent(dist, 1), absmax=self._absmax + ) elif self._dist == "trunc_normal": mean = torch.tanh(mean) std = 2 * torch.sigmoid(std / 2) + self._min_std dist = tools.SafeTruncatedNormal(mean, std, -1, 1) - dist = tools.ContDist(torchd.independent.Independent(dist, 1)) + dist = tools.ContDist( + torchd.independent.Independent(dist, 1), absmax=self._absmax + ) elif self._dist == "onehot": dist = tools.OneHotDist(mean, unimix_ratio=self._unimix_ratio) elif self._dist == "onehot_gumble": - dist = tools.ContDist(torchd.gumbel.Gumbel(mean, 1 / self._temp)) + dist = tools.ContDist( + torchd.gumbel.Gumbel(mean, 1 / self._temp), absmax=self._absmax + ) elif dist == "huber": dist = tools.ContDist( torchd.independent.Independent( - tools.UnnormalizedHuber(mean, std, 1.0), len(shape) + tools.UnnormalizedHuber(mean, std, 1.0), + len(shape), + absmax=self._absmax, ) ) elif dist == "binary": diff --git a/tools.py b/tools.py index 1aff067..cb09056 100644 --- a/tools.py +++ b/tools.py @@ -562,10 +562,11 @@ class SymlogDist: class ContDist: - def __init__(self, dist=None): + def __init__(self, dist=None, absmax=None): super().__init__() self._dist = dist self.mean = dist.mean + self.absmax = absmax def __getattr__(self, name): return getattr(self._dist, name) @@ -574,10 +575,16 @@ class ContDist: return self._dist.entropy() def mode(self): - return self._dist.mean + out = self._dist.mean + if self.absmax is not None: + out *= (self.absmax / torch.clip(torch.abs(out), min=self.absmax)).detach() + return out def sample(self, sample_shape=()): - return self._dist.rsample(sample_shape) + out = self._dist.rsample(sample_shape) + if self.absmax is not None: + out *= (self.absmax / torch.clip(torch.abs(out), min=self.absmax)).detach() + return out def log_prob(self, x): return self._dist.log_prob(x)