n+e 94bfb32cc1
optimize training procedure and improve code coverage (#189)
1. add policy.eval() in all test scripts' "watch performance"
2. remove dict return support for collector preprocess_fn
3. add `__contains__` and `pop` in batch: `key in batch`, `batch.pop(key, deft)`
4. exact n_episode for a list of n_episode limitation and save fake data in cache_buffer when self.buffer is None (#184)
5. fix tensorboard logging: h-axis stands for env step instead of gradient step; add test results into tensorboard
6. add test_returns (both GAE and nstep)
7. change the type-checking order in batch.py and converter.py in order to meet the most often case first
8. fix shape inconsistency for torch.Tensor in replay buffer
9. remove `**kwargs` in ReplayBuffer
10. remove default value in batch.split() and add merge_last argument (#185)
11. improve nstep efficiency
12. add max_batchsize in onpolicy algorithms
13. potential bugfix for subproc.wait
14. fix RecurrentActorProb
15. improve the code-coverage (from 90% to 95%) and remove the dead code
16. fix some incorrect type annotation

The above improvement also increases the training FPS: on my computer, the previous version is only ~1800 FPS and after that, it can reach ~2050 (faster than v0.2.4.post1).
2020-08-27 12:15:18 +08:00

170 lines
6.2 KiB
Python

import torch
import numpy as np
from copy import deepcopy
from typing import Dict, Union, Optional
from tianshou.policy import BasePolicy
from tianshou.data import Batch, ReplayBuffer, to_torch_as, to_numpy
class DQNPolicy(BasePolicy):
"""Implementation of Deep Q Network. arXiv:1312.5602
Implementation of Double Q-Learning. arXiv:1509.06461
Implementation of Dueling DQN. arXiv:1511.06581 (the dueling DQN is
implemented in the network side, not here)
:param torch.nn.Module model: a model following the rules in
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
:param torch.optim.Optimizer optim: a torch.optim for optimizing the model.
:param float discount_factor: in [0, 1].
:param int estimation_step: greater than 1, the number of steps to look
ahead.
:param int target_update_freq: the target network update frequency (``0``
if you do not use the target network).
:param bool reward_normalization: normalize the reward to Normal(0, 1),
defaults to ``False``.
.. seealso::
Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed
explanation.
"""
def __init__(self,
model: torch.nn.Module,
optim: torch.optim.Optimizer,
discount_factor: float = 0.99,
estimation_step: int = 1,
target_update_freq: int = 0,
reward_normalization: bool = False,
**kwargs) -> None:
super().__init__(**kwargs)
self.model = model
self.optim = optim
self.eps = 0
assert 0 <= discount_factor <= 1, 'discount_factor should in [0, 1]'
self._gamma = discount_factor
assert estimation_step > 0, 'estimation_step should greater than 0'
self._n_step = estimation_step
self._target = target_update_freq > 0
self._freq = target_update_freq
self._cnt = 0
if self._target:
self.model_old = deepcopy(self.model)
self.model_old.eval()
self._rew_norm = reward_normalization
def set_eps(self, eps: float) -> None:
"""Set the eps for epsilon-greedy exploration."""
self.eps = eps
def train(self, mode=True) -> torch.nn.Module:
"""Set the module in training mode, except for the target network."""
self.training = mode
self.model.train(mode)
return self
def sync_weight(self) -> None:
"""Synchronize the weight for the target network."""
self.model_old.load_state_dict(self.model.state_dict())
def _target_q(self, buffer: ReplayBuffer,
indice: np.ndarray) -> torch.Tensor:
batch = buffer[indice] # batch.obs_next: s_{t+n}
if self._target:
# target_Q = Q_old(s_, argmax(Q_new(s_, *)))
a = self(batch, input='obs_next', eps=0).act
with torch.no_grad():
target_q = self(
batch, model='model_old', input='obs_next').logits
target_q = target_q[np.arange(len(a)), a]
else:
with torch.no_grad():
target_q = self(batch, input='obs_next').logits.max(dim=1)[0]
return target_q
def process_fn(self, batch: Batch, buffer: ReplayBuffer,
indice: np.ndarray) -> Batch:
"""Compute the n-step return for Q-learning targets. More details can
be found at :meth:`~tianshou.policy.BasePolicy.compute_nstep_return`.
"""
batch = self.compute_nstep_return(
batch, buffer, indice, self._target_q,
self._gamma, self._n_step, self._rew_norm)
return batch
def forward(self, batch: Batch,
state: Optional[Union[dict, Batch, np.ndarray]] = None,
model: str = 'model',
input: str = 'obs',
eps: Optional[float] = None,
**kwargs) -> Batch:
"""Compute action over the given batch data. If you need to mask the
action, please add a "mask" into batch.obs, for example, if we have an
environment that has "0/1/2" three actions:
::
batch == Batch(
obs=Batch(
obs="original obs, with batch_size=1 for demonstration",
mask=np.array([[False, True, False]]),
# action 1 is available
# action 0 and 2 are unavailable
),
...
)
:param float eps: in [0, 1], for epsilon-greedy exploration method.
:return: A :class:`~tianshou.data.Batch` which has 3 keys:
* ``act`` the action.
* ``logits`` the network's raw output.
* ``state`` the hidden state.
.. seealso::
Please refer to :meth:`~tianshou.policy.BasePolicy.forward` for
more detailed explanation.
"""
model = getattr(self, model)
obs = getattr(batch, input)
obs_ = obs.obs if hasattr(obs, 'obs') else obs
q, h = model(obs_, state=state, info=batch.info)
act = to_numpy(q.max(dim=1)[1])
has_mask = hasattr(obs, 'mask')
if has_mask:
# some of actions are masked, they cannot be selected
q_ = to_numpy(q)
q_[~obs.mask] = -np.inf
act = q_.argmax(axis=1)
# add eps to act
if eps is None:
eps = self.eps
if not np.isclose(eps, 0):
for i in range(len(q)):
if np.random.rand() < eps:
q_ = np.random.rand(*q[i].shape)
if has_mask:
q_[~obs.mask[i]] = -np.inf
act[i] = q_.argmax()
return Batch(logits=q, act=act, state=h)
def learn(self, batch: Batch, **kwargs) -> Dict[str, float]:
if self._target and self._cnt % self._freq == 0:
self.sync_weight()
self.optim.zero_grad()
weight = batch.pop('weight', 1.)
q = self(batch, eps=0.).logits
q = q[np.arange(len(q)), batch.act]
r = to_torch_as(batch.returns, q).flatten()
td = r - q
loss = (td.pow(2) * weight).mean()
batch.weight = td # prio-buffer
loss.backward()
self.optim.step()
self._cnt += 1
return {'loss': loss.item()}