Fix type annotations of dist_fn
This commit is contained in:
		
							parent
							
								
									a161a9cf58
								
							
						
					
					
						commit
						22dfc4ed2e
					
				| @ -1,15 +1,11 @@ | ||||
| from abc import ABC, abstractmethod | ||||
| from collections.abc import Callable | ||||
| from typing import TypeAlias | ||||
| 
 | ||||
| import torch | ||||
| 
 | ||||
| from tianshou.highlevel.env import Environments, EnvType | ||||
| from tianshou.policy.modelfree.pg import TDistParams | ||||
| from tianshou.policy.modelfree.pg import TDistributionFunction | ||||
| from tianshou.utils.string import ToStringMixin | ||||
| 
 | ||||
| TDistributionFunction: TypeAlias = Callable[[TDistParams], torch.distributions.Distribution] | ||||
| 
 | ||||
| 
 | ||||
| class DistributionFunctionFactory(ToStringMixin, ABC): | ||||
|     @abstractmethod | ||||
| @ -23,7 +19,7 @@ class DistributionFunctionFactoryCategorical(DistributionFunctionFactory): | ||||
|         return self._dist_fn | ||||
| 
 | ||||
|     @staticmethod | ||||
|     def _dist_fn(p: TDistParams) -> torch.distributions.Distribution: | ||||
|     def _dist_fn(p: torch.Tensor) -> torch.distributions.Distribution: | ||||
|         return torch.distributions.Categorical(logits=p) | ||||
| 
 | ||||
| 
 | ||||
|  | ||||
| @ -14,8 +14,8 @@ from tianshou.highlevel.params.alpha import AutoAlphaFactory | ||||
| from tianshou.highlevel.params.dist_fn import ( | ||||
|     DistributionFunctionFactory, | ||||
|     DistributionFunctionFactoryDefault, | ||||
|     TDistributionFunction, | ||||
| ) | ||||
| from tianshou.policy.modelfree.pg import TDistributionFunction | ||||
| from tianshou.highlevel.params.env_param import EnvValueFactory, FloatEnvValueFactory | ||||
| from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactory | ||||
| from tianshou.highlevel.params.noise import NoiseFactory | ||||
|  | ||||
| @ -10,7 +10,7 @@ from tianshou.data import ReplayBuffer, to_numpy, to_torch | ||||
| from tianshou.data.types import LogpOldProtocol, RolloutBatchProtocol | ||||
| from tianshou.policy import PPOPolicy | ||||
| from tianshou.policy.base import TLearningRateScheduler | ||||
| from tianshou.policy.modelfree.pg import TDistParams | ||||
| from tianshou.policy.modelfree.pg import TDistributionFunction | ||||
| 
 | ||||
| 
 | ||||
| class GAILPolicy(PPOPolicy): | ||||
| @ -62,7 +62,7 @@ class GAILPolicy(PPOPolicy): | ||||
|         actor: torch.nn.Module, | ||||
|         critic: torch.nn.Module, | ||||
|         optim: torch.optim.Optimizer, | ||||
|         dist_fn: Callable[[TDistParams], torch.distributions.Distribution], | ||||
|         dist_fn: TDistributionFunction, | ||||
|         action_space: gym.Space, | ||||
|         expert_buffer: ReplayBuffer, | ||||
|         disc_net: torch.nn.Module, | ||||
|  | ||||
| @ -11,7 +11,7 @@ from tianshou.data import ReplayBuffer, to_torch_as | ||||
| from tianshou.data.types import BatchWithAdvantagesProtocol, RolloutBatchProtocol | ||||
| from tianshou.policy import PGPolicy | ||||
| from tianshou.policy.base import TLearningRateScheduler | ||||
| from tianshou.policy.modelfree.pg import TDistParams | ||||
| from tianshou.policy.modelfree.pg import TDistributionFunction | ||||
| from tianshou.utils.net.common import ActorCritic | ||||
| 
 | ||||
| 
 | ||||
| @ -50,7 +50,7 @@ class A2CPolicy(PGPolicy): | ||||
|         actor: torch.nn.Module, | ||||
|         critic: torch.nn.Module, | ||||
|         optim: torch.optim.Optimizer, | ||||
|         dist_fn: Callable[[TDistParams], torch.distributions.Distribution], | ||||
|         dist_fn: TDistributionFunction, | ||||
|         action_space: gym.Space, | ||||
|         vf_coef: float = 0.5, | ||||
|         ent_coef: float = 0.01, | ||||
|  | ||||
| @ -1,4 +1,3 @@ | ||||
| from collections.abc import Callable | ||||
| from typing import Any, Literal | ||||
| 
 | ||||
| import gymnasium as gym | ||||
| @ -12,7 +11,7 @@ from tianshou.data import Batch, ReplayBuffer | ||||
| from tianshou.data.types import BatchWithAdvantagesProtocol, RolloutBatchProtocol | ||||
| from tianshou.policy import A2CPolicy | ||||
| from tianshou.policy.base import TLearningRateScheduler | ||||
| from tianshou.policy.modelfree.pg import TDistParams | ||||
| from tianshou.policy.modelfree.pg import TDistributionFunction | ||||
| 
 | ||||
