Merge remote-tracking branch 'origin/master'

This commit is contained in:
haoshengzou 2017-12-23 17:25:37 +08:00
commit 238039b854
4 changed files with 151 additions and 53 deletions

View File

@ -212,11 +212,14 @@ class Go:
def simulate_step_forward(self, state, action): def simulate_step_forward(self, state, action):
# initialize the simulate_board from state # initialize the simulate_board from state
history_boards, color = state history_boards, color = state
vertex = self._action2vertex(action) if history_boards[-1] == history_boards[-2] and action is utils.PASS:
new_board = self._do_move(copy.copy(history_boards[-1]), color, vertex) return None, 2 * (float(self.executor_get_score(history_boards[-1]) > 0)-0.5) * color
history_boards.append(new_board) else:
new_color = -color vertex = self._action2vertex(action)
return [history_boards, new_color], 0 new_board = self._do_move(copy.copy(history_boards[-1]), color, vertex)
history_boards.append(new_board)
new_color = -color
return [history_boards, new_color], 0
def executor_do_move(self, history, latest_boards, current_board, color, vertex): def executor_do_move(self, history, latest_boards, current_board, color, vertex):
if not self._rule_check(history, current_board, color, vertex): if not self._rule_check(history, current_board, color, vertex):

View File

@ -1,7 +1,6 @@
import os import os
import time import time
import random import copy
import sys
import cPickle import cPickle
from collections import deque from collections import deque
@ -224,11 +223,21 @@ class ResNet(object):
else: else:
start_time = time.time() start_time = time.time()
for i in range(batch_size): for i in range(batch_size):
game_num = random.randint(0, self.window_length-1) priority = self.training_data['length'] / sum(self.training_data['length'])
state_num = random.randint(0, self.training_data['length'][game_num]-1) game_num = np.random.choice(self.window_length, 1, p=priority)
training_data['states'].append(np.expand_dims(self.training_data['states'][game_num][state_num], 0)) state_num = np.random.randint(self.training_data['length'][game_num])
training_data['probs'].append(np.expand_dims(self.training_data['probs'][game_num][state_num], 0)) rotate_times = np.random.randint(4)
training_data['winner'].append(np.expand_dims(self.training_data['winner'][game_num][state_num], 0)) reflect_times = np.random.randint(2)
reflect_orientation = np.random.randint(2)
training_data['states'].append(
self._preprocession(self.training_data['states'][game_num][state_num], reflect_times,
reflect_orientation, rotate_times))
training_data['probs'].append(
self._preprocession(self.training_data['probs'][game_num][state_num], reflect_times,
reflect_orientation, rotate_times))
training_data['winner'].append(
self._preprocession(self.training_data['winner'][game_num][state_num], reflect_times,
reflect_orientation, rotate_times))
value_loss, policy_loss, reg, _ = self.sess.run( value_loss, policy_loss, reg, _ = self.sess.run(
[self.value_loss, self.policy_loss, self.reg, self.train_op], [self.value_loss, self.policy_loss, self.reg, self.train_op],
feed_dict={self.x: np.concatenate(training_data['states'], axis=0), feed_dict={self.x: np.concatenate(training_data['states'], axis=0),
@ -280,6 +289,55 @@ class ResNet(object):
winner = np.concatenate(winner, axis=0) winner = np.concatenate(winner, axis=0)
return states, probs, winner return states, probs, winner
def _preprocession(self, board, reflect_times=0, reflect_orientation=0, rotate_times=0):
"""
preprocessing for augmentation
:param board: a ndarray, board to process
:param reflect_times: an integer, how many times to reflect
:param reflect_orientation: an integer, which orientation to reflect
:param rotate_times: an integer, how many times to rotate
:return:
"""
new_board = copy.copy(board)
if new_board.ndim == 3:
np.expand_dims(new_board, axis=0)
new_board = self._board_reflection(new_board, reflect_times, reflect_orientation)
new_board = self._board_rotation(new_board, rotate_times)
return new_board
def _board_rotation(self, board, times):
"""
rotate the board for augmentation
note that board's shape should be [batch_size, board_size, board_size, channels]
:param board: a ndarray, shape [batch_size, board_size, board_size, channels]
:param times: an integer, how many times to rotate
:return:
"""
return np.rot90(board, times, (1, 2))
def _board_reflection(self, board, times, orientation):
"""
reflect the board for augmentation
note that board's shape should be [batch_size, board_size, board_size, channels]
:param board: a ndarray, shape [batch_size, board_size, board_size, channels]
:param times: an integer, how many times to reflect
:param orientation: an integer, which orientation to reflect
:return:
"""
new_board = copy.copy(board)
for _ in range(times):
if orientation == 0:
new_board = new_board[:, ::-1]
if orientation == 1:
new_board = new_board[:, :, ::-1]
return new_board
if __name__ == "__main__": if __name__ == "__main__":
model = ResNet(board_size=9, action_num=82, history_length=8) model = ResNet(board_size=9, action_num=82, history_length=8)

View File

@ -25,6 +25,7 @@ def find_correct_moves(own, enemy):
mobility |= search_offset_right(own, enemy, mask, 7) # Left bottom mobility |= search_offset_right(own, enemy, mask, 7) # Left bottom
return mobility return mobility
def calc_flip(pos, own, enemy): def calc_flip(pos, own, enemy):
"""return flip stones of enemy by bitboard when I place stone at pos. """return flip stones of enemy by bitboard when I place stone at pos.
@ -123,8 +124,9 @@ class Reversi:
self.board = None # 8 * 8 board with 1 for black, -1 for white and 0 for blank self.board = None # 8 * 8 board with 1 for black, -1 for white and 0 for blank
self.color = None # 1 for black and -1 for white self.color = None # 1 for black and -1 for white
self.action = None # number in 0~63 self.action = None # number in 0~63
# self.winner = None self.winner = None
self.black_win = None self.black_win = None
self.size = 8
def get_board(self, black=None, white=None): def get_board(self, black=None, white=None):
self.black = black or (0b00001000 << 24 | 0b00010000 << 32) self.black = black or (0b00001000 << 24 | 0b00010000 << 32)
@ -132,22 +134,29 @@ class Reversi:
self.board = self.bitboard2board() self.board = self.bitboard2board()
return self.board return self.board
def is_valid(self, is_next=False):
self.board2bitboard()
own, enemy = self.get_own_and_enemy(is_next)
mobility = find_correct_moves(own, enemy)
valid_moves = bit_to_array(mobility, 64)
valid_moves = np.argwhere(valid_moves)
valid_moves = list(np.reshape(valid_moves, len(valid_moves)))
return valid_moves
def simulate_get_mask(self, state, action_set): def simulate_get_mask(self, state, action_set):
history_boards, color = state history_boards, color = state
board = history_boards[-1] board = history_boards[-1]
self.board = board self.board = board
self.color = color self.color = color
self.board2bitboard() valid_moves = self.is_valid()
own, enemy = self.get_own_and_enemy()
mobility = find_correct_moves(own, enemy)
valid_moves = bit_to_array(mobility, 64)
valid_moves = np.argwhere(valid_moves)
valid_moves = list(np.reshape(valid_moves, len(valid_moves)))
# TODO it seems that the pass move is not considered # TODO it seems that the pass move is not considered
invalid_action_mask = [] if not len(valid_moves):
for action in action_set: invalid_action_mask = action_set[0:-1]
if action not in valid_moves: else:
invalid_action_mask.append(action) invalid_action_mask = []
for action in action_set:
if action not in valid_moves:
invalid_action_mask.append(action)
return invalid_action_mask return invalid_action_mask
def simulate_step_forward(self, state, action): def simulate_step_forward(self, state, action):
@ -155,21 +164,34 @@ class Reversi:
self.color = state[1] self.color = state[1]
self.board2bitboard() self.board2bitboard()
self.action = action self.action = action
step_forward = self.step() if self.action == 64:
if step_forward: valid_moves = self.is_valid(is_next=True)
new_board = self.bitboard2board() if not len(valid_moves):
return [new_board, 0 - self.color], 0 self._game_over()
return None, self.winner * self.color
else:
return [self.board, 0 - self.color], 0
self.step()
new_board = self.bitboard2board()
return [new_board, 0 - self.color], 0
def executor_do_move(self, board, color, vertex): def executor_do_move(self, board, color, vertex):
self.board = board self.board = board
self.color = color self.color = color
self.board2bitboard() self.board2bitboard()
self.vertex2action(vertex) self.action = self._flatten(vertex)
step_forward = self.step() if self.action == 64:
if step_forward: valid_moves = self.is_valid(is_next=True)
if not len(valid_moves):
return False
else:
return True
else:
self.step()
new_board = self.bitboard2board() new_board = self.bitboard2board()
for i in range(64): for i in range(64):
board[i] = new_board[i] board[i] = new_board[i]
return True
def executor_get_score(self, board): def executor_get_score(self, board):
self.board = board self.board = board
@ -191,13 +213,14 @@ class Reversi:
elif self.board[i] == -1: elif self.board[i] == -1:
self.white |= count self.white |= count
count *= 2 count *= 2
'''
def vertex2action(self, vertex): def vertex2action(self, vertex):
x, y = vertex x, y = vertex
if x == 0 and y == 0: if x == 0 and y == 0:
self.action = None self.action = None
else: else:
self.action = 8 * (x - 1) + y - 1 self.action = 8 * (x - 1) + y - 1
'''
def bitboard2board(self): def bitboard2board(self):
board = [] board = []
@ -214,46 +237,45 @@ class Reversi:
def step(self): def step(self):
if self.action < 0 or self.action > 63: if self.action < 0 or self.action > 63:
raise ValueError("Wrong action!") raise ValueError("Action not in the range of [0,63]!")
if self.action is None: if self.action is None:
return False raise ValueError("Action is None!")
own, enemy = self.get_own_and_enemy() own, enemy = self.get_own_and_enemy()
flipped = calc_flip(self.action, own, enemy) flipped = calc_flip(self.action, own, enemy)
if bit_count(flipped) == 0: if bit_count(flipped) == 0:
self.illegal_move_to_lose(self.action) # self.illegal_move_to_lose(self.action)
return False raise ValueError("Illegal action!")
own ^= flipped own ^= flipped
own |= 1 << self.action own |= 1 << self.action
enemy ^= flipped enemy ^= flipped
self.set_own_and_enemy(own, enemy) self.set_own_and_enemy(own, enemy)
return True
def _game_over(self): def _game_over(self):
# self.done = True # self.done = True
'''
if self.winner is None: if self.winner is None:
black_num, white_num = self.number_of_black_and_white black_num, white_num = self.number_of_black_and_white
if black_num > white_num: self.black_win = black_num - white_num
if self.black_win > 0:
self.winner = 1 self.winner = 1
elif black_num < white_num: elif self.black_win < 0:
self.winner = -1 self.winner = -1
else: else:
self.winner = 0 self.winner = 0
'''
if self.black_win is None:
black_num, white_num = self.number_of_black_and_white
self.black_win = black_num - white_num
def illegal_move_to_lose(self, action): def illegal_move_to_lose(self, action):
self._game_over() self._game_over()
def get_own_and_enemy(self): def get_own_and_enemy(self, is_next=False):
if self.color == 1: if is_next:
color = 0 - self.color
else:
color = self.color
if color == 1:
own, enemy = self.black, self.white own, enemy = self.black, self.white
elif self.color == -1: elif color == -1:
own, enemy = self.white, self.black own, enemy = self.white, self.black
else: else:
own, enemy = None, None own, enemy = None, None
@ -265,6 +287,17 @@ class Reversi:
else: else:
self.white, self.black = own, enemy self.white, self.black = own, enemy
def _deflatten(self, idx):
x = idx // self.size + 1
y = idx % self.size + 1
return (x, y)
def _flatten(self, vertex):
x, y = vertex
if (x == 0) and (y == 0):
return 64
return (x - 1) * self.size + (y - 1)
@property @property
def number_of_black_and_white(self): def number_of_black_and_white(self):
return bit_count(self.black), bit_count(self.white) return bit_count(self.black), bit_count(self.white)

View File

@ -38,6 +38,7 @@ class MCTSNode(object):
def valid_mask(self, simulator): def valid_mask(self, simulator):
pass pass
class UCTNode(MCTSNode): class UCTNode(MCTSNode):
def __init__(self, parent, action, state, action_num, prior, inverse=False): def __init__(self, parent, action, state, action_num, prior, inverse=False):
super(UCTNode, self).__init__(parent, action, state, action_num, prior, inverse) super(UCTNode, self).__init__(parent, action, state, action_num, prior, inverse)
@ -71,10 +72,13 @@ class UCTNode(MCTSNode):
self.parent.backpropagation(self.children[action].reward) self.parent.backpropagation(self.children[action].reward)
def valid_mask(self, simulator): def valid_mask(self, simulator):
# let all invalid actions be illeagel in mcts # let all invalid actions be illegal in mcts
if self.mask is None: if not hasattr(simulator, 'simulate_get_mask'):
self.mask = simulator.simulate_get_mask(self.state, range(self.action_num)) pass
self.ucb[self.mask] = -float("Inf") else:
if self.mask is None:
self.mask = simulator.simulate_get_mask(self.state, range(self.action_num))
self.ucb[self.mask] = -float("Inf")
class TSNode(MCTSNode): class TSNode(MCTSNode):