diff --git a/src/shammodels/gsph/include/shammodels/gsph/modules/SolverStorage.hpp b/src/shammodels/gsph/include/shammodels/gsph/modules/SolverStorage.hpp index 118dadc459..8393a2802b 100644 --- a/src/shammodels/gsph/include/shammodels/gsph/modules/SolverStorage.hpp +++ b/src/shammodels/gsph/include/shammodels/gsph/modules/SolverStorage.hpp @@ -35,7 +35,11 @@ #include "shamrock/solvergraph/FieldRefs.hpp" #include "shamrock/solvergraph/Indexes.hpp" #include "shamrock/solvergraph/ScalarsEdge.hpp" +#include "shamrock/solvergraph/SolverGraph.hpp" #include "shamsys/legacy/log.hpp" + +// GSPH solvergraph edges +#include "shammodels/gsph/solvergraph/MergedPatchDataEdge.hpp" #include "shamtree/CompressedLeafBVH.hpp" #include "shamtree/KarrasRadixTreeField.hpp" #include "shamtree/RadixTree.hpp" @@ -70,6 +74,8 @@ namespace shammodels::gsph { using RTree = shamtree::CompressedLeafBVH; + shamrock::solvergraph::SolverGraph solver_graph; + /// Particle counts per patch std::shared_ptr> part_counts; std::shared_ptr> part_counts_with_ghost; @@ -91,8 +97,8 @@ namespace shammodels::gsph { Component ghost_handler; Component ghost_patch_cache; - /// Merged position-h data for neighbor search - Component> merged_xyzh; + /// Merged position-h data for neighbor search - managed via SolverGraph + std::shared_ptr merged_xyzh; /// Radix trees for neighbor search Component> merged_pos_trees; @@ -105,8 +111,9 @@ namespace shammodels::gsph { /// Ghost data layout and merged data std::shared_ptr xyzh_ghost_layout; std::shared_ptr ghost_layout; - Component> - merged_patchdata_ghost; + + /// Merged patchdata including all ghost fields - managed via SolverGraph + std::shared_ptr merged_patchdata_ghost; /// Density field computed via SPH summation std::shared_ptr> density; diff --git a/src/shammodels/gsph/include/shammodels/gsph/solvergraph/MergedPatchDataEdge.hpp b/src/shammodels/gsph/include/shammodels/gsph/solvergraph/MergedPatchDataEdge.hpp new file mode 100644 index 0000000000..1cf31978cf --- /dev/null +++ b/src/shammodels/gsph/include/shammodels/gsph/solvergraph/MergedPatchDataEdge.hpp @@ -0,0 +1,43 @@ +// -------------------------------------------------------// +// +// SHAMROCK code for hydrodynamics +// Copyright (c) 2021-2026 Timothée David--Cléris +// SPDX-License-Identifier: CeCILL Free Software License Agreement v2.1 +// Shamrock is licensed under the CeCILL 2.1 License, see LICENSE for more information +// +// -------------------------------------------------------// + +#pragma once + +/** + * @file MergedPatchDataEdge.hpp + * @author Guo Yansong (guo.yansong.ngy@gmail.com) + * @brief SolverGraph edge for merged PatchDataLayer + */ + +#include "shambase/DistributedData.hpp" +#include "shambase/memory.hpp" +#include "shamrock/patch/PatchDataLayer.hpp" +#include "shamrock/solvergraph/IEdgeNamed.hpp" + +namespace shammodels::gsph::solvergraph { + + /// SolverGraph edge for merged PatchDataLayer storage (local + ghost particles) + class MergedPatchDataEdge : public shamrock::solvergraph::IEdgeNamed { + public: + using IEdgeNamed::IEdgeNamed; + + shambase::DistributedData data; + + shamrock::patch::PatchDataLayer &get(u64 id) { return data.get(id); } + const shamrock::patch::PatchDataLayer &get(u64 id) const { return data.get(id); } + + shambase::DistributedData &get_data() { return data; } + const shambase::DistributedData &get_data() const { + return data; + } + + inline virtual void free_alloc() override { data = {}; } + }; + +} // namespace shammodels::gsph::solvergraph diff --git a/src/shammodels/gsph/src/Solver.cpp b/src/shammodels/gsph/src/Solver.cpp index 6ff34dbc3f..6c9c09f8ab 100644 --- a/src/shammodels/gsph/src/Solver.cpp +++ b/src/shammodels/gsph/src/Solver.cpp @@ -87,6 +87,14 @@ void shammodels::gsph::Solver::init_solver_graph() { storage.neigh_cache = std::make_shared(edges::neigh_cache, "neigh"); + // Register merged patchdata edges for dependency tracking + storage.merged_xyzh = storage.solver_graph.register_edge( + "merged_xyzh", solvergraph::MergedPatchDataEdge("merged_xyzh", "\\mathbf{xyzh}_{\\rm m}")); + + storage.merged_patchdata_ghost = storage.solver_graph.register_edge( + "merged_patchdata_ghost", + solvergraph::MergedPatchDataEdge("merged_patchdata_ghost", "\\mathbb{U}_{\\rm ghost}")); + storage.omega = std::make_shared>(1, "omega", "\\Omega"); storage.density = std::make_shared>(1, "density", "\\rho"); storage.pressure = std::make_shared>(1, "pressure", "P"); @@ -190,8 +198,8 @@ template class Kern> void shammodels::gsph::Solver::merge_position_ghost() { StackEntry stack_loc{}; - storage.merged_xyzh.set( - storage.ghost_handler.get().build_comm_merge_positions(storage.ghost_patch_cache.get())); + shambase::get_check_ref(storage.merged_xyzh).data + = (storage.ghost_handler.get().build_comm_merge_positions(storage.ghost_patch_cache.get())); // Get field indices from xyzh_ghost_layout const u32 ixyz_ghost @@ -201,39 +209,45 @@ void shammodels::gsph::Solver::merge_position_ghost() { // Set element counts shambase::get_check_ref(storage.part_counts).indexes - = storage.merged_xyzh.get().template map( - [&](u64 id, shamrock::patch::PatchDataLayer &mpdat) { - return scheduler().patch_data.get_pdat(id).get_obj_cnt(); - }); + = shambase::get_check_ref(storage.merged_xyzh) + .get_data() + .template map([&](u64 id, shamrock::patch::PatchDataLayer &mpdat) { + return scheduler().patch_data.get_pdat(id).get_obj_cnt(); + }); // Set element counts with ghost shambase::get_check_ref(storage.part_counts_with_ghost).indexes - = storage.merged_xyzh.get().template map( - [&](u64 id, shamrock::patch::PatchDataLayer &mpdat) { - return mpdat.get_obj_cnt(); - }); + = shambase::get_check_ref(storage.merged_xyzh) + .get_data() + .template map([&](u64 id, shamrock::patch::PatchDataLayer &mpdat) { + return mpdat.get_obj_cnt(); + }); // Attach spans to block coords shambase::get_check_ref(storage.positions_with_ghosts) .set_refs( - storage.merged_xyzh.get().template map>>( - [&, ixyz_ghost](u64 id, shamrock::patch::PatchDataLayer &mpdat) { - return std::ref(mpdat.get_field(ixyz_ghost)); - })); + shambase::get_check_ref(storage.merged_xyzh) + .get_data() + .template map>>( + [&, ixyz_ghost](u64 id, shamrock::patch::PatchDataLayer &mpdat) { + return std::ref(mpdat.get_field(ixyz_ghost)); + })); shambase::get_check_ref(storage.hpart_with_ghosts) .set_refs( - storage.merged_xyzh.get().template map>>( - [&, ihpart_ghost](u64 id, shamrock::patch::PatchDataLayer &mpdat) { - return std::ref(mpdat.get_field(ihpart_ghost)); - })); + shambase::get_check_ref(storage.merged_xyzh) + .get_data() + .template map>>( + [&, ihpart_ghost](u64 id, shamrock::patch::PatchDataLayer &mpdat) { + return std::ref(mpdat.get_field(ihpart_ghost)); + })); } template class Kern> void shammodels::gsph::Solver::build_merged_pos_trees() { StackEntry stack_loc{}; - auto &merged_xyzh = storage.merged_xyzh.get(); + auto &merged_xyzh = shambase::get_check_ref(storage.merged_xyzh).get_data(); auto dev_sched = shamsys::instance::get_compute_scheduler_ptr(); // Get field index from xyzh_ghost_layout @@ -278,7 +292,7 @@ template class Kern> void shammodels::gsph::Solver::compute_presteps_rint() { StackEntry stack_loc{}; - auto &xyzh_merged = storage.merged_xyzh.get(); + auto &xyzh_merged = shambase::get_check_ref(storage.merged_xyzh).get_data(); auto dev_sched = shamsys::instance::get_compute_scheduler_ptr(); storage.rtree_rint_field.set( @@ -324,7 +338,7 @@ void shammodels::gsph::Solver::start_neighbors_cache() { // Build neighbor cache using tree traversal - same approach as SPH module auto build_neigh_cache = [&](u64 patch_id) -> shamrock::tree::ObjectCache { - auto &mfield = storage.merged_xyzh.get().get(patch_id); + auto &mfield = shambase::get_check_ref(storage.merged_xyzh).get_data().get(patch_id); sham::DeviceBuffer &buf_xyz = mfield.template get_field_buf_ref(0); sham::DeviceBuffer &buf_hpart = mfield.template get_field_buf_ref(1); @@ -757,8 +771,8 @@ void shammodels::gsph::Solver::communicate_merge_ghosts_fields() { }); // Merge local and ghost data - storage.merged_patchdata_ghost.set( - ghost_handle.template merge_native( + shambase::get_check_ref(storage.merged_patchdata_ghost).data + = (ghost_handle.template merge_native( std::move(interf_pdat), [&](const shamrock::patch::Patch p, shamrock::patch::PatchDataLayer &pdat) { PatchDataLayer pdat_new(ghost_layout_ptr); @@ -803,7 +817,7 @@ void shammodels::gsph::Solver::communicate_merge_ghosts_fields() { template class Kern> void shammodels::gsph::Solver::reset_merge_ghosts_fields() { - storage.merged_patchdata_ghost.reset(); + shambase::get_check_ref(storage.merged_patchdata_ghost).free_alloc(); } template class Kern> @@ -856,7 +870,7 @@ void shammodels::gsph::Solver::compute_omega() { // 3. If h grows beyond tolerance, signal for cache rebuild // ========================================================================= - auto &merged_xyzh = storage.merged_xyzh.get(); + auto &merged_xyzh = shambase::get_check_ref(storage.merged_xyzh).get_data(); // Create field references for the iteration module // Position spans (from merged xyzh) @@ -1129,73 +1143,77 @@ void shammodels::gsph::Solver::compute_eos_fields() { soundspeed_field.ensure_sizes(counts_with_ghosts); // Iterate over merged_patchdata_ghost (includes local + ghost particles) - storage.merged_patchdata_ghost.get().for_each([&](u64 id, PatchDataLayer &mpdat) { - u32 total_elements - = shambase::get_check_ref(storage.part_counts_with_ghost).indexes.get(id); - if (total_elements == 0) - return; - - // Use SPH-summation density from communicated ghost data - sham::DeviceBuffer &buf_density = mpdat.get_field_buf_ref(idensity_interf); - auto &pressure_buf = pressure_field.get_field(id).get_buf(); - auto &soundspeed_buf = soundspeed_field.get_field(id).get_buf(); + shambase::get_check_ref(storage.merged_patchdata_ghost) + .get_data() + .for_each([&](u64 id, PatchDataLayer &mpdat) { + u32 total_elements + = shambase::get_check_ref(storage.part_counts_with_ghost).indexes.get(id); + if (total_elements == 0) + return; - sham::DeviceQueue &q = dev_sched->get_queue(); - sham::EventList depends_list; + // Use SPH-summation density from communicated ghost data + sham::DeviceBuffer &buf_density + = mpdat.get_field_buf_ref(idensity_interf); + auto &pressure_buf = pressure_field.get_field(id).get_buf(); + auto &soundspeed_buf = soundspeed_field.get_field(id).get_buf(); - auto density = buf_density.get_read_access(depends_list); - auto pressure = pressure_buf.get_write_access(depends_list); - auto soundspeed = soundspeed_buf.get_write_access(depends_list); + sham::DeviceQueue &q = dev_sched->get_queue(); + sham::EventList depends_list; - const Tscal *uint_ptr = nullptr; - if (has_uint) { - uint_ptr = mpdat.get_field_buf_ref(iuint_interf).get_read_access(depends_list); - } + auto density = buf_density.get_read_access(depends_list); + auto pressure = pressure_buf.get_write_access(depends_list); + auto soundspeed = soundspeed_buf.get_write_access(depends_list); - auto e = q.submit(depends_list, [&](sycl::handler &cgh) { - shambase::parallel_for(cgh, total_elements, "compute_eos_gsph", [=](u64 gid) { - u32 i = (u32) gid; + const Tscal *uint_ptr = nullptr; + if (has_uint) { + uint_ptr + = mpdat.get_field_buf_ref(iuint_interf).get_read_access(depends_list); + } - // Use SPH-summation density (from compute_omega, communicated to ghosts) - Tscal rho = density[i]; - rho = sycl::max(rho, Tscal(1e-30)); - - if (has_uint && uint_ptr != nullptr) { - // Adiabatic EOS (reference: g_pre_interaction.cpp line 107) - // P = (\gamma - 1) * \rho * u - Tscal u = uint_ptr[i]; - u = sycl::max(u, Tscal(1e-30)); - Tscal P = (gamma - Tscal(1.0)) * rho * u; - - // Sound speed from internal energy (reference: solver.cpp line 2661) - // c = sqrt(\gamma * (\gamma - 1) * u) - Tscal cs = sycl::sqrt(gamma * (gamma - Tscal(1.0)) * u); - - // Clamp to reasonable values - P = sycl::clamp(P, Tscal(1e-30), Tscal(1e30)); - cs = sycl::clamp(cs, Tscal(1e-10), Tscal(1e10)); - - pressure[i] = P; - soundspeed[i] = cs; - } else { - // Isothermal case - Tscal cs = Tscal(1.0); - Tscal P = cs * cs * rho; - - pressure[i] = P; - soundspeed[i] = cs; - } + auto e = q.submit(depends_list, [&](sycl::handler &cgh) { + shambase::parallel_for(cgh, total_elements, "compute_eos_gsph", [=](u64 gid) { + u32 i = (u32) gid; + + // Use SPH-summation density (from compute_omega, communicated to ghosts) + Tscal rho = density[i]; + rho = sycl::max(rho, Tscal(1e-30)); + + if (has_uint && uint_ptr != nullptr) { + // Adiabatic EOS (reference: g_pre_interaction.cpp line 107) + // P = (\gamma - 1) * \rho * u + Tscal u = uint_ptr[i]; + u = sycl::max(u, Tscal(1e-30)); + Tscal P = (gamma - Tscal(1.0)) * rho * u; + + // Sound speed from internal energy (reference: solver.cpp line 2661) + // c = sqrt(\gamma * (\gamma - 1) * u) + Tscal cs = sycl::sqrt(gamma * (gamma - Tscal(1.0)) * u); + + // Clamp to reasonable values + P = sycl::clamp(P, Tscal(1e-30), Tscal(1e30)); + cs = sycl::clamp(cs, Tscal(1e-10), Tscal(1e10)); + + pressure[i] = P; + soundspeed[i] = cs; + } else { + // Isothermal case + Tscal cs = Tscal(1.0); + Tscal P = cs * cs * rho; + + pressure[i] = P; + soundspeed[i] = cs; + } + }); }); - }); - // Complete all buffer event states - buf_density.complete_event_state(e); - if (has_uint) { - mpdat.get_field_buf_ref(iuint_interf).complete_event_state(e); - } - pressure_buf.complete_event_state(e); - soundspeed_buf.complete_event_state(e); - }); + // Complete all buffer event states + buf_density.complete_event_state(e); + if (has_uint) { + mpdat.get_field_buf_ref(iuint_interf).complete_event_state(e); + } + pressure_buf.complete_event_state(e); + soundspeed_buf.complete_event_state(e); + }); } template class Kern> @@ -1309,7 +1327,7 @@ void shammodels::gsph::Solver::compute_gradients() { grad_vy_field.ensure_sizes(counts); grad_vz_field.ensure_sizes(counts); - auto &merged_xyzh = storage.merged_xyzh.get(); + auto &merged_xyzh = shambase::get_check_ref(storage.merged_xyzh).get_data(); auto &neigh_cache = storage.neigh_cache->neigh_cache; static constexpr Tscal Rkern = Kernel::Rkern; @@ -1824,7 +1842,7 @@ shammodels::gsph::TimestepLog shammodels::gsph::Solver::evolve_once( reset_presteps_rint(); clear_merged_pos_trees(); reset_merge_ghosts_fields(); - storage.merged_xyzh.reset(); + shambase::get_check_ref(storage.merged_xyzh).free_alloc(); clear_ghost_cache(); reset_serial_patch_tree(); reset_ghost_handler(); diff --git a/src/shammodels/gsph/src/modules/UpdateDerivs.cpp b/src/shammodels/gsph/src/modules/UpdateDerivs.cpp index fe09dd66c1..c1efbb2bf1 100644 --- a/src/shammodels/gsph/src/modules/UpdateDerivs.cpp +++ b/src/shammodels/gsph/src/modules/UpdateDerivs.cpp @@ -84,9 +84,10 @@ void shammodels::gsph::modules::UpdateDerivs::update_derivs_ite = has_uint ? ghost_layout.get_field_idx(gsph::names::newtonian::uint) : 0; // Get merged data and caches from storage - auto &merged_xyzh = storage.merged_xyzh.get(); - shamrock::solvergraph::Field &omega_field = shambase::get_check_ref(storage.omega); - shambase::DistributedData &mpdats = storage.merged_patchdata_ghost.get(); + auto &merged_xyzh = shambase::get_check_ref(storage.merged_xyzh).get_data(); + shamrock::solvergraph::Field &omega_field = shambase::get_check_ref(storage.omega); + shambase::DistributedData &mpdats + = shambase::get_check_ref(storage.merged_patchdata_ghost).get_data(); // Get pressure and soundspeed from storage (includes ghosts) shamrock::solvergraph::Field &pressure_field = shambase::get_check_ref(storage.pressure); @@ -300,9 +301,10 @@ void shammodels::gsph::modules::UpdateDerivs::update_derivs_hll u32 iuint_interf = has_uint ? ghost_layout.get_field_idx(gsph::names::newtonian::uint) : 0; - auto &merged_xyzh = storage.merged_xyzh.get(); - shamrock::solvergraph::Field &omega_field = shambase::get_check_ref(storage.omega); - shambase::DistributedData &mpdats = storage.merged_patchdata_ghost.get(); + auto &merged_xyzh = shambase::get_check_ref(storage.merged_xyzh).get_data(); + shamrock::solvergraph::Field &omega_field = shambase::get_check_ref(storage.omega); + shambase::DistributedData &mpdats + = shambase::get_check_ref(storage.merged_patchdata_ghost).get_data(); // Get pressure and soundspeed from storage (includes ghosts) shamrock::solvergraph::Field &pressure_field = shambase::get_check_ref(storage.pressure);