unified the place to initialize the latents
This commit is contained in:
parent
49d12baa48
commit
e0f2017e28
15
dreamer.py
15
dreamer.py
@ -59,15 +59,6 @@ class Dreamer(nn.Module):
|
||||
|
||||
def __call__(self, obs, reset, state=None, training=True):
|
||||
step = self._step
|
||||
if self._should_reset(step):
|
||||
state = None
|
||||
if state is not None and reset.any():
|
||||
mask = 1 - reset
|
||||
for key in state[0].keys():
|
||||
for i in range(state[0][key].shape[0]):
|
||||
state[0][key][i] *= mask[i]
|
||||
for i in range(len(state[1])):
|
||||
state[1][i] *= mask[i]
|
||||
if training:
|
||||
steps = (
|
||||
self._config.pretrain
|
||||
@ -96,11 +87,7 @@ class Dreamer(nn.Module):
|
||||
|
||||
def _policy(self, obs, state, training):
|
||||
if state is None:
|
||||
batch_size = len(obs["image"])
|
||||
latent = self._wm.dynamics.initial(len(obs["image"]))
|
||||
action = torch.zeros((batch_size, self._config.num_actions)).to(
|
||||
self._config.device
|
||||
)
|
||||
latent = action = None
|
||||
else:
|
||||
latent, action = state
|
||||
obs = self._wm.preprocess(obs)
|
||||
|
@ -202,7 +202,7 @@ class WorldModel(nn.Module):
|
||||
]
|
||||
reward_post = self.heads["reward"](self.dynamics.get_feat(states)).mode()[:6]
|
||||
init = {k: v[:, -1] for k, v in states.items()}
|
||||
prior = self.dynamics.imagine(data["action"][:6, 5:], init)
|
||||
prior = self.dynamics.imagine_with_action(data["action"][:6, 5:], init)
|
||||
openl = self.heads["decoder"](self.dynamics.get_feat(prior))["image"].mode()
|
||||
reward_prior = self.heads["reward"](self.dynamics.get_feat(prior)).mode()
|
||||
# observed image is given until 5 steps
|
||||
|
16
networks.py
16
networks.py
@ -51,6 +51,7 @@ class RSSM(nn.Module):
|
||||
self._temp_post = temp_post
|
||||
self._unimix_ratio = unimix_ratio
|
||||
self._initial = initial
|
||||
self._num_actions = num_actions
|
||||
self._embed = embed
|
||||
self._device = device
|
||||
|
||||
@ -151,8 +152,6 @@ class RSSM(nn.Module):
|
||||
|
||||
def observe(self, embed, action, is_first, state=None):
|
||||
swap = lambda x: x.permute([1, 0] + list(range(2, len(x.shape))))
|
||||
if state is None:
|
||||
state = self.initial(action.shape[0])
|
||||
# (batch, time, ch) -> (time, batch, ch)
|
||||
embed, action, is_first = swap(embed), swap(action), swap(is_first)
|
||||
# prev_state[0] means selecting posterior of return(posterior, prior) from obs_step
|
||||
@ -169,10 +168,8 @@ class RSSM(nn.Module):
|
||||
prior = {k: swap(v) for k, v in prior.items()}
|
||||
return post, prior
|
||||
|
||||
def imagine(self, action, state=None):
|
||||
def imagine_with_action(self, action, state):
|
||||
swap = lambda x: x.permute([1, 0] + list(range(2, len(x.shape))))
|
||||
if state is None:
|
||||
state = self.initial(action.shape[0])
|
||||
assert isinstance(state, dict), state
|
||||
action = action
|
||||
action = swap(action)
|
||||
@ -206,7 +203,14 @@ class RSSM(nn.Module):
|
||||
# otherwise, post use different network(_obs_out_layers) with prior[deter] and embed as inputs
|
||||
prev_action *= (1.0 / torch.clip(torch.abs(prev_action), min=1.0)).detach()
|
||||
|
||||
if torch.sum(is_first) > 0:
|
||||
# initialize all prev_state
|
||||
if prev_state == None or torch.sum(is_first) == len(is_first):
|
||||
prev_state = self.initial(len(is_first))
|
||||
prev_action = torch.zeros((len(is_first), self._num_actions)).to(
|
||||
self._device
|
||||
)
|
||||
# overwrite the prev_state only where is_first=True
|
||||
elif torch.sum(is_first) > 0:
|
||||
is_first = is_first[:, None]
|
||||
prev_action *= 1.0 - is_first
|
||||
init_state = self.initial(len(is_first))
|
||||
|
Loading…
x
Reference in New Issue
Block a user