Change the behavior of to_numpy and to_torch: from now on, dict is automatically converted to Batch and list is automatically converted to np.ndarray (if an error occurs, raise the exception instead of converting each element in the list).
183 lines
6.6 KiB
Python
183 lines
6.6 KiB
Python
import torch
|
|
import numpy as np
|
|
from copy import deepcopy
|
|
from typing import Any, Dict, Union, Optional
|
|
|
|
from tianshou.policy import BasePolicy
|
|
from tianshou.data import Batch, ReplayBuffer, to_torch_as, to_numpy
|
|
|
|
|
|
class DQNPolicy(BasePolicy):
|
|
"""Implementation of Deep Q Network. arXiv:1312.5602.
|
|
|
|
Implementation of Double Q-Learning. arXiv:1509.06461.
|
|
|
|
Implementation of Dueling DQN. arXiv:1511.06581 (the dueling DQN is
|
|
implemented in the network side, not here).
|
|
|
|
:param torch.nn.Module model: a model following the rules in
|
|
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
|
|
:param torch.optim.Optimizer optim: a torch.optim for optimizing the model.
|
|
:param float discount_factor: in [0, 1].
|
|
:param int estimation_step: the number of steps to look ahead. Default to 1.
|
|
:param int target_update_freq: the target network update frequency (0 if
|
|
you do not use the target network). Default to 0.
|
|
:param bool reward_normalization: normalize the reward to Normal(0, 1).
|
|
Default to False.
|
|
|
|
.. seealso::
|
|
|
|
Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed
|
|
explanation.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
model: torch.nn.Module,
|
|
optim: torch.optim.Optimizer,
|
|
discount_factor: float = 0.99,
|
|
estimation_step: int = 1,
|
|
target_update_freq: int = 0,
|
|
reward_normalization: bool = False,
|
|
**kwargs: Any,
|
|
) -> None:
|
|
super().__init__(**kwargs)
|
|
self.model = model
|
|
self.optim = optim
|
|
self.eps = 0.0
|
|
assert 0.0 <= discount_factor <= 1.0, "discount factor should be in [0, 1]"
|
|
self._gamma = discount_factor
|
|
assert estimation_step > 0, "estimation_step should be greater than 0"
|
|
self._n_step = estimation_step
|
|
self._target = target_update_freq > 0
|
|
self._freq = target_update_freq
|
|
self._iter = 0
|
|
if self._target:
|
|
self.model_old = deepcopy(self.model)
|
|
self.model_old.eval()
|
|
self._rew_norm = reward_normalization
|
|
|
|
def set_eps(self, eps: float) -> None:
|
|
"""Set the eps for epsilon-greedy exploration."""
|
|
self.eps = eps
|
|
|
|
def train(self, mode: bool = True) -> "DQNPolicy":
|
|
"""Set the module in training mode, except for the target network."""
|
|
self.training = mode
|
|
self.model.train(mode)
|
|
return self
|
|
|
|
def sync_weight(self) -> None:
|
|
"""Synchronize the weight for the target network."""
|
|
self.model_old.load_state_dict(self.model.state_dict()) # type: ignore
|
|
|
|
def _target_q(self, buffer: ReplayBuffer, indice: np.ndarray) -> torch.Tensor:
|
|
batch = buffer[indice] # batch.obs_next: s_{t+n}
|
|
result = self(batch, input="obs_next")
|
|
if self._target:
|
|
# target_Q = Q_old(s_, argmax(Q_new(s_, *)))
|
|
target_q = self(batch, model="model_old", input="obs_next").logits
|
|
else:
|
|
target_q = result.logits
|
|
target_q = target_q[np.arange(len(result.act)), result.act]
|
|
return target_q
|
|
|
|
def process_fn(
|
|
self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray
|
|
) -> Batch:
|
|
"""Compute the n-step return for Q-learning targets.
|
|
|
|
More details can be found at
|
|
:meth:`~tianshou.policy.BasePolicy.compute_nstep_return`.
|
|
"""
|
|
batch = self.compute_nstep_return(
|
|
batch, buffer, indice, self._target_q,
|
|
self._gamma, self._n_step, self._rew_norm)
|
|
return batch
|
|
|
|
def compute_q_value(
|
|
self, logits: torch.Tensor, mask: Optional[np.ndarray]
|
|
) -> torch.Tensor:
|
|
"""Compute the q value based on the network's raw output and action mask."""
|
|
if mask is not None:
|
|
# the masked q value should be smaller than logits.min()
|
|
min_value = logits.min() - logits.max() - 1.0
|
|
logits = logits + to_torch_as(1 - mask, logits) * min_value
|
|
return logits
|
|
|
|
def forward(
|
|
self,
|
|
batch: Batch,
|
|
state: Optional[Union[dict, Batch, np.ndarray]] = None,
|
|
model: str = "model",
|
|
input: str = "obs",
|
|
**kwargs: Any,
|
|
) -> Batch:
|
|
"""Compute action over the given batch data.
|
|
|
|
If you need to mask the action, please add a "mask" into batch.obs, for
|
|
example, if we have an environment that has "0/1/2" three actions:
|
|
::
|
|
|
|
batch == Batch(
|
|
obs=Batch(
|
|
obs="original obs, with batch_size=1 for demonstration",
|
|
mask=np.array([[False, True, False]]),
|
|
# action 1 is available
|
|
# action 0 and 2 are unavailable
|
|
),
|
|
...
|
|
)
|
|
|
|
:param float eps: in [0, 1], for epsilon-greedy exploration method.
|
|
|
|
:return: A :class:`~tianshou.data.Batch` which has 3 keys:
|
|
|
|
* ``act`` the action.
|
|
* ``logits`` the network's raw output.
|
|
* ``state`` the hidden state.
|
|
|
|
.. seealso::
|
|
|
|
Please refer to :meth:`~tianshou.policy.BasePolicy.forward` for
|
|
more detailed explanation.
|
|
"""
|
|
model = getattr(self, model)
|
|
obs = batch[input]
|
|
obs_ = obs.obs if hasattr(obs, "obs") else obs
|
|
logits, h = model(obs_, state=state, info=batch.info)
|
|
q = self.compute_q_value(logits, getattr(obs, "mask", None))
|
|
if not hasattr(self, "max_action_num"):
|
|
self.max_action_num = q.shape[1]
|
|
act = to_numpy(q.max(dim=1)[1])
|
|
return Batch(logits=logits, act=act, state=h)
|
|
|
|
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
|
|
if self._target and self._iter % self._freq == 0:
|
|
self.sync_weight()
|
|
self.optim.zero_grad()
|
|
weight = batch.pop("weight", 1.0)
|
|
q = self(batch).logits
|
|
q = q[np.arange(len(q)), batch.act]
|
|
r = to_torch_as(batch.returns.flatten(), q)
|
|
td = r - q
|
|
loss = (td.pow(2) * weight).mean()
|
|
batch.weight = td # prio-buffer
|
|
loss.backward()
|
|
self.optim.step()
|
|
self._iter += 1
|
|
return {"loss": loss.item()}
|
|
|
|
def exploration_noise(
|
|
self, act: Union[np.ndarray, Batch], batch: Batch
|
|
) -> Union[np.ndarray, Batch]:
|
|
if isinstance(act, np.ndarray) and not np.isclose(self.eps, 0.0):
|
|
bsz = len(act)
|
|
rand_mask = np.random.rand(bsz) < self.eps
|
|
q = np.random.rand(bsz, self.max_action_num) # [0, 1]
|
|
if hasattr(batch.obs, "mask"):
|
|
q += batch.obs.mask
|
|
rand_act = q.argmax(axis=1)
|
|
act[rand_mask] = rand_act[rand_mask]
|
|
return act
|