n+e 94bfb32cc1
optimize training procedure and improve code coverage (#189)
1. add policy.eval() in all test scripts' "watch performance"
2. remove dict return support for collector preprocess_fn
3. add `__contains__` and `pop` in batch: `key in batch`, `batch.pop(key, deft)`
4. exact n_episode for a list of n_episode limitation and save fake data in cache_buffer when self.buffer is None (#184)
5. fix tensorboard logging: h-axis stands for env step instead of gradient step; add test results into tensorboard
6. add test_returns (both GAE and nstep)
7. change the type-checking order in batch.py and converter.py in order to meet the most often case first
8. fix shape inconsistency for torch.Tensor in replay buffer
9. remove `**kwargs` in ReplayBuffer
10. remove default value in batch.split() and add merge_last argument (#185)
11. improve nstep efficiency
12. add max_batchsize in onpolicy algorithms
13. potential bugfix for subproc.wait
14. fix RecurrentActorProb
15. improve the code-coverage (from 90% to 95%) and remove the dead code
16. fix some incorrect type annotation

The above improvement also increases the training FPS: on my computer, the previous version is only ~1800 FPS and after that, it can reach ~2050 (faster than v0.2.4.post1).
2020-08-27 12:15:18 +08:00

163 lines
6.5 KiB
Python

import torch
import numpy as np
from copy import deepcopy
from typing import Dict, Tuple, Union, Optional
from tianshou.policy import BasePolicy
from tianshou.exploration import BaseNoise, GaussianNoise
from tianshou.data import Batch, ReplayBuffer, to_torch_as
class DDPGPolicy(BasePolicy):
"""Implementation of Deep Deterministic Policy Gradient. arXiv:1509.02971
:param torch.nn.Module actor: the actor network following the rules in
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
:param torch.optim.Optimizer actor_optim: the optimizer for actor network.
:param torch.nn.Module critic: the critic network. (s, a -> Q(s, a))
:param torch.optim.Optimizer critic_optim: the optimizer for critic
network.
:param float tau: param for soft update of the target network, defaults to
0.005.
:param float gamma: discount factor, in [0, 1], defaults to 0.99.
:param BaseNoise exploration_noise: the exploration noise,
add to the action, defaults to ``GaussianNoise(sigma=0.1)``.
:param action_range: the action range (minimum, maximum).
:type action_range: (float, float)
:param bool reward_normalization: normalize the reward to Normal(0, 1),
defaults to ``False``.
:param bool ignore_done: ignore the done flag while training the policy,
defaults to ``False``.
:param int estimation_step: greater than 1, the number of steps to look
ahead.
.. seealso::
Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed
explanation.
"""
def __init__(self,
actor: torch.nn.Module,
actor_optim: torch.optim.Optimizer,
critic: torch.nn.Module,
critic_optim: torch.optim.Optimizer,
tau: float = 0.005,
gamma: float = 0.99,
exploration_noise: Optional[BaseNoise]
= GaussianNoise(sigma=0.1),
action_range: Optional[Tuple[float, float]] = None,
reward_normalization: bool = False,
ignore_done: bool = False,
estimation_step: int = 1,
**kwargs) -> None:
super().__init__(**kwargs)
if actor is not None:
self.actor, self.actor_old = actor, deepcopy(actor)
self.actor_old.eval()
self.actor_optim = actor_optim
if critic is not None:
self.critic, self.critic_old = critic, deepcopy(critic)
self.critic_old.eval()
self.critic_optim = critic_optim
assert 0 <= tau <= 1, 'tau should in [0, 1]'
self._tau = tau
assert 0 <= gamma <= 1, 'gamma should in [0, 1]'
self._gamma = gamma
self._noise = exploration_noise
assert action_range is not None
self._range = action_range
self._action_bias = (action_range[0] + action_range[1]) / 2
self._action_scale = (action_range[1] - action_range[0]) / 2
# it is only a little difference to use rand_normal
# self.noise = OUNoise()
self._rm_done = ignore_done
self._rew_norm = reward_normalization
assert estimation_step > 0, 'estimation_step should greater than 0'
self._n_step = estimation_step
def set_exp_noise(self, noise: Optional[BaseNoise]) -> None:
"""Set the exploration noise."""
self._noise = noise
def train(self, mode=True) -> torch.nn.Module:
"""Set the module in training mode, except for the target network."""
self.training = mode
self.actor.train(mode)
self.critic.train(mode)
return self
def sync_weight(self) -> None:
"""Soft-update the weight for the target network."""
for o, n in zip(self.actor_old.parameters(), self.actor.parameters()):
o.data.copy_(o.data * (1 - self._tau) + n.data * self._tau)
for o, n in zip(
self.critic_old.parameters(), self.critic.parameters()):
o.data.copy_(o.data * (1 - self._tau) + n.data * self._tau)
def _target_q(self, buffer: ReplayBuffer,
indice: np.ndarray) -> torch.Tensor:
batch = buffer[indice] # batch.obs_next: s_{t+n}
with torch.no_grad():
target_q = self.critic_old(batch.obs_next, self(
batch, model='actor_old', input='obs_next',
explorating=False).act)
return target_q
def process_fn(self, batch: Batch, buffer: ReplayBuffer,
indice: np.ndarray) -> Batch:
if self._rm_done:
batch.done = batch.done * 0.
batch = self.compute_nstep_return(
batch, buffer, indice, self._target_q,
self._gamma, self._n_step, self._rew_norm)
return batch
def forward(self, batch: Batch,
state: Optional[Union[dict, Batch, np.ndarray]] = None,
model: str = 'actor',
input: str = 'obs',
explorating: bool = True,
**kwargs) -> Batch:
"""Compute action over the given batch data.
:return: A :class:`~tianshou.data.Batch` which has 2 keys:
* ``act`` the action.
* ``state`` the hidden state.
.. seealso::
Please refer to :meth:`~tianshou.policy.BasePolicy.forward` for
more detailed explanation.
"""
model = getattr(self, model)
obs = getattr(batch, input)
actions, h = model(obs, state=state, info=batch.info)
actions += self._action_bias
if self.training and explorating:
actions += to_torch_as(self._noise(actions.shape), actions)
actions = actions.clamp(self._range[0], self._range[1])
return Batch(act=actions, state=h)
def learn(self, batch: Batch, **kwargs) -> Dict[str, float]:
weight = batch.pop('weight', 1.)
current_q = self.critic(batch.obs, batch.act).flatten()
target_q = batch.returns.flatten()
td = current_q - target_q
critic_loss = (td.pow(2) * weight).mean()
batch.weight = td # prio-buffer
self.critic_optim.zero_grad()
critic_loss.backward()
self.critic_optim.step()
action = self(batch, explorating=False).act
actor_loss = -self.critic(batch.obs, action).mean()
self.actor_optim.zero_grad()
actor_loss.backward()
self.actor_optim.step()
self.sync_weight()
return {
'loss/actor': actor_loss.item(),
'loss/critic': critic_loss.item(),
}