diff --git a/dreamer.py b/dreamer.py index 6855c8c..218f6fd 100644 --- a/dreamer.py +++ b/dreamer.py @@ -59,15 +59,6 @@ class Dreamer(nn.Module): def __call__(self, obs, reset, state=None, training=True): step = self._step - if self._should_reset(step): - state = None - if state is not None and reset.any(): - mask = 1 - reset - for key in state[0].keys(): - for i in range(state[0][key].shape[0]): - state[0][key][i] *= mask[i] - for i in range(len(state[1])): - state[1][i] *= mask[i] if training: steps = ( self._config.pretrain @@ -96,11 +87,7 @@ class Dreamer(nn.Module): def _policy(self, obs, state, training): if state is None: - batch_size = len(obs["image"]) - latent = self._wm.dynamics.initial(len(obs["image"])) - action = torch.zeros((batch_size, self._config.num_actions)).to( - self._config.device - ) + latent = action = None else: latent, action = state obs = self._wm.preprocess(obs) diff --git a/models.py b/models.py index 388fb36..417fa6a 100644 --- a/models.py +++ b/models.py @@ -202,7 +202,7 @@ class WorldModel(nn.Module): ] reward_post = self.heads["reward"](self.dynamics.get_feat(states)).mode()[:6] init = {k: v[:, -1] for k, v in states.items()} - prior = self.dynamics.imagine(data["action"][:6, 5:], init) + prior = self.dynamics.imagine_with_action(data["action"][:6, 5:], init) openl = self.heads["decoder"](self.dynamics.get_feat(prior))["image"].mode() reward_prior = self.heads["reward"](self.dynamics.get_feat(prior)).mode() # observed image is given until 5 steps diff --git a/networks.py b/networks.py index b1edfcc..38769b7 100644 --- a/networks.py +++ b/networks.py @@ -51,6 +51,7 @@ class RSSM(nn.Module): self._temp_post = temp_post self._unimix_ratio = unimix_ratio self._initial = initial + self._num_actions = num_actions self._embed = embed self._device = device @@ -151,8 +152,6 @@ class RSSM(nn.Module): def observe(self, embed, action, is_first, state=None): swap = lambda x: x.permute([1, 0] + list(range(2, len(x.shape)))) - if state is None: - state = self.initial(action.shape[0]) # (batch, time, ch) -> (time, batch, ch) embed, action, is_first = swap(embed), swap(action), swap(is_first) # prev_state[0] means selecting posterior of return(posterior, prior) from obs_step @@ -169,10 +168,8 @@ class RSSM(nn.Module): prior = {k: swap(v) for k, v in prior.items()} return post, prior - def imagine(self, action, state=None): + def imagine_with_action(self, action, state): swap = lambda x: x.permute([1, 0] + list(range(2, len(x.shape)))) - if state is None: - state = self.initial(action.shape[0]) assert isinstance(state, dict), state action = action action = swap(action) @@ -206,7 +203,14 @@ class RSSM(nn.Module): # otherwise, post use different network(_obs_out_layers) with prior[deter] and embed as inputs prev_action *= (1.0 / torch.clip(torch.abs(prev_action), min=1.0)).detach() - if torch.sum(is_first) > 0: + # initialize all prev_state + if prev_state == None or torch.sum(is_first) == len(is_first): + prev_state = self.initial(len(is_first)) + prev_action = torch.zeros((len(is_first), self._num_actions)).to( + self._device + ) + # overwrite the prev_state only where is_first=True + elif torch.sum(is_first) > 0: is_first = is_first[:, None] prev_action *= 1.0 - is_first init_state = self.initial(len(is_first))