add deepcopy for hash, add some testing
This commit is contained in:
parent
e76ccaee80
commit
0e4aa44ebb
@ -90,7 +90,7 @@ class Game:
|
||||
evaluator = lambda state:self.model(role, state)
|
||||
mcts = MCTS(self.game_engine, evaluator, [latest_boards, color],
|
||||
self.size ** 2 + 1, role=role, debug=self.debug, inverse=True)
|
||||
mcts.search(max_step=5)
|
||||
mcts.search(max_step=100)
|
||||
if self.debug:
|
||||
file = open("mcts_debug.log", 'ab')
|
||||
np.savetxt(file, mcts.root.Q, header="\n" + role + " Q value : ", fmt='%.4f', newline=", ")
|
||||
|
@ -309,9 +309,9 @@ class Go:
|
||||
liberty[reverse_color_ancestor_idx].remove(idx)
|
||||
|
||||
def executor_do_move(self, history, history_hashtable, latest_boards, current_board, group_ancestors, liberty, stones, color, vertex):
|
||||
print("===")
|
||||
print(color, vertex)
|
||||
print(group_ancestors, liberty, stones)
|
||||
#print("===")
|
||||
#print(color, vertex)
|
||||
#print(group_ancestors, liberty, stones)
|
||||
if not self._rule_check(history_hashtable, current_board, group_ancestors, liberty, color, vertex):
|
||||
return False
|
||||
idx = self._flatten(vertex)
|
||||
@ -327,7 +327,7 @@ class Go:
|
||||
self._remove_liberty(idx, reverse_color_neighbor, current_board, group_ancestors, liberty, stones)
|
||||
history.append(copy.deepcopy(current_board))
|
||||
latest_boards.append(copy.deepcopy(current_board))
|
||||
history_hashtable.add(tuple(current_board))
|
||||
history_hashtable.add(copy.deepcopy(tuple(current_board)))
|
||||
return True
|
||||
|
||||
def _find_empty(self, current_board):
|
||||
|
@ -1,3 +1,4 @@
|
||||
from __future__ import division
|
||||
import argparse
|
||||
import sys
|
||||
import re
|
||||
@ -28,10 +29,13 @@ def play(engine, data_path):
|
||||
size = {"go": 9, "reversi": 8}
|
||||
show = ['.', 'X', 'O']
|
||||
|
||||
evaluate_rounds = 1
|
||||
evaluate_rounds = 5
|
||||
game_num = 0
|
||||
total_time = 0
|
||||
f=open('time.txt','w')
|
||||
#while True:
|
||||
while game_num < evaluate_rounds:
|
||||
start = time.time()
|
||||
engine._game.model.check_latest_model()
|
||||
num = 0
|
||||
pass_flag = [False, False]
|
||||
@ -77,6 +81,13 @@ def play(engine, data_path):
|
||||
cPickle.dump(data, file)
|
||||
data.reset()
|
||||
game_num += 1
|
||||
|
||||
this_time = time.time() - start
|
||||
total += this_time
|
||||
f.write('time:'+ str(this_time)+'\n')
|
||||
f.write('Avg time:' + str(total/evaluate_rounds))
|
||||
f.close()
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
Loading…
x
Reference in New Issue
Block a user