diff --git a/tianshou/policy/imitation/discrete_bcq.py b/tianshou/policy/imitation/discrete_bcq.py index 0ab54c1..d32d4ef 100644 --- a/tianshou/policy/imitation/discrete_bcq.py +++ b/tianshou/policy/imitation/discrete_bcq.py @@ -15,6 +15,9 @@ from tianshou.data.types import ( from tianshou.policy import DQNPolicy from tianshou.policy.base import TLearningRateScheduler +float_info = torch.finfo(torch.float32) +INF = float_info.max + class DiscreteBCQPolicy(DQNPolicy): """Implementation of discrete BCQ algorithm. arXiv:1910.01708. @@ -128,7 +131,7 @@ class DiscreteBCQPolicy(DQNPolicy): # mask actions for argmax ratio = imitation_logits - imitation_logits.max(dim=-1, keepdim=True).values mask = (ratio < self._log_tau).float() - act = (q_value - np.inf * mask).argmax(dim=-1) + act = (q_value - INF * mask).argmax(dim=-1) result = Batch(act=act, state=state, q_value=q_value, imitation_logits=imitation_logits) return cast(ImitationBatchProtocol, result)