solve the performance bottleneck by only hashing the last board

This commit is contained in:
Dong Yan 2017-12-28 01:16:24 +08:00
parent affd0319e2
commit 47676993fd
3 changed files with 21 additions and 23 deletions

View File

@ -222,6 +222,10 @@ class Go:
new_color = -color
return [history_boards, new_color], 0
def simulate_hashable_conversion(self, state):
# since go is MDP, we only need the last board for hashing
return tuple(state[0][-1])
def executor_do_move(self, history, latest_boards, current_board, color, vertex):
if not self._rule_check(history, current_board, color, vertex):
return False

View File

@ -97,6 +97,10 @@ class Reversi:
history_boards.append(new_board)
return [history_boards, 0 - color], 0
def simulate_hashable_conversion(self, state):
# since go is MDP, we only need the last board for hashing
return tuple(state[0][-1])
def _get_winner(self, board):
black_num, white_num = self._number_of_black_and_white(board)
black_win = black_num - white_num

View File

@ -1,17 +1,9 @@
import numpy as np
import math
import time
import sys
import collections
c_puct = 5
def hashable_conversion(obj):
if isinstance(obj, collections.Hashable):
return obj
else:
return tuple(hashable_conversion(sub) for sub in obj)
class MCTSNode(object):
def __init__(self, parent, action, state, action_num, prior, inverse=False):
self.parent = parent
@ -109,7 +101,9 @@ class ActionNode(object):
if self.next_state is None: # next_state is None means that self.parent.state is the terminate state
self.mcts.action_selection_time += time.time() - head
return self
self.next_state_hashable = hashable_conversion(self.next_state)
head = time.time()
self.next_state_hashable = simulator.simulate_hashable_conversion(self.next_state)
self.mcts.hash_time += time.time() - head
if self.next_state_hashable in self.children.keys(): # next state has already visited before
self.mcts.action_selection_time += time.time() - head
return self.children[self.next_state_hashable].selection(simulator)
@ -153,9 +147,7 @@ class MCTS(object):
self.state_selection_time = 0
self.simulate_sf_time = 0
self.valid_mask_time = 0
self.ndarray2list_time = 0
self.list2tuple_time = 0
self.check = 0
self.hash_time = 0
def search(self, max_step=None, max_time=None):
step = 0
@ -174,18 +166,16 @@ class MCTS(object):
self.backpropagation_time += back_time
step += 1
if self.debug:
file = open("mcts_profiling.txt", "a")
file = open("mcts_profiling.log", "a")
file.write("[" + str(self.role) + "]"
+ " sel " + '%.3f' % self.selection_time + " "
+ " sel_sta " + '%.3f' % self.state_selection_time + " "
+ " valid " + '%.3f' % self.valid_mask_time + " "
+ " sel_act " + '%.3f' % self.action_selection_time + " "
+ " array2list " + '%.4f' % self.ndarray2list_time + " "
+ " check " + str(self.check) + " "
+ " list2tuple " + '%.4f' % self.list2tuple_time + " \t"
+ " forward " + '%.3f' % self.simulate_sf_time + " "
+ " exp " + '%.3f' % self.expansion_time + " "
+ " bak " + '%.3f' % self.backpropagation_time + " "
+ " hash " + '%.3f' % self.hash_time + " "
+ " step forward " + '%.3f' % self.simulate_sf_time + " "
+ " expansion " + '%.3f' % self.expansion_time + " "
+ " backprop " + '%.3f' % self.backpropagation_time + " "
+ "\n")
file.close()