“Shengjiewang-Jason” 1367bca203 first commit
2024-06-07 16:02:01 +08:00

75 lines
2.9 KiB
Cython

# distutils: language=c++
from libcpp.vector cimport vector
cdef extern from "cminimax.cpp":
pass
cdef extern from "cminimax.h" namespace "tools":
cdef cppclass CMinMaxStats:
CMinMaxStats() except +
int c_visit
float c_scale
float maximum, minimum, value_delta_max
void set_static_val(float value_delta_max, int c_visit, float c_scale)
void update(float value)
void clear()
float normalize(float value)
cdef cppclass CMinMaxStatsList:
CMinMaxStatsList() except +
CMinMaxStatsList(int num) except +
int num
vector[CMinMaxStats] stats_lst
void set_static_val(float value_delta_max, int c_visit, float c_scale)
cdef extern from "cnode.cpp":
pass
cdef extern from "cnode.h" namespace "tree":
cdef cppclass CNode:
CNode() except +
CNode(float prior, int action, CNode* parent, vector[CNode]* ptr_node_pool, float discount, int num_actions) except +
int num_actions, action, best_action, reset_value_prefix, depth, visit_count
int hidden_state_index_x, hidden_state_index_y
float value_prefix, prior, discount
CNode* parent
vector[int] children_idx, selected_children_idx
vector[float] estimated_value_lst
vector[CNode]* ptr_node_pool
cdef cppclass CRoots:
CRoots() except +
CRoots(int num_roots, int num_actions, int pool_size, float discount) except +
int num_roots, num_actions, pool_size
float discount
vector[CNode] roots
vector[vector[CNode]] node_pools
void prepare(vector[float] &values, const vector[vector[float]] &policies, int leaf_action_num)
void clear()
vector[vector[int]] get_trajectories()
vector[vector[int]] get_distributions()
vector[vector[float]] get_root_policies(CMinMaxStatsList *min_max_stats_lst)
vector[int] get_best_actions()
vector[float] get_values()
void print_tree()
cdef cppclass CSearchResults:
CSearchResults() except +
CSearchResults(int num) except +
int num
vector[int] hidden_state_index_x_lst, hidden_state_index_y_lst, last_actions, search_lens
vector[CNode*] nodes
vector[int] c_batch_sequential_halving(CRoots *roots, vector[vector[float]] gumble_noises, CMinMaxStatsList *min_max_stats_lst, int current_phase, int current_num_top_actions)
void c_batch_traverse(CRoots *roots, CMinMaxStatsList *min_max_stats_lst, CSearchResults &results, int num_simulations, int simulation_idx, const vector[vector[float]] &gumbel_noises, int current_num_top_actions)
void c_batch_back_propagate(int hidden_state_index_x, vector[float] value_prefixs, vector[float] values, vector[vector[float]] policies, CMinMaxStatsList *min_max_stats_lst, CSearchResults &results, vector[int] to_reset_lst, int leaf_action_num)