This commit is contained in:
Tongzheng Ren 2018-01-08 21:21:08 +08:00
commit 891c5b1e47
2 changed files with 5 additions and 3 deletions

View File

@ -284,7 +284,7 @@ class ResNet(object):
history.append(board) history.append(board)
states.append(self._history2state(history, color)) states.append(self._history2state(history, color))
probs.append(np.array(prob).reshape(1, self.board_size ** 2 + 1)) probs.append(np.array(prob).reshape(1, self.board_size ** 2 + 1))
winner.append(np.array(data.winner).reshape(1, 1)) winner.append(np.array(data.winner * color).reshape(1, 1))
color *= -1 color *= -1
states = np.concatenate(states, axis=0) states = np.concatenate(states, axis=0)
probs = np.concatenate(probs, axis=0) probs = np.concatenate(probs, axis=0)

View File

@ -26,7 +26,7 @@ class MCTSNode(object):
class UCTNode(MCTSNode): class UCTNode(MCTSNode):
def __init__(self, parent, action, state, action_num, prior, mcts, inverse=False): def __init__(self, parent, action, state, action_num, prior, mcts, inverse=False):
super(UCTNode, self).__init__(parent, action, state, action_num, prior, inverse) super(UCTNode, self).__init__(parent, action, state, action_num, prior, inverse)
self.Q = np.zeros([action_num]) self.Q = np.random.uniform(-1, 1, action_num) * (1e-6)
self.W = np.zeros([action_num]) self.W = np.zeros([action_num])
self.N = np.zeros([action_num]) self.N = np.zeros([action_num])
self.c_puct = c_puct self.c_puct = c_puct
@ -121,12 +121,14 @@ class ActionNode(object):
class MCTS(object): class MCTS(object):
def __init__(self, simulator, evaluator, start_state, action_num, method="UCT", def __init__(self, simulator, evaluator, start_state, action_num, method="UCT",
role="unknown", debug=False, inverse=False): role="unknown", debug=False, inverse=False, epsilon=0.25):
self.simulator = simulator self.simulator = simulator
self.evaluator = evaluator self.evaluator = evaluator
self.role = role self.role = role
self.debug = debug self.debug = debug
self.epsilon = epsilon
prior, _ = self.evaluator(start_state) prior, _ = self.evaluator(start_state)
prior = (1 - self.epsilon) * prior + self.epsilon * np.random.dirichlet(1.0/action_num * np.ones([action_num]))
self.action_num = action_num self.action_num = action_num
if method == "": if method == "":
self.root = start_state self.root = start_state