add debuf info for mcts and add softmax for the prior
This commit is contained in:
parent
725fc2c04e
commit
aa6b5434c6
@ -71,6 +71,13 @@ class Game:
|
|||||||
mcts = MCTS(self.game_engine, self.evaluator, [latest_boards, color],
|
mcts = MCTS(self.game_engine, self.evaluator, [latest_boards, color],
|
||||||
self.size ** 2 + 1, role=self.role, debug=self.debug, inverse=True)
|
self.size ** 2 + 1, role=self.role, debug=self.debug, inverse=True)
|
||||||
mcts.search(max_step=100)
|
mcts.search(max_step=100)
|
||||||
|
if self.debug:
|
||||||
|
file = open("mcts_debug.log", 'ab')
|
||||||
|
np.savetxt(file, mcts.root.Q, header="\nQ value : ", fmt='%.4f', newline=", ")
|
||||||
|
np.savetxt(file, mcts.root.W, header="\nW value : ", fmt='%.4f', newline=", ")
|
||||||
|
np.savetxt(file, mcts.root.N, header="\nN value : ", fmt="%d", newline=", ")
|
||||||
|
np.savetxt(file, mcts.root.prior, header="\nprior : ", fmt='%.4f', newline=", ")
|
||||||
|
file.close()
|
||||||
temp = 1
|
temp = 1
|
||||||
prob = mcts.root.N ** temp / np.sum(mcts.root.N ** temp)
|
prob = mcts.root.N ** temp / np.sum(mcts.root.N ** temp)
|
||||||
choice = np.random.choice(self.size ** 2 + 1, 1, p=prob).tolist()[0]
|
choice = np.random.choice(self.size ** 2 + 1, 1, p=prob).tolist()[0]
|
||||||
@ -119,7 +126,7 @@ class Game:
|
|||||||
sys.stdout.flush()
|
sys.stdout.flush()
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
print("test game.py")
|
game = Game(name="go", checkpoint_path="./checkpoint")
|
||||||
#file = open("debug.txt", "a")
|
game.debug = True
|
||||||
#file.write("mcts check\n")
|
game.think_play_move(utils.BLACK)
|
||||||
#file.close()
|
|
||||||
|
@ -80,7 +80,7 @@ class Data(object):
|
|||||||
|
|
||||||
|
|
||||||
class ResNet(object):
|
class ResNet(object):
|
||||||
def __init__(self, board_size, action_num, history_length=1, residual_block_num=20, checkpoint_path=None):
|
def __init__(self, board_size, action_num, history_length=1, residual_block_num=10, checkpoint_path=None):
|
||||||
"""
|
"""
|
||||||
the resnet model
|
the resnet model
|
||||||
|
|
||||||
@ -161,7 +161,7 @@ class ResNet(object):
|
|||||||
'The length of history cannot meet the need of the model, given {}, need {}'.format(len(history),
|
'The length of history cannot meet the need of the model, given {}, need {}'.format(len(history),
|
||||||
self.history_length))
|
self.history_length))
|
||||||
state = self._history2state(history, color)
|
state = self._history2state(history, color)
|
||||||
return self.sess.run([self.p, self.v], feed_dict={self.x: state, self.is_training: False})
|
return self.sess.run([tf.nn.softmax(self.p), self.v], feed_dict={self.x: state, self.is_training: False})
|
||||||
|
|
||||||
def _history2state(self, history, color):
|
def _history2state(self, history, color):
|
||||||
"""
|
"""
|
||||||
|
Loading…
x
Reference in New Issue
Block a user