Add DistributionFunctionFactory subclasses for discrete/continuous default
This commit is contained in:
parent
a8dc75fbab
commit
1243894eb8
@ -6,7 +6,6 @@ from collections.abc import Sequence
|
||||
from typing import Literal
|
||||
|
||||
from jsonargparse import CLI
|
||||
from torch.distributions import Independent, Normal
|
||||
|
||||
from examples.mujoco.mujoco_env import MujocoEnvFactory
|
||||
from tianshou.highlevel.config import SamplingConfig
|
||||
@ -14,6 +13,7 @@ from tianshou.highlevel.experiment import (
|
||||
ExperimentConfig,
|
||||
PPOExperimentBuilder,
|
||||
)
|
||||
from tianshou.highlevel.params.dist_fn import DistributionFunctionFactoryIndependentGaussians
|
||||
from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryLinear
|
||||
from tianshou.highlevel.params.policy_params import PPOParams
|
||||
from tianshou.utils import logging
|
||||
@ -62,9 +62,6 @@ def main(
|
||||
|
||||
env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config)
|
||||
|
||||
def dist_fn(*logits):
|
||||
return Independent(Normal(*logits), 1)
|
||||
|
||||
experiment = (
|
||||
PPOExperimentBuilder(env_factory, experiment_config, sampling_config)
|
||||
.with_ppo_params(
|
||||
@ -85,7 +82,7 @@ def main(
|
||||
lr_scheduler_factory=LRSchedulerFactoryLinear(sampling_config)
|
||||
if lr_decay
|
||||
else None,
|
||||
dist_fn=dist_fn,
|
||||
dist_fn=DistributionFunctionFactoryIndependentGaussians(),
|
||||
),
|
||||
)
|
||||
.with_actor_factory_default(hidden_sizes)
|
||||
|
@ -17,20 +17,32 @@ class DistributionFunctionFactory(ToStringMixin, ABC):
|
||||
pass
|
||||
|
||||
|
||||
def _dist_fn_categorical(p):
|
||||
return torch.distributions.Categorical(logits=p)
|
||||
class DistributionFunctionFactoryCategorical(DistributionFunctionFactory):
|
||||
def create_dist_fn(self, envs: Environments) -> TDistributionFunction:
|
||||
assert envs.get_type().assert_discrete(self)
|
||||
return self._dist_fn
|
||||
|
||||
@staticmethod
|
||||
def _dist_fn(p):
|
||||
return torch.distributions.Categorical(logits=p)
|
||||
|
||||
|
||||
def _dist_fn_gaussian(*p):
|
||||
return torch.distributions.Independent(torch.distributions.Normal(*p), 1)
|
||||
class DistributionFunctionFactoryIndependentGaussians(DistributionFunctionFactory):
|
||||
def create_dist_fn(self, envs: Environments) -> TDistributionFunction:
|
||||
assert envs.get_type().assert_continuous(self)
|
||||
return self._dist_fn
|
||||
|
||||
@staticmethod
|
||||
def _dist_fn(*p):
|
||||
return torch.distributions.Independent(torch.distributions.Normal(*p), 1)
|
||||
|
||||
|
||||
class DistributionFunctionFactoryDefault(DistributionFunctionFactory):
|
||||
def create_dist_fn(self, envs: Environments) -> TDistributionFunction:
|
||||
match envs.get_type():
|
||||
case EnvType.DISCRETE:
|
||||
return _dist_fn_categorical
|
||||
return DistributionFunctionFactoryCategorical().create_dist_fn(envs)
|
||||
case EnvType.CONTINUOUS:
|
||||
return _dist_fn_gaussian
|
||||
return DistributionFunctionFactoryIndependentGaussians().create_dist_fn(envs)
|
||||
case _:
|
||||
raise ValueError(envs.get_type())
|
||||
|
Loading…
x
Reference in New Issue
Block a user