From 4140d8c9d28fd2164ebb1f1dce902b77e0d9c5b5 Mon Sep 17 00:00:00 2001 From: JialianLee Date: Thu, 28 Dec 2017 17:10:25 +0800 Subject: [PATCH] Modification on unit test --- tianshou/core/mcts/unit_test/Evaluator.py | 3 +- tianshou/core/mcts/unit_test/ZOGame.py | 36 +++++++++++++++++++++-- tianshou/core/mcts/unit_test/agent.py | 2 +- tianshou/core/mcts/unit_test/game.py | 6 ++-- tianshou/core/mcts/unit_test/mcts.py | 2 ++ 5 files changed, 41 insertions(+), 8 deletions(-) diff --git a/tianshou/core/mcts/unit_test/Evaluator.py b/tianshou/core/mcts/unit_test/Evaluator.py index a1f9456..f78da95 100644 --- a/tianshou/core/mcts/unit_test/Evaluator.py +++ b/tianshou/core/mcts/unit_test/Evaluator.py @@ -18,6 +18,7 @@ class rollout_policy(evaluator): def __call__(self, state): # TODO: prior for rollout policy total_reward = 0. + color = state[1] action = np.random.randint(0, self.action_num) state, reward = self.env.simulate_step_forward(state, action) total_reward += reward @@ -25,4 +26,4 @@ class rollout_policy(evaluator): action = np.random.randint(0, self.action_num) state, reward = self.env.simulate_step_forward(state, action) total_reward += reward - return np.ones([self.action_num])/self.action_num, total_reward + return np.ones([self.action_num])/self.action_num, total_reward * color diff --git a/tianshou/core/mcts/unit_test/ZOGame.py b/tianshou/core/mcts/unit_test/ZOGame.py index acad284..b598579 100644 --- a/tianshou/core/mcts/unit_test/ZOGame.py +++ b/tianshou/core/mcts/unit_test/ZOGame.py @@ -9,6 +9,7 @@ class ZOTree: self.depth = self.size * 2 def simulate_step_forward(self, state, action): + self._check_state(state) seq, color = copy.deepcopy(state) if len(seq) == self.depth: winner = self.executor_get_reward(state) @@ -18,15 +19,24 @@ class ZOTree: return [seq, 0 - color], 0 def simulate_hashable_conversion(self, state): + self._check_state(state) # since go is MDP, we only need the last board for hashing return tuple(state[0]) - + def executor_get_reward(self, state): + self._check_state(state) seq = np.array(state[0], dtype='int16') length = len(seq) if length != self.depth: raise ValueError("The game is not terminated!") - result = np.sum(seq) + ones = 0 + zeros = 0 + for i in range(len(seq)): + if seq[i] == 0: + zeros += 1 + if seq[i] == 1: + ones += 1 + result = ones - zeros if result > 0: winner = 1 elif result < 0: @@ -36,6 +46,7 @@ class ZOTree: return winner def executor_do_move(self, state, action): + self._check_state(state) seq, color = state if len(seq) == self.depth: return False @@ -46,8 +57,16 @@ class ZOTree: return True def v_value(self, state): + self._check_state(state) seq, color = state - choosen_result = np.sum(np.array(seq, dtype='int16')) + ones = 0 + zeros = 0 + for i in range(len(seq)): + if seq[i] == 0: + zeros += 1 + if seq[i] == 1: + ones += 1 + choosen_result = ones - zeros if color == 1: if choosen_result > 0: return 1 @@ -65,6 +84,17 @@ class ZOTree: else: raise ValueError("Wrong color") + def _check_state(self, state): + seq, color = state + if color == 1: + if len(seq) % 2: + raise ValueError("Color is 1 but the length of seq is odd!") + elif color == -1: + if not len(seq) % 2: + raise ValueError("Color is -1 but the length of seq is even!") + else: + raise ValueError("Wrong color!") + if __name__ == "__main__": size = 2 game = ZOTree(size) diff --git a/tianshou/core/mcts/unit_test/agent.py b/tianshou/core/mcts/unit_test/agent.py index 1bffdd0..ebe346e 100644 --- a/tianshou/core/mcts/unit_test/agent.py +++ b/tianshou/core/mcts/unit_test/agent.py @@ -23,5 +23,5 @@ class Agent: N = np.power(N, 1.0 / temp) prob = N / np.sum(N) print("prob: {}".format(prob)) - action = int(np.random.binomial(1, prob[1]) * 2 - 1) + action = int(np.random.binomial(1, prob[1])) return action \ No newline at end of file diff --git a/tianshou/core/mcts/unit_test/game.py b/tianshou/core/mcts/unit_test/game.py index 7ac044c..14c2df5 100644 --- a/tianshou/core/mcts/unit_test/game.py +++ b/tianshou/core/mcts/unit_test/game.py @@ -5,11 +5,11 @@ import agent if __name__ == '__main__': print("Our game has 2 players.") print("Player 1 has color 1 and plays first. Player 2 has color -1 and plays following player 1.") - print("Both player choose 1 or -1 for an action.") + print("Both player choose 1 or 0 for an action.") size = 1 print("This game has {} iterations".format(size)) - print("If the final sequence has more 1 that -1, player 1 wins.") - print("If the final sequence has less 1 that -1, player 2 wins.") + print("If the final sequence has more 1 that 0, player 1 wins.") + print("If the final sequence has less 1 that 0, player 2 wins.") print("Otherwise, both players get 0.\n") game = ZOGame.ZOTree(size) player1 = agent.Agent(size, 1) diff --git a/tianshou/core/mcts/unit_test/mcts.py b/tianshou/core/mcts/unit_test/mcts.py index 1251d05..dd89f57 100644 --- a/tianshou/core/mcts/unit_test/mcts.py +++ b/tianshou/core/mcts/unit_test/mcts.py @@ -162,6 +162,8 @@ class MCTS(object): self.expansion_time += exp_time self.backpropagation_time += back_time step += 1 + print("Q = {}".format(self.root.Q)) + print("N = {}".format(self.root.N)) if self.debug: file = open("mcts_profiling.log", "a") file.write("[" + str(self.role) + "]"