160 lines
5.9 KiB
Python
Raw Normal View History

2020-03-14 21:48:31 +08:00
import torch
2020-03-15 17:41:00 +08:00
import numpy as np
2020-03-14 21:48:31 +08:00
from copy import deepcopy
2020-03-18 21:45:41 +08:00
import torch.nn.functional as F
2020-05-12 11:31:47 +08:00
from typing import Dict, Union, Optional
2020-03-14 21:48:31 +08:00
from tianshou.policy import BasePolicy
from tianshou.data import Batch, ReplayBuffer, PrioritizedReplayBuffer, \
to_torch, to_numpy
2020-03-14 21:48:31 +08:00
2020-03-18 21:45:41 +08:00
class DQNPolicy(BasePolicy):
2020-04-06 19:36:59 +08:00
"""Implementation of Deep Q Network. arXiv:1312.5602
2020-05-27 11:02:23 +08:00
Implementation of Double Q-Learning. arXiv:1509.06461
2020-04-06 19:36:59 +08:00
:param torch.nn.Module model: a model following the rules in
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
:param torch.optim.Optimizer optim: a torch.optim for optimizing the model.
:param float discount_factor: in [0, 1].
:param int estimation_step: greater than 1, the number of steps to look
ahead.
:param int target_update_freq: the target network update frequency (``0``
if you do not use the target network).
.. seealso::
Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed
explanation.
2020-04-06 19:36:59 +08:00
"""
2020-03-14 21:48:31 +08:00
2020-05-12 11:31:47 +08:00
def __init__(self,
model: torch.nn.Module,
optim: torch.optim.Optimizer,
2020-05-16 20:08:32 +08:00
discount_factor: float = 0.99,
estimation_step: int = 1,
2020-05-12 11:31:47 +08:00
target_update_freq: Optional[int] = 0,
**kwargs) -> None:
2020-04-08 21:13:15 +08:00
super().__init__(**kwargs)
2020-03-14 21:48:31 +08:00
self.model = model
2020-03-15 17:41:00 +08:00
self.optim = optim
self.eps = 0
2020-04-03 21:28:12 +08:00
assert 0 <= discount_factor <= 1, 'discount_factor should in [0, 1]'
2020-03-14 21:48:31 +08:00
self._gamma = discount_factor
2020-03-15 17:41:00 +08:00
assert estimation_step > 0, 'estimation_step should greater than 0'
2020-03-14 21:48:31 +08:00
self._n_step = estimation_step
2020-04-06 19:36:59 +08:00
self._target = target_update_freq > 0
2020-03-25 14:08:28 +08:00
self._freq = target_update_freq
self._cnt = 0
2020-04-06 19:36:59 +08:00
if self._target:
2020-03-14 21:48:31 +08:00
self.model_old = deepcopy(self.model)
2020-03-15 17:41:00 +08:00
self.model_old.eval()
2020-03-14 21:48:31 +08:00
2020-05-12 11:31:47 +08:00
def set_eps(self, eps: float) -> None:
2020-04-06 19:36:59 +08:00
"""Set the eps for epsilon-greedy exploration."""
2020-03-15 17:41:00 +08:00
self.eps = eps
2020-05-12 11:31:47 +08:00
def train(self) -> None:
2020-04-06 19:36:59 +08:00
"""Set the module in training mode, except for the target network."""
2020-03-15 17:41:00 +08:00
self.training = True
self.model.train()
2020-05-12 11:31:47 +08:00
def eval(self) -> None:
2020-04-06 19:36:59 +08:00
"""Set the module in evaluation mode, except for the target network."""
2020-03-15 17:41:00 +08:00
self.training = False
self.model.eval()
2020-05-12 11:31:47 +08:00
def sync_weight(self) -> None:
2020-04-06 19:36:59 +08:00
"""Synchronize the weight for the target network."""
2020-03-25 14:08:28 +08:00
self.model_old.load_state_dict(self.model.state_dict())
2020-03-14 21:48:31 +08:00
2020-06-02 22:29:50 +08:00
def _target_q(self, buffer: ReplayBuffer,
indice: np.ndarray) -> np.ndarray:
data = buffer[indice]
if self._target:
# target_Q = Q_old(s_, argmax(Q_new(s_, *)))
a = self(data, input='obs_next', eps=0).act
target_q = self(
data, model='model_old', input='obs_next').logits
target_q = to_numpy(target_q)
target_q = target_q[np.arange(len(a)), a]
else:
target_q = self(data, input='obs_next').logits
target_q = to_numpy(target_q).max(axis=1)
return target_q
2020-05-12 11:31:47 +08:00
def process_fn(self, batch: Batch, buffer: ReplayBuffer,
indice: np.ndarray) -> Batch:
2020-04-06 19:36:59 +08:00
r"""Compute the n-step return for Q-learning targets:
.. math::
G_t = \sum_{i = t}^{t + n - 1} \gamma^{i - t}(1 - d_i)r_i +
\gamma^n (1 - d_{t + n}) \max_a Q_{old}(s_{t + n}, \arg\max_a
(Q_{new}(s_{t + n}, a)))
, where :math:`\gamma` is the discount factor,
:math:`\gamma \in [0, 1]`, :math:`d_t` is the done flag of step
:math:`t`. If there is no target network, the :math:`Q_{old}` is equal
to :math:`Q_{new}`.
"""
2020-06-02 22:29:50 +08:00
batch = self.compute_nstep_return(
batch, buffer, indice, self._target_q, self._gamma, self._n_step)
if isinstance(buffer, PrioritizedReplayBuffer):
2020-06-02 22:29:50 +08:00
batch.update_weight = buffer.update_weight
batch.indice = indice
2020-03-14 21:48:31 +08:00
return batch
2020-03-15 17:41:00 +08:00
2020-05-12 11:31:47 +08:00
def forward(self, batch: Batch,
state: Optional[Union[dict, Batch, np.ndarray]] = None,
2020-05-16 20:08:32 +08:00
model: str = 'model',
input: str = 'obs',
2020-05-12 11:31:47 +08:00
eps: Optional[float] = None,
**kwargs) -> Batch:
2020-04-06 19:36:59 +08:00
"""Compute action over the given batch data.
:param float eps: in [0, 1], for epsilon-greedy exploration method.
:return: A :class:`~tianshou.data.Batch` which has 3 keys:
* ``act`` the action.
* ``logits`` the network's raw output.
* ``state`` the hidden state.
.. seealso::
2020-04-10 10:47:16 +08:00
Please refer to :meth:`~tianshou.policy.BasePolicy.forward` for
more detailed explanation.
2020-04-06 19:36:59 +08:00
"""
2020-03-17 11:37:31 +08:00
model = getattr(self, model)
obs = getattr(batch, input)
q, h = model(obs, state=state, info=batch.info)
act = to_numpy(q.max(dim=1)[1])
2020-03-17 11:37:31 +08:00
# add eps to act
if eps is None:
eps = self.eps
if not np.isclose(eps, 0):
for i in range(len(q)):
if np.random.rand() < eps:
act[i] = np.random.randint(q.shape[1])
2020-03-17 11:37:31 +08:00
return Batch(logits=q, act=act, state=h)
2020-05-12 11:31:47 +08:00
def learn(self, batch: Batch, **kwargs) -> Dict[str, float]:
2020-03-25 14:08:28 +08:00
if self._target and self._cnt % self._freq == 0:
self.sync_weight()
2020-03-15 17:41:00 +08:00
self.optim.zero_grad()
2020-06-02 22:29:50 +08:00
q = self(batch).logits
q = q[np.arange(len(q)), batch.act]
r = to_torch(batch.returns, device=q.device, dtype=q.dtype)
if hasattr(batch, 'update_weight'):
td = r - q
batch.update_weight(batch.indice, to_numpy(td))
impt_weight = to_torch(batch.impt_weight,
device=q.device, dtype=torch.float)
loss = (td.pow(2) * impt_weight).mean()
else:
loss = F.mse_loss(q, r)
2020-03-15 17:41:00 +08:00
loss.backward()
self.optim.step()
2020-03-25 14:08:28 +08:00
self._cnt += 1
2020-04-03 21:28:12 +08:00
return {'loss': loss.item()}