Add VizDoom PPO example and results (#533)
* update vizdoom ppo example * update README with results
This commit is contained in:
		
							parent
							
								
									23fbc3b712
								
							
						
					
					
						commit
						97df511a13
					
				| @ -53,10 +53,6 @@ python3 replay.py maps/D4_battle2.cfg results/c51/d4.lmp | |||||||
| 
 | 
 | ||||||
| See [maps/README.md](maps/README.md) | See [maps/README.md](maps/README.md) | ||||||
| 
 | 
 | ||||||
| ## Algorithms |  | ||||||
| 
 |  | ||||||
| The setting is exactly the same as Atari. You can definitely try more algorithms listed in Atari example. |  | ||||||
| 
 |  | ||||||
| ## Reward | ## Reward | ||||||
| 
 | 
 | ||||||
| 1. living reward is bad | 1. living reward is bad | ||||||
| @ -64,3 +60,28 @@ The setting is exactly the same as Atari. You can definitely try more algorithms | |||||||
| 3. negative reward for health and ammo2 is really helpful for d3/d4 | 3. negative reward for health and ammo2 is really helpful for d3/d4 | ||||||
| 4. only with positive reward for health is really helpful for d1 | 4. only with positive reward for health is really helpful for d1 | ||||||
| 5. remove MOVE_BACKWARD may converge faster but the final performance may be lower | 5. remove MOVE_BACKWARD may converge faster but the final performance may be lower | ||||||
|  | 
 | ||||||
|  | ## Algorithms | ||||||
|  | 
 | ||||||
|  | The setting is exactly the same as Atari. You can definitely try more algorithms listed in Atari example. | ||||||
|  | 
 | ||||||
|  | ### C51 (single run) | ||||||
|  | 
 | ||||||
|  | | task                        | best reward | reward curve                          | parameters                                                   | | ||||||
|  | | --------------------------- | ----------- | ------------------------------------- | ------------------------------------------------------------ | | ||||||
|  | | D2_navigation          | 747.52          |          | `python3 vizdoom_c51.py --task "D2_navigation"` | | ||||||
|  | | D3_battle              | 1855.29          |          | `python3 vizdoom_c51.py --task "D3_battle"` | | ||||||
|  | 
 | ||||||
|  | ### PPO (single run) | ||||||
|  | 
 | ||||||
|  | | task                        | best reward | reward curve                          | parameters                                                   | | ||||||
|  | | --------------------------- | ----------- | ------------------------------------- | ------------------------------------------------------------ | | ||||||
|  | | D2_navigation          | 770.75          |          | `python3 vizdoom_ppo.py --task "D2_navigation"` | | ||||||
|  | | D3_battle              | 320.59          |          | `python3 vizdoom_ppo.py --task "D3_battle"` | | ||||||
|  | 
 | ||||||
|  | ### PPO with ICM (single run) | ||||||
|  | 
 | ||||||
|  | | task                        | best reward | reward curve                          | parameters                                                   | | ||||||
|  | | --------------------------- | ----------- | ------------------------------------- | ------------------------------------------------------------ | | ||||||
|  | | D2_navigation          | 844.99          |          | `python3 vizdoom_ppo.py --task "D2_navigation" --icm-lr-scale 10` | | ||||||
|  | | D3_battle              | 547.08          |          | `python3 vizdoom_ppo.py --task "D3_battle" --icm-lr-scale 10` | | ||||||
|  | |||||||
							
								
								
									
										
											BIN
										
									
								
								examples/vizdoom/results/c51/D2_navigation_rew.png
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								examples/vizdoom/results/c51/D2_navigation_rew.png
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							| After Width: | Height: | Size: 209 KiB | 
							
								
								
									
										
											BIN
										
									
								
								examples/vizdoom/results/c51/D3_battle_rew.png
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								examples/vizdoom/results/c51/D3_battle_rew.png
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							| After Width: | Height: | Size: 156 KiB | 
							
								
								
									
										
											BIN
										
									
								
								examples/vizdoom/results/ppo/D2_navigation_rew.png
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								examples/vizdoom/results/ppo/D2_navigation_rew.png
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							| After Width: | Height: | Size: 159 KiB | 
							
								
								
									
										
											BIN
										
									
								
								examples/vizdoom/results/ppo/D3_battle_rew.png
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								examples/vizdoom/results/ppo/D3_battle_rew.png
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							| After Width: | Height: | Size: 164 KiB | 
							
								
								
									
										
											BIN
										
									
								
								examples/vizdoom/results/ppo_icm/D2_navigation_rew.png
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								examples/vizdoom/results/ppo_icm/D2_navigation_rew.png
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							| After Width: | Height: | Size: 157 KiB | 
							
								
								
									
										
											BIN
										
									
								
								examples/vizdoom/results/ppo_icm/D3_battle_rew.png
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								examples/vizdoom/results/ppo_icm/D3_battle_rew.png
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							| After Width: | Height: | Size: 159 KiB | 
| @ -6,11 +6,12 @@ import numpy as np | |||||||
| import torch | import torch | ||||||
| from env import Env | from env import Env | ||||||
| from network import DQN | from network import DQN | ||||||
|  | from torch.optim.lr_scheduler import LambdaLR | ||||||
| from torch.utils.tensorboard import SummaryWriter | from torch.utils.tensorboard import SummaryWriter | ||||||
| 
 | 
 | ||||||
| from tianshou.data import Collector, VectorReplayBuffer | from tianshou.data import Collector, VectorReplayBuffer | ||||||
| from tianshou.env import ShmemVectorEnv | from tianshou.env import ShmemVectorEnv | ||||||
| from tianshou.policy import A2CPolicy, ICMPolicy | from tianshou.policy import ICMPolicy, PPOPolicy | ||||||
| from tianshou.trainer import onpolicy_trainer | from tianshou.trainer import onpolicy_trainer | ||||||
| from tianshou.utils import TensorboardLogger | from tianshou.utils import TensorboardLogger | ||||||
| from tianshou.utils.net.common import ActorCritic | from tianshou.utils.net.common import ActorCritic | ||||||
| @ -21,18 +22,28 @@ def get_args(): | |||||||
|     parser = argparse.ArgumentParser() |     parser = argparse.ArgumentParser() | ||||||
|     parser.add_argument('--task', type=str, default='D2_navigation') |     parser.add_argument('--task', type=str, default='D2_navigation') | ||||||
|     parser.add_argument('--seed', type=int, default=0) |     parser.add_argument('--seed', type=int, default=0) | ||||||
|     parser.add_argument('--buffer-size', type=int, default=2000000) |     parser.add_argument('--buffer-size', type=int, default=100000) | ||||||
|     parser.add_argument('--lr', type=float, default=0.0001) |     parser.add_argument('--lr', type=float, default=0.00002) | ||||||
|     parser.add_argument('--gamma', type=float, default=0.99) |     parser.add_argument('--gamma', type=float, default=0.99) | ||||||
|     parser.add_argument('--epoch', type=int, default=300) |     parser.add_argument('--epoch', type=int, default=300) | ||||||
|     parser.add_argument('--step-per-epoch', type=int, default=100000) |     parser.add_argument('--step-per-epoch', type=int, default=100000) | ||||||
|     parser.add_argument('--episode-per-collect', type=int, default=10) |     parser.add_argument('--step-per-collect', type=int, default=1000) | ||||||
|     parser.add_argument('--update-per-step', type=float, default=0.1) |     parser.add_argument('--repeat-per-collect', type=int, default=4) | ||||||
|     parser.add_argument('--update-per-step', type=int, default=1) |     parser.add_argument('--batch-size', type=int, default=256) | ||||||
|     parser.add_argument('--batch-size', type=int, default=64) |     parser.add_argument('--hidden-size', type=int, default=512) | ||||||
|     parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[512]) |  | ||||||
|     parser.add_argument('--training-num', type=int, default=10) |     parser.add_argument('--training-num', type=int, default=10) | ||||||
|     parser.add_argument('--test-num', type=int, default=100) |     parser.add_argument('--test-num', type=int, default=100) | ||||||
|  |     parser.add_argument('--rew-norm', type=int, default=False) | ||||||
|  |     parser.add_argument('--vf-coef', type=float, default=0.5) | ||||||
|  |     parser.add_argument('--ent-coef', type=float, default=0.01) | ||||||
|  |     parser.add_argument('--gae-lambda', type=float, default=0.95) | ||||||
|  |     parser.add_argument('--lr-decay', type=int, default=True) | ||||||
|  |     parser.add_argument('--max-grad-norm', type=float, default=0.5) | ||||||
|  |     parser.add_argument('--eps-clip', type=float, default=0.2) | ||||||
|  |     parser.add_argument('--dual-clip', type=float, default=None) | ||||||
|  |     parser.add_argument('--value-clip', type=int, default=0) | ||||||
|  |     parser.add_argument('--norm-adv', type=int, default=1) | ||||||
|  |     parser.add_argument('--recompute-adv', type=int, default=0) | ||||||
|     parser.add_argument('--logdir', type=str, default='log') |     parser.add_argument('--logdir', type=str, default='log') | ||||||
|     parser.add_argument('--render', type=float, default=0.) |     parser.add_argument('--render', type=float, default=0.) | ||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
| @ -75,7 +86,7 @@ def get_args(): | |||||||
|     return parser.parse_args() |     return parser.parse_args() | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def test_a2c(args=get_args()): | def test_ppo(args=get_args()): | ||||||
|     args.cfg_path = f"maps/{args.task}.cfg" |     args.cfg_path = f"maps/{args.task}.cfg" | ||||||
|     args.wad_path = f"maps/{args.task}.wad" |     args.wad_path = f"maps/{args.task}.wad" | ||||||
|     args.res = (args.skip_num, 84, 84) |     args.res = (args.skip_num, 84, 84) | ||||||
| @ -105,33 +116,65 @@ def test_a2c(args=get_args()): | |||||||
|     test_envs.seed(args.seed) |     test_envs.seed(args.seed) | ||||||
|     # define model |     # define model | ||||||
|     net = DQN( |     net = DQN( | ||||||
|         *args.state_shape, args.action_shape, device=args.device, features_only=True |         *args.state_shape, | ||||||
|  |         args.action_shape, | ||||||
|  |         device=args.device, | ||||||
|  |         features_only=True, | ||||||
|  |         output_dim=args.hidden_size | ||||||
|     ) |     ) | ||||||
|     actor = Actor( |     actor = Actor(net, args.action_shape, device=args.device, softmax_output=False) | ||||||
|         net, args.action_shape, hidden_sizes=args.hidden_sizes, device=args.device |     critic = Critic(net, device=args.device) | ||||||
|     ) |  | ||||||
|     critic = Critic(net, hidden_sizes=args.hidden_sizes, device=args.device) |  | ||||||
|     optim = torch.optim.Adam(ActorCritic(actor, critic).parameters(), lr=args.lr) |     optim = torch.optim.Adam(ActorCritic(actor, critic).parameters(), lr=args.lr) | ||||||
|  | 
 | ||||||
