compute_nstep_returns (item 2 of #51)
This commit is contained in:
parent
f818a2467b
commit
ff81a18f42
@ -2,6 +2,8 @@ import torch
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
|
from tianshou.data import to_torch
|
||||||
|
|
||||||
|
|
||||||
class Actor(nn.Module):
|
class Actor(nn.Module):
|
||||||
def __init__(self, layer_num, state_shape, action_shape,
|
def __init__(self, layer_num, state_shape, action_shape,
|
||||||
@ -18,8 +20,7 @@ class Actor(nn.Module):
|
|||||||
self._max = max_action
|
self._max = max_action
|
||||||
|
|
||||||
def forward(self, s, **kwargs):
|
def forward(self, s, **kwargs):
|
||||||
if not isinstance(s, torch.Tensor):
|
s = to_torch(s, device=self.device, dtype=torch.float)
|
||||||
s = torch.tensor(s, device=self.device, dtype=torch.float)
|
|
||||||
batch = s.shape[0]
|
batch = s.shape[0]
|
||||||
s = s.view(batch, -1)
|
s = s.view(batch, -1)
|
||||||
logits = self.model(s)
|
logits = self.model(s)
|
||||||
@ -44,8 +45,7 @@ class ActorProb(nn.Module):
|
|||||||
self._max = max_action
|
self._max = max_action
|
||||||
|
|
||||||
def forward(self, s, **kwargs):
|
def forward(self, s, **kwargs):
|
||||||
if not isinstance(s, torch.Tensor):
|
s = to_torch(s, device=self.device, dtype=torch.float)
|
||||||
s = torch.tensor(s, device=self.device, dtype=torch.float)
|
|
||||||
batch = s.shape[0]
|
batch = s.shape[0]
|
||||||
s = s.view(batch, -1)
|
s = s.view(batch, -1)
|
||||||
logits = self.model(s)
|
logits = self.model(s)
|
||||||
@ -72,8 +72,7 @@ class Critic(nn.Module):
|
|||||||
self.model = nn.Sequential(*self.model)
|
self.model = nn.Sequential(*self.model)
|
||||||
|
|
||||||
def forward(self, s, a=None, **kwargs):
|
def forward(self, s, a=None, **kwargs):
|
||||||
if not isinstance(s, torch.Tensor):
|
s = to_torch(s, device=self.device, dtype=torch.float)
|
||||||
s = torch.tensor(s, device=self.device, dtype=torch.float)
|
|
||||||
batch = s.shape[0]
|
batch = s.shape[0]
|
||||||
s = s.view(batch, -1)
|
s = s.view(batch, -1)
|
||||||
if a is not None:
|
if a is not None:
|
||||||
@ -96,8 +95,7 @@ class RecurrentActorProb(nn.Module):
|
|||||||
self.sigma = nn.Parameter(torch.zeros(np.prod(action_shape), 1))
|
self.sigma = nn.Parameter(torch.zeros(np.prod(action_shape), 1))
|
||||||
|
|
||||||
def forward(self, s, **kwargs):
|
def forward(self, s, **kwargs):
|
||||||
if not isinstance(s, torch.Tensor):
|
s = to_torch(s, device=self.device, dtype=torch.float)
|
||||||
s = torch.tensor(s, device=self.device, dtype=torch.float)
|
|
||||||
# s [bsz, len, dim] (training) or [bsz, dim] (evaluation)
|
# s [bsz, len, dim] (training) or [bsz, dim] (evaluation)
|
||||||
# In short, the tensor's shape in training phase is longer than which
|
# In short, the tensor's shape in training phase is longer than which
|
||||||
# in evaluation phase.
|
# in evaluation phase.
|
||||||
@ -127,8 +125,7 @@ class RecurrentCritic(nn.Module):
|
|||||||
self.fc2 = nn.Linear(128 + np.prod(action_shape), 1)
|
self.fc2 = nn.Linear(128 + np.prod(action_shape), 1)
|
||||||
|
|
||||||
def forward(self, s, a=None):
|
def forward(self, s, a=None):
|
||||||
if not isinstance(s, torch.Tensor):
|
s = to_torch(s, device=self.device, dtype=torch.float)
|
||||||
s = torch.tensor(s, device=self.device, dtype=torch.float)
|
|
||||||
# s [bsz, len, dim] (training) or [bsz, dim] (evaluation)
|
# s [bsz, len, dim] (training) or [bsz, dim] (evaluation)
|
||||||
# In short, the tensor's shape in training phase is longer than which
|
# In short, the tensor's shape in training phase is longer than which
|
||||||
# in evaluation phase.
|
# in evaluation phase.
|
||||||
|
@ -3,6 +3,8 @@ import numpy as np
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from tianshou.data import to_torch
|
||||||
|
|
||||||
|
|
||||||
class Net(nn.Module):
|
class Net(nn.Module):
|
||||||
def __init__(self, layer_num, state_shape, action_shape=0, device='cpu',
|
def __init__(self, layer_num, state_shape, action_shape=0, device='cpu',
|
||||||
@ -21,8 +23,7 @@ class Net(nn.Module):
|
|||||||
self.model = nn.Sequential(*self.model)
|
self.model = nn.Sequential(*self.model)
|
||||||
|
|
||||||
def forward(self, s, state=None, info={}):
|
def forward(self, s, state=None, info={}):
|
||||||
if not isinstance(s, torch.Tensor):
|
s = to_torch(s, device=self.device, dtype=torch.float)
|
||||||
s = torch.tensor(s, device=self.device, dtype=torch.float)
|
|
||||||
batch = s.shape[0]
|
batch = s.shape[0]
|
||||||
s = s.view(batch, -1)
|
s = s.view(batch, -1)
|
||||||
logits = self.model(s)
|
logits = self.model(s)
|
||||||
@ -65,8 +66,7 @@ class Recurrent(nn.Module):
|
|||||||
self.fc2 = nn.Linear(128, np.prod(action_shape))
|
self.fc2 = nn.Linear(128, np.prod(action_shape))
|
||||||
|
|
||||||
def forward(self, s, state=None, info={}):
|
def forward(self, s, state=None, info={}):
|
||||||
if not isinstance(s, torch.Tensor):
|
s = to_torch(s, device=self.device, dtype=torch.float)
|
||||||
s = torch.tensor(s, device=self.device, dtype=torch.float)
|
|
||||||
# s [bsz, len, dim] (training) or [bsz, dim] (evaluation)
|
# s [bsz, len, dim] (training) or [bsz, dim] (evaluation)
|
||||||
# In short, the tensor's shape in training phase is longer than which
|
# In short, the tensor's shape in training phase is longer than which
|
||||||
# in evaluation phase.
|
# in evaluation phase.
|
||||||
|
@ -2,7 +2,7 @@ import torch
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Dict, List, Union, Optional
|
from typing import Dict, List, Union, Optional, Callable
|
||||||
|
|
||||||
from tianshou.data import Batch, ReplayBuffer
|
from tianshou.data import Batch, ReplayBuffer
|
||||||
|
|
||||||
@ -113,6 +113,8 @@ class BasePolicy(ABC, nn.Module):
|
|||||||
to 0.99.
|
to 0.99.
|
||||||
:param float gae_lambda: the parameter for Generalized Advantage
|
:param float gae_lambda: the parameter for Generalized Advantage
|
||||||
Estimation, should be in [0, 1], defaults to 0.95.
|
Estimation, should be in [0, 1], defaults to 0.95.
|
||||||
|
|
||||||
|
:return: a Batch. The result will be stored in batch.returns.
|
||||||
"""
|
"""
|
||||||
if v_s_ is None:
|
if v_s_ is None:
|
||||||
v_s_ = np.zeros_like(batch.rew)
|
v_s_ = np.zeros_like(batch.rew)
|
||||||
@ -120,12 +122,61 @@ class BasePolicy(ABC, nn.Module):
|
|||||||
if not isinstance(v_s_, np.ndarray):
|
if not isinstance(v_s_, np.ndarray):
|
||||||
v_s_ = np.array(v_s_, np.float)
|
v_s_ = np.array(v_s_, np.float)
|
||||||
v_s_ = v_s_.reshape(batch.rew.shape)
|
v_s_ = v_s_.reshape(batch.rew.shape)
|
||||||
batch.returns = np.roll(v_s_, 1, axis=0)
|
returns = np.roll(v_s_, 1, axis=0)
|
||||||
m = (1. - batch.done) * gamma
|
m = (1. - batch.done) * gamma
|
||||||
delta = batch.rew + v_s_ * m - batch.returns
|
delta = batch.rew + v_s_ * m - returns
|
||||||
m *= gae_lambda
|
m *= gae_lambda
|
||||||
gae = 0.
|
gae = 0.
|
||||||
for i in range(len(batch.rew) - 1, -1, -1):
|
for i in range(len(batch.rew) - 1, -1, -1):
|
||||||
gae = delta[i] + m[i] * gae
|
gae = delta[i] + m[i] * gae
|
||||||
batch.returns[i] += gae
|
returns[i] += gae
|
||||||
|
batch.returns = returns
|
||||||
|
return batch
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def compute_nstep_return(
|
||||||
|
batch: Batch,
|
||||||
|
buffer: ReplayBuffer,
|
||||||
|
indice: np.ndarray,
|
||||||
|
target_q_fn: Callable[[ReplayBuffer, np.ndarray], np.ndarray],
|
||||||
|
gamma: float = 0.99,
|
||||||
|
n_step: int = 1
|
||||||
|
) -> np.ndarray:
|
||||||
|
r"""Compute n-step return for Q-learning targets:
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
G_t = \sum_{i = t}^{t + n - 1} \gamma^{i - t}(1 - d_i)r_i +
|
||||||
|
\gamma^n (1 - d_{t + n}) Q_{\mathrm{target}}(s_{t + n})
|
||||||
|
|
||||||
|
, where :math:`\gamma` is the discount factor,
|
||||||
|
:math:`\gamma \in [0, 1]`, :math:`d_t` is the done flag of step
|
||||||
|
:math:`t`.
|
||||||
|
|
||||||
|
:param batch: a data batch, which is equal to buffer[indice].
|
||||||
|
:type batch: :class:`~tianshou.data.Batch`
|
||||||
|
:param buffer: a data buffer which contains several full-episode data
|
||||||
|
chronologically.
|
||||||
|
:type buffer: :class:`~tianshou.data.ReplayBuffer`
|
||||||
|
:param indice: sampled timestep.
|
||||||
|
:type indice: numpy.ndarray
|
||||||
|
:param float gamma: the discount factor, should be in [0, 1], defaults
|
||||||
|
to 0.99.
|
||||||
|
:param int n_step: the number of estimation step, should be an int
|
||||||
|
greater than 0, defaults to 1.
|
||||||
|
|
||||||
|
:return: a Batch. The result will be stored in batch.returns.
|
||||||
|
"""
|
||||||
|
returns = np.zeros_like(indice)
|
||||||
|
gammas = np.zeros_like(indice) + n_step
|
||||||
|
done, rew, buf_len = buffer.done, buffer.rew, len(buffer)
|
||||||
|
for n in range(n_step - 1, -1, -1):
|
||||||
|
now = (indice + n) % buf_len
|
||||||
|
gammas[done[now] > 0] = n
|
||||||
|
returns[done[now] > 0] = 0
|
||||||
|
returns = rew[now] + gamma * returns
|
||||||
|
terminal = (indice + n_step - 1) % buf_len
|
||||||
|
target_q = target_q_fn(buffer, terminal)
|
||||||
|
target_q[gammas != n_step] = 0
|
||||||
|
returns += (gamma ** gammas) * target_q
|
||||||
|
batch.returns = returns
|
||||||
return batch
|
return batch
|
||||||
|
@ -68,6 +68,21 @@ class DQNPolicy(BasePolicy):
|
|||||||
"""Synchronize the weight for the target network."""
|
"""Synchronize the weight for the target network."""
|
||||||
self.model_old.load_state_dict(self.model.state_dict())
|
self.model_old.load_state_dict(self.model.state_dict())
|
||||||
|
|
||||||
|
def _target_q(self, buffer: ReplayBuffer,
|
||||||
|
indice: np.ndarray) -> np.ndarray:
|
||||||
|
data = buffer[indice]
|
||||||
|
if self._target:
|
||||||
|
# target_Q = Q_old(s_, argmax(Q_new(s_, *)))
|
||||||
|
a = self(data, input='obs_next', eps=0).act
|
||||||
|
target_q = self(
|
||||||
|
data, model='model_old', input='obs_next').logits
|
||||||
|
target_q = to_numpy(target_q)
|
||||||
|
target_q = target_q[np.arange(len(a)), a]
|
||||||
|
else:
|
||||||
|
target_q = self(data, input='obs_next').logits
|
||||||
|
target_q = to_numpy(target_q).max(axis=1)
|
||||||
|
return target_q
|
||||||
|
|
||||||
def process_fn(self, batch: Batch, buffer: ReplayBuffer,
|
def process_fn(self, batch: Batch, buffer: ReplayBuffer,
|
||||||
indice: np.ndarray) -> Batch:
|
indice: np.ndarray) -> Batch:
|
||||||
r"""Compute the n-step return for Q-learning targets:
|
r"""Compute the n-step return for Q-learning targets:
|
||||||
@ -82,46 +97,11 @@ class DQNPolicy(BasePolicy):
|
|||||||
:math:`t`. If there is no target network, the :math:`Q_{old}` is equal
|
:math:`t`. If there is no target network, the :math:`Q_{old}` is equal
|
||||||
to :math:`Q_{new}`.
|
to :math:`Q_{new}`.
|
||||||
"""
|
"""
|
||||||
returns = np.zeros_like(indice)
|
batch = self.compute_nstep_return(
|
||||||
gammas = np.zeros_like(indice) + self._n_step
|
batch, buffer, indice, self._target_q, self._gamma, self._n_step)
|
||||||
for n in range(self._n_step - 1, -1, -1):
|
|
||||||
now = (indice + n) % len(buffer)
|
|
||||||
gammas[buffer.done[now] > 0] = n
|
|
||||||
returns[buffer.done[now] > 0] = 0
|
|
||||||
returns = buffer.rew[now] + self._gamma * returns
|
|
||||||
terminal = (indice + self._n_step - 1) % len(buffer)
|
|
||||||
terminal_data = buffer[terminal]
|
|
||||||
if self._target:
|
|
||||||
# target_Q = Q_old(s_, argmax(Q_new(s_, *)))
|
|
||||||
a = self(terminal_data, input='obs_next', eps=0).act
|
|
||||||
target_q = self(
|
|
||||||
terminal_data, model='model_old', input='obs_next').logits
|
|
||||||
if isinstance(target_q, torch.Tensor):
|
|
||||||
target_q = to_numpy(target_q)
|
|
||||||
target_q = target_q[np.arange(len(a)), a]
|
|
||||||
else:
|
|
||||||
target_q = self(terminal_data, input='obs_next').logits
|
|
||||||
if isinstance(target_q, torch.Tensor):
|
|
||||||
target_q = to_numpy(target_q)
|
|
||||||
target_q = target_q.max(axis=1)
|
|
||||||
target_q[gammas != self._n_step] = 0
|
|
||||||
returns += (self._gamma ** gammas) * target_q
|
|
||||||
batch.returns = returns
|
|
||||||
if isinstance(buffer, PrioritizedReplayBuffer):
|
if isinstance(buffer, PrioritizedReplayBuffer):
|
||||||
q = self(batch).logits
|
batch.update_weight = buffer.update_weight
|
||||||
q = q[np.arange(len(q)), batch.act]
|
batch.indice = indice
|
||||||
r = batch.returns
|
|
||||||
if isinstance(r, np.ndarray):
|
|
||||||
r = to_torch(r, device=q.device, dtype=q.dtype)
|
|
||||||
td = r - q
|
|
||||||
buffer.update_weight(indice, to_numpy(td))
|
|
||||||
impt_weight = to_torch(batch.impt_weight,
|
|
||||||
device=q.device, dtype=torch.float)
|
|
||||||
loss = (td.pow(2) * impt_weight).mean()
|
|
||||||
if not hasattr(batch, 'loss'):
|
|
||||||
batch.loss = loss
|
|
||||||
else:
|
|
||||||
batch.loss += loss
|
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
def forward(self, batch: Batch,
|
def forward(self, batch: Batch,
|
||||||
@ -162,14 +142,16 @@ class DQNPolicy(BasePolicy):
|
|||||||
if self._target and self._cnt % self._freq == 0:
|
if self._target and self._cnt % self._freq == 0:
|
||||||
self.sync_weight()
|
self.sync_weight()
|
||||||
self.optim.zero_grad()
|
self.optim.zero_grad()
|
||||||
if hasattr(batch, 'loss'):
|
q = self(batch).logits
|
||||||
loss = batch.loss
|
q = q[np.arange(len(q)), batch.act]
|
||||||
|
r = to_torch(batch.returns, device=q.device, dtype=q.dtype)
|
||||||
|
if hasattr(batch, 'update_weight'):
|
||||||
|
td = r - q
|
||||||
|
batch.update_weight(batch.indice, to_numpy(td))
|
||||||
|
impt_weight = to_torch(batch.impt_weight,
|
||||||
|
device=q.device, dtype=torch.float)
|
||||||
|
loss = (td.pow(2) * impt_weight).mean()
|
||||||
else:
|
else:
|
||||||
q = self(batch).logits
|
|
||||||
q = q[np.arange(len(q)), batch.act]
|
|
||||||
r = batch.returns
|
|
||||||
if isinstance(r, np.ndarray):
|
|
||||||
r = to_torch(r, device=q.device, dtype=q.dtype)
|
|
||||||
loss = F.mse_loss(q, r)
|
loss = F.mse_loss(q, r)
|
||||||
loss.backward()
|
loss.backward()
|
||||||
self.optim.step()
|
self.optim.step()
|
||||||
|
@ -2,6 +2,8 @@ import torch
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
|
from tianshou.data import to_numpy
|
||||||
|
|
||||||
|
|
||||||
class MovAvg(object):
|
class MovAvg(object):
|
||||||
"""Class for moving average. It will automatically exclude the infinity and
|
"""Class for moving average. It will automatically exclude the infinity and
|
||||||
@ -32,7 +34,7 @@ class MovAvg(object):
|
|||||||
only one element, a python scalar, or a list of python scalar.
|
only one element, a python scalar, or a list of python scalar.
|
||||||
"""
|
"""
|
||||||
if isinstance(x, torch.Tensor):
|
if isinstance(x, torch.Tensor):
|
||||||
x = x.item()
|
x = to_numpy(x.flatten())
|
||||||
if isinstance(x, list) or isinstance(x, np.ndarray):
|
if isinstance(x, list) or isinstance(x, np.ndarray):
|
||||||
for _ in x:
|
for _ in x:
|
||||||
if _ not in self.banned:
|
if _ not in self.banned:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user