Tianshou/AlphaGo/reversi.py

199 lines
6.7 KiB
Python
Raw Normal View History

2017-12-24 00:42:59 +08:00
import numpy as np
2017-12-24 14:40:50 +08:00
import copy
2017-12-24 00:42:59 +08:00
'''
Settings of the Reversi game.
(1, 1) is considered as the upper left corner of the board,
(size, 1) is the lower left
'''
class Reversi:
2017-12-24 14:40:50 +08:00
def __init__(self, **kwargs):
self.size = kwargs['size']
2017-12-24 00:42:59 +08:00
def _deflatten(self, idx):
x = idx // self.size + 1
y = idx % self.size + 1
return (x, y)
def _flatten(self, vertex):
x, y = vertex
if (x == 0) and (y == 0):
2017-12-24 14:40:50 +08:00
return self.size ** 2
2017-12-24 00:42:59 +08:00
return (x - 1) * self.size + (y - 1)
2017-12-24 14:40:50 +08:00
def get_board(self):
board = np.zeros([self.size, self.size], dtype=np.int32)
board[self.size / 2 - 1, self.size / 2 - 1] = -1
board[self.size / 2, self.size / 2] = -1
board[self.size / 2 - 1, self.size / 2] = 1
board[self.size / 2, self.size / 2 - 1] = 1
return board
2017-12-24 00:42:59 +08:00
2017-12-24 14:40:50 +08:00
def _find_correct_moves(self, board, color, is_next=False):
2017-12-24 00:42:59 +08:00
moves = []
if is_next:
2017-12-24 14:40:50 +08:00
new_color = 0 - color
2017-12-24 00:42:59 +08:00
else:
2017-12-24 14:40:50 +08:00
new_color = color
for i in range(self.size ** 2):
2017-12-24 00:42:59 +08:00
x, y = self._deflatten(i)
2017-12-24 14:40:50 +08:00
valid = self._is_valid(board, x - 1, y - 1, new_color)
2017-12-24 00:42:59 +08:00
if valid:
moves.append(i)
return moves
2017-12-24 14:40:50 +08:00
def _one_direction_valid(self, board, x, y, color):
2017-12-24 00:42:59 +08:00
if (x >= 0) and (x < self.size):
if (y >= 0) and (y < self.size):
2017-12-24 14:40:50 +08:00
if board[x, y] == color:
2017-12-24 00:42:59 +08:00
return True
return False
2017-12-24 14:40:50 +08:00
def _is_valid(self, board, x, y, color):
if board[x, y]:
2017-12-24 00:42:59 +08:00
return False
for x_direction in [-1, 0, 1]:
for y_direction in [-1, 0, 1]:
new_x = x
new_y = y
flag = 0
while True:
new_x += x_direction
new_y += y_direction
2017-12-24 14:40:50 +08:00
if self._one_direction_valid(board, new_x, new_y, 0 - color):
2017-12-24 00:42:59 +08:00
flag = 1
else:
break
2017-12-24 14:40:50 +08:00
if self._one_direction_valid(board, new_x, new_y, color) and flag:
2017-12-24 00:42:59 +08:00
return True
return False
def simulate_get_mask(self, state, action_set):
2017-12-24 14:40:50 +08:00
history_boards, color = copy.deepcopy(state)
board = copy.deepcopy(history_boards[-1])
valid_moves = self._find_correct_moves(board, color)
2017-12-24 00:42:59 +08:00
if not len(valid_moves):
invalid_action_mask = action_set[0:-1]
else:
invalid_action_mask = []
for action in action_set:
if action not in valid_moves:
invalid_action_mask.append(action)
return invalid_action_mask
def simulate_step_forward(self, state, action):
2017-12-24 14:40:50 +08:00
history_boards, color = copy.deepcopy(state)
board = copy.deepcopy(history_boards[-1])
if action == self.size ** 2:
valid_moves = self._find_correct_moves(board, color, is_next=True)
2017-12-24 00:42:59 +08:00
if not len(valid_moves):
2017-12-24 14:40:50 +08:00
winner = self._get_winner(board)
return None, winner * color
2017-12-24 00:42:59 +08:00
else:
2017-12-24 14:40:50 +08:00
return [history_boards, 0 - color], 0
new_board = self._step(board, color, action)
history_boards.append(new_board)
return [history_boards, 0 - color], 0
def simulate_hashable_conversion(self, state):
# since go is MDP, we only need the last board for hashing
2017-12-28 15:55:07 +08:00
return tuple(state[0][-1].flatten().tolist())
2017-12-24 14:40:50 +08:00
def _get_winner(self, board):
black_num, white_num = self._number_of_black_and_white(board)
black_win = black_num - white_num
if black_win > 0:
winner = 1
elif black_win < 0:
winner = -1
2017-12-24 00:42:59 +08:00
else:
2017-12-24 14:40:50 +08:00
winner = 0
return winner
2017-12-24 00:42:59 +08:00
2017-12-24 14:40:50 +08:00
def _number_of_black_and_white(self, board):
2017-12-24 00:42:59 +08:00
black_num = 0
white_num = 0
2017-12-24 14:40:50 +08:00
board_list = np.reshape(board, self.size ** 2)
2017-12-24 00:42:59 +08:00
for i in range(len(board_list)):
if board_list[i] == 1:
black_num += 1
elif board_list[i] == -1:
white_num += 1
return black_num, white_num
2017-12-24 14:40:50 +08:00
def _step(self, board, color, action):
if action < 0 or action > self.size ** 2 - 1:
2017-12-24 00:42:59 +08:00
raise ValueError("Action not in the range of [0,63]!")
2017-12-24 14:40:50 +08:00
if action is None:
2017-12-24 00:42:59 +08:00
raise ValueError("Action is None!")
2017-12-24 14:40:50 +08:00
x, y = self._deflatten(action)
new_board = self._flip(board, x - 1, y - 1, color)
return new_board
2017-12-24 00:42:59 +08:00
2017-12-24 14:40:50 +08:00
def _flip(self, board, x, y, color):
2017-12-24 00:42:59 +08:00
valid = 0
2017-12-24 14:40:50 +08:00
board[x, y] = color
2017-12-24 00:42:59 +08:00
for x_direction in [-1, 0, 1]:
for y_direction in [-1, 0, 1]:
new_x = x
new_y = y
flag = 0
while True:
new_x += x_direction
new_y += y_direction
2017-12-24 14:40:50 +08:00
if self._one_direction_valid(board, new_x, new_y, 0 - color):
2017-12-24 00:42:59 +08:00
flag = 1
else:
break
2017-12-24 14:40:50 +08:00
if self._one_direction_valid(board, new_x, new_y, color) and flag:
2017-12-24 00:42:59 +08:00
valid = 1
flip_x = x
flip_y = y
while True:
flip_x += x_direction
flip_y += y_direction
2017-12-24 14:40:50 +08:00
if self._one_direction_valid(board, flip_x, flip_y, 0 - color):
board[flip_x, flip_y] = color
2017-12-24 00:42:59 +08:00
else:
break
if valid:
2017-12-24 14:40:50 +08:00
return board
2017-12-24 00:42:59 +08:00
else:
2017-12-24 14:40:50 +08:00
raise ValueError("Invalid action")
2017-12-24 00:42:59 +08:00
def executor_do_move(self, history, latest_boards, board, color, vertex):
2017-12-24 14:40:50 +08:00
board = np.reshape(board, (self.size, self.size))
color = color
action = self._flatten(vertex)
if action == self.size ** 2:
valid_moves = self._find_correct_moves(board, color, is_next=True)
2017-12-24 00:42:59 +08:00
if not len(valid_moves):
return False
else:
return True
else:
2017-12-24 14:40:50 +08:00
new_board = self._step(board, color, action)
history.append(new_board)
latest_boards.append(new_board)
2017-12-24 00:42:59 +08:00
return True
def executor_get_score(self, board):
2017-12-24 14:40:50 +08:00
board = board
winner = self._get_winner(board)
return winner
2017-12-24 00:42:59 +08:00
if __name__ == "__main__":
reversi = Reversi()
# board = reversi.get_board()
# print(board)
# state, value = reversi.simulate_step_forward([board, -1], 20)
# print(state[0])
# print("board")
# print(board)
# r = reversi.executor_get_score(board)
# print(r)