diff --git a/tianshou/core/mcts.py b/tianshou/core/mcts.py index 169a6c1..73bf695 100644 --- a/tianshou/core/mcts.py +++ b/tianshou/core/mcts.py @@ -1,12 +1,11 @@ import numpy as np import math -import json +import time -js = json.load("state_mask.json") -action_num = 2 c_puct = 5. -class MCTSNode: + +class MCTSNode(object): def __init__(self, parent, action, state, action_num, prior): self.parent = parent self.action = action @@ -15,17 +14,17 @@ class MCTSNode: self.action_num = action_num self.prior = prior - def select_leaf(self): - raise NotImplementedError("Need to implement function select_leaf") + def selection(self): + raise NotImplementedError("Need to implement function selection") - def backup_value(self, action, value): - raise NotImplementedError("Need to implement function backup_value") + def backpropagation(self, action, value): + raise NotImplementedError("Need to implement function backpropagation") - def expand(self, action): - raise NotImplementedError("Need to implement function expand") + def expansion(self, simulator, action): + raise NotImplementedError("Need to implement function expansion") - def iteration(self): - raise NotImplementedError("Need to implement function iteration") + def simulation(self, state, evaluator): + raise NotImplementedError("Need to implement function simulation") class UCTNode(MCTSNode): @@ -36,25 +35,31 @@ class UCTNode(MCTSNode): self.N = np.zeros([action_num]) self.ucb = self.Q + c_puct * self.prior * math.sqrt(np.sum(self.N)) / (self.N + 1) - def select_leaf(self): + def selection(self): action = np.argmax(self.ucb) if action in self.children.keys(): - self.children[action].select_leaf() + self.children[action].selection() else: - # TODO: apply the action and evalate next state - # state, value = self.env.step_forward(self.state, action) - # self.children[action] = MCTSNode(self.env, self, action, state, prior) - # self.backup_value(action, value) - state, value = self.expand(action) - self.children[action] = UCTNode(self.env, self, action, state, prior) + return self, action - def backup_value(self, action, value): + def backpropagation(self, action, value): self.N[action] += 1 self.W[action] += 1 self.Q = self.W / self.N self.ucb = self.Q + c_puct * self.prior * math.sqrt(np.sum(self.N)) / (self.N + 1) self.parent.backup_value(self.parent.action, value) + def expansion(self, simulator, action): + next_state = simulator.step_forward(self.state, action) + # TODO: Let users/evaluator give the prior + prior = np.ones([self.action_num]) / self.action_num + self.children[action] = UCTNode(self, action, next_state, self.action_num, prior) + + def simulation(self, evaluator, state): + value = evaluator(state) + return value + + class TSNode(MCTSNode): def __init__(self, parent, action, state, action_num, prior, method="Gaussian"): super(TSNode, self).__init__(parent, action, state, action_num, prior) @@ -65,9 +70,41 @@ class TSNode(MCTSNode): self.mu = np.zeros([action_num]) self.sigma = np.zeros([action_num]) + class ActionNode: def __init__(self, parent, action): self.parent = parent self.action = action self.children = {} + self.value = {} + +class MCTS: + def __init__(self, simulator, evaluator, root, action_num, prior, method="UCT", max_step=None, max_time=None): + self.simulator = simulator + self.evaluator = evaluator + if method == "UCT": + self.root = UCTNode(None, None, root, action_num, prior) + if method == "TS": + self.root = TSNode(None, None, root, action_num, prior) + if max_step is not None: + self.step = 0 + self.max_step = max_step + if max_time is not None: + self.start_time = time.time() + self.max_time = max_time + if max_step is None and max_time is None: + raise ValueError("Need a stop criteria!") + while (max_step is not None and self.step < self.max_step or max_step is None) \ + and (max_time is not None and time.time() - self.start_time < self.max_time or max_time is None): + self.expand() + + def expand(self): + node, new_action = self.root.selection() + node.expansion(self.simulator, new_action) + value = node.simulation(self.evaluator, node.children[new_action].state) + node.backpropagation(new_action, value) + + +if __name__=="__main__": + pass \ No newline at end of file diff --git a/tianshou/core/mcts_test.py b/tianshou/core/mcts_test.py index 43a4fb0..587e7aa 100644 --- a/tianshou/core/mcts_test.py +++ b/tianshou/core/mcts_test.py @@ -1,25 +1,28 @@ import numpy as np +from mcts import MCTS class TestEnv: def __init__(self, max_step=5): - self.step = 0 - self.state = 0 self.max_step = max_step self.reward = {i:np.random.uniform() for i in range(2**max_step)} self.best = max(self.reward.items(), key=lambda x:x[1]) print("The best arm is {} with expected reward {}".format(self.best[0],self.best[1])) - def step_forward(self, action): - print("Operate action {} at timestep {}".format(action, self.step)) - self.state = self.state + 2**self.step*action - self.step = self.step + 1 - if self.step == self.max_step: - reward = int(np.random.uniform() > self.reward[self.state]) + def step_forward(self, state, action): + if action != 0 and action != 1: + raise ValueError("Action must be 0 or 1!") + if state[0] >= 2**state[1] or state[1] >= self.max_step: + raise ValueError("Invalid State!") + print("Operate action {} at state {}, timestep {}".format(action, state[0], state[1])) + state[0] = state[0] + 2**state[1]*action + state[1] = state[1] + 1 + if state[1] == self.max_step: + reward = int(np.random.uniform() > self.reward[state[0]]) print("Get reward {}".format(reward)) else: reward = 0 - return [self.state, reward] + return [state, reward] if __name__=="__main__": env = TestEnv(1) - env.step_forward(1) \ No newline at end of file + env.step_forward([0,0],1) \ No newline at end of file