From 140b1c2cab3480ca34fb97f08e816300bb2d3daf Mon Sep 17 00:00:00 2001 From: n+e <463003665@qq.com> Date: Thu, 6 Aug 2020 10:26:24 +0800 Subject: [PATCH] Improve PER (#159) - use segment tree to rewrite the previous PrioReplayBuffer code, add the test - enable all Q-learning algorithms to use PER --- README.md | 2 +- docs/index.rst | 2 +- test/base/test_buffer.py | 106 +++++++++++++- test/discrete/test_dqn.py | 26 +++- test/discrete/test_pdqn.py | 118 ---------------- test/throughput/test_buffer_profile.py | 16 ++- tianshou/data/__init__.py | 6 +- tianshou/data/buffer.py | 118 +++++++--------- tianshou/data/utils/__init__.py | 0 .../data/{utils.py => utils/converter.py} | 0 tianshou/data/utils/segtree.py | 133 ++++++++++++++++++ tianshou/policy/base.py | 10 +- tianshou/policy/modelfree/ddpg.py | 7 +- tianshou/policy/modelfree/dqn.py | 19 +-- tianshou/policy/modelfree/sac.py | 12 +- tianshou/policy/modelfree/td3.py | 11 +- 16 files changed, 368 insertions(+), 218 deletions(-) delete mode 100644 test/discrete/test_pdqn.py create mode 100644 tianshou/data/utils/__init__.py rename tianshou/data/{utils.py => utils/converter.py} (100%) create mode 100644 tianshou/data/utils/segtree.py diff --git a/README.md b/README.md index 414e288..2c027d5 100644 --- a/README.md +++ b/README.md @@ -38,7 +38,7 @@ Here is Tianshou's other features: - Support recurrent state representation in actor network and critic network (RNN-style training for POMDP) [Usage](https://tianshou.readthedocs.io/en/latest/tutorials/cheatsheet.html#rnn-style-training) - Support any type of environment state (e.g. a dict, a self-defined class, ...) [Usage](https://tianshou.readthedocs.io/en/latest/tutorials/cheatsheet.html#user-defined-environment-and-different-state-representation) - Support customized training process [Usage](https://tianshou.readthedocs.io/en/latest/tutorials/cheatsheet.html#customize-training-process) -- Support n-step returns estimation for all Q-learning based algorithms +- Support n-step returns estimation and prioritized experience replay for all Q-learning based algorithms - Support multi-agent RL [Usage](https://tianshou.readthedocs.io/en/latest/tutorials/cheatsheet.html##multi-agent-reinforcement-learning) In Chinese, Tianshou means divinely ordained and is derived to the gift of being born with. Tianshou is a reinforcement learning platform, and the RL algorithm does not learn from humans. So taking "Tianshou" means that there is no teacher to study with, but rather to learn by themselves through constant interaction with the environment. diff --git a/docs/index.rst b/docs/index.rst index 25f0410..9ef598a 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -28,7 +28,7 @@ Here is Tianshou's other features: * Support recurrent state representation in actor network and critic network (RNN-style training for POMDP): :ref:`rnn_training` * Support any type of environment state (e.g. a dict, a self-defined class, ...): :ref:`self_defined_env` * Support customized training process: :ref:`customize_training` -* Support n-step returns estimation :meth:`~tianshou.policy.BasePolicy.compute_nstep_return` for all Q-learning based algorithms +* Support n-step returns estimation :meth:`~tianshou.policy.BasePolicy.compute_nstep_return` and prioritized experience replay for all Q-learning based algorithms * Support multi-agent RL: :doc:`/tutorials/tictactoe` 中文文档位于 `https://tianshou.readthedocs.io/zh/latest/ `_ diff --git a/test/base/test_buffer.py b/test/base/test_buffer.py index d9f59d2..4c6bc71 100644 --- a/test/base/test_buffer.py +++ b/test/base/test_buffer.py @@ -1,6 +1,9 @@ +import pytest import numpy as np +from timeit import timeit -from tianshou.data import Batch, PrioritizedReplayBuffer, ReplayBuffer +from tianshou.data import Batch, PrioritizedReplayBuffer, \ + ReplayBuffer, SegmentTree if __name__ == '__main__': from env import MyTestEnv @@ -112,9 +115,110 @@ def test_update(): assert (buf2[-1].obs == buf1[0].obs).all() +def test_segtree(): + for op, init in zip(['sum', 'max', 'min'], [0., -np.inf, np.inf]): + realop = getattr(np, op) + # small test + actual_len = 8 + tree = SegmentTree(actual_len, op) # 1-15. 8-15 are leaf nodes + assert np.all([tree[i] == init for i in range(actual_len)]) + with pytest.raises(IndexError): + tree[actual_len] + naive = np.full([actual_len], init) + for _ in range(1000): + # random choose a place to perform single update + index = np.random.randint(actual_len) + value = np.random.rand() + naive[index] = value + tree[index] = value + for i in range(actual_len): + for j in range(i + 1, actual_len): + ref = realop(naive[i:j]) + out = tree.reduce(i, j) + assert np.allclose(ref, out) + # batch setitem + for _ in range(1000): + index = np.random.choice(actual_len, size=4) + value = np.random.rand(4) + naive[index] = value + tree[index] = value + assert np.allclose(realop(naive), tree.reduce()) + for i in range(10): + left = np.random.randint(actual_len) + right = np.random.randint(left + 1, actual_len + 1) + assert np.allclose(realop(naive[left:right]), + tree.reduce(left, right)) + # large test + actual_len = 16384 + tree = SegmentTree(actual_len, op) + naive = np.full([actual_len], init) + for _ in range(1000): + index = np.random.choice(actual_len, size=64) + value = np.random.rand(64) + naive[index] = value + tree[index] = value + assert np.allclose(realop(naive), tree.reduce()) + for i in range(10): + left = np.random.randint(actual_len) + right = np.random.randint(left + 1, actual_len + 1) + assert np.allclose(realop(naive[left:right]), + tree.reduce(left, right)) + + # test prefix-sum-idx + actual_len = 8 + tree = SegmentTree(actual_len) + naive = np.random.rand(actual_len) + tree[np.arange(actual_len)] = naive + for _ in range(1000): + scalar = np.random.rand() * naive.sum() + index = tree.get_prefix_sum_idx(scalar) + assert naive[:index].sum() <= scalar <= naive[:index + 1].sum() + # corner case here + naive = np.ones(actual_len, np.int) + tree[np.arange(actual_len)] = naive + for scalar in range(actual_len): + index = tree.get_prefix_sum_idx(scalar * 1.) + assert naive[:index].sum() <= scalar <= naive[:index + 1].sum() + tree = SegmentTree(10) + tree[np.arange(3)] = np.array([0.1, 0, 0.1]) + assert np.allclose(tree.get_prefix_sum_idx( + np.array([0, .1, .1 + 1e-6, .2 - 1e-6])), [0, 0, 2, 2]) + with pytest.raises(AssertionError): + tree.get_prefix_sum_idx(.2) + # test large prefix-sum-idx + actual_len = 16384 + tree = SegmentTree(actual_len) + naive = np.random.rand(actual_len) + tree[np.arange(actual_len)] = naive + for _ in range(1000): + scalar = np.random.rand() * naive.sum() + index = tree.get_prefix_sum_idx(scalar) + assert naive[:index].sum() <= scalar <= naive[:index + 1].sum() + + # profile + if __name__ == '__main__': + size = 100000 + bsz = 64 + naive = np.random.rand(size) + tree = SegmentTree(size) + tree[np.arange(size)] = naive + + def sample_npbuf(): + return np.random.choice(size, bsz, p=naive / naive.sum()) + + def sample_tree(): + scalar = np.random.rand(bsz) * tree.reduce() + return tree.get_prefix_sum_idx(scalar) + + print('npbuf', timeit(sample_npbuf, setup=sample_npbuf, number=1000)) + print('tree', timeit(sample_tree, setup=sample_tree, number=1000)) + + if __name__ == '__main__': test_replaybuffer() test_ignore_obs_next() test_stack() + test_segtree() + test_priortized_replaybuffer() test_priortized_replaybuffer(233333, 200000) test_update() diff --git a/test/discrete/test_dqn.py b/test/discrete/test_dqn.py index 0455f70..ae4c4ce 100644 --- a/test/discrete/test_dqn.py +++ b/test/discrete/test_dqn.py @@ -8,9 +8,9 @@ from torch.utils.tensorboard import SummaryWriter from tianshou.env import VectorEnv from tianshou.policy import DQNPolicy -from tianshou.trainer import offpolicy_trainer -from tianshou.data import Collector, ReplayBuffer from tianshou.utils.net.common import Net +from tianshou.trainer import offpolicy_trainer +from tianshou.data import Collector, ReplayBuffer, PrioritizedReplayBuffer def get_args(): @@ -33,6 +33,9 @@ def get_args(): parser.add_argument('--test-num', type=int, default=100) parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--render', type=float, default=0.) + parser.add_argument('--prioritized-replay', type=int, default=0) + parser.add_argument('--alpha', type=float, default=0.6) + parser.add_argument('--beta', type=float, default=0.4) parser.add_argument( '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu') @@ -58,15 +61,20 @@ def test_dqn(args=get_args()): test_envs.seed(args.seed) # model net = Net(args.layer_num, args.state_shape, - args.action_shape, args.device, - dueling=(2, 2)).to(args.device) + args.action_shape, args.device, # dueling=(1, 1) + ).to(args.device) optim = torch.optim.Adam(net.parameters(), lr=args.lr) policy = DQNPolicy( net, optim, args.gamma, args.n_step, target_update_freq=args.target_update_freq) + # buffer + if args.prioritized_replay > 0: + buf = PrioritizedReplayBuffer( + args.buffer_size, alpha=args.alpha, beta=args.beta) + else: + buf = ReplayBuffer(args.buffer_size) # collector - train_collector = Collector( - policy, train_envs, ReplayBuffer(args.buffer_size)) + train_collector = Collector(policy, train_envs, buf) test_collector = Collector(policy, test_envs) # policy.set_eps(1) train_collector.collect(n_step=args.batch_size) @@ -114,5 +122,11 @@ def test_dqn(args=get_args()): collector.close() +def test_pdqn(args=get_args()): + args.prioritized_replay = 1 + args.gamma = .95 + test_dqn(args) + + if __name__ == '__main__': test_dqn(get_args()) diff --git a/test/discrete/test_pdqn.py b/test/discrete/test_pdqn.py deleted file mode 100644 index b614f24..0000000 --- a/test/discrete/test_pdqn.py +++ /dev/null @@ -1,118 +0,0 @@ -import os -import gym -import torch -import pprint -import argparse -import numpy as np -from torch.utils.tensorboard import SummaryWriter - -from tianshou.utils.net.common import Net -from tianshou.env import VectorEnv -from tianshou.policy import DQNPolicy -from tianshou.trainer import offpolicy_trainer -from tianshou.data import Collector, ReplayBuffer, PrioritizedReplayBuffer - - -def get_args(): - parser = argparse.ArgumentParser() - parser.add_argument('--task', type=str, default='CartPole-v0') - parser.add_argument('--seed', type=int, default=1626) - parser.add_argument('--eps-test', type=float, default=0.05) - parser.add_argument('--eps-train', type=float, default=0.1) - parser.add_argument('--buffer-size', type=int, default=20000) - parser.add_argument('--lr', type=float, default=1e-3) - parser.add_argument('--gamma', type=float, default=0.9) - parser.add_argument('--n-step', type=int, default=3) - parser.add_argument('--target-update-freq', type=int, default=320) - parser.add_argument('--epoch', type=int, default=10) - parser.add_argument('--step-per-epoch', type=int, default=1000) - parser.add_argument('--collect-per-step', type=int, default=10) - parser.add_argument('--batch-size', type=int, default=64) - parser.add_argument('--layer-num', type=int, default=3) - parser.add_argument('--training-num', type=int, default=8) - parser.add_argument('--test-num', type=int, default=100) - parser.add_argument('--logdir', type=str, default='log') - parser.add_argument('--render', type=float, default=0.) - parser.add_argument('--prioritized-replay', type=int, default=1) - parser.add_argument('--alpha', type=float, default=0.5) - parser.add_argument('--beta', type=float, default=0.5) - parser.add_argument( - '--device', type=str, - default='cuda' if torch.cuda.is_available() else 'cpu') - args = parser.parse_known_args()[0] - return args - - -def test_pdqn(args=get_args()): - env = gym.make(args.task) - args.state_shape = env.observation_space.shape or env.observation_space.n - args.action_shape = env.action_space.shape or env.action_space.n - # train_envs = gym.make(args.task) - # you can also use tianshou.env.SubprocVectorEnv - train_envs = VectorEnv( - [lambda: gym.make(args.task) for _ in range(args.training_num)]) - # test_envs = gym.make(args.task) - test_envs = VectorEnv( - [lambda: gym.make(args.task) for _ in range(args.test_num)]) - # seed - np.random.seed(args.seed) - torch.manual_seed(args.seed) - train_envs.seed(args.seed) - test_envs.seed(args.seed) - # model - net = Net(args.layer_num, args.state_shape, - args.action_shape, args.device).to(args.device) - optim = torch.optim.Adam(net.parameters(), lr=args.lr) - policy = DQNPolicy( - net, optim, args.gamma, args.n_step, - target_update_freq=args.target_update_freq) - # collector - if args.prioritized_replay > 0: - buf = PrioritizedReplayBuffer( - args.buffer_size, alpha=args.alpha, - beta=args.alpha, repeat_sample=True) - else: - buf = ReplayBuffer(args.buffer_size) - train_collector = Collector( - policy, train_envs, buf) - test_collector = Collector(policy, test_envs) - # policy.set_eps(1) - train_collector.collect(n_step=args.batch_size) - # log - log_path = os.path.join(args.logdir, args.task, 'dqn') - writer = SummaryWriter(log_path) - - def save_fn(policy): - torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) - - def stop_fn(x): - return x >= env.spec.reward_threshold - - def train_fn(x): - policy.set_eps(args.eps_train) - - def test_fn(x): - policy.set_eps(args.eps_test) - - # trainer - result = offpolicy_trainer( - policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.collect_per_step, args.test_num, - args.batch_size, train_fn=train_fn, test_fn=test_fn, - stop_fn=stop_fn, save_fn=save_fn, writer=writer) - - assert stop_fn(result['best_reward']) - train_collector.close() - test_collector.close() - if __name__ == '__main__': - pprint.pprint(result) - # Let's watch its performance! - env = gym.make(args.task) - collector = Collector(policy, env) - result = collector.collect(n_episode=1, render=args.render) - print(f'Final reward: {result["rew"]}, length: {result["len"]}') - collector.close() - - -if __name__ == '__main__': - test_pdqn(get_args()) diff --git a/test/throughput/test_buffer_profile.py b/test/throughput/test_buffer_profile.py index aec3268..88abdcb 100644 --- a/test/throughput/test_buffer_profile.py +++ b/test/throughput/test_buffer_profile.py @@ -1,8 +1,8 @@ -import numpy as np import pytest +import numpy as np from tianshou.data import (ListReplayBuffer, PrioritizedReplayBuffer, - ReplayBuffer) + ReplayBuffer, SegmentTree) @pytest.fixture(scope="module") @@ -21,7 +21,7 @@ def data(): 'buffer': buffer, 'buffer2': buffer2, 'slice': slice(-3000, -1000, 2), - 'indexes': indexes + 'indexes': indexes, } @@ -77,5 +77,15 @@ def test_sample(data): buffer.sample(int(1e2)) +def test_segtree(data): + size = 100000 + tree = SegmentTree(size) + tree[np.arange(size)] = np.random.rand(size) + + for i in np.arange(1e5): + scalar = np.random.rand(64) * tree.reduce() + tree.get_prefix_sum_idx(scalar) + + if __name__ == '__main__': pytest.main(["-s", "-k buffer_profile", "--durations=0", "-v"]) diff --git a/tianshou/data/__init__.py b/tianshou/data/__init__.py index 5d097a0..f5f68e9 100644 --- a/tianshou/data/__init__.py +++ b/tianshou/data/__init__.py @@ -1,6 +1,7 @@ from tianshou.data.batch import Batch -from tianshou.data.utils import to_numpy, to_torch, \ +from tianshou.data.utils.converter import to_numpy, to_torch, \ to_torch_as +from tianshou.data.utils.segtree import SegmentTree from tianshou.data.buffer import ReplayBuffer, \ ListReplayBuffer, PrioritizedReplayBuffer from tianshou.data.collector import Collector @@ -10,8 +11,9 @@ __all__ = [ 'to_numpy', 'to_torch', 'to_torch_as', + 'SegmentTree', 'ReplayBuffer', 'ListReplayBuffer', 'PrioritizedReplayBuffer', - 'Collector' + 'Collector', ] diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py index 4491bee..1d7a80f 100644 --- a/tianshou/data/buffer.py +++ b/tianshou/data/buffer.py @@ -1,7 +1,9 @@ +import torch import numpy as np from typing import Any, Tuple, Union, Optional -from tianshou.data.batch import Batch, _create_value +from tianshou.data import Batch, SegmentTree, to_numpy +from tianshou.data.batch import _create_value class ReplayBuffer: @@ -313,7 +315,7 @@ class ReplayBuffer: done=self.done[index], obs_next=self.get(index, 'obs_next'), info=self.get(index, 'info'), - policy=self.get(index, 'policy') + policy=self.get(index, 'policy'), ) @@ -326,8 +328,8 @@ class ListReplayBuffer(ReplayBuffer): .. seealso:: - Please refer to :class:`~tianshou.data.ReplayBuffer` for more - detailed explanation. + Please refer to :class:`~tianshou.data.ReplayBuffer` for more detailed + explanation. """ def __init__(self, **kwargs) -> None: @@ -353,31 +355,32 @@ class ListReplayBuffer(ReplayBuffer): class PrioritizedReplayBuffer(ReplayBuffer): - """Prioritized replay buffer implementation. + """Implementation of Prioritized Experience Replay. arXiv:1511.05952 :param float alpha: the prioritization exponent. :param float beta: the importance sample soft coefficient. - :param str mode: defaults to ``weight``. - :param bool replace: whether to sample with replacement .. seealso:: - Please refer to :class:`~tianshou.data.ReplayBuffer` for more - detailed explanation. + Please refer to :class:`~tianshou.data.ReplayBuffer` for more detailed + explanation. """ - def __init__(self, size: int, alpha: float, beta: float, - mode: str = 'weight', - replace: bool = False, **kwargs) -> None: - if mode != 'weight': - raise NotImplementedError + def __init__(self, size: int, alpha: float, beta: float, **kwargs) -> None: super().__init__(size, **kwargs) - self._alpha = alpha - self._beta = beta - self._weight_sum = 0.0 - self._amortization_freq = 50 - self._replace = replace - self._meta.weight = np.zeros(size, dtype=np.float64) + assert alpha > 0. and beta >= 0. + self._alpha, self._beta = alpha, beta + self._max_prio = 1. + self._min_prio = 1. + # bypass the check + self._weight = SegmentTree(size) + self.__eps = np.finfo(np.float32).eps.item() + + def __getattr__(self, key: str) -> Union['Batch', Any]: + """Return self.key""" + if key == 'weight': + return self._weight + return self._meta.__dict__[key] def add(self, obs: Union[dict, np.ndarray], @@ -387,68 +390,55 @@ class PrioritizedReplayBuffer(ReplayBuffer): obs_next: Optional[Union[dict, np.ndarray]] = None, info: dict = {}, policy: Optional[Union[dict, Batch]] = {}, - weight: float = 1.0, + weight: float = None, **kwargs) -> None: """Add a batch of data into replay buffer.""" - # we have to sacrifice some convenience for speed - self._weight_sum += np.abs(weight) ** self._alpha - \ - self._meta.weight[self._index] - self._add_to_buffer('weight', np.abs(weight) ** self._alpha) + if weight is None: + weight = self._max_prio + else: + weight = np.abs(weight) + self._max_prio = max(self._max_prio, weight) + self._min_prio = min(self._min_prio, weight) + self.weight[self._index] = weight ** self._alpha super().add(obs, act, rew, done, obs_next, info, policy) - @property - def replace(self): - return self._replace - - @replace.setter - def replace(self, v: bool): - self._replace = v - def sample(self, batch_size: int) -> Tuple[Batch, np.ndarray]: - """Get a random sample from buffer with priority probability. \ - Return all the data in the buffer if batch_size is ``0``. + """Get a random sample from buffer with priority probability. Return + all the data in the buffer if batch_size is ``0``. :return: Sample data and its corresponding index inside the buffer. + + The ``weight`` in the returned Batch is the weight on loss function + to de-bias the sampling process (some transition tuples are sampled + more often so their losses are weighted less). """ - assert self._size > 0, 'cannot sample a buffer with size == 0 !' - p = None - if batch_size > 0 and (self._replace or batch_size <= self._size): - # sampling weight - p = (self.weight / self.weight.sum())[:self._size] - indice = np.random.choice( - self._size, batch_size, p=p, - replace=self._replace) - p = p[indice] # weight of each sample - elif batch_size == 0: - p = np.full(shape=self._size, fill_value=1.0 / self._size) + assert self._size > 0, 'Cannot sample a buffer with 0 size!' + if batch_size == 0: indice = np.concatenate([ np.arange(self._index, self._size), np.arange(0, self._index), ]) else: - raise ValueError( - f"batch_size should be less than {len(self)}, \ - or set replace=True") + scalar = np.random.rand(batch_size) * self.weight.reduce() + indice = self.weight.get_prefix_sum_idx(scalar) batch = self[indice] - batch["impt_weight"] = (self._size * p) ** (-self._beta) + # impt_weight + # original formula: ((p_j/p_sum*N)**(-beta))/((p_min/p_sum*N)**(-beta)) + # simplified formula: (p_j/p_min)**(-beta) + batch.weight = (batch.weight / self._min_prio) ** (-self._beta) return batch, indice - def update_weight(self, indice: Union[slice, np.ndarray], - new_weight: np.ndarray) -> None: + def update_weight(self, indice: Union[np.ndarray], + new_weight: Union[np.ndarray, torch.Tensor]) -> None: """Update priority weight by indice in this buffer. - :param np.ndarray indice: indice you want to update weight - :param np.ndarray new_weight: new priority weight you want to update + :param np.ndarray indice: indice you want to update weight. + :param np.ndarray new_weight: new priority weight you want to update. """ - if self._replace: - if isinstance(indice, slice): - # convert slice to ndarray - indice = np.arange(indice.stop)[indice] - # remove the same values in indice - indice, unique_indice = np.unique( - indice, return_index=True) - new_weight = new_weight[unique_indice] - self.weight[indice] = np.power(np.abs(new_weight), self._alpha) + weight = np.abs(to_numpy(new_weight)) + self.__eps + self.weight[indice] = weight ** self._alpha + self._max_prio = max(self._max_prio, weight.max()) + self._min_prio = min(self._min_prio, weight.min()) def __getitem__(self, index: Union[ slice, int, np.integer, np.ndarray]) -> Batch: @@ -459,6 +449,6 @@ class PrioritizedReplayBuffer(ReplayBuffer): done=self.done[index], obs_next=self.get(index, 'obs_next'), info=self.get(index, 'info'), - weight=self.weight[index], policy=self.get(index, 'policy'), + weight=self.weight[index], ) diff --git a/tianshou/data/utils/__init__.py b/tianshou/data/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tianshou/data/utils.py b/tianshou/data/utils/converter.py similarity index 100% rename from tianshou/data/utils.py rename to tianshou/data/utils/converter.py diff --git a/tianshou/data/utils/segtree.py b/tianshou/data/utils/segtree.py new file mode 100644 index 0000000..60a60dd --- /dev/null +++ b/tianshou/data/utils/segtree.py @@ -0,0 +1,133 @@ +import numpy as np +from typing import Union, Optional +# from numba import njit + + +# numba version, 5x speed up +# with size=100000 and bsz=64 +# first block (vectorized np): 0.0923 (now) -> 0.0251 +# second block (for-loop): 0.2914 -> 0.0192 (future) +# @njit +def _get_prefix_sum_idx(value, bound, sums): + index = np.ones(value.shape, dtype=np.int64) + while index[0] < bound: + index *= 2 + direct = sums[index] < value + value -= sums[index] * direct + index += direct + # for _, s in enumerate(value): + # i = 1 + # while i < bound: + # l = i * 2 + # if sums[l] >= s: + # i = l + # else: + # s = s - sums[l] + # i = l + 1 + # index[_] = i + index -= bound + return index + + +class SegmentTree: + """Implementation of Segment Tree: store an array ``arr`` with size ``n`` + in a segment tree, support value update and fast query of ``min/max/sum`` + for the interval ``[left, right)`` in O(log n) time. + + The detailed procedure is as follows: + + 1. Pad the array to have length of power of 2, so that leaf nodes in the\ + segment tree have the same depth. + 2. Store the segment tree in a binary heap. + + :param int size: the size of segment tree. + :param str operation: the operation of segment tree. Choices are "sum", + "min" and "max". Default: "sum". + """ + + def __init__(self, size: int, + operation: str = 'sum') -> None: + bound = 1 + while bound < size: + bound *= 2 + self._size = size + self._bound = bound + assert operation in ['sum', 'min', 'max'], \ + f'Unknown operation {operation}.' + if operation == 'sum': + self._op, self._init_value = np.add, 0. + elif operation == 'min': + self._op, self._init_value = np.minimum, np.inf + else: + self._op, self._init_value = np.maximum, -np.inf + # assert isinstance(self._op, np.ufunc) + self._value = np.full([bound * 2], self._init_value) + + def __len__(self): + return self._size + + def __getitem__(self, index: Union[int, np.ndarray] + ) -> Union[float, np.ndarray]: + """Return self[index]""" + return self._value[index + self._bound] + + def __setitem__(self, index: Union[int, np.ndarray], + value: Union[float, np.ndarray]) -> None: + """Duplicate values in ``index`` are handled by numpy: later index + overwrites previous ones. + + :: + + >>> a = np.array([1, 2, 3, 4]) + >>> a[[0, 1, 0, 1]] = [4, 5, 6, 7] + >>> print(a) + [6 7 3 4] + + """ + # TODO numba njit version + if isinstance(index, int): + index = np.array([index]) + assert np.all(0 <= index) and np.all(index < self._size) + if self._op is np.add: + assert np.all(0. <= value) + index = index + self._bound + self._value[index] = value + while index[0] > 1: + index //= 2 + self._value[index] = self._op( + self._value[index * 2], self._value[index * 2 + 1]) + + def reduce(self, start: Optional[int] = 0, + end: Optional[int] = None) -> float: + """Return operation(value[start:end]).""" + # TODO numba njit version + if start == 0 and end is None: + return self._value[1] + if end is None: + end = self._size + if end < 0: + end += self._size + # nodes in (start, end) should be aggregated + start, end = start + self._bound - 1, end + self._bound + result = self._init_value + while end - start > 1: # (start, end) interval is not empty + if start % 2 == 0: + result = self._op(result, self._value[start + 1]) + if end % 2 == 1: + result = self._op(result, self._value[end - 1]) + start, end = start // 2, end // 2 + return result + + def get_prefix_sum_idx( + self, value: Union[float, np.ndarray]) -> Union[int, np.ndarray]: + """Return the minimum index for each ``v`` in ``value`` so that + ``v <= sums[i]``, where sums[i] = \\sum_{j=0}^{i} arr[j]. + """ + assert self._op is np.add + assert np.all(value >= 0.) and np.all(value < self._value[1]) + single = False + if not isinstance(value, np.ndarray): + value = np.array([value]) + single = True + index = _get_prefix_sum_idx(value, self._bound, self._value) + return index.item() if single else index diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index 9d7711c..cc1a593 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -4,7 +4,8 @@ from torch import nn from abc import ABC, abstractmethod from typing import Dict, List, Union, Optional, Callable -from tianshou.data import Batch, ReplayBuffer, to_torch_as, to_numpy +from tianshou.data import Batch, ReplayBuffer, PrioritizedReplayBuffer, \ + to_torch_as, to_numpy class BasePolicy(ABC, nn.Module): @@ -213,4 +214,11 @@ class BasePolicy(ABC, nn.Module): returns = to_torch_as(returns, target_q) gammas = to_torch_as(gamma ** gammas, target_q) batch.returns = target_q * gammas + returns + # prio buffer update + if isinstance(buffer, PrioritizedReplayBuffer): + batch.update_weight = buffer.update_weight + batch.indice = indice + batch.weight = to_torch_as(batch.weight, target_q) + else: + batch.weight = torch.ones_like(target_q) return batch diff --git a/tianshou/policy/modelfree/ddpg.py b/tianshou/policy/modelfree/ddpg.py index da4833a..2205102 100644 --- a/tianshou/policy/modelfree/ddpg.py +++ b/tianshou/policy/modelfree/ddpg.py @@ -1,7 +1,6 @@ import torch import numpy as np from copy import deepcopy -import torch.nn.functional as F from typing import Dict, Tuple, Union, Optional from tianshou.policy import BasePolicy @@ -144,7 +143,11 @@ class DDPGPolicy(BasePolicy): def learn(self, batch: Batch, **kwargs) -> Dict[str, float]: current_q = self.critic(batch.obs, batch.act).flatten() target_q = batch.returns.flatten() - critic_loss = F.mse_loss(current_q, target_q) + td = current_q - target_q + if hasattr(batch, 'update_weight'): # prio-buffer + batch.update_weight(batch.indice, td) + critic_loss = (td.pow(2) * batch.weight).mean() + # critic_loss = F.mse_loss(current_q, target_q) self.critic_optim.zero_grad() critic_loss.backward() self.critic_optim.step() diff --git a/tianshou/policy/modelfree/dqn.py b/tianshou/policy/modelfree/dqn.py index 9bf60a6..c37dac5 100644 --- a/tianshou/policy/modelfree/dqn.py +++ b/tianshou/policy/modelfree/dqn.py @@ -1,12 +1,10 @@ import torch import numpy as np from copy import deepcopy -import torch.nn.functional as F from typing import Dict, Union, Optional from tianshou.policy import BasePolicy -from tianshou.data import Batch, ReplayBuffer, PrioritizedReplayBuffer, \ - to_torch_as, to_numpy +from tianshou.data import Batch, ReplayBuffer, to_torch_as, to_numpy class DQNPolicy(BasePolicy): @@ -95,9 +93,6 @@ class DQNPolicy(BasePolicy): batch = self.compute_nstep_return( batch, buffer, indice, self._target_q, self._gamma, self._n_step, self._rew_norm) - if isinstance(buffer, PrioritizedReplayBuffer): - batch.update_weight = buffer.update_weight - batch.indice = indice return batch def forward(self, batch: Batch, @@ -164,13 +159,11 @@ class DQNPolicy(BasePolicy): q = self(batch, eps=0.).logits q = q[np.arange(len(q)), batch.act] r = to_torch_as(batch.returns, q).flatten() - if hasattr(batch, 'update_weight'): - td = r - q - batch.update_weight(batch.indice, to_numpy(td)) - impt_weight = to_torch_as(batch.impt_weight, q) - loss = (td.pow(2) * impt_weight).mean() - else: - loss = F.mse_loss(q, r) + td = r - q + if hasattr(batch, 'update_weight'): # prio-buffer + batch.update_weight(batch.indice, td) + loss = (td.pow(2) * batch.weight).mean() + # loss = F.mse_loss(q, r) loss.backward() self.optim.step() self._cnt += 1 diff --git a/tianshou/policy/modelfree/sac.py b/tianshou/policy/modelfree/sac.py index b67a95b..ce4a5ba 100644 --- a/tianshou/policy/modelfree/sac.py +++ b/tianshou/policy/modelfree/sac.py @@ -1,7 +1,6 @@ import torch import numpy as np from copy import deepcopy -import torch.nn.functional as F from typing import Dict, Tuple, Union, Optional from tianshou.policy import DDPGPolicy @@ -141,16 +140,23 @@ class SACPolicy(DDPGPolicy): # critic 1 current_q1 = self.critic1(batch.obs, batch.act).flatten() target_q = batch.returns.flatten() - critic1_loss = F.mse_loss(current_q1, target_q) + td1 = current_q1 - target_q + critic1_loss = (td1.pow(2) * batch.weight).mean() + # critic1_loss = F.mse_loss(current_q1, target_q) self.critic1_optim.zero_grad() critic1_loss.backward() self.critic1_optim.step() # critic 2 current_q2 = self.critic2(batch.obs, batch.act).flatten() - critic2_loss = F.mse_loss(current_q2, target_q) + td2 = current_q2 - target_q + critic2_loss = (td2.pow(2) * batch.weight).mean() + # critic2_loss = F.mse_loss(current_q2, target_q) self.critic2_optim.zero_grad() critic2_loss.backward() self.critic2_optim.step() + # prio-buffer + if hasattr(batch, 'update_weight'): + batch.update_weight(batch.indice, (td1 + td2) / 2.) # actor obs_result = self(batch, explorating=False) a = obs_result.act diff --git a/tianshou/policy/modelfree/td3.py b/tianshou/policy/modelfree/td3.py index d90f510..698145f 100644 --- a/tianshou/policy/modelfree/td3.py +++ b/tianshou/policy/modelfree/td3.py @@ -1,7 +1,6 @@ import torch import numpy as np from copy import deepcopy -import torch.nn.functional as F from typing import Dict, Tuple, Optional from tianshou.policy import DDPGPolicy @@ -119,16 +118,22 @@ class TD3Policy(DDPGPolicy): # critic 1 current_q1 = self.critic1(batch.obs, batch.act).flatten() target_q = batch.returns.flatten() - critic1_loss = F.mse_loss(current_q1, target_q) + td1 = current_q1 - target_q + critic1_loss = (td1.pow(2) * batch.weight).mean() + # critic1_loss = F.mse_loss(current_q1, target_q) self.critic1_optim.zero_grad() critic1_loss.backward() self.critic1_optim.step() # critic 2 current_q2 = self.critic2(batch.obs, batch.act).flatten() - critic2_loss = F.mse_loss(current_q2, target_q) + td2 = current_q2 - target_q + critic2_loss = (td2.pow(2) * batch.weight).mean() + # critic2_loss = F.mse_loss(current_q2, target_q) self.critic2_optim.zero_grad() critic2_loss.backward() self.critic2_optim.step() + if hasattr(batch, 'update_weight'): # prio-buffer + batch.update_weight(batch.indice, (td1 + td2) / 2.) if self._cnt % self._freq == 0: actor_loss = -self.critic1( batch.obs, self(batch, eps=0).act).mean()