Skip to content

Commit 48e2a2e

Browse files
committed
bugfixes
1 parent 188f6af commit 48e2a2e

1 file changed

Lines changed: 42 additions & 3 deletions

File tree

src/framework/domain/metadomain_lb.cpp

Lines changed: 42 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -74,9 +74,9 @@ namespace ntt {
7474
template <SimEngine::type S, class M>
7575
requires IsCompatibleWithMetadomain<M>
7676
void Metadomain<S, M>::BalanceLoad(const SimulationParams& params) {
77-
const auto lb_dims = params.template get<std::vector<int>>("simulation.domain.load_balancing.dimensions");
78-
const auto lb_max_iters = 1; //params.template get<int>("simulation.domain.load_balancing.max_iterations", 10);
79-
const auto lb_tol = 0.1; //params.template get<real_t>("simulation.domain.load_balancing.tolerance", 0.1);
77+
const auto lb_dims = params.template get<std::vector<int>>("simulation.domain.load_balancing.dimensions");
78+
const auto lb_max_iters = static_cast<int>(params.template get<unsigned int>("simulation.domain.load_balancing.max_iterations"));
79+
const auto lb_tol = static_cast<double>(params.template get<real_t>("simulation.domain.load_balancing.tolerance"));
8080

8181
if (lb_dims.empty()) return;
8282

@@ -150,6 +150,21 @@ namespace ntt {
150150
if (!moved_global) break;
151151
}
152152

153+
// Clamp bounds to ensure every domain is at least 2*N_GHOSTS+1 cells wide
154+
const int L_min = 2 * N_GHOSTS + 1;
155+
for (int i = 1; i < nx_domains; ++i) {
156+
// ensure domain i-1 (left) is at least L_min wide
157+
if (bounds[i] - bounds[i - 1] < L_min) {
158+
bounds[i] = bounds[i - 1] + L_min;
159+
}
160+
}
161+
// walk backwards to ensure domain nx_domains-1 (rightmost) is at least L_min wide
162+
for (int i = nx_domains - 1; i >= 1; --i) {
163+
if (bounds[i + 1] - bounds[i] < L_min) {
164+
bounds[i] = bounds[i + 1] - L_min;
165+
}
166+
}
167+
153168
bool any_change = false;
154169
for (int i = 0; i <= nx_domains; ++i) {
155170
if (bounds[i] != old_bounds[i]) any_change = true;
@@ -161,6 +176,7 @@ namespace ntt {
161176

162177
// 4. Update domains if boundary changed
163178
std::vector<Domain<S, M>> new_subdomains;
179+
new_subdomains.reserve(g_ndomains);
164180
for (unsigned int idx = 0; idx < g_ndomains; ++idx) {
165181
auto& old_dom = g_subdomains[idx];
166182

@@ -274,15 +290,28 @@ namespace ntt {
274290
Kokkos::deep_copy(new_sp.ux2, old_sp.ux2);
275291
Kokkos::deep_copy(new_sp.ux3, old_sp.ux3);
276292
Kokkos::deep_copy(new_sp.weight, old_sp.weight);
293+
294+
// Reset all copied particle tags to 'alive': particles with
295+
// send-direction tags from the previous pusher step must not be
296+
// re-sent by CommunicateParticles; ShiftParticles below will
297+
// re-tag any particle that is now out of the new domain bounds.
298+
{
299+
auto tag_view = new_sp.tag;
300+
Kokkos::parallel_for("ResetTags_LB", new_sp.rangeActiveParticles(), KOKKOS_LAMBDA(int p) {
301+
tag_view(p) = ParticleTag::alive;
302+
});
303+
}
277304

278305
int offset_diff1 = old_offset_ncells[0] - new_offset_ncells[0];
279306
if constexpr (D == Dim::_1D) {
280307
if (offset_diff1 != 0) {
281308
auto i1_view = new_sp.i1;
309+
auto i1_prev_view = new_sp.i1_prev;
282310
auto tag_view = new_sp.tag;
283311
int ni1 = new_ncells[0];
284312
Kokkos::parallel_for("ShiftParticles_1D", new_sp.rangeActiveParticles(), KOKKOS_LAMBDA(int p) {
285313
i1_view(p) += offset_diff1;
314+
i1_prev_view(p) += offset_diff1;
286315
#if defined(MPI_ENABLED)
287316
tag_view(p) = mpi::SendTag(tag_view(p), i1_view(p) < 0, i1_view(p) >= ni1);
288317
#endif
@@ -292,13 +321,17 @@ namespace ntt {
292321
int offset_diff2 = old_offset_ncells[1] - new_offset_ncells[1];
293322
if (offset_diff1 != 0 || offset_diff2 != 0) {
294323
auto i1_view = new_sp.i1;
324+
auto i1_prev_view = new_sp.i1_prev;
295325
auto i2_view = new_sp.i2;
326+
auto i2_prev_view = new_sp.i2_prev;
296327
auto tag_view = new_sp.tag;
297328
int ni1 = new_ncells[0];
298329
int ni2 = new_ncells[1];
299330
Kokkos::parallel_for("ShiftParticles_2D", new_sp.rangeActiveParticles(), KOKKOS_LAMBDA(int p) {
300331
i1_view(p) += offset_diff1;
301332
i2_view(p) += offset_diff2;
333+
i1_prev_view(p) += offset_diff1;
334+
i2_prev_view(p) += offset_diff2;
302335
#if defined(MPI_ENABLED)
303336
tag_view(p) = mpi::SendTag(tag_view(p), i1_view(p) < 0, i1_view(p) >= ni1, i2_view(p) < 0, i2_view(p) >= ni2);
304337
#endif
@@ -311,6 +344,9 @@ namespace ntt {
311344
auto i1_view = new_sp.i1;
312345
auto i2_view = new_sp.i2;
313346
auto i3_view = new_sp.i3;
347+
auto i1_prev_view = new_sp.i1_prev;
348+
auto i2_prev_view = new_sp.i2_prev;
349+
auto i3_prev_view = new_sp.i3_prev;
314350
auto tag_view = new_sp.tag;
315351
int ni1 = new_ncells[0];
316352
int ni2 = new_ncells[1];
@@ -319,6 +355,9 @@ namespace ntt {
319355
i1_view(p) += offset_diff1;
320356
i2_view(p) += offset_diff2;
321357
i3_view(p) += offset_diff3;
358+
i1_prev_view(p) += offset_diff1;
359+
i2_prev_view(p) += offset_diff2;
360+
i3_prev_view(p) += offset_diff3;
322361
#if defined(MPI_ENABLED)
323362
tag_view(p) = mpi::SendTag(tag_view(p),
324363
i1_view(p) < 0, i1_view(p) >= ni1,

0 commit comments

Comments
 (0)