Allow researchers to choose whether to use Double DQN (#368)

This commit is contained in:
Ark 2021-05-21 10:53:34 +08:00 committed by GitHub
parent 8f7bc65ac7
commit 655d5fb14f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -24,6 +24,7 @@ class DQNPolicy(BasePolicy):
you do not use the target network). Default to 0.
:param bool reward_normalization: normalize the reward to Normal(0, 1).
Default to False.
:param bool is_double: use double dqn. Default to True.
.. seealso::
@ -39,6 +40,7 @@ class DQNPolicy(BasePolicy):
estimation_step: int = 1,
target_update_freq: int = 0,
reward_normalization: bool = False,
is_double: bool = True,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
@ -56,6 +58,7 @@ class DQNPolicy(BasePolicy):
self.model_old = deepcopy(self.model)
self.model_old.eval()
self._rew_norm = reward_normalization
self._is_double = is_double
def set_eps(self, eps: float) -> None:
"""Set the eps for epsilon-greedy exploration."""
@ -79,8 +82,10 @@ class DQNPolicy(BasePolicy):
target_q = self(batch, model="model_old", input="obs_next").logits
else:
target_q = result.logits
target_q = target_q[np.arange(len(result.act)), result.act]
return target_q
if self._is_double:
return target_q[np.arange(len(result.act)), result.act]
else: # Nature DQN, over estimate
return target_q.max(dim=1)[0]
def process_fn(
self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray