Add DistributionFunctionFactory subclasses for discrete/continuous default
This commit is contained in:
parent
a8dc75fbab
commit
1243894eb8
@ -6,7 +6,6 @@ from collections.abc import Sequence
|
|||||||
from typing import Literal
|
from typing import Literal
|
||||||
|
|
||||||
from jsonargparse import CLI
|
from jsonargparse import CLI
|
||||||
from torch.distributions import Independent, Normal
|
|
||||||
|
|
||||||
from examples.mujoco.mujoco_env import MujocoEnvFactory
|
from examples.mujoco.mujoco_env import MujocoEnvFactory
|
||||||
from tianshou.highlevel.config import SamplingConfig
|
from tianshou.highlevel.config import SamplingConfig
|
||||||
@ -14,6 +13,7 @@ from tianshou.highlevel.experiment import (
|
|||||||
ExperimentConfig,
|
ExperimentConfig,
|
||||||
PPOExperimentBuilder,
|
PPOExperimentBuilder,
|
||||||
)
|
)
|
||||||
|
from tianshou.highlevel.params.dist_fn import DistributionFunctionFactoryIndependentGaussians
|
||||||
from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryLinear
|
from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryLinear
|
||||||
from tianshou.highlevel.params.policy_params import PPOParams
|
from tianshou.highlevel.params.policy_params import PPOParams
|
||||||
from tianshou.utils import logging
|
from tianshou.utils import logging
|
||||||
@ -62,9 +62,6 @@ def main(
|
|||||||
|
|
||||||
env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config)
|
env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config)
|
||||||
|
|
||||||
def dist_fn(*logits):
|
|
||||||
return Independent(Normal(*logits), 1)
|
|
||||||
|
|
||||||
experiment = (
|
experiment = (
|
||||||
PPOExperimentBuilder(env_factory, experiment_config, sampling_config)
|
PPOExperimentBuilder(env_factory, experiment_config, sampling_config)
|
||||||
.with_ppo_params(
|
.with_ppo_params(
|
||||||
@ -85,7 +82,7 @@ def main(
|
|||||||
lr_scheduler_factory=LRSchedulerFactoryLinear(sampling_config)
|
lr_scheduler_factory=LRSchedulerFactoryLinear(sampling_config)
|
||||||
if lr_decay
|
if lr_decay
|
||||||
else None,
|
else None,
|
||||||
dist_fn=dist_fn,
|
dist_fn=DistributionFunctionFactoryIndependentGaussians(),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
.with_actor_factory_default(hidden_sizes)
|
.with_actor_factory_default(hidden_sizes)
|
||||||
|
@ -17,20 +17,32 @@ class DistributionFunctionFactory(ToStringMixin, ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def _dist_fn_categorical(p):
|
class DistributionFunctionFactoryCategorical(DistributionFunctionFactory):
|
||||||
return torch.distributions.Categorical(logits=p)
|
def create_dist_fn(self, envs: Environments) -> TDistributionFunction:
|
||||||
|
assert envs.get_type().assert_discrete(self)
|
||||||
|
return self._dist_fn
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _dist_fn(p):
|
||||||
|
return torch.distributions.Categorical(logits=p)
|
||||||
|
|
||||||
|
|
||||||
def _dist_fn_gaussian(*p):
|
class DistributionFunctionFactoryIndependentGaussians(DistributionFunctionFactory):
|
||||||
return torch.distributions.Independent(torch.distributions.Normal(*p), 1)
|
def create_dist_fn(self, envs: Environments) -> TDistributionFunction:
|
||||||
|
assert envs.get_type().assert_continuous(self)
|
||||||
|
return self._dist_fn
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _dist_fn(*p):
|
||||||
|
return torch.distributions.Independent(torch.distributions.Normal(*p), 1)
|
||||||
|
|
||||||
|
|
||||||
class DistributionFunctionFactoryDefault(DistributionFunctionFactory):
|
class DistributionFunctionFactoryDefault(DistributionFunctionFactory):
|
||||||
def create_dist_fn(self, envs: Environments) -> TDistributionFunction:
|
def create_dist_fn(self, envs: Environments) -> TDistributionFunction:
|
||||||
match envs.get_type():
|
match envs.get_type():
|
||||||
case EnvType.DISCRETE:
|
case EnvType.DISCRETE:
|
||||||
return _dist_fn_categorical
|
return DistributionFunctionFactoryCategorical().create_dist_fn(envs)
|
||||||
case EnvType.CONTINUOUS:
|
case EnvType.CONTINUOUS:
|
||||||
return _dist_fn_gaussian
|
return DistributionFunctionFactoryIndependentGaussians().create_dist_fn(envs)
|
||||||
case _:
|
case _:
|
||||||
raise ValueError(envs.get_type())
|
raise ValueError(envs.get_type())
|
||||||
|
Loading…
x
Reference in New Issue
Block a user