diff --git a/dreamer.py b/dreamer.py index 36eb633..0832818 100644 --- a/dreamer.py +++ b/dreamer.py @@ -55,7 +55,9 @@ class Dreamer(nn.Module): self._task_behavior = models.ImagBehavior( config, self._wm, config.behavior_stop_grad ) - if config.compile and os.name != 'nt': # compilation is not supported on windows + if ( + config.compile and os.name != "nt" + ): # compilation is not supported on windows self._wm = torch.compile(self._wm) self._task_behavior = torch.compile(self._task_behavior) reward = lambda f, s, a: self._wm.heads["reward"](f).mean() @@ -156,7 +158,6 @@ class Dreamer(nn.Module): post, context, mets = self._wm._train(data) metrics.update(mets) start = post - # start['deter'] (16, 64, 512) reward = lambda f, s, a: self._wm.heads["reward"]( self._wm.dynamics.get_feat(s) ).mode() diff --git a/exploration.py b/exploration.py index 5eefbf9..bb2e60b 100644 --- a/exploration.py +++ b/exploration.py @@ -58,7 +58,9 @@ class Plan2Explore(nn.Module): "feat": config.dyn_stoch + config.dyn_deter, }[self._config.disag_target] kw = dict( - inp_dim=feat_size + config.num_actions if config.disag_action_cond else 0, # pytorch version + inp_dim=feat_size + config.num_actions + if config.disag_action_cond + else 0, # pytorch version shape=size, layers=config.disag_layers, units=config.disag_units, @@ -93,7 +95,9 @@ class Plan2Explore(nn.Module): }[self._config.disag_target] inputs = context["feat"] if self._config.disag_action_cond: - inputs = torch.concat([inputs, torch.Tensor(data["action"]).to(self._config.device)], -1) + inputs = torch.concat( + [inputs, torch.Tensor(data["action"]).to(self._config.device)], -1 + ) metrics.update(self._train_ensemble(inputs, target)) metrics.update(self._behavior._train(start, self._intrinsic_reward)[-1]) return None, metrics diff --git a/models.py b/models.py index 8402d10..8fadf7d 100644 --- a/models.py +++ b/models.py @@ -399,9 +399,6 @@ class ImagBehavior(nn.Module): if self._config.future_entropy and self._config.actor_state_entropy() > 0: reward += self._config.actor_state_entropy() * state_ent value = self.value(imag_feat).mode() - # value(15, 960, ch) - # action(15, 960, ch) - # discount(15, 960, ch) target = tools.lambda_return( reward[:-1], value[:-1], diff --git a/networks.py b/networks.py index 3a767fa..aade584 100644 --- a/networks.py +++ b/networks.py @@ -215,7 +215,9 @@ class RSSM(nn.Module): is_first, is_first.shape + (1,) * (len(val.shape) - len(is_first.shape)), ) - prev_state[key] = val * (1.0 - is_first_r) + init_state[key] * is_first_r + prev_state[key] = ( + val * (1.0 - is_first_r) + init_state[key] * is_first_r + ) prior = self.img_step(prev_state, prev_action, None, sample) if self._shared: