solve the performance bottleneck by only hashing the last board
This commit is contained in:
parent
affd0319e2
commit
47676993fd
@ -222,6 +222,10 @@ class Go:
|
|||||||
new_color = -color
|
new_color = -color
|
||||||
return [history_boards, new_color], 0
|
return [history_boards, new_color], 0
|
||||||
|
|
||||||
|
def simulate_hashable_conversion(self, state):
|
||||||
|
# since go is MDP, we only need the last board for hashing
|
||||||
|
return tuple(state[0][-1])
|
||||||
|
|
||||||
def executor_do_move(self, history, latest_boards, current_board, color, vertex):
|
def executor_do_move(self, history, latest_boards, current_board, color, vertex):
|
||||||
if not self._rule_check(history, current_board, color, vertex):
|
if not self._rule_check(history, current_board, color, vertex):
|
||||||
return False
|
return False
|
||||||
|
@ -97,6 +97,10 @@ class Reversi:
|
|||||||
history_boards.append(new_board)
|
history_boards.append(new_board)
|
||||||
return [history_boards, 0 - color], 0
|
return [history_boards, 0 - color], 0
|
||||||
|
|
||||||
|
def simulate_hashable_conversion(self, state):
|
||||||
|
# since go is MDP, we only need the last board for hashing
|
||||||
|
return tuple(state[0][-1])
|
||||||
|
|
||||||
def _get_winner(self, board):
|
def _get_winner(self, board):
|
||||||
black_num, white_num = self._number_of_black_and_white(board)
|
black_num, white_num = self._number_of_black_and_white(board)
|
||||||
black_win = black_num - white_num
|
black_win = black_num - white_num
|
||||||
|
@ -1,17 +1,9 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import math
|
import math
|
||||||
import time
|
import time
|
||||||
import sys
|
|
||||||
import collections
|
|
||||||
|
|
||||||
c_puct = 5
|
c_puct = 5
|
||||||
|
|
||||||
def hashable_conversion(obj):
|
|
||||||
if isinstance(obj, collections.Hashable):
|
|
||||||
return obj
|
|
||||||
else:
|
|
||||||
return tuple(hashable_conversion(sub) for sub in obj)
|
|
||||||
|
|
||||||
class MCTSNode(object):
|
class MCTSNode(object):
|
||||||
def __init__(self, parent, action, state, action_num, prior, inverse=False):
|
def __init__(self, parent, action, state, action_num, prior, inverse=False):
|
||||||
self.parent = parent
|
self.parent = parent
|
||||||
@ -109,7 +101,9 @@ class ActionNode(object):
|
|||||||
if self.next_state is None: # next_state is None means that self.parent.state is the terminate state
|
if self.next_state is None: # next_state is None means that self.parent.state is the terminate state
|
||||||
self.mcts.action_selection_time += time.time() - head
|
self.mcts.action_selection_time += time.time() - head
|
||||||
return self
|
return self
|
||||||
self.next_state_hashable = hashable_conversion(self.next_state)
|
head = time.time()
|
||||||
|
self.next_state_hashable = simulator.simulate_hashable_conversion(self.next_state)
|
||||||
|
self.mcts.hash_time += time.time() - head
|
||||||
if self.next_state_hashable in self.children.keys(): # next state has already visited before
|
if self.next_state_hashable in self.children.keys(): # next state has already visited before
|
||||||
self.mcts.action_selection_time += time.time() - head
|
self.mcts.action_selection_time += time.time() - head
|
||||||
return self.children[self.next_state_hashable].selection(simulator)
|
return self.children[self.next_state_hashable].selection(simulator)
|
||||||
@ -153,9 +147,7 @@ class MCTS(object):
|
|||||||
self.state_selection_time = 0
|
self.state_selection_time = 0
|
||||||
self.simulate_sf_time = 0
|
self.simulate_sf_time = 0
|
||||||
self.valid_mask_time = 0
|
self.valid_mask_time = 0
|
||||||
self.ndarray2list_time = 0
|
self.hash_time = 0
|
||||||
self.list2tuple_time = 0
|
|
||||||
self.check = 0
|
|
||||||
|
|
||||||
def search(self, max_step=None, max_time=None):
|
def search(self, max_step=None, max_time=None):
|
||||||
step = 0
|
step = 0
|
||||||
@ -174,18 +166,16 @@ class MCTS(object):
|
|||||||
self.backpropagation_time += back_time
|
self.backpropagation_time += back_time
|
||||||
step += 1
|
step += 1
|
||||||
if self.debug:
|
if self.debug:
|
||||||
file = open("mcts_profiling.txt", "a")
|
file = open("mcts_profiling.log", "a")
|
||||||
file.write("[" + str(self.role) + "]"
|
file.write("[" + str(self.role) + "]"
|
||||||
+ " sel " + '%.3f' % self.selection_time + " "
|
+ " sel " + '%.3f' % self.selection_time + " "
|
||||||
+ " sel_sta " + '%.3f' % self.state_selection_time + " "
|
+ " sel_sta " + '%.3f' % self.state_selection_time + " "
|
||||||
+ " valid " + '%.3f' % self.valid_mask_time + " "
|
+ " valid " + '%.3f' % self.valid_mask_time + " "
|
||||||
+ " sel_act " + '%.3f' % self.action_selection_time + " "
|
+ " sel_act " + '%.3f' % self.action_selection_time + " "
|
||||||
+ " array2list " + '%.4f' % self.ndarray2list_time + " "
|
+ " hash " + '%.3f' % self.hash_time + " "
|
||||||
+ " check " + str(self.check) + " "
|
+ " step forward " + '%.3f' % self.simulate_sf_time + " "
|
||||||
+ " list2tuple " + '%.4f' % self.list2tuple_time + " \t"
|
+ " expansion " + '%.3f' % self.expansion_time + " "
|
||||||
+ " forward " + '%.3f' % self.simulate_sf_time + " "
|
+ " backprop " + '%.3f' % self.backpropagation_time + " "
|
||||||
+ " exp " + '%.3f' % self.expansion_time + " "
|
|
||||||
+ " bak " + '%.3f' % self.backpropagation_time + " "
|
|
||||||
+ "\n")
|
+ "\n")
|
||||||
file.close()
|
file.close()
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user