make ppo discrete test script more general (#418)

This commit is contained in:
n+e 2021-08-15 21:37:37 +08:00 committed by GitHub
parent bba30f83d1
commit 5b7732a29b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 29 additions and 24 deletions

View File

@ -103,27 +103,29 @@ The example scripts are under [test/](https://github.com/thu-ml/tianshou/blob/ma
### Comprehensive Functionality
| RL Platform | GitHub Stars | # of RL Alg. <sup>(1)</sup> | Custom Env | Batch Training | RNN Support | Nested Observation | Backend |
| ------------------------------------------------------------ | ------------------------------------------------------------ | --------------------------- | --------------------------- | --------------------------------- | ------------------ | ------------------ | ---------- |
| [Baselines](https://github.com/openai/baselines) | [![GitHub stars](https://img.shields.io/github/stars/openai/baselines)](https://github.com/openai/baselines/stargazers) | 9 | :heavy_check_mark: (gym) | :heavy_minus_sign: <sup>(2)</sup> | :heavy_check_mark: | :x: | TF1 |
| [Stable-Baselines](https://github.com/hill-a/stable-baselines) | [![GitHub stars](https://img.shields.io/github/stars/hill-a/stable-baselines)](https://github.com/hill-a/stable-baselines/stargazers) | 11 | :heavy_check_mark: (gym) | :heavy_minus_sign: <sup>(2)</sup> | :heavy_check_mark: | :x: | TF1 |
| [Stable-Baselines3](https://github.com/DLR-RM/stable-baselines3) | [![GitHub stars](https://img.shields.io/github/stars/DLR-RM/stable-baselines3)](https://github.com/DLR-RM/stable-baselines3/stargazers) | 7 | :heavy_check_mark: (gym) | :heavy_minus_sign: <sup>(2)</sup> | :x: | :heavy_check_mark: | PyTorch |
| [Ray/RLlib](https://github.com/ray-project/ray/tree/master/rllib/) | [![GitHub stars](https://img.shields.io/github/stars/ray-project/ray)](https://github.com/ray-project/ray/stargazers) | 16 | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | TF/PyTorch |
| [SpinningUp](https://github.com/openai/spinningup) | [![GitHub stars](https://img.shields.io/github/stars/openai/spinningup)](https://github.com/openai/spinningupstargazers) | 6 | :heavy_check_mark: (gym) | :heavy_minus_sign: <sup>(2)</sup> | :x: | :x: | PyTorch |
| [Dopamine](https://github.com/google/dopamine) | [![GitHub stars](https://img.shields.io/github/stars/google/dopamine)](https://github.com/google/dopamine/stargazers) | 7 | :x: | :x: | :x: | :x: | TF/JAX |
| [ACME](https://github.com/deepmind/acme) | [![GitHub stars](https://img.shields.io/github/stars/deepmind/acme)](https://github.com/deepmind/acme/stargazers) | 14 | :heavy_check_mark: (dm_env) | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | TF/JAX |
| [keras-rl](https://github.com/keras-rl/keras-rl) | [![GitHub stars](https://img.shields.io/github/stars/keras-rl/keras-rl)](https://github.com/keras-rl/keras-rlstargazers) | 7 | :heavy_check_mark: (gym) | :x: | :x: | :x: | Keras |
| [rlpyt](https://github.com/astooke/rlpyt) | [![GitHub stars](https://img.shields.io/github/stars/astooke/rlpyt)](https://github.com/astooke/rlpyt/stargazers) | 11 | :x: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | PyTorch |
| [ChainerRL](https://github.com/chainer/chainerrl) | [![GitHub stars](https://img.shields.io/github/stars/chainer/chainerrl)](https://github.com/chainer/chainerrl/stargazers) | 18 | :heavy_check_mark: (gym) | :heavy_check_mark: | :heavy_check_mark: | :x: | Chainer |
| [Sample Factory](https://github.com/alex-petrenko/sample-factory) | [![GitHub stars](https://img.shields.io/github/stars/alex-petrenko/sample-factory)](https://github.com/alex-petrenko/sample-factory/stargazers) | 1<sup>(3)</sup> | :heavy_check_mark: (gym) | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | PyTorch |
| | | | | | | | |
| [Tianshou](https://github.com/thu-ml/tianshou) | [![GitHub stars](https://img.shields.io/github/stars/thu-ml/tianshou)](https://github.com/thu-ml/tianshou/stargazers) | 20 | :heavy_check_mark: (gym) | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | PyTorch |
| RL Platform | GitHub Stars | # of Alg. <sup>(1)</sup> | Custom Env | Batch Training | RNN Support | Nested Observation | Backend |
| ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------ | --------------------------- | --------------------------------- | ------------------ | ------------------ | ---------- |
| [Baselines](https://github.com/openai/baselines) | [![GitHub stars](https://img.shields.io/github/stars/openai/baselines)](https://github.com/openai/baselines/stargazers) | 9 | :heavy_check_mark: (gym) | :heavy_minus_sign: <sup>(2)</sup> | :heavy_check_mark: | :x: | TF1 |
| [Stable-Baselines](https://github.com/hill-a/stable-baselines) | [![GitHub stars](https://img.shields.io/github/stars/hill-a/stable-baselines)](https://github.com/hill-a/stable-baselines/stargazers) | 11 | :heavy_check_mark: (gym) | :heavy_minus_sign: <sup>(2)</sup> | :heavy_check_mark: | :x: | TF1 |
| [Stable-Baselines3](https://github.com/DLR-RM/stable-baselines3) | [![GitHub stars](https://img.shields.io/github/stars/DLR-RM/stable-baselines3)](https://github.com/DLR-RM/stable-baselines3/stargazers) | 7<sup> (3)</sup> | :heavy_check_mark: (gym) | :heavy_minus_sign: <sup>(2)</sup> | :x: | :heavy_check_mark: | PyTorch |
| [Ray/RLlib](https://github.com/ray-project/ray/tree/master/rllib/) | [![GitHub stars](https://img.shields.io/github/stars/ray-project/ray)](https://github.com/ray-project/ray/stargazers) | 16 | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | TF/PyTorch |
| [SpinningUp](https://github.com/openai/spinningup) | [![GitHub stars](https://img.shields.io/github/stars/openai/spinningup)](https://github.com/openai/spinningupstargazers) | 6 | :heavy_check_mark: (gym) | :heavy_minus_sign: <sup>(2)</sup> | :x: | :x: | PyTorch |
| [Dopamine](https://github.com/google/dopamine) | [![GitHub stars](https://img.shields.io/github/stars/google/dopamine)](https://github.com/google/dopamine/stargazers) | 7 | :x: | :x: | :x: | :x: | TF/JAX |
| [ACME](https://github.com/deepmind/acme) | [![GitHub stars](https://img.shields.io/github/stars/deepmind/acme)](https://github.com/deepmind/acme/stargazers) | 14 | :heavy_check_mark: (dm_env) | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | TF/JAX |
| [keras-rl](https://github.com/keras-rl/keras-rl) | [![GitHub stars](https://img.shields.io/github/stars/keras-rl/keras-rl)](https://github.com/keras-rl/keras-rlstargazers) | 7 | :heavy_check_mark: (gym) | :x: | :x: | :x: | Keras |
| [rlpyt](https://github.com/astooke/rlpyt) | [![GitHub stars](https://img.shields.io/github/stars/astooke/rlpyt)](https://github.com/astooke/rlpyt/stargazers) | 11 | :x: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | PyTorch |
| [ChainerRL](https://github.com/chainer/chainerrl) | [![GitHub stars](https://img.shields.io/github/stars/chainer/chainerrl)](https://github.com/chainer/chainerrl/stargazers) | 18 | :heavy_check_mark: (gym) | :heavy_check_mark: | :heavy_check_mark: | :x: | Chainer |
| [Sample Factory](https://github.com/alex-petrenko/sample-factory) | [![GitHub stars](https://img.shields.io/github/stars/alex-petrenko/sample-factory)](https://github.com/alex-petrenko/sample-factory/stargazers) | 1<sup> (4)</sup> | :heavy_check_mark: (gym) | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | PyTorch |
| | | | | | | | |
| [Tianshou](https://github.com/thu-ml/tianshou) | [![GitHub stars](https://img.shields.io/github/stars/thu-ml/tianshou)](https://github.com/thu-ml/tianshou/stargazers) | 20 | :heavy_check_mark: (gym) | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | PyTorch |
<sup>(1): access date: 2021-08-08</sup>
<sup>(2): not all algorithms support this feature</sup>
<sup>(3): super fast APPO!</sup>
<sup>(3): TQC and QR-DQN in [sb3-contrib](https://github.com/Stable-Baselines-Team/stable-baselines3-contrib) instead of main repo</sup>
<sup>(4): super fast APPO!</sup>
### High quality software engineering standard

View File

@ -20,12 +20,12 @@ def get_args():
parser.add_argument('--task', type=str, default='CartPole-v0')
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--buffer-size', type=int, default=20000)
parser.add_argument('--lr', type=float, default=1e-3)
parser.add_argument('--lr', type=float, default=3e-4)
parser.add_argument('--gamma', type=float, default=0.99)
parser.add_argument('--epoch', type=int, default=10)
parser.add_argument('--step-per-epoch', type=int, default=50000)
parser.add_argument('--episode-per-collect', type=int, default=20)
parser.add_argument('--repeat-per-collect', type=int, default=2)
parser.add_argument('--step-per-collect', type=int, default=2000)
parser.add_argument('--repeat-per-collect', type=int, default=10)
parser.add_argument('--batch-size', type=int, default=64)
parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[64, 64])
parser.add_argument('--training-num', type=int, default=20)
@ -41,15 +41,16 @@ def get_args():
parser.add_argument('--eps-clip', type=float, default=0.2)
parser.add_argument('--max-grad-norm', type=float, default=0.5)
parser.add_argument('--gae-lambda', type=float, default=0.95)
parser.add_argument('--rew-norm', type=int, default=1)
parser.add_argument('--rew-norm', type=int, default=0)
parser.add_argument('--norm-adv', type=int, default=0)
parser.add_argument('--recompute-adv', type=int, default=0)
parser.add_argument('--dual-clip', type=float, default=None)
parser.add_argument('--value-clip', type=int, default=1)
parser.add_argument('--value-clip', type=int, default=0)
args = parser.parse_known_args()[0]
return args
def test_ppo(args=get_args()):
torch.set_num_threads(1) # for poor CPU
env = gym.make(args.task)
args.state_shape = env.observation_space.shape or env.observation_space.n
args.action_shape = env.action_space.shape or env.action_space.n
@ -90,7 +91,9 @@ def test_ppo(args=get_args()):
dual_clip=args.dual_clip,
value_clip=args.value_clip,
action_space=env.action_space,
deterministic_eval=True)
deterministic_eval=True,
advantage_normalization=args.norm_adv,
recompute_advantage=args.recompute_adv)
# collector
train_collector = Collector(
policy, train_envs,
@ -111,7 +114,7 @@ def test_ppo(args=get_args()):
result = onpolicy_trainer(
policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.repeat_per_collect, args.test_num, args.batch_size,
episode_per_collect=args.episode_per_collect, stop_fn=stop_fn, save_fn=save_fn,
step_per_collect=args.step_per_collect, stop_fn=stop_fn, save_fn=save_fn,
logger=logger)
assert stop_fn(result['best_reward'])