Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,23 @@ namespace shammodels::basegodunov::modules {
*/
template<class UserAcc, class... T>
void gen_refine_block_changes(
shambase::DistributedData<sham::DeviceBuffer<u32>> &refine_list,
shambase::DistributedData<sham::DeviceBuffer<u32>> &derefine_list,
shambase::DistributedData<sham::DeviceBuffer<u32>> &refine_flags,
shambase::DistributedData<sham::DeviceBuffer<u32>> &derefine_flags,
T &&...args);

/**
* @brief Enforces the 2:1 refinement ratio for blocks.
*
* This function iterates through blocks marked for refinement and ensures that
* adjacent, coarser blocks are also marked for refinement to maintain the 2:1
* grid balance. This is done iteratively to propagate the refinement as needed.
* @param refine_flags refinement flags
* @param refine_list refinement maps
*/
void enforce_two_to_one_for_refinement(
shambase::DistributedData<sham::DeviceBuffer<u32>> &&refine_flags,
shambase::DistributedData<sham::DeviceBuffer<u32>> &refine_list);

template<class UserAcc>
bool internal_refine_grid(shambase::DistributedData<sham::DeviceBuffer<u32>> &&refine_list);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ namespace shammodels::basegodunov {
using Tscal = shambase::VecComponent<Tvec>;
using Tgridscal = shambase::VecComponent<TgridVec>;
static constexpr u32 dim = shambase::VectorProperties<Tvec>::dimension;
using TgridUint = typename std::make_unsigned<shambase::VecComponent<TgridVec>>::type;

using RTree = RadixTree<Tmorton, TgridVec>;

Expand Down Expand Up @@ -141,6 +142,9 @@ namespace shammodels::basegodunov {

std::shared_ptr<shamrock::solvergraph::DDSharedBuffers<u32>> idx_in_ghost;

std::shared_ptr<shamrock::solvergraph::ScalarsEdge<TgridVec>> level0_size;
std::shared_ptr<shamrock::solvergraph::Field<TgridUint>> amr_block_levels;

std::shared_ptr<solvergraph::NeighGraphLinkFieldEdge<std::array<Tscal, 2>>> rho_face_xp;
std::shared_ptr<solvergraph::NeighGraphLinkFieldEdge<std::array<Tscal, 2>>> rho_face_xm;
std::shared_ptr<solvergraph::NeighGraphLinkFieldEdge<std::array<Tscal, 2>>> rho_face_yp;
Expand Down
224 changes: 153 additions & 71 deletions src/shammodels/ramses/src/modules/AMRGridRefinementHandler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,28 @@
*
*/

#include "shammodels/ramses/modules/AMRGridRefinementHandler.hpp"
#include "shambase/DistributedData.hpp"
#include "shambase/aliases_int.hpp"
#include "shambase/memory.hpp"
#include "shamalgs/details/algorithm/algorithm.hpp"
#include "shambackends/DeviceBuffer.hpp"
#include "shambackends/DeviceQueue.hpp"
#include "shambackends/EventList.hpp"
#include "shamcomm/logs.hpp"
#include "shammodels/ramses/modules/AMRGridRefinementHandler.hpp"
#include "shammodels/ramses/modules/AMRSortBlocks.hpp"
#include "shamsys/NodeInstance.hpp"
#include <shambackends/sycl.hpp>
#include <algorithm>
#include <stdexcept>
#include <utility>

template<class Tvec, class TgridVec>
template<class UserAcc, class... T>
void shammodels::basegodunov::modules::AMRGridRefinementHandler<Tvec, TgridVec>::
gen_refine_block_changes(
shambase::DistributedData<sham::DeviceBuffer<u32>> &refine_list,
shambase::DistributedData<sham::DeviceBuffer<u32>> &derefine_list,
shambase::DistributedData<sham::DeviceBuffer<u32>> &refine_flags,
shambase::DistributedData<sham::DeviceBuffer<u32>> &derefine_flags,
T &&...args) {

using namespace shamrock::patch;
Expand All @@ -42,16 +52,16 @@ void shammodels::basegodunov::modules::AMRGridRefinementHandler<Tvec, TgridVec>:
// create the refine and derefine flags buffers
u32 obj_cnt = pdat.get_obj_cnt();

sham::DeviceBuffer<u32> refine_flags(obj_cnt, dev_sched);
sham::DeviceBuffer<u32> derefine_flags(obj_cnt, dev_sched);
sham::DeviceBuffer<u32> refine_flag(obj_cnt, dev_sched);
sham::DeviceBuffer<u32> derefine_flag(obj_cnt, dev_sched);

{
sham::EventList depends_list;

UserAcc uacc(depends_list, id_patch, cur_p, pdat, args...);

auto refine_acc = refine_flags.get_write_access(depends_list);
auto derefine_acc = derefine_flags.get_write_access(depends_list);
auto refine_acc = refine_flag.get_write_access(depends_list);
auto derefine_acc = derefine_flag.get_write_access(depends_list);

// fill in the flags
auto e = q.submit(depends_list, [&](sycl::handler &cgh) {
Expand All @@ -73,87 +83,145 @@ void shammodels::basegodunov::modules::AMRGridRefinementHandler<Tvec, TgridVec>:
sham::EventList resulting_events;
resulting_events.add_event(e);

refine_flags.complete_event_state(resulting_events);
derefine_flags.complete_event_state(resulting_events);
refine_flag.complete_event_state(resulting_events);
derefine_flag.complete_event_state(resulting_events);

uacc.finalize(resulting_events, id_patch, cur_p, pdat, args...);
}

sham::DeviceBuffer<TgridVec> &buf_cell_min = pdat.get_field_buf_ref<TgridVec>(0);
sham::DeviceBuffer<TgridVec> &buf_cell_max = pdat.get_field_buf_ref<TgridVec>(1);
refine_flags.add_obj(id_patch, std::move(refine_flag));
derefine_flags.add_obj(id_patch, std::move(derefine_flag));
});
}

sham::EventList depends_list;
auto acc_min = buf_cell_min.get_read_access(depends_list);
auto acc_max = buf_cell_max.get_read_access(depends_list);
auto acc_merge_flag = derefine_flags.get_write_access(depends_list);
/**
* @brief check and enforce 2:1 rule for refinement
* @tparam Tvec
* @tparam TgridVec
* @param refine_list refinement mask
* @param refine_idx_list refinement map
*/
template<class Tvec, class TgridVec>
void shammodels::basegodunov::modules::AMRGridRefinementHandler<Tvec, TgridVec>::
enforce_two_to_one_for_refinement(
shambase::DistributedData<sham::DeviceBuffer<u32>> &&refine_flags,
shambase::DistributedData<sham::DeviceBuffer<u32>> &refine_list) {

// keep only derefine flags on only if the eight cells want to merge and if they can
auto e = q.submit(depends_list, [&](sycl::handler &cgh) {
cgh.parallel_for(sycl::range<1>(obj_cnt), [=](sycl::item<1> gid) {
u32 id = gid.get_linear_id();
using namespace shamrock::patch;
using AMRGraph = shammodels::basegodunov::modules::AMRGraph;
using Direction_ = shammodels::basegodunov::modules::Direction;
using AMRGraphLinkiterator = shammodels::basegodunov::modules::AMRGraph::ro_access;
using TgridUint = typename std::make_unsigned<shambase::VecComponent<TgridVec>>::type;

std::array<BlockCoord, split_count> blocks;
bool do_merge = true;
u64 tot_refine = 0;

// This avoid the case where we are in the last block of the buffer to avoid the
// out-of-bound read
if (id + split_count <= obj_cnt) {
bool all_want_to_merge = true;
sham::DeviceQueue &q = shamsys::instance::get_compute_scheduler().get_queue();
auto dev_sched = shamsys::instance::get_compute_scheduler_ptr();
scheduler().for_each_patchdata_nonempty([&](Patch cur_p, PatchDataLayer &pdat) {
u64 id_patch = cur_p.id_patch;

for (u32 lid = 0; lid < split_count; lid++) {
blocks[lid] = BlockCoord{acc_min[gid + lid], acc_max[gid + lid]};
all_want_to_merge = all_want_to_merge && acc_merge_flag[gid + lid];
sham::DeviceBuffer<u32> &refine_flags_buf = refine_list.get(id_patch);
u32 obj_cnt = pdat.get_obj_cnt();

// blocks graph in each direction for the current patch
AMRGraph &block_graph_neighs_xp = shambase::get_check_ref(storage.block_graph_edge)
.get_refs_dir(Direction_::xp)
.get(id_patch);
AMRGraph &block_graph_neighs_xm = shambase::get_check_ref(storage.block_graph_edge)
.get_refs_dir(Direction_::xm)
.get(id_patch);
AMRGraph &block_graph_neighs_yp = shambase::get_check_ref(storage.block_graph_edge)
.get_refs_dir(Direction_::yp)
.get(id_patch);
AMRGraph &block_graph_neighs_ym = shambase::get_check_ref(storage.block_graph_edge)
.get_refs_dir(Direction_::ym)
.get(id_patch);
AMRGraph &block_graph_neighs_zp = shambase::get_check_ref(storage.block_graph_edge)
.get_refs_dir(Direction_::zp)
.get(id_patch);
AMRGraph &block_graph_neighs_zm = shambase::get_check_ref(storage.block_graph_edge)
.get_refs_dir(Direction_::zm)
.get(id_patch);

// get levels in the current patch
sham::DeviceBuffer<TgridUint> &buf_amr_block_levels
= shambase::get_check_ref(storage.amr_block_levels).get_buf(id_patch);

// propagate refinement until stability
for (auto pass = 0; pass < 3; pass++) {

sham::EventList depend_list;
AMRGraphLinkiterator block_graph_xp
= block_graph_neighs_xp.get_read_access(depend_list);
AMRGraphLinkiterator block_graph_xm
= block_graph_neighs_xm.get_read_access(depend_list);
AMRGraphLinkiterator block_graph_yp
= block_graph_neighs_yp.get_read_access(depend_list);
AMRGraphLinkiterator block_graph_ym
= block_graph_neighs_ym.get_read_access(depend_list);
AMRGraphLinkiterator block_graph_zp
= block_graph_neighs_zp.get_read_access(depend_list);
AMRGraphLinkiterator block_graph_zm
= block_graph_neighs_zm.get_read_access(depend_list);
auto acc_amr_levels = buf_amr_block_levels.get_read_access(depend_list);
auto acc_ref_flags = refine_flags_buf.get_write_access(depend_list);

auto e_all_dir = q.submit(depend_list, [&](sycl::handler &cgh) {
cgh.parallel_for(sycl::range<1>(obj_cnt), [=](sycl::item<1> gid) {
u32 block_id = gid.get_linear_id();

// get refinement flag and amr level of the current block
u32 cur_ref_flag = acc_ref_flags[block_id];
auto cur_block_level = acc_amr_levels[block_id];

if (cur_ref_flag) {
auto enforce_2_to_1_rule = [&](u32 neigh_block_id) {
if (0 <= neigh_block_id && neigh_block_id < obj_cnt) {
auto neigh_block_level = acc_amr_levels[neigh_block_id];
if (cur_block_level > neigh_block_level) {
sycl::atomic_ref<
u32,
sycl::memory_order::relaxed,
sycl::memory_scope::device>
atomic_flag(acc_ref_flags[neigh_block_id]);
atomic_flag.store(1);
}
}
};

block_graph_xp.for_each_object_link(block_id, enforce_2_to_1_rule);
block_graph_xm.for_each_object_link(block_id, enforce_2_to_1_rule);
block_graph_yp.for_each_object_link(block_id, enforce_2_to_1_rule);
block_graph_ym.for_each_object_link(block_id, enforce_2_to_1_rule);
block_graph_zp.for_each_object_link(block_id, enforce_2_to_1_rule);
block_graph_zm.for_each_object_link(block_id, enforce_2_to_1_rule);
}

do_merge = all_want_to_merge && BlockCoord::are_mergeable(blocks);

} else {
do_merge = false;
}

acc_merge_flag[gid] = do_merge;
});
});
});

buf_cell_min.complete_event_state(e);
buf_cell_max.complete_event_state(e);
derefine_flags.complete_event_state(e);
block_graph_neighs_xp.complete_event_state(e_all_dir);
block_graph_neighs_xm.complete_event_state(e_all_dir);
block_graph_neighs_yp.complete_event_state(e_all_dir);
block_graph_neighs_ym.complete_event_state(e_all_dir);
block_graph_neighs_zp.complete_event_state(e_all_dir);
block_graph_neighs_zm.complete_event_state(e_all_dir);
buf_amr_block_levels.complete_event_state(e_all_dir);
refine_flags_buf.complete_event_state(e_all_dir);
}

////////////////////////////////////////////////////////////////////////////////
// refinement
////////////////////////////////////////////////////////////////////////////////

// perform stream compactions on the refinement flags
auto buf_refine = shamalgs::numeric::stream_compact(dev_sched, refine_flags, obj_cnt);

auto buf_refine = shamalgs::numeric::stream_compact(dev_sched, refine_flags_buf, obj_cnt);
shamlog_debug_ln(
"AMRGrid", "patch ", id_patch, "refine block count = ", buf_refine.get_size());

"AMRGrid", "patch ", id_patch, buf_refine.get_size(), "marked for refinement + 2:1");
tot_refine += buf_refine.get_size();

// add the results to the map
refine_list.add_obj(id_patch, std::move(buf_refine));

////////////////////////////////////////////////////////////////////////////////
// derefinement
////////////////////////////////////////////////////////////////////////////////

// perform stream compactions on the derefinement flags
auto buf_derefine = shamalgs::numeric::stream_compact(dev_sched, derefine_flags, obj_cnt);

shamlog_debug_ln(
"AMRGrid", "patch ", id_patch, "merge block count = ", buf_derefine.get_size());

tot_derefine += buf_derefine.get_size();

// add the results to the map
derefine_list.add_obj(id_patch, std::move(buf_derefine));
});

logger::info_ln("AMRGrid", "on this process", tot_refine, "blocks were refined");
logger::info_ln(
"AMRGrid", "on this process", tot_derefine * split_count, "blocks were derefined");
logger::info_ln("AMRGrid", "on this process", tot_refine, "blocks will be refined");
}

template<class Tvec, class TgridVec>
template<class UserAcc>
bool shammodels::basegodunov::modules::AMRGridRefinementHandler<Tvec, TgridVec>::
Expand Down Expand Up @@ -368,10 +436,16 @@ void shammodels::basegodunov::modules::AMRGridRefinementHandler<Tvec, TgridVec>:
block_sorter.reorder_amr_blocks();

// get refine and derefine list
shambase::DistributedData<sham::DeviceBuffer<u32>> refine_flags;
shambase::DistributedData<sham::DeviceBuffer<u32>> derefine_flags;

shambase::DistributedData<sham::DeviceBuffer<u32>> refine_list;
shambase::DistributedData<sham::DeviceBuffer<u32>> derefine_list;

gen_refine_block_changes<UserAccCrit>(refine_list, derefine_list);
gen_refine_block_changes<UserAccCrit>(refine_flags, derefine_flags);

///// enforce 2:1 for refinement ///////
enforce_two_to_one_for_refinement(std::move(refine_flags), refine_list);

//////// apply refine ////////
// Note that this only add new blocks at the end of the patchdata
Expand All @@ -382,7 +456,8 @@ void shammodels::basegodunov::modules::AMRGridRefinementHandler<Tvec, TgridVec>:
// This is ok to call straight after the refine without edditing the index list in derefine_list
// since no permutations were applied in internal_refine_grid and no cells can be both refined
// and derefined in the same pass
internal_derefine_grid<UserAccMerge>(std::move(derefine_list));

// internal_derefine_grid<UserAccMerge>(std::move(derefine_list));
}

template<class Tvec, class TgridVec>
Expand Down Expand Up @@ -633,11 +708,17 @@ void shammodels::basegodunov::modules::AMRGridRefinementHandler<Tvec, TgridVec>:
Tscal dxfact(solver_config.grid_coord_to_pos_fact);

// get refine and derefine list
shambase::DistributedData<sham::DeviceBuffer<u32>> refine_flags;
shambase::DistributedData<sham::DeviceBuffer<u32>> derefine_flags;

shambase::DistributedData<sham::DeviceBuffer<u32>> refine_list;
shambase::DistributedData<sham::DeviceBuffer<u32>> derefine_list;
// shambase::DistributedData<sham::DeviceBuffer<u32>> derefine_list;

gen_refine_block_changes<RefineCritBlock>(
refine_list, derefine_list, dxfact, cfg->crit_mass);
refine_flags, derefine_flags, dxfact, cfg->crit_mass);

///// enforce 2:1 for refinement ///////
enforce_two_to_one_for_refinement(std::move(refine_flags), refine_list);

//////// apply refine ////////
// Note that this only add new blocks at the end of the patchdata
Expand All @@ -648,7 +729,8 @@ void shammodels::basegodunov::modules::AMRGridRefinementHandler<Tvec, TgridVec>:
// This is ok to call straight after the refine without edditing the index list in
// derefine_list since no permutations were applied in internal_refine_grid and no cells can
// be both refined and derefined in the same pass
bool change_derefine = internal_derefine_grid<RefineCellAccessor>(std::move(derefine_list));
bool change_derefine = false;
// internal_derefine_grid<RefineCellAccessor>(std::move(derefine_list));

has_cell_order_changed = has_cell_order_changed || (change_refine || change_derefine);
}
Expand Down
Loading