Naming and typing improvements in Actor/Critic/Policy forwards (#1032)

Closes #917 

### Internal Improvements
- Better variable names related to model outputs (logits, dist input
etc.). #1032
- Improved typing for actors and critics, using Tianshou classes like
`Actor`, `ActorProb`, etc.,
instead of just `nn.Module`. #1032
- Added interfaces for most `Actor` and `Critic` classes to enforce the
presence of `forward` methods. #1032
- Simplified `PGPolicy` forward by unifying the `dist_fn` interface (see
associated breaking change). #1032
- Use `.mode` of distribution instead of relying on knowledge of the
distribution type. #1032

### Breaking Changes

- Changed interface of `dist_fn` in `PGPolicy` and all subclasses to
take a single argument in both
continuous and discrete cases. #1032

---------

Co-authored-by: Arnau Jimenez <arnau.jimenez@zeiss.com>
Co-authored-by: Michael Panchenko <m.panchenko@appliedai.de>
This commit is contained in:
Erni 2024-04-01 17:14:17 +02:00 committed by GitHub
parent 5bf923c9bd
commit bf0d632108
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
43 changed files with 342 additions and 245 deletions

View File

@ -13,6 +13,12 @@
- Introduced a first iteration of a naming convention for vars in `Collector`s. #1063
- Generally improved readability of Collector code and associated tests (still quite some way to go). #1063
- Improved typing for `exploration_noise` and within Collector. #1063
- Better variable names related to model outputs (logits, dist input etc.). #1032
- Improved typing for actors and critics, using Tianshou classes like `Actor`, `ActorProb`, etc.,
instead of just `nn.Module`. #1032
- Added interfaces for most `Actor` and `Critic` classes to enforce the presence of `forward` methods. #1032
- Simplified `PGPolicy` forward by unifying the `dist_fn` interface (see associated breaking change). #1032
- Use `.mode` of distribution instead of relying on knowledge of the distribution type. #1032
### Breaking Changes
@ -21,6 +27,8 @@
expicitly or pass `reset_before_collect=True` . #1063
- VectorEnvs now return an array of info-dicts on reset instead of a list. #1063
- Fixed `iter(Batch(...)` which now behaves the same way as `Batch(...).__iter__()`. Can be considered a bugfix. #1063
- Changed interface of `dist_fn` in `PGPolicy` and all subclasses to take a single argument in both
continuous and discrete cases. #1032
### Tests
- Fixed env seeding it test_sac_with_il.py so that the test doesn't fail randomly. #1081

View File

@ -69,7 +69,7 @@
"from tianshou.policy import BasePolicy\n",
"from tianshou.policy.modelfree.pg import (\n",
" PGTrainingStats,\n",
" TDistributionFunction,\n",
" TDistFnDiscrOrCont,\n",
" TPGTrainingStats,\n",
")\n",
"from tianshou.utils import RunningMeanStd\n",
@ -339,7 +339,7 @@
" *,\n",
" actor: torch.nn.Module,\n",
" optim: torch.optim.Optimizer,\n",
" dist_fn: TDistributionFunction,\n",
" dist_fn: TDistFnDiscrOrCont,\n",
" action_space: gym.Space,\n",
" discount_factor: float = 0.99,\n",
" observation_space: gym.Space | None = None,\n",

View File

@ -257,3 +257,8 @@ macOS
joblib
master
Panchenko
BA
BH
BO
BD

View File

@ -167,8 +167,9 @@ def test_gail(args: argparse.Namespace = get_args()) -> None:
lr_scheduler = LambdaLR(optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num)
def dist(*logits: torch.Tensor) -> Distribution:
return Independent(Normal(*logits), 1)
def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution:
loc, scale = loc_scale
return Independent(Normal(loc, scale), 1)
# expert replay buffer
dataset = d4rl.qlearning_dataset(gym.make(args.expert_data_task))

View File

@ -137,8 +137,9 @@ def test_a2c(args: argparse.Namespace = get_args()) -> None:
lr_scheduler = LambdaLR(optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num)
def dist(*logits: torch.Tensor) -> Distribution:
return Independent(Normal(*logits), 1)
def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution:
loc, scale = loc_scale
return Independent(Normal(loc, scale), 1)
policy: A2CPolicy = A2CPolicy(
actor=actor,

View File

@ -134,8 +134,9 @@ def test_npg(args: argparse.Namespace = get_args()) -> None:
lr_scheduler = LambdaLR(optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num)
def dist(*logits: torch.Tensor) -> Distribution:
return Independent(Normal(*logits), 1)
def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution:
loc, scale = loc_scale
return Independent(Normal(loc, scale), 1)
policy: NPGPolicy = NPGPolicy(
actor=actor,

View File

@ -137,8 +137,9 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None:
lr_scheduler = LambdaLR(optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num)
def dist(*logits: torch.Tensor) -> Distribution:
return Independent(Normal(*logits), 1)
def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution:
loc, scale = loc_scale
return Independent(Normal(loc, scale), 1)
policy: PPOPolicy = PPOPolicy(
actor=actor,

View File

@ -119,8 +119,9 @@ def test_reinforce(args: argparse.Namespace = get_args()) -> None:
lr_scheduler = LambdaLR(optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num)
def dist(*logits: torch.Tensor) -> Distribution:
return Independent(Normal(*logits), 1)
def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution:
loc, scale = loc_scale
return Independent(Normal(loc, scale), 1)
policy: PGPolicy = PGPolicy(
actor=actor,

View File

@ -137,8 +137,9 @@ def test_trpo(args: argparse.Namespace = get_args()) -> None:
lr_scheduler = LambdaLR(optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num)
def dist(*logits: torch.Tensor) -> Distribution:
return Independent(Normal(*logits), 1)
def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution:
loc, scale = loc_scale
return Independent(Normal(loc, scale), 1)
policy: TRPOPolicy = TRPOPolicy(
actor=actor,

View File

@ -2,7 +2,7 @@ import gymnasium as gym
import numpy as np
import pytest
import torch
from torch.distributions import Categorical, Independent, Normal
from torch.distributions import Categorical, Distribution, Independent, Normal
from tianshou.policy import PPOPolicy
from tianshou.utils.net.common import ActorCritic, Net
@ -25,7 +25,11 @@ def policy(request):
Net(state_shape=obs_shape, hidden_sizes=[64, 64], action_shape=action_space.shape),
action_shape=action_space.shape,
)
dist_fn = lambda *logits: Independent(Normal(*logits), 1)
def dist_fn(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution:
loc, scale = loc_scale
return Independent(Normal(loc, scale), 1)
elif action_type == "discrete":
action_space = gym.spaces.Discrete(3)
actor = Actor(

View File

@ -103,8 +103,9 @@ def test_npg(args: argparse.Namespace = get_args()) -> None:
# replace DiagGuassian with Independent(Normal) which is equivalent
# pass *logits to be consistent with policy.forward
def dist(*logits: torch.Tensor) -> Distribution:
return Independent(Normal(*logits), 1)
def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution:
loc, scale = loc_scale
return Independent(Normal(loc, scale), 1)
policy: NPGPolicy[NPGTrainingStats] = NPGPolicy(
actor=actor,

View File

@ -100,8 +100,9 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None:
# replace DiagGuassian with Independent(Normal) which is equivalent
# pass *logits to be consistent with policy.forward
def dist(*logits: torch.Tensor) -> Distribution:
return Independent(Normal(*logits), 1)
def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution:
loc, scale = loc_scale
return Independent(Normal(loc, scale), 1)
policy: PPOPolicy[PPOTrainingStats] = PPOPolicy(
actor=actor,

View File

@ -102,8 +102,9 @@ def test_trpo(args: argparse.Namespace = get_args()) -> None:
# replace DiagGuassian with Independent(Normal) which is equivalent
# pass *logits to be consistent with policy.forward
def dist(*logits: torch.Tensor) -> Distribution:
return Independent(Normal(*logits), 1)
def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution:
loc, scale = loc_scale
return Independent(Normal(loc, scale), 1)
policy: BasePolicy = TRPOPolicy(
actor=actor,

View File

@ -133,8 +133,9 @@ def test_gail(args: argparse.Namespace = get_args()) -> None:
# replace DiagGuassian with Independent(Normal) which is equivalent
# pass *logits to be consistent with policy.forward
def dist(*logits: torch.Tensor) -> Distribution:
return Independent(Normal(*logits), 1)
def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution:
loc, scale = loc_scale
return Independent(Normal(loc, scale), 1)
policy: BasePolicy = GAILPolicy(
actor=actor,

View File

@ -181,8 +181,9 @@ def get_agents(
torch.nn.init.zeros_(m.bias)
optim = torch.optim.Adam(set(actor.parameters()).union(critic.parameters()), lr=args.lr)
def dist(*logits: torch.Tensor) -> Distribution:
return Independent(Normal(*logits), 1)
def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution:
loc, scale = loc_scale
return Independent(Normal(loc, scale), 1)
agent: PPOPolicy = PPOPolicy(
actor,

View File

@ -1,40 +1,47 @@
from abc import ABC, abstractmethod
from collections.abc import Callable
from typing import Any
import torch
from tianshou.highlevel.env import Environments, EnvType
from tianshou.policy.modelfree.pg import TDistributionFunction
from tianshou.policy.modelfree.pg import TDistFnDiscrete, TDistFnDiscrOrCont
from tianshou.utils.string import ToStringMixin
class DistributionFunctionFactory(ToStringMixin, ABC):
# True return type defined in subclasses
@abstractmethod
def create_dist_fn(self, envs: Environments) -> TDistributionFunction:
def create_dist_fn(
self,
envs: Environments,
) -> Callable[[Any], torch.distributions.Distribution]:
pass
class DistributionFunctionFactoryCategorical(DistributionFunctionFactory):
def create_dist_fn(self, envs: Environments) -> TDistributionFunction:
def create_dist_fn(self, envs: Environments) -> TDistFnDiscrete:
envs.get_type().assert_discrete(self)
return self._dist_fn
@staticmethod
def _dist_fn(p: torch.Tensor) -> torch.distributions.Distribution:
def _dist_fn(p: torch.Tensor) -> torch.distributions.Categorical:
return torch.distributions.Categorical(logits=p)
class DistributionFunctionFactoryIndependentGaussians(DistributionFunctionFactory):
def create_dist_fn(self, envs: Environments) -> TDistributionFunction:
def create_dist_fn(self, envs: Environments) -> TDistFnDiscrOrCont:
envs.get_type().assert_continuous(self)
return self._dist_fn
@staticmethod
def _dist_fn(*p: torch.Tensor) -> torch.distributions.Distribution:
return torch.distributions.Independent(torch.distributions.Normal(*p), 1)
def _dist_fn(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> torch.distributions.Distribution:
loc, scale = loc_scale
return torch.distributions.Independent(torch.distributions.Normal(loc, scale), 1)
class DistributionFunctionFactoryDefault(DistributionFunctionFactory):
def create_dist_fn(self, envs: Environments) -> TDistributionFunction:
def create_dist_fn(self, envs: Environments) -> TDistFnDiscrOrCont:
match envs.get_type():
case EnvType.DISCRETE:
return DistributionFunctionFactoryCategorical().create_dist_fn(envs)

View File

@ -19,7 +19,7 @@ from tianshou.highlevel.params.dist_fn import (
from tianshou.highlevel.params.env_param import EnvValueFactory, FloatEnvValueFactory
from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactory
from tianshou.highlevel.params.noise import NoiseFactory
from tianshou.policy.modelfree.pg import TDistributionFunction
from tianshou.policy.modelfree.pg import TDistFnDiscrOrCont
from tianshou.utils import MultipleLRSchedulers
from tianshou.utils.string import ToStringMixin
@ -322,7 +322,7 @@ class PGParams(Params, ParamsMixinActionScaling, ParamsMixinLearningRateWithSche
whether to use deterministic action (the dist's mode) instead of stochastic one during evaluation.
Does not affect training.
"""
dist_fn: TDistributionFunction | DistributionFunctionFactory | Literal["default"] = "default"
dist_fn: TDistFnDiscrOrCont | DistributionFunctionFactory | Literal["default"] = "default"
"""
This can either be a function which maps the model output to a torch distribution or a
factory for the creation of such a function.

View File

@ -213,10 +213,11 @@ class BasePolicy(nn.Module, Generic[TTrainingStats], ABC):
super().__init__()
self.observation_space = observation_space
self.action_space = action_space
self._action_type: Literal["discrete", "continuous"]
if isinstance(action_space, Discrete | MultiDiscrete | MultiBinary):
self.action_type = "discrete"
self._action_type = "discrete"
elif isinstance(action_space, Box):
self.action_type = "continuous"
self._action_type = "continuous"
else:
raise ValueError(f"Unsupported action space: {action_space}.")
self.agent_id = 0
@ -226,6 +227,10 @@ class BasePolicy(nn.Module, Generic[TTrainingStats], ABC):
self.lr_scheduler = lr_scheduler
self._compile()
@property
def action_type(self) -> Literal["discrete", "continuous"]:
return self._action_type
def set_agent_id(self, agent_id: int) -> None:
"""Set self.agent_id = agent_id, for MARL."""
self.agent_id = agent_id

View File

@ -16,6 +16,12 @@ from tianshou.data.types import (
from tianshou.policy import BasePolicy
from tianshou.policy.base import TLearningRateScheduler, TrainingStats
# Dimension Naming Convention
# B - Batch Size
# A - Action
# D - Dist input (usually 2, loc and scale)
# H - Dimension of hidden, can be None
@dataclass(kw_only=True)
class ImitationTrainingStats(TrainingStats):
@ -72,9 +78,20 @@ class ImitationPolicy(BasePolicy[TImitationTrainingStats], Generic[TImitationTra
state: dict | BatchProtocol | np.ndarray | None = None,
**kwargs: Any,
) -> ModelOutputBatchProtocol:
logits, hidden = self.actor(batch.obs, state=state, info=batch.info)
act = logits.max(dim=1)[1] if self.action_type == "discrete" else logits
result = Batch(logits=logits, act=act, state=hidden)
# TODO - ALGO-REFACTORING: marked for refactoring when Algorithm abstraction is introduced
if self.action_type == "discrete":
# If it's discrete, the "actor" is usually a critic that maps obs to action_values
# which then could be turned into logits or a Categorigal
action_values_BA, hidden_BH = self.actor(batch.obs, state=state, info=batch.info)
act_B = action_values_BA.argmax(dim=1)
result = Batch(logits=action_values_BA, act=act_B, state=hidden_BH)
elif self.action_type == "continuous":
# If it's continuous, the actor would usually deliver something like loc, scale determining a
# Gaussian dist
dist_input_BD, hidden_BH = self.actor(batch.obs, state=state, info=batch.info)
result = Batch(logits=dist_input_BD, act=dist_input_BD, state=hidden_BH)
else:
raise RuntimeError(f"Unknown {self.action_type=}, this shouldn't have happened!")
return cast(ModelOutputBatchProtocol, result)
def learn(

View File

@ -34,8 +34,7 @@ TDiscreteBCQTrainingStats = TypeVar("TDiscreteBCQTrainingStats", bound=DiscreteB
class DiscreteBCQPolicy(DQNPolicy[TDiscreteBCQTrainingStats]):
"""Implementation of discrete BCQ algorithm. arXiv:1910.01708.
:param model: a model following the rules in
:class:`~tianshou.policy.BasePolicy`. (s -> q_value)
:param model: a model following the rules (s_B -> action_values_BA)
:param imitator: a model following the rules in
:class:`~tianshou.policy.BasePolicy`. (s -> imitation_logits)
:param optim: a torch.optim for optimizing the model.

View File

@ -25,8 +25,7 @@ TDiscreteCQLTrainingStats = TypeVar("TDiscreteCQLTrainingStats", bound=DiscreteC
class DiscreteCQLPolicy(QRDQNPolicy[TDiscreteCQLTrainingStats]):
"""Implementation of discrete Conservative Q-Learning algorithm. arXiv:2006.04779.
:param model: a model following the rules in
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
:param model: a model following the rules (s_B -> action_values_BA)
:param optim: a torch.optim for optimizing the model.
:param action_space: Env's action space.
:param min_q_weight: the weight for the cql loss.

View File

@ -11,6 +11,7 @@ from tianshou.data import to_torch, to_torch_as
from tianshou.data.types import RolloutBatchProtocol
from tianshou.policy.base import TLearningRateScheduler
from tianshou.policy.modelfree.pg import PGPolicy, PGTrainingStats
from tianshou.utils.net.discrete import Actor, Critic
@dataclass
@ -26,8 +27,9 @@ TDiscreteCRRTrainingStats = TypeVar("TDiscreteCRRTrainingStats", bound=DiscreteC
class DiscreteCRRPolicy(PGPolicy[TDiscreteCRRTrainingStats]):
r"""Implementation of discrete Critic Regularized Regression. arXiv:2006.15134.
:param actor: the actor network following the rules in
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
:param actor: the actor network following the rules:
If `self.action_type == "discrete"`: (`s_B` ->`action_values_BA`).
If `self.action_type == "continuous"`: (`s_B` -> `dist_input_BD`).
:param critic: the action-value critic (i.e., Q function)
network. (s -> Q(s, \*))
:param optim: a torch.optim for optimizing the model.
@ -55,8 +57,8 @@ class DiscreteCRRPolicy(PGPolicy[TDiscreteCRRTrainingStats]):
def __init__(
self,
*,
actor: torch.nn.Module,
critic: torch.nn.Module,
actor: torch.nn.Module | Actor,
critic: torch.nn.Module | Critic,
optim: torch.optim.Optimizer,
action_space: gym.spaces.Discrete,
discount_factor: float = 0.99,

View File

@ -15,8 +15,11 @@ from tianshou.data import (
from tianshou.data.types import LogpOldProtocol, RolloutBatchProtocol
from tianshou.policy import PPOPolicy
from tianshou.policy.base import TLearningRateScheduler
from tianshou.policy.modelfree.pg import TDistributionFunction
from tianshou.policy.modelfree.pg import TDistFnDiscrOrCont
from tianshou.policy.modelfree.ppo import PPOTrainingStats
from tianshou.utils.net.continuous import ActorProb, Critic
from tianshou.utils.net.discrete import Actor as DiscreteActor
from tianshou.utils.net.discrete import Critic as DiscreteCritic
@dataclass(kw_only=True)
@ -32,7 +35,9 @@ TGailTrainingStats = TypeVar("TGailTrainingStats", bound=GailTrainingStats)
class GAILPolicy(PPOPolicy[TGailTrainingStats]):
r"""Implementation of Generative Adversarial Imitation Learning. arXiv:1606.03476.
:param actor: the actor network following the rules in BasePolicy. (s -> logits)
:param actor: the actor network following the rules:
If `self.action_type == "discrete"`: (`s_B` ->`action_values_BA`).
If `self.action_type == "continuous"`: (`s_B` -> `dist_input_BD`).
:param critic: the critic network. (s -> V(s))
:param optim: the optimizer for actor and critic network.
:param dist_fn: distribution class for computing the action.
@ -75,10 +80,10 @@ class GAILPolicy(PPOPolicy[TGailTrainingStats]):
def __init__(
self,
*,
actor: torch.nn.Module,
critic: torch.nn.Module,
actor: torch.nn.Module | ActorProb | DiscreteActor,
critic: torch.nn.Module | Critic | DiscreteCritic,
optim: torch.optim.Optimizer,
dist_fn: TDistributionFunction,
dist_fn: TDistFnDiscrOrCont,
action_space: gym.Space,
expert_buffer: ReplayBuffer,
disc_net: torch.nn.Module,

View File

@ -25,7 +25,7 @@ class TD3BCPolicy(TD3Policy[TTD3BCTrainingStats]):
"""Implementation of TD3+BC. arXiv:2106.06860.
:param actor: the actor network following the rules in
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
:class:`~tianshou.policy.BasePolicy`. (s -> actions)
:param actor_optim: the optimizer for actor network.
:param critic: the first critic network. (s, a -> Q(s, a))
:param critic_optim: the optimizer for the first critic network.

View File

@ -11,8 +11,11 @@ from tianshou.data import ReplayBuffer, SequenceSummaryStats, to_torch_as
from tianshou.data.types import BatchWithAdvantagesProtocol, RolloutBatchProtocol
from tianshou.policy import PGPolicy
from tianshou.policy.base import TLearningRateScheduler, TrainingStats
from tianshou.policy.modelfree.pg import TDistributionFunction
from tianshou.policy.modelfree.pg import TDistFnDiscrOrCont
from tianshou.utils.net.common import ActorCritic
from tianshou.utils.net.continuous import ActorProb, Critic
from tianshou.utils.net.discrete import Actor as DiscreteActor
from tianshou.utils.net.discrete import Critic as DiscreteCritic
@dataclass(kw_only=True)
@ -30,7 +33,9 @@ TA2CTrainingStats = TypeVar("TA2CTrainingStats", bound=A2CTrainingStats)
class A2CPolicy(PGPolicy[TA2CTrainingStats], Generic[TA2CTrainingStats]): # type: ignore[type-var]
"""Implementation of Synchronous Advantage Actor-Critic. arXiv:1602.01783.
:param actor: the actor network following the rules in BasePolicy. (s -> logits)
:param actor: the actor network following the rules:
If `self.action_type == "discrete"`: (`s_B` ->`action_values_BA`).
If `self.action_type == "continuous"`: (`s_B` -> `dist_input_BD`).
:param critic: the critic network. (s -> V(s))
:param optim: the optimizer for actor and critic network.
:param dist_fn: distribution class for computing the action.
@ -59,10 +64,10 @@ class A2CPolicy(PGPolicy[TA2CTrainingStats], Generic[TA2CTrainingStats]): # typ
def __init__(
self,
*,
actor: torch.nn.Module,
critic: torch.nn.Module,
actor: torch.nn.Module | ActorProb | DiscreteActor,
critic: torch.nn.Module | Critic | DiscreteCritic,
optim: torch.optim.Optimizer,
dist_fn: TDistributionFunction,
dist_fn: TDistFnDiscrOrCont,
action_space: gym.Space,
vf_coef: float = 0.5,
ent_coef: float = 0.01,

View File

@ -31,7 +31,7 @@ TBDQNTrainingStats = TypeVar("TBDQNTrainingStats", bound=BDQNTrainingStats)
class BranchingDQNPolicy(DQNPolicy[TBDQNTrainingStats]):
"""Implementation of the Branching dual Q network arXiv:1711.08946.
:param model: BranchingNet mapping (obs, state, info) -> logits.
:param model: BranchingNet mapping (obs, state, info) -> action_values_BA.
:param optim: a torch.optim for optimizing the model.
:param discount_factor: in [0, 1].
:param estimation_step: the number of steps to look ahead.
@ -156,10 +156,10 @@ class BranchingDQNPolicy(DQNPolicy[TBDQNTrainingStats]):
model = getattr(self, model)
obs = batch.obs
# TODO: this is very contrived, see also iqn.py
obs_next = obs.obs if hasattr(obs, "obs") else obs
logits, hidden = model(obs_next, state=state, info=batch.info)
act = to_numpy(logits.max(dim=-1)[1])
result = Batch(logits=logits, act=act, state=hidden)
obs_next_BO = obs.obs if hasattr(obs, "obs") else obs
action_values_BA, hidden_BH = model(obs_next_BO, state=state, info=batch.info)
act_B = to_numpy(action_values_BA.argmax(dim=-1))
result = Batch(logits=action_values_BA, act=act_B, state=hidden_BH)
return cast(ModelOutputBatchProtocol, result)
def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TBDQNTrainingStats:

View File

@ -23,8 +23,7 @@ TC51TrainingStats = TypeVar("TC51TrainingStats", bound=C51TrainingStats)
class C51Policy(DQNPolicy[TC51TrainingStats], Generic[TC51TrainingStats]):
"""Implementation of Categorical Deep Q-Network. arXiv:1707.06887.
:param model: a model following the rules in
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
:param model: a model following the rules (s_B -> action_values_BA)
:param optim: a torch.optim for optimizing the model.
:param discount_factor: in [0, 1].
:param num_atoms: the number of atoms in the support set of the

View File

@ -19,6 +19,7 @@ from tianshou.data.types import (
from tianshou.exploration import BaseNoise, GaussianNoise
from tianshou.policy import BasePolicy
from tianshou.policy.base import TLearningRateScheduler, TrainingStats
from tianshou.utils.net.continuous import Actor, Critic
@dataclass(kw_only=True)
@ -33,8 +34,7 @@ TDDPGTrainingStats = TypeVar("TDDPGTrainingStats", bound=DDPGTrainingStats)
class DDPGPolicy(BasePolicy[TDDPGTrainingStats], Generic[TDDPGTrainingStats]):
"""Implementation of Deep Deterministic Policy Gradient. arXiv:1509.02971.
:param actor: The actor network following the rules in
:class:`~tianshou.policy.BasePolicy`. (s -> model_output)
:param actor: The actor network following the rules (s -> actions)
:param actor_optim: The optimizer for actor network.
:param critic: The critic network. (s, a -> Q(s, a))
:param critic_optim: The optimizer for critic network.
@ -60,9 +60,9 @@ class DDPGPolicy(BasePolicy[TDDPGTrainingStats], Generic[TDDPGTrainingStats]):
def __init__(
self,
*,
actor: torch.nn.Module,
actor: torch.nn.Module | Actor,
actor_optim: torch.optim.Optimizer,
critic: torch.nn.Module,
critic: torch.nn.Module | Critic,
critic_optim: torch.optim.Optimizer,
action_space: gym.Space,
tau: float = 0.005,

View File

@ -12,6 +12,7 @@ from tianshou.data.types import ActBatchProtocol, ObsBatchProtocol, RolloutBatch
from tianshou.policy import SACPolicy
from tianshou.policy.base import TLearningRateScheduler
from tianshou.policy.modelfree.sac import SACTrainingStats
from tianshou.utils.net.discrete import Actor, Critic
@dataclass
@ -25,8 +26,7 @@ TDiscreteSACTrainingStats = TypeVar("TDiscreteSACTrainingStats", bound=DiscreteS
class DiscreteSACPolicy(SACPolicy[TDiscreteSACTrainingStats]):
"""Implementation of SAC for Discrete Action Settings. arXiv:1910.07207.
:param actor: the actor network following the rules in
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
:param actor: the actor network following the rules (s_B -> dist_input_BD)
:param actor_optim: the optimizer for actor network.
:param critic: the first critic network. (s, a -> Q(s, a))
:param critic_optim: the optimizer for the first critic network.
@ -54,12 +54,12 @@ class DiscreteSACPolicy(SACPolicy[TDiscreteSACTrainingStats]):
def __init__(
self,
*,
actor: torch.nn.Module,
actor: torch.nn.Module | Actor,
actor_optim: torch.optim.Optimizer,
critic: torch.nn.Module,
critic: torch.nn.Module | Critic,
critic_optim: torch.optim.Optimizer,
action_space: gym.spaces.Discrete,
critic2: torch.nn.Module | None = None,
critic2: torch.nn.Module | Critic | None = None,
critic2_optim: torch.optim.Optimizer | None = None,
tau: float = 0.005,
gamma: float = 0.99,
@ -105,13 +105,13 @@ class DiscreteSACPolicy(SACPolicy[TDiscreteSACTrainingStats]):
state: dict | Batch | np.ndarray | None = None,
**kwargs: Any,
) -> Batch:
logits, hidden = self.actor(batch.obs, state=state, info=batch.info)
dist = Categorical(logits=logits)
logits_BA, hidden_BH = self.actor(batch.obs, state=state, info=batch.info)
dist = Categorical(logits=logits_BA)
if self.deterministic_eval and not self.training:
act = dist.mode
act_B = dist.mode
else:
act = dist.sample()
return Batch(logits=logits, act=act, state=hidden, dist=dist)
act_B = dist.sample()
return Batch(logits=logits_BA, act=act_B, state=hidden_BH, dist=dist)
def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor:
obs_next_batch = Batch(

View File

@ -17,6 +17,7 @@ from tianshou.data.types import (
)
from tianshou.policy import BasePolicy
from tianshou.policy.base import TLearningRateScheduler, TrainingStats
from tianshou.utils.net.common import Net
@dataclass(kw_only=True)
@ -35,8 +36,7 @@ class DQNPolicy(BasePolicy[TDQNTrainingStats], Generic[TDQNTrainingStats]):
Implementation of Dueling DQN. arXiv:1511.06581 (the dueling DQN is
implemented in the network side, not here).
:param model: a model following the rules in
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
:param model: a model following the rules (s -> action_values_BA)
:param optim: a torch.optim for optimizing the model.
:param discount_factor: in [0, 1].
:param estimation_step: the number of steps to look ahead.
@ -60,7 +60,7 @@ class DQNPolicy(BasePolicy[TDQNTrainingStats], Generic[TDQNTrainingStats]):
def __init__(
self,
*,
model: torch.nn.Module,
model: torch.nn.Module | Net,
optim: torch.optim.Optimizer,
# TODO: type violates Liskov substitution principle
action_space: gym.spaces.Discrete,
@ -201,12 +201,12 @@ class DQNPolicy(BasePolicy[TDQNTrainingStats], Generic[TDQNTrainingStats]):
obs = batch.obs
# TODO: this is convoluted! See also other places where this is done.
obs_next = obs.obs if hasattr(obs, "obs") else obs
logits, hidden = model(obs_next, state=state, info=batch.info)
q = self.compute_q_value(logits, getattr(obs, "mask", None))
action_values_BA, hidden_BH = model(obs_next, state=state, info=batch.info)
q = self.compute_q_value(action_values_BA, getattr(obs, "mask", None))
if self.max_action_num is None:
self.max_action_num = q.shape[1]
act = to_numpy(q.max(dim=1)[1])
result = Batch(logits=logits, act=act, state=hidden)
act_B = to_numpy(q.argmax(dim=1))
result = Batch(logits=action_values_BA, act=act_B, state=hidden_BH)
return cast(ModelOutputBatchProtocol, result)
def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TDQNTrainingStats:

View File

@ -27,8 +27,7 @@ TFQFTrainingStats = TypeVar("TFQFTrainingStats", bound=FQFTrainingStats)
class FQFPolicy(QRDQNPolicy[TFQFTrainingStats]):
"""Implementation of Fully-parameterized Quantile Function. arXiv:1911.02140.
:param model: a model following the rules in
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
:param model: a model following the rules (s_B -> action_values_BA)
:param optim: a torch.optim for optimizing the model.
:param fraction_model: a FractionProposalNetwork for
proposing fractions/quantiles given state.

View File

@ -29,8 +29,7 @@ TIQNTrainingStats = TypeVar("TIQNTrainingStats", bound=IQNTrainingStats)
class IQNPolicy(QRDQNPolicy[TIQNTrainingStats]):
"""Implementation of Implicit Quantile Network. arXiv:1806.06923.
:param model: a model following the rules in
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
:param model: a model following the rules (s_B -> action_values_BA)
:param optim: a torch.optim for optimizing the model.
:param discount_factor: in [0, 1].
:param sample_size: the number of samples for policy evaluation.

View File

@ -12,7 +12,10 @@ from tianshou.data import Batch, ReplayBuffer, SequenceSummaryStats
from tianshou.data.types import BatchWithAdvantagesProtocol, RolloutBatchProtocol
from tianshou.policy import A2CPolicy
from tianshou.policy.base import TLearningRateScheduler, TrainingStats
from tianshou.policy.modelfree.pg import TDistributionFunction
from tianshou.policy.modelfree.pg import TDistFnDiscrOrCont
from tianshou.utils.net.continuous import ActorProb, Critic
from tianshou.utils.net.discrete import Actor as DiscreteActor
from tianshou.utils.net.discrete import Critic as DiscreteCritic
@dataclass(kw_only=True)
@ -31,7 +34,9 @@ class NPGPolicy(A2CPolicy[TNPGTrainingStats], Generic[TNPGTrainingStats]): # ty
https://proceedings.neurips.cc/paper/2001/file/4b86abe48d358ecf194c56c69108433e-Paper.pdf
:param actor: the actor network following the rules in BasePolicy. (s -> logits)
:param actor: the actor network following the rules:
If `self.action_type == "discrete"`: (`s` ->`action_values_BA`).
If `self.action_type == "continuous"`: (`s` -> `dist_input_BD`).
:param critic: the critic network. (s -> V(s))
:param optim: the optimizer for actor and critic network.
:param dist_fn: distribution class for computing the action.
@ -55,10 +60,10 @@ class NPGPolicy(A2CPolicy[TNPGTrainingStats], Generic[TNPGTrainingStats]): # ty
def __init__(
self,
*,
actor: torch.nn.Module,
critic: torch.nn.Module,
actor: torch.nn.Module | ActorProb | DiscreteActor,
critic: torch.nn.Module | Critic | DiscreteCritic,
optim: torch.optim.Optimizer,
dist_fn: TDistributionFunction,
dist_fn: TDistFnDiscrOrCont,
action_space: gym.Space,
optim_critic_iters: int = 5,
actor_step_size: float = 0.5,

View File

@ -1,7 +1,7 @@
import warnings
from collections.abc import Callable
from dataclasses import dataclass
from typing import Any, Generic, Literal, TypeAlias, TypeVar, cast
from typing import Any, Generic, Literal, TypeVar, cast
import gymnasium as gym
import numpy as np
@ -24,9 +24,22 @@ from tianshou.data.types import (
from tianshou.policy import BasePolicy
from tianshou.policy.base import TLearningRateScheduler, TrainingStats
from tianshou.utils import RunningMeanStd
from tianshou.utils.net.continuous import ActorProb
from tianshou.utils.net.discrete import Actor
# TODO: Is there a better way to define this type? mypy doesn't like Callable[[torch.Tensor, ...], torch.distributions.Distribution]
TDistributionFunction: TypeAlias = Callable[..., torch.distributions.Distribution]
# Dimension Naming Convention
# B - Batch Size
# A - Action
# D - Dist input (usually 2, loc and scale)
# H - Dimension of hidden, can be None
TDistFnContinuous = Callable[
[tuple[torch.Tensor, torch.Tensor]],
torch.distributions.Distribution,
]
TDistFnDiscrete = Callable[[torch.Tensor], torch.distributions.Categorical]
TDistFnDiscrOrCont = TDistFnContinuous | TDistFnDiscrete
@dataclass(kw_only=True)
@ -40,8 +53,9 @@ TPGTrainingStats = TypeVar("TPGTrainingStats", bound=PGTrainingStats)
class PGPolicy(BasePolicy[TPGTrainingStats], Generic[TPGTrainingStats]):
"""Implementation of REINFORCE algorithm.
:param actor: mapping (s->model_output), should follow the rules in
:class:`~tianshou.policy.BasePolicy`.
:param actor: the actor network following the rules:
If `self.action_type == "discrete"`: (`s_B` ->`action_values_BA`).
If `self.action_type == "continuous"`: (`s_B` -> `dist_input_BD`).
:param optim: optimizer for actor network.
:param dist_fn: distribution class for computing the action.
Maps model_output -> distribution. Typically a Gaussian distribution
@ -71,9 +85,9 @@ class PGPolicy(BasePolicy[TPGTrainingStats], Generic[TPGTrainingStats]):
def __init__(
self,
*,
actor: torch.nn.Module,
actor: torch.nn.Module | ActorProb | Actor,
optim: torch.optim.Optimizer,
dist_fn: TDistributionFunction,
dist_fn: TDistFnDiscrOrCont,
action_space: gym.Space,
discount_factor: float = 0.99,
# TODO: rename to return_normalization?
@ -175,20 +189,20 @@ class PGPolicy(BasePolicy[TPGTrainingStats], Generic[TPGTrainingStats]):
Please refer to :meth:`~tianshou.policy.BasePolicy.forward` for
more detailed explanation.
"""
# TODO: rename? It's not really logits and there are particular
# assumptions about the order of the output and on distribution type
logits, hidden = self.actor(batch.obs, state=state, info=batch.info)
if isinstance(logits, tuple):
dist = self.dist_fn(*logits)
else:
dist = self.dist_fn(logits)
# TODO - ALGO: marked for algorithm refactoring
action_dist_input_BD, hidden_BH = self.actor(batch.obs, state=state, info=batch.info)
# in the case that self.action_type == "discrete", the dist should always be Categorical, and D=A
# therefore action_dist_input_BD is equivalent to logits_BA
# If discrete, dist_fn will typically map loc, scale to a distribution (usually a Gaussian)
# the action_dist_input_BD in that case is a tuple of loc_B, scale_B and needs to be unpacked
dist = self.dist_fn(action_dist_input_BD)
# in this case, the dist is unused!
if self.deterministic_eval and not self.training:
act = dist.mode
act_B = dist.mode
else:
act = dist.sample()
result = Batch(logits=logits, act=act, state=hidden, dist=dist)
act_B = dist.sample()
# act is of dimension BA in continuous case and of dimension B in discrete
result = Batch(logits=action_dist_input_BD, act=act_B, state=hidden_BH, dist=dist)
return cast(DistBatchProtocol, result)
# TODO: why does mypy complain?

View File

@ -10,8 +10,11 @@ from tianshou.data import ReplayBuffer, SequenceSummaryStats, to_torch_as
from tianshou.data.types import LogpOldProtocol, RolloutBatchProtocol
from tianshou.policy import A2CPolicy
from tianshou.policy.base import TLearningRateScheduler, TrainingStats
from tianshou.policy.modelfree.pg import TDistributionFunction
from tianshou.policy.modelfree.pg import TDistFnDiscrOrCont
from tianshou.utils.net.common import ActorCritic
from tianshou.utils.net.continuous import ActorProb, Critic
from tianshou.utils.net.discrete import Actor as DiscreteActor
from tianshou.utils.net.discrete import Critic as DiscreteCritic
@dataclass(kw_only=True)
@ -29,7 +32,9 @@ TPPOTrainingStats = TypeVar("TPPOTrainingStats", bound=PPOTrainingStats)
class PPOPolicy(A2CPolicy[TPPOTrainingStats], Generic[TPPOTrainingStats]): # type: ignore[type-var]
r"""Implementation of Proximal Policy Optimization. arXiv:1707.06347.
:param actor: the actor network following the rules in BasePolicy. (s -> logits)
:param actor: the actor network following the rules:
If `self.action_type == "discrete"`: (`s` ->`action_values_BA`).
If `self.action_type == "continuous"`: (`s` -> `dist_input_BD`).
:param critic: the critic network. (s -> V(s))
:param optim: the optimizer for actor and critic network.
:param dist_fn: distribution class for computing the action.
@ -67,10 +72,10 @@ class PPOPolicy(A2CPolicy[TPPOTrainingStats], Generic[TPPOTrainingStats]): # ty
def __init__(
self,
*,
actor: torch.nn.Module,
critic: torch.nn.Module,
actor: torch.nn.Module | ActorProb | DiscreteActor,
critic: torch.nn.Module | Critic | DiscreteCritic,
optim: torch.optim.Optimizer,
dist_fn: TDistributionFunction,
dist_fn: TDistFnDiscrOrCont,
action_space: gym.Space,
eps_clip: float = 0.2,
dual_clip: float | None = None,

View File

@ -25,8 +25,7 @@ TQRDQNTrainingStats = TypeVar("TQRDQNTrainingStats", bound=QRDQNTrainingStats)
class QRDQNPolicy(DQNPolicy[TQRDQNTrainingStats], Generic[TQRDQNTrainingStats]):
"""Implementation of Quantile Regression Deep Q-Network. arXiv:1710.10044.
:param model: a model following the rules in
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
:param model: a model following the rules (s -> action_values_BA)
:param optim: a torch.optim for optimizing the model.
:param action_space: Env's action space.
:param discount_factor: in [0, 1].

View File

@ -12,6 +12,7 @@ from tianshou.exploration import BaseNoise
from tianshou.policy import DDPGPolicy
from tianshou.policy.base import TLearningRateScheduler
from tianshou.policy.modelfree.ddpg import DDPGTrainingStats
from tianshou.utils.net.continuous import ActorProb
@dataclass
@ -61,7 +62,7 @@ class REDQPolicy(DDPGPolicy[TREDQTrainingStats]):
def __init__(
self,
*,
actor: torch.nn.Module,
actor: torch.nn.Module | ActorProb,
actor_optim: torch.optim.Optimizer,
critic: torch.nn.Module,
critic_optim: torch.optim.Optimizer,
@ -150,23 +151,28 @@ class REDQPolicy(DDPGPolicy[TREDQTrainingStats]):
state: dict | Batch | np.ndarray | None = None,
**kwargs: Any,
) -> Batch:
loc_scale, h = self.actor(batch.obs, state=state, info=batch.info)
loc, scale = loc_scale
dist = Independent(Normal(loc, scale), 1)
(loc_B, scale_B), h_BH = self.actor(batch.obs, state=state, info=batch.info)
dist = Independent(Normal(loc_B, scale_B), 1)
if self.deterministic_eval and not self.training:
act = dist.mode
act_B = dist.mode
else:
act = dist.rsample()
log_prob = dist.log_prob(act).unsqueeze(-1)
act_B = dist.rsample()
log_prob = dist.log_prob(act_B).unsqueeze(-1)
# apply correction for Tanh squashing when computing logprob from Gaussian
# You can check out the original SAC paper (arXiv 1801.01290): Eq 21.
# in appendix C to get some understanding of this equation.
squashed_action = torch.tanh(act)
squashed_action = torch.tanh(act_B)
log_prob = log_prob - torch.log((1 - squashed_action.pow(2)) + self.__eps).sum(
-1,
keepdim=True,
)
return Batch(logits=loc_scale, act=squashed_action, state=h, dist=dist, log_prob=log_prob)
return Batch(
logits=(loc_B, scale_B),
act=squashed_action,
state=h_BH,
dist=dist,
log_prob=log_prob,
)
def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor:
obs_next_batch = Batch(

View File

@ -17,6 +17,7 @@ from tianshou.exploration import BaseNoise
from tianshou.policy import DDPGPolicy
from tianshou.policy.base import TLearningRateScheduler, TrainingStats
from tianshou.utils.conversion import to_optional_float
from tianshou.utils.net.continuous import ActorProb
from tianshou.utils.optim import clone_optimizer
@ -36,8 +37,7 @@ TSACTrainingStats = TypeVar("TSACTrainingStats", bound=SACTrainingStats)
class SACPolicy(DDPGPolicy[TSACTrainingStats], Generic[TSACTrainingStats]): # type: ignore[type-var]
"""Implementation of Soft Actor-Critic. arXiv:1812.05905.
:param actor: the actor network following the rules in
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
:param actor: the actor network following the rules (s -> dist_input_BD)
:param actor_optim: the optimizer for actor network.
:param critic: the first critic network. (s, a -> Q(s, a))
:param critic_optim: the optimizer for the first critic network.
@ -76,7 +76,7 @@ class SACPolicy(DDPGPolicy[TSACTrainingStats], Generic[TSACTrainingStats]): # t
def __init__(
self,
*,
actor: torch.nn.Module,
actor: torch.nn.Module | ActorProb,
actor_optim: torch.optim.Optimizer,
critic: torch.nn.Module,
critic_optim: torch.optim.Optimizer,
@ -173,26 +173,25 @@ class SACPolicy(DDPGPolicy[TSACTrainingStats], Generic[TSACTrainingStats]): # t
state: dict | Batch | np.ndarray | None = None,
**kwargs: Any,
) -> DistLogProbBatchProtocol:
logits, hidden = self.actor(batch.obs, state=state, info=batch.info)
assert isinstance(logits, tuple)
dist = Independent(Normal(*logits), 1)
(loc_B, scale_B), hidden_BH = self.actor(batch.obs, state=state, info=batch.info)
dist = Independent(Normal(loc=loc_B, scale=scale_B), 1)
if self.deterministic_eval and not self.training:
act = dist.mode
act_B = dist.mode
else:
act = dist.rsample()
log_prob = dist.log_prob(act).unsqueeze(-1)
act_B = dist.rsample()
log_prob = dist.log_prob(act_B).unsqueeze(-1)
# apply correction for Tanh squashing when computing logprob from Gaussian
# You can check out the original SAC paper (arXiv 1801.01290): Eq 21.
# in appendix C to get some understanding of this equation.
squashed_action = torch.tanh(act)
squashed_action = torch.tanh(act_B)
log_prob = log_prob - torch.log((1 - squashed_action.pow(2)) + self.__eps).sum(
-1,
keepdim=True,
)
result = Batch(
logits=logits,
logits=(loc_B, scale_B),
act=squashed_action,
state=hidden,
state=hidden_BH,
dist=dist,
log_prob=log_prob,
)

View File

@ -29,7 +29,7 @@ class TD3Policy(DDPGPolicy[TTD3TrainingStats], Generic[TTD3TrainingStats]): # t
"""Implementation of TD3, arXiv:1802.09477.
:param actor: the actor network following the rules in
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
:class:`~tianshou.policy.BasePolicy`. (s -> actions)
:param actor_optim: the optimizer for actor network.
:param critic: the first critic network. (s, a -> Q(s, a))
:param critic_optim: the optimizer for the first critic network.

View File

@ -11,7 +11,10 @@ from tianshou.data import Batch, SequenceSummaryStats
from tianshou.policy import NPGPolicy
from tianshou.policy.base import TLearningRateScheduler
from tianshou.policy.modelfree.npg import NPGTrainingStats
from tianshou.policy.modelfree.pg import TDistributionFunction
from tianshou.policy.modelfree.pg import TDistFnDiscrOrCont
from tianshou.utils.net.continuous import ActorProb, Critic
from tianshou.utils.net.discrete import Actor as DiscreteActor
from tianshou.utils.net.discrete import Critic as DiscreteCritic
@dataclass(kw_only=True)
@ -25,7 +28,9 @@ TTRPOTrainingStats = TypeVar("TTRPOTrainingStats", bound=TRPOTrainingStats)
class TRPOPolicy(NPGPolicy[TTRPOTrainingStats]):
"""Implementation of Trust Region Policy Optimization. arXiv:1502.05477.
:param actor: the actor network following the rules in BasePolicy. (s -> logits)
:param actor: the actor network following the rules:
If `self.action_type == "discrete"`: (`s_B` ->`action_values_BA`).
If `self.action_type == "continuous"`: (`s_B` -> `dist_input_BD`).
:param critic: the critic network. (s -> V(s))
:param optim: the optimizer for actor and critic network.
:param dist_fn: distribution class for computing the action.
@ -53,10 +58,10 @@ class TRPOPolicy(NPGPolicy[TTRPOTrainingStats]):
def __init__(
self,
*,
actor: torch.nn.Module,
critic: torch.nn.Module,
actor: torch.nn.Module | ActorProb | DiscreteActor,
critic: torch.nn.Module | Critic | DiscreteCritic,
optim: torch.optim.Optimizer,
dist_fn: TDistributionFunction,
dist_fn: TDistFnDiscrOrCont,
action_space: gym.Space,
max_kl: float = 0.01,
backtrack_coeff: float = 0.8,

View File

@ -610,6 +610,17 @@ class BaseActor(nn.Module, ABC):
def get_output_dim(self) -> int:
pass
@abstractmethod
def forward(
self,
obs: np.ndarray | torch.Tensor,
state: Any = None,
info: dict[str, Any] | None = None,
) -> tuple[Any, Any]:
# TODO: ALGO-REFACTORING. Marked to be addressed as part of Algorithm abstraction.
# Return type needs to be more specific
pass
def getattr_with_matching_alt_value(obj: Any, attr_name: str, alt_value: T | None) -> T:
"""Gets the given attribute from the given object or takes the alternative value if it is not present.

View File

@ -1,4 +1,5 @@
import warnings
from abc import ABC, abstractmethod
from collections.abc import Sequence
from typing import Any
@ -9,6 +10,7 @@ from torch import nn
from tianshou.utils.net.common import (
MLP,
BaseActor,
Net,
TActionShape,
TLinearLayer,
get_output_dim,
@ -19,33 +21,27 @@ SIGMA_MAX = 2
class Actor(BaseActor):
"""Simple actor network.
"""Simple actor network that directly outputs actions for continuous action space.
Used primarily in DDPG and its variants. For probabilistic policies, see :class:`~ActorProb`.
It will create an actor operated in continuous action space with structure of preprocess_net ---> action_shape.
:param preprocess_net: a self-defined preprocess_net which output a
flattened hidden state.
:param preprocess_net: a self-defined preprocess_net, see usage.
Typically, an instance of :class:`~tianshou.utils.net.common.Net`.
:param action_shape: a sequence of int for the shape of action.
:param hidden_sizes: a sequence of int for constructing the MLP after
preprocess_net. Default to empty sequence (where the MLP now contains
only a single linear layer).
:param max_action: the scale for the final action logits. Default to
1.
:param preprocess_net_output_dim: the output dimension of
preprocess_net.
:param max_action: the scale for the final action.
:param preprocess_net_output_dim: the output dimension of
`preprocess_net`. Only used when `preprocess_net` does not have the attribute `output_dim`.
For advanced usage (how to customize the network), please refer to
:ref:`build_the_network`.
.. seealso::
Please refer to :class:`~tianshou.utils.net.common.Net` as an instance
of how preprocess_net is suggested to be defined.
"""
def __init__(
self,
preprocess_net: nn.Module,
preprocess_net: nn.Module | Net,
action_shape: TActionShape,
hidden_sizes: Sequence[int] = (),
max_action: float = 1.0,
@ -77,42 +73,50 @@ class Actor(BaseActor):
state: Any = None,
info: dict[str, Any] | None = None,
) -> tuple[torch.Tensor, Any]:
"""Mapping: obs -> logits -> action."""
if info is None:
info = {}
logits, hidden = self.preprocess(obs, state)
logits = self.max_action * torch.tanh(self.last(logits))
return logits, hidden
"""Mapping: s_B -> action_values_BA, hidden_state_BH | None.
Returns a tensor representing the actions directly, i.e, of shape
`(n_actions, )`, and a hidden state (which may be None).
The hidden state is only not None if a recurrent net is used as part of the
learning algorithm (support for RNNs is currently experimental).
"""
action_BA, hidden_BH = self.preprocess(obs, state)
action_BA = self.max_action * torch.tanh(self.last(action_BA))
return action_BA, hidden_BH
class Critic(nn.Module):
class CriticBase(nn.Module, ABC):
@abstractmethod
def forward(
self,
obs: np.ndarray | torch.Tensor,
act: np.ndarray | torch.Tensor | None = None,
info: dict[str, Any] | None = None,
) -> torch.Tensor:
"""Mapping: (s_B, a_B) -> Q(s, a)_B."""
class Critic(CriticBase):
"""Simple critic network.
It will create an actor operated in continuous action space with structure of preprocess_net ---> 1(q value).
:param preprocess_net: a self-defined preprocess_net which output a
flattened hidden state.
:param preprocess_net: a self-defined preprocess_net, see usage.
Typically, an instance of :class:`~tianshou.utils.net.common.Net`.
:param hidden_sizes: a sequence of int for constructing the MLP after
preprocess_net. Default to empty sequence (where the MLP now contains
only a single linear layer).
:param preprocess_net_output_dim: the output dimension of
preprocess_net.
:param linear_layer: use this module as linear layer. Default to nn.Linear.
:param preprocess_net_output_dim: the output dimension of
`preprocess_net`. Only used when `preprocess_net` does not have the attribute `output_dim`.
:param linear_layer: use this module as linear layer.
:param flatten_input: whether to flatten input data for the last layer.
Default to True.
For advanced usage (how to customize the network), please refer to
:ref:`build_the_network`.
.. seealso::
Please refer to :class:`~tianshou.utils.net.common.Net` as an instance
of how preprocess_net is suggested to be defined.
"""
def __init__(
self,
preprocess_net: nn.Module,
preprocess_net: nn.Module | Net,
hidden_sizes: Sequence[int] = (),
device: str | int | torch.device = "cpu",
preprocess_net_output_dim: int | None = None,
@ -139,9 +143,7 @@ class Critic(nn.Module):
act: np.ndarray | torch.Tensor | None = None,
info: dict[str, Any] | None = None,
) -> torch.Tensor:
"""Mapping: (s, a) -> logits -> Q(s, a)."""
if info is None:
info = {}
"""Mapping: (s_B, a_B) -> Q(s, a)_B."""
obs = torch.as_tensor(
obs,
device=self.device,
@ -154,41 +156,35 @@ class Critic(nn.Module):
dtype=torch.float32,
).flatten(1)
obs = torch.cat([obs, act], dim=1)
logits, hidden = self.preprocess(obs)
return self.last(logits)
values_B, hidden_BH = self.preprocess(obs)
return self.last(values_B)
class ActorProb(BaseActor):
"""Simple actor network (output with a Gauss distribution).
"""Simple actor network that outputs `mu` and `sigma` to be used as input for a `dist_fn` (typically, a Gaussian).
:param preprocess_net: a self-defined preprocess_net which output a
flattened hidden state.
Used primarily in SAC, PPO and variants thereof. For deterministic policies, see :class:`~Actor`.
:param preprocess_net: a self-defined preprocess_net, see usage.
Typically, an instance of :class:`~tianshou.utils.net.common.Net`.
:param action_shape: a sequence of int for the shape of action.
:param hidden_sizes: a sequence of int for constructing the MLP after
preprocess_net. Default to empty sequence (where the MLP now contains
only a single linear layer).
:param max_action: the scale for the final action logits. Default to
1.
:param unbounded: whether to apply tanh activation on final logits.
Default to False.
:param conditioned_sigma: True when sigma is calculated from the
input, False when sigma is an independent parameter. Default to False.
:param preprocess_net_output_dim: the output dimension of
preprocess_net.
:param max_action: the scale for the final action logits.
:param unbounded: whether to apply tanh activation on final logits.
:param conditioned_sigma: True when sigma is calculated from the
input, False when sigma is an independent parameter.
:param preprocess_net_output_dim: the output dimension of
`preprocess_net`. Only used when `preprocess_net` does not have the attribute `output_dim`.
For advanced usage (how to customize the network), please refer to
:ref:`build_the_network`.
.. seealso::
Please refer to :class:`~tianshou.utils.net.common.Net` as an instance
of how preprocess_net is suggested to be defined.
"""
# TODO: force kwargs, adjust downstream code
def __init__(
self,
preprocess_net: nn.Module,
preprocess_net: nn.Module | Net,
action_shape: TActionShape,
hidden_sizes: Sequence[int] = (),
max_action: float = 1.0,
@ -402,8 +398,7 @@ class Perturbation(nn.Module):
flattened hidden state.
:param max_action: the maximum value of each dimension of action.
:param device: which device to create this model on.
Default to cpu.
:param phi: max perturbation parameter for BCQ. Default to 0.05.
:param phi: max perturbation parameter for BCQ.
For advanced usage (how to customize the network), please refer to
:ref:`build_the_network`.
@ -449,7 +444,6 @@ class VAE(nn.Module):
:param latent_dim: the size of latent layer.
:param max_action: the maximum value of each dimension of action.
:param device: which device to create this model on.
Default to "cpu".
For advanced usage (how to customize the network), please refer to
:ref:`build_the_network`.

View File

@ -7,17 +7,14 @@ import torch.nn.functional as F
from torch import nn
from tianshou.data import Batch, to_torch
from tianshou.utils.net.common import MLP, BaseActor, TActionShape, get_output_dim
from tianshou.utils.net.common import MLP, BaseActor, Net, TActionShape, get_output_dim
class Actor(BaseActor):
"""Simple actor network.
"""Simple actor network for discrete action spaces.
Will create an actor operated in discrete action space with structure of
preprocess_net ---> action_shape.
:param preprocess_net: a self-defined preprocess_net which output a
flattened hidden state.
:param preprocess_net: a self-defined preprocess_net. Typically, an instance of
:class:`~tianshou.utils.net.common.Net`.
:param action_shape: a sequence of int for the shape of action.
:param hidden_sizes: a sequence of int for constructing the MLP after
preprocess_net. Default to empty sequence (where the MLP now contains
@ -25,20 +22,15 @@ class Actor(BaseActor):
:param softmax_output: whether to apply a softmax layer over the last
layer's output.
:param preprocess_net_output_dim: the output dimension of
preprocess_net.
`preprocess_net`. Only used when `preprocess_net` does not have the attribute `output_dim`.
For advanced usage (how to customize the network), please refer to
:ref:`build_the_network`.
.. seealso::
Please refer to :class:`~tianshou.utils.net.common.Net` as an instance
of how preprocess_net is suggested to be defined.
"""
def __init__(
self,
preprocess_net: nn.Module,
preprocess_net: nn.Module | Net,
action_shape: TActionShape,
hidden_sizes: Sequence[int] = (),
softmax_output: bool = True,
@ -71,43 +63,44 @@ class Actor(BaseActor):
obs: np.ndarray | torch.Tensor,
state: Any = None,
info: dict[str, Any] | None = None,
) -> tuple[torch.Tensor, Any]:
r"""Mapping: s -> Q(s, \*)."""
if info is None:
info = {}
logits, hidden = self.preprocess(obs, state)
logits = self.last(logits)
) -> tuple[torch.Tensor, torch.Tensor | None]:
r"""Mapping: s_B -> action_values_BA, hidden_state_BH | None.
Returns a tensor representing the values of each action, i.e, of shape
`(n_actions, )`, and
a hidden state (which may be None). If `self.softmax_output` is True, they are the
probabilities for taking each action. Otherwise, they will be action values.
The hidden state is only
not None if a recurrent net is used as part of the learning algorithm.
"""
x, hidden_BH = self.preprocess(obs, state)
x = self.last(x)
if self.softmax_output:
logits = F.softmax(logits, dim=-1)
return logits, hidden
x = F.softmax(x, dim=-1)
# If we computed softmax, output is probabilities, otherwise it's the non-normalized action values
output_BA = x
return output_BA, hidden_BH
class Critic(nn.Module):
"""Simple critic network.
"""Simple critic network for discrete action spaces.
It will create an actor operated in discrete action space with structure of preprocess_net ---> 1(q value).
:param preprocess_net: a self-defined preprocess_net which output a
flattened hidden state.
:param preprocess_net: a self-defined preprocess_net. Typically, an instance of
:class:`~tianshou.utils.net.common.Net`.
:param hidden_sizes: a sequence of int for constructing the MLP after
preprocess_net. Default to empty sequence (where the MLP now contains
only a single linear layer).
:param last_size: the output dimension of Critic network. Default to 1.
:param preprocess_net_output_dim: the output dimension of
preprocess_net.
`preprocess_net`. Only used when `preprocess_net` does not have the attribute `output_dim`.
For advanced usage (how to customize the network), please refer to
:ref:`build_the_network`.
.. seealso::
Please refer to :class:`~tianshou.utils.net.common.Net` as an instance
of how preprocess_net is suggested to be defined.
:ref:`build_the_network`..
"""
def __init__(
self,
preprocess_net: nn.Module,
preprocess_net: nn.Module | Net,
hidden_sizes: Sequence[int] = (),
last_size: int = 1,
preprocess_net_output_dim: int | None = None,
@ -120,8 +113,10 @@ class Critic(nn.Module):
input_dim = get_output_dim(preprocess_net, preprocess_net_output_dim)
self.last = MLP(input_dim, last_size, hidden_sizes, device=self.device)
# TODO: make a proper interface!
def forward(self, obs: np.ndarray | torch.Tensor, **kwargs: Any) -> torch.Tensor:
"""Mapping: s -> V(s)."""
"""Mapping: s_B -> V(s)_B."""
# TODO: don't use this mechanism for passing state
logits, _ = self.preprocess(obs, state=kwargs.get("state", None))
return self.last(logits)