From 0e4aa44ebb50e88bdb25b51e030b5e7ed230bf8a Mon Sep 17 00:00:00 2001 From: Wenbo Date: Wed, 17 Jan 2018 15:54:46 +0800 Subject: [PATCH] add deepcopy for hash, add some testing --- AlphaGo/game.py | 2 +- AlphaGo/go.py | 8 ++++---- AlphaGo/play.py | 13 ++++++++++++- 3 files changed, 17 insertions(+), 6 deletions(-) diff --git a/AlphaGo/game.py b/AlphaGo/game.py index a299e97..9b4ba1e 100644 --- a/AlphaGo/game.py +++ b/AlphaGo/game.py @@ -90,7 +90,7 @@ class Game: evaluator = lambda state:self.model(role, state) mcts = MCTS(self.game_engine, evaluator, [latest_boards, color], self.size ** 2 + 1, role=role, debug=self.debug, inverse=True) - mcts.search(max_step=5) + mcts.search(max_step=100) if self.debug: file = open("mcts_debug.log", 'ab') np.savetxt(file, mcts.root.Q, header="\n" + role + " Q value : ", fmt='%.4f', newline=", ") diff --git a/AlphaGo/go.py b/AlphaGo/go.py index 84e6b7d..5d4d21e 100644 --- a/AlphaGo/go.py +++ b/AlphaGo/go.py @@ -309,9 +309,9 @@ class Go: liberty[reverse_color_ancestor_idx].remove(idx) def executor_do_move(self, history, history_hashtable, latest_boards, current_board, group_ancestors, liberty, stones, color, vertex): - print("===") - print(color, vertex) - print(group_ancestors, liberty, stones) + #print("===") + #print(color, vertex) + #print(group_ancestors, liberty, stones) if not self._rule_check(history_hashtable, current_board, group_ancestors, liberty, color, vertex): return False idx = self._flatten(vertex) @@ -327,7 +327,7 @@ class Go: self._remove_liberty(idx, reverse_color_neighbor, current_board, group_ancestors, liberty, stones) history.append(copy.deepcopy(current_board)) latest_boards.append(copy.deepcopy(current_board)) - history_hashtable.add(tuple(current_board)) + history_hashtable.add(copy.deepcopy(tuple(current_board))) return True def _find_empty(self, current_board): diff --git a/AlphaGo/play.py b/AlphaGo/play.py index be54ad2..b877b87 100644 --- a/AlphaGo/play.py +++ b/AlphaGo/play.py @@ -1,3 +1,4 @@ +from __future__ import division import argparse import sys import re @@ -28,10 +29,13 @@ def play(engine, data_path): size = {"go": 9, "reversi": 8} show = ['.', 'X', 'O'] - evaluate_rounds = 1 + evaluate_rounds = 5 game_num = 0 + total_time = 0 + f=open('time.txt','w') #while True: while game_num < evaluate_rounds: + start = time.time() engine._game.model.check_latest_model() num = 0 pass_flag = [False, False] @@ -77,6 +81,13 @@ def play(engine, data_path): cPickle.dump(data, file) data.reset() game_num += 1 + + this_time = time.time() - start + total += this_time + f.write('time:'+ str(this_time)+'\n') + f.write('Avg time:' + str(total/evaluate_rounds)) + f.close() + if __name__ == '__main__':