This is the third PR of 6 commits mentioned in #274, which features refactor of Collector to fix #245. You can check #274 for more detail. Things changed in this PR: 1. refactor collector to be more cleaner, split AsyncCollector to support asyncvenv; 2. change buffer.add api to add(batch, bffer_ids); add several types of buffer (VectorReplayBuffer, PrioritizedVectorReplayBuffer, etc.) 3. add policy.exploration_noise(act, batch) -> act 4. small change in BasePolicy.compute_*_returns 5. move reward_metric from collector to trainer 6. fix np.asanyarray issue (different version's numpy will result in different output) 7. flake8 maxlength=88 8. polish docs and fix test Co-authored-by: n+e <trinkle23897@gmail.com>
		
			
				
	
	
		
			21 lines
		
	
	
		
			441 B
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			21 lines
		
	
	
		
			441 B
		
	
	
	
		
			Python
		
	
	
	
	
	
| import pprint
 | |
| from tic_tac_toe import get_args, train_agent, watch
 | |
| 
 | |
| 
 | |
| def test_tic_tac_toe(args=get_args()):
 | |
|     if args.watch:
 | |
|         watch(args)
 | |
|         return
 | |
| 
 | |
|     result, agent = train_agent(args)
 | |
|     assert result["best_reward"] >= args.win_rate
 | |
| 
 | |
|     if __name__ == '__main__':
 | |
|         pprint.pprint(result)
 | |
|         # Let's watch its performance!
 | |
|         watch(args, agent)
 | |
| 
 | |
| 
 | |
| if __name__ == '__main__':
 | |
|     test_tic_tac_toe(get_args())
 |