Add auto alpha tuning and exploration noise for sac. (#80)

Add class BaseNoise and GaussianNoise for the concept of exploration noise.
Add new test for sac tested in MountainCarContinuous-v0,
which should benefits from the two above new feature.
This commit is contained in:
danagi 2020-06-16 22:17:28 +08:00 committed by GitHub
parent 263e490b76
commit c59ad40aef
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 255 additions and 49 deletions

View File

@ -10,6 +10,7 @@ from tianshou.env import VectorEnv
from tianshou.policy import DDPGPolicy from tianshou.policy import DDPGPolicy
from tianshou.trainer import offpolicy_trainer from tianshou.trainer import offpolicy_trainer
from tianshou.data import Collector, ReplayBuffer from tianshou.data import Collector, ReplayBuffer
from tianshou.exploration import GaussianNoise
if __name__ == '__main__': if __name__ == '__main__':
from net import Actor, Critic from net import Actor, Critic
@ -78,7 +79,7 @@ def test_ddpg(args=get_args()):
critic_optim = torch.optim.Adam(critic.parameters(), lr=args.critic_lr) critic_optim = torch.optim.Adam(critic.parameters(), lr=args.critic_lr)
policy = DDPGPolicy( policy = DDPGPolicy(
actor, actor_optim, critic, critic_optim, actor, actor_optim, critic, critic_optim,
args.tau, args.gamma, args.exploration_noise, args.tau, args.gamma, GaussianNoise(sigma=args.exploration_noise),
[env.action_space.low[0], env.action_space.high[0]], [env.action_space.low[0], env.action_space.high[0]],
reward_normalization=args.rew_norm, reward_normalization=args.rew_norm,
ignore_done=args.ignore_done, ignore_done=args.ignore_done,

View File

@ -0,0 +1,130 @@
import os
import gym
import torch
import pprint
import argparse
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from tianshou.policy import SACPolicy
from tianshou.trainer import offpolicy_trainer
from tianshou.data import Collector, ReplayBuffer
from tianshou.env import VectorEnv
from tianshou.exploration import OUNoise
if __name__ == '__main__':
from net import ActorProb, Critic
else: # pytest
from test.continuous.net import ActorProb, Critic
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='MountainCarContinuous-v0')
parser.add_argument('--seed', type=int, default=1626)
parser.add_argument('--buffer-size', type=int, default=50000)
parser.add_argument('--actor-lr', type=float, default=3e-4)
parser.add_argument('--critic-lr', type=float, default=3e-4)
parser.add_argument('--alpha-lr', type=float, default=3e-4)
parser.add_argument('--noise_std', type=float, default=0.5)
parser.add_argument('--gamma', type=float, default=0.99)
parser.add_argument('--tau', type=float, default=0.005)
parser.add_argument('--auto_alpha', type=bool, default=True)
parser.add_argument('--alpha', type=float, default=0.2)
parser.add_argument('--epoch', type=int, default=20)
parser.add_argument('--step-per-epoch', type=int, default=2400)
parser.add_argument('--collect-per-step', type=int, default=1)
parser.add_argument('--batch-size', type=int, default=128)
parser.add_argument('--layer-num', type=int, default=1)
parser.add_argument('--training-num', type=int, default=80)
parser.add_argument('--test-num', type=int, default=100)
parser.add_argument('--logdir', type=str, default='log')
parser.add_argument('--render', type=float, default=1.0/35.0)
parser.add_argument('--rew-norm', type=bool, default=False)
parser.add_argument(
'--device', type=str,
default='cuda' if torch.cuda.is_available() else 'cpu')
args = parser.parse_known_args()[0]
return args
def test_sac(args=get_args()):
env = gym.make(args.task)
args.state_shape = env.observation_space.shape or env.observation_space.n
args.action_shape = env.action_space.shape or env.action_space.n
args.max_action = env.action_space.high[0]
# train_envs = gym.make(args.task)
train_envs = VectorEnv(
[lambda: gym.make(args.task) for _ in range(args.training_num)])
# test_envs = gym.make(args.task)
test_envs = VectorEnv(
[lambda: gym.make(args.task) for _ in range(args.test_num)])
# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
train_envs.seed(args.seed)
test_envs.seed(args.seed)
# model
actor = ActorProb(
args.layer_num, args.state_shape, args.action_shape,
args.max_action, args.device
).to(args.device)
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
critic1 = Critic(
args.layer_num, args.state_shape, args.action_shape, args.device
).to(args.device)
critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr)
critic2 = Critic(
args.layer_num, args.state_shape, args.action_shape, args.device
).to(args.device)
critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)
if args.auto_alpha:
target_entropy = -np.prod(env.action_space.shape)
log_alpha = torch.zeros(1, requires_grad=True, device=args.device)
alpha_optim = torch.optim.Adam([log_alpha], lr=args.alpha_lr)
alpha = (target_entropy, log_alpha, alpha_optim)
else:
alpha = args.alpha
policy = SACPolicy(
actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim,
args.tau, args.gamma, alpha,
[env.action_space.low[0], env.action_space.high[0]],
reward_normalization=args.rew_norm, ignore_done=True,
exploration_noise=OUNoise(0.0, args.noise_std))
# collector
train_collector = Collector(
policy, train_envs, ReplayBuffer(args.buffer_size))
test_collector = Collector(policy, test_envs)
# train_collector.collect(n_step=args.buffer_size)
# log
log_path = os.path.join(args.logdir, args.task, 'sac')
writer = SummaryWriter(log_path)
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
def stop_fn(x):
return x >= env.spec.reward_threshold
# trainer
result = offpolicy_trainer(
policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.collect_per_step, args.test_num,
args.batch_size, stop_fn=stop_fn, save_fn=save_fn, writer=writer)
assert stop_fn(result['best_reward'])
train_collector.close()
test_collector.close()
if __name__ == '__main__':
pprint.pprint(result)
# Let's watch its performance!
env = gym.make(args.task)
collector = Collector(policy, env)
result = collector.collect(n_episode=1, render=args.render)
print(f'Final reward: {result["rew"]}, length: {result["len"]}')
collector.close()
if __name__ == '__main__':
test_sac()

