// Copyright (c) EVAR Lab, IIIS, Tsinghua University. // // This source code is licensed under the GNU License, Version 3.0 // found in the LICENSE file in the root directory of this source tree. #include #include "cnode.h" namespace tree{ CSearchResults::CSearchResults(){ this->num = 0; } CSearchResults::CSearchResults(int num){ this->num = num; for(int i = 0; i < num; ++i){ this->search_paths.push_back(std::vector()); } } CSearchResults::~CSearchResults(){} //********************************************************* std::vector softmax(std::vector logits){ std::vector policy(logits.size(), 0.0); float max_logit = max_float(logits); // printf("size=%d\n", logits.size()); for(long unsigned int a = 0; a < logits.size(); ++a){ policy[a] = exp(logits[a] - max_logit); } float policy_sum = sum(policy); for(long unsigned int a = 0; a < logits.size(); ++a){ policy[a] = policy[a] / policy_sum; } return policy; } CNode::CNode(){ this->num_actions = -1; this->action = -1; this->best_action = -1; this->reset_value_prefix = 0; this->depth = 0; this->visit_count = 0; this->hidden_state_index_x = 0; this->hidden_state_index_y = 0; this->value_prefix = 0.0; this->prior = 0.0; this->discount = 0.0; this->parent = nullptr; this->ptr_node_pool = nullptr; this->children_idx = std::vector(); this->selected_children_idx = std::vector(); this->estimated_value_lst = std::vector(); } CNode::CNode(float prior, int action, CNode* parent, std::vector *ptr_node_pool, float discount, int num_actions){ this->num_actions = num_actions; this->action = action; this->best_action = -1; this->reset_value_prefix = 0; if(parent == nullptr) this->depth = 0; else this->depth = parent->depth + 1; this->visit_count = 0; this->hidden_state_index_x = 0; this->hidden_state_index_y = 0; this->value_prefix = 0.0; this->prior = prior; this->discount = discount; this->parent = parent; this->ptr_node_pool = ptr_node_pool; this->children_idx = std::vector(); this->selected_children_idx = std::vector(); this->estimated_value_lst = std::vector(); } CNode::~CNode(){} void CNode::expand(int hidden_state_index_x, int hidden_state_index_y, float value_prefix, const std::vector &policy_logits, int reset_value_prefix, int leaf_action_num){ this->hidden_state_index_x = hidden_state_index_x; this->hidden_state_index_y = hidden_state_index_y; this->value_prefix = value_prefix; this->reset_value_prefix = reset_value_prefix; for(long unsigned int action = 0; action < policy_logits.size(); ++action){ float prior = policy_logits[action]; int index = this->ptr_node_pool->size(); this->children_idx.push_back(index); this->ptr_node_pool->push_back(CNode(prior, action, this, ptr_node_pool, this->discount, leaf_action_num)); } } std::vector CNode::get_policy(){ std::vector logits = this->get_children_priors(); std::vector policy = softmax(logits); return policy; } std::vector CNode::get_improved_policy(std::vector transformed_completed_Qs){ std::vector logits(this->num_actions, 0.0); for(int action = 0; action < this->num_actions; ++action){ CNode* child = this->get_child(action); logits[action] = child->prior + transformed_completed_Qs[action]; } std::vector policy = softmax(logits); return policy; } std::vector CNode::get_completed_Q(tools::CMinMaxStats &min_max_stats, int to_normalize){ std::vector completed_Qs(this->num_actions, 0.0); float v_mix = this->get_v_mix(); for(int action = 0; action < this->num_actions; ++action){ CNode* child = this->get_child(action); float Q = 0.0; if(child->is_expanded()) Q = this->get_qsa(action); else { Q = v_mix; // printf("use v_mix in continuous\n"); } if (to_normalize == 1) { completed_Qs[action] = min_max_stats.normalize(Q); if (completed_Qs[action] < 0.0) completed_Qs[action] = 0.0; if (completed_Qs[action] > 1.0) completed_Qs[action] = 1.0; } else completed_Qs[action] = Q; } if (to_normalize == 2){ printf("use final normalize\n"); float v_max = max_float(completed_Qs); float v_min = min_float(completed_Qs); // printf("here, %.3f, %.3f\n", v_max, v_min); for(int action = 0; action < this->num_actions; ++action){ completed_Qs[action] = (completed_Qs[action] - v_min) / (v_max - v_min); // printf("%.3f\n", completed_Qs[action]); } } return completed_Qs; } std::vector CNode::get_children_priors(){ std::vector priors(this->num_actions, 0.0); for(int action = 0; action < this->num_actions; ++action){ CNode* child = this->get_child(action); priors[action] = child->prior; } return priors; } std::vector CNode::get_children_visits(){ std::vector visits(this->num_actions, 0); for(int action = 0; action < this->num_actions; ++action){ CNode* child = this->get_child(action); visits[action] = child->visit_count; } return visits; } std::vector CNode::get_trajectory(){ std::vector traj; CNode* node = this; int best_action = node->best_action; while(best_action >= 0){ traj.push_back(best_action); node = node->get_child(best_action); best_action = node->best_action; } return traj; } int CNode::get_children_visit_sum(){ std::vector visit_lst = this->get_children_visits(); return sum(visit_lst); } float CNode::get_v_mix(){ std::vector pi_lst = this->get_policy(); float pi_sum = 0.0; float pi_qsa_sum = 0.0; float v_mix = 0.0; for(int action = 0; action < this->num_actions; ++action){ CNode* child = this->get_child(action); if(child->is_expanded()){ pi_sum += pi_lst[action]; pi_qsa_sum += pi_lst[action] * this->get_qsa(action); } } // if no child has been visited if(pi_sum < EPSILON) v_mix = this->get_value(); else{ v_mix = (1. / (1. + this->visit_count)) * (this->get_value() + this->visit_count * pi_qsa_sum / pi_sum); } return v_mix; } float CNode::get_reward(){ if(this->reset_value_prefix){ // printf("reset\n"); return this->value_prefix; } else return this->value_prefix - (this->parent)->value_prefix; } float CNode::get_value(){ if(this->is_expanded()) return sum(this->estimated_value_lst) / float(this->estimated_value_lst.size()); else return (this->parent)->get_v_mix(); } float CNode::get_qsa(int action){ CNode* child = this->get_child(action); float qsa = child->get_reward() + this->discount * child->get_value(); return qsa; } CNode* CNode::get_child(int action){ int index = this->children_idx[action]; return &((*(this->ptr_node_pool))[index]); } CNode* CNode::get_root(){ CNode* node = this; while(!this->is_root()){ node = node->parent; } return node; } std::vector CNode::get_expanded_children(){ std::vector children; // printf("num_actions=%d\n", this->num_actions); for(int action = 0; action < this->num_actions; ++action){ CNode* child = this->get_child(action); if(child->is_expanded()){ children.push_back(child); } } return children; } int CNode::is_root(){ return this->parent == nullptr; } int CNode::is_leaf(){ std::vector children = this->get_expanded_children(); return children.size() == 0; } int CNode::is_expanded(){ return this->children_idx.size() > 0; } int CNode::do_equal_visit(int num_simulations){ int min_visit_count = num_simulations + 1; int action = -1; // printf("selected_size=%d\n", this->selected_children_idx.size()); for(int selected_child_idx : this->selected_children_idx){ // printf("%d ", selected_child_idx); int visit_count = (this->get_child(selected_child_idx))->visit_count; if(visit_count < min_visit_count){ action = selected_child_idx; min_visit_count = visit_count; } } // printf("ywr_root_select=%d\n", action); return action; } void CNode::print_tree(std::vector &info){ // printf("expanded=%d\n", this->is_expanded()); if(!this->is_expanded()) return; for(int i = 0; i < this->depth; ++i){ std::cout << info[i]; } int is_leaf = this->is_leaf(); // printf("again\n"); if(is_leaf) std::cout << "└──"; else std::cout << "├──"; this->print(); std::vector expanded_children = this->get_expanded_children(); // printf("finish, %d\n", expanded_children.size()); for(CNode* child : expanded_children){ std::string str = "| "; if(is_leaf) str = " "; info.push_back(str); // printf("vc=%d\n", child->visit_count); child->print_tree(info); } } void CNode::print(){ std::string action_info = std::to_string(this->action); if(this->is_root()){ action_info = "["; for(int a : this->selected_children_idx) action_info = action_info + std::to_string(a) + ", "; action_info += "]"; } std::cout << std::setprecision(3) << "[a=" << action_info << " reset=" << this->reset_value_prefix << " (n=" << this->visit_count << ", vp=" << this->value_prefix << ", r=" << this->get_reward() << ", v=" << this->get_value() << ")]" << std::endl; // printf("here\n"); } //********************************************************* CRoots::CRoots(){ this->num_roots = 0; this->num_actions = 0; this->pool_size = 0; this->discount = 0.0; } CRoots::CRoots(int num_roots, int num_actions, int pool_size, float discount){ this->num_roots = num_roots; this->num_actions = num_actions; this->pool_size = pool_size; this->discount = discount; this->node_pools.reserve(num_roots); this->roots.reserve(num_roots); for(int i = 0; i < num_roots; ++i){ this->node_pools.push_back(std::vector()); this->node_pools[i].reserve(pool_size); this->roots.push_back(CNode(1, -1, nullptr, &this->node_pools[i], discount, num_actions)); } } CRoots::~CRoots(){} void CRoots::prepare(const std::vector &values, const std::vector> &policies, int leaf_action_num){ for(int i = 0; i < this->num_roots; ++i){ this->roots[i].expand(0, i, 0, policies[i], 1, leaf_action_num); this->roots[i].estimated_value_lst.push_back(values[i]); this->roots[i].visit_count += 1; } } void CRoots::clear(){ this->node_pools.clear(); this->roots.clear(); } std::vector> CRoots::get_trajectories(){ std::vector> trajs; trajs.reserve(this->num_roots); for(int i = 0; i < this->num_roots; ++i){ trajs.push_back(this->roots[i].get_trajectory()); } return trajs; } std::vector> CRoots::get_distributions(){ std::vector> distributions; distributions.reserve(this->num_roots); for(int i = 0; i < this->num_roots; ++i){ distributions.push_back(this->roots[i].get_children_visits()); } return distributions; } std::vector> CRoots::get_root_policies(tools::CMinMaxStatsList *min_max_stats_lst){ std::vector> policies; policies.reserve(this->num_roots); for(int i = 0; i < this->num_roots; ++i){ std::vector transformed_completed_Qs = get_transformed_completed_Qs(&(this->roots[i]), min_max_stats_lst->stats_lst[i], 0); for (int j = 0; j < this->roots[i].num_actions; j++){ float cq = transformed_completed_Qs[j]; if (isnan(cq)) printf("trans_Q NaN, %d, %d, %2f, %2f\n", i, j, min_max_stats_lst->stats_lst[i].maximum, min_max_stats_lst->stats_lst[i].minimum); } std::vector improved_policy = this->roots[i].get_improved_policy(transformed_completed_Qs); policies.push_back(improved_policy); } return policies; } std::vector CRoots::get_best_actions(){ std::vector best_actions(this->num_roots, -1); for(int i = 0; i < this->num_roots; ++i){ best_actions[i] = this->roots[i].selected_children_idx[0]; } return best_actions; } std::vector CRoots::get_values(){ std::vector values; for(int i = 0; i < this->num_roots; ++i){ values.push_back(this->roots[i].get_value()); } return values; } void CRoots::print_tree(){ for(int i = 0; i < this->num_roots; ++i){ std::vector info; this->roots[i].print_tree(info); } } //********************************************************* void print_arr(std::vector arr){ std::cout << "["; for(float a : arr){ std::cout << a << ","; } std::cout << "]" << std::endl; } void print_arr(std::vector arr){ std::cout << "["; for(int a : arr){ std::cout << a << ","; } std::cout << "]" << std::endl; } int argmax(std::vector arr){ int index = -3; float max_val = FLOAT_MIN; for(long unsigned int i = 0; i < arr.size(); ++i){ if(arr[i] > max_val){ max_val = arr[i]; index = i; } } return index; } int max_int(std::vector arr){ int max_val = int(FLOAT_MIN); for(int a : arr){ if(a > max_val){ max_val = a; } } return max_val; } float max_float(std::vector arr){ float max_val = FLOAT_MIN; for(float a : arr){ if(a > max_val){ max_val = a; } } return max_val; } float min_float(std::vector arr){ float min_val = 1000000.0; for (float a : arr){ if (a < min_val){ min_val = a; } } return min_val; } int sum(std::vector arr){ int res = 0.; for(int a : arr){ res += a; } return res; } float sum(std::vector arr){ float res = 0.; for(float a : arr){ res += a; } return res; } std::vector get_transformed_completed_Qs(CNode* node, tools::CMinMaxStats &min_max_stats, int final){ // get completed Q int to_normalize = 1; if (final) to_normalize = 2; std::vector completed_Qs = node->get_completed_Q(min_max_stats, to_normalize); // calculate the transformed Q values int max_child_visit_count = max_int(node->get_children_visits()); // sigma transform for(long unsigned int i = 0; i < completed_Qs.size(); ++i){ completed_Qs[i] = (min_max_stats.c_visit + max_child_visit_count) * min_max_stats.c_scale * completed_Qs[i]; } return completed_Qs; } std::vector c_batch_sequential_halving(CRoots *roots, const std::vector>& gumble_noises, tools::CMinMaxStatsList *min_max_stats_lst, int current_phase, int current_num_top_actions){ std::vector best_actions(roots->num_roots, -1); for(int i = 0; i < roots->num_roots; ++i){ int action = sequential_halving(&(roots->roots[i]), gumble_noises[i], min_max_stats_lst->stats_lst[i], current_phase, current_num_top_actions); best_actions[i] = action; } return best_actions; } int sequential_halving(CNode* root, const std::vector& gumble_noise, tools::CMinMaxStats &min_max_stats, int current_phase, int current_num_top_actions){ std::vector children_prior = root->get_children_priors(); std::vector children_scores; std::vector transformed_completed_Qs = get_transformed_completed_Qs(root, min_max_stats, 0); // the later phase: score = g + logits + sigma(hat_q) from the selected children std::vector selected_children_idx = root->selected_children_idx; for(int action : selected_children_idx){ children_scores.push_back(gumble_noise[action] + children_prior[action] + transformed_completed_Qs[action]); } std::vector idx(children_scores.size()); std::iota(idx.begin(), idx.end(), 0); std::sort(idx.begin(), idx.end(), [&children_scores](size_t index_1, size_t index_2) {return children_scores[index_1] > children_scores[index_2]; }); root->selected_children_idx.clear(); for(int i = 0; i < current_num_top_actions; ++i){ root->selected_children_idx.push_back(selected_children_idx[idx[i]]); } int best_action = root->selected_children_idx[0]; return best_action; } int select_action(CNode* node, tools::CMinMaxStats &min_max_stats, int num_simulations, int simulation_idx, const std::vector& gumble_noise, int current_num_top_actions){ int action = -1; if(node->is_root()){ if(simulation_idx == 0){ // the first phase: score = g + logits from all children std::vector children_prior = node->get_children_priors(); std::vector children_scores; for(int action = 0; action < node->num_actions; ++action){ children_scores.push_back(gumble_noise[action] + children_prior[action]); } std::vector idx(children_scores.size()); std::iota(idx.begin(), idx.end(), 0); std::sort(idx.begin(), idx.end(), [&children_scores](size_t index_1, size_t index_2) {return children_scores[index_1] > children_scores[index_2]; }); node->selected_children_idx.clear(); for(int action = 0; action < current_num_top_actions; ++action){ node->selected_children_idx.push_back(idx[action]); } } action = node->do_equal_visit(num_simulations); } else{ std::vector transformed_completed_Qs = get_transformed_completed_Qs(node, min_max_stats, 0); std::vector improved_policy = node->get_improved_policy(transformed_completed_Qs); std::vector ori_policy = node->get_policy(); std::vector children_visits = node->get_children_visits(); std::vector children_scores(node->num_actions, 0.0); for(int a = 0; a < node->num_actions; ++a){ float score = improved_policy[a] - children_visits[a] / (1. + float(node->visit_count)); children_scores[a] = score; } action = argmax(children_scores); } return action; } void c_batch_traverse(CRoots *roots, tools::CMinMaxStatsList *min_max_stats_lst, CSearchResults &results, int num_simulations, int simulation_idx, const std::vector>& gumble_noise, int current_num_top_actions){ int last_action = -1; results.search_lens = std::vector(); for(int i = 0; i < results.num; ++i){ CNode *node = &(roots->roots[i]); int search_len = 0; results.search_paths[i].push_back(node); while(node->is_expanded()){ int action = select_action(node, min_max_stats_lst->stats_lst[i], num_simulations, simulation_idx, gumble_noise[i], current_num_top_actions); node->best_action = action; node = node->get_child(action); last_action = action; results.search_paths[i].push_back(node); search_len += 1; } CNode* parent = results.search_paths[i][results.search_paths[i].size() - 2]; results.hidden_state_index_x_lst.push_back(parent->hidden_state_index_x); results.hidden_state_index_y_lst.push_back(parent->hidden_state_index_y); results.last_actions.push_back(last_action); results.search_lens.push_back(search_len); results.nodes.push_back(node); } } void back_propagate(std::vector &search_path, tools::CMinMaxStats &min_max_stats, float value){ float bootstrap_value = value; int path_len = search_path.size(); std::vector info; for(int i = path_len - 1; i >= 0; --i){ CNode* node = search_path[i]; (node->estimated_value_lst).push_back(bootstrap_value); node->visit_count += 1; bootstrap_value = node->get_reward() + node->discount * bootstrap_value; min_max_stats.update(bootstrap_value); } } void c_batch_back_propagate(int hidden_state_index_x, const std::vector &value_prefixs, const std::vector &values, const std::vector> &policies, tools::CMinMaxStatsList *min_max_stats_lst, CSearchResults &results, std::vector to_reset_lst, int leaf_action_num){ for(int i = 0; i < results.num; ++i){ results.nodes[i]->expand(hidden_state_index_x, i, value_prefixs[i], policies[i], to_reset_lst[i], leaf_action_num); back_propagate(results.search_paths[i], min_max_stats_lst->stats_lst[i], values[i]); } } }