Support deterministic evaluation for onpolicy algorithms (#354)
This commit is contained in:
parent
ff4d3cd714
commit
f4e05d585a
@ -96,7 +96,8 @@ def test_npg(args=get_args()):
|
|||||||
gae_lambda=args.gae_lambda,
|
gae_lambda=args.gae_lambda,
|
||||||
action_space=env.action_space,
|
action_space=env.action_space,
|
||||||
optim_critic_iters=args.optim_critic_iters,
|
optim_critic_iters=args.optim_critic_iters,
|
||||||
actor_step_size=args.actor_step_size)
|
actor_step_size=args.actor_step_size,
|
||||||
|
deterministic_eval=True)
|
||||||
# collector
|
# collector
|
||||||
train_collector = Collector(
|
train_collector = Collector(
|
||||||
policy, train_envs,
|
policy, train_envs,
|
||||||
|
@ -145,7 +145,7 @@ def test_sac_with_il(args=get_args()):
|
|||||||
).to(args.device)
|
).to(args.device)
|
||||||
optim = torch.optim.Adam(net.parameters(), lr=args.il_lr)
|
optim = torch.optim.Adam(net.parameters(), lr=args.il_lr)
|
||||||
il_policy = ImitationPolicy(
|
il_policy = ImitationPolicy(
|
||||||
net, optim, mode='continuous', action_space=env.action_space,
|
net, optim, action_space=env.action_space,
|
||||||
action_scaling=True, action_bound_method="clip")
|
action_scaling=True, action_bound_method="clip")
|
||||||
il_test_collector = Collector(
|
il_test_collector = Collector(
|
||||||
il_policy,
|
il_policy,
|
||||||
|
@ -124,7 +124,7 @@ def test_a2c_with_il(args=get_args()):
|
|||||||
device=args.device)
|
device=args.device)
|
||||||
net = Actor(net, args.action_shape, device=args.device).to(args.device)
|
net = Actor(net, args.action_shape, device=args.device).to(args.device)
|
||||||
optim = torch.optim.Adam(net.parameters(), lr=args.il_lr)
|
optim = torch.optim.Adam(net.parameters(), lr=args.il_lr)
|
||||||
il_policy = ImitationPolicy(net, optim, mode='discrete')
|
il_policy = ImitationPolicy(net, optim, action_space=env.action_space)
|
||||||
il_test_collector = Collector(
|
il_test_collector = Collector(
|
||||||
il_policy,
|
il_policy,
|
||||||
DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)])
|
DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)])
|
||||||
|
@ -89,7 +89,8 @@ def test_ppo(args=get_args()):
|
|||||||
reward_normalization=args.rew_norm,
|
reward_normalization=args.rew_norm,
|
||||||
dual_clip=args.dual_clip,
|
dual_clip=args.dual_clip,
|
||||||
value_clip=args.value_clip,
|
value_clip=args.value_clip,
|
||||||
action_space=env.action_space)
|
action_space=env.action_space,
|
||||||
|
deterministic_eval=True)
|
||||||
# collector
|
# collector
|
||||||
train_collector = Collector(
|
train_collector = Collector(
|
||||||
policy, train_envs,
|
policy, train_envs,
|
||||||
|
@ -5,6 +5,7 @@ from torch import nn
|
|||||||
from numba import njit
|
from numba import njit
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any, Dict, Tuple, Union, Optional, Callable
|
from typing import Any, Dict, Tuple, Union, Optional, Callable
|
||||||
|
from gym.spaces import Box, Discrete, MultiDiscrete, MultiBinary
|
||||||
|
|
||||||
from tianshou.data import Batch, ReplayBuffer, to_torch_as, to_numpy
|
from tianshou.data import Batch, ReplayBuffer, to_torch_as, to_numpy
|
||||||
|
|
||||||
@ -66,6 +67,11 @@ class BasePolicy(ABC, nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.observation_space = observation_space
|
self.observation_space = observation_space
|
||||||
self.action_space = action_space
|
self.action_space = action_space
|
||||||
|
self.action_type = ""
|
||||||
|
if isinstance(action_space, (Discrete, MultiDiscrete, MultiBinary)):
|
||||||
|
self.action_type = "discrete"
|
||||||
|
elif isinstance(action_space, Box):
|
||||||
|
self.action_type = "continuous"
|
||||||
self.agent_id = 0
|
self.agent_id = 0
|
||||||
self.updating = False
|
self.updating = False
|
||||||
self.action_scaling = action_scaling
|
self.action_scaling = action_scaling
|
||||||
|
@ -13,8 +13,7 @@ class ImitationPolicy(BasePolicy):
|
|||||||
:param torch.nn.Module model: a model following the rules in
|
:param torch.nn.Module model: a model following the rules in
|
||||||
:class:`~tianshou.policy.BasePolicy`. (s -> a)
|
:class:`~tianshou.policy.BasePolicy`. (s -> a)
|
||||||
:param torch.optim.Optimizer optim: for optimizing the model.
|
:param torch.optim.Optimizer optim: for optimizing the model.
|
||||||
:param str mode: indicate the imitation type ("continuous" or "discrete"
|
:param gym.Space action_space: env's action space.
|
||||||
action space). Default to "continuous".
|
|
||||||
|
|
||||||
.. seealso::
|
.. seealso::
|
||||||
|
|
||||||
@ -26,15 +25,13 @@ class ImitationPolicy(BasePolicy):
|
|||||||
self,
|
self,
|
||||||
model: torch.nn.Module,
|
model: torch.nn.Module,
|
||||||
optim: torch.optim.Optimizer,
|
optim: torch.optim.Optimizer,
|
||||||
mode: str = "continuous",
|
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self.model = model
|
self.model = model
|
||||||
self.optim = optim
|
self.optim = optim
|
||||||
assert mode in ["continuous", "discrete"], \
|
assert self.action_type in ["continuous", "discrete"], \
|
||||||
f"Mode {mode} is not in ['continuous', 'discrete']."
|
"Please specify action_space."
|
||||||
self.mode = mode
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -43,7 +40,7 @@ class ImitationPolicy(BasePolicy):
|
|||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Batch:
|
) -> Batch:
|
||||||
logits, h = self.model(batch.obs, state=state, info=batch.info)
|
logits, h = self.model(batch.obs, state=state, info=batch.info)
|
||||||
if self.mode == "discrete":
|
if self.action_type == "discrete":
|
||||||
a = logits.max(dim=1)[1]
|
a = logits.max(dim=1)[1]
|
||||||
else:
|
else:
|
||||||
a = logits
|
a = logits
|
||||||
@ -51,11 +48,11 @@ class ImitationPolicy(BasePolicy):
|
|||||||
|
|
||||||
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
|
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
|
||||||
self.optim.zero_grad()
|
self.optim.zero_grad()
|
||||||
if self.mode == "continuous": # regression
|
if self.action_type == "continuous": # regression
|
||||||
a = self(batch).act
|
a = self(batch).act
|
||||||
a_ = to_torch(batch.act, dtype=torch.float32, device=a.device)
|
a_ = to_torch(batch.act, dtype=torch.float32, device=a.device)
|
||||||
loss = F.mse_loss(a, a_) # type: ignore
|
loss = F.mse_loss(a, a_) # type: ignore
|
||||||
elif self.mode == "discrete": # classification
|
elif self.action_type == "discrete": # classification
|
||||||
a = F.log_softmax(self(batch).logits, dim=-1)
|
a = F.log_softmax(self(batch).logits, dim=-1)
|
||||||
a_ = to_torch(batch.act, dtype=torch.long, device=a.device)
|
a_ = to_torch(batch.act, dtype=torch.long, device=a.device)
|
||||||
loss = F.nll_loss(a, a_) # type: ignore
|
loss = F.nll_loss(a, a_) # type: ignore
|
||||||
|
@ -39,6 +39,8 @@ class A2CPolicy(PGPolicy):
|
|||||||
to use option "action_scaling" or "action_bound_method". Default to None.
|
to use option "action_scaling" or "action_bound_method". Default to None.
|
||||||
:param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
|
:param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
|
||||||
optimizer in each policy.update(). Default to None (no lr_scheduler).
|
optimizer in each policy.update(). Default to None (no lr_scheduler).
|
||||||
|
:param bool deterministic_eval: whether to use deterministic action instead of
|
||||||
|
stochastic action sampled by the policy. Default to False.
|
||||||
|
|
||||||
.. seealso::
|
.. seealso::
|
||||||
|
|
||||||
|
@ -42,6 +42,8 @@ class NPGPolicy(A2CPolicy):
|
|||||||
to use option "action_scaling" or "action_bound_method". Default to None.
|
to use option "action_scaling" or "action_bound_method". Default to None.
|
||||||
:param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
|
:param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
|
||||||
optimizer in each policy.update(). Default to None (no lr_scheduler).
|
optimizer in each policy.update(). Default to None (no lr_scheduler).
|
||||||
|
:param bool deterministic_eval: whether to use deterministic action instead of
|
||||||
|
stochastic action sampled by the policy. Default to False.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -3,8 +3,8 @@ import numpy as np
|
|||||||
from typing import Any, Dict, List, Type, Union, Optional
|
from typing import Any, Dict, List, Type, Union, Optional
|
||||||
|
|
||||||
from tianshou.policy import BasePolicy
|
from tianshou.policy import BasePolicy
|
||||||
from tianshou.data import Batch, ReplayBuffer, to_torch_as
|
|
||||||
from tianshou.utils import RunningMeanStd
|
from tianshou.utils import RunningMeanStd
|
||||||
|
from tianshou.data import Batch, ReplayBuffer, to_torch_as
|
||||||
|
|
||||||
|
|
||||||
class PGPolicy(BasePolicy):
|
class PGPolicy(BasePolicy):
|
||||||
@ -25,6 +25,8 @@ class PGPolicy(BasePolicy):
|
|||||||
to use option "action_scaling" or "action_bound_method". Default to None.
|
to use option "action_scaling" or "action_bound_method". Default to None.
|
||||||
:param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
|
:param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
|
||||||
optimizer in each policy.update(). Default to None (no lr_scheduler).
|
optimizer in each policy.update(). Default to None (no lr_scheduler).
|
||||||
|
:param bool deterministic_eval: whether to use deterministic action instead of
|
||||||
|
stochastic action sampled by the policy. Default to False.
|
||||||
|
|
||||||
.. seealso::
|
.. seealso::
|
||||||
|
|
||||||
@ -42,6 +44,7 @@ class PGPolicy(BasePolicy):
|
|||||||
action_scaling: bool = True,
|
action_scaling: bool = True,
|
||||||
action_bound_method: str = "clip",
|
action_bound_method: str = "clip",
|
||||||
lr_scheduler: Optional[torch.optim.lr_scheduler.LambdaLR] = None,
|
lr_scheduler: Optional[torch.optim.lr_scheduler.LambdaLR] = None,
|
||||||
|
deterministic_eval: bool = False,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(action_scaling=action_scaling,
|
super().__init__(action_scaling=action_scaling,
|
||||||
@ -55,6 +58,7 @@ class PGPolicy(BasePolicy):
|
|||||||
self._rew_norm = reward_normalization
|
self._rew_norm = reward_normalization
|
||||||
self.ret_rms = RunningMeanStd()
|
self.ret_rms = RunningMeanStd()
|
||||||
self._eps = 1e-8
|
self._eps = 1e-8
|
||||||
|
self._deterministic_eval = deterministic_eval
|
||||||
|
|
||||||
def process_fn(
|
def process_fn(
|
||||||
self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray
|
self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray
|
||||||
@ -103,7 +107,13 @@ class PGPolicy(BasePolicy):
|
|||||||
dist = self.dist_fn(*logits)
|
dist = self.dist_fn(*logits)
|
||||||
else:
|
else:
|
||||||
dist = self.dist_fn(logits)
|
dist = self.dist_fn(logits)
|
||||||
act = dist.sample()
|
if self._deterministic_eval and not self.training:
|
||||||
|
if self.action_type == "discrete":
|
||||||
|
act = logits.argmax(-1)
|
||||||
|
elif self.action_type == "continuous":
|
||||||
|
act = logits[0]
|
||||||
|
else:
|
||||||
|
act = dist.sample()
|
||||||
return Batch(logits=logits, act=act, state=h, dist=dist)
|
return Batch(logits=logits, act=act, state=h, dist=dist)
|
||||||
|
|
||||||
def learn( # type: ignore
|
def learn( # type: ignore
|
||||||
|
@ -49,6 +49,8 @@ class PPOPolicy(A2CPolicy):
|
|||||||
to use option "action_scaling" or "action_bound_method". Default to None.
|
to use option "action_scaling" or "action_bound_method". Default to None.
|
||||||
:param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
|
:param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
|
||||||
optimizer in each policy.update(). Default to None (no lr_scheduler).
|
optimizer in each policy.update(). Default to None (no lr_scheduler).
|
||||||
|
:param bool deterministic_eval: whether to use deterministic action instead of
|
||||||
|
stochastic action sampled by the policy. Default to False.
|
||||||
|
|
||||||
.. seealso::
|
.. seealso::
|
||||||
|
|
||||||
|
@ -45,6 +45,8 @@ class TRPOPolicy(NPGPolicy):
|
|||||||
to use option "action_scaling" or "action_bound_method". Default to None.
|
to use option "action_scaling" or "action_bound_method". Default to None.
|
||||||
:param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
|
:param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
|
||||||
optimizer in each policy.update(). Default to None (no lr_scheduler).
|
optimizer in each policy.update(). Default to None (no lr_scheduler).
|
||||||
|
:param bool deterministic_eval: whether to use deterministic action instead of
|
||||||
|
stochastic action sampled by the policy. Default to False.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user