diff --git a/src/shammodels/ramses/include/shammodels/ramses/modules/AMRGridRefinementHandler.hpp b/src/shammodels/ramses/include/shammodels/ramses/modules/AMRGridRefinementHandler.hpp index bab7476bb7..a2ae7d85c5 100644 --- a/src/shammodels/ramses/include/shammodels/ramses/modules/AMRGridRefinementHandler.hpp +++ b/src/shammodels/ramses/include/shammodels/ramses/modules/AMRGridRefinementHandler.hpp @@ -69,10 +69,23 @@ namespace shammodels::basegodunov::modules { */ template void gen_refine_block_changes( - shambase::DistributedData> &refine_list, - shambase::DistributedData> &derefine_list, + shambase::DistributedData> &refine_flags, + shambase::DistributedData> &derefine_flags, T &&...args); + /** + * @brief Enforces the 2:1 refinement ratio for blocks. + * + * This function iterates through blocks marked for refinement and ensures that + * adjacent, coarser blocks are also marked for refinement to maintain the 2:1 + * grid balance. This is done iteratively to propagate the refinement as needed. + * @param refine_flags refinement flags + * @param refine_list refinement maps + */ + void enforce_two_to_one_for_refinement( + shambase::DistributedData> &&refine_flags, + shambase::DistributedData> &refine_list); + template bool internal_refine_grid(shambase::DistributedData> &&refine_list); diff --git a/src/shammodels/ramses/include/shammodels/ramses/modules/SolverStorage.hpp b/src/shammodels/ramses/include/shammodels/ramses/modules/SolverStorage.hpp index ca78cde26a..baab0113c5 100644 --- a/src/shammodels/ramses/include/shammodels/ramses/modules/SolverStorage.hpp +++ b/src/shammodels/ramses/include/shammodels/ramses/modules/SolverStorage.hpp @@ -60,6 +60,7 @@ namespace shammodels::basegodunov { using Tscal = shambase::VecComponent; using Tgridscal = shambase::VecComponent; static constexpr u32 dim = shambase::VectorProperties::dimension; + using TgridUint = typename std::make_unsigned>::type; using RTree = RadixTree; @@ -141,6 +142,9 @@ namespace shammodels::basegodunov { std::shared_ptr> idx_in_ghost; + std::shared_ptr> level0_size; + std::shared_ptr> amr_block_levels; + std::shared_ptr>> rho_face_xp; std::shared_ptr>> rho_face_xm; std::shared_ptr>> rho_face_yp; diff --git a/src/shammodels/ramses/src/modules/AMRGridRefinementHandler.cpp b/src/shammodels/ramses/src/modules/AMRGridRefinementHandler.cpp index d43f70b7c9..accee2eb9d 100644 --- a/src/shammodels/ramses/src/modules/AMRGridRefinementHandler.cpp +++ b/src/shammodels/ramses/src/modules/AMRGridRefinementHandler.cpp @@ -14,18 +14,28 @@ * */ -#include "shammodels/ramses/modules/AMRGridRefinementHandler.hpp" +#include "shambase/DistributedData.hpp" +#include "shambase/aliases_int.hpp" +#include "shambase/memory.hpp" #include "shamalgs/details/algorithm/algorithm.hpp" +#include "shambackends/DeviceBuffer.hpp" +#include "shambackends/DeviceQueue.hpp" +#include "shambackends/EventList.hpp" #include "shamcomm/logs.hpp" +#include "shammodels/ramses/modules/AMRGridRefinementHandler.hpp" #include "shammodels/ramses/modules/AMRSortBlocks.hpp" +#include "shamsys/NodeInstance.hpp" +#include +#include #include +#include template template void shammodels::basegodunov::modules::AMRGridRefinementHandler:: gen_refine_block_changes( - shambase::DistributedData> &refine_list, - shambase::DistributedData> &derefine_list, + shambase::DistributedData> &refine_flags, + shambase::DistributedData> &derefine_flags, T &&...args) { using namespace shamrock::patch; @@ -42,16 +52,16 @@ void shammodels::basegodunov::modules::AMRGridRefinementHandler: // create the refine and derefine flags buffers u32 obj_cnt = pdat.get_obj_cnt(); - sham::DeviceBuffer refine_flags(obj_cnt, dev_sched); - sham::DeviceBuffer derefine_flags(obj_cnt, dev_sched); + sham::DeviceBuffer refine_flag(obj_cnt, dev_sched); + sham::DeviceBuffer derefine_flag(obj_cnt, dev_sched); { sham::EventList depends_list; UserAcc uacc(depends_list, id_patch, cur_p, pdat, args...); - auto refine_acc = refine_flags.get_write_access(depends_list); - auto derefine_acc = derefine_flags.get_write_access(depends_list); + auto refine_acc = refine_flag.get_write_access(depends_list); + auto derefine_acc = derefine_flag.get_write_access(depends_list); // fill in the flags auto e = q.submit(depends_list, [&](sycl::handler &cgh) { @@ -73,87 +83,145 @@ void shammodels::basegodunov::modules::AMRGridRefinementHandler: sham::EventList resulting_events; resulting_events.add_event(e); - refine_flags.complete_event_state(resulting_events); - derefine_flags.complete_event_state(resulting_events); + refine_flag.complete_event_state(resulting_events); + derefine_flag.complete_event_state(resulting_events); uacc.finalize(resulting_events, id_patch, cur_p, pdat, args...); } - sham::DeviceBuffer &buf_cell_min = pdat.get_field_buf_ref(0); - sham::DeviceBuffer &buf_cell_max = pdat.get_field_buf_ref(1); + refine_flags.add_obj(id_patch, std::move(refine_flag)); + derefine_flags.add_obj(id_patch, std::move(derefine_flag)); + }); +} - sham::EventList depends_list; - auto acc_min = buf_cell_min.get_read_access(depends_list); - auto acc_max = buf_cell_max.get_read_access(depends_list); - auto acc_merge_flag = derefine_flags.get_write_access(depends_list); +/** + * @brief check and enforce 2:1 rule for refinement + * @tparam Tvec + * @tparam TgridVec + * @param refine_list refinement mask + * @param refine_idx_list refinement map + */ +template +void shammodels::basegodunov::modules::AMRGridRefinementHandler:: + enforce_two_to_one_for_refinement( + shambase::DistributedData> &&refine_flags, + shambase::DistributedData> &refine_list) { - // keep only derefine flags on only if the eight cells want to merge and if they can - auto e = q.submit(depends_list, [&](sycl::handler &cgh) { - cgh.parallel_for(sycl::range<1>(obj_cnt), [=](sycl::item<1> gid) { - u32 id = gid.get_linear_id(); + using namespace shamrock::patch; + using AMRGraph = shammodels::basegodunov::modules::AMRGraph; + using Direction_ = shammodels::basegodunov::modules::Direction; + using AMRGraphLinkiterator = shammodels::basegodunov::modules::AMRGraph::ro_access; + using TgridUint = typename std::make_unsigned>::type; - std::array blocks; - bool do_merge = true; + u64 tot_refine = 0; - // This avoid the case where we are in the last block of the buffer to avoid the - // out-of-bound read - if (id + split_count <= obj_cnt) { - bool all_want_to_merge = true; + sham::DeviceQueue &q = shamsys::instance::get_compute_scheduler().get_queue(); + auto dev_sched = shamsys::instance::get_compute_scheduler_ptr(); + scheduler().for_each_patchdata_nonempty([&](Patch cur_p, PatchDataLayer &pdat) { + u64 id_patch = cur_p.id_patch; - for (u32 lid = 0; lid < split_count; lid++) { - blocks[lid] = BlockCoord{acc_min[gid + lid], acc_max[gid + lid]}; - all_want_to_merge = all_want_to_merge && acc_merge_flag[gid + lid]; + sham::DeviceBuffer &refine_flags_buf = refine_list.get(id_patch); + u32 obj_cnt = pdat.get_obj_cnt(); + + // blocks graph in each direction for the current patch + AMRGraph &block_graph_neighs_xp = shambase::get_check_ref(storage.block_graph_edge) + .get_refs_dir(Direction_::xp) + .get(id_patch); + AMRGraph &block_graph_neighs_xm = shambase::get_check_ref(storage.block_graph_edge) + .get_refs_dir(Direction_::xm) + .get(id_patch); + AMRGraph &block_graph_neighs_yp = shambase::get_check_ref(storage.block_graph_edge) + .get_refs_dir(Direction_::yp) + .get(id_patch); + AMRGraph &block_graph_neighs_ym = shambase::get_check_ref(storage.block_graph_edge) + .get_refs_dir(Direction_::ym) + .get(id_patch); + AMRGraph &block_graph_neighs_zp = shambase::get_check_ref(storage.block_graph_edge) + .get_refs_dir(Direction_::zp) + .get(id_patch); + AMRGraph &block_graph_neighs_zm = shambase::get_check_ref(storage.block_graph_edge) + .get_refs_dir(Direction_::zm) + .get(id_patch); + + // get levels in the current patch + sham::DeviceBuffer &buf_amr_block_levels + = shambase::get_check_ref(storage.amr_block_levels).get_buf(id_patch); + + // propagate refinement until stability + for (auto pass = 0; pass < 3; pass++) { + + sham::EventList depend_list; + AMRGraphLinkiterator block_graph_xp + = block_graph_neighs_xp.get_read_access(depend_list); + AMRGraphLinkiterator block_graph_xm + = block_graph_neighs_xm.get_read_access(depend_list); + AMRGraphLinkiterator block_graph_yp + = block_graph_neighs_yp.get_read_access(depend_list); + AMRGraphLinkiterator block_graph_ym + = block_graph_neighs_ym.get_read_access(depend_list); + AMRGraphLinkiterator block_graph_zp + = block_graph_neighs_zp.get_read_access(depend_list); + AMRGraphLinkiterator block_graph_zm + = block_graph_neighs_zm.get_read_access(depend_list); + auto acc_amr_levels = buf_amr_block_levels.get_read_access(depend_list); + auto acc_ref_flags = refine_flags_buf.get_write_access(depend_list); + + auto e_all_dir = q.submit(depend_list, [&](sycl::handler &cgh) { + cgh.parallel_for(sycl::range<1>(obj_cnt), [=](sycl::item<1> gid) { + u32 block_id = gid.get_linear_id(); + + // get refinement flag and amr level of the current block + u32 cur_ref_flag = acc_ref_flags[block_id]; + auto cur_block_level = acc_amr_levels[block_id]; + + if (cur_ref_flag) { + auto enforce_2_to_1_rule = [&](u32 neigh_block_id) { + if (0 <= neigh_block_id && neigh_block_id < obj_cnt) { + auto neigh_block_level = acc_amr_levels[neigh_block_id]; + if (cur_block_level > neigh_block_level) { + sycl::atomic_ref< + u32, + sycl::memory_order::relaxed, + sycl::memory_scope::device> + atomic_flag(acc_ref_flags[neigh_block_id]); + atomic_flag.store(1); + } + } + }; + + block_graph_xp.for_each_object_link(block_id, enforce_2_to_1_rule); + block_graph_xm.for_each_object_link(block_id, enforce_2_to_1_rule); + block_graph_yp.for_each_object_link(block_id, enforce_2_to_1_rule); + block_graph_ym.for_each_object_link(block_id, enforce_2_to_1_rule); + block_graph_zp.for_each_object_link(block_id, enforce_2_to_1_rule); + block_graph_zm.for_each_object_link(block_id, enforce_2_to_1_rule); } - - do_merge = all_want_to_merge && BlockCoord::are_mergeable(blocks); - - } else { - do_merge = false; - } - - acc_merge_flag[gid] = do_merge; + }); }); - }); - - buf_cell_min.complete_event_state(e); - buf_cell_max.complete_event_state(e); - derefine_flags.complete_event_state(e); + block_graph_neighs_xp.complete_event_state(e_all_dir); + block_graph_neighs_xm.complete_event_state(e_all_dir); + block_graph_neighs_yp.complete_event_state(e_all_dir); + block_graph_neighs_ym.complete_event_state(e_all_dir); + block_graph_neighs_zp.complete_event_state(e_all_dir); + block_graph_neighs_zm.complete_event_state(e_all_dir); + buf_amr_block_levels.complete_event_state(e_all_dir); + refine_flags_buf.complete_event_state(e_all_dir); + } //////////////////////////////////////////////////////////////////////////////// // refinement //////////////////////////////////////////////////////////////////////////////// // perform stream compactions on the refinement flags - auto buf_refine = shamalgs::numeric::stream_compact(dev_sched, refine_flags, obj_cnt); - + auto buf_refine = shamalgs::numeric::stream_compact(dev_sched, refine_flags_buf, obj_cnt); shamlog_debug_ln( - "AMRGrid", "patch ", id_patch, "refine block count = ", buf_refine.get_size()); - + "AMRGrid", "patch ", id_patch, buf_refine.get_size(), "marked for refinement + 2:1"); tot_refine += buf_refine.get_size(); - - // add the results to the map refine_list.add_obj(id_patch, std::move(buf_refine)); - - //////////////////////////////////////////////////////////////////////////////// - // derefinement - //////////////////////////////////////////////////////////////////////////////// - - // perform stream compactions on the derefinement flags - auto buf_derefine = shamalgs::numeric::stream_compact(dev_sched, derefine_flags, obj_cnt); - - shamlog_debug_ln( - "AMRGrid", "patch ", id_patch, "merge block count = ", buf_derefine.get_size()); - - tot_derefine += buf_derefine.get_size(); - - // add the results to the map - derefine_list.add_obj(id_patch, std::move(buf_derefine)); }); - - logger::info_ln("AMRGrid", "on this process", tot_refine, "blocks were refined"); - logger::info_ln( - "AMRGrid", "on this process", tot_derefine * split_count, "blocks were derefined"); + logger::info_ln("AMRGrid", "on this process", tot_refine, "blocks will be refined"); } + template template bool shammodels::basegodunov::modules::AMRGridRefinementHandler:: @@ -368,10 +436,16 @@ void shammodels::basegodunov::modules::AMRGridRefinementHandler: block_sorter.reorder_amr_blocks(); // get refine and derefine list + shambase::DistributedData> refine_flags; + shambase::DistributedData> derefine_flags; + shambase::DistributedData> refine_list; shambase::DistributedData> derefine_list; - gen_refine_block_changes(refine_list, derefine_list); + gen_refine_block_changes(refine_flags, derefine_flags); + + ///// enforce 2:1 for refinement /////// + enforce_two_to_one_for_refinement(std::move(refine_flags), refine_list); //////// apply refine //////// // Note that this only add new blocks at the end of the patchdata @@ -382,7 +456,8 @@ void shammodels::basegodunov::modules::AMRGridRefinementHandler: // This is ok to call straight after the refine without edditing the index list in derefine_list // since no permutations were applied in internal_refine_grid and no cells can be both refined // and derefined in the same pass - internal_derefine_grid(std::move(derefine_list)); + + // internal_derefine_grid(std::move(derefine_list)); } template @@ -633,11 +708,17 @@ void shammodels::basegodunov::modules::AMRGridRefinementHandler: Tscal dxfact(solver_config.grid_coord_to_pos_fact); // get refine and derefine list + shambase::DistributedData> refine_flags; + shambase::DistributedData> derefine_flags; + shambase::DistributedData> refine_list; - shambase::DistributedData> derefine_list; + // shambase::DistributedData> derefine_list; gen_refine_block_changes( - refine_list, derefine_list, dxfact, cfg->crit_mass); + refine_flags, derefine_flags, dxfact, cfg->crit_mass); + + ///// enforce 2:1 for refinement /////// + enforce_two_to_one_for_refinement(std::move(refine_flags), refine_list); //////// apply refine //////// // Note that this only add new blocks at the end of the patchdata @@ -648,7 +729,8 @@ void shammodels::basegodunov::modules::AMRGridRefinementHandler: // This is ok to call straight after the refine without edditing the index list in // derefine_list since no permutations were applied in internal_refine_grid and no cells can // be both refined and derefined in the same pass - bool change_derefine = internal_derefine_grid(std::move(derefine_list)); + bool change_derefine = false; + // internal_derefine_grid(std::move(derefine_list)); has_cell_order_changed = has_cell_order_changed || (change_refine || change_derefine); }