Erni bf0d632108
Naming and typing improvements in Actor/Critic/Policy forwards (#1032)
Closes #917 

### Internal Improvements
- Better variable names related to model outputs (logits, dist input
etc.). #1032
- Improved typing for actors and critics, using Tianshou classes like
`Actor`, `ActorProb`, etc.,
instead of just `nn.Module`. #1032
- Added interfaces for most `Actor` and `Critic` classes to enforce the
presence of `forward` methods. #1032
- Simplified `PGPolicy` forward by unifying the `dist_fn` interface (see
associated breaking change). #1032
- Use `.mode` of distribution instead of relying on knowledge of the
distribution type. #1032

### Breaking Changes

- Changed interface of `dist_fn` in `PGPolicy` and all subclasses to
take a single argument in both
continuous and discrete cases. #1032

---------

Co-authored-by: Arnau Jimenez <arnau.jimenez@zeiss.com>
Co-authored-by: Michael Panchenko <m.panchenko@appliedai.de>
2024-04-01 17:14:17 +02:00

52 lines
1.8 KiB
Python

from abc import ABC, abstractmethod
from collections.abc import Callable
from typing import Any
import torch
from tianshou.highlevel.env import Environments, EnvType
from tianshou.policy.modelfree.pg import TDistFnDiscrete, TDistFnDiscrOrCont
from tianshou.utils.string import ToStringMixin
class DistributionFunctionFactory(ToStringMixin, ABC):
# True return type defined in subclasses
@abstractmethod
def create_dist_fn(
self,
envs: Environments,
) -> Callable[[Any], torch.distributions.Distribution]:
pass
class DistributionFunctionFactoryCategorical(DistributionFunctionFactory):
def create_dist_fn(self, envs: Environments) -> TDistFnDiscrete:
envs.get_type().assert_discrete(self)
return self._dist_fn
@staticmethod
def _dist_fn(p: torch.Tensor) -> torch.distributions.Categorical:
return torch.distributions.Categorical(logits=p)
class DistributionFunctionFactoryIndependentGaussians(DistributionFunctionFactory):
def create_dist_fn(self, envs: Environments) -> TDistFnDiscrOrCont:
envs.get_type().assert_continuous(self)
return self._dist_fn
@staticmethod
def _dist_fn(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> torch.distributions.Distribution:
loc, scale = loc_scale
return torch.distributions.Independent(torch.distributions.Normal(loc, scale), 1)
class DistributionFunctionFactoryDefault(DistributionFunctionFactory):
def create_dist_fn(self, envs: Environments) -> TDistFnDiscrOrCont:
match envs.get_type():
case EnvType.DISCRETE:
return DistributionFunctionFactoryCategorical().create_dist_fn(envs)
case EnvType.CONTINUOUS:
return DistributionFunctionFactoryIndependentGaussians().create_dist_fn(envs)
case _:
raise ValueError(envs.get_type())