Support deterministic evaluation for onpolicy algorithms (#354)
This commit is contained in:
parent
ff4d3cd714
commit
f4e05d585a
@ -96,7 +96,8 @@ def test_npg(args=get_args()):
|
||||
gae_lambda=args.gae_lambda,
|
||||
action_space=env.action_space,
|
||||
optim_critic_iters=args.optim_critic_iters,
|
||||
actor_step_size=args.actor_step_size)
|
||||
actor_step_size=args.actor_step_size,
|
||||
deterministic_eval=True)
|
||||
# collector
|
||||
train_collector = Collector(
|
||||
policy, train_envs,
|
||||
|
@ -145,7 +145,7 @@ def test_sac_with_il(args=get_args()):
|
||||
).to(args.device)
|
||||
optim = torch.optim.Adam(net.parameters(), lr=args.il_lr)
|
||||
il_policy = ImitationPolicy(
|
||||
net, optim, mode='continuous', action_space=env.action_space,
|
||||
net, optim, action_space=env.action_space,
|
||||
action_scaling=True, action_bound_method="clip")
|
||||
il_test_collector = Collector(
|
||||
il_policy,
|
||||
|
@ -124,7 +124,7 @@ def test_a2c_with_il(args=get_args()):
|
||||
device=args.device)
|
||||
net = Actor(net, args.action_shape, device=args.device).to(args.device)
|
||||
optim = torch.optim.Adam(net.parameters(), lr=args.il_lr)
|
||||
il_policy = ImitationPolicy(net, optim, mode='discrete')
|
||||
il_policy = ImitationPolicy(net, optim, action_space=env.action_space)
|
||||
il_test_collector = Collector(
|
||||
il_policy,
|
||||
DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)])
|
||||
|
@ -89,7 +89,8 @@ def test_ppo(args=get_args()):
|
||||
reward_normalization=args.rew_norm,
|
||||
dual_clip=args.dual_clip,
|
||||
value_clip=args.value_clip,
|
||||
action_space=env.action_space)
|
||||
action_space=env.action_space,
|
||||
deterministic_eval=True)
|
||||
# collector
|
||||
train_collector = Collector(
|
||||
policy, train_envs,
|
||||
|
@ -5,6 +5,7 @@ from torch import nn
|
||||
from numba import njit
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, Tuple, Union, Optional, Callable
|
||||
from gym.spaces import Box, Discrete, MultiDiscrete, MultiBinary
|
||||
|
||||
from tianshou.data import Batch, ReplayBuffer, to_torch_as, to_numpy
|
||||
|
||||
@ -66,6 +67,11 @@ class BasePolicy(ABC, nn.Module):
|
||||
super().__init__()
|
||||
self.observation_space = observation_space
|
||||
self.action_space = action_space
|
||||
self.action_type = ""
|
||||
if isinstance(action_space, (Discrete, MultiDiscrete, MultiBinary)):
|
||||
self.action_type = "discrete"
|
||||
elif isinstance(action_space, Box):
|
||||
self.action_type = "continuous"
|
||||
self.agent_id = 0
|
||||
self.updating = False
|
||||
self.action_scaling = action_scaling
|
||||
|
@ -13,8 +13,7 @@ class ImitationPolicy(BasePolicy):
|
||||
:param torch.nn.Module model: a model following the rules in
|
||||
:class:`~tianshou.policy.BasePolicy`. (s -> a)
|
||||
:param torch.optim.Optimizer optim: for optimizing the model.
|
||||
:param str mode: indicate the imitation type ("continuous" or "discrete"
|
||||
action space). Default to "continuous".
|
||||
:param gym.Space action_space: env's action space.
|
||||
|
||||
.. seealso::
|
||||
|
||||
@ -26,15 +25,13 @@ class ImitationPolicy(BasePolicy):
|
||||
self,
|
||||
model: torch.nn.Module,
|
||||
optim: torch.optim.Optimizer,
|
||||
mode: str = "continuous",
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self.model = model
|
||||
self.optim = optim
|
||||
assert mode in ["continuous", "discrete"], \
|
||||
f"Mode {mode} is not in ['continuous', 'discrete']."
|
||||
self.mode = mode
|
||||
assert self.action_type in ["continuous", "discrete"], \
|
||||
"Please specify action_space."
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -43,7 +40,7 @@ class ImitationPolicy(BasePolicy):
|
||||
**kwargs: Any,
|
||||
) -> Batch:
|
||||
logits, h = self.model(batch.obs, state=state, info=batch.info)
|
||||
if self.mode == "discrete":
|
||||
if self.action_type == "discrete":
|
||||
a = logits.max(dim=1)[1]
|
||||
else:
|
||||
a = logits
|
||||
@ -51,11 +48,11 @@ class ImitationPolicy(BasePolicy):
|
||||
|
||||
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
|
||||
self.optim.zero_grad()
|
||||
if self.mode == "continuous": # regression
|
||||
if self.action_type == "continuous": # regression
|
||||
a = self(batch).act
|
||||
a_ = to_torch(batch.act, dtype=torch.float32, device=a.device)
|
||||
loss = F.mse_loss(a, a_) # type: ignore
|
||||
elif self.mode == "discrete": # classification
|
||||
elif self.action_type == "discrete": # classification
|
||||
a = F.log_softmax(self(batch).logits, dim=-1)
|
||||
a_ = to_torch(batch.act, dtype=torch.long, device=a.device)
|
||||
loss = F.nll_loss(a, a_) # type: ignore
|
||||
|
@ -39,6 +39,8 @@ class A2CPolicy(PGPolicy):
|
||||
to use option "action_scaling" or "action_bound_method". Default to None.
|
||||
:param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
|
||||
optimizer in each policy.update(). Default to None (no lr_scheduler).
|
||||
:param bool deterministic_eval: whether to use deterministic action instead of
|
||||
stochastic action sampled by the policy. Default to False.
|
||||
|
||||
.. seealso::
|
||||
|
||||
|
@ -42,6 +42,8 @@ class NPGPolicy(A2CPolicy):
|
||||
to use option "action_scaling" or "action_bound_method". Default to None.
|
||||
:param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
|
||||
optimizer in each policy.update(). Default to None (no lr_scheduler).
|
||||
:param bool deterministic_eval: whether to use deterministic action instead of
|
||||
stochastic action sampled by the policy. Default to False.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
@ -3,8 +3,8 @@ import numpy as np
|
||||
from typing import Any, Dict, List, Type, Union, Optional
|
||||
|
||||
from tianshou.policy import BasePolicy
|
||||
from tianshou.data import Batch, ReplayBuffer, to_torch_as
|
||||
from tianshou.utils import RunningMeanStd
|
||||
from tianshou.data import Batch, ReplayBuffer, to_torch_as
|
||||
|
||||
|
||||
class PGPolicy(BasePolicy):
|
||||
@ -25,6 +25,8 @@ class PGPolicy(BasePolicy):
|
||||
to use option "action_scaling" or "action_bound_method". Default to None.
|
||||
:param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
|
||||
optimizer in each policy.update(). Default to None (no lr_scheduler).
|
||||
:param bool deterministic_eval: whether to use deterministic action instead of
|
||||
stochastic action sampled by the policy. Default to False.
|
||||
|
||||
.. seealso::
|
||||
|
||||
@ -42,6 +44,7 @@ class PGPolicy(BasePolicy):
|
||||
action_scaling: bool = True,
|
||||
action_bound_method: str = "clip",
|
||||
lr_scheduler: Optional[torch.optim.lr_scheduler.LambdaLR] = None,
|
||||
deterministic_eval: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(action_scaling=action_scaling,
|
||||
@ -55,6 +58,7 @@ class PGPolicy(BasePolicy):
|
||||
self._rew_norm = reward_normalization
|
||||
self.ret_rms = RunningMeanStd()
|
||||
self._eps = 1e-8
|
||||
self._deterministic_eval = deterministic_eval
|
||||
|
||||
def process_fn(
|
||||
self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray
|
||||
@ -103,6 +107,12 @@ class PGPolicy(BasePolicy):
|
||||
dist = self.dist_fn(*logits)
|
||||
else:
|
||||
dist = self.dist_fn(logits)
|
||||
if self._deterministic_eval and not self.training:
|
||||
if self.action_type == "discrete":
|
||||
act = logits.argmax(-1)
|
||||
elif self.action_type == "continuous":
|
||||
act = logits[0]
|
||||
else:
|
||||
act = dist.sample()
|
||||
return Batch(logits=logits, act=act, state=h, dist=dist)
|
||||
|
||||
|
@ -49,6 +49,8 @@ class PPOPolicy(A2CPolicy):
|
||||
to use option "action_scaling" or "action_bound_method". Default to None.
|
||||
:param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
|
||||
optimizer in each policy.update(). Default to None (no lr_scheduler).
|
||||
:param bool deterministic_eval: whether to use deterministic action instead of
|
||||
stochastic action sampled by the policy. Default to False.
|
||||
|
||||
.. seealso::
|
||||
|
||||
|
@ -45,6 +45,8 @@ class TRPOPolicy(NPGPolicy):
|
||||
to use option "action_scaling" or "action_bound_method". Default to None.
|
||||
:param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
|
||||
optimizer in each policy.update(). Default to None (no lr_scheduler).
|
||||
:param bool deterministic_eval: whether to use deterministic action instead of
|
||||
stochastic action sampled by the policy. Default to False.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
Loading…
x
Reference in New Issue
Block a user