286 lines
11 KiB
Python
Raw Normal View History

import gym
2020-05-12 11:31:47 +08:00
import torch
2020-04-14 21:11:06 +08:00
import numpy as np
2020-03-18 21:45:41 +08:00
from torch import nn
from numba import njit
2020-03-12 22:20:33 +08:00
from abc import ABC, abstractmethod
2020-06-02 22:29:50 +08:00
from typing import Dict, List, Union, Optional, Callable
2020-05-12 11:31:47 +08:00
from tianshou.data import Batch, ReplayBuffer, PrioritizedReplayBuffer, \
to_torch_as, to_numpy
2020-03-12 22:20:33 +08:00
2020-03-18 21:45:41 +08:00
class BasePolicy(ABC, nn.Module):
"""The base class for any RL policy.
Tianshou aims to modularizing RL algorithms. It comes into several classes
of policies in Tianshou. All of the policy classes must inherit
2020-04-06 19:36:59 +08:00
:class:`~tianshou.policy.BasePolicy`.
2020-03-13 17:49:22 +08:00
2020-04-06 19:36:59 +08:00
A policy class typically has four parts:
* :meth:`~tianshou.policy.BasePolicy.__init__`: initialize the policy, \
including coping the target network and so on;
2020-04-10 10:47:16 +08:00
* :meth:`~tianshou.policy.BasePolicy.forward`: compute action with given \
2020-04-06 19:36:59 +08:00
observation;
* :meth:`~tianshou.policy.BasePolicy.process_fn`: pre-process data from \
the replay buffer (this function can interact with replay buffer);
* :meth:`~tianshou.policy.BasePolicy.learn`: update policy with a given \
batch of data.
Most of the policy needs a neural network to predict the action and an
optimizer to optimize the policy. The rules of self-defined networks are:
1. Input: observation "obs" (may be a ``numpy.ndarray``, a \
``torch.Tensor``, a dict or any others), hidden state "state" (for RNN \
usage), and other information "info" provided by the environment.
2. Output: some "logits", the next hidden state "state", and the \
intermediate result during policy forwarding procedure "policy". The \
"logits" could be a tuple instead of a ``torch.Tensor``. It depends on how\
the policy process the network output. For example, in PPO, the return of \
the network might be ``(mu, sigma), state`` for Gaussian policy. The \
"policy" can be a Batch of torch.Tensor or other things, which will be \
stored in the replay buffer, and can be accessed in the policy update \
process (e.g. in "policy.learn()", the "batch.policy" is what you need).
2020-04-06 19:36:59 +08:00
Since :class:`~tianshou.policy.BasePolicy` inherits ``torch.nn.Module``,
2020-04-10 11:16:33 +08:00
you can use :class:`~tianshou.policy.BasePolicy` almost the same as
``torch.nn.Module``, for instance, loading and saving the model:
2020-04-06 19:36:59 +08:00
::
torch.save(policy.state_dict(), "policy.pth")
policy.load_state_dict(torch.load("policy.pth"))
2020-04-06 19:36:59 +08:00
"""
def __init__(self,
observation_space: gym.Space = None,
action_space: gym.Space = None
) -> None:
2020-03-12 22:20:33 +08:00
super().__init__()
self.observation_space = observation_space
self.action_space = action_space
Add multi-agent example: tic-tac-toe (#122) * make fileds with empty Batch rather than None after reset * dummy code * remove dummy * add reward_length argument for collector * Improve Batch (#126) * make sure the key type of Batch is string, and add unit tests * add is_empty() function and unit tests * enable cat of mixing dict and Batch, just like stack * bugfix for reward_length * add get_final_reward_fn argument to collector to deal with marl * minor polish * remove multibuf * minor polish * improve and implement Batch.cat_ * bugfix for buffer.sample with field impt_weight * restore the usage of a.cat_(b) * fix 2 bugs in batch and add corresponding unittest * code fix for update * update is_empty to recognize empty over empty; bugfix for len * bugfix for update and add testcase * add testcase of update * make fileds with empty Batch rather than None after reset * dummy code * remove dummy * add reward_length argument for collector * bugfix for reward_length * add get_final_reward_fn argument to collector to deal with marl * make sure the key type of Batch is string, and add unit tests * add is_empty() function and unit tests * enable cat of mixing dict and Batch, just like stack * dummy code * remove dummy * add multi-agent example: tic-tac-toe * move TicTacToeEnv to a separate file * remove dummy MANet * code refactor * move tic-tac-toe example to test * update doc with marl-example * fix docs * reduce the threshold * revert * update player id to start from 1 and change player to agent; keep coding * add reward_length argument for collector * Improve Batch (#128) * minor polish * improve and implement Batch.cat_ * bugfix for buffer.sample with field impt_weight * restore the usage of a.cat_(b) * fix 2 bugs in batch and add corresponding unittest * code fix for update * update is_empty to recognize empty over empty; bugfix for len * bugfix for update and add testcase * add testcase of update * fix docs * fix docs * fix docs [ci skip] * fix docs [ci skip] Co-authored-by: Trinkle23897 <463003665@qq.com> * refact * re-implement Batch.stack and add testcases * add doc for Batch.stack * reward_metric * modify flag * minor fix * reuse _create_values and refactor stack_ & cat_ * fix pep8 * fix reward stat in collector * fix stat of collector, simplify test/base/env.py * fix docs * minor fix * raise exception for stacking with partial keys and axis!=0 * minor fix * minor fix * minor fix * marl-examples * add condense; bugfix for torch.Tensor; code refactor * marl example can run now * enable tic tac toe with larger board size and win-size * add test dependency * Fix padding of inconsistent keys with Batch.stack and Batch.cat (#130) * re-implement Batch.stack and add testcases * add doc for Batch.stack * reuse _create_values and refactor stack_ & cat_ * fix pep8 * fix docs * raise exception for stacking with partial keys and axis!=0 * minor fix * minor fix Co-authored-by: Trinkle23897 <463003665@qq.com> * stash * let agent learn to play as agent 2 which is harder * code refactor * Improve collector (#125) * remove multibuf * reward_metric * make fileds with empty Batch rather than None after reset * many fixes and refactor Co-authored-by: Trinkle23897 <463003665@qq.com> * marl for tic-tac-toe and general gomoku * update default gamma to 0.1 for tic tac toe to win earlier * fix name typo; change default game config; add rew_norm option * fix pep8 * test commit * mv test dir name * add rew flag * fix torch.optim import error and madqn rew_norm * remove useless kwargs * Vector env enable select worker (#132) * Enable selecting worker for vector env step method. * Update collector to match new vecenv selective worker behavior. * Bug fix. * Fix rebase Co-authored-by: Alexis Duburcq <alexis.duburcq@wandercraft.eu> * show the last move of tictactoe by capital letters * add multi-agent tutorial * fix link * Standardized behavior of Batch.cat and misc code refactor (#137) * code refactor; remove unused kwargs; add reward_normalization for dqn * bugfix for __setitem__ with torch.Tensor; add Batch.condense * minor fix * support cat with empty Batch * remove the dependency of is_empty on len; specify the semantic of empty Batch by test cases * support stack with empty Batch * remove condense * refactor code to reflect the shared / partial / reserved categories of keys * add is_empty(recursive=False) * doc fix * docfix and bugfix for _is_batch_set * add doc for key reservation * bugfix for algebra operators * fix cat with lens hint * code refactor * bugfix for storing None * use ValueError instead of exception * hide lens away from users * add comment for __cat * move the computation of the initial value of lens in cat_ itself. * change the place of doc string * doc fix for Batch doc string * change recursive to recurse * doc string fix * minor fix for batch doc * write tutorials to specify the standard of Batch (#142) * add doc for len exceptions * doc move; unify is_scalar_value function * remove some issubclass check * bugfix for shape of Batch(a=1) * keep moving doc * keep writing batch tutorial * draft version of Batch tutorial done * improving doc * keep improving doc * batch tutorial done * rename _is_number * rename _is_scalar * shape property do not raise exception * restore some doc string * grammarly [ci skip] * grammarly + fix warning of building docs * polish docs * trim and re-arrange batch tutorial * go straight to the point * minor fix for batch doc * add shape / len in basic usage * keep improving tutorial * unify _to_array_with_correct_type to remove duplicate code * delegate type convertion to Batch.__init__ * further delegate type convertion to Batch.__init__ * bugfix for setattr * add a _parse_value function * remove dummy function call * polish docs Co-authored-by: Trinkle23897 <463003665@qq.com> * bugfix for mapolicy * pretty code * remove debug code; remove condense * doc fix * check before get_agents in tutorials/tictactoe * tutorial * fix * minor fix for batch doc * minor polish * faster test_ttt * improve tic-tac-toe environment * change default epoch and step-per-epoch for tic-tac-toe * fix mapolicy * minor polish for mapolicy * 90% to 80% (need to change the tutorial) * win rate * show step number at board * simplify mapolicy * minor polish for mapolicy * remove MADQN * fix pep8 * change legal_actions to mask (need to update docs) * simplify maenv * fix typo * move basevecenv to single file * separate RandomAgent * update docs * grammarly * fix pep8 * win rate typo * format in cheatsheet * use bool mask directly * update doc for boolean mask Co-authored-by: Trinkle23897 <463003665@qq.com> Co-authored-by: Alexis DUBURCQ <alexis.duburcq@gmail.com> Co-authored-by: Alexis Duburcq <alexis.duburcq@wandercraft.eu>
2020-07-21 14:59:49 +08:00
self.agent_id = 0
def set_agent_id(self, agent_id: int) -> None:
"""Set self.agent_id = agent_id, for MARL."""
Add multi-agent example: tic-tac-toe (#122) * make fileds with empty Batch rather than None after reset * dummy code * remove dummy * add reward_length argument for collector * Improve Batch (#126) * make sure the key type of Batch is string, and add unit tests * add is_empty() function and unit tests * enable cat of mixing dict and Batch, just like stack * bugfix for reward_length * add get_final_reward_fn argument to collector to deal with marl * minor polish * remove multibuf * minor polish * improve and implement Batch.cat_ * bugfix for buffer.sample with field impt_weight * restore the usage of a.cat_(b) * fix 2 bugs in batch and add corresponding unittest * code fix for update * update is_empty to recognize empty over empty; bugfix for len * bugfix for update and add testcase * add testcase of update * make fileds with empty Batch rather than None after reset * dummy code * remove dummy * add reward_length argument for collector * bugfix for reward_length * add get_final_reward_fn argument to collector to deal with marl * make sure the key type of Batch is string, and add unit tests * add is_empty() function and unit tests * enable cat of mixing dict and Batch, just like stack * dummy code * remove dummy * add multi-agent example: tic-tac-toe * move TicTacToeEnv to a separate file * remove dummy MANet * code refactor * move tic-tac-toe example to test * update doc with marl-example * fix docs * reduce the threshold * revert * update player id to start from 1 and change player to agent; keep coding * add reward_length argument for collector * Improve Batch (#128) * minor polish * improve and implement Batch.cat_ * bugfix for buffer.sample with field impt_weight * restore the usage of a.cat_(b) * fix 2 bugs in batch and add corresponding unittest * code fix for update * update is_empty to recognize empty over empty; bugfix for len * bugfix for update and add testcase * add testcase of update * fix docs * fix docs * fix docs [ci skip] * fix docs [ci skip] Co-authored-by: Trinkle23897 <463003665@qq.com> * refact * re-implement Batch.stack and add testcases * add doc for Batch.stack * reward_metric * modify flag * minor fix * reuse _create_values and refactor stack_ & cat_ * fix pep8 * fix reward stat in collector * fix stat of collector, simplify test/base/env.py * fix docs * minor fix * raise exception for stacking with partial keys and axis!=0 * minor fix * minor fix * minor fix * marl-examples * add condense; bugfix for torch.Tensor; code refactor * marl example can run now * enable tic tac toe with larger board size and win-size * add test dependency * Fix padding of inconsistent keys with Batch.stack and Batch.cat (#130) * re-implement Batch.stack and add testcases * add doc for Batch.stack * reuse _create_values and refactor stack_ & cat_ * fix pep8 * fix docs * raise exception for stacking with partial keys and axis!=0 * minor fix * minor fix Co-authored-by: Trinkle23897 <463003665@qq.com> * stash * let agent learn to play as agent 2 which is harder * code refactor * Improve collector (#125) * remove multibuf * reward_metric * make fileds with empty Batch rather than None after reset * many fixes and refactor Co-authored-by: Trinkle23897 <463003665@qq.com> * marl for tic-tac-toe and general gomoku * update default gamma to 0.1 for tic tac toe to win earlier * fix name typo; change default game config; add rew_norm option * fix pep8 * test commit * mv test dir name * add rew flag * fix torch.optim import error and madqn rew_norm * remove useless kwargs * Vector env enable select worker (#132) * Enable selecting worker for vector env step method. * Update collector to match new vecenv selective worker behavior. * Bug fix. * Fix rebase Co-authored-by: Alexis Duburcq <alexis.duburcq@wandercraft.eu> * show the last move of tictactoe by capital letters * add multi-agent tutorial * fix link * Standardized behavior of Batch.cat and misc code refactor (#137) * code refactor; remove unused kwargs; add reward_normalization for dqn * bugfix for __setitem__ with torch.Tensor; add Batch.condense * minor fix * support cat with empty Batch * remove the dependency of is_empty on len; specify the semantic of empty Batch by test cases * support stack with empty Batch * remove condense * refactor code to reflect the shared / partial / reserved categories of keys * add is_empty(recursive=False) * doc fix * docfix and bugfix for _is_batch_set * add doc for key reservation * bugfix for algebra operators * fix cat with lens hint * code refactor * bugfix for storing None * use ValueError instead of exception * hide lens away from users * add comment for __cat * move the computation of the initial value of lens in cat_ itself. * change the place of doc string * doc fix for Batch doc string * change recursive to recurse * doc string fix * minor fix for batch doc * write tutorials to specify the standard of Batch (#142) * add doc for len exceptions * doc move; unify is_scalar_value function * remove some issubclass check * bugfix for shape of Batch(a=1) * keep moving doc * keep writing batch tutorial * draft version of Batch tutorial done * improving doc * keep improving doc * batch tutorial done * rename _is_number * rename _is_scalar * shape property do not raise exception * restore some doc string * grammarly [ci skip] * grammarly + fix warning of building docs * polish docs * trim and re-arrange batch tutorial * go straight to the point * minor fix for batch doc * add shape / len in basic usage * keep improving tutorial * unify _to_array_with_correct_type to remove duplicate code * delegate type convertion to Batch.__init__ * further delegate type convertion to Batch.__init__ * bugfix for setattr * add a _parse_value function * remove dummy function call * polish docs Co-authored-by: Trinkle23897 <463003665@qq.com> * bugfix for mapolicy * pretty code * remove debug code; remove condense * doc fix * check before get_agents in tutorials/tictactoe * tutorial * fix * minor fix for batch doc * minor polish * faster test_ttt * improve tic-tac-toe environment * change default epoch and step-per-epoch for tic-tac-toe * fix mapolicy * minor polish for mapolicy * 90% to 80% (need to change the tutorial) * win rate * show step number at board * simplify mapolicy * minor polish for mapolicy * remove MADQN * fix pep8 * change legal_actions to mask (need to update docs) * simplify maenv * fix typo * move basevecenv to single file * separate RandomAgent * update docs * grammarly * fix pep8 * win rate typo * format in cheatsheet * use bool mask directly * update doc for boolean mask Co-authored-by: Trinkle23897 <463003665@qq.com> Co-authored-by: Alexis DUBURCQ <alexis.duburcq@gmail.com> Co-authored-by: Alexis Duburcq <alexis.duburcq@wandercraft.eu>
2020-07-21 14:59:49 +08:00
self.agent_id = agent_id
2020-03-12 22:20:33 +08:00
@abstractmethod
2020-05-12 11:31:47 +08:00
def forward(self, batch: Batch,
state: Optional[Union[dict, Batch, np.ndarray]] = None,
**kwargs) -> Batch:
2020-04-06 19:36:59 +08:00
"""Compute action over the given batch data.
:return: A :class:`~tianshou.data.Batch` which MUST have the following\
keys:
* ``act`` an numpy.ndarray or a torch.Tensor, the action over \
given batch data.
* ``state`` a dict, an numpy.ndarray or a torch.Tensor, the \
internal state of the policy, ``None`` as default.
Other keys are user-defined. It depends on the algorithm. For example,
::
# some code
return Batch(logits=..., act=..., state=None, dist=...)
The keyword ``policy`` is reserved and the corresponding data will be
stored into the replay buffer. For instance,
::
# some code
return Batch(..., policy=Batch(log_prob=dist.log_prob(act)))
# and in the sampled data batch, you can directly use
# batch.policy.log_prob to get your data.
2020-04-06 19:36:59 +08:00
"""
2020-03-12 22:20:33 +08:00
pass
def process_fn(self, batch: Batch, buffer: ReplayBuffer,
indice: np.ndarray) -> Batch:
"""Pre-process the data from the provided replay buffer.
Used in :meth:`update`. Check out :ref:`process_fn` for more
information.
"""
return batch
2020-03-15 17:41:00 +08:00
@abstractmethod
2020-05-12 11:31:47 +08:00
def learn(self, batch: Batch, **kwargs
) -> Dict[str, Union[float, List[float]]]:
2020-04-06 19:36:59 +08:00
"""Update policy with a given batch of data.
2020-03-12 22:20:33 +08:00
2020-04-06 19:36:59 +08:00
:return: A dict which includes loss and its corresponding label.
.. warning::
If you use ``torch.distributions.Normal`` and
``torch.distributions.Categorical`` to calculate the log_prob,
please be careful about the shape: Categorical distribution gives
"[batch_size]" shape while Normal distribution gives "[batch_size,
1]" shape. The auto-broadcasting of numerical operation with torch
tensors will amplify this error.
2020-04-06 19:36:59 +08:00
"""
2020-03-14 21:48:31 +08:00
pass
2020-04-14 21:11:06 +08:00
def post_process_fn(self, batch: Batch,
buffer: ReplayBuffer, indice: np.ndarray) -> None:
"""Post-process the data from the provided replay buffer.
Typical usage is to update the sampling weight in prioritized
experience replay. Used in :meth:`update`.
"""
if isinstance(buffer, PrioritizedReplayBuffer) \
and hasattr(batch, 'weight'):
buffer.update_weight(indice, batch.weight)
def update(self, sample_size: int, buffer: Optional[ReplayBuffer],
*args, **kwargs) -> Dict[str, Union[float, List[float]]]:
"""Update the policy network and replay buffer.
It includes 3 function steps: process_fn, learn, and post_process_fn.
:param int sample_size: 0 means it will extract all the data from the
buffer, otherwise it will sample a batch with given sample_size.
:param ReplayBuffer buffer: the corresponding replay buffer.
"""
if buffer is None:
return {}
batch, indice = buffer.sample(sample_size)
batch = self.process_fn(batch, buffer, indice)
result = self.learn(batch, *args, **kwargs)
self.post_process_fn(batch, buffer, indice)
return result
2020-04-19 14:30:42 +08:00
@staticmethod
2020-05-12 11:31:47 +08:00
def compute_episodic_return(
batch: Batch,
v_s_: Optional[Union[np.ndarray, torch.Tensor]] = None,
gamma: float = 0.99,
gae_lambda: float = 0.95,
rew_norm: bool = False,
) -> Batch:
"""Compute returns over given full-length episodes.
Implementation of Generalized Advantage Estimator (arXiv:1506.02438).
2020-04-14 21:11:06 +08:00
:param batch: a data batch which contains several full-episode data
chronologically.
:type batch: :class:`~tianshou.data.Batch`
:param v_s_: the value function of all next states :math:`V(s')`.
:type v_s_: numpy.ndarray
:param float gamma: the discount factor, should be in [0, 1], defaults
to 0.99.
:param float gae_lambda: the parameter for Generalized Advantage
Estimation, should be in [0, 1], defaults to 0.95.
:param bool rew_norm: normalize the reward to Normal(0, 1), defaults
to False.
2020-06-02 22:29:50 +08:00
:return: a Batch. The result will be stored in batch.returns as a numpy
array with shape (bsz, ).
2020-04-14 21:11:06 +08:00
"""
rew = batch.rew
v_s_ = np.zeros_like(rew) if v_s_ is None else to_numpy(v_s_).flatten()
returns = _episodic_return(v_s_, rew, batch.done, gamma, gae_lambda)
if rew_norm and not np.isclose(returns.std(), 0, 1e-2):
returns = (returns - returns.mean()) / returns.std()
2020-06-02 22:29:50 +08:00
batch.returns = returns
return batch
@staticmethod
def compute_nstep_return(
batch: Batch,
buffer: ReplayBuffer,
indice: np.ndarray,
2020-06-03 13:59:47 +08:00
target_q_fn: Callable[[ReplayBuffer, np.ndarray], torch.Tensor],
2020-06-02 22:29:50 +08:00
gamma: float = 0.99,
2020-06-03 13:59:47 +08:00
n_step: int = 1,
rew_norm: bool = False,
) -> Batch:
r"""Compute n-step return for Q-learning targets.
2020-06-02 22:29:50 +08:00
.. math::
G_t = \sum_{i = t}^{t + n - 1} \gamma^{i - t}(1 - d_i)r_i +
\gamma^n (1 - d_{t + n}) Q_{\mathrm{target}}(s_{t + n})
where :math:`\gamma` is the discount factor,
2020-06-02 22:29:50 +08:00
:math:`\gamma \in [0, 1]`, :math:`d_t` is the done flag of step
:math:`t`.
:param batch: a data batch, which is equal to buffer[indice].
:type batch: :class:`~tianshou.data.Batch`
:param buffer: a data buffer which contains several full-episode data
chronologically.
:type buffer: :class:`~tianshou.data.ReplayBuffer`
:param indice: sampled timestep.
:type indice: numpy.ndarray
2020-06-03 13:59:47 +08:00
:param function target_q_fn: a function receives :math:`t+n-1` step's
data and compute target Q value.
2020-06-02 22:29:50 +08:00
:param float gamma: the discount factor, should be in [0, 1], defaults
to 0.99.
:param int n_step: the number of estimation step, should be an int
greater than 0, defaults to 1.
2020-06-03 13:59:47 +08:00
:param bool rew_norm: normalize the reward to Normal(0, 1), defaults
to False.
2020-06-02 22:29:50 +08:00
2020-06-03 13:59:47 +08:00
:return: a Batch. The result will be stored in batch.returns as a
torch.Tensor with shape (bsz, ).
2020-06-02 22:29:50 +08:00
"""
rew = buffer.rew
2020-06-03 13:59:47 +08:00
if rew_norm:
bfr = rew[:min(len(buffer), 1000)] # avoid large buffer
2020-06-03 13:59:47 +08:00
mean, std = bfr.mean(), bfr.std()
if np.isclose(std, 0, 1e-2):
mean, std = 0., 1.
2020-06-03 13:59:47 +08:00
else:
mean, std = 0., 1.
buf_len = len(buffer)
2020-06-02 22:29:50 +08:00
terminal = (indice + n_step - 1) % buf_len
target_q_torch = target_q_fn(buffer, terminal).flatten() # (bsz, )
target_q = to_numpy(target_q_torch)
target_q = _nstep_return(rew, buffer.done, target_q, indice,
gamma, n_step, len(buffer), mean, std)
batch.returns = to_torch_as(target_q, target_q_torch)
# prio buffer update
if isinstance(buffer, PrioritizedReplayBuffer):
batch.weight = to_torch_as(batch.weight, target_q_torch)
2020-04-14 21:11:06 +08:00
return batch
@njit
def _episodic_return(
v_s_: np.ndarray, rew: np.ndarray, done: np.ndarray,
gamma: float, gae_lambda: float,
) -> np.ndarray:
"""Numba speedup: 4.1s -> 0.057s."""
returns = np.roll(v_s_, 1)
m = (1. - done) * gamma
delta = rew + v_s_ * m - returns
m *= gae_lambda
gae = 0.
for i in range(len(rew) - 1, -1, -1):
gae = delta[i] + m[i] * gae
returns[i] += gae
return returns
@njit
def _nstep_return(
rew: np.ndarray, done: np.ndarray, target_q: np.ndarray,
indice: np.ndarray, gamma: float, n_step: int, buf_len: int,
mean: float, std: float
) -> np.ndarray:
"""Numba speedup: 0.3s -> 0.15s."""
returns = np.zeros(indice.shape)
gammas = np.full(indice.shape, n_step)
for n in range(n_step - 1, -1, -1):
now = (indice + n) % buf_len
gammas[done[now] > 0] = n
returns[done[now] > 0] = 0.
returns = (rew[now] - mean) / std + gamma * returns
target_q[gammas != n_step] = 0
target_q = target_q * (gamma ** gammas) + returns
return target_q