Skip to content

Commit 3405621

Browse files
author
Seema Mirchandaney
committed
Add support for replicate op in distributed training
- Add perform_pass_expansion_for_replicate for fwd/bwd pass expansion - Add perform_shard_expansion_for_replicate and _bwd for shard expansion - Add build_replicate_invocation in make_dynamic_open_dataflow_graph - Add is_replicate_attrs helper and guard replicate in copy_insertion - Add ReplicateAttrs to TrainingOperationAttrs - Add SumReductionFloat/Double for backward replicate reduce operation - Add issue_replicate_bwd in spawn_dynamic_node_invocation - Fix per_device_op_state init race condition with direct write - Fix .value() calls on optional per_device_op_state across op impls - Update issue_copy to support optional reduction op - Add testcase for replicate op
1 parent b6d00f8 commit 3405621

18 files changed

Lines changed: 713 additions & 364 deletions

File tree

lib/op-attrs/src/op-attrs/ops/element_unary.cc

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@ ParallelTensorDimDegrees get_output_parallel_dim_degrees(
3535
ElementUnaryAttrs const &attrs,
3636
ParallelTensorDimDegrees const &input_degrees) {
3737
ASSERT(input_degrees.sum_degree.value == 1);
38-
ASSERT(input_degrees.discard_copy_degree.value == 1);
3938

4039
return input_degrees;
4140
}

lib/op-attrs/test/src/op-attrs/ops/element_unary.cc

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -62,13 +62,5 @@ TEST_SUITE(FF_TEST_SUITE) {
6262
SumDegree{degree}, DiscardCopyDegree{1_p}, 1_p, 1_p, 1_p)));
6363
}
6464

65-
SUBCASE("discard copy degree > 1") {
66-
positive_int degree = 2_p;
67-
68-
CHECK_THROWS(get_output_shape(
69-
attrs,
70-
make_input(
71-
SumDegree{1_p}, DiscardCopyDegree{degree}, 1_p, 1_p, 1_p)));
72-
}
7365
}
7466
}

lib/realm-execution/include/realm-execution/realm_context.h

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -63,15 +63,18 @@ struct RealmContext {
6363
int priority = 0);
6464
///\}
6565

66-
/** \name Data movement */
66+
/** \name Data movement and reduction */
6767
///\{
68-
Realm::Event issue_copy(ParallelTensorShape const &src_shape,
69-
Realm::RegionInstance src_inst,
70-
ParallelTensorShape const &dst_shape,
71-
Realm::RegionInstance dst_inst,
72-
Realm::ProfilingRequestSet const &requests,
73-
Realm::Event wait_on = Realm::Event::NO_EVENT,
74-
int priority = 0);
68+
Realm::Event
69+
issue_copy(ParallelTensorShape const &src_shape,
70+
Realm::RegionInstance src_inst,
71+
ParallelTensorShape const &dst_shape,
72+
Realm::RegionInstance dst_inst,
73+
Realm::ProfilingRequestSet const &requests,
74+
Realm::Event wait_on = Realm::Event::NO_EVENT,
75+
int priority = 0,
76+
std::optional<Realm::ReductionOpID> redop_id = std::nullopt,
77+
bool exclusive = false);
7578
///\}
7679

7780
/** \name Instance management */

lib/realm-execution/include/realm-execution/sum_reduction.h

Lines changed: 0 additions & 99 deletions
This file was deleted.

lib/realm-execution/include/realm-execution/tasks/realm_reduction.h

Lines changed: 31 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,32 @@
11
#pragma once
2-
#include <realm.h>
32
#include "op-attrs/datatype.dtg.h"
3+
#include <realm.h>
44

