solve the performance bottleneck by only hashing the last board
This commit is contained in:
parent
affd0319e2
commit
47676993fd
@ -222,6 +222,10 @@ class Go:
|
||||
new_color = -color
|
||||
return [history_boards, new_color], 0
|
||||
|
||||
def simulate_hashable_conversion(self, state):
|
||||
# since go is MDP, we only need the last board for hashing
|
||||
return tuple(state[0][-1])
|
||||
|
||||
def executor_do_move(self, history, latest_boards, current_board, color, vertex):
|
||||
if not self._rule_check(history, current_board, color, vertex):
|
||||
return False
|
||||
|
@ -97,6 +97,10 @@ class Reversi:
|
||||
history_boards.append(new_board)
|
||||
return [history_boards, 0 - color], 0
|
||||
|
||||
def simulate_hashable_conversion(self, state):
|
||||
# since go is MDP, we only need the last board for hashing
|
||||
return tuple(state[0][-1])
|
||||
|
||||
def _get_winner(self, board):
|
||||
black_num, white_num = self._number_of_black_and_white(board)
|
||||
black_win = black_num - white_num
|
||||
|
@ -1,17 +1,9 @@
|
||||
import numpy as np
|
||||
import math
|
||||
import time
|
||||
import sys
|
||||
import collections
|
||||
|
||||
c_puct = 5
|
||||
|
||||
def hashable_conversion(obj):
|
||||
if isinstance(obj, collections.Hashable):
|
||||
return obj
|
||||
else:
|
||||
return tuple(hashable_conversion(sub) for sub in obj)
|
||||
|
||||
class MCTSNode(object):
|
||||
def __init__(self, parent, action, state, action_num, prior, inverse=False):
|
||||
self.parent = parent
|
||||
@ -109,7 +101,9 @@ class ActionNode(object):
|
||||
if self.next_state is None: # next_state is None means that self.parent.state is the terminate state
|
||||
self.mcts.action_selection_time += time.time() - head
|
||||
return self
|
||||
self.next_state_hashable = hashable_conversion(self.next_state)
|
||||
head = time.time()
|
||||
self.next_state_hashable = simulator.simulate_hashable_conversion(self.next_state)
|
||||
self.mcts.hash_time += time.time() - head
|
||||
if self.next_state_hashable in self.children.keys(): # next state has already visited before
|
||||
self.mcts.action_selection_time += time.time() - head
|
||||
return self.children[self.next_state_hashable].selection(simulator)
|
||||
@ -153,9 +147,7 @@ class MCTS(object):
|
||||
self.state_selection_time = 0
|
||||
self.simulate_sf_time = 0
|
||||
self.valid_mask_time = 0
|
||||
self.ndarray2list_time = 0
|
||||
self.list2tuple_time = 0
|
||||
self.check = 0
|
||||
self.hash_time = 0
|
||||
|
||||
def search(self, max_step=None, max_time=None):
|
||||
step = 0
|
||||
@ -174,18 +166,16 @@ class MCTS(object):
|
||||
self.backpropagation_time += back_time
|
||||
step += 1
|
||||
if self.debug:
|
||||
file = open("mcts_profiling.txt", "a")
|
||||
file = open("mcts_profiling.log", "a")
|
||||
file.write("[" + str(self.role) + "]"
|
||||
+ " sel " + '%.3f' % self.selection_time + " "
|
||||
+ " sel_sta " + '%.3f' % self.state_selection_time + " "
|
||||
+ " valid " + '%.3f' % self.valid_mask_time + " "
|
||||
+ " sel_act " + '%.3f' % self.action_selection_time + " "
|
||||
+ " array2list " + '%.4f' % self.ndarray2list_time + " "
|
||||
+ " check " + str(self.check) + " "
|
||||
+ " list2tuple " + '%.4f' % self.list2tuple_time + " \t"
|
||||
+ " forward " + '%.3f' % self.simulate_sf_time + " "
|
||||
+ " exp " + '%.3f' % self.expansion_time + " "
|
||||
+ " bak " + '%.3f' % self.backpropagation_time + " "
|
||||
+ " hash " + '%.3f' % self.hash_time + " "
|
||||
+ " step forward " + '%.3f' % self.simulate_sf_time + " "
|
||||
+ " expansion " + '%.3f' % self.expansion_time + " "
|
||||
+ " backprop " + '%.3f' % self.backpropagation_time + " "
|
||||
+ "\n")
|
||||
file.close()
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user