From 1f011a44ef12ca6a8651a6870cc37670a1c96dec Mon Sep 17 00:00:00 2001 From: mcgrady00h <281130306@qq.com> Date: Tue, 19 Dec 2017 17:04:55 +0800 Subject: [PATCH 1/4] add mcts virtual loss version (may have bugs) --- tianshou/core/mcts/mcts_test.py | 3 + tianshou/core/mcts/mcts_virtual_loss.py | 263 +++++++++++++++++++ tianshou/core/mcts/mcts_virtual_loss_test.py | 55 ++++ 3 files changed, 321 insertions(+) create mode 100644 tianshou/core/mcts/mcts_virtual_loss.py create mode 100644 tianshou/core/mcts/mcts_virtual_loss_test.py diff --git a/tianshou/core/mcts/mcts_test.py b/tianshou/core/mcts/mcts_test.py index da404ca..49b85be 100644 --- a/tianshou/core/mcts/mcts_test.py +++ b/tianshou/core/mcts/mcts_test.py @@ -12,6 +12,9 @@ class TestEnv: print(self.reward) # print("The best arm is {} with expected reward {}".format(self.best[0],self.best[1])) + def simulate_is_valid(self, state, act): + return True + def step_forward(self, state, action): if action != 0 and action != 1: raise ValueError("Action must be 0 or 1! Your action is {}".format(action)) diff --git a/tianshou/core/mcts/mcts_virtual_loss.py b/tianshou/core/mcts/mcts_virtual_loss.py new file mode 100644 index 0000000..9d20b5a --- /dev/null +++ b/tianshou/core/mcts/mcts_virtual_loss.py @@ -0,0 +1,263 @@ +# -*- coding: utf-8 -*- +# vim:fenc=utf-8 +# $File: mcts_virtual_loss.py +# $Date: Tue Dec 19 17:0444 2017 +0800 +# Original file: mcts.py +# $Author: renyong15 © +# + +""" + This is an implementation of the MCTS with virtual loss. + Due to the limitation of Python design mechanism, we implements the virtual loss in a mini-batch + manner. +""" + +import numpy as np +import math +import time + +c_puct = 5 + + +def list2tuple(list): + try: + return tuple(list2tuple(sub) for sub in list) + except TypeError: + return list + + +def tuple2list(tuple): + try: + return list(tuple2list(sub) for sub in tuple) + except TypeError: + return tuple + + +class MCTSNodeVirtualLoss(object): + def __init__(self, parent, action, state, action_num, prior, inverse=False): + self.parent = parent + self.action = action + self.children = {} + self.state = state + self.action_num = action_num + self.prior = np.array(prior).reshape(-1) + self.inverse = inverse + + def selection(self, simulator): + raise NotImplementedError("Need to implement function selection") + + def backpropagation(self, action): + raise NotImplementedError("Need to implement function backpropagation") + + def valid_mask(self, simulator): + pass + +class UCTNodeVirtualLoss(MCTSNodeVirtualLoss): + def __init__(self, parent, action, state, action_num, prior, inverse=False): + super(UCTNodeVirtualLoss, self).__init__(parent, action, state, action_num, prior, inverse) + self.Q = np.zeros([action_num]) + self.W = np.zeros([action_num]) + self.N = np.zeros([action_num]) + self.virtual_loss = np.zeros([action_num]) + #### modified by adding virtual loss + #self.ucb = self.Q + c_puct * self.prior * math.sqrt(np.sum(self.N)) / (self.N + 1) + + self.mask = None + + def selection(self, simulator): + self.valid_mask(simulator) + self.Q = np.zeros([self.action_num]) + N_not_zero = self.N > 0 + self.Q[N_not_zero] = (self.W[N_not_zero] + self.virtual_loss[N_not_zero] + 0.) / self.N[N_not_zero] + self.ucb = self.Q + c_puct * self.prior * math.sqrt(np.sum(self.N + self.virtual_loss)) /\ + (self.N + self.virtual_loss + 1) + action = np.argmax(self.ucb) + self.virtual_loss[action] += 1 + + if action in self.children.keys(): + return self.children[action].selection(simulator) + else: + self.children[action] = ActionNodeVirtualLoss(self, action) + return self.children[action].selection(simulator) + + def remove_virtual_loss(self): + ### if not virtual_loss for every action is zero + if np.sum(self.virtual_loss > 0) > 0: + self.virtual_loss = np.zeros([self.action_num]) + if self.parent: + self.parent.remove_virtual_loss() + + def backpropagation(self, action): + action = int(action) + self.N[action] += 1 + self.W[action] += self.children[action].reward + + ## do not need to compute Q and ucb immediately since it will be modified by virtual loss + #for i in range(self.action_num): + # if self.N[i] != 0: + # self.Q[i] = (self.W[i] + 0.) / self.N[i] + #self.ucb = self.Q + c_puct * self.prior * math.sqrt(np.sum(self.N)) / (self.N + 1.) + + if self.parent is not None: + if self.inverse: + self.parent.backpropagation(-self.children[action].reward) + else: + self.parent.backpropagation(self.children[action].reward) + + def valid_mask(self, simulator): + if self.mask is None: + start_time = time.time() + self.mask = [] + for act in range(self.action_num - 1): + if not simulator.simulate_is_valid(self.state, act): + self.mask.append(act) + self.ucb[act] = -float("Inf") + else: + self.ucb[self.mask] = -float("Inf") + + + +class ActionNodeVirtualLoss(object): + def __init__(self, parent, action): + self.parent = parent + self.action = action + self.children = {} + self.next_state = None + self.origin_state = None + self.state_type = None + self.reward = 0 + + def remove_virtual_loss(self): + self.parent.remove_virtual_loss() + + def type_conversion_to_tuple(self): + if type(self.next_state) is np.ndarray: + self.next_state = self.next_state.tolist() + if type(self.next_state) is list: + self.next_state = list2tuple(self.next_state) + + def type_conversion_to_origin(self): + if self.state_type is np.ndarray: + self.next_state = np.array(self.next_state) + if self.state_type is list: + self.next_state = tuple2list(self.next_state) + + def selection(self, simulator): + self.next_state, self.reward = simulator.step_forward(self.parent.state, self.action) + self.origin_state = self.next_state + self.state_type = type(self.next_state) + self.type_conversion_to_tuple() + if self.next_state is not None: + if self.next_state in self.children.keys(): + return self.children[self.next_state].selection(simulator) + else: + return self.parent, self.action + else: + return self.parent, self.action + + def expansion(self, action, state, action_num, prior, inverse ): + if state is not None: + self.children[state] = UCTNodeVirtualLoss(self, action, state, action_num, prior, inverse) + + + def backpropagation(self, value): + self.reward += value + self.parent.backpropagation(self.action) + + +class MCTSVirtualLoss(object): + def __init__(self, simulator, evaluator, root, action_num, batch_size = 1, method = "UCT", inverse = False): + self.simulator = simulator + self.evaluator = evaluator + prior, _ = self.evaluator(root) + self.action_num = action_num + self.batch_size = batch_size + + if method == "": + self.root = root + elif method == "UCT": + self.root = UCTNodeVirtualLoss(None, None, root, action_num, prior, inverse) + elif method == "TS": + self.root = TSNodeVirtualLoss(None, None, root, action_num, prior, inverse=inverse) + else: + raise ValueError("Need a root type") + + self.inverse = inverse + + + def do_search(self, max_step=None, max_time=None): + if max_step is not None: + self.step = 0 + self.max_step = max_step + if max_time is not None: + self.start_time = time.time() + self.max_time = max_time + if max_step is None and max_time is None: + raise ValueError("Need a stop criteria!") + + self.select_time = [] + self.evaluate_time = [] + self.bp_time = [] + while (max_step is not None and self.step < self.max_step or max_step is None) \ + and (max_time is not None and time.time() - self.start_time < self.max_time or max_time is None): + self.expand() + if max_step is not None: + self.step += 1 + + def expand(self): + ## minibatch with virtual loss + nodes = [] + new_actions = [] + next_states = [] + + for i in range(self.batch_size): + node, new_action = self.root.selection(self.simulator) + nodes.append(node) + new_actions.append(new_action) + next_states.append(node.children[new_action].next_state) + + for node in nodes: + node.remove_virtual_loss() + + assert(np.sum(self.root.virtual_loss > 0) == 0) + #### compute value in batch manner unless the evaluator do not support it + try: + priors, values = self.evaluator(next_states) + except: + priors = [] + values = [] + for i in range(self.batch_size): + if next_states[i] is not None: + prior, value = self.evaluator(next_states[i]) + priors.append(prior) + values.append(value) + else: + priors.append(0.) + values.append(0.) + + #### for now next_state == origin_state + #### may have problem here. What if we reached the same next_state with same parent and action pair + for i in range(self.batch_size): + nodes[i].children[new_actions[i]].expansion(new_actions[i], + next_states[i], + self.action_num, + priors[i], + nodes[i].inverse) + + for i in range(self.batch_size): + nodes[i].children[new_actions[i]].backpropagation(values[i] + 0.) + + +##### TODO +class TSNodeVirtualLoss(MCTSNodeVirtualLoss): + def __init__(self, parent, action, state, action_num, prior, method="Gaussian", inverse=False): + super(TSNodeVirtualLoss, self).__init__(parent, action, state, action_num, prior, inverse) + if method == "Beta": + self.alpha = np.ones([action_num]) + self.beta = np.ones([action_num]) + if method == "Gaussian": + self.mu = np.zeros([action_num]) + self.sigma = np.zeros([action_num]) + +if __name__ == "__main__": + mcts_virtual_loss = MCTSNodeVirtualLoss(None, None, 10, 1, 'UCT') diff --git a/tianshou/core/mcts/mcts_virtual_loss_test.py b/tianshou/core/mcts/mcts_virtual_loss_test.py new file mode 100644 index 0000000..d2d6c81 --- /dev/null +++ b/tianshou/core/mcts/mcts_virtual_loss_test.py @@ -0,0 +1,55 @@ +# -*- coding: utf-8 -*- +# vim:fenc=utf-8 +# $File: mcts_virtual_loss_test.py +# $Date: Tue Dec 19 16:5459 2017 +0800 +# Original file: mcts_test.py +# $Author: renyong15 © +# + + + +import numpy as np +from mcts_virtual_loss import MCTSVirtualLoss +from evaluator import rollout_policy + + +class TestEnv: + def __init__(self, max_step=5): + self.max_step = max_step + self.reward = {i: np.random.uniform() for i in range(2 ** max_step)} + # self.reward = {0:1, 1:0} + self.best = max(self.reward.items(), key=lambda x: x[1]) + print(self.reward) + # print("The best arm is {} with expected reward {}".format(self.best[0],self.best[1])) + + def simulate_is_valid(self, state, act): + return True + + def step_forward(self, state, action): + if action != 0 and action != 1: + raise ValueError("Action must be 0 or 1! Your action is {}".format(action)) + if state[0] >= 2 ** state[1] or state[1] > self.max_step: + raise ValueError("Invalid State! Your state is {}".format(state)) + # print("Operate action {} at state {}, timestep {}".format(action, state[0], state[1])) + if state[1] == self.max_step: + new_state = None + reward = 0 + else: + num = state[0] + 2 ** state[1] * action + step = state[1] + 1 + new_state = [num, step] + if step == self.max_step: + reward = int(np.random.uniform() < self.reward[num]) + else: + reward = 0. + return new_state, reward + + +if __name__ == "__main__": + env = TestEnv(2) + rollout = rollout_policy(env, 2) + evaluator = lambda state: rollout(state) + mcts_virtual_loss = MCTSVirtualLoss(env, evaluator, [0, 0], 2, batch_size = 10) + for i in range(10): + mcts_virtual_loss.do_search(max_step = 100) + From 3b534064bd6c92c972883d448c7c77fa0884e356 Mon Sep 17 00:00:00 2001 From: mcgrady00h <281130306@qq.com> Date: Sat, 23 Dec 2017 02:48:53 +0800 Subject: [PATCH 2/4] fix virtual loss bug --- tianshou/core/mcts/mcts.py | 22 +++-------- tianshou/core/mcts/mcts_virtual_loss.py | 41 ++++++++++---------- tianshou/core/mcts/mcts_virtual_loss_test.py | 6 +-- tianshou/core/mcts/utils.py | 21 ++++++++++ 4 files changed, 49 insertions(+), 41 deletions(-) create mode 100644 tianshou/core/mcts/utils.py diff --git a/tianshou/core/mcts/mcts.py b/tianshou/core/mcts/mcts.py index 979e994..16d13d5 100644 --- a/tianshou/core/mcts/mcts.py +++ b/tianshou/core/mcts/mcts.py @@ -1,22 +1,9 @@ import numpy as np import math import time +import sys,os +from .utils import list2tuple, tuple2list -c_puct = 5 - - -def list2tuple(list): - try: - return tuple(list2tuple(sub) for sub in list) - except TypeError: - return list - - -def tuple2list(tuple): - try: - return list(tuple2list(sub) for sub in tuple) - except TypeError: - return tuple class MCTSNode(object): @@ -39,12 +26,13 @@ class MCTSNode(object): pass class UCTNode(MCTSNode): - def __init__(self, parent, action, state, action_num, prior, inverse=False): + def __init__(self, parent, action, state, action_num, prior, inverse=False, c_puct = 5): super(UCTNode, self).__init__(parent, action, state, action_num, prior, inverse) self.Q = np.zeros([action_num]) self.W = np.zeros([action_num]) self.N = np.zeros([action_num]) - self.ucb = self.Q + c_puct * self.prior * math.sqrt(np.sum(self.N)) / (self.N + 1) + self.c_puct = c_puct + self.ucb = self.Q + self.c_puct * self.prior * math.sqrt(np.sum(self.N)) / (self.N + 1) self.mask = None def selection(self, simulator): diff --git a/tianshou/core/mcts/mcts_virtual_loss.py b/tianshou/core/mcts/mcts_virtual_loss.py index 9d20b5a..9335464 100644 --- a/tianshou/core/mcts/mcts_virtual_loss.py +++ b/tianshou/core/mcts/mcts_virtual_loss.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- # vim:fenc=utf-8 # $File: mcts_virtual_loss.py -# $Date: Tue Dec 19 17:0444 2017 +0800 +# $Date: Sat Dec 23 02:4850 2017 +0800 # Original file: mcts.py # $Author: renyong15 © # @@ -12,25 +12,13 @@ manner. """ +from __future__ import absolute_import + import numpy as np import math import time - -c_puct = 5 - - -def list2tuple(list): - try: - return tuple(list2tuple(sub) for sub in list) - except TypeError: - return list - - -def tuple2list(tuple): - try: - return list(tuple2list(sub) for sub in tuple) - except TypeError: - return tuple +import sys,os +from .utils import list2tuple, tuple2list class MCTSNodeVirtualLoss(object): @@ -53,12 +41,13 @@ class MCTSNodeVirtualLoss(object): pass class UCTNodeVirtualLoss(MCTSNodeVirtualLoss): - def __init__(self, parent, action, state, action_num, prior, inverse=False): + def __init__(self, parent, action, state, action_num, prior, inverse=False, c_puct = 5): super(UCTNodeVirtualLoss, self).__init__(parent, action, state, action_num, prior, inverse) self.Q = np.zeros([action_num]) self.W = np.zeros([action_num]) self.N = np.zeros([action_num]) self.virtual_loss = np.zeros([action_num]) + self.c_puct = c_puct #### modified by adding virtual loss #self.ucb = self.Q + c_puct * self.prior * math.sqrt(np.sum(self.N)) / (self.N + 1) @@ -67,9 +56,9 @@ class UCTNodeVirtualLoss(MCTSNodeVirtualLoss): def selection(self, simulator): self.valid_mask(simulator) self.Q = np.zeros([self.action_num]) - N_not_zero = self.N > 0 - self.Q[N_not_zero] = (self.W[N_not_zero] + self.virtual_loss[N_not_zero] + 0.) / self.N[N_not_zero] - self.ucb = self.Q + c_puct * self.prior * math.sqrt(np.sum(self.N + self.virtual_loss)) /\ + N_not_zero = (self.N + self.virtual_loss) > 0 + self.Q[N_not_zero] = (self.W[N_not_zero] + 0.)/ (self.virtual_loss[N_not_zero] + self.N[N_not_zero]) + self.ucb = self.Q + self.c_puct * self.prior * math.sqrt(np.sum(self.N + self.virtual_loss)) /\ (self.N + self.virtual_loss + 1) action = np.argmax(self.ucb) self.virtual_loss[action] += 1 @@ -93,6 +82,7 @@ class UCTNodeVirtualLoss(MCTSNodeVirtualLoss): self.W[action] += self.children[action].reward ## do not need to compute Q and ucb immediately since it will be modified by virtual loss + ## just comment out and leaving for comparision #for i in range(self.action_num): # if self.N[i] != 0: # self.Q[i] = (self.W[i] + 0.) / self.N[i] @@ -186,6 +176,12 @@ class MCTSVirtualLoss(object): def do_search(self, max_step=None, max_time=None): + """ + Expand the MCTS tree with stop crierion either by max_step or max_time + + :param max_step search maximum minibath rounds. ONE step is ONE minibatch + :param max_time search maximum seconds + """ if max_step is not None: self.step = 0 self.max_step = max_step @@ -205,6 +201,9 @@ class MCTSVirtualLoss(object): self.step += 1 def expand(self): + """ + Core logic method for MCTS tree to expand nodes. + """ ## minibatch with virtual loss nodes = [] new_actions = [] diff --git a/tianshou/core/mcts/mcts_virtual_loss_test.py b/tianshou/core/mcts/mcts_virtual_loss_test.py index d2d6c81..e4666f3 100644 --- a/tianshou/core/mcts/mcts_virtual_loss_test.py +++ b/tianshou/core/mcts/mcts_virtual_loss_test.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- # vim:fenc=utf-8 # $File: mcts_virtual_loss_test.py -# $Date: Tue Dec 19 16:5459 2017 +0800 +# $Date: Sat Dec 23 02:2139 2017 +0800 # Original file: mcts_test.py # $Author: renyong15 © # @@ -9,8 +9,8 @@ import numpy as np -from mcts_virtual_loss import MCTSVirtualLoss -from evaluator import rollout_policy +from .mcts_virtual_loss import MCTSVirtualLoss +from .evaluator import rollout_policy class TestEnv: diff --git a/tianshou/core/mcts/utils.py b/tianshou/core/mcts/utils.py new file mode 100644 index 0000000..de518a0 --- /dev/null +++ b/tianshou/core/mcts/utils.py @@ -0,0 +1,21 @@ +# -*- coding: utf-8 -*- +# vim:fenc=utf-8 +# $File: utils.py +# $Date: Sat Dec 23 02:0854 2017 +0800 +# $Author: renyong15 © +# + +def list2tuple(list): + try: + return tuple(list2tuple(sub) for sub in list) + except TypeError: + return list + + +def tuple2list(tuple): + try: + return list(tuple2list(sub) for sub in tuple) + except TypeError: + return tuple + + From cf57144ce994dc57588c1473fc05e85bbac92587 Mon Sep 17 00:00:00 2001 From: mcgrady00h <281130306@qq.com> Date: Sun, 24 Dec 2017 15:47:11 +0800 Subject: [PATCH 3/4] merge master --- AlphaGo/network.py | 225 --------------------------------------------- 1 file changed, 225 deletions(-) delete mode 100644 AlphaGo/network.py diff --git a/AlphaGo/network.py b/AlphaGo/network.py deleted file mode 100644 index cfff6f3..0000000 --- a/AlphaGo/network.py +++ /dev/null @@ -1,225 +0,0 @@ -import os -import time -import sys - -import numpy as np -import time -import tensorflow as tf -import tensorflow.contrib.layers as layers - -import multi_gpu -import time -import copy - -# os.environ["CUDA_VISIBLE_DEVICES"] = "1" -os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' - - -def residual_block(input, is_training): - normalizer_params = {'is_training': is_training, - 'updates_collections': tf.GraphKeys.UPDATE_OPS} - h = layers.conv2d(input, 256, kernel_size=3, stride=1, activation_fn=tf.nn.relu, - normalizer_fn=layers.batch_norm, normalizer_params=normalizer_params, - weights_regularizer=layers.l2_regularizer(1e-4)) - h = layers.conv2d(h, 256, kernel_size=3, stride=1, activation_fn=tf.identity, - normalizer_fn=layers.batch_norm, normalizer_params=normalizer_params, - weights_regularizer=layers.l2_regularizer(1e-4)) - h = h + input - return tf.nn.relu(h) - - -def policy_heads(input, is_training): - normalizer_params = {'is_training': is_training, - 'updates_collections': tf.GraphKeys.UPDATE_OPS} - h = layers.conv2d(input, 2, kernel_size=1, stride=1, activation_fn=tf.nn.relu, - normalizer_fn=layers.batch_norm, normalizer_params=normalizer_params, - weights_regularizer=layers.l2_regularizer(1e-4)) - h = layers.flatten(h) - h = layers.fully_connected(h, 82, activation_fn=tf.identity, weights_regularizer=layers.l2_regularizer(1e-4)) - return h - - -def value_heads(input, is_training): - normalizer_params = {'is_training': is_training, - 'updates_collections': tf.GraphKeys.UPDATE_OPS} - h = layers.conv2d(input, 2, kernel_size=1, stride=1, activation_fn=tf.nn.relu, - normalizer_fn=layers.batch_norm, normalizer_params=normalizer_params, - weights_regularizer=layers.l2_regularizer(1e-4)) - h = layers.flatten(h) - h = layers.fully_connected(h, 256, activation_fn=tf.nn.relu, weights_regularizer=layers.l2_regularizer(1e-4)) - h = layers.fully_connected(h, 1, activation_fn=tf.nn.tanh, weights_regularizer=layers.l2_regularizer(1e-4)) - return h - - -class Network(object): - def __init__(self): - self.x = tf.placeholder(tf.float32, shape=[None, 9, 9, 17]) - self.is_training = tf.placeholder(tf.bool, shape=[]) - self.z = tf.placeholder(tf.float32, shape=[None, 1]) - self.pi = tf.placeholder(tf.float32, shape=[None, 82]) - self.build_network() - - def build_network(self): - h = layers.conv2d(self.x, 256, kernel_size=3, stride=1, activation_fn=tf.nn.relu, - normalizer_fn=layers.batch_norm, - normalizer_params={'is_training': self.is_training, - 'updates_collections': tf.GraphKeys.UPDATE_OPS}, - weights_regularizer=layers.l2_regularizer(1e-4)) - for i in range(4): - h = residual_block(h, self.is_training) - self.v = value_heads(h, self.is_training) - self.p = policy_heads(h, self.is_training) - # loss = tf.reduce_mean(tf.square(z-v)) - tf.multiply(pi, tf.log(tf.clip_by_value(tf.nn.softmax(p), 1e-8, tf.reduce_max(tf.nn.softmax(p))))) - self.value_loss = tf.reduce_mean(tf.square(self.z - self.v)) - self.policy_loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=self.pi, logits=self.p)) - - self.reg = tf.add_n(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) - self.total_loss = self.value_loss + self.policy_loss + self.reg - # train_op = tf.train.MomentumOptimizer(1e-4, momentum=0.9, use_nesterov=True).minimize(total_loss) - self.update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) - with tf.control_dependencies(self.update_ops): - self.train_op = tf.train.RMSPropOptimizer(1e-4).minimize(self.total_loss) - self.var_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) - self.saver = tf.train.Saver(max_to_keep=10, var_list=self.var_list) - self.sess = multi_gpu.create_session() - - def train(self): - data_path = "./training_data/" - data_name = os.listdir(data_path) - epochs = 100 - batch_size = 128 - - result_path = "./checkpoints_origin/" - with multi_gpu.create_session() as sess: - sess.run(tf.global_variables_initializer()) - ckpt_file = tf.train.latest_checkpoint(result_path) - if ckpt_file is not None: - print('Restoring model from {}...'.format(ckpt_file)) - self.saver.restore(sess, ckpt_file) - for epoch in range(epochs): - for name in data_name: - data = np.load(data_path + name) - boards = data["boards"] - wins = data["wins"] - ps = data["ps"] - print (boards.shape) - print (wins.shape) - print (ps.shape) - batch_num = boards.shape[0] // batch_size - index = np.arange(boards.shape[0]) - np.random.shuffle(index) - value_losses = [] - policy_losses = [] - regs = [] - time_train = -time.time() - for iter in range(batch_num): - lv, lp, r, value, prob, _ = sess.run( - [self.value_loss, self.policy_loss, self.reg, self.v, tf.nn.softmax(self.p), self.train_op], - feed_dict={self.x: boards[ - index[iter * batch_size:(iter + 1) * batch_size]], - self.z: wins[index[ - iter * batch_size:(iter + 1) * batch_size]], - self.pi: ps[index[ - iter * batch_size:(iter + 1) * batch_size]], - self.is_training: True}) - value_losses.append(lv) - policy_losses.append(lp) - regs.append(r) - if iter % 1 == 0: - print( - "Epoch: {}, Part {}, Iteration: {}, Time: {}, Value Loss: {}, Policy Loss: {}, Reg: {}".format( - epoch, name, iter, time.time() + time_train, np.mean(np.array(value_losses)), - np.mean(np.array(policy_losses)), np.mean(np.array(regs)))) - time_train = -time.time() - value_losses = [] - policy_losses = [] - regs = [] - if iter % 20 == 0: - save_path = "Epoch{}.Part{}.Iteration{}.ckpt".format(epoch, name, iter) - self.saver.save(sess, result_path + save_path) - del data, boards, wins, ps - - - # def forward(call_number): - # # checkpoint_path = "/home/yama/rl/tianshou/AlphaGo/checkpoints" - # checkpoint_path = "/home/jialian/stuGo/tianshou/stuGo/checkpoints/" - # board_file = np.genfromtxt("/home/jialian/stuGo/tianshou/leela-zero/src/mcts_nn_files/board_" + call_number, - # dtype='str'); - # human_board = np.zeros((17, 19, 19)) - # - # # TODO : is it ok to ignore the last channel? - # for i in range(17): - # human_board[i] = np.array(list(board_file[i])).reshape(19, 19) - # # print("============================") - # # print("human board sum : " + str(np.sum(human_board[-1]))) - # # print("============================") - # # print(human_board) - # # print("============================") - # # rint(human_board) - # feed_board = human_board.transpose(1, 2, 0).reshape(1, 19, 19, 17) - # # print(feed_board[:,:,:,-1]) - # # print(feed_board.shape) - # - # # npz_board = np.load("/home/yama/rl/tianshou/AlphaGo/data/7f83928932f64a79bc1efdea268698ae.npz") - # # print(npz_board["boards"].shape) - # # feed_board = npz_board["boards"][10].reshape(-1, 19, 19, 17) - # ##print(feed_board) - # # show_board = feed_board[0].transpose(2, 0, 1) - # # print("board shape : ", show_board.shape) - # # print(show_board) - # - # itflag = False - # with multi_gpu.create_session() as sess: - # sess.run(tf.global_variables_initializer()) - # ckpt_file = tf.train.latest_checkpoint(checkpoint_path) - # if ckpt_file is not None: - # # print('Restoring model from {}...'.format(ckpt_file)) - # saver.restore(sess, ckpt_file) - # else: - # raise ValueError("No model loaded") - # res = sess.run([tf.nn.softmax(p), v], feed_dict={x: feed_board, is_training: itflag}) - # # res = sess.run([tf.nn.softmax(p),v], feed_dict={x:fix_board["boards"][300].reshape(-1, 19, 19, 17), is_training:False}) - # # res = sess.run([tf.nn.softmax(p),v], feed_dict={x:fix_board["boards"][50].reshape(-1, 19, 19, 17), is_training:True}) - # # print(np.argmax(res[0])) - # np.savetxt(sys.stdout, res[0][0], fmt="%.6f", newline=" ") - # np.savetxt(sys.stdout, res[1][0], fmt="%.6f", newline=" ") - # pv_file = "/home/jialian/stuGotianshou/leela-zero/src/mcts_nn_files/policy_value" - # np.savetxt(pv_file, np.concatenate((res[0][0], res[1][0])), fmt="%.6f", newline=" ") - # # np.savetxt(pv_file, res[1][0], fmt="%.6f", newline=" ") - # return res - - def forward(self, checkpoint_path): - # checkpoint_path = "/home/tongzheng/tianshou/AlphaGo/checkpoints/" - # sess = multi_gpu.create_session() - # sess.run(tf.global_variables_initializer()) - if checkpoint_path is None: - self.sess.run(tf.global_variables_initializer()) - else: - ckpt_file = tf.train.latest_checkpoint(checkpoint_path) - if ckpt_file is not None: - # print('Restoring model from {}...'.format(ckpt_file)) - self.saver.restore(self.sess, ckpt_file) - # print('Successfully loaded') - else: - raise ValueError("No model loaded") - # prior, value = sess.run([tf.nn.softmax(p), v], feed_dict={x: state, is_training: False}) - # return prior, value - return self.sess - - -if __name__ == '__main__': - # state = np.random.randint(0, 1, [256, 9, 9, 17]) - # net = Network() - # net.train() - # sess = net.forward() - # start_time = time.time() - # for i in range(100): - # sess.run([tf.nn.softmax(net.p), net.v], feed_dict={net.x: state, net.is_training: False}) - # print("Step {}, use time {}".format(i, time.time() - start_time)) - # start_time = time.time() - net0 = Network() - sess0 = net0.forward("./checkpoints/") - print("Loaded") - while True: - pass - From 5aa5dcd191a266aca637574ff8aaab46ee1c58ae Mon Sep 17 00:00:00 2001 From: mcgrady00h <281130306@qq.com> Date: Sun, 24 Dec 2017 16:47:43 +0800 Subject: [PATCH 4/4] add comments for mcts with virtual loss --- tianshou/core/mcts/mcts_virtual_loss.py | 47 +++++++++++++++++++++---- 1 file changed, 41 insertions(+), 6 deletions(-) diff --git a/tianshou/core/mcts/mcts_virtual_loss.py b/tianshou/core/mcts/mcts_virtual_loss.py index 9335464..f27d8a3 100644 --- a/tianshou/core/mcts/mcts_virtual_loss.py +++ b/tianshou/core/mcts/mcts_virtual_loss.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- # vim:fenc=utf-8 # $File: mcts_virtual_loss.py -# $Date: Sat Dec 23 02:4850 2017 +0800 +# $Date: Sun Dec 24 16:4740 2017 +0800 # Original file: mcts.py # $Author: renyong15 © # @@ -22,7 +22,17 @@ from .utils import list2tuple, tuple2list class MCTSNodeVirtualLoss(object): - def __init__(self, parent, action, state, action_num, prior, inverse=False): + """ + MCTS abstract class with virtual loss. Currently we only support UCT node. + Role of the Parameters can be found in Readme.md. + """ + def __init__(self, + parent, + action, + state, + action_num, + prior, + inverse = False): self.parent = parent self.action = action self.children = {} @@ -41,7 +51,19 @@ class MCTSNodeVirtualLoss(object): pass class UCTNodeVirtualLoss(MCTSNodeVirtualLoss): - def __init__(self, parent, action, state, action_num, prior, inverse=False, c_puct = 5): + """ + UCT node (state node) with virtual loss. + Role of the Parameters can be found in Readme.md. + :param c_puct balance between exploration and exploition, + """ + def __init__(self, + parent, + action, + state, + action_num, + prior, + inverse=False, + c_puct = 5): super(UCTNodeVirtualLoss, self).__init__(parent, action, state, action_num, prior, inverse) self.Q = np.zeros([action_num]) self.W = np.zeros([action_num]) @@ -53,7 +75,8 @@ class UCTNodeVirtualLoss(MCTSNodeVirtualLoss): self.mask = None - def selection(self, simulator): + def selection(self, + simulator): self.valid_mask(simulator) self.Q = np.zeros([self.action_num]) N_not_zero = (self.N + self.virtual_loss) > 0 @@ -108,6 +131,9 @@ class UCTNodeVirtualLoss(MCTSNodeVirtualLoss): class ActionNodeVirtualLoss(object): + """ + Action node with virtual loss. + """ def __init__(self, parent, action): self.parent = parent self.action = action @@ -156,6 +182,9 @@ class ActionNodeVirtualLoss(object): class MCTSVirtualLoss(object): + """ + MCTS class with virtual loss + """ def __init__(self, simulator, evaluator, root, action_num, batch_size = 1, method = "UCT", inverse = False): self.simulator = simulator self.evaluator = evaluator @@ -196,13 +225,19 @@ class MCTSVirtualLoss(object): self.bp_time = [] while (max_step is not None and self.step < self.max_step or max_step is None) \ and (max_time is not None and time.time() - self.start_time < self.max_time or max_time is None): - self.expand() + self._expand() if max_step is not None: self.step += 1 - def expand(self): + def _expand(self): """ Core logic method for MCTS tree to expand nodes. + Steps to expand node: + 1. Select final action node with virtual loss and collect them in to a minibatch. + (i.e. root->action->state->action...->action) + 2. Remove the virtual loss + 3. Evaluate the whole minibatch using evaluator + 4. Expand new nodes and perform back propogation. """ ## minibatch with virtual loss nodes = []