add pytorch drl result
This commit is contained in:
parent
519f9f20d0
commit
44f911bc31
27
README.md
27
README.md
@ -14,7 +14,7 @@
|
||||
|
||||
- [Policy Gradient (PG)](https://papers.nips.cc/paper/1713-policy-gradient-methods-for-reinforcement-learning-with-function-approximation.pdf)
|
||||
- [Deep Q-Network (DQN)](https://storage.googleapis.com/deepmind-media/dqn/DQNNaturePaper.pdf)
|
||||
- [Double DQN (DDQN)](https://arxiv.org/pdf/1509.06461.pdf)
|
||||
- [Double DQN (DDQN)](https://arxiv.org/pdf/1509.06461.pdf) + n-step returns
|
||||
- [Advantage Actor-Critic (A2C)](http://incompleteideas.net/book/RLbook2018.pdf)
|
||||
- [Deep Deterministic Policy Gradient (DDPG)](https://arxiv.org/pdf/1509.02971.pdf)
|
||||
- [Proximal Policy Optimization (PPO)](https://arxiv.org/pdf/1707.06347.pdf)
|
||||
@ -47,23 +47,21 @@ Tianshou is a lightweight but high-speed reinforcement learning platform. For ex
|
||||
|
||||

|
||||
|
||||
We select some of famous (>1k stars) reinforcement learning platform. Here is the table for other algorithms and platforms:
|
||||
We select some of famous (>1k stars) reinforcement learning platform. Here is the benchmark result for other algorithms and platforms on toy scenarios:
|
||||
|
||||
| Platform | [Tianshou](https://github.com/thu-ml/tianshou)* | [Baselines](https://github.com/openai/baselines) | [Ray/RLlib](https://github.com/ray-project/ray/tree/master/rllib/) | [PyTorch DRL](https://github.com/p-christ/Deep-Reinforcement-Learning-Algorithms-with-PyTorch) | [rlpyt](https://github.com/astooke/rlpyt) |
|
||||
| Platform | [Tianshou](https://github.com/thu-ml/tianshou) | [Baselines](https://github.com/openai/baselines) | [Ray/RLlib](https://github.com/ray-project/ray/tree/master/rllib/) | [PyTorch DRL](https://github.com/p-christ/Deep-Reinforcement-Learning-Algorithms-with-PyTorch) | [rlpyt](https://github.com/astooke/rlpyt) |
|
||||
| ------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ |
|
||||
| GitHub Stars | [](https://github.com/thu-ml/tianshou/stargazers) | [](https://github.com/openai/baselines/stargazers) | [](https://github.com/ray-project/ray/stargazers) | [](https://github.com/p-christ/Deep-Reinforcement-Learning-Algorithms-with-PyTorch/stargazers) | [](https://github.com/astooke/rlpyt/stargazers) |
|
||||
| Algo \ ML platform | PyTorch | TensorFlow | TF/PyTorch | PyTorch | PyTorch |
|
||||
| PG - CartPole | 9.03±4.18s | | | | |
|
||||
| DQN - CartPole | 20.94±11.38s | | | | |
|
||||
| A2C - CartPole | 11.72±3.85s | | | | |
|
||||
| PPO - CartPole | 35.25±16.47s | | | | |
|
||||
| DDPG - Pendulum | 46.95±24.31s | | | | |
|
||||
| SAC - Pendulum | 38.92±2.09s | None | | | |
|
||||
| TD3 - Pendulum | 48.39±7.22s | None | | | |
|
||||
| PG - CartPole | 9.03±4.18s | | | None | |
|
||||
| DQN - CartPole | 20.94±11.38s | | | 175.55±53.81s | |
|
||||
| A2C - CartPole | 11.72±3.85s | | | Error | |
|
||||
| PPO - CartPole | 35.25±16.47s | | | 29.16±15.46s | |
|
||||
| DDPG - Pendulum | 46.95±24.31s | | | 652.83±471.28s | |
|
||||
| SAC - Pendulum | 38.92±2.09s | None | | 808.21±405.70s | |
|
||||
| TD3 - Pendulum | 48.39±7.22s | None | | 619.33±324.97s | |
|
||||
|
||||
The reward threshold is 195.0 in CartPole and -250.0 in Pendulum over consecutive 100 episodes.
|
||||
|
||||
*: Tianshou uses 10 seeds for testing in 10 epochs. We erase those trials which failed training within the given limitation.
|
||||
All of the platforms use at most 10 different seeds for testing. We erase those trials which failed for training. The reward threshold is 195.0 in CartPole and -250.0 in Pendulum over consecutive 100 episodes.
|
||||
|
||||
### Reproducible
|
||||
|
||||
@ -173,7 +171,7 @@ test_collector = Collector(policy, test_envs)
|
||||
Let's train it:
|
||||
|
||||
```python
|
||||
result = onpolicy_trainer(policy, train_collector, test_collector, epoch, step_per_epoch, collect_per_step, repeat_per_collect, [1] * test_num, batch_size, stop_fn=lambda x: x >= env.spec.reward_threshold, writer=writer)
|
||||
result = onpolicy_trainer(policy, train_collector, test_collector, epoch, step_per_epoch, collect_per_step, repeat_per_collect, test_num, batch_size, stop_fn=lambda x: x >= env.spec.reward_threshold, writer=writer)
|
||||
```
|
||||
|
||||
Saving / loading trained policy (it's exactly the same as PyTorch nn.module):
|
||||
@ -213,6 +211,7 @@ If you find Tianshou useful, please cite it in your publications.
|
||||
|
||||
## TODO
|
||||
|
||||
- [ ] More examples on [mujoco, atari] benchmark
|
||||
- [ ] Prioritized replay buffer
|
||||
- [ ] RNN support
|
||||
- [ ] Multi-agent
|
||||
|
@ -87,7 +87,7 @@ def test_ddpg(args=get_args()):
|
||||
# trainer
|
||||
result = offpolicy_trainer(
|
||||
policy, train_collector, test_collector, args.epoch,
|
||||
args.step_per_epoch, args.collect_per_step, [1] * args.test_num,
|
||||
args.step_per_epoch, args.collect_per_step, args.test_num,
|
||||
args.batch_size, stop_fn=stop_fn, writer=writer)
|
||||
assert stop_fn(result['best_reward'])
|
||||
train_collector.close()
|
||||
|
@ -96,7 +96,7 @@ def _test_ppo(args=get_args()):
|
||||
result = onpolicy_trainer(
|
||||
policy, train_collector, test_collector, args.epoch,
|
||||
args.step_per_epoch, args.collect_per_step, args.repeat_per_collect,
|
||||
[1] * args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer)
|
||||
args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer)
|
||||
assert stop_fn(result['best_reward'])
|
||||
train_collector.close()
|
||||
test_collector.close()
|
||||
|
@ -92,7 +92,7 @@ def test_sac(args=get_args()):
|
||||
# trainer
|
||||
result = offpolicy_trainer(
|
||||
policy, train_collector, test_collector, args.epoch,
|
||||
args.step_per_epoch, args.collect_per_step, [1] * args.test_num,
|
||||
args.step_per_epoch, args.collect_per_step, args.test_num,
|
||||
args.batch_size, stop_fn=stop_fn, writer=writer)
|
||||
assert stop_fn(result['best_reward'])
|
||||
train_collector.close()
|
||||
|
@ -96,7 +96,7 @@ def test_td3(args=get_args()):
|
||||
# trainer
|
||||
result = offpolicy_trainer(
|
||||
policy, train_collector, test_collector, args.epoch,
|
||||
args.step_per_epoch, args.collect_per_step, [1] * args.test_num,
|
||||
args.step_per_epoch, args.collect_per_step, args.test_num,
|
||||
args.batch_size, stop_fn=stop_fn, writer=writer)
|
||||
assert stop_fn(result['best_reward'])
|
||||
train_collector.close()
|
||||
|
@ -82,7 +82,7 @@ def test_a2c(args=get_args()):
|
||||
result = onpolicy_trainer(
|
||||
policy, train_collector, test_collector, args.epoch,
|
||||
args.step_per_epoch, args.collect_per_step, args.repeat_per_collect,
|
||||
[1] * args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer)
|
||||
args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer)
|
||||
assert stop_fn(result['best_reward'])
|
||||
train_collector.close()
|
||||
test_collector.close()
|
||||
|
@ -87,7 +87,7 @@ def test_dqn(args=get_args()):
|
||||
# trainer
|
||||
result = offpolicy_trainer(
|
||||
policy, train_collector, test_collector, args.epoch,
|
||||
args.step_per_epoch, args.collect_per_step, [1] * args.test_num,
|
||||
args.step_per_epoch, args.collect_per_step, args.test_num,
|
||||
args.batch_size, train_fn=train_fn, test_fn=test_fn,
|
||||
stop_fn=stop_fn, writer=writer)
|
||||
|
||||
|
@ -130,7 +130,7 @@ def test_pg(args=get_args()):
|
||||
result = onpolicy_trainer(
|
||||
policy, train_collector, test_collector, args.epoch,
|
||||
args.step_per_epoch, args.collect_per_step, args.repeat_per_collect,
|
||||
[1] * args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer)
|
||||
args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer)
|
||||
assert stop_fn(result['best_reward'])
|
||||
train_collector.close()
|
||||
test_collector.close()
|
||||
|
@ -87,7 +87,7 @@ def test_ppo(args=get_args()):
|
||||
result = onpolicy_trainer(
|
||||
policy, train_collector, test_collector, args.epoch,
|
||||
args.step_per_epoch, args.collect_per_step, args.repeat_per_collect,
|
||||
[1] * args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer)
|
||||
args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer)
|
||||
assert stop_fn(result['best_reward'])
|
||||
train_collector.close()
|
||||
test_collector.close()
|
||||
|
@ -54,6 +54,9 @@ class Collector(object):
|
||||
else:
|
||||
self.buffer.reset()
|
||||
|
||||
def get_env_num(self):
|
||||
return self.env_num
|
||||
|
||||
def reset_env(self):
|
||||
self._obs = self.env.reset()
|
||||
self._act = self._rew = self._done = self._info = None
|
||||
|
@ -1,4 +1,5 @@
|
||||
import time
|
||||
import numpy as np
|
||||
|
||||
|
||||
def test_episode(policy, collector, test_fn, epoch, n_episode):
|
||||
@ -7,6 +8,11 @@ def test_episode(policy, collector, test_fn, epoch, n_episode):
|
||||
policy.eval()
|
||||
if test_fn:
|
||||
test_fn(epoch)
|
||||
if collector.get_env_num() > 1 and np.isscalar(n_episode):
|
||||
n = collector.get_env_num()
|
||||
n_ = np.zeros(n) + n_episode // n
|
||||
n_[:n_episode % n] += 1
|
||||
n_episode = list(n_)
|
||||
return collector.collect(n_episode=n_episode)
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user