combine gtp and network

This commit is contained in:
rtz19970824 2017-12-05 23:17:20 +08:00
parent 543d876f12
commit f9f63e6609
6 changed files with 205 additions and 50 deletions

View File

@ -8,6 +8,7 @@ from __future__ import print_function
import utils import utils
import copy import copy
import tensorflow as tf import tensorflow as tf
import numpy as np
from collections import deque from collections import deque
import Network import Network
@ -185,7 +186,7 @@ class Game:
self.board = [utils.EMPTY] * (self.size * self.size) self.board = [utils.EMPTY] * (self.size * self.size)
self.strategy = strategy() self.strategy = strategy()
# self.strategy = None # self.strategy = None
self.executor = Executor(game = self) self.executor = Executor(game=self)
self.history = [] self.history = []
self.past = deque(maxlen=8) self.past = deque(maxlen=8)
for i in range(8): for i in range(8):
@ -211,8 +212,8 @@ class Game:
def set_komi(self, k): def set_komi(self, k):
self.komi = k self.komi = k
def check_valid(self, vertex): def check_valid(self, color, vertex):
return True return self.executor.is_valid(color, vertex)
def do_move(self, color, vertex): def do_move(self, color, vertex):
if vertex == utils.PASS: if vertex == utils.PASS:
@ -224,7 +225,6 @@ class Game:
# move = self.strategy.gen_move(color) # move = self.strategy.gen_move(color)
# return move # return move
move = self.strategy.gen_move(self.past, color) move = self.strategy.gen_move(self.past, color)
print(move)
self.do_move(color, move) self.do_move(color, move)
return move return move

26
AlphaGo/self-play.py Normal file
View File

@ -0,0 +1,26 @@
from game import Game
from engine import GTPEngine
g = Game()
g.show_board()
e = GTPEngine(game_obj=g)
num = 0
black_pass = False
white_pass = False
while not (black_pass and white_pass):
if num % 2 == 0:
res=e.run_cmd("genmove BLACK")[0]
num += 1
print(res)
if res == (0,0):
black_pass = True
else:
res = e.run_cmd("genmove WHITE")[0]
num += 1
print(res)
if res == (0, 0):
white_pass = True
g.show_board()

View File

