Merge branch 'master' of https://github.com/sproblvem/tianshou
This commit is contained in:
commit
891c5b1e47
@ -284,7 +284,7 @@ class ResNet(object):
|
|||||||
history.append(board)
|
history.append(board)
|
||||||
states.append(self._history2state(history, color))
|
states.append(self._history2state(history, color))
|
||||||
probs.append(np.array(prob).reshape(1, self.board_size ** 2 + 1))
|
probs.append(np.array(prob).reshape(1, self.board_size ** 2 + 1))
|
||||||
winner.append(np.array(data.winner).reshape(1, 1))
|
winner.append(np.array(data.winner * color).reshape(1, 1))
|
||||||
color *= -1
|
color *= -1
|
||||||
states = np.concatenate(states, axis=0)
|
states = np.concatenate(states, axis=0)
|
||||||
probs = np.concatenate(probs, axis=0)
|
probs = np.concatenate(probs, axis=0)
|
||||||
|
@ -26,7 +26,7 @@ class MCTSNode(object):
|
|||||||
class UCTNode(MCTSNode):
|
class UCTNode(MCTSNode):
|
||||||
def __init__(self, parent, action, state, action_num, prior, mcts, inverse=False):
|
def __init__(self, parent, action, state, action_num, prior, mcts, inverse=False):
|
||||||
super(UCTNode, self).__init__(parent, action, state, action_num, prior, inverse)
|
super(UCTNode, self).__init__(parent, action, state, action_num, prior, inverse)
|
||||||
self.Q = np.zeros([action_num])
|
self.Q = np.random.uniform(-1, 1, action_num) * (1e-6)
|
||||||
self.W = np.zeros([action_num])
|
self.W = np.zeros([action_num])
|
||||||
self.N = np.zeros([action_num])
|
self.N = np.zeros([action_num])
|
||||||
self.c_puct = c_puct
|
self.c_puct = c_puct
|
||||||
@ -121,12 +121,14 @@ class ActionNode(object):
|
|||||||
|
|
||||||
class MCTS(object):
|
class MCTS(object):
|
||||||
def __init__(self, simulator, evaluator, start_state, action_num, method="UCT",
|
def __init__(self, simulator, evaluator, start_state, action_num, method="UCT",
|
||||||
role="unknown", debug=False, inverse=False):
|
role="unknown", debug=False, inverse=False, epsilon=0.25):
|
||||||
self.simulator = simulator
|
self.simulator = simulator
|
||||||
self.evaluator = evaluator
|
self.evaluator = evaluator
|
||||||
self.role = role
|
self.role = role
|
||||||
self.debug = debug
|
self.debug = debug
|
||||||
|
self.epsilon = epsilon
|
||||||
prior, _ = self.evaluator(start_state)
|
prior, _ = self.evaluator(start_state)
|
||||||
|
prior = (1 - self.epsilon) * prior + self.epsilon * np.random.dirichlet(1.0/action_num * np.ones([action_num]))
|
||||||
self.action_num = action_num
|
self.action_num = action_num
|
||||||
if method == "":
|
if method == "":
|
||||||
self.root = start_state
|
self.root = start_state
|
||||||
|
Loading…
x
Reference in New Issue
Block a user