unified the place to initialize the latents

This commit is contained in:
NM512 2024-01-05 10:09:13 +09:00
parent 49d12baa48
commit e0f2017e28
3 changed files with 12 additions and 21 deletions

View File

@ -59,15 +59,6 @@ class Dreamer(nn.Module):
def __call__(self, obs, reset, state=None, training=True):
step = self._step
if self._should_reset(step):
state = None
if state is not None and reset.any():
mask = 1 - reset
for key in state[0].keys():
for i in range(state[0][key].shape[0]):
state[0][key][i] *= mask[i]
for i in range(len(state[1])):
state[1][i] *= mask[i]
if training:
steps = (
self._config.pretrain
@ -96,11 +87,7 @@ class Dreamer(nn.Module):
def _policy(self, obs, state, training):
if state is None:
batch_size = len(obs["image"])
latent = self._wm.dynamics.initial(len(obs["image"]))
action = torch.zeros((batch_size, self._config.num_actions)).to(
self._config.device
)
latent = action = None
else:
latent, action = state
obs = self._wm.preprocess(obs)

View File

@ -202,7 +202,7 @@ class WorldModel(nn.Module):
]
reward_post = self.heads["reward"](self.dynamics.get_feat(states)).mode()[:6]
init = {k: v[:, -1] for k, v in states.items()}
prior = self.dynamics.imagine(data["action"][:6, 5:], init)
prior = self.dynamics.imagine_with_action(data["action"][:6, 5:], init)
openl = self.heads["decoder"](self.dynamics.get_feat(prior))["image"].mode()
reward_prior = self.heads["reward"](self.dynamics.get_feat(prior)).mode()
# observed image is given until 5 steps

View File

@ -51,6 +51,7 @@ class RSSM(nn.Module):
self._temp_post = temp_post
self._unimix_ratio = unimix_ratio
self._initial = initial
self._num_actions = num_actions
self._embed = embed
self._device = device
@ -151,8 +152,6 @@ class RSSM(nn.Module):
def observe(self, embed, action, is_first, state=None):
swap = lambda x: x.permute([1, 0] + list(range(2, len(x.shape))))
if state is None:
state = self.initial(action.shape[0])
# (batch, time, ch) -> (time, batch, ch)
embed, action, is_first = swap(embed), swap(action), swap(is_first)
# prev_state[0] means selecting posterior of return(posterior, prior) from obs_step
@ -169,10 +168,8 @@ class RSSM(nn.Module):
prior = {k: swap(v) for k, v in prior.items()}
return post, prior
def imagine(self, action, state=None):
def imagine_with_action(self, action, state):
swap = lambda x: x.permute([1, 0] + list(range(2, len(x.shape))))
if state is None:
state = self.initial(action.shape[0])
assert isinstance(state, dict), state
action = action
action = swap(action)
@ -206,7 +203,14 @@ class RSSM(nn.Module):
# otherwise, post use different network(_obs_out_layers) with prior[deter] and embed as inputs
prev_action *= (1.0 / torch.clip(torch.abs(prev_action), min=1.0)).detach()
if torch.sum(is_first) > 0:
# initialize all prev_state
if prev_state == None or torch.sum(is_first) == len(is_first):
prev_state = self.initial(len(is_first))
prev_action = torch.zeros((len(is_first), self._num_actions)).to(
self._device
)
# overwrite the prev_state only where is_first=True
elif torch.sum(is_first) > 0:
is_first = is_first[:, None]
prev_action *= 1.0 - is_first
init_state = self.initial(len(is_first))