// Copyright (c) EVAR Lab, IIIS, Tsinghua University. // // This source code is licensed under the GNU License, Version 3.0 // found in the LICENSE file in the root directory of this source tree. #ifndef CNODE_H #define CNODE_H #include "cminimax.h" #include #include #include #include #include #include #include #include const int DEBUG_MODE = 0; namespace tree { class CNode { public: int visit_count, to_play, action_num, hidden_state_index_x, hidden_state_index_y, best_action, is_reset; float value_prefix, prior, value_sum; std::vector children_index; std::vector* ptr_node_pool; CNode(); CNode(float prior, int action_num, std::vector *ptr_node_pool); ~CNode(); void expand(int to_play, int hidden_state_index_x, int hidden_state_index_y, float value_prefix, const std::vector &policy_logits); void add_exploration_noise(float exploration_fraction, const std::vector &noises); float get_mean_q(int isRoot, float parent_q, float discount); void print_out(); int expanded(); float value(); std::vector get_trajectory(); std::vector get_children_distribution(); CNode* get_child(int action); }; class CRoots{ public: int root_num, action_num, pool_size; std::vector roots; std::vector> node_pools; CRoots(); CRoots(int root_num, int action_num, int pool_size); ~CRoots(); void prepare(float root_exploration_fraction, const std::vector> &noises, const std::vector &value_prefixs, const std::vector> &policies); void prepare_no_noise(const std::vector &value_prefixs, const std::vector> &policies); void clear(); std::vector> get_trajectories(); std::vector> get_distributions(); std::vector get_values(); }; class CSearchResults{ public: int num; std::vector hidden_state_index_x_lst, hidden_state_index_y_lst, last_actions, search_lens; std::vector nodes; std::vector> search_paths; CSearchResults(); CSearchResults(int num); ~CSearchResults(); }; //********************************************************* void update_tree_q(CNode* root, tools::CMinMaxStats &min_max_stats, float discount); void cback_propagate(std::vector &search_path, tools::CMinMaxStats &min_max_stats, int to_play, float value, float discount); void cbatch_back_propagate(int hidden_state_index_x, float discount, const std::vector &value_prefixs, const std::vector &values, const std::vector> &policies, tools::CMinMaxStatsList *min_max_stats_lst, CSearchResults &results, std::vector is_reset_lst); int cselect_child(CNode* root, tools::CMinMaxStats &min_max_stats, int pb_c_base, float pb_c_init, float discount, float mean_q); float cucb_score(CNode *child, tools::CMinMaxStats &min_max_stats, float parent_mean_q, int is_reset, float total_children_visit_counts, float parent_value_prefix, float pb_c_base, float pb_c_init, float discount); void cbatch_traverse(CRoots *roots, int pb_c_base, float pb_c_init, float discount, tools::CMinMaxStatsList *min_max_stats_lst, CSearchResults &results); } #endif