45 lines
		
	
	
		
			1.6 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			45 lines
		
	
	
		
			1.6 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| from abc import ABC, abstractmethod
 | |
| 
 | |
| import torch
 | |
| 
 | |
| from tianshou.highlevel.env import Environments, EnvType
 | |
| from tianshou.policy.modelfree.pg import TDistributionFunction
 | |
| from tianshou.utils.string import ToStringMixin
 | |
| 
 | |
| 
 | |
| class DistributionFunctionFactory(ToStringMixin, ABC):
 | |
|     @abstractmethod
 | |
|     def create_dist_fn(self, envs: Environments) -> TDistributionFunction:
 | |
|         pass
 | |
| 
 | |
| 
 | |
| class DistributionFunctionFactoryCategorical(DistributionFunctionFactory):
 | |
|     def create_dist_fn(self, envs: Environments) -> TDistributionFunction:
 | |
|         envs.get_type().assert_discrete(self)
 | |
|         return self._dist_fn
 | |
| 
 | |
|     @staticmethod
 | |
|     def _dist_fn(p: torch.Tensor) -> torch.distributions.Distribution:
 | |
|         return torch.distributions.Categorical(logits=p)
 | |
| 
 | |
| 
 | |
| class DistributionFunctionFactoryIndependentGaussians(DistributionFunctionFactory):
 | |
|     def create_dist_fn(self, envs: Environments) -> TDistributionFunction:
 | |
|         envs.get_type().assert_continuous(self)
 | |
|         return self._dist_fn
 | |
| 
 | |
|     @staticmethod
 | |
|     def _dist_fn(*p: torch.Tensor) -> torch.distributions.Distribution:
 | |
|         return torch.distributions.Independent(torch.distributions.Normal(*p), 1)
 | |
| 
 | |
| 
 | |
| class DistributionFunctionFactoryDefault(DistributionFunctionFactory):
 | |
|     def create_dist_fn(self, envs: Environments) -> TDistributionFunction:
 | |
|         match envs.get_type():
 | |
|             case EnvType.DISCRETE:
 | |
|                 return DistributionFunctionFactoryCategorical().create_dist_fn(envs)
 | |
|             case EnvType.CONTINUOUS:
 | |
|                 return DistributionFunctionFactoryIndependentGaussians().create_dist_fn(envs)
 | |
|             case _:
 | |
|                 raise ValueError(envs.get_type())
 |