From 3b534064bd6c92c972883d448c7c77fa0884e356 Mon Sep 17 00:00:00 2001 From: mcgrady00h <281130306@qq.com> Date: Sat, 23 Dec 2017 02:48:53 +0800 Subject: [PATCH] fix virtual loss bug --- tianshou/core/mcts/mcts.py | 22 +++-------- tianshou/core/mcts/mcts_virtual_loss.py | 41 ++++++++++---------- tianshou/core/mcts/mcts_virtual_loss_test.py | 6 +-- tianshou/core/mcts/utils.py | 21 ++++++++++ 4 files changed, 49 insertions(+), 41 deletions(-) create mode 100644 tianshou/core/mcts/utils.py diff --git a/tianshou/core/mcts/mcts.py b/tianshou/core/mcts/mcts.py index 979e994..16d13d5 100644 --- a/tianshou/core/mcts/mcts.py +++ b/tianshou/core/mcts/mcts.py @@ -1,22 +1,9 @@ import numpy as np import math import time +import sys,os +from .utils import list2tuple, tuple2list -c_puct = 5 - - -def list2tuple(list): - try: - return tuple(list2tuple(sub) for sub in list) - except TypeError: - return list - - -def tuple2list(tuple): - try: - return list(tuple2list(sub) for sub in tuple) - except TypeError: - return tuple class MCTSNode(object): @@ -39,12 +26,13 @@ class MCTSNode(object): pass class UCTNode(MCTSNode): - def __init__(self, parent, action, state, action_num, prior, inverse=False): + def __init__(self, parent, action, state, action_num, prior, inverse=False, c_puct = 5): super(UCTNode, self).__init__(parent, action, state, action_num, prior, inverse) self.Q = np.zeros([action_num]) self.W = np.zeros([action_num]) self.N = np.zeros([action_num]) - self.ucb = self.Q + c_puct * self.prior * math.sqrt(np.sum(self.N)) / (self.N + 1) + self.c_puct = c_puct + self.ucb = self.Q + self.c_puct * self.prior * math.sqrt(np.sum(self.N)) / (self.N + 1) self.mask = None def selection(self, simulator): diff --git a/tianshou/core/mcts/mcts_virtual_loss.py b/tianshou/core/mcts/mcts_virtual_loss.py index 9d20b5a..9335464 100644 --- a/tianshou/core/mcts/mcts_virtual_loss.py +++ b/tianshou/core/mcts/mcts_virtual_loss.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- # vim:fenc=utf-8 # $File: mcts_virtual_loss.py -# $Date: Tue Dec 19 17:0444 2017 +0800 +# $Date: Sat Dec 23 02:4850 2017 +0800 # Original file: mcts.py # $Author: renyong15 © # @@ -12,25 +12,13 @@ manner. """ +from __future__ import absolute_import + import numpy as np import math import time - -c_puct = 5 - - -def list2tuple(list): - try: - return tuple(list2tuple(sub) for sub in list) - except TypeError: - return list - - -def tuple2list(tuple): - try: - return list(tuple2list(sub) for sub in tuple) - except TypeError: - return tuple +import sys,os +from .utils import list2tuple, tuple2list class MCTSNodeVirtualLoss(object): @@ -53,12 +41,13 @@ class MCTSNodeVirtualLoss(object): pass class UCTNodeVirtualLoss(MCTSNodeVirtualLoss): - def __init__(self, parent, action, state, action_num, prior, inverse=False): + def __init__(self, parent, action, state, action_num, prior, inverse=False, c_puct = 5): super(UCTNodeVirtualLoss, self).__init__(parent, action, state, action_num, prior, inverse) self.Q = np.zeros([action_num]) self.W = np.zeros([action_num]) self.N = np.zeros([action_num]) self.virtual_loss = np.zeros([action_num]) + self.c_puct = c_puct #### modified by adding virtual loss #self.ucb = self.Q + c_puct * self.prior * math.sqrt(np.sum(self.N)) / (self.N + 1) @@ -67,9 +56,9 @@ class UCTNodeVirtualLoss(MCTSNodeVirtualLoss): def selection(self, simulator): self.valid_mask(simulator) self.Q = np.zeros([self.action_num]) - N_not_zero = self.N > 0 - self.Q[N_not_zero] = (self.W[N_not_zero] + self.virtual_loss[N_not_zero] + 0.) / self.N[N_not_zero] - self.ucb = self.Q + c_puct * self.prior * math.sqrt(np.sum(self.N + self.virtual_loss)) /\ + N_not_zero = (self.N + self.virtual_loss) > 0 + self.Q[N_not_zero] = (self.W[N_not_zero] + 0.)/ (self.virtual_loss[N_not_zero] + self.N[N_not_zero]) + self.ucb = self.Q + self.c_puct * self.prior * math.sqrt(np.sum(self.N + self.virtual_loss)) /\ (self.N + self.virtual_loss + 1) action = np.argmax(self.ucb) self.virtual_loss[action] += 1 @@ -93,6 +82,7 @@ class UCTNodeVirtualLoss(MCTSNodeVirtualLoss): self.W[action] += self.children[action].reward ## do not need to compute Q and ucb immediately since it will be modified by virtual loss + ## just comment out and leaving for comparision #for i in range(self.action_num): # if self.N[i] != 0: # self.Q[i] = (self.W[i] + 0.) / self.N[i] @@ -186,6 +176,12 @@ class MCTSVirtualLoss(object): def do_search(self, max_step=None, max_time=None): + """ + Expand the MCTS tree with stop crierion either by max_step or max_time + + :param max_step search maximum minibath rounds. ONE step is ONE minibatch + :param max_time search maximum seconds + """ if max_step is not None: self.step = 0 self.max_step = max_step @@ -205,6 +201,9 @@ class MCTSVirtualLoss(object): self.step += 1 def expand(self): + """ + Core logic method for MCTS tree to expand nodes. + """ ## minibatch with virtual loss nodes = [] new_actions = [] diff --git a/tianshou/core/mcts/mcts_virtual_loss_test.py b/tianshou/core/mcts/mcts_virtual_loss_test.py index d2d6c81..e4666f3 100644 --- a/tianshou/core/mcts/mcts_virtual_loss_test.py +++ b/tianshou/core/mcts/mcts_virtual_loss_test.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- # vim:fenc=utf-8 # $File: mcts_virtual_loss_test.py -# $Date: Tue Dec 19 16:5459 2017 +0800 +# $Date: Sat Dec 23 02:2139 2017 +0800 # Original file: mcts_test.py # $Author: renyong15 © # @@ -9,8 +9,8 @@ import numpy as np -from mcts_virtual_loss import MCTSVirtualLoss -from evaluator import rollout_policy +from .mcts_virtual_loss import MCTSVirtualLoss +from .evaluator import rollout_policy class TestEnv: diff --git a/tianshou/core/mcts/utils.py b/tianshou/core/mcts/utils.py new file mode 100644 index 0000000..de518a0 --- /dev/null +++ b/tianshou/core/mcts/utils.py @@ -0,0 +1,21 @@ +# -*- coding: utf-8 -*- +# vim:fenc=utf-8 +# $File: utils.py +# $Date: Sat Dec 23 02:0854 2017 +0800 +# $Author: renyong15 © +# + +def list2tuple(list): + try: + return tuple(list2tuple(sub) for sub in list) + except TypeError: + return list + + +def tuple2list(tuple): + try: + return list(tuple2list(sub) for sub in tuple) + except TypeError: + return tuple + +