learnable initial state options for RSSM
This commit is contained in:
parent
1328ff1088
commit
0eb66997fb
@ -63,6 +63,7 @@ defaults:
|
|||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
unimix_ratio: 0.01
|
unimix_ratio: 0.01
|
||||||
action_unimix_ratio: 0.01
|
action_unimix_ratio: 0.01
|
||||||
|
initial: 'learned'
|
||||||
|
|
||||||
# Training
|
# Training
|
||||||
batch_size: 16
|
batch_size: 16
|
||||||
|
|||||||
@ -110,9 +110,10 @@ class Dreamer(nn.Module):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
latent, action = state
|
latent, action = state
|
||||||
embed = self._wm.encoder(self._wm.preprocess(obs))
|
obs = self._wm.preprocess(obs)
|
||||||
|
embed = self._wm.encoder(obs)
|
||||||
latent, _ = self._wm.dynamics.obs_step(
|
latent, _ = self._wm.dynamics.obs_step(
|
||||||
latent, action, embed, self._config.collect_dyn_sample
|
latent, action, embed, obs["is_first"], self._config.collect_dyn_sample
|
||||||
)
|
)
|
||||||
if self._config.eval_state_mean:
|
if self._config.eval_state_mean:
|
||||||
latent["stoch"] = latent["mean"]
|
latent["stoch"] = latent["mean"]
|
||||||
|
|||||||
14
models.py
14
models.py
@ -66,6 +66,7 @@ class WorldModel(nn.Module):
|
|||||||
config.dyn_min_std,
|
config.dyn_min_std,
|
||||||
config.dyn_cell,
|
config.dyn_cell,
|
||||||
config.unimix_ratio,
|
config.unimix_ratio,
|
||||||
|
config.initial,
|
||||||
config.num_actions,
|
config.num_actions,
|
||||||
embed_size,
|
embed_size,
|
||||||
config.device,
|
config.device,
|
||||||
@ -95,6 +96,7 @@ class WorldModel(nn.Module):
|
|||||||
config.norm,
|
config.norm,
|
||||||
dist=config.reward_head,
|
dist=config.reward_head,
|
||||||
outscale=0.0,
|
outscale=0.0,
|
||||||
|
device=config.device,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.heads["reward"] = networks.DenseHead(
|
self.heads["reward"] = networks.DenseHead(
|
||||||
@ -106,6 +108,7 @@ class WorldModel(nn.Module):
|
|||||||
config.norm,
|
config.norm,
|
||||||
dist=config.reward_head,
|
dist=config.reward_head,
|
||||||
outscale=0.0,
|
outscale=0.0,
|
||||||
|
device=config.device,
|
||||||
)
|
)
|
||||||
self.heads["cont"] = networks.DenseHead(
|
self.heads["cont"] = networks.DenseHead(
|
||||||
feat_size, # pytorch version
|
feat_size, # pytorch version
|
||||||
@ -115,6 +118,7 @@ class WorldModel(nn.Module):
|
|||||||
config.act,
|
config.act,
|
||||||
config.norm,
|
config.norm,
|
||||||
dist="binary",
|
dist="binary",
|
||||||
|
device=config.device,
|
||||||
)
|
)
|
||||||
for name in config.grad_heads:
|
for name in config.grad_heads:
|
||||||
assert name in self.heads, name
|
assert name in self.heads, name
|
||||||
@ -140,7 +144,9 @@ class WorldModel(nn.Module):
|
|||||||
with tools.RequiresGrad(self):
|
with tools.RequiresGrad(self):
|
||||||
with torch.cuda.amp.autocast(self._use_amp):
|
with torch.cuda.amp.autocast(self._use_amp):
|
||||||
embed = self.encoder(data)
|
embed = self.encoder(data)
|
||||||
post, prior = self.dynamics.observe(embed, data["action"])
|
post, prior = self.dynamics.observe(
|
||||||
|
embed, data["action"], data["is_first"]
|
||||||
|
)
|
||||||
kl_free = tools.schedule(self._config.kl_free, self._step)
|
kl_free = tools.schedule(self._config.kl_free, self._step)
|
||||||
dyn_scale = tools.schedule(self._config.dyn_scale, self._step)
|
dyn_scale = tools.schedule(self._config.dyn_scale, self._step)
|
||||||
rep_scale = tools.schedule(self._config.rep_scale, self._step)
|
rep_scale = tools.schedule(self._config.rep_scale, self._step)
|
||||||
@ -204,7 +210,9 @@ class WorldModel(nn.Module):
|
|||||||
data = self.preprocess(data)
|
data = self.preprocess(data)
|
||||||
embed = self.encoder(data)
|
embed = self.encoder(data)
|
||||||
|
|
||||||
states, _ = self.dynamics.observe(embed[:6, :5], data["action"][:6, :5])
|
states, _ = self.dynamics.observe(
|
||||||
|
embed[:6, :5], data["action"][:6, :5], data["is_first"][:6, :5]
|
||||||
|
)
|
||||||
recon = self.heads["image"](self.dynamics.get_feat(states)).mode()[:6]
|
recon = self.heads["image"](self.dynamics.get_feat(states)).mode()[:6]
|
||||||
reward_post = self.heads["reward"](self.dynamics.get_feat(states)).mode()[:6]
|
reward_post = self.heads["reward"](self.dynamics.get_feat(states)).mode()[:6]
|
||||||
init = {k: v[:, -1] for k, v in states.items()}
|
init = {k: v[:, -1] for k, v in states.items()}
|
||||||
@ -257,6 +265,7 @@ class ImagBehavior(nn.Module):
|
|||||||
config.norm,
|
config.norm,
|
||||||
config.value_head,
|
config.value_head,
|
||||||
outscale=0.0,
|
outscale=0.0,
|
||||||
|
device=config.device,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.value = networks.DenseHead(
|
self.value = networks.DenseHead(
|
||||||
@ -268,6 +277,7 @@ class ImagBehavior(nn.Module):
|
|||||||
config.norm,
|
config.norm,
|
||||||
config.value_head,
|
config.value_head,
|
||||||
outscale=0.0,
|
outscale=0.0,
|
||||||
|
device=config.device,
|
||||||
)
|
)
|
||||||
if config.slow_value_target:
|
if config.slow_value_target:
|
||||||
self._slow_value = copy.deepcopy(self.value)
|
self._slow_value = copy.deepcopy(self.value)
|
||||||
|
|||||||
52
networks.py
52
networks.py
@ -28,6 +28,7 @@ class RSSM(nn.Module):
|
|||||||
min_std=0.1,
|
min_std=0.1,
|
||||||
cell="gru",
|
cell="gru",
|
||||||
unimix_ratio=0.01,
|
unimix_ratio=0.01,
|
||||||
|
initial="learned",
|
||||||
num_actions=None,
|
num_actions=None,
|
||||||
embed=None,
|
embed=None,
|
||||||
device=None,
|
device=None,
|
||||||
@ -48,6 +49,7 @@ class RSSM(nn.Module):
|
|||||||
self._std_act = std_act
|
self._std_act = std_act
|
||||||
self._temp_post = temp_post
|
self._temp_post = temp_post
|
||||||
self._unimix_ratio = unimix_ratio
|
self._unimix_ratio = unimix_ratio
|
||||||
|
self._initial = initial
|
||||||
self._embed = embed
|
self._embed = embed
|
||||||
self._device = device
|
self._device = device
|
||||||
|
|
||||||
@ -112,6 +114,12 @@ class RSSM(nn.Module):
|
|||||||
self._obs_stat_layer = nn.Linear(self._hidden, 2 * self._stoch)
|
self._obs_stat_layer = nn.Linear(self._hidden, 2 * self._stoch)
|
||||||
self._obs_stat_layer.apply(tools.weight_init)
|
self._obs_stat_layer.apply(tools.weight_init)
|
||||||
|
|
||||||
|
if self._initial == "learned":
|
||||||
|
self.W = torch.nn.Parameter(
|
||||||
|
torch.zeros((1, self._deter), device=torch.device(self._device)),
|
||||||
|
requires_grad=True,
|
||||||
|
)
|
||||||
|
|
||||||
def initial(self, batch_size):
|
def initial(self, batch_size):
|
||||||
deter = torch.zeros(batch_size, self._deter).to(self._device)
|
deter = torch.zeros(batch_size, self._deter).to(self._device)
|
||||||
if self._discrete:
|
if self._discrete:
|
||||||
@ -131,19 +139,27 @@ class RSSM(nn.Module):
|
|||||||
stoch=torch.zeros([batch_size, self._stoch]).to(self._device),
|
stoch=torch.zeros([batch_size, self._stoch]).to(self._device),
|
||||||
deter=deter,
|
deter=deter,
|
||||||
)
|
)
|
||||||
return state
|
if self._initial == "zeros":
|
||||||
|
return state
|
||||||
|
elif self._initial == "learned":
|
||||||
|
state["deter"] = torch.tanh(self.W).repeat(batch_size, 1)
|
||||||
|
state["stoch"] = self.get_stoch(state["deter"])
|
||||||
|
return state
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(self._initial)
|
||||||
|
|
||||||
def observe(self, embed, action, state=None):
|
def observe(self, embed, action, is_first, state=None):
|
||||||
swap = lambda x: x.permute([1, 0] + list(range(2, len(x.shape))))
|
swap = lambda x: x.permute([1, 0] + list(range(2, len(x.shape))))
|
||||||
if state is None:
|
if state is None:
|
||||||
state = self.initial(action.shape[0])
|
state = self.initial(action.shape[0])
|
||||||
# (batch, time, ch) -> (time, batch, ch)
|
# (batch, time, ch) -> (time, batch, ch)
|
||||||
embed, action = swap(embed), swap(action)
|
embed, action, is_first = swap(embed), swap(action), swap(is_first)
|
||||||
|
# prev_state[0] means selecting posterior of return(posterior, prior) from obs_step
|
||||||
post, prior = tools.static_scan(
|
post, prior = tools.static_scan(
|
||||||
lambda prev_state, prev_act, embed: self.obs_step(
|
lambda prev_state, prev_act, embed, is_first: self.obs_step(
|
||||||
prev_state[0], prev_act, embed
|
prev_state[0], prev_act, embed, is_first
|
||||||
),
|
),
|
||||||
(action, embed),
|
(action, embed, is_first),
|
||||||
(state, state),
|
(state, state),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -184,10 +200,22 @@ class RSSM(nn.Module):
|
|||||||
)
|
)
|
||||||
return dist
|
return dist
|
||||||
|
|
||||||
def obs_step(self, prev_state, prev_action, embed, sample=True):
|
def obs_step(self, prev_state, prev_action, embed, is_first, sample=True):
|
||||||
# if shared is True, prior and post both use same networks(inp_layers, _img_out_layers, _ims_stat_layer)
|
# if shared is True, prior and post both use same networks(inp_layers, _img_out_layers, _ims_stat_layer)
|
||||||
# otherwise, post use different network(_obs_out_layers) with prior[deter] and embed as inputs
|
# otherwise, post use different network(_obs_out_layers) with prior[deter] and embed as inputs
|
||||||
prev_action *= (1.0 / torch.clip(torch.abs(prev_action), min=1.0)).detach()
|
prev_action *= (1.0 / torch.clip(torch.abs(prev_action), min=1.0)).detach()
|
||||||
|
|
||||||
|
if torch.sum(is_first) > 0:
|
||||||
|
is_first = is_first[:, None]
|
||||||
|
prev_action *= 1.0 - is_first
|
||||||
|
init_state = self.initial(len(is_first))
|
||||||
|
for key, val in prev_state.items():
|
||||||
|
is_first_r = torch.reshape(
|
||||||
|
is_first,
|
||||||
|
is_first.shape + (1,) * (len(val.shape) - len(is_first.shape)),
|
||||||
|
)
|
||||||
|
val = val * (1.0 - is_first_r) + init_state[key] * is_first_r
|
||||||
|
|
||||||
prior = self.img_step(prev_state, prev_action, None, sample)
|
prior = self.img_step(prev_state, prev_action, None, sample)
|
||||||
if self._shared:
|
if self._shared:
|
||||||
post = self.img_step(prev_state, prev_action, embed, sample)
|
post = self.img_step(prev_state, prev_action, embed, sample)
|
||||||
@ -242,6 +270,12 @@ class RSSM(nn.Module):
|
|||||||
prior = {"stoch": stoch, "deter": deter, **stats}
|
prior = {"stoch": stoch, "deter": deter, **stats}
|
||||||
return prior
|
return prior
|
||||||
|
|
||||||
|
def get_stoch(self, deter):
|
||||||
|
x = self._img_out_layers(deter)
|
||||||
|
stats = self._suff_stats_layer("ims", x)
|
||||||
|
dist = self.get_dist(stats)
|
||||||
|
return dist.mode()
|
||||||
|
|
||||||
def _suff_stats_layer(self, name, x):
|
def _suff_stats_layer(self, name, x):
|
||||||
if self._discrete:
|
if self._discrete:
|
||||||
if name == "ims":
|
if name == "ims":
|
||||||
@ -435,6 +469,7 @@ class DenseHead(nn.Module):
|
|||||||
dist="normal",
|
dist="normal",
|
||||||
std=1.0,
|
std=1.0,
|
||||||
outscale=1.0,
|
outscale=1.0,
|
||||||
|
device="cuda",
|
||||||
):
|
):
|
||||||
super(DenseHead, self).__init__()
|
super(DenseHead, self).__init__()
|
||||||
self._shape = (shape,) if isinstance(shape, int) else shape
|
self._shape = (shape,) if isinstance(shape, int) else shape
|
||||||
@ -446,6 +481,7 @@ class DenseHead(nn.Module):
|
|||||||
self._norm = norm
|
self._norm = norm
|
||||||
self._dist = dist
|
self._dist = dist
|
||||||
self._std = std
|
self._std = std
|
||||||
|
self._device = device
|
||||||
|
|
||||||
layers = []
|
layers = []
|
||||||
for index in range(self._layers):
|
for index in range(self._layers):
|
||||||
@ -491,7 +527,7 @@ class DenseHead(nn.Module):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
if self._dist == "twohot_symlog":
|
if self._dist == "twohot_symlog":
|
||||||
return tools.TwoHotDistSymlog(logits=mean)
|
return tools.TwoHotDistSymlog(logits=mean, device=self._device)
|
||||||
raise NotImplementedError(self._dist)
|
raise NotImplementedError(self._dist)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user