Skip to content

Gfx1250 moe#402

Merged
coderfeli merged 164 commits intomainfrom
gfx1250_moe_new
Apr 16, 2026
Merged

Gfx1250 moe#402
coderfeli merged 164 commits intomainfrom
gfx1250_moe_new

Conversation

@XingerZhu
Copy link
Copy Markdown
Collaborator

Motivation

Technical Details

Test Plan

Test Result

Submission Checklist

sjfeng1999 and others added 30 commits March 3, 2026 08:49
- Fix Python version compatibility in meta.py: add support for Python < 3.11
  by checking for positions attribute availability
- Replace hardcoded MLIR library paths in executor.py with environment variable
  MLIR_PATH, with clear error message when not set
- Update LLVM commit hash and enable ROCM runner in build script
* [FLYDSL]:add copy_atom right_inverse

* [FLYDSL]: right_inverse dynamic process bugfix

* [FLYDSL]:Python refactoring and adaptation

* [FLYDSL]:rm example 05
* Migrate Python bindings to PyConcreteType<> and fix TypeID ODR violation

- FlyExtension.cpp / FlyROCDLExtension.cpp: migrate from legacy
  mlir_type_subclass() to PyConcreteType<> CRTP pattern (required by
  new MLIR Python binding API). Types are defined inside
  namespace mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::fly, using
  ::mlir:: global qualifiers to avoid the mlir::python::mlir namespace
  collision when NB_DOMAIN=mlir.

- CMakeLists.txt: remove MLIRFlyDialect / MLIRFlyROCDLDialect from
  _fly.so / _fly_rocdl.so PRIVATE_LINK_LIBS. These static archives
  were being linked into both the extension modules AND FlyPythonCAPI.so
  (via EMBED_CAPI_LINK_LIBS → MLIRCPIFly), creating duplicate TypeID
  static variables. The dialect registered under FlyPythonCAPI.so's
  TypeIDs but _fly.so looked up types with its own copy, causing
  "storage uniquer isn't initialized" at runtime. Now all symbols are
  resolved from FlyPythonCAPI.so.

- FlyToROCDL.cpp: use string-based type matching for MmaAtomCDNA3_MFMA
  to work around the same TypeID ODR issue in the conversion pass, and
  fix ROCDL MFMA intrinsic call to use I32Attr attributes instead of
  Value operands for cbsz/abid/blgp control parameters.

* Fix pass registry ODR violation: register Fly passes via CAPI

- PRIVATE_LINK_LIBS MLIRFlyToROCDL in _mlirRegisterEverything pulled in a
local copy of MLIRPass, causing registerFlyPasses() to register into a
LOCAL pass registry inside _mlirRegisterEverything.so while
PassManager.parse() queried the GLOBAL registry in FlyPythonCAPI.so.

- Fix by introducing CAPI functions (mlirRegisterFlyPasses,
mlirRegisterFlyToROCDLConversionPass) in the CAPI libraries so pass
registration happens inside FlyPythonCAPI.so's single global registry.

- update cmake/llvm-hash.txt to keep same with triton llvm hash.

* Sync build_llvm.sh with pre_bumpupllvm and add ROCM runner

Align script with pre_bumpupllvm branch: full clone, buildmlir dir,
NVPTX target, NB_DOMAIN=mlir, package install by default. Keep
reading LLVM commit from cmake/llvm-hash.txt. Add
MLIR_ENABLE_ROCM_RUNNER=ON for GPU kernel execution support.

Co-authored-by: Cursor <cursoragent@cursor.com>

---------

Co-authored-by: Cursor <cursoragent@cursor.com>
- Rename C++ binding structs with Py prefix (e.g. IntTupleType -> PyIntTupleType) for consistency
- Add __all__ exports to typing, primitive, and gpu modules
- Add Int4 numeric type
- Fix frameInfo.positions compatibility for older Python versions
- Fix dialect import order to ensure _Dialect is properly exported
- Add fly_rocdl ops/enum gen copy rules in CMake
- Improve build_llvm.sh with configurable parallel jobs and --no-install flag
- Clean up redundant comments and formatting

Co-authored-by: Cursor <cursoragent@cursor.com>
gemm test ready
* [FLYDSL]: add recast_layout op

* [FLYDSL]: refactor

* [FLYDSL]: add detail namespace

