add comments for mcts with virtual loss
This commit is contained in:
parent
8c6f44a015
commit
5aa5dcd191
@ -1,7 +1,7 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# vim:fenc=utf-8
|
||||
# $File: mcts_virtual_loss.py
|
||||
# $Date: Sat Dec 23 02:4850 2017 +0800
|
||||
# $Date: Sun Dec 24 16:4740 2017 +0800
|
||||
# Original file: mcts.py
|
||||
# $Author: renyong15 © <mails.tsinghua.edu.cn>
|
||||
#
|
||||
@ -22,7 +22,17 @@ from .utils import list2tuple, tuple2list
|
||||
|
||||
|
||||
class MCTSNodeVirtualLoss(object):
|
||||
def __init__(self, parent, action, state, action_num, prior, inverse=False):
|
||||
"""
|
||||
MCTS abstract class with virtual loss. Currently we only support UCT node.
|
||||
Role of the Parameters can be found in Readme.md.
|
||||
"""
|
||||
def __init__(self,
|
||||
parent,
|
||||
action,
|
||||
state,
|
||||
action_num,
|
||||
prior,
|
||||
inverse = False):
|
||||
self.parent = parent
|
||||
self.action = action
|
||||
self.children = {}
|
||||
@ -41,7 +51,19 @@ class MCTSNodeVirtualLoss(object):
|
||||
pass
|
||||
|
||||
class UCTNodeVirtualLoss(MCTSNodeVirtualLoss):
|
||||
def __init__(self, parent, action, state, action_num, prior, inverse=False, c_puct = 5):
|
||||
"""
|
||||
UCT node (state node) with virtual loss.
|
||||
Role of the Parameters can be found in Readme.md.
|
||||
:param c_puct balance between exploration and exploition,
|
||||
"""
|
||||
def __init__(self,
|
||||
parent,
|
||||
action,
|
||||
state,
|
||||
action_num,
|
||||
prior,
|
||||
inverse=False,
|
||||
c_puct = 5):
|
||||
super(UCTNodeVirtualLoss, self).__init__(parent, action, state, action_num, prior, inverse)
|
||||
self.Q = np.zeros([action_num])
|
||||
self.W = np.zeros([action_num])
|
||||
@ -53,7 +75,8 @@ class UCTNodeVirtualLoss(MCTSNodeVirtualLoss):
|
||||
|
||||
self.mask = None
|
||||
|
||||
def selection(self, simulator):
|
||||
def selection(self,
|
||||
simulator):
|
||||
self.valid_mask(simulator)
|
||||
self.Q = np.zeros([self.action_num])
|
||||
N_not_zero = (self.N + self.virtual_loss) > 0
|
||||
@ -108,6 +131,9 @@ class UCTNodeVirtualLoss(MCTSNodeVirtualLoss):
|
||||
|
||||
|
||||
class ActionNodeVirtualLoss(object):
|
||||
"""
|
||||
Action node with virtual loss.
|
||||
"""
|
||||
def __init__(self, parent, action):
|
||||
self.parent = parent
|
||||
self.action = action
|
||||
@ -156,6 +182,9 @@ class ActionNodeVirtualLoss(object):
|
||||
|
||||
|
||||
class MCTSVirtualLoss(object):
|
||||
"""
|
||||
MCTS class with virtual loss
|
||||
"""
|
||||
def __init__(self, simulator, evaluator, root, action_num, batch_size = 1, method = "UCT", inverse = False):
|
||||
self.simulator = simulator
|
||||
self.evaluator = evaluator
|
||||
@ -196,13 +225,19 @@ class MCTSVirtualLoss(object):
|
||||
self.bp_time = []
|
||||
while (max_step is not None and self.step < self.max_step or max_step is None) \
|
||||
and (max_time is not None and time.time() - self.start_time < self.max_time or max_time is None):
|
||||
self.expand()
|
||||
self._expand()
|
||||
if max_step is not None:
|
||||
self.step += 1
|
||||
|
||||
def expand(self):
|
||||
def _expand(self):
|
||||
"""
|
||||
Core logic method for MCTS tree to expand nodes.
|
||||
Steps to expand node:
|
||||
1. Select final action node with virtual loss and collect them in to a minibatch.
|
||||
(i.e. root->action->state->action...->action)
|
||||
2. Remove the virtual loss
|
||||
3. Evaluate the whole minibatch using evaluator
|
||||
4. Expand new nodes and perform back propogation.
|
||||
"""
|
||||
## minibatch with virtual loss
|
||||
nodes = []
|
||||
|
Loading…
x
Reference in New Issue
Block a user