make ppo discrete test script more general (#418)
This commit is contained in:
		
							parent
							
								
									bba30f83d1
								
							
						
					
					
						commit
						5b7732a29b
					
				
							
								
								
									
										34
									
								
								README.md
									
									
									
									
									
								
							
							
						
						
									
										34
									
								
								README.md
									
									
									
									
									
								
							| @ -103,27 +103,29 @@ The example scripts are under [test/](https://github.com/thu-ml/tianshou/blob/ma | ||||
| 
 | ||||
| ### Comprehensive Functionality | ||||
| 
 | ||||
| | RL Platform                                                  | GitHub Stars                                                 | # of RL Alg. <sup>(1)</sup> | Custom Env                  | Batch Training                    | RNN Support        | Nested Observation | Backend    | | ||||
| | ------------------------------------------------------------ | ------------------------------------------------------------ | --------------------------- | --------------------------- | --------------------------------- | ------------------ | ------------------ | ---------- | | ||||
| | [Baselines](https://github.com/openai/baselines)             | [](https://github.com/openai/baselines/stargazers) | 9                           | :heavy_check_mark: (gym)    | :heavy_minus_sign: <sup>(2)</sup> | :heavy_check_mark: | :x:                | TF1        | | ||||
| | [Stable-Baselines](https://github.com/hill-a/stable-baselines) | [](https://github.com/hill-a/stable-baselines/stargazers) | 11                          | :heavy_check_mark: (gym)    | :heavy_minus_sign: <sup>(2)</sup> | :heavy_check_mark: | :x:                | TF1        | | ||||
| | [Stable-Baselines3](https://github.com/DLR-RM/stable-baselines3) | [](https://github.com/DLR-RM/stable-baselines3/stargazers) | 7                           | :heavy_check_mark: (gym)    | :heavy_minus_sign: <sup>(2)</sup> | :x:                | :heavy_check_mark: | PyTorch    | | ||||
| | [Ray/RLlib](https://github.com/ray-project/ray/tree/master/rllib/) | [](https://github.com/ray-project/ray/stargazers) | 16                          | :heavy_check_mark:          | :heavy_check_mark:                | :heavy_check_mark: | :heavy_check_mark: | TF/PyTorch | | ||||
| | [SpinningUp](https://github.com/openai/spinningup)           | [](https://github.com/openai/spinningupstargazers) | 6                           | :heavy_check_mark: (gym)    | :heavy_minus_sign: <sup>(2)</sup> | :x:                | :x:                | PyTorch    | | ||||
| | [Dopamine](https://github.com/google/dopamine)               | [](https://github.com/google/dopamine/stargazers) | 7                           | :x:                         | :x:                               | :x:                | :x:                | TF/JAX     | | ||||
| | [ACME](https://github.com/deepmind/acme)                     | [](https://github.com/deepmind/acme/stargazers) | 14                          | :heavy_check_mark: (dm_env) | :heavy_check_mark:                | :heavy_check_mark: | :heavy_check_mark: | TF/JAX     | | ||||
| | [keras-rl](https://github.com/keras-rl/keras-rl)             | [](https://github.com/keras-rl/keras-rlstargazers) | 7                           | :heavy_check_mark: (gym)    | :x:                               | :x:                | :x:                | Keras      | | ||||
| | [rlpyt](https://github.com/astooke/rlpyt)                    | [](https://github.com/astooke/rlpyt/stargazers) | 11                          | :x:                         | :heavy_check_mark:                | :heavy_check_mark: | :heavy_check_mark: | PyTorch    | | ||||
| | [ChainerRL](https://github.com/chainer/chainerrl)            | [](https://github.com/chainer/chainerrl/stargazers) | 18                          | :heavy_check_mark: (gym)    | :heavy_check_mark:                | :heavy_check_mark: | :x:                | Chainer    | | ||||
| | [Sample Factory](https://github.com/alex-petrenko/sample-factory) | [](https://github.com/alex-petrenko/sample-factory/stargazers) | 1<sup>(3)</sup>             | :heavy_check_mark: (gym)    | :heavy_check_mark:                | :heavy_check_mark: | :heavy_check_mark: | PyTorch    | | ||||
| |                                                              |                                                              |                             |                             |                                   |                    |                    |            | | ||||
| | [Tianshou](https://github.com/thu-ml/tianshou)               | [](https://github.com/thu-ml/tianshou/stargazers) | 20                          | :heavy_check_mark: (gym)    | :heavy_check_mark:                | :heavy_check_mark: | :heavy_check_mark: | PyTorch    | | ||||
| | RL Platform                                                  | GitHub Stars                                                 | # of Alg. <sup>(1)</sup> | Custom Env                  | Batch Training                    | RNN Support        | Nested Observation | Backend    | | ||||
| | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------ | --------------------------- | --------------------------------- | ------------------ | ------------------ | ---------- | | ||||
| | [Baselines](https://github.com/openai/baselines)             | [](https://github.com/openai/baselines/stargazers) | 9                        | :heavy_check_mark: (gym)    | :heavy_minus_sign: <sup>(2)</sup> | :heavy_check_mark: | :x:                | TF1        | | ||||
| | [Stable-Baselines](https://github.com/hill-a/stable-baselines) | [](https://github.com/hill-a/stable-baselines/stargazers) | 11                       | :heavy_check_mark: (gym)    | :heavy_minus_sign: <sup>(2)</sup> | :heavy_check_mark: | :x:                | TF1        | | ||||
| | [Stable-Baselines3](https://github.com/DLR-RM/stable-baselines3) | [](https://github.com/DLR-RM/stable-baselines3/stargazers) | 7<sup> (3)</sup>         | :heavy_check_mark: (gym)    | :heavy_minus_sign: <sup>(2)</sup> | :x:                | :heavy_check_mark: | PyTorch    | | ||||
| | [Ray/RLlib](https://github.com/ray-project/ray/tree/master/rllib/) | [](https://github.com/ray-project/ray/stargazers) | 16                       | :heavy_check_mark:          | :heavy_check_mark:                | :heavy_check_mark: | :heavy_check_mark: | TF/PyTorch | | ||||
| | [SpinningUp](https://github.com/openai/spinningup)           | [](https://github.com/openai/spinningupstargazers) | 6                        | :heavy_check_mark: (gym)    | :heavy_minus_sign: <sup>(2)</sup> | :x:                | :x:                | PyTorch    | | ||||
| | [Dopamine](https://github.com/google/dopamine)               | [](https://github.com/google/dopamine/stargazers) | 7                        | :x:                         | :x:                               | :x:                | :x:                | TF/JAX     | | ||||
| | [ACME](https://github.com/deepmind/acme)                     | [](https://github.com/deepmind/acme/stargazers) | 14                       | :heavy_check_mark: (dm_env) | :heavy_check_mark:                | :heavy_check_mark: | :heavy_check_mark: | TF/JAX     | | ||||
| | [keras-rl](https://github.com/keras-rl/keras-rl)             | [](https://github.com/keras-rl/keras-rlstargazers) | 7                        | :heavy_check_mark: (gym)    | :x:                               | :x:                | :x:                | Keras      | | ||||
| | [rlpyt](https://github.com/astooke/rlpyt)                    | [](https://github.com/astooke/rlpyt/stargazers) | 11                       | :x:                         | :heavy_check_mark:                | :heavy_check_mark: | :heavy_check_mark: | PyTorch    | | ||||
| | [ChainerRL](https://github.com/chainer/chainerrl)            | [](https://github.com/chainer/chainerrl/stargazers) | 18                       | :heavy_check_mark: (gym)    | :heavy_check_mark:                | :heavy_check_mark: | :x:                | Chainer    | | ||||
| | [Sample Factory](https://github.com/alex-petrenko/sample-factory) | [](https://github.com/alex-petrenko/sample-factory/stargazers) | 1<sup> (4)</sup>         | :heavy_check_mark: (gym)    | :heavy_check_mark:                | :heavy_check_mark: | :heavy_check_mark: | PyTorch    | | ||||
| |                                                              |                                                              |                          |                             |                                   |                    |                    |            | | ||||
| | [Tianshou](https://github.com/thu-ml/tianshou)               | [](https://github.com/thu-ml/tianshou/stargazers) | 20                       | :heavy_check_mark: (gym)    | :heavy_check_mark:                | :heavy_check_mark: | :heavy_check_mark: | PyTorch    | | ||||
| 
 | ||||
| <sup>(1): access date: 2021-08-08</sup> | ||||
| 
 | ||||
| <sup>(2): not all algorithms support this feature</sup> | ||||
| 
 | ||||
| <sup>(3): super fast APPO!</sup> | ||||
| <sup>(3): TQC and QR-DQN in [sb3-contrib](https://github.com/Stable-Baselines-Team/stable-baselines3-contrib) instead of main repo</sup> | ||||
| 
 | ||||
| <sup>(4): super fast APPO!</sup> | ||||
| 
 | ||||
| 
 | ||||
| ### High quality software engineering standard | ||||
|  | ||||
| @ -20,12 +20,12 @@ def get_args(): | ||||
|     parser.add_argument('--task', type=str, default='CartPole-v0') | ||||
|     parser.add_argument('--seed', type=int, default=0) | ||||
|     parser.add_argument('--buffer-size', type=int, default=20000) | ||||
|     parser.add_argument('--lr', type=float, default=1e-3) | ||||
|     parser.add_argument('--lr', type=float, default=3e-4) | ||||
|     parser.add_argument('--gamma', type=float, default=0.99) | ||||
|     parser.add_argument('--epoch', type=int, default=10) | ||||
|     parser.add_argument('--step-per-epoch', type=int, default=50000) | ||||
|     parser.add_argument('--episode-per-collect', type=int, default=20) | ||||
|     parser.add_argument('--repeat-per-collect', type=int, default=2) | ||||
|     parser.add_argument('--step-per-collect', type=int, default=2000) | ||||
|     parser.add_argument('--repeat-per-collect', type=int, default=10) | ||||
|     parser.add_argument('--batch-size', type=int, default=64) | ||||
|     parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[64, 64]) | ||||
|     parser.add_argument('--training-num', type=int, default=20) | ||||
| @ -41,15 +41,16 @@ def get_args(): | ||||
|     parser.add_argument('--eps-clip', type=float, default=0.2) | ||||
|     parser.add_argument('--max-grad-norm', type=float, default=0.5) | ||||
|     parser.add_argument('--gae-lambda', type=float, default=0.95) | ||||
|     parser.add_argument('--rew-norm', type=int, default=1) | ||||
|     parser.add_argument('--rew-norm', type=int, default=0) | ||||
|     parser.add_argument('--norm-adv', type=int, default=0) | ||||
|     parser.add_argument('--recompute-adv', type=int, default=0) | ||||
|     parser.add_argument('--dual-clip', type=float, default=None) | ||||
|     parser.add_argument('--value-clip', type=int, default=1) | ||||
|     parser.add_argument('--value-clip', type=int, default=0) | ||||
|     args = parser.parse_known_args()[0] | ||||
|     return args | ||||
| 
 | ||||
| 
 | ||||
| def test_ppo(args=get_args()): | ||||
|     torch.set_num_threads(1)  # for poor CPU | ||||
|     env = gym.make(args.task) | ||||
|     args.state_shape = env.observation_space.shape or env.observation_space.n | ||||
|     args.action_shape = env.action_space.shape or env.action_space.n | ||||
| @ -90,7 +91,9 @@ def test_ppo(args=get_args()): | ||||
|         dual_clip=args.dual_clip, | ||||
|         value_clip=args.value_clip, | ||||
|         action_space=env.action_space, | ||||
|         deterministic_eval=True) | ||||
|         deterministic_eval=True, | ||||
|         advantage_normalization=args.norm_adv, | ||||
|         recompute_advantage=args.recompute_adv) | ||||
|     # collector | ||||
|     train_collector = Collector( | ||||
|         policy, train_envs, | ||||
| @ -111,7 +114,7 @@ def test_ppo(args=get_args()): | ||||
|     result = onpolicy_trainer( | ||||
|         policy, train_collector, test_collector, args.epoch, | ||||
|         args.step_per_epoch, args.repeat_per_collect, args.test_num, args.batch_size, | ||||
|         episode_per_collect=args.episode_per_collect, stop_fn=stop_fn, save_fn=save_fn, | ||||
|         step_per_collect=args.step_per_collect, stop_fn=stop_fn, save_fn=save_fn, | ||||
|         logger=logger) | ||||
|     assert stop_fn(result['best_reward']) | ||||
| 
 | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user