modify game.py for multi-player
This commit is contained in:
parent
eb0ce95919
commit
c2775df8e6
@ -12,6 +12,7 @@ import numpy as np
|
|||||||
import sys, os
|
import sys, os
|
||||||
import model
|
import model
|
||||||
from collections import deque
|
from collections import deque
|
||||||
|
|
||||||
sys.path.append(os.path.join(os.path.dirname(__file__), os.path.pardir))
|
sys.path.append(os.path.join(os.path.dirname(__file__), os.path.pardir))
|
||||||
from tianshou.core.mcts.mcts import MCTS
|
from tianshou.core.mcts.mcts import MCTS
|
||||||
|
|
||||||
@ -19,6 +20,7 @@ import go
|
|||||||
import reversi
|
import reversi
|
||||||
import time
|
import time
|
||||||
|
|
||||||
|
|
||||||
class Game:
|
class Game:
|
||||||
'''
|
'''
|
||||||
Load the real game and trained weights.
|
Load the real game and trained weights.
|
||||||
@ -26,10 +28,11 @@ class Game:
|
|||||||
TODO : Maybe merge with the engine class in future,
|
TODO : Maybe merge with the engine class in future,
|
||||||
currently leave it untouched for interacting with Go UI.
|
currently leave it untouched for interacting with Go UI.
|
||||||
'''
|
'''
|
||||||
def __init__(self, name=None, role=None, debug=False, checkpoint_path=None):
|
|
||||||
|
def __init__(self, name=None, role=None, debug=False, black_checkpoint_path=None, white_checkpoint_path=None):
|
||||||
self.name = name
|
self.name = name
|
||||||
if role is None:
|
if role is None:
|
||||||
raise ValueError("Need a role!")
|
raise ValueError("Need a role!")
|
||||||
self.role = role
|
self.role = role
|
||||||
self.debug = debug
|
self.debug = debug
|
||||||
if self.name == "go":
|
if self.name == "go":
|
||||||
@ -49,8 +52,9 @@ class Game:
|
|||||||
else:
|
else:
|
||||||
raise ValueError(name + " is an unknown game...")
|
raise ValueError(name + " is an unknown game...")
|
||||||
|
|
||||||
self.evaluator = model.ResNet(self.size, self.size ** 2 + 1, history_length=self.history_length,
|
self.model = model.ResNet(self.size, self.size ** 2 + 1, history_length=self.history_length,
|
||||||
checkpoint_path=checkpoint_path)
|
black_checkpoint_path=black_checkpoint_path,
|
||||||
|
white_checkpoint_path=white_checkpoint_path)
|
||||||
self.latest_boards = deque(maxlen=self.history_length)
|
self.latest_boards = deque(maxlen=self.history_length)
|
||||||
for _ in range(self.history_length):
|
for _ in range(self.history_length):
|
||||||
self.latest_boards.append(self.board)
|
self.latest_boards.append(self.board)
|
||||||
@ -72,7 +76,12 @@ class Game:
|
|||||||
self.komi = k
|
self.komi = k
|
||||||
|
|
||||||
def think(self, latest_boards, color):
|
def think(self, latest_boards, color):
|
||||||
mcts = MCTS(self.game_engine, self.evaluator, [latest_boards, color],
|
if color == +1:
|
||||||
|
role = 'black'
|
||||||
|
if color == -1:
|
||||||
|
role = 'white'
|
||||||
|
evaluator = lambda state:self.model(role, state)
|
||||||
|
mcts = MCTS(self.game_engine, evaluator, [latest_boards, color],
|
||||||
self.size ** 2 + 1, role=self.role, debug=self.debug, inverse=True)
|
self.size ** 2 + 1, role=self.role, debug=self.debug, inverse=True)
|
||||||
mcts.search(max_step=100)
|
mcts.search(max_step=100)
|
||||||
if self.debug:
|
if self.debug:
|
||||||
@ -98,7 +107,8 @@ class Game:
|
|||||||
if self.name == "reversi":
|
if self.name == "reversi":
|
||||||
res = self.game_engine.executor_do_move(self.history, self.latest_boards, self.board, color, vertex)
|
res = self.game_engine.executor_do_move(self.history, self.latest_boards, self.board, color, vertex)
|
||||||
if self.name == "go":
|
if self.name == "go":
|
||||||
res = self.game_engine.executor_do_move(self.history, self.history_set, self.latest_boards, self.board, color, vertex)
|
res = self.game_engine.executor_do_move(self.history, self.history_set, self.latest_boards, self.board,
|
||||||
|
color, vertex)
|
||||||
return res
|
return res
|
||||||
|
|
||||||
def think_play_move(self, color):
|
def think_play_move(self, color):
|
||||||
@ -128,8 +138,8 @@ class Game:
|
|||||||
print('')
|
print('')
|
||||||
sys.stdout.flush()
|
sys.stdout.flush()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
game = Game(name="reversi", role="black", checkpoint_path=None)
|
game = Game(name="reversi", role="black", checkpoint_path=None)
|
||||||
game.debug = True
|
game.debug = True
|
||||||
game.think_play_move(utils.BLACK)
|
game.think_play_move(utils.BLACK)
|
||||||
|
|
||||||
|
@ -168,7 +168,7 @@ class ResNet(object):
|
|||||||
self.__setattr__(scope + '_saver',
|
self.__setattr__(scope + '_saver',
|
||||||
tf.train.Saver(max_to_keep=0, var_list=self.__getattribute__(scope + '_var_list')))
|
tf.train.Saver(max_to_keep=0, var_list=self.__getattribute__(scope + '_var_list')))
|
||||||
|
|
||||||
def __call__(self, state):
|
def __call__(self, role, state):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
:param history: a list, the history
|
:param history: a list, the history
|
||||||
@ -184,10 +184,10 @@ class ResNet(object):
|
|||||||
'The length of history cannot meet the need of the model, given {}, need {}'.format(len(history),
|
'The length of history cannot meet the need of the model, given {}, need {}'.format(len(history),
|
||||||
self.history_length))
|
self.history_length))
|
||||||
eval_state = self._history2state(history, color)
|
eval_state = self._history2state(history, color)
|
||||||
if color == +1:
|
if role == 'black':
|
||||||
return self.sess.run([self.black_prob, self.black_v],
|
return self.sess.run([self.black_prob, self.black_v],
|
||||||
feed_dict={self.x: eval_state, self.is_training: False})
|
feed_dict={self.x: eval_state, self.is_training: False})
|
||||||
if color == -1:
|
if role == 'white':
|
||||||
return self.sess.run([self.white_prob, self.white_v],
|
return self.sess.run([self.white_prob, self.white_v],
|
||||||
feed_dict={self.x: eval_state, self.is_training: False})
|
feed_dict={self.x: eval_state, self.is_training: False})
|
||||||
|
|
||||||
|
@ -24,6 +24,7 @@ class Data(object):
|
|||||||
def reset(self):
|
def reset(self):
|
||||||
self.__init__()
|
self.__init__()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
"""
|
"""
|
||||||
Starting two different players which load network weights to evaluate the winning ratio.
|
Starting two different players which load network weights to evaluate the winning ratio.
|
||||||
|
Loading…
x
Reference in New Issue
Block a user