count the winning rate for each player
This commit is contained in:
parent
8780417378
commit
dcf293d637
1
AlphaGo/.gitignore
vendored
1
AlphaGo/.gitignore
vendored
@ -1,3 +1,4 @@
|
||||
data
|
||||
checkpoints
|
||||
checkpoints_origin
|
||||
*.log
|
||||
|
29
AlphaGo/data_statistic.py
Normal file
29
AlphaGo/data_statistic.py
Normal file
@ -0,0 +1,29 @@
|
||||
import os
|
||||
import cPickle
|
||||
|
||||
class Data(object):
|
||||
def __init__(self):
|
||||
self.boards = []
|
||||
self.probs = []
|
||||
self.winner = 0
|
||||
|
||||
def file_to_training_data(file_name):
|
||||
with open(file_name, 'rb') as file:
|
||||
try:
|
||||
file.seek(0)
|
||||
data = cPickle.load(file)
|
||||
return data.winner
|
||||
except Exception as e:
|
||||
print(e)
|
||||
return 0
|
||||
|
||||
if __name__ == "__main__":
|
||||
win_count = [0, 0, 0]
|
||||
file_list = os.listdir("./data")
|
||||
#print file_list
|
||||
for file in file_list:
|
||||
win_count[file_to_training_data("./data/" + file)] += 1
|
||||
print "Total play : " + str(len(file_list))
|
||||
print "Black wins : " + str(win_count[1])
|
||||
print "White wins : " + str(win_count[-1])
|
||||
|
@ -62,7 +62,7 @@ class Game:
|
||||
|
||||
def think(self, latest_boards, color):
|
||||
mcts = MCTS(self.game_engine, self.evaluator, [latest_boards, color], self.size ** 2 + 1, inverse=True)
|
||||
mcts.search(max_step=20)
|
||||
mcts.search(max_step=100)
|
||||
temp = 1
|
||||
prob = mcts.root.N ** temp / np.sum(mcts.root.N ** temp)
|
||||
choice = np.random.choice(self.size ** 2 + 1, 1, p=prob).tolist()[0]
|
||||
|
Loading…
x
Reference in New Issue
Block a user