Allow researchers to choose whether to use Double DQN (#368)
This commit is contained in:
parent
8f7bc65ac7
commit
655d5fb14f
@ -24,6 +24,7 @@ class DQNPolicy(BasePolicy):
|
|||||||
you do not use the target network). Default to 0.
|
you do not use the target network). Default to 0.
|
||||||
:param bool reward_normalization: normalize the reward to Normal(0, 1).
|
:param bool reward_normalization: normalize the reward to Normal(0, 1).
|
||||||
Default to False.
|
Default to False.
|
||||||
|
:param bool is_double: use double dqn. Default to True.
|
||||||
|
|
||||||
.. seealso::
|
.. seealso::
|
||||||
|
|
||||||
@ -39,6 +40,7 @@ class DQNPolicy(BasePolicy):
|
|||||||
estimation_step: int = 1,
|
estimation_step: int = 1,
|
||||||
target_update_freq: int = 0,
|
target_update_freq: int = 0,
|
||||||
reward_normalization: bool = False,
|
reward_normalization: bool = False,
|
||||||
|
is_double: bool = True,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
@ -56,6 +58,7 @@ class DQNPolicy(BasePolicy):
|
|||||||
self.model_old = deepcopy(self.model)
|
self.model_old = deepcopy(self.model)
|
||||||
self.model_old.eval()
|
self.model_old.eval()
|
||||||
self._rew_norm = reward_normalization
|
self._rew_norm = reward_normalization
|
||||||
|
self._is_double = is_double
|
||||||
|
|
||||||
def set_eps(self, eps: float) -> None:
|
def set_eps(self, eps: float) -> None:
|
||||||
"""Set the eps for epsilon-greedy exploration."""
|
"""Set the eps for epsilon-greedy exploration."""
|
||||||
@ -79,8 +82,10 @@ class DQNPolicy(BasePolicy):
|
|||||||
target_q = self(batch, model="model_old", input="obs_next").logits
|
target_q = self(batch, model="model_old", input="obs_next").logits
|
||||||
else:
|
else:
|
||||||
target_q = result.logits
|
target_q = result.logits
|
||||||
target_q = target_q[np.arange(len(result.act)), result.act]
|
if self._is_double:
|
||||||
return target_q
|
return target_q[np.arange(len(result.act)), result.act]
|
||||||
|
else: # Nature DQN, over estimate
|
||||||
|
return target_q.max(dim=1)[0]
|
||||||
|
|
||||||
def process_fn(
|
def process_fn(
|
||||||
self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray
|
self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray
|
||||||
|
Loading…
x
Reference in New Issue
Block a user