Allow researchers to choose whether to use Double DQN (#368)
This commit is contained in:
parent
8f7bc65ac7
commit
655d5fb14f
@ -24,6 +24,7 @@ class DQNPolicy(BasePolicy):
|
||||
you do not use the target network). Default to 0.
|
||||
:param bool reward_normalization: normalize the reward to Normal(0, 1).
|
||||
Default to False.
|
||||
:param bool is_double: use double dqn. Default to True.
|
||||
|
||||
.. seealso::
|
||||
|
||||
@ -39,6 +40,7 @@ class DQNPolicy(BasePolicy):
|
||||
estimation_step: int = 1,
|
||||
target_update_freq: int = 0,
|
||||
reward_normalization: bool = False,
|
||||
is_double: bool = True,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(**kwargs)
|
||||
@ -56,6 +58,7 @@ class DQNPolicy(BasePolicy):
|
||||
self.model_old = deepcopy(self.model)
|
||||
self.model_old.eval()
|
||||
self._rew_norm = reward_normalization
|
||||
self._is_double = is_double
|
||||
|
||||
def set_eps(self, eps: float) -> None:
|
||||
"""Set the eps for epsilon-greedy exploration."""
|
||||
@ -79,8 +82,10 @@ class DQNPolicy(BasePolicy):
|
||||
target_q = self(batch, model="model_old", input="obs_next").logits
|
||||
else:
|
||||
target_q = result.logits
|
||||
target_q = target_q[np.arange(len(result.act)), result.act]
|
||||
return target_q
|
||||
if self._is_double:
|
||||
return target_q[np.arange(len(result.act)), result.act]
|
||||
else: # Nature DQN, over estimate
|
||||
return target_q.max(dim=1)[0]
|
||||
|
||||
def process_fn(
|
||||
self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray
|
||||
|
Loading…
x
Reference in New Issue
Block a user