Closes #917 ### Internal Improvements - Better variable names related to model outputs (logits, dist input etc.). #1032 - Improved typing for actors and critics, using Tianshou classes like `Actor`, `ActorProb`, etc., instead of just `nn.Module`. #1032 - Added interfaces for most `Actor` and `Critic` classes to enforce the presence of `forward` methods. #1032 - Simplified `PGPolicy` forward by unifying the `dist_fn` interface (see associated breaking change). #1032 - Use `.mode` of distribution instead of relying on knowledge of the distribution type. #1032 ### Breaking Changes - Changed interface of `dist_fn` in `PGPolicy` and all subclasses to take a single argument in both continuous and discrete cases. #1032 --------- Co-authored-by: Arnau Jimenez <arnau.jimenez@zeiss.com> Co-authored-by: Michael Panchenko <m.panchenko@appliedai.de>
52 lines
1.8 KiB
Python
52 lines
1.8 KiB
Python
from abc import ABC, abstractmethod
|
|
from collections.abc import Callable
|
|
from typing import Any
|
|
|
|
import torch
|
|
|
|
from tianshou.highlevel.env import Environments, EnvType
|
|
from tianshou.policy.modelfree.pg import TDistFnDiscrete, TDistFnDiscrOrCont
|
|
from tianshou.utils.string import ToStringMixin
|
|
|
|
|
|
class DistributionFunctionFactory(ToStringMixin, ABC):
|
|
# True return type defined in subclasses
|
|
@abstractmethod
|
|
def create_dist_fn(
|
|
self,
|
|
envs: Environments,
|
|
) -> Callable[[Any], torch.distributions.Distribution]:
|
|
pass
|
|
|
|
|
|
class DistributionFunctionFactoryCategorical(DistributionFunctionFactory):
|
|
def create_dist_fn(self, envs: Environments) -> TDistFnDiscrete:
|
|
envs.get_type().assert_discrete(self)
|
|
return self._dist_fn
|
|
|
|
@staticmethod
|
|
def _dist_fn(p: torch.Tensor) -> torch.distributions.Categorical:
|
|
return torch.distributions.Categorical(logits=p)
|
|
|
|
|
|
class DistributionFunctionFactoryIndependentGaussians(DistributionFunctionFactory):
|
|
def create_dist_fn(self, envs: Environments) -> TDistFnDiscrOrCont:
|
|
envs.get_type().assert_continuous(self)
|
|
return self._dist_fn
|
|
|
|
@staticmethod
|
|
def _dist_fn(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> torch.distributions.Distribution:
|
|
loc, scale = loc_scale
|
|
return torch.distributions.Independent(torch.distributions.Normal(loc, scale), 1)
|
|
|
|
|
|
class DistributionFunctionFactoryDefault(DistributionFunctionFactory):
|
|
def create_dist_fn(self, envs: Environments) -> TDistFnDiscrOrCont:
|
|
match envs.get_type():
|
|
case EnvType.DISCRETE:
|
|
return DistributionFunctionFactoryCategorical().create_dist_fn(envs)
|
|
case EnvType.CONTINUOUS:
|
|
return DistributionFunctionFactoryIndependentGaussians().create_dist_fn(envs)
|
|
case _:
|
|
raise ValueError(envs.get_type())
|