Using dist.mode instead of logits.argmax (#1066)
changed all the occurrences where an action is selected deterministically - **from**: using the outputs of the actor network. - **to**: using the mode of the PyTorch distribution. --------- Co-authored-by: Arnau Jimenez <arnau.jimenez@zeiss.com>
This commit is contained in:
		
							parent
							
								
									7c970df53f
								
							
						
					
					
						commit
						1aee41fa9c
					
				| @ -3,7 +3,6 @@ import os | |||||||
| 
 | 
 | ||||||
| import gymnasium as gym | import gymnasium as gym | ||||||
| import numpy as np | import numpy as np | ||||||
| import pytest |  | ||||||
| import torch | import torch | ||||||
| from torch.utils.tensorboard import SummaryWriter | from torch.utils.tensorboard import SummaryWriter | ||||||
| 
 | 
 | ||||||
| @ -58,7 +57,6 @@ def get_args() -> argparse.Namespace: | |||||||
|     return parser.parse_known_args()[0] |     return parser.parse_known_args()[0] | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @pytest.mark.skipif(envpool is None, reason="EnvPool doesn't support this platform") |  | ||||||
| def test_sac_with_il(args: argparse.Namespace = get_args()) -> None: | def test_sac_with_il(args: argparse.Namespace = get_args()) -> None: | ||||||
|     # if you want to use python vector env, please refer to other test scripts |     # if you want to use python vector env, please refer to other test scripts | ||||||
|     # train_envs = env = envpool.make_gymnasium(args.task, num_envs=args.training_num, seed=args.seed) |     # train_envs = env = envpool.make_gymnasium(args.task, num_envs=args.training_num, seed=args.seed) | ||||||
|  | |||||||
| @ -109,7 +109,7 @@ class DiscreteSACPolicy(SACPolicy[TDiscreteSACTrainingStats]): | |||||||
|         logits, hidden = self.actor(batch.obs, state=state, info=batch.info) |         logits, hidden = self.actor(batch.obs, state=state, info=batch.info) | ||||||
|         dist = Categorical(logits=logits) |         dist = Categorical(logits=logits) | ||||||
|         if self.deterministic_eval and not self.training: |         if self.deterministic_eval and not self.training: | ||||||
|             act = logits.argmax(axis=-1) |             act = dist.mode | ||||||
|         else: |         else: | ||||||
|             act = dist.sample() |             act = dist.sample() | ||||||
|         return Batch(logits=logits, act=act, state=hidden, dist=dist) |         return Batch(logits=logits, act=act, state=hidden, dist=dist) | ||||||
|  | |||||||
| @ -158,19 +158,6 @@ class PGPolicy(BasePolicy[TPGTrainingStats], Generic[TPGTrainingStats]): | |||||||
|         batch: BatchWithReturnsProtocol |         batch: BatchWithReturnsProtocol | ||||||
|         return batch |         return batch | ||||||
| 
 | 
 | ||||||
|     def _get_deterministic_action(self, logits: torch.Tensor) -> torch.Tensor: |  | ||||||
|         if self.action_type == "discrete": |  | ||||||
|             return logits.argmax(-1) |  | ||||||
|         if self.action_type == "continuous": |  | ||||||
|             # assume that the mode of the distribution is the first element |  | ||||||
|             # of the actor's output (the "logits") |  | ||||||
|             return logits[0] |  | ||||||
|         raise RuntimeError( |  | ||||||
|             f"Unknown action type: {self.action_type}. " |  | ||||||
|             f"This should not happen and might be a bug." |  | ||||||
|             f"Supported action types are: 'discrete' and 'continuous'.", |  | ||||||
|         ) |  | ||||||
| 
 |  | ||||||
|     def forward( |     def forward( | ||||||
|         self, |         self, | ||||||
|         batch: ObsBatchProtocol, |         batch: ObsBatchProtocol, | ||||||
| @ -198,7 +185,7 @@ class PGPolicy(BasePolicy[TPGTrainingStats], Generic[TPGTrainingStats]): | |||||||
| 
 | 
 | ||||||
|         # in this case, the dist is unused! |         # in this case, the dist is unused! | ||||||
|         if self.deterministic_eval and not self.training: |         if self.deterministic_eval and not self.training: | ||||||
|             act = self._get_deterministic_action(logits) |             act = dist.mode | ||||||
|         else: |         else: | ||||||
|             act = dist.sample() |             act = dist.sample() | ||||||
|         result = Batch(logits=logits, act=act, state=hidden, dist=dist) |         result = Batch(logits=logits, act=act, state=hidden, dist=dist) | ||||||
|  | |||||||
| @ -153,7 +153,10 @@ class REDQPolicy(DDPGPolicy[TREDQTrainingStats]): | |||||||
|         loc_scale, h = self.actor(batch.obs, state=state, info=batch.info) |         loc_scale, h = self.actor(batch.obs, state=state, info=batch.info) | ||||||
|         loc, scale = loc_scale |         loc, scale = loc_scale | ||||||
|         dist = Independent(Normal(loc, scale), 1) |         dist = Independent(Normal(loc, scale), 1) | ||||||
|         act = loc if self.deterministic_eval and not self.training else dist.rsample() |         if self.deterministic_eval and not self.training: | ||||||
|  |             act = dist.mode | ||||||
|  |         else: | ||||||
|  |             act = dist.rsample() | ||||||
|         log_prob = dist.log_prob(act).unsqueeze(-1) |         log_prob = dist.log_prob(act).unsqueeze(-1) | ||||||
|         # apply correction for Tanh squashing when computing logprob from Gaussian |         # apply correction for Tanh squashing when computing logprob from Gaussian | ||||||
|         # You can check out the original SAC paper (arXiv 1801.01290): Eq 21. |         # You can check out the original SAC paper (arXiv 1801.01290): Eq 21. | ||||||
|  | |||||||
| @ -56,7 +56,7 @@ class SACPolicy(DDPGPolicy[TSACTrainingStats], Generic[TSACTrainingStats]):  # t | |||||||
|         This is useful when solving "hard exploration" problems. |         This is useful when solving "hard exploration" problems. | ||||||
|         "default" is equivalent to GaussianNoise(sigma=0.1). |         "default" is equivalent to GaussianNoise(sigma=0.1). | ||||||
|     :param deterministic_eval: whether to use deterministic action |     :param deterministic_eval: whether to use deterministic action | ||||||
|         (mean of Gaussian policy) in evaluation mode instead of stochastic |         (mode of Gaussian policy) in evaluation mode instead of stochastic | ||||||
|         action sampled by the policy. Does not affect training. |         action sampled by the policy. Does not affect training. | ||||||
|     :param action_scaling: whether to map actions from range [-1, 1] |     :param action_scaling: whether to map actions from range [-1, 1] | ||||||
|         to range[action_spaces.low, action_spaces.high]. |         to range[action_spaces.low, action_spaces.high]. | ||||||
| @ -177,7 +177,7 @@ class SACPolicy(DDPGPolicy[TSACTrainingStats], Generic[TSACTrainingStats]):  # t | |||||||
|         assert isinstance(logits, tuple) |         assert isinstance(logits, tuple) | ||||||
|         dist = Independent(Normal(*logits), 1) |         dist = Independent(Normal(*logits), 1) | ||||||
|         if self.deterministic_eval and not self.training: |         if self.deterministic_eval and not self.training: | ||||||
|             act = logits[0] |             act = dist.mode | ||||||
|         else: |         else: | ||||||
|             act = dist.rsample() |             act = dist.rsample() | ||||||
|         log_prob = dist.log_prob(act).unsqueeze(-1) |         log_prob = dist.log_prob(act).unsqueeze(-1) | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user