|  |     lr_scheduler = None | ||||||
|  |     if args.lr_decay: | ||||||
|  |         # decay learning rate to 0 linearly | ||||||
|  |         max_update_num = np.ceil( | ||||||
|  |             args.step_per_epoch / args.step_per_collect | ||||||
|  |         ) * args.epoch | ||||||
|  | 
 | ||||||
|  |         lr_scheduler = LambdaLR( | ||||||
|  |             optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|     # define policy |     # define policy | ||||||
|     dist = torch.distributions.Categorical |     def dist(p): | ||||||
|     policy = A2CPolicy(actor, critic, optim, dist).to(args.device) |         return torch.distributions.Categorical(logits=p) | ||||||
|  | 
 | ||||||
|  |     policy = PPOPolicy( | ||||||
|  |         actor, | ||||||
|  |         critic, | ||||||
|  |         optim, | ||||||
|  |         dist, | ||||||
|  |         discount_factor=args.gamma, | ||||||
|  |         gae_lambda=args.gae_lambda, | ||||||
|  |         max_grad_norm=args.max_grad_norm, | ||||||
|  |         vf_coef=args.vf_coef, | ||||||
|  |         ent_coef=args.ent_coef, | ||||||
|  |         reward_normalization=args.rew_norm, | ||||||
|  |         action_scaling=False, | ||||||
|  |         lr_scheduler=lr_scheduler, | ||||||
|  |         action_space=env.action_space, | ||||||
|  |         eps_clip=args.eps_clip, | ||||||
|  |         value_clip=args.value_clip, | ||||||
|  |         dual_clip=args.dual_clip, | ||||||
|  |         advantage_normalization=args.norm_adv, | ||||||
|  |         recompute_advantage=args.recompute_adv | ||||||
|  |     ).to(args.device) | ||||||
|     if args.icm_lr_scale > 0: |     if args.icm_lr_scale > 0: | ||||||
|         feature_net = DQN( |         feature_net = DQN( | ||||||
|             *args.state_shape, |             *args.state_shape, | ||||||
|             args.action_shape, |             args.action_shape, | ||||||
|             device=args.device, |             device=args.device, | ||||||
|             features_only=True |             features_only=True, | ||||||
|  |             output_dim=args.hidden_size | ||||||
|         ) |         ) | ||||||
|         action_dim = np.prod(args.action_shape) |         action_dim = np.prod(args.action_shape) | ||||||
|         feature_dim = feature_net.output_dim |         feature_dim = feature_net.output_dim | ||||||
|         icm_net = IntrinsicCuriosityModule( |         icm_net = IntrinsicCuriosityModule( | ||||||
|             feature_net.net, |             feature_net.net, feature_dim, action_dim, device=args.device | ||||||
|             feature_dim, |  | ||||||
|             action_dim, |  | ||||||
|             hidden_sizes=args.hidden_sizes, |  | ||||||
|             device=args.device |  | ||||||
|         ) |         ) | ||||||
|         icm_optim = torch.optim.adam(icm_net.parameters(), lr=args.lr) |         icm_optim = torch.optim.Adam(icm_net.parameters(), lr=args.lr) | ||||||
|         policy = ICMPolicy( |         policy = ICMPolicy( | ||||||
|             policy, icm_net, icm_optim, args.icm_lr_scale, args.icm_reward_scale, |             policy, icm_net, icm_optim, args.icm_lr_scale, args.icm_reward_scale, | ||||||
|             args.icm_forward_loss_weight |             args.icm_forward_loss_weight | ||||||
| @ -153,7 +196,8 @@ def test_a2c(args=get_args()): | |||||||
|     train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) |     train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) | ||||||
|     test_collector = Collector(policy, test_envs, exploration_noise=True) |     test_collector = Collector(policy, test_envs, exploration_noise=True) | ||||||
|     # log |     # log | ||||||
|     log_path = os.path.join(args.logdir, args.task, 'a2c') |     log_name = 'ppo_icm' if args.icm_lr_scale > 0 else 'ppo' | ||||||
|  |     log_path = os.path.join(args.logdir, args.task, log_name) | ||||||
|     writer = SummaryWriter(log_path) |     writer = SummaryWriter(log_path) | ||||||
|     writer.add_text("args", str(args)) |     writer.add_text("args", str(args)) | ||||||
|     logger = TensorboardLogger(writer) |     logger = TensorboardLogger(writer) | ||||||
| @ -162,10 +206,15 @@ def test_a2c(args=get_args()): | |||||||
|         torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) |         torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) | ||||||
| 
 | 
 | ||||||