| 
 | ||||
| class NPGPolicy(A2CPolicy): | ||||
| @ -47,7 +46,7 @@ class NPGPolicy(A2CPolicy): | ||||
|         actor: torch.nn.Module, | ||||
|         critic: torch.nn.Module, | ||||
|         optim: torch.optim.Optimizer, | ||||
|         dist_fn: Callable[[TDistParams], torch.distributions.Distribution], | ||||
|         dist_fn: TDistributionFunction, | ||||
|         action_space: gym.Space, | ||||
|         optim_critic_iters: int = 5, | ||||
|         actor_step_size: float = 0.5, | ||||
|  | ||||
| @ -1,6 +1,5 @@ | ||||
| import warnings | ||||
| from collections.abc import Callable | ||||
| from typing import Any, Literal, TypeAlias, cast | ||||
| from typing import Any, Literal, cast, TypeAlias, Callable | ||||
| 
 | ||||
| import gymnasium as gym | ||||
| import numpy as np | ||||
| @ -17,7 +16,7 @@ from tianshou.policy import BasePolicy | ||||
| from tianshou.policy.base import TLearningRateScheduler | ||||
| from tianshou.utils import RunningMeanStd | ||||
| 
 | ||||
| TDistParams: TypeAlias = torch.Tensor | [torch.Tensor, torch.Tensor] | ||||
| TDistributionFunction: TypeAlias = Callable[[torch.Tensor, ...], torch.distributions.Distribution] | ||||
| 
 | ||||
| 
 | ||||
| class PGPolicy(BasePolicy): | ||||
| @ -56,7 +55,7 @@ class PGPolicy(BasePolicy): | ||||
|         *, | ||||
|         actor: torch.nn.Module, | ||||
|         optim: torch.optim.Optimizer, | ||||
|         dist_fn: Callable[[TDistParams], torch.distributions.Distribution], | ||||
|         dist_fn: TDistributionFunction, | ||||
|         action_space: gym.Space, | ||||
|         discount_factor: float = 0.99, | ||||
|         # TODO: rename to return_normalization? | ||||
|  | ||||
| @ -1,4 +1,3 @@ | ||||
| from collections.abc import Callable | ||||
| from typing import Any, Literal | ||||
| 
 | ||||
| import gymnasium as gym | ||||
| @ -10,7 +9,7 @@ from tianshou.data import ReplayBuffer, to_torch_as | ||||
| from tianshou.data.types import LogpOldProtocol, RolloutBatchProtocol | ||||
| from tianshou.policy import A2CPolicy | ||||
| from tianshou.policy.base import TLearningRateScheduler | ||||
| from tianshou.policy.modelfree.pg import TDistParams | ||||
| from tianshou.policy.modelfree.pg import TDistributionFunction | ||||
| from tianshou.utils.net.common import ActorCritic | ||||
| 
 | ||||
| 
 | ||||
| @ -58,7 +57,7 @@ class PPOPolicy(A2CPolicy): | ||||
|         actor: torch.nn.Module, | ||||
|         critic: torch.nn.Module, | ||||
|         optim: torch.optim.Optimizer, | ||||
|         dist_fn: Callable[[TDistParams], torch.distributions.Distribution], | ||||
|         dist_fn: TDistributionFunction, | ||||
|         action_space: gym.Space, | ||||
|         eps_clip: float = 0.2, | ||||
|         dual_clip: float | None = None, | ||||
|  | ||||
| @ -1,5 +1,4 @@ | ||||
| import warnings | ||||
| from collections.abc import Callable | ||||
| from typing import Any, Literal | ||||
| 
 | ||||
| import gymnasium as gym | ||||
| @ -10,7 +9,7 @@ from torch.distributions import kl_divergence | ||||
| from tianshou.data import Batch | ||||
| from tianshou.policy import NPGPolicy | ||||
| from tianshou.policy.base import TLearningRateScheduler | ||||
| from tianshou.policy.modelfree.pg import TDistParams | ||||
| from tianshou.policy.modelfree.pg import TDistributionFunction | ||||
| 
 | ||||
| 
 | ||||
| class TRPOPolicy(NPGPolicy): | ||||
| @ -47,7 +46,7 @@ class TRPOPolicy(NPGPolicy): | ||||
|         actor: torch.nn.Module, | ||||
|         critic: torch.nn.Module, | ||||
|         optim: torch.optim.Optimizer, | ||||
|         dist_fn: Callable[[TDistParams], torch.distributions.Distribution], | ||||
|         dist_fn: TDistributionFunction, | ||||
|         action_space: gym.Space, | ||||
|         max_kl: float = 0.01, | ||||
|         backtrack_coeff: float = 0.8, | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user