55
namespace FlexFlow {
66

77
// Sum reduction for float
88
struct SumReductionFloat {
99
using LHS = float;
1010
using RHS = float;
11-
static constexpr RHS identity = 0.0f; // ← inside struct, constexpr
11+
static constexpr RHS identity = 0.0f; // ← inside struct, constexpr
1212

1313
template <bool EXCLUSIVE>
1414
static void apply(LHS &lhs, RHS rhs) {
1515
if (EXCLUSIVE) {
1616
lhs += rhs;
1717
} else {
1818
// atomic add for non-exclusive
19-
__sync_fetch_and_add((int*)&lhs, *(int*)&rhs);
19+
__sync_fetch_and_add((int *)&lhs, *(int *)&rhs);
2020
// proper float atomic add — use union trick
21-
union { float f; int i; } old_val, new_val;
21+
union {
22+
float f;
23+
int i;
24+
} old_val, new_val;
2225
do {
2326
old_val.f = lhs;
2427
new_val.f = old_val.f + rhs;
25-
} while (!__sync_bool_compare_and_swap(
26-
(int*)&lhs, old_val.i, new_val.i));
28+
} while (
29+
!__sync_bool_compare_and_swap((int *)&lhs, old_val.i, new_val.i));
2730
}
2831
}
2932

@@ -32,34 +35,39 @@ struct SumReductionFloat {
3235
if (EXCLUSIVE) {
3336
rhs1 += rhs2;
3437
} else {
35-
union { float f; int i; } old_val, new_val;
38+
union {
39+
float f;
40+
int i;
41+
} old_val, new_val;
3642
do {
3743
old_val.f = rhs1;
3844
new_val.f = old_val.f + rhs2;
39-
} while (!__sync_bool_compare_and_swap(
40-
(int*)&rhs1, old_val.i, new_val.i));
45+
} while (
46+
!__sync_bool_compare_and_swap((int *)&rhs1, old_val.i, new_val.i));
4147
}
4248
}
4349
};
4450

45-
4651
// Sum reduction for double
4752
struct SumReductionDouble {
4853
using LHS = double;
4954
using RHS = double;
50-
static constexpr RHS identity = 0.0; // ← inside struct, constexpr
55+
static constexpr RHS identity = 0.0; // ← inside struct, constexpr
5156

5257
template <bool EXCLUSIVE>
5358
static void apply(LHS &lhs, RHS rhs) {
5459
if (EXCLUSIVE) {
5560
lhs += rhs;
5661
} else {
57-
union { double d; long long i; } old_val, new_val;
62+
union {
63+
double d;
64+
long long i;
65+
} old_val, new_val;
5866
do {
5967
old_val.d = lhs;
6068
new_val.d = old_val.d + rhs;
6169
} while (!__sync_bool_compare_and_swap(
62-
(long long*)&lhs, old_val.i, new_val.i));
70+
(long long *)&lhs, old_val.i, new_val.i));
6371
}
6472
}
6573

@@ -68,26 +76,31 @@ struct SumReductionDouble {
6876
if (EXCLUSIVE) {
6977
rhs1 += rhs2;
7078
} else {
71-
union { double d; long long i; } old_val, new_val;
79+
union {
80+
double d;
81+
long long i;
82+
} old_val, new_val;
7283
do {
7384
old_val.d = rhs1;
7485
new_val.d = old_val.d + rhs2;
7586
} while (!__sync_bool_compare_and_swap(
76-
(long long*)&rhs1, old_val.i, new_val.i));
87+
(long long *)&rhs1, old_val.i, new_val.i));
7788
}
7889
}
7990
};
8091

8192
// Reduction op IDs — must not conflict with other registered redops
8293
enum SumReductionOpIDs {
83-
REDOP_SUM_FLOAT = 1,
94+
REDOP_SUM_FLOAT = 1,
8495
REDOP_SUM_DOUBLE = 2,
8596
};
8697

