Tianshou/AlphaGo/self-play.py

104 lines
3.1 KiB
Python
Raw Normal View History

2017-12-05 23:17:20 +08:00
from game import Game
from engine import GTPEngine
2017-12-05 23:42:18 +08:00
import re
2017-12-08 17:05:33 +08:00
import numpy as np
2017-12-08 18:08:15 +08:00
import os
2017-12-08 17:05:33 +08:00
from collections import deque
import utils
import argparse
2017-12-05 23:17:20 +08:00
2017-12-08 17:05:33 +08:00
parser = argparse.ArgumentParser()
parser.add_argument('--result_path', type=str, default='./part1')
args = parser.parse_args()
2017-12-08 18:08:15 +08:00
if not os.path.exists(args.result_path):
os.makedirs(args.result_path)
2017-12-12 17:09:26 +08:00
game = Game()
2017-12-08 17:05:33 +08:00
engine = GTPEngine(game_obj=game)
history = deque(maxlen=8)
for i in range(8):
history.append(game.board)
state = []
prob = []
winner = []
2017-12-07 17:51:58 +08:00
pattern = "[A-Z]{1}[0-9]{1}"
2017-12-08 17:05:33 +08:00
game.show_board()
def history2state(history, color):
state = np.zeros([1, game.size, game.size, 17])
for i in range(8):
state[0, :, :, i] = np.array(np.array(history[i]) == np.ones(game.size ** 2)).reshape(game.size, game.size)
state[0, :, :, i + 8] = np.array(np.array(history[i]) == -np.ones(game.size ** 2)).reshape(game.size, game.size)
if color == utils.BLACK:
state[0, :, :, 16] = np.ones([game.size, game.size])
if color == utils.WHITE:
state[0, :, :, 16] = np.zeros([game.size, game.size])
return state
2017-12-05 23:17:20 +08:00
num = 0
2017-12-08 17:05:33 +08:00
game_num = 0
2017-12-05 23:17:20 +08:00
black_pass = False
white_pass = False
2017-12-08 17:05:33 +08:00
while True:
2017-12-08 18:59:20 +08:00
print("Start game {}".format(game_num))
2017-12-08 17:05:33 +08:00
while not (black_pass and white_pass) and num < game.size ** 2 * 2:
if num % 2 == 0:
color = utils.BLACK
new_state = history2state(history, color)
state.append(new_state)
result = engine.run_cmd(str(num) + " genmove BLACK")
num += 1
match = re.search(pattern, result)
if match is not None:
print(match.group())
else:
print("pass")
if re.search("pass", result) is not None:
black_pass = True
else:
black_pass = False
2017-12-07 17:51:58 +08:00
else:
2017-12-08 17:05:33 +08:00
color = utils.WHITE
new_state = history2state(history, color)
state.append(new_state)
result = engine.run_cmd(str(num) + " genmove WHITE")
num += 1
match = re.search(pattern, result)
if match is not None:
print(match.group())
else:
print("pass")
if re.search("pass", result) is not None:
white_pass = True
else:
white_pass = False
game.show_board()
prob.append(np.array(game.prob).reshape(-1, game.size ** 2 + 1))
print("Finished")
2017-12-08 18:59:20 +08:00
print("\n")
2017-12-20 01:14:05 +08:00
score = game.game_engine.executor_get_score(True)
2017-12-08 17:05:33 +08:00
if score > 0:
winner = utils.BLACK
2017-12-05 23:17:20 +08:00
else:
2017-12-08 17:05:33 +08:00
winner = utils.WHITE
state = np.concatenate(state, axis=0)
prob = np.concatenate(prob, axis=0)
winner = np.ones([num, 1]) * winner
assert state.shape[0] == prob.shape[0]
assert state.shape[0] == winner.shape[0]
2017-12-08 18:08:15 +08:00
np.savez(args.result_path + "/game" + str(game_num), state=state, prob=prob, winner=winner)
2017-12-08 17:05:33 +08:00
state = []
prob = []
winner = []
num = 0
black_pass = False
white_pass = False
engine.run_cmd(str(num) + " clear_board")
history.clear()
for _ in range(8):
history.append(game.board)
game_num += 1