diff --git a/tianshou/policy/modelfree/dqn.py b/tianshou/policy/modelfree/dqn.py index 5a4f663..e52af9e 100644 --- a/tianshou/policy/modelfree/dqn.py +++ b/tianshou/policy/modelfree/dqn.py @@ -24,6 +24,7 @@ class DQNPolicy(BasePolicy): you do not use the target network). Default to 0. :param bool reward_normalization: normalize the reward to Normal(0, 1). Default to False. + :param bool is_double: use double dqn. Default to True. .. seealso:: @@ -39,6 +40,7 @@ class DQNPolicy(BasePolicy): estimation_step: int = 1, target_update_freq: int = 0, reward_normalization: bool = False, + is_double: bool = True, **kwargs: Any, ) -> None: super().__init__(**kwargs) @@ -56,6 +58,7 @@ class DQNPolicy(BasePolicy): self.model_old = deepcopy(self.model) self.model_old.eval() self._rew_norm = reward_normalization + self._is_double = is_double def set_eps(self, eps: float) -> None: """Set the eps for epsilon-greedy exploration.""" @@ -79,8 +82,10 @@ class DQNPolicy(BasePolicy): target_q = self(batch, model="model_old", input="obs_next").logits else: target_q = result.logits - target_q = target_q[np.arange(len(result.act)), result.act] - return target_q + if self._is_double: + return target_q[np.arange(len(result.act)), result.act] + else: # Nature DQN, over estimate + return target_q.max(dim=1)[0] def process_fn( self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray