Support deterministic evaluation for onpolicy algorithms (#354)

This commit is contained in:
Yuge Zhang 2021-04-27 21:22:39 +08:00 committed by GitHub
parent ff4d3cd714
commit f4e05d585a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 38 additions and 15 deletions

View File

@ -96,7 +96,8 @@ def test_npg(args=get_args()):
gae_lambda=args.gae_lambda, gae_lambda=args.gae_lambda,
action_space=env.action_space, action_space=env.action_space,
optim_critic_iters=args.optim_critic_iters, optim_critic_iters=args.optim_critic_iters,
actor_step_size=args.actor_step_size) actor_step_size=args.actor_step_size,
deterministic_eval=True)
# collector # collector
train_collector = Collector( train_collector = Collector(
policy, train_envs, policy, train_envs,

View File

@ -145,7 +145,7 @@ def test_sac_with_il(args=get_args()):
).to(args.device) ).to(args.device)
optim = torch.optim.Adam(net.parameters(), lr=args.il_lr) optim = torch.optim.Adam(net.parameters(), lr=args.il_lr)
il_policy = ImitationPolicy( il_policy = ImitationPolicy(
net, optim, mode='continuous', action_space=env.action_space, net, optim, action_space=env.action_space,
action_scaling=True, action_bound_method="clip") action_scaling=True, action_bound_method="clip")
il_test_collector = Collector( il_test_collector = Collector(
il_policy, il_policy,

View File

@ -124,7 +124,7 @@ def test_a2c_with_il(args=get_args()):
device=args.device) device=args.device)
net = Actor(net, args.action_shape, device=args.device).to(args.device) net = Actor(net, args.action_shape, device=args.device).to(args.device)
optim = torch.optim.Adam(net.parameters(), lr=args.il_lr) optim = torch.optim.Adam(net.parameters(), lr=args.il_lr)
il_policy = ImitationPolicy(net, optim, mode='discrete') il_policy = ImitationPolicy(net, optim, action_space=env.action_space)
il_test_collector = Collector( il_test_collector = Collector(
il_policy, il_policy,
DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)]) DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)])

View File

