Support Replicate Parallel Operator on CPUs for Realm backend#1640
Support Replicate Parallel Operator on CPUs for Realm backend#1640seemamirch wants to merge 2 commits intoflexflow:masterfrom
Conversation
- Add perform_pass_expansion_for_replicate for fwd/bwd pass expansion - Add perform_shard_expansion_for_replicate and _bwd for shard expansion - Add build_replicate_invocation in make_dynamic_open_dataflow_graph - Add is_replicate_attrs helper and guard replicate in copy_insertion - Add ReplicateAttrs to TrainingOperationAttrs - Add SumReductionFloat/Double for backward replicate reduce operation - Add issue_replicate_bwd in spawn_dynamic_node_invocation - Fix per_device_op_state init race condition with direct write - Fix .value() calls on optional per_device_op_state across op impls - Update issue_copy to support optional reduction op - Add testcase for replicate op
- Add perform_pass_expansion_for_replicate for fwd/bwd pass expansion - Add perform_shard_expansion_for_replicate and _bwd for shard expansion - Add build_replicate_invocation in make_dynamic_open_dataflow_graph - Add is_replicate_attrs helper and guard replicate in copy_insertion - Add ReplicateAttrs to TrainingOperationAttrs - Add SumReductionFloat/Double for backward replicate reduce operation - Add issue_replicate_bwd in spawn_dynamic_node_invocation - Fix per_device_op_state init race condition with direct write - Fix .value() calls on optional per_device_op_state across op impls - Update issue_copy to support optional reduction op - Add testcase for replicate op
|
@lockshaw @elliottslaughter - Please review |
b18c75d to
3405621
Compare
lockshaw
left a comment
There was a problem hiding this comment.
@lockshaw reviewed 9 files and all commit messages, and made 12 comments.
Reviewable status: 9 of 17 files reviewed, 12 unresolved discussions (waiting on seemamirch).
lib/op-attrs/test/src/op-attrs/ops/element_unary.cc line 65 at r1 (raw file):
} SUBCASE("discard copy degree > 1") {
Minor: Ideally add a test case for the correct behavior rather than removing it. I'm also happy to contribute this if you'd prefer
lib/realm-execution/include/realm-execution/realm_context.h line 69 at r1 (raw file):
///\{ Realm::Event issue_copy(ParallelTensorShape const &src_shape,
Minor: It would be good to get a docstring with an explanation of all these parameters at some point
lib/realm-execution/include/realm-execution/tasks/realm_reduction.h line 1 at r1 (raw file):
#pragma once
For consistency with the rest of the codebase
Suggestion:
#ifndef ...lib/realm-execution/include/realm-execution/tasks/realm_reduction.h line 8 at r1 (raw file):
// Sum reduction for float struct SumReductionFloat {
Minor: It looks(?) like this API comes from realm, is there a link to some docs somewhere that we could include in the docstrings for this file for people not as familiar with the API>
lib/realm-execution/include/realm-execution/tasks/realm_reduction.h line 9 at r1 (raw file):
// Sum reduction for float struct SumReductionFloat { using LHS = float;
Why use aliases here? Since they're just constant (i.e., not dependent on template params or anything) it seems like it would be cleaner to just omit them and use the type directly
lib/realm-execution/src/realm-execution/distributed_per_device_op_state_initialization.cc line 34 at r1 (raw file):
DeviceSpecificPtr<PerDeviceOpState> *> device_state_map; std::vector<Realm::Event> completion_events;
Minor: It seems like the pattern of pair<T, Realm::Event> is becoming pretty common in realm-execution, maybe we can generalize this into a more structured future type to avoid some of the low-level manipulations? I don't love juggling these separate datastructures, it feels like creating opportunities for one to get out of sync with the other and create a bunch of bugs
lib/task-spec/include/task-spec/dynamic_graph/training_operation_attrs.dtg.toml line 30 at r1 (raw file):
[[values]] type = "::FlexFlow::ReplicateAttrs"
Isn't ReplicateAttrs already part of PCGOperatorAttrs?
lib/task-spec/src/task-spec/dynamic_graph/copy_insertion.cc line 36 at r1 (raw file):
return true; } return false;
In general we avoid assigning to refs unless it has a provable (i.e., profiling has been run) performance benefit as it creates more opportunities for lifetime/memory issues
Suggestion:
TrainingOperationAttrs op_attrs = i.node_attrs.op_attrs.value();
return op_attrs.is_replicate();lib/task-spec/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_mapped_pcg.cc line 25 at r1 (raw file):
// find the layer that produces this tensor for (auto const &[layer, _] : get_parallel_layer_attrs_mapping(mpcg.pcg)) {
I think you can replace a bunch of this with get_source_layer from parallel_computation_graph.h
lib/task-spec/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_mapped_pcg.cc line 41 at r1 (raw file):
parallel_tensor_guid_t const &tensor) { std::unordered_map<parallel_layer_guid_t, TensorSlotName> result; for (auto const &[layer, _] : get_parallel_layer_attrs_mapping(mpcg.pcg)) {
I think you can replace a bunch of this with get_source_layer from parallel_computation_graph.h
lib/task-spec/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_mapped_pcg.cc line 79 at r1 (raw file):
static DynamicNodeInvocation build_replicate_invocation(parallel_layer_guid_t const &layer, ParallelLayerAttrs const &attrs,
If this should only handle ReplicateAttrs, change the type of the attrs parameter to ReplicateAttrs
lib/task-spec/src/task-spec/ops/impl/element_binary.cc line 39 at r1 (raw file):
ProfilingSettings profiling = acc.get_profiling_settings(); DeviceType kernel_device_type = acc.get_kernel_device_type(); std::optional<ElementBinaryPerDeviceState> per_device_state =
Am I correct in understanding that this is because CPU implementations for these ops don't need the per device state?
Description of changes:
Add support for replicate op in distributed training & Realm backend
This change is