| 
									
										
										
										
											2020-07-21 14:59:49 +08:00
										 |  |  | import os | 
					
						
							|  |  |  | import pprint | 
					
						
							|  |  |  | import numpy as np | 
					
						
							|  |  |  | from copy import deepcopy | 
					
						
							|  |  |  | from torch.utils.tensorboard import SummaryWriter | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-08-19 15:00:24 +08:00
										 |  |  | from tianshou.env import DummyVectorEnv | 
					
						
							| 
									
										
										
										
											2020-07-21 14:59:49 +08:00
										 |  |  | from tianshou.data import Collector | 
					
						
							|  |  |  | from tianshou.policy import RandomPolicy | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | from tic_tac_toe_env import TicTacToeEnv | 
					
						
							|  |  |  | from tic_tac_toe import get_parser, get_agents, train_agent, watch | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def get_args(): | 
					
						
							|  |  |  |     parser = get_parser() | 
					
						
							|  |  |  |     parser.add_argument('--self_play_round', type=int, default=20) | 
					
						
							|  |  |  |     args = parser.parse_known_args()[0] | 
					
						
							|  |  |  |     return args | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def gomoku(args=get_args()): | 
					
						
							|  |  |  |     Collector._default_rew_metric = lambda x: x[args.agent_id - 1] | 
					
						
							|  |  |  |     if args.watch: | 
					
						
							|  |  |  |         watch(args) | 
					
						
							|  |  |  |         return | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     policy, optim = get_agents(args) | 
					
						
							|  |  |  |     agent_learn = policy.policies[args.agent_id - 1] | 
					
						
							|  |  |  |     agent_opponent = policy.policies[2 - args.agent_id] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # log | 
					
						
							|  |  |  |     log_path = os.path.join(args.logdir, 'Gomoku', 'dqn') | 
					
						
							|  |  |  |     args.writer = SummaryWriter(log_path) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     opponent_pool = [agent_opponent] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def env_func(): | 
					
						
							|  |  |  |         return TicTacToeEnv(args.board_size, args.win_size) | 
					
						
							| 
									
										
										
										
											2020-08-19 15:00:24 +08:00
										 |  |  |     test_envs = DummyVectorEnv([env_func for _ in range(args.test_num)]) | 
					
						
							| 
									
										
										
										
											2020-07-21 14:59:49 +08:00
										 |  |  |     for r in range(args.self_play_round): | 
					
						
							|  |  |  |         rews = [] | 
					
						
							|  |  |  |         agent_learn.set_eps(0.0) | 
					
						
							|  |  |  |         # compute the reward over previous learner | 
					
						
							|  |  |  |         for opponent in opponent_pool: | 
					
						
							|  |  |  |             policy.replace_policy(opponent, 3 - args.agent_id) | 
					
						
							|  |  |  |             test_collector = Collector(policy, test_envs) | 
					
						
							|  |  |  |             results = test_collector.collect(n_episode=100) | 
					
						
							|  |  |  |             rews.append(results['rew']) | 
					
						
							|  |  |  |         rews = np.array(rews) | 
					
						
							|  |  |  |         # weight opponent by their difficulty level | 
					
						
							|  |  |  |         rews = np.exp(-rews * 10.0) | 
					
						
							|  |  |  |         rews /= np.sum(rews) | 
					
						
							|  |  |  |         total_epoch = args.epoch | 
					
						
							|  |  |  |         args.epoch = 1 | 
					
						
							|  |  |  |         for epoch in range(total_epoch): | 
					
						
							|  |  |  |             # sample one opponent | 
					
						
							|  |  |  |             opp_id = np.random.choice(len(opponent_pool), size=1, p=rews) | 
					
						
							|  |  |  |             print(f'selection probability {rews.tolist()}') | 
					
						
							|  |  |  |             print(f'selected opponent {opp_id}') | 
					
						
							|  |  |  |             opponent = opponent_pool[opp_id.item(0)] | 
					
						
							|  |  |  |             agent = RandomPolicy() | 
					
						
							|  |  |  |             # previous learner can only be used for forward | 
					
						
							|  |  |  |             agent.forward = opponent.forward | 
					
						
							|  |  |  |             args.model_save_path = os.path.join( | 
					
						
							|  |  |  |                 args.logdir, 'Gomoku', 'dqn', | 
					
						
							|  |  |  |                 f'policy_round_{r}_epoch_{epoch}.pth') | 
					
						
							|  |  |  |             result, agent_learn = train_agent( | 
					
						
							|  |  |  |                 args, agent_learn=agent_learn, | 
					
						
							|  |  |  |                 agent_opponent=agent, optim=optim) | 
					
						
							|  |  |  |             print(f'round_{r}_epoch_{epoch}') | 
					
						
							|  |  |  |             pprint.pprint(result) | 
					
						
							|  |  |  |         learnt_agent = deepcopy(agent_learn) | 
					
						
							|  |  |  |         learnt_agent.set_eps(0.0) | 
					
						
							|  |  |  |         opponent_pool.append(learnt_agent) | 
					
						
							|  |  |  |         args.epoch = total_epoch | 
					
						
							|  |  |  |     if __name__ == '__main__': | 
					
						
							|  |  |  |         # Let's watch its performance! | 
					
						
							|  |  |  |         opponent = opponent_pool[-2] | 
					
						
							|  |  |  |         watch(args, agent_learn, opponent) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | if __name__ == '__main__': | 
					
						
							|  |  |  |     gomoku(get_args()) |