fix virtual loss bug
This commit is contained in:
parent
1f011a44ef
commit
3b534064bd
@ -1,22 +1,9 @@
|
||||
import numpy as np
|
||||
import math
|
||||
import time
|
||||
import sys,os
|
||||
from .utils import list2tuple, tuple2list
|
||||
|
||||
c_puct = 5
|
||||
|
||||
|
||||
def list2tuple(list):
|
||||
try:
|
||||
return tuple(list2tuple(sub) for sub in list)
|
||||
except TypeError:
|
||||
return list
|
||||
|
||||
|
||||
def tuple2list(tuple):
|
||||
try:
|
||||
return list(tuple2list(sub) for sub in tuple)
|
||||
except TypeError:
|
||||
return tuple
|
||||
|
||||
|
||||
class MCTSNode(object):
|
||||
@ -39,12 +26,13 @@ class MCTSNode(object):
|
||||
pass
|
||||
|
||||
class UCTNode(MCTSNode):
|
||||
def __init__(self, parent, action, state, action_num, prior, inverse=False):
|
||||
def __init__(self, parent, action, state, action_num, prior, inverse=False, c_puct = 5):
|
||||
super(UCTNode, self).__init__(parent, action, state, action_num, prior, inverse)
|
||||
self.Q = np.zeros([action_num])
|
||||
self.W = np.zeros([action_num])
|
||||
self.N = np.zeros([action_num])
|
||||
self.ucb = self.Q + c_puct * self.prior * math.sqrt(np.sum(self.N)) / (self.N + 1)
|
||||
self.c_puct = c_puct
|
||||
self.ucb = self.Q + self.c_puct * self.prior * math.sqrt(np.sum(self.N)) / (self.N + 1)
|
||||
self.mask = None
|
||||
|
||||
def selection(self, simulator):
|
||||
|
@ -1,7 +1,7 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# vim:fenc=utf-8
|
||||
# $File: mcts_virtual_loss.py
|
||||
# $Date: Tue Dec 19 17:0444 2017 +0800
|
||||
# $Date: Sat Dec 23 02:4850 2017 +0800
|
||||
# Original file: mcts.py
|
||||
# $Author: renyong15 © <mails.tsinghua.edu.cn>
|
||||
#
|
||||
@ -12,25 +12,13 @@
|
||||
manner.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
|
||||
import numpy as np
|
||||
import math
|
||||
import time
|
||||
|
||||
c_puct = 5
|
||||
|
||||
|
||||
def list2tuple(list):
|
||||
try:
|
||||
return tuple(list2tuple(sub) for sub in list)
|
||||
except TypeError:
|
||||
return list
|
||||
|
||||
|
||||
def tuple2list(tuple):
|
||||
try:
|
||||
return list(tuple2list(sub) for sub in tuple)
|
||||
except TypeError:
|
||||
return tuple
|
||||
import sys,os
|
||||
from .utils import list2tuple, tuple2list
|
||||
|
||||
|
||||
class MCTSNodeVirtualLoss(object):
|
||||
@ -53,12 +41,13 @@ class MCTSNodeVirtualLoss(object):
|
||||
pass
|
||||
|
||||
class UCTNodeVirtualLoss(MCTSNodeVirtualLoss):
|
||||
def __init__(self, parent, action, state, action_num, prior, inverse=False):
|
||||
def __init__(self, parent, action, state, action_num, prior, inverse=False, c_puct = 5):
|
||||
super(UCTNodeVirtualLoss, self).__init__(parent, action, state, action_num, prior, inverse)
|
||||
self.Q = np.zeros([action_num])
|
||||
self.W = np.zeros([action_num])
|
||||
self.N = np.zeros([action_num])
|
||||
self.virtual_loss = np.zeros([action_num])
|
||||
self.c_puct = c_puct
|
||||
#### modified by adding virtual loss
|
||||
#self.ucb = self.Q + c_puct * self.prior * math.sqrt(np.sum(self.N)) / (self.N + 1)
|
||||
|
||||
@ -67,9 +56,9 @@ class UCTNodeVirtualLoss(MCTSNodeVirtualLoss):
|
||||
def selection(self, simulator):
|
||||
self.valid_mask(simulator)
|
||||
self.Q = np.zeros([self.action_num])
|
||||
N_not_zero = self.N > 0
|
||||
self.Q[N_not_zero] = (self.W[N_not_zero] + self.virtual_loss[N_not_zero] + 0.) / self.N[N_not_zero]
|
||||
self.ucb = self.Q + c_puct * self.prior * math.sqrt(np.sum(self.N + self.virtual_loss)) /\
|
||||
N_not_zero = (self.N + self.virtual_loss) > 0
|
||||
self.Q[N_not_zero] = (self.W[N_not_zero] + 0.)/ (self.virtual_loss[N_not_zero] + self.N[N_not_zero])
|
||||
self.ucb = self.Q + self.c_puct * self.prior * math.sqrt(np.sum(self.N + self.virtual_loss)) /\
|
||||
(self.N + self.virtual_loss + 1)
|
||||
action = np.argmax(self.ucb)
|
||||
self.virtual_loss[action] += 1
|
||||
@ -93,6 +82,7 @@ class UCTNodeVirtualLoss(MCTSNodeVirtualLoss):
|
||||
self.W[action] += self.children[action].reward
|
||||
|
||||
## do not need to compute Q and ucb immediately since it will be modified by virtual loss
|
||||
## just comment out and leaving for comparision
|
||||
#for i in range(self.action_num):
|
||||
# if self.N[i] != 0:
|
||||
# self.Q[i] = (self.W[i] + 0.) / self.N[i]
|
||||
@ -186,6 +176,12 @@ class MCTSVirtualLoss(object):
|
||||
|
||||
|
||||
def do_search(self, max_step=None, max_time=None):
|
||||
"""
|
||||
Expand the MCTS tree with stop crierion either by max_step or max_time
|
||||
|
||||
:param max_step search maximum minibath rounds. ONE step is ONE minibatch
|
||||
:param max_time search maximum seconds
|
||||
"""
|
||||
if max_step is not None:
|
||||
self.step = 0
|
||||
self.max_step = max_step
|
||||
@ -205,6 +201,9 @@ class MCTSVirtualLoss(object):
|
||||
self.step += 1
|
||||
|
||||
def expand(self):
|
||||
"""
|
||||
Core logic method for MCTS tree to expand nodes.
|
||||
"""
|
||||
## minibatch with virtual loss
|
||||
nodes = []
|
||||
new_actions = []
|
||||
|
@ -1,7 +1,7 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# vim:fenc=utf-8
|
||||
# $File: mcts_virtual_loss_test.py
|
||||
# $Date: Tue Dec 19 16:5459 2017 +0800
|
||||
# $Date: Sat Dec 23 02:2139 2017 +0800
|
||||
# Original file: mcts_test.py
|
||||
# $Author: renyong15 © <mails.tsinghua.edu.cn>
|
||||
#
|
||||
@ -9,8 +9,8 @@
|
||||
|
||||
|
||||
import numpy as np
|
||||
from mcts_virtual_loss import MCTSVirtualLoss
|
||||
from evaluator import rollout_policy
|
||||
from .mcts_virtual_loss import MCTSVirtualLoss
|
||||
from .evaluator import rollout_policy
|
||||
|
||||
|
||||
class TestEnv:
|
||||
|
21
tianshou/core/mcts/utils.py
Normal file
21
tianshou/core/mcts/utils.py
Normal file
@ -0,0 +1,21 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# vim:fenc=utf-8
|
||||
# $File: utils.py
|
||||
# $Date: Sat Dec 23 02:0854 2017 +0800
|
||||
# $Author: renyong15 © <mails.tsinghua.edu.cn>
|
||||
#
|
||||
|
||||
def list2tuple(list):
|
||||
try:
|
||||
return tuple(list2tuple(sub) for sub in list)
|
||||
except TypeError:
|
||||
return list
|
||||
|
||||
|
||||
def tuple2list(tuple):
|
||||
try:
|
||||
return list(tuple2list(sub) for sub in tuple)
|
||||
except TypeError:
|
||||
return tuple
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user