From 9f6098497336d989c70b7f6fc67ebf2bc4ad6e85 Mon Sep 17 00:00:00 2001 From: Dong Yan Date: Wed, 27 Dec 2017 14:08:34 +0800 Subject: [PATCH] remove type_conversion function --- AlphaGo/model.py | 4 ++-- tianshou/core/mcts/mcts.py | 33 ++++++++++----------------------- 2 files changed, 12 insertions(+), 25 deletions(-) diff --git a/AlphaGo/model.py b/AlphaGo/model.py index dbfc5ca..6fde6e5 100644 --- a/AlphaGo/model.py +++ b/AlphaGo/model.py @@ -161,8 +161,8 @@ class ResNet(object): raise ValueError( 'The length of history cannot meet the need of the model, given {}, need {}'.format(len(history), self.history_length)) - state = self._history2state(history, color) - return self.sess.run([self.prob, self.v], feed_dict={self.x: state, self.is_training: False}) + eval_state = self._history2state(history, color) + return self.sess.run([self.prob, self.v], feed_dict={self.x: eval_state, self.is_training: False}) def _history2state(self, history, color): """ diff --git a/tianshou/core/mcts/mcts.py b/tianshou/core/mcts/mcts.py index f64b5a0..98ab056 100644 --- a/tianshou/core/mcts/mcts.py +++ b/tianshou/core/mcts/mcts.py @@ -6,11 +6,11 @@ import collections c_puct = 5 -def list2tuple(obj): +def hashable_conversion(obj): if isinstance(obj, collections.Hashable): return obj else: - return tuple(list2tuple(sub) for sub in obj) + return tuple(hashable_conversion(sub) for sub in obj) class MCTSNode(object): def __init__(self, parent, action, state, action_num, prior, inverse=False): @@ -79,7 +79,7 @@ class UCTNode(MCTSNode): self.mask = simulator.simulate_get_mask(self.state, range(self.action_num)) self.ucb[self.mask] = -float("Inf") - +# Code reserved for Thompson Sampling class TSNode(MCTSNode): def __init__(self, parent, action, state, action_num, prior, method="Gaussian", inverse=False): super(TSNode, self).__init__(parent, action, state, action_num, prior, inverse) @@ -97,22 +97,11 @@ class ActionNode(object): self.action = action self.children = {} self.next_state = None + self.next_state_hashable = None self.state_type = None self.reward = 0 self.mcts = mcts - def type_conversion_to_tuple(self): - t0 = time.time() - if isinstance(self.next_state, np.ndarray): - self.next_state = self.next_state.tolist() - t1 = time.time() - if isinstance(self.next_state, list): - self.next_state = list2tuple(self.next_state) - t2 = time.time() - self.mcts.ndarray2list_time += t1 - t0 - self.mcts.list2tuple_time += t2 - t1 - self.mcts.check += sys.getsizeof(object) - def selection(self, simulator): head = time.time() self.next_state, self.reward = simulator.simulate_step_forward(self.parent.state, self.action) @@ -120,29 +109,28 @@ class ActionNode(object): if self.next_state is None: # next_state is None means that self.parent.state is the terminate state self.mcts.action_selection_time += time.time() - head return self.parent, self.action - self.origin_state = self.next_state - self.type_conversion_to_tuple() - if self.next_state in self.children.keys(): # next state has already visited before + self.next_state_hashable = hashable_conversion(self.next_state) + if self.next_state_hashable in self.children.keys(): # next state has already visited before self.mcts.action_selection_time += time.time() - head - return self.children[self.next_state].selection(simulator) + return self.children[self.next_state_hashable].selection(simulator) else: # next state is a new state never seen before self.mcts.action_selection_time += time.time() - head return self.parent, self.action def expansion(self, evaluator, action_num): if self.next_state is not None: + # note that self.next_state was assigned already at the selection function prior, value = evaluator(self.next_state) - self.children[self.next_state] = UCTNode(self, self.action, self.origin_state, action_num, prior, + self.children[self.next_state_hashable] = UCTNode(self, self.action, self.next_state, action_num, prior, mcts=self.mcts, inverse=self.parent.inverse) return value - else: + else: # self.next_state is None means MCTS selected a terminate node return 0. def backpropagation(self, value): self.reward += value self.parent.backpropagation(self.action) - class MCTS(object): def __init__(self, simulator, evaluator, start_state, action_num, method="UCT", role="unknown", debug=False, inverse=False): @@ -214,6 +202,5 @@ class MCTS(object): t3 = time.time() return t1 - t0, t2 - t1, t3 - t2 - if __name__ == "__main__": pass