From e4e56d17d13e0eb484cf0a1aef728c59acbb4e2a Mon Sep 17 00:00:00 2001 From: rtz19970824 Date: Tue, 21 Nov 2017 22:52:17 +0800 Subject: [PATCH] minor fixed --- tianshou/core/mcts/evaluator.py | 5 +++-- tianshou/core/mcts/mcts.py | 6 +++--- tianshou/core/mcts/mcts_test.py | 8 ++++---- 3 files changed, 10 insertions(+), 9 deletions(-) diff --git a/tianshou/core/mcts/evaluator.py b/tianshou/core/mcts/evaluator.py index aacdbe4..bef8d43 100644 --- a/tianshou/core/mcts/evaluator.py +++ b/tianshou/core/mcts/evaluator.py @@ -17,11 +17,12 @@ class rollout_policy(evaluator): def __call__(self, state): # TODO: prior for rollout policy - total_reward = 0 + total_reward = 0. action = np.random.randint(0, self.action_num) state, reward = self.env.step_forward(state, action) + total_reward += reward while state is not None: action = np.random.randint(0, self.action_num) state, reward = self.env.step_forward(state, action) total_reward += reward - return reward + return total_reward diff --git a/tianshou/core/mcts/mcts.py b/tianshou/core/mcts/mcts.py index 1bdc0ff..6292fd5 100644 --- a/tianshou/core/mcts/mcts.py +++ b/tianshou/core/mcts/mcts.py @@ -41,6 +41,7 @@ class UCTNode(MCTSNode): return self.children[action].selection(simulator) def backpropagation(self, action): + action = int(action) self.N[action] += 1 self.W[action] += self.children[action].reward for i in range(self.action_num): @@ -88,7 +89,7 @@ class ActionNode: # TODO: Let users/evaluator give the prior if self.next_state is not None: prior = np.ones([action_num]) / action_num - self.children[self.next_state] = UCTNode(self.parent, self.action, self.next_state, action_num, prior) + self.children[self.next_state] = UCTNode(self, self.action, self.next_state, action_num, prior) return True else: return False @@ -133,8 +134,7 @@ class MCTS: value = node.simulation(self.evaluator, node.children[new_action].next_state) node.children[new_action].backpropagation(value + 0.) else: - value = node.simulation(self.evaluator, node.state) - node.parent.children[node.action].backpropagation(value + 0.) + node.children[new_action].backpropagation(0.) if __name__ == "__main__": diff --git a/tianshou/core/mcts/mcts_test.py b/tianshou/core/mcts/mcts_test.py index a0425b8..1208054 100644 --- a/tianshou/core/mcts/mcts_test.py +++ b/tianshou/core/mcts/mcts_test.py @@ -7,10 +7,10 @@ class TestEnv: def __init__(self, max_step=5): self.max_step = max_step self.reward = {i: np.random.uniform() for i in range(2 ** max_step)} - # self.reward = {0:0.8, 1:0.2, 2:0.4, 3:0.6} + # self.reward = {0:1, 1:0} self.best = max(self.reward.items(), key=lambda x: x[1]) - # print("The best arm is {} with expected reward {}".format(self.best[0],self.best[1])) print(self.reward) + # print("The best arm is {} with expected reward {}".format(self.best[0],self.best[1])) def step_forward(self, state, action): if action != 0 and action != 1: @@ -26,14 +26,14 @@ class TestEnv: step = state[1] + 1 new_state = (num, step) if step == self.max_step: - reward = int(np.random.uniform() < self.reward[state[0]]) + reward = int(np.random.uniform() < self.reward[num]) else: reward = 0 return new_state, reward if __name__ == "__main__": - env = TestEnv(1) + env = TestEnv(2) rollout = rollout_policy(env, 2) evaluator = lambda state: rollout(state) mcts = MCTS(env, evaluator, [0, 0], 2, np.array([0.5, 0.5]), max_step=1e4)