@ -89,7 +89,8 @@ def test_ppo(args=get_args()):
reward_normalization=args.rew_norm, reward_normalization=args.rew_norm,
dual_clip=args.dual_clip, dual_clip=args.dual_clip,
value_clip=args.value_clip, value_clip=args.value_clip,
action_space=env.action_space) action_space=env.action_space,
deterministic_eval=True)
# collector # collector
train_collector = Collector( train_collector = Collector(
policy, train_envs, policy, train_envs,

View File

@ -5,6 +5,7 @@ from torch import nn
from numba import njit from numba import njit
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Dict, Tuple, Union, Optional, Callable from typing import Any, Dict, Tuple, Union, Optional, Callable
from gym.spaces import Box, Discrete, MultiDiscrete, MultiBinary
from tianshou.data import Batch, ReplayBuffer, to_torch_as, to_numpy from tianshou.data import Batch, ReplayBuffer, to_torch_as, to_numpy
@ -66,6 +67,11 @@ class BasePolicy(ABC, nn.Module):
super().__init__() super().__init__()
self.observation_space = observation_space self.observation_space = observation_space
self.action_space = action_space self.action_space = action_space
self.action_type = ""
if isinstance(action_space, (Discrete, MultiDiscrete, MultiBinary)):
self.action_type = "discrete"
elif isinstance(action_space, Box):
self.action_type = "continuous"
self.agent_id = 0 self.agent_id = 0
self.updating = False self.updating = False
self.action_scaling = action_scaling self.action_scaling = action_scaling

View File

@ -13,8 +13,7 @@ class ImitationPolicy(BasePolicy):
:param torch.nn.Module model: a model following the rules in :param torch.nn.Module model: a model following the rules in
:class:`~tianshou.policy.BasePolicy`. (s -> a) :class:`~tianshou.policy.BasePolicy`. (s -> a)
:param torch.optim.Optimizer optim: for optimizing the model. :param torch.optim.Optimizer optim: for optimizing the model.
:param str mode: indicate the imitation type ("continuous" or "discrete" :param gym.Space action_space: env's action space.
action space). Default to "continuous".
.. seealso:: .. seealso::
@ -26,15 +25,13 @@ class ImitationPolicy(BasePolicy):
self, self,
model: torch.nn.Module, model: torch.nn.Module,
optim: torch.optim.Optimizer, optim: torch.optim.Optimizer,
mode: str = "continuous",
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
super().__init__(**kwargs) super().__init__(**kwargs)
self.model = model self.model = model
self.optim = optim self.optim = optim
assert mode in ["continuous", "discrete"], \ assert self.action_type in ["continuous", "discrete"], \
f"Mode {mode} is not in ['continuous', 'discrete']." "Please specify action_space."
self.mode = mode
def forward( def forward(
self, self,
@ -43,7 +40,7 @@ class ImitationPolicy(BasePolicy):
**kwargs: Any, **kwargs: Any,
) -> Batch: ) -> Batch:
logits, h = self.model(batch.obs, state=state, info=batch.info) logits, h = self.model(batch.obs, state=state, info=batch.info)
if self.mode == "discrete": if self.action_type == "discrete":
a = logits.max(dim=1)[1] a = logits.max(dim=1)[1]
else: else:
a = logits a = logits
@ -51,11 +48,11 @@ class ImitationPolicy(BasePolicy):
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
self.optim.zero_grad() self.optim.zero_grad()
if self.mode == "continuous": # regression if self.action_type == "continuous": # regression
a = self(batch).act a = self(batch).act
a_ = to_torch(batch.act, dtype=torch.float32, device=a.device) a_ = to_torch(batch.act, dtype=torch.float32, device=a.device)
loss = F.mse_loss(a, a_) # type: ignore loss = F.mse_loss(a, a_) # type: ignore
elif self.mode == "discrete": # classification elif self.action_type == "discrete": # classification
a = F.log_softmax(self(batch).logits, dim=-1) a = F.log_softmax(self(batch).logits, dim=-1)
a_ = to_torch(batch.act, dtype=torch.long, device=a.device) a_ = to_torch(batch.act, dtype=torch.long, device=a.device)
loss = F.nll_loss(a, a_) # type: ignore loss = F.nll_loss(a, a_) # type: ignore

View File

@ -39,6 +39,8 @@ class A2CPolicy(PGPolicy):
to use option "action_scaling" or "action_bound_method". Default to None. to use option "action_scaling" or "action_bound_method". Default to None.
:param lr_scheduler: a learning rate scheduler that adjusts the learning rate in :param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
optimizer in each policy.update(). Default to None (no lr_scheduler). optimizer in each policy.update(). Default to None (no lr_scheduler).
:param bool deterministic_eval: whether to use deterministic action instead of
stochastic action sampled by the policy. Default to False.
.. seealso:: .. seealso::

View File

@ -42,6 +42,8 @@ class NPGPolicy(A2CPolicy):
to use option "action_scaling" or "action_bound_method". Default to None. to use option "action_scaling" or "action_bound_method". Default to None.
:param lr_scheduler: a learning rate scheduler that adjusts the learning rate in :param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
optimizer in each policy.update(). Default to None (no lr_scheduler). optimizer in each policy.update(). Default to None (no lr_scheduler).
:param bool deterministic_eval: whether to use deterministic action instead of
stochastic action sampled by the policy. Default to False.
""" """
def __init__( def __init__(

View File

@ -3,8 +3,8 @@ import numpy as np
from typing import Any, Dict, List, Type, Union, Optional from typing import Any, Dict, List, Type, Union, Optional
from tianshou.policy import BasePolicy from tianshou.policy import BasePolicy
from tianshou.data import Batch, ReplayBuffer, to_torch_as
from tianshou.utils import RunningMeanStd from tianshou.utils import RunningMeanStd
from tianshou.data import Batch, ReplayBuffer, to_torch_as
class PGPolicy(BasePolicy): class PGPolicy(BasePolicy):
@ -25,6 +25,8 @@ class PGPolicy(BasePolicy):
to use option "action_scaling" or "action_bound_method". Default to None. to use option "action_scaling" or "action_bound_method". Default to None.
:param lr_scheduler: a learning rate scheduler that adjusts the learning rate in :param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
optimizer in each policy.update(). Default to None (no lr_scheduler). optimizer in each policy.update(). Default to None (no lr_scheduler).
:param bool deterministic_eval: whether to use deterministic action instead of
stochastic action sampled by the policy. Default to False.
.. seealso:: .. seealso::
@ -42,6 +44,7 @@ class PGPolicy(BasePolicy):
action_scaling: bool = True, action_scaling: bool = True,
action_bound_method: str = "clip", action_bound_method: str = "clip",
lr_scheduler: Optional[torch.optim.lr_scheduler.LambdaLR] = None, lr_scheduler: Optional[torch.optim.lr_scheduler.LambdaLR] = None,
deterministic_eval: bool = False,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
super().__init__(action_scaling=action_scaling, super().__init__(action_scaling=action_scaling,
@ -55,6 +58,7 @@ class PGPolicy(BasePolicy):
self._rew_norm = reward_normalization self._rew_norm = reward_normalization
self.ret_rms = RunningMeanStd() self.ret_rms = RunningMeanStd()
self._eps = 1e-8 self._eps = 1e-8
self._deterministic_eval = deterministic_eval
def process_fn( def process_fn(
self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray
@ -103,7 +107,13 @@ class PGPolicy(BasePolicy):
dist = self.dist_fn(*logits) dist = self.dist_fn(*logits)
else: else:
dist = self.dist_fn(logits) dist = self.dist_fn(logits)
act = dist.sample() if self._deterministic_eval and not self.training:
if self.action_type == "discrete":
act = logits.argmax(-1)
elif self.action_type == "continuous":
act = logits[0]
else:
act = dist.sample()
return Batch(logits=logits, act=act, state=h, dist=dist) return Batch(logits=logits, act=act, state=h, dist=dist)
def learn( # type: ignore def learn( # type: ignore

View File

@ -49,6 +49,8 @@ class PPOPolicy(A2CPolicy):
to use option "action_scaling" or "action_bound_method". Default to None. to use option "action_scaling" or "action_bound_method". Default to None.
:param lr_scheduler: a learning rate scheduler that adjusts the learning rate in :param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
optimizer in each policy.update(). Default to None (no lr_scheduler). optimizer in each policy.update(). Default to None (no lr_scheduler).
:param bool deterministic_eval: whether to use deterministic action instead of
stochastic action sampled by the policy. Default to False.
.. seealso:: .. seealso::

View File

@ -45,6 +45,8 @@ class TRPOPolicy(NPGPolicy):
to use option "action_scaling" or "action_bound_method". Default to None. to use option "action_scaling" or "action_bound_method". Default to None.
:param lr_scheduler: a learning rate scheduler that adjusts the learning rate in :param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
optimizer in each policy.update(). Default to None (no lr_scheduler). optimizer in each policy.update(). Default to None (no lr_scheduler).
:param bool deterministic_eval: whether to use deterministic action instead of
stochastic action sampled by the policy. Default to False.
""" """
def __init__( def __init__(