385 lines
15 KiB
Python
Raw Normal View History

import gym
2020-05-12 11:31:47 +08:00
import torch
2020-04-14 21:11:06 +08:00
import numpy as np
2020-03-18 21:45:41 +08:00
from torch import nn
from numba import njit
2020-03-12 22:20:33 +08:00
from abc import ABC, abstractmethod
from typing import Any, List, Union, Mapping, Optional, Callable
2020-05-12 11:31:47 +08:00
from tianshou.data import Batch, ReplayBuffer, to_torch_as, to_numpy
2020-03-12 22:20:33 +08:00
2020-03-18 21:45:41 +08:00
class BasePolicy(ABC, nn.Module):
"""The base class for any RL policy.
Tianshou aims to modularizing RL algorithms. It comes into several classes
of policies in Tianshou. All of the policy classes must inherit
2020-04-06 19:36:59 +08:00
:class:`~tianshou.policy.BasePolicy`.
2020-03-13 17:49:22 +08:00
2020-04-06 19:36:59 +08:00
A policy class typically has four parts:
* :meth:`~tianshou.policy.BasePolicy.__init__`: initialize the policy, \
including coping the target network and so on;
2020-04-10 10:47:16 +08:00
* :meth:`~tianshou.policy.BasePolicy.forward`: compute action with given \
2020-04-06 19:36:59 +08:00
observation;
* :meth:`~tianshou.policy.BasePolicy.process_fn`: pre-process data from \
the replay buffer (this function can interact with replay buffer);
* :meth:`~tianshou.policy.BasePolicy.learn`: update policy with a given \
batch of data.
Most of the policy needs a neural network to predict the action and an
optimizer to optimize the policy. The rules of self-defined networks are:
1. Input: observation "obs" (may be a ``numpy.ndarray``, a \
``torch.Tensor``, a dict or any others), hidden state "state" (for RNN \
usage), and other information "info" provided by the environment.
2. Output: some "logits", the next hidden state "state", and the \
intermediate result during policy forwarding procedure "policy". The \
"logits" could be a tuple instead of a ``torch.Tensor``. It depends on how\
the policy process the network output. For example, in PPO, the return of \
the network might be ``(mu, sigma), state`` for Gaussian policy. The \
"policy" can be a Batch of torch.Tensor or other things, which will be \
stored in the replay buffer, and can be accessed in the policy update \
process (e.g. in "policy.learn()", the "batch.policy" is what you need).
2020-04-06 19:36:59 +08:00
Since :class:`~tianshou.policy.BasePolicy` inherits ``torch.nn.Module``,
2020-04-10 11:16:33 +08:00
you can use :class:`~tianshou.policy.BasePolicy` almost the same as
``torch.nn.Module``, for instance, loading and saving the model:
2020-04-06 19:36:59 +08:00
::
torch.save(policy.state_dict(), "policy.pth")
policy.load_state_dict(torch.load("policy.pth"))
2020-04-06 19:36:59 +08:00
"""
def __init__(
self,
observation_space: gym.Space = None,
action_space: gym.Space = None
) -> None:
2020-03-12 22:20:33 +08:00
super().__init__()
self.observation_space = observation_space
self.action_space = action_space
Add multi-agent example: tic-tac-toe (#122) * make fileds with empty Batch rather than None after reset * dummy code * remove dummy * add reward_length argument for collector * Improve Batch (#126) * make sure the key type of Batch is string, and add unit tests * add is_empty() function and unit tests * enable cat of mixing dict and Batch, just like stack * bugfix for reward_length * add get_final_reward_fn argument to collector to deal with marl * minor polish * remove multibuf * minor polish * improve and implement Batch.cat_ * bugfix for buffer.sample with field impt_weight * restore the usage of a.cat_(b) * fix 2 bugs in batch and add corresponding unittest * code fix for update * update is_empty to recognize empty over empty; bugfix for len * bugfix for update and add testcase * add testcase of update * make fileds with empty Batch rather than None after reset * dummy code * remove dummy * add reward_length argument for collector * bugfix for reward_length * add get_final_reward_fn argument to collector to deal with marl * make sure the key type of Batch is string, and add unit tests * add is_empty() function and unit tests * enable cat of mixing dict and Batch, just like stack * dummy code * remove dummy * add multi-agent example: tic-tac-toe * move TicTacToeEnv to a separate file * remove dummy MANet * code refactor * move tic-tac-toe example to test * update doc with marl-example * fix docs * reduce the threshold * revert * update player id to start from 1 and change player to agent; keep coding * add reward_length argument for collector * Improve Batch (#128) * minor polish * improve and implement Batch.cat_ * bugfix for buffer.sample with field impt_weight * restore the usage of a.cat_(b) * fix 2 bugs in batch and add corresponding unittest * code fix for update * update is_empty to recognize empty over empty; bugfix for len * bugfix for update and add testcase * add testcase of update * fix docs * fix docs * fix docs [ci skip] * fix docs [ci skip] Co-authored-by: Trinkle23897 <463003665@qq.com> * refact * re-implement Batch.stack and add testcases * add doc for Batch.stack * reward_metric * modify flag * minor fix * reuse _create_values and refactor stack_ & cat_ * fix pep8 * fix reward stat in collector * fix stat of collector, simplify test/base/env.py * fix docs * minor fix * raise exception for stacking with partial keys and axis!=0 * minor fix * minor fix * minor fix * marl-examples * add condense; bugfix for torch.Tensor; code refactor * marl example can run now * enable tic tac toe with larger board size and win-size * add test dependency * Fix padding of inconsistent keys with Batch.stack and Batch.cat (#130) * re-implement Batch.stack and add testcases * add doc for Batch.stack * reuse _create_values and refactor stack_ & cat_ * fix pep8 * fix docs * raise exception for stacking with partial keys and axis!=0 * minor fix * minor fix Co-authored-by: Trinkle23897 <463003665@qq.com> * stash * let agent learn to play as agent 2 which is harder * code refactor * Improve collector (#125) * remove multibuf * reward_metric * make fileds with empty Batch rather than None after reset * many fixes and refactor Co-authored-by: Trinkle23897 <463003665@qq.com> * marl for tic-tac-toe and general gomoku * update default gamma to 0.1 for tic tac toe to win earlier * fix name typo; change default game config; add rew_norm option * fix pep8 * test commit * mv test dir name * add rew flag * fix torch.optim import error and madqn rew_norm * remove useless kwargs * Vector env enable select worker (#132) * Enable selecting worker for vector env step method. * Update collector to match new vecenv selective worker behavior. * Bug fix. * Fix rebase Co-authored-by: Alexis Duburcq <alexis.duburcq@wandercraft.eu> * show the last move of tictactoe by capital letters * add multi-agent tutorial * fix link * Standardized behavior of Batch.cat and misc code refactor (#137) * code refactor; remove unused kwargs; add reward_normalization for dqn * bugfix for __setitem__ with torch.Tensor; add Batch.condense * minor fix * support cat with empty Batch * remove the dependency of is_empty on len; specify the semantic of empty Batch by test cases * support stack with empty Batch * remove condense * refactor code to reflect the shared / partial / reserved categories of keys * add is_empty(recursive=False) * doc fix * docfix and bugfix for _is_batch_set * add doc for key reservation * bugfix for algebra operators * fix cat with lens hint * code refactor * bugfix for storing None * use ValueError instead of exception * hide lens away from users * add comment for __cat * move the computation of the initial value of lens in cat_ itself. * change the place of doc string * doc fix for Batch doc string * change recursive to recurse * doc string fix * minor fix for batch doc * write tutorials to specify the standard of Batch (#142) * add doc for len exceptions * doc move; unify is_scalar_value function * remove some issubclass check * bugfix for shape of Batch(a=1) * keep moving doc * keep writing batch tutorial * draft version of Batch tutorial done * improving doc * keep improving doc * batch tutorial done * rename _is_number * rename _is_scalar * shape property do not raise exception * restore some doc string * grammarly [ci skip] * grammarly + fix warning of building docs * polish docs * trim and re-arrange batch tutorial * go straight to the point * minor fix for batch doc * add shape / len in basic usage * keep improving tutorial * unify _to_array_with_correct_type to remove duplicate code * delegate type convertion to Batch.__init__ * further delegate type convertion to Batch.__init__ * bugfix for setattr * add a _parse_value function * remove dummy function call * polish docs Co-authored-by: Trinkle23897 <463003665@qq.com> * bugfix for mapolicy * pretty code * remove debug code; remove condense * doc fix * check before get_agents in tutorials/tictactoe * tutorial * fix * minor fix for batch doc * minor polish * faster test_ttt * improve tic-tac-toe environment * change default epoch and step-per-epoch for tic-tac-toe * fix mapolicy * minor polish for mapolicy * 90% to 80% (need to change the tutorial) * win rate * show step number at board * simplify mapolicy * minor polish for mapolicy * remove MADQN * fix pep8 * change legal_actions to mask (need to update docs) * simplify maenv * fix typo * move basevecenv to single file * separate RandomAgent * update docs * grammarly * fix pep8 * win rate typo * format in cheatsheet * use bool mask directly * update doc for boolean mask Co-authored-by: Trinkle23897 <463003665@qq.com> Co-authored-by: Alexis DUBURCQ <alexis.duburcq@gmail.com> Co-authored-by: Alexis Duburcq <alexis.duburcq@wandercraft.eu>
2020-07-21 14:59:49 +08:00
self.agent_id = 0
self.updating = False
self._compile()
Add multi-agent example: tic-tac-toe (#122) * make fileds with empty Batch rather than None after reset * dummy code * remove dummy * add reward_length argument for collector * Improve Batch (#126) * make sure the key type of Batch is string, and add unit tests * add is_empty() function and unit tests * enable cat of mixing dict and Batch, just like stack * bugfix for reward_length * add get_final_reward_fn argument to collector to deal with marl * minor polish * remove multibuf * minor polish * improve and implement Batch.cat_ * bugfix for buffer.sample with field impt_weight * restore the usage of a.cat_(b) * fix 2 bugs in batch and add corresponding unittest * code fix for update * update is_empty to recognize empty over empty; bugfix for len * bugfix for update and add testcase * add testcase of update * make fileds with empty Batch rather than None after reset * dummy code * remove dummy * add reward_length argument for collector * bugfix for reward_length * add get_final_reward_fn argument to collector to deal with marl * make sure the key type of Batch is string, and add unit tests * add is_empty() function and unit tests * enable cat of mixing dict and Batch, just like stack * dummy code * remove dummy * add multi-agent example: tic-tac-toe * move TicTacToeEnv to a separate file * remove dummy MANet * code refactor * move tic-tac-toe example to test * update doc with marl-example * fix docs * reduce the threshold * revert * update player id to start from 1 and change player to agent; keep coding * add reward_length argument for collector * Improve Batch (#128) * minor polish * improve and implement Batch.cat_ * bugfix for buffer.sample with field impt_weight * restore the usage of a.cat_(b) * fix 2 bugs in batch and add corresponding unittest * code fix for update * update is_empty to recognize empty over empty; bugfix for len * bugfix for update and add testcase * add testcase of update * fix docs * fix docs * fix docs [ci skip] * fix docs [ci skip] Co-authored-by: Trinkle23897 <463003665@qq.com> * refact * re-implement Batch.stack and add testcases * add doc for Batch.stack * reward_metric * modify flag * minor fix * reuse _create_values and refactor stack_ & cat_ * fix pep8 * fix reward stat in collector * fix stat of collector, simplify test/base/env.py * fix docs * minor fix * raise exception for stacking with partial keys and axis!=0 * minor fix * minor fix * minor fix * marl-examples * add condense; bugfix for torch.Tensor; code refactor * marl example can run now * enable tic tac toe with larger board size and win-size * add test dependency * Fix padding of inconsistent keys with Batch.stack and Batch.cat (#130) * re-implement Batch.stack and add testcases * add doc for Batch.stack * reuse _create_values and refactor stack_ & cat_ * fix pep8 * fix docs * raise exception for stacking with partial keys and axis!=0 * minor fix * minor fix Co-authored-by: Trinkle23897 <463003665@qq.com> * stash * let agent learn to play as agent 2 which is harder * code refactor * Improve collector (#125) * remove multibuf * reward_metric * make fileds with empty Batch rather than None after reset * many fixes and refactor Co-authored-by: Trinkle23897 <463003665@qq.com> * marl for tic-tac-toe and general gomoku * update default gamma to 0.1 for tic tac toe to win earlier * fix name typo; change default game config; add rew_norm option * fix pep8 * test commit * mv test dir name * add rew flag * fix torch.optim import error and madqn rew_norm * remove useless kwargs * Vector env enable select worker (#132) * Enable selecting worker for vector env step method. * Update collector to match new vecenv selective worker behavior. * Bug fix. * Fix rebase Co-authored-by: Alexis Duburcq <alexis.duburcq@wandercraft.eu> * show the last move of tictactoe by capital letters * add multi-agent tutorial * fix link * Standardized behavior of Batch.cat and misc code refactor (#137) * code refactor; remove unused kwargs; add reward_normalization for dqn * bugfix for __setitem__ with torch.Tensor; add Batch.condense * minor fix * support cat with empty Batch * remove the dependency of is_empty on len; specify the semantic of empty Batch by test cases * support stack with empty Batch * remove condense * refactor code to reflect the shared / partial / reserved categories of keys * add is_empty(recursive=False) * doc fix * docfix and bugfix for _is_batch_set * add doc for key reservation * bugfix for algebra operators * fix cat with lens hint * code refactor * bugfix for storing None * use ValueError instead of exception * hide lens away from users * add comment for __cat * move the computation of the initial value of lens in cat_ itself. * change the place of doc string * doc fix for Batch doc string * change recursive to recurse * doc string fix * minor fix for batch doc * write tutorials to specify the standard of Batch (#142) * add doc for len exceptions * doc move; unify is_scalar_value function * remove some issubclass check * bugfix for shape of Batch(a=1) * keep moving doc * keep writing batch tutorial * draft version of Batch tutorial done * improving doc * keep improving doc * batch tutorial done * rename _is_number * rename _is_scalar * shape property do not raise exception * restore some doc string * grammarly [ci skip] * grammarly + fix warning of building docs * polish docs * trim and re-arrange batch tutorial * go straight to the point * minor fix for batch doc * add shape / len in basic usage * keep improving tutorial * unify _to_array_with_correct_type to remove duplicate code * delegate type convertion to Batch.__init__ * further delegate type convertion to Batch.__init__ * bugfix for setattr * add a _parse_value function * remove dummy function call * polish docs Co-authored-by: Trinkle23897 <463003665@qq.com> * bugfix for mapolicy * pretty code * remove debug code; remove condense * doc fix * check before get_agents in tutorials/tictactoe * tutorial * fix * minor fix for batch doc * minor polish * faster test_ttt * improve tic-tac-toe environment * change default epoch and step-per-epoch for tic-tac-toe * fix mapolicy * minor polish for mapolicy * 90% to 80% (need to change the tutorial) * win rate * show step number at board * simplify mapolicy * minor polish for mapolicy * remove MADQN * fix pep8 * change legal_actions to mask (need to update docs) * simplify maenv * fix typo * move basevecenv to single file * separate RandomAgent * update docs * grammarly * fix pep8 * win rate typo * format in cheatsheet * use bool mask directly * update doc for boolean mask Co-authored-by: Trinkle23897 <463003665@qq.com> Co-authored-by: Alexis DUBURCQ <alexis.duburcq@gmail.com> Co-authored-by: Alexis Duburcq <alexis.duburcq@wandercraft.eu>
2020-07-21 14:59:49 +08:00
def set_agent_id(self, agent_id: int) -> None:
"""Set self.agent_id = agent_id, for MARL."""
Add multi-agent example: tic-tac-toe (#122) * make fileds with empty Batch rather than None after reset * dummy code * remove dummy * add reward_length argument for collector * Improve Batch (#126) * make sure the key type of Batch is string, and add unit tests * add is_empty() function and unit tests * enable cat of mixing dict and Batch, just like stack * bugfix for reward_length * add get_final_reward_fn argument to collector to deal with marl * minor polish * remove multibuf * minor polish * improve and implement Batch.cat_ * bugfix for buffer.sample with field impt_weight * restore the usage of a.cat_(b) * fix 2 bugs in batch and add corresponding unittest * code fix for update * update is_empty to recognize empty over empty; bugfix for len * bugfix for update and add testcase * add testcase of update * make fileds with empty Batch rather than None after reset * dummy code * remove dummy * add reward_length argument for collector * bugfix for reward_length * add get_final_reward_fn argument to collector to deal with marl * make sure the key type of Batch is string, and add unit tests * add is_empty() function and unit tests * enable cat of mixing dict and Batch, just like stack * dummy code * remove dummy * add multi-agent example: tic-tac-toe * move TicTacToeEnv to a separate file * remove dummy MANet * code refactor * move tic-tac-toe example to test * update doc with marl-example * fix docs * reduce the threshold * revert * update player id to start from 1 and change player to agent; keep coding * add reward_length argument for collector * Improve Batch (#128) * minor polish * improve and implement Batch.cat_ * bugfix for buffer.sample with field impt_weight * restore the usage of a.cat_(b) * fix 2 bugs in batch and add corresponding unittest * code fix for update * update is_empty to recognize empty over empty; bugfix for len * bugfix for update and add testcase * add testcase of update * fix docs * fix docs * fix docs [ci skip] * fix docs [ci skip] Co-authored-by: Trinkle23897 <463003665@qq.com> * refact * re-implement Batch.stack and add testcases * add doc for Batch.stack * reward_metric * modify flag * minor fix * reuse _create_values and refactor stack_ & cat_ * fix pep8 * fix reward stat in collector * fix stat of collector, simplify test/base/env.py * fix docs * minor fix * raise exception for stacking with partial keys and axis!=0 * minor fix * minor fix * minor fix * marl-examples * add condense; bugfix for torch.Tensor; code refactor * marl example can run now * enable tic tac toe with larger board size and win-size * add test dependency * Fix padding of inconsistent keys with Batch.stack and Batch.cat (#130) * re-implement Batch.stack and add testcases * add doc for Batch.stack * reuse _create_values and refactor stack_ & cat_ * fix pep8 * fix docs * raise exception for stacking with partial keys and axis!=0 * minor fix * minor fix Co-authored-by: Trinkle23897 <463003665@qq.com> * stash * let agent learn to play as agent 2 which is harder * code refactor * Improve collector (#125) * remove multibuf * reward_metric * make fileds with empty Batch rather than None after reset * many fixes and refactor Co-authored-by: Trinkle23897 <463003665@qq.com> * marl for tic-tac-toe and general gomoku * update default gamma to 0.1 for tic tac toe to win earlier * fix name typo; change default game config; add rew_norm option * fix pep8 * test commit * mv test dir name * add rew flag * fix torch.optim import error and madqn rew_norm * remove useless kwargs * Vector env enable select worker (#132) * Enable selecting worker for vector env step method. * Update collector to match new vecenv selective worker behavior. * Bug fix. * Fix rebase Co-authored-by: Alexis Duburcq <alexis.duburcq@wandercraft.eu> * show the last move of tictactoe by capital letters * add multi-agent tutorial * fix link * Standardized behavior of Batch.cat and misc code refactor (#137) * code refactor; remove unused kwargs; add reward_normalization for dqn * bugfix for __setitem__ with torch.Tensor; add Batch.condense * minor fix * support cat with empty Batch * remove the dependency of is_empty on len; specify the semantic of empty Batch by test cases * support stack with empty Batch * remove condense * refactor code to reflect the shared / partial / reserved categories of keys * add is_empty(recursive=False) * doc fix * docfix and bugfix for _is_batch_set * add doc for key reservation * bugfix for algebra operators * fix cat with lens hint * code refactor * bugfix for storing None * use ValueError instead of exception * hide lens away from users * add comment for __cat * move the computation of the initial value of lens in cat_ itself. * change the place of doc string * doc fix for Batch doc string * change recursive to recurse * doc string fix * minor fix for batch doc * write tutorials to specify the standard of Batch (#142) * add doc for len exceptions * doc move; unify is_scalar_value function * remove some issubclass check * bugfix for shape of Batch(a=1) * keep moving doc * keep writing batch tutorial * draft version of Batch tutorial done * improving doc * keep improving doc * batch tutorial done * rename _is_number * rename _is_scalar * shape property do not raise exception * restore some doc string * grammarly [ci skip] * grammarly + fix warning of building docs * polish docs * trim and re-arrange batch tutorial * go straight to the point * minor fix for batch doc * add shape / len in basic usage * keep improving tutorial * unify _to_array_with_correct_type to remove duplicate code * delegate type convertion to Batch.__init__ * further delegate type convertion to Batch.__init__ * bugfix for setattr * add a _parse_value function * remove dummy function call * polish docs Co-authored-by: Trinkle23897 <463003665@qq.com> * bugfix for mapolicy * pretty code * remove debug code; remove condense * doc fix * check before get_agents in tutorials/tictactoe * tutorial * fix * minor fix for batch doc * minor polish * faster test_ttt * improve tic-tac-toe environment * change default epoch and step-per-epoch for tic-tac-toe * fix mapolicy * minor polish for mapolicy * 90% to 80% (need to change the tutorial) * win rate * show step number at board * simplify mapolicy * minor polish for mapolicy * remove MADQN * fix pep8 * change legal_actions to mask (need to update docs) * simplify maenv * fix typo * move basevecenv to single file * separate RandomAgent * update docs * grammarly * fix pep8 * win rate typo * format in cheatsheet * use bool mask directly * update doc for boolean mask Co-authored-by: Trinkle23897 <463003665@qq.com> Co-authored-by: Alexis DUBURCQ <alexis.duburcq@gmail.com> Co-authored-by: Alexis Duburcq <alexis.duburcq@wandercraft.eu>
2020-07-21 14:59:49 +08:00
self.agent_id = agent_id
2020-03-12 22:20:33 +08:00
def exploration_noise(
self, act: Union[np.ndarray, Batch], batch: Batch
) -> Union[np.ndarray, Batch]:
"""Modify the action from policy.forward with exploration noise.
:param act: a data batch or numpy.ndarray which is the action taken by
policy.forward.
:param batch: the input batch for policy.forward, kept for advanced usage.
:return: action in the same form of input "act" but with added exploration
noise.
"""
return act
2020-03-12 22:20:33 +08:00
@abstractmethod
def forward(
self,
batch: Batch,
state: Optional[Union[dict, Batch, np.ndarray]] = None,
**kwargs: Any,
) -> Batch:
2020-04-06 19:36:59 +08:00
"""Compute action over the given batch data.
:return: A :class:`~tianshou.data.Batch` which MUST have the following keys:
2020-04-06 19:36:59 +08:00
* ``act`` an numpy.ndarray or a torch.Tensor, the action over \
given batch data.
* ``state`` a dict, an numpy.ndarray or a torch.Tensor, the \
internal state of the policy, ``None`` as default.
Other keys are user-defined. It depends on the algorithm. For example,
::
# some code
return Batch(logits=..., act=..., state=None, dist=...)
The keyword ``policy`` is reserved and the corresponding data will be
stored into the replay buffer. For instance,
::
# some code
return Batch(..., policy=Batch(log_prob=dist.log_prob(act)))
# and in the sampled data batch, you can directly use
# batch.policy.log_prob to get your data.
2020-04-06 19:36:59 +08:00
"""
2020-03-12 22:20:33 +08:00
pass
def process_fn(
self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray
) -> Batch:
"""Pre-process the data from the provided replay buffer.
Used in :meth:`update`. Check out :ref:`process_fn` for more information.
"""
return batch
2020-03-15 17:41:00 +08:00
@abstractmethod
def learn(
self, batch: Batch, **kwargs: Any
) -> Mapping[str, Union[float, List[float]]]:
2020-04-06 19:36:59 +08:00
"""Update policy with a given batch of data.
2020-03-12 22:20:33 +08:00
2020-04-06 19:36:59 +08:00
:return: A dict which includes loss and its corresponding label.
.. note::
In order to distinguish the collecting state, updating state and
testing state, you can check the policy state by ``self.training``
and ``self.updating``. Please refer to :ref:`policy_state` for more
detailed explanation.
.. warning::
If you use ``torch.distributions.Normal`` and
``torch.distributions.Categorical`` to calculate the log_prob,
please be careful about the shape: Categorical distribution gives
"[batch_size]" shape while Normal distribution gives "[batch_size,
1]" shape. The auto-broadcasting of numerical operation with torch
tensors will amplify this error.
2020-04-06 19:36:59 +08:00
"""
2020-03-14 21:48:31 +08:00
pass
2020-04-14 21:11:06 +08:00
def post_process_fn(
self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray
) -> None:
"""Post-process the data from the provided replay buffer.
Typical usage is to update the sampling weight in prioritized
experience replay. Used in :meth:`update`.
"""
if hasattr(buffer, "update_weight") and hasattr(batch, "weight"):
buffer.update_weight(indice, batch.weight)
def update(
self, sample_size: int, buffer: Optional[ReplayBuffer], **kwargs: Any
) -> Mapping[str, Union[float, List[float]]]:
"""Update the policy network and replay buffer.
It includes 3 function steps: process_fn, learn, and post_process_fn.
In addition, this function will change the value of ``self.updating``:
it will be False before this function and will be True when executing
:meth:`update`. Please refer to :ref:`policy_state` for more detailed
explanation.
:param int sample_size: 0 means it will extract all the data from the
buffer, otherwise it will sample a batch with given sample_size.
:param ReplayBuffer buffer: the corresponding replay buffer.
"""
if buffer is None:
return {}
batch, indice = buffer.sample(sample_size)
self.updating = True
batch = self.process_fn(batch, buffer, indice)
result = self.learn(batch, **kwargs)
self.post_process_fn(batch, buffer, indice)
self.updating = False
return result
@staticmethod
def value_mask(buffer: ReplayBuffer, indice: np.ndarray) -> np.ndarray:
"""Value mask determines whether the obs_next of buffer[indice] is valid.
For instance, usually "obs_next" after "done" flag is considered to be invalid,
and its q/advantage value can provide meaningless (even misleading)
information, and should be set to 0 by hand. But if "done" flag is generated
because timelimit of game length (info["TimeLimit.truncated"] is set to True in
gym's settings), "obs_next" will instead be valid. Value mask is typically used
for assisting in calculating the correct q/advantage value.
:param ReplayBuffer buffer: the corresponding replay buffer.
:param numpy.ndarray indice: indices of replay buffer whose "obs_next" will be
judged.
:return: A bool type numpy.ndarray in the same shape with indice. "True" means
"obs_next" of that buffer[indice] is valid.
"""
return ~buffer.done[indice].astype(np.bool)
2020-04-19 14:30:42 +08:00
@staticmethod
2020-05-12 11:31:47 +08:00
def compute_episodic_return(
batch: Batch,
buffer: ReplayBuffer,
indice: np.ndarray,
v_s_: Optional[Union[np.ndarray, torch.Tensor]] = None,
gamma: float = 0.99,
gae_lambda: float = 0.95,
rew_norm: bool = False,
) -> Batch:
"""Compute returns over given batch.
Use Implementation of Generalized Advantage Estimator (arXiv:1506.02438)
to calculate q function/reward to go of given batch.
2020-04-14 21:11:06 +08:00
:param Batch batch: a data batch which contains several episodes of data
in sequential order. Mind that the end of each finished episode of batch
should be marked by done flag, unfinished (or collecting) episodes will be
recongized by buffer.unfinished_index().
:param np.ndarray indice: tell batch's location in buffer, batch is
equal to buffer[indice].
:param np.ndarray v_s_: the value function of all next states :math:`V(s')`.
:param float gamma: the discount factor, should be in [0, 1]. Default to 0.99.
:param float gae_lambda: the parameter for Generalized Advantage Estimation,
should be in [0, 1]. Default to 0.95.
:param bool rew_norm: normalize the reward to Normal(0, 1). Default to False.
2020-06-02 22:29:50 +08:00
:return: a Batch. The result will be stored in batch.returns as a numpy
array with shape (bsz, ).
2020-04-14 21:11:06 +08:00
"""
rew = batch.rew
if v_s_ is None:
assert np.isclose(gae_lambda, 1.0)
v_s_ = np.zeros_like(rew)
else:
v_s_ = to_numpy(v_s_.flatten()) * BasePolicy.value_mask(buffer, indice)
end_flag = batch.done.copy()
end_flag[np.isin(indice, buffer.unfinished_index())] = True
returns = _episodic_return(v_s_, rew, end_flag, gamma, gae_lambda)
if rew_norm and not np.isclose(returns.std(), 0.0, 1e-2):
returns = (returns - returns.mean()) / returns.std()
2020-06-02 22:29:50 +08:00
batch.returns = returns
return batch
@staticmethod
def compute_nstep_return(
batch: Batch,
buffer: ReplayBuffer,
indice: np.ndarray,
2020-06-03 13:59:47 +08:00
target_q_fn: Callable[[ReplayBuffer, np.ndarray], torch.Tensor],
2020-06-02 22:29:50 +08:00
gamma: float = 0.99,
2020-06-03 13:59:47 +08:00
n_step: int = 1,
rew_norm: bool = False,
) -> Batch:
r"""Compute n-step return for Q-learning targets.
2020-06-02 22:29:50 +08:00
.. math::
G_t = \sum_{i = t}^{t + n - 1} \gamma^{i - t}(1 - d_i)r_i +
\gamma^n (1 - d_{t + n}) Q_{\mathrm{target}}(s_{t + n})
where :math:`\gamma` is the discount factor, :math:`\gamma \in [0, 1]`,
:math:`d_t` is the done flag of step :math:`t`.
2020-06-02 22:29:50 +08:00
:param Batch batch: a data batch, which is equal to buffer[indice].
:param ReplayBuffer buffer: the data buffer.
:param function target_q_fn: a function which compute target Q value
of "obs_next" given data buffer and wanted indices.
:param float gamma: the discount factor, should be in [0, 1]. Default to 0.99.
:param int n_step: the number of estimation step, should be an int greater
than 0. Default to 1.
:param bool rew_norm: normalize the reward to Normal(0, 1), Default to False.
2020-06-02 22:29:50 +08:00
2020-06-03 13:59:47 +08:00
:return: a Batch. The result will be stored in batch.returns as a
torch.Tensor with the same shape as target_q_fn's return tensor.
2020-06-02 22:29:50 +08:00
"""
rew = buffer.rew
bsz = len(indice)
if rew_norm: # TODO: remove it or fix this bug
bfr = rew[:min(len(buffer), 1000)] # avoid large buffer
2020-06-03 13:59:47 +08:00
mean, std = bfr.mean(), bfr.std()
if np.isclose(std, 0, 1e-2):
mean, std = 0.0, 1.0
2020-06-03 13:59:47 +08:00
else:
mean, std = 0.0, 1.0
indices = [indice]
for _ in range(n_step - 1):
indices.append(buffer.next(indices[-1]))
indices = np.stack(indices)
# terminal indicates buffer indexes nstep after 'indice',
# and are truncated at the end of each episode
terminal = indices[-1]
with torch.no_grad():
target_q_torch = target_q_fn(buffer, terminal) # (bsz, ?)
target_q = to_numpy(target_q_torch.reshape(bsz, -1))
target_q = target_q * BasePolicy.value_mask(buffer, terminal).reshape(-1, 1)
end_flag = buffer.done.copy()
end_flag[buffer.unfinished_index()] = True
target_q = _nstep_return(rew, end_flag, target_q,
indices, gamma, n_step, mean, std)
batch.returns = to_torch_as(target_q, target_q_torch)
if hasattr(batch, "weight"): # prio buffer update
batch.weight = to_torch_as(batch.weight, target_q_torch)
2020-04-14 21:11:06 +08:00
return batch
def _compile(self) -> None:
f64 = np.array([0, 1], dtype=np.float64)
f32 = np.array([0, 1], dtype=np.float32)
b = np.array([False, True], dtype=np.bool_)
i64 = np.array([[0, 1]], dtype=np.int64)
_gae_return(f64, f64, f64, b, 0.1, 0.1)
_gae_return(f32, f32, f64, b, 0.1, 0.1)
_episodic_return(f64, f64, b, 0.1, 0.1)
_episodic_return(f32, f64, b, 0.1, 0.1)
_nstep_return(f64, b, f32.reshape(-1, 1), i64, 0.1, 1, 0.0, 1.0)
@njit
def _gae_return(
v_s: np.ndarray,
v_s_: np.ndarray,
rew: np.ndarray,
end_flag: np.ndarray,
gamma: float,
gae_lambda: float,
) -> np.ndarray:
returns = np.zeros(rew.shape)
delta = rew + v_s_ * gamma - v_s
m = (1.0 - end_flag) * (gamma * gae_lambda)
gae = 0.0
for i in range(len(rew) - 1, -1, -1):
gae = delta[i] + m[i] * gae
returns[i] = gae
return returns
@njit
def _episodic_return(
v_s_: np.ndarray,
rew: np.ndarray,
end_flag: np.ndarray,
gamma: float,
gae_lambda: float,
) -> np.ndarray:
"""Numba speedup: 4.1s -> 0.057s."""
v_s = np.roll(v_s_, 1)
return _gae_return(v_s, v_s_, rew, end_flag, gamma, gae_lambda) + v_s
@njit
def _nstep_return(
rew: np.ndarray,
end_flag: np.ndarray,
target_q: np.ndarray,
indices: np.ndarray,
gamma: float,
n_step: int,
mean: float,
std: float,
) -> np.ndarray:
gamma_buffer = np.ones(n_step + 1)
for i in range(1, n_step + 1):
gamma_buffer[i] = gamma_buffer[i - 1] * gamma
target_shape = target_q.shape
bsz = target_shape[0]
# change target_q to 2d array
target_q = target_q.reshape(bsz, -1)
returns = np.zeros(target_q.shape)
gammas = np.full(indices[0].shape, n_step)
for n in range(n_step - 1, -1, -1):
now = indices[n]
gammas[end_flag[now] > 0] = n
returns[end_flag[now] > 0] = 0.0
returns = (rew[now].reshape(bsz, 1) - mean) / std + gamma * returns
target_q = target_q * gamma_buffer[gammas].reshape(bsz, 1) + returns
return target_q.reshape(target_shape)