diff --git a/dreamer.py b/dreamer.py index 12ef8a2..36eb633 100644 --- a/dreamer.py +++ b/dreamer.py @@ -55,10 +55,10 @@ class Dreamer(nn.Module): self._task_behavior = models.ImagBehavior( config, self._wm, config.behavior_stop_grad ) - if config.compile: + if config.compile and os.name != 'nt': # compilation is not supported on windows self._wm = torch.compile(self._wm) self._task_behavior = torch.compile(self._task_behavior) - reward = lambda f, s, a: self._wm.heads["reward"](f).mean + reward = lambda f, s, a: self._wm.heads["reward"](f).mean() self._expl_behavior = dict( greedy=lambda: self._task_behavior, random=lambda: expl.Random(config, act_space), diff --git a/exploration.py b/exploration.py index f195bb8..5eefbf9 100644 --- a/exploration.py +++ b/exploration.py @@ -58,7 +58,7 @@ class Plan2Explore(nn.Module): "feat": config.dyn_stoch + config.dyn_deter, }[self._config.disag_target] kw = dict( - inp_dim=feat_size, # pytorch version + inp_dim=feat_size + config.num_actions if config.disag_action_cond else 0, # pytorch version shape=size, layers=config.disag_layers, units=config.disag_units, @@ -93,7 +93,7 @@ class Plan2Explore(nn.Module): }[self._config.disag_target] inputs = context["feat"] if self._config.disag_action_cond: - inputs = torch.concat([inputs, data["action"]], -1) + inputs = torch.concat([inputs, torch.Tensor(data["action"]).to(self._config.device)], -1) metrics.update(self._train_ensemble(inputs, target)) metrics.update(self._behavior._train(start, self._intrinsic_reward)[-1]) return None, metrics @@ -110,10 +110,7 @@ class Plan2Explore(nn.Module): disag = torch.log(disag) reward = self._config.expl_intr_scale * disag if self._config.expl_extr_scale: - reward += torch.cast( - self._config.expl_extr_scale * self._reward(feat, state, action), - torch.float32, - ) + reward += self._config.expl_extr_scale * self._reward(feat, state, action) return reward def _train_ensemble(self, inputs, targets): diff --git a/networks.py b/networks.py index 9c58faf..3a767fa 100644 --- a/networks.py +++ b/networks.py @@ -215,7 +215,7 @@ class RSSM(nn.Module): is_first, is_first.shape + (1,) * (len(val.shape) - len(is_first.shape)), ) - val = val * (1.0 - is_first_r) + init_state[key] * is_first_r + prev_state[key] = val * (1.0 - is_first_r) + init_state[key] * is_first_r prior = self.img_step(prev_state, prev_action, None, sample) if self._shared: