remove type_conversion function
This commit is contained in:
parent
a1f6044cba
commit
9f60984973
@ -161,8 +161,8 @@ class ResNet(object):
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
'The length of history cannot meet the need of the model, given {}, need {}'.format(len(history),
|
'The length of history cannot meet the need of the model, given {}, need {}'.format(len(history),
|
||||||
self.history_length))
|
self.history_length))
|
||||||
state = self._history2state(history, color)
|
eval_state = self._history2state(history, color)
|
||||||
return self.sess.run([self.prob, self.v], feed_dict={self.x: state, self.is_training: False})
|
return self.sess.run([self.prob, self.v], feed_dict={self.x: eval_state, self.is_training: False})
|
||||||
|
|
||||||
def _history2state(self, history, color):
|
def _history2state(self, history, color):
|
||||||
"""
|
"""
|
||||||
|
@ -6,11 +6,11 @@ import collections
|
|||||||
|
|
||||||
c_puct = 5
|
c_puct = 5
|
||||||
|
|
||||||
def list2tuple(obj):
|
def hashable_conversion(obj):
|
||||||
if isinstance(obj, collections.Hashable):
|
if isinstance(obj, collections.Hashable):
|
||||||
return obj
|
return obj
|
||||||
else:
|
else:
|
||||||
return tuple(list2tuple(sub) for sub in obj)
|
return tuple(hashable_conversion(sub) for sub in obj)
|
||||||
|
|
||||||
class MCTSNode(object):
|
class MCTSNode(object):
|
||||||
def __init__(self, parent, action, state, action_num, prior, inverse=False):
|
def __init__(self, parent, action, state, action_num, prior, inverse=False):
|
||||||
@ -79,7 +79,7 @@ class UCTNode(MCTSNode):
|
|||||||
self.mask = simulator.simulate_get_mask(self.state, range(self.action_num))
|
self.mask = simulator.simulate_get_mask(self.state, range(self.action_num))
|
||||||
self.ucb[self.mask] = -float("Inf")
|
self.ucb[self.mask] = -float("Inf")
|
||||||
|
|
||||||
|
# Code reserved for Thompson Sampling
|
||||||
class TSNode(MCTSNode):
|
class TSNode(MCTSNode):
|
||||||
def __init__(self, parent, action, state, action_num, prior, method="Gaussian", inverse=False):
|
def __init__(self, parent, action, state, action_num, prior, method="Gaussian", inverse=False):
|
||||||
super(TSNode, self).__init__(parent, action, state, action_num, prior, inverse)
|
super(TSNode, self).__init__(parent, action, state, action_num, prior, inverse)
|
||||||
@ -97,22 +97,11 @@ class ActionNode(object):
|
|||||||
self.action = action
|
self.action = action
|
||||||
self.children = {}
|
self.children = {}
|
||||||
self.next_state = None
|
self.next_state = None
|
||||||
|
self.next_state_hashable = None
|
||||||
self.state_type = None
|
self.state_type = None
|
||||||
self.reward = 0
|
self.reward = 0
|
||||||
self.mcts = mcts
|
self.mcts = mcts
|
||||||
|
|
||||||
def type_conversion_to_tuple(self):
|
|
||||||
t0 = time.time()
|
|
||||||
if isinstance(self.next_state, np.ndarray):
|
|
||||||
self.next_state = self.next_state.tolist()
|
|
||||||
t1 = time.time()
|
|
||||||
if isinstance(self.next_state, list):
|
|
||||||
self.next_state = list2tuple(self.next_state)
|
|
||||||
t2 = time.time()
|
|
||||||
self.mcts.ndarray2list_time += t1 - t0
|
|
||||||
self.mcts.list2tuple_time += t2 - t1
|
|
||||||
self.mcts.check += sys.getsizeof(object)
|
|
||||||
|
|
||||||
def selection(self, simulator):
|
def selection(self, simulator):
|
||||||
head = time.time()
|
head = time.time()
|
||||||
self.next_state, self.reward = simulator.simulate_step_forward(self.parent.state, self.action)
|
self.next_state, self.reward = simulator.simulate_step_forward(self.parent.state, self.action)
|
||||||
@ -120,29 +109,28 @@ class ActionNode(object):
|
|||||||
if self.next_state is None: # next_state is None means that self.parent.state is the terminate state
|
if self.next_state is None: # next_state is None means that self.parent.state is the terminate state
|
||||||
self.mcts.action_selection_time += time.time() - head
|
self.mcts.action_selection_time += time.time() - head
|
||||||
return self.parent, self.action
|
return self.parent, self.action
|
||||||
self.origin_state = self.next_state
|
self.next_state_hashable = hashable_conversion(self.next_state)
|
||||||
self.type_conversion_to_tuple()
|
if self.next_state_hashable in self.children.keys(): # next state has already visited before
|
||||||
if self.next_state in self.children.keys(): # next state has already visited before
|
|
||||||
self.mcts.action_selection_time += time.time() - head
|
self.mcts.action_selection_time += time.time() - head
|
||||||
return self.children[self.next_state].selection(simulator)
|
return self.children[self.next_state_hashable].selection(simulator)
|
||||||
else: # next state is a new state never seen before
|
else: # next state is a new state never seen before
|
||||||
self.mcts.action_selection_time += time.time() - head
|
self.mcts.action_selection_time += time.time() - head
|
||||||
return self.parent, self.action
|
return self.parent, self.action
|
||||||
|
|
||||||
def expansion(self, evaluator, action_num):
|
def expansion(self, evaluator, action_num):
|
||||||
if self.next_state is not None:
|
if self.next_state is not None:
|
||||||
|
# note that self.next_state was assigned already at the selection function
|
||||||
prior, value = evaluator(self.next_state)
|
prior, value = evaluator(self.next_state)
|
||||||
self.children[self.next_state] = UCTNode(self, self.action, self.origin_state, action_num, prior,
|
self.children[self.next_state_hashable] = UCTNode(self, self.action, self.next_state, action_num, prior,
|
||||||
mcts=self.mcts, inverse=self.parent.inverse)
|
mcts=self.mcts, inverse=self.parent.inverse)
|
||||||
return value
|
return value
|
||||||
else:
|
else: # self.next_state is None means MCTS selected a terminate node
|
||||||
return 0.
|
return 0.
|
||||||
|
|
||||||
def backpropagation(self, value):
|
def backpropagation(self, value):
|
||||||
self.reward += value
|
self.reward += value
|
||||||
self.parent.backpropagation(self.action)
|
self.parent.backpropagation(self.action)
|
||||||
|
|
||||||
|
|
||||||
class MCTS(object):
|
class MCTS(object):
|
||||||
def __init__(self, simulator, evaluator, start_state, action_num, method="UCT",
|
def __init__(self, simulator, evaluator, start_state, action_num, method="UCT",
|
||||||
role="unknown", debug=False, inverse=False):
|
role="unknown", debug=False, inverse=False):
|
||||||
@ -214,6 +202,5 @@ class MCTS(object):
|
|||||||
t3 = time.time()
|
t3 = time.time()
|
||||||
return t1 - t0, t2 - t1, t3 - t2
|
return t1 - t0, t2 - t1, t3 - t2
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
pass
|
pass
|
||||||
|
Loading…
x
Reference in New Issue
Block a user