From f4e05d585aadf82ce6705d19ebca71d52a805725 Mon Sep 17 00:00:00 2001 From: Yuge Zhang Date: Tue, 27 Apr 2021 21:22:39 +0800 Subject: [PATCH] Support deterministic evaluation for onpolicy algorithms (#354) --- test/continuous/test_npg.py | 3 ++- test/continuous/test_sac_with_il.py | 2 +- test/discrete/test_a2c_with_il.py | 2 +- test/discrete/test_ppo.py | 3 ++- tianshou/policy/base.py | 6 ++++++ tianshou/policy/imitation/base.py | 15 ++++++--------- tianshou/policy/modelfree/a2c.py | 2 ++ tianshou/policy/modelfree/npg.py | 2 ++ tianshou/policy/modelfree/pg.py | 14 ++++++++++++-- tianshou/policy/modelfree/ppo.py | 2 ++ tianshou/policy/modelfree/trpo.py | 2 ++ 11 files changed, 38 insertions(+), 15 deletions(-) diff --git a/test/continuous/test_npg.py b/test/continuous/test_npg.py index d5172fa..9243313 100644 --- a/test/continuous/test_npg.py +++ b/test/continuous/test_npg.py @@ -96,7 +96,8 @@ def test_npg(args=get_args()): gae_lambda=args.gae_lambda, action_space=env.action_space, optim_critic_iters=args.optim_critic_iters, - actor_step_size=args.actor_step_size) + actor_step_size=args.actor_step_size, + deterministic_eval=True) # collector train_collector = Collector( policy, train_envs, diff --git a/test/continuous/test_sac_with_il.py b/test/continuous/test_sac_with_il.py index c064ace..7ed05cf 100644 --- a/test/continuous/test_sac_with_il.py +++ b/test/continuous/test_sac_with_il.py @@ -145,7 +145,7 @@ def test_sac_with_il(args=get_args()): ).to(args.device) optim = torch.optim.Adam(net.parameters(), lr=args.il_lr) il_policy = ImitationPolicy( - net, optim, mode='continuous', action_space=env.action_space, + net, optim, action_space=env.action_space, action_scaling=True, action_bound_method="clip") il_test_collector = Collector( il_policy, diff --git a/test/discrete/test_a2c_with_il.py b/test/discrete/test_a2c_with_il.py index 8c188ca..ff34f9e 100644 --- a/test/discrete/test_a2c_with_il.py +++ b/test/discrete/test_a2c_with_il.py @@ -124,7 +124,7 @@ def test_a2c_with_il(args=get_args()): device=args.device) net = Actor(net, args.action_shape, device=args.device).to(args.device) optim = torch.optim.Adam(net.parameters(), lr=args.il_lr) - il_policy = ImitationPolicy(net, optim, mode='discrete') + il_policy = ImitationPolicy(net, optim, action_space=env.action_space) il_test_collector = Collector( il_policy, DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)]) diff --git a/test/discrete/test_ppo.py b/test/discrete/test_ppo.py index ae051a4..af9c3a7 100644 --- a/test/discrete/test_ppo.py +++ b/test/discrete/test_ppo.py @@ -89,7 +89,8 @@ def test_ppo(args=get_args()): reward_normalization=args.rew_norm, dual_clip=args.dual_clip, value_clip=args.value_clip, - action_space=env.action_space) + action_space=env.action_space, + deterministic_eval=True) # collector train_collector = Collector( policy, train_envs, diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index a0dcb33..dff0f85 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -5,6 +5,7 @@ from torch import nn from numba import njit from abc import ABC, abstractmethod from typing import Any, Dict, Tuple, Union, Optional, Callable +from gym.spaces import Box, Discrete, MultiDiscrete, MultiBinary from tianshou.data import Batch, ReplayBuffer, to_torch_as, to_numpy @@ -66,6 +67,11 @@ class BasePolicy(ABC, nn.Module): super().__init__() self.observation_space = observation_space self.action_space = action_space + self.action_type = "" + if isinstance(action_space, (Discrete, MultiDiscrete, MultiBinary)): + self.action_type = "discrete" + elif isinstance(action_space, Box): + self.action_type = "continuous" self.agent_id = 0 self.updating = False self.action_scaling = action_scaling diff --git a/tianshou/policy/imitation/base.py b/tianshou/policy/imitation/base.py index a618dd4..f94aa1d 100644 --- a/tianshou/policy/imitation/base.py +++ b/tianshou/policy/imitation/base.py @@ -13,8 +13,7 @@ class ImitationPolicy(BasePolicy): :param torch.nn.Module model: a model following the rules in :class:`~tianshou.policy.BasePolicy`. (s -> a) :param torch.optim.Optimizer optim: for optimizing the model. - :param str mode: indicate the imitation type ("continuous" or "discrete" - action space). Default to "continuous". + :param gym.Space action_space: env's action space. .. seealso:: @@ -26,15 +25,13 @@ class ImitationPolicy(BasePolicy): self, model: torch.nn.Module, optim: torch.optim.Optimizer, - mode: str = "continuous", **kwargs: Any, ) -> None: super().__init__(**kwargs) self.model = model self.optim = optim - assert mode in ["continuous", "discrete"], \ - f"Mode {mode} is not in ['continuous', 'discrete']." - self.mode = mode + assert self.action_type in ["continuous", "discrete"], \ + "Please specify action_space." def forward( self, @@ -43,7 +40,7 @@ class ImitationPolicy(BasePolicy): **kwargs: Any, ) -> Batch: logits, h = self.model(batch.obs, state=state, info=batch.info) - if self.mode == "discrete": + if self.action_type == "discrete": a = logits.max(dim=1)[1] else: a = logits @@ -51,11 +48,11 @@ class ImitationPolicy(BasePolicy): def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: self.optim.zero_grad() - if self.mode == "continuous": # regression + if self.action_type == "continuous": # regression a = self(batch).act a_ = to_torch(batch.act, dtype=torch.float32, device=a.device) loss = F.mse_loss(a, a_) # type: ignore - elif self.mode == "discrete": # classification + elif self.action_type == "discrete": # classification a = F.log_softmax(self(batch).logits, dim=-1) a_ = to_torch(batch.act, dtype=torch.long, device=a.device) loss = F.nll_loss(a, a_) # type: ignore diff --git a/tianshou/policy/modelfree/a2c.py b/tianshou/policy/modelfree/a2c.py index 433810d..9b27248 100644 --- a/tianshou/policy/modelfree/a2c.py +++ b/tianshou/policy/modelfree/a2c.py @@ -39,6 +39,8 @@ class A2CPolicy(PGPolicy): to use option "action_scaling" or "action_bound_method". Default to None. :param lr_scheduler: a learning rate scheduler that adjusts the learning rate in optimizer in each policy.update(). Default to None (no lr_scheduler). + :param bool deterministic_eval: whether to use deterministic action instead of + stochastic action sampled by the policy. Default to False. .. seealso:: diff --git a/tianshou/policy/modelfree/npg.py b/tianshou/policy/modelfree/npg.py index 7cda9b9..d7bff01 100644 --- a/tianshou/policy/modelfree/npg.py +++ b/tianshou/policy/modelfree/npg.py @@ -42,6 +42,8 @@ class NPGPolicy(A2CPolicy): to use option "action_scaling" or "action_bound_method". Default to None. :param lr_scheduler: a learning rate scheduler that adjusts the learning rate in optimizer in each policy.update(). Default to None (no lr_scheduler). + :param bool deterministic_eval: whether to use deterministic action instead of + stochastic action sampled by the policy. Default to False. """ def __init__( diff --git a/tianshou/policy/modelfree/pg.py b/tianshou/policy/modelfree/pg.py index 6e52569..0f21138 100644 --- a/tianshou/policy/modelfree/pg.py +++ b/tianshou/policy/modelfree/pg.py @@ -3,8 +3,8 @@ import numpy as np from typing import Any, Dict, List, Type, Union, Optional from tianshou.policy import BasePolicy -from tianshou.data import Batch, ReplayBuffer, to_torch_as from tianshou.utils import RunningMeanStd +from tianshou.data import Batch, ReplayBuffer, to_torch_as class PGPolicy(BasePolicy): @@ -25,6 +25,8 @@ class PGPolicy(BasePolicy): to use option "action_scaling" or "action_bound_method". Default to None. :param lr_scheduler: a learning rate scheduler that adjusts the learning rate in optimizer in each policy.update(). Default to None (no lr_scheduler). + :param bool deterministic_eval: whether to use deterministic action instead of + stochastic action sampled by the policy. Default to False. .. seealso:: @@ -42,6 +44,7 @@ class PGPolicy(BasePolicy): action_scaling: bool = True, action_bound_method: str = "clip", lr_scheduler: Optional[torch.optim.lr_scheduler.LambdaLR] = None, + deterministic_eval: bool = False, **kwargs: Any, ) -> None: super().__init__(action_scaling=action_scaling, @@ -55,6 +58,7 @@ class PGPolicy(BasePolicy): self._rew_norm = reward_normalization self.ret_rms = RunningMeanStd() self._eps = 1e-8 + self._deterministic_eval = deterministic_eval def process_fn( self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray @@ -103,7 +107,13 @@ class PGPolicy(BasePolicy): dist = self.dist_fn(*logits) else: dist = self.dist_fn(logits) - act = dist.sample() + if self._deterministic_eval and not self.training: + if self.action_type == "discrete": + act = logits.argmax(-1) + elif self.action_type == "continuous": + act = logits[0] + else: + act = dist.sample() return Batch(logits=logits, act=act, state=h, dist=dist) def learn( # type: ignore diff --git a/tianshou/policy/modelfree/ppo.py b/tianshou/policy/modelfree/ppo.py index 8c0575d..beb2152 100644 --- a/tianshou/policy/modelfree/ppo.py +++ b/tianshou/policy/modelfree/ppo.py @@ -49,6 +49,8 @@ class PPOPolicy(A2CPolicy): to use option "action_scaling" or "action_bound_method". Default to None. :param lr_scheduler: a learning rate scheduler that adjusts the learning rate in optimizer in each policy.update(). Default to None (no lr_scheduler). + :param bool deterministic_eval: whether to use deterministic action instead of + stochastic action sampled by the policy. Default to False. .. seealso:: diff --git a/tianshou/policy/modelfree/trpo.py b/tianshou/policy/modelfree/trpo.py index 78313e6..01c77cb 100644 --- a/tianshou/policy/modelfree/trpo.py +++ b/tianshou/policy/modelfree/trpo.py @@ -45,6 +45,8 @@ class TRPOPolicy(NPGPolicy): to use option "action_scaling" or "action_bound_method". Default to None. :param lr_scheduler: a learning rate scheduler that adjusts the learning rate in optimizer in each policy.update(). Default to None (no lr_scheduler). + :param bool deterministic_eval: whether to use deterministic action instead of + stochastic action sampled by the policy. Default to False. """ def __init__(