combine gtp and network

This commit is contained in:
rtz19970824 2017-12-05 23:17:20 +08:00
parent 543d876f12
commit f9f63e6609
6 changed files with 205 additions and 50 deletions

View File

@ -8,6 +8,7 @@ from __future__ import print_function
import utils
import copy
import tensorflow as tf
import numpy as np
from collections import deque
import Network
@ -185,7 +186,7 @@ class Game:
self.board = [utils.EMPTY] * (self.size * self.size)
self.strategy = strategy()
# self.strategy = None
self.executor = Executor(game = self)
self.executor = Executor(game=self)
self.history = []
self.past = deque(maxlen=8)
for i in range(8):
@ -211,8 +212,8 @@ class Game:
def set_komi(self, k):
self.komi = k
def check_valid(self, vertex):
return True
def check_valid(self, color, vertex):
return self.executor.is_valid(color, vertex)
def do_move(self, color, vertex):
if vertex == utils.PASS:
@ -224,7 +225,6 @@ class Game:
# move = self.strategy.gen_move(color)
# return move
move = self.strategy.gen_move(self.past, color)
print(move)
self.do_move(color, move)
return move

26
AlphaGo/self-play.py Normal file
View File

@ -0,0 +1,26 @@
from game import Game
from engine import GTPEngine
g = Game()
g.show_board()
e = GTPEngine(game_obj=g)
num = 0
black_pass = False
white_pass = False
while not (black_pass and white_pass):
if num % 2 == 0:
res=e.run_cmd("genmove BLACK")[0]
num += 1
print(res)
if res == (0,0):
black_pass = True
else:
res = e.run_cmd("genmove WHITE")[0]
num += 1
print(res)
if res == (0, 0):
white_pass = True
g.show_board()

View File