* [FLYDSL]: add upcast assert

* [FLYDSL]: rm bits number

* [FLYDSL]: rm redundant code

* [FLYDSL]: bits number only support static value

* [FLYDSL]: change APIntAttr to I32Attr

* [FLYDSL]: rm notes
* fix run error

* port all  gemm from main

* fuix cudagraph hack

* add int4 version

* change flymemref convert

* test ok

* add build script

* fix graph2

* add files

* fix flops

* fix path

* fix local test

* fix

* clean

* update readme
* add compile only and dumpir
- Add fly-opt tool (tools/fly-opt/) for MLIR IR transformations,
  registering Fly/FlyROCDL dialects and all custom passes
- Add lit.cfg.py with fly-opt/FileCheck configuration
- Test using 'lit -v tests/' to test basic lowering tests
- Add LayoutAlgebra tests: construction, size/cosize, coordinate,
  composition, product, divide, int_tuple operations
- Add Transforms tests: canonicalize, layout_lowering
- Add Conversion tests for convert-fly-to-rocdl pass, split by category:
  type_conversion, memref_alloca, memref_ops, pointer_ops,
  mma_atom, gpu_ops
…gration

- Enable LLVM_BUILD_TOOLS so fly-opt is built with the default ninja target
- Add MLIR lit test section to scripts/run_tests.sh
- Update test/lit.cfg.py to use FLY_BUILD_DIR env var (default: build-fly)
aoli26 and others added 15 commits April 9, 2026 05:20
…1250.py

Add --bench mode that sweeps model configs (DeepSeek-TP/EP, GPToss) ×
dtypes (fp4/fp8/a8w4/fp16/bf16) × token counts with tabular TFLOPS/BW
output. Reuses existing run_moe_stage1/stage2 runners. Original test
mode is unaffected.

Made-with: Cursor
…non-aligned dimensions

- Fix TypeError in stage2 mxscale non-wave-specialized pipeline loop:
  when n_accs==1, scf_yield_ returns a single ArithValue instead of a
  list, causing _res[:n_accs] and _st[:n_accs] to fail. Normalize with
  isinstance check before slicing.

- Add automatic K-dimension zero-padding in run_moe_stage1 (model_dim)
  and run_moe_stage2 (inter_dim) for mxscale dtypes (fp4/fp8/a8w4) when
  the dimension is not divisible by tile_k. This enables GPToss
  (dim=2880, 2880%128=64) to run without manual dimension adjustment.

- Use original (unpadded) dimensions for FLOPS/bandwidth accounting.

Made-with: Cursor
Made-with: Cursor

# Conflicts:
#	kernels/gemm_common_gfx1250.py
#	tests/kernels/test_gemm_fp8fp4_gfx1250.py
#	tests/kernels/test_wmma_gemm_gfx1250.py
- Split monolithic moe_gemm_2stage_gfx1250.py into:
  - moe_gemm_2stage_common_gfx1250.py: shared utilities
  - moe_gemm_2stage_wmma_gfx1250.py: fp16/bf16 WMMA kernels with own public API
  - moe_gemm_2stage_mxscale_gfx1250.py: fp4/fp8/a8w4 MXScale kernels with own public API
- Each module has self-contained compile_moe_gemm1/2/2_ex entry points
- Unsupported dtypes raise ValueError instead of fallback
- Split test_moe_gemm_gfx1250.py into test_moe_gemm_wmma_gfx1250.py and
  test_moe_gemm_mxscale_gfx1250.py with updated imports

Made-with: Cursor
Copilot AI review requested due to automatic review settings April 14, 2026 18:01
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds gfx1250-focused Mixture-of-Experts (MoE) 2-stage GEMM coverage and supporting kernel/runtime utilities, including new WMMA fp16 kernels and TDM gather descriptor support.

Changes:

  • Introduces a comprehensive gfx1250 MXScale/int-quant MoE 2-stage test harness with routing, correctness, and perf/benchmark options.
  • Adds shared gfx1250 MoE kernel helpers plus new fp16/bf16 WMMA stage1/stage2 kernel implementations.
  • Extends ROCDL TDM APIs with gather-mode descriptors/loads/stores, and adds MoE-oriented benchmarking helpers.

Reviewed changes

