2020-09-02 13:03:32 +08:00
|
|
|
import gym
|
2020-05-12 11:31:47 +08:00
|
|
|
import torch
|
2020-04-14 21:11:06 +08:00
|
|
|
import numpy as np
|
2020-03-18 21:45:41 +08:00
|
|
|
from torch import nn
|
2020-09-02 13:03:32 +08:00
|
|
|
from numba import njit
|
2020-03-12 22:20:33 +08:00
|
|
|
from abc import ABC, abstractmethod
|
2020-09-12 15:39:01 +08:00
|
|
|
from typing import Any, List, Union, Mapping, Optional, Callable
|
2020-05-12 11:31:47 +08:00
|
|
|
|
2020-09-16 17:43:19 +08:00
|
|
|
from tianshou.data import Batch, ReplayBuffer, to_torch_as, to_numpy
|
2020-03-12 22:20:33 +08:00
|
|
|
|
|
|
|
|
2020-03-18 21:45:41 +08:00
|
|
|
class BasePolicy(ABC, nn.Module):
|
2020-09-11 07:55:37 +08:00
|
|
|
"""The base class for any RL policy.
|
|
|
|
|
|
|
|
Tianshou aims to modularizing RL algorithms. It comes into several classes
|
|
|
|
of policies in Tianshou. All of the policy classes must inherit
|
2020-04-06 19:36:59 +08:00
|
|
|
:class:`~tianshou.policy.BasePolicy`.
|
2020-03-13 17:49:22 +08:00
|
|
|
|
2020-04-06 19:36:59 +08:00
|
|
|
A policy class typically has four parts:
|
|
|
|
|
|
|
|
* :meth:`~tianshou.policy.BasePolicy.__init__`: initialize the policy, \
|
|
|
|
including coping the target network and so on;
|
2020-04-10 10:47:16 +08:00
|
|
|
* :meth:`~tianshou.policy.BasePolicy.forward`: compute action with given \
|
2020-04-06 19:36:59 +08:00
|
|
|
observation;
|
|
|
|
* :meth:`~tianshou.policy.BasePolicy.process_fn`: pre-process data from \
|
|
|
|
the replay buffer (this function can interact with replay buffer);
|
|
|
|
* :meth:`~tianshou.policy.BasePolicy.learn`: update policy with a given \
|
|
|
|
batch of data.
|
|
|
|
|
|
|
|
Most of the policy needs a neural network to predict the action and an
|
|
|
|
optimizer to optimize the policy. The rules of self-defined networks are:
|
|
|
|
|
2020-09-11 07:55:37 +08:00
|
|
|
1. Input: observation "obs" (may be a ``numpy.ndarray``, a \
|
|
|
|
``torch.Tensor``, a dict or any others), hidden state "state" (for RNN \
|
|
|
|
usage), and other information "info" provided by the environment.
|
|
|
|
2. Output: some "logits", the next hidden state "state", and the \
|
|
|
|
intermediate result during policy forwarding procedure "policy". The \
|
|
|
|
"logits" could be a tuple instead of a ``torch.Tensor``. It depends on how\
|
|
|
|
the policy process the network output. For example, in PPO, the return of \
|
|
|
|
the network might be ``(mu, sigma), state`` for Gaussian policy. The \
|
|
|
|
"policy" can be a Batch of torch.Tensor or other things, which will be \
|
|
|
|
stored in the replay buffer, and can be accessed in the policy update \
|
|
|
|
process (e.g. in "policy.learn()", the "batch.policy" is what you need).
|
2020-04-06 19:36:59 +08:00
|
|
|
|
|
|
|
Since :class:`~tianshou.policy.BasePolicy` inherits ``torch.nn.Module``,
|
2020-04-10 11:16:33 +08:00
|
|
|
you can use :class:`~tianshou.policy.BasePolicy` almost the same as
|
|
|
|
``torch.nn.Module``, for instance, loading and saving the model:
|
2020-04-06 19:36:59 +08:00
|
|
|
::
|
|
|
|
|
2020-09-11 07:55:37 +08:00
|
|
|
torch.save(policy.state_dict(), "policy.pth")
|
|
|
|
policy.load_state_dict(torch.load("policy.pth"))
|
2020-04-06 19:36:59 +08:00
|
|
|
"""
|
|
|
|
|
2020-09-12 15:39:01 +08:00
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
observation_space: gym.Space = None,
|
|
|
|
action_space: gym.Space = None
|
|
|
|
) -> None:
|
2020-03-12 22:20:33 +08:00
|
|
|
super().__init__()
|
2020-09-02 13:03:32 +08:00
|
|
|
self.observation_space = observation_space
|
|
|
|
self.action_space = action_space
|
2020-07-21 14:59:49 +08:00
|
|
|
self.agent_id = 0
|
2020-09-22 16:28:46 +08:00
|
|
|
self.updating = False
|
2020-09-12 15:39:01 +08:00
|
|
|
self._compile()
|
2020-07-21 14:59:49 +08:00
|
|
|
|
|
|
|
def set_agent_id(self, agent_id: int) -> None:
|
2020-09-11 07:55:37 +08:00
|
|
|
"""Set self.agent_id = agent_id, for MARL."""
|
2020-07-21 14:59:49 +08:00
|
|
|
self.agent_id = agent_id
|
2020-03-12 22:20:33 +08:00
|
|
|
|
2021-02-19 10:33:49 +08:00
|
|
|
def exploration_noise(
|
|
|
|
self, act: Union[np.ndarray, Batch], batch: Batch
|
|
|
|
) -> Union[np.ndarray, Batch]:
|
|
|
|
"""Modify the action from policy.forward with exploration noise.
|
|
|
|
|
|
|
|
:param act: a data batch or numpy.ndarray which is the action taken by
|
|
|
|
policy.forward.
|
|
|
|
:param batch: the input batch for policy.forward, kept for advanced usage.
|
|
|
|
|
|
|
|
:return: action in the same form of input "act" but with added exploration
|
|
|
|
noise.
|
|
|
|
"""
|
|
|
|
return act
|
|
|
|
|
2020-03-12 22:20:33 +08:00
|
|
|
@abstractmethod
|
2020-09-12 15:39:01 +08:00
|
|
|
def forward(
|
|
|
|
self,
|
|
|
|
batch: Batch,
|
|
|
|
state: Optional[Union[dict, Batch, np.ndarray]] = None,
|
|
|
|
**kwargs: Any,
|
|
|
|
) -> Batch:
|
2020-04-06 19:36:59 +08:00
|
|
|
"""Compute action over the given batch data.
|
|
|
|
|
2021-02-19 10:33:49 +08:00
|
|
|
:return: A :class:`~tianshou.data.Batch` which MUST have the following keys:
|
2020-04-06 19:36:59 +08:00
|
|
|
|
|
|
|
* ``act`` an numpy.ndarray or a torch.Tensor, the action over \
|
|
|
|
given batch data.
|
|
|
|
* ``state`` a dict, an numpy.ndarray or a torch.Tensor, the \
|
|
|
|
internal state of the policy, ``None`` as default.
|
|
|
|
|
|
|
|
Other keys are user-defined. It depends on the algorithm. For example,
|
|
|
|
::
|
|
|
|
|
|
|
|
# some code
|
|
|
|
return Batch(logits=..., act=..., state=None, dist=...)
|
2020-04-29 17:48:48 +08:00
|
|
|
|
2020-07-27 16:54:14 +08:00
|
|
|
The keyword ``policy`` is reserved and the corresponding data will be
|
2020-09-11 07:55:37 +08:00
|
|
|
stored into the replay buffer. For instance,
|
2020-04-29 17:48:48 +08:00
|
|
|
::
|
|
|
|
|
|
|
|
# some code
|
|
|
|
return Batch(..., policy=Batch(log_prob=dist.log_prob(act)))
|
2020-08-15 16:10:42 +08:00
|
|
|
# and in the sampled data batch, you can directly use
|
|
|
|
# batch.policy.log_prob to get your data.
|
2020-04-06 19:36:59 +08:00
|
|
|
"""
|
2020-03-12 22:20:33 +08:00
|
|
|
pass
|
|
|
|
|
2020-09-12 15:39:01 +08:00
|
|
|
def process_fn(
|
|
|
|
self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray
|
|
|
|
) -> Batch:
|
2020-09-11 07:55:37 +08:00
|
|
|
"""Pre-process the data from the provided replay buffer.
|
|
|
|
|
2021-02-19 10:33:49 +08:00
|
|
|
Used in :meth:`update`. Check out :ref:`process_fn` for more information.
|
2020-09-02 13:03:32 +08:00
|
|
|
"""
|
|
|
|
return batch
|
|
|
|
|
2020-03-15 17:41:00 +08:00
|
|
|
@abstractmethod
|
2020-09-12 15:39:01 +08:00
|
|
|
def learn(
|
|
|
|
self, batch: Batch, **kwargs: Any
|
|
|
|
) -> Mapping[str, Union[float, List[float]]]:
|
2020-04-06 19:36:59 +08:00
|
|
|
"""Update policy with a given batch of data.
|
2020-03-12 22:20:33 +08:00
|
|
|
|
2020-04-06 19:36:59 +08:00
|
|
|
:return: A dict which includes loss and its corresponding label.
|
2020-07-21 22:24:06 +08:00
|
|
|
|
2020-09-22 16:28:46 +08:00
|
|
|
.. note::
|
|
|
|
|
|
|
|
In order to distinguish the collecting state, updating state and
|
|
|
|
testing state, you can check the policy state by ``self.training``
|
|
|
|
and ``self.updating``. Please refer to :ref:`policy_state` for more
|
|
|
|
detailed explanation.
|
|
|
|
|
2020-07-21 22:24:06 +08:00
|
|
|
.. warning::
|
|
|
|
|
|
|
|
If you use ``torch.distributions.Normal`` and
|
|
|
|
``torch.distributions.Categorical`` to calculate the log_prob,
|
|
|
|
please be careful about the shape: Categorical distribution gives
|
|
|
|
"[batch_size]" shape while Normal distribution gives "[batch_size,
|
|
|
|
1]" shape. The auto-broadcasting of numerical operation with torch
|
|
|
|
tensors will amplify this error.
|
2020-04-06 19:36:59 +08:00
|
|
|
"""
|
2020-03-14 21:48:31 +08:00
|
|
|
pass
|
2020-04-14 21:11:06 +08:00
|
|
|
|
2020-09-12 15:39:01 +08:00
|
|
|
def post_process_fn(
|
|
|
|
self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray
|
|
|
|
) -> None:
|
2020-09-11 07:55:37 +08:00
|
|
|
"""Post-process the data from the provided replay buffer.
|
|
|
|
|
|
|
|
Typical usage is to update the sampling weight in prioritized
|
|
|
|
experience replay. Used in :meth:`update`.
|
2020-09-02 13:03:32 +08:00
|
|
|
"""
|
2020-09-16 17:43:19 +08:00
|
|
|
if hasattr(buffer, "update_weight") and hasattr(batch, "weight"):
|
2020-09-02 13:03:32 +08:00
|
|
|
buffer.update_weight(indice, batch.weight)
|
|
|
|
|
2020-09-12 15:39:01 +08:00
|
|
|
def update(
|
|
|
|
self, sample_size: int, buffer: Optional[ReplayBuffer], **kwargs: Any
|
|
|
|
) -> Mapping[str, Union[float, List[float]]]:
|
2020-09-11 07:55:37 +08:00
|
|
|
"""Update the policy network and replay buffer.
|
2020-09-02 13:03:32 +08:00
|
|
|
|
2020-09-11 07:55:37 +08:00
|
|
|
It includes 3 function steps: process_fn, learn, and post_process_fn.
|
2020-09-22 16:28:46 +08:00
|
|
|
In addition, this function will change the value of ``self.updating``:
|
|
|
|
it will be False before this function and will be True when executing
|
|
|
|
:meth:`update`. Please refer to :ref:`policy_state` for more detailed
|
|
|
|
explanation.
|
2020-09-11 07:55:37 +08:00
|
|
|
|
|
|
|
:param int sample_size: 0 means it will extract all the data from the
|
|
|
|
buffer, otherwise it will sample a batch with given sample_size.
|
2020-09-02 13:03:32 +08:00
|
|
|
:param ReplayBuffer buffer: the corresponding replay buffer.
|
|
|
|
"""
|
|
|
|
if buffer is None:
|
|
|
|
return {}
|
2020-09-11 07:55:37 +08:00
|
|
|
batch, indice = buffer.sample(sample_size)
|
2020-09-22 16:28:46 +08:00
|
|
|
self.updating = True
|
2020-09-02 13:03:32 +08:00
|
|
|
batch = self.process_fn(batch, buffer, indice)
|
2020-09-12 15:39:01 +08:00
|
|
|
result = self.learn(batch, **kwargs)
|
2020-09-02 13:03:32 +08:00
|
|
|
self.post_process_fn(batch, buffer, indice)
|
2020-09-22 16:28:46 +08:00
|
|
|
self.updating = False
|
2020-09-02 13:03:32 +08:00
|
|
|
return result
|
|
|
|
|
2021-02-19 10:33:49 +08:00
|
|
|
@staticmethod
|
|
|
|
def value_mask(buffer: ReplayBuffer, indice: np.ndarray) -> np.ndarray:
|
|
|
|
"""Value mask determines whether the obs_next of buffer[indice] is valid.
|
|
|
|
|
|
|
|
For instance, usually "obs_next" after "done" flag is considered to be invalid,
|
|
|
|
and its q/advantage value can provide meaningless (even misleading)
|
|
|
|
information, and should be set to 0 by hand. But if "done" flag is generated
|
|
|
|
because timelimit of game length (info["TimeLimit.truncated"] is set to True in
|
|
|
|
gym's settings), "obs_next" will instead be valid. Value mask is typically used
|
|
|
|
for assisting in calculating the correct q/advantage value.
|
|
|
|
|
|
|
|
:param ReplayBuffer buffer: the corresponding replay buffer.
|
|
|
|
:param numpy.ndarray indice: indices of replay buffer whose "obs_next" will be
|
|
|
|
judged.
|
|
|
|
|
|
|
|
:return: A bool type numpy.ndarray in the same shape with indice. "True" means
|
|
|
|
"obs_next" of that buffer[indice] is valid.
|
|
|
|
"""
|
|
|
|
return ~buffer.done[indice].astype(np.bool)
|
|
|
|
|
2020-04-19 14:30:42 +08:00
|
|
|
@staticmethod
|
2020-05-12 11:31:47 +08:00
|
|
|
def compute_episodic_return(
|
2020-07-21 22:24:06 +08:00
|
|
|
batch: Batch,
|
2021-02-19 10:33:49 +08:00
|
|
|
buffer: ReplayBuffer,
|
|
|
|
indice: np.ndarray,
|
2020-07-21 22:24:06 +08:00
|
|
|
v_s_: Optional[Union[np.ndarray, torch.Tensor]] = None,
|
|
|
|
gamma: float = 0.99,
|
|
|
|
gae_lambda: float = 0.95,
|
2020-08-15 16:10:42 +08:00
|
|
|
rew_norm: bool = False,
|
2020-07-16 19:36:32 +08:00
|
|
|
) -> Batch:
|
2021-02-19 10:33:49 +08:00
|
|
|
"""Compute returns over given batch.
|
2020-09-11 07:55:37 +08:00
|
|
|
|
2021-02-19 10:33:49 +08:00
|
|
|
Use Implementation of Generalized Advantage Estimator (arXiv:1506.02438)
|
|
|
|
to calculate q function/reward to go of given batch.
|
2020-04-14 21:11:06 +08:00
|
|
|
|
2021-02-21 13:06:02 +08:00
|
|
|
:param Batch batch: a data batch which contains several episodes of data
|
2021-02-19 10:33:49 +08:00
|
|
|
in sequential order. Mind that the end of each finished episode of batch
|
|
|
|
should be marked by done flag, unfinished (or collecting) episodes will be
|
|
|
|
recongized by buffer.unfinished_index().
|
2021-02-21 13:06:02 +08:00
|
|
|
:param np.ndarray indice: tell batch's location in buffer, batch is
|
2021-02-19 10:33:49 +08:00
|
|
|
equal to buffer[indice].
|
2021-02-21 13:06:02 +08:00
|
|
|
:param np.ndarray v_s_: the value function of all next states :math:`V(s')`.
|
|
|
|
:param float gamma: the discount factor, should be in [0, 1]. Default to 0.99.
|
|
|
|
:param float gae_lambda: the parameter for Generalized Advantage Estimation,
|
|
|
|
should be in [0, 1]. Default to 0.95.
|
|
|
|
:param bool rew_norm: normalize the reward to Normal(0, 1). Default to False.
|
2020-06-02 22:29:50 +08:00
|
|
|
|
2020-07-21 22:24:06 +08:00
|
|
|
:return: a Batch. The result will be stored in batch.returns as a numpy
|
2020-07-23 15:12:02 +08:00
|
|
|
array with shape (bsz, ).
|
2020-04-14 21:11:06 +08:00
|
|
|
"""
|
2020-07-16 19:36:32 +08:00
|
|
|
rew = batch.rew
|
2021-02-19 10:33:49 +08:00
|
|
|
if v_s_ is None:
|
|
|
|
assert np.isclose(gae_lambda, 1.0)
|
|
|
|
v_s_ = np.zeros_like(rew)
|
|
|
|
else:
|
|
|
|
v_s_ = to_numpy(v_s_.flatten()) * BasePolicy.value_mask(buffer, indice)
|
|
|
|
|
|
|
|
end_flag = batch.done.copy()
|
|
|
|
end_flag[np.isin(indice, buffer.unfinished_index())] = True
|
|
|
|
returns = _episodic_return(v_s_, rew, end_flag, gamma, gae_lambda)
|
2020-09-12 15:39:01 +08:00
|
|
|
if rew_norm and not np.isclose(returns.std(), 0.0, 1e-2):
|
2020-08-15 16:10:42 +08:00
|
|
|
returns = (returns - returns.mean()) / returns.std()
|
2020-06-02 22:29:50 +08:00
|
|
|
batch.returns = returns
|
|
|
|
return batch
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def compute_nstep_return(
|
|
|
|
batch: Batch,
|
|
|
|
buffer: ReplayBuffer,
|
|
|
|
indice: np.ndarray,
|
2020-06-03 13:59:47 +08:00
|
|
|
target_q_fn: Callable[[ReplayBuffer, np.ndarray], torch.Tensor],
|
2020-06-02 22:29:50 +08:00
|
|
|
gamma: float = 0.99,
|
2020-06-03 13:59:47 +08:00
|
|
|
n_step: int = 1,
|
2020-07-16 19:36:32 +08:00
|
|
|
rew_norm: bool = False,
|
2020-07-21 22:24:06 +08:00
|
|
|
) -> Batch:
|
2020-09-11 07:55:37 +08:00
|
|
|
r"""Compute n-step return for Q-learning targets.
|
2020-06-02 22:29:50 +08:00
|
|
|
|
|
|
|
.. math::
|
|
|
|
G_t = \sum_{i = t}^{t + n - 1} \gamma^{i - t}(1 - d_i)r_i +
|
|
|
|
\gamma^n (1 - d_{t + n}) Q_{\mathrm{target}}(s_{t + n})
|
|
|
|
|
2021-02-19 10:33:49 +08:00
|
|
|
where :math:`\gamma` is the discount factor, :math:`\gamma \in [0, 1]`,
|
|
|
|
:math:`d_t` is the done flag of step :math:`t`.
|
2020-06-02 22:29:50 +08:00
|
|
|
|
2021-02-21 13:06:02 +08:00
|
|
|
:param Batch batch: a data batch, which is equal to buffer[indice].
|
|
|
|
:param ReplayBuffer buffer: the data buffer.
|
2021-02-19 10:33:49 +08:00
|
|
|
:param function target_q_fn: a function which compute target Q value
|
|
|
|
of "obs_next" given data buffer and wanted indices.
|
2021-02-21 13:06:02 +08:00
|
|
|
:param float gamma: the discount factor, should be in [0, 1]. Default to 0.99.
|
|
|
|
:param int n_step: the number of estimation step, should be an int greater
|
|
|
|
than 0. Default to 1.
|
|
|
|
:param bool rew_norm: normalize the reward to Normal(0, 1), Default to False.
|
2020-06-02 22:29:50 +08:00
|
|
|
|
2020-06-03 13:59:47 +08:00
|
|
|
:return: a Batch. The result will be stored in batch.returns as a
|
2021-01-06 10:17:45 +08:00
|
|
|
torch.Tensor with the same shape as target_q_fn's return tensor.
|
2020-06-02 22:29:50 +08:00
|
|
|
"""
|
2020-07-16 19:36:32 +08:00
|
|
|
rew = buffer.rew
|
2021-02-19 10:33:49 +08:00
|
|
|
bsz = len(indice)
|
|
|
|
if rew_norm: # TODO: remove it or fix this bug
|
2020-07-16 19:36:32 +08:00
|
|
|
bfr = rew[:min(len(buffer), 1000)] # avoid large buffer
|
2020-06-03 13:59:47 +08:00
|
|
|
mean, std = bfr.mean(), bfr.std()
|
2020-08-15 16:10:42 +08:00
|
|
|
if np.isclose(std, 0, 1e-2):
|
2020-09-12 15:39:01 +08:00
|
|
|
mean, std = 0.0, 1.0
|
2020-06-03 13:59:47 +08:00
|
|
|
else:
|
2020-09-12 15:39:01 +08:00
|
|
|
mean, std = 0.0, 1.0
|
2021-02-19 10:33:49 +08:00
|
|
|
indices = [indice]
|
|
|
|
for _ in range(n_step - 1):
|
|
|
|
indices.append(buffer.next(indices[-1]))
|
|
|
|
indices = np.stack(indices)
|
|
|
|
# terminal indicates buffer indexes nstep after 'indice',
|
|
|
|
# and are truncated at the end of each episode
|
|
|
|
terminal = indices[-1]
|
2021-01-28 09:27:05 +08:00
|
|
|
with torch.no_grad():
|
|
|
|
target_q_torch = target_q_fn(buffer, terminal) # (bsz, ?)
|
2021-02-19 10:33:49 +08:00
|
|
|
target_q = to_numpy(target_q_torch.reshape(bsz, -1))
|
|
|
|
target_q = target_q * BasePolicy.value_mask(buffer, terminal).reshape(-1, 1)
|
|
|
|
end_flag = buffer.done.copy()
|
|
|
|
end_flag[buffer.unfinished_index()] = True
|
|
|
|
target_q = _nstep_return(rew, end_flag, target_q,
|
|
|
|
indices, gamma, n_step, mean, std)
|
2020-09-02 13:03:32 +08:00
|
|
|
|
2020-08-27 12:15:18 +08:00
|
|
|
batch.returns = to_torch_as(target_q, target_q_torch)
|
2020-09-16 17:43:19 +08:00
|
|
|
if hasattr(batch, "weight"): # prio buffer update
|
2020-08-27 12:15:18 +08:00
|
|
|
batch.weight = to_torch_as(batch.weight, target_q_torch)
|
2020-04-14 21:11:06 +08:00
|
|
|
return batch
|
2020-08-15 16:10:42 +08:00
|
|
|
|
2020-09-12 15:39:01 +08:00
|
|
|
def _compile(self) -> None:
|
|
|
|
f64 = np.array([0, 1], dtype=np.float64)
|
|
|
|
f32 = np.array([0, 1], dtype=np.float32)
|
|
|
|
b = np.array([False, True], dtype=np.bool_)
|
2021-02-19 10:33:49 +08:00
|
|
|
i64 = np.array([[0, 1]], dtype=np.int64)
|
|
|
|
_gae_return(f64, f64, f64, b, 0.1, 0.1)
|
|
|
|
_gae_return(f32, f32, f64, b, 0.1, 0.1)
|
2020-09-12 15:39:01 +08:00
|
|
|
_episodic_return(f64, f64, b, 0.1, 0.1)
|
|
|
|
_episodic_return(f32, f64, b, 0.1, 0.1)
|
2021-02-19 10:33:49 +08:00
|
|
|
_nstep_return(f64, b, f32.reshape(-1, 1), i64, 0.1, 1, 0.0, 1.0)
|
2020-09-12 15:39:01 +08:00
|
|
|
|
2020-08-15 16:10:42 +08:00
|
|
|
|
2020-09-02 13:03:32 +08:00
|
|
|
@njit
|
2021-02-19 10:33:49 +08:00
|
|
|
def _gae_return(
|
|
|
|
v_s: np.ndarray,
|
2020-09-12 15:39:01 +08:00
|
|
|
v_s_: np.ndarray,
|
|
|
|
rew: np.ndarray,
|
2021-02-19 10:33:49 +08:00
|
|
|
end_flag: np.ndarray,
|
2020-09-12 15:39:01 +08:00
|
|
|
gamma: float,
|
|
|
|
gae_lambda: float,
|
2020-09-02 13:03:32 +08:00
|
|
|
) -> np.ndarray:
|
2021-02-19 10:33:49 +08:00
|
|
|
returns = np.zeros(rew.shape)
|
|
|
|
delta = rew + v_s_ * gamma - v_s
|
|
|
|
m = (1.0 - end_flag) * (gamma * gae_lambda)
|
2020-09-12 15:39:01 +08:00
|
|
|
gae = 0.0
|
2020-09-02 13:03:32 +08:00
|
|
|
for i in range(len(rew) - 1, -1, -1):
|
|
|
|
gae = delta[i] + m[i] * gae
|
2021-02-19 10:33:49 +08:00
|
|
|
returns[i] = gae
|
2020-09-02 13:03:32 +08:00
|
|
|
return returns
|
|
|
|
|
|
|
|
|
2021-02-19 10:33:49 +08:00
|
|
|
@njit
|
|
|
|
def _episodic_return(
|
|
|
|
v_s_: np.ndarray,
|
|
|
|
rew: np.ndarray,
|
|
|
|
end_flag: np.ndarray,
|
|
|
|
gamma: float,
|
|
|
|
gae_lambda: float,
|
|
|
|
) -> np.ndarray:
|
|
|
|
"""Numba speedup: 4.1s -> 0.057s."""
|
|
|
|
v_s = np.roll(v_s_, 1)
|
|
|
|
return _gae_return(v_s, v_s_, rew, end_flag, gamma, gae_lambda) + v_s
|
|
|
|
|
|
|
|
|
2020-09-02 13:03:32 +08:00
|
|
|
@njit
|
|
|
|
def _nstep_return(
|
2020-09-12 15:39:01 +08:00
|
|
|
rew: np.ndarray,
|
2021-02-19 10:33:49 +08:00
|
|
|
end_flag: np.ndarray,
|
2020-09-12 15:39:01 +08:00
|
|
|
target_q: np.ndarray,
|
2021-02-19 10:33:49 +08:00
|
|
|
indices: np.ndarray,
|
2020-09-12 15:39:01 +08:00
|
|
|
gamma: float,
|
|
|
|
n_step: int,
|
|
|
|
mean: float,
|
|
|
|
std: float,
|
2020-09-02 13:03:32 +08:00
|
|
|
) -> np.ndarray:
|
2021-02-19 10:33:49 +08:00
|
|
|
gamma_buffer = np.ones(n_step + 1)
|
|
|
|
for i in range(1, n_step + 1):
|
|
|
|
gamma_buffer[i] = gamma_buffer[i - 1] * gamma
|
2021-01-06 10:17:45 +08:00
|
|
|
target_shape = target_q.shape
|
|
|
|
bsz = target_shape[0]
|
|
|
|
# change target_q to 2d array
|
|
|
|
target_q = target_q.reshape(bsz, -1)
|
|
|
|
returns = np.zeros(target_q.shape)
|
2021-02-19 10:33:49 +08:00
|
|
|
gammas = np.full(indices[0].shape, n_step)
|
2020-09-02 13:03:32 +08:00
|
|
|
for n in range(n_step - 1, -1, -1):
|
2021-02-19 10:33:49 +08:00
|
|
|
now = indices[n]
|
|
|
|
gammas[end_flag[now] > 0] = n
|
|
|
|
returns[end_flag[now] > 0] = 0.0
|
|
|
|
returns = (rew[now].reshape(bsz, 1) - mean) / std + gamma * returns
|
|
|
|
target_q = target_q * gamma_buffer[gammas].reshape(bsz, 1) + returns
|
2021-01-06 10:17:45 +08:00
|
|
|
return target_q.reshape(target_shape)
|