fix virtual loss bug

This commit is contained in:
mcgrady00h 2017-12-23 02:48:53 +08:00
parent 1f011a44ef
commit 3b534064bd
4 changed files with 49 additions and 41 deletions

View File

@ -1,22 +1,9 @@
import numpy as np import numpy as np
import math import math
import time import time
import sys,os
from .utils import list2tuple, tuple2list
c_puct = 5
def list2tuple(list):
try:
return tuple(list2tuple(sub) for sub in list)
except TypeError:
return list
def tuple2list(tuple):
try:
return list(tuple2list(sub) for sub in tuple)
except TypeError:
return tuple
class MCTSNode(object): class MCTSNode(object):
@ -39,12 +26,13 @@ class MCTSNode(object):
pass pass
class UCTNode(MCTSNode): class UCTNode(MCTSNode):
def __init__(self, parent, action, state, action_num, prior, inverse=False): def __init__(self, parent, action, state, action_num, prior, inverse=False, c_puct = 5):
super(UCTNode, self).__init__(parent, action, state, action_num, prior, inverse) super(UCTNode, self).__init__(parent, action, state, action_num, prior, inverse)
self.Q = np.zeros([action_num]) self.Q = np.zeros([action_num])
self.W = np.zeros([action_num]) self.W = np.zeros([action_num])
self.N = np.zeros([action_num]) self.N = np.zeros([action_num])
self.ucb = self.Q + c_puct * self.prior * math.sqrt(np.sum(self.N)) / (self.N + 1) self.c_puct = c_puct
self.ucb = self.Q + self.c_puct * self.prior * math.sqrt(np.sum(self.N)) / (self.N + 1)
self.mask = None self.mask = None
def selection(self, simulator): def selection(self, simulator):

View File

