Add VizDoom PPO example and results (#533)
* update vizdoom ppo example * update README with results
This commit is contained in:
		
							parent
							
								
									23fbc3b712
								
							
						
					
					
						commit
						97df511a13
					
				| @ -53,10 +53,6 @@ python3 replay.py maps/D4_battle2.cfg results/c51/d4.lmp | ||||
| 
 | ||||
| See [maps/README.md](maps/README.md) | ||||
| 
 | ||||
| ## Algorithms | ||||
| 
 | ||||
| The setting is exactly the same as Atari. You can definitely try more algorithms listed in Atari example. | ||||
| 
 | ||||
| ## Reward | ||||
| 
 | ||||
| 1. living reward is bad | ||||
| @ -64,3 +60,28 @@ The setting is exactly the same as Atari. You can definitely try more algorithms | ||||
| 3. negative reward for health and ammo2 is really helpful for d3/d4 | ||||
| 4. only with positive reward for health is really helpful for d1 | ||||
| 5. remove MOVE_BACKWARD may converge faster but the final performance may be lower | ||||
| 
 | ||||
| ## Algorithms | ||||
| 
 | ||||
| The setting is exactly the same as Atari. You can definitely try more algorithms listed in Atari example. | ||||
| 
 | ||||
| ### C51 (single run) | ||||
| 
 | ||||
| | task                        | best reward | reward curve                          | parameters                                                   | | ||||
| | --------------------------- | ----------- | ------------------------------------- | ------------------------------------------------------------ | | ||||
| | D2_navigation          | 747.52          |          | `python3 vizdoom_c51.py --task "D2_navigation"` | | ||||
| | D3_battle              | 1855.29          |          | `python3 vizdoom_c51.py --task "D3_battle"` | | ||||
| 
 | ||||
| ### PPO (single run) | ||||
| 
 | ||||
| | task                        | best reward | reward curve                          | parameters                                                   | | ||||
| | --------------------------- | ----------- | ------------------------------------- | ------------------------------------------------------------ | | ||||
| | D2_navigation          | 770.75          |          | `python3 vizdoom_ppo.py --task "D2_navigation"` | | ||||
| | D3_battle              | 320.59          |          | `python3 vizdoom_ppo.py --task "D3_battle"` | | ||||
| 
 | ||||
| ### PPO with ICM (single run) | ||||
| 
 | ||||
| | task                        | best reward | reward curve                          | parameters                                                   | | ||||
| | --------------------------- | ----------- | ------------------------------------- | ------------------------------------------------------------ | | ||||
| | D2_navigation          | 844.99          |          | `python3 vizdoom_ppo.py --task "D2_navigation" --icm-lr-scale 10` | | ||||
| | D3_battle              | 547.08          |          | `python3 vizdoom_ppo.py --task "D3_battle" --icm-lr-scale 10` | | ||||
|  | ||||
							
								
								
									
										
											BIN
										
									
								
								examples/vizdoom/results/c51/D2_navigation_rew.png
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								examples/vizdoom/results/c51/D2_navigation_rew.png
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							| After Width: | Height: | Size: 209 KiB | 
							
								
								
									
										
											BIN
										
									
								
								examples/vizdoom/results/c51/D3_battle_rew.png
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								examples/vizdoom/results/c51/D3_battle_rew.png
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							| After Width: | Height: | Size: 156 KiB | 
							
								
								
									
										
											BIN
										
									
								
								examples/vizdoom/results/ppo/D2_navigation_rew.png
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								examples/vizdoom/results/ppo/D2_navigation_rew.png
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							| After Width: | Height: | Size: 159 KiB | 
							
								
								
									
										
											BIN
										
									
								
								examples/vizdoom/results/ppo/D3_battle_rew.png
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								examples/vizdoom/results/ppo/D3_battle_rew.png
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							| After Width: | Height: | Size: 164 KiB | 
							
								
								
									
										
											BIN
										
									
								
								examples/vizdoom/results/ppo_icm/D2_navigation_rew.png
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								examples/vizdoom/results/ppo_icm/D2_navigation_rew.png
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							| After Width: | Height: | Size: 157 KiB | 
							
								
								
									
										
											BIN
										
									
								
								examples/vizdoom/results/ppo_icm/D3_battle_rew.png
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								examples/vizdoom/results/ppo_icm/D3_battle_rew.png
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							| After Width: | Height: | Size: 159 KiB | 
| @ -6,11 +6,12 @@ import numpy as np | ||||
| import torch | ||||
| from env import Env | ||||
| from network import DQN | ||||
| from torch.optim.lr_scheduler import LambdaLR | ||||
| from torch.utils.tensorboard import SummaryWriter | ||||
| 
 | ||||
| from tianshou.data import Collector, VectorReplayBuffer | ||||
| from tianshou.env import ShmemVectorEnv | ||||
| from tianshou.policy import A2CPolicy, ICMPolicy | ||||
| from tianshou.policy import ICMPolicy, PPOPolicy | ||||
| from tianshou.trainer import onpolicy_trainer | ||||
| from tianshou.utils import TensorboardLogger | ||||
| from tianshou.utils.net.common import ActorCritic | ||||
| @ -21,18 +22,28 @@ def get_args(): | ||||
|     parser = argparse.ArgumentParser() | ||||
|     parser.add_argument('--task', type=str, default='D2_navigation') | ||||
|     parser.add_argument('--seed', type=int, default=0) | ||||
|     parser.add_argument('--buffer-size', type=int, default=2000000) | ||||
|     parser.add_argument('--lr', type=float, default=0.0001) | ||||
|     parser.add_argument('--buffer-size', type=int, default=100000) | ||||
|     parser.add_argument('--lr', type=float, default=0.00002) | ||||
|     parser.add_argument('--gamma', type=float, default=0.99) | ||||
|     parser.add_argument('--epoch', type=int, default=300) | ||||
|     parser.add_argument('--step-per-epoch', type=int, default=100000) | ||||
|     parser.add_argument('--episode-per-collect', type=int, default=10) | ||||
|     parser.add_argument('--update-per-step', type=float, default=0.1) | ||||
|     parser.add_argument('--update-per-step', type=int, default=1) | ||||
|     parser.add_argument('--batch-size', type=int, default=64) | ||||
|     parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[512]) | ||||
|     parser.add_argument('--step-per-collect', type=int, default=1000) | ||||
|     parser.add_argument('--repeat-per-collect', type=int, default=4) | ||||
|     parser.add_argument('--batch-size', type=int, default=256) | ||||
|     parser.add_argument('--hidden-size', type=int, default=512) | ||||
|     parser.add_argument('--training-num', type=int, default=10) | ||||
|     parser.add_argument('--test-num', type=int, default=100) | ||||
|     parser.add_argument('--rew-norm', type=int, default=False) | ||||
|     parser.add_argument('--vf-coef', type=float, default=0.5) | ||||
|     parser.add_argument('--ent-coef', type=float, default=0.01) | ||||
|     parser.add_argument('--gae-lambda', type=float, default=0.95) | ||||
|     parser.add_argument('--lr-decay', type=int, default=True) | ||||
|     parser.add_argument('--max-grad-norm', type=float, default=0.5) | ||||
|     parser.add_argument('--eps-clip', type=float, default=0.2) | ||||
|     parser.add_argument('--dual-clip', type=float, default=None) | ||||
|     parser.add_argument('--value-clip', type=int, default=0) | ||||
|     parser.add_argument('--norm-adv', type=int, default=1) | ||||
|     parser.add_argument('--recompute-adv', type=int, default=0) | ||||
|     parser.add_argument('--logdir', type=str, default='log') | ||||
|     parser.add_argument('--render', type=float, default=0.) | ||||
|     parser.add_argument( | ||||
| @ -75,7 +86,7 @@ def get_args(): | ||||
|     return parser.parse_args() | ||||
| 
 | ||||
| 
 | ||||
| def test_a2c(args=get_args()): | ||||
| def test_ppo(args=get_args()): | ||||
|     args.cfg_path = f"maps/{args.task}.cfg" | ||||
|     args.wad_path = f"maps/{args.task}.wad" | ||||
|     args.res = (args.skip_num, 84, 84) | ||||
| @ -105,33 +116,65 @@ def test_a2c(args=get_args()): | ||||
|     test_envs.seed(args.seed) | ||||
|     # define model | ||||
|     net = DQN( | ||||
|         *args.state_shape, args.action_shape, device=args.device, features_only=True | ||||
|         *args.state_shape, | ||||
|         args.action_shape, | ||||
|         device=args.device, | ||||
|         features_only=True, | ||||
|         output_dim=args.hidden_size | ||||
|     ) | ||||
|     actor = Actor( | ||||
|         net, args.action_shape, hidden_sizes=args.hidden_sizes, device=args.device | ||||
|     ) | ||||
|     critic = Critic(net, hidden_sizes=args.hidden_sizes, device=args.device) | ||||
|     actor = Actor(net, args.action_shape, device=args.device, softmax_output=False) | ||||
|     critic = Critic(net, device=args.device) | ||||
|     optim = torch.optim.Adam(ActorCritic(actor, critic).parameters(), lr=args.lr) | ||||
| 
 | ||||
|     lr_scheduler = None | ||||
|     if args.lr_decay: | ||||
|         # decay learning rate to 0 linearly | ||||
|         max_update_num = np.ceil( | ||||
|             args.step_per_epoch / args.step_per_collect | ||||
|         ) * args.epoch | ||||
| 
 | ||||
|         lr_scheduler = LambdaLR( | ||||
|             optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num | ||||
|         ) | ||||
| 
 | ||||
|     # define policy | ||||
|     dist = torch.distributions.Categorical | ||||
|     policy = A2CPolicy(actor, critic, optim, dist).to(args.device) | ||||
|     def dist(p): | ||||
|         return torch.distributions.Categorical(logits=p) | ||||
| 
 | ||||
|     policy = PPOPolicy( | ||||
|         actor, | ||||
|         critic, | ||||
|         optim, | ||||
|         dist, | ||||
|         discount_factor=args.gamma, | ||||
|         gae_lambda=args.gae_lambda, | ||||
|         max_grad_norm=args.max_grad_norm, | ||||
|         vf_coef=args.vf_coef, | ||||
|         ent_coef=args.ent_coef, | ||||
|         reward_normalization=args.rew_norm, | ||||
|         action_scaling=False, | ||||
|         lr_scheduler=lr_scheduler, | ||||
|         action_space=env.action_space, | ||||
|         eps_clip=args.eps_clip, | ||||
|         value_clip=args.value_clip, | ||||
|         dual_clip=args.dual_clip, | ||||
|         advantage_normalization=args.norm_adv, | ||||
|         recompute_advantage=args.recompute_adv | ||||
|     ).to(args.device) | ||||
|     if args.icm_lr_scale > 0: | ||||
|         feature_net = DQN( | ||||
|             *args.state_shape, | ||||
|             args.action_shape, | ||||
|             device=args.device, | ||||
|             features_only=True | ||||
|             features_only=True, | ||||
|             output_dim=args.hidden_size | ||||
|         ) | ||||
|         action_dim = np.prod(args.action_shape) | ||||
|         feature_dim = feature_net.output_dim | ||||
|         icm_net = IntrinsicCuriosityModule( | ||||
|             feature_net.net, | ||||
|             feature_dim, | ||||
|             action_dim, | ||||
|             hidden_sizes=args.hidden_sizes, | ||||
|             device=args.device | ||||
|             feature_net.net, feature_dim, action_dim, device=args.device | ||||
|         ) | ||||
|         icm_optim = torch.optim.adam(icm_net.parameters(), lr=args.lr) | ||||
|         icm_optim = torch.optim.Adam(icm_net.parameters(), lr=args.lr) | ||||
|         policy = ICMPolicy( | ||||
|             policy, icm_net, icm_optim, args.icm_lr_scale, args.icm_reward_scale, | ||||
|             args.icm_forward_loss_weight | ||||
| @ -153,7 +196,8 @@ def test_a2c(args=get_args()): | ||||
|     train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) | ||||
|     test_collector = Collector(policy, test_envs, exploration_noise=True) | ||||
|     # log | ||||
|     log_path = os.path.join(args.logdir, args.task, 'a2c') | ||||
|     log_name = 'ppo_icm' if args.icm_lr_scale > 0 else 'ppo' | ||||
|     log_path = os.path.join(args.logdir, args.task, log_name) | ||||
|     writer = SummaryWriter(log_path) | ||||
|     writer.add_text("args", str(args)) | ||||
|     logger = TensorboardLogger(writer) | ||||
| @ -162,10 +206,15 @@ def test_a2c(args=get_args()): | ||||
|         torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) | ||||
| 
 | ||||
|     def stop_fn(mean_rewards): | ||||
|         if env.spec.reward_threshold: | ||||
|             return mean_rewards >= env.spec.reward_threshold | ||||
|         elif 'Pong' in args.task: | ||||
|             return mean_rewards >= 20 | ||||
|         else: | ||||
|             return False | ||||
| 
 | ||||
|     def watch(): | ||||
|     # watch agent's performance | ||||
|     def watch(): | ||||
|         print("Setup test envs ...") | ||||
|         policy.eval() | ||||
|         test_envs.seed(args.seed) | ||||
| @ -210,7 +259,7 @@ def test_a2c(args=get_args()): | ||||
|         args.repeat_per_collect, | ||||
|         args.test_num, | ||||
|         args.batch_size, | ||||
|         episode_per_collect=args.episode_per_collect, | ||||
|         step_per_collect=args.step_per_collect, | ||||
|         stop_fn=stop_fn, | ||||
|         save_fn=save_fn, | ||||
|         logger=logger, | ||||
| @ -222,4 +271,4 @@ def test_a2c(args=get_args()): | ||||
| 
 | ||||
| 
 | ||||
| if __name__ == '__main__': | ||||
|     test_a2c(get_args()) | ||||
|     test_ppo(get_args()) | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user