Add DistributionFunctionFactory subclasses for discrete/continuous default

This commit is contained in:
Dominik Jain 2023-10-06 14:32:21 +02:00
parent a8dc75fbab
commit 1243894eb8
2 changed files with 20 additions and 11 deletions

View File

@ -6,7 +6,6 @@ from collections.abc import Sequence
from typing import Literal
from jsonargparse import CLI
from torch.distributions import Independent, Normal
from examples.mujoco.mujoco_env import MujocoEnvFactory
from tianshou.highlevel.config import SamplingConfig
@ -14,6 +13,7 @@ from tianshou.highlevel.experiment import (
ExperimentConfig,
PPOExperimentBuilder,
)
from tianshou.highlevel.params.dist_fn import DistributionFunctionFactoryIndependentGaussians
from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryLinear
from tianshou.highlevel.params.policy_params import PPOParams
from tianshou.utils import logging
@ -62,9 +62,6 @@ def main(
env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config)
def dist_fn(*logits):
return Independent(Normal(*logits), 1)
experiment = (
PPOExperimentBuilder(env_factory, experiment_config, sampling_config)
.with_ppo_params(
@ -85,7 +82,7 @@ def main(
lr_scheduler_factory=LRSchedulerFactoryLinear(sampling_config)
if lr_decay
else None,
dist_fn=dist_fn,
dist_fn=DistributionFunctionFactoryIndependentGaussians(),
),
)
.with_actor_factory_default(hidden_sizes)

View File

@ -17,20 +17,32 @@ class DistributionFunctionFactory(ToStringMixin, ABC):
pass
def _dist_fn_categorical(p):
return torch.distributions.Categorical(logits=p)
class DistributionFunctionFactoryCategorical(DistributionFunctionFactory):
def create_dist_fn(self, envs: Environments) -> TDistributionFunction:
assert envs.get_type().assert_discrete(self)
return self._dist_fn
@staticmethod
def _dist_fn(p):
return torch.distributions.Categorical(logits=p)
def _dist_fn_gaussian(*p):
return torch.distributions.Independent(torch.distributions.Normal(*p), 1)
class DistributionFunctionFactoryIndependentGaussians(DistributionFunctionFactory):
def create_dist_fn(self, envs: Environments) -> TDistributionFunction:
assert envs.get_type().assert_continuous(self)
return self._dist_fn
@staticmethod
def _dist_fn(*p):
return torch.distributions.Independent(torch.distributions.Normal(*p), 1)
class DistributionFunctionFactoryDefault(DistributionFunctionFactory):
def create_dist_fn(self, envs: Environments) -> TDistributionFunction:
match envs.get_type():
case EnvType.DISCRETE:
return _dist_fn_categorical
return DistributionFunctionFactoryCategorical().create_dist_fn(envs)
case EnvType.CONTINUOUS:
return _dist_fn_gaussian
return DistributionFunctionFactoryIndependentGaussians().create_dist_fn(envs)
case _:
raise ValueError(envs.get_type())