compute_nstep_returns (item 2 of #51)
This commit is contained in:
parent
f818a2467b
commit
ff81a18f42
@ -2,6 +2,8 @@ import torch
|
||||
import numpy as np
|
||||
from torch import nn
|
||||
|
||||
from tianshou.data import to_torch
|
||||
|
||||
|
||||
class Actor(nn.Module):
|
||||
def __init__(self, layer_num, state_shape, action_shape,
|
||||
@ -18,8 +20,7 @@ class Actor(nn.Module):
|
||||
self._max = max_action
|
||||
|
||||
def forward(self, s, **kwargs):
|
||||
if not isinstance(s, torch.Tensor):
|
||||
s = torch.tensor(s, device=self.device, dtype=torch.float)
|
||||
s = to_torch(s, device=self.device, dtype=torch.float)
|
||||
batch = s.shape[0]
|
||||
s = s.view(batch, -1)
|
||||
logits = self.model(s)
|
||||
@ -44,8 +45,7 @@ class ActorProb(nn.Module):
|
||||
self._max = max_action
|
||||
|
||||
def forward(self, s, **kwargs):
|
||||
if not isinstance(s, torch.Tensor):
|
||||
s = torch.tensor(s, device=self.device, dtype=torch.float)
|
||||
s = to_torch(s, device=self.device, dtype=torch.float)
|
||||
batch = s.shape[0]
|
||||
s = s.view(batch, -1)
|
||||
logits = self.model(s)
|
||||
@ -72,8 +72,7 @@ class Critic(nn.Module):
|
||||
self.model = nn.Sequential(*self.model)
|
||||
|
||||
def forward(self, s, a=None, **kwargs):
|
||||
if not isinstance(s, torch.Tensor):
|
||||
s = torch.tensor(s, device=self.device, dtype=torch.float)
|
||||
s = to_torch(s, device=self.device, dtype=torch.float)
|
||||
batch = s.shape[0]
|
||||
s = s.view(batch, -1)
|
||||
if a is not None:
|
||||
@ -96,8 +95,7 @@ class RecurrentActorProb(nn.Module):
|
||||
self.sigma = nn.Parameter(torch.zeros(np.prod(action_shape), 1))
|
||||
|
||||
def forward(self, s, **kwargs):
|
||||
if not isinstance(s, torch.Tensor):
|
||||
s = torch.tensor(s, device=self.device, dtype=torch.float)
|
||||
s = to_torch(s, device=self.device, dtype=torch.float)
|
||||
# s [bsz, len, dim] (training) or [bsz, dim] (evaluation)
|
||||
# In short, the tensor's shape in training phase is longer than which
|
||||
# in evaluation phase.
|
||||
@ -127,8 +125,7 @@ class RecurrentCritic(nn.Module):
|
||||
self.fc2 = nn.Linear(128 + np.prod(action_shape), 1)
|
||||
|
||||
def forward(self, s, a=None):
|
||||
if not isinstance(s, torch.Tensor):
|
||||
s = torch.tensor(s, device=self.device, dtype=torch.float)
|
||||
s = to_torch(s, device=self.device, dtype=torch.float)
|
||||
# s [bsz, len, dim] (training) or [bsz, dim] (evaluation)
|
||||
# In short, the tensor's shape in training phase is longer than which
|
||||
# in evaluation phase.
|
||||
|
@ -3,6 +3,8 @@ import numpy as np
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from tianshou.data import to_torch
|
||||
|
||||
|
||||
class Net(nn.Module):
|
||||
def __init__(self, layer_num, state_shape, action_shape=0, device='cpu',
|
||||
@ -21,8 +23,7 @@ class Net(nn.Module):
|
||||
self.model = nn.Sequential(*self.model)
|
||||
|
||||
def forward(self, s, state=None, info={}):
|
||||
if not isinstance(s, torch.Tensor):
|
||||
s = torch.tensor(s, device=self.device, dtype=torch.float)
|
||||
s = to_torch(s, device=self.device, dtype=torch.float)
|
||||
batch = s.shape[0]
|
||||
s = s.view(batch, -1)
|
||||
logits = self.model(s)
|
||||
@ -65,8 +66,7 @@ class Recurrent(nn.Module):
|
||||
self.fc2 = nn.Linear(128, np.prod(action_shape))
|
||||
|
||||
def forward(self, s, state=None, info={}):
|
||||
if not isinstance(s, torch.Tensor):
|
||||
s = torch.tensor(s, device=self.device, dtype=torch.float)
|
||||
s = to_torch(s, device=self.device, dtype=torch.float)
|
||||
# s [bsz, len, dim] (training) or [bsz, dim] (evaluation)
|
||||
# In short, the tensor's shape in training phase is longer than which
|
||||
# in evaluation phase.
|
||||
|
@ -2,7 +2,7 @@ import torch
|
||||
import numpy as np
|
||||
from torch import nn
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List, Union, Optional
|
||||
from typing import Dict, List, Union, Optional, Callable
|
||||
|
||||
from tianshou.data import Batch, ReplayBuffer
|
||||
|
||||
@ -113,6 +113,8 @@ class BasePolicy(ABC, nn.Module):
|
||||
to 0.99.
|
||||
:param float gae_lambda: the parameter for Generalized Advantage
|
||||
Estimation, should be in [0, 1], defaults to 0.95.
|
||||
|
||||
:return: a Batch. The result will be stored in batch.returns.
|
||||
"""
|
||||
if v_s_ is None:
|
||||
v_s_ = np.zeros_like(batch.rew)
|
||||
@ -120,12 +122,61 @@ class BasePolicy(ABC, nn.Module):
|
||||
if not isinstance(v_s_, np.ndarray):
|
||||
v_s_ = np.array(v_s_, np.float)
|
||||
v_s_ = v_s_.reshape(batch.rew.shape)
|
||||
batch.returns = np.roll(v_s_, 1, axis=0)
|
||||
returns = np.roll(v_s_, 1, axis=0)
|
||||
m = (1. - batch.done) * gamma
|
||||
delta = batch.rew + v_s_ * m - batch.returns
|
||||
delta = batch.rew + v_s_ * m - returns
|
||||
m *= gae_lambda
|
||||
gae = 0.
|
||||
for i in range(len(batch.rew) - 1, -1, -1):
|
||||
gae = delta[i] + m[i] * gae
|
||||
batch.returns[i] += gae
|
||||
returns[i] += gae
|
||||
batch.returns = returns
|
||||
return batch
|
||||
|
||||
@staticmethod
|
||||
def compute_nstep_return(
|
||||
batch: Batch,
|
||||
buffer: ReplayBuffer,
|
||||
indice: np.ndarray,
|
||||
target_q_fn: Callable[[ReplayBuffer, np.ndarray], np.ndarray],
|
||||
gamma: float = 0.99,
|
||||
n_step: int = 1
|
||||
) -> np.ndarray:
|
||||
r"""Compute n-step return for Q-learning targets:
|
||||
|
||||
.. math::
|
||||
G_t = \sum_{i = t}^{t + n - 1} \gamma^{i - t}(1 - d_i)r_i +
|
||||
\gamma^n (1 - d_{t + n}) Q_{\mathrm{target}}(s_{t + n})
|
||||
|
||||
, where :math:`\gamma` is the discount factor,
|
||||
:math:`\gamma \in [0, 1]`, :math:`d_t` is the done flag of step
|
||||
:math:`t`.
|
||||
|
||||
:param batch: a data batch, which is equal to buffer[indice].
|
||||
:type batch: :class:`~tianshou.data.Batch`
|
||||
:param buffer: a data buffer which contains several full-episode data
|
||||
chronologically.
|
||||
:type buffer: :class:`~tianshou.data.ReplayBuffer`
|
||||
:param indice: sampled timestep.
|
||||
:type indice: numpy.ndarray
|
||||
:param float gamma: the discount factor, should be in [0, 1], defaults
|
||||
to 0.99.
|
||||
:param int n_step: the number of estimation step, should be an int
|
||||
greater than 0, defaults to 1.
|
||||
|
||||
:return: a Batch. The result will be stored in batch.returns.
|
||||
"""
|
||||
returns = np.zeros_like(indice)
|
||||
gammas = np.zeros_like(indice) + n_step
|
||||
done, rew, buf_len = buffer.done, buffer.rew, len(buffer)
|
||||
for n in range(n_step - 1, -1, -1):
|
||||
now = (indice + n) % buf_len
|
||||
gammas[done[now] > 0] = n
|
||||
returns[done[now] > 0] = 0
|
||||
returns = rew[now] + gamma * returns
|
||||
terminal = (indice + n_step - 1) % buf_len
|
||||
target_q = target_q_fn(buffer, terminal)
|
||||
target_q[gammas != n_step] = 0
|
||||
returns += (gamma ** gammas) * target_q
|
||||
batch.returns = returns
|
||||
return batch
|
||||
|
@ -68,6 +68,21 @@ class DQNPolicy(BasePolicy):
|
||||
"""Synchronize the weight for the target network."""
|
||||
self.model_old.load_state_dict(self.model.state_dict())
|
||||
|
||||
def _target_q(self, buffer: ReplayBuffer,
|
||||
indice: np.ndarray) -> np.ndarray:
|
||||
data = buffer[indice]
|
||||
if self._target:
|
||||
# target_Q = Q_old(s_, argmax(Q_new(s_, *)))
|
||||
a = self(data, input='obs_next', eps=0).act
|
||||
target_q = self(
|
||||
data, model='model_old', input='obs_next').logits
|
||||
target_q = to_numpy(target_q)
|
||||
target_q = target_q[np.arange(len(a)), a]
|
||||
else:
|
||||
target_q = self(data, input='obs_next').logits
|
||||
target_q = to_numpy(target_q).max(axis=1)
|
||||
return target_q
|
||||
|
||||
def process_fn(self, batch: Batch, buffer: ReplayBuffer,
|
||||
indice: np.ndarray) -> Batch:
|
||||
r"""Compute the n-step return for Q-learning targets:
|
||||
@ -82,46 +97,11 @@ class DQNPolicy(BasePolicy):
|
||||
:math:`t`. If there is no target network, the :math:`Q_{old}` is equal
|
||||
to :math:`Q_{new}`.
|
||||
"""
|
||||
returns = np.zeros_like(indice)
|
||||
gammas = np.zeros_like(indice) + self._n_step
|
||||
for n in range(self._n_step - 1, -1, -1):
|
||||
now = (indice + n) % len(buffer)
|
||||
gammas[buffer.done[now] > 0] = n
|
||||
returns[buffer.done[now] > 0] = 0
|
||||
returns = buffer.rew[now] + self._gamma * returns
|
||||
terminal = (indice + self._n_step - 1) % len(buffer)
|
||||
terminal_data = buffer[terminal]
|
||||
if self._target:
|
||||
# target_Q = Q_old(s_, argmax(Q_new(s_, *)))
|
||||
a = self(terminal_data, input='obs_next', eps=0).act
|
||||
target_q = self(
|
||||
terminal_data, model='model_old', input='obs_next').logits
|
||||
if isinstance(target_q, torch.Tensor):
|
||||
target_q = to_numpy(target_q)
|
||||
target_q = target_q[np.arange(len(a)), a]
|
||||
else:
|
||||
target_q = self(terminal_data, input='obs_next').logits
|
||||
if isinstance(target_q, torch.Tensor):
|
||||
target_q = to_numpy(target_q)
|
||||
target_q = target_q.max(axis=1)
|
||||
target_q[gammas != self._n_step] = 0
|
||||
returns += (self._gamma ** gammas) * target_q
|
||||
batch.returns = returns
|
||||
batch = self.compute_nstep_return(
|
||||
batch, buffer, indice, self._target_q, self._gamma, self._n_step)
|
||||
if isinstance(buffer, PrioritizedReplayBuffer):
|
||||
q = self(batch).logits
|
||||
q = q[np.arange(len(q)), batch.act]
|
||||
r = batch.returns
|
||||
if isinstance(r, np.ndarray):
|
||||
r = to_torch(r, device=q.device, dtype=q.dtype)
|
||||
td = r - q
|
||||
buffer.update_weight(indice, to_numpy(td))
|
||||
impt_weight = to_torch(batch.impt_weight,
|
||||
device=q.device, dtype=torch.float)
|
||||
loss = (td.pow(2) * impt_weight).mean()
|
||||
if not hasattr(batch, 'loss'):
|
||||
batch.loss = loss
|
||||
else:
|
||||
batch.loss += loss
|
||||
batch.update_weight = buffer.update_weight
|
||||
batch.indice = indice
|
||||
return batch
|
||||
|
||||
def forward(self, batch: Batch,
|
||||
@ -162,14 +142,16 @@ class DQNPolicy(BasePolicy):
|
||||
if self._target and self._cnt % self._freq == 0:
|
||||
self.sync_weight()
|
||||
self.optim.zero_grad()
|
||||
if hasattr(batch, 'loss'):
|
||||
loss = batch.loss
|
||||
q = self(batch).logits
|
||||
q = q[np.arange(len(q)), batch.act]
|
||||
r = to_torch(batch.returns, device=q.device, dtype=q.dtype)
|
||||
if hasattr(batch, 'update_weight'):
|
||||
td = r - q
|
||||
batch.update_weight(batch.indice, to_numpy(td))
|
||||
impt_weight = to_torch(batch.impt_weight,
|
||||
device=q.device, dtype=torch.float)
|
||||
loss = (td.pow(2) * impt_weight).mean()
|
||||
else:
|
||||
q = self(batch).logits
|
||||
q = q[np.arange(len(q)), batch.act]
|
||||
r = batch.returns
|
||||
if isinstance(r, np.ndarray):
|
||||
r = to_torch(r, device=q.device, dtype=q.dtype)
|
||||
loss = F.mse_loss(q, r)
|
||||
loss.backward()
|
||||
self.optim.step()
|
||||
|
@ -2,6 +2,8 @@ import torch
|
||||
import numpy as np
|
||||
from typing import Union
|
||||
|
||||
from tianshou.data import to_numpy
|
||||
|
||||
|
||||
class MovAvg(object):
|
||||
"""Class for moving average. It will automatically exclude the infinity and
|
||||
@ -32,7 +34,7 @@ class MovAvg(object):
|
||||
only one element, a python scalar, or a list of python scalar.
|
||||
"""
|
||||
if isinstance(x, torch.Tensor):
|
||||
x = x.item()
|
||||
x = to_numpy(x.flatten())
|
||||
if isinstance(x, list) or isinstance(x, np.ndarray):
|
||||
for _ in x:
|
||||
if _ not in self.banned:
|
||||
|
Loading…
x
Reference in New Issue
Block a user