Merge remote-tracking branch 'origin/master'
This commit is contained in:
commit
b2b2d01d9c
@ -212,11 +212,14 @@ class Go:
|
||||
def simulate_step_forward(self, state, action):
|
||||
# initialize the simulate_board from state
|
||||
history_boards, color = state
|
||||
vertex = self._action2vertex(action)
|
||||
new_board = self._do_move(copy.copy(history_boards[-1]), color, vertex)
|
||||
history_boards.append(new_board)
|
||||
new_color = -color
|
||||
return [history_boards, new_color], 0
|
||||
if history_boards[-1] == history_boards[-2] and action is utils.PASS:
|
||||
return None, 2 * (float(self.executor_get_score(history_boards[-1]) > 0)-0.5) * color
|
||||
else:
|
||||
vertex = self._action2vertex(action)
|
||||
new_board = self._do_move(copy.copy(history_boards[-1]), color, vertex)
|
||||
history_boards.append(new_board)
|
||||
new_color = -color
|
||||
return [history_boards, new_color], 0
|
||||
|
||||
def executor_do_move(self, history, latest_boards, current_board, color, vertex):
|
||||
if not self._rule_check(history, current_board, color, vertex):
|
||||
|
@ -1,7 +1,6 @@
|
||||
import os
|
||||
import time
|
||||
import random
|
||||
import sys
|
||||
import copy
|
||||
import cPickle
|
||||
from collections import deque
|
||||
|
||||
@ -224,11 +223,21 @@ class ResNet(object):
|
||||
else:
|
||||
start_time = time.time()
|
||||
for i in range(batch_size):
|
||||
game_num = random.randint(0, self.window_length-1)
|
||||
state_num = random.randint(0, self.training_data['length'][game_num]-1)
|
||||
training_data['states'].append(np.expand_dims(self.training_data['states'][game_num][state_num], 0))
|
||||
training_data['probs'].append(np.expand_dims(self.training_data['probs'][game_num][state_num], 0))
|
||||
training_data['winner'].append(np.expand_dims(self.training_data['winner'][game_num][state_num], 0))
|
||||
priority = self.training_data['length'] / sum(self.training_data['length'])
|
||||
game_num = np.random.choice(self.window_length, 1, p=priority)
|
||||
state_num = np.random.randint(self.training_data['length'][game_num])
|
||||
rotate_times = np.random.randint(4)
|
||||
reflect_times = np.random.randint(2)
|
||||
reflect_orientation = np.random.randint(2)
|
||||
training_data['states'].append(
|
||||
self._preprocession(self.training_data['states'][game_num][state_num], reflect_times,
|
||||
reflect_orientation, rotate_times))
|
||||
training_data['probs'].append(
|
||||
self._preprocession(self.training_data['probs'][game_num][state_num], reflect_times,
|
||||
reflect_orientation, rotate_times))
|
||||
training_data['winner'].append(
|
||||
self._preprocession(self.training_data['winner'][game_num][state_num], reflect_times,
|
||||
reflect_orientation, rotate_times))
|
||||
value_loss, policy_loss, reg, _ = self.sess.run(
|
||||
[self.value_loss, self.policy_loss, self.reg, self.train_op],
|
||||
feed_dict={self.x: np.concatenate(training_data['states'], axis=0),
|
||||
@ -280,6 +289,55 @@ class ResNet(object):
|
||||
winner = np.concatenate(winner, axis=0)
|
||||
return states, probs, winner
|
||||
|
||||
def _preprocession(self, board, reflect_times=0, reflect_orientation=0, rotate_times=0):
|
||||
"""
|
||||
preprocessing for augmentation
|
||||
|
||||
:param board: a ndarray, board to process
|
||||
:param reflect_times: an integer, how many times to reflect
|
||||
:param reflect_orientation: an integer, which orientation to reflect
|
||||
:param rotate_times: an integer, how many times to rotate
|
||||
:return:
|
||||
"""
|
||||
|
||||
new_board = copy.copy(board)
|
||||
if new_board.ndim == 3:
|
||||
np.expand_dims(new_board, axis=0)
|
||||
|
||||
new_board = self._board_reflection(new_board, reflect_times, reflect_orientation)
|
||||
new_board = self._board_rotation(new_board, rotate_times)
|
||||
|
||||
return new_board
|
||||
|
||||
def _board_rotation(self, board, times):
|
||||
"""
|
||||
rotate the board for augmentation
|
||||
note that board's shape should be [batch_size, board_size, board_size, channels]
|
||||
|
||||
:param board: a ndarray, shape [batch_size, board_size, board_size, channels]
|
||||
:param times: an integer, how many times to rotate
|
||||
:return:
|
||||
"""
|
||||
return np.rot90(board, times, (1, 2))
|
||||
|
||||
def _board_reflection(self, board, times, orientation):
|
||||
"""
|
||||
reflect the board for augmentation
|
||||
note that board's shape should be [batch_size, board_size, board_size, channels]
|
||||
|
||||
:param board: a ndarray, shape [batch_size, board_size, board_size, channels]
|
||||
:param times: an integer, how many times to reflect
|
||||
:param orientation: an integer, which orientation to reflect
|
||||
:return:
|
||||
"""
|
||||
new_board = copy.copy(board)
|
||||
for _ in range(times):
|
||||
if orientation == 0:
|
||||
new_board = new_board[:, ::-1]
|
||||
if orientation == 1:
|
||||
new_board = new_board[:, :, ::-1]
|
||||
return new_board
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
model = ResNet(board_size=9, action_num=82, history_length=8)
|
||||
|
@ -25,6 +25,7 @@ def find_correct_moves(own, enemy):
|
||||
mobility |= search_offset_right(own, enemy, mask, 7) # Left bottom
|
||||
return mobility
|
||||
|
||||
|
||||
def calc_flip(pos, own, enemy):
|
||||
"""return flip stones of enemy by bitboard when I place stone at pos.
|
||||
|
||||
@ -123,8 +124,9 @@ class Reversi:
|
||||
self.board = None # 8 * 8 board with 1 for black, -1 for white and 0 for blank
|
||||
self.color = None # 1 for black and -1 for white
|
||||
self.action = None # number in 0~63
|
||||
# self.winner = None
|
||||
self.winner = None
|
||||
self.black_win = None
|
||||
self.size = 8
|
||||
|
||||
def get_board(self, black=None, white=None):
|
||||
self.black = black or (0b00001000 << 24 | 0b00010000 << 32)
|
||||
@ -132,22 +134,29 @@ class Reversi:
|
||||
self.board = self.bitboard2board()
|
||||
return self.board
|
||||
|
||||
def is_valid(self, is_next=False):
|
||||
self.board2bitboard()
|
||||
own, enemy = self.get_own_and_enemy(is_next)
|
||||
mobility = find_correct_moves(own, enemy)
|
||||
valid_moves = bit_to_array(mobility, 64)
|
||||
valid_moves = np.argwhere(valid_moves)
|
||||
valid_moves = list(np.reshape(valid_moves, len(valid_moves)))
|
||||
return valid_moves
|
||||
|
||||
def simulate_get_mask(self, state, action_set):
|
||||
history_boards, color = state
|
||||
board = history_boards[-1]
|
||||
self.board = board
|
||||
self.color = color
|
||||
self.board2bitboard()
|
||||
own, enemy = self.get_own_and_enemy()
|
||||
mobility = find_correct_moves(own, enemy)
|
||||
valid_moves = bit_to_array(mobility, 64)
|
||||
valid_moves = np.argwhere(valid_moves)
|
||||
valid_moves = list(np.reshape(valid_moves, len(valid_moves)))
|
||||
valid_moves = self.is_valid()
|
||||
# TODO it seems that the pass move is not considered
|
||||
invalid_action_mask = []
|
||||
for action in action_set:
|
||||
if action not in valid_moves:
|
||||
invalid_action_mask.append(action)
|
||||
if not len(valid_moves):
|
||||
invalid_action_mask = action_set[0:-1]
|
||||
else:
|
||||
invalid_action_mask = []
|
||||
for action in action_set:
|
||||
if action not in valid_moves:
|
||||
invalid_action_mask.append(action)
|
||||
return invalid_action_mask
|
||||
|
||||
def simulate_step_forward(self, state, action):
|
||||
@ -155,21 +164,34 @@ class Reversi:
|
||||
self.color = state[1]
|
||||
self.board2bitboard()
|
||||
self.action = action
|
||||
step_forward = self.step()
|
||||
if step_forward:
|
||||
new_board = self.bitboard2board()
|
||||
return [new_board, 0 - self.color], 0
|
||||
if self.action == 64:
|
||||
valid_moves = self.is_valid(is_next=True)
|
||||
if not len(valid_moves):
|
||||
self._game_over()
|
||||
return None, self.winner * self.color
|
||||
else:
|
||||
return [self.board, 0 - self.color], 0
|
||||
self.step()
|
||||
new_board = self.bitboard2board()
|
||||
return [new_board, 0 - self.color], 0
|
||||
|
||||
def executor_do_move(self, board, color, vertex):
|
||||
self.board = board
|
||||
self.color = color
|
||||
self.board2bitboard()
|
||||
self.vertex2action(vertex)
|
||||
step_forward = self.step()
|
||||
if step_forward:
|
||||
self.action = self._flatten(vertex)
|
||||
if self.action == 64:
|
||||
valid_moves = self.is_valid(is_next=True)
|
||||
if not len(valid_moves):
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
else:
|
||||
self.step()
|
||||
new_board = self.bitboard2board()
|
||||
for i in range(64):
|
||||
board[i] = new_board[i]
|
||||
for i in range(64):
|
||||
board[i] = new_board[i]
|
||||
return True
|
||||
|
||||
def executor_get_score(self, board):
|
||||
self.board = board
|
||||
@ -191,13 +213,14 @@ class Reversi:
|
||||
elif self.board[i] == -1:
|
||||
self.white |= count
|
||||
count *= 2
|
||||
|
||||
'''
|
||||
def vertex2action(self, vertex):
|
||||
x, y = vertex
|
||||
if x == 0 and y == 0:
|
||||
self.action = None
|
||||
else:
|
||||
self.action = 8 * (x - 1) + y - 1
|
||||
'''
|
||||
|
||||
def bitboard2board(self):
|
||||
board = []
|
||||
@ -214,46 +237,45 @@ class Reversi:
|
||||
|
||||
def step(self):
|
||||
if self.action < 0 or self.action > 63:
|
||||
raise ValueError("Wrong action!")
|
||||
raise ValueError("Action not in the range of [0,63]!")
|
||||
if self.action is None:
|
||||
return False
|
||||
raise ValueError("Action is None!")
|
||||
|
||||
own, enemy = self.get_own_and_enemy()
|
||||
|
||||
flipped = calc_flip(self.action, own, enemy)
|
||||
if bit_count(flipped) == 0:
|
||||
self.illegal_move_to_lose(self.action)
|
||||
return False
|
||||
# self.illegal_move_to_lose(self.action)
|
||||
raise ValueError("Illegal action!")
|
||||
own ^= flipped
|
||||
own |= 1 << self.action
|
||||
enemy ^= flipped
|
||||
|
||||
self.set_own_and_enemy(own, enemy)
|
||||
return True
|
||||
|
||||
def _game_over(self):
|
||||
# self.done = True
|
||||
'''
|
||||
|
||||
if self.winner is None:
|
||||
black_num, white_num = self.number_of_black_and_white
|
||||
if black_num > white_num:
|
||||
self.black_win = black_num - white_num
|
||||
if self.black_win > 0:
|
||||
self.winner = 1
|
||||
elif black_num < white_num:
|
||||
elif self.black_win < 0:
|
||||
self.winner = -1
|
||||
else:
|
||||
self.winner = 0
|
||||
'''
|
||||
if self.black_win is None:
|
||||
black_num, white_num = self.number_of_black_and_white
|
||||
self.black_win = black_num - white_num
|
||||
|
||||
def illegal_move_to_lose(self, action):
|
||||
self._game_over()
|
||||
|
||||
def get_own_and_enemy(self):
|
||||
if self.color == 1:
|
||||
def get_own_and_enemy(self, is_next=False):
|
||||
if is_next:
|
||||
color = 0 - self.color
|
||||
else:
|
||||
color = self.color
|
||||
if color == 1:
|
||||
own, enemy = self.black, self.white
|
||||
elif self.color == -1:
|
||||
elif color == -1:
|
||||
own, enemy = self.white, self.black
|
||||
else:
|
||||
own, enemy = None, None
|
||||
@ -265,6 +287,17 @@ class Reversi:
|
||||
else:
|
||||
self.white, self.black = own, enemy
|
||||
|
||||
def _deflatten(self, idx):
|
||||
x = idx // self.size + 1
|
||||
y = idx % self.size + 1
|
||||
return (x, y)
|
||||
|
||||
def _flatten(self, vertex):
|
||||
x, y = vertex
|
||||
if (x == 0) and (y == 0):
|
||||
return 64
|
||||
return (x - 1) * self.size + (y - 1)
|
||||
|
||||
@property
|
||||
def number_of_black_and_white(self):
|
||||
return bit_count(self.black), bit_count(self.white)
|
||||
|
@ -38,6 +38,7 @@ class MCTSNode(object):
|
||||
def valid_mask(self, simulator):
|
||||
pass
|
||||
|
||||
|
||||
class UCTNode(MCTSNode):
|
||||
def __init__(self, parent, action, state, action_num, prior, inverse=False):
|
||||
super(UCTNode, self).__init__(parent, action, state, action_num, prior, inverse)
|
||||
@ -71,10 +72,13 @@ class UCTNode(MCTSNode):
|
||||
self.parent.backpropagation(self.children[action].reward)
|
||||
|
||||
def valid_mask(self, simulator):
|
||||
# let all invalid actions be illeagel in mcts
|
||||
if self.mask is None:
|
||||
self.mask = simulator.simulate_get_mask(self.state, range(self.action_num))
|
||||
self.ucb[self.mask] = -float("Inf")
|
||||
# let all invalid actions be illegal in mcts
|
||||
if not hasattr(simulator, 'simulate_get_mask'):
|
||||
pass
|
||||
else:
|
||||
if self.mask is None:
|
||||
self.mask = simulator.simulate_get_mask(self.state, range(self.action_num))
|
||||
self.ucb[self.mask] = -float("Inf")
|
||||
|
||||
|
||||
class TSNode(MCTSNode):
|
||||
|
Loading…
x
Reference in New Issue
Block a user