add pytorch drl result

This commit is contained in:
Trinkle23897 2020-03-27 09:04:29 +08:00
parent 519f9f20d0
commit 44f911bc31
11 changed files with 30 additions and 22 deletions

View File

@ -14,7 +14,7 @@
- [Policy Gradient (PG)](https://papers.nips.cc/paper/1713-policy-gradient-methods-for-reinforcement-learning-with-function-approximation.pdf)
- [Deep Q-Network (DQN)](https://storage.googleapis.com/deepmind-media/dqn/DQNNaturePaper.pdf)
- [Double DQN (DDQN)](https://arxiv.org/pdf/1509.06461.pdf)
- [Double DQN (DDQN)](https://arxiv.org/pdf/1509.06461.pdf) + n-step returns
- [Advantage Actor-Critic (A2C)](http://incompleteideas.net/book/RLbook2018.pdf)
- [Deep Deterministic Policy Gradient (DDPG)](https://arxiv.org/pdf/1509.02971.pdf)
- [Proximal Policy Optimization (PPO)](https://arxiv.org/pdf/1707.06347.pdf)
@ -47,23 +47,21 @@ Tianshou is a lightweight but high-speed reinforcement learning platform. For ex
![testpg](docs/_static/images/testpg.gif)
We select some of famous (>1k stars) reinforcement learning platform. Here is the table for other algorithms and platforms:
We select some of famous (>1k stars) reinforcement learning platform. Here is the benchmark result for other algorithms and platforms on toy scenarios:
| Platform | [Tianshou](https://github.com/thu-ml/tianshou)* | [Baselines](https://github.com/openai/baselines) | [Ray/RLlib](https://github.com/ray-project/ray/tree/master/rllib/) | [PyTorch DRL](https://github.com/p-christ/Deep-Reinforcement-Learning-Algorithms-with-PyTorch) | [rlpyt](https://github.com/astooke/rlpyt) |
| Platform | [Tianshou](https://github.com/thu-ml/tianshou) | [Baselines](https://github.com/openai/baselines) | [Ray/RLlib](https://github.com/ray-project/ray/tree/master/rllib/) | [PyTorch DRL](https://github.com/p-christ/Deep-Reinforcement-Learning-Algorithms-with-PyTorch) | [rlpyt](https://github.com/astooke/rlpyt) |
| ------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ |
| GitHub Stars | [![GitHub stars](https://img.shields.io/github/stars/thu-ml/tianshou)](https://github.com/thu-ml/tianshou/stargazers) | [![GitHub stars](https://img.shields.io/github/stars/openai/baselines)](https://github.com/openai/baselines/stargazers) | [![GitHub stars](https://img.shields.io/github/stars/ray-project/ray)](https://github.com/ray-project/ray/stargazers) | [![GitHub stars](https://img.shields.io/github/stars/p-christ/Deep-Reinforcement-Learning-Algorithms-with-PyTorch)](https://github.com/p-christ/Deep-Reinforcement-Learning-Algorithms-with-PyTorch/stargazers) | [![GitHub stars](https://img.shields.io/github/stars/astooke/rlpyt)](https://github.com/astooke/rlpyt/stargazers) |
| Algo \ ML platform | PyTorch | TensorFlow | TF/PyTorch | PyTorch | PyTorch |
| PG - CartPole | 9.03±4.18s | | | | |
| DQN - CartPole | 20.94±11.38s | | | | |
| A2C - CartPole | 11.72±3.85s | | | | |
| PPO - CartPole | 35.25±16.47s | | | | |
| DDPG - Pendulum | 46.95±24.31s | | | | |
| SAC - Pendulum | 38.92±2.09s | None | | | |
| TD3 - Pendulum | 48.39±7.22s | None | | | |
| PG - CartPole | 9.03±4.18s | | | None | |
| DQN - CartPole | 20.94±11.38s | | | 175.55±53.81s | |
| A2C - CartPole | 11.72±3.85s | | | Error | |
| PPO - CartPole | 35.25±16.47s | | | 29.16±15.46s | |
| DDPG - Pendulum | 46.95±24.31s | | | 652.83±471.28s | |
| SAC - Pendulum | 38.92±2.09s | None | | 808.21±405.70s | |
| TD3 - Pendulum | 48.39±7.22s | None | | 619.33±324.97s | |
The reward threshold is 195.0 in CartPole and -250.0 in Pendulum over consecutive 100 episodes.
*: Tianshou uses 10 seeds for testing in 10 epochs. We erase those trials which failed training within the given limitation.
All of the platforms use at most 10 different seeds for testing. We erase those trials which failed for training. The reward threshold is 195.0 in CartPole and -250.0 in Pendulum over consecutive 100 episodes.
### Reproducible
@ -173,7 +171,7 @@ test_collector = Collector(policy, test_envs)
Let's train it:
```python
result = onpolicy_trainer(policy, train_collector, test_collector, epoch, step_per_epoch, collect_per_step, repeat_per_collect, [1] * test_num, batch_size, stop_fn=lambda x: x >= env.spec.reward_threshold, writer=writer)
result = onpolicy_trainer(policy, train_collector, test_collector, epoch, step_per_epoch, collect_per_step, repeat_per_collect, test_num, batch_size, stop_fn=lambda x: x >= env.spec.reward_threshold, writer=writer)
```
Saving / loading trained policy (it's exactly the same as PyTorch nn.module):
@ -213,6 +211,7 @@ If you find Tianshou useful, please cite it in your publications.
## TODO
- [ ] More examples on [mujoco, atari] benchmark
- [ ] Prioritized replay buffer
- [ ] RNN support
- [ ] Multi-agent

View File

@ -87,7 +87,7 @@ def test_ddpg(args=get_args()):
# trainer
result = offpolicy_trainer(
policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.collect_per_step, [1] * args.test_num,
args.step_per_epoch, args.collect_per_step, args.test_num,
args.batch_size, stop_fn=stop_fn, writer=writer)
assert stop_fn(result['best_reward'])
train_collector.close()

View File

@ -96,7 +96,7 @@ def _test_ppo(args=get_args()):
result = onpolicy_trainer(
policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.collect_per_step, args.repeat_per_collect,
[1] * args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer)
args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer)
assert stop_fn(result['best_reward'])
train_collector.close()
test_collector.close()

View File

@ -92,7 +92,7 @@ def test_sac(args=get_args()):
# trainer
result = offpolicy_trainer(
policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.collect_per_step, [1] * args.test_num,
args.step_per_epoch, args.collect_per_step, args.test_num,
args.batch_size, stop_fn=stop_fn, writer=writer)
assert stop_fn(result['best_reward'])
train_collector.close()

View File

@ -96,7 +96,7 @@ def test_td3(args=get_args()):
# trainer
result = offpolicy_trainer(
policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.collect_per_step, [1] * args.test_num,
args.step_per_epoch, args.collect_per_step, args.test_num,
args.batch_size, stop_fn=stop_fn, writer=writer)
assert stop_fn(result['best_reward'])
train_collector.close()

View File

@ -82,7 +82,7 @@ def test_a2c(args=get_args()):
result = onpolicy_trainer(
policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.collect_per_step, args.repeat_per_collect,
[1] * args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer)
args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer)
assert stop_fn(result['best_reward'])
train_collector.close()
test_collector.close()

View File

@ -87,7 +87,7 @@ def test_dqn(args=get_args()):
# trainer
result = offpolicy_trainer(
policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.collect_per_step, [1] * args.test_num,
args.step_per_epoch, args.collect_per_step, args.test_num,
args.batch_size, train_fn=train_fn, test_fn=test_fn,
stop_fn=stop_fn, writer=writer)

View File

@ -130,7 +130,7 @@ def test_pg(args=get_args()):
result = onpolicy_trainer(
policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.collect_per_step, args.repeat_per_collect,
[1] * args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer)
args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer)
assert stop_fn(result['best_reward'])
train_collector.close()
test_collector.close()

View File

@ -87,7 +87,7 @@ def test_ppo(args=get_args()):
result = onpolicy_trainer(
policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.collect_per_step, args.repeat_per_collect,
[1] * args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer)
args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer)
assert stop_fn(result['best_reward'])
train_collector.close()
test_collector.close()

View File

@ -54,6 +54,9 @@ class Collector(object):
else:
self.buffer.reset()
def get_env_num(self):
return self.env_num
def reset_env(self):
self._obs = self.env.reset()
self._act = self._rew = self._done = self._info = None

View File

@ -1,4 +1,5 @@
import time
import numpy as np
def test_episode(policy, collector, test_fn, epoch, n_episode):
@ -7,6 +8,11 @@ def test_episode(policy, collector, test_fn, epoch, n_episode):
policy.eval()
if test_fn:
test_fn(epoch)
if collector.get_env_num() > 1 and np.isscalar(n_episode):
n = collector.get_env_num()
n_ = np.zeros(n) + n_episode // n
n_[:n_episode % n] += 1
n_episode = list(n_)
return collector.collect(n_episode=n_episode)