Naming and typing improvements in Actor/Critic/Policy forwards (#1032)
Closes #917 ### Internal Improvements - Better variable names related to model outputs (logits, dist input etc.). #1032 - Improved typing for actors and critics, using Tianshou classes like `Actor`, `ActorProb`, etc., instead of just `nn.Module`. #1032 - Added interfaces for most `Actor` and `Critic` classes to enforce the presence of `forward` methods. #1032 - Simplified `PGPolicy` forward by unifying the `dist_fn` interface (see associated breaking change). #1032 - Use `.mode` of distribution instead of relying on knowledge of the distribution type. #1032 ### Breaking Changes - Changed interface of `dist_fn` in `PGPolicy` and all subclasses to take a single argument in both continuous and discrete cases. #1032 --------- Co-authored-by: Arnau Jimenez <arnau.jimenez@zeiss.com> Co-authored-by: Michael Panchenko <m.panchenko@appliedai.de>
This commit is contained in:
parent
5bf923c9bd
commit
bf0d632108
@ -13,6 +13,12 @@
|
||||
- Introduced a first iteration of a naming convention for vars in `Collector`s. #1063
|
||||
- Generally improved readability of Collector code and associated tests (still quite some way to go). #1063
|
||||
- Improved typing for `exploration_noise` and within Collector. #1063
|
||||
- Better variable names related to model outputs (logits, dist input etc.). #1032
|
||||
- Improved typing for actors and critics, using Tianshou classes like `Actor`, `ActorProb`, etc.,
|
||||
instead of just `nn.Module`. #1032
|
||||
- Added interfaces for most `Actor` and `Critic` classes to enforce the presence of `forward` methods. #1032
|
||||
- Simplified `PGPolicy` forward by unifying the `dist_fn` interface (see associated breaking change). #1032
|
||||
- Use `.mode` of distribution instead of relying on knowledge of the distribution type. #1032
|
||||
|
||||
### Breaking Changes
|
||||
|
||||
@ -21,6 +27,8 @@
|
||||
expicitly or pass `reset_before_collect=True` . #1063
|
||||
- VectorEnvs now return an array of info-dicts on reset instead of a list. #1063
|
||||
- Fixed `iter(Batch(...)` which now behaves the same way as `Batch(...).__iter__()`. Can be considered a bugfix. #1063
|
||||
- Changed interface of `dist_fn` in `PGPolicy` and all subclasses to take a single argument in both
|
||||
continuous and discrete cases. #1032
|
||||
|
||||
### Tests
|
||||
- Fixed env seeding it test_sac_with_il.py so that the test doesn't fail randomly. #1081
|
||||
|
||||
@ -69,7 +69,7 @@
|
||||
"from tianshou.policy import BasePolicy\n",
|
||||
"from tianshou.policy.modelfree.pg import (\n",
|
||||
" PGTrainingStats,\n",
|
||||
" TDistributionFunction,\n",
|
||||
" TDistFnDiscrOrCont,\n",
|
||||
" TPGTrainingStats,\n",
|
||||
")\n",
|
||||
"from tianshou.utils import RunningMeanStd\n",
|
||||
@ -339,7 +339,7 @@
|
||||
" *,\n",
|
||||
" actor: torch.nn.Module,\n",
|
||||
" optim: torch.optim.Optimizer,\n",
|
||||
" dist_fn: TDistributionFunction,\n",
|
||||
" dist_fn: TDistFnDiscrOrCont,\n",
|
||||
" action_space: gym.Space,\n",
|
||||
" discount_factor: float = 0.99,\n",
|
||||
" observation_space: gym.Space | None = None,\n",
|
||||
|
||||
@ -257,3 +257,8 @@ macOS
|
||||
joblib
|
||||
master
|
||||
Panchenko
|
||||
BA
|
||||
BH
|
||||
BO
|
||||
BD
|
||||
|
||||
|
||||
@ -167,8 +167,9 @@ def test_gail(args: argparse.Namespace = get_args()) -> None:
|
||||
|
||||
lr_scheduler = LambdaLR(optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num)
|
||||
|
||||
def dist(*logits: torch.Tensor) -> Distribution:
|
||||
return Independent(Normal(*logits), 1)
|
||||
def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution:
|
||||
loc, scale = loc_scale
|
||||
return Independent(Normal(loc, scale), 1)
|
||||
|
||||
# expert replay buffer
|
||||
dataset = d4rl.qlearning_dataset(gym.make(args.expert_data_task))
|
||||
|
||||
@ -137,8 +137,9 @@ def test_a2c(args: argparse.Namespace = get_args()) -> None:
|
||||
|
||||
lr_scheduler = LambdaLR(optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num)
|
||||
|
||||
def dist(*logits: torch.Tensor) -> Distribution:
|
||||
return Independent(Normal(*logits), 1)
|
||||
def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution:
|
||||
loc, scale = loc_scale
|
||||
return Independent(Normal(loc, scale), 1)
|
||||
|
||||
policy: A2CPolicy = A2CPolicy(
|
||||
actor=actor,
|
||||
|
||||
@ -134,8 +134,9 @@ def test_npg(args: argparse.Namespace = get_args()) -> None:
|
||||
|
||||
lr_scheduler = LambdaLR(optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num)
|
||||
|
||||
def dist(*logits: torch.Tensor) -> Distribution:
|
||||
return Independent(Normal(*logits), 1)
|
||||
def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution:
|
||||
loc, scale = loc_scale
|
||||
return Independent(Normal(loc, scale), 1)
|
||||
|
||||
policy: NPGPolicy = NPGPolicy(
|
||||
actor=actor,
|
||||
|
||||
@ -137,8 +137,9 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None:
|
||||
|
||||
lr_scheduler = LambdaLR(optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num)
|
||||
|
||||
def dist(*logits: torch.Tensor) -> Distribution:
|
||||
return Independent(Normal(*logits), 1)
|
||||
def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution:
|
||||
loc, scale = loc_scale
|
||||
return Independent(Normal(loc, scale), 1)
|
||||
|
||||
policy: PPOPolicy = PPOPolicy(
|
||||
actor=actor,
|
||||
|
||||
@ -119,8 +119,9 @@ def test_reinforce(args: argparse.Namespace = get_args()) -> None:
|
||||
|
||||
lr_scheduler = LambdaLR(optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num)
|
||||
|
||||
def dist(*logits: torch.Tensor) -> Distribution:
|
||||
return Independent(Normal(*logits), 1)
|
||||
def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution:
|
||||
loc, scale = loc_scale
|
||||
return Independent(Normal(loc, scale), 1)
|
||||
|
||||
policy: PGPolicy = PGPolicy(
|
||||
actor=actor,
|
||||
|
||||
@ -137,8 +137,9 @@ def test_trpo(args: argparse.Namespace = get_args()) -> None:
|
||||
|
||||
lr_scheduler = LambdaLR(optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num)
|
||||
|
||||
def dist(*logits: torch.Tensor) -> Distribution:
|
||||
return Independent(Normal(*logits), 1)
|
||||
def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution:
|
||||
loc, scale = loc_scale
|
||||
return Independent(Normal(loc, scale), 1)
|
||||
|
||||
policy: TRPOPolicy = TRPOPolicy(
|
||||
actor=actor,
|
||||
|
||||
@ -2,7 +2,7 @@ import gymnasium as gym
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from torch.distributions import Categorical, Independent, Normal
|
||||
from torch.distributions import Categorical, Distribution, Independent, Normal
|
||||
|
||||
from tianshou.policy import PPOPolicy
|
||||
from tianshou.utils.net.common import ActorCritic, Net
|
||||
@ -25,7 +25,11 @@ def policy(request):
|
||||
Net(state_shape=obs_shape, hidden_sizes=[64, 64], action_shape=action_space.shape),
|
||||
action_shape=action_space.shape,
|
||||
)
|
||||
dist_fn = lambda *logits: Independent(Normal(*logits), 1)
|
||||
|
||||
def dist_fn(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution:
|
||||
loc, scale = loc_scale
|
||||
return Independent(Normal(loc, scale), 1)
|
||||
|
||||
elif action_type == "discrete":
|
||||
action_space = gym.spaces.Discrete(3)
|
||||
actor = Actor(
|
||||
|
||||
@ -103,8 +103,9 @@ def test_npg(args: argparse.Namespace = get_args()) -> None:
|
||||
|
||||
# replace DiagGuassian with Independent(Normal) which is equivalent
|
||||
# pass *logits to be consistent with policy.forward
|
||||
def dist(*logits: torch.Tensor) -> Distribution:
|
||||
return Independent(Normal(*logits), 1)
|
||||
def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution:
|
||||
loc, scale = loc_scale
|
||||
return Independent(Normal(loc, scale), 1)
|
||||
|
||||
policy: NPGPolicy[NPGTrainingStats] = NPGPolicy(
|
||||
actor=actor,
|
||||
|
||||
@ -100,8 +100,9 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None:
|
||||
|
||||
# replace DiagGuassian with Independent(Normal) which is equivalent
|
||||
# pass *logits to be consistent with policy.forward
|
||||
def dist(*logits: torch.Tensor) -> Distribution:
|
||||
return Independent(Normal(*logits), 1)
|
||||
def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution:
|
||||
loc, scale = loc_scale
|
||||
return Independent(Normal(loc, scale), 1)
|
||||
|
||||
policy: PPOPolicy[PPOTrainingStats] = PPOPolicy(
|
||||
actor=actor,
|
||||
|
||||
@ -102,8 +102,9 @@ def test_trpo(args: argparse.Namespace = get_args()) -> None:
|
||||
|
||||
# replace DiagGuassian with Independent(Normal) which is equivalent
|
||||
# pass *logits to be consistent with policy.forward
|
||||
def dist(*logits: torch.Tensor) -> Distribution:
|
||||
return Independent(Normal(*logits), 1)
|
||||
def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution:
|
||||
loc, scale = loc_scale
|
||||
return Independent(Normal(loc, scale), 1)
|
||||
|
||||
policy: BasePolicy = TRPOPolicy(
|
||||
actor=actor,
|
||||
|
||||
@ -133,8 +133,9 @@ def test_gail(args: argparse.Namespace = get_args()) -> None:
|
||||
|
||||
# replace DiagGuassian with Independent(Normal) which is equivalent
|
||||
# pass *logits to be consistent with policy.forward
|
||||
def dist(*logits: torch.Tensor) -> Distribution:
|
||||
return Independent(Normal(*logits), 1)
|
||||
def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution:
|
||||
loc, scale = loc_scale
|
||||
return Independent(Normal(loc, scale), 1)
|
||||
|
||||
policy: BasePolicy = GAILPolicy(
|
||||
actor=actor,
|
||||
|
||||
@ -181,8 +181,9 @@ def get_agents(
|
||||
torch.nn.init.zeros_(m.bias)
|
||||
optim = torch.optim.Adam(set(actor.parameters()).union(critic.parameters()), lr=args.lr)
|
||||
|
||||
def dist(*logits: torch.Tensor) -> Distribution:
|
||||
return Independent(Normal(*logits), 1)
|
||||
def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution:
|
||||
loc, scale = loc_scale
|
||||
return Independent(Normal(loc, scale), 1)
|
||||
|
||||
agent: PPOPolicy = PPOPolicy(
|
||||
actor,
|
||||
|
||||
@ -1,40 +1,47 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from tianshou.highlevel.env import Environments, EnvType
|
||||
from tianshou.policy.modelfree.pg import TDistributionFunction
|
||||
from tianshou.policy.modelfree.pg import TDistFnDiscrete, TDistFnDiscrOrCont
|
||||
from tianshou.utils.string import ToStringMixin
|
||||
|
||||
|
||||
class DistributionFunctionFactory(ToStringMixin, ABC):
|
||||
# True return type defined in subclasses
|
||||
@abstractmethod
|
||||
def create_dist_fn(self, envs: Environments) -> TDistributionFunction:
|
||||
def create_dist_fn(
|
||||
self,
|
||||
envs: Environments,
|
||||
) -> Callable[[Any], torch.distributions.Distribution]:
|
||||
pass
|
||||
|
||||
|
||||
class DistributionFunctionFactoryCategorical(DistributionFunctionFactory):
|
||||
def create_dist_fn(self, envs: Environments) -> TDistributionFunction:
|
||||
def create_dist_fn(self, envs: Environments) -> TDistFnDiscrete:
|
||||
envs.get_type().assert_discrete(self)
|
||||
return self._dist_fn
|
||||
|
||||
@staticmethod
|
||||
def _dist_fn(p: torch.Tensor) -> torch.distributions.Distribution:
|
||||
def _dist_fn(p: torch.Tensor) -> torch.distributions.Categorical:
|
||||
return torch.distributions.Categorical(logits=p)
|
||||
|
||||
|
||||
class DistributionFunctionFactoryIndependentGaussians(DistributionFunctionFactory):
|
||||
def create_dist_fn(self, envs: Environments) -> TDistributionFunction:
|
||||
def create_dist_fn(self, envs: Environments) -> TDistFnDiscrOrCont:
|
||||
envs.get_type().assert_continuous(self)
|
||||
return self._dist_fn
|
||||
|
||||
@staticmethod
|
||||
def _dist_fn(*p: torch.Tensor) -> torch.distributions.Distribution:
|
||||
return torch.distributions.Independent(torch.distributions.Normal(*p), 1)
|
||||
def _dist_fn(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> torch.distributions.Distribution:
|
||||
loc, scale = loc_scale
|
||||
return torch.distributions.Independent(torch.distributions.Normal(loc, scale), 1)
|
||||
|
||||
|
||||
class DistributionFunctionFactoryDefault(DistributionFunctionFactory):
|
||||
def create_dist_fn(self, envs: Environments) -> TDistributionFunction:
|
||||
def create_dist_fn(self, envs: Environments) -> TDistFnDiscrOrCont:
|
||||
match envs.get_type():
|
||||
case EnvType.DISCRETE:
|
||||
return DistributionFunctionFactoryCategorical().create_dist_fn(envs)
|
||||
|
||||
@ -19,7 +19,7 @@ from tianshou.highlevel.params.dist_fn import (
|
||||
from tianshou.highlevel.params.env_param import EnvValueFactory, FloatEnvValueFactory
|
||||
from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactory
|
||||
from tianshou.highlevel.params.noise import NoiseFactory
|
||||
from tianshou.policy.modelfree.pg import TDistributionFunction
|
||||
from tianshou.policy.modelfree.pg import TDistFnDiscrOrCont
|
||||
from tianshou.utils import MultipleLRSchedulers
|
||||
from tianshou.utils.string import ToStringMixin
|
||||
|
||||
@ -322,7 +322,7 @@ class PGParams(Params, ParamsMixinActionScaling, ParamsMixinLearningRateWithSche
|
||||
whether to use deterministic action (the dist's mode) instead of stochastic one during evaluation.
|
||||
Does not affect training.
|
||||
"""
|
||||
dist_fn: TDistributionFunction | DistributionFunctionFactory | Literal["default"] = "default"
|
||||
dist_fn: TDistFnDiscrOrCont | DistributionFunctionFactory | Literal["default"] = "default"
|
||||
"""
|
||||
This can either be a function which maps the model output to a torch distribution or a
|
||||
factory for the creation of such a function.
|
||||
|
||||
@ -213,10 +213,11 @@ class BasePolicy(nn.Module, Generic[TTrainingStats], ABC):
|
||||
super().__init__()
|
||||
self.observation_space = observation_space
|
||||
self.action_space = action_space
|
||||
self._action_type: Literal["discrete", "continuous"]
|
||||
if isinstance(action_space, Discrete | MultiDiscrete | MultiBinary):
|
||||
self.action_type = "discrete"
|
||||
self._action_type = "discrete"
|
||||
elif isinstance(action_space, Box):
|
||||
self.action_type = "continuous"
|
||||
self._action_type = "continuous"
|
||||
else:
|
||||
raise ValueError(f"Unsupported action space: {action_space}.")
|
||||
self.agent_id = 0
|
||||
@ -226,6 +227,10 @@ class BasePolicy(nn.Module, Generic[TTrainingStats], ABC):
|
||||
self.lr_scheduler = lr_scheduler
|
||||
self._compile()
|
||||
|
||||
@property
|
||||
def action_type(self) -> Literal["discrete", "continuous"]:
|
||||
return self._action_type
|
||||
|
||||
def set_agent_id(self, agent_id: int) -> None:
|
||||
"""Set self.agent_id = agent_id, for MARL."""
|
||||
self.agent_id = agent_id
|
||||
|
||||
@ -16,6 +16,12 @@ from tianshou.data.types import (
|
||||
from tianshou.policy import BasePolicy
|
||||
from tianshou.policy.base import TLearningRateScheduler, TrainingStats
|
||||
|
||||
# Dimension Naming Convention
|
||||
# B - Batch Size
|
||||
# A - Action
|
||||
# D - Dist input (usually 2, loc and scale)
|
||||
# H - Dimension of hidden, can be None
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class ImitationTrainingStats(TrainingStats):
|
||||
@ -72,9 +78,20 @@ class ImitationPolicy(BasePolicy[TImitationTrainingStats], Generic[TImitationTra
|
||||
state: dict | BatchProtocol | np.ndarray | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ModelOutputBatchProtocol:
|
||||
logits, hidden = self.actor(batch.obs, state=state, info=batch.info)
|
||||
act = logits.max(dim=1)[1] if self.action_type == "discrete" else logits
|
||||
result = Batch(logits=logits, act=act, state=hidden)
|
||||
# TODO - ALGO-REFACTORING: marked for refactoring when Algorithm abstraction is introduced
|
||||
if self.action_type == "discrete":
|
||||
# If it's discrete, the "actor" is usually a critic that maps obs to action_values
|
||||
# which then could be turned into logits or a Categorigal
|
||||
action_values_BA, hidden_BH = self.actor(batch.obs, state=state, info=batch.info)
|
||||
act_B = action_values_BA.argmax(dim=1)
|
||||
result = Batch(logits=action_values_BA, act=act_B, state=hidden_BH)
|
||||
elif self.action_type == "continuous":
|
||||
# If it's continuous, the actor would usually deliver something like loc, scale determining a
|
||||
# Gaussian dist
|
||||
dist_input_BD, hidden_BH = self.actor(batch.obs, state=state, info=batch.info)
|
||||
result = Batch(logits=dist_input_BD, act=dist_input_BD, state=hidden_BH)
|
||||
else:
|
||||
raise RuntimeError(f"Unknown {self.action_type=}, this shouldn't have happened!")
|
||||
return cast(ModelOutputBatchProtocol, result)
|
||||
|
||||
def learn(
|
||||
|
||||
@ -34,8 +34,7 @@ TDiscreteBCQTrainingStats = TypeVar("TDiscreteBCQTrainingStats", bound=DiscreteB
|
||||
class DiscreteBCQPolicy(DQNPolicy[TDiscreteBCQTrainingStats]):
|
||||
"""Implementation of discrete BCQ algorithm. arXiv:1910.01708.
|
||||
|
||||
:param model: a model following the rules in
|
||||
:class:`~tianshou.policy.BasePolicy`. (s -> q_value)
|
||||
:param model: a model following the rules (s_B -> action_values_BA)
|
||||
:param imitator: a model following the rules in
|
||||
:class:`~tianshou.policy.BasePolicy`. (s -> imitation_logits)
|
||||
:param optim: a torch.optim for optimizing the model.
|
||||
|
||||
@ -25,8 +25,7 @@ TDiscreteCQLTrainingStats = TypeVar("TDiscreteCQLTrainingStats", bound=DiscreteC
|
||||
class DiscreteCQLPolicy(QRDQNPolicy[TDiscreteCQLTrainingStats]):
|
||||
"""Implementation of discrete Conservative Q-Learning algorithm. arXiv:2006.04779.
|
||||
|
||||
:param model: a model following the rules in
|
||||
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
|
||||
:param model: a model following the rules (s_B -> action_values_BA)
|
||||
:param optim: a torch.optim for optimizing the model.
|
||||
:param action_space: Env's action space.
|
||||
:param min_q_weight: the weight for the cql loss.
|
||||
|
||||
@ -11,6 +11,7 @@ from tianshou.data import to_torch, to_torch_as
|
||||
from tianshou.data.types import RolloutBatchProtocol
|
||||
from tianshou.policy.base import TLearningRateScheduler
|
||||
from tianshou.policy.modelfree.pg import PGPolicy, PGTrainingStats
|
||||
from tianshou.utils.net.discrete import Actor, Critic
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -26,8 +27,9 @@ TDiscreteCRRTrainingStats = TypeVar("TDiscreteCRRTrainingStats", bound=DiscreteC
|
||||
class DiscreteCRRPolicy(PGPolicy[TDiscreteCRRTrainingStats]):
|
||||
r"""Implementation of discrete Critic Regularized Regression. arXiv:2006.15134.
|
||||
|
||||
:param actor: the actor network following the rules in
|
||||
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
|
||||
:param actor: the actor network following the rules:
|
||||
If `self.action_type == "discrete"`: (`s_B` ->`action_values_BA`).
|
||||
If `self.action_type == "continuous"`: (`s_B` -> `dist_input_BD`).
|
||||
:param critic: the action-value critic (i.e., Q function)
|
||||
network. (s -> Q(s, \*))
|
||||
:param optim: a torch.optim for optimizing the model.
|
||||
@ -55,8 +57,8 @@ class DiscreteCRRPolicy(PGPolicy[TDiscreteCRRTrainingStats]):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
actor: torch.nn.Module,
|
||||
critic: torch.nn.Module,
|
||||
actor: torch.nn.Module | Actor,
|
||||
critic: torch.nn.Module | Critic,
|
||||
optim: torch.optim.Optimizer,
|
||||
action_space: gym.spaces.Discrete,
|
||||
discount_factor: float = 0.99,
|
||||
|
||||
@ -15,8 +15,11 @@ from tianshou.data import (
|
||||
from tianshou.data.types import LogpOldProtocol, RolloutBatchProtocol
|
||||
from tianshou.policy import PPOPolicy
|
||||
from tianshou.policy.base import TLearningRateScheduler
|
||||
from tianshou.policy.modelfree.pg import TDistributionFunction
|
||||
from tianshou.policy.modelfree.pg import TDistFnDiscrOrCont
|
||||
from tianshou.policy.modelfree.ppo import PPOTrainingStats
|
||||
from tianshou.utils.net.continuous import ActorProb, Critic
|
||||
from tianshou.utils.net.discrete import Actor as DiscreteActor
|
||||
from tianshou.utils.net.discrete import Critic as DiscreteCritic
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
@ -32,7 +35,9 @@ TGailTrainingStats = TypeVar("TGailTrainingStats", bound=GailTrainingStats)
|
||||
class GAILPolicy(PPOPolicy[TGailTrainingStats]):
|
||||
r"""Implementation of Generative Adversarial Imitation Learning. arXiv:1606.03476.
|
||||
|
||||
:param actor: the actor network following the rules in BasePolicy. (s -> logits)
|
||||
:param actor: the actor network following the rules:
|
||||
If `self.action_type == "discrete"`: (`s_B` ->`action_values_BA`).
|
||||
If `self.action_type == "continuous"`: (`s_B` -> `dist_input_BD`).
|
||||
:param critic: the critic network. (s -> V(s))
|
||||
:param optim: the optimizer for actor and critic network.
|
||||
:param dist_fn: distribution class for computing the action.
|
||||
@ -75,10 +80,10 @@ class GAILPolicy(PPOPolicy[TGailTrainingStats]):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
actor: torch.nn.Module,
|
||||
critic: torch.nn.Module,
|
||||
actor: torch.nn.Module | ActorProb | DiscreteActor,
|
||||
critic: torch.nn.Module | Critic | DiscreteCritic,
|
||||
optim: torch.optim.Optimizer,
|
||||
dist_fn: TDistributionFunction,
|
||||
dist_fn: TDistFnDiscrOrCont,
|
||||
action_space: gym.Space,
|
||||
expert_buffer: ReplayBuffer,
|
||||
disc_net: torch.nn.Module,
|
||||
|
||||
@ -25,7 +25,7 @@ class TD3BCPolicy(TD3Policy[TTD3BCTrainingStats]):
|
||||
"""Implementation of TD3+BC. arXiv:2106.06860.
|
||||
|
||||
:param actor: the actor network following the rules in
|
||||
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
|
||||
:class:`~tianshou.policy.BasePolicy`. (s -> actions)
|
||||
:param actor_optim: the optimizer for actor network.
|
||||
:param critic: the first critic network. (s, a -> Q(s, a))
|
||||
:param critic_optim: the optimizer for the first critic network.
|
||||
|
||||
@ -11,8 +11,11 @@ from tianshou.data import ReplayBuffer, SequenceSummaryStats, to_torch_as
|
||||
from tianshou.data.types import BatchWithAdvantagesProtocol, RolloutBatchProtocol
|
||||
from tianshou.policy import PGPolicy
|
||||
from tianshou.policy.base import TLearningRateScheduler, TrainingStats
|
||||
from tianshou.policy.modelfree.pg import TDistributionFunction
|
||||
from tianshou.policy.modelfree.pg import TDistFnDiscrOrCont
|
||||
from tianshou.utils.net.common import ActorCritic
|
||||
from tianshou.utils.net.continuous import ActorProb, Critic
|
||||
from tianshou.utils.net.discrete import Actor as DiscreteActor
|
||||
from tianshou.utils.net.discrete import Critic as DiscreteCritic
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
@ -30,7 +33,9 @@ TA2CTrainingStats = TypeVar("TA2CTrainingStats", bound=A2CTrainingStats)
|
||||
class A2CPolicy(PGPolicy[TA2CTrainingStats], Generic[TA2CTrainingStats]): # type: ignore[type-var]
|
||||
"""Implementation of Synchronous Advantage Actor-Critic. arXiv:1602.01783.
|
||||
|
||||
:param actor: the actor network following the rules in BasePolicy. (s -> logits)
|
||||
:param actor: the actor network following the rules:
|
||||
If `self.action_type == "discrete"`: (`s_B` ->`action_values_BA`).
|
||||
If `self.action_type == "continuous"`: (`s_B` -> `dist_input_BD`).
|
||||
:param critic: the critic network. (s -> V(s))
|
||||
:param optim: the optimizer for actor and critic network.
|
||||
:param dist_fn: distribution class for computing the action.
|
||||
@ -59,10 +64,10 @@ class A2CPolicy(PGPolicy[TA2CTrainingStats], Generic[TA2CTrainingStats]): # typ
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
actor: torch.nn.Module,
|
||||
critic: torch.nn.Module,
|
||||
actor: torch.nn.Module | ActorProb | DiscreteActor,
|
||||
critic: torch.nn.Module | Critic | DiscreteCritic,
|
||||
optim: torch.optim.Optimizer,
|
||||
dist_fn: TDistributionFunction,
|
||||
dist_fn: TDistFnDiscrOrCont,
|
||||
action_space: gym.Space,
|
||||
vf_coef: float = 0.5,
|
||||
ent_coef: float = 0.01,
|
||||
|
||||
@ -31,7 +31,7 @@ TBDQNTrainingStats = TypeVar("TBDQNTrainingStats", bound=BDQNTrainingStats)
|
||||
class BranchingDQNPolicy(DQNPolicy[TBDQNTrainingStats]):
|
||||
"""Implementation of the Branching dual Q network arXiv:1711.08946.
|
||||
|
||||
:param model: BranchingNet mapping (obs, state, info) -> logits.
|
||||
:param model: BranchingNet mapping (obs, state, info) -> action_values_BA.
|
||||
:param optim: a torch.optim for optimizing the model.
|
||||
:param discount_factor: in [0, 1].
|
||||
:param estimation_step: the number of steps to look ahead.
|
||||
@ -156,10 +156,10 @@ class BranchingDQNPolicy(DQNPolicy[TBDQNTrainingStats]):
|
||||
model = getattr(self, model)
|
||||
obs = batch.obs
|
||||
# TODO: this is very contrived, see also iqn.py
|
||||
obs_next = obs.obs if hasattr(obs, "obs") else obs
|
||||
logits, hidden = model(obs_next, state=state, info=batch.info)
|
||||
act = to_numpy(logits.max(dim=-1)[1])
|
||||
result = Batch(logits=logits, act=act, state=hidden)
|
||||
obs_next_BO = obs.obs if hasattr(obs, "obs") else obs
|
||||
action_values_BA, hidden_BH = model(obs_next_BO, state=state, info=batch.info)
|
||||
act_B = to_numpy(action_values_BA.argmax(dim=-1))
|
||||
result = Batch(logits=action_values_BA, act=act_B, state=hidden_BH)
|
||||
return cast(ModelOutputBatchProtocol, result)
|
||||
|
||||
def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TBDQNTrainingStats:
|
||||
|
||||
@ -23,8 +23,7 @@ TC51TrainingStats = TypeVar("TC51TrainingStats", bound=C51TrainingStats)
|
||||
class C51Policy(DQNPolicy[TC51TrainingStats], Generic[TC51TrainingStats]):
|
||||
"""Implementation of Categorical Deep Q-Network. arXiv:1707.06887.
|
||||
|
||||
:param model: a model following the rules in
|
||||
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
|
||||
:param model: a model following the rules (s_B -> action_values_BA)
|
||||
:param optim: a torch.optim for optimizing the model.
|
||||
:param discount_factor: in [0, 1].
|
||||
:param num_atoms: the number of atoms in the support set of the
|
||||
|
||||
@ -19,6 +19,7 @@ from tianshou.data.types import (
|
||||
from tianshou.exploration import BaseNoise, GaussianNoise
|
||||
from tianshou.policy import BasePolicy
|
||||
from tianshou.policy.base import TLearningRateScheduler, TrainingStats
|
||||
from tianshou.utils.net.continuous import Actor, Critic
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
@ -33,8 +34,7 @@ TDDPGTrainingStats = TypeVar("TDDPGTrainingStats", bound=DDPGTrainingStats)
|
||||
class DDPGPolicy(BasePolicy[TDDPGTrainingStats], Generic[TDDPGTrainingStats]):
|
||||
"""Implementation of Deep Deterministic Policy Gradient. arXiv:1509.02971.
|
||||
|
||||
:param actor: The actor network following the rules in
|
||||
:class:`~tianshou.policy.BasePolicy`. (s -> model_output)
|
||||
:param actor: The actor network following the rules (s -> actions)
|
||||
:param actor_optim: The optimizer for actor network.
|
||||
:param critic: The critic network. (s, a -> Q(s, a))
|
||||
:param critic_optim: The optimizer for critic network.
|
||||
@ -60,9 +60,9 @@ class DDPGPolicy(BasePolicy[TDDPGTrainingStats], Generic[TDDPGTrainingStats]):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
actor: torch.nn.Module,
|
||||
actor: torch.nn.Module | Actor,
|
||||
actor_optim: torch.optim.Optimizer,
|
||||
critic: torch.nn.Module,
|
||||
critic: torch.nn.Module | Critic,
|
||||
critic_optim: torch.optim.Optimizer,
|
||||
action_space: gym.Space,
|
||||
tau: float = 0.005,
|
||||
|
||||
@ -12,6 +12,7 @@ from tianshou.data.types import ActBatchProtocol, ObsBatchProtocol, RolloutBatch
|
||||
from tianshou.policy import SACPolicy
|
||||
from tianshou.policy.base import TLearningRateScheduler
|
||||
from tianshou.policy.modelfree.sac import SACTrainingStats
|
||||
from tianshou.utils.net.discrete import Actor, Critic
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -25,8 +26,7 @@ TDiscreteSACTrainingStats = TypeVar("TDiscreteSACTrainingStats", bound=DiscreteS
|
||||
class DiscreteSACPolicy(SACPolicy[TDiscreteSACTrainingStats]):
|
||||
"""Implementation of SAC for Discrete Action Settings. arXiv:1910.07207.
|
||||
|
||||
:param actor: the actor network following the rules in
|
||||
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
|
||||
:param actor: the actor network following the rules (s_B -> dist_input_BD)
|
||||
:param actor_optim: the optimizer for actor network.
|
||||
:param critic: the first critic network. (s, a -> Q(s, a))
|
||||
:param critic_optim: the optimizer for the first critic network.
|
||||
@ -54,12 +54,12 @@ class DiscreteSACPolicy(SACPolicy[TDiscreteSACTrainingStats]):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
actor: torch.nn.Module,
|
||||
actor: torch.nn.Module | Actor,
|
||||
actor_optim: torch.optim.Optimizer,
|
||||
critic: torch.nn.Module,
|
||||
critic: torch.nn.Module | Critic,
|
||||
critic_optim: torch.optim.Optimizer,
|
||||
action_space: gym.spaces.Discrete,
|
||||
critic2: torch.nn.Module | None = None,
|
||||
critic2: torch.nn.Module | Critic | None = None,
|
||||
critic2_optim: torch.optim.Optimizer | None = None,
|
||||
tau: float = 0.005,
|
||||
gamma: float = 0.99,
|
||||
@ -105,13 +105,13 @@ class DiscreteSACPolicy(SACPolicy[TDiscreteSACTrainingStats]):
|
||||
state: dict | Batch | np.ndarray | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Batch:
|
||||
logits, hidden = self.actor(batch.obs, state=state, info=batch.info)
|
||||
dist = Categorical(logits=logits)
|
||||
logits_BA, hidden_BH = self.actor(batch.obs, state=state, info=batch.info)
|
||||
dist = Categorical(logits=logits_BA)
|
||||
if self.deterministic_eval and not self.training:
|
||||
act = dist.mode
|
||||
act_B = dist.mode
|
||||
else:
|
||||
act = dist.sample()
|
||||
return Batch(logits=logits, act=act, state=hidden, dist=dist)
|
||||
act_B = dist.sample()
|
||||
return Batch(logits=logits_BA, act=act_B, state=hidden_BH, dist=dist)
|
||||
|
||||
def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor:
|
||||
obs_next_batch = Batch(
|
||||
|
||||
@ -17,6 +17,7 @@ from tianshou.data.types import (
|
||||
)
|
||||
from tianshou.policy import BasePolicy
|
||||
from tianshou.policy.base import TLearningRateScheduler, TrainingStats
|
||||
from tianshou.utils.net.common import Net
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
@ -35,8 +36,7 @@ class DQNPolicy(BasePolicy[TDQNTrainingStats], Generic[TDQNTrainingStats]):
|
||||
Implementation of Dueling DQN. arXiv:1511.06581 (the dueling DQN is
|
||||
implemented in the network side, not here).
|
||||
|
||||
:param model: a model following the rules in
|
||||
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
|
||||
:param model: a model following the rules (s -> action_values_BA)
|
||||
:param optim: a torch.optim for optimizing the model.
|
||||
:param discount_factor: in [0, 1].
|
||||
:param estimation_step: the number of steps to look ahead.
|
||||
@ -60,7 +60,7 @@ class DQNPolicy(BasePolicy[TDQNTrainingStats], Generic[TDQNTrainingStats]):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
model: torch.nn.Module,
|
||||
model: torch.nn.Module | Net,
|
||||
optim: torch.optim.Optimizer,
|
||||
# TODO: type violates Liskov substitution principle
|
||||
action_space: gym.spaces.Discrete,
|
||||
@ -201,12 +201,12 @@ class DQNPolicy(BasePolicy[TDQNTrainingStats], Generic[TDQNTrainingStats]):
|
||||
obs = batch.obs
|
||||
# TODO: this is convoluted! See also other places where this is done.
|
||||
obs_next = obs.obs if hasattr(obs, "obs") else obs
|
||||
logits, hidden = model(obs_next, state=state, info=batch.info)
|
||||
q = self.compute_q_value(logits, getattr(obs, "mask", None))
|
||||
action_values_BA, hidden_BH = model(obs_next, state=state, info=batch.info)
|
||||
q = self.compute_q_value(action_values_BA, getattr(obs, "mask", None))
|
||||
if self.max_action_num is None:
|
||||
self.max_action_num = q.shape[1]
|
||||
act = to_numpy(q.max(dim=1)[1])
|
||||
result = Batch(logits=logits, act=act, state=hidden)
|
||||
act_B = to_numpy(q.argmax(dim=1))
|
||||
result = Batch(logits=action_values_BA, act=act_B, state=hidden_BH)
|
||||
return cast(ModelOutputBatchProtocol, result)
|
||||
|
||||
def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TDQNTrainingStats:
|
||||
|
||||
@ -27,8 +27,7 @@ TFQFTrainingStats = TypeVar("TFQFTrainingStats", bound=FQFTrainingStats)
|
||||
class FQFPolicy(QRDQNPolicy[TFQFTrainingStats]):
|
||||
"""Implementation of Fully-parameterized Quantile Function. arXiv:1911.02140.
|
||||
|
||||
:param model: a model following the rules in
|
||||
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
|
||||
:param model: a model following the rules (s_B -> action_values_BA)
|
||||
:param optim: a torch.optim for optimizing the model.
|
||||
:param fraction_model: a FractionProposalNetwork for
|
||||
proposing fractions/quantiles given state.
|
||||
|
||||
@ -29,8 +29,7 @@ TIQNTrainingStats = TypeVar("TIQNTrainingStats", bound=IQNTrainingStats)
|
||||
class IQNPolicy(QRDQNPolicy[TIQNTrainingStats]):
|
||||
"""Implementation of Implicit Quantile Network. arXiv:1806.06923.
|
||||
|
||||
:param model: a model following the rules in
|
||||
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
|
||||
:param model: a model following the rules (s_B -> action_values_BA)
|
||||
:param optim: a torch.optim for optimizing the model.
|
||||
:param discount_factor: in [0, 1].
|
||||
:param sample_size: the number of samples for policy evaluation.
|
||||
|
||||
@ -12,7 +12,10 @@ from tianshou.data import Batch, ReplayBuffer, SequenceSummaryStats
|
||||
from tianshou.data.types import BatchWithAdvantagesProtocol, RolloutBatchProtocol
|
||||
from tianshou.policy import A2CPolicy
|
||||
from tianshou.policy.base import TLearningRateScheduler, TrainingStats
|
||||
from tianshou.policy.modelfree.pg import TDistributionFunction
|
||||
from tianshou.policy.modelfree.pg import TDistFnDiscrOrCont
|
||||
from tianshou.utils.net.continuous import ActorProb, Critic
|
||||
from tianshou.utils.net.discrete import Actor as DiscreteActor
|
||||
from tianshou.utils.net.discrete import Critic as DiscreteCritic
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
@ -31,7 +34,9 @@ class NPGPolicy(A2CPolicy[TNPGTrainingStats], Generic[TNPGTrainingStats]): # ty
|
||||
|
||||
https://proceedings.neurips.cc/paper/2001/file/4b86abe48d358ecf194c56c69108433e-Paper.pdf
|
||||
|
||||
:param actor: the actor network following the rules in BasePolicy. (s -> logits)
|
||||
:param actor: the actor network following the rules:
|
||||
If `self.action_type == "discrete"`: (`s` ->`action_values_BA`).
|
||||
If `self.action_type == "continuous"`: (`s` -> `dist_input_BD`).
|
||||
:param critic: the critic network. (s -> V(s))
|
||||
:param optim: the optimizer for actor and critic network.
|
||||
:param dist_fn: distribution class for computing the action.
|
||||
@ -55,10 +60,10 @@ class NPGPolicy(A2CPolicy[TNPGTrainingStats], Generic[TNPGTrainingStats]): # ty
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
actor: torch.nn.Module,
|
||||
critic: torch.nn.Module,
|
||||
actor: torch.nn.Module | ActorProb | DiscreteActor,
|
||||
critic: torch.nn.Module | Critic | DiscreteCritic,
|
||||
optim: torch.optim.Optimizer,
|
||||
dist_fn: TDistributionFunction,
|
||||
dist_fn: TDistFnDiscrOrCont,
|
||||
action_space: gym.Space,
|
||||
optim_critic_iters: int = 5,
|
||||
actor_step_size: float = 0.5,
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import warnings
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Generic, Literal, TypeAlias, TypeVar, cast
|
||||
from typing import Any, Generic, Literal, TypeVar, cast
|
||||
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
@ -24,9 +24,22 @@ from tianshou.data.types import (
|
||||
from tianshou.policy import BasePolicy
|
||||
from tianshou.policy.base import TLearningRateScheduler, TrainingStats
|
||||
from tianshou.utils import RunningMeanStd
|
||||
from tianshou.utils.net.continuous import ActorProb
|
||||
from tianshou.utils.net.discrete import Actor
|
||||
|
||||
# TODO: Is there a better way to define this type? mypy doesn't like Callable[[torch.Tensor, ...], torch.distributions.Distribution]
|
||||
TDistributionFunction: TypeAlias = Callable[..., torch.distributions.Distribution]
|
||||
# Dimension Naming Convention
|
||||
# B - Batch Size
|
||||
# A - Action
|
||||
# D - Dist input (usually 2, loc and scale)
|
||||
# H - Dimension of hidden, can be None
|
||||
|
||||
TDistFnContinuous = Callable[
|
||||
[tuple[torch.Tensor, torch.Tensor]],
|
||||
torch.distributions.Distribution,
|
||||
]
|
||||
TDistFnDiscrete = Callable[[torch.Tensor], torch.distributions.Categorical]
|
||||
|
||||
TDistFnDiscrOrCont = TDistFnContinuous | TDistFnDiscrete
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
@ -40,8 +53,9 @@ TPGTrainingStats = TypeVar("TPGTrainingStats", bound=PGTrainingStats)
|
||||
class PGPolicy(BasePolicy[TPGTrainingStats], Generic[TPGTrainingStats]):
|
||||
"""Implementation of REINFORCE algorithm.
|
||||
|
||||
:param actor: mapping (s->model_output), should follow the rules in
|
||||
:class:`~tianshou.policy.BasePolicy`.
|
||||
:param actor: the actor network following the rules:
|
||||
If `self.action_type == "discrete"`: (`s_B` ->`action_values_BA`).
|
||||
If `self.action_type == "continuous"`: (`s_B` -> `dist_input_BD`).
|
||||
:param optim: optimizer for actor network.
|
||||
:param dist_fn: distribution class for computing the action.
|
||||
Maps model_output -> distribution. Typically a Gaussian distribution
|
||||
@ -71,9 +85,9 @@ class PGPolicy(BasePolicy[TPGTrainingStats], Generic[TPGTrainingStats]):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
actor: torch.nn.Module,
|
||||
actor: torch.nn.Module | ActorProb | Actor,
|
||||
optim: torch.optim.Optimizer,
|
||||
dist_fn: TDistributionFunction,
|
||||
dist_fn: TDistFnDiscrOrCont,
|
||||
action_space: gym.Space,
|
||||
discount_factor: float = 0.99,
|
||||
# TODO: rename to return_normalization?
|
||||
@ -175,20 +189,20 @@ class PGPolicy(BasePolicy[TPGTrainingStats], Generic[TPGTrainingStats]):
|
||||
Please refer to :meth:`~tianshou.policy.BasePolicy.forward` for
|
||||
more detailed explanation.
|
||||
"""
|
||||
# TODO: rename? It's not really logits and there are particular
|
||||
# assumptions about the order of the output and on distribution type
|
||||
logits, hidden = self.actor(batch.obs, state=state, info=batch.info)
|
||||
if isinstance(logits, tuple):
|
||||
dist = self.dist_fn(*logits)
|
||||
else:
|
||||
dist = self.dist_fn(logits)
|
||||
# TODO - ALGO: marked for algorithm refactoring
|
||||
action_dist_input_BD, hidden_BH = self.actor(batch.obs, state=state, info=batch.info)
|
||||
# in the case that self.action_type == "discrete", the dist should always be Categorical, and D=A
|
||||
# therefore action_dist_input_BD is equivalent to logits_BA
|
||||
# If discrete, dist_fn will typically map loc, scale to a distribution (usually a Gaussian)
|
||||
# the action_dist_input_BD in that case is a tuple of loc_B, scale_B and needs to be unpacked
|
||||
dist = self.dist_fn(action_dist_input_BD)
|
||||
|
||||
# in this case, the dist is unused!
|
||||
if self.deterministic_eval and not self.training:
|
||||
act = dist.mode
|
||||
act_B = dist.mode
|
||||
else:
|
||||
act = dist.sample()
|
||||
result = Batch(logits=logits, act=act, state=hidden, dist=dist)
|
||||
act_B = dist.sample()
|
||||
# act is of dimension BA in continuous case and of dimension B in discrete
|
||||
result = Batch(logits=action_dist_input_BD, act=act_B, state=hidden_BH, dist=dist)
|
||||
return cast(DistBatchProtocol, result)
|
||||
|
||||
# TODO: why does mypy complain?
|
||||
|
||||
@ -10,8 +10,11 @@ from tianshou.data import ReplayBuffer, SequenceSummaryStats, to_torch_as
|
||||
from tianshou.data.types import LogpOldProtocol, RolloutBatchProtocol
|
||||
from tianshou.policy import A2CPolicy
|
||||
from tianshou.policy.base import TLearningRateScheduler, TrainingStats
|
||||
from tianshou.policy.modelfree.pg import TDistributionFunction
|
||||
from tianshou.policy.modelfree.pg import TDistFnDiscrOrCont
|
||||
from tianshou.utils.net.common import ActorCritic
|
||||
from tianshou.utils.net.continuous import ActorProb, Critic
|
||||
from tianshou.utils.net.discrete import Actor as DiscreteActor
|
||||
from tianshou.utils.net.discrete import Critic as DiscreteCritic
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
@ -29,7 +32,9 @@ TPPOTrainingStats = TypeVar("TPPOTrainingStats", bound=PPOTrainingStats)
|
||||
class PPOPolicy(A2CPolicy[TPPOTrainingStats], Generic[TPPOTrainingStats]): # type: ignore[type-var]
|
||||
r"""Implementation of Proximal Policy Optimization. arXiv:1707.06347.
|
||||
|
||||
:param actor: the actor network following the rules in BasePolicy. (s -> logits)
|
||||
:param actor: the actor network following the rules:
|
||||
If `self.action_type == "discrete"`: (`s` ->`action_values_BA`).
|
||||
If `self.action_type == "continuous"`: (`s` -> `dist_input_BD`).
|
||||
:param critic: the critic network. (s -> V(s))
|
||||
:param optim: the optimizer for actor and critic network.
|
||||
:param dist_fn: distribution class for computing the action.
|
||||
@ -67,10 +72,10 @@ class PPOPolicy(A2CPolicy[TPPOTrainingStats], Generic[TPPOTrainingStats]): # ty
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
actor: torch.nn.Module,
|
||||
critic: torch.nn.Module,
|
||||
actor: torch.nn.Module | ActorProb | DiscreteActor,
|
||||
critic: torch.nn.Module | Critic | DiscreteCritic,
|
||||
optim: torch.optim.Optimizer,
|
||||
dist_fn: TDistributionFunction,
|
||||
dist_fn: TDistFnDiscrOrCont,
|
||||
action_space: gym.Space,
|
||||
eps_clip: float = 0.2,
|
||||
dual_clip: float | None = None,
|
||||
|
||||
@ -25,8 +25,7 @@ TQRDQNTrainingStats = TypeVar("TQRDQNTrainingStats", bound=QRDQNTrainingStats)
|
||||
class QRDQNPolicy(DQNPolicy[TQRDQNTrainingStats], Generic[TQRDQNTrainingStats]):
|
||||
"""Implementation of Quantile Regression Deep Q-Network. arXiv:1710.10044.
|
||||
|
||||
:param model: a model following the rules in
|
||||
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
|
||||
:param model: a model following the rules (s -> action_values_BA)
|
||||
:param optim: a torch.optim for optimizing the model.
|
||||
:param action_space: Env's action space.
|
||||
:param discount_factor: in [0, 1].
|
||||
|
||||
@ -12,6 +12,7 @@ from tianshou.exploration import BaseNoise
|
||||
from tianshou.policy import DDPGPolicy
|
||||
from tianshou.policy.base import TLearningRateScheduler
|
||||
from tianshou.policy.modelfree.ddpg import DDPGTrainingStats
|
||||
from tianshou.utils.net.continuous import ActorProb
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -61,7 +62,7 @@ class REDQPolicy(DDPGPolicy[TREDQTrainingStats]):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
actor: torch.nn.Module,
|
||||
actor: torch.nn.Module | ActorProb,
|
||||
actor_optim: torch.optim.Optimizer,
|
||||
critic: torch.nn.Module,
|
||||
critic_optim: torch.optim.Optimizer,
|
||||
@ -150,23 +151,28 @@ class REDQPolicy(DDPGPolicy[TREDQTrainingStats]):
|
||||
state: dict | Batch | np.ndarray | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Batch:
|
||||
loc_scale, h = self.actor(batch.obs, state=state, info=batch.info)
|
||||
loc, scale = loc_scale
|
||||
dist = Independent(Normal(loc, scale), 1)
|
||||
(loc_B, scale_B), h_BH = self.actor(batch.obs, state=state, info=batch.info)
|
||||
dist = Independent(Normal(loc_B, scale_B), 1)
|
||||
if self.deterministic_eval and not self.training:
|
||||
act = dist.mode
|
||||
act_B = dist.mode
|
||||
else:
|
||||
act = dist.rsample()
|
||||
log_prob = dist.log_prob(act).unsqueeze(-1)
|
||||
act_B = dist.rsample()
|
||||
log_prob = dist.log_prob(act_B).unsqueeze(-1)
|
||||
# apply correction for Tanh squashing when computing logprob from Gaussian
|
||||
# You can check out the original SAC paper (arXiv 1801.01290): Eq 21.
|
||||
# in appendix C to get some understanding of this equation.
|
||||
squashed_action = torch.tanh(act)
|
||||
squashed_action = torch.tanh(act_B)
|
||||
log_prob = log_prob - torch.log((1 - squashed_action.pow(2)) + self.__eps).sum(
|
||||
-1,
|
||||
keepdim=True,
|
||||
)
|
||||
return Batch(logits=loc_scale, act=squashed_action, state=h, dist=dist, log_prob=log_prob)
|
||||
return Batch(
|
||||
logits=(loc_B, scale_B),
|
||||
act=squashed_action,
|
||||
state=h_BH,
|
||||
dist=dist,
|
||||
log_prob=log_prob,
|
||||
)
|
||||
|
||||
def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor:
|
||||
obs_next_batch = Batch(
|
||||
|
||||
@ -17,6 +17,7 @@ from tianshou.exploration import BaseNoise
|
||||
from tianshou.policy import DDPGPolicy
|
||||
from tianshou.policy.base import TLearningRateScheduler, TrainingStats
|
||||
from tianshou.utils.conversion import to_optional_float
|
||||
from tianshou.utils.net.continuous import ActorProb
|
||||
from tianshou.utils.optim import clone_optimizer
|
||||
|
||||
|
||||
@ -36,8 +37,7 @@ TSACTrainingStats = TypeVar("TSACTrainingStats", bound=SACTrainingStats)
|
||||
class SACPolicy(DDPGPolicy[TSACTrainingStats], Generic[TSACTrainingStats]): # type: ignore[type-var]
|
||||
"""Implementation of Soft Actor-Critic. arXiv:1812.05905.
|
||||
|
||||
:param actor: the actor network following the rules in
|
||||
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
|
||||
:param actor: the actor network following the rules (s -> dist_input_BD)
|
||||
:param actor_optim: the optimizer for actor network.
|
||||
:param critic: the first critic network. (s, a -> Q(s, a))
|
||||
:param critic_optim: the optimizer for the first critic network.
|
||||
@ -76,7 +76,7 @@ class SACPolicy(DDPGPolicy[TSACTrainingStats], Generic[TSACTrainingStats]): # t
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
actor: torch.nn.Module,
|
||||
actor: torch.nn.Module | ActorProb,
|
||||
actor_optim: torch.optim.Optimizer,
|
||||
critic: torch.nn.Module,
|
||||
critic_optim: torch.optim.Optimizer,
|
||||
@ -173,26 +173,25 @@ class SACPolicy(DDPGPolicy[TSACTrainingStats], Generic[TSACTrainingStats]): # t
|
||||
state: dict | Batch | np.ndarray | None = None,
|
||||
**kwargs: Any,
|
||||
) -> DistLogProbBatchProtocol:
|
||||
logits, hidden = self.actor(batch.obs, state=state, info=batch.info)
|
||||
assert isinstance(logits, tuple)
|
||||
dist = Independent(Normal(*logits), 1)
|
||||
(loc_B, scale_B), hidden_BH = self.actor(batch.obs, state=state, info=batch.info)
|
||||
dist = Independent(Normal(loc=loc_B, scale=scale_B), 1)
|
||||
if self.deterministic_eval and not self.training:
|
||||
act = dist.mode
|
||||
act_B = dist.mode
|
||||
else:
|
||||
act = dist.rsample()
|
||||
log_prob = dist.log_prob(act).unsqueeze(-1)
|
||||
act_B = dist.rsample()
|
||||
log_prob = dist.log_prob(act_B).unsqueeze(-1)
|
||||
# apply correction for Tanh squashing when computing logprob from Gaussian
|
||||
# You can check out the original SAC paper (arXiv 1801.01290): Eq 21.
|
||||
# in appendix C to get some understanding of this equation.
|
||||
squashed_action = torch.tanh(act)
|
||||
squashed_action = torch.tanh(act_B)
|
||||
log_prob = log_prob - torch.log((1 - squashed_action.pow(2)) + self.__eps).sum(
|
||||
-1,
|
||||
keepdim=True,
|
||||
)
|
||||
result = Batch(
|
||||
logits=logits,
|
||||
logits=(loc_B, scale_B),
|
||||
act=squashed_action,
|
||||
state=hidden,
|
||||
state=hidden_BH,
|
||||
dist=dist,
|
||||
log_prob=log_prob,
|
||||
)
|
||||
|
||||
@ -29,7 +29,7 @@ class TD3Policy(DDPGPolicy[TTD3TrainingStats], Generic[TTD3TrainingStats]): # t
|
||||
"""Implementation of TD3, arXiv:1802.09477.
|
||||
|
||||
:param actor: the actor network following the rules in
|
||||
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
|
||||
:class:`~tianshou.policy.BasePolicy`. (s -> actions)
|
||||
:param actor_optim: the optimizer for actor network.
|
||||
:param critic: the first critic network. (s, a -> Q(s, a))
|
||||
:param critic_optim: the optimizer for the first critic network.
|
||||
|
||||
@ -11,7 +11,10 @@ from tianshou.data import Batch, SequenceSummaryStats
|
||||
from tianshou.policy import NPGPolicy
|
||||
from tianshou.policy.base import TLearningRateScheduler
|
||||
from tianshou.policy.modelfree.npg import NPGTrainingStats
|
||||
from tianshou.policy.modelfree.pg import TDistributionFunction
|
||||
from tianshou.policy.modelfree.pg import TDistFnDiscrOrCont
|
||||
from tianshou.utils.net.continuous import ActorProb, Critic
|
||||
from tianshou.utils.net.discrete import Actor as DiscreteActor
|
||||
from tianshou.utils.net.discrete import Critic as DiscreteCritic
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
@ -25,7 +28,9 @@ TTRPOTrainingStats = TypeVar("TTRPOTrainingStats", bound=TRPOTrainingStats)
|
||||
class TRPOPolicy(NPGPolicy[TTRPOTrainingStats]):
|
||||
"""Implementation of Trust Region Policy Optimization. arXiv:1502.05477.
|
||||
|
||||
:param actor: the actor network following the rules in BasePolicy. (s -> logits)
|
||||
:param actor: the actor network following the rules:
|
||||
If `self.action_type == "discrete"`: (`s_B` ->`action_values_BA`).
|
||||
If `self.action_type == "continuous"`: (`s_B` -> `dist_input_BD`).
|
||||
:param critic: the critic network. (s -> V(s))
|
||||
:param optim: the optimizer for actor and critic network.
|
||||
:param dist_fn: distribution class for computing the action.
|
||||
@ -53,10 +58,10 @@ class TRPOPolicy(NPGPolicy[TTRPOTrainingStats]):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
actor: torch.nn.Module,
|
||||
critic: torch.nn.Module,
|
||||
actor: torch.nn.Module | ActorProb | DiscreteActor,
|
||||
critic: torch.nn.Module | Critic | DiscreteCritic,
|
||||
optim: torch.optim.Optimizer,
|
||||
dist_fn: TDistributionFunction,
|
||||
dist_fn: TDistFnDiscrOrCont,
|
||||
action_space: gym.Space,
|
||||
max_kl: float = 0.01,
|
||||
backtrack_coeff: float = 0.8,
|
||||
|
||||
@ -610,6 +610,17 @@ class BaseActor(nn.Module, ABC):
|
||||
def get_output_dim(self) -> int:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def forward(
|
||||
self,
|
||||
obs: np.ndarray | torch.Tensor,
|
||||
state: Any = None,
|
||||
info: dict[str, Any] | None = None,
|
||||
) -> tuple[Any, Any]:
|
||||
# TODO: ALGO-REFACTORING. Marked to be addressed as part of Algorithm abstraction.
|
||||
# Return type needs to be more specific
|
||||
pass
|
||||
|
||||
|
||||
def getattr_with_matching_alt_value(obj: Any, attr_name: str, alt_value: T | None) -> T:
|
||||
"""Gets the given attribute from the given object or takes the alternative value if it is not present.
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
@ -9,6 +10,7 @@ from torch import nn
|
||||
from tianshou.utils.net.common import (
|
||||
MLP,
|
||||
BaseActor,
|
||||
Net,
|
||||
TActionShape,
|
||||
TLinearLayer,
|
||||
get_output_dim,
|
||||
@ -19,33 +21,27 @@ SIGMA_MAX = 2
|
||||
|
||||
|
||||
class Actor(BaseActor):
|
||||
"""Simple actor network.
|
||||
"""Simple actor network that directly outputs actions for continuous action space.
|
||||
Used primarily in DDPG and its variants. For probabilistic policies, see :class:`~ActorProb`.
|
||||
|
||||
It will create an actor operated in continuous action space with structure of preprocess_net ---> action_shape.
|
||||
|
||||
:param preprocess_net: a self-defined preprocess_net which output a
|
||||
flattened hidden state.
|
||||
:param preprocess_net: a self-defined preprocess_net, see usage.
|
||||
Typically, an instance of :class:`~tianshou.utils.net.common.Net`.
|
||||
:param action_shape: a sequence of int for the shape of action.
|
||||
:param hidden_sizes: a sequence of int for constructing the MLP after
|
||||
preprocess_net. Default to empty sequence (where the MLP now contains
|
||||
only a single linear layer).
|
||||
:param max_action: the scale for the final action logits. Default to
|
||||
1.
|
||||
:param preprocess_net_output_dim: the output dimension of
|
||||
preprocess_net.
|
||||
:param max_action: the scale for the final action.
|
||||
:param preprocess_net_output_dim: the output dimension of
|
||||
`preprocess_net`. Only used when `preprocess_net` does not have the attribute `output_dim`.
|
||||
|
||||
For advanced usage (how to customize the network), please refer to
|
||||
:ref:`build_the_network`.
|
||||
|
||||
.. seealso::
|
||||
|
||||
Please refer to :class:`~tianshou.utils.net.common.Net` as an instance
|
||||
of how preprocess_net is suggested to be defined.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
preprocess_net: nn.Module,
|
||||
preprocess_net: nn.Module | Net,
|
||||
action_shape: TActionShape,
|
||||
hidden_sizes: Sequence[int] = (),
|
||||
max_action: float = 1.0,
|
||||
@ -77,42 +73,50 @@ class Actor(BaseActor):
|
||||
state: Any = None,
|
||||
info: dict[str, Any] | None = None,
|
||||
) -> tuple[torch.Tensor, Any]:
|
||||
"""Mapping: obs -> logits -> action."""
|
||||
if info is None:
|
||||
info = {}
|
||||
logits, hidden = self.preprocess(obs, state)
|
||||
logits = self.max_action * torch.tanh(self.last(logits))
|
||||
return logits, hidden
|
||||
"""Mapping: s_B -> action_values_BA, hidden_state_BH | None.
|
||||
|
||||
Returns a tensor representing the actions directly, i.e, of shape
|
||||
`(n_actions, )`, and a hidden state (which may be None).
|
||||
The hidden state is only not None if a recurrent net is used as part of the
|
||||
learning algorithm (support for RNNs is currently experimental).
|
||||
"""
|
||||
action_BA, hidden_BH = self.preprocess(obs, state)
|
||||
action_BA = self.max_action * torch.tanh(self.last(action_BA))
|
||||
return action_BA, hidden_BH
|
||||
|
||||
|
||||
class Critic(nn.Module):
|
||||
class CriticBase(nn.Module, ABC):
|
||||
@abstractmethod
|
||||
def forward(
|
||||
self,
|
||||
obs: np.ndarray | torch.Tensor,
|
||||
act: np.ndarray | torch.Tensor | None = None,
|
||||
info: dict[str, Any] | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""Mapping: (s_B, a_B) -> Q(s, a)_B."""
|
||||
|
||||
|
||||
class Critic(CriticBase):
|
||||
"""Simple critic network.
|
||||
|
||||
It will create an actor operated in continuous action space with structure of preprocess_net ---> 1(q value).
|
||||
|
||||
:param preprocess_net: a self-defined preprocess_net which output a
|
||||
flattened hidden state.
|
||||
:param preprocess_net: a self-defined preprocess_net, see usage.
|
||||
Typically, an instance of :class:`~tianshou.utils.net.common.Net`.
|
||||
:param hidden_sizes: a sequence of int for constructing the MLP after
|
||||
preprocess_net. Default to empty sequence (where the MLP now contains
|
||||
only a single linear layer).
|
||||
:param preprocess_net_output_dim: the output dimension of
|
||||
preprocess_net.
|
||||
:param linear_layer: use this module as linear layer. Default to nn.Linear.
|
||||
:param preprocess_net_output_dim: the output dimension of
|
||||
`preprocess_net`. Only used when `preprocess_net` does not have the attribute `output_dim`.
|
||||
:param linear_layer: use this module as linear layer.
|
||||
:param flatten_input: whether to flatten input data for the last layer.
|
||||
Default to True.
|
||||
|
||||
For advanced usage (how to customize the network), please refer to
|
||||
:ref:`build_the_network`.
|
||||
|
||||
.. seealso::
|
||||
|
||||
Please refer to :class:`~tianshou.utils.net.common.Net` as an instance
|
||||
of how preprocess_net is suggested to be defined.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
preprocess_net: nn.Module,
|
||||
preprocess_net: nn.Module | Net,
|
||||
hidden_sizes: Sequence[int] = (),
|
||||
device: str | int | torch.device = "cpu",
|
||||
preprocess_net_output_dim: int | None = None,
|
||||
@ -139,9 +143,7 @@ class Critic(nn.Module):
|
||||
act: np.ndarray | torch.Tensor | None = None,
|
||||
info: dict[str, Any] | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""Mapping: (s, a) -> logits -> Q(s, a)."""
|
||||
if info is None:
|
||||
info = {}
|
||||
"""Mapping: (s_B, a_B) -> Q(s, a)_B."""
|
||||
obs = torch.as_tensor(
|
||||
obs,
|
||||
device=self.device,
|
||||
@ -154,41 +156,35 @@ class Critic(nn.Module):
|
||||
dtype=torch.float32,
|
||||
).flatten(1)
|
||||
obs = torch.cat([obs, act], dim=1)
|
||||
logits, hidden = self.preprocess(obs)
|
||||
return self.last(logits)
|
||||
values_B, hidden_BH = self.preprocess(obs)
|
||||
return self.last(values_B)
|
||||
|
||||
|
||||
class ActorProb(BaseActor):
|
||||
"""Simple actor network (output with a Gauss distribution).
|
||||
"""Simple actor network that outputs `mu` and `sigma` to be used as input for a `dist_fn` (typically, a Gaussian).
|
||||
|
||||
:param preprocess_net: a self-defined preprocess_net which output a
|
||||
flattened hidden state.
|
||||
Used primarily in SAC, PPO and variants thereof. For deterministic policies, see :class:`~Actor`.
|
||||
|
||||
:param preprocess_net: a self-defined preprocess_net, see usage.
|
||||
Typically, an instance of :class:`~tianshou.utils.net.common.Net`.
|
||||
:param action_shape: a sequence of int for the shape of action.
|
||||
:param hidden_sizes: a sequence of int for constructing the MLP after
|
||||
preprocess_net. Default to empty sequence (where the MLP now contains
|
||||
only a single linear layer).
|
||||
:param max_action: the scale for the final action logits. Default to
|
||||
1.
|
||||
:param unbounded: whether to apply tanh activation on final logits.
|
||||
Default to False.
|
||||
:param conditioned_sigma: True when sigma is calculated from the
|
||||
input, False when sigma is an independent parameter. Default to False.
|
||||
:param preprocess_net_output_dim: the output dimension of
|
||||
preprocess_net.
|
||||
:param max_action: the scale for the final action logits.
|
||||
:param unbounded: whether to apply tanh activation on final logits.
|
||||
:param conditioned_sigma: True when sigma is calculated from the
|
||||
input, False when sigma is an independent parameter.
|
||||
:param preprocess_net_output_dim: the output dimension of
|
||||
`preprocess_net`. Only used when `preprocess_net` does not have the attribute `output_dim`.
|
||||
|
||||
For advanced usage (how to customize the network), please refer to
|
||||
:ref:`build_the_network`.
|
||||
|
||||
.. seealso::
|
||||
|
||||
Please refer to :class:`~tianshou.utils.net.common.Net` as an instance
|
||||
of how preprocess_net is suggested to be defined.
|
||||
"""
|
||||
|
||||
# TODO: force kwargs, adjust downstream code
|
||||
def __init__(
|
||||
self,
|
||||
preprocess_net: nn.Module,
|
||||
preprocess_net: nn.Module | Net,
|
||||
action_shape: TActionShape,
|
||||
hidden_sizes: Sequence[int] = (),
|
||||
max_action: float = 1.0,
|
||||
@ -402,8 +398,7 @@ class Perturbation(nn.Module):
|
||||
flattened hidden state.
|
||||
:param max_action: the maximum value of each dimension of action.
|
||||
:param device: which device to create this model on.
|
||||
Default to cpu.
|
||||
:param phi: max perturbation parameter for BCQ. Default to 0.05.
|
||||
:param phi: max perturbation parameter for BCQ.
|
||||
|
||||
For advanced usage (how to customize the network), please refer to
|
||||
:ref:`build_the_network`.
|
||||
@ -449,7 +444,6 @@ class VAE(nn.Module):
|
||||
:param latent_dim: the size of latent layer.
|
||||
:param max_action: the maximum value of each dimension of action.
|
||||
:param device: which device to create this model on.
|
||||
Default to "cpu".
|
||||
|
||||
For advanced usage (how to customize the network), please refer to
|
||||
:ref:`build_the_network`.
|
||||
|
||||
@ -7,17 +7,14 @@ import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from tianshou.data import Batch, to_torch
|
||||
from tianshou.utils.net.common import MLP, BaseActor, TActionShape, get_output_dim
|
||||
from tianshou.utils.net.common import MLP, BaseActor, Net, TActionShape, get_output_dim
|
||||
|
||||
|
||||
class Actor(BaseActor):
|
||||
"""Simple actor network.
|
||||
"""Simple actor network for discrete action spaces.
|
||||
|
||||
Will create an actor operated in discrete action space with structure of
|
||||
preprocess_net ---> action_shape.
|
||||
|
||||
:param preprocess_net: a self-defined preprocess_net which output a
|
||||
flattened hidden state.
|
||||
:param preprocess_net: a self-defined preprocess_net. Typically, an instance of
|
||||
:class:`~tianshou.utils.net.common.Net`.
|
||||
:param action_shape: a sequence of int for the shape of action.
|
||||
:param hidden_sizes: a sequence of int for constructing the MLP after
|
||||
preprocess_net. Default to empty sequence (where the MLP now contains
|
||||
@ -25,20 +22,15 @@ class Actor(BaseActor):
|
||||
:param softmax_output: whether to apply a softmax layer over the last
|
||||
layer's output.
|
||||
:param preprocess_net_output_dim: the output dimension of
|
||||
preprocess_net.
|
||||
`preprocess_net`. Only used when `preprocess_net` does not have the attribute `output_dim`.
|
||||
|
||||
For advanced usage (how to customize the network), please refer to
|
||||
:ref:`build_the_network`.
|
||||
|
||||
.. seealso::
|
||||
|
||||
Please refer to :class:`~tianshou.utils.net.common.Net` as an instance
|
||||
of how preprocess_net is suggested to be defined.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
preprocess_net: nn.Module,
|
||||
preprocess_net: nn.Module | Net,
|
||||
action_shape: TActionShape,
|
||||
hidden_sizes: Sequence[int] = (),
|
||||
softmax_output: bool = True,
|
||||
@ -71,43 +63,44 @@ class Actor(BaseActor):
|
||||
obs: np.ndarray | torch.Tensor,
|
||||
state: Any = None,
|
||||
info: dict[str, Any] | None = None,
|
||||
) -> tuple[torch.Tensor, Any]:
|
||||
r"""Mapping: s -> Q(s, \*)."""
|
||||
if info is None:
|
||||
info = {}
|
||||
logits, hidden = self.preprocess(obs, state)
|
||||
logits = self.last(logits)
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
r"""Mapping: s_B -> action_values_BA, hidden_state_BH | None.
|
||||
|
||||
Returns a tensor representing the values of each action, i.e, of shape
|
||||
`(n_actions, )`, and
|
||||
a hidden state (which may be None). If `self.softmax_output` is True, they are the
|
||||
probabilities for taking each action. Otherwise, they will be action values.
|
||||
The hidden state is only
|
||||
not None if a recurrent net is used as part of the learning algorithm.
|
||||
"""
|
||||
x, hidden_BH = self.preprocess(obs, state)
|
||||
x = self.last(x)
|
||||
if self.softmax_output:
|
||||
logits = F.softmax(logits, dim=-1)
|
||||
return logits, hidden
|
||||
x = F.softmax(x, dim=-1)
|
||||
# If we computed softmax, output is probabilities, otherwise it's the non-normalized action values
|
||||
output_BA = x
|
||||
return output_BA, hidden_BH
|
||||
|
||||
|
||||
class Critic(nn.Module):
|
||||
"""Simple critic network.
|
||||
"""Simple critic network for discrete action spaces.
|
||||
|
||||
It will create an actor operated in discrete action space with structure of preprocess_net ---> 1(q value).
|
||||
|
||||
:param preprocess_net: a self-defined preprocess_net which output a
|
||||
flattened hidden state.
|
||||
:param preprocess_net: a self-defined preprocess_net. Typically, an instance of
|
||||
:class:`~tianshou.utils.net.common.Net`.
|
||||
:param hidden_sizes: a sequence of int for constructing the MLP after
|
||||
preprocess_net. Default to empty sequence (where the MLP now contains
|
||||
only a single linear layer).
|
||||
:param last_size: the output dimension of Critic network. Default to 1.
|
||||
:param preprocess_net_output_dim: the output dimension of
|
||||
preprocess_net.
|
||||
`preprocess_net`. Only used when `preprocess_net` does not have the attribute `output_dim`.
|
||||
|
||||
For advanced usage (how to customize the network), please refer to
|
||||
:ref:`build_the_network`.
|
||||
|
||||
.. seealso::
|
||||
|
||||
Please refer to :class:`~tianshou.utils.net.common.Net` as an instance
|
||||
of how preprocess_net is suggested to be defined.
|
||||
:ref:`build_the_network`..
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
preprocess_net: nn.Module,
|
||||
preprocess_net: nn.Module | Net,
|
||||
hidden_sizes: Sequence[int] = (),
|
||||
last_size: int = 1,
|
||||
preprocess_net_output_dim: int | None = None,
|
||||
@ -120,8 +113,10 @@ class Critic(nn.Module):
|
||||
input_dim = get_output_dim(preprocess_net, preprocess_net_output_dim)
|
||||
self.last = MLP(input_dim, last_size, hidden_sizes, device=self.device)
|
||||
|
||||
# TODO: make a proper interface!
|
||||
def forward(self, obs: np.ndarray | torch.Tensor, **kwargs: Any) -> torch.Tensor:
|
||||
"""Mapping: s -> V(s)."""
|
||||
"""Mapping: s_B -> V(s)_B."""
|
||||
# TODO: don't use this mechanism for passing state
|
||||
logits, _ = self.preprocess(obs, state=kwargs.get("state", None))
|
||||
return self.last(logits)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user