learnable initial state options for RSSM

This commit is contained in:
NM512 2023-04-29 07:54:03 +09:00
parent 1328ff1088
commit 0eb66997fb
4 changed files with 60 additions and 12 deletions

View File

@ -63,6 +63,7 @@ defaults:
weight_decay: 0.0
unimix_ratio: 0.01
action_unimix_ratio: 0.01
initial: 'learned'
# Training
batch_size: 16

View File

@ -110,9 +110,10 @@ class Dreamer(nn.Module):
)
else:
latent, action = state
embed = self._wm.encoder(self._wm.preprocess(obs))
obs = self._wm.preprocess(obs)
embed = self._wm.encoder(obs)
latent, _ = self._wm.dynamics.obs_step(
latent, action, embed, self._config.collect_dyn_sample
latent, action, embed, obs["is_first"], self._config.collect_dyn_sample
)
if self._config.eval_state_mean:
latent["stoch"] = latent["mean"]

View File

@ -66,6 +66,7 @@ class WorldModel(nn.Module):
config.dyn_min_std,
config.dyn_cell,
config.unimix_ratio,
config.initial,
config.num_actions,
embed_size,
config.device,
@ -95,6 +96,7 @@ class WorldModel(nn.Module):
config.norm,
dist=config.reward_head,
outscale=0.0,
device=config.device,
)
else:
self.heads["reward"] = networks.DenseHead(
@ -106,6 +108,7 @@ class WorldModel(nn.Module):
config.norm,
dist=config.reward_head,
outscale=0.0,
device=config.device,
)
self.heads["cont"] = networks.DenseHead(
feat_size, # pytorch version
@ -115,6 +118,7 @@ class WorldModel(nn.Module):
config.act,
config.norm,
dist="binary",
device=config.device,
)
for name in config.grad_heads:
assert name in self.heads, name
@ -140,7 +144,9 @@ class WorldModel(nn.Module):
with tools.RequiresGrad(self):
with torch.cuda.amp.autocast(self._use_amp):
embed = self.encoder(data)
post, prior = self.dynamics.observe(embed, data["action"])
post, prior = self.dynamics.observe(
embed, data["action"], data["is_first"]
)
kl_free = tools.schedule(self._config.kl_free, self._step)
dyn_scale = tools.schedule(self._config.dyn_scale, self._step)
rep_scale = tools.schedule(self._config.rep_scale, self._step)
@ -204,7 +210,9 @@ class WorldModel(nn.Module):
data = self.preprocess(data)
embed = self.encoder(data)
states, _ = self.dynamics.observe(embed[:6, :5], data["action"][:6, :5])
states, _ = self.dynamics.observe(
embed[:6, :5], data["action"][:6, :5], data["is_first"][:6, :5]
)
recon = self.heads["image"](self.dynamics.get_feat(states)).mode()[:6]
reward_post = self.heads["reward"](self.dynamics.get_feat(states)).mode()[:6]
init = {k: v[:, -1] for k, v in states.items()}
@ -257,6 +265,7 @@ class ImagBehavior(nn.Module):
config.norm,
config.value_head,
outscale=0.0,
device=config.device,
)
else:
self.value = networks.DenseHead(
@ -268,6 +277,7 @@ class ImagBehavior(nn.Module):
config.norm,
config.value_head,
outscale=0.0,
device=config.device,
)
if config.slow_value_target:
self._slow_value = copy.deepcopy(self.value)

View File

