Merge pull request #5 from sproblvem/union_set

add union set for do_move and is_valid
The modify on play.py should be removed, I will fix it on latter commit
This commit is contained in:
sproblvem 2018-02-23 15:01:17 +08:00 committed by GitHub
commit a0849fa213
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 161 additions and 65 deletions

View File

@ -40,6 +40,9 @@ class Game:
self.history_hashtable = set() self.history_hashtable = set()
self.game_engine = go.Go(size=self.size, komi=self.komi) self.game_engine = go.Go(size=self.size, komi=self.komi)
self.board = [utils.EMPTY] * (self.size ** 2) self.board = [utils.EMPTY] * (self.size ** 2)
self.group_ancestors = {} # key: idx, value: ancestor idx
self.liberty = {} # key: ancestor idx, value: set of liberty
self.stones = {} # key: ancestor idx, value: set of stones
elif self.name == "reversi": elif self.name == "reversi":
self.size = 8 self.size = 8
self.history_length = 1 self.history_length = 1
@ -62,6 +65,9 @@ class Game:
self.board = [utils.EMPTY] * (self.size ** 2) self.board = [utils.EMPTY] * (self.size ** 2)
del self.history[:] del self.history[:]
self.history_hashtable.clear() self.history_hashtable.clear()
self.group_ancestors.clear()
self.liberty.clear()
self.stones.clear()
if self.name == "reversi": if self.name == "reversi":
self.board = self.game_engine.get_board() self.board = self.game_engine.get_board()
for _ in range(self.history_length): for _ in range(self.history_length):
@ -109,7 +115,7 @@ class Game:
res = self.game_engine.executor_do_move(self.history, self.latest_boards, self.board, color, vertex) res = self.game_engine.executor_do_move(self.history, self.latest_boards, self.board, color, vertex)
if self.name == "go": if self.name == "go":
res = self.game_engine.executor_do_move(self.history, self.history_hashtable, self.latest_boards, self.board, res = self.game_engine.executor_do_move(self.history, self.history_hashtable, self.latest_boards, self.board,
color, vertex) self.group_ancestors, self.liberty, self.stones, color, vertex)
return res return res
def think_play_move(self, color): def think_play_move(self, color):

View File

