variable rename and delete redundant code
This commit is contained in:
parent
0c3ff3bf37
commit
7f0565a5f6
@ -46,7 +46,8 @@ class Game:
|
||||
else:
|
||||
raise ValueError(name + " is an unknown game...")
|
||||
|
||||
self.evaluator = model.ResNet(self.size, self.size ** 2 + 1, history_length=self.history_length, checkpoint_path=checkpoint_path)
|
||||
self.evaluator = model.ResNet(self.size, self.size ** 2 + 1, history_length=self.history_length,
|
||||
checkpoint_path=checkpoint_path)
|
||||
self.latest_boards = deque(maxlen=self.history_length)
|
||||
for _ in range(self.history_length):
|
||||
self.latest_boards.append(self.board)
|
||||
@ -91,10 +92,6 @@ class Game:
|
||||
# this function can be called directly to play the opponent's move
|
||||
if vertex == utils.PASS:
|
||||
return True
|
||||
# TODO this implementation is not very elegant
|
||||
if self.name == "go":
|
||||
res = self.game_engine.executor_do_move(self.history, self.latest_boards, self.board, color, vertex)
|
||||
elif self.name == "reversi":
|
||||
res = self.game_engine.executor_do_move(self.history, self.latest_boards, self.board, color, vertex)
|
||||
return res
|
||||
|
||||
|
@ -129,6 +129,7 @@ class ActionNode(object):
|
||||
self.mcts.action_selection_time += time.time() - head
|
||||
return self.parent, self.action
|
||||
else:
|
||||
# self.next_state is None means we have reach the terminate state
|
||||
self.mcts.action_selection_time += time.time() - head
|
||||
return self.parent, self.action
|
||||
|
||||
@ -147,20 +148,20 @@ class ActionNode(object):
|
||||
|
||||
|
||||
class MCTS(object):
|
||||
def __init__(self, simulator, evaluator, root, action_num, method="UCT",
|
||||
def __init__(self, simulator, evaluator, start_state, action_num, method="UCT",
|
||||
role="unknown", debug=False, inverse=False):
|
||||
self.simulator = simulator
|
||||
self.evaluator = evaluator
|
||||
self.role = role
|
||||
self.debug = debug
|
||||
prior, _ = self.evaluator(root)
|
||||
prior, _ = self.evaluator(start_state)
|
||||
self.action_num = action_num
|
||||
if method == "":
|
||||
self.root = root
|
||||
self.root = start_state
|
||||
if method == "UCT":
|
||||
self.root = UCTNode(None, None, root, action_num, prior, mcts=self, inverse=inverse)
|
||||
self.root = UCTNode(None, None, start_state, action_num, prior, mcts=self, inverse=inverse)
|
||||
if method == "TS":
|
||||
self.root = TSNode(None, None, root, action_num, prior, inverse=inverse)
|
||||
self.root = TSNode(None, None, start_state, action_num, prior, inverse=inverse)
|
||||
self.inverse = inverse
|
||||
|
||||
# time spend on each step
|
||||
@ -191,7 +192,7 @@ class MCTS(object):
|
||||
self.expansion_time += exp_time
|
||||
self.backpropagation_time += back_time
|
||||
step += 1
|
||||
if (self.debug):
|
||||
if self.debug:
|
||||
file = open("mcts_profiling.txt", "a")
|
||||
file.write("[" + str(self.role) + "]"
|
||||
+ " sel " + '%.3f' % self.selection_time + " "
|
||||
|
Loading…
x
Reference in New Issue
Block a user