Bugfix/discrete bcq inf (#995)

Fixes a small bug with using np.inf instead of torch-based infinity

Closes #963

---------

Co-authored-by: ivan.rodriguez <ivan.rodriguez@unternehmertum.de>
This commit is contained in:
Michael Panchenko 2023-11-24 11:17:40 +01:00 committed by GitHub
parent 31fa0325fa
commit f134bc20b5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -15,6 +15,9 @@ from tianshou.data.types import (
from tianshou.policy import DQNPolicy
from tianshou.policy.base import TLearningRateScheduler
float_info = torch.finfo(torch.float32)
INF = float_info.max
class DiscreteBCQPolicy(DQNPolicy):
"""Implementation of discrete BCQ algorithm. arXiv:1910.01708.
@ -128,7 +131,7 @@ class DiscreteBCQPolicy(DQNPolicy):
# mask actions for argmax
ratio = imitation_logits - imitation_logits.max(dim=-1, keepdim=True).values
mask = (ratio < self._log_tau).float()
act = (q_value - np.inf * mask).argmax(dim=-1)
act = (q_value - INF * mask).argmax(dim=-1)
result = Batch(act=act, state=state, q_value=q_value, imitation_logits=imitation_logits)
return cast(ImitationBatchProtocol, result)