unified the place to initialize the latents
This commit is contained in:
parent
49d12baa48
commit
e0f2017e28
15
dreamer.py
15
dreamer.py
@ -59,15 +59,6 @@ class Dreamer(nn.Module):
|
|||||||
|
|
||||||
def __call__(self, obs, reset, state=None, training=True):
|
def __call__(self, obs, reset, state=None, training=True):
|
||||||
step = self._step
|
step = self._step
|
||||||
if self._should_reset(step):
|
|
||||||
state = None
|
|
||||||
if state is not None and reset.any():
|
|
||||||
mask = 1 - reset
|
|
||||||
for key in state[0].keys():
|
|
||||||
for i in range(state[0][key].shape[0]):
|
|
||||||
state[0][key][i] *= mask[i]
|
|
||||||
for i in range(len(state[1])):
|
|
||||||
state[1][i] *= mask[i]
|
|
||||||
if training:
|
if training:
|
||||||
steps = (
|
steps = (
|
||||||
self._config.pretrain
|
self._config.pretrain
|
||||||
@ -96,11 +87,7 @@ class Dreamer(nn.Module):
|
|||||||
|
|
||||||
def _policy(self, obs, state, training):
|
def _policy(self, obs, state, training):
|
||||||
if state is None:
|
if state is None:
|
||||||
batch_size = len(obs["image"])
|
latent = action = None
|
||||||
latent = self._wm.dynamics.initial(len(obs["image"]))
|
|
||||||
action = torch.zeros((batch_size, self._config.num_actions)).to(
|
|
||||||
self._config.device
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
latent, action = state
|
latent, action = state
|
||||||
obs = self._wm.preprocess(obs)
|
obs = self._wm.preprocess(obs)
|
||||||
|
@ -202,7 +202,7 @@ class WorldModel(nn.Module):
|
|||||||
]
|
]
|
||||||
reward_post = self.heads["reward"](self.dynamics.get_feat(states)).mode()[:6]
|
reward_post = self.heads["reward"](self.dynamics.get_feat(states)).mode()[:6]
|
||||||
init = {k: v[:, -1] for k, v in states.items()}
|
init = {k: v[:, -1] for k, v in states.items()}
|
||||||
prior = self.dynamics.imagine(data["action"][:6, 5:], init)
|
prior = self.dynamics.imagine_with_action(data["action"][:6, 5:], init)
|
||||||
openl = self.heads["decoder"](self.dynamics.get_feat(prior))["image"].mode()
|
openl = self.heads["decoder"](self.dynamics.get_feat(prior))["image"].mode()
|
||||||
reward_prior = self.heads["reward"](self.dynamics.get_feat(prior)).mode()
|
reward_prior = self.heads["reward"](self.dynamics.get_feat(prior)).mode()
|
||||||
# observed image is given until 5 steps
|
# observed image is given until 5 steps
|
||||||
|
16
networks.py
16
networks.py
@ -51,6 +51,7 @@ class RSSM(nn.Module):
|
|||||||
self._temp_post = temp_post
|
self._temp_post = temp_post
|
||||||
self._unimix_ratio = unimix_ratio
|
self._unimix_ratio = unimix_ratio
|
||||||
self._initial = initial
|
self._initial = initial
|
||||||
|
self._num_actions = num_actions
|
||||||
self._embed = embed
|
self._embed = embed
|
||||||
self._device = device
|
self._device = device
|
||||||
|
|
||||||
@ -151,8 +152,6 @@ class RSSM(nn.Module):
|
|||||||
|
|
||||||
def observe(self, embed, action, is_first, state=None):
|
def observe(self, embed, action, is_first, state=None):
|
||||||
swap = lambda x: x.permute([1, 0] + list(range(2, len(x.shape))))
|
swap = lambda x: x.permute([1, 0] + list(range(2, len(x.shape))))
|
||||||
if state is None:
|
|
||||||
state = self.initial(action.shape[0])
|
|
||||||
# (batch, time, ch) -> (time, batch, ch)
|
# (batch, time, ch) -> (time, batch, ch)
|
||||||
embed, action, is_first = swap(embed), swap(action), swap(is_first)
|
embed, action, is_first = swap(embed), swap(action), swap(is_first)
|
||||||
# prev_state[0] means selecting posterior of return(posterior, prior) from obs_step
|
# prev_state[0] means selecting posterior of return(posterior, prior) from obs_step
|
||||||
@ -169,10 +168,8 @@ class RSSM(nn.Module):
|
|||||||
prior = {k: swap(v) for k, v in prior.items()}
|
prior = {k: swap(v) for k, v in prior.items()}
|
||||||
return post, prior
|
return post, prior
|
||||||
|
|
||||||
def imagine(self, action, state=None):
|
def imagine_with_action(self, action, state):
|
||||||
swap = lambda x: x.permute([1, 0] + list(range(2, len(x.shape))))
|
swap = lambda x: x.permute([1, 0] + list(range(2, len(x.shape))))
|
||||||
if state is None:
|
|
||||||
state = self.initial(action.shape[0])
|
|
||||||
assert isinstance(state, dict), state
|
assert isinstance(state, dict), state
|
||||||
action = action
|
action = action
|
||||||
action = swap(action)
|
action = swap(action)
|
||||||
@ -206,7 +203,14 @@ class RSSM(nn.Module):
|
|||||||
# otherwise, post use different network(_obs_out_layers) with prior[deter] and embed as inputs
|
# otherwise, post use different network(_obs_out_layers) with prior[deter] and embed as inputs
|
||||||
prev_action *= (1.0 / torch.clip(torch.abs(prev_action), min=1.0)).detach()
|
prev_action *= (1.0 / torch.clip(torch.abs(prev_action), min=1.0)).detach()
|
||||||
|
|
||||||
if torch.sum(is_first) > 0:
|
# initialize all prev_state
|
||||||
|
if prev_state == None or torch.sum(is_first) == len(is_first):
|
||||||
|
prev_state = self.initial(len(is_first))
|
||||||
|
prev_action = torch.zeros((len(is_first), self._num_actions)).to(
|
||||||
|
self._device
|
||||||
|
)
|
||||||
|
# overwrite the prev_state only where is_first=True
|
||||||
|
elif torch.sum(is_first) > 0:
|
||||||
is_first = is_first[:, None]
|
is_first = is_first[:, None]
|
||||||
prev_action *= 1.0 - is_first
|
prev_action *= 1.0 - is_first
|
||||||
init_state = self.initial(len(is_first))
|
init_state = self.initial(len(is_first))
|
||||||
|
Loading…
x
Reference in New Issue
Block a user