2020-03-14 21:48:31 +08:00
|
|
|
import torch
|
2020-03-15 17:41:00 +08:00
|
|
|
import numpy as np
|
2020-03-14 21:48:31 +08:00
|
|
|
from copy import deepcopy
|
2020-03-18 21:45:41 +08:00
|
|
|
import torch.nn.functional as F
|
2020-05-12 11:31:47 +08:00
|
|
|
from typing import Dict, Union, Optional
|
2020-03-14 21:48:31 +08:00
|
|
|
|
|
|
|
from tianshou.policy import BasePolicy
|
2020-05-29 14:45:21 +02:00
|
|
|
from tianshou.data import Batch, ReplayBuffer, PrioritizedReplayBuffer, \
|
2020-06-03 13:59:47 +08:00
|
|
|
to_torch_as, to_numpy
|
2020-03-14 21:48:31 +08:00
|
|
|
|
|
|
|
|
2020-03-18 21:45:41 +08:00
|
|
|
class DQNPolicy(BasePolicy):
|
2020-04-06 19:36:59 +08:00
|
|
|
"""Implementation of Deep Q Network. arXiv:1312.5602
|
2020-05-27 11:02:23 +08:00
|
|
|
Implementation of Double Q-Learning. arXiv:1509.06461
|
2020-04-06 19:36:59 +08:00
|
|
|
|
|
|
|
:param torch.nn.Module model: a model following the rules in
|
|
|
|
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
|
|
|
|
:param torch.optim.Optimizer optim: a torch.optim for optimizing the model.
|
|
|
|
:param float discount_factor: in [0, 1].
|
|
|
|
:param int estimation_step: greater than 1, the number of steps to look
|
|
|
|
ahead.
|
|
|
|
:param int target_update_freq: the target network update frequency (``0``
|
|
|
|
if you do not use the target network).
|
2020-07-16 19:36:32 +08:00
|
|
|
:param bool reward_normalization: normalize the reward to Normal(0, 1),
|
|
|
|
defaults to ``False``.
|
2020-04-09 21:36:53 +08:00
|
|
|
|
|
|
|
.. seealso::
|
|
|
|
|
|
|
|
Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed
|
|
|
|
explanation.
|
2020-04-06 19:36:59 +08:00
|
|
|
"""
|
2020-03-14 21:48:31 +08:00
|
|
|
|
2020-05-12 11:31:47 +08:00
|
|
|
def __init__(self,
|
|
|
|
model: torch.nn.Module,
|
|
|
|
optim: torch.optim.Optimizer,
|
2020-05-16 20:08:32 +08:00
|
|
|
discount_factor: float = 0.99,
|
|
|
|
estimation_step: int = 1,
|
2020-05-12 11:31:47 +08:00
|
|
|
target_update_freq: Optional[int] = 0,
|
2020-07-16 19:36:32 +08:00
|
|
|
reward_normalization: bool = False,
|
2020-05-12 11:31:47 +08:00
|
|
|
**kwargs) -> None:
|
2020-04-08 21:13:15 +08:00
|
|
|
super().__init__(**kwargs)
|
2020-03-14 21:48:31 +08:00
|
|
|
self.model = model
|
2020-03-15 17:41:00 +08:00
|
|
|
self.optim = optim
|
|
|
|
self.eps = 0
|
2020-04-03 21:28:12 +08:00
|
|
|
assert 0 <= discount_factor <= 1, 'discount_factor should in [0, 1]'
|
2020-03-14 21:48:31 +08:00
|
|
|
self._gamma = discount_factor
|
2020-03-15 17:41:00 +08:00
|
|
|
assert estimation_step > 0, 'estimation_step should greater than 0'
|
2020-03-14 21:48:31 +08:00
|
|
|
self._n_step = estimation_step
|
2020-04-06 19:36:59 +08:00
|
|
|
self._target = target_update_freq > 0
|
2020-03-25 14:08:28 +08:00
|
|
|
self._freq = target_update_freq
|
|
|
|
self._cnt = 0
|
2020-04-06 19:36:59 +08:00
|
|
|
if self._target:
|
2020-03-14 21:48:31 +08:00
|
|
|
self.model_old = deepcopy(self.model)
|
2020-03-15 17:41:00 +08:00
|
|
|
self.model_old.eval()
|
2020-07-16 19:36:32 +08:00
|
|
|
self._rew_norm = reward_normalization
|
2020-03-14 21:48:31 +08:00
|
|
|
|
2020-05-12 11:31:47 +08:00
|
|
|
def set_eps(self, eps: float) -> None:
|
2020-04-06 19:36:59 +08:00
|
|
|
"""Set the eps for epsilon-greedy exploration."""
|
2020-03-15 17:41:00 +08:00
|
|
|
self.eps = eps
|
|
|
|
|
2020-07-06 10:44:34 +08:00
|
|
|
def train(self, mode=True) -> torch.nn.Module:
|
2020-04-06 19:36:59 +08:00
|
|
|
"""Set the module in training mode, except for the target network."""
|
2020-07-06 10:44:34 +08:00
|
|
|
self.training = mode
|
|
|
|
self.model.train(mode)
|
|
|
|
return self
|
2020-03-15 17:41:00 +08:00
|
|
|
|
2020-05-12 11:31:47 +08:00
|
|
|
def sync_weight(self) -> None:
|
2020-04-06 19:36:59 +08:00
|
|
|
"""Synchronize the weight for the target network."""
|
2020-03-25 14:08:28 +08:00
|
|
|
self.model_old.load_state_dict(self.model.state_dict())
|
2020-03-14 21:48:31 +08:00
|
|
|
|
2020-06-02 22:29:50 +08:00
|
|
|
def _target_q(self, buffer: ReplayBuffer,
|
2020-06-03 13:59:47 +08:00
|
|
|
indice: np.ndarray) -> torch.Tensor:
|
|
|
|
batch = buffer[indice] # batch.obs_next: s_{t+n}
|
2020-06-02 22:29:50 +08:00
|
|
|
if self._target:
|
|
|
|
# target_Q = Q_old(s_, argmax(Q_new(s_, *)))
|
2020-06-03 13:59:47 +08:00
|
|
|
a = self(batch, input='obs_next', eps=0).act
|
|
|
|
with torch.no_grad():
|
|
|
|
target_q = self(
|
|
|
|
batch, model='model_old', input='obs_next').logits
|
2020-06-02 22:29:50 +08:00
|
|
|
target_q = target_q[np.arange(len(a)), a]
|
|
|
|
else:
|
2020-06-03 13:59:47 +08:00
|
|
|
with torch.no_grad():
|
|
|
|
target_q = self(batch, input='obs_next').logits.max(dim=1)[0]
|
2020-06-02 22:29:50 +08:00
|
|
|
return target_q
|
|
|
|
|
2020-05-12 11:31:47 +08:00
|
|
|
def process_fn(self, batch: Batch, buffer: ReplayBuffer,
|
|
|
|
indice: np.ndarray) -> Batch:
|
2020-07-21 14:59:49 +08:00
|
|
|
"""Compute the n-step return for Q-learning targets. More details can
|
|
|
|
be found at :meth:`~tianshou.policy.BasePolicy.compute_nstep_return`.
|
2020-04-06 19:36:59 +08:00
|
|
|
"""
|
2020-06-02 22:29:50 +08:00
|
|
|
batch = self.compute_nstep_return(
|
2020-07-16 19:36:32 +08:00
|
|
|
batch, buffer, indice, self._target_q,
|
|
|
|
self._gamma, self._n_step, self._rew_norm)
|
2020-04-26 12:05:58 +08:00
|
|
|
if isinstance(buffer, PrioritizedReplayBuffer):
|
2020-06-02 22:29:50 +08:00
|
|
|
batch.update_weight = buffer.update_weight
|
|
|
|
batch.indice = indice
|
2020-03-14 21:48:31 +08:00
|
|
|
return batch
|
2020-03-15 17:41:00 +08:00
|
|
|
|
2020-05-12 11:31:47 +08:00
|
|
|
def forward(self, batch: Batch,
|
|
|
|
state: Optional[Union[dict, Batch, np.ndarray]] = None,
|
2020-05-16 20:08:32 +08:00
|
|
|
model: str = 'model',
|
|
|
|
input: str = 'obs',
|
2020-05-12 11:31:47 +08:00
|
|
|
eps: Optional[float] = None,
|
|
|
|
**kwargs) -> Batch:
|
2020-07-21 14:59:49 +08:00
|
|
|
"""Compute action over the given batch data. If you need to mask the
|
|
|
|
action, please add a "mask" into batch.obs, for example, if we have an
|
|
|
|
environment that has "0/1/2" three actions:
|
|
|
|
::
|
|
|
|
|
|
|
|
batch == Batch(
|
|
|
|
obs=Batch(
|
|
|
|
obs="original obs, with batch_size=1 for demonstration",
|
|
|
|
mask=np.array([[False, True, False]]),
|
|
|
|
# action 1 is available
|
|
|
|
# action 0 and 2 are unavailable
|
|
|
|
),
|
|
|
|
...
|
|
|
|
)
|
2020-04-06 19:36:59 +08:00
|
|
|
|
|
|
|
:param float eps: in [0, 1], for epsilon-greedy exploration method.
|
|
|
|
|
|
|
|
:return: A :class:`~tianshou.data.Batch` which has 3 keys:
|
|
|
|
|
|
|
|
* ``act`` the action.
|
|
|
|
* ``logits`` the network's raw output.
|
|
|
|
* ``state`` the hidden state.
|
|
|
|
|
2020-04-09 21:36:53 +08:00
|
|
|
.. seealso::
|
|
|
|
|
2020-04-10 10:47:16 +08:00
|
|
|
Please refer to :meth:`~tianshou.policy.BasePolicy.forward` for
|
2020-04-09 21:36:53 +08:00
|
|
|
more detailed explanation.
|
2020-04-06 19:36:59 +08:00
|
|
|
"""
|
2020-03-17 11:37:31 +08:00
|
|
|
model = getattr(self, model)
|
|
|
|
obs = getattr(batch, input)
|
2020-07-21 14:59:49 +08:00
|
|
|
obs_ = obs.obs if hasattr(obs, 'obs') else obs
|
|
|
|
q, h = model(obs_, state=state, info=batch.info)
|
2020-05-29 14:45:21 +02:00
|
|
|
act = to_numpy(q.max(dim=1)[1])
|
2020-07-21 14:59:49 +08:00
|
|
|
has_mask = hasattr(obs, 'mask')
|
|
|
|
if has_mask:
|
|
|
|
# some of actions are masked, they cannot be selected
|
|
|
|
q_ = to_numpy(q)
|
|
|
|
q_[~obs.mask] = -np.inf
|
|
|
|
act = q_.argmax(axis=1)
|
2020-03-17 11:37:31 +08:00
|
|
|
# add eps to act
|
|
|
|
if eps is None:
|
|
|
|
eps = self.eps
|
2020-05-21 11:35:41 +08:00
|
|
|
if not np.isclose(eps, 0):
|
|
|
|
for i in range(len(q)):
|
|
|
|
if np.random.rand() < eps:
|
2020-07-21 14:59:49 +08:00
|
|
|
q_ = np.random.rand(*q[i].shape)
|
|
|
|
if has_mask:
|
|
|
|
q_[~obs.mask[i]] = -np.inf
|
|
|
|
act[i] = q_.argmax()
|
2020-03-17 11:37:31 +08:00
|
|
|
return Batch(logits=q, act=act, state=h)
|
|
|
|
|
2020-05-12 11:31:47 +08:00
|
|
|
def learn(self, batch: Batch, **kwargs) -> Dict[str, float]:
|
2020-03-25 14:08:28 +08:00
|
|
|
if self._target and self._cnt % self._freq == 0:
|
|
|
|
self.sync_weight()
|
2020-03-15 17:41:00 +08:00
|
|
|
self.optim.zero_grad()
|
2020-07-24 17:38:12 +08:00
|
|
|
q = self(batch, eps=0.).logits
|
2020-06-02 22:29:50 +08:00
|
|
|
q = q[np.arange(len(q)), batch.act]
|
2020-07-23 15:12:02 +08:00
|
|
|
r = to_torch_as(batch.returns, q).flatten()
|
2020-06-02 22:29:50 +08:00
|
|
|
if hasattr(batch, 'update_weight'):
|
|
|
|
td = r - q
|
|
|
|
batch.update_weight(batch.indice, to_numpy(td))
|
2020-06-03 13:59:47 +08:00
|
|
|
impt_weight = to_torch_as(batch.impt_weight, q)
|
2020-06-02 22:29:50 +08:00
|
|
|
loss = (td.pow(2) * impt_weight).mean()
|
2020-04-26 12:05:58 +08:00
|
|
|
else:
|
|
|
|
loss = F.mse_loss(q, r)
|
2020-03-15 17:41:00 +08:00
|
|
|
loss.backward()
|
|
|
|
self.optim.step()
|
2020-03-25 14:08:28 +08:00
|
|
|
self._cnt += 1
|
2020-04-03 21:28:12 +08:00
|
|
|
return {'loss': loss.item()}
|