diff --git a/README.md b/README.md index 421d88f..fc10ed9 100644 --- a/README.md +++ b/README.md @@ -20,6 +20,7 @@ - [Policy Gradient (PG)](https://papers.nips.cc/paper/1713-policy-gradient-methods-for-reinforcement-learning-with-function-approximation.pdf) - [Deep Q-Network (DQN)](https://storage.googleapis.com/deepmind-media/dqn/DQNNaturePaper.pdf) - [Double DQN (DDQN)](https://arxiv.org/pdf/1509.06461.pdf) with n-step returns +- [Prioritized DQN (PDQN)](https://arxiv.org/pdf/1511.05952.pdf)) - [Advantage Actor-Critic (A2C)](https://openai.com/blog/baselines-acktr-a2c/) - [Deep Deterministic Policy Gradient (DDPG)](https://arxiv.org/pdf/1509.02971.pdf) - [Proximal Policy Optimization (PPO)](https://arxiv.org/pdf/1707.06347.pdf) diff --git a/docs/index.rst b/docs/index.rst index 0981f84..c355258 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -11,6 +11,7 @@ Welcome to Tianshou! * :class:`~tianshou.policy.PGPolicy` `Policy Gradient `_ * :class:`~tianshou.policy.DQNPolicy` `Deep Q-Network `_ * :class:`~tianshou.policy.DQNPolicy` `Double DQN `_ with n-step returns +* :class:`~tianshou.policy.DQNPolicy` `Prioritized DQN `_ * :class:`~tianshou.policy.DDPGPolicy` `Deep Deterministic Policy Gradient `_ * :class:`~tianshou.policy.PPOPolicy` `Proximal Policy Optimization `_ diff --git a/test/base/test_prioritized_replay_buffer.py b/test/base/test_prioritized_replay_buffer.py new file mode 100644 index 0000000..cca4ace --- /dev/null +++ b/test/base/test_prioritized_replay_buffer.py @@ -0,0 +1,37 @@ +import numpy as np +from tianshou.data import PrioritizedReplayBuffer + +if __name__ == '__main__': + from env import MyTestEnv +else: # pytest + from test.base.env import MyTestEnv + + +def test_replaybuffer(size=32, bufsize=15): + env = MyTestEnv(size) + buf = PrioritizedReplayBuffer(bufsize, 0.5, 0.5) + obs = env.reset() + action_list = [1] * 5 + [0] * 10 + [1] * 10 + for i, a in enumerate(action_list): + obs_next, rew, done, info = env.step(a) + buf.add(obs, a, rew, done, obs_next, info, np.random.randn()-0.5) + obs = obs_next + assert np.isclose(np.sum((buf.weight/buf._weight_sum)[:buf._size]), 1, + rtol=1e-12) + data, indice = buf.sample(len(buf) // 2) + if len(buf)//2 == 0: + assert len(data) == len(buf) + else: + assert len(data) == len(buf)//2 + assert len(buf) == min(bufsize, i + 1), print(len(buf), i) + assert np.isclose(buf._weight_sum, (buf.weight).sum()) + data, indice = buf.sample(len(buf) // 2) + buf.update_weight(indice, -data.weight/2) + assert np.isclose(buf.weight[indice], np.power( + np.abs(-data.weight/2), buf._alpha)).all() + assert np.isclose(buf._weight_sum, (buf.weight).sum()) + + +if __name__ == "__main__": + test_replaybuffer(233333, 200000) + print("pass") diff --git a/test/discrete/test_pdqn.py b/test/discrete/test_pdqn.py new file mode 100644 index 0000000..8450c48 --- /dev/null +++ b/test/discrete/test_pdqn.py @@ -0,0 +1,122 @@ +import os +import gym +import torch +import pprint +import argparse +import numpy as np +from torch.utils.tensorboard import SummaryWriter + +from tianshou.env import VectorEnv +from tianshou.policy import DQNPolicy +from tianshou.trainer import offpolicy_trainer +from tianshou.data import Collector, ReplayBuffer, PrioritizedReplayBuffer + +if __name__ == '__main__': + from net import Net +else: # pytest + from test.discrete.net import Net + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--task', type=str, default='CartPole-v0') + parser.add_argument('--seed', type=int, default=1626) + parser.add_argument('--eps-test', type=float, default=0.05) + parser.add_argument('--eps-train', type=float, default=0.1) + parser.add_argument('--buffer-size', type=int, default=20000) + parser.add_argument('--lr', type=float, default=1e-3) + parser.add_argument('--gamma', type=float, default=0.9) + parser.add_argument('--n-step', type=int, default=3) + parser.add_argument('--target-update-freq', type=int, default=320) + parser.add_argument('--epoch', type=int, default=10) + parser.add_argument('--step-per-epoch', type=int, default=1000) + parser.add_argument('--collect-per-step', type=int, default=10) + parser.add_argument('--batch-size', type=int, default=64) + parser.add_argument('--layer-num', type=int, default=3) + parser.add_argument('--training-num', type=int, default=8) + parser.add_argument('--test-num', type=int, default=100) + parser.add_argument('--logdir', type=str, default='log') + parser.add_argument('--render', type=float, default=0.) + parser.add_argument('--prioritized-replay', type=int, default=1) + parser.add_argument('--alpha', type=float, default=0.5) + parser.add_argument('--beta', type=float, default=0.5) + parser.add_argument( + '--device', type=str, + default='cuda' if torch.cuda.is_available() else 'cpu') + args = parser.parse_known_args()[0] + return args + + +def test_pdqn(args=get_args()): + env = gym.make(args.task) + args.state_shape = env.observation_space.shape or env.observation_space.n + args.action_shape = env.action_space.shape or env.action_space.n + # train_envs = gym.make(args.task) + # you can also use tianshou.env.SubprocVectorEnv + train_envs = VectorEnv( + [lambda: gym.make(args.task) for _ in range(args.training_num)]) + # test_envs = gym.make(args.task) + test_envs = VectorEnv( + [lambda: gym.make(args.task) for _ in range(args.test_num)]) + # seed + np.random.seed(args.seed) + torch.manual_seed(args.seed) + train_envs.seed(args.seed) + test_envs.seed(args.seed) + # model + net = Net(args.layer_num, args.state_shape, args.action_shape, args.device) + net = net.to(args.device) + optim = torch.optim.Adam(net.parameters(), lr=args.lr) + policy = DQNPolicy( + net, optim, args.gamma, args.n_step, + use_target_network=args.target_update_freq > 0, + target_update_freq=args.target_update_freq) + # collector + if args.prioritized_replay > 0: + buf = PrioritizedReplayBuffer( + args.buffer_size, alpha=args.alpha, beta=args.alpha) + else: + buf = ReplayBuffer(args.buffer_size) + train_collector = Collector( + policy, train_envs, buf) + test_collector = Collector(policy, test_envs) + # policy.set_eps(1) + train_collector.collect(n_step=args.batch_size) + # log + log_path = os.path.join(args.logdir, args.task, 'dqn') + writer = SummaryWriter(log_path) + + def save_fn(policy): + torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) + + def stop_fn(x): + return x >= env.spec.reward_threshold + + def train_fn(x): + policy.set_eps(args.eps_train) + + def test_fn(x): + policy.set_eps(args.eps_test) + + # trainer + result = offpolicy_trainer( + policy, train_collector, test_collector, args.epoch, + args.step_per_epoch, args.collect_per_step, args.test_num, + args.batch_size, train_fn=train_fn, test_fn=test_fn, + stop_fn=stop_fn, save_fn=save_fn, writer=writer) + + assert stop_fn(result['best_reward']) + train_collector.close() + test_collector.close() + if __name__ == '__main__': + pprint.pprint(result) + # Let's watch its performance! + env = gym.make(args.task) + collector = Collector(policy, env) + result = collector.collect(n_episode=1, render=args.render) + print(f'Final reward: {result["rew"]}, length: {result["len"]}') + collector.close() + + +if __name__ == '__main__': + test_pdqn(get_args()) diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py index 596e45a..76c0df0 100644 --- a/tianshou/data/buffer.py +++ b/tianshou/data/buffer.py @@ -258,14 +258,87 @@ class ListReplayBuffer(ReplayBuffer): class PrioritizedReplayBuffer(ReplayBuffer): """docstring for PrioritizedReplayBuffer""" - def __init__(self, size, **kwargs): + def __init__(self, size, alpha: float, beta: float, + mode: str = 'weight', **kwargs): + if mode != 'weight': + raise NotImplementedError super().__init__(size, **kwargs) + self._alpha = alpha # prioritization exponent + self._beta = beta # importance sample soft coefficient + self._weight_sum = 0.0 + self.weight = np.zeros(size, dtype=np.float64) + self._amortization_freq = 50 + self._amortization_counter = 0 - def add(self, obs, act, rew, done, obs_next=0, info={}, weight=None): - raise NotImplementedError + def add(self, obs, act, rew, done, obs_next=0, info={}, weight=1.0): + """Add a batch of data into replay buffer.""" + self._weight_sum += np.abs(weight)**self._alpha - \ + self.weight[self._index] + # we have to sacrifice some convenience for speed :( + self._add_to_buffer('weight', np.abs(weight)**self._alpha) + super().add(obs, act, rew, done, obs_next, info) + self._check_weight_sum() - def sample(self, batch_size): - raise NotImplementedError + def sample(self, batch_size: int = 0, importance_sample: bool = True): + """ Get a random sample from buffer with priority probability. \ + Return all the data in the buffer if batch_size is ``0``. + + :return: Sample data and its corresponding index inside the buffer. + """ + if batch_size > 0 and batch_size <= self._size: + # Multiple sampling of the same sample + # will cause weight update conflict + indice = np.random.choice( + self._size, batch_size, + p=(self.weight/self.weight.sum())[:self._size], replace=False) + # self._weight_sum is not work for the accuracy issue + # p=(self.weight/self._weight_sum)[:self._size], replace=False) + elif batch_size == 0: + indice = np.concatenate([ + np.arange(self._index, self._size), + np.arange(0, self._index), + ]) + else: + # if batch_size larger than len(self), + # it will lead to a bug in update weight + raise ValueError("batch_size should be less than len(self)") + batch = self[indice] + if importance_sample: + impt_weight = Batch( + impt_weight=1/np.power( + self._size*(batch.weight/self._weight_sum), self._beta)) + batch.append(impt_weight) + self._check_weight_sum() + return batch, indice def reset(self): - raise NotImplementedError + self._amortization_counter = 0 + super().reset() + + def update_weight(self, indice, new_weight: np.ndarray): + """update priority weight by indice in this buffer + + :param indice: indice you want to update weight + :param new_weight: new priority weight you wangt to update + """ + self._weight_sum += np.power(np.abs(new_weight), self._alpha).sum() \ + - self.weight[indice].sum() + self.weight[indice] = np.power(np.abs(new_weight), self._alpha) + + def __getitem__(self, index): + return Batch( + obs=self.get(index, 'obs'), + act=self.act[index], + rew=self.rew[index], + done=self.done[index], + obs_next=self.get(index, 'obs_next'), + info=self.info[index], + weight=self.weight[index] + ) + + def _check_weight_sum(self): + # keep a accurate _weight_sum + self._amortization_counter += 1 + if self._amortization_counter % self._amortization_freq == 0: + self._weight_sum = np.sum(self.weight) + self._amortization_counter = 0 diff --git a/tianshou/policy/modelfree/dqn.py b/tianshou/policy/modelfree/dqn.py index 3dd0687..d17b57b 100644 --- a/tianshou/policy/modelfree/dqn.py +++ b/tianshou/policy/modelfree/dqn.py @@ -3,7 +3,7 @@ import numpy as np from copy import deepcopy import torch.nn.functional as F -from tianshou.data import Batch +from tianshou.data import Batch, PrioritizedReplayBuffer from tianshou.policy import BasePolicy @@ -98,6 +98,18 @@ class DQNPolicy(BasePolicy): target_q[gammas != self._n_step] = 0 returns += (self._gamma ** gammas) * target_q batch.returns = returns + if isinstance(buffer, PrioritizedReplayBuffer): + q = self(batch).logits + q = q[np.arange(len(q)), batch.act] + r = batch.returns + if isinstance(r, np.ndarray): + r = torch.tensor(r, device=q.device, dtype=q.dtype) + td = r-q + buffer.update_weight(indice, td.detach().numpy()) + impt_weight = torch.tensor(batch.impt_weight, + device=q.device, dtype=torch.float) + loss = (td.pow(2)*impt_weight).mean() + batch.loss = loss return batch def forward(self, batch, state=None, @@ -133,12 +145,15 @@ class DQNPolicy(BasePolicy): if self._target and self._cnt % self._freq == 0: self.sync_weight() self.optim.zero_grad() - q = self(batch).logits - q = q[np.arange(len(q)), batch.act] - r = batch.returns - if isinstance(r, np.ndarray): - r = torch.tensor(r, device=q.device, dtype=q.dtype) - loss = F.mse_loss(q, r) + if hasattr(batch, 'loss'): + loss = batch.loss + else: + q = self(batch).logits + q = q[np.arange(len(q)), batch.act] + r = batch.returns + if isinstance(r, np.ndarray): + r = torch.tensor(r, device=q.device, dtype=q.dtype) + loss = F.mse_loss(q, r) loss.backward() self.optim.step() self._cnt += 1