Improve PER (#159)

- use segment tree to rewrite the previous PrioReplayBuffer code, add the test

- enable all Q-learning algorithms to use PER
This commit is contained in:
n+e 2020-08-06 10:26:24 +08:00 committed by GitHub
parent 312b7551cc
commit 140b1c2cab
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 368 additions and 218 deletions

View File

@ -38,7 +38,7 @@ Here is Tianshou's other features:
- Support recurrent state representation in actor network and critic network (RNN-style training for POMDP) [Usage](https://tianshou.readthedocs.io/en/latest/tutorials/cheatsheet.html#rnn-style-training)
- Support any type of environment state (e.g. a dict, a self-defined class, ...) [Usage](https://tianshou.readthedocs.io/en/latest/tutorials/cheatsheet.html#user-defined-environment-and-different-state-representation)
- Support customized training process [Usage](https://tianshou.readthedocs.io/en/latest/tutorials/cheatsheet.html#customize-training-process)
- Support n-step returns estimation for all Q-learning based algorithms
- Support n-step returns estimation and prioritized experience replay for all Q-learning based algorithms
- Support multi-agent RL [Usage](https://tianshou.readthedocs.io/en/latest/tutorials/cheatsheet.html##multi-agent-reinforcement-learning)
In Chinese, Tianshou means divinely ordained and is derived to the gift of being born with. Tianshou is a reinforcement learning platform, and the RL algorithm does not learn from humans. So taking "Tianshou" means that there is no teacher to study with, but rather to learn by themselves through constant interaction with the environment.

View File

@ -28,7 +28,7 @@ Here is Tianshou's other features:
* Support recurrent state representation in actor network and critic network (RNN-style training for POMDP): :ref:`rnn_training`
* Support any type of environment state (e.g. a dict, a self-defined class, ...): :ref:`self_defined_env`
* Support customized training process: :ref:`customize_training`
* Support n-step returns estimation :meth:`~tianshou.policy.BasePolicy.compute_nstep_return` for all Q-learning based algorithms
* Support n-step returns estimation :meth:`~tianshou.policy.BasePolicy.compute_nstep_return` and prioritized experience replay for all Q-learning based algorithms
* Support multi-agent RL: :doc:`/tutorials/tictactoe`
中文文档位于 `https://tianshou.readthedocs.io/zh/latest/ <https://tianshou.readthedocs.io/zh/latest/>`_

View File

@ -1,6 +1,9 @@
import pytest
import numpy as np
from timeit import timeit
from tianshou.data import Batch, PrioritizedReplayBuffer, ReplayBuffer
from tianshou.data import Batch, PrioritizedReplayBuffer, \
ReplayBuffer, SegmentTree
if __name__ == '__main__':
from env import MyTestEnv
@ -112,9 +115,110 @@ def test_update():
assert (buf2[-1].obs == buf1[0].obs).all()
def test_segtree():
for op, init in zip(['sum', 'max', 'min'], [0., -np.inf, np.inf]):
realop = getattr(np, op)
# small test
actual_len = 8
tree = SegmentTree(actual_len, op) # 1-15. 8-15 are leaf nodes
assert np.all([tree[i] == init for i in range(actual_len)])
with pytest.raises(IndexError):
tree[actual_len]
naive = np.full([actual_len], init)
for _ in range(1000):
# random choose a place to perform single update
index = np.random.randint(actual_len)
value = np.random.rand()
naive[index] = value
tree[index] = value
for i in range(actual_len):
for j in range(i + 1, actual_len):
ref = realop(naive[i:j])
out = tree.reduce(i, j)
assert np.allclose(ref, out)
# batch setitem
for _ in range(1000):
index = np.random.choice(actual_len, size=4)
value = np.random.rand(4)
naive[index] = value
tree[index] = value
assert np.allclose(realop(naive), tree.reduce())
for i in range(10):
left = np.random.randint(actual_len)
right = np.random.randint(left + 1, actual_len + 1)
assert np.allclose(realop(naive[left:right]),
tree.reduce(left, right))
# large test
actual_len = 16384
tree = SegmentTree(actual_len, op)
naive = np.full([actual_len], init)
for _ in range(1000):
index = np.random.choice(actual_len, size=64)
value = np.random.rand(64)
naive[index] = value
tree[index] = value
assert np.allclose(realop(naive), tree.reduce())
for i in range(10):
left = np.random.randint(actual_len)
right = np.random.randint(left + 1, actual_len + 1)
assert np.allclose(realop(naive[left:right]),
tree.reduce(left, right))
# test prefix-sum-idx
actual_len = 8
tree = SegmentTree(actual_len)
naive = np.random.rand(actual_len)
tree[np.arange(actual_len)] = naive
for _ in range(1000):
scalar = np.random.rand() * naive.sum()
index = tree.get_prefix_sum_idx(scalar)
assert naive[:index].sum() <= scalar <= naive[:index + 1].sum()
# corner case here
naive = np.ones(actual_len, np.int)
tree[np.arange(actual_len)] = naive
for scalar in range(actual_len):
index = tree.get_prefix_sum_idx(scalar * 1.)
assert naive[:index].sum() <= scalar <= naive[:index + 1].sum()
tree = SegmentTree(10)
tree[np.arange(3)] = np.array([0.1, 0, 0.1])
assert np.allclose(tree.get_prefix_sum_idx(
np.array([0, .1, .1 + 1e-6, .2 - 1e-6])), [0, 0, 2, 2])
with pytest.raises(AssertionError):
tree.get_prefix_sum_idx(.2)
# test large prefix-sum-idx
actual_len = 16384
tree = SegmentTree(actual_len)
naive = np.random.rand(actual_len)
tree[np.arange(actual_len)] = naive
for _ in range(1000):
scalar = np.random.rand() * naive.sum()
index = tree.get_prefix_sum_idx(scalar)
assert naive[:index].sum() <= scalar <= naive[:index + 1].sum()
# profile
if __name__ == '__main__':
size = 100000
bsz = 64
naive = np.random.rand(size)
tree = SegmentTree(size)
tree[np.arange(size)] = naive
def sample_npbuf():
return np.random.choice(size, bsz, p=naive / naive.sum())
def sample_tree():
scalar = np.random.rand(bsz) * tree.reduce()
return tree.get_prefix_sum_idx(scalar)
print('npbuf', timeit(sample_npbuf, setup=sample_npbuf, number=1000))
print('tree', timeit(sample_tree, setup=sample_tree, number=1000))
if __name__ == '__main__':
test_replaybuffer()
test_ignore_obs_next()
test_stack()
test_segtree()
test_priortized_replaybuffer()
test_priortized_replaybuffer(233333, 200000)
test_update()

View File

@ -8,9 +8,9 @@ from torch.utils.tensorboard import SummaryWriter
from tianshou.env import VectorEnv
from tianshou.policy import DQNPolicy
from tianshou.trainer import offpolicy_trainer
from tianshou.data import Collector, ReplayBuffer
from tianshou.utils.net.common import Net
from tianshou.trainer import offpolicy_trainer
from tianshou.data import Collector, ReplayBuffer, PrioritizedReplayBuffer
def get_args():
@ -33,6 +33,9 @@ def get_args():
parser.add_argument('--test-num', type=int, default=100)
parser.add_argument('--logdir', type=str, default='log')
parser.add_argument('--render', type=float, default=0.)
parser.add_argument('--prioritized-replay', type=int, default=0)
parser.add_argument('--alpha', type=float, default=0.6)
parser.add_argument('--beta', type=float, default=0.4)
parser.add_argument(
'--device', type=str,
default='cuda' if torch.cuda.is_available() else 'cpu')
@ -58,15 +61,20 @@ def test_dqn(args=get_args()):
test_envs.seed(args.seed)
# model
net = Net(args.layer_num, args.state_shape,
args.action_shape, args.device,
dueling=(2, 2)).to(args.device)
args.action_shape, args.device, # dueling=(1, 1)
).to(args.device)
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
policy = DQNPolicy(
net, optim, args.gamma, args.n_step,
target_update_freq=args.target_update_freq)
# buffer
if args.prioritized_replay > 0:
buf = PrioritizedReplayBuffer(
args.buffer_size, alpha=args.alpha, beta=args.beta)
else:
buf = ReplayBuffer(args.buffer_size)
# collector
train_collector = Collector(
policy, train_envs, ReplayBuffer(args.buffer_size))
train_collector = Collector(policy, train_envs, buf)
test_collector = Collector(policy, test_envs)
# policy.set_eps(1)
train_collector.collect(n_step=args.batch_size)
@ -114,5 +122,11 @@ def test_dqn(args=get_args()):
collector.close()
def test_pdqn(args=get_args()):
args.prioritized_replay = 1
args.gamma = .95
test_dqn(args)
if __name__ == '__main__':
test_dqn(get_args())

View File

@ -1,118 +0,0 @@
import os
import gym
import torch
import pprint
import argparse
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from tianshou.utils.net.common import Net
from tianshou.env import VectorEnv
from tianshou.policy import DQNPolicy
from tianshou.trainer import offpolicy_trainer
from tianshou.data import Collector, ReplayBuffer, PrioritizedReplayBuffer
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='CartPole-v0')
parser.add_argument('--seed', type=int, default=1626)
parser.add_argument('--eps-test', type=float, default=0.05)
parser.add_argument('--eps-train', type=float, default=0.1)
parser.add_argument('--buffer-size', type=int, default=20000)
parser.add_argument('--lr', type=float, default=1e-3)
parser.add_argument('--gamma', type=float, default=0.9)
parser.add_argument('--n-step', type=int, default=3)
parser.add_argument('--target-update-freq', type=int, default=320)
parser.add_argument('--epoch', type=int, default=10)
parser.add_argument('--step-per-epoch', type=int, default=1000)
parser.add_argument('--collect-per-step', type=int, default=10)
parser.add_argument('--batch-size', type=int, default=64)
parser.add_argument('--layer-num', type=int, default=3)
parser.add_argument('--training-num', type=int, default=8)
parser.add_argument('--test-num', type=int, default=100)
parser.add_argument('--logdir', type=str, default='log')
parser.add_argument('--render', type=float, default=0.)
parser.add_argument('--prioritized-replay', type=int, default=1)
parser.add_argument('--alpha', type=float, default=0.5)
parser.add_argument('--beta', type=float, default=0.5)
parser.add_argument(
'--device', type=str,
default='cuda' if torch.cuda.is_available() else 'cpu')
args = parser.parse_known_args()[0]
return args
def test_pdqn(args=get_args()):
env = gym.make(args.task)
args.state_shape = env.observation_space.shape or env.observation_space.n
args.action_shape = env.action_space.shape or env.action_space.n
# train_envs = gym.make(args.task)
# you can also use tianshou.env.SubprocVectorEnv
train_envs = VectorEnv(
[lambda: gym.make(args.task) for _ in range(args.training_num)])
# test_envs = gym.make(args.task)
test_envs = VectorEnv(
[lambda: gym.make(args.task) for _ in range(args.test_num)])
# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
train_envs.seed(args.seed)
test_envs.seed(args.seed)
# model
net = Net(args.layer_num, args.state_shape,
args.action_shape, args.device).to(args.device)
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
policy = DQNPolicy(
net, optim, args.gamma, args.n_step,
target_update_freq=args.target_update_freq)
# collector
if args.prioritized_replay > 0:
buf = PrioritizedReplayBuffer(
args.buffer_size, alpha=args.alpha,
beta=args.alpha, repeat_sample=True)
else:
buf = ReplayBuffer(args.buffer_size)
train_collector = Collector(
policy, train_envs, buf)
test_collector = Collector(policy, test_envs)
# policy.set_eps(1)
train_collector.collect(n_step=args.batch_size)
# log
log_path = os.path.join(args.logdir, args.task, 'dqn')
writer = SummaryWriter(log_path)
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
def stop_fn(x):
return x >= env.spec.reward_threshold
def train_fn(x):
policy.set_eps(args.eps_train)
def test_fn(x):
policy.set_eps(args.eps_test)
# trainer
result = offpolicy_trainer(
policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.collect_per_step, args.test_num,
args.batch_size, train_fn=train_fn, test_fn=test_fn,
stop_fn=stop_fn, save_fn=save_fn, writer=writer)
assert stop_fn(result['best_reward'])
train_collector.close()
test_collector.close()
if __name__ == '__main__':
pprint.pprint(result)
# Let's watch its performance!
env = gym.make(args.task)
collector = Collector(policy, env)
result = collector.collect(n_episode=1, render=args.render)
print(f'Final reward: {result["rew"]}, length: {result["len"]}')
collector.close()
if __name__ == '__main__':
test_pdqn(get_args())

View File

@ -1,8 +1,8 @@
import numpy as np
import pytest
import numpy as np
from tianshou.data import (ListReplayBuffer, PrioritizedReplayBuffer,
ReplayBuffer)
ReplayBuffer, SegmentTree)
@pytest.fixture(scope="module")
@ -21,7 +21,7 @@ def data():
'buffer': buffer,
'buffer2': buffer2,
'slice': slice(-3000, -1000, 2),
'indexes': indexes
'indexes': indexes,
}
@ -77,5 +77,15 @@ def test_sample(data):
buffer.sample(int(1e2))
def test_segtree(data):
size = 100000
tree = SegmentTree(size)
tree[np.arange(size)] = np.random.rand(size)
for i in np.arange(1e5):
scalar = np.random.rand(64) * tree.reduce()
tree.get_prefix_sum_idx(scalar)
if __name__ == '__main__':
pytest.main(["-s", "-k buffer_profile", "--durations=0", "-v"])

View File

@ -1,6 +1,7 @@
from tianshou.data.batch import Batch
from tianshou.data.utils import to_numpy, to_torch, \
from tianshou.data.utils.converter import to_numpy, to_torch, \
to_torch_as
from tianshou.data.utils.segtree import SegmentTree
from tianshou.data.buffer import ReplayBuffer, \
ListReplayBuffer, PrioritizedReplayBuffer
from tianshou.data.collector import Collector
@ -10,8 +11,9 @@ __all__ = [
'to_numpy',
'to_torch',
'to_torch_as',
'SegmentTree',
'ReplayBuffer',
'ListReplayBuffer',
'PrioritizedReplayBuffer',
'Collector'
'Collector',
]

View File

@ -1,7 +1,9 @@
import torch
import numpy as np
from typing import Any, Tuple, Union, Optional
from tianshou.data.batch import Batch, _create_value
from tianshou.data import Batch, SegmentTree, to_numpy
from tianshou.data.batch import _create_value
class ReplayBuffer:
@ -313,7 +315,7 @@ class ReplayBuffer:
done=self.done[index],
obs_next=self.get(index, 'obs_next'),
info=self.get(index, 'info'),
policy=self.get(index, 'policy')
policy=self.get(index, 'policy'),
)
@ -326,8 +328,8 @@ class ListReplayBuffer(ReplayBuffer):
.. seealso::
Please refer to :class:`~tianshou.data.ReplayBuffer` for more
detailed explanation.
Please refer to :class:`~tianshou.data.ReplayBuffer` for more detailed
explanation.
"""
def __init__(self, **kwargs) -> None:
@ -353,31 +355,32 @@ class ListReplayBuffer(ReplayBuffer):
class PrioritizedReplayBuffer(ReplayBuffer):
"""Prioritized replay buffer implementation.
"""Implementation of Prioritized Experience Replay. arXiv:1511.05952
:param float alpha: the prioritization exponent.
:param float beta: the importance sample soft coefficient.
:param str mode: defaults to ``weight``.
:param bool replace: whether to sample with replacement
.. seealso::
Please refer to :class:`~tianshou.data.ReplayBuffer` for more
detailed explanation.
Please refer to :class:`~tianshou.data.ReplayBuffer` for more detailed
explanation.
"""
def __init__(self, size: int, alpha: float, beta: float,
mode: str = 'weight',
replace: bool = False, **kwargs) -> None:
if mode != 'weight':
raise NotImplementedError
def __init__(self, size: int, alpha: float, beta: float, **kwargs) -> None:
super().__init__(size, **kwargs)
self._alpha = alpha
self._beta = beta
self._weight_sum = 0.0
self._amortization_freq = 50
self._replace = replace
self._meta.weight = np.zeros(size, dtype=np.float64)
assert alpha > 0. and beta >= 0.
self._alpha, self._beta = alpha, beta
self._max_prio = 1.
self._min_prio = 1.
# bypass the check
self._weight = SegmentTree(size)
self.__eps = np.finfo(np.float32).eps.item()
def __getattr__(self, key: str) -> Union['Batch', Any]:
"""Return self.key"""
if key == 'weight':
return self._weight
return self._meta.__dict__[key]
def add(self,
obs: Union[dict, np.ndarray],
@ -387,68 +390,55 @@ class PrioritizedReplayBuffer(ReplayBuffer):
obs_next: Optional[Union[dict, np.ndarray]] = None,
info: dict = {},
policy: Optional[Union[dict, Batch]] = {},
weight: float = 1.0,
weight: float = None,
**kwargs) -> None:
"""Add a batch of data into replay buffer."""
# we have to sacrifice some convenience for speed
self._weight_sum += np.abs(weight) ** self._alpha - \
self._meta.weight[self._index]
self._add_to_buffer('weight', np.abs(weight) ** self._alpha)
if weight is None:
weight = self._max_prio
else:
weight = np.abs(weight)
self._max_prio = max(self._max_prio, weight)
self._min_prio = min(self._min_prio, weight)
self.weight[self._index] = weight ** self._alpha
super().add(obs, act, rew, done, obs_next, info, policy)
@property
def replace(self):
return self._replace
@replace.setter
def replace(self, v: bool):
self._replace = v
def sample(self, batch_size: int) -> Tuple[Batch, np.ndarray]:
"""Get a random sample from buffer with priority probability. \
Return all the data in the buffer if batch_size is ``0``.
"""Get a random sample from buffer with priority probability. Return
all the data in the buffer if batch_size is ``0``.
:return: Sample data and its corresponding index inside the buffer.
The ``weight`` in the returned Batch is the weight on loss function
to de-bias the sampling process (some transition tuples are sampled
more often so their losses are weighted less).
"""
assert self._size > 0, 'cannot sample a buffer with size == 0 !'
p = None
if batch_size > 0 and (self._replace or batch_size <= self._size):
# sampling weight
p = (self.weight / self.weight.sum())[:self._size]
indice = np.random.choice(
self._size, batch_size, p=p,
replace=self._replace)
p = p[indice] # weight of each sample
elif batch_size == 0:
p = np.full(shape=self._size, fill_value=1.0 / self._size)
assert self._size > 0, 'Cannot sample a buffer with 0 size!'
if batch_size == 0:
indice = np.concatenate([
np.arange(self._index, self._size),
np.arange(0, self._index),
])
else:
raise ValueError(
f"batch_size should be less than {len(self)}, \
or set replace=True")
scalar = np.random.rand(batch_size) * self.weight.reduce()
indice = self.weight.get_prefix_sum_idx(scalar)
batch = self[indice]
batch["impt_weight"] = (self._size * p) ** (-self._beta)
# impt_weight
# original formula: ((p_j/p_sum*N)**(-beta))/((p_min/p_sum*N)**(-beta))
# simplified formula: (p_j/p_min)**(-beta)
batch.weight = (batch.weight / self._min_prio) ** (-self._beta)
return batch, indice
def update_weight(self, indice: Union[slice, np.ndarray],
new_weight: np.ndarray) -> None:
def update_weight(self, indice: Union[np.ndarray],
new_weight: Union[np.ndarray, torch.Tensor]) -> None:
"""Update priority weight by indice in this buffer.
:param np.ndarray indice: indice you want to update weight
:param np.ndarray new_weight: new priority weight you want to update
:param np.ndarray indice: indice you want to update weight.
:param np.ndarray new_weight: new priority weight you want to update.
"""
if self._replace:
if isinstance(indice, slice):
# convert slice to ndarray
indice = np.arange(indice.stop)[indice]
# remove the same values in indice
indice, unique_indice = np.unique(
indice, return_index=True)
new_weight = new_weight[unique_indice]
self.weight[indice] = np.power(np.abs(new_weight), self._alpha)
weight = np.abs(to_numpy(new_weight)) + self.__eps
self.weight[indice] = weight ** self._alpha
self._max_prio = max(self._max_prio, weight.max())
self._min_prio = min(self._min_prio, weight.min())
def __getitem__(self, index: Union[
slice, int, np.integer, np.ndarray]) -> Batch:
@ -459,6 +449,6 @@ class PrioritizedReplayBuffer(ReplayBuffer):
done=self.done[index],
obs_next=self.get(index, 'obs_next'),
info=self.get(index, 'info'),
weight=self.weight[index],
policy=self.get(index, 'policy'),
weight=self.weight[index],
)

View File

View File

@ -0,0 +1,133 @@
import numpy as np
from typing import Union, Optional
# from numba import njit
# numba version, 5x speed up
# with size=100000 and bsz=64
# first block (vectorized np): 0.0923 (now) -> 0.0251
# second block (for-loop): 0.2914 -> 0.0192 (future)
# @njit
def _get_prefix_sum_idx(value, bound, sums):
index = np.ones(value.shape, dtype=np.int64)
while index[0] < bound:
index *= 2
direct = sums[index] < value
value -= sums[index] * direct
index += direct
# for _, s in enumerate(value):
# i = 1
# while i < bound:
# l = i * 2
# if sums[l] >= s:
# i = l
# else:
# s = s - sums[l]
# i = l + 1
# index[_] = i
index -= bound
return index
class SegmentTree:
"""Implementation of Segment Tree: store an array ``arr`` with size ``n``
in a segment tree, support value update and fast query of ``min/max/sum``
for the interval ``[left, right)`` in O(log n) time.
The detailed procedure is as follows:
1. Pad the array to have length of power of 2, so that leaf nodes in the\
segment tree have the same depth.
2. Store the segment tree in a binary heap.
:param int size: the size of segment tree.
:param str operation: the operation of segment tree. Choices are "sum",
"min" and "max". Default: "sum".
"""
def __init__(self, size: int,
operation: str = 'sum') -> None:
bound = 1
while bound < size:
bound *= 2
self._size = size
self._bound = bound
assert operation in ['sum', 'min', 'max'], \
f'Unknown operation {operation}.'
if operation == 'sum':
self._op, self._init_value = np.add, 0.
elif operation == 'min':
self._op, self._init_value = np.minimum, np.inf
else:
self._op, self._init_value = np.maximum, -np.inf
# assert isinstance(self._op, np.ufunc)
self._value = np.full([bound * 2], self._init_value)
def __len__(self):
return self._size
def __getitem__(self, index: Union[int, np.ndarray]
) -> Union[float, np.ndarray]:
"""Return self[index]"""
return self._value[index + self._bound]
def __setitem__(self, index: Union[int, np.ndarray],
value: Union[float, np.ndarray]) -> None:
"""Duplicate values in ``index`` are handled by numpy: later index
overwrites previous ones.
::
>>> a = np.array([1, 2, 3, 4])
>>> a[[0, 1, 0, 1]] = [4, 5, 6, 7]
>>> print(a)
[6 7 3 4]
"""
# TODO numba njit version
if isinstance(index, int):
index = np.array([index])
assert np.all(0 <= index) and np.all(index < self._size)
if self._op is np.add:
assert np.all(0. <= value)
index = index + self._bound
self._value[index] = value
while index[0] > 1:
index //= 2
self._value[index] = self._op(
self._value[index * 2], self._value[index * 2 + 1])
def reduce(self, start: Optional[int] = 0,
end: Optional[int] = None) -> float:
"""Return operation(value[start:end])."""
# TODO numba njit version
if start == 0 and end is None:
return self._value[1]
if end is None:
end = self._size
if end < 0:
end += self._size
# nodes in (start, end) should be aggregated
start, end = start + self._bound - 1, end + self._bound
result = self._init_value
while end - start > 1: # (start, end) interval is not empty
if start % 2 == 0:
result = self._op(result, self._value[start + 1])
if end % 2 == 1:
result = self._op(result, self._value[end - 1])
start, end = start // 2, end // 2
return result
def get_prefix_sum_idx(
self, value: Union[float, np.ndarray]) -> Union[int, np.ndarray]:
"""Return the minimum index for each ``v`` in ``value`` so that
``v <= sums[i]``, where sums[i] = \\sum_{j=0}^{i} arr[j].
"""
assert self._op is np.add
assert np.all(value >= 0.) and np.all(value < self._value[1])
single = False
if not isinstance(value, np.ndarray):
value = np.array([value])
single = True
index = _get_prefix_sum_idx(value, self._bound, self._value)
return index.item() if single else index

View File

@ -4,7 +4,8 @@ from torch import nn
from abc import ABC, abstractmethod
from typing import Dict, List, Union, Optional, Callable
from tianshou.data import Batch, ReplayBuffer, to_torch_as, to_numpy
from tianshou.data import Batch, ReplayBuffer, PrioritizedReplayBuffer, \
to_torch_as, to_numpy
class BasePolicy(ABC, nn.Module):
@ -213,4 +214,11 @@ class BasePolicy(ABC, nn.Module):
returns = to_torch_as(returns, target_q)
gammas = to_torch_as(gamma ** gammas, target_q)
batch.returns = target_q * gammas + returns
# prio buffer update
if isinstance(buffer, PrioritizedReplayBuffer):
batch.update_weight = buffer.update_weight
batch.indice = indice
batch.weight = to_torch_as(batch.weight, target_q)
else:
batch.weight = torch.ones_like(target_q)
return batch

View File

@ -1,7 +1,6 @@
import torch
import numpy as np
from copy import deepcopy
import torch.nn.functional as F
from typing import Dict, Tuple, Union, Optional
from tianshou.policy import BasePolicy
@ -144,7 +143,11 @@ class DDPGPolicy(BasePolicy):
def learn(self, batch: Batch, **kwargs) -> Dict[str, float]:
current_q = self.critic(batch.obs, batch.act).flatten()
target_q = batch.returns.flatten()
critic_loss = F.mse_loss(current_q, target_q)
td = current_q - target_q
if hasattr(batch, 'update_weight'): # prio-buffer
batch.update_weight(batch.indice, td)
critic_loss = (td.pow(2) * batch.weight).mean()
# critic_loss = F.mse_loss(current_q, target_q)
self.critic_optim.zero_grad()
critic_loss.backward()
self.critic_optim.step()

View File

@ -1,12 +1,10 @@
import torch
import numpy as np
from copy import deepcopy
import torch.nn.functional as F
from typing import Dict, Union, Optional
from tianshou.policy import BasePolicy
from tianshou.data import Batch, ReplayBuffer, PrioritizedReplayBuffer, \
to_torch_as, to_numpy
from tianshou.data import Batch, ReplayBuffer, to_torch_as, to_numpy
class DQNPolicy(BasePolicy):
@ -95,9 +93,6 @@ class DQNPolicy(BasePolicy):
batch = self.compute_nstep_return(
batch, buffer, indice, self._target_q,
self._gamma, self._n_step, self._rew_norm)
if isinstance(buffer, PrioritizedReplayBuffer):
batch.update_weight = buffer.update_weight
batch.indice = indice
return batch
def forward(self, batch: Batch,
@ -164,13 +159,11 @@ class DQNPolicy(BasePolicy):
q = self(batch, eps=0.).logits
q = q[np.arange(len(q)), batch.act]
r = to_torch_as(batch.returns, q).flatten()
if hasattr(batch, 'update_weight'):
td = r - q
batch.update_weight(batch.indice, to_numpy(td))
impt_weight = to_torch_as(batch.impt_weight, q)
loss = (td.pow(2) * impt_weight).mean()
else:
loss = F.mse_loss(q, r)
td = r - q
if hasattr(batch, 'update_weight'): # prio-buffer
batch.update_weight(batch.indice, td)
loss = (td.pow(2) * batch.weight).mean()
# loss = F.mse_loss(q, r)
loss.backward()
self.optim.step()
self._cnt += 1

View File

@ -1,7 +1,6 @@
import torch
import numpy as np
from copy import deepcopy
import torch.nn.functional as F
from typing import Dict, Tuple, Union, Optional
from tianshou.policy import DDPGPolicy
@ -141,16 +140,23 @@ class SACPolicy(DDPGPolicy):
# critic 1
current_q1 = self.critic1(batch.obs, batch.act).flatten()
target_q = batch.returns.flatten()
critic1_loss = F.mse_loss(current_q1, target_q)
td1 = current_q1 - target_q
critic1_loss = (td1.pow(2) * batch.weight).mean()
# critic1_loss = F.mse_loss(current_q1, target_q)
self.critic1_optim.zero_grad()
critic1_loss.backward()
self.critic1_optim.step()
# critic 2
current_q2 = self.critic2(batch.obs, batch.act).flatten()
critic2_loss = F.mse_loss(current_q2, target_q)
td2 = current_q2 - target_q
critic2_loss = (td2.pow(2) * batch.weight).mean()
# critic2_loss = F.mse_loss(current_q2, target_q)
self.critic2_optim.zero_grad()
critic2_loss.backward()
self.critic2_optim.step()
# prio-buffer
if hasattr(batch, 'update_weight'):
batch.update_weight(batch.indice, (td1 + td2) / 2.)
# actor
obs_result = self(batch, explorating=False)
a = obs_result.act

View File

@ -1,7 +1,6 @@
import torch
import numpy as np
from copy import deepcopy
import torch.nn.functional as F
from typing import Dict, Tuple, Optional
from tianshou.policy import DDPGPolicy
@ -119,16 +118,22 @@ class TD3Policy(DDPGPolicy):
# critic 1
current_q1 = self.critic1(batch.obs, batch.act).flatten()
target_q = batch.returns.flatten()
critic1_loss = F.mse_loss(current_q1, target_q)
td1 = current_q1 - target_q
critic1_loss = (td1.pow(2) * batch.weight).mean()
# critic1_loss = F.mse_loss(current_q1, target_q)
self.critic1_optim.zero_grad()
critic1_loss.backward()
self.critic1_optim.step()
# critic 2
current_q2 = self.critic2(batch.obs, batch.act).flatten()
critic2_loss = F.mse_loss(current_q2, target_q)
td2 = current_q2 - target_q
critic2_loss = (td2.pow(2) * batch.weight).mean()
# critic2_loss = F.mse_loss(current_q2, target_q)
self.critic2_optim.zero_grad()
critic2_loss.backward()
self.critic2_optim.step()
if hasattr(batch, 'update_weight'): # prio-buffer
batch.update_weight(batch.indice, (td1 + td2) / 2.)
if self._cnt % self._freq == 0:
actor_loss = -self.critic1(
batch.obs, self(batch, eps=0).act).mean()