Remap action to fit gym's action space (#313)

Co-authored-by: Trinkle23897 <trinkle23897@gmail.com>
This commit is contained in:
ChenDRAG 2021-03-21 16:45:50 +08:00 committed by GitHub
parent 0c7117dd55
commit 4d92952a7b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
21 changed files with 145 additions and 77 deletions

View File

@ -117,9 +117,8 @@ def test_sac_bipedal(args=get_args()):
policy = SACPolicy(
actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim,
action_range=[env.action_space.low[0], env.action_space.high[0]],
tau=args.tau, gamma=args.gamma, alpha=args.alpha,
estimation_step=args.n_step)
estimation_step=args.n_step, action_space=env.action_space)
# load a previous policy
if args.resume_path:
policy.load_state_dict(torch.load(args.resume_path))

View File

@ -90,10 +90,10 @@ def test_sac(args=get_args()):
policy = SACPolicy(
actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim,
action_range=[env.action_space.low[0], env.action_space.high[0]],
tau=args.tau, gamma=args.gamma, alpha=args.alpha,
reward_normalization=args.rew_norm,
exploration_noise=OUNoise(0.0, args.noise_std))
exploration_noise=OUNoise(0.0, args.noise_std),
action_space=env.action_space)
# collector
train_collector = Collector(
policy, train_envs,

View File

@ -84,10 +84,9 @@ def test_ddpg(args=get_args()):
critic_optim = torch.optim.Adam(critic.parameters(), lr=args.critic_lr)
policy = DDPGPolicy(
actor, actor_optim, critic, critic_optim,
action_range=[env.action_space.low[0], env.action_space.high[0]],
tau=args.tau, gamma=args.gamma,
exploration_noise=GaussianNoise(sigma=args.exploration_noise),
estimation_step=args.n_step)
estimation_step=args.n_step, action_space=env.action_space)
# load a previous policy
if args.resume_path:
policy.load_state_dict(torch.load(

View File

@ -97,9 +97,8 @@ def test_sac(args=get_args()):
policy = SACPolicy(
actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim,
action_range=[env.action_space.low[0], env.action_space.high[0]],
tau=args.tau, gamma=args.gamma, alpha=args.alpha,
estimation_step=args.n_step)
estimation_step=args.n_step, action_space=env.action_space)
# load a previous policy
if args.resume_path:
policy.load_state_dict(torch.load(

View File

@ -95,11 +95,11 @@ def test_td3(args=get_args()):
policy = TD3Policy(
actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim,
action_range=[env.action_space.low[0], env.action_space.high[0]],
tau=args.tau, gamma=args.gamma,
exploration_noise=GaussianNoise(sigma=args.exploration_noise),
policy_noise=args.policy_noise, update_actor_freq=args.update_actor_freq,
noise_clip=args.noise_clip, estimation_step=args.n_step)
noise_clip=args.noise_clip, estimation_step=args.n_step,
action_space=env.action_space)
# load a previous policy
if args.resume_path:

View File

@ -77,11 +77,10 @@ def test_ddpg(args=get_args()):
critic_optim = torch.optim.Adam(critic.parameters(), lr=args.critic_lr)
policy = DDPGPolicy(
actor, actor_optim, critic, critic_optim,
action_range=[env.action_space.low[0], env.action_space.high[0]],
tau=args.tau, gamma=args.gamma,
exploration_noise=GaussianNoise(sigma=args.exploration_noise),
reward_normalization=args.rew_norm,
estimation_step=args.n_step)
estimation_step=args.n_step, action_space=env.action_space)
# collector
train_collector = Collector(
policy, train_envs,

View File

@ -100,9 +100,8 @@ def test_ppo(args=get_args()):
# dual_clip=args.dual_clip,
# dual clip cause monotonically increasing log_std :)
value_clip=args.value_clip,
# action_range=[env.action_space.low[0], env.action_space.high[0]],)
# if clip the action, ppo would not converge :)
gae_lambda=args.gae_lambda)
gae_lambda=args.gae_lambda,
action_space=env.action_space)
# collector
train_collector = Collector(
policy, train_envs,

View File

@ -87,10 +87,9 @@ def test_sac_with_il(args=get_args()):
critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)
policy = SACPolicy(
actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim,
action_range=[env.action_space.low[0], env.action_space.high[0]],
tau=args.tau, gamma=args.gamma, alpha=args.alpha,
reward_normalization=args.rew_norm,
estimation_step=args.n_step)
estimation_step=args.n_step, action_space=env.action_space)
# collector
train_collector = Collector(
policy, train_envs,

View File

@ -86,14 +86,14 @@ def test_td3(args=get_args()):
critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)
policy = TD3Policy(
actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim,
action_range=[env.action_space.low[0], env.action_space.high[0]],
tau=args.tau, gamma=args.gamma,
exploration_noise=GaussianNoise(sigma=args.exploration_noise),
policy_noise=args.policy_noise,
update_actor_freq=args.update_actor_freq,
noise_clip=args.noise_clip,
reward_normalization=args.rew_norm,
estimation_step=args.n_step)
estimation_step=args.n_step,
action_space=env.action_space)
# collector
train_collector = Collector(
policy, train_envs,

View File

@ -80,7 +80,8 @@ def test_a2c_with_il(args=get_args()):
policy = A2CPolicy(
actor, critic, optim, dist, args.gamma, gae_lambda=args.gae_lambda,
vf_coef=args.vf_coef, ent_coef=args.ent_coef,
max_grad_norm=args.max_grad_norm, reward_normalization=args.rew_norm)
max_grad_norm=args.max_grad_norm, reward_normalization=args.rew_norm,
action_space=env.action_space)
# collector
train_collector = Collector(
policy, train_envs,

View File

@ -63,7 +63,8 @@ def test_pg(args=get_args()):
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
dist = torch.distributions.Categorical
policy = PGPolicy(net, optim, dist, args.gamma,
reward_normalization=args.rew_norm)
reward_normalization=args.rew_norm,
action_space=env.action_space)
# collector
train_collector = Collector(
policy, train_envs,

View File

@ -85,11 +85,11 @@ def test_ppo(args=get_args()):
eps_clip=args.eps_clip,
vf_coef=args.vf_coef,
ent_coef=args.ent_coef,
action_range=None,
gae_lambda=args.gae_lambda,
reward_normalization=args.rew_norm,
dual_clip=args.dual_clip,
value_clip=args.value_clip)
value_clip=args.value_clip,
action_space=env.action_space)
# collector
train_collector = Collector(
policy, train_envs,

View File

@ -219,8 +219,10 @@ class Collector(object):
act = self.policy.exploration_noise(act, self.data)
self.data.update(policy=policy, act=act)
# get bounded and remapped actions first (not saved into buffer)
action_remap = self.policy.map_action(self.data.act)
# step in env
obs_next, rew, done, info = self.env.step(self.data.act, id=ready_env_ids)
obs_next, rew, done, info = self.env.step(action_remap, id=ready_env_ids)
self.data.update(obs_next=obs_next, rew=rew, done=done, info=info)
if self.preprocess_fn:
@ -419,8 +421,10 @@ class AsyncCollector(Collector):
_alloc_by_keys_diff(whole_data, self.data, self.env_num, False)
whole_data[ready_env_ids] = self.data # lots of overhead
# get bounded and remapped actions first (not saved into buffer)
action_remap = self.policy.map_action(self.data.act)
# step in env
obs_next, rew, done, info = self.env.step(self.data.act, id=ready_env_ids)
obs_next, rew, done, info = self.env.step(action_remap, id=ready_env_ids)
# change self.data here because ready_env_ids has changed
ready_env_ids = np.array([i["env_id"] for i in info])

View File

@ -53,14 +53,20 @@ class BasePolicy(ABC, nn.Module):
def __init__(
self,
observation_space: gym.Space = None,
action_space: gym.Space = None
observation_space: Optional[gym.Space] = None,
action_space: Optional[gym.Space] = None,
action_scaling: bool = False,
action_bound_method: str = "",
) -> None:
super().__init__()
self.observation_space = observation_space
self.action_space = action_space
self.agent_id = 0
self.updating = False
self.action_scaling = action_scaling
# can be one of ("clip", "tanh", ""), empty string means no bounding
assert action_bound_method in ("", "clip", "tanh")
self.action_bound_method = action_bound_method
self._compile()
def set_agent_id(self, agent_id: int) -> None:
@ -114,6 +120,38 @@ class BasePolicy(ABC, nn.Module):
"""
pass
def map_action(self, act: Union[Batch, np.ndarray]) -> Union[Batch, np.ndarray]:
"""Map raw network output to action range in gym's env.action_space.
This function is called in :meth:`~tianshou.data.Collector.collect` and only
affects action sending to env. Remapped action will not be stored in buffer
and thus can be viewed as a part of env (a black box action transformation).
Action mapping includes 2 standard procedures: bounding and scaling. Bounding
procedure expects original action range is (-inf, inf) and maps it to [-1, 1],
while scaling procedure expects original action range is (-1, 1) and maps it
to [action_space.low, action_space.high]. Bounding procedure is applied first.
:param act: a data batch or numpy.ndarray which is the action taken by
policy.forward.
:return: action in the same form of input "act" but remap to the target action
space.
"""
if isinstance(self.action_space, gym.spaces.Box) and \
isinstance(act, np.ndarray):
# currently this action mapping only supports np.ndarray action
if self.action_bound_method == "clip":
act = np.clip(act, -1.0, 1.0)
elif self.action_bound_method == "tanh":
act = np.tanh(act)
if self.action_scaling:
assert np.all(act >= -1.0) and np.all(act <= 1.0), \
"action scaling only accepts raw action range = [-1, 1]"
low, high = self.action_space.low, self.action_space.high
act = low + (high - low) * (act + 1.0) / 2.0
return act
def process_fn(
self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray
) -> Batch:

View File

@ -31,6 +31,13 @@ class A2CPolicy(PGPolicy):
depends on the size of available memory and the memory cost of the
model; should be as large as possible within the memory constraint.
Default to 256.
:param bool action_scaling: whether to map actions from range [-1, 1] to range
[action_spaces.low, action_spaces.high]. Default to True.
:param str action_bound_method: method to bound action to range [-1, 1], can be
either "clip" (for simply clipping the action), "tanh" (for applying tanh
squashing) for now, or empty string for no bounding. Default to "clip".
:param Optional[gym.Space] action_space: env's action space, mandatory if you want
to use option "action_scaling" or "action_bound_method". Default to None.
.. seealso::

View File

@ -16,8 +16,6 @@ class DDPGPolicy(BasePolicy):
:param torch.optim.Optimizer actor_optim: the optimizer for actor network.
:param torch.nn.Module critic: the critic network. (s, a -> Q(s, a))
:param torch.optim.Optimizer critic_optim: the optimizer for critic network.
:param action_range: the action range (minimum, maximum).
:type action_range: Tuple[float, float]
:param float tau: param for soft update of the target network. Default to 0.005.
:param float gamma: discount factor, in [0, 1]. Default to 0.99.
:param BaseNoise exploration_noise: the exploration noise,
@ -25,6 +23,13 @@ class DDPGPolicy(BasePolicy):
:param bool reward_normalization: normalize the reward to Normal(0, 1),
Default to False.
:param int estimation_step: the number of steps to look ahead. Default to 1.
:param bool action_scaling: whether to map actions from range [-1, 1] to range
[action_spaces.low, action_spaces.high]. Default to True.
:param str action_bound_method: method to bound action to range [-1, 1], can be
either "clip" (for simply clipping the action), "tanh" (for applying tanh
squashing) for now, or empty string for no bounding. Default to "clip".
:param Optional[gym.Space] action_space: env's action space, mandatory if you want
to use option "action_scaling" or "action_bound_method". Default to None.
.. seealso::
@ -38,15 +43,17 @@ class DDPGPolicy(BasePolicy):
actor_optim: Optional[torch.optim.Optimizer],
critic: Optional[torch.nn.Module],
critic_optim: Optional[torch.optim.Optimizer],
action_range: Tuple[float, float],
tau: float = 0.005,
gamma: float = 0.99,
exploration_noise: Optional[BaseNoise] = GaussianNoise(sigma=0.1),
reward_normalization: bool = False,
estimation_step: int = 1,
action_scaling: bool = True,
action_bound_method: str = "clip",
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
super().__init__(action_scaling=action_scaling,
action_bound_method=action_bound_method, **kwargs)
if actor is not None and actor_optim is not None:
self.actor: torch.nn.Module = actor
self.actor_old = deepcopy(actor)
@ -62,9 +69,6 @@ class DDPGPolicy(BasePolicy):
assert 0.0 <= gamma <= 1.0, "gamma should be in [0, 1]"
self._gamma = gamma
self._noise = exploration_noise
self._range = action_range
self._action_bias = (action_range[0] + action_range[1]) / 2.0
self._action_scale = (action_range[1] - action_range[0]) / 2.0
# it is only a little difference to use GaussianNoise
# self.noise = OUNoise()
self._rew_norm = reward_normalization
@ -128,8 +132,6 @@ class DDPGPolicy(BasePolicy):
model = getattr(self, model)
obs = batch[input]
actions, h = model(obs, state=state, info=batch.info)
actions += self._action_bias
actions = actions.clamp(self._range[0], self._range[1])
return Batch(act=actions, state=h)
@staticmethod
@ -168,5 +170,4 @@ class DDPGPolicy(BasePolicy):
def exploration_noise(self, act: np.ndarray, batch: Batch) -> np.ndarray:
if self._noise:
act = act + self._noise(act.shape)
act = act.clip(self._range[0], self._range[1])
return act

View File

@ -49,10 +49,10 @@ class DiscreteSACPolicy(SACPolicy):
estimation_step: int = 1,
**kwargs: Any,
) -> None:
super().__init__(actor, actor_optim, critic1, critic1_optim, critic2,
critic2_optim, (-np.inf, np.inf), tau, gamma, alpha,
reward_normalization, estimation_step,
**kwargs)
super().__init__(
actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim,
tau, gamma, alpha, reward_normalization, estimation_step,
action_scaling=False, action_bound_method="", **kwargs)
self._alpha: Union[float, torch.Tensor]
def forward( # type: ignore

View File

@ -15,6 +15,13 @@ class PGPolicy(BasePolicy):
:param dist_fn: distribution class for computing the action.
:type dist_fn: Type[torch.distributions.Distribution]
:param float discount_factor: in [0, 1]. Default to 0.99.
:param bool action_scaling: whether to map actions from range [-1, 1] to range
[action_spaces.low, action_spaces.high]. Default to True.
:param str action_bound_method: method to bound action to range [-1, 1], can be
either "clip" (for simply clipping the action), "tanh" (for applying tanh
squashing) for now, or empty string for no bounding. Default to "clip".
:param Optional[gym.Space] action_space: env's action space, mandatory if you want
to use option "action_scaling" or "action_bound_method". Default to None.
.. seealso::
@ -29,9 +36,12 @@ class PGPolicy(BasePolicy):
dist_fn: Type[torch.distributions.Distribution],
discount_factor: float = 0.99,
reward_normalization: bool = False,
action_scaling: bool = True,
action_bound_method: str = "clip",
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
super().__init__(action_scaling=action_scaling,
action_bound_method=action_bound_method, **kwargs)
if model is not None:
self.model: torch.nn.Module = model
self.optim = optim

View File

@ -1,7 +1,7 @@
import torch
import numpy as np
from torch import nn
from typing import Any, Dict, List, Type, Tuple, Union, Optional
from typing import Any, Dict, List, Type, Union, Optional
from tianshou.policy import PGPolicy
from tianshou.data import Batch, ReplayBuffer, to_numpy, to_torch_as
@ -23,8 +23,6 @@ class PPOPolicy(PGPolicy):
paper. Default to 0.2.
:param float vf_coef: weight for value loss. Default to 0.5.
:param float ent_coef: weight for entropy loss. Default to 0.01.
:param action_range: the action range (minimum, maximum).
:type action_range: (float, float)
:param float gae_lambda: in [0, 1], param for Generalized Advantage
Estimation. Default to 0.95.
:param float dual_clip: a parameter c mentioned in arXiv:1912.09729 Equ. 5,
@ -38,6 +36,13 @@ class PPOPolicy(PGPolicy):
depends on the size of available memory and the memory cost of the
model; should be as large as possible within the memory constraint.
Default to 256.
:param bool action_scaling: whether to map actions from range [-1, 1] to range
[action_spaces.low, action_spaces.high]. Default to True.
:param str action_bound_method: method to bound action to range [-1, 1], can be
either "clip" (for simply clipping the action), "tanh" (for applying tanh
squashing) for now, or empty string for no bounding. Default to "clip".
:param Optional[gym.Space] action_space: env's action space, mandatory if you want
to use option "action_scaling" or "action_bound_method". Default to None.
.. seealso::
@ -56,7 +61,6 @@ class PPOPolicy(PGPolicy):
eps_clip: float = 0.2,
vf_coef: float = 0.5,
ent_coef: float = 0.01,
action_range: Optional[Tuple[float, float]] = None,
gae_lambda: float = 0.95,
dual_clip: Optional[float] = None,
value_clip: bool = True,
@ -69,7 +73,6 @@ class PPOPolicy(PGPolicy):
self._eps_clip = eps_clip
self._weight_vf = vf_coef
self._weight_ent = ent_coef
self._range = action_range
self.actor = actor
self.critic = critic
self._batch = max_batchsize
@ -135,8 +138,6 @@ class PPOPolicy(PGPolicy):
else:
dist = self.dist_fn(logits)
act = dist.sample()
if self._range:
act = act.clamp(self._range[0], self._range[1])
return Batch(logits=logits, act=act, state=h, dist=dist)
def learn( # type: ignore

View File

@ -6,7 +6,7 @@ from typing import Any, Dict, Tuple, Union, Optional
from tianshou.policy import DDPGPolicy
from tianshou.exploration import BaseNoise
from tianshou.data import Batch, ReplayBuffer
from tianshou.data import Batch, ReplayBuffer, to_torch_as
class SACPolicy(DDPGPolicy):
@ -21,8 +21,6 @@ class SACPolicy(DDPGPolicy):
:param torch.nn.Module critic2: the second critic network. (s, a -> Q(s, a))
:param torch.optim.Optimizer critic2_optim: the optimizer for the second
critic network.
:param action_range: the action range (minimum, maximum).
:type action_range: Tuple[float, float]
:param float tau: param for soft update of the target network. Default to 0.005.
:param float gamma: discount factor, in [0, 1]. Default to 0.99.
:param (float, torch.Tensor, torch.optim.Optimizer) or float alpha: entropy
@ -36,6 +34,13 @@ class SACPolicy(DDPGPolicy):
:param bool deterministic_eval: whether to use deterministic action (mean
of Gaussian policy) instead of stochastic action sampled by the policy.
Default to True.
:param bool action_scaling: whether to map actions from range [-1, 1] to range
[action_spaces.low, action_spaces.high]. Default to True.
:param str action_bound_method: method to bound action to range [-1, 1], can be
either "clip" (for simply clipping the action), "tanh" (for applying tanh
squashing) for now, or empty string for no bounding. Default to "tanh".
:param Optional[gym.Space] action_space: env's action space, mandatory if you want
to use option "action_scaling" or "action_bound_method". Default to None.
.. seealso::
@ -51,7 +56,6 @@ class SACPolicy(DDPGPolicy):
critic1_optim: torch.optim.Optimizer,
critic2: torch.nn.Module,
critic2_optim: torch.optim.Optimizer,
action_range: Tuple[float, float],
tau: float = 0.005,
gamma: float = 0.99,
alpha: Union[float, Tuple[float, torch.Tensor, torch.optim.Optimizer]] = 0.2,
@ -59,11 +63,13 @@ class SACPolicy(DDPGPolicy):
estimation_step: int = 1,
exploration_noise: Optional[BaseNoise] = None,
deterministic_eval: bool = True,
action_bound_method: str = "tanh",
**kwargs: Any,
) -> None:
super().__init__(None, None, None, None, action_range, tau, gamma,
exploration_noise, reward_normalization,
estimation_step, **kwargs)
super().__init__(
None, None, None, None, tau, gamma, exploration_noise,
reward_normalization, estimation_step,
action_bound_method=action_bound_method, **kwargs)
self.actor, self.actor_optim = actor, actor_optim
self.critic1, self.critic1_old = critic1, deepcopy(critic1)
self.critic1_old.eval()
@ -110,24 +116,26 @@ class SACPolicy(DDPGPolicy):
assert isinstance(logits, tuple)
dist = Independent(Normal(*logits), 1)
if self._deterministic_eval and not self.training:
x = logits[0]
act = logits[0]
else:
x = dist.rsample()
y = torch.tanh(x)
act = y * self._action_scale + self._action_bias
# __eps is used to avoid log of zero/negative number.
y = self._action_scale * (1 - y.pow(2)) + self.__eps
# Compute logprob from Gaussian, and then apply correction for Tanh squashing.
act = dist.rsample()
log_prob = dist.log_prob(act).unsqueeze(-1)
if self.action_bound_method == "tanh" and self.action_space is not None:
# apply correction for Tanh squashing when computing logprob from Gaussian
# You can check out the original SAC paper (arXiv 1801.01290): Eq 21.
# in appendix C to get some understanding of this equation.
log_prob = dist.log_prob(x).unsqueeze(-1)
log_prob = log_prob - torch.log(y).sum(-1, keepdim=True)
if self.action_scaling:
action_scale = to_torch_as(
(self.action_space.high - self.action_space.low) / 2.0, act)
else:
action_scale = 1.0 # type: ignore
squashed_action = torch.tanh(act)
log_prob = log_prob - torch.log(
action_scale * (1 - squashed_action.pow(2)) + self.__eps
).sum(-1, keepdim=True)
return Batch(logits=logits, act=act, state=h, dist=dist, log_prob=log_prob)
def _target_q(
self, buffer: ReplayBuffer, indice: np.ndarray
) -> torch.Tensor:
def _target_q(self, buffer: ReplayBuffer, indice: np.ndarray) -> torch.Tensor:
batch = buffer[indice] # batch.obs: s_{t+n}
obs_next_result = self(batch, input='obs_next')
a_ = obs_next_result.act

View File

@ -1,7 +1,7 @@
import torch
import numpy as np
from copy import deepcopy
from typing import Any, Dict, Tuple, Optional
from typing import Any, Dict, Optional
from tianshou.policy import DDPGPolicy
from tianshou.data import Batch, ReplayBuffer
@ -20,8 +20,6 @@ class TD3Policy(DDPGPolicy):
:param torch.nn.Module critic2: the second critic network. (s, a -> Q(s, a))
:param torch.optim.Optimizer critic2_optim: the optimizer for the second
critic network.
:param action_range: the action range (minimum, maximum).
:type action_range: Tuple[float, float]
:param float tau: param for soft update of the target network. Default to 0.005.
:param float gamma: discount factor, in [0, 1]. Default to 0.99.
:param float exploration_noise: the exploration noise, add to the action.
@ -34,6 +32,13 @@ class TD3Policy(DDPGPolicy):
Default to 0.5.
:param bool reward_normalization: normalize the reward to Normal(0, 1).
Default to False.
:param bool action_scaling: whether to map actions from range [-1, 1] to range
[action_spaces.low, action_spaces.high]. Default to True.
:param str action_bound_method: method to bound action to range [-1, 1], can be
either "clip" (for simply clipping the action), "tanh" (for applying tanh
squashing) for now, or empty string for no bounding. Default to "clip".
:param Optional[gym.Space] action_space: env's action space, mandatory if you want
to use option "action_scaling" or "action_bound_method". Default to None.
.. seealso::
@ -49,7 +54,6 @@ class TD3Policy(DDPGPolicy):
critic1_optim: torch.optim.Optimizer,
critic2: torch.nn.Module,
critic2_optim: torch.optim.Optimizer,
action_range: Tuple[float, float],
tau: float = 0.005,
gamma: float = 0.99,
exploration_noise: Optional[BaseNoise] = GaussianNoise(sigma=0.1),
@ -60,7 +64,7 @@ class TD3Policy(DDPGPolicy):
estimation_step: int = 1,
**kwargs: Any,
) -> None:
super().__init__(actor, actor_optim, None, None, action_range, tau, gamma,
super().__init__(actor, actor_optim, None, None, tau, gamma,
exploration_noise, reward_normalization,
estimation_step, **kwargs)
self.critic1, self.critic1_old = critic1, deepcopy(critic1)
@ -98,7 +102,6 @@ class TD3Policy(DDPGPolicy):
if self._noise_clip > 0.0:
noise = noise.clamp(-self._noise_clip, self._noise_clip)
a_ += noise
a_ = a_.clamp(self._range[0], self._range[1])
target_q = torch.min(
self.critic1_old(batch.obs_next, a_),
self.critic2_old(batch.obs_next, a_))