diff --git a/test/continuous/net.py b/test/continuous/net.py index 199d250..043044a 100644 --- a/test/continuous/net.py +++ b/test/continuous/net.py @@ -2,6 +2,8 @@ import torch import numpy as np from torch import nn +from tianshou.data import to_torch + class Actor(nn.Module): def __init__(self, layer_num, state_shape, action_shape, @@ -18,8 +20,7 @@ class Actor(nn.Module): self._max = max_action def forward(self, s, **kwargs): - if not isinstance(s, torch.Tensor): - s = torch.tensor(s, device=self.device, dtype=torch.float) + s = to_torch(s, device=self.device, dtype=torch.float) batch = s.shape[0] s = s.view(batch, -1) logits = self.model(s) @@ -44,8 +45,7 @@ class ActorProb(nn.Module): self._max = max_action def forward(self, s, **kwargs): - if not isinstance(s, torch.Tensor): - s = torch.tensor(s, device=self.device, dtype=torch.float) + s = to_torch(s, device=self.device, dtype=torch.float) batch = s.shape[0] s = s.view(batch, -1) logits = self.model(s) @@ -72,8 +72,7 @@ class Critic(nn.Module): self.model = nn.Sequential(*self.model) def forward(self, s, a=None, **kwargs): - if not isinstance(s, torch.Tensor): - s = torch.tensor(s, device=self.device, dtype=torch.float) + s = to_torch(s, device=self.device, dtype=torch.float) batch = s.shape[0] s = s.view(batch, -1) if a is not None: @@ -96,8 +95,7 @@ class RecurrentActorProb(nn.Module): self.sigma = nn.Parameter(torch.zeros(np.prod(action_shape), 1)) def forward(self, s, **kwargs): - if not isinstance(s, torch.Tensor): - s = torch.tensor(s, device=self.device, dtype=torch.float) + s = to_torch(s, device=self.device, dtype=torch.float) # s [bsz, len, dim] (training) or [bsz, dim] (evaluation) # In short, the tensor's shape in training phase is longer than which # in evaluation phase. @@ -127,8 +125,7 @@ class RecurrentCritic(nn.Module): self.fc2 = nn.Linear(128 + np.prod(action_shape), 1) def forward(self, s, a=None): - if not isinstance(s, torch.Tensor): - s = torch.tensor(s, device=self.device, dtype=torch.float) + s = to_torch(s, device=self.device, dtype=torch.float) # s [bsz, len, dim] (training) or [bsz, dim] (evaluation) # In short, the tensor's shape in training phase is longer than which # in evaluation phase. diff --git a/test/discrete/net.py b/test/discrete/net.py index a4de019..1dcf783 100644 --- a/test/discrete/net.py +++ b/test/discrete/net.py @@ -3,6 +3,8 @@ import numpy as np from torch import nn import torch.nn.functional as F +from tianshou.data import to_torch + class Net(nn.Module): def __init__(self, layer_num, state_shape, action_shape=0, device='cpu', @@ -21,8 +23,7 @@ class Net(nn.Module): self.model = nn.Sequential(*self.model) def forward(self, s, state=None, info={}): - if not isinstance(s, torch.Tensor): - s = torch.tensor(s, device=self.device, dtype=torch.float) + s = to_torch(s, device=self.device, dtype=torch.float) batch = s.shape[0] s = s.view(batch, -1) logits = self.model(s) @@ -65,8 +66,7 @@ class Recurrent(nn.Module): self.fc2 = nn.Linear(128, np.prod(action_shape)) def forward(self, s, state=None, info={}): - if not isinstance(s, torch.Tensor): - s = torch.tensor(s, device=self.device, dtype=torch.float) + s = to_torch(s, device=self.device, dtype=torch.float) # s [bsz, len, dim] (training) or [bsz, dim] (evaluation) # In short, the tensor's shape in training phase is longer than which # in evaluation phase. diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index 75752d7..387cfed 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -2,7 +2,7 @@ import torch import numpy as np from torch import nn from abc import ABC, abstractmethod -from typing import Dict, List, Union, Optional +from typing import Dict, List, Union, Optional, Callable from tianshou.data import Batch, ReplayBuffer @@ -113,6 +113,8 @@ class BasePolicy(ABC, nn.Module): to 0.99. :param float gae_lambda: the parameter for Generalized Advantage Estimation, should be in [0, 1], defaults to 0.95. + + :return: a Batch. The result will be stored in batch.returns. """ if v_s_ is None: v_s_ = np.zeros_like(batch.rew) @@ -120,12 +122,61 @@ class BasePolicy(ABC, nn.Module): if not isinstance(v_s_, np.ndarray): v_s_ = np.array(v_s_, np.float) v_s_ = v_s_.reshape(batch.rew.shape) - batch.returns = np.roll(v_s_, 1, axis=0) + returns = np.roll(v_s_, 1, axis=0) m = (1. - batch.done) * gamma - delta = batch.rew + v_s_ * m - batch.returns + delta = batch.rew + v_s_ * m - returns m *= gae_lambda gae = 0. for i in range(len(batch.rew) - 1, -1, -1): gae = delta[i] + m[i] * gae - batch.returns[i] += gae + returns[i] += gae + batch.returns = returns + return batch + + @staticmethod + def compute_nstep_return( + batch: Batch, + buffer: ReplayBuffer, + indice: np.ndarray, + target_q_fn: Callable[[ReplayBuffer, np.ndarray], np.ndarray], + gamma: float = 0.99, + n_step: int = 1 + ) -> np.ndarray: + r"""Compute n-step return for Q-learning targets: + + .. math:: + G_t = \sum_{i = t}^{t + n - 1} \gamma^{i - t}(1 - d_i)r_i + + \gamma^n (1 - d_{t + n}) Q_{\mathrm{target}}(s_{t + n}) + + , where :math:`\gamma` is the discount factor, + :math:`\gamma \in [0, 1]`, :math:`d_t` is the done flag of step + :math:`t`. + + :param batch: a data batch, which is equal to buffer[indice]. + :type batch: :class:`~tianshou.data.Batch` + :param buffer: a data buffer which contains several full-episode data + chronologically. + :type buffer: :class:`~tianshou.data.ReplayBuffer` + :param indice: sampled timestep. + :type indice: numpy.ndarray + :param float gamma: the discount factor, should be in [0, 1], defaults + to 0.99. + :param int n_step: the number of estimation step, should be an int + greater than 0, defaults to 1. + + :return: a Batch. The result will be stored in batch.returns. + """ + returns = np.zeros_like(indice) + gammas = np.zeros_like(indice) + n_step + done, rew, buf_len = buffer.done, buffer.rew, len(buffer) + for n in range(n_step - 1, -1, -1): + now = (indice + n) % buf_len + gammas[done[now] > 0] = n + returns[done[now] > 0] = 0 + returns = rew[now] + gamma * returns + terminal = (indice + n_step - 1) % buf_len + target_q = target_q_fn(buffer, terminal) + target_q[gammas != n_step] = 0 + returns += (gamma ** gammas) * target_q + batch.returns = returns return batch diff --git a/tianshou/policy/modelfree/dqn.py b/tianshou/policy/modelfree/dqn.py index 71c9866..f535079 100644 --- a/tianshou/policy/modelfree/dqn.py +++ b/tianshou/policy/modelfree/dqn.py @@ -68,6 +68,21 @@ class DQNPolicy(BasePolicy): """Synchronize the weight for the target network.""" self.model_old.load_state_dict(self.model.state_dict()) + def _target_q(self, buffer: ReplayBuffer, + indice: np.ndarray) -> np.ndarray: + data = buffer[indice] + if self._target: + # target_Q = Q_old(s_, argmax(Q_new(s_, *))) + a = self(data, input='obs_next', eps=0).act + target_q = self( + data, model='model_old', input='obs_next').logits + target_q = to_numpy(target_q) + target_q = target_q[np.arange(len(a)), a] + else: + target_q = self(data, input='obs_next').logits + target_q = to_numpy(target_q).max(axis=1) + return target_q + def process_fn(self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray) -> Batch: r"""Compute the n-step return for Q-learning targets: @@ -82,46 +97,11 @@ class DQNPolicy(BasePolicy): :math:`t`. If there is no target network, the :math:`Q_{old}` is equal to :math:`Q_{new}`. """ - returns = np.zeros_like(indice) - gammas = np.zeros_like(indice) + self._n_step - for n in range(self._n_step - 1, -1, -1): - now = (indice + n) % len(buffer) - gammas[buffer.done[now] > 0] = n - returns[buffer.done[now] > 0] = 0 - returns = buffer.rew[now] + self._gamma * returns - terminal = (indice + self._n_step - 1) % len(buffer) - terminal_data = buffer[terminal] - if self._target: - # target_Q = Q_old(s_, argmax(Q_new(s_, *))) - a = self(terminal_data, input='obs_next', eps=0).act - target_q = self( - terminal_data, model='model_old', input='obs_next').logits - if isinstance(target_q, torch.Tensor): - target_q = to_numpy(target_q) - target_q = target_q[np.arange(len(a)), a] - else: - target_q = self(terminal_data, input='obs_next').logits - if isinstance(target_q, torch.Tensor): - target_q = to_numpy(target_q) - target_q = target_q.max(axis=1) - target_q[gammas != self._n_step] = 0 - returns += (self._gamma ** gammas) * target_q - batch.returns = returns + batch = self.compute_nstep_return( + batch, buffer, indice, self._target_q, self._gamma, self._n_step) if isinstance(buffer, PrioritizedReplayBuffer): - q = self(batch).logits - q = q[np.arange(len(q)), batch.act] - r = batch.returns - if isinstance(r, np.ndarray): - r = to_torch(r, device=q.device, dtype=q.dtype) - td = r - q - buffer.update_weight(indice, to_numpy(td)) - impt_weight = to_torch(batch.impt_weight, - device=q.device, dtype=torch.float) - loss = (td.pow(2) * impt_weight).mean() - if not hasattr(batch, 'loss'): - batch.loss = loss - else: - batch.loss += loss + batch.update_weight = buffer.update_weight + batch.indice = indice return batch def forward(self, batch: Batch, @@ -162,14 +142,16 @@ class DQNPolicy(BasePolicy): if self._target and self._cnt % self._freq == 0: self.sync_weight() self.optim.zero_grad() - if hasattr(batch, 'loss'): - loss = batch.loss + q = self(batch).logits + q = q[np.arange(len(q)), batch.act] + r = to_torch(batch.returns, device=q.device, dtype=q.dtype) + if hasattr(batch, 'update_weight'): + td = r - q + batch.update_weight(batch.indice, to_numpy(td)) + impt_weight = to_torch(batch.impt_weight, + device=q.device, dtype=torch.float) + loss = (td.pow(2) * impt_weight).mean() else: - q = self(batch).logits - q = q[np.arange(len(q)), batch.act] - r = batch.returns - if isinstance(r, np.ndarray): - r = to_torch(r, device=q.device, dtype=q.dtype) loss = F.mse_loss(q, r) loss.backward() self.optim.step() diff --git a/tianshou/utils/moving_average.py b/tianshou/utils/moving_average.py index d994762..7dfc0b1 100644 --- a/tianshou/utils/moving_average.py +++ b/tianshou/utils/moving_average.py @@ -2,6 +2,8 @@ import torch import numpy as np from typing import Union +from tianshou.data import to_numpy + class MovAvg(object): """Class for moving average. It will automatically exclude the infinity and @@ -32,7 +34,7 @@ class MovAvg(object): only one element, a python scalar, or a list of python scalar. """ if isinstance(x, torch.Tensor): - x = x.item() + x = to_numpy(x.flatten()) if isinstance(x, list) or isinstance(x, np.ndarray): for _ in x: if _ not in self.banned: