diff --git a/exploration.py b/exploration.py
index bb2e60b..56e10ba 100644
--- a/exploration.py
+++ b/exploration.py
@@ -58,9 +58,9 @@ class Plan2Explore(nn.Module):
             "feat": config.dyn_stoch + config.dyn_deter,
         }[self._config.disag_target]
         kw = dict(
-            inp_dim=feat_size + config.num_actions
+            inp_dim=feat_size + (config.num_actions
             if config.disag_action_cond
-            else 0,  # pytorch version
+            else 0),  # pytorch version
             shape=size,
             layers=config.disag_layers,
             units=config.disag_units,
@@ -72,7 +72,7 @@ class Plan2Explore(nn.Module):
         kw = dict(wd=config.weight_decay, opt=config.opt, use_amp=self._use_amp)
         self._model_opt = tools.Optimizer(
             "explorer",
-            self.parameters(),
+            self._networks.parameters(),
             config.model_lr,
             config.opt_eps,
             config.grad_clip,
@@ -80,7 +80,7 @@ class Plan2Explore(nn.Module):
         )
 
     def train(self, start, context, data):
-        with tools.RequiresGrad(self):
+        with tools.RequiresGrad(self._networks):
             metrics = {}
             stoch = start["stoch"]
             if self._config.dyn_discrete:
@@ -99,7 +99,7 @@ class Plan2Explore(nn.Module):
                     [inputs, torch.Tensor(data["action"]).to(self._config.device)], -1
                 )
             metrics.update(self._train_ensemble(inputs, target))
-            metrics.update(self._behavior._train(start, self._intrinsic_reward)[-1])
+        metrics.update(self._behavior._train(start, self._intrinsic_reward)[-1])
         return None, metrics
 
     def _intrinsic_reward(self, feat, state, action):
@@ -129,5 +129,5 @@ class Plan2Explore(nn.Module):
                 [torch.mean(pred.log_prob(targets))[None] for pred in preds], 0
             )
             loss = -torch.mean(likes)
-        metrics = self._model_opt(loss, self.parameters())
+        metrics = self._model_opt(loss, self._networks.parameters())
         return metrics