modify AlphaGo patch
This commit is contained in:
commit
93dc10a728
2
.gitignore
vendored
2
.gitignore
vendored
@ -1,4 +1,4 @@
|
|||||||
.idea
|
.idea
|
||||||
leela-zero
|
leela-zero
|
||||||
.pyc
|
*.pyc
|
||||||
parameters
|
parameters
|
||||||
|
@ -4,6 +4,7 @@ import time
|
|||||||
import os
|
import os
|
||||||
import multi_gpu
|
import multi_gpu
|
||||||
import tensorflow.contrib.layers as layers
|
import tensorflow.contrib.layers as layers
|
||||||
|
import sys
|
||||||
|
|
||||||
def residual_block(input, is_training):
|
def residual_block(input, is_training):
|
||||||
normalizer_params = {'is_training': is_training,
|
normalizer_params = {'is_training': is_training,
|
||||||
@ -102,6 +103,13 @@ def train():
|
|||||||
|
|
||||||
def forward(board):
|
def forward(board):
|
||||||
result_path = "./results/"
|
result_path = "./results/"
|
||||||
|
itflag = False
|
||||||
|
res = None
|
||||||
|
if board is None:
|
||||||
|
board = np.load("/home/yama/tongzheng/AG/self_play_204/d7d7d552b7be4b51883de99d74a8e51b.npz")
|
||||||
|
board = board["boards"][100].reshape(-1, 19, 19, 17)
|
||||||
|
result_path = "../parameters/checkpoints"
|
||||||
|
itflag = True
|
||||||
with multi_gpu.create_session() as sess:
|
with multi_gpu.create_session() as sess:
|
||||||
sess.run(tf.global_variables_initializer())
|
sess.run(tf.global_variables_initializer())
|
||||||
ckpt_file = tf.train.latest_checkpoint(result_path)
|
ckpt_file = tf.train.latest_checkpoint(result_path)
|
||||||
@ -110,7 +118,15 @@ def forward(board):
|
|||||||
saver.restore(sess, ckpt_file)
|
saver.restore(sess, ckpt_file)
|
||||||
else:
|
else:
|
||||||
raise ValueError("No model loaded")
|
raise ValueError("No model loaded")
|
||||||
return sess.run([p,v], feed_dict={x:board})
|
res = sess.run([tf.nn.softmax(p),v], feed_dict={x:board, is_training:itflag})
|
||||||
|
#res = sess.run([tf.nn.softmax(p),v], feed_dict={x:fix_board["boards"][300].reshape(-1, 19, 19, 17), is_training:False})
|
||||||
|
#res = sess.run([tf.nn.softmax(p),v], feed_dict={x:fix_board["boards"][50].reshape(-1, 19, 19, 17), is_training:True})
|
||||||
|
print(res)
|
||||||
|
#print(res[0].tolist()[0])
|
||||||
|
#print(np.argmax(res[0]))
|
||||||
|
return res
|
||||||
|
|
||||||
if __name__=="__main__":
|
if __name__=='__main__':
|
||||||
train()
|
#train()
|
||||||
|
if sys.argv[1] == "test":
|
||||||
|
forward(None)
|
||||||
|
31
README.md
31
README.md
@ -28,3 +28,34 @@ MCTS
|
|||||||
## agent (optional)
|
## agent (optional)
|
||||||
|
|
||||||
DQNAgent etc.
|
DQNAgent etc.
|
||||||
|
|
||||||
|
## Pontential Bugs:
|
||||||
|
|
||||||
|
0. Wrong calculation of eval value
|
||||||
|
|
||||||
|
UCTNode.cpp
|
||||||
|
```
|
||||||
|
106 if (to_move == FastBoard::WHITE) {
|
||||||
|
107 net_eval = 1.0f - net_eval;
|
||||||
|
108 }
|
||||||
|
|
||||||
|
309 if (tomove == FastBoard::WHITE) {
|
||||||
|
310 score = 1.0f - score;
|
||||||
|
311 }
|
||||||
|
```
|
||||||
|
|
||||||
|
1. create children only on leaf node
|
||||||
|
|
||||||
|
UCTSearch.cpp
|
||||||
|
```
|
||||||
|
60 if (!node->has_children() && m_nodes < MAX_TREE_SIZE) {
|
||||||
|
61 float eval;
|
||||||
|
62 auto success = node->create_children(m_nodes, currstate, eval);
|
||||||
|
63 if (success) {
|
||||||
|
64 result = SearchResult(eval);
|
||||||
|
65 }
|
||||||
|
66 }
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user