From ab1b3775e74ab9287cada2abf76460991a7eb2c8 Mon Sep 17 00:00:00 2001 From: Tongzheng Ren Date: Fri, 10 Nov 2017 14:40:23 +0800 Subject: [PATCH] minor fixed --- AlphaGo/Network.py | 8 ++++---- utils/text2data.py | 47 +++++++++++++++++++++++++++------------------- 2 files changed, 32 insertions(+), 23 deletions(-) diff --git a/AlphaGo/Network.py b/AlphaGo/Network.py index a4f117f..d069927 100644 --- a/AlphaGo/Network.py +++ b/AlphaGo/Network.py @@ -12,7 +12,7 @@ os.environ["CUDA_VISIBLE_DEVICES"] = "1" def residual_block(input, is_training): normalizer_params = {'is_training': is_training, - 'updates_collections': None} + 'updates_collections': tf.GraphKeys.UPDATE_OPS} h = layers.conv2d(input, 256, kernel_size=3, stride=1, activation_fn=tf.nn.relu, normalizer_fn=layers.batch_norm, normalizer_params=normalizer_params, weights_regularizer=layers.l2_regularizer(1e-4)) @@ -25,7 +25,7 @@ def residual_block(input, is_training): def policy_heads(input, is_training): normalizer_params = {'is_training': is_training, - 'updates_collections': None} + 'updates_collections': tf.GraphKeys.UPDATE_OPS} h = layers.conv2d(input, 2, kernel_size=1, stride=1, activation_fn=tf.nn.relu, normalizer_fn=layers.batch_norm, normalizer_params=normalizer_params, weights_regularizer=layers.l2_regularizer(1e-4)) @@ -36,7 +36,7 @@ def policy_heads(input, is_training): def value_heads(input, is_training): normalizer_params = {'is_training': is_training, - 'updates_collections': None} + 'updates_collections': tf.GraphKeys.UPDATE_OPS} h = layers.conv2d(input, 2, kernel_size=1, stride=1, activation_fn=tf.nn.relu, normalizer_fn=layers.batch_norm, normalizer_params=normalizer_params, weights_regularizer=layers.l2_regularizer(1e-4)) @@ -52,7 +52,7 @@ z = tf.placeholder(tf.float32, shape=[None, 1]) pi = tf.placeholder(tf.float32, shape=[None, 362]) h = layers.conv2d(x, 256, kernel_size=3, stride=1, activation_fn=tf.nn.relu, normalizer_fn=layers.batch_norm, - normalizer_params={'is_training': is_training, 'updates_collections': None}, + normalizer_params={'is_training': is_training, 'updates_collections': tf.GraphKeys.UPDATE_OPS}, weights_regularizer=layers.l2_regularizer(1e-4)) for i in range(19): h = residual_block(h, is_training) diff --git a/utils/text2data.py b/utils/text2data.py index 31d5ade..607c7f2 100644 --- a/utils/text2data.py +++ b/utils/text2data.py @@ -1,5 +1,6 @@ import numpy as np import os +import time def hex2board(hex): scale = 16 @@ -20,7 +21,7 @@ def str2prob(str): if np.sum(np.isnan(prob))==0: return prob, True else: - return 0, False + return 0, False dir = "/home/yama/leela-zero/data/sgf-txt-files/" name = os.listdir(dir) @@ -29,37 +30,45 @@ for n in name: if n[-4:]==".txt": text.append(n) print(text) +total_start = -time.time() for t in text: + start = -time.time() num = 0 - boards = np.zeros([0, 19, 19, 17], dtype='int8') - board = np.zeros([1, 19, 19, 0], dtype='int8') - win = np.zeros([0, 1], dtype='int8') - p = np.zeros([0, 362]) + boards = [] + board = [] + win = [] + p = [] flag = False for line in open(dir + t): - if num % 19 == 0: - flag = False + if num % 19 == 0: + flag = False if num % 19 < 16: new_board = hex2board(line) - board = np.concatenate([board, new_board], axis=3) + board.append(new_board) if num % 19 == 16: if int(line) == 0: new_board = np.ones([1, 19 ,19 ,1], dtype='int8') if int(line) == 1: new_board = np.zeros([1, 19, 19, 1], dtype='int8') - board = np.concatenate([board, new_board], axis=3) - boards = np.concatenate([boards, board], axis=0) - board = np.zeros([1, 19, 19, 0], dtype='int8') + board.append(new_board) + board = np.concatenate(board, axis=3) + boards.append(board) + board = [] if num % 19 == 17: - if str2prob(line)[1] == True: - p = np.concatenate([p,str2prob(line)[0]], axis=0) - else: - flag = True - boards = boards[:-1] + if str2prob(line)[1] == True: + p.append(str2prob(line)[0]) + else: + flag = True + boards = boards[:-1] if num % 19 == 18: - if flag == False: - win = np.concatenate([win, np.array(float(line), dtype='int8').reshape(1,1)], axis=0) + if flag == False: + win.append(np.array(int(line), dtype='int8').reshape(1,1)) num=num+1 + boards = np.concatenate(boards, axis=0) + win = np.concatenate(win, axis=0) + p = np.concatenate(p, axis=0) print("Boards Shape: {}, Wins Shape: {}, Probs Shape : {}".format(boards.shape[0], win.shape[0], p.shape[0])) - print "Finished " + t + print("Finished {} Time {}".format(t, time.time()+start)) np.savez("/home/tongzheng/meta-data/"+t[:-4], boards=boards, win=win, p=p) + del boards, board, win, p +print("All finished! Time {}".format(time.time()+total_start)) \ No newline at end of file