variable rename and delete redundant code

This commit is contained in:
Dong Yan 2017-12-26 22:19:10 +08:00
parent 0c3ff3bf37
commit 7f0565a5f6
2 changed files with 10 additions and 12 deletions

View File

@ -46,7 +46,8 @@ class Game:
else:
raise ValueError(name + " is an unknown game...")
self.evaluator = model.ResNet(self.size, self.size ** 2 + 1, history_length=self.history_length, checkpoint_path=checkpoint_path)
self.evaluator = model.ResNet(self.size, self.size ** 2 + 1, history_length=self.history_length,
checkpoint_path=checkpoint_path)
self.latest_boards = deque(maxlen=self.history_length)
for _ in range(self.history_length):
self.latest_boards.append(self.board)
@ -91,10 +92,6 @@ class Game:
# this function can be called directly to play the opponent's move
if vertex == utils.PASS:
return True
# TODO this implementation is not very elegant
if self.name == "go":
res = self.game_engine.executor_do_move(self.history, self.latest_boards, self.board, color, vertex)
elif self.name == "reversi":
res = self.game_engine.executor_do_move(self.history, self.latest_boards, self.board, color, vertex)
return res

View File

@ -129,6 +129,7 @@ class ActionNode(object):
self.mcts.action_selection_time += time.time() - head
return self.parent, self.action
else:
# self.next_state is None means we have reach the terminate state
self.mcts.action_selection_time += time.time() - head
return self.parent, self.action
@ -147,20 +148,20 @@ class ActionNode(object):
class MCTS(object):
def __init__(self, simulator, evaluator, root, action_num, method="UCT",
def __init__(self, simulator, evaluator, start_state, action_num, method="UCT",
role="unknown", debug=False, inverse=False):
self.simulator = simulator
self.evaluator = evaluator
self.role = role
self.debug = debug
prior, _ = self.evaluator(root)
prior, _ = self.evaluator(start_state)
self.action_num = action_num
if method == "":
self.root = root
self.root = start_state
if method == "UCT":
self.root = UCTNode(None, None, root, action_num, prior, mcts=self, inverse=inverse)
self.root = UCTNode(None, None, start_state, action_num, prior, mcts=self, inverse=inverse)
if method == "TS":
self.root = TSNode(None, None, root, action_num, prior, inverse=inverse)
self.root = TSNode(None, None, start_state, action_num, prior, inverse=inverse)
self.inverse = inverse
# time spend on each step
@ -191,7 +192,7 @@ class MCTS(object):
self.expansion_time += exp_time
self.backpropagation_time += back_time
step += 1
if (self.debug):
if self.debug:
file = open("mcts_profiling.txt", "a")
file.write("[" + str(self.role) + "]"
+ " sel " + '%.3f' % self.selection_time + " "