add pytorch drl result
This commit is contained in:
parent
519f9f20d0
commit
44f911bc31
27
README.md
27
README.md
@ -14,7 +14,7 @@
|
|||||||
|
|
||||||
- [Policy Gradient (PG)](https://papers.nips.cc/paper/1713-policy-gradient-methods-for-reinforcement-learning-with-function-approximation.pdf)
|
- [Policy Gradient (PG)](https://papers.nips.cc/paper/1713-policy-gradient-methods-for-reinforcement-learning-with-function-approximation.pdf)
|
||||||
- [Deep Q-Network (DQN)](https://storage.googleapis.com/deepmind-media/dqn/DQNNaturePaper.pdf)
|
- [Deep Q-Network (DQN)](https://storage.googleapis.com/deepmind-media/dqn/DQNNaturePaper.pdf)
|
||||||
- [Double DQN (DDQN)](https://arxiv.org/pdf/1509.06461.pdf)
|
- [Double DQN (DDQN)](https://arxiv.org/pdf/1509.06461.pdf) + n-step returns
|
||||||
- [Advantage Actor-Critic (A2C)](http://incompleteideas.net/book/RLbook2018.pdf)
|
- [Advantage Actor-Critic (A2C)](http://incompleteideas.net/book/RLbook2018.pdf)
|
||||||
- [Deep Deterministic Policy Gradient (DDPG)](https://arxiv.org/pdf/1509.02971.pdf)
|
- [Deep Deterministic Policy Gradient (DDPG)](https://arxiv.org/pdf/1509.02971.pdf)
|
||||||
- [Proximal Policy Optimization (PPO)](https://arxiv.org/pdf/1707.06347.pdf)
|
- [Proximal Policy Optimization (PPO)](https://arxiv.org/pdf/1707.06347.pdf)
|
||||||
@ -47,23 +47,21 @@ Tianshou is a lightweight but high-speed reinforcement learning platform. For ex
|
|||||||
|
|
||||||

|

|
||||||
|
|
||||||
We select some of famous (>1k stars) reinforcement learning platform. Here is the table for other algorithms and platforms:
|
We select some of famous (>1k stars) reinforcement learning platform. Here is the benchmark result for other algorithms and platforms on toy scenarios:
|
||||||
|
|
||||||
| Platform | [Tianshou](https://github.com/thu-ml/tianshou)* | [Baselines](https://github.com/openai/baselines) | [Ray/RLlib](https://github.com/ray-project/ray/tree/master/rllib/) | [PyTorch DRL](https://github.com/p-christ/Deep-Reinforcement-Learning-Algorithms-with-PyTorch) | [rlpyt](https://github.com/astooke/rlpyt) |
|
| Platform | [Tianshou](https://github.com/thu-ml/tianshou) | [Baselines](https://github.com/openai/baselines) | [Ray/RLlib](https://github.com/ray-project/ray/tree/master/rllib/) | [PyTorch DRL](https://github.com/p-christ/Deep-Reinforcement-Learning-Algorithms-with-PyTorch) | [rlpyt](https://github.com/astooke/rlpyt) |
|
||||||
| ------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ |
|
| ------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ |
|
||||||
| GitHub Stars | [](https://github.com/thu-ml/tianshou/stargazers) | [](https://github.com/openai/baselines/stargazers) | [](https://github.com/ray-project/ray/stargazers) | [](https://github.com/p-christ/Deep-Reinforcement-Learning-Algorithms-with-PyTorch/stargazers) | [](https://github.com/astooke/rlpyt/stargazers) |
|
| GitHub Stars | [](https://github.com/thu-ml/tianshou/stargazers) | [](https://github.com/openai/baselines/stargazers) | [](https://github.com/ray-project/ray/stargazers) | [](https://github.com/p-christ/Deep-Reinforcement-Learning-Algorithms-with-PyTorch/stargazers) | [](https://github.com/astooke/rlpyt/stargazers) |
|
||||||
| Algo \ ML platform | PyTorch | TensorFlow | TF/PyTorch | PyTorch | PyTorch |
|
| Algo \ ML platform | PyTorch | TensorFlow | TF/PyTorch | PyTorch | PyTorch |
|
||||||
| PG - CartPole | 9.03±4.18s | | | | |
|
| PG - CartPole | 9.03±4.18s | | | None | |
|
||||||
| DQN - CartPole | 20.94±11.38s | | | | |
|
| DQN - CartPole | 20.94±11.38s | | | 175.55±53.81s | |
|
||||||
| A2C - CartPole | 11.72±3.85s | | | | |
|
| A2C - CartPole | 11.72±3.85s | | | Error | |
|
||||||
| PPO - CartPole | 35.25±16.47s | | | | |
|
| PPO - CartPole | 35.25±16.47s | | | 29.16±15.46s | |
|
||||||
| DDPG - Pendulum | 46.95±24.31s | | | | |
|
| DDPG - Pendulum | 46.95±24.31s | | | 652.83±471.28s | |
|
||||||
| SAC - Pendulum | 38.92±2.09s | None | | | |
|
| SAC - Pendulum | 38.92±2.09s | None | | 808.21±405.70s | |
|
||||||
| TD3 - Pendulum | 48.39±7.22s | None | | | |
|
| TD3 - Pendulum | 48.39±7.22s | None | | 619.33±324.97s | |
|
||||||
|
|
||||||
The reward threshold is 195.0 in CartPole and -250.0 in Pendulum over consecutive 100 episodes.
|
All of the platforms use at most 10 different seeds for testing. We erase those trials which failed for training. The reward threshold is 195.0 in CartPole and -250.0 in Pendulum over consecutive 100 episodes.
|
||||||
|
|
||||||
*: Tianshou uses 10 seeds for testing in 10 epochs. We erase those trials which failed training within the given limitation.
|
|
||||||
|
|
||||||
### Reproducible
|
### Reproducible
|
||||||
|
|
||||||
@ -173,7 +171,7 @@ test_collector = Collector(policy, test_envs)
|
|||||||
Let's train it:
|
Let's train it:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
result = onpolicy_trainer(policy, train_collector, test_collector, epoch, step_per_epoch, collect_per_step, repeat_per_collect, [1] * test_num, batch_size, stop_fn=lambda x: x >= env.spec.reward_threshold, writer=writer)
|
result = onpolicy_trainer(policy, train_collector, test_collector, epoch, step_per_epoch, collect_per_step, repeat_per_collect, test_num, batch_size, stop_fn=lambda x: x >= env.spec.reward_threshold, writer=writer)
|
||||||
```
|
```
|
||||||
|
|
||||||
Saving / loading trained policy (it's exactly the same as PyTorch nn.module):
|
Saving / loading trained policy (it's exactly the same as PyTorch nn.module):
|
||||||
@ -213,6 +211,7 @@ If you find Tianshou useful, please cite it in your publications.
|
|||||||
|
|
||||||
## TODO
|
## TODO
|
||||||
|
|
||||||
|
- [ ] More examples on [mujoco, atari] benchmark
|
||||||
- [ ] Prioritized replay buffer
|
- [ ] Prioritized replay buffer
|
||||||
- [ ] RNN support
|
- [ ] RNN support
|
||||||
- [ ] Multi-agent
|
- [ ] Multi-agent
|
||||||
|
@ -87,7 +87,7 @@ def test_ddpg(args=get_args()):
|
|||||||
# trainer
|
# trainer
|
||||||
result = offpolicy_trainer(
|
result = offpolicy_trainer(
|
||||||
policy, train_collector, test_collector, args.epoch,
|
policy, train_collector, test_collector, args.epoch,
|
||||||
args.step_per_epoch, args.collect_per_step, [1] * args.test_num,
|
args.step_per_epoch, args.collect_per_step, args.test_num,
|
||||||
args.batch_size, stop_fn=stop_fn, writer=writer)
|
args.batch_size, stop_fn=stop_fn, writer=writer)
|
||||||
assert stop_fn(result['best_reward'])
|
assert stop_fn(result['best_reward'])
|
||||||
train_collector.close()
|
train_collector.close()
|
||||||
|
@ -96,7 +96,7 @@ def _test_ppo(args=get_args()):
|
|||||||
result = onpolicy_trainer(
|
result = onpolicy_trainer(
|
||||||
policy, train_collector, test_collector, args.epoch,
|
policy, train_collector, test_collector, args.epoch,
|
||||||
args.step_per_epoch, args.collect_per_step, args.repeat_per_collect,
|
args.step_per_epoch, args.collect_per_step, args.repeat_per_collect,
|
||||||
[1] * args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer)
|
args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer)
|
||||||
assert stop_fn(result['best_reward'])
|
assert stop_fn(result['best_reward'])
|
||||||
train_collector.close()
|
train_collector.close()
|
||||||
test_collector.close()
|
test_collector.close()
|
||||||
|
@ -92,7 +92,7 @@ def test_sac(args=get_args()):
|
|||||||
# trainer
|
# trainer
|
||||||
result = offpolicy_trainer(
|
result = offpolicy_trainer(
|
||||||
policy, train_collector, test_collector, args.epoch,
|
policy, train_collector, test_collector, args.epoch,
|
||||||
args.step_per_epoch, args.collect_per_step, [1] * args.test_num,
|
args.step_per_epoch, args.collect_per_step, args.test_num,
|
||||||
args.batch_size, stop_fn=stop_fn, writer=writer)
|
args.batch_size, stop_fn=stop_fn, writer=writer)
|
||||||
assert stop_fn(result['best_reward'])
|
assert stop_fn(result['best_reward'])
|
||||||
train_collector.close()
|
train_collector.close()
|
||||||
|
@ -96,7 +96,7 @@ def test_td3(args=get_args()):
|
|||||||
# trainer
|
# trainer
|
||||||
result = offpolicy_trainer(
|
result = offpolicy_trainer(
|
||||||
policy, train_collector, test_collector, args.epoch,
|
policy, train_collector, test_collector, args.epoch,
|
||||||
args.step_per_epoch, args.collect_per_step, [1] * args.test_num,
|
args.step_per_epoch, args.collect_per_step, args.test_num,
|
||||||
args.batch_size, stop_fn=stop_fn, writer=writer)
|
args.batch_size, stop_fn=stop_fn, writer=writer)
|
||||||
assert stop_fn(result['best_reward'])
|
assert stop_fn(result['best_reward'])
|
||||||
train_collector.close()
|
train_collector.close()
|
||||||
|
@ -82,7 +82,7 @@ def test_a2c(args=get_args()):
|
|||||||
result = onpolicy_trainer(
|
result = onpolicy_trainer(
|
||||||
policy, train_collector, test_collector, args.epoch,
|
policy, train_collector, test_collector, args.epoch,
|
||||||
args.step_per_epoch, args.collect_per_step, args.repeat_per_collect,
|
args.step_per_epoch, args.collect_per_step, args.repeat_per_collect,
|
||||||
[1] * args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer)
|
args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer)
|
||||||
assert stop_fn(result['best_reward'])
|
assert stop_fn(result['best_reward'])
|
||||||
train_collector.close()
|
train_collector.close()
|
||||||
test_collector.close()
|
test_collector.close()
|
||||||
|
@ -87,7 +87,7 @@ def test_dqn(args=get_args()):
|
|||||||
# trainer
|
# trainer
|
||||||
result = offpolicy_trainer(
|
result = offpolicy_trainer(
|
||||||
policy, train_collector, test_collector, args.epoch,
|
policy, train_collector, test_collector, args.epoch,
|
||||||
args.step_per_epoch, args.collect_per_step, [1] * args.test_num,
|
args.step_per_epoch, args.collect_per_step, args.test_num,
|
||||||
args.batch_size, train_fn=train_fn, test_fn=test_fn,
|
args.batch_size, train_fn=train_fn, test_fn=test_fn,
|
||||||
stop_fn=stop_fn, writer=writer)
|
stop_fn=stop_fn, writer=writer)
|
||||||
|
|
||||||
|
@ -130,7 +130,7 @@ def test_pg(args=get_args()):
|
|||||||
result = onpolicy_trainer(
|
result = onpolicy_trainer(
|
||||||
policy, train_collector, test_collector, args.epoch,
|
policy, train_collector, test_collector, args.epoch,
|
||||||
args.step_per_epoch, args.collect_per_step, args.repeat_per_collect,
|
args.step_per_epoch, args.collect_per_step, args.repeat_per_collect,
|
||||||
[1] * args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer)
|
args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer)
|
||||||
assert stop_fn(result['best_reward'])
|
assert stop_fn(result['best_reward'])
|
||||||
train_collector.close()
|
train_collector.close()
|
||||||
test_collector.close()
|
test_collector.close()
|
||||||
|
@ -87,7 +87,7 @@ def test_ppo(args=get_args()):
|
|||||||
result = onpolicy_trainer(
|
result = onpolicy_trainer(
|
||||||
policy, train_collector, test_collector, args.epoch,
|
policy, train_collector, test_collector, args.epoch,
|
||||||
args.step_per_epoch, args.collect_per_step, args.repeat_per_collect,
|
args.step_per_epoch, args.collect_per_step, args.repeat_per_collect,
|
||||||
[1] * args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer)
|
args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer)
|
||||||
assert stop_fn(result['best_reward'])
|
assert stop_fn(result['best_reward'])
|
||||||
train_collector.close()
|
train_collector.close()
|
||||||
test_collector.close()
|
test_collector.close()
|
||||||
|
@ -54,6 +54,9 @@ class Collector(object):
|
|||||||
else:
|
else:
|
||||||
self.buffer.reset()
|
self.buffer.reset()
|
||||||
|
|
||||||
|
def get_env_num(self):
|
||||||
|
return self.env_num
|
||||||
|
|
||||||
def reset_env(self):
|
def reset_env(self):
|
||||||
self._obs = self.env.reset()
|
self._obs = self.env.reset()
|
||||||
self._act = self._rew = self._done = self._info = None
|
self._act = self._rew = self._done = self._info = None
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
import time
|
import time
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
def test_episode(policy, collector, test_fn, epoch, n_episode):
|
def test_episode(policy, collector, test_fn, epoch, n_episode):
|
||||||
@ -7,6 +8,11 @@ def test_episode(policy, collector, test_fn, epoch, n_episode):
|
|||||||
policy.eval()
|
policy.eval()
|
||||||
if test_fn:
|
if test_fn:
|
||||||
test_fn(epoch)
|
test_fn(epoch)
|
||||||
|
if collector.get_env_num() > 1 and np.isscalar(n_episode):
|
||||||
|
n = collector.get_env_num()
|
||||||
|
n_ = np.zeros(n) + n_episode // n
|
||||||
|
n_[:n_episode % n] += 1
|
||||||
|
n_episode = list(n_)
|
||||||
return collector.collect(n_episode=n_episode)
|
return collector.collect(n_episode=n_episode)
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user