@ -1,7 +1,7 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# vim:fenc=utf-8 # vim:fenc=utf-8
# $File: mcts_virtual_loss.py # $File: mcts_virtual_loss.py
# $Date: Tue Dec 19 17:0444 2017 +0800 # $Date: Sat Dec 23 02:4850 2017 +0800
# Original file: mcts.py # Original file: mcts.py
# $Author: renyong15 © <mails.tsinghua.edu.cn> # $Author: renyong15 © <mails.tsinghua.edu.cn>
# #
@ -12,25 +12,13 @@
manner. manner.
""" """
from __future__ import absolute_import
import numpy as np import numpy as np
import math import math
import time import time
import sys,os
c_puct = 5 from .utils import list2tuple, tuple2list
def list2tuple(list):
try:
return tuple(list2tuple(sub) for sub in list)
except TypeError:
return list
def tuple2list(tuple):
try:
return list(tuple2list(sub) for sub in tuple)
except TypeError:
return tuple
class MCTSNodeVirtualLoss(object): class MCTSNodeVirtualLoss(object):
@ -53,12 +41,13 @@ class MCTSNodeVirtualLoss(object):
pass pass
class UCTNodeVirtualLoss(MCTSNodeVirtualLoss): class UCTNodeVirtualLoss(MCTSNodeVirtualLoss):
def __init__(self, parent, action, state, action_num, prior, inverse=False): def __init__(self, parent, action, state, action_num, prior, inverse=False, c_puct = 5):
super(UCTNodeVirtualLoss, self).__init__(parent, action, state, action_num, prior, inverse) super(UCTNodeVirtualLoss, self).__init__(parent, action, state, action_num, prior, inverse)
self.Q = np.zeros([action_num]) self.Q = np.zeros([action_num])
self.W = np.zeros([action_num]) self.W = np.zeros([action_num])
self.N = np.zeros([action_num]) self.N = np.zeros([action_num])
self.virtual_loss = np.zeros([action_num]) self.virtual_loss = np.zeros([action_num])
self.c_puct = c_puct
#### modified by adding virtual loss #### modified by adding virtual loss
#self.ucb = self.Q + c_puct * self.prior * math.sqrt(np.sum(self.N)) / (self.N + 1) #self.ucb = self.Q + c_puct * self.prior * math.sqrt(np.sum(self.N)) / (self.N + 1)
@ -67,9 +56,9 @@ class UCTNodeVirtualLoss(MCTSNodeVirtualLoss):
def selection(self, simulator): def selection(self, simulator):
self.valid_mask(simulator) self.valid_mask(simulator)
self.Q = np.zeros([self.action_num]) self.Q = np.zeros([self.action_num])
N_not_zero = self.N > 0 N_not_zero = (self.N + self.virtual_loss) > 0
self.Q[N_not_zero] = (self.W[N_not_zero] + self.virtual_loss[N_not_zero] + 0.) / self.N[N_not_zero] self.Q[N_not_zero] = (self.W[N_not_zero] + 0.)/ (self.virtual_loss[N_not_zero] + self.N[N_not_zero])
self.ucb = self.Q + c_puct * self.prior * math.sqrt(np.sum(self.N + self.virtual_loss)) /\ self.ucb = self.Q + self.c_puct * self.prior * math.sqrt(np.sum(self.N + self.virtual_loss)) /\
(self.N + self.virtual_loss + 1) (self.N + self.virtual_loss + 1)
action = np.argmax(self.ucb) action = np.argmax(self.ucb)
self.virtual_loss[action] += 1 self.virtual_loss[action] += 1
@ -93,6 +82,7 @@ class UCTNodeVirtualLoss(MCTSNodeVirtualLoss):
self.W[action] += self.children[action].reward self.W[action] += self.children[action].reward
## do not need to compute Q and ucb immediately since it will be modified by virtual loss ## do not need to compute Q and ucb immediately since it will be modified by virtual loss
## just comment out and leaving for comparision
#for i in range(self.action_num): #for i in range(self.action_num):
# if self.N[i] != 0: # if self.N[i] != 0:
# self.Q[i] = (self.W[i] + 0.) / self.N[i] # self.Q[i] = (self.W[i] + 0.) / self.N[i]
@ -186,6 +176,12 @@ class MCTSVirtualLoss(object):
def do_search(self, max_step=None, max_time=None): def do_search(self, max_step=None, max_time=None):
"""
Expand the MCTS tree with stop crierion either by max_step or max_time
:param max_step search maximum minibath rounds. ONE step is ONE minibatch
:param max_time search maximum seconds
"""
if max_step is not None: if max_step is not None:
self.step = 0 self.step = 0
self.max_step = max_step self.max_step = max_step
@ -205,6 +201,9 @@ class MCTSVirtualLoss(object):
self.step += 1 self.step += 1
def expand(self): def expand(self):
"""
Core logic method for MCTS tree to expand nodes.
"""
## minibatch with virtual loss ## minibatch with virtual loss
nodes = [] nodes = []
new_actions = [] new_actions = []

View File

@ -1,7 +1,7 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# vim:fenc=utf-8 # vim:fenc=utf-8
# $File: mcts_virtual_loss_test.py # $File: mcts_virtual_loss_test.py
# $Date: Tue Dec 19 16:5459 2017 +0800 # $Date: Sat Dec 23 02:2139 2017 +0800
# Original file: mcts_test.py # Original file: mcts_test.py
# $Author: renyong15 © <mails.tsinghua.edu.cn> # $Author: renyong15 © <mails.tsinghua.edu.cn>
# #
@ -9,8 +9,8 @@
import numpy as np import numpy as np
from mcts_virtual_loss import MCTSVirtualLoss from .mcts_virtual_loss import MCTSVirtualLoss
from evaluator import rollout_policy from .evaluator import rollout_policy
class TestEnv: class TestEnv:

View File

@ -0,0 +1,21 @@
# -*- coding: utf-8 -*-
# vim:fenc=utf-8
# $File: utils.py
# $Date: Sat Dec 23 02:0854 2017 +0800
# $Author: renyong15 © <mails.tsinghua.edu.cn>
#
def list2tuple(list):
try:
return tuple(list2tuple(sub) for sub in list)
except TypeError:
return list
def tuple2list(tuple):
try:
return list(tuple2list(sub) for sub in tuple)
except TypeError:
return tuple