Remap action to fit gym's action space (#313)

Co-authored-by: Trinkle23897 <trinkle23897@gmail.com>
This commit is contained in:
ChenDRAG 2021-03-21 16:45:50 +08:00 committed by GitHub
parent 0c7117dd55
commit 4d92952a7b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
21 changed files with 145 additions and 77 deletions

View File

@ -117,9 +117,8 @@ def test_sac_bipedal(args=get_args()):
policy = SACPolicy( policy = SACPolicy(
actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim, actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim,
action_range=[env.action_space.low[0], env.action_space.high[0]],
tau=args.tau, gamma=args.gamma, alpha=args.alpha, tau=args.tau, gamma=args.gamma, alpha=args.alpha,
estimation_step=args.n_step) estimation_step=args.n_step, action_space=env.action_space)
# load a previous policy # load a previous policy
if args.resume_path: if args.resume_path:
policy.load_state_dict(torch.load(args.resume_path)) policy.load_state_dict(torch.load(args.resume_path))

View File

@ -90,10 +90,10 @@ def test_sac(args=get_args()):
policy = SACPolicy( policy = SACPolicy(
actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim, actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim,
action_range=[env.action_space.low[0], env.action_space.high[0]],
tau=args.tau, gamma=args.gamma, alpha=args.alpha, tau=args.tau, gamma=args.gamma, alpha=args.alpha,
reward_normalization=args.rew_norm, reward_normalization=args.rew_norm,
exploration_noise=OUNoise(0.0, args.noise_std)) exploration_noise=OUNoise(0.0, args.noise_std),
action_space=env.action_space)
# collector # collector
train_collector = Collector( train_collector = Collector(
policy, train_envs, policy, train_envs,

View File

@ -84,10 +84,9 @@ def test_ddpg(args=get_args()):
critic_optim = torch.optim.Adam(critic.parameters(), lr=args.critic_lr) critic_optim = torch.optim.Adam(critic.parameters(), lr=args.critic_lr)
policy = DDPGPolicy( policy = DDPGPolicy(
actor, actor_optim, critic, critic_optim, actor, actor_optim, critic, critic_optim,
action_range=[env.action_space.low[0], env.action_space.high[0]],
tau=args.tau, gamma=args.gamma, tau=args.tau, gamma=args.gamma,
exploration_noise=GaussianNoise(sigma=args.exploration_noise), exploration_noise=GaussianNoise(sigma=args.exploration_noise),
estimation_step=args.n_step) estimation_step=args.n_step, action_space=env.action_space)
# load a previous policy # load a previous policy
if args.resume_path: if args.resume_path:
policy.load_state_dict(torch.load( policy.load_state_dict(torch.load(

View File

@ -97,9 +97,8 @@ def test_sac(args=get_args()):
policy = SACPolicy( policy = SACPolicy(
actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim, actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim,
action_range=[env.action_space.low[0], env.action_space.high[0]],
tau=args.tau, gamma=args.gamma, alpha=args.alpha, tau=args.tau, gamma=args.gamma, alpha=args.alpha,
estimation_step=args.n_step) estimation_step=args.n_step, action_space=env.action_space)
# load a previous policy # load a previous policy
if args.resume_path: if args.resume_path:
policy.load_state_dict(torch.load( policy.load_state_dict(torch.load(

View File

@ -95,11 +95,11 @@ def test_td3(args=get_args()):
policy = TD3Policy( policy = TD3Policy(
actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim, actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim,
action_range=[env.action_space.low[0], env.action_space.high[0]],
tau=args.tau, gamma=args.gamma, tau=args.tau, gamma=args.gamma,
exploration_noise=GaussianNoise(sigma=args.exploration_noise), exploration_noise=GaussianNoise(sigma=args.exploration_noise),
policy_noise=args.policy_noise, update_actor_freq=args.update_actor_freq, policy_noise=args.policy_noise, update_actor_freq=args.update_actor_freq,
noise_clip=args.noise_clip, estimation_step=args.n_step) noise_clip=args.noise_clip, estimation_step=args.n_step,
action_space=env.action_space)
# load a previous policy # load a previous policy
if args.resume_path: if args.resume_path:

View File

@ -77,11 +77,10 @@ def test_ddpg(args=get_args()):
critic_optim = torch.optim.Adam(critic.parameters(), lr=args.critic_lr) critic_optim = torch.optim.Adam(critic.parameters(), lr=args.critic_lr)
policy = DDPGPolicy( policy = DDPGPolicy(
actor, actor_optim, critic, critic_optim, actor, actor_optim, critic, critic_optim,
action_range=[env.action_space.low[0], env.action_space.high[0]],
tau=args.tau, gamma=args.gamma, tau=args.tau, gamma=args.gamma,
exploration_noise=GaussianNoise(sigma=args.exploration_noise), exploration_noise=GaussianNoise(sigma=args.exploration_noise),
reward_normalization=args.rew_norm, reward_normalization=args.rew_norm,
estimation_step=args.n_step) estimation_step=args.n_step, action_space=env.action_space)
# collector # collector
train_collector = Collector( train_collector = Collector(
policy, train_envs, policy, train_envs,

View File

@ -100,9 +100,8 @@ def test_ppo(args=get_args()):
# dual_clip=args.dual_clip, # dual_clip=args.dual_clip,
# dual clip cause monotonically increasing log_std :) # dual clip cause monotonically increasing log_std :)
value_clip=args.value_clip, value_clip=args.value_clip,
# action_range=[env.action_space.low[0], env.action_space.high[0]],) gae_lambda=args.gae_lambda,
# if clip the action, ppo would not converge :) action_space=env.action_space)
gae_lambda=args.gae_lambda)
# collector # collector
train_collector = Collector( train_collector = Collector(
policy, train_envs, policy, train_envs,

View File

@ -87,10 +87,9 @@ def test_sac_with_il(args=get_args()):
critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr) critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)
policy = SACPolicy( policy = SACPolicy(
actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim, actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim,
action_range=[env.action_space.low[0], env.action_space.high[0]],
tau=args.tau, gamma=args.gamma, alpha=args.alpha, tau=args.tau, gamma=args.gamma, alpha=args.alpha,
reward_normalization=args.rew_norm, reward_normalization=args.rew_norm,
estimation_step=args.n_step) estimation_step=args.n_step, action_space=env.action_space)
# collector # collector
train_collector = Collector( train_collector = Collector(
policy, train_envs, policy, train_envs,

View File

@ -86,14 +86,14 @@ def test_td3(args=get_args()):
critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr) critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)
policy = TD3Policy( policy = TD3Policy(
actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim, actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim,
action_range=[env.action_space.low[0], env.action_space.high[0]],
tau=args.tau, gamma=args.gamma, tau=args.tau, gamma=args.gamma,
exploration_noise=GaussianNoise(sigma=args.exploration_noise), exploration_noise=GaussianNoise(sigma=args.exploration_noise),
policy_noise=args.policy_noise, policy_noise=args.policy_noise,
update_actor_freq=args.update_actor_freq, update_actor_freq=args.update_actor_freq,
noise_clip=args.noise_clip, noise_clip=args.noise_clip,
reward_normalization=args.rew_norm, reward_normalization=args.rew_norm,
estimation_step=args.n_step) estimation_step=args.n_step,
action_space=env.action_space)
# collector # collector
train_collector = Collector( train_collector = Collector(
policy, train_envs, policy, train_envs,

View File

@ -80,7 +80,8 @@ def test_a2c_with_il(args=get_args()):
policy = A2CPolicy( policy = A2CPolicy(
actor, critic, optim, dist, args.gamma, gae_lambda=args.gae_lambda, actor, critic, optim, dist, args.gamma, gae_lambda=args.gae_lambda,
vf_coef=args.vf_coef, ent_coef=args.ent_coef, vf_coef=args.vf_coef, ent_coef=args.ent_coef,
max_grad_norm=args.max_grad_norm, reward_normalization=args.rew_norm) max_grad_norm=args.max_grad_norm, reward_normalization=args.rew_norm,
action_space=env.action_space)
# collector # collector
train_collector = Collector( train_collector = Collector(
policy, train_envs, policy, train_envs,

View File

@ -63,7 +63,8 @@ def test_pg(args=get_args()):
optim = torch.optim.Adam(net.parameters(), lr=args.lr) optim = torch.optim.Adam(net.parameters(), lr=args.lr)
dist = torch.distributions.Categorical dist = torch.distributions.Categorical
policy = PGPolicy(net, optim, dist, args.gamma, policy = PGPolicy(net, optim, dist, args.gamma,
reward_normalization=args.rew_norm) reward_normalization=args.rew_norm,
action_space=env.action_space)
# collector # collector
train_collector = Collector( train_collector = Collector(
policy, train_envs, policy, train_envs,

View File

@ -85,11 +85,11 @@ def test_ppo(args=get_args()):
eps_clip=args.eps_clip, eps_clip=args.eps_clip,
vf_coef=args.vf_coef, vf_coef=args.vf_coef,
ent_coef=args.ent_coef, ent_coef=args.ent_coef,
action_range=None,
gae_lambda=args.gae_lambda, gae_lambda=args.gae_lambda,
reward_normalization=args.rew_norm, reward_normalization=args.rew_norm,
dual_clip=args.dual_clip, dual_clip=args.dual_clip,
value_clip=args.value_clip) value_clip=args.value_clip,
action_space=env.action_space)
# collector # collector
train_collector = Collector( train_collector = Collector(
policy, train_envs, policy, train_envs,

View File

@ -219,8 +219,10 @@ class Collector(object):
act = self.policy.exploration_noise(act, self.data) act = self.policy.exploration_noise(act, self.data)
self.data.update(policy=policy, act=act) self.data.update(policy=policy, act=act)
# get bounded and remapped actions first (not saved into buffer)
action_remap = self.policy.map_action(self.data.act)
# step in env # step in env
obs_next, rew, done, info = self.env.step(self.data.act, id=ready_env_ids) obs_next, rew, done, info = self.env.step(action_remap, id=ready_env_ids)
self.data.update(obs_next=obs_next, rew=rew, done=done, info=info) self.data.update(obs_next=obs_next, rew=rew, done=done, info=info)
if self.preprocess_fn: if self.preprocess_fn:
@ -419,8 +421,10 @@ class AsyncCollector(Collector):
_alloc_by_keys_diff(whole_data, self.data, self.env_num, False) _alloc_by_keys_diff(whole_data, self.data, self.env_num, False)
whole_data[ready_env_ids] = self.data # lots of overhead whole_data[ready_env_ids] = self.data # lots of overhead
# get bounded and remapped actions first (not saved into buffer)
action_remap = self.policy.map_action(self.data.act)
# step in env # step in env
obs_next, rew, done, info = self.env.step(self.data.act, id=ready_env_ids) obs_next, rew, done, info = self.env.step(action_remap, id=ready_env_ids)
# change self.data here because ready_env_ids has changed # change self.data here because ready_env_ids has changed
ready_env_ids = np.array([i["env_id"] for i in info]) ready_env_ids = np.array([i["env_id"] for i in info])

View File

@ -53,14 +53,20 @@ class BasePolicy(ABC, nn.Module):
def __init__( def __init__(
self, self,
observation_space: gym.Space = None, observation_space: Optional[gym.Space] = None,
action_space: gym.Space = None action_space: Optional[gym.Space] = None,
action_scaling: bool = False,
action_bound_method: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.observation_space = observation_space self.observation_space = observation_space
self.action_space = action_space self.action_space = action_space
self.agent_id = 0 self.agent_id = 0
self.updating = False self.updating = False
self.action_scaling = action_scaling
# can be one of ("clip", "tanh", ""), empty string means no bounding
assert action_bound_method in ("", "clip", "tanh")
self.action_bound_method = action_bound_method
self._compile() self._compile()
def set_agent_id(self, agent_id: int) -> None: def set_agent_id(self, agent_id: int) -> None:
@ -114,6 +120,38 @@ class BasePolicy(ABC, nn.Module):
""" """
pass pass
def map_action(self, act: Union[Batch, np.ndarray]) -> Union[Batch, np.ndarray]:
"""Map raw network output to action range in gym's env.action_space.
This function is called in :meth:`~tianshou.data.Collector.collect` and only
affects action sending to env. Remapped action will not be stored in buffer
and thus can be viewed as a part of env (a black box action transformation).
Action mapping includes 2 standard procedures: bounding and scaling. Bounding
procedure expects original action range is (-inf, inf) and maps it to [-1, 1],
while scaling procedure expects original action range is (-1, 1) and maps it
to [action_space.low, action_space.high]. Bounding procedure is applied first.
:param act: a data batch or numpy.ndarray which is the action taken by
policy.forward.
:return: action in the same form of input "act" but remap to the target action
space.
"""
if isinstance(self.action_space, gym.spaces.Box) and \
isinstance(act, np.ndarray):
# currently this action mapping only supports np.ndarray action
if self.action_bound_method == "clip":
act = np.clip(act, -1.0, 1.0)
elif self.action_bound_method == "tanh":
act = np.tanh(act)
if self.action_scaling:
assert np.all(act >= -1.0) and np.all(act <= 1.0), \
"action scaling only accepts raw action range = [-1, 1]"
low, high = self.action_space.low, self.action_space.high
act = low + (high - low) * (act + 1.0) / 2.0
return act
def process_fn( def process_fn(
self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray
) -> Batch: ) -> Batch:

View File

@ -31,6 +31,13 @@ class A2CPolicy(PGPolicy):
depends on the size of available memory and the memory cost of the depends on the size of available memory and the memory cost of the
model; should be as large as possible within the memory constraint. model; should be as large as possible within the memory constraint.
Default to 256. Default to 256.
:param bool action_scaling: whether to map actions from range [-1, 1] to range
[action_spaces.low, action_spaces.high]. Default to True.
:param str action_bound_method: method to bound action to range [-1, 1], can be
either "clip" (for simply clipping the action), "tanh" (for applying tanh
squashing) for now, or empty string for no bounding. Default to "clip".
:param Optional[gym.Space] action_space: env's action space, mandatory if you want
to use option "action_scaling" or "action_bound_method". Default to None.
.. seealso:: .. seealso::

View File

@ -16,8 +16,6 @@ class DDPGPolicy(BasePolicy):
:param torch.optim.Optimizer actor_optim: the optimizer for actor network. :param torch.optim.Optimizer actor_optim: the optimizer for actor network.
:param torch.nn.Module critic: the critic network. (s, a -> Q(s, a)) :param torch.nn.Module critic: the critic network. (s, a -> Q(s, a))
:param torch.optim.Optimizer critic_optim: the optimizer for critic network. :param torch.optim.Optimizer critic_optim: the optimizer for critic network.
:param action_range: the action range (minimum, maximum).
:type action_range: Tuple[float, float]
:param float tau: param for soft update of the target network. Default to 0.005. :param float tau: param for soft update of the target network. Default to 0.005.
:param float gamma: discount factor, in [0, 1]. Default to 0.99. :param float gamma: discount factor, in [0, 1]. Default to 0.99.
:param BaseNoise exploration_noise: the exploration noise, :param BaseNoise exploration_noise: the exploration noise,
@ -25,6 +23,13 @@ class DDPGPolicy(BasePolicy):
:param bool reward_normalization: normalize the reward to Normal(0, 1), :param bool reward_normalization: normalize the reward to Normal(0, 1),
Default to False. Default to False.
:param int estimation_step: the number of steps to look ahead. Default to 1. :param int estimation_step: the number of steps to look ahead. Default to 1.
:param bool action_scaling: whether to map actions from range [-1, 1] to range
[action_spaces.low, action_spaces.high]. Default to True.
:param str action_bound_method: method to bound action to range [-1, 1], can be
either "clip" (for simply clipping the action), "tanh" (for applying tanh
squashing) for now, or empty string for no bounding. Default to "clip".
:param Optional[gym.Space] action_space: env's action space, mandatory if you want
to use option "action_scaling" or "action_bound_method". Default to None.
.. seealso:: .. seealso::
@ -38,15 +43,17 @@ class DDPGPolicy(BasePolicy):
actor_optim: Optional[torch.optim.Optimizer], actor_optim: Optional[torch.optim.Optimizer],
critic: Optional[torch.nn.Module], critic: Optional[torch.nn.Module],
critic_optim: Optional[torch.optim.Optimizer], critic_optim: Optional[torch.optim.Optimizer],
action_range: Tuple[float, float],
tau: float = 0.005, tau: float = 0.005,
gamma: float = 0.99, gamma: float = 0.99,
exploration_noise: Optional[BaseNoise] = GaussianNoise(sigma=0.1), exploration_noise: Optional[BaseNoise] = GaussianNoise(sigma=0.1),
reward_normalization: bool = False, reward_normalization: bool = False,
estimation_step: int = 1, estimation_step: int = 1,
action_scaling: bool = True,
action_bound_method: str = "clip",
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
super().__init__(**kwargs) super().__init__(action_scaling=action_scaling,
action_bound_method=action_bound_method, **kwargs)
if actor is not None and actor_optim is not None: if actor is not None and actor_optim is not None:
self.actor: torch.nn.Module = actor self.actor: torch.nn.Module = actor
self.actor_old = deepcopy(actor) self.actor_old = deepcopy(actor)
@ -62,9 +69,6 @@ class DDPGPolicy(BasePolicy):
assert 0.0 <= gamma <= 1.0, "gamma should be in [0, 1]" assert 0.0 <= gamma <= 1.0, "gamma should be in [0, 1]"
self._gamma = gamma self._gamma = gamma
self._noise = exploration_noise self._noise = exploration_noise
self._range = action_range
self._action_bias = (action_range[0] + action_range[1]) / 2.0
self._action_scale = (action_range[1] - action_range[0]) / 2.0
# it is only a little difference to use GaussianNoise # it is only a little difference to use GaussianNoise
# self.noise = OUNoise() # self.noise = OUNoise()
self._rew_norm = reward_normalization self._rew_norm = reward_normalization
@ -128,8 +132,6 @@ class DDPGPolicy(BasePolicy):
model = getattr(self, model) model = getattr(self, model)
obs = batch[input] obs = batch[input]
actions, h = model(obs, state=state, info=batch.info) actions, h = model(obs, state=state, info=batch.info)
actions += self._action_bias
actions = actions.clamp(self._range[0], self._range[1])
return Batch(act=actions, state=h) return Batch(act=actions, state=h)
@staticmethod @staticmethod
@ -168,5 +170,4 @@ class DDPGPolicy(BasePolicy):
def exploration_noise(self, act: np.ndarray, batch: Batch) -> np.ndarray: def exploration_noise(self, act: np.ndarray, batch: Batch) -> np.ndarray:
if self._noise: if self._noise:
act = act + self._noise(act.shape) act = act + self._noise(act.shape)
act = act.clip(self._range[0], self._range[1])
return act return act

View File

@ -49,10 +49,10 @@ class DiscreteSACPolicy(SACPolicy):
estimation_step: int = 1, estimation_step: int = 1,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
super().__init__(actor, actor_optim, critic1, critic1_optim, critic2, super().__init__(
critic2_optim, (-np.inf, np.inf), tau, gamma, alpha, actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim,
reward_normalization, estimation_step, tau, gamma, alpha, reward_normalization, estimation_step,
**kwargs) action_scaling=False, action_bound_method="", **kwargs)
self._alpha: Union[float, torch.Tensor] self._alpha: Union[float, torch.Tensor]
def forward( # type: ignore def forward( # type: ignore

View File

@ -15,6 +15,13 @@ class PGPolicy(BasePolicy):
:param dist_fn: distribution class for computing the action. :param dist_fn: distribution class for computing the action.
:type dist_fn: Type[torch.distributions.Distribution] :type dist_fn: Type[torch.distributions.Distribution]
:param float discount_factor: in [0, 1]. Default to 0.99. :param float discount_factor: in [0, 1]. Default to 0.99.
:param bool action_scaling: whether to map actions from range [-1, 1] to range
[action_spaces.low, action_spaces.high]. Default to True.
:param str action_bound_method: method to bound action to range [-1, 1], can be
either "clip" (for simply clipping the action), "tanh" (for applying tanh
squashing) for now, or empty string for no bounding. Default to "clip".
:param Optional[gym.Space] action_space: env's action space, mandatory if you want
to use option "action_scaling" or "action_bound_method". Default to None.
.. seealso:: .. seealso::
@ -29,9 +36,12 @@ class PGPolicy(BasePolicy):
dist_fn: Type[torch.distributions.Distribution], dist_fn: Type[torch.distributions.Distribution],
discount_factor: float = 0.99, discount_factor: float = 0.99,
reward_normalization: bool = False, reward_normalization: bool = False,
action_scaling: bool = True,
action_bound_method: str = "clip",
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
super().__init__(**kwargs) super().__init__(action_scaling=action_scaling,
action_bound_method=action_bound_method, **kwargs)
if model is not None: if model is not None:
self.model: torch.nn.Module = model self.model: torch.nn.Module = model
self.optim = optim self.optim = optim

View File

@ -1,7 +1,7 @@
import torch import torch
import numpy as np import numpy as np
from torch import nn from torch import nn
from typing import Any, Dict, List, Type, Tuple, Union, Optional from typing import Any, Dict, List, Type, Union, Optional
from tianshou.policy import PGPolicy from tianshou.policy import PGPolicy
from tianshou.data import Batch, ReplayBuffer, to_numpy, to_torch_as from tianshou.data import Batch, ReplayBuffer, to_numpy, to_torch_as
@ -23,8 +23,6 @@ class PPOPolicy(PGPolicy):
paper. Default to 0.2. paper. Default to 0.2.
:param float vf_coef: weight for value loss. Default to 0.5. :param float vf_coef: weight for value loss. Default to 0.5.
:param float ent_coef: weight for entropy loss. Default to 0.01. :param float ent_coef: weight for entropy loss. Default to 0.01.
:param action_range: the action range (minimum, maximum).
:type action_range: (float, float)
:param float gae_lambda: in [0, 1], param for Generalized Advantage :param float gae_lambda: in [0, 1], param for Generalized Advantage
Estimation. Default to 0.95. Estimation. Default to 0.95.
:param float dual_clip: a parameter c mentioned in arXiv:1912.09729 Equ. 5, :param float dual_clip: a parameter c mentioned in arXiv:1912.09729 Equ. 5,
@ -38,6 +36,13 @@ class PPOPolicy(PGPolicy):
depends on the size of available memory and the memory cost of the depends on the size of available memory and the memory cost of the
model; should be as large as possible within the memory constraint. model; should be as large as possible within the memory constraint.
Default to 256. Default to 256.
:param bool action_scaling: whether to map actions from range [-1, 1] to range
[action_spaces.low, action_spaces.high]. Default to True.
:param str action_bound_method: method to bound action to range [-1, 1], can be
either "clip" (for simply clipping the action), "tanh" (for applying tanh
squashing) for now, or empty string for no bounding. Default to "clip".
:param Optional[gym.Space] action_space: env's action space, mandatory if you want
to use option "action_scaling" or "action_bound_method". Default to None.
.. seealso:: .. seealso::
@ -56,7 +61,6 @@ class PPOPolicy(PGPolicy):
eps_clip: float = 0.2, eps_clip: float = 0.2,
vf_coef: float = 0.5, vf_coef: float = 0.5,
ent_coef: float = 0.01, ent_coef: float = 0.01,
action_range: Optional[Tuple[float, float]] = None,
gae_lambda: float = 0.95, gae_lambda: float = 0.95,
dual_clip: Optional[float] = None, dual_clip: Optional[float] = None,
value_clip: bool = True, value_clip: bool = True,
@ -69,7 +73,6 @@ class PPOPolicy(PGPolicy):
self._eps_clip = eps_clip self._eps_clip = eps_clip
self._weight_vf = vf_coef self._weight_vf = vf_coef
self._weight_ent = ent_coef self._weight_ent = ent_coef
self._range = action_range
self.actor = actor self.actor = actor
self.critic = critic self.critic = critic
self._batch = max_batchsize self._batch = max_batchsize
@ -135,8 +138,6 @@ class PPOPolicy(PGPolicy):
else: else:
dist = self.dist_fn(logits) dist = self.dist_fn(logits)
act = dist.sample() act = dist.sample()
if self._range:
act = act.clamp(self._range[0], self._range[1])
return Batch(logits=logits, act=act, state=h, dist=dist) return Batch(logits=logits, act=act, state=h, dist=dist)
def learn( # type: ignore def learn( # type: ignore

View File

@ -6,7 +6,7 @@ from typing import Any, Dict, Tuple, Union, Optional
from tianshou.policy import DDPGPolicy from tianshou.policy import DDPGPolicy
from tianshou.exploration import BaseNoise from tianshou.exploration import BaseNoise
from tianshou.data import Batch, ReplayBuffer from tianshou.data import Batch, ReplayBuffer, to_torch_as
class SACPolicy(DDPGPolicy): class SACPolicy(DDPGPolicy):
@ -21,8 +21,6 @@ class SACPolicy(DDPGPolicy):
:param torch.nn.Module critic2: the second critic network. (s, a -> Q(s, a)) :param torch.nn.Module critic2: the second critic network. (s, a -> Q(s, a))
:param torch.optim.Optimizer critic2_optim: the optimizer for the second :param torch.optim.Optimizer critic2_optim: the optimizer for the second
critic network. critic network.
:param action_range: the action range (minimum, maximum).
:type action_range: Tuple[float, float]
:param float tau: param for soft update of the target network. Default to 0.005. :param float tau: param for soft update of the target network. Default to 0.005.
:param float gamma: discount factor, in [0, 1]. Default to 0.99. :param float gamma: discount factor, in [0, 1]. Default to 0.99.
:param (float, torch.Tensor, torch.optim.Optimizer) or float alpha: entropy :param (float, torch.Tensor, torch.optim.Optimizer) or float alpha: entropy
@ -36,6 +34,13 @@ class SACPolicy(DDPGPolicy):
:param bool deterministic_eval: whether to use deterministic action (mean :param bool deterministic_eval: whether to use deterministic action (mean
of Gaussian policy) instead of stochastic action sampled by the policy. of Gaussian policy) instead of stochastic action sampled by the policy.
Default to True. Default to True.
:param bool action_scaling: whether to map actions from range [-1, 1] to range
[action_spaces.low, action_spaces.high]. Default to True.
:param str action_bound_method: method to bound action to range [-1, 1], can be
either "clip" (for simply clipping the action), "tanh" (for applying tanh
squashing) for now, or empty string for no bounding. Default to "tanh".
:param Optional[gym.Space] action_space: env's action space, mandatory if you want
to use option "action_scaling" or "action_bound_method". Default to None.
.. seealso:: .. seealso::
@ -51,7 +56,6 @@ class SACPolicy(DDPGPolicy):
critic1_optim: torch.optim.Optimizer, critic1_optim: torch.optim.Optimizer,
critic2: torch.nn.Module, critic2: torch.nn.Module,
critic2_optim: torch.optim.Optimizer, critic2_optim: torch.optim.Optimizer,
action_range: Tuple[float, float],
tau: float = 0.005, tau: float = 0.005,
gamma: float = 0.99, gamma: float = 0.99,
alpha: Union[float, Tuple[float, torch.Tensor, torch.optim.Optimizer]] = 0.2, alpha: Union[float, Tuple[float, torch.Tensor, torch.optim.Optimizer]] = 0.2,
@ -59,11 +63,13 @@ class SACPolicy(DDPGPolicy):
estimation_step: int = 1, estimation_step: int = 1,
exploration_noise: Optional[BaseNoise] = None, exploration_noise: Optional[BaseNoise] = None,
deterministic_eval: bool = True, deterministic_eval: bool = True,
action_bound_method: str = "tanh",
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
super().__init__(None, None, None, None, action_range, tau, gamma, super().__init__(
exploration_noise, reward_normalization, None, None, None, None, tau, gamma, exploration_noise,
estimation_step, **kwargs) reward_normalization, estimation_step,
action_bound_method=action_bound_method, **kwargs)
self.actor, self.actor_optim = actor, actor_optim self.actor, self.actor_optim = actor, actor_optim
self.critic1, self.critic1_old = critic1, deepcopy(critic1) self.critic1, self.critic1_old = critic1, deepcopy(critic1)
self.critic1_old.eval() self.critic1_old.eval()
@ -110,24 +116,26 @@ class SACPolicy(DDPGPolicy):
assert isinstance(logits, tuple) assert isinstance(logits, tuple)
dist = Independent(Normal(*logits), 1) dist = Independent(Normal(*logits), 1)
if self._deterministic_eval and not self.training: if self._deterministic_eval and not self.training:
x = logits[0] act = logits[0]
else: else:
x = dist.rsample() act = dist.rsample()
y = torch.tanh(x) log_prob = dist.log_prob(act).unsqueeze(-1)
act = y * self._action_scale + self._action_bias if self.action_bound_method == "tanh" and self.action_space is not None:
# __eps is used to avoid log of zero/negative number. # apply correction for Tanh squashing when computing logprob from Gaussian
y = self._action_scale * (1 - y.pow(2)) + self.__eps
# Compute logprob from Gaussian, and then apply correction for Tanh squashing.
# You can check out the original SAC paper (arXiv 1801.01290): Eq 21. # You can check out the original SAC paper (arXiv 1801.01290): Eq 21.
# in appendix C to get some understanding of this equation. # in appendix C to get some understanding of this equation.
log_prob = dist.log_prob(x).unsqueeze(-1) if self.action_scaling:
log_prob = log_prob - torch.log(y).sum(-1, keepdim=True) action_scale = to_torch_as(
(self.action_space.high - self.action_space.low) / 2.0, act)
else:
action_scale = 1.0 # type: ignore
squashed_action = torch.tanh(act)
log_prob = log_prob - torch.log(
action_scale * (1 - squashed_action.pow(2)) + self.__eps
).sum(-1, keepdim=True)
return Batch(logits=logits, act=act, state=h, dist=dist, log_prob=log_prob) return Batch(logits=logits, act=act, state=h, dist=dist, log_prob=log_prob)
def _target_q( def _target_q(self, buffer: ReplayBuffer, indice: np.ndarray) -> torch.Tensor:
self, buffer: ReplayBuffer, indice: np.ndarray
) -> torch.Tensor:
batch = buffer[indice] # batch.obs: s_{t+n} batch = buffer[indice] # batch.obs: s_{t+n}
obs_next_result = self(batch, input='obs_next') obs_next_result = self(batch, input='obs_next')
a_ = obs_next_result.act a_ = obs_next_result.act

View File

@ -1,7 +1,7 @@
import torch import torch
import numpy as np import numpy as np
from copy import deepcopy from copy import deepcopy
from typing import Any, Dict, Tuple, Optional from typing import Any, Dict, Optional
from tianshou.policy import DDPGPolicy from tianshou.policy import DDPGPolicy
from tianshou.data import Batch, ReplayBuffer from tianshou.data import Batch, ReplayBuffer
@ -20,8 +20,6 @@ class TD3Policy(DDPGPolicy):
:param torch.nn.Module critic2: the second critic network. (s, a -> Q(s, a)) :param torch.nn.Module critic2: the second critic network. (s, a -> Q(s, a))
:param torch.optim.Optimizer critic2_optim: the optimizer for the second :param torch.optim.Optimizer critic2_optim: the optimizer for the second
critic network. critic network.
:param action_range: the action range (minimum, maximum).
:type action_range: Tuple[float, float]
:param float tau: param for soft update of the target network. Default to 0.005. :param float tau: param for soft update of the target network. Default to 0.005.
:param float gamma: discount factor, in [0, 1]. Default to 0.99. :param float gamma: discount factor, in [0, 1]. Default to 0.99.
:param float exploration_noise: the exploration noise, add to the action. :param float exploration_noise: the exploration noise, add to the action.
@ -34,6 +32,13 @@ class TD3Policy(DDPGPolicy):
Default to 0.5. Default to 0.5.
:param bool reward_normalization: normalize the reward to Normal(0, 1). :param bool reward_normalization: normalize the reward to Normal(0, 1).
Default to False. Default to False.
:param bool action_scaling: whether to map actions from range [-1, 1] to range
[action_spaces.low, action_spaces.high]. Default to True.
:param str action_bound_method: method to bound action to range [-1, 1], can be
either "clip" (for simply clipping the action), "tanh" (for applying tanh
squashing) for now, or empty string for no bounding. Default to "clip".
:param Optional[gym.Space] action_space: env's action space, mandatory if you want
to use option "action_scaling" or "action_bound_method". Default to None.
.. seealso:: .. seealso::
@ -49,7 +54,6 @@ class TD3Policy(DDPGPolicy):
critic1_optim: torch.optim.Optimizer, critic1_optim: torch.optim.Optimizer,
critic2: torch.nn.Module, critic2: torch.nn.Module,
critic2_optim: torch.optim.Optimizer, critic2_optim: torch.optim.Optimizer,
action_range: Tuple[float, float],
tau: float = 0.005, tau: float = 0.005,
gamma: float = 0.99, gamma: float = 0.99,
exploration_noise: Optional[BaseNoise] = GaussianNoise(sigma=0.1), exploration_noise: Optional[BaseNoise] = GaussianNoise(sigma=0.1),
@ -60,7 +64,7 @@ class TD3Policy(DDPGPolicy):
estimation_step: int = 1, estimation_step: int = 1,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
super().__init__(actor, actor_optim, None, None, action_range, tau, gamma, super().__init__(actor, actor_optim, None, None, tau, gamma,
exploration_noise, reward_normalization, exploration_noise, reward_normalization,
estimation_step, **kwargs) estimation_step, **kwargs)
self.critic1, self.critic1_old = critic1, deepcopy(critic1) self.critic1, self.critic1_old = critic1, deepcopy(critic1)
@ -98,7 +102,6 @@ class TD3Policy(DDPGPolicy):
if self._noise_clip > 0.0: if self._noise_clip > 0.0:
noise = noise.clamp(-self._noise_clip, self._noise_clip) noise = noise.clamp(-self._noise_clip, self._noise_clip)
a_ += noise a_ += noise
a_ = a_.clamp(self._range[0], self._range[1])
target_q = torch.min( target_q = torch.min(
self.critic1_old(batch.obs_next, a_), self.critic1_old(batch.obs_next, a_),
self.critic2_old(batch.obs_next, a_)) self.critic2_old(batch.obs_next, a_))