compute_nstep_returns (item 2 of #51)

This commit is contained in:
Trinkle23897 2020-06-02 22:29:50 +08:00
parent f818a2467b
commit ff81a18f42
5 changed files with 97 additions and 65 deletions

View File

@ -2,6 +2,8 @@ import torch
import numpy as np
from torch import nn
from tianshou.data import to_torch
class Actor(nn.Module):
def __init__(self, layer_num, state_shape, action_shape,
@ -18,8 +20,7 @@ class Actor(nn.Module):
self._max = max_action
def forward(self, s, **kwargs):
if not isinstance(s, torch.Tensor):
s = torch.tensor(s, device=self.device, dtype=torch.float)
s = to_torch(s, device=self.device, dtype=torch.float)
batch = s.shape[0]
s = s.view(batch, -1)
logits = self.model(s)
@ -44,8 +45,7 @@ class ActorProb(nn.Module):
self._max = max_action
def forward(self, s, **kwargs):
if not isinstance(s, torch.Tensor):
s = torch.tensor(s, device=self.device, dtype=torch.float)
s = to_torch(s, device=self.device, dtype=torch.float)
batch = s.shape[0]
s = s.view(batch, -1)
logits = self.model(s)
@ -72,8 +72,7 @@ class Critic(nn.Module):
self.model = nn.Sequential(*self.model)
def forward(self, s, a=None, **kwargs):
if not isinstance(s, torch.Tensor):
s = torch.tensor(s, device=self.device, dtype=torch.float)
s = to_torch(s, device=self.device, dtype=torch.float)
batch = s.shape[0]
s = s.view(batch, -1)
if a is not None:
@ -96,8 +95,7 @@ class RecurrentActorProb(nn.Module):
self.sigma = nn.Parameter(torch.zeros(np.prod(action_shape), 1))
def forward(self, s, **kwargs):
if not isinstance(s, torch.Tensor):
s = torch.tensor(s, device=self.device, dtype=torch.float)
s = to_torch(s, device=self.device, dtype=torch.float)
# s [bsz, len, dim] (training) or [bsz, dim] (evaluation)
# In short, the tensor's shape in training phase is longer than which
# in evaluation phase.
@ -127,8 +125,7 @@ class RecurrentCritic(nn.Module):
self.fc2 = nn.Linear(128 + np.prod(action_shape), 1)
def forward(self, s, a=None):
if not isinstance(s, torch.Tensor):
s = torch.tensor(s, device=self.device, dtype=torch.float)
s = to_torch(s, device=self.device, dtype=torch.float)
# s [bsz, len, dim] (training) or [bsz, dim] (evaluation)
# In short, the tensor's shape in training phase is longer than which
# in evaluation phase.

View File

@ -3,6 +3,8 @@ import numpy as np
from torch import nn
import torch.nn.functional as F
from tianshou.data import to_torch
class Net(nn.Module):
def __init__(self, layer_num, state_shape, action_shape=0, device='cpu',
@ -21,8 +23,7 @@ class Net(nn.Module):
self.model = nn.Sequential(*self.model)
def forward(self, s, state=None, info={}):
if not isinstance(s, torch.Tensor):
s = torch.tensor(s, device=self.device, dtype=torch.float)
s = to_torch(s, device=self.device, dtype=torch.float)
batch = s.shape[0]
s = s.view(batch, -1)
logits = self.model(s)
@ -65,8 +66,7 @@ class Recurrent(nn.Module):
self.fc2 = nn.Linear(128, np.prod(action_shape))
def forward(self, s, state=None, info={}):
if not isinstance(s, torch.Tensor):
s = torch.tensor(s, device=self.device, dtype=torch.float)
s = to_torch(s, device=self.device, dtype=torch.float)
# s [bsz, len, dim] (training) or [bsz, dim] (evaluation)
# In short, the tensor's shape in training phase is longer than which
# in evaluation phase.

View File

@ -2,7 +2,7 @@ import torch
import numpy as np
from torch import nn
from abc import ABC, abstractmethod
from typing import Dict, List, Union, Optional
from typing import Dict, List, Union, Optional, Callable
from tianshou.data import Batch, ReplayBuffer
@ -113,6 +113,8 @@ class BasePolicy(ABC, nn.Module):
to 0.99.
:param float gae_lambda: the parameter for Generalized Advantage
Estimation, should be in [0, 1], defaults to 0.95.
:return: a Batch. The result will be stored in batch.returns.
"""
if v_s_ is None:
v_s_ = np.zeros_like(batch.rew)
@ -120,12 +122,61 @@ class BasePolicy(ABC, nn.Module):
if not isinstance(v_s_, np.ndarray):
v_s_ = np.array(v_s_, np.float)
v_s_ = v_s_.reshape(batch.rew.shape)
batch.returns = np.roll(v_s_, 1, axis=0)
returns = np.roll(v_s_, 1, axis=0)
m = (1. - batch.done) * gamma
delta = batch.rew + v_s_ * m - batch.returns
delta = batch.rew + v_s_ * m - returns
m *= gae_lambda
gae = 0.
for i in range(len(batch.rew) - 1, -1, -1):
gae = delta[i] + m[i] * gae
batch.returns[i] += gae
returns[i] += gae
batch.returns = returns
return batch
@staticmethod
def compute_nstep_return(
batch: Batch,
buffer: ReplayBuffer,
indice: np.ndarray,
target_q_fn: Callable[[ReplayBuffer, np.ndarray], np.ndarray],
gamma: float = 0.99,
n_step: int = 1
) -> np.ndarray:
r"""Compute n-step return for Q-learning targets:
.. math::
G_t = \sum_{i = t}^{t + n - 1} \gamma^{i - t}(1 - d_i)r_i +
\gamma^n (1 - d_{t + n}) Q_{\mathrm{target}}(s_{t + n})
, where :math:`\gamma` is the discount factor,
:math:`\gamma \in [0, 1]`, :math:`d_t` is the done flag of step
:math:`t`.
:param batch: a data batch, which is equal to buffer[indice].
:type batch: :class:`~tianshou.data.Batch`
:param buffer: a data buffer which contains several full-episode data
chronologically.
:type buffer: :class:`~tianshou.data.ReplayBuffer`
:param indice: sampled timestep.
:type indice: numpy.ndarray
:param float gamma: the discount factor, should be in [0, 1], defaults
to 0.99.
:param int n_step: the number of estimation step, should be an int
greater than 0, defaults to 1.
:return: a Batch. The result will be stored in batch.returns.
"""
returns = np.zeros_like(indice)
gammas = np.zeros_like(indice) + n_step
done, rew, buf_len = buffer.done, buffer.rew, len(buffer)
for n in range(n_step - 1, -1, -1):
now = (indice + n) % buf_len
gammas[done[now] > 0] = n
returns[done[now] > 0] = 0
returns = rew[now] + gamma * returns
terminal = (indice + n_step - 1) % buf_len
target_q = target_q_fn(buffer, terminal)
target_q[gammas != n_step] = 0
returns += (gamma ** gammas) * target_q
batch.returns = returns
return batch

View File

@ -68,6 +68,21 @@ class DQNPolicy(BasePolicy):
"""Synchronize the weight for the target network."""
self.model_old.load_state_dict(self.model.state_dict())
def _target_q(self, buffer: ReplayBuffer,
indice: np.ndarray) -> np.ndarray:
data = buffer[indice]
if self._target:
# target_Q = Q_old(s_, argmax(Q_new(s_, *)))
a = self(data, input='obs_next', eps=0).act
target_q = self(
data, model='model_old', input='obs_next').logits
target_q = to_numpy(target_q)
target_q = target_q[np.arange(len(a)), a]
else:
target_q = self(data, input='obs_next').logits
target_q = to_numpy(target_q).max(axis=1)
return target_q
def process_fn(self, batch: Batch, buffer: ReplayBuffer,
indice: np.ndarray) -> Batch:
r"""Compute the n-step return for Q-learning targets:
@ -82,46 +97,11 @@ class DQNPolicy(BasePolicy):
:math:`t`. If there is no target network, the :math:`Q_{old}` is equal
to :math:`Q_{new}`.
"""
returns = np.zeros_like(indice)
gammas = np.zeros_like(indice) + self._n_step
for n in range(self._n_step - 1, -1, -1):
now = (indice + n) % len(buffer)
gammas[buffer.done[now] > 0] = n
returns[buffer.done[now] > 0] = 0
returns = buffer.rew[now] + self._gamma * returns
terminal = (indice + self._n_step - 1) % len(buffer)
terminal_data = buffer[terminal]
if self._target:
# target_Q = Q_old(s_, argmax(Q_new(s_, *)))
a = self(terminal_data, input='obs_next', eps=0).act
target_q = self(
terminal_data, model='model_old', input='obs_next').logits
if isinstance(target_q, torch.Tensor):
target_q = to_numpy(target_q)
target_q = target_q[np.arange(len(a)), a]
else:
target_q = self(terminal_data, input='obs_next').logits
if isinstance(target_q, torch.Tensor):
target_q = to_numpy(target_q)
target_q = target_q.max(axis=1)
target_q[gammas != self._n_step] = 0
returns += (self._gamma ** gammas) * target_q
batch.returns = returns
batch = self.compute_nstep_return(
batch, buffer, indice, self._target_q, self._gamma, self._n_step)
if isinstance(buffer, PrioritizedReplayBuffer):
q = self(batch).logits
q = q[np.arange(len(q)), batch.act]
r = batch.returns
if isinstance(r, np.ndarray):
r = to_torch(r, device=q.device, dtype=q.dtype)
td = r - q
buffer.update_weight(indice, to_numpy(td))
impt_weight = to_torch(batch.impt_weight,
device=q.device, dtype=torch.float)
loss = (td.pow(2) * impt_weight).mean()
if not hasattr(batch, 'loss'):
batch.loss = loss
else:
batch.loss += loss
batch.update_weight = buffer.update_weight
batch.indice = indice
return batch
def forward(self, batch: Batch,
@ -162,14 +142,16 @@ class DQNPolicy(BasePolicy):
if self._target and self._cnt % self._freq == 0:
self.sync_weight()
self.optim.zero_grad()
if hasattr(batch, 'loss'):
loss = batch.loss
q = self(batch).logits
q = q[np.arange(len(q)), batch.act]
r = to_torch(batch.returns, device=q.device, dtype=q.dtype)
if hasattr(batch, 'update_weight'):
td = r - q
batch.update_weight(batch.indice, to_numpy(td))
impt_weight = to_torch(batch.impt_weight,
device=q.device, dtype=torch.float)
loss = (td.pow(2) * impt_weight).mean()
else:
q = self(batch).logits
q = q[np.arange(len(q)), batch.act]
r = batch.returns
if isinstance(r, np.ndarray):
r = to_torch(r, device=q.device, dtype=q.dtype)
loss = F.mse_loss(q, r)
loss.backward()
self.optim.step()

View File

@ -2,6 +2,8 @@ import torch
import numpy as np
from typing import Union
from tianshou.data import to_numpy
class MovAvg(object):
"""Class for moving average. It will automatically exclude the infinity and
@ -32,7 +34,7 @@ class MovAvg(object):
only one element, a python scalar, or a list of python scalar.
"""
if isinstance(x, torch.Tensor):
x = x.item()
x = to_numpy(x.flatten())
if isinstance(x, list) or isinstance(x, np.ndarray):
for _ in x:
if _ not in self.banned: