Merge pull request #2 from sproblvem/mcts_virtual_loss
Mcts virtual loss
This commit is contained in:
commit
9583a14856
@ -1,225 +0,0 @@
|
|||||||
import os
|
|
||||||
import time
|
|
||||||
import sys
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import time
|
|
||||||
import tensorflow as tf
|
|
||||||
import tensorflow.contrib.layers as layers
|
|
||||||
|
|
||||||
import multi_gpu
|
|
||||||
import time
|
|
||||||
import copy
|
|
||||||
|
|
||||||
# os.environ["CUDA_VISIBLE_DEVICES"] = "1"
|
|
||||||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
|
|
||||||
|
|
||||||
|
|
||||||
def residual_block(input, is_training):
|
|
||||||
normalizer_params = {'is_training': is_training,
|
|
||||||
'updates_collections': tf.GraphKeys.UPDATE_OPS}
|
|
||||||
h = layers.conv2d(input, 256, kernel_size=3, stride=1, activation_fn=tf.nn.relu,
|
|
||||||
normalizer_fn=layers.batch_norm, normalizer_params=normalizer_params,
|
|
||||||
weights_regularizer=layers.l2_regularizer(1e-4))
|
|
||||||
h = layers.conv2d(h, 256, kernel_size=3, stride=1, activation_fn=tf.identity,
|
|
||||||
normalizer_fn=layers.batch_norm, normalizer_params=normalizer_params,
|
|
||||||
weights_regularizer=layers.l2_regularizer(1e-4))
|
|
||||||
h = h + input
|
|
||||||
return tf.nn.relu(h)
|
|
||||||
|
|
||||||
|
|
||||||
def policy_heads(input, is_training):
|
|
||||||
normalizer_params = {'is_training': is_training,
|
|
||||||
'updates_collections': tf.GraphKeys.UPDATE_OPS}
|
|
||||||
h = layers.conv2d(input, 2, kernel_size=1, stride=1, activation_fn=tf.nn.relu,
|
|
||||||
normalizer_fn=layers.batch_norm, normalizer_params=normalizer_params,
|
|
||||||
weights_regularizer=layers.l2_regularizer(1e-4))
|
|
||||||
h = layers.flatten(h)
|
|
||||||
h = layers.fully_connected(h, 82, activation_fn=tf.identity, weights_regularizer=layers.l2_regularizer(1e-4))
|
|
||||||
return h
|
|
||||||
|
|
||||||
|
|
||||||
def value_heads(input, is_training):
|
|
||||||
normalizer_params = {'is_training': is_training,
|
|
||||||
'updates_collections': tf.GraphKeys.UPDATE_OPS}
|
|
||||||
h = layers.conv2d(input, 2, kernel_size=1, stride=1, activation_fn=tf.nn.relu,
|
|
||||||
normalizer_fn=layers.batch_norm, normalizer_params=normalizer_params,
|
|
||||||
weights_regularizer=layers.l2_regularizer(1e-4))
|
|
||||||
h = layers.flatten(h)
|
|
||||||
h = layers.fully_connected(h, 256, activation_fn=tf.nn.relu, weights_regularizer=layers.l2_regularizer(1e-4))
|
|
||||||
h = layers.fully_connected(h, 1, activation_fn=tf.nn.tanh, weights_regularizer=layers.l2_regularizer(1e-4))
|
|
||||||
return h
|
|
||||||
|
|
||||||
|
|
||||||
class Network(object):
|
|
||||||
def __init__(self):
|
|
||||||
self.x = tf.placeholder(tf.float32, shape=[None, 9, 9, 17])
|
|
||||||
self.is_training = tf.placeholder(tf.bool, shape=[])
|
|
||||||
self.z = tf.placeholder(tf.float32, shape=[None, 1])
|
|
||||||
self.pi = tf.placeholder(tf.float32, shape=[None, 82])
|
|
||||||
self.build_network()
|
|
||||||
|
|
||||||
def build_network(self):
|
|
||||||
h = layers.conv2d(self.x, 256, kernel_size=3, stride=1, activation_fn=tf.nn.relu,
|
|
||||||
normalizer_fn=layers.batch_norm,
|
|
||||||
normalizer_params={'is_training': self.is_training,
|
|
||||||
'updates_collections': tf.GraphKeys.UPDATE_OPS},
|
|
||||||
weights_regularizer=layers.l2_regularizer(1e-4))
|
|
||||||
for i in range(4):
|
|
||||||
h = residual_block(h, self.is_training)
|
|
||||||
self.v = value_heads(h, self.is_training)
|
|
||||||
self.p = policy_heads(h, self.is_training)
|
|
||||||
# loss = tf.reduce_mean(tf.square(z-v)) - tf.multiply(pi, tf.log(tf.clip_by_value(tf.nn.softmax(p), 1e-8, tf.reduce_max(tf.nn.softmax(p)))))
|
|
||||||
self.value_loss = tf.reduce_mean(tf.square(self.z - self.v))
|
|
||||||
self.policy_loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=self.pi, logits=self.p))
|
|
||||||
|
|
||||||
self.reg = tf.add_n(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))
|
|
||||||
self.total_loss = self.value_loss + self.policy_loss + self.reg
|
|
||||||
# train_op = tf.train.MomentumOptimizer(1e-4, momentum=0.9, use_nesterov=True).minimize(total_loss)
|
|
||||||
self.update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
|
|
||||||
with tf.control_dependencies(self.update_ops):
|
|
||||||
self.train_op = tf.train.RMSPropOptimizer(1e-4).minimize(self.total_loss)
|
|
||||||
self.var_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
|
|
||||||
self.saver = tf.train.Saver(max_to_keep=10, var_list=self.var_list)
|
|
||||||
self.sess = multi_gpu.create_session()
|
|
||||||
|
|
||||||
def train(self):
|
|
||||||
data_path = "./training_data/"
|
|
||||||
data_name = os.listdir(data_path)
|
|
||||||
epochs = 100
|
|
||||||
batch_size = 128
|
|
||||||
|
|
||||||
result_path = "./checkpoints_origin/"
|
|
||||||
with multi_gpu.create_session() as sess:
|
|
||||||
sess.run(tf.global_variables_initializer())
|
|
||||||
ckpt_file = tf.train.latest_checkpoint(result_path)
|
|
||||||
if ckpt_file is not None:
|
|
||||||
print('Restoring model from {}...'.format(ckpt_file))
|
|
||||||
self.saver.restore(sess, ckpt_file)
|
|
||||||
for epoch in range(epochs):
|
|
||||||
for name in data_name:
|
|
||||||
data = np.load(data_path + name)
|
|
||||||
boards = data["boards"]
|
|
||||||
wins = data["wins"]
|
|
||||||
ps = data["ps"]
|
|
||||||
print (boards.shape)
|
|
||||||
print (wins.shape)
|
|
||||||
print (ps.shape)
|
|
||||||
batch_num = boards.shape[0] // batch_size
|
|
||||||
index = np.arange(boards.shape[0])
|
|
||||||
np.random.shuffle(index)
|
|
||||||
value_losses = []
|
|
||||||
policy_losses = []
|
|
||||||
regs = []
|
|
||||||
time_train = -time.time()
|
|
||||||
for iter in range(batch_num):
|
|
||||||
lv, lp, r, value, prob, _ = sess.run(
|
|
||||||
[self.value_loss, self.policy_loss, self.reg, self.v, tf.nn.softmax(self.p), self.train_op],
|
|
||||||
feed_dict={self.x: boards[
|
|
||||||
index[iter * batch_size:(iter + 1) * batch_size]],
|
|
||||||
self.z: wins[index[
|
|
||||||
iter * batch_size:(iter + 1) * batch_size]],
|
|
||||||
self.pi: ps[index[
|
|
||||||
iter * batch_size:(iter + 1) * batch_size]],
|
|
||||||
self.is_training: True})
|
|
||||||
value_losses.append(lv)
|
|
||||||
policy_losses.append(lp)
|
|
||||||
regs.append(r)
|
|
||||||
if iter % 1 == 0:
|
|
||||||
print(
|
|
||||||
"Epoch: {}, Part {}, Iteration: {}, Time: {}, Value Loss: {}, Policy Loss: {}, Reg: {}".format(
|
|
||||||
epoch, name, iter, time.time() + time_train, np.mean(np.array(value_losses)),
|
|
||||||
np.mean(np.array(policy_losses)), np.mean(np.array(regs))))
|
|
||||||
time_train = -time.time()
|
|
||||||
value_losses = []
|
|
||||||
policy_losses = []
|
|
||||||
regs = []
|
|
||||||
if iter % 20 == 0:
|
|
||||||
save_path = "Epoch{}.Part{}.Iteration{}.ckpt".format(epoch, name, iter)
|
|
||||||
self.saver.save(sess, result_path + save_path)
|
|
||||||
del data, boards, wins, ps
|
|
||||||
|
|
||||||
|
|
||||||
# def forward(call_number):
|
|
||||||
# # checkpoint_path = "/home/yama/rl/tianshou/AlphaGo/checkpoints"
|
|
||||||
# checkpoint_path = "/home/jialian/stuGo/tianshou/stuGo/checkpoints/"
|
|
||||||
# board_file = np.genfromtxt("/home/jialian/stuGo/tianshou/leela-zero/src/mcts_nn_files/board_" + call_number,
|
|
||||||
# dtype='str');
|
|
||||||
# human_board = np.zeros((17, 19, 19))
|
|
||||||
#
|
|
||||||
# # TODO : is it ok to ignore the last channel?
|
|
||||||
# for i in range(17):
|
|
||||||
# human_board[i] = np.array(list(board_file[i])).reshape(19, 19)
|
|
||||||
# # print("============================")
|
|
||||||
# # print("human board sum : " + str(np.sum(human_board[-1])))
|
|
||||||
# # print("============================")
|
|
||||||
# # print(human_board)
|
|
||||||
# # print("============================")
|
|
||||||
# # rint(human_board)
|
|
||||||
# feed_board = human_board.transpose(1, 2, 0).reshape(1, 19, 19, 17)
|
|
||||||
# # print(feed_board[:,:,:,-1])
|
|
||||||
# # print(feed_board.shape)
|
|
||||||
#
|
|
||||||
# # npz_board = np.load("/home/yama/rl/tianshou/AlphaGo/data/7f83928932f64a79bc1efdea268698ae.npz")
|
|
||||||
# # print(npz_board["boards"].shape)
|
|
||||||
# # feed_board = npz_board["boards"][10].reshape(-1, 19, 19, 17)
|
|
||||||
# ##print(feed_board)
|
|
||||||
# # show_board = feed_board[0].transpose(2, 0, 1)
|
|
||||||
# # print("board shape : ", show_board.shape)
|
|
||||||
# # print(show_board)
|
|
||||||
#
|
|
||||||
# itflag = False
|
|
||||||
# with multi_gpu.create_session() as sess:
|
|
||||||
# sess.run(tf.global_variables_initializer())
|
|
||||||
# ckpt_file = tf.train.latest_checkpoint(checkpoint_path)
|
|
||||||
# if ckpt_file is not None:
|
|
||||||
# # print('Restoring model from {}...'.format(ckpt_file))
|
|
||||||
# saver.restore(sess, ckpt_file)
|
|
||||||
# else:
|
|
||||||
# raise ValueError("No model loaded")
|
|
||||||
# res = sess.run([tf.nn.softmax(p), v], feed_dict={x: feed_board, is_training: itflag})
|
|
||||||
# # res = sess.run([tf.nn.softmax(p),v], feed_dict={x:fix_board["boards"][300].reshape(-1, 19, 19, 17), is_training:False})
|
|
||||||
# # res = sess.run([tf.nn.softmax(p),v], feed_dict={x:fix_board["boards"][50].reshape(-1, 19, 19, 17), is_training:True})
|
|
||||||
# # print(np.argmax(res[0]))
|
|
||||||
# np.savetxt(sys.stdout, res[0][0], fmt="%.6f", newline=" ")
|
|
||||||
# np.savetxt(sys.stdout, res[1][0], fmt="%.6f", newline=" ")
|
|
||||||
# pv_file = "/home/jialian/stuGotianshou/leela-zero/src/mcts_nn_files/policy_value"
|
|
||||||
# np.savetxt(pv_file, np.concatenate((res[0][0], res[1][0])), fmt="%.6f", newline=" ")
|
|
||||||
# # np.savetxt(pv_file, res[1][0], fmt="%.6f", newline=" ")
|
|
||||||
# return res
|
|
||||||
|
|
||||||
def forward(self, checkpoint_path):
|
|
||||||
# checkpoint_path = "/home/tongzheng/tianshou/AlphaGo/checkpoints/"
|
|
||||||
# sess = multi_gpu.create_session()
|
|
||||||
# sess.run(tf.global_variables_initializer())
|
|
||||||
if checkpoint_path is None:
|
|
||||||
self.sess.run(tf.global_variables_initializer())
|
|
||||||
else:
|
|
||||||
ckpt_file = tf.train.latest_checkpoint(checkpoint_path)
|
|
||||||
if ckpt_file is not None:
|
|
||||||
# print('Restoring model from {}...'.format(ckpt_file))
|
|
||||||
self.saver.restore(self.sess, ckpt_file)
|
|
||||||
# print('Successfully loaded')
|
|
||||||
else:
|
|
||||||
raise ValueError("No model loaded")
|
|
||||||
# prior, value = sess.run([tf.nn.softmax(p), v], feed_dict={x: state, is_training: False})
|
|
||||||
# return prior, value
|
|
||||||
return self.sess
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
# state = np.random.randint(0, 1, [256, 9, 9, 17])
|
|
||||||
# net = Network()
|
|
||||||
# net.train()
|
|
||||||
# sess = net.forward()
|
|
||||||
# start_time = time.time()
|
|
||||||
# for i in range(100):
|
|
||||||
# sess.run([tf.nn.softmax(net.p), net.v], feed_dict={net.x: state, net.is_training: False})
|
|
||||||
# print("Step {}, use time {}".format(i, time.time() - start_time))
|
|
||||||
# start_time = time.time()
|
|
||||||
net0 = Network()
|
|
||||||
sess0 = net0.forward("./checkpoints/")
|
|
||||||
print("Loaded")
|
|
||||||
while True:
|
|
||||||
pass
|
|
||||||
|
|
||||||
@ -37,7 +37,8 @@ class UCTNode(MCTSNode):
|
|||||||
self.Q = np.zeros([action_num])
|
self.Q = np.zeros([action_num])
|
||||||
self.W = np.zeros([action_num])
|
self.W = np.zeros([action_num])
|
||||||
self.N = np.zeros([action_num])
|
self.N = np.zeros([action_num])
|
||||||
self.ucb = self.Q + c_puct * self.prior * math.sqrt(np.sum(self.N)) / (self.N + 1)
|
self.c_puct = c_puct
|
||||||
|
self.ucb = self.Q + self.c_puct * self.prior * math.sqrt(np.sum(self.N)) / (self.N + 1)
|
||||||
self.mask = None
|
self.mask = None
|
||||||
self.elapse_time = 0
|
self.elapse_time = 0
|
||||||
self.mcts = mcts
|
self.mcts = mcts
|
||||||
|
|||||||
@ -12,6 +12,9 @@ class TestEnv:
|
|||||||
print(self.reward)
|
print(self.reward)
|
||||||
# print("The best arm is {} with expected reward {}".format(self.best[0],self.best[1]))
|
# print("The best arm is {} with expected reward {}".format(self.best[0],self.best[1]))
|
||||||
|
|
||||||
|
def simulate_is_valid(self, state, act):
|
||||||
|
return True
|
||||||
|
|
||||||
def step_forward(self, state, action):
|
def step_forward(self, state, action):
|
||||||
if action != 0 and action != 1:
|
if action != 0 and action != 1:
|
||||||
raise ValueError("Action must be 0 or 1! Your action is {}".format(action))
|
raise ValueError("Action must be 0 or 1! Your action is {}".format(action))
|
||||||
|
|||||||
297
tianshou/core/mcts/mcts_virtual_loss.py
Normal file
297
tianshou/core/mcts/mcts_virtual_loss.py
Normal file
@ -0,0 +1,297 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# vim:fenc=utf-8
|
||||||
|
# $File: mcts_virtual_loss.py
|
||||||
|
# $Date: Sun Dec 24 16:4740 2017 +0800
|
||||||
|
# Original file: mcts.py
|
||||||
|
# $Author: renyong15 © <mails.tsinghua.edu.cn>
|
||||||
|
#
|
||||||
|
|
||||||
|
"""
|
||||||
|
This is an implementation of the MCTS with virtual loss.
|
||||||
|
Due to the limitation of Python design mechanism, we implements the virtual loss in a mini-batch
|
||||||
|
manner.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import math
|
||||||
|
import time
|
||||||
|
import sys,os
|
||||||
|
from .utils import list2tuple, tuple2list
|
||||||
|
|
||||||
|
|
||||||
|
class MCTSNodeVirtualLoss(object):
|
||||||
|
"""
|
||||||
|
MCTS abstract class with virtual loss. Currently we only support UCT node.
|
||||||
|
Role of the Parameters can be found in Readme.md.
|
||||||
|
"""
|
||||||
|
def __init__(self,
|
||||||
|
parent,
|
||||||
|
action,
|
||||||
|
state,
|
||||||
|
action_num,
|
||||||
|
prior,
|
||||||
|
inverse = False):
|
||||||
|
self.parent = parent
|
||||||
|
self.action = action
|
||||||
|
self.children = {}
|
||||||
|
self.state = state
|
||||||
|
self.action_num = action_num
|
||||||
|
self.prior = np.array(prior).reshape(-1)
|
||||||
|
self.inverse = inverse
|
||||||
|
|
||||||
|
def selection(self, simulator):
|
||||||
|
raise NotImplementedError("Need to implement function selection")
|
||||||
|
|
||||||
|
def backpropagation(self, action):
|
||||||
|
raise NotImplementedError("Need to implement function backpropagation")
|
||||||
|
|
||||||
|
def valid_mask(self, simulator):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class UCTNodeVirtualLoss(MCTSNodeVirtualLoss):
|
||||||
|
"""
|
||||||
|
UCT node (state node) with virtual loss.
|
||||||
|
Role of the Parameters can be found in Readme.md.
|
||||||
|
:param c_puct balance between exploration and exploition,
|
||||||
|
"""
|
||||||
|
def __init__(self,
|
||||||
|
parent,
|
||||||
|
action,
|
||||||
|
state,
|
||||||
|
action_num,
|
||||||
|
prior,
|
||||||
|
inverse=False,
|
||||||
|
c_puct = 5):
|
||||||
|
super(UCTNodeVirtualLoss, self).__init__(parent, action, state, action_num, prior, inverse)
|
||||||
|
self.Q = np.zeros([action_num])
|
||||||
|
self.W = np.zeros([action_num])
|
||||||
|
self.N = np.zeros([action_num])
|
||||||
|
self.virtual_loss = np.zeros([action_num])
|
||||||
|
self.c_puct = c_puct
|
||||||
|
#### modified by adding virtual loss
|
||||||
|
#self.ucb = self.Q + c_puct * self.prior * math.sqrt(np.sum(self.N)) / (self.N + 1)
|
||||||
|
|
||||||
|
self.mask = None
|
||||||
|
|
||||||
|
def selection(self,
|
||||||
|
simulator):
|
||||||
|
self.valid_mask(simulator)
|
||||||
|
self.Q = np.zeros([self.action_num])
|
||||||
|
N_not_zero = (self.N + self.virtual_loss) > 0
|
||||||
|
self.Q[N_not_zero] = (self.W[N_not_zero] + 0.)/ (self.virtual_loss[N_not_zero] + self.N[N_not_zero])
|
||||||
|
self.ucb = self.Q + self.c_puct * self.prior * math.sqrt(np.sum(self.N + self.virtual_loss)) /\
|
||||||
|
(self.N + self.virtual_loss + 1)
|
||||||
|
action = np.argmax(self.ucb)
|
||||||
|
self.virtual_loss[action] += 1
|
||||||
|
|
||||||
|
if action in self.children.keys():
|
||||||
|
return self.children[action].selection(simulator)
|
||||||
|
else:
|
||||||
|
self.children[action] = ActionNodeVirtualLoss(self, action)
|
||||||
|
return self.children[action].selection(simulator)
|
||||||
|
|
||||||
|
def remove_virtual_loss(self):
|
||||||
|
### if not virtual_loss for every action is zero
|
||||||
|
if np.sum(self.virtual_loss > 0) > 0:
|
||||||
|
self.virtual_loss = np.zeros([self.action_num])
|
||||||
|
if self.parent:
|
||||||
|
self.parent.remove_virtual_loss()
|
||||||
|
|
||||||
|
def backpropagation(self, action):
|
||||||
|
action = int(action)
|
||||||
|
self.N[action] += 1
|
||||||
|
self.W[action] += self.children[action].reward
|
||||||
|
|
||||||
|
## do not need to compute Q and ucb immediately since it will be modified by virtual loss
|
||||||
|
## just comment out and leaving for comparision
|
||||||
|
#for i in range(self.action_num):
|
||||||
|
# if self.N[i] != 0:
|
||||||
|
# self.Q[i] = (self.W[i] + 0.) / self.N[i]
|
||||||
|
#self.ucb = self.Q + c_puct * self.prior * math.sqrt(np.sum(self.N)) / (self.N + 1.)
|
||||||
|
|
||||||
|
if self.parent is not None:
|
||||||
|
if self.inverse:
|
||||||
|
self.parent.backpropagation(-self.children[action].reward)
|
||||||
|
else:
|
||||||
|
self.parent.backpropagation(self.children[action].reward)
|
||||||
|
|
||||||
|
def valid_mask(self, simulator):
|
||||||
|
if self.mask is None:
|
||||||
|
start_time = time.time()
|
||||||
|
self.mask = []
|
||||||
|
for act in range(self.action_num - 1):
|
||||||
|
if not simulator.simulate_is_valid(self.state, act):
|
||||||
|
self.mask.append(act)
|
||||||
|
self.ucb[act] = -float("Inf")
|
||||||
|
else:
|
||||||
|
self.ucb[self.mask] = -float("Inf")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class ActionNodeVirtualLoss(object):
|
||||||
|
"""
|
||||||
|
Action node with virtual loss.
|
||||||
|
"""
|
||||||
|
def __init__(self, parent, action):
|
||||||
|
self.parent = parent
|
||||||
|
self.action = action
|
||||||
|
self.children = {}
|
||||||
|
self.next_state = None
|
||||||
|
self.origin_state = None
|
||||||
|
self.state_type = None
|
||||||
|
self.reward = 0
|
||||||
|
|
||||||
|
def remove_virtual_loss(self):
|
||||||
|
self.parent.remove_virtual_loss()
|
||||||
|
|
||||||
|
def type_conversion_to_tuple(self):
|
||||||
|
if type(self.next_state) is np.ndarray:
|
||||||
|
self.next_state = self.next_state.tolist()
|
||||||
|
if type(self.next_state) is list:
|
||||||
|
self.next_state = list2tuple(self.next_state)
|
||||||
|
|
||||||
|
def type_conversion_to_origin(self):
|
||||||
|
if self.state_type is np.ndarray:
|
||||||
|
self.next_state = np.array(self.next_state)
|
||||||
|
if self.state_type is list:
|
||||||
|
self.next_state = tuple2list(self.next_state)
|
||||||
|
|
||||||
|
def selection(self, simulator):
|
||||||
|
self.next_state, self.reward = simulator.step_forward(self.parent.state, self.action)
|
||||||
|
self.origin_state = self.next_state
|
||||||
|
self.state_type = type(self.next_state)
|
||||||
|
self.type_conversion_to_tuple()
|
||||||
|
if self.next_state is not None:
|
||||||
|
if self.next_state in self.children.keys():
|
||||||
|
return self.children[self.next_state].selection(simulator)
|
||||||
|
else:
|
||||||
|
return self.parent, self.action
|
||||||
|
else:
|
||||||
|
return self.parent, self.action
|
||||||
|
|
||||||
|
def expansion(self, action, state, action_num, prior, inverse ):
|
||||||
|
if state is not None:
|
||||||
|
self.children[state] = UCTNodeVirtualLoss(self, action, state, action_num, prior, inverse)
|
||||||
|
|
||||||
|
|
||||||
|
def backpropagation(self, value):
|
||||||
|
self.reward += value
|
||||||
|
self.parent.backpropagation(self.action)
|
||||||
|
|
||||||
|
|
||||||
|
class MCTSVirtualLoss(object):
|
||||||
|
"""
|
||||||
|
MCTS class with virtual loss
|
||||||
|
"""
|
||||||
|
def __init__(self, simulator, evaluator, root, action_num, batch_size = 1, method = "UCT", inverse = False):
|
||||||
|
self.simulator = simulator
|
||||||
|
self.evaluator = evaluator
|
||||||
|
prior, _ = self.evaluator(root)
|
||||||
|
self.action_num = action_num
|
||||||
|
self.batch_size = batch_size
|
||||||
|
|
||||||
|
if method == "":
|
||||||
|
self.root = root
|
||||||
|
elif method == "UCT":
|
||||||
|
self.root = UCTNodeVirtualLoss(None, None, root, action_num, prior, inverse)
|
||||||
|
elif method == "TS":
|
||||||
|
self.root = TSNodeVirtualLoss(None, None, root, action_num, prior, inverse=inverse)
|
||||||
|
else:
|
||||||
|
raise ValueError("Need a root type")
|
||||||
|
|
||||||
|
self.inverse = inverse
|
||||||
|
|
||||||
|
|
||||||
|
def do_search(self, max_step=None, max_time=None):
|
||||||
|
"""
|
||||||
|
Expand the MCTS tree with stop crierion either by max_step or max_time
|
||||||
|
|
||||||
|
:param max_step search maximum minibath rounds. ONE step is ONE minibatch
|
||||||
|
:param max_time search maximum seconds
|
||||||
|
"""
|
||||||
|
if max_step is not None:
|
||||||
|
self.step = 0
|
||||||
|
self.max_step = max_step
|
||||||
|
if max_time is not None:
|
||||||
|
self.start_time = time.time()
|
||||||
|
self.max_time = max_time
|
||||||
|
if max_step is None and max_time is None:
|
||||||
|
raise ValueError("Need a stop criteria!")
|
||||||
|
|
||||||
|
self.select_time = []
|
||||||
|
self.evaluate_time = []
|
||||||
|
self.bp_time = []
|
||||||
|
while (max_step is not None and self.step < self.max_step or max_step is None) \
|
||||||
|
and (max_time is not None and time.time() - self.start_time < self.max_time or max_time is None):
|
||||||
|
self._expand()
|
||||||
|
if max_step is not None:
|
||||||
|
self.step += 1
|
||||||
|
|
||||||
|
def _expand(self):
|
||||||
|
"""
|
||||||
|
Core logic method for MCTS tree to expand nodes.
|
||||||
|
Steps to expand node:
|
||||||
|
1. Select final action node with virtual loss and collect them in to a minibatch.
|
||||||
|
(i.e. root->action->state->action...->action)
|
||||||
|
2. Remove the virtual loss
|
||||||
|
3. Evaluate the whole minibatch using evaluator
|
||||||
|
4. Expand new nodes and perform back propogation.
|
||||||
|
"""
|
||||||
|
## minibatch with virtual loss
|
||||||
|
nodes = []
|
||||||
|
new_actions = []
|
||||||
|
next_states = []
|
||||||
|
|
||||||
|
for i in range(self.batch_size):
|
||||||
|
node, new_action = self.root.selection(self.simulator)
|
||||||
|
nodes.append(node)
|
||||||
|
new_actions.append(new_action)
|
||||||
|
next_states.append(node.children[new_action].next_state)
|
||||||
|
|
||||||
|
for node in nodes:
|
||||||
|
node.remove_virtual_loss()
|
||||||
|
|
||||||
|
assert(np.sum(self.root.virtual_loss > 0) == 0)
|
||||||
|
#### compute value in batch manner unless the evaluator do not support it
|
||||||
|
try:
|
||||||
|
priors, values = self.evaluator(next_states)
|
||||||
|
except:
|
||||||
|
priors = []
|
||||||
|
values = []
|
||||||
|
for i in range(self.batch_size):
|
||||||
|
if next_states[i] is not None:
|
||||||
|
prior, value = self.evaluator(next_states[i])
|
||||||
|
priors.append(prior)
|
||||||
|
values.append(value)
|
||||||
|
else:
|
||||||
|
priors.append(0.)
|
||||||
|
values.append(0.)
|
||||||
|
|
||||||
|
#### for now next_state == origin_state
|
||||||
|
#### may have problem here. What if we reached the same next_state with same parent and action pair
|
||||||
|
for i in range(self.batch_size):
|
||||||
|
nodes[i].children[new_actions[i]].expansion(new_actions[i],
|
||||||
|
next_states[i],
|
||||||
|
self.action_num,
|
||||||
|
priors[i],
|
||||||
|
nodes[i].inverse)
|
||||||
|
|
||||||
|
for i in range(self.batch_size):
|
||||||
|
nodes[i].children[new_actions[i]].backpropagation(values[i] + 0.)
|
||||||
|
|
||||||
|
|
||||||
|
##### TODO
|
||||||
|
class TSNodeVirtualLoss(MCTSNodeVirtualLoss):
|
||||||
|
def __init__(self, parent, action, state, action_num, prior, method="Gaussian", inverse=False):
|
||||||
|
super(TSNodeVirtualLoss, self).__init__(parent, action, state, action_num, prior, inverse)
|
||||||
|
if method == "Beta":
|
||||||
|
self.alpha = np.ones([action_num])
|
||||||
|
self.beta = np.ones([action_num])
|
||||||
|
if method == "Gaussian":
|
||||||
|
self.mu = np.zeros([action_num])
|
||||||
|
self.sigma = np.zeros([action_num])
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
mcts_virtual_loss = MCTSNodeVirtualLoss(None, None, 10, 1, 'UCT')
|
||||||
55
tianshou/core/mcts/mcts_virtual_loss_test.py
Normal file
55
tianshou/core/mcts/mcts_virtual_loss_test.py
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# vim:fenc=utf-8
|
||||||
|
# $File: mcts_virtual_loss_test.py
|
||||||
|
# $Date: Sat Dec 23 02:2139 2017 +0800
|
||||||
|
# Original file: mcts_test.py
|
||||||
|
# $Author: renyong15 © <mails.tsinghua.edu.cn>
|
||||||
|
#
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from .mcts_virtual_loss import MCTSVirtualLoss
|
||||||
|
from .evaluator import rollout_policy
|
||||||
|
|
||||||
|
|
||||||
|
class TestEnv:
|
||||||
|
def __init__(self, max_step=5):
|
||||||
|
self.max_step = max_step
|
||||||
|
self.reward = {i: np.random.uniform() for i in range(2 ** max_step)}
|
||||||
|
# self.reward = {0:1, 1:0}
|
||||||
|
self.best = max(self.reward.items(), key=lambda x: x[1])
|
||||||
|
print(self.reward)
|
||||||
|
# print("The best arm is {} with expected reward {}".format(self.best[0],self.best[1]))
|
||||||
|
|
||||||
|
def simulate_is_valid(self, state, act):
|
||||||
|
return True
|
||||||
|
|
||||||
|
def step_forward(self, state, action):
|
||||||
|
if action != 0 and action != 1:
|
||||||
|
raise ValueError("Action must be 0 or 1! Your action is {}".format(action))
|
||||||
|
if state[0] >= 2 ** state[1] or state[1] > self.max_step:
|
||||||
|
raise ValueError("Invalid State! Your state is {}".format(state))
|
||||||
|
# print("Operate action {} at state {}, timestep {}".format(action, state[0], state[1]))
|
||||||
|
if state[1] == self.max_step:
|
||||||
|
new_state = None
|
||||||
|
reward = 0
|
||||||
|
else:
|
||||||
|
num = state[0] + 2 ** state[1] * action
|
||||||
|
step = state[1] + 1
|
||||||
|
new_state = [num, step]
|
||||||
|
if step == self.max_step:
|
||||||
|
reward = int(np.random.uniform() < self.reward[num])
|
||||||
|
else:
|
||||||
|
reward = 0.
|
||||||
|
return new_state, reward
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
env = TestEnv(2)
|
||||||
|
rollout = rollout_policy(env, 2)
|
||||||
|
evaluator = lambda state: rollout(state)
|
||||||
|
mcts_virtual_loss = MCTSVirtualLoss(env, evaluator, [0, 0], 2, batch_size = 10)
|
||||||
|
for i in range(10):
|
||||||
|
mcts_virtual_loss.do_search(max_step = 100)
|
||||||
|
|
||||||
21
tianshou/core/mcts/utils.py
Normal file
21
tianshou/core/mcts/utils.py
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# vim:fenc=utf-8
|
||||||
|
# $File: utils.py
|
||||||
|
# $Date: Sat Dec 23 02:0854 2017 +0800
|
||||||
|
# $Author: renyong15 © <mails.tsinghua.edu.cn>
|
||||||
|
#
|
||||||
|
|
||||||
|
def list2tuple(list):
|
||||||
|
try:
|
||||||
|
return tuple(list2tuple(sub) for sub in list)
|
||||||
|
except TypeError:
|
||||||
|
return list
|
||||||
|
|
||||||
|
|
||||||
|
def tuple2list(tuple):
|
||||||
|
try:
|
||||||
|
return list(tuple2list(sub) for sub in tuple)
|
||||||
|
except TypeError:
|
||||||
|
return tuple
|
||||||
|
|
||||||
|
|
||||||
Loading…
x
Reference in New Issue
Block a user