From 664e879a4596b0cc3e53dbc91f9b23fc77b00c18 Mon Sep 17 00:00:00 2001 From: Leodasce Sewanou Date: Fri, 13 Feb 2026 14:38:24 +0100 Subject: [PATCH 1/5] [Ramses][AMR][1/] Refactor AMRGridRefinementHandler.hpp This commit moves the geometrical validity checks for block coarsening into a separate module, which will be introduced in a forthcoming commit. As a result, gen_refine_block_changes is now responsible only for computing refinement and derefinement flags based on the user-defined refinement criteria. --- .../modules/AMRGridRefinementHandler.hpp | 4 +- .../src/modules/AMRGridRefinementHandler.cpp | 93 +++---------------- 2 files changed, 16 insertions(+), 81 deletions(-) diff --git a/src/shammodels/ramses/include/shammodels/ramses/modules/AMRGridRefinementHandler.hpp b/src/shammodels/ramses/include/shammodels/ramses/modules/AMRGridRefinementHandler.hpp index b68e9a96cd..1c75cbd613 100644 --- a/src/shammodels/ramses/include/shammodels/ramses/modules/AMRGridRefinementHandler.hpp +++ b/src/shammodels/ramses/include/shammodels/ramses/modules/AMRGridRefinementHandler.hpp @@ -74,8 +74,8 @@ namespace shammodels::basegodunov::modules { */ template void gen_refine_block_changes( - shambase::DistributedData &refine_list, - shambase::DistributedData &derefine_list, + shambase::DistributedData> &refine_flags, + shambase::DistributedData> &derefine_flags, T &&...args); template diff --git a/src/shammodels/ramses/src/modules/AMRGridRefinementHandler.cpp b/src/shammodels/ramses/src/modules/AMRGridRefinementHandler.cpp index e95b85190c..f8eed09b98 100644 --- a/src/shammodels/ramses/src/modules/AMRGridRefinementHandler.cpp +++ b/src/shammodels/ramses/src/modules/AMRGridRefinementHandler.cpp @@ -14,25 +14,27 @@ * */ -#include "shammodels/ramses/modules/AMRGridRefinementHandler.hpp" +#include "shambase/DistributedData.hpp" +#include "shambase/aliases_int.hpp" #include "shamalgs/details/algorithm/algorithm.hpp" #include "shamcomm/logs.hpp" +#include "shammodels/ramses/modules/AMRGridRefinementHandler.hpp" #include "shammodels/ramses/modules/AMRSortBlocks.hpp" +#include +#include #include +#include template template void shammodels::basegodunov::modules::AMRGridRefinementHandler:: gen_refine_block_changes( - shambase::DistributedData &refine_list, - shambase::DistributedData &derefine_list, + shambase::DistributedData> &refn_flags, + shambase::DistributedData> &derfn_flags, T &&...args) { using namespace shamrock::patch; - u64 tot_refine = 0; - u64 tot_derefine = 0; - scheduler().for_each_patchdata_nonempty([&](Patch cur_p, PatchDataLayer &pdat) { sham::DeviceQueue &q = shamsys::instance::get_compute_scheduler().get_queue(); @@ -73,81 +75,12 @@ void shammodels::basegodunov::modules::AMRGridRefinementHandler: uacc.finalize(resulting_events, id_patch, cur_p, pdat, args...); } - sham::DeviceBuffer &buf_cell_min = pdat.get_field_buf_ref(0); - sham::DeviceBuffer &buf_cell_max = pdat.get_field_buf_ref(1); - - sham::EventList depends_list; - auto acc_min = buf_cell_min.get_read_access(depends_list); - auto acc_max = buf_cell_max.get_read_access(depends_list); - - // keep only derefine flags on only if the eight cells want to merge and if they can - auto e = q.submit(depends_list, [&](sycl::handler &cgh) { - sycl::accessor acc_merge_flag{derefine_flags, cgh, sycl::read_write}; - - cgh.parallel_for(sycl::range<1>(obj_cnt), [=](sycl::item<1> gid) { - u32 id = gid.get_linear_id(); - - std::array blocks; - bool do_merge = true; - // This avoid the case where we are in the last block of the buffer to avoid the - // out-of-bound read - if (id + split_count <= obj_cnt) { - bool all_want_to_merge = true; - - for (u32 lid = 0; lid < split_count; lid++) { - blocks[lid] = BlockCoord{acc_min[gid + lid], acc_max[gid + lid]}; - all_want_to_merge = all_want_to_merge && acc_merge_flag[gid + lid]; - } - - do_merge = all_want_to_merge && BlockCoord::are_mergeable(blocks); - - } else { - do_merge = false; - } - - acc_merge_flag[gid] = do_merge; - }); - }); - - buf_cell_min.complete_event_state(e); - buf_cell_max.complete_event_state(e); - - //////////////////////////////////////////////////////////////////////////////// - // refinement - //////////////////////////////////////////////////////////////////////////////// - - // perform stream compactions on the refinement flags - auto [buf_refine, len_refine] - = shamalgs::numeric::stream_compact(q.q, refine_flags, obj_cnt); - - shamlog_debug_ln("AMRGrid", "patch ", id_patch, "refine block count = ", len_refine); - - tot_refine += len_refine; - - // add the results to the map - refine_list.add_obj(id_patch, OptIndexList{std::move(buf_refine), len_refine}); - - //////////////////////////////////////////////////////////////////////////////// - // derefinement - //////////////////////////////////////////////////////////////////////////////// - - // perform stream compactions on the derefinement flags - auto [buf_derefine, len_derefine] - = shamalgs::numeric::stream_compact(q.q, derefine_flags, obj_cnt); - - shamlog_debug_ln("AMRGrid", "patch ", id_patch, "merge block count = ", len_derefine); - - tot_derefine += len_derefine; - - // add the results to the map - derefine_list.add_obj(id_patch, OptIndexList{std::move(buf_derefine), len_derefine}); + refn_flags.add_obj(id_patch, std::move(refine_flags)); + derfn_flags.add_obj(id_patch, std::move(derefine_flags)); }); - - logger::info_ln("AMRGrid", "on this process", tot_refine, "blocks were refined"); - logger::info_ln( - "AMRGrid", "on this process", tot_derefine * split_count, "blocks were derefined"); } + template template bool shammodels::basegodunov::modules::AMRGridRefinementHandler:: @@ -620,11 +553,13 @@ void shammodels::basegodunov::modules::AMRGridRefinementHandler: Tscal dxfact(solver_config.grid_coord_to_pos_fact); // get refine and derefine list + shambase::DistributedData> refine_flags; + shambase::DistributedData> derefine_flags; shambase::DistributedData refine_list; shambase::DistributedData derefine_list; gen_refine_block_changes( - refine_list, derefine_list, dxfact, cfg->crit_mass); + refine_flags, derefine_flags, dxfact, cfg->crit_mass); //////// apply refine //////// // Note that this only add new blocks at the end of the patchdata From dd3403a87a2ac48c02e0dcc7d8d825ef083cbf58 Mon Sep 17 00:00:00 2001 From: Leodasce Sewanou Date: Fri, 13 Feb 2026 15:30:06 +0100 Subject: [PATCH 2/5] [Ramses][AMR] [Ramses][AMR] Add 2:1 refinement consistency check This commit adds a function to check and enforce the 2:1 refinement rule. In the current implementation, kernels are launched over all blocks rather than only over blocks flagged for refinement. This is because only face neighbors (in the 6 Cartesian directions) are currently available, whereas enforcing the 2:1 rule also requires information about edge and corner neighbors. For robustness, the procedure is iterated three times to ensure convergence. A more efficient level-by-level implementation will be introduced in a future update. --- .../modules/AMRGridRefinementHandler.hpp | 7 + .../ramses/modules/SolverStorage.hpp | 4 + .../src/modules/AMRGridRefinementHandler.cpp | 221 ++++++++++++++++++ 3 files changed, 232 insertions(+) diff --git a/src/shammodels/ramses/include/shammodels/ramses/modules/AMRGridRefinementHandler.hpp b/src/shammodels/ramses/include/shammodels/ramses/modules/AMRGridRefinementHandler.hpp index 1c75cbd613..9fd7d9665f 100644 --- a/src/shammodels/ramses/include/shammodels/ramses/modules/AMRGridRefinementHandler.hpp +++ b/src/shammodels/ramses/include/shammodels/ramses/modules/AMRGridRefinementHandler.hpp @@ -78,6 +78,13 @@ namespace shammodels::basegodunov::modules { shambase::DistributedData> &derefine_flags, T &&...args); + /** + * @brief + */ + void enforce_two_to_one_for_refinement( + shambase::DistributedData> &&refine_flags, + shambase::DistributedData &refine_idx_list); + template bool internal_refine_grid(shambase::DistributedData &&refine_list); diff --git a/src/shammodels/ramses/include/shammodels/ramses/modules/SolverStorage.hpp b/src/shammodels/ramses/include/shammodels/ramses/modules/SolverStorage.hpp index ca78cde26a..baab0113c5 100644 --- a/src/shammodels/ramses/include/shammodels/ramses/modules/SolverStorage.hpp +++ b/src/shammodels/ramses/include/shammodels/ramses/modules/SolverStorage.hpp @@ -60,6 +60,7 @@ namespace shammodels::basegodunov { using Tscal = shambase::VecComponent; using Tgridscal = shambase::VecComponent; static constexpr u32 dim = shambase::VectorProperties::dimension; + using TgridUint = typename std::make_unsigned>::type; using RTree = RadixTree; @@ -141,6 +142,9 @@ namespace shammodels::basegodunov { std::shared_ptr> idx_in_ghost; + std::shared_ptr> level0_size; + std::shared_ptr> amr_block_levels; + std::shared_ptr>> rho_face_xp; std::shared_ptr>> rho_face_xm; std::shared_ptr>> rho_face_yp; diff --git a/src/shammodels/ramses/src/modules/AMRGridRefinementHandler.cpp b/src/shammodels/ramses/src/modules/AMRGridRefinementHandler.cpp index f8eed09b98..b571ca2249 100644 --- a/src/shammodels/ramses/src/modules/AMRGridRefinementHandler.cpp +++ b/src/shammodels/ramses/src/modules/AMRGridRefinementHandler.cpp @@ -16,10 +16,14 @@ #include "shambase/DistributedData.hpp" #include "shambase/aliases_int.hpp" +#include "shambase/memory.hpp" #include "shamalgs/details/algorithm/algorithm.hpp" +#include "shambackends/DeviceQueue.hpp" +#include "shambackends/EventList.hpp" #include "shamcomm/logs.hpp" #include "shammodels/ramses/modules/AMRGridRefinementHandler.hpp" #include "shammodels/ramses/modules/AMRSortBlocks.hpp" +#include "shamsys/NodeInstance.hpp" #include #include #include @@ -81,6 +85,220 @@ void shammodels::basegodunov::modules::AMRGridRefinementHandler: }); } +/** + * @brief check and enforce 2:1 rule for refinement + * @tparam Tvec + * @tparam TgridVec + * @param refine_list refinement mask + * @param refine_idx_list refinement map + */ +template +void shammodels::basegodunov::modules::AMRGridRefinementHandler:: + enforce_two_to_one_for_refinement( + shambase::DistributedData> &&refine_flags, + shambase::DistributedData &refine_list) { + + using namespace shamrock::patch; + using AMRGraph = shammodels::basegodunov::modules::AMRGraph; + using Direction_ = shammodels::basegodunov::modules::Direction; + using AMRGraphLinkiterator = shammodels::basegodunov::modules::AMRGraph::ro_access; + using TgridUint = typename std::make_unsigned>::type; + + u64 tot_refine = 0; + + scheduler().for_each_patchdata_nonempty([&](Patch cur_p, PatchDataLayer &pdat) { + sham::DeviceQueue &q = shamsys::instance::get_compute_scheduler().get_queue(); + u64 id_patch = cur_p.id_patch; + + sycl::buffer &refn_flags = refine_flags.get(id_patch); + u32 obj_cnt = pdat.get_obj_cnt(); + + // blocks graph in each direction for the current patch + AMRGraph &block_graph_neighs_xp = shambase::get_check_ref(storage.block_graph_edge) + .get_refs_dir(Direction_::xp) + .get(id_patch); + AMRGraph &block_graph_neighs_xm = shambase::get_check_ref(storage.block_graph_edge) + .get_refs_dir(Direction_::xm) + .get(id_patch); + AMRGraph &block_graph_neighs_yp = shambase::get_check_ref(storage.block_graph_edge) + .get_refs_dir(Direction_::yp) + .get(id_patch); + AMRGraph &block_graph_neighs_ym = shambase::get_check_ref(storage.block_graph_edge) + .get_refs_dir(Direction_::ym) + .get(id_patch); + AMRGraph &block_graph_neighs_zp = shambase::get_check_ref(storage.block_graph_edge) + .get_refs_dir(Direction_::zp) + .get(id_patch); + AMRGraph &block_graph_neighs_zm = shambase::get_check_ref(storage.block_graph_edge) + .get_refs_dir(Direction_::zm) + .get(id_patch); + + // get levels in the current patch + sham::DeviceBuffer &buf_amr_block_levels + = shambase::get_check_ref(storage.amr_block_levels).get_buf(id_patch); + + // propagate refinement until stability + for (auto pass = 0; pass < 3; pass++) { + + sham::EventList depend_list; + AMRGraphLinkiterator block_graph_xp + = block_graph_neighs_xp.get_read_access(depend_list); + AMRGraphLinkiterator block_graph_xm + = block_graph_neighs_xm.get_read_access(depend_list); + AMRGraphLinkiterator block_graph_yp + = block_graph_neighs_yp.get_read_access(depend_list); + AMRGraphLinkiterator block_graph_ym + = block_graph_neighs_ym.get_read_access(depend_list); + AMRGraphLinkiterator block_graph_zp + = block_graph_neighs_zp.get_read_access(depend_list); + AMRGraphLinkiterator block_graph_zm + = block_graph_neighs_zm.get_read_access(depend_list); + auto acc_amr_levels = buf_amr_block_levels.get_read_access(depend_list); + + auto e_all_dir = q.submit(depend_list, [&](sycl::handler &cgh) { + sycl::accessor acc_ref_flags{refn_flags, cgh, sycl::read_write}; + cgh.parallel_for(sycl::range<1>(obj_cnt), [=](sycl::item<1> gid) { + u32 block_id = gid.get_linear_id(); + + // get refinement flag and amr level of the current block + u32 cur_ref_flag = acc_ref_flags[block_id]; + auto cur_block_level = acc_amr_levels[block_id]; + + if (cur_ref_flag) { + ///////////////////////////////////////////////////////////// + /// xp + //////////////////////////////////////////////////////////// + block_graph_xp.for_each_object_link(block_id, [&](u32 neigh_block_id) { + // get refinement flag and amr level of the neighborh block + u32 neigh_ref_flag = acc_ref_flags[neigh_block_id]; + auto neigh_block_level = acc_amr_levels[neigh_block_id]; + + if ((0 <= neigh_block_id) && (neigh_block_id < obj_cnt) + && (cur_block_level > neigh_block_level)) { + sycl::atomic_ref< + u32, + sycl::memory_order::relaxed, + sycl::memory_scope::device> + atomic_flag(acc_ref_flags[neigh_block_id]); + atomic_flag.store(1); + } + }); + + ///////////////////////////////////////////////////////////// + /// xm + //////////////////////////////////////////////////////////// + block_graph_xm.for_each_object_link(block_id, [&](u32 neigh_block_id) { + // get refinement flag and amr level of the neighborh block + u32 neigh_ref_flag = acc_ref_flags[neigh_block_id]; + auto neigh_block_level = acc_amr_levels[neigh_block_id]; + + if ((0 <= neigh_block_id) && (neigh_block_id < obj_cnt) + && (cur_block_level > neigh_block_level)) { + sycl::atomic_ref< + u32, + sycl::memory_order::relaxed, + sycl::memory_scope::device> + atomic_flag(acc_ref_flags[neigh_block_id]); + atomic_flag.store(1); + } + }); + + ///////////////////////////////////////////////////////////// + /// yp + //////////////////////////////////////////////////////////// + block_graph_yp.for_each_object_link(block_id, [&](u32 neigh_block_id) { + // get refinement flag and amr level of the neighborh block + u32 neigh_ref_flag = acc_ref_flags[neigh_block_id]; + auto neigh_block_level = acc_amr_levels[neigh_block_id]; + if ((0 <= neigh_block_id) && (neigh_block_id < obj_cnt) + && (cur_block_level > neigh_block_level)) { + sycl::atomic_ref< + u32, + sycl::memory_order::relaxed, + sycl::memory_scope::device> + atomic_flag(acc_ref_flags[neigh_block_id]); + atomic_flag.store(1); + } + }); + ///////////////////////////////////////////////////////////// + /// ym + //////////////////////////////////////////////////////////// + block_graph_ym.for_each_object_link(block_id, [&](u32 neigh_block_id) { + // get refinement flag and amr level of the neighborh block + u32 neigh_ref_flag = acc_ref_flags[neigh_block_id]; + auto neigh_block_level = acc_amr_levels[neigh_block_id]; + + if ((0 <= neigh_block_id) && (neigh_block_id < obj_cnt) + && (cur_block_level > neigh_block_level)) { + sycl::atomic_ref< + u32, + sycl::memory_order::relaxed, + sycl::memory_scope::device> + atomic_flag(acc_ref_flags[neigh_block_id]); + atomic_flag.store(1); + } + }); + ///////////////////////////////////////////////////////////// + /// zp + //////////////////////////////////////////////////////////// + block_graph_zp.for_each_object_link(block_id, [&](u32 neigh_block_id) { + // get refinement flag and amr level of the neighborh block + u32 neigh_ref_flag = acc_ref_flags[neigh_block_id]; + auto neigh_block_level = acc_amr_levels[neigh_block_id]; + + if ((0 <= neigh_block_id) && (neigh_block_id < obj_cnt) + && (cur_block_level > neigh_block_level)) { + sycl::atomic_ref< + u32, + sycl::memory_order::relaxed, + sycl::memory_scope::device> + atomic_flag(acc_ref_flags[neigh_block_id]); + atomic_flag.store(1); + } + }); + ///////////////////////////////////////////////////////////// + /// zm + //////////////////////////////////////////////////////////// + block_graph_zm.for_each_object_link(block_id, [&](u32 neigh_block_id) { + // get refinement flag and amr level of the neighborh block + u32 neigh_ref_flag = acc_ref_flags[neigh_block_id]; + auto neigh_block_level = acc_amr_levels[neigh_block_id]; + + if ((0 <= neigh_block_id) && (neigh_block_id < obj_cnt) + && (cur_block_level > neigh_block_level)) { + sycl::atomic_ref< + u32, + sycl::memory_order::relaxed, + sycl::memory_scope::device> + atomic_flag(acc_ref_flags[neigh_block_id]); + atomic_flag.store(1); + } + }); + } + }); + }); + block_graph_neighs_xp.complete_event_state(e_all_dir); + block_graph_neighs_xm.complete_event_state(e_all_dir); + block_graph_neighs_yp.complete_event_state(e_all_dir); + block_graph_neighs_ym.complete_event_state(e_all_dir); + block_graph_neighs_zp.complete_event_state(e_all_dir); + block_graph_neighs_zm.complete_event_state(e_all_dir); + buf_amr_block_levels.complete_event_state(e_all_dir); + } + //////////////////////////////////////////////////////////////////////////////// + // refinement + //////////////////////////////////////////////////////////////////////////////// + + // perform stream compactions on the refinement flags + auto [buf_refine, len_refine] = shamalgs::numeric::stream_compact(q.q, refn_flags, obj_cnt); + shamlog_debug_ln("AMRGrid", "patch ", id_patch, len_refine, "marked for refinement + 2:1"); + tot_refine += len_refine; + // add the results to the map + refine_list.add_obj(id_patch, OptIndexList{std::move(buf_refine), len_refine}); + }); + logger::info_ln("AMRGrid", "on this process", tot_refine, "blocks will be refined"); +} + template template bool shammodels::basegodunov::modules::AMRGridRefinementHandler:: @@ -565,6 +783,9 @@ void shammodels::basegodunov::modules::AMRGridRefinementHandler: // Note that this only add new blocks at the end of the patchdata bool change_refine = internal_refine_grid(std::move(refine_list)); + ///// enforce 2:1 for refinement /////// + enforce_two_to_one_for_refinement(std::move(refine_flags), refine_list); + //////// apply derefine //////// // Note that this will perform the merge then remove the old blocks // This is ok to call straight after the refine without edditing the index list in From c1c605930b576c44dec3262b75d5ae3f7ee9604d Mon Sep 17 00:00:00 2001 From: Leodasce Sewanou Date: Fri, 13 Feb 2026 15:43:02 +0100 Subject: [PATCH 3/5] deactivate derefinement, to allow the CI test for the amr to pass on the modifications --- .../ramses/src/modules/AMRGridRefinementHandler.cpp | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/shammodels/ramses/src/modules/AMRGridRefinementHandler.cpp b/src/shammodels/ramses/src/modules/AMRGridRefinementHandler.cpp index b571ca2249..9595b1c141 100644 --- a/src/shammodels/ramses/src/modules/AMRGridRefinementHandler.cpp +++ b/src/shammodels/ramses/src/modules/AMRGridRefinementHandler.cpp @@ -779,19 +779,20 @@ void shammodels::basegodunov::modules::AMRGridRefinementHandler: gen_refine_block_changes( refine_flags, derefine_flags, dxfact, cfg->crit_mass); + ///// enforce 2:1 for refinement /////// + enforce_two_to_one_for_refinement(std::move(refine_flags), refine_list); + //////// apply refine //////// // Note that this only add new blocks at the end of the patchdata bool change_refine = internal_refine_grid(std::move(refine_list)); - ///// enforce 2:1 for refinement /////// - enforce_two_to_one_for_refinement(std::move(refine_flags), refine_list); - //////// apply derefine //////// // Note that this will perform the merge then remove the old blocks // This is ok to call straight after the refine without edditing the index list in // derefine_list since no permutations were applied in internal_refine_grid and no cells can // be both refined and derefined in the same pass - bool change_derefine = internal_derefine_grid(std::move(derefine_list)); + bool change_derefine = false; + // internal_derefine_grid(std::move(derefine_list)); has_cell_order_changed = has_cell_order_changed || (change_refine || change_derefine); } From 7c27f3b613dbe78158d8c2c6a90b0c1865c8ad72 Mon Sep 17 00:00:00 2001 From: Leodasce Sewanou Date: Fri, 13 Feb 2026 16:41:41 +0100 Subject: [PATCH 4/5] remove code duplication --- .../modules/AMRGridRefinementHandler.hpp | 10 +- .../src/modules/AMRGridRefinementHandler.cpp | 127 +++--------------- 2 files changed, 27 insertions(+), 110 deletions(-) diff --git a/src/shammodels/ramses/include/shammodels/ramses/modules/AMRGridRefinementHandler.hpp b/src/shammodels/ramses/include/shammodels/ramses/modules/AMRGridRefinementHandler.hpp index 9fd7d9665f..91e5d79af8 100644 --- a/src/shammodels/ramses/include/shammodels/ramses/modules/AMRGridRefinementHandler.hpp +++ b/src/shammodels/ramses/include/shammodels/ramses/modules/AMRGridRefinementHandler.hpp @@ -79,11 +79,17 @@ namespace shammodels::basegodunov::modules { T &&...args); /** - * @brief + * @brief Enforces the 2:1 refinement ratio for blocks. + * + * This function iterates through blocks marked for refinement and ensures that + * adjacent, coarser blocks are also marked for refinement to maintain the 2:1 + * grid balance. This is done iteratively to propagate the refinement as needed. + * @param refine_flags refinement flags + * @param refine_list refinement maps */ void enforce_two_to_one_for_refinement( shambase::DistributedData> &&refine_flags, - shambase::DistributedData &refine_idx_list); + shambase::DistributedData &refine_list); template bool internal_refine_grid(shambase::DistributedData &&refine_list); diff --git a/src/shammodels/ramses/src/modules/AMRGridRefinementHandler.cpp b/src/shammodels/ramses/src/modules/AMRGridRefinementHandler.cpp index 9595b1c141..c7d59067a7 100644 --- a/src/shammodels/ramses/src/modules/AMRGridRefinementHandler.cpp +++ b/src/shammodels/ramses/src/modules/AMRGridRefinementHandler.cpp @@ -165,115 +165,26 @@ void shammodels::basegodunov::modules::AMRGridRefinementHandler: auto cur_block_level = acc_amr_levels[block_id]; if (cur_ref_flag) { - ///////////////////////////////////////////////////////////// - /// xp - //////////////////////////////////////////////////////////// - block_graph_xp.for_each_object_link(block_id, [&](u32 neigh_block_id) { - // get refinement flag and amr level of the neighborh block - u32 neigh_ref_flag = acc_ref_flags[neigh_block_id]; - auto neigh_block_level = acc_amr_levels[neigh_block_id]; - - if ((0 <= neigh_block_id) && (neigh_block_id < obj_cnt) - && (cur_block_level > neigh_block_level)) { - sycl::atomic_ref< - u32, - sycl::memory_order::relaxed, - sycl::memory_scope::device> - atomic_flag(acc_ref_flags[neigh_block_id]); - atomic_flag.store(1); + auto enforce_2_to_1_rule = [&](u32 neigh_block_id) { + if (0 <= neigh_block_id && neigh_block_id < obj_cnt) { + auto neigh_block_level = acc_amr_levels[neigh_block_id]; + if (cur_block_level > neigh_block_level) { + sycl::atomic_ref< + u32, + sycl::memory_order::relaxed, + sycl::memory_scope::device> + atomic_flag(acc_ref_flags[neigh_block_id]); + atomic_flag.store(1); + } } - }); - - ///////////////////////////////////////////////////////////// - /// xm - //////////////////////////////////////////////////////////// - block_graph_xm.for_each_object_link(block_id, [&](u32 neigh_block_id) { - // get refinement flag and amr level of the neighborh block - u32 neigh_ref_flag = acc_ref_flags[neigh_block_id]; - auto neigh_block_level = acc_amr_levels[neigh_block_id]; - - if ((0 <= neigh_block_id) && (neigh_block_id < obj_cnt) - && (cur_block_level > neigh_block_level)) { - sycl::atomic_ref< - u32, - sycl::memory_order::relaxed, - sycl::memory_scope::device> - atomic_flag(acc_ref_flags[neigh_block_id]); - atomic_flag.store(1); - } - }); - - ///////////////////////////////////////////////////////////// - /// yp - //////////////////////////////////////////////////////////// - block_graph_yp.for_each_object_link(block_id, [&](u32 neigh_block_id) { - // get refinement flag and amr level of the neighborh block - u32 neigh_ref_flag = acc_ref_flags[neigh_block_id]; - auto neigh_block_level = acc_amr_levels[neigh_block_id]; - if ((0 <= neigh_block_id) && (neigh_block_id < obj_cnt) - && (cur_block_level > neigh_block_level)) { - sycl::atomic_ref< - u32, - sycl::memory_order::relaxed, - sycl::memory_scope::device> - atomic_flag(acc_ref_flags[neigh_block_id]); - atomic_flag.store(1); - } - }); - ///////////////////////////////////////////////////////////// - /// ym - //////////////////////////////////////////////////////////// - block_graph_ym.for_each_object_link(block_id, [&](u32 neigh_block_id) { - // get refinement flag and amr level of the neighborh block - u32 neigh_ref_flag = acc_ref_flags[neigh_block_id]; - auto neigh_block_level = acc_amr_levels[neigh_block_id]; - - if ((0 <= neigh_block_id) && (neigh_block_id < obj_cnt) - && (cur_block_level > neigh_block_level)) { - sycl::atomic_ref< - u32, - sycl::memory_order::relaxed, - sycl::memory_scope::device> - atomic_flag(acc_ref_flags[neigh_block_id]); - atomic_flag.store(1); - } - }); - ///////////////////////////////////////////////////////////// - /// zp - //////////////////////////////////////////////////////////// - block_graph_zp.for_each_object_link(block_id, [&](u32 neigh_block_id) { - // get refinement flag and amr level of the neighborh block - u32 neigh_ref_flag = acc_ref_flags[neigh_block_id]; - auto neigh_block_level = acc_amr_levels[neigh_block_id]; - - if ((0 <= neigh_block_id) && (neigh_block_id < obj_cnt) - && (cur_block_level > neigh_block_level)) { - sycl::atomic_ref< - u32, - sycl::memory_order::relaxed, - sycl::memory_scope::device> - atomic_flag(acc_ref_flags[neigh_block_id]); - atomic_flag.store(1); - } - }); - ///////////////////////////////////////////////////////////// - /// zm - //////////////////////////////////////////////////////////// - block_graph_zm.for_each_object_link(block_id, [&](u32 neigh_block_id) { - // get refinement flag and amr level of the neighborh block - u32 neigh_ref_flag = acc_ref_flags[neigh_block_id]; - auto neigh_block_level = acc_amr_levels[neigh_block_id]; - - if ((0 <= neigh_block_id) && (neigh_block_id < obj_cnt) - && (cur_block_level > neigh_block_level)) { - sycl::atomic_ref< - u32, - sycl::memory_order::relaxed, - sycl::memory_scope::device> - atomic_flag(acc_ref_flags[neigh_block_id]); - atomic_flag.store(1); - } - }); + }; + + block_graph_xp.for_each_object_link(block_id, enforce_2_to_1_rule); + block_graph_xm.for_each_object_link(block_id, enforce_2_to_1_rule); + block_graph_yp.for_each_object_link(block_id, enforce_2_to_1_rule); + block_graph_ym.for_each_object_link(block_id, enforce_2_to_1_rule); + block_graph_zp.for_each_object_link(block_id, enforce_2_to_1_rule); + block_graph_zm.for_each_object_link(block_id, enforce_2_to_1_rule); } }); }); From 83fd0385e9a9fa80827be1c086a52c4a62f6626b Mon Sep 17 00:00:00 2001 From: Leodasce Sewanou Date: Mon, 16 Feb 2026 11:48:57 +0100 Subject: [PATCH 5/5] updates --- .../modules/AMRGridRefinementHandler.hpp | 4 +- .../src/modules/AMRGridRefinementHandler.cpp | 101 ++++++------------ 2 files changed, 35 insertions(+), 70 deletions(-) diff --git a/src/shammodels/ramses/include/shammodels/ramses/modules/AMRGridRefinementHandler.hpp b/src/shammodels/ramses/include/shammodels/ramses/modules/AMRGridRefinementHandler.hpp index 27ed8c6571..a2ae7d85c5 100644 --- a/src/shammodels/ramses/include/shammodels/ramses/modules/AMRGridRefinementHandler.hpp +++ b/src/shammodels/ramses/include/shammodels/ramses/modules/AMRGridRefinementHandler.hpp @@ -83,8 +83,8 @@ namespace shammodels::basegodunov::modules { * @param refine_list refinement maps */ void enforce_two_to_one_for_refinement( - shambase::DistributedData> &&refine_flags, - shambase::DistributedData &refine_list); + shambase::DistributedData> &&refine_flags, + shambase::DistributedData> &refine_list); template bool internal_refine_grid(shambase::DistributedData> &&refine_list); diff --git a/src/shammodels/ramses/src/modules/AMRGridRefinementHandler.cpp b/src/shammodels/ramses/src/modules/AMRGridRefinementHandler.cpp index ac46eacec8..accee2eb9d 100644 --- a/src/shammodels/ramses/src/modules/AMRGridRefinementHandler.cpp +++ b/src/shammodels/ramses/src/modules/AMRGridRefinementHandler.cpp @@ -18,6 +18,7 @@ #include "shambase/aliases_int.hpp" #include "shambase/memory.hpp" #include "shamalgs/details/algorithm/algorithm.hpp" +#include "shambackends/DeviceBuffer.hpp" #include "shambackends/DeviceQueue.hpp" #include "shambackends/EventList.hpp" #include "shamcomm/logs.hpp" @@ -39,8 +40,6 @@ void shammodels::basegodunov::modules::AMRGridRefinementHandler: using namespace shamrock::patch; - scheduler().for_each_patchdata_nonempty([&](Patch cur_p, PatchDataLayer &pdat) { - sham::DeviceQueue &q = shamsys::instance::get_compute_scheduler().get_queue(); u64 tot_refine = 0; u64 tot_derefine = 0; @@ -53,16 +52,16 @@ void shammodels::basegodunov::modules::AMRGridRefinementHandler: // create the refine and derefine flags buffers u32 obj_cnt = pdat.get_obj_cnt(); - sham::DeviceBuffer refine_flags(obj_cnt, dev_sched); - sham::DeviceBuffer derefine_flags(obj_cnt, dev_sched); + sham::DeviceBuffer refine_flag(obj_cnt, dev_sched); + sham::DeviceBuffer derefine_flag(obj_cnt, dev_sched); { sham::EventList depends_list; UserAcc uacc(depends_list, id_patch, cur_p, pdat, args...); - auto refine_acc = refine_flags.get_write_access(depends_list); - auto derefine_acc = derefine_flags.get_write_access(depends_list); + auto refine_acc = refine_flag.get_write_access(depends_list); + auto derefine_acc = derefine_flag.get_write_access(depends_list); // fill in the flags auto e = q.submit(depends_list, [&](sycl::handler &cgh) { @@ -84,14 +83,14 @@ void shammodels::basegodunov::modules::AMRGridRefinementHandler: sham::EventList resulting_events; resulting_events.add_event(e); - refine_flags.complete_event_state(resulting_events); - derefine_flags.complete_event_state(resulting_events); + refine_flag.complete_event_state(resulting_events); + derefine_flag.complete_event_state(resulting_events); uacc.finalize(resulting_events, id_patch, cur_p, pdat, args...); } - refn_flags.add_obj(id_patch, std::move(refine_flags)); - derfn_flags.add_obj(id_patch, std::move(derefine_flags)); + refine_flags.add_obj(id_patch, std::move(refine_flag)); + derefine_flags.add_obj(id_patch, std::move(derefine_flag)); }); } @@ -105,21 +104,8 @@ void shammodels::basegodunov::modules::AMRGridRefinementHandler: template void shammodels::basegodunov::modules::AMRGridRefinementHandler:: enforce_two_to_one_for_refinement( - shambase::DistributedData> &&refine_flags, - shambase::DistributedData &refine_list) { - - sham::DeviceBuffer &buf_cell_min = pdat.get_field_buf_ref(0); - sham::DeviceBuffer &buf_cell_max = pdat.get_field_buf_ref(1); - - sham::EventList depends_list; - auto acc_min = buf_cell_min.get_read_access(depends_list); - auto acc_max = buf_cell_max.get_read_access(depends_list); - auto acc_merge_flag = derefine_flags.get_write_access(depends_list); - - // keep only derefine flags on only if the eight cells want to merge and if they can - auto e = q.submit(depends_list, [&](sycl::handler &cgh) { - cgh.parallel_for(sycl::range<1>(obj_cnt), [=](sycl::item<1> gid) { - u32 id = gid.get_linear_id(); + shambase::DistributedData> &&refine_flags, + shambase::DistributedData> &refine_list) { using namespace shamrock::patch; using AMRGraph = shammodels::basegodunov::modules::AMRGraph; @@ -129,12 +115,13 @@ void shammodels::basegodunov::modules::AMRGridRefinementHandler: u64 tot_refine = 0; + sham::DeviceQueue &q = shamsys::instance::get_compute_scheduler().get_queue(); + auto dev_sched = shamsys::instance::get_compute_scheduler_ptr(); scheduler().for_each_patchdata_nonempty([&](Patch cur_p, PatchDataLayer &pdat) { - sham::DeviceQueue &q = shamsys::instance::get_compute_scheduler().get_queue(); - u64 id_patch = cur_p.id_patch; + u64 id_patch = cur_p.id_patch; - sycl::buffer &refn_flags = refine_flags.get(id_patch); - u32 obj_cnt = pdat.get_obj_cnt(); + sham::DeviceBuffer &refine_flags_buf = refine_list.get(id_patch); + u32 obj_cnt = pdat.get_obj_cnt(); // blocks graph in each direction for the current patch AMRGraph &block_graph_neighs_xp = shambase::get_check_ref(storage.block_graph_edge) @@ -177,9 +164,9 @@ void shammodels::basegodunov::modules::AMRGridRefinementHandler: AMRGraphLinkiterator block_graph_zm = block_graph_neighs_zm.get_read_access(depend_list); auto acc_amr_levels = buf_amr_block_levels.get_read_access(depend_list); + auto acc_ref_flags = refine_flags_buf.get_write_access(depend_list); auto e_all_dir = q.submit(depend_list, [&](sycl::handler &cgh) { - sycl::accessor acc_ref_flags{refn_flags, cgh, sycl::read_write}; cgh.parallel_for(sycl::range<1>(obj_cnt), [=](sycl::item<1> gid) { u32 block_id = gid.get_linear_id(); @@ -218,47 +205,19 @@ void shammodels::basegodunov::modules::AMRGridRefinementHandler: block_graph_neighs_zp.complete_event_state(e_all_dir); block_graph_neighs_zm.complete_event_state(e_all_dir); buf_amr_block_levels.complete_event_state(e_all_dir); + refine_flags_buf.complete_event_state(e_all_dir); } - }); - - buf_cell_min.complete_event_state(e); - buf_cell_max.complete_event_state(e); - derefine_flags.complete_event_state(e); //////////////////////////////////////////////////////////////////////////////// // refinement //////////////////////////////////////////////////////////////////////////////// // perform stream compactions on the refinement flags - auto [buf_refine, len_refine] = shamalgs::numeric::stream_compact(q.q, refn_flags, obj_cnt); - shamlog_debug_ln("AMRGrid", "patch ", id_patch, len_refine, "marked for refinement + 2:1"); - tot_refine += len_refine; - // add the results to the map - refine_list.add_obj(id_patch, OptIndexList{std::move(buf_refine), len_refine}); - auto buf_refine = shamalgs::numeric::stream_compact(dev_sched, refine_flags, obj_cnt); - + auto buf_refine = shamalgs::numeric::stream_compact(dev_sched, refine_flags_buf, obj_cnt); shamlog_debug_ln( - "AMRGrid", "patch ", id_patch, "refine block count = ", buf_refine.get_size()); - + "AMRGrid", "patch ", id_patch, buf_refine.get_size(), "marked for refinement + 2:1"); tot_refine += buf_refine.get_size(); - - // add the results to the map refine_list.add_obj(id_patch, std::move(buf_refine)); - - //////////////////////////////////////////////////////////////////////////////// - // derefinement - //////////////////////////////////////////////////////////////////////////////// - - // perform stream compactions on the derefinement flags - auto buf_derefine = shamalgs::numeric::stream_compact(dev_sched, derefine_flags, obj_cnt); - - shamlog_debug_ln( - "AMRGrid", "patch ", id_patch, "merge block count = ", buf_derefine.get_size()); - - tot_derefine += buf_derefine.get_size(); - - // add the results to the map - derefine_list.add_obj(id_patch, std::move(buf_derefine)); }); logger::info_ln("AMRGrid", "on this process", tot_refine, "blocks will be refined"); } @@ -477,10 +436,16 @@ void shammodels::basegodunov::modules::AMRGridRefinementHandler: block_sorter.reorder_amr_blocks(); // get refine and derefine list + shambase::DistributedData> refine_flags; + shambase::DistributedData> derefine_flags; + shambase::DistributedData> refine_list; shambase::DistributedData> derefine_list; - gen_refine_block_changes(refine_list, derefine_list); + gen_refine_block_changes(refine_flags, derefine_flags); + + ///// enforce 2:1 for refinement /////// + enforce_two_to_one_for_refinement(std::move(refine_flags), refine_list); //////// apply refine //////// // Note that this only add new blocks at the end of the patchdata @@ -491,7 +456,8 @@ void shammodels::basegodunov::modules::AMRGridRefinementHandler: // This is ok to call straight after the refine without edditing the index list in derefine_list // since no permutations were applied in internal_refine_grid and no cells can be both refined // and derefined in the same pass - internal_derefine_grid(std::move(derefine_list)); + + // internal_derefine_grid(std::move(derefine_list)); } template @@ -742,12 +708,11 @@ void shammodels::basegodunov::modules::AMRGridRefinementHandler: Tscal dxfact(solver_config.grid_coord_to_pos_fact); // get refine and derefine list - shambase::DistributedData> refine_flags; - shambase::DistributedData> derefine_flags; - shambase::DistributedData refine_list; - shambase::DistributedData derefine_list; + shambase::DistributedData> refine_flags; + shambase::DistributedData> derefine_flags; + shambase::DistributedData> refine_list; - shambase::DistributedData> derefine_list; + // shambase::DistributedData> derefine_list; gen_refine_block_changes( refine_flags, derefine_flags, dxfact, cfg->crit_mass);