learnable initial state options for RSSM
This commit is contained in:
parent
1328ff1088
commit
0eb66997fb
@ -63,6 +63,7 @@ defaults:
|
||||
weight_decay: 0.0
|
||||
unimix_ratio: 0.01
|
||||
action_unimix_ratio: 0.01
|
||||
initial: 'learned'
|
||||
|
||||
# Training
|
||||
batch_size: 16
|
||||
|
@ -110,9 +110,10 @@ class Dreamer(nn.Module):
|
||||
)
|
||||
else:
|
||||
latent, action = state
|
||||
embed = self._wm.encoder(self._wm.preprocess(obs))
|
||||
obs = self._wm.preprocess(obs)
|
||||
embed = self._wm.encoder(obs)
|
||||
latent, _ = self._wm.dynamics.obs_step(
|
||||
latent, action, embed, self._config.collect_dyn_sample
|
||||
latent, action, embed, obs["is_first"], self._config.collect_dyn_sample
|
||||
)
|
||||
if self._config.eval_state_mean:
|
||||
latent["stoch"] = latent["mean"]
|
||||
|
14
models.py
14
models.py
@ -66,6 +66,7 @@ class WorldModel(nn.Module):
|
||||
config.dyn_min_std,
|
||||
config.dyn_cell,
|
||||
config.unimix_ratio,
|
||||
config.initial,
|
||||
config.num_actions,
|
||||
embed_size,
|
||||
config.device,
|
||||
@ -95,6 +96,7 @@ class WorldModel(nn.Module):
|
||||
config.norm,
|
||||
dist=config.reward_head,
|
||||
outscale=0.0,
|
||||
device=config.device,
|
||||
)
|
||||
else:
|
||||
self.heads["reward"] = networks.DenseHead(
|
||||
@ -106,6 +108,7 @@ class WorldModel(nn.Module):
|
||||
config.norm,
|
||||
dist=config.reward_head,
|
||||
outscale=0.0,
|
||||
device=config.device,
|
||||
)
|
||||
self.heads["cont"] = networks.DenseHead(
|
||||
feat_size, # pytorch version
|
||||
@ -115,6 +118,7 @@ class WorldModel(nn.Module):
|
||||
config.act,
|
||||
config.norm,
|
||||
dist="binary",
|
||||
device=config.device,
|
||||
)
|
||||
for name in config.grad_heads:
|
||||
assert name in self.heads, name
|
||||
@ -140,7 +144,9 @@ class WorldModel(nn.Module):
|
||||
with tools.RequiresGrad(self):
|
||||
with torch.cuda.amp.autocast(self._use_amp):
|
||||
embed = self.encoder(data)
|
||||
post, prior = self.dynamics.observe(embed, data["action"])
|
||||
post, prior = self.dynamics.observe(
|
||||
embed, data["action"], data["is_first"]
|
||||
)
|
||||
kl_free = tools.schedule(self._config.kl_free, self._step)
|
||||
dyn_scale = tools.schedule(self._config.dyn_scale, self._step)
|
||||
rep_scale = tools.schedule(self._config.rep_scale, self._step)
|
||||
@ -204,7 +210,9 @@ class WorldModel(nn.Module):
|
||||
data = self.preprocess(data)
|
||||
embed = self.encoder(data)
|
||||
|
||||
states, _ = self.dynamics.observe(embed[:6, :5], data["action"][:6, :5])
|
||||
states, _ = self.dynamics.observe(
|
||||
embed[:6, :5], data["action"][:6, :5], data["is_first"][:6, :5]
|
||||
)
|
||||
recon = self.heads["image"](self.dynamics.get_feat(states)).mode()[:6]
|
||||
reward_post = self.heads["reward"](self.dynamics.get_feat(states)).mode()[:6]
|
||||
init = {k: v[:, -1] for k, v in states.items()}
|
||||
@ -257,6 +265,7 @@ class ImagBehavior(nn.Module):
|
||||
config.norm,
|
||||
config.value_head,
|
||||
outscale=0.0,
|
||||
device=config.device,
|
||||
)
|
||||
else:
|
||||
self.value = networks.DenseHead(
|
||||
@ -268,6 +277,7 @@ class ImagBehavior(nn.Module):
|
||||
config.norm,
|
||||
config.value_head,
|
||||
outscale=0.0,
|
||||
device=config.device,
|
||||
)
|
||||
if config.slow_value_target:
|
||||
self._slow_value = copy.deepcopy(self.value)
|
||||
|
52
networks.py
52
networks.py
@ -28,6 +28,7 @@ class RSSM(nn.Module):
|
||||
min_std=0.1,
|
||||
cell="gru",
|
||||
unimix_ratio=0.01,
|
||||
initial="learned",
|
||||
num_actions=None,
|
||||
embed=None,
|
||||
device=None,
|
||||
@ -48,6 +49,7 @@ class RSSM(nn.Module):
|
||||
self._std_act = std_act
|
||||
self._temp_post = temp_post
|
||||
self._unimix_ratio = unimix_ratio
|
||||
self._initial = initial
|
||||
self._embed = embed
|
||||
self._device = device
|
||||
|
||||
@ -112,6 +114,12 @@ class RSSM(nn.Module):
|
||||
self._obs_stat_layer = nn.Linear(self._hidden, 2 * self._stoch)
|
||||
self._obs_stat_layer.apply(tools.weight_init)
|
||||
|
||||
if self._initial == "learned":
|
||||
self.W = torch.nn.Parameter(
|
||||
torch.zeros((1, self._deter), device=torch.device(self._device)),
|
||||
requires_grad=True,
|
||||
)
|
||||
|
||||
def initial(self, batch_size):
|
||||
deter = torch.zeros(batch_size, self._deter).to(self._device)
|
||||
if self._discrete:
|
||||
@ -131,19 +139,27 @@ class RSSM(nn.Module):
|
||||
stoch=torch.zeros([batch_size, self._stoch]).to(self._device),
|
||||
deter=deter,
|
||||
)
|
||||
return state
|
||||
if self._initial == "zeros":
|
||||
return state
|
||||
elif self._initial == "learned":
|
||||
state["deter"] = torch.tanh(self.W).repeat(batch_size, 1)
|
||||
state["stoch"] = self.get_stoch(state["deter"])
|
||||
return state
|
||||
else:
|
||||
raise NotImplementedError(self._initial)
|
||||
|
||||
def observe(self, embed, action, state=None):
|
||||
def observe(self, embed, action, is_first, state=None):
|
||||
swap = lambda x: x.permute([1, 0] + list(range(2, len(x.shape))))
|
||||
if state is None:
|
||||
state = self.initial(action.shape[0])
|
||||
# (batch, time, ch) -> (time, batch, ch)
|
||||
embed, action = swap(embed), swap(action)
|
||||
embed, action, is_first = swap(embed), swap(action), swap(is_first)
|
||||
# prev_state[0] means selecting posterior of return(posterior, prior) from obs_step
|
||||
post, prior = tools.static_scan(
|
||||
lambda prev_state, prev_act, embed: self.obs_step(
|
||||
prev_state[0], prev_act, embed
|
||||
lambda prev_state, prev_act, embed, is_first: self.obs_step(
|
||||
prev_state[0], prev_act, embed, is_first
|
||||
),
|
||||
(action, embed),
|
||||
(action, embed, is_first),
|
||||
(state, state),
|
||||
)
|
||||
|
||||
@ -184,10 +200,22 @@ class RSSM(nn.Module):
|
||||
)
|
||||
return dist
|
||||
|
||||
def obs_step(self, prev_state, prev_action, embed, sample=True):
|
||||
def obs_step(self, prev_state, prev_action, embed, is_first, sample=True):
|
||||
# if shared is True, prior and post both use same networks(inp_layers, _img_out_layers, _ims_stat_layer)
|
||||
# otherwise, post use different network(_obs_out_layers) with prior[deter] and embed as inputs
|
||||
prev_action *= (1.0 / torch.clip(torch.abs(prev_action), min=1.0)).detach()
|
||||
|
||||
if torch.sum(is_first) > 0:
|
||||
is_first = is_first[:, None]
|
||||
prev_action *= 1.0 - is_first
|
||||
init_state = self.initial(len(is_first))
|
||||
for key, val in prev_state.items():
|
||||
is_first_r = torch.reshape(
|
||||
is_first,
|
||||
is_first.shape + (1,) * (len(val.shape) - len(is_first.shape)),
|
||||
)
|
||||
val = val * (1.0 - is_first_r) + init_state[key] * is_first_r
|
||||
|
||||
prior = self.img_step(prev_state, prev_action, None, sample)
|
||||
if self._shared:
|
||||
post = self.img_step(prev_state, prev_action, embed, sample)
|
||||
@ -242,6 +270,12 @@ class RSSM(nn.Module):
|
||||
prior = {"stoch": stoch, "deter": deter, **stats}
|
||||
return prior
|
||||
|
||||
def get_stoch(self, deter):
|
||||
x = self._img_out_layers(deter)
|
||||
stats = self._suff_stats_layer("ims", x)
|
||||
dist = self.get_dist(stats)
|
||||
return dist.mode()
|
||||
|
||||
def _suff_stats_layer(self, name, x):
|
||||
if self._discrete:
|
||||
if name == "ims":
|
||||
@ -435,6 +469,7 @@ class DenseHead(nn.Module):
|
||||
dist="normal",
|
||||
std=1.0,
|
||||
outscale=1.0,
|
||||
device="cuda",
|
||||
):
|
||||
super(DenseHead, self).__init__()
|
||||
self._shape = (shape,) if isinstance(shape, int) else shape
|
||||
@ -446,6 +481,7 @@ class DenseHead(nn.Module):
|
||||
self._norm = norm
|
||||
self._dist = dist
|
||||
self._std = std
|
||||
self._device = device
|
||||
|
||||
layers = []
|
||||
for index in range(self._layers):
|
||||
@ -491,7 +527,7 @@ class DenseHead(nn.Module):
|
||||
)
|
||||
)
|
||||
if self._dist == "twohot_symlog":
|
||||
return tools.TwoHotDistSymlog(logits=mean)
|
||||
return tools.TwoHotDistSymlog(logits=mean, device=self._device)
|
||||
raise NotImplementedError(self._dist)
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user