74 lines
2.5 KiB
Python
74 lines
2.5 KiB
Python
import numpy as np
|
|
import math
|
|
import json
|
|
|
|
js = json.load("state_mask.json")
|
|
action_num = 2
|
|
c_puct = 5.
|
|
|
|
class MCTSNode:
|
|
def __init__(self, parent, action, state, action_num, prior):
|
|
self.parent = parent
|
|
self.action = action
|
|
self.children = {}
|
|
self.state = state
|
|
self.action_num = action_num
|
|
self.prior = prior
|
|
|
|
def select_leaf(self):
|
|
raise NotImplementedError("Need to implement function select_leaf")
|
|
|
|
def backup_value(self, action, value):
|
|
raise NotImplementedError("Need to implement function backup_value")
|
|
|
|
def expand(self, action):
|
|
raise NotImplementedError("Need to implement function expand")
|
|
|
|
def iteration(self):
|
|
raise NotImplementedError("Need to implement function iteration")
|
|
|
|
|
|
class UCTNode(MCTSNode):
|
|
def __init__(self, parent, action, state, action_num, prior):
|
|
super(UCTNode, self).__init__(parent, action, state, action_num, prior)
|
|
self.Q = np.zeros([action_num])
|
|
self.W = np.zeros([action_num])
|
|
self.N = np.zeros([action_num])
|
|
self.ucb = self.Q + c_puct * self.prior * math.sqrt(np.sum(self.N)) / (self.N + 1)
|
|
|
|
def select_leaf(self):
|
|
action = np.argmax(self.ucb)
|
|
if action in self.children.keys():
|
|
self.children[action].select_leaf()
|
|
else:
|
|
# TODO: apply the action and evalate next state
|
|
# state, value = self.env.step_forward(self.state, action)
|
|
# self.children[action] = MCTSNode(self.env, self, action, state, prior)
|
|
# self.backup_value(action, value)
|
|
state, value = self.expand(action)
|
|
self.children[action] = UCTNode(self.env, self, action, state, prior)
|
|
|
|
def backup_value(self, action, value):
|
|
self.N[action] += 1
|
|
self.W[action] += 1
|
|
self.Q = self.W / self.N
|
|
self.ucb = self.Q + c_puct * self.prior * math.sqrt(np.sum(self.N)) / (self.N + 1)
|
|
self.parent.backup_value(self.parent.action, value)
|
|
|
|
class TSNode(MCTSNode):
|
|
def __init__(self, parent, action, state, action_num, prior, method="Gaussian"):
|
|
super(TSNode, self).__init__(parent, action, state, action_num, prior)
|
|
if method == "Beta":
|
|
self.alpha = np.ones([action_num])
|
|
self.beta = np.ones([action_num])
|
|
if method == "Gaussian":
|
|
self.mu = np.zeros([action_num])
|
|
self.sigma = np.zeros([action_num])
|
|
|
|
class ActionNode:
|
|
def __init__(self, parent, action):
|
|
self.parent = parent
|
|
self.action = action
|
|
self.children = {}
|
|
|