diff --git a/tianshou/core/mcts/mcts_virtual_loss.py b/tianshou/core/mcts/mcts_virtual_loss.py index 9335464..f27d8a3 100644 --- a/tianshou/core/mcts/mcts_virtual_loss.py +++ b/tianshou/core/mcts/mcts_virtual_loss.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- # vim:fenc=utf-8 # $File: mcts_virtual_loss.py -# $Date: Sat Dec 23 02:4850 2017 +0800 +# $Date: Sun Dec 24 16:4740 2017 +0800 # Original file: mcts.py # $Author: renyong15 © # @@ -22,7 +22,17 @@ from .utils import list2tuple, tuple2list class MCTSNodeVirtualLoss(object): - def __init__(self, parent, action, state, action_num, prior, inverse=False): + """ + MCTS abstract class with virtual loss. Currently we only support UCT node. + Role of the Parameters can be found in Readme.md. + """ + def __init__(self, + parent, + action, + state, + action_num, + prior, + inverse = False): self.parent = parent self.action = action self.children = {} @@ -41,7 +51,19 @@ class MCTSNodeVirtualLoss(object): pass class UCTNodeVirtualLoss(MCTSNodeVirtualLoss): - def __init__(self, parent, action, state, action_num, prior, inverse=False, c_puct = 5): + """ + UCT node (state node) with virtual loss. + Role of the Parameters can be found in Readme.md. + :param c_puct balance between exploration and exploition, + """ + def __init__(self, + parent, + action, + state, + action_num, + prior, + inverse=False, + c_puct = 5): super(UCTNodeVirtualLoss, self).__init__(parent, action, state, action_num, prior, inverse) self.Q = np.zeros([action_num]) self.W = np.zeros([action_num]) @@ -53,7 +75,8 @@ class UCTNodeVirtualLoss(MCTSNodeVirtualLoss): self.mask = None - def selection(self, simulator): + def selection(self, + simulator): self.valid_mask(simulator) self.Q = np.zeros([self.action_num]) N_not_zero = (self.N + self.virtual_loss) > 0 @@ -108,6 +131,9 @@ class UCTNodeVirtualLoss(MCTSNodeVirtualLoss): class ActionNodeVirtualLoss(object): + """ + Action node with virtual loss. + """ def __init__(self, parent, action): self.parent = parent self.action = action @@ -156,6 +182,9 @@ class ActionNodeVirtualLoss(object): class MCTSVirtualLoss(object): + """ + MCTS class with virtual loss + """ def __init__(self, simulator, evaluator, root, action_num, batch_size = 1, method = "UCT", inverse = False): self.simulator = simulator self.evaluator = evaluator @@ -196,13 +225,19 @@ class MCTSVirtualLoss(object): self.bp_time = [] while (max_step is not None and self.step < self.max_step or max_step is None) \ and (max_time is not None and time.time() - self.start_time < self.max_time or max_time is None): - self.expand() + self._expand() if max_step is not None: self.step += 1 - def expand(self): + def _expand(self): """ Core logic method for MCTS tree to expand nodes. + Steps to expand node: + 1. Select final action node with virtual loss and collect them in to a minibatch. + (i.e. root->action->state->action...->action) + 2. Remove the virtual loss + 3. Evaluate the whole minibatch using evaluator + 4. Expand new nodes and perform back propogation. """ ## minibatch with virtual loss nodes = []