From 44f911bc31f34eb7e79a929587b5d78d1ebe6d9c Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Fri, 27 Mar 2020 09:04:29 +0800 Subject: [PATCH] add pytorch drl result --- README.md | 27 +++++++++++++-------------- test/continuous/test_ddpg.py | 2 +- test/continuous/test_ppo.py | 2 +- test/continuous/test_sac.py | 2 +- test/continuous/test_td3.py | 2 +- test/discrete/test_a2c.py | 2 +- test/discrete/test_dqn.py | 2 +- test/discrete/test_pg.py | 2 +- test/discrete/test_ppo.py | 2 +- tianshou/data/collector.py | 3 +++ tianshou/trainer/utils.py | 6 ++++++ 11 files changed, 30 insertions(+), 22 deletions(-) diff --git a/README.md b/README.md index d4474f0..b952ff3 100644 --- a/README.md +++ b/README.md @@ -14,7 +14,7 @@ - [Policy Gradient (PG)](https://papers.nips.cc/paper/1713-policy-gradient-methods-for-reinforcement-learning-with-function-approximation.pdf) - [Deep Q-Network (DQN)](https://storage.googleapis.com/deepmind-media/dqn/DQNNaturePaper.pdf) -- [Double DQN (DDQN)](https://arxiv.org/pdf/1509.06461.pdf) +- [Double DQN (DDQN)](https://arxiv.org/pdf/1509.06461.pdf) + n-step returns - [Advantage Actor-Critic (A2C)](http://incompleteideas.net/book/RLbook2018.pdf) - [Deep Deterministic Policy Gradient (DDPG)](https://arxiv.org/pdf/1509.02971.pdf) - [Proximal Policy Optimization (PPO)](https://arxiv.org/pdf/1707.06347.pdf) @@ -47,23 +47,21 @@ Tianshou is a lightweight but high-speed reinforcement learning platform. For ex ![testpg](docs/_static/images/testpg.gif) -We select some of famous (>1k stars) reinforcement learning platform. Here is the table for other algorithms and platforms: +We select some of famous (>1k stars) reinforcement learning platform. Here is the benchmark result for other algorithms and platforms on toy scenarios: -| Platform | [Tianshou](https://github.com/thu-ml/tianshou)* | [Baselines](https://github.com/openai/baselines) | [Ray/RLlib](https://github.com/ray-project/ray/tree/master/rllib/) | [PyTorch DRL](https://github.com/p-christ/Deep-Reinforcement-Learning-Algorithms-with-PyTorch) | [rlpyt](https://github.com/astooke/rlpyt) | +| Platform | [Tianshou](https://github.com/thu-ml/tianshou) | [Baselines](https://github.com/openai/baselines) | [Ray/RLlib](https://github.com/ray-project/ray/tree/master/rllib/) | [PyTorch DRL](https://github.com/p-christ/Deep-Reinforcement-Learning-Algorithms-with-PyTorch) | [rlpyt](https://github.com/astooke/rlpyt) | | ------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | | GitHub Stars | [![GitHub stars](https://img.shields.io/github/stars/thu-ml/tianshou)](https://github.com/thu-ml/tianshou/stargazers) | [![GitHub stars](https://img.shields.io/github/stars/openai/baselines)](https://github.com/openai/baselines/stargazers) | [![GitHub stars](https://img.shields.io/github/stars/ray-project/ray)](https://github.com/ray-project/ray/stargazers) | [![GitHub stars](https://img.shields.io/github/stars/p-christ/Deep-Reinforcement-Learning-Algorithms-with-PyTorch)](https://github.com/p-christ/Deep-Reinforcement-Learning-Algorithms-with-PyTorch/stargazers) | [![GitHub stars](https://img.shields.io/github/stars/astooke/rlpyt)](https://github.com/astooke/rlpyt/stargazers) | | Algo \ ML platform | PyTorch | TensorFlow | TF/PyTorch | PyTorch | PyTorch | -| PG - CartPole | 9.03±4.18s | | | | | -| DQN - CartPole | 20.94±11.38s | | | | | -| A2C - CartPole | 11.72±3.85s | | | | | -| PPO - CartPole | 35.25±16.47s | | | | | -| DDPG - Pendulum | 46.95±24.31s | | | | | -| SAC - Pendulum | 38.92±2.09s | None | | | | -| TD3 - Pendulum | 48.39±7.22s | None | | | | +| PG - CartPole | 9.03±4.18s | | | None | | +| DQN - CartPole | 20.94±11.38s | | | 175.55±53.81s | | +| A2C - CartPole | 11.72±3.85s | | | Error | | +| PPO - CartPole | 35.25±16.47s | | | 29.16±15.46s | | +| DDPG - Pendulum | 46.95±24.31s | | | 652.83±471.28s | | +| SAC - Pendulum | 38.92±2.09s | None | | 808.21±405.70s | | +| TD3 - Pendulum | 48.39±7.22s | None | | 619.33±324.97s | | -The reward threshold is 195.0 in CartPole and -250.0 in Pendulum over consecutive 100 episodes. - -*: Tianshou uses 10 seeds for testing in 10 epochs. We erase those trials which failed training within the given limitation. +All of the platforms use at most 10 different seeds for testing. We erase those trials which failed for training. The reward threshold is 195.0 in CartPole and -250.0 in Pendulum over consecutive 100 episodes. ### Reproducible @@ -173,7 +171,7 @@ test_collector = Collector(policy, test_envs) Let's train it: ```python -result = onpolicy_trainer(policy, train_collector, test_collector, epoch, step_per_epoch, collect_per_step, repeat_per_collect, [1] * test_num, batch_size, stop_fn=lambda x: x >= env.spec.reward_threshold, writer=writer) +result = onpolicy_trainer(policy, train_collector, test_collector, epoch, step_per_epoch, collect_per_step, repeat_per_collect, test_num, batch_size, stop_fn=lambda x: x >= env.spec.reward_threshold, writer=writer) ``` Saving / loading trained policy (it's exactly the same as PyTorch nn.module): @@ -213,6 +211,7 @@ If you find Tianshou useful, please cite it in your publications. ## TODO +- [ ] More examples on [mujoco, atari] benchmark - [ ] Prioritized replay buffer - [ ] RNN support - [ ] Multi-agent diff --git a/test/continuous/test_ddpg.py b/test/continuous/test_ddpg.py index 5480bc7..cb100da 100644 --- a/test/continuous/test_ddpg.py +++ b/test/continuous/test_ddpg.py @@ -87,7 +87,7 @@ def test_ddpg(args=get_args()): # trainer result = offpolicy_trainer( policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.collect_per_step, [1] * args.test_num, + args.step_per_epoch, args.collect_per_step, args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer) assert stop_fn(result['best_reward']) train_collector.close() diff --git a/test/continuous/test_ppo.py b/test/continuous/test_ppo.py index a5308be..3cfb9f2 100644 --- a/test/continuous/test_ppo.py +++ b/test/continuous/test_ppo.py @@ -96,7 +96,7 @@ def _test_ppo(args=get_args()): result = onpolicy_trainer( policy, train_collector, test_collector, args.epoch, args.step_per_epoch, args.collect_per_step, args.repeat_per_collect, - [1] * args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer) + args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer) assert stop_fn(result['best_reward']) train_collector.close() test_collector.close() diff --git a/test/continuous/test_sac.py b/test/continuous/test_sac.py index ba3d3fb..a900940 100644 --- a/test/continuous/test_sac.py +++ b/test/continuous/test_sac.py @@ -92,7 +92,7 @@ def test_sac(args=get_args()): # trainer result = offpolicy_trainer( policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.collect_per_step, [1] * args.test_num, + args.step_per_epoch, args.collect_per_step, args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer) assert stop_fn(result['best_reward']) train_collector.close() diff --git a/test/continuous/test_td3.py b/test/continuous/test_td3.py index 52876ca..78db493 100644 --- a/test/continuous/test_td3.py +++ b/test/continuous/test_td3.py @@ -96,7 +96,7 @@ def test_td3(args=get_args()): # trainer result = offpolicy_trainer( policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.collect_per_step, [1] * args.test_num, + args.step_per_epoch, args.collect_per_step, args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer) assert stop_fn(result['best_reward']) train_collector.close() diff --git a/test/discrete/test_a2c.py b/test/discrete/test_a2c.py index 6f826d7..20ef0ef 100644 --- a/test/discrete/test_a2c.py +++ b/test/discrete/test_a2c.py @@ -82,7 +82,7 @@ def test_a2c(args=get_args()): result = onpolicy_trainer( policy, train_collector, test_collector, args.epoch, args.step_per_epoch, args.collect_per_step, args.repeat_per_collect, - [1] * args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer) + args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer) assert stop_fn(result['best_reward']) train_collector.close() test_collector.close() diff --git a/test/discrete/test_dqn.py b/test/discrete/test_dqn.py index b742621..3a11993 100644 --- a/test/discrete/test_dqn.py +++ b/test/discrete/test_dqn.py @@ -87,7 +87,7 @@ def test_dqn(args=get_args()): # trainer result = offpolicy_trainer( policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.collect_per_step, [1] * args.test_num, + args.step_per_epoch, args.collect_per_step, args.test_num, args.batch_size, train_fn=train_fn, test_fn=test_fn, stop_fn=stop_fn, writer=writer) diff --git a/test/discrete/test_pg.py b/test/discrete/test_pg.py index b823dd9..e0f4a08 100644 --- a/test/discrete/test_pg.py +++ b/test/discrete/test_pg.py @@ -130,7 +130,7 @@ def test_pg(args=get_args()): result = onpolicy_trainer( policy, train_collector, test_collector, args.epoch, args.step_per_epoch, args.collect_per_step, args.repeat_per_collect, - [1] * args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer) + args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer) assert stop_fn(result['best_reward']) train_collector.close() test_collector.close() diff --git a/test/discrete/test_ppo.py b/test/discrete/test_ppo.py index 4ebb2c8..d21dd6c 100644 --- a/test/discrete/test_ppo.py +++ b/test/discrete/test_ppo.py @@ -87,7 +87,7 @@ def test_ppo(args=get_args()): result = onpolicy_trainer( policy, train_collector, test_collector, args.epoch, args.step_per_epoch, args.collect_per_step, args.repeat_per_collect, - [1] * args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer) + args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer) assert stop_fn(result['best_reward']) train_collector.close() test_collector.close() diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index 1c2fe9d..0673ac7 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -54,6 +54,9 @@ class Collector(object): else: self.buffer.reset() + def get_env_num(self): + return self.env_num + def reset_env(self): self._obs = self.env.reset() self._act = self._rew = self._done = self._info = None diff --git a/tianshou/trainer/utils.py b/tianshou/trainer/utils.py index d5f2dca..40ae5a0 100644 --- a/tianshou/trainer/utils.py +++ b/tianshou/trainer/utils.py @@ -1,4 +1,5 @@ import time +import numpy as np def test_episode(policy, collector, test_fn, epoch, n_episode): @@ -7,6 +8,11 @@ def test_episode(policy, collector, test_fn, epoch, n_episode): policy.eval() if test_fn: test_fn(epoch) + if collector.get_env_num() > 1 and np.isscalar(n_episode): + n = collector.get_env_num() + n_ = np.zeros(n) + n_episode // n + n_[:n_episode % n] += 1 + n_episode = list(n_) return collector.collect(n_episode=n_episode)