combine gtp and network
This commit is contained in:
parent
543d876f12
commit
f9f63e6609
@ -8,6 +8,7 @@ from __future__ import print_function
|
||||
import utils
|
||||
import copy
|
||||
import tensorflow as tf
|
||||
import numpy as np
|
||||
from collections import deque
|
||||
|
||||
import Network
|
||||
@ -185,7 +186,7 @@ class Game:
|
||||
self.board = [utils.EMPTY] * (self.size * self.size)
|
||||
self.strategy = strategy()
|
||||
# self.strategy = None
|
||||
self.executor = Executor(game = self)
|
||||
self.executor = Executor(game=self)
|
||||
self.history = []
|
||||
self.past = deque(maxlen=8)
|
||||
for i in range(8):
|
||||
@ -211,8 +212,8 @@ class Game:
|
||||
def set_komi(self, k):
|
||||
self.komi = k
|
||||
|
||||
def check_valid(self, vertex):
|
||||
return True
|
||||
def check_valid(self, color, vertex):
|
||||
return self.executor.is_valid(color, vertex)
|
||||
|
||||
def do_move(self, color, vertex):
|
||||
if vertex == utils.PASS:
|
||||
@ -224,7 +225,6 @@ class Game:
|
||||
# move = self.strategy.gen_move(color)
|
||||
# return move
|
||||
move = self.strategy.gen_move(self.past, color)
|
||||
print(move)
|
||||
self.do_move(color, move)
|
||||
return move
|
||||
|
||||
|
26
AlphaGo/self-play.py
Normal file
26
AlphaGo/self-play.py
Normal file
@ -0,0 +1,26 @@
|
||||
from game import Game
|
||||
from engine import GTPEngine
|
||||
|
||||
g = Game()
|
||||
|
||||
|
||||
g.show_board()
|
||||
e = GTPEngine(game_obj=g)
|
||||
|
||||
num = 0
|
||||
black_pass = False
|
||||
white_pass = False
|
||||
while not (black_pass and white_pass):
|
||||
if num % 2 == 0:
|
||||
res=e.run_cmd("genmove BLACK")[0]
|
||||
num += 1
|
||||
print(res)
|
||||
if res == (0,0):
|
||||
black_pass = True
|
||||
else:
|
||||
res = e.run_cmd("genmove WHITE")[0]
|
||||
num += 1
|
||||
print(res)
|
||||
if res == (0, 0):
|
||||
white_pass = True
|
||||
g.show_board()
|
@ -1,25 +1,29 @@
|
||||
import os,sys
|
||||
import os, sys
|
||||
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), os.path.pardir))
|
||||
import numpy as np
|
||||
import utils
|
||||
import time
|
||||
import copy
|
||||
import network_small
|
||||
import tensorflow as tf
|
||||
from collections import deque
|
||||
from tianshou.core.mcts.mcts import MCTS
|
||||
|
||||
DELTA = [[1, 0], [-1, 0], [0, -1], [0, 1]]
|
||||
|
||||
|
||||
class GoEnv:
|
||||
def __init__(self, size=9, komi=6.5):
|
||||
self.size = size
|
||||
self.komi = 6.5
|
||||
self.komi = komi
|
||||
self.board = [utils.EMPTY] * (self.size * self.size)
|
||||
self.history = deque(maxlen=8)
|
||||
|
||||
def _flatten(self, vertex):
|
||||
x, y = vertex
|
||||
return (x - 1) * self.size + (y - 1)
|
||||
|
||||
|
||||
def _bfs(self, vertex, color, block, status, alive_break):
|
||||
block.append(vertex)
|
||||
status[self._flatten(vertex)] = True
|
||||
@ -102,13 +106,28 @@ class GoEnv:
|
||||
for b in block:
|
||||
self.board[self._flatten(b)] = utils.EMPTY
|
||||
|
||||
def is_valid(self, color, vertex):
|
||||
# def is_valid(self, color, vertex):
|
||||
def is_valid(self, state, action):
|
||||
if action == 81:
|
||||
vertex = (0, 0)
|
||||
else:
|
||||
vertex = (action / 9 + 1, action % 9 + 1)
|
||||
if state[0, 0, 0, -1] == 1:
|
||||
color = 1
|
||||
else:
|
||||
color = -1
|
||||
self.history.clear()
|
||||
for i in range(8):
|
||||
self.history.append((state[:, :, :, i] - state[:, :, :, i + 8]).reshape(-1).tolist())
|
||||
self.board = copy.copy(self.history[-1])
|
||||
### in board
|
||||
if not self._in_board(vertex):
|
||||
return False
|
||||
|
||||
### already have stone
|
||||
if not self.board[self._flatten(vertex)] == utils.EMPTY:
|
||||
# print(np.array(self.board).reshape(9, 9))
|
||||
# print(vertex)
|
||||
return False
|
||||
|
||||
### check if it is qi
|
||||
@ -127,13 +146,12 @@ class GoEnv:
|
||||
id_ = self._flatten(vertex)
|
||||
if self.board[id_] == utils.EMPTY:
|
||||
self.board[id_] = color
|
||||
self.history.append(self.board)
|
||||
self.history.append(copy.copy(self.board))
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
def step_forward(self, state, action):
|
||||
# print(state)
|
||||
if state[0, 0, 0, -1] == 1:
|
||||
color = 1
|
||||
else:
|
||||
@ -141,7 +159,10 @@ class GoEnv:
|
||||
if action == 81:
|
||||
vertex = (0, 0)
|
||||
else:
|
||||
vertex = (action / 9 + 1, action % 9)
|
||||
vertex = (action % 9 + 1, action / 9 + 1)
|
||||
# print(vertex)
|
||||
# print(self.board)
|
||||
self.board = (state[:, :, :, 7] - state[:, :, :, 15]).reshape(-1).tolist()
|
||||
self.do_move(color, vertex)
|
||||
new_state = np.concatenate(
|
||||
[state[:, :, :, 1:8], (np.array(self.board) == 1).reshape(1, 9, 9, 1),
|
||||
@ -162,8 +183,8 @@ class strategy(object):
|
||||
def data_process(self, history, color):
|
||||
state = np.zeros([1, 9, 9, 17])
|
||||
for i in range(8):
|
||||
state[0, :, :, i] = history[i] == 1
|
||||
state[0, :, :, i + 8] = history[i] == -1
|
||||
state[0, :, :, i] = np.array(np.array(history[i]) == np.ones(81)).reshape(9, 9)
|
||||
state[0, :, :, i + 8] = np.array(np.array(history[i]) == -np.ones(81)).reshape(9, 9)
|
||||
if color == 1:
|
||||
state[0, :, :, 16] = np.ones([9, 9])
|
||||
if color == -1:
|
||||
@ -171,16 +192,15 @@ class strategy(object):
|
||||
return state
|
||||
|
||||
def gen_move(self, history, color):
|
||||
self.simulator.history = history
|
||||
self.simulator.board = history[-1]
|
||||
state = self.data_process(history, color)
|
||||
prior = self.evaluator(state)[0]
|
||||
mcts = MCTS(self.simulator, self.evaluator, state, 82, prior, inverse=True, max_step=100)
|
||||
self.simulator.history = copy.copy(history)
|
||||
self.simulator.board = copy.copy(history[-1])
|
||||
state = self.data_process(self.simulator.history, color)
|
||||
mcts = MCTS(self.simulator, self.evaluator, state, 82, inverse=True, max_step=10)
|
||||
temp = 1
|
||||
p = mcts.root.N ** temp / np.sum(mcts.root.N ** temp)
|
||||
choice = np.random.choice(82, 1, p=p).tolist()[0]
|
||||
if choice == 81:
|
||||
move = (0, 0)
|
||||
else:
|
||||
move = (choice / 9 + 1, choice % 9 + 1)
|
||||
move = (choice % 9 + 1, choice / 9 + 1)
|
||||
return move
|
||||
|
148
AlphaGo/test.py
148
AlphaGo/test.py
@ -1,34 +1,134 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# vim:fenc=utf-8
|
||||
# $File: test.py
|
||||
# $Date: Fri Dec 01 01:3722 2017 +0800
|
||||
# $Author: renyong15 © <mails.tsinghua.edu.cn>
|
||||
#
|
||||
|
||||
from game import Game
|
||||
from engine import GTPEngine
|
||||
import utils
|
||||
|
||||
g = Game()
|
||||
e = GTPEngine(game_obj=g)
|
||||
|
||||
e.run_cmd("genmove BLACK")
|
||||
|
||||
g = Game()
|
||||
e = GTPEngine(game_obj = g)
|
||||
res = e.run_cmd('1 protocol_version')
|
||||
print(e.known_commands)
|
||||
print(res)
|
||||
|
||||
res = e.run_cmd('2 name')
|
||||
print(res)
|
||||
|
||||
res = e.run_cmd('3 known_command quit')
|
||||
print(res)
|
||||
|
||||
res = e.run_cmd('4 unknown_command quitagain')
|
||||
print(res)
|
||||
|
||||
res = e.run_cmd('5 list_commands')
|
||||
print(res)
|
||||
|
||||
res = e.run_cmd('6 komi 6')
|
||||
print(res)
|
||||
|
||||
res = e.run_cmd('7 play BLACK C3')
|
||||
print(res)
|
||||
|
||||
res = e.run_cmd('play BLACK C4')
|
||||
res = e.run_cmd('play BLACK C5')
|
||||
res = e.run_cmd('play BLACK C6')
|
||||
res = e.run_cmd('play BLACK D3')
|
||||
print(res)
|
||||
|
||||
res = e.run_cmd('8 genmove BLACK')
|
||||
print(res)
|
||||
|
||||
#g.show_board()
|
||||
print(g.check_valid((10, 9)))
|
||||
print(g.executor._neighbor((1,1)))
|
||||
print(g.do_move(utils.WHITE, (4, 6)))
|
||||
#g.show_board()
|
||||
|
||||
|
||||
res = e.run_cmd('play BLACK L10')
|
||||
res = e.run_cmd('play BLACK L11')
|
||||
res = e.run_cmd('play BLACK L12')
|
||||
res = e.run_cmd('play BLACK L13')
|
||||
res = e.run_cmd('play BLACK L14')
|
||||
res = e.run_cmd('play BLACK m15')
|
||||
res = e.run_cmd('play BLACK m9')
|
||||
res = e.run_cmd('play BLACK C9')
|
||||
res = e.run_cmd('play BLACK D9')
|
||||
res = e.run_cmd('play BLACK E9')
|
||||
res = e.run_cmd('play BLACK F9')
|
||||
res = e.run_cmd('play BLACK G9')
|
||||
res = e.run_cmd('play BLACK H9')
|
||||
res = e.run_cmd('play BLACK I9')
|
||||
|
||||
res = e.run_cmd('play BLACK N9')
|
||||
res = e.run_cmd('play BLACK N15')
|
||||
res = e.run_cmd('play BLACK O10')
|
||||
res = e.run_cmd('play BLACK O11')
|
||||
res = e.run_cmd('play BLACK O12')
|
||||
res = e.run_cmd('play BLACK O13')
|
||||
res = e.run_cmd('play BLACK O14')
|
||||
res = e.run_cmd('play BLACK M12')
|
||||
|
||||
res = e.run_cmd('play WHITE M10')
|
||||
res = e.run_cmd('play WHITE M11')
|
||||
res = e.run_cmd('play WHITE N10')
|
||||
res = e.run_cmd('play WHITE N11')
|
||||
|
||||
res = e.run_cmd('play WHITE M13')
|
||||
res = e.run_cmd('play WHITE M14')
|
||||
res = e.run_cmd('play WHITE N13')
|
||||
res = e.run_cmd('play WHITE N14')
|
||||
print(res)
|
||||
|
||||
res = e.run_cmd('play BLACK N12')
|
||||
print(res)
|
||||
#g.show_board()
|
||||
|
||||
res = e.run_cmd('play BLACK P16')
|
||||
res = e.run_cmd('play BLACK P17')
|
||||
res = e.run_cmd('play BLACK P18')
|
||||
res = e.run_cmd('play BLACK P19')
|
||||
res = e.run_cmd('play BLACK Q16')
|
||||
res = e.run_cmd('play BLACK R16')
|
||||
res = e.run_cmd('play BLACK S16')
|
||||
|
||||
res = e.run_cmd('play WHITE S18')
|
||||
res = e.run_cmd('play WHITE S17')
|
||||
res = e.run_cmd('play WHITE Q19')
|
||||
res = e.run_cmd('play WHITE Q18')
|
||||
res = e.run_cmd('play WHITE Q17')
|
||||
res = e.run_cmd('play WHITE R18')
|
||||
res = e.run_cmd('play WHITE R17')
|
||||
res = e.run_cmd('play BLACK S19')
|
||||
print(res)
|
||||
#g.show_board()
|
||||
|
||||
res = e.run_cmd('play WHITE R19')
|
||||
g.show_board()
|
||||
e.run_cmd("genmove WHITE")
|
||||
|
||||
res = e.run_cmd('play BLACK S19')
|
||||
print(res)
|
||||
g.show_board()
|
||||
e.run_cmd("genmove BLACK")
|
||||
|
||||
res = e.run_cmd('play BLACK S19')
|
||||
print(res)
|
||||
|
||||
|
||||
res = e.run_cmd('play BLACK E17')
|
||||
res = e.run_cmd('play BLACK F16')
|
||||
res = e.run_cmd('play BLACK F18')
|
||||
res = e.run_cmd('play BLACK G17')
|
||||
res = e.run_cmd('play WHITE G16')
|
||||
res = e.run_cmd('play WHITE G18')
|
||||
res = e.run_cmd('play WHITE H17')
|
||||
g.show_board()
|
||||
e.run_cmd("genmove WHITE")
|
||||
|
||||
res = e.run_cmd('play WHITE F17')
|
||||
g.show_board()
|
||||
e.run_cmd("genmove BLACK")
|
||||
g.show_board()
|
||||
e.run_cmd("genmove WHITE")
|
||||
g.show_board()
|
||||
e.run_cmd("genmove BLACK")
|
||||
g.show_board()
|
||||
e.run_cmd("genmove WHITE")
|
||||
g.show_board()
|
||||
e.run_cmd("genmove BLACK")
|
||||
g.show_board()
|
||||
e.run_cmd("genmove WHITE")
|
||||
|
||||
res = e.run_cmd('play BLACK G17')
|
||||
print(res)
|
||||
g.show_board()
|
||||
|
||||
res = e.run_cmd('play BLACK G19')
|
||||
res = e.run_cmd('play BLACK G17')
|
||||
g.show_board()
|
@ -26,7 +26,7 @@ class MCTSNode(object):
|
||||
self.children = {}
|
||||
self.state = state
|
||||
self.action_num = action_num
|
||||
self.prior = prior
|
||||
self.prior = np.array(prior).reshape(-1)
|
||||
self.inverse = inverse
|
||||
|
||||
def selection(self, simulator):
|
||||
@ -35,6 +35,8 @@ class MCTSNode(object):
|
||||
def backpropagation(self, action):
|
||||
raise NotImplementedError("Need to implement function backpropagation")
|
||||
|
||||
def valid_mask(self, simulator):
|
||||
pass
|
||||
|
||||
class UCTNode(MCTSNode):
|
||||
def __init__(self, parent, action, state, action_num, prior, inverse=False):
|
||||
@ -45,6 +47,7 @@ class UCTNode(MCTSNode):
|
||||
self.ucb = self.Q + c_puct * self.prior * math.sqrt(np.sum(self.N)) / (self.N + 1)
|
||||
|
||||
def selection(self, simulator):
|
||||
self.valid_mask(simulator)
|
||||
action = np.argmax(self.ucb)
|
||||
if action in self.children.keys():
|
||||
return self.children[action].selection(simulator)
|
||||
@ -66,6 +69,11 @@ class UCTNode(MCTSNode):
|
||||
else:
|
||||
self.parent.backpropagation(self.children[action].reward)
|
||||
|
||||
def valid_mask(self, simulator):
|
||||
for act in range(self.action_num - 1):
|
||||
if not simulator.is_valid(self.state, act):
|
||||
self.ucb[act] = -float("Inf")
|
||||
|
||||
|
||||
class TSNode(MCTSNode):
|
||||
def __init__(self, parent, action, state, action_num, prior, method="Gaussian", inverse=False):
|
||||
@ -78,7 +86,7 @@ class TSNode(MCTSNode):
|
||||
self.sigma = np.zeros([action_num])
|
||||
|
||||
|
||||
class ActionNode:
|
||||
class ActionNode(object):
|
||||
def __init__(self, parent, action):
|
||||
self.parent = parent
|
||||
self.action = action
|
||||
@ -120,18 +128,19 @@ class ActionNode:
|
||||
self.parent.inverse)
|
||||
return value
|
||||
else:
|
||||
return 0
|
||||
return 0.
|
||||
|
||||
def backpropagation(self, value):
|
||||
self.reward += value
|
||||
self.parent.backpropagation(self.action)
|
||||
|
||||
|
||||
class MCTS:
|
||||
def __init__(self, simulator, evaluator, root, action_num, prior, method="UCT", inverse=False, max_step=None,
|
||||
class MCTS(object):
|
||||
def __init__(self, simulator, evaluator, root, action_num, method="UCT", inverse=False, max_step=None,
|
||||
max_time=None):
|
||||
self.simulator = simulator
|
||||
self.evaluator = evaluator
|
||||
prior, _ = self.evaluator(root)
|
||||
self.action_num = action_num
|
||||
if method == "":
|
||||
self.root = root
|
||||
|
@ -28,7 +28,7 @@ class TestEnv:
|
||||
if step == self.max_step:
|
||||
reward = int(np.random.uniform() < self.reward[num])
|
||||
else:
|
||||
reward = 0
|
||||
reward = 0.
|
||||
return new_state, reward
|
||||
|
||||
|
||||
@ -36,4 +36,4 @@ if __name__ == "__main__":
|
||||
env = TestEnv(2)
|
||||
rollout = rollout_policy(env, 2)
|
||||
evaluator = lambda state: rollout(state)
|
||||
mcts = MCTS(env, evaluator, [0, 0], 2, np.array([0.5, 0.5]), max_step=1e4)
|
||||
mcts = MCTS(env, evaluator, [0, 0], 2, max_step=1e4)
|
||||
|
Loading…
x
Reference in New Issue
Block a user