52 lines
1.8 KiB
Python
Raw Normal View History

from abc import ABC, abstractmethod
from collections.abc import Callable
from typing import Any
import torch
from tianshou.highlevel.env import Environments, EnvType
from tianshou.policy.modelfree.pg import TDistFnDiscrete, TDistFnDiscrOrCont
from tianshou.utils.string import ToStringMixin
class DistributionFunctionFactory(ToStringMixin, ABC):
# True return type defined in subclasses
@abstractmethod
def create_dist_fn(
self,
envs: Environments,
) -> Callable[[Any], torch.distributions.Distribution]:
pass
class DistributionFunctionFactoryCategorical(DistributionFunctionFactory):
def create_dist_fn(self, envs: Environments) -> TDistFnDiscrete:
envs.get_type().assert_discrete(self)
return self._dist_fn
@staticmethod
def _dist_fn(p: torch.Tensor) -> torch.distributions.Categorical:
return torch.distributions.Categorical(logits=p)
class DistributionFunctionFactoryIndependentGaussians(DistributionFunctionFactory):
def create_dist_fn(self, envs: Environments) -> TDistFnDiscrOrCont:
envs.get_type().assert_continuous(self)
return self._dist_fn
@staticmethod
def _dist_fn(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> torch.distributions.Distribution:
loc, scale = loc_scale
return torch.distributions.Independent(torch.distributions.Normal(loc, scale), 1)
class DistributionFunctionFactoryDefault(DistributionFunctionFactory):
def create_dist_fn(self, envs: Environments) -> TDistFnDiscrOrCont:
match envs.get_type():
case EnvType.DISCRETE:
return DistributionFunctionFactoryCategorical().create_dist_fn(envs)
case EnvType.CONTINUOUS:
return DistributionFunctionFactoryIndependentGaussians().create_dist_fn(envs)
case _:
raise ValueError(envs.get_type())