remove type_conversion function

This commit is contained in:
Dong Yan 2017-12-27 14:08:34 +08:00
parent a1f6044cba
commit 9f60984973
2 changed files with 12 additions and 25 deletions

View File

@ -161,8 +161,8 @@ class ResNet(object):
raise ValueError(
'The length of history cannot meet the need of the model, given {}, need {}'.format(len(history),
self.history_length))
state = self._history2state(history, color)
return self.sess.run([self.prob, self.v], feed_dict={self.x: state, self.is_training: False})
eval_state = self._history2state(history, color)
return self.sess.run([self.prob, self.v], feed_dict={self.x: eval_state, self.is_training: False})
def _history2state(self, history, color):
"""

View File

@ -6,11 +6,11 @@ import collections
c_puct = 5
def list2tuple(obj):
def hashable_conversion(obj):
if isinstance(obj, collections.Hashable):
return obj
else:
return tuple(list2tuple(sub) for sub in obj)
return tuple(hashable_conversion(sub) for sub in obj)
class MCTSNode(object):
def __init__(self, parent, action, state, action_num, prior, inverse=False):
@ -79,7 +79,7 @@ class UCTNode(MCTSNode):
self.mask = simulator.simulate_get_mask(self.state, range(self.action_num))
self.ucb[self.mask] = -float("Inf")
# Code reserved for Thompson Sampling
class TSNode(MCTSNode):
def __init__(self, parent, action, state, action_num, prior, method="Gaussian", inverse=False):
super(TSNode, self).__init__(parent, action, state, action_num, prior, inverse)
@ -97,22 +97,11 @@ class ActionNode(object):
self.action = action
self.children = {}
self.next_state = None
self.next_state_hashable = None
self.state_type = None
self.reward = 0
self.mcts = mcts
def type_conversion_to_tuple(self):
t0 = time.time()
if isinstance(self.next_state, np.ndarray):
self.next_state = self.next_state.tolist()
t1 = time.time()
if isinstance(self.next_state, list):
self.next_state = list2tuple(self.next_state)
t2 = time.time()
self.mcts.ndarray2list_time += t1 - t0
self.mcts.list2tuple_time += t2 - t1
self.mcts.check += sys.getsizeof(object)
def selection(self, simulator):
head = time.time()
self.next_state, self.reward = simulator.simulate_step_forward(self.parent.state, self.action)
@ -120,29 +109,28 @@ class ActionNode(object):
if self.next_state is None: # next_state is None means that self.parent.state is the terminate state
self.mcts.action_selection_time += time.time() - head
return self.parent, self.action
self.origin_state = self.next_state
self.type_conversion_to_tuple()
if self.next_state in self.children.keys(): # next state has already visited before
self.next_state_hashable = hashable_conversion(self.next_state)
if self.next_state_hashable in self.children.keys(): # next state has already visited before
self.mcts.action_selection_time += time.time() - head
return self.children[self.next_state].selection(simulator)
return self.children[self.next_state_hashable].selection(simulator)
else: # next state is a new state never seen before
self.mcts.action_selection_time += time.time() - head
return self.parent, self.action
def expansion(self, evaluator, action_num):
if self.next_state is not None:
# note that self.next_state was assigned already at the selection function
prior, value = evaluator(self.next_state)
self.children[self.next_state] = UCTNode(self, self.action, self.origin_state, action_num, prior,
self.children[self.next_state_hashable] = UCTNode(self, self.action, self.next_state, action_num, prior,
mcts=self.mcts, inverse=self.parent.inverse)
return value
else:
else: # self.next_state is None means MCTS selected a terminate node
return 0.
def backpropagation(self, value):
self.reward += value
self.parent.backpropagation(self.action)
class MCTS(object):
def __init__(self, simulator, evaluator, start_state, action_num, method="UCT",
role="unknown", debug=False, inverse=False):
@ -214,6 +202,5 @@ class MCTS(object):
t3 = time.time()
return t1 - t0, t2 - t1, t3 - t2
if __name__ == "__main__":
pass