add debuf info for mcts and add softmax for the prior

This commit is contained in:
Dong Yan 2017-12-26 14:46:14 +08:00
parent 725fc2c04e
commit aa6b5434c6
2 changed files with 13 additions and 6 deletions

View File

@ -71,6 +71,13 @@ class Game:
mcts = MCTS(self.game_engine, self.evaluator, [latest_boards, color],
self.size ** 2 + 1, role=self.role, debug=self.debug, inverse=True)
mcts.search(max_step=100)
if self.debug:
file = open("mcts_debug.log", 'ab')
np.savetxt(file, mcts.root.Q, header="\nQ value : ", fmt='%.4f', newline=", ")
np.savetxt(file, mcts.root.W, header="\nW value : ", fmt='%.4f', newline=", ")
np.savetxt(file, mcts.root.N, header="\nN value : ", fmt="%d", newline=", ")
np.savetxt(file, mcts.root.prior, header="\nprior : ", fmt='%.4f', newline=", ")
file.close()
temp = 1
prob = mcts.root.N ** temp / np.sum(mcts.root.N ** temp)
choice = np.random.choice(self.size ** 2 + 1, 1, p=prob).tolist()[0]
@ -119,7 +126,7 @@ class Game:
sys.stdout.flush()
if __name__ == "__main__":
print("test game.py")
#file = open("debug.txt", "a")
#file.write("mcts check\n")
#file.close()
game = Game(name="go", checkpoint_path="./checkpoint")
game.debug = True
game.think_play_move(utils.BLACK)

View File

@ -80,7 +80,7 @@ class Data(object):
class ResNet(object):
def __init__(self, board_size, action_num, history_length=1, residual_block_num=20, checkpoint_path=None):
def __init__(self, board_size, action_num, history_length=1, residual_block_num=10, checkpoint_path=None):
"""
the resnet model
@ -161,7 +161,7 @@ class ResNet(object):
'The length of history cannot meet the need of the model, given {}, need {}'.format(len(history),
self.history_length))
state = self._history2state(history, color)
return self.sess.run([self.p, self.v], feed_dict={self.x: state, self.is_training: False})
return self.sess.run([tf.nn.softmax(self.p), self.v], feed_dict={self.x: state, self.is_training: False})
def _history2state(self, history, color):
"""