replace try catch by isinstance collections.Hashable
This commit is contained in:
parent
f0074aa7ca
commit
89226b449a
1
AlphaGo/.gitignore
vendored
1
AlphaGo/.gitignore
vendored
@ -2,3 +2,4 @@ data
|
||||
checkpoints
|
||||
checkpoints_origin
|
||||
*.log
|
||||
*.txt
|
||||
|
@ -33,8 +33,8 @@ class Game:
|
||||
if self.name == "go":
|
||||
self.size = 9
|
||||
self.komi = 3.75
|
||||
self.history = []
|
||||
self.history_length = 8
|
||||
self.history = []
|
||||
self.game_engine = go.Go(size=self.size, komi=self.komi, role=self.role)
|
||||
self.board = [utils.EMPTY] * (self.size ** 2)
|
||||
elif self.name == "reversi":
|
||||
|
@ -1,23 +1,16 @@
|
||||
import numpy as np
|
||||
import math
|
||||
import time
|
||||
import sys
|
||||
import collections
|
||||
|
||||
c_puct = 5
|
||||
|
||||
|
||||
def list2tuple(list):
|
||||
try:
|
||||
return tuple(list2tuple(sub) for sub in list)
|
||||
except TypeError:
|
||||
return list
|
||||
|
||||
|
||||
def tuple2list(tuple):
|
||||
try:
|
||||
return list(tuple2list(sub) for sub in tuple)
|
||||
except TypeError:
|
||||
return tuple
|
||||
|
||||
def list2tuple(obj):
|
||||
if isinstance(obj, collections.Hashable):
|
||||
return obj
|
||||
else:
|
||||
return tuple(list2tuple(sub) for sub in obj)
|
||||
|
||||
class MCTSNode(object):
|
||||
def __init__(self, parent, action, state, action_num, prior, inverse=False):
|
||||
@ -38,7 +31,6 @@ class MCTSNode(object):
|
||||
def valid_mask(self, simulator):
|
||||
pass
|
||||
|
||||
|
||||
class UCTNode(MCTSNode):
|
||||
def __init__(self, parent, action, state, action_num, prior, mcts, inverse=False):
|
||||
super(UCTNode, self).__init__(parent, action, state, action_num, prior, inverse)
|
||||
@ -119,12 +111,7 @@ class ActionNode(object):
|
||||
t2 = time.time()
|
||||
self.mcts.ndarray2list_time += t1 - t0
|
||||
self.mcts.list2tuple_time += t2 - t1
|
||||
|
||||
def type_conversion_to_origin(self):
|
||||
if isinstance(self.state_type, np.ndarray):
|
||||
self.next_state = np.array(self.next_state)
|
||||
if isinstance(self.state_type, np.ndarray):
|
||||
self.next_state = tuple2list(self.next_state)
|
||||
self.mcts.check += sys.getsizeof(object)
|
||||
|
||||
def selection(self, simulator):
|
||||
head = time.time()
|
||||
|
Loading…
x
Reference in New Issue
Block a user