Bugfix/discrete bcq inf (#995)
Fixes a small bug with using np.inf instead of torch-based infinity Closes #963 --------- Co-authored-by: ivan.rodriguez <ivan.rodriguez@unternehmertum.de>
This commit is contained in:
parent
31fa0325fa
commit
f134bc20b5
@ -15,6 +15,9 @@ from tianshou.data.types import (
|
|||||||
from tianshou.policy import DQNPolicy
|
from tianshou.policy import DQNPolicy
|
||||||
from tianshou.policy.base import TLearningRateScheduler
|
from tianshou.policy.base import TLearningRateScheduler
|
||||||
|
|
||||||
|
float_info = torch.finfo(torch.float32)
|
||||||
|
INF = float_info.max
|
||||||
|
|
||||||
|
|
||||||
class DiscreteBCQPolicy(DQNPolicy):
|
class DiscreteBCQPolicy(DQNPolicy):
|
||||||
"""Implementation of discrete BCQ algorithm. arXiv:1910.01708.
|
"""Implementation of discrete BCQ algorithm. arXiv:1910.01708.
|
||||||
@ -128,7 +131,7 @@ class DiscreteBCQPolicy(DQNPolicy):
|
|||||||
# mask actions for argmax
|
# mask actions for argmax
|
||||||
ratio = imitation_logits - imitation_logits.max(dim=-1, keepdim=True).values
|
ratio = imitation_logits - imitation_logits.max(dim=-1, keepdim=True).values
|
||||||
mask = (ratio < self._log_tau).float()
|
mask = (ratio < self._log_tau).float()
|
||||||
act = (q_value - np.inf * mask).argmax(dim=-1)
|
act = (q_value - INF * mask).argmax(dim=-1)
|
||||||
|
|
||||||
result = Batch(act=act, state=state, q_value=q_value, imitation_logits=imitation_logits)
|
result = Batch(act=act, state=state, q_value=q_value, imitation_logits=imitation_logits)
|
||||||
return cast(ImitationBatchProtocol, result)
|
return cast(ImitationBatchProtocol, result)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user