133 lines
5.2 KiB
Python
133 lines
5.2 KiB
Python
from copy import deepcopy
|
|
from typing import Any, Dict, Optional
|
|
|
|
import numpy as np
|
|
import torch
|
|
|
|
from tianshou.data import Batch, ReplayBuffer
|
|
from tianshou.exploration import BaseNoise, GaussianNoise
|
|
from tianshou.policy import DDPGPolicy
|
|
|
|
|
|
class TD3Policy(DDPGPolicy):
|
|
"""Implementation of TD3, arXiv:1802.09477.
|
|
|
|
:param torch.nn.Module actor: the actor network following the rules in
|
|
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
|
|
:param torch.optim.Optimizer actor_optim: the optimizer for actor network.
|
|
:param torch.nn.Module critic1: the first critic network. (s, a -> Q(s, a))
|
|
:param torch.optim.Optimizer critic1_optim: the optimizer for the first
|
|
critic network.
|
|
:param torch.nn.Module critic2: the second critic network. (s, a -> Q(s, a))
|
|
:param torch.optim.Optimizer critic2_optim: the optimizer for the second
|
|
critic network.
|
|
:param float tau: param for soft update of the target network. Default to 0.005.
|
|
:param float gamma: discount factor, in [0, 1]. Default to 0.99.
|
|
:param float exploration_noise: the exploration noise, add to the action.
|
|
Default to ``GaussianNoise(sigma=0.1)``
|
|
:param float policy_noise: the noise used in updating policy network.
|
|
Default to 0.2.
|
|
:param int update_actor_freq: the update frequency of actor network.
|
|
Default to 2.
|
|
:param float noise_clip: the clipping range used in updating policy network.
|
|
Default to 0.5.
|
|
:param bool reward_normalization: normalize the reward to Normal(0, 1).
|
|
Default to False.
|
|
:param bool action_scaling: whether to map actions from range [-1, 1] to range
|
|
[action_spaces.low, action_spaces.high]. Default to True.
|
|
:param str action_bound_method: method to bound action to range [-1, 1], can be
|
|
either "clip" (for simply clipping the action) or empty string for no bounding.
|
|
Default to "clip".
|
|
:param Optional[gym.Space] action_space: env's action space, mandatory if you want
|
|
to use option "action_scaling" or "action_bound_method". Default to None.
|
|
|
|
.. seealso::
|
|
|
|
Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed
|
|
explanation.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
actor: torch.nn.Module,
|
|
actor_optim: torch.optim.Optimizer,
|
|
critic1: torch.nn.Module,
|
|
critic1_optim: torch.optim.Optimizer,
|
|
critic2: torch.nn.Module,
|
|
critic2_optim: torch.optim.Optimizer,
|
|
tau: float = 0.005,
|
|
gamma: float = 0.99,
|
|
exploration_noise: Optional[BaseNoise] = GaussianNoise(sigma=0.1),
|
|
policy_noise: float = 0.2,
|
|
update_actor_freq: int = 2,
|
|
noise_clip: float = 0.5,
|
|
reward_normalization: bool = False,
|
|
estimation_step: int = 1,
|
|
**kwargs: Any,
|
|
) -> None:
|
|
super().__init__(
|
|
actor, actor_optim, None, None, tau, gamma, exploration_noise,
|
|
reward_normalization, estimation_step, **kwargs
|
|
)
|
|
self.critic1, self.critic1_old = critic1, deepcopy(critic1)
|
|
self.critic1_old.eval()
|
|
self.critic1_optim = critic1_optim
|
|
self.critic2, self.critic2_old = critic2, deepcopy(critic2)
|
|
self.critic2_old.eval()
|
|
self.critic2_optim = critic2_optim
|
|
self._policy_noise = policy_noise
|
|
self._freq = update_actor_freq
|
|
self._noise_clip = noise_clip
|
|
self._cnt = 0
|
|
self._last = 0
|
|
|
|
def train(self, mode: bool = True) -> "TD3Policy":
|
|
self.training = mode
|
|
self.actor.train(mode)
|
|
self.critic1.train(mode)
|
|
self.critic2.train(mode)
|
|
return self
|
|
|
|
def sync_weight(self) -> None:
|
|
self.soft_update(self.critic1_old, self.critic1, self.tau)
|
|
self.soft_update(self.critic2_old, self.critic2, self.tau)
|
|
self.soft_update(self.actor_old, self.actor, self.tau)
|
|
|
|
def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor:
|
|
batch = buffer[indices] # batch.obs: s_{t+n}
|
|
act_ = self(batch, model="actor_old", input="obs_next").act
|
|
noise = torch.randn(size=act_.shape, device=act_.device) * self._policy_noise
|
|
if self._noise_clip > 0.0:
|
|
noise = noise.clamp(-self._noise_clip, self._noise_clip)
|
|
act_ += noise
|
|
target_q = torch.min(
|
|
self.critic1_old(batch.obs_next, act_),
|
|
self.critic2_old(batch.obs_next, act_),
|
|
)
|
|
return target_q
|
|
|
|
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
|
|
# critic 1&2
|
|
td1, critic1_loss = self._mse_optimizer(
|
|
batch, self.critic1, self.critic1_optim
|
|
)
|
|
td2, critic2_loss = self._mse_optimizer(
|
|
batch, self.critic2, self.critic2_optim
|
|
)
|
|
batch.weight = (td1 + td2) / 2.0 # prio-buffer
|
|
|
|
# actor
|
|
if self._cnt % self._freq == 0:
|
|
actor_loss = -self.critic1(batch.obs, self(batch, eps=0.0).act).mean()
|
|
self.actor_optim.zero_grad()
|
|
actor_loss.backward()
|
|
self._last = actor_loss.item()
|
|
self.actor_optim.step()
|
|
self.sync_weight()
|
|
self._cnt += 1
|
|
return {
|
|
"loss/actor": self._last,
|
|
"loss/critic1": critic1_loss.item(),
|
|
"loss/critic2": critic2_loss.item(),
|
|
}
|