wizardsheng 1eb6137645
Add QR-DQN algorithm (#276)
This is the PR for QR-DQN algorithm: https://arxiv.org/abs/1710.10044

1. add QR-DQN policy in tianshou/policy/modelfree/qrdqn.py.
2. add QR-DQN net in examples/atari/atari_network.py.
3. add QR-DQN atari example in examples/atari/atari_qrdqn.py.
4. add QR-DQN statement in tianshou/policy/init.py.
5. add QR-DQN unit test in test/discrete/test_qrdqn.py.
6. add QR-DQN atari results in examples/atari/results/qrdqn/.
7. add compute_q_value in DQNPolicy and C51Policy for simplify forward function.
8. move `with torch.no_grad():` from `_target_q` to BasePolicy

By running "python3 atari_qrdqn.py --task "PongNoFrameskip-v4" --batch-size 64", get best_result': '19.8 ± 0.40', in epoch 8.
2021-01-28 09:27:05 +08:00

95 lines
3.7 KiB
Python

import torch
import warnings
import numpy as np
from typing import Any, Dict
import torch.nn.functional as F
from tianshou.policy import DQNPolicy
from tianshou.data import Batch, ReplayBuffer
class QRDQNPolicy(DQNPolicy):
"""Implementation of Quantile Regression Deep Q-Network. arXiv:1710.10044.
:param torch.nn.Module model: a model following the rules in
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
:param torch.optim.Optimizer optim: a torch.optim for optimizing the model.
:param float discount_factor: in [0, 1].
:param int num_quantiles: the number of quantile midpoints in the inverse
cumulative distribution function of the value, defaults to 200.
:param int estimation_step: greater than 1, the number of steps to look
ahead.
:param int target_update_freq: the target network update frequency (0 if
you do not use the target network).
:param bool reward_normalization: normalize the reward to Normal(0, 1),
defaults to False.
.. seealso::
Please refer to :class:`~tianshou.policy.DQNPolicy` for more detailed
explanation.
"""
def __init__(
self,
model: torch.nn.Module,
optim: torch.optim.Optimizer,
discount_factor: float = 0.99,
num_quantiles: int = 200,
estimation_step: int = 1,
target_update_freq: int = 0,
reward_normalization: bool = False,
**kwargs: Any,
) -> None:
super().__init__(model, optim, discount_factor, estimation_step,
target_update_freq, reward_normalization, **kwargs)
assert num_quantiles > 1, "num_quantiles should be greater than 1"
self._num_quantiles = num_quantiles
tau = torch.linspace(0, 1, self._num_quantiles + 1)
self.tau_hat = torch.nn.Parameter(
((tau[:-1] + tau[1:]) / 2).view(1, -1, 1), requires_grad=False)
warnings.filterwarnings("ignore", message="Using a target size")
def _target_q(
self, buffer: ReplayBuffer, indice: np.ndarray
) -> torch.Tensor:
batch = buffer[indice] # batch.obs_next: s_{t+n}
if self._target:
a = self(batch, input="obs_next").act
next_dist = self(
batch, model="model_old", input="obs_next"
).logits
else:
next_b = self(batch, input="obs_next")
a = next_b.act
next_dist = next_b.logits
next_dist = next_dist[np.arange(len(a)), a, :]
return next_dist # shape: [bsz, num_quantiles]
def compute_q_value(self, logits: torch.Tensor) -> torch.Tensor:
"""Compute the q value based on the network's raw output logits."""
return logits.mean(2)
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
if self._target and self._iter % self._freq == 0:
self.sync_weight()
self.optim.zero_grad()
weight = batch.pop("weight", 1.0)
curr_dist = self(batch).logits
act = batch.act
curr_dist = curr_dist[np.arange(len(act)), act, :].unsqueeze(2)
target_dist = batch.returns.unsqueeze(1)
# calculate each element's difference between curr_dist and target_dist
u = F.smooth_l1_loss(target_dist, curr_dist, reduction="none")
huber_loss = (u * (
self.tau_hat - (target_dist - curr_dist).detach().le(0.).float()
).abs()).sum(-1).mean(1)
loss = (huber_loss * weight).mean()
# ref: https://github.com/ku2482/fqf-iqn-qrdqn.pytorch/
# blob/master/fqf_iqn_qrdqn/agent/qrdqn_agent.py L130
batch.weight = u.detach().abs().sum(-1).mean(1) # prio-buffer
loss.backward()
self.optim.step()
self._iter += 1
return {"loss": loss.item()}