add comments for mcts with virtual loss
This commit is contained in:
parent
8c6f44a015
commit
5aa5dcd191
@ -1,7 +1,7 @@
|
|||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
# vim:fenc=utf-8
|
# vim:fenc=utf-8
|
||||||
# $File: mcts_virtual_loss.py
|
# $File: mcts_virtual_loss.py
|
||||||
# $Date: Sat Dec 23 02:4850 2017 +0800
|
# $Date: Sun Dec 24 16:4740 2017 +0800
|
||||||
# Original file: mcts.py
|
# Original file: mcts.py
|
||||||
# $Author: renyong15 © <mails.tsinghua.edu.cn>
|
# $Author: renyong15 © <mails.tsinghua.edu.cn>
|
||||||
#
|
#
|
||||||
@ -22,7 +22,17 @@ from .utils import list2tuple, tuple2list
|
|||||||
|
|
||||||
|
|
||||||
class MCTSNodeVirtualLoss(object):
|
class MCTSNodeVirtualLoss(object):
|
||||||
def __init__(self, parent, action, state, action_num, prior, inverse=False):
|
"""
|
||||||
|
MCTS abstract class with virtual loss. Currently we only support UCT node.
|
||||||
|
Role of the Parameters can be found in Readme.md.
|
||||||
|
"""
|
||||||
|
def __init__(self,
|
||||||
|
parent,
|
||||||
|
action,
|
||||||
|
state,
|
||||||
|
action_num,
|
||||||
|
prior,
|
||||||
|
inverse = False):
|
||||||
self.parent = parent
|
self.parent = parent
|
||||||
self.action = action
|
self.action = action
|
||||||
self.children = {}
|
self.children = {}
|
||||||
@ -41,7 +51,19 @@ class MCTSNodeVirtualLoss(object):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
class UCTNodeVirtualLoss(MCTSNodeVirtualLoss):
|
class UCTNodeVirtualLoss(MCTSNodeVirtualLoss):
|
||||||
def __init__(self, parent, action, state, action_num, prior, inverse=False, c_puct = 5):
|
"""
|
||||||
|
UCT node (state node) with virtual loss.
|
||||||
|
Role of the Parameters can be found in Readme.md.
|
||||||
|
:param c_puct balance between exploration and exploition,
|
||||||
|
"""
|
||||||
|
def __init__(self,
|
||||||
|
parent,
|
||||||
|
action,
|
||||||
|
state,
|
||||||
|
action_num,
|
||||||
|
prior,
|
||||||
|
inverse=False,
|
||||||
|
c_puct = 5):
|
||||||
super(UCTNodeVirtualLoss, self).__init__(parent, action, state, action_num, prior, inverse)
|
super(UCTNodeVirtualLoss, self).__init__(parent, action, state, action_num, prior, inverse)
|
||||||
self.Q = np.zeros([action_num])
|
self.Q = np.zeros([action_num])
|
||||||
self.W = np.zeros([action_num])
|
self.W = np.zeros([action_num])
|
||||||
@ -53,7 +75,8 @@ class UCTNodeVirtualLoss(MCTSNodeVirtualLoss):
|
|||||||
|
|
||||||
self.mask = None
|
self.mask = None
|
||||||
|
|
||||||
def selection(self, simulator):
|
def selection(self,
|
||||||
|
simulator):
|
||||||
self.valid_mask(simulator)
|
self.valid_mask(simulator)
|
||||||
self.Q = np.zeros([self.action_num])
|
self.Q = np.zeros([self.action_num])
|
||||||
N_not_zero = (self.N + self.virtual_loss) > 0
|
N_not_zero = (self.N + self.virtual_loss) > 0
|
||||||
@ -108,6 +131,9 @@ class UCTNodeVirtualLoss(MCTSNodeVirtualLoss):
|
|||||||
|
|
||||||
|
|
||||||
class ActionNodeVirtualLoss(object):
|
class ActionNodeVirtualLoss(object):
|
||||||
|
"""
|
||||||
|
Action node with virtual loss.
|
||||||
|
"""
|
||||||
def __init__(self, parent, action):
|
def __init__(self, parent, action):
|
||||||
self.parent = parent
|
self.parent = parent
|
||||||
self.action = action
|
self.action = action
|
||||||
@ -156,6 +182,9 @@ class ActionNodeVirtualLoss(object):
|
|||||||
|
|
||||||
|
|
||||||
class MCTSVirtualLoss(object):
|
class MCTSVirtualLoss(object):
|
||||||
|
"""
|
||||||
|
MCTS class with virtual loss
|
||||||
|
"""
|
||||||
def __init__(self, simulator, evaluator, root, action_num, batch_size = 1, method = "UCT", inverse = False):
|
def __init__(self, simulator, evaluator, root, action_num, batch_size = 1, method = "UCT", inverse = False):
|
||||||
self.simulator = simulator
|
self.simulator = simulator
|
||||||
self.evaluator = evaluator
|
self.evaluator = evaluator
|
||||||
@ -196,13 +225,19 @@ class MCTSVirtualLoss(object):
|
|||||||
self.bp_time = []
|
self.bp_time = []
|
||||||
while (max_step is not None and self.step < self.max_step or max_step is None) \
|
while (max_step is not None and self.step < self.max_step or max_step is None) \
|
||||||
and (max_time is not None and time.time() - self.start_time < self.max_time or max_time is None):
|
and (max_time is not None and time.time() - self.start_time < self.max_time or max_time is None):
|
||||||
self.expand()
|
self._expand()
|
||||||
if max_step is not None:
|
if max_step is not None:
|
||||||
self.step += 1
|
self.step += 1
|
||||||
|
|
||||||
def expand(self):
|
def _expand(self):
|
||||||
"""
|
"""
|
||||||
Core logic method for MCTS tree to expand nodes.
|
Core logic method for MCTS tree to expand nodes.
|
||||||
|
Steps to expand node:
|
||||||
|
1. Select final action node with virtual loss and collect them in to a minibatch.
|
||||||
|
(i.e. root->action->state->action...->action)
|
||||||
|
2. Remove the virtual loss
|
||||||
|
3. Evaluate the whole minibatch using evaluator
|
||||||
|
4. Expand new nodes and perform back propogation.
|
||||||
"""
|
"""
|
||||||
## minibatch with virtual loss
|
## minibatch with virtual loss
|
||||||
nodes = []
|
nodes = []
|
||||||
|
Loading…
x
Reference in New Issue
Block a user