Dominik Jain cd79cf8661 Add A2C high-level API
* Add common based class for A2C and PPO agent factories
* Add default for dist_fn parameter, adding corresponding factories
* Add example mujoco_a2c_hl
2023-10-18 20:44:16 +02:00

36 lines
1.0 KiB
Python

from abc import ABC, abstractmethod
from collections.abc import Callable
from typing import TypeAlias
import torch
from tianshou.highlevel.env import Environments, EnvType
from tianshou.policy.modelfree.pg import TDistParams
TDistributionFunction: TypeAlias = Callable[[TDistParams], torch.distributions.Distribution]
class DistributionFunctionFactory(ABC):
@abstractmethod
def create_dist_fn(self, envs: Environments) -> TDistributionFunction:
pass
def _dist_fn_categorical(p):
return torch.distributions.Categorical(logits=p)
def _dist_fn_gaussian(*p):
return torch.distributions.Independent(torch.distributions.Normal(*p), 1)
class DistributionFunctionFactoryDefault(DistributionFunctionFactory):
def create_dist_fn(self, envs: Environments) -> TDistributionFunction:
match envs.get_type():
case EnvType.DISCRETE:
return _dist_fn_categorical
case EnvType.CONTINUOUS:
return _dist_fn_gaussian
case _:
raise ValueError(envs.get_type())