Skip to content

Support Replicate Parallel Operator on CPUs for Realm backend#1640

Open
seemamirch wants to merge 2 commits intoflexflow:masterfrom
seemamirch:sm/realm-parallel-operators-replicate
Open

Support Replicate Parallel Operator on CPUs for Realm backend#1640
seemamirch wants to merge 2 commits intoflexflow:masterfrom
seemamirch:sm/realm-parallel-operators-replicate

Conversation

@seemamirch
Copy link
Copy Markdown

@seemamirch seemamirch commented Apr 9, 2026

Description of changes:

Add support for replicate op in distributed training & Realm backend

  • Add perform_pass_expansion_for_replicate for fwd/bwd pass expansion
  • Add perform_shard_expansion_for_replicate and _bwd for shard expansion
  • Add build_replicate_invocation in make_dynamic_open_dataflow_graph
  • Add is_replicate_attrs helper and guard replicate in copy_insertion
  • Add ReplicateAttrs to TrainingOperationAttrs
  • Add SumReductionFloat/Double for backward replicate reduce operation
  • Add issue_replicate_bwd in spawn_dynamic_node_invocation
  • Fix per_device_op_state init race condition with direct write
  • Fix .value() calls on optional per_device_op_state across op unary/binary impls
  • Update issue_copy to support optional reduction op
  • Fix Relu to allow discard_copy_degree > 1
  • Add testcase for Replicate Op

This change is Reviewable

Seema Mirchandaney added 2 commits April 9, 2026 15:49
- Add perform_pass_expansion_for_replicate for fwd/bwd pass expansion
- Add perform_shard_expansion_for_replicate and _bwd for shard expansion
- Add build_replicate_invocation in make_dynamic_open_dataflow_graph
- Add is_replicate_attrs helper and guard replicate in copy_insertion
- Add ReplicateAttrs to TrainingOperationAttrs
- Add SumReductionFloat/Double for backward replicate reduce operation
- Add issue_replicate_bwd in spawn_dynamic_node_invocation
- Fix per_device_op_state init race condition with direct write
- Fix .value() calls on optional per_device_op_state across op impls
- Update issue_copy to support optional reduction op
- Add testcase for replicate op
- Add perform_pass_expansion_for_replicate for fwd/bwd pass expansion
- Add perform_shard_expansion_for_replicate and _bwd for shard expansion
- Add build_replicate_invocation in make_dynamic_open_dataflow_graph
- Add is_replicate_attrs helper and guard replicate in copy_insertion
- Add ReplicateAttrs to TrainingOperationAttrs
- Add SumReductionFloat/Double for backward replicate reduce operation
- Add issue_replicate_bwd in spawn_dynamic_node_invocation
- Fix per_device_op_state init race condition with direct write
- Fix .value() calls on optional per_device_op_state across op impls
- Update issue_copy to support optional reduction op
- Add testcase for replicate op
@seemamirch
Copy link
Copy Markdown
Author

@lockshaw @elliottslaughter - Please review

@seemamirch seemamirch force-pushed the sm/realm-parallel-operators-replicate branch from b18c75d to 3405621 Compare April 9, 2026 23:13
@lockshaw lockshaw self-requested a review April 14, 2026 19:12
Copy link
Copy Markdown
Collaborator

@lockshaw lockshaw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@lockshaw reviewed 9 files and all commit messages, and made 12 comments.
Reviewable status: 9 of 17 files reviewed, 12 unresolved discussions (waiting on seemamirch).


lib/op-attrs/test/src/op-attrs/ops/element_unary.cc line 65 at r1 (raw file):

    }

    SUBCASE("discard copy degree > 1") {

Minor: Ideally add a test case for the correct behavior rather than removing it. I'm also happy to contribute this if you'd prefer


lib/realm-execution/include/realm-execution/realm_context.h line 69 at r1 (raw file):

  ///\{
  Realm::Event
      issue_copy(ParallelTensorShape const &src_shape,

Minor: It would be good to get a docstring with an explanation of all these parameters at some point


lib/realm-execution/include/realm-execution/tasks/realm_reduction.h line 1 at r1 (raw file):

#pragma once

For consistency with the rest of the codebase

Suggestion:

#ifndef ...

lib/realm-execution/include/realm-execution/tasks/realm_reduction.h line 8 at r1 (raw file):

// Sum reduction for float
struct SumReductionFloat {

Minor: It looks(?) like this API comes from realm, is there a link to some docs somewhere that we could include in the docstrings for this file for people not as familiar with the API>


lib/realm-execution/include/realm-execution/tasks/realm_reduction.h line 9 at r1 (raw file):

// Sum reduction for float
struct SumReductionFloat {
  using LHS = float;

Why use aliases here? Since they're just constant (i.e., not dependent on template params or anything) it seems like it would be cleaner to just omit them and use the type directly


lib/realm-execution/src/realm-execution/distributed_per_device_op_state_initialization.cc line 34 at r1 (raw file):

                     DeviceSpecificPtr<PerDeviceOpState> *>
      device_state_map;
  std::vector<Realm::Event> completion_events;

Minor: It seems like the pattern of pair<T, Realm::Event> is becoming pretty common in realm-execution, maybe we can generalize this into a more structured future type to avoid some of the low-level manipulations? I don't love juggling these separate datastructures, it feels like creating opportunities for one to get out of sync with the other and create a bunch of bugs


lib/task-spec/include/task-spec/dynamic_graph/training_operation_attrs.dtg.toml line 30 at r1 (raw file):

[[values]]
type = "::FlexFlow::ReplicateAttrs"

Isn't ReplicateAttrs already part of PCGOperatorAttrs?


lib/task-spec/src/task-spec/dynamic_graph/copy_insertion.cc line 36 at r1 (raw file):

    return true;
  }
  return false;

In general we avoid assigning to refs unless it has a provable (i.e., profiling has been run) performance benefit as it creates more opportunities for lifetime/memory issues

Suggestion:

  TrainingOperationAttrs op_attrs = i.node_attrs.op_attrs.value();
  return op_attrs.is_replicate();

lib/task-spec/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_mapped_pcg.cc line 25 at r1 (raw file):

  // find the layer that produces this tensor
  for (auto const &[layer, _] : get_parallel_layer_attrs_mapping(mpcg.pcg)) {

I think you can replace a bunch of this with get_source_layer from parallel_computation_graph.h


lib/task-spec/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_mapped_pcg.cc line 41 at r1 (raw file):

                            parallel_tensor_guid_t const &tensor) {
  std::unordered_map<parallel_layer_guid_t, TensorSlotName> result;
  for (auto const &[layer, _] : get_parallel_layer_attrs_mapping(mpcg.pcg)) {

I think you can replace a bunch of this with get_source_layer from parallel_computation_graph.h


lib/task-spec/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_mapped_pcg.cc line 79 at r1 (raw file):

static DynamicNodeInvocation
    build_replicate_invocation(parallel_layer_guid_t const &layer,
                               ParallelLayerAttrs const &attrs,

If this should only handle ReplicateAttrs, change the type of the attrs parameter to ReplicateAttrs


lib/task-spec/src/task-spec/ops/impl/element_binary.cc line 39 at r1 (raw file):

  ProfilingSettings profiling = acc.get_profiling_settings();
  DeviceType kernel_device_type = acc.get_kernel_device_type();
  std::optional<ElementBinaryPerDeviceState> per_device_state =

Am I correct in understanding that this is because CPU implementations for these ops don't need the per device state?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants