From 89226b449a8d0a05ffd852805913fcf05efdca07 Mon Sep 17 00:00:00 2001 From: Dong Yan Date: Sun, 24 Dec 2017 20:57:53 +0800 Subject: [PATCH] replace try catch by isinstance collections.Hashable --- AlphaGo/.gitignore | 1 + AlphaGo/game.py | 2 +- tianshou/core/mcts/mcts.py | 29 ++++++++--------------------- 3 files changed, 10 insertions(+), 22 deletions(-) diff --git a/AlphaGo/.gitignore b/AlphaGo/.gitignore index e578e5a..ff61326 100644 --- a/AlphaGo/.gitignore +++ b/AlphaGo/.gitignore @@ -2,3 +2,4 @@ data checkpoints checkpoints_origin *.log +*.txt diff --git a/AlphaGo/game.py b/AlphaGo/game.py index 8ffde93..a962f5c 100644 --- a/AlphaGo/game.py +++ b/AlphaGo/game.py @@ -33,8 +33,8 @@ class Game: if self.name == "go": self.size = 9 self.komi = 3.75 - self.history = [] self.history_length = 8 + self.history = [] self.game_engine = go.Go(size=self.size, komi=self.komi, role=self.role) self.board = [utils.EMPTY] * (self.size ** 2) elif self.name == "reversi": diff --git a/tianshou/core/mcts/mcts.py b/tianshou/core/mcts/mcts.py index 1994284..bd21e09 100644 --- a/tianshou/core/mcts/mcts.py +++ b/tianshou/core/mcts/mcts.py @@ -1,23 +1,16 @@ import numpy as np import math import time +import sys +import collections c_puct = 5 - -def list2tuple(list): - try: - return tuple(list2tuple(sub) for sub in list) - except TypeError: - return list - - -def tuple2list(tuple): - try: - return list(tuple2list(sub) for sub in tuple) - except TypeError: - return tuple - +def list2tuple(obj): + if isinstance(obj, collections.Hashable): + return obj + else: + return tuple(list2tuple(sub) for sub in obj) class MCTSNode(object): def __init__(self, parent, action, state, action_num, prior, inverse=False): @@ -38,7 +31,6 @@ class MCTSNode(object): def valid_mask(self, simulator): pass - class UCTNode(MCTSNode): def __init__(self, parent, action, state, action_num, prior, mcts, inverse=False): super(UCTNode, self).__init__(parent, action, state, action_num, prior, inverse) @@ -119,12 +111,7 @@ class ActionNode(object): t2 = time.time() self.mcts.ndarray2list_time += t1 - t0 self.mcts.list2tuple_time += t2 - t1 - - def type_conversion_to_origin(self): - if isinstance(self.state_type, np.ndarray): - self.next_state = np.array(self.next_state) - if isinstance(self.state_type, np.ndarray): - self.next_state = tuple2list(self.next_state) + self.mcts.check += sys.getsizeof(object) def selection(self, simulator): head = time.time()