Remap action to fit gym's action space (#313)
Co-authored-by: Trinkle23897 <trinkle23897@gmail.com>
This commit is contained in:
parent
0c7117dd55
commit
4d92952a7b
@ -117,9 +117,8 @@ def test_sac_bipedal(args=get_args()):
|
||||
|
||||
policy = SACPolicy(
|
||||
actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim,
|
||||
action_range=[env.action_space.low[0], env.action_space.high[0]],
|
||||
tau=args.tau, gamma=args.gamma, alpha=args.alpha,
|
||||
estimation_step=args.n_step)
|
||||
estimation_step=args.n_step, action_space=env.action_space)
|
||||
# load a previous policy
|
||||
if args.resume_path:
|
||||
policy.load_state_dict(torch.load(args.resume_path))
|
||||
|
@ -90,10 +90,10 @@ def test_sac(args=get_args()):
|
||||
|
||||
policy = SACPolicy(
|
||||
actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim,
|
||||
action_range=[env.action_space.low[0], env.action_space.high[0]],
|
||||
tau=args.tau, gamma=args.gamma, alpha=args.alpha,
|
||||
reward_normalization=args.rew_norm,
|
||||
exploration_noise=OUNoise(0.0, args.noise_std))
|
||||
exploration_noise=OUNoise(0.0, args.noise_std),
|
||||
action_space=env.action_space)
|
||||
# collector
|
||||
train_collector = Collector(
|
||||
policy, train_envs,
|
||||
|
@ -84,10 +84,9 @@ def test_ddpg(args=get_args()):
|
||||
critic_optim = torch.optim.Adam(critic.parameters(), lr=args.critic_lr)
|
||||
policy = DDPGPolicy(
|
||||
actor, actor_optim, critic, critic_optim,
|
||||
action_range=[env.action_space.low[0], env.action_space.high[0]],
|
||||
tau=args.tau, gamma=args.gamma,
|
||||
exploration_noise=GaussianNoise(sigma=args.exploration_noise),
|
||||
estimation_step=args.n_step)
|
||||
estimation_step=args.n_step, action_space=env.action_space)
|
||||
# load a previous policy
|
||||
if args.resume_path:
|
||||
policy.load_state_dict(torch.load(
|
||||
|
@ -97,9 +97,8 @@ def test_sac(args=get_args()):
|
||||
|
||||
policy = SACPolicy(
|
||||
actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim,
|
||||
action_range=[env.action_space.low[0], env.action_space.high[0]],
|
||||
tau=args.tau, gamma=args.gamma, alpha=args.alpha,
|
||||
estimation_step=args.n_step)
|
||||
estimation_step=args.n_step, action_space=env.action_space)
|
||||
# load a previous policy
|
||||
if args.resume_path:
|
||||
policy.load_state_dict(torch.load(
|
||||
|
@ -95,11 +95,11 @@ def test_td3(args=get_args()):
|
||||
|
||||
policy = TD3Policy(
|
||||
actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim,
|
||||
action_range=[env.action_space.low[0], env.action_space.high[0]],
|
||||
tau=args.tau, gamma=args.gamma,
|
||||
exploration_noise=GaussianNoise(sigma=args.exploration_noise),
|
||||
policy_noise=args.policy_noise, update_actor_freq=args.update_actor_freq,
|
||||
noise_clip=args.noise_clip, estimation_step=args.n_step)
|
||||
noise_clip=args.noise_clip, estimation_step=args.n_step,
|
||||
action_space=env.action_space)
|
||||
|
||||
# load a previous policy
|
||||
if args.resume_path:
|
||||
|
@ -77,11 +77,10 @@ def test_ddpg(args=get_args()):
|
||||
critic_optim = torch.optim.Adam(critic.parameters(), lr=args.critic_lr)
|
||||
policy = DDPGPolicy(
|
||||
actor, actor_optim, critic, critic_optim,
|
||||
action_range=[env.action_space.low[0], env.action_space.high[0]],
|
||||
tau=args.tau, gamma=args.gamma,
|
||||
exploration_noise=GaussianNoise(sigma=args.exploration_noise),
|
||||
reward_normalization=args.rew_norm,
|
||||
estimation_step=args.n_step)
|
||||
estimation_step=args.n_step, action_space=env.action_space)
|
||||
# collector
|
||||
train_collector = Collector(
|
||||
policy, train_envs,
|
||||
|
@ -100,9 +100,8 @@ def test_ppo(args=get_args()):
|
||||
# dual_clip=args.dual_clip,
|
||||
# dual clip cause monotonically increasing log_std :)
|
||||
value_clip=args.value_clip,
|
||||
# action_range=[env.action_space.low[0], env.action_space.high[0]],)
|
||||
# if clip the action, ppo would not converge :)
|
||||
gae_lambda=args.gae_lambda)
|
||||
gae_lambda=args.gae_lambda,
|
||||
action_space=env.action_space)
|
||||
# collector
|
||||
train_collector = Collector(
|
||||
policy, train_envs,
|
||||
|
@ -87,10 +87,9 @@ def test_sac_with_il(args=get_args()):
|
||||
critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)
|
||||
policy = SACPolicy(
|
||||
actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim,
|
||||
action_range=[env.action_space.low[0], env.action_space.high[0]],
|
||||
tau=args.tau, gamma=args.gamma, alpha=args.alpha,
|
||||
reward_normalization=args.rew_norm,
|
||||
estimation_step=args.n_step)
|
||||
estimation_step=args.n_step, action_space=env.action_space)
|
||||
# collector
|
||||
train_collector = Collector(
|
||||
policy, train_envs,
|
||||
|
@ -86,14 +86,14 @@ def test_td3(args=get_args()):
|
||||
critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)
|
||||
policy = TD3Policy(
|
||||
actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim,
|
||||
action_range=[env.action_space.low[0], env.action_space.high[0]],
|
||||
tau=args.tau, gamma=args.gamma,
|
||||
exploration_noise=GaussianNoise(sigma=args.exploration_noise),
|
||||
policy_noise=args.policy_noise,
|
||||
update_actor_freq=args.update_actor_freq,
|
||||
noise_clip=args.noise_clip,
|
||||
reward_normalization=args.rew_norm,
|
||||
estimation_step=args.n_step)
|
||||
estimation_step=args.n_step,
|
||||
action_space=env.action_space)
|
||||
# collector
|
||||
train_collector = Collector(
|
||||
policy, train_envs,
|
||||
|
@ -80,7 +80,8 @@ def test_a2c_with_il(args=get_args()):
|
||||
policy = A2CPolicy(
|
||||
actor, critic, optim, dist, args.gamma, gae_lambda=args.gae_lambda,
|
||||
vf_coef=args.vf_coef, ent_coef=args.ent_coef,
|
||||
max_grad_norm=args.max_grad_norm, reward_normalization=args.rew_norm)
|
||||
max_grad_norm=args.max_grad_norm, reward_normalization=args.rew_norm,
|
||||
action_space=env.action_space)
|
||||
# collector
|
||||
train_collector = Collector(
|
||||
policy, train_envs,
|
||||
|
@ -63,7 +63,8 @@ def test_pg(args=get_args()):
|
||||
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
|
||||
dist = torch.distributions.Categorical
|
||||
policy = PGPolicy(net, optim, dist, args.gamma,
|
||||
reward_normalization=args.rew_norm)
|
||||
reward_normalization=args.rew_norm,
|
||||
action_space=env.action_space)
|
||||
# collector
|
||||
train_collector = Collector(
|
||||
policy, train_envs,
|
||||
|
@ -85,11 +85,11 @@ def test_ppo(args=get_args()):
|
||||
eps_clip=args.eps_clip,
|
||||
vf_coef=args.vf_coef,
|
||||
ent_coef=args.ent_coef,
|
||||
action_range=None,
|
||||
gae_lambda=args.gae_lambda,
|
||||
reward_normalization=args.rew_norm,
|
||||
dual_clip=args.dual_clip,
|
||||
value_clip=args.value_clip)
|
||||
value_clip=args.value_clip,
|
||||
action_space=env.action_space)
|
||||
# collector
|
||||
train_collector = Collector(
|
||||
policy, train_envs,
|
||||
|
@ -219,8 +219,10 @@ class Collector(object):
|
||||
act = self.policy.exploration_noise(act, self.data)
|
||||
self.data.update(policy=policy, act=act)
|
||||
|
||||
# get bounded and remapped actions first (not saved into buffer)
|
||||
action_remap = self.policy.map_action(self.data.act)
|
||||
# step in env
|
||||
obs_next, rew, done, info = self.env.step(self.data.act, id=ready_env_ids)
|
||||
obs_next, rew, done, info = self.env.step(action_remap, id=ready_env_ids)
|
||||
|
||||
self.data.update(obs_next=obs_next, rew=rew, done=done, info=info)
|
||||
if self.preprocess_fn:
|
||||
@ -419,8 +421,10 @@ class AsyncCollector(Collector):
|
||||
_alloc_by_keys_diff(whole_data, self.data, self.env_num, False)
|
||||
whole_data[ready_env_ids] = self.data # lots of overhead
|
||||
|
||||
# get bounded and remapped actions first (not saved into buffer)
|
||||
action_remap = self.policy.map_action(self.data.act)
|
||||
# step in env
|
||||
obs_next, rew, done, info = self.env.step(self.data.act, id=ready_env_ids)
|
||||
obs_next, rew, done, info = self.env.step(action_remap, id=ready_env_ids)
|
||||
|
||||
# change self.data here because ready_env_ids has changed
|
||||
ready_env_ids = np.array([i["env_id"] for i in info])
|
||||
|
@ -53,14 +53,20 @@ class BasePolicy(ABC, nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
observation_space: gym.Space = None,
|
||||
action_space: gym.Space = None
|
||||
observation_space: Optional[gym.Space] = None,
|
||||
action_space: Optional[gym.Space] = None,
|
||||
action_scaling: bool = False,
|
||||
action_bound_method: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.observation_space = observation_space
|
||||
self.action_space = action_space
|
||||
self.agent_id = 0
|
||||
self.updating = False
|
||||
self.action_scaling = action_scaling
|
||||
# can be one of ("clip", "tanh", ""), empty string means no bounding
|
||||
assert action_bound_method in ("", "clip", "tanh")
|
||||
self.action_bound_method = action_bound_method
|
||||
self._compile()
|
||||
|
||||
def set_agent_id(self, agent_id: int) -> None:
|
||||
@ -114,6 +120,38 @@ class BasePolicy(ABC, nn.Module):
|
||||
"""
|
||||
pass
|
||||
|
||||
def map_action(self, act: Union[Batch, np.ndarray]) -> Union[Batch, np.ndarray]:
|
||||
"""Map raw network output to action range in gym's env.action_space.
|
||||
|
||||
This function is called in :meth:`~tianshou.data.Collector.collect` and only
|
||||
affects action sending to env. Remapped action will not be stored in buffer
|
||||
and thus can be viewed as a part of env (a black box action transformation).
|
||||
|
||||
Action mapping includes 2 standard procedures: bounding and scaling. Bounding
|
||||
procedure expects original action range is (-inf, inf) and maps it to [-1, 1],
|
||||
while scaling procedure expects original action range is (-1, 1) and maps it
|
||||
to [action_space.low, action_space.high]. Bounding procedure is applied first.
|
||||
|
||||
:param act: a data batch or numpy.ndarray which is the action taken by
|
||||
policy.forward.
|
||||
|
||||
:return: action in the same form of input "act" but remap to the target action
|
||||
space.
|
||||
"""
|
||||
if isinstance(self.action_space, gym.spaces.Box) and \
|
||||
isinstance(act, np.ndarray):
|
||||
# currently this action mapping only supports np.ndarray action
|
||||
if self.action_bound_method == "clip":
|
||||
act = np.clip(act, -1.0, 1.0)
|
||||
elif self.action_bound_method == "tanh":
|
||||
act = np.tanh(act)
|
||||
if self.action_scaling:
|
||||
assert np.all(act >= -1.0) and np.all(act <= 1.0), \
|
||||
"action scaling only accepts raw action range = [-1, 1]"
|
||||
low, high = self.action_space.low, self.action_space.high
|
||||
act = low + (high - low) * (act + 1.0) / 2.0
|
||||
return act
|
||||
|
||||
def process_fn(
|
||||
self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray
|
||||
) -> Batch:
|
||||
|
@ -31,6 +31,13 @@ class A2CPolicy(PGPolicy):
|
||||
depends on the size of available memory and the memory cost of the
|
||||
model; should be as large as possible within the memory constraint.
|
||||
Default to 256.
|
||||
:param bool action_scaling: whether to map actions from range [-1, 1] to range
|
||||
[action_spaces.low, action_spaces.high]. Default to True.
|
||||
:param str action_bound_method: method to bound action to range [-1, 1], can be
|
||||
either "clip" (for simply clipping the action), "tanh" (for applying tanh
|
||||
squashing) for now, or empty string for no bounding. Default to "clip".
|
||||
:param Optional[gym.Space] action_space: env's action space, mandatory if you want
|
||||
to use option "action_scaling" or "action_bound_method". Default to None.
|
||||
|
||||
.. seealso::
|
||||
|
||||
|
@ -16,8 +16,6 @@ class DDPGPolicy(BasePolicy):
|
||||
:param torch.optim.Optimizer actor_optim: the optimizer for actor network.
|
||||
:param torch.nn.Module critic: the critic network. (s, a -> Q(s, a))
|
||||
:param torch.optim.Optimizer critic_optim: the optimizer for critic network.
|
||||
:param action_range: the action range (minimum, maximum).
|
||||
:type action_range: Tuple[float, float]
|
||||
:param float tau: param for soft update of the target network. Default to 0.005.
|
||||
:param float gamma: discount factor, in [0, 1]. Default to 0.99.
|
||||
:param BaseNoise exploration_noise: the exploration noise,
|
||||
@ -25,6 +23,13 @@ class DDPGPolicy(BasePolicy):
|
||||
:param bool reward_normalization: normalize the reward to Normal(0, 1),
|
||||
Default to False.
|
||||
:param int estimation_step: the number of steps to look ahead. Default to 1.
|
||||
:param bool action_scaling: whether to map actions from range [-1, 1] to range
|
||||
[action_spaces.low, action_spaces.high]. Default to True.
|
||||
:param str action_bound_method: method to bound action to range [-1, 1], can be
|
||||
either "clip" (for simply clipping the action), "tanh" (for applying tanh
|
||||
squashing) for now, or empty string for no bounding. Default to "clip".
|
||||
:param Optional[gym.Space] action_space: env's action space, mandatory if you want
|
||||
to use option "action_scaling" or "action_bound_method". Default to None.
|
||||
|
||||
.. seealso::
|
||||
|
||||
@ -38,15 +43,17 @@ class DDPGPolicy(BasePolicy):
|
||||
actor_optim: Optional[torch.optim.Optimizer],
|
||||
critic: Optional[torch.nn.Module],
|
||||
critic_optim: Optional[torch.optim.Optimizer],
|
||||
action_range: Tuple[float, float],
|
||||
tau: float = 0.005,
|
||||
gamma: float = 0.99,
|
||||
exploration_noise: Optional[BaseNoise] = GaussianNoise(sigma=0.1),
|
||||
reward_normalization: bool = False,
|
||||
estimation_step: int = 1,
|
||||
action_scaling: bool = True,
|
||||
action_bound_method: str = "clip",
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(**kwargs)
|
||||
super().__init__(action_scaling=action_scaling,
|
||||
action_bound_method=action_bound_method, **kwargs)
|
||||
if actor is not None and actor_optim is not None:
|
||||
self.actor: torch.nn.Module = actor
|
||||
self.actor_old = deepcopy(actor)
|
||||
@ -62,9 +69,6 @@ class DDPGPolicy(BasePolicy):
|
||||
assert 0.0 <= gamma <= 1.0, "gamma should be in [0, 1]"
|
||||
self._gamma = gamma
|
||||
self._noise = exploration_noise
|
||||
self._range = action_range
|
||||
self._action_bias = (action_range[0] + action_range[1]) / 2.0
|
||||
self._action_scale = (action_range[1] - action_range[0]) / 2.0
|
||||
# it is only a little difference to use GaussianNoise
|
||||
# self.noise = OUNoise()
|
||||
self._rew_norm = reward_normalization
|
||||
@ -128,8 +132,6 @@ class DDPGPolicy(BasePolicy):
|
||||
model = getattr(self, model)
|
||||
obs = batch[input]
|
||||
actions, h = model(obs, state=state, info=batch.info)
|
||||
actions += self._action_bias
|
||||
actions = actions.clamp(self._range[0], self._range[1])
|
||||
return Batch(act=actions, state=h)
|
||||
|
||||
@staticmethod
|
||||
@ -168,5 +170,4 @@ class DDPGPolicy(BasePolicy):
|
||||
def exploration_noise(self, act: np.ndarray, batch: Batch) -> np.ndarray:
|
||||
if self._noise:
|
||||
act = act + self._noise(act.shape)
|
||||
act = act.clip(self._range[0], self._range[1])
|
||||
return act
|
||||
|
@ -49,10 +49,10 @@ class DiscreteSACPolicy(SACPolicy):
|
||||
estimation_step: int = 1,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(actor, actor_optim, critic1, critic1_optim, critic2,
|
||||
critic2_optim, (-np.inf, np.inf), tau, gamma, alpha,
|
||||
reward_normalization, estimation_step,
|
||||
**kwargs)
|
||||
super().__init__(
|
||||
actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim,
|
||||
tau, gamma, alpha, reward_normalization, estimation_step,
|
||||
action_scaling=False, action_bound_method="", **kwargs)
|
||||
self._alpha: Union[float, torch.Tensor]
|
||||
|
||||
def forward( # type: ignore
|
||||
|
@ -15,6 +15,13 @@ class PGPolicy(BasePolicy):
|
||||
:param dist_fn: distribution class for computing the action.
|
||||
:type dist_fn: Type[torch.distributions.Distribution]
|
||||
:param float discount_factor: in [0, 1]. Default to 0.99.
|
||||
:param bool action_scaling: whether to map actions from range [-1, 1] to range
|
||||
[action_spaces.low, action_spaces.high]. Default to True.
|
||||
:param str action_bound_method: method to bound action to range [-1, 1], can be
|
||||
either "clip" (for simply clipping the action), "tanh" (for applying tanh
|
||||
squashing) for now, or empty string for no bounding. Default to "clip".
|
||||
:param Optional[gym.Space] action_space: env's action space, mandatory if you want
|
||||
to use option "action_scaling" or "action_bound_method". Default to None.
|
||||
|
||||
.. seealso::
|
||||
|
||||
@ -29,9 +36,12 @@ class PGPolicy(BasePolicy):
|
||||
dist_fn: Type[torch.distributions.Distribution],
|
||||
discount_factor: float = 0.99,
|
||||
reward_normalization: bool = False,
|
||||
action_scaling: bool = True,
|
||||
action_bound_method: str = "clip",
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(**kwargs)
|
||||
super().__init__(action_scaling=action_scaling,
|
||||
action_bound_method=action_bound_method, **kwargs)
|
||||
if model is not None:
|
||||
self.model: torch.nn.Module = model
|
||||
self.optim = optim
|
||||
|
@ -1,7 +1,7 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from torch import nn
|
||||
from typing import Any, Dict, List, Type, Tuple, Union, Optional
|
||||
from typing import Any, Dict, List, Type, Union, Optional
|
||||
|
||||
from tianshou.policy import PGPolicy
|
||||
from tianshou.data import Batch, ReplayBuffer, to_numpy, to_torch_as
|
||||
@ -23,8 +23,6 @@ class PPOPolicy(PGPolicy):
|
||||
paper. Default to 0.2.
|
||||
:param float vf_coef: weight for value loss. Default to 0.5.
|
||||
:param float ent_coef: weight for entropy loss. Default to 0.01.
|
||||
:param action_range: the action range (minimum, maximum).
|
||||
:type action_range: (float, float)
|
||||
:param float gae_lambda: in [0, 1], param for Generalized Advantage
|
||||
Estimation. Default to 0.95.
|
||||
:param float dual_clip: a parameter c mentioned in arXiv:1912.09729 Equ. 5,
|
||||
@ -38,6 +36,13 @@ class PPOPolicy(PGPolicy):
|
||||
depends on the size of available memory and the memory cost of the
|
||||
model; should be as large as possible within the memory constraint.
|
||||
Default to 256.
|
||||
:param bool action_scaling: whether to map actions from range [-1, 1] to range
|
||||
[action_spaces.low, action_spaces.high]. Default to True.
|
||||
:param str action_bound_method: method to bound action to range [-1, 1], can be
|
||||
either "clip" (for simply clipping the action), "tanh" (for applying tanh
|
||||
squashing) for now, or empty string for no bounding. Default to "clip".
|
||||
:param Optional[gym.Space] action_space: env's action space, mandatory if you want
|
||||
to use option "action_scaling" or "action_bound_method". Default to None.
|
||||
|
||||
.. seealso::
|
||||
|
||||
@ -56,7 +61,6 @@ class PPOPolicy(PGPolicy):
|
||||
eps_clip: float = 0.2,
|
||||
vf_coef: float = 0.5,
|
||||
ent_coef: float = 0.01,
|
||||
action_range: Optional[Tuple[float, float]] = None,
|
||||
gae_lambda: float = 0.95,
|
||||
dual_clip: Optional[float] = None,
|
||||
value_clip: bool = True,
|
||||
@ -69,7 +73,6 @@ class PPOPolicy(PGPolicy):
|
||||
self._eps_clip = eps_clip
|
||||
self._weight_vf = vf_coef
|
||||
self._weight_ent = ent_coef
|
||||
self._range = action_range
|
||||
self.actor = actor
|
||||
self.critic = critic
|
||||
self._batch = max_batchsize
|
||||
@ -135,8 +138,6 @@ class PPOPolicy(PGPolicy):
|
||||
else:
|
||||
dist = self.dist_fn(logits)
|
||||
act = dist.sample()
|
||||
if self._range:
|
||||
act = act.clamp(self._range[0], self._range[1])
|
||||
return Batch(logits=logits, act=act, state=h, dist=dist)
|
||||
|
||||
def learn( # type: ignore
|
||||
|
@ -6,7 +6,7 @@ from typing import Any, Dict, Tuple, Union, Optional
|
||||
|
||||
from tianshou.policy import DDPGPolicy
|
||||
from tianshou.exploration import BaseNoise
|
||||
from tianshou.data import Batch, ReplayBuffer
|
||||
from tianshou.data import Batch, ReplayBuffer, to_torch_as
|
||||
|
||||
|
||||
class SACPolicy(DDPGPolicy):
|
||||
@ -21,8 +21,6 @@ class SACPolicy(DDPGPolicy):
|
||||
:param torch.nn.Module critic2: the second critic network. (s, a -> Q(s, a))
|
||||
:param torch.optim.Optimizer critic2_optim: the optimizer for the second
|
||||
critic network.
|
||||
:param action_range: the action range (minimum, maximum).
|
||||
:type action_range: Tuple[float, float]
|
||||
:param float tau: param for soft update of the target network. Default to 0.005.
|
||||
:param float gamma: discount factor, in [0, 1]. Default to 0.99.
|
||||
:param (float, torch.Tensor, torch.optim.Optimizer) or float alpha: entropy
|
||||
@ -36,6 +34,13 @@ class SACPolicy(DDPGPolicy):
|
||||
:param bool deterministic_eval: whether to use deterministic action (mean
|
||||
of Gaussian policy) instead of stochastic action sampled by the policy.
|
||||
Default to True.
|
||||
:param bool action_scaling: whether to map actions from range [-1, 1] to range
|
||||
[action_spaces.low, action_spaces.high]. Default to True.
|
||||
:param str action_bound_method: method to bound action to range [-1, 1], can be
|
||||
either "clip" (for simply clipping the action), "tanh" (for applying tanh
|
||||
squashing) for now, or empty string for no bounding. Default to "tanh".
|
||||
:param Optional[gym.Space] action_space: env's action space, mandatory if you want
|
||||
to use option "action_scaling" or "action_bound_method". Default to None.
|
||||
|
||||
.. seealso::
|
||||
|
||||
@ -51,7 +56,6 @@ class SACPolicy(DDPGPolicy):
|
||||
critic1_optim: torch.optim.Optimizer,
|
||||
critic2: torch.nn.Module,
|
||||
critic2_optim: torch.optim.Optimizer,
|
||||
action_range: Tuple[float, float],
|
||||
tau: float = 0.005,
|
||||
gamma: float = 0.99,
|
||||
alpha: Union[float, Tuple[float, torch.Tensor, torch.optim.Optimizer]] = 0.2,
|
||||
@ -59,11 +63,13 @@ class SACPolicy(DDPGPolicy):
|
||||
estimation_step: int = 1,
|
||||
exploration_noise: Optional[BaseNoise] = None,
|
||||
deterministic_eval: bool = True,
|
||||
action_bound_method: str = "tanh",
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(None, None, None, None, action_range, tau, gamma,
|
||||
exploration_noise, reward_normalization,
|
||||
estimation_step, **kwargs)
|
||||
super().__init__(
|
||||
None, None, None, None, tau, gamma, exploration_noise,
|
||||
reward_normalization, estimation_step,
|
||||
action_bound_method=action_bound_method, **kwargs)
|
||||
self.actor, self.actor_optim = actor, actor_optim
|
||||
self.critic1, self.critic1_old = critic1, deepcopy(critic1)
|
||||
self.critic1_old.eval()
|
||||
@ -110,24 +116,26 @@ class SACPolicy(DDPGPolicy):
|
||||
assert isinstance(logits, tuple)
|
||||
dist = Independent(Normal(*logits), 1)
|
||||
if self._deterministic_eval and not self.training:
|
||||
x = logits[0]
|
||||
act = logits[0]
|
||||
else:
|
||||
x = dist.rsample()
|
||||
y = torch.tanh(x)
|
||||
act = y * self._action_scale + self._action_bias
|
||||
# __eps is used to avoid log of zero/negative number.
|
||||
y = self._action_scale * (1 - y.pow(2)) + self.__eps
|
||||
# Compute logprob from Gaussian, and then apply correction for Tanh squashing.
|
||||
# You can check out the original SAC paper (arXiv 1801.01290): Eq 21.
|
||||
# in appendix C to get some understanding of this equation.
|
||||
log_prob = dist.log_prob(x).unsqueeze(-1)
|
||||
log_prob = log_prob - torch.log(y).sum(-1, keepdim=True)
|
||||
|
||||
act = dist.rsample()
|
||||
log_prob = dist.log_prob(act).unsqueeze(-1)
|
||||
if self.action_bound_method == "tanh" and self.action_space is not None:
|
||||
# apply correction for Tanh squashing when computing logprob from Gaussian
|
||||
# You can check out the original SAC paper (arXiv 1801.01290): Eq 21.
|
||||
# in appendix C to get some understanding of this equation.
|
||||
if self.action_scaling:
|
||||
action_scale = to_torch_as(
|
||||
(self.action_space.high - self.action_space.low) / 2.0, act)
|
||||
else:
|
||||
action_scale = 1.0 # type: ignore
|
||||
squashed_action = torch.tanh(act)
|
||||
log_prob = log_prob - torch.log(
|
||||
action_scale * (1 - squashed_action.pow(2)) + self.__eps
|
||||
).sum(-1, keepdim=True)
|
||||
return Batch(logits=logits, act=act, state=h, dist=dist, log_prob=log_prob)
|
||||
|
||||
def _target_q(
|
||||
self, buffer: ReplayBuffer, indice: np.ndarray
|
||||
) -> torch.Tensor:
|
||||
def _target_q(self, buffer: ReplayBuffer, indice: np.ndarray) -> torch.Tensor:
|
||||
batch = buffer[indice] # batch.obs: s_{t+n}
|
||||
obs_next_result = self(batch, input='obs_next')
|
||||
a_ = obs_next_result.act
|
||||
|
@ -1,7 +1,7 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from copy import deepcopy
|
||||
from typing import Any, Dict, Tuple, Optional
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from tianshou.policy import DDPGPolicy
|
||||
from tianshou.data import Batch, ReplayBuffer
|
||||
@ -20,8 +20,6 @@ class TD3Policy(DDPGPolicy):
|
||||
:param torch.nn.Module critic2: the second critic network. (s, a -> Q(s, a))
|
||||
:param torch.optim.Optimizer critic2_optim: the optimizer for the second
|
||||
critic network.
|
||||
:param action_range: the action range (minimum, maximum).
|
||||
:type action_range: Tuple[float, float]
|
||||
:param float tau: param for soft update of the target network. Default to 0.005.
|
||||
:param float gamma: discount factor, in [0, 1]. Default to 0.99.
|
||||
:param float exploration_noise: the exploration noise, add to the action.
|
||||
@ -34,6 +32,13 @@ class TD3Policy(DDPGPolicy):
|
||||
Default to 0.5.
|
||||
:param bool reward_normalization: normalize the reward to Normal(0, 1).
|
||||
Default to False.
|
||||
:param bool action_scaling: whether to map actions from range [-1, 1] to range
|
||||
[action_spaces.low, action_spaces.high]. Default to True.
|
||||
:param str action_bound_method: method to bound action to range [-1, 1], can be
|
||||
either "clip" (for simply clipping the action), "tanh" (for applying tanh
|
||||
squashing) for now, or empty string for no bounding. Default to "clip".
|
||||
:param Optional[gym.Space] action_space: env's action space, mandatory if you want
|
||||
to use option "action_scaling" or "action_bound_method". Default to None.
|
||||
|
||||
.. seealso::
|
||||
|
||||
@ -49,7 +54,6 @@ class TD3Policy(DDPGPolicy):
|
||||
critic1_optim: torch.optim.Optimizer,
|
||||
critic2: torch.nn.Module,
|
||||
critic2_optim: torch.optim.Optimizer,
|
||||
action_range: Tuple[float, float],
|
||||
tau: float = 0.005,
|
||||
gamma: float = 0.99,
|
||||
exploration_noise: Optional[BaseNoise] = GaussianNoise(sigma=0.1),
|
||||
@ -60,7 +64,7 @@ class TD3Policy(DDPGPolicy):
|
||||
estimation_step: int = 1,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(actor, actor_optim, None, None, action_range, tau, gamma,
|
||||
super().__init__(actor, actor_optim, None, None, tau, gamma,
|
||||
exploration_noise, reward_normalization,
|
||||
estimation_step, **kwargs)
|
||||
self.critic1, self.critic1_old = critic1, deepcopy(critic1)
|
||||
@ -98,7 +102,6 @@ class TD3Policy(DDPGPolicy):
|
||||
if self._noise_clip > 0.0:
|
||||
noise = noise.clamp(-self._noise_clip, self._noise_clip)
|
||||
a_ += noise
|
||||
a_ = a_.clamp(self._range[0], self._range[1])
|
||||
target_q = torch.min(
|
||||
self.critic1_old(batch.obs_next, a_),
|
||||
self.critic2_old(batch.obs_next, a_))
|
||||
|
Loading…
x
Reference in New Issue
Block a user