Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ namespace shamrock::scheduler::details {
u64 index;
i32 new_owner;

LoadBalancedTile() = default;

LoadBalancedTile(TileWithLoad<Torder, Tweight> in, u64 inindex)
: ordering_val(in.ordering_val), load_value(in.load_value), index(inindex) {}
};
Expand Down Expand Up @@ -76,9 +78,10 @@ namespace shamrock::scheduler::details {
using LBTile = TileWithLoad<Torder, Tweight>;
using LBTileResult = details::LoadBalancedTile<Torder, Tweight>;

std::vector<LBTileResult> res;
std::vector<LBTileResult> res(lb_vector.size());
#pragma omp parallel for
for (u64 i = 0; i < lb_vector.size(); i++) {
res.push_back(LBTileResult{lb_vector[i], i});
res[i] = LBTileResult{lb_vector[i], i};
}

// apply the ordering
Expand All @@ -94,15 +97,18 @@ namespace shamrock::scheduler::details {

double target_datacnt = double(res[res.size() - 1].accumulated_load_value) / wsize;

for (LBTileResult &tile : res) {
#pragma omp parallel for
for (u64 i = 0; i < res.size(); i++) {
LBTileResult &tile = res[i];
tile.new_owner
= (target_datacnt == 0)
? 0
: sycl::clamp(
i32(tile.accumulated_load_value / target_datacnt), 0, wsize - 1);
}

if (shamcomm::world_rank() == 0) {
if (shamcomm::world_rank() == 0
&& shamcomm::logs::get_loglevel() >= shamcomm::logs::log_debug) {
for (LBTileResult t : res) {
shamlog_debug_ln(
"HilbertLoadBalance",
Expand Down Expand Up @@ -141,9 +147,10 @@ namespace shamrock::scheduler::details {
using LBTile = TileWithLoad<Torder, Tweight>;
using LBTileResult = details::LoadBalancedTile<Torder, Tweight>;

std::vector<LBTileResult> res;
std::vector<LBTileResult> res(lb_vector.size());
#pragma omp parallel for
for (u64 i = 0; i < lb_vector.size(); i++) {
res.push_back(LBTileResult{lb_vector[i], i});
res[i] = LBTileResult{lb_vector[i], i};
}

// apply the ordering
Expand All @@ -160,15 +167,18 @@ namespace shamrock::scheduler::details {

double target_datacnt = double(res[res.size() - 1].accumulated_load_value) / wsize;

for (LBTileResult &tile : res) {
#pragma omp parallel for
for (u64 i = 0; i < res.size(); i++) {
LBTileResult &tile = res[i];
tile.new_owner
= (target_datacnt == 0)
? 0
: sycl::clamp(
i32(tile.accumulated_load_value / target_datacnt), 0, wsize - 1);
}

if (shamcomm::world_rank() == 0) {
if (shamcomm::world_rank() == 0
&& shamcomm::logs::get_loglevel() >= shamcomm::logs::log_debug) {
for (LBTileResult t : res) {
shamlog_debug_ln(
"HilbertLoadBalance",
Expand Down
64 changes: 30 additions & 34 deletions src/shamrock/src/scheduler/HilbertLoadBalance.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,28 +25,16 @@

inline void apply_node_patch_packing(
std::vector<shamrock::patch::Patch> &global_patch_list, std::vector<i32> &new_owner_table) {
using namespace shamrock::patch;
sycl::buffer<i32> new_owner(new_owner_table.data(), new_owner_table.size());
sycl::buffer<Patch> patch_buf(global_patch_list.data(), global_patch_list.size());

sycl::range<1> range{global_patch_list.size()};

// pack nodes
shamsys::instance::get_alt_queue()
.submit([&](sycl::handler &cgh) {
auto ptch = patch_buf.get_access<sycl::access::mode::read>(cgh);
// auto pdt = dt_buf.get_access<sycl::access::mode::read>(cgh);
auto chosen_node = new_owner.get_access<sycl::access::mode::write>(cgh);

cgh.parallel_for(range, [=](sycl::item<1> item) {
u64 i = (u64) item.get_id(0);

if (ptch[i].pack_node_index != u64_max) {
chosen_node[i] = chosen_node[ptch[i].pack_node_index];
}
});
})
.wait();
// Note that there seems to be a data race here
// However this should never happends as packing index will only point toward a patch without
// packing. As such the data we are accessing should never be modified during this loop.
#pragma omp parallel for
for (size_t i = 0; i < global_patch_list.size(); i++) {
if (global_patch_list[i].pack_node_index != u64_max) {
new_owner_table[i] = new_owner_table[global_patch_list[i].pack_node_index];
}
}
}

namespace shamrock::scheduler {
Expand Down Expand Up @@ -102,17 +90,17 @@ namespace shamrock::scheduler {

// TODO add bool for optional print verbosity
// std::cout << i << " : " << old_owner << " -> " << new_owner << std::endl;
if (new_owner != old_owner) {

using ChangeOp = LoadBalancingChangeList::ChangeOp;
using ChangeOp = LoadBalancingChangeList::ChangeOp;

ChangeOp op;
op.patch_idx = i;
op.patch_id = global_patch_list[i].id_patch;
op.rank_owner_new = new_owner;
op.rank_owner_old = old_owner;
op.tag_comm = tags_it_node[old_owner];
ChangeOp op;
op.patch_idx = i;
op.patch_id = global_patch_list[i].id_patch;
op.rank_owner_new = new_owner;
op.rank_owner_old = old_owner;
op.tag_comm = tags_it_node[old_owner];

if (new_owner != old_owner) {
change_list.change_ops.push_back(op);
tags_it_node[old_owner]++;
}
Expand All @@ -126,23 +114,31 @@ namespace shamrock::scheduler {
f64 avg = 0;
f64 var = 0;

for (i32 nid = 0; nid < shamcomm::world_size(); nid++) {
i32 world_size = shamcomm::world_size();

#pragma omp parallel for reduction(min : min) reduction(max : max) reduction(+ : avg)
for (i32 nid = 0; nid < world_size; nid++) {
f64 val = load_per_node[nid];
min = sycl::fmin(min, val);
max = sycl::fmax(max, val);
avg += val;
}

if (shamcomm::world_rank() == 0) {
if (shamcomm::world_rank() == 0
&& shamcomm::logs::get_loglevel() >= shamcomm::logs::log_debug) {
for (i32 nid = 0; nid < world_size; nid++) {
shamlog_debug_ln(
"HilbertLoadBalance", "node :", nid, "load :", load_per_node[nid]);
}
}
avg /= shamcomm::world_size();
for (i32 nid = 0; nid < shamcomm::world_size(); nid++) {
avg /= world_size;

#pragma omp parallel for reduction(+ : var)
for (i32 nid = 0; nid < world_size; nid++) {
f64 val = load_per_node[nid];
var += (val - avg) * (val - avg);
}
var /= shamcomm::world_size();
var /= world_size;

if (shamcomm::world_rank() == 0) {
std::string str = "Loadbalance stats : \n";
Expand Down
Loading