add comments for mcts with virtual loss

This commit is contained in:
mcgrady00h 2017-12-24 16:47:43 +08:00
parent 8c6f44a015
commit 5aa5dcd191

View File

@ -1,7 +1,7 @@
# -*- coding: utf-8 -*-
# vim:fenc=utf-8
# $File: mcts_virtual_loss.py
# $Date: Sat Dec 23 02:4850 2017 +0800
# $Date: Sun Dec 24 16:4740 2017 +0800
# Original file: mcts.py
# $Author: renyong15 © <mails.tsinghua.edu.cn>
#
@ -22,7 +22,17 @@ from .utils import list2tuple, tuple2list
class MCTSNodeVirtualLoss(object):
def __init__(self, parent, action, state, action_num, prior, inverse=False):
"""
MCTS abstract class with virtual loss. Currently we only support UCT node.
Role of the Parameters can be found in Readme.md.
"""
def __init__(self,
parent,
action,
state,
action_num,
prior,
inverse = False):
self.parent = parent
self.action = action
self.children = {}
@ -41,7 +51,19 @@ class MCTSNodeVirtualLoss(object):
pass
class UCTNodeVirtualLoss(MCTSNodeVirtualLoss):
def __init__(self, parent, action, state, action_num, prior, inverse=False, c_puct = 5):
"""
UCT node (state node) with virtual loss.
Role of the Parameters can be found in Readme.md.
:param c_puct balance between exploration and exploition,
"""
def __init__(self,
parent,
action,
state,
action_num,
prior,
inverse=False,
c_puct = 5):
super(UCTNodeVirtualLoss, self).__init__(parent, action, state, action_num, prior, inverse)
self.Q = np.zeros([action_num])
self.W = np.zeros([action_num])
@ -53,7 +75,8 @@ class UCTNodeVirtualLoss(MCTSNodeVirtualLoss):
self.mask = None
def selection(self, simulator):
def selection(self,
simulator):
self.valid_mask(simulator)
self.Q = np.zeros([self.action_num])
N_not_zero = (self.N + self.virtual_loss) > 0
@ -108,6 +131,9 @@ class UCTNodeVirtualLoss(MCTSNodeVirtualLoss):
class ActionNodeVirtualLoss(object):
"""
Action node with virtual loss.
"""
def __init__(self, parent, action):
self.parent = parent
self.action = action
@ -156,6 +182,9 @@ class ActionNodeVirtualLoss(object):
class MCTSVirtualLoss(object):
"""
MCTS class with virtual loss
"""
def __init__(self, simulator, evaluator, root, action_num, batch_size = 1, method = "UCT", inverse = False):
self.simulator = simulator
self.evaluator = evaluator
@ -196,13 +225,19 @@ class MCTSVirtualLoss(object):
self.bp_time = []
while (max_step is not None and self.step < self.max_step or max_step is None) \
and (max_time is not None and time.time() - self.start_time < self.max_time or max_time is None):
self.expand()
self._expand()
if max_step is not None:
self.step += 1
def expand(self):
def _expand(self):
"""
Core logic method for MCTS tree to expand nodes.
Steps to expand node:
1. Select final action node with virtual loss and collect them in to a minibatch.
(i.e. root->action->state->action...->action)
2. Remove the virtual loss
3. Evaluate the whole minibatch using evaluator
4. Expand new nodes and perform back propogation.
"""
## minibatch with virtual loss
nodes = []