Tianshou/AlphaGo/data.py

84 lines
3.3 KiB
Python
Raw Normal View History

2017-11-04 22:16:43 +08:00
import os
2017-11-08 08:32:07 +08:00
import threading
2017-11-04 22:16:43 +08:00
import numpy as np
2017-11-08 08:32:07 +08:00
path = "/home/yama/leela-zero/data/npz-files/"
2017-11-04 22:16:43 +08:00
name = os.listdir(path)
2017-11-08 08:32:07 +08:00
print(len(name))
thread_num = 17
batch_num = len(name) // thread_num
2017-11-04 22:16:43 +08:00
2017-11-08 08:32:07 +08:00
def integrate(name, index):
boards = np.zeros([0, 19, 19, 17])
wins = np.zeros([0, 1])
ps = np.zeros([0, 362])
for n in name:
data = np.load(path + n)
board = data["boards"]
win = data["win"]
p = data["p"]
# board = np.zeros([0, 19, 19, 17])
# win = np.zeros([0, 1])
# p = np.zeros([0, 362])
# for i in range(data["boards"].shape[3]):
# board = np.concatenate([board, data["boards"][:,:,:,i].reshape(-1, 19, 19, 17)], axis=0)
# win = np.concatenate([win, data["win"][:,i].reshape(-1, 1)], axis=0)
# p = np.concatenate([p, data["p"][:,i].reshape(-1, 362)], axis=0)
boards = np.concatenate([boards, board], axis=0)
wins = np.concatenate([wins, win], axis=0)
ps = np.concatenate([ps, p], axis=0)
# print("Finish " + n)
print ("Integration {} Finished!".format(index))
board_ori = boards
win_ori = wins
p_ori = ps
for i in range(1, 3):
board = np.rot90(board_ori, i, (1, 2))
p = np.concatenate(
[np.rot90(p_ori[:, :-1].reshape(-1, 19, 19), i, (1, 2)).reshape(-1, 361), p_ori[:, -1].reshape(-1, 1)],
axis=1)
boards = np.concatenate([boards, board], axis=0)
wins = np.concatenate([wins, win_ori], axis=0)
ps = np.concatenate([ps, p], axis=0)
2017-11-04 22:16:43 +08:00
2017-11-08 08:32:07 +08:00
board = board_ori[:, ::-1]
p = np.concatenate([p_ori[:, :-1].reshape(-1, 19, 19)[:, ::-1].reshape(-1, 361), p_ori[:, -1].reshape(-1, 1)],
axis=1)
boards = np.concatenate([boards, board], axis=0)
wins = np.concatenate([wins, win_ori], axis=0)
ps = np.concatenate([ps, p], axis=0)
2017-11-04 22:16:43 +08:00
2017-11-08 08:32:07 +08:00
board = board_ori[:, :, ::-1]
p = np.concatenate([p_ori[:, :-1].reshape(-1, 19, 19)[:, :, ::-1].reshape(-1, 361), p_ori[:, -1].reshape(-1, 1)],
axis=1)
boards = np.concatenate([boards, board], axis=0)
wins = np.concatenate([wins, win_ori], axis=0)
ps = np.concatenate([ps, p], axis=0)
2017-11-04 22:16:43 +08:00
2017-11-08 08:32:07 +08:00
board = board_ori[:, ::-1]
p = np.concatenate(
[np.rot90(p_ori[:, :-1].reshape(-1, 19, 19)[:, ::-1], 1, (1, 2)).reshape(-1, 361), p_ori[:, -1].reshape(-1, 1)],
axis=1)
boards = np.concatenate([boards, np.rot90(board, 1, (1, 2))], axis=0)
wins = np.concatenate([wins, win_ori], axis=0)
ps = np.concatenate([ps, p], axis=0)
2017-11-04 22:16:43 +08:00
2017-11-08 08:32:07 +08:00
board = board_ori[:, :, ::-1]
p = np.concatenate(
[np.rot90(p_ori[:, :-1].reshape(-1, 19, 19)[:, :, ::-1], 1, (1, 2)).reshape(-1, 361),
p_ori[:, -1].reshape(-1, 1)],
axis=1)
boards = np.concatenate([boards, np.rot90(board, 1, (1, 2))], axis=0)
wins = np.concatenate([wins, win_ori], axis=0)
ps = np.concatenate([ps, p], axis=0)
2017-11-04 22:16:43 +08:00
2017-11-08 08:32:07 +08:00
np.savez("/home/tongzheng/data/data-" + str(index), boards=boards, wins=wins, ps=ps)
print ("Thread {} has finished.".format(index))
thread_list = list()
for i in range(thread_num):
thread_list.append(threading.Thread(target=integrate, args=(name[batch_num * i:batch_num * (i + 1)], i,)))
for thread in thread_list:
thread.start()
for thread in thread_list:
thread.join()