From 01f39f40d3df481703401ebaf2d8305f232074b6 Mon Sep 17 00:00:00 2001 From: rtz19970824 Date: Thu, 28 Dec 2017 19:38:25 +0800 Subject: [PATCH] debug for unit test --- tianshou/core/mcts/unit_test/ZOGame.py | 17 +++++------------ tianshou/core/mcts/unit_test/agent.py | 4 ++-- tianshou/core/mcts/unit_test/game.py | 4 ++-- tianshou/core/mcts/unit_test/mcts.py | 2 +- 4 files changed, 10 insertions(+), 17 deletions(-) diff --git a/tianshou/core/mcts/unit_test/ZOGame.py b/tianshou/core/mcts/unit_test/ZOGame.py index b598579..0b3d771 100644 --- a/tianshou/core/mcts/unit_test/ZOGame.py +++ b/tianshou/core/mcts/unit_test/ZOGame.py @@ -29,17 +29,10 @@ class ZOTree: length = len(seq) if length != self.depth: raise ValueError("The game is not terminated!") - ones = 0 - zeros = 0 - for i in range(len(seq)): - if seq[i] == 0: - zeros += 1 - if seq[i] == 1: - ones += 1 - result = ones - zeros - if result > 0: + result = np.sum(seq) + if result > self.size: winner = 1 - elif result < 0: + elif result < self.size: winner = -1 else: winner = 0 @@ -98,7 +91,7 @@ class ZOTree: if __name__ == "__main__": size = 2 game = ZOTree(size) - seq = [1, -1, 1, 1] + seq = [1, 0, 1, 1] result = game.executor_do_move([seq, 1], 1) print(result) - print(seq) \ No newline at end of file + print(seq) diff --git a/tianshou/core/mcts/unit_test/agent.py b/tianshou/core/mcts/unit_test/agent.py index ebe346e..6dd34aa 100644 --- a/tianshou/core/mcts/unit_test/agent.py +++ b/tianshou/core/mcts/unit_test/agent.py @@ -17,11 +17,11 @@ class Agent: def gen_move(self, seq): if len(seq) >= 2 * self.size: raise ValueError("Game is terminated.") - mcts = MCTS(self.simulator, self.evaluator, [seq, self.color], 2) + mcts = MCTS(self.simulator, self.evaluator, [seq, self.color], 2, inverse=True) mcts.search(max_step=50) N = mcts.root.N N = np.power(N, 1.0 / temp) prob = N / np.sum(N) print("prob: {}".format(prob)) action = int(np.random.binomial(1, prob[1])) - return action \ No newline at end of file + return action diff --git a/tianshou/core/mcts/unit_test/game.py b/tianshou/core/mcts/unit_test/game.py index 14c2df5..6fb504b 100644 --- a/tianshou/core/mcts/unit_test/game.py +++ b/tianshou/core/mcts/unit_test/game.py @@ -6,7 +6,7 @@ if __name__ == '__main__': print("Our game has 2 players.") print("Player 1 has color 1 and plays first. Player 2 has color -1 and plays following player 1.") print("Both player choose 1 or 0 for an action.") - size = 1 + size = 2 print("This game has {} iterations".format(size)) print("If the final sequence has more 1 that 0, player 1 wins.") print("If the final sequence has less 1 that 0, player 2 wins.") @@ -34,4 +34,4 @@ if __name__ == '__main__': break print("The choice sequence is {}".format(seq)) - print("The game result is {}".format(winner)) \ No newline at end of file + print("The game result is {}".format(winner)) diff --git a/tianshou/core/mcts/unit_test/mcts.py b/tianshou/core/mcts/unit_test/mcts.py index dd89f57..49c9faf 100644 --- a/tianshou/core/mcts/unit_test/mcts.py +++ b/tianshou/core/mcts/unit_test/mcts.py @@ -187,7 +187,7 @@ class MCTS(object): prior, value = self.evaluator(next_action.next_state) next_action.expansion(prior, self.action_num) else: - value = 0 + value = 0. t2 = time.time() if self.inverse: next_action.backpropagation(-value + 0.)