minor fixed

This commit is contained in:
rtz19970824 2017-11-21 22:52:17 +08:00
parent 31beb46563
commit e4e56d17d1
3 changed files with 10 additions and 9 deletions

View File

@ -17,11 +17,12 @@ class rollout_policy(evaluator):
def __call__(self, state): def __call__(self, state):
# TODO: prior for rollout policy # TODO: prior for rollout policy
total_reward = 0 total_reward = 0.
action = np.random.randint(0, self.action_num) action = np.random.randint(0, self.action_num)
state, reward = self.env.step_forward(state, action) state, reward = self.env.step_forward(state, action)
total_reward += reward
while state is not None: while state is not None:
action = np.random.randint(0, self.action_num) action = np.random.randint(0, self.action_num)
state, reward = self.env.step_forward(state, action) state, reward = self.env.step_forward(state, action)
total_reward += reward total_reward += reward
return reward return total_reward

View File

@ -41,6 +41,7 @@ class UCTNode(MCTSNode):
return self.children[action].selection(simulator) return self.children[action].selection(simulator)
def backpropagation(self, action): def backpropagation(self, action):
action = int(action)
self.N[action] += 1 self.N[action] += 1
self.W[action] += self.children[action].reward self.W[action] += self.children[action].reward
for i in range(self.action_num): for i in range(self.action_num):
@ -88,7 +89,7 @@ class ActionNode:
# TODO: Let users/evaluator give the prior # TODO: Let users/evaluator give the prior
if self.next_state is not None: if self.next_state is not None:
prior = np.ones([action_num]) / action_num prior = np.ones([action_num]) / action_num
self.children[self.next_state] = UCTNode(self.parent, self.action, self.next_state, action_num, prior) self.children[self.next_state] = UCTNode(self, self.action, self.next_state, action_num, prior)
return True return True
else: else:
return False return False
@ -133,8 +134,7 @@ class MCTS:
value = node.simulation(self.evaluator, node.children[new_action].next_state) value = node.simulation(self.evaluator, node.children[new_action].next_state)
node.children[new_action].backpropagation(value + 0.) node.children[new_action].backpropagation(value + 0.)
else: else:
value = node.simulation(self.evaluator, node.state) node.children[new_action].backpropagation(0.)
node.parent.children[node.action].backpropagation(value + 0.)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -7,10 +7,10 @@ class TestEnv:
def __init__(self, max_step=5): def __init__(self, max_step=5):
self.max_step = max_step self.max_step = max_step
self.reward = {i: np.random.uniform() for i in range(2 ** max_step)} self.reward = {i: np.random.uniform() for i in range(2 ** max_step)}
# self.reward = {0:0.8, 1:0.2, 2:0.4, 3:0.6} # self.reward = {0:1, 1:0}
self.best = max(self.reward.items(), key=lambda x: x[1]) self.best = max(self.reward.items(), key=lambda x: x[1])
# print("The best arm is {} with expected reward {}".format(self.best[0],self.best[1]))
print(self.reward) print(self.reward)
# print("The best arm is {} with expected reward {}".format(self.best[0],self.best[1]))
def step_forward(self, state, action): def step_forward(self, state, action):
if action != 0 and action != 1: if action != 0 and action != 1:
@ -26,14 +26,14 @@ class TestEnv:
step = state[1] + 1 step = state[1] + 1
new_state = (num, step) new_state = (num, step)
if step == self.max_step: if step == self.max_step:
reward = int(np.random.uniform() < self.reward[state[0]]) reward = int(np.random.uniform() < self.reward[num])
else: else:
reward = 0 reward = 0
return new_state, reward return new_state, reward
if __name__ == "__main__": if __name__ == "__main__":
env = TestEnv(1) env = TestEnv(2)
rollout = rollout_policy(env, 2) rollout = rollout_policy(env, 2)
evaluator = lambda state: rollout(state) evaluator = lambda state: rollout(state)
mcts = MCTS(env, evaluator, [0, 0], 2, np.array([0.5, 0.5]), max_step=1e4) mcts = MCTS(env, evaluator, [0, 0], 2, np.array([0.5, 0.5]), max_step=1e4)