modify AlphaGo patch

This commit is contained in:
Tongzheng Ren 2017-11-08 08:35:39 +08:00
commit 93dc10a728
3 changed files with 59 additions and 12 deletions

2
.gitignore vendored
View File

@ -1,4 +1,4 @@
.idea .idea
leela-zero leela-zero
.pyc *.pyc
parameters parameters

View File

@ -4,6 +4,7 @@ import time
import os import os
import multi_gpu import multi_gpu
import tensorflow.contrib.layers as layers import tensorflow.contrib.layers as layers
import sys
def residual_block(input, is_training): def residual_block(input, is_training):
normalizer_params = {'is_training': is_training, normalizer_params = {'is_training': is_training,
@ -102,15 +103,30 @@ def train():
def forward(board): def forward(board):
result_path = "./results/" result_path = "./results/"
with multi_gpu.create_session() as sess: itflag = False
sess.run(tf.global_variables_initializer()) res = None
ckpt_file = tf.train.latest_checkpoint(result_path) if board is None:
if ckpt_file is not None: board = np.load("/home/yama/tongzheng/AG/self_play_204/d7d7d552b7be4b51883de99d74a8e51b.npz")
print('Restoring model from {}...'.format(ckpt_file)) board = board["boards"][100].reshape(-1, 19, 19, 17)
saver.restore(sess, ckpt_file) result_path = "../parameters/checkpoints"
else: itflag = True
raise ValueError("No model loaded") with multi_gpu.create_session() as sess:
return sess.run([p,v], feed_dict={x:board}) sess.run(tf.global_variables_initializer())
ckpt_file = tf.train.latest_checkpoint(result_path)
if ckpt_file is not None:
print('Restoring model from {}...'.format(ckpt_file))
saver.restore(sess, ckpt_file)
else:
raise ValueError("No model loaded")
res = sess.run([tf.nn.softmax(p),v], feed_dict={x:board, is_training:itflag})
#res = sess.run([tf.nn.softmax(p),v], feed_dict={x:fix_board["boards"][300].reshape(-1, 19, 19, 17), is_training:False})
#res = sess.run([tf.nn.softmax(p),v], feed_dict={x:fix_board["boards"][50].reshape(-1, 19, 19, 17), is_training:True})
print(res)
#print(res[0].tolist()[0])
#print(np.argmax(res[0]))
return res
if __name__=="__main__": if __name__=='__main__':
train() #train()
if sys.argv[1] == "test":
forward(None)

View File

@ -28,3 +28,34 @@ MCTS
## agent (optional) ## agent (optional)
DQNAgent etc. DQNAgent etc.
## Pontential Bugs:
0. Wrong calculation of eval value
UCTNode.cpp
```
106 if (to_move == FastBoard::WHITE) {
107 net_eval = 1.0f - net_eval;
108 }
309 if (tomove == FastBoard::WHITE) {
310 score = 1.0f - score;
311 }
```
1. create children only on leaf node
UCTSearch.cpp
```
60 if (!node->has_children() && m_nodes < MAX_TREE_SIZE) {
61 float eval;
62 auto success = node->create_children(m_nodes, currstate, eval);
63 if (success) {
64 result = SearchResult(eval);
65 }
66 }
```