Merge remote-tracking branch 'origin/master'
This commit is contained in:
commit
86bf94fde1
4
.gitignore
vendored
4
.gitignore
vendored
@ -4,8 +4,8 @@ leela-zero
|
|||||||
parameters
|
parameters
|
||||||
*.swp
|
*.swp
|
||||||
*.sublime*
|
*.sublime*
|
||||||
checkpoints
|
checkpoint
|
||||||
checkpoints_origin
|
|
||||||
*.json
|
*.json
|
||||||
.DS_Store
|
.DS_Store
|
||||||
data
|
data
|
||||||
|
.log
|
||||||
|
@ -183,7 +183,7 @@ class GTPEngine():
|
|||||||
return 'unknown player', False
|
return 'unknown player', False
|
||||||
|
|
||||||
def cmd_get_score(self, args, **kwargs):
|
def cmd_get_score(self, args, **kwargs):
|
||||||
return self._game.game_engine.executor_get_score(self._game.board, True), True
|
return self._game.game_engine.executor_get_score(self._game.board), True
|
||||||
|
|
||||||
def cmd_show_board(self, args, **kwargs):
|
def cmd_show_board(self, args, **kwargs):
|
||||||
return self._game.board, True
|
return self._game.board, True
|
||||||
@ -194,4 +194,4 @@ class GTPEngine():
|
|||||||
|
|
||||||
if __name__ == "main":
|
if __name__ == "main":
|
||||||
game = Game()
|
game = Game()
|
||||||
engine = GTPEngine(game_obj=Game)
|
engine = GTPEngine(game_obj=game)
|
||||||
|
@ -10,12 +10,14 @@ import copy
|
|||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import sys, os
|
import sys, os
|
||||||
import go
|
|
||||||
import model
|
import model
|
||||||
from collections import deque
|
from collections import deque
|
||||||
sys.path.append(os.path.join(os.path.dirname(__file__), os.path.pardir))
|
sys.path.append(os.path.join(os.path.dirname(__file__), os.path.pardir))
|
||||||
from tianshou.core.mcts.mcts import MCTS
|
from tianshou.core.mcts.mcts import MCTS
|
||||||
|
|
||||||
|
import go
|
||||||
|
import reversi
|
||||||
|
|
||||||
class Game:
|
class Game:
|
||||||
'''
|
'''
|
||||||
Load the real game and trained weights.
|
Load the real game and trained weights.
|
||||||
@ -23,23 +25,32 @@ class Game:
|
|||||||
TODO : Maybe merge with the engine class in future,
|
TODO : Maybe merge with the engine class in future,
|
||||||
currently leave it untouched for interacting with Go UI.
|
currently leave it untouched for interacting with Go UI.
|
||||||
'''
|
'''
|
||||||
def __init__(self, size=9, komi=3.75, checkpoint_path=None):
|
def __init__(self, name="go", checkpoint_path=None):
|
||||||
self.size = size
|
self.name = name
|
||||||
self.komi = komi
|
if self.name == "go":
|
||||||
self.board = [utils.EMPTY] * (self.size ** 2)
|
self.size = 9
|
||||||
self.history = []
|
self.komi = 3.75
|
||||||
self.latest_boards = deque(maxlen=8)
|
self.board = [utils.EMPTY] * (self.size ** 2)
|
||||||
for _ in range(8):
|
self.history = []
|
||||||
self.latest_boards.append(self.board)
|
self.history_length = 8
|
||||||
self.evaluator = model.ResNet(self.size, self.size**2 + 1, history_length=8)
|
self.latest_boards = deque(maxlen=8)
|
||||||
# self.evaluator = lambda state: self.sess.run([tf.nn.softmax(self.net.p), self.net.v],
|
for _ in range(8):
|
||||||
# feed_dict={self.net.x: state, self.net.is_training: False})
|
self.latest_boards.append(self.board)
|
||||||
self.game_engine = go.Go(size=self.size, komi=self.komi)
|
self.game_engine = go.Go(size=self.size, komi=self.komi)
|
||||||
|
elif self.name == "reversi":
|
||||||
|
self.size = 8
|
||||||
|
self.history_length = 1
|
||||||
|
self.game_engine = reversi.Reversi()
|
||||||
|
self.board = self.game_engine.get_board()
|
||||||
|
else:
|
||||||
|
raise ValueError(name + " is an unknown game...")
|
||||||
|
|
||||||
|
self.evaluator = model.ResNet(self.size, self.size ** 2 + 1, history_length=self.history_length)
|
||||||
|
|
||||||
def clear(self):
|
def clear(self):
|
||||||
self.board = [utils.EMPTY] * (self.size ** 2)
|
self.board = [utils.EMPTY] * (self.size ** 2)
|
||||||
self.history = []
|
self.history = []
|
||||||
for _ in range(8):
|
for _ in range(self.history_length):
|
||||||
self.latest_boards.append(self.board)
|
self.latest_boards.append(self.board)
|
||||||
|
|
||||||
def set_size(self, n):
|
def set_size(self, n):
|
||||||
@ -65,7 +76,11 @@ class Game:
|
|||||||
# this function can be called directly to play the opponent's move
|
# this function can be called directly to play the opponent's move
|
||||||
if vertex == utils.PASS:
|
if vertex == utils.PASS:
|
||||||
return True
|
return True
|
||||||
res = self.game_engine.executor_do_move(self.history, self.latest_boards, self.board, color, vertex)
|
# TODO this implementation is not very elegant
|
||||||
|
if self.name == "go":
|
||||||
|
res = self.game_engine.executor_do_move(self.history, self.latest_boards, self.board, color, vertex)
|
||||||
|
elif self.name == "reversi":
|
||||||
|
res = self.game_engine.executor_do_move(self.board, color, vertex)
|
||||||
return res
|
return res
|
||||||
|
|
||||||
def think_play_move(self, color):
|
def think_play_move(self, color):
|
||||||
|
@ -157,7 +157,7 @@ class Go:
|
|||||||
vertex = self._deflatten(action)
|
vertex = self._deflatten(action)
|
||||||
return vertex
|
return vertex
|
||||||
|
|
||||||
def _is_valid(self, history_boards, current_board, color, vertex):
|
def _rule_check(self, history_boards, current_board, color, vertex):
|
||||||
### in board
|
### in board
|
||||||
if not self._in_board(vertex):
|
if not self._in_board(vertex):
|
||||||
return False
|
return False
|
||||||
@ -176,30 +176,30 @@ class Go:
|
|||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def simulate_is_valid(self, state, action):
|
def _is_valid(self, state, action):
|
||||||
history_boards, color = state
|
history_boards, color = state
|
||||||
vertex = self._action2vertex(action)
|
vertex = self._action2vertex(action)
|
||||||
current_board = history_boards[-1]
|
current_board = history_boards[-1]
|
||||||
|
|
||||||
if not self._is_valid(history_boards, current_board, color, vertex):
|
if not self._rule_check(history_boards, current_board, color, vertex):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if not self._knowledge_prunning(current_board, color, vertex):
|
if not self._knowledge_prunning(current_board, color, vertex):
|
||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def simulate_is_valid_list(self, state, action_set):
|
def simulate_get_mask(self, state, action_set):
|
||||||
# find all the invalid actions
|
# find all the invalid actions
|
||||||
invalid_action_list = []
|
invalid_action_mask = []
|
||||||
for action_candidate in action_set[:-1]:
|
for action_candidate in action_set[:-1]:
|
||||||
# go through all the actions excluding pass
|
# go through all the actions excluding pass
|
||||||
if not self.simulate_is_valid(state, action_candidate):
|
if not self._is_valid(state, action_candidate):
|
||||||
invalid_action_list.append(action_candidate)
|
invalid_action_mask.append(action_candidate)
|
||||||
if len(invalid_action_list) < len(action_set) - 1:
|
if len(invalid_action_mask) < len(action_set) - 1:
|
||||||
invalid_action_list.append(action_set[-1])
|
invalid_action_mask.append(action_set[-1])
|
||||||
# forbid pass, if we have other choices
|
# forbid pass, if we have other choices
|
||||||
# TODO: In fact we should not do this. In some extreme cases, we should permit pass.
|
# TODO: In fact we should not do this. In some extreme cases, we should permit pass.
|
||||||
return invalid_action_list
|
return invalid_action_mask
|
||||||
|
|
||||||
def _do_move(self, board, color, vertex):
|
def _do_move(self, board, color, vertex):
|
||||||
if vertex == utils.PASS:
|
if vertex == utils.PASS:
|
||||||
@ -219,7 +219,7 @@ class Go:
|
|||||||
return [history_boards, new_color], 0
|
return [history_boards, new_color], 0
|
||||||
|
|
||||||
def executor_do_move(self, history, latest_boards, current_board, color, vertex):
|
def executor_do_move(self, history, latest_boards, current_board, color, vertex):
|
||||||
if not self._is_valid(history, current_board, color, vertex):
|
if not self._rule_check(history, current_board, color, vertex):
|
||||||
return False
|
return False
|
||||||
current_board[self._flatten(vertex)] = color
|
current_board[self._flatten(vertex)] = color
|
||||||
self._process_board(current_board, color, vertex)
|
self._process_board(current_board, color, vertex)
|
||||||
@ -280,7 +280,7 @@ class Go:
|
|||||||
elif color_estimate < 0:
|
elif color_estimate < 0:
|
||||||
return utils.WHITE
|
return utils.WHITE
|
||||||
|
|
||||||
def executor_get_score(self, current_board, is_unknown_estimation=False):
|
def executor_get_score(self, current_board):
|
||||||
'''
|
'''
|
||||||
is_unknown_estimation: whether use nearby stone to predict the unknown
|
is_unknown_estimation: whether use nearby stone to predict the unknown
|
||||||
return score from BLACK perspective.
|
return score from BLACK perspective.
|
||||||
@ -294,10 +294,8 @@ class Go:
|
|||||||
_board[self._flatten(vertex)] = utils.BLACK
|
_board[self._flatten(vertex)] = utils.BLACK
|
||||||
elif boarder_color == {utils.WHITE}:
|
elif boarder_color == {utils.WHITE}:
|
||||||
_board[self._flatten(vertex)] = utils.WHITE
|
_board[self._flatten(vertex)] = utils.WHITE
|
||||||
elif is_unknown_estimation:
|
|
||||||
_board[self._flatten(vertex)] = self._predict_from_nearby(_board, vertex)
|
|
||||||
else:
|
else:
|
||||||
_board[self._flatten(vertex)] =utils.UNKNOWN
|
_board[self._flatten(vertex)] = self._predict_from_nearby(_board, vertex)
|
||||||
score = 0
|
score = 0
|
||||||
for i in _board:
|
for i in _board:
|
||||||
if i == utils.BLACK:
|
if i == utils.BLACK:
|
||||||
@ -308,3 +306,42 @@ class Go:
|
|||||||
|
|
||||||
return score
|
return score
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
### do unit test for Go class
|
||||||
|
pure_test = [
|
||||||
|
0, 1, 0, 1, 0, 1, 0, 0, 0,
|
||||||
|
1, 0, 1, 0, 1, 0, 0, 0, 0,
|
||||||
|
0, 1, 0, 1, 0, 0, 1, 0, 0,
|
||||||
|
0, 0, 1, 0, 0, 1, 0, 1, 0,
|
||||||
|
0, 0, 0, 0, 0, 1, 1, 1, 0,
|
||||||
|
1, 1, 1, 0, 0, 0, 0, 0, 0,
|
||||||
|
1, 0, 1, 0, 0, 1, 1, 0, 0,
|
||||||
|
1, 1, 1, 0, 1, 0, 1, 0, 0,
|
||||||
|
0, 0, 0, 0, 1, 1, 1, 0, 0
|
||||||
|
]
|
||||||
|
|
||||||
|
pt_qry = [(1, 1), (1, 5), (3, 3), (4, 7), (7, 2), (8, 6)]
|
||||||
|
pt_ans = [True, True, True, True, True, True]
|
||||||
|
|
||||||
|
opponent_test = [
|
||||||
|
0, 1, 0, 1, 0, 1, 0,-1, 1,
|
||||||
|
1,-1, 0,-1, 1,-1, 0, 1, 0,
|
||||||
|
0, 0, 0, 0, 0, 0, 0, 0, 1,
|
||||||
|
1, 1,-1, 0, 1,-1, 1, 0, 0,
|
||||||
|
1, 0, 1, 0, 1, 0, 1, 0, 0,
|
||||||
|
-1,1, 1, 0, 1, 1, 1, 0, 0,
|
||||||
|
0, 1,-1, 0,-1,-1,-1, 0, 0,
|
||||||
|
1, 0, 1, 0,-1, 0,-1, 0, 0,
|
||||||
|
0, 1, 0, 0,-1,-1,-1, 0, 0
|
||||||
|
]
|
||||||
|
ot_qry = [(1, 1), (1, 5), (2, 9), (5, 2), (5, 6), (8, 6), (8, 2)]
|
||||||
|
ot_ans = [False, False, False, False, False, False, True]
|
||||||
|
|
||||||
|
go = Go(size=9, komi=3.75)
|
||||||
|
for i in range(6):
|
||||||
|
print (go._is_eye(pure_test, utils.BLACK, pt_qry[i]))
|
||||||
|
print("Test of pure eye\n")
|
||||||
|
|
||||||
|
for i in range(7):
|
||||||
|
print (go._is_eye(opponent_test, utils.BLACK, ot_qry[i]))
|
||||||
|
print("Test of eye surrend by opponents\n")
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
|
import random
|
||||||
import sys
|
import sys
|
||||||
import cPickle
|
import cPickle
|
||||||
from collections import deque
|
from collections import deque
|
||||||
@ -104,7 +105,7 @@ class ResNet(object):
|
|||||||
self.window_length = 7000
|
self.window_length = 7000
|
||||||
self.save_freq = 5000
|
self.save_freq = 5000
|
||||||
self.training_data = {'states': deque(maxlen=self.window_length), 'probs': deque(maxlen=self.window_length),
|
self.training_data = {'states': deque(maxlen=self.window_length), 'probs': deque(maxlen=self.window_length),
|
||||||
'winner': deque(maxlen=self.window_length)}
|
'winner': deque(maxlen=self.window_length), 'length': deque(maxlen=self.window_length)}
|
||||||
|
|
||||||
def _build_network(self, residual_block_num, checkpoint_path):
|
def _build_network(self, residual_block_num, checkpoint_path):
|
||||||
"""
|
"""
|
||||||
@ -199,15 +200,15 @@ class ResNet(object):
|
|||||||
|
|
||||||
new_file_list = []
|
new_file_list = []
|
||||||
all_file_list = []
|
all_file_list = []
|
||||||
training_data = {}
|
training_data = {'states': [], 'probs': [], 'winner': []}
|
||||||
|
|
||||||
iters = 0
|
iters = 0
|
||||||
while True:
|
while True:
|
||||||
new_file_list = list(set(os.listdir(data_path)).difference(all_file_list))
|
new_file_list = list(set(os.listdir(data_path)).difference(all_file_list))
|
||||||
if new_file_list:
|
while new_file_list:
|
||||||
all_file_list = os.listdir(data_path)
|
all_file_list = os.listdir(data_path)
|
||||||
new_file_list.sort(
|
new_file_list.sort(
|
||||||
key=lambda file: os.path.getmtime(data_path + file) if not os.path.isdir(data_path + file) else 0)
|
key=lambda file: os.path.getmtime(data_path + file) if not os.path.isdir(data_path + file) else 0)
|
||||||
if new_file_list:
|
|
||||||
for file in new_file_list:
|
for file in new_file_list:
|
||||||
states, probs, winner = self._file_to_training_data(data_path + file)
|
states, probs, winner = self._file_to_training_data(data_path + file)
|
||||||
assert states.shape[0] == probs.shape[0]
|
assert states.shape[0] == probs.shape[0]
|
||||||
@ -215,32 +216,36 @@ class ResNet(object):
|
|||||||
self.training_data['states'].append(states)
|
self.training_data['states'].append(states)
|
||||||
self.training_data['probs'].append(probs)
|
self.training_data['probs'].append(probs)
|
||||||
self.training_data['winner'].append(winner)
|
self.training_data['winner'].append(winner)
|
||||||
if len(self.training_data['states']) == self.window_length:
|
self.training_data['length'].append(states.shape[0])
|
||||||
training_data['states'] = np.concatenate(self.training_data['states'], axis=0)
|
new_file_list = list(set(os.listdir(data_path)).difference(all_file_list))
|
||||||
training_data['probs'] = np.concatenate(self.training_data['probs'], axis=0)
|
|
||||||
training_data['winner'] = np.concatenate(self.training_data['winner'], axis=0)
|
|
||||||
|
|
||||||
if len(self.training_data['states']) != self.window_length:
|
if len(self.training_data['states']) != self.window_length:
|
||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
data_num = training_data['states'].shape[0]
|
|
||||||
index = np.arange(data_num)
|
|
||||||
np.random.shuffle(index)
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
for i in range(batch_size):
|
||||||
|
game_num = random.randint(0, self.window_length-1)
|
||||||
|
state_num = random.randint(0, self.training_data['length'][game_num]-1)
|
||||||
|
training_data['states'].append(np.expand_dims(self.training_data['states'][game_num][state_num], 0))
|
||||||
|
training_data['probs'].append(np.expand_dims(self.training_data['probs'][game_num][state_num], 0))
|
||||||
|
training_data['winner'].append(np.expand_dims(self.training_data['winner'][game_num][state_num], 0))
|
||||||
value_loss, policy_loss, reg, _ = self.sess.run(
|
value_loss, policy_loss, reg, _ = self.sess.run(
|
||||||
[self.value_loss, self.policy_loss, self.reg, self.train_op],
|
[self.value_loss, self.policy_loss, self.reg, self.train_op],
|
||||||
feed_dict={self.x: training_data['states'][index[:batch_size]],
|
feed_dict={self.x: np.concatenate(training_data['states'], axis=0),
|
||||||
self.z: training_data['winner'][index[:batch_size]],
|
self.z: np.concatenate(training_data['winner'], axis=0),
|
||||||
self.pi: training_data['probs'][index[:batch_size]],
|
self.pi: np.concatenate(training_data['probs'], axis=0),
|
||||||
self.is_training: True})
|
self.is_training: True})
|
||||||
|
|
||||||
print("Iteration: {}, Time: {}, Value Loss: {}, Policy Loss: {}, Reg: {}".format(iters,
|
print("Iteration: {}, Time: {}, Value Loss: {}, Policy Loss: {}, Reg: {}".format(iters,
|
||||||
time.time() - start_time,
|
time.time() - start_time,
|
||||||
value_loss,
|
value_loss,
|
||||||
policy_loss, reg))
|
policy_loss, reg))
|
||||||
iters += 1
|
|
||||||
if iters % self.save_freq == 0:
|
if iters % self.save_freq == 0:
|
||||||
save_path = "Iteration{}.ckpt".format(iters)
|
save_path = "Iteration{}.ckpt".format(iters)
|
||||||
self.saver.save(self.sess, self.checkpoint_path + save_path)
|
self.saver.save(self.sess, self.checkpoint_path + save_path)
|
||||||
|
for key in training_data.keys():
|
||||||
|
training_data[key] = []
|
||||||
|
iters += 1
|
||||||
|
|
||||||
def _file_to_training_data(self, file_name):
|
def _file_to_training_data(self, file_name):
|
||||||
read = False
|
read = False
|
||||||
@ -250,7 +255,7 @@ class ResNet(object):
|
|||||||
file.seek(0)
|
file.seek(0)
|
||||||
data = cPickle.load(file)
|
data = cPickle.load(file)
|
||||||
read = True
|
read = True
|
||||||
print("{} Loaded".format(file_name))
|
print("{} Loaded!".format(file_name))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
print(e)
|
||||||
time.sleep(1)
|
time.sleep(1)
|
||||||
@ -276,6 +281,6 @@ class ResNet(object):
|
|||||||
return states, probs, winner
|
return states, probs, winner
|
||||||
|
|
||||||
|
|
||||||
if __name__=="__main__":
|
if __name__ == "__main__":
|
||||||
model = ResNet(board_size=9, action_num=82)
|
model = ResNet(board_size=9, action_num=82, history_length=8)
|
||||||
model.train("file", data_path="./data/", batch_size=128, checkpoint_path="./checkpoint/")
|
model.train("file", data_path="./data/", batch_size=128, checkpoint_path="./checkpoint/")
|
||||||
|
@ -7,7 +7,6 @@ import time
|
|||||||
import os
|
import os
|
||||||
import cPickle
|
import cPickle
|
||||||
|
|
||||||
|
|
||||||
class Data(object):
|
class Data(object):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.boards = []
|
self.boards = []
|
||||||
|
@ -34,7 +34,7 @@ if __name__ == '__main__':
|
|||||||
|
|
||||||
daemon = Pyro4.Daemon() # make a Pyro daemon
|
daemon = Pyro4.Daemon() # make a Pyro daemon
|
||||||
ns = Pyro4.locateNS() # find the name server
|
ns = Pyro4.locateNS() # find the name server
|
||||||
player = Player(role = args.role, engine = engine)
|
player = Player(role=args.role, engine=engine)
|
||||||
print "Init " + args.role + " player finished"
|
print "Init " + args.role + " player finished"
|
||||||
uri = daemon.register(player) # register the greeting maker as a Pyro object
|
uri = daemon.register(player) # register the greeting maker as a Pyro object
|
||||||
print "Start on name " + args.role
|
print "Start on name " + args.role
|
||||||
|
@ -25,7 +25,6 @@ def find_correct_moves(own, enemy):
|
|||||||
mobility |= search_offset_right(own, enemy, mask, 7) # Left bottom
|
mobility |= search_offset_right(own, enemy, mask, 7) # Left bottom
|
||||||
return mobility
|
return mobility
|
||||||
|
|
||||||
|
|
||||||
def calc_flip(pos, own, enemy):
|
def calc_flip(pos, own, enemy):
|
||||||
"""return flip stones of enemy by bitboard when I place stone at pos.
|
"""return flip stones of enemy by bitboard when I place stone at pos.
|
||||||
|
|
||||||
@ -34,7 +33,6 @@ def calc_flip(pos, own, enemy):
|
|||||||
:param enemy: bitboard
|
:param enemy: bitboard
|
||||||
:return: flip stones of enemy when I place stone at pos.
|
:return: flip stones of enemy when I place stone at pos.
|
||||||
"""
|
"""
|
||||||
assert 0 <= pos <= 63, f"pos={pos}"
|
|
||||||
f1 = _calc_flip_half(pos, own, enemy)
|
f1 = _calc_flip_half(pos, own, enemy)
|
||||||
f2 = _calc_flip_half(63 - pos, rotate180(own), rotate180(enemy))
|
f2 = _calc_flip_half(63 - pos, rotate180(own), rotate180(enemy))
|
||||||
return f1 | rotate180(f2)
|
return f1 | rotate180(f2)
|
||||||
@ -125,27 +123,42 @@ class Reversi:
|
|||||||
self.board = None # 8 * 8 board with 1 for black, -1 for white and 0 for blank
|
self.board = None # 8 * 8 board with 1 for black, -1 for white and 0 for blank
|
||||||
self.color = None # 1 for black and -1 for white
|
self.color = None # 1 for black and -1 for white
|
||||||
self.action = None # number in 0~63
|
self.action = None # number in 0~63
|
||||||
self.winner = None
|
# self.winner = None
|
||||||
|
self.black_win = None
|
||||||
|
|
||||||
def simulate_is_valid(self, board, color):
|
def get_board(self, black=None, white=None):
|
||||||
|
self.black = black or (0b00001000 << 24 | 0b00010000 << 32)
|
||||||
|
self.white = white or (0b00010000 << 24 | 0b00001000 << 32)
|
||||||
|
self.board = self.bitboard2board()
|
||||||
|
return self.board
|
||||||
|
|
||||||
|
def simulate_get_mask(self, state, action_set):
|
||||||
|
history_boards, color = state
|
||||||
|
board = history_boards[-1]
|
||||||
self.board = board
|
self.board = board
|
||||||
self.color = color
|
self.color = color
|
||||||
self.board2bitboard()
|
self.board2bitboard()
|
||||||
own, enemy = self.get_own_and_enemy()
|
own, enemy = self.get_own_and_enemy()
|
||||||
mobility = find_correct_moves(own, enemy)
|
mobility = find_correct_moves(own, enemy)
|
||||||
valid_moves = bit_to_array(mobility, 64)
|
valid_moves = bit_to_array(mobility, 64)
|
||||||
|
valid_moves = np.argwhere(valid_moves)
|
||||||
valid_moves = list(np.reshape(valid_moves, len(valid_moves)))
|
valid_moves = list(np.reshape(valid_moves, len(valid_moves)))
|
||||||
return valid_moves
|
# TODO it seems that the pass move is not considered
|
||||||
|
invalid_action_mask = []
|
||||||
|
for action in action_set:
|
||||||
|
if action not in valid_moves:
|
||||||
|
invalid_action_mask.append(action)
|
||||||
|
return invalid_action_mask
|
||||||
|
|
||||||
def simulate_step_forward(self, board, color, vertex):
|
def simulate_step_forward(self, state, action):
|
||||||
self.board = board
|
self.board = state[0]
|
||||||
self.color = color
|
self.color = state[1]
|
||||||
self.board2bitboard()
|
self.board2bitboard()
|
||||||
self.vertex2action(vertex)
|
self.action = action
|
||||||
step_forward = self.step()
|
step_forward = self.step()
|
||||||
if step_forward:
|
if step_forward:
|
||||||
new_board = self.bitboard2board()
|
new_board = self.bitboard2board()
|
||||||
return new_board
|
return [new_board, 0 - self.color], 0
|
||||||
|
|
||||||
def executor_do_move(self, board, color, vertex):
|
def executor_do_move(self, board, color, vertex):
|
||||||
self.board = board
|
self.board = board
|
||||||
@ -155,20 +168,21 @@ class Reversi:
|
|||||||
step_forward = self.step()
|
step_forward = self.step()
|
||||||
if step_forward:
|
if step_forward:
|
||||||
new_board = self.bitboard2board()
|
new_board = self.bitboard2board()
|
||||||
return new_board
|
for i in range(64):
|
||||||
|
board[i] = new_board[i]
|
||||||
|
|
||||||
def executor_get_score(self, board):
|
def executor_get_score(self, board):
|
||||||
self.board = board
|
self.board = board
|
||||||
self._game_over()
|
self._game_over()
|
||||||
if self.winner is not None:
|
if self.black_win is not None:
|
||||||
return self.winner, 0 - self.winner
|
return self.black_win
|
||||||
else:
|
else:
|
||||||
ValueError("Game not finished!")
|
raise ValueError("Game not finished!")
|
||||||
|
|
||||||
def board2bitboard(self):
|
def board2bitboard(self):
|
||||||
count = 1
|
count = 1
|
||||||
if self.board is None:
|
if self.board is None:
|
||||||
ValueError("None board!")
|
raise ValueError("None board!")
|
||||||
self.black = 0
|
self.black = 0
|
||||||
self.white = 0
|
self.white = 0
|
||||||
for i in range(64):
|
for i in range(64):
|
||||||
@ -200,7 +214,7 @@ class Reversi:
|
|||||||
|
|
||||||
def step(self):
|
def step(self):
|
||||||
if self.action < 0 or self.action > 63:
|
if self.action < 0 or self.action > 63:
|
||||||
ValueError("Wrong action!")
|
raise ValueError("Wrong action!")
|
||||||
if self.action is None:
|
if self.action is None:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@ -219,6 +233,7 @@ class Reversi:
|
|||||||
|
|
||||||
def _game_over(self):
|
def _game_over(self):
|
||||||
# self.done = True
|
# self.done = True
|
||||||
|
'''
|
||||||
if self.winner is None:
|
if self.winner is None:
|
||||||
black_num, white_num = self.number_of_black_and_white
|
black_num, white_num = self.number_of_black_and_white
|
||||||
if black_num > white_num:
|
if black_num > white_num:
|
||||||
@ -227,9 +242,12 @@ class Reversi:
|
|||||||
self.winner = -1
|
self.winner = -1
|
||||||
else:
|
else:
|
||||||
self.winner = 0
|
self.winner = 0
|
||||||
|
'''
|
||||||
|
if self.black_win is None:
|
||||||
|
black_num, white_num = self.number_of_black_and_white
|
||||||
|
self.black_win = black_num - white_num
|
||||||
|
|
||||||
def illegal_move_to_lose(self, action):
|
def illegal_move_to_lose(self, action):
|
||||||
logger.warning(f"Illegal action={action}, No Flipped!")
|
|
||||||
self._game_over()
|
self._game_over()
|
||||||
|
|
||||||
def get_own_and_enemy(self):
|
def get_own_and_enemy(self):
|
||||||
|
@ -79,7 +79,7 @@ while True:
|
|||||||
prob.append(np.array(game.prob).reshape(-1, game.size ** 2 + 1))
|
prob.append(np.array(game.prob).reshape(-1, game.size ** 2 + 1))
|
||||||
print("Finished")
|
print("Finished")
|
||||||
print("\n")
|
print("\n")
|
||||||
score = game.game_engine.executor_get_score(game.board, True)
|
score = game.game_engine.executor_get_score(game.board)
|
||||||
if score > 0:
|
if score > 0:
|
||||||
winner = utils.BLACK
|
winner = utils.BLACK
|
||||||
else:
|
else:
|
||||||
|
@ -1,266 +0,0 @@
|
|||||||
import numpy as np
|
|
||||||
import sys
|
|
||||||
from game import Game
|
|
||||||
from engine import GTPEngine
|
|
||||||
import utils
|
|
||||||
import time
|
|
||||||
import copy
|
|
||||||
import network_small
|
|
||||||
import tensorflow as tf
|
|
||||||
from collections import deque
|
|
||||||
from tianshou.core.mcts.mcts import MCTS
|
|
||||||
|
|
||||||
DELTA = [[1, 0], [-1, 0], [0, -1], [0, 1]]
|
|
||||||
CORNER_OFFSET = [[-1, -1], [-1, 1], [1, 1], [1, -1]]
|
|
||||||
|
|
||||||
class GoEnv:
|
|
||||||
def __init__(self, size=9, komi=6.5):
|
|
||||||
self.size = size
|
|
||||||
self.komi = komi
|
|
||||||
self.board = [utils.EMPTY] * (self.size * self.size)
|
|
||||||
self.history = deque(maxlen=8)
|
|
||||||
|
|
||||||
def _set_board(self, board):
|
|
||||||
self.board = board
|
|
||||||
|
|
||||||
def _flatten(self, vertex):
|
|
||||||
x, y = vertex
|
|
||||||
return (x - 1) * self.size + (y - 1)
|
|
||||||
|
|
||||||
def _bfs(self, vertex, color, block, status, alive_break):
|
|
||||||
block.append(vertex)
|
|
||||||
status[self._flatten(vertex)] = True
|
|
||||||
nei = self._neighbor(vertex)
|
|
||||||
for n in nei:
|
|
||||||
if not status[self._flatten(n)]:
|
|
||||||
if self.board[self._flatten(n)] == color:
|
|
||||||
self._bfs(n, color, block, status, alive_break)
|
|
||||||
|
|
||||||
def _find_block(self, vertex, alive_break=False):
|
|
||||||
block = []
|
|
||||||
status = [False] * (self.size * self.size)
|
|
||||||
color = self.board[self._flatten(vertex)]
|
|
||||||
self._bfs(vertex, color, block, status, alive_break)
|
|
||||||
|
|
||||||
for b in block:
|
|
||||||
for n in self._neighbor(b):
|
|
||||||
if self.board[self._flatten(n)] == utils.EMPTY:
|
|
||||||
return False, block
|
|
||||||
return True, block
|
|
||||||
|
|
||||||
def _is_qi(self, color, vertex):
|
|
||||||
nei = self._neighbor(vertex)
|
|
||||||
for n in nei:
|
|
||||||
if self.board[self._flatten(n)] == utils.EMPTY:
|
|
||||||
return True
|
|
||||||
|
|
||||||
self.board[self._flatten(vertex)] = color
|
|
||||||
for n in nei:
|
|
||||||
if self.board[self._flatten(n)] == utils.another_color(color):
|
|
||||||
can_kill, block = self._find_block(n)
|
|
||||||
if can_kill:
|
|
||||||
self.board[self._flatten(vertex)] = utils.EMPTY
|
|
||||||
return True
|
|
||||||
|
|
||||||
### avoid suicide
|
|
||||||
can_kill, block = self._find_block(vertex)
|
|
||||||
if can_kill:
|
|
||||||
self.board[self._flatten(vertex)] = utils.EMPTY
|
|
||||||
return False
|
|
||||||
|
|
||||||
self.board[self._flatten(vertex)] = utils.EMPTY
|
|
||||||
return True
|
|
||||||
|
|
||||||
def _check_global_isomorphous(self, color, vertex):
|
|
||||||
##backup
|
|
||||||
_board = copy.copy(self.board)
|
|
||||||
self.board[self._flatten(vertex)] = color
|
|
||||||
self._process_board(color, vertex)
|
|
||||||
if self.board in self.history:
|
|
||||||
res = True
|
|
||||||
else:
|
|
||||||
res = False
|
|
||||||
|
|
||||||
self.board = _board
|
|
||||||
return res
|
|
||||||
|
|
||||||
def _in_board(self, vertex):
|
|
||||||
x, y = vertex
|
|
||||||
if x < 1 or x > self.size: return False
|
|
||||||
if y < 1 or y > self.size: return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
def _neighbor(self, vertex):
|
|
||||||
x, y = vertex
|
|
||||||
nei = []
|
|
||||||
for d in DELTA:
|
|
||||||
_x = x + d[0]
|
|
||||||
_y = y + d[1]
|
|
||||||
if self._in_board((_x, _y)):
|
|
||||||
nei.append((_x, _y))
|
|
||||||
return nei
|
|
||||||
|
|
||||||
def _corner(self, vertex):
|
|
||||||
x, y = vertex
|
|
||||||
corner = []
|
|
||||||
for d in CORNER_OFFSET:
|
|
||||||
_x = x + d[0]
|
|
||||||
_y = y + d[1]
|
|
||||||
if self._in_board((_x, _y)):
|
|
||||||
corner.append((_x, _y))
|
|
||||||
return corner
|
|
||||||
|
|
||||||
def _process_board(self, color, vertex):
|
|
||||||
nei = self._neighbor(vertex)
|
|
||||||
for n in nei:
|
|
||||||
if self.board[self._flatten(n)] == utils.another_color(color):
|
|
||||||
can_kill, block = self._find_block(n, alive_break=True)
|
|
||||||
if can_kill:
|
|
||||||
for b in block:
|
|
||||||
self.board[self._flatten(b)] = utils.EMPTY
|
|
||||||
|
|
||||||
def _find_group(self, start):
|
|
||||||
color = self.board[self._flatten(start)]
|
|
||||||
#print ("color : ", color)
|
|
||||||
chain = set()
|
|
||||||
frontier = [start]
|
|
||||||
while frontier:
|
|
||||||
current = frontier.pop()
|
|
||||||
#print ("current : ", current)
|
|
||||||
chain.add(current)
|
|
||||||
for n in self._neighbor(current):
|
|
||||||
#print n, self._flatten(n), self.board[self._flatten(n)],
|
|
||||||
if self.board[self._flatten(n)] == color and not n in chain:
|
|
||||||
frontier.append(n)
|
|
||||||
return chain
|
|
||||||
|
|
||||||
def _is_eye(self, color, vertex):
|
|
||||||
nei = self._neighbor(vertex)
|
|
||||||
cor = self._corner(vertex)
|
|
||||||
ncolor = {color == self.board[self._flatten(n)] for n in nei}
|
|
||||||
if False in ncolor:
|
|
||||||
#print "not all neighbors are in same color with us"
|
|
||||||
return False
|
|
||||||
if set(nei) < self._find_group(nei[0]):
|
|
||||||
#print "all neighbors are in same group and same color with us"
|
|
||||||
return True
|
|
||||||
else:
|
|
||||||
opponent_number = [self.board[self._flatten(c)] for c in cor].count(-color)
|
|
||||||
opponent_propotion = float(opponent_number) / float(len(cor))
|
|
||||||
if opponent_propotion < 0.5:
|
|
||||||
#print "few opponents, real eye"
|
|
||||||
return True
|
|
||||||
else:
|
|
||||||
#print "many opponents, fake eye"
|
|
||||||
return False
|
|
||||||
|
|
||||||
# def is_valid(self, color, vertex):
|
|
||||||
def is_valid(self, state, action):
|
|
||||||
# state is the play board, the shape is [1, 9, 9, 17]
|
|
||||||
if action == self.size * self.size:
|
|
||||||
vertex = (0, 0)
|
|
||||||
else:
|
|
||||||
vertex = (action / self.size + 1, action % self.size + 1)
|
|
||||||
if state[0, 0, 0, -1] == utils.BLACK:
|
|
||||||
color = utils.BLACK
|
|
||||||
else:
|
|
||||||
color = utils.WHITE
|
|
||||||
self.history.clear()
|
|
||||||
for i in range(8):
|
|
||||||
self.history.append((state[:, :, :, i] - state[:, :, :, i + 8]).reshape(-1).tolist())
|
|
||||||
self.board = copy.copy(self.history[-1])
|
|
||||||
### in board
|
|
||||||
if not self._in_board(vertex):
|
|
||||||
return False
|
|
||||||
|
|
||||||
### already have stone
|
|
||||||
if not self.board[self._flatten(vertex)] == utils.EMPTY:
|
|
||||||
# print(np.array(self.board).reshape(9, 9))
|
|
||||||
# print(vertex)
|
|
||||||
return False
|
|
||||||
|
|
||||||
### check if it is qi
|
|
||||||
if not self._is_qi(color, vertex):
|
|
||||||
return False
|
|
||||||
|
|
||||||
### check if it is an eye of yourself
|
|
||||||
### assumptions : notice that this judgement requires that the state is an endgame
|
|
||||||
#if self._is_eye(color, vertex):
|
|
||||||
# return False
|
|
||||||
|
|
||||||
if self._check_global_isomorphous(color, vertex):
|
|
||||||
return False
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
def do_move(self, color, vertex):
|
|
||||||
if vertex == utils.PASS:
|
|
||||||
return True
|
|
||||||
|
|
||||||
id_ = self._flatten(vertex)
|
|
||||||
if self.board[id_] == utils.EMPTY:
|
|
||||||
self.board[id_] = color
|
|
||||||
self.history.append(copy.copy(self.board))
|
|
||||||
return True
|
|
||||||
else:
|
|
||||||
return False
|
|
||||||
|
|
||||||
def step_forward(self, state, action):
|
|
||||||
if state[0, 0, 0, -1] == 1:
|
|
||||||
color = 1
|
|
||||||
else:
|
|
||||||
color = -1
|
|
||||||
if action == 81:
|
|
||||||
vertex = (0, 0)
|
|
||||||
else:
|
|
||||||
vertex = (action % 9 + 1, action / 9 + 1)
|
|
||||||
# print(vertex)
|
|
||||||
# print(self.board)
|
|
||||||
self.board = (state[:, :, :, 7] - state[:, :, :, 15]).reshape(-1).tolist()
|
|
||||||
self.do_move(color, vertex)
|
|
||||||
new_state = np.concatenate(
|
|
||||||
[state[:, :, :, 1:8], (np.array(self.board) == 1).reshape(1, 9, 9, 1),
|
|
||||||
state[:, :, :, 9:16], (np.array(self.board) == -1).reshape(1, 9, 9, 1),
|
|
||||||
np.array(1 - state[:, :, :, -1]).reshape(1, 9, 9, 1)],
|
|
||||||
axis=3)
|
|
||||||
return new_state, 0
|
|
||||||
|
|
||||||
|
|
||||||
pure_test = [
|
|
||||||
0, 1, 0, 1, 0, 1, 0, 0, 0,
|
|
||||||
1, 0, 1, 0, 1, 0, 0, 0, 0,
|
|
||||||
0, 1, 0, 1, 0, 0, 1, 0, 0,
|
|
||||||
0, 0, 1, 0, 0, 1, 0, 1, 0,
|
|
||||||
0, 0, 0, 0, 0, 1, 1, 1, 0,
|
|
||||||
1, 1, 1, 0, 0, 0, 0, 0, 0,
|
|
||||||
1, 0, 1, 0, 0, 1, 1, 0, 0,
|
|
||||||
1, 1, 1, 0, 1, 0, 1, 0, 0,
|
|
||||||
0, 0, 0, 0, 1, 1, 1, 0, 0
|
|
||||||
]
|
|
||||||
|
|
||||||
pt_qry = [(1, 1), (1, 5), (3, 3), (4, 7), (7, 2), (8, 6)]
|
|
||||||
pt_ans = [True, True, True, True, True, True]
|
|
||||||
|
|
||||||
opponent_test = [
|
|
||||||
0, 1, 0, 1, 0, 1, 0,-1, 1,
|
|
||||||
1,-1, 0,-1, 1,-1, 0, 1, 0,
|
|
||||||
0, 0, 0, 0, 0, 0, 0, 0, 1,
|
|
||||||
1, 1,-1, 0, 1,-1, 1, 0, 0,
|
|
||||||
1, 0, 1, 0, 1, 0, 1, 0, 0,
|
|
||||||
-1, 1, 1, 0, 1, 1, 1, 0, 0,
|
|
||||||
0, 1,-1, 0,-1,-1,-1, 0, 0,
|
|
||||||
1, 0, 1, 0,-1, 0,-1, 0, 0,
|
|
||||||
0, 1, 0, 0,-1,-1,-1, 0, 0
|
|
||||||
]
|
|
||||||
ot_qry = [(1, 1), (1, 5), (2, 9), (5, 2), (5, 6), (8, 2), (8, 6)]
|
|
||||||
ot_ans = [False, False, False, False, False, True, False]
|
|
||||||
|
|
||||||
#print (ge._find_group((6, 1)))
|
|
||||||
#print ge._is_eye(utils.BLACK, pt_qry[0])
|
|
||||||
ge = GoEnv()
|
|
||||||
ge._set_board(pure_test)
|
|
||||||
for i in range(6):
|
|
||||||
print (ge._is_eye(utils.BLACK, pt_qry[i]))
|
|
||||||
ge._set_board(opponent_test)
|
|
||||||
for i in range(7):
|
|
||||||
print (ge._is_eye(utils.BLACK, ot_qry[i]))
|
|
@ -73,7 +73,7 @@ class UCTNode(MCTSNode):
|
|||||||
def valid_mask(self, simulator):
|
def valid_mask(self, simulator):
|
||||||
# let all invalid actions be illeagel in mcts
|
# let all invalid actions be illeagel in mcts
|
||||||
if self.mask is None:
|
if self.mask is None:
|
||||||
self.mask = simulator.simulate_is_valid_list(self.state, range(self.action_num))
|
self.mask = simulator.simulate_get_mask(self.state, range(self.action_num))
|
||||||
self.ucb[self.mask] = -float("Inf")
|
self.ucb[self.mask] = -float("Inf")
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user