fix virtual loss bug
This commit is contained in:
parent
1f011a44ef
commit
3b534064bd
@ -1,22 +1,9 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import math
|
import math
|
||||||
import time
|
import time
|
||||||
|
import sys,os
|
||||||
|
from .utils import list2tuple, tuple2list
|
||||||
|
|
||||||
c_puct = 5
|
|
||||||
|
|
||||||
|
|
||||||
def list2tuple(list):
|
|
||||||
try:
|
|
||||||
return tuple(list2tuple(sub) for sub in list)
|
|
||||||
except TypeError:
|
|
||||||
return list
|
|
||||||
|
|
||||||
|
|
||||||
def tuple2list(tuple):
|
|
||||||
try:
|
|
||||||
return list(tuple2list(sub) for sub in tuple)
|
|
||||||
except TypeError:
|
|
||||||
return tuple
|
|
||||||
|
|
||||||
|
|
||||||
class MCTSNode(object):
|
class MCTSNode(object):
|
||||||
@ -39,12 +26,13 @@ class MCTSNode(object):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
class UCTNode(MCTSNode):
|
class UCTNode(MCTSNode):
|
||||||
def __init__(self, parent, action, state, action_num, prior, inverse=False):
|
def __init__(self, parent, action, state, action_num, prior, inverse=False, c_puct = 5):
|
||||||
super(UCTNode, self).__init__(parent, action, state, action_num, prior, inverse)
|
super(UCTNode, self).__init__(parent, action, state, action_num, prior, inverse)
|
||||||
self.Q = np.zeros([action_num])
|
self.Q = np.zeros([action_num])
|
||||||
self.W = np.zeros([action_num])
|
self.W = np.zeros([action_num])
|
||||||
self.N = np.zeros([action_num])
|
self.N = np.zeros([action_num])
|
||||||
self.ucb = self.Q + c_puct * self.prior * math.sqrt(np.sum(self.N)) / (self.N + 1)
|
self.c_puct = c_puct
|
||||||
|
self.ucb = self.Q + self.c_puct * self.prior * math.sqrt(np.sum(self.N)) / (self.N + 1)
|
||||||
self.mask = None
|
self.mask = None
|
||||||
|
|
||||||
def selection(self, simulator):
|
def selection(self, simulator):
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
# vim:fenc=utf-8
|
# vim:fenc=utf-8
|
||||||
# $File: mcts_virtual_loss.py
|
# $File: mcts_virtual_loss.py
|
||||||
# $Date: Tue Dec 19 17:0444 2017 +0800
|
# $Date: Sat Dec 23 02:4850 2017 +0800
|
||||||
# Original file: mcts.py
|
# Original file: mcts.py
|
||||||
# $Author: renyong15 © <mails.tsinghua.edu.cn>
|
# $Author: renyong15 © <mails.tsinghua.edu.cn>
|
||||||
#
|
#
|
||||||
@ -12,25 +12,13 @@
|
|||||||
manner.
|
manner.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import math
|
import math
|
||||||
import time
|
import time
|
||||||
|
import sys,os
|
||||||
c_puct = 5
|
from .utils import list2tuple, tuple2list
|
||||||
|
|
||||||
|
|
||||||
def list2tuple(list):
|
|
||||||
try:
|
|
||||||
return tuple(list2tuple(sub) for sub in list)
|
|
||||||
except TypeError:
|
|
||||||
return list
|
|
||||||
|
|
||||||
|
|
||||||
def tuple2list(tuple):
|
|
||||||
try:
|
|
||||||
return list(tuple2list(sub) for sub in tuple)
|
|
||||||
except TypeError:
|
|
||||||
return tuple
|
|
||||||
|
|
||||||
|
|
||||||
class MCTSNodeVirtualLoss(object):
|
class MCTSNodeVirtualLoss(object):
|
||||||
@ -53,12 +41,13 @@ class MCTSNodeVirtualLoss(object):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
class UCTNodeVirtualLoss(MCTSNodeVirtualLoss):
|
class UCTNodeVirtualLoss(MCTSNodeVirtualLoss):
|
||||||
def __init__(self, parent, action, state, action_num, prior, inverse=False):
|
def __init__(self, parent, action, state, action_num, prior, inverse=False, c_puct = 5):
|
||||||
super(UCTNodeVirtualLoss, self).__init__(parent, action, state, action_num, prior, inverse)
|
super(UCTNodeVirtualLoss, self).__init__(parent, action, state, action_num, prior, inverse)
|
||||||
self.Q = np.zeros([action_num])
|
self.Q = np.zeros([action_num])
|
||||||
self.W = np.zeros([action_num])
|
self.W = np.zeros([action_num])
|
||||||
self.N = np.zeros([action_num])
|
self.N = np.zeros([action_num])
|
||||||
self.virtual_loss = np.zeros([action_num])
|
self.virtual_loss = np.zeros([action_num])
|
||||||
|
self.c_puct = c_puct
|
||||||
#### modified by adding virtual loss
|
#### modified by adding virtual loss
|
||||||
#self.ucb = self.Q + c_puct * self.prior * math.sqrt(np.sum(self.N)) / (self.N + 1)
|
#self.ucb = self.Q + c_puct * self.prior * math.sqrt(np.sum(self.N)) / (self.N + 1)
|
||||||
|
|
||||||
@ -67,9 +56,9 @@ class UCTNodeVirtualLoss(MCTSNodeVirtualLoss):
|
|||||||
def selection(self, simulator):
|
def selection(self, simulator):
|
||||||
self.valid_mask(simulator)
|
self.valid_mask(simulator)
|
||||||
self.Q = np.zeros([self.action_num])
|
self.Q = np.zeros([self.action_num])
|
||||||
N_not_zero = self.N > 0
|
N_not_zero = (self.N + self.virtual_loss) > 0
|
||||||
self.Q[N_not_zero] = (self.W[N_not_zero] + self.virtual_loss[N_not_zero] + 0.) / self.N[N_not_zero]
|
self.Q[N_not_zero] = (self.W[N_not_zero] + 0.)/ (self.virtual_loss[N_not_zero] + self.N[N_not_zero])
|
||||||
self.ucb = self.Q + c_puct * self.prior * math.sqrt(np.sum(self.N + self.virtual_loss)) /\
|
self.ucb = self.Q + self.c_puct * self.prior * math.sqrt(np.sum(self.N + self.virtual_loss)) /\
|
||||||
(self.N + self.virtual_loss + 1)
|
(self.N + self.virtual_loss + 1)
|
||||||
action = np.argmax(self.ucb)
|
action = np.argmax(self.ucb)
|
||||||
self.virtual_loss[action] += 1
|
self.virtual_loss[action] += 1
|
||||||
@ -93,6 +82,7 @@ class UCTNodeVirtualLoss(MCTSNodeVirtualLoss):
|
|||||||
self.W[action] += self.children[action].reward
|
self.W[action] += self.children[action].reward
|
||||||
|
|
||||||
## do not need to compute Q and ucb immediately since it will be modified by virtual loss
|
## do not need to compute Q and ucb immediately since it will be modified by virtual loss
|
||||||
|
## just comment out and leaving for comparision
|
||||||
#for i in range(self.action_num):
|
#for i in range(self.action_num):
|
||||||
# if self.N[i] != 0:
|
# if self.N[i] != 0:
|
||||||
# self.Q[i] = (self.W[i] + 0.) / self.N[i]
|
# self.Q[i] = (self.W[i] + 0.) / self.N[i]
|
||||||
@ -186,6 +176,12 @@ class MCTSVirtualLoss(object):
|
|||||||
|
|
||||||
|
|
||||||
def do_search(self, max_step=None, max_time=None):
|
def do_search(self, max_step=None, max_time=None):
|
||||||
|
"""
|
||||||
|
Expand the MCTS tree with stop crierion either by max_step or max_time
|
||||||
|
|
||||||
|
:param max_step search maximum minibath rounds. ONE step is ONE minibatch
|
||||||
|
:param max_time search maximum seconds
|
||||||
|
"""
|
||||||
if max_step is not None:
|
if max_step is not None:
|
||||||
self.step = 0
|
self.step = 0
|
||||||
self.max_step = max_step
|
self.max_step = max_step
|
||||||
@ -205,6 +201,9 @@ class MCTSVirtualLoss(object):
|
|||||||
self.step += 1
|
self.step += 1
|
||||||
|
|
||||||
def expand(self):
|
def expand(self):
|
||||||
|
"""
|
||||||
|
Core logic method for MCTS tree to expand nodes.
|
||||||
|
"""
|
||||||
## minibatch with virtual loss
|
## minibatch with virtual loss
|
||||||
nodes = []
|
nodes = []
|
||||||
new_actions = []
|
new_actions = []
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
# vim:fenc=utf-8
|
# vim:fenc=utf-8
|
||||||
# $File: mcts_virtual_loss_test.py
|
# $File: mcts_virtual_loss_test.py
|
||||||
# $Date: Tue Dec 19 16:5459 2017 +0800
|
# $Date: Sat Dec 23 02:2139 2017 +0800
|
||||||
# Original file: mcts_test.py
|
# Original file: mcts_test.py
|
||||||
# $Author: renyong15 © <mails.tsinghua.edu.cn>
|
# $Author: renyong15 © <mails.tsinghua.edu.cn>
|
||||||
#
|
#
|
||||||
@ -9,8 +9,8 @@
|
|||||||
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from mcts_virtual_loss import MCTSVirtualLoss
|
from .mcts_virtual_loss import MCTSVirtualLoss
|
||||||
from evaluator import rollout_policy
|
from .evaluator import rollout_policy
|
||||||
|
|
||||||
|
|
||||||
class TestEnv:
|
class TestEnv:
|
||||||
|
21
tianshou/core/mcts/utils.py
Normal file
21
tianshou/core/mcts/utils.py
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# vim:fenc=utf-8
|
||||||
|
# $File: utils.py
|
||||||
|
# $Date: Sat Dec 23 02:0854 2017 +0800
|
||||||
|
# $Author: renyong15 © <mails.tsinghua.edu.cn>
|
||||||
|
#
|
||||||
|
|
||||||
|
def list2tuple(list):
|
||||||
|
try:
|
||||||
|
return tuple(list2tuple(sub) for sub in list)
|
||||||
|
except TypeError:
|
||||||
|
return list
|
||||||
|
|
||||||
|
|
||||||
|
def tuple2list(tuple):
|
||||||
|
try:
|
||||||
|
return list(tuple2list(sub) for sub in tuple)
|
||||||
|
except TypeError:
|
||||||
|
return tuple
|
||||||
|
|
||||||
|
|
Loading…
x
Reference in New Issue
Block a user