Bugfix/discrete bcq inf (#995)
Fixes a small bug with using np.inf instead of torch-based infinity Closes #963 --------- Co-authored-by: ivan.rodriguez <ivan.rodriguez@unternehmertum.de>
This commit is contained in:
parent
31fa0325fa
commit
f134bc20b5
@ -15,6 +15,9 @@ from tianshou.data.types import (
|
||||
from tianshou.policy import DQNPolicy
|
||||
from tianshou.policy.base import TLearningRateScheduler
|
||||
|
||||
float_info = torch.finfo(torch.float32)
|
||||
INF = float_info.max
|
||||
|
||||
|
||||
class DiscreteBCQPolicy(DQNPolicy):
|
||||
"""Implementation of discrete BCQ algorithm. arXiv:1910.01708.
|
||||
@ -128,7 +131,7 @@ class DiscreteBCQPolicy(DQNPolicy):
|
||||
# mask actions for argmax
|
||||
ratio = imitation_logits - imitation_logits.max(dim=-1, keepdim=True).values
|
||||
mask = (ratio < self._log_tau).float()
|
||||
act = (q_value - np.inf * mask).argmax(dim=-1)
|
||||
act = (q_value - INF * mask).argmax(dim=-1)
|
||||
|
||||
result = Batch(act=act, state=state, q_value=q_value, imitation_logits=imitation_logits)
|
||||
return cast(ImitationBatchProtocol, result)
|
||||
|
Loading…
x
Reference in New Issue
Block a user