connect reversi with game
This commit is contained in:
parent
5b044c9a0c
commit
e63338ab01
@ -183,7 +183,7 @@ class GTPEngine():
|
||||
return 'unknown player', False
|
||||
|
||||
def cmd_get_score(self, args, **kwargs):
|
||||
return self._game.game_engine.executor_get_score(self._game.board, True), True
|
||||
return self._game.game_engine.executor_get_score(self._game.board), True
|
||||
|
||||
def cmd_show_board(self, args, **kwargs):
|
||||
return self._game.board, True
|
||||
@ -194,4 +194,4 @@ class GTPEngine():
|
||||
|
||||
if __name__ == "main":
|
||||
game = Game()
|
||||
engine = GTPEngine(game_obj=Game)
|
||||
engine = GTPEngine(game_obj=game)
|
||||
|
@ -10,12 +10,14 @@ import copy
|
||||
import tensorflow as tf
|
||||
import numpy as np
|
||||
import sys, os
|
||||
import go
|
||||
import model
|
||||
from collections import deque
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), os.path.pardir))
|
||||
from tianshou.core.mcts.mcts import MCTS
|
||||
|
||||
import go
|
||||
import reversi
|
||||
|
||||
class Game:
|
||||
'''
|
||||
Load the real game and trained weights.
|
||||
@ -23,18 +25,26 @@ class Game:
|
||||
TODO : Maybe merge with the engine class in future,
|
||||
currently leave it untouched for interacting with Go UI.
|
||||
'''
|
||||
def __init__(self, size=9, komi=3.75, checkpoint_path=None):
|
||||
self.size = size
|
||||
self.komi = komi
|
||||
def __init__(self, name="go", checkpoint_path=None):
|
||||
self.name = name
|
||||
if "go" == name:
|
||||
self.size = 9
|
||||
self.komi = 3.75
|
||||
self.board = [utils.EMPTY] * (self.size ** 2)
|
||||
self.history = []
|
||||
self.latest_boards = deque(maxlen=8)
|
||||
for _ in range(8):
|
||||
self.latest_boards.append(self.board)
|
||||
self.evaluator = model.ResNet(self.size, self.size**2 + 1, history_length=8, checkpoint_path=checkpoint_path)
|
||||
# self.evaluator = lambda state: self.sess.run([tf.nn.softmax(self.net.p), self.net.v],
|
||||
# feed_dict={self.net.x: state, self.net.is_training: False})
|
||||
|
||||
self.evaluator = model.ResNet(self.size, self.size**2 + 1, history_length=8)
|
||||
self.game_engine = go.Go(size=self.size, komi=self.komi)
|
||||
elif "reversi" == name:
|
||||
self.size = 8
|
||||
self.evaluator = model.ResNet(self.size, self.size**2 + 1, history_length=1)
|
||||
self.game_engine = reversi.Reversi()
|
||||
self.board = self.game_engine.get_board()
|
||||
else:
|
||||
print(name + " is an unknown game...")
|
||||
|
||||
def clear(self):
|
||||
self.board = [utils.EMPTY] * (self.size ** 2)
|
||||
@ -65,7 +75,11 @@ class Game:
|
||||
# this function can be called directly to play the opponent's move
|
||||
if vertex == utils.PASS:
|
||||
return True
|
||||
# TODO this implementation is not very elegant
|
||||
if "go" == self.name:
|
||||
res = self.game_engine.executor_do_move(self.history, self.latest_boards, self.board, color, vertex)
|
||||
elif "revsersi" == self.name:
|
||||
res = self.game_engine.executor_do_move(self.board, color, vertex)
|
||||
return res
|
||||
|
||||
def think_play_move(self, color):
|
||||
@ -96,7 +110,7 @@ class Game:
|
||||
sys.stdout.flush()
|
||||
|
||||
if __name__ == "__main__":
|
||||
g = Game(checkpoint_path='./checkpoints/')
|
||||
g = Game()
|
||||
g.show_board()
|
||||
g.think_play_move(1)
|
||||
#file = open("debug.txt", "a")
|
||||
|
@ -157,7 +157,7 @@ class Go:
|
||||
vertex = self._deflatten(action)
|
||||
return vertex
|
||||
|
||||
def _is_valid(self, history_boards, current_board, color, vertex):
|
||||
def _rule_check(self, history_boards, current_board, color, vertex):
|
||||
### in board
|
||||
if not self._in_board(vertex):
|
||||
return False
|
||||
@ -176,30 +176,30 @@ class Go:
|
||||
|
||||
return True
|
||||
|
||||
def simulate_is_valid(self, state, action):
|
||||
def _is_valid(self, state, action):
|
||||
history_boards, color = state
|
||||
vertex = self._action2vertex(action)
|
||||
current_board = history_boards[-1]
|
||||
|
||||
if not self._is_valid(history_boards, current_board, color, vertex):
|
||||
if not self._rule_check(history_boards, current_board, color, vertex):
|
||||
return False
|
||||
|
||||
if not self._knowledge_prunning(current_board, color, vertex):
|
||||
return False
|
||||
return True
|
||||
|
||||
def simulate_is_valid_list(self, state, action_set):
|
||||
def simulate_get_mask(self, state, action_set):
|
||||
# find all the invalid actions
|
||||
invalid_action_list = []
|
||||
invalid_action_mask = []
|
||||
for action_candidate in action_set[:-1]:
|
||||
# go through all the actions excluding pass
|
||||
if not self.simulate_is_valid(state, action_candidate):
|
||||
invalid_action_list.append(action_candidate)
|
||||
if len(invalid_action_list) < len(action_set) - 1:
|
||||
invalid_action_list.append(action_set[-1])
|
||||
if not self._is_valid(state, action_candidate):
|
||||
invalid_action_mask.append(action_candidate)
|
||||
if len(invalid_action_mask) < len(action_set) - 1:
|
||||
invalid_action_mask.append(action_set[-1])
|
||||
# forbid pass, if we have other choices
|
||||
# TODO: In fact we should not do this. In some extreme cases, we should permit pass.
|
||||
return invalid_action_list
|
||||
return invalid_action_mask
|
||||
|
||||
def _do_move(self, board, color, vertex):
|
||||
if vertex == utils.PASS:
|
||||
@ -219,7 +219,7 @@ class Go:
|
||||
return [history_boards, new_color], 0
|
||||
|
||||
def executor_do_move(self, history, latest_boards, current_board, color, vertex):
|
||||
if not self._is_valid(history, current_board, color, vertex):
|
||||
if not self._rule_check(history, current_board, color, vertex):
|
||||
return False
|
||||
current_board[self._flatten(vertex)] = color
|
||||
self._process_board(current_board, color, vertex)
|
||||
@ -280,7 +280,7 @@ class Go:
|
||||
elif color_estimate < 0:
|
||||
return utils.WHITE
|
||||
|
||||
def executor_get_score(self, current_board, is_unknown_estimation=False):
|
||||
def executor_get_score(self, current_board):
|
||||
'''
|
||||
is_unknown_estimation: whether use nearby stone to predict the unknown
|
||||
return score from BLACK perspective.
|
||||
@ -294,10 +294,8 @@ class Go:
|
||||
_board[self._flatten(vertex)] = utils.BLACK
|
||||
elif boarder_color == {utils.WHITE}:
|
||||
_board[self._flatten(vertex)] = utils.WHITE
|
||||
elif is_unknown_estimation:
|
||||
_board[self._flatten(vertex)] = self._predict_from_nearby(_board, vertex)
|
||||
else:
|
||||
_board[self._flatten(vertex)] =utils.UNKNOWN
|
||||
_board[self._flatten(vertex)] = self._predict_from_nearby(_board, vertex)
|
||||
score = 0
|
||||
for i in _board:
|
||||
if i == utils.BLACK:
|
||||
|
@ -7,7 +7,6 @@ import time
|
||||
import os
|
||||
import cPickle
|
||||
|
||||
|
||||
class Data(object):
|
||||
def __init__(self):
|
||||
self.boards = []
|
||||
|
@ -25,7 +25,6 @@ def find_correct_moves(own, enemy):
|
||||
mobility |= search_offset_right(own, enemy, mask, 7) # Left bottom
|
||||
return mobility
|
||||
|
||||
|
||||
def calc_flip(pos, own, enemy):
|
||||
"""return flip stones of enemy by bitboard when I place stone at pos.
|
||||
|
||||
@ -133,7 +132,9 @@ class Reversi:
|
||||
self.board = self.bitboard2board()
|
||||
return self.board
|
||||
|
||||
def simulate_is_valid(self, board, color):
|
||||
def simulate_get_mask(self, state, action_set):
|
||||
history_boards, color = state
|
||||
board = history_boards[-1]
|
||||
self.board = board
|
||||
self.color = color
|
||||
self.board2bitboard()
|
||||
@ -142,13 +143,18 @@ class Reversi:
|
||||
valid_moves = bit_to_array(mobility, 64)
|
||||
valid_moves = np.argwhere(valid_moves)
|
||||
valid_moves = list(np.reshape(valid_moves, len(valid_moves)))
|
||||
return valid_moves
|
||||
# TODO it seems that the pass move is not considered
|
||||
invalid_action_mask = []
|
||||
for action in action_set:
|
||||
if action not in valid_moves:
|
||||
invalid_action_mask.append(action)
|
||||
return invalid_action_mask
|
||||
|
||||
def simulate_step_forward(self, state, vertex):
|
||||
def simulate_step_forward(self, state, action):
|
||||
self.board = state[0]
|
||||
self.color = state[1]
|
||||
self.board2bitboard()
|
||||
self.vertex2action(vertex)
|
||||
self.action = action
|
||||
step_forward = self.step()
|
||||
if step_forward:
|
||||
new_board = self.bitboard2board()
|
||||
|
@ -79,7 +79,7 @@ while True:
|
||||
prob.append(np.array(game.prob).reshape(-1, game.size ** 2 + 1))
|
||||
print("Finished")
|
||||
print("\n")
|
||||
score = game.game_engine.executor_get_score(game.board, True)
|
||||
score = game.game_engine.executor_get_score(game.board)
|
||||
if score > 0:
|
||||
winner = utils.BLACK
|
||||
else:
|
||||
|
@ -73,7 +73,7 @@ class UCTNode(MCTSNode):
|
||||
def valid_mask(self, simulator):
|
||||
# let all invalid actions be illeagel in mcts
|
||||
if self.mask is None:
|
||||
self.mask = simulator.simulate_is_valid_list(self.state, range(self.action_num))
|
||||
self.mask = simulator.simulate_get_mask(self.state, range(self.action_num))
|
||||
self.ucb[self.mask] = -float("Inf")
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user