Closes #914 Additional changes: - Deprecate python below 11 - Remove 3rd party and throughput tests. This simplifies install and test pipeline - Remove gym compatibility and shimmy - Format with 3.11 conventions. In particular, add `zip(..., strict=True/False)` where possible Since the additional tests and gym were complicating the CI pipeline (flaky and dist-dependent), it didn't make sense to work on fixing the current tests in this PR to then just delete them in the next one. So this PR changes the build and removes these tests at the same time.
216 lines
7.8 KiB
Python
216 lines
7.8 KiB
Python
from copy import deepcopy
|
|
from typing import Any, cast
|
|
|
|
import numpy as np
|
|
import torch
|
|
|
|
from tianshou.data import Batch, ReplayBuffer, to_numpy, to_torch_as
|
|
from tianshou.data.batch import BatchProtocol
|
|
from tianshou.data.types import (
|
|
BatchWithReturnsProtocol,
|
|
ModelOutputBatchProtocol,
|
|
RolloutBatchProtocol,
|
|
)
|
|
from tianshou.policy import BasePolicy
|
|
|
|
|
|
class DQNPolicy(BasePolicy):
|
|
"""Implementation of Deep Q Network. arXiv:1312.5602.
|
|
|
|
Implementation of Double Q-Learning. arXiv:1509.06461.
|
|
|
|
Implementation of Dueling DQN. arXiv:1511.06581 (the dueling DQN is
|
|
implemented in the network side, not here).
|
|
|
|
:param torch.nn.Module model: a model following the rules in
|
|
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
|
|
:param torch.optim.Optimizer optim: a torch.optim for optimizing the model.
|
|
:param float discount_factor: in [0, 1].
|
|
:param int estimation_step: the number of steps to look ahead. Default to 1.
|
|
:param int target_update_freq: the target network update frequency (0 if
|
|
you do not use the target network). Default to 0.
|
|
:param bool reward_normalization: normalize the reward to Normal(0, 1).
|
|
Default to False.
|
|
:param bool is_double: use double dqn. Default to True.
|
|
:param bool clip_loss_grad: clip the gradient of the loss in accordance
|
|
with nature14236; this amounts to using the Huber loss instead of
|
|
the MSE loss. Default to False.
|
|
:param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
|
|
optimizer in each policy.update(). Default to None (no lr_scheduler).
|
|
|
|
.. seealso::
|
|
|
|
Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed
|
|
explanation.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
model: torch.nn.Module,
|
|
optim: torch.optim.Optimizer,
|
|
discount_factor: float = 0.99,
|
|
estimation_step: int = 1,
|
|
target_update_freq: int = 0,
|
|
reward_normalization: bool = False,
|
|
is_double: bool = True,
|
|
clip_loss_grad: bool = False,
|
|
**kwargs: Any,
|
|
) -> None:
|
|
super().__init__(**kwargs)
|
|
self.model = model
|
|
self.optim = optim
|
|
self.eps = 0.0
|
|
assert 0.0 <= discount_factor <= 1.0, "discount factor should be in [0, 1]"
|
|
self._gamma = discount_factor
|
|
assert estimation_step > 0, "estimation_step should be greater than 0"
|
|
self._n_step = estimation_step
|
|
self._target = target_update_freq > 0
|
|
self._freq = target_update_freq
|
|
self._iter = 0
|
|
if self._target:
|
|
self.model_old = deepcopy(self.model)
|
|
self.model_old.eval()
|
|
self._rew_norm = reward_normalization
|
|
self._is_double = is_double
|
|
self._clip_loss_grad = clip_loss_grad
|
|
|
|
def set_eps(self, eps: float) -> None:
|
|
"""Set the eps for epsilon-greedy exploration."""
|
|
self.eps = eps
|
|
|
|
def train(self, mode: bool = True) -> "DQNPolicy":
|
|
"""Set the module in training mode, except for the target network."""
|
|
self.training = mode
|
|
self.model.train(mode)
|
|
return self
|
|
|
|
def sync_weight(self) -> None:
|
|
"""Synchronize the weight for the target network."""
|
|
self.model_old.load_state_dict(self.model.state_dict())
|
|
|
|
def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor:
|
|
batch = buffer[indices] # batch.obs_next: s_{t+n}
|
|
result = self(batch, input="obs_next")
|
|
if self._target:
|
|
# target_Q = Q_old(s_, argmax(Q_new(s_, *)))
|
|
target_q = self(batch, model="model_old", input="obs_next").logits
|
|
else:
|
|
target_q = result.logits
|
|
if self._is_double:
|
|
return target_q[np.arange(len(result.act)), result.act]
|
|
# Nature DQN, over estimate
|
|
return target_q.max(dim=1)[0]
|
|
|
|
def process_fn(
|
|
self,
|
|
batch: RolloutBatchProtocol,
|
|
buffer: ReplayBuffer,
|
|
indices: np.ndarray,
|
|
) -> BatchWithReturnsProtocol:
|
|
"""Compute the n-step return for Q-learning targets.
|
|
|
|
More details can be found at
|
|
:meth:`~tianshou.policy.BasePolicy.compute_nstep_return`.
|
|
"""
|
|
return self.compute_nstep_return(
|
|
batch,
|
|
buffer,
|
|
indices,
|
|
self._target_q,
|
|
self._gamma,
|
|
self._n_step,
|
|
self._rew_norm,
|
|
)
|
|
|
|
def compute_q_value(self, logits: torch.Tensor, mask: np.ndarray | None) -> torch.Tensor:
|
|
"""Compute the q value based on the network's raw output and action mask."""
|
|
if mask is not None:
|
|
# the masked q value should be smaller than logits.min()
|
|
min_value = logits.min() - logits.max() - 1.0
|
|
logits = logits + to_torch_as(1 - mask, logits) * min_value
|
|
return logits
|
|
|
|
def forward(
|
|
self,
|
|
batch: RolloutBatchProtocol,
|
|
state: dict | BatchProtocol | np.ndarray | None = None,
|
|
model: str = "model",
|
|
input: str = "obs",
|
|
**kwargs: Any,
|
|
) -> ModelOutputBatchProtocol:
|
|
"""Compute action over the given batch data.
|
|
|
|
If you need to mask the action, please add a "mask" into batch.obs, for
|
|
example, if we have an environment that has "0/1/2" three actions:
|
|
::
|
|
|
|
batch == Batch(
|
|
obs=Batch(
|
|
obs="original obs, with batch_size=1 for demonstration",
|
|
mask=np.array([[False, True, False]]),
|
|
# action 1 is available
|
|
# action 0 and 2 are unavailable
|
|
),
|
|
...
|
|
)
|
|
|
|
:return: A :class:`~tianshou.data.Batch` which has 3 keys:
|
|
|
|
* ``act`` the action.
|
|
* ``logits`` the network's raw output.
|
|
* ``state`` the hidden state.
|
|
|
|
.. seealso::
|
|
|
|
Please refer to :meth:`~tianshou.policy.BasePolicy.forward` for
|
|
more detailed explanation.
|
|
"""
|
|
model = getattr(self, model)
|
|
obs = batch[input]
|
|
obs_next = obs.obs if hasattr(obs, "obs") else obs
|
|
logits, hidden = model(obs_next, state=state, info=batch.info)
|
|
q = self.compute_q_value(logits, getattr(obs, "mask", None))
|
|
if not hasattr(self, "max_action_num"):
|
|
self.max_action_num = q.shape[1]
|
|
act = to_numpy(q.max(dim=1)[1])
|
|
result = Batch(logits=logits, act=act, state=hidden)
|
|
return cast(ModelOutputBatchProtocol, result)
|
|
|
|
def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> dict[str, float]:
|
|
if self._target and self._iter % self._freq == 0:
|
|
self.sync_weight()
|
|
self.optim.zero_grad()
|
|
weight = batch.pop("weight", 1.0)
|
|
q = self(batch).logits
|
|
q = q[np.arange(len(q)), batch.act]
|
|
returns = to_torch_as(batch.returns.flatten(), q)
|
|
td_error = returns - q
|
|
|
|
if self._clip_loss_grad:
|
|
y = q.reshape(-1, 1)
|
|
t = returns.reshape(-1, 1)
|
|
loss = torch.nn.functional.huber_loss(y, t, reduction="mean")
|
|
else:
|
|
loss = (td_error.pow(2) * weight).mean()
|
|
|
|
batch.weight = td_error # prio-buffer
|
|
loss.backward()
|
|
self.optim.step()
|
|
self._iter += 1
|
|
return {"loss": loss.item()}
|
|
|
|
def exploration_noise(
|
|
self,
|
|
act: np.ndarray | BatchProtocol,
|
|
batch: RolloutBatchProtocol,
|
|
) -> np.ndarray | BatchProtocol:
|
|
if isinstance(act, np.ndarray) and not np.isclose(self.eps, 0.0):
|
|
bsz = len(act)
|
|
rand_mask = np.random.rand(bsz) < self.eps
|
|
q = np.random.rand(bsz, self.max_action_num) # [0, 1]
|
|
if hasattr(batch.obs, "mask"):
|
|
q += batch.obs.mask
|
|
rand_act = q.argmax(axis=1)
|
|
act[rand_mask] = rand_act[rand_mask]
|
|
return act
|