minor fixed
This commit is contained in:
parent
31beb46563
commit
e4e56d17d1
@ -17,11 +17,12 @@ class rollout_policy(evaluator):
|
||||
|
||||
def __call__(self, state):
|
||||
# TODO: prior for rollout policy
|
||||
total_reward = 0
|
||||
total_reward = 0.
|
||||
action = np.random.randint(0, self.action_num)
|
||||
state, reward = self.env.step_forward(state, action)
|
||||
total_reward += reward
|
||||
while state is not None:
|
||||
action = np.random.randint(0, self.action_num)
|
||||
state, reward = self.env.step_forward(state, action)
|
||||
total_reward += reward
|
||||
return reward
|
||||
return total_reward
|
||||
|
@ -41,6 +41,7 @@ class UCTNode(MCTSNode):
|
||||
return self.children[action].selection(simulator)
|
||||
|
||||
def backpropagation(self, action):
|
||||
action = int(action)
|
||||
self.N[action] += 1
|
||||
self.W[action] += self.children[action].reward
|
||||
for i in range(self.action_num):
|
||||
@ -88,7 +89,7 @@ class ActionNode:
|
||||
# TODO: Let users/evaluator give the prior
|
||||
if self.next_state is not None:
|
||||
prior = np.ones([action_num]) / action_num
|
||||
self.children[self.next_state] = UCTNode(self.parent, self.action, self.next_state, action_num, prior)
|
||||
self.children[self.next_state] = UCTNode(self, self.action, self.next_state, action_num, prior)
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
@ -133,8 +134,7 @@ class MCTS:
|
||||
value = node.simulation(self.evaluator, node.children[new_action].next_state)
|
||||
node.children[new_action].backpropagation(value + 0.)
|
||||
else:
|
||||
value = node.simulation(self.evaluator, node.state)
|
||||
node.parent.children[node.action].backpropagation(value + 0.)
|
||||
node.children[new_action].backpropagation(0.)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -7,10 +7,10 @@ class TestEnv:
|
||||
def __init__(self, max_step=5):
|
||||
self.max_step = max_step
|
||||
self.reward = {i: np.random.uniform() for i in range(2 ** max_step)}
|
||||
# self.reward = {0:0.8, 1:0.2, 2:0.4, 3:0.6}
|
||||
# self.reward = {0:1, 1:0}
|
||||
self.best = max(self.reward.items(), key=lambda x: x[1])
|
||||
# print("The best arm is {} with expected reward {}".format(self.best[0],self.best[1]))
|
||||
print(self.reward)
|
||||
# print("The best arm is {} with expected reward {}".format(self.best[0],self.best[1]))
|
||||
|
||||
def step_forward(self, state, action):
|
||||
if action != 0 and action != 1:
|
||||
@ -26,14 +26,14 @@ class TestEnv:
|
||||
step = state[1] + 1
|
||||
new_state = (num, step)
|
||||
if step == self.max_step:
|
||||
reward = int(np.random.uniform() < self.reward[state[0]])
|
||||
reward = int(np.random.uniform() < self.reward[num])
|
||||
else:
|
||||
reward = 0
|
||||
return new_state, reward
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
env = TestEnv(1)
|
||||
env = TestEnv(2)
|
||||
rollout = rollout_policy(env, 2)
|
||||
evaluator = lambda state: rollout(state)
|
||||
mcts = MCTS(env, evaluator, [0, 0], 2, np.array([0.5, 0.5]), max_step=1e4)
|
||||
|
Loading…
x
Reference in New Issue
Block a user