Add auto alpha tuning and exploration noise for sac. (#80)
Add class BaseNoise and GaussianNoise for the concept of exploration noise. Add new test for sac tested in MountainCarContinuous-v0, which should benefits from the two above new feature.
This commit is contained in:
parent
263e490b76
commit
c59ad40aef
@ -10,6 +10,7 @@ from tianshou.env import VectorEnv
|
||||
from tianshou.policy import DDPGPolicy
|
||||
from tianshou.trainer import offpolicy_trainer
|
||||
from tianshou.data import Collector, ReplayBuffer
|
||||
from tianshou.exploration import GaussianNoise
|
||||
|
||||
if __name__ == '__main__':
|
||||
from net import Actor, Critic
|
||||
@ -78,7 +79,7 @@ def test_ddpg(args=get_args()):
|
||||
critic_optim = torch.optim.Adam(critic.parameters(), lr=args.critic_lr)
|
||||
policy = DDPGPolicy(
|
||||
actor, actor_optim, critic, critic_optim,
|
||||
args.tau, args.gamma, args.exploration_noise,
|
||||
args.tau, args.gamma, GaussianNoise(sigma=args.exploration_noise),
|
||||
[env.action_space.low[0], env.action_space.high[0]],
|
||||
reward_normalization=args.rew_norm,
|
||||
ignore_done=args.ignore_done,
|
||||
|
130
test/continuous/test_sac_with_mcc.py
Normal file
130
test/continuous/test_sac_with_mcc.py
Normal file
@ -0,0 +1,130 @@
|
||||
import os
|
||||
import gym
|
||||
import torch
|
||||
import pprint
|
||||
import argparse
|
||||
import numpy as np
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from tianshou.policy import SACPolicy
|
||||
from tianshou.trainer import offpolicy_trainer
|
||||
from tianshou.data import Collector, ReplayBuffer
|
||||
from tianshou.env import VectorEnv
|
||||
from tianshou.exploration import OUNoise
|
||||
|
||||
if __name__ == '__main__':
|
||||
from net import ActorProb, Critic
|
||||
else: # pytest
|
||||
from test.continuous.net import ActorProb, Critic
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--task', type=str, default='MountainCarContinuous-v0')
|
||||
parser.add_argument('--seed', type=int, default=1626)
|
||||
parser.add_argument('--buffer-size', type=int, default=50000)
|
||||
parser.add_argument('--actor-lr', type=float, default=3e-4)
|
||||
parser.add_argument('--critic-lr', type=float, default=3e-4)
|
||||
parser.add_argument('--alpha-lr', type=float, default=3e-4)
|
||||
parser.add_argument('--noise_std', type=float, default=0.5)
|
||||
parser.add_argument('--gamma', type=float, default=0.99)
|
||||
parser.add_argument('--tau', type=float, default=0.005)
|
||||
parser.add_argument('--auto_alpha', type=bool, default=True)
|
||||
parser.add_argument('--alpha', type=float, default=0.2)
|
||||
parser.add_argument('--epoch', type=int, default=20)
|
||||
parser.add_argument('--step-per-epoch', type=int, default=2400)
|
||||
parser.add_argument('--collect-per-step', type=int, default=1)
|
||||
parser.add_argument('--batch-size', type=int, default=128)
|
||||
parser.add_argument('--layer-num', type=int, default=1)
|
||||
parser.add_argument('--training-num', type=int, default=80)
|
||||
parser.add_argument('--test-num', type=int, default=100)
|
||||
parser.add_argument('--logdir', type=str, default='log')
|
||||
parser.add_argument('--render', type=float, default=1.0/35.0)
|
||||
parser.add_argument('--rew-norm', type=bool, default=False)
|
||||
parser.add_argument(
|
||||
'--device', type=str,
|
||||
default='cuda' if torch.cuda.is_available() else 'cpu')
|
||||
args = parser.parse_known_args()[0]
|
||||
return args
|
||||
|
||||
|
||||
def test_sac(args=get_args()):
|
||||
env = gym.make(args.task)
|
||||
args.state_shape = env.observation_space.shape or env.observation_space.n
|
||||
args.action_shape = env.action_space.shape or env.action_space.n
|
||||
args.max_action = env.action_space.high[0]
|
||||
# train_envs = gym.make(args.task)
|
||||
train_envs = VectorEnv(
|
||||
[lambda: gym.make(args.task) for _ in range(args.training_num)])
|
||||
# test_envs = gym.make(args.task)
|
||||
test_envs = VectorEnv(
|
||||
[lambda: gym.make(args.task) for _ in range(args.test_num)])
|
||||
# seed
|
||||
np.random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
train_envs.seed(args.seed)
|
||||
test_envs.seed(args.seed)
|
||||
# model
|
||||
actor = ActorProb(
|
||||
args.layer_num, args.state_shape, args.action_shape,
|
||||
args.max_action, args.device
|
||||
).to(args.device)
|
||||
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
|
||||
critic1 = Critic(
|
||||
args.layer_num, args.state_shape, args.action_shape, args.device
|
||||
).to(args.device)
|
||||
critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr)
|
||||
critic2 = Critic(
|
||||
args.layer_num, args.state_shape, args.action_shape, args.device
|
||||
).to(args.device)
|
||||
critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)
|
||||
|
||||
if args.auto_alpha:
|
||||
target_entropy = -np.prod(env.action_space.shape)
|
||||
log_alpha = torch.zeros(1, requires_grad=True, device=args.device)
|
||||
alpha_optim = torch.optim.Adam([log_alpha], lr=args.alpha_lr)
|
||||
alpha = (target_entropy, log_alpha, alpha_optim)
|
||||
else:
|
||||
alpha = args.alpha
|
||||
|
||||
policy = SACPolicy(
|
||||
actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim,
|
||||
args.tau, args.gamma, alpha,
|
||||
[env.action_space.low[0], env.action_space.high[0]],
|
||||
reward_normalization=args.rew_norm, ignore_done=True,
|
||||
exploration_noise=OUNoise(0.0, args.noise_std))
|
||||
# collector
|
||||
train_collector = Collector(
|
||||
policy, train_envs, ReplayBuffer(args.buffer_size))
|
||||
test_collector = Collector(policy, test_envs)
|
||||
# train_collector.collect(n_step=args.buffer_size)
|
||||
# log
|
||||
log_path = os.path.join(args.logdir, args.task, 'sac')
|
||||
writer = SummaryWriter(log_path)
|
||||
|
||||
def save_fn(policy):
|
||||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||
|
||||
def stop_fn(x):
|
||||
return x >= env.spec.reward_threshold
|
||||
|
||||
# trainer
|
||||
result = offpolicy_trainer(
|
||||
policy, train_collector, test_collector, args.epoch,
|
||||
args.step_per_epoch, args.collect_per_step, args.test_num,
|
||||
args.batch_size, stop_fn=stop_fn, save_fn=save_fn, writer=writer)
|
||||
assert stop_fn(result['best_reward'])
|
||||
train_collector.close()
|
||||
test_collector.close()
|
||||
if __name__ == '__main__':
|
||||
pprint.pprint(result)
|
||||
# Let's watch its performance!
|
||||
env = gym.make(args.task)
|
||||
collector = Collector(policy, env)
|
||||
result = collector.collect(n_episode=1, render=args.render)
|
||||
print(f'Final reward: {result["rew"]}, length: {result["len"]}')
|
||||
collector.close()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_sac()
|
@ -10,6 +10,7 @@ from tianshou.env import VectorEnv
|
||||
from tianshou.policy import TD3Policy
|
||||
from tianshou.trainer import offpolicy_trainer
|
||||
from tianshou.data import Collector, ReplayBuffer
|
||||
from tianshou.exploration import GaussianNoise
|
||||
|
||||
if __name__ == '__main__':
|
||||
from net import Actor, Critic
|
||||
@ -85,8 +86,8 @@ def test_td3(args=get_args()):
|
||||
critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)
|
||||
policy = TD3Policy(
|
||||
actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim,
|
||||
args.tau, args.gamma, args.exploration_noise, args.policy_noise,
|
||||
args.update_actor_freq, args.noise_clip,
|
||||
args.tau, args.gamma, GaussianNoise(sigma=args.exploration_noise),
|
||||
args.policy_noise, args.update_actor_freq, args.noise_clip,
|
||||
[env.action_space.low[0], env.action_space.high[0]],
|
||||
reward_normalization=args.rew_norm,
|
||||
ignore_done=args.ignore_done,
|
||||
|
@ -1,5 +1,7 @@
|
||||
from tianshou.exploration.random import OUNoise
|
||||
from tianshou.exploration.random import BaseNoise, GaussianNoise, OUNoise
|
||||
|
||||
__all__ = [
|
||||
'BaseNoise',
|
||||
'GaussianNoise',
|
||||
'OUNoise',
|
||||
]
|
||||
|
@ -1,8 +1,42 @@
|
||||
import numpy as np
|
||||
from typing import Union, Optional
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class OUNoise(object):
|
||||
class BaseNoise(ABC, object):
|
||||
"""The action noise base class."""
|
||||
|
||||
def __init__(self, **kwargs) -> None:
|
||||
super(BaseNoise, self).__init__()
|
||||
|
||||
@abstractmethod
|
||||
def __call__(self, **kwargs) -> np.ndarray:
|
||||
"""Generate new noise."""
|
||||
raise NotImplementedError
|
||||
|
||||
def reset(self, **kwargs) -> None:
|
||||
"""Reset to the initial state."""
|
||||
pass
|
||||
|
||||
|
||||
class GaussianNoise(BaseNoise):
|
||||
"""Class for vanilla gaussian process,
|
||||
used for exploration in DDPG by default.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
mu: float = 0.0,
|
||||
sigma: float = 1.0):
|
||||
super().__init__()
|
||||
self._mu = mu
|
||||
assert 0 <= sigma, 'noise std should not be negative'
|
||||
self._sigma = sigma
|
||||
|
||||
def __call__(self, size: tuple) -> np.ndarray:
|
||||
return np.random.normal(self._mu, self._sigma, size)
|
||||
|
||||
|
||||
class OUNoise(BaseNoise):
|
||||
"""Class for Ornstein-Uhlenbeck process, as used for exploration in DDPG.
|
||||
Usage:
|
||||
::
|
||||
@ -19,26 +53,31 @@ class OUNoise(object):
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
mu: float = 0.0,
|
||||
sigma: float = 0.3,
|
||||
theta: float = 0.15,
|
||||
dt: float = 1e-2,
|
||||
x0: Optional[Union[float, np.ndarray]] = None
|
||||
) -> None:
|
||||
self.alpha = theta * dt
|
||||
self.beta = sigma * np.sqrt(dt)
|
||||
self.x0 = x0
|
||||
super(BaseNoise, self).__init__()
|
||||
self._mu = mu
|
||||
self._alpha = theta * dt
|
||||
self._beta = sigma * np.sqrt(dt)
|
||||
self._x0 = x0
|
||||
self.reset()
|
||||
|
||||
def __call__(self, size: tuple, mu: float = .1) -> np.ndarray:
|
||||
def __call__(self, size: tuple, mu: Optional[float] = None) -> np.ndarray:
|
||||
"""Generate new noise. Return a ``numpy.ndarray`` which size is equal
|
||||
to ``size``.
|
||||
"""
|
||||
if self.x is None or self.x.shape != size:
|
||||
self.x = 0
|
||||
r = self.beta * np.random.normal(size=size)
|
||||
self.x = self.x + self.alpha * (mu - self.x) + r
|
||||
return self.x
|
||||
if self._x is None or self._x.shape != size:
|
||||
self._x = 0
|
||||
if mu is None:
|
||||
mu = self._mu
|
||||
r = self._beta * np.random.normal(size=size)
|
||||
self._x = self._x + self._alpha * (mu - self._x) + r
|
||||
return self._x
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset to the initial state."""
|
||||
self.x = None
|
||||
self._x = self._x0
|
||||
|
@ -5,7 +5,7 @@ import torch.nn.functional as F
|
||||
from typing import Dict, Tuple, Union, Optional
|
||||
|
||||
from tianshou.policy import BasePolicy
|
||||
# from tianshou.exploration import OUNoise
|
||||
from tianshou.exploration import BaseNoise, GaussianNoise
|
||||
from tianshou.data import Batch, ReplayBuffer, to_torch_as
|
||||
|
||||
|
||||
@ -21,8 +21,8 @@ class DDPGPolicy(BasePolicy):
|
||||
:param float tau: param for soft update of the target network, defaults to
|
||||
0.005.
|
||||
:param float gamma: discount factor, in [0, 1], defaults to 0.99.
|
||||
:param float exploration_noise: the noise intensity, add to the action,
|
||||
defaults to 0.1.
|
||||
:param BaseNoise exploration_noise: the exploration noise,
|
||||
add to the action, defaults to ``GaussianNoise(sigma=0.1)``.
|
||||
:param action_range: the action range (minimum, maximum).
|
||||
:type action_range: (float, float)
|
||||
:param bool reward_normalization: normalize the reward to Normal(0, 1),
|
||||
@ -45,7 +45,8 @@ class DDPGPolicy(BasePolicy):
|
||||
critic_optim: torch.optim.Optimizer,
|
||||
tau: float = 0.005,
|
||||
gamma: float = 0.99,
|
||||
exploration_noise: float = 0.1,
|
||||
exploration_noise: Optional[BaseNoise]
|
||||
= GaussianNoise(sigma=0.1),
|
||||
action_range: Optional[Tuple[float, float]] = None,
|
||||
reward_normalization: bool = False,
|
||||
ignore_done: bool = False,
|
||||
@ -64,8 +65,7 @@ class DDPGPolicy(BasePolicy):
|
||||
self._tau = tau
|
||||
assert 0 <= gamma <= 1, 'gamma should in [0, 1]'
|
||||
self._gamma = gamma
|
||||
assert 0 <= exploration_noise, 'noise should not be negative'
|
||||
self._eps = exploration_noise
|
||||
self._noise = exploration_noise
|
||||
assert action_range is not None
|
||||
self._range = action_range
|
||||
self._action_bias = (action_range[0] + action_range[1]) / 2
|
||||
@ -77,9 +77,9 @@ class DDPGPolicy(BasePolicy):
|
||||
assert estimation_step > 0, 'estimation_step should greater than 0'
|
||||
self._n_step = estimation_step
|
||||
|
||||
def set_eps(self, eps: float) -> None:
|
||||
"""Set the eps for exploration."""
|
||||
self._eps = eps
|
||||
def set_exp_noise(self, noise: Optional[BaseNoise]) -> None:
|
||||
"""Set the exploration noise."""
|
||||
self._noise = noise
|
||||
|
||||
def train(self) -> None:
|
||||
"""Set the module in training mode, except for the target network."""
|
||||
@ -106,7 +106,8 @@ class DDPGPolicy(BasePolicy):
|
||||
batch = buffer[indice] # batch.obs_next: s_{t+n}
|
||||
with torch.no_grad():
|
||||
target_q = self.critic_old(batch.obs_next, self(
|
||||
batch, model='actor_old', input='obs_next', eps=0).act)
|
||||
batch, model='actor_old', input='obs_next',
|
||||
explorating=False).act)
|
||||
return target_q
|
||||
|
||||
def process_fn(self, batch: Batch, buffer: ReplayBuffer,
|
||||
@ -122,7 +123,7 @@ class DDPGPolicy(BasePolicy):
|
||||
state: Optional[Union[dict, Batch, np.ndarray]] = None,
|
||||
model: str = 'actor',
|
||||
input: str = 'obs',
|
||||
eps: Optional[float] = None,
|
||||
explorating: bool = True,
|
||||
**kwargs) -> Batch:
|
||||
"""Compute action over the given batch data.
|
||||
|
||||
@ -142,14 +143,8 @@ class DDPGPolicy(BasePolicy):
|
||||
obs = getattr(batch, input)
|
||||
logits, h = model(obs, state=state, info=batch.info)
|
||||
logits += self._action_bias
|
||||
if eps is None:
|
||||
eps = self._eps
|
||||
if eps > 0:
|
||||
# noise = np.random.normal(0, eps, size=logits.shape)
|
||||
# logits += to_torch(noise, device=logits.device)
|
||||
# noise = self.noise(logits.shape, eps)
|
||||
logits += torch.randn(
|
||||
size=logits.shape, device=logits.device) * eps
|
||||
if self.training and explorating:
|
||||
logits += to_torch_as(self._noise(logits.shape), logits)
|
||||
logits = logits.clamp(self._range[0], self._range[1])
|
||||
return Batch(act=logits, state=h)
|
||||
|
||||
@ -161,7 +156,8 @@ class DDPGPolicy(BasePolicy):
|
||||
self.critic_optim.zero_grad()
|
||||
critic_loss.backward()
|
||||
self.critic_optim.step()
|
||||
actor_loss = -self.critic(batch.obs, self(batch, eps=0).act).mean()
|
||||
action = self(batch, explorating=False).act
|
||||
actor_loss = -self.critic(batch.obs, action).mean()
|
||||
self.actor_optim.zero_grad()
|
||||
actor_loss.backward()
|
||||
self.actor_optim.step()
|
||||
|
@ -7,6 +7,7 @@ from typing import Dict, Tuple, Union, Optional
|
||||
from tianshou.policy import DDPGPolicy
|
||||
from tianshou.policy.dist import DiagGaussian
|
||||
from tianshou.data import Batch, to_torch_as, ReplayBuffer
|
||||
from tianshou.exploration import BaseNoise
|
||||
|
||||
|
||||
class SACPolicy(DDPGPolicy):
|
||||
@ -28,13 +29,18 @@ class SACPolicy(DDPGPolicy):
|
||||
:param float gamma: discount factor, in [0, 1], defaults to 0.99.
|
||||
:param float exploration_noise: the noise intensity, add to the action,
|
||||
defaults to 0.1.
|
||||
:param float alpha: entropy regularization coefficient, default to 0.2.
|
||||
:param (float, torch.Tensor, torch.optim.Optimizer) or float alpha: entropy
|
||||
regularization coefficient, default to 0.2.
|
||||
If a tuple (target_entropy, log_alpha, alpha_optim) is provided, then
|
||||
alpha is automatatically tuned.
|
||||
:param action_range: the action range (minimum, maximum).
|
||||
:type action_range: (float, float)
|
||||
:param bool reward_normalization: normalize the reward to Normal(0, 1),
|
||||
defaults to ``False``.
|
||||
:param bool ignore_done: ignore the done flag while training the policy,
|
||||
defaults to ``False``.
|
||||
:param BaseNoise exploration_noise: add a noise to action for exploration.
|
||||
This is useful when solving hard-exploration problem.
|
||||
|
||||
.. seealso::
|
||||
|
||||
@ -51,13 +57,15 @@ class SACPolicy(DDPGPolicy):
|
||||
critic2_optim: torch.optim.Optimizer,
|
||||
tau: float = 0.005,
|
||||
gamma: float = 0.99,
|
||||
alpha: float = 0.2,
|
||||
alpha: Tuple[float, torch.Tensor, torch.optim.Optimizer]
|
||||
or float = 0.2,
|
||||
action_range: Optional[Tuple[float, float]] = None,
|
||||
reward_normalization: bool = False,
|
||||
ignore_done: bool = False,
|
||||
estimation_step: int = 1,
|
||||
exploration_noise: Optional[BaseNoise] = None,
|
||||
**kwargs) -> None:
|
||||
super().__init__(None, None, None, None, tau, gamma, 0,
|
||||
super().__init__(None, None, None, None, tau, gamma, exploration_noise,
|
||||
action_range, reward_normalization, ignore_done,
|
||||
estimation_step, **kwargs)
|
||||
self.actor, self.actor_optim = actor, actor_optim
|
||||
@ -67,7 +75,18 @@ class SACPolicy(DDPGPolicy):
|
||||
self.critic2, self.critic2_old = critic2, deepcopy(critic2)
|
||||
self.critic2_old.eval()
|
||||
self.critic2_optim = critic2_optim
|
||||
self._alpha = alpha
|
||||
|
||||
self._automatic_alpha_tuning = not isinstance(alpha, float)
|
||||
if self._automatic_alpha_tuning:
|
||||
self._target_entropy = alpha[0]
|
||||
assert(alpha[1].shape == torch.Size([1])
|
||||
and alpha[1].requires_grad)
|
||||
self._log_alpha = alpha[1]
|
||||
self._alpha_optim = alpha[2]
|
||||
self._alpha = self._log_alpha.exp()
|
||||
else:
|
||||
self._alpha = alpha
|
||||
|
||||
self.__eps = np.finfo(np.float32).eps.item()
|
||||
|
||||
def train(self) -> None:
|
||||
@ -92,7 +111,9 @@ class SACPolicy(DDPGPolicy):
|
||||
|
||||
def forward(self, batch: Batch,
|
||||
state: Optional[Union[dict, Batch, np.ndarray]] = None,
|
||||
input: str = 'obs', **kwargs) -> Batch:
|
||||
input: str = 'obs',
|
||||
explorating: bool = True,
|
||||
**kwargs) -> Batch:
|
||||
obs = getattr(batch, input)
|
||||
logits, h = self.actor(obs, state=state, info=batch.info)
|
||||
assert isinstance(logits, tuple)
|
||||
@ -100,9 +121,10 @@ class SACPolicy(DDPGPolicy):
|
||||
x = dist.rsample()
|
||||
y = torch.tanh(x)
|
||||
act = y * self._action_scale + self._action_bias
|
||||
log_prob = dist.log_prob(x) - torch.log(
|
||||
self._action_scale * (1 - y.pow(2)) + self.__eps
|
||||
).sum(-1, keepdim=True)
|
||||
y = self._action_scale * (1 - y.pow(2)) + self.__eps
|
||||
log_prob = dist.log_prob(x) - torch.log(y).sum(-1, keepdim=True)
|
||||
if self._noise is not None and self.training and explorating:
|
||||
act += to_torch_as(self._noise(act.shape), act)
|
||||
act = act.clamp(self._range[0], self._range[1])
|
||||
return Batch(
|
||||
logits=logits, act=act, state=h, dist=dist, log_prob=log_prob)
|
||||
@ -111,7 +133,7 @@ class SACPolicy(DDPGPolicy):
|
||||
indice: np.ndarray) -> torch.Tensor:
|
||||
batch = buffer[indice] # batch.obs: s_{t+n}
|
||||
with torch.no_grad():
|
||||
obs_next_result = self(batch, input='obs_next')
|
||||
obs_next_result = self(batch, input='obs_next', explorating=False)
|
||||
a_ = obs_next_result.act
|
||||
batch.act = to_torch_as(batch.act, a_)
|
||||
target_q = torch.min(
|
||||
@ -135,7 +157,7 @@ class SACPolicy(DDPGPolicy):
|
||||
critic2_loss.backward()
|
||||
self.critic2_optim.step()
|
||||
# actor
|
||||
obs_result = self(batch)
|
||||
obs_result = self(batch, explorating=False)
|
||||
a = obs_result.act
|
||||
current_q1a = self.critic1(batch.obs, a)
|
||||
current_q2a = self.critic2(batch.obs, a)
|
||||
@ -144,9 +166,22 @@ class SACPolicy(DDPGPolicy):
|
||||
self.actor_optim.zero_grad()
|
||||
actor_loss.backward()
|
||||
self.actor_optim.step()
|
||||
|
||||
if self._automatic_alpha_tuning:
|
||||
log_prob = (obs_result.log_prob + self._target_entropy).detach()
|
||||
alpha_loss = -(self._log_alpha * log_prob).mean()
|
||||
self._alpha_optim.zero_grad()
|
||||
alpha_loss.backward()
|
||||
self._alpha_optim.step()
|
||||
self._alpha = self._log_alpha.exp()
|
||||
|
||||
self.sync_weight()
|
||||
return {
|
||||
|
||||
result = {
|
||||
'loss/actor': actor_loss.item(),
|
||||
'loss/critic1': critic1_loss.item(),
|
||||
'loss/critic2': critic2_loss.item(),
|
||||
}
|
||||
if self._automatic_alpha_tuning:
|
||||
result['loss/alpha'] = alpha_loss.item()
|
||||
return result
|
||||
|
@ -6,6 +6,7 @@ from typing import Dict, Tuple, Optional
|
||||
|
||||
from tianshou.policy import DDPGPolicy
|
||||
from tianshou.data import Batch, ReplayBuffer
|
||||
from tianshou.exploration import BaseNoise, GaussianNoise
|
||||
|
||||
|
||||
class TD3Policy(DDPGPolicy):
|
||||
@ -26,8 +27,8 @@ class TD3Policy(DDPGPolicy):
|
||||
:param float tau: param for soft update of the target network, defaults to
|
||||
0.005.
|
||||
:param float gamma: discount factor, in [0, 1], defaults to 0.99.
|
||||
:param float exploration_noise: the noise intensity, add to the action,
|
||||
defaults to 0.1.
|
||||
:param float exploration_noise: the exploration noise, add to the action,
|
||||
defaults to ``GaussianNoise(sigma=0.1)``
|
||||
:param float policy_noise: the noise used in updating policy network,
|
||||
default to 0.2.
|
||||
:param int update_actor_freq: the update frequency of actor network,
|
||||
@ -56,7 +57,8 @@ class TD3Policy(DDPGPolicy):
|
||||
critic2_optim: torch.optim.Optimizer,
|
||||
tau: float = 0.005,
|
||||
gamma: float = 0.99,
|
||||
exploration_noise: float = 0.1,
|
||||
exploration_noise: Optional[BaseNoise]
|
||||
= GaussianNoise(sigma=0.1),
|
||||
policy_noise: float = 0.2,
|
||||
update_actor_freq: int = 2,
|
||||
noise_clip: float = 0.5,
|
||||
|
Loading…
x
Reference in New Issue
Block a user