replace try catch by isinstance collections.Hashable

This commit is contained in:
Dong Yan 2017-12-24 20:57:53 +08:00
parent f0074aa7ca
commit 89226b449a
3 changed files with 10 additions and 22 deletions

1
AlphaGo/.gitignore vendored
View File

@ -2,3 +2,4 @@ data
checkpoints checkpoints
checkpoints_origin checkpoints_origin
*.log *.log
*.txt

View File

@ -33,8 +33,8 @@ class Game:
if self.name == "go": if self.name == "go":
self.size = 9 self.size = 9
self.komi = 3.75 self.komi = 3.75
self.history = []
self.history_length = 8 self.history_length = 8
self.history = []
self.game_engine = go.Go(size=self.size, komi=self.komi, role=self.role) self.game_engine = go.Go(size=self.size, komi=self.komi, role=self.role)
self.board = [utils.EMPTY] * (self.size ** 2) self.board = [utils.EMPTY] * (self.size ** 2)
elif self.name == "reversi": elif self.name == "reversi":

View File

@ -1,23 +1,16 @@
import numpy as np import numpy as np
import math import math
import time import time
import sys
import collections
c_puct = 5 c_puct = 5
def list2tuple(obj):
def list2tuple(list): if isinstance(obj, collections.Hashable):
try: return obj
return tuple(list2tuple(sub) for sub in list) else:
except TypeError: return tuple(list2tuple(sub) for sub in obj)
return list
def tuple2list(tuple):
try:
return list(tuple2list(sub) for sub in tuple)
except TypeError:
return tuple
class MCTSNode(object): class MCTSNode(object):
def __init__(self, parent, action, state, action_num, prior, inverse=False): def __init__(self, parent, action, state, action_num, prior, inverse=False):
@ -38,7 +31,6 @@ class MCTSNode(object):
def valid_mask(self, simulator): def valid_mask(self, simulator):
pass pass
class UCTNode(MCTSNode): class UCTNode(MCTSNode):
def __init__(self, parent, action, state, action_num, prior, mcts, inverse=False): def __init__(self, parent, action, state, action_num, prior, mcts, inverse=False):
super(UCTNode, self).__init__(parent, action, state, action_num, prior, inverse) super(UCTNode, self).__init__(parent, action, state, action_num, prior, inverse)
@ -119,12 +111,7 @@ class ActionNode(object):
t2 = time.time() t2 = time.time()
self.mcts.ndarray2list_time += t1 - t0 self.mcts.ndarray2list_time += t1 - t0
self.mcts.list2tuple_time += t2 - t1 self.mcts.list2tuple_time += t2 - t1
self.mcts.check += sys.getsizeof(object)
def type_conversion_to_origin(self):
if isinstance(self.state_type, np.ndarray):
self.next_state = np.array(self.next_state)
if isinstance(self.state_type, np.ndarray):
self.next_state = tuple2list(self.next_state)
def selection(self, simulator): def selection(self, simulator):
head = time.time() head = time.time()