View File

@ -10,6 +10,7 @@ from tianshou.env import VectorEnv
from tianshou.policy import TD3Policy from tianshou.policy import TD3Policy
from tianshou.trainer import offpolicy_trainer from tianshou.trainer import offpolicy_trainer
from tianshou.data import Collector, ReplayBuffer from tianshou.data import Collector, ReplayBuffer
from tianshou.exploration import GaussianNoise
if __name__ == '__main__': if __name__ == '__main__':
from net import Actor, Critic from net import Actor, Critic
@ -85,8 +86,8 @@ def test_td3(args=get_args()):
critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr) critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)
policy = TD3Policy( policy = TD3Policy(
actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim, actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim,
args.tau, args.gamma, args.exploration_noise, args.policy_noise, args.tau, args.gamma, GaussianNoise(sigma=args.exploration_noise),
args.update_actor_freq, args.noise_clip, args.policy_noise, args.update_actor_freq, args.noise_clip,
[env.action_space.low[0], env.action_space.high[0]], [env.action_space.low[0], env.action_space.high[0]],
reward_normalization=args.rew_norm, reward_normalization=args.rew_norm,
ignore_done=args.ignore_done, ignore_done=args.ignore_done,

View File

@ -1,5 +1,7 @@
from tianshou.exploration.random import OUNoise from tianshou.exploration.random import BaseNoise, GaussianNoise, OUNoise
__all__ = [ __all__ = [
'BaseNoise',
'GaussianNoise',
'OUNoise', 'OUNoise',
] ]

View File

@ -1,8 +1,42 @@
import numpy as np import numpy as np
from typing import Union, Optional from typing import Union, Optional
from abc import ABC, abstractmethod
class OUNoise(object): class BaseNoise(ABC, object):
"""The action noise base class."""
def __init__(self, **kwargs) -> None:
super(BaseNoise, self).__init__()
@abstractmethod
def __call__(self, **kwargs) -> np.ndarray:
"""Generate new noise."""
raise NotImplementedError
def reset(self, **kwargs) -> None:
"""Reset to the initial state."""
pass
class GaussianNoise(BaseNoise):
"""Class for vanilla gaussian process,
used for exploration in DDPG by default.
"""
def __init__(self,
mu: float = 0.0,
sigma: float = 1.0):
super().__init__()
self._mu = mu
assert 0 <= sigma, 'noise std should not be negative'
self._sigma = sigma
def __call__(self, size: tuple) -> np.ndarray:
return np.random.normal(self._mu, self._sigma, size)
class OUNoise(BaseNoise):
"""Class for Ornstein-Uhlenbeck process, as used for exploration in DDPG. """Class for Ornstein-Uhlenbeck process, as used for exploration in DDPG.
Usage: Usage:
:: ::
@ -19,26 +53,31 @@ class OUNoise(object):
""" """
def __init__(self, def __init__(self,
mu: float = 0.0,
sigma: float = 0.3, sigma: float = 0.3,
theta: float = 0.15, theta: float = 0.15,
dt: float = 1e-2, dt: float = 1e-2,
x0: Optional[Union[float, np.ndarray]] = None x0: Optional[Union[float, np.ndarray]] = None
) -> None: ) -> None:
self.alpha = theta * dt super(BaseNoise, self).__init__()
self.beta = sigma * np.sqrt(dt) self._mu = mu
self.x0 = x0 self._alpha = theta * dt
self._beta = sigma * np.sqrt(dt)
self._x0 = x0
self.reset() self.reset()
def __call__(self, size: tuple, mu: float = .1) -> np.ndarray: def __call__(self, size: tuple, mu: Optional[float] = None) -> np.ndarray:
"""Generate new noise. Return a ``numpy.ndarray`` which size is equal """Generate new noise. Return a ``numpy.ndarray`` which size is equal
to ``size``. to ``size``.
""" """
if self.x is None or self.x.shape != size: if self._x is None or self._x.shape != size:
self.x = 0 self._x = 0
r = self.beta * np.random.normal(size=size) if mu is None:
self.x = self.x + self.alpha * (mu - self.x) + r mu = self._mu
return self.x r = self._beta * np.random.normal(size=size)
self._x = self._x + self._alpha * (mu - self._x) + r
return self._x
def reset(self) -> None: def reset(self) -> None:
"""Reset to the initial state.""" """Reset to the initial state."""
self.x = None self._x = self._x0

View File

@ -5,7 +5,7 @@ import torch.nn.functional as F
from typing import Dict, Tuple, Union, Optional from typing import Dict, Tuple, Union, Optional
from tianshou.policy import BasePolicy from tianshou.policy import BasePolicy
# from tianshou.exploration import OUNoise from tianshou.exploration import BaseNoise, GaussianNoise
from tianshou.data import Batch, ReplayBuffer, to_torch_as from tianshou.data import Batch, ReplayBuffer, to_torch_as
@ -21,8 +21,8 @@ class DDPGPolicy(BasePolicy):
:param float tau: param for soft update of the target network, defaults to :param float tau: param for soft update of the target network, defaults to
0.005. 0.005.
:param float gamma: discount factor, in [0, 1], defaults to 0.99. :param float gamma: discount factor, in [0, 1], defaults to 0.99.
:param float exploration_noise: the noise intensity, add to the action, :param BaseNoise exploration_noise: the exploration noise,
defaults to 0.1. add to the action, defaults to ``GaussianNoise(sigma=0.1)``.
:param action_range: the action range (minimum, maximum). :param action_range: the action range (minimum, maximum).
:type action_range: (float, float) :type action_range: (float, float)
:param bool reward_normalization: normalize the reward to Normal(0, 1), :param bool reward_normalization: normalize the reward to Normal(0, 1),
@ -45,7 +45,8 @@ class DDPGPolicy(BasePolicy):
critic_optim: torch.optim.Optimizer, critic_optim: torch.optim.Optimizer,
tau: float = 0.005, tau: float = 0.005,
gamma: float = 0.99, gamma: float = 0.99,
exploration_noise: float = 0.1, exploration_noise: Optional[BaseNoise]
= GaussianNoise(sigma=0.1),
action_range: Optional[Tuple[float, float]] = None, action_range: Optional[Tuple[float, float]] = None,
reward_normalization: bool = False, reward_normalization: bool = False,
ignore_done: bool = False, ignore_done: bool = False,
@ -64,8 +65,7 @@ class DDPGPolicy(BasePolicy):
self._tau = tau self._tau = tau
assert 0 <= gamma <= 1, 'gamma should in [0, 1]' assert 0 <= gamma <= 1, 'gamma should in [0, 1]'
self._gamma = gamma self._gamma = gamma
assert 0 <= exploration_noise, 'noise should not be negative' self._noise = exploration_noise
self._eps = exploration_noise
assert action_range is not None assert action_range is not None
self._range = action_range self._range = action_range
self._action_bias = (action_range[0] + action_range[1]) / 2 self._action_bias = (action_range[0] + action_range[1]) / 2
@ -77,9 +77,9 @@ class DDPGPolicy(BasePolicy):
assert estimation_step > 0, 'estimation_step should greater than 0' assert estimation_step > 0, 'estimation_step should greater than 0'
self._n_step = estimation_step self._n_step = estimation_step
def set_eps(self, eps: float) -> None: def set_exp_noise(self, noise: Optional[BaseNoise]) -> None:
"""Set the eps for exploration.""" """Set the exploration noise."""
self._eps = eps self._noise = noise
def train(self) -> None: def train(self) -> None:
"""Set the module in training mode, except for the target network.""" """Set the module in training mode, except for the target network."""
@ -106,7 +106,8 @@ class DDPGPolicy(BasePolicy):
batch = buffer[indice] # batch.obs_next: s_{t+n} batch = buffer[indice] # batch.obs_next: s_{t+n}
with torch.no_grad(): with torch.no_grad():
target_q = self.critic_old(batch.obs_next, self( target_q = self.critic_old(batch.obs_next, self(
batch, model='actor_old', input='obs_next', eps=0).act) batch, model='actor_old', input='obs_next',
explorating=False).act)
return target_q return target_q
def process_fn(self, batch: Batch, buffer: ReplayBuffer, def process_fn(self, batch: Batch, buffer: ReplayBuffer,
@ -122,7 +123,7 @@ class DDPGPolicy(BasePolicy):
state: Optional[Union[dict, Batch, np.ndarray]] = None, state: Optional[Union[dict, Batch, np.ndarray]] = None,
model: str = 'actor', model: str = 'actor',
input: str = 'obs', input: str = 'obs',
eps: Optional[float] = None, explorating: bool = True,
**kwargs) -> Batch: **kwargs) -> Batch:
"""Compute action over the given batch data. """Compute action over the given batch data.
@ -142,14 +143,8 @@ class DDPGPolicy(BasePolicy):
obs = getattr(batch, input) obs = getattr(batch, input)
logits, h = model(obs, state=state, info=batch.info) logits, h = model(obs, state=state, info=batch.info)
logits += self._action_bias logits += self._action_bias
if eps is None: if self.training and explorating:
eps = self._eps logits += to_torch_as(self._noise(logits.shape), logits)
if eps > 0:
# noise = np.random.normal(0, eps, size=logits.shape)
# logits += to_torch(noise, device=logits.device)
# noise = self.noise(logits.shape, eps)
logits += torch.randn(
size=logits.shape, device=logits.device) * eps
logits = logits.clamp(self._range[0], self._range[1]) logits = logits.clamp(self._range[0], self._range[1])
return Batch(act=logits, state=h) return Batch(act=logits, state=h)
@ -161,7 +156,8 @@ class DDPGPolicy(BasePolicy):
self.critic_optim.zero_grad() self.critic_optim.zero_grad()
critic_loss.backward() critic_loss.backward()
self.critic_optim.step() self.critic_optim.step()
actor_loss = -self.critic(batch.obs, self(batch, eps=0).act).mean() action = self(batch, explorating=False).act
actor_loss = -self.critic(batch.obs, action).mean()
self.actor_optim.zero_grad() self.actor_optim.zero_grad()
actor_loss.backward() actor_loss.backward()
self.actor_optim.step() self.actor_optim.step()

View File

@ -7,6 +7,7 @@ from typing import Dict, Tuple, Union, Optional
from tianshou.policy import DDPGPolicy from tianshou.policy import DDPGPolicy
from tianshou.policy.dist import DiagGaussian from tianshou.policy.dist import DiagGaussian
from tianshou.data import Batch, to_torch_as, ReplayBuffer from tianshou.data import Batch, to_torch_as, ReplayBuffer
from tianshou.exploration import BaseNoise
class SACPolicy(DDPGPolicy): class SACPolicy(DDPGPolicy):
@ -28,13 +29,18 @@ class SACPolicy(DDPGPolicy):
:param float gamma: discount factor, in [0, 1], defaults to 0.99. :param float gamma: discount factor, in [0, 1], defaults to 0.99.
:param float exploration_noise: the noise intensity, add to the action, :param float exploration_noise: the noise intensity, add to the action,
defaults to 0.1. defaults to 0.1.
:param float alpha: entropy regularization coefficient, default to 0.2. :param (float, torch.Tensor, torch.optim.Optimizer) or float alpha: entropy
regularization coefficient, default to 0.2.
If a tuple (target_entropy, log_alpha, alpha_optim) is provided, then
alpha is automatatically tuned.
:param action_range: the action range (minimum, maximum). :param action_range: the action range (minimum, maximum).
:type action_range: (float, float) :type action_range: (float, float)
:param bool reward_normalization: normalize the reward to Normal(0, 1), :param bool reward_normalization: normalize the reward to Normal(0, 1),
defaults to ``False``. defaults to ``False``.
:param bool ignore_done: ignore the done flag while training the policy, :param bool ignore_done: ignore the done flag while training the policy,
defaults to ``False``. defaults to ``False``.
:param BaseNoise exploration_noise: add a noise to action for exploration.
This is useful when solving hard-exploration problem.
.. seealso:: .. seealso::
@ -51,13 +57,15 @@ class SACPolicy(DDPGPolicy):
critic2_optim: torch.optim.Optimizer, critic2_optim: torch.optim.Optimizer,
tau: float = 0.005, tau: float = 0.005,
gamma: float = 0.99, gamma: float = 0.99,
alpha: float = 0.2, alpha: Tuple[float, torch.Tensor, torch.optim.Optimizer]
or float = 0.2,
action_range: Optional[Tuple[float, float]] = None, action_range: Optional[Tuple[float, float]] = None,
reward_normalization: bool = False, reward_normalization: bool = False,
ignore_done: bool = False, ignore_done: bool = False,
estimation_step: int = 1, estimation_step: int = 1,
exploration_noise: Optional[BaseNoise] = None,
**kwargs) -> None: **kwargs) -> None:
super().__init__(None, None, None, None, tau, gamma, 0, super().__init__(None, None, None, None, tau, gamma, exploration_noise,
action_range, reward_normalization, ignore_done, action_range, reward_normalization, ignore_done,
estimation_step, **kwargs) estimation_step, **kwargs)
self.actor, self.actor_optim = actor, actor_optim self.actor, self.actor_optim = actor, actor_optim
@ -67,7 +75,18 @@ class SACPolicy(DDPGPolicy):
self.critic2, self.critic2_old = critic2, deepcopy(critic2) self.critic2, self.critic2_old = critic2, deepcopy(critic2)
self.critic2_old.eval() self.critic2_old.eval()
self.critic2_optim = critic2_optim self.critic2_optim = critic2_optim
self._alpha = alpha
self._automatic_alpha_tuning = not isinstance(alpha, float)
if self._automatic_alpha_tuning:
self._target_entropy = alpha[0]
assert(alpha[1].shape == torch.Size([1])
and alpha[1].requires_grad)
self._log_alpha = alpha[1]
self._alpha_optim = alpha[2]
self._alpha = self._log_alpha.exp()
else:
self._alpha = alpha
self.__eps = np.finfo(np.float32).eps.item() self.__eps = np.finfo(np.float32).eps.item()
def train(self) -> None: def train(self) -> None:
@ -92,7 +111,9 @@ class SACPolicy(DDPGPolicy):
def forward(self, batch: Batch, def forward(self, batch: Batch,
state: Optional[Union[dict, Batch, np.ndarray]] = None, state: Optional[Union[dict, Batch, np.ndarray]] = None,
input: str = 'obs', **kwargs) -> Batch: input: str = 'obs',
explorating: bool = True,
**kwargs) -> Batch:
obs = getattr(batch, input) obs = getattr(batch, input)
logits, h = self.actor(obs, state=state, info=batch.info) logits, h = self.actor(obs, state=state, info=batch.info)
assert isinstance(logits, tuple) assert isinstance(logits, tuple)
@ -100,9 +121,10 @@ class SACPolicy(DDPGPolicy):
x = dist.rsample() x = dist.rsample()
y = torch.tanh(x) y = torch.tanh(x)
act = y * self._action_scale + self._action_bias act = y * self._action_scale + self._action_bias
log_prob = dist.log_prob(x) - torch.log( y = self._action_scale * (1 - y.pow(2)) + self.__eps
self._action_scale * (1 - y.pow(2)) + self.__eps log_prob = dist.log_prob(x) - torch.log(y).sum(-1, keepdim=True)
).sum(-1, keepdim=True) if self._noise is not None and self.training and explorating:
act += to_torch_as(self._noise(act.shape), act)
act = act.clamp(self._range[0], self._range[1]) act = act.clamp(self._range[0], self._range[1])
return Batch( return Batch(
logits=logits, act=act, state=h, dist=dist, log_prob=log_prob) logits=logits, act=act, state=h, dist=dist, log_prob=log_prob)
@ -111,7 +133,7 @@ class SACPolicy(DDPGPolicy):
indice: np.ndarray) -> torch.Tensor: indice: np.ndarray) -> torch.Tensor:
batch = buffer[indice] # batch.obs: s_{t+n} batch = buffer[indice] # batch.obs: s_{t+n}
with torch.no_grad(): with torch.no_grad():
obs_next_result = self(batch, input='obs_next') obs_next_result = self(batch, input='obs_next', explorating=False)
a_ = obs_next_result.act a_ = obs_next_result.act
batch.act = to_torch_as(batch.act, a_) batch.act = to_torch_as(batch.act, a_)
target_q = torch.min( target_q = torch.min(
@ -135,7 +157,7 @@ class SACPolicy(DDPGPolicy):
critic2_loss.backward() critic2_loss.backward()
self.critic2_optim.step() self.critic2_optim.step()
# actor # actor
obs_result = self(batch) obs_result = self(batch, explorating=False)
a = obs_result.act a = obs_result.act
current_q1a = self.critic1(batch.obs, a) current_q1a = self.critic1(batch.obs, a)
current_q2a = self.critic2(batch.obs, a) current_q2a = self.critic2(batch.obs, a)
@ -144,9 +166,22 @@ class SACPolicy(DDPGPolicy):
self.actor_optim.zero_grad() self.actor_optim.zero_grad()
actor_loss.backward() actor_loss.backward()
self.actor_optim.step() self.actor_optim.step()
if self._automatic_alpha_tuning:
log_prob = (obs_result.log_prob + self._target_entropy).detach()
alpha_loss = -(self._log_alpha * log_prob).mean()
self._alpha_optim.zero_grad()
alpha_loss.backward()
self._alpha_optim.step()
self._alpha = self._log_alpha.exp()
self.sync_weight() self.sync_weight()
return {
result = {
'loss/actor': actor_loss.item(), 'loss/actor': actor_loss.item(),
'loss/critic1': critic1_loss.item(), 'loss/critic1': critic1_loss.item(),
'loss/critic2': critic2_loss.item(), 'loss/critic2': critic2_loss.item(),
} }
if self._automatic_alpha_tuning:
result['loss/alpha'] = alpha_loss.item()
return result

View File

@ -6,6 +6,7 @@ from typing import Dict, Tuple, Optional
from tianshou.policy import DDPGPolicy from tianshou.policy import DDPGPolicy
from tianshou.data import Batch, ReplayBuffer from tianshou.data import Batch, ReplayBuffer
from tianshou.exploration import BaseNoise, GaussianNoise
class TD3Policy(DDPGPolicy): class TD3Policy(DDPGPolicy):
@ -26,8 +27,8 @@ class TD3Policy(DDPGPolicy):
:param float tau: param for soft update of the target network, defaults to :param float tau: param for soft update of the target network, defaults to
0.005. 0.005.
:param float gamma: discount factor, in [0, 1], defaults to 0.99. :param float gamma: discount factor, in [0, 1], defaults to 0.99.
:param float exploration_noise: the noise intensity, add to the action, :param float exploration_noise: the exploration noise, add to the action,
defaults to 0.1. defaults to ``GaussianNoise(sigma=0.1)``
:param float policy_noise: the noise used in updating policy network, :param float policy_noise: the noise used in updating policy network,
default to 0.2. default to 0.2.
:param int update_actor_freq: the update frequency of actor network, :param int update_actor_freq: the update frequency of actor network,
@ -56,7 +57,8 @@ class TD3Policy(DDPGPolicy):
critic2_optim: torch.optim.Optimizer, critic2_optim: torch.optim.Optimizer,
tau: float = 0.005, tau: float = 0.005,
gamma: float = 0.99, gamma: float = 0.99,
exploration_noise: float = 0.1, exploration_noise: Optional[BaseNoise]
= GaussianNoise(sigma=0.1),
policy_noise: float = 0.2, policy_noise: float = 0.2,
update_actor_freq: int = 2, update_actor_freq: int = 2,
noise_clip: float = 0.5, noise_clip: float = 0.5,