replace try catch by isinstance collections.Hashable
This commit is contained in:
parent
f0074aa7ca
commit
89226b449a
1
AlphaGo/.gitignore
vendored
1
AlphaGo/.gitignore
vendored
@ -2,3 +2,4 @@ data
|
|||||||
checkpoints
|
checkpoints
|
||||||
checkpoints_origin
|
checkpoints_origin
|
||||||
*.log
|
*.log
|
||||||
|
*.txt
|
||||||
|
@ -33,8 +33,8 @@ class Game:
|
|||||||
if self.name == "go":
|
if self.name == "go":
|
||||||
self.size = 9
|
self.size = 9
|
||||||
self.komi = 3.75
|
self.komi = 3.75
|
||||||
self.history = []
|
|
||||||
self.history_length = 8
|
self.history_length = 8
|
||||||
|
self.history = []
|
||||||
self.game_engine = go.Go(size=self.size, komi=self.komi, role=self.role)
|
self.game_engine = go.Go(size=self.size, komi=self.komi, role=self.role)
|
||||||
self.board = [utils.EMPTY] * (self.size ** 2)
|
self.board = [utils.EMPTY] * (self.size ** 2)
|
||||||
elif self.name == "reversi":
|
elif self.name == "reversi":
|
||||||
|
@ -1,23 +1,16 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import math
|
import math
|
||||||
import time
|
import time
|
||||||
|
import sys
|
||||||
|
import collections
|
||||||
|
|
||||||
c_puct = 5
|
c_puct = 5
|
||||||
|
|
||||||
|
def list2tuple(obj):
|
||||||
def list2tuple(list):
|
if isinstance(obj, collections.Hashable):
|
||||||
try:
|
return obj
|
||||||
return tuple(list2tuple(sub) for sub in list)
|
else:
|
||||||
except TypeError:
|
return tuple(list2tuple(sub) for sub in obj)
|
||||||
return list
|
|
||||||
|
|
||||||
|
|
||||||
def tuple2list(tuple):
|
|
||||||
try:
|
|
||||||
return list(tuple2list(sub) for sub in tuple)
|
|
||||||
except TypeError:
|
|
||||||
return tuple
|
|
||||||
|
|
||||||
|
|
||||||
class MCTSNode(object):
|
class MCTSNode(object):
|
||||||
def __init__(self, parent, action, state, action_num, prior, inverse=False):
|
def __init__(self, parent, action, state, action_num, prior, inverse=False):
|
||||||
@ -38,7 +31,6 @@ class MCTSNode(object):
|
|||||||
def valid_mask(self, simulator):
|
def valid_mask(self, simulator):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class UCTNode(MCTSNode):
|
class UCTNode(MCTSNode):
|
||||||
def __init__(self, parent, action, state, action_num, prior, mcts, inverse=False):
|
def __init__(self, parent, action, state, action_num, prior, mcts, inverse=False):
|
||||||
super(UCTNode, self).__init__(parent, action, state, action_num, prior, inverse)
|
super(UCTNode, self).__init__(parent, action, state, action_num, prior, inverse)
|
||||||
@ -119,12 +111,7 @@ class ActionNode(object):
|
|||||||
t2 = time.time()
|
t2 = time.time()
|
||||||
self.mcts.ndarray2list_time += t1 - t0
|
self.mcts.ndarray2list_time += t1 - t0
|
||||||
self.mcts.list2tuple_time += t2 - t1
|
self.mcts.list2tuple_time += t2 - t1
|
||||||
|
self.mcts.check += sys.getsizeof(object)
|
||||||
def type_conversion_to_origin(self):
|
|
||||||
if isinstance(self.state_type, np.ndarray):
|
|
||||||
self.next_state = np.array(self.next_state)
|
|
||||||
if isinstance(self.state_type, np.ndarray):
|
|
||||||
self.next_state = tuple2list(self.next_state)
|
|
||||||
|
|
||||||
def selection(self, simulator):
|
def selection(self, simulator):
|
||||||
head = time.time()
|
head = time.time()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user