@ -1,18 +1,22 @@
import os,sys import os, sys
sys.path.append(os.path.join(os.path.dirname(__file__), os.path.pardir)) sys.path.append(os.path.join(os.path.dirname(__file__), os.path.pardir))
import numpy as np import numpy as np
import utils import utils
import time import time
import copy
import network_small import network_small
import tensorflow as tf import tensorflow as tf
from collections import deque from collections import deque
from tianshou.core.mcts.mcts import MCTS from tianshou.core.mcts.mcts import MCTS
DELTA = [[1, 0], [-1, 0], [0, -1], [0, 1]]
class GoEnv: class GoEnv:
def __init__(self, size=9, komi=6.5): def __init__(self, size=9, komi=6.5):
self.size = size self.size = size
self.komi = 6.5 self.komi = komi
self.board = [utils.EMPTY] * (self.size * self.size) self.board = [utils.EMPTY] * (self.size * self.size)
self.history = deque(maxlen=8) self.history = deque(maxlen=8)
@ -102,13 +106,28 @@ class GoEnv:
for b in block: for b in block:
self.board[self._flatten(b)] = utils.EMPTY self.board[self._flatten(b)] = utils.EMPTY
def is_valid(self, color, vertex): # def is_valid(self, color, vertex):
def is_valid(self, state, action):
if action == 81:
vertex = (0, 0)
else:
vertex = (action / 9 + 1, action % 9 + 1)
if state[0, 0, 0, -1] == 1:
color = 1
else:
color = -1
self.history.clear()
for i in range(8):
self.history.append((state[:, :, :, i] - state[:, :, :, i + 8]).reshape(-1).tolist())
self.board = copy.copy(self.history[-1])
### in board ### in board
if not self._in_board(vertex): if not self._in_board(vertex):
return False return False
### already have stone ### already have stone
if not self.board[self._flatten(vertex)] == utils.EMPTY: if not self.board[self._flatten(vertex)] == utils.EMPTY:
# print(np.array(self.board).reshape(9, 9))
# print(vertex)
return False return False
### check if it is qi ### check if it is qi
@ -127,13 +146,12 @@ class GoEnv:
id_ = self._flatten(vertex) id_ = self._flatten(vertex)
if self.board[id_] == utils.EMPTY: if self.board[id_] == utils.EMPTY:
self.board[id_] = color self.board[id_] = color
self.history.append(self.board) self.history.append(copy.copy(self.board))
return True return True
else: else:
return False return False
def step_forward(self, state, action): def step_forward(self, state, action):
# print(state)
if state[0, 0, 0, -1] == 1: if state[0, 0, 0, -1] == 1:
color = 1 color = 1
else: else:
@ -141,7 +159,10 @@ class GoEnv:
if action == 81: if action == 81:
vertex = (0, 0) vertex = (0, 0)
else: else:
vertex = (action / 9 + 1, action % 9) vertex = (action % 9 + 1, action / 9 + 1)
# print(vertex)
# print(self.board)
self.board = (state[:, :, :, 7] - state[:, :, :, 15]).reshape(-1).tolist()
self.do_move(color, vertex) self.do_move(color, vertex)
new_state = np.concatenate( new_state = np.concatenate(
[state[:, :, :, 1:8], (np.array(self.board) == 1).reshape(1, 9, 9, 1), [state[:, :, :, 1:8], (np.array(self.board) == 1).reshape(1, 9, 9, 1),
@ -162,8 +183,8 @@ class strategy(object):
def data_process(self, history, color): def data_process(self, history, color):
state = np.zeros([1, 9, 9, 17]) state = np.zeros([1, 9, 9, 17])
for i in range(8): for i in range(8):
state[0, :, :, i] = history[i] == 1 state[0, :, :, i] = np.array(np.array(history[i]) == np.ones(81)).reshape(9, 9)
state[0, :, :, i + 8] = history[i] == -1 state[0, :, :, i + 8] = np.array(np.array(history[i]) == -np.ones(81)).reshape(9, 9)
if color == 1: if color == 1:
state[0, :, :, 16] = np.ones([9, 9]) state[0, :, :, 16] = np.ones([9, 9])
if color == -1: if color == -1:
@ -171,16 +192,15 @@ class strategy(object):
return state return state
def gen_move(self, history, color): def gen_move(self, history, color):
self.simulator.history = history self.simulator.history = copy.copy(history)
self.simulator.board = history[-1] self.simulator.board = copy.copy(history[-1])
state = self.data_process(history, color) state = self.data_process(self.simulator.history, color)
prior = self.evaluator(state)[0] mcts = MCTS(self.simulator, self.evaluator, state, 82, inverse=True, max_step=10)
mcts = MCTS(self.simulator, self.evaluator, state, 82, prior, inverse=True, max_step=100)
temp = 1 temp = 1
p = mcts.root.N ** temp / np.sum(mcts.root.N ** temp) p = mcts.root.N ** temp / np.sum(mcts.root.N ** temp)
choice = np.random.choice(82, 1, p=p).tolist()[0] choice = np.random.choice(82, 1, p=p).tolist()[0]
if choice == 81: if choice == 81:
move = (0, 0) move = (0, 0)
else: else:
move = (choice / 9 + 1, choice % 9 + 1) move = (choice % 9 + 1, choice / 9 + 1)
return move return move

View File

@ -1,34 +1,134 @@
# -*- coding: utf-8 -*-
# vim:fenc=utf-8
# $File: test.py
# $Date: Fri Dec 01 01:3722 2017 +0800
# $Author: renyong15 © <mails.tsinghua.edu.cn>
#
from game import Game from game import Game
from engine import GTPEngine from engine import GTPEngine
import utils import utils
g = Game()
e = GTPEngine(game_obj=g)
e.run_cmd("genmove BLACK")
g = Game()
e = GTPEngine(game_obj = g)
res = e.run_cmd('1 protocol_version')
print(e.known_commands)
print(res)
res = e.run_cmd('2 name')
print(res)
res = e.run_cmd('3 known_command quit')
print(res)
res = e.run_cmd('4 unknown_command quitagain')
print(res)
res = e.run_cmd('5 list_commands')
print(res)
res = e.run_cmd('6 komi 6')
print(res)
res = e.run_cmd('7 play BLACK C3')
print(res)
res = e.run_cmd('play BLACK C4')
res = e.run_cmd('play BLACK C5')
res = e.run_cmd('play BLACK C6')
res = e.run_cmd('play BLACK D3')
print(res)
res = e.run_cmd('8 genmove BLACK')
print(res)
#g.show_board()
print(g.check_valid((10, 9)))
print(g.executor._neighbor((1,1)))
print(g.do_move(utils.WHITE, (4, 6)))
#g.show_board()
res = e.run_cmd('play BLACK L10')
res = e.run_cmd('play BLACK L11')
res = e.run_cmd('play BLACK L12')
res = e.run_cmd('play BLACK L13')
res = e.run_cmd('play BLACK L14')
res = e.run_cmd('play BLACK m15')
res = e.run_cmd('play BLACK m9')
res = e.run_cmd('play BLACK C9')
res = e.run_cmd('play BLACK D9')
res = e.run_cmd('play BLACK E9')
res = e.run_cmd('play BLACK F9')
res = e.run_cmd('play BLACK G9')
res = e.run_cmd('play BLACK H9')
res = e.run_cmd('play BLACK I9')
res = e.run_cmd('play BLACK N9')
res = e.run_cmd('play BLACK N15')
res = e.run_cmd('play BLACK O10')
res = e.run_cmd('play BLACK O11')
res = e.run_cmd('play BLACK O12')
res = e.run_cmd('play BLACK O13')
res = e.run_cmd('play BLACK O14')
res = e.run_cmd('play BLACK M12')
res = e.run_cmd('play WHITE M10')
res = e.run_cmd('play WHITE M11')
res = e.run_cmd('play WHITE N10')
res = e.run_cmd('play WHITE N11')
res = e.run_cmd('play WHITE M13')
res = e.run_cmd('play WHITE M14')
res = e.run_cmd('play WHITE N13')
res = e.run_cmd('play WHITE N14')
print(res)
res = e.run_cmd('play BLACK N12')
print(res)
#g.show_board()
res = e.run_cmd('play BLACK P16')
res = e.run_cmd('play BLACK P17')
res = e.run_cmd('play BLACK P18')
res = e.run_cmd('play BLACK P19')
res = e.run_cmd('play BLACK Q16')
res = e.run_cmd('play BLACK R16')
res = e.run_cmd('play BLACK S16')
res = e.run_cmd('play WHITE S18')
res = e.run_cmd('play WHITE S17')
res = e.run_cmd('play WHITE Q19')
res = e.run_cmd('play WHITE Q18')
res = e.run_cmd('play WHITE Q17')
res = e.run_cmd('play WHITE R18')
res = e.run_cmd('play WHITE R17')
res = e.run_cmd('play BLACK S19')
print(res)
#g.show_board()
res = e.run_cmd('play WHITE R19')
g.show_board() g.show_board()
e.run_cmd("genmove WHITE")
res = e.run_cmd('play BLACK S19')
print(res)
g.show_board() g.show_board()
e.run_cmd("genmove BLACK")
res = e.run_cmd('play BLACK S19')
print(res)
res = e.run_cmd('play BLACK E17')
res = e.run_cmd('play BLACK F16')
res = e.run_cmd('play BLACK F18')
res = e.run_cmd('play BLACK G17')
res = e.run_cmd('play WHITE G16')
res = e.run_cmd('play WHITE G18')
res = e.run_cmd('play WHITE H17')
g.show_board() g.show_board()
e.run_cmd("genmove WHITE")
res = e.run_cmd('play WHITE F17')
g.show_board() g.show_board()
e.run_cmd("genmove BLACK")
res = e.run_cmd('play BLACK G17')
print(res)
g.show_board() g.show_board()
e.run_cmd("genmove WHITE")
g.show_board() res = e.run_cmd('play BLACK G19')
e.run_cmd("genmove BLACK") res = e.run_cmd('play BLACK G17')
g.show_board()
e.run_cmd("genmove WHITE")
g.show_board()
e.run_cmd("genmove BLACK")
g.show_board()
e.run_cmd("genmove WHITE")
g.show_board() g.show_board()

View File

@ -26,7 +26,7 @@ class MCTSNode(object):
self.children = {} self.children = {}
self.state = state self.state = state
self.action_num = action_num self.action_num = action_num
self.prior = prior self.prior = np.array(prior).reshape(-1)
self.inverse = inverse self.inverse = inverse
def selection(self, simulator): def selection(self, simulator):
@ -35,6 +35,8 @@ class MCTSNode(object):
def backpropagation(self, action): def backpropagation(self, action):
raise NotImplementedError("Need to implement function backpropagation") raise NotImplementedError("Need to implement function backpropagation")
def valid_mask(self, simulator):
pass
class UCTNode(MCTSNode): class UCTNode(MCTSNode):
def __init__(self, parent, action, state, action_num, prior, inverse=False): def __init__(self, parent, action, state, action_num, prior, inverse=False):
@ -45,6 +47,7 @@ class UCTNode(MCTSNode):
self.ucb = self.Q + c_puct * self.prior * math.sqrt(np.sum(self.N)) / (self.N + 1) self.ucb = self.Q + c_puct * self.prior * math.sqrt(np.sum(self.N)) / (self.N + 1)
def selection(self, simulator): def selection(self, simulator):
self.valid_mask(simulator)
action = np.argmax(self.ucb) action = np.argmax(self.ucb)
if action in self.children.keys(): if action in self.children.keys():
return self.children[action].selection(simulator) return self.children[action].selection(simulator)
@ -66,6 +69,11 @@ class UCTNode(MCTSNode):
else: else:
self.parent.backpropagation(self.children[action].reward) self.parent.backpropagation(self.children[action].reward)
def valid_mask(self, simulator):
for act in range(self.action_num - 1):
if not simulator.is_valid(self.state, act):
self.ucb[act] = -float("Inf")
class TSNode(MCTSNode): class TSNode(MCTSNode):
def __init__(self, parent, action, state, action_num, prior, method="Gaussian", inverse=False): def __init__(self, parent, action, state, action_num, prior, method="Gaussian", inverse=False):
@ -78,7 +86,7 @@ class TSNode(MCTSNode):
self.sigma = np.zeros([action_num]) self.sigma = np.zeros([action_num])
class ActionNode: class ActionNode(object):
def __init__(self, parent, action): def __init__(self, parent, action):
self.parent = parent self.parent = parent
self.action = action self.action = action
@ -120,18 +128,19 @@ class ActionNode:
self.parent.inverse) self.parent.inverse)
return value return value
else: else:
return 0 return 0.
def backpropagation(self, value): def backpropagation(self, value):
self.reward += value self.reward += value
self.parent.backpropagation(self.action) self.parent.backpropagation(self.action)
class MCTS: class MCTS(object):
def __init__(self, simulator, evaluator, root, action_num, prior, method="UCT", inverse=False, max_step=None, def __init__(self, simulator, evaluator, root, action_num, method="UCT", inverse=False, max_step=None,
max_time=None): max_time=None):
self.simulator = simulator self.simulator = simulator
self.evaluator = evaluator self.evaluator = evaluator
prior, _ = self.evaluator(root)
self.action_num = action_num self.action_num = action_num
if method == "": if method == "":
self.root = root self.root = root

View File

@ -28,7 +28,7 @@ class TestEnv:
if step == self.max_step: if step == self.max_step:
reward = int(np.random.uniform() < self.reward[num]) reward = int(np.random.uniform() < self.reward[num])
else: else:
reward = 0 reward = 0.
return new_state, reward return new_state, reward
@ -36,4 +36,4 @@ if __name__ == "__main__":
env = TestEnv(2) env = TestEnv(2)
rollout = rollout_policy(env, 2) rollout = rollout_policy(env, 2)
evaluator = lambda state: rollout(state) evaluator = lambda state: rollout(state)
mcts = MCTS(env, evaluator, [0, 0], 2, np.array([0.5, 0.5]), max_step=1e4) mcts = MCTS(env, evaluator, [0, 0], 2, max_step=1e4)