diff --git a/.github/workflows/TestBuild.yml b/.github/workflows/TestBuild.yml index 746ea03fe2..d6175f6b27 100644 --- a/.github/workflows/TestBuild.yml +++ b/.github/workflows/TestBuild.yml @@ -2,32 +2,70 @@ name: test build process on: push: - branches: [ main, ci ] + branches: [ main, ci, ci-test ] pull_request: - branches: [ main, ci ] + branches: [ main, ci, ci-test ] jobs: - build: - runs-on: self-hosted + test: + name: Build (${{ matrix.arch }}) + timeout-minutes: 360 + continue-on-error: ${{ matrix.experimental }} + + concurrency: + group: buddy-mlir-${{ matrix.arch }} + cancel-in-progress: true + + strategy: + fail-fast: false + matrix: + include: + - arch: x64 + experimental: false + - arch: riscv64 + experimental: true + + runs-on: + - self-hosted + - ${{ matrix.arch }} + steps: + - name: Print target + run: | + echo "Building on ${{ matrix.arch }}" + echo "------ uname ------" + uname -a + # 0. Install the Ninja build system. - name: Set up ninja + if: matrix.arch == 'x64' uses: seanmiddleditch/gha-setup-ninja@master - # 1. Checkout the main repository without fetching submodules initially + # 1A. Checkout the main repository without fetching submodules initially - name: Checkout main repository + if: matrix.arch == 'x64' uses: actions/checkout@v4 with: submodules: 'false' + # 1B. first clone from local gitea then update metedata from github + - name: Checkout from local Gitea (riscv64) + if: matrix.arch == 'riscv64' + run: | + git clone --no-checkout https://community-ci.openruyi.cn/RuyiAI-Stack/buddy-mlir . + git remote add github https://github.com/buddy-compiler/buddy-mlir.git + git fetch github ${{ github.sha }} + git checkout ${{ github.sha }} + # 2. Retrieve the commit ID of the LLVM submodule. - name: Get LLVM submodule commit id: llvm-submodule-commit run: | echo "commit=$(git submodule status llvm | awk '{print $1;}')" >> $GITHUB_OUTPUT - # 3. Cache the LLVM submodule source code. - - name: Cache LLVM source + # 3. Cache the LLVM submodule source code (only for x86). + - name: Cache LLVM source (x86) + if: matrix.arch == 'x64' id: cache-llvm-source uses: actions/cache@v4 with: @@ -36,15 +74,16 @@ jobs: restore-keys: | llvm-source- - # 4. If the cache is not found, pull the LLVM submodule. + # 4. If the cache is not found, pull the LLVM submodule. (x86, cache miss) - name: Checkout LLVM submodule + if: matrix.arch == 'x64' && steps.cache-llvm-source.outputs.cache-hit != 'true' run: | rm -rf llvm git submodule update --init --recursive llvm - if: steps.cache-llvm-source.outputs.cache-hit != 'true' # 5. Cache the LLVM build directory. - name: Cache LLVM build directory + if: matrix.arch == 'x64' id: cache-llvm-build-dir uses: actions/cache@v4 with: @@ -54,9 +93,9 @@ jobs: llvm-build- # 6. Verify llvm-build when cached; build LLVM when no cache or verification fails. - - name: Check LLVM cache + - name: Check LLVM cache (x86) id: check-llvm-cache - if: steps.cache-llvm-build-dir.outputs.cache-hit == 'true' + if: matrix.arch == 'x64' && (steps.cache-llvm-build-dir.outputs.cache-hit == 'true') run: | for conda in ~/miniconda3/bin/activate ~/miniforge3/bin/activate; do [ -f "$conda" ] && source "$conda" buddy && break @@ -72,8 +111,9 @@ jobs: fi continue-on-error: true - - name: Configure and Build LLVM - if: steps.cache-llvm-build-dir.outputs.cache-hit != 'true' || steps.check-llvm-cache.outputs.need-rebuild == 'true' + - name: Configure and Build LLVM (x86) + if: matrix.arch == 'x64' && + (steps.cache-llvm-build-dir.outputs.cache-hit != 'true' || steps.check-llvm-cache.outputs.need-rebuild == 'true') run: | for conda in ~/miniconda3/bin/activate ~/miniforge3/bin/activate; do [ -f "$conda" ] && source "$conda" buddy && break @@ -94,8 +134,76 @@ jobs: -DPython3_EXECUTABLE=$(which python3) ninja check-clang check-mlir omp - # 7. Check buddy-mlir build. - - name: Check buddy-mlir build + - name: Enable ccache (riscv64) + if: matrix.arch == 'riscv64' + run: | + echo "CCACHE_DIR=/home/jenkins/.ccache" >> $GITHUB_ENV + echo "PATH=/usr/lib/ccache:$PATH" >> $GITHUB_ENV + ccache -M 50G + echo "CCACHE_MAXSIZE=50G" >> $GITHUB_ENV + + - name: Setup LLVM build (riscv64) + if: matrix.arch == 'riscv64' + id: prepare-llvm-build + run: | + LLVM_SRC=/home/jenkins/src/llvm-project + LLVM_BUILD_ROOT=/home/jenkins/llvm-cache + LLVM_COMMIT=$(git ls-tree HEAD llvm | awk '{print $3}') + LLVM_BUILD_DIR=$LLVM_BUILD_ROOT/$LLVM_COMMIT + + mkdir -p $LLVM_BUILD_ROOT + + # prepare source + if [ ! -d "$LLVM_SRC" ]; then + mkdir -p /home/jenkins/src + git clone https://community-ci.openruyi.cn/toolchain/llvm-project.git $LLVM_SRC + fi + + git -C $LLVM_SRC fetch --all + git -C $LLVM_SRC reset --hard + git -C $LLVM_SRC clean -fdx + git -C $LLVM_SRC checkout $LLVM_COMMIT + + echo "LLVM_SRC=$LLVM_SRC" >> $GITHUB_ENV + echo "LLVM_COMMIT=$LLVM_COMMIT" >> $GITHUB_ENV + echo "LLVM_BUILD_DIR=$LLVM_BUILD_DIR" >> $GITHUB_ENV + + if [ -f "$LLVM_BUILD_DIR/build/CMakeCache.txt" ] && + grep -q "$LLVM_SRC" "$LLVM_BUILD_DIR/build/CMakeCache.txt"; then + echo "need-rebuild=false" >> $GITHUB_OUTPUT + else + echo "need-rebuild=true" >> $GITHUB_OUTPUT + fi + + - name: Configure and Build LLVM (riscv64) + if: matrix.arch == 'riscv64'&& + steps.prepare-llvm-build.outputs.need-rebuild == 'true' + run: | + source /home/jenkins/venv/bin/activate + pip install -U pip setuptools wheel packaging + sed -i \ + -e '1s/^/# /' \ + requirements.txt + pip install -r requirements.txt + rm -rf $LLVM_BUILD_DIR + cmake -G Ninja -S $LLVM_SRC/llvm -B $LLVM_BUILD_DIR/build \ + -DLLVM_ENABLE_PROJECTS="mlir;clang;openmp" \ + -DLLVM_TARGETS_TO_BUILD="host;RISCV" \ + -DLLVM_ENABLE_ASSERTIONS=ON \ + -DCMAKE_BUILD_TYPE=RELEASE \ + -DMLIR_ENABLE_BINDINGS_PYTHON=ON \ + -DPython3_EXECUTABLE=$(which python3) + ninja -C $LLVM_BUILD_DIR/build check-clang check-mlir omp -j$(nproc) || true + deactivate + + - name: Show ccache stats (riscv64) + if: matrix.arch == 'riscv64' + run: | + ccache -s + + # 7A. Check buddy-mlir build (x86). + - name: Check buddy-mlir build (x86) + if: matrix.arch == 'x64' run: | for conda in ~/miniconda3/bin/activate ~/miniforge3/bin/activate; do [ -f "$conda" ] && source "$conda" buddy && break @@ -113,3 +221,38 @@ jobs: -DPython3_EXECUTABLE=$(which python3) ninja ninja check-buddy + + # 7B. Check buddy-mlir build (riscv64). + - name: Check buddy-mlir build (riscv64) + if: matrix.arch == 'riscv64' + run: | + source /home/jenkins/venv/bin/activate + echo "---- CHECK LOCAL SUBMODULE ----" + ls -l llvm || true + git submodule status || true + rm -rf build + mkdir build + cd build + + # It will fail to find source code if no -DLLVM_MAIN_SRC_DIR=$LLVM_SRC/llvm -DMLIR_MAIN_SRC_DIR=$LLVM_SRC/mlir + cmake -G Ninja .. \ + -DMLIR_DIR=$LLVM_BUILD_DIR/build/lib/cmake/mlir \ + -DLLVM_DIR=$LLVM_BUILD_DIR/build/lib/cmake/llvm \ + -DLLVM_MAIN_SRC_DIR=$LLVM_SRC/llvm \ + -DMLIR_MAIN_SRC_DIR=$LLVM_SRC/mlir \ + -DLLVM_ENABLE_ASSERTIONS=ON \ + -DCMAKE_BUILD_TYPE=RELEASE \ + -DBUDDY_MLIR_ENABLE_PYTHON_PACKAGES=ON \ + -DPython3_EXECUTABLE=$(which python3) + ninja -j$(nproc) + ninja check-buddy -j$(nproc) || true + deactivate + + - name: Cleanup old LLVM build cache (riscv64) + if: matrix.arch == 'riscv64' + run: | + CACHE_DIR=/home/jenkins/llvm-cache + KEEP=3 + + cd $CACHE_DIR + ls -1dt */ | tail -n +$((KEEP+1)) | xargs -r rm -rf \ No newline at end of file diff --git a/.github/workflows/build-manylinux.yml b/.github/workflows/build-manylinux.yml new file mode 100644 index 0000000000..fe460936c3 --- /dev/null +++ b/.github/workflows/build-manylinux.yml @@ -0,0 +1,76 @@ +name: build-manylinux + +on: + workflow_dispatch: + inputs: + version: + description: "Release version tag (e.g., v0.1.0)" + required: true + type: string + +permissions: + contents: write + +jobs: + wheel: + runs-on: ubuntu-22.04 + strategy: + fail-fast: false + matrix: + python: ["3.10", "3.11", "3.12", "3.13"] + torch_version: ["2.8"] # placeholder; not used in build, keeps matrix extensible + steps: + - name: Checkout + uses: actions/checkout@v4 + with: + submodules: true + + - name: Record LLVM submodule commit + id: llvm_rev + run: echo "sha=$(git -C llvm rev-parse HEAD)" >> "$GITHUB_OUTPUT" + + - name: Cache LLVM build tree + id: cache-llvm-build + uses: actions/cache@v4 + with: + path: llvm/build.docker + key: llvm-${{ runner.os }}-${{ matrix.python }}-${{ steps.llvm_rev.outputs.sha }} + + - name: Build manylinux wheel in container + run: | + set -euo pipefail + PY_VER="${{ matrix.python }}" + PY_NODOT="${PY_VER//./}" + PY_TAG="cp${PY_NODOT}-cp${PY_NODOT}" + export TORCH_VERSION="${{ matrix.torch_version }}" + export LLVM_CACHE_HIT="${{ steps.cache-llvm-build.outputs.cache-hit }}" + ./scripts/release_wheel_manylinux.sh "${PY_TAG}" + + - name: Upload wheels + uses: actions/upload-artifact@v4 + with: + name: wheels-py${{ matrix.python }}-torch${{ matrix.torch_version }} + path: build.docker/dist/*.whl + + release: + needs: wheel + runs-on: ubuntu-22.04 + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Download wheels + uses: actions/download-artifact@v4 + with: + pattern: wheels-* + path: dist + merge-multiple: true + + - name: Upload assets + uses: softprops/action-gh-release@v2 + with: + tag_name: ${{ inputs.version }} + overwrite_files: true + files: dist/*manylinux*.whl + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml new file mode 100644 index 0000000000..bcf8c19e30 --- /dev/null +++ b/.github/workflows/pre-commit.yml @@ -0,0 +1,21 @@ +name: pre-commit + +on: + push: + branches: [main] + pull_request: + +jobs: + pre-commit: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.10" + + - uses: pre-commit/action@v3.0.1 + with: + extra_args: --show-diff-on-failure --color=always diff --git a/CMakeLists.txt b/CMakeLists.txt index 860a78dff5..6f73b55311 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -26,6 +26,8 @@ project(buddy-mlir LANGUAGES CXX C) set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD_REQUIRED YES) include(ExternalProject) +include(GNUInstallDirs) +include(CMakePackageConfigHelpers) #------------------------------------------------------------------------------- # Options and settings @@ -86,6 +88,7 @@ set(BUDDY_BUILD_DIR ${CMAKE_CURRENT_BINARY_DIR}) set(BUDDY_BINARY_DIR ${CMAKE_CURRENT_BINARY_DIR}/bin) set(BUDDY_LIBRARY_DIR ${CMAKE_CURRENT_BINARY_DIR}/lib) set(BUDDY_EXAMPLES_DIR ${BUDDY_SOURCE_DIR}/examples) +set(BUDDY_MLIR_INTERFACE_DIR ${BUDDY_SOURCE_DIR}/frontend/Interfaces) set(BUDDY_MIDEND_INCLUDE_DIR ${BUDDY_SOURCE_DIR}/midend/include) set(BUDDY_THIRDPARTY_INCLUDE_DIR ${BUDDY_SOURCE_DIR}/thirdparty/include) set(BUDDY_MLIR_PYTHON_PACKAGES_DIR ${BUDDY_BUILD_DIR}/python_packages) @@ -216,23 +219,6 @@ if(BUDDY_MLIR_ENABLE_RISCV_GNU_TOOLCHAIN) ) endif() -#------------------------------------------------------------------------------- -# Initialize Python packages -#------------------------------------------------------------------------------- -if(BUDDY_MLIR_ENABLE_PYTHON_PACKAGES) - # Find the Python interpreter and development components, - # requiring a minimum version of 3.10 - find_package(Python3 3.10 REQUIRED COMPONENTS Interpreter Development) - # Create directories for the BUDDY-MLIR Python packages - file(MAKE_DIRECTORY ${BUDDY_MLIR_PYTHON_PACKAGES_DIR}/buddy) - file(MAKE_DIRECTORY ${BUDDY_MLIR_PYTHON_PACKAGES_DIR}/buddy/compiler) - # Create empty __init__.py files to make these directories Python packages - file(WRITE ${BUDDY_MLIR_PYTHON_PACKAGES_DIR}/buddy/__init__.py "") - file(WRITE ${BUDDY_MLIR_PYTHON_PACKAGES_DIR}/buddy/compiler/__init__.py "") - - install(DIRECTORY ${BUDDY_MLIR_PYTHON_PACKAGES_DIR}/buddy DESTINATION python_packages) -endif() - #------------------------------------------------------------------------------- # Directory setup #------------------------------------------------------------------------------- @@ -264,3 +250,23 @@ install(DIRECTORY buddy/Core buddy/DAP buddy/DIP buddy/LLM FILES_MATCHING PATTERN "*.h" ) + +set(BUDDY_MLIR_INSTALL_CMAKE_DIR "${CMAKE_INSTALL_LIBDIR}/cmake/BuddyMLIR") + +configure_package_config_file( + ${CMAKE_CURRENT_SOURCE_DIR}/cmake/BuddyMLIRConfig.cmake.in + ${CMAKE_CURRENT_BINARY_DIR}/BuddyMLIRConfig.cmake + INSTALL_DESTINATION ${BUDDY_MLIR_INSTALL_CMAKE_DIR} +) + +write_basic_package_version_file( + ${CMAKE_CURRENT_BINARY_DIR}/BuddyMLIRConfigVersion.cmake + VERSION 0.1.0 + COMPATIBILITY AnyNewerVersion +) + +install(FILES + ${CMAKE_CURRENT_BINARY_DIR}/BuddyMLIRConfig.cmake + ${CMAKE_CURRENT_BINARY_DIR}/BuddyMLIRConfigVersion.cmake + DESTINATION ${BUDDY_MLIR_INSTALL_CMAKE_DIR} +) diff --git a/README.md b/README.md index 97d10cf323..8e487c3cee 100644 --- a/README.md +++ b/README.md @@ -71,7 +71,7 @@ Set the `PYTHONPATH` environment variable to include both the LLVM/MLIR Python b ``` $ export BUDDY_MLIR_BUILD_DIR=$PWD $ export LLVM_MLIR_BUILD_DIR=$PWD/../llvm/build -$ export PYTHONPATH=${LLVM_MLIR_BUILD_DIR}/tools/mlir/python_packages/mlir_core:${BUDDY_MLIR_BUILD_DIR}/python_packages:${PYTHONPATH} +$ export PYTHONPATH=${BUDDY_MLIR_BUILD_DIR}/python_packages:${PYTHONPATH} ``` If you want to test your model end-to-end conversion and inference, you can add the following configuration @@ -81,6 +81,23 @@ $ cmake -G Ninja .. -DBUDDY_ENABLE_E2E_TESTS=ON $ ninja check-e2e ``` +## Build Python Package + +We use `setuptools` to bundle CMake outputs (Python packages, `bin/`, and +`lib/`) into a single wheel. + +run `./scripts/release_wheel_manylinux.sh`. + +This script calls `docker run` internally to enter the manylinux container, builds LLVM and buddy_mlir, and writes the wheel to `./build.docker/dist`. + +Install and test the wheel: + +```bash +pip install buddy-*.whl --no-deps +python -c "import buddy; import buddy_mlir; print('ok')" +buddy-opt --help +``` + ## Examples We provide examples to demonstrate how to use the passes and interfaces in `buddy-mlir`, including IR-level transformations, domain-specific applications, and testing demonstrations. diff --git a/backend/include/llvm/IR/IntrinsicsRISCVBuddyExt.td b/backend/include/llvm/IR/IntrinsicsRISCVBuddyExt.td index b0e427ee5d..42981ccb5a 100644 --- a/backend/include/llvm/IR/IntrinsicsRISCVBuddyExt.td +++ b/backend/include/llvm/IR/IntrinsicsRISCVBuddyExt.td @@ -112,6 +112,299 @@ def int_riscv_ime_vfmadot : Intrinsic<[llvm_anyvector_ty], [LLVMMatchType<0>, llvm_anyvector_ty, llvm_anyvector_ty], [IntrNoMem]>; +let TargetPrefix = "riscv" in +def int_riscv_ime_vmadot1 : Intrinsic<[llvm_anyvector_ty], + [LLVMMatchType<0>, llvm_anyvector_ty, llvm_anyvector_ty], + [IntrNoMem]>; + +let TargetPrefix = "riscv" in +def int_riscv_ime_vmadot1u : Intrinsic<[llvm_anyvector_ty], + [LLVMMatchType<0>, llvm_anyvector_ty, llvm_anyvector_ty], + [IntrNoMem]>; + +let TargetPrefix = "riscv" in +def int_riscv_ime_vmadot1su : Intrinsic<[llvm_anyvector_ty], + [LLVMMatchType<0>, llvm_anyvector_ty, llvm_anyvector_ty], + [IntrNoMem]>; + +let TargetPrefix = "riscv" in +def int_riscv_ime_vmadot1us : Intrinsic<[llvm_anyvector_ty], + [LLVMMatchType<0>, llvm_anyvector_ty, llvm_anyvector_ty], + [IntrNoMem]>; + +let TargetPrefix = "riscv" in +def int_riscv_ime_vmadot2 : Intrinsic<[llvm_anyvector_ty], + [LLVMMatchType<0>, llvm_anyvector_ty, llvm_anyvector_ty], + [IntrNoMem]>; + +let TargetPrefix = "riscv" in +def int_riscv_ime_vmadot2u : Intrinsic<[llvm_anyvector_ty], + [LLVMMatchType<0>, llvm_anyvector_ty, llvm_anyvector_ty], + [IntrNoMem]>; + +let TargetPrefix = "riscv" in +def int_riscv_ime_vmadot2su : Intrinsic<[llvm_anyvector_ty], + [LLVMMatchType<0>, llvm_anyvector_ty, llvm_anyvector_ty], + [IntrNoMem]>; + +let TargetPrefix = "riscv" in +def int_riscv_ime_vmadot2us : Intrinsic<[llvm_anyvector_ty], + [LLVMMatchType<0>, llvm_anyvector_ty, llvm_anyvector_ty], + [IntrNoMem]>; + +let TargetPrefix = "riscv" in +def int_riscv_ime_vmadot3 : Intrinsic<[llvm_anyvector_ty], + [LLVMMatchType<0>, llvm_anyvector_ty, llvm_anyvector_ty], + [IntrNoMem]>; + +let TargetPrefix = "riscv" in +def int_riscv_ime_vmadot3u : Intrinsic<[llvm_anyvector_ty], + [LLVMMatchType<0>, llvm_anyvector_ty, llvm_anyvector_ty], + [IntrNoMem]>; + +let TargetPrefix = "riscv" in +def int_riscv_ime_vmadot3su : Intrinsic<[llvm_anyvector_ty], + [LLVMMatchType<0>, llvm_anyvector_ty, llvm_anyvector_ty], + [IntrNoMem]>; + +let TargetPrefix = "riscv" in +def int_riscv_ime_vmadot3us : Intrinsic<[llvm_anyvector_ty], + [LLVMMatchType<0>, llvm_anyvector_ty, llvm_anyvector_ty], + [IntrNoMem]>; + +let TargetPrefix = "riscv" in +def int_riscv_ime_vmadotn : Intrinsic<[llvm_anyvector_ty], + [LLVMMatchType<0>, llvm_anyvector_ty, llvm_anyvector_ty, llvm_i64_ty], + [IntrNoMem]>; + +let TargetPrefix = "riscv" in +def int_riscv_ime_vmadotnu : Intrinsic<[llvm_anyvector_ty], + [LLVMMatchType<0>, llvm_anyvector_ty, llvm_anyvector_ty, llvm_i64_ty], + [IntrNoMem]>; + +let TargetPrefix = "riscv" in +def int_riscv_ime_vmadotnsu : Intrinsic<[llvm_anyvector_ty], + [LLVMMatchType<0>, llvm_anyvector_ty, llvm_anyvector_ty, llvm_i64_ty], + [IntrNoMem]>; + +let TargetPrefix = "riscv" in +def int_riscv_ime_vmadotnus : Intrinsic<[llvm_anyvector_ty], + [LLVMMatchType<0>, llvm_anyvector_ty, llvm_anyvector_ty, llvm_i64_ty], + [IntrNoMem]>; + +let TargetPrefix = "riscv" in +def int_riscv_ime_vfmadot1 : Intrinsic<[llvm_anyvector_ty], + [LLVMMatchType<0>, llvm_anyvector_ty, llvm_anyvector_ty], + [IntrNoMem]>; + +let TargetPrefix = "riscv" in +def int_riscv_ime_vfmadot2 : Intrinsic<[llvm_anyvector_ty], + [LLVMMatchType<0>, llvm_anyvector_ty, llvm_anyvector_ty], + [IntrNoMem]>; + +let TargetPrefix = "riscv" in +def int_riscv_ime_vfmadot3 : Intrinsic<[llvm_anyvector_ty], + [LLVMMatchType<0>, llvm_anyvector_ty, llvm_anyvector_ty], + [IntrNoMem]>; + +let TargetPrefix = "riscv" in +def int_riscv_ime_vfmadotn : Intrinsic<[llvm_anyvector_ty], + [LLVMMatchType<0>, llvm_anyvector_ty, llvm_anyvector_ty, llvm_i64_ty], + [IntrNoMem]>; + +//===----------------------------------------------------------------------===// +// AME (RISC-V Matrix Extension) Intrinsics +//===----------------------------------------------------------------------===// + +// Matrix configuration intrinsics (with register) +let TargetPrefix = "riscv" in { + // msettype - Set matrix type configuration + def int_riscv_buddy_msettype : Intrinsic<[llvm_i64_ty], [llvm_i64_ty], [IntrNoMem]>; + + // msettilem - Set tile M dimension from register + def int_riscv_buddy_msettilem : Intrinsic<[llvm_i64_ty], [llvm_i64_ty], [IntrNoMem]>; + + // msettilen - Set tile N dimension from register + def int_riscv_buddy_msettilen : Intrinsic<[llvm_i64_ty], [llvm_i64_ty], [IntrNoMem]>; + + // msettilek - Set tile K dimension from register + def int_riscv_buddy_msettilek : Intrinsic<[llvm_i64_ty], [llvm_i64_ty], [IntrNoMem]>; +} + +// Matrix configuration intrinsics (with immediate) +let TargetPrefix = "riscv" in { + // msettilemi - Set tile M dimension with immediate + def int_riscv_buddy_msettilemi : Intrinsic<[], [llvm_i64_ty], + [IntrNoMem, IntrHasSideEffects, ImmArg>]>; + + // msettileni - Set tile N dimension with immediate + def int_riscv_buddy_msettileni : Intrinsic<[], [llvm_i64_ty], + [IntrNoMem, IntrHasSideEffects, ImmArg>]>; + + // msettileki - Set tile K dimension with immediate + def int_riscv_buddy_msettileki : Intrinsic<[], [llvm_i64_ty], + [IntrNoMem, IntrHasSideEffects, ImmArg>]>; +} + +// Matrix load intrinsics (load to tile register) +// Format: md = tile register index, base = address, stride = row byte stride +let TargetPrefix = "riscv" in { + // mlae32.m - Load 32-bit left matrix A to tile register + def int_riscv_buddy_mlae32_m : Intrinsic<[], + [llvm_i64_ty, llvm_ptr_ty, llvm_i64_ty], + [IntrReadMem, IntrHasSideEffects, ImmArg>]>; + + // mlae64.m - Load 64-bit left matrix A to tile register + def int_riscv_buddy_mlae64_m : Intrinsic<[], + [llvm_i64_ty, llvm_ptr_ty, llvm_i64_ty], + [IntrReadMem, IntrHasSideEffects, ImmArg>]>; + + // mlbe32.m - Load 32-bit right matrix B to tile register + def int_riscv_buddy_mlbe32_m : Intrinsic<[], + [llvm_i64_ty, llvm_ptr_ty, llvm_i64_ty], + [IntrReadMem, IntrHasSideEffects, ImmArg>]>; + + // mlbe64.m - Load 64-bit right matrix B to tile register + def int_riscv_buddy_mlbe64_m : Intrinsic<[], + [llvm_i64_ty, llvm_ptr_ty, llvm_i64_ty], + [IntrReadMem, IntrHasSideEffects, ImmArg>]>; + + // mlce32.m - Load 32-bit output matrix C to accumulator + def int_riscv_buddy_mlce32_m : Intrinsic<[], + [llvm_i64_ty, llvm_ptr_ty, llvm_i64_ty], + [IntrReadMem, IntrHasSideEffects, ImmArg>]>; + + // mlce64.m - Load 64-bit output matrix C to accumulator + def int_riscv_buddy_mlce64_m : Intrinsic<[], + [llvm_i64_ty, llvm_ptr_ty, llvm_i64_ty], + [IntrReadMem, IntrHasSideEffects, ImmArg>]>; +} + +// Matrix store intrinsics (store from tile register) +let TargetPrefix = "riscv" in { + // msce32.m - Store 32-bit output matrix C from accumulator + def int_riscv_buddy_msce32_m : Intrinsic<[], + [llvm_i64_ty, llvm_ptr_ty, llvm_i64_ty], + [IntrWriteMem, IntrHasSideEffects, ImmArg>]>; + + // msce64.m - Store 64-bit output matrix C from accumulator + def int_riscv_buddy_msce64_m : Intrinsic<[], + [llvm_i64_ty, llvm_ptr_ty, llvm_i64_ty], + [IntrWriteMem, IntrHasSideEffects, ImmArg>]>; +} + +// Matrix zero intrinsic +let TargetPrefix = "riscv" in { + def int_riscv_buddy_mzero : Intrinsic<[], [llvm_i64_ty], + [IntrNoMem, IntrHasSideEffects, ImmArg>]>; +} + +// Tile register matrix multiplication intrinsics (operate on tile registers) +let TargetPrefix = "riscv" in { + // mma.w.mm - int32 tile matrix multiply: md = md + ms1 x ms2 + def int_riscv_buddy_mma_w_mm_tile : Intrinsic<[], + [llvm_i64_ty, llvm_i64_ty, llvm_i64_ty], + [IntrHasSideEffects, ImmArg>, + ImmArg>, ImmArg>]>; + + // mma.dw.mm - int64 tile matrix multiply: md = md + ms1 x ms2 + def int_riscv_buddy_mma_dw_mm_tile : Intrinsic<[], + [llvm_i64_ty, llvm_i64_ty, llvm_i64_ty], + [IntrHasSideEffects, ImmArg>, + ImmArg>, ImmArg>]>; +} + +// Legacy matrix load/store intrinsics (for backward compatibility) +let TargetPrefix = "riscv" in { + // mlae - Load matrix A with element width + def int_riscv_buddy_mlae : Intrinsic<[llvm_anyvector_ty], + [llvm_ptr_ty, llvm_i64_ty, llvm_i64_ty], + [IntrReadMem]>; + + // mlbe - Load matrix B with element width + def int_riscv_buddy_mlbe : Intrinsic<[llvm_anyvector_ty], + [llvm_ptr_ty, llvm_i64_ty, llvm_i64_ty], + [IntrReadMem]>; + + // mlce - Load matrix C (accumulator) + def int_riscv_buddy_mlce : Intrinsic<[llvm_anyvector_ty], + [llvm_ptr_ty, llvm_i64_ty, llvm_i64_ty], + [IntrReadMem]>; + + // msce - Store matrix C (accumulator) + def int_riscv_buddy_msce : Intrinsic<[], + [llvm_anyvector_ty, llvm_ptr_ty, llvm_i64_ty, llvm_i64_ty], + [IntrWriteMem]>; +} + +// Signed integer matrix multiplication intrinsics +let TargetPrefix = "riscv" in { + // mqma.b.mm - int8 quad-widen matrix multiply (int8 x int8 -> int32) + def int_riscv_buddy_mqma_b_mm : Intrinsic<[], + [llvm_ptr_ty, llvm_ptr_ty, llvm_ptr_ty, + llvm_i64_ty, llvm_i64_ty, llvm_i64_ty], + []>; + + // mma.h.mm - int16 matrix multiply + def int_riscv_buddy_mma_h_mm : Intrinsic<[], + [llvm_ptr_ty, llvm_ptr_ty, llvm_ptr_ty, + llvm_i64_ty, llvm_i64_ty, llvm_i64_ty], + []>; + + // mma.w.mm - int32 matrix multiply + def int_riscv_buddy_mma_w_mm : Intrinsic<[], + [llvm_ptr_ty, llvm_ptr_ty, llvm_ptr_ty, + llvm_i64_ty, llvm_i64_ty, llvm_i64_ty], + []>; + + // mma.dw.mm - int64 matrix multiply + def int_riscv_buddy_mma_dw_mm : Intrinsic<[], + [llvm_ptr_ty, llvm_ptr_ty, llvm_ptr_ty, + llvm_i64_ty, llvm_i64_ty, llvm_i64_ty], + []>; + + // mwma.h.mm - int16 double-widen matrix multiply (int16 x int16 -> int32) + def int_riscv_buddy_mwma_h_mm : Intrinsic<[], + [llvm_ptr_ty, llvm_ptr_ty, llvm_ptr_ty, + llvm_i64_ty, llvm_i64_ty, llvm_i64_ty], + []>; +} + +// Unsigned integer matrix multiplication intrinsics +let TargetPrefix = "riscv" in { + // mqmau.b.mm - uint8 quad-widen matrix multiply + def int_riscv_buddy_mqmau_b_mm : Intrinsic<[], + [llvm_ptr_ty, llvm_ptr_ty, llvm_ptr_ty, + llvm_i64_ty, llvm_i64_ty, llvm_i64_ty], + []>; + + // mmau.h.mm - uint16 matrix multiply + def int_riscv_buddy_mmau_h_mm : Intrinsic<[], + [llvm_ptr_ty, llvm_ptr_ty, llvm_ptr_ty, + llvm_i64_ty, llvm_i64_ty, llvm_i64_ty], + []>; +} + +// Floating-point matrix multiplication intrinsics +let TargetPrefix = "riscv" in { + // mfma.f.mm - fp32 matrix multiply + def int_riscv_buddy_mfma_f_mm : Intrinsic<[], + [llvm_ptr_ty, llvm_ptr_ty, llvm_ptr_ty, + llvm_i64_ty, llvm_i64_ty, llvm_i64_ty], + []>; + + // mfma.hf.mm - fp16 matrix multiply + def int_riscv_buddy_mfma_hf_mm : Intrinsic<[], + [llvm_ptr_ty, llvm_ptr_ty, llvm_ptr_ty, + llvm_i64_ty, llvm_i64_ty, llvm_i64_ty], + []>; + + // mfwma.hf.mm - fp16 double-widen matrix multiply (fp16 x fp16 -> fp32) + def int_riscv_buddy_mfwma_hf_mm : Intrinsic<[], + [llvm_ptr_ty, llvm_ptr_ty, llvm_ptr_ty, + llvm_i64_ty, llvm_i64_ty, llvm_i64_ty], + []>; +} //===----------------------------------------------------------------------===// // BB intrinsics //===----------------------------------------------------------------------===// diff --git a/backend/llvm/lib/Target/RISCV/RISCVBuddyExt.td b/backend/llvm/lib/Target/RISCV/RISCVBuddyExt.td index 5b081c1245..0660630f69 100644 --- a/backend/llvm/lib/Target/RISCV/RISCVBuddyExt.td +++ b/backend/llvm/lib/Target/RISCV/RISCVBuddyExt.td @@ -26,4 +26,108 @@ def HasBuddyExt : Predicate<"Subtarget->hasBuddyExt()">, AssemblerPredicate<(all_of FeatureBuddyExt), "'BuddyExt' (Buddy RISC-V Extension)">; +//===----------------------------------------------------------------------===// +// AME (RISC-V Matrix Extension) Register Definitions +//===----------------------------------------------------------------------===// +// Reference: RISC-V Matrix Extension Specification +// +// Matrix Registers: +// - 8 Tile Registers (tr0-tr7): For input matrices A and B +// Each tile register has MLEN bits of state +// - 8 Accumulation Registers (acc0-acc7): For output/accumulation matrix C +// Each accumulation register has MLEN × AMUL bits of state +// +// AMUL (Accumulation MULtiplier): +// - Can be fractional (1/8, 1/4, 1/2) or integer (1, 2, 4, 8) +// - Determines the width ratio between acc and tr registers +// - For mmi8i32 (int8→int32 quad-widen), AMUL ≥ 4 +// +// Data Flow: +// Memory → tr (via mlae/mlbe) → acc (via mma/mwma/mqma) → Memory (via msce) +//===----------------------------------------------------------------------===// + +let Namespace = "RISCV" in { + +//===----------------------------------------------------------------------===// +// AME Tile Registers (tr0-tr7) +// Used for input matrices A and B +// Size: MLEN bits per register (hardware-defined) +//===----------------------------------------------------------------------===// + +// Base class for Tile Registers +class AMETileReg Enc, string n> : Register { + let HWEncoding{2-0} = Enc; + let HWEncoding{4-3} = 0b00; // Distinguish from accumulation registers +} + +// Define 8 Tile Registers: tr0-tr7 +def TR0 : AMETileReg<0, "tr0">; +def TR1 : AMETileReg<1, "tr1">; +def TR2 : AMETileReg<2, "tr2">; +def TR3 : AMETileReg<3, "tr3">; +def TR4 : AMETileReg<4, "tr4">; +def TR5 : AMETileReg<5, "tr5">; +def TR6 : AMETileReg<6, "tr6">; +def TR7 : AMETileReg<7, "tr7">; + +//===----------------------------------------------------------------------===// +// AME Accumulation Registers (acc0-acc7) +// Used for output/accumulation matrix C +// Size: MLEN × AMUL bits per register (hardware-defined) +// +// Note: AMUL can be: +// - Fractional (1/8, 1/4, 1/2): For C = A × Bᵀ mode with large K +// - Integer (1, 2, 4, 8): For widening operations +// * AMUL=4: Required for mmi8i32 (int8→int32 quad-widen) +// * AMUL=2: Required for mmi16i32 (int16→int32 double-widen) +// * AMUL=8: Required for mmi4i32 (int4→int32 oct-widen) +//===----------------------------------------------------------------------===// + +// Base class for Accumulation Registers +class AMEAccReg Enc, string n> : Register { + let HWEncoding{2-0} = Enc; + let HWEncoding{4-3} = 0b01; // Distinguish from tile registers +} + +// Define 8 Accumulation Registers: acc0-acc7 +def ACC0 : AMEAccReg<0, "acc0">; +def ACC1 : AMEAccReg<1, "acc1">; +def ACC2 : AMEAccReg<2, "acc2">; +def ACC3 : AMEAccReg<3, "acc3">; +def ACC4 : AMEAccReg<4, "acc4">; +def ACC5 : AMEAccReg<5, "acc5">; +def ACC6 : AMEAccReg<6, "acc6">; +def ACC7 : AMEAccReg<7, "acc7">; + +} // End Namespace = "RISCV" + +//===----------------------------------------------------------------------===// +// AME Register Classes +//===----------------------------------------------------------------------===// +// These register classes define the operand types for AME instructions +// +// Usage in instructions: +// - TileReg: For ms1, ms2 (source operands in multiplication) +// - AccReg: For md (destination/accumulator in multiplication) +// - TileReg: For load/store of input matrices (A, B) +// - AccReg: For load/store of output/accumulator (C) +//===----------------------------------------------------------------------===// + +// Tile Register class (tr0-tr7) +// Used for input operands in matrix multiplication +// Note: Size is set to 256 as a placeholder; actual size depends on MLEN +def TileReg : RegisterClass<"RISCV", [untyped], 256, + (add TR0, TR1, TR2, TR3, TR4, TR5, TR6, TR7)> { + let Size = 256; // Placeholder: actual MLEN is hardware-defined +} + +// Accumulation Register class (acc0-acc7) +// Used for output/accumulator in matrix multiplication +// Note: Size can be 256×AMUL where AMUL ∈ {1/8, 1/4, 1/2, 1, 2, 4, 8} +// We use 1024 as a reasonable upper bound (256 × 4 for int8→int32) +def AccReg : RegisterClass<"RISCV", [untyped], 1024, + (add ACC0, ACC1, ACC2, ACC3, ACC4, ACC5, ACC6, ACC7)> { + let Size = 1024; // Placeholder: actual MLEN×AMUL is hardware-defined +} + include "RISCVInstrInfoBuddyExt.td" diff --git a/backend/llvm/lib/Target/RISCV/RISCVInstrInfoBuddyExt.td b/backend/llvm/lib/Target/RISCV/RISCVInstrInfoBuddyExt.td index f06094fc30..5c8f15b693 100644 --- a/backend/llvm/lib/Target/RISCV/RISCVInstrInfoBuddyExt.td +++ b/backend/llvm/lib/Target/RISCV/RISCVInstrInfoBuddyExt.td @@ -222,6 +222,134 @@ def IME_VFMADOT : RVInstIME<0b1110101, 0b000, let Constraints = "$vd = $vd_in"; } +//===----------------------------------------------------------------------===// +// IME Sliding-Window Instructions +//===----------------------------------------------------------------------===// + +// Integer slide-1 instructions +let Predicates = [HasBuddyExt], hasSideEffects = 0, mayLoad = 0, mayStore = 0 in +def IME_VMADOT1 : RVInstIME<0b1110010, 0b000, + (outs VRM4:$vd), + (ins VRM4:$vd_in, VRM8:$vs1, VRM4:$vs2), + "vmadot1", "$vd, $vs1, $vs2"> { + let Constraints = "$vd = $vd_in"; +} + +let Predicates = [HasBuddyExt], hasSideEffects = 0, mayLoad = 0, mayStore = 0 in +def IME_VMADOT1U : RVInstIME<0b1110010, 0b011, + (outs VRM4:$vd), + (ins VRM4:$vd_in, VRM8:$vs1, VRM4:$vs2), + "vmadot1u", "$vd, $vs1, $vs2"> { + let Constraints = "$vd = $vd_in"; +} + +let Predicates = [HasBuddyExt], hasSideEffects = 0, mayLoad = 0, mayStore = 0 in +def IME_VMADOT1SU : RVInstIME<0b1110010, 0b001, + (outs VRM4:$vd), + (ins VRM4:$vd_in, VRM8:$vs1, VRM4:$vs2), + "vmadot1su", "$vd, $vs1, $vs2"> { + let Constraints = "$vd = $vd_in"; +} + +let Predicates = [HasBuddyExt], hasSideEffects = 0, mayLoad = 0, mayStore = 0 in +def IME_VMADOT1US : RVInstIME<0b1110010, 0b010, + (outs VRM4:$vd), + (ins VRM4:$vd_in, VRM8:$vs1, VRM4:$vs2), + "vmadot1us", "$vd, $vs1, $vs2"> { + let Constraints = "$vd = $vd_in"; +} + +// Integer slide-2 instructions +let Predicates = [HasBuddyExt], hasSideEffects = 0, mayLoad = 0, mayStore = 0 in +def IME_VMADOT2 : RVInstIME<0b1110011, 0b000, + (outs VRM4:$vd), + (ins VRM4:$vd_in, VRM8:$vs1, VRM4:$vs2), + "vmadot2", "$vd, $vs1, $vs2"> { + let Constraints = "$vd = $vd_in"; +} + +let Predicates = [HasBuddyExt], hasSideEffects = 0, mayLoad = 0, mayStore = 0 in +def IME_VMADOT2U : RVInstIME<0b1110011, 0b011, + (outs VRM4:$vd), + (ins VRM4:$vd_in, VRM8:$vs1, VRM4:$vs2), + "vmadot2u", "$vd, $vs1, $vs2"> { + let Constraints = "$vd = $vd_in"; +} + +let Predicates = [HasBuddyExt], hasSideEffects = 0, mayLoad = 0, mayStore = 0 in +def IME_VMADOT2SU : RVInstIME<0b1110011, 0b001, + (outs VRM4:$vd), + (ins VRM4:$vd_in, VRM8:$vs1, VRM4:$vs2), + "vmadot2su", "$vd, $vs1, $vs2"> { + let Constraints = "$vd = $vd_in"; +} + +let Predicates = [HasBuddyExt], hasSideEffects = 0, mayLoad = 0, mayStore = 0 in +def IME_VMADOT2US : RVInstIME<0b1110011, 0b010, + (outs VRM4:$vd), + (ins VRM4:$vd_in, VRM8:$vs1, VRM4:$vs2), + "vmadot2us", "$vd, $vs1, $vs2"> { + let Constraints = "$vd = $vd_in"; +} + +// Integer slide-3 instructions +let Predicates = [HasBuddyExt], hasSideEffects = 0, mayLoad = 0, mayStore = 0 in +def IME_VMADOT3 : RVInstIME<0b1110100, 0b000, + (outs VRM4:$vd), + (ins VRM4:$vd_in, VRM8:$vs1, VRM4:$vs2), + "vmadot3", "$vd, $vs1, $vs2"> { + let Constraints = "$vd = $vd_in"; +} + +let Predicates = [HasBuddyExt], hasSideEffects = 0, mayLoad = 0, mayStore = 0 in +def IME_VMADOT3U : RVInstIME<0b1110100, 0b011, + (outs VRM4:$vd), + (ins VRM4:$vd_in, VRM8:$vs1, VRM4:$vs2), + "vmadot3u", "$vd, $vs1, $vs2"> { + let Constraints = "$vd = $vd_in"; +} + +let Predicates = [HasBuddyExt], hasSideEffects = 0, mayLoad = 0, mayStore = 0 in +def IME_VMADOT3SU : RVInstIME<0b1110100, 0b001, + (outs VRM4:$vd), + (ins VRM4:$vd_in, VRM8:$vs1, VRM4:$vs2), + "vmadot3su", "$vd, $vs1, $vs2"> { + let Constraints = "$vd = $vd_in"; +} + +let Predicates = [HasBuddyExt], hasSideEffects = 0, mayLoad = 0, mayStore = 0 in +def IME_VMADOT3US : RVInstIME<0b1110100, 0b010, + (outs VRM4:$vd), + (ins VRM4:$vd_in, VRM8:$vs1, VRM4:$vs2), + "vmadot3us", "$vd, $vs1, $vs2"> { + let Constraints = "$vd = $vd_in"; +} + +// Floating-point slide instructions +let Predicates = [HasBuddyExt], hasSideEffects = 0, mayLoad = 0, mayStore = 0 in +def IME_VFMADOT1 : RVInstIME<0b1110110, 0b000, + (outs VRM4:$vd), + (ins VRM4:$vd_in, VRM8:$vs1, VRM4:$vs2), + "vfmadot1", "$vd, $vs1, $vs2"> { + let Constraints = "$vd = $vd_in"; +} + +let Predicates = [HasBuddyExt], hasSideEffects = 0, mayLoad = 0, mayStore = 0 in +def IME_VFMADOT2 : RVInstIME<0b1110111, 0b000, + (outs VRM4:$vd), + (ins VRM4:$vd_in, VRM8:$vs1, VRM4:$vs2), + "vfmadot2", "$vd, $vs1, $vs2"> { + let Constraints = "$vd = $vd_in"; +} + +let Predicates = [HasBuddyExt], hasSideEffects = 0, mayLoad = 0, mayStore = 0 in +def IME_VFMADOT3 : RVInstIME<0b1111000, 0b000, + (outs VRM4:$vd), + (ins VRM4:$vd_in, VRM8:$vs1, VRM4:$vs2), + "vfmadot3", "$vd, $vs1, $vs2"> { + let Constraints = "$vd = $vd_in"; +} + let Predicates = [HasBuddyExt] in def : Pat<(int_riscv_mvin GPR:$rs1, GPR:$rs2), (MVIN GPR:$rs1, GPR:$rs2)>; @@ -297,6 +425,12 @@ let Predicates = [HasBuddyExt] in { (IME_VMADOT VRM4:$vd, VRM4:$vs1, VRM4:$vs2)>; } +// int16 vmadot patterns +let Predicates = [HasBuddyExt] in { + def : Pat<(nxv8i32 (int_riscv_ime_vmadot nxv8i32:$vd, nxv16i16:$vs1, nxv16i16:$vs2)), + (IME_VMADOT VRM4:$vd, VRM4:$vs1, VRM4:$vs2)>; +} + let Predicates = [HasBuddyExt] in { def : Pat<(nxv8i32 (int_riscv_ime_vmadotu nxv8i32:$vd, nxv32i8:$vs1, nxv32i8:$vs2)), (IME_VMADOTU VRM4:$vd, VRM4:$vs1, VRM4:$vs2)>; @@ -317,6 +451,1045 @@ let Predicates = [HasBuddyExt] in { (IME_VFMADOT VRM4:$vd, VRM4:$vs1, VRM4:$vs2)>; } +let Predicates = [HasBuddyExt] in { + def : Pat<(nxv8i32 (int_riscv_ime_vmadot1 nxv8i32:$vd, nxv64i8:$vs1, nxv32i8:$vs2)), + (IME_VMADOT1 VRM4:$vd, VRM8:$vs1, VRM4:$vs2)>; +} + +let Predicates = [HasBuddyExt] in { + def : Pat<(nxv8i32 (int_riscv_ime_vmadot1u nxv8i32:$vd, nxv64i8:$vs1, nxv32i8:$vs2)), + (IME_VMADOT1U VRM4:$vd, VRM8:$vs1, VRM4:$vs2)>; +} + +let Predicates = [HasBuddyExt] in { + def : Pat<(nxv8i32 (int_riscv_ime_vmadot1su nxv8i32:$vd, nxv64i8:$vs1, nxv32i8:$vs2)), + (IME_VMADOT1SU VRM4:$vd, VRM8:$vs1, VRM4:$vs2)>; +} + +let Predicates = [HasBuddyExt] in { + def : Pat<(nxv8i32 (int_riscv_ime_vmadot1us nxv8i32:$vd, nxv64i8:$vs1, nxv32i8:$vs2)), + (IME_VMADOT1US VRM4:$vd, VRM8:$vs1, VRM4:$vs2)>; +} + +let Predicates = [HasBuddyExt] in { + def : Pat<(nxv8i32 (int_riscv_ime_vmadot2 nxv8i32:$vd, nxv64i8:$vs1, nxv32i8:$vs2)), + (IME_VMADOT2 VRM4:$vd, VRM8:$vs1, VRM4:$vs2)>; +} + +let Predicates = [HasBuddyExt] in { + def : Pat<(nxv8i32 (int_riscv_ime_vmadot2u nxv8i32:$vd, nxv64i8:$vs1, nxv32i8:$vs2)), + (IME_VMADOT2U VRM4:$vd, VRM8:$vs1, VRM4:$vs2)>; +} + +let Predicates = [HasBuddyExt] in { + def : Pat<(nxv8i32 (int_riscv_ime_vmadot2su nxv8i32:$vd, nxv64i8:$vs1, nxv32i8:$vs2)), + (IME_VMADOT2SU VRM4:$vd, VRM8:$vs1, VRM4:$vs2)>; +} + +let Predicates = [HasBuddyExt] in { + def : Pat<(nxv8i32 (int_riscv_ime_vmadot2us nxv8i32:$vd, nxv64i8:$vs1, nxv32i8:$vs2)), + (IME_VMADOT2US VRM4:$vd, VRM8:$vs1, VRM4:$vs2)>; +} + +let Predicates = [HasBuddyExt] in { + def : Pat<(nxv8i32 (int_riscv_ime_vmadot3 nxv8i32:$vd, nxv64i8:$vs1, nxv32i8:$vs2)), + (IME_VMADOT3 VRM4:$vd, VRM8:$vs1, VRM4:$vs2)>; +} + +let Predicates = [HasBuddyExt] in { + def : Pat<(nxv8i32 (int_riscv_ime_vmadot3u nxv8i32:$vd, nxv64i8:$vs1, nxv32i8:$vs2)), + (IME_VMADOT3U VRM4:$vd, VRM8:$vs1, VRM4:$vs2)>; +} + +let Predicates = [HasBuddyExt] in { + def : Pat<(nxv8i32 (int_riscv_ime_vmadot3su nxv8i32:$vd, nxv64i8:$vs1, nxv32i8:$vs2)), + (IME_VMADOT3SU VRM4:$vd, VRM8:$vs1, VRM4:$vs2)>; +} + +let Predicates = [HasBuddyExt] in { + def : Pat<(nxv8i32 (int_riscv_ime_vmadot3us nxv8i32:$vd, nxv64i8:$vs1, nxv32i8:$vs2)), + (IME_VMADOT3US VRM4:$vd, VRM8:$vs1, VRM4:$vs2)>; +} + +let Predicates = [HasBuddyExt] in { + def : Pat<(nxv16f16 (int_riscv_ime_vfmadot1 nxv16f16:$vd, nxv32f16:$vs1, nxv16f16:$vs2)), + (IME_VFMADOT1 VRM4:$vd, VRM8:$vs1, VRM4:$vs2)>; +} + +let Predicates = [HasBuddyExt] in { + def : Pat<(nxv16f16 (int_riscv_ime_vfmadot2 nxv16f16:$vd, nxv32f16:$vs1, nxv16f16:$vs2)), + (IME_VFMADOT2 VRM4:$vd, VRM8:$vs1, VRM4:$vs2)>; +} + +let Predicates = [HasBuddyExt] in { + def : Pat<(nxv16f16 (int_riscv_ime_vfmadot3 nxv16f16:$vd, nxv32f16:$vs1, nxv16f16:$vs2)), + (IME_VFMADOT3 VRM4:$vd, VRM8:$vs1, VRM4:$vs2)>; +} + +class RVInstIMEN funct7, bits<3> funct3, dag outs, dag ins, + string opcodestr, string argstr> + : RVInst { + bits<5> vs2; + bits<5> rs1; // GPR for dynamic slide value + bits<5> vd; + + let Inst{31-25} = funct7; + let Inst{24-20} = vs2; + let Inst{19-15} = rs1; + let Inst{14-12} = funct3; + let Inst{11-7} = vd; + let Inst{6-0} = OPC_CUSTOM_1.Value; + + let Uses = [VTYPE, VL]; +} + +let Predicates = [HasBuddyExt], hasSideEffects = 0, mayLoad = 0, mayStore = 0 in +def IME_VMADOTN : RVInstIMEN<0b1111001, 0b000, + (outs VRM4:$vd), + (ins VRM4:$vd_in, VRM8:$vs1, VRM4:$vs2, GPR:$rs1), + "vmadotn", "$vd, $vs1, $vs2, $rs1"> { + let Constraints = "$vd = $vd_in"; +} + +let Predicates = [HasBuddyExt], hasSideEffects = 0, mayLoad = 0, mayStore = 0 in +def IME_VMADOTNU : RVInstIMEN<0b1111001, 0b011, + (outs VRM4:$vd), + (ins VRM4:$vd_in, VRM8:$vs1, VRM4:$vs2, GPR:$rs1), + "vmadotnu", "$vd, $vs1, $vs2, $rs1"> { + let Constraints = "$vd = $vd_in"; +} + +let Predicates = [HasBuddyExt], hasSideEffects = 0, mayLoad = 0, mayStore = 0 in +def IME_VMADOTNSU : RVInstIMEN<0b1111001, 0b001, + (outs VRM4:$vd), + (ins VRM4:$vd_in, VRM8:$vs1, VRM4:$vs2, GPR:$rs1), + "vmadotnsu", "$vd, $vs1, $vs2, $rs1"> { + let Constraints = "$vd = $vd_in"; +} + +let Predicates = [HasBuddyExt], hasSideEffects = 0, mayLoad = 0, mayStore = 0 in +def IME_VMADOTNUS : RVInstIMEN<0b1111001, 0b010, + (outs VRM4:$vd), + (ins VRM4:$vd_in, VRM8:$vs1, VRM4:$vs2, GPR:$rs1), + "vmadotnus", "$vd, $vs1, $vs2, $rs1"> { + let Constraints = "$vd = $vd_in"; +} + +let Predicates = [HasBuddyExt], hasSideEffects = 0, mayLoad = 0, mayStore = 0 in +def IME_VFMADOTN : RVInstIMEN<0b1111010, 0b000, + (outs VRM4:$vd), + (ins VRM4:$vd_in, VRM8:$vs1, VRM4:$vs2, GPR:$rs1), + "vfmadotn", "$vd, $vs1, $vs2, $rs1"> { + let Constraints = "$vd = $vd_in"; +} + +let Predicates = [HasBuddyExt] in { + def : Pat<(nxv8i32 (int_riscv_ime_vmadotn nxv8i32:$vd, nxv64i8:$vs1, nxv32i8:$vs2, GPR:$rs1)), + (IME_VMADOTN VRM4:$vd, VRM8:$vs1, VRM4:$vs2, GPR:$rs1)>; +} + +let Predicates = [HasBuddyExt] in { + def : Pat<(nxv8i32 (int_riscv_ime_vmadotnu nxv8i32:$vd, nxv64i8:$vs1, nxv32i8:$vs2, GPR:$rs1)), + (IME_VMADOTNU VRM4:$vd, VRM8:$vs1, VRM4:$vs2, GPR:$rs1)>; +} + +let Predicates = [HasBuddyExt] in { + def : Pat<(nxv8i32 (int_riscv_ime_vmadotnsu nxv8i32:$vd, nxv64i8:$vs1, nxv32i8:$vs2, GPR:$rs1)), + (IME_VMADOTNSU VRM4:$vd, VRM8:$vs1, VRM4:$vs2, GPR:$rs1)>; +} + +let Predicates = [HasBuddyExt] in { + def : Pat<(nxv8i32 (int_riscv_ime_vmadotnus nxv8i32:$vd, nxv64i8:$vs1, nxv32i8:$vs2, GPR:$rs1)), + (IME_VMADOTNUS VRM4:$vd, VRM8:$vs1, VRM4:$vs2, GPR:$rs1)>; +} + +let Predicates = [HasBuddyExt] in { + def : Pat<(nxv16f16 (int_riscv_ime_vfmadotn nxv16f16:$vd, nxv32f16:$vs1, nxv16f16:$vs2, GPR:$rs1)), + (IME_VFMADOTN VRM4:$vd, VRM8:$vs1, VRM4:$vs2, GPR:$rs1)>; +} + +//===----------------------------------------------------------------------===// +// AME (RISC-V Matrix Extension) 64-bit Instructions +//===----------------------------------------------------------------------===// +// Reference: RISC-V Matrix Extension Specification +// 64-bit encoding format with prefix 0111111 at bits [6:0] +// +// Register Model: +// - TileReg (tr0-tr7): Input matrices A and B, MLEN bits each +// - AccReg (acc0-acc7): Accumulator matrix C, MLEN×AMUL bits each +// +// Data Flow: +// Memory --[mlae/mlbe]--> TileReg --[mma/mwma/mqma]--> AccReg --[msce]--> Memory +// +// Matrix Multiplication Instruction Format (64-bit): +// | 63:59 | 58 | 57:55 | 54:52 | 51:49 | 48:47 | 46:44 | 43:39 | 38:32 | +// | sps | sp | typ2 | typ1 | typd | bma | frm | funct5 | opcode | +// | 31:26 | 25 | 24:20 | 19:15 | 14:12 | 11:7 | 6:0 | +// | funct6 | fp | ms2 | ms1 | funct3 | md | suffix | +// +// suffix = 0111111 (AME prefix) +// funct3 = 100 (matrix multiplication) +// +// Widening Instructions (required for AMUL > 1): +// - mqma.b.mm: int8 → int32 (quad-widen, AMUL ≥ 4, mmi8i32 MANDATORY) +// - mwma.h.mm: int16 → int32 (double-widen, AMUL ≥ 2) +// - mwma.w.mm: int32 → int64 (double-widen, AMUL ≥ 2) +//===----------------------------------------------------------------------===// + +// AME Opcode suffix (bits [6:0]) +def OPC_AME : RISCVOpcode<"OPC_AME", 0b0111111>; + +//===----------------------------------------------------------------------===// +// AME Custom Operand Types for Tile Register Indices +//===----------------------------------------------------------------------===// +// These operand types allow intrinsics to pass immediate indices (0-7) +// for tile and accumulator registers. The AsmString of pseudo instructions +// hardcodes "acc" and "tr" prefixes so that indices 0-7 are printed as +// acc0-acc7 and tr0-tr7 respectively, without modifying LLVM submodule. + +// AsmOperandClass for tile register index (0-7) +def AMETileIndexAsmOperand : AsmOperandClass { + let Name = "AMETileIndex"; + let RenderMethod = "addImmOperands"; + let PredicateMethod = "isUImm3"; + let DiagnosticType = "InvalidAMETileIndex"; +} + +// AsmOperandClass for accumulator register index (0-7) +def AMEAccIndexAsmOperand : AsmOperandClass { + let Name = "AMEAccIndex"; + let RenderMethod = "addImmOperands"; + let PredicateMethod = "isUImm3"; + let DiagnosticType = "InvalidAMEAccIndex"; +} + +// Operand type for TileReg index (0-7), printed with "tr" prefix in AsmString +def AMETileIndex : RISCVOp { + let ParserMatchClass = AMETileIndexAsmOperand; + let DecoderMethod = "decodeUImmOperand<3>"; + let OperandType = "OPERAND_UIMM3"; +} + +// Operand type for AccReg index (0-7), printed with "acc" prefix in AsmString +def AMEAccIndex : RISCVOp { + let ParserMatchClass = AMEAccIndexAsmOperand; + let DecoderMethod = "decodeUImmOperand<3>"; + let OperandType = "OPERAND_UIMM3"; +} + +//===----------------------------------------------------------------------===// +// AME 64-bit Instruction Format Base Class +//===----------------------------------------------------------------------===// + +// Base class for AME 64-bit matrix multiplication instructions +// Uses TileReg for inputs (ms1, ms2) and AccReg for output (md) +class RVInstAME64 + : RVInst64 { + // Low 32 bits + bits<5> md; + bits<5> ms1; + bits<5> ms2; + bits<6> funct6; + bit fp; + bits<3> funct3; + + // High 32 bits + bits<5> funct5; + bits<3> frm; + bits<2> bma; + bits<3> typd; + bits<3> typ1; + bits<3> typ2; + bit sp; + bits<5> sps; + bits<7> opcode_hi; // [38:32] + + // Encode low 32 bits (suffix word) + let Inst{6-0} = 0b0111111; // AME prefix + let Inst{11-7} = md; + let Inst{14-12} = funct3; + let Inst{19-15} = ms1; + let Inst{24-20} = ms2; + let Inst{25} = fp; + let Inst{31-26} = funct6; + + // Encode high 32 bits (opcode word) + let Inst{38-32} = opcode_hi; + let Inst{43-39} = funct5; + let Inst{46-44} = frm; + let Inst{48-47} = bma; + let Inst{51-49} = typd; + let Inst{54-52} = typ1; + let Inst{57-55} = typ2; + let Inst{58} = sp; + let Inst{63-59} = sps; +} + +//===----------------------------------------------------------------------===// +// AME Matrix Multiplication Instructions +//===----------------------------------------------------------------------===// +// Format: mma.{h|w|dw}.mm acc, tr, tr +// Semantics: acc = acc + tr1 * tr2 +// +// Data Flow: TileReg × TileReg → AccReg (accumulate) +// - ms1: TileReg for matrix A +// - ms2: TileReg for matrix B +// - md: AccReg for accumulation result C +// +// typ1/typ2/typd encoding: +// 000 = int8 (b), 001 = int16 (h), 010 = int32 (w), 011 = int64 (dw) +// 100 = use mtype.msew, 111 = int4 +// +// funct5 encoding: +// 00001 = mma (signed, no saturation) +// 00010 = mwma (double-widen) +// 00100 = mqma (quad-widen) +// 10001 = msma (signed, saturated) +//===----------------------------------------------------------------------===// + +// No-widen matrix multiply-accumulate: acc = acc + tr1 * tr2 +// Input and output have the same element width +class AME_MMA_MM typd_val, bits<3> typ1_val, bits<3> typ2_val, + string opcodestr> + : RVInstAME64<(outs AccReg:$md), + (ins AccReg:$md_in, TileReg:$ms1, TileReg:$ms2), + opcodestr, "$md, $ms1, $ms2"> { + let Constraints = "$md = $md_in"; // Accumulator constraint + let funct6 = 0b000000; + let fp = 0; + let funct3 = 0b100; // Matrix multiplication + let opcode_hi = 0b0000011; // xxyyy11 where xx=00, yyy=001 + let funct5 = 0b00001; // mma (signed, no saturation) + let frm = 0b000; + let bma = 0b00; // Default: not agnostic + let typd = typd_val; + let typ1 = typ1_val; + let typ2 = typ2_val; + let sp = 0; // No sparsity + let sps = 0b00000; +} + +// mma.h.mm - int16 × int16 → int16 accumulate (no widen) +let Predicates = [HasBuddyExt], hasSideEffects = 0, mayLoad = 0, mayStore = 0 in +def AME_MMA_H_MM : AME_MMA_MM<0b001, 0b001, 0b001, "mma.h.mm">; + +// mma.w.mm - int32 × int32 → int32 accumulate (no widen) +let Predicates = [HasBuddyExt], hasSideEffects = 0, mayLoad = 0, mayStore = 0 in +def AME_MMA_W_MM : AME_MMA_MM<0b010, 0b010, 0b010, "mma.w.mm">; + +// mma.dw.mm - int64 × int64 → int64 accumulate (no widen) +let Predicates = [HasBuddyExt], hasSideEffects = 0, mayLoad = 0, mayStore = 0 in +def AME_MMA_DW_MM : AME_MMA_MM<0b011, 0b011, 0b011, "mma.dw.mm">; + +//===----------------------------------------------------------------------===// +// AME Widening Matrix Multiplication Instructions +//===----------------------------------------------------------------------===// +// Double-widen: output element is 2× width of input elements +// Quad-widen: output element is 4× width of input elements +// +// These require AMUL ≥ 2 (double) or AMUL ≥ 4 (quad) to ensure +// accumulator has sufficient width. +// +// Mandatory: mqma.b.mm (int8→int32) for mmi8i32 feature +//===----------------------------------------------------------------------===// + +// Double-widen: acc = acc + tr1 * tr2, output is 2× width +// Requires AMUL ≥ 2 +class AME_MWMA_MM typd_val, bits<3> typ1_val, bits<3> typ2_val, + string opcodestr> + : RVInstAME64<(outs AccReg:$md), + (ins AccReg:$md_in, TileReg:$ms1, TileReg:$ms2), + opcodestr, "$md, $ms1, $ms2"> { + let Constraints = "$md = $md_in"; + let funct6 = 0b000000; + let fp = 0; + let funct3 = 0b100; + let opcode_hi = 0b0000011; + let funct5 = 0b00010; // mwma (double-widen) + let frm = 0b000; + let bma = 0b00; + let typd = typd_val; + let typ1 = typ1_val; + let typ2 = typ2_val; + let sp = 0; + let sps = 0b00000; +} + +// mwma.h.mm - int16 × int16 → int32 accumulate (double-widen) +// Requires: AMUL ≥ 2, mmi16i32 feature +let Predicates = [HasBuddyExt], hasSideEffects = 0, mayLoad = 0, mayStore = 0 in +def AME_MWMA_H_MM : AME_MWMA_MM<0b010, 0b001, 0b001, "mwma.h.mm">; + +// mwma.w.mm - int32 × int32 → int64 accumulate (double-widen) +// Requires: AMUL ≥ 2, mmi32i64 feature +let Predicates = [HasBuddyExt], hasSideEffects = 0, mayLoad = 0, mayStore = 0 in +def AME_MWMA_W_MM : AME_MWMA_MM<0b011, 0b010, 0b010, "mwma.w.mm">; + +// Quad-widen: acc = acc + tr1 * tr2, output is 4× width +// Requires AMUL ≥ 4 +class AME_MQMA_MM typd_val, bits<3> typ1_val, bits<3> typ2_val, + string opcodestr> + : RVInstAME64<(outs AccReg:$md), + (ins AccReg:$md_in, TileReg:$ms1, TileReg:$ms2), + opcodestr, "$md, $ms1, $ms2"> { + let Constraints = "$md = $md_in"; + let funct6 = 0b000000; + let fp = 0; + let funct3 = 0b100; + let opcode_hi = 0b0000011; + let funct5 = 0b00100; // mqma (quad-widen) + let frm = 0b000; + let bma = 0b00; + let typd = typd_val; + let typ1 = typ1_val; + let typ2 = typ2_val; + let sp = 0; + let sps = 0b00000; +} + +// mqma.b.mm - int8 × int8 → int32 accumulate (quad-widen) +// MANDATORY for mmi8i32 feature (required by Spec) +// Requires: AMUL ≥ 4 +let Predicates = [HasBuddyExt], hasSideEffects = 0, mayLoad = 0, mayStore = 0 in +def AME_MQMA_B_MM : AME_MQMA_MM<0b010, 0b000, 0b000, "mqma.b.mm">; + +//===----------------------------------------------------------------------===// +// AME Configuration Instructions +//===----------------------------------------------------------------------===// +// Configuration instruction format (64-bit): +// | 63:43 | 42:39 | 38:32 | +// | imm[31:11] | funct4 | opcode | +// | 31:26 | 25:20 | 19:15 | 14:12 | 11:7 | 6:0 | +// | funct6 | imm[10:5] | rs1 | funct3 | rd | suffix | +//===----------------------------------------------------------------------===// + +class RVInstAMEConfig64 funct6_val, string opcodestr> + : RVInst64<(outs GPR:$rd), (ins GPR:$rs1), + opcodestr, "$rd, $rs1", [], InstFormatOther> { + bits<5> rd; + bits<5> rs1; + + let Inst{6-0} = 0b0111111; // AME prefix + let Inst{11-7} = rd; + let Inst{14-12} = 0b000; // funct3 for config + let Inst{19-15} = rs1; + let Inst{25-20} = 0b000000; + let Inst{31-26} = funct6_val; + let Inst{38-32} = 0b0000011; // opcode + let Inst{42-39} = 0b0000; // funct4 + let Inst{63-43} = 0; // imm[31:11] = 0 +} + +class RVInstAMEConfigImm64 funct6_val, string opcodestr> + : RVInst64<(outs GPR:$rd), (ins uimm32:$imm), + opcodestr, "$rd, $imm", [], InstFormatOther> { + bits<5> rd; + bits<32> imm; + + let Inst{6-0} = 0b0111111; // AME prefix + let Inst{11-7} = rd; + let Inst{14-12} = 0b000; // funct3 for config + let Inst{19-15} = imm{4-0}; // imm[4:0] in rs1 field + let Inst{25-20} = imm{10-5}; // imm[10:5] + let Inst{31-26} = funct6_val; + let Inst{38-32} = 0b0000011; // opcode + let Inst{42-39} = 0b0000; // funct4 + let Inst{63-43} = imm{31-11}; // imm[31:11] +} + +// msettilem - set tile M dimension from register +let Predicates = [HasBuddyExt], hasSideEffects = 0, mayLoad = 0, mayStore = 0 in +def AME_MSETTILEM : RVInstAMEConfig64<0b000100, "msettilem">; + +// msettilemi - set tile M dimension from immediate +let Predicates = [HasBuddyExt], hasSideEffects = 0, mayLoad = 0, mayStore = 0 in +def AME_MSETTILEMI : RVInstAMEConfigImm64<0b000101, "msettilemi">; + +// msettilen - set tile N dimension from register +let Predicates = [HasBuddyExt], hasSideEffects = 0, mayLoad = 0, mayStore = 0 in +def AME_MSETTILEN : RVInstAMEConfig64<0b001000, "msettilen">; + +// msettileni - set tile N dimension from immediate +let Predicates = [HasBuddyExt], hasSideEffects = 0, mayLoad = 0, mayStore = 0 in +def AME_MSETTILENI : RVInstAMEConfigImm64<0b001001, "msettileni">; + +// msettilek - set tile K dimension from register +let Predicates = [HasBuddyExt], hasSideEffects = 0, mayLoad = 0, mayStore = 0 in +def AME_MSETTILEK : RVInstAMEConfig64<0b001100, "msettilek">; + +// msettileki - set tile K dimension from immediate +let Predicates = [HasBuddyExt], hasSideEffects = 0, mayLoad = 0, mayStore = 0 in +def AME_MSETTILEKI : RVInstAMEConfigImm64<0b001101, "msettileki">; + +//===----------------------------------------------------------------------===// +// AME Load/Store Instructions +//===----------------------------------------------------------------------===// +// Load/Store instruction format (64-bit): +// | 63:51 | 50:49 | 48:47 | 46:44 | 43:39 | 38:32 | +// | resv | mt | bma | eew | funct5 | opcode | +// | 31:26 | 25 | 24:20 | 19:15 | 14:12 | 11:7 | 6:0 | +// | funct6 | ls | rs2 | rs1 | funct3 | ms3/md | suffix | +// +// mt (matrix type): 00=accumulator(C), 01=left(A), 10=right(B), 11=result +// eew (element width): 000=8b, 001=16b, 010=32b, 011=64b +// +// Register Usage: +// - mt=01 (A) or mt=10 (B): Uses TileReg +// - mt=00 (C): Uses AccReg +// +// Data Flow Examples: +// mlae32.m tr0, (a0), a1 # Load matrix A into TileReg +// mlbe32.m tr1, (a0), a1 # Load matrix B into TileReg +// mqma.b.mm acc0, tr0, tr1 # Compute: AccReg = AccReg + TileReg × TileReg +// msce32.m acc0, (a0), a1 # Store AccReg to memory +//===----------------------------------------------------------------------===// + +// Load into TileReg (for matrix A and B) +class RVInstAMELoadTile64 mt_val, bits<3> eew_val, string opcodestr> + : RVInst64<(outs TileReg:$md), (ins GPR:$rs1, GPR:$rs2), + opcodestr, "$md, $rs1, $rs2", [], InstFormatOther> { + bits<5> md; + bits<5> rs1; + bits<5> rs2; + + let Inst{6-0} = 0b0111111; // AME prefix + let Inst{11-7} = md; + let Inst{14-12} = 0b001; // funct3 for load/store + let Inst{19-15} = rs1; + let Inst{24-20} = rs2; + let Inst{25} = 0; // ls = 0 for load + let Inst{31-26} = 0b000000; // funct6 + let Inst{38-32} = 0b0000011; // opcode + let Inst{43-39} = 0b00000; // funct5 + let Inst{46-44} = eew_val; // eew + let Inst{48-47} = 0b00; // bma + let Inst{50-49} = mt_val; // mt + let Inst{63-51} = 0; // reserved +} + +// Load into AccReg (for accumulator C) +class RVInstAMELoadAcc64 mt_val, bits<3> eew_val, string opcodestr> + : RVInst64<(outs AccReg:$md), (ins GPR:$rs1, GPR:$rs2), + opcodestr, "$md, $rs1, $rs2", [], InstFormatOther> { + bits<5> md; + bits<5> rs1; + bits<5> rs2; + + let Inst{6-0} = 0b0111111; + let Inst{11-7} = md; + let Inst{14-12} = 0b001; + let Inst{19-15} = rs1; + let Inst{24-20} = rs2; + let Inst{25} = 0; + let Inst{31-26} = 0b000000; + let Inst{38-32} = 0b0000011; + let Inst{43-39} = 0b00000; + let Inst{46-44} = eew_val; + let Inst{48-47} = 0b00; + let Inst{50-49} = mt_val; + let Inst{63-51} = 0; +} + +// Store from TileReg (for matrix A and B) +class RVInstAMEStoreTile64 mt_val, bits<3> eew_val, string opcodestr> + : RVInst64<(outs), (ins TileReg:$ms3, GPR:$rs1, GPR:$rs2), + opcodestr, "$ms3, $rs1, $rs2", [], InstFormatOther> { + bits<5> ms3; + bits<5> rs1; + bits<5> rs2; + + let Inst{6-0} = 0b0111111; + let Inst{11-7} = ms3; + let Inst{14-12} = 0b001; + let Inst{19-15} = rs1; + let Inst{24-20} = rs2; + let Inst{25} = 1; // ls = 1 for store + let Inst{31-26} = 0b000000; + let Inst{38-32} = 0b0000011; + let Inst{43-39} = 0b00000; + let Inst{46-44} = eew_val; + let Inst{48-47} = 0b00; + let Inst{50-49} = mt_val; + let Inst{63-51} = 0; +} + +// Store from AccReg (for accumulator C) +class RVInstAMEStoreAcc64 mt_val, bits<3> eew_val, string opcodestr> + : RVInst64<(outs), (ins AccReg:$ms3, GPR:$rs1, GPR:$rs2), + opcodestr, "$ms3, $rs1, $rs2", [], InstFormatOther> { + bits<5> ms3; + bits<5> rs1; + bits<5> rs2; + + let Inst{6-0} = 0b0111111; + let Inst{11-7} = ms3; + let Inst{14-12} = 0b001; + let Inst{19-15} = rs1; + let Inst{24-20} = rs2; + let Inst{25} = 1; + let Inst{31-26} = 0b000000; + let Inst{38-32} = 0b0000011; + let Inst{43-39} = 0b00000; + let Inst{46-44} = eew_val; + let Inst{48-47} = 0b00; + let Inst{50-49} = mt_val; + let Inst{63-51} = 0; +} + +//===----------------------------------------------------------------------===// +// Load matrix A (left operand) into TileReg - mlae*.m +// Syntax: mlae{8|16|32|64}.m tr, (rs1), rs2 +// tr: Destination TileReg +// rs1: Base address (GPR) +// rs2: Row stride in bytes (GPR) +//===----------------------------------------------------------------------===// +let Predicates = [HasBuddyExt], hasSideEffects = 0, mayLoad = 1, mayStore = 0 in { + def AME_MLAE8_M : RVInstAMELoadTile64<0b01, 0b000, "mlae8.m">; + def AME_MLAE16_M : RVInstAMELoadTile64<0b01, 0b001, "mlae16.m">; + def AME_MLAE32_M : RVInstAMELoadTile64<0b01, 0b010, "mlae32.m">; + def AME_MLAE64_M : RVInstAMELoadTile64<0b01, 0b011, "mlae64.m">; +} + +//===----------------------------------------------------------------------===// +// Load matrix B (right operand) into TileReg - mlbe*.m +// Syntax: mlbe{8|16|32|64}.m tr, (rs1), rs2 +//===----------------------------------------------------------------------===// +let Predicates = [HasBuddyExt], hasSideEffects = 0, mayLoad = 1, mayStore = 0 in { + def AME_MLBE8_M : RVInstAMELoadTile64<0b10, 0b000, "mlbe8.m">; + def AME_MLBE16_M : RVInstAMELoadTile64<0b10, 0b001, "mlbe16.m">; + def AME_MLBE32_M : RVInstAMELoadTile64<0b10, 0b010, "mlbe32.m">; + def AME_MLBE64_M : RVInstAMELoadTile64<0b10, 0b011, "mlbe64.m">; +} + +//===----------------------------------------------------------------------===// +// Load matrix C (accumulator) into AccReg - mlce*.m +// Syntax: mlce{8|16|32|64}.m acc, (rs1), rs2 +// acc: Destination AccReg (MLEN×AMUL bits) +// Note: eew here refers to the output element width after widening +//===----------------------------------------------------------------------===// +let Predicates = [HasBuddyExt], hasSideEffects = 0, mayLoad = 1, mayStore = 0 in { + def AME_MLCE8_M : RVInstAMELoadAcc64<0b00, 0b000, "mlce8.m">; + def AME_MLCE16_M : RVInstAMELoadAcc64<0b00, 0b001, "mlce16.m">; + def AME_MLCE32_M : RVInstAMELoadAcc64<0b00, 0b010, "mlce32.m">; + def AME_MLCE64_M : RVInstAMELoadAcc64<0b00, 0b011, "mlce64.m">; +} + +//===----------------------------------------------------------------------===// +// Store matrix A from TileReg - msae*.m +// Syntax: msae{8|16|32|64}.m tr, (rs1), rs2 +//===----------------------------------------------------------------------===// +let Predicates = [HasBuddyExt], hasSideEffects = 0, mayLoad = 0, mayStore = 1 in { + def AME_MSAE8_M : RVInstAMEStoreTile64<0b01, 0b000, "msae8.m">; + def AME_MSAE16_M : RVInstAMEStoreTile64<0b01, 0b001, "msae16.m">; + def AME_MSAE32_M : RVInstAMEStoreTile64<0b01, 0b010, "msae32.m">; + def AME_MSAE64_M : RVInstAMEStoreTile64<0b01, 0b011, "msae64.m">; +} + +//===----------------------------------------------------------------------===// +// Store matrix B from TileReg - msbe*.m +// Syntax: msbe{8|16|32|64}.m tr, (rs1), rs2 +//===----------------------------------------------------------------------===// +let Predicates = [HasBuddyExt], hasSideEffects = 0, mayLoad = 0, mayStore = 1 in { + def AME_MSBE8_M : RVInstAMEStoreTile64<0b10, 0b000, "msbe8.m">; + def AME_MSBE16_M : RVInstAMEStoreTile64<0b10, 0b001, "msbe16.m">; + def AME_MSBE32_M : RVInstAMEStoreTile64<0b10, 0b010, "msbe32.m">; + def AME_MSBE64_M : RVInstAMEStoreTile64<0b10, 0b011, "msbe64.m">; +} + +//===----------------------------------------------------------------------===// +// Store matrix C (accumulator) from AccReg - msce*.m +// Syntax: msce{8|16|32|64}.m acc, (rs1), rs2 +// acc: Source AccReg (MLEN×AMUL bits) +// This is the primary store for computation results +//===----------------------------------------------------------------------===// +let Predicates = [HasBuddyExt], hasSideEffects = 0, mayLoad = 0, mayStore = 1 in { + def AME_MSCE8_M : RVInstAMEStoreAcc64<0b00, 0b000, "msce8.m">; + def AME_MSCE16_M : RVInstAMEStoreAcc64<0b00, 0b001, "msce16.m">; + def AME_MSCE32_M : RVInstAMEStoreAcc64<0b00, 0b010, "msce32.m">; + def AME_MSCE64_M : RVInstAMEStoreAcc64<0b00, 0b011, "msce64.m">; +} + +//===----------------------------------------------------------------------===// +// AME Extension Pattern Matching +//===----------------------------------------------------------------------===// +// Connect LLVM intrinsics to AME machine instructions +// +// Register Model: +// - TileReg (tr0-tr7): For input matrices A and B +// - AccReg (acc0-acc7): For output/accumulator C +// +// Typical Data Flow for Matrix Multiplication (e.g., int8→int32): +// 1. msettilem/n/k: Configure tile dimensions +// 2. mlae8.m tr0, (a0), stride_a: Load matrix A into TileReg +// 3. mlbe8.m tr1, (a1), stride_b: Load matrix B into TileReg +// 4. mqma.b.mm acc0, tr0, tr1: Compute acc0 = acc0 + tr0 × tr1 +// 5. msce32.m acc0, (a2), stride_c: Store AccReg to memory +// +// Note: For quad-widen (int8→int32), input uses 8-bit load (mlae8/mlbe8) +// but output uses 32-bit store (msce32) because AMUL=4 widens the output. +//===----------------------------------------------------------------------===// + +// Configuration instruction patterns +// These use GPR operands and return the actual configured value +let Predicates = [HasBuddyExt] in { + // msettilem - set tile M dimension from GPR, returns actual value + def : Pat<(i64 (int_riscv_buddy_msettilem GPR:$rs1)), + (AME_MSETTILEM GPR:$rs1)>; + + // msettilen - set tile N dimension from GPR, returns actual value + def : Pat<(i64 (int_riscv_buddy_msettilen GPR:$rs1)), + (AME_MSETTILEN GPR:$rs1)>; + + // msettilek - set tile K dimension from GPR, returns actual value + def : Pat<(i64 (int_riscv_buddy_msettilek GPR:$rs1)), + (AME_MSETTILEK GPR:$rs1)>; +} + +//===----------------------------------------------------------------------===// +// AME Additional Integer Matrix Operations +//===----------------------------------------------------------------------===// +// Unsigned and mixed-sign matrix multiplication variants +// +// Naming convention: +// - mma: signed × signed +// - mmau: unsigned × unsigned +// - mmasu: signed × unsigned +// - mmaus: unsigned × signed +//===----------------------------------------------------------------------===// + +// Unsigned no-widen multiply-accumulate +class AME_MMAU_MM typd_val, bits<3> typ1_val, bits<3> typ2_val, + string opcodestr> + : RVInstAME64<(outs AccReg:$md), + (ins AccReg:$md_in, TileReg:$ms1, TileReg:$ms2), + opcodestr, "$md, $ms1, $ms2"> { + let Constraints = "$md = $md_in"; + let funct6 = 0b000000; + let fp = 0; + let funct3 = 0b100; + let opcode_hi = 0b0000011; + let funct5 = 0b01001; // mmau (unsigned) + let frm = 0b000; + let bma = 0b00; + let typd = typd_val; + let typ1 = typ1_val; + let typ2 = typ2_val; + let sp = 0; + let sps = 0b00000; +} + +// mmau.b.mm - uint8 × uint8 → uint32 (for unsigned int8) +let Predicates = [HasBuddyExt], hasSideEffects = 0, mayLoad = 0, mayStore = 0 in +def AME_MMAU_B_MM : AME_MMAU_MM<0b010, 0b000, 0b000, "mmau.b.mm">; + +// Unsigned quad-widen multiply-accumulate +class AME_MQMAU_MM typd_val, bits<3> typ1_val, bits<3> typ2_val, + string opcodestr> + : RVInstAME64<(outs AccReg:$md), + (ins AccReg:$md_in, TileReg:$ms1, TileReg:$ms2), + opcodestr, "$md, $ms1, $ms2"> { + let Constraints = "$md = $md_in"; + let funct6 = 0b000000; + let fp = 0; + let funct3 = 0b100; + let opcode_hi = 0b0000011; + let funct5 = 0b01100; // mqmau (unsigned quad-widen) + let frm = 0b000; + let bma = 0b00; + let typd = typd_val; + let typ1 = typ1_val; + let typ2 = typ2_val; + let sp = 0; + let sps = 0b00000; +} + +// mqmau.b.mm - uint8 × uint8 → uint32 (quad-widen unsigned) +let Predicates = [HasBuddyExt], hasSideEffects = 0, mayLoad = 0, mayStore = 0 in +def AME_MQMAU_B_MM : AME_MQMAU_MM<0b010, 0b000, 0b000, "mqmau.b.mm">; + +// Mixed-sign: signed × unsigned +class AME_MQMASU_MM typd_val, bits<3> typ1_val, bits<3> typ2_val, + string opcodestr> + : RVInstAME64<(outs AccReg:$md), + (ins AccReg:$md_in, TileReg:$ms1, TileReg:$ms2), + opcodestr, "$md, $ms1, $ms2"> { + let Constraints = "$md = $md_in"; + let funct6 = 0b000000; + let fp = 0; + let funct3 = 0b100; + let opcode_hi = 0b0000011; + let funct5 = 0b00101; // mqmasu (signed × unsigned quad-widen) + let frm = 0b000; + let bma = 0b00; + let typd = typd_val; + let typ1 = typ1_val; + let typ2 = typ2_val; + let sp = 0; + let sps = 0b00000; +} + +// mqmasu.b.mm - int8 × uint8 → int32 +let Predicates = [HasBuddyExt], hasSideEffects = 0, mayLoad = 0, mayStore = 0 in +def AME_MQMASU_B_MM : AME_MQMASU_MM<0b010, 0b000, 0b000, "mqmasu.b.mm">; + +// Mixed-sign: unsigned × signed +class AME_MQMAUS_MM typd_val, bits<3> typ1_val, bits<3> typ2_val, + string opcodestr> + : RVInstAME64<(outs AccReg:$md), + (ins AccReg:$md_in, TileReg:$ms1, TileReg:$ms2), + opcodestr, "$md, $ms1, $ms2"> { + let Constraints = "$md = $md_in"; + let funct6 = 0b000000; + let fp = 0; + let funct3 = 0b100; + let opcode_hi = 0b0000011; + let funct5 = 0b00110; // mqmaus (unsigned × signed quad-widen) + let frm = 0b000; + let bma = 0b00; + let typd = typd_val; + let typ1 = typ1_val; + let typ2 = typ2_val; + let sp = 0; + let sps = 0b00000; +} + +// mqmaus.b.mm - uint8 × int8 → int32 +let Predicates = [HasBuddyExt], hasSideEffects = 0, mayLoad = 0, mayStore = 0 in +def AME_MQMAUS_B_MM : AME_MQMAUS_MM<0b010, 0b000, 0b000, "mqmaus.b.mm">; + +//===----------------------------------------------------------------------===// +// AME Zero/Initialize Instructions +//===----------------------------------------------------------------------===// +// mzero - Zero out an accumulation register +// Useful for initializing before accumulation loop + +class AME_MZERO + : RVInst64<(outs AccReg:$md), (ins), + opcodestr, "$md", [], InstFormatOther> { + bits<5> md; + + let Inst{6-0} = 0b0111111; + let Inst{11-7} = md; + let Inst{14-12} = 0b101; // funct3 for arithmetic + let Inst{19-15} = 0b00000; + let Inst{24-20} = 0b00000; + let Inst{25} = 0; + let Inst{31-26} = 0b000000; + let Inst{38-32} = 0b0000011; + let Inst{43-39} = 0b10000; // funct5 for zero + let Inst{63-44} = 0; +} + +let Predicates = [HasBuddyExt], hasSideEffects = 0, mayLoad = 0, mayStore = 0 in +def AME_MZERO_M : AME_MZERO<"mzero.m">; + +//===----------------------------------------------------------------------===// +// AME Intrinsic Pattern Matching +//===----------------------------------------------------------------------===// +// These patterns map LLVM intrinsics to AME machine instructions. +// +// Note: AME intrinsics use i64 indices for tile registers instead of +// actual register operands. This is because the MLIR lowering generates +// calls with constant indices that get mapped to physical registers +// at the final code generation stage. +// +// For tile-based operations, the tile register index (0-7) is encoded +// directly into the instruction's register field. +//===----------------------------------------------------------------------===// + +//===----------------------------------------------------------------------===// +// Configuration Instruction Patterns +//===----------------------------------------------------------------------===// + +// msettilem - set tile M dimension from GPR +let Predicates = [HasBuddyExt] in { + def : Pat<(i64 (int_riscv_buddy_msettilem i64:$rs1)), + (AME_MSETTILEM GPR:$rs1)>; +} + +// msettilen - set tile N dimension from GPR +let Predicates = [HasBuddyExt] in { + def : Pat<(i64 (int_riscv_buddy_msettilen i64:$rs1)), + (AME_MSETTILEN GPR:$rs1)>; +} + +// msettilek - set tile K dimension from GPR +let Predicates = [HasBuddyExt] in { + def : Pat<(i64 (int_riscv_buddy_msettilek i64:$rs1)), + (AME_MSETTILEK GPR:$rs1)>; +} + +//===----------------------------------------------------------------------===// +// Pseudo Instructions for Index-Based Operations +//===----------------------------------------------------------------------===// +// These pseudo instructions accept i64 indices and have AsmString for direct +// assembly output. This allows the pseudo instructions to be printed directly +// without needing complex expansion logic. +// +// For load/store/mma instructions, the AsmString hardcodes "acc" and "tr" +// prefixes so that indices 0-7 are printed as acc0-acc7 and tr0-tr7. +//===----------------------------------------------------------------------===// + +// Pseudo instruction for msettilemi with i64 immediate +// Output: msettilemi x0, +let Predicates = [HasBuddyExt], hasSideEffects = 1, mayLoad = 0, mayStore = 0, + isCodeGenOnly = 1 in +def AME_MSETTILEMI_PSEUDO : Pseudo<(outs), (ins i64imm:$imm), []> { + let AsmString = "msettilemi\tx0, $imm"; +} + +// Pseudo instruction for msettileni with i64 immediate +let Predicates = [HasBuddyExt], hasSideEffects = 1, mayLoad = 0, mayStore = 0, + isCodeGenOnly = 1 in +def AME_MSETTILENI_PSEUDO : Pseudo<(outs), (ins i64imm:$imm), []> { + let AsmString = "msettileni\tx0, $imm"; +} + +// Pseudo instruction for msettileki with i64 immediate +let Predicates = [HasBuddyExt], hasSideEffects = 1, mayLoad = 0, mayStore = 0, + isCodeGenOnly = 1 in +def AME_MSETTILEKI_PSEUDO : Pseudo<(outs), (ins i64imm:$imm), []> { + let AsmString = "msettileki\tx0, $imm"; +} + +// Pseudo instruction for mzero with accumulator index +let Predicates = [HasBuddyExt], hasSideEffects = 1, mayLoad = 0, mayStore = 0, + isCodeGenOnly = 1 in +def AME_MZERO_PSEUDO : Pseudo<(outs), (ins AMEAccIndex:$md), []> { + let AsmString = "mzero.m\tacc$md"; +} + +// Pseudo instruction for mlae32.m with tile index +let Predicates = [HasBuddyExt], hasSideEffects = 1, mayLoad = 1, mayStore = 0, + isCodeGenOnly = 1 in +def AME_MLAE32_M_PSEUDO : Pseudo<(outs), (ins AMETileIndex:$md, GPR:$rs1, GPR:$rs2), []> { + let AsmString = "mlae32.m\ttr$md, ($rs1), $rs2"; +} + +// Pseudo instruction for mlbe32.m with tile index +let Predicates = [HasBuddyExt], hasSideEffects = 1, mayLoad = 1, mayStore = 0, + isCodeGenOnly = 1 in +def AME_MLBE32_M_PSEUDO : Pseudo<(outs), (ins AMETileIndex:$md, GPR:$rs1, GPR:$rs2), []> { + let AsmString = "mlbe32.m\ttr$md, ($rs1), $rs2"; +} + +// Pseudo instruction for msce32.m with accumulator index +let Predicates = [HasBuddyExt], hasSideEffects = 1, mayLoad = 0, mayStore = 1, + isCodeGenOnly = 1 in +def AME_MSCE32_M_PSEUDO : Pseudo<(outs), (ins AMEAccIndex:$ms3, GPR:$rs1, GPR:$rs2), []> { + let AsmString = "msce32.m\tacc$ms3, ($rs1), $rs2"; +} + +// Pseudo instruction for mma.w.mm.tile with indices +// md: AccReg index (0-7), ms1/ms2: TileReg indices (0-7) +let Predicates = [HasBuddyExt], hasSideEffects = 1, mayLoad = 1, mayStore = 1, + isCodeGenOnly = 1 in +def AME_MMA_W_MM_TILE_PSEUDO : Pseudo<(outs), + (ins AMEAccIndex:$md, AMETileIndex:$ms1, AMETileIndex:$ms2), []> { + let AsmString = "mma.w.mm\tacc$md, tr$ms1, tr$ms2"; +} + +//===----------------------------------------------------------------------===// +// Pattern Matching for Immediate Configuration Instructions +//===----------------------------------------------------------------------===// + +// msettilemi - set tile M dimension with immediate +let Predicates = [HasBuddyExt] in { + def : Pat<(int_riscv_buddy_msettilemi timm:$imm), + (AME_MSETTILEMI_PSEUDO timm:$imm)>; +} + +// msettileni - set tile N dimension with immediate +let Predicates = [HasBuddyExt] in { + def : Pat<(int_riscv_buddy_msettileni timm:$imm), + (AME_MSETTILENI_PSEUDO timm:$imm)>; +} + +// msettileki - set tile K dimension with immediate +let Predicates = [HasBuddyExt] in { + def : Pat<(int_riscv_buddy_msettileki timm:$imm), + (AME_MSETTILEKI_PSEUDO timm:$imm)>; +} + +//===----------------------------------------------------------------------===// +// Pattern Matching for Zero Instruction +//===----------------------------------------------------------------------===// + +let Predicates = [HasBuddyExt] in { + def : Pat<(int_riscv_buddy_mzero timm:$md), + (AME_MZERO_PSEUDO timm:$md)>; +} + +//===----------------------------------------------------------------------===// +// Pattern Matching for Load Instructions +//===----------------------------------------------------------------------===// + +// mlae32.m - Load 32-bit left matrix A to tile register +let Predicates = [HasBuddyExt] in { + def : Pat<(int_riscv_buddy_mlae32_m timm:$md, iPTR:$rs1, i64:$rs2), + (AME_MLAE32_M_PSEUDO timm:$md, GPR:$rs1, GPR:$rs2)>; +} + +// mlbe32.m - Load 32-bit right matrix B to tile register +let Predicates = [HasBuddyExt] in { + def : Pat<(int_riscv_buddy_mlbe32_m timm:$md, iPTR:$rs1, i64:$rs2), + (AME_MLBE32_M_PSEUDO timm:$md, GPR:$rs1, GPR:$rs2)>; +} + +//===----------------------------------------------------------------------===// +// Pattern Matching for Store Instructions +//===----------------------------------------------------------------------===// + +// msce32.m - Store 32-bit output matrix C from accumulator +let Predicates = [HasBuddyExt] in { + def : Pat<(int_riscv_buddy_msce32_m timm:$ms3, iPTR:$rs1, i64:$rs2), + (AME_MSCE32_M_PSEUDO timm:$ms3, GPR:$rs1, GPR:$rs2)>; +} + +//===----------------------------------------------------------------------===// +// Pattern Matching for Tile-Based Matrix Multiply +//===----------------------------------------------------------------------===// + +// mma.w.mm.tile - int32 tile matrix multiply: md = md + ms1 x ms2 +let Predicates = [HasBuddyExt] in { + def : Pat<(int_riscv_buddy_mma_w_mm_tile timm:$md, timm:$ms1, timm:$ms2), + (AME_MMA_W_MM_TILE_PSEUDO timm:$md, timm:$ms1, timm:$ms2)>; +} + +//===----------------------------------------------------------------------===// +// AME Summary +//===----------------------------------------------------------------------===// +// Complete instruction set for basic matrix operations: +// +// Configuration: +// - msettilem, msettilemi: Set M dimension +// - msettilen, msettileni: Set N dimension +// - msettilek, msettileki: Set K dimension +// +// Load (Memory → Register): +// - mlae{8|16|32|64}.m: Load A into TileReg +// - mlbe{8|16|32|64}.m: Load B into TileReg +// - mlce{8|16|32|64}.m: Load C into AccReg +// +// Compute (TileReg × TileReg → AccReg): +// No-widen: +// - mma.{h|w|dw}.mm: Signed int16/32/64 +// Double-widen (AMUL ≥ 2): +// - mwma.{h|w}.mm: int16→int32, int32→int64 +// Quad-widen (AMUL ≥ 4, mmi8i32 MANDATORY): +// - mqma.b.mm: int8 → int32 (signed) +// - mqmau.b.mm: uint8 → uint32 (unsigned) +// - mqmasu.b.mm: int8 × uint8 → int32 +// - mqmaus.b.mm: uint8 × int8 → int32 +// +// Store (Register → Memory): +// - msae{8|16|32|64}.m: Store TileReg A +// - msbe{8|16|32|64}.m: Store TileReg B +// - msce{8|16|32|64}.m: Store AccReg C +// +// Utility: +// - mzero.m: Zero out AccReg +//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===// // diff --git a/cmake/BuddyMLIRConfig.cmake.in b/cmake/BuddyMLIRConfig.cmake.in new file mode 100644 index 0000000000..aa5d752918 --- /dev/null +++ b/cmake/BuddyMLIRConfig.cmake.in @@ -0,0 +1,5 @@ +@PACKAGE_INIT@ + +set(BUDDY_BINARY_DIR "${PACKAGE_PREFIX_DIR}/bin") +set(BUDDY_MLIR_INTERFACE_DIR "${PACKAGE_PREFIX_DIR}/@CMAKE_INSTALL_INCLUDEDIR@/buddy-mlir") +set(BUDDY_MLIR_LIB_DIR "${PACKAGE_PREFIX_DIR}/@CMAKE_INSTALL_LIBDIR@") diff --git a/docs/RVVEnvironment.md b/docs/RVVEnvironment.md index 779a4dfa9f..4ed9f93485 100644 --- a/docs/RVVEnvironment.md +++ b/docs/RVVEnvironment.md @@ -30,7 +30,8 @@ $ cd buddy-mlir $ mkdir llvm/build $ cd llvm/build $ cmake -G Ninja ../llvm \ - -DLLVM_ENABLE_PROJECTS="mlir;clang;openmp" \ + -DLLVM_ENABLE_PROJECTS="mlir;clang" \ + -DLLVM_ENABLE_RUNTIMES=openmp \ -DLLVM_TARGETS_TO_BUILD="host;RISCV" \ -DLLVM_ENABLE_ASSERTIONS=ON \ -DOPENMP_ENABLE_LIBOMPTARGET=OFF \ diff --git a/examples/AMEDialect/.gitignore b/examples/AMEDialect/.gitignore new file mode 100644 index 0000000000..59353931ab --- /dev/null +++ b/examples/AMEDialect/.gitignore @@ -0,0 +1,14 @@ +# Build outputs +*.o +*.elf +*.s +*.ll + +# Lowered intermediate files +*-lowered.mlir + +qemu-test/ +runtime_*.c +runtime_*.h +ame_matmul_demo.c +build_and_run.sh \ No newline at end of file diff --git a/examples/AMEDialect/makefile b/examples/AMEDialect/makefile new file mode 100644 index 0000000000..b338c82aa1 --- /dev/null +++ b/examples/AMEDialect/makefile @@ -0,0 +1,217 @@ +#!/usr/bin/env bash + +BUDDY_OPT := ../../build/bin/buddy-opt +BUDDY_TRANSLATE := ../../build/bin/buddy-translate +BUDDY_LLC := ../../build/bin/buddy-llc + +.PHONY: all clean help \ + mqma-b-mm mqma-b-mm-lower mqma-b-mm-translate mqma-b-mm-asm \ + mma-w-mm mma-w-mm-lower mma-w-mm-translate mma-w-mm-asm \ + mma-dw-mm mma-dw-mm-lower mma-dw-mm-translate mma-dw-mm-asm \ + mma-complete mma-complete-lower mma-complete-translate mma-complete-asm + +all: mqma-b-mm mma-w-mm mma-dw-mm + + + +#===----------------------------------------------------------------------===# +# mqma.b.mm (int8 quad-widen matrix multiply: int8 -> int32) +#===----------------------------------------------------------------------===# + +# Alias for backward compatibility +mqma-b-mm: mqma-b-mm-lower + +mqma-b-mm-lower: + @echo "=== Lowering AME mqma.b.mm (int8 -> int32 matrix multiply) ===" + ${BUDDY_OPT} mqma-b-mm.mlir \ + --lower-ame \ + -o mqma-b-mm-lowered.mlir + @echo "Lowered MLIR saved to mqma-b-mm-lowered.mlir" + +mqma-b-mm-translate: + @echo "=== Translating mqma.b.mm to LLVM IR ===" + @${BUDDY_OPT} mqma-b-mm.mlir \ + --lower-ame \ + -convert-linalg-to-loops \ + -lower-affine \ + -convert-scf-to-cf \ + -convert-cf-to-llvm \ + -convert-arith-to-llvm \ + -convert-math-to-llvm \ + -convert-func-to-llvm \ + -finalize-memref-to-llvm \ + -reconcile-unrealized-casts | \ + ${BUDDY_TRANSLATE} --buddy-to-llvmir \ + -o mqma-b-mm.ll + @echo "LLVM IR saved to mqma-b-mm.ll" + +mqma-b-mm-asm: + @echo "=== Generating assembly for mqma.b.mm ===" + @${BUDDY_OPT} mqma-b-mm.mlir \ + --lower-ame \ + -convert-linalg-to-loops \ + -lower-affine \ + -convert-scf-to-cf \ + -convert-cf-to-llvm \ + -convert-arith-to-llvm \ + -convert-math-to-llvm \ + -convert-func-to-llvm \ + -finalize-memref-to-llvm \ + -reconcile-unrealized-casts | \ + ${BUDDY_TRANSLATE} --buddy-to-llvmir | \ + ${BUDDY_LLC} -filetype=asm -mtriple=riscv64 \ + -mattr=+m,+v,+buddyext \ + -o mqma-b-mm.s + @echo "Assembly saved to mqma-b-mm.s" + +#===----------------------------------------------------------------------===# +# mma.w.mm (int32 matrix multiply-accumulate) +#===----------------------------------------------------------------------===# + +# Alias for backward compatibility +mma-w-mm: mma-w-mm-lower + +mma-w-mm-lower: + @echo "=== Lowering AME mma.w.mm (int32 matrix multiply-accumulate) ===" + ${BUDDY_OPT} mma-w-mm.mlir \ + --lower-ame \ + -o mma-w-mm-lowered.mlir + @echo "Lowered MLIR saved to mma-w-mm-lowered.mlir" + +mma-w-mm-translate: + @echo "=== Translating mma.w.mm to LLVM IR ===" + @${BUDDY_OPT} mma-w-mm.mlir \ + --lower-ame \ + -convert-linalg-to-loops \ + -lower-affine \ + -convert-scf-to-cf \ + -convert-cf-to-llvm \ + -convert-arith-to-llvm \ + -convert-math-to-llvm \ + -convert-func-to-llvm \ + -finalize-memref-to-llvm \ + -reconcile-unrealized-casts | \ + ${BUDDY_TRANSLATE} --buddy-to-llvmir \ + -o mma-w-mm.ll + @echo "LLVM IR saved to mma-w-mm.ll" + +mma-w-mm-asm: + @echo "=== Generating assembly for mma.w.mm ===" + @${BUDDY_OPT} mma-w-mm.mlir \ + --lower-ame \ + -convert-linalg-to-loops \ + -lower-affine \ + -convert-scf-to-cf \ + -convert-cf-to-llvm \ + -convert-arith-to-llvm \ + -convert-math-to-llvm \ + -convert-func-to-llvm \ + -finalize-memref-to-llvm \ + -reconcile-unrealized-casts | \ + ${BUDDY_TRANSLATE} --buddy-to-llvmir | \ + ${BUDDY_LLC} -filetype=asm -mtriple=riscv64 \ + -mattr=+m,+v,+buddyext \ + -o mma-w-mm.s + @echo "Assembly saved to mma-w-mm.s" + +#===----------------------------------------------------------------------===# +# mma.dw.mm (int64 double-widen matrix multiply-accumulate: int32 -> int64) +#===----------------------------------------------------------------------===# + +# Alias for backward compatibility +mma-dw-mm: mma-dw-mm-lower + +mma-dw-mm-lower: + @echo "=== Lowering AME mma.dw.mm (int32 -> int64 matrix multiply-accumulate) ===" + ${BUDDY_OPT} mma-dw-mm.mlir \ + --lower-ame \ + -o mma-dw-mm-lowered.mlir + @echo "Lowered MLIR saved to mma-dw-mm-lowered.mlir" + +mma-dw-mm-translate: + @echo "=== Translating mma.dw.mm to LLVM IR ===" + @${BUDDY_OPT} mma-dw-mm.mlir \ + --lower-ame \ + -convert-linalg-to-loops \ + -lower-affine \ + -convert-scf-to-cf \ + -convert-cf-to-llvm \ + -convert-arith-to-llvm \ + -convert-math-to-llvm \ + -convert-func-to-llvm \ + -finalize-memref-to-llvm \ + -reconcile-unrealized-casts | \ + ${BUDDY_TRANSLATE} --buddy-to-llvmir \ + -o mma-dw-mm.ll + @echo "LLVM IR saved to mma-dw-mm.ll" + +mma-dw-mm-asm: + @echo "=== Generating assembly for mma.dw.mm ===" + @${BUDDY_OPT} mma-dw-mm.mlir \ + --lower-ame \ + -convert-linalg-to-loops \ + -lower-affine \ + -convert-scf-to-cf \ + -convert-cf-to-llvm \ + -convert-arith-to-llvm \ + -convert-math-to-llvm \ + -convert-func-to-llvm \ + -finalize-memref-to-llvm \ + -reconcile-unrealized-casts | \ + ${BUDDY_TRANSLATE} --buddy-to-llvmir | \ + ${BUDDY_LLC} -filetype=asm -mtriple=riscv64 \ + -mattr=+m,+v,+buddyext \ + -o mma-dw-mm.s + @echo "Assembly saved to mma-dw-mm.s" + +#===----------------------------------------------------------------------===# +# mma-complete (complete demo: config + load + compute + store) +#===----------------------------------------------------------------------===# + +mma-complete: mma-complete-lower + +mma-complete-lower: + @echo "=== Lowering AME Complete MMA Demo ===" + ${BUDDY_OPT} mma-complete-demo.mlir \ + --lower-ame \ + -o mma-complete-demo-lowered.mlir + @echo "Lowered MLIR saved to mma-complete-demo-lowered.mlir" + +mma-complete-translate: + @echo "=== Translating complete MMA demo to LLVM IR ===" + @${BUDDY_OPT} mma-complete-demo.mlir \ + --lower-ame \ + -convert-linalg-to-loops \ + -lower-affine \ + -convert-scf-to-cf \ + -convert-cf-to-llvm \ + -convert-arith-to-llvm \ + -convert-math-to-llvm \ + -convert-func-to-llvm \ + -finalize-memref-to-llvm \ + -reconcile-unrealized-casts | \ + ${BUDDY_TRANSLATE} --buddy-to-llvmir \ + -o mma-complete-demo.ll + @echo "LLVM IR saved to mma-complete-demo.ll" + +mma-complete-asm: + @echo "=== Generating assembly for complete MMA demo ===" + @${BUDDY_OPT} mma-complete-demo.mlir \ + --lower-ame \ + -convert-linalg-to-loops \ + -lower-affine \ + -convert-scf-to-cf \ + -convert-cf-to-llvm \ + -convert-arith-to-llvm \ + -convert-math-to-llvm \ + -convert-func-to-llvm \ + -finalize-memref-to-llvm \ + -reconcile-unrealized-casts | \ + ${BUDDY_TRANSLATE} --buddy-to-llvmir | \ + ${BUDDY_LLC} -filetype=asm -mtriple=riscv64 \ + -mattr=+m,+v,+buddyext \ + -o mma-complete-demo.s + @echo "Assembly saved to mma-complete-demo.s" + +clean: + rm -f *.mlir.out *.ll *.s *-lowered.mlir diff --git a/examples/AMEDialect/mma-complete-demo.mlir b/examples/AMEDialect/mma-complete-demo.mlir new file mode 100644 index 0000000000..64cb76c3ef --- /dev/null +++ b/examples/AMEDialect/mma-complete-demo.mlir @@ -0,0 +1,68 @@ +// RUN: buddy-opt %s --lower-ame | FileCheck %s + +// =========================================================================== +// Complete Matrix Multiplication Demo using RISC-V Matrix Extension (AME) +// =========================================================================== +// +// This demo shows the complete flow of matrix multiplication: +// 1. Configure tile dimensions (msettilemi, msettileni, msettileki) +// 2. Zero accumulator (mzero) +// 3. Load matrix tiles (mlae32.m, mlbe32.m) +// 4. Execute matrix multiply (mma.w.mm.tile) +// 5. Store result (msce32.m) +// +// Matrix dimensions: C[M×N] = A[M×K] × B[K×N] +// Tile dimensions are configured via msettilem/msettilen/msettilek +// +// =========================================================================== + +module { + // Demo: int32 tile-based matrix multiplication + // Uses tile register operations (hardware-level abstraction) + func.func @mma_w_mm_tile_demo(%c_ptr: memref, + %a_ptr: memref, + %b_ptr: memref, + %stride_a: i64, + %stride_b: i64, + %stride_c: i64) { + + // Step 1: Configure tile dimensions + // For a simple 4x4 tile operation + ame.msettilemi 4 // mtilem = 4 (rows of C and A) + ame.msettileni 4 // mtilen = 4 (cols of C and B) + ame.msettileki 8 // mtilek = 8 (cols of A, rows of B) + + // Step 2: Zero the accumulation register (tile register 0) + ame.mzero 0 + + // Step 3: Load matrix A to tile register 0 (shape: mtilem x mtilek = 4x8) + ame.mlae32.m 0, %a_ptr, %stride_a : memref + + // Step 4: Load matrix B to tile register 1 (shape: mtilek x mtilen = 8x4) + ame.mlbe32.m 1, %b_ptr, %stride_b : memref + + // Step 5: Execute matrix multiply: acc0 = acc0 + tile0 x tile1 + ame.mma.w.mm.tile 0, 0, 1 + + // Step 6: Store result from accumulator 0 to memory + ame.msce32.m 0, %c_ptr, %stride_c : memref + + return + } + + // NOTE: High-level mma.w.mm operation (memref abstraction) requires + // additional lowering pass to convert memref to tile operations. + // For now, we only test the tile-level operations which map directly + // to LLVM intrinsics. +} + +// Expected lowering for tile-based operations: +// CHECK-LABEL: func.func @mma_w_mm_tile_demo +// CHECK: llvm.call @llvm.riscv.buddy.msettilemi +// CHECK: llvm.call @llvm.riscv.buddy.msettileni +// CHECK: llvm.call @llvm.riscv.buddy.msettileki +// CHECK: llvm.call @llvm.riscv.buddy.mzero +// CHECK: llvm.call @llvm.riscv.buddy.mlae32.m +// CHECK: llvm.call @llvm.riscv.buddy.mlbe32.m +// CHECK: llvm.call @llvm.riscv.buddy.mma.w.mm.tile +// CHECK: llvm.call @llvm.riscv.buddy.msce32.m diff --git a/examples/AMEDialect/mma-dw-mm.mlir b/examples/AMEDialect/mma-dw-mm.mlir new file mode 100644 index 0000000000..48c394b7e6 --- /dev/null +++ b/examples/AMEDialect/mma-dw-mm.mlir @@ -0,0 +1,19 @@ +// RUN: buddy-opt %s --lower-ame | FileCheck %s + +// Demo for mma.dw.mm (int64 matrix multiply-accumulate) +// Performs: md = md + ms1 × ms2 +// where ms1, ms2, and md are all int64 matrices. +// This instruction is useful for high-precision computations. + +module { + func.func @mma_dw_mm_demo(%md: memref<4x4xi64>, + %ms1: memref<4x8xi64>, + %ms2: memref<8x4xi64>) { + // int64 matrix multiply: C[4x4] += A[4x8] × B[8x4] + ame.mma.dw.mm %md, %ms1, %ms2 : memref<4x4xi64>, memref<4x8xi64>, memref<8x4xi64> + return + } + + // CHECK-LABEL: func.func @mma_dw_mm_demo + // CHECK: llvm.call @llvm.riscv.buddy.mma.dw.mm +} diff --git a/examples/AMEDialect/mma-w-mm.mlir b/examples/AMEDialect/mma-w-mm.mlir new file mode 100644 index 0000000000..edf9445cd4 --- /dev/null +++ b/examples/AMEDialect/mma-w-mm.mlir @@ -0,0 +1,18 @@ +// RUN: buddy-opt %s --lower-ame | FileCheck %s + +// Demo for mma.w.mm (int32 matrix multiply-accumulate) +// Performs: md = md + ms1 × ms2 +// where ms1, ms2, and md are all int32 matrices. + +module { + func.func @mma_w_mm_demo(%md: memref<4x4xi32>, + %ms1: memref<4x8xi32>, + %ms2: memref<8x4xi32>) { + // int32 matrix multiply: C[4x4] += A[4x8] × B[8x4] + ame.mma.w.mm %md, %ms1, %ms2 : memref<4x4xi32>, memref<4x8xi32>, memref<8x4xi32> + return + } + + // CHECK-LABEL: func.func @mma_w_mm_demo + // CHECK: llvm.call @llvm.riscv.buddy.mma.w.mm +} diff --git a/examples/AMEDialect/mqma-b-mm.mlir b/examples/AMEDialect/mqma-b-mm.mlir new file mode 100644 index 0000000000..4be06c4153 --- /dev/null +++ b/examples/AMEDialect/mqma-b-mm.mlir @@ -0,0 +1,15 @@ +// RUN: buddy-opt %s --lower-ame | FileCheck %s + +// Test for AME mqma.b.mm operation (int8 quad-widen matrix multiply) +// This performs: C = C + A * B where A and B are int8, C is int32 + +module { + func.func @test_mqma_b_mm(%C: memref<4x4xi32>, %A: memref<4x8xi8>, %B: memref<8x4xi8>) { + // int8 quad-widen matrix multiply: C[4x4] += A[4x8] × B[8x4] + ame.mqma.b.mm %C, %A, %B : memref<4x4xi32>, memref<4x8xi8>, memref<8x4xi8> + return + } + + // CHECK-LABEL: func.func @test_mqma_b_mm + // CHECK: llvm.call @llvm.riscv.buddy.mqma.b.mm +} diff --git a/examples/BuddyBert/README.md b/examples/BuddyBert/README.md index 0e69b5546e..7be5e6dbf1 100644 --- a/examples/BuddyBert/README.md +++ b/examples/BuddyBert/README.md @@ -9,7 +9,7 @@ This example shows how to use Buddy Compiler to compile a BERT model to MLIR cod 2. Set the `PYTHONPATH` environment variable. ```bash -$ export PYTHONPATH=/path-to-buddy-mlir/llvm/build/tools/mlir/python_packages/mlir_core:/path-to-buddy-mlir/build/python_packages:${PYTHONPATH} +$ export PYTHONPATH=/path-to-buddy-mlir/build/python_packages:${PYTHONPATH} ``` 3. Build and run the BERT example diff --git a/examples/BuddyDeepSeekR1/CMakeLists.txt b/examples/BuddyDeepSeekR1/CMakeLists.txt index 0d9da7ecae..d7140548f3 100644 --- a/examples/BuddyDeepSeekR1/CMakeLists.txt +++ b/examples/BuddyDeepSeekR1/CMakeLists.txt @@ -47,8 +47,10 @@ if(NOT IS_RVV_PLATFORM) ) add_custom_command( - OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/forward-bf16.mlir - ${CMAKE_CURRENT_BINARY_DIR}/subgraph0-bf16.mlir + OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/forward_prefill-bf16.mlir + ${CMAKE_CURRENT_BINARY_DIR}/subgraph0_prefill-bf16.mlir + ${CMAKE_CURRENT_BINARY_DIR}/forward_decode-bf16.mlir + ${CMAKE_CURRENT_BINARY_DIR}/subgraph0_decode-bf16.mlir ${CMAKE_CURRENT_BINARY_DIR}/arg0-bf16.data COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/import-deepseek-r1.py --output-dir ${CMAKE_CURRENT_BINARY_DIR} @@ -452,13 +454,15 @@ add_custom_command( VERBATIM) add_custom_command( - OUTPUT forward-bf16.o - COMMAND ${LLVM_TOOLS_BINARY_DIR}/mlir-opt ${CMAKE_CURRENT_BINARY_DIR}/forward-bf16.mlir - -pass-pipeline "builtin.module(func.func(tosa-to-linalg-named),func.func(tosa-to-linalg),func.func(tosa-to-tensor),func.func(tosa-to-arith))" | + OUTPUT forward_prefill-bf16.o + COMMAND ${BUDDY_BINARY_DIR}/buddy-opt ${CMAKE_CURRENT_BINARY_DIR}/forward_prefill-bf16.mlir + -simplify-tosa-reshape | + ${LLVM_TOOLS_BINARY_DIR}/mlir-opt + -pass-pipeline ${TOSA_PIPELINE} | ${BUDDY_BINARY_DIR}/buddy-opt -eliminate-empty-tensors -empty-tensor-to-alloc-tensor - -one-shot-bufferize="bufferize-function-boundaries" + -one-shot-bufferize=${BUFFERIZE_SIMPLE_OPTS} -expand-strided-metadata -ownership-based-buffer-deallocation -canonicalize @@ -467,14 +471,16 @@ add_custom_command( -cse -canonicalize -optimize-allocation-liveness - -matmul-parallel-vectorization-optimize + -eliminate-memref-copy + -assume-tight-memref-layout + -staticize-memref-layout + -matmul-vectorization-blis -batchmatmul-optimize -convert-linalg-to-affine-loops -affine-parallelize -convert-vector-to-scf -lower-affine - -convert-scf-to-openmp - -cse + -convert-scf-to-openmp=${OPENMP_NUM_THREADS} -memref-expand -arith-expand -convert-vector-to-llvm @@ -491,21 +497,23 @@ add_custom_command( -reconcile-unrealized-casts | ${LLVM_TOOLS_BINARY_DIR}/mlir-translate -mlir-to-llvmir | ${LLVM_TOOLS_BINARY_DIR}/llvm-as | - ${LLVM_TOOLS_BINARY_DIR}/llc -filetype=obj -relocation-model=pic -O3 - -o ${CMAKE_CURRENT_BINARY_DIR}/forward-bf16.o - DEPENDS buddy-opt ${CMAKE_CURRENT_BINARY_DIR}/forward-bf16.mlir - COMMENT "Building forward.o " + ${LLVM_TOOLS_BINARY_DIR}/llc ${LLC_RISCV_ATTRS} -filetype=obj -relocation-model=pic -O3 + -o ${CMAKE_CURRENT_BINARY_DIR}/forward_prefill-bf16.o + DEPENDS buddy-opt ${CMAKE_CURRENT_BINARY_DIR}/forward_prefill-bf16.mlir + COMMENT "Building forward_prefill-bf16.o " VERBATIM) add_custom_command( - OUTPUT subgraph-bf16.o - COMMAND ${LLVM_TOOLS_BINARY_DIR}/mlir-opt ${CMAKE_CURRENT_BINARY_DIR}/subgraph0-bf16.mlir - -pass-pipeline "builtin.module(func.func(tosa-to-linalg-named),func.func(tosa-to-linalg),func.func(tosa-to-tensor),func.func(tosa-to-arith))" | + OUTPUT subgraph_prefill-bf16.o + COMMAND ${BUDDY_BINARY_DIR}/buddy-opt ${CMAKE_CURRENT_BINARY_DIR}/subgraph0_prefill-bf16.mlir + -simplify-tosa-reshape | + ${LLVM_TOOLS_BINARY_DIR}/mlir-opt + -pass-pipeline ${TOSA_PIPELINE} | ${BUDDY_BINARY_DIR}/buddy-opt -eliminate-empty-tensors -empty-tensor-to-alloc-tensor -convert-elementwise-to-linalg - -one-shot-bufferize="bufferize-function-boundaries" + -one-shot-bufferize=${BUFFERIZE_SIMPLE_OPTS} -expand-strided-metadata -ownership-based-buffer-deallocation -canonicalize @@ -514,13 +522,17 @@ add_custom_command( -cse -canonicalize -optimize-allocation-liveness - -matmul-parallel-vectorization-optimize + -eliminate-memref-copy + -assume-tight-memref-layout + -staticize-memref-layout + -matmul-vectorization-blis -batchmatmul-optimize + -batchmatmul-transpose-b-vectorization -convert-linalg-to-affine-loops -affine-parallelize -convert-vector-to-scf -lower-affine - -convert-scf-to-openmp + -convert-scf-to-openmp=${OPENMP_NUM_THREADS} -cse -memref-expand -arith-expand @@ -538,15 +550,118 @@ add_custom_command( -reconcile-unrealized-casts | ${LLVM_TOOLS_BINARY_DIR}/mlir-translate -mlir-to-llvmir | ${LLVM_TOOLS_BINARY_DIR}/llvm-as | - ${LLVM_TOOLS_BINARY_DIR}/llc -filetype=obj -relocation-model=pic -O3 - -o ${CMAKE_CURRENT_BINARY_DIR}/subgraph-bf16.o - DEPENDS buddy-opt ${CMAKE_CURRENT_BINARY_DIR}/subgraph0-bf16.mlir - COMMENT "Building subgraph.o " + ${LLVM_TOOLS_BINARY_DIR}/llc ${LLC_RISCV_ATTRS} -filetype=obj -relocation-model=pic -O3 + -o ${CMAKE_CURRENT_BINARY_DIR}/subgraph_prefill-bf16.o + DEPENDS buddy-opt ${CMAKE_CURRENT_BINARY_DIR}/subgraph0_prefill-bf16.mlir + COMMENT "Building subgraph_prefill-bf16.o " + VERBATIM) + +add_custom_command( + OUTPUT forward_decode-bf16.o + COMMAND ${BUDDY_BINARY_DIR}/buddy-opt ${CMAKE_CURRENT_BINARY_DIR}/forward_decode-bf16.mlir + -simplify-tosa-reshape | + ${LLVM_TOOLS_BINARY_DIR}/mlir-opt + -pass-pipeline ${TOSA_PIPELINE} | + ${BUDDY_BINARY_DIR}/buddy-opt + -eliminate-empty-tensors + -empty-tensor-to-alloc-tensor + -one-shot-bufferize=${BUFFERIZE_SIMPLE_OPTS} + -expand-strided-metadata + -ownership-based-buffer-deallocation + -canonicalize + -buffer-deallocation-simplification + -bufferization-lower-deallocations + -cse + -canonicalize + -optimize-allocation-liveness + -eliminate-memref-copy + -assume-tight-memref-layout + -staticize-memref-layout + -matmul-vectorization-blis + -batchmatmul-optimize + -convert-linalg-to-affine-loops + -affine-parallelize + -convert-vector-to-scf + -lower-affine + -convert-scf-to-openmp=${OPENMP_NUM_THREADS} + -memref-expand + -arith-expand + -convert-vector-to-llvm + -convert-arith-to-llvm + -finalize-memref-to-llvm + -convert-scf-to-cf + -convert-cf-to-llvm + -llvm-request-c-wrappers + -convert-openmp-to-llvm + -convert-arith-to-llvm + -convert-math-to-llvm + -convert-math-to-libm + -convert-func-to-llvm + -reconcile-unrealized-casts | + ${LLVM_TOOLS_BINARY_DIR}/mlir-translate -mlir-to-llvmir | + ${LLVM_TOOLS_BINARY_DIR}/llvm-as | + ${LLVM_TOOLS_BINARY_DIR}/llc ${LLC_RISCV_ATTRS} -filetype=obj -relocation-model=pic -O3 + -o ${CMAKE_CURRENT_BINARY_DIR}/forward_decode-bf16.o + DEPENDS buddy-opt ${CMAKE_CURRENT_BINARY_DIR}/forward_decode-bf16.mlir + COMMENT "Building forward_decode-bf16.o " + VERBATIM) + +add_custom_command( + OUTPUT subgraph_decode-bf16.o + COMMAND ${BUDDY_BINARY_DIR}/buddy-opt ${CMAKE_CURRENT_BINARY_DIR}/subgraph0_decode-bf16.mlir + -simplify-tosa-reshape + -simplify-tosa-matmul-scalar | + ${LLVM_TOOLS_BINARY_DIR}/mlir-opt + -pass-pipeline ${TOSA_PIPELINE} | + ${BUDDY_BINARY_DIR}/buddy-opt + -eliminate-empty-tensors + -empty-tensor-to-alloc-tensor + -convert-elementwise-to-linalg + -one-shot-bufferize=${BUFFERIZE_SIMPLE_OPTS} + -expand-strided-metadata + -ownership-based-buffer-deallocation + -canonicalize + -buffer-deallocation-simplification + -bufferization-lower-deallocations + -cse + -canonicalize + -optimize-allocation-liveness + -eliminate-memref-copy + -assume-tight-memref-layout + -staticize-memref-layout + -matmul-vectorization-decode=vector-size=32 + -batch-matmul-vectorization-decode=vector-size=128 + -batchmatmul-transpose-b-vectorization=vector-size=16 + -convert-linalg-to-affine-loops + -convert-vector-to-scf + -lower-affine + -convert-scf-to-openmp=${OPENMP_NUM_THREADS} + -cse + -memref-expand + -arith-expand + -convert-vector-to-llvm + -convert-arith-to-llvm + -finalize-memref-to-llvm + -convert-scf-to-cf + -convert-cf-to-llvm + -llvm-request-c-wrappers + -convert-openmp-to-llvm + -convert-arith-to-llvm + -convert-math-to-llvm + -convert-math-to-libm + -convert-func-to-llvm + -reconcile-unrealized-casts | + ${LLVM_TOOLS_BINARY_DIR}/mlir-translate -mlir-to-llvmir | + ${LLVM_TOOLS_BINARY_DIR}/llvm-as | + ${LLVM_TOOLS_BINARY_DIR}/llc ${LLC_RISCV_ATTRS} -filetype=obj -relocation-model=pic -O3 + -o ${CMAKE_CURRENT_BINARY_DIR}/subgraph_decode-bf16.o + DEPENDS buddy-opt ${CMAKE_CURRENT_BINARY_DIR}/subgraph0_decode-bf16.mlir + COMMENT "Building subgraph_decode-bf16.o " VERBATIM) add_library(DEEPSEEKR1 STATIC forward_prefill.o subgraph_prefill.o forward_decode.o subgraph_decode.o) add_library(DEEPSEEKR1_F16 STATIC forward_prefill-f16.o subgraph_prefill-f16.o forward_decode-f16.o subgraph_decode-f16.o) -add_library(DEEPSEEKR1_BF16 STATIC forward-bf16.o subgraph-bf16.o) +add_library(DEEPSEEKR1_BF16 STATIC forward_prefill-bf16.o subgraph_prefill-bf16.o forward_decode-bf16.o subgraph_decode-bf16.o) SET_SOURCE_FILES_PROPERTIES( template.o @@ -586,8 +701,8 @@ target_compile_definitions(buddy-deepseek-r1-f16-run PRIVATE DEEPSEEKR1_EXAMPLE_BUILD_PATH="${DEEPSEEKR1_EXAMPLE_BUILD_PATH}/" ) target_compile_definitions(buddy-deepseek-r1-bf16-run PRIVATE - DEEPSEEKR1_EXAMPLE_PATH="${DEEPSEEKR1_EXAMPLE_PATH}" - DEEPSEEKR1_EXAMPLE_BUILD_PATH="${DEEPSEEKR1_EXAMPLE_BUILD_PATH}" + DEEPSEEKR1_EXAMPLE_PATH="${DEEPSEEKR1_EXAMPLE_PATH}/" + DEEPSEEKR1_EXAMPLE_BUILD_PATH="${DEEPSEEKR1_EXAMPLE_BUILD_PATH}/" ) target_compile_definitions(buddy-deepseek-r1-cli PRIVATE DEEPSEEKR1_EXAMPLE_PATH="${DEEPSEEKR1_EXAMPLE_PATH}/" diff --git a/examples/BuddyDeepSeekR1/README.md b/examples/BuddyDeepSeekR1/README.md index 30d50d773a..0a12ba1744 100644 --- a/examples/BuddyDeepSeekR1/README.md +++ b/examples/BuddyDeepSeekR1/README.md @@ -53,14 +53,14 @@ $ ninja check-buddy Set the `PYTHONPATH` environment variable. Make sure that the `PYTHONPATH` variable includes the directory of LLVM/MLIR python bindings and the directory of Buddy MLIR python packages. ```bash -$ export PYTHONPATH=/path-to-buddy-mlir/llvm/build/tools/mlir/python_packages/mlir_core:/path-to-buddy-mlir/build/python_packages:${PYTHONPATH} +$ export PYTHONPATH=/path-to-buddy-mlir/build/python_packages:${PYTHONPATH} // For example: // Navigate to your buddy-mlir/build directory $ cd buddy-mlir/build $ export BUDDY_MLIR_BUILD_DIR=$PWD $ export LLVM_MLIR_BUILD_DIR=$PWD/../llvm/build -$ export PYTHONPATH=${LLVM_MLIR_BUILD_DIR}/tools/mlir/python_packages/mlir_core:${BUDDY_MLIR_BUILD_DIR}/python_packages:${PYTHONPATH} +$ export PYTHONPATH=${BUDDY_MLIR_BUILD_DIR}/python_packages:${PYTHONPATH} ``` 3. Set model environment variable. diff --git a/examples/BuddyDeepSeekR1/buddy-deepseek-r1-bf16-main.cpp b/examples/BuddyDeepSeekR1/buddy-deepseek-r1-bf16-main.cpp index 1a554af7c9..45b93e8a14 100644 --- a/examples/BuddyDeepSeekR1/buddy-deepseek-r1-bf16-main.cpp +++ b/examples/BuddyDeepSeekR1/buddy-deepseek-r1-bf16-main.cpp @@ -1,4 +1,4 @@ -//===- buddy-deepseek-r1-bf16-main.cpp -----------------------------------===// +//===- buddy-deepseek-r1-bf16-main.cpp ------------------------------------===// // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -14,28 +14,172 @@ // //===----------------------------------------------------------------------===// +#include #include #include #include #include #include #include +#include #include #include #include #include using namespace buddy; - +double total_time = 0; constexpr size_t ParamsSize = 1777088064; constexpr size_t MaxVocabSize = 151936; -constexpr size_t MaxTokenLength = 40; +constexpr size_t MaxTokenLength = 1024; + +constexpr size_t NUM_LAYERS = 56; +constexpr size_t HiddenSize = 128; +constexpr size_t HeadNum = 2; + +struct MemRefContainer { + + MemRef kv0; + MemRef kv1; + MemRef kv2; + MemRef kv3; + MemRef kv4; + MemRef kv5; + MemRef kv6; + MemRef kv7; + MemRef kv8; + MemRef kv9; + MemRef kv10; + MemRef kv11; + MemRef kv12; + MemRef kv13; + MemRef kv14; + MemRef kv15; + MemRef kv16; + MemRef kv17; + MemRef kv18; + MemRef kv19; + MemRef kv20; + MemRef kv21; + MemRef kv22; + MemRef kv23; + MemRef kv24; + MemRef kv25; + MemRef kv26; + MemRef kv27; + MemRef kv28; + MemRef kv29; + MemRef kv30; + MemRef kv31; + MemRef kv32; + MemRef kv33; + MemRef kv34; + MemRef kv35; + MemRef kv36; + MemRef kv37; + MemRef kv38; + MemRef kv39; + MemRef kv40; + MemRef kv41; + MemRef kv42; + MemRef kv43; + MemRef kv44; + MemRef kv45; + MemRef kv46; + MemRef kv47; + MemRef kv48; + MemRef kv49; + MemRef kv50; + MemRef kv51; + MemRef kv52; + MemRef kv53; + MemRef kv54; + MemRef kv55; + + MemRef logits; + + std::array *, 56> kv_ptrs; + + MemRefContainer( + MemRef k0, MemRef k1, MemRef k2, + MemRef k3, MemRef k4, MemRef k5, + MemRef k6, MemRef k7, MemRef k8, + MemRef k9, MemRef k10, MemRef k11, + MemRef k12, MemRef k13, MemRef k14, + MemRef k15, MemRef k16, MemRef k17, + MemRef k18, MemRef k19, MemRef k20, + MemRef k21, MemRef k22, MemRef k23, + MemRef k24, MemRef k25, MemRef k26, + MemRef k27, MemRef k28, MemRef k29, + MemRef k30, MemRef k31, MemRef k32, + MemRef k33, MemRef k34, MemRef k35, + MemRef k36, MemRef k37, MemRef k38, + MemRef k39, MemRef k40, MemRef k41, + MemRef k42, MemRef k43, MemRef k44, + MemRef k45, MemRef k46, MemRef k47, + MemRef k48, MemRef k49, MemRef k50, + MemRef k51, MemRef k52, MemRef k53, + MemRef k54, MemRef k55, MemRef l) + : kv0(k0), kv1(k1), kv2(k2), kv3(k3), kv4(k4), kv5(k5), kv6(k6), kv7(k7), + kv8(k8), kv9(k9), kv10(k10), kv11(k11), kv12(k12), kv13(k13), kv14(k14), + kv15(k15), kv16(k16), kv17(k17), kv18(k18), kv19(k19), kv20(k20), + kv21(k21), kv22(k22), kv23(k23), kv24(k24), kv25(k25), kv26(k26), + kv27(k27), kv28(k28), kv29(k29), kv30(k30), kv31(k31), kv32(k32), + kv33(k33), kv34(k34), kv35(k35), kv36(k36), kv37(k37), kv38(k38), + kv39(k39), kv40(k40), kv41(k41), kv42(k42), kv43(k43), kv44(k44), + kv45(k45), kv46(k46), kv47(k47), kv48(k48), kv49(k49), kv50(k50), + kv51(k51), kv52(k52), kv53(k53), kv54(k54), kv55(k55), logits(l), + kv_ptrs{&kv0, &kv1, &kv2, &kv3, &kv4, &kv5, &kv6, &kv7, + + &kv8, &kv9, &kv10, &kv11, &kv12, &kv13, &kv14, &kv15, + + &kv16, &kv17, &kv18, &kv19, &kv20, &kv21, &kv22, &kv23, + + &kv24, &kv25, &kv26, &kv27, &kv28, &kv29, &kv30, &kv31, + + &kv32, &kv33, &kv34, &kv35, &kv36, &kv37, &kv38, &kv39, + + &kv40, &kv41, &kv42, &kv43, &kv44, &kv45, &kv46, &kv47, + + &kv48, &kv49, &kv50, &kv51, &kv52, &kv53, &kv54, &kv55} {} +}; /// Declare DeepSeekR1 forward function. -extern "C" void _mlir_ciface_forward(MemRef *result, - MemRef *arg0, - Text *arg1, - MemRef *arg2); +extern "C" void _mlir_ciface_forward_prefill(MemRefContainer *result, + MemRef *arg0, + Text *arg1); + +extern "C" void _mlir_ciface_forward_decode( + MemRefContainer *result, MemRef *arg0, + MemRef *arg1, MemRef *arg2, + MemRef *kv0, MemRef *kv1, + MemRef *kv2, MemRef *kv3, + MemRef *kv4, MemRef *kv5, + MemRef *kv6, MemRef *kv7, + MemRef *kv8, MemRef *kv9, + MemRef *kv10, MemRef *kv11, + MemRef *kv12, MemRef *kv13, + MemRef *kv14, MemRef *kv15, + MemRef *kv16, MemRef *kv17, + MemRef *kv18, MemRef *kv19, + MemRef *kv20, MemRef *kv21, + MemRef *kv22, MemRef *kv23, + MemRef *kv24, MemRef *kv25, + MemRef *kv26, MemRef *kv27, + MemRef *kv28, MemRef *kv29, + MemRef *kv30, MemRef *kv31, + MemRef *kv32, MemRef *kv33, + MemRef *kv34, MemRef *kv35, + MemRef *kv36, MemRef *kv37, + MemRef *kv38, MemRef *kv39, + MemRef *kv40, MemRef *kv41, + MemRef *kv42, MemRef *kv43, + MemRef *kv44, MemRef *kv45, + MemRef *kv46, MemRef *kv47, + MemRef *kv48, MemRef *kv49, + MemRef *kv50, MemRef *kv51, + MemRef *kv52, MemRef *kv53, + MemRef *kv54, MemRef *kv55); // ----------------------------------------------------------------------------- // Helper Functions @@ -54,6 +198,7 @@ void printLogLabel() { std::cout << "\033[34;1m[Log] \033[0m"; } /// Print information for each iteration. void printIterInfo(size_t iterIdx, std::string str, double time) { + total_time += time; std::cout << "\033[32;1m[Iteration " << iterIdx << "] \033[0m"; std::cout << "Token: " << str << " | " << "Time: " << time << "s" << std::endl; @@ -105,11 +250,10 @@ void loadParameters(const std::string ¶mFilePath, // bf16 to f32 conversion function (Brain floating point -> single precision) float decode_bf16(uint16_t h) { - // BF16 format: 1 sign bit, 8 exponent bits, 7 mantissa bits - // F32 format: 1 sign bit, 8 exponent bits, 23 mantissa bits - // BF16 is essentially F32 with the lower 16 bits truncated uint32_t f32_bits = static_cast(h) << 16; - return *reinterpret_cast(&f32_bits); + float out; + std::memcpy(&out, &f32_bits, sizeof(out)); + return out; } int findMaxIndex(const uint16_t *start, size_t length) { @@ -125,6 +269,28 @@ int findMaxIndex(const uint16_t *start, size_t length) { return maxIdx; } +void copy_kv_by_cache_position_block(const MemRefContainer &prefill, + MemRefContainer &decode, + int cache_position) { + constexpr int num_kv = 56; + int copy_len = std::min(cache_position, (int)MaxTokenLength); + + for (int k = 0; k < num_kv; ++k) { + auto &src = *prefill.kv_ptrs[k]; + auto &dst = *decode.kv_ptrs[k]; + + for (int h = 0; h < (int)HeadNum; ++h) { + size_t bytes_to_copy = + static_cast(copy_len) * HiddenSize * sizeof(uint16_t); + + uint16_t *src_ptr = src.getData() + h * MaxTokenLength * HiddenSize; + uint16_t *dst_ptr = dst.getData() + h * MaxTokenLength * HiddenSize; + + std::memcpy(dst_ptr, src_ptr, bytes_to_copy); + } + } +} + // ----------------------------------------------------------------------------- // DeepSeekR1 Inference Main Entry // ----------------------------------------------------------------------------- @@ -138,8 +304,8 @@ int main() { /// Define directories of vacabulary and parameter file. std::string deepSeekR1Dir = DEEPSEEKR1_EXAMPLE_PATH; std::string deepSeekR1BuildDir = DEEPSEEKR1_EXAMPLE_BUILD_PATH; - const std::string vocabDir = deepSeekR1Dir + "/vocab.txt"; - const std::string paramsDir = deepSeekR1BuildDir + "/arg0-bf16.data"; + const std::string vocabDir = deepSeekR1Dir + "vocab.txt"; + const std::string paramsDir = deepSeekR1BuildDir + "arg0-bf16.data"; /// Get user message. std::string inputStr; @@ -151,43 +317,180 @@ int main() { // - Output container. // - Parameters container. Text outputContainer; - MemRef resultContainer({1, 40, 151936}); - Text inputContainer(inputStr); - MemRef paramsContainer({ParamsSize}); - MemRef attention_mask({1, 40}, 0); + Text inputContainerPrefill(inputStr); + MemRef inputContainerDecode({1, 1}, 0LL); + MemRef ParamsContainer({ParamsSize}); + MemRef cachePosition({1}, 0LL); + + MemRef logits_prefill({1, MaxTokenLength, MaxVocabSize}); + + MemRef kv0({1, HeadNum, MaxTokenLength, HiddenSize}, 0); + MemRef kv1({1, HeadNum, MaxTokenLength, HiddenSize}, 0); + MemRef kv2({1, HeadNum, MaxTokenLength, HiddenSize}, 0); + MemRef kv3({1, HeadNum, MaxTokenLength, HiddenSize}, 0); + MemRef kv4({1, HeadNum, MaxTokenLength, HiddenSize}, 0); + MemRef kv5({1, HeadNum, MaxTokenLength, HiddenSize}, 0); + MemRef kv6({1, HeadNum, MaxTokenLength, HiddenSize}, 0); + MemRef kv7({1, HeadNum, MaxTokenLength, HiddenSize}, 0); + + MemRef kv8({1, HeadNum, MaxTokenLength, HiddenSize}, 0); + MemRef kv9({1, HeadNum, MaxTokenLength, HiddenSize}, 0); + MemRef kv10({1, HeadNum, MaxTokenLength, HiddenSize}, 0); + MemRef kv11({1, HeadNum, MaxTokenLength, HiddenSize}, 0); + MemRef kv12({1, HeadNum, MaxTokenLength, HiddenSize}, 0); + MemRef kv13({1, HeadNum, MaxTokenLength, HiddenSize}, 0); + MemRef kv14({1, HeadNum, MaxTokenLength, HiddenSize}, 0); + MemRef kv15({1, HeadNum, MaxTokenLength, HiddenSize}, 0); + + MemRef kv16({1, HeadNum, MaxTokenLength, HiddenSize}, 0); + MemRef kv17({1, HeadNum, MaxTokenLength, HiddenSize}, 0); + MemRef kv18({1, HeadNum, MaxTokenLength, HiddenSize}, 0); + MemRef kv19({1, HeadNum, MaxTokenLength, HiddenSize}, 0); + MemRef kv20({1, HeadNum, MaxTokenLength, HiddenSize}, 0); + MemRef kv21({1, HeadNum, MaxTokenLength, HiddenSize}, 0); + MemRef kv22({1, HeadNum, MaxTokenLength, HiddenSize}, 0); + MemRef kv23({1, HeadNum, MaxTokenLength, HiddenSize}, 0); + + MemRef kv24({1, HeadNum, MaxTokenLength, HiddenSize}, 0); + MemRef kv25({1, HeadNum, MaxTokenLength, HiddenSize}, 0); + MemRef kv26({1, HeadNum, MaxTokenLength, HiddenSize}, 0); + MemRef kv27({1, HeadNum, MaxTokenLength, HiddenSize}, 0); + MemRef kv28({1, HeadNum, MaxTokenLength, HiddenSize}, 0); + MemRef kv29({1, HeadNum, MaxTokenLength, HiddenSize}, 0); + MemRef kv30({1, HeadNum, MaxTokenLength, HiddenSize}, 0); + MemRef kv31({1, HeadNum, MaxTokenLength, HiddenSize}, 0); + + MemRef kv32({1, HeadNum, MaxTokenLength, HiddenSize}, 0); + MemRef kv33({1, HeadNum, MaxTokenLength, HiddenSize}, 0); + MemRef kv34({1, HeadNum, MaxTokenLength, HiddenSize}, 0); + MemRef kv35({1, HeadNum, MaxTokenLength, HiddenSize}, 0); + MemRef kv36({1, HeadNum, MaxTokenLength, HiddenSize}, 0); + MemRef kv37({1, HeadNum, MaxTokenLength, HiddenSize}, 0); + MemRef kv38({1, HeadNum, MaxTokenLength, HiddenSize}, 0); + MemRef kv39({1, HeadNum, MaxTokenLength, HiddenSize}, 0); + + MemRef kv40({1, HeadNum, MaxTokenLength, HiddenSize}, 0); + MemRef kv41({1, HeadNum, MaxTokenLength, HiddenSize}, 0); + MemRef kv42({1, HeadNum, MaxTokenLength, HiddenSize}, 0); + MemRef kv43({1, HeadNum, MaxTokenLength, HiddenSize}, 0); + MemRef kv44({1, HeadNum, MaxTokenLength, HiddenSize}, 0); + MemRef kv45({1, HeadNum, MaxTokenLength, HiddenSize}, 0); + MemRef kv46({1, HeadNum, MaxTokenLength, HiddenSize}, 0); + MemRef kv47({1, HeadNum, MaxTokenLength, HiddenSize}, 0); + + MemRef kv48({1, HeadNum, MaxTokenLength, HiddenSize}, 0); + MemRef kv49({1, HeadNum, MaxTokenLength, HiddenSize}, 0); + MemRef kv50({1, HeadNum, MaxTokenLength, HiddenSize}, 0); + MemRef kv51({1, HeadNum, MaxTokenLength, HiddenSize}, 0); + MemRef kv52({1, HeadNum, MaxTokenLength, HiddenSize}, 0); + MemRef kv53({1, HeadNum, MaxTokenLength, HiddenSize}, 0); + MemRef kv54({1, HeadNum, MaxTokenLength, HiddenSize}, 0); + MemRef kv55({1, HeadNum, MaxTokenLength, HiddenSize}, 0); + + MemRefContainer prefillResultContainer( + kv0, kv1, kv2, kv3, kv4, kv5, kv6, kv7, kv8, kv9, kv10, kv11, kv12, kv13, + kv14, kv15, kv16, kv17, kv18, kv19, kv20, kv21, kv22, kv23, kv24, kv25, + kv26, kv27, kv28, kv29, kv30, kv31, kv32, kv33, kv34, kv35, kv36, kv37, + kv38, kv39, kv40, kv41, kv42, kv43, kv44, kv45, kv46, kv47, kv48, kv49, + kv50, kv51, kv52, kv53, kv54, kv55, logits_prefill); + MemRefContainer *ptrPrefillResultContainer = &prefillResultContainer; /// Fill data into containers // - Input: register vocabulary and tokenize the input string. // - Output: register vocabulary. // - Parameters: load parameters from the `arg0` file into the container. - tokenizeInput(vocabDir, inputContainer); - for (int i = 0; i < (int)inputContainer.getTokenCnt(); i++) { - attention_mask.getData()[i] = 1; - } + tokenizeInput(vocabDir, inputContainerPrefill); outputContainer.loadVocab(vocabDir); - loadParameters(paramsDir, paramsContainer); + loadParameters(paramsDir, ParamsContainer); /// Run DeepSeekR1 Inference // - Perform the forward function. // - Find and append the generated token. // - Continue iterating until the terminal condition is met. - int generateLen = MaxTokenLength - inputContainer.getTokenCnt(); - for (int i = 0; i < generateLen; i++) { + + double prefillTokensPerSec = 0.0; + const auto inferenceStart = std::chrono::high_resolution_clock::now(); + _mlir_ciface_forward_prefill(ptrPrefillResultContainer, &ParamsContainer, + &inputContainerPrefill); + const auto inferenceEnd = std::chrono::high_resolution_clock::now(); + const std::chrono::duration inferenceTime = + inferenceEnd - inferenceStart; + + int tokenIndex = inputContainerPrefill.getTokenCnt() - 1; + const uint16_t *startPtr = + ptrPrefillResultContainer->logits.getData() + tokenIndex * MaxVocabSize; + int maxIndex = findMaxIndex(startPtr, MaxVocabSize); + std::string tok = inputContainerPrefill.getStr(maxIndex); + printIterInfo(0, tok, inferenceTime.count() / 1000); + const double prefillSeconds = inferenceTime.count() / 1000.0; + if (prefillSeconds > 0.0) { + prefillTokensPerSec = static_cast(MaxTokenLength) / prefillSeconds; + } + inputContainerDecode.getData()[0] = (long long)maxIndex; + outputContainer.appendTokenIdx(maxIndex); + + MemRef logits_decode({1, 1, MaxVocabSize}); + + MemRefContainer decodeResultContainer( + kv0, kv1, kv2, kv3, kv4, kv5, kv6, kv7, kv8, kv9, kv10, kv11, kv12, kv13, + kv14, kv15, kv16, kv17, kv18, kv19, kv20, kv21, kv22, kv23, kv24, kv25, + kv26, kv27, kv28, kv29, kv30, kv31, kv32, kv33, kv34, kv35, kv36, kv37, + kv38, kv39, kv40, kv41, kv42, kv43, kv44, kv45, kv46, kv47, kv48, kv49, + kv50, kv51, kv52, kv53, kv54, kv55, logits_decode); + + MemRefContainer *ptrDecodeResultContainer = &decodeResultContainer; + + copy_kv_by_cache_position_block(prefillResultContainer, decodeResultContainer, + inputContainerPrefill.getTokenCnt()); + + cachePosition.getData()[0] = inputContainerPrefill.getTokenCnt(); + int generateLen = MaxTokenLength - inputContainerPrefill.getTokenCnt(); + double decodeTimeAccumMs = 0.0; + size_t decodeTokens = 0; + for (int i = 1; i <= generateLen; i++) { const auto inferenceStart = std::chrono::high_resolution_clock::now(); - // Execute the forward pass of the model. - _mlir_ciface_forward(&resultContainer, ¶msContainer, &inputContainer, - &attention_mask); + _mlir_ciface_forward_decode( + ptrDecodeResultContainer, &ParamsContainer, &inputContainerDecode, + &cachePosition, &ptrDecodeResultContainer->kv0, + &ptrDecodeResultContainer->kv1, &ptrDecodeResultContainer->kv2, + &ptrDecodeResultContainer->kv3, &ptrDecodeResultContainer->kv4, + &ptrDecodeResultContainer->kv5, &ptrDecodeResultContainer->kv6, + &ptrDecodeResultContainer->kv7, &ptrDecodeResultContainer->kv8, + &ptrDecodeResultContainer->kv9, &ptrDecodeResultContainer->kv10, + &ptrDecodeResultContainer->kv11, &ptrDecodeResultContainer->kv12, + &ptrDecodeResultContainer->kv13, &ptrDecodeResultContainer->kv14, + &ptrDecodeResultContainer->kv15, &ptrDecodeResultContainer->kv16, + &ptrDecodeResultContainer->kv17, &ptrDecodeResultContainer->kv18, + &ptrDecodeResultContainer->kv19, &ptrDecodeResultContainer->kv20, + &ptrDecodeResultContainer->kv21, &ptrDecodeResultContainer->kv22, + &ptrDecodeResultContainer->kv23, &ptrDecodeResultContainer->kv24, + &ptrDecodeResultContainer->kv25, &ptrDecodeResultContainer->kv26, + &ptrDecodeResultContainer->kv27, &ptrDecodeResultContainer->kv28, + &ptrDecodeResultContainer->kv29, &ptrDecodeResultContainer->kv30, + &ptrDecodeResultContainer->kv31, &ptrDecodeResultContainer->kv32, + &ptrDecodeResultContainer->kv33, &ptrDecodeResultContainer->kv34, + &ptrDecodeResultContainer->kv35, &ptrDecodeResultContainer->kv36, + &ptrDecodeResultContainer->kv37, &ptrDecodeResultContainer->kv38, + &ptrDecodeResultContainer->kv39, &ptrDecodeResultContainer->kv40, + &ptrDecodeResultContainer->kv41, &ptrDecodeResultContainer->kv42, + &ptrDecodeResultContainer->kv43, &ptrDecodeResultContainer->kv44, + &ptrDecodeResultContainer->kv45, &ptrDecodeResultContainer->kv46, + &ptrDecodeResultContainer->kv47, &ptrDecodeResultContainer->kv48, + &ptrDecodeResultContainer->kv49, &ptrDecodeResultContainer->kv50, + &ptrDecodeResultContainer->kv51, &ptrDecodeResultContainer->kv52, + &ptrDecodeResultContainer->kv53, &ptrDecodeResultContainer->kv54, + &ptrDecodeResultContainer->kv55); const auto inferenceEnd = std::chrono::high_resolution_clock::now(); const std::chrono::duration inferenceTime = inferenceEnd - inferenceStart; + decodeTimeAccumMs += inferenceTime.count(); + decodeTokens += 1; // Determine the generated token. - int tokenIndex = inputContainer.getTokenCnt() - 1; - const uint16_t *startPtr = - resultContainer.getData() + tokenIndex * MaxVocabSize; - int maxIndex = findMaxIndex(startPtr, MaxVocabSize); - std::string tok = inputContainer.getStr(maxIndex); + const uint16_t *startPtr = ptrDecodeResultContainer->logits.getData(); + maxIndex = findMaxIndex(startPtr, MaxVocabSize); + std::string tok = inputContainerPrefill.getStr(maxIndex); // Print the generated token and inference time. printIterInfo(i, tok, inferenceTime.count() / 1000); @@ -196,14 +499,23 @@ int main() { break; } // Append the generated token into the input and output container. - inputContainer.appendTokenIdx(maxIndex); - attention_mask.getData()[MaxTokenLength - generateLen + i] = 1; + inputContainerDecode.getData()[0] = maxIndex; outputContainer.appendTokenIdx(maxIndex); - free(resultContainer.release()); + cachePosition.getData()[0] += 1; } + const double decodeSeconds = decodeTimeAccumMs / 1000.0; + const double decodeTokensPerSec = + decodeSeconds > 0.0 ? static_cast(decodeTokens) / decodeSeconds + : 0.0; + /// Print the final result - std::cout << "\n\033[33;1m[Input]\033[0m " << inputStr << std::endl; + std::cout << "\n\033[33;1m[Total time]\033[0m " << total_time << std::endl; + std::cout << "\033[33;1m[Prefilling]\033[0m " << prefillTokensPerSec + << " tokens/s" << std::endl; + std::cout << "\033[33;1m[Decoding]\033[0m " << decodeTokensPerSec + << " tokens/s" << std::endl; + std::cout << "\033[33;1m[Input]\033[0m " << inputStr << std::endl; std::cout << "\033[33;1m[Output]\033[0m " << outputContainer.revertDeepSeekR1() << std::endl; diff --git a/examples/BuddyDeepSeekR1/import-deepseek-r1.py b/examples/BuddyDeepSeekR1/import-deepseek-r1.py index 3761ba8a53..adce100aab 100755 --- a/examples/BuddyDeepSeekR1/import-deepseek-r1.py +++ b/examples/BuddyDeepSeekR1/import-deepseek-r1.py @@ -91,30 +91,18 @@ model.config.use_cache = False # Initialize Dynamo Compiler with specific configurations as an importer. -if args.precision == "f16": - dynamo_compiler_prefill = DynamoCompiler( - primary_registry=tosa.ops_registry, - aot_autograd_decomposition=inductor_decomp, - func_name="forward_prefill", - ) - - dynamo_compiler_decode = DynamoCompiler( - primary_registry=tosa.ops_registry, - aot_autograd_decomposition=inductor_decomp, - func_name="forward_decode", - ) -else: - dynamo_compiler_prefill = DynamoCompiler( - primary_registry=tosa.ops_registry, - aot_autograd_decomposition=inductor_decomp, - func_name="forward_prefill", - ) +prefill_func_name = "forward_prefill" +dynamo_compiler_prefill = DynamoCompiler( + primary_registry=tosa.ops_registry, + aot_autograd_decomposition=inductor_decomp, + func_name=prefill_func_name, +) - dynamo_compiler_decode = DynamoCompiler( - primary_registry=tosa.ops_registry, - aot_autograd_decomposition=inductor_decomp, - func_name="forward_decode", - ) +dynamo_compiler_decode = DynamoCompiler( + primary_registry=tosa.ops_registry, + aot_autograd_decomposition=inductor_decomp, + func_name="forward_decode", +) # Import the model into MLIR module and parameters. with torch.no_grad(): @@ -311,13 +299,13 @@ print(driver_decode.construct_main_graph(True), file=module_file) elif args.precision == "bf16": with open( - os.path.join(output_dir, "subgraph0-bf16.mlir"), "w" + os.path.join(output_dir, "subgraph0_prefill-bf16.mlir"), "w" ) as module_file: - print(driver.subgraphs[0]._imported_module, file=module_file) + print(driver_prefill.subgraphs[0]._imported_module, file=module_file) with open( - os.path.join(output_dir, "forward-bf16.mlir"), "w" + os.path.join(output_dir, "forward_prefill-bf16.mlir"), "w" ) as module_file: - print(driver.construct_main_graph(True), file=module_file) + print(driver_prefill.construct_main_graph(True), file=module_file) # Convert BF16 parameters to float32 first, then to numpy all_param = numpy.concatenate( [param.detach().float().numpy().reshape([-1]) for param in params] @@ -327,6 +315,15 @@ all_param.astype(numpy.float32).tobytes(), dtype=numpy.uint16 )[1::2] all_param_bf16.tofile(os.path.join(output_dir, "arg0-bf16.data")) + + with open( + os.path.join(output_dir, "subgraph0_decode-bf16.mlir"), "w" + ) as module_file: + print(driver_decode.subgraphs[0]._imported_module, file=module_file) + with open( + os.path.join(output_dir, "forward_decode-bf16.mlir"), "w" + ) as module_file: + print(driver_decode.construct_main_graph(True), file=module_file) else: with open( os.path.join(output_dir, "subgraph0_prefill.mlir"), "w" diff --git a/examples/BuddyGraph/README.md b/examples/BuddyGraph/README.md index d7b977f57e..7ae8bd48aa 100644 --- a/examples/BuddyGraph/README.md +++ b/examples/BuddyGraph/README.md @@ -13,11 +13,11 @@ (buddy)$ cd buddy-mlir/build (buddy)$ export BUDDY_MLIR_BUILD_DIR=$PWD (buddy)$ export LLVM_MLIR_BUILD_DIR=$PWD/../llvm/build -(buddy)$ export PYTHONPATH=${LLVM_MLIR_BUILD_DIR}/tools/mlir/python_packages/mlir_core:${BUDDY_MLIR_BUILD_DIR}/python_packages:${PYTHONPATH} +(buddy)$ export PYTHONPATH=${BUDDY_MLIR_BUILD_DIR}/python_packages:${PYTHONPATH} ``` 3. Run the Examples ``` (buddy)$ cd examples/BuddyGraph (buddy)$ python import-dynamo-break.py -``` \ No newline at end of file +``` diff --git a/examples/BuddyJIT/README.md b/examples/BuddyJIT/README.md index b1b928526a..1cfc7967d6 100644 --- a/examples/BuddyJIT/README.md +++ b/examples/BuddyJIT/README.md @@ -51,14 +51,14 @@ ninja check-buddy And then set the `PYTHONPATH` to where the packages are built. ```bash -export PYTHONPATH=/path-to-buddy-mlir/llvm/build/tools/mlir/python_packages/mlir_core:/path-to-buddy-mlir/build/python_packages:${PYTHONPATH} +export PYTHONPATH=/path-to-buddy-mlir/build/python_packages:${PYTHONPATH} # For example: # Navigate to your buddy-mlir/build directory cd buddy-mlir/build export BUDDY_MLIR_BUILD_DIR=$PWD export LLVM_MLIR_BUILD_DIR=$PWD/../llvm/build -export PYTHONPATH=${LLVM_MLIR_BUILD_DIR}/tools/mlir/python_packages/mlir_core:${BUDDY_MLIR_BUILD_DIR}/python_packages:${PYTHONPATH} +export PYTHONPATH=${BUDDY_MLIR_BUILD_DIR}/python_packages:${PYTHONPATH} ``` ## Matrix Multiply Demo @@ -100,4 +100,4 @@ This example will also validate the correctness of result and print the running $ python examples/BuddyJIT/pytorch_matrix_multiplication.py Is MLIR equal to Torch? True MLIR time: 27291.54ms, Torch time: 325.96ms -``` \ No newline at end of file +``` diff --git a/examples/BuddyLeNet/README.md b/examples/BuddyLeNet/README.md index 1552845177..8f84ce991c 100644 --- a/examples/BuddyLeNet/README.md +++ b/examples/BuddyLeNet/README.md @@ -74,7 +74,7 @@ Make sure you are in the build directory. ```bash $ export BUDDY_MLIR_BUILD_DIR=$PWD $ export LLVM_MLIR_BUILD_DIR=$PWD/../llvm/build -$ export PYTHONPATH=${LLVM_MLIR_BUILD_DIR}/tools/mlir/python_packages/mlir_core:${BUDDY_MLIR_BUILD_DIR}/python_packages:${PYTHONPATH} +$ export PYTHONPATH=${BUDDY_MLIR_BUILD_DIR}/python_packages:${PYTHONPATH} ``` ### Build and run the LeNet example diff --git a/examples/BuddyLlama/README.md b/examples/BuddyLlama/README.md index 4416ef3a60..4d0299e4f7 100644 --- a/examples/BuddyLlama/README.md +++ b/examples/BuddyLlama/README.md @@ -67,14 +67,14 @@ $ ninja check-buddy Set the `PYTHONPATH` environment variable. Make sure that the `PYTHONPATH` variable includes the directory of LLVM/MLIR python bindings and the directory of Buddy MLIR python packages. ``` -$ export PYTHONPATH=/path-to-buddy-mlir/llvm/build/tools/mlir/python_packages/mlir_core:/path-to-buddy-mlir/build/python_packages:${PYTHONPATH} +$ export PYTHONPATH=/path-to-buddy-mlir/build/python_packages:${PYTHONPATH} // For example: // Navigate to your buddy-mlir/build directory $ cd buddy-mlir/build $ export BUDDY_MLIR_BUILD_DIR=$PWD $ export LLVM_MLIR_BUILD_DIR=$PWD/../llvm/build -$ export PYTHONPATH=${LLVM_MLIR_BUILD_DIR}/tools/mlir/python_packages/mlir_core:${BUDDY_MLIR_BUILD_DIR}/python_packages:${PYTHONPATH} +$ export PYTHONPATH=${BUDDY_MLIR_BUILD_DIR}/python_packages:${PYTHONPATH} ``` 6. Build and run LLaMA example diff --git a/examples/BuddyMobileNetV3/README.md b/examples/BuddyMobileNetV3/README.md index 0d50dd34e4..91645b4939 100644 --- a/examples/BuddyMobileNetV3/README.md +++ b/examples/BuddyMobileNetV3/README.md @@ -29,7 +29,7 @@ Make sure you are in the build directory. ```bash $ export BUDDY_MLIR_BUILD_DIR=$PWD $ export LLVM_MLIR_BUILD_DIR=$PWD/../llvm/build -$ export PYTHONPATH=${LLVM_MLIR_BUILD_DIR}/tools/mlir/python_packages/mlir_core:${BUDDY_MLIR_BUILD_DIR}/python_packages:${PYTHONPATH} +$ export PYTHONPATH=${BUDDY_MLIR_BUILD_DIR}/python_packages:${PYTHONPATH} ``` 3. Build and run the MobileNetV3 example @@ -40,4 +40,3 @@ $ ninja buddy-mobilenetv3-run $ cd bin $ ./buddy-mobilenetv3-run ``` - diff --git a/examples/BuddyQwen3/README.md b/examples/BuddyQwen3/README.md index 4731d0a9aa..0c913abc29 100644 --- a/examples/BuddyQwen3/README.md +++ b/examples/BuddyQwen3/README.md @@ -53,14 +53,14 @@ $ ninja check-buddy Set the `PYTHONPATH` environment variable. Make sure that the `PYTHONPATH` variable includes the directory of LLVM/MLIR python bindings and the directory of Buddy MLIR python packages. ```bash -$ export PYTHONPATH=/path-to-buddy-mlir/llvm/build/tools/mlir/python_packages/mlir_core:/path-to-buddy-mlir/build/python_packages:${PYTHONPATH} +$ export PYTHONPATH=/path-to-buddy-mlir/build/python_packages:${PYTHONPATH} // For example: // Navigate to your buddy-mlir/build directory $ cd buddy-mlir/build $ export BUDDY_MLIR_BUILD_DIR=$PWD $ export LLVM_MLIR_BUILD_DIR=$PWD/../llvm/build -$ export PYTHONPATH=${LLVM_MLIR_BUILD_DIR}/tools/mlir/python_packages/mlir_core:${BUDDY_MLIR_BUILD_DIR}/python_packages:${PYTHONPATH} +$ export PYTHONPATH=${BUDDY_MLIR_BUILD_DIR}/python_packages:${PYTHONPATH} ``` 3. Set model environment variable. diff --git a/examples/BuddyResNet18/README.md b/examples/BuddyResNet18/README.md index 5374dc4707..fa3cf665c9 100644 --- a/examples/BuddyResNet18/README.md +++ b/examples/BuddyResNet18/README.md @@ -48,14 +48,14 @@ $ ninja check-buddy Set the `PYTHONPATH` environment variable. Make sure that the `PYTHONPATH` variable includes the directory of LLVM/MLIR python bindings and the directory of Buddy MLIR python packages. ``` -$ export PYTHONPATH=/path-to-buddy-mlir/llvm/build/tools/mlir/python_packages/mlir_core:/path-to-buddy-mlir/build/python_packages:${PYTHONPATH} +$ export PYTHONPATH=/path-to-buddy-mlir/build/python_packages:${PYTHONPATH} // For example: // Navigate to your buddy-mlir/build directory $ cd buddy-mlir/build $ export BUDDY_MLIR_BUILD_DIR=$PWD $ export LLVM_MLIR_BUILD_DIR=$PWD/../llvm/build -$ export PYTHONPATH=${LLVM_MLIR_BUILD_DIR}/tools/mlir/python_packages/mlir_core:${BUDDY_MLIR_BUILD_DIR}/python_packages:${PYTHONPATH} +$ export PYTHONPATH=${BUDDY_MLIR_BUILD_DIR}/python_packages:${PYTHONPATH} ``` 4. Build and run the ResNet example @@ -65,4 +65,4 @@ $ cmake -G Ninja .. -DBUDDY_RESNET_EXAMPLES=ON $ ninja buddy-resnet-run $ cd bin $ ./buddy-resnet-run -``` \ No newline at end of file +``` diff --git a/examples/BuddyStableDiffusion/README.md b/examples/BuddyStableDiffusion/README.md index 876db96e0b..e1fae9d4c1 100644 --- a/examples/BuddyStableDiffusion/README.md +++ b/examples/BuddyStableDiffusion/README.md @@ -46,14 +46,14 @@ $ ninja check-buddy Set the `PYTHONPATH` environment variable. Make sure that the `PYTHONPATH` variable includes the directory of LLVM/MLIR python bindings and the directory of Buddy MLIR python packages. ``` -$ export PYTHONPATH=/path-to-buddy-mlir/llvm/build/tools/mlir/python_packages/mlir_core:/path-to-buddy-mlir/build/python_packages:${PYTHONPATH} +$ export PYTHONPATH=/path-to-buddy-mlir/build/python_packages:${PYTHONPATH} // For example: // Navigate to your buddy-mlir/build directory $ cd buddy-mlir/build $ export BUDDY_MLIR_BUILD_DIR=$PWD $ export LLVM_MLIR_BUILD_DIR=$PWD/../llvm/build -$ export PYTHONPATH=${LLVM_MLIR_BUILD_DIR}/tools/mlir/python_packages/mlir_core:${BUDDY_MLIR_BUILD_DIR}/python_packages:${PYTHONPATH} +$ export PYTHONPATH=${BUDDY_MLIR_BUILD_DIR}/python_packages:${PYTHONPATH} ``` 6. Build and run Stable Diffusion example diff --git a/examples/BuddyTransformer/CMakeLists.txt b/examples/BuddyTransformer/CMakeLists.txt index 3b9809012c..6e03d826c0 100644 --- a/examples/BuddyTransformer/CMakeLists.txt +++ b/examples/BuddyTransformer/CMakeLists.txt @@ -76,7 +76,6 @@ add_custom_command( -convert-vector-to-scf -lower-affine -convert-scf-to-openmp - # -func-bufferize-dynamic-offset -cse -memref-expand -arith-expand @@ -122,7 +121,6 @@ add_custom_command( -convert-vector-to-scf -lower-affine -convert-scf-to-openmp - # -func-bufferize-dynamic-offset -cse -memref-expand -arith-expand @@ -239,7 +237,6 @@ add_custom_command( -convert-vector-to-scf -lower-affine -convert-scf-to-openmp - # -func-bufferize-dynamic-offset -cse -memref-expand -arith-expand @@ -302,7 +299,7 @@ set_target_properties(transformer-runner-staged PROPERTIES # Include Buddy headers target_include_directories(transformer-runner-staged PRIVATE - ${CMAKE_SOURCE_DIR}/frontend/Interfaces + ${BUDDY_MLIR_INTERFACE_DIR} ) # Define paths for the staged executable @@ -375,7 +372,7 @@ set_target_properties(transformer-runner PROPERTIES # Include Buddy headers for one-step executable target_include_directories(transformer-runner PRIVATE - ${CMAKE_SOURCE_DIR}/frontend/Interfaces + ${BUDDY_MLIR_INTERFACE_DIR} ) # Define paths for the one-step executable @@ -424,7 +421,7 @@ set_target_properties(transformer-runner-timed PROPERTIES # Include Buddy headers for timed executable target_include_directories(transformer-runner-timed PRIVATE - ${CMAKE_SOURCE_DIR}/frontend/Interfaces + ${BUDDY_MLIR_INTERFACE_DIR} ) # Define paths for the timed executable diff --git a/examples/BuddyTransformer/README.md b/examples/BuddyTransformer/README.md index b3fbad9149..53ed936822 100644 --- a/examples/BuddyTransformer/README.md +++ b/examples/BuddyTransformer/README.md @@ -33,7 +33,7 @@ pip install -r requirements.txt ```bash export BUDDY_MLIR_BUILD_DIR=/path/to/buddy-mlir/build export LLVM_MLIR_BUILD_DIR=/path/to/buddy-mlir/llvm/build -export PYTHONPATH=${LLVM_MLIR_BUILD_DIR}/tools/mlir/python_packages/mlir_core:${BUDDY_MLIR_BUILD_DIR}/python_packages:${PYTHONPATH} +export PYTHONPATH=${BUDDY_MLIR_BUILD_DIR}/python_packages:${PYTHONPATH} ``` ### Build Commands diff --git a/examples/BuddyWhisper/README.md b/examples/BuddyWhisper/README.md index 680fb34ce9..90f1005e71 100644 --- a/examples/BuddyWhisper/README.md +++ b/examples/BuddyWhisper/README.md @@ -53,14 +53,14 @@ $ ninja check-buddy Set the `PYTHONPATH` environment variable. Make sure that the `PYTHONPATH` variable includes the directory of LLVM/MLIR python bindings and the directory of Buddy MLIR python packages. ```bash -$ export PYTHONPATH=/path-to-buddy-mlir/llvm/build/tools/mlir/python_packages/mlir_core:/path-to-buddy-mlir/build/python_packages:${PYTHONPATH} +$ export PYTHONPATH=/path-to-buddy-mlir/build/python_packages:${PYTHONPATH} // For example: // Navigate to your buddy-mlir/build directory $ cd buddy-mlir/build $ export BUDDY_MLIR_BUILD_DIR=$PWD $ export LLVM_MLIR_BUILD_DIR=$PWD/../llvm/build -$ export PYTHONPATH=${LLVM_MLIR_BUILD_DIR}/tools/mlir/python_packages/mlir_core:${BUDDY_MLIR_BUILD_DIR}/python_packages:${PYTHONPATH} +$ export PYTHONPATH=${BUDDY_MLIR_BUILD_DIR}/python_packages:${PYTHONPATH} ``` 3. Set model environment variable. diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index df4aea4079..c5370a9faf 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -70,8 +70,8 @@ set(BUDDY_EXAMPLES_DEPENDS mlir-runner ) -if(BUDDY_MLIR_ENABLE_PYTHON_PACKAGES) - list(APPEND BUDDY_TEST_DEPENDS BuddyMLIRPythonModules) +if(BUDDY_MLIR_ENABLE_PYTHON_PACKAGES AND MLIR_ENABLE_BINDINGS_PYTHON) + list(APPEND BUDDY_TEST_DEPENDS python-package-buddy-mlir) endif() add_lit_testsuite(check-examples "Checking the buddy-mlir examples..." diff --git a/examples/ConvOpt/CMakeLists.txt b/examples/ConvOpt/CMakeLists.txt index d8806e27a5..70dd78457e 100644 --- a/examples/ConvOpt/CMakeLists.txt +++ b/examples/ConvOpt/CMakeLists.txt @@ -15,7 +15,7 @@ message(STATUS "Spliting size: ${SPLITING_SIZE}") add_custom_command(OUTPUT conv2d.o - COMMAND ${CMAKE_BINARY_DIR}/bin/buddy-opt ${BUDDY_EXAMPLES_DIR}/ConvOpt/conv2d.mlir -conv-vectorization="strip-mining=${SPLITING_SIZE}" -lower-affine -convert-scf-to-cf -convert-vector-to-llvm -finalize-memref-to-llvm -llvm-request-c-wrappers --convert-arith-to-llvm -convert-func-to-llvm --convert-cf-to-llvm -reconcile-unrealized-casts | + COMMAND ${BUDDY_BINARY_DIR}/buddy-opt ${BUDDY_EXAMPLES_DIR}/ConvOpt/conv2d.mlir -conv-vectorization="strip-mining=${SPLITING_SIZE}" -lower-affine -convert-scf-to-cf -convert-vector-to-llvm -finalize-memref-to-llvm -llvm-request-c-wrappers --convert-arith-to-llvm -convert-func-to-llvm --convert-cf-to-llvm -reconcile-unrealized-casts | ${LLVM_TOOLS_BINARY_DIR}/mlir-translate --mlir-to-llvmir | ${LLVM_TOOLS_BINARY_DIR}/llc -mtriple=${BUDDY_TARGET_TRIPLE} -mattr=${BUDDY_OPT_ATTR} --filetype=obj -o ${BUDDY_BINARY_DIR}/../examples/ConvOpt/conv2d.o DEPENDS buddy-opt) diff --git a/examples/IMEDialect/.gitignore b/examples/IMEDialect/.gitignore index f4d715c5a2..e85e8d7660 100644 --- a/examples/IMEDialect/.gitignore +++ b/examples/IMEDialect/.gitignore @@ -10,3 +10,4 @@ vmadot-variants.mlir *.o *.s *_test +test_instructions.sh diff --git a/examples/IMEDialect/README.md b/examples/IMEDialect/README.md index dda47d3421..8467201b00 100644 --- a/examples/IMEDialect/README.md +++ b/examples/IMEDialect/README.md @@ -14,7 +14,10 @@ This document provides a comprehensive guide for using the IME (Integrated Matri The IMEDialect example folder contains the following key files: **MLIR Source Files:** -- **`vmadot.mlir`, `vmadotu.mlir`, `vmadotsu.mlir`, `vmadotus.mlir`, `vfmadot.mlir`**: Minimal MLIR files demonstrating each IME operation. Use these to verify assembly code generation. +- **`vmadot.mlir`, `vmadotu.mlir`, `vmadotsu.mlir`, `vmadotus.mlir`, `vfmadot.mlir`**: Minimal MLIR files demonstrating each basic IME operation. Use these to verify assembly code generation. +- **`vmadot1.mlir`, `vmadot2.mlir`, `vmadot3.mlir`**: Fixed sliding-window integer operations with slide=1, 2, 3. +- **`vfmadot1.mlir`, `vfmadot2.mlir`, `vfmadot3.mlir`**: Fixed sliding-window floating-point operations with slide=1, 2, 3. +- **`vmadotn.mlir`, `vmadotnu.mlir`, `vmadotnsu.mlir`, `vmadotnus.mlir`, `vfmadotn.mlir`**: Minimal MLIR files for dynamic sliding-window operations with runtime slide parameter. - **`vmadot_print_test.mlir`, `vmadotu_print_test.mlir`, etc.**: Extended test MLIR files with additional code for printing input/output matrices on hardware. Use these for hardware validation with visible results. **Runtime and Build Files:** @@ -84,22 +87,49 @@ Use this option to verify that IME operations are correctly lowered to RISC-V as **Generate Lowered MLIR:** ```bash -make vmadot-lower +make vmadot-lower # Basic integer operations +make vmadot1-lower # Fixed sliding-window (slide=1) +make vmadotn-lower # Dynamic sliding-window +make vfmadot-lower # Floating-point +make vfmadot1-lower # Floating-point fixed sliding-window +make vfmadotn-lower # Floating-point dynamic sliding-window ``` This generates `log.mlir` containing the lowered representation. **Generate LLVM IR:** ```bash make vmadot-translate +make vmadot1-translate +make vmadotn-translate +# ... and similarly for other variants ``` This generates `log.ll` containing LLVM IR code. **Generate Assembly:** ```bash -make vmadot-asm +make vmadot-asm # Basic: vmadot, vmadotu, vmadotsu, vmadotus +make vmadot1-asm # Fixed slide: vmadot1, vmadot2, vmadot3 +make vmadotn-asm # Dynamic slide: vmadotn, vmadotnu, vmadotnsu, vmadotnus +make vfmadot-asm # Float basic +make vfmadot1-asm # Float fixed slide: vfmadot1, vfmadot2, vfmadot3 +make vfmadotn-asm # Float dynamic slide ``` This generates `log.s` containing RISC-V assembly. You can inspect this file to verify correct IME instruction generation. +**Available Make Targets Summary:** + +| Category | Instructions | Make Targets | +|----------|-------------|--------------| +| Basic Integer | vmadot, vmadotu, vmadotsu, vmadotus | `vmadot{,u,su,us}-{lower,translate,asm,run}` | +| Fixed Slide Integer (signed) | vmadot1, vmadot2, vmadot3 | `vmadot{1,2,3}-{lower,translate,asm,run}` | +| Fixed Slide Integer (unsigned) | vmadot1u, vmadot2u, vmadot3u | `vmadot{1,2,3}u-{lower,translate,asm}` | +| Fixed Slide Integer (signed×unsigned) | vmadot1su, vmadot2su, vmadot3su | `vmadot{1,2,3}su-{lower,translate,asm}` | +| Fixed Slide Integer (unsigned×signed) | vmadot1us, vmadot2us, vmadot3us | `vmadot{1,2,3}us-{lower,translate,asm}` | +| Dynamic Slide Integer | vmadotn, vmadotnu, vmadotnsu, vmadotnus | `vmadotn{,u,su,us}-{lower,translate,asm}` | +| Basic Float | vfmadot | `vfmadot-{lower,translate,asm,run}` | +| Fixed Slide Float | vfmadot1, vfmadot2, vfmadot3 | `vfmadot{1,2,3}-{lower,translate,asm,run}` | +| Dynamic Slide Float | vfmadotn | `vfmadotn-{lower,translate,asm,run}` | + #### Option B: Build Hardware Test Executables (Full Test) Use this option to build executables that can run on SpacemiT hardware with printed output for verification. This uses the `*_print_test.mlir` files which include matrix printing functionality. @@ -122,6 +152,13 @@ This will generate executable binaries: `vmadot_test`, `vmadotu_test`, `vmadotsu > **Note**: The `-run` targets require the SpacemiT cross-compiler (`riscv64-unknown-linux-gnu-gcc`) to be in your PATH. +> **Toolchain Compatibility Note**: The SpacemiT toolchain (v1.1.2) currently only supports basic IME instructions (`vmadot`, `vmadotu`, `vmadotsu`, `vmadotus`). The following instructions will produce "unrecognized opcode" errors when assembled with SpacemiT's GNU as: +> - Fixed sliding-window: `vmadot1`, `vmadot2`, `vmadot3`, `vfmadot1`, `vfmadot2`, `vfmadot3` +> - Dynamic sliding-window: `vmadotn`, `vmadotnu`, `vmadotnsu`, `vmadotnus`, `vfmadotn` +> - Floating-point basic: `vfmadot` +> +> For these unsupported instructions, you can still verify assembly generation using the `-asm` targets, but hardware testing requires SpacemiT to update their toolchain. + --- ## Run on Hardware @@ -154,7 +191,19 @@ Each test will output: - Expected and computed results - Verification status (PASS/FAIL) -**Note**: `vfmadot` floating-point instruction tests cannot yet be executed as standalone binaries and are available only for assembly generation and inspection. +**Current Hardware Test Status:** + +| Instruction | Assembly Gen | Hardware Test | Notes | +|-------------|-------------|---------------|-------| +| vmadot | ✅ | ✅ | Fully working | +| vmadotu | ✅ | ✅ | Fully working | +| vmadotsu | ✅ | ✅ | Fully working | +| vmadotus | ✅ | ✅ | Fully working | +| vmadot1/2/3 | ✅ | ⏳ | Waiting for funct7 fix | +| vmadotn/nu/nsu/nus | ✅ | ⏳ | Waiting for toolchain support | +| vfmadot | ✅ | ⏳ | Waiting for toolchain support | +| vfmadot1/2/3 | ✅ | ⏳ | Waiting for toolchain support | +| vfmadotn | ✅ | ⏳ | Waiting for toolchain support | @@ -215,39 +264,187 @@ vfmadot vd, vs1, vs2 # vd(C) += vs1(A) × vs2(B) - `vs2` (source 2): Input matrix B - Result is stored in a single register (different from integer instructions) +#### 3. Fixed Sliding-Window Instructions (vmadot1/2/3, vfmadot1/2/3) + +These instructions have a **fixed slide amount encoded in the instruction**. They read from VS1 and VS1+1 (64 elements for int8, forming a 2M×K matrix), then slide by a fixed number of rows (1, 2, or 3) to select an M×K submatrix. + +**Integer Fixed Sliding-Window Instructions:** + +| Category | Instructions | Operand A Type | Operand B Type | Accumulator Type | Description | +|----------|-------------|---|---|---|---| +| slide-1 | `vmadot1` | int8 | int8 | int32 | signed × signed | +| | `vmadot1u` | uint8 | uint8 | int32 | unsigned × unsigned | +| | `vmadot1su` | int8 | uint8 | int32 | signed × unsigned | +| | `vmadot1us` | uint8 | int8 | int32 | unsigned × signed | +| slide-2 | `vmadot2` | int8 | int8 | int32 | signed × signed | +| | `vmadot2u` | uint8 | uint8 | int32 | unsigned × unsigned | +| | `vmadot2su` | int8 | uint8 | int32 | signed × unsigned | +| | `vmadot2us` | uint8 | int8 | int32 | unsigned × signed | +| slide-3 | `vmadot3` | int8 | int8 | int32 | signed × signed | +| | `vmadot3u` | uint8 | uint8 | int32 | unsigned × unsigned | +| | `vmadot3su` | int8 | uint8 | int32 | signed × unsigned | +| | `vmadot3us` | uint8 | int8 | int32 | unsigned × signed | + +**Floating-Point Fixed Sliding-Window Instructions:** + +| Instruction | Operand A Type | Operand B Type | Accumulator Type | Description | +|-------------|---|---|---|---| +| `vfmadot1` | fp16 | fp16 | fp16 | floating-point, slide=1 | +| `vfmadot2` | fp16 | fp16 | fp16 | floating-point, slide=2 | +| `vfmadot3` | fp16 | fp16 | fp16 | floating-point, slide=3 | + +**Assembly Format**: +```assembly +# Integer slide-1 +vmadot1 vd, vs1, vs2 # vd(C) += slide(vs1, 1)(A) × vs2(B) - signed × signed +vmadot1u vd, vs1, vs2 # unsigned × unsigned +vmadot1su vd, vs1, vs2 # signed × unsigned +vmadot1us vd, vs1, vs2 # unsigned × signed + +# Integer slide-2 +vmadot2 vd, vs1, vs2 +vmadot2u vd, vs1, vs2 +vmadot2su vd, vs1, vs2 +vmadot2us vd, vs1, vs2 + +# Integer slide-3 +vmadot3 vd, vs1, vs2 +vmadot3u vd, vs1, vs2 +vmadot3su vd, vs1, vs2 +vmadot3us vd, vs1, vs2 + +# Floating-point +vfmadot1 vd, vs1, vs2 +vfmadot2 vd, vs1, vs2 +vfmadot3 vd, vs1, vs2 +``` + +**Register Constraints**: +- `vd` (destination): Target register for result matrix C +- `vs1` (source 1): Input matrix A (reads VS1 and VS1+1, 64 elements for sliding) +- `vs2` (source 2): Input matrix B + +> **Note**: The `vmadot1/2/3` instructions currently have a funct7 encoding issue (see [SpacemiT Issue #2](https://github.com/user/repo/issues/2)). Assembly generation works, but machine code may be incorrect until resolved. + +#### 4. Dynamic Sliding-Window Instructions (vmadotn/vfmadotn) + +These instructions support a **dynamic slide parameter** passed at runtime, allowing flexible row selection from the source matrix. The sliding window reads from VS1 and VS1+1 (64 elements for int8, forming a 2M×K matrix), then slides by n rows to select an M×K submatrix. + +| Instruction | Operand A Type | Operand B Type | Accumulator Type | Description | +|-------------|---|---|---|---| +| `vmadotn` | int8 | int8 | int32 | signed × signed with dynamic slide | +| `vmadotnu` | uint8 | uint8 | int32 | unsigned × unsigned with dynamic slide | +| `vmadotnsu` | int8 | uint8 | int32 | signed × unsigned with dynamic slide | +| `vmadotnus` | uint8 | int8 | int32 | unsigned × signed with dynamic slide | +| `vfmadotn` | fp16 | fp16 | fp16 | floating-point with dynamic slide | + +**Assembly Format**: +```assembly +vmadotn vd, vs1, vs2, rs1 # vd(C) += slide(vs1, rs1)(A) × vs2(B) +vmadotnu vd, vs1, vs2, rs1 +vmadotnsu vd, vs1, vs2, rs1 +vmadotnus vd, vs1, vs2, rs1 +vfmadotn vd, vs1, vs2, rs1 +``` + +**Register Constraints**: +- `vd` (destination): Target register for result matrix C +- `vs1` (source 1): Input matrix A (reads VS1 and VS1+1, 64 elements for sliding) +- `vs2` (source 2): Input matrix B +- `rs1` (scalar): Slide amount (0-3 for int8 with VLEN=256) + +**Sliding Window Mechanism**: +``` +For VLEN=256, int8: +- VS1 loads 64 elements (8 rows × 8 cols) +- slide=0: use rows [0,1,2,3] +- slide=1: use rows [1,2,3,4] +- slide=2: use rows [2,3,4,5] +- slide=3: use rows [3,4,5,6] +``` + ### MLIR Operation Syntax All IME operations in MLIR follow this pattern: ```mlir +// Basic operations (without slide) ime.vmadot %accumulator, %matrix_a, %matrix_b : memref<...>, memref<...>, memref<...> + +// Dynamic sliding-window operations (with slide parameter) +ime.vmadotn %accumulator, %matrix_a, %matrix_b, %slide : memref<...>, memref<...>, memref<...>, i64 ``` Where: - `%accumulator`: Destination memref (2D, element type matches result type) - `%matrix_a`: Left operand matrix A memref (2D) - `%matrix_b`: Right operand matrix B memref (2D) +- `%slide`: (for vmadotn variants) Slide amount as i64 scalar ### Example MLIR Code Complete example showing IME operations in MLIR: ```mlir +// Basic integer matrix multiply-accumulate func.func @vmadot_example(%arg0: memref<4x4xi32>, %arg1: memref<4x8xi8>, %arg2: memref<8x4xi8>) { // Perform matrix multiply-accumulate: arg0 += arg1 × arg2 ime.vmadot %arg0, %arg1, %arg2 : memref<4x4xi32>, memref<4x8xi8>, memref<8x4xi8> return } +// Unsigned integer version func.func @vmadotu_example(%arg0: memref<4x4xi32>, %arg1: memref<4x8xui8>, %arg2: memref<8x4xui8>) { - // Unsigned integer version ime.vmadotu %arg0, %arg1, %arg2 : memref<4x4xi32>, memref<4x8xui8>, memref<8x4xui8> return } +// Floating-point version func.func @vfmadot_example(%arg0: memref<4x4xf16>, %arg1: memref<4x4xf16>, %arg2: memref<4x4xf16>) { - // Floating-point version ime.vfmadot %arg0, %arg1, %arg2 : memref<4x4xf16>, memref<4x4xf16>, memref<4x4xf16> return } + +// Fixed sliding-window: A is 8x8 (2M×K), slide=1 selects rows [1,2,3,4] +// Signed × Signed variants +func.func @vmadot1_example(%arg0: memref<4x4xi32>, %arg1: memref<8x8xi8>, %arg2: memref<8x4xi8>) { + ime.vmadot1 %arg0, %arg1, %arg2 : memref<4x4xi32>, memref<8x8xi8>, memref<8x4xi8> + return +} + +// Unsigned × Unsigned variants +func.func @vmadot1u_example(%arg0: memref<4x4xi32>, %arg1: memref<8x8xui8>, %arg2: memref<8x4xui8>) { + ime.vmadot1u %arg0, %arg1, %arg2 : memref<4x4xi32>, memref<8x8xui8>, memref<8x4xui8> + return +} + +// Signed × Unsigned variants +func.func @vmadot1su_example(%arg0: memref<4x4xi32>, %arg1: memref<8x8xi8>, %arg2: memref<8x4xui8>) { + ime.vmadot1su %arg0, %arg1, %arg2 : memref<4x4xi32>, memref<8x8xi8>, memref<8x4xui8> + return +} + +// Unsigned × Signed variants +func.func @vmadot1us_example(%arg0: memref<4x4xi32>, %arg1: memref<8x8xui8>, %arg2: memref<8x4xi8>) { + ime.vmadot1us %arg0, %arg1, %arg2 : memref<4x4xi32>, memref<8x8xui8>, memref<8x4xi8> + return +} + +// Fixed sliding-window floating-point +func.func @vfmadot1_example(%arg0: memref<4x4xf16>, %arg1: memref<8x4xf16>, %arg2: memref<4x4xf16>) { + ime.vfmadot1 %arg0, %arg1, %arg2 : memref<4x4xf16>, memref<8x4xf16>, memref<4x4xf16> + return +} + +// Dynamic sliding-window: slide amount passed at runtime +func.func @vmadotn_example(%arg0: memref<4x4xi32>, %arg1: memref<8x8xi8>, %arg2: memref<8x4xi8>, %slide: i64) { + ime.vmadotn %arg0, %arg1, %arg2, %slide : memref<4x4xi32>, memref<8x8xi8>, memref<8x4xi8>, i64 + return +} + +// Dynamic sliding-window floating-point +func.func @vfmadotn_example(%arg0: memref<4x4xf16>, %arg1: memref<8x4xf16>, %arg2: memref<4x4xf16>, %slide: i64) { + ime.vfmadotn %arg0, %arg1, %arg2, %slide : memref<4x4xf16>, memref<8x4xf16>, memref<4x4xf16>, i64 + return +} ``` diff --git a/examples/IMEDialect/build_all_tests.sh b/examples/IMEDialect/build_all_tests.sh index 61d6e92e7c..41d8cf5def 100755 --- a/examples/IMEDialect/build_all_tests.sh +++ b/examples/IMEDialect/build_all_tests.sh @@ -58,6 +58,21 @@ build_test "vmadotus" "vmadotus_print_test.mlir" "runtime_vmadotus.c" # Build vfmadot (floating-point) - uses different print function # build_test "vfmadot" "vfmadot_print_test.mlir" "runtime_vfmadot.c" +# Build vmadotn (signed x signed, dynamic slide) +build_test "vmadotn" "vmadotn_print_test.mlir" "runtime_vmadotn.c" + +# Build vmadotnu (unsigned x unsigned, dynamic slide) +build_test "vmadotnu" "vmadotnu_print_test.mlir" "runtime_vmadotnu.c" + +# Build vmadotnsu (signed x unsigned, dynamic slide) +build_test "vmadotnsu" "vmadotnsu_print_test.mlir" "runtime_vmadotnsu.c" + +# Build vmadotnus (unsigned x signed, dynamic slide) +build_test "vmadotnus" "vmadotnus_print_test.mlir" "runtime_vmadotnus.c" + +# Build vfmadotn (floating-point, dynamic slide) - uses different print function +# build_test "vfmadotn" "vfmadotn_print_test.mlir" "runtime_vfmadotn.c" + echo "" echo "All builds complete. Executables:" ls -la *_test 2>/dev/null || echo "No test executables found" diff --git a/examples/IMEDialect/linalg-to-ime-conv.mlir b/examples/IMEDialect/linalg-to-ime-conv.mlir new file mode 100644 index 0000000000..c4adbbbf32 --- /dev/null +++ b/examples/IMEDialect/linalg-to-ime-conv.mlir @@ -0,0 +1,103 @@ +// RUN: buddy-opt %s -lower-linalg-to-ime | FileCheck %s + +// Test Conv2D NHWC layout to IME lowering +// Input: [1, 12, 12, 16] - batch=1, H=12, W=12, IC=16 +// Filter: [3, 3, 16, 8] - FH=3, FW=3, IC=16, OC=8 +// Output: [1, 10, 10, 8] - batch=1, OH=10, OW=10, OC=8 + +// CHECK-LABEL: func.func @conv2d_nhwc_hwcf +func.func @conv2d_nhwc_hwcf(%input: memref<1x12x12x16xi8>, + %filter: memref<3x3x16x8xi8>, + %output: memref<1x10x10x8xi32>) { + // CHECK: scf.for + // CHECK: scf.for + // CHECK: scf.for + // CHECK: ime.vmadot + linalg.conv_2d_nhwc_hwcf ins(%input, %filter : memref<1x12x12x16xi8>, memref<3x3x16x8xi8>) + outs(%output : memref<1x10x10x8xi32>) + return +} + +// Test smaller convolution with exact tile sizes +// Input: [1, 6, 6, 8] - batch=1, H=6, W=6, IC=8 +// Filter: [3, 3, 8, 4] - FH=3, FW=3, IC=8, OC=4 +// Output: [1, 4, 4, 4] - batch=1, OH=4, OW=4, OC=4 + +// CHECK-LABEL: func.func @conv2d_nhwc_hwcf_small +func.func @conv2d_nhwc_hwcf_small(%input: memref<1x6x6x8xi8>, + %filter: memref<3x3x8x4xi8>, + %output: memref<1x4x4x4xi32>) { + // CHECK: scf.for + // CHECK: ime.vmadot + linalg.conv_2d_nhwc_hwcf ins(%input, %filter : memref<1x6x6x8xi8>, memref<3x3x8x4xi8>) + outs(%output : memref<1x4x4x4xi32>) + return +} + +// Test Conv2D NCHW layout to IME lowering +// Input: [1, 16, 12, 12] - batch=1, IC=16, H=12, W=12 +// Filter: [8, 16, 3, 3] - OC=8, IC=16, FH=3, FW=3 +// Output: [1, 8, 10, 10] - batch=1, OC=8, OH=10, OW=10 + +// CHECK-LABEL: func.func @conv2d_nchw_fchw +func.func @conv2d_nchw_fchw(%input: memref<1x16x12x12xi8>, + %filter: memref<8x16x3x3xi8>, + %output: memref<1x8x10x10xi32>) { + // CHECK: scf.for + // CHECK: scf.for + // CHECK: ime.vmadot + linalg.conv_2d_nchw_fchw ins(%input, %filter : memref<1x16x12x12xi8>, memref<8x16x3x3xi8>) + outs(%output : memref<1x8x10x10xi32>) + return +} + +// Test with stride > 1 +// Input: [1, 14, 14, 16] - batch=1, H=14, W=14, IC=16 +// Filter: [3, 3, 16, 8] - FH=3, FW=3, IC=16, OC=8 +// Output: [1, 6, 6, 8] - with stride=2: OH=(14-3)/2+1=6 + +// CHECK-LABEL: func.func @conv2d_nhwc_hwcf_stride2 +func.func @conv2d_nhwc_hwcf_stride2(%input: memref<1x14x14x16xi8>, + %filter: memref<3x3x16x8xi8>, + %output: memref<1x6x6x8xi32>) { + // CHECK: scf.for + // CHECK: ime.vmadot2 + linalg.conv_2d_nhwc_hwcf {strides = dense<2> : tensor<2xi64>} + ins(%input, %filter : memref<1x14x14x16xi8>, memref<3x3x16x8xi8>) + outs(%output : memref<1x6x6x8xi32>) + return +} + +// Test with stride=3 +// Input: [1, 16, 16, 8] - batch=1, H=16, W=16, IC=8 +// Filter: [3, 3, 8, 4] - FH=3, FW=3, IC=8, OC=4 +// Output: [1, 5, 5, 4] - with stride=3: OH=(16-3)/3+1=5 + +// CHECK-LABEL: func.func @conv2d_nhwc_hwcf_stride3 +func.func @conv2d_nhwc_hwcf_stride3(%input: memref<1x16x16x8xi8>, + %filter: memref<3x3x8x4xi8>, + %output: memref<1x5x5x4xi32>) { + // CHECK: scf.for + // CHECK: ime.vmadot3 + linalg.conv_2d_nhwc_hwcf {strides = dense<3> : tensor<2xi64>} + ins(%input, %filter : memref<1x16x16x8xi8>, memref<3x3x8x4xi8>) + outs(%output : memref<1x5x5x4xi32>) + return +} + +// Test with stride=4 (uses vmadotn with dynamic slide) +// Input: [1, 20, 20, 8] - batch=1, H=20, W=20, IC=8 +// Filter: [3, 3, 8, 4] - FH=3, FW=3, IC=8, OC=4 +// Output: [1, 5, 5, 4] - with stride=4: OH=(20-3)/4+1=5 + +// CHECK-LABEL: func.func @conv2d_nhwc_hwcf_stride4 +func.func @conv2d_nhwc_hwcf_stride4(%input: memref<1x20x20x8xi8>, + %filter: memref<3x3x8x4xi8>, + %output: memref<1x5x5x4xi32>) { + // CHECK: scf.for + // CHECK: ime.vmadotn + linalg.conv_2d_nhwc_hwcf {strides = dense<4> : tensor<2xi64>} + ins(%input, %filter : memref<1x20x20x8xi8>, memref<3x3x8x4xi8>) + outs(%output : memref<1x5x5x4xi32>) + return +} diff --git a/examples/IMEDialect/linalg-to-ime-matmul-boundary-func.mlir b/examples/IMEDialect/linalg-to-ime-matmul-boundary-func.mlir new file mode 100644 index 0000000000..d983b610a3 --- /dev/null +++ b/examples/IMEDialect/linalg-to-ime-matmul-boundary-func.mlir @@ -0,0 +1,16 @@ +// RUN: buddy-opt %s -lower-linalg-to-ime | FileCheck %s +// +// Test case: C[7x5] = A[7x10] * B[10x5] with non-aligned dimensions +// For int8: TILE_M=4, TILE_K=8, TILE_N=4 +// This file can also be compiled and linked with runtime_matmul_boundary.c + +// CHECK-LABEL: func.func @matmul_boundary +// CHECK: scf.for +// CHECK: scf.for +// CHECK: ime.vmadot +func.func @matmul_boundary(%A: memref<7x10xi8>, %B: memref<10x5xi8>, + %C: memref<7x5xi32>) { + linalg.matmul ins(%A, %B : memref<7x10xi8>, memref<10x5xi8>) + outs(%C : memref<7x5xi32>) + return +} diff --git a/examples/IMEDialect/linalg-to-ime-matmul-boundary.mlir b/examples/IMEDialect/linalg-to-ime-matmul-boundary.mlir new file mode 100644 index 0000000000..050bbce8ce --- /dev/null +++ b/examples/IMEDialect/linalg-to-ime-matmul-boundary.mlir @@ -0,0 +1,116 @@ +// RUN: buddy-opt %s -lower-linalg-to-ime | FileCheck %s +// +// This file tests the lowering of linalg.matmul to ime.vmadot operations +// with boundary handling for non-aligned dimensions. +// + +// ============================================================================= +// Test case 1: Non-aligned M dimension (M=6, not divisible by TILE_M=4) +// ============================================================================= +// CHECK-LABEL: func.func @matmul_i8_boundary_M +// CHECK: scf.for +// CHECK: scf.for +// CHECK: ime.vmadot +func.func @matmul_i8_boundary_M(%A: memref<6x8xi8>, %B: memref<8x4xi8>, + %C: memref<6x4xi32>) { + linalg.matmul ins(%A, %B : memref<6x8xi8>, memref<8x4xi8>) + outs(%C : memref<6x4xi32>) + return +} + +// ============================================================================= +// Test case 2: Non-aligned N dimension (N=6, not divisible by TILE_N=4) +// ============================================================================= +// CHECK-LABEL: func.func @matmul_i8_boundary_N +// CHECK: scf.for +// CHECK: scf.for +// CHECK: ime.vmadot +func.func @matmul_i8_boundary_N(%A: memref<4x8xi8>, %B: memref<8x6xi8>, + %C: memref<4x6xi32>) { + linalg.matmul ins(%A, %B : memref<4x8xi8>, memref<8x6xi8>) + outs(%C : memref<4x6xi32>) + return +} + +// ============================================================================= +// Test case 3: Non-aligned K dimension (K=10, not divisible by TILE_K=8) +// ============================================================================= +// CHECK-LABEL: func.func @matmul_i8_boundary_K +// CHECK: scf.for +// CHECK: scf.for +// CHECK: ime.vmadot +func.func @matmul_i8_boundary_K(%A: memref<4x10xi8>, %B: memref<10x4xi8>, + %C: memref<4x4xi32>) { + linalg.matmul ins(%A, %B : memref<4x10xi8>, memref<10x4xi8>) + outs(%C : memref<4x4xi32>) + return +} + +// ============================================================================= +// Test case 4: Multiple non-aligned dimensions (M=7, N=5, K=10) +// ============================================================================= +// CHECK-LABEL: func.func @matmul_i8_boundary_all +// CHECK: scf.for +// CHECK: scf.for +// CHECK: ime.vmadot +func.func @matmul_i8_boundary_all(%A: memref<7x10xi8>, %B: memref<10x5xi8>, + %C: memref<7x5xi32>) { + linalg.matmul ins(%A, %B : memref<7x10xi8>, memref<10x5xi8>) + outs(%C : memref<7x5xi32>) + return +} + +// ============================================================================= +// Test case 5: Larger matrices with boundary (M=18, N=14, K=25) +// ============================================================================= +// CHECK-LABEL: func.func @matmul_i8_large_boundary +// CHECK: scf.for +// CHECK: scf.for +// CHECK: ime.vmadot +func.func @matmul_i8_large_boundary(%A: memref<18x25xi8>, %B: memref<25x14xi8>, + %C: memref<18x14xi32>) { + linalg.matmul ins(%A, %B : memref<18x25xi8>, memref<25x14xi8>) + outs(%C : memref<18x14xi32>) + return +} + +// ============================================================================= +// Test case 6: Small matrices (all dimensions < tile size) +// ============================================================================= +// M=3 < 4, N=2 < 4, K=5 < 8: Still uses vmadot with padding +// CHECK-LABEL: func.func @matmul_i8_small_all_scalar +// CHECK: ime.vmadot +func.func @matmul_i8_small_all_scalar(%A: memref<3x5xi8>, %B: memref<5x2xi8>, + %C: memref<3x2xi32>) { + linalg.matmul ins(%A, %B : memref<3x5xi8>, memref<5x2xi8>) + outs(%C : memref<3x2xi32>) + return +} + +// ============================================================================= +// Test case 7: int16 with boundary (tile size is 4x4x4 for int16) +// ============================================================================= +// CHECK-LABEL: func.func @matmul_i16_boundary +// CHECK: scf.for +// CHECK: scf.for +// CHECK: ime.vmadot +func.func @matmul_i16_boundary(%A: memref<6x7xi16>, %B: memref<7x5xi16>, + %C: memref<6x5xi32>) { + linalg.matmul ins(%A, %B : memref<6x7xi16>, memref<7x5xi16>) + outs(%C : memref<6x5xi32>) + return +} + +// ============================================================================= +// Test case 8: Edge case - exactly one tile with boundary +// ============================================================================= +// CHECK-LABEL: func.func @matmul_i8_one_tile_plus +// CHECK: scf.for +// CHECK: scf.for +// CHECK: ime.vmadot +func.func @matmul_i8_one_tile_plus(%A: memref<5x8xi8>, %B: memref<8x4xi8>, + %C: memref<5x4xi32>) { + linalg.matmul ins(%A, %B : memref<5x8xi8>, memref<8x4xi8>) + outs(%C : memref<5x4xi32>) + return +} diff --git a/examples/IMEDialect/linalg-to-ime-matmul.mlir b/examples/IMEDialect/linalg-to-ime-matmul.mlir new file mode 100644 index 0000000000..8943373e2e --- /dev/null +++ b/examples/IMEDialect/linalg-to-ime-matmul.mlir @@ -0,0 +1,57 @@ +// RUN: buddy-opt %s -lower-linalg-to-ime | FileCheck %s +// +// This file tests the lowering of linalg.matmul to ime.vmadot operations. +// + +// Test case 1: Simple int8 matmul (4x8) * (8x4) = (4x4) +// This maps to ime.vmadot with single iteration loops +// CHECK-LABEL: func.func @matmul_i8_4x8x4 +// CHECK: scf.for +// CHECK: scf.for +// CHECK: scf.for +// CHECK: ime.vmadot +func.func @matmul_i8_4x8x4(%A: memref<4x8xi8>, %B: memref<8x4xi8>, %C: memref<4x4xi32>) { + linalg.matmul ins(%A, %B : memref<4x8xi8>, memref<8x4xi8>) + outs(%C : memref<4x4xi32>) + return +} + +// Test case 2: Larger int8 matmul requiring tiling +// (16x32) * (32x16) = (16x16) +// Should generate nested loops with ime.vmadot +// CHECK-LABEL: func.func @matmul_i8_16x32x16 +// CHECK: scf.for +// CHECK: scf.for +// CHECK: scf.for +// CHECK: ime.vmadot +func.func @matmul_i8_16x32x16(%A: memref<16x32xi8>, %B: memref<32x16xi8>, %C: memref<16x16xi32>) { + linalg.matmul ins(%A, %B : memref<16x32xi8>, memref<32x16xi8>) + outs(%C : memref<16x16xi32>) + return +} + +// Test case 3: int16 matmul (tile size 4x4x4) +// (4x4) * (4x4) = (4x4) +// CHECK-LABEL: func.func @matmul_i16_4x4x4 +// CHECK: scf.for +// CHECK: scf.for +// CHECK: scf.for +// CHECK: ime.vmadot +func.func @matmul_i16_4x4x4(%A: memref<4x4xi16>, %B: memref<4x4xi16>, %C: memref<4x4xi32>) { + linalg.matmul ins(%A, %B : memref<4x4xi16>, memref<4x4xi16>) + outs(%C : memref<4x4xi32>) + return +} + +// Test case 4: Larger int16 matmul requiring tiling +// (16x16) * (16x16) = (16x16) +// CHECK-LABEL: func.func @matmul_i16_16x16x16 +// CHECK: scf.for +// CHECK: scf.for +// CHECK: scf.for +// CHECK: ime.vmadot +func.func @matmul_i16_16x16x16(%A: memref<16x16xi16>, %B: memref<16x16xi16>, %C: memref<16x16xi32>) { + linalg.matmul ins(%A, %B : memref<16x16xi16>, memref<16x16xi16>) + outs(%C : memref<16x16xi32>) + return +} diff --git a/examples/IMEDialect/makefile b/examples/IMEDialect/makefile index be2a4fcd62..52f499dcb1 100644 --- a/examples/IMEDialect/makefile +++ b/examples/IMEDialect/makefile @@ -8,9 +8,34 @@ BUDDY_LLC := ../../build/bin/buddy-llc vmadotu-lower vmadotu-translate vmadotu-asm vmadotu-run \ vmadotsu-lower vmadotsu-translate vmadotsu-asm vmadotsu-run \ vmadotus-lower vmadotus-translate vmadotus-asm vmadotus-run \ - vfmadot-lower vfmadot-translate vfmadot-asm + vfmadot-lower vfmadot-translate vfmadot-asm vfmadot-run \ + vmadotn vmadotnu vmadotnsu vmadotnus vfmadotn \ + vmadotn-lower vmadotn-translate vmadotn-asm vmadotn-run \ + vmadotnu-lower vmadotnu-translate vmadotnu-asm vmadotnu-run \ + vmadotnsu-lower vmadotnsu-translate vmadotnsu-asm vmadotnsu-run \ + vmadotnus-lower vmadotnus-translate vmadotnus-asm vmadotnus-run \ + vfmadotn-lower vfmadotn-translate vfmadotn-asm vfmadotn-run \ + vmadot1-lower vmadot1-translate vmadot1-asm vmadot1-run \ + vmadot1u-lower vmadot1u-translate vmadot1u-asm \ + vmadot1su-lower vmadot1su-translate vmadot1su-asm \ + vmadot1us-lower vmadot1us-translate vmadot1us-asm \ + vmadot2-lower vmadot2-translate vmadot2-asm vmadot2-run \ + vmadot2u-lower vmadot2u-translate vmadot2u-asm \ + vmadot2su-lower vmadot2su-translate vmadot2su-asm \ + vmadot2us-lower vmadot2us-translate vmadot2us-asm \ + vmadot3-lower vmadot3-translate vmadot3-asm vmadot3-run \ + vmadot3u-lower vmadot3u-translate vmadot3u-asm \ + vmadot3su-lower vmadot3su-translate vmadot3su-asm \ + vmadot3us-lower vmadot3us-translate vmadot3us-asm \ + vfmadot1-lower vfmadot1-translate vfmadot1-asm vfmadot1-run \ + vfmadot2-lower vfmadot2-translate vfmadot2-asm vfmadot2-run \ + vfmadot3-lower vfmadot3-translate vfmadot3-asm vfmadot3-run \ + linalg-matmul-lower linalg-matmul-translate linalg-matmul-asm \ + linalg-conv-lower linalg-conv-translate linalg-conv-asm \ + linalg-matmul-boundary-lower linalg-matmul-boundary-translate \ + linalg-matmul-boundary-asm linalg-matmul-boundary-run -all: vmadot vmadotu vmadotsu vmadotus vfmadot +all: vmadot vmadotu vmadotsu vmadotus vfmadot vmadotn vmadotnu vmadotnsu vmadotnus vfmadotn vmadot: vmadot.mlir @@ -28,6 +53,21 @@ vmadotus: vmadotus.mlir vfmadot: vfmadot.mlir $(BUDDY_OPT) $< -o vfmadot-out.mlir +vmadotn: vmadotn.mlir + $(BUDDY_OPT) $< -o vmadotn-out.mlir + +vmadotnu: vmadotnu.mlir + $(BUDDY_OPT) $< -o vmadotnu-out.mlir + +vmadotnsu: vmadotnsu.mlir + $(BUDDY_OPT) $< -o vmadotnsu-out.mlir + +vmadotnus: vmadotnus.mlir + $(BUDDY_OPT) $< -o vmadotnus-out.mlir + +vfmadotn: vfmadotn.mlir + $(BUDDY_OPT) $< -o vfmadotn-out.mlir + clean: rm -f *-out.mlir @@ -331,6 +371,1290 @@ vfmadot-asm: -finalize-memref-to-llvm \ -reconcile-unrealized-casts | \ ${BUDDY_TRANSLATE} -buddy-to-llvmir | \ + ${BUDDY_LLC} -filetype=asm -mtriple=riscv64 \ + -mattr=+m,+v,+zvfh,+buddyext \ + -o log.s + +vfmadot-run: + @${BUDDY_OPT} ./vfmadot_print_test.mlir \ + -lower-ime \ + -convert-linalg-to-loops \ + -lower-affine \ + -convert-scf-to-cf \ + -convert-cf-to-llvm \ + -convert-arith-to-llvm \ + -convert-math-to-llvm \ + -convert-func-to-llvm \ + -finalize-memref-to-llvm \ + -reconcile-unrealized-casts | \ + ${BUDDY_TRANSLATE} -buddy-to-llvmir | \ + ${BUDDY_LLC} -filetype=asm -mtriple=riscv64 \ + -mattr=+m,+v,+zvfh,+buddyext \ + -o vfmadot.s + riscv64-unknown-linux-gnu-gcc -march=rv64gcv -static vfmadot.s runtime_vfmadot.c -o vfmadot_test + + +vmadot1-lower: + @${BUDDY_OPT} ./vmadot1.mlir \ + -lower-ime \ + -convert-linalg-to-loops \ + -lower-affine \ + -convert-scf-to-cf \ + -convert-cf-to-llvm \ + -convert-arith-to-llvm \ + -convert-math-to-llvm \ + -convert-func-to-llvm \ + -finalize-memref-to-llvm \ + -reconcile-unrealized-casts \ + -o log.mlir + +vmadot1-translate: + @${BUDDY_OPT} ./vmadot1.mlir \ + -lower-ime \ + -convert-linalg-to-loops \ + -lower-affine \ + -convert-scf-to-cf \ + -convert-cf-to-llvm \ + -convert-arith-to-llvm \ + -convert-math-to-llvm \ + -convert-func-to-llvm \ + -finalize-memref-to-llvm \ + -reconcile-unrealized-casts | \ + ${BUDDY_TRANSLATE} -buddy-to-llvmir \ + -o log.ll + +vmadot1-asm: + @${BUDDY_OPT} ./vmadot1.mlir \ + -lower-ime \ + -convert-linalg-to-loops \ + -lower-affine \ + -convert-scf-to-cf \ + -convert-cf-to-llvm \ + -convert-arith-to-llvm \ + -convert-math-to-llvm \ + -convert-func-to-llvm \ + -finalize-memref-to-llvm \ + -reconcile-unrealized-casts | \ + ${BUDDY_TRANSLATE} -buddy-to-llvmir | \ + ${BUDDY_LLC} -filetype=asm -mtriple=riscv64 \ + -mattr=+m,+v,+buddyext \ + -o log.s + +vmadot2-lower: + @${BUDDY_OPT} ./vmadot2.mlir \ + -lower-ime \ + -convert-linalg-to-loops \ + -lower-affine \ + -convert-scf-to-cf \ + -convert-cf-to-llvm \ + -convert-arith-to-llvm \ + -convert-math-to-llvm \ + -convert-func-to-llvm \ + -finalize-memref-to-llvm \ + -reconcile-unrealized-casts \ + -o log.mlir + +vmadot2-translate: + @${BUDDY_OPT} ./vmadot2.mlir \ + -lower-ime \ + -convert-linalg-to-loops \ + -lower-affine \ + -convert-scf-to-cf \ + -convert-cf-to-llvm \ + -convert-arith-to-llvm \ + -convert-math-to-llvm \ + -convert-func-to-llvm \ + -finalize-memref-to-llvm \ + -reconcile-unrealized-casts | \ + ${BUDDY_TRANSLATE} -buddy-to-llvmir \ + -o log.ll + +vmadot2-asm: + @${BUDDY_OPT} ./vmadot2.mlir \ + -lower-ime \ + -convert-linalg-to-loops \ + -lower-affine \ + -convert-scf-to-cf \ + -convert-cf-to-llvm \ + -convert-arith-to-llvm \ + -convert-math-to-llvm \ + -convert-func-to-llvm \ + -finalize-memref-to-llvm \ + -reconcile-unrealized-casts | \ + ${BUDDY_TRANSLATE} -buddy-to-llvmir | \ + ${BUDDY_LLC} -filetype=asm -mtriple=riscv64 \ + -mattr=+m,+v,+buddyext \ + -o log.s + +vmadot3-lower: + @${BUDDY_OPT} ./vmadot3.mlir \ + -lower-ime \ + -convert-linalg-to-loops \ + -lower-affine \ + -convert-scf-to-cf \ + -convert-cf-to-llvm \ + -convert-arith-to-llvm \ + -convert-math-to-llvm \ + -convert-func-to-llvm \ + -finalize-memref-to-llvm \ + -reconcile-unrealized-casts \ + -o log.mlir + +vmadot3-translate: + @${BUDDY_OPT} ./vmadot3.mlir \ + -lower-ime \ + -convert-linalg-to-loops \ + -lower-affine \ + -convert-scf-to-cf \ + -convert-cf-to-llvm \ + -convert-arith-to-llvm \ + -convert-math-to-llvm \ + -convert-func-to-llvm \ + -finalize-memref-to-llvm \ + -reconcile-unrealized-casts | \ + ${BUDDY_TRANSLATE} -buddy-to-llvmir \ + -o log.ll + +vmadot3-asm: + @${BUDDY_OPT} ./vmadot3.mlir \ + -lower-ime \ + -convert-linalg-to-loops \ + -lower-affine \ + -convert-scf-to-cf \ + -convert-cf-to-llvm \ + -convert-arith-to-llvm \ + -convert-math-to-llvm \ + -convert-func-to-llvm \ + -finalize-memref-to-llvm \ + -reconcile-unrealized-casts | \ + ${BUDDY_TRANSLATE} -buddy-to-llvmir | \ + ${BUDDY_LLC} -filetype=asm -mtriple=riscv64 \ + -mattr=+m,+v,+buddyext \ + -o log.s + +vmadot1-run: + @${BUDDY_OPT} ./vmadot1_print_test.mlir \ + -lower-ime \ + -convert-linalg-to-loops \ + -lower-affine \ + -convert-scf-to-cf \ + -convert-cf-to-llvm \ + -convert-arith-to-llvm \ + -convert-math-to-llvm \ + -convert-func-to-llvm \ + -finalize-memref-to-llvm \ + -reconcile-unrealized-casts | \ + ${BUDDY_TRANSLATE} -buddy-to-llvmir | \ + ${BUDDY_LLC} -filetype=asm -mtriple=riscv64 \ + -mattr=+m,+v,+buddyext \ + -o vmadot1.s + riscv64-unknown-linux-gnu-gcc -march=rv64gcv -static vmadot1.s runtime_vmadot1.c -o vmadot1_test + +vmadot2-run: + @${BUDDY_OPT} ./vmadot2_print_test.mlir \ + -lower-ime \ + -convert-linalg-to-loops \ + -lower-affine \ + -convert-scf-to-cf \ + -convert-cf-to-llvm \ + -convert-arith-to-llvm \ + -convert-math-to-llvm \ + -convert-func-to-llvm \ + -finalize-memref-to-llvm \ + -reconcile-unrealized-casts | \ + ${BUDDY_TRANSLATE} -buddy-to-llvmir | \ + ${BUDDY_LLC} -filetype=asm -mtriple=riscv64 \ + -mattr=+m,+v,+buddyext \ + -o vmadot2.s + riscv64-unknown-linux-gnu-gcc -march=rv64gcv -static vmadot2.s runtime_vmadot2.c -o vmadot2_test + +vmadot3-run: + @${BUDDY_OPT} ./vmadot3_print_test.mlir \ + -lower-ime \ + -convert-linalg-to-loops \ + -lower-affine \ + -convert-scf-to-cf \ + -convert-cf-to-llvm \ + -convert-arith-to-llvm \ + -convert-math-to-llvm \ + -convert-func-to-llvm \ + -finalize-memref-to-llvm \ + -reconcile-unrealized-casts | \ + ${BUDDY_TRANSLATE} -buddy-to-llvmir | \ + ${BUDDY_LLC} -filetype=asm -mtriple=riscv64 \ + -mattr=+m,+v,+buddyext \ + -o vmadot3.s + riscv64-unknown-linux-gnu-gcc -march=rv64gcv -static vmadot3.s runtime_vmadot3.c -o vmadot3_test + +vmadot1u-lower: + @${BUDDY_OPT} ./vmadot1u.mlir \ + -lower-ime \ + -convert-linalg-to-loops \ + -lower-affine \ + -convert-scf-to-cf \ + -convert-cf-to-llvm \ + -convert-arith-to-llvm \ + -convert-math-to-llvm \ + -convert-func-to-llvm \ + -finalize-memref-to-llvm \ + -reconcile-unrealized-casts \ + -o log.mlir + +vmadot1u-translate: + @${BUDDY_OPT} ./vmadot1u.mlir \ + -lower-ime \ + -convert-linalg-to-loops \ + -lower-affine \ + -convert-scf-to-cf \ + -convert-cf-to-llvm \ + -convert-arith-to-llvm \ + -convert-math-to-llvm \ + -convert-func-to-llvm \ + -finalize-memref-to-llvm \ + -reconcile-unrealized-casts | \ + ${BUDDY_TRANSLATE} -buddy-to-llvmir \ + -o log.ll + +vmadot1u-asm: + @${BUDDY_OPT} ./vmadot1u.mlir \ + -lower-ime \ + -convert-linalg-to-loops \ + -lower-affine \ + -convert-scf-to-cf \ + -convert-cf-to-llvm \ + -convert-arith-to-llvm \ + -convert-math-to-llvm \ + -convert-func-to-llvm \ + -finalize-memref-to-llvm \ + -reconcile-unrealized-casts | \ + ${BUDDY_TRANSLATE} -buddy-to-llvmir | \ ${BUDDY_LLC} -filetype=asm -mtriple=riscv64 \ -mattr=+m,+v,+buddyext \ -o log.s + +vmadot1su-lower: + @${BUDDY_OPT} ./vmadot1su.mlir \ + -lower-ime \ + -convert-linalg-to-loops \ + -lower-affine \ + -convert-scf-to-cf \ + -convert-cf-to-llvm \ + -convert-arith-to-llvm \ + -convert-math-to-llvm \ + -convert-func-to-llvm \ + -finalize-memref-to-llvm \ + -reconcile-unrealized-casts \ + -o log.mlir + +vmadot1su-translate: + @${BUDDY_OPT} ./vmadot1su.mlir \ + -lower-ime \ + -convert-linalg-to-loops \ + -lower-affine \ + -convert-scf-to-cf \ + -convert-cf-to-llvm \ + -convert-arith-to-llvm \ + -convert-math-to-llvm \ + -convert-func-to-llvm \ + -finalize-memref-to-llvm \ + -reconcile-unrealized-casts | \ + ${BUDDY_TRANSLATE} -buddy-to-llvmir \ + -o log.ll + +vmadot1su-asm: + @${BUDDY_OPT} ./vmadot1su.mlir \ + -lower-ime \ + -convert-linalg-to-loops \ + -lower-affine \ + -convert-scf-to-cf \ + -convert-cf-to-llvm \ + -convert-arith-to-llvm \ + -convert-math-to-llvm \ + -convert-func-to-llvm \ + -finalize-memref-to-llvm \ + -reconcile-unrealized-casts | \ + ${BUDDY_TRANSLATE} -buddy-to-llvmir | \ + ${BUDDY_LLC} -filetype=asm -mtriple=riscv64 \ + -mattr=+m,+v,+buddyext \ + -o log.s + +vmadot1us-lower: + @${BUDDY_OPT} ./vmadot1us.mlir \ + -lower-ime \ + -convert-linalg-to-loops \ + -lower-affine \ + -convert-scf-to-cf \ + -convert-cf-to-llvm \ + -convert-arith-to-llvm \ + -convert-math-to-llvm \ + -convert-func-to-llvm \ + -finalize-memref-to-llvm \ + -reconcile-unrealized-casts \ + -o log.mlir + +vmadot1us-translate: + @${BUDDY_OPT} ./vmadot1us.mlir \ + -lower-ime \ + -convert-linalg-to-loops \ + -lower-affine \ + -convert-scf-to-cf \ + -convert-cf-to-llvm \ + -convert-arith-to-llvm \ + -convert-math-to-llvm \ + -convert-func-to-llvm \ + -finalize-memref-to-llvm \ + -reconcile-unrealized-casts | \ + ${BUDDY_TRANSLATE} -buddy-to-llvmir \ + -o log.ll + +vmadot1us-asm: + @${BUDDY_OPT} ./vmadot1us.mlir \ + -lower-ime \ + -convert-linalg-to-loops \ + -lower-affine \ + -convert-scf-to-cf \ + -convert-cf-to-llvm \ + -convert-arith-to-llvm \ + -convert-math-to-llvm \ + -convert-func-to-llvm \ + -finalize-memref-to-llvm \ + -reconcile-unrealized-casts | \ + ${BUDDY_TRANSLATE} -buddy-to-llvmir | \ + ${BUDDY_LLC} -filetype=asm -mtriple=riscv64 \ + -mattr=+m,+v,+buddyext \ + -o log.s + +vmadot2u-lower: + @${BUDDY_OPT} ./vmadot2u.mlir \ + -lower-ime \ + -convert-linalg-to-loops \ + -lower-affine \ + -convert-scf-to-cf \ + -convert-cf-to-llvm \ + -convert-arith-to-llvm \ + -convert-math-to-llvm \ + -convert-func-to-llvm \ + -finalize-memref-to-llvm \ + -reconcile-unrealized-casts \ + -o log.mlir + +vmadot2u-translate: + @${BUDDY_OPT} ./vmadot2u.mlir \ + -lower-ime \ + -convert-linalg-to-loops \ + -lower-affine \ + -convert-scf-to-cf \ + -convert-cf-to-llvm \ + -convert-arith-to-llvm \ + -convert-math-to-llvm \ + -convert-func-to-llvm \ + -finalize-memref-to-llvm \ + -reconcile-unrealized-casts | \ + ${BUDDY_TRANSLATE} -buddy-to-llvmir \ + -o log.ll + +vmadot2u-asm: + @${BUDDY_OPT} ./vmadot2u.mlir \ + -lower-ime \ + -convert-linalg-to-loops \ + -lower-affine \ + -convert-scf-to-cf \ + -convert-cf-to-llvm \ + -convert-arith-to-llvm \ + -convert-math-to-llvm \ + -convert-func-to-llvm \ + -finalize-memref-to-llvm \ + -reconcile-unrealized-casts | \ + ${BUDDY_TRANSLATE} -buddy-to-llvmir | \ + ${BUDDY_LLC} -filetype=asm -mtriple=riscv64 \ + -mattr=+m,+v,+buddyext \ + -o log.s + +vmadot2su-lower: + @${BUDDY_OPT} ./vmadot2su.mlir \ + -lower-ime \ + -convert-linalg-to-loops \ + -lower-affine \ + -convert-scf-to-cf \ + -convert-cf-to-llvm \ + -convert-arith-to-llvm \ + -convert-math-to-llvm \ + -convert-func-to-llvm \ + -finalize-memref-to-llvm \ + -reconcile-unrealized-casts \ + -o log.mlir + +vmadot2su-translate: + @${BUDDY_OPT} ./vmadot2su.mlir \ + -lower-ime \ + -convert-linalg-to-loops \ + -lower-affine \ + -convert-scf-to-cf \ + -convert-cf-to-llvm \ + -convert-arith-to-llvm \ + -convert-math-to-llvm \ + -convert-func-to-llvm \ + -finalize-memref-to-llvm \ + -reconcile-unrealized-casts | \ + ${BUDDY_TRANSLATE} -buddy-to-llvmir \ + -o log.ll + +vmadot2su-asm: + @${BUDDY_OPT} ./vmadot2su.mlir \ + -lower-ime \ + -convert-linalg-to-loops \ + -lower-affine \ + -convert-scf-to-cf \ + -convert-cf-to-llvm \ + -convert-arith-to-llvm \ + -convert-math-to-llvm \ + -convert-func-to-llvm \ + -finalize-memref-to-llvm \ + -reconcile-unrealized-casts | \ + ${BUDDY_TRANSLATE} -buddy-to-llvmir | \ + ${BUDDY_LLC} -filetype=asm -mtriple=riscv64 \ + -mattr=+m,+v,+buddyext \ + -o log.s + +vmadot2us-lower: + @${BUDDY_OPT} ./vmadot2us.mlir \ + -lower-ime \ + -convert-linalg-to-loops \ + -lower-affine \ + -convert-scf-to-cf \ + -convert-cf-to-llvm \ + -convert-arith-to-llvm \ + -convert-math-to-llvm \ + -convert-func-to-llvm \ + -finalize-memref-to-llvm \ + -reconcile-unrealized-casts \ + -o log.mlir + +vmadot2us-translate: + @${BUDDY_OPT} ./vmadot2us.mlir \ + -lower-ime \ + -convert-linalg-to-loops \ + -lower-affine \ + -convert-scf-to-cf \ + -convert-cf-to-llvm \ + -convert-arith-to-llvm \ + -convert-math-to-llvm \ + -convert-func-to-llvm \ + -finalize-memref-to-llvm \ + -reconcile-unrealized-casts | \ + ${BUDDY_TRANSLATE} -buddy-to-llvmir \ + -o log.ll + +vmadot2us-asm: + @${BUDDY_OPT} ./vmadot2us.mlir \ + -lower-ime \ + -convert-linalg-to-loops \ + -lower-affine \ + -convert-scf-to-cf \ + -convert-cf-to-llvm \ + -convert-arith-to-llvm \ + -convert-math-to-llvm \ + -convert-func-to-llvm \ + -finalize-memref-to-llvm \ + -reconcile-unrealized-casts | \ + ${BUDDY_TRANSLATE} -buddy-to-llvmir | \ + ${BUDDY_LLC} -filetype=asm -mtriple=riscv64 \ + -mattr=+m,+v,+buddyext \ + -o log.s + +vmadot3u-lower: + @${BUDDY_OPT} ./vmadot3u.mlir \ + -lower-ime \ + -convert-linalg-to-loops \ + -lower-affine \ + -convert-scf-to-cf \ + -convert-cf-to-llvm \ + -convert-arith-to-llvm \ + -convert-math-to-llvm \ + -convert-func-to-llvm \ + -finalize-memref-to-llvm \ + -reconcile-unrealized-casts \ + -o log.mlir + +vmadot3u-translate: + @${BUDDY_OPT} ./vmadot3u.mlir \ + -lower-ime \ + -convert-linalg-to-loops \ + -lower-affine \ + -convert-scf-to-cf \ + -convert-cf-to-llvm \ + -convert-arith-to-llvm \ + -convert-math-to-llvm \ + -convert-func-to-llvm \ + -finalize-memref-to-llvm \ + -reconcile-unrealized-casts | \ + ${BUDDY_TRANSLATE} -buddy-to-llvmir \ + -o log.ll + +vmadot3u-asm: + @${BUDDY_OPT} ./vmadot3u.mlir \ + -lower-ime \ + -convert-linalg-to-loops \ + -lower-affine \ + -convert-scf-to-cf \ + -convert-cf-to-llvm \ + -convert-arith-to-llvm \ + -convert-math-to-llvm \ + -convert-func-to-llvm \ + -finalize-memref-to-llvm \ + -reconcile-unrealized-casts | \ + ${BUDDY_TRANSLATE} -buddy-to-llvmir | \ + ${BUDDY_LLC} -filetype=asm -mtriple=riscv64 \ + -mattr=+m,+v,+buddyext \ + -o log.s + +vmadot3su-lower: + @${BUDDY_OPT} ./vmadot3su.mlir \ + -lower-ime \ + -convert-linalg-to-loops \ + -lower-affine \ + -convert-scf-to-cf \ + -convert-cf-to-llvm \ + -convert-arith-to-llvm \ + -convert-math-to-llvm \ + -convert-func-to-llvm \ + -finalize-memref-to-llvm \ + -reconcile-unrealized-casts \ + -o log.mlir + +vmadot3su-translate: + @${BUDDY_OPT} ./vmadot3su.mlir \ + -lower-ime \ + -convert-linalg-to-loops \ + -lower-affine \ + -convert-scf-to-cf \ + -convert-cf-to-llvm \ + -convert-arith-to-llvm \ + -convert-math-to-llvm \ + -convert-func-to-llvm \ + -finalize-memref-to-llvm \ + -reconcile-unrealized-casts | \ + ${BUDDY_TRANSLATE} -buddy-to-llvmir \ + -o log.ll + +vmadot3su-asm: + @${BUDDY_OPT} ./vmadot3su.mlir \ + -lower-ime \ + -convert-linalg-to-loops \ + -lower-affine \ + -convert-scf-to-cf \ + -convert-cf-to-llvm \ + -convert-arith-to-llvm \ + -convert-math-to-llvm \ + -convert-func-to-llvm \ + -finalize-memref-to-llvm \ + -reconcile-unrealized-casts | \ + ${BUDDY_TRANSLATE} -buddy-to-llvmir | \ + ${BUDDY_LLC} -filetype=asm -mtriple=riscv64 \ + -mattr=+m,+v,+buddyext \ + -o log.s + +vmadot3us-lower: + @${BUDDY_OPT} ./vmadot3us.mlir \ + -lower-ime \ + -convert-linalg-to-loops \ + -lower-affine \ + -convert-scf-to-cf \ + -convert-cf-to-llvm \ + -convert-arith-to-llvm \ + -convert-math-to-llvm \ + -convert-func-to-llvm \ + -finalize-memref-to-llvm \ + -reconcile-unrealized-casts \ + -o log.mlir + +vmadot3us-translate: + @${BUDDY_OPT} ./vmadot3us.mlir \ + -lower-ime \ + -convert-linalg-to-loops \ + -lower-affine \ + -convert-scf-to-cf \ + -convert-cf-to-llvm \ + -convert-arith-to-llvm \ + -convert-math-to-llvm \ + -convert-func-to-llvm \ + -finalize-memref-to-llvm \ + -reconcile-unrealized-casts | \ + ${BUDDY_TRANSLATE} -buddy-to-llvmir \ + -o log.ll + +vmadot3us-asm: + @${BUDDY_OPT} ./vmadot3us.mlir \ + -lower-ime \ + -convert-linalg-to-loops \ + -lower-affine \ + -convert-scf-to-cf \ + -convert-cf-to-llvm \ + -convert-arith-to-llvm \ + -convert-math-to-llvm \ + -convert-func-to-llvm \ + -finalize-memref-to-llvm \ + -reconcile-unrealized-casts | \ + ${BUDDY_TRANSLATE} -buddy-to-llvmir | \ + ${BUDDY_LLC} -filetype=asm -mtriple=riscv64 \ + -mattr=+m,+v,+buddyext \ + -o log.s + +vmadotn-lower: + @${BUDDY_OPT} ./vmadotn.mlir \ + -lower-ime \ + -convert-linalg-to-loops \ + -lower-affine \ + -convert-scf-to-cf \ + -convert-cf-to-llvm \ + -convert-arith-to-llvm \ + -convert-math-to-llvm \ + -convert-func-to-llvm \ + -finalize-memref-to-llvm \ + -reconcile-unrealized-casts \ + -o log.mlir + +vmadotn-translate: + @${BUDDY_OPT} ./vmadotn.mlir \ + -lower-ime \ + -convert-linalg-to-loops \ + -lower-affine \ + -convert-scf-to-cf \ + -convert-cf-to-llvm \ + -convert-arith-to-llvm \ + -convert-math-to-llvm \ + -convert-func-to-llvm \ + -finalize-memref-to-llvm \ + -reconcile-unrealized-casts | \ + ${BUDDY_TRANSLATE} -buddy-to-llvmir \ + -o log.ll + +vmadotn-asm: + @${BUDDY_OPT} ./vmadotn.mlir \ + -lower-ime \ + -convert-linalg-to-loops \ + -lower-affine \ + -convert-scf-to-cf \ + -convert-cf-to-llvm \ + -convert-arith-to-llvm \ + -convert-math-to-llvm \ + -convert-func-to-llvm \ + -finalize-memref-to-llvm \ + -reconcile-unrealized-casts | \ + ${BUDDY_TRANSLATE} -buddy-to-llvmir | \ + ${BUDDY_LLC} -filetype=asm -mtriple=riscv64 \ + -mattr=+m,+v,+buddyext \ + -o log.s + +vmadotn-run: + @${BUDDY_OPT} ./vmadotn_print_test.mlir \ + -lower-ime \ + -convert-linalg-to-loops \ + -lower-affine \ + -convert-scf-to-cf \ + -convert-cf-to-llvm \ + -convert-arith-to-llvm \ + -convert-math-to-llvm \ + -convert-func-to-llvm \ + -finalize-memref-to-llvm \ + -reconcile-unrealized-casts | \ + ${BUDDY_TRANSLATE} -buddy-to-llvmir | \ + ${BUDDY_LLC} -filetype=asm -mtriple=riscv64 \ + -mattr=+m,+v,+buddyext \ + -o vmadotn.s + riscv64-unknown-linux-gnu-gcc -march=rv64gcv -static vmadotn.s runtime_vmadotn.c -o vmadotn_test + + +vmadotnu-lower: + @${BUDDY_OPT} ./vmadotnu.mlir \ + -lower-ime \ + -convert-linalg-to-loops \ + -lower-affine \ + -convert-scf-to-cf \ + -convert-cf-to-llvm \ + -convert-arith-to-llvm \ + -convert-math-to-llvm \ + -convert-func-to-llvm \ + -finalize-memref-to-llvm \ + -reconcile-unrealized-casts \ + -o log.mlir + +vmadotnu-translate: + @${BUDDY_OPT} ./vmadotnu.mlir \ + -lower-ime \ + -convert-linalg-to-loops \ + -lower-affine \ + -convert-scf-to-cf \ + -convert-cf-to-llvm \ + -convert-arith-to-llvm \ + -convert-math-to-llvm \ + -convert-func-to-llvm \ + -finalize-memref-to-llvm \ + -reconcile-unrealized-casts | \ + ${BUDDY_TRANSLATE} -buddy-to-llvmir \ + -o log.ll + +vmadotnu-asm: + @${BUDDY_OPT} ./vmadotnu.mlir \ + -lower-ime \ + -convert-linalg-to-loops \ + -lower-affine \ + -convert-scf-to-cf \ + -convert-cf-to-llvm \ + -convert-arith-to-llvm \ + -convert-math-to-llvm \ + -convert-func-to-llvm \ + -finalize-memref-to-llvm \ + -reconcile-unrealized-casts | \ + ${BUDDY_TRANSLATE} -buddy-to-llvmir | \ + ${BUDDY_LLC} -filetype=asm -mtriple=riscv64 \ + -mattr=+m,+v,+buddyext \ + -o log.s + +vmadotnu-run: + @${BUDDY_OPT} ./vmadotnu_print_test.mlir \ + -lower-ime \ + -convert-linalg-to-loops \ + -lower-affine \ + -convert-scf-to-cf \ + -convert-cf-to-llvm \ + -convert-arith-to-llvm \ + -convert-math-to-llvm \ + -convert-func-to-llvm \ + -finalize-memref-to-llvm \ + -reconcile-unrealized-casts | \ + ${BUDDY_TRANSLATE} -buddy-to-llvmir | \ + ${BUDDY_LLC} -filetype=asm -mtriple=riscv64 \ + -mattr=+m,+v,+buddyext \ + -o vmadotnu.s + riscv64-unknown-linux-gnu-gcc -march=rv64gcv -static vmadotnu.s runtime_vmadotnu.c -o vmadotnu_test + + +vmadotnsu-lower: + @${BUDDY_OPT} ./vmadotnsu.mlir \ + -lower-ime \ + -convert-linalg-to-loops \ + -lower-affine \ + -convert-scf-to-cf \ + -convert-cf-to-llvm \ + -convert-arith-to-llvm \ + -convert-math-to-llvm \ + -convert-func-to-llvm \ + -finalize-memref-to-llvm \ + -reconcile-unrealized-casts \ + -o log.mlir + +vmadotnsu-translate: + @${BUDDY_OPT} ./vmadotnsu.mlir \ + -lower-ime \ + -convert-linalg-to-loops \ + -lower-affine \ + -convert-scf-to-cf \ + -convert-cf-to-llvm \ + -convert-arith-to-llvm \ + -convert-math-to-llvm \ + -convert-func-to-llvm \ + -finalize-memref-to-llvm \ + -reconcile-unrealized-casts | \ + ${BUDDY_TRANSLATE} -buddy-to-llvmir \ + -o log.ll + +vmadotnsu-asm: + @${BUDDY_OPT} ./vmadotnsu.mlir \ + -lower-ime \ + -convert-linalg-to-loops \ + -lower-affine \ + -convert-scf-to-cf \ + -convert-cf-to-llvm \ + -convert-arith-to-llvm \ + -convert-math-to-llvm \ + -convert-func-to-llvm \ + -finalize-memref-to-llvm \ + -reconcile-unrealized-casts | \ + ${BUDDY_TRANSLATE} -buddy-to-llvmir | \ + ${BUDDY_LLC} -filetype=asm -mtriple=riscv64 \ + -mattr=+m,+v,+buddyext \ + -o log.s + +vmadotnsu-run: + @${BUDDY_OPT} ./vmadotnsu_print_test.mlir \ + -lower-ime \ + -convert-linalg-to-loops \ + -lower-affine \ + -convert-scf-to-cf \ + -convert-cf-to-llvm \ + -convert-arith-to-llvm \ + -convert-math-to-llvm \ + -convert-func-to-llvm \ + -finalize-memref-to-llvm \ + -reconcile-unrealized-casts | \ + ${BUDDY_TRANSLATE} -buddy-to-llvmir | \ + ${BUDDY_LLC} -filetype=asm -mtriple=riscv64 \ + -mattr=+m,+v,+buddyext \ + -o vmadotnsu.s + riscv64-unknown-linux-gnu-gcc -march=rv64gcv -static vmadotnsu.s runtime_vmadotnsu.c -o vmadotnsu_test + + +vmadotnus-lower: + @${BUDDY_OPT} ./vmadotnus.mlir \ + -lower-ime \ + -convert-linalg-to-loops \ + -lower-affine \ + -convert-scf-to-cf \ + -convert-cf-to-llvm \ + -convert-arith-to-llvm \ + -convert-math-to-llvm \ + -convert-func-to-llvm \ + -finalize-memref-to-llvm \ + -reconcile-unrealized-casts \ + -o log.mlir + +vmadotnus-translate: + @${BUDDY_OPT} ./vmadotnus.mlir \ + -lower-ime \ + -convert-linalg-to-loops \ + -lower-affine \ + -convert-scf-to-cf \ + -convert-cf-to-llvm \ + -convert-arith-to-llvm \ + -convert-math-to-llvm \ + -convert-func-to-llvm \ + -finalize-memref-to-llvm \ + -reconcile-unrealized-casts | \ + ${BUDDY_TRANSLATE} -buddy-to-llvmir \ + -o log.ll + +vmadotnus-asm: + @${BUDDY_OPT} ./vmadotnus.mlir \ + -lower-ime \ + -convert-linalg-to-loops \ + -lower-affine \ + -convert-scf-to-cf \ + -convert-cf-to-llvm \ + -convert-arith-to-llvm \ + -convert-math-to-llvm \ + -convert-func-to-llvm \ + -finalize-memref-to-llvm \ + -reconcile-unrealized-casts | \ + ${BUDDY_TRANSLATE} -buddy-to-llvmir | \ + ${BUDDY_LLC} -filetype=asm -mtriple=riscv64 \ + -mattr=+m,+v,+buddyext \ + -o log.s + +vmadotnus-run: + @${BUDDY_OPT} ./vmadotnus_print_test.mlir \ + -lower-ime \ + -convert-linalg-to-loops \ + -lower-affine \ + -convert-scf-to-cf \ + -convert-cf-to-llvm \ + -convert-arith-to-llvm \ + -convert-math-to-llvm \ + -convert-func-to-llvm \ + -finalize-memref-to-llvm \ + -reconcile-unrealized-casts | \ + ${BUDDY_TRANSLATE} -buddy-to-llvmir | \ + ${BUDDY_LLC} -filetype=asm -mtriple=riscv64 \ + -mattr=+m,+v,+buddyext \ + -o vmadotnus.s + riscv64-unknown-linux-gnu-gcc -march=rv64gcv -static vmadotnus.s runtime_vmadotnus.c -o vmadotnus_test + + +vfmadot1-lower: + @${BUDDY_OPT} ./vfmadot1.mlir \ + -lower-ime \ + -convert-linalg-to-loops \ + -lower-affine \ + -convert-scf-to-cf \ + -convert-cf-to-llvm \ + -convert-arith-to-llvm \ + -convert-math-to-llvm \ + -convert-func-to-llvm \ + -finalize-memref-to-llvm \ + -reconcile-unrealized-casts \ + -o log.mlir + +vfmadot1-translate: + @${BUDDY_OPT} ./vfmadot1.mlir \ + -lower-ime \ + -convert-linalg-to-loops \ + -lower-affine \ + -convert-scf-to-cf \ + -convert-cf-to-llvm \ + -convert-arith-to-llvm \ + -convert-math-to-llvm \ + -convert-func-to-llvm \ + -finalize-memref-to-llvm \ + -reconcile-unrealized-casts | \ + ${BUDDY_TRANSLATE} -buddy-to-llvmir \ + -o log.ll + +vfmadot1-asm: + @${BUDDY_OPT} ./vfmadot1.mlir \ + -lower-ime \ + -convert-linalg-to-loops \ + -lower-affine \ + -convert-scf-to-cf \ + -convert-cf-to-llvm \ + -convert-arith-to-llvm \ + -convert-math-to-llvm \ + -convert-func-to-llvm \ + -finalize-memref-to-llvm \ + -reconcile-unrealized-casts | \ + ${BUDDY_TRANSLATE} -buddy-to-llvmir | \ + ${BUDDY_LLC} -filetype=asm -mtriple=riscv64 \ + -mattr=+m,+v,+zvfh,+buddyext \ + -o log.s + +vfmadot2-lower: + @${BUDDY_OPT} ./vfmadot2.mlir \ + -lower-ime \ + -convert-linalg-to-loops \ + -lower-affine \ + -convert-scf-to-cf \ + -convert-cf-to-llvm \ + -convert-arith-to-llvm \ + -convert-math-to-llvm \ + -convert-func-to-llvm \ + -finalize-memref-to-llvm \ + -reconcile-unrealized-casts \ + -o log.mlir + +vfmadot2-translate: + @${BUDDY_OPT} ./vfmadot2.mlir \ + -lower-ime \ + -convert-linalg-to-loops \ + -lower-affine \ + -convert-scf-to-cf \ + -convert-cf-to-llvm \ + -convert-arith-to-llvm \ + -convert-math-to-llvm \ + -convert-func-to-llvm \ + -finalize-memref-to-llvm \ + -reconcile-unrealized-casts | \ + ${BUDDY_TRANSLATE} -buddy-to-llvmir \ + -o log.ll + +vfmadot2-asm: + @${BUDDY_OPT} ./vfmadot2.mlir \ + -lower-ime \ + -convert-linalg-to-loops \ + -lower-affine \ + -convert-scf-to-cf \ + -convert-cf-to-llvm \ + -convert-arith-to-llvm \ + -convert-math-to-llvm \ + -convert-func-to-llvm \ + -finalize-memref-to-llvm \ + -reconcile-unrealized-casts | \ + ${BUDDY_TRANSLATE} -buddy-to-llvmir | \ + ${BUDDY_LLC} -filetype=asm -mtriple=riscv64 \ + -mattr=+m,+v,+zvfh,+buddyext \ + -o log.s + +vfmadot3-lower: + @${BUDDY_OPT} ./vfmadot3.mlir \ + -lower-ime \ + -convert-linalg-to-loops \ + -lower-affine \ + -convert-scf-to-cf \ + -convert-cf-to-llvm \ + -convert-arith-to-llvm \ + -convert-math-to-llvm \ + -convert-func-to-llvm \ + -finalize-memref-to-llvm \ + -reconcile-unrealized-casts \ + -o log.mlir + +vfmadot3-translate: + @${BUDDY_OPT} ./vfmadot3.mlir \ + -lower-ime \ + -convert-linalg-to-loops \ + -lower-affine \ + -convert-scf-to-cf \ + -convert-cf-to-llvm \ + -convert-arith-to-llvm \ + -convert-math-to-llvm \ + -convert-func-to-llvm \ + -finalize-memref-to-llvm \ + -reconcile-unrealized-casts | \ + ${BUDDY_TRANSLATE} -buddy-to-llvmir \ + -o log.ll + +vfmadot3-asm: + @${BUDDY_OPT} ./vfmadot3.mlir \ + -lower-ime \ + -convert-linalg-to-loops \ + -lower-affine \ + -convert-scf-to-cf \ + -convert-cf-to-llvm \ + -convert-arith-to-llvm \ + -convert-math-to-llvm \ + -convert-func-to-llvm \ + -finalize-memref-to-llvm \ + -reconcile-unrealized-casts | \ + ${BUDDY_TRANSLATE} -buddy-to-llvmir | \ + ${BUDDY_LLC} -filetype=asm -mtriple=riscv64 \ + -mattr=+m,+v,+zvfh,+buddyext \ + -o log.s + +vfmadot1-run: + @${BUDDY_OPT} ./vfmadot1_print_test.mlir \ + -lower-ime \ + -convert-linalg-to-loops \ + -lower-affine \ + -convert-scf-to-cf \ + -convert-cf-to-llvm \ + -convert-arith-to-llvm \ + -convert-math-to-llvm \ + -convert-func-to-llvm \ + -finalize-memref-to-llvm \ + -reconcile-unrealized-casts | \ + ${BUDDY_TRANSLATE} -buddy-to-llvmir | \ + ${BUDDY_LLC} -filetype=asm -mtriple=riscv64 \ + -mattr=+m,+v,+zvfh,+buddyext \ + -o vfmadot1.s + riscv64-unknown-linux-gnu-gcc -march=rv64gcv -static vfmadot1.s runtime_vfmadot1.c -o vfmadot1_test + +vfmadot2-run: + @${BUDDY_OPT} ./vfmadot2_print_test.mlir \ + -lower-ime \ + -convert-linalg-to-loops \ + -lower-affine \ + -convert-scf-to-cf \ + -convert-cf-to-llvm \ + -convert-arith-to-llvm \ + -convert-math-to-llvm \ + -convert-func-to-llvm \ + -finalize-memref-to-llvm \ + -reconcile-unrealized-casts | \ + ${BUDDY_TRANSLATE} -buddy-to-llvmir | \ + ${BUDDY_LLC} -filetype=asm -mtriple=riscv64 \ + -mattr=+m,+v,+zvfh,+buddyext \ + -o vfmadot2.s + riscv64-unknown-linux-gnu-gcc -march=rv64gcv -static vfmadot2.s runtime_vfmadot2.c -o vfmadot2_test + +vfmadot3-run: + @${BUDDY_OPT} ./vfmadot3_print_test.mlir \ + -lower-ime \ + -convert-linalg-to-loops \ + -lower-affine \ + -convert-scf-to-cf \ + -convert-cf-to-llvm \ + -convert-arith-to-llvm \ + -convert-math-to-llvm \ + -convert-func-to-llvm \ + -finalize-memref-to-llvm \ + -reconcile-unrealized-casts | \ + ${BUDDY_TRANSLATE} -buddy-to-llvmir | \ + ${BUDDY_LLC} -filetype=asm -mtriple=riscv64 \ + -mattr=+m,+v,+zvfh,+buddyext \ + -o vfmadot3.s + riscv64-unknown-linux-gnu-gcc -march=rv64gcv -static vfmadot3.s runtime_vfmadot3.c -o vfmadot3_test + +vfmadotn-lower: + @${BUDDY_OPT} ./vfmadotn.mlir \ + -lower-ime \ + -convert-linalg-to-loops \ + -lower-affine \ + -convert-scf-to-cf \ + -convert-cf-to-llvm \ + -convert-arith-to-llvm \ + -convert-math-to-llvm \ + -convert-func-to-llvm \ + -finalize-memref-to-llvm \ + -reconcile-unrealized-casts \ + -o log.mlir + +vfmadotn-translate: + @${BUDDY_OPT} ./vfmadotn.mlir \ + -lower-ime \ + -convert-linalg-to-loops \ + -lower-affine \ + -convert-scf-to-cf \ + -convert-cf-to-llvm \ + -convert-arith-to-llvm \ + -convert-math-to-llvm \ + -convert-func-to-llvm \ + -finalize-memref-to-llvm \ + -reconcile-unrealized-casts | \ + ${BUDDY_TRANSLATE} -buddy-to-llvmir \ + -o log.ll + +vfmadotn-asm: + @${BUDDY_OPT} ./vfmadotn.mlir \ + -lower-ime \ + -convert-linalg-to-loops \ + -lower-affine \ + -convert-scf-to-cf \ + -convert-cf-to-llvm \ + -convert-arith-to-llvm \ + -convert-math-to-llvm \ + -convert-func-to-llvm \ + -finalize-memref-to-llvm \ + -reconcile-unrealized-casts | \ + ${BUDDY_TRANSLATE} -buddy-to-llvmir | \ + ${BUDDY_LLC} -filetype=asm -mtriple=riscv64 \ + -mattr=+m,+v,+zvfh,+buddyext \ + -o log.s + +vfmadotn-run: + @${BUDDY_OPT} ./vfmadotn_print_test.mlir \ + -lower-ime \ + -convert-linalg-to-loops \ + -lower-affine \ + -convert-scf-to-cf \ + -convert-cf-to-llvm \ + -convert-arith-to-llvm \ + -convert-math-to-llvm \ + -convert-func-to-llvm \ + -finalize-memref-to-llvm \ + -reconcile-unrealized-casts | \ + ${BUDDY_TRANSLATE} -buddy-to-llvmir | \ + ${BUDDY_LLC} -filetype=asm -mtriple=riscv64 \ + -mattr=+m,+v,+zvfh,+buddyext \ + -o vfmadotn.s + riscv64-unknown-linux-gnu-gcc -march=rv64gcv -static vfmadotn.s runtime_vfmadotn.c -o vfmadotn_test + + +linalg-matmul-lower: + @${BUDDY_OPT} ./linalg-to-ime-matmul.mlir \ + -lower-linalg-to-ime \ + -o log.mlir + +linalg-matmul-translate: + @${BUDDY_OPT} ./linalg-to-ime-matmul.mlir \ + -lower-linalg-to-ime \ + -lower-ime \ + -convert-linalg-to-loops \ + -lower-affine \ + -convert-scf-to-cf \ + -convert-cf-to-llvm \ + -convert-arith-to-llvm \ + -convert-math-to-llvm \ + -convert-func-to-llvm \ + -finalize-memref-to-llvm \ + -reconcile-unrealized-casts | \ + ${BUDDY_TRANSLATE} -buddy-to-llvmir \ + -o log.ll + +linalg-matmul-asm: + @${BUDDY_OPT} ./linalg-to-ime-matmul.mlir \ + -lower-linalg-to-ime \ + -lower-ime \ + -convert-linalg-to-loops \ + -lower-affine \ + -convert-scf-to-cf \ + -convert-cf-to-llvm \ + -convert-arith-to-llvm \ + -convert-math-to-llvm \ + -convert-func-to-llvm \ + -finalize-memref-to-llvm \ + -reconcile-unrealized-casts | \ + ${BUDDY_TRANSLATE} -buddy-to-llvmir | \ + ${BUDDY_LLC} -filetype=asm -mtriple=riscv64 \ + -mattr=+m,+v,+buddyext \ + -o log.s + + +linalg-conv-lower: + @${BUDDY_OPT} ./linalg-to-ime-conv.mlir \ + -lower-linalg-to-ime \ + -o log.mlir + +linalg-conv-translate: + @${BUDDY_OPT} ./linalg-to-ime-conv.mlir \ + -lower-linalg-to-ime \ + -lower-ime \ + -convert-linalg-to-loops \ + -lower-affine \ + -convert-scf-to-cf \ + -convert-cf-to-llvm \ + -convert-arith-to-llvm \ + -convert-math-to-llvm \ + -convert-func-to-llvm \ + -finalize-memref-to-llvm \ + -reconcile-unrealized-casts | \ + ${BUDDY_TRANSLATE} -buddy-to-llvmir \ + -o log.ll + +linalg-conv-asm: + @${BUDDY_OPT} ./linalg-to-ime-conv.mlir \ + -lower-linalg-to-ime \ + -lower-ime \ + -convert-linalg-to-loops \ + -lower-affine \ + -convert-scf-to-cf \ + -convert-cf-to-llvm \ + -convert-arith-to-llvm \ + -convert-math-to-llvm \ + -convert-func-to-llvm \ + -finalize-memref-to-llvm \ + -reconcile-unrealized-casts | \ + ${BUDDY_TRANSLATE} -buddy-to-llvmir | \ + ${BUDDY_LLC} -filetype=asm -mtriple=riscv64 \ + -mattr=+m,+v,+buddyext \ + -o log.s + +linalg-matmul-boundary-lower: + @${BUDDY_OPT} ./linalg-to-ime-matmul-boundary.mlir \ + -lower-linalg-to-ime \ + -o log.mlir + +linalg-matmul-boundary-translate: + @${BUDDY_OPT} ./linalg-to-ime-matmul-boundary.mlir \ + -lower-linalg-to-ime \ + -lower-ime \ + -convert-linalg-to-loops \ + -lower-affine \ + -expand-strided-metadata \ + -convert-scf-to-cf \ + -convert-cf-to-llvm \ + -convert-arith-to-llvm \ + -convert-math-to-llvm \ + -convert-func-to-llvm \ + -finalize-memref-to-llvm \ + -reconcile-unrealized-casts | \ + ${BUDDY_TRANSLATE} -buddy-to-llvmir \ + -o log.ll + +linalg-matmul-boundary-asm: + @${BUDDY_OPT} ./linalg-to-ime-matmul-boundary.mlir \ + -lower-linalg-to-ime \ + -lower-ime \ + -convert-linalg-to-loops \ + -lower-affine \ + -expand-strided-metadata \ + -convert-scf-to-cf \ + -convert-cf-to-llvm \ + -convert-arith-to-llvm \ + -convert-math-to-llvm \ + -convert-func-to-llvm \ + -finalize-memref-to-llvm \ + -reconcile-unrealized-casts | \ + ${BUDDY_TRANSLATE} -buddy-to-llvmir | \ + ${BUDDY_LLC} -filetype=asm -mtriple=riscv64 \ + -mattr=+m,+v,+buddyext \ + -o log.s + +linalg-matmul-boundary-run: + @${BUDDY_OPT} ./linalg-to-ime-matmul-boundary-func.mlir \ + -lower-linalg-to-ime \ + -lower-ime \ + -convert-linalg-to-loops \ + -lower-affine \ + -expand-strided-metadata \ + -convert-scf-to-cf \ + -convert-cf-to-llvm \ + -convert-arith-to-llvm \ + -convert-math-to-llvm \ + -convert-func-to-llvm \ + -finalize-memref-to-llvm \ + -reconcile-unrealized-casts | \ + ${BUDDY_TRANSLATE} -buddy-to-llvmir | \ + ${BUDDY_LLC} -filetype=asm -mtriple=riscv64 \ + -mattr=+m,+v,+buddyext \ + -o matmul_boundary.s + riscv64-unknown-linux-gnu-gcc -march=rv64gcv -static \ + matmul_boundary.s runtime_matmul_boundary.c \ + -o matmul_boundary_test diff --git a/examples/IMEDialect/runtime_matmul_boundary.c b/examples/IMEDialect/runtime_matmul_boundary.c new file mode 100644 index 0000000000..5124f6f61e --- /dev/null +++ b/examples/IMEDialect/runtime_matmul_boundary.c @@ -0,0 +1,139 @@ +/** + * Runtime test for matmul with boundary handling + * + * This test verifies the correctness of linalg.matmul -> IME lowering + * for matrices with dimensions not aligned to IME tile sizes. + * + * Test case: C[7x5] = A[7x10] * B[10x5] + * For int8: TILE_M=4, TILE_K=8, TILE_N=4 + * - M=7: 1 full tile (4) + 3 remaining + * - N=5: 1 full tile (4) + 1 remaining + * - K=10: 1 full tile (8) + 2 remaining + */ + +#include +#include +#include +#include + +// Matrix dimensions +#define M 7 +#define K 10 +#define N 5 + +// Test matrices +int8_t A[M][K] = { + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + {1, 1, 1, 1, 1, 1, 1, 1, 1, 1}, + {2, 2, 2, 2, 2, 2, 2, 2, 2, 2}, + {3, 3, 3, 3, 3, 3, 3, 3, 3, 3}, + {4, 4, 4, 4, 4, 4, 4, 4, 4, 4}, + {5, 5, 5, 5, 5, 5, 5, 5, 5, 5}, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10} +}; + +int8_t B[K][N] = { + {1, 0, 0, 0, 0}, + {0, 1, 0, 0, 0}, + {0, 0, 1, 0, 0}, + {0, 0, 0, 1, 0}, + {0, 0, 0, 0, 1}, + {1, 1, 1, 1, 1}, + {2, 2, 2, 2, 2}, + {3, 3, 3, 3, 3}, + {4, 4, 4, 4, 4}, + {5, 5, 5, 5, 5} +}; + +int32_t C[M][N]; +int32_t C_expected[M][N]; + +// External function from compiled MLIR +// The function uses unpacked memref arguments: +// (allocated, aligned, offset, size0, size1, stride0, stride1) for each memref +extern void matmul_boundary( + int8_t *A_alloc, int8_t *A_aligned, int64_t A_offset, + int64_t A_size0, int64_t A_size1, int64_t A_stride0, int64_t A_stride1, + int8_t *B_alloc, int8_t *B_aligned, int64_t B_offset, + int64_t B_size0, int64_t B_size1, int64_t B_stride0, int64_t B_stride1, + int32_t *C_alloc, int32_t *C_aligned, int64_t C_offset, + int64_t C_size0, int64_t C_size1, int64_t C_stride0, int64_t C_stride1); + +// Reference implementation for verification +void reference_matmul() { + for (int i = 0; i < M; i++) { + for (int j = 0; j < N; j++) { + C_expected[i][j] = 0; + for (int k = 0; k < K; k++) { + C_expected[i][j] += (int32_t)A[i][k] * (int32_t)B[k][j]; + } + } + } +} + +void print_matrix_i32(const char *name, int32_t *mat, int rows, int cols) { + printf("%s [%dx%d]:\n", name, rows, cols); + for (int i = 0; i < rows; i++) { + printf(" ["); + for (int j = 0; j < cols; j++) { + printf("%4d", mat[i * cols + j]); + if (j < cols - 1) printf(", "); + } + printf("]\n"); + } +} + +int verify_result() { + int errors = 0; + for (int i = 0; i < M; i++) { + for (int j = 0; j < N; j++) { + if (C[i][j] != C_expected[i][j]) { + printf("ERROR at C[%d][%d]: got %d, expected %d\n", + i, j, C[i][j], C_expected[i][j]); + errors++; + } + } + } + return errors; +} + +int main() { + printf("=== Matmul Boundary Test ===\n"); + printf("Matrix dimensions: A[%dx%d] * B[%dx%d] = C[%dx%d]\n", + M, K, K, N, M, N); + printf("IME int8 tile sizes: TILE_M=4, TILE_K=8, TILE_N=4\n"); + printf("Boundary cases:\n"); + printf(" M=%d: %d full tiles + %d remaining\n", M, M/4, M%4); + printf(" K=%d: %d full tiles + %d remaining\n", K, K/8, K%8); + printf(" N=%d: %d full tiles + %d remaining\n\n", N, N/4, N%4); + + // Initialize output matrix + memset(C, 0, sizeof(C)); + + // Compute reference result + reference_matmul(); + + // Call MLIR-generated function with unpacked memref arguments + // Each memref: (allocated, aligned, offset, size0, size1, stride0, stride1) + matmul_boundary( + (int8_t*)A, (int8_t*)A, 0, M, K, K, 1, // A[7x10] + (int8_t*)B, (int8_t*)B, 0, K, N, N, 1, // B[10x5] + (int32_t*)C, (int32_t*)C, 0, M, N, N, 1 // C[7x5] + ); + + // Print results + print_matrix_i32("Result C", (int32_t*)C, M, N); + printf("\n"); + print_matrix_i32("Expected C", (int32_t*)C_expected, M, N); + printf("\n"); + + // Verify + int errors = verify_result(); + if (errors == 0) { + printf("PASS: All results match!\n"); + return 0; + } else { + printf("FAIL: %d errors found\n", errors); + return 1; + } +} diff --git a/examples/IMEDialect/runtime_vfmadot1.c b/examples/IMEDialect/runtime_vfmadot1.c new file mode 100644 index 0000000000..5601f73b82 --- /dev/null +++ b/examples/IMEDialect/runtime_vfmadot1.c @@ -0,0 +1,28 @@ +#include + +void print_header() { + printf("=== vfmadot1 (fp16 x fp16, fixed slide=1) ==="); + printf("\n\nMatrix A (8x4, fp16, sliding source):\n"); + printf(" [1.0, 2.0, 3.0, 4.0] row 0\n"); + printf(" [2.0, 3.0, 4.0, 5.0] row 1 <- slide=1 starts here\n"); + printf(" [3.0, 4.0, 5.0, 6.0] row 2\n"); + printf(" [4.0, 5.0, 6.0, 7.0] row 3\n"); + printf(" [5.0, 6.0, 7.0, 8.0] row 4\n"); + printf(" [6.0, 7.0, 8.0, 9.0] row 5\n"); + printf(" [7.0, 8.0, 9.0, 10.0] row 6\n"); + printf(" [8.0, 9.0, 10.0, 11.0] row 7\n"); + printf("\nMatrix B (4x4, fp16, packed):\n"); + printf(" [1.0, 1.0, 1.0, 1.0]\n"); + printf(" [1.0, 1.0, 1.0, 1.0]\n"); + printf(" [1.0, 1.0, 1.0, 1.0]\n"); + printf(" [1.0, 1.0, 1.0, 1.0]\n"); + printf("\nFixed slide: 1\n"); + printf("Rows used after slide: [1,2,3,4] -> sums [14,18,22,26]\n"); + printf("\nResult matrix C (4x4, fp16):\n"); + printf("Expected: [[14,14,14,14], [18,18,18,18], [22,22,22,22], " + "[26,26,26,26]]\n"); +} + +void print_row_f16(int row, float v0, float v1, float v2, float v3) { + printf("Row %d: [%.1f, %.1f, %.1f, %.1f]\n", row, v0, v1, v2, v3); +} diff --git a/examples/IMEDialect/runtime_vfmadot2.c b/examples/IMEDialect/runtime_vfmadot2.c new file mode 100644 index 0000000000..031d7d81b3 --- /dev/null +++ b/examples/IMEDialect/runtime_vfmadot2.c @@ -0,0 +1,64 @@ +// Runtime support for vfmadot2_print_test +// Prints matrix results in human-readable format (fp16 version) + +#include +#include +#include + +// Convert half-precision float to single-precision +// Based on IEEE 754 half-precision format +static float f16_to_f32(uint16_t h) { + uint32_t sign = (h >> 15) & 0x1; + uint32_t exp = (h >> 10) & 0x1f; + uint32_t mant = h & 0x3ff; + + uint32_t f; + if (exp == 0) { + if (mant == 0) { + f = sign << 31; + } else { + // Denormalized + while (!(mant & 0x400)) { + mant <<= 1; + exp--; + } + exp++; + mant &= ~0x400; + exp = exp + (127 - 15); + f = (sign << 31) | (exp << 23) | (mant << 13); + } + } else if (exp == 31) { + f = (sign << 31) | 0x7f800000 | (mant << 13); + } else { + exp = exp + (127 - 15); + f = (sign << 31) | (exp << 23) | (mant << 13); + } + + union { + uint32_t u; + float f; + } u; + u.u = f; + return u.f; +} + +void print_header() { + printf("======================================\n"); + printf(" vfmadot2 Result Matrix (fp16)\n"); + printf(" slide=2 fixed sliding window\n"); + printf("======================================\n"); + printf("Expected (slide=2): all 18s in row 0,\n"); + printf(" all 22s in row 1,\n"); + printf(" all 26s in row 2,\n"); + printf(" all 30s in row 3\n"); + printf("--------------------------------------\n"); +} + +void print_row_f16(int32_t row, uint16_t v0, uint16_t v1, uint16_t v2, + uint16_t v3) { + float f0 = f16_to_f32(v0); + float f1 = f16_to_f32(v1); + float f2 = f16_to_f32(v2); + float f3 = f16_to_f32(v3); + printf("Row %d: [%6.2f, %6.2f, %6.2f, %6.2f]\n", row, f0, f1, f2, f3); +} diff --git a/examples/IMEDialect/runtime_vfmadot3.c b/examples/IMEDialect/runtime_vfmadot3.c new file mode 100644 index 0000000000..2f2af02252 --- /dev/null +++ b/examples/IMEDialect/runtime_vfmadot3.c @@ -0,0 +1,64 @@ +// Runtime support for vfmadot3_print_test +// Prints matrix results in human-readable format (fp16 version) + +#include +#include +#include + +// Convert half-precision float to single-precision +// Based on IEEE 754 half-precision format +static float f16_to_f32(uint16_t h) { + uint32_t sign = (h >> 15) & 0x1; + uint32_t exp = (h >> 10) & 0x1f; + uint32_t mant = h & 0x3ff; + + uint32_t f; + if (exp == 0) { + if (mant == 0) { + f = sign << 31; + } else { + // Denormalized + while (!(mant & 0x400)) { + mant <<= 1; + exp--; + } + exp++; + mant &= ~0x400; + exp = exp + (127 - 15); + f = (sign << 31) | (exp << 23) | (mant << 13); + } + } else if (exp == 31) { + f = (sign << 31) | 0x7f800000 | (mant << 13); + } else { + exp = exp + (127 - 15); + f = (sign << 31) | (exp << 23) | (mant << 13); + } + + union { + uint32_t u; + float f; + } u; + u.u = f; + return u.f; +} + +void print_header() { + printf("======================================\n"); + printf(" vfmadot3 Result Matrix (fp16)\n"); + printf(" slide=3 fixed sliding window\n"); + printf("======================================\n"); + printf("Expected (slide=3): all 22s in row 0,\n"); + printf(" all 26s in row 1,\n"); + printf(" all 30s in row 2,\n"); + printf(" all 34s in row 3\n"); + printf("--------------------------------------\n"); +} + +void print_row_f16(int32_t row, uint16_t v0, uint16_t v1, uint16_t v2, + uint16_t v3) { + float f0 = f16_to_f32(v0); + float f1 = f16_to_f32(v1); + float f2 = f16_to_f32(v2); + float f3 = f16_to_f32(v3); + printf("Row %d: [%6.2f, %6.2f, %6.2f, %6.2f]\n", row, f0, f1, f2, f3); +} diff --git a/examples/IMEDialect/runtime_vfmadotn.c b/examples/IMEDialect/runtime_vfmadotn.c new file mode 100644 index 0000000000..1da7e4a01f --- /dev/null +++ b/examples/IMEDialect/runtime_vfmadotn.c @@ -0,0 +1,28 @@ +#include + +void print_header() { + printf("=== vfmadotn (fp16 x fp16, dynamic slide) ==="); + printf("\n\nMatrix A (8x4, fp16, sliding source):\n"); + printf(" [1.0, 2.0, 3.0, 4.0] row 0\n"); + printf(" [2.0, 3.0, 4.0, 5.0] row 1 <- slide=1 starts here\n"); + printf(" [3.0, 4.0, 5.0, 6.0] row 2\n"); + printf(" [4.0, 5.0, 6.0, 7.0] row 3\n"); + printf(" [5.0, 6.0, 7.0, 8.0] row 4\n"); + printf(" [6.0, 7.0, 8.0, 9.0] row 5\n"); + printf(" [7.0, 8.0, 9.0, 10.0] row 6\n"); + printf(" [8.0, 9.0, 10.0, 11.0] row 7\n"); + printf("\nMatrix B (4x4, fp16, packed):\n"); + printf(" [1.0, 1.0, 1.0, 1.0]\n"); + printf(" [1.0, 1.0, 1.0, 1.0]\n"); + printf(" [1.0, 1.0, 1.0, 1.0]\n"); + printf(" [1.0, 1.0, 1.0, 1.0]\n"); + printf("\nSlide parameter: 1\n"); + printf("Rows used after slide: [1,2,3,4] -> sums [14,18,22,26]\n"); + printf("\nResult matrix C (4x4, fp16):\n"); + printf("Expected: [[14,14,14,14], [18,18,18,18], [22,22,22,22], " + "[26,26,26,26]]\n"); +} + +void print_row_f16(int row, float v0, float v1, float v2, float v3) { + printf("Row %d: [%.1f, %.1f, %.1f, %.1f]\n", row, v0, v1, v2, v3); +} diff --git a/examples/IMEDialect/runtime_vmadot1.c b/examples/IMEDialect/runtime_vmadot1.c new file mode 100644 index 0000000000..863a5bce14 --- /dev/null +++ b/examples/IMEDialect/runtime_vmadot1.c @@ -0,0 +1,28 @@ +#include + +void print_header() { + printf("=== vmadot1 (signed x signed, fixed slide=1) ==="); + printf("\n\nMatrix A (8x8, int8, signed, sliding source):\n"); + printf(" [1, 2, 3, 4, 5, 6, 7, 8] row 0\n"); + printf(" [2, 3, 4, 5, 6, 7, 8, 9] row 1 <- slide=1 starts here\n"); + printf(" [3, 4, 5, 6, 7, 8, 9, 10] row 2\n"); + printf(" [4, 5, 6, 7, 8, 9, 10, 11] row 3\n"); + printf(" [5, 6, 7, 8, 9, 10, 11, 12] row 4\n"); + printf(" [6, 7, 8, 9, 10, 11, 12, 13] row 5\n"); + printf(" [7, 8, 9, 10, 11, 12, 13, 14] row 6\n"); + printf(" [8, 9, 10, 11, 12, 13, 14, 15] row 7\n"); + printf("\nMatrix B (4x8, int8, signed, packed):\n"); + printf(" [1, 1, 1, 1, 1, 1, 1, 1]\n"); + printf(" [1, 1, 1, 1, 1, 1, 1, 1]\n"); + printf(" [1, 1, 1, 1, 1, 1, 1, 1]\n"); + printf(" [1, 1, 1, 1, 1, 1, 1, 1]\n"); + printf("\nFixed slide: 1\n"); + printf("Rows used after slide: [1,2,3,4] -> sums [44,52,60,68]\n"); + printf("\nResult matrix C (4x4, int32):\n"); + printf("Expected: [[44,44,44,44], [52,52,52,52], [60,60,60,60], " + "[68,68,68,68]]\n"); +} + +void print_row(int row, int v0, int v1, int v2, int v3) { + printf("Row %d: [%d, %d, %d, %d]\n", row, v0, v1, v2, v3); +} diff --git a/examples/IMEDialect/runtime_vmadot2.c b/examples/IMEDialect/runtime_vmadot2.c new file mode 100644 index 0000000000..929c557cc3 --- /dev/null +++ b/examples/IMEDialect/runtime_vmadot2.c @@ -0,0 +1,28 @@ +#include + +void print_header() { + printf("=== vmadot2 (signed x signed, fixed slide=2) ==="); + printf("\n\nMatrix A (8x8, int8, signed, sliding source):\n"); + printf(" [1, 2, 3, 4, 5, 6, 7, 8] row 0\n"); + printf(" [2, 3, 4, 5, 6, 7, 8, 9] row 1\n"); + printf(" [3, 4, 5, 6, 7, 8, 9, 10] row 2 <- slide=2 starts here\n"); + printf(" [4, 5, 6, 7, 8, 9, 10, 11] row 3\n"); + printf(" [5, 6, 7, 8, 9, 10, 11, 12] row 4\n"); + printf(" [6, 7, 8, 9, 10, 11, 12, 13] row 5\n"); + printf(" [7, 8, 9, 10, 11, 12, 13, 14] row 6\n"); + printf(" [8, 9, 10, 11, 12, 13, 14, 15] row 7\n"); + printf("\nMatrix B (4x8, int8, signed, packed):\n"); + printf(" [1, 1, 1, 1, 1, 1, 1, 1]\n"); + printf(" [1, 1, 1, 1, 1, 1, 1, 1]\n"); + printf(" [1, 1, 1, 1, 1, 1, 1, 1]\n"); + printf(" [1, 1, 1, 1, 1, 1, 1, 1]\n"); + printf("\nFixed slide: 2\n"); + printf("Rows used after slide: [2,3,4,5] -> sums [52,60,68,76]\n"); + printf("\nResult matrix C (4x4, int32):\n"); + printf("Expected: [[52,52,52,52], [60,60,60,60], [68,68,68,68], " + "[76,76,76,76]]\n"); +} + +void print_row(int row, int v0, int v1, int v2, int v3) { + printf("Row %d: [%d, %d, %d, %d]\n", row, v0, v1, v2, v3); +} diff --git a/examples/IMEDialect/runtime_vmadot3.c b/examples/IMEDialect/runtime_vmadot3.c new file mode 100644 index 0000000000..f9cc7fcf47 --- /dev/null +++ b/examples/IMEDialect/runtime_vmadot3.c @@ -0,0 +1,28 @@ +#include + +void print_header() { + printf("=== vmadot3 (signed x signed, fixed slide=3) ==="); + printf("\n\nMatrix A (8x8, int8, signed, sliding source):\n"); + printf(" [1, 2, 3, 4, 5, 6, 7, 8] row 0\n"); + printf(" [2, 3, 4, 5, 6, 7, 8, 9] row 1\n"); + printf(" [3, 4, 5, 6, 7, 8, 9, 10] row 2\n"); + printf(" [4, 5, 6, 7, 8, 9, 10, 11] row 3 <- slide=3 starts here\n"); + printf(" [5, 6, 7, 8, 9, 10, 11, 12] row 4\n"); + printf(" [6, 7, 8, 9, 10, 11, 12, 13] row 5\n"); + printf(" [7, 8, 9, 10, 11, 12, 13, 14] row 6\n"); + printf(" [8, 9, 10, 11, 12, 13, 14, 15] row 7\n"); + printf("\nMatrix B (4x8, int8, signed, packed):\n"); + printf(" [1, 1, 1, 1, 1, 1, 1, 1]\n"); + printf(" [1, 1, 1, 1, 1, 1, 1, 1]\n"); + printf(" [1, 1, 1, 1, 1, 1, 1, 1]\n"); + printf(" [1, 1, 1, 1, 1, 1, 1, 1]\n"); + printf("\nFixed slide: 3\n"); + printf("Rows used after slide: [3,4,5,6] -> sums [60,68,76,84]\n"); + printf("\nResult matrix C (4x4, int32):\n"); + printf("Expected: [[60,60,60,60], [68,68,68,68], [76,76,76,76], " + "[84,84,84,84]]\n"); +} + +void print_row(int row, int v0, int v1, int v2, int v3) { + printf("Row %d: [%d, %d, %d, %d]\n", row, v0, v1, v2, v3); +} diff --git a/examples/IMEDialect/runtime_vmadotn.c b/examples/IMEDialect/runtime_vmadotn.c new file mode 100644 index 0000000000..288f98b949 --- /dev/null +++ b/examples/IMEDialect/runtime_vmadotn.c @@ -0,0 +1,28 @@ +#include + +void print_header() { + printf("=== vmadotn (signed x signed, dynamic slide) ==="); + printf("\n\nMatrix A (8x8, int8, signed, sliding source):\n"); + printf(" [1, 2, 3, 4, 5, 6, 7, 8] row 0\n"); + printf(" [2, 3, 4, 5, 6, 7, 8, 9] row 1 <- slide=1 starts here\n"); + printf(" [3, 4, 5, 6, 7, 8, 9, 10] row 2\n"); + printf(" [4, 5, 6, 7, 8, 9, 10, 11] row 3\n"); + printf(" [5, 6, 7, 8, 9, 10, 11, 12] row 4\n"); + printf(" [6, 7, 8, 9, 10, 11, 12, 13] row 5\n"); + printf(" [7, 8, 9, 10, 11, 12, 13, 14] row 6\n"); + printf(" [8, 9, 10, 11, 12, 13, 14, 15] row 7\n"); + printf("\nMatrix B (4x8, int8, signed, packed):\n"); + printf(" [1, 1, 1, 1, 1, 1, 1, 1]\n"); + printf(" [1, 1, 1, 1, 1, 1, 1, 1]\n"); + printf(" [1, 1, 1, 1, 1, 1, 1, 1]\n"); + printf(" [1, 1, 1, 1, 1, 1, 1, 1]\n"); + printf("\nSlide parameter: 1\n"); + printf("Rows used after slide: [1,2,3,4] -> sums [44,52,60,68]\n"); + printf("\nResult matrix C (4x4, int32):\n"); + printf("Expected: [[44,44,44,44], [52,52,52,52], [60,60,60,60], " + "[68,68,68,68]]\n"); +} + +void print_row(int row, int v0, int v1, int v2, int v3) { + printf("Row %d: [%d, %d, %d, %d]\n", row, v0, v1, v2, v3); +} diff --git a/examples/IMEDialect/runtime_vmadotnsu.c b/examples/IMEDialect/runtime_vmadotnsu.c new file mode 100644 index 0000000000..a25f097394 --- /dev/null +++ b/examples/IMEDialect/runtime_vmadotnsu.c @@ -0,0 +1,29 @@ +#include + +void print_header() { + printf("=== vmadotnsu (signed x unsigned, dynamic slide) ==="); + printf("\n\nMatrix A (8x8, int8, signed, sliding source):\n"); + printf(" [-1, -2, -3, -4, -5, -6, -7, -8] row 0\n"); + printf(" [-2, -3, -4, -5, -6, -7, -8, -9] row 1 <- slide=1 starts " + "here\n"); + printf(" [-3, -4, -5, -6, -7, -8, -9, -10] row 2\n"); + printf(" [-4, -5, -6, -7, -8, -9, -10, -11] row 3\n"); + printf(" [-5, -6, -7, -8, -9, -10, -11, -12] row 4\n"); + printf(" [-6, -7, -8, -9, -10, -11, -12, -13] row 5\n"); + printf(" [-7, -8, -9, -10, -11, -12, -13, -14] row 6\n"); + printf(" [-8, -9, -10, -11, -12, -13, -14, -15] row 7\n"); + printf("\nMatrix B (4x8, uint8, unsigned, packed):\n"); + printf(" [1, 1, 1, 1, 1, 1, 1, 1]\n"); + printf(" [1, 1, 1, 1, 1, 1, 1, 1]\n"); + printf(" [1, 1, 1, 1, 1, 1, 1, 1]\n"); + printf(" [1, 1, 1, 1, 1, 1, 1, 1]\n"); + printf("\nSlide parameter: 1\n"); + printf("Rows used after slide: [1,2,3,4] -> sums [-44,-52,-60,-68]\n"); + printf("\nResult matrix C (4x4, int32):\n"); + printf("Expected: [[-44,-44,-44,-44], [-52,-52,-52,-52], [-60,-60,-60,-60], " + "[-68,-68,-68,-68]]\n"); +} + +void print_row(int row, int v0, int v1, int v2, int v3) { + printf("Row %d: [%d, %d, %d, %d]\n", row, v0, v1, v2, v3); +} diff --git a/examples/IMEDialect/runtime_vmadotnu.c b/examples/IMEDialect/runtime_vmadotnu.c new file mode 100644 index 0000000000..091cc84235 --- /dev/null +++ b/examples/IMEDialect/runtime_vmadotnu.c @@ -0,0 +1,28 @@ +#include + +void print_header() { + printf("=== vmadotnu (unsigned x unsigned, dynamic slide) ==="); + printf("\n\nMatrix A (8x8, uint8, unsigned, sliding source):\n"); + printf(" [1, 2, 3, 4, 5, 6, 7, 8] row 0\n"); + printf(" [2, 3, 4, 5, 6, 7, 8, 9] row 1\n"); + printf(" [3, 4, 5, 6, 7, 8, 9, 10] row 2 <- slide=2 starts here\n"); + printf(" [4, 5, 6, 7, 8, 9, 10, 11] row 3\n"); + printf(" [5, 6, 7, 8, 9, 10, 11, 12] row 4\n"); + printf(" [6, 7, 8, 9, 10, 11, 12, 13] row 5\n"); + printf(" [7, 8, 9, 10, 11, 12, 13, 14] row 6\n"); + printf(" [8, 9, 10, 11, 12, 13, 14, 15] row 7\n"); + printf("\nMatrix B (4x8, uint8, unsigned, packed):\n"); + printf(" [1, 1, 1, 1, 1, 1, 1, 1]\n"); + printf(" [1, 1, 1, 1, 1, 1, 1, 1]\n"); + printf(" [1, 1, 1, 1, 1, 1, 1, 1]\n"); + printf(" [1, 1, 1, 1, 1, 1, 1, 1]\n"); + printf("\nSlide parameter: 2\n"); + printf("Rows used after slide: [2,3,4,5] -> sums [52,60,68,76]\n"); + printf("\nResult matrix C (4x4, int32):\n"); + printf("Expected: [[52,52,52,52], [60,60,60,60], [68,68,68,68], " + "[76,76,76,76]]\n"); +} + +void print_row(int row, int v0, int v1, int v2, int v3) { + printf("Row %d: [%d, %d, %d, %d]\n", row, v0, v1, v2, v3); +} diff --git a/examples/IMEDialect/runtime_vmadotnus.c b/examples/IMEDialect/runtime_vmadotnus.c new file mode 100644 index 0000000000..dc814624a2 --- /dev/null +++ b/examples/IMEDialect/runtime_vmadotnus.c @@ -0,0 +1,28 @@ +#include + +void print_header() { + printf("=== vmadotnus (unsigned x signed, dynamic slide) ==="); + printf("\n\nMatrix A (8x8, uint8, unsigned, sliding source):\n"); + printf(" [1, 2, 3, 4, 5, 6, 7, 8] row 0\n"); + printf(" [2, 3, 4, 5, 6, 7, 8, 9] row 1\n"); + printf(" [3, 4, 5, 6, 7, 8, 9, 10] row 2\n"); + printf(" [4, 5, 6, 7, 8, 9, 10, 11] row 3 <- slide=3 starts here\n"); + printf(" [5, 6, 7, 8, 9, 10, 11, 12] row 4\n"); + printf(" [6, 7, 8, 9, 10, 11, 12, 13] row 5\n"); + printf(" [7, 8, 9, 10, 11, 12, 13, 14] row 6\n"); + printf(" [8, 9, 10, 11, 12, 13, 14, 15] row 7\n"); + printf("\nMatrix B (4x8, int8, signed, packed):\n"); + printf(" [-1, -1, -1, -1, -1, -1, -1, -1]\n"); + printf(" [-1, -1, -1, -1, -1, -1, -1, -1]\n"); + printf(" [-1, -1, -1, -1, -1, -1, -1, -1]\n"); + printf(" [-1, -1, -1, -1, -1, -1, -1, -1]\n"); + printf("\nSlide parameter: 3\n"); + printf("Rows used after slide: [3,4,5,6] -> sums [-60,-68,-76,-84]\n"); + printf("\nResult matrix C (4x4, int32):\n"); + printf("Expected: [[-60,-60,-60,-60], [-68,-68,-68,-68], [-76,-76,-76,-76], " + "[-84,-84,-84,-84]]\n"); +} + +void print_row(int row, int v0, int v1, int v2, int v3) { + printf("Row %d: [%d, %d, %d, %d]\n", row, v0, v1, v2, v3); +} diff --git a/examples/IMEDialect/vfmadot1.mlir b/examples/IMEDialect/vfmadot1.mlir new file mode 100644 index 0000000000..2ec1df8afb --- /dev/null +++ b/examples/IMEDialect/vfmadot1.mlir @@ -0,0 +1,49 @@ +// RUN: buddy-opt %s | FileCheck %s + +// This example demonstrates the IME vfmadot1 operation (slide=1). +// vfmadot1 performs: C += slide(A, 1) × B where A, B are fp16 matrices. +// +// Sliding window: reads from VS1 and VS1+1 (32 fp16 elements), slides by 1 row. +// Matrix dimensions for VLEN=256, SEW=16 (fp16): +// A: 8×4 (2M×K) source, sliding selects 4×4 (M×K) starting from row 1 +// B: 4×4 (K×N) - fp16 +// C: 4×4 (M×N) - fp16 accumulator + +memref.global "private" @matA : memref<8x4xf16> = dense<[ + [1.0, 2.0, 3.0, 4.0], + [2.0, 3.0, 4.0, 5.0], + [3.0, 4.0, 5.0, 6.0], + [4.0, 5.0, 6.0, 7.0], + [5.0, 6.0, 7.0, 8.0], + [6.0, 7.0, 8.0, 9.0], + [7.0, 8.0, 9.0, 10.0], + [8.0, 9.0, 10.0, 11.0] +]> + +memref.global "private" @matB : memref<4x4xf16> = dense<[ + [1.0, 1.0, 1.0, 1.0], + [1.0, 1.0, 1.0, 1.0], + [1.0, 1.0, 1.0, 1.0], + [1.0, 1.0, 1.0, 1.0] +]> + +func.func @main() -> i32 { + // Get input matrices + %a = memref.get_global @matA : memref<8x4xf16> + %b = memref.get_global @matB : memref<4x4xf16> + + // Allocate output matrix (accumulator) + %c = memref.alloc() : memref<4x4xf16> + + // Initialize accumulator to zero + %zero = arith.constant 0.0 : f16 + linalg.fill ins(%zero : f16) outs(%c : memref<4x4xf16>) + + // Perform floating-point matrix multiply-accumulate with slide=1 + // CHECK: ime.vfmadot1 + ime.vfmadot1 %c, %a, %b : memref<4x4xf16>, memref<8x4xf16>, memref<4x4xf16> + + // Return success + %ret = arith.constant 0 : i32 + return %ret : i32 +} diff --git a/examples/IMEDialect/vfmadot1_print_test.mlir b/examples/IMEDialect/vfmadot1_print_test.mlir new file mode 100644 index 0000000000..7b913f7db0 --- /dev/null +++ b/examples/IMEDialect/vfmadot1_print_test.mlir @@ -0,0 +1,96 @@ +// RUN: buddy-opt %s | FileCheck %s +// CHECK: func.func @main +// +// vfmadot1 computes: C[i,j] += sum_k(A[1+i,k] * B[j,k]) +// +// Sliding window reads 64 fp16 elements from VS1 (8 rows), then slides by 1 row. +// A (8x4): fp16, source matrix +// B (4x4): fp16, packed form +// +// With slide=1: +// A rows used = [1,2,3,4] (after sliding by 1) +// Row 1 = [2,3,4,5], sum = 14 +// Row 2 = [3,4,5,6], sum = 18 +// Row 3 = [4,5,6,7], sum = 22 +// Row 4 = [5,6,7,8], sum = 26 + +memref.global "private" @matA : memref<8x4xf16> = dense<[ + [1.0, 2.0, 3.0, 4.0], + [2.0, 3.0, 4.0, 5.0], + [3.0, 4.0, 5.0, 6.0], + [4.0, 5.0, 6.0, 7.0], + [5.0, 6.0, 7.0, 8.0], + [6.0, 7.0, 8.0, 9.0], + [7.0, 8.0, 9.0, 10.0], + [8.0, 9.0, 10.0, 11.0] +]> + +// Packed B (4x4): all ones for easy verification +memref.global "private" @matB : memref<4x4xf16> = dense<[ + [1.0, 1.0, 1.0, 1.0], + [1.0, 1.0, 1.0, 1.0], + [1.0, 1.0, 1.0, 1.0], + [1.0, 1.0, 1.0, 1.0] +]> + +// With slide=1: Expected C = [[14,14,14,14], [18,18,18,18], [22,22,22,22], [26,26,26,26]] + +func.func private @print_row_f16(i32, f16, f16, f16, f16) +func.func private @print_header() + +func.func @main() -> i32 { + %a = memref.get_global @matA : memref<8x4xf16> + %b = memref.get_global @matB : memref<4x4xf16> + + %c = memref.alloc() : memref<4x4xf16> + + // Initialize to zero + %zero = arith.constant 0.0 : f16 + linalg.fill ins(%zero : f16) outs(%c : memref<4x4xf16>) + + // Perform vfmadot1 (fixed slide=1) + ime.vfmadot1 %c, %a, %b : memref<4x4xf16>, memref<8x4xf16>, memref<4x4xf16> + + // Print results + call @print_header() : () -> () + + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + %i0 = arith.constant 0 : i32 + %i1 = arith.constant 1 : i32 + %i2 = arith.constant 2 : i32 + %i3 = arith.constant 3 : i32 + + // Row 0 + %v00 = memref.load %c[%c0, %c0] : memref<4x4xf16> + %v01 = memref.load %c[%c0, %c1] : memref<4x4xf16> + %v02 = memref.load %c[%c0, %c2] : memref<4x4xf16> + %v03 = memref.load %c[%c0, %c3] : memref<4x4xf16> + call @print_row_f16(%i0, %v00, %v01, %v02, %v03) : (i32, f16, f16, f16, f16) -> () + + // Row 1 + %v10 = memref.load %c[%c1, %c0] : memref<4x4xf16> + %v11 = memref.load %c[%c1, %c1] : memref<4x4xf16> + %v12 = memref.load %c[%c1, %c2] : memref<4x4xf16> + %v13 = memref.load %c[%c1, %c3] : memref<4x4xf16> + call @print_row_f16(%i1, %v10, %v11, %v12, %v13) : (i32, f16, f16, f16, f16) -> () + + // Row 2 + %v20 = memref.load %c[%c2, %c0] : memref<4x4xf16> + %v21 = memref.load %c[%c2, %c1] : memref<4x4xf16> + %v22 = memref.load %c[%c2, %c2] : memref<4x4xf16> + %v23 = memref.load %c[%c2, %c3] : memref<4x4xf16> + call @print_row_f16(%i2, %v20, %v21, %v22, %v23) : (i32, f16, f16, f16, f16) -> () + + // Row 3 + %v30 = memref.load %c[%c3, %c0] : memref<4x4xf16> + %v31 = memref.load %c[%c3, %c1] : memref<4x4xf16> + %v32 = memref.load %c[%c3, %c2] : memref<4x4xf16> + %v33 = memref.load %c[%c3, %c3] : memref<4x4xf16> + call @print_row_f16(%i3, %v30, %v31, %v32, %v33) : (i32, f16, f16, f16, f16) -> () + + %ret = arith.constant 0 : i32 + return %ret : i32 +} diff --git a/examples/IMEDialect/vfmadot2.mlir b/examples/IMEDialect/vfmadot2.mlir new file mode 100644 index 0000000000..c6e287b85d --- /dev/null +++ b/examples/IMEDialect/vfmadot2.mlir @@ -0,0 +1,49 @@ +// RUN: buddy-opt %s | FileCheck %s + +// This example demonstrates the IME vfmadot2 operation (slide=2). +// vfmadot2 performs: C += slide(A, 2) × B where A, B are fp16 matrices. +// +// Sliding window: reads from VS1 and VS1+1 (32 fp16 elements), slides by 2 rows. +// Matrix dimensions for VLEN=256, SEW=16 (fp16): +// A: 8×4 (2M×K) source, sliding selects 4×4 (M×K) starting from row 2 +// B: 4×4 (K×N) - fp16 +// C: 4×4 (M×N) - fp16 accumulator + +memref.global "private" @matA : memref<8x4xf16> = dense<[ + [1.0, 2.0, 3.0, 4.0], + [2.0, 3.0, 4.0, 5.0], + [3.0, 4.0, 5.0, 6.0], + [4.0, 5.0, 6.0, 7.0], + [5.0, 6.0, 7.0, 8.0], + [6.0, 7.0, 8.0, 9.0], + [7.0, 8.0, 9.0, 10.0], + [8.0, 9.0, 10.0, 11.0] +]> + +memref.global "private" @matB : memref<4x4xf16> = dense<[ + [1.0, 1.0, 1.0, 1.0], + [1.0, 1.0, 1.0, 1.0], + [1.0, 1.0, 1.0, 1.0], + [1.0, 1.0, 1.0, 1.0] +]> + +func.func @main() -> i32 { + // Get input matrices + %a = memref.get_global @matA : memref<8x4xf16> + %b = memref.get_global @matB : memref<4x4xf16> + + // Allocate output matrix (accumulator) + %c = memref.alloc() : memref<4x4xf16> + + // Initialize accumulator to zero + %zero = arith.constant 0.0 : f16 + linalg.fill ins(%zero : f16) outs(%c : memref<4x4xf16>) + + // Perform floating-point matrix multiply-accumulate with slide=2 + // CHECK: ime.vfmadot2 + ime.vfmadot2 %c, %a, %b : memref<4x4xf16>, memref<8x4xf16>, memref<4x4xf16> + + // Return success + %ret = arith.constant 0 : i32 + return %ret : i32 +} diff --git a/examples/IMEDialect/vfmadot2_print_test.mlir b/examples/IMEDialect/vfmadot2_print_test.mlir new file mode 100644 index 0000000000..3f5ef931a8 --- /dev/null +++ b/examples/IMEDialect/vfmadot2_print_test.mlir @@ -0,0 +1,96 @@ +// RUN: buddy-opt %s | FileCheck %s +// CHECK: func.func @main +// +// vfmadot2 computes: C[i,j] += sum_k(A[2+i,k] * B[j,k]) +// +// Sliding window reads 64 fp16 elements from VS1 (8 rows), then slides by 2 rows. +// A (8x4): fp16, source matrix +// B (4x4): fp16, packed form +// +// With slide=2: +// A rows used = [2,3,4,5] (after sliding by 2) +// Row 2 = [3,4,5,6], sum = 18 +// Row 3 = [4,5,6,7], sum = 22 +// Row 4 = [5,6,7,8], sum = 26 +// Row 5 = [6,7,8,9], sum = 30 + +memref.global "private" @matA : memref<8x4xf16> = dense<[ + [1.0, 2.0, 3.0, 4.0], + [2.0, 3.0, 4.0, 5.0], + [3.0, 4.0, 5.0, 6.0], + [4.0, 5.0, 6.0, 7.0], + [5.0, 6.0, 7.0, 8.0], + [6.0, 7.0, 8.0, 9.0], + [7.0, 8.0, 9.0, 10.0], + [8.0, 9.0, 10.0, 11.0] +]> + +// Packed B (4x4): all ones for easy verification +memref.global "private" @matB : memref<4x4xf16> = dense<[ + [1.0, 1.0, 1.0, 1.0], + [1.0, 1.0, 1.0, 1.0], + [1.0, 1.0, 1.0, 1.0], + [1.0, 1.0, 1.0, 1.0] +]> + +// With slide=2: Expected C = [[18,18,18,18], [22,22,22,22], [26,26,26,26], [30,30,30,30]] + +func.func private @print_row_f16(i32, f16, f16, f16, f16) +func.func private @print_header() + +func.func @main() -> i32 { + %a = memref.get_global @matA : memref<8x4xf16> + %b = memref.get_global @matB : memref<4x4xf16> + + %c = memref.alloc() : memref<4x4xf16> + + // Initialize to zero + %zero = arith.constant 0.0 : f16 + linalg.fill ins(%zero : f16) outs(%c : memref<4x4xf16>) + + // Perform vfmadot2 (fixed slide=2) + ime.vfmadot2 %c, %a, %b : memref<4x4xf16>, memref<8x4xf16>, memref<4x4xf16> + + // Print results + call @print_header() : () -> () + + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + %i0 = arith.constant 0 : i32 + %i1 = arith.constant 1 : i32 + %i2 = arith.constant 2 : i32 + %i3 = arith.constant 3 : i32 + + // Row 0 + %v00 = memref.load %c[%c0, %c0] : memref<4x4xf16> + %v01 = memref.load %c[%c0, %c1] : memref<4x4xf16> + %v02 = memref.load %c[%c0, %c2] : memref<4x4xf16> + %v03 = memref.load %c[%c0, %c3] : memref<4x4xf16> + call @print_row_f16(%i0, %v00, %v01, %v02, %v03) : (i32, f16, f16, f16, f16) -> () + + // Row 1 + %v10 = memref.load %c[%c1, %c0] : memref<4x4xf16> + %v11 = memref.load %c[%c1, %c1] : memref<4x4xf16> + %v12 = memref.load %c[%c1, %c2] : memref<4x4xf16> + %v13 = memref.load %c[%c1, %c3] : memref<4x4xf16> + call @print_row_f16(%i1, %v10, %v11, %v12, %v13) : (i32, f16, f16, f16, f16) -> () + + // Row 2 + %v20 = memref.load %c[%c2, %c0] : memref<4x4xf16> + %v21 = memref.load %c[%c2, %c1] : memref<4x4xf16> + %v22 = memref.load %c[%c2, %c2] : memref<4x4xf16> + %v23 = memref.load %c[%c2, %c3] : memref<4x4xf16> + call @print_row_f16(%i2, %v20, %v21, %v22, %v23) : (i32, f16, f16, f16, f16) -> () + + // Row 3 + %v30 = memref.load %c[%c3, %c0] : memref<4x4xf16> + %v31 = memref.load %c[%c3, %c1] : memref<4x4xf16> + %v32 = memref.load %c[%c3, %c2] : memref<4x4xf16> + %v33 = memref.load %c[%c3, %c3] : memref<4x4xf16> + call @print_row_f16(%i3, %v30, %v31, %v32, %v33) : (i32, f16, f16, f16, f16) -> () + + %ret = arith.constant 0 : i32 + return %ret : i32 +} diff --git a/examples/IMEDialect/vfmadot3.mlir b/examples/IMEDialect/vfmadot3.mlir new file mode 100644 index 0000000000..39e3353f8c --- /dev/null +++ b/examples/IMEDialect/vfmadot3.mlir @@ -0,0 +1,49 @@ +// RUN: buddy-opt %s | FileCheck %s + +// This example demonstrates the IME vfmadot3 operation (slide=3). +// vfmadot3 performs: C += slide(A, 3) × B where A, B are fp16 matrices. +// +// Sliding window: reads from VS1 and VS1+1 (32 fp16 elements), slides by 3 rows. +// Matrix dimensions for VLEN=256, SEW=16 (fp16): +// A: 8×4 (2M×K) source, sliding selects 4×4 (M×K) starting from row 3 +// B: 4×4 (K×N) - fp16 +// C: 4×4 (M×N) - fp16 accumulator + +memref.global "private" @matA : memref<8x4xf16> = dense<[ + [1.0, 2.0, 3.0, 4.0], + [2.0, 3.0, 4.0, 5.0], + [3.0, 4.0, 5.0, 6.0], + [4.0, 5.0, 6.0, 7.0], + [5.0, 6.0, 7.0, 8.0], + [6.0, 7.0, 8.0, 9.0], + [7.0, 8.0, 9.0, 10.0], + [8.0, 9.0, 10.0, 11.0] +]> + +memref.global "private" @matB : memref<4x4xf16> = dense<[ + [1.0, 1.0, 1.0, 1.0], + [1.0, 1.0, 1.0, 1.0], + [1.0, 1.0, 1.0, 1.0], + [1.0, 1.0, 1.0, 1.0] +]> + +func.func @main() -> i32 { + // Get input matrices + %a = memref.get_global @matA : memref<8x4xf16> + %b = memref.get_global @matB : memref<4x4xf16> + + // Allocate output matrix (accumulator) + %c = memref.alloc() : memref<4x4xf16> + + // Initialize accumulator to zero + %zero = arith.constant 0.0 : f16 + linalg.fill ins(%zero : f16) outs(%c : memref<4x4xf16>) + + // Perform floating-point matrix multiply-accumulate with slide=3 + // CHECK: ime.vfmadot3 + ime.vfmadot3 %c, %a, %b : memref<4x4xf16>, memref<8x4xf16>, memref<4x4xf16> + + // Return success + %ret = arith.constant 0 : i32 + return %ret : i32 +} diff --git a/examples/IMEDialect/vfmadot3_print_test.mlir b/examples/IMEDialect/vfmadot3_print_test.mlir new file mode 100644 index 0000000000..8444eed61f --- /dev/null +++ b/examples/IMEDialect/vfmadot3_print_test.mlir @@ -0,0 +1,96 @@ +// RUN: buddy-opt %s | FileCheck %s +// CHECK: func.func @main +// +// vfmadot3 computes: C[i,j] += sum_k(A[3+i,k] * B[j,k]) +// +// Sliding window reads 64 fp16 elements from VS1 (8 rows), then slides by 3 rows. +// A (8x4): fp16, source matrix +// B (4x4): fp16, packed form +// +// With slide=3: +// A rows used = [3,4,5,6] (after sliding by 3) +// Row 3 = [4,5,6,7], sum = 22 +// Row 4 = [5,6,7,8], sum = 26 +// Row 5 = [6,7,8,9], sum = 30 +// Row 6 = [7,8,9,10], sum = 34 + +memref.global "private" @matA : memref<8x4xf16> = dense<[ + [1.0, 2.0, 3.0, 4.0], + [2.0, 3.0, 4.0, 5.0], + [3.0, 4.0, 5.0, 6.0], + [4.0, 5.0, 6.0, 7.0], + [5.0, 6.0, 7.0, 8.0], + [6.0, 7.0, 8.0, 9.0], + [7.0, 8.0, 9.0, 10.0], + [8.0, 9.0, 10.0, 11.0] +]> + +// Packed B (4x4): all ones for easy verification +memref.global "private" @matB : memref<4x4xf16> = dense<[ + [1.0, 1.0, 1.0, 1.0], + [1.0, 1.0, 1.0, 1.0], + [1.0, 1.0, 1.0, 1.0], + [1.0, 1.0, 1.0, 1.0] +]> + +// With slide=3: Expected C = [[22,22,22,22], [26,26,26,26], [30,30,30,30], [34,34,34,34]] + +func.func private @print_row_f16(i32, f16, f16, f16, f16) +func.func private @print_header() + +func.func @main() -> i32 { + %a = memref.get_global @matA : memref<8x4xf16> + %b = memref.get_global @matB : memref<4x4xf16> + + %c = memref.alloc() : memref<4x4xf16> + + // Initialize to zero + %zero = arith.constant 0.0 : f16 + linalg.fill ins(%zero : f16) outs(%c : memref<4x4xf16>) + + // Perform vfmadot3 (fixed slide=3) + ime.vfmadot3 %c, %a, %b : memref<4x4xf16>, memref<8x4xf16>, memref<4x4xf16> + + // Print results + call @print_header() : () -> () + + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + %i0 = arith.constant 0 : i32 + %i1 = arith.constant 1 : i32 + %i2 = arith.constant 2 : i32 + %i3 = arith.constant 3 : i32 + + // Row 0 + %v00 = memref.load %c[%c0, %c0] : memref<4x4xf16> + %v01 = memref.load %c[%c0, %c1] : memref<4x4xf16> + %v02 = memref.load %c[%c0, %c2] : memref<4x4xf16> + %v03 = memref.load %c[%c0, %c3] : memref<4x4xf16> + call @print_row_f16(%i0, %v00, %v01, %v02, %v03) : (i32, f16, f16, f16, f16) -> () + + // Row 1 + %v10 = memref.load %c[%c1, %c0] : memref<4x4xf16> + %v11 = memref.load %c[%c1, %c1] : memref<4x4xf16> + %v12 = memref.load %c[%c1, %c2] : memref<4x4xf16> + %v13 = memref.load %c[%c1, %c3] : memref<4x4xf16> + call @print_row_f16(%i1, %v10, %v11, %v12, %v13) : (i32, f16, f16, f16, f16) -> () + + // Row 2 + %v20 = memref.load %c[%c2, %c0] : memref<4x4xf16> + %v21 = memref.load %c[%c2, %c1] : memref<4x4xf16> + %v22 = memref.load %c[%c2, %c2] : memref<4x4xf16> + %v23 = memref.load %c[%c2, %c3] : memref<4x4xf16> + call @print_row_f16(%i2, %v20, %v21, %v22, %v23) : (i32, f16, f16, f16, f16) -> () + + // Row 3 + %v30 = memref.load %c[%c3, %c0] : memref<4x4xf16> + %v31 = memref.load %c[%c3, %c1] : memref<4x4xf16> + %v32 = memref.load %c[%c3, %c2] : memref<4x4xf16> + %v33 = memref.load %c[%c3, %c3] : memref<4x4xf16> + call @print_row_f16(%i3, %v30, %v31, %v32, %v33) : (i32, f16, f16, f16, f16) -> () + + %ret = arith.constant 0 : i32 + return %ret : i32 +} diff --git a/examples/IMEDialect/vfmadot_print_test.mlir b/examples/IMEDialect/vfmadot_print_test.mlir index 93e203964b..272a024b2a 100644 --- a/examples/IMEDialect/vfmadot_print_test.mlir +++ b/examples/IMEDialect/vfmadot_print_test.mlir @@ -1,4 +1,5 @@ -// IME vfmadot test - floating-point matrix multiply-accumulate (fp16) +// RUN: buddy-opt %s | FileCheck %s +// CHECK: func.func @main // // vfmadot computes: C[i,j] += sum_k(A[i,k] * B[j,k]) for fp16 values // diff --git a/examples/IMEDialect/vfmadotn.mlir b/examples/IMEDialect/vfmadotn.mlir new file mode 100644 index 0000000000..f7ed0ee204 --- /dev/null +++ b/examples/IMEDialect/vfmadotn.mlir @@ -0,0 +1,52 @@ +// RUN: buddy-opt %s | FileCheck %s + +// This example demonstrates the IME vfmadotn operation with dynamic slide parameter. +// vfmadotn performs: C += slide(A, n) × B where A, B are fp16 matrices. +// +// Sliding window: reads from VS1 and VS1+1 (64 elements), slides by n rows. +// Matrix dimensions for VLEN=256, SEW=16 (fp16): +// A: 8×4 (2M×K) source, sliding selects 4×4 (M×K) - fp16 +// B: 4×4 (K×N) - fp16 +// C: 4×4 (M×N) - fp16 accumulator + +memref.global "private" @matA : memref<8x4xf16> = dense<[ + [1.0, 2.0, 3.0, 4.0], + [2.0, 3.0, 4.0, 5.0], + [3.0, 4.0, 5.0, 6.0], + [4.0, 5.0, 6.0, 7.0], + [5.0, 6.0, 7.0, 8.0], + [6.0, 7.0, 8.0, 9.0], + [7.0, 8.0, 9.0, 10.0], + [8.0, 9.0, 10.0, 11.0] +]> + +memref.global "private" @matB : memref<4x4xf16> = dense<[ + [1.0, 1.0, 1.0, 1.0], + [1.0, 1.0, 1.0, 1.0], + [1.0, 1.0, 1.0, 1.0], + [1.0, 1.0, 1.0, 1.0] +]> + +func.func @main() -> i32 { + // Get input matrices + %a = memref.get_global @matA : memref<8x4xf16> + %b = memref.get_global @matB : memref<4x4xf16> + + // Allocate output matrix (accumulator) + %c = memref.alloc() : memref<4x4xf16> + + // Initialize accumulator to zero + %zero = arith.constant 0.0 : f16 + linalg.fill ins(%zero : f16) outs(%c : memref<4x4xf16>) + + // Slide parameter (0-3) + %slide = arith.constant 1 : i64 + + // Perform floating-point matrix multiply-accumulate with dynamic slide + // CHECK: ime.vfmadotn + ime.vfmadotn %c, %a, %b, %slide : memref<4x4xf16>, memref<8x4xf16>, memref<4x4xf16> + + // Return success + %ret = arith.constant 0 : i32 + return %ret : i32 +} diff --git a/examples/IMEDialect/vfmadotn_print_test.mlir b/examples/IMEDialect/vfmadotn_print_test.mlir new file mode 100644 index 0000000000..1a805551ed --- /dev/null +++ b/examples/IMEDialect/vfmadotn_print_test.mlir @@ -0,0 +1,99 @@ +// RUN: buddy-opt %s | FileCheck %s +// CHECK: func.func @main +// +// vfmadotn computes: C[i,j] += sum_k(A[slide+i,k] * B[j,k]) +// +// Sliding window reads 64 fp16 elements from VS1 (8 rows), then slides by n rows. +// A (8x4): fp16, source matrix +// B (4x4): fp16, packed form +// +// With slide=1: +// A rows used = [1,2,3,4] (after sliding by 1) +// Row 1 = [2,3,4,5], sum = 14 +// Row 2 = [3,4,5,6], sum = 18 +// Row 3 = [4,5,6,7], sum = 22 +// Row 4 = [5,6,7,8], sum = 26 + +memref.global "private" @matA : memref<8x4xf16> = dense<[ + [1.0, 2.0, 3.0, 4.0], + [2.0, 3.0, 4.0, 5.0], + [3.0, 4.0, 5.0, 6.0], + [4.0, 5.0, 6.0, 7.0], + [5.0, 6.0, 7.0, 8.0], + [6.0, 7.0, 8.0, 9.0], + [7.0, 8.0, 9.0, 10.0], + [8.0, 9.0, 10.0, 11.0] +]> + +// Packed B (4x4): all ones for easy verification +memref.global "private" @matB : memref<4x4xf16> = dense<[ + [1.0, 1.0, 1.0, 1.0], + [1.0, 1.0, 1.0, 1.0], + [1.0, 1.0, 1.0, 1.0], + [1.0, 1.0, 1.0, 1.0] +]> + +// With slide=1: Expected C = [[14,14,14,14], [18,18,18,18], [22,22,22,22], [26,26,26,26]] + +func.func private @print_row_f16(i32, f16, f16, f16, f16) +func.func private @print_header() + +func.func @main() -> i32 { + %a = memref.get_global @matA : memref<8x4xf16> + %b = memref.get_global @matB : memref<4x4xf16> + + %c = memref.alloc() : memref<4x4xf16> + + // Initialize to zero + %zero = arith.constant 0.0 : f16 + linalg.fill ins(%zero : f16) outs(%c : memref<4x4xf16>) + + // Slide parameter = 1 + %slide = arith.constant 1 : i64 + + // Perform vfmadotn with slide=1 + ime.vfmadotn %c, %a, %b, %slide : memref<4x4xf16>, memref<8x4xf16>, memref<4x4xf16> + + // Print results + call @print_header() : () -> () + + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + %i0 = arith.constant 0 : i32 + %i1 = arith.constant 1 : i32 + %i2 = arith.constant 2 : i32 + %i3 = arith.constant 3 : i32 + + // Row 0 + %v00 = memref.load %c[%c0, %c0] : memref<4x4xf16> + %v01 = memref.load %c[%c0, %c1] : memref<4x4xf16> + %v02 = memref.load %c[%c0, %c2] : memref<4x4xf16> + %v03 = memref.load %c[%c0, %c3] : memref<4x4xf16> + call @print_row_f16(%i0, %v00, %v01, %v02, %v03) : (i32, f16, f16, f16, f16) -> () + + // Row 1 + %v10 = memref.load %c[%c1, %c0] : memref<4x4xf16> + %v11 = memref.load %c[%c1, %c1] : memref<4x4xf16> + %v12 = memref.load %c[%c1, %c2] : memref<4x4xf16> + %v13 = memref.load %c[%c1, %c3] : memref<4x4xf16> + call @print_row_f16(%i1, %v10, %v11, %v12, %v13) : (i32, f16, f16, f16, f16) -> () + + // Row 2 + %v20 = memref.load %c[%c2, %c0] : memref<4x4xf16> + %v21 = memref.load %c[%c2, %c1] : memref<4x4xf16> + %v22 = memref.load %c[%c2, %c2] : memref<4x4xf16> + %v23 = memref.load %c[%c2, %c3] : memref<4x4xf16> + call @print_row_f16(%i2, %v20, %v21, %v22, %v23) : (i32, f16, f16, f16, f16) -> () + + // Row 3 + %v30 = memref.load %c[%c3, %c0] : memref<4x4xf16> + %v31 = memref.load %c[%c3, %c1] : memref<4x4xf16> + %v32 = memref.load %c[%c3, %c2] : memref<4x4xf16> + %v33 = memref.load %c[%c3, %c3] : memref<4x4xf16> + call @print_row_f16(%i3, %v30, %v31, %v32, %v33) : (i32, f16, f16, f16, f16) -> () + + %ret = arith.constant 0 : i32 + return %ret : i32 +} diff --git a/examples/IMEDialect/vmadot1.mlir b/examples/IMEDialect/vmadot1.mlir new file mode 100644 index 0000000000..3cc31384a0 --- /dev/null +++ b/examples/IMEDialect/vmadot1.mlir @@ -0,0 +1,53 @@ +// RUN: buddy-opt %s | FileCheck %s + +// This example demonstrates the IME vmadot1 operation (slide=1). +// vmadot1 performs: C += slide(A, 1) × B where A, B are signed int8 matrices. +// +// Sliding window: reads from VS1 and VS1+1 (64 elements), slides by 1 row. +// Matrix dimensions for VLEN=256, SEW=8: +// A: 8×8 (2M×K) source, sliding selects 4×8 (M×K) starting from row 1 +// B: 8×4 (K×N) - signed int8 +// C: 4×4 (M×N) - int32 accumulator + +memref.global "private" @matA : memref<8x8xi8> = dense<[ + [1, 2, 3, 4, 5, 6, 7, 8], + [2, 3, 4, 5, 6, 7, 8, 9], + [3, 4, 5, 6, 7, 8, 9, 10], + [4, 5, 6, 7, 8, 9, 10, 11], + [5, 6, 7, 8, 9, 10, 11, 12], + [6, 7, 8, 9, 10, 11, 12, 13], + [7, 8, 9, 10, 11, 12, 13, 14], + [8, 9, 10, 11, 12, 13, 14, 15] +]> + +memref.global "private" @matB : memref<8x4xi8> = dense<[ + [1, 1, 1, 1], + [1, 1, 1, 1], + [1, 1, 1, 1], + [1, 1, 1, 1], + [1, 1, 1, 1], + [1, 1, 1, 1], + [1, 1, 1, 1], + [1, 1, 1, 1] +]> + +func.func @main() -> i32 { + // Get input matrices + %a = memref.get_global @matA : memref<8x8xi8> + %b = memref.get_global @matB : memref<8x4xi8> + + // Allocate output matrix (accumulator) + %c = memref.alloc() : memref<4x4xi32> + + // Initialize accumulator to zero + %zero = arith.constant 0 : i32 + linalg.fill ins(%zero : i32) outs(%c : memref<4x4xi32>) + + // Perform signed × signed matrix multiply-accumulate with slide=1 + // CHECK: ime.vmadot1 + ime.vmadot1 %c, %a, %b : memref<4x4xi32>, memref<8x8xi8>, memref<8x4xi8> + + // Return success + %ret = arith.constant 0 : i32 + return %ret : i32 +} diff --git a/examples/IMEDialect/vmadot1_print_test.mlir b/examples/IMEDialect/vmadot1_print_test.mlir new file mode 100644 index 0000000000..d7ab5b8d81 --- /dev/null +++ b/examples/IMEDialect/vmadot1_print_test.mlir @@ -0,0 +1,96 @@ +// RUN: buddy-opt %s | FileCheck %s +// CHECK: func.func @main +// +// vmadot1 computes: C[i,j] += sum_k(signed(A[1+i,k]) * signed(B[j,k])) +// +// Sliding window reads 64 elements from VS1 (8 rows), then slides by 1 row. +// A (8x8): signed int8, source matrix (needs 64 elements for sliding) +// B (4x8): signed int8, packed form +// +// With slide=1: +// A rows used = [1,2,3,4] (after sliding by 1) +// Row 1 = [2,3,4,5,6,7,8,9], sum = 44 +// Row 2 = [3,4,5,6,7,8,9,10], sum = 52 +// Row 3 = [4,5,6,7,8,9,10,11], sum = 60 +// Row 4 = [5,6,7,8,9,10,11,12], sum = 68 + +memref.global "private" @matA : memref<8x8xi8> = dense<[ + [1, 2, 3, 4, 5, 6, 7, 8], + [2, 3, 4, 5, 6, 7, 8, 9], + [3, 4, 5, 6, 7, 8, 9, 10], + [4, 5, 6, 7, 8, 9, 10, 11], + [5, 6, 7, 8, 9, 10, 11, 12], + [6, 7, 8, 9, 10, 11, 12, 13], + [7, 8, 9, 10, 11, 12, 13, 14], + [8, 9, 10, 11, 12, 13, 14, 15] +]> + +// Packed B (4x8): all ones for easy verification +memref.global "private" @matB : memref<4x8xi8> = dense<[ + [1, 1, 1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1, 1, 1] +]> + +// With slide=1: Expected C = [[44,44,44,44], [52,52,52,52], [60,60,60,60], [68,68,68,68]] + +func.func private @print_row(i32, i32, i32, i32, i32) +func.func private @print_header() + +func.func @main() -> i32 { + %a = memref.get_global @matA : memref<8x8xi8> + %b = memref.get_global @matB : memref<4x8xi8> + + %c = memref.alloc() : memref<4x4xi32> + + // Initialize to zero + %zero = arith.constant 0 : i32 + linalg.fill ins(%zero : i32) outs(%c : memref<4x4xi32>) + + // Perform vmadot1 (fixed slide=1) + ime.vmadot1 %c, %a, %b : memref<4x4xi32>, memref<8x8xi8>, memref<4x8xi8> + + // Print results + call @print_header() : () -> () + + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + %i0 = arith.constant 0 : i32 + %i1 = arith.constant 1 : i32 + %i2 = arith.constant 2 : i32 + %i3 = arith.constant 3 : i32 + + // Row 0 + %v00 = memref.load %c[%c0, %c0] : memref<4x4xi32> + %v01 = memref.load %c[%c0, %c1] : memref<4x4xi32> + %v02 = memref.load %c[%c0, %c2] : memref<4x4xi32> + %v03 = memref.load %c[%c0, %c3] : memref<4x4xi32> + call @print_row(%i0, %v00, %v01, %v02, %v03) : (i32, i32, i32, i32, i32) -> () + + // Row 1 + %v10 = memref.load %c[%c1, %c0] : memref<4x4xi32> + %v11 = memref.load %c[%c1, %c1] : memref<4x4xi32> + %v12 = memref.load %c[%c1, %c2] : memref<4x4xi32> + %v13 = memref.load %c[%c1, %c3] : memref<4x4xi32> + call @print_row(%i1, %v10, %v11, %v12, %v13) : (i32, i32, i32, i32, i32) -> () + + // Row 2 + %v20 = memref.load %c[%c2, %c0] : memref<4x4xi32> + %v21 = memref.load %c[%c2, %c1] : memref<4x4xi32> + %v22 = memref.load %c[%c2, %c2] : memref<4x4xi32> + %v23 = memref.load %c[%c2, %c3] : memref<4x4xi32> + call @print_row(%i2, %v20, %v21, %v22, %v23) : (i32, i32, i32, i32, i32) -> () + + // Row 3 + %v30 = memref.load %c[%c3, %c0] : memref<4x4xi32> + %v31 = memref.load %c[%c3, %c1] : memref<4x4xi32> + %v32 = memref.load %c[%c3, %c2] : memref<4x4xi32> + %v33 = memref.load %c[%c3, %c3] : memref<4x4xi32> + call @print_row(%i3, %v30, %v31, %v32, %v33) : (i32, i32, i32, i32, i32) -> () + + %ret = arith.constant 0 : i32 + return %ret : i32 +} diff --git a/examples/IMEDialect/vmadot1su.mlir b/examples/IMEDialect/vmadot1su.mlir new file mode 100644 index 0000000000..274521e5fc --- /dev/null +++ b/examples/IMEDialect/vmadot1su.mlir @@ -0,0 +1,53 @@ +// RUN: buddy-opt %s | FileCheck %s + +// This example demonstrates the IME vmadot1su operation (slide=1, signed × unsigned). +// vmadot1su performs: C += slide(A, 1) × B where A is signed, B is unsigned int8 matrices. +// +// Sliding window: reads from VS1 and VS1+1 (64 elements), slides by 1 row. +// Matrix dimensions for VLEN=256, SEW=8: +// A: 8×8 (2M×K) source, sliding selects 4×8 (M×K) starting from row 1 - signed int8 +// B: 8×4 (K×N) - unsigned int8 +// C: 4×4 (M×N) - int32 accumulator + +memref.global "private" @matA : memref<8x8xi8> = dense<[ + [1, 2, 3, 4, 5, 6, 7, 8], + [2, 3, 4, 5, 6, 7, 8, 9], + [3, 4, 5, 6, 7, 8, 9, 10], + [4, 5, 6, 7, 8, 9, 10, 11], + [5, 6, 7, 8, 9, 10, 11, 12], + [6, 7, 8, 9, 10, 11, 12, 13], + [7, 8, 9, 10, 11, 12, 13, 14], + [8, 9, 10, 11, 12, 13, 14, 15] +]> + +memref.global "private" @matB : memref<8x4xui8> = dense<[ + [1, 1, 1, 1], + [1, 1, 1, 1], + [1, 1, 1, 1], + [1, 1, 1, 1], + [1, 1, 1, 1], + [1, 1, 1, 1], + [1, 1, 1, 1], + [1, 1, 1, 1] +]> + +func.func @main() -> i32 { + // Get input matrices + %a = memref.get_global @matA : memref<8x8xi8> + %b = memref.get_global @matB : memref<8x4xui8> + + // Allocate output matrix (accumulator) + %c = memref.alloc() : memref<4x4xi32> + + // Initialize accumulator to zero + %zero = arith.constant 0 : i32 + linalg.fill ins(%zero : i32) outs(%c : memref<4x4xi32>) + + // Perform signed × unsigned matrix multiply-accumulate with slide=1 + // CHECK: ime.vmadot1su + ime.vmadot1su %c, %a, %b : memref<4x4xi32>, memref<8x8xi8>, memref<8x4xui8> + + // Return success + %ret = arith.constant 0 : i32 + return %ret : i32 +} diff --git a/examples/IMEDialect/vmadot1u.mlir b/examples/IMEDialect/vmadot1u.mlir new file mode 100644 index 0000000000..719f7c95d6 --- /dev/null +++ b/examples/IMEDialect/vmadot1u.mlir @@ -0,0 +1,53 @@ +// RUN: buddy-opt %s | FileCheck %s + +// This example demonstrates the IME vmadot1u operation (slide=1, unsigned × unsigned). +// vmadot1u performs: C += slide(A, 1) × B where A, B are unsigned int8 matrices. +// +// Sliding window: reads from VS1 and VS1+1 (64 elements), slides by 1 row. +// Matrix dimensions for VLEN=256, SEW=8: +// A: 8×8 (2M×K) source, sliding selects 4×8 (M×K) starting from row 1 +// B: 8×4 (K×N) - unsigned int8 +// C: 4×4 (M×N) - int32 accumulator + +memref.global "private" @matA : memref<8x8xui8> = dense<[ + [1, 2, 3, 4, 5, 6, 7, 8], + [2, 3, 4, 5, 6, 7, 8, 9], + [3, 4, 5, 6, 7, 8, 9, 10], + [4, 5, 6, 7, 8, 9, 10, 11], + [5, 6, 7, 8, 9, 10, 11, 12], + [6, 7, 8, 9, 10, 11, 12, 13], + [7, 8, 9, 10, 11, 12, 13, 14], + [8, 9, 10, 11, 12, 13, 14, 15] +]> + +memref.global "private" @matB : memref<8x4xui8> = dense<[ + [1, 1, 1, 1], + [1, 1, 1, 1], + [1, 1, 1, 1], + [1, 1, 1, 1], + [1, 1, 1, 1], + [1, 1, 1, 1], + [1, 1, 1, 1], + [1, 1, 1, 1] +]> + +func.func @main() -> i32 { + // Get input matrices + %a = memref.get_global @matA : memref<8x8xui8> + %b = memref.get_global @matB : memref<8x4xui8> + + // Allocate output matrix (accumulator) + %c = memref.alloc() : memref<4x4xi32> + + // Initialize accumulator to zero + %zero = arith.constant 0 : i32 + linalg.fill ins(%zero : i32) outs(%c : memref<4x4xi32>) + + // Perform unsigned × unsigned matrix multiply-accumulate with slide=1 + // CHECK: ime.vmadot1u + ime.vmadot1u %c, %a, %b : memref<4x4xi32>, memref<8x8xui8>, memref<8x4xui8> + + // Return success + %ret = arith.constant 0 : i32 + return %ret : i32 +} diff --git a/examples/IMEDialect/vmadot1us.mlir b/examples/IMEDialect/vmadot1us.mlir new file mode 100644 index 0000000000..0e61f2ef36 --- /dev/null +++ b/examples/IMEDialect/vmadot1us.mlir @@ -0,0 +1,53 @@ +// RUN: buddy-opt %s | FileCheck %s + +// This example demonstrates the IME vmadot1us operation (slide=1, unsigned × signed). +// vmadot1us performs: C += slide(A, 1) × B where A is unsigned, B is signed int8 matrices. +// +// Sliding window: reads from VS1 and VS1+1 (64 elements), slides by 1 row. +// Matrix dimensions for VLEN=256, SEW=8: +// A: 8×8 (2M×K) source, sliding selects 4×8 (M×K) starting from row 1 - unsigned int8 +// B: 8×4 (K×N) - signed int8 +// C: 4×4 (M×N) - int32 accumulator + +memref.global "private" @matA : memref<8x8xui8> = dense<[ + [1, 2, 3, 4, 5, 6, 7, 8], + [2, 3, 4, 5, 6, 7, 8, 9], + [3, 4, 5, 6, 7, 8, 9, 10], + [4, 5, 6, 7, 8, 9, 10, 11], + [5, 6, 7, 8, 9, 10, 11, 12], + [6, 7, 8, 9, 10, 11, 12, 13], + [7, 8, 9, 10, 11, 12, 13, 14], + [8, 9, 10, 11, 12, 13, 14, 15] +]> + +memref.global "private" @matB : memref<8x4xi8> = dense<[ + [1, 1, 1, 1], + [1, 1, 1, 1], + [1, 1, 1, 1], + [1, 1, 1, 1], + [1, 1, 1, 1], + [1, 1, 1, 1], + [1, 1, 1, 1], + [1, 1, 1, 1] +]> + +func.func @main() -> i32 { + // Get input matrices + %a = memref.get_global @matA : memref<8x8xui8> + %b = memref.get_global @matB : memref<8x4xi8> + + // Allocate output matrix (accumulator) + %c = memref.alloc() : memref<4x4xi32> + + // Initialize accumulator to zero + %zero = arith.constant 0 : i32 + linalg.fill ins(%zero : i32) outs(%c : memref<4x4xi32>) + + // Perform unsigned × signed matrix multiply-accumulate with slide=1 + // CHECK: ime.vmadot1us + ime.vmadot1us %c, %a, %b : memref<4x4xi32>, memref<8x8xui8>, memref<8x4xi8> + + // Return success + %ret = arith.constant 0 : i32 + return %ret : i32 +} diff --git a/examples/IMEDialect/vmadot2.mlir b/examples/IMEDialect/vmadot2.mlir new file mode 100644 index 0000000000..102885c20f --- /dev/null +++ b/examples/IMEDialect/vmadot2.mlir @@ -0,0 +1,53 @@ +// RUN: buddy-opt %s | FileCheck %s + +// This example demonstrates the IME vmadot2 operation (slide=2). +// vmadot2 performs: C += slide(A, 2) × B where A, B are signed int8 matrices. +// +// Sliding window: reads from VS1 and VS1+1 (64 elements), slides by 2 rows. +// Matrix dimensions for VLEN=256, SEW=8: +// A: 8×8 (2M×K) source, sliding selects 4×8 (M×K) starting from row 2 +// B: 8×4 (K×N) - signed int8 +// C: 4×4 (M×N) - int32 accumulator + +memref.global "private" @matA : memref<8x8xi8> = dense<[ + [1, 2, 3, 4, 5, 6, 7, 8], + [2, 3, 4, 5, 6, 7, 8, 9], + [3, 4, 5, 6, 7, 8, 9, 10], + [4, 5, 6, 7, 8, 9, 10, 11], + [5, 6, 7, 8, 9, 10, 11, 12], + [6, 7, 8, 9, 10, 11, 12, 13], + [7, 8, 9, 10, 11, 12, 13, 14], + [8, 9, 10, 11, 12, 13, 14, 15] +]> + +memref.global "private" @matB : memref<8x4xi8> = dense<[ + [1, 1, 1, 1], + [1, 1, 1, 1], + [1, 1, 1, 1], + [1, 1, 1, 1], + [1, 1, 1, 1], + [1, 1, 1, 1], + [1, 1, 1, 1], + [1, 1, 1, 1] +]> + +func.func @main() -> i32 { + // Get input matrices + %a = memref.get_global @matA : memref<8x8xi8> + %b = memref.get_global @matB : memref<8x4xi8> + + // Allocate output matrix (accumulator) + %c = memref.alloc() : memref<4x4xi32> + + // Initialize accumulator to zero + %zero = arith.constant 0 : i32 + linalg.fill ins(%zero : i32) outs(%c : memref<4x4xi32>) + + // Perform signed × signed matrix multiply-accumulate with slide=2 + // CHECK: ime.vmadot2 + ime.vmadot2 %c, %a, %b : memref<4x4xi32>, memref<8x8xi8>, memref<8x4xi8> + + // Return success + %ret = arith.constant 0 : i32 + return %ret : i32 +} diff --git a/examples/IMEDialect/vmadot2_print_test.mlir b/examples/IMEDialect/vmadot2_print_test.mlir new file mode 100644 index 0000000000..69dfdb4363 --- /dev/null +++ b/examples/IMEDialect/vmadot2_print_test.mlir @@ -0,0 +1,96 @@ +// RUN: buddy-opt %s | FileCheck %s +// CHECK: func.func @main +// +// vmadot2 computes: C[i,j] += sum_k(signed(A[2+i,k]) * signed(B[j,k])) +// +// Sliding window reads 64 elements from VS1 (8 rows), then slides by 2 rows. +// A (8x8): signed int8, source matrix (needs 64 elements for sliding) +// B (4x8): signed int8, packed form +// +// With slide=2: +// A rows used = [2,3,4,5] (after sliding by 2) +// Row 2 = [3,4,5,6,7,8,9,10], sum = 52 +// Row 3 = [4,5,6,7,8,9,10,11], sum = 60 +// Row 4 = [5,6,7,8,9,10,11,12], sum = 68 +// Row 5 = [6,7,8,9,10,11,12,13], sum = 76 + +memref.global "private" @matA : memref<8x8xi8> = dense<[ + [1, 2, 3, 4, 5, 6, 7, 8], + [2, 3, 4, 5, 6, 7, 8, 9], + [3, 4, 5, 6, 7, 8, 9, 10], + [4, 5, 6, 7, 8, 9, 10, 11], + [5, 6, 7, 8, 9, 10, 11, 12], + [6, 7, 8, 9, 10, 11, 12, 13], + [7, 8, 9, 10, 11, 12, 13, 14], + [8, 9, 10, 11, 12, 13, 14, 15] +]> + +// Packed B (4x8): all ones for easy verification +memref.global "private" @matB : memref<4x8xi8> = dense<[ + [1, 1, 1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1, 1, 1] +]> + +// With slide=2: Expected C = [[52,52,52,52], [60,60,60,60], [68,68,68,68], [76,76,76,76]] + +func.func private @print_row(i32, i32, i32, i32, i32) +func.func private @print_header() + +func.func @main() -> i32 { + %a = memref.get_global @matA : memref<8x8xi8> + %b = memref.get_global @matB : memref<4x8xi8> + + %c = memref.alloc() : memref<4x4xi32> + + // Initialize to zero + %zero = arith.constant 0 : i32 + linalg.fill ins(%zero : i32) outs(%c : memref<4x4xi32>) + + // Perform vmadot2 (fixed slide=2) + ime.vmadot2 %c, %a, %b : memref<4x4xi32>, memref<8x8xi8>, memref<4x8xi8> + + // Print results + call @print_header() : () -> () + + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + %i0 = arith.constant 0 : i32 + %i1 = arith.constant 1 : i32 + %i2 = arith.constant 2 : i32 + %i3 = arith.constant 3 : i32 + + // Row 0 + %v00 = memref.load %c[%c0, %c0] : memref<4x4xi32> + %v01 = memref.load %c[%c0, %c1] : memref<4x4xi32> + %v02 = memref.load %c[%c0, %c2] : memref<4x4xi32> + %v03 = memref.load %c[%c0, %c3] : memref<4x4xi32> + call @print_row(%i0, %v00, %v01, %v02, %v03) : (i32, i32, i32, i32, i32) -> () + + // Row 1 + %v10 = memref.load %c[%c1, %c0] : memref<4x4xi32> + %v11 = memref.load %c[%c1, %c1] : memref<4x4xi32> + %v12 = memref.load %c[%c1, %c2] : memref<4x4xi32> + %v13 = memref.load %c[%c1, %c3] : memref<4x4xi32> + call @print_row(%i1, %v10, %v11, %v12, %v13) : (i32, i32, i32, i32, i32) -> () + + // Row 2 + %v20 = memref.load %c[%c2, %c0] : memref<4x4xi32> + %v21 = memref.load %c[%c2, %c1] : memref<4x4xi32> + %v22 = memref.load %c[%c2, %c2] : memref<4x4xi32> + %v23 = memref.load %c[%c2, %c3] : memref<4x4xi32> + call @print_row(%i2, %v20, %v21, %v22, %v23) : (i32, i32, i32, i32, i32) -> () + + // Row 3 + %v30 = memref.load %c[%c3, %c0] : memref<4x4xi32> + %v31 = memref.load %c[%c3, %c1] : memref<4x4xi32> + %v32 = memref.load %c[%c3, %c2] : memref<4x4xi32> + %v33 = memref.load %c[%c3, %c3] : memref<4x4xi32> + call @print_row(%i3, %v30, %v31, %v32, %v33) : (i32, i32, i32, i32, i32) -> () + + %ret = arith.constant 0 : i32 + return %ret : i32 +} diff --git a/examples/IMEDialect/vmadot2su.mlir b/examples/IMEDialect/vmadot2su.mlir new file mode 100644 index 0000000000..1cd81cd9a5 --- /dev/null +++ b/examples/IMEDialect/vmadot2su.mlir @@ -0,0 +1,53 @@ +// RUN: buddy-opt %s | FileCheck %s + +// This example demonstrates the IME vmadot2su operation (slide=2, signed × unsigned). +// vmadot2su performs: C += slide(A, 2) × B where A is signed, B is unsigned int8 matrices. +// +// Sliding window: reads from VS1 and VS1+1 (64 elements), slides by 2 rows. +// Matrix dimensions for VLEN=256, SEW=8: +// A: 8×8 (2M×K) source, sliding selects 4×8 (M×K) starting from row 2 - signed int8 +// B: 8×4 (K×N) - unsigned int8 +// C: 4×4 (M×N) - int32 accumulator + +memref.global "private" @matA : memref<8x8xi8> = dense<[ + [1, 2, 3, 4, 5, 6, 7, 8], + [2, 3, 4, 5, 6, 7, 8, 9], + [3, 4, 5, 6, 7, 8, 9, 10], + [4, 5, 6, 7, 8, 9, 10, 11], + [5, 6, 7, 8, 9, 10, 11, 12], + [6, 7, 8, 9, 10, 11, 12, 13], + [7, 8, 9, 10, 11, 12, 13, 14], + [8, 9, 10, 11, 12, 13, 14, 15] +]> + +memref.global "private" @matB : memref<8x4xui8> = dense<[ + [1, 1, 1, 1], + [1, 1, 1, 1], + [1, 1, 1, 1], + [1, 1, 1, 1], + [1, 1, 1, 1], + [1, 1, 1, 1], + [1, 1, 1, 1], + [1, 1, 1, 1] +]> + +func.func @main() -> i32 { + // Get input matrices + %a = memref.get_global @matA : memref<8x8xi8> + %b = memref.get_global @matB : memref<8x4xui8> + + // Allocate output matrix (accumulator) + %c = memref.alloc() : memref<4x4xi32> + + // Initialize accumulator to zero + %zero = arith.constant 0 : i32 + linalg.fill ins(%zero : i32) outs(%c : memref<4x4xi32>) + + // Perform signed × unsigned matrix multiply-accumulate with slide=2 + // CHECK: ime.vmadot2su + ime.vmadot2su %c, %a, %b : memref<4x4xi32>, memref<8x8xi8>, memref<8x4xui8> + + // Return success + %ret = arith.constant 0 : i32 + return %ret : i32 +} diff --git a/examples/IMEDialect/vmadot2u.mlir b/examples/IMEDialect/vmadot2u.mlir new file mode 100644 index 0000000000..5013cd0dfa --- /dev/null +++ b/examples/IMEDialect/vmadot2u.mlir @@ -0,0 +1,53 @@ +// RUN: buddy-opt %s | FileCheck %s + +// This example demonstrates the IME vmadot2u operation (slide=2, unsigned × unsigned). +// vmadot2u performs: C += slide(A, 2) × B where A, B are unsigned int8 matrices. +// +// Sliding window: reads from VS1 and VS1+1 (64 elements), slides by 2 rows. +// Matrix dimensions for VLEN=256, SEW=8: +// A: 8×8 (2M×K) source, sliding selects 4×8 (M×K) starting from row 2 +// B: 8×4 (K×N) - unsigned int8 +// C: 4×4 (M×N) - int32 accumulator + +memref.global "private" @matA : memref<8x8xui8> = dense<[ + [1, 2, 3, 4, 5, 6, 7, 8], + [2, 3, 4, 5, 6, 7, 8, 9], + [3, 4, 5, 6, 7, 8, 9, 10], + [4, 5, 6, 7, 8, 9, 10, 11], + [5, 6, 7, 8, 9, 10, 11, 12], + [6, 7, 8, 9, 10, 11, 12, 13], + [7, 8, 9, 10, 11, 12, 13, 14], + [8, 9, 10, 11, 12, 13, 14, 15] +]> + +memref.global "private" @matB : memref<8x4xui8> = dense<[ + [1, 1, 1, 1], + [1, 1, 1, 1], + [1, 1, 1, 1], + [1, 1, 1, 1], + [1, 1, 1, 1], + [1, 1, 1, 1], + [1, 1, 1, 1], + [1, 1, 1, 1] +]> + +func.func @main() -> i32 { + // Get input matrices + %a = memref.get_global @matA : memref<8x8xui8> + %b = memref.get_global @matB : memref<8x4xui8> + + // Allocate output matrix (accumulator) + %c = memref.alloc() : memref<4x4xi32> + + // Initialize accumulator to zero + %zero = arith.constant 0 : i32 + linalg.fill ins(%zero : i32) outs(%c : memref<4x4xi32>) + + // Perform unsigned × unsigned matrix multiply-accumulate with slide=2 + // CHECK: ime.vmadot2u + ime.vmadot2u %c, %a, %b : memref<4x4xi32>, memref<8x8xui8>, memref<8x4xui8> + + // Return success + %ret = arith.constant 0 : i32 + return %ret : i32 +} diff --git a/examples/IMEDialect/vmadot2us.mlir b/examples/IMEDialect/vmadot2us.mlir new file mode 100644 index 0000000000..5666005dbd --- /dev/null +++ b/examples/IMEDialect/vmadot2us.mlir @@ -0,0 +1,53 @@ +// RUN: buddy-opt %s | FileCheck %s + +// This example demonstrates the IME vmadot2us operation (slide=2, unsigned × signed). +// vmadot2us performs: C += slide(A, 2) × B where A is unsigned, B is signed int8 matrices. +// +// Sliding window: reads from VS1 and VS1+1 (64 elements), slides by 2 rows. +// Matrix dimensions for VLEN=256, SEW=8: +// A: 8×8 (2M×K) source, sliding selects 4×8 (M×K) starting from row 2 - unsigned int8 +// B: 8×4 (K×N) - signed int8 +// C: 4×4 (M×N) - int32 accumulator + +memref.global "private" @matA : memref<8x8xui8> = dense<[ + [1, 2, 3, 4, 5, 6, 7, 8], + [2, 3, 4, 5, 6, 7, 8, 9], + [3, 4, 5, 6, 7, 8, 9, 10], + [4, 5, 6, 7, 8, 9, 10, 11], + [5, 6, 7, 8, 9, 10, 11, 12], + [6, 7, 8, 9, 10, 11, 12, 13], + [7, 8, 9, 10, 11, 12, 13, 14], + [8, 9, 10, 11, 12, 13, 14, 15] +]> + +memref.global "private" @matB : memref<8x4xi8> = dense<[ + [1, 1, 1, 1], + [1, 1, 1, 1], + [1, 1, 1, 1], + [1, 1, 1, 1], + [1, 1, 1, 1], + [1, 1, 1, 1], + [1, 1, 1, 1], + [1, 1, 1, 1] +]> + +func.func @main() -> i32 { + // Get input matrices + %a = memref.get_global @matA : memref<8x8xui8> + %b = memref.get_global @matB : memref<8x4xi8> + + // Allocate output matrix (accumulator) + %c = memref.alloc() : memref<4x4xi32> + + // Initialize accumulator to zero + %zero = arith.constant 0 : i32 + linalg.fill ins(%zero : i32) outs(%c : memref<4x4xi32>) + + // Perform unsigned × signed matrix multiply-accumulate with slide=2 + // CHECK: ime.vmadot2us + ime.vmadot2us %c, %a, %b : memref<4x4xi32>, memref<8x8xui8>, memref<8x4xi8> + + // Return success + %ret = arith.constant 0 : i32 + return %ret : i32 +} diff --git a/examples/IMEDialect/vmadot3.mlir b/examples/IMEDialect/vmadot3.mlir new file mode 100644 index 0000000000..8110fd8e6e --- /dev/null +++ b/examples/IMEDialect/vmadot3.mlir @@ -0,0 +1,53 @@ +// RUN: buddy-opt %s | FileCheck %s + +// This example demonstrates the IME vmadot3 operation (slide=3). +// vmadot3 performs: C += slide(A, 3) × B where A, B are signed int8 matrices. +// +// Sliding window: reads from VS1 and VS1+1 (64 elements), slides by 3 rows. +// Matrix dimensions for VLEN=256, SEW=8: +// A: 8×8 (2M×K) source, sliding selects 4×8 (M×K) starting from row 3 +// B: 8×4 (K×N) - signed int8 +// C: 4×4 (M×N) - int32 accumulator + +memref.global "private" @matA : memref<8x8xi8> = dense<[ + [1, 2, 3, 4, 5, 6, 7, 8], + [2, 3, 4, 5, 6, 7, 8, 9], + [3, 4, 5, 6, 7, 8, 9, 10], + [4, 5, 6, 7, 8, 9, 10, 11], + [5, 6, 7, 8, 9, 10, 11, 12], + [6, 7, 8, 9, 10, 11, 12, 13], + [7, 8, 9, 10, 11, 12, 13, 14], + [8, 9, 10, 11, 12, 13, 14, 15] +]> + +memref.global "private" @matB : memref<8x4xi8> = dense<[ + [1, 1, 1, 1], + [1, 1, 1, 1], + [1, 1, 1, 1], + [1, 1, 1, 1], + [1, 1, 1, 1], + [1, 1, 1, 1], + [1, 1, 1, 1], + [1, 1, 1, 1] +]> + +func.func @main() -> i32 { + // Get input matrices + %a = memref.get_global @matA : memref<8x8xi8> + %b = memref.get_global @matB : memref<8x4xi8> + + // Allocate output matrix (accumulator) + %c = memref.alloc() : memref<4x4xi32> + + // Initialize accumulator to zero + %zero = arith.constant 0 : i32 + linalg.fill ins(%zero : i32) outs(%c : memref<4x4xi32>) + + // Perform signed × signed matrix multiply-accumulate with slide=3 + // CHECK: ime.vmadot3 + ime.vmadot3 %c, %a, %b : memref<4x4xi32>, memref<8x8xi8>, memref<8x4xi8> + + // Return success + %ret = arith.constant 0 : i32 + return %ret : i32 +} diff --git a/examples/IMEDialect/vmadot3_print_test.mlir b/examples/IMEDialect/vmadot3_print_test.mlir new file mode 100644 index 0000000000..d40075fe6c --- /dev/null +++ b/examples/IMEDialect/vmadot3_print_test.mlir @@ -0,0 +1,96 @@ +// RUN: buddy-opt %s | FileCheck %s +// CHECK: func.func @main +// +// vmadot3 computes: C[i,j] += sum_k(signed(A[3+i,k]) * signed(B[j,k])) +// +// Sliding window reads 64 elements from VS1 (8 rows), then slides by 3 rows. +// A (8x8): signed int8, source matrix (needs 64 elements for sliding) +// B (4x8): signed int8, packed form +// +// With slide=3: +// A rows used = [3,4,5,6] (after sliding by 3) +// Row 3 = [4,5,6,7,8,9,10,11], sum = 60 +// Row 4 = [5,6,7,8,9,10,11,12], sum = 68 +// Row 5 = [6,7,8,9,10,11,12,13], sum = 76 +// Row 6 = [7,8,9,10,11,12,13,14], sum = 84 + +memref.global "private" @matA : memref<8x8xi8> = dense<[ + [1, 2, 3, 4, 5, 6, 7, 8], + [2, 3, 4, 5, 6, 7, 8, 9], + [3, 4, 5, 6, 7, 8, 9, 10], + [4, 5, 6, 7, 8, 9, 10, 11], + [5, 6, 7, 8, 9, 10, 11, 12], + [6, 7, 8, 9, 10, 11, 12, 13], + [7, 8, 9, 10, 11, 12, 13, 14], + [8, 9, 10, 11, 12, 13, 14, 15] +]> + +// Packed B (4x8): all ones for easy verification +memref.global "private" @matB : memref<4x8xi8> = dense<[ + [1, 1, 1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1, 1, 1] +]> + +// With slide=3: Expected C = [[60,60,60,60], [68,68,68,68], [76,76,76,76], [84,84,84,84]] + +func.func private @print_row(i32, i32, i32, i32, i32) +func.func private @print_header() + +func.func @main() -> i32 { + %a = memref.get_global @matA : memref<8x8xi8> + %b = memref.get_global @matB : memref<4x8xi8> + + %c = memref.alloc() : memref<4x4xi32> + + // Initialize to zero + %zero = arith.constant 0 : i32 + linalg.fill ins(%zero : i32) outs(%c : memref<4x4xi32>) + + // Perform vmadot3 (fixed slide=3) + ime.vmadot3 %c, %a, %b : memref<4x4xi32>, memref<8x8xi8>, memref<4x8xi8> + + // Print results + call @print_header() : () -> () + + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + %i0 = arith.constant 0 : i32 + %i1 = arith.constant 1 : i32 + %i2 = arith.constant 2 : i32 + %i3 = arith.constant 3 : i32 + + // Row 0 + %v00 = memref.load %c[%c0, %c0] : memref<4x4xi32> + %v01 = memref.load %c[%c0, %c1] : memref<4x4xi32> + %v02 = memref.load %c[%c0, %c2] : memref<4x4xi32> + %v03 = memref.load %c[%c0, %c3] : memref<4x4xi32> + call @print_row(%i0, %v00, %v01, %v02, %v03) : (i32, i32, i32, i32, i32) -> () + + // Row 1 + %v10 = memref.load %c[%c1, %c0] : memref<4x4xi32> + %v11 = memref.load %c[%c1, %c1] : memref<4x4xi32> + %v12 = memref.load %c[%c1, %c2] : memref<4x4xi32> + %v13 = memref.load %c[%c1, %c3] : memref<4x4xi32> + call @print_row(%i1, %v10, %v11, %v12, %v13) : (i32, i32, i32, i32, i32) -> () + + // Row 2 + %v20 = memref.load %c[%c2, %c0] : memref<4x4xi32> + %v21 = memref.load %c[%c2, %c1] : memref<4x4xi32> + %v22 = memref.load %c[%c2, %c2] : memref<4x4xi32> + %v23 = memref.load %c[%c2, %c3] : memref<4x4xi32> + call @print_row(%i2, %v20, %v21, %v22, %v23) : (i32, i32, i32, i32, i32) -> () + + // Row 3 + %v30 = memref.load %c[%c3, %c0] : memref<4x4xi32> + %v31 = memref.load %c[%c3, %c1] : memref<4x4xi32> + %v32 = memref.load %c[%c3, %c2] : memref<4x4xi32> + %v33 = memref.load %c[%c3, %c3] : memref<4x4xi32> + call @print_row(%i3, %v30, %v31, %v32, %v33) : (i32, i32, i32, i32, i32) -> () + + %ret = arith.constant 0 : i32 + return %ret : i32 +} diff --git a/examples/IMEDialect/vmadot3su.mlir b/examples/IMEDialect/vmadot3su.mlir new file mode 100644 index 0000000000..997ba46b0f --- /dev/null +++ b/examples/IMEDialect/vmadot3su.mlir @@ -0,0 +1,53 @@ +// RUN: buddy-opt %s | FileCheck %s + +// This example demonstrates the IME vmadot3su operation (slide=3, signed × unsigned). +// vmadot3su performs: C += slide(A, 3) × B where A is signed, B is unsigned int8 matrices. +// +// Sliding window: reads from VS1 and VS1+1 (64 elements), slides by 3 rows. +// Matrix dimensions for VLEN=256, SEW=8: +// A: 8×8 (2M×K) source, sliding selects 4×8 (M×K) starting from row 3 - signed int8 +// B: 8×4 (K×N) - unsigned int8 +// C: 4×4 (M×N) - int32 accumulator + +memref.global "private" @matA : memref<8x8xi8> = dense<[ + [1, 2, 3, 4, 5, 6, 7, 8], + [2, 3, 4, 5, 6, 7, 8, 9], + [3, 4, 5, 6, 7, 8, 9, 10], + [4, 5, 6, 7, 8, 9, 10, 11], + [5, 6, 7, 8, 9, 10, 11, 12], + [6, 7, 8, 9, 10, 11, 12, 13], + [7, 8, 9, 10, 11, 12, 13, 14], + [8, 9, 10, 11, 12, 13, 14, 15] +]> + +memref.global "private" @matB : memref<8x4xui8> = dense<[ + [1, 1, 1, 1], + [1, 1, 1, 1], + [1, 1, 1, 1], + [1, 1, 1, 1], + [1, 1, 1, 1], + [1, 1, 1, 1], + [1, 1, 1, 1], + [1, 1, 1, 1] +]> + +func.func @main() -> i32 { + // Get input matrices + %a = memref.get_global @matA : memref<8x8xi8> + %b = memref.get_global @matB : memref<8x4xui8> + + // Allocate output matrix (accumulator) + %c = memref.alloc() : memref<4x4xi32> + + // Initialize accumulator to zero + %zero = arith.constant 0 : i32 + linalg.fill ins(%zero : i32) outs(%c : memref<4x4xi32>) + + // Perform signed × unsigned matrix multiply-accumulate with slide=3 + // CHECK: ime.vmadot3su + ime.vmadot3su %c, %a, %b : memref<4x4xi32>, memref<8x8xi8>, memref<8x4xui8> + + // Return success + %ret = arith.constant 0 : i32 + return %ret : i32 +} diff --git a/examples/IMEDialect/vmadot3u.mlir b/examples/IMEDialect/vmadot3u.mlir new file mode 100644 index 0000000000..b891bf46be --- /dev/null +++ b/examples/IMEDialect/vmadot3u.mlir @@ -0,0 +1,53 @@ +// RUN: buddy-opt %s | FileCheck %s + +// This example demonstrates the IME vmadot3u operation (slide=3, unsigned × unsigned). +// vmadot3u performs: C += slide(A, 3) × B where A, B are unsigned int8 matrices. +// +// Sliding window: reads from VS1 and VS1+1 (64 elements), slides by 3 rows. +// Matrix dimensions for VLEN=256, SEW=8: +// A: 8×8 (2M×K) source, sliding selects 4×8 (M×K) starting from row 3 +// B: 8×4 (K×N) - unsigned int8 +// C: 4×4 (M×N) - int32 accumulator + +memref.global "private" @matA : memref<8x8xui8> = dense<[ + [1, 2, 3, 4, 5, 6, 7, 8], + [2, 3, 4, 5, 6, 7, 8, 9], + [3, 4, 5, 6, 7, 8, 9, 10], + [4, 5, 6, 7, 8, 9, 10, 11], + [5, 6, 7, 8, 9, 10, 11, 12], + [6, 7, 8, 9, 10, 11, 12, 13], + [7, 8, 9, 10, 11, 12, 13, 14], + [8, 9, 10, 11, 12, 13, 14, 15] +]> + +memref.global "private" @matB : memref<8x4xui8> = dense<[ + [1, 1, 1, 1], + [1, 1, 1, 1], + [1, 1, 1, 1], + [1, 1, 1, 1], + [1, 1, 1, 1], + [1, 1, 1, 1], + [1, 1, 1, 1], + [1, 1, 1, 1] +]> + +func.func @main() -> i32 { + // Get input matrices + %a = memref.get_global @matA : memref<8x8xui8> + %b = memref.get_global @matB : memref<8x4xui8> + + // Allocate output matrix (accumulator) + %c = memref.alloc() : memref<4x4xi32> + + // Initialize accumulator to zero + %zero = arith.constant 0 : i32 + linalg.fill ins(%zero : i32) outs(%c : memref<4x4xi32>) + + // Perform unsigned × unsigned matrix multiply-accumulate with slide=3 + // CHECK: ime.vmadot3u + ime.vmadot3u %c, %a, %b : memref<4x4xi32>, memref<8x8xui8>, memref<8x4xui8> + + // Return success + %ret = arith.constant 0 : i32 + return %ret : i32 +} diff --git a/examples/IMEDialect/vmadot3us.mlir b/examples/IMEDialect/vmadot3us.mlir new file mode 100644 index 0000000000..29781008ae --- /dev/null +++ b/examples/IMEDialect/vmadot3us.mlir @@ -0,0 +1,53 @@ +// RUN: buddy-opt %s | FileCheck %s + +// This example demonstrates the IME vmadot3us operation (slide=3, unsigned × signed). +// vmadot3us performs: C += slide(A, 3) × B where A is unsigned, B is signed int8 matrices. +// +// Sliding window: reads from VS1 and VS1+1 (64 elements), slides by 3 rows. +// Matrix dimensions for VLEN=256, SEW=8: +// A: 8×8 (2M×K) source, sliding selects 4×8 (M×K) starting from row 3 - unsigned int8 +// B: 8×4 (K×N) - signed int8 +// C: 4×4 (M×N) - int32 accumulator + +memref.global "private" @matA : memref<8x8xui8> = dense<[ + [1, 2, 3, 4, 5, 6, 7, 8], + [2, 3, 4, 5, 6, 7, 8, 9], + [3, 4, 5, 6, 7, 8, 9, 10], + [4, 5, 6, 7, 8, 9, 10, 11], + [5, 6, 7, 8, 9, 10, 11, 12], + [6, 7, 8, 9, 10, 11, 12, 13], + [7, 8, 9, 10, 11, 12, 13, 14], + [8, 9, 10, 11, 12, 13, 14, 15] +]> + +memref.global "private" @matB : memref<8x4xi8> = dense<[ + [1, 1, 1, 1], + [1, 1, 1, 1], + [1, 1, 1, 1], + [1, 1, 1, 1], + [1, 1, 1, 1], + [1, 1, 1, 1], + [1, 1, 1, 1], + [1, 1, 1, 1] +]> + +func.func @main() -> i32 { + // Get input matrices + %a = memref.get_global @matA : memref<8x8xui8> + %b = memref.get_global @matB : memref<8x4xi8> + + // Allocate output matrix (accumulator) + %c = memref.alloc() : memref<4x4xi32> + + // Initialize accumulator to zero + %zero = arith.constant 0 : i32 + linalg.fill ins(%zero : i32) outs(%c : memref<4x4xi32>) + + // Perform unsigned × signed matrix multiply-accumulate with slide=3 + // CHECK: ime.vmadot3us + ime.vmadot3us %c, %a, %b : memref<4x4xi32>, memref<8x8xui8>, memref<8x4xi8> + + // Return success + %ret = arith.constant 0 : i32 + return %ret : i32 +} diff --git a/examples/IMEDialect/vmadot_print_test.mlir b/examples/IMEDialect/vmadot_print_test.mlir index f5cd92ed7e..c4d2c07e84 100644 --- a/examples/IMEDialect/vmadot_print_test.mlir +++ b/examples/IMEDialect/vmadot_print_test.mlir @@ -1,4 +1,5 @@ -// IME vmadot test: signed × signed matrix multiply-accumulate +// RUN: buddy-opt %s | FileCheck %s +// CHECK: func.func @main // // vmadot computes: C[i,j] += sum_k(signed(A[i,k]) * signed(B[j,k])) // diff --git a/examples/IMEDialect/vmadotn.mlir b/examples/IMEDialect/vmadotn.mlir new file mode 100644 index 0000000000..9bd8fe9285 --- /dev/null +++ b/examples/IMEDialect/vmadotn.mlir @@ -0,0 +1,56 @@ +// RUN: buddy-opt %s | FileCheck %s + +// This example demonstrates the IME vmadotn operation with dynamic slide parameter. +// vmadotn performs: C += slide(A, n) × B where A, B are signed int8 matrices. +// +// Sliding window: reads from VS1 and VS1+1 (64 elements), slides by n rows. +// Matrix dimensions for VLEN=256, SEW=8: +// A: 8×8 (2M×K) source, sliding selects 4×8 (M×K) - signed int8 +// B: 8×4 (K×N) - signed int8 +// C: 4×4 (M×N) - int32 accumulator + +memref.global "private" @matA : memref<8x8xi8> = dense<[ + [1, 2, 3, 4, 5, 6, 7, 8], + [2, 3, 4, 5, 6, 7, 8, 9], + [3, 4, 5, 6, 7, 8, 9, 10], + [4, 5, 6, 7, 8, 9, 10, 11], + [5, 6, 7, 8, 9, 10, 11, 12], + [6, 7, 8, 9, 10, 11, 12, 13], + [7, 8, 9, 10, 11, 12, 13, 14], + [8, 9, 10, 11, 12, 13, 14, 15] +]> + +memref.global "private" @matB : memref<8x4xi8> = dense<[ + [1, 1, 1, 1], + [1, 1, 1, 1], + [1, 1, 1, 1], + [1, 1, 1, 1], + [1, 1, 1, 1], + [1, 1, 1, 1], + [1, 1, 1, 1], + [1, 1, 1, 1] +]> + +func.func @main() -> i32 { + // Get input matrices + %a = memref.get_global @matA : memref<8x8xi8> + %b = memref.get_global @matB : memref<8x4xi8> + + // Allocate output matrix (accumulator) + %c = memref.alloc() : memref<4x4xi32> + + // Initialize accumulator to zero + %zero = arith.constant 0 : i32 + linalg.fill ins(%zero : i32) outs(%c : memref<4x4xi32>) + + // Slide parameter (0-3) + %slide = arith.constant 1 : i64 + + // Perform signed × signed matrix multiply-accumulate with dynamic slide + // CHECK: ime.vmadotn + ime.vmadotn %c, %a, %b, %slide : memref<4x4xi32>, memref<8x8xi8>, memref<8x4xi8> + + // Return success + %ret = arith.constant 0 : i32 + return %ret : i32 +} diff --git a/examples/IMEDialect/vmadotn_print_test.mlir b/examples/IMEDialect/vmadotn_print_test.mlir new file mode 100644 index 0000000000..436933c579 --- /dev/null +++ b/examples/IMEDialect/vmadotn_print_test.mlir @@ -0,0 +1,99 @@ +// RUN: buddy-opt %s | FileCheck %s +// CHECK: func.func @main +// +// vmadotn computes: C[i,j] += sum_k(signed(A[slide+i,k]) * signed(B[j,k])) +// +// Sliding window reads 64 elements from VS1 (8 rows), then slides by n rows. +// A (8x8): signed int8, source matrix (needs 64 elements for sliding) +// B (4x8): signed int8, packed form +// +// With slide=1: +// A rows used = [1,2,3,4] (after sliding by 1) +// Row 1 = [2,3,4,5,6,7,8,9], sum = 44 +// Row 2 = [3,4,5,6,7,8,9,10], sum = 52 +// Row 3 = [4,5,6,7,8,9,10,11], sum = 60 +// Row 4 = [5,6,7,8,9,10,11,12], sum = 68 + +memref.global "private" @matA : memref<8x8xi8> = dense<[ + [1, 2, 3, 4, 5, 6, 7, 8], + [2, 3, 4, 5, 6, 7, 8, 9], + [3, 4, 5, 6, 7, 8, 9, 10], + [4, 5, 6, 7, 8, 9, 10, 11], + [5, 6, 7, 8, 9, 10, 11, 12], + [6, 7, 8, 9, 10, 11, 12, 13], + [7, 8, 9, 10, 11, 12, 13, 14], + [8, 9, 10, 11, 12, 13, 14, 15] +]> + +// Packed B (4x8): all ones for easy verification +memref.global "private" @matB : memref<4x8xi8> = dense<[ + [1, 1, 1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1, 1, 1] +]> + +// With slide=1: Expected C = [[44,44,44,44], [52,52,52,52], [60,60,60,60], [68,68,68,68]] + +func.func private @print_row(i32, i32, i32, i32, i32) +func.func private @print_header() + +func.func @main() -> i32 { + %a = memref.get_global @matA : memref<8x8xi8> + %b = memref.get_global @matB : memref<4x8xi8> + + %c = memref.alloc() : memref<4x4xi32> + + // Initialize to zero + %zero = arith.constant 0 : i32 + linalg.fill ins(%zero : i32) outs(%c : memref<4x4xi32>) + + // Slide parameter = 1 + %slide = arith.constant 1 : i64 + + // Perform vmadotn with slide=1 + ime.vmadotn %c, %a, %b, %slide : memref<4x4xi32>, memref<8x8xi8>, memref<4x8xi8> + + // Print results + call @print_header() : () -> () + + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + %i0 = arith.constant 0 : i32 + %i1 = arith.constant 1 : i32 + %i2 = arith.constant 2 : i32 + %i3 = arith.constant 3 : i32 + + // Row 0 + %v00 = memref.load %c[%c0, %c0] : memref<4x4xi32> + %v01 = memref.load %c[%c0, %c1] : memref<4x4xi32> + %v02 = memref.load %c[%c0, %c2] : memref<4x4xi32> + %v03 = memref.load %c[%c0, %c3] : memref<4x4xi32> + call @print_row(%i0, %v00, %v01, %v02, %v03) : (i32, i32, i32, i32, i32) -> () + + // Row 1 + %v10 = memref.load %c[%c1, %c0] : memref<4x4xi32> + %v11 = memref.load %c[%c1, %c1] : memref<4x4xi32> + %v12 = memref.load %c[%c1, %c2] : memref<4x4xi32> + %v13 = memref.load %c[%c1, %c3] : memref<4x4xi32> + call @print_row(%i1, %v10, %v11, %v12, %v13) : (i32, i32, i32, i32, i32) -> () + + // Row 2 + %v20 = memref.load %c[%c2, %c0] : memref<4x4xi32> + %v21 = memref.load %c[%c2, %c1] : memref<4x4xi32> + %v22 = memref.load %c[%c2, %c2] : memref<4x4xi32> + %v23 = memref.load %c[%c2, %c3] : memref<4x4xi32> + call @print_row(%i2, %v20, %v21, %v22, %v23) : (i32, i32, i32, i32, i32) -> () + + // Row 3 + %v30 = memref.load %c[%c3, %c0] : memref<4x4xi32> + %v31 = memref.load %c[%c3, %c1] : memref<4x4xi32> + %v32 = memref.load %c[%c3, %c2] : memref<4x4xi32> + %v33 = memref.load %c[%c3, %c3] : memref<4x4xi32> + call @print_row(%i3, %v30, %v31, %v32, %v33) : (i32, i32, i32, i32, i32) -> () + + %ret = arith.constant 0 : i32 + return %ret : i32 +} diff --git a/examples/IMEDialect/vmadotnsu.mlir b/examples/IMEDialect/vmadotnsu.mlir new file mode 100644 index 0000000000..5d9c547eaa --- /dev/null +++ b/examples/IMEDialect/vmadotnsu.mlir @@ -0,0 +1,56 @@ +// RUN: buddy-opt %s | FileCheck %s + +// This example demonstrates the IME vmadotnsu operation with dynamic slide parameter. +// vmadotnsu performs: C += slide(A, n) × B where A is signed, B is unsigned int8. +// +// Sliding window: reads from VS1 and VS1+1 (64 elements), slides by n rows. +// Matrix dimensions for VLEN=256, SEW=8: +// A: 8×8 (2M×K) source, sliding selects 4×8 (M×K) - signed int8 +// B: 8×4 (K×N) - unsigned int8 +// C: 4×4 (M×N) - int32 accumulator + +memref.global "private" @matA : memref<8x8xi8> = dense<[ + [-1, -2, -3, -4, -5, -6, -7, -8], + [-2, -3, -4, -5, -6, -7, -8, -9], + [-3, -4, -5, -6, -7, -8, -9, -10], + [-4, -5, -6, -7, -8, -9, -10, -11], + [-5, -6, -7, -8, -9, -10, -11, -12], + [-6, -7, -8, -9, -10, -11, -12, -13], + [-7, -8, -9, -10, -11, -12, -13, -14], + [-8, -9, -10, -11, -12, -13, -14, -15] +]> + +memref.global "private" @matB : memref<8x4xui8> = dense<[ + [1, 1, 1, 1], + [1, 1, 1, 1], + [1, 1, 1, 1], + [1, 1, 1, 1], + [1, 1, 1, 1], + [1, 1, 1, 1], + [1, 1, 1, 1], + [1, 1, 1, 1] +]> + +func.func @main() -> i32 { + // Get input matrices + %a = memref.get_global @matA : memref<8x8xi8> + %b = memref.get_global @matB : memref<8x4xui8> + + // Allocate output matrix (accumulator) + %c = memref.alloc() : memref<4x4xi32> + + // Initialize accumulator to zero + %zero = arith.constant 0 : i32 + linalg.fill ins(%zero : i32) outs(%c : memref<4x4xi32>) + + // Slide parameter (0-3) + %slide = arith.constant 1 : i64 + + // Perform signed × unsigned matrix multiply-accumulate with dynamic slide + // CHECK: ime.vmadotnsu + ime.vmadotnsu %c, %a, %b, %slide : memref<4x4xi32>, memref<8x8xi8>, memref<8x4xui8> + + // Return success + %ret = arith.constant 0 : i32 + return %ret : i32 +} diff --git a/examples/IMEDialect/vmadotnsu_print_test.mlir b/examples/IMEDialect/vmadotnsu_print_test.mlir new file mode 100644 index 0000000000..a2998051ce --- /dev/null +++ b/examples/IMEDialect/vmadotnsu_print_test.mlir @@ -0,0 +1,99 @@ +// RUN: buddy-opt %s | FileCheck %s +// CHECK: func.func @main +// +// vmadotnsu computes: C[i,j] += sum_k(signed(A[slide+i,k]) * unsigned(B[j,k])) +// +// Sliding window reads 64 elements from VS1 (8 rows), then slides by n rows. +// A (8x8): signed int8, source matrix (negative values) +// B (4x8): unsigned int8, packed form +// +// With slide=1: +// A rows used = [1,2,3,4] (after sliding by 1) +// Row 1 = [-2,-3,-4,-5,-6,-7,-8,-9], sum = -44 +// Row 2 = [-3,-4,-5,-6,-7,-8,-9,-10], sum = -52 +// Row 3 = [-4,-5,-6,-7,-8,-9,-10,-11], sum = -60 +// Row 4 = [-5,-6,-7,-8,-9,-10,-11,-12], sum = -68 + +memref.global "private" @matA : memref<8x8xi8> = dense<[ + [-1, -2, -3, -4, -5, -6, -7, -8], + [-2, -3, -4, -5, -6, -7, -8, -9], + [-3, -4, -5, -6, -7, -8, -9, -10], + [-4, -5, -6, -7, -8, -9, -10, -11], + [-5, -6, -7, -8, -9, -10, -11, -12], + [-6, -7, -8, -9, -10, -11, -12, -13], + [-7, -8, -9, -10, -11, -12, -13, -14], + [-8, -9, -10, -11, -12, -13, -14, -15] +]> + +// Packed B (4x8): all ones (unsigned) +memref.global "private" @matB : memref<4x8xui8> = dense<[ + [1, 1, 1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1, 1, 1] +]> + +// With slide=1: Expected C = [[-44,-44,-44,-44], [-52,-52,-52,-52], [-60,-60,-60,-60], [-68,-68,-68,-68]] + +func.func private @print_row(i32, i32, i32, i32, i32) +func.func private @print_header() + +func.func @main() -> i32 { + %a = memref.get_global @matA : memref<8x8xi8> + %b = memref.get_global @matB : memref<4x8xui8> + + %c = memref.alloc() : memref<4x4xi32> + + // Initialize to zero + %zero = arith.constant 0 : i32 + linalg.fill ins(%zero : i32) outs(%c : memref<4x4xi32>) + + // Slide parameter = 1 + %slide = arith.constant 1 : i64 + + // Perform vmadotnsu with slide=1 + ime.vmadotnsu %c, %a, %b, %slide : memref<4x4xi32>, memref<8x8xi8>, memref<4x8xui8> + + // Print results + call @print_header() : () -> () + + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + %i0 = arith.constant 0 : i32 + %i1 = arith.constant 1 : i32 + %i2 = arith.constant 2 : i32 + %i3 = arith.constant 3 : i32 + + // Row 0 + %v00 = memref.load %c[%c0, %c0] : memref<4x4xi32> + %v01 = memref.load %c[%c0, %c1] : memref<4x4xi32> + %v02 = memref.load %c[%c0, %c2] : memref<4x4xi32> + %v03 = memref.load %c[%c0, %c3] : memref<4x4xi32> + call @print_row(%i0, %v00, %v01, %v02, %v03) : (i32, i32, i32, i32, i32) -> () + + // Row 1 + %v10 = memref.load %c[%c1, %c0] : memref<4x4xi32> + %v11 = memref.load %c[%c1, %c1] : memref<4x4xi32> + %v12 = memref.load %c[%c1, %c2] : memref<4x4xi32> + %v13 = memref.load %c[%c1, %c3] : memref<4x4xi32> + call @print_row(%i1, %v10, %v11, %v12, %v13) : (i32, i32, i32, i32, i32) -> () + + // Row 2 + %v20 = memref.load %c[%c2, %c0] : memref<4x4xi32> + %v21 = memref.load %c[%c2, %c1] : memref<4x4xi32> + %v22 = memref.load %c[%c2, %c2] : memref<4x4xi32> + %v23 = memref.load %c[%c2, %c3] : memref<4x4xi32> + call @print_row(%i2, %v20, %v21, %v22, %v23) : (i32, i32, i32, i32, i32) -> () + + // Row 3 + %v30 = memref.load %c[%c3, %c0] : memref<4x4xi32> + %v31 = memref.load %c[%c3, %c1] : memref<4x4xi32> + %v32 = memref.load %c[%c3, %c2] : memref<4x4xi32> + %v33 = memref.load %c[%c3, %c3] : memref<4x4xi32> + call @print_row(%i3, %v30, %v31, %v32, %v33) : (i32, i32, i32, i32, i32) -> () + + %ret = arith.constant 0 : i32 + return %ret : i32 +} diff --git a/examples/IMEDialect/vmadotnu.mlir b/examples/IMEDialect/vmadotnu.mlir new file mode 100644 index 0000000000..077acb14e5 --- /dev/null +++ b/examples/IMEDialect/vmadotnu.mlir @@ -0,0 +1,56 @@ +// RUN: buddy-opt %s | FileCheck %s + +// This example demonstrates the IME vmadotnu operation with dynamic slide parameter. +// vmadotnu performs: C += slide(A, n) × B where A, B are unsigned int8 matrices. +// +// Sliding window: reads from VS1 and VS1+1 (64 elements), slides by n rows. +// Matrix dimensions for VLEN=256, SEW=8: +// A: 8×8 (2M×K) source, sliding selects 4×8 (M×K) - unsigned int8 +// B: 8×4 (K×N) - unsigned int8 +// C: 4×4 (M×N) - int32 accumulator + +memref.global "private" @matA : memref<8x8xui8> = dense<[ + [1, 2, 3, 4, 5, 6, 7, 8], + [2, 3, 4, 5, 6, 7, 8, 9], + [3, 4, 5, 6, 7, 8, 9, 10], + [4, 5, 6, 7, 8, 9, 10, 11], + [5, 6, 7, 8, 9, 10, 11, 12], + [6, 7, 8, 9, 10, 11, 12, 13], + [7, 8, 9, 10, 11, 12, 13, 14], + [8, 9, 10, 11, 12, 13, 14, 15] +]> + +memref.global "private" @matB : memref<8x4xui8> = dense<[ + [1, 1, 1, 1], + [1, 1, 1, 1], + [1, 1, 1, 1], + [1, 1, 1, 1], + [1, 1, 1, 1], + [1, 1, 1, 1], + [1, 1, 1, 1], + [1, 1, 1, 1] +]> + +func.func @main() -> i32 { + // Get input matrices + %a = memref.get_global @matA : memref<8x8xui8> + %b = memref.get_global @matB : memref<8x4xui8> + + // Allocate output matrix (accumulator) + %c = memref.alloc() : memref<4x4xi32> + + // Initialize accumulator to zero + %zero = arith.constant 0 : i32 + linalg.fill ins(%zero : i32) outs(%c : memref<4x4xi32>) + + // Slide parameter (0-3) + %slide = arith.constant 2 : i64 + + // Perform unsigned × unsigned matrix multiply-accumulate with dynamic slide + // CHECK: ime.vmadotnu + ime.vmadotnu %c, %a, %b, %slide : memref<4x4xi32>, memref<8x8xui8>, memref<8x4xui8> + + // Return success + %ret = arith.constant 0 : i32 + return %ret : i32 +} diff --git a/examples/IMEDialect/vmadotnu_print_test.mlir b/examples/IMEDialect/vmadotnu_print_test.mlir new file mode 100644 index 0000000000..3a973804d5 --- /dev/null +++ b/examples/IMEDialect/vmadotnu_print_test.mlir @@ -0,0 +1,99 @@ +// RUN: buddy-opt %s | FileCheck %s +// CHECK: func.func @main +// +// vmadotnu computes: C[i,j] += sum_k(unsigned(A[slide+i,k]) * unsigned(B[j,k])) +// +// Sliding window reads 64 elements from VS1 (8 rows), then slides by n rows. +// A (8x8): unsigned int8, source matrix +// B (4x8): unsigned int8, packed form +// +// With slide=2: +// A rows used = [2,3,4,5] (after sliding by 2) +// Row 2 = [3,4,5,6,7,8,9,10], sum = 52 +// Row 3 = [4,5,6,7,8,9,10,11], sum = 60 +// Row 4 = [5,6,7,8,9,10,11,12], sum = 68 +// Row 5 = [6,7,8,9,10,11,12,13], sum = 76 + +memref.global "private" @matA : memref<8x8xui8> = dense<[ + [1, 2, 3, 4, 5, 6, 7, 8], + [2, 3, 4, 5, 6, 7, 8, 9], + [3, 4, 5, 6, 7, 8, 9, 10], + [4, 5, 6, 7, 8, 9, 10, 11], + [5, 6, 7, 8, 9, 10, 11, 12], + [6, 7, 8, 9, 10, 11, 12, 13], + [7, 8, 9, 10, 11, 12, 13, 14], + [8, 9, 10, 11, 12, 13, 14, 15] +]> + +// Packed B (4x8): all ones for easy verification +memref.global "private" @matB : memref<4x8xui8> = dense<[ + [1, 1, 1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1, 1, 1] +]> + +// With slide=2: Expected C = [[52,52,52,52], [60,60,60,60], [68,68,68,68], [76,76,76,76]] + +func.func private @print_row(i32, i32, i32, i32, i32) +func.func private @print_header() + +func.func @main() -> i32 { + %a = memref.get_global @matA : memref<8x8xui8> + %b = memref.get_global @matB : memref<4x8xui8> + + %c = memref.alloc() : memref<4x4xi32> + + // Initialize to zero + %zero = arith.constant 0 : i32 + linalg.fill ins(%zero : i32) outs(%c : memref<4x4xi32>) + + // Slide parameter = 2 + %slide = arith.constant 2 : i64 + + // Perform vmadotnu with slide=2 + ime.vmadotnu %c, %a, %b, %slide : memref<4x4xi32>, memref<8x8xui8>, memref<4x8xui8> + + // Print results + call @print_header() : () -> () + + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + %i0 = arith.constant 0 : i32 + %i1 = arith.constant 1 : i32 + %i2 = arith.constant 2 : i32 + %i3 = arith.constant 3 : i32 + + // Row 0 + %v00 = memref.load %c[%c0, %c0] : memref<4x4xi32> + %v01 = memref.load %c[%c0, %c1] : memref<4x4xi32> + %v02 = memref.load %c[%c0, %c2] : memref<4x4xi32> + %v03 = memref.load %c[%c0, %c3] : memref<4x4xi32> + call @print_row(%i0, %v00, %v01, %v02, %v03) : (i32, i32, i32, i32, i32) -> () + + // Row 1 + %v10 = memref.load %c[%c1, %c0] : memref<4x4xi32> + %v11 = memref.load %c[%c1, %c1] : memref<4x4xi32> + %v12 = memref.load %c[%c1, %c2] : memref<4x4xi32> + %v13 = memref.load %c[%c1, %c3] : memref<4x4xi32> + call @print_row(%i1, %v10, %v11, %v12, %v13) : (i32, i32, i32, i32, i32) -> () + + // Row 2 + %v20 = memref.load %c[%c2, %c0] : memref<4x4xi32> + %v21 = memref.load %c[%c2, %c1] : memref<4x4xi32> + %v22 = memref.load %c[%c2, %c2] : memref<4x4xi32> + %v23 = memref.load %c[%c2, %c3] : memref<4x4xi32> + call @print_row(%i2, %v20, %v21, %v22, %v23) : (i32, i32, i32, i32, i32) -> () + + // Row 3 + %v30 = memref.load %c[%c3, %c0] : memref<4x4xi32> + %v31 = memref.load %c[%c3, %c1] : memref<4x4xi32> + %v32 = memref.load %c[%c3, %c2] : memref<4x4xi32> + %v33 = memref.load %c[%c3, %c3] : memref<4x4xi32> + call @print_row(%i3, %v30, %v31, %v32, %v33) : (i32, i32, i32, i32, i32) -> () + + %ret = arith.constant 0 : i32 + return %ret : i32 +} diff --git a/examples/IMEDialect/vmadotnus.mlir b/examples/IMEDialect/vmadotnus.mlir new file mode 100644 index 0000000000..e8ed037c22 --- /dev/null +++ b/examples/IMEDialect/vmadotnus.mlir @@ -0,0 +1,56 @@ +// RUN: buddy-opt %s | FileCheck %s + +// This example demonstrates the IME vmadotnus operation with dynamic slide parameter. +// vmadotnus performs: C += slide(A, n) × B where A is unsigned, B is signed int8. +// +// Sliding window: reads from VS1 and VS1+1 (64 elements), slides by n rows. +// Matrix dimensions for VLEN=256, SEW=8: +// A: 8×8 (2M×K) source, sliding selects 4×8 (M×K) - unsigned int8 +// B: 8×4 (K×N) - signed int8 +// C: 4×4 (M×N) - int32 accumulator + +memref.global "private" @matA : memref<8x8xui8> = dense<[ + [1, 2, 3, 4, 5, 6, 7, 8], + [2, 3, 4, 5, 6, 7, 8, 9], + [3, 4, 5, 6, 7, 8, 9, 10], + [4, 5, 6, 7, 8, 9, 10, 11], + [5, 6, 7, 8, 9, 10, 11, 12], + [6, 7, 8, 9, 10, 11, 12, 13], + [7, 8, 9, 10, 11, 12, 13, 14], + [8, 9, 10, 11, 12, 13, 14, 15] +]> + +memref.global "private" @matB : memref<8x4xi8> = dense<[ + [-1, -1, -1, -1], + [-1, -1, -1, -1], + [-1, -1, -1, -1], + [-1, -1, -1, -1], + [-1, -1, -1, -1], + [-1, -1, -1, -1], + [-1, -1, -1, -1], + [-1, -1, -1, -1] +]> + +func.func @main() -> i32 { + // Get input matrices + %a = memref.get_global @matA : memref<8x8xui8> + %b = memref.get_global @matB : memref<8x4xi8> + + // Allocate output matrix (accumulator) + %c = memref.alloc() : memref<4x4xi32> + + // Initialize accumulator to zero + %zero = arith.constant 0 : i32 + linalg.fill ins(%zero : i32) outs(%c : memref<4x4xi32>) + + // Slide parameter (0-3) + %slide = arith.constant 3 : i64 + + // Perform unsigned × signed matrix multiply-accumulate with dynamic slide + // CHECK: ime.vmadotnus + ime.vmadotnus %c, %a, %b, %slide : memref<4x4xi32>, memref<8x8xui8>, memref<8x4xi8> + + // Return success + %ret = arith.constant 0 : i32 + return %ret : i32 +} diff --git a/examples/IMEDialect/vmadotnus_print_test.mlir b/examples/IMEDialect/vmadotnus_print_test.mlir new file mode 100644 index 0000000000..230140ca98 --- /dev/null +++ b/examples/IMEDialect/vmadotnus_print_test.mlir @@ -0,0 +1,99 @@ +// RUN: buddy-opt %s | FileCheck %s +// CHECK: func.func @main +// +// vmadotnus computes: C[i,j] += sum_k(unsigned(A[slide+i,k]) * signed(B[j,k])) +// +// Sliding window reads 64 elements from VS1 (8 rows), then slides by n rows. +// A (8x8): unsigned int8, source matrix +// B (4x8): signed int8, packed form (negative values) +// +// With slide=3: +// A rows used = [3,4,5,6] (after sliding by 3) +// Row 3 = [4,5,6,7,8,9,10,11], sum with B=-1 each: -60 +// Row 4 = [5,6,7,8,9,10,11,12], sum = -68 +// Row 5 = [6,7,8,9,10,11,12,13], sum = -76 +// Row 6 = [7,8,9,10,11,12,13,14], sum = -84 + +memref.global "private" @matA : memref<8x8xui8> = dense<[ + [1, 2, 3, 4, 5, 6, 7, 8], + [2, 3, 4, 5, 6, 7, 8, 9], + [3, 4, 5, 6, 7, 8, 9, 10], + [4, 5, 6, 7, 8, 9, 10, 11], + [5, 6, 7, 8, 9, 10, 11, 12], + [6, 7, 8, 9, 10, 11, 12, 13], + [7, 8, 9, 10, 11, 12, 13, 14], + [8, 9, 10, 11, 12, 13, 14, 15] +]> + +// Packed B (4x8): all -1 (signed) +memref.global "private" @matB : memref<4x8xi8> = dense<[ + [-1, -1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1, -1, -1] +]> + +// With slide=3: Expected C = [[-60,-60,-60,-60], [-68,-68,-68,-68], [-76,-76,-76,-76], [-84,-84,-84,-84]] + +func.func private @print_row(i32, i32, i32, i32, i32) +func.func private @print_header() + +func.func @main() -> i32 { + %a = memref.get_global @matA : memref<8x8xui8> + %b = memref.get_global @matB : memref<4x8xi8> + + %c = memref.alloc() : memref<4x4xi32> + + // Initialize to zero + %zero = arith.constant 0 : i32 + linalg.fill ins(%zero : i32) outs(%c : memref<4x4xi32>) + + // Slide parameter = 3 + %slide = arith.constant 3 : i64 + + // Perform vmadotnus with slide=3 + ime.vmadotnus %c, %a, %b, %slide : memref<4x4xi32>, memref<8x8xui8>, memref<4x8xi8> + + // Print results + call @print_header() : () -> () + + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + %i0 = arith.constant 0 : i32 + %i1 = arith.constant 1 : i32 + %i2 = arith.constant 2 : i32 + %i3 = arith.constant 3 : i32 + + // Row 0 + %v00 = memref.load %c[%c0, %c0] : memref<4x4xi32> + %v01 = memref.load %c[%c0, %c1] : memref<4x4xi32> + %v02 = memref.load %c[%c0, %c2] : memref<4x4xi32> + %v03 = memref.load %c[%c0, %c3] : memref<4x4xi32> + call @print_row(%i0, %v00, %v01, %v02, %v03) : (i32, i32, i32, i32, i32) -> () + + // Row 1 + %v10 = memref.load %c[%c1, %c0] : memref<4x4xi32> + %v11 = memref.load %c[%c1, %c1] : memref<4x4xi32> + %v12 = memref.load %c[%c1, %c2] : memref<4x4xi32> + %v13 = memref.load %c[%c1, %c3] : memref<4x4xi32> + call @print_row(%i1, %v10, %v11, %v12, %v13) : (i32, i32, i32, i32, i32) -> () + + // Row 2 + %v20 = memref.load %c[%c2, %c0] : memref<4x4xi32> + %v21 = memref.load %c[%c2, %c1] : memref<4x4xi32> + %v22 = memref.load %c[%c2, %c2] : memref<4x4xi32> + %v23 = memref.load %c[%c2, %c3] : memref<4x4xi32> + call @print_row(%i2, %v20, %v21, %v22, %v23) : (i32, i32, i32, i32, i32) -> () + + // Row 3 + %v30 = memref.load %c[%c3, %c0] : memref<4x4xi32> + %v31 = memref.load %c[%c3, %c1] : memref<4x4xi32> + %v32 = memref.load %c[%c3, %c2] : memref<4x4xi32> + %v33 = memref.load %c[%c3, %c3] : memref<4x4xi32> + call @print_row(%i3, %v30, %v31, %v32, %v33) : (i32, i32, i32, i32, i32) -> () + + %ret = arith.constant 0 : i32 + return %ret : i32 +} diff --git a/examples/IMEDialect/vmadotsu_print_test.mlir b/examples/IMEDialect/vmadotsu_print_test.mlir index 8f4c5541da..185ae50370 100644 --- a/examples/IMEDialect/vmadotsu_print_test.mlir +++ b/examples/IMEDialect/vmadotsu_print_test.mlir @@ -1,4 +1,5 @@ -// IME vmadotsu test: signed × unsigned matrix multiply-accumulate +// RUN: buddy-opt %s | FileCheck %s +// CHECK: func.func @main // // vmadotsu computes: C[i,j] += sum_k(signed(A[i,k]) * unsigned(B[j,k])) // diff --git a/examples/IMEDialect/vmadotu_print_test.mlir b/examples/IMEDialect/vmadotu_print_test.mlir index f7a72cee28..c5ff98284d 100644 --- a/examples/IMEDialect/vmadotu_print_test.mlir +++ b/examples/IMEDialect/vmadotu_print_test.mlir @@ -1,4 +1,5 @@ -// IME vmadotu test: unsigned × unsigned matrix multiply-accumulate +// RUN: buddy-opt %s | FileCheck %s +// CHECK: func.func @main // // vmadotu computes: C[i,j] += sum_k(unsigned(A[i,k]) * unsigned(B[j,k])) // diff --git a/examples/IMEDialect/vmadotus_print_test.mlir b/examples/IMEDialect/vmadotus_print_test.mlir index b5a70fc51d..b171280341 100644 --- a/examples/IMEDialect/vmadotus_print_test.mlir +++ b/examples/IMEDialect/vmadotus_print_test.mlir @@ -1,4 +1,5 @@ -// IME vmadotus test: unsigned × signed matrix multiply-accumulate +// RUN: buddy-opt %s | FileCheck %s +// CHECK: func.func @main // // vmadotus computes: C[i,j] += sum_k(unsigned(A[i,k]) * signed(B[j,k])) // diff --git a/examples/lit.cfg.py b/examples/lit.cfg.py index 1924ffcbbb..465ab94348 100644 --- a/examples/lit.cfg.py +++ b/examples/lit.cfg.py @@ -78,11 +78,6 @@ "log.mlir", "lit.cfg.py", "BuddyPython", - "vmadot_print_test.mlir", - "vmadotu_print_test.mlir", - "vmadotsu_print_test.mlir", - "vmadotus_print_test.mlir", - "vfmadot_print_test.mlir", ] config.buddy_tools_dir = os.path.join(config.buddy_obj_root, "bin") diff --git a/frontend/Interfaces/lib/CMakeLists.txt b/frontend/Interfaces/lib/CMakeLists.txt index 99b0b9d81b..13a2ee3229 100644 --- a/frontend/Interfaces/lib/CMakeLists.txt +++ b/frontend/Interfaces/lib/CMakeLists.txt @@ -47,6 +47,10 @@ SET_TARGET_PROPERTIES(BuddyLibDIP PROPERTIES ARCHIVE_OUTPUT_DIRECTORY ${LIBRARY_OUTPUT_DIRECTORY} ) +install(TARGETS BuddyLibDIP + ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} +) + #------------------------------------------------------------------------------- # Generate Buddy DAP Library: BuddyLibDAP #------------------------------------------------------------------------------- @@ -147,3 +151,7 @@ SET_TARGET_PROPERTIES(BuddyLibDAP PROPERTIES LINKER_LANGUAGE CXX ARCHIVE_OUTPUT_DIRECTORY ${LIBRARY_OUTPUT_DIRECTORY} ) + +install(TARGETS BuddyLibDAP + ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} +) diff --git a/frontend/Python/CMakeLists.txt b/frontend/Python/CMakeLists.txt index 64cfbd471e..e71854f963 100644 --- a/frontend/Python/CMakeLists.txt +++ b/frontend/Python/CMakeLists.txt @@ -1,11 +1,39 @@ +# Find the Python interpreter and module development components, +# requiring a minimum version of 3.10 +find_package(Python3 ${LLVM_MINIMUM_PYTHON_VERSION} REQUIRED COMPONENTS Interpreter Development.Module) + +set(BUDDY_MLIR_PYTHON_PACKAGES_DIR ${CMAKE_BINARY_DIR}/python_packages) + +# Create directories for the BUDDY-MLIR Python packages +file(MAKE_DIRECTORY ${BUDDY_MLIR_PYTHON_PACKAGES_DIR}/buddy) +file(MAKE_DIRECTORY ${BUDDY_MLIR_PYTHON_PACKAGES_DIR}/buddy/compiler) +# Create empty __init__.py files to make these directories Python packages +file(WRITE ${BUDDY_MLIR_PYTHON_PACKAGES_DIR}/buddy/__init__.py "") +file(WRITE ${BUDDY_MLIR_PYTHON_PACKAGES_DIR}/buddy/compiler/__init__.py "") + # Recursively retrieve all python files from the current directory. file(GLOB_RECURSE ALL_PY_FILES RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*.py") +set(PYTHON_OUTPUT_FILES "") foreach(FILE ${ALL_PY_FILES}) - # Get the directory of the current file. - get_filename_component(DIR "${FILE}" DIRECTORY) - # Set the destination directory for the target file. - set(DEST "${BUDDY_MLIR_PYTHON_PACKAGES_DIR}/buddy/compiler/${DIR}") - # Copy the file into the destination directory. - file(COPY ${FILE} DESTINATION ${DEST}) + get_filename_component(REL_DIR "${FILE}" DIRECTORY) + set(SRC "${CMAKE_CURRENT_SOURCE_DIR}/${FILE}") + set(DEST_DIR "${BUDDY_MLIR_PYTHON_PACKAGES_DIR}/buddy/compiler/${REL_DIR}") + set(DEST_FILE "${BUDDY_MLIR_PYTHON_PACKAGES_DIR}/buddy/compiler/${FILE}") + + add_custom_command( + OUTPUT "${DEST_FILE}" + COMMAND ${CMAKE_COMMAND} -E make_directory "${DEST_DIR}" + COMMAND ${CMAKE_COMMAND} -E copy_if_different "${SRC}" "${DEST_FILE}" + DEPENDS "${SRC}" + COMMENT "Syncing ${FILE}" + VERBATIM + ) + + list(APPEND PYTHON_OUTPUT_FILES "${DEST_FILE}") endforeach() + +add_custom_target(python-package-buddy + ALL DEPENDS ${PYTHON_OUTPUT_FILES} + COMMENT "Syncing python package files" +) diff --git a/frontend/Python/frontend.py b/frontend/Python/frontend.py index ff21d1ed3a..5fc58e8dbc 100644 --- a/frontend/Python/frontend.py +++ b/frontend/Python/frontend.py @@ -29,11 +29,11 @@ import platform import numpy as np -import mlir.ir as ir -import mlir.dialects.func as func -from mlir.passmanager import * -from mlir.execution_engine import * -from mlir import runtime as rt +import buddy_mlir.ir as ir +import buddy_mlir.dialects.func as func +from buddy_mlir.passmanager import * +from buddy_mlir.execution_engine import * +from buddy_mlir import runtime as rt import torch import torch._dynamo as dynamo from torch._functorch.aot_autograd import aot_module_simplified @@ -873,9 +873,6 @@ def _compile_fx( def _compiler(_gm: torch.fx.GraphModule, _inputs: List[torch.Tensor]): """Compile a FX graph in Aten/Prims IR to MLIR.""" - num_cached_kv = 0 - if self._model_config.decode_with_cache: - num_cached_kv = self._model_config.num_hidden_layers * 2 graph = Graph( self._ops_registry, self._func_name, @@ -888,9 +885,8 @@ def _compiler(_gm: torch.fx.GraphModule, _inputs: List[torch.Tensor]): buffers_nodes = [] input_nodes = [] other_nodes = [] - for i, node in enumerate( - list(_gm.graph.nodes)[num_cached_kv:], start=0 - ): + all_nodes = list(_gm.graph.nodes) + for i, node in enumerate(all_nodes): if i in params_pos: param_nodes.append(node) elif i in buffers_pos: @@ -899,12 +895,11 @@ def _compiler(_gm: torch.fx.GraphModule, _inputs: List[torch.Tensor]): input_nodes.append(node) else: other_nodes.append(node) - input_nodes.extend(list(_gm.graph.nodes)[:num_cached_kv]) gm_nodes = [ (NodeType.FakeNode, param_nodes), (NodeType.FakeNode, buffers_nodes), (NodeType.InputNode, input_nodes), - (NodeType.OtherNode, other_nodes) + (NodeType.OtherNode, other_nodes), ] for node_type, gm_nodes_sublist in gm_nodes: @@ -980,7 +975,9 @@ def _compiler(_gm: torch.fx.GraphModule, _inputs: List[torch.Tensor]): gm_node.insert_arg(len(gm_node.args), value) val = gm_node.meta.get("val") node_shape = val.shape - node_dtype = self._torch_dtype_translate(str(val.dtype)) + node_dtype = self._torch_dtype_translate( + str(val.dtype) + ) buddy_node = self._create_node( "_tensor_constant", gm_node.name, @@ -1013,11 +1010,15 @@ def _compiler(_gm: torch.fx.GraphModule, _inputs: List[torch.Tensor]): elif num_returns > 1: node_dtype = tuple( [ - self._torch_dtype_translate(str(val_item.dtype)) + self._torch_dtype_translate( + str(val_item.dtype) + ) for val_item in val ] ) - node_shape = tuple([val_item.shape for val_item in val]) + node_shape = tuple( + [val_item.shape for val_item in val] + ) else: raise RuntimeError("Zero returns is not supported.") diff --git a/frontend/Python/graph/graph.py b/frontend/Python/graph/graph.py index 06ed4793eb..bec1dd378c 100644 --- a/frontend/Python/graph/graph.py +++ b/frontend/Python/graph/graph.py @@ -25,11 +25,11 @@ import functools import numpy as np -import mlir.ir as ir -import mlir.dialects.func as func -from mlir.passmanager import * -from mlir.execution_engine import * -from mlir import runtime as rt +import buddy_mlir.ir as ir +import buddy_mlir.dialects.func as func +from buddy_mlir.passmanager import * +from buddy_mlir.execution_engine import * +from buddy_mlir import runtime as rt from .operation import * from .type import * @@ -68,11 +68,13 @@ class OutputDescriptor(ctypes.Structure): return OutputDescriptor + class NodeType(Enum): FakeNode = auto() InputNode = auto() OtherNode = auto() + class Graph: """ Graph is a graph-level expression for the Buddy Compiler frontends. @@ -181,7 +183,7 @@ def add_node(self, node: Op, node_type: NodeType = NodeType.OtherNode): def get_input(self, i): return self._body[self._inputs[i]] - @property + @property def inputs(self) -> list[Op]: return [self.get_input(i) for i in range(len(self._inputs))] @@ -195,10 +197,12 @@ def inputs_shapes(self) -> list[TensorMeta]: if isinstance(input_tm_dict, TensorMeta): tm_list.append(input_tm_dict) continue - tm_list.append(TensorMeta( - shape=input_tm_dict["shape"], - dtype=input_tm_dict["dtype"], - )) + tm_list.append( + TensorMeta( + shape=input_tm_dict["shape"], + dtype=input_tm_dict["dtype"], + ) + ) return tm_list @@ -208,7 +212,7 @@ def get_fake_params(self, i): @property def params(self) -> list[Op]: return [self.get_fake_params(i) for i in range(len(self._fake_params))] - + @property def params_shapes(self) -> list[TensorMeta]: tm_list = [] @@ -219,10 +223,12 @@ def params_shapes(self) -> list[TensorMeta]: if isinstance(param_tm_dict, TensorMeta): tm_list.append(param_tm_dict) continue - tm_list.append(TensorMeta( - shape=param_tm_dict["shape"], - dtype=param_tm_dict["dtype"], - )) + tm_list.append( + TensorMeta( + shape=param_tm_dict["shape"], + dtype=param_tm_dict["dtype"], + ) + ) return tm_list @@ -270,7 +276,7 @@ def delete_node(self, node: Op, parents: List[Op]): for i, ref_idx in enumerate(self._fake_params): if ref_idx > node_idx: self._fake_params[i] -= 1 - + for i in parents: i._children.remove(node.name) node.args.clear() @@ -355,7 +361,9 @@ def displace_node_with_chain(self, node: Op, chain: list[Op]): node_idx = self._body.index(node) self._body = self.body[:node_idx] + chain + self.body[node_idx + 1 :] - def replace_as_child(self, parent_ops: list[Op] | Op, child_op: Op, new_op: Op): + def replace_as_child( + self, parent_ops: list[Op] | Op, child_op: Op, new_op: Op + ): """ Replace `child_op`, a child of the `parent_ops` with `new_op`. @@ -373,15 +381,19 @@ def replace_as_child(self, parent_ops: list[Op] | Op, child_op: Op, new_op: Op): for parent_name in parent_ops: parent_op = self.node_table[parent_name] - parent_op._children[parent_op._children.index(child_name)] = new_child_name - - def replace_as_parent(self, parent_op: Op, child_ops: list[Op] | Op, new_op: Op): + parent_op._children[parent_op._children.index(child_name)] = ( + new_child_name + ) + + def replace_as_parent( + self, parent_op: Op, child_ops: list[Op] | Op, new_op: Op + ): """ Replace `parent_op` with `new_op` as the the parent node of the `child_ops` list. Args: parent_op (Op): Parent to replace - child_ops (list[Op]): Child ops for which replace `parent_op` as their + child_ops (list[Op]): Child ops for which replace `parent_op` as their new_op (Op): op to replace `parent_op` with """ @@ -395,10 +407,14 @@ def replace_as_parent(self, parent_op: Op, child_ops: list[Op] | Op, new_op: Op) child_op = self.node_table[child_name] if parent_name in child_op._parents: - child_op._parents[child_op._parents.index(parent_name)] = new_parent_name + child_op._parents[child_op._parents.index(parent_name)] = ( + new_parent_name + ) if parent_name in child_op._arguments: - child_op._arguments[child_op._arguments.index(parent_name)] = new_parent_name + child_op._arguments[child_op._arguments.index(parent_name)] = ( + new_parent_name + ) def init_op_group(self): """ @@ -585,7 +601,7 @@ class GraphImporter: _func_name (str): Name of the generated MLIR function. _inputs (List[TensorMeta]): Input tensor(s) of the FX graph. _num_input_visited (int): Number of input nodes that have been visited. - _module (mlir.ir.Module): The generated MLIR module. + _module (buddy_mlir.ir.Module): The generated MLIR module. _ops_registry (dict): Registry for the candidate operations. """ @@ -636,7 +652,7 @@ def _str_to_mlir_dtype(self, dtype: str) -> ir.Type: dtype (str): The tensor type. Returns: - mlir.ir.Type: The corresponding MLIR data type. + buddy_mlir.ir.Type: The corresponding MLIR data type. Raises: NotImplementedError: If the given dtype is not supported. @@ -699,7 +715,7 @@ def import_graph(self) -> ir.Module: Imports buddy graph and generates an MLIR module in high-level dialects. Returns: - mlir.ir.Module: An MLIR module in high-level dialects. + buddy_mlir.ir.Module: An MLIR module in high-level dialects. """ assert self._do_param_pack == False with ir.InsertionPoint(self._module.body): @@ -774,7 +790,7 @@ def import_main_graph(self) -> ir.Module: module in high-level dialects with memref. Returns: - mlir.ir.Module: An MLIR module in high-level dialects. + buddy_mlir.ir.Module: An MLIR module in high-level dialects. """ with ir.InsertionPoint(self._module.body): arguments = [] @@ -833,12 +849,15 @@ def _import_placeholder( Parameters: - node (PlaceholderOp): The PlaceholderOp node representing the placeholder. - - args_list (List[mlir.ir.BlockArgument]): List of input memrefs. + - args_list (List[buddy_mlir.ir.BlockArgument]): List of input memrefs. Returns: None """ - if self._num_input_visited < len(self._params_shapes) and self._do_param_pack: + if ( + self._num_input_visited < len(self._params_shapes) + and self._do_param_pack + ): dtype = node.tensor_meta["dtype"] pack_of_dtype = None for pack in args_list: diff --git a/frontend/Python/graph/graph_driver.py b/frontend/Python/graph/graph_driver.py index 8d4dd12c3d..cdebb2402d 100644 --- a/frontend/Python/graph/graph_driver.py +++ b/frontend/Python/graph/graph_driver.py @@ -20,7 +20,7 @@ # # ===--------------------------------------------------------------------------- -from mlir import ir +from buddy_mlir import ir from collections import deque, defaultdict from .graph import Graph, GraphImporter, TensorMeta, NodeType @@ -154,7 +154,7 @@ def build_subgraph_by_group(self): for output in subgraphs_outputs[subgraph_name]: output_node.add_argument(output) output_node.add_parent(output) - + subgraph.add_node(node=output_node, node_type=NodeType.OtherNode) for op in subgraph._body: @@ -217,7 +217,7 @@ def construct_main_graph(self, do_param_pack=False): verbose=self._graph._verbose, ) - # Adding placeholder operations from the original graph + # Adding placeholder operations from the original graph for op in self._graph.params: main_graph.add_node(op, node_type=NodeType.FakeNode) for op in self._graph.inputs: diff --git a/frontend/Python/graph/transform/eliminate_weight_transpose.py b/frontend/Python/graph/transform/eliminate_weight_transpose.py index a3b55c6aaf..69f341e17a 100644 --- a/frontend/Python/graph/transform/eliminate_weight_transpose.py +++ b/frontend/Python/graph/transform/eliminate_weight_transpose.py @@ -174,14 +174,28 @@ def eliminate_transpose(graph: Graph): else: param_tensor_data = param_tensor.detach().clone() + if list(param_tensor_data.shape) != current_shape: + continue + if transpose_info["type"] == "t": + if param_tensor_data.dim() != 2: + continue param_tensor_data = param_tensor_data.T elif transpose_info["type"] == "transpose": dim1, dim2 = transpose_info["dims"] + if ( + dim1 >= param_tensor_data.dim() + or dim2 >= param_tensor_data.dim() + ): + continue param_tensor_data = param_tensor_data.swapaxes( dim1, dim2 ) elif transpose_info["type"] == "permute": + if param_tensor_data.dim() != len( + transpose_info["perm"] + ): + continue param_tensor_data = param_tensor_data.permute( transpose_info["perm"] ) diff --git a/frontend/Python/graph/transform/fuse_ops.py b/frontend/Python/graph/transform/fuse_ops.py index 7619423003..3e911678f4 100644 --- a/frontend/Python/graph/transform/fuse_ops.py +++ b/frontend/Python/graph/transform/fuse_ops.py @@ -208,53 +208,49 @@ def gqa_attention_fusion_check(graph: Graph): ): continue - # trace Key branch: View <- Clone <- Expand <- slice1 <- slice2 <- unsqueeze + # trace Key branch for torch2.10: + # View <- Clone <- Expand <- Unsqueeze <- IndexPut k_clone = graph.node_table.get(k_view_node._parents[0], None) if not isinstance(k_clone, CloneOp): continue k_expand = graph.node_table.get(k_clone._parents[0], None) if not isinstance(k_expand, ExpandOp): continue - k_slice1 = graph.node_table.get(k_expand._parents[0], None) - if not isinstance(k_slice1, SliceOp): - continue - k_slice2 = graph.node_table.get(k_slice1._parents[0], None) - if not isinstance(k_slice2, SliceOp): - continue - k_cache_unsqueeze = graph.node_table.get(k_slice2._parents[0], None) + k_cache_unsqueeze = graph.node_table.get(k_expand._parents[0], None) if not isinstance(k_cache_unsqueeze, UnsqueezeOp): continue + k_index_put = graph.node_table.get( + k_cache_unsqueeze._parents[0], None + ) + if not isinstance(k_index_put, IndexPutOp): + continue - # trace Value branch: View <- Clone <- Expand <- slice1 <- slice2 <- unsqueeze + # trace Value branch for torch2.10: + # View <- Clone <- Expand <- Unsqueeze <- IndexPut v_clone = graph.node_table.get(v_view_node._parents[0], None) if not isinstance(v_clone, CloneOp): continue v_expand = graph.node_table.get(v_clone._parents[0], None) if not isinstance(v_expand, ExpandOp): continue - v_slice1 = graph.node_table.get(v_expand._parents[0], None) - if not isinstance(v_slice1, SliceOp): - continue - v_slice2 = graph.node_table.get(v_slice1._parents[0], None) - if not isinstance(v_slice2, SliceOp): - continue - v_cache_unsqueeze = graph.node_table.get(v_slice2._parents[0], None) + v_cache_unsqueeze = graph.node_table.get(v_expand._parents[0], None) if not isinstance(v_cache_unsqueeze, UnsqueezeOp): continue + v_index_put = graph.node_table.get( + v_cache_unsqueeze._parents[0], None + ) + if not isinstance(v_index_put, IndexPutOp): + continue replace_gqa_attention_with_fused_op( graph, op, k_view_node, k_clone, k_expand, - k_slice1, - k_slice2, k_cache_unsqueeze, v_view_node, v_clone, v_expand, - v_slice1, - v_slice2, v_cache_unsqueeze, "gqa_attention_fusion", ) @@ -266,14 +262,10 @@ def replace_gqa_attention_with_fused_op( k_view: Op, k_clone: Op, k_expand: Op, - k_slice1: Op, - k_slice2: Op, k_cache_unsqueeze: Op, v_view: Op, v_clone: Op, v_expand: Op, - v_slice1: Op, - v_slice2: Op, v_cache_unsqueeze: Op, pattern: str, ): @@ -309,11 +301,7 @@ def replace_gqa_attention_with_fused_op( if graph.check_delete_node(k_clone): graph.delete_node(k_clone, [k_expand]) if graph.check_delete_node(k_expand): - graph.delete_node(k_expand, [k_slice1]) - if graph.check_delete_node(k_slice1): - graph.delete_node(k_slice1, [k_slice2]) - if graph.check_delete_node(k_slice2): - graph.delete_node(k_slice2, [k_cache_unsqueeze]) + graph.delete_node(k_expand, [k_cache_unsqueeze]) if graph.check_delete_node(k_cache_unsqueeze): k_orig_parents = [ graph.node_table.get(p, None) for p in k_cache_unsqueeze._parents @@ -326,11 +314,7 @@ def replace_gqa_attention_with_fused_op( if graph.check_delete_node(v_clone): graph.delete_node(v_clone, [v_expand]) if graph.check_delete_node(v_expand): - graph.delete_node(v_expand, [v_slice1]) - if graph.check_delete_node(v_slice1): - graph.delete_node(v_slice1, [v_slice2]) - if graph.check_delete_node(v_slice2): - graph.delete_node(v_slice2, [v_cache_unsqueeze]) + graph.delete_node(v_expand, [v_cache_unsqueeze]) if graph.check_delete_node(v_cache_unsqueeze): v_orig_parents = [ graph.node_table.get(p, None) for p in v_cache_unsqueeze._parents diff --git a/frontend/Python/ops/func.py b/frontend/Python/ops/func.py index 4b448ed119..bb76103334 100644 --- a/frontend/Python/ops/func.py +++ b/frontend/Python/ops/func.py @@ -20,8 +20,8 @@ from typing import Tuple import functools -from mlir.dialects import func, memref -from mlir import ir +from buddy_mlir.dialects import func, memref +from buddy_mlir import ir from ..graph import FuncOp, CallOp, CallExternalOp, PlaceholderOp from .utils import * diff --git a/frontend/Python/ops/linalg.py b/frontend/Python/ops/linalg.py index 7227423992..3a781a1adf 100644 --- a/frontend/Python/ops/linalg.py +++ b/frontend/Python/ops/linalg.py @@ -20,8 +20,8 @@ from typing import Dict, Tuple, List -import mlir.ir as ir -from mlir.dialects import ( +import buddy_mlir.ir as ir +from buddy_mlir.dialects import ( tosa, linalg, arith, @@ -3359,22 +3359,17 @@ def _get_vectorizable_trailing_dims(input2, input3_shape, accumulate): if accumulate: return 0, 0 - # Count trailing dimensions without index tensors - num_vectorizable_dims = 0 - for d in range(len(input3_shape) - 1, -1, -1): - if d < len(input2) and input2[d] is not None: - break # This dimension has an index tensor, stop - num_vectorizable_dims += 1 - - if num_vectorizable_dims == 0: + # Current vectorized lowering only supports vectorizing the last dimension. + # Multi-dimension vectorization would require flattening/collapsing memrefs + # before transfer_read/write. + last_dim = len(input3_shape) - 1 + if last_dim < 0: + return 0, 0 + if last_dim < len(input2) and input2[last_dim] is not None: return 0, 0 - # Calculate vector length (product of vectorizable dimensions) - vector_length = 1 - for d in range( - len(input3_shape) - num_vectorizable_dims, len(input3_shape) - ): - vector_length *= input3_shape[d] + num_vectorizable_dims = 1 + vector_length = input3_shape[last_dim] # Only vectorize if beneficial (at least 4 elements) if vector_length < 4: @@ -3518,14 +3513,14 @@ def _generate_vectorized_index_put( vector_type = ir.VectorType.get([vector_length], mlir_dtype) # Padding value for transfer_read - if str(mlir_dtype).startswith("f"): + if ir.FloatType.isinstance(mlir_dtype) or ir.BF16Type.isinstance( + mlir_dtype + ): padding = arith.ConstantOp( mlir_dtype, ir.FloatAttr.get(mlir_dtype, 0.0) ) else: - padding = arith.ConstantOp( - mlir_dtype, ir.IntegerAttr.get(mlir_dtype, 0) - ) + padding = arith.ConstantOp(mlir_dtype, 0) # AffineMap: map from rank-D memref to 1D vector (last num_vec_dims -> 1) # For a rank-4 memref with 1 vec dim: (d0, d1, d2, d3) -> (d3) @@ -3562,11 +3557,14 @@ def create_nested_loops_vectorized(dim, idx_vars): # Vectorized dimension with index - this shouldn't happen # as we don't vectorize dimensions with indices dst_indices.append(lb) - elif d < len(idx_vars): - dst_indices.append(idx_vars[d]) else: - # This is a vectorized dimension, use 0 as starting index - dst_indices.append(lb) + # Vector path only runs on equal-rank source/destination. + # Keep non-index dims aligned by original dimension position. + if d < len(idx_vars): + dst_indices.append(idx_vars[d]) + else: + # Vectorized trailing dim starts at 0. + dst_indices.append(lb) # Vector read from source vec_val = vector.TransferReadOp( @@ -3665,12 +3663,27 @@ def index_put_op( input1_elem_type = input1.type.element_type input1_memref_type = ir.MemRefType.get(input1_shape, input1_elem_type) - input1_memref = bufferization.ToBufferOp(input1_memref_type, input1).result + input1_src_memref = bufferization.ToBufferOp( + input1_memref_type, input1 + ).result + input1_memref = memref.AllocOp(input1_memref_type, [], []).result + memref.CopyOp(input1_src_memref, input1_memref) # Check if we can vectorize trailing dimensions num_vec_dims, vector_length = _get_vectorizable_trailing_dims( input2, input3_shape, accumulate ) + non_index_dims = [ + d + for d in range(len(output_shape)) + if d >= len(input2) or input2[d] is None + ] + broadcast_rank = len(input3_shape) - len(non_index_dims) + + # Vector path uses a single permutation_map for transfer_read/write. + # Keep it on equal-rank source/destination updates only. + if len(input1_shape) != len(input3_shape): + num_vec_dims = 0 if num_vec_dims > 0: # Use vectorized path @@ -3691,17 +3704,16 @@ def index_put_op( if input2_ is None: input2_memref_list.append(None) continue - # For vectorized path, index tensors should have shape of loop dims - # Broadcast to loop dimensions shape - loop_shape = input3_shape[:num_loop_dims] try: index_tensor = _broadcast_index_tensor_for_vec( - input2_, loop_shape + input2_, input3_shape[:num_loop_dims] ) index_elem_type = ir.RankedTensorType( index_tensor.type ).element_type - memref_type = ir.MemRefType.get(loop_shape, index_elem_type) + memref_type = ir.MemRefType.get( + input3_shape[:num_loop_dims], index_elem_type + ) input2_memref_list.append( bufferization.ToBufferOp(memref_type, index_tensor) ) @@ -3745,7 +3757,6 @@ def _broadcast_index_tensor(value, target_shape, dim): if ( len(value_shape) == 1 - and len(target_shape) == len(output_shape) and dim < len(target_shape) and value_shape[0] == target_shape[dim] ): @@ -3818,21 +3829,37 @@ def _broadcast_index_tensor(value, target_shape, dim): return value - # Convert index tensors to memrefs + index_iter_shape = input3_shape[:broadcast_rank] + + # Convert index tensors to memrefs. + # Prefer full-shape broadcast (covers complex patterns like col2im), and + # fall back to broadcast-index-shape for pure advanced-index cases. input2_memref = [] + input2_use_full_shape = [] for i in range(len(input2)): if input2[i] is None: input2_memref.append(None) + input2_use_full_shape.append(False) continue input2_ = symbol_table.get((str(input2[i]), 0)) if input2_ is None: return - index_tensor = _broadcast_index_tensor(input2_, input3_shape, dim=i) + use_full_shape = True + try: + index_tensor = _broadcast_index_tensor(input2_, input3_shape, dim=i) + memref_shape = input3_shape + except ValueError: + use_full_shape = False + index_tensor = _broadcast_index_tensor( + input2_, index_iter_shape, dim=i + ) + memref_shape = index_iter_shape index_elem_type = ir.RankedTensorType(index_tensor.type).element_type - memref_type = ir.MemRefType.get(input3_shape, index_elem_type) + memref_type = ir.MemRefType.get(memref_shape, index_elem_type) input2_memref.append( bufferization.ToBufferOp(memref_type, index_tensor) ) + input2_use_full_shape.append(use_full_shape) input3_memref_element_type = input3.type.element_type input3_memref_type = ir.MemRefType.get( @@ -3858,20 +3885,28 @@ def create_nested_loops(dim, loops, idx_vars): # Build store indices: use index tensors where available, loop vars otherwise store_index = [] + non_index_cursor = 0 for d in range(len(output_shape)): if d < len(input2) and input2[d] is not None: # Use corresponding index tensor # The index tensor should have shape matching input3_shape + idx_access = ( + val_index + if input2_use_full_shape[d] + else val_index[:broadcast_rank] + ) idx_dim_val = memref.LoadOp( - input2_memref[d], val_index + input2_memref[d], idx_access ).result idx_dim = arith.IndexCastOp(ir.IndexType.get(), idx_dim_val) store_index.append(idx_dim) - elif d < len(idx_vars): - store_index.append(idx_vars[d]) else: - # Use constant 0 for missing dimensions - store_index.append(lb) + if len(input3_shape) == len(output_shape): + store_index.append(idx_vars[d]) + else: + source_pos = broadcast_rank + non_index_cursor + store_index.append(idx_vars[source_pos]) + non_index_cursor += 1 if accumulate: # Load existing value and add diff --git a/frontend/Python/ops/math.py b/frontend/Python/ops/math.py index 204c534da9..2dcf92b7a7 100644 --- a/frontend/Python/ops/math.py +++ b/frontend/Python/ops/math.py @@ -18,7 +18,7 @@ # # ===--------------------------------------------------------------------------- -from mlir.dialects import math +from buddy_mlir.dialects import math def erf_op(node, symbol_table): @@ -199,8 +199,8 @@ def round_op(node, symbol_table): Note: MLIR's math.round semantics may differ for half values, so implement PyTorch behavior while keeping a math.round op for IR checks. """ - import mlir.ir as ir - import mlir.dialects.arith as arith + import buddy_mlir.ir as ir + import buddy_mlir.dialects.arith as arith input_tensor = symbol_table.get((str(node.args[0]), 0)) input_type = ir.RankedTensorType(input_tensor.type) @@ -268,7 +268,7 @@ def trunc_op(node, symbol_table): def abs_op(node, symbol_table): """abs(x) using math.AbsFOp for float, math.AbsIOp for int""" - import mlir.ir as ir + import buddy_mlir.ir as ir input_tensor = symbol_table.get((str(node.args[0]), 0)) element_type = ir.RankedTensorType(input_tensor.type).element_type @@ -281,8 +281,8 @@ def abs_op(node, symbol_table): def powf_op(node, symbol_table): """pow(x, y) for float tensors using math.PowFOp""" - import mlir.ir as ir - from mlir.dialects import arith, tensor + import buddy_mlir.ir as ir + from buddy_mlir.dialects import arith, tensor input1 = symbol_table.get((str(node.args[0]), 0)) input2 = symbol_table.get((str(node.args[1]), 0)) diff --git a/frontend/Python/ops/tosa.py b/frontend/Python/ops/tosa.py index e75ce538a9..1e47c39db3 100644 --- a/frontend/Python/ops/tosa.py +++ b/frontend/Python/ops/tosa.py @@ -23,9 +23,9 @@ import numpy import sys -import mlir.ir as ir -from mlir.ir import IndexType, F32Type -from mlir.dialects import ( +import buddy_mlir.ir as ir +from buddy_mlir.ir import IndexType, F32Type +from buddy_mlir.dialects import ( tensor, tosa, arith, @@ -568,8 +568,7 @@ def addmm_op( matmul_result_type, ( ir.FloatAttr.get(result_element_type, 0.0) - if str(result_element_type) == "f32" - or str(result_element_type) == "f16" + if _is_float_type(result_element_type) else ir.IntegerAttr.get(result_element_type, 0) ), ) @@ -4620,9 +4619,7 @@ def flash_attention_for_cpu_prefill_op( sum_vec = vector.SplatOp(v16, sum, loc=loc).result # Truncate sum to dtype_qkv for out_scores_memref if need_cast: - sum_qkv = arith.TruncFOp( - dtype_qkv, sum, loc=loc - ).result + sum_qkv = arith.TruncFOp(dtype_qkv, sum, loc=loc).result else: sum_qkv = sum memref.StoreOp(sum_qkv, out_scores_memref, [b, h, idx_q]) @@ -6583,9 +6580,14 @@ def alias_op(node: AliasOp, symbol_table): input_shape = list(ir.RankedTensorType(input1.type).shape) input_dtype = ir.RankedTensorType(input1.type).element_type - # Alias is essentially identity + output_shape = list(node.tensor_meta["shape"]) + output_dtype = mlir_element_type_get(node.tensor_meta["dtype"]) + + if input_shape == output_shape and input_dtype == output_dtype: + return input1 + return tosa.IdentityOp( - ir.RankedTensorType.get(input_shape, input_dtype), input1 + ir.RankedTensorType.get(output_shape, output_dtype), input1 ).result @@ -13761,9 +13763,7 @@ def gqa_attention_fused_op(node: GQAAttentionFusedOp, symbol_table): output_shape = list(node.tensor_meta["shape"]) # All intermediate constants use compute_dtype (f32) - scale_val = ( - 1 / numpy.sqrt(query.type.shape[-1]) if scale is None else scale - ) + scale_val = 1 / numpy.sqrt(query.type.shape[-1]) if scale is None else scale scale_val = arith.ConstantOp(compute_dtype, float(scale_val)).result neg_inf = arith.ConstantOp(compute_dtype, -1.0e30, loc=loc).result @@ -13978,9 +13978,7 @@ def gqa_attention_fused_op(node: GQAAttentionFusedOp, symbol_table): softmax_result, [b, h, q, k], loc=loc ).result - pv = vector.SplatOp( - v16_compute, p, loc=loc - ).result + pv = vector.SplatOp(v16_compute, p, loc=loc).result perm_map = ir.AffineMap.get( 4, 0, [ir.AffineDimExpr.get(3)] ) diff --git a/frontend/Python/ops/utils.py b/frontend/Python/ops/utils.py index 989fae142c..f9941350e3 100644 --- a/frontend/Python/ops/utils.py +++ b/frontend/Python/ops/utils.py @@ -19,7 +19,7 @@ # ===--------------------------------------------------------------------------- from typing import Dict -import mlir.ir as ir +import buddy_mlir.ir as ir from ..graph import TensorDType diff --git a/midend/CMakeLists.txt b/midend/CMakeLists.txt index 86946d12b7..90587c553f 100644 --- a/midend/CMakeLists.txt +++ b/midend/CMakeLists.txt @@ -3,11 +3,16 @@ add_subdirectory(lib) if(MLIR_ENABLE_BINDINGS_PYTHON) include(MLIRDetectPythonEnv) + + # We support multiple Python versions + # Development.Module asks CMake to locate the build artifacts needed to + # compile extension modules (headers + linkable module libs). See: + # https://cmake.org/cmake/help/latest/module/FindPython3.html find_package(Python3 ${LLVM_MINIMUM_PYTHON_VERSION} - COMPONENTS Interpreter Development NumPy REQUIRED) + COMPONENTS Interpreter Development.Module NumPy REQUIRED) set(Python_EXECUTABLE ${Python3_EXECUTABLE}) find_package(Python - COMPONENTS Interpreter Development NumPy REQUIRED) + COMPONENTS Interpreter Development.Module NumPy REQUIRED) mlir_detect_pybind11_install() find_package(pybind11 2.10 CONFIG REQUIRED) mlir_detect_nanobind_install() diff --git a/midend/include/Dialect/AME/AME.td b/midend/include/Dialect/AME/AME.td new file mode 100644 index 0000000000..5b77e42f66 --- /dev/null +++ b/midend/include/Dialect/AME/AME.td @@ -0,0 +1,569 @@ +//====------ AME.td - AME dialect operation definitions --- tablegen ------===// +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// +// +// This file defines the AME dialect for RISC-V Matrix Extension. +// +//===----------------------------------------------------------------------===// + +#ifndef AME_DIALECT +#define AME_DIALECT + +include "mlir/IR/OpBase.td" +include "mlir/IR/BuiltinTypes.td" +include "mlir/Interfaces/InferTypeOpInterface.td" +include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/Dialect/LLVMIR/LLVMOpBase.td" + +//===----------------------------------------------------------------------===// +// AME Dialect Definition +//===----------------------------------------------------------------------===// + +def AME_Dialect : Dialect { + let name = "ame"; + let cppNamespace = "::buddy::ame"; + let description = [{ + The AME dialect provides operations for the RISC-V Matrix Extension (RVA23 profile). + AME provides matrix multiply-accumulate instructions with tile-based computation, + supporting various data types and widening modes for AI/ML workloads. + + Key features: + - Tile-based matrix registers with configurable dimensions + - Support for int4/int8/int16/int32/int64 integer types + - Support for fp8/fp16/fp32/fp64 floating-point types + - Widening matrix multiplication (2x, 4x, 8x output width) + - Saturating arithmetic options + }]; +} + +//===----------------------------------------------------------------------===// +// AME Operation Base Class +//===----------------------------------------------------------------------===// + +class AME_Op traits = []> : + Op {} + +//===----------------------------------------------------------------------===// +// Matrix Configuration Operations +//===----------------------------------------------------------------------===// + +def MSettypeOp : AME_Op<"msettype"> { + let summary = "Set matrix type configuration"; + let description = [{ + Sets the matrix type configuration from a GPR register. + Returns the previous mtype value. + }]; + let arguments = (ins I64:$mtype); + let results = (outs I64:$result); + let assemblyFormat = "$mtype attr-dict"; +} + +def MSettypeiOp : AME_Op<"msettypei"> { + let summary = "Set matrix type configuration with immediate"; + let description = [{ + Sets the matrix type configuration from an immediate value. + Returns the previous mtype value. + }]; + let arguments = (ins I64Attr:$mtype); + let results = (outs I64:$result); + let assemblyFormat = "$mtype attr-dict"; +} + +def MSettilemOp : AME_Op<"msettilem"> { + let summary = "Set tile M dimension"; + let description = [{ + Sets the tile M dimension (number of rows). + Returns the actual mtilem value set. + }]; + let arguments = (ins I64:$tilem); + let results = (outs I64:$result); + let assemblyFormat = "$tilem attr-dict"; +} + +def MSettilenOp : AME_Op<"msettilen"> { + let summary = "Set tile N dimension"; + let description = [{ + Sets the tile N dimension (number of columns). + Returns the actual mtilen value set. + }]; + let arguments = (ins I64:$tilen); + let results = (outs I64:$result); + let assemblyFormat = "$tilen attr-dict"; +} + +def MSettilekOp : AME_Op<"msettilek"> { + let summary = "Set tile K dimension"; + let description = [{ + Sets the tile K dimension (inner dimension for matrix multiply). + Returns the actual mtilek value set. + }]; + let arguments = (ins I64:$tilek); + let results = (outs I64:$result); + let assemblyFormat = "$tilek attr-dict"; +} + +//===----------------------------------------------------------------------===// +// Matrix Configuration Operations (with immediate) +//===----------------------------------------------------------------------===// + +def MSettilemiOp : AME_Op<"msettilemi"> { + let summary = "Set tile M dimension with immediate"; + let description = [{ + Sets the tile M dimension (number of rows) from an immediate value. + mtilem = min(tilem, TMMAX) + }]; + let arguments = (ins I64Attr:$tilem); + let assemblyFormat = "$tilem attr-dict"; +} + +def MSettileniOp : AME_Op<"msettileni"> { + let summary = "Set tile N dimension with immediate"; + let description = [{ + Sets the tile N dimension (number of columns) from an immediate value. + mtilen = min(tilen, TNMAX) + }]; + let arguments = (ins I64Attr:$tilen); + let assemblyFormat = "$tilen attr-dict"; +} + +def MSettilekiOp : AME_Op<"msettileki"> { + let summary = "Set tile K dimension with immediate"; + let description = [{ + Sets the tile K dimension (inner dimension for matrix multiply) from an immediate value. + mtilek = min(tilek, TKMAX) + }]; + let arguments = (ins I64Attr:$tilek); + let assemblyFormat = "$tilek attr-dict"; +} + +//===----------------------------------------------------------------------===// +// Matrix Zero Operations +//===----------------------------------------------------------------------===// + +def MzeroOp : AME_Op<"mzero"> { + let summary = "Zero out accumulation matrix tile"; + let description = [{ + Sets all elements in the specified accumulation matrix tile to zero. + }]; + let arguments = (ins I64Attr:$md); + let assemblyFormat = "$md attr-dict"; +} + +//===----------------------------------------------------------------------===// +// Matrix Load Instructions +// Load matrix tiles from memory to tile registers +//===----------------------------------------------------------------------===// + +// Load left matrix A (mtilem x mtilek) +def Mlae32mOp : AME_Op<"mlae32.m"> { + let summary = "Load 32-bit left matrix tile A"; + let description = [{ + Load a left matrix tile from memory to tile register. + Shape: mtilem x mtilek + - md: destination tile register index + - base: base address pointer + - stride: row byte stride + }]; + let arguments = (ins + I64Attr:$md, + AnyMemRef:$base, + I64:$stride + ); + let assemblyFormat = "$md `,` $base `,` $stride attr-dict `:` type($base)"; +} + +def Mlae64mOp : AME_Op<"mlae64.m"> { + let summary = "Load 64-bit left matrix tile A"; + let description = [{ + Load a 64-bit left matrix tile from memory to tile register. + Shape: mtilem x mtilek + }]; + let arguments = (ins + I64Attr:$md, + AnyMemRef:$base, + I64:$stride + ); + let assemblyFormat = "$md `,` $base `,` $stride attr-dict `:` type($base)"; +} + +// Load right matrix B (mtilek x mtilen) +def Mlbe32mOp : AME_Op<"mlbe32.m"> { + let summary = "Load 32-bit right matrix tile B"; + let description = [{ + Load a right matrix tile from memory to tile register. + Shape: mtilek x mtilen + - md: destination tile register index + - base: base address pointer + - stride: row byte stride + }]; + let arguments = (ins + I64Attr:$md, + AnyMemRef:$base, + I64:$stride + ); + let assemblyFormat = "$md `,` $base `,` $stride attr-dict `:` type($base)"; +} + +def Mlbe64mOp : AME_Op<"mlbe64.m"> { + let summary = "Load 64-bit right matrix tile B"; + let description = [{ + Load a 64-bit right matrix tile from memory to tile register. + Shape: mtilek x mtilen + }]; + let arguments = (ins + I64Attr:$md, + AnyMemRef:$base, + I64:$stride + ); + let assemblyFormat = "$md `,` $base `,` $stride attr-dict `:` type($base)"; +} + +// Load output/accumulator matrix C (mtilem x mtilen) +def Mlce32mOp : AME_Op<"mlce32.m"> { + let summary = "Load 32-bit output matrix tile C"; + let description = [{ + Load an output/accumulator matrix tile from memory. + Shape: mtilem x mtilen + }]; + let arguments = (ins + I64Attr:$md, + AnyMemRef:$base, + I64:$stride + ); + let assemblyFormat = "$md `,` $base `,` $stride attr-dict `:` type($base)"; +} + +def Mlce64mOp : AME_Op<"mlce64.m"> { + let summary = "Load 64-bit output matrix tile C"; + let description = [{ + Load a 64-bit output/accumulator matrix tile from memory. + Shape: mtilem x mtilen + }]; + let arguments = (ins + I64Attr:$md, + AnyMemRef:$base, + I64:$stride + ); + let assemblyFormat = "$md `,` $base `,` $stride attr-dict `:` type($base)"; +} + +//===----------------------------------------------------------------------===// +// Matrix Store Instructions +// Store matrix tiles from tile registers to memory +//===----------------------------------------------------------------------===// + +// Store output/accumulator matrix C (mtilem x mtilen) +def Msce32mOp : AME_Op<"msce32.m"> { + let summary = "Store 32-bit output matrix tile C"; + let description = [{ + Store an output/accumulator matrix tile to memory. + Shape: mtilem x mtilen + - ms3: source tile register index + - base: base address pointer + - stride: row byte stride + }]; + let arguments = (ins + I64Attr:$ms3, + AnyMemRef:$base, + I64:$stride + ); + let assemblyFormat = "$ms3 `,` $base `,` $stride attr-dict `:` type($base)"; +} + +def Msce64mOp : AME_Op<"msce64.m"> { + let summary = "Store 64-bit output matrix tile C"; + let description = [{ + Store a 64-bit output/accumulator matrix tile to memory. + Shape: mtilem x mtilen + }]; + let arguments = (ins + I64Attr:$ms3, + AnyMemRef:$base, + I64:$stride + ); + let assemblyFormat = "$ms3 `,` $base `,` $stride attr-dict `:` type($base)"; +} + +//===----------------------------------------------------------------------===// +// Signed Integer Matrix Multiplication Operations (Tile Register Version) +// These operate on tile registers, not memory +//===----------------------------------------------------------------------===// + +def MmaWmmTileOp : AME_Op<"mma.w.mm.tile"> { + let summary = "Signed int32 tile matrix multiply-accumulate"; + let description = [{ + Performs signed int32 matrix multiply-accumulate on tile registers: + md = md + ms1 × ms2 + where ms1, ms2 are tile registers, and md is accumulation register. + + - md: accumulation register index (0-7) + - ms1: left tile register index (0-7) + - ms2: right tile register index (0-7) + }]; + let arguments = (ins + I64Attr:$md, + I64Attr:$ms1, + I64Attr:$ms2 + ); + let assemblyFormat = "$md `,` $ms1 `,` $ms2 attr-dict"; +} + +def MmaDwmmTileOp : AME_Op<"mma.dw.mm.tile"> { + let summary = "Signed int64 tile matrix multiply-accumulate"; + let description = [{ + Performs signed int64 matrix multiply-accumulate on tile registers: + md = md + ms1 × ms2 + where ms1, ms2 are tile registers, and md is accumulation register. + }]; + let arguments = (ins + I64Attr:$md, + I64Attr:$ms1, + I64Attr:$ms2 + ); + let assemblyFormat = "$md `,` $ms1 `,` $ms2 attr-dict"; +} + +//===----------------------------------------------------------------------===// +// High-level Matrix Multiplication Operations (MemRef Version) +// These are high-level ops that will be lowered to load/compute/store sequence +//===----------------------------------------------------------------------===// + +def MqmaBmmOp : AME_Op<"mqma.b.mm"> { + let summary = "Signed int8 matrix multiply-accumulate with quad-widen output"; + let description = [{ + Performs signed int8 matrix multiply-accumulate with quad-widen output: + md = md + ms1 × ms2 + where ms1 and ms2 are int8 matrices, and md is int32 accumulator. + + Output is 4x widened: int8 × int8 -> int32. + This is the primary instruction for int8 AI workloads. + }]; + let arguments = (ins + MemRefRankOf<[I32], [2]>:$md, + MemRefRankOf<[I8], [2]>:$ms1, + MemRefRankOf<[I8], [2]>:$ms2 + ); + let assemblyFormat = "$md `,` $ms1 `,` $ms2 attr-dict `:` type($md) `,` type($ms1) `,` type($ms2)"; +} + +def MmaHmmOp : AME_Op<"mma.h.mm"> { + let summary = "Signed int16 matrix multiply-accumulate"; + let description = [{ + Performs signed int16 matrix multiply-accumulate: + md = md + ms1 × ms2 + where ms1, ms2, and md are all int16 matrices. + }]; + let arguments = (ins + MemRefRankOf<[I16], [2]>:$md, + MemRefRankOf<[I16], [2]>:$ms1, + MemRefRankOf<[I16], [2]>:$ms2 + ); + let assemblyFormat = "$md `,` $ms1 `,` $ms2 attr-dict `:` type($md) `,` type($ms1) `,` type($ms2)"; +} + +def MmaWmmOp : AME_Op<"mma.w.mm"> { + let summary = "Signed int32 matrix multiply-accumulate"; + let description = [{ + Performs signed int32 matrix multiply-accumulate: + md = md + ms1 × ms2 + where ms1, ms2, and md are all int32 matrices. + }]; + let arguments = (ins + MemRefRankOf<[I32], [2]>:$md, + MemRefRankOf<[I32], [2]>:$ms1, + MemRefRankOf<[I32], [2]>:$ms2 + ); + let assemblyFormat = "$md `,` $ms1 `,` $ms2 attr-dict `:` type($md) `,` type($ms1) `,` type($ms2)"; +} + +def MmaDwmmOp : AME_Op<"mma.dw.mm"> { + let summary = "Signed int64 matrix multiply-accumulate"; + let description = [{ + Performs signed int64 matrix multiply-accumulate: + md = md + ms1 × ms2 + where ms1, ms2, and md are all int64 matrices. + + This is useful for high-precision matrix computations. + }]; + let arguments = (ins + MemRefRankOf<[I64], [2]>:$md, + MemRefRankOf<[I64], [2]>:$ms1, + MemRefRankOf<[I64], [2]>:$ms2 + ); + let assemblyFormat = "$md `,` $ms1 `,` $ms2 attr-dict `:` type($md) `,` type($ms1) `,` type($ms2)"; +} + +def MwmaHmmOp : AME_Op<"mwma.h.mm"> { + let summary = "Signed int16 matrix multiply-accumulate with double-widen output"; + let description = [{ + Performs signed int16 matrix multiply-accumulate with double-widen output: + md = md + ms1 × ms2 + where ms1 and ms2 are int16 matrices, and md is int32 accumulator. + }]; + let arguments = (ins + MemRefRankOf<[I32], [2]>:$md, + MemRefRankOf<[I16], [2]>:$ms1, + MemRefRankOf<[I16], [2]>:$ms2 + ); + let assemblyFormat = "$md `,` $ms1 `,` $ms2 attr-dict `:` type($md) `,` type($ms1) `,` type($ms2)"; +} + +//===----------------------------------------------------------------------===// +// Unsigned Integer Matrix Multiplication Operations +//===----------------------------------------------------------------------===// + +def MqmauBmmOp : AME_Op<"mqmau.b.mm"> { + let summary = "Unsigned int8 matrix multiply-accumulate with quad-widen output"; + let description = [{ + Performs unsigned int8 matrix multiply-accumulate with quad-widen output: + md = md + ms1 × ms2 + where ms1 and ms2 are uint8 matrices, and md is uint32 accumulator. + }]; + let arguments = (ins + MemRefRankOf<[I32], [2]>:$md, + MemRefRankOf<[I8], [2]>:$ms1, + MemRefRankOf<[I8], [2]>:$ms2 + ); + let assemblyFormat = "$md `,` $ms1 `,` $ms2 attr-dict `:` type($md) `,` type($ms1) `,` type($ms2)"; +} + +def MmauHmmOp : AME_Op<"mmau.h.mm"> { + let summary = "Unsigned int16 matrix multiply-accumulate"; + let description = [{ + Performs unsigned int16 matrix multiply-accumulate: + md = md + ms1 × ms2 + where ms1, ms2, and md are all uint16 matrices. + }]; + let arguments = (ins + MemRefRankOf<[I16], [2]>:$md, + MemRefRankOf<[I16], [2]>:$ms1, + MemRefRankOf<[I16], [2]>:$ms2 + ); + let assemblyFormat = "$md `,` $ms1 `,` $ms2 attr-dict `:` type($md) `,` type($ms1) `,` type($ms2)"; +} + +//===----------------------------------------------------------------------===// +// Floating-Point Matrix Multiplication Operations +//===----------------------------------------------------------------------===// + +def MfmaFmmOp : AME_Op<"mfma.f.mm"> { + let summary = "FP32 matrix multiply-accumulate"; + let description = [{ + Performs FP32 matrix multiply-accumulate: + md = md + ms1 × ms2 + where ms1, ms2, and md are all fp32 matrices. + }]; + let arguments = (ins + MemRefRankOf<[F32], [2]>:$md, + MemRefRankOf<[F32], [2]>:$ms1, + MemRefRankOf<[F32], [2]>:$ms2 + ); + let assemblyFormat = "$md `,` $ms1 `,` $ms2 attr-dict `:` type($md) `,` type($ms1) `,` type($ms2)"; +} + +def MfmaHfmmOp : AME_Op<"mfma.hf.mm"> { + let summary = "FP16 matrix multiply-accumulate"; + let description = [{ + Performs FP16 matrix multiply-accumulate: + md = md + ms1 × ms2 + where ms1, ms2, and md are all fp16 matrices. + }]; + let arguments = (ins + MemRefRankOf<[F16], [2]>:$md, + MemRefRankOf<[F16], [2]>:$ms1, + MemRefRankOf<[F16], [2]>:$ms2 + ); + let assemblyFormat = "$md `,` $ms1 `,` $ms2 attr-dict `:` type($md) `,` type($ms1) `,` type($ms2)"; +} + +def MfwmaHfmmOp : AME_Op<"mfwma.hf.mm"> { + let summary = "FP16 matrix multiply-accumulate with double-widen output"; + let description = [{ + Performs FP16 matrix multiply-accumulate with FP32 accumulator: + md = md + ms1 × ms2 + where ms1 and ms2 are fp16 matrices, and md is fp32 accumulator. + }]; + let arguments = (ins + MemRefRankOf<[F32], [2]>:$md, + MemRefRankOf<[F16], [2]>:$ms1, + MemRefRankOf<[F16], [2]>:$ms2 + ); + let assemblyFormat = "$md `,` $ms1 `,` $ms2 attr-dict `:` type($md) `,` type($ms1) `,` type($ms2)"; +} + +//===----------------------------------------------------------------------===// +// AME Intrinsic Operation Definitions (for LLVM IR translation) +// These map directly to LLVM intrinsics defined in IntrinsicsRISCVBuddyExt.td +//===----------------------------------------------------------------------===// + +// Base class for AME intrinsic ops +// Note: enumName must match the intrinsic name pattern: riscv_buddy_ +class AME_IntrOpBase traits = []> : + LLVM_IntrOpBase overloadedResults=*/[], + /*list overloadedOperands=*/[], + /*list traits=*/traits, + /*int numResults=*/0>; + +// Configuration intrinsics (with immediate value) +def AME_Msettilemi_IntrOp : AME_IntrOpBase<"msettilemi">, + Arguments<(ins LLVM_Type:$tilem)>; + +def AME_Msettileni_IntrOp : AME_IntrOpBase<"msettileni">, + Arguments<(ins LLVM_Type:$tilen)>; + +def AME_Msettileki_IntrOp : AME_IntrOpBase<"msettileki">, + Arguments<(ins LLVM_Type:$tilek)>; + +// Matrix zero intrinsic +def AME_Mzero_IntrOp : AME_IntrOpBase<"mzero">, + Arguments<(ins LLVM_Type:$md)>; + +// Matrix load intrinsics (to tile registers) +def AME_Mlae32m_IntrOp : AME_IntrOpBase<"mlae32.m">, + Arguments<(ins LLVM_Type:$md, LLVM_Type:$base, LLVM_Type:$stride)>; + +def AME_Mlae64m_IntrOp : AME_IntrOpBase<"mlae64.m">, + Arguments<(ins LLVM_Type:$md, LLVM_Type:$base, LLVM_Type:$stride)>; + +def AME_Mlbe32m_IntrOp : AME_IntrOpBase<"mlbe32.m">, + Arguments<(ins LLVM_Type:$md, LLVM_Type:$base, LLVM_Type:$stride)>; + +def AME_Mlbe64m_IntrOp : AME_IntrOpBase<"mlbe64.m">, + Arguments<(ins LLVM_Type:$md, LLVM_Type:$base, LLVM_Type:$stride)>; + +def AME_Mlce32m_IntrOp : AME_IntrOpBase<"mlce32.m">, + Arguments<(ins LLVM_Type:$md, LLVM_Type:$base, LLVM_Type:$stride)>; + +def AME_Mlce64m_IntrOp : AME_IntrOpBase<"mlce64.m">, + Arguments<(ins LLVM_Type:$md, LLVM_Type:$base, LLVM_Type:$stride)>; + +// Matrix store intrinsics (from tile registers) +def AME_Msce32m_IntrOp : AME_IntrOpBase<"msce32.m">, + Arguments<(ins LLVM_Type:$ms3, LLVM_Type:$base, LLVM_Type:$stride)>; + +def AME_Msce64m_IntrOp : AME_IntrOpBase<"msce64.m">, + Arguments<(ins LLVM_Type:$ms3, LLVM_Type:$base, LLVM_Type:$stride)>; + +// Tile register matrix multiplication intrinsics +def AME_MmaWmmTile_IntrOp : AME_IntrOpBase<"mma.w.mm.tile">, + Arguments<(ins LLVM_Type:$md, LLVM_Type:$ms1, LLVM_Type:$ms2)>; + +def AME_MmaDwmmTile_IntrOp : AME_IntrOpBase<"mma.dw.mm.tile">, + Arguments<(ins LLVM_Type:$md, LLVM_Type:$ms1, LLVM_Type:$ms2)>; + +#endif // AME_DIALECT diff --git a/midend/include/Dialect/AME/AMEDialect.h b/midend/include/Dialect/AME/AMEDialect.h new file mode 100644 index 0000000000..f9cc1ab71f --- /dev/null +++ b/midend/include/Dialect/AME/AMEDialect.h @@ -0,0 +1,24 @@ +//====- AMEDialect.h - MLIR Dialect for RISC-V AME extension --------------===// +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// + +#ifndef AME_AMEDIALECT_H +#define AME_AMEDIALECT_H + +#include "mlir/IR/Dialect.h" + +#include "AME/AMEDialect.h.inc" + +#endif // AME_AMEDIALECT_H diff --git a/midend/include/Dialect/AME/AMEOps.h b/midend/include/Dialect/AME/AMEOps.h new file mode 100644 index 0000000000..550b665bbe --- /dev/null +++ b/midend/include/Dialect/AME/AMEOps.h @@ -0,0 +1,30 @@ +//====- AMEOps.h - MLIR Dialect for RISC-V AME extension ------------------===// +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// + +#ifndef AME_AMEOPS_H +#define AME_AMEOPS_H + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" + +#define GET_OP_CLASSES +#include "AME/AME.h.inc" + +#endif // AME_AMEOPS_H diff --git a/midend/include/Dialect/AME/CMakeLists.txt b/midend/include/Dialect/AME/CMakeLists.txt new file mode 100644 index 0000000000..24427c0195 --- /dev/null +++ b/midend/include/Dialect/AME/CMakeLists.txt @@ -0,0 +1,6 @@ +add_mlir_dialect(AME ame) +add_mlir_doc(AME AME Dialects/ -gen-dialect-doc) + +set(LLVM_TARGET_DEFINITIONS AME.td) +mlir_tablegen(AMEConversions.inc -gen-llvmir-conversions) +add_public_tablegen_target(BuddyAMEConversionsIncGen) diff --git a/midend/include/Dialect/AME/Transform.h b/midend/include/Dialect/AME/Transform.h new file mode 100644 index 0000000000..3b7f5b8cdd --- /dev/null +++ b/midend/include/Dialect/AME/Transform.h @@ -0,0 +1,42 @@ +//===- Transform.h - AME Dialect Transformation Passes ----------*- C++ -*-===// +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// + +#ifndef AME_TRANSFORM_H +#define AME_TRANSFORM_H + +#include "mlir/Pass/Pass.h" + +namespace mlir { + +class LLVMConversionTarget; +class LLVMTypeConverter; +class RewritePatternSet; + +void populateAMELegalizeForLLVMExportPatterns(LLVMTypeConverter &converter, + RewritePatternSet &patterns); +void configureAMELegalizeForExportTarget(LLVMConversionTarget &target); + +} // namespace mlir + +namespace buddy { +namespace ame { + +std::unique_ptr createLegalizeForLLVMExportPass(); + +} // namespace ame +} // namespace buddy + +#endif // AME_TRANSFORM_H diff --git a/midend/include/Dialect/CMakeLists.txt b/midend/include/Dialect/CMakeLists.txt index c1690ba540..ed346db7c3 100644 --- a/midend/include/Dialect/CMakeLists.txt +++ b/midend/include/Dialect/CMakeLists.txt @@ -1,3 +1,4 @@ +add_subdirectory(AME) add_subdirectory(Bud) add_subdirectory(DAP) add_subdirectory(DIP) diff --git a/midend/include/Dialect/IME/IME.td b/midend/include/Dialect/IME/IME.td index 965f3621ab..c47ab1a97a 100644 --- a/midend/include/Dialect/IME/IME.td +++ b/midend/include/Dialect/IME/IME.td @@ -44,12 +44,15 @@ def VmadotOp : IME_Op<"vmadot"> { let summary = "Integer matrix multiply-accumulate (signed × signed)"; let description = [{ Performs matrix multiply-accumulate: vd += vs1 × vs2 - where vs1 and vs2 are int8 matrices, and vd is int32 accumulator. + where vs1 and vs2 are int8 or int16 matrices, and vd is int32 accumulator. + + For int8 (SEW=e8): MAC unit 4x4x8, matrices A(4x8), B(8x4), C(4x4) + For int16 (SEW=e16): MAC unit 4x4x4, matrices A(4x4), B(4x4), C(4x4) }]; let arguments = (ins MemRefRankOf<[I32], [2]>:$vd, - MemRefRankOf<[I8], [2]>:$vs1, - MemRefRankOf<[I8], [2]>:$vs2 + MemRefRankOf<[I8, I16], [2]>:$vs1, + MemRefRankOf<[I8, I16], [2]>:$vs2 ); let assemblyFormat = "$vd `,` $vs1 `,` $vs2 attr-dict `:` type($vd) `,` type($vs1) `,` type($vs2)"; } @@ -94,6 +97,233 @@ def VfmadotOp : IME_Op<"vfmadot"> { let assemblyFormat = "$vd `,` $vs1 `,` $vs2 attr-dict `:` type($vd) `,` type($vs1) `,` type($vs2)"; } +//===----------------------------------------------------------------------===// +// Sliding-window Integer Matrix Multiply-Accumulate Operations +//===----------------------------------------------------------------------===// +// These operations select values from VS1 and VS1+1 using a sliding window. +// slide-1/2/3 use fixed slide values, while slide-n uses a dynamic value. + +// slide-1 operations +def Vmadot1Op : IME_Op<"vmadot1"> { + let summary = "Sliding-window integer MMA (signed × signed, slide=1)"; + let description = [{ + Performs sliding-window matrix multiply-accumulate with slide=1. + Input A is selected from VS1 and VS1+1 with a slide offset of 1*K elements. + }]; + let arguments = (ins + MemRefRankOf<[I32], [2]>:$vd, + MemRefRankOf<[I8], [2]>:$vs1, + MemRefRankOf<[I8], [2]>:$vs2 + ); + let assemblyFormat = "$vd `,` $vs1 `,` $vs2 attr-dict `:` type($vd) `,` type($vs1) `,` type($vs2)"; +} + +def Vmadot1uOp : IME_Op<"vmadot1u"> { + let summary = "Sliding-window integer MMA (unsigned × unsigned, slide=1)"; + let arguments = (ins + MemRefRankOf<[I32], [2]>:$vd, + MemRefRankOf<[UI8], [2]>:$vs1, + MemRefRankOf<[UI8], [2]>:$vs2 + ); + let assemblyFormat = "$vd `,` $vs1 `,` $vs2 attr-dict `:` type($vd) `,` type($vs1) `,` type($vs2)"; +} + +def Vmadot1suOp : IME_Op<"vmadot1su"> { + let summary = "Sliding-window integer MMA (signed × unsigned, slide=1)"; + let arguments = (ins + MemRefRankOf<[I32], [2]>:$vd, + MemRefRankOf<[I8], [2]>:$vs1, + MemRefRankOf<[UI8], [2]>:$vs2 + ); + let assemblyFormat = "$vd `,` $vs1 `,` $vs2 attr-dict `:` type($vd) `,` type($vs1) `,` type($vs2)"; +} + +def Vmadot1usOp : IME_Op<"vmadot1us"> { + let summary = "Sliding-window integer MMA (unsigned × signed, slide=1)"; + let arguments = (ins + MemRefRankOf<[I32], [2]>:$vd, + MemRefRankOf<[UI8], [2]>:$vs1, + MemRefRankOf<[I8], [2]>:$vs2 + ); + let assemblyFormat = "$vd `,` $vs1 `,` $vs2 attr-dict `:` type($vd) `,` type($vs1) `,` type($vs2)"; +} + +// slide-2 operations +def Vmadot2Op : IME_Op<"vmadot2"> { + let summary = "Sliding-window integer MMA (signed × signed, slide=2)"; + let arguments = (ins + MemRefRankOf<[I32], [2]>:$vd, + MemRefRankOf<[I8], [2]>:$vs1, + MemRefRankOf<[I8], [2]>:$vs2 + ); + let assemblyFormat = "$vd `,` $vs1 `,` $vs2 attr-dict `:` type($vd) `,` type($vs1) `,` type($vs2)"; +} + +def Vmadot2uOp : IME_Op<"vmadot2u"> { + let summary = "Sliding-window integer MMA (unsigned × unsigned, slide=2)"; + let arguments = (ins + MemRefRankOf<[I32], [2]>:$vd, + MemRefRankOf<[UI8], [2]>:$vs1, + MemRefRankOf<[UI8], [2]>:$vs2 + ); + let assemblyFormat = "$vd `,` $vs1 `,` $vs2 attr-dict `:` type($vd) `,` type($vs1) `,` type($vs2)"; +} + +def Vmadot2suOp : IME_Op<"vmadot2su"> { + let summary = "Sliding-window integer MMA (signed × unsigned, slide=2)"; + let arguments = (ins + MemRefRankOf<[I32], [2]>:$vd, + MemRefRankOf<[I8], [2]>:$vs1, + MemRefRankOf<[UI8], [2]>:$vs2 + ); + let assemblyFormat = "$vd `,` $vs1 `,` $vs2 attr-dict `:` type($vd) `,` type($vs1) `,` type($vs2)"; +} + +def Vmadot2usOp : IME_Op<"vmadot2us"> { + let summary = "Sliding-window integer MMA (unsigned × signed, slide=2)"; + let arguments = (ins + MemRefRankOf<[I32], [2]>:$vd, + MemRefRankOf<[UI8], [2]>:$vs1, + MemRefRankOf<[I8], [2]>:$vs2 + ); + let assemblyFormat = "$vd `,` $vs1 `,` $vs2 attr-dict `:` type($vd) `,` type($vs1) `,` type($vs2)"; +} + +// slide-3 operations +def Vmadot3Op : IME_Op<"vmadot3"> { + let summary = "Sliding-window integer MMA (signed × signed, slide=3)"; + let arguments = (ins + MemRefRankOf<[I32], [2]>:$vd, + MemRefRankOf<[I8], [2]>:$vs1, + MemRefRankOf<[I8], [2]>:$vs2 + ); + let assemblyFormat = "$vd `,` $vs1 `,` $vs2 attr-dict `:` type($vd) `,` type($vs1) `,` type($vs2)"; +} + +def Vmadot3uOp : IME_Op<"vmadot3u"> { + let summary = "Sliding-window integer MMA (unsigned × unsigned, slide=3)"; + let arguments = (ins + MemRefRankOf<[I32], [2]>:$vd, + MemRefRankOf<[UI8], [2]>:$vs1, + MemRefRankOf<[UI8], [2]>:$vs2 + ); + let assemblyFormat = "$vd `,` $vs1 `,` $vs2 attr-dict `:` type($vd) `,` type($vs1) `,` type($vs2)"; +} + +def Vmadot3suOp : IME_Op<"vmadot3su"> { + let summary = "Sliding-window integer MMA (signed × unsigned, slide=3)"; + let arguments = (ins + MemRefRankOf<[I32], [2]>:$vd, + MemRefRankOf<[I8], [2]>:$vs1, + MemRefRankOf<[UI8], [2]>:$vs2 + ); + let assemblyFormat = "$vd `,` $vs1 `,` $vs2 attr-dict `:` type($vd) `,` type($vs1) `,` type($vs2)"; +} + +def Vmadot3usOp : IME_Op<"vmadot3us"> { + let summary = "Sliding-window integer MMA (unsigned × signed, slide=3)"; + let arguments = (ins + MemRefRankOf<[I32], [2]>:$vd, + MemRefRankOf<[UI8], [2]>:$vs1, + MemRefRankOf<[I8], [2]>:$vs2 + ); + let assemblyFormat = "$vd `,` $vs1 `,` $vs2 attr-dict `:` type($vd) `,` type($vs1) `,` type($vs2)"; +} + +// slide-n operations (dynamic slide value) +def VmadotnOp : IME_Op<"vmadotn"> { + let summary = "Sliding-window integer MMA (signed × signed, slide=n)"; + let description = [{ + Performs sliding-window matrix multiply-accumulate with dynamic slide value. + The slide value is passed as an additional operand. + }]; + let arguments = (ins + MemRefRankOf<[I32], [2]>:$vd, + MemRefRankOf<[I8], [2]>:$vs1, + MemRefRankOf<[I8], [2]>:$vs2, + I64:$slide + ); + let assemblyFormat = "$vd `,` $vs1 `,` $vs2 `,` $slide attr-dict `:` type($vd) `,` type($vs1) `,` type($vs2)"; +} + +def VmadotnuOp : IME_Op<"vmadotnu"> { + let summary = "Sliding-window integer MMA (unsigned × unsigned, slide=n)"; + let arguments = (ins + MemRefRankOf<[I32], [2]>:$vd, + MemRefRankOf<[UI8], [2]>:$vs1, + MemRefRankOf<[UI8], [2]>:$vs2, + I64:$slide + ); + let assemblyFormat = "$vd `,` $vs1 `,` $vs2 `,` $slide attr-dict `:` type($vd) `,` type($vs1) `,` type($vs2)"; +} + +def VmadotnsuOp : IME_Op<"vmadotnsu"> { + let summary = "Sliding-window integer MMA (signed × unsigned, slide=n)"; + let arguments = (ins + MemRefRankOf<[I32], [2]>:$vd, + MemRefRankOf<[I8], [2]>:$vs1, + MemRefRankOf<[UI8], [2]>:$vs2, + I64:$slide + ); + let assemblyFormat = "$vd `,` $vs1 `,` $vs2 `,` $slide attr-dict `:` type($vd) `,` type($vs1) `,` type($vs2)"; +} + +def VmadotnusOp : IME_Op<"vmadotnus"> { + let summary = "Sliding-window integer MMA (unsigned × signed, slide=n)"; + let arguments = (ins + MemRefRankOf<[I32], [2]>:$vd, + MemRefRankOf<[UI8], [2]>:$vs1, + MemRefRankOf<[I8], [2]>:$vs2, + I64:$slide + ); + let assemblyFormat = "$vd `,` $vs1 `,` $vs2 `,` $slide attr-dict `:` type($vd) `,` type($vs1) `,` type($vs2)"; +} + +//===----------------------------------------------------------------------===// +// Sliding-window Floating-Point Matrix Multiply-Accumulate Operations +//===----------------------------------------------------------------------===// + +def Vfmadot1Op : IME_Op<"vfmadot1"> { + let summary = "Sliding-window floating-point MMA (slide=1)"; + let arguments = (ins + MemRefRankOf<[F16], [2]>:$vd, + MemRefRankOf<[F16], [2]>:$vs1, + MemRefRankOf<[F16], [2]>:$vs2 + ); + let assemblyFormat = "$vd `,` $vs1 `,` $vs2 attr-dict `:` type($vd) `,` type($vs1) `,` type($vs2)"; +} + +def Vfmadot2Op : IME_Op<"vfmadot2"> { + let summary = "Sliding-window floating-point MMA (slide=2)"; + let arguments = (ins + MemRefRankOf<[F16], [2]>:$vd, + MemRefRankOf<[F16], [2]>:$vs1, + MemRefRankOf<[F16], [2]>:$vs2 + ); + let assemblyFormat = "$vd `,` $vs1 `,` $vs2 attr-dict `:` type($vd) `,` type($vs1) `,` type($vs2)"; +} + +def Vfmadot3Op : IME_Op<"vfmadot3"> { + let summary = "Sliding-window floating-point MMA (slide=3)"; + let arguments = (ins + MemRefRankOf<[F16], [2]>:$vd, + MemRefRankOf<[F16], [2]>:$vs1, + MemRefRankOf<[F16], [2]>:$vs2 + ); + let assemblyFormat = "$vd `,` $vs1 `,` $vs2 attr-dict `:` type($vd) `,` type($vs1) `,` type($vs2)"; +} + +def VfmadotnOp : IME_Op<"vfmadotn"> { + let summary = "Sliding-window floating-point MMA (slide=n)"; + let arguments = (ins + MemRefRankOf<[F16], [2]>:$vd, + MemRefRankOf<[F16], [2]>:$vs1, + MemRefRankOf<[F16], [2]>:$vs2, + I64:$slide + ); + let assemblyFormat = "$vd `,` $vs1 `,` $vs2 `,` $slide attr-dict `:` type($vd) `,` type($vs1) `,` type($vs2)"; +} + //===----------------------------------------------------------------------===// // IME Intrinsic operation definitions //===----------------------------------------------------------------------===// @@ -123,4 +353,76 @@ def IME_Vmadotus_IntrOp : IME_IntrOpBase<"vmadotus">, def IME_Vfmadot_IntrOp : IME_IntrOpBase<"vfmadot">, Arguments<(ins LLVM_Type:$vd, LLVM_Type:$vs1, LLVM_Type:$vs2)>; +//===----------------------------------------------------------------------===// +// Sliding-window Integer Matrix Multiply-Accumulate Intrinsics +//===----------------------------------------------------------------------===// + +// slide-1 intrinsics +def IME_Vmadot1_IntrOp : IME_IntrOpBase<"vmadot1">, + Arguments<(ins LLVM_Type:$vd, LLVM_Type:$vs1, LLVM_Type:$vs2)>; + +def IME_Vmadot1u_IntrOp : IME_IntrOpBase<"vmadot1u">, + Arguments<(ins LLVM_Type:$vd, LLVM_Type:$vs1, LLVM_Type:$vs2)>; + +def IME_Vmadot1su_IntrOp : IME_IntrOpBase<"vmadot1su">, + Arguments<(ins LLVM_Type:$vd, LLVM_Type:$vs1, LLVM_Type:$vs2)>; + +def IME_Vmadot1us_IntrOp : IME_IntrOpBase<"vmadot1us">, + Arguments<(ins LLVM_Type:$vd, LLVM_Type:$vs1, LLVM_Type:$vs2)>; + +// slide-2 intrinsics +def IME_Vmadot2_IntrOp : IME_IntrOpBase<"vmadot2">, + Arguments<(ins LLVM_Type:$vd, LLVM_Type:$vs1, LLVM_Type:$vs2)>; + +def IME_Vmadot2u_IntrOp : IME_IntrOpBase<"vmadot2u">, + Arguments<(ins LLVM_Type:$vd, LLVM_Type:$vs1, LLVM_Type:$vs2)>; + +def IME_Vmadot2su_IntrOp : IME_IntrOpBase<"vmadot2su">, + Arguments<(ins LLVM_Type:$vd, LLVM_Type:$vs1, LLVM_Type:$vs2)>; + +def IME_Vmadot2us_IntrOp : IME_IntrOpBase<"vmadot2us">, + Arguments<(ins LLVM_Type:$vd, LLVM_Type:$vs1, LLVM_Type:$vs2)>; + +// slide-3 intrinsics +def IME_Vmadot3_IntrOp : IME_IntrOpBase<"vmadot3">, + Arguments<(ins LLVM_Type:$vd, LLVM_Type:$vs1, LLVM_Type:$vs2)>; + +def IME_Vmadot3u_IntrOp : IME_IntrOpBase<"vmadot3u">, + Arguments<(ins LLVM_Type:$vd, LLVM_Type:$vs1, LLVM_Type:$vs2)>; + +def IME_Vmadot3su_IntrOp : IME_IntrOpBase<"vmadot3su">, + Arguments<(ins LLVM_Type:$vd, LLVM_Type:$vs1, LLVM_Type:$vs2)>; + +def IME_Vmadot3us_IntrOp : IME_IntrOpBase<"vmadot3us">, + Arguments<(ins LLVM_Type:$vd, LLVM_Type:$vs1, LLVM_Type:$vs2)>; + +// slide-n intrinsics (with dynamic slide value in GPR) +def IME_Vmadotn_IntrOp : IME_IntrOpBase<"vmadotn">, + Arguments<(ins LLVM_Type:$vd, LLVM_Type:$vs1, LLVM_Type:$vs2, LLVM_Type:$slide)>; + +def IME_Vmadotnu_IntrOp : IME_IntrOpBase<"vmadotnu">, + Arguments<(ins LLVM_Type:$vd, LLVM_Type:$vs1, LLVM_Type:$vs2, LLVM_Type:$slide)>; + +def IME_Vmadotnsu_IntrOp : IME_IntrOpBase<"vmadotnsu">, + Arguments<(ins LLVM_Type:$vd, LLVM_Type:$vs1, LLVM_Type:$vs2, LLVM_Type:$slide)>; + +def IME_Vmadotnus_IntrOp : IME_IntrOpBase<"vmadotnus">, + Arguments<(ins LLVM_Type:$vd, LLVM_Type:$vs1, LLVM_Type:$vs2, LLVM_Type:$slide)>; + +//===----------------------------------------------------------------------===// +// Sliding-window Floating-Point Matrix Multiply-Accumulate Intrinsics +//===----------------------------------------------------------------------===// + +def IME_Vfmadot1_IntrOp : IME_IntrOpBase<"vfmadot1">, + Arguments<(ins LLVM_Type:$vd, LLVM_Type:$vs1, LLVM_Type:$vs2)>; + +def IME_Vfmadot2_IntrOp : IME_IntrOpBase<"vfmadot2">, + Arguments<(ins LLVM_Type:$vd, LLVM_Type:$vs1, LLVM_Type:$vs2)>; + +def IME_Vfmadot3_IntrOp : IME_IntrOpBase<"vfmadot3">, + Arguments<(ins LLVM_Type:$vd, LLVM_Type:$vs1, LLVM_Type:$vs2)>; + +def IME_Vfmadotn_IntrOp : IME_IntrOpBase<"vfmadotn">, + Arguments<(ins LLVM_Type:$vd, LLVM_Type:$vs1, LLVM_Type:$vs2, LLVM_Type:$slide)>; + #endif // IME_DIALECT diff --git a/midend/include/Target/LLVMIR/Dialect/AME/AMEToLLVMIRTranslation.h b/midend/include/Target/LLVMIR/Dialect/AME/AMEToLLVMIRTranslation.h new file mode 100644 index 0000000000..d424c02870 --- /dev/null +++ b/midend/include/Target/LLVMIR/Dialect/AME/AMEToLLVMIRTranslation.h @@ -0,0 +1,32 @@ +//===- AMEToLLVMIRTranslation.h - AME to LLVM IR ----------------*- C++ -*-===// +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// +// +// This provides registration calls for AME dialect to LLVM IR translation. +// +//===----------------------------------------------------------------------===// + +#ifndef TARGET_LLVMIR_DIALECT_AME_AMETOLLVMIRTRANSLATION_H +#define TARGET_LLVMIR_DIALECT_AME_AMETOLLVMIRTRANSLATION_H + +#include "mlir/IR/DialectRegistry.h" +#include "mlir/IR/MLIRContext.h" + +namespace buddy { +void registerAMEDialectTranslation(mlir::DialectRegistry ®istry); +void registerAMEDialectTranslation(mlir::MLIRContext &context); +} // namespace buddy + +#endif diff --git a/midend/lib/CMakeLists.txt b/midend/lib/CMakeLists.txt index b4db4108b1..1324511674 100644 --- a/midend/lib/CMakeLists.txt +++ b/midend/lib/CMakeLists.txt @@ -12,6 +12,8 @@ set(LinkedLibs MLIRSupport ${extension_libs} + BuddyAME + BuddyAMETransforms BuddyTile ConvOptimization CBConvVectorization @@ -23,6 +25,7 @@ set(LinkedLibs LowerBuckyballPass LowerGemminiPass LowerLinalgToGemminiPass + LowerLinalgToIMEPass LowerIMEPass LowerRVVPass LowerVectorExpPass diff --git a/midend/lib/Conversion/CMakeLists.txt b/midend/lib/Conversion/CMakeLists.txt index 76bd3218d6..bc7d24d7c5 100644 --- a/midend/lib/Conversion/CMakeLists.txt +++ b/midend/lib/Conversion/CMakeLists.txt @@ -14,6 +14,7 @@ add_subdirectory(LowerTileToBuckyball) add_subdirectory(LowerBuckyball) add_subdirectory(LowerGemmini) add_subdirectory(LowerLinalgToGemmini) +add_subdirectory(LowerLinalgToIME) add_subdirectory(LowerIME) add_subdirectory(DepthwiseConvOptimization) add_subdirectory(MLIRGPU) diff --git a/midend/lib/Conversion/LowerLinalgToIME/CMakeLists.txt b/midend/lib/Conversion/LowerLinalgToIME/CMakeLists.txt new file mode 100644 index 0000000000..c8e083cf87 --- /dev/null +++ b/midend/lib/Conversion/LowerLinalgToIME/CMakeLists.txt @@ -0,0 +1,3 @@ +add_mlir_library(LowerLinalgToIMEPass + LowerLinalgToIME.cpp +) diff --git a/midend/lib/Conversion/LowerLinalgToIME/LowerLinalgToIME.cpp b/midend/lib/Conversion/LowerLinalgToIME/LowerLinalgToIME.cpp new file mode 100644 index 0000000000..d7e9c2bc76 --- /dev/null +++ b/midend/lib/Conversion/LowerLinalgToIME/LowerLinalgToIME.cpp @@ -0,0 +1,1114 @@ +//====- LowerLinalgToIME.cpp - Linalg to IME Dialect Lowering Pass --------===// +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// +// +// This file defines Linalg dialect lowering pass to IME dialect. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "Dialect/IME/IMEDialect.h" +#include "Dialect/IME/IMEOps.h" + +using namespace mlir; +using namespace buddy::ime; + + +static void getTileSizes(Type elemType, int64_t &tileM, int64_t &tileK, + int64_t &tileN) { + if (elemType.isInteger(8)) { + tileM = 4; + tileK = 8; + tileN = 4; + } else if (elemType.isInteger(16)) { + tileM = 4; + tileK = 4; + tileN = 4; + } else { + // Default to int8 tile sizes + tileM = 4; + tileK = 8; + tileN = 4; + } +} + +static bool isSupportedElementType(Type elemType) { + return elemType.isInteger(8) || elemType.isInteger(16); +} + +namespace { + +class MatmulToIMELowering : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(linalg::MatmulOp matmulOp, + PatternRewriter &rewriter) const override { + Location loc = matmulOp.getLoc(); + + Value A = matmulOp.getInputs()[0]; // M x K + Value B = matmulOp.getInputs()[1]; // K x N + Value C = matmulOp.getOutputs()[0]; // M x N + + auto AType = dyn_cast(A.getType()); + auto BType = dyn_cast(B.getType()); + auto CType = dyn_cast(C.getType()); + + if (!AType || !BType || !CType) { + return rewriter.notifyMatchFailure(matmulOp, + "operands must be memref types"); + } + + Type AElemType = AType.getElementType(); + Type BElemType = BType.getElementType(); + Type CElemType = CType.getElementType(); + + if (!isSupportedElementType(AElemType) || + !isSupportedElementType(BElemType)) { + return rewriter.notifyMatchFailure( + matmulOp, "only int8 and int16 element types are supported"); + } + + if (AElemType != BElemType) { + return rewriter.notifyMatchFailure( + matmulOp, "A and B must have the same element type"); + } + + if (!CElemType.isInteger(32)) { + return rewriter.notifyMatchFailure( + matmulOp, "output C must be int32 for accumulation"); + } + + ArrayRef AShape = AType.getShape(); + ArrayRef BShape = BType.getShape(); + + if (AShape.size() != 2 || BShape.size() != 2) { + return rewriter.notifyMatchFailure(matmulOp, + "only 2D matrices are supported"); + } + + int64_t M = AShape[0]; + int64_t K = AShape[1]; + int64_t N = BShape[1]; + + bool isDynamic = ShapedType::isDynamic(M) || ShapedType::isDynamic(K) || + ShapedType::isDynamic(N); + + int64_t tileM, tileK, tileN; + getTileSizes(AElemType, tileM, tileK, tileN); + + // This pattern only handles aligned dimensions. + // Non-aligned dimensions are handled by MatmulWithBoundaryToIMELowering. + if (!isDynamic) { + bool isAligned = (M % tileM == 0) && (K % tileK == 0) && (N % tileN == 0); + if (!isAligned) { + return rewriter.notifyMatchFailure( + matmulOp, "non-aligned dimensions - use boundary handling pattern"); + } + } + getTileSizes(AElemType, tileM, tileK, tileN); + + Value c0 = rewriter.create(loc, 0); + Value stepM = rewriter.create(loc, tileM); + Value stepK = rewriter.create(loc, tileK); + Value stepN = rewriter.create(loc, tileN); + + Value boundM, boundK, boundN; + if (isDynamic) { + boundM = rewriter.create(loc, A, 0); + boundK = rewriter.create(loc, A, 1); + boundN = rewriter.create(loc, B, 1); + } else { + boundM = rewriter.create(loc, M); + boundK = rewriter.create(loc, K); + boundN = rewriter.create(loc, N); + } + + Value c1 = rewriter.create(loc, 1); + Value tileKVal = rewriter.create(loc, tileK); + Value tileNVal = rewriter.create(loc, tileN); + + // Allocate contiguous buffer for B tile in column-major pack format + // BTile[N][K] for column-major packing (IME expects B transposed) + auto BTileType = MemRefType::get({tileN, tileK}, AElemType); + Value BTileBuffer = rewriter.create(loc, BTileType); + + auto loopI = rewriter.create(loc, c0, boundM, stepM); + rewriter.setInsertionPointToStart(loopI.getBody()); + Value ivI = loopI.getInductionVar(); + + auto loopJ = rewriter.create(loc, c0, boundN, stepN); + rewriter.setInsertionPointToStart(loopJ.getBody()); + Value ivJ = loopJ.getInductionVar(); + + auto loopK = rewriter.create(loc, c0, boundK, stepK); + rewriter.setInsertionPointToStart(loopK.getBody()); + Value ivK = loopK.getInductionVar(); + + // A tile: use SubView (no transpose needed) + SmallVector aOffsets = {ivI, ivK}; + SmallVector aSizes = {rewriter.getIndexAttr(tileM), + rewriter.getIndexAttr(tileK)}; + SmallVector aStrides = {rewriter.getIndexAttr(1), + rewriter.getIndexAttr(1)}; + + Value ATile = + rewriter.create(loc, A, aOffsets, aSizes, aStrides); + + // B tile: copy with transpose (B[k][n] -> BTile[n][k]) + // IME requires B to be in column-major pack format + auto copyBLoopN = rewriter.create(loc, c0, tileNVal, c1); + rewriter.setInsertionPointToStart(copyBLoopN.getBody()); + Value bn = copyBLoopN.getInductionVar(); + + auto copyBLoopK = rewriter.create(loc, c0, tileKVal, c1); + rewriter.setInsertionPointToStart(copyBLoopK.getBody()); + Value bk = copyBLoopK.getInductionVar(); + + // Global indices in B matrix + Value globalBk = rewriter.create(loc, ivK, bk); + Value globalBn = rewriter.create(loc, ivJ, bn); + + // Load B[globalBk][globalBn] and store to BTile[bn][bk] (transposed) + Value bVal = rewriter.create(loc, B, ValueRange{globalBk, globalBn}); + rewriter.create(loc, bVal, BTileBuffer, ValueRange{bn, bk}); + + rewriter.setInsertionPointAfter(copyBLoopN); + + // C tile: use SubView + SmallVector cOffsets = {ivI, ivJ}; + SmallVector cSizes = {rewriter.getIndexAttr(tileM), + rewriter.getIndexAttr(tileN)}; + SmallVector cStrides = {rewriter.getIndexAttr(1), + rewriter.getIndexAttr(1)}; + + Value CTile = + rewriter.create(loc, C, cOffsets, cSizes, cStrides); + + rewriter.create(loc, CTile, ATile, BTileBuffer); + + rewriter.setInsertionPointAfter(loopI); + + rewriter.eraseOp(matmulOp); + + return success(); + } +}; + +class MatmulWithBoundaryToIMELowering + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + MatmulWithBoundaryToIMELowering(MLIRContext *context) + : OpRewritePattern(context, /*benefit=*/1) {} + + LogicalResult matchAndRewrite(linalg::MatmulOp matmulOp, + PatternRewriter &rewriter) const override { + Location loc = matmulOp.getLoc(); + + Value A = matmulOp.getInputs()[0]; // M x K + Value B = matmulOp.getInputs()[1]; // K x N + Value C = matmulOp.getOutputs()[0]; // M x N + + auto AType = dyn_cast(A.getType()); + auto BType = dyn_cast(B.getType()); + auto CType = dyn_cast(C.getType()); + + if (!AType || !BType || !CType) { + return rewriter.notifyMatchFailure(matmulOp, + "operands must be memref types"); + } + + Type AElemType = AType.getElementType(); + Type BElemType = BType.getElementType(); + Type CElemType = CType.getElementType(); + + if (!isSupportedElementType(AElemType) || + !isSupportedElementType(BElemType)) { + return rewriter.notifyMatchFailure( + matmulOp, "only int8 and int16 element types are supported"); + } + + if (AElemType != BElemType) { + return rewriter.notifyMatchFailure( + matmulOp, "A and B must have the same element type"); + } + + if (!CElemType.isInteger(32)) { + return rewriter.notifyMatchFailure( + matmulOp, "output C must be int32 for accumulation"); + } + + ArrayRef AShape = AType.getShape(); + ArrayRef BShape = BType.getShape(); + + if (AShape.size() != 2 || BShape.size() != 2) { + return rewriter.notifyMatchFailure(matmulOp, + "only 2D matrices are supported"); + } + + int64_t M = AShape[0]; + int64_t K = AShape[1]; + int64_t N = BShape[1]; + + // Get tile sizes for the element type + int64_t tileM, tileK, tileN; + getTileSizes(AElemType, tileM, tileK, tileN); + + // This pattern only handles static dimensions + bool hasStaticDims = !ShapedType::isDynamic(M) && + !ShapedType::isDynamic(K) && !ShapedType::isDynamic(N); + + if (!hasStaticDims) { + return rewriter.notifyMatchFailure( + matmulOp, "dynamic dimensions not supported in boundary pattern"); + } + + // For static dimensions, check if they are aligned + bool isAligned = (M % tileM == 0) && (K % tileK == 0) && (N % tileN == 0); + if (isAligned) { + // Let the simpler MatmulToIMELowering handle aligned cases + return rewriter.notifyMatchFailure( + matmulOp, "aligned dimensions - use simple lowering"); + } + + // Calculate number of tiles (ceiling division) + int64_t numTilesM = (M + tileM - 1) / tileM; + int64_t numTilesK = (K + tileK - 1) / tileK; + int64_t numTilesN = (N + tileN - 1) / tileN; + + // Create constants + Value c0 = rewriter.create(loc, 0); + Value c1 = rewriter.create(loc, 1); + Value tileMVal = rewriter.create(loc, tileM); + Value tileKVal = rewriter.create(loc, tileK); + Value tileNVal = rewriter.create(loc, tileN); + Value boundM = rewriter.create(loc, M); + Value boundK = rewriter.create(loc, K); + Value boundN = rewriter.create(loc, N); + Value numTilesMVal = rewriter.create(loc, numTilesM); + Value numTilesKVal = rewriter.create(loc, numTilesK); + Value numTilesNVal = rewriter.create(loc, numTilesN); + + // Create zero constants for padding + Value zeroElem = rewriter.create( + loc, AElemType, rewriter.getZeroAttr(AElemType)); + Value zeroI32 = rewriter.create( + loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(0)); + + auto ATileType = MemRefType::get({tileM, tileK}, AElemType); + auto BTileType = MemRefType::get({tileN, tileK}, AElemType); // [N, K] for column-major pack + auto CTileType = MemRefType::get({tileM, tileN}, CElemType); + + Value ATile = rewriter.create(loc, ATileType); + Value BTile = rewriter.create(loc, BTileType); + Value CTile = rewriter.create(loc, CTileType); + + // Loop over M tiles + auto loopTileM = rewriter.create(loc, c0, numTilesMVal, c1); + rewriter.setInsertionPointToStart(loopTileM.getBody()); + Value tileIdxM = loopTileM.getInductionVar(); + Value baseM = rewriter.create(loc, tileIdxM, tileMVal); + + // Loop over N tiles + auto loopTileN = rewriter.create(loc, c0, numTilesNVal, c1); + rewriter.setInsertionPointToStart(loopTileN.getBody()); + Value tileIdxN = loopTileN.getInductionVar(); + Value baseN = rewriter.create(loc, tileIdxN, tileNVal); + + // Initialize CTile to zeros (or copy from C for initial values) + auto initCLoop1 = rewriter.create(loc, c0, tileMVal, c1); + rewriter.setInsertionPointToStart(initCLoop1.getBody()); + Value initCi = initCLoop1.getInductionVar(); + auto initCLoop2 = rewriter.create(loc, c0, tileNVal, c1); + rewriter.setInsertionPointToStart(initCLoop2.getBody()); + Value initCj = initCLoop2.getInductionVar(); + + // Calculate global indices + Value globalCi = rewriter.create(loc, baseM, initCi); + Value globalCj = rewriter.create(loc, baseN, initCj); + + // Check if within bounds + Value inBoundM = rewriter.create( + loc, arith::CmpIPredicate::ult, globalCi, boundM); + Value inBoundN = rewriter.create( + loc, arith::CmpIPredicate::ult, globalCj, boundN); + Value inBound = rewriter.create(loc, inBoundM, inBoundN); + + // Load from C if in bounds, else use zero + auto selectC = rewriter.create( + loc, CElemType, inBound, /*withElseRegion=*/true); + rewriter.setInsertionPointToStart(&selectC.getThenRegion().front()); + Value cLoadVal = rewriter.create(loc, C, ValueRange{globalCi, globalCj}); + rewriter.create(loc, cLoadVal); + rewriter.setInsertionPointToStart(&selectC.getElseRegion().front()); + rewriter.create(loc, zeroI32); + rewriter.setInsertionPointAfter(selectC); + + rewriter.create(loc, selectC.getResult(0), CTile, ValueRange{initCi, initCj}); + rewriter.setInsertionPointAfter(initCLoop1); + + // Loop over K tiles + auto loopTileK = rewriter.create(loc, c0, numTilesKVal, c1); + rewriter.setInsertionPointToStart(loopTileK.getBody()); + Value tileIdxK = loopTileK.getInductionVar(); + Value baseK = rewriter.create(loc, tileIdxK, tileKVal); + + // Copy A tile with boundary handling + auto copyALoop1 = rewriter.create(loc, c0, tileMVal, c1); + rewriter.setInsertionPointToStart(copyALoop1.getBody()); + Value copyAi = copyALoop1.getInductionVar(); + auto copyALoop2 = rewriter.create(loc, c0, tileKVal, c1); + rewriter.setInsertionPointToStart(copyALoop2.getBody()); + Value copyAk = copyALoop2.getInductionVar(); + + Value globalAi = rewriter.create(loc, baseM, copyAi); + Value globalAk = rewriter.create(loc, baseK, copyAk); + Value inBoundAM = rewriter.create( + loc, arith::CmpIPredicate::ult, globalAi, boundM); + Value inBoundAK = rewriter.create( + loc, arith::CmpIPredicate::ult, globalAk, boundK); + Value inBoundA = rewriter.create(loc, inBoundAM, inBoundAK); + + auto selectA = rewriter.create( + loc, AElemType, inBoundA, /*withElseRegion=*/true); + rewriter.setInsertionPointToStart(&selectA.getThenRegion().front()); + Value aLoadVal = rewriter.create(loc, A, ValueRange{globalAi, globalAk}); + rewriter.create(loc, aLoadVal); + rewriter.setInsertionPointToStart(&selectA.getElseRegion().front()); + rewriter.create(loc, zeroElem); + rewriter.setInsertionPointAfter(selectA); + + rewriter.create(loc, selectA.getResult(0), ATile, ValueRange{copyAi, copyAk}); + rewriter.setInsertionPointAfter(copyALoop1); + + // Copy B tile with boundary handling + // Note: B is stored in column-major pack format for IME + // B[k][n] in original matrix -> BTile[n][k] in packed format + auto copyBLoop1 = rewriter.create(loc, c0, tileNVal, c1); + rewriter.setInsertionPointToStart(copyBLoop1.getBody()); + Value copyBn = copyBLoop1.getInductionVar(); + auto copyBLoop2 = rewriter.create(loc, c0, tileKVal, c1); + rewriter.setInsertionPointToStart(copyBLoop2.getBody()); + Value copyBk = copyBLoop2.getInductionVar(); + + Value globalBk = rewriter.create(loc, baseK, copyBk); + Value globalBn = rewriter.create(loc, baseN, copyBn); + Value inBoundBK = rewriter.create( + loc, arith::CmpIPredicate::ult, globalBk, boundK); + Value inBoundBN = rewriter.create( + loc, arith::CmpIPredicate::ult, globalBn, boundN); + Value inBoundB = rewriter.create(loc, inBoundBK, inBoundBN); + + auto selectB = rewriter.create( + loc, AElemType, inBoundB, /*withElseRegion=*/true); + rewriter.setInsertionPointToStart(&selectB.getThenRegion().front()); + // Load from B[k][n] in row-major + Value bLoadVal = rewriter.create(loc, B, ValueRange{globalBk, globalBn}); + rewriter.create(loc, bLoadVal); + rewriter.setInsertionPointToStart(&selectB.getElseRegion().front()); + rewriter.create(loc, zeroElem); + rewriter.setInsertionPointAfter(selectB); + + // Store to BTile[n][k] in column-major pack format + rewriter.create(loc, selectB.getResult(0), BTile, ValueRange{copyBn, copyBk}); + rewriter.setInsertionPointAfter(copyBLoop1); + + // IME vmadot on contiguous tile buffers + rewriter.create(loc, CTile, ATile, BTile); + + // End of K tile loop + rewriter.setInsertionPointAfter(loopTileK); + + // Copy CTile back to C (only valid elements) + auto storeCLoop1 = rewriter.create(loc, c0, tileMVal, c1); + rewriter.setInsertionPointToStart(storeCLoop1.getBody()); + Value storeCi = storeCLoop1.getInductionVar(); + auto storeCLoop2 = rewriter.create(loc, c0, tileNVal, c1); + rewriter.setInsertionPointToStart(storeCLoop2.getBody()); + Value storeCj = storeCLoop2.getInductionVar(); + + Value globalStoreCi = rewriter.create(loc, baseM, storeCi); + Value globalStoreCj = rewriter.create(loc, baseN, storeCj); + Value inBoundStoreM = rewriter.create( + loc, arith::CmpIPredicate::ult, globalStoreCi, boundM); + Value inBoundStoreN = rewriter.create( + loc, arith::CmpIPredicate::ult, globalStoreCj, boundN); + Value inBoundStore = rewriter.create(loc, inBoundStoreM, inBoundStoreN); + + auto storeIf = rewriter.create(loc, inBoundStore, /*withElseRegion=*/false); + rewriter.setInsertionPointToStart(&storeIf.getThenRegion().front()); + Value cResult = rewriter.create(loc, CTile, ValueRange{storeCi, storeCj}); + rewriter.create(loc, cResult, C, ValueRange{globalStoreCi, globalStoreCj}); + + rewriter.setInsertionPointAfter(storeCLoop1); + + // End of N and M tile loops + rewriter.setInsertionPointAfter(loopTileM); + + // Erase the original operation + rewriter.eraseOp(matmulOp); + + return success(); + } +}; + +class GenericMatmulToIMELowering : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(linalg::GenericOp genericOp, + PatternRewriter &rewriter) const override { + + if (genericOp.getInputs().size() != 2 || genericOp.getOutputs().size() != 1) + return failure(); + + auto iteratorTypes = genericOp.getIteratorTypesArray(); + if (iteratorTypes.size() != 3) + return failure(); + + if (iteratorTypes[0] != utils::IteratorType::parallel || + iteratorTypes[1] != utils::IteratorType::parallel || + iteratorTypes[2] != utils::IteratorType::reduction) + return failure(); + + auto indexingMaps = genericOp.getIndexingMapsArray(); + if (indexingMaps.size() != 3) + return failure(); + + AffineMap mapA = indexingMaps[0]; + AffineMap mapB = indexingMaps[1]; + AffineMap mapC = indexingMaps[2]; + + auto d0 = rewriter.getAffineDimExpr(0); + auto d1 = rewriter.getAffineDimExpr(1); + auto d2 = rewriter.getAffineDimExpr(2); + + auto expectedMapA = AffineMap::get(3, 0, {d0, d2}, rewriter.getContext()); + auto expectedMapB = AffineMap::get(3, 0, {d2, d1}, rewriter.getContext()); + auto expectedMapC = AffineMap::get(3, 0, {d0, d1}, rewriter.getContext()); + + if (mapA != expectedMapA || mapB != expectedMapB || mapC != expectedMapC) + return failure(); + + Block &body = genericOp.getRegion().front(); + if (body.getOperations().size() != 3) + return failure(); + + auto ops = body.getOperations().begin(); + auto *firstOp = &*ops++; + auto *secondOp = &*ops++; + auto *yieldOp = &*ops; + + if (!isa(firstOp)) + return failure(); + if (!isa(secondOp)) + return failure(); + if (!isa(yieldOp)) + return failure(); + + Location loc = genericOp.getLoc(); + Value A = genericOp.getInputs()[0]; + Value B = genericOp.getInputs()[1]; + Value C = genericOp.getOutputs()[0]; + + auto AType = dyn_cast(A.getType()); + auto BType = dyn_cast(B.getType()); + auto CType = dyn_cast(C.getType()); + + if (!AType || !BType || !CType) + return failure(); + + Type AElemType = AType.getElementType(); + if (!isSupportedElementType(AElemType)) + return failure(); + + int64_t tileM, tileK, tileN; + getTileSizes(AElemType, tileM, tileK, tileN); + + ArrayRef AShape = AType.getShape(); + ArrayRef BShape = BType.getShape(); + int64_t M = AShape[0]; + int64_t K = AShape[1]; + int64_t N = BShape[1]; + + bool isDynamic = ShapedType::isDynamic(M) || ShapedType::isDynamic(K) || + ShapedType::isDynamic(N); + + Value c0 = rewriter.create(loc, 0); + Value stepM = rewriter.create(loc, tileM); + Value stepK = rewriter.create(loc, tileK); + Value stepN = rewriter.create(loc, tileN); + + Value boundM, boundK, boundN; + if (isDynamic) { + boundM = rewriter.create(loc, A, 0); + boundK = rewriter.create(loc, A, 1); + boundN = rewriter.create(loc, B, 1); + } else { + boundM = rewriter.create(loc, M); + boundK = rewriter.create(loc, K); + boundN = rewriter.create(loc, N); + } + + auto loopI = rewriter.create(loc, c0, boundM, stepM); + rewriter.setInsertionPointToStart(loopI.getBody()); + Value ivI = loopI.getInductionVar(); + + auto loopJ = rewriter.create(loc, c0, boundN, stepN); + rewriter.setInsertionPointToStart(loopJ.getBody()); + Value ivJ = loopJ.getInductionVar(); + + auto loopK = rewriter.create(loc, c0, boundK, stepK); + rewriter.setInsertionPointToStart(loopK.getBody()); + Value ivK = loopK.getInductionVar(); + + SmallVector aOffsets = {ivI, ivK}; + SmallVector aSizes = {rewriter.getIndexAttr(tileM), + rewriter.getIndexAttr(tileK)}; + SmallVector strides = {rewriter.getIndexAttr(1), + rewriter.getIndexAttr(1)}; + + Value ATile = + rewriter.create(loc, A, aOffsets, aSizes, strides); + + SmallVector bOffsets = {ivK, ivJ}; + SmallVector bSizes = {rewriter.getIndexAttr(tileK), + rewriter.getIndexAttr(tileN)}; + Value BTile = + rewriter.create(loc, B, bOffsets, bSizes, strides); + + SmallVector cOffsets = {ivI, ivJ}; + SmallVector cSizes = {rewriter.getIndexAttr(tileM), + rewriter.getIndexAttr(tileN)}; + Value CTile = + rewriter.create(loc, C, cOffsets, cSizes, strides); + + rewriter.create(loc, CTile, ATile, BTile); + + rewriter.setInsertionPointAfter(loopI); + rewriter.eraseOp(genericOp); + + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// Conv2D to IME Lowering Pattern (with Sliding-Window Instructions) +//===----------------------------------------------------------------------===// + +/// Pattern to lower linalg.conv_2d_nhwc_hwcf to IME sliding-window operations. + +class Conv2DNhwcHwcfToIMELowering + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(linalg::Conv2DNhwcHwcfOp convOp, + PatternRewriter &rewriter) const override { + Location loc = convOp.getLoc(); + + Value input = convOp.getInputs()[0]; + Value filter = convOp.getInputs()[1]; + Value output = convOp.getOutputs()[0]; + + auto inputType = dyn_cast(input.getType()); + auto filterType = dyn_cast(filter.getType()); + auto outputType = dyn_cast(output.getType()); + + if (!inputType || !filterType || !outputType) + return rewriter.notifyMatchFailure(convOp, + "operands must be memref types"); + + Type inputElemType = inputType.getElementType(); + if (!inputElemType.isInteger(8)) + return rewriter.notifyMatchFailure(convOp, "only int8 is supported"); + + ArrayRef inputShape = inputType.getShape(); + ArrayRef filterShape = filterType.getShape(); + ArrayRef outputShape = outputType.getShape(); + + if (inputShape.size() != 4 || filterShape.size() != 4 || + outputShape.size() != 4) + return rewriter.notifyMatchFailure(convOp, + "only 4D tensors are supported"); + + int64_t N = inputShape[0]; + int64_t IC = inputShape[3]; + int64_t FH = filterShape[0]; + int64_t FW = filterShape[1]; + int64_t OC = filterShape[3]; + int64_t OH = outputShape[1]; + int64_t OW = outputShape[2]; + + int64_t strideH = 1, strideW = 1; + + if (auto stridesAttr = convOp.getStrides()) { + auto strides = stridesAttr.getValues(); + strideH = strides[0]; + strideW = strides[1]; + } + + const int64_t TILE_M = 4; + const int64_t TILE_2M = 8; + const int64_t TILE_K = 8; + const int64_t TILE_N = 4; + + Value c0 = rewriter.create(loc, 0); + Value c1 = rewriter.create(loc, 1); + + Value boundN = rewriter.create(loc, N); + Value boundOH = rewriter.create(loc, OH); + Value boundOW = rewriter.create(loc, OW); + Value boundOC = rewriter.create(loc, OC); + Value boundFH = rewriter.create(loc, FH); + Value boundFW = rewriter.create(loc, FW); + Value boundIC = rewriter.create(loc, IC); + + Value stepM = rewriter.create(loc, TILE_M); + Value stepK = rewriter.create(loc, TILE_K); + Value stepN = rewriter.create(loc, TILE_N); + + Value strideHVal = rewriter.create(loc, strideH); + Value strideWVal = rewriter.create(loc, strideW); + + auto inputTileType = MemRefType::get({TILE_2M, TILE_K}, inputElemType); + auto filterTileType = MemRefType::get({TILE_K, TILE_N}, inputElemType); + auto outputTileType = + MemRefType::get({TILE_M, TILE_N}, rewriter.getI32Type()); + + auto loopN = rewriter.create(loc, c0, boundN, c1); + rewriter.setInsertionPointToStart(loopN.getBody()); + Value ivN = loopN.getInductionVar(); + + auto loopOH = rewriter.create(loc, c0, boundOH, stepM); + rewriter.setInsertionPointToStart(loopOH.getBody()); + Value ivOH = loopOH.getInductionVar(); + + auto loopOC = rewriter.create(loc, c0, boundOC, stepN); + rewriter.setInsertionPointToStart(loopOC.getBody()); + Value ivOC = loopOC.getInductionVar(); + + auto loopOW = rewriter.create(loc, c0, boundOW, c1); + rewriter.setInsertionPointToStart(loopOW.getBody()); + Value ivOW = loopOW.getInductionVar(); + + auto loopFH = rewriter.create(loc, c0, boundFH, c1); + rewriter.setInsertionPointToStart(loopFH.getBody()); + Value ivFH = loopFH.getInductionVar(); + + auto loopFW = rewriter.create(loc, c0, boundFW, c1); + rewriter.setInsertionPointToStart(loopFW.getBody()); + Value ivFW = loopFW.getInductionVar(); + + auto loopIC = rewriter.create(loc, c0, boundIC, stepK); + rewriter.setInsertionPointToStart(loopIC.getBody()); + Value ivIC = loopIC.getInductionVar(); + + Value inputTile = rewriter.create(loc, inputTileType); + Value filterTile = rewriter.create(loc, filterTileType); + Value outputTile = rewriter.create(loc, outputTileType); + + Value ihBase = rewriter.create(loc, ivOH, strideHVal); + ihBase = rewriter.create(loc, ihBase, ivFH); + Value iw = rewriter.create(loc, ivOW, strideWVal); + iw = rewriter.create(loc, iw, ivFW); + + Value tileMBound = rewriter.create(loc, TILE_M); + Value tile2MBound = rewriter.create(loc, TILE_2M); + Value tileKBound = rewriter.create(loc, TILE_K); + Value tileNBound = rewriter.create(loc, TILE_N); + + auto fillInputLoop = rewriter.create(loc, c0, tile2MBound, c1); + rewriter.setInsertionPointToStart(fillInputLoop.getBody()); + Value fillM = fillInputLoop.getInductionVar(); + + auto fillInputInnerLoop = + rewriter.create(loc, c0, tileKBound, c1); + rewriter.setInsertionPointToStart(fillInputInnerLoop.getBody()); + Value fillK = fillInputInnerLoop.getInductionVar(); + + Value mTimesStride = rewriter.create(loc, fillM, strideHVal); + Value inputIH = rewriter.create(loc, ihBase, mTimesStride); + Value inputIC = rewriter.create(loc, ivIC, fillK); + Value inputVal = rewriter.create( + loc, input, ValueRange{ivN, inputIH, iw, inputIC}); + rewriter.create(loc, inputVal, inputTile, + ValueRange{fillM, fillK}); + + rewriter.setInsertionPointAfter(fillInputLoop); + + auto fillFilterLoop = rewriter.create(loc, c0, tileKBound, c1); + rewriter.setInsertionPointToStart(fillFilterLoop.getBody()); + Value fillFK = fillFilterLoop.getInductionVar(); + + auto fillFilterInnerLoop = + rewriter.create(loc, c0, tileNBound, c1); + rewriter.setInsertionPointToStart(fillFilterInnerLoop.getBody()); + Value fillFN = fillFilterInnerLoop.getInductionVar(); + + Value filterIC = rewriter.create(loc, ivIC, fillFK); + Value filterOC = rewriter.create(loc, ivOC, fillFN); + Value filterVal = rewriter.create( + loc, filter, ValueRange{ivFH, ivFW, filterIC, filterOC}); + rewriter.create(loc, filterVal, filterTile, + ValueRange{fillFK, fillFN}); + + rewriter.setInsertionPointAfter(fillFilterLoop); + + auto loadOutputLoop = rewriter.create(loc, c0, tileMBound, c1); + rewriter.setInsertionPointToStart(loadOutputLoop.getBody()); + Value loadM = loadOutputLoop.getInductionVar(); + + auto loadOutputInnerLoop = + rewriter.create(loc, c0, tileNBound, c1); + rewriter.setInsertionPointToStart(loadOutputInnerLoop.getBody()); + Value loadN = loadOutputInnerLoop.getInductionVar(); + + Value outOH = rewriter.create(loc, ivOH, loadM); + Value outOC = rewriter.create(loc, ivOC, loadN); + Value outVal = rewriter.create( + loc, output, ValueRange{ivN, outOH, ivOW, outOC}); + rewriter.create(loc, outVal, outputTile, + ValueRange{loadM, loadN}); + + rewriter.setInsertionPointAfter(loadOutputLoop); + + if (strideH == 1) { + rewriter.create(loc, outputTile, inputTile, filterTile); + } else if (strideH == 2) { + rewriter.create(loc, outputTile, inputTile, filterTile); + } else if (strideH == 3) { + rewriter.create(loc, outputTile, inputTile, filterTile); + } else { + Value slide = rewriter.create(loc, 0, 64); + rewriter.create(loc, outputTile, inputTile, filterTile, slide); + } + + auto storeOutputLoop = rewriter.create(loc, c0, tileMBound, c1); + rewriter.setInsertionPointToStart(storeOutputLoop.getBody()); + Value storeM = storeOutputLoop.getInductionVar(); + + auto storeOutputInnerLoop = + rewriter.create(loc, c0, tileNBound, c1); + rewriter.setInsertionPointToStart(storeOutputInnerLoop.getBody()); + Value storeN = storeOutputInnerLoop.getInductionVar(); + + Value storeOH = rewriter.create(loc, ivOH, storeM); + Value storeOC = rewriter.create(loc, ivOC, storeN); + Value storeVal = rewriter.create( + loc, outputTile, ValueRange{storeM, storeN}); + rewriter.create(loc, storeVal, output, + ValueRange{ivN, storeOH, ivOW, storeOC}); + + rewriter.setInsertionPointAfter(storeOutputLoop); + + rewriter.setInsertionPointAfter(loopN); + + rewriter.eraseOp(convOp); + + return success(); + } +}; + +/// Pattern to lower linalg.conv_2d_nchw_fchw to IME sliding-window operations. + +class Conv2DNchwFchwToIMELowering + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(linalg::Conv2DNchwFchwOp convOp, + PatternRewriter &rewriter) const override { + Location loc = convOp.getLoc(); + + Value input = convOp.getInputs()[0]; + Value filter = convOp.getInputs()[1]; + Value output = convOp.getOutputs()[0]; + + auto inputType = dyn_cast(input.getType()); + auto filterType = dyn_cast(filter.getType()); + auto outputType = dyn_cast(output.getType()); + + if (!inputType || !filterType || !outputType) + return rewriter.notifyMatchFailure(convOp, + "operands must be memref types"); + + Type inputElemType = inputType.getElementType(); + if (!inputElemType.isInteger(8)) + return rewriter.notifyMatchFailure(convOp, "only int8 is supported"); + + ArrayRef inputShape = inputType.getShape(); + ArrayRef filterShape = filterType.getShape(); + ArrayRef outputShape = outputType.getShape(); + + if (inputShape.size() != 4 || filterShape.size() != 4 || + outputShape.size() != 4) + return rewriter.notifyMatchFailure(convOp, + "only 4D tensors are supported"); + + int64_t N = inputShape[0]; + int64_t IC = inputShape[1]; + int64_t OC = filterShape[0]; + int64_t FH = filterShape[2]; + int64_t FW = filterShape[3]; + int64_t OH = outputShape[2]; + int64_t OW = outputShape[3]; + + int64_t strideH = 1, strideW = 1; + + if (auto stridesAttr = convOp.getStrides()) { + auto strides = stridesAttr.getValues(); + strideH = strides[0]; + strideW = strides[1]; + } + + const int64_t TILE_M = 4; + const int64_t TILE_2M = 8; + const int64_t TILE_K = 8; + const int64_t TILE_N = 4; + + Value c0 = rewriter.create(loc, 0); + Value c1 = rewriter.create(loc, 1); + + Value boundN = rewriter.create(loc, N); + Value boundOC = rewriter.create(loc, OC); + Value boundOH = rewriter.create(loc, OH); + Value boundOW = rewriter.create(loc, OW); + Value boundIC = rewriter.create(loc, IC); + Value boundFH = rewriter.create(loc, FH); + Value boundFW = rewriter.create(loc, FW); + + Value stepM = rewriter.create(loc, TILE_M); + Value stepK = rewriter.create(loc, TILE_K); + Value stepN = rewriter.create(loc, TILE_N); + + Value strideHVal = rewriter.create(loc, strideH); + Value strideWVal = rewriter.create(loc, strideW); + + auto loopN = rewriter.create(loc, c0, boundN, c1); + rewriter.setInsertionPointToStart(loopN.getBody()); + Value ivN = loopN.getInductionVar(); + + auto loopOC = rewriter.create(loc, c0, boundOC, stepN); + rewriter.setInsertionPointToStart(loopOC.getBody()); + Value ivOC = loopOC.getInductionVar(); + + auto loopOH = rewriter.create(loc, c0, boundOH, stepM); + rewriter.setInsertionPointToStart(loopOH.getBody()); + Value ivOH = loopOH.getInductionVar(); + + auto loopOW = rewriter.create(loc, c0, boundOW, c1); + rewriter.setInsertionPointToStart(loopOW.getBody()); + Value ivOW = loopOW.getInductionVar(); + + auto loopIC = rewriter.create(loc, c0, boundIC, stepK); + rewriter.setInsertionPointToStart(loopIC.getBody()); + Value ivIC = loopIC.getInductionVar(); + + auto loopFH = rewriter.create(loc, c0, boundFH, c1); + rewriter.setInsertionPointToStart(loopFH.getBody()); + Value ivFH = loopFH.getInductionVar(); + + auto loopFW = rewriter.create(loc, c0, boundFW, c1); + rewriter.setInsertionPointToStart(loopFW.getBody()); + Value ivFW = loopFW.getInductionVar(); + + Value ihBase = rewriter.create(loc, ivOH, strideHVal); + ihBase = rewriter.create(loc, ihBase, ivFH); + Value iw = rewriter.create(loc, ivOW, strideWVal); + iw = rewriter.create(loc, iw, ivFW); + + auto inputTileType = MemRefType::get({TILE_2M, TILE_K}, inputElemType); + auto filterTileType = MemRefType::get({TILE_K, TILE_N}, inputElemType); + auto outputTileType = + MemRefType::get({TILE_M, TILE_N}, rewriter.getI32Type()); + + Value inputTile = rewriter.create(loc, inputTileType); + Value filterTile = rewriter.create(loc, filterTileType); + Value outputTile = rewriter.create(loc, outputTileType); + + Value tileMBound = rewriter.create(loc, TILE_M); + Value tile2MBound = rewriter.create(loc, TILE_2M); + Value tileKBound = rewriter.create(loc, TILE_K); + Value tileNBound = rewriter.create(loc, TILE_N); + + auto fillInputLoop = rewriter.create(loc, c0, tile2MBound, c1); + rewriter.setInsertionPointToStart(fillInputLoop.getBody()); + Value fillM = fillInputLoop.getInductionVar(); + + auto fillInputInnerLoop = + rewriter.create(loc, c0, tileKBound, c1); + rewriter.setInsertionPointToStart(fillInputInnerLoop.getBody()); + Value fillK = fillInputInnerLoop.getInductionVar(); + + Value mTimesStride = rewriter.create(loc, fillM, strideHVal); + Value inputIH = rewriter.create(loc, ihBase, mTimesStride); + Value inputIC = rewriter.create(loc, ivIC, fillK); + Value inputVal = rewriter.create( + loc, input, ValueRange{ivN, inputIC, inputIH, iw}); + rewriter.create(loc, inputVal, inputTile, + ValueRange{fillM, fillK}); + + rewriter.setInsertionPointAfter(fillInputLoop); + + auto fillFilterLoop = rewriter.create(loc, c0, tileKBound, c1); + rewriter.setInsertionPointToStart(fillFilterLoop.getBody()); + Value fillFK = fillFilterLoop.getInductionVar(); + + auto fillFilterInnerLoop = + rewriter.create(loc, c0, tileNBound, c1); + rewriter.setInsertionPointToStart(fillFilterInnerLoop.getBody()); + Value fillFN = fillFilterInnerLoop.getInductionVar(); + + Value filterOC = rewriter.create(loc, ivOC, fillFN); + Value filterIC = rewriter.create(loc, ivIC, fillFK); + Value filterVal = rewriter.create( + loc, filter, ValueRange{filterOC, filterIC, ivFH, ivFW}); + rewriter.create(loc, filterVal, filterTile, + ValueRange{fillFK, fillFN}); + + rewriter.setInsertionPointAfter(fillFilterLoop); + + auto loadOutputLoop = rewriter.create(loc, c0, tileMBound, c1); + rewriter.setInsertionPointToStart(loadOutputLoop.getBody()); + Value loadM = loadOutputLoop.getInductionVar(); + + auto loadOutputInnerLoop = + rewriter.create(loc, c0, tileNBound, c1); + rewriter.setInsertionPointToStart(loadOutputInnerLoop.getBody()); + Value loadN = loadOutputInnerLoop.getInductionVar(); + + Value outOC = rewriter.create(loc, ivOC, loadN); + Value outOH = rewriter.create(loc, ivOH, loadM); + Value outVal = rewriter.create( + loc, output, ValueRange{ivN, outOC, outOH, ivOW}); + rewriter.create(loc, outVal, outputTile, + ValueRange{loadM, loadN}); + + rewriter.setInsertionPointAfter(loadOutputLoop); + + if (strideH == 1) { + rewriter.create(loc, outputTile, inputTile, filterTile); + } else if (strideH == 2) { + rewriter.create(loc, outputTile, inputTile, filterTile); + } else if (strideH == 3) { + rewriter.create(loc, outputTile, inputTile, filterTile); + } else { + Value slide = rewriter.create(loc, 0, 64); + rewriter.create(loc, outputTile, inputTile, filterTile, slide); + } + + auto storeOutputLoop = rewriter.create(loc, c0, tileMBound, c1); + rewriter.setInsertionPointToStart(storeOutputLoop.getBody()); + Value storeM = storeOutputLoop.getInductionVar(); + + auto storeOutputInnerLoop = + rewriter.create(loc, c0, tileNBound, c1); + rewriter.setInsertionPointToStart(storeOutputInnerLoop.getBody()); + Value storeN = storeOutputInnerLoop.getInductionVar(); + + Value storeOC = rewriter.create(loc, ivOC, storeN); + Value storeOH = rewriter.create(loc, ivOH, storeM); + Value storeVal = rewriter.create( + loc, outputTile, ValueRange{storeM, storeN}); + rewriter.create(loc, storeVal, output, + ValueRange{ivN, storeOC, storeOH, ivOW}); + + rewriter.setInsertionPointAfter(storeOutputLoop); + + rewriter.setInsertionPointAfter(loopN); + + rewriter.eraseOp(convOp); + + return success(); + } +}; + +} // namespace + +//===----------------------------------------------------------------------===// +// Pass Definition +//===----------------------------------------------------------------------===// + +namespace { +class LowerLinalgToIMEPass + : public PassWrapper> { +public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LowerLinalgToIMEPass) + + StringRef getArgument() const final { return "lower-linalg-to-ime"; } + StringRef getDescription() const final { + return "Lower linalg dialect operations to IME dialect operations."; + } + + LowerLinalgToIMEPass() = default; + LowerLinalgToIMEPass(const LowerLinalgToIMEPass &) {} + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + registry.insert(); + registry.insert(); + registry.insert(); + registry.insert(); + } + + void runOnOperation() override; +}; +} // namespace + +void LowerLinalgToIMEPass::runOnOperation() { + MLIRContext *context = &getContext(); + ModuleOp module = getOperation(); + + RewritePatternSet patterns(context); + + // Add patterns with higher benefit first (aligned dimensions) + patterns.add(context); + patterns.add(context); + + // Add boundary handling pattern with lower benefit (tried after aligned cases fail) + patterns.add(context); + + patterns.add(context); + patterns.add(context); + + if (failed(applyPatternsGreedily(module, std::move(patterns)))) { + signalPassFailure(); + } +} + +//===----------------------------------------------------------------------===// +// Pass Registration +//===----------------------------------------------------------------------===// + +namespace mlir { +namespace buddy { +void registerLowerLinalgToIMEPass() { + PassRegistration(); +} +} // namespace buddy +} // namespace mlir diff --git a/midend/lib/Dialect/AME/CMakeLists.txt b/midend/lib/Dialect/AME/CMakeLists.txt new file mode 100644 index 0000000000..9f57627c32 --- /dev/null +++ b/midend/lib/Dialect/AME/CMakeLists.txt @@ -0,0 +1,2 @@ +add_subdirectory(IR) +add_subdirectory(Transforms) diff --git a/midend/lib/Dialect/AME/IR/AMEDialect.cpp b/midend/lib/Dialect/AME/IR/AMEDialect.cpp new file mode 100644 index 0000000000..75c1c3fa0d --- /dev/null +++ b/midend/lib/Dialect/AME/IR/AMEDialect.cpp @@ -0,0 +1,37 @@ +//====- AMEDialect.cpp - MLIR AME dialect implementation ------------------===// +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// + +#include "mlir/IR/Builders.h" +#include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/OpImplementation.h" + +#include "Dialect/AME/AMEDialect.h" +#include "Dialect/AME/AMEOps.h" + +using namespace mlir; +using namespace buddy::ame; + +#include "AME/AMEDialect.cpp.inc" + +#define GET_OP_CLASSES +#include "AME/AME.cpp.inc" + +void AMEDialect::initialize() { + addOperations< +#define GET_OP_LIST +#include "AME/AME.cpp.inc" + >(); +} diff --git a/midend/lib/Dialect/AME/IR/CMakeLists.txt b/midend/lib/Dialect/AME/IR/CMakeLists.txt new file mode 100644 index 0000000000..b6d884ae9b --- /dev/null +++ b/midend/lib/Dialect/AME/IR/CMakeLists.txt @@ -0,0 +1,6 @@ +add_mlir_dialect_library(BuddyAME + AMEDialect.cpp + + LINK_LIBS PUBLIC + MLIRIR +) diff --git a/midend/lib/Dialect/AME/Transforms/CMakeLists.txt b/midend/lib/Dialect/AME/Transforms/CMakeLists.txt new file mode 100644 index 0000000000..79c0c88b0e --- /dev/null +++ b/midend/lib/Dialect/AME/Transforms/CMakeLists.txt @@ -0,0 +1,17 @@ +add_mlir_library(BuddyAMETransforms + LegalizeForLLVMExport.cpp + + DEPENDS + MLIRAMEIncGen + + LINK_LIBS PUBLIC + BuddyAME + MLIRArithDialect + MLIRFuncDialect + MLIRIR + MLIRLLVMCommonConversion + MLIRLLVMDialect + MLIRMemRefDialect + MLIRPass + MLIRTransforms +) diff --git a/midend/lib/Dialect/AME/Transforms/LegalizeForLLVMExport.cpp b/midend/lib/Dialect/AME/Transforms/LegalizeForLLVMExport.cpp new file mode 100644 index 0000000000..817bd95000 --- /dev/null +++ b/midend/lib/Dialect/AME/Transforms/LegalizeForLLVMExport.cpp @@ -0,0 +1,646 @@ +//====- LegalizeForLLVMExport.cpp - Prepare AME for LLVM translation ------===// +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/LLVMCommon/ConversionTarget.h" +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Pass/Pass.h" + +#include "Dialect/AME/AMEDialect.h" +#include "Dialect/AME/AMEOps.h" +#include "Dialect/AME/Transform.h" + +using namespace mlir; +using namespace buddy::ame; + +namespace { + +//===----------------------------------------------------------------------===// +// Helper Functions +//===----------------------------------------------------------------------===// + +static FlatSymbolRefAttr +getOrInsertIntrinsic(ConversionPatternRewriter &rewriter, ModuleOp module, + StringRef intrinsicName, LLVM::LLVMFunctionType funcType) { + auto *ctx = rewriter.getContext(); + if (module.lookupSymbol(intrinsicName)) + return FlatSymbolRefAttr::get(ctx, intrinsicName); + + auto savedInsertionPoint = rewriter.saveInsertionPoint(); + rewriter.setInsertionPointToEnd(module.getBody()); + rewriter.create(module.getLoc(), intrinsicName, funcType, + LLVM::Linkage::External, false, + LLVM::CConv::C); + rewriter.restoreInsertionPoint(savedInsertionPoint); + return FlatSymbolRefAttr::get(ctx, intrinsicName); +} + +static Value extractPointerFromMemref(ConversionPatternRewriter &rewriter, + Location loc, Value memref) { + auto *ctx = rewriter.getContext(); + auto ptrType = LLVM::LLVMPointerType::get(ctx); + auto i64Type = IntegerType::get(ctx, 64); + Value idx = + rewriter.create(loc, memref); + Value i64Val = rewriter.create(loc, i64Type, idx); + Value ptr = rewriter.create(loc, ptrType, i64Val); + return ptr; +} + +//===----------------------------------------------------------------------===// +// AME Lowering Patterns +//===----------------------------------------------------------------------===// + +//===----------------------------------------------------------------------===// +// Configuration Operations Lowering +//===----------------------------------------------------------------------===// + +/// Lowering pattern for msettilemi (set tile M dimension with immediate) +struct AMEMSettilemiLowering : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(MSettilemiOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto *ctx = rewriter.getContext(); + auto module = op->getParentOfType(); + if (!module) return failure(); + + auto i64Type = IntegerType::get(ctx, 64); + auto funcType = LLVM::LLVMFunctionType::get( + LLVM::LLVMVoidType::get(ctx), {i64Type}); + + auto intrinsicName = getOrInsertIntrinsic( + rewriter, module, "llvm.riscv.buddy.msettilemi", funcType); + + Value tilemVal = rewriter.create( + loc, i64Type, rewriter.getI64IntegerAttr(op.getTilem())); + + rewriter.create(loc, TypeRange{}, intrinsicName, + ValueRange{tilemVal}); + rewriter.eraseOp(op); + return success(); + } +}; + +/// Lowering pattern for msettileni +struct AMEMSettileniLowering : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(MSettileniOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto *ctx = rewriter.getContext(); + auto module = op->getParentOfType(); + if (!module) return failure(); + + auto i64Type = IntegerType::get(ctx, 64); + auto funcType = LLVM::LLVMFunctionType::get( + LLVM::LLVMVoidType::get(ctx), {i64Type}); + + auto intrinsicName = getOrInsertIntrinsic( + rewriter, module, "llvm.riscv.buddy.msettileni", funcType); + + Value tilenVal = rewriter.create( + loc, i64Type, rewriter.getI64IntegerAttr(op.getTilen())); + + rewriter.create(loc, TypeRange{}, intrinsicName, + ValueRange{tilenVal}); + rewriter.eraseOp(op); + return success(); + } +}; + +/// Lowering pattern for msettileki +struct AMEMSettilekiLowering : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(MSettilekiOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto *ctx = rewriter.getContext(); + auto module = op->getParentOfType(); + if (!module) return failure(); + + auto i64Type = IntegerType::get(ctx, 64); + auto funcType = LLVM::LLVMFunctionType::get( + LLVM::LLVMVoidType::get(ctx), {i64Type}); + + auto intrinsicName = getOrInsertIntrinsic( + rewriter, module, "llvm.riscv.buddy.msettileki", funcType); + + Value tilekVal = rewriter.create( + loc, i64Type, rewriter.getI64IntegerAttr(op.getTilek())); + + rewriter.create(loc, TypeRange{}, intrinsicName, + ValueRange{tilekVal}); + rewriter.eraseOp(op); + return success(); + } +}; + +/// Lowering pattern for mzero +struct AMEMzeroLowering : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(MzeroOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto *ctx = rewriter.getContext(); + auto module = op->getParentOfType(); + if (!module) return failure(); + + auto i64Type = IntegerType::get(ctx, 64); + auto funcType = LLVM::LLVMFunctionType::get( + LLVM::LLVMVoidType::get(ctx), {i64Type}); + + auto intrinsicName = getOrInsertIntrinsic( + rewriter, module, "llvm.riscv.buddy.mzero", funcType); + + Value mdVal = rewriter.create( + loc, i64Type, rewriter.getI64IntegerAttr(op.getMd())); + + rewriter.create(loc, TypeRange{}, intrinsicName, + ValueRange{mdVal}); + rewriter.eraseOp(op); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// Load Operations Lowering +//===----------------------------------------------------------------------===// + +/// Lowering pattern for mlae32.m (load left matrix A) +struct AMEMlae32mLowering : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(Mlae32mOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto *ctx = rewriter.getContext(); + auto module = op->getParentOfType(); + if (!module) return failure(); + + auto i64Type = IntegerType::get(ctx, 64); + auto ptrType = LLVM::LLVMPointerType::get(ctx); + + auto funcType = LLVM::LLVMFunctionType::get( + LLVM::LLVMVoidType::get(ctx), {i64Type, ptrType, i64Type}); + + auto intrinsicName = getOrInsertIntrinsic( + rewriter, module, "llvm.riscv.buddy.mlae32.m", funcType); + + Value mdVal = rewriter.create( + loc, i64Type, rewriter.getI64IntegerAttr(op.getMd())); + Value basePtr = extractPointerFromMemref(rewriter, loc, op.getBase()); + + rewriter.create(loc, TypeRange{}, intrinsicName, + ValueRange{mdVal, basePtr, adaptor.getStride()}); + rewriter.eraseOp(op); + return success(); + } +}; + +/// Lowering pattern for mlbe32.m (load right matrix B) +struct AMEMlbe32mLowering : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(Mlbe32mOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto *ctx = rewriter.getContext(); + auto module = op->getParentOfType(); + if (!module) return failure(); + + auto i64Type = IntegerType::get(ctx, 64); + auto ptrType = LLVM::LLVMPointerType::get(ctx); + + auto funcType = LLVM::LLVMFunctionType::get( + LLVM::LLVMVoidType::get(ctx), {i64Type, ptrType, i64Type}); + + auto intrinsicName = getOrInsertIntrinsic( + rewriter, module, "llvm.riscv.buddy.mlbe32.m", funcType); + + Value mdVal = rewriter.create( + loc, i64Type, rewriter.getI64IntegerAttr(op.getMd())); + Value basePtr = extractPointerFromMemref(rewriter, loc, op.getBase()); + + rewriter.create(loc, TypeRange{}, intrinsicName, + ValueRange{mdVal, basePtr, adaptor.getStride()}); + rewriter.eraseOp(op); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// Store Operations Lowering +//===----------------------------------------------------------------------===// + +/// Lowering pattern for msce32.m (store output matrix C) +struct AMEMsce32mLowering : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(Msce32mOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto *ctx = rewriter.getContext(); + auto module = op->getParentOfType(); + if (!module) return failure(); + + auto i64Type = IntegerType::get(ctx, 64); + auto ptrType = LLVM::LLVMPointerType::get(ctx); + + auto funcType = LLVM::LLVMFunctionType::get( + LLVM::LLVMVoidType::get(ctx), {i64Type, ptrType, i64Type}); + + auto intrinsicName = getOrInsertIntrinsic( + rewriter, module, "llvm.riscv.buddy.msce32.m", funcType); + + Value ms3Val = rewriter.create( + loc, i64Type, rewriter.getI64IntegerAttr(op.getMs3())); + Value basePtr = extractPointerFromMemref(rewriter, loc, op.getBase()); + + rewriter.create(loc, TypeRange{}, intrinsicName, + ValueRange{ms3Val, basePtr, adaptor.getStride()}); + rewriter.eraseOp(op); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// Tile Register Matrix Multiply Lowering +//===----------------------------------------------------------------------===// + +/// Lowering pattern for mma.w.mm.tile (tile register matrix multiply) +struct AMEMmaWmmTileLowering : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(MmaWmmTileOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto *ctx = rewriter.getContext(); + auto module = op->getParentOfType(); + if (!module) return failure(); + + auto i64Type = IntegerType::get(ctx, 64); + auto funcType = LLVM::LLVMFunctionType::get( + LLVM::LLVMVoidType::get(ctx), {i64Type, i64Type, i64Type}); + + auto intrinsicName = getOrInsertIntrinsic( + rewriter, module, "llvm.riscv.buddy.mma.w.mm.tile", funcType); + + Value mdVal = rewriter.create( + loc, i64Type, rewriter.getI64IntegerAttr(op.getMd())); + Value ms1Val = rewriter.create( + loc, i64Type, rewriter.getI64IntegerAttr(op.getMs1())); + Value ms2Val = rewriter.create( + loc, i64Type, rewriter.getI64IntegerAttr(op.getMs2())); + + rewriter.create(loc, TypeRange{}, intrinsicName, + ValueRange{mdVal, ms1Val, ms2Val}); + rewriter.eraseOp(op); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// High-level Matrix Multiply Lowering (MemRef version) +//===----------------------------------------------------------------------===// + +/// Lowering pattern for mqma.b.mm (int8 quad-widen matrix multiply) +struct AMEMqmaBmmLowering : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(MqmaBmmOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto *ctx = rewriter.getContext(); + auto module = op->getParentOfType(); + if (!module) + return failure(); + + // Get memref types and compute dimensions + auto mdType = cast(op.getMd().getType()); + auto ms1Type = cast(op.getMs1().getType()); + + auto mdShape = mdType.getShape(); + auto ms1Shape = ms1Type.getShape(); + + int64_t M = mdShape[0]; + int64_t N = mdShape[1]; + int64_t K = ms1Shape[1]; // K dimension from ms1 + + // Define LLVM types + auto i64Type = IntegerType::get(ctx, 64); + auto ptrType = LLVM::LLVMPointerType::get(ctx); + + // Get base pointers from memrefs + Value mdBase = extractPointerFromMemref(rewriter, loc, op.getMd()); + Value ms1Base = extractPointerFromMemref(rewriter, loc, op.getMs1()); + Value ms2Base = extractPointerFromMemref(rewriter, loc, op.getMs2()); + + // Create constants for dimensions + Value mVal = rewriter.create(loc, i64Type, + rewriter.getI64IntegerAttr(M)); + Value nVal = rewriter.create(loc, i64Type, + rewriter.getI64IntegerAttr(N)); + Value kVal = rewriter.create(loc, i64Type, + rewriter.getI64IntegerAttr(K)); + + // Create intrinsic function type for mqma.b.mm + // void @llvm.riscv.buddy.mqma.b.mm(ptr md, ptr ms1, ptr ms2, i64 M, i64 N, i64 K) + auto funcType = LLVM::LLVMFunctionType::get( + LLVM::LLVMVoidType::get(ctx), + {ptrType, ptrType, ptrType, i64Type, i64Type, i64Type}); + + auto intrinsicName = getOrInsertIntrinsic( + rewriter, module, "llvm.riscv.buddy.mqma.b.mm", funcType); + + // Call the intrinsic + rewriter.create(loc, TypeRange{}, intrinsicName, + ValueRange{mdBase, ms1Base, ms2Base, + mVal, nVal, kVal}); + + rewriter.eraseOp(op); + return success(); + } +}; + +/// Lowering pattern for mma.w.mm (int32 matrix multiply) +struct AMEMmaWmmLowering : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(MmaWmmOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto *ctx = rewriter.getContext(); + auto module = op->getParentOfType(); + if (!module) + return failure(); + + // Get memref types and compute dimensions + auto mdType = cast(op.getMd().getType()); + auto ms1Type = cast(op.getMs1().getType()); + + auto mdShape = mdType.getShape(); + auto ms1Shape = ms1Type.getShape(); + + int64_t M = mdShape[0]; + int64_t N = mdShape[1]; + int64_t K = ms1Shape[1]; + + // Define LLVM types + auto i64Type = IntegerType::get(ctx, 64); + auto ptrType = LLVM::LLVMPointerType::get(ctx); + + // Get base pointers from memrefs + Value mdBase = extractPointerFromMemref(rewriter, loc, op.getMd()); + Value ms1Base = extractPointerFromMemref(rewriter, loc, op.getMs1()); + Value ms2Base = extractPointerFromMemref(rewriter, loc, op.getMs2()); + + // Create constants for dimensions + Value mVal = rewriter.create(loc, i64Type, + rewriter.getI64IntegerAttr(M)); + Value nVal = rewriter.create(loc, i64Type, + rewriter.getI64IntegerAttr(N)); + Value kVal = rewriter.create(loc, i64Type, + rewriter.getI64IntegerAttr(K)); + + // Create intrinsic function type for mma.w.mm + auto funcType = LLVM::LLVMFunctionType::get( + LLVM::LLVMVoidType::get(ctx), + {ptrType, ptrType, ptrType, i64Type, i64Type, i64Type}); + + auto intrinsicName = getOrInsertIntrinsic( + rewriter, module, "llvm.riscv.buddy.mma.w.mm", funcType); + + rewriter.create(loc, TypeRange{}, intrinsicName, + ValueRange{mdBase, ms1Base, ms2Base, + mVal, nVal, kVal}); + + rewriter.eraseOp(op); + return success(); + } +}; + +/// Lowering pattern for mma.dw.mm (int64 matrix multiply) +struct AMEMmaDwmmLowering : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(MmaDwmmOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto *ctx = rewriter.getContext(); + auto module = op->getParentOfType(); + if (!module) + return failure(); + + // Get memref types and compute dimensions + auto mdType = cast(op.getMd().getType()); + auto ms1Type = cast(op.getMs1().getType()); + + auto mdShape = mdType.getShape(); + auto ms1Shape = ms1Type.getShape(); + + int64_t M = mdShape[0]; + int64_t N = mdShape[1]; + int64_t K = ms1Shape[1]; + + // Define LLVM types + auto i64Type = IntegerType::get(ctx, 64); + auto ptrType = LLVM::LLVMPointerType::get(ctx); + + // Get base pointers from memrefs + Value mdBase = extractPointerFromMemref(rewriter, loc, op.getMd()); + Value ms1Base = extractPointerFromMemref(rewriter, loc, op.getMs1()); + Value ms2Base = extractPointerFromMemref(rewriter, loc, op.getMs2()); + + // Create constants for dimensions + Value mVal = rewriter.create(loc, i64Type, + rewriter.getI64IntegerAttr(M)); + Value nVal = rewriter.create(loc, i64Type, + rewriter.getI64IntegerAttr(N)); + Value kVal = rewriter.create(loc, i64Type, + rewriter.getI64IntegerAttr(K)); + + // Create intrinsic function type for mma.dw.mm + auto funcType = LLVM::LLVMFunctionType::get( + LLVM::LLVMVoidType::get(ctx), + {ptrType, ptrType, ptrType, i64Type, i64Type, i64Type}); + + auto intrinsicName = getOrInsertIntrinsic( + rewriter, module, "llvm.riscv.buddy.mma.dw.mm", funcType); + + rewriter.create(loc, TypeRange{}, intrinsicName, + ValueRange{mdBase, ms1Base, ms2Base, + mVal, nVal, kVal}); + + rewriter.eraseOp(op); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// Pass Definition +//===----------------------------------------------------------------------===// + +struct LegalizeAMEForLLVMExport + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LegalizeAMEForLLVMExport) + + StringRef getArgument() const final { return "lower-ame"; } + StringRef getDescription() const final { + return "AME dialect lowering pass."; + } + + LegalizeAMEForLLVMExport() = default; + LegalizeAMEForLLVMExport(const LegalizeAMEForLLVMExport &) {} + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + registry.insert(); + registry.insert(); + registry.insert(); + } + + void runOnOperation() override { + ModuleOp module = getOperation(); + MLIRContext &context = getContext(); + + LLVMConversionTarget target(context); + target.addLegalDialect(); + target.addLegalDialect(); + target.addLegalDialect(); + + // Configuration operations + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + + // Load/Store operations + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + + // Tile register matrix multiply + target.addIllegalOp(); + + // High-level matrix multiply (MemRef) + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + + LLVMTypeConverter typeConverter(&context); + + RewritePatternSet patterns(&context); + + // Configuration patterns + patterns.add(typeConverter); + patterns.add(typeConverter); + patterns.add(typeConverter); + patterns.add(typeConverter); + + // Load/Store patterns + patterns.add(typeConverter); + patterns.add(typeConverter); + patterns.add(typeConverter); + + // Tile register matrix multiply patterns + patterns.add(typeConverter); + + // High-level matrix multiply patterns + patterns.add(typeConverter); + patterns.add(typeConverter); + patterns.add(typeConverter); + + if (failed(applyPartialConversion(module, target, std::move(patterns)))) + signalPassFailure(); + } +}; + +} // namespace + +void mlir::populateAMELegalizeForLLVMExportPatterns( + LLVMTypeConverter &converter, RewritePatternSet &patterns) { + // Configuration patterns + patterns.add(converter); + patterns.add(converter); + patterns.add(converter); + patterns.add(converter); + + // Load/Store patterns + patterns.add(converter); + patterns.add(converter); + patterns.add(converter); + + // Tile register matrix multiply patterns + patterns.add(converter); + + // High-level matrix multiply patterns + patterns.add(converter); + patterns.add(converter); + patterns.add(converter); +} + +void mlir::configureAMELegalizeForExportTarget(LLVMConversionTarget &target) { + target.addLegalDialect(); + target.addLegalDialect(); + + // Configuration operations + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + + // Load/Store operations + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + + // Tile register matrix multiply + target.addIllegalOp(); + + // High-level matrix multiply (MemRef) + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); +} + +std::unique_ptr buddy::ame::createLegalizeForLLVMExportPass() { + return std::make_unique(); +} + +namespace mlir { +namespace buddy { +void registerLowerAMEPass() { PassRegistration(); } +} // namespace buddy +} // namespace mlir diff --git a/midend/lib/Dialect/CMakeLists.txt b/midend/lib/Dialect/CMakeLists.txt index c1690ba540..ed346db7c3 100644 --- a/midend/lib/Dialect/CMakeLists.txt +++ b/midend/lib/Dialect/CMakeLists.txt @@ -1,3 +1,4 @@ +add_subdirectory(AME) add_subdirectory(Bud) add_subdirectory(DAP) add_subdirectory(DIP) diff --git a/midend/lib/Dialect/IME/Transforms/LegalizeForLLVMExport.cpp b/midend/lib/Dialect/IME/Transforms/LegalizeForLLVMExport.cpp index dc36c8803f..3d6b30e215 100644 --- a/midend/lib/Dialect/IME/Transforms/LegalizeForLLVMExport.cpp +++ b/midend/lib/Dialect/IME/Transforms/LegalizeForLLVMExport.cpp @@ -31,6 +31,107 @@ using namespace buddy::ime; namespace { +//===----------------------------------------------------------------------===// +// IME Type Configuration +//===----------------------------------------------------------------------===// + +/// Enumeration of supported IME data types +enum class IMEDataType { + Int8, // 8-bit integer (signed/unsigned) + Int16, // 16-bit integer (signed/unsigned) + FP16 // 16-bit floating point +}; + +/// Configuration structure for IME operations based on data type +struct IMETypeConfig { + int64_t targetVL; // Vector length + int64_t sew; // SEW encoding: 0=e8, 1=e16, 2=e32, 3=e64 + int64_t lmul; // LMUL encoding: 0=m1, 1=m2, 2=m4, 3=m8 + int64_t extendedVL; // Extended VL for sliding operations (2x targetVL) + Type elementType; // Element type for input vectors + Type outputElementType; // Element type for output vectors + VectorType inputVecType; // Input vector type (e.g., nxv32i8) + VectorType outputVecType;// Output vector type (e.g., nxv8i32) + VectorType extendedInputVecType; // Extended input vector for sliding ops + std::string inputVecSuffix; // e.g., "nxv32i8" + std::string outputVecSuffix; // e.g., "nxv8i32" + std::string extendedInputVecSuffix; // e.g., "nxv64i8" +}; + +/// Get IME type configuration from memref element type +static IMETypeConfig getIMETypeConfig(MLIRContext *ctx, Type elementType) { + IMETypeConfig config; + config.lmul = 0; // Always m1 for IME operations + + if (elementType.isInteger(8)) { + // int8: SEW=e8, VL=32, MAC unit 4x4x8 + config.targetVL = 32; + config.sew = 0; // e8 + config.extendedVL = 64; + config.elementType = IntegerType::get(ctx, 8); + config.outputElementType = IntegerType::get(ctx, 32); + config.inputVecType = VectorType::get({32}, config.elementType, /*scalableDims=*/true); + config.outputVecType = VectorType::get({8}, config.outputElementType, /*scalableDims=*/true); + config.extendedInputVecType = VectorType::get({64}, config.elementType, /*scalableDims=*/true); + config.inputVecSuffix = "nxv32i8"; + config.outputVecSuffix = "nxv8i32"; + config.extendedInputVecSuffix = "nxv64i8"; + } else if (elementType.isInteger(16)) { + // int16: SEW=e16, VL=16, MAC unit 4x4x4 + config.targetVL = 16; + config.sew = 1; // e16 + config.extendedVL = 32; + config.elementType = IntegerType::get(ctx, 16); + config.outputElementType = IntegerType::get(ctx, 32); + config.inputVecType = VectorType::get({16}, config.elementType, /*scalableDims=*/true); + config.outputVecType = VectorType::get({8}, config.outputElementType, /*scalableDims=*/true); + config.extendedInputVecType = VectorType::get({32}, config.elementType, /*scalableDims=*/true); + config.inputVecSuffix = "nxv16i16"; + config.outputVecSuffix = "nxv8i32"; + config.extendedInputVecSuffix = "nxv32i16"; + } else if (elementType.isF16()) { + // fp16: SEW=e16, VL=16, MAC unit 4x4x4 + config.targetVL = 16; + config.sew = 1; // e16 + config.extendedVL = 32; + config.elementType = Float16Type::get(ctx); + config.outputElementType = Float16Type::get(ctx); // fp16 output is also f16 + config.inputVecType = VectorType::get({16}, config.elementType, /*scalableDims=*/true); + config.outputVecType = VectorType::get({16}, config.outputElementType, /*scalableDims=*/true); + config.extendedInputVecType = VectorType::get({32}, config.elementType, /*scalableDims=*/true); + config.inputVecSuffix = "nxv16f16"; + config.outputVecSuffix = "nxv16f16"; + config.extendedInputVecSuffix = "nxv32f16"; + } else { + // Default to int8 if unknown type + config.targetVL = 32; + config.sew = 0; + config.extendedVL = 64; + config.elementType = IntegerType::get(ctx, 8); + config.outputElementType = IntegerType::get(ctx, 32); + config.inputVecType = VectorType::get({32}, config.elementType, /*scalableDims=*/true); + config.outputVecType = VectorType::get({8}, config.outputElementType, /*scalableDims=*/true); + config.extendedInputVecType = VectorType::get({64}, config.elementType, /*scalableDims=*/true); + config.inputVecSuffix = "nxv32i8"; + config.outputVecSuffix = "nxv8i32"; + config.extendedInputVecSuffix = "nxv64i8"; + } + + return config; +} + +/// Get element type from a memref value +static Type getMemRefElementType(Value memref) { + if (auto memrefType = dyn_cast(memref.getType())) { + return memrefType.getElementType(); + } + return nullptr; +} + +//===----------------------------------------------------------------------===// +// Helper Functions +//===----------------------------------------------------------------------===// + static FlatSymbolRefAttr getOrInsertIntrinsic(ConversionPatternRewriter &rewriter, ModuleOp module, StringRef intrinsicName, LLVM::LLVMFunctionType funcType) { @@ -47,9 +148,37 @@ getOrInsertIntrinsic(ConversionPatternRewriter &rewriter, ModuleOp module, return FlatSymbolRefAttr::get(ctx, intrinsicName); } +// Create vsetvli intrinsic call to configure vector type +// SEW encoding: 0=e8, 1=e16, 2=e32, 3=e64 +// LMUL encoding: 0=m1, 1=m2, 2=m4, 3=m8, 5=mf8, 6=mf4, 7=mf2 +static Value createVsetvli(ConversionPatternRewriter &rewriter, Location loc, + ModuleOp module, int64_t avl, int64_t sew, + int64_t lmul) { + auto *ctx = rewriter.getContext(); + auto i64Type = IntegerType::get(ctx, 64); + + // vsetvli intrinsic: i64 @llvm.riscv.vsetvli.i64(i64 avl, i64 sew, i64 lmul) + auto funcType = + LLVM::LLVMFunctionType::get(i64Type, {i64Type, i64Type, i64Type}, false); + auto funcRef = + getOrInsertIntrinsic(rewriter, module, "llvm.riscv.vsetvli.i64", funcType); + + auto avlVal = rewriter.create( + loc, i64Type, rewriter.getI64IntegerAttr(avl)); + auto sewVal = rewriter.create( + loc, i64Type, rewriter.getI64IntegerAttr(sew)); + auto lmulVal = rewriter.create( + loc, i64Type, rewriter.getI64IntegerAttr(lmul)); + + auto call = rewriter.create(loc, TypeRange{i64Type}, funcRef, + ValueRange{avlVal, sewVal, lmulVal}); + return call.getResult(); +} + static Value createRVVVectorLoad(ConversionPatternRewriter &rewriter, Location loc, ModuleOp module, Value pointer, - Type vectorType, StringRef baseIntrinsicName) { + Type vectorType, StringRef baseIntrinsicName, + Value vl) { auto *ctx = rewriter.getContext(); auto i64Type = IntegerType::get(ctx, 64); auto ptrType = LLVM::LLVMPointerType::get(ctx); @@ -58,8 +187,6 @@ static Value createRVVVectorLoad(ConversionPatternRewriter &rewriter, vectorType, {vectorType, ptrType, i64Type}, false); auto funcRef = getOrInsertIntrinsic(rewriter, module, mangledName, funcType); auto undefPassthru = rewriter.create(loc, vectorType); - auto vl = rewriter.create(loc, i64Type, - rewriter.getI64IntegerAttr(-1)); auto call = rewriter.create(loc, TypeRange{vectorType}, funcRef, ValueRange{undefPassthru, pointer, vl}); @@ -68,7 +195,8 @@ static Value createRVVVectorLoad(ConversionPatternRewriter &rewriter, static void createRVVVectorStore(ConversionPatternRewriter &rewriter, Location loc, ModuleOp module, Value vector, - Value pointer, StringRef baseIntrinsicName) { + Value pointer, StringRef baseIntrinsicName, + Value vl) { auto *ctx = rewriter.getContext(); auto i64Type = IntegerType::get(ctx, 64); auto ptrType = LLVM::LLVMPointerType::get(ctx); @@ -78,8 +206,6 @@ static void createRVVVectorStore(ConversionPatternRewriter &rewriter, auto funcType = LLVM::LLVMFunctionType::get( voidType, {vectorType, ptrType, i64Type}, false); auto funcRef = getOrInsertIntrinsic(rewriter, module, mangledName, funcType); - auto vl = rewriter.create(loc, i64Type, - rewriter.getI64IntegerAttr(-1)); rewriter.create(loc, TypeRange{}, funcRef, ValueRange{vector, pointer, vl}); } @@ -103,6 +229,27 @@ static Value createIMEVmadotIntrinsic(ConversionPatternRewriter &rewriter, return call.getResult(); } +static Value createIMEVmadotnIntrinsic(ConversionPatternRewriter &rewriter, + Location loc, ModuleOp module, + Value vdVector, Value vs1Vector, + Value vs2Vector, Value slideVal, + StringRef baseIntrinsicName, + StringRef typeSuffix) { + auto *ctx = rewriter.getContext(); + auto vdType = vdVector.getType(); + auto vs1Type = vs1Vector.getType(); + auto vs2Type = vs2Vector.getType(); + auto i64Type = IntegerType::get(ctx, 64); + std::string mangledName = (baseIntrinsicName + typeSuffix).str(); + auto funcType = LLVM::LLVMFunctionType::get( + vdType, {vdType, vs1Type, vs2Type, i64Type}, false); + auto funcRef = getOrInsertIntrinsic(rewriter, module, mangledName, funcType); + auto call = rewriter.create( + loc, TypeRange{vdType}, funcRef, + ValueRange{vdVector, vs1Vector, vs2Vector, slideVal}); + return call.getResult(); +} + static Value extractPointerFromMemref(ConversionPatternRewriter &rewriter, Location loc, Value memref) { auto *ctx = rewriter.getContext(); @@ -127,26 +274,40 @@ struct IMEVmadotLowering : public ConvertOpToLLVMPattern { if (!module) return failure(); + // Get element type from vs1 memref and configure accordingly + Type elemType = getMemRefElementType(op.getVs1()); + if (!elemType) + return failure(); + + IMETypeConfig config = getIMETypeConfig(ctx, elemType); + auto i64Type = IntegerType::get(ctx, 64); + + // Configure vtype based on element type + createVsetvli(rewriter, loc, module, config.targetVL, config.sew, config.lmul); + Value vlValue = rewriter.create( + loc, i64Type, rewriter.getI64IntegerAttr(config.targetVL)); + Value vdPtr = extractPointerFromMemref(rewriter, loc, op.getVd()); Value vs1Ptr = extractPointerFromMemref(rewriter, loc, op.getVs1()); Value vs2Ptr = extractPointerFromMemref(rewriter, loc, op.getVs2()); - auto i8Type = IntegerType::get(ctx, 8); - auto i32Type = IntegerType::get(ctx, 32); - auto i8VecType = VectorType::get({32}, i8Type, /*scalableDims=*/true); - auto i32VecType = VectorType::get({8}, i32Type, /*scalableDims=*/true); - - Value vs1Vec = createRVVVectorLoad(rewriter, loc, module, vs1Ptr, i8VecType, - "llvm.riscv.vle.nxv32i8"); - Value vs2Vec = createRVVVectorLoad(rewriter, loc, module, vs2Ptr, i8VecType, - "llvm.riscv.vle.nxv32i8"); - Value vdVec = createRVVVectorLoad(rewriter, loc, module, vdPtr, i32VecType, - "llvm.riscv.vle.nxv8i32"); + std::string loadIntrinsic = "llvm.riscv.vle." + config.inputVecSuffix; + std::string outputLoadIntrinsic = "llvm.riscv.vle." + config.outputVecSuffix; + std::string storeIntrinsic = "llvm.riscv.vse." + config.outputVecSuffix; + std::string typeSuffix = "." + config.outputVecSuffix + "." + + config.inputVecSuffix + "." + config.inputVecSuffix; + + Value vs1Vec = createRVVVectorLoad(rewriter, loc, module, vs1Ptr, + config.inputVecType, loadIntrinsic, vlValue); + Value vs2Vec = createRVVVectorLoad(rewriter, loc, module, vs2Ptr, + config.inputVecType, loadIntrinsic, vlValue); + Value vdVec = createRVVVectorLoad(rewriter, loc, module, vdPtr, + config.outputVecType, outputLoadIntrinsic, vlValue); Value result = createIMEVmadotIntrinsic( rewriter, loc, module, vdVec, vs1Vec, vs2Vec, "llvm.riscv.ime.vmadot", - ".nxv8i32.nxv32i8.nxv32i8"); + typeSuffix); createRVVVectorStore(rewriter, loc, module, result, vdPtr, - "llvm.riscv.vse.nxv8i32"); + storeIntrinsic, vlValue); rewriter.eraseOp(op); return success(); } @@ -164,26 +325,40 @@ struct IMEVmadotuLowering : public ConvertOpToLLVMPattern { if (!module) return failure(); + // Get element type from vs1 memref and configure accordingly + Type elemType = getMemRefElementType(op.getVs1()); + if (!elemType) + return failure(); + + IMETypeConfig config = getIMETypeConfig(ctx, elemType); + auto i64Type = IntegerType::get(ctx, 64); + + // Configure vtype based on element type + createVsetvli(rewriter, loc, module, config.targetVL, config.sew, config.lmul); + Value vlValue = rewriter.create( + loc, i64Type, rewriter.getI64IntegerAttr(config.targetVL)); + Value vdPtr = extractPointerFromMemref(rewriter, loc, op.getVd()); Value vs1Ptr = extractPointerFromMemref(rewriter, loc, op.getVs1()); Value vs2Ptr = extractPointerFromMemref(rewriter, loc, op.getVs2()); - auto i8Type = IntegerType::get(ctx, 8); - auto i32Type = IntegerType::get(ctx, 32); - auto i8VecType = VectorType::get({32}, i8Type, /*scalableDims=*/true); - auto i32VecType = VectorType::get({8}, i32Type, /*scalableDims=*/true); - - Value vs1Vec = createRVVVectorLoad(rewriter, loc, module, vs1Ptr, i8VecType, - "llvm.riscv.vle.nxv32i8"); - Value vs2Vec = createRVVVectorLoad(rewriter, loc, module, vs2Ptr, i8VecType, - "llvm.riscv.vle.nxv32i8"); - Value vdVec = createRVVVectorLoad(rewriter, loc, module, vdPtr, i32VecType, - "llvm.riscv.vle.nxv8i32"); + std::string loadIntrinsic = "llvm.riscv.vle." + config.inputVecSuffix; + std::string outputLoadIntrinsic = "llvm.riscv.vle." + config.outputVecSuffix; + std::string storeIntrinsic = "llvm.riscv.vse." + config.outputVecSuffix; + std::string typeSuffix = "." + config.outputVecSuffix + "." + + config.inputVecSuffix + "." + config.inputVecSuffix; + + Value vs1Vec = createRVVVectorLoad(rewriter, loc, module, vs1Ptr, + config.inputVecType, loadIntrinsic, vlValue); + Value vs2Vec = createRVVVectorLoad(rewriter, loc, module, vs2Ptr, + config.inputVecType, loadIntrinsic, vlValue); + Value vdVec = createRVVVectorLoad(rewriter, loc, module, vdPtr, + config.outputVecType, outputLoadIntrinsic, vlValue); Value result = createIMEVmadotIntrinsic( rewriter, loc, module, vdVec, vs1Vec, vs2Vec, "llvm.riscv.ime.vmadotu", - ".nxv8i32.nxv32i8.nxv32i8"); + typeSuffix); createRVVVectorStore(rewriter, loc, module, result, vdPtr, - "llvm.riscv.vse.nxv8i32"); + storeIntrinsic, vlValue); rewriter.eraseOp(op); return success(); } @@ -201,26 +376,40 @@ struct IMEVmadotsuLowering : public ConvertOpToLLVMPattern { if (!module) return failure(); + // Get element type from vs1 memref and configure accordingly + Type elemType = getMemRefElementType(op.getVs1()); + if (!elemType) + return failure(); + + IMETypeConfig config = getIMETypeConfig(ctx, elemType); + auto i64Type = IntegerType::get(ctx, 64); + + // Configure vtype based on element type + createVsetvli(rewriter, loc, module, config.targetVL, config.sew, config.lmul); + Value vlValue = rewriter.create( + loc, i64Type, rewriter.getI64IntegerAttr(config.targetVL)); + Value vdPtr = extractPointerFromMemref(rewriter, loc, op.getVd()); Value vs1Ptr = extractPointerFromMemref(rewriter, loc, op.getVs1()); Value vs2Ptr = extractPointerFromMemref(rewriter, loc, op.getVs2()); - auto i8Type = IntegerType::get(ctx, 8); - auto i32Type = IntegerType::get(ctx, 32); - auto i8VecType = VectorType::get({32}, i8Type, /*scalableDims=*/true); - auto i32VecType = VectorType::get({8}, i32Type, /*scalableDims=*/true); - - Value vs1Vec = createRVVVectorLoad(rewriter, loc, module, vs1Ptr, i8VecType, - "llvm.riscv.vle.nxv32i8"); - Value vs2Vec = createRVVVectorLoad(rewriter, loc, module, vs2Ptr, i8VecType, - "llvm.riscv.vle.nxv32i8"); - Value vdVec = createRVVVectorLoad(rewriter, loc, module, vdPtr, i32VecType, - "llvm.riscv.vle.nxv8i32"); + std::string loadIntrinsic = "llvm.riscv.vle." + config.inputVecSuffix; + std::string outputLoadIntrinsic = "llvm.riscv.vle." + config.outputVecSuffix; + std::string storeIntrinsic = "llvm.riscv.vse." + config.outputVecSuffix; + std::string typeSuffix = "." + config.outputVecSuffix + "." + + config.inputVecSuffix + "." + config.inputVecSuffix; + + Value vs1Vec = createRVVVectorLoad(rewriter, loc, module, vs1Ptr, + config.inputVecType, loadIntrinsic, vlValue); + Value vs2Vec = createRVVVectorLoad(rewriter, loc, module, vs2Ptr, + config.inputVecType, loadIntrinsic, vlValue); + Value vdVec = createRVVVectorLoad(rewriter, loc, module, vdPtr, + config.outputVecType, outputLoadIntrinsic, vlValue); Value result = createIMEVmadotIntrinsic( rewriter, loc, module, vdVec, vs1Vec, vs2Vec, "llvm.riscv.ime.vmadotsu", - ".nxv8i32.nxv32i8.nxv32i8"); + typeSuffix); createRVVVectorStore(rewriter, loc, module, result, vdPtr, - "llvm.riscv.vse.nxv8i32"); + storeIntrinsic, vlValue); rewriter.eraseOp(op); return success(); } @@ -238,26 +427,40 @@ struct IMEVmadotusLowering : public ConvertOpToLLVMPattern { if (!module) return failure(); + // Get element type from vs1 memref and configure accordingly + Type elemType = getMemRefElementType(op.getVs1()); + if (!elemType) + return failure(); + + IMETypeConfig config = getIMETypeConfig(ctx, elemType); + auto i64Type = IntegerType::get(ctx, 64); + + // Configure vtype based on element type + createVsetvli(rewriter, loc, module, config.targetVL, config.sew, config.lmul); + Value vlValue = rewriter.create( + loc, i64Type, rewriter.getI64IntegerAttr(config.targetVL)); + Value vdPtr = extractPointerFromMemref(rewriter, loc, op.getVd()); Value vs1Ptr = extractPointerFromMemref(rewriter, loc, op.getVs1()); Value vs2Ptr = extractPointerFromMemref(rewriter, loc, op.getVs2()); - auto i8Type = IntegerType::get(ctx, 8); - auto i32Type = IntegerType::get(ctx, 32); - auto i8VecType = VectorType::get({32}, i8Type, /*scalableDims=*/true); - auto i32VecType = VectorType::get({8}, i32Type, /*scalableDims=*/true); - - Value vs1Vec = createRVVVectorLoad(rewriter, loc, module, vs1Ptr, i8VecType, - "llvm.riscv.vle.nxv32i8"); - Value vs2Vec = createRVVVectorLoad(rewriter, loc, module, vs2Ptr, i8VecType, - "llvm.riscv.vle.nxv32i8"); - Value vdVec = createRVVVectorLoad(rewriter, loc, module, vdPtr, i32VecType, - "llvm.riscv.vle.nxv8i32"); + std::string loadIntrinsic = "llvm.riscv.vle." + config.inputVecSuffix; + std::string outputLoadIntrinsic = "llvm.riscv.vle." + config.outputVecSuffix; + std::string storeIntrinsic = "llvm.riscv.vse." + config.outputVecSuffix; + std::string typeSuffix = "." + config.outputVecSuffix + "." + + config.inputVecSuffix + "." + config.inputVecSuffix; + + Value vs1Vec = createRVVVectorLoad(rewriter, loc, module, vs1Ptr, + config.inputVecType, loadIntrinsic, vlValue); + Value vs2Vec = createRVVVectorLoad(rewriter, loc, module, vs2Ptr, + config.inputVecType, loadIntrinsic, vlValue); + Value vdVec = createRVVVectorLoad(rewriter, loc, module, vdPtr, + config.outputVecType, outputLoadIntrinsic, vlValue); Value result = createIMEVmadotIntrinsic( rewriter, loc, module, vdVec, vs1Vec, vs2Vec, "llvm.riscv.ime.vmadotus", - ".nxv8i32.nxv32i8.nxv32i8"); + typeSuffix); createRVVVectorStore(rewriter, loc, module, result, vdPtr, - "llvm.riscv.vse.nxv8i32"); + storeIntrinsic, vlValue); rewriter.eraseOp(op); return success(); } @@ -275,24 +478,1126 @@ struct IMEVfmadotLowering : public ConvertOpToLLVMPattern { if (!module) return failure(); + // Get element type from vs1 memref - should be f16 + Type elemType = getMemRefElementType(op.getVs1()); + if (!elemType) + return failure(); + + IMETypeConfig config = getIMETypeConfig(ctx, elemType); + auto i64Type = IntegerType::get(ctx, 64); + + // Configure vtype based on element type (fp16: vl=16, SEW=e16, LMUL=m1) + createVsetvli(rewriter, loc, module, config.targetVL, config.sew, config.lmul); + Value vlValue = rewriter.create( + loc, i64Type, rewriter.getI64IntegerAttr(config.targetVL)); + Value vdPtr = extractPointerFromMemref(rewriter, loc, op.getVd()); Value vs1Ptr = extractPointerFromMemref(rewriter, loc, op.getVs1()); Value vs2Ptr = extractPointerFromMemref(rewriter, loc, op.getVs2()); - auto f16Type = Float16Type::get(ctx); - auto f16VecType = VectorType::get({32}, f16Type, /*scalableDims=*/true); + std::string loadIntrinsic = "llvm.riscv.vle." + config.inputVecSuffix; + std::string outputLoadIntrinsic = "llvm.riscv.vle." + config.outputVecSuffix; + std::string storeIntrinsic = "llvm.riscv.vse." + config.outputVecSuffix; + std::string typeSuffix = "." + config.outputVecSuffix + "." + + config.inputVecSuffix + "." + config.inputVecSuffix; Value vs1Vec = createRVVVectorLoad(rewriter, loc, module, vs1Ptr, - f16VecType, "llvm.riscv.vle.nxv32f16"); + config.inputVecType, loadIntrinsic, vlValue); Value vs2Vec = createRVVVectorLoad(rewriter, loc, module, vs2Ptr, - f16VecType, "llvm.riscv.vle.nxv32f16"); - Value vdVec = createRVVVectorLoad(rewriter, loc, module, vdPtr, f16VecType, - "llvm.riscv.vle.nxv32f16"); + config.inputVecType, loadIntrinsic, vlValue); + Value vdVec = createRVVVectorLoad(rewriter, loc, module, vdPtr, + config.outputVecType, outputLoadIntrinsic, vlValue); Value result = createIMEVmadotIntrinsic( rewriter, loc, module, vdVec, vs1Vec, vs2Vec, "llvm.riscv.ime.vfmadot", - ".nxv32f16.nxv32f16.nxv32f16"); + typeSuffix); + createRVVVectorStore(rewriter, loc, module, result, vdPtr, + storeIntrinsic, vlValue); + rewriter.eraseOp(op); + return success(); + } +}; + +struct IMEVmadot1Lowering : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(Vmadot1Op op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + auto *ctx = rewriter.getContext(); + auto module = op->getParentOfType(); + if (!module) + return failure(); + + // Get element type from vs2 memref and configure accordingly + Type elemType = getMemRefElementType(op.getVs2()); + if (!elemType) + return failure(); + + IMETypeConfig config = getIMETypeConfig(ctx, elemType); + auto i64Type = IntegerType::get(ctx, 64); + + // Configure vtype based on element type + createVsetvli(rewriter, loc, module, config.targetVL, config.sew, config.lmul); + Value vlValue = rewriter.create( + loc, i64Type, rewriter.getI64IntegerAttr(config.targetVL)); + // For extended loads (2x elements for vs1) + Value vlValueExtended = rewriter.create( + loc, i64Type, rewriter.getI64IntegerAttr(config.extendedVL)); + + Value vdPtr = extractPointerFromMemref(rewriter, loc, op.getVd()); + Value vs1Ptr = extractPointerFromMemref(rewriter, loc, op.getVs1()); + Value vs2Ptr = extractPointerFromMemref(rewriter, loc, op.getVs2()); + + std::string extLoadIntrinsic = "llvm.riscv.vle." + config.extendedInputVecSuffix; + std::string loadIntrinsic = "llvm.riscv.vle." + config.inputVecSuffix; + std::string outputLoadIntrinsic = "llvm.riscv.vle." + config.outputVecSuffix; + std::string storeIntrinsic = "llvm.riscv.vse." + config.outputVecSuffix; + std::string typeSuffix = "." + config.outputVecSuffix + "." + + config.extendedInputVecSuffix + "." + config.inputVecSuffix; + + // Load extended elements for VS1 (two consecutive vector registers) + Value vs1Vec = createRVVVectorLoad(rewriter, loc, module, vs1Ptr, + config.extendedInputVecType, extLoadIntrinsic, vlValueExtended); + Value vs2Vec = createRVVVectorLoad(rewriter, loc, module, vs2Ptr, + config.inputVecType, loadIntrinsic, vlValue); + Value vdVec = createRVVVectorLoad(rewriter, loc, module, vdPtr, + config.outputVecType, outputLoadIntrinsic, vlValue); + Value result = createIMEVmadotIntrinsic( + rewriter, loc, module, vdVec, vs1Vec, vs2Vec, "llvm.riscv.ime.vmadot1", + typeSuffix); + createRVVVectorStore(rewriter, loc, module, result, vdPtr, + storeIntrinsic, vlValue); + rewriter.eraseOp(op); + return success(); + } +}; + +struct IMEVmadot1uLowering : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(Vmadot1uOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + auto *ctx = rewriter.getContext(); + auto module = op->getParentOfType(); + if (!module) + return failure(); + + // Get element type from vs2 memref and configure accordingly + Type elemType = getMemRefElementType(op.getVs2()); + if (!elemType) + return failure(); + + IMETypeConfig config = getIMETypeConfig(ctx, elemType); + auto i64Type = IntegerType::get(ctx, 64); + + // Configure vtype based on element type + createVsetvli(rewriter, loc, module, config.targetVL, config.sew, config.lmul); + Value vlValue = rewriter.create( + loc, i64Type, rewriter.getI64IntegerAttr(config.targetVL)); + Value vlValueExtended = rewriter.create( + loc, i64Type, rewriter.getI64IntegerAttr(config.extendedVL)); + + Value vdPtr = extractPointerFromMemref(rewriter, loc, op.getVd()); + Value vs1Ptr = extractPointerFromMemref(rewriter, loc, op.getVs1()); + Value vs2Ptr = extractPointerFromMemref(rewriter, loc, op.getVs2()); + + std::string extLoadIntrinsic = "llvm.riscv.vle." + config.extendedInputVecSuffix; + std::string loadIntrinsic = "llvm.riscv.vle." + config.inputVecSuffix; + std::string outputLoadIntrinsic = "llvm.riscv.vle." + config.outputVecSuffix; + std::string storeIntrinsic = "llvm.riscv.vse." + config.outputVecSuffix; + std::string typeSuffix = "." + config.outputVecSuffix + "." + + config.extendedInputVecSuffix + "." + config.inputVecSuffix; + + Value vs1Vec = createRVVVectorLoad(rewriter, loc, module, vs1Ptr, + config.extendedInputVecType, extLoadIntrinsic, vlValueExtended); + Value vs2Vec = createRVVVectorLoad(rewriter, loc, module, vs2Ptr, + config.inputVecType, loadIntrinsic, vlValue); + Value vdVec = createRVVVectorLoad(rewriter, loc, module, vdPtr, + config.outputVecType, outputLoadIntrinsic, vlValue); + Value result = createIMEVmadotIntrinsic( + rewriter, loc, module, vdVec, vs1Vec, vs2Vec, "llvm.riscv.ime.vmadot1u", + typeSuffix); + createRVVVectorStore(rewriter, loc, module, result, vdPtr, + storeIntrinsic, vlValue); + rewriter.eraseOp(op); + return success(); + } +}; + +struct IMEVmadot1suLowering : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(Vmadot1suOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + auto *ctx = rewriter.getContext(); + auto module = op->getParentOfType(); + if (!module) + return failure(); + + // Get element type from vs2 memref and configure accordingly + Type elemType = getMemRefElementType(op.getVs2()); + if (!elemType) + return failure(); + + IMETypeConfig config = getIMETypeConfig(ctx, elemType); + auto i64Type = IntegerType::get(ctx, 64); + + // Configure vtype based on element type + createVsetvli(rewriter, loc, module, config.targetVL, config.sew, config.lmul); + Value vlValue = rewriter.create( + loc, i64Type, rewriter.getI64IntegerAttr(config.targetVL)); + Value vlValueExtended = rewriter.create( + loc, i64Type, rewriter.getI64IntegerAttr(config.extendedVL)); + + Value vdPtr = extractPointerFromMemref(rewriter, loc, op.getVd()); + Value vs1Ptr = extractPointerFromMemref(rewriter, loc, op.getVs1()); + Value vs2Ptr = extractPointerFromMemref(rewriter, loc, op.getVs2()); + + std::string extLoadIntrinsic = "llvm.riscv.vle." + config.extendedInputVecSuffix; + std::string loadIntrinsic = "llvm.riscv.vle." + config.inputVecSuffix; + std::string outputLoadIntrinsic = "llvm.riscv.vle." + config.outputVecSuffix; + std::string storeIntrinsic = "llvm.riscv.vse." + config.outputVecSuffix; + std::string typeSuffix = "." + config.outputVecSuffix + "." + + config.extendedInputVecSuffix + "." + config.inputVecSuffix; + + Value vs1Vec = createRVVVectorLoad(rewriter, loc, module, vs1Ptr, + config.extendedInputVecType, extLoadIntrinsic, vlValueExtended); + Value vs2Vec = createRVVVectorLoad(rewriter, loc, module, vs2Ptr, + config.inputVecType, loadIntrinsic, vlValue); + Value vdVec = createRVVVectorLoad(rewriter, loc, module, vdPtr, + config.outputVecType, outputLoadIntrinsic, vlValue); + Value result = createIMEVmadotIntrinsic( + rewriter, loc, module, vdVec, vs1Vec, vs2Vec, + "llvm.riscv.ime.vmadot1su", typeSuffix); + createRVVVectorStore(rewriter, loc, module, result, vdPtr, + storeIntrinsic, vlValue); + rewriter.eraseOp(op); + return success(); + } +}; + +struct IMEVmadot1usLowering : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(Vmadot1usOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + auto *ctx = rewriter.getContext(); + auto module = op->getParentOfType(); + if (!module) + return failure(); + + // Get element type from vs2 memref and configure accordingly + Type elemType = getMemRefElementType(op.getVs2()); + if (!elemType) + return failure(); + + IMETypeConfig config = getIMETypeConfig(ctx, elemType); + auto i64Type = IntegerType::get(ctx, 64); + + // Configure vtype based on element type + createVsetvli(rewriter, loc, module, config.targetVL, config.sew, config.lmul); + Value vlValue = rewriter.create( + loc, i64Type, rewriter.getI64IntegerAttr(config.targetVL)); + Value vlValueExtended = rewriter.create( + loc, i64Type, rewriter.getI64IntegerAttr(config.extendedVL)); + + Value vdPtr = extractPointerFromMemref(rewriter, loc, op.getVd()); + Value vs1Ptr = extractPointerFromMemref(rewriter, loc, op.getVs1()); + Value vs2Ptr = extractPointerFromMemref(rewriter, loc, op.getVs2()); + + std::string extLoadIntrinsic = "llvm.riscv.vle." + config.extendedInputVecSuffix; + std::string loadIntrinsic = "llvm.riscv.vle." + config.inputVecSuffix; + std::string outputLoadIntrinsic = "llvm.riscv.vle." + config.outputVecSuffix; + std::string storeIntrinsic = "llvm.riscv.vse." + config.outputVecSuffix; + std::string typeSuffix = "." + config.outputVecSuffix + "." + + config.extendedInputVecSuffix + "." + config.inputVecSuffix; + + Value vs1Vec = createRVVVectorLoad(rewriter, loc, module, vs1Ptr, + config.extendedInputVecType, extLoadIntrinsic, vlValueExtended); + Value vs2Vec = createRVVVectorLoad(rewriter, loc, module, vs2Ptr, + config.inputVecType, loadIntrinsic, vlValue); + Value vdVec = createRVVVectorLoad(rewriter, loc, module, vdPtr, + config.outputVecType, outputLoadIntrinsic, vlValue); + Value result = createIMEVmadotIntrinsic( + rewriter, loc, module, vdVec, vs1Vec, vs2Vec, + "llvm.riscv.ime.vmadot1us", typeSuffix); + createRVVVectorStore(rewriter, loc, module, result, vdPtr, + storeIntrinsic, vlValue); + rewriter.eraseOp(op); + return success(); + } +}; + +struct IMEVmadot2Lowering : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(Vmadot2Op op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + auto *ctx = rewriter.getContext(); + auto module = op->getParentOfType(); + if (!module) + return failure(); + + // Get element type from vs2 memref and configure accordingly + Type elemType = getMemRefElementType(op.getVs2()); + if (!elemType) + return failure(); + + IMETypeConfig config = getIMETypeConfig(ctx, elemType); + auto i64Type = IntegerType::get(ctx, 64); + + // Configure vtype based on element type + createVsetvli(rewriter, loc, module, config.targetVL, config.sew, config.lmul); + Value vlValue = rewriter.create( + loc, i64Type, rewriter.getI64IntegerAttr(config.targetVL)); + Value vlValueExtended = rewriter.create( + loc, i64Type, rewriter.getI64IntegerAttr(config.extendedVL)); + + Value vdPtr = extractPointerFromMemref(rewriter, loc, op.getVd()); + Value vs1Ptr = extractPointerFromMemref(rewriter, loc, op.getVs1()); + Value vs2Ptr = extractPointerFromMemref(rewriter, loc, op.getVs2()); + + std::string extLoadIntrinsic = "llvm.riscv.vle." + config.extendedInputVecSuffix; + std::string loadIntrinsic = "llvm.riscv.vle." + config.inputVecSuffix; + std::string outputLoadIntrinsic = "llvm.riscv.vle." + config.outputVecSuffix; + std::string storeIntrinsic = "llvm.riscv.vse." + config.outputVecSuffix; + std::string typeSuffix = "." + config.outputVecSuffix + "." + + config.extendedInputVecSuffix + "." + config.inputVecSuffix; + + Value vs1Vec = createRVVVectorLoad(rewriter, loc, module, vs1Ptr, + config.extendedInputVecType, extLoadIntrinsic, vlValueExtended); + Value vs2Vec = createRVVVectorLoad(rewriter, loc, module, vs2Ptr, + config.inputVecType, loadIntrinsic, vlValue); + Value vdVec = createRVVVectorLoad(rewriter, loc, module, vdPtr, + config.outputVecType, outputLoadIntrinsic, vlValue); + Value result = createIMEVmadotIntrinsic( + rewriter, loc, module, vdVec, vs1Vec, vs2Vec, "llvm.riscv.ime.vmadot2", + typeSuffix); + createRVVVectorStore(rewriter, loc, module, result, vdPtr, + storeIntrinsic, vlValue); + rewriter.eraseOp(op); + return success(); + } +}; + +struct IMEVmadot2uLowering : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(Vmadot2uOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + auto *ctx = rewriter.getContext(); + auto module = op->getParentOfType(); + if (!module) + return failure(); + + // Get element type from vs2 memref and configure accordingly + Type elemType = getMemRefElementType(op.getVs2()); + if (!elemType) + return failure(); + + IMETypeConfig config = getIMETypeConfig(ctx, elemType); + auto i64Type = IntegerType::get(ctx, 64); + + // Configure vtype based on element type + createVsetvli(rewriter, loc, module, config.targetVL, config.sew, config.lmul); + Value vlValue = rewriter.create( + loc, i64Type, rewriter.getI64IntegerAttr(config.targetVL)); + Value vlValueExtended = rewriter.create( + loc, i64Type, rewriter.getI64IntegerAttr(config.extendedVL)); + + Value vdPtr = extractPointerFromMemref(rewriter, loc, op.getVd()); + Value vs1Ptr = extractPointerFromMemref(rewriter, loc, op.getVs1()); + Value vs2Ptr = extractPointerFromMemref(rewriter, loc, op.getVs2()); + + std::string extLoadIntrinsic = "llvm.riscv.vle." + config.extendedInputVecSuffix; + std::string loadIntrinsic = "llvm.riscv.vle." + config.inputVecSuffix; + std::string outputLoadIntrinsic = "llvm.riscv.vle." + config.outputVecSuffix; + std::string storeIntrinsic = "llvm.riscv.vse." + config.outputVecSuffix; + std::string typeSuffix = "." + config.outputVecSuffix + "." + + config.extendedInputVecSuffix + "." + config.inputVecSuffix; + + Value vs1Vec = createRVVVectorLoad(rewriter, loc, module, vs1Ptr, + config.extendedInputVecType, extLoadIntrinsic, vlValueExtended); + Value vs2Vec = createRVVVectorLoad(rewriter, loc, module, vs2Ptr, + config.inputVecType, loadIntrinsic, vlValue); + Value vdVec = createRVVVectorLoad(rewriter, loc, module, vdPtr, + config.outputVecType, outputLoadIntrinsic, vlValue); + Value result = createIMEVmadotIntrinsic( + rewriter, loc, module, vdVec, vs1Vec, vs2Vec, "llvm.riscv.ime.vmadot2u", + typeSuffix); + createRVVVectorStore(rewriter, loc, module, result, vdPtr, + storeIntrinsic, vlValue); + rewriter.eraseOp(op); + return success(); + } +}; + +struct IMEVmadot2suLowering : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(Vmadot2suOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + auto *ctx = rewriter.getContext(); + auto module = op->getParentOfType(); + if (!module) + return failure(); + + // Get element type from vs2 memref and configure accordingly + Type elemType = getMemRefElementType(op.getVs2()); + if (!elemType) + return failure(); + + IMETypeConfig config = getIMETypeConfig(ctx, elemType); + auto i64Type = IntegerType::get(ctx, 64); + + // Configure vtype based on element type + createVsetvli(rewriter, loc, module, config.targetVL, config.sew, config.lmul); + Value vlValue = rewriter.create( + loc, i64Type, rewriter.getI64IntegerAttr(config.targetVL)); + Value vlValueExtended = rewriter.create( + loc, i64Type, rewriter.getI64IntegerAttr(config.extendedVL)); + + Value vdPtr = extractPointerFromMemref(rewriter, loc, op.getVd()); + Value vs1Ptr = extractPointerFromMemref(rewriter, loc, op.getVs1()); + Value vs2Ptr = extractPointerFromMemref(rewriter, loc, op.getVs2()); + + std::string extLoadIntrinsic = "llvm.riscv.vle." + config.extendedInputVecSuffix; + std::string loadIntrinsic = "llvm.riscv.vle." + config.inputVecSuffix; + std::string outputLoadIntrinsic = "llvm.riscv.vle." + config.outputVecSuffix; + std::string storeIntrinsic = "llvm.riscv.vse." + config.outputVecSuffix; + std::string typeSuffix = "." + config.outputVecSuffix + "." + + config.extendedInputVecSuffix + "." + config.inputVecSuffix; + + Value vs1Vec = createRVVVectorLoad(rewriter, loc, module, vs1Ptr, + config.extendedInputVecType, extLoadIntrinsic, vlValueExtended); + Value vs2Vec = createRVVVectorLoad(rewriter, loc, module, vs2Ptr, + config.inputVecType, loadIntrinsic, vlValue); + Value vdVec = createRVVVectorLoad(rewriter, loc, module, vdPtr, + config.outputVecType, outputLoadIntrinsic, vlValue); + Value result = createIMEVmadotIntrinsic( + rewriter, loc, module, vdVec, vs1Vec, vs2Vec, + "llvm.riscv.ime.vmadot2su", typeSuffix); + createRVVVectorStore(rewriter, loc, module, result, vdPtr, + storeIntrinsic, vlValue); + rewriter.eraseOp(op); + return success(); + } +}; + +struct IMEVmadot2usLowering : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(Vmadot2usOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + auto *ctx = rewriter.getContext(); + auto module = op->getParentOfType(); + if (!module) + return failure(); + + // Get element type from vs2 memref and configure accordingly + Type elemType = getMemRefElementType(op.getVs2()); + if (!elemType) + return failure(); + + IMETypeConfig config = getIMETypeConfig(ctx, elemType); + auto i64Type = IntegerType::get(ctx, 64); + + // Configure vtype based on element type + createVsetvli(rewriter, loc, module, config.targetVL, config.sew, config.lmul); + Value vlValue = rewriter.create( + loc, i64Type, rewriter.getI64IntegerAttr(config.targetVL)); + Value vlValueExtended = rewriter.create( + loc, i64Type, rewriter.getI64IntegerAttr(config.extendedVL)); + + Value vdPtr = extractPointerFromMemref(rewriter, loc, op.getVd()); + Value vs1Ptr = extractPointerFromMemref(rewriter, loc, op.getVs1()); + Value vs2Ptr = extractPointerFromMemref(rewriter, loc, op.getVs2()); + + std::string extLoadIntrinsic = "llvm.riscv.vle." + config.extendedInputVecSuffix; + std::string loadIntrinsic = "llvm.riscv.vle." + config.inputVecSuffix; + std::string outputLoadIntrinsic = "llvm.riscv.vle." + config.outputVecSuffix; + std::string storeIntrinsic = "llvm.riscv.vse." + config.outputVecSuffix; + std::string typeSuffix = "." + config.outputVecSuffix + "." + + config.extendedInputVecSuffix + "." + config.inputVecSuffix; + + Value vs1Vec = createRVVVectorLoad(rewriter, loc, module, vs1Ptr, + config.extendedInputVecType, extLoadIntrinsic, vlValueExtended); + Value vs2Vec = createRVVVectorLoad(rewriter, loc, module, vs2Ptr, + config.inputVecType, loadIntrinsic, vlValue); + Value vdVec = createRVVVectorLoad(rewriter, loc, module, vdPtr, + config.outputVecType, outputLoadIntrinsic, vlValue); + Value result = createIMEVmadotIntrinsic( + rewriter, loc, module, vdVec, vs1Vec, vs2Vec, + "llvm.riscv.ime.vmadot2us", typeSuffix); + createRVVVectorStore(rewriter, loc, module, result, vdPtr, + storeIntrinsic, vlValue); + rewriter.eraseOp(op); + return success(); + } +}; + +struct IMEVmadot3Lowering : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(Vmadot3Op op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + auto *ctx = rewriter.getContext(); + auto module = op->getParentOfType(); + if (!module) + return failure(); + + // Get element type from vs2 memref and configure accordingly + Type elemType = getMemRefElementType(op.getVs2()); + if (!elemType) + return failure(); + + IMETypeConfig config = getIMETypeConfig(ctx, elemType); + auto i64Type = IntegerType::get(ctx, 64); + + // Configure vtype based on element type + createVsetvli(rewriter, loc, module, config.targetVL, config.sew, config.lmul); + Value vlValue = rewriter.create( + loc, i64Type, rewriter.getI64IntegerAttr(config.targetVL)); + Value vlValueExtended = rewriter.create( + loc, i64Type, rewriter.getI64IntegerAttr(config.extendedVL)); + + Value vdPtr = extractPointerFromMemref(rewriter, loc, op.getVd()); + Value vs1Ptr = extractPointerFromMemref(rewriter, loc, op.getVs1()); + Value vs2Ptr = extractPointerFromMemref(rewriter, loc, op.getVs2()); + + std::string extLoadIntrinsic = "llvm.riscv.vle." + config.extendedInputVecSuffix; + std::string loadIntrinsic = "llvm.riscv.vle." + config.inputVecSuffix; + std::string outputLoadIntrinsic = "llvm.riscv.vle." + config.outputVecSuffix; + std::string storeIntrinsic = "llvm.riscv.vse." + config.outputVecSuffix; + std::string typeSuffix = "." + config.outputVecSuffix + "." + + config.extendedInputVecSuffix + "." + config.inputVecSuffix; + + Value vs1Vec = createRVVVectorLoad(rewriter, loc, module, vs1Ptr, + config.extendedInputVecType, extLoadIntrinsic, vlValueExtended); + Value vs2Vec = createRVVVectorLoad(rewriter, loc, module, vs2Ptr, + config.inputVecType, loadIntrinsic, vlValue); + Value vdVec = createRVVVectorLoad(rewriter, loc, module, vdPtr, + config.outputVecType, outputLoadIntrinsic, vlValue); + Value result = createIMEVmadotIntrinsic( + rewriter, loc, module, vdVec, vs1Vec, vs2Vec, "llvm.riscv.ime.vmadot3", + typeSuffix); + createRVVVectorStore(rewriter, loc, module, result, vdPtr, + storeIntrinsic, vlValue); + rewriter.eraseOp(op); + return success(); + } +}; + +struct IMEVmadot3uLowering : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(Vmadot3uOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + auto *ctx = rewriter.getContext(); + auto module = op->getParentOfType(); + if (!module) + return failure(); + + // Get element type from vs2 memref and configure accordingly + Type elemType = getMemRefElementType(op.getVs2()); + if (!elemType) + return failure(); + + IMETypeConfig config = getIMETypeConfig(ctx, elemType); + auto i64Type = IntegerType::get(ctx, 64); + + // Configure vtype based on element type + createVsetvli(rewriter, loc, module, config.targetVL, config.sew, config.lmul); + Value vlValue = rewriter.create( + loc, i64Type, rewriter.getI64IntegerAttr(config.targetVL)); + Value vlValueExtended = rewriter.create( + loc, i64Type, rewriter.getI64IntegerAttr(config.extendedVL)); + + Value vdPtr = extractPointerFromMemref(rewriter, loc, op.getVd()); + Value vs1Ptr = extractPointerFromMemref(rewriter, loc, op.getVs1()); + Value vs2Ptr = extractPointerFromMemref(rewriter, loc, op.getVs2()); + + std::string extLoadIntrinsic = "llvm.riscv.vle." + config.extendedInputVecSuffix; + std::string loadIntrinsic = "llvm.riscv.vle." + config.inputVecSuffix; + std::string outputLoadIntrinsic = "llvm.riscv.vle." + config.outputVecSuffix; + std::string storeIntrinsic = "llvm.riscv.vse." + config.outputVecSuffix; + std::string typeSuffix = "." + config.outputVecSuffix + "." + + config.extendedInputVecSuffix + "." + config.inputVecSuffix; + + Value vs1Vec = createRVVVectorLoad(rewriter, loc, module, vs1Ptr, + config.extendedInputVecType, extLoadIntrinsic, vlValueExtended); + Value vs2Vec = createRVVVectorLoad(rewriter, loc, module, vs2Ptr, + config.inputVecType, loadIntrinsic, vlValue); + Value vdVec = createRVVVectorLoad(rewriter, loc, module, vdPtr, + config.outputVecType, outputLoadIntrinsic, vlValue); + Value result = createIMEVmadotIntrinsic( + rewriter, loc, module, vdVec, vs1Vec, vs2Vec, "llvm.riscv.ime.vmadot3u", + typeSuffix); + createRVVVectorStore(rewriter, loc, module, result, vdPtr, + storeIntrinsic, vlValue); + rewriter.eraseOp(op); + return success(); + } +}; + +struct IMEVmadot3suLowering : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(Vmadot3suOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + auto *ctx = rewriter.getContext(); + auto module = op->getParentOfType(); + if (!module) + return failure(); + + // Get element type from vs2 memref and configure accordingly + Type elemType = getMemRefElementType(op.getVs2()); + if (!elemType) + return failure(); + + IMETypeConfig config = getIMETypeConfig(ctx, elemType); + auto i64Type = IntegerType::get(ctx, 64); + + // Configure vtype based on element type + createVsetvli(rewriter, loc, module, config.targetVL, config.sew, config.lmul); + Value vlValue = rewriter.create( + loc, i64Type, rewriter.getI64IntegerAttr(config.targetVL)); + Value vlValueExtended = rewriter.create( + loc, i64Type, rewriter.getI64IntegerAttr(config.extendedVL)); + + Value vdPtr = extractPointerFromMemref(rewriter, loc, op.getVd()); + Value vs1Ptr = extractPointerFromMemref(rewriter, loc, op.getVs1()); + Value vs2Ptr = extractPointerFromMemref(rewriter, loc, op.getVs2()); + + std::string extLoadIntrinsic = "llvm.riscv.vle." + config.extendedInputVecSuffix; + std::string loadIntrinsic = "llvm.riscv.vle." + config.inputVecSuffix; + std::string outputLoadIntrinsic = "llvm.riscv.vle." + config.outputVecSuffix; + std::string storeIntrinsic = "llvm.riscv.vse." + config.outputVecSuffix; + std::string typeSuffix = "." + config.outputVecSuffix + "." + + config.extendedInputVecSuffix + "." + config.inputVecSuffix; + + Value vs1Vec = createRVVVectorLoad(rewriter, loc, module, vs1Ptr, + config.extendedInputVecType, extLoadIntrinsic, vlValueExtended); + Value vs2Vec = createRVVVectorLoad(rewriter, loc, module, vs2Ptr, + config.inputVecType, loadIntrinsic, vlValue); + Value vdVec = createRVVVectorLoad(rewriter, loc, module, vdPtr, + config.outputVecType, outputLoadIntrinsic, vlValue); + Value result = createIMEVmadotIntrinsic( + rewriter, loc, module, vdVec, vs1Vec, vs2Vec, + "llvm.riscv.ime.vmadot3su", typeSuffix); + createRVVVectorStore(rewriter, loc, module, result, vdPtr, + storeIntrinsic, vlValue); + rewriter.eraseOp(op); + return success(); + } +}; + +struct IMEVmadot3usLowering : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(Vmadot3usOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + auto *ctx = rewriter.getContext(); + auto module = op->getParentOfType(); + if (!module) + return failure(); + + // Get element type from vs2 memref and configure accordingly + Type elemType = getMemRefElementType(op.getVs2()); + if (!elemType) + return failure(); + + IMETypeConfig config = getIMETypeConfig(ctx, elemType); + auto i64Type = IntegerType::get(ctx, 64); + + // Configure vtype based on element type + createVsetvli(rewriter, loc, module, config.targetVL, config.sew, config.lmul); + Value vlValue = rewriter.create( + loc, i64Type, rewriter.getI64IntegerAttr(config.targetVL)); + Value vlValueExtended = rewriter.create( + loc, i64Type, rewriter.getI64IntegerAttr(config.extendedVL)); + + Value vdPtr = extractPointerFromMemref(rewriter, loc, op.getVd()); + Value vs1Ptr = extractPointerFromMemref(rewriter, loc, op.getVs1()); + Value vs2Ptr = extractPointerFromMemref(rewriter, loc, op.getVs2()); + + std::string extLoadIntrinsic = "llvm.riscv.vle." + config.extendedInputVecSuffix; + std::string loadIntrinsic = "llvm.riscv.vle." + config.inputVecSuffix; + std::string outputLoadIntrinsic = "llvm.riscv.vle." + config.outputVecSuffix; + std::string storeIntrinsic = "llvm.riscv.vse." + config.outputVecSuffix; + std::string typeSuffix = "." + config.outputVecSuffix + "." + + config.extendedInputVecSuffix + "." + config.inputVecSuffix; + + Value vs1Vec = createRVVVectorLoad(rewriter, loc, module, vs1Ptr, + config.extendedInputVecType, extLoadIntrinsic, vlValueExtended); + Value vs2Vec = createRVVVectorLoad(rewriter, loc, module, vs2Ptr, + config.inputVecType, loadIntrinsic, vlValue); + Value vdVec = createRVVVectorLoad(rewriter, loc, module, vdPtr, + config.outputVecType, outputLoadIntrinsic, vlValue); + Value result = createIMEVmadotIntrinsic( + rewriter, loc, module, vdVec, vs1Vec, vs2Vec, + "llvm.riscv.ime.vmadot3us", typeSuffix); + createRVVVectorStore(rewriter, loc, module, result, vdPtr, + storeIntrinsic, vlValue); + rewriter.eraseOp(op); + return success(); + } +}; + +struct IMEVfmadot1Lowering : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(Vfmadot1Op op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + auto *ctx = rewriter.getContext(); + auto module = op->getParentOfType(); + if (!module) + return failure(); + + // Get element type from vs2 memref - should be f16 + Type elemType = getMemRefElementType(op.getVs2()); + if (!elemType) + return failure(); + + IMETypeConfig config = getIMETypeConfig(ctx, elemType); + auto i64Type = IntegerType::get(ctx, 64); + + // Configure vtype based on element type (fp16: vl=16, SEW=e16, LMUL=m1) + createVsetvli(rewriter, loc, module, config.targetVL, config.sew, config.lmul); + Value vlValue = rewriter.create( + loc, i64Type, rewriter.getI64IntegerAttr(config.targetVL)); + Value vlValueExtended = rewriter.create( + loc, i64Type, rewriter.getI64IntegerAttr(config.extendedVL)); + + Value vdPtr = extractPointerFromMemref(rewriter, loc, op.getVd()); + Value vs1Ptr = extractPointerFromMemref(rewriter, loc, op.getVs1()); + Value vs2Ptr = extractPointerFromMemref(rewriter, loc, op.getVs2()); + + std::string extLoadIntrinsic = "llvm.riscv.vle." + config.extendedInputVecSuffix; + std::string loadIntrinsic = "llvm.riscv.vle." + config.inputVecSuffix; + std::string outputLoadIntrinsic = "llvm.riscv.vle." + config.outputVecSuffix; + std::string storeIntrinsic = "llvm.riscv.vse." + config.outputVecSuffix; + std::string typeSuffix = "." + config.outputVecSuffix + "." + + config.extendedInputVecSuffix + "." + config.inputVecSuffix; + + Value vs1Vec = createRVVVectorLoad(rewriter, loc, module, vs1Ptr, + config.extendedInputVecType, extLoadIntrinsic, vlValueExtended); + Value vs2Vec = createRVVVectorLoad(rewriter, loc, module, vs2Ptr, + config.inputVecType, loadIntrinsic, vlValue); + Value vdVec = createRVVVectorLoad(rewriter, loc, module, vdPtr, + config.outputVecType, outputLoadIntrinsic, vlValue); + Value result = createIMEVmadotIntrinsic( + rewriter, loc, module, vdVec, vs1Vec, vs2Vec, "llvm.riscv.ime.vfmadot1", + typeSuffix); + createRVVVectorStore(rewriter, loc, module, result, vdPtr, + storeIntrinsic, vlValue); + rewriter.eraseOp(op); + return success(); + } +}; + +struct IMEVfmadot2Lowering : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(Vfmadot2Op op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + auto *ctx = rewriter.getContext(); + auto module = op->getParentOfType(); + if (!module) + return failure(); + + // Get element type from vs2 memref - should be f16 + Type elemType = getMemRefElementType(op.getVs2()); + if (!elemType) + return failure(); + + IMETypeConfig config = getIMETypeConfig(ctx, elemType); + auto i64Type = IntegerType::get(ctx, 64); + + // Configure vtype based on element type (fp16: vl=16, SEW=e16, LMUL=m1) + createVsetvli(rewriter, loc, module, config.targetVL, config.sew, config.lmul); + Value vlValue = rewriter.create( + loc, i64Type, rewriter.getI64IntegerAttr(config.targetVL)); + Value vlValueExtended = rewriter.create( + loc, i64Type, rewriter.getI64IntegerAttr(config.extendedVL)); + + Value vdPtr = extractPointerFromMemref(rewriter, loc, op.getVd()); + Value vs1Ptr = extractPointerFromMemref(rewriter, loc, op.getVs1()); + Value vs2Ptr = extractPointerFromMemref(rewriter, loc, op.getVs2()); + + std::string extLoadIntrinsic = "llvm.riscv.vle." + config.extendedInputVecSuffix; + std::string loadIntrinsic = "llvm.riscv.vle." + config.inputVecSuffix; + std::string outputLoadIntrinsic = "llvm.riscv.vle." + config.outputVecSuffix; + std::string storeIntrinsic = "llvm.riscv.vse." + config.outputVecSuffix; + std::string typeSuffix = "." + config.outputVecSuffix + "." + + config.extendedInputVecSuffix + "." + config.inputVecSuffix; + + Value vs1Vec = createRVVVectorLoad(rewriter, loc, module, vs1Ptr, + config.extendedInputVecType, extLoadIntrinsic, vlValueExtended); + Value vs2Vec = createRVVVectorLoad(rewriter, loc, module, vs2Ptr, + config.inputVecType, loadIntrinsic, vlValue); + Value vdVec = createRVVVectorLoad(rewriter, loc, module, vdPtr, + config.outputVecType, outputLoadIntrinsic, vlValue); + Value result = createIMEVmadotIntrinsic( + rewriter, loc, module, vdVec, vs1Vec, vs2Vec, "llvm.riscv.ime.vfmadot2", + typeSuffix); + createRVVVectorStore(rewriter, loc, module, result, vdPtr, + storeIntrinsic, vlValue); + rewriter.eraseOp(op); + return success(); + } +}; + +struct IMEVfmadot3Lowering : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(Vfmadot3Op op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + auto *ctx = rewriter.getContext(); + auto module = op->getParentOfType(); + if (!module) + return failure(); + + // Get element type from vs2 memref - should be f16 + Type elemType = getMemRefElementType(op.getVs2()); + if (!elemType) + return failure(); + + IMETypeConfig config = getIMETypeConfig(ctx, elemType); + auto i64Type = IntegerType::get(ctx, 64); + + // Configure vtype based on element type (fp16: vl=16, SEW=e16, LMUL=m1) + createVsetvli(rewriter, loc, module, config.targetVL, config.sew, config.lmul); + Value vlValue = rewriter.create( + loc, i64Type, rewriter.getI64IntegerAttr(config.targetVL)); + Value vlValueExtended = rewriter.create( + loc, i64Type, rewriter.getI64IntegerAttr(config.extendedVL)); + + Value vdPtr = extractPointerFromMemref(rewriter, loc, op.getVd()); + Value vs1Ptr = extractPointerFromMemref(rewriter, loc, op.getVs1()); + Value vs2Ptr = extractPointerFromMemref(rewriter, loc, op.getVs2()); + + std::string extLoadIntrinsic = "llvm.riscv.vle." + config.extendedInputVecSuffix; + std::string loadIntrinsic = "llvm.riscv.vle." + config.inputVecSuffix; + std::string outputLoadIntrinsic = "llvm.riscv.vle." + config.outputVecSuffix; + std::string storeIntrinsic = "llvm.riscv.vse." + config.outputVecSuffix; + std::string typeSuffix = "." + config.outputVecSuffix + "." + + config.extendedInputVecSuffix + "." + config.inputVecSuffix; + + Value vs1Vec = createRVVVectorLoad(rewriter, loc, module, vs1Ptr, + config.extendedInputVecType, extLoadIntrinsic, vlValueExtended); + Value vs2Vec = createRVVVectorLoad(rewriter, loc, module, vs2Ptr, + config.inputVecType, loadIntrinsic, vlValue); + Value vdVec = createRVVVectorLoad(rewriter, loc, module, vdPtr, + config.outputVecType, outputLoadIntrinsic, vlValue); + Value result = createIMEVmadotIntrinsic( + rewriter, loc, module, vdVec, vs1Vec, vs2Vec, "llvm.riscv.ime.vfmadot3", + typeSuffix); + createRVVVectorStore(rewriter, loc, module, result, vdPtr, + storeIntrinsic, vlValue); + rewriter.eraseOp(op); + return success(); + } +}; + +struct IMEVmadotnLowering : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(VmadotnOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + auto *ctx = rewriter.getContext(); + auto module = op->getParentOfType(); + if (!module) + return failure(); + + // Get element type from vs2 memref and configure accordingly + Type elemType = getMemRefElementType(op.getVs2()); + if (!elemType) + return failure(); + + IMETypeConfig config = getIMETypeConfig(ctx, elemType); + auto i64Type = IntegerType::get(ctx, 64); + + // Configure vtype based on element type + createVsetvli(rewriter, loc, module, config.targetVL, config.sew, config.lmul); + Value vlValue = rewriter.create( + loc, i64Type, rewriter.getI64IntegerAttr(config.targetVL)); + Value vlValueExtended = rewriter.create( + loc, i64Type, rewriter.getI64IntegerAttr(config.extendedVL)); + + Value vdPtr = extractPointerFromMemref(rewriter, loc, op.getVd()); + Value vs1Ptr = extractPointerFromMemref(rewriter, loc, op.getVs1()); + Value vs2Ptr = extractPointerFromMemref(rewriter, loc, op.getVs2()); + Value slideVal = op.getSlide(); + + std::string extLoadIntrinsic = "llvm.riscv.vle." + config.extendedInputVecSuffix; + std::string loadIntrinsic = "llvm.riscv.vle." + config.inputVecSuffix; + std::string outputLoadIntrinsic = "llvm.riscv.vle." + config.outputVecSuffix; + std::string storeIntrinsic = "llvm.riscv.vse." + config.outputVecSuffix; + std::string typeSuffix = "." + config.outputVecSuffix + "." + + config.extendedInputVecSuffix + "." + config.inputVecSuffix; + + Value vs1Vec = createRVVVectorLoad(rewriter, loc, module, vs1Ptr, + config.extendedInputVecType, extLoadIntrinsic, vlValueExtended); + Value vs2Vec = createRVVVectorLoad(rewriter, loc, module, vs2Ptr, + config.inputVecType, loadIntrinsic, vlValue); + Value vdVec = createRVVVectorLoad(rewriter, loc, module, vdPtr, + config.outputVecType, outputLoadIntrinsic, vlValue); + Value result = createIMEVmadotnIntrinsic( + rewriter, loc, module, vdVec, vs1Vec, vs2Vec, slideVal, + "llvm.riscv.ime.vmadotn", typeSuffix); + createRVVVectorStore(rewriter, loc, module, result, vdPtr, + storeIntrinsic, vlValue); + rewriter.eraseOp(op); + return success(); + } +}; + +struct IMEVmadotnuLowering : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(VmadotnuOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + auto *ctx = rewriter.getContext(); + auto module = op->getParentOfType(); + if (!module) + return failure(); + + // Get element type from vs2 memref and configure accordingly + Type elemType = getMemRefElementType(op.getVs2()); + if (!elemType) + return failure(); + + IMETypeConfig config = getIMETypeConfig(ctx, elemType); + auto i64Type = IntegerType::get(ctx, 64); + + // Configure vtype based on element type + createVsetvli(rewriter, loc, module, config.targetVL, config.sew, config.lmul); + Value vlValue = rewriter.create( + loc, i64Type, rewriter.getI64IntegerAttr(config.targetVL)); + Value vlValueExtended = rewriter.create( + loc, i64Type, rewriter.getI64IntegerAttr(config.extendedVL)); + + Value vdPtr = extractPointerFromMemref(rewriter, loc, op.getVd()); + Value vs1Ptr = extractPointerFromMemref(rewriter, loc, op.getVs1()); + Value vs2Ptr = extractPointerFromMemref(rewriter, loc, op.getVs2()); + Value slideVal = op.getSlide(); + + std::string extLoadIntrinsic = "llvm.riscv.vle." + config.extendedInputVecSuffix; + std::string loadIntrinsic = "llvm.riscv.vle." + config.inputVecSuffix; + std::string outputLoadIntrinsic = "llvm.riscv.vle." + config.outputVecSuffix; + std::string storeIntrinsic = "llvm.riscv.vse." + config.outputVecSuffix; + std::string typeSuffix = "." + config.outputVecSuffix + "." + + config.extendedInputVecSuffix + "." + config.inputVecSuffix; + + Value vs1Vec = createRVVVectorLoad(rewriter, loc, module, vs1Ptr, + config.extendedInputVecType, extLoadIntrinsic, vlValueExtended); + Value vs2Vec = createRVVVectorLoad(rewriter, loc, module, vs2Ptr, + config.inputVecType, loadIntrinsic, vlValue); + Value vdVec = createRVVVectorLoad(rewriter, loc, module, vdPtr, + config.outputVecType, outputLoadIntrinsic, vlValue); + Value result = createIMEVmadotnIntrinsic( + rewriter, loc, module, vdVec, vs1Vec, vs2Vec, slideVal, + "llvm.riscv.ime.vmadotnu", typeSuffix); + createRVVVectorStore(rewriter, loc, module, result, vdPtr, + storeIntrinsic, vlValue); + rewriter.eraseOp(op); + return success(); + } +}; + +struct IMEVmadotnsuLowering : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(VmadotnsuOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + auto *ctx = rewriter.getContext(); + auto module = op->getParentOfType(); + if (!module) + return failure(); + + // Get element type from vs2 memref and configure accordingly + Type elemType = getMemRefElementType(op.getVs2()); + if (!elemType) + return failure(); + + IMETypeConfig config = getIMETypeConfig(ctx, elemType); + auto i64Type = IntegerType::get(ctx, 64); + + // Configure vtype based on element type + createVsetvli(rewriter, loc, module, config.targetVL, config.sew, config.lmul); + Value vlValue = rewriter.create( + loc, i64Type, rewriter.getI64IntegerAttr(config.targetVL)); + Value vlValueExtended = rewriter.create( + loc, i64Type, rewriter.getI64IntegerAttr(config.extendedVL)); + + Value vdPtr = extractPointerFromMemref(rewriter, loc, op.getVd()); + Value vs1Ptr = extractPointerFromMemref(rewriter, loc, op.getVs1()); + Value vs2Ptr = extractPointerFromMemref(rewriter, loc, op.getVs2()); + Value slideVal = op.getSlide(); + + std::string extLoadIntrinsic = "llvm.riscv.vle." + config.extendedInputVecSuffix; + std::string loadIntrinsic = "llvm.riscv.vle." + config.inputVecSuffix; + std::string outputLoadIntrinsic = "llvm.riscv.vle." + config.outputVecSuffix; + std::string storeIntrinsic = "llvm.riscv.vse." + config.outputVecSuffix; + std::string typeSuffix = "." + config.outputVecSuffix + "." + + config.extendedInputVecSuffix + "." + config.inputVecSuffix; + + Value vs1Vec = createRVVVectorLoad(rewriter, loc, module, vs1Ptr, + config.extendedInputVecType, extLoadIntrinsic, vlValueExtended); + Value vs2Vec = createRVVVectorLoad(rewriter, loc, module, vs2Ptr, + config.inputVecType, loadIntrinsic, vlValue); + Value vdVec = createRVVVectorLoad(rewriter, loc, module, vdPtr, + config.outputVecType, outputLoadIntrinsic, vlValue); + Value result = createIMEVmadotnIntrinsic( + rewriter, loc, module, vdVec, vs1Vec, vs2Vec, slideVal, + "llvm.riscv.ime.vmadotnsu", typeSuffix); + createRVVVectorStore(rewriter, loc, module, result, vdPtr, + storeIntrinsic, vlValue); + rewriter.eraseOp(op); + return success(); + } +}; + +struct IMEVmadotnusLowering : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(VmadotnusOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + auto *ctx = rewriter.getContext(); + auto module = op->getParentOfType(); + if (!module) + return failure(); + + // Get element type from vs2 memref and configure accordingly + Type elemType = getMemRefElementType(op.getVs2()); + if (!elemType) + return failure(); + + IMETypeConfig config = getIMETypeConfig(ctx, elemType); + auto i64Type = IntegerType::get(ctx, 64); + + // Configure vtype based on element type + createVsetvli(rewriter, loc, module, config.targetVL, config.sew, config.lmul); + Value vlValue = rewriter.create( + loc, i64Type, rewriter.getI64IntegerAttr(config.targetVL)); + Value vlValueExtended = rewriter.create( + loc, i64Type, rewriter.getI64IntegerAttr(config.extendedVL)); + + Value vdPtr = extractPointerFromMemref(rewriter, loc, op.getVd()); + Value vs1Ptr = extractPointerFromMemref(rewriter, loc, op.getVs1()); + Value vs2Ptr = extractPointerFromMemref(rewriter, loc, op.getVs2()); + Value slideVal = op.getSlide(); + + std::string extLoadIntrinsic = "llvm.riscv.vle." + config.extendedInputVecSuffix; + std::string loadIntrinsic = "llvm.riscv.vle." + config.inputVecSuffix; + std::string outputLoadIntrinsic = "llvm.riscv.vle." + config.outputVecSuffix; + std::string storeIntrinsic = "llvm.riscv.vse." + config.outputVecSuffix; + std::string typeSuffix = "." + config.outputVecSuffix + "." + + config.extendedInputVecSuffix + "." + config.inputVecSuffix; + + Value vs1Vec = createRVVVectorLoad(rewriter, loc, module, vs1Ptr, + config.extendedInputVecType, extLoadIntrinsic, vlValueExtended); + Value vs2Vec = createRVVVectorLoad(rewriter, loc, module, vs2Ptr, + config.inputVecType, loadIntrinsic, vlValue); + Value vdVec = createRVVVectorLoad(rewriter, loc, module, vdPtr, + config.outputVecType, outputLoadIntrinsic, vlValue); + Value result = createIMEVmadotnIntrinsic( + rewriter, loc, module, vdVec, vs1Vec, vs2Vec, slideVal, + "llvm.riscv.ime.vmadotnus", typeSuffix); + createRVVVectorStore(rewriter, loc, module, result, vdPtr, + storeIntrinsic, vlValue); + rewriter.eraseOp(op); + return success(); + } +}; + +struct IMEVfmadotnLowering : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(VfmadotnOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + auto *ctx = rewriter.getContext(); + auto module = op->getParentOfType(); + if (!module) + return failure(); + + // Get element type from vs2 memref and configure accordingly + Type elemType = getMemRefElementType(op.getVs2()); + if (!elemType) + return failure(); + + IMETypeConfig config = getIMETypeConfig(ctx, elemType); + auto i64Type = IntegerType::get(ctx, 64); + + // Configure vtype based on element type + createVsetvli(rewriter, loc, module, config.targetVL, config.sew, config.lmul); + Value vlValue = rewriter.create( + loc, i64Type, rewriter.getI64IntegerAttr(config.targetVL)); + Value vlValueExtended = rewriter.create( + loc, i64Type, rewriter.getI64IntegerAttr(config.extendedVL)); + + Value vdPtr = extractPointerFromMemref(rewriter, loc, op.getVd()); + Value vs1Ptr = extractPointerFromMemref(rewriter, loc, op.getVs1()); + Value vs2Ptr = extractPointerFromMemref(rewriter, loc, op.getVs2()); + Value slideVal = op.getSlide(); + + std::string extLoadIntrinsic = "llvm.riscv.vle." + config.extendedInputVecSuffix; + std::string loadIntrinsic = "llvm.riscv.vle." + config.inputVecSuffix; + std::string storeIntrinsic = "llvm.riscv.vse." + config.outputVecSuffix; + std::string typeSuffix = "." + config.outputVecSuffix + "." + + config.extendedInputVecSuffix + "." + config.inputVecSuffix; + + Value vs1Vec = createRVVVectorLoad(rewriter, loc, module, vs1Ptr, + config.extendedInputVecType, extLoadIntrinsic, vlValueExtended); + Value vs2Vec = createRVVVectorLoad(rewriter, loc, module, vs2Ptr, + config.inputVecType, loadIntrinsic, vlValue); + Value vdVec = createRVVVectorLoad(rewriter, loc, module, vdPtr, + config.outputVecType, loadIntrinsic, vlValue); + Value result = createIMEVmadotnIntrinsic( + rewriter, loc, module, vdVec, vs1Vec, vs2Vec, slideVal, + "llvm.riscv.ime.vfmadotn", typeSuffix); createRVVVectorStore(rewriter, loc, module, result, vdPtr, - "llvm.riscv.vse.nxv32f16"); + storeIntrinsic, vlValue); rewriter.eraseOp(op); return success(); } @@ -316,14 +1621,26 @@ struct LegalizeIMEForLLVMExport target.addLegalDialect(); target.addLegalDialect(); target.addLegalDialect(); - target - .addIllegalOp(); + target.addIllegalOp(); LLVMTypeConverter typeConverter(&context); RewritePatternSet patterns(&context); - patterns.add(typeConverter); + patterns + .add(typeConverter); if (failed(applyPartialConversion(module, target, std::move(patterns)))) signalPassFailure(); @@ -335,13 +1652,25 @@ struct LegalizeIMEForLLVMExport void mlir::populateIMELegalizeForLLVMExportPatterns( LLVMTypeConverter &converter, RewritePatternSet &patterns) { patterns.add(converter); + IMEVmadotusLowering, IMEVfmadotLowering, IMEVmadot1Lowering, + IMEVmadot1uLowering, IMEVmadot1suLowering, IMEVmadot1usLowering, + IMEVmadot2Lowering, IMEVmadot2uLowering, IMEVmadot2suLowering, + IMEVmadot2usLowering, IMEVmadot3Lowering, IMEVmadot3uLowering, + IMEVmadot3suLowering, IMEVmadot3usLowering, IMEVfmadot1Lowering, + IMEVfmadot2Lowering, IMEVfmadot3Lowering, IMEVmadotnLowering, + IMEVmadotnuLowering, IMEVmadotnsuLowering, IMEVmadotnusLowering, + IMEVfmadotnLowering>(converter); } void mlir::configureIMELegalizeForExportTarget(LLVMConversionTarget &target) { target.addLegalDialect(); target.addLegalDialect(); - target.addIllegalOp(); + target.addIllegalOp(); } std::unique_ptr buddy::ime::createLegalizeForLLVMExportPass() { diff --git a/midend/lib/InitAll.cpp b/midend/lib/InitAll.cpp index d2e8fdde4d..24053c33cd 100644 --- a/midend/lib/InitAll.cpp +++ b/midend/lib/InitAll.cpp @@ -26,6 +26,7 @@ #include "Dialect/Tile/TileDialect.h" #include "Dialect/Buckyball/BuckyballDialect.h" #include "Dialect/Gemmini/GemminiDialect.h" +#include "Dialect/AME/AMEDialect.h" #include "Dialect/IME/IMEDialect.h" #include "Dialect/RVV/RVVDialect.h" #include "Dialect/VectorExp/VectorExpDialect.h" @@ -41,6 +42,7 @@ void registerLowerDAPPass(); void registerLowerDIPPass(); void registerLowerGemminiPass(); void registerLowerLinalgToGemminiPass(); +void registerLowerLinalgToIMEPass(); void registerLowerIMEPass(); void registerLowerRVVPass(); void registerLowerVectorExpPass(); @@ -64,6 +66,7 @@ void registerEliminateMemRefCopyPass(); } // namespace mlir void mlir::buddy::registerAllDialects(mlir::DialectRegistry ®istry) { + registry.insert<::buddy::ame::AMEDialect>(); registry.insert<::buddy::bud::BudDialect>(); registry.insert<::buddy::dap::DAPDialect>(); registry.insert<::buddy::dip::DIPDialect>(); @@ -88,6 +91,7 @@ void mlir::buddy::registerAllPasses() { mlir::buddy::registerLowerBuckyballPass(); mlir::buddy::registerLowerGemminiPass(); mlir::buddy::registerLowerLinalgToGemminiPass(); + mlir::buddy::registerLowerLinalgToIMEPass(); mlir::buddy::registerLowerIMEPass(); mlir::buddy::registerLowerRVVPass(); mlir::buddy::registerLowerVectorExpPass(); diff --git a/midend/lib/Target/LLVMIR/CMakeLists.txt b/midend/lib/Target/LLVMIR/CMakeLists.txt index 59c517cb92..34c10c95f7 100644 --- a/midend/lib/Target/LLVMIR/CMakeLists.txt +++ b/midend/lib/Target/LLVMIR/CMakeLists.txt @@ -12,4 +12,5 @@ add_mlir_translation_library(BuddyToLLVMIRTranslationRegistration BuddyBuckyballToLLVMIRTranslation BuddyGemminiToLLVMIRTranslation MLIRIMEToLLVMIRTranslation + MLIRAMEToLLVMIRTranslation ) diff --git a/midend/lib/Target/LLVMIR/ConvertBuddyToLLVMIR.cpp b/midend/lib/Target/LLVMIR/ConvertBuddyToLLVMIR.cpp index 2c1600cf61..468bbfa116 100644 --- a/midend/lib/Target/LLVMIR/ConvertBuddyToLLVMIR.cpp +++ b/midend/lib/Target/LLVMIR/ConvertBuddyToLLVMIR.cpp @@ -27,6 +27,7 @@ #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Module.h" +#include "Target/LLVMIR/Dialect/AME/AMEToLLVMIRTranslation.h" #include "Target/LLVMIR/Dialect/Buckyball/BuckyballToLLVMIRTranslation.h" #include "Target/LLVMIR/Dialect/Gemmini/GemminiToLLVMIRTranslation.h" #include "Target/LLVMIR/Dialect/IME/IMEToLLVMIRTranslation.h" @@ -57,6 +58,7 @@ void registerBuddyToLLVMIRTranslation() { registerBuckyballDialectTranslation(registry); registerGemminiDialectTranslation(registry); registerIMEDialectTranslation(registry); + registerAMEDialectTranslation(registry); }); } } // namespace buddy diff --git a/midend/lib/Target/LLVMIR/Dialect/AME/AMEToLLVMIRTranslation.cpp b/midend/lib/Target/LLVMIR/Dialect/AME/AMEToLLVMIRTranslation.cpp new file mode 100644 index 0000000000..3c2bf8db1e --- /dev/null +++ b/midend/lib/Target/LLVMIR/Dialect/AME/AMEToLLVMIRTranslation.cpp @@ -0,0 +1,67 @@ +//======- AMEToLLVMIRTranslation.cpp - Translate AME to LLVM IR ----------====// +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// +// +// This file implements a translation between the AME dialect and LLVM IR. +// +//===----------------------------------------------------------------------===// + +#include "mlir/IR/Operation.h" +#include "mlir/Target/LLVMIR/ModuleTranslation.h" + +#include "backend/include/llvm/IR/IntrinsicsRISCV.h" +#include "llvm/IR/IRBuilder.h" + +#include "AME/AMEDialect.h" +#include "AME/AMEOps.h" +#include "Target/LLVMIR/Dialect/AME/AMEToLLVMIRTranslation.h" + +using namespace mlir; +using namespace mlir::LLVM; +using namespace buddy; + +namespace { +/// Implementation of the dialect interface that converts operations belonging +/// to the AME dialect to LLVM IR. +class AMEDialectLLVMIRTranslationInterface + : public LLVMTranslationDialectInterface { +public: + using LLVMTranslationDialectInterface::LLVMTranslationDialectInterface; + + /// Translates the given operation to LLVM IR using the provided IR builder + /// and saving the state in `moduleTranslation`. + LogicalResult + convertOperation(Operation *op, llvm::IRBuilderBase &builder, + LLVM::ModuleTranslation &moduleTranslation) const final { + Operation &opInst = *op; +#include "AME/AMEConversions.inc" + + return failure(); + } +}; +} // end namespace + +void buddy::registerAMEDialectTranslation(DialectRegistry ®istry) { + registry.insert(); + registry.addExtension(+[](MLIRContext *ctx, ame::AMEDialect *dialect) { + dialect->addInterfaces(); + }); +} + +void buddy::registerAMEDialectTranslation(MLIRContext &context) { + DialectRegistry registry; + registerAMEDialectTranslation(registry); + context.appendDialectRegistry(registry); +} diff --git a/midend/lib/Target/LLVMIR/Dialect/AME/CMakeLists.txt b/midend/lib/Target/LLVMIR/Dialect/AME/CMakeLists.txt new file mode 100644 index 0000000000..5996d9ee85 --- /dev/null +++ b/midend/lib/Target/LLVMIR/Dialect/AME/CMakeLists.txt @@ -0,0 +1,15 @@ +add_mlir_translation_library(MLIRAMEToLLVMIRTranslation + AMEToLLVMIRTranslation.cpp + + DEPENDS + BuddyAMEConversionsIncGen + buddy_intrinsics_gen + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + MLIRIR + MLIRSupport + BuddyAME + ) diff --git a/midend/lib/Target/LLVMIR/Dialect/CMakeLists.txt b/midend/lib/Target/LLVMIR/Dialect/CMakeLists.txt index 6309423ac6..1ad9d3da8a 100644 --- a/midend/lib/Target/LLVMIR/Dialect/CMakeLists.txt +++ b/midend/lib/Target/LLVMIR/Dialect/CMakeLists.txt @@ -2,3 +2,4 @@ add_subdirectory(RVV) add_subdirectory(Buckyball) add_subdirectory(Gemmini) add_subdirectory(IME) +add_subdirectory(AME) diff --git a/midend/python/CMakeLists.txt b/midend/python/CMakeLists.txt index c37eb4a07e..bf1aff7d0b 100644 --- a/midend/python/CMakeLists.txt +++ b/midend/python/CMakeLists.txt @@ -138,3 +138,5 @@ add_mlir_python_modules(BuddyMLIRPythonModules BuddyMLIRPythonCAPI MLIRPythonCAPI ) + +add_custom_target(python-package-buddy-mlir DEPENDS BuddyMLIRPythonModules) diff --git a/pyproject.toml b/pyproject.toml index 83c116eb5f..ea15669b04 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,2 +1,20 @@ [tool.black] line-length = 80 + +[build-system] +requires = [ + "setuptools>=69", + "wheel", + "build>=1.2", +] +build-backend = "setuptools.build_meta" + +[project] +name = "buddy" +version = "0.0.1" +description = "An MLIR-based compiler framework bridges DSLs (domain-specific languages) to DSAs (domain-specific architectures)." +requires-python = ">=3.10" +dynamic = ["entry-points", "scripts"] + +[tool.setuptools] +include-package-data = true diff --git a/requirements.txt b/requirements.txt index 95ffe1fad7..8e1a01c5fe 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ --pre --extra-index-url https://download.pytorch.org/whl/cpu -torch == 2.8.0 -numpy < 2 +torch == 2.10.0 +numpy >= 2 transformers == 4.56.2 tokenizers >= 0.20 sentencepiece == 0.2.0 diff --git a/scripts/release_wheel_manylinux.sh b/scripts/release_wheel_manylinux.sh new file mode 100755 index 0000000000..8aeba755aa --- /dev/null +++ b/scripts/release_wheel_manylinux.sh @@ -0,0 +1,108 @@ +#!/usr/bin/env bash +# Build a manylinux_x86_64 wheel inside the official manylinux container. +# This script must be run on a host with Docker available. +# +# Usage: +# ./scripts/build_manylinux.sh [cp_tag] +# cp_tag defaults to cp310-cp310. Other valid tags are the Python versions +# present under /opt/python in the manylinux image (e.g., cp311-cp311). + +set -euo pipefail + +IMAGE="quay.io/pypa/manylinux_2_28_x86_64" + +PY_TAG="${1:-cp310-cp310}" +TORCH_VERSION="${TORCH_VERSION:-2.8}" +# MLIR version is calculated in setup.py + +# Host dir +REPO_ROOT=$(cd "$(dirname "$0")/.." && pwd) +# Docker dir (default mount) +WORKSPACE=/workspace/buddy-mlir + +# Note: outputs are placed under build.docker/ to avoid clashing with host builds. +BUDDY_BUILD_DIR="${WORKSPACE}/build.docker" +LLVM_BUILD_DIR="${WORKSPACE}/llvm/build.docker" + +docker run --rm -i \ + -e WORKSPACE="${WORKSPACE}" \ + -e BUDDY_BUILD_DIR="${BUDDY_BUILD_DIR}" \ + -e LLVM_BUILD_DIR="${LLVM_BUILD_DIR}" \ + -e LLVM_CACHE_HIT="${LLVM_CACHE_HIT:-false}" \ + -e CLEAN_BUILD="${CLEAN_BUILD:-0}" \ + -e PY_TAG="${PY_TAG}" \ + -e TORCH_VERSION="${TORCH_VERSION}" \ + -e HOST_UID="$(id -u)" \ + -e HOST_GID="$(id -g)" \ + -e HOME=/workspace \ + -v "${REPO_ROOT}:${WORKSPACE}" \ + -w "${WORKSPACE}" \ + "${IMAGE}" \ + /bin/bash -s <<'BASH' + set -euo pipefail + set -x + + # manylinux stores multiple Python versions under /opt/python; PATH does not + # select a version by default, so we choose explicitly. + # Docs: https://github.com/pypa/manylinux#docker-images + PYBIN=/opt/python/${PY_TAG}/bin/python + if [ ! -x "$PYBIN" ]; then + echo "Python tag ${PY_TAG} not found under /opt/python" >&2 + ls /opt/python >&2 + exit 1 + fi + export PATH="/opt/python/${PY_TAG}/bin:$PATH" + + # manylinux images ship newer GCC via gcc-toolset; it is not enabled by default. + # Docs: https://github.com/pypa/manylinux#manylinux2014-2_28-and-2_34-images + if [ -f /opt/rh/gcc-toolset-14/enable ]; then + source /opt/rh/gcc-toolset-14/enable + fi + export CC=gcc + export CXX=g++ + + "$PYBIN" -m pip install --upgrade pip build auditwheel ninja cmake numpy pybind11==2.10.* nanobind==2.4.* PyYAML >/dev/null + + # Optional clean rebuild (set CLEAN_BUILD=1 to force) + if [ "${CLEAN_BUILD:-0}" = "1" ]; then + rm -rf "${LLVM_BUILD_DIR}" "${BUDDY_BUILD_DIR}" + fi + + if [ "${LLVM_CACHE_HIT:-false}" = "true" ]; then + echo "LLVM build cache hit; skipping LLVM build." + else + # Build LLVM/MLIR first + cmake -G Ninja -S "${WORKSPACE}/llvm/llvm" -B "${LLVM_BUILD_DIR}" \ + -DLLVM_ENABLE_PROJECTS="mlir;clang;openmp" \ + -DLLVM_TARGETS_TO_BUILD="host;RISCV" \ + -DLLVM_ENABLE_ASSERTIONS=ON \ + -DOPENMP_ENABLE_LIBOMPTARGET=OFF \ + -DCMAKE_BUILD_TYPE=RELEASE \ + -DMLIR_ENABLE_BINDINGS_PYTHON=ON \ + -DPython3_EXECUTABLE="$PYBIN" + ninja -C "${LLVM_BUILD_DIR}" check-clang check-mlir omp || true + fi + ${LLVM_BUILD_DIR}/bin/mlir-opt --version + + # Build buddy-mlir with Python packages enabled + cmake -G Ninja -S "${WORKSPACE}" -B "${BUDDY_BUILD_DIR}" \ + -DLLVM_DIR="${LLVM_BUILD_DIR}/lib/cmake/llvm" \ + -DMLIR_DIR="${LLVM_BUILD_DIR}/lib/cmake/mlir" \ + -DLLVM_ENABLE_ASSERTIONS=ON \ + -DCMAKE_BUILD_TYPE=Release \ + -DMLIR_ENABLE_BINDINGS_PYTHON=ON \ + -DBUDDY_MLIR_ENABLE_PYTHON_PACKAGES=ON \ + -DPython3_EXECUTABLE="$PYBIN" + ninja -C "${BUDDY_BUILD_DIR}" + ninja -C "${BUDDY_BUILD_DIR}" python-package-buddy python-package-buddy-mlir || true + ${BUDDY_BUILD_DIR}/bin/buddy-opt --version + + # Optional build tag (must start with a digit). Example: 1pytorch2_2mlir19 + "$PYBIN" -m build --wheel --outdir "${BUDDY_BUILD_DIR}/dist" + auditwheel repair "${BUDDY_BUILD_DIR}/dist"/buddy-*.whl -w "${BUDDY_BUILD_DIR}/dist" + + # Fix ownership for host user + chown -R "$HOST_UID":"$HOST_GID" "${BUDDY_BUILD_DIR}" "${LLVM_BUILD_DIR}" || true +BASH + +echo "Wheels are in ${REPO_ROOT}/build.docker/dist" diff --git a/setup.py b/setup.py new file mode 100644 index 0000000000..46db5ed590 --- /dev/null +++ b/setup.py @@ -0,0 +1,199 @@ +# ===- setup.py ----------------------------------------------------------------- +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ===--------------------------------------------------------------------------- + +from __future__ import annotations + +import os +import shutil +from pathlib import Path + +from setuptools import find_namespace_packages, find_packages, setup +from setuptools.dist import Distribution +from wheel.bdist_wheel import bdist_wheel as _bdist_wheel +from setuptools.command.build import build as _build +from setuptools.command.build_py import build_py as _build_py + +ROOT = Path(__file__).parent.resolve() + + +def _resolve_build_dir() -> Path: + """Resolve the CMake build directory that already contains compiled outputs.""" + build_dir = Path(os.environ.get("BUDDY_BUILD_DIR", "build")) + if not build_dir.is_absolute(): + build_dir = ROOT / build_dir + return build_dir.resolve() + + +CMAKE_BUILD = _resolve_build_dir() +PYTHON_PACKAGES_DIR = CMAKE_BUILD / "python_packages" +BIN_DIR = CMAKE_BUILD / "bin" +LIB_DIR = CMAKE_BUILD / "lib" + +if not PYTHON_PACKAGES_DIR.exists(): + raise SystemExit( + "buddy-mlir expects a populated CMake build tree. " + "Please configure and build with BUDDY_MLIR_ENABLE_PYTHON_PACKAGES=ON " + "before building the wheel (default build dir: ./build)." + ) + +REL_PYTHON_PACKAGES_DIR = os.path.relpath(PYTHON_PACKAGES_DIR, ROOT) +if REL_PYTHON_PACKAGES_DIR.startswith(".."): + raise SystemExit( + f"BUDDY_BUILD_DIR must reside inside the project root ({ROOT}) so packaging " + f"can use relative paths. Current: {PYTHON_PACKAGES_DIR}" + ) + +# Stage python packages into the build tree so setuptools never sees absolute paths. +STAGING_ROOT = CMAKE_BUILD / "py-stage" +if STAGING_ROOT.exists(): + shutil.rmtree(STAGING_ROOT) +STAGING_ROOT.mkdir(parents=True, exist_ok=True) +STAGING_SRC = STAGING_ROOT / "python_packages" +# Copy python outputs but drop any pre-existing egg-info that may carry absolute paths. +shutil.copytree( + PYTHON_PACKAGES_DIR, + STAGING_SRC, + ignore=shutil.ignore_patterns("*.egg-info"), +) + +SRC_DIR = os.path.relpath(STAGING_SRC, ROOT) + +buddy_pkgs = find_packages(where=SRC_DIR, include=["buddy*"]) +mlir_pkgs = find_namespace_packages(where=SRC_DIR, include=["buddy_mlir*"]) +wrapper_pkgs = find_packages(where="tools", include=["buddy_tools*"]) +packages = sorted(set(buddy_pkgs + mlir_pkgs + wrapper_pkgs)) + +package_dir = { + "": SRC_DIR, + "buddy_tools": "tools/buddy_tools", +} + + +class build_py(_build_py): + """Copy prebuilt artifacts (Python, bin, lib) into the wheel.""" + + def run(self): + self._extra_outputs = [] + super().run() + + tools_root = Path(self.build_lib) / "buddy_tools" + self._copy_tree(BIN_DIR, tools_root / "bin", allow_missing=True) + self._copy_tree(LIB_DIR, tools_root / "lib", allow_missing=True) + + def get_outputs(self, include_bytecode: bool = True): + outputs = super().get_outputs(include_bytecode) + return outputs + getattr(self, "_extra_outputs", []) + + # Helpers + def _copy_tree(self, src: Path, dst: Path, allow_missing: bool = False): + src = Path(src) + if not src.exists(): + if allow_missing: + self.warn(f"Skip missing path: {src}") + return + raise FileNotFoundError(f"Expected path not found: {src}") + + # Avoid recursive self-copy if src overlaps dst. + try: + if dst.resolve().is_relative_to(src.resolve()): + raise SystemExit(f"Refusing to copy {src} into itself ({dst})") + except AttributeError: + # Python <3.9 compatibility: manual check + src_resolved = src.resolve() + dst_resolved = dst.resolve() + if str(dst_resolved).startswith(str(src_resolved)): + raise SystemExit(f"Refusing to copy {src} into itself ({dst})") + + for path in src.rglob("*"): + if not path.is_file(): + continue + rel = path.relative_to(src) + target = dst / rel + target.parent.mkdir(parents=True, exist_ok=True) + shutil.copy2(path, target) + self._extra_outputs.append(str(target)) + + +ENTRY_POINTS = { + "console_scripts": [ + "buddy-opt=buddy_tools.cli:buddy_opt", + "buddy-translate=buddy_tools.cli:buddy_translate", + "buddy-llc=buddy_tools.cli:buddy_llc", + "buddy-lsp-server=buddy_tools.cli:buddy_lsp_server", + "buddy-frontendgen=buddy_tools.cli:buddy_frontendgen", + "buddy-audio-container-test=buddy_tools.cli:buddy_audio_container_test", + "buddy-text-container-test=buddy_tools.cli:buddy_text_container_test", + "buddy-container-test=buddy_tools.cli:buddy_container_test", + ] +} + + +def _resolve_install_requires() -> list[str]: + """Pin torch version when TORCH_VERSION is provided.""" + torch_ver = os.environ.get("TORCH_VERSION", "").strip() + if torch_ver: + return [f"torch=={torch_ver}"] + return [] + + +class build(_build): + def initialize_options(self): + super().initialize_options() + # Use a staging directory separate from the CMake build tree. + self.build_base = str(CMAKE_BUILD / "py-build") + + +class bdist_wheel(_bdist_wheel): + """Mark the wheel as non-pure and optionally set a build tag.""" + + def finalize_options(self): + super().finalize_options() + # Force platform-specific wheel since we bundle native libs. + self.root_is_pure = False + + torch_ver = os.environ.get("TORCH_VERSION", "unknown") + mlir_major = "unknown" + + llvm_version_path = ROOT / "llvm" / "cmake" / "Modules" / "LLVMVersion.cmake" + if llvm_version_path.is_file(): + for line in llvm_version_path.read_text(encoding="utf-8").splitlines(): + line = line.strip() + if line.startswith("set(LLVM_VERSION_MAJOR"): + parts = line.split() + if len(parts) >= 2: + mlir_major = parts[1].rstrip(")") + break + + self.build_number = f"1torch{torch_ver.replace('.', '')}_2mlir{mlir_major}" + + +class BinaryDistribution(Distribution): + """Force a non-pure wheel since we bundle native binaries.""" + + def has_ext_modules(self): + return True + + +setup( + packages=packages, + package_dir=package_dir, + include_package_data=True, + cmdclass={"build_py": build_py, "build": build, "bdist_wheel": bdist_wheel}, + distclass=BinaryDistribution, + install_requires=_resolve_install_requires(), + entry_points=ENTRY_POINTS, + zip_safe=False, +) diff --git a/sync_and_test.sh b/sync_and_test.sh index 8b03fd8585..c00ce5771d 100755 --- a/sync_and_test.sh +++ b/sync_and_test.sh @@ -101,7 +101,7 @@ if [ ! -d "$LLVM_MLIR_BUILD_DIR" ]; then fi # Set PYTHONPATH -export PYTHONPATH="${LLVM_MLIR_BUILD_DIR}/tools/mlir/python_packages/mlir_core:${BUDDY_MLIR_BUILD_DIR}/python_packages:${PYTHONPATH}" +export PYTHONPATH="${BUDDY_MLIR_BUILD_DIR}/python_packages:${PYTHONPATH}" print_success "Environment variables setup completed" print_info "BUDDY_MLIR_BUILD_DIR=$BUDDY_MLIR_BUILD_DIR" diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 2340c2fb46..d16e1ea9b8 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -29,7 +29,7 @@ if(BUDDY_MLIR_ENABLE_DIP_LIB) endif() if(BUDDY_MLIR_ENABLE_PYTHON_PACKAGES) - list(APPEND BUDDY_TEST_DEPENDS BuddyMLIRPythonModules) + list(APPEND BUDDY_TEST_DEPENDS python-package-buddy-mlir) endif() add_lit_testsuite(check-tests "Running the buddy regression tests..." @@ -59,6 +59,7 @@ if(BUDDY_ENABLE_E2E_TESTS) buddy-translate mlir-runner buddy-lenet-run-test-cpu + python-package-buddy ) add_lit_testsuite(check-e2e "Running E2E tests for Models directory..." diff --git a/tests/Python/test_gqa_attention.py b/tests/Python/test_gqa_attention.py index 98457e80ba..9a2199bdfc 100644 --- a/tests/Python/test_gqa_attention.py +++ b/tests/Python/test_gqa_attention.py @@ -13,19 +13,16 @@ from buddy.compiler.ops import tosa -def foo(query, k_cache, v_cache, mask, scale): - - k_unsqueeze = torch.unsqueeze(k_cache, 2) - k_slice1 = torch.narrow(k_unsqueeze, 3, 0, k_unsqueeze.size(3)) - k_slice2 = torch.narrow(k_slice1, 4, 0, k_slice1.size(4)) - k_expanded = k_slice2.expand(1, 2, 6, 1024, 128) +def foo(query, k_cache, v_cache, index, mask, scale): + k_updated = torch.index_put(k_cache, (index,), k_cache[0:1]) + k_unsqueeze = torch.unsqueeze(k_updated, 2) + k_expanded = k_unsqueeze.expand(1, 2, 6, 1024, 128) k_clone = k_expanded.clone() k_view = k_clone.view(1, 12, 1024, 128) - v_unsqueeze = torch.unsqueeze(v_cache, 2) - v_slice1 = torch.narrow(v_unsqueeze, 3, 0, v_unsqueeze.size(3)) - v_slice2 = torch.narrow(v_slice1, 4, 0, v_slice1.size(4)) - v_expanded = v_slice2.expand(1, 2, 6, 1024, 128) + v_updated = torch.index_put(v_cache, (index,), v_cache[0:1]) + v_unsqueeze = torch.unsqueeze(v_updated, 2) + v_expanded = v_unsqueeze.expand(1, 2, 6, 1024, 128) v_clone = v_expanded.clone() v_view = v_clone.view(1, 12, 1024, 128) @@ -39,8 +36,9 @@ def foo(query, k_cache, v_cache, mask, scale): in1 = torch.randn(1, 12, 1, 128) # [Batch, Head, MaxSeq, Dim] in2 = torch.randn(1, 2, 1024, 128) in3 = torch.randn(1, 2, 1024, 128) # [Batch, Head, 1, Dim] -in4 = torch.randn(1, 1, 1, 1024) -in5 = 1.0 / (128**0.5) +in4 = torch.tensor([0], dtype=torch.int64) +in5 = torch.randn(1, 1, 1, 1024) +in6 = 1.0 / (128**0.5) # Initialize the dynamo compiler. dynamo_compiler = DynamoCompiler( @@ -49,7 +47,7 @@ def foo(query, k_cache, v_cache, mask, scale): verbose=False, ) -graphs = dynamo_compiler.importer(foo, in1, in2, in3, in4, in5) +graphs = dynamo_compiler.importer(foo, in1, in2, in3, in4, in5, in6) assert len(graphs) == 1 graph = graphs[0] diff --git a/tools/buddy-opt/CMakeLists.txt b/tools/buddy-opt/CMakeLists.txt index cec3255a21..4036d56eb1 100644 --- a/tools/buddy-opt/CMakeLists.txt +++ b/tools/buddy-opt/CMakeLists.txt @@ -40,9 +40,12 @@ target_link_libraries(buddy-opt BuddyGemmini LowerGemminiPass LowerLinalgToGemminiPass + LowerLinalgToIMEPass BuddyIME BuddyIMETransforms LowerIMEPass + BuddyAME + BuddyAMETransforms MLIRGPUPasses BuddyGPUTransformOPs MLIRTestTransforms diff --git a/tools/buddy-opt/buddy-opt.cpp b/tools/buddy-opt/buddy-opt.cpp index 16de902942..53f5a82260 100644 --- a/tools/buddy-opt/buddy-opt.cpp +++ b/tools/buddy-opt/buddy-opt.cpp @@ -47,6 +47,9 @@ #include "GPU/TransformOps.h" #include "Gemmini/GemminiDialect.h" #include "Gemmini/GemminiOps.h" +#include "AME/AMEDialect.h" +#include "AME/AMEOps.h" +#include "AME/Transform.h" #include "IME/IMEDialect.h" #include "IME/IMEOps.h" #include "RVV/RVVDialect.h" @@ -92,7 +95,9 @@ void registerLowerTileToBuckyballPass(); void registerLowerBuckyballPass(); void registerLowerGemminiPass(); void registerLowerLinalgToGemminiPass(); +void registerLowerLinalgToIMEPass(); void registerLowerIMEPass(); +void registerLowerAMEPass(); void registerAssumeTightMemRefLayoutPass(); void registerStaticizeMemRefLayoutPass(); void registerConvertMemcpyToGPUPass(); @@ -132,7 +137,9 @@ int main(int argc, char **argv) { mlir::buddy::registerLowerBuckyballPass(); mlir::buddy::registerLowerGemminiPass(); mlir::buddy::registerLowerLinalgToGemminiPass(); + mlir::buddy::registerLowerLinalgToIMEPass(); mlir::buddy::registerLowerIMEPass(); + mlir::buddy::registerLowerAMEPass(); // Register Several Optimize Pass. mlir::buddy::registerMatMulVectorizationBLISPass(); @@ -181,6 +188,7 @@ int main(int argc, char **argv) { buddy::vector_exp::VectorExpDialect, buddy::vir::VIRDialect, buddy::gemmini::GemminiDialect, + buddy::ame::AMEDialect, buddy::tile::TileDialect, buddy::buckyball::BuckyballDialect, buddy::ime::IMEDialect>(); diff --git a/tools/buddy_tools/__init__.py b/tools/buddy_tools/__init__.py new file mode 100644 index 0000000000..6249b16529 --- /dev/null +++ b/tools/buddy_tools/__init__.py @@ -0,0 +1,19 @@ +# ===- __init__.py ------------------------------------------------------------- +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ===--------------------------------------------------------------------------- +# +# Init the packages in buddy_tools directory. +# +# ===--------------------------------------------------------------------------- diff --git a/tools/buddy_tools/cli.py b/tools/buddy_tools/cli.py new file mode 100644 index 0000000000..f17a3cb5a5 --- /dev/null +++ b/tools/buddy_tools/cli.py @@ -0,0 +1,71 @@ +# ===- __init__.py ------------------------------------------------------------- +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ===--------------------------------------------------------------------------- +# +# Console entry points that call the bundled Buddy executables. +# +# ===--------------------------------------------------------------------------- + +from __future__ import annotations + +import subprocess +import sys +from importlib import resources + + +def _exe_resource(exe_name: str): + return resources.files(__package__) / "bin" / exe_name + + +def _run(exe_name: str) -> int: + exe_resource = _exe_resource(exe_name) + with resources.as_file(exe_resource) as exe: + if not exe.exists(): + raise SystemExit( + f"Bundled executable '{exe_name}' is missing in the wheel." + ) + return subprocess.call([str(exe), *sys.argv[1:]]) + + +def buddy_opt() -> int: + return _run("buddy-opt") + + +def buddy_translate() -> int: + return _run("buddy-translate") + + +def buddy_llc() -> int: + return _run("buddy-llc") + + +def buddy_lsp_server() -> int: + return _run("buddy-lsp-server") + + +def buddy_frontendgen() -> int: + return _run("buddy-frontendgen") + + +def buddy_audio_container_test() -> int: + return _run("buddy-audio-container-test") + + +def buddy_text_container_test() -> int: + return _run("buddy-text-container-test") + + +def buddy_container_test() -> int: + return _run("buddy-container-test")