2023-09-28 14:28:03 +02:00
|
|
|
from abc import ABC, abstractmethod
|
|
|
|
from collections.abc import Callable
|
|
|
|
from typing import TypeAlias
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
|
|
|
from tianshou.highlevel.env import Environments, EnvType
|
|
|
|
from tianshou.policy.modelfree.pg import TDistParams
|
2023-10-05 13:15:24 +02:00
|
|
|
from tianshou.utils.string import ToStringMixin
|
2023-09-28 14:28:03 +02:00
|
|
|
|
|
|
|
TDistributionFunction: TypeAlias = Callable[[TDistParams], torch.distributions.Distribution]
|
|
|
|
|
|
|
|
|
2023-10-05 13:15:24 +02:00
|
|
|
class DistributionFunctionFactory(ToStringMixin, ABC):
|
2023-09-28 14:28:03 +02:00
|
|
|
@abstractmethod
|
|
|
|
def create_dist_fn(self, envs: Environments) -> TDistributionFunction:
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
2023-10-06 14:32:21 +02:00
|
|
|
class DistributionFunctionFactoryCategorical(DistributionFunctionFactory):
|
|
|
|
def create_dist_fn(self, envs: Environments) -> TDistributionFunction:
|
|
|
|
assert envs.get_type().assert_discrete(self)
|
|
|
|
return self._dist_fn
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def _dist_fn(p):
|
|
|
|
return torch.distributions.Categorical(logits=p)
|
2023-09-28 14:28:03 +02:00
|
|
|
|
|
|
|
|
2023-10-06 14:32:21 +02:00
|
|
|
class DistributionFunctionFactoryIndependentGaussians(DistributionFunctionFactory):
|
|
|
|
def create_dist_fn(self, envs: Environments) -> TDistributionFunction:
|
|
|
|
assert envs.get_type().assert_continuous(self)
|
|
|
|
return self._dist_fn
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def _dist_fn(*p):
|
|
|
|
return torch.distributions.Independent(torch.distributions.Normal(*p), 1)
|
2023-09-28 14:28:03 +02:00
|
|
|
|
|
|
|
|
|
|
|
class DistributionFunctionFactoryDefault(DistributionFunctionFactory):
|
|
|
|
def create_dist_fn(self, envs: Environments) -> TDistributionFunction:
|
|
|
|
match envs.get_type():
|
|
|
|
case EnvType.DISCRETE:
|
2023-10-06 14:32:21 +02:00
|
|
|
return DistributionFunctionFactoryCategorical().create_dist_fn(envs)
|
2023-09-28 14:28:03 +02:00
|
|
|
case EnvType.CONTINUOUS:
|
2023-10-06 14:32:21 +02:00
|
|
|
return DistributionFunctionFactoryIndependentGaussians().create_dist_fn(envs)
|
2023-09-28 14:28:03 +02:00
|
|
|
case _:
|
|
|
|
raise ValueError(envs.get_type())
|