@ -28,6 +28,7 @@ class RSSM(nn.Module):
min_std=0.1,
cell="gru",
unimix_ratio=0.01,
initial="learned",
num_actions=None,
embed=None,
device=None,
@ -48,6 +49,7 @@ class RSSM(nn.Module):
self._std_act = std_act
self._temp_post = temp_post
self._unimix_ratio = unimix_ratio
self._initial = initial
self._embed = embed
self._device = device
@ -112,6 +114,12 @@ class RSSM(nn.Module):
self._obs_stat_layer = nn.Linear(self._hidden, 2 * self._stoch)
self._obs_stat_layer.apply(tools.weight_init)
if self._initial == "learned":
self.W = torch.nn.Parameter(
torch.zeros((1, self._deter), device=torch.device(self._device)),
requires_grad=True,
)
def initial(self, batch_size):
deter = torch.zeros(batch_size, self._deter).to(self._device)
if self._discrete:
@ -131,19 +139,27 @@ class RSSM(nn.Module):
stoch=torch.zeros([batch_size, self._stoch]).to(self._device),
deter=deter,
)
return state
if self._initial == "zeros":
return state
elif self._initial == "learned":
state["deter"] = torch.tanh(self.W).repeat(batch_size, 1)
state["stoch"] = self.get_stoch(state["deter"])
return state
else:
raise NotImplementedError(self._initial)
def observe(self, embed, action, state=None):
def observe(self, embed, action, is_first, state=None):
swap = lambda x: x.permute([1, 0] + list(range(2, len(x.shape))))
if state is None:
state = self.initial(action.shape[0])
# (batch, time, ch) -> (time, batch, ch)
embed, action = swap(embed), swap(action)
embed, action, is_first = swap(embed), swap(action), swap(is_first)
# prev_state[0] means selecting posterior of return(posterior, prior) from obs_step
post, prior = tools.static_scan(
lambda prev_state, prev_act, embed: self.obs_step(
prev_state[0], prev_act, embed
lambda prev_state, prev_act, embed, is_first: self.obs_step(
prev_state[0], prev_act, embed, is_first
),
(action, embed),
(action, embed, is_first),
(state, state),
)
@ -184,10 +200,22 @@ class RSSM(nn.Module):
)
return dist
def obs_step(self, prev_state, prev_action, embed, sample=True):
def obs_step(self, prev_state, prev_action, embed, is_first, sample=True):
# if shared is True, prior and post both use same networks(inp_layers, _img_out_layers, _ims_stat_layer)
# otherwise, post use different network(_obs_out_layers) with prior[deter] and embed as inputs
prev_action *= (1.0 / torch.clip(torch.abs(prev_action), min=1.0)).detach()
if torch.sum(is_first) > 0:
is_first = is_first[:, None]
prev_action *= 1.0 - is_first
init_state = self.initial(len(is_first))
for key, val in prev_state.items():
is_first_r = torch.reshape(
is_first,
is_first.shape + (1,) * (len(val.shape) - len(is_first.shape)),
)
val = val * (1.0 - is_first_r) + init_state[key] * is_first_r
prior = self.img_step(prev_state, prev_action, None, sample)
if self._shared:
post = self.img_step(prev_state, prev_action, embed, sample)
@ -242,6 +270,12 @@ class RSSM(nn.Module):
prior = {"stoch": stoch, "deter": deter, **stats}
return prior
def get_stoch(self, deter):
x = self._img_out_layers(deter)
stats = self._suff_stats_layer("ims", x)
dist = self.get_dist(stats)
return dist.mode()
def _suff_stats_layer(self, name, x):
if self._discrete:
if name == "ims":
@ -435,6 +469,7 @@ class DenseHead(nn.Module):
dist="normal",
std=1.0,
outscale=1.0,
device="cuda",
):
super(DenseHead, self).__init__()
self._shape = (shape,) if isinstance(shape, int) else shape
@ -446,6 +481,7 @@ class DenseHead(nn.Module):
self._norm = norm
self._dist = dist
self._std = std
self._device = device
layers = []
for index in range(self._layers):
@ -491,7 +527,7 @@ class DenseHead(nn.Module):
)
)
if self._dist == "twohot_symlog":
return tools.TwoHotDistSymlog(logits=mean)
return tools.TwoHotDistSymlog(logits=mean, device=self._device)
raise NotImplementedError(self._dist)