Using dist.mode instead of logits.argmax (#1066)
changed all the occurrences where an action is selected deterministically - **from**: using the outputs of the actor network. - **to**: using the mode of the PyTorch distribution. --------- Co-authored-by: Arnau Jimenez <arnau.jimenez@zeiss.com>
This commit is contained in:
parent
7c970df53f
commit
1aee41fa9c
@ -3,7 +3,6 @@ import os
|
||||
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
@ -58,7 +57,6 @@ def get_args() -> argparse.Namespace:
|
||||
return parser.parse_known_args()[0]
|
||||
|
||||
|
||||
@pytest.mark.skipif(envpool is None, reason="EnvPool doesn't support this platform")
|
||||
def test_sac_with_il(args: argparse.Namespace = get_args()) -> None:
|
||||
# if you want to use python vector env, please refer to other test scripts
|
||||
# train_envs = env = envpool.make_gymnasium(args.task, num_envs=args.training_num, seed=args.seed)
|
||||
|
@ -109,7 +109,7 @@ class DiscreteSACPolicy(SACPolicy[TDiscreteSACTrainingStats]):
|
||||
logits, hidden = self.actor(batch.obs, state=state, info=batch.info)
|
||||
dist = Categorical(logits=logits)
|
||||
if self.deterministic_eval and not self.training:
|
||||
act = logits.argmax(axis=-1)
|
||||
act = dist.mode
|
||||
else:
|
||||
act = dist.sample()
|
||||
return Batch(logits=logits, act=act, state=hidden, dist=dist)
|
||||
|
@ -158,19 +158,6 @@ class PGPolicy(BasePolicy[TPGTrainingStats], Generic[TPGTrainingStats]):
|
||||
batch: BatchWithReturnsProtocol
|
||||
return batch
|
||||
|
||||
def _get_deterministic_action(self, logits: torch.Tensor) -> torch.Tensor:
|
||||
if self.action_type == "discrete":
|
||||
return logits.argmax(-1)
|
||||
if self.action_type == "continuous":
|
||||
# assume that the mode of the distribution is the first element
|
||||
# of the actor's output (the "logits")
|
||||
return logits[0]
|
||||
raise RuntimeError(
|
||||
f"Unknown action type: {self.action_type}. "
|
||||
f"This should not happen and might be a bug."
|
||||
f"Supported action types are: 'discrete' and 'continuous'.",
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
batch: ObsBatchProtocol,
|
||||
@ -198,7 +185,7 @@ class PGPolicy(BasePolicy[TPGTrainingStats], Generic[TPGTrainingStats]):
|
||||
|
||||
# in this case, the dist is unused!
|
||||
if self.deterministic_eval and not self.training:
|
||||
act = self._get_deterministic_action(logits)
|
||||
act = dist.mode
|
||||
else:
|
||||
act = dist.sample()
|
||||
result = Batch(logits=logits, act=act, state=hidden, dist=dist)
|
||||
|
@ -153,7 +153,10 @@ class REDQPolicy(DDPGPolicy[TREDQTrainingStats]):
|
||||
loc_scale, h = self.actor(batch.obs, state=state, info=batch.info)
|
||||
loc, scale = loc_scale
|
||||
dist = Independent(Normal(loc, scale), 1)
|
||||
act = loc if self.deterministic_eval and not self.training else dist.rsample()
|
||||
if self.deterministic_eval and not self.training:
|
||||
act = dist.mode
|
||||
else:
|
||||
act = dist.rsample()
|
||||
log_prob = dist.log_prob(act).unsqueeze(-1)
|
||||
# apply correction for Tanh squashing when computing logprob from Gaussian
|
||||
# You can check out the original SAC paper (arXiv 1801.01290): Eq 21.
|
||||
|
@ -56,7 +56,7 @@ class SACPolicy(DDPGPolicy[TSACTrainingStats], Generic[TSACTrainingStats]): # t
|
||||
This is useful when solving "hard exploration" problems.
|
||||
"default" is equivalent to GaussianNoise(sigma=0.1).
|
||||
:param deterministic_eval: whether to use deterministic action
|
||||
(mean of Gaussian policy) in evaluation mode instead of stochastic
|
||||
(mode of Gaussian policy) in evaluation mode instead of stochastic
|
||||
action sampled by the policy. Does not affect training.
|
||||
:param action_scaling: whether to map actions from range [-1, 1]
|
||||
to range[action_spaces.low, action_spaces.high].
|
||||
@ -177,7 +177,7 @@ class SACPolicy(DDPGPolicy[TSACTrainingStats], Generic[TSACTrainingStats]): # t
|
||||
assert isinstance(logits, tuple)
|
||||
dist = Independent(Normal(*logits), 1)
|
||||
if self.deterministic_eval and not self.training:
|
||||
act = logits[0]
|
||||
act = dist.mode
|
||||
else:
|
||||
act = dist.rsample()
|
||||
log_prob = dist.log_prob(act).unsqueeze(-1)
|
||||
|
Loading…
x
Reference in New Issue
Block a user