minor fixed

This commit is contained in:
rtz19970824 2017-11-21 22:52:17 +08:00
parent 31beb46563
commit e4e56d17d1
3 changed files with 10 additions and 9 deletions

View File

@ -17,11 +17,12 @@ class rollout_policy(evaluator):
def __call__(self, state):
# TODO: prior for rollout policy
total_reward = 0
total_reward = 0.
action = np.random.randint(0, self.action_num)
state, reward = self.env.step_forward(state, action)
total_reward += reward
while state is not None:
action = np.random.randint(0, self.action_num)
state, reward = self.env.step_forward(state, action)
total_reward += reward
return reward
return total_reward

View File

@ -41,6 +41,7 @@ class UCTNode(MCTSNode):
return self.children[action].selection(simulator)
def backpropagation(self, action):
action = int(action)
self.N[action] += 1
self.W[action] += self.children[action].reward
for i in range(self.action_num):
@ -88,7 +89,7 @@ class ActionNode:
# TODO: Let users/evaluator give the prior
if self.next_state is not None:
prior = np.ones([action_num]) / action_num
self.children[self.next_state] = UCTNode(self.parent, self.action, self.next_state, action_num, prior)
self.children[self.next_state] = UCTNode(self, self.action, self.next_state, action_num, prior)
return True
else:
return False
@ -133,8 +134,7 @@ class MCTS:
value = node.simulation(self.evaluator, node.children[new_action].next_state)
node.children[new_action].backpropagation(value + 0.)
else:
value = node.simulation(self.evaluator, node.state)
node.parent.children[node.action].backpropagation(value + 0.)
node.children[new_action].backpropagation(0.)
if __name__ == "__main__":

View File

@ -7,10 +7,10 @@ class TestEnv:
def __init__(self, max_step=5):
self.max_step = max_step
self.reward = {i: np.random.uniform() for i in range(2 ** max_step)}
# self.reward = {0:0.8, 1:0.2, 2:0.4, 3:0.6}
# self.reward = {0:1, 1:0}
self.best = max(self.reward.items(), key=lambda x: x[1])
# print("The best arm is {} with expected reward {}".format(self.best[0],self.best[1]))
print(self.reward)
# print("The best arm is {} with expected reward {}".format(self.best[0],self.best[1]))
def step_forward(self, state, action):
if action != 0 and action != 1:
@ -26,14 +26,14 @@ class TestEnv:
step = state[1] + 1
new_state = (num, step)
if step == self.max_step:
reward = int(np.random.uniform() < self.reward[state[0]])
reward = int(np.random.uniform() < self.reward[num])
else:
reward = 0
return new_state, reward
if __name__ == "__main__":
env = TestEnv(1)
env = TestEnv(2)
rollout = rollout_policy(env, 2)
evaluator = lambda state: rollout(state)
mcts = MCTS(env, evaluator, [0, 0], 2, np.array([0.5, 0.5]), max_step=1e4)