|     def stop_fn(mean_rewards): |     def stop_fn(mean_rewards): | ||||||
|         return False |         if env.spec.reward_threshold: | ||||||
|  |             return mean_rewards >= env.spec.reward_threshold | ||||||
|  |         elif 'Pong' in args.task: | ||||||
|  |             return mean_rewards >= 20 | ||||||
|  |         else: | ||||||
|  |             return False | ||||||
| 
 | 
 | ||||||
|  |     # watch agent's performance | ||||||
|     def watch(): |     def watch(): | ||||||
|         # watch agent's performance |  | ||||||
|         print("Setup test envs ...") |         print("Setup test envs ...") | ||||||
|         policy.eval() |         policy.eval() | ||||||
|         test_envs.seed(args.seed) |         test_envs.seed(args.seed) | ||||||
| @ -210,7 +259,7 @@ def test_a2c(args=get_args()): | |||||||
|         args.repeat_per_collect, |         args.repeat_per_collect, | ||||||
|         args.test_num, |         args.test_num, | ||||||
|         args.batch_size, |         args.batch_size, | ||||||
|         episode_per_collect=args.episode_per_collect, |         step_per_collect=args.step_per_collect, | ||||||
|         stop_fn=stop_fn, |         stop_fn=stop_fn, | ||||||
|         save_fn=save_fn, |         save_fn=save_fn, | ||||||
|         logger=logger, |         logger=logger, | ||||||
| @ -222,4 +271,4 @@ def test_a2c(args=get_args()): | |||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||||
|     test_a2c(get_args()) |     test_ppo(get_args()) | ||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user