Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions examples/benchmarks/sph_weak_scale_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@
)
cfg.set_boundary_periodic()
cfg.set_eos_adiabatic(gamma)
cfg.set_max_neigh_cache_size(int(100e9))
cfg.print_status()
model.set_solver_config(cfg)
model.init_scheduler(scheduler_split_val, scheduler_merge_val)
Expand Down Expand Up @@ -102,7 +101,7 @@

model.set_value_in_a_box("uint", "f64", 0, bmin, bmax)

rinj = 8 * dr
rinj = 16 * dr
u_inj = 1
model.add_kernel_value("uint", "f64", u_inj, (0, 0, 0), rinj)

Expand All @@ -116,9 +115,6 @@
model.set_cfl_cour(0.1)
model.set_cfl_force(0.1)

model.set_cfl_multipler(1e-4)
model.set_cfl_mult_stiffness(1e6)

shamrock.backends.reset_mem_info_max()

# converge smoothing length and compute initial dt
Expand All @@ -129,8 +125,11 @@
res_cnts = []
res_system_metrics = []

for i in range(5):
for i in range(10):
shamrock.sys.mpi_barrier()

# To replay the same step
model.set_next_dt(0.0)
model.timestep()

tmp_res_rate, tmp_res_cnt, tmp_system_metrics = (
Expand Down
3 changes: 3 additions & 0 deletions src/shamalgs/src/collective/sparse_exchange.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ namespace shamalgs::collective {
/// fetch u64_2 from global message data
std::vector<u64_2> fetch_global_message_data(
const std::vector<CommMessageInfo> &messages_send) {
__shamrock_stack_entry();

std::vector<u64_2> local_data = std::vector<u64_2>(messages_send.size());

Expand Down Expand Up @@ -84,6 +85,7 @@ namespace shamalgs::collective {

/// decode message to get message
std::vector<CommMessageInfo> decode_all_message(const std::vector<u64_2> &global_data) {
__shamrock_stack_entry();
std::vector<CommMessageInfo> message_all(global_data.size());
for (u64 i = 0; i < global_data.size(); i++) {
message_all[i] = unpack(global_data[i]);
Expand All @@ -94,6 +96,7 @@ namespace shamalgs::collective {

/// compute message tags
void compute_tags(std::vector<CommMessageInfo> &message_all) {
__shamrock_stack_entry();

std::vector<i32> tag_map(shamcomm::world_size(), 0);

Expand Down
5 changes: 5 additions & 0 deletions src/shamrock/include/shamrock/scheduler/SerialPatchTree.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -269,17 +269,22 @@ class SerialPatchTree {
sycl::queue &queue,
shamrock::patch::PatchField<T> pfield,
Func &&reducer) {
__shamrock_stack_entry();

shamrock::patch::PatchtreeField<T> ptfield;
ptfield.allocate(get_element_count());

{
__shamrock_stack_entry();
sycl::host_accessor lpid{
shambase::get_check_ref(linked_patch_ids_buf), sycl::read_only};
sycl::host_accessor tree_field{
shambase::get_check_ref(ptfield.internal_buf), sycl::write_only, sycl::no_init};

// init reduction
std::unordered_map<u64, u64> &idp_to_gid = sched.patch_list.id_patch_to_global_idx;

#pragma omp parallel for
for (u64 idx = 0; idx < get_element_count(); idx++) {
tree_field[idx] = (lpid[idx] != u64_max) ? pfield.get(lpid[idx]) : T();
}
Expand Down
Loading