8798
inline Realm::ReductionOpID get_sum_reduction_op_id(DataType dtype) {
8899
switch (dtype) {
89-
case DataType::FLOAT: return REDOP_SUM_FLOAT;
90-
case DataType::DOUBLE: return REDOP_SUM_DOUBLE;
100+
case DataType::FLOAT:
101+
return REDOP_SUM_FLOAT;
102+
case DataType::DOUBLE:
103+
return REDOP_SUM_DOUBLE;
91104
default:
92105
PANIC("no sum reduction registered for datatype {}", dtype);
93106
}

lib/realm-execution/src/realm-execution/distributed_per_device_op_state_initialization.cc

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ PerDeviceOpStateBacking perform_distributed_per_device_op_state_initialization(
3131
std::unordered_map<DynamicNodeInvocation,
3232
DeviceSpecificPtr<PerDeviceOpState> *>
3333
device_state_map;
34+
std::vector<Realm::Event> completion_events;
3435
for (DynamicNodeInvocation const &invocation : dg.invocations) {
3536
Realm::Processor target_proc = ctx.map_device_coord_to_processor(
3637
assert_unwrap(invocation.node_attrs.device_coord));
@@ -56,14 +57,17 @@ PerDeviceOpStateBacking perform_distributed_per_device_op_state_initialization(
5657
precondition);
5758

5859
if (completion_event.has_value()) {
60+
completion_events.push_back(completion_event.value());
5961
device_state_map.insert(std::pair{invocation, device_state_ptr});
6062
} else {
6163
// Task doesn't require initialization, clean up and don't store result
6264
delete device_state_ptr;
6365
}
6466
}
6567

66-
ctx.get_outstanding_events().wait();
68+
// wait for all init tasks — direct write to *result_ptr happens
69+
// before each init task event fires so result is ready after this
70+
Realm::Event::merge_events(completion_events).wait();
6771

6872
auto deref = [](DeviceSpecificPtr<PerDeviceOpState> *const &p) { return *p; };
6973
std::unordered_map<DynamicNodeInvocation, DeviceSpecificPtr<PerDeviceOpState>>

lib/realm-execution/src/realm-execution/pcg_instance.cc

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include "realm-execution/instance_allocation.h"
77
#include "realm-execution/realm_context.h"
88
#include "realm-execution/tasks/impl/op_task.h"
9+
#include "realm-execution/tasks/realm_reduction.h"
910
#include "realm-execution/tensor_instance_backing.h"
1011
#include "task-spec/dynamic_graph/copy_insertion.h"
1112
#include "task-spec/dynamic_graph/dynamic_node_invocation.dtg.h"
@@ -215,18 +216,71 @@ static Realm::Event spawn_dynamic_node_invocation(
215216
precondition);
216217
};
217218

219+
// issue_replicate_bwd lambda
220+
auto issue_replicate_bwd = [&]() {
221+
std::optional<DynamicValueAttrs> output_grad_opt;
222+
for (auto const &[slot, value] : invocation.inputs) {
223+
if (slot.slot_tensor_role == DynamicTensorRole{FwbTensorType::GRADIENT}) {
224+
output_grad_opt = value;
225+
}
226+
}
227+
DynamicValueAttrs output_grad = assert_unwrap(output_grad_opt);
228+
DynamicValueAttrs input_grad = get_only(invocation.outputs).second;
229+
Realm::RegionInstance dst_inst =
230+
tensor_instance_backing.backing.at(input_grad).first;
231+
232+
Realm::ReductionOpID redop_id = get_sum_reduction_op_id(
233+
assert_unwrap(output_grad.parallel_tensor_shape).data_type);
234+
235+
// chain reductions sequentially to avoid write races on dst
236+
Realm::Event e = precondition;
237+
for (auto const &[p, m] : assert_unwrap(output_grad.mapping)) {
238+
DynamicValueAttrs replica_key = output_grad;
239+
replica_key.mapping =
240+
bidict<ParallelTensorSpaceCoordinate, MachineSpaceCoordinate>{{p, m}};
241+
replica_key.shard_coord = p;
242+
243+
Realm::RegionInstance src_inst =
244+
tensor_instance_backing.backing.at(replica_key).first;
245+
246+
e = ctx.issue_copy(assert_unwrap(output_grad.parallel_tensor_shape),
247+
src_inst,
248+
assert_unwrap(input_grad.parallel_tensor_shape),
249+
dst_inst,
250+
Realm::ProfilingRequestSet{},
251+
e,
252+
0,
253+
redop_id,
254+
false);
255+
}
256+
return e;
257+
};
258+
218259
TrainingOperationAttrs op_attrs =
219260
assert_unwrap(invocation.node_attrs.op_attrs);
220261
return op_attrs.visit<Realm::Event>(overload{
221262
[&](PCGOperatorAttrs const &pcg_op_attrs) {
222263
return pcg_op_attrs.visit<Realm::Event>(overload{
223264
[&](InputAttrs const &) { return Realm::Event::NO_EVENT; },
224265
[&](WeightAttrs const &) { return Realm::Event::NO_EVENT; },
266+
[&](ReplicateAttrs const &) {
267+
// this should never be reached since replicate
268+
// goes through TrainingOperationAttrs::ReplicateAttrs
269+
PANIC("unexpected replicate in PCGOperatorAttrs path");
270+
return Realm::Event::NO_EVENT;
271+
},
225272
[&](auto const &) { return spawn_task(); },
226273
});
227274
},
228275
[&](LossAttrs const &) { return spawn_task(); },
229276
[&](CopyAttrs const &) { return issue_copy(); },
277+
[&](ReplicateAttrs const &) {
278+
if (invocation.node_attrs.task_type.has_value() &&
279+
invocation.node_attrs.task_type.value() == DynamicTaskType::BWD) {
280+
return issue_replicate_bwd();
281+
}
282+
return issue_copy();
283+
},
230284
});
231285
}
232286

0 commit comments

Comments
 (0)