170 lines
6.5 KiB
Python
Raw Normal View History

2020-03-18 21:45:41 +08:00
import torch
2020-03-21 15:31:31 +08:00
import numpy as np
2020-03-18 21:45:41 +08:00
from copy import deepcopy
from typing import Any, Dict, Tuple, Union, Optional
2020-03-18 21:45:41 +08:00
from tianshou.policy import BasePolicy
from tianshou.exploration import BaseNoise, GaussianNoise
2020-06-03 13:59:47 +08:00
from tianshou.data import Batch, ReplayBuffer, to_torch_as
2020-03-18 21:45:41 +08:00
class DDPGPolicy(BasePolicy):
"""Implementation of Deep Deterministic Policy Gradient. arXiv:1509.02971.
2020-04-06 19:36:59 +08:00
:param torch.nn.Module actor: the actor network following the rules in
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
:param torch.optim.Optimizer actor_optim: the optimizer for actor network.
:param torch.nn.Module critic: the critic network. (s, a -> Q(s, a))
:param torch.optim.Optimizer critic_optim: the optimizer for critic
network.
:param action_range: the action range (minimum, maximum).
:type action_range: Tuple[float, float]
2020-04-06 19:36:59 +08:00
:param float tau: param for soft update of the target network, defaults to
0.005.
:param float gamma: discount factor, in [0, 1], defaults to 0.99.
:param BaseNoise exploration_noise: the exploration noise,
add to the action, defaults to ``GaussianNoise(sigma=0.1)``.
2020-04-06 19:36:59 +08:00
:param bool reward_normalization: normalize the reward to Normal(0, 1),
defaults to False.
2020-04-06 19:36:59 +08:00
:param bool ignore_done: ignore the done flag while training the policy,
defaults to False.
2020-06-03 13:59:47 +08:00
:param int estimation_step: greater than 1, the number of steps to look
ahead.
.. seealso::
Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed
explanation.
2020-04-06 19:36:59 +08:00
"""
2020-03-18 21:45:41 +08:00
def __init__(
self,
actor: Optional[torch.nn.Module],
actor_optim: Optional[torch.optim.Optimizer],
critic: Optional[torch.nn.Module],
critic_optim: Optional[torch.optim.Optimizer],
action_range: Tuple[float, float],
tau: float = 0.005,
gamma: float = 0.99,
exploration_noise: Optional[BaseNoise] = GaussianNoise(sigma=0.1),
reward_normalization: bool = False,
ignore_done: bool = False,
estimation_step: int = 1,
**kwargs: Any,
) -> None:
2020-04-08 21:13:15 +08:00
super().__init__(**kwargs)
if actor is not None and actor_optim is not None:
self.actor: torch.nn.Module = actor
self.actor_old = deepcopy(actor)
2020-03-23 17:17:41 +08:00
self.actor_old.eval()
self.actor_optim: torch.optim.Optimizer = actor_optim
if critic is not None and critic_optim is not None:
self.critic: torch.nn.Module = critic
self.critic_old = deepcopy(critic)
2020-03-23 11:34:52 +08:00
self.critic_old.eval()
self.critic_optim: torch.optim.Optimizer = critic_optim
assert 0.0 <= tau <= 1.0, "tau should be in [0, 1]"
2020-03-18 21:45:41 +08:00
self._tau = tau
assert 0.0 <= gamma <= 1.0, "gamma should be in [0, 1]"
2020-03-18 21:45:41 +08:00
self._gamma = gamma
self._noise = exploration_noise
2020-03-18 21:45:41 +08:00
self._range = action_range
self._action_bias = (action_range[0] + action_range[1]) / 2.0
self._action_scale = (action_range[1] - action_range[0]) / 2.0
# it is only a little difference to use GaussianNoise
2020-03-18 21:45:41 +08:00
# self.noise = OUNoise()
2020-03-25 14:08:28 +08:00
self._rm_done = ignore_done
2020-03-21 15:31:31 +08:00
self._rew_norm = reward_normalization
assert estimation_step > 0, "estimation_step should be greater than 0"
2020-06-03 13:59:47 +08:00
self._n_step = estimation_step
2020-03-18 21:45:41 +08:00
def set_exp_noise(self, noise: Optional[BaseNoise]) -> None:
"""Set the exploration noise."""
self._noise = noise
2020-03-18 21:45:41 +08:00
def train(self, mode: bool = True) -> "DDPGPolicy":
2020-04-06 19:36:59 +08:00
"""Set the module in training mode, except for the target network."""
self.training = mode
self.actor.train(mode)
self.critic.train(mode)
return self
2020-03-18 21:45:41 +08:00
2020-05-12 11:31:47 +08:00
def sync_weight(self) -> None:
2020-04-06 19:36:59 +08:00
"""Soft-update the weight for the target network."""
2020-03-18 21:45:41 +08:00
for o, n in zip(self.actor_old.parameters(), self.actor.parameters()):
o.data.copy_(o.data * (1.0 - self._tau) + n.data * self._tau)
2020-03-18 21:45:41 +08:00
for o, n in zip(
self.critic_old.parameters(), self.critic.parameters()
):
o.data.copy_(o.data * (1.0 - self._tau) + n.data * self._tau)
2020-03-18 21:45:41 +08:00
def _target_q(
self, buffer: ReplayBuffer, indice: np.ndarray
) -> torch.Tensor:
2020-06-03 13:59:47 +08:00
batch = buffer[indice] # batch.obs_next: s_{t+n}
with torch.no_grad():
target_q = self.critic_old(
batch.obs_next,
self(batch, model='actor_old', input='obs_next').act)
2020-06-03 13:59:47 +08:00
return target_q
def process_fn(
self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray
) -> Batch:
2020-03-25 14:08:28 +08:00
if self._rm_done:
batch.done = batch.done * 0.0
2020-06-03 13:59:47 +08:00
batch = self.compute_nstep_return(
batch, buffer, indice, self._target_q,
self._gamma, self._n_step, self._rew_norm)
2020-03-23 17:17:41 +08:00
return batch
def forward(
self,
batch: Batch,
state: Optional[Union[dict, Batch, np.ndarray]] = None,
model: str = "actor",
input: str = "obs",
**kwargs: Any,
) -> Batch:
2020-04-06 19:36:59 +08:00
"""Compute action over the given batch data.
:return: A :class:`~tianshou.data.Batch` which has 2 keys:
* ``act`` the action.
* ``state`` the hidden state.
.. seealso::
2020-04-10 10:47:16 +08:00
Please refer to :meth:`~tianshou.policy.BasePolicy.forward` for
more detailed explanation.
2020-04-06 19:36:59 +08:00
"""
2020-03-18 21:45:41 +08:00
model = getattr(self, model)
obs = batch[input]
actions, h = model(obs, state=state, info=batch.info)
actions += self._action_bias
if self._noise and not self.updating:
actions += to_torch_as(self._noise(actions.shape), actions)
actions = actions.clamp(self._range[0], self._range[1])
return Batch(act=actions, state=h)
2020-03-18 21:45:41 +08:00
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
weight = batch.pop("weight", 1.0)
current_q = self.critic(batch.obs, batch.act).flatten()
target_q = batch.returns.flatten()
td = current_q - target_q
critic_loss = (td.pow(2) * weight).mean()
batch.weight = td # prio-buffer
2020-03-18 21:45:41 +08:00
self.critic_optim.zero_grad()
critic_loss.backward()
self.critic_optim.step()
action = self(batch).act
actor_loss = -self.critic(batch.obs, action).mean()
2020-03-18 21:45:41 +08:00
self.actor_optim.zero_grad()
actor_loss.backward()
self.actor_optim.step()
2020-03-19 17:23:46 +08:00
self.sync_weight()
return {
"loss/actor": actor_loss.item(),
"loss/critic": critic_loss.item(),
2020-03-19 17:23:46 +08:00
}