minor fixed
This commit is contained in:
parent
1d7d4e14ef
commit
ab1b3775e7
@ -12,7 +12,7 @@ os.environ["CUDA_VISIBLE_DEVICES"] = "1"
|
||||
|
||||
def residual_block(input, is_training):
|
||||
normalizer_params = {'is_training': is_training,
|
||||
'updates_collections': None}
|
||||
'updates_collections': tf.GraphKeys.UPDATE_OPS}
|
||||
h = layers.conv2d(input, 256, kernel_size=3, stride=1, activation_fn=tf.nn.relu,
|
||||
normalizer_fn=layers.batch_norm, normalizer_params=normalizer_params,
|
||||
weights_regularizer=layers.l2_regularizer(1e-4))
|
||||
@ -25,7 +25,7 @@ def residual_block(input, is_training):
|
||||
|
||||
def policy_heads(input, is_training):
|
||||
normalizer_params = {'is_training': is_training,
|
||||
'updates_collections': None}
|
||||
'updates_collections': tf.GraphKeys.UPDATE_OPS}
|
||||
h = layers.conv2d(input, 2, kernel_size=1, stride=1, activation_fn=tf.nn.relu,
|
||||
normalizer_fn=layers.batch_norm, normalizer_params=normalizer_params,
|
||||
weights_regularizer=layers.l2_regularizer(1e-4))
|
||||
@ -36,7 +36,7 @@ def policy_heads(input, is_training):
|
||||
|
||||
def value_heads(input, is_training):
|
||||
normalizer_params = {'is_training': is_training,
|
||||
'updates_collections': None}
|
||||
'updates_collections': tf.GraphKeys.UPDATE_OPS}
|
||||
h = layers.conv2d(input, 2, kernel_size=1, stride=1, activation_fn=tf.nn.relu,
|
||||
normalizer_fn=layers.batch_norm, normalizer_params=normalizer_params,
|
||||
weights_regularizer=layers.l2_regularizer(1e-4))
|
||||
@ -52,7 +52,7 @@ z = tf.placeholder(tf.float32, shape=[None, 1])
|
||||
pi = tf.placeholder(tf.float32, shape=[None, 362])
|
||||
|
||||
h = layers.conv2d(x, 256, kernel_size=3, stride=1, activation_fn=tf.nn.relu, normalizer_fn=layers.batch_norm,
|
||||
normalizer_params={'is_training': is_training, 'updates_collections': None},
|
||||
normalizer_params={'is_training': is_training, 'updates_collections': tf.GraphKeys.UPDATE_OPS},
|
||||
weights_regularizer=layers.l2_regularizer(1e-4))
|
||||
for i in range(19):
|
||||
h = residual_block(h, is_training)
|
||||
|
@ -1,5 +1,6 @@
|
||||
import numpy as np
|
||||
import os
|
||||
import time
|
||||
|
||||
def hex2board(hex):
|
||||
scale = 16
|
||||
@ -29,37 +30,45 @@ for n in name:
|
||||
if n[-4:]==".txt":
|
||||
text.append(n)
|
||||
print(text)
|
||||
total_start = -time.time()
|
||||
for t in text:
|
||||
start = -time.time()
|
||||
num = 0
|
||||
boards = np.zeros([0, 19, 19, 17], dtype='int8')
|
||||
board = np.zeros([1, 19, 19, 0], dtype='int8')
|
||||
win = np.zeros([0, 1], dtype='int8')
|
||||
p = np.zeros([0, 362])
|
||||
boards = []
|
||||
board = []
|
||||
win = []
|
||||
p = []
|
||||
flag = False
|
||||
for line in open(dir + t):
|
||||
if num % 19 == 0:
|
||||
flag = False
|
||||
if num % 19 < 16:
|
||||
new_board = hex2board(line)
|
||||
board = np.concatenate([board, new_board], axis=3)
|
||||
board.append(new_board)
|
||||
if num % 19 == 16:
|
||||
if int(line) == 0:
|
||||
new_board = np.ones([1, 19 ,19 ,1], dtype='int8')
|
||||
if int(line) == 1:
|
||||
new_board = np.zeros([1, 19, 19, 1], dtype='int8')
|
||||
board = np.concatenate([board, new_board], axis=3)
|
||||
boards = np.concatenate([boards, board], axis=0)
|
||||
board = np.zeros([1, 19, 19, 0], dtype='int8')
|
||||
board.append(new_board)
|
||||
board = np.concatenate(board, axis=3)
|
||||
boards.append(board)
|
||||
board = []
|
||||
if num % 19 == 17:
|
||||
if str2prob(line)[1] == True:
|
||||
p = np.concatenate([p,str2prob(line)[0]], axis=0)
|
||||
p.append(str2prob(line)[0])
|
||||
else:
|
||||
flag = True
|
||||
boards = boards[:-1]
|
||||
if num % 19 == 18:
|
||||
if flag == False:
|
||||
win = np.concatenate([win, np.array(float(line), dtype='int8').reshape(1,1)], axis=0)
|
||||
win.append(np.array(int(line), dtype='int8').reshape(1,1))
|
||||
num=num+1
|
||||
boards = np.concatenate(boards, axis=0)
|
||||
win = np.concatenate(win, axis=0)
|
||||
p = np.concatenate(p, axis=0)
|
||||
print("Boards Shape: {}, Wins Shape: {}, Probs Shape : {}".format(boards.shape[0], win.shape[0], p.shape[0]))
|
||||
print "Finished " + t
|
||||
print("Finished {} Time {}".format(t, time.time()+start))
|
||||
np.savez("/home/tongzheng/meta-data/"+t[:-4], boards=boards, win=win, p=p)
|
||||
del boards, board, win, p
|
||||
print("All finished! Time {}".format(time.time()+total_start))
|
Loading…
x
Reference in New Issue
Block a user