From 6bf120237b04cfa1ed3e9b7caf0edc92f190281c Mon Sep 17 00:00:00 2001 From: WeberSamuel <37903003+WeberSamuel@users.noreply.github.com> Date: Sat, 19 Aug 2023 09:42:54 +0200 Subject: [PATCH] Fix Plan2Explore --- exploration.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/exploration.py b/exploration.py index bb2e60b..56e10ba 100644 --- a/exploration.py +++ b/exploration.py @@ -58,9 +58,9 @@ class Plan2Explore(nn.Module): "feat": config.dyn_stoch + config.dyn_deter, }[self._config.disag_target] kw = dict( - inp_dim=feat_size + config.num_actions + inp_dim=feat_size + (config.num_actions if config.disag_action_cond - else 0, # pytorch version + else 0), # pytorch version shape=size, layers=config.disag_layers, units=config.disag_units, @@ -72,7 +72,7 @@ class Plan2Explore(nn.Module): kw = dict(wd=config.weight_decay, opt=config.opt, use_amp=self._use_amp) self._model_opt = tools.Optimizer( "explorer", - self.parameters(), + self._networks.parameters(), config.model_lr, config.opt_eps, config.grad_clip, @@ -80,7 +80,7 @@ class Plan2Explore(nn.Module): ) def train(self, start, context, data): - with tools.RequiresGrad(self): + with tools.RequiresGrad(self._networks): metrics = {} stoch = start["stoch"] if self._config.dyn_discrete: @@ -99,7 +99,7 @@ class Plan2Explore(nn.Module): [inputs, torch.Tensor(data["action"]).to(self._config.device)], -1 ) metrics.update(self._train_ensemble(inputs, target)) - metrics.update(self._behavior._train(start, self._intrinsic_reward)[-1]) + metrics.update(self._behavior._train(start, self._intrinsic_reward)[-1]) return None, metrics def _intrinsic_reward(self, feat, state, action): @@ -129,5 +129,5 @@ class Plan2Explore(nn.Module): [torch.mean(pred.log_prob(targets))[None] for pred in preds], 0 ) loss = -torch.mean(likes) - metrics = self._model_opt(loss, self.parameters()) + metrics = self._model_opt(loss, self._networks.parameters()) return metrics