Naming and typing improvements in Actor/Critic/Policy forwards (#1032)

Closes #917 

### Internal Improvements
- Better variable names related to model outputs (logits, dist input
etc.). #1032
- Improved typing for actors and critics, using Tianshou classes like
`Actor`, `ActorProb`, etc.,
instead of just `nn.Module`. #1032
- Added interfaces for most `Actor` and `Critic` classes to enforce the
presence of `forward` methods. #1032
- Simplified `PGPolicy` forward by unifying the `dist_fn` interface (see
associated breaking change). #1032
- Use `.mode` of distribution instead of relying on knowledge of the
distribution type. #1032

### Breaking Changes

- Changed interface of `dist_fn` in `PGPolicy` and all subclasses to
take a single argument in both
continuous and discrete cases. #1032

---------

Co-authored-by: Arnau Jimenez <arnau.jimenez@zeiss.com>
Co-authored-by: Michael Panchenko <m.panchenko@appliedai.de>
This commit is contained in:
Erni 2024-04-01 17:14:17 +02:00 committed by GitHub
parent 5bf923c9bd
commit bf0d632108
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
43 changed files with 342 additions and 245 deletions

View File

@ -13,6 +13,12 @@
- Introduced a first iteration of a naming convention for vars in `Collector`s. #1063 - Introduced a first iteration of a naming convention for vars in `Collector`s. #1063
- Generally improved readability of Collector code and associated tests (still quite some way to go). #1063 - Generally improved readability of Collector code and associated tests (still quite some way to go). #1063
- Improved typing for `exploration_noise` and within Collector. #1063 - Improved typing for `exploration_noise` and within Collector. #1063
- Better variable names related to model outputs (logits, dist input etc.). #1032
- Improved typing for actors and critics, using Tianshou classes like `Actor`, `ActorProb`, etc.,
instead of just `nn.Module`. #1032
- Added interfaces for most `Actor` and `Critic` classes to enforce the presence of `forward` methods. #1032
- Simplified `PGPolicy` forward by unifying the `dist_fn` interface (see associated breaking change). #1032
- Use `.mode` of distribution instead of relying on knowledge of the distribution type. #1032
### Breaking Changes ### Breaking Changes
@ -21,6 +27,8 @@
expicitly or pass `reset_before_collect=True` . #1063 expicitly or pass `reset_before_collect=True` . #1063
- VectorEnvs now return an array of info-dicts on reset instead of a list. #1063 - VectorEnvs now return an array of info-dicts on reset instead of a list. #1063
- Fixed `iter(Batch(...)` which now behaves the same way as `Batch(...).__iter__()`. Can be considered a bugfix. #1063 - Fixed `iter(Batch(...)` which now behaves the same way as `Batch(...).__iter__()`. Can be considered a bugfix. #1063
- Changed interface of `dist_fn` in `PGPolicy` and all subclasses to take a single argument in both
continuous and discrete cases. #1032
### Tests ### Tests
- Fixed env seeding it test_sac_with_il.py so that the test doesn't fail randomly. #1081 - Fixed env seeding it test_sac_with_il.py so that the test doesn't fail randomly. #1081

View File

@ -69,7 +69,7 @@
"from tianshou.policy import BasePolicy\n", "from tianshou.policy import BasePolicy\n",
"from tianshou.policy.modelfree.pg import (\n", "from tianshou.policy.modelfree.pg import (\n",
" PGTrainingStats,\n", " PGTrainingStats,\n",
" TDistributionFunction,\n", " TDistFnDiscrOrCont,\n",
" TPGTrainingStats,\n", " TPGTrainingStats,\n",
")\n", ")\n",
"from tianshou.utils import RunningMeanStd\n", "from tianshou.utils import RunningMeanStd\n",
@ -339,7 +339,7 @@
" *,\n", " *,\n",
" actor: torch.nn.Module,\n", " actor: torch.nn.Module,\n",
" optim: torch.optim.Optimizer,\n", " optim: torch.optim.Optimizer,\n",
" dist_fn: TDistributionFunction,\n", " dist_fn: TDistFnDiscrOrCont,\n",
" action_space: gym.Space,\n", " action_space: gym.Space,\n",
" discount_factor: float = 0.99,\n", " discount_factor: float = 0.99,\n",
" observation_space: gym.Space | None = None,\n", " observation_space: gym.Space | None = None,\n",

View File

@ -257,3 +257,8 @@ macOS
joblib joblib
master master
Panchenko Panchenko
BA
BH
BO
BD

View File

@ -167,8 +167,9 @@ def test_gail(args: argparse.Namespace = get_args()) -> None:
lr_scheduler = LambdaLR(optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num) lr_scheduler = LambdaLR(optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num)
def dist(*logits: torch.Tensor) -> Distribution: def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution:
return Independent(Normal(*logits), 1) loc, scale = loc_scale
return Independent(Normal(loc, scale), 1)
# expert replay buffer # expert replay buffer
dataset = d4rl.qlearning_dataset(gym.make(args.expert_data_task)) dataset = d4rl.qlearning_dataset(gym.make(args.expert_data_task))

View File

@ -137,8 +137,9 @@ def test_a2c(args: argparse.Namespace = get_args()) -> None:
lr_scheduler = LambdaLR(optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num) lr_scheduler = LambdaLR(optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num)
def dist(*logits: torch.Tensor) -> Distribution: def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution:
return Independent(Normal(*logits), 1) loc, scale = loc_scale
return Independent(Normal(loc, scale), 1)
policy: A2CPolicy = A2CPolicy( policy: A2CPolicy = A2CPolicy(
actor=actor, actor=actor,

View File

@ -134,8 +134,9 @@ def test_npg(args: argparse.Namespace = get_args()) -> None:
lr_scheduler = LambdaLR(optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num) lr_scheduler = LambdaLR(optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num)
def dist(*logits: torch.Tensor) -> Distribution: def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution:
return Independent(Normal(*logits), 1) loc, scale = loc_scale
return Independent(Normal(loc, scale), 1)
policy: NPGPolicy = NPGPolicy( policy: NPGPolicy = NPGPolicy(
actor=actor, actor=actor,

View File

@ -137,8 +137,9 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None:
lr_scheduler = LambdaLR(optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num) lr_scheduler = LambdaLR(optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num)
def dist(*logits: torch.Tensor) -> Distribution: def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution:
return Independent(Normal(*logits), 1) loc, scale = loc_scale
return Independent(Normal(loc, scale), 1)
policy: PPOPolicy = PPOPolicy( policy: PPOPolicy = PPOPolicy(
actor=actor, actor=actor,

View File

@ -119,8 +119,9 @@ def test_reinforce(args: argparse.Namespace = get_args()) -> None:
lr_scheduler = LambdaLR(optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num) lr_scheduler = LambdaLR(optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num)
def dist(*logits: torch.Tensor) -> Distribution: def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution:
return Independent(Normal(*logits), 1) loc, scale = loc_scale
return Independent(Normal(loc, scale), 1)
policy: PGPolicy = PGPolicy( policy: PGPolicy = PGPolicy(
actor=actor, actor=actor,

View File

@ -137,8 +137,9 @@ def test_trpo(args: argparse.Namespace = get_args()) -> None:
lr_scheduler = LambdaLR(optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num) lr_scheduler = LambdaLR(optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num)
def dist(*logits: torch.Tensor) -> Distribution: def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution:
return Independent(Normal(*logits), 1) loc, scale = loc_scale
return Independent(Normal(loc, scale), 1)
policy: TRPOPolicy = TRPOPolicy( policy: TRPOPolicy = TRPOPolicy(
actor=actor, actor=actor,

View File

@ -2,7 +2,7 @@ import gymnasium as gym
import numpy as np import numpy as np
import pytest import pytest
import torch import torch
from torch.distributions import Categorical, Independent, Normal from torch.distributions import Categorical, Distribution, Independent, Normal
from tianshou.policy import PPOPolicy from tianshou.policy import PPOPolicy
from tianshou.utils.net.common import ActorCritic, Net from tianshou.utils.net.common import ActorCritic, Net
@ -25,7 +25,11 @@ def policy(request):
Net(state_shape=obs_shape, hidden_sizes=[64, 64], action_shape=action_space.shape), Net(state_shape=obs_shape, hidden_sizes=[64, 64], action_shape=action_space.shape),
action_shape=action_space.shape, action_shape=action_space.shape,
) )
dist_fn = lambda *logits: Independent(Normal(*logits), 1)
def dist_fn(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution:
loc, scale = loc_scale
return Independent(Normal(loc, scale), 1)
elif action_type == "discrete": elif action_type == "discrete":
action_space = gym.spaces.Discrete(3) action_space = gym.spaces.Discrete(3)
actor = Actor( actor = Actor(

View File

@ -103,8 +103,9 @@ def test_npg(args: argparse.Namespace = get_args()) -> None:
# replace DiagGuassian with Independent(Normal) which is equivalent # replace DiagGuassian with Independent(Normal) which is equivalent
# pass *logits to be consistent with policy.forward # pass *logits to be consistent with policy.forward
def dist(*logits: torch.Tensor) -> Distribution: def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution:
return Independent(Normal(*logits), 1) loc, scale = loc_scale
return Independent(Normal(loc, scale), 1)
policy: NPGPolicy[NPGTrainingStats] = NPGPolicy( policy: NPGPolicy[NPGTrainingStats] = NPGPolicy(
actor=actor, actor=actor,

View File

@ -100,8 +100,9 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None:
# replace DiagGuassian with Independent(Normal) which is equivalent # replace DiagGuassian with Independent(Normal) which is equivalent
# pass *logits to be consistent with policy.forward # pass *logits to be consistent with policy.forward
def dist(*logits: torch.Tensor) -> Distribution: def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution:
return Independent(Normal(*logits), 1) loc, scale = loc_scale
return Independent(Normal(loc, scale), 1)
policy: PPOPolicy[PPOTrainingStats] = PPOPolicy( policy: PPOPolicy[PPOTrainingStats] = PPOPolicy(
actor=actor, actor=actor,

View File

@ -102,8 +102,9 @@ def test_trpo(args: argparse.Namespace = get_args()) -> None:
# replace DiagGuassian with Independent(Normal) which is equivalent # replace DiagGuassian with Independent(Normal) which is equivalent
# pass *logits to be consistent with policy.forward # pass *logits to be consistent with policy.forward
def dist(*logits: torch.Tensor) -> Distribution: def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution:
return Independent(Normal(*logits), 1) loc, scale = loc_scale
return Independent(Normal(loc, scale), 1)
policy: BasePolicy = TRPOPolicy( policy: BasePolicy = TRPOPolicy(
actor=actor, actor=actor,

View File

@ -133,8 +133,9 @@ def test_gail(args: argparse.Namespace = get_args()) -> None:
# replace DiagGuassian with Independent(Normal) which is equivalent # replace DiagGuassian with Independent(Normal) which is equivalent
# pass *logits to be consistent with policy.forward # pass *logits to be consistent with policy.forward
def dist(*logits: torch.Tensor) -> Distribution: def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution:
return Independent(Normal(*logits), 1) loc, scale = loc_scale
return Independent(Normal(loc, scale), 1)
policy: BasePolicy = GAILPolicy( policy: BasePolicy = GAILPolicy(
actor=actor, actor=actor,

View File

@ -181,8 +181,9 @@ def get_agents(
torch.nn.init.zeros_(m.bias) torch.nn.init.zeros_(m.bias)
optim = torch.optim.Adam(set(actor.parameters()).union(critic.parameters()), lr=args.lr) optim = torch.optim.Adam(set(actor.parameters()).union(critic.parameters()), lr=args.lr)
def dist(*logits: torch.Tensor) -> Distribution: def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution:
return Independent(Normal(*logits), 1) loc, scale = loc_scale
return Independent(Normal(loc, scale), 1)
agent: PPOPolicy = PPOPolicy( agent: PPOPolicy = PPOPolicy(
actor, actor,

View File

@ -1,40 +1,47 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Callable
from typing import Any
import torch import torch
from tianshou.highlevel.env import Environments, EnvType from tianshou.highlevel.env import Environments, EnvType
from tianshou.policy.modelfree.pg import TDistributionFunction from tianshou.policy.modelfree.pg import TDistFnDiscrete, TDistFnDiscrOrCont
from tianshou.utils.string import ToStringMixin from tianshou.utils.string import ToStringMixin
class DistributionFunctionFactory(ToStringMixin, ABC): class DistributionFunctionFactory(ToStringMixin, ABC):
# True return type defined in subclasses
@abstractmethod @abstractmethod
def create_dist_fn(self, envs: Environments) -> TDistributionFunction: def create_dist_fn(
self,
envs: Environments,
) -> Callable[[Any], torch.distributions.Distribution]:
pass pass
class DistributionFunctionFactoryCategorical(DistributionFunctionFactory): class DistributionFunctionFactoryCategorical(DistributionFunctionFactory):
def create_dist_fn(self, envs: Environments) -> TDistributionFunction: def create_dist_fn(self, envs: Environments) -> TDistFnDiscrete:
envs.get_type().assert_discrete(self) envs.get_type().assert_discrete(self)
return self._dist_fn return self._dist_fn
@staticmethod @staticmethod
def _dist_fn(p: torch.Tensor) -> torch.distributions.Distribution: def _dist_fn(p: torch.Tensor) -> torch.distributions.Categorical:
return torch.distributions.Categorical(logits=p) return torch.distributions.Categorical(logits=p)
class DistributionFunctionFactoryIndependentGaussians(DistributionFunctionFactory): class DistributionFunctionFactoryIndependentGaussians(DistributionFunctionFactory):
def create_dist_fn(self, envs: Environments) -> TDistributionFunction: def create_dist_fn(self, envs: Environments) -> TDistFnDiscrOrCont:
envs.get_type().assert_continuous(self) envs.get_type().assert_continuous(self)
return self._dist_fn return self._dist_fn
@staticmethod @staticmethod
def _dist_fn(*p: torch.Tensor) -> torch.distributions.Distribution: def _dist_fn(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> torch.distributions.Distribution:
return torch.distributions.Independent(torch.distributions.Normal(*p), 1) loc, scale = loc_scale
return torch.distributions.Independent(torch.distributions.Normal(loc, scale), 1)
class DistributionFunctionFactoryDefault(DistributionFunctionFactory): class DistributionFunctionFactoryDefault(DistributionFunctionFactory):
def create_dist_fn(self, envs: Environments) -> TDistributionFunction: def create_dist_fn(self, envs: Environments) -> TDistFnDiscrOrCont:
match envs.get_type(): match envs.get_type():
case EnvType.DISCRETE: case EnvType.DISCRETE:
return DistributionFunctionFactoryCategorical().create_dist_fn(envs) return DistributionFunctionFactoryCategorical().create_dist_fn(envs)

View File

@ -19,7 +19,7 @@ from tianshou.highlevel.params.dist_fn import (
from tianshou.highlevel.params.env_param import EnvValueFactory, FloatEnvValueFactory from tianshou.highlevel.params.env_param import EnvValueFactory, FloatEnvValueFactory
from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactory from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactory
from tianshou.highlevel.params.noise import NoiseFactory from tianshou.highlevel.params.noise import NoiseFactory
from tianshou.policy.modelfree.pg import TDistributionFunction from tianshou.policy.modelfree.pg import TDistFnDiscrOrCont
from tianshou.utils import MultipleLRSchedulers from tianshou.utils import MultipleLRSchedulers
from tianshou.utils.string import ToStringMixin from tianshou.utils.string import ToStringMixin
@ -322,7 +322,7 @@ class PGParams(Params, ParamsMixinActionScaling, ParamsMixinLearningRateWithSche
whether to use deterministic action (the dist's mode) instead of stochastic one during evaluation. whether to use deterministic action (the dist's mode) instead of stochastic one during evaluation.
Does not affect training. Does not affect training.
""" """
dist_fn: TDistributionFunction | DistributionFunctionFactory | Literal["default"] = "default" dist_fn: TDistFnDiscrOrCont | DistributionFunctionFactory | Literal["default"] = "default"
""" """
This can either be a function which maps the model output to a torch distribution or a This can either be a function which maps the model output to a torch distribution or a
factory for the creation of such a function. factory for the creation of such a function.

View File

@ -213,10 +213,11 @@ class BasePolicy(nn.Module, Generic[TTrainingStats], ABC):
super().__init__() super().__init__()
self.observation_space = observation_space self.observation_space = observation_space
self.action_space = action_space self.action_space = action_space
self._action_type: Literal["discrete", "continuous"]
if isinstance(action_space, Discrete | MultiDiscrete | MultiBinary): if isinstance(action_space, Discrete | MultiDiscrete | MultiBinary):
self.action_type = "discrete" self._action_type = "discrete"
elif isinstance(action_space, Box): elif isinstance(action_space, Box):
self.action_type = "continuous" self._action_type = "continuous"
else: else:
raise ValueError(f"Unsupported action space: {action_space}.") raise ValueError(f"Unsupported action space: {action_space}.")
self.agent_id = 0 self.agent_id = 0
@ -226,6 +227,10 @@ class BasePolicy(nn.Module, Generic[TTrainingStats], ABC):
self.lr_scheduler = lr_scheduler self.lr_scheduler = lr_scheduler
self._compile() self._compile()
@property
def action_type(self) -> Literal["discrete", "continuous"]:
return self._action_type
def set_agent_id(self, agent_id: int) -> None: def set_agent_id(self, agent_id: int) -> None:
"""Set self.agent_id = agent_id, for MARL.""" """Set self.agent_id = agent_id, for MARL."""
self.agent_id = agent_id self.agent_id = agent_id

View File

@ -16,6 +16,12 @@ from tianshou.data.types import (
from tianshou.policy import BasePolicy from tianshou.policy import BasePolicy
from tianshou.policy.base import TLearningRateScheduler, TrainingStats from tianshou.policy.base import TLearningRateScheduler, TrainingStats
# Dimension Naming Convention
# B - Batch Size
# A - Action
# D - Dist input (usually 2, loc and scale)
# H - Dimension of hidden, can be None
@dataclass(kw_only=True) @dataclass(kw_only=True)
class ImitationTrainingStats(TrainingStats): class ImitationTrainingStats(TrainingStats):
@ -72,9 +78,20 @@ class ImitationPolicy(BasePolicy[TImitationTrainingStats], Generic[TImitationTra
state: dict | BatchProtocol | np.ndarray | None = None, state: dict | BatchProtocol | np.ndarray | None = None,
**kwargs: Any, **kwargs: Any,
) -> ModelOutputBatchProtocol: ) -> ModelOutputBatchProtocol:
logits, hidden = self.actor(batch.obs, state=state, info=batch.info) # TODO - ALGO-REFACTORING: marked for refactoring when Algorithm abstraction is introduced
act = logits.max(dim=1)[1] if self.action_type == "discrete" else logits if self.action_type == "discrete":
result = Batch(logits=logits, act=act, state=hidden) # If it's discrete, the "actor" is usually a critic that maps obs to action_values
# which then could be turned into logits or a Categorigal
action_values_BA, hidden_BH = self.actor(batch.obs, state=state, info=batch.info)
act_B = action_values_BA.argmax(dim=1)
result = Batch(logits=action_values_BA, act=act_B, state=hidden_BH)
elif self.action_type == "continuous":
# If it's continuous, the actor would usually deliver something like loc, scale determining a
# Gaussian dist
dist_input_BD, hidden_BH = self.actor(batch.obs, state=state, info=batch.info)
result = Batch(logits=dist_input_BD, act=dist_input_BD, state=hidden_BH)
else:
raise RuntimeError(f"Unknown {self.action_type=}, this shouldn't have happened!")
return cast(ModelOutputBatchProtocol, result) return cast(ModelOutputBatchProtocol, result)
def learn( def learn(

View File

@ -34,8 +34,7 @@ TDiscreteBCQTrainingStats = TypeVar("TDiscreteBCQTrainingStats", bound=DiscreteB
class DiscreteBCQPolicy(DQNPolicy[TDiscreteBCQTrainingStats]): class DiscreteBCQPolicy(DQNPolicy[TDiscreteBCQTrainingStats]):
"""Implementation of discrete BCQ algorithm. arXiv:1910.01708. """Implementation of discrete BCQ algorithm. arXiv:1910.01708.
:param model: a model following the rules in :param model: a model following the rules (s_B -> action_values_BA)
:class:`~tianshou.policy.BasePolicy`. (s -> q_value)
:param imitator: a model following the rules in :param imitator: a model following the rules in
:class:`~tianshou.policy.BasePolicy`. (s -> imitation_logits) :class:`~tianshou.policy.BasePolicy`. (s -> imitation_logits)
:param optim: a torch.optim for optimizing the model. :param optim: a torch.optim for optimizing the model.

View File

@ -25,8 +25,7 @@ TDiscreteCQLTrainingStats = TypeVar("TDiscreteCQLTrainingStats", bound=DiscreteC
class DiscreteCQLPolicy(QRDQNPolicy[TDiscreteCQLTrainingStats]): class DiscreteCQLPolicy(QRDQNPolicy[TDiscreteCQLTrainingStats]):
"""Implementation of discrete Conservative Q-Learning algorithm. arXiv:2006.04779. """Implementation of discrete Conservative Q-Learning algorithm. arXiv:2006.04779.
:param model: a model following the rules in :param model: a model following the rules (s_B -> action_values_BA)
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
:param optim: a torch.optim for optimizing the model. :param optim: a torch.optim for optimizing the model.
:param action_space: Env's action space. :param action_space: Env's action space.
:param min_q_weight: the weight for the cql loss. :param min_q_weight: the weight for the cql loss.

View File

@ -11,6 +11,7 @@ from tianshou.data import to_torch, to_torch_as
from tianshou.data.types import RolloutBatchProtocol from tianshou.data.types import RolloutBatchProtocol
from tianshou.policy.base import TLearningRateScheduler from tianshou.policy.base import TLearningRateScheduler
from tianshou.policy.modelfree.pg import PGPolicy, PGTrainingStats from tianshou.policy.modelfree.pg import PGPolicy, PGTrainingStats
from tianshou.utils.net.discrete import Actor, Critic
@dataclass @dataclass
@ -26,8 +27,9 @@ TDiscreteCRRTrainingStats = TypeVar("TDiscreteCRRTrainingStats", bound=DiscreteC
class DiscreteCRRPolicy(PGPolicy[TDiscreteCRRTrainingStats]): class DiscreteCRRPolicy(PGPolicy[TDiscreteCRRTrainingStats]):
r"""Implementation of discrete Critic Regularized Regression. arXiv:2006.15134. r"""Implementation of discrete Critic Regularized Regression. arXiv:2006.15134.
:param actor: the actor network following the rules in :param actor: the actor network following the rules:
:class:`~tianshou.policy.BasePolicy`. (s -> logits) If `self.action_type == "discrete"`: (`s_B` ->`action_values_BA`).
If `self.action_type == "continuous"`: (`s_B` -> `dist_input_BD`).
:param critic: the action-value critic (i.e., Q function) :param critic: the action-value critic (i.e., Q function)
network. (s -> Q(s, \*)) network. (s -> Q(s, \*))
:param optim: a torch.optim for optimizing the model. :param optim: a torch.optim for optimizing the model.
@ -55,8 +57,8 @@ class DiscreteCRRPolicy(PGPolicy[TDiscreteCRRTrainingStats]):
def __init__( def __init__(
self, self,
*, *,
actor: torch.nn.Module, actor: torch.nn.Module | Actor,
critic: torch.nn.Module, critic: torch.nn.Module | Critic,
optim: torch.optim.Optimizer, optim: torch.optim.Optimizer,
action_space: gym.spaces.Discrete, action_space: gym.spaces.Discrete,
discount_factor: float = 0.99, discount_factor: float = 0.99,

View File

@ -15,8 +15,11 @@ from tianshou.data import (
from tianshou.data.types import LogpOldProtocol, RolloutBatchProtocol from tianshou.data.types import LogpOldProtocol, RolloutBatchProtocol
from tianshou.policy import PPOPolicy from tianshou.policy import PPOPolicy
from tianshou.policy.base import TLearningRateScheduler from tianshou.policy.base import TLearningRateScheduler
from tianshou.policy.modelfree.pg import TDistributionFunction from tianshou.policy.modelfree.pg import TDistFnDiscrOrCont
from tianshou.policy.modelfree.ppo import PPOTrainingStats from tianshou.policy.modelfree.ppo import PPOTrainingStats
from tianshou.utils.net.continuous import ActorProb, Critic
from tianshou.utils.net.discrete import Actor as DiscreteActor
from tianshou.utils.net.discrete import Critic as DiscreteCritic
@dataclass(kw_only=True) @dataclass(kw_only=True)
@ -32,7 +35,9 @@ TGailTrainingStats = TypeVar("TGailTrainingStats", bound=GailTrainingStats)
class GAILPolicy(PPOPolicy[TGailTrainingStats]): class GAILPolicy(PPOPolicy[TGailTrainingStats]):
r"""Implementation of Generative Adversarial Imitation Learning. arXiv:1606.03476. r"""Implementation of Generative Adversarial Imitation Learning. arXiv:1606.03476.
:param actor: the actor network following the rules in BasePolicy. (s -> logits) :param actor: the actor network following the rules:
If `self.action_type == "discrete"`: (`s_B` ->`action_values_BA`).
If `self.action_type == "continuous"`: (`s_B` -> `dist_input_BD`).
:param critic: the critic network. (s -> V(s)) :param critic: the critic network. (s -> V(s))
:param optim: the optimizer for actor and critic network. :param optim: the optimizer for actor and critic network.
:param dist_fn: distribution class for computing the action. :param dist_fn: distribution class for computing the action.
@ -75,10 +80,10 @@ class GAILPolicy(PPOPolicy[TGailTrainingStats]):
def __init__( def __init__(
self, self,
*, *,
actor: torch.nn.Module, actor: torch.nn.Module | ActorProb | DiscreteActor,
critic: torch.nn.Module, critic: torch.nn.Module | Critic | DiscreteCritic,
optim: torch.optim.Optimizer, optim: torch.optim.Optimizer,
dist_fn: TDistributionFunction, dist_fn: TDistFnDiscrOrCont,
action_space: gym.Space, action_space: gym.Space,
expert_buffer: ReplayBuffer, expert_buffer: ReplayBuffer,
disc_net: torch.nn.Module, disc_net: torch.nn.Module,

View File

@ -25,7 +25,7 @@ class TD3BCPolicy(TD3Policy[TTD3BCTrainingStats]):
"""Implementation of TD3+BC. arXiv:2106.06860. """Implementation of TD3+BC. arXiv:2106.06860.
:param actor: the actor network following the rules in :param actor: the actor network following the rules in
:class:`~tianshou.policy.BasePolicy`. (s -> logits) :class:`~tianshou.policy.BasePolicy`. (s -> actions)
:param actor_optim: the optimizer for actor network. :param actor_optim: the optimizer for actor network.
:param critic: the first critic network. (s, a -> Q(s, a)) :param critic: the first critic network. (s, a -> Q(s, a))
:param critic_optim: the optimizer for the first critic network. :param critic_optim: the optimizer for the first critic network.

View File

@ -11,8 +11,11 @@ from tianshou.data import ReplayBuffer, SequenceSummaryStats, to_torch_as
from tianshou.data.types import BatchWithAdvantagesProtocol, RolloutBatchProtocol from tianshou.data.types import BatchWithAdvantagesProtocol, RolloutBatchProtocol
from tianshou.policy import PGPolicy from tianshou.policy import PGPolicy
from tianshou.policy.base import TLearningRateScheduler, TrainingStats from tianshou.policy.base import TLearningRateScheduler, TrainingStats
from tianshou.policy.modelfree.pg import TDistributionFunction from tianshou.policy.modelfree.pg import TDistFnDiscrOrCont
from tianshou.utils.net.common import ActorCritic from tianshou.utils.net.common import ActorCritic
from tianshou.utils.net.continuous import ActorProb, Critic
from tianshou.utils.net.discrete import Actor as DiscreteActor
from tianshou.utils.net.discrete import Critic as DiscreteCritic
@dataclass(kw_only=True) @dataclass(kw_only=True)
@ -30,7 +33,9 @@ TA2CTrainingStats = TypeVar("TA2CTrainingStats", bound=A2CTrainingStats)
class A2CPolicy(PGPolicy[TA2CTrainingStats], Generic[TA2CTrainingStats]): # type: ignore[type-var] class A2CPolicy(PGPolicy[TA2CTrainingStats], Generic[TA2CTrainingStats]): # type: ignore[type-var]
"""Implementation of Synchronous Advantage Actor-Critic. arXiv:1602.01783. """Implementation of Synchronous Advantage Actor-Critic. arXiv:1602.01783.
:param actor: the actor network following the rules in BasePolicy. (s -> logits) :param actor: the actor network following the rules:
If `self.action_type == "discrete"`: (`s_B` ->`action_values_BA`).
If `self.action_type == "continuous"`: (`s_B` -> `dist_input_BD`).
:param critic: the critic network. (s -> V(s)) :param critic: the critic network. (s -> V(s))
:param optim: the optimizer for actor and critic network. :param optim: the optimizer for actor and critic network.
:param dist_fn: distribution class for computing the action. :param dist_fn: distribution class for computing the action.
@ -59,10 +64,10 @@ class A2CPolicy(PGPolicy[TA2CTrainingStats], Generic[TA2CTrainingStats]): # typ
def __init__( def __init__(
self, self,
*, *,
actor: torch.nn.Module, actor: torch.nn.Module | ActorProb | DiscreteActor,
critic: torch.nn.Module, critic: torch.nn.Module | Critic | DiscreteCritic,
optim: torch.optim.Optimizer, optim: torch.optim.Optimizer,
dist_fn: TDistributionFunction, dist_fn: TDistFnDiscrOrCont,
action_space: gym.Space, action_space: gym.Space,
vf_coef: float = 0.5, vf_coef: float = 0.5,
ent_coef: float = 0.01, ent_coef: float = 0.01,

View File

@ -31,7 +31,7 @@ TBDQNTrainingStats = TypeVar("TBDQNTrainingStats", bound=BDQNTrainingStats)
class BranchingDQNPolicy(DQNPolicy[TBDQNTrainingStats]): class BranchingDQNPolicy(DQNPolicy[TBDQNTrainingStats]):
"""Implementation of the Branching dual Q network arXiv:1711.08946. """Implementation of the Branching dual Q network arXiv:1711.08946.
:param model: BranchingNet mapping (obs, state, info) -> logits. :param model: BranchingNet mapping (obs, state, info) -> action_values_BA.
:param optim: a torch.optim for optimizing the model. :param optim: a torch.optim for optimizing the model.
:param discount_factor: in [0, 1]. :param discount_factor: in [0, 1].
:param estimation_step: the number of steps to look ahead. :param estimation_step: the number of steps to look ahead.
@ -156,10 +156,10 @@ class BranchingDQNPolicy(DQNPolicy[TBDQNTrainingStats]):
model = getattr(self, model) model = getattr(self, model)
obs = batch.obs obs = batch.obs
# TODO: this is very contrived, see also iqn.py # TODO: this is very contrived, see also iqn.py
obs_next = obs.obs if hasattr(obs, "obs") else obs obs_next_BO = obs.obs if hasattr(obs, "obs") else obs
logits, hidden = model(obs_next, state=state, info=batch.info) action_values_BA, hidden_BH = model(obs_next_BO, state=state, info=batch.info)
act = to_numpy(logits.max(dim=-1)[1]) act_B = to_numpy(action_values_BA.argmax(dim=-1))
result = Batch(logits=logits, act=act, state=hidden) result = Batch(logits=action_values_BA, act=act_B, state=hidden_BH)
return cast(ModelOutputBatchProtocol, result) return cast(ModelOutputBatchProtocol, result)
def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TBDQNTrainingStats: def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TBDQNTrainingStats:

View File

@ -23,8 +23,7 @@ TC51TrainingStats = TypeVar("TC51TrainingStats", bound=C51TrainingStats)
class C51Policy(DQNPolicy[TC51TrainingStats], Generic[TC51TrainingStats]): class C51Policy(DQNPolicy[TC51TrainingStats], Generic[TC51TrainingStats]):
"""Implementation of Categorical Deep Q-Network. arXiv:1707.06887. """Implementation of Categorical Deep Q-Network. arXiv:1707.06887.
:param model: a model following the rules in :param model: a model following the rules (s_B -> action_values_BA)
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
:param optim: a torch.optim for optimizing the model. :param optim: a torch.optim for optimizing the model.
:param discount_factor: in [0, 1]. :param discount_factor: in [0, 1].
:param num_atoms: the number of atoms in the support set of the :param num_atoms: the number of atoms in the support set of the

View File

@ -19,6 +19,7 @@ from tianshou.data.types import (
from tianshou.exploration import BaseNoise, GaussianNoise from tianshou.exploration import BaseNoise, GaussianNoise
from tianshou.policy import BasePolicy from tianshou.policy import BasePolicy
from tianshou.policy.base import TLearningRateScheduler, TrainingStats from tianshou.policy.base import TLearningRateScheduler, TrainingStats
from tianshou.utils.net.continuous import Actor, Critic
@dataclass(kw_only=True) @dataclass(kw_only=True)
@ -33,8 +34,7 @@ TDDPGTrainingStats = TypeVar("TDDPGTrainingStats", bound=DDPGTrainingStats)
class DDPGPolicy(BasePolicy[TDDPGTrainingStats], Generic[TDDPGTrainingStats]): class DDPGPolicy(BasePolicy[TDDPGTrainingStats], Generic[TDDPGTrainingStats]):
"""Implementation of Deep Deterministic Policy Gradient. arXiv:1509.02971. """Implementation of Deep Deterministic Policy Gradient. arXiv:1509.02971.
:param actor: The actor network following the rules in :param actor: The actor network following the rules (s -> actions)
:class:`~tianshou.policy.BasePolicy`. (s -> model_output)
:param actor_optim: The optimizer for actor network. :param actor_optim: The optimizer for actor network.
:param critic: The critic network. (s, a -> Q(s, a)) :param critic: The critic network. (s, a -> Q(s, a))
:param critic_optim: The optimizer for critic network. :param critic_optim: The optimizer for critic network.
@ -60,9 +60,9 @@ class DDPGPolicy(BasePolicy[TDDPGTrainingStats], Generic[TDDPGTrainingStats]):
def __init__( def __init__(
self, self,
*, *,
actor: torch.nn.Module, actor: torch.nn.Module | Actor,
actor_optim: torch.optim.Optimizer, actor_optim: torch.optim.Optimizer,
critic: torch.nn.Module, critic: torch.nn.Module | Critic,
critic_optim: torch.optim.Optimizer, critic_optim: torch.optim.Optimizer,
action_space: gym.Space, action_space: gym.Space,
tau: float = 0.005, tau: float = 0.005,

View File

@ -12,6 +12,7 @@ from tianshou.data.types import ActBatchProtocol, ObsBatchProtocol, RolloutBatch
from tianshou.policy import SACPolicy from tianshou.policy import SACPolicy
from tianshou.policy.base import TLearningRateScheduler from tianshou.policy.base import TLearningRateScheduler
from tianshou.policy.modelfree.sac import SACTrainingStats from tianshou.policy.modelfree.sac import SACTrainingStats
from tianshou.utils.net.discrete import Actor, Critic
@dataclass @dataclass
@ -25,8 +26,7 @@ TDiscreteSACTrainingStats = TypeVar("TDiscreteSACTrainingStats", bound=DiscreteS
class DiscreteSACPolicy(SACPolicy[TDiscreteSACTrainingStats]): class DiscreteSACPolicy(SACPolicy[TDiscreteSACTrainingStats]):
"""Implementation of SAC for Discrete Action Settings. arXiv:1910.07207. """Implementation of SAC for Discrete Action Settings. arXiv:1910.07207.
:param actor: the actor network following the rules in :param actor: the actor network following the rules (s_B -> dist_input_BD)
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
:param actor_optim: the optimizer for actor network. :param actor_optim: the optimizer for actor network.
:param critic: the first critic network. (s, a -> Q(s, a)) :param critic: the first critic network. (s, a -> Q(s, a))
:param critic_optim: the optimizer for the first critic network. :param critic_optim: the optimizer for the first critic network.
@ -54,12 +54,12 @@ class DiscreteSACPolicy(SACPolicy[TDiscreteSACTrainingStats]):
def __init__( def __init__(
self, self,
*, *,
actor: torch.nn.Module, actor: torch.nn.Module | Actor,
actor_optim: torch.optim.Optimizer, actor_optim: torch.optim.Optimizer,
critic: torch.nn.Module, critic: torch.nn.Module | Critic,
critic_optim: torch.optim.Optimizer, critic_optim: torch.optim.Optimizer,
action_space: gym.spaces.Discrete, action_space: gym.spaces.Discrete,
critic2: torch.nn.Module | None = None, critic2: torch.nn.Module | Critic | None = None,
critic2_optim: torch.optim.Optimizer | None = None, critic2_optim: torch.optim.Optimizer | None = None,
tau: float = 0.005, tau: float = 0.005,
gamma: float = 0.99, gamma: float = 0.99,
@ -105,13 +105,13 @@ class DiscreteSACPolicy(SACPolicy[TDiscreteSACTrainingStats]):
state: dict | Batch | np.ndarray | None = None, state: dict | Batch | np.ndarray | None = None,
**kwargs: Any, **kwargs: Any,
) -> Batch: ) -> Batch:
logits, hidden = self.actor(batch.obs, state=state, info=batch.info) logits_BA, hidden_BH = self.actor(batch.obs, state=state, info=batch.info)
dist = Categorical(logits=logits) dist = Categorical(logits=logits_BA)
if self.deterministic_eval and not self.training: if self.deterministic_eval and not self.training:
act = dist.mode act_B = dist.mode
else: else:
act = dist.sample() act_B = dist.sample()
return Batch(logits=logits, act=act, state=hidden, dist=dist) return Batch(logits=logits_BA, act=act_B, state=hidden_BH, dist=dist)
def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor:
obs_next_batch = Batch( obs_next_batch = Batch(

View File

@ -17,6 +17,7 @@ from tianshou.data.types import (
) )
from tianshou.policy import BasePolicy from tianshou.policy import BasePolicy
from tianshou.policy.base import TLearningRateScheduler, TrainingStats from tianshou.policy.base import TLearningRateScheduler, TrainingStats
from tianshou.utils.net.common import Net
@dataclass(kw_only=True) @dataclass(kw_only=True)
@ -35,8 +36,7 @@ class DQNPolicy(BasePolicy[TDQNTrainingStats], Generic[TDQNTrainingStats]):
Implementation of Dueling DQN. arXiv:1511.06581 (the dueling DQN is Implementation of Dueling DQN. arXiv:1511.06581 (the dueling DQN is
implemented in the network side, not here). implemented in the network side, not here).
:param model: a model following the rules in :param model: a model following the rules (s -> action_values_BA)
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
:param optim: a torch.optim for optimizing the model. :param optim: a torch.optim for optimizing the model.
:param discount_factor: in [0, 1]. :param discount_factor: in [0, 1].
:param estimation_step: the number of steps to look ahead. :param estimation_step: the number of steps to look ahead.
@ -60,7 +60,7 @@ class DQNPolicy(BasePolicy[TDQNTrainingStats], Generic[TDQNTrainingStats]):
def __init__( def __init__(
self, self,
*, *,
model: torch.nn.Module, model: torch.nn.Module | Net,
optim: torch.optim.Optimizer, optim: torch.optim.Optimizer,
# TODO: type violates Liskov substitution principle # TODO: type violates Liskov substitution principle
action_space: gym.spaces.Discrete, action_space: gym.spaces.Discrete,
@ -201,12 +201,12 @@ class DQNPolicy(BasePolicy[TDQNTrainingStats], Generic[TDQNTrainingStats]):
obs = batch.obs obs = batch.obs
# TODO: this is convoluted! See also other places where this is done. # TODO: this is convoluted! See also other places where this is done.
obs_next = obs.obs if hasattr(obs, "obs") else obs obs_next = obs.obs if hasattr(obs, "obs") else obs
logits, hidden = model(obs_next, state=state, info=batch.info) action_values_BA, hidden_BH = model(obs_next, state=state, info=batch.info)
q = self.compute_q_value(logits, getattr(obs, "mask", None)) q = self.compute_q_value(action_values_BA, getattr(obs, "mask", None))
if self.max_action_num is None: if self.max_action_num is None:
self.max_action_num = q.shape[1] self.max_action_num = q.shape[1]
act = to_numpy(q.max(dim=1)[1]) act_B = to_numpy(q.argmax(dim=1))
result = Batch(logits=logits, act=act, state=hidden) result = Batch(logits=action_values_BA, act=act_B, state=hidden_BH)
return cast(ModelOutputBatchProtocol, result) return cast(ModelOutputBatchProtocol, result)
def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TDQNTrainingStats: def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TDQNTrainingStats:

View File

@ -27,8 +27,7 @@ TFQFTrainingStats = TypeVar("TFQFTrainingStats", bound=FQFTrainingStats)
class FQFPolicy(QRDQNPolicy[TFQFTrainingStats]): class FQFPolicy(QRDQNPolicy[TFQFTrainingStats]):
"""Implementation of Fully-parameterized Quantile Function. arXiv:1911.02140. """Implementation of Fully-parameterized Quantile Function. arXiv:1911.02140.
:param model: a model following the rules in :param model: a model following the rules (s_B -> action_values_BA)
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
:param optim: a torch.optim for optimizing the model. :param optim: a torch.optim for optimizing the model.
:param fraction_model: a FractionProposalNetwork for :param fraction_model: a FractionProposalNetwork for
proposing fractions/quantiles given state. proposing fractions/quantiles given state.

View File

@ -29,8 +29,7 @@ TIQNTrainingStats = TypeVar("TIQNTrainingStats", bound=IQNTrainingStats)
class IQNPolicy(QRDQNPolicy[TIQNTrainingStats]): class IQNPolicy(QRDQNPolicy[TIQNTrainingStats]):
"""Implementation of Implicit Quantile Network. arXiv:1806.06923. """Implementation of Implicit Quantile Network. arXiv:1806.06923.
:param model: a model following the rules in :param model: a model following the rules (s_B -> action_values_BA)
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
:param optim: a torch.optim for optimizing the model. :param optim: a torch.optim for optimizing the model.
:param discount_factor: in [0, 1]. :param discount_factor: in [0, 1].
:param sample_size: the number of samples for policy evaluation. :param sample_size: the number of samples for policy evaluation.

View File

@ -12,7 +12,10 @@ from tianshou.data import Batch, ReplayBuffer, SequenceSummaryStats
from tianshou.data.types import BatchWithAdvantagesProtocol, RolloutBatchProtocol from tianshou.data.types import BatchWithAdvantagesProtocol, RolloutBatchProtocol
from tianshou.policy import A2CPolicy from tianshou.policy import A2CPolicy
from tianshou.policy.base import TLearningRateScheduler, TrainingStats from tianshou.policy.base import TLearningRateScheduler, TrainingStats
from tianshou.policy.modelfree.pg import TDistributionFunction from tianshou.policy.modelfree.pg import TDistFnDiscrOrCont
from tianshou.utils.net.continuous import ActorProb, Critic
from tianshou.utils.net.discrete import Actor as DiscreteActor
from tianshou.utils.net.discrete import Critic as DiscreteCritic
@dataclass(kw_only=True) @dataclass(kw_only=True)
@ -31,7 +34,9 @@ class NPGPolicy(A2CPolicy[TNPGTrainingStats], Generic[TNPGTrainingStats]): # ty
https://proceedings.neurips.cc/paper/2001/file/4b86abe48d358ecf194c56c69108433e-Paper.pdf https://proceedings.neurips.cc/paper/2001/file/4b86abe48d358ecf194c56c69108433e-Paper.pdf
:param actor: the actor network following the rules in BasePolicy. (s -> logits) :param actor: the actor network following the rules:
If `self.action_type == "discrete"`: (`s` ->`action_values_BA`).
If `self.action_type == "continuous"`: (`s` -> `dist_input_BD`).
:param critic: the critic network. (s -> V(s)) :param critic: the critic network. (s -> V(s))
:param optim: the optimizer for actor and critic network. :param optim: the optimizer for actor and critic network.
:param dist_fn: distribution class for computing the action. :param dist_fn: distribution class for computing the action.
@ -55,10 +60,10 @@ class NPGPolicy(A2CPolicy[TNPGTrainingStats], Generic[TNPGTrainingStats]): # ty
def __init__( def __init__(
self, self,
*, *,
actor: torch.nn.Module, actor: torch.nn.Module | ActorProb | DiscreteActor,
critic: torch.nn.Module, critic: torch.nn.Module | Critic | DiscreteCritic,
optim: torch.optim.Optimizer, optim: torch.optim.Optimizer,
dist_fn: TDistributionFunction, dist_fn: TDistFnDiscrOrCont,
action_space: gym.Space, action_space: gym.Space,
optim_critic_iters: int = 5, optim_critic_iters: int = 5,
actor_step_size: float = 0.5, actor_step_size: float = 0.5,

View File

@ -1,7 +1,7 @@
import warnings import warnings
from collections.abc import Callable from collections.abc import Callable
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Generic, Literal, TypeAlias, TypeVar, cast from typing import Any, Generic, Literal, TypeVar, cast
import gymnasium as gym import gymnasium as gym
import numpy as np import numpy as np
@ -24,9 +24,22 @@ from tianshou.data.types import (
from tianshou.policy import BasePolicy from tianshou.policy import BasePolicy
from tianshou.policy.base import TLearningRateScheduler, TrainingStats from tianshou.policy.base import TLearningRateScheduler, TrainingStats
from tianshou.utils import RunningMeanStd from tianshou.utils import RunningMeanStd
from tianshou.utils.net.continuous import ActorProb
from tianshou.utils.net.discrete import Actor
# TODO: Is there a better way to define this type? mypy doesn't like Callable[[torch.Tensor, ...], torch.distributions.Distribution] # Dimension Naming Convention
TDistributionFunction: TypeAlias = Callable[..., torch.distributions.Distribution] # B - Batch Size
# A - Action
# D - Dist input (usually 2, loc and scale)
# H - Dimension of hidden, can be None
TDistFnContinuous = Callable[
[tuple[torch.Tensor, torch.Tensor]],
torch.distributions.Distribution,
]
TDistFnDiscrete = Callable[[torch.Tensor], torch.distributions.Categorical]
TDistFnDiscrOrCont = TDistFnContinuous | TDistFnDiscrete
@dataclass(kw_only=True) @dataclass(kw_only=True)
@ -40,8 +53,9 @@ TPGTrainingStats = TypeVar("TPGTrainingStats", bound=PGTrainingStats)
class PGPolicy(BasePolicy[TPGTrainingStats], Generic[TPGTrainingStats]): class PGPolicy(BasePolicy[TPGTrainingStats], Generic[TPGTrainingStats]):
"""Implementation of REINFORCE algorithm. """Implementation of REINFORCE algorithm.
:param actor: mapping (s->model_output), should follow the rules in :param actor: the actor network following the rules:
:class:`~tianshou.policy.BasePolicy`. If `self.action_type == "discrete"`: (`s_B` ->`action_values_BA`).
If `self.action_type == "continuous"`: (`s_B` -> `dist_input_BD`).
:param optim: optimizer for actor network. :param optim: optimizer for actor network.
:param dist_fn: distribution class for computing the action. :param dist_fn: distribution class for computing the action.
Maps model_output -> distribution. Typically a Gaussian distribution Maps model_output -> distribution. Typically a Gaussian distribution
@ -71,9 +85,9 @@ class PGPolicy(BasePolicy[TPGTrainingStats], Generic[TPGTrainingStats]):
def __init__( def __init__(
self, self,
*, *,
actor: torch.nn.Module, actor: torch.nn.Module | ActorProb | Actor,
optim: torch.optim.Optimizer, optim: torch.optim.Optimizer,
dist_fn: TDistributionFunction, dist_fn: TDistFnDiscrOrCont,
action_space: gym.Space, action_space: gym.Space,
discount_factor: float = 0.99, discount_factor: float = 0.99,
# TODO: rename to return_normalization? # TODO: rename to return_normalization?
@ -175,20 +189,20 @@ class PGPolicy(BasePolicy[TPGTrainingStats], Generic[TPGTrainingStats]):
Please refer to :meth:`~tianshou.policy.BasePolicy.forward` for Please refer to :meth:`~tianshou.policy.BasePolicy.forward` for
more detailed explanation. more detailed explanation.
""" """
# TODO: rename? It's not really logits and there are particular # TODO - ALGO: marked for algorithm refactoring
# assumptions about the order of the output and on distribution type action_dist_input_BD, hidden_BH = self.actor(batch.obs, state=state, info=batch.info)
logits, hidden = self.actor(batch.obs, state=state, info=batch.info) # in the case that self.action_type == "discrete", the dist should always be Categorical, and D=A
if isinstance(logits, tuple): # therefore action_dist_input_BD is equivalent to logits_BA
dist = self.dist_fn(*logits) # If discrete, dist_fn will typically map loc, scale to a distribution (usually a Gaussian)
else: # the action_dist_input_BD in that case is a tuple of loc_B, scale_B and needs to be unpacked
dist = self.dist_fn(logits) dist = self.dist_fn(action_dist_input_BD)
# in this case, the dist is unused!
if self.deterministic_eval and not self.training: if self.deterministic_eval and not self.training:
act = dist.mode act_B = dist.mode
else: else:
act = dist.sample() act_B = dist.sample()
result = Batch(logits=logits, act=act, state=hidden, dist=dist) # act is of dimension BA in continuous case and of dimension B in discrete
result = Batch(logits=action_dist_input_BD, act=act_B, state=hidden_BH, dist=dist)
return cast(DistBatchProtocol, result) return cast(DistBatchProtocol, result)
# TODO: why does mypy complain? # TODO: why does mypy complain?

View File

@ -10,8 +10,11 @@ from tianshou.data import ReplayBuffer, SequenceSummaryStats, to_torch_as
from tianshou.data.types import LogpOldProtocol, RolloutBatchProtocol from tianshou.data.types import LogpOldProtocol, RolloutBatchProtocol
from tianshou.policy import A2CPolicy from tianshou.policy import A2CPolicy
from tianshou.policy.base import TLearningRateScheduler, TrainingStats from tianshou.policy.base import TLearningRateScheduler, TrainingStats
from tianshou.policy.modelfree.pg import TDistributionFunction from tianshou.policy.modelfree.pg import TDistFnDiscrOrCont
from tianshou.utils.net.common import ActorCritic from tianshou.utils.net.common import ActorCritic
from tianshou.utils.net.continuous import ActorProb, Critic
from tianshou.utils.net.discrete import Actor as DiscreteActor
from tianshou.utils.net.discrete import Critic as DiscreteCritic
@dataclass(kw_only=True) @dataclass(kw_only=True)
@ -29,7 +32,9 @@ TPPOTrainingStats = TypeVar("TPPOTrainingStats", bound=PPOTrainingStats)
class PPOPolicy(A2CPolicy[TPPOTrainingStats], Generic[TPPOTrainingStats]): # type: ignore[type-var] class PPOPolicy(A2CPolicy[TPPOTrainingStats], Generic[TPPOTrainingStats]): # type: ignore[type-var]
r"""Implementation of Proximal Policy Optimization. arXiv:1707.06347. r"""Implementation of Proximal Policy Optimization. arXiv:1707.06347.
:param actor: the actor network following the rules in BasePolicy. (s -> logits) :param actor: the actor network following the rules:
If `self.action_type == "discrete"`: (`s` ->`action_values_BA`).
If `self.action_type == "continuous"`: (`s` -> `dist_input_BD`).
:param critic: the critic network. (s -> V(s)) :param critic: the critic network. (s -> V(s))
:param optim: the optimizer for actor and critic network. :param optim: the optimizer for actor and critic network.
:param dist_fn: distribution class for computing the action. :param dist_fn: distribution class for computing the action.
@ -67,10 +72,10 @@ class PPOPolicy(A2CPolicy[TPPOTrainingStats], Generic[TPPOTrainingStats]): # ty
def __init__( def __init__(
self, self,
*, *,
actor: torch.nn.Module, actor: torch.nn.Module | ActorProb | DiscreteActor,
critic: torch.nn.Module, critic: torch.nn.Module | Critic | DiscreteCritic,
optim: torch.optim.Optimizer, optim: torch.optim.Optimizer,
dist_fn: TDistributionFunction, dist_fn: TDistFnDiscrOrCont,
action_space: gym.Space, action_space: gym.Space,
eps_clip: float = 0.2, eps_clip: float = 0.2,
dual_clip: float | None = None, dual_clip: float | None = None,

View File

@ -25,8 +25,7 @@ TQRDQNTrainingStats = TypeVar("TQRDQNTrainingStats", bound=QRDQNTrainingStats)
class QRDQNPolicy(DQNPolicy[TQRDQNTrainingStats], Generic[TQRDQNTrainingStats]): class QRDQNPolicy(DQNPolicy[TQRDQNTrainingStats], Generic[TQRDQNTrainingStats]):
"""Implementation of Quantile Regression Deep Q-Network. arXiv:1710.10044. """Implementation of Quantile Regression Deep Q-Network. arXiv:1710.10044.
:param model: a model following the rules in :param model: a model following the rules (s -> action_values_BA)
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
:param optim: a torch.optim for optimizing the model. :param optim: a torch.optim for optimizing the model.
:param action_space: Env's action space. :param action_space: Env's action space.
:param discount_factor: in [0, 1]. :param discount_factor: in [0, 1].

View File

@ -12,6 +12,7 @@ from tianshou.exploration import BaseNoise
from tianshou.policy import DDPGPolicy from tianshou.policy import DDPGPolicy
from tianshou.policy.base import TLearningRateScheduler from tianshou.policy.base import TLearningRateScheduler
from tianshou.policy.modelfree.ddpg import DDPGTrainingStats from tianshou.policy.modelfree.ddpg import DDPGTrainingStats
from tianshou.utils.net.continuous import ActorProb
@dataclass @dataclass
@ -61,7 +62,7 @@ class REDQPolicy(DDPGPolicy[TREDQTrainingStats]):
def __init__( def __init__(
self, self,
*, *,
actor: torch.nn.Module, actor: torch.nn.Module | ActorProb,
actor_optim: torch.optim.Optimizer, actor_optim: torch.optim.Optimizer,
critic: torch.nn.Module, critic: torch.nn.Module,
critic_optim: torch.optim.Optimizer, critic_optim: torch.optim.Optimizer,
@ -150,23 +151,28 @@ class REDQPolicy(DDPGPolicy[TREDQTrainingStats]):
state: dict | Batch | np.ndarray | None = None, state: dict | Batch | np.ndarray | None = None,
**kwargs: Any, **kwargs: Any,
) -> Batch: ) -> Batch:
loc_scale, h = self.actor(batch.obs, state=state, info=batch.info) (loc_B, scale_B), h_BH = self.actor(batch.obs, state=state, info=batch.info)
loc, scale = loc_scale dist = Independent(Normal(loc_B, scale_B), 1)
dist = Independent(Normal(loc, scale), 1)
if self.deterministic_eval and not self.training: if self.deterministic_eval and not self.training:
act = dist.mode act_B = dist.mode
else: else:
act = dist.rsample() act_B = dist.rsample()
log_prob = dist.log_prob(act).unsqueeze(-1) log_prob = dist.log_prob(act_B).unsqueeze(-1)
# apply correction for Tanh squashing when computing logprob from Gaussian # apply correction for Tanh squashing when computing logprob from Gaussian
# You can check out the original SAC paper (arXiv 1801.01290): Eq 21. # You can check out the original SAC paper (arXiv 1801.01290): Eq 21.
# in appendix C to get some understanding of this equation. # in appendix C to get some understanding of this equation.
squashed_action = torch.tanh(act) squashed_action = torch.tanh(act_B)
log_prob = log_prob - torch.log((1 - squashed_action.pow(2)) + self.__eps).sum( log_prob = log_prob - torch.log((1 - squashed_action.pow(2)) + self.__eps).sum(
-1, -1,
keepdim=True, keepdim=True,
) )
return Batch(logits=loc_scale, act=squashed_action, state=h, dist=dist, log_prob=log_prob) return Batch(
logits=(loc_B, scale_B),
act=squashed_action,
state=h_BH,
dist=dist,
log_prob=log_prob,
)
def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor:
obs_next_batch = Batch( obs_next_batch = Batch(

View File

@ -17,6 +17,7 @@ from tianshou.exploration import BaseNoise
from tianshou.policy import DDPGPolicy from tianshou.policy import DDPGPolicy
from tianshou.policy.base import TLearningRateScheduler, TrainingStats from tianshou.policy.base import TLearningRateScheduler, TrainingStats
from tianshou.utils.conversion import to_optional_float from tianshou.utils.conversion import to_optional_float
from tianshou.utils.net.continuous import ActorProb
from tianshou.utils.optim import clone_optimizer from tianshou.utils.optim import clone_optimizer
@ -36,8 +37,7 @@ TSACTrainingStats = TypeVar("TSACTrainingStats", bound=SACTrainingStats)
class SACPolicy(DDPGPolicy[TSACTrainingStats], Generic[TSACTrainingStats]): # type: ignore[type-var] class SACPolicy(DDPGPolicy[TSACTrainingStats], Generic[TSACTrainingStats]): # type: ignore[type-var]
"""Implementation of Soft Actor-Critic. arXiv:1812.05905. """Implementation of Soft Actor-Critic. arXiv:1812.05905.
:param actor: the actor network following the rules in :param actor: the actor network following the rules (s -> dist_input_BD)
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
:param actor_optim: the optimizer for actor network. :param actor_optim: the optimizer for actor network.
:param critic: the first critic network. (s, a -> Q(s, a)) :param critic: the first critic network. (s, a -> Q(s, a))
:param critic_optim: the optimizer for the first critic network. :param critic_optim: the optimizer for the first critic network.
@ -76,7 +76,7 @@ class SACPolicy(DDPGPolicy[TSACTrainingStats], Generic[TSACTrainingStats]): # t
def __init__( def __init__(
self, self,
*, *,
actor: torch.nn.Module, actor: torch.nn.Module | ActorProb,
actor_optim: torch.optim.Optimizer, actor_optim: torch.optim.Optimizer,
critic: torch.nn.Module, critic: torch.nn.Module,
critic_optim: torch.optim.Optimizer, critic_optim: torch.optim.Optimizer,
@ -173,26 +173,25 @@ class SACPolicy(DDPGPolicy[TSACTrainingStats], Generic[TSACTrainingStats]): # t
state: dict | Batch | np.ndarray | None = None, state: dict | Batch | np.ndarray | None = None,
**kwargs: Any, **kwargs: Any,
) -> DistLogProbBatchProtocol: ) -> DistLogProbBatchProtocol:
logits, hidden = self.actor(batch.obs, state=state, info=batch.info) (loc_B, scale_B), hidden_BH = self.actor(batch.obs, state=state, info=batch.info)
assert isinstance(logits, tuple) dist = Independent(Normal(loc=loc_B, scale=scale_B), 1)
dist = Independent(Normal(*logits), 1)
if self.deterministic_eval and not self.training: if self.deterministic_eval and not self.training:
act = dist.mode act_B = dist.mode
else: else:
act = dist.rsample() act_B = dist.rsample()
log_prob = dist.log_prob(act).unsqueeze(-1) log_prob = dist.log_prob(act_B).unsqueeze(-1)
# apply correction for Tanh squashing when computing logprob from Gaussian # apply correction for Tanh squashing when computing logprob from Gaussian
# You can check out the original SAC paper (arXiv 1801.01290): Eq 21. # You can check out the original SAC paper (arXiv 1801.01290): Eq 21.
# in appendix C to get some understanding of this equation. # in appendix C to get some understanding of this equation.
squashed_action = torch.tanh(act) squashed_action = torch.tanh(act_B)
log_prob = log_prob - torch.log((1 - squashed_action.pow(2)) + self.__eps).sum( log_prob = log_prob - torch.log((1 - squashed_action.pow(2)) + self.__eps).sum(
-1, -1,
keepdim=True, keepdim=True,
) )
result = Batch( result = Batch(
logits=logits, logits=(loc_B, scale_B),
act=squashed_action, act=squashed_action,
state=hidden, state=hidden_BH,
dist=dist, dist=dist,
log_prob=log_prob, log_prob=log_prob,
) )

View File

@ -29,7 +29,7 @@ class TD3Policy(DDPGPolicy[TTD3TrainingStats], Generic[TTD3TrainingStats]): # t
"""Implementation of TD3, arXiv:1802.09477. """Implementation of TD3, arXiv:1802.09477.
:param actor: the actor network following the rules in :param actor: the actor network following the rules in
:class:`~tianshou.policy.BasePolicy`. (s -> logits) :class:`~tianshou.policy.BasePolicy`. (s -> actions)
:param actor_optim: the optimizer for actor network. :param actor_optim: the optimizer for actor network.
:param critic: the first critic network. (s, a -> Q(s, a)) :param critic: the first critic network. (s, a -> Q(s, a))
:param critic_optim: the optimizer for the first critic network. :param critic_optim: the optimizer for the first critic network.

View File

@ -11,7 +11,10 @@ from tianshou.data import Batch, SequenceSummaryStats
from tianshou.policy import NPGPolicy from tianshou.policy import NPGPolicy
from tianshou.policy.base import TLearningRateScheduler from tianshou.policy.base import TLearningRateScheduler
from tianshou.policy.modelfree.npg import NPGTrainingStats from tianshou.policy.modelfree.npg import NPGTrainingStats
from tianshou.policy.modelfree.pg import TDistributionFunction from tianshou.policy.modelfree.pg import TDistFnDiscrOrCont
from tianshou.utils.net.continuous import ActorProb, Critic
from tianshou.utils.net.discrete import Actor as DiscreteActor
from tianshou.utils.net.discrete import Critic as DiscreteCritic
@dataclass(kw_only=True) @dataclass(kw_only=True)
@ -25,7 +28,9 @@ TTRPOTrainingStats = TypeVar("TTRPOTrainingStats", bound=TRPOTrainingStats)
class TRPOPolicy(NPGPolicy[TTRPOTrainingStats]): class TRPOPolicy(NPGPolicy[TTRPOTrainingStats]):
"""Implementation of Trust Region Policy Optimization. arXiv:1502.05477. """Implementation of Trust Region Policy Optimization. arXiv:1502.05477.
:param actor: the actor network following the rules in BasePolicy. (s -> logits) :param actor: the actor network following the rules:
If `self.action_type == "discrete"`: (`s_B` ->`action_values_BA`).
If `self.action_type == "continuous"`: (`s_B` -> `dist_input_BD`).
:param critic: the critic network. (s -> V(s)) :param critic: the critic network. (s -> V(s))
:param optim: the optimizer for actor and critic network. :param optim: the optimizer for actor and critic network.
:param dist_fn: distribution class for computing the action. :param dist_fn: distribution class for computing the action.
@ -53,10 +58,10 @@ class TRPOPolicy(NPGPolicy[TTRPOTrainingStats]):
def __init__( def __init__(
self, self,
*, *,
actor: torch.nn.Module, actor: torch.nn.Module | ActorProb | DiscreteActor,
critic: torch.nn.Module, critic: torch.nn.Module | Critic | DiscreteCritic,
optim: torch.optim.Optimizer, optim: torch.optim.Optimizer,
dist_fn: TDistributionFunction, dist_fn: TDistFnDiscrOrCont,
action_space: gym.Space, action_space: gym.Space,
max_kl: float = 0.01, max_kl: float = 0.01,
backtrack_coeff: float = 0.8, backtrack_coeff: float = 0.8,

View File

@ -610,6 +610,17 @@ class BaseActor(nn.Module, ABC):
def get_output_dim(self) -> int: def get_output_dim(self) -> int:
pass pass
@abstractmethod
def forward(
self,
obs: np.ndarray | torch.Tensor,
state: Any = None,
info: dict[str, Any] | None = None,
) -> tuple[Any, Any]:
# TODO: ALGO-REFACTORING. Marked to be addressed as part of Algorithm abstraction.
# Return type needs to be more specific
pass
def getattr_with_matching_alt_value(obj: Any, attr_name: str, alt_value: T | None) -> T: def getattr_with_matching_alt_value(obj: Any, attr_name: str, alt_value: T | None) -> T:
"""Gets the given attribute from the given object or takes the alternative value if it is not present. """Gets the given attribute from the given object or takes the alternative value if it is not present.

View File

@ -1,4 +1,5 @@
import warnings import warnings
from abc import ABC, abstractmethod
from collections.abc import Sequence from collections.abc import Sequence
from typing import Any from typing import Any
@ -9,6 +10,7 @@ from torch import nn
from tianshou.utils.net.common import ( from tianshou.utils.net.common import (
MLP, MLP,
BaseActor, BaseActor,
Net,
TActionShape, TActionShape,
TLinearLayer, TLinearLayer,
get_output_dim, get_output_dim,
@ -19,33 +21,27 @@ SIGMA_MAX = 2
class Actor(BaseActor): class Actor(BaseActor):
"""Simple actor network. """Simple actor network that directly outputs actions for continuous action space.
Used primarily in DDPG and its variants. For probabilistic policies, see :class:`~ActorProb`.
It will create an actor operated in continuous action space with structure of preprocess_net ---> action_shape. It will create an actor operated in continuous action space with structure of preprocess_net ---> action_shape.
:param preprocess_net: a self-defined preprocess_net which output a :param preprocess_net: a self-defined preprocess_net, see usage.
flattened hidden state. Typically, an instance of :class:`~tianshou.utils.net.common.Net`.
:param action_shape: a sequence of int for the shape of action. :param action_shape: a sequence of int for the shape of action.
:param hidden_sizes: a sequence of int for constructing the MLP after :param hidden_sizes: a sequence of int for constructing the MLP after
preprocess_net. Default to empty sequence (where the MLP now contains
only a single linear layer).
:param max_action: the scale for the final action logits. Default to
1.
:param preprocess_net_output_dim: the output dimension of
preprocess_net. preprocess_net.
:param max_action: the scale for the final action.
:param preprocess_net_output_dim: the output dimension of
`preprocess_net`. Only used when `preprocess_net` does not have the attribute `output_dim`.
For advanced usage (how to customize the network), please refer to For advanced usage (how to customize the network), please refer to
:ref:`build_the_network`. :ref:`build_the_network`.
.. seealso::
Please refer to :class:`~tianshou.utils.net.common.Net` as an instance
of how preprocess_net is suggested to be defined.
""" """
def __init__( def __init__(
self, self,
preprocess_net: nn.Module, preprocess_net: nn.Module | Net,
action_shape: TActionShape, action_shape: TActionShape,
hidden_sizes: Sequence[int] = (), hidden_sizes: Sequence[int] = (),
max_action: float = 1.0, max_action: float = 1.0,
@ -77,42 +73,50 @@ class Actor(BaseActor):
state: Any = None, state: Any = None,
info: dict[str, Any] | None = None, info: dict[str, Any] | None = None,
) -> tuple[torch.Tensor, Any]: ) -> tuple[torch.Tensor, Any]:
"""Mapping: obs -> logits -> action.""" """Mapping: s_B -> action_values_BA, hidden_state_BH | None.
if info is None:
info = {} Returns a tensor representing the actions directly, i.e, of shape
logits, hidden = self.preprocess(obs, state) `(n_actions, )`, and a hidden state (which may be None).
logits = self.max_action * torch.tanh(self.last(logits)) The hidden state is only not None if a recurrent net is used as part of the
return logits, hidden learning algorithm (support for RNNs is currently experimental).
"""
action_BA, hidden_BH = self.preprocess(obs, state)
action_BA = self.max_action * torch.tanh(self.last(action_BA))
return action_BA, hidden_BH
class Critic(nn.Module): class CriticBase(nn.Module, ABC):
@abstractmethod
def forward(
self,
obs: np.ndarray | torch.Tensor,
act: np.ndarray | torch.Tensor | None = None,
info: dict[str, Any] | None = None,
) -> torch.Tensor:
"""Mapping: (s_B, a_B) -> Q(s, a)_B."""
class Critic(CriticBase):
"""Simple critic network. """Simple critic network.
It will create an actor operated in continuous action space with structure of preprocess_net ---> 1(q value). It will create an actor operated in continuous action space with structure of preprocess_net ---> 1(q value).
:param preprocess_net: a self-defined preprocess_net which output a :param preprocess_net: a self-defined preprocess_net, see usage.
flattened hidden state. Typically, an instance of :class:`~tianshou.utils.net.common.Net`.
:param hidden_sizes: a sequence of int for constructing the MLP after :param hidden_sizes: a sequence of int for constructing the MLP after
preprocess_net. Default to empty sequence (where the MLP now contains
only a single linear layer).
:param preprocess_net_output_dim: the output dimension of
preprocess_net. preprocess_net.
:param linear_layer: use this module as linear layer. Default to nn.Linear. :param preprocess_net_output_dim: the output dimension of
`preprocess_net`. Only used when `preprocess_net` does not have the attribute `output_dim`.
:param linear_layer: use this module as linear layer.
:param flatten_input: whether to flatten input data for the last layer. :param flatten_input: whether to flatten input data for the last layer.
Default to True.
For advanced usage (how to customize the network), please refer to For advanced usage (how to customize the network), please refer to
:ref:`build_the_network`. :ref:`build_the_network`.
.. seealso::
Please refer to :class:`~tianshou.utils.net.common.Net` as an instance
of how preprocess_net is suggested to be defined.
""" """
def __init__( def __init__(
self, self,
preprocess_net: nn.Module, preprocess_net: nn.Module | Net,
hidden_sizes: Sequence[int] = (), hidden_sizes: Sequence[int] = (),
device: str | int | torch.device = "cpu", device: str | int | torch.device = "cpu",
preprocess_net_output_dim: int | None = None, preprocess_net_output_dim: int | None = None,
@ -139,9 +143,7 @@ class Critic(nn.Module):
act: np.ndarray | torch.Tensor | None = None, act: np.ndarray | torch.Tensor | None = None,
info: dict[str, Any] | None = None, info: dict[str, Any] | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
"""Mapping: (s, a) -> logits -> Q(s, a).""" """Mapping: (s_B, a_B) -> Q(s, a)_B."""
if info is None:
info = {}
obs = torch.as_tensor( obs = torch.as_tensor(
obs, obs,
device=self.device, device=self.device,
@ -154,41 +156,35 @@ class Critic(nn.Module):
dtype=torch.float32, dtype=torch.float32,
).flatten(1) ).flatten(1)
obs = torch.cat([obs, act], dim=1) obs = torch.cat([obs, act], dim=1)
logits, hidden = self.preprocess(obs) values_B, hidden_BH = self.preprocess(obs)
return self.last(logits) return self.last(values_B)
class ActorProb(BaseActor): class ActorProb(BaseActor):
"""Simple actor network (output with a Gauss distribution). """Simple actor network that outputs `mu` and `sigma` to be used as input for a `dist_fn` (typically, a Gaussian).
:param preprocess_net: a self-defined preprocess_net which output a Used primarily in SAC, PPO and variants thereof. For deterministic policies, see :class:`~Actor`.
flattened hidden state.
:param preprocess_net: a self-defined preprocess_net, see usage.
Typically, an instance of :class:`~tianshou.utils.net.common.Net`.
:param action_shape: a sequence of int for the shape of action. :param action_shape: a sequence of int for the shape of action.
:param hidden_sizes: a sequence of int for constructing the MLP after :param hidden_sizes: a sequence of int for constructing the MLP after
preprocess_net. Default to empty sequence (where the MLP now contains
only a single linear layer).
:param max_action: the scale for the final action logits. Default to
1.
:param unbounded: whether to apply tanh activation on final logits.
Default to False.
:param conditioned_sigma: True when sigma is calculated from the
input, False when sigma is an independent parameter. Default to False.
:param preprocess_net_output_dim: the output dimension of
preprocess_net. preprocess_net.
:param max_action: the scale for the final action logits.
:param unbounded: whether to apply tanh activation on final logits.
:param conditioned_sigma: True when sigma is calculated from the
input, False when sigma is an independent parameter.
:param preprocess_net_output_dim: the output dimension of
`preprocess_net`. Only used when `preprocess_net` does not have the attribute `output_dim`.
For advanced usage (how to customize the network), please refer to For advanced usage (how to customize the network), please refer to
:ref:`build_the_network`. :ref:`build_the_network`.
.. seealso::
Please refer to :class:`~tianshou.utils.net.common.Net` as an instance
of how preprocess_net is suggested to be defined.
""" """
# TODO: force kwargs, adjust downstream code # TODO: force kwargs, adjust downstream code
def __init__( def __init__(
self, self,
preprocess_net: nn.Module, preprocess_net: nn.Module | Net,
action_shape: TActionShape, action_shape: TActionShape,
hidden_sizes: Sequence[int] = (), hidden_sizes: Sequence[int] = (),
max_action: float = 1.0, max_action: float = 1.0,
@ -402,8 +398,7 @@ class Perturbation(nn.Module):
flattened hidden state. flattened hidden state.
:param max_action: the maximum value of each dimension of action. :param max_action: the maximum value of each dimension of action.
:param device: which device to create this model on. :param device: which device to create this model on.
Default to cpu. :param phi: max perturbation parameter for BCQ.
:param phi: max perturbation parameter for BCQ. Default to 0.05.
For advanced usage (how to customize the network), please refer to For advanced usage (how to customize the network), please refer to
:ref:`build_the_network`. :ref:`build_the_network`.
@ -449,7 +444,6 @@ class VAE(nn.Module):
:param latent_dim: the size of latent layer. :param latent_dim: the size of latent layer.
:param max_action: the maximum value of each dimension of action. :param max_action: the maximum value of each dimension of action.
:param device: which device to create this model on. :param device: which device to create this model on.
Default to "cpu".
For advanced usage (how to customize the network), please refer to For advanced usage (how to customize the network), please refer to
:ref:`build_the_network`. :ref:`build_the_network`.

View File

@ -7,17 +7,14 @@ import torch.nn.functional as F
from torch import nn from torch import nn
from tianshou.data import Batch, to_torch from tianshou.data import Batch, to_torch
from tianshou.utils.net.common import MLP, BaseActor, TActionShape, get_output_dim from tianshou.utils.net.common import MLP, BaseActor, Net, TActionShape, get_output_dim
class Actor(BaseActor): class Actor(BaseActor):
"""Simple actor network. """Simple actor network for discrete action spaces.
Will create an actor operated in discrete action space with structure of :param preprocess_net: a self-defined preprocess_net. Typically, an instance of
preprocess_net ---> action_shape. :class:`~tianshou.utils.net.common.Net`.
:param preprocess_net: a self-defined preprocess_net which output a
flattened hidden state.
:param action_shape: a sequence of int for the shape of action. :param action_shape: a sequence of int for the shape of action.
:param hidden_sizes: a sequence of int for constructing the MLP after :param hidden_sizes: a sequence of int for constructing the MLP after
preprocess_net. Default to empty sequence (where the MLP now contains preprocess_net. Default to empty sequence (where the MLP now contains
@ -25,20 +22,15 @@ class Actor(BaseActor):
:param softmax_output: whether to apply a softmax layer over the last :param softmax_output: whether to apply a softmax layer over the last
layer's output. layer's output.
:param preprocess_net_output_dim: the output dimension of :param preprocess_net_output_dim: the output dimension of
preprocess_net. `preprocess_net`. Only used when `preprocess_net` does not have the attribute `output_dim`.
For advanced usage (how to customize the network), please refer to For advanced usage (how to customize the network), please refer to
:ref:`build_the_network`. :ref:`build_the_network`.
.. seealso::
Please refer to :class:`~tianshou.utils.net.common.Net` as an instance
of how preprocess_net is suggested to be defined.
""" """
def __init__( def __init__(
self, self,
preprocess_net: nn.Module, preprocess_net: nn.Module | Net,
action_shape: TActionShape, action_shape: TActionShape,
hidden_sizes: Sequence[int] = (), hidden_sizes: Sequence[int] = (),
softmax_output: bool = True, softmax_output: bool = True,
@ -71,43 +63,44 @@ class Actor(BaseActor):
obs: np.ndarray | torch.Tensor, obs: np.ndarray | torch.Tensor,
state: Any = None, state: Any = None,
info: dict[str, Any] | None = None, info: dict[str, Any] | None = None,
) -> tuple[torch.Tensor, Any]: ) -> tuple[torch.Tensor, torch.Tensor | None]:
r"""Mapping: s -> Q(s, \*).""" r"""Mapping: s_B -> action_values_BA, hidden_state_BH | None.
if info is None:
info = {} Returns a tensor representing the values of each action, i.e, of shape
logits, hidden = self.preprocess(obs, state) `(n_actions, )`, and
logits = self.last(logits) a hidden state (which may be None). If `self.softmax_output` is True, they are the
probabilities for taking each action. Otherwise, they will be action values.
The hidden state is only
not None if a recurrent net is used as part of the learning algorithm.
"""
x, hidden_BH = self.preprocess(obs, state)
x = self.last(x)
if self.softmax_output: if self.softmax_output:
logits = F.softmax(logits, dim=-1) x = F.softmax(x, dim=-1)
return logits, hidden # If we computed softmax, output is probabilities, otherwise it's the non-normalized action values
output_BA = x
return output_BA, hidden_BH
class Critic(nn.Module): class Critic(nn.Module):
"""Simple critic network. """Simple critic network for discrete action spaces.
It will create an actor operated in discrete action space with structure of preprocess_net ---> 1(q value). :param preprocess_net: a self-defined preprocess_net. Typically, an instance of
:class:`~tianshou.utils.net.common.Net`.
:param preprocess_net: a self-defined preprocess_net which output a
flattened hidden state.
:param hidden_sizes: a sequence of int for constructing the MLP after :param hidden_sizes: a sequence of int for constructing the MLP after
preprocess_net. Default to empty sequence (where the MLP now contains preprocess_net. Default to empty sequence (where the MLP now contains
only a single linear layer). only a single linear layer).
:param last_size: the output dimension of Critic network. Default to 1. :param last_size: the output dimension of Critic network. Default to 1.
:param preprocess_net_output_dim: the output dimension of :param preprocess_net_output_dim: the output dimension of
preprocess_net. `preprocess_net`. Only used when `preprocess_net` does not have the attribute `output_dim`.
For advanced usage (how to customize the network), please refer to For advanced usage (how to customize the network), please refer to
:ref:`build_the_network`. :ref:`build_the_network`..
.. seealso::
Please refer to :class:`~tianshou.utils.net.common.Net` as an instance
of how preprocess_net is suggested to be defined.
""" """
def __init__( def __init__(
self, self,
preprocess_net: nn.Module, preprocess_net: nn.Module | Net,
hidden_sizes: Sequence[int] = (), hidden_sizes: Sequence[int] = (),
last_size: int = 1, last_size: int = 1,
preprocess_net_output_dim: int | None = None, preprocess_net_output_dim: int | None = None,
@ -120,8 +113,10 @@ class Critic(nn.Module):
input_dim = get_output_dim(preprocess_net, preprocess_net_output_dim) input_dim = get_output_dim(preprocess_net, preprocess_net_output_dim)
self.last = MLP(input_dim, last_size, hidden_sizes, device=self.device) self.last = MLP(input_dim, last_size, hidden_sizes, device=self.device)
# TODO: make a proper interface!
def forward(self, obs: np.ndarray | torch.Tensor, **kwargs: Any) -> torch.Tensor: def forward(self, obs: np.ndarray | torch.Tensor, **kwargs: Any) -> torch.Tensor:
"""Mapping: s -> V(s).""" """Mapping: s_B -> V(s)_B."""
# TODO: don't use this mechanism for passing state
logits, _ = self.preprocess(obs, state=kwargs.get("state", None)) logits, _ = self.preprocess(obs, state=kwargs.get("state", None))
return self.last(logits) return self.last(logits)