Copilot reviewed 5 out of 8 changed files in this pull request and generated 1 comment.

Show a summary per file
File Description
tests/kernels/test_moe_gemm_mxscale_gfx1250.py New gfx1250 MoE 2-stage test harness for fp4/fp8/a8w4 + int quant variants.
tests/kernels/benchmark_common.py Adds reusable MoE benchmarking utilities (tile resolution, bytes moved, timing).
python/flydsl/expr/rocdl/tdm_ops.py Adds TDM gather descriptor + gather load/store APIs for row-indexed transfers.
kernels/moe_gemm_2stage_wmma_gfx1250.py New gfx1250 WMMA fp16/bf16 MoE stage1/stage2 kernel compilation entry points.
kernels/moe_gemm_2stage_common_gfx1250.py New shared helpers used by gfx1250 MoE kernels (tiling, epilogues, wrappers).
kernels/gemm_common_gfx1250.py Extends pipeline/barrier helpers and adds wave-specialized TDM load helper (but currently has duplicated definitions).
Comments suppressed due to low confidence (1)

kernels/gemm_common_gfx1250.py:178

  • WGP_BARRIER_ID, pipeline_fence_signal, pipeline_fence_wait, and issue_tdm_loads are defined twice in this module (see the second block starting here). In Python the later definitions override the earlier ones, which defeats the new scf.IfOp-based implementation above and reintroduces the older issue_tdm_loads that uses a Python if arith.cmpi(...) (invalid for MLIR values). Please remove the duplicated older definitions (or merge the logic) so there is exactly one set of fence/load helpers, and ensure issue_tdm_loads uses IR control flow (scf.IfOp) rather than Python conditionals.
WGP_BARRIER_ID = -1


def pipeline_fence_signal(outstanding=0, use_cluster=False):
    """Signal half of a split barrier fence.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread tests/kernels/test_moe_gemm_mxscale_gfx1250.py
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 5 out of 8 changed files in this pull request and generated 3 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread tests/kernels/test_moe_gemm_mxscale_gfx1250.py Outdated
Comment thread tests/kernels/test_moe_gemm_mxscale_gfx1250.py Outdated
Comment thread kernels/gemm_common_gfx1250.py
Comment thread kernels/gemm_common_gfx1250.py
… store

Extract shared logic between _compile_stage1_mxscale_kernel_impl and
_compile_stage2_mxscale_kernel_impl into four new helpers in the common
module:

- _compute_mxscale_tiling(): format config, WMMA constants, tiling math,
  parameter validation
- _make_mxscale_data_loaders(): factory for 9 identical LDS data-loading
  adapter closures
- _compute_pipeline_plan(): pre-load / tail plan computation
- _compute_tdm_store_layout(): TDM store D output LDS layout

Net effect: mxscale file shrinks by ~330 lines (3220 -> 2889) while
common grows by ~326 lines with reusable infrastructure.

Made-with: Cursor
- Import _bf16_to_f16_wrapper from common instead of duplicating locally
- Merge _compile_moe_stage1/2_wmma_kernel into unified _compile_moe_wmma_gemm
- Simplify compile_moe_gemm1/2/2_ex to thin wrappers using **kw forwarding
- Reduces file from 1101 to 912 lines (-17%)

Made-with: Cursor
…common

- Deduplicate routing utilities (moe_sorting_torch_native, build_routing_buffers,
  get_topk_valid_mask, RoutingBuffers) by importing from test_moe_gemm.py
- Remove aiter CK comparison blocks (dead code on gfx1250)
- Remove unused w2 allocation/quantization from stage1 runners
- Clean up commented-out debug lines and unused imports
- Extract generic MoE benchmark sweep system (add_moe_bench_args,
  moe_bench_config, moe_bench_main) into benchmark_common.py

Made-with: Cursor
WGP_BARRIER_ID, pipeline_fence_signal, pipeline_fence_wait, and
issue_tdm_loads were each defined twice. Keep the first (correct)
versions that use scf.IfOp for proper MLIR IR generation.

Made-with: Cursor
@XingerZhu XingerZhu requested a review from coderfeli April 16, 2026 02:44
@coderfeli coderfeli merged commit f65e930 into main Apr 16, 2026
9 checks passed
@coderfeli coderfeli deleted the gfx1250_moe_new branch April 16, 2026 03:18
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants