From ce709a7f3f2163296280496777bedc4cc13bb79b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Timoth=C3=A9e=20David--Cl=C3=A9ris?= Date: Wed, 24 Dec 2025 15:13:50 +0100 Subject: [PATCH 01/54] [Algs] Implement a inplace variant of sparse comm --- .../shamalgs/collective/RequestList.hpp | 74 +++++++ .../shamalgs/collective/sparseXchg.hpp | 8 + .../shamalgs/collective/sparse_exchange.hpp | 54 +++++ src/shamalgs/src/collective/sparseXchg.cpp | 100 +++------- .../src/collective/sparse_exchange.cpp | 185 ++++++++++++++++++ 5 files changed, 351 insertions(+), 70 deletions(-) create mode 100644 src/shamalgs/include/shamalgs/collective/RequestList.hpp create mode 100644 src/shamalgs/include/shamalgs/collective/sparse_exchange.hpp create mode 100644 src/shamalgs/src/collective/sparse_exchange.cpp diff --git a/src/shamalgs/include/shamalgs/collective/RequestList.hpp b/src/shamalgs/include/shamalgs/collective/RequestList.hpp new file mode 100644 index 0000000000..784ac5a2b8 --- /dev/null +++ b/src/shamalgs/include/shamalgs/collective/RequestList.hpp @@ -0,0 +1,74 @@ +// -------------------------------------------------------// +// +// SHAMROCK code for hydrodynamics +// Copyright (c) 2021-2025 Timothée David--Cléris +// SPDX-License-Identifier: CeCILL Free Software License Agreement v2.1 +// Shamrock is licensed under the CeCILL 2.1 License, see LICENSE for more information +// +// -------------------------------------------------------// + +#pragma once + +/** + * @file RequestList.hpp + * @author Timothée David--Cléris (tim.shamrock@proton.me) + * @brief + * + */ + +#include "shambase/narrowing.hpp" +#include "shambase/time.hpp" +#include "shamcomm/wrapper.hpp" +#include + +namespace shamalgs::collective { + + class RequestList { + + std::vector rqs; + std::vector is_ready; + + size_t ready_count = 0; + + public: + MPI_Request &new_request() { + rqs.push_back(MPI_Request{}); + size_t rq_index = rqs.size() - 1; + auto &rq = rqs[rq_index]; + is_ready.push_back(false); + return rq; + } + + size_t size() { return rqs.size(); } + bool is_event_ready(size_t i) { return is_ready[i]; } + std::vector &requests() { return rqs; } + + void test_ready() { + for (u32 i = 0; i < rqs.size(); i++) { + if (!is_ready[i]) { + MPI_Status st; + int ready; + shamcomm::mpi::Test(&rqs[i], &ready, MPI_STATUS_IGNORE); + if (ready) { + is_ready[i] = true; + ready_count++; + } + } + } + } + + bool all_ready() { return ready_count == rqs.size(); } + + void wait_all() { + std::vector st_lst(rqs.size()); + shamcomm::mpi::Waitall( + shambase::narrow_or_throw(rqs.size()), rqs.data(), st_lst.data()); + } + + size_t remain_count() { + test_ready(); + return rqs.size() - ready_count; + } + }; + +} // namespace shamalgs::collective diff --git a/src/shamalgs/include/shamalgs/collective/sparseXchg.hpp b/src/shamalgs/include/shamalgs/collective/sparseXchg.hpp index 29ab8b1029..c901472ea1 100644 --- a/src/shamalgs/include/shamalgs/collective/sparseXchg.hpp +++ b/src/shamalgs/include/shamalgs/collective/sparseXchg.hpp @@ -43,6 +43,14 @@ namespace shamalgs::collective { + struct CommMessage { + sham::DeviceBuffer &bytebuffer; + size_t message_offset; + size_t message_size; + i32 sender_rank; + i32 receiver_rank; + }; + struct SendPayload { i32 receiver_rank; std::unique_ptr payload; diff --git a/src/shamalgs/include/shamalgs/collective/sparse_exchange.hpp b/src/shamalgs/include/shamalgs/collective/sparse_exchange.hpp new file mode 100644 index 0000000000..6f98a15cf2 --- /dev/null +++ b/src/shamalgs/include/shamalgs/collective/sparse_exchange.hpp @@ -0,0 +1,54 @@ +// -------------------------------------------------------// +// +// SHAMROCK code for hydrodynamics +// Copyright (c) 2021-2025 Timothée David--Cléris +// SPDX-License-Identifier: CeCILL Free Software License Agreement v2.1 +// Shamrock is licensed under the CeCILL 2.1 License, see LICENSE for more information +// +// -------------------------------------------------------// + +#pragma once + +/** + * @file sparse_exchange.hpp + * @author Timothée David--Cléris (tim.shamrock@proton.me) + * @brief + * + */ + +#include "shambackends/DeviceBuffer.hpp" +#include +#include + +namespace shamalgs::collective { + + struct CommMessageInfo { + size_t message_size; ///< Size of the MPI message + i32 rank_sender; ///< Rank of the sender + i32 rank_receiver; ///< Rank of the receiver + std::optional message_tag; ///< Tag of the MPI message + std::optional + message_bytebuf_offset_send; ///< Offset of the MPI message in the send buffer + std::optional + message_bytebuf_offset_recv; ///< Offset of the MPI message in the recv buffer + }; + + struct CommTable { + std::vector messages_send; ///< Messages to send + std::vector message_all; ///< All messages = (allgatherv of messages_send) + std::vector messages_recv; ///< Messages to recv + std::vector send_message_global_ids; ///< ids of messages_send in message_all + std::vector recv_message_global_ids; ///< ids of messages_recv in message_all + size_t send_total_size; ///< Total size of the send buffer + size_t recv_total_size; ///< Total size of the recv buffer + }; + + CommTable build_sparse_exchange_table(const std::vector &messages_send); + + void sparse_exchange( + std::shared_ptr dev_sched, + sham::DeviceBuffer &bytebuffer_send, + sham::DeviceBuffer &bytebuffer_recv, + const CommTable &comm_table); + +} // namespace shamalgs::collective diff --git a/src/shamalgs/src/collective/sparseXchg.cpp b/src/shamalgs/src/collective/sparseXchg.cpp index cc279a0327..31d3bf4b61 100644 --- a/src/shamalgs/src/collective/sparseXchg.cpp +++ b/src/shamalgs/src/collective/sparseXchg.cpp @@ -18,6 +18,7 @@ #include "shambase/exception.hpp" #include "shambase/string.hpp" #include "shambase/time.hpp" +#include "shamalgs/collective/RequestList.hpp" #include "shamcmdopt/env.hpp" #include "shamcomm/logs.hpp" #include "shamcomm/worldInfo.hpp" @@ -87,75 +88,34 @@ namespace { bool is_recv; }; - struct RequestList { - - std::vector rqs; - std::vector is_ready; - - u32 ready_count = 0; - - MPI_Request &new_request() { - rqs.push_back(MPI_Request{}); - u32 rq_index = rqs.size() - 1; - auto &rq = rqs[rq_index]; - is_ready.push_back(false); - return rq; - } - - void test_ready() { - for (u32 i = 0; i < rqs.size(); i++) { - if (!is_ready[i]) { - MPI_Status st; - int ready; - shamcomm::mpi::Test(&rqs[i], &ready, MPI_STATUS_IGNORE); - if (ready) { - is_ready[i] = true; - ready_count++; - } - } - } - } - - bool all_ready() { return ready_count == rqs.size(); } - - void wait_all() { - std::vector st_lst(rqs.size()); - shamcomm::mpi::Waitall(rqs.size(), rqs.data(), st_lst.data()); - } - - u32 remain_count() { - test_ready(); - return rqs.size() - ready_count; - } - }; - - auto report_unfinished_requests = [](RequestList &rqs, std::vector &rqs_infos) { - std::string err_msg = ""; - for (u32 i = 0; i < rqs.rqs.size(); i++) { - if (!rqs.is_ready[i]) { - - if (rqs_infos[i].is_send) { - err_msg += shambase::format( - "communication timeout : send {} -> {} tag {} size {}\n", - rqs_infos[i].sender, - rqs_infos[i].receiver, - rqs_infos[i].tag, - rqs_infos[i].size); - } else { - err_msg += shambase::format( - "communication timeout : recv {} -> {} tag {} size {}\n", - rqs_infos[i].sender, - rqs_infos[i].receiver, - rqs_infos[i].tag, - rqs_infos[i].size); - } - } - } - std::string msg = shambase::format("communication timeout : \n{}", err_msg); - logger::err_ln("Sparse comm", msg); - std::this_thread::sleep_for(std::chrono::seconds(2)); - shambase::throw_with_loc(msg); - }; + auto report_unfinished_requests + = [](shamalgs::collective::RequestList &rqs, std::vector &rqs_infos) { + std::string err_msg = ""; + for (u32 i = 0; i < rqs.size(); i++) { + if (!rqs.is_event_ready(i)) { + + if (rqs_infos[i].is_send) { + err_msg += shambase::format( + "communication timeout : send {} -> {} tag {} size {}\n", + rqs_infos[i].sender, + rqs_infos[i].receiver, + rqs_infos[i].tag, + rqs_infos[i].size); + } else { + err_msg += shambase::format( + "communication timeout : recv {} -> {} tag {} size {}\n", + rqs_infos[i].sender, + rqs_infos[i].receiver, + rqs_infos[i].tag, + rqs_infos[i].size); + } + } + } + std::string msg = shambase::format("communication timeout : \n{}", err_msg); + logger::err_ln("Sparse comm", msg); + std::this_thread::sleep_for(std::chrono::seconds(2)); + shambase::throw_with_loc(msg); + }; auto test_event_completions = [](std::vector &rqs, std::vector &rqs_infos) { @@ -583,7 +543,7 @@ namespace shamalgs::collective { } } - test_event_completions(rqs.rqs, rqs_infos); + test_event_completions(rqs.requests(), rqs_infos); rqs.wait_all(); diff --git a/src/shamalgs/src/collective/sparse_exchange.cpp b/src/shamalgs/src/collective/sparse_exchange.cpp new file mode 100644 index 0000000000..8c1a4a72b7 --- /dev/null +++ b/src/shamalgs/src/collective/sparse_exchange.cpp @@ -0,0 +1,185 @@ +// -------------------------------------------------------// +// +// SHAMROCK code for hydrodynamics +// Copyright (c) 2021-2025 Timothée David--Cléris +// SPDX-License-Identifier: CeCILL Free Software License Agreement v2.1 +// Shamrock is licensed under the CeCILL 2.1 License, see LICENSE for more information +// +// -------------------------------------------------------// + +#pragma once + +/** + * @file sparse_exchange.cpp + * @author Timothée David--Cléris (tim.shamrock@proton.me) + * @brief + * + */ + +#include "shamalgs/collective/sparse_exchange.hpp" +#include "shambase/exception.hpp" +#include "shambase/memory.hpp" +#include "shambase/stacktrace.hpp" +#include "shamalgs/collective/exchanges.hpp" +#include "shambackends/math.hpp" +#include "shamcomm/mpi.hpp" +#include "shamcomm/worldInfo.hpp" +#include + +namespace shamalgs::collective { + + CommMessageInfo unpack(u64_2 comm_info) { + u64 comm_vec = comm_info.x(); + size_t message_size = comm_info.y(); + u32_2 comm_ranks = sham::unpack32(comm_vec); + u32 sender = comm_ranks.x(); + u32 receiver = comm_ranks.y(); + + if (message_size == 0) { + throw shambase::make_except_with_loc(shambase::format( + "Message size is 0 for rank {}, sender = {}, receiver = {}", + shamcomm::world_rank(), + sender, + receiver)); + } + + return CommMessageInfo{ + message_size, + static_cast(sender), + static_cast(receiver), + std::nullopt, + std::nullopt, + std::nullopt}; + }; + + CommTable build_sparse_exchange_table(const std::vector &messages_send) { + __shamrock_stack_entry(); + + //////////////////////////////////////////////////////////// + // Pack the local data then allgatherv to get the global data + //////////////////////////////////////////////////////////// + + std::vector local_data = std::vector(messages_send.size()); + + for (size_t i = 0; i < messages_send.size(); i++) { + u32 sender = static_cast(messages_send[i].rank_sender); + u32 receiver = static_cast(messages_send[i].rank_receiver); + size_t message_size = messages_send[i].message_size; + + if (sender != shamcomm::world_rank()) { + throw shambase::make_except_with_loc(shambase::format( + "You are trying to send a message from a rank that does not posses it\n" + " sender = {}, receiver = {}, world_rank = {}", + sender, + receiver, + shamcomm::world_rank())); + } + + local_data[i] = u64_2{sham::pack32(sender, receiver), message_size}; + } + + std::vector global_data; + vector_allgatherv(local_data, global_data, MPI_COMM_WORLD); + + //////////////////////////////////////////////////////////// + // Unpack the global data and build the global message list + //////////////////////////////////////////////////////////// + + std::vector message_all(global_data.size()); + + std::vector tag_map(shamcomm::world_size(), 0); + + u32 send_idx = 0; + u32 recv_idx = 0; + + size_t recv_offset = 0; + size_t send_offset = 0; + for (u64 i = 0; i < global_data.size(); i++) { + auto message_info = unpack(global_data[i]); + + auto sender = message_info.rank_sender; + auto receiver = message_info.rank_receiver; + + i32 &tag_map_ref = tag_map[static_cast(sender)]; + + i32 tag = tag_map_ref; + tag_map_ref++; + + message_info.message_tag = tag; + + if (sender == shamcomm::world_rank()) { + message_info.message_bytebuf_offset_send = send_offset; + send_offset += message_info.message_size; + send_idx++; + } + + if (receiver == shamcomm::world_rank()) { + message_info.message_bytebuf_offset_recv = recv_offset; + recv_offset += message_info.message_size; + recv_idx++; + } + + message_all[i] = message_info; + } + + //////////////////////////////////////////////////////////// + // now that all comm were computed we can build the send and recv message lists + //////////////////////////////////////////////////////////// + + std::vector ret_message_send(send_idx); + std::vector ret_message_recv(recv_idx); + + std::vector send_message_global_ids(send_idx); + std::vector recv_message_global_ids(recv_idx); + + send_idx = 0; + recv_idx = 0; + + for (size_t i = 0; i < message_all.size(); i++) { + auto message_info = message_all[i]; + if (message_info.rank_sender == shamcomm::world_rank()) { + + // the sender shoudl have set the offset for all messages, otherwise throw + auto expected_offset = shambase::get_check_ref( + messages_send.at(send_idx).message_bytebuf_offset_send); + + // check that the send offset match for good measure + if (message_info.message_bytebuf_offset_send != expected_offset) { + throw shambase::make_except_with_loc(shambase::format( + "The sender has not set the offset for all messages, otherwise throw\n" + " expected_offset = {}, actual_offset = {}", + expected_offset, + message_info.message_bytebuf_offset_send)); + } + + ret_message_send[send_idx] = message_info; + send_message_global_ids[send_idx] = i; + send_idx++; + } + if (message_info.rank_receiver == shamcomm::world_rank()) { + ret_message_recv[recv_idx] = message_info; + recv_message_global_ids[recv_idx] = i; + recv_idx++; + } + } + + return CommTable{ + ret_message_send, + message_all, + ret_message_recv, + send_message_global_ids, + recv_message_global_ids, + send_offset, + recv_offset}; + } + + void sparse_exchange( + std::shared_ptr dev_sched, + sham::DeviceBuffer &bytebuffer_send, + sham::DeviceBuffer &bytebuffer_recv, + const CommTable &comm_table) { + + __shamrock_stack_entry(); + } + +} // namespace shamalgs::collective From df9a29d2e1eade06650d4bd006969b85f4cba30f Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Wed, 24 Dec 2025 14:17:39 +0000 Subject: [PATCH 02/54] [gh-action] trigger CI with empty commit From 67c05c25bf4a00cd69355dca8163db3a546ab645 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Timoth=C3=A9e=20David--Cl=C3=A9ris?= Date: Wed, 24 Dec 2025 15:43:17 +0100 Subject: [PATCH 03/54] update from stack diff --- .../shamalgs/collective/RequestList.hpp | 21 +++++++++---------- .../shamalgs/collective/sparseXchg.hpp | 8 ------- 2 files changed, 10 insertions(+), 19 deletions(-) diff --git a/src/shamalgs/include/shamalgs/collective/RequestList.hpp b/src/shamalgs/include/shamalgs/collective/RequestList.hpp index 784ac5a2b8..ff5a8efcb2 100644 --- a/src/shamalgs/include/shamalgs/collective/RequestList.hpp +++ b/src/shamalgs/include/shamalgs/collective/RequestList.hpp @@ -12,12 +12,11 @@ /** * @file RequestList.hpp * @author Timothée David--Cléris (tim.shamrock@proton.me) - * @brief + * @brief Provides a helper class to manage a list of MPI requests. * */ #include "shambase/narrowing.hpp" -#include "shambase/time.hpp" #include "shamcomm/wrapper.hpp" #include @@ -32,21 +31,18 @@ namespace shamalgs::collective { public: MPI_Request &new_request() { - rqs.push_back(MPI_Request{}); - size_t rq_index = rqs.size() - 1; - auto &rq = rqs[rq_index]; + rqs.emplace_back(); is_ready.push_back(false); - return rq; + return rqs.back(); } - size_t size() { return rqs.size(); } - bool is_event_ready(size_t i) { return is_ready[i]; } + size_t size() const { return rqs.size(); } + bool is_event_ready(size_t i) const { return is_ready[i]; } std::vector &requests() { return rqs; } void test_ready() { - for (u32 i = 0; i < rqs.size(); i++) { + for (size_t i = 0; i < rqs.size(); i++) { if (!is_ready[i]) { - MPI_Status st; int ready; shamcomm::mpi::Test(&rqs[i], &ready, MPI_STATUS_IGNORE); if (ready) { @@ -57,9 +53,12 @@ namespace shamalgs::collective { } } - bool all_ready() { return ready_count == rqs.size(); } + bool all_ready() const { return ready_count == rqs.size(); } void wait_all() { + if (ready_count == rqs.size()) { + return; + } std::vector st_lst(rqs.size()); shamcomm::mpi::Waitall( shambase::narrow_or_throw(rqs.size()), rqs.data(), st_lst.data()); diff --git a/src/shamalgs/include/shamalgs/collective/sparseXchg.hpp b/src/shamalgs/include/shamalgs/collective/sparseXchg.hpp index c901472ea1..29ab8b1029 100644 --- a/src/shamalgs/include/shamalgs/collective/sparseXchg.hpp +++ b/src/shamalgs/include/shamalgs/collective/sparseXchg.hpp @@ -43,14 +43,6 @@ namespace shamalgs::collective { - struct CommMessage { - sham::DeviceBuffer &bytebuffer; - size_t message_offset; - size_t message_size; - i32 sender_rank; - i32 receiver_rank; - }; - struct SendPayload { i32 receiver_rank; std::unique_ptr payload; From d1f7b3d3db6c39ed94a253a6127fc38ef9711b61 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Timoth=C3=A9e=20David--Cl=C3=A9ris?= Date: Wed, 24 Dec 2025 16:17:40 +0100 Subject: [PATCH 04/54] implement more of ti --- .../shamalgs/collective/RequestList.hpp | 50 +++++++++++++++ .../src/collective/sparse_exchange.cpp | 64 +++++++++++++++++++ 2 files changed, 114 insertions(+) diff --git a/src/shamalgs/include/shamalgs/collective/RequestList.hpp b/src/shamalgs/include/shamalgs/collective/RequestList.hpp index ff5a8efcb2..2e47f1245f 100644 --- a/src/shamalgs/include/shamalgs/collective/RequestList.hpp +++ b/src/shamalgs/include/shamalgs/collective/RequestList.hpp @@ -17,6 +17,8 @@ */ #include "shambase/narrowing.hpp" +#include "shambase/time.hpp" +#include "shamcomm/logs.hpp" #include "shamcomm/wrapper.hpp" #include @@ -64,10 +66,58 @@ namespace shamalgs::collective { shambase::narrow_or_throw(rqs.size()), rqs.data(), st_lst.data()); } + size_t remain_count_no_test() { return rqs.size() - ready_count; } + size_t remain_count() { test_ready(); return rqs.size() - ready_count; } + + void report_timeout() const { + std::string err_msg = ""; + for (size_t i = 0; i < rqs.size(); i++) { + if (!is_ready[i]) { + err_msg += shambase::format("request {} is not ready\n", i); + } + } + std::string msg = shambase::format("timeout : \n{}", err_msg); + throw shambase::make_except_with_loc(msg); + } + + // spin lock until the number of in-flight requests is less than max_in_flight + void spin_lock_partial_wait(size_t max_in_flight, f64 timeout, f64 print_freq) { + + if (rqs.size() < max_in_flight) { + return; + } + + f64 last_print_time = 0; + size_t in_flight = remain_count(); + + if (in_flight < max_in_flight) { + return; + } + + shambase::Timer twait; + twait.start(); + do { + twait.end(); + if (twait.elasped_sec() > timeout) { + report_timeout(); + } + + if (twait.elasped_sec() - last_print_time > print_freq) { + logger::warn_ln( + "SparseComm", + "too many messages in flight :", + in_flight, + "/", + max_in_flight); + last_print_time = twait.elasped_sec(); + } + in_flight = remain_count(); + } while (in_flight >= max_in_flight); + } }; } // namespace shamalgs::collective diff --git a/src/shamalgs/src/collective/sparse_exchange.cpp b/src/shamalgs/src/collective/sparse_exchange.cpp index 8c1a4a72b7..6622bac50c 100644 --- a/src/shamalgs/src/collective/sparse_exchange.cpp +++ b/src/shamalgs/src/collective/sparse_exchange.cpp @@ -19,8 +19,11 @@ #include "shamalgs/collective/sparse_exchange.hpp" #include "shambase/exception.hpp" #include "shambase/memory.hpp" +#include "shambase/narrowing.hpp" #include "shambase/stacktrace.hpp" +#include "shamalgs/collective/RequestList.hpp" #include "shamalgs/collective/exchanges.hpp" +#include "shambackends/USMPtrHolder.hpp" #include "shambackends/math.hpp" #include "shamcomm/mpi.hpp" #include "shamcomm/worldInfo.hpp" @@ -180,6 +183,67 @@ namespace shamalgs::collective { const CommTable &comm_table) { __shamrock_stack_entry(); + + bool direct_gpu_capable = dev_sched->ctx->device->mpi_prop.is_mpi_direct_capable; + + // TODO: check device loc depending on dgpu support + + u32 SHAM_SPARSE_COMM_INFLIGHT_LIM = 128; // TODO: use the env variable + + if (comm_table.send_total_size < bytebuffer_send.get_size()) { + throw shambase::make_except_with_loc(shambase::format( + "The send total size is greater than the send buffer size\n" + " send_total_size = {}, send_buffer_size = {}", + comm_table.send_total_size, + bytebuffer_send.get_size())); + } + + if (comm_table.recv_total_size < bytebuffer_recv.get_size()) { + throw shambase::make_except_with_loc(shambase::format( + "The recv total size is greater than the recv buffer size\n" + " recv_total_size = {}, recv_buffer_size = {}", + comm_table.recv_total_size, + bytebuffer_recv.get_size())); + } + + sham::EventList depends_list; + const u8 *send_ptr = bytebuffer_send.get_read_access(depends_list); + u8 *recv_ptr = bytebuffer_recv.get_write_access(depends_list); + depends_list.wait(); + bytebuffer_send.complete_event_state(sycl::event{}); + bytebuffer_recv.complete_event_state(sycl::event{}); + + RequestList rqs; + for (u32 i = 0; i < comm_table.message_all.size(); i++) { + + auto message_info = comm_table.message_all[i]; + + if (message_info.rank_sender == shamcomm::world_rank()) { + auto &rq = rqs.new_request(); + shamcomm::mpi::Isend( + send_ptr + shambase::get_check_ref(message_info.message_bytebuf_offset_send), + shambase::narrow_or_throw(message_info.message_size), + MPI_BYTE, + message_info.rank_receiver, + shambase::get_check_ref(message_info.message_tag), + MPI_COMM_WORLD, + &rq); + } + + if (message_info.rank_receiver == shamcomm::world_rank()) { + auto &rq = rqs.new_request(); + shamcomm::mpi::Irecv( + recv_ptr + shambase::get_check_ref(message_info.message_bytebuf_offset_recv), + shambase::narrow_or_throw(message_info.message_size), + MPI_BYTE, + message_info.rank_sender, + shambase::get_check_ref(message_info.message_tag), + MPI_COMM_WORLD, + &rq); + } + + rqs.spin_lock_partial_wait(SHAM_SPARSE_COMM_INFLIGHT_LIM, 120, 10); + } } } // namespace shamalgs::collective From bda7f5e6a4318f59568fefd9f13c1cdd71bce7ff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Timoth=C3=A9e=20David--Cl=C3=A9ris?= Date: Wed, 24 Dec 2025 16:31:10 +0100 Subject: [PATCH 05/54] noice --- .../shamalgs/collective/sparse_exchange.hpp | 5 +- .../src/collective/sparse_exchange.cpp | 81 ++++++++++++------- 2 files changed, 53 insertions(+), 33 deletions(-) diff --git a/src/shamalgs/include/shamalgs/collective/sparse_exchange.hpp b/src/shamalgs/include/shamalgs/collective/sparse_exchange.hpp index 6f98a15cf2..08fc6da761 100644 --- a/src/shamalgs/include/shamalgs/collective/sparse_exchange.hpp +++ b/src/shamalgs/include/shamalgs/collective/sparse_exchange.hpp @@ -45,10 +45,11 @@ namespace shamalgs::collective { CommTable build_sparse_exchange_table(const std::vector &messages_send); + template void sparse_exchange( std::shared_ptr dev_sched, - sham::DeviceBuffer &bytebuffer_send, - sham::DeviceBuffer &bytebuffer_recv, + sham::DeviceBuffer &bytebuffer_send, + sham::DeviceBuffer &bytebuffer_recv, const CommTable &comm_table); } // namespace shamalgs::collective diff --git a/src/shamalgs/src/collective/sparse_exchange.cpp b/src/shamalgs/src/collective/sparse_exchange.cpp index 6622bac50c..f88fe4a757 100644 --- a/src/shamalgs/src/collective/sparse_exchange.cpp +++ b/src/shamalgs/src/collective/sparse_exchange.cpp @@ -178,41 +178,14 @@ namespace shamalgs::collective { void sparse_exchange( std::shared_ptr dev_sched, - sham::DeviceBuffer &bytebuffer_send, - sham::DeviceBuffer &bytebuffer_recv, + const u8 *bytebuffer_send, + u8 *bytebuffer_recv, const CommTable &comm_table) { __shamrock_stack_entry(); - bool direct_gpu_capable = dev_sched->ctx->device->mpi_prop.is_mpi_direct_capable; - - // TODO: check device loc depending on dgpu support - u32 SHAM_SPARSE_COMM_INFLIGHT_LIM = 128; // TODO: use the env variable - if (comm_table.send_total_size < bytebuffer_send.get_size()) { - throw shambase::make_except_with_loc(shambase::format( - "The send total size is greater than the send buffer size\n" - " send_total_size = {}, send_buffer_size = {}", - comm_table.send_total_size, - bytebuffer_send.get_size())); - } - - if (comm_table.recv_total_size < bytebuffer_recv.get_size()) { - throw shambase::make_except_with_loc(shambase::format( - "The recv total size is greater than the recv buffer size\n" - " recv_total_size = {}, recv_buffer_size = {}", - comm_table.recv_total_size, - bytebuffer_recv.get_size())); - } - - sham::EventList depends_list; - const u8 *send_ptr = bytebuffer_send.get_read_access(depends_list); - u8 *recv_ptr = bytebuffer_recv.get_write_access(depends_list); - depends_list.wait(); - bytebuffer_send.complete_event_state(sycl::event{}); - bytebuffer_recv.complete_event_state(sycl::event{}); - RequestList rqs; for (u32 i = 0; i < comm_table.message_all.size(); i++) { @@ -221,7 +194,8 @@ namespace shamalgs::collective { if (message_info.rank_sender == shamcomm::world_rank()) { auto &rq = rqs.new_request(); shamcomm::mpi::Isend( - send_ptr + shambase::get_check_ref(message_info.message_bytebuf_offset_send), + bytebuffer_send + + shambase::get_check_ref(message_info.message_bytebuf_offset_send), shambase::narrow_or_throw(message_info.message_size), MPI_BYTE, message_info.rank_receiver, @@ -233,7 +207,8 @@ namespace shamalgs::collective { if (message_info.rank_receiver == shamcomm::world_rank()) { auto &rq = rqs.new_request(); shamcomm::mpi::Irecv( - recv_ptr + shambase::get_check_ref(message_info.message_bytebuf_offset_recv), + bytebuffer_recv + + shambase::get_check_ref(message_info.message_bytebuf_offset_recv), shambase::narrow_or_throw(message_info.message_size), MPI_BYTE, message_info.rank_sender, @@ -246,4 +221,48 @@ namespace shamalgs::collective { } } + template + void sparse_exchange( + std::shared_ptr dev_sched, + sham::DeviceBuffer &bytebuffer_send, + sham::DeviceBuffer &bytebuffer_recv, + const CommTable &comm_table) { + + __shamrock_stack_entry(); + + if (comm_table.send_total_size < bytebuffer_send.get_size()) { + throw shambase::make_except_with_loc(shambase::format( + "The send total size is greater than the send buffer size\n" + " send_total_size = {}, send_buffer_size = {}", + comm_table.send_total_size, + bytebuffer_send.get_size())); + } + + if (comm_table.recv_total_size < bytebuffer_recv.get_size()) { + throw shambase::make_except_with_loc(shambase::format( + "The recv total size is greater than the recv buffer size\n" + " recv_total_size = {}, recv_buffer_size = {}", + comm_table.recv_total_size, + bytebuffer_recv.get_size())); + } + + bool direct_gpu_capable = dev_sched->ctx->device->mpi_prop.is_mpi_direct_capable; + + if (!direct_gpu_capable && target == sham::device) { + throw shambase::make_except_with_loc( + "You are trying to use a device buffer on the device but the device is not direct " + "GPU capable"); + } + + sham::EventList depends_list; + const u8 *send_ptr = bytebuffer_send.get_read_access(depends_list); + u8 *recv_ptr = bytebuffer_recv.get_write_access(depends_list); + depends_list.wait(); + + sparse_exchange(dev_sched, send_ptr, recv_ptr, comm_table); + + bytebuffer_send.complete_event_state(sycl::event{}); + bytebuffer_recv.complete_event_state(sycl::event{}); + } + } // namespace shamalgs::collective From 97dc8857d149452352b595aafe69b4559aeabaaf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Timoth=C3=A9e=20David--Cl=C3=A9ris?= Date: Thu, 25 Dec 2025 17:54:33 +0100 Subject: [PATCH 06/54] it works --- src/shamalgs/CMakeLists.txt | 1 + .../src/collective/sparse_exchange.cpp | 18 +- .../collective/sparse_exchange_tests.cpp | 206 ++++++++++++++++++ 3 files changed, 222 insertions(+), 3 deletions(-) create mode 100644 src/tests/shamalgs/collective/sparse_exchange_tests.cpp diff --git a/src/shamalgs/CMakeLists.txt b/src/shamalgs/CMakeLists.txt index 22a57a8c4b..372a27f587 100644 --- a/src/shamalgs/CMakeLists.txt +++ b/src/shamalgs/CMakeLists.txt @@ -42,6 +42,7 @@ set(Sources src/primitives/gen_buffer_index.cpp src/primitives/segmented_sort_in_place.cpp src/primitives/append_subset_to.cpp + src/collective/sparse_exchange.cpp ) if(SHAMROCK_USE_SHARED_LIB) diff --git a/src/shamalgs/src/collective/sparse_exchange.cpp b/src/shamalgs/src/collective/sparse_exchange.cpp index f88fe4a757..b0f1a06fe5 100644 --- a/src/shamalgs/src/collective/sparse_exchange.cpp +++ b/src/shamalgs/src/collective/sparse_exchange.cpp @@ -7,8 +7,6 @@ // // -------------------------------------------------------// -#pragma once - /** * @file sparse_exchange.cpp * @author Timothée David--Cléris (tim.shamrock@proton.me) @@ -152,7 +150,7 @@ namespace shamalgs::collective { "The sender has not set the offset for all messages, otherwise throw\n" " expected_offset = {}, actual_offset = {}", expected_offset, - message_info.message_bytebuf_offset_send)); + message_info.message_bytebuf_offset_send.value())); } ret_message_send[send_idx] = message_info; @@ -219,6 +217,7 @@ namespace shamalgs::collective { rqs.spin_lock_partial_wait(SHAM_SPARSE_COMM_INFLIGHT_LIM, 120, 10); } + rqs.wait_all(); } template @@ -265,4 +264,17 @@ namespace shamalgs::collective { bytebuffer_recv.complete_event_state(sycl::event{}); } + // template instantiations + template void sparse_exchange( + std::shared_ptr dev_sched, + sham::DeviceBuffer &bytebuffer_send, + sham::DeviceBuffer &bytebuffer_recv, + const CommTable &comm_table); + + template void sparse_exchange( + std::shared_ptr dev_sched, + sham::DeviceBuffer &bytebuffer_send, + sham::DeviceBuffer &bytebuffer_recv, + const CommTable &comm_table); + } // namespace shamalgs::collective diff --git a/src/tests/shamalgs/collective/sparse_exchange_tests.cpp b/src/tests/shamalgs/collective/sparse_exchange_tests.cpp new file mode 100644 index 0000000000..1222fd76f3 --- /dev/null +++ b/src/tests/shamalgs/collective/sparse_exchange_tests.cpp @@ -0,0 +1,206 @@ +// -------------------------------------------------------// +// +// SHAMROCK code for hydrodynamics +// Copyright (c) 2021-2025 Timothée David--Cléris +// SPDX-License-Identifier: CeCILL Free Software License Agreement v2.1 +// Shamrock is licensed under the CeCILL 2.1 License, see LICENSE for more information +// +// -------------------------------------------------------// + +#include "shamalgs/collective/sparse_exchange.hpp" +#include "shamalgs/details/random/random.hpp" +#include "shamalgs/primitives/equals.hpp" +#include "shambackends/DeviceBuffer.hpp" +#include "shamcomm/logs.hpp" +#include "shamsys/NodeInstance.hpp" +#include "shamtest/shamtest.hpp" +#include + +namespace { + + struct TestElement { + i32 sender, receiver; + u32 size; + }; + +} // namespace + +void reorder_msg(std::vector &test_elements) { + std::sort(test_elements.begin(), test_elements.end(), [](const auto &lhs, const auto &rhs) { + return lhs.sender + < rhs.sender; //|| (lhs.sender == rhs.sender && lhs.receiver < rhs.receiver); + }); +} + +void test_sparse_exchange(std::vector test_elements) { + auto dev_sched = shamsys::instance::get_compute_scheduler_ptr(); + + reorder_msg(test_elements); + + std::vector> all_bufs; + + std::mt19937 eng(0x123); + for (const auto &test_element : test_elements) { + all_bufs.push_back( + shamalgs::random::mock_buffer_usm(dev_sched, eng(), test_element.size)); + } + + sham::DeviceBuffer send_buf(0, dev_sched); + std::vector messages_send; + + size_t total_recv_size = 0; + size_t total_recv_count = 0; + size_t sender_offset = 0; + size_t sender_count = 0; + for (u32 i = 0; i < test_elements.size(); i++) { + if (test_elements[i].sender == shamcomm::world_rank()) { + messages_send.push_back(shamalgs::collective::CommMessageInfo{ + test_elements[i].size, + test_elements[i].sender, + test_elements[i].receiver, + std::nullopt, + sender_offset, + std::nullopt, + }); + + logger::info_ln("sparse exchange test", + "rank :", + shamcomm::world_rank(), + "send message : (", + test_elements[i].sender, + "->", + test_elements[i].receiver, + ") data :", + all_bufs[i].copy_to_stdvec()); + + send_buf.append(all_bufs[i]); + sender_offset += test_elements[i].size; + sender_count++; + } + if (test_elements[i].receiver == shamcomm::world_rank()) { + total_recv_size += test_elements[i].size; + total_recv_count++; + } + } + + shamalgs::collective::CommTable comm_table + = shamalgs::collective::build_sparse_exchange_table(messages_send); + + REQUIRE_EQUAL(comm_table.send_total_size, sender_offset); + REQUIRE_EQUAL(comm_table.recv_total_size, total_recv_size); + + // allocate recv buffer + sham::DeviceBuffer recv_buf(comm_table.recv_total_size, dev_sched); + + // do the comm + if (dev_sched->ctx->device->mpi_prop.is_mpi_direct_capable) { + shamalgs::collective::sparse_exchange(dev_sched, send_buf, recv_buf, comm_table); + } else { + auto send_buf_host = send_buf.copy_to(); + auto recv_buf_host = recv_buf.copy_to(); + shamalgs::collective::sparse_exchange(dev_sched, send_buf_host, recv_buf_host, comm_table); + recv_buf.copy_from(recv_buf_host); + } + + // time to check + + size_t send_msg_idx = 0; + size_t recv_msg_idx = 0; + for (u32 i = 0; i < test_elements.size(); i++) { + if (test_elements[i].sender == shamcomm::world_rank()) { + REQUIRE_EQUAL( + comm_table.messages_send[send_msg_idx].message_size, test_elements[i].size); + REQUIRE_EQUAL( + comm_table.messages_send[send_msg_idx].rank_sender, test_elements[i].sender); + REQUIRE_EQUAL( + comm_table.messages_send[send_msg_idx].rank_receiver, test_elements[i].receiver); + + send_msg_idx++; + } + if (test_elements[i].receiver == shamcomm::world_rank()) { + REQUIRE_EQUAL( + comm_table.messages_recv[recv_msg_idx].message_size, test_elements[i].size); + REQUIRE_EQUAL( + comm_table.messages_recv[recv_msg_idx].rank_sender, test_elements[i].sender); + REQUIRE_EQUAL( + comm_table.messages_recv[recv_msg_idx].rank_receiver, test_elements[i].receiver); + + auto &ref_buf = all_bufs[i]; + sham::DeviceBuffer recov(test_elements[i].size, dev_sched); + size_t begin = shambase::get_check_ref( + comm_table.messages_recv[recv_msg_idx].message_bytebuf_offset_recv); + size_t end = begin + test_elements[i].size; + recv_buf.copy_range(begin, end, recov); + + logger::info_ln("sparse exchange test","rank :", shamcomm::world_rank(), "recv message : (", test_elements[i].sender, "->", test_elements[i].receiver, ") data :", recov.copy_to_stdvec()); + + REQUIRE_EQUAL(recov.copy_to_stdvec(), ref_buf.copy_to_stdvec()); + + recv_msg_idx++; + } + REQUIRE_EQUAL(comm_table.message_all[i].message_size, test_elements[i].size); + REQUIRE_EQUAL(comm_table.message_all[i].rank_sender, test_elements[i].sender); + REQUIRE_EQUAL(comm_table.message_all[i].rank_receiver, test_elements[i].receiver); + } +} + +TestStart(Unittest, "shamalgs/collective/test_sparse_exchange", testsparsexchg_2, -1) { + + if(shamcomm::world_rank() == 0){ + logger::info_ln("sparse exchange test","empty comm"); + } + + test_sparse_exchange({}); + + if(shamcomm::world_rank() == 0){ + logger::info_ln("sparse exchange test","send to self"); + } + + { + // everyone send to itself + std::mt19937 eng(0x123); + std::vector test_elements; + for (i32 i = 0; i < shamcomm::world_size(); i++) { + test_elements.push_back(TestElement{ + i, + i, + shamalgs::primitives::mock_value(eng, 1, 10)}); + } + test_sparse_exchange(test_elements); + } + + + if(shamcomm::world_rank() == 0){ + logger::info_ln("sparse exchange test","send to next"); + } + + { + // everyone send to next one + std::mt19937 eng(0x123); + std::vector test_elements; + for (i32 i = 0; i < shamcomm::world_size(); i++) { + test_elements.push_back(TestElement{ + i, + (i + 1) % shamcomm::world_size(), + shamalgs::primitives::mock_value(eng, 1, 10)}); + } + test_sparse_exchange(test_elements); + } + + if(shamcomm::world_rank() == 0){ + logger::info_ln("sparse exchange test","random test"); + } + + { + // random test + std::mt19937 eng(0x123); + std::vector test_elements; + for (u32 i = 0; i < 3*shamcomm::world_size(); i++) { + test_elements.push_back(TestElement{ + shamalgs::primitives::mock_value(eng, 0, shamcomm::world_size() - 1), + shamalgs::primitives::mock_value(eng, 0, shamcomm::world_size() - 1), + shamalgs::primitives::mock_value(eng, 1, 10)}); + } + test_sparse_exchange(test_elements); + } +} From 1fddeb55bed094b4d109e5dd43405d1b3dddfe15 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Timoth=C3=A9e=20David--Cl=C3=A9ris?= Date: Thu, 25 Dec 2025 18:50:50 +0100 Subject: [PATCH 07/54] swap comm in the code --- .../src/collective/distributedDataComm.cpp | 61 ++++++++++++++ .../collective/sparse_exchange_tests.cpp | 80 +++++++++++-------- 2 files changed, 106 insertions(+), 35 deletions(-) diff --git a/src/shamalgs/src/collective/distributedDataComm.cpp b/src/shamalgs/src/collective/distributedDataComm.cpp index aa1a9ac50f..a84249a987 100644 --- a/src/shamalgs/src/collective/distributedDataComm.cpp +++ b/src/shamalgs/src/collective/distributedDataComm.cpp @@ -18,6 +18,7 @@ #include "shambase/exception.hpp" #include "shambase/memory.hpp" #include "shambase/stacktrace.hpp" +#include "shamalgs/collective/sparse_exchange.hpp" #include "shamalgs/serialize.hpp" #include "shambackends/DeviceBuffer.hpp" #include "shambackends/DeviceScheduler.hpp" @@ -106,6 +107,47 @@ namespace shamalgs::collective { } } + sham::DeviceBuffer send_buf(0, dev_sched); + std::vector messages_send; + + size_t sender_offset = 0; + for (auto &[key, buf] : send_bufs) { + + auto [sender, receiver] = key; + u64 size = buf->get_size(); + + messages_send.push_back( + shamalgs::collective::CommMessageInfo{ + size, + sender, + receiver, + std::nullopt, + sender_offset, + std::nullopt, + }); + + send_buf.append(*buf); + sender_offset += size; + } + + shamalgs::collective::CommTable comm_table2 + = shamalgs::collective::build_sparse_exchange_table(messages_send); + + sham::DeviceBuffer recv_buf(comm_table2.recv_total_size, dev_sched); + + if (dev_sched->ctx->device->mpi_prop.is_mpi_direct_capable) { + shamalgs::collective::sparse_exchange( + dev_sched, send_buf, recv_buf, comm_table2); + } else { + auto send_buf_host = send_buf.copy_to(); + sham::DeviceBuffer recv_buf_host( + comm_table2.recv_total_size, dev_sched); + shamalgs::collective::sparse_exchange( + dev_sched, send_buf_host, recv_buf_host, comm_table2); + recv_buf.copy_from(recv_buf_host); + } + +#ifdef false // prepare payload std::vector send_payoad; { @@ -127,6 +169,8 @@ namespace shamalgs::collective { base_sparse_comm(dev_sched, send_payoad, recv_payload); } +#endif + // make serializers from recv buffs struct RecvPayloadSer { i32 sender_ranks; @@ -135,6 +179,7 @@ namespace shamalgs::collective { std::vector recv_payload_bufs; +#ifdef false { NamedStackEntry stack_loc2{"move payloads"}; for (RecvPayload &payload : recv_payload) { @@ -149,6 +194,22 @@ namespace shamalgs::collective { payload.sender_ranks, SerializeHelper(dev_sched, std::move(buf))}); } } +#endif + + for (auto &msg : comm_table2.messages_recv) { + + u64 size = msg.message_size; + i32 sender = msg.rank_sender; + i32 receiver = msg.rank_receiver; + size_t begin = shambase::get_check_ref(msg.message_bytebuf_offset_recv); + size_t end = begin + size; + + sham::DeviceBuffer recov(size, dev_sched); + + recv_buf.copy_range(begin, end, recov); + recv_payload_bufs.push_back( + RecvPayloadSer{sender, SerializeHelper(dev_sched, std::move(recov))}); + } { NamedStackEntry stack_loc2{"split recv comms"}; diff --git a/src/tests/shamalgs/collective/sparse_exchange_tests.cpp b/src/tests/shamalgs/collective/sparse_exchange_tests.cpp index 1222fd76f3..c6fc383d5e 100644 --- a/src/tests/shamalgs/collective/sparse_exchange_tests.cpp +++ b/src/tests/shamalgs/collective/sparse_exchange_tests.cpp @@ -54,16 +54,18 @@ void test_sparse_exchange(std::vector test_elements) { size_t sender_count = 0; for (u32 i = 0; i < test_elements.size(); i++) { if (test_elements[i].sender == shamcomm::world_rank()) { - messages_send.push_back(shamalgs::collective::CommMessageInfo{ - test_elements[i].size, - test_elements[i].sender, - test_elements[i].receiver, - std::nullopt, - sender_offset, - std::nullopt, - }); - - logger::info_ln("sparse exchange test", + messages_send.push_back( + shamalgs::collective::CommMessageInfo{ + test_elements[i].size, + test_elements[i].sender, + test_elements[i].receiver, + std::nullopt, + sender_offset, + std::nullopt, + }); + + logger::info_ln( + "sparse exchange test", "rank :", shamcomm::world_rank(), "send message : (", @@ -132,7 +134,16 @@ void test_sparse_exchange(std::vector test_elements) { size_t end = begin + test_elements[i].size; recv_buf.copy_range(begin, end, recov); - logger::info_ln("sparse exchange test","rank :", shamcomm::world_rank(), "recv message : (", test_elements[i].sender, "->", test_elements[i].receiver, ") data :", recov.copy_to_stdvec()); + logger::info_ln( + "sparse exchange test", + "rank :", + shamcomm::world_rank(), + "recv message : (", + test_elements[i].sender, + "->", + test_elements[i].receiver, + ") data :", + recov.copy_to_stdvec()); REQUIRE_EQUAL(recov.copy_to_stdvec(), ref_buf.copy_to_stdvec()); @@ -145,15 +156,15 @@ void test_sparse_exchange(std::vector test_elements) { } TestStart(Unittest, "shamalgs/collective/test_sparse_exchange", testsparsexchg_2, -1) { - - if(shamcomm::world_rank() == 0){ - logger::info_ln("sparse exchange test","empty comm"); + + if (shamcomm::world_rank() == 0) { + logger::info_ln("sparse exchange test", "empty comm"); } - + test_sparse_exchange({}); - if(shamcomm::world_rank() == 0){ - logger::info_ln("sparse exchange test","send to self"); + if (shamcomm::world_rank() == 0) { + logger::info_ln("sparse exchange test", "send to self"); } { @@ -161,17 +172,14 @@ TestStart(Unittest, "shamalgs/collective/test_sparse_exchange", testsparsexchg_2 std::mt19937 eng(0x123); std::vector test_elements; for (i32 i = 0; i < shamcomm::world_size(); i++) { - test_elements.push_back(TestElement{ - i, - i, - shamalgs::primitives::mock_value(eng, 1, 10)}); + test_elements.push_back( + TestElement{i, i, shamalgs::primitives::mock_value(eng, 1, 10)}); } test_sparse_exchange(test_elements); } - - if(shamcomm::world_rank() == 0){ - logger::info_ln("sparse exchange test","send to next"); + if (shamcomm::world_rank() == 0) { + logger::info_ln("sparse exchange test", "send to next"); } { @@ -179,27 +187,29 @@ TestStart(Unittest, "shamalgs/collective/test_sparse_exchange", testsparsexchg_2 std::mt19937 eng(0x123); std::vector test_elements; for (i32 i = 0; i < shamcomm::world_size(); i++) { - test_elements.push_back(TestElement{ - i, - (i + 1) % shamcomm::world_size(), - shamalgs::primitives::mock_value(eng, 1, 10)}); + test_elements.push_back( + TestElement{ + i, + (i + 1) % shamcomm::world_size(), + shamalgs::primitives::mock_value(eng, 1, 10)}); } test_sparse_exchange(test_elements); } - if(shamcomm::world_rank() == 0){ - logger::info_ln("sparse exchange test","random test"); + if (shamcomm::world_rank() == 0) { + logger::info_ln("sparse exchange test", "random test"); } { // random test std::mt19937 eng(0x123); std::vector test_elements; - for (u32 i = 0; i < 3*shamcomm::world_size(); i++) { - test_elements.push_back(TestElement{ - shamalgs::primitives::mock_value(eng, 0, shamcomm::world_size() - 1), - shamalgs::primitives::mock_value(eng, 0, shamcomm::world_size() - 1), - shamalgs::primitives::mock_value(eng, 1, 10)}); + for (u32 i = 0; i < 3 * shamcomm::world_size(); i++) { + test_elements.push_back( + TestElement{ + shamalgs::primitives::mock_value(eng, 0, shamcomm::world_size() - 1), + shamalgs::primitives::mock_value(eng, 0, shamcomm::world_size() - 1), + shamalgs::primitives::mock_value(eng, 1, 10)}); } test_sparse_exchange(test_elements); } From cedccc8071e0d311e21725f29e0b9b94b1500542 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Timoth=C3=A9e=20David--Cl=C3=A9ris?= Date: Thu, 25 Dec 2025 23:51:03 +0100 Subject: [PATCH 08/54] use an alloc cache in exchange --- .../collective/distributedDataComm.hpp | 9 +++++++- .../src/collective/distributedDataComm.cpp | 21 ++++++++++++++++--- .../zeus/src/modules/GhostZones.cpp | 9 ++++++-- .../scheduler/ReattributeDataUtility.hpp | 5 ++++- .../solvergraph/ExchangeGhostField.hpp | 3 +++ .../solvergraph/ExchangeGhostLayer.hpp | 2 ++ .../src/solvergraph/ExchangeGhostField.cpp | 3 ++- .../src/solvergraph/ExchangeGhostLayer.cpp | 3 ++- .../collective/distributedDataCommTests.cpp | 12 ++++++++--- 9 files changed, 55 insertions(+), 12 deletions(-) diff --git a/src/shamalgs/include/shamalgs/collective/distributedDataComm.hpp b/src/shamalgs/include/shamalgs/collective/distributedDataComm.hpp index 82e37493d0..4a487adbe1 100644 --- a/src/shamalgs/include/shamalgs/collective/distributedDataComm.hpp +++ b/src/shamalgs/include/shamalgs/collective/distributedDataComm.hpp @@ -34,11 +34,17 @@ namespace shamalgs::collective { using SerializedDDataComm = shambase::DistributedDataShared>; + struct DDSCommCache { + std::unique_ptr> cache1; + std::unique_ptr> cache2; + }; + void distributed_data_sparse_comm( std::shared_ptr dev_sched, SerializedDDataComm &send_ddistrib_data, SerializedDDataComm &recv_distrib_data, std::function rank_getter, + DDSCommCache &cache, std::optional comm_table = {}); template @@ -49,6 +55,7 @@ namespace shamalgs::collective { std::function rank_getter, std::function(T &)> serialize, std::function &&)> deserialize, + DDSCommCache &cache, std::optional comm_table = {}) { StackEntry stack_loc{}; @@ -68,7 +75,7 @@ namespace shamalgs::collective { SerializedDDataComm dcomm_recv; - distributed_data_sparse_comm(dev_sched, dcomm_send, dcomm_recv, rank_getter); + distributed_data_sparse_comm(dev_sched, dcomm_send, dcomm_recv, rank_getter, cache); recv_distrib_data = dcomm_recv.map([&](u64, u64, sham::DeviceBuffer &buf) { // exchange the buffer held by the distrib data and give it to the deserializer diff --git a/src/shamalgs/src/collective/distributedDataComm.cpp b/src/shamalgs/src/collective/distributedDataComm.cpp index a84249a987..4524717c7b 100644 --- a/src/shamalgs/src/collective/distributedDataComm.cpp +++ b/src/shamalgs/src/collective/distributedDataComm.cpp @@ -79,6 +79,7 @@ namespace shamalgs::collective { SerializedDDataComm &send_distrib_data, SerializedDDataComm &recv_distrib_data, std::function rank_getter, + DDSCommCache &cache, std::optional comm_table) { StackEntry stack_loc{}; @@ -139,9 +140,23 @@ namespace shamalgs::collective { shamalgs::collective::sparse_exchange( dev_sched, send_buf, recv_buf, comm_table2); } else { - auto send_buf_host = send_buf.copy_to(); - sham::DeviceBuffer recv_buf_host( - comm_table2.recv_total_size, dev_sched); + if (!cache.cache1) { + cache.cache1 = std::make_unique>( + send_buf.get_size(), dev_sched); + } else { + cache.cache1->resize(send_buf.get_size()); + } + cache.cache1->copy_from(send_buf); + sham::DeviceBuffer &send_buf_host = *cache.cache1; + + if (!cache.cache2) { + cache.cache2 = std::make_unique>( + comm_table2.recv_total_size, dev_sched); + } else { + cache.cache2->resize(comm_table2.recv_total_size); + } + sham::DeviceBuffer &recv_buf_host = *cache.cache2; + shamalgs::collective::sparse_exchange( dev_sched, send_buf_host, recv_buf_host, comm_table2); recv_buf.copy_from(recv_buf_host); diff --git a/src/shammodels/zeus/src/modules/GhostZones.cpp b/src/shammodels/zeus/src/modules/GhostZones.cpp index 4cdf79f1cd..bc10ba344e 100644 --- a/src/shammodels/zeus/src/modules/GhostZones.cpp +++ b/src/shammodels/zeus/src/modules/GhostZones.cpp @@ -210,6 +210,7 @@ shambase::DistributedDataShared shammodels::zeu shambase::DistributedDataShared recv_dat; + shamalgs::collective::DDSCommCache cache; shamalgs::collective::serialize_sparse_comm( shamsys::instance::get_compute_scheduler_ptr(), std::forward>(interf), @@ -229,7 +230,8 @@ shambase::DistributedDataShared shammodels::zeu shamsys::instance::get_compute_scheduler_ptr(), std::forward>(buf)); return shamrock::patch::PatchDataLayer::deserialize_buf(ser, pdl_ptr); - }); + }, + cache); return recv_dat; } @@ -243,6 +245,8 @@ shambase::DistributedDataShared> shammodels::zeus::modules::Gh shambase::DistributedDataShared> recv_dat; + shamalgs::collective::DDSCommCache cache; + shamalgs::collective::serialize_sparse_comm>( shamsys::instance::get_compute_scheduler_ptr(), std::forward>>(interf), @@ -262,7 +266,8 @@ shambase::DistributedDataShared> shammodels::zeus::modules::Gh shamsys::instance::get_compute_scheduler_ptr(), std::forward>(buf)); return PatchDataField::deserialize_full(ser); - }); + }, + cache); return recv_dat; } diff --git a/src/shamrock/include/shamrock/scheduler/ReattributeDataUtility.hpp b/src/shamrock/include/shamrock/scheduler/ReattributeDataUtility.hpp index 11e044ba43..7d80af3f80 100644 --- a/src/shamrock/include/shamrock/scheduler/ReattributeDataUtility.hpp +++ b/src/shamrock/include/shamrock/scheduler/ReattributeDataUtility.hpp @@ -229,6 +229,8 @@ namespace shamrock { DistributedDataShared recv_dat; + shamalgs::collective::DDSCommCache cache; + shamalgs::collective::serialize_sparse_comm( shamsys::instance::get_compute_scheduler_ptr(), std::move(part_exchange), @@ -248,7 +250,8 @@ namespace shamrock { shamsys::instance::get_compute_scheduler_ptr(), std::forward>(buf)); return PatchDataLayer::deserialize_buf(ser, sched.get_layout_ptr()); - }); + }, + cache); recv_dat.for_each([&](u64 sender, u64 receiver, PatchDataLayer &pdat) { shamlog_debug_ln("Part Exchanges", format("send = {} recv = {}", sender, receiver)); diff --git a/src/shamrock/include/shamrock/solvergraph/ExchangeGhostField.hpp b/src/shamrock/include/shamrock/solvergraph/ExchangeGhostField.hpp index b92f227efd..ce19a86797 100644 --- a/src/shamrock/include/shamrock/solvergraph/ExchangeGhostField.hpp +++ b/src/shamrock/include/shamrock/solvergraph/ExchangeGhostField.hpp @@ -19,6 +19,7 @@ * domains in the Shamrock hydrodynamics framework. */ +#include "shamalgs/collective/distributedDataComm.hpp" #include "shamrock/solvergraph/INode.hpp" #include "shamrock/solvergraph/PatchDataFieldDDShared.hpp" #include "shamrock/solvergraph/ScalarsEdge.hpp" @@ -57,6 +58,8 @@ namespace shamrock::solvergraph { template class ExchangeGhostField : public shamrock::solvergraph::INode { + shamalgs::collective::DDSCommCache cache; + public: /** * @brief Default constructor for ExchangeGhostField node diff --git a/src/shamrock/include/shamrock/solvergraph/ExchangeGhostLayer.hpp b/src/shamrock/include/shamrock/solvergraph/ExchangeGhostLayer.hpp index c61b66cd23..332936d5fe 100644 --- a/src/shamrock/include/shamrock/solvergraph/ExchangeGhostLayer.hpp +++ b/src/shamrock/include/shamrock/solvergraph/ExchangeGhostLayer.hpp @@ -19,6 +19,7 @@ * domains in the Shamrock hydrodynamics framework. */ +#include "shamalgs/collective/distributedDataComm.hpp" #include "shamrock/solvergraph/INode.hpp" #include "shamrock/solvergraph/PatchDataLayerDDShared.hpp" #include "shamrock/solvergraph/ScalarsEdge.hpp" @@ -55,6 +56,7 @@ namespace shamrock::solvergraph { * @endcode */ class ExchangeGhostLayer : public shamrock::solvergraph::INode { + shamalgs::collective::DDSCommCache cache; std::shared_ptr ghost_layer_layout; public: diff --git a/src/shamrock/src/solvergraph/ExchangeGhostField.cpp b/src/shamrock/src/solvergraph/ExchangeGhostField.cpp index c7d10a5a64..a9061c2759 100644 --- a/src/shamrock/src/solvergraph/ExchangeGhostField.cpp +++ b/src/shamrock/src/solvergraph/ExchangeGhostField.cpp @@ -50,7 +50,8 @@ void shamrock::solvergraph::ExchangeGhostField::_impl_evaluate_internal() { shamsys::instance::get_compute_scheduler_ptr(), std::forward>(buf)); return PatchDataField::deserialize_full(ser); - }); + }, + cache); ghost_layer.patchdata_fields = std::move(recv_dat); } diff --git a/src/shamrock/src/solvergraph/ExchangeGhostLayer.cpp b/src/shamrock/src/solvergraph/ExchangeGhostLayer.cpp index 98c325e0f8..63271f8751 100644 --- a/src/shamrock/src/solvergraph/ExchangeGhostLayer.cpp +++ b/src/shamrock/src/solvergraph/ExchangeGhostLayer.cpp @@ -49,7 +49,8 @@ void shamrock::solvergraph::ExchangeGhostLayer::_impl_evaluate_internal() { shamsys::instance::get_compute_scheduler_ptr(), std::forward>(buf)); return shamrock::patch::PatchDataLayer::deserialize_buf(ser, ghost_layer_layout); - }); + }, + cache); ghost_layer.patchdatas = std::move(recv_dat); } diff --git a/src/tests/shamalgs/collective/distributedDataCommTests.cpp b/src/tests/shamalgs/collective/distributedDataCommTests.cpp index 0fc68ac143..c4c9aed5c9 100644 --- a/src/tests/shamalgs/collective/distributedDataCommTests.cpp +++ b/src/tests/shamalgs/collective/distributedDataCommTests.cpp @@ -71,9 +71,15 @@ void distribdata_sparse_comm_test(std::string prefix) { }); shamalgs::collective::SerializedDDataComm recv_data; - distributed_data_sparse_comm(get_compute_scheduler_ptr(), send_data, recv_data, [&](u64 id) { - return rank_owner[id]; - }); + shamalgs::collective::DDSCommCache cache; + distributed_data_sparse_comm( + get_compute_scheduler_ptr(), + send_data, + recv_data, + [&](u64 id) { + return rank_owner[id]; + }, + cache); shamalgs::collective::SerializedDDataComm recv_data_ref; dat_ref.for_each([&](u64 sender, u64 receiver, sham::DeviceBuffer &buf) { From 5f0c757e86ce0aa49385ab6e63e06e67e8cacf03 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Timoth=C3=A9e=20David--Cl=C3=A9ris?= Date: Fri, 26 Dec 2025 09:57:26 +0100 Subject: [PATCH 09/54] enable persistance --- .../include/shammodels/sph/BasicSPHGhosts.hpp | 22 ++++++++++++------- .../shammodels/sph/modules/SolverStorage.hpp | 3 +++ src/shammodels/sph/src/Solver.cpp | 12 +++++++--- 3 files changed, 26 insertions(+), 11 deletions(-) diff --git a/src/shammodels/sph/include/shammodels/sph/BasicSPHGhosts.hpp b/src/shammodels/sph/include/shammodels/sph/BasicSPHGhosts.hpp index 4202dbbdf4..f9629c3474 100644 --- a/src/shammodels/sph/include/shammodels/sph/BasicSPHGhosts.hpp +++ b/src/shammodels/sph/include/shammodels/sph/BasicSPHGhosts.hpp @@ -82,6 +82,8 @@ namespace shammodels::sph { std::shared_ptr xyzh_ghost_layout; + std::shared_ptr exchange_gz_positions; + std::shared_ptr> patch_rank_owner; BasicSPHGhostHandler( @@ -92,6 +94,9 @@ namespace shammodels::sph { xyzh_ghost_layout = std::make_shared(); xyzh_ghost_layout->add_field("xyz", 1); xyzh_ghost_layout->add_field("hpart", 1); + + exchange_gz_positions + = std::make_shared(xyzh_ghost_layout); } /** @@ -320,7 +325,8 @@ namespace shammodels::sph { inline shambase::DistributedDataShared communicate_pdat( const std::shared_ptr &pdl_ptr, - shambase::DistributedDataShared &&interf) { + shambase::DistributedDataShared &&interf, + std::shared_ptr exchange_gz_node) { StackEntry stack_loc{}; // ---------------------------------------------------------------------------------------- @@ -332,8 +338,6 @@ namespace shammodels::sph { = std::forward>( interf); - std::shared_ptr exchange_gz_node - = std::make_shared(pdl_ptr); exchange_gz_node->set_edges(this->patch_rank_owner, exchange_gz_edge); exchange_gz_node->evaluate(); @@ -373,7 +377,9 @@ namespace shammodels::sph { template inline shambase::DistributedDataShared> communicate_pdatfield( - shambase::DistributedDataShared> &&interf, u32 nvar) { + shambase::DistributedDataShared> &&interf, + u32 nvar, + std::shared_ptr> exchange_gz_node) { StackEntry stack_loc{}; // ---------------------------------------------------------------------------------------- @@ -384,8 +390,6 @@ namespace shammodels::sph { exchange_gz_edge->patchdata_fields = std::forward>>(interf); - std::shared_ptr> exchange_gz_node - = std::make_shared>(); exchange_gz_node->set_edges(this->patch_rank_owner, exchange_gz_edge); exchange_gz_node->evaluate(); @@ -425,7 +429,8 @@ namespace shammodels::sph { inline shambase::DistributedDataShared build_communicate_positions(shambase::DistributedDataShared &builder) { auto pos_interf = build_position_interf_field(builder); - return communicate_pdat(xyzh_ghost_layout, std::move(pos_interf)); + return communicate_pdat( + xyzh_ghost_layout, std::move(pos_interf), exchange_gz_positions); } template @@ -486,7 +491,8 @@ namespace shammodels::sph { inline shambase::DistributedData build_comm_merge_positions(shambase::DistributedDataShared &builder) { auto pos_interf = build_position_interf_field(builder); - return merge_position_buf(communicate_pdat(xyzh_ghost_layout, std::move(pos_interf))); + return merge_position_buf( + communicate_pdat(xyzh_ghost_layout, std::move(pos_interf), exchange_gz_positions)); } }; diff --git a/src/shammodels/sph/include/shammodels/sph/modules/SolverStorage.hpp b/src/shammodels/sph/include/shammodels/sph/modules/SolverStorage.hpp index 70f1d436cf..e6d0a331ab 100644 --- a/src/shammodels/sph/include/shammodels/sph/modules/SolverStorage.hpp +++ b/src/shammodels/sph/include/shammodels/sph/modules/SolverStorage.hpp @@ -52,6 +52,9 @@ namespace shammodels::sph { using RTree = shamtree::CompressedLeafBVH; + std::shared_ptr exchange_gz_node; + std::shared_ptr> exchange_gz_alpha; + std::shared_ptr> part_counts; std::shared_ptr> part_counts_with_ghost; diff --git a/src/shammodels/sph/src/Solver.cpp b/src/shammodels/sph/src/Solver.cpp index bd32de1cb8..1f20f81f69 100644 --- a/src/shammodels/sph/src/Solver.cpp +++ b/src/shammodels/sph/src/Solver.cpp @@ -126,6 +126,11 @@ void shammodels::sph::Solver::init_solver_graph() { storage.pressure = std::make_shared>(1, "pressure", "P"); storage.soundspeed = std::make_shared>(1, "soundspeed", "c_s"); + + storage.exchange_gz_alpha + = std::make_shared>(); + storage.exchange_gz_node + = std::make_shared(storage.ghost_layout.get()); } template class Kern> @@ -977,8 +982,8 @@ void shammodels::sph::Solver::communicate_merge_ghosts_fields() { } }); - shambase::DistributedDataShared interf_pdat - = ghost_handle.communicate_pdat(ghost_layout_ptr, std::move(pdat_interf)); + shambase::DistributedDataShared interf_pdat = ghost_handle.communicate_pdat( + ghost_layout_ptr, std::move(pdat_interf), storage.exchange_gz_node); std::map sz_interf_map; interf_pdat.for_each([&](u64 s, u64 r, PatchDataLayer &pdat_interf) { @@ -1571,7 +1576,8 @@ shammodels::sph::TimestepLog shammodels::sph::Solver::evolve_once() }); shambase::DistributedDataShared> interf_pdat - = ghost_handle.communicate_pdatfield(std::move(field_interf), 1); + = ghost_handle.communicate_pdatfield( + std::move(field_interf), 1, storage.exchange_gz_alpha); shambase::DistributedData> merged_field = ghost_handle.template merge_native, PatchDataField>( From 6dad87c9cbd641d24ae5eeb46ef7db1aff0a0494 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Timoth=C3=A9e=20David--Cl=C3=A9ris?= Date: Sat, 27 Dec 2025 08:38:33 +0100 Subject: [PATCH 10/54] even less allocs --- .../src/collective/distributedDataComm.cpp | 18 +++++++++++++ .../src/details/internal_alloc.cpp | 2 ++ .../include/shammodels/sph/BasicSPHGhosts.hpp | 25 ++++++++----------- .../sph/include/shammodels/sph/Solver.hpp | 15 ++++++++--- .../shammodels/sph/modules/SolverStorage.hpp | 4 ++- src/shammodels/sph/src/Solver.cpp | 18 ++++++++----- 6 files changed, 58 insertions(+), 24 deletions(-) diff --git a/src/shamalgs/src/collective/distributedDataComm.cpp b/src/shamalgs/src/collective/distributedDataComm.cpp index 4524717c7b..4980ea5855 100644 --- a/src/shamalgs/src/collective/distributedDataComm.cpp +++ b/src/shamalgs/src/collective/distributedDataComm.cpp @@ -141,18 +141,36 @@ namespace shamalgs::collective { dev_sched, send_buf, recv_buf, comm_table2); } else { if (!cache.cache1) { + // logger::info_ln("ddcomm", "alloc cache1", shambase::fmt_callstack()); cache.cache1 = std::make_unique>( send_buf.get_size(), dev_sched); } else { + // logger::info_ln( + // "ddcomm", + // shambase::format( + // "resize cache1 from {} to {} (cache1 ptr = {})", + // cache.cache1->get_size(), + // send_buf.get_size(), + // static_cast(cache.cache1.get())), + // shambase::fmt_callstack()); cache.cache1->resize(send_buf.get_size()); } cache.cache1->copy_from(send_buf); sham::DeviceBuffer &send_buf_host = *cache.cache1; if (!cache.cache2) { + // logger::info_ln("ddcomm", "alloc cache2", shambase::fmt_callstack()); cache.cache2 = std::make_unique>( comm_table2.recv_total_size, dev_sched); } else { + // logger::info_ln( + // "ddcomm", + // shambase::format( + // "resize cache2 from {} to {} (cache2 ptr = {})", + // cache.cache2->get_size(), + // comm_table2.recv_total_size, + // static_cast(cache.cache2.get())), + // shambase::fmt_callstack()); cache.cache2->resize(comm_table2.recv_total_size); } sham::DeviceBuffer &recv_buf_host = *cache.cache2; diff --git a/src/shambackends/src/details/internal_alloc.cpp b/src/shambackends/src/details/internal_alloc.cpp index a58efa31d8..072b704420 100644 --- a/src/shambackends/src/details/internal_alloc.cpp +++ b/src/shambackends/src/details/internal_alloc.cpp @@ -326,6 +326,8 @@ namespace sham::details { register_alloc_shared(sz, end_time - start_time); } else if constexpr (target == host) { register_alloc_host(sz, end_time - start_time); + // logger::info_ln("internal_alloc", "alloc host : sz =", sz, " | time =", end_time - + // start_time); } return usm_ptr; diff --git a/src/shammodels/sph/include/shammodels/sph/BasicSPHGhosts.hpp b/src/shammodels/sph/include/shammodels/sph/BasicSPHGhosts.hpp index f9629c3474..81aaa41b41 100644 --- a/src/shammodels/sph/include/shammodels/sph/BasicSPHGhosts.hpp +++ b/src/shammodels/sph/include/shammodels/sph/BasicSPHGhosts.hpp @@ -80,24 +80,17 @@ namespace shammodels::sph { using GeneratorMap = shambase::DistributedDataShared; - std::shared_ptr xyzh_ghost_layout; - - std::shared_ptr exchange_gz_positions; + std::shared_ptr &xyzh_ghost_layout; std::shared_ptr> patch_rank_owner; BasicSPHGhostHandler( PatchScheduler &sched, Config ghost_config, - std::shared_ptr> patch_rank_owner) - : sched(sched), ghost_config(ghost_config), patch_rank_owner(patch_rank_owner) { - xyzh_ghost_layout = std::make_shared(); - xyzh_ghost_layout->add_field("xyz", 1); - xyzh_ghost_layout->add_field("hpart", 1); - - exchange_gz_positions - = std::make_shared(xyzh_ghost_layout); - } + std::shared_ptr> patch_rank_owner, + std::shared_ptr &xyzh_ghost_layout) + : sched(sched), ghost_config(ghost_config), patch_rank_owner(patch_rank_owner), + xyzh_ghost_layout(xyzh_ghost_layout) {} /** * @brief Find interfaces and their metadata @@ -427,7 +420,9 @@ namespace shammodels::sph { } inline shambase::DistributedDataShared - build_communicate_positions(shambase::DistributedDataShared &builder) { + build_communicate_positions( + shambase::DistributedDataShared &builder, + std::shared_ptr &exchange_gz_positions) { auto pos_interf = build_position_interf_field(builder); return communicate_pdat( xyzh_ghost_layout, std::move(pos_interf), exchange_gz_positions); @@ -489,7 +484,9 @@ namespace shammodels::sph { } inline shambase::DistributedData - build_comm_merge_positions(shambase::DistributedDataShared &builder) { + build_comm_merge_positions( + shambase::DistributedDataShared &builder, + std::shared_ptr &exchange_gz_positions) { auto pos_interf = build_position_interf_field(builder); return merge_position_buf( communicate_pdat(xyzh_ghost_layout, std::move(pos_interf), exchange_gz_positions)); diff --git a/src/shammodels/sph/include/shammodels/sph/Solver.hpp b/src/shammodels/sph/include/shammodels/sph/Solver.hpp index 6916452158..e372e16715 100644 --- a/src/shammodels/sph/include/shammodels/sph/Solver.hpp +++ b/src/shammodels/sph/include/shammodels/sph/Solver.hpp @@ -103,12 +103,20 @@ namespace shammodels::sph { if (SolverBCFree *c = std::get_if(&solver_config.boundary_config.config)) { storage.ghost_handler.set( - GhostHandle{scheduler(), BCFree{}, storage.patch_rank_owner}); + GhostHandle{ + scheduler(), + BCFree{}, + storage.patch_rank_owner, + storage.xyzh_ghost_layout}); } else if ( SolverBCPeriodic *c = std::get_if(&solver_config.boundary_config.config)) { storage.ghost_handler.set( - GhostHandle{scheduler(), BCPeriodic{}, storage.patch_rank_owner}); + GhostHandle{ + scheduler(), + BCPeriodic{}, + storage.patch_rank_owner, + storage.xyzh_ghost_layout}); } else if ( SolverBCShearingPeriodic *c = std::get_if(&solver_config.boundary_config.config)) { @@ -117,7 +125,8 @@ namespace shammodels::sph { scheduler(), BCShearingPeriodic{ c->shear_base, c->shear_dir, c->shear_speed * time_val, c->shear_speed}, - storage.patch_rank_owner}); + storage.patch_rank_owner, + storage.xyzh_ghost_layout}); } } inline void reset_ghost_handler() { storage.ghost_handler.reset(); } diff --git a/src/shammodels/sph/include/shammodels/sph/modules/SolverStorage.hpp b/src/shammodels/sph/include/shammodels/sph/modules/SolverStorage.hpp index e6d0a331ab..6f3cfd9ccd 100644 --- a/src/shammodels/sph/include/shammodels/sph/modules/SolverStorage.hpp +++ b/src/shammodels/sph/include/shammodels/sph/modules/SolverStorage.hpp @@ -54,6 +54,7 @@ namespace shammodels::sph { std::shared_ptr exchange_gz_node; std::shared_ptr> exchange_gz_alpha; + std::shared_ptr exchange_gz_positions; std::shared_ptr> part_counts; std::shared_ptr> part_counts_with_ghost; @@ -80,7 +81,8 @@ namespace shammodels::sph { std::shared_ptr> omega; - Component> ghost_layout; + std::shared_ptr ghost_layout; + std::shared_ptr xyzh_ghost_layout; Component> merged_patchdata_ghost; diff --git a/src/shammodels/sph/src/Solver.cpp b/src/shammodels/sph/src/Solver.cpp index 1f20f81f69..05e0258c0d 100644 --- a/src/shammodels/sph/src/Solver.cpp +++ b/src/shammodels/sph/src/Solver.cpp @@ -130,7 +130,9 @@ void shammodels::sph::Solver::init_solver_graph() { storage.exchange_gz_alpha = std::make_shared>(); storage.exchange_gz_node - = std::make_shared(storage.ghost_layout.get()); + = std::make_shared(storage.ghost_layout); + storage.exchange_gz_positions + = std::make_shared(storage.xyzh_ghost_layout); } template class Kern> @@ -383,8 +385,8 @@ void shammodels::sph::Solver::merge_position_ghost() { StackEntry stack_loc{}; - storage.merged_xyzh.set( - storage.ghost_handler.get().build_comm_merge_positions(storage.ghost_patch_cache.get())); + storage.merged_xyzh.set(storage.ghost_handler.get().build_comm_merge_positions( + storage.ghost_patch_cache.get(), storage.exchange_gz_positions)); { // set element counts shambase::get_check_ref(storage.part_counts).indexes @@ -744,12 +746,16 @@ void shammodels::sph::Solver::sph_prestep(Tscal time_val, Tscal dt) template class Kern> void shammodels::sph::Solver::init_ghost_layout() { - storage.ghost_layout.set(std::make_shared()); + storage.ghost_layout = std::make_shared(); shamrock::patch::PatchDataLayerLayout &ghost_layout - = shambase::get_check_ref(storage.ghost_layout.get()); + = shambase::get_check_ref(storage.ghost_layout); solver_config.set_ghost_layout(ghost_layout); + + storage.xyzh_ghost_layout = std::make_shared(); + storage.xyzh_ghost_layout->template add_field("xyz", 1); + storage.xyzh_ghost_layout->template add_field("hpart", 1); } template class Kern> @@ -874,7 +880,7 @@ void shammodels::sph::Solver::communicate_merge_ghosts_fields() { const u32 iepsilon = (has_epsilon_field) ? pdl.get_field_idx("epsilon") : 0; const u32 ideltav = (has_deltav_field) ? pdl.get_field_idx("deltav") : 0; - auto ghost_layout_ptr = storage.ghost_layout.get(); + auto &ghost_layout_ptr = storage.ghost_layout; shamrock::patch::PatchDataLayerLayout &ghost_layout = shambase::get_check_ref(ghost_layout_ptr); u32 ihpart_interf = ghost_layout.get_field_idx("hpart"); u32 iuint_interf = ghost_layout.get_field_idx("uint"); From 229ac8b85ab859f80005252de938aba9d5c26810 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Timoth=C3=A9e=20David--Cl=C3=A9ris?= Date: Sat, 27 Dec 2025 08:54:54 +0100 Subject: [PATCH 11/54] fix precommit --- .../sph/include/shammodels/sph/modules/SolverStorage.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/shammodels/sph/include/shammodels/sph/modules/SolverStorage.hpp b/src/shammodels/sph/include/shammodels/sph/modules/SolverStorage.hpp index 7e877c5756..a966683c50 100644 --- a/src/shammodels/sph/include/shammodels/sph/modules/SolverStorage.hpp +++ b/src/shammodels/sph/include/shammodels/sph/modules/SolverStorage.hpp @@ -57,7 +57,7 @@ namespace shammodels::sph { std::shared_ptr exchange_gz_node; std::shared_ptr> exchange_gz_alpha; std::shared_ptr exchange_gz_positions; - + shamrock::solvergraph::SolverGraph solver_graph; std::shared_ptr solver_sequence; From 3c8c2d35c457b037a4a0907f3329b7b325a4584c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Timoth=C3=A9e=20David--Cl=C3=A9ris?= Date: Mon, 29 Dec 2025 17:37:53 +0100 Subject: [PATCH 12/54] fix compile and improve setup perf --- .../shammodels/gsph/modules/SolverStorage.hpp | 5 ++- src/shammodels/gsph/src/Solver.cpp | 24 +++++++---- src/shammodels/sph/src/modules/SPHSetup.cpp | 42 +++++++++++++------ .../include/shamrock/patch/PatchDataField.hpp | 39 ++++++++++++----- 4 files changed, 79 insertions(+), 31 deletions(-) diff --git a/src/shammodels/gsph/include/shammodels/gsph/modules/SolverStorage.hpp b/src/shammodels/gsph/include/shammodels/gsph/modules/SolverStorage.hpp index ebf8fbc6b0..8260c916ac 100644 --- a/src/shammodels/gsph/include/shammodels/gsph/modules/SolverStorage.hpp +++ b/src/shammodels/gsph/include/shammodels/gsph/modules/SolverStorage.hpp @@ -70,6 +70,9 @@ namespace shammodels::gsph { using RTree = shamtree::CompressedLeafBVH; + std::shared_ptr exchange_gz_node; + std::shared_ptr exchange_gz_positions; + /// Particle counts per patch std::shared_ptr> part_counts; std::shared_ptr> part_counts_with_ghost; @@ -104,7 +107,7 @@ namespace shammodels::gsph { /// Ghost data layout and merged data std::shared_ptr xyzh_ghost_layout; - Component> ghost_layout; + std::shared_ptr ghost_layout; Component> merged_patchdata_ghost; diff --git a/src/shammodels/gsph/src/Solver.cpp b/src/shammodels/gsph/src/Solver.cpp index 0a47626de8..38146c1b4f 100644 --- a/src/shammodels/gsph/src/Solver.cpp +++ b/src/shammodels/gsph/src/Solver.cpp @@ -92,6 +92,11 @@ void shammodels::gsph::Solver::init_solver_graph() { storage.pressure = std::make_shared>(1, "pressure", "P"); storage.soundspeed = std::make_shared>(1, "soundspeed", "c_s"); + + storage.exchange_gz_node + = std::make_shared(storage.ghost_layout); + storage.exchange_gz_positions + = std::make_shared(storage.xyzh_ghost_layout); } template class Kern> @@ -177,8 +182,11 @@ template class Kern> void shammodels::gsph::Solver::merge_position_ghost() { StackEntry stack_loc{}; - storage.merged_xyzh.set( - storage.ghost_handler.get().build_comm_merge_positions(storage.ghost_patch_cache.get())); + std::shared_ptr exchange_gz_node + = std::make_shared(storage.xyzh_ghost_layout); + + storage.merged_xyzh.set(storage.ghost_handler.get().build_comm_merge_positions( + storage.ghost_patch_cache.get(), exchange_gz_node)); // Set element counts shambase::get_check_ref(storage.part_counts).indexes @@ -591,11 +599,10 @@ void shammodels::gsph::Solver::init_ghost_layout() { storage.xyzh_ghost_layout->template add_field("hpart", 1); // Reset first in case it was set from a previous timestep - storage.ghost_layout.reset(); - storage.ghost_layout.set(std::make_shared()); + storage.ghost_layout = std::make_shared(); shamrock::patch::PatchDataLayerLayout &ghost_layout - = shambase::get_check_ref(storage.ghost_layout.get()); + = shambase::get_check_ref(storage.ghost_layout); solver_config.set_ghost_layout(ghost_layout); } @@ -618,7 +625,7 @@ void shammodels::gsph::Solver::communicate_merge_ghosts_fields() { const bool has_uint = solver_config.has_field_uint(); const u32 iuint = has_uint ? pdl.get_field_idx("uint") : 0; - auto ghost_layout_ptr = storage.ghost_layout.get(); + auto ghost_layout_ptr = storage.ghost_layout; shamrock::patch::PatchDataLayerLayout &ghost_layout = shambase::get_check_ref(ghost_layout_ptr); u32 ihpart_interf = ghost_layout.get_field_idx("hpart"); u32 ivxyz_interf = ghost_layout.get_field_idx("vxyz"); @@ -683,9 +690,12 @@ void shammodels::gsph::Solver::communicate_merge_ghosts_fields() { } }); + std::shared_ptr exchange_gz_node + = std::make_shared(storage.ghost_layout); + // Communicate ghost data across MPI ranks shambase::DistributedDataShared interf_pdat - = ghost_handle.communicate_pdat(ghost_layout_ptr, std::move(pdat_interf)); + = ghost_handle.communicate_pdat(ghost_layout_ptr, std::move(pdat_interf), exchange_gz_node); // Count total ghost particles per patch std::map sz_interf_map; diff --git a/src/shammodels/sph/src/modules/SPHSetup.cpp b/src/shammodels/sph/src/modules/SPHSetup.cpp index 921d611159..6031ece3fd 100644 --- a/src/shammodels/sph/src/modules/SPHSetup.cpp +++ b/src/shammodels/sph/src/modules/SPHSetup.cpp @@ -417,12 +417,17 @@ void shammodels::sph::modules::SPHSetup::apply_setup_new( auto inject_in_local_domains = [&sched, &inserter, &compute_load, &insert_step, &log_inject_status]( shamrock::patch::PatchDataLayer &to_insert) { + __shamrock_stack_entry(); + bool has_been_limited = true; while (has_been_limited) { has_been_limited = false; using namespace shamrock::patch; + auto dev_sched = shamsys::instance::get_compute_scheduler_ptr(); + sham::DeviceBuffer mask_get_ids_where(0, dev_sched); + // inject in local domains first PatchCoordTransform ptransf = sched.get_sim_box().get_patch_transform(); sched.for_each_local_patchdata([&](const Patch p, PatchDataLayer &pdat) { @@ -430,7 +435,8 @@ void shammodels::sph::modules::SPHSetup::apply_setup_new( PatchDataField &xyz = to_insert.get_field(0); - auto ids = xyz.get_ids_where( + auto ids = xyz.get_ids_where_recycle_buffer( + mask_get_ids_where, [](auto access, u32 id, shammath::CoordRange patch_coord) { Tvec tmp = access[id]; return patch_coord.contain_pos(tmp); @@ -461,21 +467,12 @@ void shammodels::sph::modules::SPHSetup::apply_setup_new( } }; - u32 step_count = 0; - while (!shamalgs::collective::are_all_rank_true(to_insert.is_empty(), MPI_COMM_WORLD)) { - - // assume that the sched is synchronized and that there is at least a patch. - // TODO actually check that - - using namespace shamrock::patch; - - auto dev_sched = shamsys::instance::get_compute_scheduler_ptr(); + auto get_index_per_ranks = [&]() { + __shamrock_stack_entry(); SerialPatchTree sptree = SerialPatchTree::build(sched); sptree.attach_buf(); - inject_in_local_domains(to_insert); - // find where each particle should be inserted PatchDataField &pos_field = to_insert.get_field(0); @@ -507,6 +504,24 @@ void shammodels::sph::modules::SPHSetup::apply_setup_new( "a new id could not be computed"); } + return index_per_ranks; + }; + + shamalgs::collective::DDSCommCache comm_cache; + u32 step_count = 0; + while (!shamalgs::collective::are_all_rank_true(to_insert.is_empty(), MPI_COMM_WORLD)) { + + // assume that the sched is synchronized and that there is at least a patch. + // TODO actually check that + + using namespace shamrock::patch; + + auto dev_sched = shamsys::instance::get_compute_scheduler_ptr(); + + inject_in_local_domains(to_insert); + + std::unordered_map> index_per_ranks = get_index_per_ranks(); + // allgather the list of messages // format:(u32_2(sender_rank, receiver_rank), u64(indices_size)) std::vector send_msg; @@ -645,7 +660,8 @@ void shammodels::sph::modules::SPHSetup::apply_setup_new( // serializer shamalgs::SerializeHelper ser(dev_sched, std::forward>(buf)); return PatchDataLayer::deserialize_buf(ser, sched.get_layout_ptr()); - }); + }, + comm_cache); // insert the data into the data to be inserted recv_dat.for_each([&](u64 sender, u64 receiver, PatchDataLayer &pdat) { diff --git a/src/shamrock/include/shamrock/patch/PatchDataField.hpp b/src/shamrock/include/shamrock/patch/PatchDataField.hpp index 3bed0f3832..28fc2633b7 100644 --- a/src/shamrock/include/shamrock/patch/PatchDataField.hpp +++ b/src/shamrock/include/shamrock/patch/PatchDataField.hpp @@ -341,6 +341,34 @@ class PatchDataField { } } + template + inline sham::DeviceBuffer get_ids_where_recycle_buffer( + sham::DeviceBuffer &mask, Lambdacd &&cd_true, Args... args) const { + StackEntry stack_loc{}; + + auto dev_sched = shamsys::instance::get_compute_scheduler_ptr(); + sham::DeviceQueue &q = shambase::get_check_ref(dev_sched).get_queue(); + + auto obj_cnt = get_obj_cnt(); + if (obj_cnt > 0) { + // buffer of booleans to store result of the condition + mask.resize(obj_cnt); + + sham::kernel_call( + q, + sham::MultiRef{buf}, + sham::MultiRef{mask}, + obj_cnt, + [=, nvar_field = nvar](u32 id, const T *__restrict acc, u32 *__restrict acc_mask) { + acc_mask[id] = cd_true(acc, id * nvar_field, args...); + }); + + return shamalgs::stream_compact(dev_sched, mask, obj_cnt); + } else { + return sham::DeviceBuffer(0, dev_sched); + } + } + /** * @brief Same function as @see PatchDataField#get_ids_set_where but return a optional * sycl::buffer of the found index @@ -363,16 +391,7 @@ class PatchDataField { // buffer of booleans to store result of the condition sham::DeviceBuffer mask(obj_cnt, dev_sched); - sham::kernel_call( - q, - sham::MultiRef{buf}, - sham::MultiRef{mask}, - obj_cnt, - [=, nvar_field = nvar](u32 id, const T *__restrict acc, u32 *__restrict acc_mask) { - acc_mask[id] = cd_true(acc, id * nvar_field, args...); - }); - - return shamalgs::stream_compact(dev_sched, mask, obj_cnt); + return get_ids_where_recycle_buffer(mask, cd_true, args...); } else { return sham::DeviceBuffer(0, dev_sched); } From 33a933cb88e2c75a148ccc7a9246a2dd545f4338 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?David--Cl=C3=A9ris=20Timoth=C3=A9e?= Date: Thu, 1 Jan 2026 17:48:03 +0100 Subject: [PATCH 13/54] Update src/shamrock/include/shamrock/patch/PatchDataField.hpp --- .../include/shamrock/patch/PatchDataField.hpp | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/src/shamrock/include/shamrock/patch/PatchDataField.hpp b/src/shamrock/include/shamrock/patch/PatchDataField.hpp index 28fc2633b7..921910d8a8 100644 --- a/src/shamrock/include/shamrock/patch/PatchDataField.hpp +++ b/src/shamrock/include/shamrock/patch/PatchDataField.hpp @@ -391,7 +391,16 @@ class PatchDataField { // buffer of booleans to store result of the condition sham::DeviceBuffer mask(obj_cnt, dev_sched); - return get_ids_where_recycle_buffer(mask, cd_true, args...); +sham::kernel_call( + q, + sham::MultiRef{buf}, + sham::MultiRef{mask}, + obj_cnt, + [=, nvar_field = nvar](u32 id, const T *__restrict acc, u32 *__restrict acc_mask) { + acc_mask[id] = cd_true(acc, id * nvar_field, args...); + }); + + return shamalgs::stream_compact(dev_sched, mask, obj_cnt); } else { return sham::DeviceBuffer(0, dev_sched); } From f806918701ad56b3569fdc91a8405bc3ebb6ed01 Mon Sep 17 00:00:00 2001 From: "autofix-ci[bot]" <114827586+autofix-ci[bot]@users.noreply.github.com> Date: Tue, 13 Jan 2026 10:08:23 +0000 Subject: [PATCH 14/54] [autofix.ci] automatic fix: pre-commit hooks --- src/shamrock/include/shamrock/patch/PatchDataField.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/shamrock/include/shamrock/patch/PatchDataField.hpp b/src/shamrock/include/shamrock/patch/PatchDataField.hpp index 921910d8a8..b467fd3407 100644 --- a/src/shamrock/include/shamrock/patch/PatchDataField.hpp +++ b/src/shamrock/include/shamrock/patch/PatchDataField.hpp @@ -391,7 +391,7 @@ class PatchDataField { // buffer of booleans to store result of the condition sham::DeviceBuffer mask(obj_cnt, dev_sched); -sham::kernel_call( + sham::kernel_call( q, sham::MultiRef{buf}, sham::MultiRef{mask}, From 12a6ba51189001353d3cc2e80bd7decd3d8b92f4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Timoth=C3=A9e=20David--Cl=C3=A9ris?= Date: Wed, 14 Jan 2026 11:51:35 +0100 Subject: [PATCH 15/54] small fix --- src/shammodels/gsph/src/Solver.cpp | 4 ++-- src/shammodels/sph/src/modules/SPHSetup.cpp | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/shammodels/gsph/src/Solver.cpp b/src/shammodels/gsph/src/Solver.cpp index c347eafff5..1878e6eece 100644 --- a/src/shammodels/gsph/src/Solver.cpp +++ b/src/shammodels/gsph/src/Solver.cpp @@ -198,7 +198,7 @@ void shammodels::gsph::Solver::merge_position_ghost() { = std::make_shared(storage.xyzh_ghost_layout); storage.merged_xyzh.set(storage.ghost_handler.get().build_comm_merge_positions( - storage.ghost_patch_cache.get(), exchange_gz_node)); + storage.ghost_patch_cache.get())); // Get field indices from xyzh_ghost_layout const u32 ixyz_ghost = storage.xyzh_ghost_layout->template get_field_idx("xyz"); @@ -749,7 +749,7 @@ void shammodels::gsph::Solver::communicate_merge_ghosts_fields() { // Communicate ghost data across MPI ranks shambase::DistributedDataShared interf_pdat - = ghost_handle.communicate_pdat(ghost_layout_ptr, std::move(pdat_interf), exchange_gz_node); + = ghost_handle.communicate_pdat(ghost_layout_ptr, std::move(pdat_interf)); // Count total ghost particles per patch std::map sz_interf_map; diff --git a/src/shammodels/sph/src/modules/SPHSetup.cpp b/src/shammodels/sph/src/modules/SPHSetup.cpp index 792dcb7b8a..3a49fb3607 100644 --- a/src/shammodels/sph/src/modules/SPHSetup.cpp +++ b/src/shammodels/sph/src/modules/SPHSetup.cpp @@ -535,6 +535,7 @@ void shammodels::sph::modules::SPHSetup::apply_setup_new( f64 total_time_rank_getter = 0; f64 max_time_rank_getter = 0; + shamalgs::collective::DDSCommCache comm_cache; u32 step_count = 0; while (!shamalgs::collective::are_all_rank_true(to_insert.is_empty(), MPI_COMM_WORLD)) { From 95f5290906ba88292670c8be342ebc7d01b7eddb Mon Sep 17 00:00:00 2001 From: "autofix-ci[bot]" <114827586+autofix-ci[bot]@users.noreply.github.com> Date: Wed, 14 Jan 2026 10:52:09 +0000 Subject: [PATCH 16/54] [autofix.ci] automatic fix: pre-commit hooks --- src/shammodels/gsph/src/Solver.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/shammodels/gsph/src/Solver.cpp b/src/shammodels/gsph/src/Solver.cpp index 1878e6eece..36e762e377 100644 --- a/src/shammodels/gsph/src/Solver.cpp +++ b/src/shammodels/gsph/src/Solver.cpp @@ -197,8 +197,8 @@ void shammodels::gsph::Solver::merge_position_ghost() { std::shared_ptr exchange_gz_node = std::make_shared(storage.xyzh_ghost_layout); - storage.merged_xyzh.set(storage.ghost_handler.get().build_comm_merge_positions( - storage.ghost_patch_cache.get())); + storage.merged_xyzh.set( + storage.ghost_handler.get().build_comm_merge_positions(storage.ghost_patch_cache.get())); // Get field indices from xyzh_ghost_layout const u32 ixyz_ghost = storage.xyzh_ghost_layout->template get_field_idx("xyz"); From c04fabde56ef36f286985cfd84f7f589ae433e9e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?David--Cl=C3=A9ris=20Timoth=C3=A9e?= Date: Sun, 18 Jan 2026 18:28:01 +0100 Subject: [PATCH 17/54] Update src/shamalgs/include/shamalgs/collective/RequestList.hpp --- src/shamalgs/include/shamalgs/collective/RequestList.hpp | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/shamalgs/include/shamalgs/collective/RequestList.hpp b/src/shamalgs/include/shamalgs/collective/RequestList.hpp index c36e658a27..5b333d5f3b 100644 --- a/src/shamalgs/include/shamalgs/collective/RequestList.hpp +++ b/src/shamalgs/include/shamalgs/collective/RequestList.hpp @@ -62,8 +62,10 @@ namespace shamalgs::collective { return; } std::vector st_lst(rqs.size()); - shamcomm::mpi::Waitall( - shambase::narrow_or_throw(rqs.size()), rqs.data(), st_lst.data()); +shamcomm::mpi::Waitall( +shambase::narrow_or_throw(rqs.size()), rqs.data(), st_lst.data()); +ready_count = rqs.size(); +is_ready.assign(rqs.size(), true); } size_t remain_count_no_test() { return rqs.size() - ready_count; } From 461735643653426660d765357a326e782b810b70 Mon Sep 17 00:00:00 2001 From: "autofix-ci[bot]" <114827586+autofix-ci[bot]@users.noreply.github.com> Date: Sun, 18 Jan 2026 17:28:36 +0000 Subject: [PATCH 18/54] [autofix.ci] automatic fix: pre-commit hooks --- src/shamalgs/include/shamalgs/collective/RequestList.hpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/shamalgs/include/shamalgs/collective/RequestList.hpp b/src/shamalgs/include/shamalgs/collective/RequestList.hpp index 5b333d5f3b..61f90b78f3 100644 --- a/src/shamalgs/include/shamalgs/collective/RequestList.hpp +++ b/src/shamalgs/include/shamalgs/collective/RequestList.hpp @@ -62,10 +62,10 @@ namespace shamalgs::collective { return; } std::vector st_lst(rqs.size()); -shamcomm::mpi::Waitall( -shambase::narrow_or_throw(rqs.size()), rqs.data(), st_lst.data()); -ready_count = rqs.size(); -is_ready.assign(rqs.size(), true); + shamcomm::mpi::Waitall( + shambase::narrow_or_throw(rqs.size()), rqs.data(), st_lst.data()); + ready_count = rqs.size(); + is_ready.assign(rqs.size(), true); } size_t remain_count_no_test() { return rqs.size() - ready_count; } From 67c58c08929ae47368a81bed51b7bd91b56c70fe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Timoth=C3=A9e=20David--Cl=C3=A9ris?= Date: Mon, 19 Jan 2026 12:27:36 +0100 Subject: [PATCH 19/54] tmp --- .../shamalgs/collective/sparse_exchange.hpp | 29 ++- .../src/collective/sparse_exchange.cpp | 189 ++++++++++++------ .../include/shambackends/DeviceBuffer.hpp | 58 ++++++ .../include/shambackends/DeviceBufferRef.hpp | 26 +++ .../collective/sparse_exchange_tests.cpp | 173 +++++++++++----- 5 files changed, 360 insertions(+), 115 deletions(-) create mode 100644 src/shambackends/include/shambackends/DeviceBufferRef.hpp diff --git a/src/shamalgs/include/shamalgs/collective/sparse_exchange.hpp b/src/shamalgs/include/shamalgs/collective/sparse_exchange.hpp index 92abef3284..47b50872ea 100644 --- a/src/shamalgs/include/shamalgs/collective/sparse_exchange.hpp +++ b/src/shamalgs/include/shamalgs/collective/sparse_exchange.hpp @@ -22,14 +22,27 @@ namespace shamalgs::collective { + struct CommMessageBufOffset { + size_t buf_id; + size_t data_offset; + + friend bool operator==(const CommMessageBufOffset &a, const CommMessageBufOffset &b) { + return a.buf_id == b.buf_id && a.data_offset == b.data_offset; + } + friend bool operator!=(const CommMessageBufOffset &a, const CommMessageBufOffset &b) { + return !(a == b); + } + }; + struct CommMessageInfo { size_t message_size; ///< Size of the MPI message i32 rank_sender; ///< Rank of the sender i32 rank_receiver; ///< Rank of the receiver std::optional message_tag; ///< Tag of the MPI message - std::optional + + std::optional message_bytebuf_offset_send; ///< Offset of the MPI message in the send buffer - std::optional + std::optional message_bytebuf_offset_recv; ///< Offset of the MPI message in the recv buffer }; @@ -39,17 +52,19 @@ namespace shamalgs::collective { std::vector messages_recv; ///< Messages to recv std::vector send_message_global_ids; ///< ids of messages_send in message_all std::vector recv_message_global_ids; ///< ids of messages_recv in message_all - size_t send_total_size; ///< Total size of the send buffer - size_t recv_total_size; ///< Total size of the recv buffer + + std::vector send_total_sizes; ///< Total size of the send buffer + std::vector recv_total_sizes; ///< Total size of the recv buffer }; - CommTable build_sparse_exchange_table(const std::vector &messages_send); + CommTable build_sparse_exchange_table( + const std::vector &messages_send, size_t max_alloc_size); template void sparse_exchange( std::shared_ptr dev_sched, - sham::DeviceBuffer &bytebuffer_send, - sham::DeviceBuffer &bytebuffer_recv, + std::vector>> &bytebuffer_send, + std::vector>> &bytebuffer_recv, const CommTable &comm_table); } // namespace shamalgs::collective diff --git a/src/shamalgs/src/collective/sparse_exchange.cpp b/src/shamalgs/src/collective/sparse_exchange.cpp index f50b214c67..4789e36dca 100644 --- a/src/shamalgs/src/collective/sparse_exchange.cpp +++ b/src/shamalgs/src/collective/sparse_exchange.cpp @@ -53,7 +53,8 @@ namespace shamalgs::collective { std::nullopt}; }; - CommTable build_sparse_exchange_table(const std::vector &messages_send) { + CommTable build_sparse_exchange_table( + const std::vector &messages_send, size_t max_alloc_size) { __shamrock_stack_entry(); //////////////////////////////////////////////////////////// @@ -89,38 +90,67 @@ namespace shamalgs::collective { std::vector message_all(global_data.size()); std::vector tag_map(shamcomm::world_size(), 0); + std::vector send_buf_sizes{0}; + std::vector recv_buf_sizes{0}; u32 send_idx = 0; u32 recv_idx = 0; + { + size_t tmp_recv_offset = 0; + size_t tmp_send_offset = 0; + size_t send_buf_id = 0; + size_t recv_buf_id = 0; + for (u64 i = 0; i < global_data.size(); i++) { + auto message_info = unpack(global_data[i]); + + auto sender = message_info.rank_sender; + auto receiver = message_info.rank_receiver; + + // tagging logic + i32 &tag_map_ref = tag_map[static_cast(sender)]; + i32 tag = tag_map_ref; + tag_map_ref++; + + message_info.message_tag = tag; + + // offset logic (& buffer selection) + if (sender == shamcomm::world_rank()) { + if (message_info.message_size > max_alloc_size) { + throw ""; // TODO + } + + if (tmp_send_offset + message_info.message_size > max_alloc_size) { + send_buf_id++; + tmp_send_offset = 0; + } + + message_info.message_bytebuf_offset_send = {send_buf_id, tmp_send_offset}; + tmp_send_offset += message_info.message_size; + send_buf_sizes.at(send_buf_id) += message_info.message_size; + + send_idx++; + } - size_t recv_offset = 0; - size_t send_offset = 0; - for (u64 i = 0; i < global_data.size(); i++) { - auto message_info = unpack(global_data[i]); - - auto sender = message_info.rank_sender; - auto receiver = message_info.rank_receiver; + if (receiver == shamcomm::world_rank()) { - i32 &tag_map_ref = tag_map[static_cast(sender)]; + if (message_info.message_size > max_alloc_size) { + throw ""; // TODO + } - i32 tag = tag_map_ref; - tag_map_ref++; + if (tmp_recv_offset + message_info.message_size > max_alloc_size) { + recv_buf_id++; + tmp_recv_offset = 0; + } - message_info.message_tag = tag; + message_info.message_bytebuf_offset_recv = {recv_buf_id, tmp_recv_offset}; + tmp_recv_offset += message_info.message_size; + recv_buf_sizes.at(recv_buf_id) += message_info.message_size; - if (sender == shamcomm::world_rank()) { - message_info.message_bytebuf_offset_send = send_offset; - send_offset += message_info.message_size; - send_idx++; - } + recv_idx++; + } - if (receiver == shamcomm::world_rank()) { - message_info.message_bytebuf_offset_recv = recv_offset; - recv_offset += message_info.message_size; - recv_idx++; + message_all[i] = message_info; } - - message_all[i] = message_info; } //////////////////////////////////////////////////////////// @@ -144,13 +174,18 @@ namespace shamalgs::collective { auto expected_offset = shambase::get_check_ref( messages_send.at(send_idx).message_bytebuf_offset_send); + auto actual_offset + = shambase::get_check_ref(message_info.message_bytebuf_offset_send); + // check that the send offset match for good measure - if (message_info.message_bytebuf_offset_send != expected_offset) { + if (actual_offset != expected_offset) { throw shambase::make_except_with_loc(shambase::format( "The sender has not set the offset for all messages, otherwise throw\n" - " expected_offset = {}, actual_offset = {}", - expected_offset, - message_info.message_bytebuf_offset_send.value())); + " expected_offset = ({}, {}), actual_offset = ({}, {})", + expected_offset.buf_id, + expected_offset.data_offset, + actual_offset.buf_id, + actual_offset.data_offset)); } ret_message_send[send_idx] = message_info; @@ -170,14 +205,14 @@ namespace shamalgs::collective { ret_message_recv, send_message_global_ids, recv_message_global_ids, - send_offset, - recv_offset}; + send_buf_sizes, + recv_buf_sizes}; } void sparse_exchange( std::shared_ptr dev_sched, - const u8 *bytebuffer_send, - u8 *bytebuffer_recv, + const std::vector &bytebuffer_send, + const std::vector &bytebuffer_recv, const CommTable &comm_table) { __shamrock_stack_entry(); @@ -190,10 +225,11 @@ namespace shamalgs::collective { auto message_info = comm_table.message_all[i]; if (message_info.rank_sender == shamcomm::world_rank()) { - auto &rq = rqs.new_request(); + auto off_info = shambase::get_check_ref(message_info.message_bytebuf_offset_send); + auto ptr = bytebuffer_send.at(off_info.buf_id) + off_info.data_offset; + auto &rq = rqs.new_request(); shamcomm::mpi::Isend( - bytebuffer_send - + shambase::get_check_ref(message_info.message_bytebuf_offset_send), + ptr, shambase::narrow_or_throw(message_info.message_size), MPI_BYTE, message_info.rank_receiver, @@ -203,10 +239,11 @@ namespace shamalgs::collective { } if (message_info.rank_receiver == shamcomm::world_rank()) { - auto &rq = rqs.new_request(); + auto off_info = shambase::get_check_ref(message_info.message_bytebuf_offset_recv); + auto ptr = bytebuffer_recv.at(off_info.buf_id) + off_info.data_offset; + auto &rq = rqs.new_request(); shamcomm::mpi::Irecv( - bytebuffer_recv - + shambase::get_check_ref(message_info.message_bytebuf_offset_recv), + ptr, shambase::narrow_or_throw(message_info.message_size), MPI_BYTE, message_info.rank_sender, @@ -223,8 +260,8 @@ namespace shamalgs::collective { template void sparse_exchange( std::shared_ptr dev_sched, - sham::DeviceBuffer &bytebuffer_send, - sham::DeviceBuffer &bytebuffer_recv, + std::vector>> &bytebuffer_send, + std::vector>> &bytebuffer_recv, const CommTable &comm_table) { __shamrock_stack_entry(); @@ -235,52 +272,90 @@ namespace shamalgs::collective { "distinct."); } - if (comm_table.send_total_size > bytebuffer_send.get_size()) { + if (comm_table.send_total_sizes.size() != bytebuffer_send.size()) { throw shambase::make_except_with_loc(shambase::format( "The send total size is greater than the send buffer size\n" - " send_total_size = {}, send_buffer_size = {}", - comm_table.send_total_size, - bytebuffer_send.get_size())); + " send_total_sizes = {}, send_buffer_size = {}", + comm_table.send_total_sizes.size(), + bytebuffer_send.size())); } - if (comm_table.recv_total_size > bytebuffer_recv.get_size()) { + if (comm_table.recv_total_sizes.size() != bytebuffer_recv.size()) { throw shambase::make_except_with_loc(shambase::format( "The recv total size is greater than the recv buffer size\n" - " recv_total_size = {}, recv_buffer_size = {}", - comm_table.recv_total_size, - bytebuffer_recv.get_size())); + " recv_total_sizes = {}, recv_buffer_size = {}", + comm_table.recv_total_sizes.size(), + bytebuffer_recv.size())); + } + + for (size_t i = 0; i < comm_table.send_total_sizes.size(); i++) { + if (comm_table.send_total_sizes[i] > bytebuffer_send[i]->get_size()) { + throw shambase::make_except_with_loc(shambase::format( + "The send total size is greater than the send buffer size\n" + " send_total_sizes = {}, send_buffer_size = {}, buf_id = {}", + comm_table.send_total_sizes[i], + bytebuffer_send[i]->get_size(), + i)); + } + } + + for (size_t i = 0; i < comm_table.recv_total_sizes.size(); i++) { + if (comm_table.recv_total_sizes[i] > bytebuffer_recv[i]->get_size()) { + throw shambase::make_except_with_loc(shambase::format( + "The recv total size is greater than the recv buffer size\n" + " recv_total_sizes = {}, recv_buffer_size = {}, buf_id = {}", + comm_table.recv_total_sizes[i], + bytebuffer_recv[i]->get_size(), + i)); + } } bool direct_gpu_capable = dev_sched->ctx->device->mpi_prop.is_mpi_direct_capable; if (!direct_gpu_capable && target == sham::device) { throw shambase::make_except_with_loc( - "You are trying to use a device buffer on the device but the device is not direct " + "You are trying to use a device buffer on the device but the device is not " + "direct " "GPU capable"); } + std::vector send_ptrs(bytebuffer_send.size()); + std::vector recv_ptrs(bytebuffer_recv.size()); + sham::EventList depends_list; - const u8 *send_ptr = bytebuffer_send.get_read_access(depends_list); - u8 *recv_ptr = bytebuffer_recv.get_write_access(depends_list); + for (size_t i = 0; i < bytebuffer_send.size(); i++) { + send_ptrs[i] + = shambase::get_check_ref(bytebuffer_send[i]).get_read_access(depends_list); + } + + for (size_t i = 0; i < bytebuffer_recv.size(); i++) { + recv_ptrs[i] + = shambase::get_check_ref(bytebuffer_recv[i]).get_write_access(depends_list); + } depends_list.wait(); - sparse_exchange(dev_sched, send_ptr, recv_ptr, comm_table); + sparse_exchange(dev_sched, send_ptrs, recv_ptrs, comm_table); - bytebuffer_send.complete_event_state(sycl::event{}); - bytebuffer_recv.complete_event_state(sycl::event{}); + for (size_t i = 0; i < bytebuffer_send.size(); i++) { + shambase::get_check_ref(bytebuffer_send[i]).complete_event_state(sycl::event{}); + } + + for (size_t i = 0; i < bytebuffer_recv.size(); i++) { + shambase::get_check_ref(bytebuffer_recv[i]).complete_event_state(sycl::event{}); + } } // template instantiations template void sparse_exchange( std::shared_ptr dev_sched, - sham::DeviceBuffer &bytebuffer_send, - sham::DeviceBuffer &bytebuffer_recv, + std::vector>> &bytebuffer_send, + std::vector>> &bytebuffer_recv, const CommTable &comm_table); template void sparse_exchange( std::shared_ptr dev_sched, - sham::DeviceBuffer &bytebuffer_send, - sham::DeviceBuffer &bytebuffer_recv, + std::vector>> &bytebuffer_send, + std::vector>> &bytebuffer_recv, const CommTable &comm_table); } // namespace shamalgs::collective diff --git a/src/shambackends/include/shambackends/DeviceBuffer.hpp b/src/shambackends/include/shambackends/DeviceBuffer.hpp index 833f69c554..36acc9c394 100644 --- a/src/shambackends/include/shambackends/DeviceBuffer.hpp +++ b/src/shambackends/include/shambackends/DeviceBuffer.hpp @@ -548,6 +548,64 @@ namespace sham { dest.complete_event_state(e); } + /** + * @brief Copy a range of elements from the buffer to another buffer + * + * This function copies a range of elements from the buffer to another buffer. + * The range is specified by the begin and end indices. + * + * @param begin The starting index of the range to copy, inclusive. + * @param end The ending index of the range to copy, exclusive. + * @param dest The destination buffer to copy to. + */ + template + inline void copy_range_offset( + size_t begin, + size_t end, + sham::DeviceBuffer &dest, + size_t dest_offset) const { + + if (begin > end) { + shambase::throw_with_loc(shambase::format( + "copy_range: begin > end\n begin = {},\n end = {}", begin, end)); + } + + if (dest_offset > dest.get_size()) { + shambase::throw_with_loc(shambase::format( + "copy_range_offset: dest_offset > dest.get_size()\n dest_offset = {},\n " + "dest.get_size() = {}", + dest_offset, + dest.get_size())); + } + + if (end - begin > (dest.get_size() - dest_offset)) { + shambase::throw_with_loc(shambase::format( + "copy_range_offset: end - begin > dest.get_size() - dest_offset\n end - begin " + "= {},\n " + "dest.get_size() - dest_offset = {},\n dest_offset = {}", + end - begin, + dest.get_size() - dest_offset, + dest_offset)); + } + + if (begin == end) { + return; + } + + size_t len = end - begin; + + sham::EventList depends_list; + const T *ptr_src = get_read_access(depends_list) + begin; + T *ptr_dest = dest.get_write_access(depends_list) + dest_offset; + + sycl::event e = get_queue().submit(depends_list, [&](sycl::handler &cgh) { + cgh.copy(ptr_src, ptr_dest, len); + }); + + complete_event_state(e); + dest.complete_event_state(e); + } + /** * @brief Copy the content of a std::vector into the buffer * diff --git a/src/shambackends/include/shambackends/DeviceBufferRef.hpp b/src/shambackends/include/shambackends/DeviceBufferRef.hpp new file mode 100644 index 0000000000..5fd1d3e65a --- /dev/null +++ b/src/shambackends/include/shambackends/DeviceBufferRef.hpp @@ -0,0 +1,26 @@ +// -------------------------------------------------------// +// +// SHAMROCK code for hydrodynamics +// Copyright (c) 2021-2026 Timothée David--Cléris +// SPDX-License-Identifier: CeCILL Free Software License Agreement v2.1 +// Shamrock is licensed under the CeCILL 2.1 License, see LICENSE for more information +// +// -------------------------------------------------------// + +#pragma once + +/** + * @file DeviceBufferRef.hpp + * @author Timothée David--Cléris (tim.shamrock@proton.me) + * @brief + * + */ + +#include "shambackends/DeviceBuffer.hpp" + +namespace sham { + + template + using DeviceBufferRef = std::reference_wrapper>; + +} diff --git a/src/tests/shamalgs/collective/sparse_exchange_tests.cpp b/src/tests/shamalgs/collective/sparse_exchange_tests.cpp index 9b1eb1c1d1..5587a3b866 100644 --- a/src/tests/shamalgs/collective/sparse_exchange_tests.cpp +++ b/src/tests/shamalgs/collective/sparse_exchange_tests.cpp @@ -32,7 +32,7 @@ void reorder_msg(std::vector &test_elements) { }); } -void test_sparse_exchange(std::vector test_elements) { +void test_sparse_exchange(std::vector test_elements, size_t max_alloc_size) { auto dev_sched = shamsys::instance::get_compute_scheduler_ptr(); reorder_msg(test_elements); @@ -45,63 +45,115 @@ void test_sparse_exchange(std::vector test_elements) { shamalgs::random::mock_buffer_usm(dev_sched, eng(), test_element.size)); } - sham::DeviceBuffer send_buf(0, dev_sched); std::vector messages_send; - size_t total_recv_size = 0; - size_t total_recv_count = 0; - size_t sender_offset = 0; - size_t sender_count = 0; - for (u32 i = 0; i < test_elements.size(); i++) { - if (test_elements[i].sender == shamcomm::world_rank()) { - messages_send.push_back( - shamalgs::collective::CommMessageInfo{ - test_elements[i].size, + std::vector total_send_sizes = {0}; + std::vector total_recv_sizes = {0}; + { + u32 send_buf_id = 0; + size_t send_offset = 0; + u32 recv_buf_id = 0; + size_t recv_offset = 0; + for (u32 i = 0; i < test_elements.size(); i++) { + if (test_elements[i].sender == shamcomm::world_rank()) { + messages_send.push_back( + shamalgs::collective::CommMessageInfo{ + test_elements[i].size, + test_elements[i].sender, + test_elements[i].receiver, + std::nullopt, + std::nullopt, + std::nullopt, + }); + + logger::info_ln( + "sparse exchange test", + "rank :", + shamcomm::world_rank(), + "send message : (", test_elements[i].sender, + "->", test_elements[i].receiver, - std::nullopt, - sender_offset, - std::nullopt, - }); - - logger::info_ln( - "sparse exchange test", - "rank :", - shamcomm::world_rank(), - "send message : (", - test_elements[i].sender, - "->", - test_elements[i].receiver, - ") data :", - all_bufs[i].copy_to_stdvec()); - - send_buf.append(all_bufs[i]); - sender_offset += test_elements[i].size; - sender_count++; - } - if (test_elements[i].receiver == shamcomm::world_rank()) { - total_recv_size += test_elements[i].size; - total_recv_count++; + ") data :", + all_bufs[i].copy_to_stdvec()); + + if (send_offset + test_elements[i].size > max_alloc_size) { + send_buf_id++; + send_offset = 0; + total_send_sizes.push_back(0); + } + + total_send_sizes.at(send_buf_id) += test_elements[i].size; + } + if (test_elements[i].receiver == shamcomm::world_rank()) { + if (recv_offset + test_elements[i].size > max_alloc_size) { + recv_buf_id++; + recv_offset = 0; + total_recv_sizes.push_back(0); + } + + total_recv_sizes.at(recv_buf_id) += test_elements[i].size; + } } } shamalgs::collective::CommTable comm_table - = shamalgs::collective::build_sparse_exchange_table(messages_send); + = shamalgs::collective::build_sparse_exchange_table(messages_send, max_alloc_size); + + REQUIRE_EQUAL(comm_table.send_total_sizes, total_send_sizes); + REQUIRE_EQUAL(comm_table.recv_total_sizes, total_recv_sizes); + + // allocate send and receive bufs + std::vector>> send_bufs{}; + + for (size_t i = 0; i < comm_table.send_total_sizes.size(); i++) { + send_bufs.push_back( + std::make_unique>(comm_table.send_total_sizes[i], dev_sched)); + } + + std::vector>> recv_bufs{}; + + for (size_t i = 0; i < comm_table.recv_total_sizes.size(); i++) { + recv_bufs.push_back( + std::make_unique>(comm_table.recv_total_sizes[i], dev_sched)); + } + + // push data to the comm buf + for (size_t i = 0; i < comm_table.messages_send.size(); i++) { + auto msg_info = comm_table.messages_send[i]; + size_t global_msg_id = comm_table.send_message_global_ids[i]; - REQUIRE_EQUAL(comm_table.send_total_size, sender_offset); - REQUIRE_EQUAL(comm_table.recv_total_size, total_recv_size); + auto off_info = shambase::get_check_ref(msg_info.message_bytebuf_offset_send); - // allocate recv buffer - sham::DeviceBuffer recv_buf(comm_table.recv_total_size, dev_sched); + auto &source = all_bufs.at(global_msg_id); + auto &dest = shambase::get_check_ref(send_bufs.at(off_info.buf_id)); + + source.copy_range_offset(0, source.get_size(), dest, off_info.data_offset); + } // do the comm if (dev_sched->ctx->device->mpi_prop.is_mpi_direct_capable) { - shamalgs::collective::sparse_exchange(dev_sched, send_buf, recv_buf, comm_table); + shamalgs::collective::sparse_exchange(dev_sched, send_bufs, recv_bufs, comm_table); } else { - auto send_buf_host = send_buf.copy_to(); - auto recv_buf_host = recv_buf.copy_to(); - shamalgs::collective::sparse_exchange(dev_sched, send_buf_host, recv_buf_host, comm_table); - recv_buf.copy_from(recv_buf_host); + std::vector>> send_bufs_host{}; + std::vector>> recv_bufs_host{}; + + for (size_t i = 0; i < comm_table.send_total_sizes.size(); i++) { + send_bufs_host.push_back( + std::make_unique>( + send_bufs[i]->copy_to())); + } + for (size_t i = 0; i < comm_table.recv_total_sizes.size(); i++) { + recv_bufs_host.push_back( + std::make_unique>( + comm_table.recv_total_sizes[i], dev_sched)); + } + + shamalgs::collective::sparse_exchange( + dev_sched, send_bufs_host, recv_bufs_host, comm_table); + for (size_t i = 0; i < comm_table.recv_total_sizes.size(); i++) { + recv_bufs[i]->copy_from(*recv_bufs_host[i]); + } } // time to check @@ -129,10 +181,11 @@ void test_sparse_exchange(std::vector test_elements) { auto &ref_buf = all_bufs[i]; sham::DeviceBuffer recov(test_elements[i].size, dev_sched); - size_t begin = shambase::get_check_ref( + auto off_info = shambase::get_check_ref( comm_table.messages_recv[recv_msg_idx].message_bytebuf_offset_recv); - size_t end = begin + test_elements[i].size; - recv_buf.copy_range(begin, end, recov); + size_t begin = off_info.data_offset; + size_t end = begin + test_elements[i].size; + shambase::get_check_ref(recv_bufs.at(off_info.buf_id)).copy_range(begin, end, recov); logger::info_ln( "sparse exchange test", @@ -161,7 +214,7 @@ TestStart(Unittest, "shamalgs/collective/test_sparse_exchange", testsparsexchg_2 logger::info_ln("sparse exchange test", "empty comm"); } - test_sparse_exchange({}); + test_sparse_exchange({}, i32_max); if (shamcomm::world_rank() == 0) { logger::info_ln("sparse exchange test", "send to self"); @@ -175,7 +228,7 @@ TestStart(Unittest, "shamalgs/collective/test_sparse_exchange", testsparsexchg_2 test_elements.push_back( TestElement{i, i, shamalgs::primitives::mock_value(eng, 1, 10)}); } - test_sparse_exchange(test_elements); + test_sparse_exchange(test_elements, i32_max); } if (shamcomm::world_rank() == 0) { @@ -193,7 +246,7 @@ TestStart(Unittest, "shamalgs/collective/test_sparse_exchange", testsparsexchg_2 (i + 1) % shamcomm::world_size(), shamalgs::primitives::mock_value(eng, 1, 10)}); } - test_sparse_exchange(test_elements); + test_sparse_exchange(test_elements, i32_max); } if (shamcomm::world_rank() == 0) { @@ -211,6 +264,24 @@ TestStart(Unittest, "shamalgs/collective/test_sparse_exchange", testsparsexchg_2 shamalgs::primitives::mock_value(eng, 0, shamcomm::world_size() - 1), shamalgs::primitives::mock_value(eng, 1, 10)}); } - test_sparse_exchange(test_elements); + test_sparse_exchange(test_elements, i32_max); + } + + if (shamcomm::world_rank() == 0) { + logger::info_ln("sparse exchange test", "random test (force multiple bufs)"); + } + + { + // random test + std::mt19937 eng(0x123); + std::vector test_elements; + for (u32 i = 0; i < 3 * shamcomm::world_size(); i++) { + test_elements.push_back( + TestElement{ + shamalgs::primitives::mock_value(eng, 0, shamcomm::world_size() - 1), + shamalgs::primitives::mock_value(eng, 0, shamcomm::world_size() - 1), + shamalgs::primitives::mock_value(eng, 1, 10)}); + } + test_sparse_exchange(test_elements, 20); } } From 80ad01c04e95ff64061c0c949f0a074b85283286 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Timoth=C3=A9e=20David--Cl=C3=A9ris?= Date: Mon, 19 Jan 2026 18:01:23 +0100 Subject: [PATCH 20/54] merge working changes --- .../src/collective/sparse_exchange.cpp | 102 +++++++++++------- .../src/details/internal_alloc.cpp | 2 - 2 files changed, 61 insertions(+), 43 deletions(-) diff --git a/src/shamalgs/src/collective/sparse_exchange.cpp b/src/shamalgs/src/collective/sparse_exchange.cpp index 4789e36dca..be37674d97 100644 --- a/src/shamalgs/src/collective/sparse_exchange.cpp +++ b/src/shamalgs/src/collective/sparse_exchange.cpp @@ -22,11 +22,11 @@ #include "shamalgs/collective/RequestList.hpp" #include "shamalgs/collective/exchanges.hpp" #include "shambackends/USMPtrHolder.hpp" +#include "shambackends/fmt_bindings/fmt_defs.hpp" #include "shambackends/math.hpp" #include "shamcomm/mpi.hpp" #include "shamcomm/worldInfo.hpp" #include - namespace shamalgs::collective { CommMessageInfo unpack(u64_2 comm_info) { @@ -53,13 +53,9 @@ namespace shamalgs::collective { std::nullopt}; }; - CommTable build_sparse_exchange_table( - const std::vector &messages_send, size_t max_alloc_size) { - __shamrock_stack_entry(); - - //////////////////////////////////////////////////////////// - // Pack the local data then allgatherv to get the global data - //////////////////////////////////////////////////////////// + /// fetch u64_2 from global message data + std::vector fetch_global_message_data( + const std::vector &messages_send) { std::vector local_data = std::vector(messages_send.size()); @@ -83,13 +79,51 @@ namespace shamalgs::collective { std::vector global_data; vector_allgatherv(local_data, global_data, MPI_COMM_WORLD); - //////////////////////////////////////////////////////////// - // Unpack the global data and build the global message list - //////////////////////////////////////////////////////////// + return global_data; // there should be return value optimisation here + } + /// decode message to get message + std::vector decode_all_message(const std::vector &global_data) { std::vector message_all(global_data.size()); + for (u64 i = 0; i < global_data.size(); i++) { + message_all[i] = unpack(global_data[i]); + } + + return message_all; + } + + /// compute message tags + void compute_tags(std::vector &message_all) { std::vector tag_map(shamcomm::world_size(), 0); + + for (u64 i = 0; i < message_all.size(); i++) { + auto &message_info = message_all[i]; + auto sender = message_info.rank_sender; + + // tagging logic + i32 &tag_map_ref = tag_map[static_cast(sender)]; + i32 tag = tag_map_ref; + tag_map_ref++; + + message_info.message_tag = tag; + } + } + + CommTable build_sparse_exchange_table( + const std::vector &messages_send, size_t max_alloc_size) { + __shamrock_stack_entry(); + + std::vector global_data = fetch_global_message_data(messages_send); + + std::vector message_all = decode_all_message(global_data); + + compute_tags(message_all); + + //////////////////////////////////////////////////////////// + // Compute offsets + //////////////////////////////////////////////////////////// + std::vector send_buf_sizes{0}; std::vector recv_buf_sizes{0}; @@ -100,28 +134,27 @@ namespace shamalgs::collective { size_t tmp_send_offset = 0; size_t send_buf_id = 0; size_t recv_buf_id = 0; - for (u64 i = 0; i < global_data.size(); i++) { - auto message_info = unpack(global_data[i]); + for (u64 i = 0; i < message_all.size(); i++) { + auto &message_info = message_all[i]; auto sender = message_info.rank_sender; auto receiver = message_info.rank_receiver; - // tagging logic - i32 &tag_map_ref = tag_map[static_cast(sender)]; - i32 tag = tag_map_ref; - tag_map_ref++; - - message_info.message_tag = tag; - // offset logic (& buffer selection) if (sender == shamcomm::world_rank()) { if (message_info.message_size > max_alloc_size) { - throw ""; // TODO + throw shambase::make_except_with_loc( + shambase::format( + "Message size is greater than the max alloc size\n" + " message_size = {}, max_alloc_size = {}", + message_info.message_size, + max_alloc_size)); } if (tmp_send_offset + message_info.message_size > max_alloc_size) { send_buf_id++; tmp_send_offset = 0; + send_buf_sizes.push_back(0); } message_info.message_bytebuf_offset_send = {send_buf_id, tmp_send_offset}; @@ -134,12 +167,18 @@ namespace shamalgs::collective { if (receiver == shamcomm::world_rank()) { if (message_info.message_size > max_alloc_size) { - throw ""; // TODO + throw shambase::make_except_with_loc( + shambase::format( + "Message size is greater than the max alloc size\n" + " message_size = {}, max_alloc_size = {}", + message_info.message_size, + max_alloc_size)); } if (tmp_recv_offset + message_info.message_size > max_alloc_size) { recv_buf_id++; tmp_recv_offset = 0; + recv_buf_sizes.push_back(0); } message_info.message_bytebuf_offset_recv = {recv_buf_id, tmp_recv_offset}; @@ -169,25 +208,6 @@ namespace shamalgs::collective { for (size_t i = 0; i < message_all.size(); i++) { auto message_info = message_all[i]; if (message_info.rank_sender == shamcomm::world_rank()) { - - // the sender should have set the offset for all messages, otherwise throw - auto expected_offset = shambase::get_check_ref( - messages_send.at(send_idx).message_bytebuf_offset_send); - - auto actual_offset - = shambase::get_check_ref(message_info.message_bytebuf_offset_send); - - // check that the send offset match for good measure - if (actual_offset != expected_offset) { - throw shambase::make_except_with_loc(shambase::format( - "The sender has not set the offset for all messages, otherwise throw\n" - " expected_offset = ({}, {}), actual_offset = ({}, {})", - expected_offset.buf_id, - expected_offset.data_offset, - actual_offset.buf_id, - actual_offset.data_offset)); - } - ret_message_send[send_idx] = message_info; send_message_global_ids[send_idx] = i; send_idx++; diff --git a/src/shambackends/src/details/internal_alloc.cpp b/src/shambackends/src/details/internal_alloc.cpp index 7306a140fe..904aa77917 100644 --- a/src/shambackends/src/details/internal_alloc.cpp +++ b/src/shambackends/src/details/internal_alloc.cpp @@ -356,8 +356,6 @@ namespace sham::details { register_alloc_shared(sz, end_time - start_time); } else if constexpr (target == host) { register_alloc_host(sz, end_time - start_time); - // logger::info_ln("internal_alloc", "alloc host : sz =", sz, " | time =", end_time - - // start_time); } return usm_ptr; From e503c0e71888dbb8fef9033c1bf083f89332b1f1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Timoth=C3=A9e=20David--Cl=C3=A9ris?= Date: Mon, 19 Jan 2026 18:01:53 +0100 Subject: [PATCH 21/54] merge working changes --- .../include/shamrock/scheduler/ReattributeDataUtility.hpp | 1 + 1 file changed, 1 insertion(+) diff --git a/src/shamrock/include/shamrock/scheduler/ReattributeDataUtility.hpp b/src/shamrock/include/shamrock/scheduler/ReattributeDataUtility.hpp index 9ac7dd1e51..56c5e4b6a3 100644 --- a/src/shamrock/include/shamrock/scheduler/ReattributeDataUtility.hpp +++ b/src/shamrock/include/shamrock/scheduler/ReattributeDataUtility.hpp @@ -16,6 +16,7 @@ */ #include "shambase/string.hpp" +#include "shamalgs/collective/distributedDataComm.hpp" #include "shamalgs/memory.hpp" #include "shambackends/comm/details/CommunicationBufferImpl.hpp" #include "shamrock/patch/PatchDataLayer.hpp" From e65c34d84ecd1f2f6623306b33e751c55b20be81 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Timoth=C3=A9e=20David--Cl=C3=A9ris?= Date: Mon, 19 Jan 2026 18:33:34 +0100 Subject: [PATCH 22/54] it compiles --- .../collective/distributedDataComm.hpp | 89 ++++++++++++++++++- .../src/collective/distributedDataComm.cpp | 87 ++++++++---------- 2 files changed, 126 insertions(+), 50 deletions(-) diff --git a/src/shamalgs/include/shamalgs/collective/distributedDataComm.hpp b/src/shamalgs/include/shamalgs/collective/distributedDataComm.hpp index 3ac373b8a6..41c0fed953 100644 --- a/src/shamalgs/include/shamalgs/collective/distributedDataComm.hpp +++ b/src/shamalgs/include/shamalgs/collective/distributedDataComm.hpp @@ -34,9 +34,94 @@ namespace shamalgs::collective { using SerializedDDataComm = shambase::DistributedDataShared>; + template + struct DDSCommCacheTarget { + std::vector>> cache1; + std::vector>> cache2; + + void set_sizes( + sham::DeviceScheduler_ptr dev_sched, + std::vector sizes_cache1, + std::vector sizes_cache2) { + // ensure correct length + cache1.resize(sizes_cache1.size()); + cache2.resize(sizes_cache2.size()); + + // if size is different, resize + for (size_t i = 0; i < sizes_cache1.size(); i++) { + if (cache1[i]) { + cache1[i]->resize(sizes_cache1[i]); + } else { + cache1[i] = std::make_unique>( + sizes_cache1[i], dev_sched); + } + } + for (size_t i = 0; i < sizes_cache2.size(); i++) { + if (cache2[i]) { + cache2[i]->resize(sizes_cache2[i]); + } else { + cache2[i] = std::make_unique>( + sizes_cache2[i], dev_sched); + } + } + } + + inline void write_buf_at(size_t buf_id, size_t offset, const sham::DeviceBuffer &buf) { + buf.copy_range_offset( + 0, buf.get_size(), shambase::get_check_ref(cache1[buf_id]), offset); + } + + inline void read_buf_at( + size_t buf_id, size_t offset, size_t size, sham::DeviceBuffer &buf) { + buf.resize(size); + shambase::get_check_ref(cache1[buf_id]) + .copy_range(offset, offset + size, buf); + } + }; + struct DDSCommCache { - std::unique_ptr> cache1; - std::unique_ptr> cache2; + std::variant, DDSCommCacheTarget> cache; + + template + std::vector>> & get_cache1() { + return shambase::get_check_ref(std::get_if>(&cache)).cache1; + } + + template + std::vector>> & get_cache2() { + return shambase::get_check_ref(std::get_if>(&cache)).cache2; + } + + template + void set_sizes( + sham::DeviceScheduler_ptr dev_sched, + std::vector sizes_cache1, + std::vector sizes_cache2) { + // init if not there + if (std::get_if>(&cache) == nullptr) { + cache = DDSCommCacheTarget{}; + } + + std::get>(cache).set_sizes( + dev_sched, sizes_cache1, sizes_cache2); + } + + inline void write_buf_at(size_t buf_id, size_t offset, sham::DeviceBuffer &buf) { + std::visit( + [&](auto &cache) { + cache.write_buf_at(buf_id, offset, buf); + }, + cache); + } + + inline void read_buf_at( + size_t buf_id, size_t offset, size_t size, sham::DeviceBuffer &buf) { + std::visit( + [&](auto &cache) { + cache.read_buf_at(buf_id, offset, size, buf); + }, + cache); + } }; void distributed_data_sparse_comm( diff --git a/src/shamalgs/src/collective/distributedDataComm.cpp b/src/shamalgs/src/collective/distributedDataComm.cpp index ef1c9c9332..da956efce6 100644 --- a/src/shamalgs/src/collective/distributedDataComm.cpp +++ b/src/shamalgs/src/collective/distributedDataComm.cpp @@ -108,10 +108,9 @@ namespace shamalgs::collective { } } - sham::DeviceBuffer send_buf(0, dev_sched); std::vector messages_send; + std::vector>> data_send; - size_t sender_offset = 0; for (auto &[key, buf] : send_bufs) { auto [sender, receiver] = key; @@ -123,61 +122,53 @@ namespace shamalgs::collective { sender, receiver, std::nullopt, - sender_offset, + std::nullopt, std::nullopt, }); - send_buf.append(*buf); - sender_offset += size; + data_send.push_back(std::move(buf)); + } + + size_t max_alloc_size; + if (dev_sched->ctx->device->mpi_prop.is_mpi_direct_capable) { + max_alloc_size = dev_sched->ctx->device->prop.max_mem_alloc_size_dev; + } else { + max_alloc_size = dev_sched->ctx->device->prop.max_mem_alloc_size_host; } shamalgs::collective::CommTable comm_table2 - = shamalgs::collective::build_sparse_exchange_table(messages_send); + = shamalgs::collective::build_sparse_exchange_table(messages_send, max_alloc_size); + + if (dev_sched->ctx->device->mpi_prop.is_mpi_direct_capable) { + cache.set_sizes( + dev_sched, comm_table2.send_total_sizes, comm_table2.recv_total_sizes); + } else { + cache.set_sizes( + dev_sched, comm_table2.send_total_sizes, comm_table2.recv_total_sizes); + } + + for (size_t i = 0; i < comm_table2.messages_send.size(); i++) { + auto &msg_info = comm_table2.messages_send[i]; + auto offset_info = shambase::get_check_ref(msg_info.message_bytebuf_offset_send); + auto &buf_src = shambase::get_check_ref(data_send.at(i)); - sham::DeviceBuffer recv_buf(comm_table2.recv_total_size, dev_sched); + SHAM_ASSERT(buf_src.get_size() == msg_info.message_size); + + cache.write_buf_at(offset_info.buf_id, offset_info.data_offset, buf_src); + } if (dev_sched->ctx->device->mpi_prop.is_mpi_direct_capable) { shamalgs::collective::sparse_exchange( - dev_sched, send_buf, recv_buf, comm_table2); + dev_sched, + cache.get_cache1(), + cache.get_cache2(), + comm_table2); } else { - if (!cache.cache1) { - // logger::info_ln("ddcomm", "alloc cache1", shambase::fmt_callstack()); - cache.cache1 = std::make_unique>( - send_buf.get_size(), dev_sched); - } else { - // logger::info_ln( - // "ddcomm", - // shambase::format( - // "resize cache1 from {} to {} (cache1 ptr = {})", - // cache.cache1->get_size(), - // send_buf.get_size(), - // static_cast(cache.cache1.get())), - // shambase::fmt_callstack()); - cache.cache1->resize(send_buf.get_size()); - } - cache.cache1->copy_from(send_buf); - sham::DeviceBuffer &send_buf_host = *cache.cache1; - - if (!cache.cache2) { - // logger::info_ln("ddcomm", "alloc cache2", shambase::fmt_callstack()); - cache.cache2 = std::make_unique>( - comm_table2.recv_total_size, dev_sched); - } else { - // logger::info_ln( - // "ddcomm", - // shambase::format( - // "resize cache2 from {} to {} (cache2 ptr = {})", - // cache.cache2->get_size(), - // comm_table2.recv_total_size, - // static_cast(cache.cache2.get())), - // shambase::fmt_callstack()); - cache.cache2->resize(comm_table2.recv_total_size); - } - sham::DeviceBuffer &recv_buf_host = *cache.cache2; - shamalgs::collective::sparse_exchange( - dev_sched, send_buf_host, recv_buf_host, comm_table2); - recv_buf.copy_from(recv_buf_host); + dev_sched, + cache.get_cache1(), + cache.get_cache2(), + comm_table2); } #ifdef false @@ -234,12 +225,12 @@ namespace shamalgs::collective { u64 size = msg.message_size; i32 sender = msg.rank_sender; i32 receiver = msg.rank_receiver; - size_t begin = shambase::get_check_ref(msg.message_bytebuf_offset_recv); - size_t end = begin + size; + + auto offset_info = shambase::get_check_ref(msg.message_bytebuf_offset_recv); sham::DeviceBuffer recov(size, dev_sched); + cache.read_buf_at(offset_info.buf_id, offset_info.data_offset, size, recov); - recv_buf.copy_range(begin, end, recov); recv_payload_bufs.push_back( RecvPayloadSer{sender, SerializeHelper(dev_sched, std::move(recov))}); } From f822afcb89bd7552c2dbac4afc842ef4926883ec Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Timoth=C3=A9e=20David--Cl=C3=A9ris?= Date: Mon, 19 Jan 2026 18:33:42 +0100 Subject: [PATCH 23/54] it compiles --- .../include/shamalgs/collective/distributedDataComm.hpp | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/shamalgs/include/shamalgs/collective/distributedDataComm.hpp b/src/shamalgs/include/shamalgs/collective/distributedDataComm.hpp index 41c0fed953..96ca07e065 100644 --- a/src/shamalgs/include/shamalgs/collective/distributedDataComm.hpp +++ b/src/shamalgs/include/shamalgs/collective/distributedDataComm.hpp @@ -74,8 +74,7 @@ namespace shamalgs::collective { inline void read_buf_at( size_t buf_id, size_t offset, size_t size, sham::DeviceBuffer &buf) { buf.resize(size); - shambase::get_check_ref(cache1[buf_id]) - .copy_range(offset, offset + size, buf); + shambase::get_check_ref(cache1[buf_id]).copy_range(offset, offset + size, buf); } }; @@ -83,14 +82,14 @@ namespace shamalgs::collective { std::variant, DDSCommCacheTarget> cache; template - std::vector>> & get_cache1() { + std::vector>> &get_cache1() { return shambase::get_check_ref(std::get_if>(&cache)).cache1; } template - std::vector>> & get_cache2() { + std::vector>> &get_cache2() { return shambase::get_check_ref(std::get_if>(&cache)).cache2; - } + } template void set_sizes( From dfaf17eef093cec415dd0b0b44ed365e0c30cb94 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Timoth=C3=A9e=20David--Cl=C3=A9ris?= Date: Mon, 19 Jan 2026 19:10:10 +0100 Subject: [PATCH 24/54] fixed ? --- .../collective/distributedDataComm.hpp | 44 ++++++++++++++++--- .../src/collective/distributedDataComm.cpp | 44 +++++++++++++++++-- .../collective/sparse_exchange_tests.cpp | 2 + 3 files changed, 80 insertions(+), 10 deletions(-) diff --git a/src/shamalgs/include/shamalgs/collective/distributedDataComm.hpp b/src/shamalgs/include/shamalgs/collective/distributedDataComm.hpp index 96ca07e065..5efd93a5bb 100644 --- a/src/shamalgs/include/shamalgs/collective/distributedDataComm.hpp +++ b/src/shamalgs/include/shamalgs/collective/distributedDataComm.hpp @@ -66,16 +66,29 @@ namespace shamalgs::collective { } } - inline void write_buf_at(size_t buf_id, size_t offset, const sham::DeviceBuffer &buf) { + inline void send_cache_write_buf_at( + size_t buf_id, size_t offset, const sham::DeviceBuffer &buf) { buf.copy_range_offset( 0, buf.get_size(), shambase::get_check_ref(cache1[buf_id]), offset); } - inline void read_buf_at( + inline void send_cache_read_buf_at( size_t buf_id, size_t offset, size_t size, sham::DeviceBuffer &buf) { buf.resize(size); shambase::get_check_ref(cache1[buf_id]).copy_range(offset, offset + size, buf); } + + inline void recv_cache_write_buf_at( + size_t buf_id, size_t offset, const sham::DeviceBuffer &buf) { + buf.copy_range_offset( + 0, buf.get_size(), shambase::get_check_ref(cache2[buf_id]), offset); + } + + inline void recv_cache_read_buf_at( + size_t buf_id, size_t offset, size_t size, sham::DeviceBuffer &buf) { + buf.resize(size); + shambase::get_check_ref(cache2[buf_id]).copy_range(offset, offset + size, buf); + } }; struct DDSCommCache { @@ -105,19 +118,38 @@ namespace shamalgs::collective { dev_sched, sizes_cache1, sizes_cache2); } - inline void write_buf_at(size_t buf_id, size_t offset, sham::DeviceBuffer &buf) { + inline void send_cache_write_buf_at( + size_t buf_id, size_t offset, sham::DeviceBuffer &buf) { + std::visit( + [&](auto &cache) { + cache.send_cache_write_buf_at(buf_id, offset, buf); + }, + cache); + } + + inline void send_cache_read_buf_at( + size_t buf_id, size_t offset, size_t size, sham::DeviceBuffer &buf) { + std::visit( + [&](auto &cache) { + cache.send_cache_read_buf_at(buf_id, offset, size, buf); + }, + cache); + } + + inline void recv_cache_write_buf_at( + size_t buf_id, size_t offset, sham::DeviceBuffer &buf) { std::visit( [&](auto &cache) { - cache.write_buf_at(buf_id, offset, buf); + cache.recv_cache_write_buf_at(buf_id, offset, buf); }, cache); } - inline void read_buf_at( + inline void recv_cache_read_buf_at( size_t buf_id, size_t offset, size_t size, sham::DeviceBuffer &buf) { std::visit( [&](auto &cache) { - cache.read_buf_at(buf_id, offset, size, buf); + cache.recv_cache_read_buf_at(buf_id, offset, size, buf); }, cache); } diff --git a/src/shamalgs/src/collective/distributedDataComm.cpp b/src/shamalgs/src/collective/distributedDataComm.cpp index da956efce6..55d570df1c 100644 --- a/src/shamalgs/src/collective/distributedDataComm.cpp +++ b/src/shamalgs/src/collective/distributedDataComm.cpp @@ -147,6 +147,9 @@ namespace shamalgs::collective { dev_sched, comm_table2.send_total_sizes, comm_table2.recv_total_sizes); } + SHAM_ASSERT(comm_table2.send_total_sizes.size() == data_send.size()); + SHAM_ASSERT(comm_table2.send_total_sizes.size() == messages_send.size()); + for (size_t i = 0; i < comm_table2.messages_send.size(); i++) { auto &msg_info = comm_table2.messages_send[i]; auto offset_info = shambase::get_check_ref(msg_info.message_bytebuf_offset_send); @@ -154,7 +157,18 @@ namespace shamalgs::collective { SHAM_ASSERT(buf_src.get_size() == msg_info.message_size); - cache.write_buf_at(offset_info.buf_id, offset_info.data_offset, buf_src); + cache.send_cache_write_buf_at(offset_info.buf_id, offset_info.data_offset, buf_src); + // logger::info_ln( + // "distributed data sparse comm", + // "rank :", + // shamcomm::world_rank(), + // "sender :", msg_info.rank_sender, + // "receiver :", msg_info.rank_receiver, + // "write buf at :", + // offset_info.buf_id, + // "offset :", + // offset_info.data_offset, + // "size :", buf_src.get_size()); } if (dev_sched->ctx->device->mpi_prop.is_mpi_direct_capable) { @@ -229,7 +243,17 @@ namespace shamalgs::collective { auto offset_info = shambase::get_check_ref(msg.message_bytebuf_offset_recv); sham::DeviceBuffer recov(size, dev_sched); - cache.read_buf_at(offset_info.buf_id, offset_info.data_offset, size, recov); + cache.recv_cache_read_buf_at(offset_info.buf_id, offset_info.data_offset, size, recov); + + // logger::info_ln( + // "distributed data sparse comm", + // "rank :", + // shamcomm::world_rank(), + // "sender :", sender, + // "receiver :", receiver, + // "read buf at :", + // offset_info.buf_id, + // "offset :", offset_info.data_offset, "size :", size); recv_payload_bufs.push_back( RecvPayloadSer{sender, SerializeHelper(dev_sched, std::move(recov))}); @@ -248,12 +272,24 @@ namespace shamalgs::collective { recv.ser.load(receiver); recv.ser.load(length); + // logger::info_ln( + // "distributed data sparse comm", + // "rank :", + // shamcomm::world_rank(), + // "load obj :", + // sender, + // "->", + // receiver, + // "size :", length); + { // check correctness ranks i32 supposed_sender_rank = rank_getter(sender); i32 real_sender_rank = recv.sender_ranks; if (supposed_sender_rank != real_sender_rank) { - throw make_except_with_loc( - "the rank do not matches"); + throw make_except_with_loc(shambase::format( + "the rank do not matches {} != {}", + supposed_sender_rank, + real_sender_rank)); } } diff --git a/src/tests/shamalgs/collective/sparse_exchange_tests.cpp b/src/tests/shamalgs/collective/sparse_exchange_tests.cpp index 5587a3b866..4cc47b0039 100644 --- a/src/tests/shamalgs/collective/sparse_exchange_tests.cpp +++ b/src/tests/shamalgs/collective/sparse_exchange_tests.cpp @@ -84,6 +84,7 @@ void test_sparse_exchange(std::vector test_elements, size_t max_all } total_send_sizes.at(send_buf_id) += test_elements[i].size; + send_offset += test_elements[i].size; } if (test_elements[i].receiver == shamcomm::world_rank()) { if (recv_offset + test_elements[i].size > max_alloc_size) { @@ -93,6 +94,7 @@ void test_sparse_exchange(std::vector test_elements, size_t max_all } total_recv_sizes.at(recv_buf_id) += test_elements[i].size; + recv_offset += test_elements[i].size; } } } From 99282cfba22be75fa3133db21e085922a7e48ee6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Timoth=C3=A9e=20David--Cl=C3=A9ris?= Date: Mon, 19 Jan 2026 20:38:29 +0100 Subject: [PATCH 25/54] fixed ? --- .../collective/sparse_exchange_tests.cpp | 264 +++++++++++++----- 1 file changed, 200 insertions(+), 64 deletions(-) diff --git a/src/tests/shamalgs/collective/sparse_exchange_tests.cpp b/src/tests/shamalgs/collective/sparse_exchange_tests.cpp index 4cc47b0039..65610591dc 100644 --- a/src/tests/shamalgs/collective/sparse_exchange_tests.cpp +++ b/src/tests/shamalgs/collective/sparse_exchange_tests.cpp @@ -14,6 +14,7 @@ #include "shamcomm/logs.hpp" #include "shamsys/NodeInstance.hpp" #include "shamtest/shamtest.hpp" +#include #include namespace { @@ -32,27 +33,20 @@ void reorder_msg(std::vector &test_elements) { }); } -void test_sparse_exchange(std::vector test_elements, size_t max_alloc_size) { - auto dev_sched = shamsys::instance::get_compute_scheduler_ptr(); - - reorder_msg(test_elements); - - std::vector> all_bufs; - - std::mt19937 eng(0x123); - for (const auto &test_element : test_elements) { - all_bufs.push_back( - shamalgs::random::mock_buffer_usm(dev_sched, eng(), test_element.size)); - } +#if false +void validate_comm_table( + const std::vector &test_elements, + const shamalgs::collective::CommTable &comm_table, + size_t max_alloc_size) { std::vector messages_send; std::vector total_send_sizes = {0}; std::vector total_recv_sizes = {0}; - { + shamalgs::collective::sequentialize([&]() { u32 send_buf_id = 0; - size_t send_offset = 0; u32 recv_buf_id = 0; + size_t send_offset = 0; size_t recv_offset = 0; for (u32 i = 0; i < test_elements.size(); i++) { if (test_elements[i].sender == shamcomm::world_rank()) { @@ -74,8 +68,7 @@ void test_sparse_exchange(std::vector test_elements, size_t max_all test_elements[i].sender, "->", test_elements[i].receiver, - ") data :", - all_bufs[i].copy_to_stdvec()); + ")"); if (send_offset + test_elements[i].size > max_alloc_size) { send_buf_id++; @@ -84,7 +77,6 @@ void test_sparse_exchange(std::vector test_elements, size_t max_all } total_send_sizes.at(send_buf_id) += test_elements[i].size; - send_offset += test_elements[i].size; } if (test_elements[i].receiver == shamcomm::world_rank()) { if (recv_offset + test_elements[i].size > max_alloc_size) { @@ -94,18 +86,156 @@ void test_sparse_exchange(std::vector test_elements, size_t max_all } total_recv_sizes.at(recv_buf_id) += test_elements[i].size; - recv_offset += test_elements[i].size; } } + }); + + REQUIRE_EQUAL(comm_table.send_total_sizes, total_send_sizes); + REQUIRE_EQUAL(comm_table.recv_total_sizes, total_recv_sizes); + + shamalgs::collective::sequentialize([&]() { + size_t send_msg_idx = 0; + size_t recv_msg_idx = 0; + for (u32 i = 0; i < test_elements.size(); i++) { + if (test_elements[i].sender == shamcomm::world_rank()) { + REQUIRE_EQUAL( + comm_table.messages_send[send_msg_idx].message_size, test_elements[i].size); + REQUIRE_EQUAL( + comm_table.messages_send[send_msg_idx].rank_sender, test_elements[i].sender); + REQUIRE_EQUAL( + comm_table.messages_send[send_msg_idx].rank_receiver, + test_elements[i].receiver); + + send_msg_idx++; + } + if (test_elements[i].receiver == shamcomm::world_rank()) { + REQUIRE_EQUAL( + comm_table.messages_recv[recv_msg_idx].message_size, test_elements[i].size); + REQUIRE_EQUAL( + comm_table.messages_recv[recv_msg_idx].rank_sender, test_elements[i].sender); + REQUIRE_EQUAL( + comm_table.messages_recv[recv_msg_idx].rank_receiver, + test_elements[i].receiver); + + auto &ref_buf = all_bufs[i]; + sham::DeviceBuffer recov(test_elements[i].size, dev_sched); + auto off_info = shambase::get_check_ref( + comm_table.messages_recv[recv_msg_idx].message_bytebuf_offset_recv); + size_t begin = off_info.data_offset; + size_t end = begin + test_elements[i].size; + shambase::get_check_ref(recv_bufs.at(off_info.buf_id)) + .copy_range(begin, end, recov); + + logger::info_ln( + "sparse exchange test", + "rank :", + shamcomm::world_rank(), + "recv message : (", + test_elements[i].sender, + "->", + test_elements[i].receiver, + ") data :", + recov.copy_to_stdvec()); + + REQUIRE_EQUAL(recov.copy_to_stdvec(), ref_buf.copy_to_stdvec()); + + recv_msg_idx++; + } + REQUIRE_EQUAL(comm_table.message_all[i].message_size, test_elements[i].size); + REQUIRE_EQUAL(comm_table.message_all[i].rank_sender, test_elements[i].sender); + REQUIRE_EQUAL(comm_table.message_all[i].rank_receiver, test_elements[i].receiver); + } + }); +} + +template<> +struct fmt::formatter { + + template + constexpr auto parse(ParseContext &ctx) { + return ctx.begin(); + } + + template + auto format(shamalgs::collective::CommMessageBufOffset c, FormatContext &ctx) const { + return fmt::format_to( + ctx.out(), "Offset(buf_id : {}, data_offset : {})", c.buf_id, c.data_offset); + } +}; + +template<> +struct fmt::formatter { + + template + constexpr auto parse(ParseContext &ctx) { + return ctx.begin(); + } + + template + auto format(shamalgs::collective::CommMessageInfo c, FormatContext &ctx) const { + return fmt::format_to( + ctx.out(), + "Info(size : {}, sender : {}, receiver : {}, offset_send : {}, offset_recv : {})", + c.message_size, + c.rank_sender, + c.rank_receiver, + c.message_bytebuf_offset_send, + c.message_bytebuf_offset_recv); + } +}; + +void print_comm_table(const shamalgs::collective::CommTable &comm_table) { + std::stringstream ss; + ss << shambase::format( + "messages_send : [\n {}\n]\n", fmt::join(comm_table.messages_send, "\n ")); + ss << shambase::format( + "messages_recv : [\n {}\n]\n", fmt::join(comm_table.messages_recv, "\n ")); + ss << shambase::format( + "message_all : [\n {}\n]\n", fmt::join(comm_table.message_all, "\n ")); + ss << shambase::format("send_message_global_ids : {}\n", comm_table.send_message_global_ids); + ss << shambase::format("recv_message_global_ids : {}\n", comm_table.recv_message_global_ids); + ss << shambase::format("send_total_sizes : {}\n", comm_table.send_total_sizes); + ss << shambase::format("recv_total_sizes : {}\n", comm_table.recv_total_sizes); + logger::info_ln( + "sparse exchange test", "rank :", shamcomm::world_rank(), "comm table :", "\n" + ss.str()); +} +#endif + +void test_sparse_exchange(std::vector test_elements, size_t max_alloc_size) { + auto dev_sched = shamsys::instance::get_compute_scheduler_ptr(); + + reorder_msg(test_elements); + + std::vector> all_bufs; + + std::mt19937 eng(0x123); + for (const auto &test_element : test_elements) { + all_bufs.push_back( + shamalgs::random::mock_buffer_usm(dev_sched, eng(), test_element.size)); + } + + std::vector messages_send; + + for (u32 i = 0; i < test_elements.size(); i++) { + if (test_elements[i].sender == shamcomm::world_rank()) { + messages_send.push_back( + shamalgs::collective::CommMessageInfo{ + test_elements[i].size, + test_elements[i].sender, + test_elements[i].receiver, + std::nullopt, + std::nullopt, + std::nullopt, + }); + } } shamalgs::collective::CommTable comm_table = shamalgs::collective::build_sparse_exchange_table(messages_send, max_alloc_size); - REQUIRE_EQUAL(comm_table.send_total_sizes, total_send_sizes); - REQUIRE_EQUAL(comm_table.recv_total_sizes, total_recv_sizes); + // print_comm_table(comm_table); - // allocate send and receive bufs + // allocate send bufs std::vector>> send_bufs{}; for (size_t i = 0; i < comm_table.send_total_sizes.size(); i++) { @@ -113,13 +243,6 @@ void test_sparse_exchange(std::vector test_elements, size_t max_all std::make_unique>(comm_table.send_total_sizes[i], dev_sched)); } - std::vector>> recv_bufs{}; - - for (size_t i = 0; i < comm_table.recv_total_sizes.size(); i++) { - recv_bufs.push_back( - std::make_unique>(comm_table.recv_total_sizes[i], dev_sched)); - } - // push data to the comm buf for (size_t i = 0; i < comm_table.messages_send.size(); i++) { auto msg_info = comm_table.messages_send[i]; @@ -133,6 +256,14 @@ void test_sparse_exchange(std::vector test_elements, size_t max_all source.copy_range_offset(0, source.get_size(), dest, off_info.data_offset); } + // allocate recv bufs + std::vector>> recv_bufs{}; + + for (size_t i = 0; i < comm_table.recv_total_sizes.size(); i++) { + recv_bufs.push_back( + std::make_unique>(comm_table.recv_total_sizes[i], dev_sched)); + } + // do the comm if (dev_sched->ctx->device->mpi_prop.is_mpi_direct_capable) { shamalgs::collective::sparse_exchange(dev_sched, send_bufs, recv_bufs, comm_table); @@ -158,37 +289,44 @@ void test_sparse_exchange(std::vector test_elements, size_t max_all } } + { + std::stringstream ss; + ss << "send bufs :\n"; + for (size_t i = 0; i < send_bufs.size(); i++) { + ss << "buf " << i << " : " << shambase::format("{}", send_bufs[i]->copy_to_stdvec()) + << "\n"; + } + ss << "recv bufs :\n"; + for (size_t i = 0; i < recv_bufs.size(); i++) { + ss << "buf " << i << " : " << shambase::format("{}", recv_bufs[i]->copy_to_stdvec()) + << "\n"; + } + logger::info_ln("sparse exchange test", "rank :", shamcomm::world_rank(), ss.str()); + } + // time to check + std::vector> recv_messages; - size_t send_msg_idx = 0; - size_t recv_msg_idx = 0; - for (u32 i = 0; i < test_elements.size(); i++) { - if (test_elements[i].sender == shamcomm::world_rank()) { - REQUIRE_EQUAL( - comm_table.messages_send[send_msg_idx].message_size, test_elements[i].size); - REQUIRE_EQUAL( - comm_table.messages_send[send_msg_idx].rank_sender, test_elements[i].sender); - REQUIRE_EQUAL( - comm_table.messages_send[send_msg_idx].rank_receiver, test_elements[i].receiver); - - send_msg_idx++; - } - if (test_elements[i].receiver == shamcomm::world_rank()) { - REQUIRE_EQUAL( - comm_table.messages_recv[recv_msg_idx].message_size, test_elements[i].size); - REQUIRE_EQUAL( - comm_table.messages_recv[recv_msg_idx].rank_sender, test_elements[i].sender); - REQUIRE_EQUAL( - comm_table.messages_recv[recv_msg_idx].rank_receiver, test_elements[i].receiver); - - auto &ref_buf = all_bufs[i]; - sham::DeviceBuffer recov(test_elements[i].size, dev_sched); - auto off_info = shambase::get_check_ref( - comm_table.messages_recv[recv_msg_idx].message_bytebuf_offset_recv); - size_t begin = off_info.data_offset; - size_t end = begin + test_elements[i].size; - shambase::get_check_ref(recv_bufs.at(off_info.buf_id)).copy_range(begin, end, recov); + for (size_t i = 0; i < comm_table.messages_recv.size(); i++) { + auto msg_info = comm_table.messages_recv[i]; + size_t global_msg_id = comm_table.recv_message_global_ids[i]; + + auto off_info + = shambase::get_check_ref(comm_table.messages_recv[i].message_bytebuf_offset_recv); + sham::DeviceBuffer recov(test_elements[global_msg_id].size, dev_sched); + + size_t begin = off_info.data_offset; + size_t end = begin + test_elements[global_msg_id].size; + shambase::get_check_ref(recv_bufs.at(off_info.buf_id)).copy_range(begin, end, recov); + recv_messages.push_back(std::move(recov)); + } + + // validate + u32 recv_idx = 0; + for (size_t i = 0; i < test_elements.size(); i++) { + if (test_elements[i].receiver == shamcomm::world_rank()) { + REQUIRE_EQUAL(recv_messages[recv_idx].copy_to_stdvec(), all_bufs[i].copy_to_stdvec()); logger::info_ln( "sparse exchange test", "rank :", @@ -198,15 +336,13 @@ void test_sparse_exchange(std::vector test_elements, size_t max_all "->", test_elements[i].receiver, ") data :", - recov.copy_to_stdvec()); - - REQUIRE_EQUAL(recov.copy_to_stdvec(), ref_buf.copy_to_stdvec()); - - recv_msg_idx++; + recv_messages[recv_idx].copy_to_stdvec(), + "data ref :", + all_bufs[i].copy_to_stdvec(), + "valid :", + recv_messages[recv_idx].copy_to_stdvec() == all_bufs[i].copy_to_stdvec()); + recv_idx++; } - REQUIRE_EQUAL(comm_table.message_all[i].message_size, test_elements[i].size); - REQUIRE_EQUAL(comm_table.message_all[i].rank_sender, test_elements[i].sender); - REQUIRE_EQUAL(comm_table.message_all[i].rank_receiver, test_elements[i].receiver); } } From f88783d5ec1200ebfabbdf0fc65489c77338a75f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Timoth=C3=A9e=20David--Cl=C3=A9ris?= Date: Mon, 19 Jan 2026 21:28:32 +0100 Subject: [PATCH 26/54] fixed ? --- src/shamalgs/src/collective/distributedDataComm.cpp | 11 +++++++++-- src/shamalgs/src/collective/sparse_exchange.cpp | 12 ++++++++++-- 2 files changed, 19 insertions(+), 4 deletions(-) diff --git a/src/shamalgs/src/collective/distributedDataComm.cpp b/src/shamalgs/src/collective/distributedDataComm.cpp index 55d570df1c..f16110fc04 100644 --- a/src/shamalgs/src/collective/distributedDataComm.cpp +++ b/src/shamalgs/src/collective/distributedDataComm.cpp @@ -147,8 +147,15 @@ namespace shamalgs::collective { dev_sched, comm_table2.send_total_sizes, comm_table2.recv_total_sizes); } - SHAM_ASSERT(comm_table2.send_total_sizes.size() == data_send.size()); - SHAM_ASSERT(comm_table2.send_total_sizes.size() == messages_send.size()); + if(comm_table2.send_total_sizes.size() != data_send.size()) { + throw make_except_with_loc(shambase::format( + "send total sizes size : {} != data send size : {}", comm_table2.send_total_sizes.size(), data_send.size())); + } + + if(comm_table2.send_total_sizes.size() != messages_send.size()) { + throw make_except_with_loc(shambase::format( + "send total sizes size : {} != messages send size : {}", comm_table2.send_total_sizes.size(), messages_send.size())); + } for (size_t i = 0; i < comm_table2.messages_send.size(); i++) { auto &msg_info = comm_table2.messages_send[i]; diff --git a/src/shamalgs/src/collective/sparse_exchange.cpp b/src/shamalgs/src/collective/sparse_exchange.cpp index be37674d97..98ec94c58a 100644 --- a/src/shamalgs/src/collective/sparse_exchange.cpp +++ b/src/shamalgs/src/collective/sparse_exchange.cpp @@ -124,8 +124,8 @@ namespace shamalgs::collective { // Compute offsets //////////////////////////////////////////////////////////// - std::vector send_buf_sizes{0}; - std::vector recv_buf_sizes{0}; + std::vector send_buf_sizes{}; + std::vector recv_buf_sizes{}; u32 send_idx = 0; u32 recv_idx = 0; @@ -151,6 +151,10 @@ namespace shamalgs::collective { max_alloc_size)); } + if(send_buf_sizes.size() == 0){ + send_buf_sizes.push_back(0); + } + if (tmp_send_offset + message_info.message_size > max_alloc_size) { send_buf_id++; tmp_send_offset = 0; @@ -174,6 +178,10 @@ namespace shamalgs::collective { message_info.message_size, max_alloc_size)); } + + if(recv_buf_sizes.size() == 0){ + recv_buf_sizes.push_back(0); + } if (tmp_recv_offset + message_info.message_size > max_alloc_size) { recv_buf_id++; From 9bd9181e0c1e1df2d830f9086c9ed5350c855c47 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Timoth=C3=A9e=20David--Cl=C3=A9ris?= Date: Mon, 19 Jan 2026 21:30:51 +0100 Subject: [PATCH 27/54] fixed ? --- src/shamalgs/src/collective/distributedDataComm.cpp | 12 ++++++++---- src/shamalgs/src/collective/sparse_exchange.cpp | 6 +++--- 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/src/shamalgs/src/collective/distributedDataComm.cpp b/src/shamalgs/src/collective/distributedDataComm.cpp index f16110fc04..bcfb4c0d09 100644 --- a/src/shamalgs/src/collective/distributedDataComm.cpp +++ b/src/shamalgs/src/collective/distributedDataComm.cpp @@ -147,14 +147,18 @@ namespace shamalgs::collective { dev_sched, comm_table2.send_total_sizes, comm_table2.recv_total_sizes); } - if(comm_table2.send_total_sizes.size() != data_send.size()) { + if (comm_table2.send_total_sizes.size() != data_send.size()) { throw make_except_with_loc(shambase::format( - "send total sizes size : {} != data send size : {}", comm_table2.send_total_sizes.size(), data_send.size())); + "send total sizes size : {} != data send size : {}", + comm_table2.send_total_sizes.size(), + data_send.size())); } - if(comm_table2.send_total_sizes.size() != messages_send.size()) { + if (comm_table2.send_total_sizes.size() != messages_send.size()) { throw make_except_with_loc(shambase::format( - "send total sizes size : {} != messages send size : {}", comm_table2.send_total_sizes.size(), messages_send.size())); + "send total sizes size : {} != messages send size : {}", + comm_table2.send_total_sizes.size(), + messages_send.size())); } for (size_t i = 0; i < comm_table2.messages_send.size(); i++) { diff --git a/src/shamalgs/src/collective/sparse_exchange.cpp b/src/shamalgs/src/collective/sparse_exchange.cpp index 98ec94c58a..1ac3c59fc7 100644 --- a/src/shamalgs/src/collective/sparse_exchange.cpp +++ b/src/shamalgs/src/collective/sparse_exchange.cpp @@ -151,7 +151,7 @@ namespace shamalgs::collective { max_alloc_size)); } - if(send_buf_sizes.size() == 0){ + if (send_buf_sizes.size() == 0) { send_buf_sizes.push_back(0); } @@ -178,8 +178,8 @@ namespace shamalgs::collective { message_info.message_size, max_alloc_size)); } - - if(recv_buf_sizes.size() == 0){ + + if (recv_buf_sizes.size() == 0) { recv_buf_sizes.push_back(0); } From 4b9fd8e99d1368ffeb0a2f3a99f2dffaa6933e62 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Timoth=C3=A9e=20David--Cl=C3=A9ris?= Date: Mon, 19 Jan 2026 22:21:40 +0100 Subject: [PATCH 28/54] fixed ? --- .../src/collective/distributedDataComm.cpp | 33 ++++++++++++++----- 1 file changed, 25 insertions(+), 8 deletions(-) diff --git a/src/shamalgs/src/collective/distributedDataComm.cpp b/src/shamalgs/src/collective/distributedDataComm.cpp index bcfb4c0d09..516c9327f7 100644 --- a/src/shamalgs/src/collective/distributedDataComm.cpp +++ b/src/shamalgs/src/collective/distributedDataComm.cpp @@ -147,18 +147,35 @@ namespace shamalgs::collective { dev_sched, comm_table2.send_total_sizes, comm_table2.recv_total_sizes); } - if (comm_table2.send_total_sizes.size() != data_send.size()) { + if (comm_table2.messages_send.size() != data_send.size()) { + std::vector tmp1 {}; + for(size_t i = 0; i < data_send.size(); i++) { + tmp1.push_back(comm_table2.messages_send[i].message_size); + } + + std::vector tmp2 {}; + for(size_t i = 0; i < data_send.size(); i++) { + tmp2.push_back(data_send[i]->get_size()); + } + throw make_except_with_loc(shambase::format( - "send total sizes size : {} != data send size : {}", - comm_table2.send_total_sizes.size(), - data_send.size())); + "message send mismatch : {} != {}", + tmp1, tmp2)); } - if (comm_table2.send_total_sizes.size() != messages_send.size()) { + if (comm_table2.messages_send.size() != messages_send.size()) { + std::vector tmp1 {}; + for(size_t i = 0; i < comm_table2.messages_send.size(); i++) { + tmp1.push_back(comm_table2.messages_send[i].message_size); + } + + std::vector tmp2 {}; + for(size_t i = 0; i < messages_send.size(); i++) { + tmp2.push_back(messages_send[i].message_size); + } throw make_except_with_loc(shambase::format( - "send total sizes size : {} != messages send size : {}", - comm_table2.send_total_sizes.size(), - messages_send.size())); + "message send mismatch : {} != {}", + tmp1, tmp2)); } for (size_t i = 0; i < comm_table2.messages_send.size(); i++) { From 9f938bad171073fafbb3ab866f313d7ae566d5b6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Timoth=C3=A9e=20David--Cl=C3=A9ris?= Date: Mon, 19 Jan 2026 22:21:49 +0100 Subject: [PATCH 29/54] fixed ? --- .../src/collective/distributedDataComm.cpp | 26 +++++++++---------- 1 file changed, 12 insertions(+), 14 deletions(-) diff --git a/src/shamalgs/src/collective/distributedDataComm.cpp b/src/shamalgs/src/collective/distributedDataComm.cpp index 516c9327f7..004f002c92 100644 --- a/src/shamalgs/src/collective/distributedDataComm.cpp +++ b/src/shamalgs/src/collective/distributedDataComm.cpp @@ -148,34 +148,32 @@ namespace shamalgs::collective { } if (comm_table2.messages_send.size() != data_send.size()) { - std::vector tmp1 {}; - for(size_t i = 0; i < data_send.size(); i++) { + std::vector tmp1{}; + for (size_t i = 0; i < data_send.size(); i++) { tmp1.push_back(comm_table2.messages_send[i].message_size); } - std::vector tmp2 {}; - for(size_t i = 0; i < data_send.size(); i++) { + std::vector tmp2{}; + for (size_t i = 0; i < data_send.size(); i++) { tmp2.push_back(data_send[i]->get_size()); } - throw make_except_with_loc(shambase::format( - "message send mismatch : {} != {}", - tmp1, tmp2)); + throw make_except_with_loc( + shambase::format("message send mismatch : {} != {}", tmp1, tmp2)); } if (comm_table2.messages_send.size() != messages_send.size()) { - std::vector tmp1 {}; - for(size_t i = 0; i < comm_table2.messages_send.size(); i++) { + std::vector tmp1{}; + for (size_t i = 0; i < comm_table2.messages_send.size(); i++) { tmp1.push_back(comm_table2.messages_send[i].message_size); } - std::vector tmp2 {}; - for(size_t i = 0; i < messages_send.size(); i++) { + std::vector tmp2{}; + for (size_t i = 0; i < messages_send.size(); i++) { tmp2.push_back(messages_send[i].message_size); } - throw make_except_with_loc(shambase::format( - "message send mismatch : {} != {}", - tmp1, tmp2)); + throw make_except_with_loc( + shambase::format("message send mismatch : {} != {}", tmp1, tmp2)); } for (size_t i = 0; i < comm_table2.messages_send.size(); i++) { From 7598143de15ee54da786b9e466e42134abde076b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Timoth=C3=A9e=20David--Cl=C3=A9ris?= Date: Mon, 19 Jan 2026 23:41:58 +0100 Subject: [PATCH 30/54] better --- .../src/collective/distributedDataComm.cpp | 214 ++++++++++++------ 1 file changed, 142 insertions(+), 72 deletions(-) diff --git a/src/shamalgs/src/collective/distributedDataComm.cpp b/src/shamalgs/src/collective/distributedDataComm.cpp index 004f002c92..edb5ca31b0 100644 --- a/src/shamalgs/src/collective/distributedDataComm.cpp +++ b/src/shamalgs/src/collective/distributedDataComm.cpp @@ -22,9 +22,34 @@ #include "shamalgs/serialize.hpp" #include "shambackends/DeviceBuffer.hpp" #include "shambackends/DeviceScheduler.hpp" +#include "shamcmdopt/env.hpp" #include #include +auto SPARSE_COMM_MODE = shamcmdopt::getenv_str_default_register( + "SPARSE_COMM_MODE", "new", "Sparse communication mode (new=with cache, old=without cache)"); + +namespace { + struct SparseCommMode { + enum Mode { NEW, OLD }; + }; + + constexpr auto parse_sparse_comm_mode = []() { + if (SPARSE_COMM_MODE == "new") { + return SparseCommMode::NEW; + } else if (SPARSE_COMM_MODE == "old") { + return SparseCommMode::OLD; + } else { + throw std::invalid_argument( + "Invalid sparse communication mode, valid modes are: new, old"); + } + }; + + bool use_old_sparse_comm_mode = parse_sparse_comm_mode() == SparseCommMode::OLD; + + bool warning_printed = false; +} // namespace + namespace shamalgs::collective { namespace details { @@ -74,6 +99,114 @@ namespace shamalgs::collective { } // namespace details + void distributed_data_sparse_comm_old( + sham::DeviceScheduler_ptr dev_sched, + SerializedDDataComm &send_distrib_data, + SerializedDDataComm &recv_distrib_data, + std::function rank_getter, + std::optional comm_table) { + + StackEntry stack_loc{}; + + using namespace shambase; + using DataTmp = details::DataTmp; + + // prepare map + std::map, std::vector> send_data; + send_distrib_data.for_each([&](u64 sender, u64 receiver, sham::DeviceBuffer &buf) { + std::pair key = {rank_getter(sender), rank_getter(receiver)}; + + send_data[key].push_back(DataTmp{sender, receiver, buf.get_size(), buf}); + }); + + // serialize together similar communications + std::map, SerializeHelper> serializers + = details::serialize_group_data(dev_sched, send_data); + + // recover bufs from serializers + std::map, std::unique_ptr>> send_bufs; + { + NamedStackEntry stack_loc2{"recover bufs"}; + for (auto &[key, ser] : serializers) { + send_bufs[key] = std::make_unique>(ser.finalize()); + } + } + + // prepare payload + std::vector send_payoad; + { + NamedStackEntry stack_loc2{"prepare payload"}; + for (auto &[key, buf] : send_bufs) { + send_payoad.push_back( + {key.second, + std::make_unique( + shambase::extract_pointer(buf), dev_sched)}); + } + } + + // sparse comm + std::vector recv_payload; + + if (comm_table) { + sparse_comm_c(dev_sched, send_payoad, recv_payload, *comm_table); + } else { + base_sparse_comm(dev_sched, send_payoad, recv_payload); + } + + // make serializers from recv buffs + struct RecvPayloadSer { + i32 sender_ranks; + SerializeHelper ser; + }; + + std::vector recv_payload_bufs; + + { + NamedStackEntry stack_loc2{"move payloads"}; + for (RecvPayload &payload : recv_payload) { + + shamcomm::CommunicationBuffer comm_buf = extract_pointer(payload.payload); + + sham::DeviceBuffer buf + = shamcomm::CommunicationBuffer::convert_usm(std::move(comm_buf)); + + recv_payload_bufs.push_back( + RecvPayloadSer{ + payload.sender_ranks, SerializeHelper(dev_sched, std::move(buf))}); + } + } + + { + NamedStackEntry stack_loc2{"split recv comms"}; + // deserialize into the shared distributed data + for (RecvPayloadSer &recv : recv_payload_bufs) { + u64 cnt_obj; + recv.ser.load(cnt_obj); + for (u32 i = 0; i < cnt_obj; i++) { + u64 sender, receiver, length; + + recv.ser.load(sender); + recv.ser.load(receiver); + recv.ser.load(length); + + { // check correctness ranks + i32 supposed_sender_rank = rank_getter(sender); + i32 real_sender_rank = recv.sender_ranks; + if (supposed_sender_rank != real_sender_rank) { + throw make_except_with_loc( + "the rank do not matches"); + } + } + + auto it = recv_distrib_data.add_obj( + sender, receiver, sham::DeviceBuffer(length, dev_sched)); + + recv.ser.load_buf(it->second, length); + } + } + } + } + void distributed_data_sparse_comm( sham::DeviceScheduler_ptr dev_sched, SerializedDDataComm &send_distrib_data, @@ -82,6 +215,15 @@ namespace shamalgs::collective { DDSCommCache &cache, std::optional comm_table) { + if (use_old_sparse_comm_mode) { + if (shamcomm::world_rank() == 0 && !warning_printed) { + logger::warn_ln("SparseComm", "using old sparse communication mode"); + warning_printed = true; + } + return distributed_data_sparse_comm_old( + dev_sched, send_distrib_data, recv_distrib_data, rank_getter, comm_table); + } + StackEntry stack_loc{}; using namespace shambase; @@ -184,17 +326,6 @@ namespace shamalgs::collective { SHAM_ASSERT(buf_src.get_size() == msg_info.message_size); cache.send_cache_write_buf_at(offset_info.buf_id, offset_info.data_offset, buf_src); - // logger::info_ln( - // "distributed data sparse comm", - // "rank :", - // shamcomm::world_rank(), - // "sender :", msg_info.rank_sender, - // "receiver :", msg_info.rank_receiver, - // "write buf at :", - // offset_info.buf_id, - // "offset :", - // offset_info.data_offset, - // "size :", buf_src.get_size()); } if (dev_sched->ctx->device->mpi_prop.is_mpi_direct_capable) { @@ -211,30 +342,6 @@ namespace shamalgs::collective { comm_table2); } -#ifdef false - // prepare payload - std::vector send_payoad; - { - NamedStackEntry stack_loc2{"prepare payload"}; - for (auto &[key, buf] : send_bufs) { - send_payoad.push_back( - {key.second, - std::make_unique( - shambase::extract_pointer(buf), dev_sched)}); - } - } - - // sparse comm - std::vector recv_payload; - - if (comm_table) { - sparse_comm_c(dev_sched, send_payoad, recv_payload, *comm_table); - } else { - base_sparse_comm(dev_sched, send_payoad, recv_payload); - } - -#endif - // make serializers from recv buffs struct RecvPayloadSer { i32 sender_ranks; @@ -243,23 +350,6 @@ namespace shamalgs::collective { std::vector recv_payload_bufs; -#ifdef false - { - NamedStackEntry stack_loc2{"move payloads"}; - for (RecvPayload &payload : recv_payload) { - - shamcomm::CommunicationBuffer comm_buf = extract_pointer(payload.payload); - - sham::DeviceBuffer buf - = shamcomm::CommunicationBuffer::convert_usm(std::move(comm_buf)); - - recv_payload_bufs.push_back( - RecvPayloadSer{ - payload.sender_ranks, SerializeHelper(dev_sched, std::move(buf))}); - } - } -#endif - for (auto &msg : comm_table2.messages_recv) { u64 size = msg.message_size; @@ -271,16 +361,6 @@ namespace shamalgs::collective { sham::DeviceBuffer recov(size, dev_sched); cache.recv_cache_read_buf_at(offset_info.buf_id, offset_info.data_offset, size, recov); - // logger::info_ln( - // "distributed data sparse comm", - // "rank :", - // shamcomm::world_rank(), - // "sender :", sender, - // "receiver :", receiver, - // "read buf at :", - // offset_info.buf_id, - // "offset :", offset_info.data_offset, "size :", size); - recv_payload_bufs.push_back( RecvPayloadSer{sender, SerializeHelper(dev_sched, std::move(recov))}); } @@ -298,16 +378,6 @@ namespace shamalgs::collective { recv.ser.load(receiver); recv.ser.load(length); - // logger::info_ln( - // "distributed data sparse comm", - // "rank :", - // shamcomm::world_rank(), - // "load obj :", - // sender, - // "->", - // receiver, - // "size :", length); - { // check correctness ranks i32 supposed_sender_rank = rank_getter(sender); i32 real_sender_rank = recv.sender_ranks; From 6d7f1b24cad4048962d69a37bbf85077a229de34 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Timoth=C3=A9e=20David--Cl=C3=A9ris?= Date: Tue, 20 Jan 2026 14:37:54 +0100 Subject: [PATCH 31/54] try a fix --- .../collective/distributedDataComm.hpp | 6 +++ .../src/collective/distributedDataComm.cpp | 3 +- .../src/collective/sparse_exchange.cpp | 10 ++++- .../include/shambackends/DeviceBuffer.hpp | 40 ++++++++++++++++++- .../include/shambackends/DeviceQueue.hpp | 2 +- 5 files changed, 56 insertions(+), 5 deletions(-) diff --git a/src/shamalgs/include/shamalgs/collective/distributedDataComm.hpp b/src/shamalgs/include/shamalgs/collective/distributedDataComm.hpp index 5efd93a5bb..8f77d7aab5 100644 --- a/src/shamalgs/include/shamalgs/collective/distributedDataComm.hpp +++ b/src/shamalgs/include/shamalgs/collective/distributedDataComm.hpp @@ -43,6 +43,9 @@ namespace shamalgs::collective { sham::DeviceScheduler_ptr dev_sched, std::vector sizes_cache1, std::vector sizes_cache2) { + + __shamrock_stack_entry(); + // ensure correct length cache1.resize(sizes_cache1.size()); cache2.resize(sizes_cache2.size()); @@ -109,6 +112,9 @@ namespace shamalgs::collective { sham::DeviceScheduler_ptr dev_sched, std::vector sizes_cache1, std::vector sizes_cache2) { + + __shamrock_stack_entry(); + // init if not there if (std::get_if>(&cache) == nullptr) { cache = DDSCommCacheTarget{}; diff --git a/src/shamalgs/src/collective/distributedDataComm.cpp b/src/shamalgs/src/collective/distributedDataComm.cpp index edb5ca31b0..f31044817b 100644 --- a/src/shamalgs/src/collective/distributedDataComm.cpp +++ b/src/shamalgs/src/collective/distributedDataComm.cpp @@ -224,7 +224,7 @@ namespace shamalgs::collective { dev_sched, send_distrib_data, recv_distrib_data, rank_getter, comm_table); } - StackEntry stack_loc{}; + __shamrock_stack_entry(); using namespace shambase; using DataTmp = details::DataTmp; @@ -277,6 +277,7 @@ namespace shamalgs::collective { } else { max_alloc_size = dev_sched->ctx->device->prop.max_mem_alloc_size_host; } + max_alloc_size *= 0.95; // keep 5% of the max alloc size for safety shamalgs::collective::CommTable comm_table2 = shamalgs::collective::build_sparse_exchange_table(messages_send, max_alloc_size); diff --git a/src/shamalgs/src/collective/sparse_exchange.cpp b/src/shamalgs/src/collective/sparse_exchange.cpp index 1ac3c59fc7..291187b952 100644 --- a/src/shamalgs/src/collective/sparse_exchange.cpp +++ b/src/shamalgs/src/collective/sparse_exchange.cpp @@ -155,10 +155,11 @@ namespace shamalgs::collective { send_buf_sizes.push_back(0); } - if (tmp_send_offset + message_info.message_size > max_alloc_size) { + if (tmp_send_offset + message_info.message_size >= max_alloc_size) { send_buf_id++; tmp_send_offset = 0; send_buf_sizes.push_back(0); + logger::info_ln("sparse comm", "is using multiple buffers (send) !"); } message_info.message_bytebuf_offset_send = {send_buf_id, tmp_send_offset}; @@ -183,10 +184,11 @@ namespace shamalgs::collective { recv_buf_sizes.push_back(0); } - if (tmp_recv_offset + message_info.message_size > max_alloc_size) { + if (tmp_recv_offset + message_info.message_size >= max_alloc_size) { recv_buf_id++; tmp_recv_offset = 0; recv_buf_sizes.push_back(0); + logger::info_ln("sparse comm", "is using multiple buffers (recv) !"); } message_info.message_bytebuf_offset_recv = {recv_buf_id, tmp_recv_offset}; @@ -200,6 +202,10 @@ namespace shamalgs::collective { } } + { + logger::info_ln("sparse comm", "send_buf_sizes :", send_buf_sizes); + logger::info_ln("sparse comm", "recv_buf_sizes :", recv_buf_sizes); + } //////////////////////////////////////////////////////////// // now that all comm were computed we can build the send and recv message lists //////////////////////////////////////////////////////////// diff --git a/src/shambackends/include/shambackends/DeviceBuffer.hpp b/src/shambackends/include/shambackends/DeviceBuffer.hpp index 36acc9c394..8fcd304017 100644 --- a/src/shambackends/include/shambackends/DeviceBuffer.hpp +++ b/src/shambackends/include/shambackends/DeviceBuffer.hpp @@ -18,6 +18,7 @@ #include "shambase/assert.hpp" #include "shambase/memory.hpp" +#include "shambase/type_traits.hpp" #include "shambackends/DeviceScheduler.hpp" #include "shambackends/USMPtrHolder.hpp" #include "shambackends/details/BufferEventHandler.hpp" @@ -1014,6 +1015,27 @@ namespace sham { // Size manipulation /////////////////////////////////////////////////////////////////////// + inline size_t get_max_alloc_size() const { + size_t max_alloc_size; + + auto &dev_prop = hold.get_dev_scheduler().get_queue().get_device_prop(); + + if constexpr (target == device) { + max_alloc_size = dev_prop.max_mem_alloc_size_dev; + } else if constexpr (target == host) { + max_alloc_size = dev_prop.max_mem_alloc_size_host; + } else if constexpr (target == shared) { + max_alloc_size + = sycl::min(dev_prop.max_mem_alloc_size_dev, dev_prop.max_mem_alloc_size_host); + } else { + static_assert( + shambase::always_false_v, + "get_max_alloc_size: invalid target"); + } + + return max_alloc_size; + } + /** * @brief Resizes the buffer to a given size. * @@ -1030,7 +1052,23 @@ namespace sham { if (alloc_request_size_fct(new_size, dev_sched) > hold.get_bytesize()) { // expand storage - size_t new_storage_size = alloc_request_size_fct(new_size * 1.5, dev_sched); + size_t min_size_new_alloc = alloc_request_size_fct(new_size, dev_sched); + size_t wanted_size_new_alloc = alloc_request_size_fct(new_size * 1.5, dev_sched); + + size_t new_storage_size = sycl::min(min_size_new_alloc, wanted_size_new_alloc); + + if (new_storage_size > get_max_alloc_size()) { + shambase::throw_with_loc(shambase::format( + "new_storage_size > get_max_alloc_size()\n" + " new_storage_size = {}\n" + " get_max_alloc_size() = {}\n" + " min_size_new_alloc = {}\n" + " wanted_size_new_alloc = {}", + new_storage_size, + get_max_alloc_size(), + min_size_new_alloc, + wanted_size_new_alloc)); + } DeviceBuffer new_buf( new_size, diff --git a/src/shambackends/include/shambackends/DeviceQueue.hpp b/src/shambackends/include/shambackends/DeviceQueue.hpp index 51bd82fbbe..1fc01c3c63 100644 --- a/src/shambackends/include/shambackends/DeviceQueue.hpp +++ b/src/shambackends/include/shambackends/DeviceQueue.hpp @@ -162,7 +162,7 @@ namespace sham { * * @return DeviceProperties The properties of the associated device */ - inline DeviceProperties get_device_prop() { + inline DeviceProperties &get_device_prop() { return shambase::get_check_ref(ctx).device->prop; } }; From 1d370ad5c9ac63c7587d831fccdab3fd0a1e2a5b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Timoth=C3=A9e=20David--Cl=C3=A9ris?= Date: Wed, 21 Jan 2026 09:37:09 +0100 Subject: [PATCH 32/54] fixed ? --- src/shambackends/include/shambackends/DeviceBuffer.hpp | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/shambackends/include/shambackends/DeviceBuffer.hpp b/src/shambackends/include/shambackends/DeviceBuffer.hpp index 0323ed97f7..46c00b05ca 100644 --- a/src/shambackends/include/shambackends/DeviceBuffer.hpp +++ b/src/shambackends/include/shambackends/DeviceBuffer.hpp @@ -1040,7 +1040,10 @@ namespace sham { size_t min_size_new_alloc = alloc_request_size_fct(new_size, dev_sched); size_t wanted_size_new_alloc = alloc_request_size_fct(new_size * 1.5, dev_sched); - size_t new_storage_size = sycl::min(min_size_new_alloc, wanted_size_new_alloc); + size_t new_storage_size = wanted_size_new_alloc; + if (new_storage_size > get_max_alloc_size()) { + new_storage_size = sycl::max(get_max_alloc_size(), min_size_new_alloc); + } if (new_storage_size > get_max_alloc_size()) { shambase::throw_with_loc(shambase::format( From 55a875d6c6407d8b23d82fded92a35aef7cb6ee6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Timoth=C3=A9e=20David--Cl=C3=A9ris?= Date: Wed, 21 Jan 2026 10:04:26 +0100 Subject: [PATCH 33/54] fixed ? --- src/shambackends/include/shambackends/DeviceBuffer.hpp | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/shambackends/include/shambackends/DeviceBuffer.hpp b/src/shambackends/include/shambackends/DeviceBuffer.hpp index 46c00b05ca..8f805fda8e 100644 --- a/src/shambackends/include/shambackends/DeviceBuffer.hpp +++ b/src/shambackends/include/shambackends/DeviceBuffer.hpp @@ -1037,15 +1037,18 @@ namespace sham { if (alloc_request_size_fct(new_size, dev_sched) > hold.get_bytesize()) { // expand storage + size_t max_alloc_size = get_max_alloc_size(); + size_t alignment = get_alignment(dev_sched); size_t min_size_new_alloc = alloc_request_size_fct(new_size, dev_sched); size_t wanted_size_new_alloc = alloc_request_size_fct(new_size * 1.5, dev_sched); + size_t max_possible_alloc = max_alloc_size - (max_alloc_size % alignment); size_t new_storage_size = wanted_size_new_alloc; - if (new_storage_size > get_max_alloc_size()) { - new_storage_size = sycl::max(get_max_alloc_size(), min_size_new_alloc); + if (new_storage_size > max_alloc_size) { + new_storage_size = sycl::max(max_possible_alloc, min_size_new_alloc); } - if (new_storage_size > get_max_alloc_size()) { + if (new_storage_size > max_alloc_size) { shambase::throw_with_loc(shambase::format( "new_storage_size > get_max_alloc_size()\n" " new_storage_size = {}\n" From 24b3e470fc1e1ee3de02be87933c075967ed850d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Timoth=C3=A9e=20David--Cl=C3=A9ris?= Date: Wed, 21 Jan 2026 10:05:48 +0100 Subject: [PATCH 34/54] fixed ? --- .../include/shambackends/DeviceBuffer.hpp | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/src/shambackends/include/shambackends/DeviceBuffer.hpp b/src/shambackends/include/shambackends/DeviceBuffer.hpp index 8f805fda8e..fea4db40d3 100644 --- a/src/shambackends/include/shambackends/DeviceBuffer.hpp +++ b/src/shambackends/include/shambackends/DeviceBuffer.hpp @@ -1037,11 +1037,15 @@ namespace sham { if (alloc_request_size_fct(new_size, dev_sched) > hold.get_bytesize()) { // expand storage - size_t max_alloc_size = get_max_alloc_size(); - size_t alignment = get_alignment(dev_sched); - size_t min_size_new_alloc = alloc_request_size_fct(new_size, dev_sched); - size_t wanted_size_new_alloc = alloc_request_size_fct(new_size * 1.5, dev_sched); - size_t max_possible_alloc = max_alloc_size - (max_alloc_size % alignment); + size_t max_alloc_size = get_max_alloc_size(); + std::optional alignment = get_alignment(dev_sched); + size_t min_size_new_alloc = alloc_request_size_fct(new_size, dev_sched); + size_t wanted_size_new_alloc = alloc_request_size_fct(new_size * 1.5, dev_sched); + size_t max_possible_alloc = max_alloc_size; + + if (alignment) { + max_possible_alloc = max_possible_alloc - (max_possible_alloc % *alignment); + } size_t new_storage_size = wanted_size_new_alloc; if (new_storage_size > max_alloc_size) { From 634900aa11cd4817889fe98847ca1363691e9e3a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Timoth=C3=A9e=20David--Cl=C3=A9ris?= Date: Wed, 21 Jan 2026 10:23:02 +0100 Subject: [PATCH 35/54] avoid copy in cache resize --- .../include/shamalgs/collective/distributedDataComm.hpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/shamalgs/include/shamalgs/collective/distributedDataComm.hpp b/src/shamalgs/include/shamalgs/collective/distributedDataComm.hpp index 8f77d7aab5..b00f1fb76d 100644 --- a/src/shamalgs/include/shamalgs/collective/distributedDataComm.hpp +++ b/src/shamalgs/include/shamalgs/collective/distributedDataComm.hpp @@ -53,7 +53,7 @@ namespace shamalgs::collective { // if size is different, resize for (size_t i = 0; i < sizes_cache1.size(); i++) { if (cache1[i]) { - cache1[i]->resize(sizes_cache1[i]); + cache1[i]->resize(sizes_cache1[i], false); } else { cache1[i] = std::make_unique>( sizes_cache1[i], dev_sched); @@ -61,7 +61,7 @@ namespace shamalgs::collective { } for (size_t i = 0; i < sizes_cache2.size(); i++) { if (cache2[i]) { - cache2[i]->resize(sizes_cache2[i]); + cache2[i]->resize(sizes_cache2[i], false); } else { cache2[i] = std::make_unique>( sizes_cache2[i], dev_sched); From 245b0b9238b5de7d2bf9abeae5b4d87581a67f6a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Timoth=C3=A9e=20David--Cl=C3=A9ris?= Date: Wed, 21 Jan 2026 12:44:32 +0100 Subject: [PATCH 36/54] better --- src/shambackends/src/Device.cpp | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/shambackends/src/Device.cpp b/src/shambackends/src/Device.cpp index 6de70fbbdd..599af9fc54 100644 --- a/src/shambackends/src/Device.cpp +++ b/src/shambackends/src/Device.cpp @@ -273,6 +273,13 @@ namespace sham { } } + {// PCI id infos + #if defined(SYCL_EXT_INTEL_DEVICE_INFO) && SYCL_EXT_INTEL_DEVICE_INFO >= 5 + FETCH_PROP(pci_address, sycl::ext::intel::info::device::pci_address) + logger::raw_ln("pci address :", pci_address.value()); + #endif + } + return DeviceProperties{ Vendor::UNKNOWN, // We cannot determine the vendor get_device_backend(dev), // Query the backend based on the platform name From a4d18ea1429dedc063efb8fb4916059aafaff968 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Timoth=C3=A9e=20David--Cl=C3=A9ris?= Date: Wed, 21 Jan 2026 12:45:22 +0100 Subject: [PATCH 37/54] better --- src/shambackends/src/Device.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/shambackends/src/Device.cpp b/src/shambackends/src/Device.cpp index 599af9fc54..a735fade66 100644 --- a/src/shambackends/src/Device.cpp +++ b/src/shambackends/src/Device.cpp @@ -275,7 +275,7 @@ namespace sham { {// PCI id infos #if defined(SYCL_EXT_INTEL_DEVICE_INFO) && SYCL_EXT_INTEL_DEVICE_INFO >= 5 - FETCH_PROP(pci_address, sycl::ext::intel::info::device::pci_address) + FETCH_PROP(pci_address, sycl::ext::intel::info::pci_address) logger::raw_ln("pci address :", pci_address.value()); #endif } From 981eb4d3b3e0bf4e554cdf6f79ad02050799a8f7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Timoth=C3=A9e=20David--Cl=C3=A9ris?= Date: Wed, 21 Jan 2026 12:48:20 +0100 Subject: [PATCH 38/54] better --- src/shambackends/src/Device.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/shambackends/src/Device.cpp b/src/shambackends/src/Device.cpp index a735fade66..a9681b4f28 100644 --- a/src/shambackends/src/Device.cpp +++ b/src/shambackends/src/Device.cpp @@ -275,11 +275,12 @@ namespace sham { {// PCI id infos #if defined(SYCL_EXT_INTEL_DEVICE_INFO) && SYCL_EXT_INTEL_DEVICE_INFO >= 5 - FETCH_PROP(pci_address, sycl::ext::intel::info::pci_address) + FETCH_PROPN(sycl::ext::intel::info::device::pci_address, u32, pci_address) logger::raw_ln("pci address :", pci_address.value()); #endif } + return DeviceProperties{ Vendor::UNKNOWN, // We cannot determine the vendor get_device_backend(dev), // Query the backend based on the platform name From cb101fc43866b96ec1bf4919adace7f6d938569f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Timoth=C3=A9e=20David--Cl=C3=A9ris?= Date: Wed, 21 Jan 2026 12:49:13 +0100 Subject: [PATCH 39/54] better --- src/shambackends/src/Device.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/shambackends/src/Device.cpp b/src/shambackends/src/Device.cpp index a9681b4f28..d8eb495cc2 100644 --- a/src/shambackends/src/Device.cpp +++ b/src/shambackends/src/Device.cpp @@ -275,7 +275,7 @@ namespace sham { {// PCI id infos #if defined(SYCL_EXT_INTEL_DEVICE_INFO) && SYCL_EXT_INTEL_DEVICE_INFO >= 5 - FETCH_PROPN(sycl::ext::intel::info::device::pci_address, u32, pci_address) + FETCH_PROPN(sycl::ext::intel::info::device::pci_address, std::string, pci_address) logger::raw_ln("pci address :", pci_address.value()); #endif } From 07d73af2ba6ae155ab4f70dbe440142cb2b9c0fb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Timoth=C3=A9e=20David--Cl=C3=A9ris?= Date: Wed, 21 Jan 2026 12:50:17 +0100 Subject: [PATCH 40/54] better --- src/shambackends/src/Device.cpp | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/src/shambackends/src/Device.cpp b/src/shambackends/src/Device.cpp index d8eb495cc2..91f0a6bb8d 100644 --- a/src/shambackends/src/Device.cpp +++ b/src/shambackends/src/Device.cpp @@ -110,6 +110,20 @@ namespace sham { } \ }(); + /// Fetches a property of a SYCL device (for cases where multiple prop would have the same name) +#define FETCH_PROPN_FULL(info_, info_type, n) \ +std::optional n = [&]() -> std::optional { \ + try { \ + return {dev.get_info()}; \ + } catch (...) { \ + logger::warn_ln( \ + "Device", \ + "dev.get_info<" #info_ ">() raised an exception for device", \ + name); \ + return {}; \ + } \ +}(); + /** * @brief Fetches the properties of a SYCL device. * @@ -275,7 +289,7 @@ namespace sham { {// PCI id infos #if defined(SYCL_EXT_INTEL_DEVICE_INFO) && SYCL_EXT_INTEL_DEVICE_INFO >= 5 - FETCH_PROPN(sycl::ext::intel::info::device::pci_address, std::string, pci_address) + FETCH_PROPN_FULL(sycl::ext::intel::info::device::pci_address, std::string, pci_address) logger::raw_ln("pci address :", pci_address.value()); #endif } From 12f7449ac0ee00cb2ee5100540b7b4587d99e17b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Timoth=C3=A9e=20David--Cl=C3=A9ris?= Date: Wed, 21 Jan 2026 13:03:44 +0100 Subject: [PATCH 41/54] better --- .../include/shambackends/Device.hpp | 3 + src/shambackends/src/Device.cpp | 65 ++++++++++--------- src/shamsys/src/shamrock_smi.cpp | 6 +- 3 files changed, 40 insertions(+), 34 deletions(-) diff --git a/src/shambackends/include/shambackends/Device.hpp b/src/shambackends/include/shambackends/Device.hpp index 2169290a39..8a9d89f5aa 100644 --- a/src/shambackends/include/shambackends/Device.hpp +++ b/src/shambackends/include/shambackends/Device.hpp @@ -120,6 +120,9 @@ namespace sham { /// Default work group size uint32_t default_work_group_size; + + /// PCI address of the device + std::optional pci_address; }; struct DeviceMPIProperties { diff --git a/src/shambackends/src/Device.cpp b/src/shambackends/src/Device.cpp index 91f0a6bb8d..31cd6f452a 100644 --- a/src/shambackends/src/Device.cpp +++ b/src/shambackends/src/Device.cpp @@ -110,19 +110,17 @@ namespace sham { } \ }(); - /// Fetches a property of a SYCL device (for cases where multiple prop would have the same name) -#define FETCH_PROPN_FULL(info_, info_type, n) \ -std::optional n = [&]() -> std::optional { \ - try { \ - return {dev.get_info()}; \ - } catch (...) { \ - logger::warn_ln( \ - "Device", \ - "dev.get_info<" #info_ ">() raised an exception for device", \ - name); \ - return {}; \ - } \ -}(); + /// Fetches a property of a SYCL device (for cases where multiple prop would have the same name) +#define FETCH_PROPN_FULL(info_, info_type, n) \ + std::optional n = [&]() -> std::optional { \ + try { \ + return {dev.get_info()}; \ + } catch (...) { \ + logger::warn_ln( \ + "Device", "dev.get_info<" #info_ ">() raised an exception for device", name); \ + return {}; \ + } \ + }(); /** * @brief Fetches the properties of a SYCL device. @@ -287,28 +285,31 @@ std::optional n = [&]() -> std::optional { } } - {// PCI id infos - #if defined(SYCL_EXT_INTEL_DEVICE_INFO) && SYCL_EXT_INTEL_DEVICE_INFO >= 5 + DeviceProperties ret + = {Vendor::UNKNOWN, // We cannot determine the vendor + get_device_backend(dev), // Query the backend based on the platform name + get_device_type(dev), + shambase::get_check_ref(global_mem_size), + shambase::get_check_ref(global_mem_cache_line_size), + shambase::get_check_ref(global_mem_cache_size), + shambase::get_check_ref(local_mem_size), + shambase::get_check_ref(max_compute_units), + max_alloc_dev, + max_alloc_host, + shambase::get_check_ref(mem_base_addr_align), + shambase::get_check_ref(sub_group_sizes), + default_work_group_size}; + + { // PCI id infos +#if defined(SYCL_EXT_INTEL_DEVICE_INFO) && SYCL_EXT_INTEL_DEVICE_INFO >= 5 FETCH_PROPN_FULL(sycl::ext::intel::info::device::pci_address, std::string, pci_address) - logger::raw_ln("pci address :", pci_address.value()); - #endif + if (pci_address) { + ret.pci_address = *pci_address; + } +#endif } - - return DeviceProperties{ - Vendor::UNKNOWN, // We cannot determine the vendor - get_device_backend(dev), // Query the backend based on the platform name - get_device_type(dev), - shambase::get_check_ref(global_mem_size), - shambase::get_check_ref(global_mem_cache_line_size), - shambase::get_check_ref(global_mem_cache_size), - shambase::get_check_ref(local_mem_size), - shambase::get_check_ref(max_compute_units), - max_alloc_dev, - max_alloc_host, - shambase::get_check_ref(mem_base_addr_align), - shambase::get_check_ref(sub_group_sizes), - default_work_group_size}; + return ret; } /** diff --git a/src/shamsys/src/shamrock_smi.cpp b/src/shamsys/src/shamrock_smi.cpp index a2ec787db8..00085c80d3 100644 --- a/src/shamsys/src/shamrock_smi.cpp +++ b/src/shamsys/src/shamrock_smi.cpp @@ -180,7 +180,8 @@ namespace shamsys { - local_mem_size = {} - mem_base_addr_align = {}, - max_mem_alloc_size_dev = {}, - - max_mem_alloc_size_host = {})", + - max_mem_alloc_size_host = {}, + - pci_address = {})", DeviceName, dev.device_id, dev.prop.default_work_group_size, @@ -188,7 +189,8 @@ namespace shamsys { nolimit_if_too_large(dev.prop.local_mem_size), dev.prop.mem_base_addr_align, shambase::readable_sizeof(dev.prop.max_mem_alloc_size_dev), - shambase::readable_sizeof(dev.prop.max_mem_alloc_size_host)); + shambase::readable_sizeof(dev.prop.max_mem_alloc_size_host), + dev.prop.pci_address ? *dev.prop.pci_address : "N/A"); std::unordered_map devicename_histogram = shamcomm::string_histogram({dev_with_id}, "xxx\nxxx"); From 46a5b5c4b2374af954116bcdc9ca652e40c335f3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Timoth=C3=A9e=20David--Cl=C3=A9ris?= Date: Thu, 22 Jan 2026 16:52:11 +0100 Subject: [PATCH 42/54] [Pylib] provide a phantom simulation load utility --- src/pylib/shamrock/utils/__init__.py | 2 +- src/pylib/shamrock/utils/phantom/__init__.py | 133 ++++++++++++++++++ .../sph/include/shammodels/sph/Model.hpp | 2 +- src/shammodels/sph/src/Model.cpp | 5 +- src/shammodels/sph/src/pySPHModel.cpp | 8 +- 5 files changed, 143 insertions(+), 7 deletions(-) create mode 100644 src/pylib/shamrock/utils/phantom/__init__.py diff --git a/src/pylib/shamrock/utils/__init__.py b/src/pylib/shamrock/utils/__init__.py index 1ae54e2b49..3cb4eef9fe 100644 --- a/src/pylib/shamrock/utils/__init__.py +++ b/src/pylib/shamrock/utils/__init__.py @@ -2,4 +2,4 @@ Shamrock utility library. """ -from . import plot +from . import phantom, plot diff --git a/src/pylib/shamrock/utils/phantom/__init__.py b/src/pylib/shamrock/utils/phantom/__init__.py new file mode 100644 index 0000000000..b1ea2acad7 --- /dev/null +++ b/src/pylib/shamrock/utils/phantom/__init__.py @@ -0,0 +1,133 @@ +""" +Phantom related utilities. +""" + +import os + +import shamrock.sys + + +def parse_in_file(in_file): + """ + Parse a Phantom .in file and return a dictionary of the parameters. + """ + with open(in_file, "r") as f: + lines = f.readlines() + + params = {} + + for line in lines: + # Skip empty lines and comment lines + stripped_line = line.strip() + if not stripped_line or stripped_line.startswith("#"): + continue + + # Check if line contains an equals sign + if "=" in line: + # Split by '=' to get variable name and value part + parts = line.split("=", 1) + var_name = parts[0].strip() + + # Get value part (everything after =) + value_part = parts[1] + + # Remove comment if present (text after !) + if "!" in value_part: + value_part = value_part.split("!")[0] + + # Strip whitespace from value + value = value_part.strip() + + # Try to convert to appropriate type + # Check for boolean + if value == "T": + value = True + elif value == "F": + value = False + else: + # Try to convert to number + try: + # Try integer first + if "." not in value and "E" not in value and "e" not in value: + value = int(value) + else: + # Try float + value = float(value) + except ValueError: + # Keep as string if conversion fails + pass + + params[var_name] = value + + return params + + +def in_dump_to_config(model, simulation_path, in_file_name, dump_file_name, do_print=True): + """ + Convert a Phantom .in file and a Phantom dump file to a Shamrock config. + """ + in_params = parse_in_file(os.path.join(simulation_path, in_file_name)) + + dump = shamrock.load_phantom_dump(os.path.join(simulation_path, dump_file_name)) + + cfg = model.gen_config_from_phantom_dump(dump) + # Set the solver config to be the one stored in cfg + model.set_solver_config(cfg) + + # Print the solver config + if do_print and shamrock.sys.world_rank() == 0: + print("Solver config:") + model.get_current_config().print_status() + print("In file parameters:") + for key, value in in_params.items(): + print(f"{key}: {value}") + print("Dump state:") + dump.print_state() + + return cfg, dump + + +def load_simulation(simulation_path, in_file_name, dump_file_name, do_print=True): + """ + Load a Phantom simulation into a Shamrock model. + """ + + dump_path = os.path.join(simulation_path, dump_file_name) + in_file_path = os.path.join(simulation_path, in_file_name) + + in_params = parse_in_file(os.path.join(simulation_path, in_file_name)) + + # setup = dump finish with .tmp + is_setup_file = dump_file_name.endswith(".tmp") + + # Open the phantom dump + dump = shamrock.load_phantom_dump(dump_path) + + # Start a SPH simulation from the phantom dump + ctx = shamrock.Context() + ctx.pdata_layout_new() + model = shamrock.get_Model_SPH(context=ctx, vector_type="f64_3", sph_kernel="M4") + + cfg = model.gen_config_from_phantom_dump(dump) + # Set the solver config to be the one stored in cfg + model.set_solver_config(cfg) + + # Print infos + if do_print and shamrock.sys.world_rank() == 0: + print(f"Is setup file: {is_setup_file}") + print("Solver config:") + model.get_current_config().print_status() + print("In file parameters:") + for key, value in in_params.items(): + print(f"{key}: {value}") + print("Dump state:") + dump.print_state() + + model.init_scheduler(int(1e8), 1) + + if is_setup_file: + model.init_from_phantom_dump(dump, 0.5) + else: + model.init_from_phantom_dump(dump, 1.0) + + return ctx, model diff --git a/src/shammodels/sph/include/shammodels/sph/Model.hpp b/src/shammodels/sph/include/shammodels/sph/Model.hpp index e8f70d625b..bbc610db33 100644 --- a/src/shammodels/sph/include/shammodels/sph/Model.hpp +++ b/src/shammodels/sph/include/shammodels/sph/Model.hpp @@ -100,7 +100,7 @@ namespace shammodels::sph { } SolverConfig gen_config_from_phantom_dump(PhantomDump &phdump, bool bypass_error); - void init_from_phantom_dump(PhantomDump &phdump); + void init_from_phantom_dump(PhantomDump &phdump, Tscal hpart_fact_load = 1.0); PhantomDump make_phantom_dump(); void do_vtk_dump(std::string filename, bool add_patch_world_id) { diff --git a/src/shammodels/sph/src/Model.cpp b/src/shammodels/sph/src/Model.cpp index b93ea032fd..bc59cb6fa8 100644 --- a/src/shammodels/sph/src/Model.cpp +++ b/src/shammodels/sph/src/Model.cpp @@ -1224,7 +1224,8 @@ auto shammodels::sph::Model::gen_config_from_phantom_dump( } template class SPHKernel> -void shammodels::sph::Model::init_from_phantom_dump(PhantomDump &phdump) { +void shammodels::sph::Model::init_from_phantom_dump( + PhantomDump &phdump, Tscal hpart_fact_load) { StackEntry stack_loc{}; bool has_coord_in_header = true; @@ -1353,7 +1354,7 @@ void shammodels::sph::Model::init_from_phantom_dump(PhantomDump ins_vxyz.push_back(vxyz[i]); } for (u64 i : sel_index) { - ins_h.push_back(h[i]); + ins_h.push_back(h[i] * hpart_fact_load); } if (u.size() > 0) { for (u64 i : sel_index) { diff --git a/src/shammodels/sph/src/pySPHModel.cpp b/src/shammodels/sph/src/pySPHModel.cpp index 7c30926f90..d859b57ea2 100644 --- a/src/shammodels/sph/src/pySPHModel.cpp +++ b/src/shammodels/sph/src/pySPHModel.cpp @@ -951,9 +951,11 @@ void add_instance(py::module &m, std::string name_config, std::string name_model )==") .def( "init_from_phantom_dump", - [](T &self, PhantomDump &dump) { - self.init_from_phantom_dump(dump); - }) + [](T &self, PhantomDump &dump, Tscal hpart_fact_load) { + self.init_from_phantom_dump(dump, hpart_fact_load); + }, + py::arg("dump"), + py::arg("hpart_fact_load") = 1.0) .def( "make_phantom_dump", [](T &self) { From f1ce2ed466b0d1656cc430cc74a217d24a148d28 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Timoth=C3=A9e=20David--Cl=C3=A9ris?= Date: Thu, 22 Jan 2026 17:11:04 +0100 Subject: [PATCH 43/54] use it in tests --- .../sph/run_start_sph_from_phantom_dump.py | 23 +++---------------- src/pylib/shamrock/utils/phantom/__init__.py | 18 ++++++++++----- 2 files changed, 15 insertions(+), 26 deletions(-) diff --git a/doc/sphinx/examples/sph/run_start_sph_from_phantom_dump.py b/doc/sphinx/examples/sph/run_start_sph_from_phantom_dump.py index f7eb2582d2..0a5dcfd0e8 100644 --- a/doc/sphinx/examples/sph/run_start_sph_from_phantom_dump.py +++ b/doc/sphinx/examples/sph/run_start_sph_from_phantom_dump.py @@ -30,26 +30,9 @@ shamrock.change_loglevel(1) shamrock.sys.init("0:0") -# %% -# Open the phantom dump -dump = shamrock.load_phantom_dump(filename) -dump.print_state() - -# %% -# Start a SPH simulation from the phantom dump -ctx = shamrock.Context() -ctx.pdata_layout_new() -model = shamrock.get_Model_SPH(context=ctx, vector_type="f64_3", sph_kernel="M4") - -cfg = model.gen_config_from_phantom_dump(dump) -# Set the solver config to be the one stored in cfg -model.set_solver_config(cfg) -# Print the solver config -model.get_current_config().print_status() - -model.init_scheduler(int(1e8), 1) - -model.init_from_phantom_dump(dump) +ctx, model = shamrock.utils.phantom.load_simulation( + dump_folder, dump_file_name="blast_00010", in_file_name=None +) # %% # Run a simple timestep just for wasting some computing time :) diff --git a/src/pylib/shamrock/utils/phantom/__init__.py b/src/pylib/shamrock/utils/phantom/__init__.py index b1ea2acad7..61d3fa310c 100644 --- a/src/pylib/shamrock/utils/phantom/__init__.py +++ b/src/pylib/shamrock/utils/phantom/__init__.py @@ -87,15 +87,18 @@ def in_dump_to_config(model, simulation_path, in_file_name, dump_file_name, do_p return cfg, dump -def load_simulation(simulation_path, in_file_name, dump_file_name, do_print=True): +def load_simulation(simulation_path, dump_file_name, in_file_name=None, do_print=True): """ Load a Phantom simulation into a Shamrock model. """ dump_path = os.path.join(simulation_path, dump_file_name) - in_file_path = os.path.join(simulation_path, in_file_name) - in_params = parse_in_file(os.path.join(simulation_path, in_file_name)) + if in_file_name is not None: + in_file_path = os.path.join(simulation_path, in_file_name) + in_params = parse_in_file(in_file_path) + else: + in_params = None # setup = dump finish with .tmp is_setup_file = dump_file_name.endswith(".tmp") @@ -117,9 +120,12 @@ def load_simulation(simulation_path, in_file_name, dump_file_name, do_print=True print(f"Is setup file: {is_setup_file}") print("Solver config:") model.get_current_config().print_status() - print("In file parameters:") - for key, value in in_params.items(): - print(f"{key}: {value}") + + if in_params is not None: + print("In file parameters:") + for key, value in in_params.items(): + print(f"{key}: {value}") + print("Dump state:") dump.print_state() From f3114acb5731203558f22677ba944023533a6642 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Timoth=C3=A9e=20David--Cl=C3=A9ris?= Date: Thu, 22 Jan 2026 17:13:34 +0100 Subject: [PATCH 44/54] better --- src/pylib/shamrock/utils/phantom/__init__.py | 25 -------------------- 1 file changed, 25 deletions(-) diff --git a/src/pylib/shamrock/utils/phantom/__init__.py b/src/pylib/shamrock/utils/phantom/__init__.py index 61d3fa310c..b488936d8d 100644 --- a/src/pylib/shamrock/utils/phantom/__init__.py +++ b/src/pylib/shamrock/utils/phantom/__init__.py @@ -62,31 +62,6 @@ def parse_in_file(in_file): return params -def in_dump_to_config(model, simulation_path, in_file_name, dump_file_name, do_print=True): - """ - Convert a Phantom .in file and a Phantom dump file to a Shamrock config. - """ - in_params = parse_in_file(os.path.join(simulation_path, in_file_name)) - - dump = shamrock.load_phantom_dump(os.path.join(simulation_path, dump_file_name)) - - cfg = model.gen_config_from_phantom_dump(dump) - # Set the solver config to be the one stored in cfg - model.set_solver_config(cfg) - - # Print the solver config - if do_print and shamrock.sys.world_rank() == 0: - print("Solver config:") - model.get_current_config().print_status() - print("In file parameters:") - for key, value in in_params.items(): - print(f"{key}: {value}") - print("Dump state:") - dump.print_state() - - return cfg, dump - - def load_simulation(simulation_path, dump_file_name, in_file_name=None, do_print=True): """ Load a Phantom simulation into a Shamrock model. From ad0e27a1f0699327db8609810c6602b85b8c7205 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Thu, 22 Jan 2026 16:14:20 +0000 Subject: [PATCH 45/54] [gh-action] trigger CI with empty commit From f64b3f89639b116e475264ece49a88b01b53e4ac Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Timoth=C3=A9e=20David--Cl=C3=A9ris?= Date: Thu, 22 Jan 2026 20:20:34 +0100 Subject: [PATCH 46/54] cleaner logs --- src/pylib/shamrock/utils/phantom/__init__.py | 45 ++++++++++++++------ 1 file changed, 31 insertions(+), 14 deletions(-) diff --git a/src/pylib/shamrock/utils/phantom/__init__.py b/src/pylib/shamrock/utils/phantom/__init__.py index b488936d8d..496683fd8d 100644 --- a/src/pylib/shamrock/utils/phantom/__init__.py +++ b/src/pylib/shamrock/utils/phantom/__init__.py @@ -67,6 +67,14 @@ def load_simulation(simulation_path, dump_file_name, in_file_name=None, do_print Load a Phantom simulation into a Shamrock model. """ + if do_print and shamrock.sys.world_rank() == 0: + print("-----------------------------------------------------------") + print("---------------- Phantom dump loading -----------------") + print("-----------------------------------------------------------") + + # setup = dump finish with .tmp + is_setup_file = dump_file_name.endswith(".tmp") + dump_path = os.path.join(simulation_path, dump_file_name) if in_file_name is not None: @@ -75,8 +83,8 @@ def load_simulation(simulation_path, dump_file_name, in_file_name=None, do_print else: in_params = None - # setup = dump finish with .tmp - is_setup_file = dump_file_name.endswith(".tmp") + if do_print and shamrock.sys.world_rank() == 0: + print(" - Loading phantom dump from: ", dump_path) # Open the phantom dump dump = shamrock.load_phantom_dump(dump_path) @@ -86,29 +94,38 @@ def load_simulation(simulation_path, dump_file_name, in_file_name=None, do_print ctx.pdata_layout_new() model = shamrock.get_Model_SPH(context=ctx, vector_type="f64_3", sph_kernel="M4") + if do_print and shamrock.sys.world_rank() == 0: + print(" - Generating Shamrock solver config from phantom dump") cfg = model.gen_config_from_phantom_dump(dump) + if do_print and shamrock.sys.world_rank() == 0: + print(" - Setting Shamrock solver config") # Set the solver config to be the one stored in cfg model.set_solver_config(cfg) - # Print infos if do_print and shamrock.sys.world_rank() == 0: - print(f"Is setup file: {is_setup_file}") - print("Solver config:") - model.get_current_config().print_status() - - if in_params is not None: - print("In file parameters:") - for key, value in in_params.items(): - print(f"{key}: {value}") - - print("Dump state:") - dump.print_state() + print(" - Initializing domain scheduler") model.init_scheduler(int(1e8), 1) + if do_print and shamrock.sys.world_rank() == 0: + print(f" - Initializing from phantom dump (setup file: {is_setup_file})") + if is_setup_file: model.init_from_phantom_dump(dump, 0.5) else: model.init_from_phantom_dump(dump, 1.0) + # Print infos + if do_print and shamrock.sys.world_rank() == 0: + print(" - Shamrock solver config:") + model.get_current_config().print_status() + + if in_params is not None: + print(" - Phantom input file parameters:") + for key, value in in_params.items(): + print(f"{key}: {value}") + + # print("Dump state:") + # dump.print_state() + return ctx, model From b6b972fec29a840c40aad5f11473d993ecd3410e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?David--Cl=C3=A9ris=20Timoth=C3=A9e?= Date: Thu, 22 Jan 2026 21:28:06 +0100 Subject: [PATCH 47/54] Adjust phantom dump initialization parameter --- src/pylib/shamrock/utils/phantom/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pylib/shamrock/utils/phantom/__init__.py b/src/pylib/shamrock/utils/phantom/__init__.py index 496683fd8d..21a7a9994b 100644 --- a/src/pylib/shamrock/utils/phantom/__init__.py +++ b/src/pylib/shamrock/utils/phantom/__init__.py @@ -111,7 +111,7 @@ def load_simulation(simulation_path, dump_file_name, in_file_name=None, do_print print(f" - Initializing from phantom dump (setup file: {is_setup_file})") if is_setup_file: - model.init_from_phantom_dump(dump, 0.5) + model.init_from_phantom_dump(dump, 0.05) else: model.init_from_phantom_dump(dump, 1.0) From 0dc5c01811fba2c8a5286e13eff77dd81e2a3788 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Timoth=C3=A9e=20David--Cl=C3=A9ris?= Date: Thu, 22 Jan 2026 21:30:30 +0100 Subject: [PATCH 48/54] its a start --- .../examples/sph/run_phantom_run_emulation.py | 41 +++++++++++++++ src/pylib/shamrock/utils/__init__.py | 1 + src/pylib/shamrock/utils/url/__init__.py | 52 +++++++++++++++++++ .../include/shampylib/pyNodeInstance.hpp | 10 ++++ 4 files changed, 104 insertions(+) create mode 100644 doc/sphinx/examples/sph/run_phantom_run_emulation.py create mode 100644 src/pylib/shamrock/utils/url/__init__.py diff --git a/doc/sphinx/examples/sph/run_phantom_run_emulation.py b/doc/sphinx/examples/sph/run_phantom_run_emulation.py new file mode 100644 index 0000000000..60f9ba0cf7 --- /dev/null +++ b/doc/sphinx/examples/sph/run_phantom_run_emulation.py @@ -0,0 +1,41 @@ +""" +Perform a Phantom run in Shamrock +========================================== + +Setup from a phantom dump and run according to the input file +""" + +import os +from urllib.request import urlretrieve + +import shamrock + +# If we use the shamrock executable to run this script instead of the python interpreter, +# we should not initialize the system as the shamrock executable needs to handle specific MPI logic +if not shamrock.sys.is_initialized(): + shamrock.change_loglevel(1) + shamrock.sys.init("0:0") + +dump_folder = "_to_trash/phantom_test_sim" +if shamrock.sys.world_rank() == 0: + os.makedirs(dump_folder, exist_ok=True) +shamrock.sys.mpi_barrier() + +input_file_name = "disc.in" +dump_file_name = "disc_00000.tmp" + +input_file_path = os.path.join(dump_folder, input_file_name) +dump_file_path = os.path.join(dump_folder, dump_file_name) + +input_file_url = "https://raw.githubusercontent.com/Shamrock-code/reference-files/refs/heads/main/phantom_disc_simulation/disc.in" +dump_file_url = "https://raw.githubusercontent.com/Shamrock-code/reference-files/refs/heads/main/phantom_disc_simulation/disc_00000.tmp" + +shamrock.utils.url.download_file(input_file_url, input_file_path) +shamrock.utils.url.download_file(dump_file_url, dump_file_path) + + +ctx, model = shamrock.utils.phantom.load_simulation( + dump_folder, dump_file_name=dump_file_name, in_file_name=input_file_name +) + +model.timestep() diff --git a/src/pylib/shamrock/utils/__init__.py b/src/pylib/shamrock/utils/__init__.py index 3cb4eef9fe..74ddd1b7b1 100644 --- a/src/pylib/shamrock/utils/__init__.py +++ b/src/pylib/shamrock/utils/__init__.py @@ -3,3 +3,4 @@ """ from . import phantom, plot +from .url import download_file diff --git a/src/pylib/shamrock/utils/url/__init__.py b/src/pylib/shamrock/utils/url/__init__.py new file mode 100644 index 0000000000..20bfe20c4d --- /dev/null +++ b/src/pylib/shamrock/utils/url/__init__.py @@ -0,0 +1,52 @@ +""" +Utility functions to download files +""" + +import os +import sys +from urllib.request import urlretrieve + +import shamrock.sys + + +def fmt(n): + for u in ("B", "KB", "MB", "GB", "TB"): + if n < 1024: + return f"{n:.1f}{u}" + n /= 1024 + + +def reporthook(block_num, block_size, total_size): + if total_size <= 0 or block_size <= 0: + return + freq_report = int(int(total_size / block_size) / 10) + if freq_report <= 0: + return + if block_num % freq_report == 0: + BAR_WIDTH = 60 + downloaded = block_num * block_size + percent = int((downloaded / total_size) * 100) + filled = int(BAR_WIDTH * percent / 100) + bar = "#" * filled + "-" * (BAR_WIDTH - filled) + sys.stdout.write(f"[{bar}] {percent:3d}% | {fmt(downloaded)}/{fmt(total_size)}\n") + sys.stdout.flush() + + +def download_file(url, filename): + """ + Download a file from an URL + """ + + if shamrock.sys.world_rank() == 0: + print(f" - Downloading {filename} from {url}") + # create the directory if it does not exist + os.makedirs(os.path.dirname(filename), exist_ok=True) + urlretrieve(url, filename, reporthook=reporthook) + + shamrock.sys.mpi_barrier() + + # check that the file exists + if not os.path.exists(filename): + raise FileNotFoundError( + f"File {filename} should have been downloaded but is not present on rank {shamrock.sys.world_rank()}" + ) diff --git a/src/shampylib/include/shampylib/pyNodeInstance.hpp b/src/shampylib/include/shampylib/pyNodeInstance.hpp index 31676d0437..234e3ded9c 100644 --- a/src/shampylib/include/shampylib/pyNodeInstance.hpp +++ b/src/shampylib/include/shampylib/pyNodeInstance.hpp @@ -18,6 +18,7 @@ #include "shambindings/pybindaliases.hpp" #include "shamcmdopt/cmdopt.hpp" #include "shamcomm/mpiInfo.hpp" +#include "shamcomm/wrapper.hpp" #include "shamsys/NodeInstance.hpp" namespace shamsys::instance { @@ -79,5 +80,14 @@ namespace shamsys::instance { R"pbdoc( Return true if the node instance is initialized )pbdoc"); + + m.def( + "mpi_barrier", + []() { + shamcomm::mpi::Barrier(MPI_COMM_WORLD); + }, + R"pbdoc( + Call the MPI barrier + )pbdoc"); } } // namespace shamsys::instance From 80223ee1b4493e2d5210e10634c7209f053ee658 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Timoth=C3=A9e=20David--Cl=C3=A9ris?= Date: Thu, 22 Jan 2026 22:08:50 +0100 Subject: [PATCH 49/54] better --- .../examples/sph/run_phantom_run_emulation.py | 12 +-- src/pylib/shamrock/utils/phantom/__init__.py | 95 +++++++++++++++++-- 2 files changed, 94 insertions(+), 13 deletions(-) diff --git a/doc/sphinx/examples/sph/run_phantom_run_emulation.py b/doc/sphinx/examples/sph/run_phantom_run_emulation.py index 60f9ba0cf7..89a1bdfcce 100644 --- a/doc/sphinx/examples/sph/run_phantom_run_emulation.py +++ b/doc/sphinx/examples/sph/run_phantom_run_emulation.py @@ -6,6 +6,7 @@ """ import os +import shutil from urllib.request import urlretrieve import shamrock @@ -18,7 +19,11 @@ dump_folder = "_to_trash/phantom_test_sim" if shamrock.sys.world_rank() == 0: + # remove the folder if it exists (ok if it does not exist) + if os.path.exists(dump_folder): + shutil.rmtree(dump_folder) os.makedirs(dump_folder, exist_ok=True) + shamrock.sys.mpi_barrier() input_file_name = "disc.in" @@ -33,9 +38,4 @@ shamrock.utils.url.download_file(input_file_url, input_file_path) shamrock.utils.url.download_file(dump_file_url, dump_file_path) - -ctx, model = shamrock.utils.phantom.load_simulation( - dump_folder, dump_file_name=dump_file_name, in_file_name=input_file_name -) - -model.timestep() +shamrock.utils.phantom.run_phantom_simulation(dump_folder, "disc") diff --git a/src/pylib/shamrock/utils/phantom/__init__.py b/src/pylib/shamrock/utils/phantom/__init__.py index 21a7a9994b..79c7e23738 100644 --- a/src/pylib/shamrock/utils/phantom/__init__.py +++ b/src/pylib/shamrock/utils/phantom/__init__.py @@ -62,7 +62,7 @@ def parse_in_file(in_file): return params -def load_simulation(simulation_path, dump_file_name, in_file_name=None, do_print=True): +def load_simulation(simulation_path, dump_file_name=None, in_file_name=None, do_print=True): """ Load a Phantom simulation into a Shamrock model. """ @@ -72,20 +72,27 @@ def load_simulation(simulation_path, dump_file_name, in_file_name=None, do_print print("---------------- Phantom dump loading -----------------") print("-----------------------------------------------------------") - # setup = dump finish with .tmp - is_setup_file = dump_file_name.endswith(".tmp") - - dump_path = os.path.join(simulation_path, dump_file_name) - if in_file_name is not None: in_file_path = os.path.join(simulation_path, in_file_name) in_params = parse_in_file(in_file_path) else: in_params = None + if dump_file_name is None: + if in_file_name is not None: + dump_file_name = in_params["dumpfile"] + + else: + raise ValueError("Either dump_file_name or in_file_name must be provided") + + dump_path = os.path.join(simulation_path, dump_file_name) + if do_print and shamrock.sys.world_rank() == 0: print(" - Loading phantom dump from: ", dump_path) + # setup = dump finish with .tmp + is_setup_file = dump_file_name.endswith(".tmp") + # Open the phantom dump dump = shamrock.load_phantom_dump(dump_path) @@ -128,4 +135,78 @@ def load_simulation(simulation_path, dump_file_name, in_file_name=None, do_print # print("Dump state:") # dump.print_state() - return ctx, model + return ctx, model, in_params + + +def run_phantom_simulation(simulation_folder, sim_name): + """ + Run a Phantom simulation in Shamrock. + """ + + input_file_name = sim_name + ".in" + + ctx, model, in_params = shamrock.utils.phantom.load_simulation( + simulation_folder, in_file_name=input_file_name + ) + + dump_file_name = in_params["dumpfile"] + + # phantom dumps are 00000.tmp if before start and then sim_name_{:05d} + # parse the dump number + dump_number = int(dump_file_name.split("_")[1].split(".")[0]) + print(f"Dump number: {dump_number}") + + dtmax = float(in_params["dtmax"]) + tmax = float(in_params["tmax"]) + + def get_ph_dump_file_name(dump_number): + return f"{sim_name}_{dump_number:05d}" + + def get_ph_dump_name(dump_number): + return os.path.join(simulation_folder, get_ph_dump_file_name(dump_number)) + + def do_dump(dump_number): + if shamrock.sys.world_rank() == 0: + print("-----------------------------------------------------------") + print("---------------- Phantom dump saving -----------------") + print("-----------------------------------------------------------") + print(f" - Saving dump {dump_number} to {get_ph_dump_name(dump_number)}") + dump = model.make_phantom_dump() + dump.save_dump(get_ph_dump_name(dump_number)) + + # replace dumpfile in the input file + lines = [] + with open(os.path.join(simulation_folder, f"{sim_name}.in"), "r") as f: + lines = f.readlines() + with open(os.path.join(simulation_folder, f"{sim_name}.in"), "w") as f: + for line in lines: + if "dumpfile" in line: + line = f" dumpfile = {get_ph_dump_file_name(dump_number)} ! dump file to start from\n" + f.write(line) + + # if .tmp is there remove it + if ( + os.path.exists(os.path.join(simulation_folder, f"{sim_name}_00000.tmp")) + and shamrock.sys.world_rank() == 0 + ): + os.remove(os.path.join(simulation_folder, f"{sim_name}_00000.tmp")) + + do_dump(dump_number) + + dump_number += 1 + + # evolve until tmax in increments of dtmax + last_step = False + while not last_step: + next_time = model.get_time() + dtmax + + if next_time > tmax: + next_time = tmax + last_step = True + + last_step = model.evolve_until(next_time) + + do_dump(dump_number) + dump_number += 1 + + return ctx, model, in_params From 594d65de411e318c4db5a827478951ac5ebb57c1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Timoth=C3=A9e=20David--Cl=C3=A9ris?= Date: Thu, 22 Jan 2026 23:39:31 +0100 Subject: [PATCH 50/54] better --- .../examples/sph/run_phantom_run_emulation.py | 47 ++++++++++++++++++- src/pylib/shamrock/utils/phantom/__init__.py | 8 +++- .../sph/src/io/Phantom2Shamrock.cpp | 14 ++++++ 3 files changed, 67 insertions(+), 2 deletions(-) diff --git a/doc/sphinx/examples/sph/run_phantom_run_emulation.py b/doc/sphinx/examples/sph/run_phantom_run_emulation.py index 89a1bdfcce..96cb2a82fd 100644 --- a/doc/sphinx/examples/sph/run_phantom_run_emulation.py +++ b/doc/sphinx/examples/sph/run_phantom_run_emulation.py @@ -38,4 +38,49 @@ shamrock.utils.url.download_file(input_file_url, input_file_path) shamrock.utils.url.download_file(dump_file_url, dump_file_path) -shamrock.utils.phantom.run_phantom_simulation(dump_folder, "disc") + +def plot_that_rippa(ctx, model, dump_number): + pixel_x = 1920 + pixel_y = 1080 + radius = 5 + center = (0.0, 0.0, 0.0) + + aspect = pixel_x / pixel_y + pic_range = [-radius * aspect, radius * aspect, -radius, radius] + delta_x = (radius * 2 * aspect, 0.0, 0.0) + delta_y = (0.0, radius * 2, 0.0) + + arr_rho = model.render_cartesian_column_integ( + "rho", + "f64", + center=(0.0, 0.0, 0.0), + delta_x=delta_x, + delta_y=delta_y, + nx=pixel_x, + ny=pixel_y, + ) + + import copy + + import matplotlib + import matplotlib.pyplot as plt + + my_cmap = copy.copy(matplotlib.colormaps.get_cmap("gist_heat")) # copy the default cmap + my_cmap.set_bad(color="black") + + plt.figure(figsize=(16 / 2, 9 / 2)) + res = plt.imshow(arr_rho, cmap=my_cmap, origin="lower", extent=pic_range, norm="log", vmin=1e-9) + + cbar = plt.colorbar(res, extend="both") + cbar.set_label(r"$\int \rho \, \mathrm{d} z$ [code unit]") + # or r"$\rho$ [code unit]" for slices + + plt.title("t = {:0.3f} [code unit]".format(model.get_time())) + plt.xlabel("x") + plt.ylabel("z") + plt.show() + + +ctx, model, in_params = shamrock.utils.phantom.run_phantom_simulation( + dump_folder, "disc", callback=plot_that_rippa +) diff --git a/src/pylib/shamrock/utils/phantom/__init__.py b/src/pylib/shamrock/utils/phantom/__init__.py index 79c7e23738..1053ffcbfb 100644 --- a/src/pylib/shamrock/utils/phantom/__init__.py +++ b/src/pylib/shamrock/utils/phantom/__init__.py @@ -138,7 +138,7 @@ def load_simulation(simulation_path, dump_file_name=None, in_file_name=None, do_ return ctx, model, in_params -def run_phantom_simulation(simulation_folder, sim_name): +def run_phantom_simulation(simulation_folder, sim_name, callback=None): """ Run a Phantom simulation in Shamrock. """ @@ -191,6 +191,12 @@ def do_dump(dump_number): ): os.remove(os.path.join(simulation_folder, f"{sim_name}_00000.tmp")) + if callback is not None: + callback(ctx, model, dump_number) + + model.set_next_dt(0) + model.timestep() + do_dump(dump_number) dump_number += 1 diff --git a/src/shammodels/sph/src/io/Phantom2Shamrock.cpp b/src/shammodels/sph/src/io/Phantom2Shamrock.cpp index 7c9427e1ad..23388e8d29 100644 --- a/src/shammodels/sph/src/io/Phantom2Shamrock.cpp +++ b/src/shammodels/sph/src/io/Phantom2Shamrock.cpp @@ -122,6 +122,20 @@ namespace shammodels::sph { f64 utime = phdump.read_header_float("utime"); f64 umagfd = phdump.read_header_float("umagfd"); + if (udist == 1.0 || umass == 1.0 || utime == 1.0 || umagfd == 3.54491) { + logger::warn_ln("SPH", "phantom dump units are not set, defaulting to SI"); + logger::warn_ln("SPH", "udist =", udist); + logger::warn_ln("SPH", "umass =", umass); + logger::warn_ln("SPH", "utime =", utime); + logger::warn_ln("SPH", "umagfd =", umagfd); + + return shamunits::UnitSystem(); + } + + // convert from CGS to SI + udist /= 100.0; + umass /= 1000.0; + return shamunits::UnitSystem( utime, udist, umass // unit_current = 1 , From f006575d9f96f6109e40635e9a2f2c48ac419f7e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Timoth=C3=A9e=20David--Cl=C3=A9ris?= Date: Thu, 22 Jan 2026 23:55:38 +0100 Subject: [PATCH 51/54] better --- src/shammodels/sph/src/io/Phantom2Shamrock.cpp | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/shammodels/sph/src/io/Phantom2Shamrock.cpp b/src/shammodels/sph/src/io/Phantom2Shamrock.cpp index 23388e8d29..24db3629cf 100644 --- a/src/shammodels/sph/src/io/Phantom2Shamrock.cpp +++ b/src/shammodels/sph/src/io/Phantom2Shamrock.cpp @@ -154,9 +154,10 @@ namespace shammodels::sph { dump.table_header_f64.add("umass", units->kg_inv); dump.table_header_f64.add("utime", units->s_inv); - f64 umass = units->template to(); + // Back to freakin CGS (worst units system ever, well no ... there is imperial) + f64 umass = units->template to() / 1000.0; f64 utime = units->template to(); - f64 udist = units->template to(); + f64 udist = units->template to() / 100.0; shamunits::Constants ctes{*units}; f64 ccst = ctes.c(); From 6d826ec257c1395d0adfcc8dec0824d1cb8e835d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Timoth=C3=A9e=20David--Cl=C3=A9ris?= Date: Fri, 23 Jan 2026 00:09:05 +0100 Subject: [PATCH 52/54] better --- src/pylib/shamrock/utils/phantom/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pylib/shamrock/utils/phantom/__init__.py b/src/pylib/shamrock/utils/phantom/__init__.py index 1053ffcbfb..ff50bb74d4 100644 --- a/src/pylib/shamrock/utils/phantom/__init__.py +++ b/src/pylib/shamrock/utils/phantom/__init__.py @@ -210,7 +210,7 @@ def do_dump(dump_number): next_time = tmax last_step = True - last_step = model.evolve_until(next_time) + model.evolve_until(next_time) do_dump(dump_number) dump_number += 1 From d187a235996174844f5e240c342bc12bbd93b40d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Timoth=C3=A9e=20David--Cl=C3=A9ris?= Date: Fri, 23 Jan 2026 09:42:25 +0100 Subject: [PATCH 53/54] better --- doc/sphinx/examples/sph/run_start_sph_from_phantom_dump.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/sphinx/examples/sph/run_start_sph_from_phantom_dump.py b/doc/sphinx/examples/sph/run_start_sph_from_phantom_dump.py index 0a5dcfd0e8..2b6336ce6e 100644 --- a/doc/sphinx/examples/sph/run_start_sph_from_phantom_dump.py +++ b/doc/sphinx/examples/sph/run_start_sph_from_phantom_dump.py @@ -30,7 +30,7 @@ shamrock.change_loglevel(1) shamrock.sys.init("0:0") -ctx, model = shamrock.utils.phantom.load_simulation( +ctx, model, in_params = shamrock.utils.phantom.load_simulation( dump_folder, dump_file_name="blast_00010", in_file_name=None ) From b5220b77d10dfc9651515b1a979951781532d33a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Timoth=C3=A9e=20David--Cl=C3=A9ris?= Date: Fri, 23 Jan 2026 23:49:20 +0100 Subject: [PATCH 54/54] [SPH] Add extended Farris et al 2014 EOS (Ragussa et al 2016) --- .../include/shammodels/common/EOSConfig.hpp | 14 +++ src/shammodels/common/src/EOSConfig.cpp | 46 ++++--- .../shammodels/sph/io/PhantomDumpEOSUtils.hpp | 40 +++++++ .../sph/src/io/Phantom2Shamrock.cpp | 29 +++++ .../sph/src/io/PhantomDumpEOSUtils.cpp | 58 +++++++++ src/shammodels/sph/src/modules/ComputeEos.cpp | 100 +++++++++++++++- src/shamphys/include/shamphys/eos_config.hpp | 112 ++++++++++++++---- 7 files changed, 363 insertions(+), 36 deletions(-) diff --git a/src/shammodels/common/include/shammodels/common/EOSConfig.hpp b/src/shammodels/common/include/shammodels/common/EOSConfig.hpp index 8c8429f2b2..1442eda5f5 100644 --- a/src/shammodels/common/include/shammodels/common/EOSConfig.hpp +++ b/src/shammodels/common/include/shammodels/common/EOSConfig.hpp @@ -67,6 +67,10 @@ namespace shammodels { using LocallyIsothermalFA2014 = shamphys::EOS_Config_LocallyIsothermalDisc_Farris2014; + /// Locally isothermal equation of state configuration from Farris 2014 extended to q != 1/2 + using LocallyIsothermalFA2014Extended + = shamphys::EOS_Config_LocallyIsothermalDisc_ExtendedFarris2014; + /// Fermi equation of state configuration using Fermi = shamphys::EOS_Config_Fermi; @@ -78,6 +82,7 @@ namespace shammodels { LocallyIsothermal, LocallyIsothermalLP07, LocallyIsothermalFA2014, + LocallyIsothermalFA2014Extended, Fermi>; /// Current EOS configuration @@ -128,6 +133,11 @@ namespace shammodels { config = LocallyIsothermalFA2014{h_over_r}; } + inline void set_locally_isothermalFA2014Extended( + Tscal cs0, Tscal q, Tscal r0, u32 n_sinks) { + config = LocallyIsothermalFA2014Extended{cs0, q, r0, n_sinks}; + } + /** * @brief Set the EOS configuration to a Fermi equation of state * @@ -173,6 +183,10 @@ void shammodels::EOSConfig::print_status() { } else if ( LocallyIsothermalFA2014 *eos_config = std::get_if(&config)) { logger::raw_ln("locally isothermal (Farris 2014) : "); + } else if ( + LocallyIsothermalFA2014Extended *eos_config + = std::get_if(&config)) { + logger::raw_ln("locally isothermal (Farris 2014 extended) : "); } else if (Fermi *eos_config = std::get_if(&config)) { logger::raw_ln("Fermi : "); logger::raw_ln("mu_e", eos_config->mu_e); diff --git a/src/shammodels/common/src/EOSConfig.cpp b/src/shammodels/common/src/EOSConfig.cpp index e6667cf96f..45e338e99c 100644 --- a/src/shammodels/common/src/EOSConfig.cpp +++ b/src/shammodels/common/src/EOSConfig.cpp @@ -55,13 +55,14 @@ namespace shammodels { static_assert(shambase::always_false_v, "This Tvec type is not handled"); } - using Isothermal = typename EOSConfig::Isothermal; - using Adiabatic = typename EOSConfig::Adiabatic; - using Polytropic = typename EOSConfig::Polytropic; - using LocIsoT = typename EOSConfig::LocallyIsothermal; - using LocIsoTLP07 = typename EOSConfig::LocallyIsothermalLP07; - using LocIsoTFA2014 = typename EOSConfig::LocallyIsothermalFA2014; - using Fermi = typename EOSConfig::Fermi; + using Isothermal = typename EOSConfig::Isothermal; + using Adiabatic = typename EOSConfig::Adiabatic; + using Polytropic = typename EOSConfig::Polytropic; + using LocIsoT = typename EOSConfig::LocallyIsothermal; + using LocIsoTLP07 = typename EOSConfig::LocallyIsothermalLP07; + using LocIsoTFA2014 = typename EOSConfig::LocallyIsothermalFA2014; + using LocIsoTFA2014Extended = typename EOSConfig::LocallyIsothermalFA2014Extended; + using Fermi = typename EOSConfig::Fermi; if (const Isothermal *eos_config = std::get_if(&p.config)) { j = json{{"Tvec", type_id}, {"eos_type", "isothermal"}, {"cs", eos_config->cs}}; @@ -87,6 +88,16 @@ namespace shammodels { {"Tvec", type_id}, {"eos_type", "locally_isothermal_fa2014"}, {"h_over_r", eos_config->h_over_r}}; + } else if ( + const LocIsoTFA2014Extended *eos_config + = std::get_if(&p.config)) { + j = json{ + {"Tvec", type_id}, + {"eos_type", "locally_isothermal_fa2014_extended"}, + {"cs0", eos_config->cs0}, + {"q", eos_config->q}, + {"r0", eos_config->r0}, + {"n_sinks", eos_config->n_sinks}}; } else if (const Fermi *eos_config = std::get_if(&p.config)) { j = json{{"Tvec", type_id}, {"eos_type", "fermi"}, {"mu_e", eos_config->mu_e}}; } else { @@ -135,13 +146,14 @@ namespace shammodels { std::string eos_type; j.at("eos_type").get_to(eos_type); - using Isothermal = typename EOSConfig::Isothermal; - using Adiabatic = typename EOSConfig::Adiabatic; - using Polytropic = typename EOSConfig::Polytropic; - using LocIsoT = typename EOSConfig::LocallyIsothermal; - using LocIsoTLP07 = typename EOSConfig::LocallyIsothermalLP07; - using LocIsoTFA2014 = typename EOSConfig::LocallyIsothermalFA2014; - using Fermi = typename EOSConfig::Fermi; + using Isothermal = typename EOSConfig::Isothermal; + using Adiabatic = typename EOSConfig::Adiabatic; + using Polytropic = typename EOSConfig::Polytropic; + using LocIsoT = typename EOSConfig::LocallyIsothermal; + using LocIsoTLP07 = typename EOSConfig::LocallyIsothermalLP07; + using LocIsoTFA2014 = typename EOSConfig::LocallyIsothermalFA2014; + using LocIsoTFA2014Extended = typename EOSConfig::LocallyIsothermalFA2014Extended; + using Fermi = typename EOSConfig::Fermi; if (eos_type == "isothermal") { p.config = Isothermal{j.at("cs").get()}; @@ -156,6 +168,12 @@ namespace shammodels { j.at("cs0").get(), j.at("q").get(), j.at("r0").get()}; } else if (eos_type == "locally_isothermal_fa2014") { p.config = LocIsoTFA2014{j.at("h_over_r").get()}; + } else if (eos_type == "locally_isothermal_fa2014_extended") { + p.config = LocIsoTFA2014Extended{ + j.at("cs0").get(), + j.at("q").get(), + j.at("r0").get(), + j.at("n_sinks").get()}; } else if (eos_type == "fermi") { p.config = Fermi{j.at("mu_e").get()}; } else { diff --git a/src/shammodels/sph/include/shammodels/sph/io/PhantomDumpEOSUtils.hpp b/src/shammodels/sph/include/shammodels/sph/io/PhantomDumpEOSUtils.hpp index c00c6309ef..a8bbc2801f 100644 --- a/src/shammodels/sph/include/shammodels/sph/io/PhantomDumpEOSUtils.hpp +++ b/src/shammodels/sph/include/shammodels/sph/io/PhantomDumpEOSUtils.hpp @@ -98,4 +98,44 @@ namespace shammodels::sph::phdump { */ void eos3_write(PhantomDump &dump, const f64 &cs0, const f64 &q, const f64 &r0); + /** + * @brief Load the EOS13 from the phantom dump + * + * @param[in] dump Phantom dump file + * @param[out] cs0 Sound speed at the reference radius + * @param[out] q Power law index + * @param[out] r0 Reference radius + */ + void eos13_load(const PhantomDump &dump, f64 &cs0, f64 &q, f64 &r0); + + /** + * @brief Write the EOS13 to the phantom dump + * + * @param[out] dump Phantom dump file + * @param[in] cs0 Sound speed at the reference radius + * @param[in] q Power law index + * @param[in] r0 Reference radius + */ + void eos13_write(PhantomDump &dump, const f64 &cs0, const f64 &q, const f64 &r0); + + /** + * @brief Load the EOS14 from the phantom dump + * + * @param[in] dump Phantom dump file + * @param[out] cs0 Sound speed at the reference radius + * @param[out] q Power law index + * @param[out] r0 Reference radius + */ + void eos14_load(const PhantomDump &dump, f64 &cs0, f64 &q, f64 &r0); + + /** + * @brief Write the EOS14 to the phantom dump + * + * @param[out] dump Phantom dump file + * @param[in] cs0 Sound speed at the reference radius + * @param[in] q Power law index + * @param[in] r0 Reference radius + */ + void eos14_write(PhantomDump &dump, const f64 &cs0, const f64 &q, const f64 &r0); + } // namespace shammodels::sph::phdump diff --git a/src/shammodels/sph/src/io/Phantom2Shamrock.cpp b/src/shammodels/sph/src/io/Phantom2Shamrock.cpp index 7c9427e1ad..9f001b14a0 100644 --- a/src/shammodels/sph/src/io/Phantom2Shamrock.cpp +++ b/src/shammodels/sph/src/io/Phantom2Shamrock.cpp @@ -43,6 +43,15 @@ namespace shammodels::sph { f64 cs0, q, r0; phdump::eos3_load(phdump, cs0, q, r0); cfg.set_locally_isothermalLP07(cs0, q, r0); + } else if (ieos == 13) { + f64 cs0, q, r0; + phdump::eos13_load(phdump, cs0, q, r0); + // u32_max implies all sinks + cfg.set_locally_isothermalFA2014Extended(cs0, q, r0, u32_max); + } else if (ieos == 14) { + f64 cs0, q, r0; + phdump::eos14_load(phdump, cs0, q, r0); + cfg.set_locally_isothermalFA2014Extended(cs0, q, r0, 2); } else { const std::string msg = "loading phantom ieos=" + std::to_string(ieos) + " is not implemented in shamrock"; @@ -65,6 +74,8 @@ namespace shammodels::sph { using EOS_LocallyIsothermal = typename EOSConfig::LocallyIsothermal; using EOS_LocallyIsothermalLP07 = typename EOSConfig::LocallyIsothermalLP07; using EOS_LocallyIsothermalFA2014 = typename EOSConfig::LocallyIsothermalFA2014; + using EOS_LocallyIsothermalFA2014Extended = + typename EOSConfig::LocallyIsothermalFA2014Extended; if (EOS_Isothermal *eos_config = std::get_if(&cfg.config)) { phdump::eos1_write(dump, eos_config->cs); @@ -74,6 +85,23 @@ namespace shammodels::sph { EOS_LocallyIsothermalLP07 *eos_config = std::get_if(&cfg.config)) { phdump::eos3_write(dump, eos_config->cs0, eos_config->q, eos_config->r0); + } else if ( + EOS_LocallyIsothermalFA2014Extended *eos_config + = std::get_if(&cfg.config)) { + + if (eos_config->n_sinks == u32_max) { + phdump::eos13_write(dump, eos_config->cs0, eos_config->q, eos_config->r0); + } else if (eos_config->n_sinks == 2) { + phdump::eos14_write(dump, eos_config->cs0, eos_config->q, eos_config->r0); + } else { + const std::string msg + = "Phantom only support all or 2 sinks for this EOS configuration"; + if (bypass_error) { + logger::warn_ln("SPH", msg); + } else { + shambase::throw_unimplemented(msg); + } + } } else { const std::string msg = "The current shamrock EOS is not implemented in phantom dump conversion"; @@ -100,6 +128,7 @@ namespace shammodels::sph { } // namespace shammodels::sph namespace shammodels::sph { + template AVConfig get_shamrock_avconfig(PhantomDump &phdump) { AVConfig cfg{}; diff --git a/src/shammodels/sph/src/io/PhantomDumpEOSUtils.cpp b/src/shammodels/sph/src/io/PhantomDumpEOSUtils.cpp index 464c443ace..f57922378a 100644 --- a/src/shammodels/sph/src/io/PhantomDumpEOSUtils.cpp +++ b/src/shammodels/sph/src/io/PhantomDumpEOSUtils.cpp @@ -244,4 +244,62 @@ namespace shammodels::sph::phdump { dump.table_header_i32.add("ieos", 3); write_headeropts_eos(3, dump, eos); } + + /* + case(13) + ! + !--Locally isothermal eos for generic hierarchical system + ! + ! Assuming all sink particles are stars. + ! Generalisation of Farris et al. (2014; for binaries) to N stars. + ! For two sink particles this is identical to ieos=14 + ! + */ + + void eos13_load(const PhantomDump &dump, f64 &cs0, f64 &q, f64 &r0) { + assert_ieos_val(dump, 13); + EOSPhConfig eos = read_headeropts_eos(dump, 13); + + cs0 = sycl::sqrt(eos.polyk); + q = eos.qfacdisc; + r0 = 1; // the polyk in phantom include the 1/r0^2 ? + } + + void eos13_write(PhantomDump &dump, const f64 &cs0, const f64 &q, const f64 &r0) { + EOSPhConfig eos; + + eos.polyk = cs0 * cs0 / (r0 * r0); + eos.qfacdisc = q; + + dump.table_header_i32.add("ieos", 13); + write_headeropts_eos(13, dump, eos); + } + + /* + case(14) + ! + !--Locally isothermal eos from Farris et al. (2014) for binary system + ! + ! uses the locations of the first two sink particles + ! + */ + + void eos14_load(const PhantomDump &dump, f64 &cs0, f64 &q, f64 &r0) { + assert_ieos_val(dump, 14); + EOSPhConfig eos = read_headeropts_eos(dump, 14); + + cs0 = sycl::sqrt(eos.polyk); + q = eos.qfacdisc; + r0 = 1; // the polyk in phantom include the 1/r0^2 ? + } + + void eos14_write(PhantomDump &dump, const f64 &cs0, const f64 &q, const f64 &r0) { + EOSPhConfig eos; + + eos.polyk = cs0 * cs0 / (r0 * r0); + eos.qfacdisc = q; + + dump.table_header_i32.add("ieos", 14); + write_headeropts_eos(14, dump, eos); + } } // namespace shammodels::sph::phdump diff --git a/src/shammodels/sph/src/modules/ComputeEos.cpp b/src/shammodels/sph/src/modules/ComputeEos.cpp index 31131c9cf5..f3211a2ef9 100644 --- a/src/shammodels/sph/src/modules/ComputeEos.cpp +++ b/src/shammodels/sph/src/modules/ComputeEos.cpp @@ -108,7 +108,9 @@ void shammodels::sph::modules::ComputeEos::compute_eos_internal using SolverEOS_LocallyIsothermal = typename SolverConfigEOS::LocallyIsothermal; using SolverEOS_LocallyIsothermalLP07 = typename SolverConfigEOS::LocallyIsothermalLP07; using SolverEOS_LocallyIsothermalFA2014 = typename SolverConfigEOS::LocallyIsothermalFA2014; - using SolverEOS_Fermi = typename SolverConfigEOS::Fermi; + using SolverEOS_LocallyIsothermalFA2014Extended = + typename SolverConfigEOS::LocallyIsothermalFA2014Extended; + using SolverEOS_Fermi = typename SolverConfigEOS::Fermi; sham::DeviceQueue &q = shamsys::instance::get_compute_scheduler().get_queue(); @@ -390,6 +392,102 @@ void shammodels::sph::modules::ComputeEos::compute_eos_internal buf_xyz.complete_event_state(e); }); + } else if ( + SolverEOS_LocallyIsothermalFA2014Extended *eos_config + = std::get_if( + &solver_config.eos_config.config)) { + + Tscal _cs0 = eos_config->cs0; + Tscal _r0 = eos_config->r0; + Tscal _q = eos_config->q; + u32 n_sinks = eos_config->n_sinks; + + using EOS = shamphys::EOS_LocallyIsothermal; + + auto &sink_parts = storage.sinks.get(); + std::vector sink_pos; + std::vector sink_mass; + u32 sink_cnt = 0; + + for (auto &s : sink_parts) { + sink_pos.push_back(s.pos); + sink_mass.push_back(s.mass); + if (sink_pos.size() >= n_sinks) { // We only consider the first n_sinks sinks + break; + } + sink_cnt++; + } + + sycl::buffer sink_pos_buf{sink_pos}; + sycl::buffer sink_mass_buf{sink_mass}; + + storage.merged_patchdata_ghost.get().for_each([&](u64 id, PatchDataLayer &mpdat) { + auto &mfield = storage.merged_xyzh.get().get(id); + + sham::DeviceBuffer &buf_xyz = mfield.template get_field_buf_ref(0); + + sham::DeviceBuffer &buf_P + = shambase::get_check_ref(storage.pressure).get_field(id).get_buf(); + sham::DeviceBuffer &buf_cs + = shambase::get_check_ref(storage.soundspeed).get_field(id).get_buf(); + sham::DeviceBuffer &buf_uint = mpdat.get_field_buf_ref(iuint_interf); + auto rho_getter = rho_getter_gen(mpdat); + + // TODO: Use the complex kernel call when implemented + + sham::EventList depends_list; + + auto P = buf_P.get_write_access(depends_list); + auto cs = buf_cs.get_write_access(depends_list); + auto rho = rho_getter.get_read_access(depends_list); + auto U = buf_uint.get_read_access(depends_list); + auto xyz = buf_xyz.get_read_access(depends_list); + + u32 total_elements + = shambase::get_check_ref(storage.part_counts_with_ghost).indexes.get(id); + + auto e = q.submit(depends_list, [&](sycl::handler &cgh) { + sycl::accessor spos{sink_pos_buf, cgh, sycl::read_only}; + sycl::accessor smass{sink_mass_buf, cgh, sycl::read_only}; + u32 scount = sink_cnt; + + Tscal cs0 = _cs0; + Tscal r0 = _r0; + Tscal q = _q; + + Tscal inv_r0_q = 1. / sycl::pow(r0, q); + + cgh.parallel_for(sycl::range<1>{total_elements}, [=](sycl::item<1> item) { + using namespace shamrock::sph; + + Tvec R = xyz[item]; + Tscal rho_a = rho(item.get_linear_id()); + + Tscal sink_mass_sum = 0; + Tscal pot_sum = 0; + for (u32 i = 0; i < scount; i++) { + Tvec s_r = spos[i] - R; + Tscal s_m = smass[i]; + Tscal s_r_abs = sycl::length(s_r); + sink_mass_sum += s_m; + pot_sum += s_m / s_r_abs; + } + + Tscal cs_out = cs0 * inv_r0_q * sycl::pow(pot_sum / sink_mass_sum, q); + Tscal P_a = EOS::pressure_from_cs(cs_out * cs_out, rho_a); + + P[item] = P_a; + cs[item] = cs_out; + }); + }); + + buf_P.complete_event_state(e); + buf_cs.complete_event_state(e); + rho_getter.complete_event_state(e); + buf_uint.complete_event_state(e); + buf_xyz.complete_event_state(e); + }); + } else if ( SolverEOS_Fermi *eos_config = std::get_if(&solver_config.eos_config.config)) { diff --git a/src/shamphys/include/shamphys/eos_config.hpp b/src/shamphys/include/shamphys/eos_config.hpp index c5423f50db..796786565a 100644 --- a/src/shamphys/include/shamphys/eos_config.hpp +++ b/src/shamphys/include/shamphys/eos_config.hpp @@ -94,6 +94,24 @@ namespace shamphys { Tscal gamma; }; + /** + * @brief Equal operator for the EOS_Config_Polytropic struct + * + * @tparam Tscal Scalar type + * @param lhs First EOS_Config_Polytropic struct to compare + * @param rhs Second EOS_Config_Polytropic struct to compare + * + * This function checks if two EOS_Config_Polytropic structs are equal by comparing their K and + * gamma values. + * + * @return true if the two structs have the same K and gamma values, false otherwise + */ + template + inline bool operator==( + const EOS_Config_Polytropic &lhs, const EOS_Config_Polytropic &rhs) { + return (lhs.K == rhs.K) && (lhs.gamma == rhs.gamma); + } + /** * @brief Configuration struct for Fermi equation of state * @@ -112,21 +130,20 @@ namespace shamphys { }; /** - * @brief Equal operator for the EOS_Config_Polytropic struct + * @brief Equal operator for the EOS_Config_Fermi struct * * @tparam Tscal Scalar type - * @param lhs First EOS_Config_Polytropic struct to compare - * @param rhs Second EOS_Config_Polytropic struct to compare + * @param lhs First EOS_Config_Fermi struct to compare + * @param rhs Second EOS_Config_Fermi struct to compare * - * This function checks if two EOS_Config_Polytropic structs are equal by comparing their K and - * gamma values. + * This function checks if two EOS_Config_Fermi structs are equal by comparing their mu_e + * values. * - * @return true if the two structs have the same K and gamma values, false otherwise + * @return true if the two structs have the same mu_e values, false otherwise */ template - inline bool operator==( - const EOS_Config_Polytropic &lhs, const EOS_Config_Polytropic &rhs) { - return (lhs.K == rhs.K) && (lhs.gamma == rhs.gamma); + inline bool operator==(const EOS_Config_Fermi &lhs, const EOS_Config_Fermi &rhs) { + return lhs.mu_e == rhs.mu_e; } /** @@ -141,13 +158,13 @@ namespace shamphys { template struct EOS_Config_LocallyIsothermal_LP07 { /// Soundspeed at the reference radius - Tscal cs0 = 0.005; + Tscal cs0; /// Power exponent of the soundspeed profile - Tscal q = 2; + Tscal q; /// Reference radius - Tscal r0 = 10; + Tscal r0; }; /** @@ -205,20 +222,73 @@ namespace shamphys { } /** - * @brief Equal operator for the EOS_Config_Fermi struct + * @brief Configuration struct for the locally isothermal equation of state extended from Farris + * 2014 to include for the q index of the disc. * - * @tparam Tscal Scalar type - * @param lhs First EOS_Config_Fermi struct to compare - * @param rhs Second EOS_Config_Fermi struct to compare + * This EOS should match with ieos 13 and 14 of phantom. * - * This function checks if two EOS_Config_Fermi structs are equal by comparing their mu_e - * values. + * The equation in phantom is a bit weird so re-derived it here. * - * @return true if the two structs have the same mu_e values, false otherwise + * Farris 2014 EOS which only corresponds to q=1/2: + * + * \f$ + * c_s = \frac{H_0}{r_0}\left(\frac{G M_1}{r_1} + \frac{G M_2}{r_2}\right) + * \f$ + * + * However the extension of that EOS to q != 1/2 was only introduced in Ragussa et al 2016, if + * I'm right with: + * + * \f$ + * c_s = \frac{H_0}{r_0} \left(\frac{G M_1}{r_1} + \frac{G M_2}{r_2}\right)^{q} + * \f$ + * + * But as is the units are broken if q is not 1/2 so you need to compensate with + * \f$r_0 \Omega_0\f$ + * + * \f$c_s = \frac{H_0}{r_0}\frac{1}{(r_0 \Omega_0)^{2q - 1}}\left(\frac{G M_1}{r_1} + * + \frac{G M_2}{r_2}\right)^{q} \f$ + * + * \f$= c_{s0} \frac{1}{(r_0 \Omega_0)^{q}}\left(\frac{G M_1}{r_1} + * + \frac{G M_2}{r_2}\right)^{q}\f$ + * + * \f$= c_{s0}\frac{1}{r_0^{q}}\left[\frac{1}{\sum_i M_i}\sum_i \frac{M_i}{r_i}\right]^{q}\f$ + * + * @tparam Tscal Scalar type */ template - inline bool operator==(const EOS_Config_Fermi &lhs, const EOS_Config_Fermi &rhs) { - return lhs.mu_e == rhs.mu_e; + struct EOS_Config_LocallyIsothermalDisc_ExtendedFarris2014 { + /// Soundspeed at the reference radius + Tscal cs0; + + /// Power exponent of the soundspeed profile + Tscal q; + + /// Reference radius + Tscal r0; + + /// Number of sinks to consider for the equation of state + u32 n_sinks; + }; + + /** + * @brief Equal operator for the EOS_Config_LocallyIsothermalDisc_ExtendedFarris2014 struct + * + * @tparam Tscal Scalar type + * @param lhs First EOS_Config_LocallyIsothermalDisc_ExtendedFarris2014 struct to compare + * @param rhs Second EOS_Config_LocallyIsothermalDisc_ExtendedFarris2014 struct to compare + * + * This function checks if two EOS_Config_LocallyIsothermalDisc_ExtendedFarris2014 structs are + equal by + * comparing their cs0, q, r0, and n_sinks values. + + * @return true if the two structs have the same cs0, q, r0, and n_sinks values, false otherwise + */ + template + inline bool operator==( + const EOS_Config_LocallyIsothermalDisc_ExtendedFarris2014 &lhs, + const EOS_Config_LocallyIsothermalDisc_ExtendedFarris2014 &rhs) { + return (lhs.cs0 == rhs.cs0) && (lhs.q == rhs.q) && (lhs.r0 == rhs.r0) + && (lhs.n_sinks == rhs.n_sinks); } } // namespace shamphys