“Shengjiewang-Jason” 1367bca203 first commit
2024-06-07 16:02:01 +08:00

150 lines
7.3 KiB
Cython

# distutils: language=c++
from libcpp.vector cimport vector
cdef extern from "cminimax.cpp":
pass
cdef extern from "cminimax.h" namespace "tools":
cdef cppclass CMinMaxStats:
CMinMaxStats() except +
float maximum, minimum, value_delta_max
void set_delta(float value_delta_max)
void update(float value)
void clear()
float normalize(float value)
cdef cppclass CMinMaxStatsList:
CMinMaxStatsList() except +
CMinMaxStatsList(int num) except +
int num
vector[CMinMaxStats] stats_lst
void set_delta(float value_delta_max)
vector[float] get_min_max()
# cdef extern from "cnode.cpp":
# pass
#
#
# cdef extern from "cnode.h" namespace "tree":
# cdef cppclass CNode:
# CNode() except +
# CNode(float prior, int action_num, vector[CNode]* ptr_node_pool) except +
# int visit_count, to_play, action_num, hidden_state_index_x, hidden_state_index_y, best_action
# float reward_sums, prior, value_sum
# vector[int] children_index;
# vector[CNode]* ptr_node_pool;
#
# void expand(int to_play, int hidden_state_index_x, int hidden_state_index_y, float reward_sums, vector[float] policy_logits)
# void add_exploration_noise(float exploration_fraction, vector[float] noises)
# float get_mean_q(int isRoot, float parent_q, float discount)
#
# int expanded()
# float value()
# vector[int] get_trajectory()
# vector[int] get_children_distribution()
# CNode* get_child(int action)
#
# cdef cppclass CRoots:
# CRoots() except +
# CRoots(int root_num, int action_num, int pool_size) except +
# int root_num, action_num, pool_size
# vector[CNode] roots
# vector[vector[CNode]] node_pools
#
# void prepare(float root_exploration_fraction, const vector[vector[float]] &noises, const vector[float] &reward_sums, const vector[vector[float]] &policies)
# void prepare_no_noise(const vector[float] &reward_sums, const vector[vector[float]] &policies)
# void clear()
# vector[vector[int]] get_trajectories()
# vector[vector[int]] get_distributions()
# vector[float] get_values()
#
# cdef cppclass CSearchResults:
# CSearchResults() except +
# CSearchResults(int num) except +
# int num
# vector[int] hidden_state_index_x_lst, hidden_state_index_y_lst, last_actions, search_lens
# vector[CNode*] nodes
# # vector[vector[CNode*]] search_paths
#
# cdef void cback_propagate(vector[CNode*] &search_path, CMinMaxStats &min_max_stats, int to_play, float value, float discount)
# void cmulti_back_propagate(int hidden_state_index_x, float discount, vector[float] reward_sums, vector[float] values, vector[vector[float]] policies,
# CMinMaxStatsList *min_max_stats_lst, CSearchResults &results, vector[int] is_reset_lst, vector[float] similarities)
# # int cselect_child(CNode &root, CMinMaxStats &min_max_stats, int pb_c_base, float pb_c_init)
# # float cucb_score(CNode &parent, CNode &child, CMinMaxStats &min_max_stats, int pb_c_base, float pb_c_init)
# void cmulti_traverse(CRoots *roots, int pb_c_base, float pb_c_init, float discount, CMinMaxStatsList *min_max_stats_lst, CSearchResults &results, int use_mcgs)
cdef extern from "gumbel_cnode.cpp":
pass
cdef extern from "gumbel_cnode.h" namespace "tree":
cdef cppclass CNode:
CNode() except +
CNode(float prior, int action_num, vector[CNode]* ptr_node_pool) except +
int visit_count, to_play, action_num, hidden_state_index_x, hidden_state_index_y, best_action
int phase_added_flag, current_phase, phase_num, phase_to_visit_num, m, simulation_num
int is_root
float reward_sums, prior, value_sum, value_mix
vector[int] children_index;
vector[CNode]* ptr_node_pool;
CNode* parent;
void expand(int to_play, int hidden_state_index_x, int hidden_state_index_y, float reward_sums, vector[float] policy_logits, int simulation_num)
# void expand_q_init(int to_play, int hidden_state_index_x, int hidden_state_index_y, float reward_sums, vector[float] policy_logits, vector[float] q_inits)
int expanded()
float value(CNode parent)
vector[int] get_trajectory()
CNode* get_child(int action)
cdef cppclass CRoots:
CRoots() except +
CRoots(int root_num, int action_num, int pool_size) except +
int root_num, action_num, pool_size
vector[CNode] roots
vector[vector[CNode]] node_pools
void prepare(const vector[float] &reward_sums, const vector[vector[float]] &policies, int m, int simulation_num, const vector[float] &values)
# void prepare_q_init(const vector[float] &reward_sums, const vector[vector[float]] &policies, int m, int simulation_num, const vector[float] &values, const vector[vector[float]] &q_inits)
void clear()
vector[vector[int]] get_trajectories()
vector[vector[float]] get_advantages(float discount)
vector[vector[float]] get_pi_primes(CMinMaxStatsList *min_max_stats_lst, float c_visit, float c_scale, float discount)
vector[float] get_values()
vector[vector[float]] get_child_values(float discount)
vector[vector[float]] get_priors()
vector[int] get_actions(CMinMaxStatsList *min_max_stats_lst, float c_visit, float c_scale, const vector[vector[float]] gumbels, float discount)
cdef cppclass CSearchResults:
CSearchResults() except +
CSearchResults(int num) except +
int num
vector[int] hidden_state_index_x_lst, hidden_state_index_y_lst, last_actions, search_lens
vector[CNode*] nodes
# vector[vector[CNode*]] search_paths
vector[vector[int]] search_path_index_x_lst, search_path_index_y_lst, search_path_actions
cdef void cback_propagate(vector[CNode*] &search_path, CMinMaxStats &min_max_stats, int to_play, float value, float discount)
void cmulti_back_propagate(int hidden_state_index_x, float discount, vector[float] reward_sums, vector[float] values, vector[vector[float]] policies,
CMinMaxStatsList *min_max_stats_lst, CSearchResults &results, vector[int] is_reset_lst, int simulation_idx, vector[vector[float]] gumbels, float c_visit, float c_scale, int simulation_num)
void cmulti_traverse(CRoots *roots, float c_visit, float c_scale, float discount, CMinMaxStatsList *min_max_stats_lst, CSearchResults &results, int simulation_idx, vector[vector[float]] gumbels)
void cmulti_traverse_return_path(CRoots *roots, float c_visit, float c_scale, float discount, CMinMaxStatsList *min_max_stats_lst, CSearchResults &results, int simulation_idx, vector[vector[float]] gumbels)
# cdef extern from "cresults.cpp":
# pass
#
#
# cdef extern from "cresults.h" namespace "search":
# cdef cppclass CSearchResults:
# CSearchResults() except +
# vector[int] hidden_state_index_x_lst, hidden_state_index_y_lst, last_actions
# vector[CNode] nodes
# vector[vector[CNode]] search_paths
#
# void cmulti_traverse(vector[CNode] roots, vector[CMinMaxStats] min_max_stats_lst, int num, vector[int] histories_len, vector[vector[int]] action_histories, int pb_c_base, float pb_c_init, CSearchResults &results)