modify game.py for multi-player
This commit is contained in:
parent
eb0ce95919
commit
c2775df8e6
@ -12,6 +12,7 @@ import numpy as np
|
||||
import sys, os
|
||||
import model
|
||||
from collections import deque
|
||||
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), os.path.pardir))
|
||||
from tianshou.core.mcts.mcts import MCTS
|
||||
|
||||
@ -19,6 +20,7 @@ import go
|
||||
import reversi
|
||||
import time
|
||||
|
||||
|
||||
class Game:
|
||||
'''
|
||||
Load the real game and trained weights.
|
||||
@ -26,10 +28,11 @@ class Game:
|
||||
TODO : Maybe merge with the engine class in future,
|
||||
currently leave it untouched for interacting with Go UI.
|
||||
'''
|
||||
def __init__(self, name=None, role=None, debug=False, checkpoint_path=None):
|
||||
|
||||
def __init__(self, name=None, role=None, debug=False, black_checkpoint_path=None, white_checkpoint_path=None):
|
||||
self.name = name
|
||||
if role is None:
|
||||
raise ValueError("Need a role!")
|
||||
if role is None:
|
||||
raise ValueError("Need a role!")
|
||||
self.role = role
|
||||
self.debug = debug
|
||||
if self.name == "go":
|
||||
@ -49,8 +52,9 @@ class Game:
|
||||
else:
|
||||
raise ValueError(name + " is an unknown game...")
|
||||
|
||||
self.evaluator = model.ResNet(self.size, self.size ** 2 + 1, history_length=self.history_length,
|
||||
checkpoint_path=checkpoint_path)
|
||||
self.model = model.ResNet(self.size, self.size ** 2 + 1, history_length=self.history_length,
|
||||
black_checkpoint_path=black_checkpoint_path,
|
||||
white_checkpoint_path=white_checkpoint_path)
|
||||
self.latest_boards = deque(maxlen=self.history_length)
|
||||
for _ in range(self.history_length):
|
||||
self.latest_boards.append(self.board)
|
||||
@ -72,7 +76,12 @@ class Game:
|
||||
self.komi = k
|
||||
|
||||
def think(self, latest_boards, color):
|
||||
mcts = MCTS(self.game_engine, self.evaluator, [latest_boards, color],
|
||||
if color == +1:
|
||||
role = 'black'
|
||||
if color == -1:
|
||||
role = 'white'
|
||||
evaluator = lambda state:self.model(role, state)
|
||||
mcts = MCTS(self.game_engine, evaluator, [latest_boards, color],
|
||||
self.size ** 2 + 1, role=self.role, debug=self.debug, inverse=True)
|
||||
mcts.search(max_step=100)
|
||||
if self.debug:
|
||||
@ -98,7 +107,8 @@ class Game:
|
||||
if self.name == "reversi":
|
||||
res = self.game_engine.executor_do_move(self.history, self.latest_boards, self.board, color, vertex)
|
||||
if self.name == "go":
|
||||
res = self.game_engine.executor_do_move(self.history, self.history_set, self.latest_boards, self.board, color, vertex)
|
||||
res = self.game_engine.executor_do_move(self.history, self.history_set, self.latest_boards, self.board,
|
||||
color, vertex)
|
||||
return res
|
||||
|
||||
def think_play_move(self, color):
|
||||
@ -128,8 +138,8 @@ class Game:
|
||||
print('')
|
||||
sys.stdout.flush()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
game = Game(name="reversi", role="black", checkpoint_path=None)
|
||||
game.debug = True
|
||||
game.think_play_move(utils.BLACK)
|
||||
|
||||
|
@ -168,7 +168,7 @@ class ResNet(object):
|
||||
self.__setattr__(scope + '_saver',
|
||||
tf.train.Saver(max_to_keep=0, var_list=self.__getattribute__(scope + '_var_list')))
|
||||
|
||||
def __call__(self, state):
|
||||
def __call__(self, role, state):
|
||||
"""
|
||||
|
||||
:param history: a list, the history
|
||||
@ -184,10 +184,10 @@ class ResNet(object):
|
||||
'The length of history cannot meet the need of the model, given {}, need {}'.format(len(history),
|
||||
self.history_length))
|
||||
eval_state = self._history2state(history, color)
|
||||
if color == +1:
|
||||
if role == 'black':
|
||||
return self.sess.run([self.black_prob, self.black_v],
|
||||
feed_dict={self.x: eval_state, self.is_training: False})
|
||||
if color == -1:
|
||||
if role == 'white':
|
||||
return self.sess.run([self.white_prob, self.white_v],
|
||||
feed_dict={self.x: eval_state, self.is_training: False})
|
||||
|
||||
|
@ -24,6 +24,7 @@ class Data(object):
|
||||
def reset(self):
|
||||
self.__init__()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
"""
|
||||
Starting two different players which load network weights to evaluate the winning ratio.
|
||||
|
Loading…
x
Reference in New Issue
Block a user