fix virtual loss bug

This commit is contained in:
mcgrady00h 2017-12-23 02:48:53 +08:00
parent 1f011a44ef
commit 3b534064bd
4 changed files with 49 additions and 41 deletions

View File

@ -1,22 +1,9 @@
import numpy as np
import math
import time
import sys,os
from .utils import list2tuple, tuple2list
c_puct = 5
def list2tuple(list):
try:
return tuple(list2tuple(sub) for sub in list)
except TypeError:
return list
def tuple2list(tuple):
try:
return list(tuple2list(sub) for sub in tuple)
except TypeError:
return tuple
class MCTSNode(object):
@ -39,12 +26,13 @@ class MCTSNode(object):
pass
class UCTNode(MCTSNode):
def __init__(self, parent, action, state, action_num, prior, inverse=False):
def __init__(self, parent, action, state, action_num, prior, inverse=False, c_puct = 5):
super(UCTNode, self).__init__(parent, action, state, action_num, prior, inverse)
self.Q = np.zeros([action_num])
self.W = np.zeros([action_num])
self.N = np.zeros([action_num])
self.ucb = self.Q + c_puct * self.prior * math.sqrt(np.sum(self.N)) / (self.N + 1)
self.c_puct = c_puct
self.ucb = self.Q + self.c_puct * self.prior * math.sqrt(np.sum(self.N)) / (self.N + 1)
self.mask = None
def selection(self, simulator):

View File

@ -1,7 +1,7 @@
# -*- coding: utf-8 -*-
# vim:fenc=utf-8
# $File: mcts_virtual_loss.py
# $Date: Tue Dec 19 17:0444 2017 +0800
# $Date: Sat Dec 23 02:4850 2017 +0800
# Original file: mcts.py
# $Author: renyong15 © <mails.tsinghua.edu.cn>
#
@ -12,25 +12,13 @@
manner.
"""
from __future__ import absolute_import
import numpy as np
import math
import time
c_puct = 5
def list2tuple(list):
try:
return tuple(list2tuple(sub) for sub in list)
except TypeError:
return list
def tuple2list(tuple):
try:
return list(tuple2list(sub) for sub in tuple)
except TypeError:
return tuple
import sys,os
from .utils import list2tuple, tuple2list
class MCTSNodeVirtualLoss(object):
@ -53,12 +41,13 @@ class MCTSNodeVirtualLoss(object):
pass
class UCTNodeVirtualLoss(MCTSNodeVirtualLoss):
def __init__(self, parent, action, state, action_num, prior, inverse=False):
def __init__(self, parent, action, state, action_num, prior, inverse=False, c_puct = 5):
super(UCTNodeVirtualLoss, self).__init__(parent, action, state, action_num, prior, inverse)
self.Q = np.zeros([action_num])
self.W = np.zeros([action_num])
self.N = np.zeros([action_num])
self.virtual_loss = np.zeros([action_num])
self.c_puct = c_puct
#### modified by adding virtual loss
#self.ucb = self.Q + c_puct * self.prior * math.sqrt(np.sum(self.N)) / (self.N + 1)
@ -67,9 +56,9 @@ class UCTNodeVirtualLoss(MCTSNodeVirtualLoss):
def selection(self, simulator):
self.valid_mask(simulator)
self.Q = np.zeros([self.action_num])
N_not_zero = self.N > 0
self.Q[N_not_zero] = (self.W[N_not_zero] + self.virtual_loss[N_not_zero] + 0.) / self.N[N_not_zero]
self.ucb = self.Q + c_puct * self.prior * math.sqrt(np.sum(self.N + self.virtual_loss)) /\
N_not_zero = (self.N + self.virtual_loss) > 0
self.Q[N_not_zero] = (self.W[N_not_zero] + 0.)/ (self.virtual_loss[N_not_zero] + self.N[N_not_zero])
self.ucb = self.Q + self.c_puct * self.prior * math.sqrt(np.sum(self.N + self.virtual_loss)) /\
(self.N + self.virtual_loss + 1)
action = np.argmax(self.ucb)
self.virtual_loss[action] += 1
@ -93,6 +82,7 @@ class UCTNodeVirtualLoss(MCTSNodeVirtualLoss):
self.W[action] += self.children[action].reward
## do not need to compute Q and ucb immediately since it will be modified by virtual loss
## just comment out and leaving for comparision
#for i in range(self.action_num):
# if self.N[i] != 0:
# self.Q[i] = (self.W[i] + 0.) / self.N[i]
@ -186,6 +176,12 @@ class MCTSVirtualLoss(object):
def do_search(self, max_step=None, max_time=None):
"""
Expand the MCTS tree with stop crierion either by max_step or max_time
:param max_step search maximum minibath rounds. ONE step is ONE minibatch
:param max_time search maximum seconds
"""
if max_step is not None:
self.step = 0
self.max_step = max_step
@ -205,6 +201,9 @@ class MCTSVirtualLoss(object):
self.step += 1
def expand(self):
"""
Core logic method for MCTS tree to expand nodes.
"""
## minibatch with virtual loss
nodes = []
new_actions = []

View File

@ -1,7 +1,7 @@
# -*- coding: utf-8 -*-
# vim:fenc=utf-8
# $File: mcts_virtual_loss_test.py
# $Date: Tue Dec 19 16:5459 2017 +0800
# $Date: Sat Dec 23 02:2139 2017 +0800
# Original file: mcts_test.py
# $Author: renyong15 © <mails.tsinghua.edu.cn>
#
@ -9,8 +9,8 @@
import numpy as np
from mcts_virtual_loss import MCTSVirtualLoss
from evaluator import rollout_policy
from .mcts_virtual_loss import MCTSVirtualLoss
from .evaluator import rollout_policy
class TestEnv:

View File

@ -0,0 +1,21 @@
# -*- coding: utf-8 -*-
# vim:fenc=utf-8
# $File: utils.py
# $Date: Sat Dec 23 02:0854 2017 +0800
# $Author: renyong15 © <mails.tsinghua.edu.cn>
#
def list2tuple(list):
try:
return tuple(list2tuple(sub) for sub in list)
except TypeError:
return list
def tuple2list(tuple):
try:
return list(tuple2list(sub) for sub in tuple)
except TypeError:
return tuple