Prioritized DQN (#30)
* add sum_tree.py * add prioritized replay buffer * del sum_tree.py * fix some format issues * fix weight_update bug * simply replace replaybuffer in test_dqn without weight update * weight default set to 1 * fix sampling bug when buffer is not full * rename parameter * fix formula error, add accuracy check * add PrioritizedDQN test * add test_pdqn.py * add update_weight() doc * add ref of prio dqn in readme.md and index.rst * restore test_dqn.py, fix args of test_pdqn.py
This commit is contained in:
parent
70290346ea
commit
b23749463e
@ -20,6 +20,7 @@
|
|||||||
- [Policy Gradient (PG)](https://papers.nips.cc/paper/1713-policy-gradient-methods-for-reinforcement-learning-with-function-approximation.pdf)
|
- [Policy Gradient (PG)](https://papers.nips.cc/paper/1713-policy-gradient-methods-for-reinforcement-learning-with-function-approximation.pdf)
|
||||||
- [Deep Q-Network (DQN)](https://storage.googleapis.com/deepmind-media/dqn/DQNNaturePaper.pdf)
|
- [Deep Q-Network (DQN)](https://storage.googleapis.com/deepmind-media/dqn/DQNNaturePaper.pdf)
|
||||||
- [Double DQN (DDQN)](https://arxiv.org/pdf/1509.06461.pdf) with n-step returns
|
- [Double DQN (DDQN)](https://arxiv.org/pdf/1509.06461.pdf) with n-step returns
|
||||||
|
- [Prioritized DQN (PDQN)](https://arxiv.org/pdf/1511.05952.pdf))
|
||||||
- [Advantage Actor-Critic (A2C)](https://openai.com/blog/baselines-acktr-a2c/)
|
- [Advantage Actor-Critic (A2C)](https://openai.com/blog/baselines-acktr-a2c/)
|
||||||
- [Deep Deterministic Policy Gradient (DDPG)](https://arxiv.org/pdf/1509.02971.pdf)
|
- [Deep Deterministic Policy Gradient (DDPG)](https://arxiv.org/pdf/1509.02971.pdf)
|
||||||
- [Proximal Policy Optimization (PPO)](https://arxiv.org/pdf/1707.06347.pdf)
|
- [Proximal Policy Optimization (PPO)](https://arxiv.org/pdf/1707.06347.pdf)
|
||||||
|
@ -11,6 +11,7 @@ Welcome to Tianshou!
|
|||||||
* :class:`~tianshou.policy.PGPolicy` `Policy Gradient <https://papers.nips.cc/paper/1713-policy-gradient-methods-for-reinforcement-learning-with-function-approximation.pdf>`_
|
* :class:`~tianshou.policy.PGPolicy` `Policy Gradient <https://papers.nips.cc/paper/1713-policy-gradient-methods-for-reinforcement-learning-with-function-approximation.pdf>`_
|
||||||
* :class:`~tianshou.policy.DQNPolicy` `Deep Q-Network <https://storage.googleapis.com/deepmind-media/dqn/DQNNaturePaper.pdf>`_
|
* :class:`~tianshou.policy.DQNPolicy` `Deep Q-Network <https://storage.googleapis.com/deepmind-media/dqn/DQNNaturePaper.pdf>`_
|
||||||
* :class:`~tianshou.policy.DQNPolicy` `Double DQN <https://arxiv.org/pdf/1509.06461.pdf>`_ with n-step returns
|
* :class:`~tianshou.policy.DQNPolicy` `Double DQN <https://arxiv.org/pdf/1509.06461.pdf>`_ with n-step returns
|
||||||
|
* :class:`~tianshou.policy.DQNPolicy` `Prioritized DQN <https://arxiv.org/pdf/1511.05952.pdf`_
|
||||||
* :class:`~tianshou.policy.A2CPolicy` `Advantage Actor-Critic <https://openai.com/blog/baselines-acktr-a2c/>`_
|
* :class:`~tianshou.policy.A2CPolicy` `Advantage Actor-Critic <https://openai.com/blog/baselines-acktr-a2c/>`_
|
||||||
* :class:`~tianshou.policy.DDPGPolicy` `Deep Deterministic Policy Gradient <https://arxiv.org/pdf/1509.02971.pdf>`_
|
* :class:`~tianshou.policy.DDPGPolicy` `Deep Deterministic Policy Gradient <https://arxiv.org/pdf/1509.02971.pdf>`_
|
||||||
* :class:`~tianshou.policy.PPOPolicy` `Proximal Policy Optimization <https://arxiv.org/pdf/1707.06347.pdf>`_
|
* :class:`~tianshou.policy.PPOPolicy` `Proximal Policy Optimization <https://arxiv.org/pdf/1707.06347.pdf>`_
|
||||||
|
37
test/base/test_prioritized_replay_buffer.py
Normal file
37
test/base/test_prioritized_replay_buffer.py
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
import numpy as np
|
||||||
|
from tianshou.data import PrioritizedReplayBuffer
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
from env import MyTestEnv
|
||||||
|
else: # pytest
|
||||||
|
from test.base.env import MyTestEnv
|
||||||
|
|
||||||
|
|
||||||
|
def test_replaybuffer(size=32, bufsize=15):
|
||||||
|
env = MyTestEnv(size)
|
||||||
|
buf = PrioritizedReplayBuffer(bufsize, 0.5, 0.5)
|
||||||
|
obs = env.reset()
|
||||||
|
action_list = [1] * 5 + [0] * 10 + [1] * 10
|
||||||
|
for i, a in enumerate(action_list):
|
||||||
|
obs_next, rew, done, info = env.step(a)
|
||||||
|
buf.add(obs, a, rew, done, obs_next, info, np.random.randn()-0.5)
|
||||||
|
obs = obs_next
|
||||||
|
assert np.isclose(np.sum((buf.weight/buf._weight_sum)[:buf._size]), 1,
|
||||||
|
rtol=1e-12)
|
||||||
|
data, indice = buf.sample(len(buf) // 2)
|
||||||
|
if len(buf)//2 == 0:
|
||||||
|
assert len(data) == len(buf)
|
||||||
|
else:
|
||||||
|
assert len(data) == len(buf)//2
|
||||||
|
assert len(buf) == min(bufsize, i + 1), print(len(buf), i)
|
||||||
|
assert np.isclose(buf._weight_sum, (buf.weight).sum())
|
||||||
|
data, indice = buf.sample(len(buf) // 2)
|
||||||
|
buf.update_weight(indice, -data.weight/2)
|
||||||
|
assert np.isclose(buf.weight[indice], np.power(
|
||||||
|
np.abs(-data.weight/2), buf._alpha)).all()
|
||||||
|
assert np.isclose(buf._weight_sum, (buf.weight).sum())
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_replaybuffer(233333, 200000)
|
||||||
|
print("pass")
|
122
test/discrete/test_pdqn.py
Normal file
122
test/discrete/test_pdqn.py
Normal file
@ -0,0 +1,122 @@
|
|||||||
|
import os
|
||||||
|
import gym
|
||||||
|
import torch
|
||||||
|
import pprint
|
||||||
|
import argparse
|
||||||
|
import numpy as np
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
|
from tianshou.env import VectorEnv
|
||||||
|
from tianshou.policy import DQNPolicy
|
||||||
|
from tianshou.trainer import offpolicy_trainer
|
||||||
|
from tianshou.data import Collector, ReplayBuffer, PrioritizedReplayBuffer
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
from net import Net
|
||||||
|
else: # pytest
|
||||||
|
from test.discrete.net import Net
|
||||||
|
|
||||||
|
|
||||||
|
def get_args():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--task', type=str, default='CartPole-v0')
|
||||||
|
parser.add_argument('--seed', type=int, default=1626)
|
||||||
|
parser.add_argument('--eps-test', type=float, default=0.05)
|
||||||
|
parser.add_argument('--eps-train', type=float, default=0.1)
|
||||||
|
parser.add_argument('--buffer-size', type=int, default=20000)
|
||||||
|
parser.add_argument('--lr', type=float, default=1e-3)
|
||||||
|
parser.add_argument('--gamma', type=float, default=0.9)
|
||||||
|
parser.add_argument('--n-step', type=int, default=3)
|
||||||
|
parser.add_argument('--target-update-freq', type=int, default=320)
|
||||||
|
parser.add_argument('--epoch', type=int, default=10)
|
||||||
|
parser.add_argument('--step-per-epoch', type=int, default=1000)
|
||||||
|
parser.add_argument('--collect-per-step', type=int, default=10)
|
||||||
|
parser.add_argument('--batch-size', type=int, default=64)
|
||||||
|
parser.add_argument('--layer-num', type=int, default=3)
|
||||||
|
parser.add_argument('--training-num', type=int, default=8)
|
||||||
|
parser.add_argument('--test-num', type=int, default=100)
|
||||||
|
parser.add_argument('--logdir', type=str, default='log')
|
||||||
|
parser.add_argument('--render', type=float, default=0.)
|
||||||
|
parser.add_argument('--prioritized-replay', type=int, default=1)
|
||||||
|
parser.add_argument('--alpha', type=float, default=0.5)
|
||||||
|
parser.add_argument('--beta', type=float, default=0.5)
|
||||||
|
parser.add_argument(
|
||||||
|
'--device', type=str,
|
||||||
|
default='cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
|
args = parser.parse_known_args()[0]
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
def test_pdqn(args=get_args()):
|
||||||
|
env = gym.make(args.task)
|
||||||
|
args.state_shape = env.observation_space.shape or env.observation_space.n
|
||||||
|
args.action_shape = env.action_space.shape or env.action_space.n
|
||||||
|
# train_envs = gym.make(args.task)
|
||||||
|
# you can also use tianshou.env.SubprocVectorEnv
|
||||||
|
train_envs = VectorEnv(
|
||||||
|
[lambda: gym.make(args.task) for _ in range(args.training_num)])
|
||||||
|
# test_envs = gym.make(args.task)
|
||||||
|
test_envs = VectorEnv(
|
||||||
|
[lambda: gym.make(args.task) for _ in range(args.test_num)])
|
||||||
|
# seed
|
||||||
|
np.random.seed(args.seed)
|
||||||
|
torch.manual_seed(args.seed)
|
||||||
|
train_envs.seed(args.seed)
|
||||||
|
test_envs.seed(args.seed)
|
||||||
|
# model
|
||||||
|
net = Net(args.layer_num, args.state_shape, args.action_shape, args.device)
|
||||||
|
net = net.to(args.device)
|
||||||
|
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
|
||||||
|
policy = DQNPolicy(
|
||||||
|
net, optim, args.gamma, args.n_step,
|
||||||
|
use_target_network=args.target_update_freq > 0,
|
||||||
|
target_update_freq=args.target_update_freq)
|
||||||
|
# collector
|
||||||
|
if args.prioritized_replay > 0:
|
||||||
|
buf = PrioritizedReplayBuffer(
|
||||||
|
args.buffer_size, alpha=args.alpha, beta=args.alpha)
|
||||||
|
else:
|
||||||
|
buf = ReplayBuffer(args.buffer_size)
|
||||||
|
train_collector = Collector(
|
||||||
|
policy, train_envs, buf)
|
||||||
|
test_collector = Collector(policy, test_envs)
|
||||||
|
# policy.set_eps(1)
|
||||||
|
train_collector.collect(n_step=args.batch_size)
|
||||||
|
# log
|
||||||
|
log_path = os.path.join(args.logdir, args.task, 'dqn')
|
||||||
|
writer = SummaryWriter(log_path)
|
||||||
|
|
||||||
|
def save_fn(policy):
|
||||||
|
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||||
|
|
||||||
|
def stop_fn(x):
|
||||||
|
return x >= env.spec.reward_threshold
|
||||||
|
|
||||||
|
def train_fn(x):
|
||||||
|
policy.set_eps(args.eps_train)
|
||||||
|
|
||||||
|
def test_fn(x):
|
||||||
|
policy.set_eps(args.eps_test)
|
||||||
|
|
||||||
|
# trainer
|
||||||
|
result = offpolicy_trainer(
|
||||||
|
policy, train_collector, test_collector, args.epoch,
|
||||||
|
args.step_per_epoch, args.collect_per_step, args.test_num,
|
||||||
|
args.batch_size, train_fn=train_fn, test_fn=test_fn,
|
||||||
|
stop_fn=stop_fn, save_fn=save_fn, writer=writer)
|
||||||
|
|
||||||
|
assert stop_fn(result['best_reward'])
|
||||||
|
train_collector.close()
|
||||||
|
test_collector.close()
|
||||||
|
if __name__ == '__main__':
|
||||||
|
pprint.pprint(result)
|
||||||
|
# Let's watch its performance!
|
||||||
|
env = gym.make(args.task)
|
||||||
|
collector = Collector(policy, env)
|
||||||
|
result = collector.collect(n_episode=1, render=args.render)
|
||||||
|
print(f'Final reward: {result["rew"]}, length: {result["len"]}')
|
||||||
|
collector.close()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test_pdqn(get_args())
|
@ -258,14 +258,87 @@ class ListReplayBuffer(ReplayBuffer):
|
|||||||
class PrioritizedReplayBuffer(ReplayBuffer):
|
class PrioritizedReplayBuffer(ReplayBuffer):
|
||||||
"""docstring for PrioritizedReplayBuffer"""
|
"""docstring for PrioritizedReplayBuffer"""
|
||||||
|
|
||||||
def __init__(self, size, **kwargs):
|
def __init__(self, size, alpha: float, beta: float,
|
||||||
|
mode: str = 'weight', **kwargs):
|
||||||
|
if mode != 'weight':
|
||||||
|
raise NotImplementedError
|
||||||
super().__init__(size, **kwargs)
|
super().__init__(size, **kwargs)
|
||||||
|
self._alpha = alpha # prioritization exponent
|
||||||
|
self._beta = beta # importance sample soft coefficient
|
||||||
|
self._weight_sum = 0.0
|
||||||
|
self.weight = np.zeros(size, dtype=np.float64)
|
||||||
|
self._amortization_freq = 50
|
||||||
|
self._amortization_counter = 0
|
||||||
|
|
||||||
def add(self, obs, act, rew, done, obs_next=0, info={}, weight=None):
|
def add(self, obs, act, rew, done, obs_next=0, info={}, weight=1.0):
|
||||||
raise NotImplementedError
|
"""Add a batch of data into replay buffer."""
|
||||||
|
self._weight_sum += np.abs(weight)**self._alpha - \
|
||||||
|
self.weight[self._index]
|
||||||
|
# we have to sacrifice some convenience for speed :(
|
||||||
|
self._add_to_buffer('weight', np.abs(weight)**self._alpha)
|
||||||
|
super().add(obs, act, rew, done, obs_next, info)
|
||||||
|
self._check_weight_sum()
|
||||||
|
|
||||||
def sample(self, batch_size):
|
def sample(self, batch_size: int = 0, importance_sample: bool = True):
|
||||||
raise NotImplementedError
|
""" Get a random sample from buffer with priority probability. \
|
||||||
|
Return all the data in the buffer if batch_size is ``0``.
|
||||||
|
|
||||||
|
:return: Sample data and its corresponding index inside the buffer.
|
||||||
|
"""
|
||||||
|
if batch_size > 0 and batch_size <= self._size:
|
||||||
|
# Multiple sampling of the same sample
|
||||||
|
# will cause weight update conflict
|
||||||
|
indice = np.random.choice(
|
||||||
|
self._size, batch_size,
|
||||||
|
p=(self.weight/self.weight.sum())[:self._size], replace=False)
|
||||||
|
# self._weight_sum is not work for the accuracy issue
|
||||||
|
# p=(self.weight/self._weight_sum)[:self._size], replace=False)
|
||||||
|
elif batch_size == 0:
|
||||||
|
indice = np.concatenate([
|
||||||
|
np.arange(self._index, self._size),
|
||||||
|
np.arange(0, self._index),
|
||||||
|
])
|
||||||
|
else:
|
||||||
|
# if batch_size larger than len(self),
|
||||||
|
# it will lead to a bug in update weight
|
||||||
|
raise ValueError("batch_size should be less than len(self)")
|
||||||
|
batch = self[indice]
|
||||||
|
if importance_sample:
|
||||||
|
impt_weight = Batch(
|
||||||
|
impt_weight=1/np.power(
|
||||||
|
self._size*(batch.weight/self._weight_sum), self._beta))
|
||||||
|
batch.append(impt_weight)
|
||||||
|
self._check_weight_sum()
|
||||||
|
return batch, indice
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
raise NotImplementedError
|
self._amortization_counter = 0
|
||||||
|
super().reset()
|
||||||
|
|
||||||
|
def update_weight(self, indice, new_weight: np.ndarray):
|
||||||
|
"""update priority weight by indice in this buffer
|
||||||
|
|
||||||
|
:param indice: indice you want to update weight
|
||||||
|
:param new_weight: new priority weight you wangt to update
|
||||||
|
"""
|
||||||
|
self._weight_sum += np.power(np.abs(new_weight), self._alpha).sum() \
|
||||||
|
- self.weight[indice].sum()
|
||||||
|
self.weight[indice] = np.power(np.abs(new_weight), self._alpha)
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
return Batch(
|
||||||
|
obs=self.get(index, 'obs'),
|
||||||
|
act=self.act[index],
|
||||||
|
rew=self.rew[index],
|
||||||
|
done=self.done[index],
|
||||||
|
obs_next=self.get(index, 'obs_next'),
|
||||||
|
info=self.info[index],
|
||||||
|
weight=self.weight[index]
|
||||||
|
)
|
||||||
|
|
||||||
|
def _check_weight_sum(self):
|
||||||
|
# keep a accurate _weight_sum
|
||||||
|
self._amortization_counter += 1
|
||||||
|
if self._amortization_counter % self._amortization_freq == 0:
|
||||||
|
self._weight_sum = np.sum(self.weight)
|
||||||
|
self._amortization_counter = 0
|
||||||
|
@ -3,7 +3,7 @@ import numpy as np
|
|||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from tianshou.data import Batch
|
from tianshou.data import Batch, PrioritizedReplayBuffer
|
||||||
from tianshou.policy import BasePolicy
|
from tianshou.policy import BasePolicy
|
||||||
|
|
||||||
|
|
||||||
@ -98,6 +98,18 @@ class DQNPolicy(BasePolicy):
|
|||||||
target_q[gammas != self._n_step] = 0
|
target_q[gammas != self._n_step] = 0
|
||||||
returns += (self._gamma ** gammas) * target_q
|
returns += (self._gamma ** gammas) * target_q
|
||||||
batch.returns = returns
|
batch.returns = returns
|
||||||
|
if isinstance(buffer, PrioritizedReplayBuffer):
|
||||||
|
q = self(batch).logits
|
||||||
|
q = q[np.arange(len(q)), batch.act]
|
||||||
|
r = batch.returns
|
||||||
|
if isinstance(r, np.ndarray):
|
||||||
|
r = torch.tensor(r, device=q.device, dtype=q.dtype)
|
||||||
|
td = r-q
|
||||||
|
buffer.update_weight(indice, td.detach().numpy())
|
||||||
|
impt_weight = torch.tensor(batch.impt_weight,
|
||||||
|
device=q.device, dtype=torch.float)
|
||||||
|
loss = (td.pow(2)*impt_weight).mean()
|
||||||
|
batch.loss = loss
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
def forward(self, batch, state=None,
|
def forward(self, batch, state=None,
|
||||||
@ -133,12 +145,15 @@ class DQNPolicy(BasePolicy):
|
|||||||
if self._target and self._cnt % self._freq == 0:
|
if self._target and self._cnt % self._freq == 0:
|
||||||
self.sync_weight()
|
self.sync_weight()
|
||||||
self.optim.zero_grad()
|
self.optim.zero_grad()
|
||||||
q = self(batch).logits
|
if hasattr(batch, 'loss'):
|
||||||
q = q[np.arange(len(q)), batch.act]
|
loss = batch.loss
|
||||||
r = batch.returns
|
else:
|
||||||
if isinstance(r, np.ndarray):
|
q = self(batch).logits
|
||||||
r = torch.tensor(r, device=q.device, dtype=q.dtype)
|
q = q[np.arange(len(q)), batch.act]
|
||||||
loss = F.mse_loss(q, r)
|
r = batch.returns
|
||||||
|
if isinstance(r, np.ndarray):
|
||||||
|
r = torch.tensor(r, device=q.device, dtype=q.dtype)
|
||||||
|
loss = F.mse_loss(q, r)
|
||||||
loss.backward()
|
loss.backward()
|
||||||
self.optim.step()
|
self.optim.step()
|
||||||
self._cnt += 1
|
self._cnt += 1
|
||||||
|
Loading…
x
Reference in New Issue
Block a user