variable rename and delete redundant code
This commit is contained in:
parent
0c3ff3bf37
commit
7f0565a5f6
@ -46,7 +46,8 @@ class Game:
|
|||||||
else:
|
else:
|
||||||
raise ValueError(name + " is an unknown game...")
|
raise ValueError(name + " is an unknown game...")
|
||||||
|
|
||||||
self.evaluator = model.ResNet(self.size, self.size ** 2 + 1, history_length=self.history_length, checkpoint_path=checkpoint_path)
|
self.evaluator = model.ResNet(self.size, self.size ** 2 + 1, history_length=self.history_length,
|
||||||
|
checkpoint_path=checkpoint_path)
|
||||||
self.latest_boards = deque(maxlen=self.history_length)
|
self.latest_boards = deque(maxlen=self.history_length)
|
||||||
for _ in range(self.history_length):
|
for _ in range(self.history_length):
|
||||||
self.latest_boards.append(self.board)
|
self.latest_boards.append(self.board)
|
||||||
@ -91,11 +92,7 @@ class Game:
|
|||||||
# this function can be called directly to play the opponent's move
|
# this function can be called directly to play the opponent's move
|
||||||
if vertex == utils.PASS:
|
if vertex == utils.PASS:
|
||||||
return True
|
return True
|
||||||
# TODO this implementation is not very elegant
|
res = self.game_engine.executor_do_move(self.history, self.latest_boards, self.board, color, vertex)
|
||||||
if self.name == "go":
|
|
||||||
res = self.game_engine.executor_do_move(self.history, self.latest_boards, self.board, color, vertex)
|
|
||||||
elif self.name == "reversi":
|
|
||||||
res = self.game_engine.executor_do_move(self.history, self.latest_boards, self.board, color, vertex)
|
|
||||||
return res
|
return res
|
||||||
|
|
||||||
def think_play_move(self, color):
|
def think_play_move(self, color):
|
||||||
|
@ -129,6 +129,7 @@ class ActionNode(object):
|
|||||||
self.mcts.action_selection_time += time.time() - head
|
self.mcts.action_selection_time += time.time() - head
|
||||||
return self.parent, self.action
|
return self.parent, self.action
|
||||||
else:
|
else:
|
||||||
|
# self.next_state is None means we have reach the terminate state
|
||||||
self.mcts.action_selection_time += time.time() - head
|
self.mcts.action_selection_time += time.time() - head
|
||||||
return self.parent, self.action
|
return self.parent, self.action
|
||||||
|
|
||||||
@ -147,20 +148,20 @@ class ActionNode(object):
|
|||||||
|
|
||||||
|
|
||||||
class MCTS(object):
|
class MCTS(object):
|
||||||
def __init__(self, simulator, evaluator, root, action_num, method="UCT",
|
def __init__(self, simulator, evaluator, start_state, action_num, method="UCT",
|
||||||
role="unknown", debug=False, inverse=False):
|
role="unknown", debug=False, inverse=False):
|
||||||
self.simulator = simulator
|
self.simulator = simulator
|
||||||
self.evaluator = evaluator
|
self.evaluator = evaluator
|
||||||
self.role = role
|
self.role = role
|
||||||
self.debug = debug
|
self.debug = debug
|
||||||
prior, _ = self.evaluator(root)
|
prior, _ = self.evaluator(start_state)
|
||||||
self.action_num = action_num
|
self.action_num = action_num
|
||||||
if method == "":
|
if method == "":
|
||||||
self.root = root
|
self.root = start_state
|
||||||
if method == "UCT":
|
if method == "UCT":
|
||||||
self.root = UCTNode(None, None, root, action_num, prior, mcts=self, inverse=inverse)
|
self.root = UCTNode(None, None, start_state, action_num, prior, mcts=self, inverse=inverse)
|
||||||
if method == "TS":
|
if method == "TS":
|
||||||
self.root = TSNode(None, None, root, action_num, prior, inverse=inverse)
|
self.root = TSNode(None, None, start_state, action_num, prior, inverse=inverse)
|
||||||
self.inverse = inverse
|
self.inverse = inverse
|
||||||
|
|
||||||
# time spend on each step
|
# time spend on each step
|
||||||
@ -191,7 +192,7 @@ class MCTS(object):
|
|||||||
self.expansion_time += exp_time
|
self.expansion_time += exp_time
|
||||||
self.backpropagation_time += back_time
|
self.backpropagation_time += back_time
|
||||||
step += 1
|
step += 1
|
||||||
if (self.debug):
|
if self.debug:
|
||||||
file = open("mcts_profiling.txt", "a")
|
file = open("mcts_profiling.txt", "a")
|
||||||
file.write("[" + str(self.role) + "]"
|
file.write("[" + str(self.role) + "]"
|
||||||
+ " sel " + '%.3f' % self.selection_time + " "
|
+ " sel " + '%.3f' % self.selection_time + " "
|
||||||
|
Loading…
x
Reference in New Issue
Block a user