Using dist.mode instead of logits.argmax (#1066)
changed all the occurrences where an action is selected deterministically - **from**: using the outputs of the actor network. - **to**: using the mode of the PyTorch distribution. --------- Co-authored-by: Arnau Jimenez <arnau.jimenez@zeiss.com>
This commit is contained in:
		
							parent
							
								
									7c970df53f
								
							
						
					
					
						commit
						1aee41fa9c
					
				| @ -3,7 +3,6 @@ import os | ||||
| 
 | ||||
| import gymnasium as gym | ||||
| import numpy as np | ||||
| import pytest | ||||
| import torch | ||||
| from torch.utils.tensorboard import SummaryWriter | ||||
| 
 | ||||
| @ -58,7 +57,6 @@ def get_args() -> argparse.Namespace: | ||||
|     return parser.parse_known_args()[0] | ||||
| 
 | ||||
| 
 | ||||
| @pytest.mark.skipif(envpool is None, reason="EnvPool doesn't support this platform") | ||||
| def test_sac_with_il(args: argparse.Namespace = get_args()) -> None: | ||||
|     # if you want to use python vector env, please refer to other test scripts | ||||
|     # train_envs = env = envpool.make_gymnasium(args.task, num_envs=args.training_num, seed=args.seed) | ||||
|  | ||||
| @ -109,7 +109,7 @@ class DiscreteSACPolicy(SACPolicy[TDiscreteSACTrainingStats]): | ||||
|         logits, hidden = self.actor(batch.obs, state=state, info=batch.info) | ||||
|         dist = Categorical(logits=logits) | ||||
|         if self.deterministic_eval and not self.training: | ||||
|             act = logits.argmax(axis=-1) | ||||
|             act = dist.mode | ||||
|         else: | ||||
|             act = dist.sample() | ||||
|         return Batch(logits=logits, act=act, state=hidden, dist=dist) | ||||
|  | ||||
| @ -158,19 +158,6 @@ class PGPolicy(BasePolicy[TPGTrainingStats], Generic[TPGTrainingStats]): | ||||
|         batch: BatchWithReturnsProtocol | ||||
|         return batch | ||||
| 
 | ||||
|     def _get_deterministic_action(self, logits: torch.Tensor) -> torch.Tensor: | ||||
|         if self.action_type == "discrete": | ||||
|             return logits.argmax(-1) | ||||
|         if self.action_type == "continuous": | ||||
|             # assume that the mode of the distribution is the first element | ||||
|             # of the actor's output (the "logits") | ||||
|             return logits[0] | ||||
|         raise RuntimeError( | ||||
|             f"Unknown action type: {self.action_type}. " | ||||
|             f"This should not happen and might be a bug." | ||||
|             f"Supported action types are: 'discrete' and 'continuous'.", | ||||
|         ) | ||||
| 
 | ||||
|     def forward( | ||||
|         self, | ||||
|         batch: ObsBatchProtocol, | ||||
| @ -198,7 +185,7 @@ class PGPolicy(BasePolicy[TPGTrainingStats], Generic[TPGTrainingStats]): | ||||
| 
 | ||||
|         # in this case, the dist is unused! | ||||
|         if self.deterministic_eval and not self.training: | ||||
|             act = self._get_deterministic_action(logits) | ||||
|             act = dist.mode | ||||
|         else: | ||||
|             act = dist.sample() | ||||
|         result = Batch(logits=logits, act=act, state=hidden, dist=dist) | ||||
|  | ||||
| @ -153,7 +153,10 @@ class REDQPolicy(DDPGPolicy[TREDQTrainingStats]): | ||||
|         loc_scale, h = self.actor(batch.obs, state=state, info=batch.info) | ||||
|         loc, scale = loc_scale | ||||
|         dist = Independent(Normal(loc, scale), 1) | ||||
|         act = loc if self.deterministic_eval and not self.training else dist.rsample() | ||||
|         if self.deterministic_eval and not self.training: | ||||
|             act = dist.mode | ||||
|         else: | ||||
|             act = dist.rsample() | ||||
|         log_prob = dist.log_prob(act).unsqueeze(-1) | ||||
|         # apply correction for Tanh squashing when computing logprob from Gaussian | ||||
|         # You can check out the original SAC paper (arXiv 1801.01290): Eq 21. | ||||
|  | ||||
| @ -56,7 +56,7 @@ class SACPolicy(DDPGPolicy[TSACTrainingStats], Generic[TSACTrainingStats]):  # t | ||||
|         This is useful when solving "hard exploration" problems. | ||||
|         "default" is equivalent to GaussianNoise(sigma=0.1). | ||||
|     :param deterministic_eval: whether to use deterministic action | ||||
|         (mean of Gaussian policy) in evaluation mode instead of stochastic | ||||
|         (mode of Gaussian policy) in evaluation mode instead of stochastic | ||||
|         action sampled by the policy. Does not affect training. | ||||
|     :param action_scaling: whether to map actions from range [-1, 1] | ||||
|         to range[action_spaces.low, action_spaces.high]. | ||||
| @ -177,7 +177,7 @@ class SACPolicy(DDPGPolicy[TSACTrainingStats], Generic[TSACTrainingStats]):  # t | ||||
|         assert isinstance(logits, tuple) | ||||
|         dist = Independent(Normal(*logits), 1) | ||||
|         if self.deterministic_eval and not self.training: | ||||
|             act = logits[0] | ||||
|             act = dist.mode | ||||
|         else: | ||||
|             act = dist.rsample() | ||||
|         log_prob = dist.log_prob(act).unsqueeze(-1) | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user