Using dist.mode instead of logits.argmax (#1066)

changed all the occurrences where an action is selected deterministically

- **from**: using the outputs of the actor network.
- **to**: using the mode of the PyTorch distribution.

---------

Co-authored-by: Arnau Jimenez <arnau.jimenez@zeiss.com>
This commit is contained in:
Erni 2024-03-03 00:09:39 +01:00 committed by GitHub
parent 7c970df53f
commit 1aee41fa9c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 8 additions and 20 deletions

View File

@ -3,7 +3,6 @@ import os
import gymnasium as gym
import numpy as np
import pytest
import torch
from torch.utils.tensorboard import SummaryWriter
@ -58,7 +57,6 @@ def get_args() -> argparse.Namespace:
return parser.parse_known_args()[0]
@pytest.mark.skipif(envpool is None, reason="EnvPool doesn't support this platform")
def test_sac_with_il(args: argparse.Namespace = get_args()) -> None:
# if you want to use python vector env, please refer to other test scripts
# train_envs = env = envpool.make_gymnasium(args.task, num_envs=args.training_num, seed=args.seed)

View File

@ -109,7 +109,7 @@ class DiscreteSACPolicy(SACPolicy[TDiscreteSACTrainingStats]):
logits, hidden = self.actor(batch.obs, state=state, info=batch.info)
dist = Categorical(logits=logits)
if self.deterministic_eval and not self.training:
act = logits.argmax(axis=-1)
act = dist.mode
else:
act = dist.sample()
return Batch(logits=logits, act=act, state=hidden, dist=dist)

View File

@ -158,19 +158,6 @@ class PGPolicy(BasePolicy[TPGTrainingStats], Generic[TPGTrainingStats]):
batch: BatchWithReturnsProtocol
return batch
def _get_deterministic_action(self, logits: torch.Tensor) -> torch.Tensor:
if self.action_type == "discrete":
return logits.argmax(-1)
if self.action_type == "continuous":
# assume that the mode of the distribution is the first element
# of the actor's output (the "logits")
return logits[0]
raise RuntimeError(
f"Unknown action type: {self.action_type}. "
f"This should not happen and might be a bug."
f"Supported action types are: 'discrete' and 'continuous'.",
)
def forward(
self,
batch: ObsBatchProtocol,
@ -198,7 +185,7 @@ class PGPolicy(BasePolicy[TPGTrainingStats], Generic[TPGTrainingStats]):
# in this case, the dist is unused!
if self.deterministic_eval and not self.training:
act = self._get_deterministic_action(logits)
act = dist.mode
else:
act = dist.sample()
result = Batch(logits=logits, act=act, state=hidden, dist=dist)

View File

@ -153,7 +153,10 @@ class REDQPolicy(DDPGPolicy[TREDQTrainingStats]):
loc_scale, h = self.actor(batch.obs, state=state, info=batch.info)
loc, scale = loc_scale
dist = Independent(Normal(loc, scale), 1)
act = loc if self.deterministic_eval and not self.training else dist.rsample()
if self.deterministic_eval and not self.training:
act = dist.mode
else:
act = dist.rsample()
log_prob = dist.log_prob(act).unsqueeze(-1)
# apply correction for Tanh squashing when computing logprob from Gaussian
# You can check out the original SAC paper (arXiv 1801.01290): Eq 21.

View File

@ -56,7 +56,7 @@ class SACPolicy(DDPGPolicy[TSACTrainingStats], Generic[TSACTrainingStats]): # t
This is useful when solving "hard exploration" problems.
"default" is equivalent to GaussianNoise(sigma=0.1).
:param deterministic_eval: whether to use deterministic action
(mean of Gaussian policy) in evaluation mode instead of stochastic
(mode of Gaussian policy) in evaluation mode instead of stochastic
action sampled by the policy. Does not affect training.
:param action_scaling: whether to map actions from range [-1, 1]
to range[action_spaces.low, action_spaces.high].
@ -177,7 +177,7 @@ class SACPolicy(DDPGPolicy[TSACTrainingStats], Generic[TSACTrainingStats]): # t
assert isinstance(logits, tuple)
dist = Independent(Normal(*logits), 1)
if self.deterministic_eval and not self.training:
act = logits[0]
act = dist.mode
else:
act = dist.rsample()
log_prob = dist.log_prob(act).unsqueeze(-1)