minor fixed
This commit is contained in:
parent
31beb46563
commit
e4e56d17d1
@ -17,11 +17,12 @@ class rollout_policy(evaluator):
|
|||||||
|
|
||||||
def __call__(self, state):
|
def __call__(self, state):
|
||||||
# TODO: prior for rollout policy
|
# TODO: prior for rollout policy
|
||||||
total_reward = 0
|
total_reward = 0.
|
||||||
action = np.random.randint(0, self.action_num)
|
action = np.random.randint(0, self.action_num)
|
||||||
state, reward = self.env.step_forward(state, action)
|
state, reward = self.env.step_forward(state, action)
|
||||||
|
total_reward += reward
|
||||||
while state is not None:
|
while state is not None:
|
||||||
action = np.random.randint(0, self.action_num)
|
action = np.random.randint(0, self.action_num)
|
||||||
state, reward = self.env.step_forward(state, action)
|
state, reward = self.env.step_forward(state, action)
|
||||||
total_reward += reward
|
total_reward += reward
|
||||||
return reward
|
return total_reward
|
||||||
|
@ -41,6 +41,7 @@ class UCTNode(MCTSNode):
|
|||||||
return self.children[action].selection(simulator)
|
return self.children[action].selection(simulator)
|
||||||
|
|
||||||
def backpropagation(self, action):
|
def backpropagation(self, action):
|
||||||
|
action = int(action)
|
||||||
self.N[action] += 1
|
self.N[action] += 1
|
||||||
self.W[action] += self.children[action].reward
|
self.W[action] += self.children[action].reward
|
||||||
for i in range(self.action_num):
|
for i in range(self.action_num):
|
||||||
@ -88,7 +89,7 @@ class ActionNode:
|
|||||||
# TODO: Let users/evaluator give the prior
|
# TODO: Let users/evaluator give the prior
|
||||||
if self.next_state is not None:
|
if self.next_state is not None:
|
||||||
prior = np.ones([action_num]) / action_num
|
prior = np.ones([action_num]) / action_num
|
||||||
self.children[self.next_state] = UCTNode(self.parent, self.action, self.next_state, action_num, prior)
|
self.children[self.next_state] = UCTNode(self, self.action, self.next_state, action_num, prior)
|
||||||
return True
|
return True
|
||||||
else:
|
else:
|
||||||
return False
|
return False
|
||||||
@ -133,8 +134,7 @@ class MCTS:
|
|||||||
value = node.simulation(self.evaluator, node.children[new_action].next_state)
|
value = node.simulation(self.evaluator, node.children[new_action].next_state)
|
||||||
node.children[new_action].backpropagation(value + 0.)
|
node.children[new_action].backpropagation(value + 0.)
|
||||||
else:
|
else:
|
||||||
value = node.simulation(self.evaluator, node.state)
|
node.children[new_action].backpropagation(0.)
|
||||||
node.parent.children[node.action].backpropagation(value + 0.)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -7,10 +7,10 @@ class TestEnv:
|
|||||||
def __init__(self, max_step=5):
|
def __init__(self, max_step=5):
|
||||||
self.max_step = max_step
|
self.max_step = max_step
|
||||||
self.reward = {i: np.random.uniform() for i in range(2 ** max_step)}
|
self.reward = {i: np.random.uniform() for i in range(2 ** max_step)}
|
||||||
# self.reward = {0:0.8, 1:0.2, 2:0.4, 3:0.6}
|
# self.reward = {0:1, 1:0}
|
||||||
self.best = max(self.reward.items(), key=lambda x: x[1])
|
self.best = max(self.reward.items(), key=lambda x: x[1])
|
||||||
# print("The best arm is {} with expected reward {}".format(self.best[0],self.best[1]))
|
|
||||||
print(self.reward)
|
print(self.reward)
|
||||||
|
# print("The best arm is {} with expected reward {}".format(self.best[0],self.best[1]))
|
||||||
|
|
||||||
def step_forward(self, state, action):
|
def step_forward(self, state, action):
|
||||||
if action != 0 and action != 1:
|
if action != 0 and action != 1:
|
||||||
@ -26,14 +26,14 @@ class TestEnv:
|
|||||||
step = state[1] + 1
|
step = state[1] + 1
|
||||||
new_state = (num, step)
|
new_state = (num, step)
|
||||||
if step == self.max_step:
|
if step == self.max_step:
|
||||||
reward = int(np.random.uniform() < self.reward[state[0]])
|
reward = int(np.random.uniform() < self.reward[num])
|
||||||
else:
|
else:
|
||||||
reward = 0
|
reward = 0
|
||||||
return new_state, reward
|
return new_state, reward
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
env = TestEnv(1)
|
env = TestEnv(2)
|
||||||
rollout = rollout_policy(env, 2)
|
rollout = rollout_policy(env, 2)
|
||||||
evaluator = lambda state: rollout(state)
|
evaluator = lambda state: rollout(state)
|
||||||
mcts = MCTS(env, evaluator, [0, 0], 2, np.array([0.5, 0.5]), max_step=1e4)
|
mcts = MCTS(env, evaluator, [0, 0], 2, np.array([0.5, 0.5]), max_step=1e4)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user