n+e 09692c84fe
fix numpy>=1.20 typing check (#323)
Change the behavior of to_numpy and to_torch: from now on, dict is automatically converted to Batch and list is automatically converted to np.ndarray (if an error occurs, raise the exception instead of converting each element in the list).
2021-03-30 16:06:03 +08:00

183 lines
6.6 KiB
Python

import torch
import numpy as np
from copy import deepcopy
from typing import Any, Dict, Union, Optional
from tianshou.policy import BasePolicy
from tianshou.data import Batch, ReplayBuffer, to_torch_as, to_numpy
class DQNPolicy(BasePolicy):
"""Implementation of Deep Q Network. arXiv:1312.5602.
Implementation of Double Q-Learning. arXiv:1509.06461.
Implementation of Dueling DQN. arXiv:1511.06581 (the dueling DQN is
implemented in the network side, not here).
:param torch.nn.Module model: a model following the rules in
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
:param torch.optim.Optimizer optim: a torch.optim for optimizing the model.
:param float discount_factor: in [0, 1].
:param int estimation_step: the number of steps to look ahead. Default to 1.
:param int target_update_freq: the target network update frequency (0 if
you do not use the target network). Default to 0.
:param bool reward_normalization: normalize the reward to Normal(0, 1).
Default to False.
.. seealso::
Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed
explanation.
"""
def __init__(
self,
model: torch.nn.Module,
optim: torch.optim.Optimizer,
discount_factor: float = 0.99,
estimation_step: int = 1,
target_update_freq: int = 0,
reward_normalization: bool = False,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
self.model = model
self.optim = optim
self.eps = 0.0
assert 0.0 <= discount_factor <= 1.0, "discount factor should be in [0, 1]"
self._gamma = discount_factor
assert estimation_step > 0, "estimation_step should be greater than 0"
self._n_step = estimation_step
self._target = target_update_freq > 0
self._freq = target_update_freq
self._iter = 0
if self._target:
self.model_old = deepcopy(self.model)
self.model_old.eval()
self._rew_norm = reward_normalization
def set_eps(self, eps: float) -> None:
"""Set the eps for epsilon-greedy exploration."""
self.eps = eps
def train(self, mode: bool = True) -> "DQNPolicy":
"""Set the module in training mode, except for the target network."""
self.training = mode
self.model.train(mode)
return self
def sync_weight(self) -> None:
"""Synchronize the weight for the target network."""
self.model_old.load_state_dict(self.model.state_dict()) # type: ignore
def _target_q(self, buffer: ReplayBuffer, indice: np.ndarray) -> torch.Tensor:
batch = buffer[indice] # batch.obs_next: s_{t+n}
result = self(batch, input="obs_next")
if self._target:
# target_Q = Q_old(s_, argmax(Q_new(s_, *)))
target_q = self(batch, model="model_old", input="obs_next").logits
else:
target_q = result.logits
target_q = target_q[np.arange(len(result.act)), result.act]
return target_q
def process_fn(
self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray
) -> Batch:
"""Compute the n-step return for Q-learning targets.
More details can be found at
:meth:`~tianshou.policy.BasePolicy.compute_nstep_return`.
"""
batch = self.compute_nstep_return(
batch, buffer, indice, self._target_q,
self._gamma, self._n_step, self._rew_norm)
return batch
def compute_q_value(
self, logits: torch.Tensor, mask: Optional[np.ndarray]
) -> torch.Tensor:
"""Compute the q value based on the network's raw output and action mask."""
if mask is not None:
# the masked q value should be smaller than logits.min()
min_value = logits.min() - logits.max() - 1.0
logits = logits + to_torch_as(1 - mask, logits) * min_value
return logits
def forward(
self,
batch: Batch,
state: Optional[Union[dict, Batch, np.ndarray]] = None,
model: str = "model",
input: str = "obs",
**kwargs: Any,
) -> Batch:
"""Compute action over the given batch data.
If you need to mask the action, please add a "mask" into batch.obs, for
example, if we have an environment that has "0/1/2" three actions:
::
batch == Batch(
obs=Batch(
obs="original obs, with batch_size=1 for demonstration",
mask=np.array([[False, True, False]]),
# action 1 is available
# action 0 and 2 are unavailable
),
...
)
:param float eps: in [0, 1], for epsilon-greedy exploration method.
:return: A :class:`~tianshou.data.Batch` which has 3 keys:
* ``act`` the action.
* ``logits`` the network's raw output.
* ``state`` the hidden state.
.. seealso::
Please refer to :meth:`~tianshou.policy.BasePolicy.forward` for
more detailed explanation.
"""
model = getattr(self, model)
obs = batch[input]
obs_ = obs.obs if hasattr(obs, "obs") else obs
logits, h = model(obs_, state=state, info=batch.info)
q = self.compute_q_value(logits, getattr(obs, "mask", None))
if not hasattr(self, "max_action_num"):
self.max_action_num = q.shape[1]
act = to_numpy(q.max(dim=1)[1])
return Batch(logits=logits, act=act, state=h)
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
if self._target and self._iter % self._freq == 0:
self.sync_weight()
self.optim.zero_grad()
weight = batch.pop("weight", 1.0)
q = self(batch).logits
q = q[np.arange(len(q)), batch.act]
r = to_torch_as(batch.returns.flatten(), q)
td = r - q
loss = (td.pow(2) * weight).mean()
batch.weight = td # prio-buffer
loss.backward()
self.optim.step()
self._iter += 1
return {"loss": loss.item()}
def exploration_noise(
self, act: Union[np.ndarray, Batch], batch: Batch
) -> Union[np.ndarray, Batch]:
if isinstance(act, np.ndarray) and not np.isclose(self.eps, 0.0):
bsz = len(act)
rand_mask = np.random.rand(bsz) < self.eps
q = np.random.rand(bsz, self.max_action_num) # [0, 1]
if hasattr(batch.obs, "mask"):
q += batch.obs.mask
rand_act = q.argmax(axis=1)
act[rand_mask] = rand_act[rand_mask]
return act