2023-10-18 20:44:17 +02:00

45 lines
1.6 KiB
Python

from abc import ABC, abstractmethod
import torch
from tianshou.highlevel.env import Environments, EnvType
from tianshou.policy.modelfree.pg import TDistributionFunction
from tianshou.utils.string import ToStringMixin
class DistributionFunctionFactory(ToStringMixin, ABC):
@abstractmethod
def create_dist_fn(self, envs: Environments) -> TDistributionFunction:
pass
class DistributionFunctionFactoryCategorical(DistributionFunctionFactory):
def create_dist_fn(self, envs: Environments) -> TDistributionFunction:
envs.get_type().assert_discrete(self)
return self._dist_fn
@staticmethod
def _dist_fn(p: torch.Tensor) -> torch.distributions.Distribution:
return torch.distributions.Categorical(logits=p)
class DistributionFunctionFactoryIndependentGaussians(DistributionFunctionFactory):
def create_dist_fn(self, envs: Environments) -> TDistributionFunction:
envs.get_type().assert_continuous(self)
return self._dist_fn
@staticmethod
def _dist_fn(*p: torch.Tensor) -> torch.distributions.Distribution:
return torch.distributions.Independent(torch.distributions.Normal(*p), 1)
class DistributionFunctionFactoryDefault(DistributionFunctionFactory):
def create_dist_fn(self, envs: Environments) -> TDistributionFunction:
match envs.get_type():
case EnvType.DISCRETE:
return DistributionFunctionFactoryCategorical().create_dist_fn(envs)
case EnvType.CONTINUOUS:
return DistributionFunctionFactoryIndependentGaussians().create_dist_fn(envs)
case _:
raise ValueError(envs.get_type())