@ -44,6 +44,23 @@ class Go:
nei.append((_x, _y)) nei.append((_x, _y))
return nei return nei
def _neighbor_color(self, current_board, vertex, color):
# return neighbors which are listed in different colors
color_neighbor = [] # 1)neighbors in the same color
reverse_color_neighbor = [] # 2)neighbors in the reverse color
empty_neighbor = [] # 2)empty neighbors
reverse_color = utils.BLACK if color == utils.WHITE else utils.WHITE
for n in self._neighbor(vertex):
if current_board[self._flatten(n)] == color:
color_neighbor.append(self._flatten(n))
elif current_board[self._flatten(n)] == utils.EMPTY:
empty_neighbor.append(self._flatten(n))
elif current_board[self._flatten(n)] == reverse_color:
reverse_color_neighbor.append(self._flatten(n))
else:
raise ValueError("board have other positions excluding BLACK, WHITE and EMPTY")
return color_neighbor, reverse_color_neighbor, empty_neighbor
def _corner(self, vertex): def _corner(self, vertex):
x, y = vertex x, y = vertex
corner = [] corner = []
@ -56,13 +73,11 @@ class Go:
def _find_group(self, current_board, vertex): def _find_group(self, current_board, vertex):
color = current_board[self._flatten(vertex)] color = current_board[self._flatten(vertex)]
# print ("color : ", color)
chain = set() chain = set()
frontier = [vertex] frontier = [vertex]
has_liberty = False has_liberty = False
while frontier: while frontier:
current = frontier.pop() current = frontier.pop()
# print ("current : ", current)
chain.add(current) chain.add(current)
for n in self._neighbor(current): for n in self._neighbor(current):
if current_board[self._flatten(n)] == color and not n in chain: if current_board[self._flatten(n)] == color and not n in chain:
@ -71,21 +86,26 @@ class Go:
has_liberty = True has_liberty = True
return has_liberty, chain return has_liberty, chain
def _is_suicide(self, current_board, color, vertex): def _find_ancestor(self, group_ancestors, idx):
current_board[self._flatten(vertex)] = color # assume that we already take this move r = idx
suicide = False while group_ancestors[r] != r:
r = group_ancestors[r]
group_ancestors[idx] = r
return r
has_liberty, group = self._find_group(current_board, vertex) def _is_suicide(self, current_board, group_ancestors, liberty, color, vertex):
if not has_liberty: color_neighbor, reverse_color_neighbor, empty_neighbor = self._neighbor_color(current_board, vertex, color)
suicide = True # no liberty, suicide if empty_neighbor:
for n in self._neighbor(vertex): return False # neighbors have empty spaces
if current_board[self._flatten(n)] == utils.another_color(color): elif color_neighbor: # neighbors have same color, they have liberties
opponent_liberty, group = self._find_group(current_board, n) for idx in color_neighbor:
if not opponent_liberty: if len(liberty[self._find_ancestor(group_ancestors, idx)]) > 1:
suicide = False # this move is able to take opponent's stone, not suicide return False
else: # neighbors have reverse color, they have only one liberty
current_board[self._flatten(vertex)] = utils.EMPTY # undo this move for idx in reverse_color_neighbor:
return suicide if len(liberty[self._find_ancestor(group_ancestors, idx)]) == 1:
return False
return True
def _process_board(self, current_board, color, vertex): def _process_board(self, current_board, color, vertex):
nei = self._neighbor(vertex) nei = self._neighbor(vertex)
@ -107,25 +127,15 @@ class Go:
return repeat return repeat
def _is_eye(self, current_board, color, vertex): def _is_eye(self, current_board, color, vertex):
nei = self._neighbor(vertex) # return is this position is an real eye of color
cor = self._corner(vertex) color_neighbor, reverse_color_neighbor, empty_neighbor = self._neighbor_color(current_board, vertex, color)
ncolor = {color == current_board[self._flatten(n)] for n in nei} if reverse_color_neighbor or empty_neighbor: # not an eye
if False in ncolor:
# print "not all neighbors are in same color with us"
return False return False
_, group = self._find_group(current_board, nei[0]) cor = self._corner(vertex)
if set(nei) < group: opponent_number = [current_board[self._flatten(c)] for c in cor].count(-color)
# print "all neighbors are in same group and same color with us" opponent_propotion = float(opponent_number) / float(len(cor))
return True # opponent_propotion<0.5 fake eye
else: return True if opponent_propotion < 0.5 else False
opponent_number = [current_board[self._flatten(c)] for c in cor].count(-color)
opponent_propotion = float(opponent_number) / float(len(cor))
if opponent_propotion < 0.5:
# print "few opponents, real eye"
return True
else:
# print "many opponents, fake eye"
return False
def _knowledge_prunning(self, current_board, color, vertex): def _knowledge_prunning(self, current_board, color, vertex):
# forbid some stupid selfplay using human knowledge # forbid some stupid selfplay using human knowledge
@ -134,23 +144,6 @@ class Go:
# forbid position on its own eye. # forbid position on its own eye.
return True return True
def _is_game_finished(self, current_board, color):
'''
for each empty position, if it has both BLACK and WHITE neighbors, the game is still not finished
:return: return the game is finished
'''
board = copy.deepcopy(current_board)
empty_idx = [i for i, x in enumerate(board) if x == utils.EMPTY] # find all empty idx
for idx in empty_idx:
neighbor_idx = self._neighbor(self.deflatten(idx))
if len(neighbor_idx) > 1:
first_idx = neighbor_idx[0]
for other_idx in neighbor_idx[1:]:
if board[self.flatten(other_idx)] != board[self.flatten(first_idx)]:
return False
return True
def _action2vertex(self, action): def _action2vertex(self, action):
if action == self.size ** 2: if action == self.size ** 2:
vertex = (0, 0) vertex = (0, 0)
@ -158,7 +151,7 @@ class Go:
vertex = self._deflatten(action) vertex = self._deflatten(action)
return vertex return vertex
def _rule_check(self, history_hashtable, current_board, color, vertex, is_thinking=True): def _rule_check(self, history_hashtable, current_board, group_ancestors, liberty, color, vertex, is_thinking=True):
### in board ### in board
if not self._in_board(vertex): if not self._in_board(vertex):
if not is_thinking: if not is_thinking:
@ -174,7 +167,7 @@ class Go:
return False return False
### check if it is suicide ### check if it is suicide
if self._is_suicide(current_board, color, vertex): if self._is_suicide(current_board, group_ancestors, liberty, color, vertex):
if not is_thinking: if not is_thinking:
raise ValueError("Target point causes suicide, Current Board: {}, color: {}, vertex : {}".format(current_board, color, vertex)) raise ValueError("Target point causes suicide, Current Board: {}, color: {}, vertex : {}".format(current_board, color, vertex))
else: else:
@ -189,34 +182,59 @@ class Go:
return True return True
def _is_valid(self, state, action, history_hashtable): def _is_valid(self, state, action, history_hashtable, group_ancestors, liberty):
history_boards, color = state history_boards, color = state
vertex = self._action2vertex(action) vertex = self._action2vertex(action)
current_board = history_boards[-1] current_board = history_boards[-1]
if not self._rule_check(history_hashtable, current_board, color, vertex): if not self._rule_check(history_hashtable, current_board, group_ancestors, liberty, color, vertex):
return False return False
if not self._knowledge_prunning(current_board, color, vertex): if not self._knowledge_prunning(current_board, color, vertex):
return False return False
return True return True
def _get_groups(self, board):
group_ancestors = {} # key: idx, value: ancestor idx
liberty = {} # key: ancestor idx, value: set of liberty
for idx, color in enumerate(board):
if color and idx not in group_ancestors:
# build group
group_ancestors[idx] = idx
color_neighbor, _, empty_neighbor = \
self._neighbor_color(board, self._deflatten(idx), color)
liberty[idx] = set(empty_neighbor)
group_list = copy.deepcopy(color_neighbor)
while group_list:
add_idx = group_list.pop()
if add_idx not in group_ancestors:
group_ancestors[add_idx] = idx
color_neighbor_add, _, empty_neighbor_add = \
self._neighbor_color(board, self._deflatten(add_idx), color)
group_list += color_neighbor_add
liberty[idx] |= set(empty_neighbor_add)
return group_ancestors, liberty
def simulate_get_mask(self, state, action_set): def simulate_get_mask(self, state, action_set):
# find all the invalid actions # find all the invalid actions
invalid_action_mask = [] invalid_action_mask = []
history_boards, color = state history_boards, color = state
group_ancestors, liberty = self._get_groups(history_boards[-1])
history_hashtable = set() history_hashtable = set()
for board in history_boards: for board in history_boards:
history_hashtable.add(tuple(board)) history_hashtable.add(tuple(board))
for action_candidate in action_set[:-1]: for action_candidate in action_set[:-1]:
# go through all the actions excluding pass # go through all the actions excluding pass
if not self._is_valid(state, action_candidate, history_hashtable): if not self._is_valid(state, action_candidate, history_hashtable, group_ancestors, liberty):
invalid_action_mask.append(action_candidate) invalid_action_mask.append(action_candidate)
if len(invalid_action_mask) < len(action_set) - 1: if len(invalid_action_mask) < len(action_set) - 1:
invalid_action_mask.append(action_set[-1]) invalid_action_mask.append(action_set[-1])
# forbid pass, if we have other choices # forbid pass, if we have other choices
# TODO: In fact we should not do this. In some extreme cases, we should permit pass. # TODO: In fact we should not do this. In some extreme cases, we should permit pass.
del history_hashtable del history_hashtable
del group_ancestors
del liberty
# del stones
return invalid_action_mask return invalid_action_mask
def _do_move(self, board, color, vertex): def _do_move(self, board, color, vertex):
@ -243,12 +261,70 @@ class Go:
# since go is MDP, we only need the last board for hashing # since go is MDP, we only need the last board for hashing
return tuple(state[0][-1]) return tuple(state[0][-1])
def executor_do_move(self, history, history_hashtable, latest_boards, current_board, color, vertex): def _join_group(self, idx, idx_list, empty_neighbor, group_ancestors, liberty, stones):
if not self._rule_check(history_hashtable, current_board, color, vertex, is_thinking=False): # idx joins its neighbors id_list
# raise ValueError("!!! We have more than four ko at the same time !!!") # empty_neighbor: empty neighbors of idx
color_ancestor = set()
for color_idx in idx_list:
color_ancestor.add(self._find_ancestor(group_ancestors, color_idx))
joined_ancestor = color_ancestor.pop()
liberty[joined_ancestor] |= set(empty_neighbor)
stones[joined_ancestor].add(idx)
group_ancestors[idx] = joined_ancestor
# add other groups
for color_idx in color_ancestor:
liberty[joined_ancestor] |= liberty[color_idx]
stones[joined_ancestor] |= stones[color_idx]
del liberty[color_idx]
for stone in stones[color_idx]:
group_ancestors[stone] = joined_ancestor
del stones[color_idx]
liberty[joined_ancestor].remove(idx)
def _add_captured_liberty(self, board, liberty, group_ancestors, stones):
for captured_stone in stones:
color_neighbor, reverse_color_neighbor, empty_neighbor = \
self._neighbor_color(board, self._deflatten(captured_stone), board[captured_stone])
assert not empty_neighbor # make sure no empty spaces
for reverse_color_idx in reverse_color_neighbor:
reverse_color_idx_ancestor = self._find_ancestor(group_ancestors, reverse_color_idx)
liberty[reverse_color_idx_ancestor].add(captured_stone)
def _remove_liberty(self, idx, reverse_color_neighbor, current_board, group_ancestors, liberty, stones):
# reverse_color_neighbor: stones near idx in the reverse color
reverse_color_ancestor = set()
for reverse_idx in reverse_color_neighbor:
reverse_color_ancestor.add(self._find_ancestor(group_ancestors, reverse_idx))
for reverse_color_ancestor_idx in reverse_color_ancestor:
if len(liberty[reverse_color_ancestor_idx]) == 1:
# capture this group if no liberty left
self._add_captured_liberty(current_board, liberty, group_ancestors, stones[reverse_color_ancestor_idx])
for captured_stone in stones[reverse_color_ancestor_idx]:
current_board[captured_stone] = utils.EMPTY
del group_ancestors[captured_stone]
del liberty[reverse_color_ancestor_idx]
del stones[reverse_color_ancestor_idx]
else:
# remove this liberty
liberty[reverse_color_ancestor_idx].remove(idx)
def executor_do_move(self, history, history_hashtable, latest_boards, current_board, group_ancestors, liberty, stones, color, vertex):
#print("===")
#print(color, vertex)
#print(group_ancestors, liberty, stones)
if not self._rule_check(history_hashtable, current_board, group_ancestors, liberty, color, vertex):
return False return False
current_board[self._flatten(vertex)] = color idx = self._flatten(vertex)
self._process_board(current_board, color, vertex) current_board[idx] = color
color_neighbor, reverse_color_neighbor, empty_neighbor = self._neighbor_color(current_board, vertex, color)
if color_neighbor: # join nearby groups
self._join_group(idx, color_neighbor, empty_neighbor, group_ancestors, liberty, stones)
else: # build a new group
group_ancestors[idx] = idx
liberty[idx] = set(empty_neighbor)
stones[idx] = {idx}
if reverse_color_neighbor: # remove liberty for nearby reverse color
self._remove_liberty(idx, reverse_color_neighbor, current_board, group_ancestors, liberty, stones)
history.append(copy.deepcopy(current_board)) history.append(copy.deepcopy(current_board))
latest_boards.append(copy.deepcopy(current_board)) latest_boards.append(copy.deepcopy(current_board))
history_hashtable.add(copy.deepcopy(tuple(current_board))) history_hashtable.add(copy.deepcopy(tuple(current_board)))
@ -362,6 +438,8 @@ if __name__ == "__main__":
1, 0, 1, 1, 1, 1, 1, -1, 0, 1, 0, 1, 1, 1, 1, 1, -1, 0,
1, 1, 0, 1, -1, -1, -1, -1, -1 1, 1, 0, 1, -1, -1, -1, -1, -1
] ]
'''
time0 = time.time() time0 = time.time()
score = go.executor_get_score(endgame) score = go.executor_get_score(endgame)
time1 = time.time() time1 = time.time()
@ -370,6 +448,7 @@ if __name__ == "__main__":
time2 = time.time() time2 = time.time()
print(score, time2 - time1) print(score, time2 - time1)
''' '''
'''
### do unit test for Go class ### do unit test for Go class
pure_test = [ pure_test = [
0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0,

View File

@ -1,3 +1,4 @@
from __future__ import division
import argparse import argparse
import sys import sys
import re import re
@ -28,10 +29,13 @@ def play(engine, data_path):
size = {"go": 9, "reversi": 8} size = {"go": 9, "reversi": 8}
show = ['.', 'X', 'O'] show = ['.', 'X', 'O']
# evaluate_rounds = 100 evaluate_rounds = 5
game_num = 0 game_num = 0
while True: total = 0
# while game_num < evaluate_rounds: f=open('time.txt','w')
#while True:
while game_num < evaluate_rounds:
start = time.time()
engine._game.model.check_latest_model() engine._game.model.check_latest_model()
num = 0 num = 0
pass_flag = [False, False] pass_flag = [False, False]
@ -77,6 +81,13 @@ def play(engine, data_path):
cPickle.dump(data, file) cPickle.dump(data, file)
data.reset() data.reset()
game_num += 1 game_num += 1
this_time = time.time() - start
total += this_time
f.write('time:'+ str(this_time)+'\n')
f.write('Avg time:' + str(total/evaluate_rounds))
f.close()
if __name__ == '__main__': if __name__ == '__main__':