45 lines
1.6 KiB
Python
45 lines
1.6 KiB
Python
from abc import ABC, abstractmethod
|
|
|
|
import torch
|
|
|
|
from tianshou.highlevel.env import Environments, EnvType
|
|
from tianshou.policy.modelfree.pg import TDistributionFunction
|
|
from tianshou.utils.string import ToStringMixin
|
|
|
|
|
|
class DistributionFunctionFactory(ToStringMixin, ABC):
|
|
@abstractmethod
|
|
def create_dist_fn(self, envs: Environments) -> TDistributionFunction:
|
|
pass
|
|
|
|
|
|
class DistributionFunctionFactoryCategorical(DistributionFunctionFactory):
|
|
def create_dist_fn(self, envs: Environments) -> TDistributionFunction:
|
|
envs.get_type().assert_discrete(self)
|
|
return self._dist_fn
|
|
|
|
@staticmethod
|
|
def _dist_fn(p: torch.Tensor) -> torch.distributions.Distribution:
|
|
return torch.distributions.Categorical(logits=p)
|
|
|
|
|
|
class DistributionFunctionFactoryIndependentGaussians(DistributionFunctionFactory):
|
|
def create_dist_fn(self, envs: Environments) -> TDistributionFunction:
|
|
envs.get_type().assert_continuous(self)
|
|
return self._dist_fn
|
|
|
|
@staticmethod
|
|
def _dist_fn(*p: torch.Tensor) -> torch.distributions.Distribution:
|
|
return torch.distributions.Independent(torch.distributions.Normal(*p), 1)
|
|
|
|
|
|
class DistributionFunctionFactoryDefault(DistributionFunctionFactory):
|
|
def create_dist_fn(self, envs: Environments) -> TDistributionFunction:
|
|
match envs.get_type():
|
|
case EnvType.DISCRETE:
|
|
return DistributionFunctionFactoryCategorical().create_dist_fn(envs)
|
|
case EnvType.CONTINUOUS:
|
|
return DistributionFunctionFactoryIndependentGaussians().create_dist_fn(envs)
|
|
case _:
|
|
raise ValueError(envs.get_type())
|