Using dist.mode instead of logits.argmax (#1066)
changed all the occurrences where an action is selected deterministically - **from**: using the outputs of the actor network. - **to**: using the mode of the PyTorch distribution. --------- Co-authored-by: Arnau Jimenez <arnau.jimenez@zeiss.com>
This commit is contained in:
parent
7c970df53f
commit
1aee41fa9c
@ -3,7 +3,6 @@ import os
|
|||||||
|
|
||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
|
||||||
import torch
|
import torch
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
@ -58,7 +57,6 @@ def get_args() -> argparse.Namespace:
|
|||||||
return parser.parse_known_args()[0]
|
return parser.parse_known_args()[0]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(envpool is None, reason="EnvPool doesn't support this platform")
|
|
||||||
def test_sac_with_il(args: argparse.Namespace = get_args()) -> None:
|
def test_sac_with_il(args: argparse.Namespace = get_args()) -> None:
|
||||||
# if you want to use python vector env, please refer to other test scripts
|
# if you want to use python vector env, please refer to other test scripts
|
||||||
# train_envs = env = envpool.make_gymnasium(args.task, num_envs=args.training_num, seed=args.seed)
|
# train_envs = env = envpool.make_gymnasium(args.task, num_envs=args.training_num, seed=args.seed)
|
||||||
|
@ -109,7 +109,7 @@ class DiscreteSACPolicy(SACPolicy[TDiscreteSACTrainingStats]):
|
|||||||
logits, hidden = self.actor(batch.obs, state=state, info=batch.info)
|
logits, hidden = self.actor(batch.obs, state=state, info=batch.info)
|
||||||
dist = Categorical(logits=logits)
|
dist = Categorical(logits=logits)
|
||||||
if self.deterministic_eval and not self.training:
|
if self.deterministic_eval and not self.training:
|
||||||
act = logits.argmax(axis=-1)
|
act = dist.mode
|
||||||
else:
|
else:
|
||||||
act = dist.sample()
|
act = dist.sample()
|
||||||
return Batch(logits=logits, act=act, state=hidden, dist=dist)
|
return Batch(logits=logits, act=act, state=hidden, dist=dist)
|
||||||
|
@ -158,19 +158,6 @@ class PGPolicy(BasePolicy[TPGTrainingStats], Generic[TPGTrainingStats]):
|
|||||||
batch: BatchWithReturnsProtocol
|
batch: BatchWithReturnsProtocol
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
def _get_deterministic_action(self, logits: torch.Tensor) -> torch.Tensor:
|
|
||||||
if self.action_type == "discrete":
|
|
||||||
return logits.argmax(-1)
|
|
||||||
if self.action_type == "continuous":
|
|
||||||
# assume that the mode of the distribution is the first element
|
|
||||||
# of the actor's output (the "logits")
|
|
||||||
return logits[0]
|
|
||||||
raise RuntimeError(
|
|
||||||
f"Unknown action type: {self.action_type}. "
|
|
||||||
f"This should not happen and might be a bug."
|
|
||||||
f"Supported action types are: 'discrete' and 'continuous'.",
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
batch: ObsBatchProtocol,
|
batch: ObsBatchProtocol,
|
||||||
@ -198,7 +185,7 @@ class PGPolicy(BasePolicy[TPGTrainingStats], Generic[TPGTrainingStats]):
|
|||||||
|
|
||||||
# in this case, the dist is unused!
|
# in this case, the dist is unused!
|
||||||
if self.deterministic_eval and not self.training:
|
if self.deterministic_eval and not self.training:
|
||||||
act = self._get_deterministic_action(logits)
|
act = dist.mode
|
||||||
else:
|
else:
|
||||||
act = dist.sample()
|
act = dist.sample()
|
||||||
result = Batch(logits=logits, act=act, state=hidden, dist=dist)
|
result = Batch(logits=logits, act=act, state=hidden, dist=dist)
|
||||||
|
@ -153,7 +153,10 @@ class REDQPolicy(DDPGPolicy[TREDQTrainingStats]):
|
|||||||
loc_scale, h = self.actor(batch.obs, state=state, info=batch.info)
|
loc_scale, h = self.actor(batch.obs, state=state, info=batch.info)
|
||||||
loc, scale = loc_scale
|
loc, scale = loc_scale
|
||||||
dist = Independent(Normal(loc, scale), 1)
|
dist = Independent(Normal(loc, scale), 1)
|
||||||
act = loc if self.deterministic_eval and not self.training else dist.rsample()
|
if self.deterministic_eval and not self.training:
|
||||||
|
act = dist.mode
|
||||||
|
else:
|
||||||
|
act = dist.rsample()
|
||||||
log_prob = dist.log_prob(act).unsqueeze(-1)
|
log_prob = dist.log_prob(act).unsqueeze(-1)
|
||||||
# apply correction for Tanh squashing when computing logprob from Gaussian
|
# apply correction for Tanh squashing when computing logprob from Gaussian
|
||||||
# You can check out the original SAC paper (arXiv 1801.01290): Eq 21.
|
# You can check out the original SAC paper (arXiv 1801.01290): Eq 21.
|
||||||
|
@ -56,7 +56,7 @@ class SACPolicy(DDPGPolicy[TSACTrainingStats], Generic[TSACTrainingStats]): # t
|
|||||||
This is useful when solving "hard exploration" problems.
|
This is useful when solving "hard exploration" problems.
|
||||||
"default" is equivalent to GaussianNoise(sigma=0.1).
|
"default" is equivalent to GaussianNoise(sigma=0.1).
|
||||||
:param deterministic_eval: whether to use deterministic action
|
:param deterministic_eval: whether to use deterministic action
|
||||||
(mean of Gaussian policy) in evaluation mode instead of stochastic
|
(mode of Gaussian policy) in evaluation mode instead of stochastic
|
||||||
action sampled by the policy. Does not affect training.
|
action sampled by the policy. Does not affect training.
|
||||||
:param action_scaling: whether to map actions from range [-1, 1]
|
:param action_scaling: whether to map actions from range [-1, 1]
|
||||||
to range[action_spaces.low, action_spaces.high].
|
to range[action_spaces.low, action_spaces.high].
|
||||||
@ -177,7 +177,7 @@ class SACPolicy(DDPGPolicy[TSACTrainingStats], Generic[TSACTrainingStats]): # t
|
|||||||
assert isinstance(logits, tuple)
|
assert isinstance(logits, tuple)
|
||||||
dist = Independent(Normal(*logits), 1)
|
dist = Independent(Normal(*logits), 1)
|
||||||
if self.deterministic_eval and not self.training:
|
if self.deterministic_eval and not self.training:
|
||||||
act = logits[0]
|
act = dist.mode
|
||||||
else:
|
else:
|
||||||
act = dist.rsample()
|
act = dist.rsample()
|
||||||
log_prob = dist.log_prob(act).unsqueeze(-1)
|
log_prob = dist.log_prob(act).unsqueeze(-1)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user