diff --git a/tianshou/policy/modelfree/dqn.py b/tianshou/policy/modelfree/dqn.py index b337e9c..58df054 100644 --- a/tianshou/policy/modelfree/dqn.py +++ b/tianshou/policy/modelfree/dqn.py @@ -150,9 +150,10 @@ class DQNPolicy(BasePolicy): # add eps to act if eps is None: eps = self.eps - for i in range(len(q)): - if np.random.rand() < eps: - act[i] = np.random.randint(q.shape[1]) + if not np.isclose(eps, 0): + for i in range(len(q)): + if np.random.rand() < eps: + act[i] = np.random.randint(q.shape[1]) return Batch(logits=q, act=act, state=h) def learn(self, batch: Batch, **kwargs) -> Dict[str, float]: