Improve PER (#159)
- use segment tree to rewrite the previous PrioReplayBuffer code, add the test - enable all Q-learning algorithms to use PER
This commit is contained in:
parent
312b7551cc
commit
140b1c2cab
@ -38,7 +38,7 @@ Here is Tianshou's other features:
|
||||
- Support recurrent state representation in actor network and critic network (RNN-style training for POMDP) [Usage](https://tianshou.readthedocs.io/en/latest/tutorials/cheatsheet.html#rnn-style-training)
|
||||
- Support any type of environment state (e.g. a dict, a self-defined class, ...) [Usage](https://tianshou.readthedocs.io/en/latest/tutorials/cheatsheet.html#user-defined-environment-and-different-state-representation)
|
||||
- Support customized training process [Usage](https://tianshou.readthedocs.io/en/latest/tutorials/cheatsheet.html#customize-training-process)
|
||||
- Support n-step returns estimation for all Q-learning based algorithms
|
||||
- Support n-step returns estimation and prioritized experience replay for all Q-learning based algorithms
|
||||
- Support multi-agent RL [Usage](https://tianshou.readthedocs.io/en/latest/tutorials/cheatsheet.html##multi-agent-reinforcement-learning)
|
||||
|
||||
In Chinese, Tianshou means divinely ordained and is derived to the gift of being born with. Tianshou is a reinforcement learning platform, and the RL algorithm does not learn from humans. So taking "Tianshou" means that there is no teacher to study with, but rather to learn by themselves through constant interaction with the environment.
|
||||
|
||||
@ -28,7 +28,7 @@ Here is Tianshou's other features:
|
||||
* Support recurrent state representation in actor network and critic network (RNN-style training for POMDP): :ref:`rnn_training`
|
||||
* Support any type of environment state (e.g. a dict, a self-defined class, ...): :ref:`self_defined_env`
|
||||
* Support customized training process: :ref:`customize_training`
|
||||
* Support n-step returns estimation :meth:`~tianshou.policy.BasePolicy.compute_nstep_return` for all Q-learning based algorithms
|
||||
* Support n-step returns estimation :meth:`~tianshou.policy.BasePolicy.compute_nstep_return` and prioritized experience replay for all Q-learning based algorithms
|
||||
* Support multi-agent RL: :doc:`/tutorials/tictactoe`
|
||||
|
||||
中文文档位于 `https://tianshou.readthedocs.io/zh/latest/ <https://tianshou.readthedocs.io/zh/latest/>`_
|
||||
|
||||
@ -1,6 +1,9 @@
|
||||
import pytest
|
||||
import numpy as np
|
||||
from timeit import timeit
|
||||
|
||||
from tianshou.data import Batch, PrioritizedReplayBuffer, ReplayBuffer
|
||||
from tianshou.data import Batch, PrioritizedReplayBuffer, \
|
||||
ReplayBuffer, SegmentTree
|
||||
|
||||
if __name__ == '__main__':
|
||||
from env import MyTestEnv
|
||||
@ -112,9 +115,110 @@ def test_update():
|
||||
assert (buf2[-1].obs == buf1[0].obs).all()
|
||||
|
||||
|
||||
def test_segtree():
|
||||
for op, init in zip(['sum', 'max', 'min'], [0., -np.inf, np.inf]):
|
||||
realop = getattr(np, op)
|
||||
# small test
|
||||
actual_len = 8
|
||||
tree = SegmentTree(actual_len, op) # 1-15. 8-15 are leaf nodes
|
||||
assert np.all([tree[i] == init for i in range(actual_len)])
|
||||
with pytest.raises(IndexError):
|
||||
tree[actual_len]
|
||||
naive = np.full([actual_len], init)
|
||||
for _ in range(1000):
|
||||
# random choose a place to perform single update
|
||||
index = np.random.randint(actual_len)
|
||||
value = np.random.rand()
|
||||
naive[index] = value
|
||||
tree[index] = value
|
||||
for i in range(actual_len):
|
||||
for j in range(i + 1, actual_len):
|
||||
ref = realop(naive[i:j])
|
||||
out = tree.reduce(i, j)
|
||||
assert np.allclose(ref, out)
|
||||
# batch setitem
|
||||
for _ in range(1000):
|
||||
index = np.random.choice(actual_len, size=4)
|
||||
value = np.random.rand(4)
|
||||
naive[index] = value
|
||||
tree[index] = value
|
||||
assert np.allclose(realop(naive), tree.reduce())
|
||||
for i in range(10):
|
||||
left = np.random.randint(actual_len)
|
||||
right = np.random.randint(left + 1, actual_len + 1)
|
||||
assert np.allclose(realop(naive[left:right]),
|
||||
tree.reduce(left, right))
|
||||
# large test
|
||||
actual_len = 16384
|
||||
tree = SegmentTree(actual_len, op)
|
||||
naive = np.full([actual_len], init)
|
||||
for _ in range(1000):
|
||||
index = np.random.choice(actual_len, size=64)
|
||||
value = np.random.rand(64)
|
||||
naive[index] = value
|
||||
tree[index] = value
|
||||
assert np.allclose(realop(naive), tree.reduce())
|
||||
for i in range(10):
|
||||
left = np.random.randint(actual_len)
|
||||
right = np.random.randint(left + 1, actual_len + 1)
|
||||
assert np.allclose(realop(naive[left:right]),
|
||||
tree.reduce(left, right))
|
||||
|
||||
# test prefix-sum-idx
|
||||
actual_len = 8
|
||||
tree = SegmentTree(actual_len)
|
||||
naive = np.random.rand(actual_len)
|
||||
tree[np.arange(actual_len)] = naive
|
||||
for _ in range(1000):
|
||||
scalar = np.random.rand() * naive.sum()
|
||||
index = tree.get_prefix_sum_idx(scalar)
|
||||
assert naive[:index].sum() <= scalar <= naive[:index + 1].sum()
|
||||
# corner case here
|
||||
naive = np.ones(actual_len, np.int)
|
||||
tree[np.arange(actual_len)] = naive
|
||||
for scalar in range(actual_len):
|
||||
index = tree.get_prefix_sum_idx(scalar * 1.)
|
||||
assert naive[:index].sum() <= scalar <= naive[:index + 1].sum()
|
||||
tree = SegmentTree(10)
|
||||
tree[np.arange(3)] = np.array([0.1, 0, 0.1])
|
||||
assert np.allclose(tree.get_prefix_sum_idx(
|
||||
np.array([0, .1, .1 + 1e-6, .2 - 1e-6])), [0, 0, 2, 2])
|
||||
with pytest.raises(AssertionError):
|
||||
tree.get_prefix_sum_idx(.2)
|
||||
# test large prefix-sum-idx
|
||||
actual_len = 16384
|
||||
tree = SegmentTree(actual_len)
|
||||
naive = np.random.rand(actual_len)
|
||||
tree[np.arange(actual_len)] = naive
|
||||
for _ in range(1000):
|
||||
scalar = np.random.rand() * naive.sum()
|
||||
index = tree.get_prefix_sum_idx(scalar)
|
||||
assert naive[:index].sum() <= scalar <= naive[:index + 1].sum()
|
||||
|
||||
# profile
|
||||
if __name__ == '__main__':
|
||||
size = 100000
|
||||
bsz = 64
|
||||
naive = np.random.rand(size)
|
||||
tree = SegmentTree(size)
|
||||
tree[np.arange(size)] = naive
|
||||
|
||||
def sample_npbuf():
|
||||
return np.random.choice(size, bsz, p=naive / naive.sum())
|
||||
|
||||
def sample_tree():
|
||||
scalar = np.random.rand(bsz) * tree.reduce()
|
||||
return tree.get_prefix_sum_idx(scalar)
|
||||
|
||||
print('npbuf', timeit(sample_npbuf, setup=sample_npbuf, number=1000))
|
||||
print('tree', timeit(sample_tree, setup=sample_tree, number=1000))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_replaybuffer()
|
||||
test_ignore_obs_next()
|
||||
test_stack()
|
||||
test_segtree()
|
||||
test_priortized_replaybuffer()
|
||||
test_priortized_replaybuffer(233333, 200000)
|
||||
test_update()
|
||||
|
||||
@ -8,9 +8,9 @@ from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from tianshou.env import VectorEnv
|
||||
from tianshou.policy import DQNPolicy
|
||||
from tianshou.trainer import offpolicy_trainer
|
||||
from tianshou.data import Collector, ReplayBuffer
|
||||
from tianshou.utils.net.common import Net
|
||||
from tianshou.trainer import offpolicy_trainer
|
||||
from tianshou.data import Collector, ReplayBuffer, PrioritizedReplayBuffer
|
||||
|
||||
|
||||
def get_args():
|
||||
@ -33,6 +33,9 @@ def get_args():
|
||||
parser.add_argument('--test-num', type=int, default=100)
|
||||
parser.add_argument('--logdir', type=str, default='log')
|
||||
parser.add_argument('--render', type=float, default=0.)
|
||||
parser.add_argument('--prioritized-replay', type=int, default=0)
|
||||
parser.add_argument('--alpha', type=float, default=0.6)
|
||||
parser.add_argument('--beta', type=float, default=0.4)
|
||||
parser.add_argument(
|
||||
'--device', type=str,
|
||||
default='cuda' if torch.cuda.is_available() else 'cpu')
|
||||
@ -58,15 +61,20 @@ def test_dqn(args=get_args()):
|
||||
test_envs.seed(args.seed)
|
||||
# model
|
||||
net = Net(args.layer_num, args.state_shape,
|
||||
args.action_shape, args.device,
|
||||
dueling=(2, 2)).to(args.device)
|
||||
args.action_shape, args.device, # dueling=(1, 1)
|
||||
).to(args.device)
|
||||
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
|
||||
policy = DQNPolicy(
|
||||
net, optim, args.gamma, args.n_step,
|
||||
target_update_freq=args.target_update_freq)
|
||||
# buffer
|
||||
if args.prioritized_replay > 0:
|
||||
buf = PrioritizedReplayBuffer(
|
||||
args.buffer_size, alpha=args.alpha, beta=args.beta)
|
||||
else:
|
||||
buf = ReplayBuffer(args.buffer_size)
|
||||
# collector
|
||||
train_collector = Collector(
|
||||
policy, train_envs, ReplayBuffer(args.buffer_size))
|
||||
train_collector = Collector(policy, train_envs, buf)
|
||||
test_collector = Collector(policy, test_envs)
|
||||
# policy.set_eps(1)
|
||||
train_collector.collect(n_step=args.batch_size)
|
||||
@ -114,5 +122,11 @@ def test_dqn(args=get_args()):
|
||||
collector.close()
|
||||
|
||||
|
||||
def test_pdqn(args=get_args()):
|
||||
args.prioritized_replay = 1
|
||||
args.gamma = .95
|
||||
test_dqn(args)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_dqn(get_args())
|
||||
|
||||
@ -1,118 +0,0 @@
|
||||
import os
|
||||
import gym
|
||||
import torch
|
||||
import pprint
|
||||
import argparse
|
||||
import numpy as np
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from tianshou.utils.net.common import Net
|
||||
from tianshou.env import VectorEnv
|
||||
from tianshou.policy import DQNPolicy
|
||||
from tianshou.trainer import offpolicy_trainer
|
||||
from tianshou.data import Collector, ReplayBuffer, PrioritizedReplayBuffer
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--task', type=str, default='CartPole-v0')
|
||||
parser.add_argument('--seed', type=int, default=1626)
|
||||
parser.add_argument('--eps-test', type=float, default=0.05)
|
||||
parser.add_argument('--eps-train', type=float, default=0.1)
|
||||
parser.add_argument('--buffer-size', type=int, default=20000)
|
||||
parser.add_argument('--lr', type=float, default=1e-3)
|
||||
parser.add_argument('--gamma', type=float, default=0.9)
|
||||
parser.add_argument('--n-step', type=int, default=3)
|
||||
parser.add_argument('--target-update-freq', type=int, default=320)
|
||||
parser.add_argument('--epoch', type=int, default=10)
|
||||
parser.add_argument('--step-per-epoch', type=int, default=1000)
|
||||
parser.add_argument('--collect-per-step', type=int, default=10)
|
||||
parser.add_argument('--batch-size', type=int, default=64)
|
||||
parser.add_argument('--layer-num', type=int, default=3)
|
||||
parser.add_argument('--training-num', type=int, default=8)
|
||||
parser.add_argument('--test-num', type=int, default=100)
|
||||
parser.add_argument('--logdir', type=str, default='log')
|
||||
parser.add_argument('--render', type=float, default=0.)
|
||||
parser.add_argument('--prioritized-replay', type=int, default=1)
|
||||
parser.add_argument('--alpha', type=float, default=0.5)
|
||||
parser.add_argument('--beta', type=float, default=0.5)
|
||||
parser.add_argument(
|
||||
'--device', type=str,
|
||||
default='cuda' if torch.cuda.is_available() else 'cpu')
|
||||
args = parser.parse_known_args()[0]
|
||||
return args
|
||||
|
||||
|
||||
def test_pdqn(args=get_args()):
|
||||
env = gym.make(args.task)
|
||||
args.state_shape = env.observation_space.shape or env.observation_space.n
|
||||
args.action_shape = env.action_space.shape or env.action_space.n
|
||||
# train_envs = gym.make(args.task)
|
||||
# you can also use tianshou.env.SubprocVectorEnv
|
||||
train_envs = VectorEnv(
|
||||
[lambda: gym.make(args.task) for _ in range(args.training_num)])
|
||||
# test_envs = gym.make(args.task)
|
||||
test_envs = VectorEnv(
|
||||
[lambda: gym.make(args.task) for _ in range(args.test_num)])
|
||||
# seed
|
||||
np.random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
train_envs.seed(args.seed)
|
||||
test_envs.seed(args.seed)
|
||||
# model
|
||||
net = Net(args.layer_num, args.state_shape,
|
||||
args.action_shape, args.device).to(args.device)
|
||||
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
|
||||
policy = DQNPolicy(
|
||||
net, optim, args.gamma, args.n_step,
|
||||
target_update_freq=args.target_update_freq)
|
||||
# collector
|
||||
if args.prioritized_replay > 0:
|
||||
buf = PrioritizedReplayBuffer(
|
||||
args.buffer_size, alpha=args.alpha,
|
||||
beta=args.alpha, repeat_sample=True)
|
||||
else:
|
||||
buf = ReplayBuffer(args.buffer_size)
|
||||
train_collector = Collector(
|
||||
policy, train_envs, buf)
|
||||
test_collector = Collector(policy, test_envs)
|
||||
# policy.set_eps(1)
|
||||
train_collector.collect(n_step=args.batch_size)
|
||||
# log
|
||||
log_path = os.path.join(args.logdir, args.task, 'dqn')
|
||||
writer = SummaryWriter(log_path)
|
||||
|
||||
def save_fn(policy):
|
||||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||
|
||||
def stop_fn(x):
|
||||
return x >= env.spec.reward_threshold
|
||||
|
||||
def train_fn(x):
|
||||
policy.set_eps(args.eps_train)
|
||||
|
||||
def test_fn(x):
|
||||
policy.set_eps(args.eps_test)
|
||||
|
||||
# trainer
|
||||
result = offpolicy_trainer(
|
||||
policy, train_collector, test_collector, args.epoch,
|
||||
args.step_per_epoch, args.collect_per_step, args.test_num,
|
||||
args.batch_size, train_fn=train_fn, test_fn=test_fn,
|
||||
stop_fn=stop_fn, save_fn=save_fn, writer=writer)
|
||||
|
||||
assert stop_fn(result['best_reward'])
|
||||
train_collector.close()
|
||||
test_collector.close()
|
||||
if __name__ == '__main__':
|
||||
pprint.pprint(result)
|
||||
# Let's watch its performance!
|
||||
env = gym.make(args.task)
|
||||
collector = Collector(policy, env)
|
||||
result = collector.collect(n_episode=1, render=args.render)
|
||||
print(f'Final reward: {result["rew"]}, length: {result["len"]}')
|
||||
collector.close()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_pdqn(get_args())
|
||||
@ -1,8 +1,8 @@
|
||||
import numpy as np
|
||||
import pytest
|
||||
import numpy as np
|
||||
|
||||
from tianshou.data import (ListReplayBuffer, PrioritizedReplayBuffer,
|
||||
ReplayBuffer)
|
||||
ReplayBuffer, SegmentTree)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
@ -21,7 +21,7 @@ def data():
|
||||
'buffer': buffer,
|
||||
'buffer2': buffer2,
|
||||
'slice': slice(-3000, -1000, 2),
|
||||
'indexes': indexes
|
||||
'indexes': indexes,
|
||||
}
|
||||
|
||||
|
||||
@ -77,5 +77,15 @@ def test_sample(data):
|
||||
buffer.sample(int(1e2))
|
||||
|
||||
|
||||
def test_segtree(data):
|
||||
size = 100000
|
||||
tree = SegmentTree(size)
|
||||
tree[np.arange(size)] = np.random.rand(size)
|
||||
|
||||
for i in np.arange(1e5):
|
||||
scalar = np.random.rand(64) * tree.reduce()
|
||||
tree.get_prefix_sum_idx(scalar)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main(["-s", "-k buffer_profile", "--durations=0", "-v"])
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
from tianshou.data.batch import Batch
|
||||
from tianshou.data.utils import to_numpy, to_torch, \
|
||||
from tianshou.data.utils.converter import to_numpy, to_torch, \
|
||||
to_torch_as
|
||||
from tianshou.data.utils.segtree import SegmentTree
|
||||
from tianshou.data.buffer import ReplayBuffer, \
|
||||
ListReplayBuffer, PrioritizedReplayBuffer
|
||||
from tianshou.data.collector import Collector
|
||||
@ -10,8 +11,9 @@ __all__ = [
|
||||
'to_numpy',
|
||||
'to_torch',
|
||||
'to_torch_as',
|
||||
'SegmentTree',
|
||||
'ReplayBuffer',
|
||||
'ListReplayBuffer',
|
||||
'PrioritizedReplayBuffer',
|
||||
'Collector'
|
||||
'Collector',
|
||||
]
|
||||
|
||||
@ -1,7 +1,9 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from typing import Any, Tuple, Union, Optional
|
||||
|
||||
from tianshou.data.batch import Batch, _create_value
|
||||
from tianshou.data import Batch, SegmentTree, to_numpy
|
||||
from tianshou.data.batch import _create_value
|
||||
|
||||
|
||||
class ReplayBuffer:
|
||||
@ -313,7 +315,7 @@ class ReplayBuffer:
|
||||
done=self.done[index],
|
||||
obs_next=self.get(index, 'obs_next'),
|
||||
info=self.get(index, 'info'),
|
||||
policy=self.get(index, 'policy')
|
||||
policy=self.get(index, 'policy'),
|
||||
)
|
||||
|
||||
|
||||
@ -326,8 +328,8 @@ class ListReplayBuffer(ReplayBuffer):
|
||||
|
||||
.. seealso::
|
||||
|
||||
Please refer to :class:`~tianshou.data.ReplayBuffer` for more
|
||||
detailed explanation.
|
||||
Please refer to :class:`~tianshou.data.ReplayBuffer` for more detailed
|
||||
explanation.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs) -> None:
|
||||
@ -353,31 +355,32 @@ class ListReplayBuffer(ReplayBuffer):
|
||||
|
||||
|
||||
class PrioritizedReplayBuffer(ReplayBuffer):
|
||||
"""Prioritized replay buffer implementation.
|
||||
"""Implementation of Prioritized Experience Replay. arXiv:1511.05952
|
||||
|
||||
:param float alpha: the prioritization exponent.
|
||||
:param float beta: the importance sample soft coefficient.
|
||||
:param str mode: defaults to ``weight``.
|
||||
:param bool replace: whether to sample with replacement
|
||||
|
||||
.. seealso::
|
||||
|
||||
Please refer to :class:`~tianshou.data.ReplayBuffer` for more
|
||||
detailed explanation.
|
||||
Please refer to :class:`~tianshou.data.ReplayBuffer` for more detailed
|
||||
explanation.
|
||||
"""
|
||||
|
||||
def __init__(self, size: int, alpha: float, beta: float,
|
||||
mode: str = 'weight',
|
||||
replace: bool = False, **kwargs) -> None:
|
||||
if mode != 'weight':
|
||||
raise NotImplementedError
|
||||
def __init__(self, size: int, alpha: float, beta: float, **kwargs) -> None:
|
||||
super().__init__(size, **kwargs)
|
||||
self._alpha = alpha
|
||||
self._beta = beta
|
||||
self._weight_sum = 0.0
|
||||
self._amortization_freq = 50
|
||||
self._replace = replace
|
||||
self._meta.weight = np.zeros(size, dtype=np.float64)
|
||||
assert alpha > 0. and beta >= 0.
|
||||
self._alpha, self._beta = alpha, beta
|
||||
self._max_prio = 1.
|
||||
self._min_prio = 1.
|
||||
# bypass the check
|
||||
self._weight = SegmentTree(size)
|
||||
self.__eps = np.finfo(np.float32).eps.item()
|
||||
|
||||
def __getattr__(self, key: str) -> Union['Batch', Any]:
|
||||
"""Return self.key"""
|
||||
if key == 'weight':
|
||||
return self._weight
|
||||
return self._meta.__dict__[key]
|
||||
|
||||
def add(self,
|
||||
obs: Union[dict, np.ndarray],
|
||||
@ -387,68 +390,55 @@ class PrioritizedReplayBuffer(ReplayBuffer):
|
||||
obs_next: Optional[Union[dict, np.ndarray]] = None,
|
||||
info: dict = {},
|
||||
policy: Optional[Union[dict, Batch]] = {},
|
||||
weight: float = 1.0,
|
||||
weight: float = None,
|
||||
**kwargs) -> None:
|
||||
"""Add a batch of data into replay buffer."""
|
||||
# we have to sacrifice some convenience for speed
|
||||
self._weight_sum += np.abs(weight) ** self._alpha - \
|
||||
self._meta.weight[self._index]
|
||||
self._add_to_buffer('weight', np.abs(weight) ** self._alpha)
|
||||
if weight is None:
|
||||
weight = self._max_prio
|
||||
else:
|
||||
weight = np.abs(weight)
|
||||
self._max_prio = max(self._max_prio, weight)
|
||||
self._min_prio = min(self._min_prio, weight)
|
||||
self.weight[self._index] = weight ** self._alpha
|
||||
super().add(obs, act, rew, done, obs_next, info, policy)
|
||||
|
||||
@property
|
||||
def replace(self):
|
||||
return self._replace
|
||||
|
||||
@replace.setter
|
||||
def replace(self, v: bool):
|
||||
self._replace = v
|
||||
|
||||
def sample(self, batch_size: int) -> Tuple[Batch, np.ndarray]:
|
||||
"""Get a random sample from buffer with priority probability. \
|
||||
Return all the data in the buffer if batch_size is ``0``.
|
||||
"""Get a random sample from buffer with priority probability. Return
|
||||
all the data in the buffer if batch_size is ``0``.
|
||||
|
||||
:return: Sample data and its corresponding index inside the buffer.
|
||||
|
||||
The ``weight`` in the returned Batch is the weight on loss function
|
||||
to de-bias the sampling process (some transition tuples are sampled
|
||||
more often so their losses are weighted less).
|
||||
"""
|
||||
assert self._size > 0, 'cannot sample a buffer with size == 0 !'
|
||||
p = None
|
||||
if batch_size > 0 and (self._replace or batch_size <= self._size):
|
||||
# sampling weight
|
||||
p = (self.weight / self.weight.sum())[:self._size]
|
||||
indice = np.random.choice(
|
||||
self._size, batch_size, p=p,
|
||||
replace=self._replace)
|
||||
p = p[indice] # weight of each sample
|
||||
elif batch_size == 0:
|
||||
p = np.full(shape=self._size, fill_value=1.0 / self._size)
|
||||
assert self._size > 0, 'Cannot sample a buffer with 0 size!'
|
||||
if batch_size == 0:
|
||||
indice = np.concatenate([
|
||||
np.arange(self._index, self._size),
|
||||
np.arange(0, self._index),
|
||||
])
|
||||
else:
|
||||
raise ValueError(
|
||||
f"batch_size should be less than {len(self)}, \
|
||||
or set replace=True")
|
||||
scalar = np.random.rand(batch_size) * self.weight.reduce()
|
||||
indice = self.weight.get_prefix_sum_idx(scalar)
|
||||
batch = self[indice]
|
||||
batch["impt_weight"] = (self._size * p) ** (-self._beta)
|
||||
# impt_weight
|
||||
# original formula: ((p_j/p_sum*N)**(-beta))/((p_min/p_sum*N)**(-beta))
|
||||
# simplified formula: (p_j/p_min)**(-beta)
|
||||
batch.weight = (batch.weight / self._min_prio) ** (-self._beta)
|
||||
return batch, indice
|
||||
|
||||
def update_weight(self, indice: Union[slice, np.ndarray],
|
||||
new_weight: np.ndarray) -> None:
|
||||
def update_weight(self, indice: Union[np.ndarray],
|
||||
new_weight: Union[np.ndarray, torch.Tensor]) -> None:
|
||||
"""Update priority weight by indice in this buffer.
|
||||
|
||||
:param np.ndarray indice: indice you want to update weight
|
||||
:param np.ndarray new_weight: new priority weight you want to update
|
||||
:param np.ndarray indice: indice you want to update weight.
|
||||
:param np.ndarray new_weight: new priority weight you want to update.
|
||||
"""
|
||||
if self._replace:
|
||||
if isinstance(indice, slice):
|
||||
# convert slice to ndarray
|
||||
indice = np.arange(indice.stop)[indice]
|
||||
# remove the same values in indice
|
||||
indice, unique_indice = np.unique(
|
||||
indice, return_index=True)
|
||||
new_weight = new_weight[unique_indice]
|
||||
self.weight[indice] = np.power(np.abs(new_weight), self._alpha)
|
||||
weight = np.abs(to_numpy(new_weight)) + self.__eps
|
||||
self.weight[indice] = weight ** self._alpha
|
||||
self._max_prio = max(self._max_prio, weight.max())
|
||||
self._min_prio = min(self._min_prio, weight.min())
|
||||
|
||||
def __getitem__(self, index: Union[
|
||||
slice, int, np.integer, np.ndarray]) -> Batch:
|
||||
@ -459,6 +449,6 @@ class PrioritizedReplayBuffer(ReplayBuffer):
|
||||
done=self.done[index],
|
||||
obs_next=self.get(index, 'obs_next'),
|
||||
info=self.get(index, 'info'),
|
||||
weight=self.weight[index],
|
||||
policy=self.get(index, 'policy'),
|
||||
weight=self.weight[index],
|
||||
)
|
||||
|
||||
0
tianshou/data/utils/__init__.py
Normal file
0
tianshou/data/utils/__init__.py
Normal file
133
tianshou/data/utils/segtree.py
Normal file
133
tianshou/data/utils/segtree.py
Normal file
@ -0,0 +1,133 @@
|
||||
import numpy as np
|
||||
from typing import Union, Optional
|
||||
# from numba import njit
|
||||
|
||||
|
||||
# numba version, 5x speed up
|
||||
# with size=100000 and bsz=64
|
||||
# first block (vectorized np): 0.0923 (now) -> 0.0251
|
||||
# second block (for-loop): 0.2914 -> 0.0192 (future)
|
||||
# @njit
|
||||
def _get_prefix_sum_idx(value, bound, sums):
|
||||
index = np.ones(value.shape, dtype=np.int64)
|
||||
while index[0] < bound:
|
||||
index *= 2
|
||||
direct = sums[index] < value
|
||||
value -= sums[index] * direct
|
||||
index += direct
|
||||
# for _, s in enumerate(value):
|
||||
# i = 1
|
||||
# while i < bound:
|
||||
# l = i * 2
|
||||
# if sums[l] >= s:
|
||||
# i = l
|
||||
# else:
|
||||
# s = s - sums[l]
|
||||
# i = l + 1
|
||||
# index[_] = i
|
||||
index -= bound
|
||||
return index
|
||||
|
||||
|
||||
class SegmentTree:
|
||||
"""Implementation of Segment Tree: store an array ``arr`` with size ``n``
|
||||
in a segment tree, support value update and fast query of ``min/max/sum``
|
||||
for the interval ``[left, right)`` in O(log n) time.
|
||||
|
||||
The detailed procedure is as follows:
|
||||
|
||||
1. Pad the array to have length of power of 2, so that leaf nodes in the\
|
||||
segment tree have the same depth.
|
||||
2. Store the segment tree in a binary heap.
|
||||
|
||||
:param int size: the size of segment tree.
|
||||
:param str operation: the operation of segment tree. Choices are "sum",
|
||||
"min" and "max". Default: "sum".
|
||||
"""
|
||||
|
||||
def __init__(self, size: int,
|
||||
operation: str = 'sum') -> None:
|
||||
bound = 1
|
||||
while bound < size:
|
||||
bound *= 2
|
||||
self._size = size
|
||||
self._bound = bound
|
||||
assert operation in ['sum', 'min', 'max'], \
|
||||
f'Unknown operation {operation}.'
|
||||
if operation == 'sum':
|
||||
self._op, self._init_value = np.add, 0.
|
||||
elif operation == 'min':
|
||||
self._op, self._init_value = np.minimum, np.inf
|
||||
else:
|
||||
self._op, self._init_value = np.maximum, -np.inf
|
||||
# assert isinstance(self._op, np.ufunc)
|
||||
self._value = np.full([bound * 2], self._init_value)
|
||||
|
||||
def __len__(self):
|
||||
return self._size
|
||||
|
||||
def __getitem__(self, index: Union[int, np.ndarray]
|
||||
) -> Union[float, np.ndarray]:
|
||||
"""Return self[index]"""
|
||||
return self._value[index + self._bound]
|
||||
|
||||
def __setitem__(self, index: Union[int, np.ndarray],
|
||||
value: Union[float, np.ndarray]) -> None:
|
||||
"""Duplicate values in ``index`` are handled by numpy: later index
|
||||
overwrites previous ones.
|
||||
|
||||
::
|
||||
|
||||
>>> a = np.array([1, 2, 3, 4])
|
||||
>>> a[[0, 1, 0, 1]] = [4, 5, 6, 7]
|
||||
>>> print(a)
|
||||
[6 7 3 4]
|
||||
|
||||
"""
|
||||
# TODO numba njit version
|
||||
if isinstance(index, int):
|
||||
index = np.array([index])
|
||||
assert np.all(0 <= index) and np.all(index < self._size)
|
||||
if self._op is np.add:
|
||||
assert np.all(0. <= value)
|
||||
index = index + self._bound
|
||||
self._value[index] = value
|
||||
while index[0] > 1:
|
||||
index //= 2
|
||||
self._value[index] = self._op(
|
||||
self._value[index * 2], self._value[index * 2 + 1])
|
||||
|
||||
def reduce(self, start: Optional[int] = 0,
|
||||
end: Optional[int] = None) -> float:
|
||||
"""Return operation(value[start:end])."""
|
||||
# TODO numba njit version
|
||||
if start == 0 and end is None:
|
||||
return self._value[1]
|
||||
if end is None:
|
||||
end = self._size
|
||||
if end < 0:
|
||||
end += self._size
|
||||
# nodes in (start, end) should be aggregated
|
||||
start, end = start + self._bound - 1, end + self._bound
|
||||
result = self._init_value
|
||||
while end - start > 1: # (start, end) interval is not empty
|
||||
if start % 2 == 0:
|
||||
result = self._op(result, self._value[start + 1])
|
||||
if end % 2 == 1:
|
||||
result = self._op(result, self._value[end - 1])
|
||||
start, end = start // 2, end // 2
|
||||
return result
|
||||
|
||||
def get_prefix_sum_idx(
|
||||
self, value: Union[float, np.ndarray]) -> Union[int, np.ndarray]:
|
||||
"""Return the minimum index for each ``v`` in ``value`` so that
|
||||
``v <= sums[i]``, where sums[i] = \\sum_{j=0}^{i} arr[j].
|
||||
"""
|
||||
assert self._op is np.add
|
||||
assert np.all(value >= 0.) and np.all(value < self._value[1])
|
||||
single = False
|
||||
if not isinstance(value, np.ndarray):
|
||||
value = np.array([value])
|
||||
single = True
|
||||
index = _get_prefix_sum_idx(value, self._bound, self._value)
|
||||
return index.item() if single else index
|
||||
@ -4,7 +4,8 @@ from torch import nn
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List, Union, Optional, Callable
|
||||
|
||||
from tianshou.data import Batch, ReplayBuffer, to_torch_as, to_numpy
|
||||
from tianshou.data import Batch, ReplayBuffer, PrioritizedReplayBuffer, \
|
||||
to_torch_as, to_numpy
|
||||
|
||||
|
||||
class BasePolicy(ABC, nn.Module):
|
||||
@ -213,4 +214,11 @@ class BasePolicy(ABC, nn.Module):
|
||||
returns = to_torch_as(returns, target_q)
|
||||
gammas = to_torch_as(gamma ** gammas, target_q)
|
||||
batch.returns = target_q * gammas + returns
|
||||
# prio buffer update
|
||||
if isinstance(buffer, PrioritizedReplayBuffer):
|
||||
batch.update_weight = buffer.update_weight
|
||||
batch.indice = indice
|
||||
batch.weight = to_torch_as(batch.weight, target_q)
|
||||
else:
|
||||
batch.weight = torch.ones_like(target_q)
|
||||
return batch
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from copy import deepcopy
|
||||
import torch.nn.functional as F
|
||||
from typing import Dict, Tuple, Union, Optional
|
||||
|
||||
from tianshou.policy import BasePolicy
|
||||
@ -144,7 +143,11 @@ class DDPGPolicy(BasePolicy):
|
||||
def learn(self, batch: Batch, **kwargs) -> Dict[str, float]:
|
||||
current_q = self.critic(batch.obs, batch.act).flatten()
|
||||
target_q = batch.returns.flatten()
|
||||
critic_loss = F.mse_loss(current_q, target_q)
|
||||
td = current_q - target_q
|
||||
if hasattr(batch, 'update_weight'): # prio-buffer
|
||||
batch.update_weight(batch.indice, td)
|
||||
critic_loss = (td.pow(2) * batch.weight).mean()
|
||||
# critic_loss = F.mse_loss(current_q, target_q)
|
||||
self.critic_optim.zero_grad()
|
||||
critic_loss.backward()
|
||||
self.critic_optim.step()
|
||||
|
||||
@ -1,12 +1,10 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from copy import deepcopy
|
||||
import torch.nn.functional as F
|
||||
from typing import Dict, Union, Optional
|
||||
|
||||
from tianshou.policy import BasePolicy
|
||||
from tianshou.data import Batch, ReplayBuffer, PrioritizedReplayBuffer, \
|
||||
to_torch_as, to_numpy
|
||||
from tianshou.data import Batch, ReplayBuffer, to_torch_as, to_numpy
|
||||
|
||||
|
||||
class DQNPolicy(BasePolicy):
|
||||
@ -95,9 +93,6 @@ class DQNPolicy(BasePolicy):
|
||||
batch = self.compute_nstep_return(
|
||||
batch, buffer, indice, self._target_q,
|
||||
self._gamma, self._n_step, self._rew_norm)
|
||||
if isinstance(buffer, PrioritizedReplayBuffer):
|
||||
batch.update_weight = buffer.update_weight
|
||||
batch.indice = indice
|
||||
return batch
|
||||
|
||||
def forward(self, batch: Batch,
|
||||
@ -164,13 +159,11 @@ class DQNPolicy(BasePolicy):
|
||||
q = self(batch, eps=0.).logits
|
||||
q = q[np.arange(len(q)), batch.act]
|
||||
r = to_torch_as(batch.returns, q).flatten()
|
||||
if hasattr(batch, 'update_weight'):
|
||||
td = r - q
|
||||
batch.update_weight(batch.indice, to_numpy(td))
|
||||
impt_weight = to_torch_as(batch.impt_weight, q)
|
||||
loss = (td.pow(2) * impt_weight).mean()
|
||||
else:
|
||||
loss = F.mse_loss(q, r)
|
||||
td = r - q
|
||||
if hasattr(batch, 'update_weight'): # prio-buffer
|
||||
batch.update_weight(batch.indice, td)
|
||||
loss = (td.pow(2) * batch.weight).mean()
|
||||
# loss = F.mse_loss(q, r)
|
||||
loss.backward()
|
||||
self.optim.step()
|
||||
self._cnt += 1
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from copy import deepcopy
|
||||
import torch.nn.functional as F
|
||||
from typing import Dict, Tuple, Union, Optional
|
||||
|
||||
from tianshou.policy import DDPGPolicy
|
||||
@ -141,16 +140,23 @@ class SACPolicy(DDPGPolicy):
|
||||
# critic 1
|
||||
current_q1 = self.critic1(batch.obs, batch.act).flatten()
|
||||
target_q = batch.returns.flatten()
|
||||
critic1_loss = F.mse_loss(current_q1, target_q)
|
||||
td1 = current_q1 - target_q
|
||||
critic1_loss = (td1.pow(2) * batch.weight).mean()
|
||||
# critic1_loss = F.mse_loss(current_q1, target_q)
|
||||
self.critic1_optim.zero_grad()
|
||||
critic1_loss.backward()
|
||||
self.critic1_optim.step()
|
||||
# critic 2
|
||||
current_q2 = self.critic2(batch.obs, batch.act).flatten()
|
||||
critic2_loss = F.mse_loss(current_q2, target_q)
|
||||
td2 = current_q2 - target_q
|
||||
critic2_loss = (td2.pow(2) * batch.weight).mean()
|
||||
# critic2_loss = F.mse_loss(current_q2, target_q)
|
||||
self.critic2_optim.zero_grad()
|
||||
critic2_loss.backward()
|
||||
self.critic2_optim.step()
|
||||
# prio-buffer
|
||||
if hasattr(batch, 'update_weight'):
|
||||
batch.update_weight(batch.indice, (td1 + td2) / 2.)
|
||||
# actor
|
||||
obs_result = self(batch, explorating=False)
|
||||
a = obs_result.act
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from copy import deepcopy
|
||||
import torch.nn.functional as F
|
||||
from typing import Dict, Tuple, Optional
|
||||
|
||||
from tianshou.policy import DDPGPolicy
|
||||
@ -119,16 +118,22 @@ class TD3Policy(DDPGPolicy):
|
||||
# critic 1
|
||||
current_q1 = self.critic1(batch.obs, batch.act).flatten()
|
||||
target_q = batch.returns.flatten()
|
||||
critic1_loss = F.mse_loss(current_q1, target_q)
|
||||
td1 = current_q1 - target_q
|
||||
critic1_loss = (td1.pow(2) * batch.weight).mean()
|
||||
# critic1_loss = F.mse_loss(current_q1, target_q)
|
||||
self.critic1_optim.zero_grad()
|
||||
critic1_loss.backward()
|
||||
self.critic1_optim.step()
|
||||
# critic 2
|
||||
current_q2 = self.critic2(batch.obs, batch.act).flatten()
|
||||
critic2_loss = F.mse_loss(current_q2, target_q)
|
||||
td2 = current_q2 - target_q
|
||||
critic2_loss = (td2.pow(2) * batch.weight).mean()
|
||||
# critic2_loss = F.mse_loss(current_q2, target_q)
|
||||
self.critic2_optim.zero_grad()
|
||||
critic2_loss.backward()
|
||||
self.critic2_optim.step()
|
||||
if hasattr(batch, 'update_weight'): # prio-buffer
|
||||
batch.update_weight(batch.indice, (td1 + td2) / 2.)
|
||||
if self._cnt % self._freq == 0:
|
||||
actor_loss = -self.critic1(
|
||||
batch.obs, self(batch, eps=0).act).mean()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user