add debuf info for mcts and add softmax for the prior
This commit is contained in:
parent
725fc2c04e
commit
aa6b5434c6
@ -71,6 +71,13 @@ class Game:
|
||||
mcts = MCTS(self.game_engine, self.evaluator, [latest_boards, color],
|
||||
self.size ** 2 + 1, role=self.role, debug=self.debug, inverse=True)
|
||||
mcts.search(max_step=100)
|
||||
if self.debug:
|
||||
file = open("mcts_debug.log", 'ab')
|
||||
np.savetxt(file, mcts.root.Q, header="\nQ value : ", fmt='%.4f', newline=", ")
|
||||
np.savetxt(file, mcts.root.W, header="\nW value : ", fmt='%.4f', newline=", ")
|
||||
np.savetxt(file, mcts.root.N, header="\nN value : ", fmt="%d", newline=", ")
|
||||
np.savetxt(file, mcts.root.prior, header="\nprior : ", fmt='%.4f', newline=", ")
|
||||
file.close()
|
||||
temp = 1
|
||||
prob = mcts.root.N ** temp / np.sum(mcts.root.N ** temp)
|
||||
choice = np.random.choice(self.size ** 2 + 1, 1, p=prob).tolist()[0]
|
||||
@ -119,7 +126,7 @@ class Game:
|
||||
sys.stdout.flush()
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("test game.py")
|
||||
#file = open("debug.txt", "a")
|
||||
#file.write("mcts check\n")
|
||||
#file.close()
|
||||
game = Game(name="go", checkpoint_path="./checkpoint")
|
||||
game.debug = True
|
||||
game.think_play_move(utils.BLACK)
|
||||
|
||||
|
@ -80,7 +80,7 @@ class Data(object):
|
||||
|
||||
|
||||
class ResNet(object):
|
||||
def __init__(self, board_size, action_num, history_length=1, residual_block_num=20, checkpoint_path=None):
|
||||
def __init__(self, board_size, action_num, history_length=1, residual_block_num=10, checkpoint_path=None):
|
||||
"""
|
||||
the resnet model
|
||||
|
||||
@ -161,7 +161,7 @@ class ResNet(object):
|
||||
'The length of history cannot meet the need of the model, given {}, need {}'.format(len(history),
|
||||
self.history_length))
|
||||
state = self._history2state(history, color)
|
||||
return self.sess.run([self.p, self.v], feed_dict={self.x: state, self.is_training: False})
|
||||
return self.sess.run([tf.nn.softmax(self.p), self.v], feed_dict={self.x: state, self.is_training: False})
|
||||
|
||||
def _history2state(self, history, color):
|
||||
"""
|
||||
|
Loading…
x
Reference in New Issue
Block a user