diff --git a/AlphaGo/go.py b/AlphaGo/go.py index 55f5a4a..987fe93 100644 --- a/AlphaGo/go.py +++ b/AlphaGo/go.py @@ -222,6 +222,10 @@ class Go: new_color = -color return [history_boards, new_color], 0 + def simulate_hashable_conversion(self, state): + # since go is MDP, we only need the last board for hashing + return tuple(state[0][-1]) + def executor_do_move(self, history, latest_boards, current_board, color, vertex): if not self._rule_check(history, current_board, color, vertex): return False diff --git a/AlphaGo/reversi.py b/AlphaGo/reversi.py index c6c8a5b..08a5ec5 100644 --- a/AlphaGo/reversi.py +++ b/AlphaGo/reversi.py @@ -97,6 +97,10 @@ class Reversi: history_boards.append(new_board) return [history_boards, 0 - color], 0 + def simulate_hashable_conversion(self, state): + # since go is MDP, we only need the last board for hashing + return tuple(state[0][-1]) + def _get_winner(self, board): black_num, white_num = self._number_of_black_and_white(board) black_win = black_num - white_num diff --git a/tianshou/core/mcts/mcts.py b/tianshou/core/mcts/mcts.py index 4c23809..9625261 100644 --- a/tianshou/core/mcts/mcts.py +++ b/tianshou/core/mcts/mcts.py @@ -1,17 +1,9 @@ import numpy as np import math import time -import sys -import collections c_puct = 5 -def hashable_conversion(obj): - if isinstance(obj, collections.Hashable): - return obj - else: - return tuple(hashable_conversion(sub) for sub in obj) - class MCTSNode(object): def __init__(self, parent, action, state, action_num, prior, inverse=False): self.parent = parent @@ -109,7 +101,9 @@ class ActionNode(object): if self.next_state is None: # next_state is None means that self.parent.state is the terminate state self.mcts.action_selection_time += time.time() - head return self - self.next_state_hashable = hashable_conversion(self.next_state) + head = time.time() + self.next_state_hashable = simulator.simulate_hashable_conversion(self.next_state) + self.mcts.hash_time += time.time() - head if self.next_state_hashable in self.children.keys(): # next state has already visited before self.mcts.action_selection_time += time.time() - head return self.children[self.next_state_hashable].selection(simulator) @@ -153,9 +147,7 @@ class MCTS(object): self.state_selection_time = 0 self.simulate_sf_time = 0 self.valid_mask_time = 0 - self.ndarray2list_time = 0 - self.list2tuple_time = 0 - self.check = 0 + self.hash_time = 0 def search(self, max_step=None, max_time=None): step = 0 @@ -174,18 +166,16 @@ class MCTS(object): self.backpropagation_time += back_time step += 1 if self.debug: - file = open("mcts_profiling.txt", "a") + file = open("mcts_profiling.log", "a") file.write("[" + str(self.role) + "]" - + " sel " + '%.3f' % self.selection_time + " " - + " sel_sta " + '%.3f' % self.state_selection_time + " " - + " valid " + '%.3f' % self.valid_mask_time + " " - + " sel_act " + '%.3f' % self.action_selection_time + " " - + " array2list " + '%.4f' % self.ndarray2list_time + " " - + " check " + str(self.check) + " " - + " list2tuple " + '%.4f' % self.list2tuple_time + " \t" - + " forward " + '%.3f' % self.simulate_sf_time + " " - + " exp " + '%.3f' % self.expansion_time + " " - + " bak " + '%.3f' % self.backpropagation_time + " " + + " sel " + '%.3f' % self.selection_time + " " + + " sel_sta " + '%.3f' % self.state_selection_time + " " + + " valid " + '%.3f' % self.valid_mask_time + " " + + " sel_act " + '%.3f' % self.action_selection_time + " " + + " hash " + '%.3f' % self.hash_time + " " + + " step forward " + '%.3f' % self.simulate_sf_time + " " + + " expansion " + '%.3f' % self.expansion_time + " " + + " backprop " + '%.3f' % self.backpropagation_time + " " + "\n") file.close()