Support deterministic evaluation for onpolicy algorithms (#354)

This commit is contained in:
Yuge Zhang 2021-04-27 21:22:39 +08:00 committed by GitHub
parent ff4d3cd714
commit f4e05d585a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 38 additions and 15 deletions

View File

@ -96,7 +96,8 @@ def test_npg(args=get_args()):
gae_lambda=args.gae_lambda,
action_space=env.action_space,
optim_critic_iters=args.optim_critic_iters,
actor_step_size=args.actor_step_size)
actor_step_size=args.actor_step_size,
deterministic_eval=True)
# collector
train_collector = Collector(
policy, train_envs,

View File

@ -145,7 +145,7 @@ def test_sac_with_il(args=get_args()):
).to(args.device)
optim = torch.optim.Adam(net.parameters(), lr=args.il_lr)
il_policy = ImitationPolicy(
net, optim, mode='continuous', action_space=env.action_space,
net, optim, action_space=env.action_space,
action_scaling=True, action_bound_method="clip")
il_test_collector = Collector(
il_policy,

View File

@ -124,7 +124,7 @@ def test_a2c_with_il(args=get_args()):
device=args.device)
net = Actor(net, args.action_shape, device=args.device).to(args.device)
optim = torch.optim.Adam(net.parameters(), lr=args.il_lr)
il_policy = ImitationPolicy(net, optim, mode='discrete')
il_policy = ImitationPolicy(net, optim, action_space=env.action_space)
il_test_collector = Collector(
il_policy,
DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)])

View File

@ -89,7 +89,8 @@ def test_ppo(args=get_args()):
reward_normalization=args.rew_norm,
dual_clip=args.dual_clip,
value_clip=args.value_clip,
action_space=env.action_space)
action_space=env.action_space,
deterministic_eval=True)
# collector
train_collector = Collector(
policy, train_envs,

View File

@ -5,6 +5,7 @@ from torch import nn
from numba import njit
from abc import ABC, abstractmethod
from typing import Any, Dict, Tuple, Union, Optional, Callable
from gym.spaces import Box, Discrete, MultiDiscrete, MultiBinary
from tianshou.data import Batch, ReplayBuffer, to_torch_as, to_numpy
@ -66,6 +67,11 @@ class BasePolicy(ABC, nn.Module):
super().__init__()
self.observation_space = observation_space
self.action_space = action_space
self.action_type = ""
if isinstance(action_space, (Discrete, MultiDiscrete, MultiBinary)):
self.action_type = "discrete"
elif isinstance(action_space, Box):
self.action_type = "continuous"
self.agent_id = 0
self.updating = False
self.action_scaling = action_scaling

View File

@ -13,8 +13,7 @@ class ImitationPolicy(BasePolicy):
:param torch.nn.Module model: a model following the rules in
:class:`~tianshou.policy.BasePolicy`. (s -> a)
:param torch.optim.Optimizer optim: for optimizing the model.
:param str mode: indicate the imitation type ("continuous" or "discrete"
action space). Default to "continuous".
:param gym.Space action_space: env's action space.
.. seealso::
@ -26,15 +25,13 @@ class ImitationPolicy(BasePolicy):
self,
model: torch.nn.Module,
optim: torch.optim.Optimizer,
mode: str = "continuous",
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
self.model = model
self.optim = optim
assert mode in ["continuous", "discrete"], \
f"Mode {mode} is not in ['continuous', 'discrete']."
self.mode = mode
assert self.action_type in ["continuous", "discrete"], \
"Please specify action_space."
def forward(
self,
@ -43,7 +40,7 @@ class ImitationPolicy(BasePolicy):
**kwargs: Any,
) -> Batch:
logits, h = self.model(batch.obs, state=state, info=batch.info)
if self.mode == "discrete":
if self.action_type == "discrete":
a = logits.max(dim=1)[1]
else:
a = logits
@ -51,11 +48,11 @@ class ImitationPolicy(BasePolicy):
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
self.optim.zero_grad()
if self.mode == "continuous": # regression
if self.action_type == "continuous": # regression
a = self(batch).act
a_ = to_torch(batch.act, dtype=torch.float32, device=a.device)
loss = F.mse_loss(a, a_) # type: ignore
elif self.mode == "discrete": # classification
elif self.action_type == "discrete": # classification
a = F.log_softmax(self(batch).logits, dim=-1)
a_ = to_torch(batch.act, dtype=torch.long, device=a.device)
loss = F.nll_loss(a, a_) # type: ignore

View File

@ -39,6 +39,8 @@ class A2CPolicy(PGPolicy):
to use option "action_scaling" or "action_bound_method". Default to None.
:param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
optimizer in each policy.update(). Default to None (no lr_scheduler).
:param bool deterministic_eval: whether to use deterministic action instead of
stochastic action sampled by the policy. Default to False.
.. seealso::

View File

@ -42,6 +42,8 @@ class NPGPolicy(A2CPolicy):
to use option "action_scaling" or "action_bound_method". Default to None.
:param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
optimizer in each policy.update(). Default to None (no lr_scheduler).
:param bool deterministic_eval: whether to use deterministic action instead of
stochastic action sampled by the policy. Default to False.
"""
def __init__(

View File

@ -3,8 +3,8 @@ import numpy as np
from typing import Any, Dict, List, Type, Union, Optional
from tianshou.policy import BasePolicy
from tianshou.data import Batch, ReplayBuffer, to_torch_as
from tianshou.utils import RunningMeanStd
from tianshou.data import Batch, ReplayBuffer, to_torch_as
class PGPolicy(BasePolicy):
@ -25,6 +25,8 @@ class PGPolicy(BasePolicy):
to use option "action_scaling" or "action_bound_method". Default to None.
:param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
optimizer in each policy.update(). Default to None (no lr_scheduler).
:param bool deterministic_eval: whether to use deterministic action instead of
stochastic action sampled by the policy. Default to False.
.. seealso::
@ -42,6 +44,7 @@ class PGPolicy(BasePolicy):
action_scaling: bool = True,
action_bound_method: str = "clip",
lr_scheduler: Optional[torch.optim.lr_scheduler.LambdaLR] = None,
deterministic_eval: bool = False,
**kwargs: Any,
) -> None:
super().__init__(action_scaling=action_scaling,
@ -55,6 +58,7 @@ class PGPolicy(BasePolicy):
self._rew_norm = reward_normalization
self.ret_rms = RunningMeanStd()
self._eps = 1e-8
self._deterministic_eval = deterministic_eval
def process_fn(
self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray
@ -103,6 +107,12 @@ class PGPolicy(BasePolicy):
dist = self.dist_fn(*logits)
else:
dist = self.dist_fn(logits)
if self._deterministic_eval and not self.training:
if self.action_type == "discrete":
act = logits.argmax(-1)
elif self.action_type == "continuous":
act = logits[0]
else:
act = dist.sample()
return Batch(logits=logits, act=act, state=h, dist=dist)

View File

@ -49,6 +49,8 @@ class PPOPolicy(A2CPolicy):
to use option "action_scaling" or "action_bound_method". Default to None.
:param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
optimizer in each policy.update(). Default to None (no lr_scheduler).
:param bool deterministic_eval: whether to use deterministic action instead of
stochastic action sampled by the policy. Default to False.
.. seealso::

View File

@ -45,6 +45,8 @@ class TRPOPolicy(NPGPolicy):
to use option "action_scaling" or "action_bound_method". Default to None.
:param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
optimizer in each policy.update(). Default to None (no lr_scheduler).
:param bool deterministic_eval: whether to use deterministic action instead of
stochastic action sampled by the policy. Default to False.
"""
def __init__(