85 lines
		
	
	
		
			2.8 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
		
		
			
		
	
	
			85 lines
		
	
	
		
			2.8 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
|  | import os | ||
|  | import pprint | ||
|  | import numpy as np | ||
|  | from copy import deepcopy | ||
|  | from torch.utils.tensorboard import SummaryWriter | ||
|  | 
 | ||
|  | from tianshou.env import VectorEnv | ||
|  | from tianshou.data import Collector | ||
|  | from tianshou.policy import RandomPolicy | ||
|  | 
 | ||
|  | from tic_tac_toe_env import TicTacToeEnv | ||
|  | from tic_tac_toe import get_parser, get_agents, train_agent, watch | ||
|  | 
 | ||
|  | 
 | ||
|  | def get_args(): | ||
|  |     parser = get_parser() | ||
|  |     parser.add_argument('--self_play_round', type=int, default=20) | ||
|  |     args = parser.parse_known_args()[0] | ||
|  |     return args | ||
|  | 
 | ||
|  | 
 | ||
|  | def gomoku(args=get_args()): | ||
|  |     Collector._default_rew_metric = lambda x: x[args.agent_id - 1] | ||
|  |     if args.watch: | ||
|  |         watch(args) | ||
|  |         return | ||
|  | 
 | ||
|  |     policy, optim = get_agents(args) | ||
|  |     agent_learn = policy.policies[args.agent_id - 1] | ||
|  |     agent_opponent = policy.policies[2 - args.agent_id] | ||
|  | 
 | ||
|  |     # log | ||
|  |     log_path = os.path.join(args.logdir, 'Gomoku', 'dqn') | ||
|  |     args.writer = SummaryWriter(log_path) | ||
|  | 
 | ||
|  |     opponent_pool = [agent_opponent] | ||
|  | 
 | ||
|  |     def env_func(): | ||
|  |         return TicTacToeEnv(args.board_size, args.win_size) | ||
|  |     test_envs = VectorEnv([env_func for _ in range(args.test_num)]) | ||
|  |     for r in range(args.self_play_round): | ||
|  |         rews = [] | ||
|  |         agent_learn.set_eps(0.0) | ||
|  |         # compute the reward over previous learner | ||
|  |         for opponent in opponent_pool: | ||
|  |             policy.replace_policy(opponent, 3 - args.agent_id) | ||
|  |             test_collector = Collector(policy, test_envs) | ||
|  |             results = test_collector.collect(n_episode=100) | ||
|  |             rews.append(results['rew']) | ||
|  |         rews = np.array(rews) | ||
|  |         # weight opponent by their difficulty level | ||
|  |         rews = np.exp(-rews * 10.0) | ||
|  |         rews /= np.sum(rews) | ||
|  |         total_epoch = args.epoch | ||
|  |         args.epoch = 1 | ||
|  |         for epoch in range(total_epoch): | ||
|  |             # sample one opponent | ||
|  |             opp_id = np.random.choice(len(opponent_pool), size=1, p=rews) | ||
|  |             print(f'selection probability {rews.tolist()}') | ||
|  |             print(f'selected opponent {opp_id}') | ||
|  |             opponent = opponent_pool[opp_id.item(0)] | ||
|  |             agent = RandomPolicy() | ||
|  |             # previous learner can only be used for forward | ||
|  |             agent.forward = opponent.forward | ||
|  |             args.model_save_path = os.path.join( | ||
|  |                 args.logdir, 'Gomoku', 'dqn', | ||
|  |                 f'policy_round_{r}_epoch_{epoch}.pth') | ||
|  |             result, agent_learn = train_agent( | ||
|  |                 args, agent_learn=agent_learn, | ||
|  |                 agent_opponent=agent, optim=optim) | ||
|  |             print(f'round_{r}_epoch_{epoch}') | ||
|  |             pprint.pprint(result) | ||
|  |         learnt_agent = deepcopy(agent_learn) | ||
|  |         learnt_agent.set_eps(0.0) | ||
|  |         opponent_pool.append(learnt_agent) | ||
|  |         args.epoch = total_epoch | ||
|  |     if __name__ == '__main__': | ||
|  |         # Let's watch its performance! | ||
|  |         opponent = opponent_pool[-2] | ||
|  |         watch(args, agent_learn, opponent) | ||
|  | 
 | ||
|  | 
 | ||
|  | if __name__ == '__main__': | ||
|  |     gomoku(get_args()) |