23 lines
		
	
	
		
			544 B
		
	
	
	
		
			Python
		
	
	
	
	
	
		
		
			
		
	
	
			23 lines
		
	
	
		
			544 B
		
	
	
	
		
			Python
		
	
	
	
	
	
|  | import pprint | ||
|  | from tianshou.data import Collector | ||
|  | from tic_tac_toe import get_args, train_agent, watch | ||
|  | 
 | ||
|  | 
 | ||
|  | def test_tic_tac_toe(args=get_args()): | ||
|  |     Collector._default_rew_metric = lambda x: x[args.agent_id - 1] | ||
|  |     if args.watch: | ||
|  |         watch(args) | ||
|  |         return | ||
|  | 
 | ||
|  |     result, agent = train_agent(args) | ||
|  |     assert result["best_reward"] >= args.win_rate | ||
|  | 
 | ||
|  |     if __name__ == '__main__': | ||
|  |         pprint.pprint(result) | ||
|  |         # Let's watch its performance! | ||
|  |         watch(args, agent) | ||
|  | 
 | ||
|  | 
 | ||
|  | if __name__ == '__main__': | ||
|  |     test_tic_tac_toe(get_args()) |