@ -1,18 +1,22 @@
import os,sys
import os, sys
sys.path.append(os.path.join(os.path.dirname(__file__), os.path.pardir))
import numpy as np
import utils
import time
import copy
import network_small
import tensorflow as tf
from collections import deque
from tianshou.core.mcts.mcts import MCTS
DELTA = [[1, 0], [-1, 0], [0, -1], [0, 1]]
class GoEnv:
def __init__(self, size=9, komi=6.5):
self.size = size
self.komi = 6.5
self.komi = komi
self.board = [utils.EMPTY] * (self.size * self.size)
self.history = deque(maxlen=8)
@ -102,13 +106,28 @@ class GoEnv:
for b in block:
self.board[self._flatten(b)] = utils.EMPTY
def is_valid(self, color, vertex):
# def is_valid(self, color, vertex):
def is_valid(self, state, action):
if action == 81:
vertex = (0, 0)
else:
vertex = (action / 9 + 1, action % 9 + 1)
if state[0, 0, 0, -1] == 1:
color = 1
else:
color = -1
self.history.clear()
for i in range(8):
self.history.append((state[:, :, :, i] - state[:, :, :, i + 8]).reshape(-1).tolist())
self.board = copy.copy(self.history[-1])
### in board
if not self._in_board(vertex):
return False
### already have stone
if not self.board[self._flatten(vertex)] == utils.EMPTY:
# print(np.array(self.board).reshape(9, 9))
# print(vertex)
return False
### check if it is qi
@ -127,13 +146,12 @@ class GoEnv:
id_ = self._flatten(vertex)
if self.board[id_] == utils.EMPTY:
self.board[id_] = color
self.history.append(self.board)
self.history.append(copy.copy(self.board))
return True
else:
return False
def step_forward(self, state, action):
# print(state)
if state[0, 0, 0, -1] == 1:
color = 1
else:
@ -141,7 +159,10 @@ class GoEnv:
if action == 81:
vertex = (0, 0)
else:
vertex = (action / 9 + 1, action % 9)
vertex = (action % 9 + 1, action / 9 + 1)
# print(vertex)
# print(self.board)
self.board = (state[:, :, :, 7] - state[:, :, :, 15]).reshape(-1).tolist()
self.do_move(color, vertex)
new_state = np.concatenate(
[state[:, :, :, 1:8], (np.array(self.board) == 1).reshape(1, 9, 9, 1),
@ -162,8 +183,8 @@ class strategy(object):
def data_process(self, history, color):
state = np.zeros([1, 9, 9, 17])
for i in range(8):
state[0, :, :, i] = history[i] == 1
state[0, :, :, i + 8] = history[i] == -1
state[0, :, :, i] = np.array(np.array(history[i]) == np.ones(81)).reshape(9, 9)
state[0, :, :, i + 8] = np.array(np.array(history[i]) == -np.ones(81)).reshape(9, 9)
if color == 1:
state[0, :, :, 16] = np.ones([9, 9])
if color == -1:
@ -171,16 +192,15 @@ class strategy(object):
return state
def gen_move(self, history, color):
self.simulator.history = history
self.simulator.board = history[-1]
state = self.data_process(history, color)
prior = self.evaluator(state)[0]
mcts = MCTS(self.simulator, self.evaluator, state, 82, prior, inverse=True, max_step=100)
self.simulator.history = copy.copy(history)
self.simulator.board = copy.copy(history[-1])
state = self.data_process(self.simulator.history, color)
mcts = MCTS(self.simulator, self.evaluator, state, 82, inverse=True, max_step=10)
temp = 1
p = mcts.root.N ** temp / np.sum(mcts.root.N ** temp)
choice = np.random.choice(82, 1, p=p).tolist()[0]
if choice == 81:
move = (0, 0)
else:
move = (choice / 9 + 1, choice % 9 + 1)
move = (choice % 9 + 1, choice / 9 + 1)
return move

View File

@ -1,34 +1,134 @@
# -*- coding: utf-8 -*-
# vim:fenc=utf-8
# $File: test.py
# $Date: Fri Dec 01 01:3722 2017 +0800
# $Author: renyong15 © <mails.tsinghua.edu.cn>
#
from game import Game
from engine import GTPEngine
import utils
g = Game()
e = GTPEngine(game_obj=g)
e.run_cmd("genmove BLACK")
g = Game()
e = GTPEngine(game_obj = g)
res = e.run_cmd('1 protocol_version')
print(e.known_commands)
print(res)
res = e.run_cmd('2 name')
print(res)
res = e.run_cmd('3 known_command quit')
print(res)
res = e.run_cmd('4 unknown_command quitagain')
print(res)
res = e.run_cmd('5 list_commands')
print(res)
res = e.run_cmd('6 komi 6')
print(res)
res = e.run_cmd('7 play BLACK C3')
print(res)
res = e.run_cmd('play BLACK C4')
res = e.run_cmd('play BLACK C5')
res = e.run_cmd('play BLACK C6')
res = e.run_cmd('play BLACK D3')
print(res)
res = e.run_cmd('8 genmove BLACK')
print(res)
#g.show_board()
print(g.check_valid((10, 9)))
print(g.executor._neighbor((1,1)))
print(g.do_move(utils.WHITE, (4, 6)))
#g.show_board()
res = e.run_cmd('play BLACK L10')
res = e.run_cmd('play BLACK L11')
res = e.run_cmd('play BLACK L12')
res = e.run_cmd('play BLACK L13')
res = e.run_cmd('play BLACK L14')
res = e.run_cmd('play BLACK m15')
res = e.run_cmd('play BLACK m9')
res = e.run_cmd('play BLACK C9')
res = e.run_cmd('play BLACK D9')
res = e.run_cmd('play BLACK E9')
res = e.run_cmd('play BLACK F9')
res = e.run_cmd('play BLACK G9')
res = e.run_cmd('play BLACK H9')
res = e.run_cmd('play BLACK I9')
res = e.run_cmd('play BLACK N9')
res = e.run_cmd('play BLACK N15')
res = e.run_cmd('play BLACK O10')
res = e.run_cmd('play BLACK O11')
res = e.run_cmd('play BLACK O12')
res = e.run_cmd('play BLACK O13')
res = e.run_cmd('play BLACK O14')
res = e.run_cmd('play BLACK M12')
res = e.run_cmd('play WHITE M10')
res = e.run_cmd('play WHITE M11')
res = e.run_cmd('play WHITE N10')
res = e.run_cmd('play WHITE N11')
res = e.run_cmd('play WHITE M13')
res = e.run_cmd('play WHITE M14')
res = e.run_cmd('play WHITE N13')
res = e.run_cmd('play WHITE N14')
print(res)
res = e.run_cmd('play BLACK N12')
print(res)
#g.show_board()
res = e.run_cmd('play BLACK P16')
res = e.run_cmd('play BLACK P17')
res = e.run_cmd('play BLACK P18')
res = e.run_cmd('play BLACK P19')
res = e.run_cmd('play BLACK Q16')
res = e.run_cmd('play BLACK R16')
res = e.run_cmd('play BLACK S16')
res = e.run_cmd('play WHITE S18')
res = e.run_cmd('play WHITE S17')
res = e.run_cmd('play WHITE Q19')
res = e.run_cmd('play WHITE Q18')
res = e.run_cmd('play WHITE Q17')
res = e.run_cmd('play WHITE R18')
res = e.run_cmd('play WHITE R17')
res = e.run_cmd('play BLACK S19')
print(res)
#g.show_board()
res = e.run_cmd('play WHITE R19')
g.show_board()
e.run_cmd("genmove WHITE")
res = e.run_cmd('play BLACK S19')
print(res)
g.show_board()
e.run_cmd("genmove BLACK")
res = e.run_cmd('play BLACK S19')
print(res)
res = e.run_cmd('play BLACK E17')
res = e.run_cmd('play BLACK F16')
res = e.run_cmd('play BLACK F18')
res = e.run_cmd('play BLACK G17')
res = e.run_cmd('play WHITE G16')
res = e.run_cmd('play WHITE G18')
res = e.run_cmd('play WHITE H17')
g.show_board()
e.run_cmd("genmove WHITE")
res = e.run_cmd('play WHITE F17')
g.show_board()
e.run_cmd("genmove BLACK")
res = e.run_cmd('play BLACK G17')
print(res)
g.show_board()
e.run_cmd("genmove WHITE")
g.show_board()
e.run_cmd("genmove BLACK")
g.show_board()
e.run_cmd("genmove WHITE")
g.show_board()
e.run_cmd("genmove BLACK")
g.show_board()
e.run_cmd("genmove WHITE")
res = e.run_cmd('play BLACK G19')
res = e.run_cmd('play BLACK G17')
g.show_board()

View File

@ -26,7 +26,7 @@ class MCTSNode(object):
self.children = {}
self.state = state
self.action_num = action_num
self.prior = prior
self.prior = np.array(prior).reshape(-1)
self.inverse = inverse
def selection(self, simulator):
@ -35,6 +35,8 @@ class MCTSNode(object):
def backpropagation(self, action):
raise NotImplementedError("Need to implement function backpropagation")
def valid_mask(self, simulator):
pass
class UCTNode(MCTSNode):
def __init__(self, parent, action, state, action_num, prior, inverse=False):
@ -45,6 +47,7 @@ class UCTNode(MCTSNode):
self.ucb = self.Q + c_puct * self.prior * math.sqrt(np.sum(self.N)) / (self.N + 1)
def selection(self, simulator):
self.valid_mask(simulator)
action = np.argmax(self.ucb)
if action in self.children.keys():
return self.children[action].selection(simulator)
@ -66,6 +69,11 @@ class UCTNode(MCTSNode):
else:
self.parent.backpropagation(self.children[action].reward)
def valid_mask(self, simulator):
for act in range(self.action_num - 1):
if not simulator.is_valid(self.state, act):
self.ucb[act] = -float("Inf")
class TSNode(MCTSNode):
def __init__(self, parent, action, state, action_num, prior, method="Gaussian", inverse=False):
@ -78,7 +86,7 @@ class TSNode(MCTSNode):
self.sigma = np.zeros([action_num])
class ActionNode:
class ActionNode(object):
def __init__(self, parent, action):
self.parent = parent
self.action = action
@ -120,18 +128,19 @@ class ActionNode:
self.parent.inverse)
return value
else:
return 0
return 0.
def backpropagation(self, value):
self.reward += value
self.parent.backpropagation(self.action)
class MCTS:
def __init__(self, simulator, evaluator, root, action_num, prior, method="UCT", inverse=False, max_step=None,
class MCTS(object):
def __init__(self, simulator, evaluator, root, action_num, method="UCT", inverse=False, max_step=None,
max_time=None):
self.simulator = simulator
self.evaluator = evaluator
prior, _ = self.evaluator(root)
self.action_num = action_num
if method == "":
self.root = root

View File

@ -28,7 +28,7 @@ class TestEnv:
if step == self.max_step:
reward = int(np.random.uniform() < self.reward[num])
else:
reward = 0
reward = 0.
return new_state, reward
@ -36,4 +36,4 @@ if __name__ == "__main__":
env = TestEnv(2)
rollout = rollout_policy(env, 2)
evaluator = lambda state: rollout(state)
mcts = MCTS(env, evaluator, [0, 0], 2, np.array([0.5, 0.5]), max_step=1e4)
mcts = MCTS(env, evaluator, [0, 0], 2, max_step=1e4)