diff --git a/core/ctree/cnode.cpp b/core/ctree/cnode.cpp index 6563c007..d24685d8 100644 --- a/core/ctree/cnode.cpp +++ b/core/ctree/cnode.cpp @@ -201,7 +201,7 @@ namespace tree{ CRoots::~CRoots(){} - void CRoots::prepare(float root_exploration_fraction, const std::vector> &noises, const std::vector &value_prefixs, const std::vector> &policies){ + void CRoots::prepare(float root_exploration_fraction, const std::vector > &noises, const std::vector &value_prefixs, const std::vector > &policies){ for(int i = 0; i < this->root_num; ++i){ this->roots[i].expand(0, 0, i, value_prefixs[i], policies[i]); this->roots[i].add_exploration_noise(root_exploration_fraction, noises[i]); @@ -210,7 +210,7 @@ namespace tree{ } } - void CRoots::prepare_no_noise(const std::vector &value_prefixs, const std::vector> &policies){ + void CRoots::prepare_no_noise(const std::vector &value_prefixs, const std::vector > &policies){ for(int i = 0; i < this->root_num; ++i){ this->roots[i].expand(0, 0, i, value_prefixs[i], policies[i]); @@ -223,8 +223,8 @@ namespace tree{ this->roots.clear(); } - std::vector> CRoots::get_trajectories(){ - std::vector> trajs; + std::vector > CRoots::get_trajectories(){ + std::vector > trajs; trajs.reserve(this->root_num); for(int i = 0; i < this->root_num; ++i){ @@ -233,8 +233,8 @@ namespace tree{ return trajs; } - std::vector> CRoots::get_distributions(){ - std::vector> distributions; + std::vector > CRoots::get_distributions(){ + std::vector > distributions; distributions.reserve(this->root_num); for(int i = 0; i < this->root_num; ++i){ @@ -314,7 +314,7 @@ namespace tree{ update_tree_q(root, min_max_stats, discount); } - void cbatch_back_propagate(int hidden_state_index_x, float discount, const std::vector &value_prefixs, const std::vector &values, const std::vector> &policies, tools::CMinMaxStatsList *min_max_stats_lst, CSearchResults &results, std::vector is_reset_lst){ + void cbatch_back_propagate(int hidden_state_index_x, float discount, const std::vector &value_prefixs, const std::vector &values, const std::vector > &policies, tools::CMinMaxStatsList *min_max_stats_lst, CSearchResults &results, std::vector is_reset_lst){ for(int i = 0; i < results.num; ++i){ results.nodes[i]->expand(0, hidden_state_index_x, i, value_prefixs[i], policies[i]); // reset diff --git a/core/ctree/cnode.h b/core/ctree/cnode.h index efc0e854..18a839ec 100644 --- a/core/ctree/cnode.h +++ b/core/ctree/cnode.h @@ -44,17 +44,17 @@ namespace tree { public: int root_num, action_num, pool_size; std::vector roots; - std::vector> node_pools; + std::vector > node_pools; CRoots(); CRoots(int root_num, int action_num, int pool_size); ~CRoots(); - void prepare(float root_exploration_fraction, const std::vector> &noises, const std::vector &value_prefixs, const std::vector> &policies); - void prepare_no_noise(const std::vector &value_prefixs, const std::vector> &policies); + void prepare(float root_exploration_fraction, const std::vector > &noises, const std::vector &value_prefixs, const std::vector > &policies); + void prepare_no_noise(const std::vector &value_prefixs, const std::vector > &policies); void clear(); - std::vector> get_trajectories(); - std::vector> get_distributions(); + std::vector > get_trajectories(); + std::vector > get_distributions(); std::vector get_values(); }; @@ -64,7 +64,7 @@ namespace tree { int num; std::vector hidden_state_index_x_lst, hidden_state_index_y_lst, last_actions, search_lens; std::vector nodes; - std::vector> search_paths; + std::vector > search_paths; CSearchResults(); CSearchResults(int num); @@ -76,7 +76,7 @@ namespace tree { //********************************************************* void update_tree_q(CNode* root, tools::CMinMaxStats &min_max_stats, float discount); void cback_propagate(std::vector &search_path, tools::CMinMaxStats &min_max_stats, int to_play, float value, float discount); - void cbatch_back_propagate(int hidden_state_index_x, float discount, const std::vector &value_prefixs, const std::vector &values, const std::vector> &policies, tools::CMinMaxStatsList *min_max_stats_lst, CSearchResults &results, std::vector is_reset_lst); + void cbatch_back_propagate(int hidden_state_index_x, float discount, const std::vector &value_prefixs, const std::vector &values, const std::vector > &policies, tools::CMinMaxStatsList *min_max_stats_lst, CSearchResults &results, std::vector is_reset_lst); int cselect_child(CNode* root, tools::CMinMaxStats &min_max_stats, int pb_c_base, float pb_c_init, float discount, float mean_q); float cucb_score(CNode *child, tools::CMinMaxStats &min_max_stats, float parent_mean_q, int is_reset, float total_children_visit_counts, float parent_value_prefix, float pb_c_base, float pb_c_init, float discount); void cbatch_traverse(CRoots *roots, int pb_c_base, float pb_c_init, float discount, tools::CMinMaxStatsList *min_max_stats_lst, CSearchResults &results);