diff --git a/.clang-format b/.clang-format index 3bb927623ba..4bb0ebed189 100644 --- a/.clang-format +++ b/.clang-format @@ -16,7 +16,7 @@ --- Language: Cpp BasedOnStyle: Google -IndentWidth: 4 +IndentWidth: 2 TabWidth: 2 ContinuationIndentWidth: 4 AccessModifierOffset: -1 # The private/protected/public has no indent in class @@ -26,4 +26,5 @@ BinPackParameters: false BinPackArguments: false IncludeBlocks: Preserve IncludeIsMainSourceRegex: (\.cu)$ +SortIncludes: false ... diff --git a/.claude/skills/cuda-kernel-unittest.md b/.claude/skills/cuda-kernel-unittest.md new file mode 100644 index 00000000000..600708f5784 --- /dev/null +++ b/.claude/skills/cuda-kernel-unittest.md @@ -0,0 +1,174 @@ +# Skill: CUDA Kernel Unit Test + +Write unit tests for PaddlePaddle CUDA custom ops following a modular 4-layer architecture. + +## Trigger + +When the user asks to write/create/add unit tests for a CUDA kernel (`.cu` file in `custom_ops/`). + +## Steps + +1. **Read the CUDA kernel source** to understand: input/output tensors, dtypes, shapes, which tensors are CPU vs GPU, scalar attrs, in-place semantics. +2. **Write the test file** in `tests/operators/test_.py` following the structure below. + +## Test File Structure + +```python +import unittest +from typing import Any, Dict +import numpy as np +import paddle + +# --- Import ops (bypass fastdeploy.__init__) --- +try: + import sys, os + _fd_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + if _fd_root not in sys.path: + sys.path.insert(0, _fd_root) + from fastdeploy.import_ops import import_custom_ops + _package = "fastdeploy.model_executor.ops.gpu" + import_custom_ops(_package, ".fastdeploy_ops", globals()) +except ImportError as e: + print(f"Import error: {e}") + raise + +CUDA_PLACE = paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda() else paddle.CPUPlace() +CPU_PLACE = paddle.CPUPlace() + + +# ============================================================ +# Layer 1: Helpers — tensor creation / kernel invocation / output extraction +# ============================================================ + +def to_paddle_inputs(inputs: Dict[str, Any]) -> Dict[str, Any]: + """Convert numpy dict → paddle tensors. CPU tensors must be explicitly handled.""" + paddle_inputs = {} + for k, v in inputs.items(): + if isinstance(v, (int, bool, float, str)): + paddle_inputs[k] = v + elif k in ("",): # <-- tensors the kernel expects on CPU + paddle_inputs[k] = paddle.to_tensor(v, place=CPU_PLACE) + elif v is not None: + paddle_inputs[k] = paddle.to_tensor(v, place=CUDA_PLACE) + else: + paddle_inputs[k] = None + return paddle_inputs + +def run_kernel(paddle_inputs, inputs): + """Call the CUDA kernel with paddle tensors + scalar attrs.""" + kernel_name( + paddle_inputs["tensor_a"], + # ... all tensor args ... + inputs["scalar_attr"], # scalar attrs from raw dict + ) + +def get_outputs(paddle_inputs) -> Dict[str, np.ndarray]: + """Extract ALL in-place-modified tensors back to numpy.""" + keys = ["tensor_a", "tensor_b", ...] + return {k: paddle_inputs[k].numpy() for k in keys} + + +# ============================================================ +# Layer 2: Input generation +# ============================================================ + +def gen__inputs(real_bsz=8, ..., seed=42) -> Dict[str, Any]: + """Generate randomized test inputs. Returns dict with both numpy arrays and scalar configs.""" + rng = np.random.default_rng(seed) + # ... generate all numpy arrays with correct dtypes/shapes ... + return { "tensor_a": ..., "scalar_attr": ..., "real_bsz": real_bsz, ... } + + +# ============================================================ +# Layer 3: Reference implementation (pure Python/NumPy) +# ============================================================ + +def reference_(inputs: Dict[str, Any]) -> Dict[str, Any]: + """Python reference — must match CUDA kernel logic exactly.""" + # Deep-copy all mutable arrays + tensor_a = inputs["tensor_a"].copy() + # ... replicate kernel logic ... + return {"tensor_a": tensor_a, ...} + + +# ============================================================ +# Layer 4a: TEST_CONFIGS — all pure-parameter test scenarios +# ============================================================ + +TEST_CONFIGS = [ + # Each config is a dict of gen__inputs kwargs + a "name" key. + # Pure parameter variations go here — do NOT create separate test methods for them. + # + # --- basic coverage --- + {"name": "small_batch", "real_bsz": 1, "seed": 42, ...}, + {"name": "large_batch", "real_bsz": 64, "seed": 42, ...}, + # --- mode / strategy variants --- + {"name": "mode_a", "real_bsz": 8, "mode": "a", "seed": 42, ...}, + {"name": "mode_b", "real_bsz": 8, "mode": "b", "seed": 42, ...}, + # --- flags --- + {"name": "reject_all", "real_bsz": 8, "reject_all": True, "seed": 42, ...}, + {"name": "accept_all", "real_bsz": 8, "accept_all": True, "seed": 42, ...}, + # --- edge cases --- + {"name": "min_batch", "real_bsz": 1, "max_tokens": 1, "seed": 42, ...}, +] + + +# ============================================================ +# Layer 4b: Test suite +# ============================================================ + +class Test(unittest.TestCase): + + # ------ shared helpers ------ + + def _run_and_get(self, inputs): + paddle_inputs = to_paddle_inputs(inputs) + run_kernel(paddle_inputs, inputs) + return get_outputs(paddle_inputs) + + def _check_all_outputs(self, inputs, outputs): + """Compare ALL output tensors against reference + sanity checks.""" + ref = reference_(inputs) + all_keys = ["tensor_a", "tensor_b", ...] + for key in all_keys: + np.testing.assert_array_equal( + outputs[key], ref[key], err_msg=f"{key} mismatch" + ) + # Add domain-specific sanity checks here + + def _run_full_test(self, config): + inputs = gen__inputs(**config) + outputs = self._run_and_get(inputs) + self._check_all_outputs(inputs, outputs) + return outputs + + # ------ test cases ------ + + def test_configs(self): + """Run all TEST_CONFIGS via subTest (one subTest per config).""" + for cfg in TEST_CONFIGS: + with self.subTest(name=cfg["name"]): + test_cfg = {k: v for k, v in cfg.items() if k != "name"} + self._run_full_test(test_cfg) + + # Only keep separate test methods for scenarios that need tensor overrides: + def test_special_scenario(self): + """Scenarios that need manual tensor setup beyond gen_inputs params.""" + inputs = gen__inputs(real_bsz=2, seed=42) + inputs["some_tensor"][0, 2] = special_value # override specific tensor + outputs = self._run_and_get(inputs) + self._check_all_outputs(inputs, outputs) + +if __name__ == "__main__": + unittest.main() +``` + +## Key Rules + +1. **CPU vs GPU tensors**: Read the CUDA kernel `.cu` file carefully. If a tensor is `copy_to(place, false)` inside the host function, it's a CPU tensor input — must use `CPU_PLACE` in `to_paddle_inputs`. +2. **`_check_all_outputs` checks ALL tensors**: Every in-place-modified output tensor must be compared against reference. Never scatter `assertEqual`/`assertTrue` across individual test methods — all checks go through `_check_all_outputs`. +3. **Stochastic kernels**: If the kernel uses `curand` (e.g., top-p sampling), compare only deterministic positions. Skip the last sampled token in `compare_results`. Note: `curand_states` in reference should be sized to `max_step_tokens` (position count), not `bsz` (batch count). +4. **TEST_CONFIGS for pure-parameter scenarios**: Any test that only differs by `gen_inputs` parameters belongs in `TEST_CONFIGS`, not a separate `test_*` method. Only create separate methods when you need to **override specific tensor values** after generation. +5. **Test cases are thin**: Each `test_*` method should be 3-15 lines. It either calls `_run_full_test(config)` or does `gen → override → _run_and_get → _check_all_outputs`. +6. **No `fastdeploy.__init__`**: Import ops via `import_custom_ops` directly to avoid heavy dependency chain. +7. **Padding slots**: Kernel may have `max_bsz > real_bsz`. Reference impl must handle padding slots the same way as the kernel (typically no-op or stop_count++). diff --git a/.flake8 b/.flake8 index 869c57d3e61..1656330a998 100644 --- a/.flake8 +++ b/.flake8 @@ -1,5 +1,5 @@ [flake8] -ignore = E203, E402, E501, E731, E741, W503, W605, E722 +ignore = E203, E402, E501, E731, E741, W503, W605, E722, E231, W604, E702, E226, E221, E713, E271 max-line-length = 119 # E402: module level import not at top of file diff --git a/.github/actions/rerun-workflow/action.yml b/.github/actions/rerun-workflow/action.yml new file mode 100644 index 00000000000..f4d72f32d6b --- /dev/null +++ b/.github/actions/rerun-workflow/action.yml @@ -0,0 +1,30 @@ +name: 'Rerun Workflow' +description: 'Re-run GitHub Actions workflow for a given Pull Request' +inputs: + GITHUB_TOKEN: + description: 'GitHub token with repo scope' + required: true + OWNER: + description: 'Repository owner' + required: true + REPO: + description: 'Repository name' + required: true + PR_ID: + description: 'Pull Request ID' + required: true + JOB_NAME: + description: 'Job name to rerun' + required: true + +runs: + using: 'composite' + steps: + - run: bash ./.github/actions/rerun-workflow/rerun.sh + shell: bash + env: + GITHUB_TOKEN: ${{ inputs.GITHUB_TOKEN }} + OWNER: ${{ inputs.OWNER }} + REPO: ${{ inputs.REPO }} + PR_ID: ${{ inputs.PR_ID }} + JOB_NAME: ${{ inputs.JOB_NAME }} diff --git a/.github/actions/rerun-workflow/rerun.sh b/.github/actions/rerun-workflow/rerun.sh new file mode 100644 index 00000000000..dce8a7fad3e --- /dev/null +++ b/.github/actions/rerun-workflow/rerun.sh @@ -0,0 +1,77 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +set -e + +COMMIT_SHA=$(curl -s -H "Authorization: token $GITHUB_TOKEN" \ + "https://api.github.com/repos/$OWNER/$REPO/pulls/$PR_ID" | jq -r '.head.sha') + +echo "Commit SHA: $COMMIT_SHA" + +response=$(curl -s -H "Authorization: token $GITHUB_TOKEN" \ + "https://api.github.com/repos/$OWNER/$REPO/actions/runs?head_sha=$COMMIT_SHA&per_page=100") + +echo "Response: $response" + +run_ids=$(echo "$response" | jq -r '.workflow_runs[].id') + +if [ -n "$run_ids" ]; then + echo "Found run_ids for commit $COMMIT_SHA: $run_ids" + + for run_id in $run_ids; do + if [ "$JOB_NAME" = "all-failed" ]; then + echo "Rerunning all failed jobs for run_id: $run_id" + + rerun_response=$(curl -X POST -s -w "%{http_code}" -o /dev/null \ + -H "Accept: application/vnd.github.v3+json" \ + -H "Authorization: Bearer $GITHUB_TOKEN" \ + "https://api.github.com/repos/$OWNER/$REPO/actions/runs/$run_id/rerun-failed-jobs") + if [ "$rerun_response" -eq 201 ]; then + echo "Successfully requested rerun for all blocked jobs in run_id: $run_id" + else + echo "Failed to request rerun for run_id: $run_id with status code $rerun_response" + fi + + else + jobs_response=$(curl -s -H "Authorization: token $GITHUB_TOKEN" \ + "https://api.github.com/repos/$OWNER/$REPO/actions/runs/$run_id/jobs") + + echo "Jobs Response for run_id $run_id: $jobs_response" + + # if [[ "$JOB_NAME" == *"bypass"* ]]; then + block_jobs=$(echo "$jobs_response" | jq -r --arg job_name "$JOB_NAME" \ + '.jobs[] | select(.name == $job_name) | .id') + # else + # block_jobs=$(echo "$jobs_response" | jq -r --arg job_name "$JOB_NAME" \ + # '.jobs[] | select(.name == $job_name and .conclusion != "success") | .id') + # fi + + if [ -n "$block_jobs" ]; then + echo "Found block jobs for run_id $run_id: $block_jobs" + + for job_id in $block_jobs; do + echo "Rerunning job_id: $job_id" + curl -X POST -H "Accept: application/vnd.github.v3+json" \ + -H "Authorization: token $GITHUB_TOKEN" \ + "https://api.github.com/repos/$OWNER/$REPO/actions/jobs/$job_id/rerun" + done + else + echo "No block jobs found for run_id $run_id with name $JOB_NAME." + fi + fi + done +else + echo "No matching workflow runs found for commit $COMMIT_SHA." + exit 1 +fi diff --git a/.github/copilot-instructions.md b/.github/copilot-instructions.md new file mode 100644 index 00000000000..469c1bf910d --- /dev/null +++ b/.github/copilot-instructions.md @@ -0,0 +1,54 @@ +# GitHub Copilot Custom Review Instructions + +When reviewing code, focus on: + +## Security Critical Issues +- Check for hardcoded secrets, API keys, or credentials +- Look for SQL injection and XSS vulnerabilities +- Verify proper input validation and sanitization +- Review authentication and authorization logic + +## Performance Red Flags +- Identify N+1 database query problems +- Spot inefficient loops and algorithmic issues +- Check for memory leaks and resource cleanup +- Review caching opportunities for expensive operations + +## Code Quality Essentials +- Functions should be focused and appropriately sized +- Use clear, descriptive naming conventions +- Ensure proper error handling throughout + +## Review Style +- Be specific and actionable in feedback +- Explain the "why" behind recommendations +- Acknowledge good patterns when you see them +- Ask clarifying questions when code intent is unclear + +Always prioritize security vulnerabilities and performance issues that could impact users. + +Always suggest changes to improve readability. For example, this suggestion seeks to make the code more readable and also makes the validation logic reusable and testable. + +// Instead of: +if (user.email && user.email.includes('@') && user.email.length > 5) { + submitButton.enabled = true; +} else { + submitButton.enabled = false; +} + +// Consider: +function isValidEmail(email) { + return email && email.includes('@') && email.length > 5; +} + +submitButton.enabled = isValidEmail(user.email); + +## Description for pull request + +- Please check the title of the Pull Request. It needs to follow the format of [CLASS]Title, for example, [BugFix] Fix memory leak of data processor. If the title is incorrect, provide suggestions on how the committer should modify it. +- Please check the description information of the Pull Request. At a minimum, it should explain why these modifications are being made in this Pull Request and what problem is being solved. If the committer hasn't written the corresponding information or the information is incomplete, prompt the committer to make modifications. +- For all Pull Requests, please confirm whether it is necessary to add, update, or delete documentation, and remind the committer to handle it accordingly. + +## Others +- 对于所有提交的PR,你提交的评论都使用中文语言,但需要注意,代码中的注释仍然需要使用英文 +- 在你提交Pull Request时,需要注意代码风格要满足本Repo的设定,commit代码前需要`pip install pre-commit==4.2.0`并且执行`pre-commit install` diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md new file mode 100644 index 00000000000..7f3c72aa8cd --- /dev/null +++ b/.github/pull_request_template.md @@ -0,0 +1,34 @@ + + + + +## Motivation + + + +> :bulb: If this PR is a Cherry Pick, the PR title needs to follow the format by adding the [Cherry-Pick] label at the very beginning and appending the original PR ID at the end. For example, [Cherry-Pick][CI] Add check trigger and logic(#5191) + +> :bulb: 如若此PR是Cherry Pick,PR标题需遵循格式,在最开始加上[Cherry-Pick]标签,以及最后面加上原PR ID,例如[Cherry-Pick][CI] Add check trigger and logic(#5191) + +## Modifications + + + +## Usage or Command + + + + +## Accuracy Tests + + + +## Checklist + +- [ ] Add at least a tag in the PR title. + - Tag list: [`[FDConfig]`,`[APIServer]`,`[Engine]`, `[Scheduler]`, `[PD Disaggregation]`, `[Executor]`, `[Graph Optimization]`, `[Speculative Decoding]`, `[RL]`, `[Models]`, `[Quantization]`, `[Loader]`, `[OP]`, `[KVCache]`, `[DataProcessor]`, `[BugFix]`, `[Docs]`, `[CI]`, `[Optimization]`, `[Feature]`, `[Benchmark]`, `[Others]`, `[XPU]`, `[HPU]`, `[GCU]`, `[DCU]`, `[Iluvatar]`, `[Metax]`] + - You can add new tags based on the PR content, but the semantics must be clear. +- [ ] Format your code, run `pre-commit` before commit. +- [ ] Add unit tests. Please write the reason in this PR if no unit tests. +- [ ] Provide accuracy results. +- [ ] If the current PR is submitting to the `release` branch, make sure the PR has been submitted to the `develop` branch, then cherry-pick it to the `release` branch with the `[Cherry-Pick]` PR tag. diff --git a/.github/workflows/CheckPRTemplate.yml b/.github/workflows/CheckPRTemplate.yml new file mode 100644 index 00000000000..e5b3dcd3ad9 --- /dev/null +++ b/.github/workflows/CheckPRTemplate.yml @@ -0,0 +1,54 @@ +name: Check PR Template + +on: + pull_request: + branches: + - develop + - 'release/*' + +jobs: + check: + name: Check PR Template + if: ${{ github.repository_owner == 'PaddlePaddle' }} + runs-on: ubuntu-latest + env: + PR_ID: ${{ github.event.pull_request.number }} + BASE_BRANCH: ${{ github.event.pull_request.base.ref }} + AUTHOR: ${{ github.event.pull_request.user.login }} + + steps: + - name: Cleanup + run: | + rm -rf * .[^.]* + + - name: Checkout base branch + uses: actions/checkout@v6 + with: + ref: ${{ github.event.pull_request.base.ref }} + fetch-depth: 1000 + + - name: Merge PR to test branch + run: | + git fetch origin pull/${PR_ID}/merge + git checkout -b test FETCH_HEAD + + - name: Setup Python 3.10 + uses: actions/setup-python@v6 + with: + python-version: '3.10' + cache: 'pip' + + - name: Install Python dependencies + run: | + python -m pip install --upgrade pip + pip install httpx + + - name: Check PR Template + env: + AGILE_PULL_ID: ${{ env.PR_ID }} + AGILE_COMPILE_BRANCH: ${{ env.BASE_BRANCH }} + AGILE_CHECKIN_AUTHOR: ${{ env.AUTHOR }} + GITHUB_API_TOKEN: ${{ secrets.GITHUB_TOKEN }} + run: | + python scripts/CheckPRTemplate.py; EXCODE=$? + exit $EXCODE diff --git a/.github/workflows/Codestyle-Check.yml b/.github/workflows/Codestyle-Check.yml new file mode 100644 index 00000000000..6811e3fb38d --- /dev/null +++ b/.github/workflows/Codestyle-Check.yml @@ -0,0 +1,50 @@ +name: Codestyle-Check + +on: + pull_request: + branches: + - develop + - 'release/*' + +jobs: + pre-commit: + name: Pre Commit + if: ${{ github.repository_owner == 'PaddlePaddle' }} + runs-on: ubuntu-latest + env: + PR_ID: ${{ github.event.pull_request.number }} + BRANCH: ${{ github.event.pull_request.base.ref }} + + steps: + - name: Cleanup + run: | + rm -rf * .[^.]* + + - name: Checkout base repo + uses: actions/checkout@v6 + with: + ref: ${{ github.event.pull_request.base.ref }} + fetch-depth: 1000 + + - name: Merge PR to test branch + run: | + git fetch origin pull/${PR_ID}/merge + git checkout -b test FETCH_HEAD + + - name: Setup python3.10 + uses: actions/setup-python@v6 + with: + python-version: '3.10' + cache: 'pip' + + - name: Install dependencies + run: | + pip install pre-commit==4.2.0 cpplint==1.6.0 clang-format==13.0.0 + + - name: Check pre-commit + env: + SKIP_CLANG_TIDY_CHECK: "ON" + run: | + set +e + bash -x tools/codestyle/pre_commit.sh;EXCODE=$? + exit $EXCODE diff --git a/.github/workflows/_accuracy_test.yml b/.github/workflows/_accuracy_test.yml new file mode 100644 index 00000000000..4efb008da17 --- /dev/null +++ b/.github/workflows/_accuracy_test.yml @@ -0,0 +1,206 @@ +name: Accuracy Test +description: "Run Accuracy Tests" + +on: + workflow_call: + inputs: + DOCKER_IMAGE: + description: "Build Images" + required: true + type: string + default: "ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:cuda126-py310" + FASTDEPLOY_ARCHIVE_URL: + description: "URL of the compressed FastDeploy code archive." + required: true + type: string + FASTDEPLOY_WHEEL_URL: + description: "URL of the FastDeploy Wheel." + required: true + type: string + CACHE_DIR: + description: "Cache Dir Use" + required: false + type: string + default: "" + MODEL_CACHE_DIR: + description: "Cache Dir Use" + required: false + type: string + default: "" + +jobs: + accuracy_tests: + runs-on: [self-hosted, GPU-h20-1Cards] + timeout-minutes: 60 + steps: + - name: Code Prepare + shell: bash + env: + docker_image: ${{ inputs.DOCKER_IMAGE }} + fd_archive_url: ${{ inputs.FASTDEPLOY_ARCHIVE_URL }} + run: | + set -x + REPO="https://github.com/${{ github.repository }}.git" + FULL_REPO="${{ github.repository }}" + REPO_NAME="${FULL_REPO##*/}" + BASE_BRANCH="${{ github.base_ref }}" + docker pull ${docker_image} + # Clean the repository directory before starting + docker run --rm --net=host -v $(pwd):/workspace -w /workspace \ + -e "REPO_NAME=${REPO_NAME}" \ + ${docker_image} /bin/bash -c ' + CLEAN_RETRIES=3 + CLEAN_COUNT=0 + + while [ $CLEAN_COUNT -lt $CLEAN_RETRIES ]; do + echo "Attempt $((CLEAN_COUNT+1)) to remove ${REPO_NAME}* ..." + rm -rf "${REPO_NAME}"* || true + sleep 2 + + # Check if anything matching ${REPO_NAME}* still exists + if ! ls "${REPO_NAME}"* >/dev/null 2>&1; then + echo "All ${REPO_NAME}* removed successfully" + break + fi + + CLEAN_COUNT=$((CLEAN_COUNT + 1)) + done + + if ls "${REPO_NAME}"* >/dev/null 2>&1; then + echo "ERROR: Failed to clean ${REPO_NAME}* after multiple attempts" + ls -ld "${REPO_NAME}"* + exit 1 + fi + ' + + wget -q --no-proxy ${fd_archive_url} + tar -xf FastDeploy.tar.gz + rm -rf FastDeploy.tar.gz + cd FastDeploy + git config --global user.name "FastDeployCI" + git config --global user.email "fastdeploy_ci@example.com" + git log -n 3 --oneline + + - name: Run FastDeploy Base Tests + shell: bash + env: + docker_image: ${{ inputs.DOCKER_IMAGE }} + fastdeploy_wheel_url: ${{ inputs.FASTDEPLOY_WHEEL_URL }} + CACHE_DIR: ${{ inputs.CACHE_DIR }} + MODEL_CACHE_DIR: ${{ inputs.MODEL_CACHE_DIR }} + run: | + runner_name="${{ runner.name }}" + CARD_ID=$(echo "${runner_name}" | awk -F'-' '{print $NF}') + DEVICES=$(echo "$CARD_ID" | fold -w1 | paste -sd,) + DEVICE_PORT=$(echo "$DEVICES" | cut -d',' -f1) + + FLASK_PORT=$((8068 + DEVICE_PORT * 100)) + FD_API_PORT=$((8088 + DEVICE_PORT * 100)) + FD_ENGINE_QUEUE_PORT=$((8058 + DEVICE_PORT * 100)) + FD_METRICS_PORT=$((8078 + DEVICE_PORT * 100)) + FD_CACHE_QUEUE_PORT=$((8098 + DEVICE_PORT * 100)) + echo "Test ENV Parameter:" + echo "=========================================================" + echo "FLASK_PORT=${FLASK_PORT}" + echo "FD_API_PORT=${FD_API_PORT}" + echo "FD_ENGINE_QUEUE_PORT=${FD_ENGINE_QUEUE_PORT}" + echo "FD_METRICS_PORT=${FD_METRICS_PORT}" + echo "FD_CACHE_QUEUE_PORT=${FD_CACHE_QUEUE_PORT}" + echo "DEVICES=${DEVICES}" + echo "=========================================================" + + CACHE_DIR="${CACHE_DIR:-$(dirname "$(dirname "${{ github.workspace }}")")}" + echo "CACHE_DIR is set to ${CACHE_DIR}" + if [ ! -f "${CACHE_DIR}/gitconfig" ]; then + touch "${CACHE_DIR}/gitconfig" + fi + if [ ! -d "${MODEL_CACHE_DIR}" ]; then + echo "Error: MODEL_CACHE_DIR '${MODEL_CACHE_DIR}' does not exist." + exit 1 + fi + + PORTS=($FLASK_PORT $FD_API_PORT $FD_ENGINE_QUEUE_PORT $FD_METRICS_PORT $FD_CACHE_QUEUE_PORT) + LOG_FILE="./port_cleanup_$(date +%Y%m%d_%H%M%S).log" + echo "==== LOG_FILE is ${LOG_FILE} ====" + + echo "==== PORT CLEAN BEFORE TASK RUN ====" | tee -a $LOG_FILE + + for port in "${PORTS[@]}"; do + PIDS=$(lsof -t -i :$port || true) + if [ -n "$PIDS" ]; then + echo "Port $port is occupied by PID(s): $PIDS" | tee -a $LOG_FILE + echo "$PIDS" | xargs -r kill -9 + echo "Port $port cleared" | tee -a $LOG_FILE + else + echo "Port $port is free" | tee -a $LOG_FILE + fi + done + + echo "==== PORT CLEAN COMPLETE ====" | tee -a $LOG_FILE + + echo "=========================================================" + echo "Ensuring no stale container named ${runner_name} ..." + if [ "$(docker ps -a -q -f name=${runner_name})" ]; then + echo "Removing stale container: ${runner_name}" + docker rm -f ${runner_name} || true + fi + + docker run --rm --ipc=host --pid=host --net=host \ + --name ${runner_name} \ + -v $(pwd):/workspace \ + -w /workspace \ + -e fastdeploy_wheel_url=${fastdeploy_wheel_url} \ + -e "FD_API_PORT=${FD_API_PORT}" \ + -e "FD_ENGINE_QUEUE_PORT=${FD_ENGINE_QUEUE_PORT}" \ + -e "FD_METRICS_PORT=${FD_METRICS_PORT}" \ + -e "FD_CACHE_QUEUE_PORT=${FD_CACHE_QUEUE_PORT}" \ + -e "FLASK_PORT=${FLASK_PORT}" \ + -v "${MODEL_CACHE_DIR}:/MODELDATA" \ + -v "${CACHE_DIR}/gitconfig:/etc/gitconfig:ro" \ + -v "${CACHE_DIR}/.cache:/root/.cache" \ + -v "${CACHE_DIR}/ConfigDir:/root/.config" \ + -e TZ="Asia/Shanghai" \ + --gpus '"device='"${DEVICES}"'"' ${docker_image} /bin/bash -xc ' + python -m pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu126/ + + pip config set global.index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple + + python -m pip install ${fastdeploy_wheel_url} + python -m pip install pytest + + wget --no-proxy https://paddle-qa.bj.bcebos.com/zhengtianyu/tools/llm-deploy-linux-amd64 + chmod +x ./llm-deploy-linux-amd64 + ./llm-deploy-linux-amd64 -python python3.10 \ + -model_name ERNIE-4.5-0.3B-Paddle \ + -model_path /MODELDATA \ + --skip install,model + + git config --global --add safe.directory /workspace/FastDeploy + cd FastDeploy + pushd tests/ce/deploy + ps -ef | grep "${FD_CACHE_QUEUE_PORT}" | grep -v grep | awk "{print \$2}" | xargs -r kill -9 + ps -ef | grep "${FD_ENGINE_QUEUE_PORT}" | grep -v grep | awk "{print \$2}" | xargs -r kill -9 + python3.10 deploy.py > dd.log 2>&1 & + sleep 3 + curl -X POST http://0.0.0.0:${FLASK_PORT}/start \ + -H "Content-Type: application/json" \ + -d "{\"--model\": \"/MODELDATA/ERNIE-4.5-0.3B-Paddle\"}" + + curl -X POST http://localhost:${FLASK_PORT}/wait_for_infer?timeout=90 + popd + + pushd tests/ce/accuracy_cases + export URL=http://localhost:${FD_API_PORT}/v1/chat/completions + export TEMPLATE=TOKEN_LOGPROB + export MODEL_SIZE=0.3B + TEST_EXIT_CODE=0 + python gsm8k.py || TEST_EXIT_CODE=1 + popd + echo "TEST_EXIT_CODE=${TEST_EXIT_CODE}" >> /workspace/FastDeploy/exit_code.env + ' + if [ -f ./FastDeploy/exit_code.env ]; then + source ./FastDeploy/exit_code.env + cat ./FastDeploy/exit_code.env >> $GITHUB_ENV + fi + echo "TEST_EXIT_CODE=${TEST_EXIT_CODE}" + exit ${TEST_EXIT_CODE} diff --git a/.github/workflows/_base_test.yml b/.github/workflows/_base_test.yml new file mode 100644 index 00000000000..e4e53bf1b28 --- /dev/null +++ b/.github/workflows/_base_test.yml @@ -0,0 +1,293 @@ +name: Base Test +description: "Run Base Tests" + +on: + workflow_call: + inputs: + DOCKER_IMAGE: + description: "Build Images" + required: true + type: string + default: "ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:cuda126-py310" + FASTDEPLOY_ARCHIVE_URL: + description: "URL of the compressed FastDeploy code archive." + required: true + type: string + FASTDEPLOY_WHEEL_URL: + description: "URL of the FastDeploy Wheel." + required: true + type: string + CACHE_DIR: + description: "Cache Dir Use" + required: false + type: string + default: "" + MODEL_CACHE_DIR: + description: "Cache Dir Use" + required: false + type: string + default: "" + secrets: + github-token: + required: true + +jobs: + check_bypass: + uses: ./.github/workflows/check-bypass.yml + secrets: + github-token: ${{ secrets.github-token }} + with: + workflow-name: base_test + + base_tests: + runs-on: [self-hosted, GPU-h20-1Cards] + needs: check_bypass + if: ${{ inputs.FASTDEPLOY_WHEEL_URL != '' && needs.check_bypass.outputs.can-skip != 'true' }} + timeout-minutes: 60 + steps: + - name: Code Prepare + shell: bash + env: + docker_image: ${{ inputs.DOCKER_IMAGE }} + fd_archive_url: ${{ inputs.FASTDEPLOY_ARCHIVE_URL }} + run: | + set -x + REPO="https://github.com/${{ github.repository }}.git" + FULL_REPO="${{ github.repository }}" + REPO_NAME="${FULL_REPO##*/}" + BASE_BRANCH="${{ github.base_ref }}" + docker pull ${docker_image} + # Clean the repository directory before starting + docker run --rm --net=host -v $(pwd):/workspace -w /workspace \ + -e "REPO_NAME=${REPO_NAME}" \ + ${docker_image} /bin/bash -c ' + CLEAN_RETRIES=3 + CLEAN_COUNT=0 + + while [ $CLEAN_COUNT -lt $CLEAN_RETRIES ]; do + echo "Attempt $((CLEAN_COUNT+1)) to remove ${REPO_NAME}* ..." + rm -rf "${REPO_NAME}"* || true + sleep 2 + + # Check if anything matching ${REPO_NAME}* still exists + if ! ls "${REPO_NAME}"* >/dev/null 2>&1; then + echo "All ${REPO_NAME}* removed successfully" + break + fi + + CLEAN_COUNT=$((CLEAN_COUNT + 1)) + done + + if ls "${REPO_NAME}"* >/dev/null 2>&1; then + echo "ERROR: Failed to clean ${REPO_NAME}* after multiple attempts" + ls -ld "${REPO_NAME}"* + exit 1 + fi + ' + + # Download with retry and validation + MAX_RETRIES=3 + RETRY_COUNT=0 + while [ $RETRY_COUNT -lt $MAX_RETRIES ]; do + if wget -q --no-proxy ${fd_archive_url} && [ -f FastDeploy.tar.gz ] && [ -s FastDeploy.tar.gz ]; then + echo "Download successful, file size: $(stat -c%s FastDeploy.tar.gz) bytes" + break + else + RETRY_COUNT=$((RETRY_COUNT + 1)) + echo "Download failed or file is empty, retry $RETRY_COUNT/$MAX_RETRIES..." + rm -f FastDeploy.tar.gz + sleep 2 + fi + done + + if [ ! -f FastDeploy.tar.gz ] || [ ! -s FastDeploy.tar.gz ]; then + echo "ERROR: Failed to download FastDeploy.tar.gz after $MAX_RETRIES attempts" + exit 1 + fi + + # Verify tar.gz integrity before extraction + if ! tar -tzf FastDeploy.tar.gz > /dev/null 2>&1; then + echo "ERROR: FastDeploy.tar.gz is corrupted or incomplete" + exit 1 + fi + + tar -xf FastDeploy.tar.gz + rm -rf FastDeploy.tar.gz + cd FastDeploy + git config --global user.name "FastDeployCI" + git config --global user.email "fastdeploy_ci@example.com" + git log -n 3 --oneline + + - name: Run FastDeploy Base Tests + shell: bash + env: + docker_image: ${{ inputs.DOCKER_IMAGE }} + fastdeploy_wheel_url: ${{ inputs.FASTDEPLOY_WHEEL_URL }} + CACHE_DIR: ${{ inputs.CACHE_DIR }} + MODEL_CACHE_DIR: ${{ inputs.MODEL_CACHE_DIR }} + run: | + runner_name="${{ runner.name }}" + CARD_ID=$(echo "${runner_name}" | awk -F'-' '{print $NF}') + DEVICES=$(echo "$CARD_ID" | fold -w1 | paste -sd,) + DEVICE_PORT=$(echo "$DEVICES" | cut -d',' -f1) + + FLASK_PORT=$((8068 + DEVICE_PORT * 100)) + FD_API_PORT=$((8088 + DEVICE_PORT * 100)) + FD_ENGINE_QUEUE_PORT=$((8058 + DEVICE_PORT * 100)) + FD_METRICS_PORT=$((8078 + DEVICE_PORT * 100)) + FD_CACHE_QUEUE_PORT=$((8098 + DEVICE_PORT * 100)) + echo "Test ENV Parameter:" + echo "=========================================================" + echo "FLASK_PORT=${FLASK_PORT}" + echo "FD_API_PORT=${FD_API_PORT}" + echo "FD_ENGINE_QUEUE_PORT=${FD_ENGINE_QUEUE_PORT}" + echo "FD_METRICS_PORT=${FD_METRICS_PORT}" + echo "FD_CACHE_QUEUE_PORT=${FD_CACHE_QUEUE_PORT}" + echo "DEVICES=${DEVICES}" + echo "=========================================================" + + CACHE_DIR="${CACHE_DIR:-$(dirname "$(dirname "${{ github.workspace }}")")}" + echo "CACHE_DIR is set to ${CACHE_DIR}" + if [ ! -f "${CACHE_DIR}/gitconfig" ]; then + touch "${CACHE_DIR}/gitconfig" + fi + if [ ! -d "${MODEL_CACHE_DIR}" ]; then + echo "Error: MODEL_CACHE_DIR '${MODEL_CACHE_DIR}' does not exist." + exit 1 + fi + + PORTS=($FLASK_PORT $FD_API_PORT $FD_ENGINE_QUEUE_PORT $FD_METRICS_PORT $FD_CACHE_QUEUE_PORT) + LOG_FILE="./port_cleanup_$(date +%Y%m%d_%H%M%S).log" + echo "==== LOG_FILE is ${LOG_FILE} ====" + + echo "==== PORT CLEAN BEFORE TASK RUN ====" | tee -a $LOG_FILE + + for port in "${PORTS[@]}"; do + PIDS=$(lsof -t -i :$port || true) + if [ -n "$PIDS" ]; then + echo "Port $port is occupied by PID(s): $PIDS" | tee -a $LOG_FILE + echo "$PIDS" | xargs -r kill -9 + echo "Port $port cleared" | tee -a $LOG_FILE + else + echo "Port $port is free" | tee -a $LOG_FILE + fi + done + + echo "==== PORT CLEAN COMPLETE ====" | tee -a $LOG_FILE + + echo "=========================================================" + echo "Ensuring no stale container named ${runner_name} ..." + if [ "$(docker ps -a -q -f name=${runner_name})" ]; then + echo "Removing stale container: ${runner_name}" + docker rm -f ${runner_name} || true + fi + + docker run --rm --ipc=host --pid=host --net=host \ + --name ${runner_name} \ + -v $(pwd):/workspace \ + -w /workspace \ + -e fastdeploy_wheel_url=${fastdeploy_wheel_url} \ + -e "FD_API_PORT=${FD_API_PORT}" \ + -e "FD_ENGINE_QUEUE_PORT=${FD_ENGINE_QUEUE_PORT}" \ + -e "FD_METRICS_PORT=${FD_METRICS_PORT}" \ + -e "FLASK_PORT=${FLASK_PORT}" \ + -e "FD_CACHE_QUEUE_PORT=${FD_CACHE_QUEUE_PORT}" \ + -v "${MODEL_CACHE_DIR}:/MODELDATA" \ + -v "${CACHE_DIR}/gitconfig:/etc/gitconfig:ro" \ + -v "${CACHE_DIR}/.cache:/root/.cache" \ + -v "${CACHE_DIR}/ConfigDir:/root/.config" \ + -e TZ="Asia/Shanghai" \ + --gpus '"device='"${DEVICES}"'"' ${docker_image} /bin/bash -xc ' + python -m pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu126/ + + pip config set global.index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple + + python -m pip install ${fastdeploy_wheel_url} + python -m pip install pytest + + wget --no-proxy https://paddle-qa.bj.bcebos.com/zhengtianyu/tools/llm-deploy-linux-amd64 + chmod +x ./llm-deploy-linux-amd64 + ./llm-deploy-linux-amd64 -python python3.10 \ + -model_name ERNIE-4.5-0.3B-Paddle \ + -model_path /MODELDATA \ + --skip install,model + + git config --global --add safe.directory /workspace/FastDeploy + cd FastDeploy + pushd tests/ce/deploy + ps -ef | grep "${FD_CACHE_QUEUE_PORT}" | grep -v grep | awk "{print \$2}" | xargs -r kill -9 + ps -ef | grep "${FD_ENGINE_QUEUE_PORT}" | grep -v grep | awk "{print \$2}" | xargs -r kill -9 + python3.10 deploy.py > dd.log 2>&1 & + sleep 3 + curl -X POST http://0.0.0.0:${FLASK_PORT}/start \ + -H "Content-Type: application/json" \ + -d "{\"--model\": \"/MODELDATA/ERNIE-4.5-0.3B-Paddle\"}" + + check_service() { + local timeout=${1:-90} + local url="http://localhost:${FLASK_PORT}/wait_for_infer?timeout=${timeout}" + local resp + + resp=$(curl -s -X POST "$url") + + if echo "$resp" | grep -q "服务启动超时"; then + exit 8 + fi + } + + check_service 90 + popd + + pushd tests/ce/server + export URL=http://localhost:${FD_API_PORT}/v1/chat/completions + export TEMPLATE=TOKEN_LOGPROB + TEST_EXIT_CODE=0 + python -m pytest -sv test_base_chat.py test_compare_top_logprobs.py test_logprobs.py test_params_boundary.py test_seed_usage.py test_stream.py test_evil_cases.py test_completions.py test_return_token_ids.py test_update_weight.py || TEST_EXIT_CODE=1 + curl -X POST http://0.0.0.0:${FLASK_PORT}/switch \ + -H "Content-Type: application/json" \ + -d "{\"--model\": \"/MODELDATA/ERNIE-4.5-0.3B-Paddle\", \"--early-stop-config\": \"{\\\"enable_early_stop\\\":true, \\\"window_size\\\":6, \\\"threshold\\\":0.93}\"}" + check_service 90 + python -m pytest -sv test_repetition_early_stop.py || TEST_EXIT_CODE=1 + + curl -X POST http://0.0.0.0:${FLASK_PORT}/switch \ + -H "Content-Type: application/json" \ + -d "{ \"--model\": \"/MODELDATA/ERNIE-4.5-0.3B-Paddle\", \"--max-concurrency\": 5, \"--max-waiting-time\": 1 }" + check_service 90 + python -m pytest -sv test_max_concurrency.py || TEST_EXIT_CODE=1 + + curl -X POST http://0.0.0.0:${FLASK_PORT}/switch \ + -H "Content-Type: application/json" \ + -d "{ \"--model\": \"/MODELDATA/ERNIE-4.5-0.3B-Paddle\", \"--max-concurrency\": 5000, \"--max-waiting-time\": 1 }" + check_service 90 + python -m pytest -sv test_max_waiting_time.py || TEST_EXIT_CODE=1 + + curl -X POST http://0.0.0.0:${FLASK_PORT}/switch \ + -H "Content-Type: application/json" \ + -d "{\"--model\": \"/MODELDATA/ernie-4_5-21b-a3b-bf16-paddle\", \"--config\": \"ernie45t_21b_sot_wint4.yaml\", \"--enable-logprob\": \"False\"}" + check_service 360 + export TEMPLATE=TOKEN_NORMAL + python -m pytest -sv test_seed_usage.py -k "not test_seed_stream" || TEST_EXIT_CODE=1 + + curl -X POST http://0.0.0.0:${FLASK_PORT}/switch \ + -H "Content-Type: application/json" \ + -d "{\"--model\": \"/MODELDATA/ernie-4_5-21b-a3b-bf16-paddle\", \"--config\": \"ernie45t_21b_cinn_wint4.yaml\", \"--enable-logprob\": \"False\"}" + check_service 360 + export TEMPLATE=TOKEN_NORMAL + python -m pytest -sv test_seed_usage.py -k "not test_seed_stream" || TEST_EXIT_CODE=1 + + export TEMPLATE=TOKEN_NORMAL + curl -X POST http://0.0.0.0:${FLASK_PORT}/switch \ + -H "Content-Type: application/json" \ + -d "{\"--model\": \"/MODELDATA/ERNIE-4.5-VL-28B-A3B-Thinking\", \"--reasoning-parser\": \"ernie-45-vl-thinking\", \"--tool-call-parser\": \"ernie-45-vl-thinking\", \"--tensor-parallel-size\": 1, \"--quantization\": \"wint4\", \"--max-model-len\": 131072, \"--max-num-seqs\": 32, \"--no-enable-prefix-caching\": true}" + check_service 180 + python -m pytest -sv test_prompt_ids.py || TEST_EXIT_CODE=1 + + popd + echo "TEST_EXIT_CODE=${TEST_EXIT_CODE}" >> /workspace/FastDeploy/exit_code.env + ' + if [ -f ./FastDeploy/exit_code.env ]; then + source ./FastDeploy/exit_code.env + cat ./FastDeploy/exit_code.env >> $GITHUB_ENV + fi + echo "TEST_EXIT_CODE=${TEST_EXIT_CODE}" + exit ${TEST_EXIT_CODE} diff --git a/.github/workflows/_build_linux.yml b/.github/workflows/_build_linux.yml new file mode 100644 index 00000000000..172f07cfd73 --- /dev/null +++ b/.github/workflows/_build_linux.yml @@ -0,0 +1,250 @@ +name: FastDeploy Linux GPU Build Task +description: "FastDeploy packages build and upload" + +on: + workflow_call: + inputs: + DOCKER_IMAGE: + description: "Build Images" + required: true + type: string + default: "ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:cuda126-py310" + FASTDEPLOY_ARCHIVE_URL: + description: "URL of the compressed FastDeploy code archive." + required: true + type: string + COMPILE_ARCH: + description: "Build GPU Archs" + required: true + type: string + default: "80,90" + WITH_NIGHTLY_BUILD: + description: "Enable nightly build mode (e.g. add date suffix to version)" + required: false + type: string + default: "OFF" + FD_VERSION: + description: "FastDeploy Package Version" + required: false + type: string + default: "" + PADDLEVERSION: + description: "Paddle Version Build Use" + required: false + type: string + default: "" + PADDLE_WHL_URL: + description: "Paddle Wheel Package URL" + required: false + type: string + default: "" + UPLOAD: + description: "Upload Package" + required: false + type: string + default: "ON" + CACHE_DIR: + description: "Cache Dir Use" + required: false + type: string + default: "" + FD_UNIFY_BUILD: + description: "Enable unified build mode; build once without arch-specific compilation" + required: false + type: string + default: "" + outputs: + wheel_path: + description: "Output path of the generated wheel" + value: ${{ jobs.fd-build.outputs.wheel_path }} + secrets: + github-token: + required: true + +jobs: + check_bypass: + uses: ./.github/workflows/check-bypass.yml + secrets: + github-token: ${{ secrets.github-token }} + with: + workflow-name: build_gpu + + fd-build: + runs-on: [self-hosted, GPU-Build] + needs: check_bypass + if: ${{ needs.check_bypass.outputs.can-skip != 'true' }} + timeout-minutes: 360 + outputs: + wheel_path: ${{ steps.set_output.outputs.wheel_path }} + steps: + - name: Code Prepare + shell: bash + env: + docker_image: ${{ inputs.DOCKER_IMAGE }} + fd_archive_url: ${{ inputs.FASTDEPLOY_ARCHIVE_URL }} + IS_PR: ${{ github.event_name == 'pull_request' }} + run: | + set -x + REPO="https://github.com/${{ github.repository }}.git" + FULL_REPO="${{ github.repository }}" + REPO_NAME="${FULL_REPO##*/}" + BASE_BRANCH="${{ github.base_ref }}" + docker pull ${docker_image} + # Clean the repository directory before starting + docker run --rm --net=host -v $(pwd):/workspace -w /workspace \ + -e "REPO_NAME=${REPO_NAME}" \ + ${docker_image} /bin/bash -c ' + CLEAN_RETRIES=3 + CLEAN_COUNT=0 + + while [ $CLEAN_COUNT -lt $CLEAN_RETRIES ]; do + echo "Attempt $((CLEAN_COUNT+1)) to remove ${REPO_NAME}* ..." + rm -rf "${REPO_NAME}"* || true + sleep 2 + + # Check if anything matching ${REPO_NAME}* still exists + if ! ls "${REPO_NAME}"* >/dev/null 2>&1; then + echo "All ${REPO_NAME}* removed successfully" + break + fi + + CLEAN_COUNT=$((CLEAN_COUNT + 1)) + done + + if ls "${REPO_NAME}"* >/dev/null 2>&1; then + echo "ERROR: Failed to clean ${REPO_NAME}* after multiple attempts" + ls -ld "${REPO_NAME}"* + exit 1 + fi + ' + + wget -q --no-proxy ${fd_archive_url} + tar -xf FastDeploy.tar.gz + rm -rf FastDeploy.tar.gz + cd FastDeploy + git config --global user.name "FastDeployCI" + git config --global user.email "fastdeploy_ci@example.com" + git log -n 3 --oneline + - name: FastDeploy Build + shell: bash + env: + docker_image: ${{ inputs.DOCKER_IMAGE }} + compile_arch: ${{ inputs.COMPILE_ARCH }} + fd_version: ${{ inputs.FD_VERSION }} + CACHE_DIR: ${{ inputs.CACHE_DIR }} + BRANCH_REF: ${{ github.ref_name }} + PADDLEVERSION: ${{ inputs.PADDLEVERSION }} + PADDLE_WHL_URL: ${{ inputs.PADDLE_WHL_URL }} + WITH_NIGHTLY_BUILD: ${{ inputs.WITH_NIGHTLY_BUILD }} + FD_UNIFY_BUILD: ${{ inputs.FD_UNIFY_BUILD }} + run: | + set -x + runner_name="${{ runner.name }}" + CARD_ID=$(echo "${runner_name}" | awk -F'-' '{print $NF}') + gpu_id=$(echo "$CARD_ID" | fold -w1 | paste -sd,) + + IFS='/' read -ra parts <<< "${GITHUB_WORKSPACE}" + len=${#parts[@]} + CCACHE_DEFAULT_DIR="/$(IFS=/; echo "${parts[*]:1:$((len-5))}")" + echo "$CCACHE_DEFAULT_DIR" + + CACHE_DIR="${CACHE_DIR:-$CCACHE_DEFAULT_DIR}" + echo "CACHE_DIR is set to ${CACHE_DIR}" + if [ ! -f "${CACHE_DIR}/gitconfig" ]; then + touch "${CACHE_DIR}/gitconfig" + fi + PARENT_DIR=$(dirname "$WORKSPACE") + echo "PARENT_DIR:$PARENT_DIR" + docker run --rm --net=host \ + --cap-add=SYS_PTRACE --privileged --shm-size=64G \ + -v $(pwd):/workspace -w /workspace \ + -v "${CACHE_DIR}/gitconfig:/etc/gitconfig:ro" \ + -v "${CACHE_DIR}/.cache:/root/.cache" \ + -v "${CACHE_DIR}/.ccache:/root/.ccache" \ + -v "${CACHE_DIR}/ConfigDir:/root/.config" \ + -e TZ="Asia/Shanghai" \ + -e "COMPILE_ARCH=${compile_arch}" \ + -e "FD_VERSION=${fd_version}" \ + -e "WITH_NIGHTLY_BUILD=${WITH_NIGHTLY_BUILD}" \ + -e "FD_UNIFY_BUILD=${FD_UNIFY_BUILD}" \ + -e "PADDLEVERSION=${PADDLEVERSION}" \ + -e "PADDLE_WHL_URL=${PADDLE_WHL_URL}" \ + -e "BRANCH_REF=${BRANCH_REF}" \ + -e "CCACHE_MAXSIZE=50G" \ + --gpus "\"device=${gpu_id}\"" ${docker_image} /bin/bash -c ' + if [[ -n "${FD_VERSION}" ]]; then + export FASTDEPLOY_VERSION=${FD_VERSION} + echo "Custom FastDeploy version: ${FASTDEPLOY_VERSION}" + fi + + git config --global --add safe.directory /workspace/FastDeploy + chown -R $(whoami) /workspace/FastDeploy + cd FastDeploy + if [[ "${WITH_NIGHTLY_BUILD}" == "ON" ]];then + GIT_COMMIT_TIME=$(git --no-pager show -s --format=%ci HEAD) + DATE_ONLY=$(echo $GIT_COMMIT_TIME | sed "s/ .*//;s/-//g") + echo "Git Commit Time: $GIT_COMMIT_TIME" + echo "Date Only: $DATE_ONLY" + export FASTDEPLOY_VERSION="${FASTDEPLOY_VERSION}.dev${DATE_ONLY}" + fi + # 针对不同分支和tag使用不同的PaddlePaddle安装包 + if [[ "${PADDLE_WHL_URL}" != "" ]];then + python -m pip install ${PADDLE_WHL_URL} + elif [[ "${PADDLEVERSION}" != "" ]];then + python -m pip install paddlepaddle-gpu==${PADDLEVERSION} -i https://www.paddlepaddle.org.cn/packages/stable/cu126/ + else + python -m pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu126/ + fi + + pip config set global.index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple + + python -m pip install --upgrade pip + python -m pip install -r requirements.txt + python -m pip install wheel + # 编译RDMA + export FD_ENABLE_RDMA_COMPILE=1 + export FD_UNIFY_BUILD="${FD_UNIFY_BUILD}" + + if [[ "${FD_UNIFY_BUILD}" == "true" ]]; then + bash build.sh 1 python false + else + bash build.sh 1 python false [${COMPILE_ARCH}] + fi + ls ./dist/*.whl + ' + - name: Package Upload + id: set_output + env: + compile_arch: ${{ inputs.COMPILE_ARCH }} + run: | + set -x + if [[ "${{ github.event_name }}" == "pull_request" ]];then + commit_id=${{ github.event.pull_request.head.sha }} + pr_num=${{ github.event.pull_request.number }} + target_path=paddle-github-action/PR/FastDeploy/${pr_num}/${commit_id}/SM${compile_arch//,/_} + elif [[ "${{ github.ref_type }}" == "tag" ]]; then + commit_id=${{ github.sha }} + tag_name=${{ github.ref_name }} + target_path=paddle-github-action/TAG/FastDeploy/${tag_name}/${commit_id}/SM${compile_arch//,/_} + else + commit_id=${{ github.sha }} + branch_name=${{ github.ref_name }} + target_path=paddle-github-action/BRANCH/FastDeploy/${branch_name}/${commit_id}/SM${compile_arch//,/_} + fi + wget -q --no-proxy --no-check-certificate https://paddle-qa.bj.bcebos.com/CodeSync/develop/PaddlePaddle/PaddleTest/tools/bos_tools.py + push_file=$(realpath bos_tools.py) + python --version + python -m pip install bce-python-sdk==0.9.29 + cd FastDeploy/dist/ + matches=($(ls fastdeploy*.whl)) + if [ ${#matches[@]} -ne 1 ]; then + echo "Error: Found ${#matches[@]} matching files, expected exactly 1" + exit 1 + fi + fd_wheel_name=${matches[0]} + echo "Found: $fd_wheel_name" + tree -L 3 + python ${push_file} fastdeploy*.whl ${target_path} + target_path_stripped="${target_path#paddle-github-action/}" + WHEEL_PATH=https://paddle-github-action.bj.bcebos.com/${target_path_stripped}/${fd_wheel_name} + echo "wheel_path=${WHEEL_PATH}" >> $GITHUB_OUTPUT diff --git a/.github/workflows/_build_linux_cu129.yml b/.github/workflows/_build_linux_cu129.yml new file mode 100644 index 00000000000..6370268c7cb --- /dev/null +++ b/.github/workflows/_build_linux_cu129.yml @@ -0,0 +1,237 @@ +name: FastDeploy Linux GPU Build Task +description: "FastDeploy packages build and upload" + +on: + workflow_call: + inputs: + DOCKER_IMAGE: + description: "Build Images" + required: true + type: string + default: "ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:fastdeploy-build-cuda129-manylinux" + FASTDEPLOY_ARCHIVE_URL: + description: "URL of the compressed FastDeploy code archive." + required: true + type: string + COMPILE_ARCH: + description: "Build GPU Archs" + required: true + type: string + default: "80,90" + WITH_NIGHTLY_BUILD: + description: "Enable nightly build mode (e.g. add date suffix to version)" + required: false + type: string + default: "OFF" + FD_VERSION: + description: "FastDeploy Package Version" + required: false + type: string + default: "" + PADDLEVERSION: + description: "Paddle Version Build Use" + required: false + type: string + default: "" + PADDLE_WHL_URL: + description: "Paddle Wheel Package URL" + required: false + type: string + default: "" + UPLOAD: + description: "Upload Package" + required: false + type: string + default: "ON" + CACHE_DIR: + description: "Cache Dir Use" + required: false + type: string + default: "" + FD_UNIFY_BUILD: + description: "Enable unified build mode; build once without arch-specific compilation" + required: false + type: string + default: "" + outputs: + wheel_path_cu129: + description: "Output path of the generated wheel" + value: ${{ jobs.fd-build-cu129.outputs.wheel_path_cu129 }} +jobs: + fd-build-cu129: + runs-on: [self-hosted, GPU-Build-Cu129] + timeout-minutes: 360 + outputs: + wheel_path_cu129: ${{ steps.set_output.outputs.wheel_path_cu129 }} + steps: + - name: Code Prepare + shell: bash + env: + docker_image: ${{ inputs.DOCKER_IMAGE }} + fd_archive_url: ${{ inputs.FASTDEPLOY_ARCHIVE_URL }} + IS_PR: ${{ github.event_name == 'pull_request' }} + run: | + set -x + REPO="https://github.com/${{ github.repository }}.git" + FULL_REPO="${{ github.repository }}" + REPO_NAME="${FULL_REPO##*/}" + BASE_BRANCH="${{ github.base_ref }}" + + # Clean the repository directory before starting + docker run --rm --net=host -v $(pwd):/workspace -w /workspace \ + -e "REPO_NAME=${REPO_NAME}" \ + ${docker_image} /bin/bash -c ' + CLEAN_RETRIES=3 + CLEAN_COUNT=0 + + while [ $CLEAN_COUNT -lt $CLEAN_RETRIES ]; do + echo "Attempt $((CLEAN_COUNT+1)) to remove ${REPO_NAME}* ..." + rm -rf "${REPO_NAME}"* || true + sleep 2 + + # Check if anything matching ${REPO_NAME}* still exists + if ! ls "${REPO_NAME}"* >/dev/null 2>&1; then + echo "All ${REPO_NAME}* removed successfully" + break + fi + + CLEAN_COUNT=$((CLEAN_COUNT + 1)) + done + + if ls "${REPO_NAME}"* >/dev/null 2>&1; then + echo "ERROR: Failed to clean ${REPO_NAME}* after multiple attempts" + ls -ld "${REPO_NAME}"* + exit 1 + fi + ' + + wget -q --no-proxy ${fd_archive_url} + tar -xf FastDeploy.tar.gz + rm -rf FastDeploy.tar.gz + cd FastDeploy + git config --global user.name "FastDeployCI" + git config --global user.email "fastdeploy_ci@example.com" + git log -n 3 --oneline + - name: FastDeploy Build + shell: bash + env: + docker_image: ${{ inputs.DOCKER_IMAGE }} + compile_arch: ${{ inputs.COMPILE_ARCH }} + fd_version: ${{ inputs.FD_VERSION }} + CACHE_DIR: ${{ inputs.CACHE_DIR }} + BRANCH_REF: ${{ github.ref_name }} + PADDLEVERSION: ${{ inputs.PADDLEVERSION }} + PADDLE_WHL_URL: ${{ inputs.PADDLE_WHL_URL }} + WITH_NIGHTLY_BUILD: ${{ inputs.WITH_NIGHTLY_BUILD }} + FD_UNIFY_BUILD: ${{ inputs.FD_UNIFY_BUILD }} + run: | + set -x + runner_name="${{ runner.name }}" + CARD_ID=$(echo "${runner_name}" | awk -F'-' '{print $NF}') + gpu_id=$(echo "$CARD_ID" | fold -w1 | paste -sd,) + + IFS='/' read -ra parts <<< "${GITHUB_WORKSPACE}" + len=${#parts[@]} + CCACHE_DEFAULT_DIR="/$(IFS=/; echo "${parts[*]:1:$((len-5))}")" + echo "$CCACHE_DEFAULT_DIR" + + CACHE_DIR="${CACHE_DIR:-$CCACHE_DEFAULT_DIR}" + echo "CACHE_DIR is set to ${CACHE_DIR}" + if [ ! -f "${CACHE_DIR}/gitconfig" ]; then + touch "${CACHE_DIR}/gitconfig" + fi + PARENT_DIR=$(dirname "$WORKSPACE") + echo "PARENT_DIR:$PARENT_DIR" + docker run --rm --net=host \ + --cap-add=SYS_PTRACE --privileged --shm-size=64G \ + -v $(pwd):/workspace -w /workspace \ + -v "${CACHE_DIR}/gitconfig:/etc/gitconfig:ro" \ + -v "${CACHE_DIR}/.cache:/root/.cache" \ + -v "${CACHE_DIR}/.ccache:/root/.ccache" \ + -v "${CACHE_DIR}/ConfigDir:/root/.config" \ + -e TZ="Asia/Shanghai" \ + -e "COMPILE_ARCH=${compile_arch}" \ + -e "FD_VERSION=${fd_version}" \ + -e "WITH_NIGHTLY_BUILD=${WITH_NIGHTLY_BUILD}" \ + -e "FD_UNIFY_BUILD=${FD_UNIFY_BUILD}" \ + -e "PADDLEVERSION=${PADDLEVERSION}" \ + -e "PADDLE_WHL_URL=${PADDLE_WHL_URL}" \ + -e "BRANCH_REF=${BRANCH_REF}" \ + -e "CCACHE_MAXSIZE=50G" \ + --gpus "\"device=${gpu_id}\"" ${docker_image} /bin/bash -c ' + if [[ -n "${FD_VERSION}" ]]; then + export FASTDEPLOY_VERSION=${FD_VERSION} + echo "Custom FastDeploy version: ${FASTDEPLOY_VERSION}" + fi + + git config --global --add safe.directory /workspace/FastDeploy + chown -R $(whoami) /workspace/FastDeploy + cd FastDeploy + if [[ "${WITH_NIGHTLY_BUILD}" == "ON" ]];then + GIT_COMMIT_TIME=$(git --no-pager show -s --format=%ci HEAD) + DATE_ONLY=$(echo $GIT_COMMIT_TIME | sed "s/ .*//;s/-//g") + echo "Git Commit Time: $GIT_COMMIT_TIME" + echo "Date Only: $DATE_ONLY" + export FASTDEPLOY_VERSION="${FASTDEPLOY_VERSION}.dev${DATE_ONLY}" + fi + # 针对不同分支和tag使用不同的PaddlePaddle安装包 + if [[ "${PADDLE_WHL_URL}" != "" ]];then + python -m pip install ${PADDLE_WHL_URL} + elif [[ "${PADDLEVERSION}" != "" ]];then + python -m pip install paddlepaddle-gpu==${PADDLEVERSION} -i https://www.paddlepaddle.org.cn/packages/stable/cu129/ + else + python -m pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu129/ + fi + + pip config set global.index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple + + python -m pip install --upgrade pip + python -m pip install -r requirements.txt + python -m pip install wheel + # 编译RDMA + export FD_ENABLE_RDMA_COMPILE=1 + export FD_UNIFY_BUILD="${FD_UNIFY_BUILD}" + + if [[ "${FD_UNIFY_BUILD}" == "true" ]]; then + bash build.sh 1 python false + else + bash build.sh 1 python false [${COMPILE_ARCH}] + fi + ls ./dist/*.whl + ' + - name: Package Upload + id: set_output + env: + compile_arch: ${{ inputs.COMPILE_ARCH }} + run: | + set -x + if [[ "${{ github.event_name }}" == "pull_request" ]];then + commit_id=${{ github.event.pull_request.head.sha }} + pr_num=${{ github.event.pull_request.number }} + target_path=paddle-github-action/PR/FastDeploy/${pr_num}/${commit_id}/SM${compile_arch//,/_}/cu129 + elif [[ "${{ github.ref_type }}" == "tag" ]]; then + commit_id=${{ github.sha }} + tag_name=${{ github.ref_name }} + target_path=paddle-github-action/TAG/FastDeploy/${tag_name}/${commit_id}/SM${compile_arch//,/_}/cu129 + else + commit_id=${{ github.sha }} + branch_name=${{ github.ref_name }} + target_path=paddle-github-action/BRANCH/FastDeploy/${branch_name}/${commit_id}/SM${compile_arch//,/_}/cu129 + fi + wget -q --no-proxy --no-check-certificate https://paddle-qa.bj.bcebos.com/CodeSync/develop/PaddlePaddle/PaddleTest/tools/bos_tools.py + push_file=$(realpath bos_tools.py) + python --version + python -m pip install bce-python-sdk==0.9.29 + cd FastDeploy/dist/ + matches=($(ls fastdeploy*.whl)) + if [ ${#matches[@]} -ne 1 ]; then + echo "Error: Found ${#matches[@]} matching files, expected exactly 1" + exit 1 + fi + fd_wheel_name=${matches[0]} + echo "Found: $fd_wheel_name" + tree -L 3 + python ${push_file} fastdeploy*.whl ${target_path} + target_path_stripped="${target_path#paddle-github-action/}" + WHEEL_PATH=https://paddle-github-action.bj.bcebos.com/${target_path_stripped}/${fd_wheel_name} + echo "wheel_path_cu129=${WHEEL_PATH}" >> $GITHUB_OUTPUT diff --git a/.github/workflows/_build_linux_cu130.yml b/.github/workflows/_build_linux_cu130.yml new file mode 100644 index 00000000000..278aff6956b --- /dev/null +++ b/.github/workflows/_build_linux_cu130.yml @@ -0,0 +1,237 @@ +name: FastDeploy Linux GPU Build Task +description: "FastDeploy packages build and upload" + +on: + workflow_call: + inputs: + DOCKER_IMAGE: + description: "Build Images" + required: true + type: string + default: "ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:fastdeploy-build-cuda130-manylinux" + FASTDEPLOY_ARCHIVE_URL: + description: "URL of the compressed FastDeploy code archive." + required: true + type: string + COMPILE_ARCH: + description: "Build GPU Archs" + required: true + type: string + default: "80,90" + WITH_NIGHTLY_BUILD: + description: "Enable nightly build mode (e.g. add date suffix to version)" + required: false + type: string + default: "OFF" + FD_VERSION: + description: "FastDeploy Package Version" + required: false + type: string + default: "" + PADDLEVERSION: + description: "Paddle Version Build Use" + required: false + type: string + default: "" + PADDLE_WHL_URL: + description: "Paddle Wheel Package URL" + required: false + type: string + default: "" + UPLOAD: + description: "Upload Package" + required: false + type: string + default: "ON" + CACHE_DIR: + description: "Cache Dir Use" + required: false + type: string + default: "" + FD_UNIFY_BUILD: + description: "Enable unified build mode; build once without arch-specific compilation" + required: false + type: string + default: "" + outputs: + wheel_path_cu130: + description: "Output path of the generated wheel" + value: ${{ jobs.fd-build-cu130.outputs.wheel_path_cu130 }} +jobs: + fd-build-cu130: + runs-on: [self-hosted, GPU-Build-Cu130] + timeout-minutes: 360 + outputs: + wheel_path_cu130: ${{ steps.set_output.outputs.wheel_path_cu130 }} + steps: + - name: Code Prepare + shell: bash + env: + docker_image: ${{ inputs.DOCKER_IMAGE }} + fd_archive_url: ${{ inputs.FASTDEPLOY_ARCHIVE_URL }} + IS_PR: ${{ github.event_name == 'pull_request' }} + run: | + set -x + REPO="https://github.com/${{ github.repository }}.git" + FULL_REPO="${{ github.repository }}" + REPO_NAME="${FULL_REPO##*/}" + BASE_BRANCH="${{ github.base_ref }}" + + # Clean the repository directory before starting + docker run --rm --net=host -v $(pwd):/workspace -w /workspace \ + -e "REPO_NAME=${REPO_NAME}" \ + ${docker_image} /bin/bash -c ' + CLEAN_RETRIES=3 + CLEAN_COUNT=0 + + while [ $CLEAN_COUNT -lt $CLEAN_RETRIES ]; do + echo "Attempt $((CLEAN_COUNT+1)) to remove ${REPO_NAME}* ..." + rm -rf "${REPO_NAME}"* || true + sleep 2 + + # Check if anything matching ${REPO_NAME}* still exists + if ! ls "${REPO_NAME}"* >/dev/null 2>&1; then + echo "All ${REPO_NAME}* removed successfully" + break + fi + + CLEAN_COUNT=$((CLEAN_COUNT + 1)) + done + + if ls "${REPO_NAME}"* >/dev/null 2>&1; then + echo "ERROR: Failed to clean ${REPO_NAME}* after multiple attempts" + ls -ld "${REPO_NAME}"* + exit 1 + fi + ' + + wget -q --no-proxy ${fd_archive_url} + tar -xf FastDeploy.tar.gz + rm -rf FastDeploy.tar.gz + cd FastDeploy + git config --global user.name "FastDeployCI" + git config --global user.email "fastdeploy_ci@example.com" + git log -n 3 --oneline + - name: FastDeploy Build + shell: bash + env: + docker_image: ${{ inputs.DOCKER_IMAGE }} + compile_arch: ${{ inputs.COMPILE_ARCH }} + fd_version: ${{ inputs.FD_VERSION }} + CACHE_DIR: ${{ inputs.CACHE_DIR }} + BRANCH_REF: ${{ github.ref_name }} + PADDLEVERSION: ${{ inputs.PADDLEVERSION }} + PADDLE_WHL_URL: ${{ inputs.PADDLE_WHL_URL }} + WITH_NIGHTLY_BUILD: ${{ inputs.WITH_NIGHTLY_BUILD }} + FD_UNIFY_BUILD: ${{ inputs.FD_UNIFY_BUILD }} + run: | + set -x + runner_name="${{ runner.name }}" + CARD_ID=$(echo "${runner_name}" | awk -F'-' '{print $NF}') + gpu_id=$(echo "$CARD_ID" | fold -w1 | paste -sd,) + + IFS='/' read -ra parts <<< "${GITHUB_WORKSPACE}" + len=${#parts[@]} + CCACHE_DEFAULT_DIR="/$(IFS=/; echo "${parts[*]:1:$((len-5))}")" + echo "$CCACHE_DEFAULT_DIR" + + CACHE_DIR="${CACHE_DIR:-$CCACHE_DEFAULT_DIR}" + echo "CACHE_DIR is set to ${CACHE_DIR}" + if [ ! -f "${CACHE_DIR}/gitconfig" ]; then + touch "${CACHE_DIR}/gitconfig" + fi + PARENT_DIR=$(dirname "$WORKSPACE") + echo "PARENT_DIR:$PARENT_DIR" + docker run --rm --net=host \ + --cap-add=SYS_PTRACE --privileged --shm-size=64G \ + -v $(pwd):/workspace -w /workspace \ + -v "${CACHE_DIR}/gitconfig:/etc/gitconfig:ro" \ + -v "${CACHE_DIR}/.cache_cu130:/root/.cache" \ + -v "${CACHE_DIR}/.ccache_cu130:/root/.ccache" \ + -v "${CACHE_DIR}/ConfigDir:/root/.config" \ + -e TZ="Asia/Shanghai" \ + -e "COMPILE_ARCH=${compile_arch}" \ + -e "FD_VERSION=${fd_version}" \ + -e "WITH_NIGHTLY_BUILD=${WITH_NIGHTLY_BUILD}" \ + -e "FD_UNIFY_BUILD=${FD_UNIFY_BUILD}" \ + -e "PADDLEVERSION=${PADDLEVERSION}" \ + -e "PADDLE_WHL_URL=${PADDLE_WHL_URL}" \ + -e "BRANCH_REF=${BRANCH_REF}" \ + -e "CCACHE_MAXSIZE=50G" \ + --gpus "\"device=${gpu_id}\"" ${docker_image} /bin/bash -c ' + if [[ -n "${FD_VERSION}" ]]; then + export FASTDEPLOY_VERSION=${FD_VERSION} + echo "Custom FastDeploy version: ${FASTDEPLOY_VERSION}" + fi + + git config --global --add safe.directory /workspace/FastDeploy + chown -R $(whoami) /workspace/FastDeploy + cd FastDeploy + if [[ "${WITH_NIGHTLY_BUILD}" == "ON" ]];then + GIT_COMMIT_TIME=$(git --no-pager show -s --format=%ci HEAD) + DATE_ONLY=$(echo $GIT_COMMIT_TIME | sed "s/ .*//;s/-//g") + echo "Git Commit Time: $GIT_COMMIT_TIME" + echo "Date Only: $DATE_ONLY" + export FASTDEPLOY_VERSION="${FASTDEPLOY_VERSION}.dev${DATE_ONLY}" + fi + # 针对不同分支和tag使用不同的PaddlePaddle安装包 + if [[ "${PADDLE_WHL_URL}" != "" ]];then + python -m pip install ${PADDLE_WHL_URL} + elif [[ "${PADDLEVERSION}" != "" ]];then + python -m pip install paddlepaddle-gpu==${PADDLEVERSION} -i https://www.paddlepaddle.org.cn/packages/stable/cu130/ + else + python -m pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu130/ + fi + + pip config set global.index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple + + python -m pip install --upgrade pip + python -m pip install -r requirements.txt + python -m pip install wheel + # 编译RDMA + export FD_ENABLE_RDMA_COMPILE=1 + export FD_UNIFY_BUILD="${FD_UNIFY_BUILD}" + + if [[ "${FD_UNIFY_BUILD}" == "true" ]]; then + bash build.sh 1 python false + else + bash build.sh 1 python false [${COMPILE_ARCH}] + fi + ls ./dist/*.whl + ' + - name: Package Upload + id: set_output + env: + compile_arch: ${{ inputs.COMPILE_ARCH }} + run: | + set -x + if [[ "${{ github.event_name }}" == "pull_request" ]];then + commit_id=${{ github.event.pull_request.head.sha }} + pr_num=${{ github.event.pull_request.number }} + target_path=paddle-github-action/PR/FastDeploy/${pr_num}/${commit_id}/SM${compile_arch//,/_}/cu130 + elif [[ "${{ github.ref_type }}" == "tag" ]]; then + commit_id=${{ github.sha }} + tag_name=${{ github.ref_name }} + target_path=paddle-github-action/TAG/FastDeploy/${tag_name}/${commit_id}/SM${compile_arch//,/_}/cu130 + else + commit_id=${{ github.sha }} + branch_name=${{ github.ref_name }} + target_path=paddle-github-action/BRANCH/FastDeploy/${branch_name}/${commit_id}/SM${compile_arch//,/_}/cu130 + fi + wget -q --no-proxy --no-check-certificate https://paddle-qa.bj.bcebos.com/CodeSync/develop/PaddlePaddle/PaddleTest/tools/bos_tools.py + push_file=$(realpath bos_tools.py) + python --version + python -m pip install bce-python-sdk==0.9.29 + cd FastDeploy/dist/ + matches=($(ls fastdeploy*.whl)) + if [ ${#matches[@]} -ne 1 ]; then + echo "Error: Found ${#matches[@]} matching files, expected exactly 1" + exit 1 + fi + fd_wheel_name=${matches[0]} + echo "Found: $fd_wheel_name" + tree -L 3 + python ${push_file} fastdeploy*.whl ${target_path} + target_path_stripped="${target_path#paddle-github-action/}" + WHEEL_PATH=https://paddle-github-action.bj.bcebos.com/${target_path_stripped}/${fd_wheel_name} + echo "wheel_path_cu130=${WHEEL_PATH}" >> $GITHUB_OUTPUT diff --git a/.github/workflows/_build_linux_fd_router.yml b/.github/workflows/_build_linux_fd_router.yml new file mode 100644 index 00000000000..b600cc2328e --- /dev/null +++ b/.github/workflows/_build_linux_fd_router.yml @@ -0,0 +1,213 @@ +name: FastDeploy Linux GPU FD_ROUTER Build Task +description: "FastDeploy FD_ROUTER build and upload" + +on: + workflow_call: + inputs: + DOCKER_IMAGE: + description: "Build Images" + required: true + type: string + default: "ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:fastdeploy-build-cuda129-manylinux" + FASTDEPLOY_ARCHIVE_URL: + description: "URL of the compressed FastDeploy code archive." + required: true + type: string + COMPILE_ARCH: + description: "Build GPU Archs" + required: true + type: string + default: "80,90" + WITH_NIGHTLY_BUILD: + description: "Enable nightly build mode (e.g. add date suffix to version)" + required: false + type: string + default: "OFF" + FD_VERSION: + description: "FastDeploy Package Version" + required: false + type: string + default: "" + PADDLEVERSION: + description: "Paddle Version Build Use" + required: false + type: string + default: "" + PADDLE_WHL_URL: + description: "Paddle Wheel Package URL" + required: false + type: string + default: "" + UPLOAD: + description: "Upload Package" + required: false + type: string + default: "ON" + CACHE_DIR: + description: "Cache Dir Use" + required: false + type: string + default: "" + outputs: + fd_router_path: + description: "Output path of the generated wheel" + value: ${{ jobs.fd-router-build.outputs.fd_router_path }} +jobs: + fd-router-build: + runs-on: [self-hosted, GPU-Build-Cu129] + timeout-minutes: 360 + outputs: + fd_router_path: ${{ steps.set_output.outputs.fd_router_path }} + steps: + - name: Code Prepare + shell: bash + env: + docker_image: ${{ inputs.DOCKER_IMAGE }} + fd_archive_url: ${{ inputs.FASTDEPLOY_ARCHIVE_URL }} + IS_PR: ${{ github.event_name == 'pull_request' }} + run: | + set -x + REPO="https://github.com/${{ github.repository }}.git" + FULL_REPO="${{ github.repository }}" + REPO_NAME="${FULL_REPO##*/}" + BASE_BRANCH="${{ github.base_ref }}" + # docker pull ${docker_image} + # Clean the repository directory before starting + docker run --rm --net=host -v $(pwd):/workspace -w /workspace \ + -e "REPO_NAME=${REPO_NAME}" \ + ${docker_image} /bin/bash -c ' + CLEAN_RETRIES=3 + CLEAN_COUNT=0 + + while [ $CLEAN_COUNT -lt $CLEAN_RETRIES ]; do + echo "Attempt $((CLEAN_COUNT+1)) to remove ${REPO_NAME}* ..." + rm -rf "${REPO_NAME}"* || true + sleep 2 + + # Check if anything matching ${REPO_NAME}* still exists + if ! ls "${REPO_NAME}"* >/dev/null 2>&1; then + echo "All ${REPO_NAME}* removed successfully" + break + fi + + CLEAN_COUNT=$((CLEAN_COUNT + 1)) + done + + if ls "${REPO_NAME}"* >/dev/null 2>&1; then + echo "ERROR: Failed to clean ${REPO_NAME}* after multiple attempts" + ls -ld "${REPO_NAME}"* + exit 1 + fi + ' + + wget -q --no-proxy ${fd_archive_url} + tar -xf FastDeploy.tar.gz + rm -rf FastDeploy.tar.gz + cd FastDeploy + git config --global user.name "FastDeployCI" + git config --global user.email "fastdeploy_ci@example.com" + git log -n 3 --oneline + - name: FastDeploy FD_ROUTER Build + shell: bash + env: + docker_image: ${{ inputs.DOCKER_IMAGE }} + compile_arch: ${{ inputs.COMPILE_ARCH }} + fd_version: ${{ inputs.FD_VERSION }} + CACHE_DIR: ${{ inputs.CACHE_DIR }} + BRANCH_REF: ${{ github.ref_name }} + PADDLEVERSION: ${{ inputs.PADDLEVERSION }} + PADDLE_WHL_URL: ${{ inputs.PADDLE_WHL_URL }} + WITH_NIGHTLY_BUILD: ${{ inputs.WITH_NIGHTLY_BUILD }} + run: | + set -x + runner_name="${{ runner.name }}" + CARD_ID=$(echo "${runner_name}" | awk -F'-' '{print $NF}') + gpu_id=$(echo "$CARD_ID" | fold -w1 | paste -sd,) + + IFS='/' read -ra parts <<< "${GITHUB_WORKSPACE}" + len=${#parts[@]} + CCACHE_DEFAULT_DIR="/$(IFS=/; echo "${parts[*]:1:$((len-5))}")" + echo "$CCACHE_DEFAULT_DIR" + + CACHE_DIR="${CACHE_DIR:-$CCACHE_DEFAULT_DIR}" + echo "CACHE_DIR is set to ${CACHE_DIR}" + if [ ! -f "${CACHE_DIR}/gitconfig" ]; then + touch "${CACHE_DIR}/gitconfig" + fi + PARENT_DIR=$(dirname "$WORKSPACE") + echo "PARENT_DIR:$PARENT_DIR" + docker run --rm --net=host \ + --cap-add=SYS_PTRACE --privileged --shm-size=64G \ + -v $(pwd):/workspace -w /workspace \ + -v "${CACHE_DIR}/gitconfig:/etc/gitconfig:ro" \ + -v "${CACHE_DIR}/.cache:/root/.cache" \ + -v "${CACHE_DIR}/.ccache:/root/.ccache" \ + -v "${CACHE_DIR}/ConfigDir:/root/.config" \ + -e TZ="Asia/Shanghai" \ + -e "COMPILE_ARCH=${compile_arch}" \ + -e "FD_VERSION=${fd_version}" \ + -e "WITH_NIGHTLY_BUILD=${WITH_NIGHTLY_BUILD}" \ + -e "PADDLEVERSION=${PADDLEVERSION}" \ + -e "PADDLE_WHL_URL=${PADDLE_WHL_URL}" \ + -e "BRANCH_REF=${BRANCH_REF}" \ + -e "CCACHE_MAXSIZE=50G" \ + --gpus "\"device=${gpu_id}\"" ${docker_image} /bin/bash -c ' + if [[ -n "${FD_VERSION}" ]]; then + export FASTDEPLOY_VERSION=${FD_VERSION} + echo "Custom FastDeploy version: ${FASTDEPLOY_VERSION}" + fi + + git config --global --add safe.directory /workspace/FastDeploy + chown -R $(whoami) /workspace/FastDeploy + cd FastDeploy + if [[ "${WITH_NIGHTLY_BUILD}" == "ON" ]];then + GIT_COMMIT_TIME=$(git --no-pager show -s --format=%ci HEAD) + DATE_ONLY=$(echo $GIT_COMMIT_TIME | sed "s/ .*//;s/-//g") + echo "Git Commit Time: $GIT_COMMIT_TIME" + echo "Date Only: $DATE_ONLY" + export FASTDEPLOY_VERSION="${FASTDEPLOY_VERSION}.dev${DATE_ONLY}" + fi + cd fastdeploy/golang_router + go --version + + bash ./build.sh + ls /usr/local/bin/fd-router* + cp -r /usr/local/bin/fd-router* ./ + ' + + - name: Package Upload + id: set_output + env: + compile_arch: ${{ inputs.COMPILE_ARCH }} + run: | + set -x + if [[ "${{ github.event_name }}" == "pull_request" ]];then + commit_id=${{ github.event.pull_request.head.sha }} + pr_num=${{ github.event.pull_request.number }} + target_path=paddle-github-action/PR/FastDeploy/${pr_num}/${commit_id} + elif [[ "${{ github.ref_type }}" == "tag" ]]; then + commit_id=${{ github.sha }} + tag_name=${{ github.ref_name }} + target_path=paddle-github-action/TAG/FastDeploy/${tag_name}/${commit_id} + else + commit_id=${{ github.sha }} + branch_name=${{ github.ref_name }} + target_path=paddle-github-action/BRANCH/FastDeploy/${branch_name}/${commit_id} + fi + wget -q --no-proxy --no-check-certificate https://paddle-qa.bj.bcebos.com/CodeSync/develop/PaddlePaddle/PaddleTest/tools/bos_tools.py + push_file=$(realpath bos_tools.py) + python --version + python -m pip install bce-python-sdk==0.9.29 + + cd FastDeploy/fastdeploy/golang_router + if [ ! -f fd-router ]; then + echo "Error: fd-router not found in FastDeploy/fastdeploy/golang_router" + exit 1 + fi + + echo "Found: fd-router" + + python ${push_file} fd-router ${target_path} + target_path_stripped="${target_path#paddle-github-action/}" + FD_ROUTER_PATH=https://paddle-github-action.bj.bcebos.com/${target_path_stripped}/fd-router + echo "fd_router_path=${FD_ROUTER_PATH}" >> $GITHUB_OUTPUT diff --git a/.github/workflows/_build_linux_rl.yml b/.github/workflows/_build_linux_rl.yml new file mode 100644 index 00000000000..ede288c805a --- /dev/null +++ b/.github/workflows/_build_linux_rl.yml @@ -0,0 +1,204 @@ +name: FastDeploy Linux GPU Build Task +description: "FastDeploy packages build and upload" + +on: + workflow_call: + inputs: + DOCKER_IMAGE: + description: "Build Images" + required: true + type: string + default: "iregistry.baidu-int.com/tiangexiao/base-images:paddlecloud-ubuntu24.04-gcc13.3-cuda12.9-cudnn9.9-bccl1.4.1.4-nccl2.26.5-openmpi4.1.5-FleetY13.0.0-rc2" + FASTDEPLOY_ARCHIVE_URL: + description: "URL of the compressed FastDeploy code archive." + required: true + type: string + COMPILE_ARCH: + description: "Build GPU Archs" + required: true + type: string + default: "80,90" + WITH_NIGHTLY_BUILD: + description: "Enable nightly build mode (e.g. add date suffix to version)" + required: false + type: string + default: "OFF" + FD_VERSION: + description: "FastDeploy Package Version" + required: false + type: string + default: "" + PADDLEVERSION: + description: "Paddle Version Build Use" + required: false + type: string + default: "" + PADDLE_WHL_URL: + description: "Paddle Wheel Package URL" + required: false + type: string + default: "" + UPLOAD: + description: "Upload Package" + required: false + type: string + default: "ON" + CACHE_DIR: + description: "Cache Dir Use" + required: false + type: string + default: "" + outputs: + wheel_path_rl: + description: "Output path of the generated wheel" + value: ${{ jobs.fd-build-rl.outputs.wheel_path_rl }} +jobs: + fd-build-rl: + runs-on: [self-hosted, GPU-Build] + timeout-minutes: 360 + outputs: + wheel_path_rl: ${{ steps.set_output.outputs.wheel_path_rl }} + steps: + - name: Code Prepare + shell: bash + env: + docker_image: ${{ inputs.DOCKER_IMAGE }} + fd_archive_url: ${{ inputs.FASTDEPLOY_ARCHIVE_URL }} + IS_PR: ${{ github.event_name == 'pull_request' }} + run: | + set -x + REPO="https://github.com/${{ github.repository }}.git" + FULL_REPO="${{ github.repository }}" + REPO_NAME="${FULL_REPO##*/}" + BASE_BRANCH="${{ github.base_ref }}" + + # Clean the repository directory before starting + docker run --rm --net=host -v $(pwd):/workspace -w /workspace \ + -e "REPO_NAME=${REPO_NAME}" \ + ${docker_image} /bin/bash -c ' + CLEAN_RETRIES=3 + CLEAN_COUNT=0 + + while [ $CLEAN_COUNT -lt $CLEAN_RETRIES ]; do + echo "Attempt $((CLEAN_COUNT+1)) to remove ${REPO_NAME}* ..." + rm -rf "${REPO_NAME}"* || true + sleep 2 + + # Check if anything matching ${REPO_NAME}* still exists + if ! ls "${REPO_NAME}"* >/dev/null 2>&1; then + echo "All ${REPO_NAME}* removed successfully" + break + fi + + CLEAN_COUNT=$((CLEAN_COUNT + 1)) + done + + if ls "${REPO_NAME}"* >/dev/null 2>&1; then + echo "ERROR: Failed to clean ${REPO_NAME}* after multiple attempts" + ls -ld "${REPO_NAME}"* + exit 1 + fi + ' + + wget -q --no-proxy ${fd_archive_url} + tar -xf FastDeploy.tar.gz + rm -rf FastDeploy.tar.gz + cd FastDeploy + git config --global user.name "FastDeployCI" + git config --global user.email "fastdeploy_ci@example.com" + git log -n 3 --oneline + - name: FastDeploy Build + shell: bash + env: + docker_image: ${{ inputs.DOCKER_IMAGE }} + compile_arch: ${{ inputs.COMPILE_ARCH }} + fd_version: ${{ inputs.FD_VERSION }} + CACHE_DIR: ${{ inputs.CACHE_DIR }} + BRANCH_REF: ${{ github.ref_name }} + PADDLEVERSION: ${{ inputs.PADDLEVERSION }} + PADDLE_WHL_URL: ${{ inputs.PADDLE_WHL_URL }} + WITH_NIGHTLY_BUILD: ${{ inputs.WITH_NIGHTLY_BUILD }} + run: | + set -x + runner_name="${{ runner.name }}" + CARD_ID=$(echo "${runner_name}" | awk -F'-' '{print $NF}') + gpu_id=$(echo "$CARD_ID" | fold -w1 | paste -sd,) + + IFS='/' read -ra parts <<< "${GITHUB_WORKSPACE}" + len=${#parts[@]} + CCACHE_DEFAULT_DIR="/$(IFS=/; echo "${parts[*]:1:$((len-5))}")" + echo "$CCACHE_DEFAULT_DIR" + + CACHE_DIR="${CACHE_DIR:-$CCACHE_DEFAULT_DIR}" + echo "CACHE_DIR is set to ${CACHE_DIR}" + if [ ! -f "${CACHE_DIR}/gitconfig" ]; then + touch "${CACHE_DIR}/gitconfig" + fi + PARENT_DIR=$(dirname "$WORKSPACE") + echo "PARENT_DIR:$PARENT_DIR" + docker run --rm --net=host \ + --cap-add=SYS_PTRACE --privileged --shm-size=64G \ + -v $(pwd):/workspace -w /workspace \ + -v "${CACHE_DIR}/gitconfig:/etc/gitconfig:ro" \ + -v "${CACHE_DIR}/.cache_rl:/root/.cache" \ + -v "${CACHE_DIR}/.ccache_rl:/root/.ccache" \ + -v "${CACHE_DIR}/ConfigDir:/root/.config" \ + -e TZ="Asia/Shanghai" \ + -e "COMPILE_ARCH=${compile_arch}" \ + -e "FD_VERSION=${fd_version}" \ + -e "WITH_NIGHTLY_BUILD=${WITH_NIGHTLY_BUILD}" \ + -e "PADDLEVERSION=${PADDLEVERSION}" \ + -e "PADDLE_WHL_URL=${PADDLE_WHL_URL}" \ + -e "BRANCH_REF=${BRANCH_REF}" \ + -e "CCACHE_MAXSIZE=50G" \ + --gpus "\"device=${gpu_id}\"" ${docker_image} /bin/bash -c ' + if [[ -n "${FD_VERSION}" ]]; then + export FASTDEPLOY_VERSION=${FD_VERSION} + echo "Custom FastDeploy version: ${FASTDEPLOY_VERSION}" + fi + + git config --global --add safe.directory /workspace/FastDeploy + chown -R $(whoami) /workspace/FastDeploy + cd FastDeploy + + python -m pip uninstall paddlepaddle-gpu -y || true + wget -q --no-proxy https://paddle-qa.bj.bcebos.com/paddle-pipeline/Develop-TagBuild-Training-Linux-Gpu-Cuda12.9-Cudnn9.9-Trt10.5-Mkl-Avx-Gcc11-SelfBuiltPypiUse/latest/paddlepaddle_gpu-0.0.0-cp310-cp310-linux_x86_64.whl + python -m pip install paddlepaddle_gpu-0.0.0-cp310-cp310-linux_x86_64.whl + + pip config set global.index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple + + python -m pip install --upgrade pip + python -m pip install -r requirements.txt + python -m pip install wheel + # 编译RDMA + export FD_ENABLE_RDMA_COMPILE=1 + bash build.sh 1 python false [${COMPILE_ARCH}] + ls ./dist/*.whl + ' + - name: Package Upload + id: set_output + env: + compile_arch: ${{ inputs.COMPILE_ARCH }} + run: | + set -x + commit_id=${{ github.sha }} + branch_name=${{ github.ref_name }} + target_path=paddle-github-action/BRANCH/FastDeploy_RL/${branch_name}/${commit_id}/SM${compile_arch//,/_} + + wget -q --no-proxy --no-check-certificate https://paddle-qa.bj.bcebos.com/CodeSync/develop/PaddlePaddle/PaddleTest/tools/bos_tools.py + push_file=$(realpath bos_tools.py) + python --version + python -m pip install bce-python-sdk==0.9.29 + cd FastDeploy/dist/ + matches=($(ls fastdeploy*.whl)) + if [ ${#matches[@]} -ne 1 ]; then + echo "Error: Found ${#matches[@]} matching files, expected exactly 1" + exit 1 + fi + fd_wheel_name=${matches[0]} + echo "Found: $fd_wheel_name" + tree -L 3 + python ${push_file} fastdeploy*.whl ${target_path} + target_path_stripped="${target_path#paddle-github-action/}" + WHEEL_PATH=https://paddle-github-action.bj.bcebos.com/${target_path_stripped}/${fd_wheel_name} + echo "wheel_path_rl=${WHEEL_PATH}" >> $GITHUB_OUTPUT diff --git a/.github/workflows/_build_xpu.yml b/.github/workflows/_build_xpu.yml new file mode 100644 index 00000000000..b9bab8381d0 --- /dev/null +++ b/.github/workflows/_build_xpu.yml @@ -0,0 +1,207 @@ +name: XPU-Build-Test + +on: + workflow_call: + inputs: + DOCKER_IMAGE: + description: "Build Images" + required: true + type: string + default: "ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/fastdeploy-xpu:ci" + FASTDEPLOY_ARCHIVE_URL: + description: "URL of the compressed FastDeploy code archive." + required: true + type: string + WITH_NIGHTLY_BUILD: + description: "Enable nightly build mode (e.g. add date suffix to version)" + required: false + type: string + default: "OFF" + FD_VERSION: + description: "FastDeploy Package Version" + required: false + type: string + default: "" + PADDLEVERSION: + description: "Paddle Version Build Use" + required: false + type: string + default: "" + PADDLE_WHL_URL: + description: "Paddle Wheel Package URL" + required: false + type: string + default: "" + outputs: + wheel_path: + description: "Output path of the generated wheel" + value: ${{ jobs.xpu-build-test.outputs.wheel_path }} + secrets: + github-token: + required: true + +jobs: + check_bypass: + uses: ./.github/workflows/check-bypass.yml + secrets: + github-token: ${{ secrets.github-token }} + with: + workflow-name: build_xpu + + xpu-build-test: + runs-on: [self-hosted, XPU-P800] + needs: check_bypass + if: ${{ needs.check_bypass.outputs.can-skip != 'true' }} + outputs: + wheel_path: ${{ steps.set_output.outputs.wheel_path }} + steps: + - name: Code Prepare + shell: bash + env: + docker_image: ${{ inputs.DOCKER_IMAGE }} + fd_archive_url: ${{ inputs.FASTDEPLOY_ARCHIVE_URL }} + IS_PR: ${{ github.event_name == 'pull_request' }} + run: | + set -x + REPO="https://github.com/${{ github.repository }}.git" + FULL_REPO="${{ github.repository }}" + REPO_NAME="${FULL_REPO##*/}" + BASE_BRANCH="${{ github.base_ref }}" + docker pull ${docker_image} || true + # Clean the repository directory before starting + docker run --rm --net=host -v $(pwd):/workspace -w /workspace \ + -e "REPO_NAME=${REPO_NAME}" \ + ${docker_image} /bin/bash -c ' + CLEAN_RETRIES=3 + CLEAN_COUNT=0 + + while [ $CLEAN_COUNT -lt $CLEAN_RETRIES ]; do + echo "Attempt $((CLEAN_COUNT+1)) to remove ${REPO_NAME}* ..." + rm -rf "${REPO_NAME}"* || true + sleep 2 + + # Check if anything matching ${REPO_NAME}* still exists + if ! ls "${REPO_NAME}"* >/dev/null 2>&1; then + echo "All ${REPO_NAME}* removed successfully" + break + fi + + CLEAN_COUNT=$((CLEAN_COUNT + 1)) + done + + if ls "${REPO_NAME}"* >/dev/null 2>&1; then + echo "ERROR: Failed to clean ${REPO_NAME}* after multiple attempts" + ls -ld "${REPO_NAME}"* + exit 1 + fi + ' + + wget -q --no-proxy ${fd_archive_url} + tar -xf FastDeploy.tar.gz + rm -rf FastDeploy.tar.gz + cd FastDeploy + git config --global user.name "FastDeployCI" + git config --global user.email "fastdeploy_ci@example.com" + git log -n 3 --oneline + - name: FastDeploy Build + shell: bash + env: + docker_image: ${{ inputs.DOCKER_IMAGE }} + fd_version: ${{ inputs.FD_VERSION }} + BRANCH_REF: ${{ github.ref_name }} + PADDLEVERSION: ${{ inputs.PADDLEVERSION }} + PADDLE_WHL_URL: ${{ inputs.PADDLE_WHL_URL }} + WITH_NIGHTLY_BUILD: ${{ inputs.WITH_NIGHTLY_BUILD }} + run: | + set -x + runner_name="${{ runner.name }}" + CARD_ID=$(echo "${runner_name}" | awk -F'-' '{print $NF}') + gpu_id=$(echo "$CARD_ID" | fold -w1 | paste -sd,) + + PARENT_DIR=$(dirname "$WORKSPACE") + echo "PARENT_DIR:$PARENT_DIR" + docker run --rm --net=host \ + --cap-add=SYS_PTRACE --privileged --shm-size=64G \ + -v $(pwd):/workspace -w /workspace \ + -v "/ssd3:/ssd3" \ + -e "MODEL_PATH=/ssd3/model" \ + -e "http_proxy=$(git config --global --get http.proxy)" \ + -e "https_proxy=$(git config --global --get https.proxy)" \ + -e "no_proxy=bcebos.com,mirrors.tuna.tsinghua.edu.cn,127.0.0.1,localhost" \ + -e TZ="Asia/Shanghai" \ + -e "FD_VERSION=${fd_version}" \ + -e "WITH_NIGHTLY_BUILD=${WITH_NIGHTLY_BUILD}" \ + -e "PADDLEVERSION=${PADDLEVERSION}" \ + -e "PADDLE_WHL_URL=${PADDLE_WHL_URL}" \ + -e "BRANCH_REF=${BRANCH_REF}" \ + ${docker_image} /bin/bash -c ' + if [[ -n "${FD_VERSION}" ]]; then + export FASTDEPLOY_VERSION=${FD_VERSION} + echo "Custom FastDeploy version: ${FASTDEPLOY_VERSION}" + fi + + git config --global --add safe.directory /workspace/FastDeploy + chown -R $(whoami) /workspace/FastDeploy + cd FastDeploy + if [[ "${WITH_NIGHTLY_BUILD}" == "ON" ]];then + GIT_COMMIT_TIME=$(git --no-pager show -s --format=%ci HEAD) + DATE_ONLY=$(echo $GIT_COMMIT_TIME | sed "s/ .*//;s/-//g") + echo "Git Commit Time: $GIT_COMMIT_TIME" + echo "Date Only: $DATE_ONLY" + export FASTDEPLOY_VERSION="${FASTDEPLOY_VERSION}.dev${DATE_ONLY}" + fi + python -m pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple + # 针对不同分支和tag使用不同的PaddlePaddle安装包 + if [[ "${PADDLE_WHL_URL}" != "" ]];then + python -m pip install ${PADDLE_WHL_URL} + elif [[ "${PADDLEVERSION}" != "" ]];then + python -m pip uninstall paddlepaddle-xpu fastdeploy-xpu -y + python -m pip install paddlepaddle-xpu==${PADDLEVERSION} -i https://www.paddlepaddle.org.cn/packages/stable/xpu-p800/ + else + python -m pip uninstall paddlepaddle-xpu fastdeploy-xpu -y + python -m pip install --pre paddlepaddle-xpu -i https://www.paddlepaddle.org.cn/packages/nightly/xpu-p800/ + fi + + + python -m pip install --upgrade pip + python -m pip install -r requirements.txt + bash custom_ops/xpu_ops/download_dependencies.sh develop + export CLANG_PATH=$(pwd)/custom_ops/xpu_ops/third_party/xtdk + export XVLLM_PATH=$(pwd)/custom_ops/xpu_ops/third_party/xvllm + bash build.sh + ls ./dist/*.whl + ' + - name: Package Upload + id: set_output + run: | + set -x + if [[ "${{ github.event_name }}" == "pull_request" ]];then + commit_id=${{ github.event.pull_request.head.sha }} + pr_num=${{ github.event.pull_request.number }} + target_path=paddle-github-action/PR/FastDeploy/${pr_num}/${commit_id}/xpu + elif [[ "${{ github.ref_type }}" == "tag" ]]; then + commit_id=${{ github.sha }} + tag_name=${{ github.ref_name }} + target_path=paddle-github-action/TAG/FastDeploy/${tag_name}/${commit_id}/xpu + else + commit_id=${{ github.sha }} + branch_name=${{ github.ref_name }} + target_path=paddle-github-action/BRANCH/FastDeploy/${branch_name}/${commit_id}/xpu + fi + wget -q --no-proxy --no-check-certificate https://paddle-qa.bj.bcebos.com/CodeSync/develop/PaddlePaddle/PaddleTest/tools/bos_tools.py + push_file=$(realpath bos_tools.py) + python3 --version + python3 -m pip install bce-python-sdk==0.9.29 + cd FastDeploy/dist/ + matches=($(ls fastdeploy*.whl)) + if [ ${#matches[@]} -ne 1 ]; then + echo "Error: Found ${#matches[@]} matching files, expected exactly 1" + exit 1 + fi + fd_wheel_name=${matches[0]} + echo "Found: $fd_wheel_name" + # tree -L 3 + python3 ${push_file} fastdeploy*.whl ${target_path} + target_path_stripped="${target_path#paddle-github-action/}" + WHEEL_PATH=https://paddle-github-action.bj.bcebos.com/${target_path_stripped}/${fd_wheel_name} + echo "wheel_path=${WHEEL_PATH}" >> $GITHUB_OUTPUT diff --git a/.github/workflows/_ci_gcu.yml b/.github/workflows/_ci_gcu.yml new file mode 100644 index 00000000000..968fc5a68c4 --- /dev/null +++ b/.github/workflows/_ci_gcu.yml @@ -0,0 +1,98 @@ +name: CI_GCU + +on: + #pull_request: + #branches: + #- develop + #- 'release/*' + workflow_dispatch: + +concurrency: + group: ${{ github.event.pull_request.number }}-gcu-ci + cancel-in-progress: true + +jobs: + CI_GCU: + runs-on: + group: GCU + steps: + - name: Print current runner name + run: | + echo "Current runner name: ${{ runner.name }}" + + - name: Code Checkout + env: + docker_image: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/device/paddle-gcu:topsrider3.5.102-ubuntu20-x86_64-gcc84 + run: | + REPO="https://github.com/${{ github.repository }}.git" + FULL_REPO="${{ github.repository }}" + REPO_NAME="${FULL_REPO##*/}" + BASE_BRANCH="${{ github.base_ref }}" + # Clean the repository directory before starting + docker run --rm --net=host -v $(pwd):/workspace \ + -v ${{ github.workspace }}/../../..:${{ github.workspace }}/../../.. \ + -w /workspace \ + -e "REPO_NAME=${REPO_NAME}" \ + -e "BASE_BRANCH=${BASE_BRANCH}" \ + ${docker_image} /bin/bash -c ' + if [ -d ${REPO_NAME} ]; then + echo "Directory ${REPO_NAME} exists, removing it..." + rm -rf ${REPO_NAME} + fi + ' + git config --global user.name "FastDeployCI" + git config --global user.email "fastdeploy_ci@example.com" + source ${{ github.workspace }}/../../../proxy + git clone ${REPO} ${REPO_NAME} -b ${BASE_BRANCH} + cd FastDeploy + if [ "${{ github.event_name }}" = "pull_request" ]; then + git fetch origin pull/${{ github.event.pull_request.number }}/head:pr/${{ github.event.pull_request.number }} + git merge pr/${{ github.event.pull_request.number }} + git log -n 3 --oneline + else + git checkout ${{ github.sha }} + git log -n 3 --oneline + fi + echo "Copy models..." + sudo mkdir -p ci_models && sudo cp -r /work/deps/ERNIE-4.5-21B-A3B-Paddle ci_models + echo "Copy models done." + + - name: Run CI unittest + env: + docker_image: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/device/paddle-gcu:topsrider3.5.102-ubuntu20-x86_64-gcc84 + run: | + runner_name="${{ runner.name }}" + last_char="${runner_name: -1}" + + if [[ "$last_char" =~ [0-3] ]]; then + gcu_id="$last_char" + else + gcu_id="0" + fi + FD_API_PORT=$((9180 + gcu_id * 100)) + FD_ENGINE_QUEUE_PORT=$((9150 + gcu_id * 100)) + FD_METRICS_PORT=$((9170 + gcu_id * 100)) + + PARENT_DIR=$(dirname "$WORKSPACE") + echo "PARENT_DIR:$PARENT_DIR" + echo "Install drivers..." + cd /work/deps + sudo bash TopsRider_i3x_*_deb_amd64.run --driver --no-auto-load -y + cd - + echo "Create docker..." + docker run --rm --network=host --ipc=host --privileged \ + -v $(pwd):/workspace \ + -v /home:/home \ + -v /work:/work \ + -w /workspace \ + -e "MODEL_PATH=./ci_models" \ + -e "http_proxy=$(git config --global --get http.proxy)" \ + -e "https_proxy=$(git config --global --get https.proxy)" \ + -e "FD_API_PORT=${FD_API_PORT}" \ + -e "FD_ENGINE_QUEUE_PORT=${FD_ENGINE_QUEUE_PORT}" \ + -e "FD_METRICS_PORT=${FD_METRICS_PORT}" \ + ${docker_image} /bin/bash -c " + git config --global --add safe.directory /workspace/FastDeploy + cd FastDeploy + bash scripts/run_ci_gcu.sh + " diff --git a/.github/workflows/_ci_image_build.yml b/.github/workflows/_ci_image_build.yml new file mode 100644 index 00000000000..a498d63d1f3 --- /dev/null +++ b/.github/workflows/_ci_image_build.yml @@ -0,0 +1,73 @@ +name: Docker Build +description: "FastDeploy CI Image Build" + +on: + workflow_call: + inputs: + CI_DOCKER_IMAGE_NAME: + description: "Build Images" + required: true + type: string + default: "ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:cuda126-py310" + FASTDEPLOY_ARCHIVE_URL: + description: "URL of the compressed FastDeploy code archive." + required: true + type: string + DOCKER_IMAGE_NAME: + description: "Build Images" + required: false + type: string + default: "ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:fastdeploy-ciuse-cuda126-dailyupdate" + outputs: + docker_name_precheck: + description: "Output path of the generated wheel" + value: ${{ jobs.docker_build.outputs.docker_name_precheck }} + +jobs: + docker_build: + runs-on: [self-hosted, Docker-Build] + outputs: + docker_name_precheck: ${{ steps.docker_build.outputs.docker_name_precheck }} + steps: + - name: Docker Build + id: docker_build + shell: bash + env: + docker_image_name: ${{ inputs.CI_DOCKER_IMAGE_NAME }} + docker_image: ${{ inputs.DOCKER_IMAGE_NAME }} + fd_archive_url: ${{ inputs.FASTDEPLOY_ARCHIVE_URL }} + run: | + set -x + REPO="https://github.com/${{ github.repository }}.git" + FULL_REPO="${{ github.repository }}" + REPO_NAME="${FULL_REPO##*/}" + BASE_BRANCH="${{ github.base_ref }}" + + # Clean the repository directory before starting + docker run --rm --net=host -v $(pwd):/workspace -w /workspace \ + -e "REPO_NAME=${REPO_NAME}" \ + ${docker_image} /bin/bash -c ' + if [ -d ${REPO_NAME} ]; then + echo "Directory ${REPO_NAME} exists, removing it..." + rm -rf ${REPO_NAME}* + fi + ' + + wget -q --no-proxy ${fd_archive_url} + tar -xf FastDeploy.tar.gz + rm -rf FastDeploy.tar.gz + cd FastDeploy + git config --global user.name "FastDeployCI" + git config --global user.email "fastdeploy_ci@example.com" + git log -n 3 --oneline + + # Docker Build + cd tools/dockerfile/ + set -e + cp ../../requirements.txt ./ + cp ../../scripts/unittest_requirement.txt ./ + docker build -t ${docker_image_name} -f Dockerfile.ci . \ + --network host \ + --no-cache + docker push ${docker_image_name} + echo "docker_name_precheck=${docker_image_name}" >> $GITHUB_OUTPUT diff --git a/.github/workflows/_clone_linux.yml b/.github/workflows/_clone_linux.yml new file mode 100644 index 00000000000..9b2c5a76667 --- /dev/null +++ b/.github/workflows/_clone_linux.yml @@ -0,0 +1,89 @@ +name: FastDeploy Code Clone +description: "FastDeploy clone and upload" + +on: + workflow_call: + inputs: + bos_dir: + type: string + required: false + default: 'FastDeploy' + outputs: + repo_archive_url: + description: "Compressed source code archive." + value: ${{ jobs.code-clone.outputs.repo_archive_url }} + secrets: + github-token: + required: true + +jobs: + check_bypass: + uses: ./.github/workflows/check-bypass.yml + secrets: + github-token: ${{ secrets.github-token }} + with: + workflow-name: code_clone + + code-clone: + runs-on: + group: HK-Clone + needs: check_bypass + if: ${{ needs.check_bypass.outputs.can-skip != 'true' }} + outputs: + repo_archive_url: ${{ steps.set_output.outputs.repo_archive_url }} + steps: + - name: Clone FastDeploy + uses: actions/checkout@v6 + with: + ref: ${{ github.event_name == 'pull_request' + && github.event.pull_request.base.ref + || github.ref_name }} + submodules: 'recursive' + fetch-depth: 1000 + + - name: Merge PR (if needed) + if: ${{ github.event_name == 'pull_request' }} + run: | + git config --global user.name "FastDeployCI" + git config --global user.email "fastdeploy_ci@example.com" + echo "Fetching and merging PR..." + git fetch origin pull/${{ github.event.pull_request.number }}/head:pr/${{ github.event.pull_request.number }} + git merge --no-ff pr/${{ github.event.pull_request.number }} + echo "PR Branch log " + git log --oneline -n 5 pr/${{ github.event.pull_request.number }} + - uses: actions/setup-python@v6 + with: + python-version: '3.10' + - name: Code Info Show and Upload + id: set_output + env: + AK: paddle + SK: paddle + run: | + git submodule foreach --recursive sh -c "git config remote.origin.fetch '+refs/heads/*:refs/remotes/origin/*'" + echo "Current HEAD Log:" + git log --oneline -n 5 + ls + cd .. + tar -zcf FastDeploy.tar.gz FastDeploy + if [[ "${{ github.event_name }}" == "pull_request" ]];then + commit_id=${{ github.event.pull_request.head.sha }} + pr_num=${{ github.event.pull_request.number }} + target_path=paddle-github-action/PR/FastDeploy/${pr_num}/${commit_id} + elif [[ "${{ github.ref_type }}" == "tag" ]]; then + commit_id=${{ github.sha }} + tag_name=${{ github.ref_name }} + target_path=paddle-github-action/TAG/FastDeploy/${tag_name}/${commit_id} + else + commit_id=${{ github.sha }} + branch_name=${{ github.ref_name }} + target_path=paddle-github-action/BRANCH/FastDeploy/${branch_name}/${commit_id} + fi + wget -O bos_tools.py -q --no-proxy --no-check-certificate https://paddle-qa.bj.bcebos.com/CodeSync/develop/PaddlePaddle/PaddleTest/tools/bos_tools.py + push_file=$(realpath bos_tools.py) + python -m pip install bce-python-sdk==0.9.29 + ls + python ${push_file} FastDeploy.tar.gz ${target_path} + target_path_stripped="${target_path#paddle-github-action/}" + REPO_ARCHIVE_URL=https://paddle-github-action.bj.bcebos.com/${target_path_stripped}/FastDeploy.tar.gz + echo "repo_archive_url=${REPO_ARCHIVE_URL}" >> $GITHUB_OUTPUT diff --git a/.github/workflows/_golang_router_test.yml b/.github/workflows/_golang_router_test.yml new file mode 100644 index 00000000000..4964f3a3a05 --- /dev/null +++ b/.github/workflows/_golang_router_test.yml @@ -0,0 +1,213 @@ +name: GOLANG_ROUTER Tests +description: "Run FastDeploy golang_router tests" + +on: + workflow_call: + inputs: + DOCKER_IMAGE: + description: "Build Images" + required: true + type: string + default: "ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:fastdeploy-ciuse-cuda126-paddle-dev" + FASTDEPLOY_ARCHIVE_URL: + description: "URL of the compressed FastDeploy code archive." + required: true + type: string + FASTDEPLOY_WHEEL_URL: + description: "URL of the FastDeploy Wheel." + required: true + type: string + FASTDEPLOY_ROUTER_URL: + description: "URL of the FastDeploy Router" + required: true + type: string + CACHE_DIR: + description: "Cache Dir Use" + required: false + type: string + default: "" + MODEL_CACHE_DIR: + description: "Cache Dir Use" + required: false + type: string + default: "" + secrets: + github-token: + required: true + +jobs: + run_golang_router_tests: + runs-on: [self-hosted, GPU-h20-2Cards] + timeout-minutes: 30 + steps: + - name: Code Prepare + shell: bash + env: + docker_image: ${{ inputs.DOCKER_IMAGE }} + fd_archive_url: ${{ inputs.FASTDEPLOY_ARCHIVE_URL }} + run: | + set -x + REPO="https://github.com/${{ github.repository }}.git" + FULL_REPO="${{ github.repository }}" + REPO_NAME="${FULL_REPO##*/}" + BASE_BRANCH="${{ github.base_ref }}" + docker pull ${docker_image} + # Clean the repository directory before starting + docker run --rm --net=host -v $(pwd):/workspace -w /workspace \ + -e "REPO_NAME=${REPO_NAME}" \ + ${docker_image} /bin/bash -c ' + CLEAN_RETRIES=3 + CLEAN_COUNT=0 + + while [ $CLEAN_COUNT -lt $CLEAN_RETRIES ]; do + echo "Attempt $((CLEAN_COUNT+1)) to remove ${REPO_NAME}* ..." + rm -rf "${REPO_NAME}"* || true + sleep 2 + + # Check if anything matching ${REPO_NAME}* still exists + if ! ls "${REPO_NAME}"* >/dev/null 2>&1; then + echo "All ${REPO_NAME}* removed successfully" + break + fi + + CLEAN_COUNT=$((CLEAN_COUNT + 1)) + done + + if ls "${REPO_NAME}"* >/dev/null 2>&1; then + echo "ERROR: Failed to clean ${REPO_NAME}* after multiple attempts" + ls -ld "${REPO_NAME}"* + exit 1 + fi + ' + + wget -q --no-proxy ${fd_archive_url} + tar -xf FastDeploy.tar.gz + rm -rf FastDeploy.tar.gz + cd FastDeploy + git config --global user.name "FastDeployCI" + git config --global user.email "fastdeploy_ci@example.com" + git log -n 3 --oneline + + - name: Run Golang_Router Tests + shell: bash + env: + docker_image: ${{ inputs.DOCKER_IMAGE }} + fd_wheel_url: ${{ inputs.FASTDEPLOY_WHEEL_URL }} + fd_router_url: ${{ inputs.FASTDEPLOY_ROUTER_URL }} + CACHE_DIR: ${{ inputs.CACHE_DIR }} + BASE_REF: ${{ github.event.pull_request.base.ref }} + MODEL_CACHE_DIR: ${{ inputs.MODEL_CACHE_DIR }} + IS_PR: ${{ github.event_name == 'pull_request' }} + run: | + if [[ "$IS_PR" == "true" ]]; then + echo "Running on PR" + else + echo "Not a PR" + fi + runner_name="${{ runner.name }}" + CARD_ID=$(echo "${runner_name}" | awk -F'-' '{print $NF}') + DEVICES=$(echo "$CARD_ID" | fold -w1 | paste -sd,) + DEVICE_PORT=$(echo "$DEVICES" | cut -d',' -f1) + + FLASK_PORT=$((8068 + DEVICE_PORT * 100)) + FD_API_PORT=$((8088 + DEVICE_PORT * 100)) + FD_ENGINE_QUEUE_PORT=$((8058 + DEVICE_PORT * 100)) + FD_METRICS_PORT=$((8078 + DEVICE_PORT * 100)) + FD_CACHE_QUEUE_PORT=$((8098 + DEVICE_PORT * 100)) + FD_ROUTER_PORT=$((8048 + DEVICE_PORT * 100)) + FD_CONNECTOR_PORT=$((8038 + DEVICE_PORT * 100)) + FD_RDMA_PORT=$((8028 + DEVICE_PORT * 100)) + echo "Test ENV Parameter:" + echo "=========================================================" + echo "FLASK_PORT=${FLASK_PORT}" + echo "FD_API_PORT=${FD_API_PORT}" + echo "FD_ENGINE_QUEUE_PORT=${FD_ENGINE_QUEUE_PORT}" + echo "FD_METRICS_PORT=${FD_METRICS_PORT}" + echo "FD_CACHE_QUEUE_PORT=${FD_CACHE_QUEUE_PORT}" + echo "FD_ROUTER_PORT=${FD_ROUTER_PORT}" + echo "FD_CONNECTOR_PORT=${FD_CONNECTOR_PORT}" + echo "FD_RDMA_PORT=${FD_RDMA_PORT}" + echo "DEVICES=${DEVICES}" + echo "=========================================================" + + CACHE_DIR="${CACHE_DIR:-$(dirname "$(dirname "${{ github.workspace }}")")}" + echo "CACHE_DIR is set to ${CACHE_DIR}" + if [ ! -f "${CACHE_DIR}/gitconfig" ]; then + touch "${CACHE_DIR}/gitconfig" + fi + + PORTS=($FLASK_PORT $FD_API_PORT $FD_ENGINE_QUEUE_PORT $FD_METRICS_PORT $FD_CACHE_QUEUE_PORT) + LOG_FILE="./port_cleanup_$(date +%Y%m%d_%H%M%S).log" + echo "==== LOG_FILE is ${LOG_FILE} ====" + + echo "==== PORT CLEAN BEFORE TASK RUN ====" | tee -a $LOG_FILE + + for port in "${PORTS[@]}"; do + PIDS=$(lsof -t -i :$port || true) + if [ -n "$PIDS" ]; then + echo "Port $port is occupied by PID(s): $PIDS" | tee -a $LOG_FILE + echo "$PIDS" | xargs -r kill -9 + echo "Port $port cleared" | tee -a $LOG_FILE + else + echo "Port $port is free" | tee -a $LOG_FILE + fi + done + + echo "==== PORT CLEAN COMPLETE ====" | tee -a $LOG_FILE + + echo "=========================================================" + echo "Ensuring no stale container named ${runner_name} ..." + if [ "$(docker ps -a -q -f name=${runner_name})" ]; then + echo "Removing stale container: ${runner_name}" + docker rm -f ${runner_name} || true + fi + + export RDMA_DEVICES=$(find /dev/infiniband/uverbs* -maxdepth 1 -not -type d | xargs -I{} echo '--device {}:{}') + + docker run --rm --net=host \ + --name ${runner_name} \ + --cap-add=SYS_PTRACE --cap-add=IPC_LOCK \ + --shm-size=64G \ + ${RDMA_DEVICES} \ + --device=/dev/infiniband/rdma_cm \ + --ulimit memlock=-1:-1 \ + -v $(pwd):/workspace -w /workspace \ + -v "${CACHE_DIR}/gitconfig:/etc/gitconfig:ro" \ + -v "${CACHE_DIR}/.cache:/root/.cache" \ + -v "${CACHE_DIR}/ConfigDir:/root/.config" \ + -v "${MODEL_CACHE_DIR}:/ModelData:ro" \ + -e "MODEL_PATH=/ModelData" \ + -e "FD_API_PORT=${FD_API_PORT}" \ + -e "FD_ENGINE_QUEUE_PORT=${FD_ENGINE_QUEUE_PORT}" \ + -e "FD_METRICS_PORT=${FD_METRICS_PORT}" \ + -e "FLASK_PORT=${FLASK_PORT}" \ + -e "FD_CACHE_QUEUE_PORT=${FD_CACHE_QUEUE_PORT}" \ + -e "FD_ROUTER_PORT=${FD_ROUTER_PORT}" \ + -e "FD_CONNECTOR_PORT=${FD_CONNECTOR_PORT}" \ + -e "FD_RDMA_PORT=${FD_RDMA_PORT}" \ + -e "CLEAN_CUDA=1" \ + -e TZ="Asia/Shanghai" \ + -e "fd_wheel_url=${fd_wheel_url}" \ + -e "fd_router_url=${fd_router_url}" \ + -e "BASE_REF=${BASE_REF}" \ + -e "IS_PR=${IS_PR}" \ + --gpus "\"device=${DEVICES}\"" ${docker_image} /bin/bash -c ' + + git config --global --add safe.directory /workspace/FastDeploy + cd FastDeploy + + python -m pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu126/ + pip config set global.extra-index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple + + python -m pip install -r scripts/unittest_requirement.txt + python -m pip install ${fd_wheel_url} + rm -rf fastdeploy + python -m pip install ${fd_wheel_url} --no-deps --target=/workspace/FastDeploy + export PYTHONPATH=/workspace/FastDeploy/ + + # download fd-router binary + wget -q --no-proxy ${fd_router_url} -O /usr/local/bin/fd-router + chmod +x /usr/local/bin/fd-router + + bash scripts/run_golang_router.sh + ' diff --git a/.github/workflows/_gpu_4cards_case_test.yml b/.github/workflows/_gpu_4cards_case_test.yml new file mode 100644 index 00000000000..9b1455e8c06 --- /dev/null +++ b/.github/workflows/_gpu_4cards_case_test.yml @@ -0,0 +1,203 @@ +name: 4-GPU E2E Tests +description: "Run FastDeploy e2e tests on 4 GPUs" + +on: + workflow_call: + inputs: + DOCKER_IMAGE: + description: "Build Images" + required: true + type: string + default: "ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:fastdeploy-ciuse-cuda126-paddle-dev" + FASTDEPLOY_ARCHIVE_URL: + description: "URL of the compressed FastDeploy code archive." + required: true + type: string + FASTDEPLOY_WHEEL_URL: + description: "URL of the FastDeploy Wheel." + required: true + type: string + CACHE_DIR: + description: "Cache Dir Use" + required: false + type: string + default: "" + MODEL_CACHE_DIR: + description: "Cache Dir Use" + required: false + type: string + default: "" + secrets: + github-token: + required: true + +jobs: + check_bypass: + uses: ./.github/workflows/check-bypass.yml + secrets: + github-token: ${{ secrets.github-token }} + with: + workflow-name: gpu_4cards_test + + run_4_cards_tests: + runs-on: [self-hosted, GPU-h20-4Cards] + needs: check_bypass + if: ${{ inputs.FASTDEPLOY_WHEEL_URL != '' && needs.check_bypass.outputs.can-skip != 'true' }} + timeout-minutes: 30 + steps: + - name: Code Prepare + shell: bash + env: + docker_image: ${{ inputs.DOCKER_IMAGE }} + fd_archive_url: ${{ inputs.FASTDEPLOY_ARCHIVE_URL }} + run: | + set -x + REPO="https://github.com/${{ github.repository }}.git" + FULL_REPO="${{ github.repository }}" + REPO_NAME="${FULL_REPO##*/}" + BASE_BRANCH="${{ github.base_ref }}" + docker pull ${docker_image} + # Clean the repository directory before starting + docker run --rm --net=host -v $(pwd):/workspace -w /workspace \ + -e "REPO_NAME=${REPO_NAME}" \ + ${docker_image} /bin/bash -c ' + CLEAN_RETRIES=3 + CLEAN_COUNT=0 + + while [ $CLEAN_COUNT -lt $CLEAN_RETRIES ]; do + echo "Attempt $((CLEAN_COUNT+1)) to remove ${REPO_NAME}* ..." + rm -rf "${REPO_NAME}"* || true + sleep 2 + + # Check if anything matching ${REPO_NAME}* still exists + if ! ls "${REPO_NAME}"* >/dev/null 2>&1; then + echo "All ${REPO_NAME}* removed successfully" + break + fi + + CLEAN_COUNT=$((CLEAN_COUNT + 1)) + done + + if ls "${REPO_NAME}"* >/dev/null 2>&1; then + echo "ERROR: Failed to clean ${REPO_NAME}* after multiple attempts" + ls -ld "${REPO_NAME}"* + exit 1 + fi + ' + + wget -q --no-proxy ${fd_archive_url} + tar -xf FastDeploy.tar.gz + rm -rf FastDeploy.tar.gz + cd FastDeploy + git config --global user.name "FastDeployCI" + git config --global user.email "fastdeploy_ci@example.com" + git log -n 3 --oneline + + - name: Run Four Cards Tests + shell: bash + env: + docker_image: ${{ inputs.DOCKER_IMAGE }} + fd_wheel_url: ${{ inputs.FASTDEPLOY_WHEEL_URL }} + CACHE_DIR: ${{ inputs.CACHE_DIR }} + BASE_REF: ${{ github.event.pull_request.base.ref }} + MODEL_CACHE_DIR: ${{ inputs.MODEL_CACHE_DIR }} + IS_PR: ${{ github.event_name == 'pull_request' }} + run: | + if [[ "$IS_PR" == "true" ]]; then + echo "Running on PR" + else + echo "Not a PR" + fi + runner_name="${{ runner.name }}" + CARD_ID=$(echo "${runner_name}" | awk -F'-' '{print $NF}') + DEVICES=$(echo "$CARD_ID" | fold -w1 | paste -sd,) + DEVICE_PORT=$(echo "$DEVICES" | cut -d',' -f1) + + FLASK_PORT=$((8068 + DEVICE_PORT * 100)) + FD_API_PORT=$((8088 + DEVICE_PORT * 100)) + FD_ENGINE_QUEUE_PORT=$((8058 + DEVICE_PORT * 100)) + FD_METRICS_PORT=$((8078 + DEVICE_PORT * 100)) + FD_CACHE_QUEUE_PORT=$((8098 + DEVICE_PORT * 100)) + FD_ROUTER_PORT=$((8048 + DEVICE_PORT * 100)) + FD_CONNECTOR_PORT=$((8038 + DEVICE_PORT * 100)) + FD_RDMA_PORT=$((8028 + DEVICE_PORT * 100)) + echo "Test ENV Parameter:" + echo "=========================================================" + echo "FLASK_PORT=${FLASK_PORT}" + echo "FD_API_PORT=${FD_API_PORT}" + echo "FD_ENGINE_QUEUE_PORT=${FD_ENGINE_QUEUE_PORT}" + echo "FD_METRICS_PORT=${FD_METRICS_PORT}" + echo "FD_CACHE_QUEUE_PORT=${FD_CACHE_QUEUE_PORT}" + echo "FD_ROUTER_PORT=${FD_ROUTER_PORT}" + echo "FD_CONNECTOR_PORT=${FD_CONNECTOR_PORT}" + echo "FD_RDMA_PORT=${FD_RDMA_PORT}" + echo "DEVICES=${DEVICES}" + echo "=========================================================" + + CACHE_DIR="${CACHE_DIR:-$(dirname "$(dirname "${{ github.workspace }}")")}" + echo "CACHE_DIR is set to ${CACHE_DIR}" + if [ ! -f "${CACHE_DIR}/gitconfig" ]; then + touch "${CACHE_DIR}/gitconfig" + fi + + PORTS=($FLASK_PORT $FD_API_PORT $FD_ENGINE_QUEUE_PORT $FD_METRICS_PORT $FD_CACHE_QUEUE_PORT) + LOG_FILE="./port_cleanup_$(date +%Y%m%d_%H%M%S).log" + echo "==== LOG_FILE is ${LOG_FILE} ====" + + echo "==== PORT CLEAN BEFORE TASK RUN ====" | tee -a $LOG_FILE + + for port in "${PORTS[@]}"; do + PIDS=$(lsof -t -i :$port || true) + if [ -n "$PIDS" ]; then + echo "Port $port is occupied by PID(s): $PIDS" | tee -a $LOG_FILE + echo "$PIDS" | xargs -r kill -9 + echo "Port $port cleared" | tee -a $LOG_FILE + else + echo "Port $port is free" | tee -a $LOG_FILE + fi + done + + echo "==== PORT CLEAN COMPLETE ====" | tee -a $LOG_FILE + + echo "=========================================================" + echo "Ensuring no stale container named ${runner_name} ..." + if [ "$(docker ps -a -q -f name=${runner_name})" ]; then + echo "Removing stale container: ${runner_name}" + docker rm -f ${runner_name} || true + fi + + docker run --rm --ipc=host --net=host \ + --name ${runner_name} \ + -v $(pwd):/workspace -w /workspace \ + -v "${CACHE_DIR}/gitconfig:/etc/gitconfig:ro" \ + -v "${CACHE_DIR}/.cache:/root/.cache" \ + -v "${CACHE_DIR}/ConfigDir:/root/.config" \ + -v "${MODEL_CACHE_DIR}:/ModelData:ro" \ + -e "MODEL_PATH=/ModelData" \ + -e "FD_API_PORT=${FD_API_PORT}" \ + -e "FD_ENGINE_QUEUE_PORT=${FD_ENGINE_QUEUE_PORT}" \ + -e "FD_METRICS_PORT=${FD_METRICS_PORT}" \ + -e "FLASK_PORT=${FLASK_PORT}" \ + -e "FD_CACHE_QUEUE_PORT=${FD_CACHE_QUEUE_PORT}" \ + -e TZ="Asia/Shanghai" \ + -e "fd_wheel_url=${fd_wheel_url}" \ + -e "BASE_REF=${BASE_REF}" \ + -e "IS_PR=${IS_PR}" \ + --gpus '"device='"${DEVICES}"'"' ${docker_image} /bin/bash -c ' + + git config --global --add safe.directory /workspace/FastDeploy + cd FastDeploy + git diff origin/${BASE_REF}..HEAD --unified=0 > diff.txt + + python -m pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu126/ + pip config set global.extra-index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple + + python -m pip install -r scripts/unittest_requirement.txt + python -m pip install ${fd_wheel_url} + rm -rf fastdeploy + python -m pip install ${fd_wheel_url} --no-deps --target=/workspace/FastDeploy + export PYTHONPATH=/workspace/FastDeploy/ + + export CUDA_VISIBLE_DEVICES=0,1,2,3 + bash scripts/run_gpu_4cards.sh + ' diff --git a/.github/workflows/_iluvatar_cases.yml b/.github/workflows/_iluvatar_cases.yml new file mode 100644 index 00000000000..036d7df78ef --- /dev/null +++ b/.github/workflows/_iluvatar_cases.yml @@ -0,0 +1,82 @@ +name: ILUVATAR-Test + +on: + workflow_call: + inputs: + DOCKER_IMAGE: + description: "Build Images" + required: true + type: string + default: "ccr-2vdh3abv-pub.cnc.bj.baidubce.com/device/paddle-ixuca:3.3.0-20260312" + FASTDEPLOY_ARCHIVE_URL: + description: "URL of the compressed FastDeploy code archive." + required: true + type: string + CACHE_DIR: + description: "Cache Dir Use" + required: false + type: string + default: "" + MODEL_CACHE_DIR: + description: "Cache Dir Use" + required: false + type: string + default: "" + secrets: + github-token: + required: true + +jobs: + check_bypass: + uses: ./.github/workflows/check-bypass.yml + secrets: + github-token: ${{ secrets.github-token }} + with: + workflow-name: ci_iluvatar + + run_iluvatar_cases: + runs-on: iluvatar-gpu-2 + needs: check_bypass + if: ${{ needs.check_bypass.outputs.can-skip != 'true' }} + timeout-minutes: 60 + container: + image: ${{ inputs.DOCKER_IMAGE }} + env: + LD_LIBRARY_PATH: /usr/local/corex/lib + LIBRARY_PATH: /usr/local/corex/lib + steps: + - name: Print current runner name + run: | + echo "Current runner name: ${{ runner.name }}" + - name: Code Prepare + shell: bash + env: + docker_image: ${{ inputs.DOCKER_IMAGE }} + fd_archive_url: ${{ inputs.FASTDEPLOY_ARCHIVE_URL }} + run: | + set -x + REPO="https://github.com/${{ github.repository }}.git" + FULL_REPO="${{ github.repository }}" + REPO_NAME="${FULL_REPO##*/}" + BASE_BRANCH="${{ github.base_ref }}" + # Clean the repository directory before starting + if [ -d ${REPO_NAME} ]; then + echo "Directory ${REPO_NAME} exists, removing it..." + rm -rf ${REPO_NAME}* + fi + git config --global --add safe.directory '*' + wget -q ${fd_archive_url} + tar -xf FastDeploy.tar.gz + rm -rf FastDeploy.tar.gz + cd FastDeploy + git config --global user.name "FastDeployCI" + git config --global user.email "fastdeploy_ci@example.com" + git log -n 3 --oneline + + - name: Run CI unittest + env: + CACHE_DIR: ${{ inputs.CACHE_DIR }} + MODEL_CACHE_DIR: ${{ inputs.MODEL_CACHE_DIR }} + run: | + cd FastDeploy + bash scripts/run_ci_iluvatar.sh diff --git a/.github/workflows/_logprob_test_linux.yml b/.github/workflows/_logprob_test_linux.yml new file mode 100644 index 00000000000..ac3ebeff2fc --- /dev/null +++ b/.github/workflows/_logprob_test_linux.yml @@ -0,0 +1,221 @@ +name: Run FastDeploy LogProb Tests +description: "Run FastDeploy LogProb Tests" + +on: + workflow_call: + inputs: + DOCKER_IMAGE: + description: "Build Images" + required: true + type: string + default: "ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:cuda126-py310" + PADDLETEST_ARCHIVE_URL: + description: "URL of the compressed FastDeploy code archive." + required: true + type: string + default: "https://xly-devops.bj.bcebos.com/PaddleTest/PaddleTest.tar.gz" + FASTDEPLOY_WHEEL_URL: + description: "URL of the FastDeploy Wheel." + required: true + type: string + CACHE_DIR: + description: "Cache Dir Use" + required: false + type: string + default: "" + MODEL_CACHE_DIR: + description: "Cache Dir Use" + required: false + type: string + default: "" + secrets: + github-token: + required: true + +jobs: + check_bypass: + uses: ./.github/workflows/check-bypass.yml + secrets: + github-token: ${{ secrets.github-token }} + with: + workflow-name: logprob_test + + run_tests_logprob: + runs-on: [self-hosted, GPU-h20-1Cards] + needs: check_bypass + if: ${{ inputs.FASTDEPLOY_WHEEL_URL != '' && needs.check_bypass.outputs.can-skip != 'true' }} + timeout-minutes: 60 + steps: + - name: Code Prepare + shell: bash + env: + docker_image: ${{ inputs.DOCKER_IMAGE }} + paddletest_archive_url: ${{ inputs.PADDLETEST_ARCHIVE_URL }} + run: | + docker pull ${docker_image} + # Clean the repository directory before starting + docker run --rm --net=host -v $(pwd):/workspace -w /workspace \ + -e "REPO_NAME=${REPO_NAME}" \ + -e "BASE_BRANCH=${BASE_BRANCH}" \ + ${docker_image} /bin/bash -c ' + CLEAN_RETRIES=3 + CLEAN_COUNT=0 + + while [ $CLEAN_COUNT -lt $CLEAN_RETRIES ]; do + echo "Attempt $((CLEAN_COUNT+1)) to remove /workspace/* ..." + rm -rf /workspace/* || true + sleep 2 + + # Check if anything matching /workspace/* still exists + if ! ls /workspace/* >/dev/null 2>&1; then + echo "All /workspace/* removed successfully" + break + fi + + CLEAN_COUNT=$((CLEAN_COUNT + 1)) + done + + if ls /workspace/* >/dev/null 2>&1; then + echo "ERROR: Failed to clean /workspace/* after multiple attempts" + ls -ld /workspace/* + exit 1 + fi + ' + wget -q --no-proxy ${paddletest_archive_url} + tar -xf PaddleTest.tar.gz + rm -rf PaddleTest.tar.gz + cd PaddleTest + git config --global user.name "FastDeployCI" + git config --global user.email "fastdeploy_ci@example.com" + git log -n 3 --oneline + + - name: logprob test + shell: bash + env: + docker_image: ${{ inputs.DOCKER_IMAGE }} + fastdeploy_wheel_url: ${{ inputs.FASTDEPLOY_WHEEL_URL }} + CACHE_DIR: ${{ inputs.CACHE_DIR }} + MODEL_CACHE_DIR: ${{ inputs.MODEL_CACHE_DIR }} + run: | + runner_name="${{ runner.name }}" + CARD_ID=$(echo "${runner_name}" | awk -F'-' '{print $NF}') + DEVICES=$(echo "$CARD_ID" | fold -w1 | paste -sd,) + DEVICE_PORT=$(echo "$DEVICES" | cut -d',' -f1) + + FLASK_PORT=$((8068 + DEVICE_PORT * 100)) + FD_API_PORT=$((8088 + DEVICE_PORT * 100)) + FD_ENGINE_QUEUE_PORT=$((8058 + DEVICE_PORT * 100)) + FD_METRICS_PORT=$((8078 + DEVICE_PORT * 100)) + FD_CACHE_QUEUE_PORT=$((8098 + DEVICE_PORT * 100)) + echo "Test ENV Parameter:" + echo "=========================================================" + echo "FLASK_PORT=${FLASK_PORT}" + echo "FD_API_PORT=${FD_API_PORT}" + echo "FD_ENGINE_QUEUE_PORT=${FD_ENGINE_QUEUE_PORT}" + echo "FD_METRICS_PORT=${FD_METRICS_PORT}" + echo "FD_CACHE_QUEUE_PORT=${FD_CACHE_QUEUE_PORT}" + echo "DEVICES=${DEVICES}" + echo "=========================================================" + + CACHE_DIR="${CACHE_DIR:-$(dirname "$(dirname "${{ github.workspace }}")")}" + echo "CACHE_DIR is set to ${CACHE_DIR}" + if [ ! -f "${CACHE_DIR}/gitconfig" ]; then + touch "${CACHE_DIR}/gitconfig" + fi + if [ ! -d "${MODEL_CACHE_DIR}" ]; then + echo "Error: MODEL_CACHE_DIR '${MODEL_CACHE_DIR}' does not exist." + exit 1 + fi + + PORTS=($FLASK_PORT $FD_API_PORT $FD_ENGINE_QUEUE_PORT $FD_METRICS_PORT $FD_CACHE_QUEUE_PORT) + LOG_FILE="./port_cleanup_$(date +%Y%m%d_%H%M%S).log" + echo "==== LOG_FILE is ${LOG_FILE} ====" + + echo "==== PORT CLEAN BEFORE TASK RUN ====" | tee -a $LOG_FILE + + for port in "${PORTS[@]}"; do + PIDS=$(lsof -t -i :$port || true) + if [ -n "$PIDS" ]; then + echo "Port $port is occupied by PID(s): $PIDS" | tee -a $LOG_FILE + echo "$PIDS" | xargs -r kill -9 + echo "Port $port cleared" | tee -a $LOG_FILE + else + echo "Port $port is free" | tee -a $LOG_FILE + fi + done + + echo "==== PORT CLEAN COMPLETE ====" | tee -a $LOG_FILE + + echo "=========================================================" + echo "Ensuring no stale container named ${runner_name} ..." + if [ "$(docker ps -a -q -f name=${runner_name})" ]; then + echo "Removing stale container: ${runner_name}" + docker rm -f ${runner_name} || true + fi + docker run --rm --ipc=host --pid=host --net=host \ + --name ${runner_name} \ + -v $(pwd):/workspace \ + -w /workspace \ + -e fastdeploy_wheel_url=${fastdeploy_wheel_url} \ + -e "FD_API_PORT=${FD_API_PORT}" \ + -e "FD_ENGINE_QUEUE_PORT=${FD_ENGINE_QUEUE_PORT}" \ + -e "FD_METRICS_PORT=${FD_METRICS_PORT}" \ + -e "FD_CACHE_QUEUE_PORT=${FD_CACHE_QUEUE_PORT}" \ + -e "FLASK_PORT=${FLASK_PORT}" \ + -v "${MODEL_CACHE_DIR}:/MODELDATA" \ + -v "${CACHE_DIR}/gitconfig:/etc/gitconfig:ro" \ + -v "${CACHE_DIR}/.cache:/root/.cache" \ + -v "${CACHE_DIR}/ConfigDir:/root/.config" \ + -e TZ="Asia/Shanghai" \ + --gpus '"device='"${DEVICES}"'"' ${docker_image} /bin/bash -xc ' + python -m pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu126/ + + pip config set global.index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple + + python -m pip install ${fastdeploy_wheel_url} + + wget --no-proxy https://paddle-qa.bj.bcebos.com/zhengtianyu/tools/llm-deploy-linux-amd64 + chmod +x ./llm-deploy-linux-amd64 + ./llm-deploy-linux-amd64 -python python3.10 \ + -model_name ERNIE-4.5-0.3B-Paddle \ + -model_path /MODELDATA \ + --skip install,model + + cd PaddleTest/framework/ServeTest + ps -ef | grep "${FD_CACHE_QUEUE_PORT}" | grep -v grep | awk "{print \$2}" | xargs -r kill -9 + ps -ef | grep "${FD_ENGINE_QUEUE_PORT}" | grep -v grep | awk "{print \$2}" | xargs -r kill -9 + python3.10 deploy.py > dd.log 2>&1 & + sleep 3 + curl -X POST http://0.0.0.0:${FLASK_PORT}/start \ + -H "Content-Type: application/json" \ + -d "{\"--model\": \"/MODELDATA/ERNIE-4.5-0.3B-Paddle\"}" + + curl -X POST http://localhost:${FLASK_PORT}/wait_for_infer?timeout=90 + curl -s -o /dev/null -w "%{http_code}" -m 2 "http://0.0.0.0:${FD_API_PORT}/health" + curl -X POST "http://0.0.0.0:${FD_API_PORT}/v1/chat/completions" \ + -H "Content-Type: application/json" \ + -d "{\"messages\": [{\"role\": \"user\", \"content\": \"1+1=?\"}], \"logprobs\": true}" + set +e + rm -rf ./baseline_output + cp -r baseline_dev_0311/ERNIE-4.5-0.3B-Paddle ./baseline_output + LOGPROB_EXIT_CODE=0 + python3.10 lanucher.py --request_template TOKEN_LOGPROB --url http://localhost:${FD_API_PORT}/v1/chat/completions --case ./cases/demo.yaml --concurrency 1 --name demo --exe logprob || LOGPROB_EXIT_CODE=$? + echo "LOGPROB_EXIT_CODE=${LOGPROB_EXIT_CODE}" > /workspace/exit_code.env + curl -X POST http://localhost:${FLASK_PORT}/stop + sleep 10s + cat *result.log + exit 0 + ' + if [ $? -ne 0 ];then + exit 1 + fi + + if [ -f exit_code.env ]; then + cat exit_code.env >> $GITHUB_ENV + fi + - name: logprob test result + if: ${{ env.LOGPROB_EXIT_CODE != 0 }} + shell: bash + run: | + echo "logprob test failed with exit code ${{ env.LOGPROB_EXIT_CODE }}" + exit 8 diff --git a/.github/workflows/_pre_ce_test.yml b/.github/workflows/_pre_ce_test.yml new file mode 100644 index 00000000000..d5dd490828b --- /dev/null +++ b/.github/workflows/_pre_ce_test.yml @@ -0,0 +1,190 @@ +name: Pre-CE-Test + +on: + workflow_call: + inputs: + DOCKER_IMAGE: + description: "Build Images" + required: true + type: string + default: "ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddle:fastdeploy-ciuse-cuda126" + FASTDEPLOY_ARCHIVE_URL: + description: "URL of the compressed FastDeploy code archive." + required: true + type: string + FASTDEPLOY_WHEEL_URL: + description: "URL of the FastDeploy Wheel." + required: true + type: string + CACHE_DIR: + description: "Cache Dir Use" + required: false + type: string + default: "" + MODEL_CACHE_DIR: + description: "Cache Dir Use" + required: false + type: string + default: "" + secrets: + github-token: + required: true + +jobs: + check_bypass: + uses: ./.github/workflows/check-bypass.yml + secrets: + github-token: ${{ secrets.github-token }} + with: + workflow-name: pre_ce_test + + run_ce_cases: + runs-on: [self-hosted, PRE_CE_RUN_2Card] + needs: check_bypass + if: ${{ inputs.FASTDEPLOY_WHEEL_URL != '' && needs.check_bypass.outputs.can-skip != 'true' }} + timeout-minutes: 60 + steps: + - name: Print current runner name + run: | + echo "Current runner name: ${{ runner.name }}" + - name: Code Prepare + shell: bash + env: + docker_image: ${{ inputs.DOCKER_IMAGE }} + fd_archive_url: ${{ inputs.FASTDEPLOY_ARCHIVE_URL }} + run: | + set -x + REPO="https://github.com/${{ github.repository }}.git" + FULL_REPO="${{ github.repository }}" + REPO_NAME="${FULL_REPO##*/}" + BASE_BRANCH="${{ github.base_ref }}" + docker pull ${docker_image} + # Clean the repository directory before starting + docker run --rm --net=host -v $(pwd):/workspace -w /workspace \ + -e "REPO_NAME=${REPO_NAME}" \ + ${docker_image} /bin/bash -c ' + CLEAN_RETRIES=3 + CLEAN_COUNT=0 + + while [ $CLEAN_COUNT -lt $CLEAN_RETRIES ]; do + echo "Attempt $((CLEAN_COUNT+1)) to remove ${REPO_NAME}* ..." + rm -rf "${REPO_NAME}"* || true + sleep 2 + + # Check if anything matching ${REPO_NAME}* still exists + if ! ls "${REPO_NAME}"* >/dev/null 2>&1; then + echo "All ${REPO_NAME}* removed successfully" + break + fi + + CLEAN_COUNT=$((CLEAN_COUNT + 1)) + done + + if ls "${REPO_NAME}"* >/dev/null 2>&1; then + echo "ERROR: Failed to clean ${REPO_NAME}* after multiple attempts" + ls -ld "${REPO_NAME}"* + exit 1 + fi + ' + + wget -q --no-proxy ${fd_archive_url} + tar -xf FastDeploy.tar.gz + rm -rf FastDeploy.tar.gz + cd FastDeploy + git config --global user.name "FastDeployCI" + git config --global user.email "fastdeploy_ci@example.com" + git log -n 3 --oneline + + - name: Run CI unittest + env: + docker_image: ${{ inputs.DOCKER_IMAGE }} + fd_wheel_url: ${{ inputs.FASTDEPLOY_WHEEL_URL }} + CACHE_DIR: ${{ inputs.CACHE_DIR }} + MODEL_CACHE_DIR: ${{ inputs.MODEL_CACHE_DIR }} + run: | + runner_name="${{ runner.name }}" + CARD_ID=$(echo "${runner_name}" | awk -F'-' '{print $NF}') + DEVICES=$(echo "$CARD_ID" | fold -w1 | paste -sd,) + DEVICE_PORT=$(echo "$DEVICES" | cut -d',' -f1) + + FLASK_PORT=$((8068 + DEVICE_PORT * 100)) + FD_API_PORT=$((8088 + DEVICE_PORT * 100)) + FD_ENGINE_QUEUE_PORT=$((8058 + DEVICE_PORT * 100)) + FD_METRICS_PORT=$((8078 + DEVICE_PORT * 100)) + FD_CACHE_QUEUE_PORT=$((8098 + DEVICE_PORT * 100)) + FD_CONTROLLER_PORT=$((8018 + DEVICE_PORT * 100)) + FD_ZMQ_RECV_REQUEST_SERVER_PORT=$((8048 + DEVICE_PORT * 100)) + FD_ZMQ_SEND_RESPONSE_SERVER_PORT=$((8038 + DEVICE_PORT * 100)) + FD_ZMQ_CONTROL_CMD_SERVER_PORTS=$((8028 + DEVICE_PORT * 100)) + echo "Test ENV Parameter:" + echo "=========================================================" + echo "FLASK_PORT=${FLASK_PORT}" + echo "FD_API_PORT=${FD_API_PORT}" + echo "FD_ENGINE_QUEUE_PORT=${FD_ENGINE_QUEUE_PORT}" + echo "FD_METRICS_PORT=${FD_METRICS_PORT}" + echo "FD_CACHE_QUEUE_PORT=${FD_CACHE_QUEUE_PORT}" + echo "FD_CONTROLLER_PORT=${FD_CONTROLLER_PORT}" + echo "FD_ZMQ_RECV_REQUEST_SERVER_PORT=${FD_ZMQ_RECV_REQUEST_SERVER_PORT}" + echo "FD_ZMQ_SEND_RESPONSE_SERVER_PORT=${FD_ZMQ_SEND_RESPONSE_SERVER_PORT}" + echo "FD_ZMQ_CONTROL_CMD_SERVER_PORTS=${FD_ZMQ_CONTROL_CMD_SERVER_PORTS}" + echo "DEVICES=${DEVICES}" + echo "=========================================================" + + CACHE_DIR="${CACHE_DIR:-$(dirname "$(dirname "${{ github.workspace }}")")}" + echo "CACHE_DIR is set to ${CACHE_DIR}" + if [ ! -f "${CACHE_DIR}/gitconfig" ]; then + touch "${CACHE_DIR}/gitconfig" + fi + + PORTS=($FLASK_PORT $FD_API_PORT $FD_ENGINE_QUEUE_PORT $FD_METRICS_PORT $FD_CACHE_QUEUE_PORT) + LOG_FILE="./port_cleanup_$(date +%Y%m%d_%H%M%S).log" + echo "==== LOG_FILE is ${LOG_FILE} ====" + + echo "==== PORT CLEAN BEFORE TASK RUN ====" | tee -a $LOG_FILE + + for port in "${PORTS[@]}"; do + PIDS=$(lsof -t -i :$port || true) + if [ -n "$PIDS" ]; then + echo "Port $port is occupied by PID(s): $PIDS" | tee -a $LOG_FILE + echo "$PIDS" | xargs -r kill -9 + echo "Port $port cleared" | tee -a $LOG_FILE + else + echo "Port $port is free" | tee -a $LOG_FILE + fi + done + + echo "==== PORT CLEAN COMPLETE ====" | tee -a $LOG_FILE + + echo "=========================================================" + echo "Ensuring no stale container named ${runner_name} ..." + if [ "$(docker ps -a -q -f name=${runner_name})" ]; then + echo "Removing stale container: ${runner_name}" + docker rm -f ${runner_name} || true + fi + + docker run --rm --net=host \ + --name ${runner_name} \ + -v $(pwd):/workspace \ + -w /workspace \ + -v "${CACHE_DIR}/gitconfig:/etc/gitconfig:ro" \ + -v "${CACHE_DIR}/.cache:/root/.cache" \ + -v "${CACHE_DIR}/ConfigDir:/root/.config" \ + -v "${MODEL_CACHE_DIR}:/ModelData:ro" \ + -e "MODEL_PATH=/ModelData" \ + -e "FD_API_PORT=${FD_API_PORT}" \ + -e "FD_ENGINE_QUEUE_PORT=${FD_ENGINE_QUEUE_PORT}" \ + -e "FD_METRICS_PORT=${FD_METRICS_PORT}" \ + -e "FD_CACHE_QUEUE_PORT=${FD_CACHE_QUEUE_PORT}" \ + -e "FD_CONTROLLER_PORT=${FD_CONTROLLER_PORT}" \ + -e "FLASK_PORT=${FLASK_PORT}" \ + -e "FD_ZMQ_RECV_REQUEST_SERVER_PORT=${FD_ZMQ_RECV_REQUEST_SERVER_PORT}" \ + -e "FD_ZMQ_SEND_RESPONSE_SERVER_PORT=${FD_ZMQ_SEND_RESPONSE_SERVER_PORT}" \ + -e "FD_ZMQ_CONTROL_CMD_SERVER_PORTS=${FD_ZMQ_CONTROL_CMD_SERVER_PORTS}" \ + -e "fd_wheel_url=${fd_wheel_url}" \ + --gpus "\"device=${DEVICES}\"" ${docker_image} /bin/bash -c ' + git config --global --add safe.directory /workspace/FastDeploy + cd FastDeploy + python -m pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu126/ + python -m pip install ${fd_wheel_url} + bash scripts/run_pre_ce.sh + ' diff --git a/.github/workflows/_stable_test.yml b/.github/workflows/_stable_test.yml new file mode 100644 index 00000000000..3e7fb293e66 --- /dev/null +++ b/.github/workflows/_stable_test.yml @@ -0,0 +1,222 @@ +name: Stable Test +description: "Run Stable Tests" + +on: + workflow_call: + inputs: + DOCKER_IMAGE: + description: "Build Images" + required: true + type: string + default: "ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:cuda126-py310" + FASTDEPLOY_ARCHIVE_URL: + description: "URL of the compressed FastDeploy code archive." + required: true + type: string + FASTDEPLOY_WHEEL_URL: + description: "URL of the FastDeploy Wheel." + required: true + type: string + CACHE_DIR: + description: "Cache Dir Use" + required: false + type: string + default: "" + MODEL_CACHE_DIR: + description: "Cache Dir Use" + required: false + type: string + default: "" + secrets: + github-token: + required: true + +jobs: + check_bypass: + uses: ./.github/workflows/check-bypass.yml + secrets: + github-token: ${{ secrets.github-token }} + with: + workflow-name: stable_test + + stable_tests: + runs-on: [self-hosted, GPU-h20-2Cards] + needs: check_bypass + if: ${{ inputs.FASTDEPLOY_WHEEL_URL != '' && needs.check_bypass.outputs.can-skip != 'true' }} + timeout-minutes: 60 + steps: + - name: Code Prepare + shell: bash + env: + docker_image: ${{ inputs.DOCKER_IMAGE }} + fd_archive_url: ${{ inputs.FASTDEPLOY_ARCHIVE_URL }} + run: | + set -x + REPO="https://github.com/${{ github.repository }}.git" + FULL_REPO="${{ github.repository }}" + REPO_NAME="${FULL_REPO##*/}" + BASE_BRANCH="${{ github.base_ref }}" + docker pull ${docker_image} + # Clean the repository directory before starting + docker run --rm --net=host -v $(pwd):/workspace -w /workspace \ + -e "REPO_NAME=${REPO_NAME}" \ + ${docker_image} /bin/bash -c ' + CLEAN_RETRIES=3 + CLEAN_COUNT=0 + + while [ $CLEAN_COUNT -lt $CLEAN_RETRIES ]; do + echo "Attempt $((CLEAN_COUNT+1)) to remove ${REPO_NAME}* ..." + rm -rf "${REPO_NAME}"* || true + sleep 2 + + # Check if anything matching ${REPO_NAME}* still exists + if ! ls "${REPO_NAME}"* >/dev/null 2>&1; then + echo "All ${REPO_NAME}* removed successfully" + break + fi + + CLEAN_COUNT=$((CLEAN_COUNT + 1)) + done + + if ls "${REPO_NAME}"* >/dev/null 2>&1; then + echo "ERROR: Failed to clean ${REPO_NAME}* after multiple attempts" + ls -ld "${REPO_NAME}"* + exit 1 + fi + ' + + wget -q --no-proxy ${fd_archive_url} + tar -xf FastDeploy.tar.gz + rm -rf FastDeploy.tar.gz + cd FastDeploy + git config --global user.name "FastDeployCI" + git config --global user.email "fastdeploy_ci@example.com" + git log -n 3 --oneline + + - name: Run FastDeploy Stable Tests + shell: bash + env: + docker_image: ${{ inputs.DOCKER_IMAGE }} + fastdeploy_wheel_url: ${{ inputs.FASTDEPLOY_WHEEL_URL }} + CACHE_DIR: ${{ inputs.CACHE_DIR }} + MODEL_CACHE_DIR: ${{ inputs.MODEL_CACHE_DIR }} + run: | + runner_name="${{ runner.name }}" + CARD_ID=$(echo "${runner_name}" | awk -F'-' '{print $NF}') + DEVICES=$(echo "$CARD_ID" | fold -w1 | paste -sd,) + DEVICE_PORT=$(echo "$DEVICES" | cut -d',' -f1) + + FLASK_PORT=$((8068 + DEVICE_PORT * 100)) + FD_API_PORT=$((8088 + DEVICE_PORT * 100)) + FD_ENGINE_QUEUE_PORT=$((8058 + DEVICE_PORT * 100)) + FD_METRICS_PORT=$((8078 + DEVICE_PORT * 100)) + FD_CACHE_QUEUE_PORT=$((8038 + DEVICE_PORT * 100)) + FD_INFERENCE_MSG_QUEUE_ID=$(( 8048 + DEVICE_PORT * 100)) + echo "Test ENV Parameter:" + echo "=========================================================" + echo "FLASK_PORT=${FLASK_PORT}" + echo "FD_API_PORT=${FD_API_PORT}" + echo "FD_ENGINE_QUEUE_PORT=${FD_ENGINE_QUEUE_PORT}" + echo "FD_METRICS_PORT=${FD_METRICS_PORT}" + echo "FD_INFERENCE_MSG_QUEUE_ID=${FD_INFERENCE_MSG_QUEUE_ID}" + echo "FD_CACHE_QUEUE_PORT=${FD_CACHE_QUEUE_PORT}" + echo "DEVICES=${DEVICES}" + echo "=========================================================" + + CACHE_DIR="${CACHE_DIR:-$(dirname "$(dirname "${{ github.workspace }}")")}" + echo "CACHE_DIR is set to ${CACHE_DIR}" + if [ ! -f "${CACHE_DIR}/gitconfig" ]; then + touch "${CACHE_DIR}/gitconfig" + fi + if [ ! -d "${MODEL_CACHE_DIR}" ]; then + echo "Error: MODEL_CACHE_DIR '${MODEL_CACHE_DIR}' does not exist." + exit 1 + fi + + PORTS=($FLASK_PORT $FD_API_PORT $FD_ENGINE_QUEUE_PORT $FD_METRICS_PORT) + LOG_FILE="./port_cleanup_$(date +%Y%m%d_%H%M%S).log" + echo "==== LOG_FILE is ${LOG_FILE} ====" + + echo "==== PORT CLEAN BEFORE TASK RUN ====" | tee -a $LOG_FILE + + for port in "${PORTS[@]}"; do + PIDS=$(lsof -t -i :$port || true) + if [ -n "$PIDS" ]; then + echo "Port $port is occupied by PID(s): $PIDS" | tee -a $LOG_FILE + echo "$PIDS" | xargs -r kill -9 + echo "Port $port cleared" | tee -a $LOG_FILE + else + echo "Port $port is free" | tee -a $LOG_FILE + fi + done + + echo "==== PORT CLEAN COMPLETE ====" | tee -a $LOG_FILE + + echo "=========================================================" + echo "Ensuring no stale container named ${runner_name} ..." + if [ "$(docker ps -a -q -f name=${runner_name})" ]; then + echo "Removing stale container: ${runner_name}" + docker rm -f ${runner_name} || true + fi + + docker run --rm --net=host \ + --name ${runner_name} \ + -v $(pwd):/workspace \ + -w /workspace \ + -e fastdeploy_wheel_url=${fastdeploy_wheel_url} \ + -e "FD_API_PORT=${FD_API_PORT}" \ + -e "FD_ENGINE_QUEUE_PORT=${FD_ENGINE_QUEUE_PORT}" \ + -e "FD_METRICS_PORT=${FD_METRICS_PORT}" \ + -e "FLASK_PORT=${FLASK_PORT}" \ + -e "FD_INFERENCE_MSG_QUEUE_ID=${FD_INFERENCE_MSG_QUEUE_ID}" \ + -e "FD_CACHE_QUEUE_PORT=${FD_CACHE_QUEUE_PORT}" \ + -v "${MODEL_CACHE_DIR}:/MODELDATA" \ + -v "${CACHE_DIR}/gitconfig:/etc/gitconfig:ro" \ + -v "${CACHE_DIR}/.cache:/root/.cache" \ + -v "${CACHE_DIR}/ConfigDir:/root/.config" \ + -e TZ="Asia/Shanghai" \ + --gpus '"device='"${DEVICES}"'"' ${docker_image} /bin/bash -xc ' + python -m pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu126/ + + pip config set global.index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple + + python -m pip install ${fastdeploy_wheel_url} + python -m pip install pytest + + git config --global --add safe.directory /workspace/FastDeploy + cd FastDeploy + TEST_EXIT_CODE=0 + pushd tests/ce/stable_cases + bash launch_model.sh /MODELDATA + + TEST_EXIT_CODE=0 + bash run.sh || { + TEST_EXIT_CODE=1 + echo "==================== run.sh FAILED ====================" + + if [ -d log ]; then + echo ">>> grep error in ./log/" + grep -Rni --color=auto "error" log || true + else + echo "log/ directory not found" + fi + + if [ -f log/workerlog.0 ]; then + echo ">>> tail -n 100 log/workerlog.0" + tail -n 100 log/workerlog.0 + else + echo "log/workerlog.0 not found" + fi + + echo "=======================================================" + } + + popd + echo "TEST_EXIT_CODE=${TEST_EXIT_CODE}" >> /workspace/FastDeploy/exit_code.env + ' + if [ -f ./FastDeploy/exit_code.env ]; then + source ./FastDeploy/exit_code.env + cat ./FastDeploy/exit_code.env >> $GITHUB_ENV + fi + echo "TEST_EXIT_CODE=${TEST_EXIT_CODE}" + exit ${TEST_EXIT_CODE} diff --git a/.github/workflows/_unit_test_coverage.yml b/.github/workflows/_unit_test_coverage.yml new file mode 100644 index 00000000000..8b2d0272bc8 --- /dev/null +++ b/.github/workflows/_unit_test_coverage.yml @@ -0,0 +1,415 @@ +name: Coverage Check +description: "Run FastDeploy Unit Tests and Coverage" + +on: + workflow_call: + inputs: + DOCKER_IMAGE: + description: "Build Images" + required: true + type: string + default: "ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:cuda126-py310" + FASTDEPLOY_ARCHIVE_URL: + description: "URL of the compressed FastDeploy code archive." + required: true + type: string + FASTDEPLOY_WHEEL_URL: + description: "URL of the FastDeploy Wheel." + required: true + type: string + CACHE_DIR: + description: "Cache Dir Use" + required: false + type: string + default: "" + MODEL_CACHE_DIR: + description: "Cache Dir Use" + required: false + type: string + default: "" + secrets: + github-token: + required: true + +jobs: + check_cov_skip: + uses: ./.github/workflows/check-bypass.yml + secrets: + github-token: ${{ secrets.github-token }} + with: + workflow-name: coverage + + run_tests_with_coverage: + runs-on: [self-hosted, GPU-h1z1-2Cards] + timeout-minutes: 105 + needs: check_cov_skip + if: ${{ inputs.FASTDEPLOY_WHEEL_URL != '' && needs.check_cov_skip.outputs.can-skip != 'true' }} + outputs: + all_cov_file_url: ${{ steps.cov_upload.outputs.all_cov_file_url }} + unittest_failed_url: ${{ steps.cov_upload.outputs.unittest_failed_url }} + diff_cov_result_json_url: ${{ steps.cov_upload.outputs.diff_cov_result_json_url }} + steps: + - name: Code Prepare + shell: bash + env: + docker_image: ${{ inputs.DOCKER_IMAGE }} + fd_archive_url: ${{ inputs.FASTDEPLOY_ARCHIVE_URL }} + run: | + set -x + REPO="https://github.com/${{ github.repository }}.git" + FULL_REPO="${{ github.repository }}" + REPO_NAME="${FULL_REPO##*/}" + BASE_BRANCH="${{ github.base_ref }}" + docker pull ${docker_image} + # Clean the repository directory before starting + docker run --rm --net=host -v $(pwd):/workspace -w /workspace \ + -e "REPO_NAME=${REPO_NAME}" \ + ${docker_image} /bin/bash -c ' + CLEAN_RETRIES=3 + CLEAN_COUNT=0 + + while [ $CLEAN_COUNT -lt $CLEAN_RETRIES ]; do + echo "Attempt $((CLEAN_COUNT+1)) to remove ${REPO_NAME}* ..." + rm -rf "${REPO_NAME}"* || true + sleep 2 + + # Check if anything matching ${REPO_NAME}* still exists + if ! ls "${REPO_NAME}"* >/dev/null 2>&1; then + echo "All ${REPO_NAME}* removed successfully" + break + fi + + CLEAN_COUNT=$((CLEAN_COUNT + 1)) + done + + if ls "${REPO_NAME}"* >/dev/null 2>&1; then + echo "ERROR: Failed to clean ${REPO_NAME}* after multiple attempts" + ls -ld "${REPO_NAME}"* + exit 1 + fi + ' + + wget -q --no-proxy ${fd_archive_url} + tar -xf FastDeploy.tar.gz + rm -rf FastDeploy.tar.gz + cd FastDeploy + git config --global user.name "FastDeployCI" + git config --global user.email "fastdeploy_ci@example.com" + git log -n 3 --oneline + + - name: Run FastDeploy Unit Tests and Coverage + shell: bash + env: + docker_image: ${{ inputs.DOCKER_IMAGE }} + fd_wheel_url: ${{ inputs.FASTDEPLOY_WHEEL_URL }} + CACHE_DIR: ${{ inputs.CACHE_DIR }} + BASE_REF: ${{ github.event.pull_request.base.ref }} + MODEL_CACHE_DIR: ${{ inputs.MODEL_CACHE_DIR }} + IS_PR: ${{ github.event_name == 'pull_request' }} + run: | + if [[ "$IS_PR" == "true" ]]; then + echo "Running on PR" + else + echo "Not a PR" + fi + runner_name="${{ runner.name }}" + CARD_ID=$(echo "${runner_name}" | awk -F'-' '{print $NF}') + DEVICES=$(echo "$CARD_ID" | fold -w1 | paste -sd,) + DEVICE_PORT=$(echo "$DEVICES" | cut -d',' -f1) + + FLASK_PORT=$((8068 + DEVICE_PORT * 100)) + FD_API_PORT=$((8088 + DEVICE_PORT * 100)) + FD_ENGINE_QUEUE_PORT=$((8058 + DEVICE_PORT * 100)) + FD_METRICS_PORT=$((8078 + DEVICE_PORT * 100)) + FD_CACHE_QUEUE_PORT=$((8098 + DEVICE_PORT * 100)) + FD_ROUTER_PORT=$((8048 + DEVICE_PORT * 100)) + FD_CONNECTOR_PORT=$((8038 + DEVICE_PORT * 100)) + FD_RDMA_PORT=$((8028 + DEVICE_PORT * 100)) + echo "Test ENV Parameter:" + echo "=========================================================" + echo "FLASK_PORT=${FLASK_PORT}" + echo "FD_API_PORT=${FD_API_PORT}" + echo "FD_ENGINE_QUEUE_PORT=${FD_ENGINE_QUEUE_PORT}" + echo "FD_METRICS_PORT=${FD_METRICS_PORT}" + echo "FD_CACHE_QUEUE_PORT=${FD_CACHE_QUEUE_PORT}" + echo "FD_ROUTER_PORT=${FD_ROUTER_PORT}" + echo "FD_CONNECTOR_PORT=${FD_CONNECTOR_PORT}" + echo "FD_RDMA_PORT=${FD_RDMA_PORT}" + echo "DEVICES=${DEVICES}" + echo "=========================================================" + + CACHE_DIR="${CACHE_DIR:-$(dirname "$(dirname "${{ github.workspace }}")")}" + echo "CACHE_DIR is set to ${CACHE_DIR}" + if [ ! -f "${CACHE_DIR}/gitconfig" ]; then + touch "${CACHE_DIR}/gitconfig" + fi + + PORTS=($FLASK_PORT $FD_API_PORT $FD_ENGINE_QUEUE_PORT $FD_METRICS_PORT $FD_CACHE_QUEUE_PORT) + LOG_FILE="./port_cleanup_$(date +%Y%m%d_%H%M%S).log" + echo "==== LOG_FILE is ${LOG_FILE} ====" + + echo "==== PORT CLEAN BEFORE TASK RUN ====" | tee -a $LOG_FILE + + for port in "${PORTS[@]}"; do + PIDS=$(lsof -t -i :$port || true) + if [ -n "$PIDS" ]; then + echo "Port $port is occupied by PID(s): $PIDS" | tee -a $LOG_FILE + echo "$PIDS" | xargs -r kill -9 + echo "Port $port cleared" | tee -a $LOG_FILE + else + echo "Port $port is free" | tee -a $LOG_FILE + fi + done + + echo "==== PORT CLEAN COMPLETE ====" | tee -a $LOG_FILE + + echo "=========================================================" + echo "Ensuring no stale container named ${runner_name} ..." + if [ "$(docker ps -a -q -f name=${runner_name})" ]; then + echo "Removing stale container: ${runner_name}" + docker rm -f ${runner_name} || true + fi + + export RDMA_DEVICES=$(find /dev/infiniband/uverbs* -maxdepth 1 -not -type d | xargs -I{} echo '--device {}:{}') + + docker run --rm --net=host \ + --sysctl kernel.msgmax=1048576 \ + --sysctl kernel.msgmnb=268435456 \ + --name ${runner_name} \ + --cap-add=SYS_PTRACE --cap-add=IPC_LOCK \ + --shm-size=64G \ + ${RDMA_DEVICES} \ + --device=/dev/infiniband/rdma_cm \ + --ulimit memlock=-1:-1 \ + -v $(pwd):/workspace -w /workspace \ + -v "${CACHE_DIR}/gitconfig:/etc/gitconfig:ro" \ + -v "${CACHE_DIR}/.cache:/root/.cache" \ + -v "${CACHE_DIR}/ConfigDir:/root/.config" \ + -v "${MODEL_CACHE_DIR}:/ModelData:ro" \ + -e "MODEL_PATH=/ModelData" \ + -e "FD_API_PORT=${FD_API_PORT}" \ + -e "FD_ENGINE_QUEUE_PORT=${FD_ENGINE_QUEUE_PORT}" \ + -e "FD_METRICS_PORT=${FD_METRICS_PORT}" \ + -e "FLASK_PORT=${FLASK_PORT}" \ + -e "FD_CACHE_QUEUE_PORT=${FD_CACHE_QUEUE_PORT}" \ + -e "FD_ROUTER_PORT=${FD_ROUTER_PORT}" \ + -e "FD_CONNECTOR_PORT=${FD_CONNECTOR_PORT}" \ + -e "FD_RDMA_PORT=${FD_RDMA_PORT}" \ + -e "CLEAN_CUDA=1" \ + -e TZ="Asia/Shanghai" \ + -e "fd_wheel_url=${fd_wheel_url}" \ + -e "BASE_REF=${BASE_REF}" \ + -e "IS_PR=${IS_PR}" \ + --gpus "\"device=${DEVICES}\"" ${docker_image} /bin/bash -c ' + + git config --global --add safe.directory /workspace/FastDeploy + cd FastDeploy + git diff origin/${BASE_REF}..HEAD --unified=0 > diff.txt + python -m pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu126/ + pip config set global.extra-index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple + + python -m pip install -r scripts/unittest_requirement.txt + python -m pip install ${fd_wheel_url} + rm -rf fastdeploy + # coverage subprocess use + python -m pip install ${fd_wheel_url} --no-deps --target=/workspace/FastDeploy + export PYTHONPATH=/workspace/FastDeploy/ + if [ -d "tests/plugins" ]; then + cd tests/plugins + python setup.py install + cd ../.. + else + echo "Warning: tests/plugins directory not found, skipping setup.py install" + fi + export COVERAGE_FILE=/workspace/FastDeploy/coveragedata/.coverage + export COVERAGE_RCFILE=/workspace/FastDeploy/scripts/.coveragerc + TEST_EXIT_CODE=0 + bash scripts/coverage_run.sh || TEST_EXIT_CODE=8 + echo "TEST_EXIT_CODE=${TEST_EXIT_CODE}" >> exit_code.env + coverage combine coveragedata/ || echo "No data to combine" + coverage report + coverage xml -o python_coverage_all.xml + COVERAGE_EXIT_CODE=0 + if [[ "$IS_PR" == "true" ]]; then + echo "Running diff coverage for PR..." + diff-cover python_coverage_all.xml --diff-file=diff.txt --fail-under=80 --json-report diff_coverage.json || COVERAGE_EXIT_CODE=9 + # python scripts/generate_diff_coverage_xml.py diff.txt python_coverage_all.xml + else + echo "Running full coverage" + coverage report -m > full_coverage_report.txt + python scripts/generate_full_coverage_csv.py full_coverage_report.txt full_coverage_report.csv + fi + echo "COVERAGE_EXIT_CODE=${COVERAGE_EXIT_CODE}" >> exit_code.env + ' + if [ -f FastDeploy/exit_code.env ]; then + cat FastDeploy/exit_code.env >> $GITHUB_ENV + fi + - name: Upload coverage and unit test results to BOS + id: cov_upload + shell: bash + env: + IS_PR: ${{ github.event_name == 'pull_request' }} + GITHUB_SHA: ${{ github.sha }} + BRANCH: ${{ github.ref_name }} + PR_COMMIT_SHA: ${{ github.event.pull_request.head.sha }} + PR_NUMBER: ${{ github.event.pull_request.number }} + run: | + cd FastDeploy + python -m pip install -q bce-python-sdk==0.9.29 + wget -q --no-proxy --no-check-certificate \ + https://paddle-qa.bj.bcebos.com/CodeSync/develop/PaddlePaddle/PaddleTest/tools/bos_tools.py \ + -O bos_tools.py + push_file=$(realpath bos_tools.py) + + if [[ "$IS_PR" == "true" ]]; then + commit_id=${PR_COMMIT_SHA} + pr_num=${PR_NUMBER} + target_path=paddle-github-action/PR/FastDeploy/${pr_num}/${commit_id}/SM${compile_arch//,/_} + elif [[ "${{ github.ref_type }}" == "tag" ]]; then + commit_id=${{ github.sha }} + tag_name=${{ github.ref_name }} + target_path=paddle-github-action/TAG/FastDeploy/${tag_name}/${commit_id}/SM${compile_arch//,/_} + target_path_latest=paddle-github-action/TAG/FastDeploy/${tag_name}/latest/SM${compile_arch//,/_} + target_path_stripped_latest="${target_path_latest#paddle-github-action/}" + else + commit_id=${{ github.sha }} + branch_name=${{ github.ref_name }} + target_path=paddle-github-action/BRANCH/FastDeploy/${branch_name}/${commit_id}/SM${compile_arch//,/_} + target_path_latest=paddle-github-action/BRANCH/FastDeploy/${branch_name}/latest/SM${compile_arch//,/_} + target_path_stripped_latest="${target_path_latest#paddle-github-action/}" + fi + + target_path_stripped="${target_path#paddle-github-action/}" + + all_coverage_file="python_coverage_all.xml" + if [ -f ${all_coverage_file} ]; then + python ${push_file} ${all_coverage_file} ${target_path}/CoverageData + ALL_COV_FILE_URL=https://paddle-github-action.bj.bcebos.com/${target_path_stripped}/CoverageData/${all_coverage_file} + echo "all_cov_file_url=${ALL_COV_FILE_URL}" >> $GITHUB_OUTPUT + echo "all_cov_file_url=${ALL_COV_FILE_URL}" >> $GITHUB_ENV + fi + + if [[ "$IS_PR" == "true" ]]; then + diff_cov_result_json="diff_coverage.json" + if [ -f ${diff_cov_result_json} ]; then + python ${push_file} ${diff_cov_result_json} ${target_path}/CoverageData + DIFF_COV_JSON_URL=https://paddle-github-action.bj.bcebos.com/${target_path_stripped}/CoverageData/${diff_cov_result_json} + echo "diff_cov_result_json_url=${DIFF_COV_JSON_URL}" >> $GITHUB_OUTPUT + echo "diff_cov_result_json_url=${DIFF_COV_JSON_URL}" >> $GITHUB_ENV + fi + fi + + HAS_FAILED_TESTS=false + unittest_result="failed_tests.log" + if [ -s ${unittest_result} ]; then + HAS_FAILED_TESTS=true + python ${push_file} ${unittest_result} ${target_path}/UnitTestResult + UNIT_TEST_RESULT_URL=https://paddle-github-action.bj.bcebos.com/${target_path_stripped}/UnitTestResult/${unittest_result} + echo "unittest_failed_url=${UNIT_TEST_RESULT_URL}" >> $GITHUB_OUTPUT + echo "unittest_failed_url=${UNIT_TEST_RESULT_URL}" >> $GITHUB_ENV + fi + + if [[ "$IS_PR" != "true" ]]; then + full_cov_file="full_coverage_report.txt" + full_cov_csv="full_coverage_report.csv" + + if [ -f ${full_cov_file} ]; then + python ${push_file} ${full_cov_file} ${target_path}/CoverageData + python ${push_file} ${full_cov_file} ${target_path_latest}/CoverageData + FULL_COV_FILE_URL=https://paddle-github-action.bj.bcebos.com/${target_path_stripped}/CoverageData/${full_cov_file} + echo "full_coverage_report_url=${FULL_COV_FILE_URL}" >> $GITHUB_OUTPUT + echo "full_coverage_report_url=${FULL_COV_FILE_URL}" >> $GITHUB_ENV + fi + + if [ "$HAS_FAILED_TESTS" = false ] && [ -f ${full_cov_csv} ]; then + python ${push_file} ${full_cov_csv} ${target_path}/CoverageData + python ${push_file} ${full_cov_csv} ${target_path_latest}/CoverageData + FULL_COV_CSV_URL=https://paddle-github-action.bj.bcebos.com/${target_path_stripped}/CoverageData/${full_cov_csv} + echo "full_coverage_csv_url=${FULL_COV_CSV_URL}" >> $GITHUB_OUTPUT + echo "full_coverage_csv_url=${FULL_COV_CSV_URL}" >> $GITHUB_ENV + fi + fi + - name: Check Unit Test Success + shell: bash + run: | + cd FastDeploy + if [ "$TEST_EXIT_CODE" -eq 8 ]; then + filename=$(basename "$unittest_failed_url") + if [ -z "${unittest_failed_url}" ]; then + echo "No diff unit failed file URL provided." + else + rm -rf "${filename}" + wget -O ${filename} ${unittest_failed_url} || echo "Download unittest file failed, but continuing..." + fi + echo "Unit tests failed (exit code 8)" + if [ -f "${filename}" ];then + echo "Failed test cases:" + cat "${filename}" + fi + exit "$TEST_EXIT_CODE" + fi + echo "All tests passed" + + - name: Verify Code Coverage Threshold (80%) + if: ${{ github.event_name == 'pull_request' }} + shell: bash + run: | + cd FastDeploy + if [ "$COVERAGE_EXIT_CODE" -eq 9 ]; then + echo "Coverage generation failed (exit code 9)" + filename=$(basename "$diff_cov_result_json_url") + if [ -z "${diff_cov_result_json_url}" ]; then + echo "No diff cov result file URL provided." + else + rm -rf "${filename}" + wget -O ${filename} ${diff_cov_result_json_url} || echo "Download cov json file failed, but continuing..." + fi + if [ -f "${filename}" ];then + echo "Failed test cases:" + if command -v jq >/dev/null 2>&1; then + jq . "${filename}" + else + cat "${filename}" + fi + fi + exit "$COVERAGE_EXIT_CODE" + fi + echo "coverage passed" + exit 0 + + diff_coverage_report: + needs: run_tests_with_coverage + if: always() + runs-on: ubuntu-latest + timeout-minutes: 15 + env: + all_cov_file_url: ${{ needs.run_tests_with_coverage.outputs.all_cov_file_url }} + steps: + - name: Clone FastDeploy + uses: actions/checkout@v6 + with: + fetch-depth: 0 + - uses: actions/setup-python@v6 + with: + python-version: '3.10' + - name: Download diff coverage file + shell: bash + run: | + echo "Downloading all coverage file..." + if ! wget --no-proxy "${all_cov_file_url}" -O python_coverage_all.xml; then + echo "Download failed, skipping upload." + exit 0 + fi + + sed -i 's|/workspace/FastDeploy/fastdeploy|fastdeploy|' python_coverage_all.xml + + - name: Upload diff coverage report + if: always() && hashFiles('python_coverage_all.xml') != '' + uses: codecov/codecov-action@v6 + with: + files: ./python_coverage_all.xml + flags: GPU + name: python diff coverage + fail_ci_if_error: false + verbose: true + disable_search: true diff --git a/.github/workflows/_xpu_4cards_case_test.yml b/.github/workflows/_xpu_4cards_case_test.yml new file mode 100644 index 00000000000..f3c97f40dc6 --- /dev/null +++ b/.github/workflows/_xpu_4cards_case_test.yml @@ -0,0 +1,221 @@ +name: xpu_4cards_case_test + +on: + workflow_call: + inputs: + DOCKER_IMAGE: + description: "Build Images" + required: true + type: string + default: "ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/fastdeploy-xpu:ci" + FASTDEPLOY_ARCHIVE_URL: + description: "URL of the compressed FastDeploy code archive." + required: true + type: string + FASTDEPLOY_WHEEL_URL: + description: "URL of the compressed FastDeploy whl ." + required: true + type: string + FD_VERSION: + description: "FastDeploy Package Version" + required: false + type: string + default: "" + PADDLEVERSION: + description: "Paddle Version Build Use" + required: false + type: string + default: "" + PADDLE_WHL_URL: + description: "Paddle Wheel Package URL" + required: false + type: string + default: "" + MODEL_PATH: + description: "MODEL Dir Use" + required: true + type: string + default: "" + secrets: + github-token: + required: true + +jobs: + check_bypass: + uses: ./.github/workflows/check-bypass.yml + secrets: + github-token: ${{ secrets.github-token }} + with: + workflow-name: xpu_4cards_test + + run_xpu_4cards_cases: + runs-on: [self-hosted, XPU-P800-4Cards] + needs: check_bypass + if: ${{ inputs.FASTDEPLOY_WHEEL_URL != '' && needs.check_bypass.outputs.can-skip != 'true' }} + timeout-minutes: 60 + steps: + - name: Print current runner name + run: | + echo "Current runner name: ${{ runner.name }}" + - name: Code Prepare + shell: bash + env: + docker_image: ${{ inputs.DOCKER_IMAGE }} + fd_archive_url: ${{ inputs.FASTDEPLOY_ARCHIVE_URL }} + fd_wheel_url: ${{ inputs.FASTDEPLOY_WHEEL_URL }} + model_path: ${{ inputs.MODEL_PATH }} + run: | + set -x + REPO="https://github.com/${{ github.repository }}.git" + FULL_REPO="${{ github.repository }}" + REPO_NAME="${FULL_REPO##*/}" + BASE_BRANCH="${{ github.base_ref }}" + docker pull ${docker_image} || true + # Clean the repository directory before starting + docker run --rm --net=host -v $(pwd):/workspace -w /workspace \ + -e "REPO_NAME=${REPO_NAME}" \ + ${docker_image} /bin/bash -c ' + CLEAN_RETRIES=3 + CLEAN_COUNT=0 + + while [ $CLEAN_COUNT -lt $CLEAN_RETRIES ]; do + echo "Attempt $((CLEAN_COUNT+1)) to remove ${REPO_NAME}* ..." + rm -rf "${REPO_NAME}"* || true + sleep 2 + + # Check if anything matching ${REPO_NAME}* still exists + if ! ls "${REPO_NAME}"* >/dev/null 2>&1; then + echo "All ${REPO_NAME}* removed successfully" + break + fi + + CLEAN_COUNT=$((CLEAN_COUNT + 1)) + done + + if ls "${REPO_NAME}"* >/dev/null 2>&1; then + echo "ERROR: Failed to clean ${REPO_NAME}* after multiple attempts" + ls -ld "${REPO_NAME}"* + exit 1 + fi + ' + + wget -q --no-proxy ${fd_archive_url} + tar -xf FastDeploy.tar.gz + rm -rf FastDeploy.tar.gz + cd FastDeploy + git config --global user.name "FastDeployCI" + git config --global user.email "fastdeploy_ci@example.com" + git log -n 3 --oneline + + - name: Run CI unittest + env: + docker_image: ${{ inputs.DOCKER_IMAGE }} + fd_archive_url: ${{ inputs.FASTDEPLOY_ARCHIVE_URL }} + fd_wheel_url: ${{ inputs.FASTDEPLOY_WHEEL_URL }} + model_path: ${{ inputs.MODEL_PATH }} + run: | + runner_name="${{ runner.name }}" + last_char="${runner_name: -1}" + + if [[ "$last_char" == "1" ]]; then + xpu_id="4" + else + xpu_id="0" + fi + PARENT_DIR=$(dirname "$WORKSPACE") + echo "PARENT_DIR:$PARENT_DIR" + docker run --rm --net=host --cap-add=SYS_PTRACE --privileged --shm-size=64G \ + -v $(pwd):/workspace -w /workspace \ + -v "/ssd3:/ssd3" \ + -e "MODEL_PATH=${model_path}" \ + -e "FASTDEPLOY_ARCHIVE_URL=${fd_archive_url}" \ + -e "FASTDEPLOY_WHEEL_URL=${fd_wheel_url}" \ + -e "PADDLEVERSION=${PADDLEVERSION}" \ + -e "PADDLE_WHL_URL=${PADDLE_WHL_URL}" \ + -e "http_proxy=$(git config --global --get http.proxy)" \ + -e "https_proxy=$(git config --global --get https.proxy)" \ + -e "no_proxy=bcebos.com,mirrors.tuna.tsinghua.edu.cn,127.0.0.1,localhost" \ + -e "XPU_ID=${xpu_id}" \ + ${docker_image} /bin/bash -c ' + echo "安装lsof工具..." + apt install -y lsof + + # 设置XPU_VISIBLE_DEVICES + if [[ "$XPU_ID" == "0" ]]; then + export XPU_VISIBLE_DEVICES="0,1,2,3" + else + export XPU_VISIBLE_DEVICES="4,5,6,7" + fi + echo "XPU_VISIBLE_DEVICES=$XPU_VISIBLE_DEVICES" + + # 下载和安装xre + echo "下载和安装xre..." + mkdir -p /workspace/deps + cd /workspace/deps + if [ ! -d "xre" ]; then + wget -q https://klx-sdk-release-public.su.bcebos.com/xre/kl3-release/5.0.21.21/xre-Linux-x86_64-5.0.21.21.tar.gz + tar -zxf xre-Linux-x86_64-5.0.21.21.tar.gz && mv xre-Linux-x86_64-5.0.21.21 xre + fi + cd - + export PATH=/workspace/deps/xre/bin:$PATH + + # 重启XPU卡 + echo "重启XPU卡..." + xpu-smi -r -i $XPU_VISIBLE_DEVICES + xpu-smi + set -e + git config --global --add safe.directory /workspace/FastDeploy + cd FastDeploy + python -m pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple + python -m pip install -r requirements.txt + echo "安装PaddlePaddle..." + # 针对不同分支和tag使用不同的PaddlePaddle安装包 + if [[ "${PADDLE_WHL_URL}" != "" ]];then + python -m pip uninstall paddlepaddle-xpu fastdeploy-xpu -y + python -m pip install ${PADDLE_WHL_URL} + elif [[ "${PADDLEVERSION}" != "" ]];then + python -m pip uninstall paddlepaddle-xpu fastdeploy-xpu -y + python -m pip install paddlepaddle-xpu==${PADDLEVERSION} -i https://www.paddlepaddle.org.cn/packages/stable/xpu-p800/ + else + python -m pip uninstall paddlepaddle-xpu fastdeploy-xpu -y + python -m pip install --pre paddlepaddle-xpu -i https://www.paddlepaddle.org.cn/packages/nightly/xpu-p800/ + fi + echo "安装上游任务编译的fastdeploy-xpu..." + python -m pip install ${FASTDEPLOY_WHEEL_URL} + rm -rf fastdeploy + python -m pip install ${FASTDEPLOY_WHEEL_URL} --no-deps --target=/workspace/FastDeploy + echo "============================安装测试依赖============================" + python -m pip install openai -U + python -m pip install pytest + python -m pip install pytest-timeout + unset http_proxy + unset https_proxy + echo "============================开始运行pytest测试============================" + export PYTHONPATH=/workspace/FastDeploy/ + export PYTHONPATH=$(pwd)/tests/xpu_ci:$PYTHONPATH + mkdir -p case_logs + set +e + python -m pytest -v -s --tb=short tests/xpu_ci/4cards_cases/ + exit_code=$? + set -e + + # 修改case_logs权限,确保Docker外部的runner用户可以读取并上传 + chmod -R a+rX case_logs/ 2>/dev/null || true + + if [ $exit_code -eq 0 ]; then + echo "============================4卡cases测试通过!============================" + exit $exit_code + else + echo "============================4卡cases测试失败,请检查日志!============================" + exit $exit_code + fi + ' + + - name: Upload case logs + if: always() + uses: actions/upload-artifact@v6 + with: + name: xpu-4cards-case-logs + path: FastDeploy/case_logs/ + retention-days: 7 + if-no-files-found: ignore diff --git a/.github/workflows/_xpu_8cards_case_test.yml b/.github/workflows/_xpu_8cards_case_test.yml new file mode 100644 index 00000000000..c9ed0fa2314 --- /dev/null +++ b/.github/workflows/_xpu_8cards_case_test.yml @@ -0,0 +1,209 @@ +name: xpu_8cards_case_test + +on: + workflow_call: + inputs: + DOCKER_IMAGE: + description: "Build Images" + required: true + type: string + default: "ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/fastdeploy-xpu:ci" + FASTDEPLOY_ARCHIVE_URL: + description: "URL of the compressed FastDeploy code archive." + required: true + type: string + FASTDEPLOY_WHEEL_URL: + description: "URL of the compressed FastDeploy whl ." + required: true + type: string + FD_VERSION: + description: "FastDeploy Package Version" + required: false + type: string + default: "" + PADDLEVERSION: + description: "Paddle Version Build Use" + required: false + type: string + default: "" + PADDLE_WHL_URL: + description: "Paddle Wheel Package URL" + required: false + type: string + default: "" + MODEL_PATH: + description: "MODEL Dir Use" + required: true + type: string + default: "" + secrets: + github-token: + required: true + +jobs: + check_bypass: + uses: ./.github/workflows/check-bypass.yml + secrets: + github-token: ${{ secrets.github-token }} + with: + workflow-name: xpu_8cards_test + + run_xpu_8cards_cases: + runs-on: [self-hosted, XPU-P800-8Cards] + needs: check_bypass + if: ${{ inputs.FASTDEPLOY_WHEEL_URL != '' && needs.check_bypass.outputs.can-skip != 'true' }} + timeout-minutes: 60 + steps: + - name: Print current runner name + run: | + echo "Current runner name: ${{ runner.name }}" + - name: Code Prepare + shell: bash + env: + docker_image: ${{ inputs.DOCKER_IMAGE }} + fd_archive_url: ${{ inputs.FASTDEPLOY_ARCHIVE_URL }} + fd_wheel_url: ${{ inputs.FASTDEPLOY_WHEEL_URL }} + model_path: ${{ inputs.MODEL_PATH }} + run: | + set -x + REPO="https://github.com/${{ github.repository }}.git" + FULL_REPO="${{ github.repository }}" + REPO_NAME="${FULL_REPO##*/}" + BASE_BRANCH="${{ github.base_ref }}" + docker pull ${docker_image} || true + # Clean the repository directory before starting + docker run --rm --net=host -v $(pwd):/workspace -w /workspace \ + -e "REPO_NAME=${REPO_NAME}" \ + ${docker_image} /bin/bash -c ' + CLEAN_RETRIES=3 + CLEAN_COUNT=0 + + while [ $CLEAN_COUNT -lt $CLEAN_RETRIES ]; do + echo "Attempt $((CLEAN_COUNT+1)) to remove ${REPO_NAME}* ..." + rm -rf "${REPO_NAME}"* || true + sleep 2 + + # Check if anything matching ${REPO_NAME}* still exists + if ! ls "${REPO_NAME}"* >/dev/null 2>&1; then + echo "All ${REPO_NAME}* removed successfully" + break + fi + + CLEAN_COUNT=$((CLEAN_COUNT + 1)) + done + + if ls "${REPO_NAME}"* >/dev/null 2>&1; then + echo "ERROR: Failed to clean ${REPO_NAME}* after multiple attempts" + ls -ld "${REPO_NAME}"* + exit 1 + fi + ' + + wget -q --no-proxy ${fd_archive_url} + tar -xf FastDeploy.tar.gz + rm -rf FastDeploy.tar.gz + cd FastDeploy + git config --global user.name "FastDeployCI" + git config --global user.email "fastdeploy_ci@example.com" + git log -n 3 --oneline + + - name: Run CI unittest + env: + docker_image: ${{ inputs.DOCKER_IMAGE }} + fd_archive_url: ${{ inputs.FASTDEPLOY_ARCHIVE_URL }} + fd_wheel_url: ${{ inputs.FASTDEPLOY_WHEEL_URL }} + model_path: ${{ inputs.MODEL_PATH }} + run: | + runner_name="${{ runner.name }}" + last_char="${runner_name: -1}" + + PARENT_DIR=$(dirname "$WORKSPACE") + echo "PARENT_DIR:$PARENT_DIR" + docker run --rm --net=host --cap-add=SYS_PTRACE --privileged --shm-size=64G \ + -v $(pwd):/workspace -w /workspace \ + -v "/ssd3:/ssd3" \ + -e "MODEL_PATH=${model_path}" \ + -e "FASTDEPLOY_ARCHIVE_URL=${fd_archive_url}" \ + -e "FASTDEPLOY_WHEEL_URL=${fd_wheel_url}" \ + -e "PADDLEVERSION=${PADDLEVERSION}" \ + -e "PADDLE_WHL_URL=${PADDLE_WHL_URL}" \ + -e "http_proxy=$(git config --global --get http.proxy)" \ + -e "https_proxy=$(git config --global --get https.proxy)" \ + -e "no_proxy=bcebos.com,mirrors.tuna.tsinghua.edu.cn,127.0.0.1,localhost" \ + ${docker_image} /bin/bash -c ' + echo "安装lsof工具..." + apt install -y lsof + + # 设置XPU_VISIBLE_DEVICES + export XPU_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" + echo "XPU_VISIBLE_DEVICES=$XPU_VISIBLE_DEVICES" + + # 下载和安装xre + echo "下载和安装xre..." + mkdir -p /workspace/deps + cd /workspace/deps + if [ ! -d "xre" ]; then + wget -q https://klx-sdk-release-public.su.bcebos.com/xre/kl3-release/5.0.21.21/xre-Linux-x86_64-5.0.21.21.tar.gz + tar -zxf xre-Linux-x86_64-5.0.21.21.tar.gz && mv xre-Linux-x86_64-5.0.21.21 xre + fi + cd - + export PATH=/workspace/deps/xre/bin:$PATH + + # 重启XPU卡 + echo "重启XPU卡..." + xpu-smi -r -i $XPU_VISIBLE_DEVICES + xpu-smi + set -e + git config --global --add safe.directory /workspace/FastDeploy + cd FastDeploy + python -m pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple + python -m pip install -r requirements.txt + echo "安装PaddlePaddle..." + # 针对不同分支和tag使用不同的PaddlePaddle安装包 + if [[ "${PADDLE_WHL_URL}" != "" ]];then + python -m pip install ${PADDLE_WHL_URL} + elif [[ "${PADDLEVERSION}" != "" ]];then + python -m pip uninstall paddlepaddle-xpu fastdeploy-xpu -y + python -m pip install paddlepaddle-xpu==${PADDLEVERSION} -i https://www.paddlepaddle.org.cn/packages/stable/xpu-p800/ + else + python -m pip uninstall paddlepaddle-xpu fastdeploy-xpu -y + python -m pip install --pre paddlepaddle-xpu -i https://www.paddlepaddle.org.cn/packages/nightly/xpu-p800/ + fi + echo "安装上游任务编译的fastdeploy-xpu..." + python -m pip install ${FASTDEPLOY_WHEEL_URL} + rm -rf fastdeploy + python -m pip install ${FASTDEPLOY_WHEEL_URL} --no-deps --target=/workspace/FastDeploy + echo "============================安装测试依赖============================" + python -m pip install openai -U + python -m pip install pytest + python -m pip install pytest-timeout + unset http_proxy + unset https_proxy + echo "============================开始运行pytest测试============================" + export PYTHONPATH=/workspace/FastDeploy/ + export PYTHONPATH=$(pwd)/tests/xpu_ci:$PYTHONPATH + mkdir -p case_logs + set +e + python -m pytest -v -s --tb=short tests/xpu_ci/8cards_cases/ + exit_code=$? + set -e + + # 修改case_logs权限,确保Docker外部的runner用户可以读取并上传 + chmod -R a+rX case_logs/ 2>/dev/null || true + + if [ $exit_code -eq 0 ]; then + echo "============================8卡cases测试通过!============================" + else + echo "============================8卡cases测试失败,请检查日志!============================" + exit $exit_code + fi + ' + + - name: Upload case logs + if: always() + uses: actions/upload-artifact@v6 + with: + name: xpu-8cards-case-logs + path: FastDeploy/case_logs/ + retention-days: 7 + if-no-files-found: ignore diff --git a/.github/workflows/approve.yml b/.github/workflows/approve.yml new file mode 100644 index 00000000000..6de30d6f564 --- /dev/null +++ b/.github/workflows/approve.yml @@ -0,0 +1,42 @@ +name: Approval + +on: + pull_request: + branches: + - develop + - 'release/*' + +env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + +jobs: + Approval: + name: Approval + if: ${{ github.repository_owner == 'PaddlePaddle' }} + runs-on: ubuntu-latest + env: + PR_ID: ${{ github.event.pull_request.number }} + BRANCH: ${{ github.event.pull_request.base.ref }} + steps: + - name: Checkout base repo + uses: actions/checkout@v6 + with: + ref: ${{ github.event.pull_request.base.ref }} + fetch-depth: 1000 + + - name: Merge PR to test branch + run: | + git fetch origin pull/${PR_ID}/merge + git checkout -b test FETCH_HEAD + git log -n 3 --oneline + git remote add upstream https://github.com/PaddlePaddle/FastDeploy.git + git fetch upstream $BRANCH + + - name: Setup python3.10 + uses: actions/setup-python@v6 + with: + python-version: '3.10' + + - name: Run approval check script + run: | + bash scripts/check_approval.sh diff --git a/.github/workflows/cancel_ci_iluvatar.yml b/.github/workflows/cancel_ci_iluvatar.yml new file mode 100644 index 00000000000..9dba9a7d1e0 --- /dev/null +++ b/.github/workflows/cancel_ci_iluvatar.yml @@ -0,0 +1,20 @@ +name: ILUVATAR-CI + +on: + pull_request: + types: [closed] + branches: [develop, release/**] +permissions: read-all + +concurrency: + group: ${{ github.event.pull_request.number }}-${{ github.workflow }} + cancel-in-progress: true + +jobs: + cancel: + name: Cancel ILUVATAR-CI for ${{ github.event.pull_request.number }} + runs-on: ubuntu-latest + steps: + - name: Cancel ILUVATAR-CI + run: | + exit 0 diff --git a/.github/workflows/cancel_ci_xpu.yml b/.github/workflows/cancel_ci_xpu.yml new file mode 100644 index 00000000000..befd59796e9 --- /dev/null +++ b/.github/workflows/cancel_ci_xpu.yml @@ -0,0 +1,20 @@ +name: CI_XPU + +on: + pull_request: + types: [closed] + branches: [develop, release/**] +permissions: read-all + +concurrency: + group: ${{ github.event.pull_request.number }}-${{ github.workflow }} + cancel-in-progress: true + +jobs: + cancel: + name: Cancel CI_XPU for ${{ github.event.pull_request.number }} + runs-on: ubuntu-latest + steps: + - name: Cancel CI_XPU + run: | + exit 0 diff --git a/.github/workflows/cancel_pr_build_and_test.yml b/.github/workflows/cancel_pr_build_and_test.yml new file mode 100644 index 00000000000..bb488a529ea --- /dev/null +++ b/.github/workflows/cancel_pr_build_and_test.yml @@ -0,0 +1,19 @@ +name: PR Build and Test +on: + pull_request: + types: [closed] + branches: [develop, release/**] +permissions: read-all + +concurrency: + group: ${{ github.event.pull_request.number }}-${{ github.workflow }} + cancel-in-progress: true + +jobs: + cancel: + name: Cancel PR Build and Test for ${{ github.event.pull_request.number }} + runs-on: ubuntu-latest + steps: + - name: Cancel PR Build and Test + run: | + exit 0 diff --git a/.github/workflows/ce_job.yml b/.github/workflows/ce_job.yml new file mode 100644 index 00000000000..5b20eccdf2e --- /dev/null +++ b/.github/workflows/ce_job.yml @@ -0,0 +1,350 @@ +name: CE Compile Job + +on: + workflow_dispatch: + push: + branches: + - develop + - 'release/*' +permissions: read-all + +concurrency: + group: CE-Job-${{ github.ref }}-${{ github.sha }} + cancel-in-progress: true + +jobs: + ce_job_pre_check: + runs-on: ubuntu-latest + env: + COMPILE_BRANCH: ${{ vars.COMPILE_BRANCH }} + CE_COMPILE_SELECTION: ${{ vars.CE_COMPILE_SELECTION }} + COMPILE_USE_PADDLE_WHL_URL_MAPPINGS: ${{ vars.COMPILE_USE_PADDLE_WHL_URL_MAPPINGS }} + outputs: + branch_match: ${{ steps.set_output.outputs.branch_match }} + compile_use_paddle_whl_url: ${{ steps.set_output.outputs.compile_use_paddle_whl_url }} + sm8689_match: ${{ steps.set_output.outputs.sm8689_match }} + sm8090_match: ${{ steps.set_output.outputs.sm8090_match }} + + steps: + - name: Set Version + id: set_output + env: + COMPILE_BRANCH: ${{ env.COMPILE_BRANCH }} + CE_COMPILE_SELECTION: ${{ env.CE_COMPILE_SELECTION }} + COMPILE_USE_PADDLE_WHL_URL_MAPPINGS: ${{ env.COMPILE_USE_PADDLE_WHL_URL_MAPPINGS }} + GITHUB_REF_NAME: ${{ github.ref_name }} + run: | + # 选择要触发编译任务的分支 done + # 选择指定分支要编译的任务 8090或者8689 + # 指定分支编译要使用的Paddle的安装包,默认使用nightly最新的 + + IFS=',' read -ra BRANCHES <<< "$COMPILE_BRANCH" + MATCH=false + for b in "${BRANCHES[@]}"; do + if [[ "$b" == "${GITHUB_REF_NAME}" ]]; then + MATCH=true + break + fi + done + echo "branch_match=$MATCH" >> $GITHUB_OUTPUT + + # 通过变量CE_COMPILE_SELECTION中的映射关系,决定分支是编译sm8090还是sm8689 + for pair in $(echo "$CE_COMPILE_SELECTION" | tr ';' ' '); do + branch=$(echo "$pair" | cut -d',' -f1) + compile_task_list=$(echo "$pair" | cut -d',' -f2) + + if [[ "$branch" == "$GITHUB_REF_NAME" ]]; then + + # 判断里面是否包含 sm8090 或 sm8689 + if [[ "$compile_task_list" == *"sm8090"* ]]; then + echo "sm8090_match=true" >> $GITHUB_OUTPUT + fi + if [[ "$compile_task_list" == *"sm8689"* ]]; then + echo "sm8689_match=true" >> $GITHUB_OUTPUT + fi + break + fi + done + + # 通过变量COMPILE_USE_PADDLE_WHL_URL_MAPPINGS中的映射关系,决定是否是安装指定版本的Paddle还是直接安装URL + for pair in $(echo $COMPILE_USE_PADDLE_WHL_URL_MAPPINGS | tr ';' ' '); do + branch=$(echo "$pair" | cut -d',' -f1) + paddle_whl_url=$(echo "$pair" | cut -d',' -f2) + if [[ "$branch" == "${{ github.ref_name }}" ]]; then + FOUND_PADDLE_URL="$paddle_whl_url" + echo "compile_use_paddle_whl_url=${FOUND_PADDLE_URL}" >> $GITHUB_OUTPUT + break + fi + done + + print_ce_job_pre_check_outputs: + runs-on: ubuntu-latest + needs: ce_job_pre_check + steps: + - name: Print outputs as JSON + run: | + echo '${{ toJSON(needs.ce_job_pre_check.outputs) }}' + + + clone: + environment: CodeSync + name: FD-Clone-Linux + runs-on: ubuntu-latest + needs: ce_job_pre_check + if: ${{ needs.ce_job_pre_check.outputs.branch_match == 'true' }} + outputs: + repo_archive_url: ${{ steps.set_output.outputs.repo_archive_url }} + steps: + - name: Clone FastDeploy + uses: actions/checkout@v6 + with: + ref: ${{ github.event_name == 'pull_request' + && github.event.pull_request.base.ref + || github.ref_name }} + submodules: 'recursive' + fetch-depth: 1000 + + - name: Python Setup + uses: actions/setup-python@v6 + with: + python-version: '3.10' + - name: Code Info Show and Upload + id: set_output + env: + AK: ${{ secrets.BOS_AK }} + SK: ${{ secrets.BOS_SK }} + run: | + git submodule foreach --recursive sh -c "git config remote.origin.fetch '+refs/heads/*:refs/remotes/origin/*'" + echo "Current HEAD Log:" + git log --oneline -n 5 + ls + cd .. + tar -zcf FastDeploy.tar.gz FastDeploy + commit_id=${{ github.sha }} + branch_name=${{ github.ref_name }} + + wget -q --no-proxy --no-check-certificate https://paddle-qa.bj.bcebos.com/CodeSync/develop/PaddlePaddle/PaddleTest/tools/bos_tools.py + push_file=$(realpath bos_tools.py) + python -m pip install bce-python-sdk==0.9.29 + + filename="FastDeploy.tar.gz" + target_paths=( + "paddle-qa/BRANCH/FastDeploy/${branch_name}/${commit_id}" + "paddle-qa/BRANCH/FastDeploy/${branch_name}/latest" + ) + + for target_path in "${target_paths[@]}"; do + echo "Uploading ${filename} to ${target_path}" + python "${push_file}" "${filename}" "${target_path}" + done + + base_prefix="paddle-qa/" + + commit_path_stripped="${target_paths[0]#${base_prefix}}" + latest_path_stripped="${target_paths[1]#${base_prefix}}" + + REPO_ARCHIVE_URL="https://paddle-qa.bj.bcebos.com/${commit_path_stripped}/${filename}" + LATEST_REPO_ARCHIVE_URL="https://paddle-qa.bj.bcebos.com/${latest_path_stripped}/${filename}" + + echo "commit archive url is ${REPO_ARCHIVE_URL}" + echo "latest archive url is ${LATEST_REPO_ARCHIVE_URL}" + + echo "repo_archive_url=${REPO_ARCHIVE_URL}" >> $GITHUB_OUTPUT + + resultshow: + name: Show Code Archive Output + needs: clone + runs-on: ubuntu-latest + steps: + - name: Print repo_archive_url path + run: | + echo "The code archive is located at: ${{ needs.clone.outputs.repo_archive_url }}" + + build_sm8090: + name: BUILD_SM8090 + needs: [clone, ce_job_pre_check] + if: ${{ needs.ce_job_pre_check.outputs.sm8090_match == 'true' }} + uses: ./.github/workflows/_build_linux.yml + with: + DOCKER_IMAGE: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:fastdeploy-build-cuda126-manylinux + FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }} + COMPILE_ARCH: "80,90" + WITH_NIGHTLY_BUILD: OFF + FD_VERSION: 0.0.0 + PADDLE_WHL_URL: ${{ needs.ce_job_pre_check.outputs.compile_use_paddle_whl_url }} + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + + build_sm8090_rl: + name: BUILD_SM8090_RL + needs: [clone, ce_job_pre_check] + if: ${{ needs.ce_job_pre_check.outputs.sm8090_match == 'true' }} + uses: ./.github/workflows/_build_linux_rl.yml + with: + DOCKER_IMAGE: iregistry.baidu-int.com/new_rl_infra/base-images:paddlecloud-ubuntu24.04-gcc13.3-cuda12.9-cudnn9.9-bccl1.4.1.4-nccl2.26.5-openmpi4.1.5-FleetY13.0.0-v2.4.0-rc1 + FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }} + COMPILE_ARCH: "80,90" + WITH_NIGHTLY_BUILD: OFF + FD_VERSION: 0.0.0 + PADDLE_WHL_URL: https://paddle-qa.bj.bcebos.com/paddle-pipeline/Paddle-RL-Compile/develop/latest/paddlepaddle_gpu-3.3.0.dev-cp310-cp310-linux_x86_64.whl + + build_sm8689: + name: BUILD_SM8689 + needs: [clone, ce_job_pre_check] + if: ${{ needs.ce_job_pre_check.outputs.sm8689_match == 'true' }} + uses: ./.github/workflows/_build_linux.yml + with: + DOCKER_IMAGE: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:fastdeploy-build-cuda126-manylinux + FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }} + COMPILE_ARCH: "86,89" + WITH_NIGHTLY_BUILD: OFF + FD_VERSION: 0.0.0 + PADDLE_WHL_URL: ${{ needs.ce_job_pre_check.outputs.compile_use_paddle_whl_url }} + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + + ce_upload_sm8090: + environment: CodeSync + name: CE_UPLOAD + needs: build_sm8090 + runs-on: ubuntu-latest + env: + AK: ${{ secrets.BOS_AK }} + SK: ${{ secrets.BOS_SK }} + FASTDEPLOY_WHEEL_URL: ${{ needs.build_sm8090.outputs.wheel_path }} + COMPILE_ARCH: "80,90" + steps: + - uses: actions/setup-python@v6 + with: + python-version: '3.10' + - name: Wheel Info Show and Upload + run: | + echo "The wheel is located at: ${{ needs.build_sm8090.outputs.wheel_path }}" + wget -q --no-check-certificate ${{ needs.build_sm8090.outputs.wheel_path }} + filename=$(basename ${{ needs.build_sm8090.outputs.wheel_path }}) + + commit_id=${{ github.sha }} + branch_name=${{ github.ref_name }} + + wget -q --no-proxy --no-check-certificate https://paddle-qa.bj.bcebos.com/CodeSync/develop/PaddlePaddle/PaddleTest/tools/bos_tools.py + push_file=$(realpath bos_tools.py) + python -m pip install bce-python-sdk==0.9.29 + + target_paths=( + "paddle-qa/paddle-pipeline/FastDeploy_ActionCE/SM${COMPILE_ARCH//,/_}/${branch_name}/${commit_id}" + "paddle-qa/paddle-pipeline/FastDeploy_ActionCE/cu126/SM_80/${branch_name}/${commit_id}" + "paddle-qa/paddle-pipeline/FastDeploy_ActionCE/cu126/SM_90/${branch_name}/${commit_id}" + "paddle-qa/paddle-pipeline/FastDeploy_ActionCE/SM${COMPILE_ARCH//,/_}/${branch_name}/latest" + "paddle-qa/paddle-pipeline/FastDeploy_ActionCE/cu126/SM_80/${branch_name}/latest" + "paddle-qa/paddle-pipeline/FastDeploy_ActionCE/cu126/SM_90/${branch_name}/latest" + ) + + for target_path in "${target_paths[@]}"; do + echo "Uploading ${filename} to ${target_path}" + python "${push_file}" "${filename}" "${target_path}" + done + + base_prefix="paddle-qa/" + commit_path_stripped="${target_paths[0]#${base_prefix}}" + latest_path_stripped="${target_paths[3]#${base_prefix}}" + WHEEL_PATH="https://paddle-qa.bj.bcebos.com/${commit_path_stripped}/${filename}" + WHEEL_PATH_LATEST="https://paddle-qa.bj.bcebos.com/${latest_path_stripped}/${filename}" + + echo "commit wheel url is ${WHEEL_PATH}" + echo "latest wheel url is ${WHEEL_PATH_LATEST}" + + ce_upload_sm8090_rl: + environment: CodeSync + name: CE_UPLOAD_RL + needs: build_sm8090_rl + runs-on: ubuntu-latest + env: + AK: ${{ secrets.BOS_AK }} + SK: ${{ secrets.BOS_SK }} + FASTDEPLOY_WHEEL_URL: ${{ needs.build_sm8090_rl.outputs.wheel_path_rl }} + COMPILE_ARCH: "80,90" + steps: + - uses: actions/setup-python@v6 + with: + python-version: '3.10' + - name: Wheel Info Show and Upload + run: | + echo "The wheel is located at: ${{ needs.build_sm8090_rl.outputs.wheel_path_rl }}" + wget -q --no-check-certificate ${{ needs.build_sm8090_rl.outputs.wheel_path_rl }} + filename=$(basename ${{ needs.build_sm8090_rl.outputs.wheel_path_rl }}) + + commit_id=${{ github.sha }} + branch_name=${{ github.ref_name }} + + wget -q --no-proxy --no-check-certificate https://paddle-qa.bj.bcebos.com/CodeSync/develop/PaddlePaddle/PaddleTest/tools/bos_tools.py + push_file=$(realpath bos_tools.py) + python -m pip install bce-python-sdk==0.9.29 + + target_paths=( + "paddle-qa/paddle-pipeline/FastDeploy_ActionCE_RL/cu129/SM_8090/${branch_name}/${commit_id}" + "paddle-qa/paddle-pipeline/FastDeploy_ActionCE_RL/cu129/SM_8090/${branch_name}/latest" + ) + + for target_path in "${target_paths[@]}"; do + echo "Uploading ${filename} to ${target_path}" + python "${push_file}" "${filename}" "${target_path}" + done + + base_prefix="paddle-qa/" + commit_path_stripped="${target_paths[0]#${base_prefix}}" + latest_path_stripped="${target_paths[1]#${base_prefix}}" + WHEEL_PATH="https://paddle-qa.bj.bcebos.com/${commit_path_stripped}/${filename}" + WHEEL_PATH_LATEST="https://paddle-qa.bj.bcebos.com/${latest_path_stripped}/${filename}" + + echo "commit wheel url is ${WHEEL_PATH}" + echo "latest wheel url is ${WHEEL_PATH_LATEST}" + + ce_upload_sm8689: + environment: CodeSync + name: CE_UPLOAD + needs: build_sm8689 + runs-on: ubuntu-latest + env: + AK: ${{ secrets.BOS_AK }} + SK: ${{ secrets.BOS_SK }} + FASTDEPLOY_WHEEL_URL: ${{ needs.build_sm8689.outputs.wheel_path }} + COMPILE_ARCH: "86,89" + steps: + - uses: actions/setup-python@v6 + with: + python-version: '3.10' + - name: Wheel Info Show and Upload + run: | + echo "The wheel is located at: ${{ needs.build_sm8689.outputs.wheel_path }}" + wget -q --no-check-certificate ${{ needs.build_sm8689.outputs.wheel_path }} + filename=$(basename ${{ needs.build_sm8689.outputs.wheel_path }}) + + commit_id=${{ github.sha }} + branch_name=${{ github.ref_name }} + + wget -q --no-proxy --no-check-certificate https://paddle-qa.bj.bcebos.com/CodeSync/develop/PaddlePaddle/PaddleTest/tools/bos_tools.py + push_file=$(realpath bos_tools.py) + python -m pip install bce-python-sdk==0.9.29 + + target_paths=( + "paddle-qa/paddle-pipeline/FastDeploy_ActionCE/SM${COMPILE_ARCH//,/_}/${branch_name}/${commit_id}" + "paddle-qa/paddle-pipeline/FastDeploy_ActionCE/cu126/SM_86/${branch_name}/${commit_id}" + "paddle-qa/paddle-pipeline/FastDeploy_ActionCE/cu126/SM_89/${branch_name}/${commit_id}" + "paddle-qa/paddle-pipeline/FastDeploy_ActionCE/SM${COMPILE_ARCH//,/_}/${branch_name}/latest" + "paddle-qa/paddle-pipeline/FastDeploy_ActionCE/cu126/SM_86/${branch_name}/latest" + "paddle-qa/paddle-pipeline/FastDeploy_ActionCE/cu126/SM_89/${branch_name}/latest" + ) + + for target_path in "${target_paths[@]}"; do + echo "Uploading ${filename} to ${target_path}" + python "${push_file}" "${filename}" "${target_path}" + done + + base_prefix="paddle-qa/" + commit_path_stripped="${target_paths[0]#${base_prefix}}" + latest_path_stripped="${target_paths[3]#${base_prefix}}" + WHEEL_PATH="https://paddle-qa.bj.bcebos.com/${commit_path_stripped}/${filename}" + WHEEL_PATH_LATEST="https://paddle-qa.bj.bcebos.com/${latest_path_stripped}/${filename}" + + echo "commit wheel url is ${WHEEL_PATH}" + echo "latest wheel url is ${WHEEL_PATH_LATEST}" diff --git a/.github/workflows/check-bypass.yml b/.github/workflows/check-bypass.yml new file mode 100644 index 00000000000..1d50c641675 --- /dev/null +++ b/.github/workflows/check-bypass.yml @@ -0,0 +1,103 @@ +on: + workflow_call: + inputs: + workflow-name: + required: true + type: string + secrets: + github-token: + required: true + outputs: + can-skip: + description: "Whether the workflow can be skipped." + value: ${{ jobs.check-bypass.outputs.can-skip }} + can-skip-docs: + description: "Whether the workflow can be skipped due to docs-only change." + value: ${{ jobs.check-bypass.outputs.can-skip-docs }} + +jobs: + check-bypass: + name: Check bypass + runs-on: ubuntu-latest + permissions: + contents: read + env: + CI_TEAM_MEMBERS: '["yuanlehome","YuanRisheng","Jiang-Jia-Jun","DDDivano","XieYunshen","EmmonsCurse","CSWYF3634076","plusNew001"]' + outputs: + can-skip: ${{ steps.final-output.outputs.can-skip }} + can-skip-docs: ${{ steps.final-output.outputs.can-skip-docs }} + steps: + - name: Cleanup + run: | + rm -rf * .[^.]* + + - id: check-bypass + name: Check Bypass + uses: PFCCLab/ci-bypass@v2 + with: + github-token: ${{ secrets.github-token }} + non-pull-request-event-strategy: 'never-skipped' + type: 'composite' + composite-rule: | + { + "any": [ + { + "type": "labeled", + "label": ["skip-ci: ${{ inputs.workflow-name }}", "skip-ci: all"], + "username": ${{ env.CI_TEAM_MEMBERS }} + }, + { + "type": "commented", + "comment-pattern": [".*/skip-ci ${{ inputs.workflow-name }}.*", ".*/skip-ci all.*"], + "username": ${{ env.CI_TEAM_MEMBERS }} + } + ] + } + + - id: check-only-docs + name: Check if only Docs files changed + env: + GITHUB_TOKEN: ${{ secrets.github-token }} + run: | + if [[ "${{ github.event_name }}" != "pull_request" ]]; then + echo "can-skip-docs=false" >> "$GITHUB_OUTPUT" + exit 0 + fi + + files=$(gh pr view ${{ github.event.pull_request.number }} --repo ${{ github.repository }} --json files --jq '.files[].path') + echo "$files" + + can_skip_docs=true + for f in $files; do + if [[ ! "$f" =~ \.(md|txt|yaml|go)$ ]]; then + can_skip_docs=false + break + fi + done + + echo "can-skip-docs=$can_skip_docs" >> "$GITHUB_OUTPUT" + + - id: final-output + name: Final can-skip result + run: | + if [[ "${{ steps.check-only-docs.outputs['can-skip-docs'] }}" == "true" ]]; then + echo "can-skip=true" >> "$GITHUB_OUTPUT" + echo "can-skip-docs=true" >> "$GITHUB_OUTPUT" + exit 0 + fi + + if [[ "${{ steps.check-bypass.outputs['can-skip'] }}" == "true" ]]; then + echo "can-skip=true" >> "$GITHUB_OUTPUT" + echo "can-skip-docs=false" >> "$GITHUB_OUTPUT" + exit 0 + fi + + echo "can-skip=false" >> "$GITHUB_OUTPUT" + echo "can-skip-docs=false" >> "$GITHUB_OUTPUT" + + - id: print-final-output + name: Print final result + run: | + echo "===== Final can-skip result =====" + echo "can-skip=${{ steps.final-output.outputs['can-skip'] }}" + echo "can-skip-docs=${{ steps.final-output.outputs['can-skip-docs'] }}" diff --git a/.github/workflows/cherry-pick.yml b/.github/workflows/cherry-pick.yml new file mode 100644 index 00000000000..c6e1bad992e --- /dev/null +++ b/.github/workflows/cherry-pick.yml @@ -0,0 +1,191 @@ +name: Cherry Pick + +on: + pull_request_target: + branches: [develop] + types: [closed, labeled] + +permissions: + contents: write + pull-requests: write + issues: write + +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number }} + cancel-in-progress: false + +jobs: + cherry-pick: + if: > + github.event.pull_request.merged == true && + ( + github.event.action == 'labeled' || + contains(join(github.event.pull_request.labels.*.name, ' '), 'cherry-pick') + ) + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@v6 + with: + fetch-depth: 0 + persist-credentials: false + + - name: Cherry Pick + env: + GH_TOKEN: ${{ secrets.CHERRY_PICK_BOT_TOKEN }} + PR_NUMBER: ${{ github.event.pull_request.number }} + PR_TITLE: ${{ github.event.pull_request.title }} + PR_BODY: ${{ github.event.pull_request.body }} + PR_AUTHOR: ${{ github.event.pull_request.user.login }} + MERGE_COMMIT_SHA: ${{ github.event.pull_request.merge_commit_sha }} + BOT_USERNAME: EmmonsCurse + REPO_NAME: EmmonsCurse/FastDeploy + run: | + # Function to post comment + post_comment() { + gh pr comment "$PR_NUMBER" --body "$1" + } + + # Configure git for the original author + echo "Fetching author info for $PR_AUTHOR..." + AUTHOR_INFO=$(gh api "/users/$PR_AUTHOR" --jq '{email: .email, name: .name}') + AUTHOR_EMAIL=$(echo "$AUTHOR_INFO" | jq -r '.email') + AUTHOR_NAME=$(echo "$AUTHOR_INFO" | jq -r '.name') + + if [ "$AUTHOR_EMAIL" = "null" ] || [ -z "$AUTHOR_EMAIL" ]; then + AUTHOR_EMAIL="${PR_AUTHOR}@users.noreply.github.com" + echo "Author email not found, using default: $AUTHOR_EMAIL" + fi + if [ "$AUTHOR_NAME" = "null" ] || [ -z "$AUTHOR_NAME" ]; then + AUTHOR_NAME="${PR_AUTHOR}" + echo "Author name not found, using username: $AUTHOR_NAME" + fi + + git config user.name "$AUTHOR_NAME" + git config user.email "$AUTHOR_EMAIL" + + # Capture current SHA to return to later + ORIGINAL_HEAD_SHA=$(git rev-parse HEAD) + + # Get labels + LABELS=$(gh pr view "$PR_NUMBER" --json labels --jq '.labels[].name') + + if [ -z "$LABELS" ]; then + echo "No labels found." + exit 0 + fi + + # Loop through labels + while read -r label; do + if [[ "$label" == cherry-pick:* ]]; then + TARGET_BRANCH=$(echo "${label#cherry-pick:}" | xargs) + + if [ -z "$TARGET_BRANCH" ]; then + echo "Empty target branch for label '$label', skipping." + continue + fi + + echo "Processing cherry-pick to $TARGET_BRANCH" + + # Check if target branch exists on remote + if ! git ls-remote --exit-code --heads origin "$TARGET_BRANCH"; then + echo "Target branch $TARGET_BRANCH does not exist." + post_comment "❌ Cherry-pick failed: Target branch \`$TARGET_BRANCH\` does not exist." + continue + fi + + # Create a new branch for the cherry-pick + NEW_BRANCH="cherry-pick/$PR_NUMBER/$TARGET_BRANCH" + + # Clean up local branch if it exists (from previous run) + if git show-ref --verify --quiet "refs/heads/$NEW_BRANCH"; then + git branch -D "$NEW_BRANCH" + fi + + # Fetch the target branch and checkout a new branch from it + git fetch origin "$TARGET_BRANCH" + git checkout -b "$NEW_BRANCH" "origin/$TARGET_BRANCH" + + # Cherry pick + # Try standard cherry-pick first (for squash merges or single commits) + if git cherry-pick "$MERGE_COMMIT_SHA"; then + echo "Cherry-pick successful." + else + echo "Standard cherry-pick failed, trying with -m 1 (for merge commits)..." + git cherry-pick --abort + if git cherry-pick -m 1 "$MERGE_COMMIT_SHA"; then + echo "Cherry-pick with -m 1 successful." + else + echo "Cherry-pick failed." + git cherry-pick --abort + post_comment "❌ Cherry-pick failed: Conflicts detected when cherry-picking to \`$TARGET_BRANCH\`. Please resolve manually." + + # Cleanup + git checkout "$ORIGINAL_HEAD_SHA" + git branch -D "$NEW_BRANCH" + continue + fi + fi + + # Push + # Construct authenticated URL for the fork + FORK_URL_AUTH="https://${BOT_USERNAME}:${GH_TOKEN}@github.com/${REPO_NAME}.git" + + echo "Pushing to fork..." + git push "$FORK_URL_AUTH" "$NEW_BRANCH" --force + + # Create PR + # If PR_TITLE starts with "[", don't insert an extra space. + if [ "${PR_TITLE:0:1}" = "[" ]; then + NEW_TITLE="[Cherry-Pick]$PR_TITLE(#$PR_NUMBER)" + else + NEW_TITLE="[Cherry-Pick] $PR_TITLE(#$PR_NUMBER)" + fi + + NEW_BODY="Cherry-pick of #$PR_NUMBER (authored by @${PR_AUTHOR}) to \`$TARGET_BRANCH\`. + + devPR:https://github.com/PaddlePaddle/FastDeploy/pull/$PR_NUMBER + + --- + + $PR_BODY" + # Remove leading whitespace + NEW_BODY=$(echo "$NEW_BODY" | sed 's/^[ \t]*//') + + # Prepare head ref for PR creation (owner:branch) + HEAD_REF="${BOT_USERNAME}:${NEW_BRANCH}" + + # Check if PR already exists + EXISTING_PR=$(gh pr list --base "$TARGET_BRANCH" --head "$NEW_BRANCH" --json url --jq '.[0].url') + + if [ -n "$EXISTING_PR" ]; then + echo "PR already exists: $EXISTING_PR" + post_comment "ℹ️ Cherry-pick PR already exists: $EXISTING_PR" + else + # Create PR using gh CLI, ignoring errors because of "Resource not accessible" false positives + gh pr create --base "$TARGET_BRANCH" --head "$HEAD_REF" --title "$NEW_TITLE" --body "$NEW_BODY" || true + + # Wait a bit for eventual consistency + sleep 2 + + # Search for the created PR + CREATED_PR_URL=$(gh pr list --head "$NEW_BRANCH" --state all --json url --jq '.[0].url') + + if [ -n "$CREATED_PR_URL" ]; then + echo "Created PR: $CREATED_PR_URL" + post_comment "✅ Cherry-pick successful! Created PR: $CREATED_PR_URL" + + # Request review + gh pr review-request "$CREATED_PR_URL" --reviewer "$PR_AUTHOR" || true + else + echo "Failed to create PR." + post_comment "❌ Cherry-pick failed: Could not create PR to \`$TARGET_BRANCH\`." + continue + fi + fi + + # Cleanup for next loop + git checkout "$ORIGINAL_HEAD_SHA" + git branch -D "$NEW_BRANCH" + fi + done <<< "$LABELS" diff --git a/.github/workflows/ci.yml b/.github/workflows/ci_hpu.yml similarity index 65% rename from .github/workflows/ci.yml rename to .github/workflows/ci_hpu.yml index 518b15eb998..1bf7d7e4708 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci_hpu.yml @@ -1,4 +1,4 @@ -name: CI +name: CI_HPU on: pull_request: @@ -8,23 +8,29 @@ on: workflow_dispatch: concurrency: - group: ${{ github.event.pull_request.number }} + group: ${{ github.event.pull_request.number }}-hpu-ci cancel-in-progress: true jobs: - build: - runs-on: [self-hosted, GPU-L20-4Card] + check_bypass: + uses: ./.github/workflows/check-bypass.yml + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + with: + workflow-name: ci_hpu + + CI_HPU: + runs-on: [self-hosted, HPU-8Card] + needs: check_bypass + if: ${{ needs.check_bypass.outputs.can-skip != 'true' }} steps: - name: Print current runner name run: | echo "Current runner name: ${{ runner.name }}" - # Because the system version is lower than 2.23, the checkout cannot be used. - # - name: Checkout code - # uses: actions/checkout@v4 - name: Code Checkout env: - docker_image: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddle:fastdeploy-ciuse-cuda126 + docker_image: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/fastdeploy-hpu:latest run: | REPO="https://github.com/${{ github.repository }}.git" FULL_REPO="${{ github.repository }}" @@ -55,35 +61,34 @@ jobs: - name: Run CI unittest env: - docker_image: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddle:fastdeploy-ciuse-cuda126 + docker_image: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/fastdeploy-hpu:latest run: | runner_name="${{ runner.name }}" last_char="${runner_name: -1}" - if [ "${last_char}" = "1" ]; then - gpu_id=2 - DEVICES="2,3" + if [[ "$last_char" =~ [0-3] ]]; then + hpu_id="$last_char" else - gpu_id=0 - DEVICES="0,1" + hpu_id="0" fi - FD_API_PORT=$((9180 + gpu_id * 100)) - FD_ENGINE_QUEUE_PORT=$((9150 + gpu_id * 100)) - FD_METRICS_PORT=$((9170 + gpu_id * 100)) + FD_API_PORT=8388 + FD_ENGINE_QUEUE_PORT=8902 + FD_METRICS_PORT=8202 PARENT_DIR=$(dirname "$WORKSPACE") echo "PARENT_DIR:$PARENT_DIR" - docker run --rm --net=host -v $(pwd):/workspace -w /workspace \ - -v "/ssd4/GithubActions/gitconfig:/etc/gitconfig:ro" \ - -v "/ssd4/GithubActions/ModelData:/ModelData:ro" \ - -v "/ssd4/GithubActions/CacheDir:/root/.cache" \ - -v "/ssd4/GithubActions/ConfigDir:/root/.config" \ - -e "MODEL_PATH=/ModelData" \ + docker run --rm --net=host --cap-add=SYS_PTRACE --privileged --shm-size=64G \ + -v $(pwd):/workspace -w /workspace \ + -v "/ssd1:/ssd1" \ + -e "MODEL_PATH=/ssd1" \ + -e "http_proxy=$(git config --global --get http.proxy)" \ + -e "https_proxy=$(git config --global --get https.proxy)" \ + -e "no_proxy=bcebos.com,mirrors.tuna.tsinghua.edu.cn,localhost,127.0.0.1,0.0.0.0,10.0.0.0/8,192.168.1.0/24" \ -e "FD_API_PORT=${FD_API_PORT}" \ -e "FD_ENGINE_QUEUE_PORT=${FD_ENGINE_QUEUE_PORT}" \ -e "FD_METRICS_PORT=${FD_METRICS_PORT}" \ - --gpus '"device='"${DEVICES}"'"' ${docker_image} /bin/bash -c " + ${docker_image} /bin/bash -c " git config --global --add safe.directory /workspace/FastDeploy cd FastDeploy - bash scripts/run_ci.sh + bash scripts/run_ci_hpu.sh " diff --git a/.github/workflows/ci_iluvatar.yml b/.github/workflows/ci_iluvatar.yml new file mode 100644 index 00000000000..3c46306ba92 --- /dev/null +++ b/.github/workflows/ci_iluvatar.yml @@ -0,0 +1,27 @@ +name: ILUVATAR-CI +on: + pull_request: + types: [opened, synchronize] + branches: [develop, release/**] +permissions: read-all + +concurrency: + group: ${{ github.event.pull_request.number }}-${{ github.workflow }} + cancel-in-progress: true + +jobs: + clone: + name: FD-Clone-Linux-ILUVATAR + uses: ./.github/workflows/_clone_linux.yml + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + + iluvatar_test: + name: Run iluvatar Tests + needs: [clone] + uses: ./.github/workflows/_iluvatar_cases.yml + with: + DOCKER_IMAGE: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/device/paddle-ixuca:3.3.0-20260312 + FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }} + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} diff --git a/.github/workflows/ci_image_update.yml b/.github/workflows/ci_image_update.yml new file mode 100644 index 00000000000..762cad91023 --- /dev/null +++ b/.github/workflows/ci_image_update.yml @@ -0,0 +1,181 @@ +name: CI Images Build + +on: + workflow_dispatch: + schedule: + - cron: '0 18 * * *' # 2:00 AM China Standard Time (UTC+8) + +permissions: read-all + +concurrency: + group: CI-Images-Build-${{ github.ref }}-${{ github.sha }} + cancel-in-progress: true + + +jobs: + clone: + environment: CodeSync + name: FD-Clone-Linux + runs-on: ubuntu-latest + outputs: + repo_archive_url: ${{ steps.set_output.outputs.repo_archive_url }} + steps: + - name: Clone FastDeploy + uses: actions/checkout@v6 + with: + ref: ${{ github.ref_name }} + submodules: 'recursive' + fetch-depth: 1000 + + - name: Python Setup + uses: actions/setup-python@v6 + with: + python-version: '3.10' + - name: Code Info Show and Upload + id: set_output + env: + AK: ${{ secrets.BOS_AK }} + SK: ${{ secrets.BOS_SK }} + run: | + git submodule foreach --recursive sh -c "git config remote.origin.fetch '+refs/heads/*:refs/remotes/origin/*'" + echo "Current HEAD Log:" + git log --oneline -n 5 + ls + cd .. + tar -zcf FastDeploy.tar.gz FastDeploy + if [[ "${{ github.ref_type }}" == "tag" ]]; then + commit_id=${{ github.sha }} + tag_name=${{ github.ref_name }} + target_path=paddle-qa/TAG/FastDeploy/${tag_name}/${commit_id} + else + commit_id=${{ github.sha }} + branch_name=${{ github.ref_name }} + target_path=paddle-qa/BRANCH/FastDeploy/${branch_name}/${commit_id} + fi + wget -q --no-proxy --no-check-certificate https://paddle-qa.bj.bcebos.com/CodeSync/develop/PaddlePaddle/PaddleTest/tools/bos_tools.py + push_file=$(realpath bos_tools.py) + python -m pip install bce-python-sdk==0.9.29 + ls + python ${push_file} FastDeploy.tar.gz ${target_path} + target_path_stripped="${target_path#paddle-qa/}" + REPO_ARCHIVE_URL=https://paddle-qa.bj.bcebos.com/${target_path_stripped}/FastDeploy.tar.gz + echo "repo_archive_url=${REPO_ARCHIVE_URL}" >> $GITHUB_OUTPUT + + resultshow: + name: Show Code Archive Output + needs: clone + runs-on: ubuntu-latest + steps: + - name: Print wheel path + run: | + echo "The code archive is located at: ${{ needs.clone.outputs.repo_archive_url }}" + + ci_image_build: + name: CI Images Build + needs: clone + uses: ./.github/workflows/_ci_image_build.yml + with: + CI_DOCKER_IMAGE_NAME: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:fastdeploy-ciuse-cuda126-dailyupdate-precheck + FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }} + + build_sm8090: + name: BUILD_SM8090 + needs: [clone, ci_image_build] + uses: ./.github/workflows/_build_linux.yml + with: + DOCKER_IMAGE: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:fastdeploy-build-cuda126-manylinux + FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }} + COMPILE_ARCH: "90" + WITH_NIGHTLY_BUILD: ${{ needs.publish_pre_check.outputs.with_nightly_build }} + FD_VERSION: ${{ needs.publish_pre_check.outputs.fd_version }} + PADDLEVERSION: ${{ needs.publish_pre_check.outputs.compile_use_paddle_version }} + PADDLE_WHL_URL: ${{ needs.publish_pre_check.outputs.compile_use_paddle_whl_url }} + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + + unittest_coverage: + name: Run FastDeploy Unit Tests and Coverage + needs: [clone,build_sm8090,ci_image_build] + uses: ./.github/workflows/_unit_test_coverage.yml + with: + DOCKER_IMAGE: ${{ needs.ci_image_build.outputs.docker_name_precheck }} + FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }} + FASTDEPLOY_WHEEL_URL: ${{ needs.build_sm8090.outputs.wheel_path }} + MODEL_CACHE_DIR: "/ssd2/actions-runner/ModelData" + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + + four_cards_test: + name: Run Four Cards Tests + needs: [clone,build_sm8090,ci_image_build] + uses: ./.github/workflows/_gpu_4cards_case_test.yml + with: + DOCKER_IMAGE: ${{ needs.ci_image_build.outputs.docker_name_precheck }} + FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }} + FASTDEPLOY_WHEEL_URL: ${{ needs.build_sm8090.outputs.wheel_path }} + MODEL_CACHE_DIR: "/ssd2/actions-runner/ModelData" + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + + logprob_test: + name: Run FastDeploy LogProb Tests + needs: [build_sm8090,ci_image_build] + uses: ./.github/workflows/_logprob_test_linux.yml + with: + DOCKER_IMAGE: ${{ needs.ci_image_build.outputs.docker_name_precheck }} + PADDLETEST_ARCHIVE_URL: "https://xly-devops.bj.bcebos.com/PaddleTest/PaddleTest.tar.gz" + FASTDEPLOY_WHEEL_URL: ${{ needs.build_sm8090.outputs.wheel_path }} + MODEL_CACHE_DIR: "/ssd2/actions-runner/ModelData" + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + + pre_ce_test: + name: Extracted partial CE model tasks to run in CI. + needs: [clone,build_sm8090,ci_image_build] + uses: ./.github/workflows/_pre_ce_test.yml + with: + DOCKER_IMAGE: ${{ needs.ci_image_build.outputs.docker_name_precheck }} + FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }} + FASTDEPLOY_WHEEL_URL: ${{ needs.build_sm8090.outputs.wheel_path }} + MODEL_CACHE_DIR: "/ssd2/actions-runner/ModelData" + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + + base_test: + name: Run Base Tests + needs: [clone,build_sm8090,ci_image_build] + uses: ./.github/workflows/_base_test.yml + with: + DOCKER_IMAGE: ${{ needs.ci_image_build.outputs.docker_name_precheck }} + FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }} + FASTDEPLOY_WHEEL_URL: ${{ needs.build_sm8090.outputs.wheel_path }} + MODEL_CACHE_DIR: "/ssd2/actions-runner/ModelData" + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + + stable_test: + name: Run Stable Tests + needs: [clone,build_sm8090,ci_image_build] + uses: ./.github/workflows/_stable_test.yml + with: + DOCKER_IMAGE: ${{ needs.ci_image_build.outputs.docker_name_precheck }} + FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }} + FASTDEPLOY_WHEEL_URL: ${{ needs.build_sm8090.outputs.wheel_path }} + MODEL_CACHE_DIR: "/ssd2/actions-runner/ModelData" + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + + publish_pre_check: + name: Publish Docker Images Pre Check + needs: [ci_image_build,unittest_coverage,four_cards_test,logprob_test,pre_ce_test,base_test,stable_test] + runs-on: [self-hosted, Docker-Build] + steps: + - name: Images Uploading + env: + images_name: ${{ needs.ci_image_build.outputs.docker_name_precheck }} + ci_image_name: "ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:fastdeploy-ciuse-cuda126-paddle-dev" + run: | + echo "images_name=${images_name}" + docker images ${ci_image_name} + docker tag ${images_name} ${ci_image_name} + docker push ${ci_image_name} diff --git a/.github/workflows/ci_metax.yml b/.github/workflows/ci_metax.yml new file mode 100644 index 00000000000..5584147eb8c --- /dev/null +++ b/.github/workflows/ci_metax.yml @@ -0,0 +1,34 @@ +name: CI_METAX + +on: + pull_request_target: + types: + - opened + - synchronize + branches: + - develop + - release/** + +permissions: + contents: read + +concurrency: + group: jenkins-pr-${{ github.event.pull_request.number }} + cancel-in-progress: true + +jobs: + trigger-jenkins: + name: Trigger Jenkins for PR + runs-on: ubuntu-latest + environment: Metax_ci + + steps: + - name: Trigger Jenkins job + timeout-minutes: 120 + uses: MetaX-MACA/simple-jenkins-githubaction@v1.1 + with: + job_name: paddle_fastdeploy_metax_smoketest + username: fastdeploy_builder + api_token: ${{ secrets.METAX_JENKINS_API_TOKEN }} + pr_number: ${{ github.event.pull_request.number }} + project_branch: ${{ github.event.pull_request.base.ref }} diff --git a/.github/workflows/ci_xpu.yml b/.github/workflows/ci_xpu.yml index 7bb267fd202..cf67385c24f 100644 --- a/.github/workflows/ci_xpu.yml +++ b/.github/workflows/ci_xpu.yml @@ -2,86 +2,51 @@ name: CI_XPU on: pull_request: - branches: - - develop - - 'release/*' - workflow_dispatch: + types: [opened, synchronize] + branches: [develop, release/**] +permissions: read-all concurrency: - group: ${{ github.event.pull_request.number }}-xpu-ci + group: ${{ github.event.pull_request.number }}-${{ github.workflow }} cancel-in-progress: true jobs: - CI_XPU: - runs-on: [self-hosted, XPU-P800-8Card] - steps: - - name: Print current runner name - run: | - echo "Current runner name: ${{ runner.name }}" - # Because the system version is lower than 2.23, the checkout cannot be used. - # - name: Checkout code - # uses: actions/checkout@v4 + clone: + name: FD-Clone-Linux-XPU + uses: ./.github/workflows/_clone_linux.yml + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} - - name: Code Checkout - env: - docker_image: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/fastdeploy-xpu:2.0.0 - run: | - REPO="https://github.com/${{ github.repository }}.git" - FULL_REPO="${{ github.repository }}" - REPO_NAME="${FULL_REPO##*/}" - BASE_BRANCH="${{ github.base_ref }}" - # Clean the repository directory before starting - docker run --rm --net=host -v $(pwd):/workspace -w /workspace \ - -e "REPO_NAME=${REPO_NAME}" \ - -e "BASE_BRANCH=${BASE_BRANCH}" \ - ${docker_image} /bin/bash -c ' - if [ -d ${REPO_NAME} ]; then - echo "Directory ${REPO_NAME} exists, removing it..." - rm -rf ${REPO_NAME} - fi - ' - git config --global user.name "FastDeployCI" - git config --global user.email "fastdeploy_ci@example.com" - git clone ${REPO} ${REPO_NAME} -b ${BASE_BRANCH} - cd FastDeploy - if [ "${{ github.event_name }}" = "pull_request" ]; then - git fetch origin pull/${{ github.event.pull_request.number }}/head:pr/${{ github.event.pull_request.number }} - git merge pr/${{ github.event.pull_request.number }} - git log -n 3 --oneline - else - git checkout ${{ github.sha }} - git log -n 3 --oneline - fi + xpu_build_test: + name: xpu_build_test + needs: [clone] + uses: ./.github/workflows/_build_xpu.yml + with: + FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }} + DOCKER_IMAGE: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/fastdeploy-xpu:ci + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} - - name: Run CI unittest - env: - docker_image: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/fastdeploy-xpu:2.0.0 - run: | - runner_name="${{ runner.name }}" - last_char="${runner_name: -1}" + xpu_4cards_case_test: + name: xpu_4cards_case_test + needs: [clone, xpu_build_test] + uses: ./.github/workflows/_xpu_4cards_case_test.yml + with: + FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }} + DOCKER_IMAGE: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/fastdeploy-xpu:ci + FASTDEPLOY_WHEEL_URL: ${{ needs.xpu_build_test.outputs.wheel_path }} + MODEL_PATH: /ssd3/model + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} - if [[ "$last_char" =~ [0-3] ]]; then - gpu_id="$last_char" - else - gpu_id="0" - fi - FD_API_PORT=$((9180 + gpu_id * 100)) - FD_ENGINE_QUEUE_PORT=$((9150 + gpu_id * 100)) - FD_METRICS_PORT=$((9170 + gpu_id * 100)) - - PARENT_DIR=$(dirname "$WORKSPACE") - echo "PARENT_DIR:$PARENT_DIR" - docker run --rm --net=host --cap-add=SYS_PTRACE --privileged --shm-size=64G \ - -v $(pwd):/workspace -w /workspace \ - -v "/ssd3:/ssd3" \ - -e "MODEL_PATH=/ssd3/model" \ - -e "http_proxy=$(git config --global --get http.proxy)" \ - -e "https_proxy=$(git config --global --get https.proxy)" \ - -e "FD_API_PORT=${FD_API_PORT}" \ - -e "FD_ENGINE_QUEUE_PORT=${FD_ENGINE_QUEUE_PORT}" \ - -e "FD_METRICS_PORT=${FD_METRICS_PORT}" \ - ${docker_image} /bin/bash -c " - git config --global --add safe.directory /workspace/FastDeploy - cd FastDeploy - bash scripts/run_ci_xpu.sh - " + xpu_8cards_case_test: + name: xpu_8cards_case_test + needs: [clone, xpu_build_test] + uses: ./.github/workflows/_xpu_8cards_case_test.yml + with: + FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }} + DOCKER_IMAGE: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/fastdeploy-xpu:ci + FASTDEPLOY_WHEEL_URL: ${{ needs.xpu_build_test.outputs.wheel_path }} + MODEL_PATH: /ssd3/model + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} diff --git a/.github/workflows/gh-pages.yml b/.github/workflows/gh-pages.yml index 17234b63908..6c06ed0a6aa 100644 --- a/.github/workflows/gh-pages.yml +++ b/.github/workflows/gh-pages.yml @@ -11,11 +11,11 @@ jobs: deploy: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v4 - - uses: actions/setup-python@v5 + - uses: actions/checkout@v6 + - uses: actions/setup-python@v6 with: python-version: 3.x - - run: pip install mkdocs-material mkdocs-get-deps mkdocs-material-extensions mkdocs-multilang + - run: pip install mkdocs-material mkdocs-get-deps mkdocs-material-extensions mkdocs-multilang mkdocs-static-i18n - name: Deploy to GitHub Pages env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} diff --git a/.github/workflows/pr_build_and_test.yml b/.github/workflows/pr_build_and_test.yml new file mode 100644 index 00000000000..9ffcd75ee5c --- /dev/null +++ b/.github/workflows/pr_build_and_test.yml @@ -0,0 +1,111 @@ +name: PR Build and Test +on: + pull_request: + types: [opened, synchronize] + branches: [develop, release/**] +permissions: read-all + +concurrency: + group: ${{ github.event.pull_request.number }}-${{ github.workflow }} + cancel-in-progress: true + +jobs: + clone: + name: FD-Clone-Linux + uses: ./.github/workflows/_clone_linux.yml + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + + build: + name: FD-Build-Linux + needs: clone + uses: ./.github/workflows/_build_linux.yml + with: + DOCKER_IMAGE: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:fastdeploy-build-cuda126-manylinux + FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }} + COMPILE_ARCH: "90" + WITH_NIGHTLY_BUILD: "OFF" + FD_VERSION: "0.0.0" + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + + resultshow: + name: Use Build Output + needs: build + runs-on: ubuntu-latest + steps: + - name: Print wheel path + run: | + echo "The built wheel is located at: ${{ needs.build.outputs.wheel_path }}" + + unittest_coverage: + name: Run FastDeploy Unit Tests and Coverage + needs: [clone,build] + uses: ./.github/workflows/_unit_test_coverage.yml + with: + DOCKER_IMAGE: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:fastdeploy-ciuse-cuda126-paddle-dev + FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }} + FASTDEPLOY_WHEEL_URL: ${{ needs.build.outputs.wheel_path }} + MODEL_CACHE_DIR: "/ssd2/actions-runner/ModelData" + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + + four_cards_test: + name: Run Four Cards Tests + needs: [clone,build] + uses: ./.github/workflows/_gpu_4cards_case_test.yml + with: + DOCKER_IMAGE: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:fastdeploy-ciuse-cuda126-paddle-dev + FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }} + FASTDEPLOY_WHEEL_URL: ${{ needs.build.outputs.wheel_path }} + MODEL_CACHE_DIR: "/ssd2/actions-runner/ModelData" + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + + logprob_test: + name: Run FastDeploy LogProb Tests + needs: [build] + uses: ./.github/workflows/_logprob_test_linux.yml + with: + DOCKER_IMAGE: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:fastdeploy-ciuse-cuda126-paddle-dev + PADDLETEST_ARCHIVE_URL: "https://xly-devops.bj.bcebos.com/PaddleTest/PaddleTest.tar.gz" + FASTDEPLOY_WHEEL_URL: ${{ needs.build.outputs.wheel_path }} + MODEL_CACHE_DIR: "/ssd2/actions-runner/ModelData" + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + + pre_ce_test: + name: Extracted partial CE model tasks to run in CI. + needs: [clone,build] + uses: ./.github/workflows/_pre_ce_test.yml + with: + DOCKER_IMAGE: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:fastdeploy-ciuse-cuda126-dailyupdate + FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }} + FASTDEPLOY_WHEEL_URL: ${{ needs.build.outputs.wheel_path }} + MODEL_CACHE_DIR: "/ssd2/actions-runner/ModelData" + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + + base_test: + name: Run Base Tests + needs: [clone,build] + uses: ./.github/workflows/_base_test.yml + with: + DOCKER_IMAGE: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:fastdeploy-ciuse-cuda126-paddle-dev + FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }} + FASTDEPLOY_WHEEL_URL: ${{ needs.build.outputs.wheel_path }} + MODEL_CACHE_DIR: "/ssd2/actions-runner/ModelData" + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + + stable_test: + name: Run Stable Tests + needs: [clone,build] + uses: ./.github/workflows/_stable_test.yml + with: + DOCKER_IMAGE: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:fastdeploy-ciuse-cuda126-dailyupdate + FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }} + FASTDEPLOY_WHEEL_URL: ${{ needs.build.outputs.wheel_path }} + MODEL_CACHE_DIR: "/ssd2/actions-runner/ModelData" + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} diff --git a/.github/workflows/publish_job.yml b/.github/workflows/publish_job.yml new file mode 100644 index 00000000000..9207d58a497 --- /dev/null +++ b/.github/workflows/publish_job.yml @@ -0,0 +1,531 @@ +name: Publish Job + +on: + workflow_dispatch: + schedule: + - cron: '0 18 * * *' # 2:00 AM China Standard Time (UTC+8) + push: + # branches: + # - develop + tags: + - '*' + +permissions: read-all + +concurrency: + group: Publish-Job-${{ github.ref }}-${{ github.sha }} + cancel-in-progress: true + + +jobs: + publish_pre_check: + runs-on: ubuntu-latest + if: | + github.event.repository.fork == false && + ( + (github.event_name == 'schedule' && github.ref_name == 'develop') || + (github.event_name == 'push' && github.ref_type == 'tag') || + ((github.event_name == 'workflow_dispatch') && + (github.ref_name == 'develop' || github.ref_type == 'tag')) + ) + env: + TAG_VERSION_MAPPINGS: ${{ vars.TAG_VERSION_MAPPINGS }} + FD_VERSION_DEV: ${{ vars.FD_VERSION_DEV }} + COMPILE_USE_PADDLE_WHL_URL_MAPPINGS: ${{ vars.COMPILE_USE_PADDLE_WHL_URL_MAPPINGS }} + outputs: + compile_use_paddle_version: ${{ steps.set_output.outputs.compile_use_paddle_version }} + compile_continue: ${{ steps.set_output.outputs.compile_continue }} + fd_version: ${{ steps.set_output.outputs.fd_version }} + with_nightly_build: ${{ steps.set_output.outputs.with_nightly_build }} + compile_use_paddle_whl_url: ${{ steps.set_output.outputs.compile_use_paddle_whl_url }} + + steps: + - name: Get tag version + if: github.ref_type == 'tag' + run: | + TAG_NAME="${GITHUB_REF##*/}" # 提取 tag 名称,比如 v2.1.0 + TAG_VERSION="${TAG_NAME#v}" # 去掉前缀 v + echo "FD_VERSION=$TAG_VERSION" >> $GITHUB_ENV + + - name: Check FD version to Paddle version mapping + if: github.ref_type == 'tag' + env: + TARGET_FD: ${{ env.FD_VERSION }} + run: | + FOUND_PADDLE="" + # 遍历映射 + for pair in $(echo $TAG_VERSION_MAPPINGS | tr ';' ' '); do + fd=$(echo "$pair" | cut -d',' -f1) + paddle=$(echo "$pair" | cut -d',' -f2) + if [[ "$fd" == "$TARGET_FD" ]]; then + FOUND_PADDLE="$paddle" + break + fi + done + + if [[ -z "$FOUND_PADDLE" ]]; then + echo "No Paddle version found for FD $TARGET_FD" + else + echo "FD $TARGET_FD maps to Paddle $FOUND_PADDLE" + echo "PADDLE_VERSION=$FOUND_PADDLE" >> $GITHUB_ENV + fi + - name: Set Version + id: set_output + env: + PADDLE_VERSION: ${{ env.PADDLE_VERSION }} + FD_VERSION: ${{ env.FD_VERSION }} + run: | + if [[ "${{ github.ref_type }}" == "tag" ]]; then + if [[ -z "$PADDLE_VERSION" ]]; then + compile_continue=false + else + compile_use_paddle_version=$PADDLE_VERSION + compile_continue=true + fi + fd_version=$FD_VERSION + fi + if [[ "${{ github.ref_name }}" == "develop" ]];then + compile_continue=true + compile_use_paddle_version="" + fd_version=${FD_VERSION_DEV} + with_nightly_build=ON + fi + # Todo + # 通过变量COMPILE_USE_PADDLE_WHL_URL_MAPPINGS中的映射关系,决定是否是安装指定版本的Paddle还是直接安装URL + for pair in $(echo $COMPILE_USE_PADDLE_WHL_URL_MAPPINGS | tr ';' ' '); do + branch=$(echo "$pair" | cut -d',' -f1) + paddle_whl_url=$(echo "$pair" | cut -d',' -f2) + if [[ "$branch" == "${{ github.ref_name }}" ]]; then + FOUND_PADDLE_URL="$paddle_whl_url" + echo "compile_use_paddle_whl_url=${FOUND_PADDLE_URL}" >> $GITHUB_OUTPUT + compile_continue=true + break + fi + done + echo "compile_continue=${compile_continue}" >> $GITHUB_OUTPUT + echo "compile_use_paddle_version=${compile_use_paddle_version}" >> $GITHUB_OUTPUT + echo "fd_version=${fd_version}" >> $GITHUB_OUTPUT + echo "with_nightly_build=${with_nightly_build:-OFF}" >> $GITHUB_OUTPUT + + print_publish_pre_check_outputs: + runs-on: ubuntu-latest + needs: publish_pre_check + steps: + - name: Print outputs as JSON + run: | + echo '${{ toJSON(needs.publish_pre_check.outputs) }}' + + clone: + environment: CodeSync + name: FD-Clone-Linux + runs-on: ubuntu-latest + needs: publish_pre_check + if: ${{ needs.publish_pre_check.outputs.compile_continue == 'true' }} + outputs: + repo_archive_url: ${{ steps.set_output.outputs.repo_archive_url }} + steps: + - name: Clone FastDeploy + uses: actions/checkout@v6 + with: + ref: ${{ github.ref_name }} + submodules: 'recursive' + fetch-depth: 1000 + + - name: Python Setup + uses: actions/setup-python@v6 + with: + python-version: '3.10' + - name: Code Info Show and Upload + id: set_output + env: + AK: ${{ secrets.BOS_AK }} + SK: ${{ secrets.BOS_SK }} + run: | + git submodule foreach --recursive sh -c "git config remote.origin.fetch '+refs/heads/*:refs/remotes/origin/*'" + echo "Current HEAD Log:" + git log --oneline -n 5 + ls + cd .. + tar -zcf FastDeploy.tar.gz FastDeploy + if [[ "${{ github.ref_type }}" == "tag" ]]; then + commit_id=${{ github.sha }} + tag_name=${{ github.ref_name }} + target_path=paddle-qa/TAG/FastDeploy/${tag_name}/${commit_id} + else + commit_id=${{ github.sha }} + branch_name=${{ github.ref_name }} + target_path=paddle-qa/BRANCH/FastDeploy/${branch_name}/${commit_id} + fi + wget -q --no-proxy --no-check-certificate https://paddle-qa.bj.bcebos.com/CodeSync/develop/PaddlePaddle/PaddleTest/tools/bos_tools.py + push_file=$(realpath bos_tools.py) + python -m pip install bce-python-sdk==0.9.29 + ls + python ${push_file} FastDeploy.tar.gz ${target_path} + target_path_stripped="${target_path#paddle-qa/}" + REPO_ARCHIVE_URL=https://paddle-qa.bj.bcebos.com/${target_path_stripped}/FastDeploy.tar.gz + echo "repo_archive_url=${REPO_ARCHIVE_URL}" >> $GITHUB_OUTPUT + + resultshow: + name: Show Code Archive Output + needs: clone + runs-on: ubuntu-latest + steps: + - name: Print wheel path + run: | + echo "The code archive is located at: ${{ needs.clone.outputs.repo_archive_url }}" + + build_cu126: + name: BUILD_cu126 + needs: [clone, publish_pre_check] + uses: ./.github/workflows/_build_linux.yml + with: + DOCKER_IMAGE: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:fastdeploy-build-cuda126-manylinux + FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }} + COMPILE_ARCH: "86,89,80,90" + WITH_NIGHTLY_BUILD: ${{ needs.publish_pre_check.outputs.with_nightly_build }} + FD_VERSION: ${{ needs.publish_pre_check.outputs.fd_version }} + PADDLEVERSION: ${{ needs.publish_pre_check.outputs.compile_use_paddle_version }} + PADDLE_WHL_URL: ${{ needs.publish_pre_check.outputs.compile_use_paddle_whl_url }} + FD_UNIFY_BUILD: "true" + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + + build_cu129: + name: BUILD_cu129 + needs: [clone, publish_pre_check] + uses: ./.github/workflows/_build_linux_cu129.yml + with: + DOCKER_IMAGE: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:fastdeploy-build-cuda129-manylinux + FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }} + COMPILE_ARCH: "86,89,80,90" + WITH_NIGHTLY_BUILD: ${{ needs.publish_pre_check.outputs.with_nightly_build }} + FD_VERSION: ${{ needs.publish_pre_check.outputs.fd_version }} + PADDLEVERSION: ${{ needs.publish_pre_check.outputs.compile_use_paddle_version }} + PADDLE_WHL_URL: ${{ needs.publish_pre_check.outputs.compile_use_paddle_whl_url }} + FD_UNIFY_BUILD: "true" + + build_cu130: + name: BUILD_cu130 + needs: [ clone, publish_pre_check ] + uses: ./.github/workflows/_build_linux_cu130.yml + with: + DOCKER_IMAGE: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:fastdeploy-build-cuda130-manylinux + FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }} + COMPILE_ARCH: "86,89,80,90" + WITH_NIGHTLY_BUILD: ${{ needs.publish_pre_check.outputs.with_nightly_build }} + FD_VERSION: ${{ needs.publish_pre_check.outputs.fd_version }} + PADDLEVERSION: ${{ needs.publish_pre_check.outputs.compile_use_paddle_version }} + PADDLE_WHL_URL: ${{ needs.publish_pre_check.outputs.compile_use_paddle_whl_url }} + FD_UNIFY_BUILD: "true" + + build_fd_router: + name: BUILD_FD_ROUTER + needs: [clone, publish_pre_check] + uses: ./.github/workflows/_build_linux_fd_router.yml + with: + DOCKER_IMAGE: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:fastdeploy-build-cuda129-manylinux + FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }} + COMPILE_ARCH: "80,90" + WITH_NIGHTLY_BUILD: ${{ needs.publish_pre_check.outputs.with_nightly_build }} + FD_VERSION: ${{ needs.publish_pre_check.outputs.fd_version }} + PADDLEVERSION: ${{ needs.publish_pre_check.outputs.compile_use_paddle_version }} + PADDLE_WHL_URL: ${{ needs.publish_pre_check.outputs.compile_use_paddle_whl_url }} + + ce_upload_fd_router: + environment: CodeSync + name: CE_UPLOAD_FD_ROUTER + needs: build_fd_router + runs-on: ubuntu-latest + env: + AK: ${{ secrets.BOS_AK }} + SK: ${{ secrets.BOS_SK }} + FD_ROUTER_URL: ${{ needs.build_fd_router.outputs.fd_router_path }} + steps: + - uses: actions/setup-python@v6 + with: + python-version: '3.10' + - name: Fd-Router Info Show and Upload + if: github.ref_name == 'develop' || github.ref_type == 'tag' + run: | + echo "The fd_router is located at: ${{ needs.build_fd_router.outputs.fd_router_path }}" + wget -q --no-check-certificate ${{ needs.build_fd_router.outputs.fd_router_path }} + filename=$(basename ${{ needs.build_fd_router.outputs.fd_router_path }}) + + commit_id=${{ github.sha }} + + wget -q --no-proxy --no-check-certificate https://paddle-qa.bj.bcebos.com/CodeSync/develop/PaddlePaddle/PaddleTest/tools/bos_tools.py + push_file=$(realpath bos_tools.py) + python -m pip install bce-python-sdk==0.9.29 + + if [[ "${{ github.ref_name }}" == "develop" ]];then + branch_name=${{ github.ref_name }} + target_paths=( + "paddle-qa/paddle-pipeline/FastDeploy_ActionCE/${branch_name}/${commit_id}" + "paddle-qa/paddle-pipeline/FastDeploy_ActionCE/${branch_name}/latest" + ) + elif [[ "${{ github.ref_type }}" == "tag" ]]; then + tag_name=${{ github.ref_name }} + target_paths=( + "paddle-qa/paddle-pipeline/FastDeploy_ActionCE/${tag_name}/${commit_id}" + "paddle-qa/paddle-pipeline/FastDeploy_ActionCE/${tag_name}/latest" + ) + else + echo "Not develop or tag, do nothing" + fi + + for target_path in "${target_paths[@]}"; do + echo "Uploading ${filename} to ${target_path}" + python "${push_file}" "${filename}" "${target_path}" + done + + base_prefix="paddle-qa/" + commit_path_stripped="${target_paths[0]#${base_prefix}}" + latest_path_stripped="${target_paths[1]#${base_prefix}}" + FD_ROUTER_PATH="https://paddle-qa.bj.bcebos.com/${commit_path_stripped}/${filename}" + FD_ROUTER_PATH_LATEST="https://paddle-qa.bj.bcebos.com/${latest_path_stripped}/${filename}" + + echo "commit fd-router url is ${FD_ROUTER_PATH}" + echo "latest fd-router url is ${FD_ROUTER_PATH_LATEST}" + + paddle_pypi_upload_cu126: + environment: PaddleSourceUpload + name: PADDLE_PYPI_UPLOAD_cu126 + needs: build_cu126 + runs-on: ubuntu-latest + env: + AK: ${{ secrets.BOS_AK }} + SK: ${{ secrets.BOS_SK }} + FASTDEPLOY_WHEEL_URL: ${{ needs.build_cu126.outputs.wheel_path }} + steps: + - uses: actions/setup-python@v6 + with: + python-version: '3.10' + - name: Wheel Info Show and Upload + if: github.ref_name == 'develop' || github.ref_type == 'tag' + run: | + echo "The wheel is located at: ${FASTDEPLOY_WHEEL_URL}" + wget -q --no-check-certificate ${FASTDEPLOY_WHEEL_URL} + filename=$(basename ${FASTDEPLOY_WHEEL_URL}) + if [[ "${{ github.ref_name }}" == "develop" ]];then + target_path=paddle-whl/nightly/cu126/fastdeploy-gpu + elif [[ "${{ github.ref_type }}" == "tag" ]]; then + target_path=paddle-whl/stable/cu126/fastdeploy-gpu + else + echo "Not develop or tag, do nothing" + fi + wget -q --no-proxy --no-check-certificate https://paddle-qa.bj.bcebos.com/CodeSync/develop/PaddlePaddle/PaddleTest/tools/bos_tools.py + push_file=$(realpath bos_tools.py) + python -m pip install bce-python-sdk==0.9.29 + ls + python ${push_file} ${filename} ${target_path} + + paddle_pypi_upload_cu129: + environment: PaddleSourceUpload + name: PADDLE_PYPI_UPLOAD_cu129 + needs: build_cu129 + runs-on: ubuntu-latest + env: + AK: ${{ secrets.BOS_AK }} + SK: ${{ secrets.BOS_SK }} + FASTDEPLOY_WHEEL_URL: ${{ needs.build_cu129.outputs.wheel_path_cu129 }} + steps: + - uses: actions/setup-python@v6 + with: + python-version: '3.10' + - name: Wheel Info Show and Upload + if: github.ref_name == 'develop' || github.ref_type == 'tag' + run: | + echo "The wheel is located at: ${FASTDEPLOY_WHEEL_URL}" + wget -q --no-check-certificate ${FASTDEPLOY_WHEEL_URL} + filename=$(basename ${FASTDEPLOY_WHEEL_URL}) + if [[ "${{ github.ref_name }}" == "develop" ]];then + target_path=paddle-whl/nightly/cu129/fastdeploy-gpu + elif [[ "${{ github.ref_type }}" == "tag" ]]; then + target_path=paddle-whl/stable/cu129/fastdeploy-gpu + else + echo "Not develop or tag, do nothing" + fi + wget -q --no-proxy --no-check-certificate https://paddle-qa.bj.bcebos.com/CodeSync/develop/PaddlePaddle/PaddleTest/tools/bos_tools.py + push_file=$(realpath bos_tools.py) + python -m pip install bce-python-sdk==0.9.29 + ls + python ${push_file} ${filename} ${target_path} + + paddle_pypi_upload_cu130: + environment: PaddleSourceUpload + name: PADDLE_PYPI_UPLOAD_cu130 + needs: build_cu130 + runs-on: ubuntu-latest + env: + AK: ${{ secrets.BOS_AK }} + SK: ${{ secrets.BOS_SK }} + FASTDEPLOY_WHEEL_URL: ${{ needs.build_cu130.outputs.wheel_path_cu130 }} + steps: + - uses: actions/setup-python@v6 + with: + python-version: '3.10' + - name: Wheel Info Show and Upload + if: github.ref_name == 'develop' || github.ref_type == 'tag' + run: | + echo "The wheel is located at: ${FASTDEPLOY_WHEEL_URL}" + wget -q --no-check-certificate ${FASTDEPLOY_WHEEL_URL} + filename=$(basename ${FASTDEPLOY_WHEEL_URL}) + if [[ "${{ github.ref_name }}" == "develop" ]];then + target_path=paddle-whl/nightly/cu130/fastdeploy-gpu + elif [[ "${{ github.ref_type }}" == "tag" ]]; then + target_path=paddle-whl/stable/cu130/fastdeploy-gpu + else + echo "Not develop or tag, do nothing" + fi + wget -q --no-proxy --no-check-certificate https://paddle-qa.bj.bcebos.com/CodeSync/develop/PaddlePaddle/PaddleTest/tools/bos_tools.py + push_file=$(realpath bos_tools.py) + python -m pip install bce-python-sdk==0.9.29 + ls + python ${push_file} ${filename} ${target_path} + + images_build: + name: Run FD Image Build + needs: [clone, publish_pre_check, build_cu126] + runs-on: [self-hosted, Docker-Build] + if: | + github.event.repository.fork == false && + ( + (github.event_name == 'push' && github.ref_type == 'tag') || + (github.event_name == 'workflow_dispatch' && github.ref_type == 'tag') + ) + env: + FD_VERSION: ${{ needs.publish_pre_check.outputs.fd_version }} + PADDLEVERSION: ${{ needs.publish_pre_check.outputs.compile_use_paddle_version }} + FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }} + steps: + - name: Images Build + shell: bash + env: + docker_image: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:fastdeploy-ciuse-cuda126-dailyupdate + run: | + set -x + FULL_REPO="${{ github.repository }}" + REPO_NAME="${FULL_REPO##*/}" + fd_archive_url="${{ needs.clone.outputs.repo_archive_url }}" + + # Clean the repository directory before starting + docker run --rm --net=host -v $(pwd):/workspace -w /workspace \ + -e "REPO_NAME=${REPO_NAME}" \ + ${docker_image} /bin/bash -c ' + if [ -d ${REPO_NAME} ]; then + echo "Directory ${REPO_NAME} exists, removing it..." + rm -rf ${REPO_NAME}* + fi + ' + wget -q --no-proxy ${fd_archive_url} + tar -xf FastDeploy.tar.gz + rm -rf FastDeploy.tar.gz + cd FastDeploy + git config --global user.name "FastDeployCI" + git config --global user.email "fastdeploy_ci@example.com" + git log -n 3 --oneline + + cd ./dockerfiles + + PRODUCT_NAME=ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/fastdeploy-cuda-12.6:${FD_VERSION} + docker build --no-cache -t ${PRODUCT_NAME} -f Dockerfile.gpu . \ + --network host \ + --build-arg PADDLE_VERSION=${PADDLEVERSION} \ + --build-arg FD_VERSION=${FD_VERSION} + + docker push ${PRODUCT_NAME} + + unittest_coverage: + name: Run FastDeploy Unit Tests and Coverage + needs: [clone,build_cu126] + uses: ./.github/workflows/_unit_test_coverage.yml + with: + DOCKER_IMAGE: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:fastdeploy-ciuse-cuda126-paddle-dev + FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }} + FASTDEPLOY_WHEEL_URL: ${{ needs.build_cu126.outputs.wheel_path }} + MODEL_CACHE_DIR: "/ssd2/actions-runner/ModelData" + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + + four_cards_test: + name: Run Four Cards Tests + needs: [clone,build_cu126] + uses: ./.github/workflows/_gpu_4cards_case_test.yml + with: + DOCKER_IMAGE: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:fastdeploy-ciuse-cuda126-paddle-dev + FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }} + FASTDEPLOY_WHEEL_URL: ${{ needs.build_cu126.outputs.wheel_path }} + MODEL_CACHE_DIR: "/ssd2/actions-runner/ModelData" + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + + golang_router_test: + name: Run Golang Router Tests + needs: [ clone,build_cu126,build_fd_router ] + uses: ./.github/workflows/_golang_router_test.yml + with: + DOCKER_IMAGE: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:fastdeploy-ciuse-cuda126-dailyupdate + FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }} + FASTDEPLOY_WHEEL_URL: ${{ needs.build_cu126.outputs.wheel_path }} + FASTDEPLOY_ROUTER_URL: ${{ needs.build_fd_router.outputs.fd_router_path }} + MODEL_CACHE_DIR: "/ssd2/actions-runner/ModelData" + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + + logprob_test: + name: Run FastDeploy LogProb Tests + needs: [build_cu126] + uses: ./.github/workflows/_logprob_test_linux.yml + with: + DOCKER_IMAGE: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:fastdeploy-ciuse-cuda126-dailyupdate + PADDLETEST_ARCHIVE_URL: "https://xly-devops.bj.bcebos.com/PaddleTest/PaddleTest.tar.gz" + FASTDEPLOY_WHEEL_URL: ${{ needs.build_cu126.outputs.wheel_path }} + MODEL_CACHE_DIR: "/ssd2/actions-runner/ModelData" + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + + pre_ce_test: + name: Extracted partial CE model tasks to run in CI. + needs: [clone,build_cu126] + uses: ./.github/workflows/_pre_ce_test.yml + with: + DOCKER_IMAGE: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:fastdeploy-ciuse-cuda126-dailyupdate + FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }} + FASTDEPLOY_WHEEL_URL: ${{ needs.build_cu126.outputs.wheel_path }} + MODEL_CACHE_DIR: "/ssd2/actions-runner/ModelData" + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + + base_test: + name: Run Base Tests + needs: [clone,build_cu126] + uses: ./.github/workflows/_base_test.yml + with: + DOCKER_IMAGE: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:fastdeploy-ciuse-cuda126-dailyupdate + FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }} + FASTDEPLOY_WHEEL_URL: ${{ needs.build_cu126.outputs.wheel_path }} + MODEL_CACHE_DIR: "/ssd2/actions-runner/ModelData" + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + + accuracy_test: + name: Run Accuracy Tests + needs: [clone,build_cu126] + uses: ./.github/workflows/_accuracy_test.yml + with: + DOCKER_IMAGE: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:fastdeploy-ciuse-cuda126-dailyupdate + FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }} + FASTDEPLOY_WHEEL_URL: ${{ needs.build_cu126.outputs.wheel_path }} + MODEL_CACHE_DIR: "/ssd2/actions-runner/ModelData" + + stable_test: + name: Run Stable Tests + needs: [clone,build_cu126] + uses: ./.github/workflows/_stable_test.yml + with: + DOCKER_IMAGE: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:fastdeploy-ciuse-cuda126-dailyupdate + FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }} + FASTDEPLOY_WHEEL_URL: ${{ needs.build_cu126.outputs.wheel_path }} + MODEL_CACHE_DIR: "/ssd2/actions-runner/ModelData" + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} diff --git a/.github/workflows/remove-skip-ci-labels.yml b/.github/workflows/remove-skip-ci-labels.yml new file mode 100644 index 00000000000..978f70ea240 --- /dev/null +++ b/.github/workflows/remove-skip-ci-labels.yml @@ -0,0 +1,53 @@ +name: Remove Skip-CI Labels + +on: + pull_request_target: + types: [synchronize] + +permissions: + pull-requests: write + +jobs: + remove-skip-ci-labels: + name: Remove skip-ci labels on new commits + runs-on: ubuntu-latest + steps: + - name: Get PR labels + id: get-labels + uses: actions/github-script@v8 + with: + github-token: ${{ secrets.GITHUB_TOKEN }} + script: | + const { data: labels } = await github.rest.issues.listLabelsOnIssue({ + owner: context.repo.owner, + repo: context.repo.repo, + issue_number: context.issue.number + }); + + const skipCiLabels = labels + .filter(label => label.name.startsWith('skip-ci:')) + .map(label => label.name); + + console.log('Found skip-ci labels:', skipCiLabels); + core.setOutput('skip-ci-labels', JSON.stringify(skipCiLabels)); + core.setOutput('has-skip-ci-labels', skipCiLabels.length > 0 ? 'true' : 'false'); + + - name: Remove skip-ci labels + if: steps.get-labels.outputs.has-skip-ci-labels == 'true' + uses: actions/github-script@v8 + with: + github-token: ${{ secrets.GITHUB_TOKEN }} + script: | + const skipCiLabels = JSON.parse('${{ steps.get-labels.outputs.skip-ci-labels }}'); + + for (const label of skipCiLabels) { + console.log(`Removing label: ${label}`); + await github.rest.issues.removeLabel({ + owner: context.repo.owner, + repo: context.repo.repo, + issue_number: context.issue.number, + name: label + }); + } + + console.log(`Successfully removed ${skipCiLabels.length} skip-ci label(s)`); diff --git a/.github/workflows/rerun.yml b/.github/workflows/rerun.yml new file mode 100644 index 00000000000..bbc96edd37e --- /dev/null +++ b/.github/workflows/rerun.yml @@ -0,0 +1,217 @@ +name: Re-run + +on: + issue_comment: + types: [created] + +jobs: + re-run: + if: ${{ github.event.issue.pull_request && contains(github.event.comment.body, '/re-run') && github.event.comment.user.login == github.event.issue.user.login }} + runs-on: ubuntu-latest + steps: + - name: Cleanup + run: | + rm -rf * .[^.]* + + - name: Checkout code + uses: actions/checkout@v6 + + - name: Rerun all failed jobs + if: ${{ contains(github.event.comment.body, 'all-failed') }} + uses: ./.github/actions/rerun-workflow + with: + PR_ID: ${{ github.event.issue.number }} + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + OWNER: ${{ github.repository_owner }} + REPO: ${{ github.event.repository.name }} + JOB_NAME: 'all-failed' + + - name: Rerun Approval + if: ${{ contains(github.event.comment.body, 'approval') }} + uses: ./.github/actions/rerun-workflow + with: + PR_ID: ${{ github.event.issue.number }} + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + OWNER: ${{ github.repository_owner }} + REPO: ${{ github.event.repository.name }} + JOB_NAME: 'Approval' + + - name: Rerun CI_ILUVATAR + if: ${{ contains(github.event.comment.body, 'ci_iluvatar') }} + uses: ./.github/actions/rerun-workflow + with: + PR_ID: ${{ github.event.issue.number }} + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + OWNER: ${{ github.repository_owner }} + REPO: ${{ github.event.repository.name }} + JOB_NAME: 'Run iluvatar Tests / run_iluvatar_cases' + + - name: Rerun CI_XPU + if: ${{ contains(github.event.comment.body, 'ci_xpu') }} + uses: ./.github/actions/rerun-workflow + with: + PR_ID: ${{ github.event.issue.number }} + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + OWNER: ${{ github.repository_owner }} + REPO: ${{ github.event.repository.name }} + JOB_NAME: 'xpu_build_test / xpu-build-test' + + - name: Rerun run_xpu_4cards_cases + if: ${{ contains(github.event.comment.body, 'run_xpu_4cards_cases') }} + uses: ./.github/actions/rerun-workflow + with: + PR_ID: ${{ github.event.issue.number }} + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + OWNER: ${{ github.repository_owner }} + REPO: ${{ github.event.repository.name }} + JOB_NAME: 'xpu_4cards_case_test / run_xpu_4cards_cases' + + - name: Rerun run_xpu_8cards_cases + if: ${{ contains(github.event.comment.body, 'run_xpu_8cards_cases') }} + uses: ./.github/actions/rerun-workflow + with: + PR_ID: ${{ github.event.issue.number }} + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + OWNER: ${{ github.repository_owner }} + REPO: ${{ github.event.repository.name }} + JOB_NAME: 'xpu_8cards_case_test / run_xpu_8cards_cases' + + - name: Rerun CI_HPU + if: ${{ contains(github.event.comment.body, 'ci_hpu') }} + uses: ./.github/actions/rerun-workflow + with: + PR_ID: ${{ github.event.issue.number }} + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + OWNER: ${{ github.repository_owner }} + REPO: ${{ github.event.repository.name }} + JOB_NAME: 'CI_HPU' + + - name: Rerun CI_METAX + if: ${{ contains(github.event.comment.body, 'ci_metax') }} + uses: ./.github/actions/rerun-workflow + with: + PR_ID: ${{ github.event.issue.number }} + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + OWNER: ${{ github.repository_owner }} + REPO: ${{ github.event.repository.name }} + JOB_NAME: 'Trigger Jenkins for PR' + + - name: Rerun Check PR Template + if: ${{ contains(github.event.comment.body, 'check_pr_template') }} + uses: ./.github/actions/rerun-workflow + with: + PR_ID: ${{ github.event.issue.number }} + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + OWNER: ${{ github.repository_owner }} + REPO: ${{ github.event.repository.name }} + JOB_NAME: 'Check PR Template' + + - name: Rerun Codestyle-check + if: ${{ contains(github.event.comment.body, 'codestyle') || contains(github.event.comment.body, 'pre_commit') }} + uses: ./.github/actions/rerun-workflow + with: + PR_ID: ${{ github.event.issue.number }} + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + OWNER: ${{ github.repository_owner }} + REPO: ${{ github.event.repository.name }} + JOB_NAME: 'Pre Commit' + + - name: Rerun Clone + if: ${{ contains(github.event.comment.body, 'clone') }} + uses: ./.github/actions/rerun-workflow + with: + PR_ID: ${{ github.event.issue.number }} + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + OWNER: ${{ github.repository_owner }} + REPO: ${{ github.event.repository.name }} + JOB_NAME: 'FD-Clone-Linux / code-clone' + + - name: Rerun Build + if: ${{ contains(github.event.comment.body, 'build') }} + uses: ./.github/actions/rerun-workflow + with: + PR_ID: ${{ github.event.issue.number }} + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + OWNER: ${{ github.repository_owner }} + REPO: ${{ github.event.repository.name }} + JOB_NAME: 'FD-Build-Linux / fd-build' + + - name: Rerun run_ce_cases + if: ${{ contains(github.event.comment.body, 'run_ce_cases') }} + uses: ./.github/actions/rerun-workflow + with: + PR_ID: ${{ github.event.issue.number }} + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + OWNER: ${{ github.repository_owner }} + REPO: ${{ github.event.repository.name }} + JOB_NAME: 'Extracted partial CE model tasks to run in CI. / run_ce_cases' + + - name: Rerun accuracy_tests + if: ${{ contains(github.event.comment.body, 'accuracy_tests') }} + uses: ./.github/actions/rerun-workflow + with: + PR_ID: ${{ github.event.issue.number }} + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + OWNER: ${{ github.repository_owner }} + REPO: ${{ github.event.repository.name }} + JOB_NAME: 'Run Accuracy Tests / accuracy_tests' + + - name: Rerun base_tests + if: ${{ contains(github.event.comment.body, 'base_tests') }} + uses: ./.github/actions/rerun-workflow + with: + PR_ID: ${{ github.event.issue.number }} + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + OWNER: ${{ github.repository_owner }} + REPO: ${{ github.event.repository.name }} + JOB_NAME: 'Run Base Tests / base_tests' + + - name: Rerun run_tests_logprob + if: ${{ contains(github.event.comment.body, 'run_tests_logprob') }} + uses: ./.github/actions/rerun-workflow + with: + PR_ID: ${{ github.event.issue.number }} + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + OWNER: ${{ github.repository_owner }} + REPO: ${{ github.event.repository.name }} + JOB_NAME: 'Run FastDeploy LogProb Tests / run_tests_logprob' + + - name: Rerun run_tests_with_coverage + if: ${{ contains(github.event.comment.body, 'run_tests_with_coverage') }} + uses: ./.github/actions/rerun-workflow + with: + PR_ID: ${{ github.event.issue.number }} + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + OWNER: ${{ github.repository_owner }} + REPO: ${{ github.event.repository.name }} + JOB_NAME: 'Run FastDeploy Unit Tests and Coverage / run_tests_with_coverage' + + - name: Rerun diff_coverage_report + if: ${{ contains(github.event.comment.body, 'diff_coverage_report') }} + uses: ./.github/actions/rerun-workflow + with: + PR_ID: ${{ github.event.issue.number }} + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + OWNER: ${{ github.repository_owner }} + REPO: ${{ github.event.repository.name }} + JOB_NAME: 'Run FastDeploy Unit Tests and Coverage / diff_coverage_report' + + - name: Rerun stable_tests + if: ${{ contains(github.event.comment.body, 'stable_tests') }} + uses: ./.github/actions/rerun-workflow + with: + PR_ID: ${{ github.event.issue.number }} + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + OWNER: ${{ github.repository_owner }} + REPO: ${{ github.event.repository.name }} + JOB_NAME: 'Run Stable Tests / stable_tests' + + - name: Rerun run_4_cards_tests + if: ${{ contains(github.event.comment.body, 'run_4_cards_tests') }} + uses: ./.github/actions/rerun-workflow + with: + PR_ID: ${{ github.event.issue.number }} + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + OWNER: ${{ github.repository_owner }} + REPO: ${{ github.event.repository.name }} + JOB_NAME: 'Run Four Cards Tests / run_4_cards_tests' diff --git a/.gitignore b/.gitignore index b7c91af7730..2b35f3a83b6 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,10 @@ /.venv/ /venv/ +tests/log_* +benchmarks/openai-chat-infqps* +splitwise/log* + # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] @@ -121,7 +125,7 @@ dmypy.json FETCH_HEAD #log -log*/ +log/ checkpoints/ checkpoints_origin/ @@ -136,6 +140,7 @@ kernel_meta* fastdeploy_ops.py version.txt EGG-INFO/ +**/fastdeploy_ops/__init__.py # fp8 generated codes autogen/ @@ -156,6 +161,12 @@ nohup.out custom_ops/gpu_ops/fp8_deep_gemm/deep_gemm/include/cutlass custom_ops/gpu_ops/fp8_deep_gemm/deep_gemm/include/cute +#marlin_kernel +custom_ops/gpu_ops/moe/moe_wna16_marlin_utils/kernel_*.cu + +#machete_kernel +custom_ops/gpu_ops/machete/generated + # buff custom_ops/tmp* @@ -164,3 +175,9 @@ build .ccls-cache third_party + +custom_ops/gpu_ops/w4afp8_gemm/w4afp8_gemm_*.cu +custom_ops/gpu_ops/w4afp8_gemm/w4afp8_gemm_template.h + +custom_ops/gpu_ops/wfp8afp8_sparse_gemm/wfp8Afp8_sparse_gemm_*.cu +custom_ops/gpu_ops/wfp8afp8_sparse_gemm/wfp8Afp8_sparse_gemm_template.h diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 00000000000..dc0ae4b385d --- /dev/null +++ b/.gitmodules @@ -0,0 +1,10 @@ +[submodule "custom_ops/third_party/DeepGEMM"] + path = custom_ops/third_party/DeepGEMM + url = https://github.com/deepseek-ai/DeepGEMM.git + ignore = all +[submodule "custom_ops/third_party/cutlass"] + path = custom_ops/third_party/cutlass + url = https://github.com/NVIDIA/cutlass.git +[submodule "custom_ops/third_party/nlohmann_json"] + path = custom_ops/third_party/nlohmann_json + url = https://github.com/nlohmann/json.git diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index ce894293357..1f5791e398a 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,12 +1,18 @@ +exclude: | + (?x)^( + dockerfiles/.+| + custom_ops/third_party/.+ + )$ default_install_hook_types: - pre-commit - commit-msg default_stages: - pre-commit # Run locally + - commit-msg # - manual # Run in CI repos: - repo: https://github.com/psf/black.git - rev: 22.8.0 + rev: 25.1.0 hooks: - id: black files: \.(py|pyi)$ @@ -17,7 +23,7 @@ repos: hooks: - id: isort - repo: https://github.com/PyCQA/flake8 - rev: 4.0.1 + rev: 7.0.0 hooks: - id: flake8 # 代码检查 @@ -26,6 +32,15 @@ repos: hooks: - id: ruff args: [--output-format, github, --fix, --line-length=120, --config, pyproject.toml] +# For C++ files +- repo: local + hooks: + - id: clang-format + name: clang-format + description: Format files with ClangFormat. + entry: clang-format -i + language: system + files: \.(c|cc|cxx|cpp|cu|h|cuh|hpp|hxx|xpu|kps)$ # # 拼写检查 # - repo: https://github.com/codespell-project/codespell # rev: v2.4.1 @@ -50,3 +65,4 @@ repos: - id: detect-private-key - id: check-symlinks - id: check-added-large-files + args: ["--maxkb=1024"] diff --git a/README.md b/README.md deleted file mode 100644 index fd94d27c5fe..00000000000 --- a/README.md +++ /dev/null @@ -1,90 +0,0 @@ -

- -

-

- - - - - - - -

- -

- PaddlePaddle%2FFastDeploy | Trendshift
- Installation - | - Quick Start - | - Supported Models - -

- --------------------------------------------------------------------------------- -# FastDeploy 2.0: Inference and Deployment Toolkit for LLMs and VLMs based on PaddlePaddle - -## News - -**[2025-06] 🔥 Released FastDeploy v2.0:** Supports inference and deployment for ERNIE 4.5. Furthermore, we open-source an industrial-grade PD disaggregation with context caching, dynamic role switching for effective resource utilization to further enhance inference performance for MoE models. - -## About - -**FastDeploy** is an inference and deployment toolkit for large language models and visual language models based on PaddlePaddle. It delivers **production-ready, out-of-the-box deployment solutions** with core acceleration technologies: - -- 🚀 **Load-Balanced PD Disaggregation**: Industrial-grade solution featuring context caching and dynamic instance role switching. Optimizes resource utilization while balancing SLO compliance and throughput. -- 🔄 **Unified KV Cache Transmission**: Lightweight high-performance transport library with intelligent NVLink/RDMA selection. -- 🤝 **OpenAI API Server and vLLM Compatible**: One-command deployment with [vLLM](https://github.com/vllm-project/vllm/) interface compatibility. -- 🧮 **Comprehensive Quantization Format Support**: W8A16, W8A8, W4A16, W4A8, W2A16, FP8, and more. -- ⏩ **Advanced Acceleration Techniques**: Speculative decoding, Multi-Token Prediction (MTP) and Chunked Prefill. -- 🖥️ **Multi-Hardware Support**: NVIDIA GPU, Kunlunxin XPU, Hygon DCU, Ascend NPU, Iluvatar GPU, Enflame GCU, MetaX GPU etc. - -## Requirements - -- OS: Linux -- Python: 3.10 ~ 3.12 - -## Installation - -FastDeploy supports inference deployment on **NVIDIA GPUs**, **Kunlunxin XPUs**, **Iluvatar GPUs**, **Enflame GCUs**, and other hardware. For detailed installation instructions: - -- [NVIDIA GPU](./docs/get_started/installation/nvidia_gpu.md) -- [Kunlunxin XPU](./docs/get_started/installation/kunlunxin_xpu.md) -- [Iluvatar GPU](./docs/get_started/installation/iluvatar_gpu.md) -- [Enflame GCU](./docs/get_started/installation/Enflame_gcu.md) - -**Note:** We are actively working on expanding hardware support. Additional hardware platforms including Ascend NPU, Hygon DCU, and MetaX GPU are currently under development and testing. Stay tuned for updates! - -## Get Started - -Learn how to use FastDeploy through our documentation: -- [10-Minutes Quick Deployment](./docs/get_started/quick_start.md) -- [ERNIE-4.5 Large Language Model Deployment](./docs/get_started/ernie-4.5.md) -- [ERNIE-4.5-VL Multimodal Model Deployment](./docs/get_started/ernie-4.5-vl.md) -- [Offline Inference Development](./docs/offline_inference.md) -- [Online Service Deployment](./docs/online_serving/README.md) -- [Full Supported Models List](./docs/supported_models.md) - -## Supported Models - -| Model | Data Type | PD Disaggregation | Chunked Prefill | Prefix Caching | MTP | CUDA Graph | Maximum Context Length | -|:--- | :------- | :---------- | :-------- | :-------- | :----- | :----- | :----- | -|ERNIE-4.5-300B-A47B | BF16/WINT4/WINT8/W4A8C8/WINT2/FP8 | ✅| ✅ | ✅|✅(WINT4)| WIP |128K | -|ERNIE-4.5-300B-A47B-Base| BF16/WINT4/WINT8 | ✅| ✅ | ✅|✅(WINT4)| WIP | 128K | -|ERNIE-4.5-VL-424B-A47B | BF16/WINT4/WINT8 | WIP | ✅ | WIP | ❌ | WIP |128K | -|ERNIE-4.5-VL-28B-A3B | BF16/WINT4/WINT8 | ❌ | ✅ | WIP | ❌ | WIP |128K | -|ERNIE-4.5-21B-A3B | BF16/WINT4/WINT8/FP8 | ❌ | ✅ | ✅ | WIP | ✅|128K | -|ERNIE-4.5-21B-A3B-Base | BF16/WINT4/WINT8/FP8 | ❌ | ✅ | ✅ | WIP | ✅|128K | -|ERNIE-4.5-0.3B | BF16/WINT8/FP8 | ❌ | ✅ | ✅ | ❌ | ✅| 128K | - -## Advanced Usage - -- [Quantization](./docs/quantization/README.md) -- [PD Disaggregation Deployment](./docs/features/disaggregated.md) -- [Speculative Decoding](./docs/features/speculative_decoding.md) -- [Prefix Caching](./docs/features/prefix_caching.md) -- [Chunked Prefill](./docs/features/chunked_prefill.md) - -## Acknowledgement - -FastDeploy is licensed under the [Apache-2.0 open-source license](./LICENSE). During development, portions of [vLLM](https://github.com/vllm-project/vllm) code were referenced and incorporated to maintain interface compatibility, for which we express our gratitude. diff --git a/README.md b/README.md new file mode 120000 index 00000000000..bacd3186b4b --- /dev/null +++ b/README.md @@ -0,0 +1 @@ +README_CN.md \ No newline at end of file diff --git a/README_CN.md b/README_CN.md new file mode 100644 index 00000000000..4d110b8d830 --- /dev/null +++ b/README_CN.md @@ -0,0 +1,93 @@ +[English](README_EN.md) | 简体中文 +

+ +

+

+ + + + + + + +

+ +

+ PaddlePaddle%2FFastDeploy | Trendshift
+ 安装指导 + | + 快速入门 + | + 支持模型列表 + +

+ +-------------------------------------------------------------------------------- +# FastDeploy 飞桨大模型高效部署套件 + +## 最新活动 + +**[2026-01] FastDeploy v2.4 全新发布!** 新增 DeepSeek V3 与 Qwen3-MoE 模型的 PD 分离部署,增强MTP 投机解码能力,全面优化多硬件平台上的 MoE 推理与多模态前缀缓存性能,升级全部内容参阅 [v2.4 ReleaseNote](https://github.com/PaddlePaddle/FastDeploy/releases/tag/v2.4.0)。 + +**[2025-11] FastDeploy v2.3**: 新增[ERNIE-4.5-VL-28B-A3B-Thinking](docs/zh/get_started/ernie-4.5-vl-thinking.md)与[PaddleOCR-VL-0.9B](docs/zh/best_practices/PaddleOCR-VL-0.9B.md)两大重磅模型在多硬件平台上的部署支持,进一步优化全方位推理性能,以及带来更多部署功能和易用性的提升,升级全部内容参阅[v2.3 ReleaseNote](https://github.com/PaddlePaddle/FastDeploy/releases/tag/v2.3.0)。 + +**[2025-09] FastDeploy v2.2**: HuggingFace生态模型兼容,性能进一步优化,更新增对[baidu/ERNIE-21B-A3B-Thinking](https://huggingface.co/baidu/ERNIE-4.5-21B-A3B-Thinking)支持! + +**[2025-08] FastDeploy v2.1**:全新的KV Cache调度策略,更多模型支持PD分离和CUDA Graph,昆仑、海光等更多硬件支持增强,全方面优化服务和推理引擎的性能。 + +## 关于 + +**FastDeploy** 是基于飞桨(PaddlePaddle)的大语言模型(LLM)与视觉语言模型(VLM)推理部署工具包,提供**开箱即用的生产级部署方案**,核心技术特性包括: + +- 🚀 **负载均衡式PD分解**:工业级解决方案,支持上下文缓存与动态实例角色切换,在保障SLO达标和吞吐量的同时优化资源利用率 +- 🔄 **统一KV缓存传输**:轻量级高性能传输库,支持智能NVLink/RDMA选择 +- 🤝 **OpenAI API服务与vLLM兼容**:单命令部署,兼容[vLLM](https://github.com/vllm-project/vllm/)接口 +- 🧮 **全量化格式支持**:W8A16、W8A8、W4A16、W4A8、W2A16、FP8等 +- ⏩ **高级加速技术**:推测解码、多令牌预测(MTP)及分块预填充 +- 🖥️ **多硬件支持**:NVIDIA GPU、昆仑芯XPU、海光DCU、天数智芯GPU、燧原GCU、沐曦GPU、英特尔Gaudi等 + +## 要求 + +- 操作系统: Linux +- Python: 3.10 ~ 3.12 + +## 安装 + +FastDeploy 支持在**英伟达(NVIDIA)GPU**、**昆仑芯(Kunlunxin)XPU**、**天数(Iluvatar)GPU**、**燧原(Enflame)GCU**、**海光(Hygon)DCU** 以及其他硬件上进行推理部署。详细安装说明如下: + +- [英伟达 GPU](./docs/zh/get_started/installation/nvidia_gpu.md) +- [昆仑芯 XPU](./docs/zh/get_started/installation/kunlunxin_xpu.md) +- [天数 CoreX](./docs/zh/get_started/installation/iluvatar_gpu.md) +- [燧原 S60](./docs/zh/get_started/installation/Enflame_gcu.md) +- [海光 DCU](./docs/zh/get_started/installation/hygon_dcu.md) +- [沐曦 GPU](./docs/zh/get_started/installation/metax_gpu.md) +- [英特尔 Gaudi](./docs/zh/get_started/installation/intel_gaudi.md) + +## 入门指南 + +通过我们的文档了解如何使用 FastDeploy: +- [10分钟快速部署](./docs/zh/get_started/quick_start.md) +- [ERNIE-4.5 部署](./docs/zh/get_started/ernie-4.5.md) +- [ERNIE-4.5-VL 部署](./docs/zh/get_started/ernie-4.5-vl.md) +- [离线推理](./docs/zh/offline_inference.md) +- [在线服务](./docs/zh/online_serving/README.md) +- [最佳实践](./docs/zh/best_practices/README.md) + +## 支持模型列表 + +通过我们的文档了解如何下载模型,如何支持torch格式等: +- [模型支持列表](./docs/zh/supported_models.md) + +## 进阶用法 + +- [量化](./docs/zh/quantization/README.md) +- [分离式部署](./docs/zh/features/disaggregated.md) +- [投机解码](./docs/zh/features/speculative_decoding.md) +- [前缀缓存](./docs/zh/features/prefix_caching.md) +- [分块预填充](./docs/zh/features/chunked_prefill.md) +- [负载均衡调度Router](./docs/zh/online_serving/router.md) +- [全局Cache池化](./docs/zh/features/global_cache_pooling.md) + +## 致谢 + +FastDeploy 依据 [Apache-2.0 开源许可证](./LICENSE). 进行授权。在开发过程中,我们参考并借鉴了 [vLLM](https://github.com/vllm-project/vllm) 的部分代码,以保持接口兼容性,在此表示衷心感谢。 diff --git a/README_EN.md b/README_EN.md new file mode 100644 index 00000000000..4d918455d5f --- /dev/null +++ b/README_EN.md @@ -0,0 +1,91 @@ +English | [简体中文](README_CN.md) +

+ +

+

+ + + + + + + +

+ +

+ PaddlePaddle%2FFastDeploy | Trendshift
+ Installation + | + Quick Start + | + Supported Models + +

+ +-------------------------------------------------------------------------------- +# FastDeploy : Inference and Deployment Toolkit for LLMs and VLMs based on PaddlePaddle + +## News + +[2026-01] FastDeploy v2.4 is released! Featuring PD-separated deployment for DeepSeek V3 and Qwen3-MoE, enhanced MTP speculative decoding, and comprehensive performance boosts for MoE inference and multi-modal Prefix Caching across various hardware backends. See the full v2.4 ReleaseNote for more details. + +**[2025-11] FastDeploy v2.3**: It adds deployment support for two major models, [ERNIE-4.5-VL-28B-A3B-Thinking](docs/get_started/ernie-4.5-vl-thinking.md) and [PaddleOCR-VL-0.9B](docs/best_practices/PaddleOCR-VL-0.9B.md), across multiple hardware platforms. It further optimizes comprehensive inference performance and brings more deployment features and usability enhancements. For all the upgrade details, refer to the [v2.3 Release Note](https://github.com/PaddlePaddle/FastDeploy/releases/tag/v2.3.0). + +**[2025-09] FastDeploy v2.2**: It now offers compatibility with models in the HuggingFace ecosystem, has further optimized performance, and newly adds support for [baidu/ERNIE-21B-A3B-Thinking](https://huggingface.co/baidu/ERNIE-4.5-21B-A3B-Thinking)! + +## About + +**FastDeploy** is an inference and deployment toolkit for large language models and visual language models based on PaddlePaddle. It delivers **production-ready, out-of-the-box deployment solutions** with core acceleration technologies: + +- 🚀 **Load-Balanced PD Disaggregation**: Industrial-grade solution featuring context caching and dynamic instance role switching. Optimizes resource utilization while balancing SLO compliance and throughput. +- 🔄 **Unified KV Cache Transmission**: Lightweight high-performance transport library with intelligent NVLink/RDMA selection. +- 🤝 **OpenAI API Server and vLLM Compatible**: One-command deployment with [vLLM](https://github.com/vllm-project/vllm/) interface compatibility. +- 🧮 **Comprehensive Quantization Format Support**: W8A16, W8A8, W4A16, W4A8, W2A16, FP8, and more. +- ⏩ **Advanced Acceleration Techniques**: Speculative decoding, Multi-Token Prediction (MTP) and Chunked Prefill. +- 🖥️ **Multi-Hardware Support**: NVIDIA GPU, Kunlunxin XPU, Hygon DCU, Iluvatar GPU, Enflame GCU, MetaX GPU, Intel Gaudi etc. + +## Requirements + +- OS: Linux +- Python: 3.10 ~ 3.12 + +## Installation + +FastDeploy supports inference deployment on **NVIDIA GPUs**, **Kunlunxin XPUs**, **Iluvatar GPUs**, **Enflame GCUs**, **Hygon DCUs** and other hardware. For detailed installation instructions: + +- [NVIDIA GPU](./docs/get_started/installation/nvidia_gpu.md) +- [Kunlunxin XPU](./docs/get_started/installation/kunlunxin_xpu.md) +- [Iluvatar GPU](./docs/get_started/installation/iluvatar_gpu.md) +- [Enflame GCU](./docs/get_started/installation/Enflame_gcu.md) +- [Hygon DCU](./docs/get_started/installation/hygon_dcu.md) +- [MetaX GPU](./docs/get_started/installation/metax_gpu.md) +- [Intel Gaudi](./docs/get_started/installation/intel_gaudi.md) + +## Get Started + +Learn how to use FastDeploy through our documentation: +- [10-Minutes Quick Deployment](./docs/get_started/quick_start.md) +- [ERNIE-4.5 Large Language Model Deployment](./docs/get_started/ernie-4.5.md) +- [ERNIE-4.5-VL Multimodal Model Deployment](./docs/get_started/ernie-4.5-vl.md) +- [Offline Inference Development](./docs/offline_inference.md) +- [Online Service Deployment](./docs/online_serving/README.md) +- [Best Practices](./docs/best_practices/README.md) + +## Supported Models + +Learn how to download models, enable using the torch format, and more: +- [Full Supported Models List](./docs/supported_models.md) + +## Advanced Usage + +- [Quantization](./docs/quantization/README.md) +- [PD Disaggregation Deployment](./docs/features/disaggregated.md) +- [Speculative Decoding](./docs/features/speculative_decoding.md) +- [Prefix Caching](./docs/features/prefix_caching.md) +- [Chunked Prefill](./docs/features/chunked_prefill.md) +- [Load-Balancing Scheduling Router](./docs/online_serving/router.md) +- [Global Cache Pooling](./docs/features/global_cache_pooling.md) + +## Acknowledgement + +FastDeploy is licensed under the [Apache-2.0 open-source license](./LICENSE). During development, portions of [vLLM](https://github.com/vllm-project/vllm) code were referenced and incorporated to maintain interface compatibility, for which we express our gratitude. diff --git a/benchmarks/README.md b/benchmarks/README.md index 85a0a6f4131..a71ff5cc333 100644 --- a/benchmarks/README.md +++ b/benchmarks/README.md @@ -41,7 +41,19 @@ python -m pip install -r requirements.txt --metric-percentiles 80,95,99,99.9,99.95,99.99:性能结果中展示的性能指标分位值 --num-prompts 1:总计发送多少条请求 --max-concurrency 1:压测并发数 ---save-result:开启结果保存,结果文件会存入json +--save-result:开启结果保存,结果文件会存入json,默认False不保存 +--debug:开启debug模式,逐条打印payload和output内容,默认False +--shuffle:是否打乱数据集,默认False不打乱 +--seed:打乱数据集时的随机种子,默认0 +--pd-metrics:开启PD分离metrics指标收集,会添加请求参数collect_metrics=True,默认False +--ip-list:支持多个ip:port,将总请求数以及总并发数均分到每个IP,按整除取余分配。例:0.0.0.0:1211,0.0.0.0:1222,默认为空 +--multi-turn:开启多轮对话,将数据集messages中的多轮对话逐轮请求,默认False不区分多轮。若需要添加tool_call,需在hyperparameter-path超参yaml中配置tools,参考yaml/request_yaml/GLM-32k-tool-call.yaml,数据集中需要指定tool_url,max_loop(非必选,默认10)为单轮调用最大次数 +``` +多轮对话使用prompt_token_ids模式请求 +```bash +开启--multi-turn +--tokenizer-model:使用prompt_token_ids请求时指定,多轮对话tokenizer模型类型,可选"eb": ErnieBotTokenizer, "eb5": Ernie5Tokenizer, "eb_mm": Ernie4_5Tokenizer +--tokenizer-path:使用prompt_token_ids请求时指定,模型tokenizer路径 ``` ##### /v1/chat/completions接口压测单条数据调试 @@ -55,7 +67,7 @@ python benchmark_serving.py \ --port 9812 \ --dataset-name EBChat \ --dataset-path ./filtered_sharedgpt_2000_input_1136_output_200_fd.json \ - --hyperparameter-path yaml/request_yaml/eb45t-32k.yaml \ + --hyperparameter-path yaml/request_yaml/eb45-32k.yaml \ --percentile-metrics ttft,tpot,itl,e2el,s_ttft,s_itl,s_e2el,s_decode,input_len,s_input_len,output_len \ --metric-percentiles 80,95,99,99.9,99.95,99.99 \ --num-prompts 1 \ @@ -75,7 +87,7 @@ python benchmark_serving.py \ --port 9812 \ --dataset-name EBChat \ --dataset-path ./filtered_sharedgpt_2000_input_1136_output_200_fd.json \ - --hyperparameter-path yaml/request_yaml/eb45t-32k.yaml \ + --hyperparameter-path yaml/request_yaml/eb45-32k.yaml \ --percentile-metrics ttft,tpot,itl,e2el,s_ttft,s_itl,s_e2el,s_decode,input_len,s_input_len,output_len \ --metric-percentiles 80,95,99,99.9,99.95,99.99 \ --num-prompts 2000 \ @@ -97,7 +109,7 @@ python benchmark_serving.py \ --port 9812 \ --dataset-name EBChat \ --dataset-path ./filtered_sharedgpt_2000_input_1136_output_200_fd.json \ - --hyperparameter-path yaml/request_yaml/eb45t-32k.yaml \ + --hyperparameter-path yaml/request_yaml/eb45-32k.yaml \ --percentile-metrics ttft,tpot,itl,e2el,s_ttft,s_itl,s_e2el,s_decode,input_len,s_input_len,output_len \ --metric-percentiles 80,95,99,99.9,99.95,99.99 \ --num-prompts 2000 \ @@ -132,3 +144,30 @@ python benchmarks/benchmark_mtp.py \ --dataset-name:指定数据集类,指定为"EBChat"可读取转存的FD格式数据集 --dataset-path:测试数据集路径 ``` + +### 指定输入输出长度,构造随机纯文输入测试 + +相关参数: +- --dataset-name:指定数据集类,指定为"random"可构造随机纯文输入 +- --random-input-len:随机输入长度,对应英文单词数,默认200 +- --random-output-len:随机输出长度,默认1024 +- --random-range-ratio:输入输出长度变化范围比,[length *(1 - range_ratio), length* (1 + range_ratio)],默认0.1 + +#### 使用方式: +```bash +python benchmark_serving.py \ + --backend openai-chat \ + --model EB45T \ + --endpoint /v1/chat/completions \ + --host 0.0.0.0 \ + --port 9812 \ + --dataset-name random \ + --random-input-len 200 \ + --random-output-len 1024 \ + --random-range-ratio 0.1 \ + --percentile-metrics ttft,tpot,itl,e2el,s_ttft,s_itl,s_e2el,s_decode,input_len,s_input_len,output_len \ + --metric-percentiles 80,95,99,99.9,99.95,99.99 \ + --num-prompts 2000 \ + --max-concurrency 100 \ + --save-result > infer_log.txt 2>&1 & +``` diff --git a/benchmarks/backend_request_func.py b/benchmarks/backend_request_func.py index c83b725ecbf..b8c27f96d26 100644 --- a/benchmarks/backend_request_func.py +++ b/benchmarks/backend_request_func.py @@ -17,8 +17,10 @@ # This file is modified from https://github.com/vllm-project/vllm/blob/main/benchmarks/backend_request_func.py +import copy import io import json +import logging import os import sys import time @@ -50,6 +52,14 @@ class RequestFuncInput: multi_modal_content: Optional[dict] = None ignore_eos: bool = False language: Optional[str] = None + debug: bool = False + pd_metrics: bool = False + response_format: Optional[dict] = None + random_flag: bool = False + json_data: Optional[dict] = None + prompt_token_ids: Optional[list] = None + tokenizer_model: str = None + tokenizer_path: str = None @dataclass @@ -57,10 +67,12 @@ class RequestFuncOutput: """Output for requesting LLMs via API""" no: int = 0 + request_id: str = "" generated_text: str = "" reasoning_content: str = "" success: bool = False latency: float = 0.0 + end_timestamp: float = 0.0 # 模型完全返回的时间戳(秒, perf_counter基准) output_tokens: int = 0 ttft: float = 0.0 # Time to first token arrival_time: list = field(default_factory=list) # arrival_time @@ -68,119 +80,742 @@ class RequestFuncOutput: tpot: float = 0.0 # avg next-token latencies prompt_len: int = 0 prompt_tokens: int = 0 # 推理侧返回输入token数 + reasoning_tokens: int = 0 # 思考长度 + res_ttft: int = 0 # 包含思考首token时延 error: str = "" + metrics: dict = field(default_factory=dict) + tool_calls: list = field(default_factory=list) + output_ids: list = field(default_factory=list) + + +@dataclass +class SessionMetrics: + """多轮对话指标""" + + session_no: int + session_e2e_time: float + pure_llm_time: float + input_tokens: int + output_tokens: int + tool_calls: int + + +def safe_cost(a, b): + """时间差计算""" + if a is None or b is None: + return None + return a - b + + +def metrics_summary(metrics, token_timestamps): + """Summarize metrics""" + if not metrics or len(token_timestamps) < 2: + return {} + + m0 = metrics[0] + m_last = metrics[-1] + + summary = {} + + arrival_time = m0.get("arrival_time") + inference_start_time = m0.get("inference_start_time") + + # prefill 总耗时 + summary["prefill_cost_time"] = safe_cost(m0.get("send_request_output_to_decode_time"), arrival_time) + # prefill准备总耗时 + summary["prefill_prepare_cost_time"] = safe_cost(inference_start_time, arrival_time) + # 预处理耗时 + summary["preprocess_cost_time"] = safe_cost(m0.get("scheduler_recv_req_time"), arrival_time) + # 请求缓存耗时 + summary["cache_in_scheduler_cost_time"] = safe_cost( + m0.get("engine_get_req_time"), m0.get("scheduler_recv_req_time") + ) + # 申请 decode资源耗时 + summary["ask_decode_resource_cost_time"] = safe_cost( + m0.get("ask_decode_resource_finish_time"), m0.get("ask_decode_resource_start_time") + ) + # scheduler调度耗时 + summary["schedule_cost_time"] = safe_cost( + m0.get("inference_start_time"), m0.get("ask_decode_resource_finish_time") + ) + # prefill 的首 token 推理耗时 + summary["prefill_first_token_infer_cost_time"] = safe_cost( + m0.get("engine_recv_first_token_time"), inference_start_time + ) + # prefill 等待 cache 传输耗时 + summary["wait_sending_cache_cost_time"] = safe_cost( + m0.get("send_request_output_to_decode_time"), m0.get("wait_for_sending_cache_time") + ) + # decode分配资源耗时 + summary["decode_preallocate_cost_time"] = safe_cost( + m_last.get("decode_preallocate_req_time"), m_last.get("decode_recv_req_time") + ) + # decode准备推理耗时 + summary["decode_prepare_cost_time"] = safe_cost( + m_last.get("decode_inference_start_time"), m_last.get("decode_recv_first_token_time") + ) + # decode次token推理耗时 + summary["decode_second_token_infer_cost_time"] = safe_cost( + m_last.get("decode_recv_second_token_time"), m_last.get("decode_inference_start_time") + ) + # 返回首 token 链路耗时 + summary["first_token_transmission_cost_time"] = safe_cost( + token_timestamps[0], m_last.get("decode_recv_first_token_time") + ) + # 返回次 token 链路耗时 + summary["second_token_transmission_cost_time"] = safe_cost( + token_timestamps[1], m_last.get("decode_recv_second_token_time") + ) + + # MIX 模式下,scheduler调度耗时 + summary["mixed_schedule_cost_time"] = safe_cost(m0.get("inference_start_time"), m0.get("engine_get_req_time")) + # MIX 模式下,返回首 token 链路耗时 + summary["mixed_first_token_transmission_cost_time"] = safe_cost( + token_timestamps[0], m0.get("engine_recv_first_token_time") + ) + + summary["gpu_cache_token_num"] = m0.get("gpu_cache_token_num") + summary["cpu_cache_token_num"] = m0.get("cpu_cache_token_num") + summary["storage_cache_token_num"] = m0.get("storage_cache_token_num") + summary["cpu_cache_prepare_time"] = m0.get("cpu_cache_prepare_time") + summary["storage_cache_prepare_time"] = m0.get("storage_cache_prepare_time") + + return summary + + +def load_tokenizer(model, actor_tokenizer_path): + """加载tokenizer""" + from ernie_tokenizer import Ernie5Tokenizer, ErnieBotTokenizer + from paddleformers.transformers import AutoTokenizer + + from fastdeploy.input.ernie4_5_tokenizer import Ernie4_5Tokenizer + + vocab_file_names = ["tokenizer.model", "spm.model", "ernie_token_100k.model"] + + try: + if model == "eb": + for i in range(len(vocab_file_names)): + if os.path.exists(os.path.join(actor_tokenizer_path, vocab_file_names[i])): + ErnieBotTokenizer.resource_files_names["vocab_file"] = vocab_file_names[i] + break + tokenizer = ErnieBotTokenizer.from_pretrained(actor_tokenizer_path) + elif model == "eb_mm": + for vocab_file in vocab_file_names: + full_path = os.path.join(actor_tokenizer_path, vocab_file) + if os.path.exists(full_path): + Ernie4_5Tokenizer.resource_files_names["vocab_file"] = vocab_file + # for i in range(len(vocab_file_names)): + # if os.path.exists(os.path.join(actor_tokenizer_path, vocab_file_names[i])): + # Ernie45Tokenizer.resource_files_names["vocab_file"] = vocab_file_names[i] + # break + tokenizer = Ernie4_5Tokenizer.from_pretrained(actor_tokenizer_path) + # tokenizer.ignored_index = -100 + elif model == "eb5": + for i in range(len(vocab_file_names)): + if os.path.exists(os.path.join(actor_tokenizer_path, vocab_file_names[i])): + Ernie5Tokenizer.resource_files_names["vocab_file"] = vocab_file_names[i] + break + tokenizer = Ernie5Tokenizer.from_pretrained(actor_tokenizer_path) + else: + print("tokenizer: AUTO") + tokenizer = AutoTokenizer.from_pretrained(actor_tokenizer_path, padding_side="left", use_fast=True) + except Exception as e: + tokenizer = None + logging.warning(f"Load tokenizer error: {e}") + + return tokenizer async def async_request_eb_openai_chat_completions( request_func_input: RequestFuncInput, pbar: Optional[tqdm] = None, + session: aiohttp.ClientSession | None = None, ) -> RequestFuncOutput: """Request an LLM using EB OpenAI""" api_url = request_func_input.api_url assert api_url.endswith(("completions", "profile")), "OpenAI Chat Completions API URL must end with 'completions'." - async with aiohttp.ClientSession(trust_env=True, timeout=AIOHTTP_TIMEOUT) as session: - content = [{"type": "text", "text": request_func_input.prompt}] - if request_func_input.multi_modal_content: - content.append(request_func_input.multi_modal_content) - payload = { - "model": request_func_input.model, - "messages": request_func_input.history_QA, - "stream": True, - "stream_options": { - "include_usage": True, - "continuous_usage_stats": True, - }, - } - # 超参由yaml传入 - payload.update(request_func_input.hyper_parameters) + own_session = session is None + if own_session: + session = aiohttp.ClientSession( + trust_env=True, + read_bufsize=10 * 1024 * 1024, + timeout=AIOHTTP_TIMEOUT, + ) + + content = [{"type": "text", "text": request_func_input.prompt}] + if request_func_input.multi_modal_content: + content.append(request_func_input.multi_modal_content) + # print("######json_data:", request_func_input.json_data) + payload = { + "model": request_func_input.model, + "messages": request_func_input.history_QA, + "stream": True, + "stream_options": { + "include_usage": True, + "continuous_usage_stats": True, + }, + "max_tokens": request_func_input.output_len, + "collect_metrics": request_func_input.pd_metrics, + } + if request_func_input.json_data: + json_data = request_func_input.json_data + + if json_data.get("max_tokens"): + payload["max_tokens"] = json_data["max_tokens"] + + if json_data.get("min_tokens"): + payload["min_tokens"] = json_data["min_tokens"] + if request_func_input.response_format: + payload["response_format"] = request_func_input.response_format + + # 支持传入prompt_token_ids + if request_func_input.prompt_token_ids: + # 不走messages + payload["messages"] = [{"role": "user", "content": [{"type": "text", "text": ""}]}] + payload["prompt_token_ids"] = request_func_input.prompt_token_ids + payload["return_token_ids"] = True + # print("use_token_ids:", payload) + + # 超参由yaml传入 + payload.update(request_func_input.hyper_parameters) + + # tools信息,yaml优先级最高 + json_data = request_func_input.json_data or {} + hyper = request_func_input.hyper_parameters or {} + + tools = None + tool_choice = None + + if hyper.get("tools"): + tools = hyper.get("tools") + tool_choice = hyper.get("tool_choice", "auto") + elif json_data.get("tools"): + tools = json_data.get("tools") + tool_choice = json_data.get("tool_choice", "auto") + + if tools: + payload["tools"] = tools + payload["tool_choice"] = tool_choice + + # 随机输入开关 + if request_func_input.random_flag: + payload["max_tokens"] = request_func_input.output_len + metadata = payload.get("metadata", {}) + metadata["min_tokens"] = request_func_input.output_len + payload["metadata"] = metadata + + if request_func_input.ignore_eos: + payload["ignore_eos"] = request_func_input.ignore_eos + + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", + } + + output = RequestFuncOutput() + output.prompt_len = 0 + output.no = request_func_input.no + payload["no"] = request_func_input.no + if request_func_input.debug: + print(f"payload:{json.dumps(payload, ensure_ascii=False)}") + metrics_list = [] + request_id = "None" + + ttft = 0.0 + res_ttft = 0.0 + st = time.perf_counter() + most_recent_timestamp = st + token_timestamps = [] + tool_call_buffer = {} + try: + async with session.post(url=api_url, json=payload, headers=headers, read_bufsize=10 * 1024 * 1024) as response: + data = {} + if response.status == 200: + async for chunk_bytes in response.content: + chunk_bytes = chunk_bytes.strip() + if not chunk_bytes: + continue + + chunk = chunk_bytes.decode("utf-8").removeprefix("data: ") + if chunk != "[DONE]": + # print("####chunk:", chunk, type(chunk)) + timestamp = time.perf_counter() + data = json.loads(chunk) + # print("####data:", json.dumps(data, indent=2, ensure_ascii=False)) + + if "metrics" in data: + metrics_list.append(data["metrics"]) + + if request_id == "None" and "id" in data: + request_id = data["id"] + + if choices := data.get("choices"): + content = choices[0]["delta"].get("content") + reason_content = choices[0]["delta"].get("reasoning_content") + tool_calls = choices[0]["delta"].get("tool_calls") + completion_token_ids = choices[0]["delta"].get("completion_token_ids", []) + if tool_calls: + for tc in tool_calls: + idx = tc.get("index", 0) + + if idx not in tool_call_buffer: + tool_call_buffer[idx] = { + "id": tc.get("id"), + "name": "", + "arguments": "", + } + + func = tc.get("function", {}) + + if "name" in func: + tool_call_buffer[idx]["name"] = func["name"] + + if "arguments" in func: + tool_call_buffer[idx]["arguments"] += func["arguments"] + + # First token + if ttft == 0.0: + ttft = timestamp - st + output.ttft = ttft + # cached_tokens + if data["usage"] and data["usage"].get("prompt_tokens_details", {}): + output.prompt_len = ( + data["usage"].get("prompt_tokens_details", {}).get("cached_tokens", 0) + ) + else: + output.prompt_len = 0 + + # Decoding phase + else: + output.itl.append(timestamp - most_recent_timestamp) + + # response首token + if res_ttft == 0.0: + if content: + res_ttft = choices[0].get("arrival_time", timestamp) + output.res_ttft = res_ttft + usage = data.get("usage") or {} + output.reasoning_tokens = max(usage.get("completion_tokens", 0) - 1, 0) + + output.generated_text += content or "" + output.reasoning_content += reason_content or "" + if completion_token_ids: + output.output_ids.extend(completion_token_ids) + # print(f"####content:{data}") + output.arrival_time.append(choices[0].get("arrival_time", timestamp)) + elif usage := data.get("usage", {}): + output.output_tokens = usage.get("completion_tokens", 0) + output.prompt_tokens = usage.get("prompt_tokens", 0) + if output.prompt_len == 0: + if data["usage"] and data["usage"].get("prompt_tokens_details", {}): + output.prompt_len = ( + data["usage"].get("prompt_tokens_details", {}).get("cached_tokens", 0) + ) - if request_func_input.ignore_eos: - payload["ignore_eos"] = request_func_input.ignore_eos + most_recent_timestamp = timestamp + token_timestamps.append(time.time()) - print(f"payload:{json.dumps(payload, ensure_ascii=False)}") + # output.generated_text = generated_text + # 在流式结束时,记录最后一个 chunk 收到的时间戳 + output.end_timestamp = most_recent_timestamp - headers = { - "Content-Type": "application/json", - "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", - } + if tool_call_buffer: + for _, tc in tool_call_buffer.items(): + try: + args = json.loads(tc["arguments"]) if tc["arguments"] else {} + except: + args = {} - output = RequestFuncOutput() - output.prompt_len = 0 - output.no = request_func_input.no + output.tool_calls.append({"id": tc["id"], "name": tc["name"], "arguments": args}) - ttft = 0.0 - st = time.perf_counter() - most_recent_timestamp = st - try: - async with session.post(url=api_url, json=payload, headers=headers) as response: - if response.status == 200: - async for chunk_bytes in response.content: - chunk_bytes = chunk_bytes.strip() - if not chunk_bytes: - continue + # 新增metrics统计,计算首token过滤空包 + output.metrics = metrics_summary(metrics_list, token_timestamps[1:]) - chunk = chunk_bytes.decode("utf-8").removeprefix("data: ") - if chunk != "[DONE]": - # print("####chunk:", chunk, type(chunk)) - timestamp = time.perf_counter() - data = json.loads(chunk) + has_text = output.generated_text.strip() or output.reasoning_content.strip() + has_tool = getattr(output, "tool_calls", None) - if choices := data.get("choices"): - content = choices[0]["delta"].get("content") - reason_content = choices[0]["delta"].get("reasoning_content") - # First token - if ttft == 0.0: - ttft = timestamp - st - output.ttft = ttft - # cached_tokens - output.prompt_len = ( - data["usage"].get("prompt_tokens_details", {}).get("cached_tokens", 0) - ) + # 兼容思考内容超长截断的情况,此时回复内容为空 + if not has_text and not has_tool: + output.success = False + output.reasoning_tokens = output.output_tokens + output.error = "No generated text found!" + else: + output.success = True + output.latency = most_recent_timestamp - st + else: + error_text = await response.text() + print( + "####error response:", + error_text, + "####payload:", + payload, + ) + output.error = error_text or "" + output.success = False + except Exception: + output.success = False + exc_info = sys.exc_info() + output.error = "".join(traceback.format_exception(*exc_info)) + finally: + if own_session: + await session.close() + + output.request_id = request_id + + # 保存失败请求结果 + if not output.success or output.output_tokens == 0: + with open("error_output.txt", "a") as f: + f.write(str(output) + "\n") + if pbar: + pbar.update(1) + if request_func_input.debug: + print("#####final_output:", output) + return output - # Decoding phase - else: - output.itl.append(timestamp - most_recent_timestamp) - output.generated_text += content or "" - output.reasoning_content += reason_content or "" - output.arrival_time.append(choices[0].get("arrival_time", timestamp)) - elif usage := data.get("usage", {}): - output.output_tokens = usage.get("completion_tokens", 0) - output.prompt_tokens = usage.get("prompt_tokens", 0) +async def simple_tool_call(model_output, tool_url: str, timeout=60): + """调用工具函数""" + import re - most_recent_timestamp = timestamp + import httpx - # output.generated_text = generated_text - if output.generated_text.strip() == "": - output.success = False - output.error = "No generated text found!" + tool_id = None + + if getattr(model_output, "tool_calls", None): + tc = model_output.tool_calls[0] + tool_name = tc["name"] + args = tc.get("arguments", {}) + tool_id = tc.get("id") + else: + match = re.search(r"(.*?)", model_output.generated_text, re.S) + if not match: + return "", False, "", tool_id + + block = match.group(1).strip() + lines = block.splitlines() + tool_name = lines[0].strip() + + key = re.search(r"(.*?)", block) + val = re.search(r"(.*?)", block) + + args = {key.group(1): val.group(1)} if key and val else {} + + if not tool_name: + return "", False, "", tool_id + + headers = {"Content-Type": "application/json"} + + try: + async with httpx.AsyncClient(timeout=timeout) as client: + resp = await client.post( + tool_url, + headers=headers, + json={"tool_name": tool_name, "arguments": args}, + ) + + resp.raise_for_status() + obj = resp.json() + + return obj.get("result", resp.text), "result" in obj, tool_name, tool_id + + except Exception as e: + print(f"[TOOL ERROR] {tool_name}: {repr(e)}") + return str(e), False, tool_name, tool_id + + +async def async_request_eb_openai_chat_completions_multi_turn( + request_func_input: RequestFuncInput, + pbar: Optional[tqdm] = None, +): + # yaml中或数据集中带tools才走工具调用逻辑 + json_data = request_func_input.json_data or {} + hyper = request_func_input.hyper_parameters or {} + enable_tools = bool(json_data.get("tools") or hyper.get("tools")) + + outputs = [] + + tool_call_count = 0 + llm_time = 0.0 + tool_time = 0.0 + input_tokens = 0 + output_tokens = 0 + + ori_history = request_func_input.history_QA + user_count = sum(msg.get("role") == "user" for msg in ori_history) + print("START", request_func_input.no, "user对话轮数:", user_count, flush=True) + history = [] + prompt_no = 0 + max_prompt_len = ( + hyper.get("max_prompt_len") if hyper.get("max_prompt_len") is not None else json_data.get("max_prompt_len") + ) + print("max_prompt_len:", max_prompt_len) + input_ids_all = [] + # FD每轮 completion_token_ids + output_ids = [] + use_token_ids = bool(request_func_input.tokenizer_model and request_func_input.tokenizer_path) + tokenizer = None + + if use_token_ids: + print("token ids 拼接模式") + enable_tools = False + print("tokenizer_model:", request_func_input.tokenizer_model) + print("tokenizer_path:", request_func_input.tokenizer_path) + tokenizer = load_tokenizer( + request_func_input.tokenizer_model, + request_func_input.tokenizer_path, + ) + else: + print("messages 明文拼接模式") + + # 只创建一次 session + session_start = time.perf_counter() + connector = aiohttp.TCPConnector( + limit=0, + limit_per_host=0, + keepalive_timeout=60, + ) + + async with aiohttp.ClientSession( + connector=connector, + trust_env=True, + read_bufsize=10 * 1024 * 1024, + timeout=AIOHTTP_TIMEOUT, + ) as session: + for i, message in enumerate(ori_history): + if message["role"] == "user" or message["role"] == "tool": + history.append(message) + round_input = copy.deepcopy(request_func_input) + round_input.history_QA = history + round_input.no = f"{round_input.no}_{prompt_no}" + if use_token_ids: + if len(input_ids_all) == 0: + # 拼接token_ids模式,首轮token_ids + spliced_text = tokenizer.apply_chat_template( + history, + tokenize=False, + split_special_tokens=False, + add_special_tokens=False, + ) + # 转换为token ids + tokens = tokenizer.tokenize(spliced_text) + prompt_token_ids = tokenizer.convert_tokens_to_ids(tokens) + input_ids_all.extend(prompt_token_ids) + round_input.prompt_token_ids = input_ids_all else: - output.success = True - output.latency = most_recent_timestamp - st + prompt_length = len(input_ids_all) + len(output_ids) + if max_prompt_len and prompt_length >= max_prompt_len: + # 超长截断 + print( + f"[SESSION STOP] {round_input.no} reach max_prompt_len={max_prompt_len}, stop session" + ) + break + # 拼接token_ids模式,后续轮 + input_ids_all.extend(output_ids) + user_prompt = message["content"] + # 拼接user_prompt + if round_input.tokenizer_model == "eb5": + # EB5模型 + user_prompt = ( + f"\n\n<|im_start|>user\n{user_prompt}<|im_end|>\n\n<|im_start|>assistant\n\n" + ) + else: + # 0.3B模型,2 ,拼接时会被替换成100272 <|end_of_sentence|> + input_ids_all[-1] = 100272 + user_prompt = f"User: {user_prompt}\nAssistant: " + prompt_token_ids = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(user_prompt)) + input_ids_all.extend(prompt_token_ids) + round_input.prompt_token_ids = input_ids_all + # 复用 session + s0 = time.perf_counter() + output = await async_request_eb_openai_chat_completions( + round_input, + pbar=None, + session=session, + ) + s1 = time.perf_counter() + llm_time += s1 - s0 + + outputs.append(output) + + if not output.success: + session_end = time.perf_counter() + metrics = SessionMetrics( + session_no=request_func_input.no, + session_e2e_time=session_end - session_start, + pure_llm_time=llm_time, + input_tokens=input_tokens, + output_tokens=output_tokens, + tool_calls=tool_call_count, + ) + return outputs, metrics + + # llm_cost = s1 - s0 + input_tokens += output.prompt_tokens + output_tokens += output.output_tokens + + # 更新output_ids + output_ids = output.output_ids + + if max_prompt_len and input_tokens >= max_prompt_len: + # 后验超长截断 + print(f"[SESSION STOP] {round_input.no} reach max_prompt_len={max_prompt_len}, stop session") + break + + if enable_tools: + # 循环调用工具 + max_loop = json_data.get("max_loop", 10) + tool_url = json_data.get("tool_url", "") + max_prompt_len = json_data.get("max_prompt_len") + if not tool_url: + raise ValueError("tool_url is empty.") + for _ in range(max_loop): + t0 = time.perf_counter() + tool_result, is_tool_result, tool_name, tool_id = await simple_tool_call( + output, + tool_url, + ) + t1 = time.perf_counter() + tool_time += t1 - t0 + # print(f"#### tool_time: {t1 - t0:.3f}") + # print(f"#### tool_result: {tool_result}") + # print(f"#### is_tool_result: {is_tool_result}") + + # 工具调用失败 + if tool_name and not is_tool_result: + print(f"[SESSION FAIL] tool call failed: {tool_name}") + + output.success = False + + session_end = time.perf_counter() + session_e2e_time = session_end - session_start + tool_call_count += 1 + + metrics = SessionMetrics( + session_no=request_func_input.no, + session_e2e_time=session_e2e_time, + pure_llm_time=llm_time, + input_tokens=input_tokens, + output_tokens=output_tokens, + tool_calls=tool_call_count, + ) + + return outputs, metrics + + if not is_tool_result: + history.append( + { + "role": "assistant", + "content": output.generated_text, + } + ) + break + + assistant_msg = { + "role": "assistant", + "content": output.generated_text, + } + + if getattr(output, "tool_calls", None): + assistant_msg["tool_calls"] = [ + { + "id": tc["id"], + "type": "function", + "function": { + "name": tc["name"], + "arguments": json.dumps(tc["arguments"], ensure_ascii=False), + }, + } + for tc in output.tool_calls + ] + + history.append(assistant_msg) + + history.append( + { + "role": "tool", + "content": json.dumps(tool_result, ensure_ascii=False), + "tool_call_id": tool_id or tool_name, + } + ) + tool_call_count += 1 + + round_input.history_QA = history + + s0 = time.perf_counter() + output = await async_request_eb_openai_chat_completions( + round_input, + pbar=None, + session=session, + ) + s1 = time.perf_counter() + llm_time += s1 - s0 + + outputs.append(output) + + if not output.success: + session_end = time.perf_counter() + metrics = SessionMetrics( + session_no=request_func_input.no, + session_e2e_time=session_end - session_start, + pure_llm_time=llm_time, + input_tokens=input_tokens, + output_tokens=output_tokens, + tool_calls=tool_call_count, + ) + return outputs, metrics + + input_tokens += output.prompt_tokens + output_tokens += output.output_tokens + # 若session输入长度超过max_prompt_len,则停止session + if max_prompt_len and input_tokens >= max_prompt_len: + print( + f"[SESSION STOP] {round_input.no} reach max_prompt_len={max_prompt_len}, stop session" + ) + session_end = time.perf_counter() + metrics = SessionMetrics( + session_no=request_func_input.no, + session_e2e_time=session_end - session_start, + pure_llm_time=llm_time, + input_tokens=input_tokens, + output_tokens=output_tokens, + tool_calls=tool_call_count, + ) + return outputs, metrics + else: + print(f"Warning {prompt_no} exceed max_loop={max_loop}, force stop tool loop") + else: - error_text = await response.text() - print( - "####error response:", - error_text, - "####payload:", - payload, + # 无tools + history.append( + { + "role": "assistant", + "content": output.generated_text, + } ) - output.error = error_text or "" - output.success = False - except Exception: - output.success = False - exc_info = sys.exc_info() - output.error = "".join(traceback.format_exception(*exc_info)) - # 保存失败请求结果 - if not output.success: - with open("error_output.txt", "a") as f: - f.write(str(output) + "\n") + prompt_no += 1 + elif message["role"] == "assistant": + continue + else: + history.append(message) + + session_end = time.perf_counter() + session_e2e_time = session_end - session_start + if pbar: pbar.update(1) - print("#####final_output:", output) - return output + + metrics = SessionMetrics( + session_no=request_func_input.no, + session_e2e_time=session_e2e_time, + pure_llm_time=llm_time, + input_tokens=input_tokens, + output_tokens=output_tokens, + tool_calls=tool_call_count, + ) + + return outputs, metrics async def async_request_eb_openai_completions( @@ -193,7 +828,9 @@ async def async_request_eb_openai_completions( ("completions", "profile") ), "OpenAI Completions API URL must end with 'completions' or 'profile'." - async with aiohttp.ClientSession(trust_env=True, timeout=AIOHTTP_TIMEOUT) as session: + async with aiohttp.ClientSession( + trust_env=True, read_bufsize=10 * 1024 * 1024, timeout=AIOHTTP_TIMEOUT + ) as session: payload = { "model": request_func_input.model, "prompt": request_func_input.prompt, @@ -209,7 +846,8 @@ async def async_request_eb_openai_completions( if request_func_input.ignore_eos: payload["ignore_eos"] = request_func_input.ignore_eos - print("payload:", json.dumps(payload, ensure_ascii=False)) + if request_func_input.debug: + print("payload:", json.dumps(payload, ensure_ascii=False)) headers = { "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", @@ -288,7 +926,8 @@ async def async_request_eb_openai_completions( exc_info = sys.exc_info() output.error = "".join(traceback.format_exception(*exc_info)) - print(f"final_output:{output}") + if request_func_input.debug: + print(f"final_output:{output}") if pbar: pbar.update(1) @@ -680,6 +1319,7 @@ def to_bytes(y, sr): "deepspeed-mii": async_request_deepspeed_mii, "openai": async_request_eb_openai_completions, "openai-chat": async_request_eb_openai_chat_completions, + "openai-chat-multi-turn": async_request_eb_openai_chat_completions_multi_turn, "openai-audio": async_request_openai_audio, "tensorrt-llm": async_request_trt_llm, "scalellm": async_request_openai_completions, diff --git a/benchmarks/benchmark_dataset.py b/benchmarks/benchmark_dataset.py index 551f0c9d52b..ab7c8deb3ee 100644 --- a/benchmarks/benchmark_dataset.py +++ b/benchmarks/benchmark_dataset.py @@ -45,6 +45,8 @@ class SampleRequest: json_data: Optional[dict] prompt_len: int expected_output_len: int + response_format: Optional[dict] = None + random_flag: bool = False class BenchmarkDataset(ABC): @@ -57,6 +59,7 @@ def __init__( self, dataset_path: Optional[str] = None, random_seed: int = DEFAULT_SEED, + shuffle: bool = False, hyperparameter_path: Optional[str] = None, ) -> None: """ @@ -72,6 +75,7 @@ def __init__( # default seed. self.random_seed = random_seed if random_seed is not None else self.DEFAULT_SEED self.data = None + self.shuffle = shuffle self.hyperparameter_path = hyperparameter_path self.hyperparameters = {} @@ -211,6 +215,10 @@ def load_data(self) -> None: with open(self.dataset_path, encoding="utf-8") as f: self.data = [json.loads(i.strip()) for i in f.readlines()] + if self.shuffle: + random.seed(self.random_seed) + random.shuffle(self.data) + def sample( self, num_requests: int, @@ -225,20 +233,23 @@ def sample( for entry in self.data: if len(samples) >= num_requests: break + json_data = entry + prompt = entry["text"] - self.temperature = float(entry["temperature"]) - self.repetition_penalty = float(entry["penalty_score"]) - self.frequency_penalty = float(entry["frequency_score"]) - self.presence_penalty = float(entry["presence_score"]) - self.top_p = float(entry["topp"]) - self.prompt_len = int(entry["input_token_num"]) - new_output_len = int(entry["max_dec_len"]) + self.temperature = float(entry.get("temperature", 1)) + self.repetition_penalty = float(entry.get("penalty_score", 0)) + self.frequency_penalty = float(entry.get("frequency_score", 0)) + self.presence_penalty = float(entry.get("presence_score", 0)) + self.top_p = float(entry.get("topp", 1)) + self.prompt_len = int(entry.get("input_token_num", 0)) + new_output_len = int(entry.get("max_dec_len", 0)) if enable_multimodal_chat: prompt = self.apply_multimodal_chat_transformation(prompt, None) samples.append( SampleRequest( no=cnt, + json_data=json_data, prompt=prompt, prompt_len=self.prompt_len, history_QA=[], @@ -270,6 +281,10 @@ def load_data(self) -> None: with open(self.dataset_path, encoding="utf-8") as f: self.data = [json.loads(i.strip()) for i in f.readlines()] + if self.shuffle: + random.seed(self.random_seed) + random.shuffle(self.data) + def sample( self, num_requests: int, @@ -287,7 +302,8 @@ def sample( json_data = entry prompt = entry["messages"][-1].get("content", "") history_QA = entry.get("messages", []) - new_output_len = int(entry.get("max_tokens", 12288)) + response_format = entry.get("response_format") + new_output_len = int(entry.get("max_tokens", output_len if output_len else 12288)) if enable_multimodal_chat: prompt = self.apply_multimodal_chat_transformation(prompt, None) @@ -299,9 +315,506 @@ def sample( prompt_len=0, history_QA=history_QA, expected_output_len=new_output_len, + response_format=response_format, ) ) cnt += 1 self.maybe_oversample_requests(samples, num_requests) return samples + + +class RandomTextDataset(BenchmarkDataset): + """ + Generates random English words for pure text benchmarking. + """ + + # Common English words vocabulary + COMMON_WORDS = [ + "the", + "be", + "to", + "of", + "and", + "a", + "in", + "that", + "have", + "i", + "it", + "for", + "not", + "on", + "with", + "he", + "as", + "you", + "do", + "at", + "this", + "but", + "his", + "by", + "from", + "they", + "we", + "say", + "her", + "she", + "or", + "an", + "will", + "my", + "one", + "all", + "would", + "there", + "their", + "what", + "so", + "up", + "out", + "if", + "about", + "who", + "get", + "which", + "go", + "me", + "when", + "make", + "can", + "like", + "time", + "no", + "just", + "him", + "know", + "take", + "people", + "into", + "year", + "your", + "good", + "some", + "could", + "them", + "see", + "other", + "than", + "then", + "now", + "look", + "only", + "come", + "its", + "over", + "think", + "also", + "back", + "after", + "use", + "two", + "how", + "our", + "work", + "first", + "well", + "way", + "even", + "new", + "want", + "because", + "any", + "these", + "give", + "day", + "most", + "us", + "is", + "are", + "was", + "were", + "been", + "has", + "had", + "did", + "done", + "said", + "told", + "asked", + "thought", + "went", + "saw", + "looked", + "found", + "took", + "gave", + "made", + "put", + "set", + "got", + "ran", + "came", + "walked", + "stood", + "sat", + "lay", + "felt", + "heard", + "saw", + "knew", + "thought", + "understood", + "believed", + "wanted", + "needed", + "liked", + "loved", + "hated", + "feared", + "hoped", + "expected", + "planned", + "decided", + "agreed", + "disagreed", + "argued", + "discussed", + "explained", + "described", + "reported", + "announced", + "declared", + "stated", + "claimed", + "suggested", + "proposed", + "recommended", + "advised", + "warned", + "threatened", + "promised", + "offered", + "refused", + "denied", + "admitted", + "confessed", + "apologized", + "forgave", + "thanked", + "congratulated", + "celebrated", + "welcomed", + "greeted", + "introduced", + "presented", + "showed", + "demonstrated", + "proved", + "tested", + "examined", + "studied", + "learned", + "taught", + "trained", + "practiced", + "performed", + "played", + "worked", + "built", + "created", + "designed", + "developed", + "improved", + "changed", + "fixed", + "solved", + "completed", + "finished", + "started", + "began", + "continued", + "stopped", + "ended", + "left", + "arrived", + "departed", + "traveled", + "moved", + "stayed", + "waited", + "rested", + "slept", + "woke", + "ate", + "drank", + "cooked", + "cleaned", + "washed", + "dressed", + "undressed", + "showered", + "bathed", + "brushed", + "combed", + "shaved", + "cut", + "trimmed", + "painted", + "drew", + "wrote", + "read", + "spoke", + "listened", + "heard", + "saw", + "watched", + "looked", + "observed", + "noticed", + "recognized", + "remembered", + "forgot", + "learned", + "understood", + "knew", + "believed", + "doubted", + "wondered", + "thought", + "considered", + "decided", + "chose", + "selected", + "preferred", + "liked", + "loved", + "hated", + "feared", + "worried", + "hoped", + "expected", + "planned", + "prepared", + "organized", + "arranged", + "scheduled", + "timed", + "measured", + "counted", + "calculated", + "estimated", + "valued", + "priced", + "cost", + "paid", + "bought", + "sold", + "traded", + "exchanged", + "shared", + "divided", + "combined", + "joined", + "connected", + "attached", + "separated", + "divided", + "cut", + "broke", + "fixed", + "repaired", + "built", + "created", + "made", + "produced", + "manufactured", + "assembled", + "constructed", + "designed", + "planned", + "developed", + "improved", + "enhanced", + "changed", + "modified", + "adjusted", + "adapted", + "converted", + "transformed", + "turned", + "became", + "grew", + "developed", + "evolved", + "progressed", + "advanced", + "moved", + "went", + "came", + "arrived", + "departed", + "left", + "returned", + "went back", + "came back", + "arrived back", + "departed again", + "left again", + "returned again", + "went away", + "came close", + "moved away", + "approached", + "reached", + "arrived at", + "departed from", + "left from", + "returned to", + "went to", + "came from", + "traveled to", + "traveled from", + "moved to", + "moved from", + "stayed at", + "remained at", + "waited for", + "rested at", + "slept at", + "woke up at", + "ate at", + "drank at", + "cooked at", + "cleaned at", + "washed at", + "dressed at", + "undressed at", + "showered at", + "bathed at", + "brushed at", + "combed at", + "shaved at", + "cut at", + "trimmed at", + "painted at", + "drew at", + "wrote at", + "read at", + "spoke at", + "listened at", + "heard at", + "saw at", + "watched at", + "looked at", + "observed at", + "noticed at", + "recognized at", + "remembered at", + "forgot at", + "learned at", + "understood at", + "knew at", + "believed at", + "doubted at", + "wondered at", + "thought at", + "considered at", + "decided at", + "chose at", + "selected at", + "preferred at", + "liked at", + "loved at", + "hated at", + "feared at", + "worried at", + "hoped at", + "expected at", + "planned at", + "prepared at", + "organized at", + "arranged at", + "scheduled at", + "timed at", + "measured at", + "counted at", + "calculated at", + "estimated at", + "valued at", + "priced at", + "cost at", + "paid at", + "bought at", + "sold at", + "traded at", + "exchanged at", + "shared at", + "divided at", + "combined at", + "joined at", + "connected at", + "attached at", + "separated at", + "divided at", + "cut at", + "broke at", + "fixed at", + "repaired at", + "built at", + "created at", + "made at", + "produced at", + "manufactured at", + ] + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def sample( + self, + num_requests: int, + lora_path: Optional[str] = None, + max_loras: Optional[int] = None, + random_input_len: Optional[int] = None, + random_output_len: Optional[int] = None, + random_range_ratio: Optional[float] = None, + enable_multimodal_chat: bool = False, + **kwargs, + ) -> list: + samples = [] + + def sample_len(base_len: int, ratio: float) -> int: + if base_len is None: + return None + if ratio is None or ratio <= 0: + return base_len + lo = max(1, int(base_len * (1 - ratio))) + hi = int(base_len * (1 + ratio)) + return random.randint(lo, hi) + + for i in range(1, num_requests + 1): + # [length * (1 - range_ratio), length * (1 + range_ratio)] + sampled_input_len = sample_len(random_input_len, random_range_ratio) + sampled_output_len = sample_len(random_output_len, random_range_ratio) + + words = [random.choice(self.COMMON_WORDS) for _ in range(sampled_input_len)] + prompt_text = " ".join(words) + + data = { + "messages": [{"role": "user", "content": prompt_text}], + } + + samples.append( + SampleRequest( + no=i, + json_data=data, + prompt=prompt_text, + prompt_len=sampled_input_len, + history_QA=data["messages"], + expected_output_len=sampled_output_len, + random_flag=True, + ) + ) + return samples diff --git a/benchmarks/benchmark_fmq.py b/benchmarks/benchmark_fmq.py new file mode 100644 index 00000000000..3878f790cc6 --- /dev/null +++ b/benchmarks/benchmark_fmq.py @@ -0,0 +1,233 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import asyncio +import multiprocessing as mp +import os +import statistics +import time + +from tqdm import tqdm + +from fastdeploy.inter_communicator.fmq import FMQ + + +# ============================================================ +# Producer Task +# ============================================================ +async def producer_task(proc_id, msg_count, payload_size, shm_threshold, result_q): + fmq = FMQ() + q = fmq.queue("mp_bench_latency", role="producer") + payload = b"x" * payload_size + + # tqdm 进度条 + pbar = tqdm(total=msg_count, desc=f"Producer-{proc_id}", position=proc_id, leave=True, disable=False) + + t0 = time.perf_counter() + for i in range(msg_count): + send_ts = time.perf_counter() + await q.put(data={"pid": proc_id, "i": i, "send_ts": send_ts, "payload": payload}, shm_threshold=shm_threshold) + pbar.update(1) + # pbar.write(f"send {i}") + t1 = time.perf_counter() + result_q.put({"producer_id": proc_id, "count": msg_count, "time": t1 - t0}) + + pbar.close() + + # wait for 2 seconds before closing + await asyncio.sleep(5) + + +def producer_process(proc_id, msg_count, payload_size, shm_threshold, result_q): + async def run(): + await producer_task(proc_id, msg_count, payload_size, shm_threshold, result_q) + + asyncio.run(run()) + + +# ============================================================ +# Consumer Task +# ============================================================ +async def consumer_task(consumer_id, total_msgs, result_q, consumer_event): + fmq = FMQ() + q = fmq.queue("mp_bench_latency", role="consumer") + consumer_event.set() + + latencies = [] + recv = 0 + + # tqdm 显示进度 + pbar = tqdm(total=total_msgs, desc=f"Consumer-{consumer_id}", position=consumer_id + 1, leave=True, disable=False) + + first_recv = None + last_recv = None + + while recv < total_msgs: + msg = await q.get() + recv_ts = time.perf_counter() + if msg is None: + pbar.write("recv None") + continue + if first_recv is None: + first_recv = recv_ts + last_recv = recv_ts + send_ts = msg.payload["send_ts"] + latencies.append((recv_ts - send_ts) * 1000) # ms + pbar.update(1) + recv += 1 + + pbar.close() + + result_q.put( + {"consumer_id": consumer_id, "latencies": latencies, "first_recv": first_recv, "last_recv": last_recv} + ) + + +def consumer_process(consumer_id, total_msgs, result_q, consumer_event): + async def run(): + await consumer_task(consumer_id, total_msgs, result_q, consumer_event) + + asyncio.run(run()) + + +# ============================================================ +# MAIN benchmark +# ============================================================ +def run_benchmark( + NUM_PRODUCERS=1, + NUM_CONSUMERS=1, + NUM_MESSAGES_PER_PRODUCER=1000, + PAYLOAD_SIZE=1 * 1024 * 1024, + SHM_THRESHOLD=1 * 1024 * 1024, +): + total_messages = NUM_PRODUCERS * NUM_MESSAGES_PER_PRODUCER + total_bytes = total_messages * PAYLOAD_SIZE + + print(f"\nFastDeploy Message Queue Benchmark, pid:{os.getpid()}") + print(f"Producers: {NUM_PRODUCERS}") + print(f"Consumers: {NUM_CONSUMERS}") + print(f"Messages per producer: {NUM_MESSAGES_PER_PRODUCER}") + print(f"Total bytes: {total_bytes / 1024 / 1024 / 1024:.2f} GB") + print(f"Total messages: {total_messages:,}") + print(f"Payload per message: {PAYLOAD_SIZE / 1024 / 1024:.2f} MB") + + mp.set_start_method("fork") + manager = mp.Manager() + result_q = manager.Queue() + + # 两个信号事件 + consumer_event = manager.Event() + + procs = [] + + # Start Consumers + msgs_per_consumer = total_messages // NUM_CONSUMERS + for i in range(NUM_CONSUMERS): + p = mp.Process(target=consumer_process, args=(i, msgs_per_consumer, result_q, consumer_event)) + procs.append(p) + p.start() + + consumer_event.wait() + + # Start Producers + for i in range(NUM_PRODUCERS): + p = mp.Process( + target=producer_process, args=(i, NUM_MESSAGES_PER_PRODUCER, PAYLOAD_SIZE, SHM_THRESHOLD, result_q) + ) + procs.append(p) + p.start() + + # Join + for p in procs: + p.join() + + # Collect results + producer_stats = [] + consumer_stats = {} + + while not result_q.empty(): + item = result_q.get() + if "producer_id" in item: + producer_stats.append(item) + if "consumer_id" in item: + consumer_stats[item["consumer_id"]] = item + + # Producer stats + print("\nProducer Stats:") + for p in producer_stats: + throughput = p["count"] / p["time"] + bandwidth = (p["count"] * PAYLOAD_SIZE) / (1024**2 * p["time"]) + print( + f"[Producer-{p['producer_id']}] Sent {p['count']:,} msgs " + f"in {p['time']:.3f} s | Throughput: {throughput:,.0f} msg/s | Bandwidth: {bandwidth:.2f} MB/s" + ) + + # Consumer latency stats + print("\nConsumer Latency Stats:") + all_latencies = [] + first_recv_times = [] + last_recv_times = [] + + for cid, data in consumer_stats.items(): + lats = data["latencies"] + if len(lats) == 0: + continue + all_latencies.extend(lats) + first_recv_times.append(data["first_recv"]) + last_recv_times.append(data["last_recv"]) + + avg = statistics.mean(lats) + p50 = statistics.median(lats) + p95 = statistics.quantiles(lats, n=20)[18] + p99 = statistics.quantiles(lats, n=100)[98] + + print( + f"[Consumer-{cid}] msgs={len(lats):5d} | avg={avg:.3f} ms | " + f"P50={p50:.3f} ms | P95={p95:.3f} ms | P99={p99:.3f} ms" + ) + + # Global summary + if first_recv_times and last_recv_times: + total_time = max(last_recv_times) - min(first_recv_times) + global_throughput = total_messages / total_time + global_bandwidth = total_bytes / (1024**2 * total_time) + + if all_latencies: + avg_latency = statistics.mean(all_latencies) + min_latency = min(all_latencies) + max_latency = max(all_latencies) + p50_latency = statistics.median(all_latencies) + p95_latency = statistics.quantiles(all_latencies, n=20)[18] + p99_latency = statistics.quantiles(all_latencies, n=100)[98] + else: + avg_latency = min_latency = max_latency = p50_latency = p95_latency = p99_latency = 0.0 + + print("\nGlobal Summary:") + print(f"Total messages : {total_messages:,}") + print(f"Total data : {total_bytes / 1024**2:.2f} MB") + print(f"Total time : {total_time:.3f} s") + print(f"Global throughput: {global_throughput:,.0f} msg/s") + print(f"Global bandwidth : {global_bandwidth:.2f} MB/s") + print( + f"Latency (ms) : avg={avg_latency:.3f} " + f"| min={min_latency:.3f} | max={max_latency:.3f} " + f"| P50={p50_latency:.3f} | P95={p95_latency:.3f} | P99={p99_latency:.3f}\n" + ) + + +# Entry +if __name__ == "__main__": + run_benchmark() diff --git a/benchmarks/benchmark_mtp.py b/benchmarks/benchmark_mtp.py index 2698a553b69..a28cc7b1285 100644 --- a/benchmarks/benchmark_mtp.py +++ b/benchmarks/benchmark_mtp.py @@ -98,7 +98,7 @@ def main(args): raise ValueError("--max_concurrency should be same length as --s_itl_base_model") for max_concurrency, s_itl in zip(args.max_concurrency, args.s_itl_base_model): - # Wramup + # Warmup print("Starting warmup...") with open(os.devnull, "w") as f: with contextlib.redirect_stdout(f): diff --git a/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py index 25825061ada..c7cb9c5806a 100644 --- a/benchmarks/benchmark_serving.py +++ b/benchmarks/benchmark_serving.py @@ -39,7 +39,7 @@ RequestFuncInput, RequestFuncOutput, ) -from benchmark_dataset import EBChatDataset, EBDataset, SampleRequest +from benchmark_dataset import EBChatDataset, EBDataset, RandomTextDataset, SampleRequest from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json from tqdm.asyncio import tqdm @@ -104,6 +104,14 @@ class BenchmarkMetrics: median_output_len: float std_output_len: float percentiles_output_len: list[tuple[float, float]] + mean_reasoning_len: float + median_reasoning_len: float + std_reasoning_len: float + percentiles_reasoning_len: list[tuple[float, float]] + mean_res_ttft_ms: float + median_res_ttft_ms: float + std_res_ttft_ms: float + percentiles_res_ttft_ms: list[tuple[float, float]] async def get_request( @@ -150,7 +158,7 @@ async def get_request( def calculate_metrics( - input_requests: list[SampleRequest], + # input_requests: list[SampleRequest], outputs: list[RequestFuncOutput], dur_s: float, selected_percentiles: list[float], @@ -160,6 +168,7 @@ def calculate_metrics( input_lens: list[int] = [] infer_input_lens: list[int] = [] # 推理侧输入token数 actual_output_lens: list[int] = [] + reasoning_output_lens: list[int] = [] total_input = 0 completed = 0 good_completed = 0 @@ -169,6 +178,7 @@ def calculate_metrics( all_tpots: list[float] = [] ttfts: list[float] = [] s_ttfts: list[float] = [] + res_ttfts: list[float] = [] e2els: list[float] = [] s_e2els: list[float] = [] s_decodes: list[float] = [] @@ -177,7 +187,7 @@ def calculate_metrics( output_len = outputs[i].output_tokens if not output_len: - print("no output_len") + print("no output_len", outputs[i]) # We use the tokenizer to count the number of output tokens # for some serving backends instead of looking at # len(outputs[i].itl) since multiple output tokens may be @@ -186,6 +196,7 @@ def calculate_metrics( continue actual_output_lens.append(output_len) + reasoning_output_lens.append(outputs[i].reasoning_tokens) input_lens.append(outputs[i].prompt_len) infer_input_lens.append(outputs[i].prompt_tokens) total_input += outputs[i].prompt_tokens @@ -204,6 +215,7 @@ def calculate_metrics( ttfts.append(outputs[i].ttft) # 推理侧TTFT s_ttfts.append(outputs[i].arrival_time[1]) + res_ttfts.append(outputs[i].res_ttft) e2els.append(outputs[i].latency) # 推理侧整句时延 s_e2els.append(outputs[i].arrival_time[-1]) @@ -296,6 +308,14 @@ def calculate_metrics( std_output_len=np.std(actual_output_lens or 0) * 1, median_output_len=np.median(actual_output_lens or 0) * 1, percentiles_output_len=[(p, np.percentile(actual_output_lens or 0, p)) for p in selected_percentiles], + mean_reasoning_len=np.mean(reasoning_output_lens or 0) * 1, + std_reasoning_len=np.std(reasoning_output_lens or 0) * 1, + median_reasoning_len=np.median(reasoning_output_lens or 0) * 1, + percentiles_reasoning_len=[(p, np.percentile(reasoning_output_lens or 0, p)) for p in selected_percentiles], + mean_res_ttft_ms=np.mean(res_ttfts or 0) * 1000, # ttfts is empty if streaming is not supported by backend + std_res_ttft_ms=np.std(res_ttfts or 0) * 1000, + median_res_ttft_ms=np.median(res_ttfts or 0) * 1000, + percentiles_res_ttft_ms=[(p, np.percentile(res_ttfts or 0, p) * 1000) for p in selected_percentiles], ) return metrics, actual_output_lens @@ -317,10 +337,13 @@ async def benchmark( selected_percentile_metrics: list[str], selected_percentiles: list[float], ignore_eos: bool, + debug: bool, + pd_metrics: bool, goodput_config_dict: dict[str, float], max_concurrency: Optional[int], lora_modules: Optional[Iterable[str]], extra_body: Optional[dict], + ip_list: Optional[list[str]] = None, ): """Benchmarks an API endpoint using a given set of sample inputs and returns""" if backend in ASYNC_REQUEST_FUNCS: @@ -329,12 +352,18 @@ async def benchmark( raise ValueError(f"Unknown backend: {backend}") print("Starting initial single prompt test run...") - test_prompt, test_output_len, test_no = ( + test_prompt, test_output_len, test_no, test_json_data = ( input_requests[0].prompt, input_requests[0].expected_output_len, input_requests[0].no, + input_requests[0].json_data, ) test_history_QA = input_requests[0].history_QA + response_format = input_requests[0].response_format + random_flag = input_requests[0].random_flag + + if len(ip_list) >= 1: + api_url = f"http://{ip_list[0]}{args.endpoint}" test_input = RequestFuncInput( model=model_id, @@ -348,22 +377,32 @@ async def benchmark( output_len=test_output_len, logprobs=logprobs, ignore_eos=ignore_eos, + debug=debug, + pd_metrics=pd_metrics, extra_body=extra_body, + response_format=response_format, + random_flag=random_flag, + json_data=test_json_data, + tokenizer_model=args.tokenizer_model, + tokenizer_path=args.tokenizer_path, ) - print("test_input:", test_input) + if not debug: + print("test_input:", test_input) - test_output = await request_func(request_func_input=test_input) + test_output = await request_func(request_func_input=test_input) - print("test_output:", test_output) + if args.multi_turn: + out_list, metrics = test_output + test_output = out_list[0] - if not test_output.success: - raise ValueError( - "Initial test run failed - Please make sure benchmark arguments " - f"are correctly specified. Error: {test_output.error}" - ) - else: - print("Initial test run completed. Starting main benchmark run...") + if not test_output.success: + print("test_output:", test_output, flush=True) + raise ValueError( + f"Initial test run failed - Please make sure that 1. benchmark arguments are correctly specified and 2. the http_proxy and https_proxy are turned off. Error: {test_output.error}" + ) + else: + print("Initial test run completed. Starting main benchmark run...") if lora_modules: # For each input request, choose a LoRA module at random. @@ -381,6 +420,8 @@ async def benchmark( logprobs=logprobs, ignore_eos=ignore_eos, extra_body=extra_body, + response_format=response_format, + random_flag=random_flag, ) profile_output = await request_func(request_func_input=profile_input) if profile_output.success: @@ -394,6 +435,7 @@ async def benchmark( print(f"Traffic request rate: {request_rate}") print(f"Burstiness factor: {burstiness} ({distribution})") print(f"Maximum request concurrency: {max_concurrency}") + print(f"Drop ratio: {args.drop_ratio}") pbar = None if disable_tqdm else tqdm(total=len(input_requests)) @@ -401,45 +443,168 @@ async def benchmark( # and it will simplify the code in limited_request_func. # semaphore = (asyncio.Semaphore(max_concurrency) # if max_concurrency else contextlib.nullcontext()) - semaphore = asyncio.Semaphore(max_concurrency) if max_concurrency else None - - async def limited_request_func(request_func_input, pbar): - if semaphore is None: - return await request_func(request_func_input=request_func_input, pbar=pbar) - async with semaphore: - return await request_func(request_func_input=request_func_input, pbar=pbar) - - benchmark_start_time = time.perf_counter() - tasks: list[asyncio.Task] = [] - async for request in get_request(input_requests, request_rate, burstiness): - prompt, output_len, no = ( - request.prompt, - request.expected_output_len, - request.no, - ) - history_QA = request.history_QA - - req_model_id, req_model_name = model_id, model_name - if lora_modules: - req_lora_module = next(lora_modules) - req_model_id, req_model_name = req_lora_module, req_lora_module - - request_func_input = RequestFuncInput( - model=req_model_id, - model_name=req_model_name, - prompt=prompt, - no=no, - prompt_len=0, - history_QA=history_QA, - hyper_parameters=hyper_parameters, - api_url=api_url, - output_len=output_len, - logprobs=logprobs, - ignore_eos=ignore_eos, - extra_body=extra_body, - ) - tasks.append(asyncio.create_task(limited_request_func(request_func_input=request_func_input, pbar=pbar))) - outputs: list[RequestFuncOutput] = await asyncio.gather(*tasks) + ip_list = ip_list or [] + + if len(ip_list) <= 1: + if len(ip_list) == 1: + api_url = f"http://{ip_list[0]}{args.endpoint}" + semaphore = asyncio.Semaphore(max_concurrency) if max_concurrency else None + + async def limited_request_func(request_func_input, pbar): + if semaphore is None: + return await request_func(request_func_input=request_func_input, pbar=pbar) + async with semaphore: + return await request_func(request_func_input=request_func_input, pbar=pbar) + + tasks: list[asyncio.Task] = [] + benchmark_start_time = time.perf_counter() + + async for request in get_request(input_requests, request_rate, burstiness): + prompt, output_len, no, json_data = ( + request.prompt, + request.expected_output_len, + request.no, + request.json_data, + ) + history_QA = request.history_QA + response_format = request.response_format + random_flag = request.random_flag + + req_model_id, req_model_name = model_id, model_name + if lora_modules: + req_lora_module = next(lora_modules) + req_model_id, req_model_name = req_lora_module, req_lora_module + + request_func_input = RequestFuncInput( + model=req_model_id, + model_name=req_model_name, + prompt=prompt, + no=no, + prompt_len=0, + history_QA=history_QA, + hyper_parameters=hyper_parameters, + api_url=api_url, + output_len=output_len, + logprobs=logprobs, + debug=debug, + pd_metrics=pd_metrics, + ignore_eos=ignore_eos, + extra_body=extra_body, + response_format=response_format, + random_flag=random_flag, + json_data=json_data, + tokenizer_model=args.tokenizer_model, + tokenizer_path=args.tokenizer_path, + ) + tasks.append(asyncio.create_task(limited_request_func(request_func_input=request_func_input, pbar=pbar))) + + outputs: list[RequestFuncOutput] = await asyncio.gather(*tasks) + else: + # 多ip按DP均分并发 + assert max_concurrency, "multi-IP 模式必须指定 max_concurrency" + n_ip = len(ip_list) + if max_concurrency < n_ip: + print( + f"[WARN] max_concurrency({max_concurrency}) < IP 数({n_ip})," + f"已自动兜底为每个 IP 1 并发," + f"实际总并发将变为 {n_ip}" + ) + concurrency_per_ip = max(1, max_concurrency // n_ip) + concurrency_remainder = max(0, max_concurrency - concurrency_per_ip * n_ip) + + # 分配请求 + req_per_ip = len(input_requests) // n_ip + remainder = len(input_requests) % n_ip + + ip_requests_map = {} + start = 0 + for i, ip in enumerate(ip_list): + count = req_per_ip + (1 if i < remainder else 0) + print(f"IP: {ip}, requests: {count}") + print(f"start: {start}, end: {start + count}") + ip_requests_map[ip] = input_requests[start : start + count] + start += count + + # exit(8) + + semaphores = { + ip: asyncio.Semaphore(concurrency_per_ip + (1 if i < concurrency_remainder else 0)) + for i, ip in enumerate(ip_list) + } + + async def limited_request_func_per_ip(req_input, semaphore, pbar): + async with semaphore: + return await request_func(request_func_input=req_input, pbar=pbar) + + tasks = [] + for i, ip in enumerate(ip_list): + print( + f"Starting benchmark for IP: {ip}, " + f"concurrency per IP: {semaphores[ip]._value}, " + f"requests per IP: {len(ip_requests_map[ip])}", + flush=True, + ) + benchmark_start_time = time.perf_counter() + + for i, ip in enumerate(ip_list): + semaphore = semaphores[ip] + + for request in ip_requests_map[ip]: + prompt, output_len, no, json_data = ( + request.prompt, + request.expected_output_len, + request.no, + request.json_data, + ) + history_QA = request.history_QA + + req_model_id, req_model_name = model_id, model_name + if lora_modules: + req_lora_module = next(lora_modules) + req_model_id = req_model_name = req_lora_module + + req_input = RequestFuncInput( + model=req_model_id, + model_name=req_model_name, + prompt=prompt, + no=no, + prompt_len=0, + history_QA=history_QA, + hyper_parameters=hyper_parameters, + api_url=f"http://{ip}{args.endpoint}", # ★ 多 IP 模式仅替换 host:port + output_len=output_len, + logprobs=logprobs, + ignore_eos=ignore_eos, + debug=debug, + pd_metrics=pd_metrics, + extra_body=extra_body, + response_format=response_format, + random_flag=random_flag, + json_data=json_data, + tokenizer_model=args.tokenizer_model, + tokenizer_path=args.tokenizer_path, + ) + + tasks.append(asyncio.create_task(limited_request_func_per_ip(req_input, semaphore, pbar))) + + outputs: list[RequestFuncOutput] = await asyncio.gather(*tasks) + + # 多轮对话需要flatten后统计 + if args.multi_turn: + results = outputs + session_metrics = [] + all_outputs = [] + + for out_list, metrics in results: + session_metrics.append(metrics) + all_outputs.extend(out_list) + + outputs = all_outputs + + print(f"####len session_metrics: {len(session_metrics)}") + print(f"####len outputs: {len(outputs)}") + + outputs.sort(key=lambda x: x.end_timestamp) if profile: print("Stopping profiler...") @@ -450,6 +615,8 @@ async def limited_request_func(request_func_input, pbar): api_url=base_url + "/stop_profile", output_len=test_output_len, logprobs=logprobs, + response_format=response_format, + random_flag=random_flag, ) profile_output = await request_func(request_func_input=profile_input) if profile_output.success: @@ -458,12 +625,44 @@ async def limited_request_func(request_func_input, pbar): if pbar is not None: pbar.close() - benchmark_duration = time.perf_counter() - benchmark_start_time - print("benchmark_duration:", benchmark_duration) + benchmark_outputs = outputs + drop_ratio = args.drop_ratio + if 0.0 < drop_ratio < 1: + # 按drop_ratio头尾各舍弃一半请求,不计入benchmark统计 + n = len(outputs) + drop_count = int(n * drop_ratio) + half = drop_count // 2 + if half > 0: + benchmark_outputs = outputs[half : n - half] + + # 先过滤掉 end_timestamp == 0.0 的请求(失败请求) + benchmark_outputs = [o for o in benchmark_outputs if o.end_timestamp != 0.0] + + # 根据收到最后一个chunk的时间戳计算总时长 + if len(benchmark_outputs) >= 2: + benchmark_duration = benchmark_outputs[-1].end_timestamp - benchmark_outputs[0].end_timestamp + else: + benchmark_duration = 0.0 + print(f"丢弃前数量: {n}") + print(f"丢弃后数量: {len(benchmark_outputs)}, 返回结果异常") + exit(8) + + print(f"丢弃前数量: {n}") + print(f"丢弃后数量: {len(benchmark_outputs)}") + print(f"benchmark_duration: {benchmark_duration} 秒") + else: + benchmark_duration = time.perf_counter() - benchmark_start_time + print(f"benchmark_duration: {benchmark_duration} 秒") + + if random_flag: + print("指定随机输入输出长度测试") + print(f"random_input_len: {args.random_input_len}") + print(f"random_output_len: {args.random_output_len}") + print(f"random_range_ratio: {args.random_range_ratio}") metrics, actual_output_lens = calculate_metrics( - input_requests=input_requests, - outputs=outputs, + # input_requests=input_requests, + outputs=benchmark_outputs, dur_s=benchmark_duration, # tokenizer=tokenizer, selected_percentiles=selected_percentiles, @@ -489,16 +688,19 @@ async def limited_request_func(request_func_input, pbar): "request_throughput": metrics.request_throughput, "request_goodput:": (metrics.request_goodput if goodput_config_dict else None), "output_throughput": metrics.output_throughput, + "reasoning_lens": [output.reasoning_tokens for output in outputs], "total_token_throughput": metrics.total_token_throughput, "input_lens": [output.prompt_len for output in outputs], "infer_input_lens": [output.prompt_tokens for output in outputs], "output_lens": actual_output_lens, "ttfts": [output.ttft for output in outputs], + "res_ttfts": [output.res_ttft for output in outputs], "itls": [output.itl for output in outputs], "input_texts": [input.prompt for input in input_requests], "generated_texts": [output.generated_text for output in outputs], "reasoning_contents": [output.reasoning_content for output in outputs], "errors": [output.error for output in outputs], + "metrics": [output.metrics for output in outputs], } def process_one_metric( @@ -534,6 +736,86 @@ def process_one_metric( print("{:<40} {:<10.2f}".format(f"P{p_word} {metric_name} (ms):", value)) result[f"p{p_word}_{metric_attribute_name}_ms"] = value + def process_pd_metrics(model_outputs, metric_key, is_time=True): + # 收集所有该 metric 的数值 + values = [] + percentiles = [] + for p in args.metric_percentiles.split(","): + p = p.strip() + if p: + percentiles.append(float(p)) + for item in model_outputs: + metrics = item.metrics + if metrics.get(metric_key, None) is not None: + values.append(metrics[metric_key]) + + if not values: + print(f"[WARN] metric_key '{metric_key}' not found in outputs.") + return + + if is_time: + arr = np.array(values) * 1000 # 秒 -> 毫秒 + suffix = "(ms)" + else: + arr = np.array(values) + suffix = "" + + print("{s:{c}^{n}}".format(s=metric_key, n=50, c="-")) + print( + "{:<40} {:<10.2f}".format( + f"Mean {metric_key} {suffix}:", + np.mean(arr), + ) + ) + print( + "{:<40} {:<10.2f}".format( + f"Median {metric_key} {suffix}:", + np.median(arr), + ) + ) + for p in percentiles: + v = np.percentile(arr, p) + print("{:<40} {:<10.2f}".format(f"P{str(int(p)) if int(p) == p else str(p)} {metric_key} {suffix}:", v)) + # print(f"P{str(int(p)) if int(p) == p else str(p)} {metric_key} (ms): {v:10.2f}") + print( + "{:<40} {:<10.2f}".format( + f"Successful {metric_key}:", + len(arr), + ) + ) + + def print_metric_from_array(values, metric_key, is_time=True): + if not values: + print(f"[WARN] metric_key '{metric_key}' empty.") + return + + percentiles = [float(p.strip()) for p in args.metric_percentiles.split(",") if p.strip()] + + if is_time: + arr = np.array(values) * 1000 + suffix = "(ms)" + else: + arr = np.array(values) + suffix = "" + + print("{s:{c}^{n}}".format(s=metric_key, n=50, c="-")) + + print(f"{f'Mean {metric_key} {suffix}:':<40} {arr.mean():<10.2f}") + print(f"{f'Median {metric_key} {suffix}:':<40} {np.median(arr):<10.2f}") + print(f"{f'Min {metric_key} {suffix}:':<40} {arr.min():<10.2f}") + print(f"{f'Max {metric_key} {suffix}:':<40} {arr.max():<10.2f}") + + for p in percentiles: + v = np.percentile(arr, p) + label = f"P{int(p) if int(p) == p else p}" + print(f"{f'{label} {metric_key} {suffix}:':<40} {v:<10.2f}") + + # print(f"{f'Successful {metric_key}:':<40} {len(arr):<10}") + + def process_session_metrics(session_metrics, attr, metric_key, is_time=True): + values = [getattr(m, attr) for m in session_metrics] + print_metric_from_array(values, metric_key, is_time) + def process_one_length( # E.g., "ttft" metric_attribute_name: str, @@ -570,14 +852,50 @@ def process_one_length( process_one_length("s_decode", "Decode", "解码速度(tok/s)") process_one_metric("ttft", "TTFT", "Time to First Token") process_one_metric("s_ttft", "S_TTFT", "Infer Time to First Token") + process_one_metric("res_ttft", "Response TTFT", "包含思考首token耗时") process_one_metric("tpot", "TPOT", "Time per Output Token (excl. 1st token)") process_one_metric("itl", "ITL", "Inter-token Latency") process_one_metric("s_itl", "S_ITL", "Infer Inter-token Latency") process_one_metric("e2el", "E2EL", "End-to-end Latency") process_one_metric("s_e2el", "S_E2EL", "Infer End-to-end Latency") + if any(item.metrics for item in outputs): + process_pd_metrics(outputs, "prefill_cost_time") + process_pd_metrics(outputs, "prefill_prepare_cost_time") + process_pd_metrics(outputs, "preprocess_cost_time") + process_pd_metrics(outputs, "cache_in_scheduler_cost_time") + process_pd_metrics(outputs, "schedule_cost_time") + process_pd_metrics(outputs, "ask_decode_resource_cost_time") + process_pd_metrics(outputs, "prefill_first_token_infer_cost_time") + process_pd_metrics(outputs, "wait_sending_cache_cost_time") + process_pd_metrics(outputs, "decode_preallocate_cost_time") + process_pd_metrics(outputs, "decode_prepare_cost_time") + process_pd_metrics(outputs, "decode_second_token_infer_cost_time") + process_pd_metrics(outputs, "first_token_transmission_cost_time") + process_pd_metrics(outputs, "second_token_transmission_cost_time") + process_pd_metrics(outputs, "mixed_schedule_cost_time") + process_pd_metrics(outputs, "gpu_cache_token_num", is_time=False) + process_pd_metrics(outputs, "cpu_cache_token_num", is_time=False) + process_pd_metrics(outputs, "storage_cache_token_num", is_time=False) + process_pd_metrics(outputs, "cpu_cache_prepare_time") + process_pd_metrics(outputs, "storage_cache_prepare_time") process_one_length("input_len", "Cached Tokens", "Cached Tokens") process_one_length("s_input_len", "Input Length", "Infer Input Length") + process_one_length("reasoning_len", "Reasoning Lenth", "思考长度") process_one_length("output_len", "Output Length", "Output Length") + # 多轮metrcis统计 + if args.multi_turn: + process_session_metrics(session_metrics, "session_e2e_time", "Session E2EL") + process_session_metrics(session_metrics, "pure_llm_time", "Session llm_E2EL") + process_session_metrics(session_metrics, "tool_calls", "Tool Calls", is_time=False) + process_session_metrics(session_metrics, "input_tokens", "Session Input Tokens", is_time=False) + process_session_metrics(session_metrics, "output_tokens", "Session Output Tokens", is_time=False) + total_sessions = len(session_metrics) + total_requests = len(outputs) + success_requests = sum(1 for o in outputs if getattr(o, "success", False)) + failed_requests = total_requests - success_requests + print(f"{'Total sessions :':<40} {total_sessions:<10}") + print(f"{'Total requests :':<40} {total_requests:<10}") + print(f"{'Failed requests :':<40} {failed_requests:<10}") print("=" * 50) @@ -607,7 +925,7 @@ def benchmark_metrics( goodput_config_dict = check_goodput_args(args) metrics, actual_output_lens = calculate_metrics( - input_requests=input_requests, + # input_requests=input_requests, outputs=outputs, dur_s=benchmark_duration, selected_percentiles=selected_percentiles, @@ -803,6 +1121,9 @@ def main(args: argparse.Namespace): np.random.seed(args.seed) backend = args.backend + # 支持多轮对话方式请求,仅支持chat接口 + if args.multi_turn: + backend = "openai-chat-multi-turn" model_id = args.model model_name = args.served_model_name tokenizer_id = args.tokenizer if args.tokenizer is not None else args.model @@ -819,14 +1140,22 @@ def main(args: argparse.Namespace): # For datasets that follow a similar structure, use a mapping. dataset_mapping = { - "EB": lambda: EBDataset(random_seed=args.seed, dataset_path=args.dataset_path).sample( + "EB": lambda: EBDataset(random_seed=args.seed, dataset_path=args.dataset_path, shuffle=args.shuffle).sample( num_requests=args.num_prompts, output_len=args.sharegpt_output_len, ), - "EBChat": lambda: EBChatDataset(random_seed=args.seed, dataset_path=args.dataset_path).sample( + "EBChat": lambda: EBChatDataset( + random_seed=args.seed, dataset_path=args.dataset_path, shuffle=args.shuffle + ).sample( num_requests=args.num_prompts, output_len=args.sharegpt_output_len, ), + "random": lambda: RandomTextDataset().sample( + num_requests=args.num_prompts, + random_input_len=args.random_input_len, + random_output_len=args.random_output_len, + random_range_ratio=args.random_range_ratio, + ), } try: @@ -866,6 +1195,15 @@ def main(args: argparse.Namespace): else: hyper_parameters = {} + processed_list = [] + for item in args.ip_list: + if "," in item: + processed_list.extend([x.strip() for x in item.split(",") if x.strip()]) + else: + processed_list.append(item) + + ip_list = processed_list + benchmark_result = asyncio.run( benchmark( backend=backend, @@ -883,10 +1221,13 @@ def main(args: argparse.Namespace): selected_percentile_metrics=args.percentile_metrics.split(","), selected_percentiles=[float(p) for p in args.metric_percentiles.split(",")], ignore_eos=args.ignore_eos, + debug=args.debug, + pd_metrics=args.pd_metrics, goodput_config_dict=goodput_config_dict, max_concurrency=args.max_concurrency, lora_modules=args.lora_modules, extra_body=sampling_params, + ip_list=ip_list, ) ) @@ -951,7 +1292,7 @@ def main(args: argparse.Namespace): if args.result_dir: file_name = os.path.join(args.result_dir, file_name) with open(file_name, "w", encoding="utf-8") as outfile: - json.dump(result_json, outfile) + json.dump(result_json, outfile, ensure_ascii=False) save_to_pytorch_benchmark_format(args, result_json, file_name) @@ -960,7 +1301,7 @@ def main(args: argparse.Namespace): parser.add_argument( "--backend", type=str, - default="vllm", + default="openai-chat", choices=list(ASYNC_REQUEST_FUNCS.keys()), ) parser.add_argument( @@ -978,18 +1319,25 @@ def main(args: argparse.Namespace): default="/v1/completions", help="API endpoint.", ) + parser.add_argument( + "--ip-list", + nargs="*", + default=[], + help=( + "List of ip:port. " + "Supports: " + "1) --ip-list 127.0.0.1:8000 --ip-list 127.0.0.1:8001 " + "2) --ip-list 127.0.0.1:8000,127.0.0.1:8001" + ), + ) parser.add_argument( "--dataset-name", type=str, - default="sharegpt", + default="EBChat", choices=[ - "sharegpt", - "burstgpt", - "sonnet", - "random", - "hf", "EB", "EBChat", + "random", ], help="Name of the dataset to benchmark on.", ) @@ -1071,6 +1419,39 @@ def main(args: argparse.Namespace): "results in a more uniform arrival of requests.", ) parser.add_argument("--seed", type=int, default=0) + parser.add_argument( + "--shuffle", + action="store_true", + help="shuffle dataset", + ) + parser.add_argument( + "--pd-metrics", + action="store_true", + help="请求时增加PD分离参数,metrics: True", + ) + parser.add_argument( + "--multi-turn", + action="store_true", + help="按多轮对话方式请求", + ) + parser.add_argument( + "--tokenizer-model", + default="auto", + type=str, + help="使用token_ids请求时指定,多轮对话tokenizer模型类型,'eb': ErnieBotTokenizer, 'eb5': Ernie5Tokenizer, 'eb_mm': Ernie4_5Tokenizer", + ) + parser.add_argument( + "--tokenizer-path", + type=str, + default=None, + help="使用token_ids请求时指定,模型tokenizer路径", + ) + parser.add_argument( + "--drop-ratio", + type=float, + default=0.0, + help="Drop ratio of the outputs. [0, 1)", + ) parser.add_argument( "--trust-remote-code", action="store_true", @@ -1091,6 +1472,11 @@ def main(args: argparse.Namespace): action="store_true", help="Specify to save benchmark results to a json file", ) + parser.add_argument( + "--debug", + action="store_true", + help="print debug information (output)", + ) parser.add_argument( "--save-detailed", action="store_true", @@ -1130,7 +1516,7 @@ def main(args: argparse.Namespace): parser.add_argument( "--percentile-metrics", type=str, - default="ttft,tpot,itl", + default="ttft,tpot,itl,reasoning_len", help="Comma-separated list of selected metrics to report percentils. " "This argument specifies the metrics to report percentiles. " 'Allowed metric names are "ttft", "tpot", "itl", "e2el". ' @@ -1191,37 +1577,24 @@ def main(args: argparse.Namespace): random_group.add_argument( "--random-input-len", type=int, - default=1024, - help="Number of input tokens per request, used only for random sampling.", + default=200, + help="Number of input English words per request, used only for random-text dataset.", ) random_group.add_argument( "--random-output-len", type=int, - default=128, - help="Number of output tokens per request, used only for random sampling.", + default=1024, + help="Number of output tokens per request, used both for random and random-text datasets.", ) random_group.add_argument( "--random-range-ratio", type=float, - default=0.0, + default=0.1, help="Range ratio for sampling input/output length, " "used only for random sampling. Must be in the range [0, 1) to define " "a symmetric sampling range" "[length * (1 - range_ratio), length * (1 + range_ratio)].", ) - random_group.add_argument( - "--random-prefix-len", - type=int, - default=0, - help=( - "Number of fixed prefix tokens before the random context " - "in a request. " - "The total input length is the sum of `random-prefix-len` and " - "a random " - "context length sampled from [input_len * (1 - range_ratio), " - "input_len * (1 + range_ratio)]." - ), - ) hf_group = parser.add_argument_group("hf dataset options") hf_group.add_argument("--hf-subset", type=str, default=None, help="Subset of the HF dataset.") diff --git a/benchmarks/ernie_tokenizer.py b/benchmarks/ernie_tokenizer.py new file mode 100644 index 00000000000..6141bfe5df3 --- /dev/null +++ b/benchmarks/ernie_tokenizer.py @@ -0,0 +1,1247 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +# cipher_token=WjI1fQOvhN # do not edit this line + +import os +import re +from itertools import product +from shutil import copyfile +from typing import Any, Dict, List, Optional, Tuple + +import numpy as np +import paddle +import sentencepiece as spm +from paddleformers.transformers import AddedToken, PretrainedTokenizer +from paddleformers.transformers.tokenizer_utils_base import PaddingStrategy, TextInput +from paddleformers.utils.log import logger + + +class ErnieBotTokenizer(PretrainedTokenizer): + """ + 一个更好用的 `ErnieBotToknizer`, + 能 encode 目前 sft/ppo 阶段的特殊token,也支持多模态。 + """ + + resource_files_names = {"vocab_file": "spm.model"} + pretrained_resource_files_map = {"vocab_file": {"ernie-bot-10b": None}} + pretrained_init_configuration = {"ernie-bot-10b": {}} + model_input_names = ["input_ids", "position_ids", "attention_mask", "labels"] + padding_side = "right" + + def __init__( + self, + vocab_file, + bos_token="", + cls_token="", + eos_token="", + mask_token="", + pad_token="", + sep_token="", + unk_token="", + additional_special_tokens=None, + verbose=False, + **kwargs, + ): + """doc""" + if additional_special_tokens is None: + additional_special_tokens = ["", ""] + super().__init__( + bos_token=bos_token, + cls_token=cls_token, + eos_token=eos_token, + mask_token=mask_token, + pad_token=pad_token, + sep_token=sep_token, + unk_token=unk_token, + additional_special_tokens=additional_special_tokens, + verbose=False, + **kwargs, + ) + self.vocab_file = vocab_file + self.sp_model = spm.SentencePieceProcessor() + self.sp_model.Load(vocab_file) + # pre-process map-type all spec token for decode accelerate. + + @property + def space_token(self): + """doc""" + return "" + + @property + def space_token_id(self): + """doc""" + return self.sp_model.piece_to_id("") + + @property + def gend_token(self): + """doc""" + return "" + + @property + def gend_token_id(self): + """doc""" + return self.sp_model.piece_to_id("") + + @property + def im_start_id(self): + """doc""" + return self.sp_model.piece_to_id("<|im_start|>") + + @property + def im_end_id(self): + """doc""" + return self.sp_model.piece_to_id("<|im_end|>") + + @property + def vocab_size(self): + """doc""" + return self.sp_model.vocab_size() + + def get_vocab(self): + """doc""" + vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)} + vocab.update(self.added_tokens_encoder) + return vocab + + def _tokenize(self, text): + """doc""" + return self.sp_model.encode_as_pieces(text) + + def _convert_token_to_id(self, token): + """doc""" + return self.sp_model.piece_to_id(token) + + def _convert_id_to_token(self, id): + """doc""" + return self.sp_model.id_to_piece(id) + + def spec_init(self): + """初始化special tokens""" + if not hasattr(self, "all_spec_tok"): + self.all_spec_tok = set(self.all_special_tokens) + + def convert_tokens_to_string(self, tokens): + """Converts a sequence of tokens (string) in a single string.""" + self.spec_init() + current_sub_tokens = [] + out_string = "" + # prev_is_special = False + for token in tokens: + # make sure that special tokens are not decoded using sentencepiece model + if token in self.all_spec_tok: + # if not prev_is_special: + # out_string += " " + out_string += self.sp_model.decode(current_sub_tokens) + token + # prev_is_special = True + + current_sub_tokens = [] + else: + current_sub_tokens.append(token) + # prev_is_special = False + out_string += self.sp_model.decode(current_sub_tokens) + return out_string # .strip() + + def prepare_for_model(self, *args, **kwargs): + """doc""" + if "add_special_tokens" in kwargs: + kwargs.pop("add_special_tokens") + # logger.warning(f'ErnieBotTokenizer v2 does not support `add_special_tokens`') + return super().prepare_for_model(*args, **kwargs) + + def save_vocabulary(self, save_directory, filename_prefix: Optional[str] = None) -> Tuple[str]: + """ + Save the vocabulary and special tokens file to a directory. + Args: + save_directory (`str`): + The directory in which to save the vocabulary. + Returns: + `Tuple(str)`: Paths to the files saved. + """ + if not os.path.isdir(save_directory): + logger.error(f"Vocabulary path ({save_directory}) should be a directory") + return + out_vocab_file = os.path.join( + save_directory, + (filename_prefix + "-" if filename_prefix else "") + self.resource_files_names["vocab_file"], + ) + if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file): + copyfile(self.vocab_file, out_vocab_file) + elif not os.path.isfile(self.vocab_file): + with open(out_vocab_file, "wb") as fi: + content_spiece_model = self.sp_model.serialized_model_proto() + fi.write(content_spiece_model) + return (out_vocab_file,) + + def tokenize(self, text: TextInput, **kwargs) -> List[str]: + """ + Converts a string in a sequence of tokens, using the tokenizer. + + Split in words for word-based vocabulary or sub-words for sub-word-based vocabularies + (BPE/SentencePieces/WordPieces). Takes care of added tokens. + + Args: + text (`str`): + The sequence to be encoded. + **kwargs (additional keyword arguments): + Passed along to the model-specific `prepare_for_tokenization` preprocessing method. + + Returns: + `List[str]`: The list of tokens. + """ + # Simple mapping string => AddedToken for special tokens with specific tokenization behaviors + # all_special_tokens_extended = dict( + # (str(t), t) + # for t in self.all_special_tokens_extended + # if isinstance(t, AddedToken) + # ) + + self.spec_init() + text, kwargs = self.prepare_for_tokenization(text, **kwargs) + + # TODO: should this be in the base class? + if hasattr(self, "do_lower_case") and self.do_lower_case: + # convert non-special tokens to lowercase + escaped_special_toks = [re.escape(s_tok) for s_tok in (self.unique_no_split_tokens + self.all_spec_tok)] + pattern = r"(" + r"|".join(escaped_special_toks) + r")|" + r"(.+?)" + text = re.sub(pattern, lambda m: m.groups()[0] or m.groups()[1].lower(), text) + + no_split_token = set(self.unique_no_split_tokens) + tokens = self.tokens_trie.split(text) + + # ["This is something", "", " else"] + # for i, token in enumerate(tokens): + # if token in no_split_token: + # tok_extended = all_special_tokens_extended.get(token, None) + # print(f'>>>{token}|{tok_extended}|{all_special_tokens_extended}<<<') + # left = tokens[i - 1] if i > 0 else None + # right = tokens[i + 1] if i < len(tokens) - 1 else None + # if isinstance(tok_extended, AddedToken): + # if tok_extended.rstrip and right: + # # A bit counter-intuitive but we strip the left of the string + # # since tok_extended.rstrip means the special token is eating all white spaces on its right + # tokens[i + 1] = right.lstrip() + # # Strip white spaces on the left + # if tok_extended.lstrip and left: + # tokens[i - 1] = left.rstrip() # Opposite here + # else: + # We strip left and right by default + # if right: + # tokens[i + 1] = right.lstrip() + # if left: + # tokens[i - 1] = left.rstrip() + # ["This is something", "", "else"] + tokenized_text = [] + for token in tokens: + # Need to skip eventual empty (fully stripped) tokens + if not token: + continue + if token in no_split_token: + tokenized_text.append(token) + else: + tokenized_text.extend(self._tokenize(token)) + # ["This", " is", " something", "", "else"] + return tokenized_text + + def _decode(self, *args, **kwargs): + """doc""" + kwargs.pop("clean_up_tokenization_spaces", None) + kwargs.pop("spaces_between_special_tokens", None) + return super()._decode( + *args, **kwargs, clean_up_tokenization_spaces=False, spaces_between_special_tokens=False + ) + + def _pad( + self, + encoded_inputs: Dict, + max_length: Optional[int] = None, + padding_strategy=PaddingStrategy.DO_NOT_PAD, + pad_to_multiple_of: Optional[int] = None, + return_attention_mask: Optional[bool] = None, + ) -> dict: + """doc""" + if return_attention_mask is None: + return_attention_mask = "attention_mask" in self.model_input_names + if return_attention_mask: + required_input = encoded_inputs[self.model_input_names[0]] + if padding_strategy == PaddingStrategy.LONGEST: + max_length = len(required_input) + if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0): + max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of + needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length + if "attention_mask" in encoded_inputs and encoded_inputs["attention_mask"] is not None: + attention_mask = encoded_inputs.pop("attention_mask") + if isinstance(attention_mask, paddle.Tensor): + attention_mask = attention_mask.numpy() + elif isinstance(attention_mask, list): + attention_mask = np.array(attention_mask) + elif not isinstance(attention_mask, np.ndarray): + raise ValueError(f"Unexpected type {type(attention_mask)} of attention_mask, ") + else: + attention_mask = np.tril(np.ones((len(required_input), len(required_input)), dtype=np.int64)) + attention_mask = np.expand_dims(attention_mask, axis=0) + if needs_to_be_padded: + difference = max_length - len(required_input) + if self.padding_side == "right": + if attention_mask.ndim == 1: + pad_width = [(0, difference)] + else: + pad_width = [(0, 0), (0, difference), (0, difference)] + elif self.padding_side == "left": + if attention_mask.ndim == 1: + pad_width = [(difference, 0)] + else: + pad_width = [(0, 0), (difference, 0), (difference, 0)] + else: + raise ValueError("Invalid padding strategy:" + str(self.padding_side)) + attention_mask = np.pad(attention_mask, pad_width=pad_width, mode="constant", constant_values=0) + encoded_inputs = super()._pad( + encoded_inputs, + max_length, + padding_strategy=padding_strategy, + pad_to_multiple_of=pad_to_multiple_of, + return_attention_mask=False, + ) + if return_attention_mask: + encoded_inputs["attention_mask"] = attention_mask.tolist() + return encoded_inputs + + +def add_special_tokens( + tokenizer, + special_tokens_info, + use_ocr_specialtoken=False, + use_crop_specialtoken=False, + special_token_ids_start=254208, + special_token_ids_end=256256, +): + """ + 增加 special token + + placeholder [<|IMAGE_PLACEHOLDER|>, <|AUDIO_PLACEHOLDER|>, <|VIDEO_PLACEHOLDER|>] 共3个 + + 模态起始截止 special tokens [<|BOI|> <|EOI|> <|BOA|> <|EOA|> <|BOV|> <|EOV|>] + + ocr special tokens [<|LOC_0|> <|LOC_1|> ... <|LOC_1000|>] 共1001个 + + crop special tokens [<|CROP_COL_SEP|>, <|CROP_ROW_SEP|>, <|CROP_IMAGE_SEP|>] 共3个 + <|CROP_COL_SEP|> for col 维度切 图片width(替换原明文逗号) + <|CROP_ROW_SEP|> for row 维度切 图片height(替换原明文回车) + <|CROP_IMAGE_SEP|> for 区分原图和crop图 图片width(替换原明文两个回车) + + 共2048个 unsed token + + Args: + tokenizer (ErnieTokenizer): tokenizer + special_token_ids_start (int, optional): special token 起点 ids. Defaults to 254208. + special_token_ids_end (int, optional): 词表最多支持大小. Defaults to 256256. + """ + special_tokens = [special_tokens_info["image_placeholder"], special_tokens_info["audio_placeholder"]] + + if use_ocr_specialtoken: + special_tokens.extend(special_tokens_info["ocr_coor"]) + special_tokens.extend(special_tokens_info["ocr_begin_end"]) + + if use_crop_specialtoken: + special_tokens.extend(special_tokens_info["crop"]) + + # add special_tokens + additional_special_tokens = {"additional_special_tokens": special_tokens} + tokenizer.add_special_tokens(additional_special_tokens) + + # check + first_special_tokens = tokenizer.encode(special_tokens[0])["input_ids"] + + assert first_special_tokens[0] == special_token_ids_start, f"[ERROR] first_special_tokens={first_special_tokens}" + assert ( + len(tokenizer.get_vocab()) < special_token_ids_end + ), f"[ERROR] vocab_size = {len(tokenizer.get_vocab())} >= {special_token_ids_end} 增加过多special token了!" + + +class Ernie45Tokenizer(PretrainedTokenizer): + """ + 一个更好用的 `ErnieBotToknizer`, + 能 encode 目前 sft/ppo 阶段的特殊token,也支持多模态。 + """ + + resource_files_names = {"vocab_file": "tokenizer.model"} + pretrained_resource_files_map = {"vocab_file": {"ernie-bot-10b": None}} + pretrained_init_configuration = {"ernie-bot-10b": {}} + model_input_names = ["input_ids", "position_ids", "attention_mask", "labels"] + padding_side = "right" + + def __init__( + self, + vocab_file, + bos_token="", + cls_token="", + eos_token="", + mask_token="", + pad_token="", + sep_token="", + unk_token="", + additional_special_tokens=None, + verbose=False, + **kwargs, + ): + """doc""" + if additional_special_tokens is None: + additional_special_tokens = ["", ""] + super().__init__( + bos_token=bos_token, + cls_token=cls_token, + eos_token=eos_token, + mask_token=mask_token, + pad_token=pad_token, + sep_token=sep_token, + unk_token=unk_token, + additional_special_tokens=additional_special_tokens, + verbose=False, + **kwargs, + ) + self.vocab_file = vocab_file + self.sp_model = spm.SentencePieceProcessor() + self.sp_model.Load(vocab_file) + # pre-process map-type all spec token for decode accelerate. + + @property + def space_token(self): + """doc""" + return "" + + @property + def space_token_id(self): + """doc""" + return self.sp_model.piece_to_id("") + + @property + def gend_token(self): + """doc""" + return "" + + @property + def gend_token_id(self): + """doc""" + return self.sp_model.piece_to_id("") + + @property + def im_start_id(self): + """doc""" + return self.sp_model.piece_to_id("<|im_start|>") + + @property + def im_end_id(self): + """doc""" + return self.sp_model.piece_to_id("<|im_end|>") + + @property + def vocab_size(self): + """doc""" + return self.sp_model.vocab_size() + + def get_vocab(self): + """doc""" + vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)} + vocab.update(self.added_tokens_encoder) + return vocab + + def _tokenize(self, text): + """doc""" + return self.sp_model.encode_as_pieces(text) + + def _convert_token_to_id(self, token): + """doc""" + return self.sp_model.piece_to_id(token) + + def _convert_id_to_token(self, id): + """doc""" + return self.sp_model.id_to_piece(id) + + def spec_init(self): + """初始化特殊token集合 + 如果实例中不存在all_spec_tok属性,则使用all_special_tokens创建集合 + 并赋值给all_spec_tok属性 + """ + if not hasattr(self, "all_spec_tok"): + self.all_spec_tok = set(self.all_special_tokens) + + def convert_tokens_to_string(self, tokens): + """Converts a sequence of tokens (string) in a single string.""" + self.spec_init() + current_sub_tokens = [] + out_string = "" + # prev_is_special = False + for token in tokens: + # make sure that special tokens are not decoded using sentencepiece model + if token in self.all_spec_tok: + # if not prev_is_special: + # out_string += " " + out_string += self.sp_model.decode(current_sub_tokens) + token + # prev_is_special = True + + current_sub_tokens = [] + else: + current_sub_tokens.append(token) + # prev_is_special = False + out_string += self.sp_model.decode(current_sub_tokens) + return out_string # .strip() + + def prepare_for_model(self, *args, **kwargs): + """doc""" + if "add_special_tokens" in kwargs: + kwargs.pop("add_special_tokens") + # logger.warning(f'Ernie45Tokenizer v2 does not support `add_special_tokens`') + return super().prepare_for_model(*args, **kwargs) + + def save_vocabulary(self, save_directory, filename_prefix: Optional[str] = None) -> Tuple[str]: + """ + Save the vocabulary and special tokens file to a directory. + Args: + save_directory (`str`): + The directory in which to save the vocabulary. + Returns: + `Tuple(str)`: Paths to the files saved. + """ + if not os.path.isdir(save_directory): + logger.error(f"Vocabulary path ({save_directory}) should be a directory") + return + out_vocab_file = os.path.join( + save_directory, + (filename_prefix + "-" if filename_prefix else "") + self.resource_files_names["vocab_file"], + ) + if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file): + copyfile(self.vocab_file, out_vocab_file) + elif not os.path.isfile(self.vocab_file): + with open(out_vocab_file, "wb") as fi: + content_spiece_model = self.sp_model.serialized_model_proto() + fi.write(content_spiece_model) + return (out_vocab_file,) + + def tokenize(self, text: TextInput, **kwargs) -> List[str]: + """ + Converts a string in a sequence of tokens, using the tokenizer. + + Split in words for word-based vocabulary or sub-words for sub-word-based vocabularies + (BPE/SentencePieces/WordPieces). Takes care of added tokens. + + Args: + text (`str`): + The sequence to be encoded. + **kwargs (additional keyword arguments): + Passed along to the model-specific `prepare_for_tokenization` preprocessing method. + + Returns: + `List[str]`: The list of tokens. + """ + + self.spec_init() + text, kwargs = self.prepare_for_tokenization(text, **kwargs) + + # TODO: should this be in the base class? + if hasattr(self, "do_lower_case") and self.do_lower_case: + # convert non-special tokens to lowercase + escaped_special_toks = [re.escape(s_tok) for s_tok in (self.unique_no_split_tokens + self.all_spec_tok)] + pattern = r"(" + r"|".join(escaped_special_toks) + r")|" + r"(.+?)" + text = re.sub(pattern, lambda m: m.groups()[0] or m.groups()[1].lower(), text) + + no_split_token = set(self.unique_no_split_tokens) + tokens = self.tokens_trie.split(text) + + tokenized_text = [] + for token in tokens: + # Need to skip eventual empty (fully stripped) tokens + if not token: + continue + if token in no_split_token: + tokenized_text.append(token) + else: + tokenized_text.extend(self._tokenize(token)) + # ["This", " is", " something", "", "else"] + return tokenized_text + + def _decode(self, *args, **kwargs): + """doc""" + kwargs.pop("clean_up_tokenization_spaces", None) + kwargs.pop("spaces_between_special_tokens", None) + return super()._decode( + *args, **kwargs, clean_up_tokenization_spaces=False, spaces_between_special_tokens=False + ) + + def _pad( + self, + encoded_inputs: Dict, + max_length: Optional[int] = None, + padding_strategy=PaddingStrategy.DO_NOT_PAD, + pad_to_multiple_of: Optional[int] = None, + return_attention_mask: Optional[bool] = None, + ) -> dict: + """doc""" + if return_attention_mask is None: + return_attention_mask = "attention_mask" in self.model_input_names + if return_attention_mask: + required_input = encoded_inputs[self.model_input_names[0]] + if padding_strategy == PaddingStrategy.LONGEST: + max_length = len(required_input) + if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0): + max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of + needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length + if "attention_mask" in encoded_inputs and encoded_inputs["attention_mask"] is not None: + attention_mask = encoded_inputs.pop("attention_mask") + if isinstance(attention_mask, paddle.Tensor): + attention_mask = attention_mask.numpy() + elif isinstance(attention_mask, list): + attention_mask = np.array(attention_mask) + elif not isinstance(attention_mask, np.ndarray): + raise ValueError(f"Unexpected type {type(attention_mask)} of attention_mask, ") + else: + attention_mask = np.tril(np.ones((len(required_input), len(required_input)), dtype=np.int64)) + attention_mask = np.expand_dims(attention_mask, axis=0) + if needs_to_be_padded: + difference = max_length - len(required_input) + if self.padding_side == "right": + if attention_mask.ndim == 1: + pad_width = [(0, difference)] + else: + pad_width = [(0, 0), (0, difference), (0, difference)] + elif self.padding_side == "left": + if attention_mask.ndim == 1: + pad_width = [(difference, 0)] + else: + pad_width = [(0, 0), (difference, 0), (difference, 0)] + else: + raise ValueError("Invalid padding strategy:" + str(self.padding_side)) + attention_mask = np.pad(attention_mask, pad_width=pad_width, mode="constant", constant_values=0) + encoded_inputs = super()._pad( + encoded_inputs, + max_length, + padding_strategy=padding_strategy, + pad_to_multiple_of=pad_to_multiple_of, + return_attention_mask=False, + ) + if return_attention_mask: + encoded_inputs["attention_mask"] = attention_mask.tolist() + return encoded_inputs + + +hack_uft16_ascii = True +VOCAB_FILES_NAMES = {"vocab_file": "spm.model"} + + +class OOVProcess: + """ + 针对OOV词,做UTF-16-BE编码 + """ + + def __init__(self, vocab): + """ + Args: + vocab (dict): dict {token:id}, token is the word in vocabulary, id is the index of this word in vocabulary + e.g., {'hello': 0, 'world': 1, ...} + """ + self.vocab = vocab # dict {token:id} + self.b16_token_id_dict, self.b16_id_token_dict, self.bf16_tokens = self.get_b16_dict(self.vocab) + self.bf16_tokens = set(self.bf16_tokens) + + self.PREFIX = "<0x" + self.SUFFIX = ">" + + def encode_str(self, s, tgt_type): + """输入s是字符串,tgt_type是要编码的类型,最终得到十六进制表示字节列表。如输入为s=“魍”, tgt_type=‘utf-16-be’,输出[‘<0x9B>’, ‘<0x4D>’]""" + + # 将字符串编码为指定类型的字节串 + encoded_bytes = s.encode(tgt_type) + # 转换为十六进制表示的字节列表 + hex_list = [f"<0x{byte:02X}>" for byte in encoded_bytes] + return hex_list + + def decode_str(self, byte_16_list, tgt_type="utf-16-be"): + """ + 功能正好相反,输入s是十六进制表示字节列表,tgt_type是编码的类型,输出字符串。 + 如输出byte_16_list=[‘<0x9B>’, ‘<0x4D>’], tgt_type=‘utf-16-be’,输出“ 魍” + """ + + # 去除尖括号和'0x'前缀,并将其转换为字节数组 + byte_array = bytearray(int(byte[3:-1], 16) for byte in byte_16_list) + # 将字节数组解码为字符串 + decoded_str = byte_array.decode(tgt_type) + return decoded_str + + def tgt_type_convert(self, byte_16_list, src_type="utf-8", tgt_type="utf-16-be"): + """ + 输入是byte_16_list是十六进制的列表,src_type是byte_16_list的类型,tgt_type是要转换的类型。输出是类型为tgt_type的十六进制列表。 + 例如输出byte_16_list=[‘<0xE9>’, ‘<0xAD>’, ‘<0x8D>’], src_type=“utf-8”,tgt_type=“utf-16-be”,输出是[‘<0x9B>’, ‘<0x4D>’]。 + """ + # =======编码类型转换=== src_type->tgt_type + + if tgt_type == "utf-8" and src_type == "utf-16-be": + # hack: 针对OOV词被截断的bf16 ascii 字符,编码转化会失败,直接返回原字节列表 + try: + # 使用 decode_str 将字节列表解码成字符串 + decoded_str = self.decode_str(byte_16_list, src_type) + + # 使用 encode_str 将字符串编码为目标类型的字节列表 + encoded_list = self.encode_str(decoded_str, tgt_type) + except UnicodeDecodeError: + logger.warning( + f"UnicodeDecodeError: 被截断的OOV词无法转码,decode明文:{byte_16_list},src_type:{src_type}" + ) + return byte_16_list + else: + # 使用 decode_str 将字节列表解码成字符串 + decoded_str = self.decode_str(byte_16_list, src_type) + + # 使用 encode_str 将字符串编码为目标类型的字节列表 + encoded_list = self.encode_str(decoded_str, tgt_type) + + # # 使用 decode_str 将字节列表解码成字符串 + # decoded_str = self.decode_str(byte_16_list, src_type) + + # # 使用 encode_str 将字符串编码为目标类型的字节列表 + # encoded_list = self.encode_str(decoded_str, tgt_type) + + return encoded_list + + def is_hex_string(self, token): + """16进制判断""" + return token in self.bf16_tokens + + def change_single_ascii_for_utf16be(self, tokens_list_part, byte16_flag): + """TODO: hack 代码 LGH lgh 某个single ascii 字符没有添加到词表中,因此这个字符不转化,截断导致""" + tokens_list_part_v1 = [] + byte16_flag_v1 = [] + + for t, b in zip(tokens_list_part, byte16_flag): + if b == 0: + tokens_list_part_v1.append(t) + byte16_flag_v1.append(b) + else: + if len(t) % 2 == 0: + tokens_list_part_v1.append(t) + byte16_flag_v1.append(b) + else: + new_t = [t.pop()] + if t != []: + tokens_list_part_v1.append(t) + byte16_flag_v1.append(b) + tokens_list_part_v1.append(new_t) + byte16_flag_v1.append(0) + + tokens_list_part = tokens_list_part_v1 + byte16_flag = byte16_flag_v1 + + # assert False + return tokens_list_part, byte16_flag + + def oov_token_check(self, tokens_list): + """ + 检测 tokens_list 中的 16 进制片段,并进行分组。 + + Args: + tokens_list (list): 输入的 token 列表。 + + Returns: + tuple: 包含两个元素的元组: + - tokens_list_part (list): 分组后的列表。 + - byte16_flag (list): 每组是否为全 16 进制元素的标志位列表(1 表示全为 16 进制,0 表示否)。 + """ + tokens_list_part = [] + byte16_flag = [] + + current_group = [] + is_byte16 = None + + for token in tokens_list: + # 使用正则表达式函数判断是否为十六进制格式的字符串 + if self.is_hex_string(token): + # 当前元素是十六进制字符串 + if is_byte16 is None: + # 初始化当前组类型 + is_byte16 = True + if not is_byte16: + # 切换组,保存之前的组 + tokens_list_part.append(current_group) + byte16_flag.append(0) + current_group = [] + is_byte16 = True + else: + # 当前元素是普通字符串 + if is_byte16 is None: + # 初始化当前组类型 + is_byte16 = False + if is_byte16: + # 切换组,保存之前的组 + tokens_list_part.append(current_group) + byte16_flag.append(1) + current_group = [] + is_byte16 = False + + # 添加当前元素到当前组 + current_group.append(token) + + # 添加最后一个组 + if current_group: + tokens_list_part.append(current_group) + byte16_flag.append(1 if is_byte16 else 0) + + return tokens_list_part, byte16_flag + + def get_b16_dict(self, vocab): + """从vocab中得到16进制的id""" + hex_chars = "0123456789ABCDEF" + bf16_tokens = [f"<0x{''.join(p)}>" for p in product(hex_chars, repeat=2)] + assert len(bf16_tokens) == 256, bf16_tokens + b16_token_id_dict = {} + b16_id_token_dict = {} + + for bf16_t in bf16_tokens: + idx = vocab[bf16_t] + b16_token_id_dict[bf16_t] = idx + b16_id_token_dict[idx] = bf16_t + assert len(b16_token_id_dict) == len(b16_id_token_dict) == 256, f"{b16_token_id_dict}\n{b16_id_token_dict}" + return b16_token_id_dict, b16_id_token_dict, bf16_tokens + + def encode_or_tokenize_convert_oov(self, tokens=None, token_ids=None, src_type="utf-8", tgt_type="utf-16-be"): + """目的:tokenizer生成token或者token id过程中,对于oov词会拆成utf-8 的16进制表示,需要将这部分表示转化为utf-16-be 的16进制表示""" + + # token转化 + new_tokens = [] + if tokens is not None: + # 筛选出ovv + tokens_list_part, byte16_flag = self.oov_token_check(tokens) + + # TODO: hack 代码 LGH lgh 某个single ascii 字符没有添加到词表中,因此这个字符不转化。临时解决方案!!后续将这些字符加到词表中!!! ==== + # if tgt_type=="utf-16-be": + if tgt_type == "utf-8" and hack_uft16_ascii: + tokens_list_part, byte16_flag = self.change_single_ascii_for_utf16be(tokens_list_part, byte16_flag) + # ==== + + # 将oov utf-8 的16进制表示转化为utf-16-be 16进制表示 + # ==================原始版本=============== + for token_part, byte16 in zip(tokens_list_part, byte16_flag): + assert byte16 in [0, 1], byte16 + if byte16 == 1: + # utf-8 16进制 => utf-16-be 16进制 + token_part = self.tgt_type_convert(token_part, src_type=src_type, tgt_type=tgt_type) + new_tokens.extend(token_part) + # ========================== + # ## ==========加速版本2 =========== + # new_tokens = [ + # converted_token + # for token_part, byte16 in zip(tokens_list_part, byte16_flag) + # for converted_token in ( + # self.tgt_type_convert(token_part, src_type=src_type, tgt_type=tgt_type) + # if byte16 == 1 + # else token_part + # ) + # ] + # # ========================== + + # ## token id转化 + new_token_ids = [] + if token_ids is not None: + new_token_ids_b16 = [] + assert self.b16_id_token_dict != {}, token_ids + # 将16进制的id转化为token + + token_ids_16 = [self.b16_id_token_dict.get(id_one, id_one) for id_one in token_ids] + # 筛选出ovv + tokens_list_part, byte16_flag = self.oov_token_check(token_ids_16) + + # TODO: hack 代码 LGH lgh 某个single ascii 字符没有添加到词表中,因此这个字符不转化。临时解决方案!!后续将这些字符加到词表中!!! ==== + # if tgt_type=="utf-16-be": + if tgt_type == "utf-8" and hack_uft16_ascii: + tokens_list_part, byte16_flag = self.change_single_ascii_for_utf16be(tokens_list_part, byte16_flag) + # ==== + + # 将oov utf-8 的16进制表示转化为utf-16-be 16进制表示 + # ==================原始版本=============== + for token_part, byte16 in zip(tokens_list_part, byte16_flag): + assert byte16 in [0, 1], byte16 + if byte16 == 1: + # utf-8 16进制 => utf-16-be 16进制 + token_part = self.tgt_type_convert(token_part, src_type=src_type, tgt_type=tgt_type) + new_token_ids_b16.extend(token_part) + # ======================================== + # ### ==========加速版本2 =========== + # new_token_ids_b16 = [ + # converted_token + # for token_part, byte16 in zip(tokens_list_part, byte16_flag) + # for converted_token in ( + # self.tgt_type_convert(token_part, src_type=src_type, tgt_type=tgt_type) + # if byte16 == 1 + # else token_part + # ) + # ] + # # =========================== + new_token_ids = [self.b16_token_id_dict.get(id_one, id_one) for id_one in new_token_ids_b16] + + return new_tokens, new_token_ids + + def decode_convert_oov(self, tokens=None, token_ids=None, src_type="utf-16-be", tgt_type="utf-8"): + """ + 目的:sentencepiece中的sp.decode(ids)以及sp.decode_pieces(pieces)中的id或pieces中的16进制token都必须是"utf-8"格式, + 但是现在tokenizer.tokenize出来的16进制是“utf-16-be”,因此要转化为"utf-8"格式后,送入sentencepiece解码。 + """ + new_tokens, new_token_ids = self.encode_or_tokenize_convert_oov( + tokens=tokens, token_ids=token_ids, src_type=src_type, tgt_type=tgt_type + ) + return new_tokens, new_token_ids + + def encode_str_and_encode_str(self, s, tgt_type): + """doc""" + pass + # encoded_list = self.encode_str(s, tgt_type) + # decoded_str = self.decode_str(encoded_list, tgt_type) + + @staticmethod + def get_vocab(model_file): + """doc""" + sp = spm.SentencePieceProcessor(model_file=model_file) + vocab_size = sp.vocab_size() + assert sp.vocab_size() == sp.get_piece_size() + vocab = {sp.id_to_piece(i): i for i in range(vocab_size)} + return vocab, sp + + +class Ernie5Tokenizer(PretrainedTokenizer): + """ + Construct a ErnieBot tokenizer. Based on byte-level Byte-Pair-Encoding. + Args: + vocab_file (`str`): + Path to the vocabulary file. + """ + + vocab_files_names = {"vocab_file": "spm.model"} + pretrained_vocab_files_map = {"vocab_file": {}, "tokenizer_file": {}} + max_model_input_sizes = {} + model_input_names = ["input_ids", "attention_mask"] + + def __init__( + self, + vocab_file, + unk_token="", + bos_token="", + eos_token="", + pad_token="", + sp_model_kwargs: Optional[Dict[str, Any]] = None, + add_bos_token=True, + add_eos_token=False, + clean_up_tokenization_spaces=False, + **kwargs, + ): + """ + Constructs a SentencePieceTokenizer. + + Args: + vocab_file (str): The vocabulary file path (ends with .model) required to instantiate + the SentencePiece processor. + unk_token (str, optional): The unknown token. Defaults to " ". + bos_token (Union[str, AddedToken], optional): The beginning of sentence token. Defaults to " ". + eos_token (Union[str, AddedToken], optional): The end of sentence token. Defaults to " ". + pad_token (Union[str, AddedToken], optional): The padding token. Defaults to "". + sp_model_kwargs (Optional[Dict[str, Any]], optional): Keyword arguments passed to the SentencePiece + constructor. Defaults to None. + add_bos_token (bool, optional): Whether or not to add the bos token at the beginning of every + encoded piece. Defaults to True. + add_eos_token (bool, optional): Whether or not to add the eos token at the end of every encoded piece. + Defaults to False. + clean_up_tokenization_spaces (bool, optional): Whether or not to clean up the tokenization spaces. + Defaults to False. + **kwargs (Any, optional): Additional keyword arguments passed along to the `__init__` method of the + parent `PretrainedTokenizer`. + """ + self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs + bos_token = AddedToken(bos_token, lstrip=False, rstrip=False) if isinstance(bos_token, str) else bos_token + eos_token = AddedToken(eos_token, lstrip=False, rstrip=False) if isinstance(eos_token, str) else eos_token + unk_token = AddedToken(unk_token, lstrip=False, rstrip=False) if isinstance(unk_token, str) else unk_token + pad_token = AddedToken(pad_token, lstrip=False, rstrip=False) if isinstance(pad_token, str) else pad_token + super().__init__( + bos_token=bos_token, + eos_token=eos_token, + unk_token=unk_token, + pad_token=pad_token, + add_bos_token=add_bos_token, + add_eos_token=add_eos_token, + sp_model_kwargs=self.sp_model_kwargs, + clean_up_tokenization_spaces=clean_up_tokenization_spaces, + **kwargs, + ) + self.vocab_file = vocab_file + self.add_bos_token = add_bos_token + self.add_eos_token = add_eos_token + self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) + self.sp_model.Load(vocab_file) + # for eb35 reader + self.bos_id = self.bos_token_id + self.eos_id = self.eos_token_id + self.sep_id = self.sep_token_id + self.pad_id = self.pad_token_id + self.unk_id = self.unk_token_id + + vocab = self.get_vocab() # oov + self.oov_process = OOVProcess(vocab) # oov + self.use_oov_uft_16_be = True # True # oov是否使用uft_16_be编码 + logger.info(f">>> UTF_16_BE: self.use_oov_uft_16_be:{self.use_oov_uft_16_be}") + + def set_oov_utf_16_be(self, use_oov_uft_16_be=True): + """ + use_oov_uft_16_be 开关 + """ + self.use_oov_uft_16_be = use_oov_uft_16_be + print(f"use_oov_uft_16_be:{self.use_oov_uft_16_be}") + + def __getstate__(self): + """ + Override the default __getstate__ method to prevent pickling of spaCy models. + + Args: + None + + Returns: + dict (state): A dictionary containing all instance attributes except for "sp_model". + """ + state = self.__dict__.copy() + state["sp_model"] = None + return state + + def __setstate__(self, d): + """ + Restore the state of the object from a dictionary. + + Args: + d (dict): A dictionary containing the state of the object. + It should contain the keys 'sp_model_kwargs' and 'vocab_file'. + + Returns: + None. The object is updated in-place with the provided state. + """ + self.__dict__ = d + self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) + self.sp_model.Load(self.vocab_file) + + @property + def vocab_size(self): + """Returns vocab size""" + return self.sp_model.get_piece_size() + + def get_vocab(self): + """Returns vocab as a dict""" + vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)} + vocab.update(self.added_tokens_encoder) + return vocab + + def tokenize(self, text): + """Returns a tokenized string.""" + return self._tokenize(text) + + def encode_oov_uft_16_be(self, tokens): + """spm encode 或者 tokenizer生成tokens、token_ids后,使用此函数针对oov词转化为utf16be编码""" + if isinstance(tokens, list): + pass + else: + tokens = [tokens] + + if isinstance(tokens[0], str): + tokens, _ = self.oov_process.encode_or_tokenize_convert_oov(tokens=tokens) + else: + assert isinstance(tokens[0], int) + _, tokens = self.oov_process.encode_or_tokenize_convert_oov(token_ids=tokens) + return tokens + + def is_empty(self, value): + """检查是否为 None""" + if value is None: + return True + + # 检查是否为空字符串 + if isinstance(value, str) and value == "": + return True + + # 检查是否为空列表 + if isinstance(value, list) and len(value) == 0: + return True + + # 如果不是以上任何一种,返回 False + return False + + def _tokenize(self, text): + """Returns a tokenized string.""" + tokens = self.sp_model.encode(text, out_type=str) + if not self.is_empty(tokens) and self.use_oov_uft_16_be: # oov utf8转化为utf16be + tokens = self.encode_oov_uft_16_be(tokens=tokens) + return tokens + + def decode_oov_uft_16_be(self, tokens): + """spm decode前,将tokens、token_ids形式中OOV词的utf16be编码转化为utf8编码""" + if isinstance(tokens, list): + pass + else: + tokens = [tokens] + + if isinstance(tokens[0], str): + tokens, _ = self.oov_process.decode_convert_oov(tokens=tokens) + else: + assert isinstance(tokens[0], int) + _, tokens = self.oov_process.decode_convert_oov(token_ids=tokens) + return tokens + + def decode(self, tokens, skip_special_tokens=False, clean_up_tokenization_spaces=False): + """Returns a tokenized string.""" + if not self.is_empty(tokens) and self.use_oov_uft_16_be: # oov utf16be转化为utf8 + tokens = self.decode_oov_uft_16_be(tokens) + return self.sp_model.decode(tokens) + + def _convert_token_to_id(self, token): + """Converts a token (str) in an id using the vocab.""" + return self.sp_model.piece_to_id(token) + + def _convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + token = self.sp_model.IdToPiece(index) + return token + + def convert_tokens_to_string(self, tokens): + """Converts a sequence of tokens (string) in a single string.""" + current_sub_tokens = [] + out_string = "" + prev_is_special = False + for i, token in enumerate(tokens): + # make sure that special tokens are not decoded using sentencepiece model + if token in self.all_special_tokens: + if not prev_is_special and i != 0: + out_string += " " + + if not self.is_empty(current_sub_tokens) and self.use_oov_uft_16_be: # oov utf16be转化为utf8 + current_sub_tokens = self.decode_oov_uft_16_be(current_sub_tokens) + + out_string += self.sp_model.decode(current_sub_tokens) + token + prev_is_special = True + current_sub_tokens = [] + else: + current_sub_tokens.append(token) + prev_is_special = False + + if not self.is_empty(current_sub_tokens) and self.use_oov_uft_16_be: # oov utf16be转化为utf8 + current_sub_tokens = self.decode_oov_uft_16_be(current_sub_tokens) + out_string += self.sp_model.decode(current_sub_tokens) + return out_string + + def save_vocabulary(self, save_directory, filename_prefix: Optional[str] = None) -> Tuple[str]: + """ + Save the vocabulary and special tokens file to a directory. + Args: + save_directory (`str`): + The directory in which to save the vocabulary. + Returns: + `Tuple(str)`: Paths to the files saved. + """ + if not os.path.isdir(save_directory): + logger.error(f"Vocabulary path ({save_directory}) should be a directory") + return + out_vocab_file = os.path.join( + save_directory, + (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"], + ) + + if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file): + copyfile(self.vocab_file, out_vocab_file) + elif not os.path.isfile(self.vocab_file): + with open(out_vocab_file, "wb") as fi: + content_spiece_model = self.sp_model.serialized_model_proto() + fi.write(content_spiece_model) + + return (out_vocab_file,) + + def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): + """ + build_inputs_with_special_tokens + """ + bos_token_id = [self.bos_token_id] if self.add_bos_token else [] + eos_token_id = [self.eos_token_id] if self.add_eos_token else [] + + output = bos_token_id + token_ids_0 + eos_token_id + + if token_ids_1 is not None: + output = output + bos_token_id + token_ids_1 + eos_token_id + + return output + + def get_special_tokens_mask( + self, + token_ids_0: List[int], + token_ids_1: Optional[List[int]] = None, + already_has_special_tokens: bool = False, + ) -> List[int]: + """ + Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding + special tokens using the tokenizer `prepare_for_model` method. + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + already_has_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not the token list is already formatted with special tokens for the model. + Returns: + `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. + """ + if already_has_special_tokens: + return super().get_special_tokens_mask( + token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True + ) + + bos_token_id = [1] if self.add_bos_token else [] + eos_token_id = [1] if self.add_eos_token else [] + + if token_ids_1 is None: + return bos_token_id + ([0] * len(token_ids_0)) + eos_token_id + return ( + bos_token_id + + ([0] * len(token_ids_0)) + + eos_token_id + + bos_token_id + + ([0] * len(token_ids_1)) + + eos_token_id + ) + + def create_token_type_ids_from_sequences( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Creates a mask from the two sequences passed to be used in a sequence-pair classification task. An ALBERT + sequence pair mask has the following format: + ``` + 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 + | first sequence | second sequence | + ``` + if token_ids_1 is None, only returns the first portion of the mask (0s). + Args: + token_ids_0 (`List[int]`): + List of ids. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + Returns: + `List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s). + """ + bos_token_id = [self.bos_token_id] if self.add_bos_token else [] + eos_token_id = [self.eos_token_id] if self.add_eos_token else [] + + output = [0] * len(bos_token_id + token_ids_0 + eos_token_id) + + if token_ids_1 is not None: + output += [1] * len(bos_token_id + token_ids_1 + eos_token_id) + + return output diff --git a/benchmarks/paddleocr_vl/PaddleOCR-VL.yaml b/benchmarks/paddleocr_vl/PaddleOCR-VL.yaml new file mode 100644 index 00000000000..f22fb98cc3c --- /dev/null +++ b/benchmarks/paddleocr_vl/PaddleOCR-VL.yaml @@ -0,0 +1,97 @@ + +pipeline_name: PaddleOCR-VL + +batch_size: 64 + +use_queues: True + +use_doc_preprocessor: False +use_layout_detection: True +use_chart_recognition: False +format_block_content: False + +SubModules: + LayoutDetection: + module_name: layout_detection + model_name: PP-DocLayoutV2 + model_dir: null + batch_size: 8 + threshold: + 0: 0.5 # abstract + 1: 0.5 # algorithm + 2: 0.5 # aside_text + 3: 0.5 # chart + 4: 0.5 # content + 5: 0.4 # formula + 6: 0.4 # doc_title + 7: 0.5 # figure_title + 8: 0.5 # footer + 9: 0.5 # footer + 10: 0.5 # footnote + 11: 0.5 # formula_number + 12: 0.5 # header + 13: 0.5 # header + 14: 0.5 # image + 15: 0.4 # formula + 16: 0.5 # number + 17: 0.4 # paragraph_title + 18: 0.5 # reference + 19: 0.5 # reference_content + 20: 0.45 # seal + 21: 0.5 # table + 22: 0.4 # text + 23: 0.4 # text + 24: 0.5 # vision_footnote + layout_nms: True + layout_unclip_ratio: [1.0, 1.0] + layout_merge_bboxes_mode: + 0: "union" # abstract + 1: "union" # algorithm + 2: "union" # aside_text + 3: "large" # chart + 4: "union" # content + 5: "large" # display_formula + 6: "large" # doc_title + 7: "union" # figure_title + 8: "union" # footer + 9: "union" # footer + 10: "union" # footnote + 11: "union" # formula_number + 12: "union" # header + 13: "union" # header + 14: "union" # image + 15: "large" # inline_formula + 16: "union" # number + 17: "large" # paragraph_title + 18: "union" # reference + 19: "union" # reference_content + 20: "union" # seal + 21: "union" # table + 22: "union" # text + 23: "union" # text + 24: "union" # vision_footnote + VLRecognition: + module_name: vl_recognition + model_name: PaddleOCR-VL-0.9B + model_dir: null + batch_size: 4096 + genai_config: + backend: fastdeploy-server + server_url: http://127.0.0.1:8118/v1 + +SubPipelines: + DocPreprocessor: + pipeline_name: doc_preprocessor + batch_size: 8 + use_doc_orientation_classify: True + use_doc_unwarping: True + SubModules: + DocOrientationClassify: + module_name: doc_text_orientation + model_name: PP-LCNet_x1_0_doc_ori + model_dir: null + batch_size: 8 + DocUnwarping: + module_name: image_unwarping + model_name: UVDoc + model_dir: null diff --git a/benchmarks/paddleocr_vl/README.md b/benchmarks/paddleocr_vl/README.md new file mode 100644 index 00000000000..3dbe96c3898 --- /dev/null +++ b/benchmarks/paddleocr_vl/README.md @@ -0,0 +1,139 @@ +## FastDeploy 服务化性能压测工具(PaddleOCR-VL) + +本文档主要介绍如何对 [PaddleOCR-VL](https://www.paddleocr.ai/latest/version3.x/pipeline_usage/PaddleOCR-VL.html) 进行性能测试。 + +### 数据集: + +下载数据集到本地用于性能测试: + + + + + + + + + + + + + + +
数据集获取地址
OmniDocBench v1 数据集,共 981 个 pdf 文件https://github.com/opendatalab/OmniDocBench
+ +### 使用方式 + +1. 启动 FastDeploy 服务,下面为 A100-80G 测试时使用的参数,可以根据实际情况进行调整: + + ```shell + python -m fastdeploy.entrypoints.openai.api_server \ + --model PaddlePaddle/PaddleOCR-VL \ + --port 8118 \ + --metrics-port 8471 \ + --engine-worker-queue-port 8472 \ + --cache-queue-port 55660 \ + --max-model-len 16384 \ + --max-num-batched-tokens 16384 \ + --gpu-memory-utilization 0.7 \ + --max-num-seqs 256 \ + --workers 2 \ + --graph-optimization-config '{"graph_opt_level":0, "use_cudagraph":true}' + ``` + +2. 在同一环境安装依赖后启动测试脚本: + + ```shell + # 安装依赖 + pip install -U paddlex + # 启动测试脚本 + python benchmark.py ./test_data -b 512 -o ./benchmark.json --paddlex_config_path ./PaddleOCR-VL.yaml --gpu_ids 0 + ``` + + 测试脚本参数说明: + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
参数说明
input_dirs输入的目录路径,会自动识别到目录下的 pdf 或图片。可以提供一个或多个。
-b, --batch_size推理时使用的批处理大小。
-o, --output_path输出结果文件的路径。
--paddlex_config_pathPaddleX 的 YAML 配置文件路径。
--gpu_ids指定要使用的 GPU 设备 ID,可提供一个或多个。
+ +3. 测试结束后,会输出类似于下面的结果: + + ```text + Throughput (file): 1.3961 files per second + Average latency (batch): 351.0812 seconds + Processed pages: 981 + Throughput (page): 1.3961 pages per second + Generated tokens: 1510337 + Throughput (token): 2149.5 tokens per second + GPU utilization (%): 100.0, 0.0, 68.1 + GPU memory usage (MB): 77664.8, 58802.8, 74402.7 + ``` + + 输出结果说明: + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
参数说明
Throughput (file)每秒处理的文件数量
Average latency (batch)每批次处理的平均延迟时间,单位为秒
Processed pages已处理的页面总数
Throughput (page)每秒处理的页面数量
Generated tokens生成的token总数
Throughput (token)每秒生成的token数量
GPU utilization (%)GPU 的最大、最小、平均利用率
GPU memory usage (MB)GPU 的最大、最小、平均显存占用,单位为 MB
diff --git a/benchmarks/paddleocr_vl/benchmark.py b/benchmarks/paddleocr_vl/benchmark.py new file mode 100644 index 00000000000..c09d91c5360 --- /dev/null +++ b/benchmarks/paddleocr_vl/benchmark.py @@ -0,0 +1,226 @@ +#!/usr/bin/env python + +import argparse +import glob +import json +import os +import sys +import time +import uuid +from operator import itemgetter +from threading import Thread + +import pynvml +import tiktoken +from tqdm import tqdm + +shutdown = False + +encoding = tiktoken.get_encoding("cl100k_base") + + +class Predictor(object): + def predict(self, task_info, batch_data): + task_info["start_time"] = get_curr_time() + try: + markdown, num_pages = self._predict(batch_data) + except Exception as e: + task_info["successful"] = False + print(e) + raise + finally: + task_info["end_time"] = get_curr_time() + task_info["successful"] = True + task_info["processed_pages"] = num_pages + task_info["generated_tokens"] = len(encoding.encode(markdown)) + return markdown + + def _predict(self, batch_data): + raise NotImplementedError + + def close(self): + pass + + +class PaddleXPredictor(Predictor): + def __init__(self, config_path): + from paddlex import create_pipeline + + super().__init__() + self.pipeline = create_pipeline(config_path) + + def _predict(self, batch_data): + results = list(self.pipeline.predict(batch_data)) + return "\n\n".join(res._to_markdown(pretty=False)["markdown_texts"] for res in results), len(results) + + def close(self): + self.pipeline.close() + + +def monitor_device(gpu_ids, gpu_metrics_list): + try: + pynvml.nvmlInit() + handles = [pynvml.nvmlDeviceGetHandleByIndex(gpu_id) for gpu_id in gpu_ids] + + time.sleep(5) + while not shutdown: + try: + gpu_util = 0 + mem_bytes = 0 + + for handle in handles: + gpu_util += pynvml.nvmlDeviceGetUtilizationRates(handle).gpu + mem_bytes += pynvml.nvmlDeviceGetMemoryInfo(handle).used + + gpu_metrics_list.append( + { + "utilization": gpu_util, + "memory": mem_bytes, + } + ) + except Exception as e: + print(f"Error monitoring GPUs: {e}") + + time.sleep(0.5) + + except Exception as e: + print(f"Error initializing the GPU monitor: {e}") + finally: + try: + pynvml.nvmlShutdown() + except: + pass + + +def get_curr_time(): + return time.perf_counter() + + +def new_task_info(): + task_info = {} + task_info["id"] = uuid.uuid4().hex + return task_info + + +def create_and_submit_new_task(executor, requestor, task_info_dict, input_path): + task_info = new_task_info() + task = executor.submit( + requestor.make_request, + task_info, + input_path, + ) + task_info_dict[task] = task_info + + return task + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("input_dirs", type=str, nargs="+", metavar="INPUT_DIR") + parser.add_argument("-b", "--batch_size", type=int, default=1) + parser.add_argument("-o", "--output_path", type=str, default="benchmark.json") + parser.add_argument("--paddlex_config_path", type=str, default="PaddleOCR-VL.yaml") + parser.add_argument("--gpu_ids", type=int, nargs="+", default=[0]) + args = parser.parse_args() + + task_info_list = [] + + all_input_paths = [] + for input_dir in args.input_dirs: + all_input_paths += glob.glob(os.path.join(input_dir, "*")) + all_input_paths.sort() + if len(all_input_paths) == 0: + print("No valid data") + sys.exit(1) + + predictor = PaddleXPredictor(args.paddlex_config_path) + + if args.batch_size < 1: + print("Invalid batch size") + sys.exit(2) + + gpu_metrics_list = [] + thread_device_monitor = Thread( + target=monitor_device, + args=(args.gpu_ids, gpu_metrics_list), + ) + thread_device_monitor.start() + + try: + start_time = get_curr_time() + batch_data = [] + with open("generated_markdown.md", "w", encoding="utf-8") as f: + for i, input_path in tqdm(enumerate(all_input_paths), total=len(all_input_paths)): + batch_data.append(input_path) + if len(batch_data) == args.batch_size or i == len(all_input_paths) - 1: + task_info = new_task_info() + try: + markdown = predictor.predict(task_info, batch_data) + f.write(markdown) + f.write("\n\n") + except Exception as e: + print(e) + continue + task_info_list.append(task_info) + batch_data.clear() + end_time = get_curr_time() + finally: + shutdown = True + thread_device_monitor.join() + predictor.close() + + total_files = len(all_input_paths) + throughput_file = total_files / (end_time - start_time) + print(f"Throughput (file): {throughput_file:.4f} files per second") + duration_list_batch = [info["end_time"] - info["start_time"] for info in task_info_list] + avg_latency_batch = sum(duration_list_batch) / len(duration_list_batch) + print(f"Average latency (batch): {avg_latency_batch:.4f} seconds") + + successful_files = sum(map(lambda x: x["successful"], task_info_list)) + if successful_files: + processed_pages = sum(info.get("processed_pages", 0) for info in task_info_list) + throughput_page = processed_pages / (end_time - start_time) + print(f"Processed pages: {processed_pages}") + print(f"Throughput (page): {throughput_page:.4f} pages per second") + generated_tokens = sum(info.get("generated_tokens", 0) for info in task_info_list) + throughput_token = generated_tokens / (end_time - start_time) + print(f"Generated tokens: {generated_tokens}") + print(f"Throughput (token): {throughput_token:.1f} tokens per second") + else: + processed_pages = None + throughput_page = None + generated_tokens = None + throughput_token = None + + if gpu_metrics_list: + gpu_util_list = list(map(itemgetter("utilization"), gpu_metrics_list)) + print( + f"GPU utilization (%): {max(gpu_util_list):.1f}, {min(gpu_util_list):.1f}, {sum(gpu_util_list) / len(gpu_util_list):.1f}" + ) + gpu_mem_list = list(map(itemgetter("memory"), gpu_metrics_list)) + print( + f"GPU memory usage (MB): {max(gpu_mem_list) / 1024**2:.1f}, {min(gpu_mem_list) / 1024**2:.1f}, {sum(gpu_mem_list) / len(gpu_mem_list) / 1024**2:.1f}" + ) + + dic = { + "input_dirs": args.input_dirs, + "batch_size": args.batch_size, + "total_files": total_files, + "throughput_file": throughput_file, + "avg_latency_batch": avg_latency_batch, + "duration_list": duration_list_batch, + "successful_files": successful_files, + "processed_pages": processed_pages, + "throughput_page": throughput_page, + "generated_tokens": generated_tokens, + "throughput_token": throughput_token, + "gpu_metrics_list": gpu_metrics_list, + } + with open(args.output_path, "w", encoding="utf-8") as f: + json.dump( + dic, + f, + ensure_ascii=False, + indent=2, + ) + print(f"Config and results saved to {args.output_path}") diff --git a/benchmarks/requirements_tokenizer.txt b/benchmarks/requirements_tokenizer.txt new file mode 100644 index 00000000000..82dd8f3aaaa --- /dev/null +++ b/benchmarks/requirements_tokenizer.txt @@ -0,0 +1,9 @@ +aiohttp +tqdm +numpy +Pillow +pyyaml +requests +paddle +paddleformers +fastdeploy diff --git a/benchmarks/yaml/GLM45-air-32k-bf16-mtp-updatemodel.yaml b/benchmarks/yaml/GLM45-air-32k-bf16-mtp-updatemodel.yaml new file mode 100644 index 00000000000..69e9fd1823d --- /dev/null +++ b/benchmarks/yaml/GLM45-air-32k-bf16-mtp-updatemodel.yaml @@ -0,0 +1,10 @@ +max_model_len: 32768 +max_num_seqs: 128 +tensor_parallel_size: 4 +graph_optimization_config: + use_cudagraph: True + draft_model_use_cudagraph: True +load_choices: "default_v1" +dynamic_load_weight: True +load_strategy: ipc_snapshot +shutdown_comm_group_if_worker_idle: False diff --git a/benchmarks/yaml/GLM45-air-32k-bf16-mtp.yaml b/benchmarks/yaml/GLM45-air-32k-bf16-mtp.yaml new file mode 100644 index 00000000000..6273f6c7514 --- /dev/null +++ b/benchmarks/yaml/GLM45-air-32k-bf16-mtp.yaml @@ -0,0 +1,7 @@ +max_model_len: 32768 +max_num_seqs: 128 +tensor_parallel_size: 4 +graph_optimization_config: + use_cudagraph: True + draft_model_use_cudagraph: True +load_choices: "default_v1" diff --git a/benchmarks/yaml/GLM45-air-32k-bf16-rl.yaml b/benchmarks/yaml/GLM45-air-32k-bf16-rl.yaml new file mode 100644 index 00000000000..93813c9bd2a --- /dev/null +++ b/benchmarks/yaml/GLM45-air-32k-bf16-rl.yaml @@ -0,0 +1,10 @@ +tensor_parallel_size: 8 +max_num_seqs: 32 +gpu_memory_utilization: 0.8 +load_choices: default_v1 +enable_prefix_caching: True +graph_optimization_config: '{"use_cudagraph":true}' +max_model_len: 66560 +enable_logprob: True +enable_custom_all_reduce: False +worker: 2 diff --git a/benchmarks/yaml/GLM45-air-32k-bf16.yaml b/benchmarks/yaml/GLM45-air-32k-bf16.yaml new file mode 100644 index 00000000000..b14dce761ea --- /dev/null +++ b/benchmarks/yaml/GLM45-air-32k-bf16.yaml @@ -0,0 +1,5 @@ +max_model_len: 32768 +max_num_seqs: 128 +tensor_parallel_size: 4 +use_cudagraph: True +load_choices: "default_v1" diff --git a/benchmarks/yaml/GLM45-air-32k-wfp8afp8.yaml b/benchmarks/yaml/GLM45-air-32k-wfp8afp8.yaml new file mode 100644 index 00000000000..5e4afe79e8e --- /dev/null +++ b/benchmarks/yaml/GLM45-air-32k-wfp8afp8.yaml @@ -0,0 +1,6 @@ +max_model_len: 32768 +max_num_seqs: 128 +tensor_parallel_size: 4 +use_cudagraph: True +load_choices: "default_v1" +quantization: wfp8afp8 diff --git a/benchmarks/yaml/deepseek-32k-tp8-wint4.yaml b/benchmarks/yaml/deepseek-32k-tp8-wint4.yaml new file mode 100644 index 00000000000..a09349f044a --- /dev/null +++ b/benchmarks/yaml/deepseek-32k-tp8-wint4.yaml @@ -0,0 +1,9 @@ +quantization: wint4 +load_choices: "default_v1" +graph_optimization_config: + use_cudagraph: True + use_unique_memory_pool: True +enable_prefix_caching: False +max_num_seqs: 256 +max_model_len: 32768 +tensor_parallel_size: 8 diff --git a/benchmarks/yaml/eb45-128k-wint4-a800-tp8.yaml b/benchmarks/yaml/eb45-128k-wint4-a800-tp8.yaml index 280f8e336c0..3667361e018 100644 --- a/benchmarks/yaml/eb45-128k-wint4-a800-tp8.yaml +++ b/benchmarks/yaml/eb45-128k-wint4-a800-tp8.yaml @@ -6,3 +6,4 @@ tensor_parallel_size: 8 max_num_batched_tokens: 4096 max_num_partial_prefills: 3 max_long_partial_prefills: 3 +quantization: wint4 diff --git a/benchmarks/yaml/eb45-128k-wint4-tp1-plas.yaml b/benchmarks/yaml/eb45-128k-wint4-tp1-plas.yaml new file mode 100644 index 00000000000..6ec412b1871 --- /dev/null +++ b/benchmarks/yaml/eb45-128k-wint4-tp1-plas.yaml @@ -0,0 +1,6 @@ +tensor_parallel_size: 1 +max_model_len: 131072 +max_num_seqs: 32 +quantization: wint4 +max_num_batched_tokens: 8192 +plas_attention_config: '{"plas_encoder_top_k_left": 50, "plas_encoder_top_k_right": 60, "plas_decoder_top_k_left": 100, "plas_decoder_top_k_right": 120}' diff --git a/benchmarks/yaml/eb45-128k-wint8-a800-tp8.yaml b/benchmarks/yaml/eb45-128k-wint8-a800-tp8.yaml index 280f8e336c0..bc458d1a53a 100644 --- a/benchmarks/yaml/eb45-128k-wint8-a800-tp8.yaml +++ b/benchmarks/yaml/eb45-128k-wint8-a800-tp8.yaml @@ -6,3 +6,4 @@ tensor_parallel_size: 8 max_num_batched_tokens: 4096 max_num_partial_prefills: 3 max_long_partial_prefills: 3 +quantization: wint8 diff --git a/benchmarks/yaml/eb45-21b-a3b-32k-bf16-tp2-mooncake.yaml b/benchmarks/yaml/eb45-21b-a3b-32k-bf16-tp2-mooncake.yaml new file mode 100644 index 00000000000..0021a04f56c --- /dev/null +++ b/benchmarks/yaml/eb45-21b-a3b-32k-bf16-tp2-mooncake.yaml @@ -0,0 +1,5 @@ +max_model_len: 131072 +max_num_seqs: 256 +tensor_parallel_size: 2 +kvcache_storage_backend: "mooncake" +enable_output_caching: True diff --git a/benchmarks/yaml/eb45-21b-a3b-32k-wint8-cpu-cache.yaml b/benchmarks/yaml/eb45-21b-a3b-32k-wint8-cpu-cache.yaml new file mode 100644 index 00000000000..d26d89229db --- /dev/null +++ b/benchmarks/yaml/eb45-21b-a3b-32k-wint8-cpu-cache.yaml @@ -0,0 +1,7 @@ +max_model_len: 32768 +max_num_seqs: 128 +kv_cache_ratio: 0.75 +tensor_parallel_size: 1 +max_num_batched_tokens: 32768 +quantization: wint8 +swap_space: 100 diff --git a/benchmarks/yaml/eb45-32k-wint2-tp4.yaml b/benchmarks/yaml/eb45-32k-wint2-tp4.yaml new file mode 100644 index 00000000000..c82ea744d06 --- /dev/null +++ b/benchmarks/yaml/eb45-32k-wint2-tp4.yaml @@ -0,0 +1,5 @@ +max_model_len: 32768 +max_num_seqs: 256 +kv_cache_ratio: 0.75 +tensor_parallel_size: 4 +gpu_memory_utilization: 0.9 diff --git a/benchmarks/yaml/eb45-32k-wint4-a800-tp4-cinn.yaml b/benchmarks/yaml/eb45-32k-wint4-a800-tp4-cinn.yaml new file mode 100644 index 00000000000..cd6867a5834 --- /dev/null +++ b/benchmarks/yaml/eb45-32k-wint4-a800-tp4-cinn.yaml @@ -0,0 +1,9 @@ +max_model_len: 32768 +max_num_seqs: 96 +gpu_memory_utilization: 0.85 +kv_cache_ratio: 0.71 +tensor_parallel_size: 4 +quantization: wint4 +graph_optimization_config: + use_cudagraph: True + graph_opt_level: 2 diff --git a/benchmarks/yaml/eb45-32k-wint4-a800-tp4.yaml b/benchmarks/yaml/eb45-32k-wint4-a800-tp4.yaml index 6ac9a218875..974c2eaf776 100644 --- a/benchmarks/yaml/eb45-32k-wint4-a800-tp4.yaml +++ b/benchmarks/yaml/eb45-32k-wint4-a800-tp4.yaml @@ -1,5 +1,6 @@ max_model_len: 32768 max_num_seqs: 96 -gpu_memory_utilization: 0.9 +gpu_memory_utilization: 0.85 kv_cache_ratio: 0.71 tensor_parallel_size: 4 +quantization: wint4 diff --git a/benchmarks/yaml/eb45-32k-wint4-ep4-tp4.yaml b/benchmarks/yaml/eb45-32k-wint4-ep4-tp4.yaml new file mode 100644 index 00000000000..d05375caa19 --- /dev/null +++ b/benchmarks/yaml/eb45-32k-wint4-ep4-tp4.yaml @@ -0,0 +1,7 @@ +num_gpu_blocks_override: 1024 +max_model_len: 8192 +max_num_seqs: 64 +data_parallel_size: 4 +tensor_parallel_size: 1 +enable_expert_parallel: True +quantization: wint4 diff --git a/benchmarks/yaml/eb45-32k-wint4-mtp-h800-tp4.yaml b/benchmarks/yaml/eb45-32k-wint4-mtp-h800-tp4.yaml index c609fba495b..c71c247ee9f 100644 --- a/benchmarks/yaml/eb45-32k-wint4-mtp-h800-tp4.yaml +++ b/benchmarks/yaml/eb45-32k-wint4-mtp-h800-tp4.yaml @@ -1,6 +1,6 @@ max_model_len: 32768 max_num_seqs: 96 -gpu_memory_utilization: 0.9 +gpu_memory_utilization: 0.8 kv_cache_ratio: 0.71 tensor_parallel_size: 4 quantization: wint4 diff --git a/benchmarks/yaml/eb45-32k-wint4-tp1-dp4_ep.yaml b/benchmarks/yaml/eb45-32k-wint4-tp1-dp4_ep.yaml new file mode 100644 index 00000000000..d05375caa19 --- /dev/null +++ b/benchmarks/yaml/eb45-32k-wint4-tp1-dp4_ep.yaml @@ -0,0 +1,7 @@ +num_gpu_blocks_override: 1024 +max_model_len: 8192 +max_num_seqs: 64 +data_parallel_size: 4 +tensor_parallel_size: 1 +enable_expert_parallel: True +quantization: wint4 diff --git a/benchmarks/yaml/eb45-32k-wint4-tp4_decode.yaml b/benchmarks/yaml/eb45-32k-wint4-tp4_decode.yaml index 985ef7a34d2..34de7cd762f 100644 --- a/benchmarks/yaml/eb45-32k-wint4-tp4_decode.yaml +++ b/benchmarks/yaml/eb45-32k-wint4-tp4_decode.yaml @@ -13,3 +13,4 @@ pd_comm_port: "2334" max_num_batched_tokens: 384 max_num_partial_prefills: 3 max_long_partial_prefills: 3 +quantization: wint4 diff --git a/benchmarks/yaml/eb45-32k-wint4-tp4_decode_router.yaml b/benchmarks/yaml/eb45-32k-wint4-tp4_decode_router.yaml new file mode 100644 index 00000000000..34de7cd762f --- /dev/null +++ b/benchmarks/yaml/eb45-32k-wint4-tp4_decode_router.yaml @@ -0,0 +1,16 @@ +max_model_len: 32768 +max_num_seqs: 256 +gpu_memory_utilization: 0.9 +kv_cache_ratio: 0.8 +tensor_parallel_size: 4 +cache_queue_port: 55663 +enable_chunked_prefill: True +splitwise_role: decode +engine_worker_queue_port: 6678 +cache_transfer_protocol: "rdma,ipc" +rdma_comm_ports: "7671,7672,7673,7674" +pd_comm_port: "2334" +max_num_batched_tokens: 384 +max_num_partial_prefills: 3 +max_long_partial_prefills: 3 +quantization: wint4 diff --git a/benchmarks/yaml/eb45-32k-wint4-tp4_prefill.yaml b/benchmarks/yaml/eb45-32k-wint4-tp4_prefill.yaml index 2831838fd3e..cf4b4a51ddb 100644 --- a/benchmarks/yaml/eb45-32k-wint4-tp4_prefill.yaml +++ b/benchmarks/yaml/eb45-32k-wint4-tp4_prefill.yaml @@ -10,3 +10,4 @@ engine_worker_queue_port: 6677 cache_transfer_protocol: "rdma,ipc" rdma_comm_ports: "7675,7676,7677,7678" pd_comm_port: "2333" +quantization: wint4 diff --git a/benchmarks/yaml/eb45-32k-wint4-tp4_prefill_router.yaml b/benchmarks/yaml/eb45-32k-wint4-tp4_prefill_router.yaml new file mode 100644 index 00000000000..cf4b4a51ddb --- /dev/null +++ b/benchmarks/yaml/eb45-32k-wint4-tp4_prefill_router.yaml @@ -0,0 +1,13 @@ +max_model_len: 32768 +max_num_seqs: 16 +gpu_memory_utilization: 0.9 +kv_cache_ratio: 0.9 +tensor_parallel_size: 4 +splitwise_role: prefill +enable_prefix_caching: True +cache_queue_port: 55664 +engine_worker_queue_port: 6677 +cache_transfer_protocol: "rdma,ipc" +rdma_comm_ports: "7675,7676,7677,7678" +pd_comm_port: "2333" +quantization: wint4 diff --git a/benchmarks/yaml/eb45-32k-wint8-a800-tp8.yaml b/benchmarks/yaml/eb45-32k-wint8-a800-tp8.yaml index a8a51c08663..86e08d343e5 100644 --- a/benchmarks/yaml/eb45-32k-wint8-a800-tp8.yaml +++ b/benchmarks/yaml/eb45-32k-wint8-a800-tp8.yaml @@ -1,5 +1,6 @@ max_model_len: 32768 max_num_seqs: 96 -gpu_memory_utilization: 0.9 +gpu_memory_utilization: 0.85 kv_cache_ratio: 0.71 tensor_parallel_size: 8 +quantization: wint8 diff --git a/benchmarks/yaml/eb45-8k-fp8-tp1-dp8_ep.yaml b/benchmarks/yaml/eb45-8k-fp8-tp1-dp8_ep.yaml new file mode 100644 index 00000000000..a65fc42e6de --- /dev/null +++ b/benchmarks/yaml/eb45-8k-fp8-tp1-dp8_ep.yaml @@ -0,0 +1,6 @@ +num_gpu_blocks_override: 1024 +max_model_len: 8192 +max_num_seqs: 64 +data_parallel_size: 8 +tensor_parallel_size: 1 +enable_expert_parallel: True diff --git a/benchmarks/yaml/eb45-vl-128k-wint4-h800-tp8.yaml b/benchmarks/yaml/eb45-vl-128k-wint4-h800-tp8.yaml new file mode 100644 index 00000000000..0c5f0449485 --- /dev/null +++ b/benchmarks/yaml/eb45-vl-128k-wint4-h800-tp8.yaml @@ -0,0 +1,11 @@ +enable_mm: True +max_model_len: 131072 +max_num_seqs: 56 +gpu_memory_utilization: 0.8 +kv_cache_ratio: 0.8 +tensor_parallel_size: 8 +quantization: wint4 +limit_mm_per_prompt: '{"image": 100, "video": 100}' +enable_chunked_prefill: True +max_num_batched_tokens: 384 +reasoning_parser: ernie-45-vl diff --git a/benchmarks/yaml/eb45-vl-28b-thinking-128k-wint8.yaml b/benchmarks/yaml/eb45-vl-28b-thinking-128k-wint8.yaml new file mode 100644 index 00000000000..de436dc323f --- /dev/null +++ b/benchmarks/yaml/eb45-vl-28b-thinking-128k-wint8.yaml @@ -0,0 +1,8 @@ +max_model_len: 131072 +tensor_parallel_size: 1 +quantization: wint8 +max_num_seqs: 32 +reasoning_parser: ernie-45-vl-thinking +tool_call_parser: ernie-45-vl-thinking +load_choices: "default_v1" +mm-processor-kwargs: '{"image_max_pixels": 12845056 }' diff --git a/benchmarks/yaml/eb45-vl-28b-thinking-32k-wint8.yaml b/benchmarks/yaml/eb45-vl-28b-thinking-32k-wint8.yaml new file mode 100644 index 00000000000..2be1e0e866c --- /dev/null +++ b/benchmarks/yaml/eb45-vl-28b-thinking-32k-wint8.yaml @@ -0,0 +1,8 @@ +max_model_len: 32768 +tensor_parallel_size: 1 +quantization: wint8 +max_num_seqs: 32 +reasoning_parser: ernie-45-vl-thinking +tool_call_parser: ernie-45-vl-thinking +load_choices: "default_v1" +mm-processor-kwargs: '{"image_max_pixels": 12845056 }' diff --git a/benchmarks/yaml/eb45-vl-32k-wint4-a800-tp8-cinn.yaml b/benchmarks/yaml/eb45-vl-32k-wint4-a800-tp8-cinn.yaml new file mode 100644 index 00000000000..755b567b72f --- /dev/null +++ b/benchmarks/yaml/eb45-vl-32k-wint4-a800-tp8-cinn.yaml @@ -0,0 +1,11 @@ +max_model_len: 32768 +max_num_seqs: 56 +gpu_memory_utilization: 0.9 +kv_cache_ratio: 0.8 +tensor_parallel_size: 8 +quantization: wint4 +limit_mm_per_prompt: '{"image": 100, "video": 100}' +reasoning_parser: ernie-45-vl +graph_optimization_config: + use_cudagraph: True + graph_opt_level: 2 diff --git a/benchmarks/yaml/eb45-vl-32k-wint4-a800-tp8.yaml b/benchmarks/yaml/eb45-vl-32k-wint4-a800-tp8.yaml index 1a53f9b9a9a..751bd70e07d 100644 --- a/benchmarks/yaml/eb45-vl-32k-wint4-a800-tp8.yaml +++ b/benchmarks/yaml/eb45-vl-32k-wint4-a800-tp8.yaml @@ -7,3 +7,4 @@ tensor_parallel_size: 8 quantization: wint4 limit_mm_per_prompt: '{"image": 100, "video": 100}' reasoning_parser: ernie-45-vl +max_num_batched_tokens: 4096 diff --git a/benchmarks/yaml/eb45-vl-32k-wint8-a800-tp8.yaml b/benchmarks/yaml/eb45-vl-32k-wint8-a800-tp8.yaml index 3c803e662a9..75e2df417a3 100644 --- a/benchmarks/yaml/eb45-vl-32k-wint8-a800-tp8.yaml +++ b/benchmarks/yaml/eb45-vl-32k-wint8-a800-tp8.yaml @@ -1,7 +1,7 @@ enable_mm: True max_model_len: 32768 max_num_seqs: 36 -gpu_memory_utilization: 0.95 +gpu_memory_utilization: 0.9 kv_cache_ratio: 0.8 tensor_parallel_size: 8 quantization: wint8 diff --git a/benchmarks/yaml/eb45-vl-32k-wint8-h800-tp8.yaml b/benchmarks/yaml/eb45-vl-32k-wint8-h800-tp8.yaml index ff9611f5dfd..41d7f1869f5 100644 --- a/benchmarks/yaml/eb45-vl-32k-wint8-h800-tp8.yaml +++ b/benchmarks/yaml/eb45-vl-32k-wint8-h800-tp8.yaml @@ -1,7 +1,7 @@ enable_mm: True max_model_len: 32768 max_num_seqs: 36 -gpu_memory_utilization: 0.8 +gpu_memory_utilization: 0.85 kv_cache_ratio: 0.8 tensor_parallel_size: 8 quantization: wint8 diff --git a/benchmarks/yaml/eb45-vl-lite-32k-bf16-a800-tp1.yaml b/benchmarks/yaml/eb45-vl-lite-32k-bf16-a800-tp1.yaml new file mode 100644 index 00000000000..2a1b9148eea --- /dev/null +++ b/benchmarks/yaml/eb45-vl-lite-32k-bf16-a800-tp1.yaml @@ -0,0 +1,9 @@ +enable_mm: True +max_model_len: 32768 +max_num_seqs: 128 +gpu_memory_utilization: 0.9 +kv_cache_ratio: 0.71 +tensor_parallel_size: 1 +enable_chunked_prefill: True +max_num_batched_tokens: 384 +reasoning_parser: ernie-45-vl diff --git a/benchmarks/yaml/eb45-vl-lite-32k-wint4-a800-tp1.yaml b/benchmarks/yaml/eb45-vl-lite-32k-wint4-a800-tp1.yaml new file mode 100644 index 00000000000..ffa5ceac34b --- /dev/null +++ b/benchmarks/yaml/eb45-vl-lite-32k-wint4-a800-tp1.yaml @@ -0,0 +1,10 @@ +enable_mm: True +max_model_len: 32768 +max_num_seqs: 128 +gpu_memory_utilization: 0.9 +kv_cache_ratio: 0.71 +tensor_parallel_size: 1 +enable_chunked_prefill: True +max_num_batched_tokens: 384 +quantization: wint4 +reasoning_parser: ernie-45-vl diff --git a/benchmarks/yaml/eb45-vl-lite-32k-wint8-a800-tp1.yaml b/benchmarks/yaml/eb45-vl-lite-32k-wint8-a800-tp1.yaml new file mode 100644 index 00000000000..7a0d4a0c4db --- /dev/null +++ b/benchmarks/yaml/eb45-vl-lite-32k-wint8-a800-tp1.yaml @@ -0,0 +1,10 @@ +enable_mm: True +max_model_len: 32768 +max_num_seqs: 128 +gpu_memory_utilization: 0.9 +kv_cache_ratio: 0.71 +tensor_parallel_size: 1 +enable_chunked_prefill: True +max_num_batched_tokens: 384 +quantization: wint8 +reasoning_parser: ernie-45-vl diff --git a/benchmarks/yaml/paddleocr-vl-16k-bf16.yaml b/benchmarks/yaml/paddleocr-vl-16k-bf16.yaml new file mode 100644 index 00000000000..a5794f4337e --- /dev/null +++ b/benchmarks/yaml/paddleocr-vl-16k-bf16.yaml @@ -0,0 +1,6 @@ +max_model_len: 16384 +max_num_seqs: 256 +max_num_batched_tokens: 16384 +tensor_parallel_size: 1 +gpu_memory_utilization: 0.7 +workers: 4 diff --git a/benchmarks/yaml/qwen25_7b-vl-32k-bf16.yaml b/benchmarks/yaml/qwen25_7b-vl-32k-bf16.yaml new file mode 100644 index 00000000000..a946c0f9859 --- /dev/null +++ b/benchmarks/yaml/qwen25_7b-vl-32k-bf16.yaml @@ -0,0 +1,6 @@ +max_model_len: 32768 +max_num_seqs: 128 +gpu_memory_utilization: 0.85 +tensor_parallel_size: 1 +limit_mm_per_prompt: '{"image": 100, "video": 100}' +enable_mm: True diff --git a/benchmarks/yaml/qwen3-235b-32k-fp8-tp1-dp4_decode.yaml b/benchmarks/yaml/qwen3-235b-32k-fp8-tp1-dp4_decode.yaml new file mode 100644 index 00000000000..28acf99944a --- /dev/null +++ b/benchmarks/yaml/qwen3-235b-32k-fp8-tp1-dp4_decode.yaml @@ -0,0 +1,13 @@ +max_model_len: 32768 +max_num_seqs: 32 +data_parallel_size: 4 +tensor_parallel_size: 1 +enable_expert_parallel: True +enable_prefix_caching: False +splitwise_role: decode +cache_transfer_protocol: "rdma" +rdma_comm_ports: "7671,7672,7673,7674" +pd_comm_port: "2335" +engine_worker_queue_port: "4582,4583,4584,4585" +graph_optimization_config: + use_cudagraph: False diff --git a/benchmarks/yaml/qwen3-235b-32k-fp8-tp1-dp4_prefill.yaml b/benchmarks/yaml/qwen3-235b-32k-fp8-tp1-dp4_prefill.yaml new file mode 100644 index 00000000000..d9e879e7267 --- /dev/null +++ b/benchmarks/yaml/qwen3-235b-32k-fp8-tp1-dp4_prefill.yaml @@ -0,0 +1,13 @@ +max_model_len: 32768 +max_num_seqs: 32 +data_parallel_size: 4 +tensor_parallel_size: 1 +enable_expert_parallel: True +enable_prefix_caching: False +splitwise_role: prefill +cache_transfer_protocol: "rdma" +rdma_comm_ports: "7675,7676,7677,7678" +pd_comm_port: "2334" +engine_worker_queue_port: "4368,4369,4360,4361" +graph_optimization_config: + use_cudagraph: False diff --git a/benchmarks/yaml/qwen3-vl-64k-bf16-tp1.yaml b/benchmarks/yaml/qwen3-vl-64k-bf16-tp1.yaml new file mode 100644 index 00000000000..196b800b5bf --- /dev/null +++ b/benchmarks/yaml/qwen3-vl-64k-bf16-tp1.yaml @@ -0,0 +1,4 @@ +max_model_len: 65536 +max_num_seqs: 128 +tensor_parallel_size: 1 +limit_mm_per_prompt: '{"image": 100, "video": 100}' diff --git a/benchmarks/yaml/qwen3-vl-64k-bf16-tp2.yaml b/benchmarks/yaml/qwen3-vl-64k-bf16-tp2.yaml new file mode 100644 index 00000000000..6a0474ea4f6 --- /dev/null +++ b/benchmarks/yaml/qwen3-vl-64k-bf16-tp2.yaml @@ -0,0 +1,4 @@ +max_model_len: 65536 +max_num_seqs: 128 +tensor_parallel_size: 2 +limit_mm_per_prompt: '{"image": 100, "video": 100}' diff --git a/benchmarks/yaml/request_yaml/GLM-32k-tool-call.yaml b/benchmarks/yaml/request_yaml/GLM-32k-tool-call.yaml new file mode 100644 index 00000000000..12974a2f380 --- /dev/null +++ b/benchmarks/yaml/request_yaml/GLM-32k-tool-call.yaml @@ -0,0 +1,29 @@ +max_tokens: 32768 +tools: + - type: function + function: + name: local_knowledge_base_retrieval + description: Perform a search on a knowledge source. Returns top-5 hits with docid, score, and snippet. + parameters: + type: object + properties: + user_query: + type: string + description: Query to search the local knowledge base for relevant information + required: + - user_query + additionalProperties: false + - type: function + function: + name: get_document + description: Retrieve a full document by its docid. + parameters: + type: object + properties: + docid: + type: string + description: Document ID to retrieve + required: + - docid + additionalProperties: false +tool_choice: auto diff --git a/benchmarks/yaml/request_yaml/GLM-32k.yaml b/benchmarks/yaml/request_yaml/GLM-32k.yaml new file mode 100644 index 00000000000..c70bb5af625 --- /dev/null +++ b/benchmarks/yaml/request_yaml/GLM-32k.yaml @@ -0,0 +1,8 @@ +top_p: 0.95 +temperature: 0.6 +metadata: + min_tokens: 1 +max_tokens: 12288 +repetition_penalty: 1.0 +frequency_penalty: 0 +presence_penalty: 0 diff --git a/benchmarks/yaml/request_yaml/vLLM_default.yaml b/benchmarks/yaml/request_yaml/deepseek-32k.yaml similarity index 53% rename from benchmarks/yaml/request_yaml/vLLM_default.yaml rename to benchmarks/yaml/request_yaml/deepseek-32k.yaml index a6385823b5f..12d1198a6f9 100644 --- a/benchmarks/yaml/request_yaml/vLLM_default.yaml +++ b/benchmarks/yaml/request_yaml/deepseek-32k.yaml @@ -1,11 +1,10 @@ -top_p: 1.0 -temperature: 1.0 -metadata: - min_tokens: 1 -max_tokens: 30721 +temperature: 0.8 +top_p: 0.8 +presence_penalty: 0 repetition_penalty: 1.0 frequency_penalty: 0 -presence_penalty: 0 -skip_special_tokens: false +max_tokens: 12288 +metadata: + min_tokens: 1 chat_template_kwargs: - enable_thinking: true + enable_thinking: false diff --git a/benchmarks/yaml/request_yaml/eb45-vl-128k.yaml b/benchmarks/yaml/request_yaml/eb45-vl-128k.yaml new file mode 100644 index 00000000000..2c6a5eb7497 --- /dev/null +++ b/benchmarks/yaml/request_yaml/eb45-vl-128k.yaml @@ -0,0 +1 @@ +max_tokens: 131071 diff --git a/benchmarks/yaml/request_yaml/eb45-vl-32k.yaml b/benchmarks/yaml/request_yaml/eb45-vl-32k.yaml new file mode 100644 index 00000000000..e2fb432b979 --- /dev/null +++ b/benchmarks/yaml/request_yaml/eb45-vl-32k.yaml @@ -0,0 +1 @@ +max_tokens: 12288 diff --git a/benchmarks/yaml/request_yaml/qwen25-vl-32k.yaml b/benchmarks/yaml/request_yaml/qwen25-vl-32k.yaml new file mode 100644 index 00000000000..b26e6874970 --- /dev/null +++ b/benchmarks/yaml/request_yaml/qwen25-vl-32k.yaml @@ -0,0 +1,8 @@ +top_p: 0.8 +temperature: 0.7 +metadata: + min_tokens: 1 +max_tokens: 32768 +repetition_penalty: 1.05 +frequency_penalty: 0 +presence_penalty: 0 diff --git a/benchmarks/yaml/request_yaml/qwen3-vl-32k.yaml b/benchmarks/yaml/request_yaml/qwen3-vl-32k.yaml new file mode 100644 index 00000000000..c2197a63e54 --- /dev/null +++ b/benchmarks/yaml/request_yaml/qwen3-vl-32k.yaml @@ -0,0 +1,3 @@ +top_p: 0.8 +temperature: 0.7 +max_tokens: 32768 diff --git a/benchmarks/yaml/request_yaml/request.yaml b/benchmarks/yaml/request_yaml/request.yaml new file mode 100644 index 00000000000..9fc603354b6 --- /dev/null +++ b/benchmarks/yaml/request_yaml/request.yaml @@ -0,0 +1,11 @@ +top_p: 0.8 +temperature: 0.8 +max_tokens: 12288 +repetition_penalty: 1.0 +frequency_penalty: 0 +presence_penalty: 0 +metadata: + enable_thinking: false + min_tokens: 1 +chat_template_kwargs: + enable_thinking: false diff --git a/benchmarks/yaml/request_yaml/x1-128k.yaml b/benchmarks/yaml/request_yaml/x1-128k.yaml new file mode 100644 index 00000000000..e02e466c7b0 --- /dev/null +++ b/benchmarks/yaml/request_yaml/x1-128k.yaml @@ -0,0 +1,8 @@ +top_p: 0.95 +temperature: 0.6 +metadata: + min_tokens: 1 +max_tokens: 131071 +repetition_penalty: 1.0 +frequency_penalty: 0 +presence_penalty: 0 diff --git a/benchmarks/yaml/x1-64k-w4a8c8-tp4.yaml b/benchmarks/yaml/x1-64k-w4a8c8-tp4.yaml new file mode 100644 index 00000000000..a5bb750ba90 --- /dev/null +++ b/benchmarks/yaml/x1-64k-w4a8c8-tp4.yaml @@ -0,0 +1,10 @@ +reasoning-parser: ernie-x1 +tool_call_parser: ernie-x1 +tensor_parallel_size: 4 +max_model_len: 65536 +max_num_seqs: 128 +enable_prefix_caching: True +enable_chunked_prefill: True +gpu_memory_utilization: 0.85 +graph_optimization_config: + use_cudagraph: True diff --git a/benchmarks/yaml/x1-a3b-128k-wint8-h800-tp1.yaml b/benchmarks/yaml/x1-a3b-128k-wint8-h800-tp1.yaml new file mode 100644 index 00000000000..4476a55a9fd --- /dev/null +++ b/benchmarks/yaml/x1-a3b-128k-wint8-h800-tp1.yaml @@ -0,0 +1,7 @@ +tensor_parallel_size: 1 +max_model_len: 131072 +max_num_seqs: 32 +reasoning_parser: ernie-x1 +tool_call_parser: ernie-x1 +load_choices: "default_v1" +quantization: wint8 diff --git a/build.sh b/build.sh index aa7f40ef847..8e830ba71c7 100644 --- a/build.sh +++ b/build.sh @@ -1,6 +1,6 @@ #!/usr/bin/env bash -# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,6 +14,30 @@ # See the License for the specific language governing permissions and # limitations under the License. +function show_help() { + echo "Usage: bash build.sh [BUILD_WHEEL] [PYTHON] [FD_CPU_USE_BF16] [FD_BUILDING_ARCS] [FD_USE_PRECOMPILED] [FD_COMMIT_ID]" + echo "" + echo "BUILD_WHEEL modes:" + echo " 0 Build custom ops only (no wheel packaging or pip install)" + echo " 1 Full build: compile C++ ops + build wheel + pip install (default)" + echo "" + echo "Arguments:" + echo " PYTHON Python executable (default: python)" + echo " FD_CPU_USE_BF16 Enable CPU BF16 ops: true/false (default: false)" + echo " FD_BUILDING_ARCS Target CUDA architectures, e.g. \"[80, 90, 100]\"" + echo " FD_USE_PRECOMPILED Use precompiled ops: 0=source, 1=precompiled (default: 0)" + echo " FD_COMMIT_ID Commit ID for precompiled wheel lookup" + echo "" + echo "Examples:" + echo " bash build.sh 1 python false \"[90]\" # Full build for SM90" + echo " bash build.sh 0 python false \"[80,90]\" # Build ops only" + exit 0 +} + +if [ "${1}" = "-h" ] || [ "${1}" = "--help" ]; then + show_help +fi + BUILD_WHEEL=${1:-1} PYTHON_VERSION=${2:-"python"} export python=$PYTHON_VERSION @@ -22,7 +46,13 @@ FD_CPU_USE_BF16=${3:-"false"} # For SM90 (Hopper), use 90. For SM100 (Blackwell), use 100. # These will be translated to 90a / 100a in setup_ops.py for specific features. FD_BUILDING_ARCS=${4:-""} - +# FD_USE_PRECOMPILED: Specify whether to use precompiled custom ops. +# 0 = build ops from source (default) +# 1 = use precompiled ops +FD_USE_PRECOMPILED=${5:-0} +# FD_COMMIT_ID: Specify the commit ID for locating precompiled wheel packages. +# If not provided, the current git commit ID will be used automatically. +FD_COMMIT_ID=${6:-""} # paddle distributed use to set archs unset PADDLE_CUDA_ARCH_LIST @@ -31,16 +61,17 @@ unset PADDLE_CUDA_ARCH_LIST DIST_DIR="dist" BUILD_DIR="build" EGG_DIR="fastdeploy.egg-info" +PRE_WHEEL_DIR="pre_wheel" # custom_ops directory config OPS_SRC_DIR="custom_ops" -OPS_TMP_DIR_BASE="tmp_base" OPS_TMP_DIR="tmp" # command line log config RED='\033[0;31m' BLUE='\033[0;34m' GREEN='\033[1;32m' +YELLOW='\033[1;33m' BOLD='\033[1m' NONE='\033[0m' @@ -58,45 +89,73 @@ function python_version_check() { function init() { echo -e "${BLUE}[init]${NONE} removing building directory..." - rm -rf $DIST_DIR $BUILD_DIR $EGG_DIR + rm -rf $BUILD_DIR $EGG_DIR $PRE_WHEEL_DIR ${python} -m pip install setuptools_scm echo -e "${BLUE}[init]${NONE} ${GREEN}init success\n" } - function copy_ops(){ + local tmp_dir=${1:-$OPS_TMP_DIR} OPS_VERSION="0.0.0" PY_MAIN_VERSION=`${python} -V 2>&1 | awk '{print $2}' | awk -F '.' '{print $1}'` PY_SUB_VERSION=`${python} -V 2>&1 | awk '{print $2}' | awk -F '.' '{print $2}'` PY_VERSION="py${PY_MAIN_VERSION}.${PY_SUB_VERSION}" SYSTEM_VERSION=`${python} -c "import platform; print(platform.system().lower())"` PROCESSOR_VERSION=`${python} -c "import platform; print(platform.processor())"` - WHEEL_BASE_NAME="fastdeploy_base_ops-${OPS_VERSION}-${PY_VERSION}-${SYSTEM_VERSION}-${PROCESSOR_VERSION}.egg" - WHEEL_NAME="fastdeploy_ops-${OPS_VERSION}-${PY_VERSION}-${SYSTEM_VERSION}-${PROCESSOR_VERSION}.egg" - WHEEL_CPU_NAME="fastdeploy_cpu_ops-${OPS_VERSION}-${PY_VERSION}-${SYSTEM_VERSION}-${PROCESSOR_VERSION}.egg" + EGG_NAME="fastdeploy_ops-${OPS_VERSION}-${PY_VERSION}-${SYSTEM_VERSION}-${PROCESSOR_VERSION}.egg" + EGG_CPU_NAME="fastdeploy_cpu_ops-${OPS_VERSION}-${PY_VERSION}-${SYSTEM_VERSION}-${PROCESSOR_VERSION}.egg" + + # Add compatibility for modern python packaging methods + LEGACY_PACKAGE_DIR="${tmp_dir}/${EGG_NAME}" + MODERN_PACKAGE_DIR="${tmp_dir}/fastdeploy_ops" + LEGACY_PACKAGE_DIR_CPU="${tmp_dir}/${EGG_CPU_NAME}" + MODERN_PACKAGE_DIR_CPU="${tmp_dir}/fastdeploy_cpu_ops" + + # Handle GPU ops directory compatibility between modern and legacy naming + if [ -d "${MODERN_PACKAGE_DIR}" ]; then + echo -e "${GREEN}[Info]${NONE} Ready to copy ops from modern directory ${WHEEL_MODERN_NAME} to target directory" + TMP_PACKAGE_DIR="${tmp_dir}" + # If modern directory doesn't exist, check for legacy directory, this branch should be removed in the future + elif [ -d "${LEGACY_PACKAGE_DIR}" ]; then + echo -e "${YELLOW}[Warning]${NONE} ${EGG_NAME} directory exists. This is a legacy packaging and distribution method." + TMP_PACKAGE_DIR="${LEGACY_PACKAGE_DIR}" + else + echo -e "${RED}[Error]${NONE} Neither modern nor legacy directory for gpu ops found in ${tmp_dir}" + echo -e "${BLUE}[Info]${NONE} Maybe the compilation failed, please clean the build directory (currently ${BUILD_DIR}) and egg directory (currently ${EGG_DIR}) and try again." + echo -e "${BLUE}[Info]${NONE} If the build still fails, please try to use a clean FastDeploy code and a clean environment to compile again." + exit 1 + fi + + # Handle CPU ops directory compatibility between modern and legacy naming + if [ -d "${MODERN_PACKAGE_DIR_CPU}" ]; then + echo -e "${GREEN}[Info]${NONE} Ready to copy ops from modern directory ${WHEEL_MODERN_CPU_NAME} to target directory" + TMP_PACKAGE_DIR_BASE="${tmp_dir}" + # If modern directory doesn't exist, check for legacy directory, this branch should be removed in the future + elif [ -d "${LEGACY_PACKAGE_DIR_CPU}" ]; then + echo -e "${YELLOW}[Warning]${NONE} ${EGG_CPU_NAME} directory exists. This is a legacy packaging and distribution method." + TMP_PACKAGE_DIR_BASE="${LEGACY_PACKAGE_DIR_CPU}" + else + echo -e "${YELLOW}[Warning]${NONE} Neither modern nor legacy directory for cpu ops found in ${tmp_dir}" + fi is_rocm=`$python -c "import paddle; print(paddle.is_compiled_with_rocm())"` if [ "$is_rocm" = "True" ]; then DEVICE_TYPE="rocm" - mkdir -p ../fastdeploy/model_executor/ops/base - cp -r ./${OPS_TMP_DIR_BASE}/${WHEEL_BASE_NAME}/* ../fastdeploy/model_executor/ops/base - cp -r ./${OPS_TMP_DIR}/${WHEEL_NAME}/* ../fastdeploy/model_executor/ops/gpu - echo -e "BASE and ROCM ops have been copy to fastdeploy" + cp -r ${TMP_PACKAGE_DIR}/* ../fastdeploy/model_executor/ops/gpu + echo -e "ROCM ops have been copy to fastdeploy" return fi - mkdir -p ../fastdeploy/model_executor/ops/base is_cuda=`$python -c "import paddle; print(paddle.is_compiled_with_cuda())"` if [ "$is_cuda" = "True" ]; then DEVICE_TYPE="gpu" - cp -r ./${OPS_TMP_DIR_BASE}/${WHEEL_BASE_NAME}/* ../fastdeploy/model_executor/ops/base - cp -r ./${OPS_TMP_DIR}/${WHEEL_NAME}/* ../fastdeploy/model_executor/ops/gpu - echo -e "BASE and CUDA ops have been copy to fastdeploy" + cp -r ${TMP_PACKAGE_DIR}/* ../fastdeploy/model_executor/ops/gpu + echo -e "CUDA ops have been copy to fastdeploy" return fi is_xpu=`$python -c "import paddle; print(paddle.is_compiled_with_xpu())"` if [ "$is_xpu" = "True" ]; then DEVICE_TYPE="xpu" - cp -r ./${OPS_TMP_DIR}/${WHEEL_NAME}/* ../fastdeploy/model_executor/ops/xpu + cp -r ${TMP_PACKAGE_DIR}/* ../fastdeploy/model_executor/ops/xpu echo -e "xpu ops have been copy to fastdeploy" return fi @@ -104,7 +163,7 @@ function copy_ops(){ is_npu=`$python -c "import paddle; print(paddle.is_compiled_with_custom_device('npu'))"` if [ "$is_npu" = "True" ]; then DEVICE_TYPE="npu" - cp -r ${OPS_TMP_DIR}/${WHEEL_NAME}/* ../fastdeploy/model_executor/ops/npu + cp -r ${TMP_PACKAGE_DIR}/* ../fastdeploy/model_executor/ops/npu echo -e "npu ops have been copy to fastdeploy" return fi @@ -112,55 +171,176 @@ function copy_ops(){ if_corex=`$python -c "import paddle; print(paddle.is_compiled_with_custom_device(\"iluvatar_gpu\"))"` if [ "$if_corex" = "True" ]; then DEVICE_TYPE="iluvatar-gpu" - cp -r ./${OPS_TMP_DIR_BASE}/${WHEEL_BASE_NAME}/* ../fastdeploy/model_executor/ops/base - cp -r ./${OPS_TMP_DIR}/${WHEEL_NAME}/* ../fastdeploy/model_executor/ops/iluvatar - echo -e "BASE and Iluvatar ops have been copy to fastdeploy" + cp -r ${TMP_PACKAGE_DIR}/* ../fastdeploy/model_executor/ops/iluvatar + echo -e "Iluvatar ops have been copy to fastdeploy" return fi is_gcu=`$python -c "import paddle; print(paddle.is_compiled_with_custom_device('gcu'))"` if [ "$is_gcu" = "True" ]; then DEVICE_TYPE="gcu" - cp -r ${OPS_TMP_DIR}/${WHEEL_NAME}/* ../fastdeploy/model_executor/ops/gcu + cp -r ${TMP_PACKAGE_DIR}/* ../fastdeploy/model_executor/ops/gcu echo -e "gcu ops have been copy to fastdeploy" return fi + is_maca=`$python -c "import paddle; print(paddle.device.is_compiled_with_custom_device('metax_gpu'))"` + if [ "$is_maca" = "True" ]; then + DEVICE_TYPE="metax_gpu" + cp -r ${TMP_PACKAGE_DIR}/* ../fastdeploy/model_executor/ops/gpu + echo -e "MACA ops have been copy to fastdeploy" + return + fi + + is_intel_hpu=`$python -c "import paddle; print(paddle.is_compiled_with_custom_device('intel_hpu'))"` + if [ "$is_intel_hpu" = "True" ]; then + DEVICE_TYPE="intel-hpu" + echo -e "intel_hpu ops have been copy to fastdeploy" + return + fi + DEVICE_TYPE="cpu" - cp -r ./${OPS_TMP_DIR_BASE}/${WHEEL_BASE_NAME}/* ../fastdeploy/model_executor/ops/base cd ../../../../ - cp -r ${OPS_TMP_DIR}/${WHEEL_CPU_NAME}/* ../fastdeploy/model_executor/ops/cpu - echo -e "BASE and CPU ops have been copy to fastdeploy" + cp -r ${tmp_dir}/${WHEEL_CPU_NAME}/* ../fastdeploy/model_executor/ops/cpu + echo -e "CPU ops have been copy to fastdeploy" return } +function extract_ops_from_precompiled_wheel() { + local WHL_NAME="fastdeploy_gpu-0.0.0-py3-none-any.whl" + if [ -z "$FD_COMMIT_ID" ]; then + if git rev-parse HEAD >/dev/null 2>&1; then + FD_COMMIT_ID=$(git rev-parse HEAD) + echo -e "${BLUE}[init]${NONE} Using current repo commit ID: ${GREEN}${FD_COMMIT_ID}${NONE}" + else + echo -e "${RED}[ERROR]${NONE} Cannot determine commit ID (not a git repo). Please provide manually." + exit 1 + fi + fi + + CUDA_VERSION=$(nvcc --version | grep "release" | sed -E 's/.*release ([0-9]+)\.([0-9]+).*/\1\2/') + echo -e "${BLUE}[info]${NONE} Detected CUDA version: ${GREEN}cu${CUDA_VERSION}${NONE}" + + GPU_ARCH_STR=$(nvidia-smi --query-gpu=compute_cap --format=csv,noheader \ + | awk '{printf("%d\n",$1*10)}' | sort -u | awk '{printf("SM_%s_",$1)}' | sed 's/_$//') + echo -e "${BLUE}[info]${NONE} Detected GPU arch: ${GREEN}${GPU_ARCH_STR}${NONE}" + + local WHL_PATH="${PRE_WHEEL_DIR}/${WHL_NAME}" + local REMOTE_URL="https://paddle-qa.bj.bcebos.com/paddle-pipeline/FastDeploy_ActionCE/cu${CUDA_VERSION}/${GPU_ARCH_STR}/develop/${FD_COMMIT_ID}/${WHL_NAME}" + + mkdir -p "${PRE_WHEEL_DIR}" + + if [ ! -f "$WHL_PATH" ]; then + echo -e "${BLUE}[precompiled]${NONE} Local wheel not found, downloading from: ${REMOTE_URL}" + wget --no-check-certificate -O "$WHL_PATH" "$REMOTE_URL" || { + echo -e "${YELLOW}[WARNING]${NONE} Failed to download wheel." + return 1 + } + echo -e "${GREEN}[SUCCESS]${NONE} Downloaded precompiled wheel to ${WHL_PATH}" + else + echo -e "${BLUE}[precompiled]${NONE} Found local wheel: ${WHL_PATH}" + if ! unzip -t "$WHL_PATH" >/dev/null 2>&1; then + echo -e "${BLUE}[WARNING]${NONE} Local wheel seems invalid." + echo -e "${BLUE}[fallback]${NONE} Falling back to source compilation..." + return 1 + fi + fi + + local TMP_DIR="${PRE_WHEEL_DIR}/tmp_whl_unpack" + rm -rf "$TMP_DIR" + mkdir -p "$TMP_DIR" + + echo -e "${BLUE}[precompiled]${NONE} Unpacking wheel..." + ${python} -m zipfile -e "$WHL_PATH" "$TMP_DIR" + + local DATA_DIR + DATA_DIR=$(find "$TMP_DIR" -maxdepth 1 -type d -name "*.data" | head -n 1) + if [ -z "$DATA_DIR" ]; then + echo -e "${RED}[ERROR]${NONE} Cannot find *.data directory in unpacked wheel." + rm -rf "$TMP_DIR" + echo -e "${YELLOW}[fallback]${NONE} Falling back to source compilation..." + FD_USE_PRECOMPILED=0 + return 1 + fi + + local PLATLIB_DIR="${DATA_DIR}/platlib" + local SRC_DIR="${PLATLIB_DIR}/fastdeploy/model_executor/ops/gpu" + local DST_DIR="fastdeploy/model_executor/ops/gpu" + + if [ ! -d "$SRC_DIR" ]; then + echo -e "${RED}[ERROR]${NONE} GPU ops directory not found in wheel: $SRC_DIR" + rm -rf "$TMP_DIR" + echo -e "${YELLOW}[fallback]${NONE} Falling back to source compilation..." + FD_USE_PRECOMPILED=0 + return 1 + fi + + echo -e "${BLUE}[precompiled]${NONE} Copying GPU precompiled contents..." + mkdir -p "$DST_DIR" + cp -r "$SRC_DIR/deep_gemm" "$DST_DIR/" 2>/dev/null || true + # Check for modern Python packaging approach (fastdeploy_ops directory) + # If exists, copy the entire directory; otherwise, fall back to legacy method (individual files) + if [ -d "$SRC_DIR/fastdeploy_ops" ]; then + cp -r "$SRC_DIR/fastdeploy_ops" "$DST_DIR/" 2>/dev/null || true + else + cp -r "$SRC_DIR/fastdeploy_ops.py" "$DST_DIR/" 2>/dev/null || true + cp -f "$SRC_DIR/"fastdeploy_ops_*.so "$DST_DIR/" 2>/dev/null || true + fi + cp -f "$SRC_DIR/version.txt" "$DST_DIR/" 2>/dev/null || true + + echo -e "${GREEN}[SUCCESS]${NONE} Installed FastDeploy using precompiled wheel." + rm -rf "${PRE_WHEEL_DIR}" +} + +function build_custom_ops() { + if [ "$FD_UNIFY_BUILD" ]; then + mkdir -p ${OPS_SRC_DIR}/${OPS_TMP_DIR} + + custom_ops_dir=${OPS_TMP_DIR}/fastdeploy_ops_86 + build_and_install_ops "[86]" "$custom_ops_dir" + + custom_ops_dir=${OPS_TMP_DIR}/fastdeploy_ops_89 + build_and_install_ops "[89]" "$custom_ops_dir" + + build_and_install_ops "[80, 90]" "${OPS_TMP_DIR}" + cp -r $OPS_SRC_DIR/$OPS_TMP_DIR/* ./fastdeploy/model_executor/ops/gpu + else + build_and_install_ops "$FD_BUILDING_ARCS" "$OPS_TMP_DIR" + cd $OPS_SRC_DIR + copy_ops $OPS_TMP_DIR + cd .. + fi +} + function build_and_install_ops() { + local building_arcs=${1:-$FD_BUILDING_ARCS} + local tmp_dir=${2:-$OPS_TMP_DIR} + echo "BUILD CUSTOM OPS: ${building_arcs}, ${tmp_dir}" cd $OPS_SRC_DIR export no_proxy=bcebos.com,paddlepaddle.org.cn,${no_proxy} - echo -e "${BLUE}[build]${NONE} build and install fastdeploy_base_ops..." - ${python} setup_ops_base.py install --install-lib ${OPS_TMP_DIR_BASE} - find ${OPS_TMP_DIR_BASE} -type f -name "*.o" -exec rm -f {} \; echo -e "${BLUE}[build]${NONE} build and install fastdeploy_ops..." - TMP_DIR_REAL_PATH=`readlink -f ${OPS_TMP_DIR}` + TMP_DIR_REAL_PATH=`readlink -f ${tmp_dir}` is_xpu=`$python -c "import paddle; print(paddle.is_compiled_with_xpu())"` if [ "$is_xpu" = "True" ]; then - cd xpu_ops/src + cd xpu_ops bash build.sh ${TMP_DIR_REAL_PATH} - cd ../.. + cd .. elif [ "$FD_CPU_USE_BF16" == "true" ]; then - if [ "$FD_BUILDING_ARCS" == "" ]; then - FD_CPU_USE_BF16=True ${python} setup_ops.py install --install-lib ${OPS_TMP_DIR} + if [ "$building_arcs" == "" ]; then + FD_CPU_USE_BF16=True ${python} setup_ops.py install --install-lib ${tmp_dir} else - FD_BUILDING_ARCS=${FD_BUILDING_ARCS} FD_CPU_USE_BF16=True ${python} setup_ops.py install --install-lib ${OPS_TMP_DIR} + FD_BUILDING_ARCS=${building_arcs} FD_CPU_USE_BF16=True ${python} setup_ops.py install --install-lib ${tmp_dir} fi - find ${OPS_TMP_DIR} -type f -name "*.o" -exec rm -f {} \; + find ${tmp_dir} -type f -name "*.o" -exec rm -f {} \; elif [ "$FD_CPU_USE_BF16" == "false" ]; then - if [ "$FD_BUILDING_ARCS" == "" ]; then - ${python} setup_ops.py install --install-lib ${OPS_TMP_DIR} + if [ "$building_arcs" == "" ]; then + ${python} setup_ops.py install --install-lib ${tmp_dir} else - FD_BUILDING_ARCS=${FD_BUILDING_ARCS} ${python} setup_ops.py install --install-lib ${OPS_TMP_DIR} + FD_BUILDING_ARCS=${building_arcs} ${python} setup_ops.py install --install-lib ${tmp_dir} + fi + if [ -d "${tmp_dir}" ]; then + find ${tmp_dir} -type f -name "*.o" -exec rm -f {} \; fi - find ${OPS_TMP_DIR} -type f -name "*.o" -exec rm -f {} \; else echo "Error: Invalid parameter '$FD_CPU_USE_BF16'. Please use true or false." exit 1 @@ -171,8 +351,6 @@ function build_and_install_ops() { fi echo -e "${BLUE}[build]${NONE} ${GREEN}build fastdeploy_ops success ${NONE}" - copy_ops - cd .. } @@ -213,7 +391,6 @@ function cleanup() { fi rm -rf $OPS_SRC_DIR/$BUILD_DIR $OPS_SRC_DIR/$EGG_DIR - rm -rf $OPS_SRC_DIR/$OPS_TMP_DIR_BASE rm -rf $OPS_SRC_DIR/$OPS_TMP_DIR } @@ -223,7 +400,7 @@ function abort() { cur_dir=`basename "$pwd"` - rm -rf $BUILD_DIR $EGG_DIR $DIST_DIR + rm -rf $BUILD_DIR $EGG_DIR ${python} -m pip uninstall -y fastdeploy-${DEVICE_TYPE} rm -rf $OPS_SRC_DIR/$BUILD_DIR $OPS_SRC_DIR/$EGG_DIR @@ -237,9 +414,44 @@ if [ "$BUILD_WHEEL" -eq 1 ]; then init version_info - build_and_install_ops - build_and_install - cleanup + # Whether to enable precompiled wheel + if [ "$FD_USE_PRECOMPILED" -eq 1 ]; then + echo -e "${BLUE}[MODE]${NONE} Using precompiled .whl" + if extract_ops_from_precompiled_wheel; then + echo -e "${GREEN}[DONE]${NONE} Precompiled wheel installed successfully." + echo -e "${BLUE}[MODE]${NONE} Building wheel package from installed files..." + build_and_install + echo -e "${BLUE}[MODE]${NONE} Installing newly built FastDeploy wheel..." + ${python} -m pip install ./dist/fastdeploy*.whl + # get Paddle version + PADDLE_VERSION=`${python} -c "import paddle; print(paddle.version.full_version)"` + PADDLE_COMMIT=`${python} -c "import paddle; print(paddle.version.commit)"` + # get FastDeploy info + EFFLLM_BRANCH=`git rev-parse --abbrev-ref HEAD` + EFFLLM_COMMIT=`git rev-parse --short HEAD` + # get Python version + PYTHON_VERSION=`${python} -c "import platform; print(platform.python_version())"` + echo -e "\n${GREEN}fastdeploy wheel packaged successfully${NONE} + ${BLUE}Python version:${NONE} $PYTHON_VERSION + ${BLUE}Paddle version:${NONE} $PADDLE_VERSION ($PADDLE_COMMIT) + ${BLUE}fastdeploy branch:${NONE} $EFFLLM_BRANCH ($EFFLLM_COMMIT)\n" + echo -e "${GREEN}wheel saved under${NONE} ${RED}${BOLD}./dist${NONE}" + cleanup + trap : 0 + exit 0 + else + echo -e "${BLUE}[fallback]${NONE} ${YELLOW}Precompiled .whl unavailable, switching to source build." + FD_USE_PRECOMPILED=0 + fi + fi + + if [ "$FD_USE_PRECOMPILED" -eq 0 ]; then + echo -e "${BLUE}[MODE]${NONE} Building from source (ops)..." + build_custom_ops + echo -e "${BLUE}[MODE]${NONE} Building full wheel from source..." + build_and_install + cleanup + fi # get Paddle version PADDLE_VERSION=`${python} -c "import paddle; print(paddle.version.full_version)"` @@ -264,10 +476,10 @@ if [ "$BUILD_WHEEL" -eq 1 ]; then echo -e "${GREEN}wheel install success${NONE}\n" trap : 0 -else +elif [ "$BUILD_WHEEL" -eq 0 ]; then init - build_and_install_ops + build_custom_ops version_info - rm -rf $BUILD_DIR $EGG_DIR $DIST_DIR + rm -rf $BUILD_DIR $EGG_DIR rm -rf $OPS_SRC_DIR/$BUILD_DIR $OPS_SRC_DIR/$EGG_DIR fi diff --git a/custom_ops/0001-DeepGEMM-95e81b3.patch b/custom_ops/0001-DeepGEMM-95e81b3.patch index c3f409c1483..eb828a1b57f 100644 --- a/custom_ops/0001-DeepGEMM-95e81b3.patch +++ b/custom_ops/0001-DeepGEMM-95e81b3.patch @@ -1,22 +1,22 @@ -From 5112002c155dceecc5e5983cdb67157e4f5400e2 Mon Sep 17 00:00:00 2001 -From: minghaipeng -Date: Wed, 25 Jun 2025 15:05:24 +0800 -Subject: [PATCH] DeepGEMM 95e81b3 +From 7008a3c8b7fe833c952f27a5ab3848c485f02b5d Mon Sep 17 00:00:00 2001 +From: K11OntheBoat <“ruianmaidanglao@163.com”> +Date: Thu, 27 Nov 2025 14:38:47 +0800 +Subject: [PATCH] Remove extra H2D in DeepGemm --- - deep_gemm/__init__.py | 2 +- - deep_gemm/include/deep_gemm/scheduler.cuh | 2 +- - deep_gemm/jit/compiler.py | 2 +- - deep_gemm/jit/interleave_ffma.py | 2 +- - deep_gemm/jit/runtime.py | 4 +- - deep_gemm/jit/template.py | 34 ++++---- - deep_gemm/jit_kernels/gemm.py | 44 +++++------ - deep_gemm/jit_kernels/m_grouped_gemm.py | 96 +++++++++++------------ - deep_gemm/jit_kernels/tuner.py | 10 +-- - deep_gemm/jit_kernels/utils.py | 18 +++-- - deep_gemm/paddle_utils.py | 20 +++++ - deep_gemm/utils.py | 30 +++---- - 12 files changed, 143 insertions(+), 121 deletions(-) + deep_gemm/__init__.py | 2 +- + deep_gemm/include/deep_gemm/scheduler.cuh | 2 +- + deep_gemm/jit/compiler.py | 2 +- + deep_gemm/jit/interleave_ffma.py | 2 +- + deep_gemm/jit/runtime.py | 4 +- + deep_gemm/jit/template.py | 34 +++---- + deep_gemm/jit_kernels/gemm.py | 44 ++++----- + deep_gemm/jit_kernels/m_grouped_gemm.py | 104 +++++++++++----------- + deep_gemm/jit_kernels/tuner.py | 10 +-- + deep_gemm/jit_kernels/utils.py | 18 ++-- + deep_gemm/paddle_utils.py | 20 +++++ + deep_gemm/utils.py | 30 +++---- + 12 files changed, 147 insertions(+), 125 deletions(-) create mode 100644 deep_gemm/paddle_utils.py diff --git a/deep_gemm/__init__.py b/deep_gemm/__init__.py @@ -257,7 +257,7 @@ index cb438b7..44aa0ed 100644 args=args ) diff --git a/deep_gemm/jit_kernels/m_grouped_gemm.py b/deep_gemm/jit_kernels/m_grouped_gemm.py -index 3b518c9..ba776bd 100644 +index 3b518c9..b94e65d 100644 --- a/deep_gemm/jit_kernels/m_grouped_gemm.py +++ b/deep_gemm/jit_kernels/m_grouped_gemm.py @@ -1,4 +1,4 @@ @@ -299,8 +299,14 @@ index 3b518c9..ba776bd 100644 `m_indices[i]` records the group which the i-th row of the LHS belong to, which means that the i-th row of the LHS matrix will be multiplied with `rhs[m_indices[i]]`. Values of `m_indices` in every-m-alignment-block must also be the same. -@@ -68,19 +68,19 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten - m__ = m_indices.numel() +@@ -64,23 +64,23 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten + rhs, rhs_scales = rhs + m, k = lhs.shape + num_groups, n, k_ = rhs.shape +- m_, n_ = out.shape +- m__ = m_indices.numel() ++ # m_, n_ = out.shape ++ # m__ = m_indices.numel() # Type and shape checks - assert m == m_ == m__ and k == k_ and n == n_ @@ -384,8 +390,14 @@ index 3b518c9..ba776bd 100644 the second element is an FP32 128x128 scaling tensor for RHS of shape `[num_groups, ⌈n / 128⌉, ⌈k / 128⌉]`. out: the BF16 output tensor of shape `[num_groups, m_max, n]`, representing the result. masked_m: a tensor of shape `[num_groups]`, `masked_m[i]` records actual rows of the `lhs[i]` matrix to compute -@@ -149,21 +149,21 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor] - num_groups___ = masked_m.numel() +@@ -145,25 +145,25 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor] + rhs, rhs_scales = rhs + num_groups, m, k = lhs.shape + num_groups_, n, k_ = rhs.shape +- num_groups__, m_, n_ = out.shape +- num_groups___ = masked_m.numel() ++ # num_groups__, m_, n_ = out.shape ++ # num_groups___ = masked_m.numel() # Type and shape checks - assert num_groups == num_groups_ == num_groups__ == num_groups___ @@ -563,7 +575,7 @@ index 0000000..2326807 +CUDA_HOME = get_cuda_home() \ No newline at end of file diff --git a/deep_gemm/utils.py b/deep_gemm/utils.py -index d5cdd01..5237f09 100644 +index d5cdd01..011f14a 100644 --- a/deep_gemm/utils.py +++ b/deep_gemm/utils.py @@ -1,15 +1,15 @@ diff --git a/custom_ops/cpu_ops/avx_weight_only_fake.cc b/custom_ops/cpu_ops/avx_weight_only_fake.cc index 2150669af95..d117e660685 100644 --- a/custom_ops/cpu_ops/avx_weight_only_fake.cc +++ b/custom_ops/cpu_ops/avx_weight_only_fake.cc @@ -19,28 +19,28 @@ std::vector InvokeAvxWeightOnly(const paddle::Tensor &x, const paddle::Tensor &w_bias, const std::string &alog, bool trans) { - auto out_shape = x.shape(); - out_shape[out_shape.size() - 1] = weight.shape()[1]; - auto out = paddle::empty(out_shape, x.dtype(), paddle::CPUPlace()); - return {out}; + auto out_shape = x.shape(); + out_shape[out_shape.size() - 1] = weight.shape()[1]; + auto out = paddle::empty(out_shape, x.dtype(), paddle::CPUPlace()); + return {out}; } std::vector> AvxWeightOnlyInferShape( std::vector x_shape, std::vector weigh_shape, std::vector weigh_bias_shape) { - int m = 1; - for (int i = 0; i < x_shape.size() - 1; i++) { - m = m * x_shape[i]; - } - return {std::vector{m, weigh_shape[1]}}; + int m = 1; + for (int i = 0; i < x_shape.size() - 1; i++) { + m = m * x_shape[i]; + } + return {std::vector{m, weigh_shape[1]}}; } std::vector AvxWeightOnlyInferDtype( paddle::DataType x_dtype, paddle::DataType weight_dtype, paddle::DataType weight_bias_dtype) { - return {x_dtype}; + return {x_dtype}; } PD_BUILD_STATIC_OP(avx_weight_only) diff --git a/custom_ops/cpu_ops/get_padding_offset.cc b/custom_ops/cpu_ops/get_padding_offset.cc index 8fe73bc8e4f..50af5a2951d 100644 --- a/custom_ops/cpu_ops/get_padding_offset.cc +++ b/custom_ops/cpu_ops/get_padding_offset.cc @@ -20,13 +20,13 @@ void remove_padding(int64_t *output_data, const int *cum_offsets, const int sequence_length, const int bsz) { - for (int bi = 0; bi < bsz; ++bi) { - for (int i = 0; i < seq_lens[bi]; ++i) { - const int tgt_seq_id = bi * sequence_length - cum_offsets[bi] + i; - const int src_seq_id = bi * sequence_length + i; - output_data[tgt_seq_id] = input_data[src_seq_id]; - } + for (int bi = 0; bi < bsz; ++bi) { + for (int i = 0; i < seq_lens[bi]; ++i) { + const int tgt_seq_id = bi * sequence_length - cum_offsets[bi] + i; + const int src_seq_id = bi * sequence_length + i; + output_data[tgt_seq_id] = input_data[src_seq_id]; } + } } void get_padding_offset_kernel(int *padding_offset, @@ -37,57 +37,53 @@ void get_padding_offset_kernel(int *padding_offset, const int *seq_lens, const int max_seq_len, const int bsz) { - for (int bi = 0; bi < bsz; ++bi) { - int cum_offset = bi == 0 ? 0 : cum_offsets[bi - 1]; - auto seq_len_now = seq_lens[bi]; - for (int i = 0; i < seq_len_now; ++i) { - padding_offset[bi * max_seq_len - cum_offset + i] = cum_offset; - } - cum_offsets_out[bi] = cum_offset; - int cum_seq_len = (bi + 1) * max_seq_len - cum_offsets[bi]; - cu_seqlens_q[bi + 1] = cum_seq_len; - cu_seqlens_k[bi + 1] = cum_seq_len; + for (int bi = 0; bi < bsz; ++bi) { + int cum_offset = bi == 0 ? 0 : cum_offsets[bi - 1]; + auto seq_len_now = seq_lens[bi]; + for (int i = 0; i < seq_len_now; ++i) { + padding_offset[bi * max_seq_len - cum_offset + i] = cum_offset; } + cum_offsets_out[bi] = cum_offset; + int cum_seq_len = (bi + 1) * max_seq_len - cum_offsets[bi]; + cu_seqlens_q[bi + 1] = cum_seq_len; + cu_seqlens_k[bi + 1] = cum_seq_len; + } } std::vector GetPaddingOffset(const paddle::Tensor &input_ids, const paddle::Tensor &cum_offsets, const paddle::Tensor &token_num, const paddle::Tensor &seq_len) { - std::vector input_ids_shape = input_ids.shape(); - const int bsz = seq_len.shape()[0]; - const int seq_length = input_ids_shape[1]; - auto cum_offsets_out = cum_offsets.copy_to(paddle::CPUPlace(), false); - auto cpu_token_num = token_num.copy_to(paddle::CPUPlace(), false); + std::vector input_ids_shape = input_ids.shape(); + const int bsz = seq_len.shape()[0]; + const int seq_length = input_ids_shape[1]; + auto cum_offsets_out = cum_offsets.copy_to(paddle::CPUPlace(), false); + auto cpu_token_num = token_num.copy_to(paddle::CPUPlace(), false); - const int token_num_data = cpu_token_num.data()[0]; - auto x_remove_padding = paddle::empty( - {token_num_data}, paddle::DataType::INT64, input_ids.place()); - auto padding_offset = paddle::empty( - {token_num_data}, paddle::DataType::INT32, input_ids.place()); - auto cu_seqlens_q = - paddle::full({bsz + 1}, 0, paddle::DataType::INT32, input_ids.place()); - auto cu_seqlens_k = - paddle::full({bsz + 1}, 0, paddle::DataType::INT32, input_ids.place()); - get_padding_offset_kernel(padding_offset.data(), - cum_offsets_out.data(), - cu_seqlens_q.data(), - cu_seqlens_k.data(), - cum_offsets.data(), - seq_len.data(), - seq_length, - bsz); - remove_padding(x_remove_padding.data(), - input_ids.data(), - seq_len.data(), - cum_offsets_out.data(), - seq_length, - bsz); - return {x_remove_padding, - cum_offsets_out, - padding_offset, - cu_seqlens_q, - cu_seqlens_k}; + const int token_num_data = cpu_token_num.data()[0]; + auto x_remove_padding = paddle::empty( + {token_num_data}, paddle::DataType::INT64, input_ids.place()); + auto padding_offset = paddle::empty( + {token_num_data}, paddle::DataType::INT32, input_ids.place()); + auto cu_seqlens_q = + paddle::full({bsz + 1}, 0, paddle::DataType::INT32, input_ids.place()); + auto cu_seqlens_k = + paddle::full({bsz + 1}, 0, paddle::DataType::INT32, input_ids.place()); + get_padding_offset_kernel(padding_offset.data(), + cum_offsets_out.data(), + cu_seqlens_q.data(), + cu_seqlens_k.data(), + cum_offsets.data(), + seq_len.data(), + seq_length, + bsz); + remove_padding(x_remove_padding.data(), + input_ids.data(), + seq_len.data(), + cum_offsets_out.data(), + seq_length, + bsz); + return {x_remove_padding, padding_offset, cu_seqlens_q, cu_seqlens_k}; } std::vector> GetPaddingOffsetInferShape( @@ -95,9 +91,9 @@ std::vector> GetPaddingOffsetInferShape( const std::vector &cum_offsets_shape, const std::vector &token_num_shape, const std::vector &seq_len_shape) { - int64_t bsz = seq_len_shape[0]; - int64_t seq_len = input_ids_shape[1]; - return {{-1}, {bsz}, {-1}, {bsz + 1}, {bsz + 1}}; + int64_t bsz = seq_len_shape[0]; + int64_t seq_len = input_ids_shape[1]; + return {{-1}, {-1}, {bsz + 1}, {bsz + 1}}; } std::vector GetPaddingOffsetInferDtype( @@ -105,20 +101,13 @@ std::vector GetPaddingOffsetInferDtype( const paddle::DataType &cum_offsets_dtype, const paddle::DataType &token_num_dtype, const paddle::DataType &seq_len_dtype) { - return {input_ids_dtype, - seq_len_dtype, - seq_len_dtype, - seq_len_dtype, - seq_len_dtype}; + return {input_ids_dtype, seq_len_dtype, seq_len_dtype, seq_len_dtype}; } PD_BUILD_STATIC_OP(get_padding_offset_cpu) .Inputs({"input_ids", "cum_offsets", "token_num", "seq_len"}) - .Outputs({"x_remove_padding", - "cum_offsets_out", - "padding_offset", - "cu_seqlens_q", - "cu_seqlens_k"}) + .Outputs( + {"x_remove_padding", "padding_offset", "cu_seqlens_q", "cu_seqlens_k"}) .SetKernelFn(PD_KERNEL(GetPaddingOffset)) .SetInferShapeFn(PD_INFER_SHAPE(GetPaddingOffsetInferShape)) .SetInferDtypeFn(PD_INFER_DTYPE(GetPaddingOffsetInferDtype)); diff --git a/custom_ops/cpu_ops/rebuild_padding.cc b/custom_ops/cpu_ops/rebuild_padding.cc index 8ce533d041b..9e4627dfb7b 100644 --- a/custom_ops/cpu_ops/rebuild_padding.cc +++ b/custom_ops/cpu_ops/rebuild_padding.cc @@ -1,4 +1,4 @@ -// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -22,39 +22,40 @@ template void RebuildPaddingCPUImpl(T *output_data, const T *input_data, - const int *cum_offsets_data, + const int *cu_seqlens_q_data, const int *seq_len_this_time_data, const int *seq_lens_decoder_data, const int *seq_lens_encoder_data, int max_input_length, int dim_embed, const int elem_nums) { - for (int i = 0; i < elem_nums; ++i) { - const int bi = i / dim_embed; - const int bias_idx = i % dim_embed; - int seq_id = 0; + for (int i = 0; i < elem_nums; ++i) { + const int bi = i / dim_embed; + const int bias_idx = i % dim_embed; + int seq_id = 0; - if (seq_len_this_time_data[bi] == 0) { - continue; - } - if (seq_lens_decoder_data[bi] == 0 && seq_lens_encoder_data[bi] == 0) { - continue; - } - if (seq_lens_encoder_data[bi] > 0) { - seq_id = seq_lens_encoder_data[bi] - 1; - } - const int ori_token_idx = - bi * max_input_length - cum_offsets_data[bi] + seq_id; - const int src_offset = ori_token_idx * dim_embed + bias_idx; + if (seq_len_this_time_data[bi] == 0) { + continue; + } + if (seq_lens_decoder_data[bi] == 0 && seq_lens_encoder_data[bi] == 0) { + continue; + } - output_data[i] = input_data[src_offset]; + if (seq_lens_encoder_data[bi] > 0) { + seq_id = seq_lens_encoder_data[bi] - 1; } + + const int ori_token_idx = cu_seqlens_q_data[bi] + seq_id; + const int src_offset = ori_token_idx * dim_embed + bias_idx; + + output_data[i] = input_data[src_offset]; + } } template void RebuildAppendPaddingCPUImpl(T *output_data, const T *input_data, - const int *cum_offsets_data, + const int *cu_seqlens_q_data, const int *seq_len_this_time_data, const int *seq_lens_decoder_data, const int *seq_lens_encoder_data, @@ -62,201 +63,199 @@ void RebuildAppendPaddingCPUImpl(T *output_data, const int max_input_length, const int dim_embed, const int64_t output_elem_nums) { - for (int i = 0; i < output_elem_nums; ++i) { - int out_token_id = i / dim_embed; - int ori_token_id = - out_token_id + output_padding_offset_data[out_token_id]; - int bi = ori_token_id / max_input_length; - if (seq_len_this_time_data[bi] == 0 || - (seq_lens_decoder_data[bi] == 0 && - seq_lens_encoder_data[bi] == 0)) { - continue; - } - int seq_id = 0; - if (seq_lens_encoder_data[bi] > 0) { - seq_id = seq_lens_encoder_data[bi] - 1; - } - int input_token_id = ori_token_id - cum_offsets_data[bi] + seq_id; - int bias_idx = i % dim_embed; - int src_offset = input_token_id * dim_embed + bias_idx; - output_data[i] = input_data[src_offset]; + for (int i = 0; i < output_elem_nums; ++i) { + int out_token_id = i / dim_embed; + int ori_token_id = out_token_id + output_padding_offset_data[out_token_id]; + int bi = ori_token_id / max_input_length; + if (seq_len_this_time_data[bi] == 0 || + (seq_lens_decoder_data[bi] == 0 && seq_lens_encoder_data[bi] == 0)) { + continue; + } + int seq_id = 0; + + if (seq_lens_encoder_data[bi] > 0) { + seq_id = seq_lens_encoder_data[bi] - 1; } + int input_token_id = cu_seqlens_q_data[bi] + seq_id; + int bias_idx = i % dim_embed; + int src_offset = input_token_id * dim_embed + bias_idx; + + output_data[i] = input_data[src_offset]; + } } std::vector RebuildPaddingCPU( const paddle::Tensor &tmp_out, - const paddle::Tensor &cum_offsets, + const paddle::Tensor &cu_seqlens_q, const paddle::Tensor &seq_len_this_time, const paddle::Tensor &seq_lens_decoder, const paddle::Tensor &seq_lens_encoder, const paddle::optional &output_padding_offset, int max_input_length) { - auto tmp_out_cpu = tmp_out.copy_to(paddle::CPUPlace(), true); - auto cum_offsets_cpu = cum_offsets.copy_to(paddle::CPUPlace(), true); - auto seq_len_this_time_cpu = - seq_len_this_time.copy_to(paddle::CPUPlace(), true); - auto seq_lens_decoder_cpu = - seq_lens_decoder.copy_to(paddle::CPUPlace(), true); - auto seq_lens_encoder_cpu = - seq_lens_encoder.copy_to(paddle::CPUPlace(), true); - paddle::optional output_padding_offset_cpu; - if (output_padding_offset) { - output_padding_offset_cpu = - output_padding_offset->copy_to(paddle::CPUPlace(), true); - } + auto tmp_out_cpu = tmp_out.copy_to(paddle::CPUPlace(), true); + auto cu_seqlens_q_cpu = cu_seqlens_q.copy_to(paddle::CPUPlace(), true); + auto seq_len_this_time_cpu = + seq_len_this_time.copy_to(paddle::CPUPlace(), true); + auto seq_lens_decoder_cpu = + seq_lens_decoder.copy_to(paddle::CPUPlace(), true); + auto seq_lens_encoder_cpu = + seq_lens_encoder.copy_to(paddle::CPUPlace(), true); + paddle::optional output_padding_offset_cpu; + if (output_padding_offset) { + output_padding_offset_cpu = + output_padding_offset->copy_to(paddle::CPUPlace(), true); + } - int token_num = tmp_out_cpu.shape()[0]; - int dim_embed = tmp_out_cpu.shape()[1]; - int bsz = cum_offsets_cpu.shape()[0]; + int token_num = tmp_out_cpu.shape()[0]; + int dim_embed = tmp_out_cpu.shape()[1]; + int bsz = cu_seqlens_q_cpu.shape()[0] - 1; - paddle::Tensor out; - if (output_padding_offset_cpu) { - int need_delete_token_num = 0; - for (int i = 0; i < bsz; ++i) { - if (seq_lens_encoder_cpu.data()[i] > 0) { - need_delete_token_num += - seq_lens_encoder_cpu.data()[i] - 1; - } - } - int output_token_num = token_num - need_delete_token_num; - out = paddle::full({output_token_num, dim_embed}, - 0, - tmp_out_cpu.dtype(), - paddle::CPUPlace()); - } else { - out = paddle::full( - {bsz, dim_embed}, 0, tmp_out_cpu.dtype(), paddle::CPUPlace()); + paddle::Tensor out; + if (output_padding_offset_cpu) { + int need_delete_token_num = 0; + for (int i = 0; i < bsz; ++i) { + if (seq_lens_encoder_cpu.data()[i] > 0) { + need_delete_token_num += seq_lens_encoder_cpu.data()[i] - 1; + } } + int output_token_num = token_num - need_delete_token_num; + out = paddle::full({output_token_num, dim_embed}, + 0, + tmp_out_cpu.dtype(), + paddle::CPUPlace()); + } else { + out = paddle::full( + {bsz, dim_embed}, 0, tmp_out_cpu.dtype(), paddle::CPUPlace()); + } - const int *cum_offsets_data = cum_offsets_cpu.data(); - const int *seq_len_this_time_data = seq_len_this_time_cpu.data(); - const int *seq_lens_decoder_data = seq_lens_decoder_cpu.data(); - const int *seq_lens_encoder_data = seq_lens_encoder_cpu.data(); - int elem_nums = out.numel(); - - if (output_padding_offset_cpu) { - const int *output_padding_offset_data = - output_padding_offset_cpu->data(); - switch (tmp_out_cpu.dtype()) { - case paddle::DataType::FLOAT32: - RebuildAppendPaddingCPUImpl(out.data(), - tmp_out_cpu.data(), - cum_offsets_data, - seq_len_this_time_data, - seq_lens_decoder_data, - seq_lens_encoder_data, - output_padding_offset_data, - max_input_length, - dim_embed, - elem_nums); - break; - case paddle::DataType::FLOAT16: - RebuildAppendPaddingCPUImpl( - out.data(), - tmp_out_cpu.data(), - cum_offsets_data, - seq_len_this_time_data, - seq_lens_decoder_data, - seq_lens_encoder_data, - output_padding_offset_data, - max_input_length, - dim_embed, - elem_nums); - break; - case paddle::DataType::BFLOAT16: - RebuildAppendPaddingCPUImpl( - out.data(), - tmp_out_cpu.data(), - cum_offsets_data, - seq_len_this_time_data, - seq_lens_decoder_data, - seq_lens_encoder_data, - output_padding_offset_data, - max_input_length, - dim_embed, - elem_nums); - break; - default: - PD_THROW( - "Unsupported data type for rebuild_padding_cpu. " - "Only float32, float16, and bfloat16 are supported."); - } - } else { - switch (tmp_out_cpu.dtype()) { - case paddle::DataType::FLOAT32: - RebuildPaddingCPUImpl(out.data(), - tmp_out_cpu.data(), - cum_offsets_data, - seq_len_this_time_data, - seq_lens_decoder_data, - seq_lens_encoder_data, - max_input_length, - dim_embed, - elem_nums); - break; - case paddle::DataType::FLOAT16: - RebuildPaddingCPUImpl( - out.data(), - tmp_out_cpu.data(), - cum_offsets_data, - seq_len_this_time_data, - seq_lens_decoder_data, - seq_lens_encoder_data, - max_input_length, - dim_embed, - elem_nums); - break; - case paddle::DataType::BFLOAT16: + const int *cu_seqlens_q_data = cu_seqlens_q_cpu.data(); + const int *seq_len_this_time_data = seq_len_this_time_cpu.data(); + const int *seq_lens_decoder_data = seq_lens_decoder_cpu.data(); + const int *seq_lens_encoder_data = seq_lens_encoder_cpu.data(); + int elem_nums = out.numel(); - RebuildPaddingCPUImpl( - out.data(), - tmp_out_cpu.data(), - cum_offsets_data, - seq_len_this_time_data, - seq_lens_decoder_data, - seq_lens_encoder_data, - max_input_length, - dim_embed, - elem_nums); - break; - default: - PD_THROW( - "Unsupported data type for rebuild_padding_cpu. " - "Only float32, float16, and bfloat16 are supported."); - } + if (output_padding_offset_cpu) { + const int *output_padding_offset_data = + output_padding_offset_cpu->data(); + switch (tmp_out_cpu.dtype()) { + case paddle::DataType::FLOAT32: + RebuildAppendPaddingCPUImpl(out.data(), + tmp_out_cpu.data(), + cu_seqlens_q_data, + seq_len_this_time_data, + seq_lens_decoder_data, + seq_lens_encoder_data, + output_padding_offset_data, + max_input_length, + dim_embed, + elem_nums); + break; + case paddle::DataType::FLOAT16: + RebuildAppendPaddingCPUImpl( + out.data(), + tmp_out_cpu.data(), + cu_seqlens_q_data, + seq_len_this_time_data, + seq_lens_decoder_data, + seq_lens_encoder_data, + output_padding_offset_data, + max_input_length, + dim_embed, + elem_nums); + break; + case paddle::DataType::BFLOAT16: + RebuildAppendPaddingCPUImpl( + out.data(), + tmp_out_cpu.data(), + cu_seqlens_q_data, + seq_len_this_time_data, + seq_lens_decoder_data, + seq_lens_encoder_data, + output_padding_offset_data, + max_input_length, + dim_embed, + elem_nums); + break; + default: + PD_THROW( + "Unsupported data type for rebuild_padding_cpu. " + "Only float32, float16, and bfloat16 are supported."); } - return {out}; + } else { + switch (tmp_out_cpu.dtype()) { + case paddle::DataType::FLOAT32: + RebuildPaddingCPUImpl(out.data(), + tmp_out_cpu.data(), + cu_seqlens_q_data, + seq_len_this_time_data, + seq_lens_decoder_data, + seq_lens_encoder_data, + max_input_length, + dim_embed, + elem_nums); + break; + case paddle::DataType::FLOAT16: + RebuildPaddingCPUImpl( + out.data(), + tmp_out_cpu.data(), + cu_seqlens_q_data, + seq_len_this_time_data, + seq_lens_decoder_data, + seq_lens_encoder_data, + max_input_length, + dim_embed, + elem_nums); + break; + case paddle::DataType::BFLOAT16: + RebuildPaddingCPUImpl( + out.data(), + tmp_out_cpu.data(), + cu_seqlens_q_data, + seq_len_this_time_data, + seq_lens_decoder_data, + seq_lens_encoder_data, + max_input_length, + dim_embed, + elem_nums); + break; + default: + PD_THROW( + "Unsupported data type for rebuild_padding_cpu. " + "Only float32, float16, and bfloat16 are supported."); + } + } + return {out}; } std::vector> RebuildPaddingInferShape( const std::vector &tmp_out_shape, - const std::vector &cum_offsets_shape, + const std::vector &cu_seqlens_q_shape, const std::vector &seq_len_this_time_shape, const std::vector &seq_lens_decoder_shape, const std::vector &seq_lens_encoder_shape, const paddle::optional> &output_padding_offset_shape) { - int64_t dim_embed = tmp_out_shape[1]; - if (output_padding_offset_shape) { - return {{-1, dim_embed}}; - } else { - int64_t bsz = cum_offsets_shape[0]; - return {{bsz, dim_embed}}; - } + int64_t dim_embed = tmp_out_shape[1]; + if (output_padding_offset_shape) { + return {{-1, dim_embed}}; + } else { + int64_t bsz = cu_seqlens_q_shape[0] - 1; + return {{bsz, dim_embed}}; + } } std::vector RebuildPaddingInferDtype( const paddle::DataType &tmp_out_dtype, - const paddle::DataType &cum_offsets_dtype, + const paddle::DataType &cu_seqlens_q_dtype, const paddle::DataType &seq_len_this_time_dtype, const paddle::DataType &seq_lens_decoder_dtype, const paddle::DataType &seq_lens_encoder_dtype, const paddle::optional &output_padding_offset_dtype) { - return {tmp_out_dtype}; + return {tmp_out_dtype}; } PD_BUILD_STATIC_OP(rebuild_padding_cpu) .Inputs({"tmp_out", - "cum_offsets", + "cu_seqlens_q", "seq_len_this_time", "seq_lens_decoder", "seq_lens_encoder", diff --git a/custom_ops/cpu_ops/set_value_by_flags.cc b/custom_ops/cpu_ops/set_value_by_flags.cc index 9f9a2b4163e..1266afa1e4c 100644 --- a/custom_ops/cpu_ops/set_value_by_flags.cc +++ b/custom_ops/cpu_ops/set_value_by_flags.cc @@ -14,28 +14,28 @@ #include "paddle/extension.h" -void set_value_by_flag_and_id(const bool *stop_flags, - int64_t *pre_ids_all, - const int64_t *input_ids, - const int *seq_lens_encoder, - const int *seq_lens_decoder, - const int64_t *step_idx, - int bs, - int length, - int length_input_ids) { - for (int bi = 0; bi < bs; bi++) { - if (!stop_flags[bi]) { - const int seq_len_dec = seq_lens_decoder[bi]; - const int seq_len_enc = seq_lens_encoder[bi]; - int64_t *pre_ids_all_now = pre_ids_all + bi * length; - const int64_t *input_ids_now = input_ids + bi * length_input_ids; - if (seq_len_dec == 0) { - pre_ids_all_now[step_idx[bi]] = input_ids_now[seq_len_enc - 1]; - } else { - pre_ids_all_now[step_idx[bi]] = input_ids_now[0]; - } - } +void set_value_by_flags_and_idx(const bool *stop_flags, + int64_t *pre_ids_all, + const int64_t *input_ids, + const int *seq_lens_encoder, + const int *seq_lens_decoder, + const int64_t *step_idx, + int bs, + int length, + int length_input_ids) { + for (int bi = 0; bi < bs; bi++) { + if (!stop_flags[bi]) { + const int seq_len_dec = seq_lens_decoder[bi]; + const int seq_len_enc = seq_lens_encoder[bi]; + int64_t *pre_ids_all_now = pre_ids_all + bi * length; + const int64_t *input_ids_now = input_ids + bi * length_input_ids; + if (seq_len_dec == 0) { + pre_ids_all_now[step_idx[bi]] = input_ids_now[seq_len_enc - 1]; + } else { + pre_ids_all_now[step_idx[bi]] = input_ids_now[0]; + } } + } } void SetValueByFlagsAndIdx(const paddle::Tensor &pre_ids_all, @@ -45,12 +45,12 @@ void SetValueByFlagsAndIdx(const paddle::Tensor &pre_ids_all, const paddle::Tensor &seq_lens_decoder, const paddle::Tensor &step_idx, const paddle::Tensor &stop_flags) { - std::vector pre_ids_all_shape = pre_ids_all.shape(); - int bs = seq_lens_this_time.shape()[0]; - int length = pre_ids_all_shape[1]; - int length_input_ids = input_ids.shape()[1]; + std::vector pre_ids_all_shape = pre_ids_all.shape(); + int bs = seq_lens_this_time.shape()[0]; + int length = pre_ids_all_shape[1]; + int length_input_ids = input_ids.shape()[1]; - set_value_by_flag_and_id(stop_flags.data(), + set_value_by_flags_and_idx(stop_flags.data(), const_cast(pre_ids_all.data()), input_ids.data(), seq_lens_encoder.data(), diff --git a/custom_ops/cpu_ops/simd_sort.cc b/custom_ops/cpu_ops/simd_sort.cc index 581ee406957..857875a41f5 100644 --- a/custom_ops/cpu_ops/simd_sort.cc +++ b/custom_ops/cpu_ops/simd_sort.cc @@ -21,45 +21,45 @@ void probs_sort(const float *probs, float *ProbsVals, int vocab_size, int bsz) { - float cursum = 0; - std::vector elementsIds(vocab_size); - std::vector elementsProbs(vocab_size); + float cursum = 0; + std::vector elementsIds(vocab_size); + std::vector elementsProbs(vocab_size); #pragma omp parallel for - for (int j = 0; j < vocab_size; j++) { - elementsIds[j] = j; - elementsProbs[j] = probs[j]; - } - x86simdsortStatic::keyvalue_qsort( - elementsProbs.data(), elementsIds.data(), vocab_size, false, true); + for (int j = 0; j < vocab_size; j++) { + elementsIds[j] = j; + elementsProbs[j] = probs[j]; + } + x86simdsortStatic::keyvalue_qsort( + elementsProbs.data(), elementsIds.data(), vocab_size, false, true); #pragma omp parallel for - for (int j = 0; j < vocab_size; ++j) { - ProbsVals[j] = elementsProbs[j]; - ProbsIds[j] = elementsIds[j]; - } + for (int j = 0; j < vocab_size; ++j) { + ProbsVals[j] = elementsProbs[j]; + ProbsIds[j] = elementsIds[j]; + } } std::vector SimdSort(const paddle::Tensor &probs) { - const int bsz = probs.shape()[0]; - const int vocab_size = probs.shape()[1]; - auto sorted_indices = paddle::empty( - {bsz, vocab_size}, paddle::DataType::INT64, probs.place()); - auto sorted_probs = paddle::empty( - {bsz, vocab_size}, paddle::DataType::FLOAT32, probs.place()); - probs_sort(probs.data(), - const_cast(sorted_indices.data()), - const_cast(sorted_probs.data()), - vocab_size, - bsz); - return {sorted_indices, sorted_probs}; + const int bsz = probs.shape()[0]; + const int vocab_size = probs.shape()[1]; + auto sorted_indices = + paddle::empty({bsz, vocab_size}, paddle::DataType::INT64, probs.place()); + auto sorted_probs = paddle::empty( + {bsz, vocab_size}, paddle::DataType::FLOAT32, probs.place()); + probs_sort(probs.data(), + const_cast(sorted_indices.data()), + const_cast(sorted_probs.data()), + vocab_size, + bsz); + return {sorted_indices, sorted_probs}; } std::vector> SimdSortInferShape( const std::vector &probs_shape) { - int64_t bsz = probs_shape[0]; - int64_t vocab_size = probs_shape[1]; - return {{bsz, vocab_size}, {bsz, vocab_size}}; + int64_t bsz = probs_shape[0]; + int64_t vocab_size = probs_shape[1]; + return {{bsz, vocab_size}, {bsz, vocab_size}}; } std::vector SimdSortInferDtype( const paddle::DataType &probs_dtype) { - return {paddle::DataType::INT64, paddle::DataType::FLOAT32}; + return {paddle::DataType::INT64, paddle::DataType::FLOAT32}; } PD_BUILD_STATIC_OP(simd_sort) .Inputs({"probs"}) diff --git a/custom_ops/cpu_ops/simd_sort_fake.cc b/custom_ops/cpu_ops/simd_sort_fake.cc index 82cb1af1ccf..514ff1fa9f6 100644 --- a/custom_ops/cpu_ops/simd_sort_fake.cc +++ b/custom_ops/cpu_ops/simd_sort_fake.cc @@ -16,23 +16,23 @@ #include "paddle/extension.h" std::vector SimdSort(const paddle::Tensor &probs) { - const int bsz = probs.shape()[0]; - const int vocab_size = probs.shape()[1]; - auto sorted_indices = paddle::empty( - {bsz, vocab_size}, paddle::DataType::INT64, probs.place()); - auto sorted_probs = paddle::empty( - {bsz, vocab_size}, paddle::DataType::FLOAT32, probs.place()); - return {sorted_indices, sorted_probs}; + const int bsz = probs.shape()[0]; + const int vocab_size = probs.shape()[1]; + auto sorted_indices = + paddle::empty({bsz, vocab_size}, paddle::DataType::INT64, probs.place()); + auto sorted_probs = paddle::empty( + {bsz, vocab_size}, paddle::DataType::FLOAT32, probs.place()); + return {sorted_indices, sorted_probs}; } std::vector> SimdSortInferShape( const std::vector &probs_shape) { - int64_t bsz = probs_shape[0]; - int64_t vocab_size = probs_shape[1]; - return {{bsz, vocab_size}, {bsz, vocab_size}}; + int64_t bsz = probs_shape[0]; + int64_t vocab_size = probs_shape[1]; + return {{bsz, vocab_size}, {bsz, vocab_size}}; } std::vector SimdSortInferDtype( const paddle::DataType &probs_dtype) { - return {paddle::DataType::INT64, paddle::DataType::FLOAT32}; + return {paddle::DataType::INT64, paddle::DataType::FLOAT32}; } PD_BUILD_STATIC_OP(simd_sort) .Inputs({"probs"}) diff --git a/custom_ops/cpu_ops/stop_generation_multi_ends.cc b/custom_ops/cpu_ops/stop_generation_multi_ends.cc index 7669cfa51d0..cd4c9323a81 100644 --- a/custom_ops/cpu_ops/stop_generation_multi_ends.cc +++ b/custom_ops/cpu_ops/stop_generation_multi_ends.cc @@ -18,14 +18,18 @@ #include #include "paddle/extension.h" +#ifndef PD_BUILD_STATIC_OP +#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name) +#endif + bool is_in_end(const int64_t id, const int64_t *end_ids, int length) { - bool flag = false; - for (int i = 0; i < length; i++) { - if (id == end_ids[i]) { - return true; - } + bool flag = false; + for (int i = 0; i < length; i++) { + if (id == end_ids[i]) { + return true; } - return flag; + } + return flag; } void set_value_by_flags(bool *stop_flags, @@ -36,21 +40,23 @@ void set_value_by_flags(bool *stop_flags, const int bs, const int end_length, bool beam_search) { - for (int bi = 0; bi < bs; bi++) { - if (stop_flags[bi]) { - if ((seq_lens[bi] == 0)) { - topk_ids[bi] = -1; - } else { - topk_ids[bi] = end_ids[0]; - next_tokens[bi] = end_ids[0]; - } - } else { - next_tokens[bi] = topk_ids[bi]; - } - if (!beam_search && is_in_end(topk_ids[bi], end_ids, end_length)) { - stop_flags[bi] = true; - } + for (int bi = 0; bi < bs; bi++) { + if (stop_flags[bi]) { + if ((seq_lens[bi] == 0)) { + topk_ids[bi] = -1; + } else { + topk_ids[bi] = end_ids[0]; + next_tokens[bi] = end_ids[0]; + } + } else { + next_tokens[bi] = topk_ids[bi]; + } + if (!beam_search && is_in_end(topk_ids[bi], end_ids, end_length)) { + stop_flags[bi] = true; + topk_ids[bi] = end_ids[0]; + next_tokens[bi] = end_ids[0]; } + } } void GetStopFlagsMulti(const paddle::Tensor &topk_ids, @@ -59,17 +65,17 @@ void GetStopFlagsMulti(const paddle::Tensor &topk_ids, const paddle::Tensor &end_ids, const paddle::Tensor &next_tokens, const bool beam_search) { - std::vector shape = topk_ids.shape(); - int64_t bs_now = shape[0]; - int64_t end_length = end_ids.shape()[0]; - set_value_by_flags(const_cast(stop_flags.data()), - const_cast(topk_ids.data()), - const_cast(next_tokens.data()), - end_ids.data(), - seq_lens.data(), - bs_now, - end_length, - false); + std::vector shape = topk_ids.shape(); + int64_t bs_now = shape[0]; + int64_t end_length = end_ids.shape()[0]; + set_value_by_flags(const_cast(stop_flags.data()), + const_cast(topk_ids.data()), + const_cast(next_tokens.data()), + end_ids.data(), + seq_lens.data(), + bs_now, + end_length, + false); } PD_BUILD_STATIC_OP(set_stop_value_multi_ends_cpu) diff --git a/custom_ops/cpu_ops/token_penalty_multi_scores.cc b/custom_ops/cpu_ops/token_penalty_multi_scores.cc index fdcd56eb6da..81b0bed1986 100644 --- a/custom_ops/cpu_ops/token_penalty_multi_scores.cc +++ b/custom_ops/cpu_ops/token_penalty_multi_scores.cc @@ -23,16 +23,16 @@ void min_length_logits_process(float *logits, const int64_t bs, const int64_t length, const int64_t end_length) { - for (int bi = 0; bi < bs; ++bi) { - if (cur_len[bi] < 0) { - continue; - } - if (cur_len[bi] < min_len[bi]) { - for (int i = 0; i < end_length; ++i) { - logits[bi * length + eos_token_id[i]] = -1e10; - } - } + for (int bi = 0; bi < bs; ++bi) { + if (cur_len[bi] < 0) { + continue; } + if (cur_len[bi] < min_len[bi]) { + for (int i = 0; i < end_length; ++i) { + logits[bi * length + eos_token_id[i]] = -1e10; + } + } + } } void update_repeat_times(const int64_t *pre_ids, @@ -41,20 +41,20 @@ void update_repeat_times(const int64_t *pre_ids, const int64_t bs, const int64_t length, const int64_t length_id) { - for (int bi = 0; bi < bs; ++bi) { - if (cur_len[bi] < 0) { - continue; - } - const int64_t *pre_ids_now = pre_ids + bi * length_id; - int *repeat_times_now = repeat_times + bi * length; - for (int i = 0; i < length_id; i++) { - int64_t id = pre_ids_now[i]; - if (id < 0) { - break; - } - repeat_times_now[id] += 1; - } + for (int bi = 0; bi < bs; ++bi) { + if (cur_len[bi] < 0) { + continue; + } + const int64_t *pre_ids_now = pre_ids + bi * length_id; + int *repeat_times_now = repeat_times + bi * length; + for (int i = 0; i < length_id; i++) { + int64_t id = pre_ids_now[i]; + if (id < 0) { + break; + } + repeat_times_now[id] += 1; } + } } void update_value_by_repeat_times(const int *repeat_times, @@ -65,24 +65,22 @@ void update_value_by_repeat_times(const int *repeat_times, float *logits, const int64_t bs, const int64_t length) { - for (int bi = 0; bi < bs; ++bi) { - float *logits_now = logits + bi * length; - const int *repeat_times_now = repeat_times + bi * length; - float alpha = static_cast(penalty_scores[bi]); - float beta = static_cast(frequency_score[bi]); - float gamma = static_cast(presence_score[bi]); - for (int i = 0; i < length; ++i) { - int times = repeat_times_now[i]; - float logit_now = static_cast(logits_now[i]); - if (times == 0) { - logits_now[i] = - static_cast(logit_now / temperatures[bi]); - } - logit_now = logit_now < 0 ? logit_now * alpha : logit_now / alpha; - logits_now[i] = - static_cast(logit_now - times * beta - gamma); - } + for (int bi = 0; bi < bs; ++bi) { + float *logits_now = logits + bi * length; + const int *repeat_times_now = repeat_times + bi * length; + float alpha = static_cast(penalty_scores[bi]); + float beta = static_cast(frequency_score[bi]); + float gamma = static_cast(presence_score[bi]); + for (int i = 0; i < length; ++i) { + int times = repeat_times_now[i]; + float logit_now = static_cast(logits_now[i]); + if (times == 0) { + logits_now[i] = static_cast(logit_now / temperatures[bi]); + } + logit_now = logit_now < 0 ? logit_now * alpha : logit_now / alpha; + logits_now[i] = static_cast(logit_now - times * beta - gamma); } + } } void ban_bad_words(float *logits, @@ -90,15 +88,14 @@ void ban_bad_words(float *logits, const int64_t bs, const int64_t length, const int64_t bad_words_length) { - for (int bi = 0; bi < bs; ++bi) { - float *logits_now = logits + bi * length; - for (int bwid = 0; bwid < bad_words_length; ++bwid) { - const int64_t bad_words_token_id = bad_words_list[bwid]; - if (bad_words_token_id >= length || bad_words_token_id < 0) - continue; - logits_now[bad_words_token_id] = -1e10; - } + for (int bi = 0; bi < bs; ++bi) { + float *logits_now = logits + bi * length; + for (int bwid = 0; bwid < bad_words_length; ++bwid) { + const int64_t bad_words_token_id = bad_words_list[bwid]; + if (bad_words_token_id >= length || bad_words_token_id < 0) continue; + logits_now[bad_words_token_id] = -1e10; } + } } template @@ -112,44 +109,44 @@ void token_penalty_multi_scores_kernel(const paddle::Tensor &pre_ids, const paddle::Tensor &cur_len, const paddle::Tensor &min_len, const paddle::Tensor &eos_token_id) { - std::vector shape = logits.shape(); - auto repeat_times = - paddle::full(shape, 0, paddle::DataType::INT32, pre_ids.place()); - int64_t bs = shape[0]; - int64_t length = shape[1]; - int64_t length_id = pre_ids.shape()[1]; - int64_t end_length = eos_token_id.shape()[0]; - int64_t length_bad_words = bad_tokens.shape()[0]; + std::vector shape = logits.shape(); + auto repeat_times = + paddle::full(shape, 0, paddle::DataType::INT32, pre_ids.place()); + int64_t bs = shape[0]; + int64_t length = shape[1]; + int64_t length_id = pre_ids.shape()[1]; + int64_t end_length = eos_token_id.shape()[0]; + int64_t length_bad_words = bad_tokens.shape()[0]; - min_length_logits_process(const_cast(logits.data()), - cur_len.data(), - min_len.data(), - eos_token_id.data(), - bs, - length, - end_length); + min_length_logits_process(const_cast(logits.data()), + cur_len.data(), + min_len.data(), + eos_token_id.data(), + bs, + length, + end_length); - update_repeat_times(pre_ids.data(), - cur_len.data(), - repeat_times.data(), - bs, - length, - length_id); + update_repeat_times(pre_ids.data(), + cur_len.data(), + repeat_times.data(), + bs, + length, + length_id); - update_value_by_repeat_times(repeat_times.data(), - penalty_scores.data(), - frequency_score.data(), - presence_score.data(), - temperatures.data(), - const_cast(logits.data()), - bs, - length); + update_value_by_repeat_times(repeat_times.data(), + penalty_scores.data(), + frequency_score.data(), + presence_score.data(), + temperatures.data(), + const_cast(logits.data()), + bs, + length); - ban_bad_words(const_cast(logits.data()), - bad_tokens.data(), - bs, - length, - length_bad_words); + ban_bad_words(const_cast(logits.data()), + bad_tokens.data(), + bs, + length, + length_bad_words); } void TokenPenaltyMultiScores(const paddle::Tensor &pre_ids, @@ -162,17 +159,17 @@ void TokenPenaltyMultiScores(const paddle::Tensor &pre_ids, const paddle::Tensor &cur_len, const paddle::Tensor &min_len, const paddle::Tensor &eos_token_id) { - return token_penalty_multi_scores_kernel( - pre_ids, - logits, - penalty_scores, - frequency_scores, - presence_scores, - temperatures, - bad_tokens, - cur_len, - min_len, - eos_token_id); + return token_penalty_multi_scores_kernel( + pre_ids, + logits, + penalty_scores, + frequency_scores, + presence_scores, + temperatures, + bad_tokens, + cur_len, + min_len, + eos_token_id); } PD_BUILD_STATIC_OP(get_token_penalty_multi_scores_cpu) diff --git a/custom_ops/cpu_ops/update_inputs.cc b/custom_ops/cpu_ops/update_inputs.cc index 4d1ac1bcda8..c2b74800282 100644 --- a/custom_ops/cpu_ops/update_inputs.cc +++ b/custom_ops/cpu_ops/update_inputs.cc @@ -24,50 +24,50 @@ void update_inputs_kernel(bool *not_need_stop, const int64_t *next_tokens, const int bsz, const int input_ids_stride) { - int64_t stop_sum = 0; - for (int bi = 0; bi < bsz; ++bi) { - bool stop_flag_now = false; - int64_t stop_flag_now_int = 0; - stop_flag_now = stop_flags[bi]; - stop_flag_now_int = static_cast(stop_flag_now); - auto seq_len_this_time = seq_lens_this_time[bi]; - auto seq_len_encoder = seq_lens_encoder[bi]; - auto seq_len_decoder = seq_lens_decoder[bi]; - seq_lens_decoder[bi] = - stop_flag_now ? 0 - : (seq_len_decoder == 0 ? seq_len_encoder - : seq_len_decoder + 1); - seq_lens_this_time[bi] = stop_flag_now ? 0 : 1; - seq_lens_encoder[bi] = 0; - int64_t *input_ids_now = input_ids + bi * input_ids_stride; - input_ids_now[0] = next_tokens[bi]; - stop_sum += stop_flag_now_int; - } - not_need_stop[0] = stop_sum < stop_nums[0]; + int64_t stop_sum = 0; + for (int bi = 0; bi < bsz; ++bi) { + bool stop_flag_now = false; + int64_t stop_flag_now_int = 0; + stop_flag_now = stop_flags[bi]; + stop_flag_now_int = static_cast(stop_flag_now); + auto seq_len_this_time = seq_lens_this_time[bi]; + auto seq_len_encoder = seq_lens_encoder[bi]; + auto seq_len_decoder = seq_lens_decoder[bi]; + seq_lens_decoder[bi] = + stop_flag_now + ? 0 + : (seq_len_decoder == 0 ? seq_len_encoder : seq_len_decoder + 1); + seq_lens_this_time[bi] = stop_flag_now ? 0 : 1; + seq_lens_encoder[bi] = 0; + int64_t *input_ids_now = input_ids + bi * input_ids_stride; + input_ids_now[0] = next_tokens[bi]; + stop_sum += stop_flag_now_int; + } + not_need_stop[0] = stop_sum < stop_nums[0]; } -void UpdateInputes(const paddle::Tensor &stop_flags, - const paddle::Tensor ¬_need_stop, - const paddle::Tensor &seq_lens_this_time, - const paddle::Tensor &seq_lens_encoder, - const paddle::Tensor &seq_lens_decoder, - const paddle::Tensor &input_ids, - const paddle::Tensor &stop_nums, - const paddle::Tensor &next_tokens, - const paddle::Tensor &is_block_step) { - const int bsz = input_ids.shape()[0]; - const int input_ids_stride = input_ids.shape()[1]; - update_inputs_kernel(const_cast(not_need_stop.data()), - const_cast(seq_lens_this_time.data()), - const_cast(seq_lens_encoder.data()), - const_cast(seq_lens_decoder.data()), - const_cast(input_ids.data()), - stop_nums.data(), - stop_flags.data(), - is_block_step.data(), - next_tokens.data(), - bsz, - input_ids_stride); +void UpdateInputs(const paddle::Tensor &stop_flags, + const paddle::Tensor ¬_need_stop, + const paddle::Tensor &seq_lens_this_time, + const paddle::Tensor &seq_lens_encoder, + const paddle::Tensor &seq_lens_decoder, + const paddle::Tensor &input_ids, + const paddle::Tensor &stop_nums, + const paddle::Tensor &next_tokens, + const paddle::Tensor &is_block_step) { + const int bsz = input_ids.shape()[0]; + const int input_ids_stride = input_ids.shape()[1]; + update_inputs_kernel(const_cast(not_need_stop.data()), + const_cast(seq_lens_this_time.data()), + const_cast(seq_lens_encoder.data()), + const_cast(seq_lens_decoder.data()), + const_cast(input_ids.data()), + stop_nums.data(), + stop_flags.data(), + is_block_step.data(), + next_tokens.data(), + bsz, + input_ids_stride); } PD_BUILD_STATIC_OP(update_inputs_cpu) @@ -90,4 +90,4 @@ PD_BUILD_STATIC_OP(update_inputs_cpu) {"seq_lens_encoder", "seq_lens_encoder_out"}, {"seq_lens_decoder", "seq_lens_decoder_out"}, {"input_ids", "input_ids_out"}}) - .SetKernelFn(PD_KERNEL(UpdateInputes)); + .SetKernelFn(PD_KERNEL(UpdateInputs)); diff --git a/custom_ops/cpu_ops/xft_all_layer_fake.cc b/custom_ops/cpu_ops/xft_all_layer_fake.cc index aeb20004e6c..ab64d80c8cf 100644 --- a/custom_ops/cpu_ops/xft_all_layer_fake.cc +++ b/custom_ops/cpu_ops/xft_all_layer_fake.cc @@ -45,18 +45,18 @@ std::vector InvokeAllLLaMALayer( int maxPositions, int maxPosEmbed, int intermediateSize) { - auto out = paddle::empty_like(input); - return {out}; + auto out = paddle::empty_like(input); + return {out}; } std::vector> AllLLaMALayerInferShape( std::vector x_shape) { - return {x_shape}; + return {x_shape}; } std::vector AllLLaMALayerInferDtype( paddle::DataType x_dtype) { - return {x_dtype}; + return {x_dtype}; } PD_BUILD_STATIC_OP(xft_llama_all_layer) diff --git a/custom_ops/cpu_ops/xft_greedy_search_fake.cc b/custom_ops/cpu_ops/xft_greedy_search_fake.cc index ecf57a2ab4b..060e7da82b0 100644 --- a/custom_ops/cpu_ops/xft_greedy_search_fake.cc +++ b/custom_ops/cpu_ops/xft_greedy_search_fake.cc @@ -16,20 +16,20 @@ #include "paddle/extension.h" std::vector XftGreedySearch(const paddle::Tensor &probs) { - const int bsz = probs.shape()[0]; - const int vocab_size = probs.shape()[1]; - auto next_tokens = - paddle::empty({bsz, 1}, paddle::DataType::INT64, probs.place()); - return {next_tokens}; + const int bsz = probs.shape()[0]; + const int vocab_size = probs.shape()[1]; + auto next_tokens = + paddle::empty({bsz, 1}, paddle::DataType::INT64, probs.place()); + return {next_tokens}; } std::vector> XftGreedySearchInferShape( const std::vector &probs_shape) { - int64_t bsz = probs_shape[0]; - return {{bsz, 1}}; + int64_t bsz = probs_shape[0]; + return {{bsz, 1}}; } std::vector XftGreedySearchInferDtype( const paddle::DataType &probs_dtype) { - return {paddle::DataType::INT64}; + return {paddle::DataType::INT64}; } PD_BUILD_STATIC_OP(xft_greedy_search) .Inputs({"probs"}) diff --git a/custom_ops/gpu_ops/append_attention.cu b/custom_ops/gpu_ops/append_attention.cu index 2ba7555e7f3..c1586945cc5 100644 --- a/custom_ops/gpu_ops/append_attention.cu +++ b/custom_ops/gpu_ops/append_attention.cu @@ -14,8 +14,8 @@ #include "append_attn/append_attention_kernel.h" #include "append_attn/decoder_write_cache_with_rope_kernel.h" -#include "append_attn/speculate_write_cache_with_rope_kernel.h" #include "append_attn/encoder_write_cache_with_rope_kernel.h" +#include "append_attn/speculate_write_cache_with_rope_kernel.h" #ifndef PD_BUILD_STATIC_OP #define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name) @@ -26,19 +26,18 @@ class type2value; template <> class type2value { - public: - static constexpr paddle::DataType value = paddle::DataType::BFLOAT16; + public: + static constexpr paddle::DataType value = paddle::DataType::BFLOAT16; }; template <> class type2value { - public: - static constexpr paddle::DataType value = paddle::DataType::FLOAT16; + public: + static constexpr paddle::DataType value = paddle::DataType::FLOAT16; }; - template -std::vector AppendAttentionKernel( +void AppendAttentionKernel( const AppendAttnMetaData& meta_data, const paddle::Tensor& qkv, const paddle::Tensor& key_cache, @@ -59,7 +58,7 @@ std::vector AppendAttentionKernel( const paddle::Tensor& decoder_tile_ids_per_batch, const paddle::Tensor& decoder_num_blocks, const paddle::Tensor& set_max_lengths, - const paddle::Tensor& max_len_kv, + paddle::Tensor& fmha_out, const paddle::optional& rotary_embs, const paddle::optional& attn_mask, const paddle::optional& qkv_bias, @@ -73,6 +72,10 @@ std::vector AppendAttentionKernel( const paddle::optional& out_linear_shifts, const paddle::optional& out_linear_smooths, const paddle::optional& kv_signal_data, + const paddle::optional& q_norm_weight, + const paddle::optional& k_norm_weight, + const paddle::optional& sinks, + const float rms_norm_eps, const std::string& cache_quant_type_str, const bool use_neox_rotary_style, const bool rope_3d, @@ -86,24 +89,26 @@ std::vector AppendAttentionKernel( const int encoder_max_partition_size, const int speculate_max_draft_token_num, const bool causal, - const bool speculate_decoder) { + const bool speculate_decoder, + const int sliding_window, + const int sink_size = 0) { typedef PDTraits traits_; typedef typename traits_::DataType DataType_; typedef typename traits_::data_t data_t; - // set_max_lengths: max_len_this_time, max_enc_len_this_time, max_dec_len_this_time, max_enc_dec_len_this_time, - // max_just_dec_len_this_time, max_just_dec_merged_len_this_time, max_system_len, max_just_dec_len_without_system - int max_len_this_time = set_max_lengths.data()[0]; - int max_enc_len_this_time =set_max_lengths.data()[1]; - int max_dec_len_this_time = set_max_lengths.data()[2]; - int max_enc_dec_len_this_time = set_max_lengths.data()[3]; - int max_just_dec_len_this_time = set_max_lengths.data()[4]; + const int max_len_this_time = set_max_lengths.data()[0]; + const int max_enc_len_this_time = set_max_lengths.data()[1]; + const int max_dec_len_this_time = set_max_lengths.data()[2]; + const int max_enc_dec_len_this_time = set_max_lengths.data()[3]; + const int max_just_dec_len_this_time = set_max_lengths.data()[4]; + const int max_kv_len_this_time = set_max_lengths.data()[5]; auto main_stream = qkv.stream(); static cudaEvent_t main_event; static cudaEvent_t decoder_event; static cudaStream_t decoder_stream; static bool init_flag = false; + bool enforce_fmul_rn = getEnvEnableRL(); if (max_just_dec_len_this_time > 0 && max_enc_len_this_time > 0 && !init_flag) { cudaEventCreateWithFlags(&main_event, cudaEventDisableTiming); @@ -118,74 +123,58 @@ std::vector AppendAttentionKernel( } else { qkv_out = qkv; } - paddle::Tensor fmha_out; - if (out_linear_in_scale > 0.0) { - if (fabs(quant_max_bound - 127.0f) < 0.000001) { - fmha_out = GetEmptyTensor( - {meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims}, - paddle::DataType::INT8, - qkv.place()); - } else if (fabs(quant_max_bound - 448.0f) < 0.000001) { - fmha_out = GetEmptyTensor( - {meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims}, - paddle::DataType::FLOAT8_E4M3FN, - qkv.place()); - }else{ - PD_THROW("Only supported attr of quant_max_bound in ['127', '448']."); - } - } else { - fmha_out = GetEmptyTensor( - {meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims}, - D, - qkv.place()); - } - auto dispatch_CascadeAppendAttentionKernel = [&](auto temp_args, - const paddle::Tensor& lambda_batch_ids, - const paddle::Tensor& lambda_tile_ids_per_batch, - const int lambda_num_blocks_data, - const int lambda_block_shape_q, - const int lambda_max_dec_len, - const bool lambda_is_decoder, - const bool lambda_enable_prefill, - cudaStream_t& lambda_stream - ) -> void { - CascadeAppendAttentionKernel( - meta_data, - qkv_out, - key_cache, - value_cache, - attn_mask, - cache_k_dequant_scales, - cache_v_dequant_scales, - cache_k_zp, - cache_v_zp, - out_linear_shifts, - out_linear_smooths, - seq_lens_this_time, - seq_lens_decoder, - seq_lens_encoder, - batch_id_per_token, - cu_seqlens_q, - block_tables, - lambda_batch_ids, - lambda_tile_ids_per_batch, - cache_quant_type_str, - lambda_num_blocks_data, - lambda_block_shape_q, - max_input_length, - lambda_max_dec_len, - quant_max_bound, - quant_min_bound, - out_linear_in_scale, - max_partition_size, - encoder_max_partition_size, - speculate_max_draft_token_num, - causal, - lambda_is_decoder, - lambda_enable_prefill, - lambda_stream, - &fmha_out); + auto dispatch_CascadeAppendAttentionKernel = + [&](auto temp_args, + const paddle::Tensor& lambda_batch_ids, + const paddle::Tensor& lambda_tile_ids_per_batch, + const int lambda_num_blocks_data, + const int lambda_block_shape_q, + const int lambda_max_dec_len, + const bool lambda_is_decoder, + const bool lambda_enable_prefill, + cudaStream_t& lambda_stream) -> void { + CascadeAppendAttentionKernel( + meta_data, + qkv_out, + key_cache, + value_cache, + attn_mask, + cache_quant_type_str == "block_wise_fp8" ? cache_k_quant_scales + : cache_k_dequant_scales, + cache_quant_type_str == "block_wise_fp8" ? cache_v_quant_scales + : cache_v_dequant_scales, + cache_k_zp, + cache_v_zp, + out_linear_shifts, + out_linear_smooths, + sinks, + seq_lens_this_time, + seq_lens_decoder, + seq_lens_encoder, + batch_id_per_token, + cu_seqlens_q, + block_tables, + lambda_batch_ids, + lambda_tile_ids_per_batch, + cache_quant_type_str, + lambda_num_blocks_data, + lambda_block_shape_q, + max_input_length, + lambda_max_dec_len, + quant_max_bound, + quant_min_bound, + out_linear_in_scale, + max_partition_size, + encoder_max_partition_size, + speculate_max_draft_token_num, + causal, + lambda_is_decoder, + lambda_enable_prefill, + lambda_stream, + &fmha_out, + sliding_window, + sink_size); }; if (max_enc_len_this_time > 0) { @@ -195,35 +184,43 @@ std::vector AppendAttentionKernel( int encoder_num_blocks_data = encoder_num_blocks.data()[0]; int kv_num_blocks_data = kv_num_blocks.data()[0]; - auto dispatch_EncoderWriteCacheWithRopeKernel = [&](auto temp_args) -> void { - EncoderWriteCacheWithRopeKernel( - meta_data, - qkv, - seq_lens_this_time, - seq_lens_encoder, - seq_lens_decoder, - batch_id_per_token, - cu_seqlens_q, - block_tables, - kv_batch_ids, - kv_tile_ids_per_batch, - rotary_embs, - qkv_out_scales, - qkv_bias, - cache_k_quant_scales, - cache_v_quant_scales, - cache_k_zp, - cache_v_zp, - kv_signal_data, - cache_quant_type_str, - kv_num_blocks_data, - max_input_length, - use_neox_rotary_style, - rope_3d, - main_stream, - &qkv_out, - const_cast(&key_cache), - const_cast(&value_cache)); + auto dispatch_EncoderWriteCacheWithRopeKernel = + [&](auto temp_args) -> void { + DISPATCH_BOOL_DTYPE(enforce_fmul_rn, EnforceFmulRN, { + EncoderWriteCacheWithRopeKernel( + meta_data, + qkv, + seq_lens_this_time, + seq_lens_encoder, + seq_lens_decoder, + batch_id_per_token, + cu_seqlens_q, + block_tables, + kv_batch_ids, + kv_tile_ids_per_batch, + rotary_embs, + qkv_out_scales, + qkv_bias, + cache_k_quant_scales, + cache_v_quant_scales, + cache_k_zp, + cache_v_zp, + kv_signal_data, + cache_quant_type_str, + kv_num_blocks_data, + max_input_length, + use_neox_rotary_style, + rope_3d, + main_stream, + &qkv_out, + const_cast(&key_cache), + const_cast(&value_cache), + q_norm_weight, + k_norm_weight, + rms_norm_eps); + }) }; if (qkv_out_scales) { @@ -235,30 +232,55 @@ std::vector AppendAttentionKernel( } if (out_linear_in_scale > 0.0) { switch (fmha_out.dtype()) { - case paddle::DataType::INT8:{ + case paddle::DataType::INT8: { int8_t tmp; - dispatch_CascadeAppendAttentionKernel(tmp, encoder_batch_ids, encoder_tile_ids_per_batch, encoder_num_blocks_data, encoder_block_shape_q, max_enc_dec_len_this_time, false, true, main_stream); + dispatch_CascadeAppendAttentionKernel(tmp, + encoder_batch_ids, + encoder_tile_ids_per_batch, + encoder_num_blocks_data, + encoder_block_shape_q, + max_enc_dec_len_this_time, + false, + true, + main_stream); break; } - case paddle::DataType::FLOAT8_E4M3FN:{ + case paddle::DataType::FLOAT8_E4M3FN: { phi::dtype::float8_e4m3fn tmp; - dispatch_CascadeAppendAttentionKernel(tmp, encoder_batch_ids, encoder_tile_ids_per_batch, encoder_num_blocks_data, encoder_block_shape_q, max_enc_dec_len_this_time, false, true, main_stream); + dispatch_CascadeAppendAttentionKernel(tmp, + encoder_batch_ids, + encoder_tile_ids_per_batch, + encoder_num_blocks_data, + encoder_block_shape_q, + max_enc_dec_len_this_time, + false, + true, + main_stream); break; } - default:{ - PD_THROW("Only supported output fmha_out of quant dtype in ['int8', 'FLOAT8_E4M3FN']."); + default: { + PD_THROW( + "Only supported output fmha_out of quant dtype in ['int8', " + "'FLOAT8_E4M3FN']."); break; } } } else { data_t tmp; - dispatch_CascadeAppendAttentionKernel(tmp, encoder_batch_ids, encoder_tile_ids_per_batch, encoder_num_blocks_data, encoder_block_shape_q, max_enc_dec_len_this_time, false, true, main_stream); + dispatch_CascadeAppendAttentionKernel(tmp, + encoder_batch_ids, + encoder_tile_ids_per_batch, + encoder_num_blocks_data, + encoder_block_shape_q, + max_enc_dec_len_this_time, + false, + true, + main_stream); } } if (max_just_dec_len_this_time > 0) { int decoder_num_blocks_data = decoder_num_blocks.data()[0]; - int max_len_kv_data = max_len_kv.data()[0]; cudaStream_t exec_stream; if (max_enc_len_this_time > 0) { @@ -267,133 +289,166 @@ std::vector AppendAttentionKernel( } else { exec_stream = main_stream; } - if (speculate_decoder) { - if (qkv_out_scales) { - SpeculateWriteCacheWithRoPEKernel( - meta_data, - qkv, // [token_num, num_heads, head_dim] - seq_lens_decoder, - seq_lens_encoder, - batch_id_per_token, - cu_seqlens_q, - block_tables, - rotary_embs, - qkv_out_scales, - qkv_bias, - cache_k_quant_scales, - cache_v_quant_scales, - cache_k_zp, - cache_v_zp, - cache_quant_type_str, - use_neox_rotary_style, - max_input_length, - exec_stream, - &qkv_out, - const_cast(&key_cache), - const_cast(&value_cache)); - } else { - SpeculateWriteCacheWithRoPEKernel( - meta_data, - qkv_out, // [token_num, num_heads, head_dim] - seq_lens_decoder, - seq_lens_encoder, - batch_id_per_token, - cu_seqlens_q, - block_tables, - rotary_embs, - qkv_out_scales, - qkv_bias, - cache_k_quant_scales, - cache_v_quant_scales, - cache_k_zp, - cache_v_zp, - cache_quant_type_str, - use_neox_rotary_style, - max_input_length, - exec_stream, - &qkv_out, - const_cast(&key_cache), - const_cast(&value_cache)); - } - } else { - if (qkv_out_scales) { - DecoderWriteCacheWithRoPEKernel( - meta_data, - qkv, // [token_num, num_heads, head_dim] - seq_lens_decoder, - seq_lens_encoder, - batch_id_per_token, - cu_seqlens_q, - block_tables, - rotary_embs, - qkv_out_scales, - qkv_bias, - cache_k_quant_scales, - cache_v_quant_scales, - cache_k_zp, - cache_v_zp, - cache_quant_type_str, - use_neox_rotary_style, - rope_3d, - max_input_length, - exec_stream, - &qkv_out, - const_cast(&key_cache), - const_cast(&value_cache)); + DISPATCH_BOOL_DTYPE(enforce_fmul_rn, EnforceFmulRN, { + if (speculate_decoder) { + if (qkv_out_scales) { + SpeculateWriteCacheWithRoPEKernel( + meta_data, + qkv, // [token_num, num_heads, head_dim] + seq_lens_decoder, + seq_lens_encoder, + batch_id_per_token, + cu_seqlens_q, + block_tables, + rotary_embs, + qkv_out_scales, + qkv_bias, + cache_k_quant_scales, + cache_v_quant_scales, + cache_k_zp, + cache_v_zp, + cache_quant_type_str, + use_neox_rotary_style, + rope_3d, + max_input_length, + exec_stream, + &qkv_out, + const_cast(&key_cache), + const_cast(&value_cache), + q_norm_weight, + k_norm_weight, + rms_norm_eps); + } else { + SpeculateWriteCacheWithRoPEKernel( + meta_data, + qkv_out, // [token_num, num_heads, head_dim] + seq_lens_decoder, + seq_lens_encoder, + batch_id_per_token, + cu_seqlens_q, + block_tables, + rotary_embs, + qkv_out_scales, + qkv_bias, + cache_k_quant_scales, + cache_v_quant_scales, + cache_k_zp, + cache_v_zp, + cache_quant_type_str, + use_neox_rotary_style, + rope_3d, + max_input_length, + exec_stream, + &qkv_out, + const_cast(&key_cache), + const_cast(&value_cache), + q_norm_weight, + k_norm_weight, + rms_norm_eps); + } } else { - DecoderWriteCacheWithRoPEKernel( - meta_data, - qkv_out, // [token_num, num_heads, head_dim] - seq_lens_decoder, - seq_lens_encoder, - batch_id_per_token, - cu_seqlens_q, - block_tables, - rotary_embs, - qkv_out_scales, - qkv_bias, - cache_k_quant_scales, - cache_v_quant_scales, - cache_k_zp, - cache_v_zp, - cache_quant_type_str, - use_neox_rotary_style, - rope_3d, - max_input_length, - exec_stream, - &qkv_out, - const_cast(&key_cache), - const_cast(&value_cache)); + if (qkv_out_scales) { + DecoderWriteCacheWithRoPEKernel( + meta_data, + qkv, // [token_num, num_heads, head_dim] + seq_lens_decoder, + seq_lens_encoder, + cu_seqlens_q, + block_tables, + rotary_embs, + qkv_out_scales, + qkv_bias, + cache_k_quant_scales, + cache_v_quant_scales, + cache_k_zp, + cache_v_zp, + cache_quant_type_str, + use_neox_rotary_style, + rope_3d, + max_input_length, + exec_stream, + &qkv_out, + const_cast(&key_cache), + const_cast(&value_cache), + q_norm_weight, + k_norm_weight, + rms_norm_eps); + } else { + DecoderWriteCacheWithRoPEKernel( + meta_data, + qkv_out, // [token_num, num_heads, head_dim] + seq_lens_decoder, + seq_lens_encoder, + cu_seqlens_q, + block_tables, + rotary_embs, + qkv_out_scales, + qkv_bias, + cache_k_quant_scales, + cache_v_quant_scales, + cache_k_zp, + cache_v_zp, + cache_quant_type_str, + use_neox_rotary_style, + rope_3d, + max_input_length, + exec_stream, + &qkv_out, + const_cast(&key_cache), + const_cast(&value_cache), + q_norm_weight, + k_norm_weight, + rms_norm_eps); + } } - } + }) if (out_linear_in_scale > 0.0) { switch (fmha_out.dtype()) { - case paddle::DataType::INT8:{ - int8_t tmp; - dispatch_CascadeAppendAttentionKernel(tmp, decoder_batch_ids, decoder_tile_ids_per_batch, decoder_num_blocks_data, - decoder_block_shape_q, max_len_kv_data, !speculate_decoder, !speculate_decoder, exec_stream); + case paddle::DataType::INT8: { + int8_t tmp; + dispatch_CascadeAppendAttentionKernel(tmp, + decoder_batch_ids, + decoder_tile_ids_per_batch, + decoder_num_blocks_data, + decoder_block_shape_q, + max_kv_len_this_time, + !speculate_decoder, + !speculate_decoder, + exec_stream); break; } - case paddle::DataType::FLOAT8_E4M3FN:{ - phi::dtype::float8_e4m3fn tmp; - dispatch_CascadeAppendAttentionKernel(tmp, decoder_batch_ids, decoder_tile_ids_per_batch, decoder_num_blocks_data, - decoder_block_shape_q, max_len_kv_data, !speculate_decoder, !speculate_decoder, exec_stream); + case paddle::DataType::FLOAT8_E4M3FN: { + phi::dtype::float8_e4m3fn tmp; + dispatch_CascadeAppendAttentionKernel(tmp, + decoder_batch_ids, + decoder_tile_ids_per_batch, + decoder_num_blocks_data, + decoder_block_shape_q, + max_kv_len_this_time, + !speculate_decoder, + !speculate_decoder, + exec_stream); break; } } } else { - data_t tmp; - dispatch_CascadeAppendAttentionKernel(tmp, decoder_batch_ids, decoder_tile_ids_per_batch, decoder_num_blocks_data, - decoder_block_shape_q, max_len_kv_data, !speculate_decoder, !speculate_decoder, exec_stream); + data_t tmp; + dispatch_CascadeAppendAttentionKernel(tmp, + decoder_batch_ids, + decoder_tile_ids_per_batch, + decoder_num_blocks_data, + decoder_block_shape_q, + max_kv_len_this_time, + !speculate_decoder, + !speculate_decoder, + exec_stream); } if (max_enc_len_this_time > 0) { cudaEventRecord(decoder_event, exec_stream); cudaStreamWaitEvent(main_stream, decoder_event); } } - - return {fmha_out, qkv_out}; } std::vector AppendAttention( @@ -416,7 +471,6 @@ std::vector AppendAttention( const paddle::Tensor& decoder_tile_ids_per_batch, const paddle::Tensor& decoder_num_blocks, const paddle::Tensor& set_max_lengths, - const paddle::Tensor& max_len_kv, const paddle::optional& rotary_embs, const paddle::optional& attn_mask, const paddle::optional& qkv_bias, @@ -429,7 +483,224 @@ std::vector AppendAttention( const paddle::optional& cache_v_zp, const paddle::optional& out_linear_shifts, const paddle::optional& out_linear_smooths, + const paddle::optional& mask_offset, + const paddle::optional& kv_signal_data, + const paddle::optional& q_norm_weight, + const paddle::optional& k_norm_weight, + const paddle::optional& sinks, + const float rms_norm_eps, + const std::string& compute_dtype, + const std::string& cache_quant_type_str, + const bool use_neox_rotary_style, + const bool rope_3d, + const int max_input_length, + const float quant_max_bound, + const float quant_min_bound, + const float out_linear_in_scale, + const int encoder_block_shape_q, + const int decoder_block_shape_q, + const int max_partition_size, + const int encoder_max_partition_size, + const int speculate_max_draft_token_num, + const bool causal, + const bool speculate_decoder, + const int sliding_window, + const int sink_size = 0) { + AppendAttnMetaData meta_data; + + const auto& qkv_dims = qkv.dims(); + const auto& key_cache_dims = key_cache.dims(); + meta_data.token_nums = qkv_dims[0]; + meta_data.kv_num_heads = key_cache_dims[1]; + meta_data.head_dims = key_cache_dims[3]; + // TODO: trick method support c4, add attr head_dims in the future + if (cache_quant_type_str == "cache_int4_zp") { + meta_data.head_dims *= 2; + } + const int total_num_head = + qkv_dims[qkv_dims.size() - 1] / meta_data.head_dims; + meta_data.q_num_heads = total_num_head - 2 * meta_data.kv_num_heads; + + meta_data.max_blocks_per_seq = block_tables.dims()[1]; + meta_data.block_size = key_cache.dims()[2]; + meta_data.batch_size = seq_lens_this_time.dims()[0]; + + // template dtype generation + phi::DataType dtype_id; + switch (qkv.dtype()) { + case paddle::DataType::FLOAT16: { + dtype_id = phi::DataType::FLOAT16; + break; + } + case paddle::DataType::BFLOAT16: { + dtype_id = phi::DataType::BFLOAT16; + break; + } + case paddle::DataType::INT32: { + if (compute_dtype == "bf16") { + dtype_id = phi::DataType::BFLOAT16; + break; + } else if (compute_dtype == "fp16") { + dtype_id = phi::DataType::FLOAT16; + break; + } else { + PD_THROW("Only supported attr of compute_dtype in ['fp16', 'bf16']."); + break; + } + } + default: { + PD_THROW( + "NOT supported data type. " + "Only float16 and bfloat16 are supported. "); + break; + } + } + + // fmha_out generation, rewrite from AppendAttentionKernel + paddle::Tensor fmha_out; + if (out_linear_in_scale > 0.0) { + if (fabs(quant_max_bound - 127.0f) < 0.000001) { + fmha_out = paddle::zeros( + {meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims}, + paddle::DataType::INT8, + qkv.place()); + } else if (fabs(quant_max_bound - 448.0f) < 0.000001) { + fmha_out = paddle::zeros( + {meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims}, + paddle::DataType::FLOAT8_E4M3FN, + qkv.place()); + } else { + PD_THROW("Only supported attr of quant_max_bound in ['127', '448']."); + } + } else { + fmha_out = paddle::zeros( + {meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims}, + dtype_id, + qkv.place()); + } + + if (mask_offset) { + meta_data.mask_offset = mask_offset.get().data(); + } + + auto dispatch_by_template = [&](auto temp_args) -> void { + AppendAttentionKernel::value>( + meta_data, + qkv, + key_cache, + value_cache, + seq_lens_encoder, + seq_lens_decoder, + seq_lens_this_time, + batch_id_per_token, + cu_seqlens_q, + block_tables, + encoder_batch_ids, + encoder_tile_ids_per_batch, + encoder_num_blocks, + kv_batch_ids, + kv_tile_ids_per_batch, + kv_num_blocks, + decoder_batch_ids, + decoder_tile_ids_per_batch, + decoder_num_blocks, + set_max_lengths, + fmha_out, + rotary_embs, + attn_mask, + qkv_bias, + qkv_out_scales, + cache_k_quant_scales, + cache_v_quant_scales, + cache_k_dequant_scales, + cache_v_dequant_scales, + cache_k_zp, + cache_v_zp, + out_linear_shifts, + out_linear_smooths, + kv_signal_data, + q_norm_weight, + k_norm_weight, + sinks, + rms_norm_eps, + cache_quant_type_str, + use_neox_rotary_style, + rope_3d, + max_input_length, + quant_max_bound, + quant_min_bound, + out_linear_in_scale, + encoder_block_shape_q, + decoder_block_shape_q, + max_partition_size, + encoder_max_partition_size, + speculate_max_draft_token_num, + causal, + speculate_decoder, + sliding_window, + sink_size); + }; + + phi::dtype::float16 fp16_dtype; + phi::dtype::bfloat16 bp16_dtype; + switch (dtype_id) { + case phi::DataType::FLOAT16: { + dispatch_by_template(fp16_dtype); + return {fmha_out}; + } + case phi::DataType::BFLOAT16: { + dispatch_by_template(bp16_dtype); + return {fmha_out}; + } + default: + PD_THROW( + "NOT supported data type. " + "Only float16 and bfloat16 are supported. "); + break; + } + + return {paddle::Tensor{}}; +} + +std::vector AppendAttentionWithOutput( + const paddle::Tensor& qkv, + const paddle::Tensor& key_cache, + const paddle::Tensor& value_cache, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& batch_id_per_token, + const paddle::Tensor& cu_seqlens_q, + const paddle::Tensor& block_tables, + const paddle::Tensor& encoder_batch_ids, + const paddle::Tensor& encoder_tile_ids_per_batch, + const paddle::Tensor& encoder_num_blocks, + const paddle::Tensor& kv_batch_ids, + const paddle::Tensor& kv_tile_ids_per_batch, + const paddle::Tensor& kv_num_blocks, + const paddle::Tensor& decoder_batch_ids, + const paddle::Tensor& decoder_tile_ids_per_batch, + const paddle::Tensor& decoder_num_blocks, + const paddle::Tensor& set_max_lengths, + paddle::Tensor& fmha_out, + const paddle::optional& rotary_embs, + const paddle::optional& attn_mask, + const paddle::optional& qkv_bias, + const paddle::optional& qkv_out_scales, + const paddle::optional& cache_k_quant_scales, + const paddle::optional& cache_v_quant_scales, + const paddle::optional& cache_k_dequant_scales, + const paddle::optional& cache_v_dequant_scales, + const paddle::optional& cache_k_zp, + const paddle::optional& cache_v_zp, + const paddle::optional& out_linear_shifts, + const paddle::optional& out_linear_smooths, + const paddle::optional& mask_offset, const paddle::optional& kv_signal_data, + const paddle::optional& q_norm_weight, + const paddle::optional& k_norm_weight, + const paddle::optional& sinks, + const float rms_norm_eps, const std::string& compute_dtype, const std::string& cache_quant_type_str, const bool use_neox_rotary_style, @@ -444,7 +715,9 @@ std::vector AppendAttention( const int encoder_max_partition_size, const int speculate_max_draft_token_num, const bool causal, - const bool speculate_decoder) { + const bool speculate_decoder, + const int sliding_window, + const int sink_size = 0) { AppendAttnMetaData meta_data; const auto& qkv_dims = qkv.dims(); @@ -464,70 +737,87 @@ std::vector AppendAttention( meta_data.block_size = key_cache.dims()[2]; meta_data.batch_size = seq_lens_this_time.dims()[0]; - auto dispatch_by_template = [&](auto temp_args) -> std::vector { - return AppendAttentionKernel::value>( - meta_data, - qkv, - key_cache, - value_cache, - seq_lens_encoder, - seq_lens_decoder, - seq_lens_this_time, - batch_id_per_token, - cu_seqlens_q, - block_tables, - encoder_batch_ids, - encoder_tile_ids_per_batch, - encoder_num_blocks, - kv_batch_ids, - kv_tile_ids_per_batch, - kv_num_blocks, - decoder_batch_ids, - decoder_tile_ids_per_batch, - decoder_num_blocks, - set_max_lengths, - max_len_kv, - rotary_embs, - attn_mask, - qkv_bias, - qkv_out_scales, - cache_k_quant_scales, - cache_v_quant_scales, - cache_k_dequant_scales, - cache_v_dequant_scales, - cache_k_zp, - cache_v_zp, - out_linear_shifts, - out_linear_smooths, - kv_signal_data, - cache_quant_type_str, - use_neox_rotary_style, - rope_3d, - max_input_length, - quant_max_bound, - quant_min_bound, - out_linear_in_scale, - encoder_block_shape_q, - decoder_block_shape_q, - max_partition_size, - encoder_max_partition_size, - speculate_max_draft_token_num, - causal, - speculate_decoder); + if (mask_offset) { + meta_data.mask_offset = mask_offset.get().data(); + } + auto dispatch_by_template = [&](auto temp_args) -> void { + AppendAttentionKernel::value>( + meta_data, + qkv, + key_cache, + value_cache, + seq_lens_encoder, + seq_lens_decoder, + seq_lens_this_time, + batch_id_per_token, + cu_seqlens_q, + block_tables, + encoder_batch_ids, + encoder_tile_ids_per_batch, + encoder_num_blocks, + kv_batch_ids, + kv_tile_ids_per_batch, + kv_num_blocks, + decoder_batch_ids, + decoder_tile_ids_per_batch, + decoder_num_blocks, + set_max_lengths, + fmha_out, + rotary_embs, + attn_mask, + qkv_bias, + qkv_out_scales, + cache_k_quant_scales, + cache_v_quant_scales, + cache_k_dequant_scales, + cache_v_dequant_scales, + cache_k_zp, + cache_v_zp, + out_linear_shifts, + out_linear_smooths, + kv_signal_data, + q_norm_weight, + k_norm_weight, + sinks, + rms_norm_eps, + cache_quant_type_str, + use_neox_rotary_style, + rope_3d, + max_input_length, + quant_max_bound, + quant_min_bound, + out_linear_in_scale, + encoder_block_shape_q, + decoder_block_shape_q, + max_partition_size, + encoder_max_partition_size, + speculate_max_draft_token_num, + causal, + speculate_decoder, + sliding_window, + sink_size); }; phi::dtype::float16 fp16_dtype; phi::dtype::bfloat16 bp16_dtype; switch (qkv.dtype()) { - case paddle::DataType::FLOAT16: return dispatch_by_template(fp16_dtype); - case paddle::DataType::BFLOAT16: return dispatch_by_template(bp16_dtype); + case paddle::DataType::FLOAT16: { + dispatch_by_template(fp16_dtype); + break; + } + case paddle::DataType::BFLOAT16: { + dispatch_by_template(bp16_dtype); + break; + } case paddle::DataType::INT32: { if (compute_dtype == "bf16") { - return dispatch_by_template(bp16_dtype); + dispatch_by_template(bp16_dtype); + break; } else if (compute_dtype == "fp16") { - return dispatch_by_template(fp16_dtype); + dispatch_by_template(fp16_dtype); + break; } else { PD_THROW("Only supported attr of compute_dtype in ['fp16', 'bf16']."); break; @@ -540,7 +830,8 @@ std::vector AppendAttention( break; } } - return {paddle::Tensor{}}; + + return {fmha_out}; } std::vector> AppendAttentionInferShape( @@ -563,7 +854,6 @@ std::vector> AppendAttentionInferShape( const std::vector& decoder_tile_ids_per_batch_shape, const std::vector& decoder_num_blocks_shape, const std::vector& set_max_lengths_shape, - const std::vector& max_len_kv_shape, const paddle::optional>& rotary_embs_shape, const paddle::optional>& attn_mask_shape, const paddle::optional>& qkv_bias_shape, @@ -576,7 +866,12 @@ std::vector> AppendAttentionInferShape( const paddle::optional>& cache_v_zp_shape, const paddle::optional>& out_linear_shifts_shape, const paddle::optional>& out_linear_smooths_shape, + const paddle::optional>& mask_offset_shape, const paddle::optional>& kv_signal_data_shape, + const paddle::optional>& q_norm_weight_shape, + const paddle::optional>& k_norm_weight_shape, + const paddle::optional>& sinks_shape, + const float rms_norm_eps, const std::string& compute_dtype, const std::string& cache_quant_type_str, const bool use_neox_rotary_style, @@ -591,7 +886,9 @@ std::vector> AppendAttentionInferShape( const int encoder_max_partition_size, const int speculate_max_draft_token_num, const bool causal, - const bool speculate_decoder) { + const bool speculate_decoder, + const int sliding_window, + const int sink_size) { const int token_num = qkv_shape[0]; const int kv_num_heads = key_cache_shape[1]; int head_dim = key_cache_shape[3]; @@ -600,7 +897,7 @@ std::vector> AppendAttentionInferShape( } const int total_num_head = qkv_shape[qkv_shape.size() - 1] / head_dim; const int num_heads = total_num_head - 2 * kv_num_heads; - return {{token_num, num_heads * head_dim}, qkv_shape}; + return {{token_num, num_heads * head_dim}}; } std::vector AppendAttentionInferDtype( @@ -623,7 +920,6 @@ std::vector AppendAttentionInferDtype( const paddle::DataType& decoder_tile_ids_per_batch_dtype, const paddle::DataType& decoder_num_blocks_dtype, const paddle::DataType& set_max_lengths_dtype, - const paddle::DataType& max_len_kv_dtype, const paddle::optional& rotary_embs_dtype, const paddle::optional& attn_mask_dtype, const paddle::optional& qkv_bias_dtype, @@ -636,7 +932,12 @@ std::vector AppendAttentionInferDtype( const paddle::optional& cache_v_zp_dtype, const paddle::optional& out_linear_shifts_dtype, const paddle::optional& out_linear_smooths_dtype, + const paddle::optional& mask_offset_dtype, const paddle::optional& kv_signal_data_dtype, + const paddle::optional& q_norm_weight_dtype, + const paddle::optional& k_norm_weight_dtype, + const paddle::optional& sinks_dtype, + const float rms_norm_eps, const std::string& compute_dtype, const std::string& cache_quant_type_str, const bool use_neox_rotary_style, @@ -651,36 +952,158 @@ std::vector AppendAttentionInferDtype( const int encoder_max_partition_size, const int speculate_max_draft_token_num, const bool causal, - const bool speculate_decoder) { + const bool speculate_decoder, + const int sliding_window, + const int sink_size) { if (compute_dtype == "bf16") { if (out_linear_in_scale > 0.0) { if (fabs(quant_max_bound - 127.0f) < 0.000001) { - return {paddle::DataType::INT8, paddle::DataType::BFLOAT16}; + return {paddle::DataType::INT8}; } else if (fabs(quant_max_bound - 448.0f) < 0.000001) { - return {paddle::DataType::FLOAT8_E4M3FN, paddle::DataType::BFLOAT16}; - }else{ - PD_THROW("Only supported attr of quant_max_bound in ['127.0', '448.0']."); + return {paddle::DataType::FLOAT8_E4M3FN}; + } else { + PD_THROW( + "Only supported attr of quant_max_bound in ['127.0', '448.0']."); } } else { - return {paddle::DataType::BFLOAT16, paddle::DataType::BFLOAT16}; + return {paddle::DataType::BFLOAT16}; } } else if (compute_dtype == "fp16") { if (out_linear_in_scale > 0.0) { if (fabs(quant_max_bound - 127.0f) < 0.000001) { - return {paddle::DataType::INT8, paddle::DataType::FLOAT16}; + return {paddle::DataType::INT8}; } else if (fabs(quant_max_bound - 448.0f) < 0.000001) { - return {paddle::DataType::FLOAT8_E4M3FN, paddle::DataType::FLOAT16}; - }else{ - PD_THROW("Only supported attr of quant_max_bound in ['127.0', '448.0']."); + return {paddle::DataType::FLOAT8_E4M3FN}; + } else { + PD_THROW( + "Only supported attr of quant_max_bound in ['127.0', '448.0']."); } } else { - return {paddle::DataType::FLOAT16, paddle::DataType::FLOAT16}; + return {paddle::DataType::FLOAT16}; } } else { PD_THROW("Only supported attr of compute_dtype in ['fp16', 'bf16']."); } } +std::vector> AppendAttentionWithOutputInferShape( + const std::vector& qkv_shape, + const std::vector& key_cache_shape, + const std::vector& value_cache_shape, + const std::vector& seq_lens_encoder_shape, + const std::vector& seq_lens_decoder_shape, + const std::vector& seq_lens_this_time_shape, + const std::vector& batch_id_per_token_shape, + const std::vector& cu_seqlens_q_shape, + const std::vector& block_tables_shape, + const std::vector& encoder_batch_ids_shape, + const std::vector& encoder_tile_ids_per_batch_shape, + const std::vector& encoder_num_blocks_shape, + const std::vector& kv_batch_ids_shape, + const std::vector& kv_tile_ids_per_batch_shape, + const std::vector& kv_num_blocks_shape, + const std::vector& decoder_batch_ids_shape, + const std::vector& decoder_tile_ids_per_batch_shape, + const std::vector& decoder_num_blocks_shape, + const std::vector& set_max_lengths_shape, + const std::vector& fmha_out_shape, + const paddle::optional>& rotary_embs_shape, + const paddle::optional>& attn_mask_shape, + const paddle::optional>& qkv_bias_shape, + const paddle::optional>& qkv_out_scales_shape, + const paddle::optional>& cache_k_quant_scales_shape, + const paddle::optional>& cache_v_quant_scales_shape, + const paddle::optional>& cache_k_dequant_scales_shape, + const paddle::optional>& cache_v_dequant_scales_shape, + const paddle::optional>& cache_k_zp_shape, + const paddle::optional>& cache_v_zp_shape, + const paddle::optional>& out_linear_shifts_shape, + const paddle::optional>& out_linear_smooths_shape, + const paddle::optional>& mask_offset_shape, + const paddle::optional>& kv_signal_data_shape, + const paddle::optional>& q_norm_weight_shape, + const paddle::optional>& k_norm_weight_shape, + const paddle::optional>& sinks_shape, + const float rms_norm_eps, + const std::string& compute_dtype, + const std::string& cache_quant_type_str, + const bool use_neox_rotary_style, + const bool rope_3d, + const int max_input_length, + const float quant_max_bound, + const float quant_min_bound, + const float out_linear_in_scale, + const int encoder_block_shape_q, + const int decoder_block_shape_q, + const int max_partition_size, + const int encoder_max_partition_size, + const int speculate_max_draft_token_num, + const bool causal, + const bool speculate_decoder, + const int sliding_window, + const int sink_size) { + return {fmha_out_shape}; +} + +std::vector AppendAttentionWithOutputInferDtype( + const paddle::DataType& qkv_dtype, + const paddle::DataType& key_cache_dtype, + const paddle::DataType& value_cache_dtype, + const paddle::DataType& seq_lens_encoder_dtype, + const paddle::DataType& seq_lens_decoder_dtype, + const paddle::DataType& seq_lens_this_time_dtype, + const paddle::DataType& batch_id_per_token_dtype, + const paddle::DataType& cu_seqlens_q_dtype, + const paddle::DataType& block_tables_dtype, + const paddle::DataType& encoder_batch_ids_dtype, + const paddle::DataType& encoder_tile_ids_per_batch_dtype, + const paddle::DataType& encoder_num_blocks_dtype, + const paddle::DataType& kv_batch_ids_dtype, + const paddle::DataType& kv_tile_ids_per_batch_dtype, + const paddle::DataType& kv_num_blocks_dtype, + const paddle::DataType& decoder_batch_ids_dtype, + const paddle::DataType& decoder_tile_ids_per_batch_dtype, + const paddle::DataType& decoder_num_blocks_dtype, + const paddle::DataType& set_max_lengths_dtype, + const paddle::DataType& fmha_out_dtype, + const paddle::optional& rotary_embs_dtype, + const paddle::optional& attn_mask_dtype, + const paddle::optional& qkv_bias_dtype, + const paddle::optional& qkv_out_scales_dtype, + const paddle::optional& cache_k_quant_scales_dtype, + const paddle::optional& cache_v_quant_scales_dtype, + const paddle::optional& cache_k_dequant_scales_dtype, + const paddle::optional& cache_v_dequant_scales_dtype, + const paddle::optional& cache_k_zp_dtype, + const paddle::optional& cache_v_zp_dtype, + const paddle::optional& out_linear_shifts_dtype, + const paddle::optional& out_linear_smooths_dtype, + const paddle::optional& mask_offset_dtype, + const paddle::optional& kv_signal_data_dtype, + const paddle::optional& q_norm_weight_dtype, + const paddle::optional& k_norm_weight_dtype, + const paddle::optional& sinks_dtype, + const float rms_norm_eps, + const std::string& compute_dtype, + const std::string& cache_quant_type_str, + const bool use_neox_rotary_style, + const bool rope_3d, + const int max_input_length, + const float quant_max_bound, + const float quant_min_bound, + const float out_linear_in_scale, + const int encoder_block_shape_q, + const int decoder_block_shape_q, + const int max_partition_size, + const int encoder_max_partition_size, + const int speculate_max_draft_token_num, + const bool causal, + const bool speculate_decoder, + const int sliding_window, + const int sink_size) { + return {fmha_out_dtype}; +} + PD_BUILD_STATIC_OP(append_attention) .Inputs({"qkv", "key_cache", @@ -701,7 +1124,6 @@ PD_BUILD_STATIC_OP(append_attention) "decoder_tile_ids_per_batch", "decoder_num_blocks", "set_max_lengths", - "max_len_kv", paddle::Optional("rotary_embs"), paddle::Optional("attn_mask"), paddle::Optional("qkv_bias"), @@ -714,25 +1136,96 @@ PD_BUILD_STATIC_OP(append_attention) paddle::Optional("cache_v_zp"), paddle::Optional("out_linear_shifts"), paddle::Optional("out_linear_smooths"), - paddle::Optional("kv_signal_data")}) - .Outputs({"fmha_out", "qkv_out", "key_cache_out", "value_cache_out"}) - .SetInplaceMap({{"key_cache", "key_cache_out"}, - {"value_cache", "value_cache_out"}}) - .Attrs({"compute_type: std::string", - "cache_quant_type: std::string", - "use_neox_rotary_style: bool", - "rope_3d: bool", - "max_input_length: int", - "quant_max_bound: float", - "quant_min_bound: float", - "out_linear_in_scale: float", - "encoder_block_shape_q: int", - "decoder_block_shape_q: int", - "max_partition_size: int", - "encoder_max_partition_size: int", - "speculate_max_draft_token_num: int", - "causal: bool", - "speculate_decoder: bool"}) + paddle::Optional("mask_offset"), + paddle::Optional("kv_signal_data"), + paddle::Optional("q_norm_weight"), + paddle::Optional("k_norm_weight"), + paddle::Optional("sinks")}) + .Outputs({"fmha_out"}) + .Attrs({ + "rms_norm_eps: float", + "compute_type: std::string", + "cache_quant_type: std::string", + "use_neox_rotary_style: bool", + "rope_3d: bool", + "max_input_length: int", + "quant_max_bound: float", + "quant_min_bound: float", + "out_linear_in_scale: float", + "encoder_block_shape_q: int", + "decoder_block_shape_q: int", + "max_partition_size: int", + "encoder_max_partition_size: int", + "speculate_max_draft_token_num: int", + "causal: bool", + "speculate_decoder: bool", + "sliding_window: int", + "sink_size: int", + }) .SetKernelFn(PD_KERNEL(AppendAttention)) .SetInferShapeFn(PD_INFER_SHAPE(AppendAttentionInferShape)) .SetInferDtypeFn(PD_INFER_DTYPE(AppendAttentionInferDtype)); + +PD_BUILD_STATIC_OP(append_attention_with_output) + .Inputs({"qkv", + "key_cache", + "value_cache", + "seq_lens_encoder", + "seq_lens_decoder", + "seq_lens_this_time", + "batch_id_per_token", + "cu_seqlens_q", + "block_tables", + "encoder_batch_ids", + "encoder_tile_ids_per_batch", + "encoder_num_blocks", + "kv_batch_ids", + "kv_tile_ids_per_batch", + "kv_num_blocks", + "decoder_batch_ids", + "decoder_tile_ids_per_batch", + "decoder_num_blocks", + "set_max_lengths", + "fmha_out", + paddle::Optional("rotary_embs"), + paddle::Optional("attn_mask"), + paddle::Optional("qkv_bias"), + paddle::Optional("qkv_out_scales"), + paddle::Optional("cache_k_quant_scales"), + paddle::Optional("cache_v_quant_scales"), + paddle::Optional("cache_k_dequant_scales"), + paddle::Optional("cache_v_dequant_scales"), + paddle::Optional("cache_k_zp"), + paddle::Optional("cache_v_zp"), + paddle::Optional("out_linear_shifts"), + paddle::Optional("out_linear_smooths"), + paddle::Optional("mask_offset"), + paddle::Optional("kv_signal_data"), + paddle::Optional("q_norm_weight"), + paddle::Optional("k_norm_weight"), + paddle::Optional("sinks")}) + .Outputs({"fmha_out_out"}) + .SetInplaceMap({{"fmha_out", "fmha_out_out"}}) + .Attrs({ + "rms_norm_eps: float", + "compute_type: std::string", + "cache_quant_type: std::string", + "use_neox_rotary_style: bool", + "rope_3d: bool", + "max_input_length: int", + "quant_max_bound: float", + "quant_min_bound: float", + "out_linear_in_scale: float", + "encoder_block_shape_q: int", + "decoder_block_shape_q: int", + "max_partition_size: int", + "encoder_max_partition_size: int", + "speculate_max_draft_token_num: int", + "causal: bool", + "speculate_decoder: bool", + "sliding_window: int", + "sink_size: int", + }) + .SetKernelFn(PD_KERNEL(AppendAttentionWithOutput)) + .SetInferShapeFn(PD_INFER_SHAPE(AppendAttentionWithOutputInferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(AppendAttentionWithOutputInferDtype)); diff --git a/custom_ops/gpu_ops/append_attn/append_attention_c16_impl.cuh b/custom_ops/gpu_ops/append_attn/append_attention_c16_impl.cuh index b7d8441c685..70329c9366a 100644 --- a/custom_ops/gpu_ops/append_attn/append_attention_c16_impl.cuh +++ b/custom_ops/gpu_ops/append_attn/append_attention_c16_impl.cuh @@ -13,1262 +13,13 @@ // limitations under the License. #pragma once -#include "append_attention_func.cuh" -#include "append_attention_kernel.h" - -template -__global__ void multi_query_append_attention_kernel( - T *__restrict__ q, // [token_num, (num_heads + 2* kv_num_head) * head_dim] - T *__restrict__ cache_k, // [max_block_num, num_heads, block_size, - // head_dim] - T *__restrict__ cache_v, - const T *__restrict__ shift_bias, // [q_num_heads * HEAD_DIM] - const T *__restrict__ smooth_weight, // [q_num_heads * HEAD_DIM] - const int *__restrict__ seq_lens, - const int *__restrict__ seq_lens_kv, - const int *__restrict__ batch_ids, - const int *__restrict__ tile_ids_per_batch, - const int *__restrict__ cu_seqlens_q, - const int *__restrict__ block_table, // [bsz, block_num_per_seq] - const int max_seq_len, - const int max_dec_len, - const int max_block_num_per_seq, - const float scale, - const float quant_max_bound, - const float quant_min_bound, - const float in_scale, - const uint32_t chunk_size, - T *__restrict__ tmp_workspace, // split kv [token_num, num_chunks, - // num_heads, head_dim] - float *__restrict__ tmp_m, // [token_num, num_chunks, num_heads] - float *__restrict__ tmp_d, // [token_num, num_chunks, num_heads] - OutT *__restrict__ out, - const int speculate_max_draft_token_num = 5) { - constexpr uint32_t num_vecs_per_head = HEAD_DIM / num_elems_per_128b(); - const uint32_t btid = blockIdx.x, kv_head_idx = blockIdx.z; - const uint32_t kv_num_heads = gridDim.z; - const uint32_t q_num_heads = kv_num_heads * GROUP_SIZE; - const uint32_t q_head_idx = kv_head_idx * GROUP_SIZE; - const uint32_t tid = threadIdx.x, wid = threadIdx.y; - const uint32_t num_chunks = gridDim.y; - const uint32_t chunk_idx = blockIdx.y; - - const uint32_t batch_id = batch_ids[btid]; - const uint32_t tile_id = tile_ids_per_batch[btid]; - const uint32_t num_rows_per_block = NUM_WARPS * num_frags_x * 16; - const int *block_table_now = nullptr; - - block_table_now = block_table + batch_id * max_block_num_per_seq; - - const uint32_t q_len = seq_lens[batch_id]; - if (q_len <= 0) { - return; - } - const uint32_t q_end = - min(q_len, div_up((tile_id + 1) * num_rows_per_block, GROUP_SIZE)); - uint32_t kv_len = seq_lens_kv[batch_id]; - if (ENABLE_PREFILL) { - kv_len += q_len; - if (kv_len <= 0) { - return; - } - } else { - if (kv_len <= 0) { - return; - } - kv_len += q_len; - } - - const uint32_t num_chunks_this_seq = div_up(kv_len, chunk_size); - if (chunk_idx >= num_chunks_this_seq) { - return; - } - - const uint32_t chunk_start = partition_kv ? chunk_idx * chunk_size : 0; - const uint32_t chunk_end = - partition_kv ? min(kv_len, chunk_start + chunk_size) : kv_len; - const uint32_t chunk_len = chunk_end - chunk_start; - - extern __shared__ uint8_t smem[]; - float s_frag[num_frags_x][num_frags_z][8]; - float o_frag[num_frags_x][num_frags_y][8]; - float m_frag[num_frags_x][2]; - float d_frag[num_frags_x][2]; - init_states(o_frag, m_frag, d_frag); - - const uint32_t q_n_stride = q_num_heads * HEAD_DIM; - const uint32_t q_ori_n_stride = (q_num_heads + kv_num_heads * 2) * HEAD_DIM; - const uint32_t kv_n_stride = kv_num_heads * BLOCK_SIZE * HEAD_DIM; - const uint32_t kv_h_stride = BLOCK_SIZE * HEAD_DIM; - const uint32_t kv_b_stride = HEAD_DIM; - const uint32_t q_start_seq_id = cu_seqlens_q[batch_id]; - const uint32_t q_base_seq_id_this_block = - (tile_id * NUM_WARPS + wid) * num_frags_x * 16; - const uint32_t q_offset = q_start_seq_id * q_ori_n_stride + - q_head_idx * HEAD_DIM + - tid % 8 * num_elems_per_128b(); - const uint32_t o_offset = q_start_seq_id * q_n_stride + - q_head_idx * HEAD_DIM + - tid % 8 * num_elems_per_128b(); - T *q_base_ptr = q + q_offset; - T *o_base_ptr_T = nullptr; - OutT *o_base_ptr_int8 = nullptr; - if constexpr (partition_kv) { - if (ENABLE_PREFILL) { - o_base_ptr_T = tmp_workspace + q_start_seq_id * num_chunks * q_n_stride + - chunk_idx * q_n_stride + q_head_idx * HEAD_DIM + - tid % 8 * num_elems_per_128b(); - } else { - o_base_ptr_T = - tmp_workspace + - batch_id * speculate_max_draft_token_num * num_chunks * q_n_stride + - chunk_idx * q_n_stride + q_head_idx * HEAD_DIM + - tid % 8 * num_elems_per_128b(); - } - } else { - o_base_ptr_int8 = out + o_offset; - } - smem_t qo_smem(smem); - - uint32_t q_smem_offset_r = smem_t::get_permuted_offset( - wid * num_frags_x * 16 + tid % 16, tid / 16); // 16 * 16 - load_q_global_smem( - q_base_ptr, - &qo_smem, - q_base_seq_id_this_block, - q_end, - q_ori_n_stride, - HEAD_DIM); - commit_group(); - wait_group<0>(); - __syncthreads(); - - q_smem_inplace_multiply_sm_scale(&qo_smem, - scale); - - smem_t k_smem(smem + NUM_WARPS * num_frags_x * 16 * HEAD_DIM * sizeof(T)), - v_smem(smem + (NUM_WARPS * num_frags_x + num_frags_z) * 16 * HEAD_DIM * - sizeof(T)); - - - const uint32_t num_iterations = div_up( - CAUSAL - ? (min(chunk_len, - sub_if_greater_or_zero( - kv_len - q_len + - div_up((tile_id + 1) * num_rows_per_block, GROUP_SIZE), - chunk_start))) - : chunk_len, - num_frags_z * 16); - const uint32_t mask_check_iteration = - (CAUSAL ? (min(chunk_len, - sub_if_greater_or_zero( - kv_len - q_len + - tile_id * num_rows_per_block / GROUP_SIZE, - chunk_start))) - : chunk_len) / - (num_frags_z * 16); - uint32_t k_smem_offset_r = smem_t::get_permuted_offset( - 8 * (tid / 16) + tid % 8, (tid % 16) / 8); - - uint32_t v_smem_offset_r = - smem_t::get_permuted_offset(tid % 16, tid / 16); - - uint32_t kv_smem_offset_w = smem_t::get_permuted_offset( - wid * 4 + tid / 8, tid % 8); - - uint32_t kv_idx_base = chunk_start; - int block_id = __ldg(&block_table_now[kv_idx_base / BLOCK_SIZE]); - const uint32_t const_offset = kv_head_idx * kv_h_stride + - (wid * 4 + tid / 8) * kv_b_stride + - tid % 8 * num_elems_per_128b(); - T *cache_k_now = cache_k + block_id * kv_n_stride + const_offset; - T *cache_v_now = cache_v + block_id * kv_n_stride + const_offset; - - produce_kv_blockwise(k_smem, - &kv_smem_offset_w, - &cache_k_now, - kv_head_idx, - kv_n_stride, - kv_h_stride, - kv_b_stride, - kv_idx_base, - chunk_end); - commit_group(); - produce_kv_blockwise(v_smem, - &kv_smem_offset_w, - &cache_v_now, - kv_head_idx, - kv_n_stride, - kv_h_stride, - kv_b_stride, - kv_idx_base, - chunk_end); - commit_group(); -#pragma unroll 1 - for (uint32_t iter = 0; iter < num_iterations; ++iter) { - wait_group<1>(); - __syncthreads(); - - // s = qk - compute_qk( - &qo_smem, &q_smem_offset_r, &k_smem, &k_smem_offset_r, s_frag); - // mask according to kv_idx and q_idx - if (iter >= mask_check_iteration) { - mask_s(q_base_seq_id_this_block, - kv_idx_base, - q_len, - kv_len, - chunk_end, - s_frag); - } - - // update m,d - update_mdo_states( - s_frag, o_frag, m_frag, d_frag); - __syncthreads(); - - kv_idx_base += num_frags_z * 16; - block_id = __ldg(&block_table_now[kv_idx_base / BLOCK_SIZE]); - if (block_id < 0) { - block_id = 0; - } - cache_k_now = cache_k + block_id * kv_n_stride + const_offset; - produce_kv_blockwise(k_smem, - &kv_smem_offset_w, - &cache_k_now, - kv_head_idx, - kv_n_stride, - kv_h_stride, - kv_b_stride, - kv_idx_base, - chunk_end); - commit_group(); - wait_group<1>(); - __syncthreads(); - - // compute sfm*v - compute_sfm_v( - &v_smem, &v_smem_offset_r, s_frag, o_frag, d_frag); - - __syncthreads(); - cache_v_now = cache_v + block_id * kv_n_stride + const_offset; - produce_kv_blockwise(v_smem, - &kv_smem_offset_w, - &cache_v_now, - kv_head_idx, - kv_n_stride, - kv_h_stride, - kv_b_stride, - kv_idx_base, - chunk_end); - commit_group(); - } - wait_group<0>(); - __syncthreads(); - - if constexpr (!partition_kv) { - normalize_d(o_frag, d_frag); - } - if constexpr (partition_kv) { - write_o_reg_gmem_shift_smooth_quant( - o_frag, - &qo_smem, - o_base_ptr_T, - shift_bias, - smooth_weight, - q_base_seq_id_this_block, - q_head_idx, - quant_max_bound, - quant_min_bound, - in_scale, - q_len, - partition_kv ? q_n_stride * num_chunks : q_n_stride, - HEAD_DIM); - } else { - write_o_reg_gmem_shift_smooth_quant( - o_frag, - &qo_smem, - o_base_ptr_int8, - shift_bias, - smooth_weight, - q_base_seq_id_this_block, - q_head_idx, - quant_max_bound, - quant_min_bound, - in_scale, - q_len, - partition_kv ? q_n_stride * num_chunks : q_n_stride, - HEAD_DIM); - } - - - if constexpr (partition_kv) { -#pragma unroll - for (uint32_t fx = 0; fx < num_frags_x; ++fx) { -#pragma unroll - for (uint32_t j = 0; j < 2; ++j) { - const uint32_t qo_idx_now = - q_base_seq_id_this_block + tid / 4 + j * 8 + fx * 16; - const uint32_t qo_head_idx = q_head_idx + qo_idx_now % GROUP_SIZE; - const uint32_t qo_idx = q_start_seq_id + qo_idx_now / GROUP_SIZE; - if (qo_idx - q_start_seq_id < q_len) { - uint32_t offset; - if (ENABLE_PREFILL) { - offset = - (qo_idx * num_chunks + chunk_idx) * q_num_heads + qo_head_idx; - } else { - offset = ((batch_id * speculate_max_draft_token_num + - qo_idx_now / GROUP_SIZE) * - num_chunks + - chunk_idx) * - q_num_heads + - qo_head_idx; - } - tmp_m[offset] = m_frag[fx][j]; - tmp_d[offset] = d_frag[fx][j]; - } - } - } - } -} - -template -__global__ void multi_query_append_attention_warp1_4_kernel( - T *__restrict__ q, // [token_num, (num_heads + 2* kv_num_head) * head_dim] - T *__restrict__ cache_k, // [max_block_num, num_heads, block_size, - // head_dim] - T *__restrict__ cache_v, - const T *__restrict__ shift_bias, // [q_num_heads * HEAD_DIM] - const T *__restrict__ smooth_weight, // [q_num_heads * HEAD_DIM] - const int *__restrict__ seq_lens, - const int *__restrict__ seq_lens_kv, - const int *__restrict__ batch_ids, - const int *__restrict__ tile_ids_per_batch, - const int *__restrict__ cu_seqlens_q, - const int *__restrict__ block_table, // [bsz, block_num_per_seq] - const int max_seq_len, - const int max_dec_len, - const int max_block_num_per_seq, - const float scale, - const float quant_max_bound, - const float quant_min_bound, - const float in_scale, - const uint32_t chunk_size, - T *__restrict__ tmp_workspace, // split kv [token_num, num_chunks, - // num_heads, head_dim] - float *__restrict__ tmp_m, // [token_num, num_chunks, num_heads] - float *__restrict__ tmp_d, // [token_num, num_chunks, num_heads] - OutT *__restrict__ out, - const int speculate_max_draft_token_num = 5) { - constexpr uint32_t num_vecs_per_head = HEAD_DIM / num_elems_per_128b(); - static_assert(NUM_WARP_Q == 1, "NUM_WARP_Q must be 1"); - static_assert(NUM_WARP_KV == 4, "NUM_WARP_KV must be 4"); - const uint32_t btid = blockIdx.x, kv_head_idx = blockIdx.z; - const uint32_t kv_num_heads = gridDim.z; - const uint32_t q_num_heads = kv_num_heads * GROUP_SIZE; - const uint32_t q_head_idx = kv_head_idx * GROUP_SIZE; - const uint32_t tid = threadIdx.x, wid = threadIdx.y; - const uint32_t num_chunks = gridDim.y; - const uint32_t chunk_idx = blockIdx.y; - - const uint32_t batch_id = batch_ids[btid]; - const uint32_t tile_id = tile_ids_per_batch[btid]; - const uint32_t num_rows_per_block = num_frags_x * 16; - const int *block_table_now = block_table + batch_id * max_block_num_per_seq; - - const uint32_t q_len = seq_lens[batch_id]; - if (q_len <= 0) { - return; - } - const uint32_t q_end = - min(q_len, div_up((tile_id + 1) * num_rows_per_block, GROUP_SIZE)); - uint32_t kv_len = seq_lens_kv[batch_id]; - if (ENABLE_PREFILL) { - kv_len += q_len; - if (kv_len <= 0) { - return; - } - } else { - if (kv_len <= 0) { - return; - } - kv_len += q_len; - } - const uint32_t num_chunks_this_seq = div_up(kv_len, chunk_size); - if (chunk_idx >= num_chunks_this_seq) { - return; - } - - const uint32_t chunk_start = partition_kv ? chunk_idx * chunk_size : 0; - const uint32_t chunk_end = - partition_kv ? min(kv_len, chunk_start + chunk_size) : kv_len; - const uint32_t chunk_len = chunk_end - chunk_start; - - extern __shared__ uint8_t smem[]; - float s_frag[num_frags_x][num_frags_z][8]; - float o_frag[num_frags_x][num_frags_y][8]; - float m_frag[num_frags_x][2]; - float d_frag[num_frags_x][2]; - init_states(o_frag, m_frag, d_frag); - - const uint32_t q_n_stride = q_num_heads * HEAD_DIM; - const uint32_t q_ori_n_stride = (q_num_heads + kv_num_heads * 2) * HEAD_DIM; - const uint32_t kv_n_stride = kv_num_heads * BLOCK_SIZE * HEAD_DIM; - const uint32_t kv_h_stride = BLOCK_SIZE * HEAD_DIM; - const uint32_t kv_b_stride = HEAD_DIM; - const uint32_t q_start_seq_id = cu_seqlens_q[batch_id]; - const uint32_t q_base_seq_id_this_block = tile_id * num_frags_x * 16; - const uint32_t q_offset = q_start_seq_id * q_ori_n_stride + - q_head_idx * HEAD_DIM + - tid % 8 * num_elems_per_128b(); - const uint32_t o_offset = q_start_seq_id * q_n_stride + - q_head_idx * HEAD_DIM + - tid % 8 * num_elems_per_128b(); - T *q_base_ptr = q + q_offset; - T *o_base_ptr_T = nullptr; - OutT *o_base_ptr_int8 = nullptr; - if (num_chunks_this_seq <= 1) { - o_base_ptr_int8 = out + o_offset; - } else { - if (ENABLE_PREFILL) { - o_base_ptr_T = tmp_workspace + batch_id * num_chunks * q_n_stride + - chunk_idx * q_n_stride + q_head_idx * HEAD_DIM + - tid % 8 * num_elems_per_128b(); - } else { - o_base_ptr_T = - tmp_workspace + - batch_id * speculate_max_draft_token_num * num_chunks * q_n_stride + - chunk_idx * q_n_stride + q_head_idx * HEAD_DIM + - tid % 8 * num_elems_per_128b(); - } - } - - smem_t qo_smem(smem); - - uint32_t q_smem_offset_r = smem_t::get_permuted_offset( - tid % 16, tid / 16); // 16 * 16 - load_q_global_smem_multi_warps(q_base_ptr, - &qo_smem, - q_base_seq_id_this_block, - q_end, - q_ori_n_stride, - HEAD_DIM); - commit_group(); - wait_group<0>(); - __syncthreads(); - - q_smem_inplace_multiply_sm_scale_multi_warps( - &qo_smem, scale); - - smem_t k_smem(smem + num_frags_x * 16 * HEAD_DIM * sizeof(T)), - v_smem(smem + (num_frags_x + NUM_WARP_KV * num_frags_z) * 16 * HEAD_DIM * - sizeof(T)); - - const uint32_t num_iterations = div_up( - CAUSAL - ? (min(chunk_len, - sub_if_greater_or_zero( - kv_len - q_len + - div_up((tile_id + 1) * num_rows_per_block, GROUP_SIZE), - chunk_start))) - : chunk_len, - NUM_WARP_KV * num_frags_z * 16); - const uint32_t mask_check_iteration = - (CAUSAL ? (min(chunk_len, - sub_if_greater_or_zero( - kv_len - q_len + - tile_id * num_rows_per_block / GROUP_SIZE, - chunk_start))) - : chunk_len) / - (NUM_WARP_KV * num_frags_z * 16); - - uint32_t k_smem_offset_r = smem_t::get_permuted_offset( - wid * num_frags_z * 16 + 8 * (tid / 16) + tid % 8, (tid % 16) / 8); - - uint32_t v_smem_offset_r = smem_t::get_permuted_offset( - wid * num_frags_z * 16 + tid % 16, tid / 16); - uint32_t kv_smem_offset_w = smem_t::get_permuted_offset( - wid * 4 + tid / 8, tid % 8); - - uint32_t kv_idx_base = chunk_start; - int block_id = __ldg(&block_table_now[kv_idx_base / BLOCK_SIZE]); - const uint32_t const_offset = kv_head_idx * kv_h_stride + - (wid * 4 + tid / 8) * kv_b_stride + - tid % 8 * num_elems_per_128b(); - T *cache_k_now = cache_k + block_id * kv_n_stride + const_offset; - T *cache_v_now = cache_v + block_id * kv_n_stride + const_offset; - - produce_kv_blockwise(k_smem, - &kv_smem_offset_w, - &cache_k_now, - kv_head_idx, - kv_n_stride, - kv_h_stride, - kv_b_stride, - kv_idx_base, - chunk_end); - commit_group(); - - produce_kv_blockwise(v_smem, - &kv_smem_offset_w, - &cache_v_now, - kv_head_idx, - kv_n_stride, - kv_h_stride, - kv_b_stride, - kv_idx_base, - chunk_end); - commit_group(); - -#pragma unroll 1 - for (uint32_t iter = 0; iter < num_iterations; ++iter) { - wait_group<1>(); - __syncthreads(); - - // s = qk - compute_qk( - &qo_smem, &q_smem_offset_r, &k_smem, &k_smem_offset_r, s_frag); - // mask according to kv_idx and q_idx - if (iter >= mask_check_iteration) { - mask_s(q_base_seq_id_this_block, - kv_idx_base + wid * num_frags_z * 16, - q_len, - kv_len, - chunk_end, - s_frag); - } - - // update m,d - update_mdo_states( - s_frag, o_frag, m_frag, d_frag); - __syncthreads(); - - kv_idx_base += NUM_WARP_KV * num_frags_z * 16; - block_id = __ldg(&block_table_now[kv_idx_base / BLOCK_SIZE]); - if (block_id < 0) { - block_id = 0; - } - cache_k_now = cache_k + block_id * kv_n_stride + const_offset; - produce_kv_blockwise(k_smem, - &kv_smem_offset_w, - &cache_k_now, - kv_head_idx, - kv_n_stride, - kv_h_stride, - kv_b_stride, - kv_idx_base, - chunk_end); - commit_group(); - wait_group<1>(); - __syncthreads(); - - // compute sfm*v - compute_sfm_v( - &v_smem, &v_smem_offset_r, s_frag, o_frag, d_frag); - __syncthreads(); - - cache_v_now = cache_v + block_id * kv_n_stride + const_offset; - produce_kv_blockwise(v_smem, - &kv_smem_offset_w, - &cache_v_now, - kv_head_idx, - kv_n_stride, - kv_h_stride, - kv_b_stride, - kv_idx_base, - chunk_end); - commit_group(); - } - wait_group<0>(); - __syncthreads(); - - merge_block_res_v2( - o_frag, reinterpret_cast(smem), m_frag, d_frag, wid, tid); - - if (num_chunks_this_seq <= 1) { - normalize_d(o_frag, d_frag); - } - - // write o - // [num_frags_x, 16, num_frags_y, 16] - if (num_chunks_this_seq <= 1) { - write_o_reg_gmem_multi_warps_shift_smooth_quant( - o_frag, - &qo_smem, - o_base_ptr_int8, - shift_bias, - smooth_weight, - q_base_seq_id_this_block, - q_head_idx, - quant_max_bound, - quant_min_bound, - in_scale, - q_len, - q_n_stride, - HEAD_DIM); - } else { - write_o_reg_gmem_multi_warps_shift_smooth_quant( - o_frag, - &qo_smem, - o_base_ptr_T, - shift_bias, - smooth_weight, - q_base_seq_id_this_block, - q_head_idx, - quant_max_bound, - quant_min_bound, - in_scale, - q_len, - q_n_stride * num_chunks, - HEAD_DIM); - } - - if (num_chunks_this_seq > 1) { - if (wid == 0) { -#pragma unroll - for (uint32_t fx = 0; fx < num_frags_x; ++fx) { -#pragma unroll - for (uint32_t j = 0; j < 2; ++j) { - const uint32_t qo_idx_now = - q_base_seq_id_this_block + tid / 4 + j * 8 + fx * 16; - const uint32_t qo_head_idx = q_head_idx + qo_idx_now % GROUP_SIZE; - const uint32_t qo_idx = q_start_seq_id + qo_idx_now / GROUP_SIZE; - - if (qo_idx - q_start_seq_id < q_len) { - uint32_t offset; - if (ENABLE_PREFILL) { - offset = (batch_id * num_chunks + chunk_idx) * q_num_heads + - qo_head_idx; - } else { - offset = ((batch_id * speculate_max_draft_token_num + - qo_idx_now / GROUP_SIZE) * - num_chunks + - chunk_idx) * - q_num_heads + - qo_head_idx; - } - tmp_m[offset] = m_frag[fx][j]; - tmp_d[offset] = d_frag[fx][j]; - } - } - } - } - } -} - -template -void MultiQueryAppendAttention( - const AppendAttnMetaData &meta_data, - const paddle::Tensor &qkv, - const paddle::Tensor &cache_k, - const paddle::Tensor &cache_v, - const paddle::optional &attn_mask, - const paddle::optional &shift_bias, - const paddle::optional &smooth_weight, - const paddle::Tensor &seq_lens_q, - const paddle::Tensor &seq_lens_kv, - const paddle::Tensor &seq_lens_encoder, - const paddle::Tensor &batch_id_per_token, - const paddle::Tensor &cu_seqlens_q, - const paddle::Tensor &block_table, - const paddle::Tensor &batch_ids, - const paddle::Tensor &tile_ids_per_batch, - const int num_blocks_x_cpu, - const int max_seq_len, - const int max_dec_len, - const float quant_max_bound, - const float quant_min_bound, - const float in_scale, - const int max_partition_size, - const int encoder_max_partition_size, - const int speculate_max_draft_token_num, - const bool is_decoder, - cudaStream_t &stream, - paddle::Tensor *out) { - using NV_TYPE = typename cascade_attn_type_traits::type; - using OUT_NV_TYPE = typename cascade_attn_type_traits::type; - - auto num_heads = meta_data.q_num_heads; - auto kv_num_heads = meta_data.kv_num_heads; - auto token_num = meta_data.token_nums; - auto bsz = meta_data.batch_size; - auto max_block_num_per_seq = meta_data.max_blocks_per_seq; - - constexpr uint32_t num_warps = 4; - constexpr uint32_t NUM_WARP_KV = num_warps / NUM_WARP_Q; - constexpr uint32_t num_frags_x = BLOCK_SHAPE_Q / (16 * NUM_WARP_Q); // 1 or 2 - constexpr uint32_t num_frags_y = HEAD_DIM / 16; - constexpr uint32_t num_qrow_per_block = NUM_WARP_Q * num_frags_x * 16; - - auto *allocator = paddle::GetAllocator(qkv.place()); - - const float scale = 1.f / sqrt(HEAD_DIM); - - if constexpr (NUM_WARP_Q == 4) { - constexpr uint32_t num_frags_z = BLOCK_SIZE / 16; - constexpr uint32_t smem_size = - (num_warps * num_frags_x + NUM_WARP_KV * num_frags_z * 2) * 16 * - HEAD_DIM * sizeof(T); - auto split_kv_kernel = multi_query_append_attention_kernel; - if (smem_size >= 48 * 1024) { - cudaFuncSetAttribute(split_kv_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, - smem_size); - } - const int dev_id = 0; - int sm_count; - cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id); - - uint32_t chunk_size = static_cast(max_partition_size); - if (!is_decoder) { - chunk_size = static_cast(encoder_max_partition_size); - } - const int num_chunks = div_up(max_dec_len, chunk_size); - dim3 grids(num_blocks_x_cpu, num_chunks, kv_num_heads); - dim3 blocks(32, num_warps); - if (num_chunks <= 1) { - auto nosplit_kv_kernel = - multi_query_append_attention_kernel; - if (smem_size >= 48 * 1024) { - cudaFuncSetAttribute(nosplit_kv_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, - smem_size); - } - - nosplit_kv_kernel<<>>( - reinterpret_cast(const_cast(qkv.data())), - reinterpret_cast(const_cast(cache_k.data())), - reinterpret_cast(const_cast(cache_v.data())), - shift_bias ? reinterpret_cast( - const_cast(shift_bias.get().data())) - : nullptr, - smooth_weight ? reinterpret_cast( - const_cast(smooth_weight.get().data())) - : nullptr, - seq_lens_q.data(), - seq_lens_kv.data(), - batch_ids.data(), - tile_ids_per_batch.data(), - cu_seqlens_q.data(), - block_table.data(), - max_seq_len, - max_dec_len, - max_block_num_per_seq, - scale, - quant_max_bound, - quant_min_bound, - in_scale, - chunk_size, - nullptr, - nullptr, - nullptr, - reinterpret_cast(out->data()), - speculate_max_draft_token_num); - - } else { - phi::Allocator::AllocationPtr tmp_workspace, tmp_m, tmp_d; - if (ENABLE_PREFILL) { - tmp_workspace = allocator->Allocate( - phi::SizeOf(qkv.dtype()) * - static_cast(token_num * num_chunks * num_heads * HEAD_DIM)); - tmp_m = allocator->Allocate( - phi::SizeOf(paddle::DataType::FLOAT32) * - static_cast(token_num * num_chunks * num_heads)); - tmp_d = allocator->Allocate( - phi::SizeOf(paddle::DataType::FLOAT32) * - static_cast(token_num * num_chunks * num_heads)); - } else { - tmp_workspace = allocator->Allocate( - phi::SizeOf(qkv.dtype()) * - static_cast(speculate_max_draft_token_num * bsz * - num_chunks * num_heads * HEAD_DIM)); - tmp_m = allocator->Allocate( - phi::SizeOf(paddle::DataType::FLOAT32) * - static_cast(speculate_max_draft_token_num * bsz * - num_chunks * num_heads)); - tmp_d = allocator->Allocate( - phi::SizeOf(paddle::DataType::FLOAT32) * - static_cast(speculate_max_draft_token_num * bsz * - num_chunks * num_heads)); - } - - split_kv_kernel<<>>( - reinterpret_cast(const_cast(qkv.data())), - reinterpret_cast(const_cast(cache_k.data())), - reinterpret_cast(const_cast(cache_v.data())), - shift_bias ? reinterpret_cast( - const_cast(shift_bias.get().data())) - : nullptr, - smooth_weight ? reinterpret_cast( - const_cast(smooth_weight.get().data())) - : nullptr, - seq_lens_q.data(), - seq_lens_kv.data(), - batch_ids.data(), - tile_ids_per_batch.data(), - cu_seqlens_q.data(), - block_table.data(), - max_seq_len, - max_dec_len, - max_block_num_per_seq, - scale, - quant_max_bound, - quant_min_bound, - in_scale, - chunk_size, - reinterpret_cast(tmp_workspace->ptr()), - static_cast(tmp_m->ptr()), - static_cast(tmp_d->ptr()), - reinterpret_cast(out->data()), - speculate_max_draft_token_num); - // merge - constexpr int vec_size = num_elems_per_128b(); - if (is_decoder) { - constexpr int blockx = HEAD_DIM / vec_size; - constexpr int blocky = (128 + blockx - 1) / blockx; - dim3 grids_merge(bsz, num_heads); - dim3 blocks_merge(blockx, blocky); - merge_multi_chunks_decoder_kernel - <<>>( - reinterpret_cast(tmp_workspace->ptr()), - static_cast(tmp_m->ptr()), - static_cast(tmp_d->ptr()), - seq_lens_q.data(), - seq_lens_kv.data(), - seq_lens_encoder.data(), - cu_seqlens_q.data(), - shift_bias ? reinterpret_cast( - const_cast(shift_bias.get().data())) - : nullptr, - smooth_weight ? reinterpret_cast(const_cast( - smooth_weight.get().data())) - : nullptr, - reinterpret_cast(out->data()), - quant_max_bound, - quant_min_bound, - in_scale, - max_seq_len, - num_chunks, - num_heads, - chunk_size, - HEAD_DIM); - } else { - constexpr int blockx = HEAD_DIM / vec_size; - constexpr int blocky = (128 + blockx - 1) / blockx; - dim3 grids_merge(min(sm_count * 4, token_num), - num_heads); // 128k is too large - dim3 blocks_merge(blockx, blocky); - merge_multi_chunks_v2_kernel - <<>>( - reinterpret_cast(tmp_workspace->ptr()), - static_cast(tmp_m->ptr()), - static_cast(tmp_d->ptr()), - seq_lens_q.data(), - seq_lens_kv.data(), - seq_lens_encoder.data(), - batch_id_per_token.data(), - cu_seqlens_q.data(), - shift_bias ? reinterpret_cast( - const_cast(shift_bias.get().data())) - : nullptr, - smooth_weight ? reinterpret_cast(const_cast( - smooth_weight.get().data())) - : nullptr, - reinterpret_cast(out->data()), - quant_max_bound, - quant_min_bound, - in_scale, - max_seq_len, - num_chunks, - num_heads, - chunk_size, - HEAD_DIM, - token_num, - speculate_max_draft_token_num); - } - } - } else { - constexpr uint32_t num_frags_z = BLOCK_SIZE / 16 / NUM_WARP_KV; - constexpr uint32_t smem_size = - (num_frags_x + NUM_WARP_KV * num_frags_z * 2) * 16 * HEAD_DIM * - sizeof(T); - auto split_kv_kernel = - multi_query_append_attention_warp1_4_kernel; - if (smem_size >= 48 * 1024) { - cudaFuncSetAttribute(split_kv_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, - smem_size); - } - const int dev_id = 0; - int sm_count; - cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id); - - uint32_t chunk_size = static_cast(max_partition_size); - if (!is_decoder) { - chunk_size = static_cast(encoder_max_partition_size); - } - const int num_chunks = div_up(max_dec_len, chunk_size); - - dim3 grids(num_blocks_x_cpu, num_chunks, kv_num_heads); - dim3 blocks(32, num_warps); - - if (num_chunks <= 1) { - auto nosplit_kv_kernel = - multi_query_append_attention_warp1_4_kernel; - if (smem_size >= 48 * 1024) { - cudaFuncSetAttribute(nosplit_kv_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, - smem_size); - } - - nosplit_kv_kernel<<>>( - reinterpret_cast(const_cast(qkv.data())), - reinterpret_cast(const_cast(cache_k.data())), - reinterpret_cast(const_cast(cache_v.data())), - shift_bias ? reinterpret_cast( - const_cast(shift_bias.get().data())) - : nullptr, - smooth_weight ? reinterpret_cast( - const_cast(smooth_weight.get().data())) - : nullptr, - seq_lens_q.data(), - seq_lens_kv.data(), - batch_ids.data(), - tile_ids_per_batch.data(), - cu_seqlens_q.data(), - block_table.data(), - max_seq_len, - max_dec_len, - max_block_num_per_seq, - scale, - quant_max_bound, - quant_min_bound, - in_scale, - chunk_size, - nullptr, - nullptr, - nullptr, - reinterpret_cast(out->data()), - speculate_max_draft_token_num); - } else { - phi::Allocator::AllocationPtr tmp_workspace, tmp_m, tmp_d; - if (is_decoder) { - tmp_workspace = allocator->Allocate( - phi::SizeOf(qkv.dtype()) * - static_cast(bsz * num_chunks * num_heads * HEAD_DIM)); - tmp_m = allocator->Allocate( - phi::SizeOf(paddle::DataType::FLOAT32) * - static_cast(bsz * num_chunks * num_heads)); - tmp_d = allocator->Allocate( - phi::SizeOf(paddle::DataType::FLOAT32) * - static_cast(bsz * num_chunks * num_heads)); - } else { - if (ENABLE_PREFILL) { - tmp_workspace = - allocator->Allocate(phi::SizeOf(qkv.dtype()) * - static_cast(token_num * num_chunks * - num_heads * HEAD_DIM)); - tmp_m = allocator->Allocate( - phi::SizeOf(paddle::DataType::FLOAT32) * - static_cast(token_num * num_chunks * num_heads)); - tmp_d = allocator->Allocate( - phi::SizeOf(paddle::DataType::FLOAT32) * - static_cast(token_num * num_chunks * num_heads)); - } else { - tmp_workspace = allocator->Allocate( - phi::SizeOf(qkv.dtype()) * - static_cast(speculate_max_draft_token_num * bsz * - num_chunks * num_heads * HEAD_DIM)); - tmp_m = allocator->Allocate( - phi::SizeOf(paddle::DataType::FLOAT32) * - static_cast(speculate_max_draft_token_num * bsz * - num_chunks * num_heads)); - tmp_d = allocator->Allocate( - phi::SizeOf(paddle::DataType::FLOAT32) * - static_cast(speculate_max_draft_token_num * bsz * - num_chunks * num_heads)); - } - } - split_kv_kernel<<>>( - reinterpret_cast(const_cast(qkv.data())), - reinterpret_cast(const_cast(cache_k.data())), - reinterpret_cast(const_cast(cache_v.data())), - shift_bias ? reinterpret_cast( - const_cast(shift_bias.get().data())) - : nullptr, - smooth_weight ? reinterpret_cast( - const_cast(smooth_weight.get().data())) - : nullptr, - seq_lens_q.data(), - seq_lens_kv.data(), - batch_ids.data(), - tile_ids_per_batch.data(), - cu_seqlens_q.data(), - block_table.data(), - max_seq_len, - max_dec_len, - max_block_num_per_seq, - scale, - quant_max_bound, - quant_min_bound, - in_scale, - chunk_size, - reinterpret_cast(tmp_workspace->ptr()), - static_cast(tmp_m->ptr()), - static_cast(tmp_d->ptr()), - reinterpret_cast(out->data()), - speculate_max_draft_token_num); - - // merge - constexpr int vec_size = num_elems_per_128b(); - if (is_decoder) { - constexpr int blockx = HEAD_DIM / vec_size; - constexpr int blocky = (128 + blockx - 1) / blockx; - dim3 grids_merge(bsz, num_heads); - dim3 blocks_merge(blockx, blocky); - merge_multi_chunks_decoder_kernel - <<>>( - reinterpret_cast(tmp_workspace->ptr()), - static_cast(tmp_m->ptr()), - static_cast(tmp_d->ptr()), - seq_lens_q.data(), - seq_lens_kv.data(), - seq_lens_encoder.data(), - cu_seqlens_q.data(), - shift_bias ? reinterpret_cast( - const_cast(shift_bias.get().data())) - : nullptr, - smooth_weight ? reinterpret_cast(const_cast( - smooth_weight.get().data())) - : nullptr, - reinterpret_cast(out->data()), - quant_max_bound, - quant_min_bound, - in_scale, - max_seq_len, - num_chunks, - num_heads, - chunk_size, - HEAD_DIM); - } else { - constexpr int blockx = HEAD_DIM / vec_size; - constexpr int blocky = (128 + blockx - 1) / blockx; - dim3 grids_merge(min(sm_count * 4, token_num), - num_heads); - dim3 blocks_merge(blockx, blocky); - merge_multi_chunks_v2_kernel - <<>>( - reinterpret_cast(tmp_workspace->ptr()), - static_cast(tmp_m->ptr()), - static_cast(tmp_d->ptr()), - seq_lens_q.data(), - seq_lens_kv.data(), - seq_lens_encoder.data(), - batch_id_per_token.data(), - cu_seqlens_q.data(), - shift_bias ? reinterpret_cast( - const_cast(shift_bias.get().data())) - : nullptr, - smooth_weight ? reinterpret_cast(const_cast( - smooth_weight.get().data())) - : nullptr, - reinterpret_cast(out->data()), - quant_max_bound, - quant_min_bound, - in_scale, - max_seq_len, - num_chunks, - num_heads, - chunk_size, - HEAD_DIM, - token_num, - speculate_max_draft_token_num); - } - } - } -} +#include "multiquery_attention_c16_kernel.h" template void CascadeAppendAttentionC16Kernel( const AppendAttnMetaData& meta_data, - const paddle::Tensor& qkv, // [token_num, (num_heads + 2* kv_num_head) * head_dim] + const paddle::Tensor& + qkv, // [token_num, (num_heads + 2* kv_num_head) * head_dim] const paddle::Tensor& cache_k, // [max_block_num, num_heads, block_size, head_dim] const paddle::Tensor& @@ -1285,7 +36,8 @@ void CascadeAppendAttentionC16Kernel( const paddle::optional& shift_bias, // [num_kv_heads, head_dim] const paddle::optional& - smooth_weight, // [num_kv_heads, head_dim] + smooth_weight, // [num_kv_heads, head_dim] + const paddle::optional& sinks, // [num_heads] const paddle::Tensor& seq_lens_q, const paddle::Tensor& seq_lens_kv, const paddle::Tensor& seq_lens_encoder, @@ -1308,10 +60,11 @@ void CascadeAppendAttentionC16Kernel( const bool is_decoder, const bool enable_prefill, cudaStream_t& stream, - paddle::Tensor* out) { + paddle::Tensor* out, + const int sliding_window, + const int sink_size = 0) { const auto token_num = meta_data.token_nums; const auto block_size = meta_data.block_size; - const auto bsz = meta_data.batch_size; const auto num_heads = meta_data.q_num_heads; const auto group_size = meta_data.q_num_heads / meta_data.kv_num_heads; const auto head_dim = meta_data.head_dims; @@ -1349,6 +102,7 @@ void CascadeAppendAttentionC16Kernel( attn_mask, shift_bias, smooth_weight, + sinks, seq_lens_q, seq_lens_kv, seq_lens_encoder, @@ -1368,6 +122,293 @@ void CascadeAppendAttentionC16Kernel( speculate_max_draft_token_num, is_decoder, stream, - out); + out, + sliding_window, + sink_size); })})})})})}) } + +template void +CascadeAppendAttentionC16Kernel( + const AppendAttnMetaData& meta_data, + const paddle::Tensor& qkv, // [token_num, num_heads, head_dim] + const paddle::Tensor& + cache_k, // [max_block_num, num_heads, block_size, head_dim] + const paddle::Tensor& + cache_v, // [max_block_num, num_heads, head_dim, block_size] + const paddle::optional& attn_mask, + const paddle::optional& + cache_k_scale, // [num_kv_heads, head_dim] + const paddle::optional& + cache_v_scale, // [num_kv_heads, head_dim] + const paddle::optional& + cache_k_zp, // [num_kv_heads, head_dim] + const paddle::optional& + cache_v_zp, // [num_kv_heads, head_dim] + const paddle::optional& + shift_bias, // [num_kv_heads, head_dim] + const paddle::optional& + smooth_weight, // [num_kv_heads, head_dim] + const paddle::optional& sinks, // [num_heads] + const paddle::Tensor& seq_lens_q, + const paddle::Tensor& seq_lens_kv, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& batch_id_per_token, + const paddle::Tensor& cu_seqlens_q, + const paddle::Tensor& block_table, + const paddle::Tensor& batch_ids, + const paddle::Tensor& tile_ids_per_batch, + const int num_blocks, + const int block_shape_q, + const int max_seq_len, + const int max_dec_len, + const float quant_max_bound, + const float quant_min_bound, + const float in_scale, + const int max_partition_size, + const int encoder_max_partition_size, + const int speculate_max_draft_token_num, + const bool causal, + const bool is_decoder, + const bool enable_prefill, + cudaStream_t& stream, + paddle::Tensor* out, + const int sliding_window, + const int sink_size); + +template void +CascadeAppendAttentionC16Kernel( + const AppendAttnMetaData& meta_data, + const paddle::Tensor& qkv, // [token_num, num_heads, head_dim] + const paddle::Tensor& + cache_k, // [max_block_num, num_heads, block_size, head_dim] + const paddle::Tensor& + cache_v, // [max_block_num, num_heads, head_dim, block_size] + const paddle::optional& attn_mask, + const paddle::optional& + cache_k_scale, // [num_kv_heads, head_dim] + const paddle::optional& + cache_v_scale, // [num_kv_heads, head_dim] + const paddle::optional& + cache_k_zp, // [num_kv_heads, head_dim] + const paddle::optional& + cache_v_zp, // [num_kv_heads, head_dim] + const paddle::optional& + shift_bias, // [num_kv_heads, head_dim] + const paddle::optional& + smooth_weight, // [num_kv_heads, head_dim] + const paddle::optional& sinks, // [num_heads] + const paddle::Tensor& seq_lens_q, + const paddle::Tensor& seq_lens_kv, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& batch_id_per_token, + const paddle::Tensor& cu_seqlens_q, + const paddle::Tensor& block_table, + const paddle::Tensor& batch_ids, + const paddle::Tensor& tile_ids_per_batch, + const int num_blocks, + const int block_shape_q, + const int max_seq_len, + const int max_dec_len, + const float quant_max_bound, + const float quant_min_bound, + const float in_scale, + const int max_partition_size, + const int encoder_max_partition_size, + const int speculate_max_draft_token_num, + const bool causal, + const bool is_decoder, + const bool enable_prefill, + cudaStream_t& stream, + paddle::Tensor* out, + const int sliding_window, + const int sink_size); + +template void CascadeAppendAttentionC16Kernel( + const AppendAttnMetaData& meta_data, + const paddle::Tensor& qkv, // [token_num, num_heads, head_dim] + const paddle::Tensor& + cache_k, // [max_block_num, num_heads, block_size, head_dim] + const paddle::Tensor& + cache_v, // [max_block_num, num_heads, head_dim, block_size] + const paddle::optional& attn_mask, + const paddle::optional& + cache_k_scale, // [num_kv_heads, head_dim] + const paddle::optional& + cache_v_scale, // [num_kv_heads, head_dim] + const paddle::optional& + cache_k_zp, // [num_kv_heads, head_dim] + const paddle::optional& + cache_v_zp, // [num_kv_heads, head_dim] + const paddle::optional& + shift_bias, // [num_kv_heads, head_dim] + const paddle::optional& + smooth_weight, // [num_kv_heads, head_dim] + const paddle::optional& sinks, // [num_heads] + const paddle::Tensor& seq_lens_q, + const paddle::Tensor& seq_lens_kv, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& batch_id_per_token, + const paddle::Tensor& cu_seqlens_q, + const paddle::Tensor& block_table, + const paddle::Tensor& batch_ids, + const paddle::Tensor& tile_ids_per_batch, + const int num_blocks, + const int block_shape_q, + const int max_seq_len, + const int max_dec_len, + const float quant_max_bound, + const float quant_min_bound, + const float in_scale, + const int max_partition_size, + const int encoder_max_partition_size, + const int speculate_max_draft_token_num, + const bool causal, + const bool is_decoder, + const bool enable_prefill, + cudaStream_t& stream, + paddle::Tensor* out, + const int sliding_window, + const int sink_size); + +template void CascadeAppendAttentionC16Kernel( + const AppendAttnMetaData& meta_data, + const paddle::Tensor& qkv, // [token_num, num_heads, head_dim] + const paddle::Tensor& + cache_k, // [max_block_num, num_heads, block_size, head_dim] + const paddle::Tensor& + cache_v, // [max_block_num, num_heads, head_dim, block_size] + const paddle::optional& attn_mask, + const paddle::optional& + cache_k_scale, // [num_kv_heads, head_dim] + const paddle::optional& + cache_v_scale, // [num_kv_heads, head_dim] + const paddle::optional& + cache_k_zp, // [num_kv_heads, head_dim] + const paddle::optional& + cache_v_zp, // [num_kv_heads, head_dim] + const paddle::optional& + shift_bias, // [num_kv_heads, head_dim] + const paddle::optional& + smooth_weight, // [num_kv_heads, head_dim] + const paddle::optional& sinks, // [num_heads] + const paddle::Tensor& seq_lens_q, + const paddle::Tensor& seq_lens_kv, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& batch_id_per_token, + const paddle::Tensor& cu_seqlens_q, + const paddle::Tensor& block_table, + const paddle::Tensor& batch_ids, + const paddle::Tensor& tile_ids_per_batch, + const int num_blocks, + const int block_shape_q, + const int max_seq_len, + const int max_dec_len, + const float quant_max_bound, + const float quant_min_bound, + const float in_scale, + const int max_partition_size, + const int encoder_max_partition_size, + const int speculate_max_draft_token_num, + const bool causal, + const bool is_decoder, + const bool enable_prefill, + cudaStream_t& stream, + paddle::Tensor* out, + const int sliding_window, + const int sink_size); + +template void +CascadeAppendAttentionC16Kernel( + const AppendAttnMetaData& meta_data, + const paddle::Tensor& qkv, // [token_num, num_heads, head_dim] + const paddle::Tensor& + cache_k, // [max_block_num, num_heads, block_size, head_dim] + const paddle::Tensor& + cache_v, // [max_block_num, num_heads, head_dim, block_size] + const paddle::optional& attn_mask, + const paddle::optional& + cache_k_scale, // [num_kv_heads, head_dim] + const paddle::optional& + cache_v_scale, // [num_kv_heads, head_dim] + const paddle::optional& + cache_k_zp, // [num_kv_heads, head_dim] + const paddle::optional& + cache_v_zp, // [num_kv_heads, head_dim] + const paddle::optional& + shift_bias, // [num_kv_heads, head_dim] + const paddle::optional& + smooth_weight, // [num_kv_heads, head_dim] + const paddle::optional& sinks, // [num_heads] + const paddle::Tensor& seq_lens_q, + const paddle::Tensor& seq_lens_kv, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& batch_id_per_token, + const paddle::Tensor& cu_seqlens_q, + const paddle::Tensor& block_table, + const paddle::Tensor& batch_ids, + const paddle::Tensor& tile_ids_per_batch, + const int num_blocks, + const int block_shape_q, + const int max_seq_len, + const int max_dec_len, + const float quant_max_bound, + const float quant_min_bound, + const float in_scale, + const int max_partition_size, + const int encoder_max_partition_size, + const int speculate_max_draft_token_num, + const bool causal, + const bool is_decoder, + const bool enable_prefill, + cudaStream_t& stream, + paddle::Tensor* out, + const int sliding_window, + const int sink_size); + +template void CascadeAppendAttentionC16Kernel( + const AppendAttnMetaData& meta_data, + const paddle::Tensor& qkv, // [token_num, num_heads, head_dim] + const paddle::Tensor& + cache_k, // [max_block_num, num_heads, block_size, head_dim] + const paddle::Tensor& + cache_v, // [max_block_num, num_heads, head_dim, block_size] + const paddle::optional& attn_mask, + const paddle::optional& + cache_k_scale, // [num_kv_heads, head_dim] + const paddle::optional& + cache_v_scale, // [num_kv_heads, head_dim] + const paddle::optional& + cache_k_zp, // [num_kv_heads, head_dim] + const paddle::optional& + cache_v_zp, // [num_kv_heads, head_dim] + const paddle::optional& + shift_bias, // [num_kv_heads, head_dim] + const paddle::optional& + smooth_weight, // [num_kv_heads, head_dim] + const paddle::optional& sinks, // [num_heads] + const paddle::Tensor& seq_lens_q, + const paddle::Tensor& seq_lens_kv, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& batch_id_per_token, + const paddle::Tensor& cu_seqlens_q, + const paddle::Tensor& block_table, + const paddle::Tensor& batch_ids, + const paddle::Tensor& tile_ids_per_batch, + const int num_blocks, + const int block_shape_q, + const int max_seq_len, + const int max_dec_len, + const float quant_max_bound, + const float quant_min_bound, + const float in_scale, + const int max_partition_size, + const int encoder_max_partition_size, + const int speculate_max_draft_token_num, + const bool causal, + const bool is_decoder, + const bool enable_prefill, + cudaStream_t& stream, + paddle::Tensor* out, + const int sliding_window, + const int sink_size); diff --git a/custom_ops/gpu_ops/append_attn/append_attention_c4_impl.cuh b/custom_ops/gpu_ops/append_attn/append_attention_c4_impl.cuh index 9f003af88b5..6752ce0a3c9 100644 --- a/custom_ops/gpu_ops/append_attn/append_attention_c4_impl.cuh +++ b/custom_ops/gpu_ops/append_attn/append_attention_c4_impl.cuh @@ -13,1499 +13,13 @@ // limitations under the License. #pragma once -#include "append_attention_func.cuh" -#include "append_attention_kernel.h" - -template -__global__ void multi_query_append_attention_c4_kernel( - T *__restrict__ q, // [token_num, (num_heads + 2* kv_num_head) * head_dim] - CacheT *__restrict__ cache_k, // [max_block_num, num_heads, block_size, - // head_dim] - CacheT *__restrict__ cache_v, - const T *__restrict__ cache_k_scale, // [num_kv_heads, head_dim] - const T *__restrict__ cache_k_zero_point, // [num_kv_heads, head_dim] - const T *__restrict__ cache_v_scale, // [num_kv_heads, head_dim] - const T *__restrict__ cache_v_zero_point, // [num_kv_heads, head_dim] - const T *__restrict__ shift_bias, // [q_num_heads * HEAD_DIM] - const T *__restrict__ smooth_weight, // [q_num_heads * HEAD_DIM] - const int *__restrict__ seq_lens, - const int *__restrict__ seq_lens_kv, - const int *__restrict__ batch_ids, - const int *__restrict__ tile_ids_per_batch, - const int *__restrict__ cu_seqlens_q, - const int *__restrict__ block_table, // [bsz, block_num_per_seq] - const int max_seq_len, - const int max_dec_len, - const int max_block_num_per_seq, - const float scale, - const float quant_max_bound, - const float quant_min_bound, - const float in_scale, - const uint32_t chunk_size, - T *__restrict__ tmp_workspace, // split kv [token_num, num_chunks, - // num_heads, head_dim] - float *__restrict__ tmp_m, // [token_num, num_chunks, num_heads] - float *__restrict__ tmp_d, // [token_num, num_chunks, num_heads] - OutT *__restrict__ out, - const int speculate_max_draft_token_num = 5) { - constexpr uint32_t num_vecs_per_head = HEAD_DIM / num_elems_per_128b(); - constexpr uint32_t num_vecs_per_head_k = - HEAD_DIM / 2 / num_elems_per_128b(); - constexpr uint32_t num_vecs_per_blocksize = - BLOCK_SIZE / 2 / num_elems_per_128b(); - constexpr uint32_t inv_k_stride = 8 / num_vecs_per_head_k; - constexpr uint32_t inv_v_stride = 8 / num_vecs_per_blocksize; - const uint32_t btid = blockIdx.x, kv_head_idx = blockIdx.z; - const uint32_t kv_num_heads = gridDim.z; - const uint32_t q_num_heads = kv_num_heads * GROUP_SIZE; - const uint32_t q_head_idx = kv_head_idx * GROUP_SIZE; - const uint32_t tid = threadIdx.x, wid = threadIdx.y; - const uint32_t num_chunks = gridDim.y; - const uint32_t chunk_idx = blockIdx.y; - - const uint32_t batch_id = batch_ids[btid]; - const uint32_t tile_id = tile_ids_per_batch[btid]; - const uint32_t num_rows_per_block = NUM_WARPS * num_frags_x * 16; - const int *block_table_now = nullptr; - - block_table_now = block_table + batch_id * max_block_num_per_seq; - - const uint32_t q_len = seq_lens[batch_id]; - if (q_len <= 0) { - return; - } - const uint32_t q_end = - min(q_len, div_up((tile_id + 1) * num_rows_per_block, GROUP_SIZE)); - uint32_t kv_len = seq_lens_kv[batch_id]; - if (ENABLE_PREFILL) { - kv_len += q_len; - if (kv_len <= 0) { - return; - } - } else { - if (kv_len <= 0) { - return; - } - kv_len += q_len; - } - const uint32_t num_chunks_this_seq = div_up(kv_len, chunk_size); - if (chunk_idx >= num_chunks_this_seq) { - return; - } - - const uint32_t chunk_start = partition_kv ? chunk_idx * chunk_size : 0; - const uint32_t chunk_end = - partition_kv ? min(kv_len, chunk_start + chunk_size) : kv_len; - const uint32_t chunk_len = chunk_end - chunk_start; - - extern __shared__ uint8_t smem[]; - float s_frag[num_frags_x][num_frags_z][8]; - float o_frag[num_frags_x][num_frags_y][8]; - float m_frag[num_frags_x][2]; - float d_frag[num_frags_x][2]; - - const T *cache_k_scale_now = cache_k_scale + kv_head_idx * HEAD_DIM; - const T *cache_k_zp_now = cache_k_zero_point + kv_head_idx * HEAD_DIM; - const T *cache_v_scale_now = cache_v_scale + kv_head_idx * HEAD_DIM; - const T *cache_v_zp_now = cache_v_zero_point + kv_head_idx * HEAD_DIM; - T *cache_k_scale_smem = reinterpret_cast( - smem + NUM_WARPS * num_frags_x * 16 * HEAD_DIM * sizeof(T) + - num_frags_z * 16 * HEAD_DIM / 2 * sizeof(CacheT) * 2); - T *cache_k_zero_point_smem = cache_k_scale_smem + HEAD_DIM; - T *cache_v_scale_smem = cache_k_zero_point_smem + HEAD_DIM; - T *cache_v_zero_point_smem = cache_v_scale_smem + HEAD_DIM; -#pragma unroll - for (uint32_t i = wid * 32 + tid; i < HEAD_DIM; i += 128) { - cache_k_scale_smem[i] = cache_k_scale_now[i]; - cache_k_zero_point_smem[i] = cache_k_zp_now[i]; - cache_v_scale_smem[i] = cache_v_scale_now[i]; - cache_v_zero_point_smem[i] = cache_v_zp_now[i]; - } - - init_states(o_frag, m_frag, d_frag); - - const uint32_t q_n_stride = q_num_heads * HEAD_DIM; - const uint32_t q_ori_n_stride = (q_num_heads + kv_num_heads * 2) * HEAD_DIM; - const uint32_t kv_n_stride = kv_num_heads * BLOCK_SIZE * HEAD_DIM / 2; - const uint32_t kv_h_stride = BLOCK_SIZE * HEAD_DIM / 2; - const uint32_t kv_b_stride = HEAD_DIM / 2; - const uint32_t kv_d_stride = BLOCK_SIZE / 2; - const uint32_t q_start_seq_id = cu_seqlens_q[batch_id]; - const uint32_t q_base_seq_id_this_block = - (tile_id * NUM_WARPS + wid) * num_frags_x * 16; - const uint32_t q_offset = q_start_seq_id * q_ori_n_stride + - q_head_idx * HEAD_DIM + - tid % 8 * num_elems_per_128b(); - const uint32_t o_offset = q_start_seq_id * q_n_stride + - q_head_idx * HEAD_DIM + - tid % 8 * num_elems_per_128b(); - T *q_base_ptr = q + q_offset; - - T *o_base_ptr_T = nullptr; - OutT *o_base_ptr_int8 = nullptr; - if constexpr (partition_kv) { - if (ENABLE_PREFILL) { - o_base_ptr_T = tmp_workspace + q_start_seq_id * num_chunks * q_n_stride + - chunk_idx * q_n_stride + q_head_idx * HEAD_DIM + - tid % 8 * num_elems_per_128b(); - } else { - o_base_ptr_T = - tmp_workspace + - batch_id * speculate_max_draft_token_num * num_chunks * q_n_stride + - chunk_idx * q_n_stride + q_head_idx * HEAD_DIM + - tid % 8 * num_elems_per_128b(); - } - } else { - o_base_ptr_int8 = out + o_offset; - } - smem_t qo_smem(smem); - - uint32_t q_smem_offset_r = smem_t::get_permuted_offset( - wid * num_frags_x * 16 + tid % 16, tid / 16); - load_q_global_smem( - q_base_ptr, - &qo_smem, - q_base_seq_id_this_block, - q_end, - q_ori_n_stride, - HEAD_DIM); - commit_group(); - wait_group<0>(); - __syncthreads(); - - q_smem_inplace_multiply_sm_scale(&qo_smem, - scale); - - T cache_k_scale_frag[num_frags_y][4]; - T cache_k_zp_frag[num_frags_y][4]; - T magic_number; - if constexpr (std::is_same::value) { - magic_number = static_cast(1032.f); - } else { - magic_number = static_cast(136.f); - } -#pragma unroll - for (uint32_t fy = 0; fy < num_frags_y; ++fy) { - *(reinterpret_cast(&cache_k_scale_frag[fy][0])) = - *(reinterpret_cast(&cache_k_scale_smem[fy * 16]) + tid % 4); - *(reinterpret_cast(&cache_k_scale_frag[fy][2])) = - *(reinterpret_cast(&cache_k_scale_smem[fy * 16]) + tid % 4 + - 4); - *(reinterpret_cast(&cache_k_zp_frag[fy][0])) = - *(reinterpret_cast(&cache_k_zero_point_smem[fy * 16]) + - tid % 4); - *(reinterpret_cast(&cache_k_zp_frag[fy][2])) = - *(reinterpret_cast(&cache_k_zero_point_smem[fy * 16]) + - tid % 4 + 4); -#pragma unroll - for (uint32_t zp_i = 0; zp_i < 4; ++zp_i) { - cache_k_zp_frag[fy][zp_i] += magic_number; // 128 + 8 - } - } - T cache_v_scale_frag[num_frags_y][2]; - T cache_v_zp_frag[num_frags_y][2]; -#pragma unroll - for (uint32_t fy = 0; fy < num_frags_y; ++fy) { - cache_v_scale_frag[fy][0] = cache_v_scale_smem[fy * 16 + tid / 4]; - cache_v_scale_frag[fy][1] = cache_v_scale_smem[fy * 16 + tid / 4 + 8]; - cache_v_zp_frag[fy][0] = - cache_v_zero_point_smem[fy * 16 + tid / 4] + magic_number; - cache_v_zp_frag[fy][1] = - cache_v_zero_point_smem[fy * 16 + tid / 4 + 8] + magic_number; - } - - smem_t k_smem(smem + NUM_WARPS * num_frags_x * 16 * HEAD_DIM * sizeof(T)), - v_smem(smem + NUM_WARPS * num_frags_x * 16 * HEAD_DIM * sizeof(T) + - num_frags_z * 16 * HEAD_DIM / 2 * sizeof(CacheT)); - - - const uint32_t num_iterations = div_up( - CAUSAL - ? (min(chunk_len, - sub_if_greater_or_zero( - kv_len - q_len + - div_up((tile_id + 1) * num_rows_per_block, GROUP_SIZE), - chunk_start))) - : chunk_len, - num_frags_z * 16); - const uint32_t mask_check_iteration = - (CAUSAL ? (min(chunk_len, - sub_if_greater_or_zero( - kv_len - q_len + - tile_id * num_rows_per_block / GROUP_SIZE, - chunk_start))) - : chunk_len) / - (num_frags_z * 16); - - uint32_t k_smem_offset_r = - smem_t::get_permuted_offset( - 8 * (tid / 16) + tid % 8, (tid % 16) / 8); - - uint32_t v_smem_offset_r = - smem_t::get_permuted_offset( - 8 * (tid / 16) + tid % 8, (tid % 16) / 8); - - uint32_t k_smem_offset_w = - smem_t::get_permuted_offset( - wid * 8 + tid / 4, - tid % - 4); - uint32_t v_smem_offset_w = - smem_t::get_permuted_offset( - wid * 16 + tid / 2, tid % 2); // 2 * 128 / 8 = 32B, 64 nums - - uint32_t kv_idx_base = chunk_start; - const uint32_t const_k_offset = kv_head_idx * kv_h_stride + - (wid * 8 + tid / 4) * kv_b_stride + - tid % 4 * num_elems_per_128b(); - const uint32_t const_v_offset = kv_head_idx * kv_h_stride + - (wid * 16 + tid / 2) * kv_d_stride + - tid % 2 * num_elems_per_128b(); - - produce_k_blockwise_c4(k_smem, - &k_smem_offset_w, - cache_k, - block_table_now, - kv_head_idx, - kv_n_stride, - kv_h_stride, - kv_b_stride, - kv_idx_base, - chunk_end, - const_k_offset); - commit_group(); - produce_v_blockwise_c4(v_smem, - &v_smem_offset_w, - cache_v, - block_table_now, - kv_head_idx, - kv_n_stride, - kv_h_stride, - kv_d_stride, - kv_idx_base, - chunk_end, - const_v_offset); - commit_group(); - -#pragma unroll 1 - for (uint32_t iter = 0; iter < num_iterations; ++iter) { - wait_group<1>(); - __syncthreads(); - - compute_qk_c4( - &qo_smem, - &q_smem_offset_r, - &k_smem, - &k_smem_offset_r, - s_frag, - cache_k_scale_frag, - cache_k_zp_frag); - - if (iter >= mask_check_iteration) { - mask_s(q_base_seq_id_this_block, - kv_idx_base, - q_len, - kv_len, - chunk_end, - s_frag); - } - - update_mdo_states( - s_frag, o_frag, m_frag, d_frag); - __syncthreads(); - - kv_idx_base += num_frags_z * 16; - produce_k_blockwise_c4(k_smem, - &k_smem_offset_w, - cache_k, - block_table_now, - kv_head_idx, - kv_n_stride, - kv_h_stride, - kv_b_stride, - kv_idx_base, - chunk_end, - const_k_offset); - commit_group(); - wait_group<1>(); - __syncthreads(); - - compute_sfm_v_c4(&v_smem, - &v_smem_offset_r, - s_frag, - o_frag, - d_frag, - cache_v_scale_frag, - cache_v_zp_frag); - __syncthreads(); - - produce_v_blockwise_c4(v_smem, - &v_smem_offset_w, - cache_v, - block_table_now, - kv_head_idx, - kv_n_stride, - kv_h_stride, - kv_d_stride, - kv_idx_base, - chunk_end, - const_v_offset); - commit_group(); - } - wait_group<0>(); - __syncthreads(); - - if constexpr (!partition_kv) { - normalize_d(o_frag, d_frag); - } - - if constexpr (partition_kv) { - write_o_reg_gmem_shift_smooth_quant( - o_frag, - &qo_smem, - o_base_ptr_T, - shift_bias, - smooth_weight, - q_base_seq_id_this_block, - q_head_idx, - quant_max_bound, - quant_min_bound, - in_scale, - q_len, - partition_kv ? q_n_stride * num_chunks : q_n_stride, - HEAD_DIM); - } else { - write_o_reg_gmem_shift_smooth_quant( - o_frag, - &qo_smem, - o_base_ptr_int8, - shift_bias, - smooth_weight, - q_base_seq_id_this_block, - q_head_idx, - quant_max_bound, - quant_min_bound, - in_scale, - q_len, - partition_kv ? q_n_stride * num_chunks : q_n_stride, - HEAD_DIM); - } - - if constexpr (partition_kv) { -#pragma unroll - for (uint32_t fx = 0; fx < num_frags_x; ++fx) { -#pragma unroll - for (uint32_t j = 0; j < 2; ++j) { - const uint32_t qo_idx_now = - q_base_seq_id_this_block + tid / 4 + j * 8 + fx * 16; - const uint32_t qo_head_idx = q_head_idx + qo_idx_now % GROUP_SIZE; - const uint32_t qo_idx = q_start_seq_id + qo_idx_now / GROUP_SIZE; - if (qo_idx - q_start_seq_id < q_len) { - uint32_t offset; - if (ENABLE_PREFILL) { - offset = - (qo_idx * num_chunks + chunk_idx) * q_num_heads + qo_head_idx; - } else { - offset = ((batch_id * speculate_max_draft_token_num + - qo_idx_now / GROUP_SIZE) * - num_chunks + - chunk_idx) * - q_num_heads + - qo_head_idx; - } - tmp_m[offset] = m_frag[fx][j]; - tmp_d[offset] = d_frag[fx][j]; - } - } - } - } -} - -template -__global__ void multi_query_append_attention_c4_warp1_4_kernel( - T *__restrict__ q, // [token_num, (num_heads + 2* kv_num_head) * head_dim] - CacheT *__restrict__ cache_k, // [max_block_num, num_heads, block_size, - // head_dim] - CacheT *__restrict__ cache_v, - const T *__restrict__ cache_k_scale, // [num_kv_heads, head_dim] - const T *__restrict__ cache_k_zero_point, // [num_kv_heads, head_dim] - const T *__restrict__ cache_v_scale, // [num_kv_heads, head_dim] - const T *__restrict__ cache_v_zero_point, // [num_kv_heads, head_dim] - const T *__restrict__ shift_bias, // [q_num_heads * HEAD_DIM] - const T *__restrict__ smooth_weight, // [q_num_heads * HEAD_DIM] - const int *__restrict__ seq_lens, - const int *__restrict__ seq_lens_kv, - const int *__restrict__ batch_ids, - const int *__restrict__ tile_ids_per_batch, - const int *__restrict__ cu_seqlens_q, - const int *__restrict__ block_table, // [bsz, block_num_per_seq] - const int max_seq_len, - const int max_dec_len, - const int max_block_num_per_seq, - const float scale, - const float quant_max_bound, - const float quant_min_bound, - const float in_scale, - const uint32_t chunk_size, - T *__restrict__ tmp_workspace, // split kv [token_num, num_chunks, - // num_heads, head_dim] - float *__restrict__ tmp_m, // [token_num, num_chunks, num_heads] - float *__restrict__ tmp_d, // [token_num, num_chunks, num_heads] - OutT *__restrict__ out, - const int speculate_max_draft_token_num = 5) { - constexpr uint32_t num_vecs_per_head = HEAD_DIM / num_elems_per_128b(); - constexpr uint32_t num_vecs_per_head_k = - HEAD_DIM / 2 / num_elems_per_128b(); - constexpr uint32_t num_vecs_per_blocksize = - BLOCK_SIZE / 2 / num_elems_per_128b(); - constexpr uint32_t inv_k_stride = 8 / num_vecs_per_head_k; - constexpr uint32_t inv_v_stride = 8 / num_vecs_per_blocksize; - static_assert(NUM_WARP_Q == 1, "NUM_WARP_Q must be 1"); - static_assert(NUM_WARP_KV == 4, "NUM_WARP_KV must be 4"); - const uint32_t btid = blockIdx.x, kv_head_idx = blockIdx.z; - const uint32_t kv_num_heads = gridDim.z; - const uint32_t q_num_heads = kv_num_heads * GROUP_SIZE; - const uint32_t q_head_idx = kv_head_idx * GROUP_SIZE; - const uint32_t tid = threadIdx.x, wid = threadIdx.y; - const uint32_t num_chunks = gridDim.y; - const uint32_t chunk_idx = blockIdx.y; - - const uint32_t batch_id = batch_ids[btid]; - const uint32_t tile_id = tile_ids_per_batch[btid]; - const uint32_t num_rows_per_block = num_frags_x * 16; - const int *block_table_now = block_table + batch_id * max_block_num_per_seq; - - const uint32_t q_len = seq_lens[batch_id]; - if (q_len <= 0) { - return; - } - const uint32_t q_end = - min(q_len, div_up((tile_id + 1) * num_rows_per_block, GROUP_SIZE)); - uint32_t kv_len = seq_lens_kv[batch_id]; - if (ENABLE_PREFILL) { - kv_len += q_len; - if (kv_len <= 0) { - return; - } - } else { - if (kv_len <= 0) { - return; - } - kv_len += q_len; - } - const uint32_t num_chunks_this_seq = div_up(kv_len, chunk_size); - if (chunk_idx >= num_chunks_this_seq) { - return; - } - - const uint32_t chunk_start = partition_kv ? chunk_idx * chunk_size : 0; - const uint32_t chunk_end = - partition_kv ? min(kv_len, chunk_start + chunk_size) : kv_len; - const uint32_t chunk_len = chunk_end - chunk_start; - - extern __shared__ uint8_t smem[]; - float s_frag[num_frags_x][num_frags_z][8]; - float o_frag[num_frags_x][num_frags_y][8]; - float m_frag[num_frags_x][2]; - float d_frag[num_frags_x][2]; - init_states(o_frag, m_frag, d_frag); - - const T *cache_k_scale_now = cache_k_scale + kv_head_idx * HEAD_DIM; - const T *cache_k_zp_now = cache_k_zero_point + kv_head_idx * HEAD_DIM; - const T *cache_v_scale_now = cache_v_scale + kv_head_idx * HEAD_DIM; - const T *cache_v_zp_now = cache_v_zero_point + kv_head_idx * HEAD_DIM; - T *cache_k_scale_smem = reinterpret_cast( - smem + NUM_WARP_Q * num_frags_x * 16 * HEAD_DIM * sizeof(T) + - NUM_WARP_KV * num_frags_z * 16 * HEAD_DIM / 2 * sizeof(CacheT) * 2); - T *cache_k_zero_point_smem = cache_k_scale_smem + HEAD_DIM; - T *cache_v_scale_smem = cache_k_zero_point_smem + HEAD_DIM; - T *cache_v_zero_point_smem = cache_v_scale_smem + HEAD_DIM; -#pragma unroll - for (uint32_t i = wid * 32 + tid; i < HEAD_DIM; i += 128) { - cache_k_scale_smem[i] = cache_k_scale_now[i]; - cache_k_zero_point_smem[i] = cache_k_zp_now[i]; - cache_v_scale_smem[i] = cache_v_scale_now[i]; - cache_v_zero_point_smem[i] = cache_v_zp_now[i]; - } - - const uint32_t q_n_stride = q_num_heads * HEAD_DIM; - const uint32_t q_ori_n_stride = (q_num_heads + kv_num_heads * 2) * HEAD_DIM; - const uint32_t kv_n_stride = kv_num_heads * BLOCK_SIZE * HEAD_DIM / 2; - const uint32_t kv_h_stride = BLOCK_SIZE * HEAD_DIM / 2; - const uint32_t kv_b_stride = HEAD_DIM / 2; - const uint32_t kv_d_stride = BLOCK_SIZE / 2; - const uint32_t q_start_seq_id = cu_seqlens_q[batch_id]; - const uint32_t q_base_seq_id_this_block = tile_id * num_frags_x * 16; - const uint32_t q_offset = q_start_seq_id * q_ori_n_stride + - q_head_idx * HEAD_DIM + - tid % 8 * num_elems_per_128b(); - const uint32_t o_offset = q_start_seq_id * q_n_stride + - q_head_idx * HEAD_DIM + - tid % 8 * num_elems_per_128b(); - T *q_base_ptr = q + q_offset; - - T *o_base_ptr_T = nullptr; - OutT *o_base_ptr_int8 = nullptr; - if (num_chunks_this_seq <= 1) { - o_base_ptr_int8 = out + o_offset; - } else { - if (ENABLE_PREFILL) { - o_base_ptr_T = tmp_workspace + batch_id * num_chunks * q_n_stride + - chunk_idx * q_n_stride + q_head_idx * HEAD_DIM + - tid % 8 * num_elems_per_128b(); - } else { - o_base_ptr_T = - tmp_workspace + - batch_id * speculate_max_draft_token_num * num_chunks * q_n_stride + - chunk_idx * q_n_stride + q_head_idx * HEAD_DIM + - tid % 8 * num_elems_per_128b(); - } - } - - smem_t qo_smem(smem); - - uint32_t q_smem_offset_r = smem_t::get_permuted_offset( - tid % 16, tid / 16); - load_q_global_smem_multi_warps(q_base_ptr, - &qo_smem, - q_base_seq_id_this_block, - q_end, - q_ori_n_stride, - HEAD_DIM); - commit_group(); - wait_group<0>(); - __syncthreads(); - - q_smem_inplace_multiply_sm_scale_multi_warps( - &qo_smem, scale); - - T cache_k_scale_frag[num_frags_y][4]; - T cache_k_zp_frag[num_frags_y][4]; - T magic_number; - if constexpr (std::is_same::value) { - magic_number = static_cast(1032.f); - } else { - magic_number = static_cast(136.f); - } -#pragma unroll - for (uint32_t fy = 0; fy < num_frags_y; ++fy) { - *(reinterpret_cast(&cache_k_scale_frag[fy][0])) = - *(reinterpret_cast(&cache_k_scale_smem[fy * 16]) + tid % 4); - *(reinterpret_cast(&cache_k_scale_frag[fy][2])) = - *(reinterpret_cast(&cache_k_scale_smem[fy * 16]) + tid % 4 + - 4); - *(reinterpret_cast(&cache_k_zp_frag[fy][0])) = - *(reinterpret_cast(&cache_k_zero_point_smem[fy * 16]) + - tid % 4); - *(reinterpret_cast(&cache_k_zp_frag[fy][2])) = - *(reinterpret_cast(&cache_k_zero_point_smem[fy * 16]) + - tid % 4 + 4); -#pragma unroll - for (uint32_t zp_i = 0; zp_i < 4; ++zp_i) { - cache_k_zp_frag[fy][zp_i] += magic_number; - } - } - T cache_v_scale_frag[num_frags_y][2]; - T cache_v_zp_frag[num_frags_y][2]; -#pragma unroll - for (uint32_t fy = 0; fy < num_frags_y; ++fy) { - cache_v_scale_frag[fy][0] = cache_v_scale_smem[fy * 16 + tid / 4]; - cache_v_scale_frag[fy][1] = cache_v_scale_smem[fy * 16 + tid / 4 + 8]; - cache_v_zp_frag[fy][0] = - cache_v_zero_point_smem[fy * 16 + tid / 4] + magic_number; - cache_v_zp_frag[fy][1] = - cache_v_zero_point_smem[fy * 16 + tid / 4 + 8] + magic_number; - } - - smem_t k_smem(smem + num_frags_x * 16 * HEAD_DIM * sizeof(T)), - v_smem(smem + num_frags_x * 16 * HEAD_DIM * sizeof(T) + - NUM_WARP_KV * num_frags_z * 16 * HEAD_DIM / 2 * sizeof(CacheT)); - - const uint32_t num_iterations = div_up( - CAUSAL - ? (min(chunk_len, - sub_if_greater_or_zero( - kv_len - q_len + - div_up((tile_id + 1) * num_rows_per_block, GROUP_SIZE), - chunk_start))) - : chunk_len, - NUM_WARP_KV * num_frags_z * 16); - const uint32_t mask_check_iteration = - (CAUSAL ? (min(chunk_len, - sub_if_greater_or_zero( - kv_len - q_len + - tile_id * num_rows_per_block / GROUP_SIZE, - chunk_start))) - : chunk_len) / - (NUM_WARP_KV * num_frags_z * 16); - - uint32_t k_smem_offset_r = - smem_t::get_permuted_offset( - wid * num_frags_z * 16 + 8 * (tid / 16) + tid % 8, (tid % 16) / 8); - - uint32_t v_smem_offset_r = - smem_t::get_permuted_offset( - wid * num_frags_y * 16 + 8 * (tid / 16) + tid % 8, (tid % 16) / 8); - - uint32_t k_smem_offset_w = - smem_t::get_permuted_offset( - wid * 8 + tid / 4, - tid % - 4); - uint32_t v_smem_offset_w = - smem_t::get_permuted_offset( - wid * 16 + tid / 2, tid % 2); - - uint32_t kv_idx_base = chunk_start; - const uint32_t const_k_offset = kv_head_idx * kv_h_stride + - (wid * 8 + tid / 4) * kv_b_stride + - tid % 4 * num_elems_per_128b(); - const uint32_t const_v_offset = kv_head_idx * kv_h_stride + - (wid * 16 + tid / 2) * kv_d_stride + - tid % 2 * num_elems_per_128b(); - - produce_k_blockwise_c4(k_smem, - &k_smem_offset_w, - cache_k, - block_table_now, - kv_head_idx, - kv_n_stride, - kv_h_stride, - kv_b_stride, - kv_idx_base, - chunk_end, - const_k_offset); - commit_group(); - produce_v_blockwise_c4(v_smem, - &v_smem_offset_w, - cache_v, - block_table_now, - kv_head_idx, - kv_n_stride, - kv_h_stride, - kv_d_stride, - kv_idx_base, - chunk_end, - const_v_offset); - commit_group(); -#pragma unroll 1 - for (uint32_t iter = 0; iter < num_iterations; ++iter) { - wait_group<1>(); - __syncthreads(); - compute_qk_c4( - &qo_smem, - &q_smem_offset_r, - &k_smem, - &k_smem_offset_r, - s_frag, - cache_k_scale_frag, - cache_k_zp_frag); - if (iter >= mask_check_iteration) { - mask_s(q_base_seq_id_this_block, - kv_idx_base + wid * num_frags_z * 16, - q_len, - kv_len, - chunk_end, - s_frag); - } - - update_mdo_states( - s_frag, o_frag, m_frag, d_frag); - __syncthreads(); - - kv_idx_base += NUM_WARP_KV * num_frags_z * 16; - produce_k_blockwise_c4(k_smem, - &k_smem_offset_w, - cache_k, - block_table_now, - kv_head_idx, - kv_n_stride, - kv_h_stride, - kv_b_stride, - kv_idx_base, - chunk_end, - const_k_offset); - commit_group(); - wait_group<1>(); - __syncthreads(); - - // compute sfm*v - compute_sfm_v_c4(&v_smem, - &v_smem_offset_r, - s_frag, - o_frag, - d_frag, - cache_v_scale_frag, - cache_v_zp_frag); - __syncthreads(); - - produce_v_blockwise_c4(v_smem, - &v_smem_offset_w, - cache_v, - block_table_now, - kv_head_idx, - kv_n_stride, - kv_h_stride, - kv_d_stride, - kv_idx_base, - chunk_end, - const_v_offset); - commit_group(); - } - wait_group<0>(); - __syncthreads(); - - merge_block_res_v2( - o_frag, reinterpret_cast(smem), m_frag, d_frag, wid, tid); - - if (num_chunks_this_seq <= 1) { - normalize_d(o_frag, d_frag); - } - - // write o - // [num_frags_x, 16, num_frags_y, 16] - if (num_chunks_this_seq <= 1) { - write_o_reg_gmem_multi_warps_shift_smooth_quant( - o_frag, - &qo_smem, - o_base_ptr_int8, - shift_bias, - smooth_weight, - q_base_seq_id_this_block, - q_head_idx, - quant_max_bound, - quant_min_bound, - in_scale, - q_len, - q_n_stride, - HEAD_DIM); - } else { - write_o_reg_gmem_multi_warps_shift_smooth_quant( - o_frag, - &qo_smem, - o_base_ptr_T, - shift_bias, - smooth_weight, - q_base_seq_id_this_block, - q_head_idx, - quant_max_bound, - quant_min_bound, - in_scale, - q_len, - q_n_stride * num_chunks, - HEAD_DIM); - } - - if (num_chunks_this_seq > 1) { - if (wid == 0) { -#pragma unroll - for (uint32_t fx = 0; fx < num_frags_x; ++fx) { -#pragma unroll - for (uint32_t j = 0; j < 2; ++j) { - const uint32_t qo_idx_now = - q_base_seq_id_this_block + tid / 4 + j * 8 + fx * 16; - const uint32_t qo_head_idx = q_head_idx + qo_idx_now % GROUP_SIZE; - const uint32_t qo_idx = q_start_seq_id + qo_idx_now / GROUP_SIZE; - if (qo_idx - q_start_seq_id < q_len) { - uint32_t offset; - if (ENABLE_PREFILL) { - offset = (batch_id * num_chunks + chunk_idx) * q_num_heads + - qo_head_idx; - } else { - offset = ((batch_id * speculate_max_draft_token_num + - qo_idx_now / GROUP_SIZE) * - num_chunks + - chunk_idx) * - q_num_heads + - qo_head_idx; - } - tmp_m[offset] = m_frag[fx][j]; - tmp_d[offset] = d_frag[fx][j]; - } - } - } - } - } -} - -template -void MultiQueryAppendC4Attention( - const AppendAttnMetaData &meta_data, - const paddle::Tensor &qkv, - const paddle::Tensor &cache_k, - const paddle::Tensor &cache_v, - const paddle::optional &attn_mask, - const paddle::Tensor &cache_k_scale, - const paddle::Tensor &cache_v_scale, - const paddle::optional &cache_k_zp, - const paddle::optional &cache_v_zp, - const paddle::optional &shift_bias, - const paddle::optional &smooth_weight, - const paddle::Tensor &seq_lens_q, - const paddle::Tensor &seq_lens_kv, - const paddle::Tensor &seq_lens_encoder, - const paddle::Tensor &batch_id_per_token, - const paddle::Tensor &cu_seqlens_q, - const paddle::Tensor &block_table, - const paddle::Tensor &batch_ids, - const paddle::Tensor &tile_ids_per_batch, - const int num_blocks_x_cpu, - const int max_seq_len, - const int max_dec_len, - const float quant_max_bound, - const float quant_min_bound, - const float in_scale, - const int max_partition_size, - const int encoder_max_partition_size, - const int speculate_max_draft_token_num, - const bool is_decoder, - cudaStream_t &stream, - paddle::Tensor *out) { - using NV_TYPE = typename cascade_attn_type_traits::type; - using OUT_NV_TYPE = typename cascade_attn_type_traits::type; - - auto num_heads = meta_data.q_num_heads; - auto kv_num_heads = meta_data.kv_num_heads; - auto token_num = meta_data.token_nums; - auto bsz = meta_data.batch_size; - auto max_block_num_per_seq = meta_data.max_blocks_per_seq; - - constexpr uint32_t num_warps = 4; - constexpr uint32_t NUM_WARP_KV = num_warps / NUM_WARP_Q; - constexpr uint32_t num_frags_x = BLOCK_SHAPE_Q / (16 * NUM_WARP_Q); - constexpr uint32_t num_frags_y = HEAD_DIM / 16; - constexpr uint32_t num_qrow_per_block = NUM_WARP_Q * num_frags_x * 16; - - auto *allocator = paddle::GetAllocator(qkv.place()); - - const float scale = 1.f / sqrt(HEAD_DIM); - - if constexpr (NUM_WARP_Q == 4) { - constexpr uint32_t num_frags_z = BLOCK_SIZE / 16; - constexpr uint32_t smem_size = - num_warps * num_frags_x * 16 * HEAD_DIM * sizeof(T) + - num_frags_z * 16 * HEAD_DIM / 2 * sizeof(uint8_t) * 2 + - HEAD_DIM * 4 * sizeof(T); - auto split_kv_kernel = - multi_query_append_attention_c4_kernel; - cudaFuncSetAttribute(split_kv_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, - smem_size); - const int dev_id = 0; - int sm_count; - int act_blocks_per_sm; - cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id); - cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &act_blocks_per_sm, split_kv_kernel, num_warps * 32, smem_size); - assert(act_blocks_per_sm > 1); - const int num_blocks_per_wave = sm_count * act_blocks_per_sm; - const int num_blocks_need = num_blocks_x_cpu * kv_num_heads; - const int max_num_chunks = div_up(num_blocks_per_wave, num_blocks_need); - const float ratio = static_cast(num_blocks_need) / - static_cast(num_blocks_per_wave); - - uint32_t chunk_size = static_cast(max_partition_size); - if (!is_decoder) { - chunk_size = static_cast(encoder_max_partition_size); - } - const int num_chunks = div_up(max_dec_len, chunk_size); - - dim3 grids(num_blocks_x_cpu, num_chunks, kv_num_heads); - dim3 blocks(32, num_warps); - if (num_chunks <= 1) { - auto nosplit_kv_kernel = - multi_query_append_attention_c4_kernel; - if (smem_size >= 48 * 1024) { - cudaFuncSetAttribute(nosplit_kv_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, - smem_size); - } - nosplit_kv_kernel<<>>( - reinterpret_cast(const_cast(qkv.data())), - const_cast(cache_k.data()), - const_cast(cache_v.data()), - reinterpret_cast(const_cast(cache_k_scale.data())), - cache_k_zp ? reinterpret_cast( - const_cast(cache_k_zp.get().data())) - : nullptr, - reinterpret_cast(const_cast(cache_v_scale.data())), - cache_v_zp ? reinterpret_cast( - const_cast(cache_v_zp.get().data())) - : nullptr, - shift_bias ? reinterpret_cast( - const_cast(shift_bias.get().data())) - : nullptr, - smooth_weight ? reinterpret_cast( - const_cast(smooth_weight.get().data())) - : nullptr, - seq_lens_q.data(), - seq_lens_kv.data(), - batch_ids.data(), - tile_ids_per_batch.data(), - cu_seqlens_q.data(), - block_table.data(), - max_seq_len, - max_dec_len, - max_block_num_per_seq, - scale, - quant_max_bound, - quant_min_bound, - in_scale, - chunk_size, - nullptr, - nullptr, - nullptr, - reinterpret_cast(out->data()), - speculate_max_draft_token_num); - } else { - phi::Allocator::AllocationPtr tmp_workspace, tmp_m, tmp_d; - if (ENABLE_PREFILL) { - tmp_workspace = allocator->Allocate( - phi::SizeOf(qkv.dtype()) * - static_cast(token_num * num_chunks * num_heads * HEAD_DIM)); - tmp_m = allocator->Allocate( - phi::SizeOf(paddle::DataType::FLOAT32) * - static_cast(token_num * num_chunks * num_heads)); - tmp_d = allocator->Allocate( - phi::SizeOf(paddle::DataType::FLOAT32) * - static_cast(token_num * num_chunks * num_heads)); - } else { - tmp_workspace = allocator->Allocate( - phi::SizeOf(qkv.dtype()) * - static_cast(speculate_max_draft_token_num * bsz * - num_chunks * num_heads * HEAD_DIM)); - tmp_m = allocator->Allocate( - phi::SizeOf(paddle::DataType::FLOAT32) * - static_cast(speculate_max_draft_token_num * bsz * - num_chunks * num_heads)); - tmp_d = allocator->Allocate( - phi::SizeOf(paddle::DataType::FLOAT32) * - static_cast(speculate_max_draft_token_num * bsz * - num_chunks * num_heads)); - } - split_kv_kernel<<>>( - reinterpret_cast(const_cast(qkv.data())), - const_cast(cache_k.data()), - const_cast(cache_v.data()), - reinterpret_cast(const_cast(cache_k_scale.data())), - cache_k_zp ? reinterpret_cast( - const_cast(cache_k_zp.get().data())) - : nullptr, - reinterpret_cast(const_cast(cache_v_scale.data())), - cache_v_zp ? reinterpret_cast( - const_cast(cache_v_zp.get().data())) - : nullptr, - shift_bias ? reinterpret_cast( - const_cast(shift_bias.get().data())) - : nullptr, - smooth_weight ? reinterpret_cast( - const_cast(smooth_weight.get().data())) - : nullptr, - seq_lens_q.data(), - seq_lens_kv.data(), - batch_ids.data(), - tile_ids_per_batch.data(), - cu_seqlens_q.data(), - block_table.data(), - max_seq_len, - max_dec_len, - max_block_num_per_seq, - scale, - quant_max_bound, - quant_min_bound, - in_scale, - chunk_size, - reinterpret_cast(tmp_workspace->ptr()), - static_cast(tmp_m->ptr()), - static_cast(tmp_d->ptr()), - reinterpret_cast(out->data()), - speculate_max_draft_token_num); - // merge - constexpr int vec_size = num_elems_per_128b(); - if (is_decoder) { - constexpr int blockx = HEAD_DIM / vec_size; - constexpr int blocky = (128 + blockx - 1) / blockx; - dim3 grids_merge(bsz, num_heads); - dim3 blocks_merge(blockx, blocky); - merge_multi_chunks_decoder_kernel - <<>>( - reinterpret_cast(tmp_workspace->ptr()), - static_cast(tmp_m->ptr()), - static_cast(tmp_d->ptr()), - seq_lens_q.data(), - seq_lens_kv.data(), - seq_lens_encoder.data(), - cu_seqlens_q.data(), - shift_bias ? reinterpret_cast( - const_cast(shift_bias.get().data())) - : nullptr, - smooth_weight ? reinterpret_cast(const_cast( - smooth_weight.get().data())) - : nullptr, - reinterpret_cast(out->data()), - quant_max_bound, - quant_min_bound, - in_scale, - max_seq_len, - num_chunks, - num_heads, - chunk_size, - HEAD_DIM); - } else { - constexpr int blockx = HEAD_DIM / vec_size; - constexpr int blocky = (128 + blockx - 1) / blockx; - dim3 grids_merge(min(sm_count * 4, token_num), - num_heads); - dim3 blocks_merge(blockx, blocky); - merge_multi_chunks_v2_kernel - <<>>( - reinterpret_cast(tmp_workspace->ptr()), - static_cast(tmp_m->ptr()), - static_cast(tmp_d->ptr()), - seq_lens_q.data(), - seq_lens_kv.data(), - seq_lens_encoder.data(), - batch_id_per_token.data(), - cu_seqlens_q.data(), - shift_bias ? reinterpret_cast( - const_cast(shift_bias.get().data())) - : nullptr, - smooth_weight ? reinterpret_cast(const_cast( - smooth_weight.get().data())) - : nullptr, - reinterpret_cast(out->data()), - quant_max_bound, - quant_min_bound, - in_scale, - max_seq_len, - num_chunks, - num_heads, - chunk_size, - HEAD_DIM, - token_num, - speculate_max_draft_token_num); - } - } - } else { - constexpr uint32_t num_frags_z = BLOCK_SIZE / 16 / NUM_WARP_KV * 4; - constexpr uint32_t smem_size = - num_frags_x * 16 * HEAD_DIM * sizeof(T) + - NUM_WARP_KV * num_frags_z * 16 * HEAD_DIM / 2 * sizeof(uint8_t) * 2 + - HEAD_DIM * 4 * sizeof(T); - auto split_kv_kernel = - multi_query_append_attention_c4_warp1_4_kernel; - if (smem_size >= 48 * 1024) { - cudaFuncSetAttribute(split_kv_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, - smem_size); - } - const int dev_id = 0; - int sm_count; - int act_blocks_per_sm; - cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id); - cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &act_blocks_per_sm, split_kv_kernel, num_warps * 32, smem_size); - assert(act_blocks_per_sm > 1); - const int num_blocks_per_wave = sm_count * act_blocks_per_sm; - const int num_blocks_need = num_blocks_x_cpu * kv_num_heads; - const int max_num_chunks = div_up(num_blocks_per_wave, num_blocks_need); - const float ratio = static_cast(num_blocks_need) / - static_cast(num_blocks_per_wave); - - - uint32_t chunk_size = static_cast(max_partition_size); - if (!is_decoder) { - chunk_size = static_cast(encoder_max_partition_size); - } - const int num_chunks = div_up(max_dec_len, chunk_size); - dim3 grids(num_blocks_x_cpu, num_chunks, kv_num_heads); - dim3 blocks(32, num_warps); - if (num_chunks <= 1) { - auto nosplit_kv_kernel = - multi_query_append_attention_c4_warp1_4_kernel; - if (smem_size >= 48 * 1024) { - cudaFuncSetAttribute(nosplit_kv_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, - smem_size); - } - nosplit_kv_kernel<<>>( - reinterpret_cast(const_cast(qkv.data())), - const_cast(cache_k.data()), - const_cast(cache_v.data()), - reinterpret_cast(const_cast(cache_k_scale.data())), - cache_k_zp ? reinterpret_cast( - const_cast(cache_k_zp.get().data())) - : nullptr, - reinterpret_cast(const_cast(cache_v_scale.data())), - cache_v_zp ? reinterpret_cast( - const_cast(cache_v_zp.get().data())) - : nullptr, - shift_bias ? reinterpret_cast( - const_cast(shift_bias.get().data())) - : nullptr, - smooth_weight ? reinterpret_cast( - const_cast(smooth_weight.get().data())) - : nullptr, - seq_lens_q.data(), - seq_lens_kv.data(), - batch_ids.data(), - tile_ids_per_batch.data(), - cu_seqlens_q.data(), - block_table.data(), - max_seq_len, - max_dec_len, - max_block_num_per_seq, - scale, - quant_max_bound, - quant_min_bound, - in_scale, - chunk_size, - nullptr, - nullptr, - nullptr, - reinterpret_cast(out->data()), - speculate_max_draft_token_num); - } else { - phi::Allocator::AllocationPtr tmp_workspace, tmp_m, tmp_d; - if (is_decoder) { - tmp_workspace = allocator->Allocate( - phi::SizeOf(qkv.dtype()) * - static_cast(bsz * num_chunks * num_heads * HEAD_DIM)); - tmp_m = allocator->Allocate( - phi::SizeOf(paddle::DataType::FLOAT32) * - static_cast(bsz * num_chunks * num_heads)); - tmp_d = allocator->Allocate( - phi::SizeOf(paddle::DataType::FLOAT32) * - static_cast(bsz * num_chunks * num_heads)); - } else { - if (ENABLE_PREFILL) { - tmp_workspace = - allocator->Allocate(phi::SizeOf(qkv.dtype()) * - static_cast(token_num * num_chunks * - num_heads * HEAD_DIM)); - tmp_m = allocator->Allocate( - phi::SizeOf(paddle::DataType::FLOAT32) * - static_cast(token_num * num_chunks * num_heads)); - tmp_d = allocator->Allocate( - phi::SizeOf(paddle::DataType::FLOAT32) * - static_cast(token_num * num_chunks * num_heads)); - } else { - tmp_workspace = allocator->Allocate( - phi::SizeOf(qkv.dtype()) * - static_cast(speculate_max_draft_token_num * bsz * - num_chunks * num_heads * HEAD_DIM)); - tmp_m = allocator->Allocate( - phi::SizeOf(paddle::DataType::FLOAT32) * - static_cast(speculate_max_draft_token_num * bsz * - num_chunks * num_heads)); - tmp_d = allocator->Allocate( - phi::SizeOf(paddle::DataType::FLOAT32) * - static_cast(speculate_max_draft_token_num * bsz * - num_chunks * num_heads)); - } - } - split_kv_kernel<<>>( - reinterpret_cast(const_cast(qkv.data())), - const_cast(cache_k.data()), - const_cast(cache_v.data()), - reinterpret_cast(const_cast(cache_k_scale.data())), - cache_k_zp ? reinterpret_cast( - const_cast(cache_k_zp.get().data())) - : nullptr, - reinterpret_cast(const_cast(cache_v_scale.data())), - cache_v_zp ? reinterpret_cast( - const_cast(cache_v_zp.get().data())) - : nullptr, - shift_bias ? reinterpret_cast( - const_cast(shift_bias.get().data())) - : nullptr, - smooth_weight ? reinterpret_cast( - const_cast(smooth_weight.get().data())) - : nullptr, - seq_lens_q.data(), - seq_lens_kv.data(), - batch_ids.data(), - tile_ids_per_batch.data(), - cu_seqlens_q.data(), - block_table.data(), - max_seq_len, - max_dec_len, - max_block_num_per_seq, - scale, - quant_max_bound, - quant_min_bound, - in_scale, - chunk_size, - reinterpret_cast(tmp_workspace->ptr()), - static_cast(tmp_m->ptr()), - static_cast(tmp_d->ptr()), - reinterpret_cast(out->data()), - speculate_max_draft_token_num); - // merge - constexpr int vec_size = num_elems_per_128b(); - if (is_decoder) { - constexpr int blockx = HEAD_DIM / vec_size; - constexpr int blocky = (128 + blockx - 1) / blockx; - dim3 grids_merge(bsz, num_heads); - dim3 blocks_merge(blockx, blocky); - merge_multi_chunks_decoder_kernel - <<>>( - reinterpret_cast(tmp_workspace->ptr()), - static_cast(tmp_m->ptr()), - static_cast(tmp_d->ptr()), - seq_lens_q.data(), - seq_lens_kv.data(), - seq_lens_encoder.data(), - cu_seqlens_q.data(), - shift_bias ? reinterpret_cast( - const_cast(shift_bias.get().data())) - : nullptr, - smooth_weight ? reinterpret_cast(const_cast( - smooth_weight.get().data())) - : nullptr, - reinterpret_cast(out->data()), - quant_max_bound, - quant_min_bound, - in_scale, - max_seq_len, - num_chunks, - num_heads, - chunk_size, - HEAD_DIM); - } else { - constexpr int blockx = HEAD_DIM / vec_size; - constexpr int blocky = (128 + blockx - 1) / blockx; - dim3 grids_merge(min(sm_count * 4, token_num), - num_heads); - dim3 blocks_merge(blockx, blocky); - merge_multi_chunks_v2_kernel - <<>>( - reinterpret_cast(tmp_workspace->ptr()), - static_cast(tmp_m->ptr()), - static_cast(tmp_d->ptr()), - seq_lens_q.data(), - seq_lens_kv.data(), - seq_lens_encoder.data(), - batch_id_per_token.data(), - cu_seqlens_q.data(), - shift_bias ? reinterpret_cast( - const_cast(shift_bias.get().data())) - : nullptr, - smooth_weight ? reinterpret_cast(const_cast( - smooth_weight.get().data())) - : nullptr, - reinterpret_cast(out->data()), - quant_max_bound, - quant_min_bound, - in_scale, - max_seq_len, - num_chunks, - num_heads, - chunk_size, - HEAD_DIM, - token_num, - speculate_max_draft_token_num); - } - } - } -} +#include "multiquery_attention_c4_kernel.h" template void CascadeAppendAttentionC4Kernel( const AppendAttnMetaData& meta_data, - const paddle::Tensor& qkv, // [token_num, (num_heads + 2* kv_num_head) * head_dim] + const paddle::Tensor& + qkv, // [token_num, (num_heads + 2* kv_num_head) * head_dim] const paddle::Tensor& cache_k, // [max_block_num, num_heads, block_size, head_dim] const paddle::Tensor& @@ -1522,7 +36,8 @@ void CascadeAppendAttentionC4Kernel( const paddle::optional& shift_bias, // [num_kv_heads, head_dim] const paddle::optional& - smooth_weight, // [num_kv_heads, head_dim] + smooth_weight, // [num_kv_heads, head_dim] + const paddle::optional& sinks, // [num_heads] const paddle::Tensor& seq_lens_q, const paddle::Tensor& seq_lens_kv, const paddle::Tensor& seq_lens_encoder, @@ -1545,10 +60,11 @@ void CascadeAppendAttentionC4Kernel( const bool is_decoder, const bool enable_prefill, cudaStream_t& stream, - paddle::Tensor* out) { + paddle::Tensor* out, + const int sliding_window = 0, + const int sink_size = 0) { const auto token_num = meta_data.token_nums; const auto block_size = meta_data.block_size; - const auto bsz = meta_data.batch_size; const auto num_heads = meta_data.q_num_heads; const auto group_size = meta_data.q_num_heads / meta_data.kv_num_heads; const auto head_dim = meta_data.head_dims; @@ -1590,6 +106,7 @@ void CascadeAppendAttentionC4Kernel( cache_v_zp, shift_bias, smooth_weight, + sinks, seq_lens_q, seq_lens_kv, seq_lens_encoder, @@ -1609,6 +126,293 @@ void CascadeAppendAttentionC4Kernel( speculate_max_draft_token_num, is_decoder, stream, - out); + out, + sliding_window, + sink_size); })})})})})}) } + +template void +CascadeAppendAttentionC4Kernel( + const AppendAttnMetaData& meta_data, + const paddle::Tensor& qkv, // [token_num, num_heads, head_dim] + const paddle::Tensor& + cache_k, // [max_block_num, num_heads, block_size, head_dim] + const paddle::Tensor& + cache_v, // [max_block_num, num_heads, head_dim, block_size] + const paddle::optional& attn_mask, + const paddle::optional& + cache_k_scale, // [num_kv_heads, head_dim] + const paddle::optional& + cache_v_scale, // [num_kv_heads, head_dim] + const paddle::optional& + cache_k_zp, // [num_kv_heads, head_dim] + const paddle::optional& + cache_v_zp, // [num_kv_heads, head_dim] + const paddle::optional& + shift_bias, // [num_kv_heads, head_dim] + const paddle::optional& + smooth_weight, // [num_kv_heads, head_dim] + const paddle::optional& sinks, // [num_heads] + const paddle::Tensor& seq_lens_q, + const paddle::Tensor& seq_lens_kv, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& batch_id_per_token, + const paddle::Tensor& cu_seqlens_q, + const paddle::Tensor& block_table, + const paddle::Tensor& batch_ids, + const paddle::Tensor& tile_ids_per_batch, + const int num_blocks, + const int block_shape_q, + const int max_seq_len, + const int max_dec_len, + const float quant_max_bound, + const float quant_min_bound, + const float in_scale, + const int max_partition_size, + const int encoder_max_partition_size, + const int speculate_max_draft_token_num, + const bool causal, + const bool is_decoder, + const bool enable_prefill, + cudaStream_t& stream, + paddle::Tensor* out, + const int sliding_window, + const int sink_size); + +template void +CascadeAppendAttentionC4Kernel( + const AppendAttnMetaData& meta_data, + const paddle::Tensor& qkv, // [token_num, num_heads, head_dim] + const paddle::Tensor& + cache_k, // [max_block_num, num_heads, block_size, head_dim] + const paddle::Tensor& + cache_v, // [max_block_num, num_heads, head_dim, block_size] + const paddle::optional& attn_mask, + const paddle::optional& + cache_k_scale, // [num_kv_heads, head_dim] + const paddle::optional& + cache_v_scale, // [num_kv_heads, head_dim] + const paddle::optional& + cache_k_zp, // [num_kv_heads, head_dim] + const paddle::optional& + cache_v_zp, // [num_kv_heads, head_dim] + const paddle::optional& + shift_bias, // [num_kv_heads, head_dim] + const paddle::optional& + smooth_weight, // [num_kv_heads, head_dim] + const paddle::optional& sinks, // [num_heads] + const paddle::Tensor& seq_lens_q, + const paddle::Tensor& seq_lens_kv, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& batch_id_per_token, + const paddle::Tensor& cu_seqlens_q, + const paddle::Tensor& block_table, + const paddle::Tensor& batch_ids, + const paddle::Tensor& tile_ids_per_batch, + const int num_blocks, + const int block_shape_q, + const int max_seq_len, + const int max_dec_len, + const float quant_max_bound, + const float quant_min_bound, + const float in_scale, + const int max_partition_size, + const int encoder_max_partition_size, + const int speculate_max_draft_token_num, + const bool causal, + const bool is_decoder, + const bool enable_prefill, + cudaStream_t& stream, + paddle::Tensor* out, + const int sliding_window, + const int sink_size); + +template void CascadeAppendAttentionC4Kernel( + const AppendAttnMetaData& meta_data, + const paddle::Tensor& qkv, // [token_num, num_heads, head_dim] + const paddle::Tensor& + cache_k, // [max_block_num, num_heads, block_size, head_dim] + const paddle::Tensor& + cache_v, // [max_block_num, num_heads, head_dim, block_size] + const paddle::optional& attn_mask, + const paddle::optional& + cache_k_scale, // [num_kv_heads, head_dim] + const paddle::optional& + cache_v_scale, // [num_kv_heads, head_dim] + const paddle::optional& + cache_k_zp, // [num_kv_heads, head_dim] + const paddle::optional& + cache_v_zp, // [num_kv_heads, head_dim] + const paddle::optional& + shift_bias, // [num_kv_heads, head_dim] + const paddle::optional& + smooth_weight, // [num_kv_heads, head_dim] + const paddle::optional& sinks, // [num_heads] + const paddle::Tensor& seq_lens_q, + const paddle::Tensor& seq_lens_kv, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& batch_id_per_token, + const paddle::Tensor& cu_seqlens_q, + const paddle::Tensor& block_table, + const paddle::Tensor& batch_ids, + const paddle::Tensor& tile_ids_per_batch, + const int num_blocks, + const int block_shape_q, + const int max_seq_len, + const int max_dec_len, + const float quant_max_bound, + const float quant_min_bound, + const float in_scale, + const int max_partition_size, + const int encoder_max_partition_size, + const int speculate_max_draft_token_num, + const bool causal, + const bool is_decoder, + const bool enable_prefill, + cudaStream_t& stream, + paddle::Tensor* out, + const int sliding_window, + const int sink_size); + +template void CascadeAppendAttentionC4Kernel( + const AppendAttnMetaData& meta_data, + const paddle::Tensor& qkv, // [token_num, num_heads, head_dim] + const paddle::Tensor& + cache_k, // [max_block_num, num_heads, block_size, head_dim] + const paddle::Tensor& + cache_v, // [max_block_num, num_heads, head_dim, block_size] + const paddle::optional& attn_mask, + const paddle::optional& + cache_k_scale, // [num_kv_heads, head_dim] + const paddle::optional& + cache_v_scale, // [num_kv_heads, head_dim] + const paddle::optional& + cache_k_zp, // [num_kv_heads, head_dim] + const paddle::optional& + cache_v_zp, // [num_kv_heads, head_dim] + const paddle::optional& + shift_bias, // [num_kv_heads, head_dim] + const paddle::optional& + smooth_weight, // [num_kv_heads, head_dim] + const paddle::optional& sinks, // [num_heads] + const paddle::Tensor& seq_lens_q, + const paddle::Tensor& seq_lens_kv, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& batch_id_per_token, + const paddle::Tensor& cu_seqlens_q, + const paddle::Tensor& block_table, + const paddle::Tensor& batch_ids, + const paddle::Tensor& tile_ids_per_batch, + const int num_blocks, + const int block_shape_q, + const int max_seq_len, + const int max_dec_len, + const float quant_max_bound, + const float quant_min_bound, + const float in_scale, + const int max_partition_size, + const int encoder_max_partition_size, + const int speculate_max_draft_token_num, + const bool causal, + const bool is_decoder, + const bool enable_prefill, + cudaStream_t& stream, + paddle::Tensor* out, + const int sliding_window, + const int sink_size); + +template void +CascadeAppendAttentionC4Kernel( + const AppendAttnMetaData& meta_data, + const paddle::Tensor& qkv, // [token_num, num_heads, head_dim] + const paddle::Tensor& + cache_k, // [max_block_num, num_heads, block_size, head_dim] + const paddle::Tensor& + cache_v, // [max_block_num, num_heads, head_dim, block_size] + const paddle::optional& attn_mask, + const paddle::optional& + cache_k_scale, // [num_kv_heads, head_dim] + const paddle::optional& + cache_v_scale, // [num_kv_heads, head_dim] + const paddle::optional& + cache_k_zp, // [num_kv_heads, head_dim] + const paddle::optional& + cache_v_zp, // [num_kv_heads, head_dim] + const paddle::optional& + shift_bias, // [num_kv_heads, head_dim] + const paddle::optional& + smooth_weight, // [num_kv_heads, head_dim] + const paddle::optional& sinks, // [num_heads] + const paddle::Tensor& seq_lens_q, + const paddle::Tensor& seq_lens_kv, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& batch_id_per_token, + const paddle::Tensor& cu_seqlens_q, + const paddle::Tensor& block_table, + const paddle::Tensor& batch_ids, + const paddle::Tensor& tile_ids_per_batch, + const int num_blocks, + const int block_shape_q, + const int max_seq_len, + const int max_dec_len, + const float quant_max_bound, + const float quant_min_bound, + const float in_scale, + const int max_partition_size, + const int encoder_max_partition_size, + const int speculate_max_draft_token_num, + const bool causal, + const bool is_decoder, + const bool enable_prefill, + cudaStream_t& stream, + paddle::Tensor* out, + const int sliding_window, + const int sink_size); + +template void CascadeAppendAttentionC4Kernel( + const AppendAttnMetaData& meta_data, + const paddle::Tensor& qkv, // [token_num, num_heads, head_dim] + const paddle::Tensor& + cache_k, // [max_block_num, num_heads, block_size, head_dim] + const paddle::Tensor& + cache_v, // [max_block_num, num_heads, head_dim, block_size] + const paddle::optional& attn_mask, + const paddle::optional& + cache_k_scale, // [num_kv_heads, head_dim] + const paddle::optional& + cache_v_scale, // [num_kv_heads, head_dim] + const paddle::optional& + cache_k_zp, // [num_kv_heads, head_dim] + const paddle::optional& + cache_v_zp, // [num_kv_heads, head_dim] + const paddle::optional& + shift_bias, // [num_kv_heads, head_dim] + const paddle::optional& + smooth_weight, // [num_kv_heads, head_dim] + const paddle::optional& sinks, // [num_heads] + const paddle::Tensor& seq_lens_q, + const paddle::Tensor& seq_lens_kv, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& batch_id_per_token, + const paddle::Tensor& cu_seqlens_q, + const paddle::Tensor& block_table, + const paddle::Tensor& batch_ids, + const paddle::Tensor& tile_ids_per_batch, + const int num_blocks, + const int block_shape_q, + const int max_seq_len, + const int max_dec_len, + const float quant_max_bound, + const float quant_min_bound, + const float in_scale, + const int max_partition_size, + const int encoder_max_partition_size, + const int speculate_max_draft_token_num, + const bool causal, + const bool is_decoder, + const bool enable_prefill, + cudaStream_t& stream, + paddle::Tensor* out, + const int sliding_window, + const int sink_size); diff --git a/custom_ops/gpu_ops/append_attn/append_attention_c8_impl.cuh b/custom_ops/gpu_ops/append_attn/append_attention_c8_impl.cuh index 3b72597e025..3ab6b063f3f 100644 --- a/custom_ops/gpu_ops/append_attn/append_attention_c8_impl.cuh +++ b/custom_ops/gpu_ops/append_attn/append_attention_c8_impl.cuh @@ -13,1472 +13,13 @@ // limitations under the License. #pragma once -#include "append_attention_func.cuh" -#include "append_attention_kernel.h" +#include "multiquery_attention_c8_kernel.h" -template -__global__ void multi_query_append_attention_c8_kernel( - T *__restrict__ q, // [token_num, (num_heads + 2* kv_num_head) * head_dim] - CacheT *__restrict__ cache_k, // [max_block_num, num_heads, block_size, - // head_dim] - CacheT *__restrict__ cache_v, - const T *__restrict__ cache_k_scale, // [num_kv_heads] - const T *__restrict__ cache_v_scale, // [num_kv_heads] - const T *__restrict__ shift_bias, // [q_num_heads * HEAD_DIM] - const T *__restrict__ smooth_weight, // [q_num_heads * HEAD_DIM] - const int *__restrict__ seq_lens, - const int *__restrict__ seq_lens_kv, - const int *__restrict__ batch_ids, - const int *__restrict__ tile_ids_per_batch, - const int *__restrict__ cu_seqlens_q, - const int *__restrict__ block_table, // [bsz, block_num_per_seq] - const int max_seq_len, - const int max_dec_len, - const int max_block_num_per_seq, - const float scale, - const float quant_max_bound, - const float quant_min_bound, - const float in_scale, - const uint32_t chunk_size, - T *__restrict__ tmp_workspace, // split kv [token_num, num_chunks, - // num_heads, head_dim] - float *__restrict__ tmp_m, // [token_num, num_chunks, num_heads] - float *__restrict__ tmp_d, // [token_num, num_chunks, num_heads] - OutT *__restrict__ out, - const int speculate_max_draft_token_num = 5) { - constexpr uint32_t num_vecs_per_head = - HEAD_DIM / num_elems_per_128b(); // 128 / 8 = 16 - constexpr uint32_t num_vecs_per_head_k = - HEAD_DIM / num_elems_per_128b(); // 128 / 16 = 8 - constexpr uint32_t num_vecs_per_blocksize = - BLOCK_SIZE / num_elems_per_128b(); // 64 / 16 = 4 - constexpr uint32_t inv_k_stride = 8 / num_vecs_per_head_k; - constexpr uint32_t inv_v_stride = 8 / num_vecs_per_blocksize; - const uint32_t btid = blockIdx.x, kv_head_idx = blockIdx.z; - const uint32_t kv_num_heads = gridDim.z; - const uint32_t q_num_heads = kv_num_heads * GROUP_SIZE; - const uint32_t q_head_idx = kv_head_idx * GROUP_SIZE; - const uint32_t tid = threadIdx.x, wid = threadIdx.y; - const uint32_t num_chunks = gridDim.y; - const uint32_t chunk_idx = blockIdx.y; - - const uint32_t batch_id = batch_ids[btid]; - const uint32_t tile_id = tile_ids_per_batch[btid]; - const uint32_t num_rows_per_block = NUM_WARPS * num_frags_x * 16; - const int *block_table_now = nullptr; - - block_table_now = block_table + batch_id * max_block_num_per_seq; - - const uint32_t q_len = seq_lens[batch_id]; - if (q_len <= 0) { - return; - } - - T cache_k_scale_reg[num_frags_y * 4]; - T cache_v_scale_reg[num_frags_y * 2]; - if (is_scale_channel_wise) { - int scale_col_base = threadIdx.x % 4 * 2 + kv_head_idx * HEAD_DIM; - const T *cache_k_scale_cur_head = cache_k_scale + scale_col_base; - for (int i = 0; i < num_frags_y; ++i) { - const int scale_idx = i * 16; - cache_k_scale_reg[i * 4] = cache_k_scale_cur_head[scale_idx]; - cache_k_scale_reg[i * 4 + 1] = cache_k_scale_cur_head[scale_idx + 1]; - cache_k_scale_reg[i * 4 + 2] = cache_k_scale_cur_head[scale_idx + 8]; - cache_k_scale_reg[i * 4 + 3] = cache_k_scale_cur_head[scale_idx + 9]; - } - scale_col_base = threadIdx.x / 4 + kv_head_idx * HEAD_DIM; - const T *cache_v_scale_cur_head = cache_v_scale + scale_col_base; - for (int i = 0; i < num_frags_y; ++i) { - const int scale_idx = i * 16; - cache_v_scale_reg[i * 2] = cache_v_scale_cur_head[scale_idx]; - cache_v_scale_reg[i * 2 + 1] = cache_v_scale_cur_head[scale_idx + 8]; - } - } else { - cache_k_scale_reg[0] = cache_k_scale[kv_head_idx]; - cache_v_scale_reg[0] = cache_v_scale[kv_head_idx]; - } - - const uint32_t q_end = - min(q_len, div_up((tile_id + 1) * num_rows_per_block, GROUP_SIZE)); - uint32_t kv_len = seq_lens_kv[batch_id]; - if (ENABLE_PREFILL) { - kv_len += q_len; - if (kv_len <= 0) { - return; - } - } else { - if (kv_len <= 0) { - return; - } - kv_len += q_len; - } - const uint32_t num_chunks_this_seq = div_up(kv_len, chunk_size); - if (chunk_idx >= num_chunks_this_seq) { - return; - } - - const uint32_t chunk_start = partition_kv ? chunk_idx * chunk_size : 0; - const uint32_t chunk_end = - partition_kv ? min(kv_len, chunk_start + chunk_size) : kv_len; - const uint32_t chunk_len = chunk_end - chunk_start; - - extern __shared__ uint8_t smem[]; - float s_frag[num_frags_x][num_frags_z][8]; - float o_frag[num_frags_x][num_frags_y][8]; - float m_frag[num_frags_x][2]; - float d_frag[num_frags_x][2]; - init_states(o_frag, m_frag, d_frag); - - const uint32_t q_n_stride = q_num_heads * HEAD_DIM; - const uint32_t q_ori_n_stride = (q_num_heads + kv_num_heads * 2) * HEAD_DIM; - const uint32_t kv_n_stride = kv_num_heads * BLOCK_SIZE * HEAD_DIM; - const uint32_t kv_h_stride = BLOCK_SIZE * HEAD_DIM; - const uint32_t kv_b_stride = HEAD_DIM; - const uint32_t kv_d_stride = BLOCK_SIZE; - const uint32_t q_start_seq_id = cu_seqlens_q[batch_id]; - const uint32_t q_base_seq_id_this_block = - (tile_id * NUM_WARPS + wid) * num_frags_x * 16; - const uint32_t q_offset = q_start_seq_id * q_ori_n_stride + - q_head_idx * HEAD_DIM + - tid % 8 * num_elems_per_128b(); - const uint32_t o_offset = q_start_seq_id * q_n_stride + - q_head_idx * HEAD_DIM + - tid % 8 * num_elems_per_128b(); - T *q_base_ptr = q + q_offset; - - T *o_base_ptr_T = nullptr; - OutT *o_base_ptr_int8 = nullptr; - if constexpr (partition_kv) { - if (ENABLE_PREFILL) { - o_base_ptr_T = tmp_workspace + q_start_seq_id * num_chunks * q_n_stride + - chunk_idx * q_n_stride + q_head_idx * HEAD_DIM + - tid % 8 * num_elems_per_128b(); - } else { - o_base_ptr_T = - tmp_workspace + - batch_id * speculate_max_draft_token_num * num_chunks * q_n_stride + - chunk_idx * q_n_stride + q_head_idx * HEAD_DIM + - tid % 8 * num_elems_per_128b(); - } - } else { - o_base_ptr_int8 = out + o_offset; - } - smem_t qo_smem(smem); - - uint32_t q_smem_offset_r = smem_t::get_permuted_offset( - wid * num_frags_x * 16 + tid % 16, tid / 16); // 16 * 16 - load_q_global_smem( - q_base_ptr, - &qo_smem, - q_base_seq_id_this_block, - q_end, - q_ori_n_stride, - HEAD_DIM); - commit_group(); - wait_group<0>(); - __syncthreads(); - - q_smem_inplace_multiply_sm_scale(&qo_smem, - scale); - smem_t k_smem(smem + NUM_WARPS * num_frags_x * 16 * HEAD_DIM * sizeof(T)), - v_smem(smem + NUM_WARPS * num_frags_x * 16 * HEAD_DIM * sizeof(T) + - num_frags_z * 16 * HEAD_DIM * sizeof(CacheT)); - - - const uint32_t num_iterations = div_up( - CAUSAL - ? (min(chunk_len, - sub_if_greater_or_zero( - kv_len - q_len + - div_up((tile_id + 1) * num_rows_per_block, GROUP_SIZE), - chunk_start))) - : chunk_len, - num_frags_z * 16); - const uint32_t mask_check_iteration = - (CAUSAL ? (min(chunk_len, - sub_if_greater_or_zero( - kv_len - q_len + - tile_id * num_rows_per_block / GROUP_SIZE, - chunk_start))) - : chunk_len) / - (num_frags_z * 16); - - uint32_t k_smem_offset_r = - smem_t::get_permuted_offset( - 8 * (tid / 16) + tid % 8, (tid % 16) / 8); - - uint32_t v_smem_offset_r = - smem_t::get_permuted_offset( - 8 * (tid / 16) + tid % 8, (tid % 16) / 8); - - uint32_t k_smem_offset_w = - smem_t::get_permuted_offset( - wid * 4 + tid / 8, - tid % 8); - uint32_t v_smem_offset_w = - smem_t::get_permuted_offset( - wid * 8 + tid / 4, tid % 4); // 4 * 128 / 8 = 64 - - uint32_t kv_idx_base = chunk_start; - const uint32_t const_k_offset = kv_head_idx * kv_h_stride + - (wid * 4 + tid / 8) * kv_b_stride + - tid % 8 * num_elems_per_128b(); - const uint32_t const_v_offset = kv_head_idx * kv_h_stride + - (wid * 8 + tid / 4) * kv_d_stride + - tid % 4 * num_elems_per_128b(); - - produce_k_blockwise_c8(k_smem, - &k_smem_offset_w, - cache_k, - block_table_now, - kv_head_idx, - kv_n_stride, - kv_h_stride, - kv_b_stride, - kv_idx_base, - chunk_end, - const_k_offset); - commit_group(); - produce_v_blockwise_c8(v_smem, - &v_smem_offset_w, - cache_v, - block_table_now, - kv_head_idx, - kv_n_stride, - kv_h_stride, - kv_d_stride, - kv_idx_base, - chunk_end, - const_v_offset); - commit_group(); - -#pragma unroll 1 - for (uint32_t iter = 0; iter < num_iterations; ++iter) { - wait_group<1>(); - __syncthreads(); - // s = qk - compute_qk_c8( - &qo_smem, - &q_smem_offset_r, - &k_smem, - &k_smem_offset_r, - cache_k_scale_reg, - s_frag); - - // mask according to kv_idx and q_idx - if (iter >= mask_check_iteration) { - mask_s(q_base_seq_id_this_block, - kv_idx_base, - q_len, - kv_len, - chunk_end, - s_frag); - } - - // update m,d - update_mdo_states( - s_frag, o_frag, m_frag, d_frag); - __syncthreads(); - - kv_idx_base += num_frags_z * 16; - produce_k_blockwise_c8(k_smem, - &k_smem_offset_w, - cache_k, - block_table_now, - kv_head_idx, - kv_n_stride, - kv_h_stride, - kv_b_stride, - kv_idx_base, - chunk_end, - const_k_offset); - commit_group(); - wait_group<1>(); - __syncthreads(); - - // compute sfm*v - compute_sfm_v_c8( - &v_smem, &v_smem_offset_r, s_frag, o_frag, d_frag, cache_v_scale_reg); - __syncthreads(); - - produce_v_blockwise_c8(v_smem, - &v_smem_offset_w, - cache_v, - block_table_now, - kv_head_idx, - kv_n_stride, - kv_h_stride, - kv_d_stride, - kv_idx_base, - chunk_end, - const_v_offset); - commit_group(); - - } - wait_group<0>(); - __syncthreads(); - - if constexpr (!partition_kv) { - normalize_d(o_frag, d_frag); - } - - // write o - // [num_frags_x, 16, num_frags_y, 16] - if constexpr (partition_kv) { - write_o_reg_gmem_shift_smooth_quant( - o_frag, - &qo_smem, - o_base_ptr_T, - shift_bias, - smooth_weight, - q_base_seq_id_this_block, - q_head_idx, - quant_max_bound, - quant_min_bound, - in_scale, - q_len, - partition_kv ? q_n_stride * num_chunks : q_n_stride, - HEAD_DIM); - } else { - write_o_reg_gmem_shift_smooth_quant( - o_frag, - &qo_smem, - o_base_ptr_int8, - shift_bias, - smooth_weight, - q_base_seq_id_this_block, - q_head_idx, - quant_max_bound, - quant_min_bound, - in_scale, - q_len, - partition_kv ? q_n_stride * num_chunks : q_n_stride, - HEAD_DIM); - } - - - if constexpr (partition_kv) { -#pragma unroll - for (uint32_t fx = 0; fx < num_frags_x; ++fx) { -#pragma unroll - for (uint32_t j = 0; j < 2; ++j) { - const uint32_t qo_idx_now = - q_base_seq_id_this_block + tid / 4 + j * 8 + fx * 16; - const uint32_t qo_head_idx = q_head_idx + qo_idx_now % GROUP_SIZE; - const uint32_t qo_idx = q_start_seq_id + qo_idx_now / GROUP_SIZE; - if (qo_idx - q_start_seq_id < q_len) { - uint32_t offset; - if (ENABLE_PREFILL) { - offset = - (qo_idx * num_chunks + chunk_idx) * q_num_heads + qo_head_idx; - } else { - offset = ((batch_id * speculate_max_draft_token_num + - qo_idx_now / GROUP_SIZE) * - num_chunks + - chunk_idx) * - q_num_heads + - qo_head_idx; - } - tmp_m[offset] = m_frag[fx][j]; - tmp_d[offset] = d_frag[fx][j]; - } - } - } - } -} - -template -__global__ void multi_query_append_attention_c8_warp1_4_kernel( - T *__restrict__ q, // [token_num, (num_heads + 2* kv_num_head) * head_dim] - CacheT *__restrict__ cache_k, // [max_block_num, num_heads, block_size, - // head_dim] - CacheT *__restrict__ cache_v, - const T *__restrict__ cache_k_scale, // [num_kv_heads, head_dim] - const T *__restrict__ cache_v_scale, // [num_kv_heads, head_dim] - const T *__restrict__ shift_bias, // [q_num_heads * HEAD_DIM] - const T *__restrict__ smooth_weight, // [q_num_heads * HEAD_DIM] - const int *__restrict__ seq_lens, - const int *__restrict__ seq_lens_kv, - const int *__restrict__ batch_ids, - const int *__restrict__ tile_ids_per_batch, - const int *__restrict__ cu_seqlens_q, - const int *__restrict__ block_table, // [bsz, block_num_per_seq] - const int max_seq_len, - const int max_dec_len, - const int max_block_num_per_seq, - const float scale, - const float quant_max_bound, - const float quant_min_bound, - const float in_scale, - const uint32_t chunk_size, - T *__restrict__ tmp_workspace, // split kv [token_num, num_chunks, - // num_heads, head_dim] - float *__restrict__ tmp_m, // [token_num, num_chunks, num_heads] - float *__restrict__ tmp_d, // [token_num, num_chunks, num_heads] - OutT *__restrict__ out, - const int speculate_max_draft_token_num = 5) { - constexpr uint32_t num_vecs_per_head = HEAD_DIM / num_elems_per_128b(); - constexpr uint32_t num_vecs_per_head_k = - HEAD_DIM / num_elems_per_128b(); - constexpr uint32_t num_vecs_per_blocksize = - BLOCK_SIZE / num_elems_per_128b(); - constexpr uint32_t inv_k_stride = 8 / num_vecs_per_head_k; - constexpr uint32_t inv_v_stride = 8 / num_vecs_per_blocksize; - static_assert(NUM_WARP_Q == 1, "NUM_WARP_Q must be 1"); - static_assert(NUM_WARP_KV == 4, "NUM_WARP_KV must be 4"); - const uint32_t btid = blockIdx.x, kv_head_idx = blockIdx.z; - const uint32_t kv_num_heads = gridDim.z; - const uint32_t q_num_heads = kv_num_heads * GROUP_SIZE; - const uint32_t q_head_idx = kv_head_idx * GROUP_SIZE; - const uint32_t tid = threadIdx.x, wid = threadIdx.y; - const uint32_t num_chunks = gridDim.y; - const uint32_t chunk_idx = blockIdx.y; - - const uint32_t batch_id = batch_ids[btid]; - const uint32_t tile_id = tile_ids_per_batch[btid]; - const uint32_t num_rows_per_block = num_frags_x * 16; - const int *block_table_now = block_table + batch_id * max_block_num_per_seq; - - const uint32_t q_len = seq_lens[batch_id]; - if (q_len <= 0) { - return; - } - T cache_k_scale_reg[num_frags_y * 4]; - T cache_v_scale_reg[num_frags_y * 2]; - if (is_scale_channel_wise) { - int scale_col_base = threadIdx.x % 4 * 2 + kv_head_idx * HEAD_DIM; - const T *cache_k_scale_cur_head = cache_k_scale + scale_col_base; - for (int i = 0; i < num_frags_y; ++i) { - const int scale_idx = i * 16; - cache_k_scale_reg[i * 4] = cache_k_scale_cur_head[scale_idx]; - cache_k_scale_reg[i * 4 + 1] = cache_k_scale_cur_head[scale_idx + 1]; - cache_k_scale_reg[i * 4 + 2] = cache_k_scale_cur_head[scale_idx + 8]; - cache_k_scale_reg[i * 4 + 3] = cache_k_scale_cur_head[scale_idx + 9]; - } - scale_col_base = threadIdx.x / 4 + kv_head_idx * HEAD_DIM; - const T *cache_v_scale_cur_head = cache_v_scale + scale_col_base; - for (int i = 0; i < num_frags_y; ++i) { - const int scale_idx = i * 16; - cache_v_scale_reg[i * 2] = cache_v_scale_cur_head[scale_idx]; - cache_v_scale_reg[i * 2 + 1] = cache_v_scale_cur_head[scale_idx + 8]; - } - } else { - cache_k_scale_reg[0] = cache_k_scale[kv_head_idx]; - cache_v_scale_reg[0] = cache_v_scale[kv_head_idx]; - } - const uint32_t q_end = - min(q_len, div_up((tile_id + 1) * num_rows_per_block, GROUP_SIZE)); - uint32_t kv_len = seq_lens_kv[batch_id]; - if (ENABLE_PREFILL) { - kv_len += q_len; - if (kv_len <= 0) { - return; - } - } else { - if (kv_len <= 0) { - return; - } - kv_len += q_len; - } - const uint32_t num_chunks_this_seq = div_up(kv_len, chunk_size); - if (chunk_idx >= num_chunks_this_seq) { - return; - } - - const uint32_t chunk_start = partition_kv ? chunk_idx * chunk_size : 0; - const uint32_t chunk_end = - partition_kv ? min(kv_len, chunk_start + chunk_size) : kv_len; - const uint32_t chunk_len = chunk_end - chunk_start; - - extern __shared__ uint8_t smem[]; - float s_frag[num_frags_x][num_frags_z][8]; - float o_frag[num_frags_x][num_frags_y][8]; - float m_frag[num_frags_x][2]; - float d_frag[num_frags_x][2]; - init_states(o_frag, m_frag, d_frag); - - const uint32_t q_n_stride = q_num_heads * HEAD_DIM; - const uint32_t q_ori_n_stride = (q_num_heads + kv_num_heads * 2) * HEAD_DIM; - const uint32_t kv_n_stride = kv_num_heads * BLOCK_SIZE * HEAD_DIM; - const uint32_t kv_h_stride = BLOCK_SIZE * HEAD_DIM; - const uint32_t kv_b_stride = HEAD_DIM; - const uint32_t kv_d_stride = BLOCK_SIZE; - const uint32_t q_start_seq_id = cu_seqlens_q[batch_id]; - const uint32_t q_base_seq_id_this_block = tile_id * num_frags_x * 16; - const uint32_t q_offset = q_start_seq_id * q_ori_n_stride + - q_head_idx * HEAD_DIM + - tid % 8 * num_elems_per_128b(); - const uint32_t o_offset = q_start_seq_id * q_n_stride + - q_head_idx * HEAD_DIM + - tid % 8 * num_elems_per_128b(); - T *q_base_ptr = q + q_offset; - - T *o_base_ptr_T = nullptr; - OutT *o_base_ptr_int8 = nullptr; - if (num_chunks_this_seq <= 1) { - o_base_ptr_int8 = out + o_offset; - } else { - if (ENABLE_PREFILL) { - o_base_ptr_T = tmp_workspace + batch_id * num_chunks * q_n_stride + - chunk_idx * q_n_stride + q_head_idx * HEAD_DIM + - tid % 8 * num_elems_per_128b(); - } else { - o_base_ptr_T = - tmp_workspace + - batch_id * speculate_max_draft_token_num * num_chunks * q_n_stride + - chunk_idx * q_n_stride + q_head_idx * HEAD_DIM + - tid % 8 * num_elems_per_128b(); - } - } - - smem_t qo_smem(smem); - - uint32_t q_smem_offset_r = smem_t::get_permuted_offset( - tid % 16, tid / 16); // 16 * 16 - load_q_global_smem_multi_warps(q_base_ptr, - &qo_smem, - q_base_seq_id_this_block, - q_end, - q_ori_n_stride, - HEAD_DIM); - commit_group(); - wait_group<0>(); - __syncthreads(); - - q_smem_inplace_multiply_sm_scale_multi_warps( - &qo_smem, scale); - - smem_t k_smem(smem + num_frags_x * 16 * HEAD_DIM * sizeof(T)), - v_smem(smem + num_frags_x * 16 * HEAD_DIM * sizeof(T) + - NUM_WARP_KV * num_frags_z * 16 * HEAD_DIM * sizeof(CacheT)); - - const uint32_t num_iterations = div_up( - CAUSAL - ? (min(chunk_len, - sub_if_greater_or_zero( - kv_len - q_len + - div_up((tile_id + 1) * num_rows_per_block, GROUP_SIZE), - chunk_start))) - : chunk_len, - NUM_WARP_KV * num_frags_z * 16); - const uint32_t mask_check_iteration = - (CAUSAL ? (min(chunk_len, - sub_if_greater_or_zero( - kv_len - q_len + - tile_id * num_rows_per_block / GROUP_SIZE, - chunk_start))) - : chunk_len) / - (NUM_WARP_KV * num_frags_z * 16); - - uint32_t k_smem_offset_r = - smem_t::get_permuted_offset( - wid * num_frags_z * 16 + 8 * (tid / 16) + tid % 8, (tid % 16) / 8); - - uint32_t v_smem_offset_r = - smem_t::get_permuted_offset( - (wid / 2) * num_frags_y * 16 + 8 * (tid / 16) + tid % 8, - (wid % 2) * num_frags_z + (tid % 16) / 8); - - uint32_t k_smem_offset_w = - smem_t::get_permuted_offset( - wid * 4 + tid / 8, - tid % - 8); - uint32_t v_smem_offset_w = - smem_t::get_permuted_offset( - wid * 8 + tid / 4, tid % 4); - - uint32_t kv_idx_base = chunk_start; - const uint32_t const_k_offset = kv_head_idx * kv_h_stride + - (wid * 4 + tid / 8) * kv_b_stride + - tid % 8 * num_elems_per_128b(); - const uint32_t const_v_offset = kv_head_idx * kv_h_stride + - (wid * 8 + tid / 4) * kv_d_stride + - tid % 4 * num_elems_per_128b(); - - // load BLOCK_SIZE * HEAD_DIM each time - produce_k_blockwise_c8(k_smem, - &k_smem_offset_w, - cache_k, - block_table_now, - kv_head_idx, - kv_n_stride, - kv_h_stride, - kv_b_stride, - kv_idx_base, - chunk_end, - const_k_offset); - commit_group(); - produce_v_blockwise_c8(v_smem, - &v_smem_offset_w, - cache_v, - block_table_now, - kv_head_idx, - kv_n_stride, - kv_h_stride, - kv_d_stride, - kv_idx_base, - chunk_end, - const_v_offset); - commit_group(); -#pragma unroll 1 - for (uint32_t iter = 0; iter < num_iterations; ++iter) { - wait_group<1>(); - __syncthreads(); - - // s = qk - compute_qk_c8( - &qo_smem, - &q_smem_offset_r, - &k_smem, - &k_smem_offset_r, - cache_k_scale_reg, - s_frag); - // mask according to kv_idx and q_idx - if (iter >= mask_check_iteration) { - mask_s(q_base_seq_id_this_block, - kv_idx_base + wid * num_frags_z * 16, - q_len, - kv_len, - chunk_end, - s_frag); - } - - // update m,d - update_mdo_states( - s_frag, o_frag, m_frag, d_frag); - __syncthreads(); - - kv_idx_base += NUM_WARP_KV * num_frags_z * 16; - produce_k_blockwise_c8(k_smem, - &k_smem_offset_w, - cache_k, - block_table_now, - kv_head_idx, - kv_n_stride, - kv_h_stride, - kv_b_stride, - kv_idx_base, - chunk_end, - const_k_offset); - commit_group(); - wait_group<1>(); - __syncthreads(); - - // compute sfm * v - compute_sfm_v_c8_iter_sq_bvec( - &v_smem, &v_smem_offset_r, s_frag, o_frag, d_frag, cache_v_scale_reg); - __syncthreads(); - - produce_v_blockwise_c8(v_smem, - &v_smem_offset_w, - cache_v, - block_table_now, - kv_head_idx, - kv_n_stride, - kv_h_stride, - kv_d_stride, - kv_idx_base, - chunk_end, - const_v_offset); - commit_group(); - } - wait_group<0>(); - __syncthreads(); - - merge_block_res_v2( - o_frag, reinterpret_cast(smem), m_frag, d_frag, wid, tid); - - if (num_chunks_this_seq <= 1) { - normalize_d(o_frag, d_frag); - } - - // write o - // [num_frags_x, 16, num_frags_y, 16] - if (num_chunks_this_seq <= 1) { - write_o_reg_gmem_multi_warps_shift_smooth_quant( - o_frag, - &qo_smem, - o_base_ptr_int8, - shift_bias, - smooth_weight, - q_base_seq_id_this_block, - q_head_idx, - quant_max_bound, - quant_min_bound, - in_scale, - q_len, - q_n_stride, - HEAD_DIM); - } else { - write_o_reg_gmem_multi_warps_shift_smooth_quant( - o_frag, - &qo_smem, - o_base_ptr_T, - shift_bias, - smooth_weight, - q_base_seq_id_this_block, - q_head_idx, - quant_max_bound, - quant_min_bound, - in_scale, - q_len, - q_n_stride * num_chunks, - HEAD_DIM); - } - - if (num_chunks_this_seq > 1) { - if (wid == 0) { -#pragma unroll - for (uint32_t fx = 0; fx < num_frags_x; ++fx) { -#pragma unroll - for (uint32_t j = 0; j < 2; ++j) { - const uint32_t qo_idx_now = - q_base_seq_id_this_block + tid / 4 + j * 8 + fx * 16; - const uint32_t qo_head_idx = q_head_idx + qo_idx_now % GROUP_SIZE; - const uint32_t qo_idx = q_start_seq_id + qo_idx_now / GROUP_SIZE; - if (qo_idx - q_start_seq_id < q_len) { - - uint32_t offset; - if (ENABLE_PREFILL) { - offset = (batch_id * num_chunks + chunk_idx) * q_num_heads + - qo_head_idx; - } else { - offset = ((batch_id * speculate_max_draft_token_num + - qo_idx_now / GROUP_SIZE) * - num_chunks + - chunk_idx) * - q_num_heads + - qo_head_idx; - } - tmp_m[offset] = m_frag[fx][j]; - tmp_d[offset] = d_frag[fx][j]; - } - } - } - } - } -} - -template -void MultiQueryAppendC8Attention( - const AppendAttnMetaData &meta_data, - const paddle::Tensor &qkv, - const paddle::Tensor &cache_k, - const paddle::Tensor &cache_v, - const paddle::optional &attn_mask, - const paddle::Tensor &cache_k_scale, - const paddle::Tensor &cache_v_scale, - const paddle::optional &shift_bias, - const paddle::optional &smooth_weight, - const paddle::Tensor &seq_lens_q, - const paddle::Tensor &seq_lens_kv, - const paddle::Tensor &seq_lens_encoder, - const paddle::Tensor &batch_id_per_token, - const paddle::Tensor &cu_seqlens_q, - const paddle::Tensor &block_table, - const paddle::Tensor &batch_ids, - const paddle::Tensor &tile_ids_per_batch, - const int num_blocks_x_cpu, - const int max_seq_len, - const int max_dec_len, - const float quant_max_bound, - const float quant_min_bound, - const float in_scale, - const int max_partition_size, - const int encoder_max_partition_size, - const int speculate_max_draft_token_num, - const bool is_decoder, - cudaStream_t &stream, - paddle::Tensor *out) { - using NV_TYPE = typename cascade_attn_type_traits::type; - using OUT_NV_TYPE = typename cascade_attn_type_traits::type; - - auto num_heads = meta_data.q_num_heads; - auto kv_num_heads = meta_data.kv_num_heads; - auto token_num = meta_data.token_nums; - auto bsz = meta_data.batch_size; - auto max_block_num_per_seq = meta_data.max_blocks_per_seq; - - constexpr uint32_t num_warps = 4; - constexpr uint32_t NUM_WARP_KV = num_warps / NUM_WARP_Q; - constexpr uint32_t num_frags_x = BLOCK_SHAPE_Q / (16 * NUM_WARP_Q); - constexpr uint32_t num_frags_y = HEAD_DIM / 16; - constexpr uint32_t num_qrow_per_block = NUM_WARP_Q * num_frags_x * 16; - - auto *allocator = paddle::GetAllocator(qkv.place()); - - const float scale = 1.f / sqrt(HEAD_DIM); - bool is_scale_channel_wise = false; - if (cache_k_scale.dims()[0] == HEAD_DIM * kv_num_heads) { - is_scale_channel_wise = true; - } - - if constexpr (NUM_WARP_Q == 4) { - constexpr uint32_t num_frags_z = BLOCK_SIZE / 16; - constexpr uint32_t smem_size = - num_warps * num_frags_x * 16 * HEAD_DIM * sizeof(T) + - num_frags_z * 16 * HEAD_DIM * sizeof(uint8_t) * 2; - auto split_kv_kernel = - multi_query_append_attention_c8_kernel; - if (is_scale_channel_wise) { - split_kv_kernel = - multi_query_append_attention_c8_kernel; - } - if (smem_size >= 48 * 1024) { - cudaFuncSetAttribute(split_kv_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, - smem_size); - } - const int dev_id = 0; - int sm_count; - cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id); - uint32_t chunk_size = static_cast(max_partition_size); - if (!is_decoder) { - chunk_size = static_cast(encoder_max_partition_size); - } - const int num_chunks = div_up(max_dec_len, chunk_size); - dim3 grids(num_blocks_x_cpu, num_chunks, kv_num_heads); - dim3 blocks(32, num_warps); - if (num_chunks <= 1) { - auto nosplit_kv_kernel = - multi_query_append_attention_c8_kernel; - if (is_scale_channel_wise) { - nosplit_kv_kernel = - multi_query_append_attention_c8_kernel; - } - if (smem_size >= 48 * 1024) { - cudaFuncSetAttribute(nosplit_kv_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, - smem_size); - } - - nosplit_kv_kernel<<>>( - reinterpret_cast(const_cast(qkv.data())), - const_cast(cache_k.data()), - const_cast(cache_v.data()), - reinterpret_cast(const_cast(cache_k_scale.data())), - reinterpret_cast(const_cast(cache_v_scale.data())), - shift_bias ? reinterpret_cast( - const_cast(shift_bias.get().data())) - : nullptr, - smooth_weight ? reinterpret_cast( - const_cast(smooth_weight.get().data())) - : nullptr, - seq_lens_q.data(), - seq_lens_kv.data(), - batch_ids.data(), - tile_ids_per_batch.data(), - cu_seqlens_q.data(), - block_table.data(), - max_seq_len, - max_dec_len, - max_block_num_per_seq, - scale, - quant_max_bound, - quant_min_bound, - in_scale, - chunk_size, - nullptr, - nullptr, - nullptr, - reinterpret_cast(out->data()), - speculate_max_draft_token_num); - } else { - phi::Allocator::AllocationPtr tmp_workspace, tmp_m, tmp_d; - if (ENABLE_PREFILL) { - tmp_workspace = allocator->Allocate( - phi::SizeOf(qkv.dtype()) * - static_cast(token_num * num_chunks * num_heads * HEAD_DIM)); - tmp_m = allocator->Allocate( - phi::SizeOf(paddle::DataType::FLOAT32) * - static_cast(token_num * num_chunks * num_heads)); - tmp_d = allocator->Allocate( - phi::SizeOf(paddle::DataType::FLOAT32) * - static_cast(token_num * num_chunks * num_heads)); - } else { - tmp_workspace = allocator->Allocate( - phi::SizeOf(qkv.dtype()) * - static_cast(speculate_max_draft_token_num * bsz * - num_chunks * num_heads * HEAD_DIM)); - tmp_m = allocator->Allocate( - phi::SizeOf(paddle::DataType::FLOAT32) * - static_cast(speculate_max_draft_token_num * bsz * - num_chunks * num_heads)); - tmp_d = allocator->Allocate( - phi::SizeOf(paddle::DataType::FLOAT32) * - static_cast(speculate_max_draft_token_num * bsz * - num_chunks * num_heads)); - } - split_kv_kernel<<>>( - reinterpret_cast(const_cast(qkv.data())), - const_cast(cache_k.data()), - const_cast(cache_v.data()), - reinterpret_cast(const_cast(cache_k_scale.data())), - reinterpret_cast(const_cast(cache_v_scale.data())), - shift_bias ? reinterpret_cast( - const_cast(shift_bias.get().data())) - : nullptr, - smooth_weight ? reinterpret_cast( - const_cast(smooth_weight.get().data())) - : nullptr, - seq_lens_q.data(), - seq_lens_kv.data(), - batch_ids.data(), - tile_ids_per_batch.data(), - cu_seqlens_q.data(), - block_table.data(), - max_seq_len, - max_dec_len, - max_block_num_per_seq, - scale, - quant_max_bound, - quant_min_bound, - in_scale, - chunk_size, - reinterpret_cast(tmp_workspace->ptr()), - static_cast(tmp_m->ptr()), - static_cast(tmp_d->ptr()), - reinterpret_cast(out->data()), - speculate_max_draft_token_num); - // merge - constexpr int vec_size = num_elems_per_128b(); - if (is_decoder) { - constexpr int blockx = HEAD_DIM / vec_size; - constexpr int blocky = (128 + blockx - 1) / blockx; - dim3 grids_merge(bsz, num_heads); - dim3 blocks_merge(blockx, blocky); - merge_multi_chunks_decoder_kernel - <<>>( - reinterpret_cast(tmp_workspace->ptr()), - static_cast(tmp_m->ptr()), - static_cast(tmp_d->ptr()), - seq_lens_q.data(), - seq_lens_kv.data(), - seq_lens_encoder.data(), - cu_seqlens_q.data(), - shift_bias ? reinterpret_cast( - const_cast(shift_bias.get().data())) - : nullptr, - smooth_weight ? reinterpret_cast(const_cast( - smooth_weight.get().data())) - : nullptr, - reinterpret_cast(out->data()), - quant_max_bound, - quant_min_bound, - in_scale, - max_seq_len, - num_chunks, - num_heads, - chunk_size, - HEAD_DIM); - } else { - constexpr int blockx = HEAD_DIM / vec_size; - constexpr int blocky = (128 + blockx - 1) / blockx; - dim3 grids_merge(min(sm_count * 4, token_num), - num_heads); - dim3 blocks_merge(blockx, blocky); - merge_multi_chunks_v2_kernel - <<>>( - reinterpret_cast(tmp_workspace->ptr()), - static_cast(tmp_m->ptr()), - static_cast(tmp_d->ptr()), - seq_lens_q.data(), - seq_lens_kv.data(), - seq_lens_encoder.data(), - batch_id_per_token.data(), - cu_seqlens_q.data(), - shift_bias ? reinterpret_cast( - const_cast(shift_bias.get().data())) - : nullptr, - smooth_weight ? reinterpret_cast(const_cast( - smooth_weight.get().data())) - : nullptr, - reinterpret_cast(out->data()), - quant_max_bound, - quant_min_bound, - in_scale, - max_seq_len, - num_chunks, - num_heads, - chunk_size, - HEAD_DIM, - token_num, - speculate_max_draft_token_num); - } - } - } else { - constexpr uint32_t num_frags_z = BLOCK_SIZE / 16 / NUM_WARP_KV * 2; - constexpr uint32_t smem_size = - num_frags_x * 16 * HEAD_DIM * sizeof(T) + - NUM_WARP_KV * num_frags_z * 16 * HEAD_DIM * sizeof(uint8_t) * 2; - auto split_kv_kernel = - multi_query_append_attention_c8_warp1_4_kernel; - if (is_scale_channel_wise) { - split_kv_kernel = - multi_query_append_attention_c8_warp1_4_kernel; - } - if (smem_size >= 48 * 1024) { - cudaFuncSetAttribute(split_kv_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, - smem_size); - } - const int dev_id = 0; - int sm_count; - cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id); - uint32_t chunk_size = static_cast(max_partition_size); - if (!is_decoder) { - chunk_size = static_cast(encoder_max_partition_size); - } - - const int num_chunks = div_up(max_dec_len, chunk_size); - dim3 grids(num_blocks_x_cpu, num_chunks, kv_num_heads); - dim3 blocks(32, num_warps); - if (num_chunks <= 1) { - auto nosplit_kv_kernel = - multi_query_append_attention_c8_warp1_4_kernel; - if (is_scale_channel_wise) { - nosplit_kv_kernel = - multi_query_append_attention_c8_warp1_4_kernel; - } - if (smem_size >= 48 * 1024) { - cudaFuncSetAttribute(nosplit_kv_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, - smem_size); - } - - nosplit_kv_kernel<<>>( - reinterpret_cast(const_cast(qkv.data())), - const_cast(cache_k.data()), - const_cast(cache_v.data()), - reinterpret_cast(const_cast(cache_k_scale.data())), - reinterpret_cast(const_cast(cache_v_scale.data())), - shift_bias ? reinterpret_cast( - const_cast(shift_bias.get().data())) - : nullptr, - smooth_weight ? reinterpret_cast( - const_cast(smooth_weight.get().data())) - : nullptr, - seq_lens_q.data(), - seq_lens_kv.data(), - batch_ids.data(), - tile_ids_per_batch.data(), - cu_seqlens_q.data(), - block_table.data(), - max_seq_len, - max_dec_len, - max_block_num_per_seq, - scale, - quant_max_bound, - quant_min_bound, - in_scale, - chunk_size, - nullptr, - nullptr, - nullptr, - reinterpret_cast(out->data()), - speculate_max_draft_token_num); - } else { - phi::Allocator::AllocationPtr tmp_workspace, tmp_m, tmp_d; - if (is_decoder) { - tmp_workspace = allocator->Allocate( - phi::SizeOf(qkv.dtype()) * - static_cast(bsz * num_chunks * num_heads * HEAD_DIM)); - tmp_m = allocator->Allocate( - phi::SizeOf(paddle::DataType::FLOAT32) * - static_cast(bsz * num_chunks * num_heads)); - tmp_d = allocator->Allocate( - phi::SizeOf(paddle::DataType::FLOAT32) * - static_cast(bsz * num_chunks * num_heads)); - } else { - if (ENABLE_PREFILL) { - tmp_workspace = - allocator->Allocate(phi::SizeOf(qkv.dtype()) * - static_cast(token_num * num_chunks * - num_heads * HEAD_DIM)); - tmp_m = allocator->Allocate( - phi::SizeOf(paddle::DataType::FLOAT32) * - static_cast(token_num * num_chunks * num_heads)); - tmp_d = allocator->Allocate( - phi::SizeOf(paddle::DataType::FLOAT32) * - static_cast(token_num * num_chunks * num_heads)); - } else { - tmp_workspace = allocator->Allocate( - phi::SizeOf(qkv.dtype()) * - static_cast(speculate_max_draft_token_num * bsz * - num_chunks * num_heads * HEAD_DIM)); - tmp_m = allocator->Allocate( - phi::SizeOf(paddle::DataType::FLOAT32) * - static_cast(speculate_max_draft_token_num * bsz * - num_chunks * num_heads)); - tmp_d = allocator->Allocate( - phi::SizeOf(paddle::DataType::FLOAT32) * - static_cast(speculate_max_draft_token_num * bsz * - num_chunks * num_heads)); - } - } - split_kv_kernel<<>>( - reinterpret_cast(const_cast(qkv.data())), - const_cast(cache_k.data()), - const_cast(cache_v.data()), - reinterpret_cast(const_cast(cache_k_scale.data())), - reinterpret_cast(const_cast(cache_v_scale.data())), - shift_bias ? reinterpret_cast( - const_cast(shift_bias.get().data())) - : nullptr, - smooth_weight ? reinterpret_cast( - const_cast(smooth_weight.get().data())) - : nullptr, - seq_lens_q.data(), - seq_lens_kv.data(), - batch_ids.data(), - tile_ids_per_batch.data(), - cu_seqlens_q.data(), - block_table.data(), - max_seq_len, - max_dec_len, - max_block_num_per_seq, - scale, - quant_max_bound, - quant_min_bound, - in_scale, - chunk_size, - reinterpret_cast(tmp_workspace->ptr()), - static_cast(tmp_m->ptr()), - static_cast(tmp_d->ptr()), - reinterpret_cast(out->data()), - speculate_max_draft_token_num); - // merge - constexpr int vec_size = num_elems_per_128b(); - if (is_decoder) { - constexpr int blockx = HEAD_DIM / vec_size; - constexpr int blocky = (128 + blockx - 1) / blockx; - dim3 grids_merge(bsz, num_heads); - dim3 blocks_merge(blockx, blocky); - merge_multi_chunks_decoder_kernel - <<>>( - reinterpret_cast(tmp_workspace->ptr()), - static_cast(tmp_m->ptr()), - static_cast(tmp_d->ptr()), - seq_lens_q.data(), - seq_lens_kv.data(), - seq_lens_encoder.data(), - cu_seqlens_q.data(), - shift_bias ? reinterpret_cast( - const_cast(shift_bias.get().data())) - : nullptr, - smooth_weight ? reinterpret_cast(const_cast( - smooth_weight.get().data())) - : nullptr, - reinterpret_cast(out->data()), - quant_max_bound, - quant_min_bound, - in_scale, - max_seq_len, - num_chunks, - num_heads, - chunk_size, - HEAD_DIM); - } else { - constexpr int blockx = HEAD_DIM / vec_size; - constexpr int blocky = (128 + blockx - 1) / blockx; - dim3 grids_merge(min(sm_count * 4, token_num), - num_heads); - dim3 blocks_merge(blockx, blocky); - merge_multi_chunks_v2_kernel - <<>>( - reinterpret_cast(tmp_workspace->ptr()), - static_cast(tmp_m->ptr()), - static_cast(tmp_d->ptr()), - seq_lens_q.data(), - seq_lens_kv.data(), - seq_lens_encoder.data(), - batch_id_per_token.data(), - cu_seqlens_q.data(), - shift_bias ? reinterpret_cast( - const_cast(shift_bias.get().data())) - : nullptr, - smooth_weight ? reinterpret_cast(const_cast( - smooth_weight.get().data())) - : nullptr, - reinterpret_cast(out->data()), - quant_max_bound, - quant_min_bound, - in_scale, - max_seq_len, - num_chunks, - num_heads, - chunk_size, - HEAD_DIM, - token_num, - speculate_max_draft_token_num); - } - } - } -} - -template +template void CascadeAppendAttentionC8Kernel( const AppendAttnMetaData& meta_data, - const paddle::Tensor& qkv, // [token_num, (num_heads + 2* kv_num_head) * head_dim] + const paddle::Tensor& + qkv, // [token_num, (num_heads + 2* kv_num_head) * head_dim] const paddle::Tensor& cache_k, // [max_block_num, num_heads, block_size, head_dim] const paddle::Tensor& @@ -1495,7 +36,8 @@ void CascadeAppendAttentionC8Kernel( const paddle::optional& shift_bias, // [num_kv_heads, head_dim] const paddle::optional& - smooth_weight, // [num_kv_heads, head_dim] + smooth_weight, // [num_kv_heads, head_dim] + const paddle::optional& sinks, // [num_heads] const paddle::Tensor& seq_lens_q, const paddle::Tensor& seq_lens_kv, const paddle::Tensor& seq_lens_encoder, @@ -1517,14 +59,17 @@ void CascadeAppendAttentionC8Kernel( const bool causal, const bool is_decoder, const bool enable_prefill, + const std::string& cache_quant_type_str, cudaStream_t& stream, - paddle::Tensor* out) { + paddle::Tensor* out, + const int sliding_window = 0, + const int sink_size = 0) { const auto token_num = meta_data.token_nums; const auto block_size = meta_data.block_size; - const auto bsz = meta_data.batch_size; const auto num_heads = meta_data.q_num_heads; const auto group_size = meta_data.q_num_heads / meta_data.kv_num_heads; const auto head_dim = meta_data.head_dims; + bool is_dynamic_cfp8 = cache_quant_type_str == "block_wise_fp8"; DISPATCH_CAUSAL( causal, @@ -1542,7 +87,10 @@ void CascadeAppendAttentionC8Kernel( block_size, BLOCK_SIZE, {DISPATCH_BLOCKSHAPE_Q( - block_shape_q, BLOCK_SHAPE_Q, NUM_WARP_Q, { + block_shape_q, + BLOCK_SHAPE_Q, + NUM_WARP_Q, + {DISPATCH_DyCfp8(is_dynamic_cfp8, IsDynamicC8, { MultiQueryAppendC8Attention( + ENABLE_PREFILL, + IsFP8, + IsDynamicC8>( meta_data, qkv, cache_k, @@ -1561,6 +111,7 @@ void CascadeAppendAttentionC8Kernel( cache_v_scale.get(), shift_bias, smooth_weight, + sinks, seq_lens_q, seq_lens_kv, seq_lens_encoder, @@ -1580,6 +131,496 @@ void CascadeAppendAttentionC8Kernel( speculate_max_draft_token_num, is_decoder, stream, - out); - })})})})})}) + out, + sliding_window, + sink_size); + })})})})})})}) } + +template void +CascadeAppendAttentionC8Kernel( + const AppendAttnMetaData& meta_data, + const paddle::Tensor& qkv, + const paddle::Tensor& cache_k, + const paddle::Tensor& cache_v, + const paddle::optional& attn_mask, + const paddle::optional& cache_k_scale, + const paddle::optional& cache_v_scale, + const paddle::optional& cache_k_zp, + const paddle::optional& cache_v_zp, + const paddle::optional& shift_bias, + const paddle::optional& smooth_weight, + const paddle::optional& sinks, + const paddle::Tensor& seq_lens_q, + const paddle::Tensor& seq_lens_kv, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& batch_id_per_token, + const paddle::Tensor& cu_seqlens_q, + const paddle::Tensor& block_table, + const paddle::Tensor& batch_ids, + const paddle::Tensor& tile_ids_per_batch, + const int num_blocks, + const int block_shape_q, + const int max_seq_len, + const int max_dec_len, + const float quant_max_bound, + const float quant_min_bound, + const float in_scale, + const int max_partition_size, + const int encoder_max_partition_size, + const int speculate_max_draft_token_num, + const bool causal, + const bool is_decoder, + const bool enable_prefill, + const std::string& cache_quant_type_str, + cudaStream_t& stream, + paddle::Tensor* out, + const int sliding_window, + const int sink_size); + +template void +CascadeAppendAttentionC8Kernel( + const AppendAttnMetaData& meta_data, + const paddle::Tensor& qkv, + const paddle::Tensor& cache_k, + const paddle::Tensor& cache_v, + const paddle::optional& attn_mask, + const paddle::optional& cache_k_scale, + const paddle::optional& cache_v_scale, + const paddle::optional& cache_k_zp, + const paddle::optional& cache_v_zp, + const paddle::optional& shift_bias, + const paddle::optional& smooth_weight, + const paddle::optional& sinks, + const paddle::Tensor& seq_lens_q, + const paddle::Tensor& seq_lens_kv, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& batch_id_per_token, + const paddle::Tensor& cu_seqlens_q, + const paddle::Tensor& block_table, + const paddle::Tensor& batch_ids, + const paddle::Tensor& tile_ids_per_batch, + const int num_blocks, + const int block_shape_q, + const int max_seq_len, + const int max_dec_len, + const float quant_max_bound, + const float quant_min_bound, + const float in_scale, + const int max_partition_size, + const int encoder_max_partition_size, + const int speculate_max_draft_token_num, + const bool causal, + const bool is_decoder, + const bool enable_prefill, + const std::string& cache_quant_type_str, + cudaStream_t& stream, + paddle::Tensor* out, + const int sliding_window, + const int sink_size); + +template void +CascadeAppendAttentionC8Kernel( + const AppendAttnMetaData& meta_data, + const paddle::Tensor& qkv, + const paddle::Tensor& cache_k, + const paddle::Tensor& cache_v, + const paddle::optional& attn_mask, + const paddle::optional& cache_k_scale, + const paddle::optional& cache_v_scale, + const paddle::optional& cache_k_zp, + const paddle::optional& cache_v_zp, + const paddle::optional& shift_bias, + const paddle::optional& smooth_weight, + const paddle::optional& sinks, + const paddle::Tensor& seq_lens_q, + const paddle::Tensor& seq_lens_kv, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& batch_id_per_token, + const paddle::Tensor& cu_seqlens_q, + const paddle::Tensor& block_table, + const paddle::Tensor& batch_ids, + const paddle::Tensor& tile_ids_per_batch, + const int num_blocks, + const int block_shape_q, + const int max_seq_len, + const int max_dec_len, + const float quant_max_bound, + const float quant_min_bound, + const float in_scale, + const int max_partition_size, + const int encoder_max_partition_size, + const int speculate_max_draft_token_num, + const bool causal, + const bool is_decoder, + const bool enable_prefill, + const std::string& cache_quant_type_str, + cudaStream_t& stream, + paddle::Tensor* out, + const int sliding_window, + const int sink_size); + +template void +CascadeAppendAttentionC8Kernel( + const AppendAttnMetaData& meta_data, + const paddle::Tensor& qkv, + const paddle::Tensor& cache_k, + const paddle::Tensor& cache_v, + const paddle::optional& attn_mask, + const paddle::optional& cache_k_scale, + const paddle::optional& cache_v_scale, + const paddle::optional& cache_k_zp, + const paddle::optional& cache_v_zp, + const paddle::optional& shift_bias, + const paddle::optional& smooth_weight, + const paddle::optional& sinks, + const paddle::Tensor& seq_lens_q, + const paddle::Tensor& seq_lens_kv, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& batch_id_per_token, + const paddle::Tensor& cu_seqlens_q, + const paddle::Tensor& block_table, + const paddle::Tensor& batch_ids, + const paddle::Tensor& tile_ids_per_batch, + const int num_blocks, + const int block_shape_q, + const int max_seq_len, + const int max_dec_len, + const float quant_max_bound, + const float quant_min_bound, + const float in_scale, + const int max_partition_size, + const int encoder_max_partition_size, + const int speculate_max_draft_token_num, + const bool causal, + const bool is_decoder, + const bool enable_prefill, + const std::string& cache_quant_type_str, + cudaStream_t& stream, + paddle::Tensor* out, + const int sliding_window, + const int sink_size); + +template void CascadeAppendAttentionC8Kernel( + const AppendAttnMetaData& meta_data, + const paddle::Tensor& qkv, + const paddle::Tensor& cache_k, + const paddle::Tensor& cache_v, + const paddle::optional& attn_mask, + const paddle::optional& cache_k_scale, + const paddle::optional& cache_v_scale, + const paddle::optional& cache_k_zp, + const paddle::optional& cache_v_zp, + const paddle::optional& shift_bias, + const paddle::optional& smooth_weight, + const paddle::optional& sinks, + const paddle::Tensor& seq_lens_q, + const paddle::Tensor& seq_lens_kv, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& batch_id_per_token, + const paddle::Tensor& cu_seqlens_q, + const paddle::Tensor& block_table, + const paddle::Tensor& batch_ids, + const paddle::Tensor& tile_ids_per_batch, + const int num_blocks, + const int block_shape_q, + const int max_seq_len, + const int max_dec_len, + const float quant_max_bound, + const float quant_min_bound, + const float in_scale, + const int max_partition_size, + const int encoder_max_partition_size, + const int speculate_max_draft_token_num, + const bool causal, + const bool is_decoder, + const bool enable_prefill, + const std::string& cache_quant_type_str, + cudaStream_t& stream, + paddle::Tensor* out, + const int sliding_window, + const int sink_size); + +template void CascadeAppendAttentionC8Kernel( + const AppendAttnMetaData& meta_data, + const paddle::Tensor& qkv, + const paddle::Tensor& cache_k, + const paddle::Tensor& cache_v, + const paddle::optional& attn_mask, + const paddle::optional& cache_k_scale, + const paddle::optional& cache_v_scale, + const paddle::optional& cache_k_zp, + const paddle::optional& cache_v_zp, + const paddle::optional& shift_bias, + const paddle::optional& smooth_weight, + const paddle::optional& sinks, + const paddle::Tensor& seq_lens_q, + const paddle::Tensor& seq_lens_kv, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& batch_id_per_token, + const paddle::Tensor& cu_seqlens_q, + const paddle::Tensor& block_table, + const paddle::Tensor& batch_ids, + const paddle::Tensor& tile_ids_per_batch, + const int num_blocks, + const int block_shape_q, + const int max_seq_len, + const int max_dec_len, + const float quant_max_bound, + const float quant_min_bound, + const float in_scale, + const int max_partition_size, + const int encoder_max_partition_size, + const int speculate_max_draft_token_num, + const bool causal, + const bool is_decoder, + const bool enable_prefill, + const std::string& cache_quant_type_str, + cudaStream_t& stream, + paddle::Tensor* out, + const int sliding_window, + const int sink_size); + +template void +CascadeAppendAttentionC8Kernel( + const AppendAttnMetaData& meta_data, + const paddle::Tensor& qkv, + const paddle::Tensor& cache_k, + const paddle::Tensor& cache_v, + const paddle::optional& attn_mask, + const paddle::optional& cache_k_scale, + const paddle::optional& cache_v_scale, + const paddle::optional& cache_k_zp, + const paddle::optional& cache_v_zp, + const paddle::optional& shift_bias, + const paddle::optional& smooth_weight, + const paddle::optional& sinks, + const paddle::Tensor& seq_lens_q, + const paddle::Tensor& seq_lens_kv, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& batch_id_per_token, + const paddle::Tensor& cu_seqlens_q, + const paddle::Tensor& block_table, + const paddle::Tensor& batch_ids, + const paddle::Tensor& tile_ids_per_batch, + const int num_blocks, + const int block_shape_q, + const int max_seq_len, + const int max_dec_len, + const float quant_max_bound, + const float quant_min_bound, + const float in_scale, + const int max_partition_size, + const int encoder_max_partition_size, + const int speculate_max_draft_token_num, + const bool causal, + const bool is_decoder, + const bool enable_prefill, + const std::string& cache_quant_type_str, + cudaStream_t& stream, + paddle::Tensor* out, + const int sliding_window, + const int sink_size); + +template void +CascadeAppendAttentionC8Kernel( + const AppendAttnMetaData& meta_data, + const paddle::Tensor& qkv, + const paddle::Tensor& cache_k, + const paddle::Tensor& cache_v, + const paddle::optional& attn_mask, + const paddle::optional& cache_k_scale, + const paddle::optional& cache_v_scale, + const paddle::optional& cache_k_zp, + const paddle::optional& cache_v_zp, + const paddle::optional& shift_bias, + const paddle::optional& smooth_weight, + const paddle::optional& sinks, + const paddle::Tensor& seq_lens_q, + const paddle::Tensor& seq_lens_kv, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& batch_id_per_token, + const paddle::Tensor& cu_seqlens_q, + const paddle::Tensor& block_table, + const paddle::Tensor& batch_ids, + const paddle::Tensor& tile_ids_per_batch, + const int num_blocks, + const int block_shape_q, + const int max_seq_len, + const int max_dec_len, + const float quant_max_bound, + const float quant_min_bound, + const float in_scale, + const int max_partition_size, + const int encoder_max_partition_size, + const int speculate_max_draft_token_num, + const bool causal, + const bool is_decoder, + const bool enable_prefill, + const std::string& cache_quant_type_str, + cudaStream_t& stream, + paddle::Tensor* out, + const int sliding_window, + const int sink_size); + +template void +CascadeAppendAttentionC8Kernel( + const AppendAttnMetaData& meta_data, + const paddle::Tensor& qkv, + const paddle::Tensor& cache_k, + const paddle::Tensor& cache_v, + const paddle::optional& attn_mask, + const paddle::optional& cache_k_scale, + const paddle::optional& cache_v_scale, + const paddle::optional& cache_k_zp, + const paddle::optional& cache_v_zp, + const paddle::optional& shift_bias, + const paddle::optional& smooth_weight, + const paddle::optional& sinks, + const paddle::Tensor& seq_lens_q, + const paddle::Tensor& seq_lens_kv, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& batch_id_per_token, + const paddle::Tensor& cu_seqlens_q, + const paddle::Tensor& block_table, + const paddle::Tensor& batch_ids, + const paddle::Tensor& tile_ids_per_batch, + const int num_blocks, + const int block_shape_q, + const int max_seq_len, + const int max_dec_len, + const float quant_max_bound, + const float quant_min_bound, + const float in_scale, + const int max_partition_size, + const int encoder_max_partition_size, + const int speculate_max_draft_token_num, + const bool causal, + const bool is_decoder, + const bool enable_prefill, + const std::string& cache_quant_type_str, + cudaStream_t& stream, + paddle::Tensor* out, + const int sliding_window, + const int sink_size); + +template void +CascadeAppendAttentionC8Kernel( + const AppendAttnMetaData& meta_data, + const paddle::Tensor& qkv, + const paddle::Tensor& cache_k, + const paddle::Tensor& cache_v, + const paddle::optional& attn_mask, + const paddle::optional& cache_k_scale, + const paddle::optional& cache_v_scale, + const paddle::optional& cache_k_zp, + const paddle::optional& cache_v_zp, + const paddle::optional& shift_bias, + const paddle::optional& smooth_weight, + const paddle::optional& sinks, + const paddle::Tensor& seq_lens_q, + const paddle::Tensor& seq_lens_kv, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& batch_id_per_token, + const paddle::Tensor& cu_seqlens_q, + const paddle::Tensor& block_table, + const paddle::Tensor& batch_ids, + const paddle::Tensor& tile_ids_per_batch, + const int num_blocks, + const int block_shape_q, + const int max_seq_len, + const int max_dec_len, + const float quant_max_bound, + const float quant_min_bound, + const float in_scale, + const int max_partition_size, + const int encoder_max_partition_size, + const int speculate_max_draft_token_num, + const bool causal, + const bool is_decoder, + const bool enable_prefill, + const std::string& cache_quant_type_str, + cudaStream_t& stream, + paddle::Tensor* out, + const int sliding_window, + const int sink_size); + +template void CascadeAppendAttentionC8Kernel( + const AppendAttnMetaData& meta_data, + const paddle::Tensor& qkv, + const paddle::Tensor& cache_k, + const paddle::Tensor& cache_v, + const paddle::optional& attn_mask, + const paddle::optional& cache_k_scale, + const paddle::optional& cache_v_scale, + const paddle::optional& cache_k_zp, + const paddle::optional& cache_v_zp, + const paddle::optional& shift_bias, + const paddle::optional& smooth_weight, + const paddle::optional& sinks, + const paddle::Tensor& seq_lens_q, + const paddle::Tensor& seq_lens_kv, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& batch_id_per_token, + const paddle::Tensor& cu_seqlens_q, + const paddle::Tensor& block_table, + const paddle::Tensor& batch_ids, + const paddle::Tensor& tile_ids_per_batch, + const int num_blocks, + const int block_shape_q, + const int max_seq_len, + const int max_dec_len, + const float quant_max_bound, + const float quant_min_bound, + const float in_scale, + const int max_partition_size, + const int encoder_max_partition_size, + const int speculate_max_draft_token_num, + const bool causal, + const bool is_decoder, + const bool enable_prefill, + const std::string& cache_quant_type_str, + cudaStream_t& stream, + paddle::Tensor* out, + const int sliding_window, + const int sink_size); + +template void CascadeAppendAttentionC8Kernel( + const AppendAttnMetaData& meta_data, + const paddle::Tensor& qkv, + const paddle::Tensor& cache_k, + const paddle::Tensor& cache_v, + const paddle::optional& attn_mask, + const paddle::optional& cache_k_scale, + const paddle::optional& cache_v_scale, + const paddle::optional& cache_k_zp, + const paddle::optional& cache_v_zp, + const paddle::optional& shift_bias, + const paddle::optional& smooth_weight, + const paddle::optional& sinks, + const paddle::Tensor& seq_lens_q, + const paddle::Tensor& seq_lens_kv, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& batch_id_per_token, + const paddle::Tensor& cu_seqlens_q, + const paddle::Tensor& block_table, + const paddle::Tensor& batch_ids, + const paddle::Tensor& tile_ids_per_batch, + const int num_blocks, + const int block_shape_q, + const int max_seq_len, + const int max_dec_len, + const float quant_max_bound, + const float quant_min_bound, + const float in_scale, + const int max_partition_size, + const int encoder_max_partition_size, + const int speculate_max_draft_token_num, + const bool causal, + const bool is_decoder, + const bool enable_prefill, + const std::string& cache_quant_type_str, + cudaStream_t& stream, + paddle::Tensor* out, + const int sliding_window, + const int sink_size); diff --git a/custom_ops/gpu_ops/append_attn/append_attention_func.cuh b/custom_ops/gpu_ops/append_attn/append_attention_func.cuh index 8b6802d27d8..16c0cb4f23a 100644 --- a/custom_ops/gpu_ops/append_attn/append_attention_func.cuh +++ b/custom_ops/gpu_ops/append_attn/append_attention_func.cuh @@ -77,6 +77,14 @@ struct prefill_softmax_state_t { __device__ __forceinline__ void normalize() { const T d_t = static_cast(d); +#pragma unroll + for (size_t i = 0; i < vec_size; ++i) { + o[i] /= d_t; + } + } + + __device__ __forceinline__ void normalize(float current_sink) { + const T d_t = static_cast(d + __expf(current_sink - m)); #pragma unroll for (size_t i = 0; i < vec_size; ++i) { o[i] /= d_t; @@ -134,7 +142,6 @@ __device__ __forceinline__ void load_q_global_smem_multi_warps( const uint32_t tx_offset = tx / 8; #pragma unroll for (uint32_t fx = 0; fx < num_frags_x; ++fx) { - const uint32_t base_offset = q_idx_base + fx * 16 + tx_offset; #pragma unroll const int j = ty; @@ -143,8 +150,7 @@ __device__ __forceinline__ void load_q_global_smem_multi_warps( const uint32_t h_offset = offset_now % group_size; T* q_ptr = q_ptr_base + n_offset * qo_n_stride + h_offset * qo_h_stride; #pragma unroll - for (uint32_t fyo = 0; fyo < num_frags_y / 4; - ++fyo) { + for (uint32_t fyo = 0; fyo < num_frags_y / 4; ++fyo) { q_smem->load_128b_async( q_smem_offset_w, q_ptr, n_offset < qo_upper_bound); q_smem_offset_w = @@ -163,7 +169,7 @@ template __device__ __forceinline__ void load_q_global_smem( - T* q_ptr_base, + const T* q_ptr_base, smem_t* q_smem, uint32_t q_idx_base, const uint32_t qo_upper_bound, @@ -186,10 +192,10 @@ __device__ __forceinline__ void load_q_global_smem( const uint32_t offset_now = base_offset + j * 4; const uint32_t n_offset = offset_now / group_size; const uint32_t h_offset = offset_now % group_size; - T* q_ptr = q_ptr_base + n_offset * qo_n_stride + h_offset * qo_h_stride; + const T* q_ptr = + q_ptr_base + n_offset * qo_n_stride + h_offset * qo_h_stride; #pragma unroll - for (uint32_t fyo = 0; fyo < num_frags_y / 4; - ++fyo) { + for (uint32_t fyo = 0; fyo < num_frags_y / 4; ++fyo) { q_smem->load_128b_async( q_smem_offset_w, q_ptr, n_offset < qo_upper_bound); q_smem_offset_w = @@ -215,8 +221,7 @@ __device__ __forceinline__ void q_smem_inplace_multiply_sm_scale_multi_warps( constexpr uint32_t num_vecs_per_head = head_dim / num_elems_per_128b(); #pragma unroll - for (uint32_t i = 0; i < num_frags_x * 16 * head_dim / 1024; - ++i) { + for (uint32_t i = 0; i < num_frags_x * 16 * head_dim / 1024; ++i) { const int offset = i * 1024 + ty * 256 + tx * 8; Load(reinterpret_cast(q_smem->base) + offset, &tmp_vec); #pragma unroll @@ -281,11 +286,9 @@ __device__ __forceinline__ void produce_kv_blockwise( const uint32_t tx = threadIdx.x, ty = threadIdx.y; uint32_t kv_idx = kv_idx_base + ty * 4 + tx / 8; // kv_idx used to check #pragma unroll - for (uint32_t i = 0; i < NUM_WARP_KV * num_frags_z * 4 / num_warps; - ++i) { + for (uint32_t i = 0; i < NUM_WARP_KV * num_frags_z * 4 / num_warps; ++i) { #pragma unroll - for (uint32_t j = 0; j < num_frags_y / 4; - ++j) { + for (uint32_t j = 0; j < num_frags_y / 4; ++j) { smem.load_128b_async(*smem_offset, *gptr, kv_idx < kv_len); *smem_offset = smem.advance_offset_by_column<8>(*smem_offset, j); *gptr += 8 * num_elems_per_128b(); @@ -324,9 +327,7 @@ __device__ __forceinline__ void produce_v_blockwise_c8( block_size / num_elems_per_128b(); // 8 constexpr uint32_t NUM_WARP_KV = num_warps / NUM_WARP_Q; const uint32_t tx = threadIdx.x, ty = threadIdx.y; - uint32_t kv_idx = - kv_idx_base + - tx % 4 * num_elems_per_128b(); + uint32_t kv_idx = kv_idx_base + tx % 4 * num_elems_per_128b(); if constexpr (NUM_WARP_Q == 4) { int block_id = __ldg(&block_table_now[kv_idx / block_size]); if (block_id < 0) block_id = 0; @@ -335,8 +336,7 @@ __device__ __forceinline__ void produce_v_blockwise_c8( for (uint32_t i = 0; i < num_frags_y * 2 / num_warps; ++i) { // m (num_frags_y * 16 / (num_warps * 8)) #pragma unroll - for (uint32_t j = 0; j < num_frags_z / 4; - ++j) { + for (uint32_t j = 0; j < num_frags_z / 4; ++j) { smem.load_128b_async(*smem_offset, cache_v_now, true); *smem_offset = smem.advance_offset_by_column<4, num_vecs_per_blocksize>( *smem_offset, j); @@ -361,8 +361,7 @@ __device__ __forceinline__ void produce_v_blockwise_c8( for (uint32_t i = 0; i < num_frags_y * 2 / num_warps; ++i) { // m (num_frags_y * 16 / (num_warps * 8)) #pragma unroll - for (uint32_t j = 0; j < 2 * num_frags_z / 4; - ++j) { + for (uint32_t j = 0; j < 2 * num_frags_z / 4; ++j) { smem.load_128b_async(*smem_offset, cache_v_now, true); *smem_offset = smem.advance_offset_by_column<4, num_vecs_per_blocksize>( @@ -384,6 +383,106 @@ __device__ __forceinline__ void produce_v_blockwise_c8( } } +template +__device__ __forceinline__ void produce_kv_dynamic_scale_gmem2smem_async( + smem_t kv_scale_smem, + const int* block_table_now, + const T* cache_kv_scale, + const uint32_t kv_idx, + const uint32_t kv_num_heads, + const uint32_t kv_head_idx, + const uint32_t chunk_end) { + const uint32_t tx = threadIdx.x, ty = threadIdx.y; + const uint32_t tid = ty * 32 + tx; + if constexpr (NUM_WARP_Q == 4) { + // 4 warps shared block_size + int block_id = __ldg(&block_table_now[kv_idx / block_size]); + if (block_id < 0) block_id = 0; + if (tid < block_size / 8) { + const T* cache_k_scale_now = cache_kv_scale + + block_id * kv_num_heads * block_size + + kv_head_idx * block_size + tid * 8; + const int kv_idx_this_thread = kv_idx + tid * 8; + kv_scale_smem.load_128b_async( + tid, cache_k_scale_now, kv_idx_this_thread < chunk_end); + } + } else { + // 1 warp 32 tokens + if (tid < block_size / 8 * 2) { + const uint32_t kv_idx_now = kv_idx + block_size * tid / 8; + int block_id = __ldg(&block_table_now[kv_idx_now / block_size]); + if (block_id < 0) block_id = 0; + const int kv_idx_this_thread = kv_idx + tid * 8; + const T* cache_k_scale_now = cache_kv_scale + + block_id * kv_num_heads * block_size + + kv_head_idx * block_size + tid % 8 * 8; + kv_scale_smem.load_128b_async( + tid, cache_k_scale_now, kv_idx_this_thread < chunk_end); + } + } +} + +template +__device__ __forceinline__ void produce_k_dynamic_scale_smem2reg( + T* k_smem_scale, T* cache_k_reg) { + const uint32_t tx = threadIdx.x, ty = threadIdx.y; + if constexpr (NUM_WARP_Q == 4) { + // 4 warps shared block_size + const uint32_t row_id = tx / 4; + for (uint32_t fz = 0; fz < num_frags_z; fz++) { + const uint32_t scale_idx = fz * 16 + row_id; + cache_k_reg[fz * 2] = k_smem_scale[scale_idx]; + cache_k_reg[fz * 2 + 1] = k_smem_scale[scale_idx + 8]; + } + } else { + // 1 warp 32 tokens + const uint32_t row_id = tx / 4; + for (uint32_t fz = 0; fz < num_frags_z; fz++) { + const uint32_t scale_idx = ty * 32 + fz * 16 + row_id; + cache_k_reg[fz * 2] = k_smem_scale[scale_idx]; + cache_k_reg[fz * 2 + 1] = k_smem_scale[scale_idx + 8]; + } + } +} + +template +__device__ __forceinline__ void produce_v_dynamic_scale_smem2reg( + T* v_smem_scale, T* cache_v_reg) { + const uint32_t tx = threadIdx.x, ty = threadIdx.y; + + if constexpr (NUM_WARP_Q == 4) { + // 4 warps shared block_size + const uint32_t row_id = tx % 4 * 2; + for (uint32_t fz = 0; fz < num_frags_z; fz++) { + const uint32_t scale_idx = fz * 16 + row_id; + cache_v_reg[fz * 4] = v_smem_scale[scale_idx]; + cache_v_reg[fz * 4 + 1] = v_smem_scale[scale_idx + 1]; + cache_v_reg[fz * 4 + 2] = v_smem_scale[scale_idx + 8]; + cache_v_reg[fz * 4 + 3] = v_smem_scale[scale_idx + 9]; + } + } else { + // 1 warp 32 tokens + const uint32_t row_id = tx % 4 * 2; + for (uint32_t fz = 0; fz < num_frags_z; fz++) { + const uint32_t scale_idx = ty * 32 + fz * 16 + row_id; + cache_v_reg[fz * 4] = v_smem_scale[scale_idx]; + cache_v_reg[fz * 4 + 1] = v_smem_scale[scale_idx + 1]; + cache_v_reg[fz * 4 + 2] = v_smem_scale[scale_idx + 8]; + cache_v_reg[fz * 4 + 3] = v_smem_scale[scale_idx + 9]; + } + } +} + template (*smem_offset, cache_k_now, true); *smem_offset = smem.advance_offset_by_column<8, num_vecs_per_head>( *smem_offset, j); @@ -499,8 +597,7 @@ __device__ __forceinline__ void produce_v_blockwise_c4( #pragma unroll for (uint32_t i = 0; i < num_frags_y / num_warps; ++i) { // m #pragma unroll - for (uint32_t j = 0; j < num_frags_z / 4; - ++j) { + for (uint32_t j = 0; j < num_frags_z / 4; ++j) { smem.load_128b_async(*smem_offset, cache_v_now, true); *smem_offset = smem.advance_offset_by_column<2, num_vecs_per_blocksize>( *smem_offset, j); @@ -556,8 +653,7 @@ __device__ __forceinline__ void produce_k_blockwise_c4( for (uint32_t i = 0; i < num_frags_z * 2 / num_warps; ++i) { // m num_frags_z * 16 / (num_warps * 8) #pragma unroll - for (uint32_t j = 0; j < num_frags_y / 8; - ++j) { + for (uint32_t j = 0; j < num_frags_y / 8; ++j) { smem.load_128b_async(*smem_offset, cache_k_now, true); *smem_offset = smem.advance_offset_by_column<4, num_vecs_per_head>( *smem_offset, j); @@ -816,12 +912,13 @@ template + bool IsFP8 = false, + bool IsDynamicC8 = false> __device__ __forceinline__ void compute_qk_c8(smem_t* q_smem, uint32_t* q_smem_offset_r, smem_t* k_smem, uint32_t* k_smem_offset_r, - const T *cache_k_scale, + const T* cache_k_scale, float (*s_frag)[num_frags_z][8]) { constexpr uint32_t head_dim = num_frags_y * 16; constexpr uint32_t num_vecs_per_head_q = head_dim / num_elems_per_128b(); @@ -857,23 +954,30 @@ __device__ __forceinline__ void compute_qk_c8(smem_t* q_smem, #pragma unroll for (uint32_t fy = 0; fy < 2; ++fy) { T* b_frag_dq_T = reinterpret_cast(b_frag_dq); - convert_c8(b_frag_dq_T, b_frag[fy * 2]); - convert_c8(b_frag_dq_T + 4, b_frag[fy * 2 + 1]); + convert_c8(b_frag_dq_T, b_frag[fy * 2]); + convert_c8(b_frag_dq_T + 4, b_frag[fy * 2 + 1]); // scale zp - if constexpr (is_scale_channel_wise) { - const int scale_col = (ky * 2 + fy) * 4; - b_frag_dq_T[0] *= cache_k_scale[scale_col]; - b_frag_dq_T[1] *= cache_k_scale[scale_col + 1]; - b_frag_dq_T[2] *= cache_k_scale[scale_col + 2]; - b_frag_dq_T[3] *= cache_k_scale[scale_col + 3]; - b_frag_dq_T[4] *= cache_k_scale[scale_col]; - b_frag_dq_T[5] *= cache_k_scale[scale_col + 1]; - b_frag_dq_T[6] *= cache_k_scale[scale_col + 2]; - b_frag_dq_T[7] *= cache_k_scale[scale_col + 3]; + if constexpr (!IsDynamicC8) { + if constexpr (is_scale_channel_wise) { + const int scale_col = (ky * 2 + fy) * 4; + b_frag_dq_T[0] *= cache_k_scale[scale_col]; + b_frag_dq_T[1] *= cache_k_scale[scale_col + 1]; + b_frag_dq_T[2] *= cache_k_scale[scale_col + 2]; + b_frag_dq_T[3] *= cache_k_scale[scale_col + 3]; + b_frag_dq_T[4] *= cache_k_scale[scale_col]; + b_frag_dq_T[5] *= cache_k_scale[scale_col + 1]; + b_frag_dq_T[6] *= cache_k_scale[scale_col + 2]; + b_frag_dq_T[7] *= cache_k_scale[scale_col + 3]; + } else { +#pragma unroll + for (uint32_t b_i = 0; b_i < 8; ++b_i) { + b_frag_dq_T[b_i] *= cache_k_scale[0]; + } + } } else { #pragma unroll for (uint32_t b_i = 0; b_i < 8; ++b_i) { - b_frag_dq_T[b_i] *= cache_k_scale[0]; + b_frag_dq_T[b_i] *= cache_k_scale[fz * 2 + b_i / 4]; } } #pragma unroll @@ -905,12 +1009,17 @@ template -__device__ __forceinline__ void mask_s(const uint32_t qo_idx_base, +__device__ __forceinline__ void mask_s(const bool* attn_mask, + const uint32_t qo_idx_base, const uint32_t kv_idx_base, const uint32_t qo_len, const uint32_t kv_len, const uint32_t chunk_end, - float (*s_frag)[num_frags_z][8]) { + const uint32_t attn_mask_len, + float (*s_frag)[num_frags_z][8], + const int* mask_offset = nullptr, + const int sliding_window = 0, + const int sink_size = 0) { const uint32_t tx = threadIdx.x; #pragma unroll for (uint32_t fx = 0; fx < num_frags_x; ++fx) { @@ -924,10 +1033,44 @@ __device__ __forceinline__ void mask_s(const uint32_t qo_idx_base, group_size, kv_idx = kv_idx_base + fz * 16 + 2 * (tx % 4) + 8 * (reg_id / 4) + reg_id % 2; - const bool out_of_boundary = - (causal - ? (kv_idx > kv_len + q_idx - qo_len || (kv_idx >= chunk_end)) - : kv_idx >= chunk_end); + bool out_of_boundary; + if (mask_offset) { + if (sliding_window > 0) { + int swa_part = mask_offset[q_idx * 2 + 1] - sliding_window; + if (swa_part < 0) swa_part = 0; + int sink_part = + mask_offset[q_idx * 2] + sink_size; // sink_size = 128 + out_of_boundary = + q_idx < qo_len ? (kv_idx >= mask_offset[q_idx * 2 + 1] || + kv_idx < mask_offset[q_idx * 2] || + (kv_idx >= sink_part && kv_idx < swa_part)) + : true; + } else { + out_of_boundary = q_idx < qo_len + ? (kv_idx >= mask_offset[q_idx * 2 + 1] || + kv_idx < mask_offset[q_idx * 2]) + : true; + } + } else if (sliding_window > 0) { + bool out_of_window = int(kv_idx) <= (int)kv_len + (int)q_idx - + (int)qo_len - + sliding_window; + out_of_boundary = (causal ? (kv_idx > kv_len + q_idx - qo_len || + out_of_window || (kv_idx >= chunk_end)) + : kv_idx >= chunk_end); + } else { + out_of_boundary = (causal ? (kv_idx > kv_len + q_idx - qo_len || + (kv_idx >= chunk_end)) + : kv_idx >= chunk_end); + if (attn_mask != nullptr && kv_idx > kv_len - qo_len && + kv_idx < chunk_end && q_idx < attn_mask_len) { + const int32_t mask_idx = + q_idx * attn_mask_len + kv_idx - kv_len + qo_len; + bool mask = attn_mask[mask_idx]; + out_of_boundary |= mask; + } + } + if constexpr (std::is_same::value) { s_frag[fx][fz][reg_id] = out_of_boundary ? -5e4f : s_frag[fx][fz][reg_id]; @@ -935,6 +1078,7 @@ __device__ __forceinline__ void mask_s(const uint32_t qo_idx_base, s_frag[fx][fz][reg_id] = out_of_boundary ? -3.0e+30f : s_frag[fx][fz][reg_id]; } + } else { const uint32_t q_idx = qo_idx_base, kv_idx = kv_idx_base + fz * 16 + 2 * (tx % 4) + @@ -1078,14 +1222,16 @@ template + bool is_scale_channel_wise = false, + bool IsFP8 = false, + bool IsDynamicC8 = false> __device__ __forceinline__ void compute_sfm_v_c8( smem_t* v_smem, uint32_t* v_smem_offset_r, float (*s_frag)[num_frags_z][8], float (*o_frag)[num_frags_y][8], float (*d)[2], - const T *cache_v_scale) { + const T* cache_v_scale) { constexpr uint32_t num_vecs_per_blocksize = block_size / num_elems_per_128b(); T s_frag_f16[num_frags_x][num_frags_z][8]; @@ -1117,19 +1263,31 @@ __device__ __forceinline__ void compute_sfm_v_c8( #pragma unroll for (uint32_t fz = 0; fz < 2; ++fz) { T* b_frag_dq_T = reinterpret_cast(b_frag_dq); - convert_c8(b_frag_dq_T, b_frag[fz * 2]); - convert_c8(b_frag_dq_T + 4, b_frag[fz * 2 + 1]); + convert_c8(b_frag_dq_T, b_frag[fz * 2]); + convert_c8(b_frag_dq_T + 4, b_frag[fz * 2 + 1]); // scale zp - if constexpr (is_scale_channel_wise) { + if constexpr (!IsDynamicC8) { + if constexpr (is_scale_channel_wise) { #pragma unroll - for (uint32_t b_i = 0; b_i < 8; ++b_i) { - b_frag_dq_T[b_i] *= cache_v_scale[b_i / 4 + fy * 2]; - } - } else { + for (uint32_t b_i = 0; b_i < 8; ++b_i) { + b_frag_dq_T[b_i] *= cache_v_scale[b_i / 4 + fy * 2]; + } + } else { #pragma unroll - for (uint32_t b_i = 0; b_i < 8; ++b_i) { - b_frag_dq_T[b_i] *= cache_v_scale[0]; + for (uint32_t b_i = 0; b_i < 8; ++b_i) { + b_frag_dq_T[b_i] *= cache_v_scale[0]; + } } + } else { + const int scale_col = (kz * 2 + fz) * 4; + b_frag_dq_T[0] *= cache_v_scale[scale_col]; + b_frag_dq_T[1] *= cache_v_scale[scale_col + 1]; + b_frag_dq_T[2] *= cache_v_scale[scale_col + 2]; + b_frag_dq_T[3] *= cache_v_scale[scale_col + 3]; + b_frag_dq_T[4] *= cache_v_scale[scale_col]; + b_frag_dq_T[5] *= cache_v_scale[scale_col + 1]; + b_frag_dq_T[6] *= cache_v_scale[scale_col + 2]; + b_frag_dq_T[7] *= cache_v_scale[scale_col + 3]; } #pragma unroll for (uint32_t fx = 0; fx < num_frags_x; ++fx) { // m: num_frags_x * 16 @@ -1137,7 +1295,6 @@ __device__ __forceinline__ void compute_sfm_v_c8( o_frag[fx][fy], (uint32_t*)(s_frag_f16[fx][kz * 2 + fz]), b_frag_dq); - } } } @@ -1156,14 +1313,16 @@ template + bool is_scale_channel_wise = false, + bool IsFP8 = false, + bool IsDynamicC8 = false> __device__ __forceinline__ void compute_sfm_v_c8_iter_sq_bvec( smem_t* v_smem, uint32_t* v_smem_offset_r, float (*s_frag)[num_frags_z][8], float (*o_frag)[num_frags_y][8], float (*d)[2], - T *cache_v_scale) { + T* cache_v_scale) { constexpr uint32_t num_vecs_per_blocksize = block_size / num_elems_per_128b(); @@ -1197,19 +1356,31 @@ __device__ __forceinline__ void compute_sfm_v_c8_iter_sq_bvec( for (uint32_t fz = 0; fz < 2; ++fz) { // dequant b_frag -> b_frag_dq T* b_frag_dq_T = reinterpret_cast(b_frag_dq); - convert_c8(b_frag_dq_T, b_frag[fz * 2]); - convert_c8(b_frag_dq_T + 4, b_frag[fz * 2 + 1]); + convert_c8(b_frag_dq_T, b_frag[fz * 2]); + convert_c8(b_frag_dq_T + 4, b_frag[fz * 2 + 1]); // scale zp - if constexpr (is_scale_channel_wise) { + if constexpr (!IsDynamicC8) { + if constexpr (is_scale_channel_wise) { #pragma unroll - for (uint32_t b_i = 0; b_i < 8; ++b_i) { - b_frag_dq_T[b_i] *= cache_v_scale[b_i / 4 + fy * 2]; + for (uint32_t b_i = 0; b_i < 8; ++b_i) { + b_frag_dq_T[b_i] *= cache_v_scale[b_i / 4 + fy * 2]; + } + } else { +#pragma unroll + for (uint32_t b_i = 0; b_i < 8; ++b_i) { + b_frag_dq_T[b_i] *= cache_v_scale[0]; + } } } else { - #pragma unroll - for (uint32_t b_i = 0; b_i < 8; ++b_i) { - b_frag_dq_T[b_i] *= cache_v_scale[0]; - } + const int scale_col = (kz * 2 + fz) * 4; + b_frag_dq_T[0] *= cache_v_scale[scale_col]; + b_frag_dq_T[1] *= cache_v_scale[scale_col + 1]; + b_frag_dq_T[2] *= cache_v_scale[scale_col + 2]; + b_frag_dq_T[3] *= cache_v_scale[scale_col + 3]; + b_frag_dq_T[4] *= cache_v_scale[scale_col]; + b_frag_dq_T[5] *= cache_v_scale[scale_col + 1]; + b_frag_dq_T[6] *= cache_v_scale[scale_col + 2]; + b_frag_dq_T[7] *= cache_v_scale[scale_col + 3]; } #pragma unroll for (uint32_t fx = 0; fx < num_frags_x; ++fx) { // m: num_frags_x * 16 @@ -1254,8 +1425,7 @@ __device__ __forceinline__ void compute_sfm_v(smem_t* v_smem, } #pragma unroll - for (uint32_t fz = 0; fz < num_frags_z; - ++fz) { + for (uint32_t fz = 0; fz < num_frags_z; ++fz) { #pragma unroll for (uint32_t fy = 0; fy < num_frags_y; ++fy) { uint32_t b_frag[4]; @@ -1300,6 +1470,33 @@ __device__ __forceinline__ void normalize_d(float (*o_frag)[num_frags_y][8], } } +template +__device__ __forceinline__ void normalize_d(float (*o_frag)[num_frags_y][8], + float (*d)[2], + float (*m)[2], + float (*current_sinks)[2]) { + float d_rcp[num_frags_x][2]; +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (uint32_t j = 0; j < 2; ++j) { + d_rcp[fx][j] = 1.f / (d[fx][j] + __expf(current_sinks[fx][j] - m[fx][j])); + } + } + +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (uint32_t fy = 0; fy < num_frags_y; ++fy) { +#pragma unroll + for (uint32_t reg_id = 0; reg_id < 8; ++reg_id) { + o_frag[fx][fy][reg_id] = + o_frag[fx][fy][reg_id] * d_rcp[fx][(reg_id % 4) / 2]; + } + } + } +} + template ((T*)o_frag_f16, o_frag[fx][fy]); - uint32_t o_smem_offset_w = smem_t::get_permuted_offset< - num_vecs_per_head>( - fx * 16 + tx / 4, - fy * 2); + uint32_t o_smem_offset_w = + smem_t::get_permuted_offset(fx * 16 + tx / 4, + fy * 2); ((uint32_t*)(o_smem->base + o_smem_offset_w))[tx % 4] = o_frag_f16[0]; ((uint32_t*)(o_smem->base + o_smem_offset_w + 8 * num_vecs_per_head))[tx % 4] = o_frag_f16[1]; @@ -1423,8 +1619,8 @@ __device__ __forceinline__ void write_o_reg_gmem_kv_multi_warps( } __syncthreads(); - uint32_t o_smem_offset_w = smem_t::get_permuted_offset( - ty * 4 + tx / 8, tx % 8); + uint32_t o_smem_offset_w = + smem_t::get_permuted_offset(ty * 4 + tx / 8, tx % 8); o_idx_base += (tx / 8) / group_size; o_ptr_base += ((tx / 8) / group_size) * qo_n_stride + @@ -1438,8 +1634,7 @@ __device__ __forceinline__ void write_o_reg_gmem_kv_multi_warps( T* o_ptr = o_ptr_base + ((fx * 16 + j * 4) / group_size) * qo_n_stride + ((fx * 16 + j * 4) % group_size) * qo_h_stride; #pragma unroll - for (uint32_t fyo = 0; fyo < num_frags_y / 4; - ++fyo) { + for (uint32_t fyo = 0; fyo < num_frags_y / 4; ++fyo) { if (o_idx < qo_upper_bound) { // need write o_smem->store_128b(o_smem_offset_w, o_ptr); @@ -1454,7 +1649,6 @@ __device__ __forceinline__ void write_o_reg_gmem_kv_multi_warps( } } - template struct StoreFunc { __device__ __forceinline__ void operator()( @@ -1513,7 +1707,6 @@ struct StoreFunc { } }; - template struct StoreFunc { __device__ __forceinline__ void operator()( @@ -1566,10 +1759,9 @@ __device__ __forceinline__ void write_o_reg_gmem_multi_warps_shift_smooth_quant( for (uint32_t fy = 0; fy < num_frags_y; ++fy) { uint32_t o_frag_f16[4]; vec_cast((T*)o_frag_f16, o_frag[fx][fy]); - uint32_t o_smem_offset_w = smem_t::get_permuted_offset< - num_vecs_per_head>( - fx * 16 + tx / 4, - fy * 2); + uint32_t o_smem_offset_w = + smem_t::get_permuted_offset(fx * 16 + tx / 4, + fy * 2); ((uint32_t*)(o_smem->base + o_smem_offset_w))[tx % 4] = o_frag_f16[0]; ((uint32_t*)(o_smem->base + o_smem_offset_w + 8 * num_vecs_per_head))[tx % 4] = o_frag_f16[1]; @@ -1582,8 +1774,8 @@ __device__ __forceinline__ void write_o_reg_gmem_multi_warps_shift_smooth_quant( } __syncthreads(); - uint32_t o_smem_offset_w = smem_t::get_permuted_offset( - ty * 4 + tx / 8, tx % 8); + uint32_t o_smem_offset_w = + smem_t::get_permuted_offset(ty * 4 + tx / 8, tx % 8); const uint32_t tx_offset = tx / 8; #pragma unroll @@ -1600,8 +1792,7 @@ __device__ __forceinline__ void write_o_reg_gmem_multi_warps_shift_smooth_quant( uint32_t shift_smooth_offset = (q_head_idx_base + h_offset) * head_dim + tx % 8 * num_elems_per_128b(); #pragma unroll - for (uint32_t fyo = 0; fyo < num_frags_y / 4; - ++fyo) { + for (uint32_t fyo = 0; fyo < num_frags_y / 4; ++fyo) { if (n_offset < qo_upper_bound) { if constexpr (!partition_kv) { Load( @@ -1677,10 +1868,8 @@ __device__ __forceinline__ void write_o_reg_gmem_shift_smooth_quant( for (uint32_t fy = 0; fy < num_frags_y; ++fy) { uint32_t o_frag_f16[4]; vec_cast((T*)o_frag_f16, o_frag[fx][fy]); - uint32_t o_smem_offset_w = smem_t::get_permuted_offset< - num_vecs_per_head>( - (ty * num_frags_x + fx) * 16 + tx / 4, - fy * 2); + uint32_t o_smem_offset_w = smem_t::get_permuted_offset( + (ty * num_frags_x + fx) * 16 + tx / 4, fy * 2); ((uint32_t*)(o_smem->base + o_smem_offset_w))[tx % 4] = o_frag_f16[0]; ((uint32_t*)(o_smem->base + o_smem_offset_w + 8 * num_vecs_per_head))[tx % 4] = o_frag_f16[1]; @@ -1693,8 +1882,7 @@ __device__ __forceinline__ void write_o_reg_gmem_shift_smooth_quant( __syncthreads(); uint32_t o_smem_offset_w = smem_t::get_permuted_offset( - ty * num_frags_x * 16 + tx / 8, - tx % 8); + ty * num_frags_x * 16 + tx / 8, tx % 8); const uint32_t tx_offset = tx / 8; #pragma unroll @@ -1710,13 +1898,12 @@ __device__ __forceinline__ void write_o_reg_gmem_shift_smooth_quant( uint32_t shift_smooth_offset = (q_head_idx_base + h_offset) * head_dim + tx % 8 * num_elems_per_128b(); #pragma unroll - for (uint32_t fyo = 0; fyo < num_frags_y / 4; - ++fyo) { + for (uint32_t fyo = 0; fyo < num_frags_y / 4; ++fyo) { if (n_offset < qo_upper_bound) { if (!partition_kv) { Load( - reinterpret_cast(o_smem->base + o_smem_offset_w), - &ori_out_vec); + reinterpret_cast(o_smem->base + o_smem_offset_w), + &ori_out_vec); if (in_scale > 0.0) { if (shift_bias) { Load(shift_bias + shift_smooth_offset, @@ -1725,16 +1912,16 @@ __device__ __forceinline__ void write_o_reg_gmem_shift_smooth_quant( &smooth_weight_vec); } } - #pragma unroll +#pragma unroll for (int i = 0; i < VEC_SIZE; ++i) { StoreFunc()(ori_out_vec, - shift_bias_vec, - smooth_weight_vec, - out_vec, - quant_max_bound, - quant_min_bound, - in_scale, - i); + shift_bias_vec, + smooth_weight_vec, + out_vec, + quant_max_bound, + quant_min_bound, + in_scale, + i); } Store(out_vec, o_ptr); } else { @@ -1775,10 +1962,8 @@ __device__ __forceinline__ void write_o_reg_gmem( for (uint32_t fy = 0; fy < num_frags_y; ++fy) { uint32_t o_frag_f16[4]; vec_cast((T*)o_frag_f16, o_frag[fx][fy]); - uint32_t o_smem_offset_w = smem_t::get_permuted_offset< - num_vecs_per_head>( - (ty * num_frags_x + fx) * 16 + tx / 4, - fy * 2); + uint32_t o_smem_offset_w = smem_t::get_permuted_offset( + (ty * num_frags_x + fx) * 16 + tx / 4, fy * 2); ((uint32_t*)(o_smem->base + o_smem_offset_w))[tx % 4] = o_frag_f16[0]; ((uint32_t*)(o_smem->base + o_smem_offset_w + 8 * num_vecs_per_head))[tx % 4] = o_frag_f16[1]; @@ -1791,8 +1976,7 @@ __device__ __forceinline__ void write_o_reg_gmem( __syncthreads(); uint32_t o_smem_offset_w = smem_t::get_permuted_offset( - ty * num_frags_x * 16 + tx / 8, - tx % 8); + ty * num_frags_x * 16 + tx / 8, tx % 8); o_idx_base += (tx / 8) / group_size; o_ptr_base += ((tx / 8) / group_size) * qo_n_stride + @@ -1805,8 +1989,7 @@ __device__ __forceinline__ void write_o_reg_gmem( T* o_ptr = o_ptr_base + ((fx * 16 + j * 4) / group_size) * qo_n_stride + ((fx * 16 + j * 4) % group_size) * qo_h_stride; #pragma unroll - for (uint32_t fyo = 0; fyo < num_frags_y / 4; - ++fyo) { + for (uint32_t fyo = 0; fyo < num_frags_y / 4; ++fyo) { if (o_idx < qo_upper_bound) { o_smem->store_128b(o_smem_offset_w, o_ptr); } @@ -1867,6 +2050,7 @@ __global__ void merge_multi_chunks_kernel( const int vid = threadIdx.x, hid = threadIdx.y; const int qid = blockIdx.x; const uint32_t bid = batch_id_per_token[qid]; + if (bid == -1) return; if (seq_lens_q[bid] <= 0 || seq_lens_kv[bid] <= 0) { return; } @@ -1921,7 +2105,6 @@ __global__ void merge_multi_chunks_kernel( &out[(qid * num_heads + hid) * head_dim + vid * vec_size]); } - template __device__ __forceinline__ void merge_block_res(float (*o_frag)[num_frags_y][8], float* md_smem, @@ -2103,17 +2286,18 @@ template __global__ void merge_multi_chunks_decoder_kernel( - const T *__restrict__ multi_out, // [token_num, num_chunks, num_heads, + const T* __restrict__ multi_out, // [token_num, num_chunks, num_heads, // head_dim] - const float *__restrict__ multi_m, // [token_num, num_chunks, num_heads] - const float *__restrict__ multi_d, // [token_num, num_chunks, num_heads] - const int *__restrict__ seq_lens_q, - const int *__restrict__ seq_lens_kv, - const int *__restrict__ seq_lens_encoder, - const int *__restrict__ cu_seqlens_q, - const T *__restrict__ shift_bias, // [q_num_heads * HEAD_DIM] - const T *__restrict__ smooth_weight, // [q_num_heads * HEAD_DIM] - OutT *__restrict__ out, + const float* __restrict__ multi_m, // [token_num, num_chunks, num_heads] + const float* __restrict__ multi_d, // [token_num, num_chunks, num_heads] + const int* __restrict__ seq_lens_q, + const int* __restrict__ seq_lens_kv, + const int* __restrict__ seq_lens_encoder, + const int* __restrict__ cu_seqlens_q, + const T* __restrict__ shift_bias, // [q_num_heads * HEAD_DIM] + const T* __restrict__ smooth_weight, // [q_num_heads * HEAD_DIM] + const T* __restrict__ sinks, // [q_num_heads] + OutT* __restrict__ out, const float quant_max_bound, const float quant_min_bound, const float in_scale, @@ -2126,6 +2310,9 @@ __global__ void merge_multi_chunks_decoder_kernel( const int bid = blockIdx.x, hid = blockIdx.y; __shared__ T smem[bdy * HEAD_DIM]; __shared__ float md_smem[bdy * 2]; +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaGridDependencySynchronize(); +#endif const int start_token_idx = cu_seqlens_q[bid]; const int seq_len_q = seq_lens_q[bid]; if (seq_len_q == 0) return; @@ -2150,17 +2337,11 @@ __global__ void merge_multi_chunks_decoder_kernel( using LoadT = AlignedVector; LoadT load_vec; LoadT res_vec; - if constexpr (std::is_same::value) { -#pragma unroll - for (int i = 0; i < vec_size / 2; ++i) { - *((half2 *)(&res_vec) + i) = make_half2(0, 0); - } - } else { -#pragma unroll - for (int i = 0; i < vec_size / 2; ++i) { - *((nv_bfloat162 *)(&res_vec) + i) = make_bfloat162(0, 0); - } + + for (int i = 0; i < vec_size; ++i) { + res_vec[i] = T(0.f); } + float m; float d = 1.f; if constexpr (std::is_same::value) { @@ -2168,6 +2349,10 @@ __global__ void merge_multi_chunks_decoder_kernel( } else if constexpr (std::is_same::value) { m = -3.0e+30f; } +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaGridDependencySynchronize(); +#endif + #pragma unroll 2 for (int i = ty; i < num_chunks_this_seq; i += bdy) { uint32_t offset = (bid * num_chunks + i) * num_heads + hid; @@ -2176,8 +2361,7 @@ __global__ void merge_multi_chunks_decoder_kernel( const float m_now = multi_m[offset]; const float d_now = multi_d[offset]; m = max(m_prev, m_now); - offset = (bid * num_chunks * num_heads + i * num_heads + hid) * head_dim + - vid * vec_size; + offset = offset * head_dim + vid * vec_size; Load(&multi_out[offset], &load_vec); const float scale1 = __expf(m_prev - m), scale2 = __expf(m_now - m); const T scale1_T = static_cast(scale1), @@ -2203,7 +2387,12 @@ __global__ void merge_multi_chunks_decoder_kernel( const float m_tmp = md_smem[2 * i], d_tmp = md_smem[2 * i + 1]; st.merge(load_vec, m_tmp, d_tmp); } - st.normalize(); + if (sinks) { + float current_sink = static_cast(sinks[hid]); + st.normalize(current_sink); + } else { + st.normalize(); + } const uint32_t shift_smooth_offset = hid * head_dim + vid * vec_size; AlignedVector shift_bias_vec; @@ -2216,13 +2405,22 @@ __global__ void merge_multi_chunks_decoder_kernel( } #pragma unroll for (int i = 0; i < vec_size; ++i) { - StoreFunc()( - st.o, shift_bias_vec, smooth_weight_vec, out_vec, quant_max_bound, quant_min_bound, in_scale, i); + StoreFunc()(st.o, + shift_bias_vec, + smooth_weight_vec, + out_vec, + quant_max_bound, + quant_min_bound, + in_scale, + i); } Store( out_vec, &out[(start_token_idx * num_heads + hid) * head_dim + vid * vec_size]); } +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaTriggerProgrammaticLaunchCompletion(); +#endif } template + bool ENABLE_PREFILL = true, + bool DECODE_ONLY = true> __global__ void merge_multi_chunks_v2_kernel( - const T *__restrict__ multi_out, // [token_num, num_chunks, num_heads, + const T* __restrict__ multi_out, // [token_num, num_chunks, num_heads, // head_dim] - const float *__restrict__ multi_m, // [token_num, num_chunks, num_heads] - const float *__restrict__ multi_d, // [token_num, num_chunks, num_heads] - const int *__restrict__ seq_lens_q, - const int *__restrict__ seq_lens_kv, - const int *__restrict__ seq_lens_encoder, - const int *__restrict__ batch_id_per_token, - const int *__restrict__ cu_seqlens_q, - const T *__restrict__ shift_bias, // [q_num_heads * HEAD_DIM] - const T *__restrict__ smooth_weight, // [q_num_heads * HEAD_DIM] - OutT *__restrict__ out, + const float* __restrict__ multi_m, // [token_num, num_chunks, num_heads] + const float* __restrict__ multi_d, // [token_num, num_chunks, num_heads] + const int* __restrict__ seq_lens_q, + const int* __restrict__ seq_lens_kv, + const int* __restrict__ seq_lens_encoder, + const int* __restrict__ batch_id_per_token, + const int* __restrict__ cu_seqlens_q, + const T* __restrict__ shift_bias, // [q_num_heads * HEAD_DIM] + const T* __restrict__ smooth_weight, // [q_num_heads * HEAD_DIM] + const T* __restrict__ sinks, // [q_num_heads] + OutT* __restrict__ out, const float quant_max_bound, const float quant_min_bound, const float in_scale, @@ -2258,8 +2458,14 @@ __global__ void merge_multi_chunks_v2_kernel( const int hid = blockIdx.y; __shared__ T smem[bdy * HEAD_DIM]; __shared__ float md_smem[bdy * 2]; +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaGridDependencySynchronize(); +#endif for (int qid = blockIdx.x; qid < token_num; qid += gridDim.x) { const uint32_t bid = batch_id_per_token[qid]; + if (bid == -1) { + continue; + } const uint32_t local_seq_id = qid - cu_seqlens_q[bid]; const int seq_len_q = seq_lens_q[bid]; if (seq_len_q == 0) continue; @@ -2267,15 +2473,16 @@ __global__ void merge_multi_chunks_v2_kernel( if (ENABLE_PREFILL) { seq_len_kv += seq_len_q; if (seq_len_kv == 0) continue; - - const int seq_len_enc = seq_lens_encoder[bid]; - if (seq_len_enc <= 0) { - continue; - } } else { if (seq_len_kv == 0) continue; seq_len_kv += seq_len_q; } + if constexpr (DECODE_ONLY) { + const int seq_len_enc = seq_lens_encoder[bid]; + if (seq_len_enc > 0) { + continue; + } + } const int num_chunks_this_seq = div_up(seq_len_kv, chunk_size); if (num_chunks_this_seq <= 1) { continue; @@ -2287,12 +2494,12 @@ __global__ void merge_multi_chunks_v2_kernel( if constexpr (std::is_same::value) { #pragma unroll for (int i = 0; i < vec_size / 2; ++i) { - *((half2 *)(&res_vec) + i) = make_half2(0, 0); + *((half2*)(&res_vec) + i) = make_half2(0, 0); } } else { #pragma unroll for (int i = 0; i < vec_size / 2; ++i) { - *((nv_bfloat162 *)(&res_vec) + i) = make_bfloat162(0, 0); + *((nv_bfloat162*)(&res_vec) + i) = make_bfloat162(0, 0); } } float m; @@ -2355,7 +2562,13 @@ __global__ void merge_multi_chunks_v2_kernel( const float m_tmp = md_smem[2 * i], d_tmp = md_smem[2 * i + 1]; st.merge(load_vec, m_tmp, d_tmp); } - st.normalize(); + + if (sinks) { + float current_sink = static_cast(sinks[hid]); + st.normalize(current_sink); + } else { + st.normalize(); + } const uint32_t shift_smooth_offset = hid * head_dim + vid * vec_size; AlignedVector shift_bias_vec; @@ -2366,14 +2579,24 @@ __global__ void merge_multi_chunks_v2_kernel( Load(smooth_weight + shift_smooth_offset, &smooth_weight_vec); } + #pragma unroll for (int i = 0; i < vec_size; ++i) { - StoreFunc()( - st.o, shift_bias_vec, smooth_weight_vec, out_vec, quant_max_bound, quant_min_bound, in_scale, i); + StoreFunc()(st.o, + shift_bias_vec, + smooth_weight_vec, + out_vec, + quant_max_bound, + quant_min_bound, + in_scale, + i); } Store( out_vec, &out[(qid * num_heads + hid) * head_dim + vid * vec_size]); } __syncthreads(); } +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaTriggerProgrammaticLaunchCompletion(); +#endif } diff --git a/custom_ops/gpu_ops/append_attn/append_attention_kernel.h b/custom_ops/gpu_ops/append_attn/append_attention_kernel.h index 8799c0a7051..ca06deeeb75 100644 --- a/custom_ops/gpu_ops/append_attn/append_attention_kernel.h +++ b/custom_ops/gpu_ops/append_attn/append_attention_kernel.h @@ -13,144 +13,12 @@ // limitations under the License. #pragma once +#include "append_attention_c16_impl.cuh" +#include "append_attention_c4_impl.cuh" +#include "append_attention_c8_impl.cuh" #include "helper.h" #include "utils.cuh" -template -void CascadeAppendAttentionC16Kernel( - const AppendAttnMetaData& meta_data, - const paddle::Tensor& qkv, // [token_num, num_heads, head_dim] - const paddle::Tensor& - cache_k, // [max_block_num, num_heads, block_size, head_dim] - const paddle::Tensor& - cache_v, // [max_block_num, num_heads, head_dim, block_size] - const paddle::optional& attn_mask, - const paddle::optional& - cache_k_scale, // [num_kv_heads, head_dim] - const paddle::optional& - cache_v_scale, // [num_kv_heads, head_dim] - const paddle::optional& - cache_k_zp, // [num_kv_heads, head_dim] - const paddle::optional& - cache_v_zp, // [num_kv_heads, head_dim] - const paddle::optional& - shift_bias, // [num_kv_heads, head_dim] - const paddle::optional& - smooth_weight, // [num_kv_heads, head_dim] - const paddle::Tensor& seq_lens_q, - const paddle::Tensor& seq_lens_kv, - const paddle::Tensor& seq_lens_encoder, - const paddle::Tensor& batch_id_per_token, - const paddle::Tensor& cu_seqlens_q, - const paddle::Tensor& block_table, - const paddle::Tensor& batch_ids, - const paddle::Tensor& tile_ids_per_batch, - const int num_blocks, - const int block_shape_q, - const int max_seq_len, - const int max_dec_len, - const float quant_max_bound, - const float quant_min_bound, - const float in_scale, - const int max_partition_size, - const int encoder_max_partition_size, - const int speculate_max_draft_token_num, - const bool causal, - const bool is_decoder, - const bool enable_prefill, - cudaStream_t& stream, - paddle::Tensor* out); - -template -void CascadeAppendAttentionC8Kernel( - const AppendAttnMetaData& meta_data, - const paddle::Tensor& qkv, // [token_num, num_heads, head_dim] - const paddle::Tensor& - cache_k, // [max_block_num, num_heads, block_size, head_dim] - const paddle::Tensor& - cache_v, // [max_block_num, num_heads, head_dim, block_size] - const paddle::optional& attn_mask, - const paddle::optional& - cache_k_scale, // [num_kv_heads, head_dim] - const paddle::optional& - cache_v_scale, // [num_kv_heads, head_dim] - const paddle::optional& - cache_k_zp, // [num_kv_heads, head_dim] - const paddle::optional& - cache_v_zp, // [num_kv_heads, head_dim] - const paddle::optional& - shift_bias, // [num_kv_heads, head_dim] - const paddle::optional& - smooth_weight, // [num_kv_heads, head_dim] - const paddle::Tensor& seq_lens_q, - const paddle::Tensor& seq_lens_kv, - const paddle::Tensor& seq_lens_encoder, - const paddle::Tensor& batch_id_per_token, - const paddle::Tensor& cu_seqlens_q, - const paddle::Tensor& block_table, - const paddle::Tensor& batch_ids, - const paddle::Tensor& tile_ids_per_batch, - const int num_blocks, - const int block_shape_q, - const int max_seq_len, - const int max_dec_len, - const float quant_max_bound, - const float quant_min_bound, - const float in_scale, - const int max_partition_size, - const int encoder_max_partition_size, - const int speculate_max_draft_token_num, - const bool causal, - const bool is_decoder, - const bool enable_prefill, - cudaStream_t& stream, - paddle::Tensor* out); - -template -void CascadeAppendAttentionC4Kernel( - const AppendAttnMetaData& meta_data, - const paddle::Tensor& qkv, // [token_num, num_heads, head_dim] - const paddle::Tensor& - cache_k, // [max_block_num, num_heads, block_size, head_dim] - const paddle::Tensor& - cache_v, // [max_block_num, num_heads, head_dim, block_size] - const paddle::optional& attn_mask, - const paddle::optional& - cache_k_scale, // [num_kv_heads, head_dim] - const paddle::optional& - cache_v_scale, // [num_kv_heads, head_dim] - const paddle::optional& - cache_k_zp, // [num_kv_heads, head_dim] - const paddle::optional& - cache_v_zp, // [num_kv_heads, head_dim] - const paddle::optional& - shift_bias, // [num_kv_heads, head_dim] - const paddle::optional& - smooth_weight, // [num_kv_heads, head_dim] - const paddle::Tensor& seq_lens_q, - const paddle::Tensor& seq_lens_kv, - const paddle::Tensor& seq_lens_encoder, - const paddle::Tensor& batch_id_per_token, - const paddle::Tensor& cu_seqlens_q, - const paddle::Tensor& block_table, - const paddle::Tensor& batch_ids, - const paddle::Tensor& tile_ids_per_batch, - const int num_blocks, - const int block_shape_q, - const int max_seq_len, - const int max_dec_len, - const float quant_max_bound, - const float quant_min_bound, - const float in_scale, - const int max_partition_size, - const int encoder_max_partition_size, - const int speculate_max_draft_token_num, - const bool causal, - const bool is_decoder, - const bool enable_prefill, - cudaStream_t& stream, - paddle::Tensor* out); - template void CascadeAppendAttentionKernel( const AppendAttnMetaData& meta_data, @@ -171,7 +39,8 @@ void CascadeAppendAttentionKernel( const paddle::optional& shift_bias, // [num_kv_heads, head_dim] const paddle::optional& - smooth_weight, // [num_kv_heads, head_dim] + smooth_weight, // [num_kv_heads, head_dim] + const paddle::optional& sinks, // [num_heads] const paddle::Tensor& seq_lens_q, const paddle::Tensor& seq_lens_kv, const paddle::Tensor& seq_lens_encoder, @@ -195,150 +64,168 @@ void CascadeAppendAttentionKernel( const bool is_decoder, const bool enable_prefill, cudaStream_t& stream, - paddle::Tensor* out) { - if (cache_quant_type_str == "none") { - CascadeAppendAttentionC16Kernel(meta_data, - qkv, - cache_k, - cache_v, - attn_mask, - cache_k_scale, - cache_v_scale, - cache_k_zp, - cache_v_zp, - shift_bias, - smooth_weight, - seq_lens_q, - seq_lens_kv, - seq_lens_encoder, - batch_id_per_token, - cu_seqlens_q, - block_table, - batch_ids, - tile_ids_per_batch, - num_blocks, - block_shape_q, - max_seq_len, - max_dec_len, - quant_max_bound, - quant_min_bound, - in_scale, - max_partition_size, - encoder_max_partition_size, - speculate_max_draft_token_num, - causal, - is_decoder, - enable_prefill, - stream, - out); - } else if (cache_quant_type_str == "cache_int8") { - CascadeAppendAttentionC8Kernel(meta_data, - qkv, - cache_k, - cache_v, - attn_mask, - cache_k_scale, - cache_v_scale, - cache_k_zp, - cache_v_zp, - shift_bias, - smooth_weight, - seq_lens_q, - seq_lens_kv, - seq_lens_encoder, - batch_id_per_token, - cu_seqlens_q, - block_table, - batch_ids, - tile_ids_per_batch, - num_blocks, - block_shape_q, - max_seq_len, - max_dec_len, - quant_max_bound, - quant_min_bound, - in_scale, - max_partition_size, - encoder_max_partition_size, - speculate_max_draft_token_num, - causal, - is_decoder, - enable_prefill, - stream, - out); - } else if (cache_quant_type_str == "cache_fp8") { - CascadeAppendAttentionC8Kernel(meta_data, - qkv, - cache_k, - cache_v, - attn_mask, - cache_k_scale, - cache_v_scale, - cache_k_zp, - cache_v_zp, - shift_bias, - smooth_weight, - seq_lens_q, - seq_lens_kv, - seq_lens_encoder, - batch_id_per_token, - cu_seqlens_q, - block_table, - batch_ids, - tile_ids_per_batch, - num_blocks, - block_shape_q, - max_seq_len, - max_dec_len, - quant_max_bound, - quant_min_bound, - in_scale, - max_partition_size, - encoder_max_partition_size, - speculate_max_draft_token_num, - causal, - is_decoder, - enable_prefill, - stream, - out); - } else if (cache_quant_type_str == "cache_int4_zp") { - CascadeAppendAttentionC4Kernel(meta_data, - qkv, - cache_k, - cache_v, - attn_mask, - cache_k_scale, - cache_v_scale, - cache_k_zp, - cache_v_zp, - shift_bias, - smooth_weight, - seq_lens_q, - seq_lens_kv, - seq_lens_encoder, - batch_id_per_token, - cu_seqlens_q, - block_table, - batch_ids, - tile_ids_per_batch, - num_blocks, - block_shape_q, - max_seq_len, - max_dec_len, - quant_max_bound, - quant_min_bound, - in_scale, - max_partition_size, - encoder_max_partition_size, - speculate_max_draft_token_num, - causal, - is_decoder, - enable_prefill, - stream, - out); - } else { - PD_THROW( - "cache_quant_type_str should be one of [none, cache_int8, " - "cache_int4_zp]"); - } + paddle::Tensor* out, + const int sliding_window = 0, + const int sink_size = 0) { + if (cache_quant_type_str == "none") { + CascadeAppendAttentionC16Kernel(meta_data, + qkv, + cache_k, + cache_v, + attn_mask, + cache_k_scale, + cache_v_scale, + cache_k_zp, + cache_v_zp, + shift_bias, + smooth_weight, + sinks, + seq_lens_q, + seq_lens_kv, + seq_lens_encoder, + batch_id_per_token, + cu_seqlens_q, + block_table, + batch_ids, + tile_ids_per_batch, + num_blocks, + block_shape_q, + max_seq_len, + max_dec_len, + quant_max_bound, + quant_min_bound, + in_scale, + max_partition_size, + encoder_max_partition_size, + speculate_max_draft_token_num, + causal, + is_decoder, + enable_prefill, + stream, + out, + sliding_window, + sink_size); + } else if (cache_quant_type_str == "cache_int8") { + CascadeAppendAttentionC8Kernel( + meta_data, + qkv, + cache_k, + cache_v, + attn_mask, + cache_k_scale, + cache_v_scale, + cache_k_zp, + cache_v_zp, + shift_bias, + smooth_weight, + sinks, + seq_lens_q, + seq_lens_kv, + seq_lens_encoder, + batch_id_per_token, + cu_seqlens_q, + block_table, + batch_ids, + tile_ids_per_batch, + num_blocks, + block_shape_q, + max_seq_len, + max_dec_len, + quant_max_bound, + quant_min_bound, + in_scale, + max_partition_size, + encoder_max_partition_size, + speculate_max_draft_token_num, + causal, + is_decoder, + enable_prefill, + cache_quant_type_str, + stream, + out, + sliding_window, + sink_size); + } else if (cache_quant_type_str == "cache_fp8" or + cache_quant_type_str == "block_wise_fp8") { + CascadeAppendAttentionC8Kernel(meta_data, + qkv, + cache_k, + cache_v, + attn_mask, + cache_k_scale, + cache_v_scale, + cache_k_zp, + cache_v_zp, + shift_bias, + smooth_weight, + sinks, + seq_lens_q, + seq_lens_kv, + seq_lens_encoder, + batch_id_per_token, + cu_seqlens_q, + block_table, + batch_ids, + tile_ids_per_batch, + num_blocks, + block_shape_q, + max_seq_len, + max_dec_len, + quant_max_bound, + quant_min_bound, + in_scale, + max_partition_size, + encoder_max_partition_size, + speculate_max_draft_token_num, + causal, + is_decoder, + enable_prefill, + cache_quant_type_str, + stream, + out, + sliding_window, + sink_size); + } else if (cache_quant_type_str == "cache_int4_zp") { + CascadeAppendAttentionC4Kernel(meta_data, + qkv, + cache_k, + cache_v, + attn_mask, + cache_k_scale, + cache_v_scale, + cache_k_zp, + cache_v_zp, + shift_bias, + smooth_weight, + sinks, + seq_lens_q, + seq_lens_kv, + seq_lens_encoder, + batch_id_per_token, + cu_seqlens_q, + block_table, + batch_ids, + tile_ids_per_batch, + num_blocks, + block_shape_q, + max_seq_len, + max_dec_len, + quant_max_bound, + quant_min_bound, + in_scale, + max_partition_size, + encoder_max_partition_size, + speculate_max_draft_token_num, + causal, + is_decoder, + enable_prefill, + stream, + out, + sliding_window, + sink_size); + } else { + PD_THROW( + "cache_quant_type_str should be one of [none, cache_int8, " + "cache_int4_zp]"); + } } diff --git a/custom_ops/gpu_ops/append_attn/decode_attention_func.cuh b/custom_ops/gpu_ops/append_attn/decode_attention_func.cuh index 3ac80b6cc0c..8f7b096e6b0 100644 --- a/custom_ops/gpu_ops/append_attn/decode_attention_func.cuh +++ b/custom_ops/gpu_ops/append_attn/decode_attention_func.cuh @@ -13,8 +13,8 @@ // limitations under the License. #pragma once - -#include "multi_head_latent_attention_kernel.h" +#include "helper.h" +#include "utils.cuh" template struct softmax_state_t { @@ -42,13 +42,10 @@ struct softmax_state_t { } } - __device__ __forceinline__ softmax_state_t() { - init(); - } + __device__ __forceinline__ softmax_state_t() { init(); } - __device__ __forceinline__ void merge(const AlignedVector& other_o, - T other_m, - T other_d) { + __device__ __forceinline__ void merge( + const AlignedVector& other_o, T other_m, T other_d) { // using kType = typename cascade_attn_nv_type2_traits::type; T m_prev = m, d_prev = d; m = m_prev > other_m ? m_prev : other_m; @@ -63,13 +60,11 @@ struct softmax_state_t { } __device__ __forceinline__ void normalize() { - #pragma unroll for (size_t i = 0; i < vec_size; ++i) { o[i] /= d; } } - }; template @@ -102,65 +97,79 @@ struct softmax_state_ts { } } - __device__ __forceinline__ softmax_state_ts() { - init(); - } + __device__ __forceinline__ softmax_state_ts() { init(); } __device__ __forceinline__ void normalize(const uint32_t tile_id) { - #pragma unroll for (size_t i = 0; i < vec_size; i++) { o[tile_id][i] /= d; } } - }; -template -__device__ __forceinline__ void produce_kv(CacheT *smem, - CacheT *kv_base_gptr, - const int * block_table_smem, - const uint32_t seq_offset_gmem, - const uint32_t seq_offset_smem, - const uint32_t kv_head_idx, - const uint32_t kv_num_heads, - const uint32_t tidx, - const uint32_t chunk_start, - const uint32_t chunk_end) { +template +__device__ __forceinline__ void produce_kv(CacheT* smem, + CacheT* kv_base_gptr, + const int* block_table_smem, + const uint32_t seq_offset_gmem, + const uint32_t seq_offset_smem, + const uint32_t kv_head_idx, + const uint32_t kv_num_heads, + const uint32_t tidx, + const uint32_t chunk_start, + const uint32_t chunk_end) { int block_id = __ldg(&block_table_smem[seq_offset_gmem / BLOCK_SIZE]); if (block_id < 0) { block_id = 0; } const uint32_t block_offset = seq_offset_gmem % BLOCK_SIZE; // 8/16 T/int8 each time - const uint32_t k_offset_base = ((block_id * kv_num_heads + kv_head_idx) * BLOCK_SIZE + block_offset) * HEAD_DIM_QK; + const uint32_t k_offset_base = + ((block_id * kv_num_heads + kv_head_idx) * BLOCK_SIZE + block_offset) * + HEAD_DIM_QK; const uint32_t smem_offset_base = seq_offset_smem * HEAD_DIM_QK; - for(uint32_t vid = tidx; vid < NUM_VEC_PER_HEAD; vid += bdx) { + for (uint32_t vid = tidx; vid < NUM_VEC_PER_HEAD; vid += bdx) { pred_load<128, PrefetchMode::kPrefetch, fill_mode, CacheT>( - smem + smem_offset_base + vid * CACHE_VEC_SIZE, - kv_base_gptr + k_offset_base + vid * CACHE_VEC_SIZE, - seq_offset_gmem < chunk_end - ); + smem + smem_offset_base + vid * CACHE_VEC_SIZE, + kv_base_gptr + k_offset_base + vid * CACHE_VEC_SIZE, + seq_offset_gmem < chunk_end); } } -template -__device__ __forceinline__ void compute_qk(const T* cu_q_smem, - const CacheT* k_smem, - const uint32_t kv_idx_base, - const uint32_t stage_idx, - const uint32_t iter_base, - const uint32_t iter_bound, - const uint32_t tidx, - const uint32_t gid, - const float scale, - float *s, - softmax_state_ts& st) { +template +__device__ __forceinline__ void compute_qk( + const T* cu_q_smem, + const CacheT* k_smem, + const uint32_t kv_idx_base, + const uint32_t stage_idx, + const uint32_t iter_base, + const uint32_t iter_bound, + const uint32_t tidx, + const uint32_t gid, + const float scale, + float* s, + softmax_state_ts& st) { const CacheT* smem; AlignedVector q_vec; AlignedVector k_vec; float m_prev = st.m; - // smem = base_smem + (stage_idx * DEAL_EACH_TIME + zid * tile_size) * HEAD_DIM; + // smem = base_smem + (stage_idx * DEAL_EACH_TIME + zid * tile_size) * + // HEAD_DIM; smem = k_smem + stage_idx * DEAL_EACH_TIME * HEAD_DIM; #pragma unroll for (uint32_t j = 0; j < DEAL_EACH_TIME; ++j) { @@ -171,7 +180,7 @@ __device__ __forceinline__ void compute_qk(const T* cu_q_smem, s[j] = 0.f; } #pragma unroll - for(uint32_t vid = tidx; vid < NUM_VEC_PER_HEAD; vid += bdx) { + for (uint32_t vid = tidx; vid < NUM_VEC_PER_HEAD; vid += bdx) { Load(cu_q_smem + vid * vec_size, &q_vec); Load(smem + j * HEAD_DIM + vid * vec_size, &k_vec); for (uint32_t i = 0; i < vec_size; ++i) { @@ -211,20 +220,29 @@ __device__ __forceinline__ void compute_qk(const T* cu_q_smem, } } -template -__device__ __forceinline__ void compute_sv(const float *s, - const CacheT *base_v_smem, - const uint32_t stage_idx, - const uint32_t iter_base, - const uint32_t iter_bound, - const uint32_t tidx, - softmax_state_ts& st) { +template +__device__ __forceinline__ void compute_sv( + const float* s, + const CacheT* base_v_smem, + const uint32_t stage_idx, + const uint32_t iter_base, + const uint32_t iter_bound, + const uint32_t tidx, + softmax_state_ts& st) { const CacheT* v_smem; AlignedVector v_vec; #pragma unroll for (int j = 0; (j < DEAL_EACH_TIME) && (iter_base + j < iter_bound); ++j) { - v_smem = base_v_smem + stage_idx * DEAL_EACH_TIME * HEAD_DIM_QK + j * HEAD_DIM_QK; - for(uint32_t vid = tidx; vid < NUM_VEC_PER_HEAD; vid += bdx) { + v_smem = base_v_smem + stage_idx * DEAL_EACH_TIME * HEAD_DIM_QK + + j * HEAD_DIM_QK; + for (uint32_t vid = tidx; vid < NUM_VEC_PER_HEAD; vid += bdx) { Load(v_smem + vid * vec_size, &v_vec); uint32_t tile_id = vid / bdx; #pragma unroll diff --git a/custom_ops/gpu_ops/append_attn/decode_attention_kernel.cu b/custom_ops/gpu_ops/append_attn/decode_attention_kernel.cu deleted file mode 100644 index 701ba42df46..00000000000 --- a/custom_ops/gpu_ops/append_attn/decode_attention_kernel.cu +++ /dev/null @@ -1,560 +0,0 @@ -// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "decode_attention_func.cuh" - -#define CHECK(call) \ -do \ -{ \ - const cudaError_t error_code = call; \ - if (error_code != cudaSuccess) \ - { \ - printf("CUDA Error:\n"); \ - printf(" File: %s\n", __FILE__); \ - printf(" Line %d:\n", __LINE__); \ - printf(" Error code:%d\n", error_code); \ - printf(" Error text:%s\n", cudaGetErrorString(error_code)); \ - exit(1); \ - } \ -}while(0) - -template -__global__ void merge_varlen_multi_chunks_v2_kernel(const T * __restrict__ multi_out, // [bsz, num_chunks, num_heads, head_dim] - const T * __restrict__ multi_m, // [bsz, num_chunks, num_heads] - const T * __restrict__ multi_d, // [bsz, num_chunks, num_heads] - const int * __restrict__ seq_lens_q, - const int * __restrict__ seq_lens_kv, - const int * __restrict__ cu_seqlens_q, - const T * __restrict__ shift_bias, // [q_num_heads * HEAD_DIM] - const T * __restrict__ smooth_weight, // [q_num_heads * HEAD_DIM] - OutT * __restrict__ out, // [token_num, num_heads, head_dim] - const float in_scale, - const int num_chunks, - const int chunk_size, - const int max_seq_len, - const int num_heads, - const int head_dim) { - const int vid = threadIdx.x, ty = threadIdx.y; - const int qid = blockIdx.x, hid = blockIdx.y; - const int seq_len_q = seq_lens_q[qid]; - if (seq_len_q == 0) return; - int seq_len_kv = seq_lens_kv[qid]; - if (seq_len_kv == 0) return; - seq_len_kv += seq_len_q; - const int num_chunks_this_seq = div_up(seq_len_kv, chunk_size); - if (num_chunks_this_seq == 1 || ty >= num_chunks_this_seq) { - return; - } - __shared__ T smem[bdy * HEAD_DIM]; - __shared__ T md_smem[bdy * 2]; - - const int start_token_ids = cu_seqlens_q[qid]; - using LoadT = AlignedVector; - LoadT load_vec; - LoadT res_vec; - if constexpr (std::is_same::value) { -#pragma unroll - for (int i = 0; i < vec_size / 2; ++i) { - *((half2*)(&res_vec) + i) = make_half2(0, 0); - } - } else if constexpr (std::is_same::value) { -#pragma unroll - for (int i = 0; i < vec_size / 2; ++i) { - *((nv_bfloat162*)(&res_vec) + i) = make_bfloat162(0, 0); - } - } - T m; - T d = 1.f; - if constexpr (std::is_same::value) { - m = __float2half(-5e4f); - } else if constexpr (std::is_same::value) { - m = __float2bfloat16(-3.38953e38f); - } - // merge per ty -#pragma unroll 2 - for (int i = ty; i < num_chunks_this_seq; i += bdy) { - uint32_t offset = (qid * num_chunks + i) * num_heads + hid; - T m_prev = m; - T d_prev = d; - const T m_now = multi_m[offset]; - const T d_now = multi_d[offset]; - m = m_prev > m_now ? m_prev : m_now; - offset = (qid * num_chunks * num_heads + i * num_heads + hid) * head_dim + vid * vec_size; - Load(&multi_out[offset], &load_vec); - const T scale1 = hexp(m_prev - m), scale2 = hexp(m_now - m); - d = d * scale1 + d_now * scale2; -#pragma once - for (int j = 0; j < vec_size; j++) { - res_vec[j] = res_vec[j] * scale1 + load_vec[j] * scale2; - } - } - // store ty res - Store(res_vec, &smem[ty * head_dim + vid * vec_size]); - md_smem[2 * ty] = m; - md_smem[2 * ty + 1] = d; - __syncthreads(); - - // merge bdy - softmax_state_t st{}; - const uint32_t iter_num = min(num_chunks_this_seq, bdy); -#pragma once - for (int i = 0; i < iter_num; i++) { - Load(&smem[i * head_dim + vid * vec_size], &load_vec); - const T m_tmp = md_smem[2 * i], d_tmp = md_smem[2 * i + 1]; - st.merge(load_vec, m_tmp, d_tmp); - } - st.normalize(); - - AlignedVector out_vec; - -#pragma unroll - for (int i = 0; i < vec_size; ++i) { - out_vec[i] = static_cast(st.o[i]); - } - Store(out_vec, &out[(start_token_ids * num_heads + hid) * head_dim + vid * vec_size]); -} - -template -__global__ void multi_query_decode_attention_kernel(T * __restrict__ q, // [token_num, num_heads, head_dim] - CacheT * __restrict__ cache_k, // [max_block_num, num_heads, block_size, head_dim] - CacheT * __restrict__ cache_v, - const T * __restrict__ shift_bias, // [q_num_heads * HEAD_DIM] - const T * __restrict__ smooth_weight, // [q_num_heads * HEAD_DIM] - const int * __restrict__ seq_lens_q, - const int * __restrict__ seq_lens_kv, - const int * __restrict__ cu_seqlens_q, - const int * __restrict__ block_table, // [bsz, block_num_per_seq] - const int max_seq_len, - const int max_dec_len, - const int max_block_num_per_seq, - const float scale, - const float in_scale, - const uint32_t chunk_size, - T * __restrict__ tmp_workspace, // [batch_size, num_chunks, num_heads, head_dim] - T * __restrict__ tmp_m, // [batch_size, num_chunks, num_heads] - T * __restrict__ tmp_d, // [batch_size, num_chunks, num_heads] - OutT * __restrict__ out) { - const uint32_t bidx = blockIdx.x, kv_head_idx = blockIdx.z; - const uint32_t bid = bidx, gid = threadIdx.y; - const uint32_t tidx = threadIdx.x; - constexpr uint32_t num_vec_per_head_qk = HEAD_DIM_QK / VEC_SIZE; - constexpr uint32_t num_vec_per_head_v = HEAD_DIM_V / VEC_SIZE; - constexpr uint32_t num_tile_v = (num_vec_per_head_v + bdx - 1) / bdx; - - const uint32_t q_head_idx = kv_head_idx * GROUP_SIZE + gid; - const uint32_t kv_num_heads = gridDim.z; - const uint32_t q_num_heads = kv_num_heads * GROUP_SIZE; - - const int *block_table_now = block_table + bid * max_block_num_per_seq; - - const uint32_t num_chunks = gridDim.y; - const uint32_t chunk_id = blockIdx.y; - const uint32_t q_len = seq_lens_q[bid]; - if (q_len <= 0) { - return; - } - uint32_t kv_len = seq_lens_kv[bid]; // !!!!!!!! - if (kv_len <= 0) { - return; - } - kv_len += q_len; - const uint32_t num_chunk_this_seq = div_up(kv_len, chunk_size); - const uint32_t q_start_idx = cu_seqlens_q[bid]; - const uint32_t q_write_idx = cu_seqlens_q[bid]; - if (chunk_id >= num_chunk_this_seq) { - return; - } - - const uint32_t chunk_start = partition_kv ? chunk_id * chunk_size : 0; - const uint32_t chunk_end = partition_kv ? min(kv_len, chunk_start + chunk_size) : kv_len; - const uint32_t chunk_len = chunk_end - chunk_start; - - extern __shared__ uint8_t smem[]; - const T *q_now = q + (q_start_idx * q_num_heads + q_head_idx) * HEAD_DIM_QK; - T *q_smem = reinterpret_cast(smem); // [HEAD_DIM_QK * sizeof(T)] - T *cu_q_smem = q_smem + gid * HEAD_DIM_QK; -#pragma unroll - for(uint32_t vid = tidx; vid < num_vec_per_head_qk; vid += bdx) { - ((float4*)(&cu_q_smem[vid * VEC_SIZE]))[0] = ((float4*)(&q_now[vid * VEC_SIZE]))[0]; - - } - __syncthreads(); - using VecT = AlignedVector; - VecT q_vec; -#pragma unroll - for(uint32_t vid = tidx; vid < num_vec_per_head_qk; vid += bdx) { - Load(cu_q_smem + vid * VEC_SIZE, &q_vec); - for (uint32_t i = 0; i < VEC_SIZE; ++i) { - q_vec[i] *= scale; - } - Store(q_vec, cu_q_smem + vid * VEC_SIZE); - } - - - CacheT *kv_smem = reinterpret_cast(smem + GROUP_SIZE * HEAD_DIM_QK * sizeof(CacheT)); - uint32_t stage_idx = 0; - constexpr int loop_times = DEAL_EACH_TIME / bdy; -#pragma unroll - for (int i = 0; i < NUM_STAGES; ++i) { -#pragma unroll - for (int j = 0; j < loop_times; ++j) { - const uint32_t k_seq_offset = i * DEAL_EACH_TIME + j * bdy + gid; - const uint32_t k_seq_id = chunk_start + k_seq_offset; - produce_kv( - kv_smem, - cache_k, - block_table_now, - k_seq_id, - k_seq_offset, - kv_head_idx, - kv_num_heads, - tidx, - chunk_start, - chunk_end - ); - } - commit_group(); - stage_idx = (stage_idx + 1) % NUM_STAGES; - } - - - softmax_state_ts st; - float s[DEAL_EACH_TIME]; - - const uint32_t num_iters = div_up(chunk_len, DEAL_EACH_TIME); - for (int iter = 0; iter < num_iters; ++iter) { - wait_group(); - __syncthreads(); - // compute qk - compute_qk( - cu_q_smem, - kv_smem, - chunk_start + iter * DEAL_EACH_TIME, - stage_idx, - iter * DEAL_EACH_TIME, - chunk_len, - tidx, - gid, - scale, - s, - st - ); - __syncthreads(); - - // compute sv - compute_sv( - s, - kv_smem, - stage_idx, - iter * DEAL_EACH_TIME, - chunk_len, - tidx, - st - ); - __syncthreads(); - -#pragma unroll - for (int j = 0; j < loop_times; ++j) { - const uint32_t k_seq_offset = j * bdy + gid; - produce_kv( - kv_smem, - cache_k, - block_table_now, - chunk_start + k_seq_offset + (iter + NUM_STAGES) * DEAL_EACH_TIME, - stage_idx * DEAL_EACH_TIME + k_seq_offset, - kv_head_idx, - kv_num_heads, - tidx, - chunk_start, - chunk_end - ); - } - commit_group(); - stage_idx = (stage_idx + 1) % NUM_STAGES; - } - wait_group<0>(); - __syncthreads(); - - // normize if not partition_kv - for(uint32_t vid = tidx; vid < num_vec_per_head_v; vid += bdx) { - const uint32_t tile_id = vid / bdx; - if (!partition_kv || num_chunk_this_seq == 1) { - st.normalize(tile_id); - } - if (partition_kv && num_chunk_this_seq > 1) { - const uint32_t head_idx = (bid * num_chunks + chunk_id) * q_num_heads + q_head_idx; - Store(st.o[tile_id], tmp_workspace + head_idx * HEAD_DIM_V + vid * VEC_SIZE); - tmp_m[head_idx] = st.m; - tmp_d[head_idx] = st.d; - } else { - Store(st.o[tile_id], out + (q_write_idx * q_num_heads + q_head_idx) * HEAD_DIM_V + vid * VEC_SIZE); - } - } -} - - -template -void MultiQueryDecoderAttention( - const AppendAttnMetaData& meta_data, - cudaStream_t &stream, - const paddle::Tensor &q, - const paddle::Tensor &cache_k, // [max_block_num, num_kv_heads, block_size, head_dim] - const paddle::Tensor &cache_v, // [num_kv_heads, head_dim] - const paddle::optional& attn_mask, - const paddle::optional& shift_bias, - const paddle::optional& smooth_weight, - const paddle::Tensor &seq_lens_q, - const paddle::Tensor &seq_lens_kv, - const paddle::Tensor &batch_id_per_token, - const paddle::Tensor &cu_seqlens_q, - const paddle::Tensor &block_table, - const int max_seq_len, - const int max_dec_len, - const float rope_scale, - const float rope_theta, - const float softmax_scale, - const float in_scale, - paddle::Tensor *out) { - using NV_TYPE = typename cascade_attn_type_traits::type; - - auto num_heads = meta_data.q_num_heads; - auto kv_num_heads = meta_data.kv_num_heads; - auto token_num = meta_data.token_nums; - auto bsz = meta_data.batch_size; - auto max_block_num_per_seq = meta_data.max_blocks_per_seq; - constexpr int num_stages = NUM_STAGE; - - constexpr int vec_size = 16 / sizeof(T); // 8 16 32 - constexpr int cache_vec_size = 128 / cache_bytes; // 8 16 32 - constexpr int blockxc = HEAD_DIM_QK / cache_vec_size; - constexpr int num_vec_per_head = HEAD_DIM_QK / vec_size; - constexpr int blockx = num_vec_per_head < 32 ? num_vec_per_head : 32; - - constexpr int blocky = GROUP_SIZE; - const int gridx = bsz; - - constexpr int num_threads = blockx * blocky; - - auto splitkv_kernel = multi_query_decode_attention_kernel; - uint32_t cache_smem_bytes = 0; - - const T *shift_bias_ptr = shift_bias ? shift_bias.get().data() : nullptr; - const T *smooth_weight_ptr = smooth_weight ? smooth_weight.get().data() : nullptr; - cache_smem_bytes = num_stages * DEAL_EACH_TIME * HEAD_DIM_QK * sizeof(T); - - const uint32_t chunk_size = get_max_partition_size(bsz); - const int num_chunks = div_up(max_dec_len, chunk_size); - size_t smem_size = cache_smem_bytes + GROUP_SIZE * HEAD_DIM_QK * sizeof(T); - - if (smem_size >= 48 * 1024) { - cudaFuncSetAttribute( - splitkv_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); - } - const int dev_id = 0; - int sm_count; - int act_blocks_per_sm; - cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id); - cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &act_blocks_per_sm, splitkv_kernel, num_threads, smem_size); - assert(act_blocks_per_sm > 1); - - const int num_blocks_per_wave = sm_count * act_blocks_per_sm; - const int num_blocks_need = gridx * num_chunks * kv_num_heads; - const int max_num_chunks = div_up(num_blocks_per_wave, num_blocks_need); - const float ratio = static_cast(num_blocks_need) / static_cast(num_blocks_per_wave); - - dim3 grids(gridx, num_chunks, kv_num_heads); - dim3 blocks(blockx, blocky); - if (num_chunks <= 1) { - auto no_splitkv_kernel = multi_query_decode_attention_kernel; - if (smem_size >= 48 * 1024) { - cudaFuncSetAttribute( - no_splitkv_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); - } - no_splitkv_kernel<<>>( - reinterpret_cast(const_cast(q.data())), - reinterpret_cast(const_cast(cache_k.data())), - reinterpret_cast(const_cast(cache_v.data())), - reinterpret_cast(const_cast(shift_bias_ptr)), - reinterpret_cast(const_cast(smooth_weight_ptr)), - seq_lens_q.data(), - seq_lens_kv.data(), - cu_seqlens_q.data(), - block_table.data(), - max_seq_len, - max_dec_len, - max_block_num_per_seq, - softmax_scale, - in_scale, - chunk_size, - nullptr, - nullptr, - nullptr, - reinterpret_cast(const_cast(out->data())) - ); - - // CHECK(cudaGetLastError()); - // CHECK(cudaDeviceSynchronize()); - } else { - auto *allocator = paddle::GetAllocator(q.place()); - phi::Allocator::AllocationPtr tmp_workspace, tmp_m, tmp_d; - tmp_workspace = allocator->Allocate( - phi::SizeOf(q.dtype()) * - static_cast(bsz * num_chunks * num_heads * HEAD_DIM_V)); - tmp_m = allocator->Allocate( - phi::SizeOf(q.dtype()) * - static_cast(bsz * num_chunks * num_heads)); - tmp_d = allocator->Allocate( - phi::SizeOf(q.dtype()) * - static_cast(bsz * num_chunks * num_heads)); - - splitkv_kernel<<>>( - reinterpret_cast(const_cast(q.data())), - reinterpret_cast(const_cast(cache_k.data())), - reinterpret_cast(const_cast(cache_v.data())), - reinterpret_cast(const_cast(shift_bias_ptr)), - reinterpret_cast(const_cast(smooth_weight_ptr)), - seq_lens_q.data(), - seq_lens_kv.data(), - cu_seqlens_q.data(), - block_table.data(), - max_seq_len, - max_dec_len, - max_block_num_per_seq, - softmax_scale, - in_scale, - chunk_size, - reinterpret_cast(tmp_workspace->ptr()), - reinterpret_cast(tmp_m->ptr()), - reinterpret_cast(tmp_d->ptr()), - reinterpret_cast(const_cast(out->data())) - ); - // CHECK(cudaGetLastError()); - // CHECK(cudaDeviceSynchronize()); - - constexpr int mblockx = HEAD_DIM_V / vec_size; - constexpr int bdy = 256 / mblockx; - dim3 grids_merge(bsz, num_heads); - dim3 blocks_merge(mblockx, bdy); - merge_varlen_multi_chunks_v2_kernel<<>>( - reinterpret_cast(tmp_workspace->ptr()), - reinterpret_cast(tmp_m->ptr()), - reinterpret_cast(tmp_d->ptr()), - seq_lens_q.data(), - seq_lens_kv.data(), - cu_seqlens_q.data(), - reinterpret_cast(const_cast(shift_bias_ptr)), - reinterpret_cast(const_cast(smooth_weight_ptr)), - reinterpret_cast(const_cast(out->data())), - in_scale, - num_chunks, - chunk_size, - max_seq_len, - num_heads, - HEAD_DIM_V - ); - } - // CHECK(cudaGetLastError()); - // CHECK(cudaDeviceSynchronize()); -} - -template -void DecodeMLAAttentionKernel( - const AppendAttnMetaData& meta_data, - const paddle::Tensor &q, // [token_num, num_heads, head_dim] - const paddle::Tensor &cache_k, - const paddle::Tensor &cache_v, - const paddle::optional& attn_mask, - const paddle::optional& shift_bias, - const paddle::optional& smooth_weight, - const paddle::Tensor &seq_lens_q, // q_seq_len is 1 - const paddle::Tensor &seq_lens_kv, - const paddle::Tensor &batch_id_per_token, - const paddle::Tensor &cu_seqlens_q, - const paddle::Tensor &block_table, - int max_seq_len, - int max_dec_len, - float softmax_scale, - float in_scale, - bool causal, - cudaStream_t &stream, - paddle::Tensor *out) { - const auto token_num = meta_data.token_nums; - const auto block_size = meta_data.block_size; - const auto bsz = meta_data.batch_size; - const auto num_heads = meta_data.q_num_heads; - const auto group_size = meta_data.q_num_heads / meta_data.kv_num_heads; - const auto head_dim_qk = meta_data.head_dims; - const auto head_dim_v = meta_data.head_dims_v; - const float rope_scale = 0.0; - const float rope_theta = 0.0; - const uint32_t deal_each_time = get_cascade_attention_deal_each_time(); - const uint32_t num_stage = get_cascade_attention_num_stages(); - const uint32_t num_threads = get_cascade_attention_num_threads(); - - DISPATCH_CAUSAL(causal, CAUSAL, - {DISPATCH_MLA_GROUP_SIZE(group_size, GROUP_SIZE, - {DISPATCH_MLA_HEAD_DIM(head_dim_qk, HEAD_DIM_QK, - {DISPATCH_MLA_HEAD_DIM(head_dim_v, HEAD_DIM_V, - {DISPATCH_BLOCK_SIZE(block_size, BLOCK_SIZE, - {DISPATCH_DEAL_EACH_TIME(deal_each_time, DEAL_EACH_TIME, - {MultiQueryDecoderAttention( - meta_data, stream, q, cache_k, cache_v, attn_mask, shift_bias, smooth_weight, seq_lens_q, seq_lens_kv, batch_id_per_token, cu_seqlens_q, - block_table, max_seq_len, max_dec_len, rope_scale, rope_theta, softmax_scale, in_scale, out);})})})})})}); -} - -template void DecodeMLAAttentionKernel( - const AppendAttnMetaData& meta_data, - const paddle::Tensor &q, // [token_num, num_heads, head_dim] - const paddle::Tensor &cache_k, - const paddle::Tensor &cache_v, - const paddle::optional& attn_mask, - const paddle::optional& shift_bias, - const paddle::optional& smooth_weight, - const paddle::Tensor &seq_lens_q, // q_seq_len is 1 - const paddle::Tensor &seq_lens_kv, - const paddle::Tensor &batch_id_per_token, - const paddle::Tensor &cu_seqlens_q, - const paddle::Tensor &block_table, - int max_seq_len, - int max_dec_len, - float softmax_scale, - float in_scale, - bool causal, - cudaStream_t &stream, - paddle::Tensor *out); - -template void DecodeMLAAttentionKernel( - const AppendAttnMetaData& meta_data, - const paddle::Tensor &q, // [token_num, num_heads, head_dim] - const paddle::Tensor &cache_k, - const paddle::Tensor &cache_v, - const paddle::optional& attn_mask, - const paddle::optional& shift_bias, - const paddle::optional& smooth_weight, - const paddle::Tensor &seq_lens_q, // q_seq_len is 1 - const paddle::Tensor &seq_lens_kv, - const paddle::Tensor &batch_id_per_token, - const paddle::Tensor &cu_seqlens_q, - const paddle::Tensor &block_table, - int max_seq_len, - int max_dec_len, - float softmax_scale, - float in_scale, - bool causal, - cudaStream_t &stream, - paddle::Tensor *out); diff --git a/custom_ops/gpu_ops/append_attn/decoder_mla_attention_kernel.cu b/custom_ops/gpu_ops/append_attn/decoder_mla_attention_kernel.cu new file mode 100644 index 00000000000..6e2d9eb2ba2 --- /dev/null +++ b/custom_ops/gpu_ops/append_attn/decoder_mla_attention_kernel.cu @@ -0,0 +1,142 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include "helper.h" +#include "multiquery_decoder_attention_kernel.h" +#include "utils.cuh" + +template +void DecodeMLAAttentionKernel( + const AppendAttnMetaData &meta_data, + const paddle::Tensor &q, // [token_num, num_heads, head_dim] + const paddle::Tensor &cache_k, + const paddle::Tensor &cache_v, + const paddle::optional &attn_mask, + const paddle::optional &shift_bias, + const paddle::optional &smooth_weight, + const paddle::Tensor &seq_lens_q, // q_seq_len is 1 + const paddle::Tensor &seq_lens_kv, + const paddle::Tensor &batch_id_per_token, + const paddle::Tensor &cu_seqlens_q, + const paddle::Tensor &block_table, + int max_seq_len, + int max_dec_len, + float softmax_scale, + float in_scale, + bool causal, + cudaStream_t &stream, + paddle::Tensor *out) { + const auto token_num = meta_data.token_nums; + const auto block_size = meta_data.block_size; + const auto bsz = meta_data.batch_size; + const auto num_heads = meta_data.q_num_heads; + const auto group_size = meta_data.q_num_heads / meta_data.kv_num_heads; + const auto head_dim_qk = meta_data.head_dims; + const auto head_dim_v = meta_data.head_dims_v; + const float rope_scale = 0.0; + const float rope_theta = 0.0; + const uint32_t deal_each_time = get_cascade_attention_deal_each_time(); + const uint32_t num_stage = get_cascade_attention_num_stages(); + const uint32_t num_threads = get_cascade_attention_num_threads(); + + DISPATCH_CAUSAL( + causal, + CAUSAL, + {DISPATCH_MLA_GROUP_SIZE( + group_size, + GROUP_SIZE, + {DISPATCH_MLA_HEAD_DIM( + head_dim_qk, + HEAD_DIM_QK, + {DISPATCH_MLA_HEAD_DIM( + head_dim_v, + HEAD_DIM_V, + {DISPATCH_BLOCK_SIZE( + block_size, + BLOCK_SIZE, + {DISPATCH_DEAL_EACH_TIME(deal_each_time, DEAL_EACH_TIME, { + MultiQueryDecoderAttention( + meta_data, + stream, + q, + cache_k, + cache_v, + attn_mask, + shift_bias, + smooth_weight, + seq_lens_q, + seq_lens_kv, + batch_id_per_token, + cu_seqlens_q, + block_table, + max_seq_len, + max_dec_len, + rope_scale, + rope_theta, + softmax_scale, + in_scale, + out); + })})})})})}); +} + +template void DecodeMLAAttentionKernel( + const AppendAttnMetaData &meta_data, + const paddle::Tensor &q, // [token_num, num_heads, head_dim] + const paddle::Tensor &cache_k, + const paddle::Tensor &cache_v, + const paddle::optional &attn_mask, + const paddle::optional &shift_bias, + const paddle::optional &smooth_weight, + const paddle::Tensor &seq_lens_q, // q_seq_len is 1 + const paddle::Tensor &seq_lens_kv, + const paddle::Tensor &batch_id_per_token, + const paddle::Tensor &cu_seqlens_q, + const paddle::Tensor &block_table, + int max_seq_len, + int max_dec_len, + float softmax_scale, + float in_scale, + bool causal, + cudaStream_t &stream, + paddle::Tensor *out); + +template void DecodeMLAAttentionKernel( + const AppendAttnMetaData &meta_data, + const paddle::Tensor &q, // [token_num, num_heads, head_dim] + const paddle::Tensor &cache_k, + const paddle::Tensor &cache_v, + const paddle::optional &attn_mask, + const paddle::optional &shift_bias, + const paddle::optional &smooth_weight, + const paddle::Tensor &seq_lens_q, // q_seq_len is 1 + const paddle::Tensor &seq_lens_kv, + const paddle::Tensor &batch_id_per_token, + const paddle::Tensor &cu_seqlens_q, + const paddle::Tensor &block_table, + int max_seq_len, + int max_dec_len, + float softmax_scale, + float in_scale, + bool causal, + cudaStream_t &stream, + paddle::Tensor *out); diff --git a/custom_ops/gpu_ops/append_attn/multi_head_latent_attention_kernel.h b/custom_ops/gpu_ops/append_attn/decoder_mla_attention_kernel.h similarity index 72% rename from custom_ops/gpu_ops/append_attn/multi_head_latent_attention_kernel.h rename to custom_ops/gpu_ops/append_attn/decoder_mla_attention_kernel.h index 4d81b99a734..1546f376852 100644 --- a/custom_ops/gpu_ops/append_attn/multi_head_latent_attention_kernel.h +++ b/custom_ops/gpu_ops/append_attn/decoder_mla_attention_kernel.h @@ -1,4 +1,4 @@ -// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,19 +12,20 @@ // See the License for the specific language governing permissions and // limitations under the License. #pragma once + #include "helper.h" #include "utils.cuh" template void DecodeMLAAttentionKernel( - const AppendAttnMetaData& meta_data, - const paddle::Tensor &q, // [token_num, num_heads, head_dim] + const AppendAttnMetaData &meta_data, + const paddle::Tensor &q, // [token_num, num_heads, head_dim] const paddle::Tensor &cache_k, const paddle::Tensor &cache_v, - const paddle::optional& attn_mask, - const paddle::optional& shift_bias, - const paddle::optional& smooth_weight, - const paddle::Tensor &seq_lens_q, // q_seq_len is 1 + const paddle::optional &attn_mask, + const paddle::optional &shift_bias, + const paddle::optional &smooth_weight, + const paddle::Tensor &seq_lens_q, // q_seq_len is 1 const paddle::Tensor &seq_lens_kv, const paddle::Tensor &batch_id_per_token, const paddle::Tensor &cu_seqlens_q, diff --git a/custom_ops/gpu_ops/append_attn/decoder_write_cache_with_rope_impl.cuh b/custom_ops/gpu_ops/append_attn/decoder_write_cache_with_rope_impl.cuh index 67066efc2c8..7dd4612c529 100644 --- a/custom_ops/gpu_ops/append_attn/decoder_write_cache_with_rope_impl.cuh +++ b/custom_ops/gpu_ops/append_attn/decoder_write_cache_with_rope_impl.cuh @@ -18,17 +18,213 @@ #include "mma_tensor_op.cuh" #include "utils.cuh" -template +// Note(ZKK) +// This function is very easy! +// just make HeadDim data to be new HeadDim data! + +template +__device__ __forceinline__ void apply_rope(const T* input, + const float* cos_emb, + const float* sin_emb, + T* output, + const int thread_id) { + using LoadT = AlignedVector; + using LoadBiasT = AlignedVector; + using LoadOutScaleT = AlignedVector; + constexpr int HalfVecSize = VecSize / 2; + using LoadEmbT = AlignedVector; + + LoadT src_vec; + LoadBiasT out_vec; + LoadEmbT cos_emb_vec; + LoadEmbT sin_emb_vec; + +#pragma unroll + for (uint32_t head_bias = thread_id * VecSize; head_bias < HEAD_DIM; + head_bias += NUM_THREADS * VecSize) { + Load(&input[head_bias], &src_vec); + const uint32_t emb_idx = head_bias / 2; + Load(&cos_emb[emb_idx], &cos_emb_vec); + Load(&sin_emb[emb_idx], &sin_emb_vec); +#pragma unroll + for (int i = 0; i < HalfVecSize; i++) { + float input_left = static_cast(src_vec[2 * i]); + float input_right = static_cast(src_vec[2 * i + 1]); + + const float cos_tmp = cos_emb_vec[i]; + const float sin_tmp = sin_emb_vec[i]; + out_vec[2 * i] = + static_cast(fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp)); + out_vec[2 * i + 1] = + static_cast(fmul_func(input_right, cos_tmp) + + fmul_func(input_left, sin_tmp)); + } + Store(out_vec, &output[head_bias]); + } +} + +template +__global__ void append_decode_cache_T_rope_qk_norm_kernel( + const T* __restrict__ quant_qkv, // [bsz, num_heads + 2 * kv_num_heads, + // head_size] + T* __restrict__ key_cache, // [num_blocks, kv_num_heads, block_size, + // head_size // 2] + T* __restrict__ value_cache, // [num_blocks, kv_num_heads, block_size, + // head_size // 2] + T* __restrict__ qkv_out, + const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq] + + const int* __restrict__ cu_seqlens_q, + const int* __restrict__ seq_lens, // [bsz] + const int* __restrict__ seq_lens_encoder, // [bsz] + const float* __restrict__ cos_emb, + const float* __restrict__ sin_emb, + const int max_seq_len, + const int max_blocks_per_seq, + const int num_heads, + const int head_size, + const int block_size, + const uint32_t elem_cnt, + const int kv_num_heads, + const bool rope_3d, + const float* q_norm_weight, + const float* k_norm_weight, + const float rms_norm_eps) { + using LoadT = AlignedVector; + using LoadBiasT = AlignedVector; + using LoadKVT = AlignedVector; + constexpr int HalfVecSize = VecSize / 2; + using LoadEmbT = AlignedVector; + using LoadFloat = AlignedVector; + LoadT src_vec; + LoadBiasT out_vec; + LoadKVT cache_vec; + LoadEmbT cos_emb_vec; + LoadEmbT sin_emb_vec; + LoadFloat tmp_vec; + LoadFloat q_norm_vec, k_norm_vec; + + int64_t global_warp_idx = blockDim.y * blockIdx.x + threadIdx.y; + int64_t all_warp_num = gridDim.x * blockDim.y; + int64_t all_head_dim = elem_cnt / head_size; + + const int64_t hidden_size = (num_heads + 2 * kv_num_heads) * head_size; + const int half_head_size = head_size / 2; +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaGridDependencySynchronize(); +#endif + for (int gloabl_hi = global_warp_idx; gloabl_hi < all_head_dim; + gloabl_hi += all_warp_num) { + int64_t linear_index = gloabl_hi * head_size + threadIdx.x * VecSize; + const int ori_bi = linear_index / hidden_size; + const int bias = linear_index % hidden_size; + const int hi = bias / head_size; // q + k + v + const int h_bias = bias % head_size; + const int start_token_idx = cu_seqlens_q[ori_bi]; + if (seq_lens_encoder[ori_bi] > 0) return; + const int write_seq_id = seq_lens[ori_bi]; + if (write_seq_id == 0) continue; + + const int* block_table_now = nullptr; + + block_table_now = block_tables + ori_bi * max_blocks_per_seq; + const int block_idx = block_table_now[write_seq_id / block_size]; + const int block_offset = write_seq_id % block_size; + const uint32_t ori_idx = + start_token_idx * hidden_size + hi * head_size + h_bias; + + const int bias_idx = hi * head_size + h_bias; + Load(&quant_qkv[ori_idx], &src_vec); + if (hi < num_heads + kv_num_heads) { + // q k rope + const uint32_t emb_idx = write_seq_id * half_head_size + h_bias / 2; + uint32_t new_emb_idx = + rope_3d ? emb_idx + ori_bi * max_seq_len * head_size : emb_idx; + Load(&cos_emb[new_emb_idx], &cos_emb_vec); + Load(&sin_emb[new_emb_idx], &sin_emb_vec); + } + float thread_m2 = 0.0f; + float warp_m2 = 0.0f; + +#pragma unroll + for (int i = 0; i < HalfVecSize; i++) { + // dequant + add_bias + rope + float input_left = static_cast(src_vec[2 * i]); + float input_right = static_cast(src_vec[2 * i + 1]); + + if (hi < num_heads + kv_num_heads) { + const float cos_tmp = cos_emb_vec[i]; + const float sin_tmp = sin_emb_vec[i]; + float tmp1 = fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp); + float tmp2 = fmul_func(input_right, cos_tmp) + + fmul_func(input_left, sin_tmp); + thread_m2 += tmp1 * tmp1 + tmp2 * tmp2; + tmp_vec[2 * i] = tmp1; + tmp_vec[2 * i + 1] = tmp2; + } else { + out_vec[2 * i] = src_vec[2 * i]; + out_vec[2 * i + 1] = src_vec[2 * i + 1]; + } + } + if (hi < (num_heads + kv_num_heads)) { // q k + WelfordWarpAllReduce(thread_m2, &warp_m2); + float row_variance = max(warp_m2 / head_size, 0.0f); + float row_inv_var = Rsqrt(row_variance + rms_norm_eps); + if (hi < num_heads) { // q + Load(&q_norm_weight[threadIdx.x * VecSize], + &q_norm_vec); +#pragma unroll + for (int i = 0; i < VecSize; i++) { + out_vec[i] = static_cast(tmp_vec[i] * row_inv_var * q_norm_vec[i]); + } + } else { // k + Load(&k_norm_weight[threadIdx.x * VecSize], + &k_norm_vec); +#pragma unroll + for (int i = 0; i < VecSize; i++) { + out_vec[i] = static_cast(tmp_vec[i] * row_inv_var * k_norm_vec[i]); + } + } + } + if (hi < num_heads) { + // write q + Store(out_vec, &qkv_out[ori_idx]); + } else { + // quant + write k/v + const uint32_t kv_head_idx = (hi - num_heads) % kv_num_heads; + const uint32_t tgt_idx = + block_idx * kv_num_heads * block_size * head_size + + kv_head_idx * block_size * head_size + block_offset * head_size + + h_bias; + if (hi < num_heads + kv_num_heads) { + Store(out_vec, &key_cache[tgt_idx]); + } else { + Store(out_vec, &value_cache[tgt_idx]); + } + } + } +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaTriggerProgrammaticLaunchCompletion(); +#endif +} + +template __global__ void append_decode_cache_T_rope_kernel( const T* __restrict__ quant_qkv, // [bsz, num_heads + 2 * kv_num_heads, // head_size] - T* __restrict__ key_cache, // [num_blocks, kv_num_heads, block_size, - // head_size // 2] - T* __restrict__ value_cache, // [num_blocks, kv_num_heads, block_size, - // head_size // 2] + T* __restrict__ key_cache, // [num_blocks, kv_num_heads, block_size, + // head_size // 2] + T* __restrict__ value_cache, // [num_blocks, kv_num_heads, block_size, + // head_size // 2] T* __restrict__ qkv_out, - const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq] - const int* __restrict__ batch_id_per_token, // [num_tokens] + const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq] + const int* __restrict__ cu_seqlens_q, const int* __restrict__ seq_lens, // [bsz] const int* __restrict__ seq_lens_encoder, // [bsz] @@ -57,6 +253,9 @@ __global__ void append_decode_cache_T_rope_kernel( const int64_t hidden_size = (num_heads + 2 * kv_num_heads) * head_size; // const int64_t offset = 2 * hidden_size; const int half_head_size = head_size / 2; +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaGridDependencySynchronize(); +#endif for (int32_t linear_index = global_thread_idx * VecSize, step = gridDim.x * blockDim.x * VecSize; linear_index < elem_cnt; @@ -83,7 +282,8 @@ __global__ void append_decode_cache_T_rope_kernel( if (hi < num_heads + kv_num_heads) { // q k rope const uint32_t emb_idx = write_seq_id * half_head_size + h_bias / 2; - uint32_t new_emb_idx = rope_3d ? emb_idx + ori_bi * max_seq_len * head_size : emb_idx; + uint32_t new_emb_idx = + rope_3d ? emb_idx + ori_bi * max_seq_len * head_size : emb_idx; Load(&cos_emb[new_emb_idx], &cos_emb_vec); Load(&sin_emb[new_emb_idx], &sin_emb_vec); } @@ -97,9 +297,11 @@ __global__ void append_decode_cache_T_rope_kernel( const float cos_tmp = cos_emb_vec[i]; const float sin_tmp = sin_emb_vec[i]; out_vec[2 * i] = - static_cast(input_left * cos_tmp - input_right * sin_tmp); + static_cast(fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp)); out_vec[2 * i + 1] = - static_cast(input_right * cos_tmp + input_left * sin_tmp); + static_cast(fmul_func(input_right, cos_tmp) + + fmul_func(input_left, sin_tmp)); } else { out_vec[2 * i] = src_vec[2 * i]; out_vec[2 * i + 1] = src_vec[2 * i + 1]; @@ -122,10 +324,13 @@ __global__ void append_decode_cache_T_rope_kernel( } } } +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaTriggerProgrammaticLaunchCompletion(); +#endif } -template -__global__ void append_decode_cache_T_rope_kernel( +template +__global__ void append_decode_cache_T_quant_rope_kernel( const int* __restrict__ quant_qkv, // [bsz, num_heads + 2 * kv_num_heads, // head_size] T* __restrict__ key_cache, // [num_blocks, kv_num_heads, block_size, @@ -133,8 +338,8 @@ __global__ void append_decode_cache_T_rope_kernel( T* __restrict__ value_cache, // [num_blocks, kv_num_heads, block_size, // head_size // 2] T* __restrict__ qkv_out, - const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq] - const int* __restrict__ batch_id_per_token, // [num_tokens] + const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq] + const int* __restrict__ cu_seqlens_q, const int* __restrict__ seq_lens, // [bsz] const int* __restrict__ seq_lens_encoder, // [bsz] @@ -142,8 +347,8 @@ __global__ void append_decode_cache_T_rope_kernel( const float* __restrict__ sin_emb, const float* __restrict__ qkv_out_scales, // [num_head + 2 * // kv_num_heads, dim_head] - const T* __restrict__ qkv_biases, // [num_head + 2 * kv_num_heads, - // dim_head] + const T* __restrict__ qkv_biases, // [num_head + 2 * kv_num_heads, + // dim_head] const int max_seq_len, const int max_blocks_per_seq, const int num_heads, @@ -169,6 +374,9 @@ __global__ void append_decode_cache_T_rope_kernel( const int64_t hidden_size = (num_heads + 2 * kv_num_heads) * head_size; // const int64_t offset = 2 * hidden_size; const int half_head_size = head_size / 2; +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaGridDependencySynchronize(); +#endif for (int32_t linear_index = global_thread_idx * VecSize, step = gridDim.x * blockDim.x * VecSize; linear_index < elem_cnt; @@ -199,8 +407,10 @@ __global__ void append_decode_cache_T_rope_kernel( if (hi < num_heads + kv_num_heads) { // q k rope const uint32_t emb_idx = write_seq_id * half_head_size + h_bias / 2; - Load(&cos_emb[emb_idx], &cos_emb_vec); - Load(&sin_emb[emb_idx], &sin_emb_vec); + uint32_t new_emb_idx = + rope_3d ? emb_idx + ori_bi * max_seq_len * head_size : emb_idx; + Load(&cos_emb[new_emb_idx], &cos_emb_vec); + Load(&sin_emb[new_emb_idx], &sin_emb_vec); } #pragma unroll for (int i = 0; i < HalfVecSize; i++) { @@ -217,9 +427,11 @@ __global__ void append_decode_cache_T_rope_kernel( const float cos_tmp = cos_emb_vec[i]; const float sin_tmp = sin_emb_vec[i]; bias_vec[2 * i] = - static_cast(input_left * cos_tmp - input_right * sin_tmp); + static_cast(fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp)); bias_vec[2 * i + 1] = - static_cast(input_right * cos_tmp + input_left * sin_tmp); + static_cast(fmul_func(input_right, cos_tmp) + + fmul_func(input_left, sin_tmp)); } else { bias_vec[2 * i] = static_cast(input_left); bias_vec[2 * i + 1] = static_cast(input_right); @@ -242,9 +454,158 @@ __global__ void append_decode_cache_T_rope_kernel( } } } +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaTriggerProgrammaticLaunchCompletion(); +#endif +} + +template +__global__ void append_decode_cache_T_neox_partial_rope_kernel( + const T* __restrict__ qkv, // [bsz, num_heads + 2 * kv_num_heads, + // head_size] + T* __restrict__ key_cache, // [num_blocks, kv_num_heads, block_size, + // head_size // 2] + T* __restrict__ value_cache, // [num_blocks, kv_num_heads, block_size, + // head_size // 2] + T* __restrict__ qkv_out, + const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq] + const int* __restrict__ cu_seqlens_q, + const int* __restrict__ seq_lens, // [bsz] + const int* __restrict__ seq_lens_encoder, // [bsz] + const float* __restrict__ cos_emb, // [2, 1, max_model_len, 1, + // rotary_dim/2] + const float* __restrict__ sin_emb, // [2, 1, max_model_len, 1, + // rotary_dim/2] + const int max_seq_len, + const int max_blocks_per_seq, + const int num_heads, + const int head_size, + const int rotary_dim, + const int block_size, + const uint32_t elem_cnt, + const int kv_num_heads, + const bool rope_3d) { + using LoadT = AlignedVector; + using LoadBiasT = AlignedVector; + using LoadKVT = AlignedVector; + constexpr int HalfVecSize = VecSize / 2; + using LoadEmbT = AlignedVector; + + LoadT left_vec, right_vec; + LoadBiasT left_bias_vec, right_bias_vec; + LoadKVT left_cache_vec, right_cache_vec; + LoadEmbT cos_emb_vec; + LoadEmbT sin_emb_vec; + + int64_t global_thread_idx = blockDim.x * blockIdx.x + threadIdx.x; + const int half_head_size = head_size / 2; + const int half_rotary_dim = rotary_dim / 2; + const int64_t hidden_size = (num_heads + 2 * kv_num_heads) * head_size; + const int64_t half_hidden_size = hidden_size / 2; + // const int64_t offset = 2 * hidden_size; +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaGridDependencySynchronize(); +#endif + for (int32_t linear_index = global_thread_idx * VecSize, + step = gridDim.x * blockDim.x * VecSize; + linear_index < elem_cnt; + linear_index += step) { + const int ori_bi = linear_index / half_hidden_size; + const int bias = linear_index % half_hidden_size; + const int hi = bias / half_head_size; // q + k + v + const int h_bias = bias % half_head_size; + if (hi < num_heads && h_bias >= half_rotary_dim) { + continue; + } + const int start_token_idx = cu_seqlens_q[ori_bi]; + if (seq_lens_encoder[ori_bi] > 0) return; + const int write_seq_id = seq_lens[ori_bi]; + if (write_seq_id == 0) continue; + + const int* block_table_now = nullptr; + + block_table_now = block_tables + ori_bi * max_blocks_per_seq; + const int block_idx = block_table_now[write_seq_id / block_size]; + const int block_offset = write_seq_id % block_size; + uint32_t ori_idx_left = + start_token_idx * hidden_size + hi * head_size + h_bias; + uint32_t ori_idx_right = ori_idx_left + half_head_size; + if (hi < num_heads) { + ori_idx_right = ori_idx_left + half_rotary_dim; + } else if (hi < num_heads + kv_num_heads) { + if (h_bias < half_rotary_dim) { + ori_idx_right = ori_idx_left + half_rotary_dim; + } else { + ori_idx_left = ori_idx_left + half_rotary_dim; + ori_idx_right = ori_idx_left + half_rotary_dim; + } + } + + Load(&qkv[ori_idx_left], &left_vec); + Load(&qkv[ori_idx_right], &right_vec); + + if (hi < num_heads + kv_num_heads) { + // q k rope + const uint32_t emb_idx = write_seq_id * half_rotary_dim + h_bias; + uint32_t new_emb_idx = + rope_3d ? emb_idx + ori_bi * max_seq_len * head_size * 2 : emb_idx; + if (h_bias < half_rotary_dim) { + Load(&cos_emb[new_emb_idx], &cos_emb_vec); + Load(&sin_emb[new_emb_idx], &sin_emb_vec); + } + } +#pragma unroll + for (int i = 0; i < VecSize; i++) { + // rope + float input_left = static_cast(left_vec[i]); + float input_right = static_cast(right_vec[i]); + if (hi < num_heads + kv_num_heads && h_bias < half_rotary_dim) { + const float cos_tmp = cos_emb_vec[i]; + const float sin_tmp = sin_emb_vec[i]; + left_bias_vec[i] = + static_cast(fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp)); + right_bias_vec[i] = + static_cast(fmul_func(input_right, cos_tmp) + + fmul_func(input_left, sin_tmp)); + } else { + left_bias_vec[i] = static_cast(input_left); + right_bias_vec[i] = static_cast(input_right); + } + } + if (hi < num_heads) { + // write q + Store(left_bias_vec, &qkv_out[ori_idx_left]); + Store(right_bias_vec, &qkv_out[ori_idx_right]); + } else { + // write k/v + const uint32_t kv_head_idx = (hi - num_heads) % kv_num_heads; + uint32_t tgt_idx_left = + block_idx * kv_num_heads * block_size * head_size + + kv_head_idx * block_size * head_size + block_offset * head_size + + h_bias; + uint32_t tgt_idx_right = tgt_idx_left + half_head_size; + if (hi < num_heads + kv_num_heads) { + if (h_bias < half_rotary_dim) { + tgt_idx_right = tgt_idx_left + half_rotary_dim; + } else { + tgt_idx_left = tgt_idx_left + half_rotary_dim; + tgt_idx_right = tgt_idx_left + half_rotary_dim; + } + Store(left_bias_vec, &key_cache[tgt_idx_left]); + Store(right_bias_vec, &key_cache[tgt_idx_right]); + } else { + Store(left_bias_vec, &value_cache[tgt_idx_left]); + Store(right_bias_vec, &value_cache[tgt_idx_right]); + } + } + } +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaTriggerProgrammaticLaunchCompletion(); +#endif } -template +template __global__ void append_decode_cache_T_neox_rope_kernel( const T* __restrict__ qkv, // [bsz, num_heads + 2 * kv_num_heads, // head_size] @@ -253,8 +614,7 @@ __global__ void append_decode_cache_T_neox_rope_kernel( T* __restrict__ value_cache, // [num_blocks, kv_num_heads, block_size, // head_size // 2] T* __restrict__ qkv_out, - const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq] - const int* __restrict__ batch_id_per_token, // [num_tokens] + const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq] const int* __restrict__ cu_seqlens_q, const int* __restrict__ seq_lens, // [bsz] const int* __restrict__ seq_lens_encoder, // [bsz] @@ -266,7 +626,8 @@ __global__ void append_decode_cache_T_neox_rope_kernel( const int head_size, const int block_size, const uint32_t elem_cnt, - const int kv_num_heads) { + const int kv_num_heads, + const bool rope_3d) { using LoadT = AlignedVector; using LoadBiasT = AlignedVector; using LoadKVT = AlignedVector; @@ -284,7 +645,9 @@ __global__ void append_decode_cache_T_neox_rope_kernel( const int64_t hidden_size = (num_heads + 2 * kv_num_heads) * head_size; const int64_t half_hidden_size = hidden_size / 2; // const int64_t offset = 2 * hidden_size; - +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaGridDependencySynchronize(); +#endif for (int32_t linear_index = global_thread_idx * VecSize, step = gridDim.x * blockDim.x * VecSize; linear_index < elem_cnt; @@ -313,8 +676,10 @@ __global__ void append_decode_cache_T_neox_rope_kernel( if (hi < num_heads + kv_num_heads) { // q k rope const uint32_t emb_idx = write_seq_id * head_size + h_bias; - Load(&cos_emb[emb_idx], &cos_emb_vec); - Load(&sin_emb[emb_idx], &sin_emb_vec); + uint32_t new_emb_idx = + rope_3d ? emb_idx + ori_bi * max_seq_len * head_size * 2 : emb_idx; + Load(&cos_emb[new_emb_idx], &cos_emb_vec); + Load(&sin_emb[new_emb_idx], &sin_emb_vec); } #pragma unroll for (int i = 0; i < VecSize; i++) { @@ -325,9 +690,11 @@ __global__ void append_decode_cache_T_neox_rope_kernel( const float cos_tmp = cos_emb_vec[i]; const float sin_tmp = sin_emb_vec[i]; left_bias_vec[i] = - static_cast(input_left * cos_tmp - input_right * sin_tmp); + static_cast(fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp)); right_bias_vec[i] = - static_cast(input_right * cos_tmp + input_left * sin_tmp); + static_cast(fmul_func(input_right, cos_tmp) + + fmul_func(input_left, sin_tmp)); } else { left_bias_vec[i] = static_cast(input_left); right_bias_vec[i] = static_cast(input_right); @@ -354,10 +721,13 @@ __global__ void append_decode_cache_T_neox_rope_kernel( } } } +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaTriggerProgrammaticLaunchCompletion(); +#endif } -template -__global__ void append_decode_cache_T_neox_rope_kernel( +template +__global__ void append_decode_cache_T_quant_neox_rope_kernel( const int* __restrict__ quant_qkv, // [bsz, num_heads + 2 * kv_num_heads, // head_size] T* __restrict__ key_cache, // [num_blocks, kv_num_heads, block_size, @@ -365,8 +735,7 @@ __global__ void append_decode_cache_T_neox_rope_kernel( T* __restrict__ value_cache, // [num_blocks, kv_num_heads, block_size, // head_size // 2] T* __restrict__ qkv_out, - const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq] - const int* __restrict__ batch_id_per_token, // [num_tokens] + const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq] const int* __restrict__ cu_seqlens_q, const int* __restrict__ seq_lens, // [bsz] const int* __restrict__ seq_lens_encoder, // [bsz] @@ -374,15 +743,16 @@ __global__ void append_decode_cache_T_neox_rope_kernel( const float* __restrict__ sin_emb, const float* __restrict__ qkv_out_scales, // [num_head + 2 * // kv_num_heads, dim_head] - const T* __restrict__ qkv_biases, // [num_head + 2 * kv_num_heads, - // dim_head] + const T* __restrict__ qkv_biases, // [num_head + 2 * kv_num_heads, + // dim_head] const int max_seq_len, const int max_blocks_per_seq, const int num_heads, const int head_size, const int block_size, const uint32_t elem_cnt, - const int kv_num_heads) { + const int kv_num_heads, + const bool rope_3d) { using LoadT = AlignedVector; using LoadBiasT = AlignedVector; using LoadOutScaleT = AlignedVector; @@ -400,7 +770,9 @@ __global__ void append_decode_cache_T_neox_rope_kernel( const int half_head_size = head_size / 2; const int64_t hidden_size = (num_heads + 2 * kv_num_heads) * head_size; const int64_t half_hidden_size = hidden_size / 2; - +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaGridDependencySynchronize(); +#endif for (int32_t linear_index = global_thread_idx * VecSize, step = gridDim.x * blockDim.x * VecSize; linear_index < elem_cnt; @@ -416,79 +788,728 @@ __global__ void append_decode_cache_T_neox_rope_kernel( const int* block_table_now = nullptr; - block_table_now = block_tables + ori_bi * max_blocks_per_seq; - const int block_idx = block_table_now[write_seq_id / block_size]; - const int block_offset = write_seq_id % block_size; - const uint32_t ori_idx_left = - start_token_idx * hidden_size + hi * head_size + h_bias; - const uint32_t ori_idx_right = ori_idx_left + half_head_size; + block_table_now = block_tables + ori_bi * max_blocks_per_seq; + const int block_idx = block_table_now[write_seq_id / block_size]; + const int block_offset = write_seq_id % block_size; + const uint32_t ori_idx_left = + start_token_idx * hidden_size + hi * head_size + h_bias; + const uint32_t ori_idx_right = ori_idx_left + half_head_size; + + const int bias_idx_left = hi * head_size + h_bias; + const int bias_idx_right = bias_idx_left + half_head_size; + + Load(&quant_qkv[ori_idx_left], &left_vec); + Load(&quant_qkv[ori_idx_right], &right_vec); + if (qkv_biases) { + Load(&qkv_biases[bias_idx_left], &left_bias_vec); + Load(&qkv_biases[bias_idx_right], &right_bias_vec); + } + + Load(&qkv_out_scales[bias_idx_left], &left_out_scale_vec); + Load(&qkv_out_scales[bias_idx_right], &right_out_scale_vec); + + if (hi < num_heads + kv_num_heads) { + // q k rope + const uint32_t emb_idx = write_seq_id * head_size + h_bias; + uint32_t new_emb_idx = + rope_3d ? emb_idx + ori_bi * max_seq_len * head_size * 2 : emb_idx; + Load(&cos_emb[new_emb_idx], &cos_emb_vec); + Load(&sin_emb[new_emb_idx], &sin_emb_vec); + } +#pragma unroll + for (int i = 0; i < VecSize; i++) { + // dequant + add_bias + rope + float input_left = static_cast(left_vec[i]); + float input_right = static_cast(right_vec[i]); + input_left = qkv_biases ? input_left * left_out_scale_vec[i] + + static_cast(left_bias_vec[i]) + : input_left * left_out_scale_vec[i]; + input_right = qkv_biases ? input_right * right_out_scale_vec[i] + + static_cast(right_bias_vec[i]) + : input_right * right_out_scale_vec[i]; + if (hi < num_heads + kv_num_heads) { + const float cos_tmp = cos_emb_vec[i]; + const float sin_tmp = sin_emb_vec[i]; + left_bias_vec[i] = + static_cast(fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp)); + right_bias_vec[i] = + static_cast(fmul_func(input_right, cos_tmp) + + fmul_func(input_left, sin_tmp)); + } else { + left_bias_vec[i] = static_cast(input_left); + right_bias_vec[i] = static_cast(input_right); + } + } + if (hi < num_heads) { + // write q + Store(left_bias_vec, &qkv_out[ori_idx_left]); + Store(right_bias_vec, &qkv_out[ori_idx_right]); + } else { + // quant + write k/v + const uint32_t kv_head_idx = (hi - num_heads) % kv_num_heads; + const uint32_t tgt_idx_left = + block_idx * kv_num_heads * block_size * head_size + + kv_head_idx * block_size * head_size + block_offset * head_size + + h_bias; + const uint32_t tgt_idx_right = tgt_idx_left + half_head_size; + if (hi < num_heads + kv_num_heads) { + Store(left_bias_vec, &key_cache[tgt_idx_left]); + Store(right_bias_vec, &key_cache[tgt_idx_right]); + } else { + Store(left_bias_vec, &value_cache[tgt_idx_left]); + Store(right_bias_vec, &value_cache[tgt_idx_right]); + } + } + } +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaTriggerProgrammaticLaunchCompletion(); +#endif +} + +template +__global__ void append_decode_cache_T_int8_neox_rope_kernel( + const T* __restrict__ quant_qkv, // [bsz, num_heads + 2 * kv_num_heads, + // head_size] + uint8_t* __restrict__ key_cache, // [num_blocks, kv_num_heads, + // block_size, head_size // 2] + uint8_t* __restrict__ value_cache, // [num_blocks, kv_num_heads, + // block_size, head_size // 2] + T* __restrict__ qkv_out, + const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq] + const int* __restrict__ cu_seqlens_q, + const int* __restrict__ seq_lens, // [bsz] + const int* __restrict__ seq_lens_encoder, // [bsz] + const float* __restrict__ cos_emb, + const float* __restrict__ sin_emb, + T* __restrict__ cache_k_scale, + T* __restrict__ cache_v_scale, + const int max_seq_len, + const int max_blocks_per_seq, + const int num_heads, + const int block_size, + const float max_bound, + const float min_bound, + const int kv_num_heads, + const bool rope_3d, + const float rms_norm_eps) { + static_assert(HeadDim == 128, "just support HeadDim be 128 now!"); + static_assert(VecSize == 4, "just support VecSize be 4 now, 32 * 4!"); + constexpr int NUM_WARPS = 4; + const int tid = threadIdx.x; + const int wid = tid / 32; + const int lane_id = tid % 32; + const int bid = blockIdx.x, head_idx = blockIdx.y * NUM_WARPS + wid; + int q_head_idx, k_head_idx, v_idx; + const int64_t hidden_size = (num_heads + 2 * kv_num_heads) * HeadDim; + constexpr int half_head_size = HeadDim / 2; + const int start_token_idx = cu_seqlens_q[bid]; + if (seq_lens_encoder[bid] > 0) return; + const int write_seq_id = seq_lens[bid]; + if (write_seq_id == 0) return; + const int* block_table_now = nullptr; + + block_table_now = block_tables + bid * max_blocks_per_seq; + const int block_idx = __ldg(&block_table_now[write_seq_id / block_size]); + const int block_offset = write_seq_id % block_size; + + float thread_m2 = 0.0f; + float warp_m2 = 0.0f; +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaGridDependencySynchronize(); +#endif + if (head_idx < num_heads) { + // q + using LoadT = AlignedVector; + using LoadBiasT = AlignedVector; + constexpr int HalfVecSize = VecSize / 2; + using LoadEmbT = AlignedVector; + + LoadT src_vec; + LoadT src_vec_right; + LoadBiasT out_vec; + LoadBiasT out_vec_right; + LoadEmbT cos_emb_vec; + LoadEmbT sin_emb_vec; + const T* qkv_now = quant_qkv + start_token_idx * hidden_size; + T* qkv_out_now = qkv_out + start_token_idx * hidden_size; +#pragma unroll + for (uint32_t head_bias = lane_id * VecSize; head_bias < half_head_size; + head_bias += 32 * VecSize) { + const int bias_idx = head_idx * HeadDim + head_bias; + Load(&qkv_now[bias_idx], &src_vec); + Load(&qkv_now[bias_idx + half_head_size], &src_vec_right); + // q rope + const uint32_t emb_idx = write_seq_id * HeadDim + head_bias; + const uint32_t new_emb_idx = + rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx; + Load(&cos_emb[new_emb_idx], &cos_emb_vec); + Load(&sin_emb[new_emb_idx], &sin_emb_vec); +#pragma unroll + for (int i = 0; i < VecSize; i++) { + // dequant + add_bias + rope + float input_left = static_cast(src_vec[i]); + float input_right = static_cast(src_vec_right[i]); + + const float cos_tmp = cos_emb_vec[i]; + const float sin_tmp = sin_emb_vec[i]; + float tmp1 = fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp); + float tmp2 = fmul_func(input_right, cos_tmp) + + fmul_func(input_left, sin_tmp); + thread_m2 += tmp1 * tmp1 + tmp2 * tmp2; + out_vec[i] = static_cast(tmp1); + out_vec_right[i] = static_cast(tmp2); + } + Store(out_vec, &qkv_out_now[bias_idx]); + Store(out_vec_right, &qkv_out_now[bias_idx + half_head_size]); + } + } else if (head_idx < num_heads + 2 * kv_num_heads) { + // k + constexpr int KV_VEC_SIZE = 16 / sizeof(uint8_t); // 16 + using LoadPadKVT = AlignedVector; + const uint32_t kv_head_idx = (head_idx - num_heads) % kv_num_heads; + if (block_offset == 0) { + // pad zero for this kv_head_idx for this block + LoadPadKVT pad_cache_vec; + *(reinterpret_cast(pad_cache_vec.val)) = make_uint4(0, 0, 0, 0); + if (head_idx < num_heads + kv_num_heads) { + constexpr int num_vecs_per_head_dim = HeadDim / KV_VEC_SIZE; + constexpr int num_token_each_time = 32 / num_vecs_per_head_dim; + const uint32_t tgt_idx = + (block_idx * kv_num_heads + kv_head_idx) * block_size * HeadDim + + lane_id % num_vecs_per_head_dim * KV_VEC_SIZE; + for (int block_i = lane_id / num_vecs_per_head_dim; + block_i < block_size; + block_i += num_token_each_time) { + Store(pad_cache_vec, + &key_cache[tgt_idx + block_i * HeadDim]); + } + } else { + const int num_vecs_per_head_dim = block_size / KV_VEC_SIZE; + const int num_token_each_time = 32 / num_vecs_per_head_dim; + const uint32_t tgt_idx = + (block_idx * kv_num_heads + kv_head_idx) * HeadDim * block_size + + lane_id % num_vecs_per_head_dim * KV_VEC_SIZE; + for (int block_i = lane_id / num_vecs_per_head_dim; block_i < HeadDim; + block_i += num_token_each_time) { + Store( + pad_cache_vec, &value_cache[tgt_idx + block_i * block_size]); + } + } + __syncwarp(); + } + + constexpr int K_VEC_SIZE = 4; + constexpr int HALF_K_VEC_SIZE = 2; + using LoadKVResT = AlignedVector; + using LoadKVT = AlignedVector; + using LoadT = AlignedVector; + using LoadBiasT = AlignedVector; + using LoadEmbT = AlignedVector; + LoadKVResT cache_vec; + LoadT src_vec1, src_vec1_right, src_vec2, src_vec2_right; + LoadBiasT out_vec1, out_vec2; + LoadEmbT cos_emb_vec1, cos_emb_vec2; + LoadEmbT sin_emb_vec1, sin_emb_vec2; + + const T* qkv_now = quant_qkv + start_token_idx * hidden_size; + const int head_bias = lane_id / 4 * 16 + lane_id % 4 * 2; + const int bias_idx = head_idx * HeadDim + head_bias; + Load(&qkv_now[bias_idx], &src_vec1); + Load(&qkv_now[bias_idx + 8], &src_vec2); + T scale = T(1.0f); + const int k_head_idx = head_idx - num_heads; + const int v_head_idx = head_idx - num_heads - kv_num_heads; + if (head_idx < num_heads + kv_num_heads) { + Load( + &qkv_now[head_idx * HeadDim + (head_bias + half_head_size) % HeadDim], + &src_vec1_right); + Load( + &qkv_now[head_idx * HeadDim + + (head_bias + 8 + half_head_size) % HeadDim], + &src_vec2_right); + + const uint32_t emb_idx = write_seq_id * HeadDim + head_bias; + const uint32_t new_emb_idx = + rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx; + Load(&cos_emb[new_emb_idx], &cos_emb_vec1); + Load(&cos_emb[new_emb_idx + 8], &cos_emb_vec2); + Load(&sin_emb[new_emb_idx], &sin_emb_vec1); + Load(&sin_emb[new_emb_idx + 8], &sin_emb_vec2); + } + + if (head_idx < num_heads + kv_num_heads) { + float input_left = static_cast(src_vec1[0]); + float input_right = static_cast(src_vec1_right[0]); + float cos_tmp = cos_emb_vec1[0]; + float sin_tmp = sin_emb_vec1[0]; + float tmp1 = 0; + if (head_bias < half_head_size) { + tmp1 = fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp); + } else { + tmp1 = fmul_func(input_left, cos_tmp) + + fmul_func(input_right, sin_tmp); + } + out_vec1[0] = static_cast(tmp1); + input_left = static_cast(src_vec1[1]); + input_right = static_cast(src_vec1_right[1]); + cos_tmp = cos_emb_vec1[1]; + sin_tmp = sin_emb_vec1[1]; + if (head_bias < half_head_size) { + tmp1 = fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp); + } else { + tmp1 = fmul_func(input_left, cos_tmp) + + fmul_func(input_right, sin_tmp); + } + out_vec1[1] = static_cast(tmp1); + } else { + out_vec1[0] = src_vec1[0]; + out_vec1[1] = src_vec1[1]; + } + + // rope + if (head_idx < num_heads + kv_num_heads) { + float input_left = static_cast(src_vec2[0]); + float input_right = static_cast(src_vec2_right[0]); + float cos_tmp = cos_emb_vec2[0]; + float sin_tmp = sin_emb_vec2[0]; + float tmp1 = 0; + if (head_bias < half_head_size) { + tmp1 = fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp); + } else { + tmp1 = fmul_func(input_left, cos_tmp) + + fmul_func(input_right, sin_tmp); + } + out_vec2[0] = static_cast(tmp1); + input_left = static_cast(src_vec2[1]); + input_right = static_cast(src_vec2_right[1]); + cos_tmp = cos_emb_vec2[1]; + sin_tmp = sin_emb_vec2[1]; + if (head_bias < half_head_size) { + tmp1 = fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp); + } else { + tmp1 = fmul_func(input_left, cos_tmp) + + fmul_func(input_right, sin_tmp); + } + out_vec2[1] = static_cast(tmp1); + } else { + out_vec2[0] = src_vec2[0]; + out_vec2[1] = src_vec2[1]; + } + if constexpr (IsDynamic) { + // reduce max, 1 head per warp + T local_max = -INFINITY; +#pragma unroll + for (int i = 0; i < HALF_K_VEC_SIZE; i++) { + local_max = __hmax(local_max, __habs(out_vec1[i])); + local_max = __hmax(local_max, __habs(out_vec2[i])); + } +#pragma unroll + for (int m_offset = 16; m_offset > 0; m_offset /= 2) { + local_max = + __hmax(local_max, __shfl_xor_sync(0xffffffff, local_max, m_offset)); + } + scale = __hdiv(448, local_max); + + int cache_offset; + if (head_idx < num_heads) { + cache_offset = 0; + } else if (head_idx < num_heads + 2 * kv_num_heads) { + cache_offset = block_idx * kv_num_heads * block_size + + (head_idx - num_heads) % kv_num_heads * block_size + + block_offset; + } + T* cache_k_scale_now = cache_k_scale + cache_offset; + T* cache_v_scale_now = cache_v_scale + cache_offset; + if (lane_id == 0) { + if (head_idx < num_heads + kv_num_heads) { + cache_k_scale_now[0] = __hdiv(1, scale); + } else { + cache_v_scale_now[0] = __hdiv(1, scale); + } + } + } else { + if (head_idx < num_heads + kv_num_heads) { + scale = __ldg(&cache_k_scale[kv_head_idx]); + } else { + scale = __ldg(&cache_v_scale[kv_head_idx]); + } + } + +#pragma unroll + for (uint32_t i = 0; i < HALF_K_VEC_SIZE; i++) { + cache_vec[i] = QuantToC8( + scale, out_vec1[i], max_bound, min_bound); + cache_vec[i + HALF_K_VEC_SIZE] = QuantToC8( + scale, out_vec2[i], max_bound, min_bound); + } + if (head_idx < num_heads + kv_num_heads) { + const int start_block_16 = + block_offset / 16 * 16 + block_offset % 8 + lane_id / 4 % 2 * 8; + const uint32_t tgt_cache_idx = + block_idx * kv_num_heads * block_size * HeadDim + + kv_head_idx * block_size * HeadDim + start_block_16 * HeadDim + + lane_id / 4 / 2 * 32 + (block_offset % 16) / 8 * 16 + lane_id % 4 * 4; + Store(cache_vec, &key_cache[tgt_cache_idx]); + } else { + const uint32_t base_tgt_cache_idx = + block_idx * kv_num_heads * HeadDim * block_size + + kv_head_idx * HeadDim * block_size + + (lane_id / 4 * 16 + lane_id % 4 * 2) * block_size + + block_offset / 16 % 2 * 8 * block_size + block_offset / 16 / 2 * 32; + const uint32_t tgt_cache_idx1 = base_tgt_cache_idx + + block_offset % 8 / 2 * 4 // per 4 + + block_offset % 16 / 8 * 2 // per 2 + + block_offset % 2; // per 1 + const uint32_t tgt_cache_idx2 = tgt_cache_idx1 + block_size; + const uint32_t tgt_cache_idx3 = tgt_cache_idx1 + 16; + const uint32_t tgt_cache_idx4 = tgt_cache_idx3 + block_size; + value_cache[tgt_cache_idx1] = cache_vec[0]; + value_cache[tgt_cache_idx2] = cache_vec[1]; + value_cache[tgt_cache_idx3] = cache_vec[2]; + value_cache[tgt_cache_idx4] = cache_vec[3]; + } + } +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaTriggerProgrammaticLaunchCompletion(); +#endif +} + +template +__global__ void append_decode_cache_int8_rope_qk_norm_kernel( + const T* __restrict__ quant_qkv, // [bsz, num_heads + 2 * kv_num_heads, + // head_size] + uint8_t* __restrict__ key_cache, // [num_blocks, kv_num_heads, + // block_size, head_size // 2] + uint8_t* __restrict__ value_cache, // [num_blocks, kv_num_heads, + // block_size, head_size // 2] + T* __restrict__ qkv_out, + const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq] + const int* __restrict__ cu_seqlens_q, + const int* __restrict__ seq_lens, // [bsz] + const int* __restrict__ seq_lens_encoder, // [bsz] + const float* __restrict__ cos_emb, + const float* __restrict__ sin_emb, + T* __restrict__ cache_k_scale, + T* __restrict__ cache_v_scale, + const float* q_norm_weight, + const float* k_norm_weight, + const int max_seq_len, + const int max_blocks_per_seq, + const int num_heads, + const int block_size, + const float max_bound, + const float min_bound, + const int kv_num_heads, + const bool rope_3d, + const float rms_norm_eps) { + static_assert(HeadDim == 128, "just support HeadDim be 128 now!"); + static_assert(VecSize == 4, "just support VecSize be 4 now, 32 * 4!"); + constexpr int NUM_WARPS = 4; + const int tid = threadIdx.x; + const int wid = tid / 32; + const int lane_id = tid % 32; + const int bid = blockIdx.x, head_idx = blockIdx.y * NUM_WARPS + wid; + int q_head_idx, k_head_idx, v_idx; + const int64_t hidden_size = (num_heads + 2 * kv_num_heads) * HeadDim; + constexpr int half_head_size = HeadDim / 2; + const int start_token_idx = cu_seqlens_q[bid]; + if (seq_lens_encoder[bid] > 0) return; + const int write_seq_id = seq_lens[bid]; + if (write_seq_id == 0) return; + const int* block_table_now = nullptr; + + block_table_now = block_tables + bid * max_blocks_per_seq; + const int block_idx = __ldg(&block_table_now[write_seq_id / block_size]); + const int block_offset = write_seq_id % block_size; + + float thread_m2 = 0.0f; + float warp_m2 = 0.0f; +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaGridDependencySynchronize(); +#endif + if (head_idx < num_heads) { + // q + using LoadT = AlignedVector; + using LoadBiasT = AlignedVector; + using LoadOutScaleT = AlignedVector; + constexpr int HalfVecSize = VecSize / 2; + using LoadEmbT = AlignedVector; + + LoadT src_vec; + LoadBiasT out_vec; + LoadEmbT cos_emb_vec; + LoadEmbT sin_emb_vec; + const T* qkv_now = quant_qkv + start_token_idx * hidden_size; + T* qkv_out_now = qkv_out + start_token_idx * hidden_size; +#pragma unroll + for (uint32_t head_bias = lane_id * VecSize; head_bias < HeadDim; + head_bias += 32 * VecSize) { + const int bias_idx = head_idx * HeadDim + head_bias; + Load(&qkv_now[bias_idx], &src_vec); + // q rope + const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2; + const uint32_t new_emb_idx = + rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx; + Load(&cos_emb[new_emb_idx], &cos_emb_vec); + Load(&sin_emb[new_emb_idx], &sin_emb_vec); +#pragma unroll + for (int i = 0; i < HalfVecSize; i++) { + // dequant + add_bias + rope + float input_left = static_cast(src_vec[2 * i]); + float input_right = static_cast(src_vec[2 * i + 1]); + + const float cos_tmp = cos_emb_vec[i]; + const float sin_tmp = sin_emb_vec[i]; + float tmp1 = fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp); + float tmp2 = fmul_func(input_right, cos_tmp) + + fmul_func(input_left, sin_tmp); + thread_m2 += tmp1 * tmp1 + tmp2 * tmp2; + out_vec[2 * i] = static_cast(tmp1); + out_vec[2 * i + 1] = static_cast(tmp2); + } + // qk norm + if (q_norm_weight) { + WelfordWarpAllReduce(thread_m2, &warp_m2); + float row_variance = max(warp_m2 / HeadDim, 0.0f); + float row_inv_var = Rsqrt(row_variance + rms_norm_eps); + LoadOutScaleT q_norm_vec; + Load(&q_norm_weight[lane_id * VecSize], &q_norm_vec); +#pragma unroll + for (int i = 0; i < VecSize; i++) { + out_vec[i] = static_cast(static_cast(out_vec[i]) * + row_inv_var * q_norm_vec[i]); + } + } + Store(out_vec, &qkv_out_now[bias_idx]); + } + } else if (head_idx < num_heads + 2 * kv_num_heads) { + // k + constexpr int KV_VEC_SIZE = 16 / sizeof(uint8_t); // 16 + using LoadPadKVT = AlignedVector; + const uint32_t kv_head_idx = (head_idx - num_heads) % kv_num_heads; + if (block_offset == 0) { + // pad zero for this kv_head_idx for this block + LoadPadKVT pad_cache_vec; + *(reinterpret_cast(pad_cache_vec.val)) = make_uint4(0, 0, 0, 0); + if (head_idx < num_heads + kv_num_heads) { + constexpr int num_vecs_per_head_dim = HeadDim / KV_VEC_SIZE; + constexpr int num_token_each_time = 32 / num_vecs_per_head_dim; + const uint32_t tgt_idx = + (block_idx * kv_num_heads + kv_head_idx) * block_size * HeadDim + + lane_id % num_vecs_per_head_dim * KV_VEC_SIZE; + for (int block_i = lane_id / num_vecs_per_head_dim; + block_i < block_size; + block_i += num_token_each_time) { + Store(pad_cache_vec, + &key_cache[tgt_idx + block_i * HeadDim]); + } + } else { + const int num_vecs_per_head_dim = block_size / KV_VEC_SIZE; + const int num_token_each_time = 32 / num_vecs_per_head_dim; + const uint32_t tgt_idx = + (block_idx * kv_num_heads + kv_head_idx) * HeadDim * block_size + + lane_id % num_vecs_per_head_dim * KV_VEC_SIZE; + for (int block_i = lane_id / num_vecs_per_head_dim; block_i < HeadDim; + block_i += num_token_each_time) { + Store( + pad_cache_vec, &value_cache[tgt_idx + block_i * block_size]); + } + } + __syncwarp(); + } + + constexpr int K_VEC_SIZE = 4; + constexpr int HALF_K_VEC_SIZE = 2; + using LoadKVResT = AlignedVector; + using LoadKVT = AlignedVector; + using LoadT = AlignedVector; + using LoadBiasT = AlignedVector; + using LoadOutScaleT = AlignedVector; + using LoadEmbT = AlignedVector; + LoadKVResT cache_vec; + LoadT src_vec1, src_vec2; + LoadBiasT out_vec1, out_vec2; + LoadEmbT cos_emb_vec1, cos_emb_vec2; + LoadEmbT sin_emb_vec1, sin_emb_vec2; - const int bias_idx_left = hi * head_size + h_bias; - const int bias_idx_right = bias_idx_left + half_head_size; + const T* qkv_now = quant_qkv + start_token_idx * hidden_size; + const int head_bias = lane_id / 4 * 16 + lane_id % 4 * 2; + const int bias_idx = head_idx * HeadDim + head_bias; + Load(&qkv_now[bias_idx], &src_vec1); + Load(&qkv_now[bias_idx + 8], &src_vec2); + T scale = T(1.0f); + const int k_head_idx = head_idx - num_heads; + const int v_head_idx = head_idx - num_heads - kv_num_heads; + if (head_idx < num_heads + kv_num_heads) { + const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2; + const uint32_t new_emb_idx = + rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx; + Load(&cos_emb[new_emb_idx], &cos_emb_vec1); + Load(&cos_emb[new_emb_idx + 4], &cos_emb_vec2); + Load(&sin_emb[new_emb_idx], &sin_emb_vec1); + Load(&sin_emb[new_emb_idx + 4], &sin_emb_vec2); + } - Load(&quant_qkv[ori_idx_left], &left_vec); - Load(&quant_qkv[ori_idx_right], &right_vec); - if (qkv_biases) { - Load(&qkv_biases[bias_idx_left], &left_bias_vec); - Load(&qkv_biases[bias_idx_right], &right_bias_vec); + float input_left = static_cast(src_vec1[0]); + float input_right = static_cast(src_vec1[1]); + if (head_idx < num_heads + kv_num_heads) { + float cos_tmp = cos_emb_vec1[0]; + float sin_tmp = sin_emb_vec1[0]; + float tmp1 = fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp); + float tmp2 = fmul_func(input_right, cos_tmp) + + fmul_func(input_left, sin_tmp); + thread_m2 += tmp1 * tmp1 + tmp2 * tmp2; + out_vec1[0] = static_cast(tmp1); + out_vec1[1] = static_cast(tmp2); + } else { + out_vec1[0] = src_vec1[0]; + out_vec1[1] = src_vec1[1]; } - Load(&qkv_out_scales[bias_idx_left], &left_out_scale_vec); - Load(&qkv_out_scales[bias_idx_right], &right_out_scale_vec); + // rope + input_left = static_cast(src_vec2[0]); + input_right = static_cast(src_vec2[1]); + if (head_idx < num_heads + kv_num_heads) { + float cos_tmp = cos_emb_vec2[0]; + float sin_tmp = sin_emb_vec2[0]; + float tmp1 = fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp); + float tmp2 = fmul_func(input_right, cos_tmp) + + fmul_func(input_left, sin_tmp); + thread_m2 += tmp1 * tmp1 + tmp2 * tmp2; + out_vec2[0] = static_cast(tmp1); + out_vec2[1] = static_cast(tmp2); + } else { + out_vec2[0] = src_vec2[0]; + out_vec2[1] = src_vec2[1]; + } + if (k_norm_weight) { + if (head_idx < num_heads + kv_num_heads) { + LoadOutScaleT k_norm_vec1, k_norm_vec2; + Load(&k_norm_weight[head_bias], &k_norm_vec1); + Load(&k_norm_weight[head_bias + 8], + &k_norm_vec2); + // qk norm + WelfordWarpAllReduce(thread_m2, &warp_m2); + float row_variance = max(warp_m2 / HeadDim, 0.0f); + float row_inv_var = Rsqrt(row_variance + rms_norm_eps); - if (hi < num_heads + kv_num_heads) { - // q k rope - const uint32_t emb_idx = write_seq_id * head_size + h_bias; - Load(&cos_emb[emb_idx], &cos_emb_vec); - Load(&sin_emb[emb_idx], &sin_emb_vec); + for (int i = 0; i < HALF_K_VEC_SIZE; i++) { + out_vec1[i] = static_cast(static_cast(out_vec1[i]) * + row_inv_var * k_norm_vec1[i]); + out_vec2[i] = static_cast(static_cast(out_vec2[i]) * + row_inv_var * k_norm_vec2[i]); + } + } } + if constexpr (IsDynamic) { + // reduce max, 1 head per warp + T local_max = -INFINITY; #pragma unroll - for (int i = 0; i < VecSize; i++) { - // dequant + add_bias + rope - float input_left = static_cast(left_vec[i]); - float input_right = static_cast(right_vec[i]); - input_left = qkv_biases ? input_left * left_out_scale_vec[i] + - static_cast(left_bias_vec[i]) - : input_left * left_out_scale_vec[i]; - input_right = qkv_biases ? input_right * right_out_scale_vec[i] + - static_cast(right_bias_vec[i]) - : input_right * right_out_scale_vec[i]; - if (hi < num_heads + kv_num_heads) { - const float cos_tmp = cos_emb_vec[i]; - const float sin_tmp = sin_emb_vec[i]; - left_bias_vec[i] = - static_cast(input_left * cos_tmp - input_right * sin_tmp); - right_bias_vec[i] = - static_cast(input_right * cos_tmp + input_left * sin_tmp); - } else { - left_bias_vec[i] = static_cast(input_left); - right_bias_vec[i] = static_cast(input_right); + for (int i = 0; i < HALF_K_VEC_SIZE; i++) { + local_max = __hmax(local_max, __habs(out_vec1[i])); + local_max = __hmax(local_max, __habs(out_vec2[i])); + } +#pragma unroll + for (int m_offset = 16; m_offset > 0; m_offset /= 2) { + local_max = + __hmax(local_max, __shfl_xor_sync(0xffffffff, local_max, m_offset)); + } + scale = __hdiv(448, local_max); + + int cache_offset; + if (head_idx < num_heads) { + cache_offset = 0; + } else if (head_idx < num_heads + 2 * kv_num_heads) { + cache_offset = block_idx * kv_num_heads * block_size + + (head_idx - num_heads) % kv_num_heads * block_size + + block_offset; + } + T* cache_k_scale_now = cache_k_scale + cache_offset; + T* cache_v_scale_now = cache_v_scale + cache_offset; + if (lane_id == 0) { + if (head_idx < num_heads + kv_num_heads) { + cache_k_scale_now[0] = __hdiv(1, scale); + } else { + cache_v_scale_now[0] = __hdiv(1, scale); + } } - } - if (hi < num_heads) { - // write q - Store(left_bias_vec, &qkv_out[ori_idx_left]); - Store(right_bias_vec, &qkv_out[ori_idx_right]); } else { - // quant + write k/v - const uint32_t kv_head_idx = (hi - num_heads) % kv_num_heads; - const uint32_t tgt_idx_left = - block_idx * kv_num_heads * block_size * head_size + - kv_head_idx * block_size * head_size + block_offset * head_size + - h_bias; - const uint32_t tgt_idx_right = tgt_idx_left + half_head_size; - if (hi < num_heads + kv_num_heads) { - Store(left_bias_vec, &key_cache[tgt_idx_left]); - Store(right_bias_vec, &key_cache[tgt_idx_right]); + if (head_idx < num_heads + kv_num_heads) { + scale = __ldg(&cache_k_scale[kv_head_idx]); } else { - Store(left_bias_vec, &value_cache[tgt_idx_left]); - Store(right_bias_vec, &value_cache[tgt_idx_right]); + scale = __ldg(&cache_v_scale[kv_head_idx]); } } + +#pragma unroll + for (uint32_t i = 0; i < HALF_K_VEC_SIZE; i++) { + cache_vec[i] = QuantToC8( + scale, out_vec1[i], max_bound, min_bound); + cache_vec[i + HALF_K_VEC_SIZE] = QuantToC8( + scale, out_vec2[i], max_bound, min_bound); + } + if (head_idx < num_heads + kv_num_heads) { + const int start_block_16 = + block_offset / 16 * 16 + block_offset % 8 + lane_id / 4 % 2 * 8; + const uint32_t tgt_cache_idx = + block_idx * kv_num_heads * block_size * HeadDim + + kv_head_idx * block_size * HeadDim + start_block_16 * HeadDim + + lane_id / 4 / 2 * 32 + (block_offset % 16) / 8 * 16 + lane_id % 4 * 4; + Store(cache_vec, &key_cache[tgt_cache_idx]); + } else { + const uint32_t base_tgt_cache_idx = + block_idx * kv_num_heads * HeadDim * block_size + + kv_head_idx * HeadDim * block_size + + (lane_id / 4 * 16 + lane_id % 4 * 2) * block_size + + block_offset / 16 % 2 * 8 * block_size + block_offset / 16 / 2 * 32; + const uint32_t tgt_cache_idx1 = base_tgt_cache_idx + + block_offset % 8 / 2 * 4 // per 4 + + block_offset % 16 / 8 * 2 // per 2 + + block_offset % 2; // per 1 + const uint32_t tgt_cache_idx2 = tgt_cache_idx1 + block_size; + const uint32_t tgt_cache_idx3 = tgt_cache_idx1 + 16; + const uint32_t tgt_cache_idx4 = tgt_cache_idx3 + block_size; + value_cache[tgt_cache_idx1] = cache_vec[0]; + value_cache[tgt_cache_idx2] = cache_vec[1]; + value_cache[tgt_cache_idx3] = cache_vec[2]; + value_cache[tgt_cache_idx4] = cache_vec[3]; + } } +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaTriggerProgrammaticLaunchCompletion(); +#endif } -template +template __global__ void append_decode_cache_int8_rope_kernel( const T* __restrict__ quant_qkv, // [bsz, num_heads + 2 * kv_num_heads, // head_size] @@ -497,8 +1518,7 @@ __global__ void append_decode_cache_int8_rope_kernel( uint8_t* __restrict__ value_cache, // [num_blocks, kv_num_heads, // block_size, head_size // 2] T* __restrict__ qkv_out, - const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq] - const int* __restrict__ batch_id_per_token, // [num_tokens] + const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq] const int* __restrict__ cu_seqlens_q, const int* __restrict__ seq_lens, // [bsz] const int* __restrict__ seq_lens_encoder, // [bsz] @@ -512,7 +1532,8 @@ __global__ void append_decode_cache_int8_rope_kernel( const int block_size, const float max_bound, const float min_bound, - const int kv_num_heads) { + const int kv_num_heads, + const bool rope_3d) { static_assert(HeadDim == 128, "just support HeadDim be 128 now!"); static_assert(VecSize == 4, "just support VecSize be 4 now, 32 * 4!"); constexpr int NUM_WARPS = 4; @@ -532,46 +1553,24 @@ __global__ void append_decode_cache_int8_rope_kernel( block_table_now = block_tables + bid * max_blocks_per_seq; const int block_idx = __ldg(&block_table_now[write_seq_id / block_size]); const int block_offset = write_seq_id % block_size; - +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaGridDependencySynchronize(); +#endif if (head_idx < num_heads) { // q - using LoadT = AlignedVector; - using LoadBiasT = AlignedVector; - using LoadOutScaleT = AlignedVector; - constexpr int HalfVecSize = VecSize / 2; - using LoadEmbT = AlignedVector; - - LoadT src_vec; - LoadBiasT out_vec; - LoadEmbT cos_emb_vec; - LoadEmbT sin_emb_vec; - const T* qkv_now = quant_qkv + start_token_idx * hidden_size; - T* qkv_out_now = qkv_out + start_token_idx * hidden_size; -#pragma unroll - for (uint32_t head_bias = lane_id * VecSize; head_bias < HeadDim; - head_bias += 32 * VecSize) { - const int bias_idx = head_idx * HeadDim + head_bias; - Load(&qkv_now[bias_idx], &src_vec); - - // q rope - const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2; - Load(&cos_emb[emb_idx], &cos_emb_vec); - Load(&sin_emb[emb_idx], &sin_emb_vec); -#pragma unroll - for (int i = 0; i < HalfVecSize; i++) { - // dequant + add_bias + rope - float input_left = static_cast(src_vec[2 * i]); - float input_right = static_cast(src_vec[2 * i + 1]); + const T* qkv_now = + quant_qkv + start_token_idx * hidden_size + head_idx * HeadDim; + T* qkv_out_now = + qkv_out + start_token_idx * hidden_size + head_idx * HeadDim; + + uint32_t emb_offset = write_seq_id * half_head_size; + emb_offset += rope_3d ? bid * max_seq_len * HeadDim : 0; + apply_rope(qkv_now, + cos_emb + emb_offset, + sin_emb + emb_offset, + qkv_out_now, + lane_id); - const float cos_tmp = cos_emb_vec[i]; - const float sin_tmp = sin_emb_vec[i]; - out_vec[2 * i] = - static_cast(input_left * cos_tmp - input_right * sin_tmp); - out_vec[2 * i + 1] = - static_cast(input_right * cos_tmp + input_left * sin_tmp); - } - Store(out_vec, &qkv_out_now[bias_idx]); - } } else if (head_idx < num_heads + 2 * kv_num_heads) { // k constexpr int KV_VEC_SIZE = 16 / sizeof(uint8_t); // 16 @@ -629,14 +1628,18 @@ __global__ void append_decode_cache_int8_rope_kernel( T scale = T(1.0f); const int k_head_idx = head_idx - num_heads; const int v_head_idx = head_idx - num_heads - kv_num_heads; - const T *cache_k_scale_cur = cache_k_scale + k_head_idx * HeadDim + head_bias; - const T *cache_v_scale_cur = cache_v_scale + v_head_idx * HeadDim + head_bias; + const T* cache_k_scale_cur = + cache_k_scale + k_head_idx * HeadDim + head_bias; + const T* cache_v_scale_cur = + cache_v_scale + v_head_idx * HeadDim + head_bias; if (head_idx < num_heads + kv_num_heads) { const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2; - Load(&cos_emb[emb_idx], &cos_emb_vec1); - Load(&cos_emb[emb_idx + 4], &cos_emb_vec2); - Load(&sin_emb[emb_idx], &sin_emb_vec1); - Load(&sin_emb[emb_idx + 4], &sin_emb_vec2); + uint32_t new_emb_idx = + rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx; + Load(&cos_emb[new_emb_idx], &cos_emb_vec1); + Load(&cos_emb[new_emb_idx + 4], &cos_emb_vec2); + Load(&sin_emb[new_emb_idx], &sin_emb_vec1); + Load(&sin_emb[new_emb_idx + 4], &sin_emb_vec2); if constexpr (!is_scale_channel_wise) { scale = __ldg(&cache_k_scale[kv_head_idx]); } @@ -653,9 +1656,11 @@ __global__ void append_decode_cache_int8_rope_kernel( float cos_tmp = cos_emb_vec1[0]; float sin_tmp = sin_emb_vec1[0]; out_vec1[0] = - static_cast(input_left * cos_tmp - input_right * sin_tmp); + static_cast(fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp)); out_vec1[1] = - static_cast(input_right * cos_tmp + input_left * sin_tmp); + static_cast(fmul_func(input_right, cos_tmp) + + fmul_func(input_left, sin_tmp)); } else { out_vec1[0] = src_vec1[0]; out_vec1[1] = src_vec1[1]; @@ -665,9 +1670,13 @@ __global__ void append_decode_cache_int8_rope_kernel( float cos_tmp = cos_emb_vec1[0]; float sin_tmp = sin_emb_vec1[0]; out_vec1[0] = - static_cast((input_left * cos_tmp - input_right * sin_tmp) * float(cache_k_scale_cur[0])); + static_cast((fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp)) * + float(cache_k_scale_cur[0])); out_vec1[1] = - static_cast((input_right * cos_tmp + input_left * sin_tmp) * float(cache_k_scale_cur[1])); + static_cast((fmul_func(input_right, cos_tmp) + + fmul_func(input_left, sin_tmp)) * + float(cache_k_scale_cur[1])); } else { out_vec1[0] = static_cast(input_left * float(cache_v_scale_cur[0])); out_vec1[1] = static_cast(input_right * float(cache_v_scale_cur[1])); @@ -681,9 +1690,11 @@ __global__ void append_decode_cache_int8_rope_kernel( float cos_tmp = cos_emb_vec2[0]; float sin_tmp = sin_emb_vec2[0]; out_vec2[0] = - static_cast(input_left * cos_tmp - input_right * sin_tmp); + static_cast(fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp)); out_vec2[1] = - static_cast(input_right * cos_tmp + input_left * sin_tmp); + static_cast(fmul_func(input_right, cos_tmp) + + fmul_func(input_left, sin_tmp)); } else { out_vec2[0] = src_vec2[0]; out_vec2[1] = src_vec2[1]; @@ -693,9 +1704,13 @@ __global__ void append_decode_cache_int8_rope_kernel( float cos_tmp = cos_emb_vec2[0]; float sin_tmp = sin_emb_vec2[0]; out_vec2[0] = - static_cast((input_left * cos_tmp - input_right * sin_tmp) * float(cache_k_scale_cur[8])); + static_cast((fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp)) * + float(cache_k_scale_cur[8])); out_vec2[1] = - static_cast((input_right * cos_tmp + input_left * sin_tmp) * float(cache_k_scale_cur[9])); + static_cast((fmul_func(input_right, cos_tmp) + + fmul_func(input_left, sin_tmp)) * + float(cache_k_scale_cur[9])); } else { out_vec2[0] = static_cast(input_left * float(cache_v_scale_cur[8])); out_vec2[1] = static_cast(input_right * float(cache_v_scale_cur[9])); @@ -703,8 +1718,10 @@ __global__ void append_decode_cache_int8_rope_kernel( } #pragma unroll for (uint32_t i = 0; i < HALF_K_VEC_SIZE; i++) { - cache_vec[i] = QuantToC8(scale, out_vec1[i], max_bound, min_bound); - cache_vec[i + HALF_K_VEC_SIZE] = QuantToC8(scale, out_vec2[i], max_bound, min_bound); + cache_vec[i] = QuantToC8( + scale, out_vec1[i], max_bound, min_bound); + cache_vec[i + HALF_K_VEC_SIZE] = QuantToC8( + scale, out_vec2[i], max_bound, min_bound); } if (head_idx < num_heads + kv_num_heads) { const int start_block_16 = @@ -733,10 +1750,19 @@ __global__ void append_decode_cache_int8_rope_kernel( value_cache[tgt_cache_idx4] = cache_vec[3]; } } +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaTriggerProgrammaticLaunchCompletion(); +#endif } -template -__global__ void append_decode_cache_int8_rope_kernel( +template +__global__ void int_append_decode_cache_int8_rope_kernel( const int* __restrict__ quant_qkv, // [bsz, num_heads + 2 * kv_num_heads, // head_size] uint8_t* __restrict__ key_cache, // [num_blocks, kv_num_heads, @@ -744,8 +1770,7 @@ __global__ void append_decode_cache_int8_rope_kernel( uint8_t* __restrict__ value_cache, // [num_blocks, kv_num_heads, // block_size, head_size // 2] T* __restrict__ qkv_out, - const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq] - const int* __restrict__ batch_id_per_token, // [num_tokens] + const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq] const int* __restrict__ cu_seqlens_q, const int* __restrict__ seq_lens, // [bsz] const int* __restrict__ seq_lens_encoder, // [bsz] @@ -753,8 +1778,8 @@ __global__ void append_decode_cache_int8_rope_kernel( const float* __restrict__ sin_emb, const float* __restrict__ qkv_out_scales, // [num_head + 2 * // kv_num_heads, dim_head] - const T* __restrict__ qkv_biases, // [num_head + 2 * kv_num_heads, - // dim_head] + const T* __restrict__ qkv_biases, // [num_head + 2 * kv_num_heads, + // dim_head] const T* __restrict__ cache_k_scales, const T* __restrict__ cache_v_scales, const int max_seq_len, @@ -763,7 +1788,8 @@ __global__ void append_decode_cache_int8_rope_kernel( const int block_size, const float max_bound, const float min_bound, - const int kv_num_heads) { + const int kv_num_heads, + const bool rope_3d) { static_assert(HeadDim == 128, "just support HeadDim be 128 now!"); static_assert(VecSize == 4, "just support VecSize be 4 now, 32 * 4!"); constexpr int NUM_WARPS = 4; @@ -784,7 +1810,9 @@ __global__ void append_decode_cache_int8_rope_kernel( block_table_now = block_tables + bid * max_blocks_per_seq; const int block_idx = __ldg(&block_table_now[write_seq_id / block_size]); const int block_offset = write_seq_id % block_size; - +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaGridDependencySynchronize(); +#endif if (head_idx < num_heads) { // q using LoadT = AlignedVector; @@ -813,9 +1841,11 @@ __global__ void append_decode_cache_int8_rope_kernel( // q rope const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2; - Load(&cos_emb[emb_idx], &cos_emb_vec); + uint32_t new_emb_idx = + rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx; + Load(&cos_emb[new_emb_idx], &cos_emb_vec); - Load(&sin_emb[emb_idx], &sin_emb_vec); + Load(&sin_emb[new_emb_idx], &sin_emb_vec); #pragma unroll for (int i = 0; i < HalfVecSize; i++) { @@ -831,9 +1861,11 @@ __global__ void append_decode_cache_int8_rope_kernel( const float cos_tmp = cos_emb_vec[i]; const float sin_tmp = sin_emb_vec[i]; bias_vec[2 * i] = - static_cast(input_left * cos_tmp - input_right * sin_tmp); + static_cast(fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp)); bias_vec[2 * i + 1] = - static_cast(input_right * cos_tmp + input_left * sin_tmp); + static_cast(fmul_func(input_right, cos_tmp) + + fmul_func(input_left, sin_tmp)); } Store(bias_vec, &qkv_out_now[bias_idx]); } @@ -904,14 +1936,18 @@ __global__ void append_decode_cache_int8_rope_kernel( T scale = T(1.0f); const int k_head_idx = head_idx - num_heads; const int v_head_idx = head_idx - num_heads - kv_num_heads; - const T *cache_k_scale_cur = cache_k_scales + k_head_idx * HeadDim + head_bias; - const T *cache_v_scale_cur = cache_v_scales + v_head_idx * HeadDim + head_bias; + const T* cache_k_scale_cur = + cache_k_scales + k_head_idx * HeadDim + head_bias; + const T* cache_v_scale_cur = + cache_v_scales + v_head_idx * HeadDim + head_bias; if (head_idx < num_heads + kv_num_heads) { const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2; - Load(&cos_emb[emb_idx], &cos_emb_vec1); - Load(&cos_emb[emb_idx + 4], &cos_emb_vec2); - Load(&sin_emb[emb_idx], &sin_emb_vec1); - Load(&sin_emb[emb_idx + 4], &sin_emb_vec2); + uint32_t new_emb_idx = + rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx; + Load(&cos_emb[new_emb_idx], &cos_emb_vec1); + Load(&cos_emb[new_emb_idx + 4], &cos_emb_vec2); + Load(&sin_emb[new_emb_idx], &sin_emb_vec1); + Load(&sin_emb[new_emb_idx + 4], &sin_emb_vec2); if constexpr (!is_scale_channel_wise) { scale = __ldg(&cache_k_scales[kv_head_idx]); } @@ -934,9 +1970,11 @@ __global__ void append_decode_cache_int8_rope_kernel( float cos_tmp = cos_emb_vec1[0]; float sin_tmp = sin_emb_vec1[0]; bias_vec1[0] = - static_cast(input_left * cos_tmp - input_right * sin_tmp); + static_cast(fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp)); bias_vec1[1] = - static_cast(input_right * cos_tmp + input_left * sin_tmp); + static_cast(fmul_func(input_right, cos_tmp) + + fmul_func(input_left, sin_tmp)); } else { bias_vec1[0] = static_cast(input_left); bias_vec1[1] = static_cast(input_right); @@ -946,12 +1984,17 @@ __global__ void append_decode_cache_int8_rope_kernel( float cos_tmp = cos_emb_vec1[0]; float sin_tmp = sin_emb_vec1[0]; bias_vec1[0] = - static_cast((input_left * cos_tmp - input_right * sin_tmp) * float(cache_k_scale_cur[0])); + static_cast((fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp)) * + float(cache_k_scale_cur[0])); bias_vec1[1] = - static_cast((input_right * cos_tmp + input_left * sin_tmp) * float(cache_k_scale_cur[1])); + static_cast((fmul_func(input_right, cos_tmp) + + fmul_func(input_left, sin_tmp)) * + float(cache_k_scale_cur[1])); } else { bias_vec1[0] = static_cast(input_left * float(cache_v_scale_cur[0])); - bias_vec1[1] = static_cast(input_right * float(cache_v_scale_cur[1])); + bias_vec1[1] = + static_cast(input_right * float(cache_v_scale_cur[1])); } } @@ -968,24 +2011,31 @@ __global__ void append_decode_cache_int8_rope_kernel( float cos_tmp = cos_emb_vec2[0]; float sin_tmp = sin_emb_vec2[0]; bias_vec2[0] = - static_cast(input_left * cos_tmp - input_right * sin_tmp); + static_cast(fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp)); bias_vec2[1] = - static_cast(input_right * cos_tmp + input_left * sin_tmp); + static_cast(fmul_func(input_right, cos_tmp) + + fmul_func(input_left, sin_tmp)); } else { bias_vec2[0] = static_cast(input_left); bias_vec2[1] = static_cast(input_right); } } else { - if (head_idx < num_heads + kv_num_heads) { + if (head_idx < num_heads + kv_num_heads) { float cos_tmp = cos_emb_vec2[0]; float sin_tmp = sin_emb_vec2[0]; bias_vec2[0] = - static_cast((input_left * cos_tmp - input_right * sin_tmp) * float(cache_k_scale_cur[8])); + static_cast((fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp)) * + float(cache_k_scale_cur[8])); bias_vec2[1] = - static_cast((input_right * cos_tmp + input_left * sin_tmp) * float(cache_k_scale_cur[9])); + static_cast((fmul_func(input_right, cos_tmp) + + fmul_func(input_left, sin_tmp)) * + float(cache_k_scale_cur[9])); } else { bias_vec2[0] = static_cast(input_left * float(cache_v_scale_cur[8])); - bias_vec2[1] = static_cast(input_right * float(cache_v_scale_cur[9])); + bias_vec2[1] = + static_cast(input_right * float(cache_v_scale_cur[9])); } } #pragma unroll @@ -1034,10 +2084,16 @@ __global__ void append_decode_cache_int8_rope_kernel( value_cache[tgt_cache_idx4] = cache_vec[3]; } } +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaTriggerProgrammaticLaunchCompletion(); +#endif } - -template +template __global__ void append_decode_cache_int8_neox_rope_kernel( const T* __restrict__ quant_qkv, // [bsz, num_heads + 2 * kv_num_heads, // head_size] @@ -1046,8 +2102,7 @@ __global__ void append_decode_cache_int8_neox_rope_kernel( uint8_t* __restrict__ value_cache, // [num_blocks, kv_num_heads, // block_size, head_size // 2] T* __restrict__ qkv_out, - const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq] - const int* __restrict__ batch_id_per_token, // [num_tokens] + const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq] const int* __restrict__ cu_seqlens_q, const int* __restrict__ seq_lens, // [bsz] const int* __restrict__ seq_lens_encoder, // [bsz] @@ -1061,7 +2116,8 @@ __global__ void append_decode_cache_int8_neox_rope_kernel( const int block_size, const float max_bound, const float min_bound, - const int kv_num_heads) { + const int kv_num_heads, + const bool rope_3d) { static_assert(HeadDim == 128, "just support HeadDim be 128 now!"); static_assert(VecSize == 4, "just support VecSize be 4 now, 32 * 4!"); constexpr int NUM_WARPS = 4; @@ -1082,7 +2138,9 @@ __global__ void append_decode_cache_int8_neox_rope_kernel( block_table_now = block_tables + bid * max_blocks_per_seq; const int block_idx = __ldg(&block_table_now[write_seq_id / block_size]); const int block_offset = write_seq_id % block_size; - +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaGridDependencySynchronize(); +#endif if (head_idx < num_heads) { // q using LoadT = AlignedVector; @@ -1109,8 +2167,10 @@ __global__ void append_decode_cache_int8_neox_rope_kernel( // q rope const uint32_t emb_idx = write_seq_id * HeadDim + head_bias; - Load(&cos_emb[emb_idx], &cos_emb_vec); - Load(&sin_emb[emb_idx], &sin_emb_vec); + uint32_t new_emb_idx = + rope_3d ? emb_idx + bid * max_seq_len * HeadDim * 2 : emb_idx; + Load(&cos_emb[new_emb_idx], &cos_emb_vec); + Load(&sin_emb[new_emb_idx], &sin_emb_vec); #pragma unroll for (int i = 0; i < VecSize; i++) { @@ -1120,9 +2180,11 @@ __global__ void append_decode_cache_int8_neox_rope_kernel( const float cos_tmp = cos_emb_vec[i]; const float sin_tmp = sin_emb_vec[i]; left_bias_vec[i] = - static_cast(input_left * cos_tmp - input_right * sin_tmp); + static_cast(fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp)); right_bias_vec[i] = - static_cast(input_right * cos_tmp + input_left * sin_tmp); + static_cast(fmul_func(input_right, cos_tmp) + + fmul_func(input_left, sin_tmp)); } Store(left_bias_vec, &qkv_out_now[bias_idx_left]); Store(right_bias_vec, &qkv_out_now[bias_idx_right]); @@ -1191,10 +2253,12 @@ __global__ void append_decode_cache_int8_neox_rope_kernel( T scale; const uint32_t emb_idx = write_seq_id * HeadDim + head_bias; - Load(&cos_emb[emb_idx], &cos_emb_vec1); - Load(&cos_emb[emb_idx + 8], &cos_emb_vec2); - Load(&sin_emb[emb_idx], &sin_emb_vec1); - Load(&sin_emb[emb_idx + 8], &sin_emb_vec2); + uint32_t new_emb_idx = + rope_3d ? emb_idx + bid * max_seq_len * HeadDim * 2 : emb_idx; + Load(&cos_emb[new_emb_idx], &cos_emb_vec1); + Load(&cos_emb[new_emb_idx + 8], &cos_emb_vec2); + Load(&sin_emb[new_emb_idx], &sin_emb_vec1); + Load(&sin_emb[new_emb_idx + 8], &sin_emb_vec2); scale = __ldg(&cache_k_scales[kv_head_idx]); #pragma unroll for (int i = 0; i < HALF_K_VEC_SIZE; i++) { @@ -1204,18 +2268,22 @@ __global__ void append_decode_cache_int8_neox_rope_kernel( float cos_tmp = cos_emb_vec1[i]; float sin_tmp = sin_emb_vec1[i]; left_bias_vec1[i] = - static_cast(input_left * cos_tmp - input_right * sin_tmp); + static_cast(fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp)); right_bias_vec1[i] = - static_cast(input_right * cos_tmp + input_left * sin_tmp); + static_cast(fmul_func(input_right, cos_tmp) + + fmul_func(input_left, sin_tmp)); input_left = static_cast(left_src_vec2[i]); input_right = static_cast(right_src_vec2[i]); cos_tmp = cos_emb_vec2[i]; sin_tmp = sin_emb_vec2[i]; left_bias_vec2[i] = - static_cast(input_left * cos_tmp - input_right * sin_tmp); + static_cast(fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp)); right_bias_vec2[i] = - static_cast(input_right * cos_tmp + input_left * sin_tmp); + static_cast(fmul_func(input_right, cos_tmp) + + fmul_func(input_left, sin_tmp)); float quant_value1 = static_cast(scale * left_bias_vec1[i]); float quant_value2 = static_cast(scale * left_bias_vec2[i]); @@ -1334,10 +2402,17 @@ __global__ void append_decode_cache_int8_neox_rope_kernel( value_cache[tgt_cache_idx4] = cache_vec[3]; } } +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaTriggerProgrammaticLaunchCompletion(); +#endif } -template -__global__ void append_decode_cache_int8_neox_rope_kernel( +template +__global__ void int_append_decode_cache_int8_neox_rope_kernel( const int* __restrict__ quant_qkv, // [bsz, num_heads + 2 * kv_num_heads, // head_size] uint8_t* __restrict__ key_cache, // [num_blocks, kv_num_heads, @@ -1345,8 +2420,8 @@ __global__ void append_decode_cache_int8_neox_rope_kernel( uint8_t* __restrict__ value_cache, // [num_blocks, kv_num_heads, // block_size, head_size // 2] T* __restrict__ qkv_out, - const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq] - const int* __restrict__ batch_id_per_token, // [num_tokens] + const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq] + const int* __restrict__ cu_seqlens_q, const int* __restrict__ seq_lens, // [bsz] const int* __restrict__ seq_lens_encoder, // [bsz] @@ -1354,8 +2429,8 @@ __global__ void append_decode_cache_int8_neox_rope_kernel( const float* __restrict__ sin_emb, const float* __restrict__ qkv_out_scales, // [num_head + 2 * // kv_num_heads, dim_head] - const T* __restrict__ qkv_biases, // [num_head + 2 * kv_num_heads, - // dim_head] + const T* __restrict__ qkv_biases, // [num_head + 2 * kv_num_heads, + // dim_head] const T* __restrict__ cache_k_scales, const T* __restrict__ cache_v_scales, const int max_seq_len, @@ -1364,7 +2439,8 @@ __global__ void append_decode_cache_int8_neox_rope_kernel( const int block_size, const float max_bound, const float min_bound, - const int kv_num_heads) { + const int kv_num_heads, + const bool rope_3d) { static_assert(HeadDim == 128, "just support HeadDim be 128 now!"); static_assert(VecSize == 4, "just support VecSize be 4 now, 32 * 4!"); constexpr int NUM_WARPS = 4; @@ -1386,7 +2462,9 @@ __global__ void append_decode_cache_int8_neox_rope_kernel( block_table_now = block_tables + bid * max_blocks_per_seq; const int block_idx = __ldg(&block_table_now[write_seq_id / block_size]); const int block_offset = write_seq_id % block_size; - +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaGridDependencySynchronize(); +#endif if (head_idx < num_heads) { // q using LoadT = AlignedVector; @@ -1424,8 +2502,10 @@ __global__ void append_decode_cache_int8_neox_rope_kernel( // q rope const uint32_t emb_idx = write_seq_id * HeadDim + head_bias; - Load(&cos_emb[emb_idx], &cos_emb_vec); - Load(&sin_emb[emb_idx], &sin_emb_vec); + uint32_t new_emb_idx = + rope_3d ? emb_idx + bid * max_seq_len * HeadDim * 2 : emb_idx; + Load(&cos_emb[new_emb_idx], &cos_emb_vec); + Load(&sin_emb[new_emb_idx], &sin_emb_vec); #pragma unroll for (int i = 0; i < VecSize; i++) { @@ -1441,9 +2521,11 @@ __global__ void append_decode_cache_int8_neox_rope_kernel( const float cos_tmp = cos_emb_vec[i]; const float sin_tmp = sin_emb_vec[i]; left_bias_vec[i] = - static_cast(input_left * cos_tmp - input_right * sin_tmp); + static_cast(fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp)); right_bias_vec[i] = - static_cast(input_right * cos_tmp + input_left * sin_tmp); + static_cast(fmul_func(input_right, cos_tmp) + + fmul_func(input_left, sin_tmp)); } Store(left_bias_vec, &qkv_out_now[bias_idx_left]); Store(right_bias_vec, &qkv_out_now[bias_idx_right]); @@ -1533,10 +2615,12 @@ __global__ void append_decode_cache_int8_neox_rope_kernel( T scale; const uint32_t emb_idx = write_seq_id * HeadDim + head_bias; - Load(&cos_emb[emb_idx], &cos_emb_vec1); - Load(&cos_emb[emb_idx + 8], &cos_emb_vec2); - Load(&sin_emb[emb_idx], &sin_emb_vec1); - Load(&sin_emb[emb_idx + 8], &sin_emb_vec2); + uint32_t new_emb_idx = + rope_3d ? emb_idx + bid * max_seq_len * HeadDim * 2 : emb_idx; + Load(&cos_emb[new_emb_idx], &cos_emb_vec1); + Load(&cos_emb[new_emb_idx + 8], &cos_emb_vec2); + Load(&sin_emb[new_emb_idx], &sin_emb_vec1); + Load(&sin_emb[new_emb_idx + 8], &sin_emb_vec2); scale = __ldg(&cache_k_scales[kv_head_idx]); #pragma unroll for (int i = 0; i < HALF_K_VEC_SIZE; i++) { @@ -1552,9 +2636,11 @@ __global__ void append_decode_cache_int8_neox_rope_kernel( float cos_tmp = cos_emb_vec1[i]; float sin_tmp = sin_emb_vec1[i]; left_bias_vec1[i] = - static_cast(input_left * cos_tmp - input_right * sin_tmp); + static_cast(fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp)); right_bias_vec1[i] = - static_cast(input_right * cos_tmp + input_left * sin_tmp); + static_cast(fmul_func(input_right, cos_tmp) + + fmul_func(input_left, sin_tmp)); input_left = static_cast(left_src_vec2[i]); input_right = static_cast(right_src_vec2[i]); @@ -1567,9 +2653,11 @@ __global__ void append_decode_cache_int8_neox_rope_kernel( cos_tmp = cos_emb_vec2[i]; sin_tmp = sin_emb_vec2[i]; left_bias_vec2[i] = - static_cast(input_left * cos_tmp - input_right * sin_tmp); + static_cast(fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp)); right_bias_vec2[i] = - static_cast(input_right * cos_tmp + input_left * sin_tmp); + static_cast(fmul_func(input_right, cos_tmp) + + fmul_func(input_left, sin_tmp)); float quant_value1 = static_cast(scale * left_bias_vec1[i]); float quant_value2 = static_cast(scale * left_bias_vec2[i]); @@ -1726,10 +2814,16 @@ __global__ void append_decode_cache_int8_neox_rope_kernel( value_cache[tgt_cache_idx4] = cache_vec[3]; } } +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaTriggerProgrammaticLaunchCompletion(); +#endif } - -template +template __global__ void append_decode_cache_int4_rope_kernel( const T* __restrict__ quant_qkv, // [bsz, num_heads + 2 * kv_num_heads, // head_size] @@ -1738,8 +2832,8 @@ __global__ void append_decode_cache_int4_rope_kernel( uint8_t* __restrict__ value_cache, // [num_blocks, kv_num_heads, // block_size, head_size // 2] T* __restrict__ qkv_out, - const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq] - const int* __restrict__ batch_id_per_token, // [num_tokens] + const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq] + const int* __restrict__ cu_seqlens_q, const int* __restrict__ seq_lens, // [bsz] const int* __restrict__ seq_lens_encoder, // [bsz] @@ -1755,7 +2849,8 @@ __global__ void append_decode_cache_int4_rope_kernel( const int block_size, const float max_bound, const float min_bound, - const int kv_num_heads) { + const int kv_num_heads, + const bool rope_3d) { static_assert(HeadDim == 128, "just support HeadDim be 128 now!"); static_assert(VecSize == 4, "just support VecSize be 4 now, 32 * 4!"); constexpr int NUM_WARPS = 4; @@ -1776,46 +2871,24 @@ __global__ void append_decode_cache_int4_rope_kernel( const int block_idx = __ldg(&block_table_now[write_seq_id / block_size]); const int block_offset = write_seq_id % block_size; - +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaGridDependencySynchronize(); +#endif if (head_idx < num_heads) { // q - using LoadT = AlignedVector; - using LoadBiasT = AlignedVector; - using LoadOutScaleT = AlignedVector; - constexpr int HalfVecSize = VecSize / 2; - using LoadEmbT = AlignedVector; - - LoadT src_vec; - LoadBiasT out_vec; - LoadEmbT cos_emb_vec; - LoadEmbT sin_emb_vec; - const T* qkv_now = quant_qkv + start_token_idx * hidden_size; - T* qkv_out_now = qkv_out + start_token_idx * hidden_size; -#pragma unroll - for (uint32_t head_bias = lane_id * VecSize; head_bias < HeadDim; - head_bias += 32 * VecSize) { - const int bias_idx = head_idx * HeadDim + head_bias; - Load(&qkv_now[bias_idx], &src_vec); - - // q rope - const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2; - Load(&cos_emb[emb_idx], &cos_emb_vec); - Load(&sin_emb[emb_idx], &sin_emb_vec); -#pragma unroll - for (int i = 0; i < HalfVecSize; i++) { - // dequant + add_bias + rope - float input_left = static_cast(src_vec[2 * i]); - float input_right = static_cast(src_vec[2 * i + 1]); + const T* qkv_now = + quant_qkv + start_token_idx * hidden_size + head_idx * HeadDim; + T* qkv_out_now = + qkv_out + start_token_idx * hidden_size + head_idx * HeadDim; + + uint32_t emb_offset = write_seq_id * half_head_size; + emb_offset += rope_3d ? bid * max_seq_len * HeadDim : 0; + apply_rope(qkv_now, + cos_emb + emb_offset, + sin_emb + emb_offset, + qkv_out_now, + lane_id); - const float cos_tmp = cos_emb_vec[i]; - const float sin_tmp = sin_emb_vec[i]; - out_vec[2 * i] = - static_cast(input_left * cos_tmp - input_right * sin_tmp); - out_vec[2 * i + 1] = - static_cast(input_right * cos_tmp + input_left * sin_tmp); - } - Store(out_vec, &qkv_out_now[bias_idx]); - } } else if (head_idx < num_heads + 2 * kv_num_heads) { // k constexpr int KV_VEC_SIZE = 16 / sizeof(uint8_t); // 16 @@ -1874,10 +2947,12 @@ __global__ void append_decode_cache_int4_rope_kernel( Load(&qkv_now[bias_idx + 8], &src_vec2); if (head_idx < num_heads + kv_num_heads) { const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2; - Load(&cos_emb[emb_idx], &cos_emb_vec1); - Load(&cos_emb[emb_idx + 4], &cos_emb_vec2); - Load(&sin_emb[emb_idx], &sin_emb_vec1); - Load(&sin_emb[emb_idx + 4], &sin_emb_vec2); + uint32_t new_emb_idx = + rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx; + Load(&cos_emb[new_emb_idx], &cos_emb_vec1); + Load(&cos_emb[new_emb_idx + 4], &cos_emb_vec2); + Load(&sin_emb[new_emb_idx], &sin_emb_vec1); + Load(&sin_emb[new_emb_idx + 4], &sin_emb_vec2); Load(&cache_k_scale[cache_idx], &scale_vec1); Load(&cache_k_scale[cache_idx + 8], &scale_vec2); Load(&cache_k_zero_points[cache_idx], &zp_vec1); @@ -1895,9 +2970,11 @@ __global__ void append_decode_cache_int4_rope_kernel( float cos_tmp = cos_emb_vec1[0]; float sin_tmp = sin_emb_vec1[0]; out_vec1[0] = - static_cast(input_left * cos_tmp - input_right * sin_tmp); + static_cast(fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp)); out_vec1[1] = - static_cast(input_right * cos_tmp + input_left * sin_tmp); + static_cast(fmul_func(input_right, cos_tmp) + + fmul_func(input_left, sin_tmp)); } else { out_vec1[0] = src_vec1[0]; out_vec1[1] = src_vec1[1]; @@ -1909,9 +2986,11 @@ __global__ void append_decode_cache_int4_rope_kernel( float cos_tmp = cos_emb_vec2[0]; float sin_tmp = sin_emb_vec2[0]; out_vec2[0] = - static_cast(input_left * cos_tmp - input_right * sin_tmp); + static_cast(fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp)); out_vec2[1] = - static_cast(input_right * cos_tmp + input_left * sin_tmp); + static_cast(fmul_func(input_right, cos_tmp) + + fmul_func(input_left, sin_tmp)); } else { out_vec2[0] = src_vec2[0]; out_vec2[1] = src_vec2[1]; @@ -2022,10 +3101,17 @@ __global__ void append_decode_cache_int4_rope_kernel( (uint_quant_value2 << 4) | (uint_quant_value1 & 0x0F); } } +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaTriggerProgrammaticLaunchCompletion(); +#endif } -template -__global__ void append_decode_cache_int4_rope_kernel( +template +__global__ void int_append_decode_cache_int4_rope_kernel( const int* __restrict__ quant_qkv, // [bsz, num_heads + 2 * kv_num_heads, // head_size] uint8_t* __restrict__ key_cache, // [num_blocks, kv_num_heads, @@ -2033,8 +3119,8 @@ __global__ void append_decode_cache_int4_rope_kernel( uint8_t* __restrict__ value_cache, // [num_blocks, kv_num_heads, // block_size, head_size // 2] T* __restrict__ qkv_out, - const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq] - const int* __restrict__ batch_id_per_token, // [num_tokens] + const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq] + const int* __restrict__ cu_seqlens_q, const int* __restrict__ seq_lens, // [bsz] const int* __restrict__ seq_lens_encoder, // [bsz] @@ -2042,8 +3128,8 @@ __global__ void append_decode_cache_int4_rope_kernel( const float* __restrict__ sin_emb, const float* __restrict__ qkv_out_scales, // [num_head + 2 * // kv_num_heads, dim_head] - const T* __restrict__ qkv_biases, // [num_head + 2 * kv_num_heads, - // dim_head] + const T* __restrict__ qkv_biases, // [num_head + 2 * kv_num_heads, + // dim_head] const T* __restrict__ cache_k_scale, const T* __restrict__ cache_v_scale, const T* __restrict__ cache_k_zero_points, @@ -2054,7 +3140,8 @@ __global__ void append_decode_cache_int4_rope_kernel( const int block_size, const float max_bound, const float min_bound, - const int kv_num_heads) { + const int kv_num_heads, + const bool rope_3d) { static_assert(HeadDim == 128, "just support HeadDim be 128 now!"); static_assert(VecSize == 4, "just support VecSize be 4 now, 32 * 4!"); constexpr int NUM_WARPS = 4; @@ -2076,7 +3163,9 @@ __global__ void append_decode_cache_int4_rope_kernel( const int block_idx = __ldg(&block_table_now[write_seq_id / block_size]); const int block_offset = write_seq_id % block_size; - +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaGridDependencySynchronize(); +#endif if (head_idx < num_heads) { // q using LoadT = AlignedVector; @@ -2103,8 +3192,10 @@ __global__ void append_decode_cache_int4_rope_kernel( Load(&qkv_out_scales[bias_idx], &out_scale_vec); // q rope const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2; - Load(&cos_emb[emb_idx], &cos_emb_vec); - Load(&sin_emb[emb_idx], &sin_emb_vec); + uint32_t new_emb_idx = + rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx; + Load(&cos_emb[new_emb_idx], &cos_emb_vec); + Load(&sin_emb[new_emb_idx], &sin_emb_vec); #pragma unroll for (int i = 0; i < HalfVecSize; i++) { // dequant + add_bias + rope @@ -2119,9 +3210,11 @@ __global__ void append_decode_cache_int4_rope_kernel( const float cos_tmp = cos_emb_vec[i]; const float sin_tmp = sin_emb_vec[i]; bias_vec[2 * i] = - static_cast(input_left * cos_tmp - input_right * sin_tmp); + static_cast(fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp)); bias_vec[2 * i + 1] = - static_cast(input_right * cos_tmp + input_left * sin_tmp); + static_cast(fmul_func(input_right, cos_tmp) + + fmul_func(input_left, sin_tmp)); } Store(bias_vec, &qkv_out_now[bias_idx]); } @@ -2191,10 +3284,12 @@ __global__ void append_decode_cache_int4_rope_kernel( &out_scale_vec2); if (head_idx < num_heads + kv_num_heads) { const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2; - Load(&cos_emb[emb_idx], &cos_emb_vec1); - Load(&cos_emb[emb_idx + 4], &cos_emb_vec2); - Load(&sin_emb[emb_idx], &sin_emb_vec1); - Load(&sin_emb[emb_idx + 4], &sin_emb_vec2); + uint32_t new_emb_idx = + rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx; + Load(&cos_emb[new_emb_idx], &cos_emb_vec1); + Load(&cos_emb[new_emb_idx + 4], &cos_emb_vec2); + Load(&sin_emb[new_emb_idx], &sin_emb_vec1); + Load(&sin_emb[new_emb_idx + 4], &sin_emb_vec2); Load(&cache_k_scale[cache_idx], &scale_vec1); Load(&cache_k_scale[cache_idx + 8], &scale_vec2); Load(&cache_k_zero_points[cache_idx], &zp_vec1); @@ -2218,9 +3313,11 @@ __global__ void append_decode_cache_int4_rope_kernel( float cos_tmp = cos_emb_vec1[0]; float sin_tmp = sin_emb_vec1[0]; bias_vec1[0] = - static_cast(input_left * cos_tmp - input_right * sin_tmp); + static_cast(fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp)); bias_vec1[1] = - static_cast(input_right * cos_tmp + input_left * sin_tmp); + static_cast(fmul_func(input_right, cos_tmp) + + fmul_func(input_left, sin_tmp)); } else { bias_vec1[0] = static_cast(input_left); bias_vec1[1] = static_cast(input_right); @@ -2238,9 +3335,11 @@ __global__ void append_decode_cache_int4_rope_kernel( float cos_tmp = cos_emb_vec2[0]; float sin_tmp = sin_emb_vec2[0]; bias_vec2[0] = - static_cast(input_left * cos_tmp - input_right * sin_tmp); + static_cast(fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp)); bias_vec2[1] = - static_cast(input_right * cos_tmp + input_left * sin_tmp); + static_cast(fmul_func(input_right, cos_tmp) + + fmul_func(input_left, sin_tmp)); } else { bias_vec2[0] = static_cast(input_left); bias_vec2[1] = static_cast(input_right); @@ -2350,9 +3449,16 @@ __global__ void append_decode_cache_int4_rope_kernel( (uint_quant_value2 << 4) | (uint_quant_value1 & 0x0F); } } +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaTriggerProgrammaticLaunchCompletion(); +#endif } -template +template __global__ void append_decode_cache_int4_neox_rope_kernel( const T* __restrict__ quant_qkv, // [bsz, num_heads + 2 * kv_num_heads, // head_size] @@ -2361,8 +3467,8 @@ __global__ void append_decode_cache_int4_neox_rope_kernel( uint8_t* __restrict__ value_cache, // [num_blocks, kv_num_heads, // block_size, head_size // 2] T* __restrict__ qkv_out, - const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq] - const int* __restrict__ batch_id_per_token, // [num_tokens] + const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq] + const int* __restrict__ cu_seqlens_q, const int* __restrict__ seq_lens, // [bsz] const int* __restrict__ seq_lens_encoder, // [bsz] @@ -2378,7 +3484,8 @@ __global__ void append_decode_cache_int4_neox_rope_kernel( const int block_size, const float max_bound, const float min_bound, - const int kv_num_heads) { + const int kv_num_heads, + const bool rope_3d) { static_assert(HeadDim == 128, "just support HeadDim be 128 now!"); static_assert(VecSize == 4, "just support VecSize be 4 now, 32 * 4!"); constexpr int NUM_WARPS = 4; @@ -2399,7 +3506,9 @@ __global__ void append_decode_cache_int4_neox_rope_kernel( const int block_idx = __ldg(&block_table_now[write_seq_id / block_size]); const int block_offset = write_seq_id % block_size; - +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaGridDependencySynchronize(); +#endif if (head_idx < num_heads) { // q using LoadT = AlignedVector; @@ -2425,8 +3534,10 @@ __global__ void append_decode_cache_int4_neox_rope_kernel( // q rope const uint32_t emb_idx = write_seq_id * HeadDim + head_bias; - Load(&cos_emb[emb_idx], &cos_emb_vec); - Load(&sin_emb[emb_idx], &sin_emb_vec); + uint32_t new_emb_idx = + rope_3d ? emb_idx + bid * max_seq_len * HeadDim * 2 : emb_idx; + Load(&cos_emb[new_emb_idx], &cos_emb_vec); + Load(&sin_emb[new_emb_idx], &sin_emb_vec); #pragma unroll for (int i = 0; i < VecSize; i++) { // dequant + add_bias + rope @@ -2436,9 +3547,11 @@ __global__ void append_decode_cache_int4_neox_rope_kernel( const float cos_tmp = cos_emb_vec[i]; const float sin_tmp = sin_emb_vec[i]; left_out_vec[i] = - static_cast(input_left * cos_tmp - input_right * sin_tmp); + static_cast(fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp)); right_out_vec[i] = - static_cast(input_right * cos_tmp + input_left * sin_tmp); + static_cast(fmul_func(input_right, cos_tmp) + + fmul_func(input_left, sin_tmp)); } Store(left_out_vec, &qkv_out_now[bias_idx_left]); Store(right_out_vec, &qkv_out_now[bias_idx_right]); @@ -2507,10 +3620,12 @@ __global__ void append_decode_cache_int4_neox_rope_kernel( Load(&qkv_now[right_bias_idx], &right_src_vec1); Load(&qkv_now[right_bias_idx + 8], &right_src_vec2); const uint32_t emb_idx = write_seq_id * HeadDim + head_bias; - Load(&cos_emb[emb_idx], &cos_emb_vec1); - Load(&cos_emb[emb_idx + 8], &cos_emb_vec2); - Load(&sin_emb[emb_idx], &sin_emb_vec1); - Load(&sin_emb[emb_idx + 8], &sin_emb_vec2); + uint32_t new_emb_idx = + rope_3d ? emb_idx + bid * max_seq_len * HeadDim * 2 : emb_idx; + Load(&cos_emb[new_emb_idx], &cos_emb_vec1); + Load(&cos_emb[new_emb_idx + 8], &cos_emb_vec2); + Load(&sin_emb[new_emb_idx], &sin_emb_vec1); + Load(&sin_emb[new_emb_idx + 8], &sin_emb_vec2); Load(&cache_k_scale[left_cache_idx], &left_scale_vec1); Load(&cache_k_scale[left_cache_idx + 8], @@ -2534,19 +3649,22 @@ __global__ void append_decode_cache_int4_neox_rope_kernel( float cos_tmp = cos_emb_vec1[0]; float sin_tmp = sin_emb_vec1[0]; left_out_vec1[i] = - static_cast(input_left * cos_tmp - input_right * sin_tmp); + static_cast(fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp)); right_out_vec1[i] = - static_cast(input_right * cos_tmp + input_left * sin_tmp); - + static_cast(fmul_func(input_right, cos_tmp) + + fmul_func(input_left, sin_tmp)); input_left = static_cast(left_src_vec2[i]); input_right = static_cast(right_src_vec2[i]); cos_tmp = cos_emb_vec2[i]; sin_tmp = sin_emb_vec2[i]; left_out_vec2[i] = - static_cast(input_left * cos_tmp - input_right * sin_tmp); + static_cast(fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp)); right_out_vec2[i] = - static_cast(input_right * cos_tmp + input_left * sin_tmp); + static_cast(fmul_func(input_right, cos_tmp) + + fmul_func(input_left, sin_tmp)); // quant + write k } LoadKVResT left_cache_vec, right_cache_vec; @@ -2720,10 +3838,17 @@ __global__ void append_decode_cache_int4_neox_rope_kernel( (uint_quant_value2 << 4) | (uint_quant_value1 & 0x0F); } } +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaTriggerProgrammaticLaunchCompletion(); +#endif } -template -__global__ void append_decode_cache_int4_neox_rope_kernel( +template +__global__ void int_append_decode_cache_int4_neox_rope_kernel( const int* __restrict__ quant_qkv, // [bsz, num_heads + 2 * kv_num_heads, // head_size] uint8_t* __restrict__ key_cache, // [num_blocks, kv_num_heads, @@ -2731,8 +3856,8 @@ __global__ void append_decode_cache_int4_neox_rope_kernel( uint8_t* __restrict__ value_cache, // [num_blocks, kv_num_heads, // block_size, head_size // 2] T* __restrict__ qkv_out, - const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq] - const int* __restrict__ batch_id_per_token, // [num_tokens] + const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq] + const int* __restrict__ cu_seqlens_q, const int* __restrict__ seq_lens, // [bsz] const int* __restrict__ seq_lens_encoder, // [bsz] @@ -2740,8 +3865,8 @@ __global__ void append_decode_cache_int4_neox_rope_kernel( const float* __restrict__ sin_emb, const float* __restrict__ qkv_out_scales, // [num_head + 2 * // kv_num_heads, dim_head] - const T* __restrict__ qkv_biases, // [num_head + 2 * kv_num_heads, - // dim_head] + const T* __restrict__ qkv_biases, // [num_head + 2 * kv_num_heads, + // dim_head] const T* __restrict__ cache_k_scale, const T* __restrict__ cache_v_scale, const T* __restrict__ cache_k_zero_points, @@ -2752,7 +3877,8 @@ __global__ void append_decode_cache_int4_neox_rope_kernel( const int block_size, const float max_bound, const float min_bound, - const int kv_num_heads) { + const int kv_num_heads, + const bool rope_3d) { static_assert(HeadDim == 128, "just support HeadDim be 128 now!"); static_assert(VecSize == 4, "just support VecSize be 4 now, 32 * 4!"); constexpr int NUM_WARPS = 4; @@ -2774,7 +3900,9 @@ __global__ void append_decode_cache_int4_neox_rope_kernel( const int block_idx = __ldg(&block_table_now[write_seq_id / block_size]); const int block_offset = write_seq_id % block_size; - +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaGridDependencySynchronize(); +#endif if (head_idx < num_heads) { // q using LoadT = AlignedVector; @@ -2810,8 +3938,10 @@ __global__ void append_decode_cache_int4_neox_rope_kernel( &right_out_scale_vec); // q rope const uint32_t emb_idx = write_seq_id * HeadDim + head_bias; - Load(&cos_emb[emb_idx], &cos_emb_vec); - Load(&sin_emb[emb_idx], &sin_emb_vec); + uint32_t new_emb_idx = + rope_3d ? emb_idx + bid * max_seq_len * HeadDim * 2 : emb_idx; + Load(&cos_emb[new_emb_idx], &cos_emb_vec); + Load(&sin_emb[new_emb_idx], &sin_emb_vec); #pragma unroll for (int i = 0; i < VecSize; i++) { // dequant + add_bias + rope @@ -2826,9 +3956,11 @@ __global__ void append_decode_cache_int4_neox_rope_kernel( const float cos_tmp = cos_emb_vec[i]; const float sin_tmp = sin_emb_vec[i]; left_bias_vec[i] = - static_cast(input_left * cos_tmp - input_right * sin_tmp); + static_cast(fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp)); right_bias_vec[i] = - static_cast(input_right * cos_tmp + input_left * sin_tmp); + static_cast(fmul_func(input_right, cos_tmp) + + fmul_func(input_left, sin_tmp)); } Store(left_bias_vec, &qkv_out_now[bias_idx_left]); Store(right_bias_vec, &qkv_out_now[bias_idx_right]); @@ -2920,10 +4052,12 @@ __global__ void append_decode_cache_int4_neox_rope_kernel( &right_out_scale_vec2); const uint32_t emb_idx = write_seq_id * HeadDim + head_bias; - Load(&cos_emb[emb_idx], &cos_emb_vec1); - Load(&cos_emb[emb_idx + 8], &cos_emb_vec2); - Load(&sin_emb[emb_idx], &sin_emb_vec1); - Load(&sin_emb[emb_idx + 8], &sin_emb_vec2); + uint32_t new_emb_idx = + rope_3d ? emb_idx + bid * max_seq_len * HeadDim * 2 : emb_idx; + Load(&cos_emb[new_emb_idx], &cos_emb_vec1); + Load(&cos_emb[new_emb_idx + 8], &cos_emb_vec2); + Load(&sin_emb[new_emb_idx], &sin_emb_vec1); + Load(&sin_emb[new_emb_idx + 8], &sin_emb_vec2); Load(&cache_k_scale[left_cache_idx], &left_scale_vec1); Load(&cache_k_scale[left_cache_idx + 8], @@ -2953,19 +4087,22 @@ __global__ void append_decode_cache_int4_neox_rope_kernel( float cos_tmp = cos_emb_vec1[0]; float sin_tmp = sin_emb_vec1[0]; left_bias_vec1[i] = - static_cast(input_left * cos_tmp - input_right * sin_tmp); + static_cast(fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp)); right_bias_vec1[i] = - static_cast(input_right * cos_tmp + input_left * sin_tmp); - + static_cast(fmul_func(input_right, cos_tmp) + + fmul_func(input_left, sin_tmp)); input_left = static_cast(left_src_vec2[i]); input_right = static_cast(right_src_vec2[i]); cos_tmp = cos_emb_vec2[i]; sin_tmp = sin_emb_vec2[i]; left_bias_vec2[i] = - static_cast(input_left * cos_tmp - input_right * sin_tmp); + static_cast(fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp)); right_bias_vec2[i] = - static_cast(input_right * cos_tmp + input_left * sin_tmp); + static_cast(fmul_func(input_right, cos_tmp) + + fmul_func(input_left, sin_tmp)); // quant + write k } @@ -3169,4 +4306,7 @@ __global__ void append_decode_cache_int4_neox_rope_kernel( (uint_quant_value2 << 4) | (uint_quant_value1 & 0x0F); } } +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaTriggerProgrammaticLaunchCompletion(); +#endif } diff --git a/custom_ops/gpu_ops/append_attn/decoder_write_cache_with_rope_kernel.cu b/custom_ops/gpu_ops/append_attn/decoder_write_cache_with_rope_kernel.cu index fe72d120a4a..963ccfa23d9 100644 --- a/custom_ops/gpu_ops/append_attn/decoder_write_cache_with_rope_kernel.cu +++ b/custom_ops/gpu_ops/append_attn/decoder_write_cache_with_rope_kernel.cu @@ -15,13 +15,78 @@ #include "decoder_write_cache_with_rope_kernel.h" #include "utils.cuh" -template +template +void append_decode_cache_rope_qk_norm(const QKV_TYPE* qkv, + T* key_cache, + T* value_cache, + T* qkv_out, + const int* block_tables, + const int* cu_seqlens_q, + const int* seq_lens, + const int* seq_lens_encoder, + const float* cos_emb, + const float* sin_emb, + const float* qkv_out_scales, + const T* qkv_biases, + const int max_seq_len, + const int max_blocks_per_seq, + const int num_heads, + const int kv_num_heads, + const int dim_head, + const int block_size, + const int bsz, + const cudaStream_t& stream, + const bool use_neox_style, + const bool rope_3d, + const float* q_norm_weight, + const float* k_norm_weight, + const float rms_norm_eps) { + const uint32_t elem_nums = + use_neox_style ? bsz * (num_heads + 2 * kv_num_heads) * dim_head / 2 + : bsz * (num_heads + 2 * kv_num_heads) * dim_head; + constexpr int HEAD_DIM = 128; + + constexpr int PackSize = HEAD_DIM / kWarpSize; + const int pack_num = elem_nums / PackSize; + const int blocksize = 128; + int grid_size = 1; + GetNumBlocks<128>(pack_num, &grid_size); + dim3 block_dim(kWarpSize, blocksize / kWarpSize, 1); + launchWithPdlWhenEnabled( + append_decode_cache_T_rope_qk_norm_kernel, + grid_size, + block_dim, + 0, + stream, + reinterpret_cast(qkv), + key_cache, + value_cache, + qkv_out, + block_tables, + cu_seqlens_q, + seq_lens, + seq_lens_encoder, + cos_emb, + sin_emb, + max_seq_len, + max_blocks_per_seq, + num_heads, + dim_head, + block_size, + elem_nums, + kv_num_heads, + rope_3d, + q_norm_weight, + k_norm_weight, + rms_norm_eps); +} + +template void append_decode_cache_rope(const QKV_TYPE* qkv, T* key_cache, T* value_cache, T* qkv_out, const int* block_tables, - const int* batch_id_per_token, const int* cu_seqlens_q, const int* seq_lens, const int* seq_lens_encoder, @@ -34,6 +99,7 @@ void append_decode_cache_rope(const QKV_TYPE* qkv, const int num_heads, const int kv_num_heads, const int dim_head, + const int rotary_dim, const int block_size, const int bsz, const cudaStream_t& stream, @@ -50,106 +116,160 @@ void append_decode_cache_rope(const QKV_TYPE* qkv, GetNumBlocks<128>(pack_num, &grid_size); if (use_neox_style) { if (qkv_out_scales) { - append_decode_cache_T_neox_rope_kernel - <<>>( - reinterpret_cast(qkv), - key_cache, - value_cache, - qkv_out, - block_tables, - batch_id_per_token, - cu_seqlens_q, - seq_lens, - seq_lens_encoder, - cos_emb, - sin_emb, - qkv_out_scales, - qkv_biases, - max_seq_len, - max_blocks_per_seq, - num_heads, - dim_head, - block_size, - elem_nums, - kv_num_heads); + launchWithPdlWhenEnabled( + append_decode_cache_T_quant_neox_rope_kernel, + grid_size, + blocksize, + 0, + stream, + reinterpret_cast(qkv), + key_cache, + value_cache, + qkv_out, + block_tables, + cu_seqlens_q, + seq_lens, + seq_lens_encoder, + cos_emb, + sin_emb, + qkv_out_scales, + qkv_biases, + max_seq_len, + max_blocks_per_seq, + num_heads, + dim_head, + block_size, + elem_nums, + kv_num_heads, + rope_3d); } else { - append_decode_cache_T_neox_rope_kernel - <<>>(reinterpret_cast(qkv), - key_cache, - value_cache, - qkv_out, - block_tables, - batch_id_per_token, - cu_seqlens_q, - seq_lens, - seq_lens_encoder, - cos_emb, - sin_emb, - max_seq_len, - max_blocks_per_seq, - num_heads, - dim_head, - block_size, - elem_nums, - kv_num_heads); + if (rotary_dim < dim_head) { + auto* kernelFn = + append_decode_cache_T_neox_partial_rope_kernel; + launchWithPdlWhenEnabled(kernelFn, + grid_size, + blocksize, + 0, + stream, + reinterpret_cast(qkv), + key_cache, + value_cache, + qkv_out, + block_tables, + cu_seqlens_q, + seq_lens, + seq_lens_encoder, + cos_emb, + sin_emb, + max_seq_len, + max_blocks_per_seq, + num_heads, + dim_head, + rotary_dim, + block_size, + elem_nums, + kv_num_heads, + rope_3d); + } else { + auto* kernelFn = + append_decode_cache_T_neox_rope_kernel; + launchWithPdlWhenEnabled(kernelFn, + grid_size, + blocksize, + 0, + stream, + reinterpret_cast(qkv), + key_cache, + value_cache, + qkv_out, + block_tables, + cu_seqlens_q, + seq_lens, + seq_lens_encoder, + cos_emb, + sin_emb, + max_seq_len, + max_blocks_per_seq, + num_heads, + dim_head, + block_size, + elem_nums, + kv_num_heads, + rope_3d); + } } } else { if (qkv_out_scales) { - append_decode_cache_T_rope_kernel - <<>>( - reinterpret_cast(qkv), - key_cache, - value_cache, - qkv_out, - block_tables, - batch_id_per_token, - cu_seqlens_q, - seq_lens, - seq_lens_encoder, - cos_emb, - sin_emb, - qkv_out_scales, - qkv_biases, - max_seq_len, - max_blocks_per_seq, - num_heads, - dim_head, - block_size, - elem_nums, - kv_num_heads, - rope_3d); + launchWithPdlWhenEnabled( + append_decode_cache_T_quant_rope_kernel, + grid_size, + blocksize, + 0, + stream, + reinterpret_cast(qkv), + key_cache, + value_cache, + qkv_out, + block_tables, + cu_seqlens_q, + seq_lens, + seq_lens_encoder, + cos_emb, + sin_emb, + qkv_out_scales, + qkv_biases, + max_seq_len, + max_blocks_per_seq, + num_heads, + dim_head, + block_size, + elem_nums, + kv_num_heads, + rope_3d); } else { - append_decode_cache_T_rope_kernel - <<>>(reinterpret_cast(qkv), - key_cache, - value_cache, - qkv_out, - block_tables, - batch_id_per_token, - cu_seqlens_q, - seq_lens, - seq_lens_encoder, - cos_emb, - sin_emb, - max_seq_len, - max_blocks_per_seq, - num_heads, - dim_head, - block_size, - elem_nums, - kv_num_heads, - rope_3d); + auto* kernelFn = + append_decode_cache_T_rope_kernel; + launchWithPdlWhenEnabled(kernelFn, + grid_size, + blocksize, + 0, + stream, + reinterpret_cast(qkv), + key_cache, + value_cache, + qkv_out, + block_tables, + cu_seqlens_q, + seq_lens, + seq_lens_encoder, + cos_emb, + sin_emb, + max_seq_len, + max_blocks_per_seq, + num_heads, + dim_head, + block_size, + elem_nums, + kv_num_heads, + rope_3d); } } } -template +template void append_decode_cache_int8_rope(const QKV_TYPE* qkv, uint8_t* key_cache, uint8_t* value_cache, T* qkv_out, const int* block_tables, - const int* batch_id_per_token, const int* cu_seqlens_q, const int* seq_lens, const int* seq_lens_encoder, @@ -175,114 +295,149 @@ void append_decode_cache_int8_rope(const QKV_TYPE* qkv, dim3 grids(bsz, all_warps / num_warps); if (use_neox_style) { if (qkv_out_scales) { - append_decode_cache_int8_neox_rope_kernel - <<>>( - reinterpret_cast(qkv), - key_cache, - value_cache, - qkv_out, - block_tables, - batch_id_per_token, - cu_seqlens_q, - seq_lens, - seq_lens_encoder, - cos_emb, - sin_emb, - qkv_out_scales, - qkv_biases, - cache_k_scale, - cache_v_scale, - max_seq_len, - max_blocks_per_seq, - num_heads, - block_size, - 127.0f, - -127.0f, - kv_num_heads); + launchWithPdlWhenEnabled( + int_append_decode_cache_int8_neox_rope_kernel, + grids, + num_warps * 32, + 0, + stream, + reinterpret_cast(qkv), + key_cache, + value_cache, + qkv_out, + block_tables, + cu_seqlens_q, + seq_lens, + seq_lens_encoder, + cos_emb, + sin_emb, + qkv_out_scales, + qkv_biases, + cache_k_scale, + cache_v_scale, + max_seq_len, + max_blocks_per_seq, + num_heads, + block_size, + 127.0f, + -127.0f, + kv_num_heads, + rope_3d); } else { - append_decode_cache_int8_neox_rope_kernel - <<>>( - reinterpret_cast(qkv), - key_cache, - value_cache, - qkv_out, - block_tables, - batch_id_per_token, - cu_seqlens_q, - seq_lens, - seq_lens_encoder, - cos_emb, - sin_emb, - cache_k_scale, - cache_v_scale, - max_seq_len, - max_blocks_per_seq, - num_heads, - block_size, - 127.0f, - -127.0f, - kv_num_heads); + launchWithPdlWhenEnabled( + append_decode_cache_int8_neox_rope_kernel, + grids, + num_warps * 32, + 0, + stream, + reinterpret_cast(qkv), + key_cache, + value_cache, + qkv_out, + block_tables, + cu_seqlens_q, + seq_lens, + seq_lens_encoder, + cos_emb, + sin_emb, + cache_k_scale, + cache_v_scale, + max_seq_len, + max_blocks_per_seq, + num_heads, + block_size, + 127.0f, + -127.0f, + kv_num_heads, + rope_3d); } } else { if (qkv_out_scales) { - append_decode_cache_int8_rope_kernel - <<>>( - reinterpret_cast(qkv), - key_cache, - value_cache, - qkv_out, - block_tables, - batch_id_per_token, - cu_seqlens_q, - seq_lens, - seq_lens_encoder, - cos_emb, - sin_emb, - qkv_out_scales, - qkv_biases, - cache_k_scale, - cache_v_scale, - max_seq_len, - max_blocks_per_seq, - num_heads, - block_size, - 127.0f, - -127.0f, - kv_num_heads); + launchWithPdlWhenEnabled( + int_append_decode_cache_int8_rope_kernel, + grids, + num_warps * 32, + 0, + stream, + reinterpret_cast(qkv), + key_cache, + value_cache, + qkv_out, + block_tables, + cu_seqlens_q, + seq_lens, + seq_lens_encoder, + cos_emb, + sin_emb, + qkv_out_scales, + qkv_biases, + cache_k_scale, + cache_v_scale, + max_seq_len, + max_blocks_per_seq, + num_heads, + block_size, + 127.0f, + -127.0f, + kv_num_heads, + rope_3d); } else { - append_decode_cache_int8_rope_kernel - <<>>( - reinterpret_cast(qkv), - key_cache, - value_cache, - qkv_out, - block_tables, - batch_id_per_token, - cu_seqlens_q, - seq_lens, - seq_lens_encoder, - cos_emb, - sin_emb, - cache_k_scale, - cache_v_scale, - max_seq_len, - max_blocks_per_seq, - num_heads, - block_size, - 127.0f, - -127.0f, - kv_num_heads); + launchWithPdlWhenEnabled( + append_decode_cache_int8_rope_kernel, + grids, + num_warps * 32, + 0, + stream, + reinterpret_cast(qkv), + key_cache, + value_cache, + qkv_out, + block_tables, + cu_seqlens_q, + seq_lens, + seq_lens_encoder, + cos_emb, + sin_emb, + cache_k_scale, + cache_v_scale, + max_seq_len, + max_blocks_per_seq, + num_heads, + block_size, + 127.0f, + -127.0f, + kv_num_heads, + rope_3d); } } } -template +template void append_decode_cache_int4_rope(const QKV_TYPE* qkv, uint8_t* key_cache, uint8_t* value_cache, T* qkv_out, const int* block_tables, - const int* batch_id_per_token, const int* cu_seqlens_q, const int* seq_lens, const int* seq_lens_encoder, @@ -310,121 +465,144 @@ void append_decode_cache_int4_rope(const QKV_TYPE* qkv, dim3 grids(bsz, all_warps / num_warps); if (use_neox_style) { if (qkv_out_scales) { - append_decode_cache_int4_neox_rope_kernel - <<>>( - reinterpret_cast(qkv), - key_cache, - value_cache, - qkv_out, - block_tables, - batch_id_per_token, - cu_seqlens_q, - seq_lens, - seq_lens_encoder, - cos_emb, - sin_emb, - qkv_out_scales, - qkv_biases, - cache_k_scale, - cache_v_scale, - cache_k_zp, - cache_v_zp, - max_seq_len, - max_blocks_per_seq, - num_heads, - block_size, - 7.0f, - -8.0f, - kv_num_heads); + launchWithPdlWhenEnabled( + int_append_decode_cache_int4_neox_rope_kernel, + grids, + num_warps * 32, + 0, + stream, + reinterpret_cast(qkv), + key_cache, + value_cache, + qkv_out, + block_tables, + cu_seqlens_q, + seq_lens, + seq_lens_encoder, + cos_emb, + sin_emb, + qkv_out_scales, + qkv_biases, + cache_k_scale, + cache_v_scale, + cache_k_zp, + cache_v_zp, + max_seq_len, + max_blocks_per_seq, + num_heads, + block_size, + 7.0f, + -8.0f, + kv_num_heads, + rope_3d); } else { - append_decode_cache_int4_neox_rope_kernel - <<>>( - reinterpret_cast(qkv), - key_cache, - value_cache, - qkv_out, - block_tables, - batch_id_per_token, - cu_seqlens_q, - seq_lens, - seq_lens_encoder, - cos_emb, - sin_emb, - cache_k_scale, - cache_v_scale, - cache_k_zp, - cache_v_zp, - max_seq_len, - max_blocks_per_seq, - num_heads, - block_size, - 7.0f, - -8.0f, - kv_num_heads); + launchWithPdlWhenEnabled( + append_decode_cache_int4_neox_rope_kernel, + grids, + num_warps * 32, + 0, + stream, + reinterpret_cast(qkv), + key_cache, + value_cache, + qkv_out, + block_tables, + cu_seqlens_q, + seq_lens, + seq_lens_encoder, + cos_emb, + sin_emb, + cache_k_scale, + cache_v_scale, + cache_k_zp, + cache_v_zp, + max_seq_len, + max_blocks_per_seq, + num_heads, + block_size, + 7.0f, + -8.0f, + kv_num_heads, + rope_3d); } } else { if (qkv_out_scales) { - append_decode_cache_int4_rope_kernel - <<>>( - reinterpret_cast(qkv), - key_cache, - value_cache, - qkv_out, - block_tables, - batch_id_per_token, - cu_seqlens_q, - seq_lens, - seq_lens_encoder, - cos_emb, - sin_emb, - qkv_out_scales, - qkv_biases, - cache_k_scale, - cache_v_scale, - cache_k_zp, - cache_v_zp, - max_seq_len, - max_blocks_per_seq, - num_heads, - block_size, - 7.0f, - -8.0f, - kv_num_heads); + launchWithPdlWhenEnabled( + int_append_decode_cache_int4_rope_kernel, + grids, + num_warps * 32, + 0, + stream, + reinterpret_cast(qkv), + key_cache, + value_cache, + qkv_out, + block_tables, + cu_seqlens_q, + seq_lens, + seq_lens_encoder, + cos_emb, + sin_emb, + qkv_out_scales, + qkv_biases, + cache_k_scale, + cache_v_scale, + cache_k_zp, + cache_v_zp, + max_seq_len, + max_blocks_per_seq, + num_heads, + block_size, + 7.0f, + -8.0f, + kv_num_heads, + rope_3d); } else { - append_decode_cache_int4_rope_kernel - <<>>( - reinterpret_cast(qkv), - key_cache, - value_cache, - qkv_out, - block_tables, - batch_id_per_token, - cu_seqlens_q, - seq_lens, - seq_lens_encoder, - cos_emb, - sin_emb, - cache_k_scale, - cache_v_scale, - cache_k_zp, - cache_v_zp, - max_seq_len, - max_blocks_per_seq, - num_heads, - block_size, - 7.0f, - -8.0f, - kv_num_heads); + launchWithPdlWhenEnabled( + append_decode_cache_int4_rope_kernel, + grids, + num_warps * 32, + 0, + stream, + reinterpret_cast(qkv), + key_cache, + value_cache, + qkv_out, + block_tables, + cu_seqlens_q, + seq_lens, + seq_lens_encoder, + cos_emb, + sin_emb, + cache_k_scale, + cache_v_scale, + cache_k_zp, + cache_v_zp, + max_seq_len, + max_blocks_per_seq, + num_heads, + block_size, + 7.0f, + -8.0f, + kv_num_heads, + rope_3d); } } } -template +template void DecoderWriteCacheWithRoPEKernel( const AppendAttnMetaData& meta_data, const paddle::Tensor& qkv, const paddle::Tensor& seq_lens, const paddle::Tensor& seq_lens_encoder, - const paddle::Tensor& batch_id_per_token, const paddle::Tensor& cu_seqlens_q, const paddle::Tensor& block_tables, const paddle::optional& rotary_embs, @@ -441,7 +619,10 @@ void DecoderWriteCacheWithRoPEKernel( cudaStream_t& stream, paddle::Tensor* qkv_out, paddle::Tensor* key_cache_out, - paddle::Tensor* value_cache_out) { + paddle::Tensor* value_cache_out, + const paddle::optional& q_norm_weight, + const paddle::optional& k_norm_weight, + const float rms_norm_eps) { typedef cascade_attn_type_traits traits_; typedef cascade_attn_type_traits qkt_nv_type_; typedef typename traits_::type DataType_; @@ -458,85 +639,265 @@ void DecoderWriteCacheWithRoPEKernel( const float* cos_emb = rotary_embs ? rotary_embs.get().data() : nullptr; const float* sin_emb; + int rotary_dim = dim_head; if (rotary_embs) { sin_emb = use_neox_rotary_style ? rotary_embs.get().data() + max_seq_len * dim_head : rotary_embs.get().data() + max_seq_len * dim_head / 2; - } - if (cache_quant_type_str == "none") { - append_decode_cache_rope( - reinterpret_cast(qkv_ptr), - reinterpret_cast(key_cache_out->data()), - reinterpret_cast(value_cache_out->data()), - reinterpret_cast(qkv_out->data()), - block_tables.data(), - batch_id_per_token.data(), - cu_seqlens_q.data(), - seq_lens.data(), - seq_lens_encoder.data(), - cos_emb, - sin_emb, - qkv_out_scales ? qkv_out_scales.get().data() : nullptr, - qkv_biases ? reinterpret_cast( - const_cast(qkv_biases.get().data())) - : nullptr, - max_seq_len, - max_blocks_per_seq, - num_heads, - kv_num_heads, - dim_head, - block_size, - bsz, - stream, - use_neox_rotary_style, - rope_3d); - } else if (cache_quant_type_str == "cache_int8") { - bool is_scale_channel_wise = false; - if (cache_k_scale && cache_k_scale.get().dims()[0] == dim_head * kv_num_heads) { - is_scale_channel_wise = true; + rotary_dim = + rotary_embs.get().dims()[rotary_embs.get().dims().size() - 1] * 2; + if (rotary_dim < dim_head) { + if (!use_neox_rotary_style || qkv_out_scales || q_norm_weight || + k_norm_weight || cache_quant_type_str != "none") { + PADDLE_THROW(phi::errors::Fatal( + "partial_rotary_factor < 1.0 only supports neox_rotary_style=True, " + "qkv_out_scales is None, q_norm_weight/k_norm_weight) is None, and " + "cache_quant_type_str is 'none'.")); + } + sin_emb = rotary_embs.get().data() + max_seq_len * rotary_dim / 2; } - if (is_scale_channel_wise) { - append_decode_cache_int8_rope( - reinterpret_cast(qkv_ptr), - key_cache_out->data(), - value_cache_out->data(), - reinterpret_cast(qkv_out->data()), - block_tables.data(), - batch_id_per_token.data(), - cu_seqlens_q.data(), - seq_lens.data(), - seq_lens_encoder.data(), - cos_emb, - sin_emb, - qkv_out_scales ? qkv_out_scales.get().data() : nullptr, - qkv_biases ? reinterpret_cast( - const_cast(qkv_biases.get().data())) - : nullptr, - cache_k_scale ? reinterpret_cast( - const_cast(cache_k_scale.get().data())) - : nullptr, - cache_v_scale ? reinterpret_cast( - const_cast(cache_v_scale.get().data())) - : nullptr, - max_seq_len, - max_blocks_per_seq, - num_heads, - kv_num_heads, - dim_head, - block_size, - bsz, - stream, - use_neox_rotary_style, - rope_3d); + } + + if (q_norm_weight && k_norm_weight) { + if (cache_quant_type_str == "none") { + append_decode_cache_rope_qk_norm( + reinterpret_cast(qkv_ptr), + reinterpret_cast(key_cache_out->data()), + reinterpret_cast(value_cache_out->data()), + reinterpret_cast(qkv_out->data()), + block_tables.data(), + cu_seqlens_q.data(), + seq_lens.data(), + seq_lens_encoder.data(), + cos_emb, + sin_emb, + qkv_out_scales ? qkv_out_scales.get().data() : nullptr, + qkv_biases ? reinterpret_cast( + const_cast(qkv_biases.get().data())) + : nullptr, + max_seq_len, + max_blocks_per_seq, + num_heads, + kv_num_heads, + dim_head, + block_size, + bsz, + stream, + use_neox_rotary_style, + rope_3d, + q_norm_weight ? q_norm_weight.get().data() : nullptr, + k_norm_weight ? k_norm_weight.get().data() : nullptr, + rms_norm_eps); + } else if (cache_quant_type_str == "block_wise_fp8") { + constexpr int num_warps = 4; + const int all_warps = ((num_heads + 2 * kv_num_heads) + num_warps - 1) / + num_warps * num_warps; + dim3 grids(bsz, all_warps / num_warps); + launchWithPdlWhenEnabled( + append_decode_cache_int8_rope_qk_norm_kernel, + grids, + num_warps * 32, + 0, + stream, + reinterpret_cast(qkv_ptr), + key_cache_out->data(), + value_cache_out->data(), + reinterpret_cast(qkv_out->data()), + block_tables.data(), + cu_seqlens_q.data(), + seq_lens.data(), + seq_lens_encoder.data(), + cos_emb, + sin_emb, + const_cast(reinterpret_cast( + cache_k_scale.get().data())), + const_cast(reinterpret_cast( + (cache_v_scale.get().data()))), + q_norm_weight.get().data(), + k_norm_weight.get().data(), + max_seq_len, + max_blocks_per_seq, + num_heads, + block_size, + 127.0f, + -127.0f, + kv_num_heads, + rope_3d, + rms_norm_eps); + } else if ((cache_quant_type_str == "cache_fp8")) { + constexpr int num_warps = 4; + const int all_warps = ((num_heads + 2 * kv_num_heads) + num_warps - 1) / + num_warps * num_warps; + dim3 grids(bsz, all_warps / num_warps); + launchWithPdlWhenEnabled( + append_decode_cache_int8_rope_qk_norm_kernel, + grids, + num_warps * 32, + 0, + stream, + reinterpret_cast(qkv_ptr), + key_cache_out->data(), + value_cache_out->data(), + reinterpret_cast(qkv_out->data()), + block_tables.data(), + cu_seqlens_q.data(), + seq_lens.data(), + seq_lens_encoder.data(), + cos_emb, + sin_emb, + const_cast(reinterpret_cast( + cache_k_scale.get().data())), + const_cast(reinterpret_cast( + (cache_v_scale.get().data()))), + q_norm_weight.get().data(), + k_norm_weight.get().data(), + max_seq_len, + max_blocks_per_seq, + num_heads, + block_size, + 127.0f, + -127.0f, + kv_num_heads, + rope_3d, + rms_norm_eps); } else { - append_decode_cache_int8_rope( + PD_THROW( + "append_decode_cache_rope_qk_norm just supports cache_quant_type " + "none/block_wise_fp8/cache_fp8"); + } + } else { + if (cache_quant_type_str == "none") { + append_decode_cache_rope( + reinterpret_cast(qkv_ptr), + reinterpret_cast(key_cache_out->data()), + reinterpret_cast(value_cache_out->data()), + reinterpret_cast(qkv_out->data()), + block_tables.data(), + cu_seqlens_q.data(), + seq_lens.data(), + seq_lens_encoder.data(), + cos_emb, + sin_emb, + qkv_out_scales ? qkv_out_scales.get().data() : nullptr, + qkv_biases ? reinterpret_cast( + const_cast(qkv_biases.get().data())) + : nullptr, + max_seq_len, + max_blocks_per_seq, + num_heads, + kv_num_heads, + dim_head, + rotary_dim, + block_size, + bsz, + stream, + use_neox_rotary_style, + rope_3d); + } else if (cache_quant_type_str == "cache_int8") { + bool is_scale_channel_wise = false; + if (cache_k_scale && + cache_k_scale.get().dims()[0] == dim_head * kv_num_heads) { + is_scale_channel_wise = true; + } + if (is_scale_channel_wise) { + append_decode_cache_int8_rope( + reinterpret_cast(qkv_ptr), + key_cache_out->data(), + value_cache_out->data(), + reinterpret_cast(qkv_out->data()), + block_tables.data(), + cu_seqlens_q.data(), + seq_lens.data(), + seq_lens_encoder.data(), + cos_emb, + sin_emb, + qkv_out_scales ? qkv_out_scales.get().data() : nullptr, + qkv_biases ? reinterpret_cast( + const_cast(qkv_biases.get().data())) + : nullptr, + cache_k_scale ? reinterpret_cast( + const_cast(cache_k_scale.get().data())) + : nullptr, + cache_v_scale ? reinterpret_cast( + const_cast(cache_v_scale.get().data())) + : nullptr, + max_seq_len, + max_blocks_per_seq, + num_heads, + kv_num_heads, + dim_head, + block_size, + bsz, + stream, + use_neox_rotary_style, + rope_3d); + } else { + append_decode_cache_int8_rope( + reinterpret_cast(qkv_ptr), + key_cache_out->data(), + value_cache_out->data(), + reinterpret_cast(qkv_out->data()), + block_tables.data(), + cu_seqlens_q.data(), + seq_lens.data(), + seq_lens_encoder.data(), + cos_emb, + sin_emb, + qkv_out_scales ? qkv_out_scales.get().data() : nullptr, + qkv_biases ? reinterpret_cast( + const_cast(qkv_biases.get().data())) + : nullptr, + cache_k_scale ? reinterpret_cast( + const_cast(cache_k_scale.get().data())) + : nullptr, + cache_v_scale ? reinterpret_cast( + const_cast(cache_v_scale.get().data())) + : nullptr, + max_seq_len, + max_blocks_per_seq, + num_heads, + kv_num_heads, + dim_head, + block_size, + bsz, + stream, + use_neox_rotary_style, + rope_3d); + } + } else if (cache_quant_type_str == "cache_fp8") { + append_decode_cache_int8_rope( reinterpret_cast(qkv_ptr), key_cache_out->data(), value_cache_out->data(), reinterpret_cast(qkv_out->data()), block_tables.data(), - batch_id_per_token.data(), cu_seqlens_q.data(), seq_lens.data(), seq_lens_encoder.data(), @@ -544,8 +905,8 @@ void DecoderWriteCacheWithRoPEKernel( sin_emb, qkv_out_scales ? qkv_out_scales.get().data() : nullptr, qkv_biases ? reinterpret_cast( - const_cast(qkv_biases.get().data())) - : nullptr, + const_cast(qkv_biases.get().data())) + : nullptr, cache_k_scale ? reinterpret_cast( const_cast(cache_k_scale.get().data())) : nullptr, @@ -562,15 +923,95 @@ void DecoderWriteCacheWithRoPEKernel( stream, use_neox_rotary_style, rope_3d); - } - } else if (cache_quant_type_str == "cache_fp8") { - append_decode_cache_int8_rope( + } else if (cache_quant_type_str == "block_wise_fp8") { + constexpr int num_warps = 4; + const int all_warps = ((num_heads + 2 * kv_num_heads) + num_warps - 1) / + num_warps * num_warps; + dim3 grids(bsz, all_warps / num_warps); + if (use_neox_rotary_style) { + launchWithPdlWhenEnabled( + append_decode_cache_T_int8_neox_rope_kernel, + grids, + num_warps * 32, + 0, + stream, + reinterpret_cast(qkv_ptr), + key_cache_out->data(), + value_cache_out->data(), + reinterpret_cast(qkv_out->data()), + block_tables.data(), + cu_seqlens_q.data(), + seq_lens.data(), + seq_lens_encoder.data(), + cos_emb, + sin_emb, + const_cast(reinterpret_cast( + cache_k_scale.get().data())), + const_cast(reinterpret_cast( + (cache_v_scale.get().data()))), + max_seq_len, + max_blocks_per_seq, + num_heads, + block_size, + 127.0f, + -127.0f, + kv_num_heads, + rope_3d, + rms_norm_eps); + } else { + launchWithPdlWhenEnabled( + append_decode_cache_int8_rope_qk_norm_kernel, + grids, + num_warps * 32, + 0, + stream, + reinterpret_cast(qkv_ptr), + key_cache_out->data(), + value_cache_out->data(), + reinterpret_cast(qkv_out->data()), + block_tables.data(), + cu_seqlens_q.data(), + seq_lens.data(), + seq_lens_encoder.data(), + cos_emb, + sin_emb, + const_cast(reinterpret_cast( + cache_k_scale.get().data())), + const_cast(reinterpret_cast( + (cache_v_scale.get().data()))), + nullptr, + nullptr, + max_seq_len, + max_blocks_per_seq, + num_heads, + block_size, + 127.0f, + -127.0f, + kv_num_heads, + rope_3d, + rms_norm_eps); + } + } else if (cache_quant_type_str == "cache_int4_zp") { + append_decode_cache_int4_rope( reinterpret_cast(qkv_ptr), key_cache_out->data(), value_cache_out->data(), - reinterpret_cast(qkv_out->data()), + reinterpret_cast(const_cast(qkv_out->data())), block_tables.data(), - batch_id_per_token.data(), cu_seqlens_q.data(), seq_lens.data(), seq_lens_encoder.data(), @@ -578,14 +1019,20 @@ void DecoderWriteCacheWithRoPEKernel( sin_emb, qkv_out_scales ? qkv_out_scales.get().data() : nullptr, qkv_biases ? reinterpret_cast( - const_cast(qkv_biases.get().data())) - : nullptr, + const_cast(qkv_biases.get().data())) + : nullptr, cache_k_scale ? reinterpret_cast( const_cast(cache_k_scale.get().data())) : nullptr, cache_v_scale ? reinterpret_cast( const_cast(cache_v_scale.get().data())) : nullptr, + cache_k_zp ? reinterpret_cast( + const_cast(cache_k_zp.get().data())) + : nullptr, + cache_v_zp ? reinterpret_cast( + const_cast(cache_v_zp.get().data())) + : nullptr, max_seq_len, max_blocks_per_seq, num_heads, @@ -596,53 +1043,14 @@ void DecoderWriteCacheWithRoPEKernel( stream, use_neox_rotary_style, rope_3d); - } else if (cache_quant_type_str == "cache_int4_zp") { - append_decode_cache_int4_rope( - reinterpret_cast(qkv_ptr), - key_cache_out->data(), - value_cache_out->data(), - reinterpret_cast(const_cast(qkv_out->data())), - block_tables.data(), - batch_id_per_token.data(), - cu_seqlens_q.data(), - seq_lens.data(), - seq_lens_encoder.data(), - cos_emb, - sin_emb, - qkv_out_scales ? qkv_out_scales.get().data() : nullptr, - qkv_biases ? reinterpret_cast( - const_cast(qkv_biases.get().data())) - : nullptr, - cache_k_scale ? reinterpret_cast( - const_cast(cache_k_scale.get().data())) - : nullptr, - cache_v_scale ? reinterpret_cast( - const_cast(cache_v_scale.get().data())) - : nullptr, - cache_k_zp ? reinterpret_cast( - const_cast(cache_k_zp.get().data())) - : nullptr, - cache_v_zp ? reinterpret_cast( - const_cast(cache_v_zp.get().data())) - : nullptr, - max_seq_len, - max_blocks_per_seq, - num_heads, - kv_num_heads, - dim_head, - block_size, - bsz, - stream, - use_neox_rotary_style, - rope_3d); - } else { - PD_THROW( - "cache_quant_type_str should be one of [none, cache_int8, cache_fp8 " - "cache_int4_zp]"); + } else { + PD_THROW( + "cache_quant_type_str should be one of [none, cache_int8, cache_fp8 " + "cache_int4_zp]"); + } } } - template void DecoderWriteCacheWithRoPEKernel( const AppendAttnMetaData& meta_data, const paddle::Tensor& @@ -650,7 +1058,6 @@ template void DecoderWriteCacheWithRoPEKernel( // kv_num_heads, head_dim] if GQA) const paddle::Tensor& seq_lens, const paddle::Tensor& seq_lens_encoder, - const paddle::Tensor& batch_id_per_token, const paddle::Tensor& cu_seqlens_q, const paddle::Tensor& block_tables, const paddle::optional& rotary_embs, @@ -667,7 +1074,10 @@ template void DecoderWriteCacheWithRoPEKernel( cudaStream_t& stream, paddle::Tensor* qkv_out, paddle::Tensor* key_cache_out, - paddle::Tensor* value_cache_out); + paddle::Tensor* value_cache_out, + const paddle::optional& q_norm_weight, + const paddle::optional& k_norm_weight, + const float rms_norm_eps); template void DecoderWriteCacheWithRoPEKernel( @@ -677,7 +1087,6 @@ DecoderWriteCacheWithRoPEKernel( // kv_num_heads, head_dim] if GQA) const paddle::Tensor& seq_lens, const paddle::Tensor& seq_lens_encoder, - const paddle::Tensor& batch_id_per_token, const paddle::Tensor& cu_seqlens_q, const paddle::Tensor& block_tables, const paddle::optional& rotary_embs, @@ -694,7 +1103,10 @@ DecoderWriteCacheWithRoPEKernel( cudaStream_t& stream, paddle::Tensor* qkv_out, paddle::Tensor* key_cache_out, - paddle::Tensor* value_cache_out); + paddle::Tensor* value_cache_out, + const paddle::optional& q_norm_weight, + const paddle::optional& k_norm_weight, + const float rms_norm_eps); template void DecoderWriteCacheWithRoPEKernel( const AppendAttnMetaData& meta_data, @@ -703,7 +1115,6 @@ template void DecoderWriteCacheWithRoPEKernel( // kv_num_heads, head_dim] if GQA) const paddle::Tensor& seq_lens, const paddle::Tensor& seq_lens_encoder, - const paddle::Tensor& batch_id_per_token, const paddle::Tensor& cu_seqlens_q, const paddle::Tensor& block_tables, const paddle::optional& rotary_embs, @@ -720,7 +1131,10 @@ template void DecoderWriteCacheWithRoPEKernel( cudaStream_t& stream, paddle::Tensor* qkv_out, paddle::Tensor* key_cache_out, - paddle::Tensor* value_cache_out); + paddle::Tensor* value_cache_out, + const paddle::optional& q_norm_weight, + const paddle::optional& k_norm_weight, + const float rms_norm_eps); template void DecoderWriteCacheWithRoPEKernel( const AppendAttnMetaData& meta_data, @@ -729,7 +1143,6 @@ template void DecoderWriteCacheWithRoPEKernel( // kv_num_heads, head_dim] if GQA) const paddle::Tensor& seq_lens, const paddle::Tensor& seq_lens_encoder, - const paddle::Tensor& batch_id_per_token, const paddle::Tensor& cu_seqlens_q, const paddle::Tensor& block_tables, const paddle::optional& rotary_embs, @@ -746,4 +1159,121 @@ template void DecoderWriteCacheWithRoPEKernel( cudaStream_t& stream, paddle::Tensor* qkv_out, paddle::Tensor* key_cache_out, - paddle::Tensor* value_cache_out); + paddle::Tensor* value_cache_out, + const paddle::optional& q_norm_weight, + const paddle::optional& k_norm_weight, + const float rms_norm_eps); + +template void DecoderWriteCacheWithRoPEKernel( + const AppendAttnMetaData& meta_data, + const paddle::Tensor& + qkv, // [token_num, 3, num_head, head_dim] ([token_num, num_head + 2 * + // kv_num_heads, head_dim] if GQA) + const paddle::Tensor& seq_lens, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& cu_seqlens_q, + const paddle::Tensor& block_tables, + const paddle::optional& rotary_embs, + const paddle::optional& qkv_out_scales, + const paddle::optional& qkv_biases, + const paddle::optional& cache_k_scale, + const paddle::optional& cache_v_scale, + const paddle::optional& cache_k_zp, + const paddle::optional& cache_v_zp, + const std::string& cache_quant_type_str, + const bool use_neox_rotary_style, + const bool rope_3d, + const int max_seq_len, + cudaStream_t& stream, + paddle::Tensor* qkv_out, + paddle::Tensor* key_cache_out, + paddle::Tensor* value_cache_out, + const paddle::optional& q_norm_weight, + const paddle::optional& k_norm_weight, + const float rms_norm_eps); + +template void +DecoderWriteCacheWithRoPEKernel( + const AppendAttnMetaData& meta_data, + const paddle::Tensor& + qkv, // [token_num, 3, num_head, head_dim] ([token_num, num_head + 2 * + // kv_num_heads, head_dim] if GQA) + const paddle::Tensor& seq_lens, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& cu_seqlens_q, + const paddle::Tensor& block_tables, + const paddle::optional& rotary_embs, + const paddle::optional& qkv_out_scales, + const paddle::optional& qkv_biases, + const paddle::optional& cache_k_scale, + const paddle::optional& cache_v_scale, + const paddle::optional& cache_k_zp, + const paddle::optional& cache_v_zp, + const std::string& cache_quant_type_str, + const bool use_neox_rotary_style, + const bool rope_3d, + const int max_seq_len, + cudaStream_t& stream, + paddle::Tensor* qkv_out, + paddle::Tensor* key_cache_out, + paddle::Tensor* value_cache_out, + const paddle::optional& q_norm_weight, + const paddle::optional& k_norm_weight, + const float rms_norm_eps); + +template void DecoderWriteCacheWithRoPEKernel( + const AppendAttnMetaData& meta_data, + const paddle::Tensor& + qkv, // [token_num, 3, num_head, head_dim] ([token_num, num_head + 2 * + // kv_num_heads, head_dim] if GQA) + const paddle::Tensor& seq_lens, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& cu_seqlens_q, + const paddle::Tensor& block_tables, + const paddle::optional& rotary_embs, + const paddle::optional& qkv_out_scales, + const paddle::optional& qkv_biases, + const paddle::optional& cache_k_scale, + const paddle::optional& cache_v_scale, + const paddle::optional& cache_k_zp, + const paddle::optional& cache_v_zp, + const std::string& cache_quant_type_str, + const bool use_neox_rotary_style, + const bool rope_3d, + const int max_seq_len, + cudaStream_t& stream, + paddle::Tensor* qkv_out, + paddle::Tensor* key_cache_out, + paddle::Tensor* value_cache_out, + const paddle::optional& q_norm_weight, + const paddle::optional& k_norm_weight, + const float rms_norm_eps); + +template void +DecoderWriteCacheWithRoPEKernel( + const AppendAttnMetaData& meta_data, + const paddle::Tensor& + qkv, // [token_num, 3, num_head, head_dim] ([token_num, num_head + 2 * + // kv_num_heads, head_dim] if GQA) + const paddle::Tensor& seq_lens, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& cu_seqlens_q, + const paddle::Tensor& block_tables, + const paddle::optional& rotary_embs, + const paddle::optional& qkv_out_scales, + const paddle::optional& qkv_biases, + const paddle::optional& cache_k_scale, + const paddle::optional& cache_v_scale, + const paddle::optional& cache_k_zp, + const paddle::optional& cache_v_zp, + const std::string& cache_quant_type_str, + const bool use_neox_rotary_style, + const bool rope_3d, + const int max_seq_len, + cudaStream_t& stream, + paddle::Tensor* qkv_out, + paddle::Tensor* key_cache_out, + paddle::Tensor* value_cache_out, + const paddle::optional& q_norm_weight, + const paddle::optional& k_norm_weight, + const float rms_norm_eps); diff --git a/custom_ops/gpu_ops/append_attn/decoder_write_cache_with_rope_kernel.h b/custom_ops/gpu_ops/append_attn/decoder_write_cache_with_rope_kernel.h index b3fe75b2cd1..2acb4f8293b 100644 --- a/custom_ops/gpu_ops/append_attn/decoder_write_cache_with_rope_kernel.h +++ b/custom_ops/gpu_ops/append_attn/decoder_write_cache_with_rope_kernel.h @@ -15,7 +15,7 @@ #include "decoder_write_cache_with_rope_impl.cuh" -template +template void DecoderWriteCacheWithRoPEKernel( const AppendAttnMetaData& meta_data, const paddle::Tensor& @@ -23,7 +23,6 @@ void DecoderWriteCacheWithRoPEKernel( // kv_num_heads, head_dim] if GQA) const paddle::Tensor& seq_lens, const paddle::Tensor& seq_lens_encoder, - const paddle::Tensor& batch_id_per_token, const paddle::Tensor& cu_seqlens_q, const paddle::Tensor& block_tables, const paddle::optional& rotary_embs, @@ -40,4 +39,7 @@ void DecoderWriteCacheWithRoPEKernel( cudaStream_t& stream, paddle::Tensor* qkv_out, paddle::Tensor* key_cache_out, - paddle::Tensor* value_cache_out); + paddle::Tensor* value_cache_out, + const paddle::optional& q_norm_weight, + const paddle::optional& k_norm_weight, + const float rms_norm_eps); diff --git a/custom_ops/gpu_ops/append_attn/ds_mla_cache_kernel.cu b/custom_ops/gpu_ops/append_attn/ds_mla_cache_kernel.cu new file mode 100644 index 00000000000..dfb3e285093 --- /dev/null +++ b/custom_ops/gpu_ops/append_attn/ds_mla_cache_kernel.cu @@ -0,0 +1,411 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +/** + * DeepSeek Kv3.2 (DsKv3.2) Attention WriteCache Implementation + * + * This file implements writecache operations for DeepSeek MLA (Multi-head + * Latent Attention) with FP8 quantization support, migrated from vLLM. + * + * Key features: + * 1. DS MLA FP8 cache format (656 bytes per token): + * - 512 bytes: quantized NoPE part (fp8_e4m3) + * - 16 bytes: scale factors (4 x float32) + * - 128 bytes: RoPE part (64 x bf16, unquantized) + * + * 2. Standard MLA cache format (kv_lora_rank + pe_dim elements) + * + * 3. Indexer K quantization and cache operations + */ + +#include "ds_mla_cache_kernel.cuh" +#include "helper.h" +#include "remote_cache_kv_ipc.h" + +//============================================================================== +// DS MLA FP8 WriteCache Implementation +//============================================================================== + +/** + * Prefill stage: Write KV cache with DS MLA FP8 format + */ +template +std::vector DSMLAWriteCacheFP8( + const AppendAttnMetaData& meta_data, + const paddle::Tensor& kv_nope, + const paddle::Tensor& kv_pe, + const paddle::Tensor& slot_mapping, + cudaStream_t& stream, + paddle::Tensor* kv_cache) { + typedef PDTraits traits_; + typedef typename traits_::DataType DataType_; + typedef typename traits_::data_t data_t; + + auto num_tokens = slot_mapping.dims()[0]; + auto kv_lora_rank = 512; // DS MLA uses 512 + auto pe_dim = 64; // DS MLA uses 64 + auto block_size = meta_data.block_size; + const int entry_size = 656; + + // Launch kernel with 96 threads (64 for NoPE, 32 for RoPE) + dim3 grid(num_tokens); + dim3 block(96); + + const auto& kv_cache_dims = kv_cache->dims(); + int block_stride = kv_cache->strides()[0]; + int entry_stride = entry_size; + int kv_c_stride = kv_nope.strides()[0]; + int k_pe_stride = kv_pe.strides()[0]; + + ds_mla::concat_and_cache_ds_mla_kernel<<>>( + reinterpret_cast(const_cast(kv_nope.data())), + reinterpret_cast(const_cast(kv_pe.data())), + reinterpret_cast(kv_cache->data()), + slot_mapping.data(), + block_stride, + entry_stride, + kv_c_stride, + k_pe_stride, + kv_lora_rank, + pe_dim, + block_size); + return {}; +} + +//============================================================================== +// Standard MLA WriteCache Implementation +//============================================================================== + +/** + * Prefill stage: Write KV cache with standard MLA format + */ +template +std::vector PrefillDSMLAWriteCache( + const AppendAttnMetaData& meta_data, + const paddle::Tensor& kv_nope, + const paddle::Tensor& kv_pe, + const paddle::Tensor& slot_mapping, + const float* scale, + cudaStream_t& stream, + paddle::Tensor* kv_cache) { + typedef PDTraits traits_; + typedef typename traits_::DataType DataType_; + typedef typename traits_::data_t data_t; + + auto num_tokens = slot_mapping.dims()[0]; + auto kv_lora_rank = meta_data.head_dims_v; + auto pe_dim = meta_data.head_dims - meta_data.head_dims_v; + auto block_size = meta_data.block_size; + + const auto& kv_cache_dims = kv_cache->dims(); + int block_stride = kv_cache->strides()[0]; + int entry_stride = kv_cache->strides()[1]; + int kv_c_stride = kv_nope.strides()[0]; + int k_pe_stride = kv_pe.strides()[0]; + + dim3 grid(num_tokens); + dim3 block(std::min(kv_lora_rank, 512)); + + ds_mla::concat_and_cache_mla_kernel + <<>>( + reinterpret_cast( + const_cast(kv_nope.data())), + reinterpret_cast( + const_cast(kv_pe.data())), + reinterpret_cast(kv_cache->data()), + slot_mapping.data(), + block_stride, + entry_stride, + kv_c_stride, + k_pe_stride, + kv_lora_rank, + pe_dim, + block_size, + scale); + + return {}; +} + +//============================================================================== +// Indexer K Quantization and Cache Operations +//============================================================================== + +/** + * Quantize K tensor to FP8 and write to cache + */ +template +std::vector IndexerKQuantAndCache( + const paddle::Tensor& k, + const paddle::Tensor& slot_mapping, + const int head_dim, + const int quant_block_size, + const int cache_block_size, + const int cache_stride, + const bool use_ue8m0, + cudaStream_t& stream, + paddle::Tensor* kv_cache) { + typedef PDTraits traits_; + typedef typename traits_::DataType DataType_; + typedef typename traits_::data_t data_t; + + int num_tokens = k.dims()[0]; + + constexpr int vec_size = 4; + dim3 grid(num_tokens, + (head_dim + quant_block_size * vec_size - 1) / + (quant_block_size * vec_size)); + dim3 block(32, vec_size); + + ds_mla::indexer_k_quant_and_cache_kernel + <<>>( + reinterpret_cast(const_cast(k.data())), + reinterpret_cast(kv_cache->data()), + slot_mapping.data(), + head_dim, + quant_block_size, + cache_block_size, + cache_stride, + use_ue8m0); + + return {}; +} + +/** + * Gather K from quantized cache + */ +void CpGatherIndexerKQuantCache(const paddle::Tensor& kv_cache, + paddle::Tensor& dst_k, + paddle::Tensor& dst_scale, + const paddle::Tensor& block_table, + const paddle::Tensor& cu_seq_lens, + cudaStream_t& stream) { + int batch_size = block_table.dims()[0]; + int num_tokens = dst_k.dims()[0]; + int head_dim = dst_k.dims()[1]; + int quant_block_size = head_dim * 4 / dst_scale.dims()[1]; + + constexpr int vec_size = 16; + +#define CALL_CP_GATHER_INDEXER_K_QUANT_CACHE(BLOCK_Y_SIZE) \ + ds_mla::cp_gather_indexer_k_quant_cache_kernel \ + <<>>(reinterpret_cast(kv_cache.data()), \ + reinterpret_cast(dst_k.data()), \ + reinterpret_cast(dst_scale.data()), \ + block_table.data(), \ + cu_seq_lens.data(), \ + batch_size, \ + dst_k.strides()[0], \ + dst_k.dims()[1], \ + kv_cache.strides()[0], \ + kv_cache.strides()[1], \ + kv_cache.dims()[1], \ + block_table.dims()[1], \ + num_tokens, \ + quant_block_size); + + if (num_tokens < 32) { + CALL_CP_GATHER_INDEXER_K_QUANT_CACHE(1); + } else if (num_tokens < 64) { + CALL_CP_GATHER_INDEXER_K_QUANT_CACHE(2); + } else if (num_tokens < 128) { + CALL_CP_GATHER_INDEXER_K_QUANT_CACHE(4); + } else if (num_tokens < 256) { + CALL_CP_GATHER_INDEXER_K_QUANT_CACHE(8); + } else if (num_tokens < 512) { + CALL_CP_GATHER_INDEXER_K_QUANT_CACHE(16); + } else { + CALL_CP_GATHER_INDEXER_K_QUANT_CACHE(32); + } + +#undef CALL_CP_GATHER_INDEXER_K_QUANT_CACHE +} + +//============================================================================== +// Kernel Entry Points +//============================================================================== + +/** + * DS MLA WriteCache entry point - supports both FP8 and standard formats + */ +std::vector DSMLAWriteCacheKernel( + const paddle::Tensor& kv_nope, + const paddle::Tensor& kv_pe, + const paddle::Tensor& kv_cache, + const paddle::Tensor& slot_mapping, + const paddle::optional& scale, + const std::string& cache_quant_type_str) { + cudaStream_t stream = kv_pe.stream(); + AppendAttnMetaData meta_data; + + const auto& kv_nope_dims = kv_nope.dims(); + const auto& kv_pe_dims = kv_pe.dims(); + const auto& kv_cache_dims = kv_cache.dims(); + + meta_data.kv_num_heads = kv_cache_dims[1]; + const auto nope_size = + kv_nope_dims[kv_nope_dims.size() - 1] / meta_data.kv_num_heads; + meta_data.token_nums = kv_nope_dims[0]; + meta_data.head_dims = kv_cache_dims[3]; + meta_data.head_dims_v = nope_size; + meta_data.block_size = kv_cache_dims[2]; + + const float* scale_ptr = scale ? scale.get().data() : nullptr; + + if (cache_quant_type_str == "fp8_ds_mla") { + // FP8 DS MLA format + switch (kv_pe.dtype()) { + case paddle::DataType::BFLOAT16: { + return DSMLAWriteCacheFP8( + meta_data, + kv_nope, + kv_pe, + slot_mapping, + stream, + const_cast(&kv_cache)); + } + case paddle::DataType::FLOAT16: { + return DSMLAWriteCacheFP8( + meta_data, + kv_nope, + kv_pe, + slot_mapping, + stream, + const_cast(&kv_cache)); + } + default: + PD_THROW("Unsupported dtype for DS MLA FP8 cache"); + } + } else { + // Standard MLA format (auto/bf16/fp16) + switch (kv_pe.dtype()) { + case paddle::DataType::BFLOAT16: { + return PrefillDSMLAWriteCache( + meta_data, + kv_nope, + kv_pe, + slot_mapping, + scale_ptr, + stream, + const_cast(&kv_cache)); + } + case paddle::DataType::FLOAT16: { + return PrefillDSMLAWriteCache( + meta_data, + kv_nope, + kv_pe, + slot_mapping, + scale_ptr, + stream, + const_cast(&kv_cache)); + } + default: + PD_THROW("Unsupported dtype for DS MLA cache"); + } + } + return {}; +} + +/** + * Indexer K Quant and Cache entry point + */ +std::vector IndexerKQuantAndCacheKernel( + const paddle::Tensor& k, + const paddle::Tensor& kv_cache, + const paddle::Tensor& slot_mapping, + const int64_t quant_block_size, + const std::string& scale_fmt) { + cudaStream_t stream = k.stream(); + int num_tokens = k.dims()[0]; + int head_dim = k.dims()[1]; + int cache_block_size = kv_cache.dims()[1]; + int cache_stride = kv_cache.dims()[2]; + bool use_ue8m0 = scale_fmt == "ue8m0"; + + switch (k.dtype()) { + case paddle::DataType::BFLOAT16: { + return IndexerKQuantAndCache( + k, + slot_mapping, + head_dim, + quant_block_size, + cache_block_size, + cache_stride, + use_ue8m0, + stream, + const_cast(&kv_cache)); + } + case paddle::DataType::FLOAT16: { + return IndexerKQuantAndCache( + k, + slot_mapping, + head_dim, + quant_block_size, + cache_block_size, + cache_stride, + use_ue8m0, + stream, + const_cast(&kv_cache)); + } + default: + PD_THROW("Unsupported dtype for Indexer K Quant"); + } + return {}; +} + +/** + * Gather Indexer K from Quant Cache entry point + */ +std::vector CpGatherIndexerKQuantCacheKernel( + const paddle::Tensor& kv_cache, + paddle::Tensor& dst_k, + paddle::Tensor& dst_scale, + const paddle::Tensor& block_table, + const paddle::Tensor& cu_seq_lens) { + cudaStream_t stream = kv_cache.stream(); + CpGatherIndexerKQuantCache( + kv_cache, dst_k, dst_scale, block_table, cu_seq_lens, stream); + return {}; +} + +//============================================================================== +// Paddle Custom Operator Registration +//============================================================================== + +PD_BUILD_STATIC_OP(ds_mla_write_cache) + .Inputs({"kv_nope", + "kv_pe", + "kv_cache", + "slot_mapping", + paddle::Optional("scale")}) + .Outputs({"kv_cache_out"}) + .SetInplaceMap({{"kv_cache", "kv_cache_out"}}) + .Attrs({"cache_quant_type_str: std::string"}) + .SetKernelFn(PD_KERNEL(DSMLAWriteCacheKernel)); + +PD_BUILD_STATIC_OP(indexer_k_quant_and_cache) + .Inputs({"k", "kv_cache", "slot_mapping"}) + .Outputs({"kv_cache_out"}) + .SetInplaceMap({{"kv_cache", "kv_cache_out"}}) + .Attrs({"quant_block_size: int64_t", "scale_fmt: std::string"}) + .SetKernelFn(PD_KERNEL(IndexerKQuantAndCacheKernel)); + +PD_BUILD_STATIC_OP(cp_gather_indexer_k_quant_cache) + .Inputs({"kv_cache", "dst_k", "dst_scale", "block_table", "cu_seq_lens"}) + .Outputs({"dst_k_out", "dst_scale_out"}) + .SetInplaceMap({{"dst_k", "dst_k_out"}, {"dst_scale", "dst_scale_out"}}) + .SetKernelFn(PD_KERNEL(CpGatherIndexerKQuantCacheKernel)); diff --git a/custom_ops/gpu_ops/append_attn/ds_mla_cache_kernel.cuh b/custom_ops/gpu_ops/append_attn/ds_mla_cache_kernel.cuh new file mode 100644 index 00000000000..b0f8e00ccad --- /dev/null +++ b/custom_ops/gpu_ops/append_attn/ds_mla_cache_kernel.cuh @@ -0,0 +1,548 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include +#include +#include +#include +#include +#include "helper.h" +#include "mem_util.cuh" +#include "utils.cuh" + +// FP8 scale divisor constant (for SM90+) +#if defined(__gfx942__) +constexpr float kFp8ScaleDivisorDS = 224.f; +#else +constexpr float kFp8ScaleDivisorDS = 448.f; +#endif + +namespace ds_mla { + +/** + * FP8 scaled conversion utilities + */ +template +__device__ __forceinline__ OutT fp8_scaled_convert(InT src, float scale) { + return static_cast(static_cast(src) / scale); +} + +template <> +__device__ __forceinline__ uint8_t +fp8_scaled_convert(__nv_bfloat16 src, float scale) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 890) + float val = __bfloat162float(src) / scale; + val = fminf(fmaxf(val, -448.0f), 448.0f); + __nv_fp8_e4m3 fp8_val = static_cast<__nv_fp8_e4m3>(val); + return *reinterpret_cast(&fp8_val); +#else + return 0; +#endif +} + +template <> +__device__ __forceinline__ uint8_t +fp8_scaled_convert(half src, float scale) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 890) + float val = __half2float(src) / scale; + val = fminf(fmaxf(val, -448.0f), 448.0f); + __nv_fp8_e4m3 fp8_val = static_cast<__nv_fp8_e4m3>(val); + return *reinterpret_cast(&fp8_val); +#else + return 0; +#endif +} + +/** + * DeepSeek MLA FP8 Cache Write Kernel + * + * Cache format (fp8_ds_mla - 656 bytes per token): + * - First 512 bytes: quantized NoPE part (512 x fp8_e4m3) + * - Next 16 bytes: scale factors (4 x float32, one per 128 fp8 values) + * - Last 128 bytes: RoPE part (64 x bfloat16, not quantized) + * + * Thread organization: + * - First 2 warps (64 threads): handle NoPE FP8 quantization + * - Last 1 warp (32 threads): handle RoPE copy + * - Total: 96 threads per block + */ +template +__global__ void concat_and_cache_ds_mla_kernel( + const scalar_t* __restrict__ kv_c, // [num_tokens, kv_lora_rank] + const scalar_t* __restrict__ k_pe, // [num_tokens, pe_dim] + uint8_t* __restrict__ kv_cache, // [num_blocks, block_size, + // cache_entry_size] + const int64_t* __restrict__ slot_mapping, // [num_tokens] + const int block_stride, // stride per block in cache + const int entry_stride, // stride per token entry in cache + const int kv_c_stride, // stride for kv_c input + const int k_pe_stride, // stride for k_pe input + const int kv_lora_rank, // 512 for DS MLA + const int pe_dim, // 64 for DS MLA + const int block_size // number of tokens per cache block +) { + const int64_t token_idx = blockIdx.x; + const int64_t slot_idx = slot_mapping[token_idx]; + + // NOTE: slot_idx can be -1 if the token is padded + if (slot_idx < 0) { + return; + } + + const int64_t block_idx = slot_idx / block_size; + const int64_t block_offset = slot_idx % block_size; + const int64_t dst_idx_start = + block_idx * block_stride + block_offset * entry_stride; + + // Cast kv_cache to 16-bit for RoPE values + scalar_t* kv_cache_16bit = + reinterpret_cast(&kv_cache[dst_idx_start]); + + // The last warp handles the RoPE part + if (threadIdx.x >= 64) { + // Each thread handles two elements of RoPE + const int8_t pe_idx_start = (threadIdx.x - 64) * 2; + const int64_t src_idx = token_idx * k_pe_stride + pe_idx_start; + + // Vectorized load of two 16-bit values, performed as one 32-bit load + const int32_t vals = *reinterpret_cast(&k_pe[src_idx]); + + // RoPE values start after the packed 8-bit NoPE values and the 32-bit + // scales Position: kv_lora_rank/2 (256 bytes in 16-bit units) + 8 (16 bytes + // of scales in 16-bit units) + const int64_t dst_idx = kv_lora_rank / 2 + 8 + pe_idx_start; + + // Vectorized store of two 16-bit values + *reinterpret_cast(&kv_cache_16bit[dst_idx]) = vals; + return; + } + + // The first two warps handle the NoPE part + const int8_t warp_idx = threadIdx.x >> 5; + const int8_t lane_idx = threadIdx.x & 31; + const int8_t tile_idx = warp_idx * 2 + (lane_idx >> 4); + + // Each thread handles 8 elements of NoPE + const int64_t src_idx_start = token_idx * kv_c_stride + (threadIdx.x * 8); + + // Vectorized load of eight 16-bit values + const int4 vals_i4 = *reinterpret_cast(&kv_c[src_idx_start]); + const scalar_t* vals = reinterpret_cast(&vals_i4); + + // Max absolute value of this thread's elements + float max_abs = fmaxf(fmaxf(fmaxf(fabsf(static_cast(vals[0])), + fabsf(static_cast(vals[1]))), + fmaxf(fabsf(static_cast(vals[2])), + fabsf(static_cast(vals[3])))), + fmaxf(fmaxf(fabsf(static_cast(vals[4])), + fabsf(static_cast(vals[5]))), + fmaxf(fabsf(static_cast(vals[6])), + fabsf(static_cast(vals[7]))))); + + // Warp-level reduction to find the max absolute value in each half-warp +#pragma unroll + for (int offset = 8; offset > 0; offset /= 2) { + max_abs = fmaxf(max_abs, __shfl_xor_sync(0xFFFF, max_abs, offset, 16)); + } + + // Compute the scale for the tile + float tile_scale = fmaxf(max_abs / kFp8ScaleDivisorDS, FLT_MIN); + + // The first lane of each half-warp writes the scale to kv_cache + if ((lane_idx == 0) || (lane_idx == 16)) { + float* kv_cache_32bit = reinterpret_cast(&kv_cache[dst_idx_start]); + const uint64_t dst_idx = kv_lora_rank / 4 + tile_idx; + kv_cache_32bit[dst_idx] = tile_scale; + } + + // Now all threads in the block scale and write their elements + const int64_t dst_idx_base = dst_idx_start + (threadIdx.x * 8); + + uint8_t result[8]; +#pragma unroll + for (int i = 0; i < 8; i++) { + result[i] = fp8_scaled_convert(vals[i], tile_scale); + } + + // Store as aligned 64-bit writes + *reinterpret_cast(&kv_cache[dst_idx_base]) = + *reinterpret_cast(result); +} + +/** + * Standard MLA Cache Write Kernel (non-FP8) + * + * For auto/bf16/fp16 cache types + */ +template +__global__ void concat_and_cache_mla_kernel( + const scalar_t* __restrict__ kv_c, // [num_tokens, kv_lora_rank] + const scalar_t* __restrict__ k_pe, // [num_tokens, pe_dim] + cache_t* __restrict__ kv_cache, // [num_blocks, block_size, (kv_lora_rank + + // pe_dim)] + const int64_t* __restrict__ slot_mapping, // [num_tokens] + const int block_stride, + const int entry_stride, + const int kv_c_stride, + const int k_pe_stride, + const int kv_lora_rank, + const int pe_dim, + const int block_size, + const float* scale) { + const int64_t token_idx = blockIdx.x; + const int64_t slot_idx = slot_mapping[token_idx]; + + if (slot_idx < 0) { + return; + } + + const int64_t block_idx = slot_idx / block_size; + const int64_t block_offset = slot_idx % block_size; + + // Copy kv_c (NoPE part) + for (int i = threadIdx.x; i < kv_lora_rank; i += blockDim.x) { + const int64_t src_idx = token_idx * kv_c_stride + i; + const int64_t dst_idx = + block_idx * block_stride + block_offset * entry_stride + i; + kv_cache[dst_idx] = static_cast(kv_c[src_idx]); + } + + // Copy k_pe (RoPE part) + for (int i = threadIdx.x; i < pe_dim; i += blockDim.x) { + const int64_t src_idx = token_idx * k_pe_stride + i; + const int64_t dst_idx = block_idx * block_stride + + block_offset * entry_stride + kv_lora_rank + i; + kv_cache[dst_idx] = static_cast(k_pe[src_idx]); + } +} + +/** + * Indexer K Quantization and Cache Kernel + * + * Quantizes K values to FP8 and stores them in cache with scale factors + * Cache layout: [quantized_k (head_dim bytes)] + [scales + * (head_dim/quant_block_size * 4 bytes)] + */ +template +__global__ void indexer_k_quant_and_cache_kernel( + const scalar_t* __restrict__ k, // [num_tokens, head_dim] + uint8_t* __restrict__ kv_cache, // [num_blocks, block_size, cache_stride] + const int64_t* __restrict__ slot_mapping, // [num_tokens] + const int head_dim, + const int quant_block_size, // typically 128 + const int cache_block_size, + const int cache_stride, + const bool use_ue8m0 // use ue8m0 scale format +) { + constexpr int VEC_SIZE = 4; + const int64_t token_idx = blockIdx.x; + const int64_t head_dim_idx = (blockIdx.y * blockDim.y * blockDim.x + + threadIdx.y * blockDim.x + threadIdx.x) * + VEC_SIZE; + const int64_t slot_idx = slot_mapping[token_idx]; + const int64_t block_idx = slot_idx / cache_block_size; + const int64_t block_offset = slot_idx % cache_block_size; + + if (slot_idx < 0 || head_dim_idx >= head_dim) { + return; + } + + // Load 4 values at once using float2 (for bf16/fp16) + float2 k_val = reinterpret_cast( + k)[(token_idx * head_dim + head_dim_idx) / VEC_SIZE]; + scalar_t* k_val_ptr = reinterpret_cast(&k_val); + + float amax = 0.0f; + for (int i = 0; i < VEC_SIZE; i++) { + amax = fmaxf(amax, fabsf(static_cast(k_val_ptr[i]))); + } + + // Warp reduction to find max across quant_block_size elements + for (int mask = 16; mask > 0; mask /= 2) { + amax = fmaxf(amax, __shfl_xor_sync(0xFFFFFFFF, amax, mask)); + } + + float scale = fmaxf(amax, 1e-4f) / kFp8ScaleDivisorDS; + + if (use_ue8m0) { + scale = exp2f(ceilf(log2f(scale))); + } + + const int64_t dst_offset = block_idx * cache_block_size * cache_stride + + block_offset * head_dim + head_dim_idx; + + for (int i = 0; i < VEC_SIZE; i++) { + kv_cache[dst_offset + i] = + fp8_scaled_convert(k_val_ptr[i], scale); + } + + // First thread in warp writes the scale + if (threadIdx.x == 0) { + const int64_t dst_scale_idx = + block_idx * cache_block_size * cache_stride + + cache_block_size * head_dim + + (block_offset * head_dim + head_dim_idx) * 4 / quant_block_size; + reinterpret_cast(kv_cache)[dst_scale_idx / 4] = scale; + } +} + +/** + * Gather Indexer K from Quantized Cache Kernel + * + * Gathers and dequantizes K values from the cache + */ +template +__global__ void cp_gather_indexer_k_quant_cache_kernel( + const char* __restrict__ kv_cache, // [num_blocks, block_size, + // cache_stride] + char* __restrict__ dst_k, // [num_tokens, head_dim] + char* __restrict__ dst_scale, // [num_tokens, head_dim/quant_block_size*4] + const int* __restrict__ block_table, // [batch_size, num_blocks] + const int* __restrict__ cu_seq_lens, // [batch_size + 1] + const int batch_size, + const int64_t token_stride, + const int64_t head_dim, + const int64_t block_stride, + const int64_t cache_token_stride, + const int64_t cache_block_size, + const int num_blocks, + const int num_tokens, + const int quant_block_size) { + constexpr int VEC_SIZE = sizeof(float4) / sizeof(char); + const int token_idx = blockIdx.x * blockDim.y + threadIdx.y; + const int head_idx = (blockIdx.y * blockDim.x + threadIdx.x) * VEC_SIZE; + + // Find batch index within a block + __shared__ int batch_idx[BLOCK_Y_SIZE]; + for (int iter = 0; iter < (batch_size + blockDim.x - 1) / blockDim.x; + iter++) { + int tid = iter * blockDim.x + threadIdx.x; + if (tid < batch_size) { + const int seq_start = cu_seq_lens[tid]; + const int seq_end = cu_seq_lens[tid + 1]; + if (token_idx >= seq_start && token_idx < seq_end) { + batch_idx[threadIdx.y] = tid; + } + } + } + + __syncwarp(); + + if (head_idx >= head_dim || token_idx >= num_tokens) { + return; + } + + const int inbatch_seq_idx = token_idx - cu_seq_lens[batch_idx[threadIdx.y]]; + const int block_id = block_table[batch_idx[threadIdx.y] * num_blocks + + inbatch_seq_idx / cache_block_size]; + const int64_t src_block_offset = block_id * block_stride; + const int64_t cache_inblock_offset = + (inbatch_seq_idx % cache_block_size) * head_dim + head_idx; + const int64_t src_inblock_offset = src_block_offset + cache_inblock_offset; + const int64_t dst_inblock_offset = token_idx * token_stride + head_idx; + + reinterpret_cast(dst_k)[dst_inblock_offset / VEC_SIZE] = + reinterpret_cast(kv_cache)[src_inblock_offset / VEC_SIZE]; + + if (threadIdx.x == 0) { + const int64_t src_scale_offset = + src_block_offset + cache_block_size * head_dim + + cache_inblock_offset * 4 / quant_block_size; + reinterpret_cast(dst_scale)[dst_inblock_offset / quant_block_size] = + reinterpret_cast(kv_cache)[src_scale_offset / 4]; + } +} + +/** + * Prefill DS MLA Write Cache Kernel + * + * Writes prefill KV data to DS MLA cache format + */ +template +__global__ void prefill_ds_mla_cache_kernel( + const T* __restrict__ kv_nope, // [num_tokens, kv_num_heads * nope_size] + const T* __restrict__ kv_pe, // [num_tokens, kv_num_heads * pe_size] + uint8_t* __restrict__ kv_cache, // [num_blocks, kv_num_heads, block_size, + // entry_size] + const int* __restrict__ block_tables, + const int* __restrict__ batch_id_per_token, + const int* __restrict__ cu_seqlens_q, + const int* __restrict__ seq_lens, + const int* __restrict__ seq_lens_decoder, + const int max_seq_len, + const int max_blocks_per_seq, + const int kv_num_heads, + const int nope_size, // 512 for DS MLA + const int pe_size, // 64 for DS MLA + const int block_size, + const int entry_size, // 656 for DS MLA FP8 + const uint32_t elem_cnt) { + using LoadT = AlignedVector; + LoadT src_vec; + + int64_t global_thread_idx = blockDim.x * blockIdx.x + threadIdx.x; + const uint32_t nope_hidden_size = kv_num_heads * nope_size; + const uint32_t pe_hidden_size = kv_num_heads * pe_size; + const int64_t hidden_size = nope_hidden_size + pe_hidden_size; + + for (int32_t linear_index = global_thread_idx * VecSize, + step = gridDim.x * blockDim.x * VecSize; + linear_index < elem_cnt; + linear_index += step) { + const uint32_t token_idx = linear_index / hidden_size; + const uint32_t bias = linear_index % hidden_size; + const uint32_t ori_bi = batch_id_per_token[token_idx]; + + if (seq_lens[ori_bi] == 0) continue; + + const uint32_t ori_seq_id = + (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi]; + + const int* block_table_now = block_tables + ori_bi * max_blocks_per_seq; + const uint32_t block_idx = block_table_now[ori_seq_id / block_size]; + const uint32_t block_offset = ori_seq_id % block_size; + + if (bias < nope_hidden_size) { + const uint32_t inner_bias = bias; + const uint32_t hi = inner_bias / nope_size; + const uint32_t h_bias = inner_bias % nope_size; + + // For DS MLA FP8, NoPE part goes to first 512 bytes + const uint32_t tgt_idx = + block_idx * kv_num_heads * block_size * entry_size + + hi * block_size * entry_size + block_offset * entry_size + h_bias; + const uint32_t ori_idx = token_idx * nope_hidden_size + inner_bias; + + Load(&kv_nope[ori_idx], &src_vec); + + // Convert to FP8 and store + for (int i = 0; i < VecSize; i++) { + float val = static_cast(src_vec.val[i]); + val = fminf(fmaxf(val, -448.0f), 448.0f); +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 890) + __nv_fp8_e4m3 fp8_val = static_cast<__nv_fp8_e4m3>(val); + kv_cache[tgt_idx + i] = *reinterpret_cast(&fp8_val); +#endif + } + } else { + const uint32_t inner_bias = bias - nope_hidden_size; + const uint32_t hi = inner_bias / pe_size; + const uint32_t h_bias = inner_bias % pe_size; + + // RoPE part goes after NoPE (512 bytes) + scales (16 bytes) + const uint32_t tgt_idx = + block_idx * kv_num_heads * block_size * entry_size + + hi * block_size * entry_size + block_offset * entry_size + nope_size + + 16 + h_bias * 2; // *2 for bf16 + const uint32_t ori_idx = token_idx * pe_hidden_size + inner_bias; + + Load(&kv_pe[ori_idx], &src_vec); + + // Copy RoPE without quantization (as bf16/fp16) + T* tgt_ptr = reinterpret_cast(&kv_cache[tgt_idx]); + for (int i = 0; i < VecSize; i++) { + tgt_ptr[i] = src_vec.val[i]; + } + } + } +} + +/** + * Decode DS MLA Write Cache Kernel + */ +template +__global__ void decode_ds_mla_cache_kernel( + const T* __restrict__ kv_nope, + const T* __restrict__ kv_pe, + uint8_t* __restrict__ kv_cache, + const int* __restrict__ block_tables, + const int* __restrict__ cu_seqlens_q, + const int* __restrict__ seq_lens, + const int* __restrict__ seq_lens_encoder, + const int max_seq_len, + const int max_blocks_per_seq, + const int kv_num_heads, + const int nope_size, + const int pe_size, + const int block_size, + const int entry_size, + const uint32_t elem_cnt) { + using LoadT = AlignedVector; + LoadT src_vec; + + int64_t global_thread_idx = blockDim.x * blockIdx.x + threadIdx.x; + const uint32_t nope_hidden_size = kv_num_heads * nope_size; + const uint32_t pe_hidden_size = kv_num_heads * pe_size; + const int64_t hidden_size = nope_hidden_size + pe_hidden_size; + + for (int32_t linear_index = global_thread_idx * VecSize, + step = gridDim.x * blockDim.x * VecSize; + linear_index < elem_cnt; + linear_index += step) { + const int ori_bi = linear_index / hidden_size; + const int bias = linear_index % hidden_size; + const int start_token_idx = cu_seqlens_q[ori_bi]; + + if (seq_lens_encoder[ori_bi] > 0) return; + + const int write_seq_id = seq_lens[ori_bi]; + if (write_seq_id == 0) continue; + + const int* block_table_now = block_tables + ori_bi * max_blocks_per_seq; + const int block_idx = block_table_now[write_seq_id / block_size]; + const int block_offset = write_seq_id % block_size; + + if (bias < nope_hidden_size) { + const uint32_t inner_bias = bias; + const uint32_t hi = inner_bias / nope_size; + const uint32_t h_bias = inner_bias % nope_size; + + const uint32_t tgt_idx = + block_idx * kv_num_heads * block_size * entry_size + + hi * block_size * entry_size + block_offset * entry_size + h_bias; + const uint32_t ori_idx = start_token_idx * nope_hidden_size + inner_bias; + + Load(&kv_nope[ori_idx], &src_vec); + + for (int i = 0; i < VecSize; i++) { + float val = static_cast(src_vec.val[i]); + val = fminf(fmaxf(val, -448.0f), 448.0f); +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 890) + __nv_fp8_e4m3 fp8_val = static_cast<__nv_fp8_e4m3>(val); + kv_cache[tgt_idx + i] = *reinterpret_cast(&fp8_val); +#endif + } + } else { + const uint32_t inner_bias = bias - nope_hidden_size; + const uint32_t hi = inner_bias / pe_size; + const uint32_t h_bias = inner_bias % pe_size; + + const uint32_t tgt_idx = + block_idx * kv_num_heads * block_size * entry_size + + hi * block_size * entry_size + block_offset * entry_size + nope_size + + 16 + h_bias * 2; + const uint32_t ori_idx = start_token_idx * pe_hidden_size + inner_bias; + + Load(&kv_pe[ori_idx], &src_vec); + + T* tgt_ptr = reinterpret_cast(&kv_cache[tgt_idx]); + for (int i = 0; i < VecSize; i++) { + tgt_ptr[i] = src_vec.val[i]; + } + } + } +} + +} // namespace ds_mla diff --git a/custom_ops/gpu_ops/append_attn/encoder_write_cache_with_rope_impl.cuh b/custom_ops/gpu_ops/append_attn/encoder_write_cache_with_rope_impl.cuh index a2da51ef125..0cdea537327 100644 --- a/custom_ops/gpu_ops/append_attn/encoder_write_cache_with_rope_impl.cuh +++ b/custom_ops/gpu_ops/append_attn/encoder_write_cache_with_rope_impl.cuh @@ -18,8 +18,8 @@ #include "mma_tensor_op.cuh" #include "utils.cuh" -template -__global__ void VariableLengthRotaryKernel( +template +__global__ void IntVariableLengthRotaryKernel( const int *qkv, const float *cos_emb, // [1, 1, seq_len, dim_head / 2] const float *sin_emb, @@ -33,7 +33,8 @@ __global__ void VariableLengthRotaryKernel( const int64_t elem_cnt, const int num_head, const int seq_len, - const int last_dim) { + const int last_dim, + const bool rope_3d) { using LoadT = AlignedVector; using LoadBiasT = AlignedVector; using LoadScaleT = AlignedVector; @@ -48,12 +49,16 @@ __global__ void VariableLengthRotaryKernel( const int half_lastdim = last_dim / 2; const int hidden_size = num_head * last_dim; const int offset = 3 * hidden_size; +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaGridDependencySynchronize(); +#endif for (int64_t linear_index = global_thread_idx * VecSize, step = gridDim.x * blockDim.x * VecSize; linear_index < elem_cnt; linear_index += step) { const int token_idx = linear_index / offset; const int ori_bi = batch_id_per_token[token_idx]; + if (ori_bi == -1) continue; if (seq_lens && seq_lens[ori_bi] == 0) continue; const int bias = linear_index % offset; const int qkv_id = bias / hidden_size; @@ -61,9 +66,11 @@ __global__ void VariableLengthRotaryKernel( const int hi = qkv_bias / last_dim; const int h_bias = qkv_bias % last_dim; - const int ori_seq_id = (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi]; + const int ori_seq_id = + (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi]; const int emb_idx = ori_seq_id * half_lastdim + h_bias / 2; + int new_emb_idx = rope_3d ? emb_idx + ori_bi * last_dim * seq_len : emb_idx; const int bias_idx = qkv_id * hidden_size + hi * last_dim + h_bias; const int64_t base_idx = token_idx * 3 * hidden_size + bias_idx; Load(&qkv[base_idx], &src_vec); @@ -72,8 +79,8 @@ __global__ void VariableLengthRotaryKernel( } Load(&qkv_out_scales[bias_idx], &out_scale_vec); if (qkv_id < 2) { - Load(&cos_emb[emb_idx], &cos_emb_vec); - Load(&sin_emb[emb_idx], &sin_emb_vec); + Load(&cos_emb[new_emb_idx], &cos_emb_vec); + Load(&sin_emb[new_emb_idx], &sin_emb_vec); } #pragma unroll for (int i = 0; i < HalfVecSize; i++) { @@ -90,9 +97,11 @@ __global__ void VariableLengthRotaryKernel( const float cos_tmp = cos_emb_vec[i]; const float sin_tmp = sin_emb_vec[i]; bias_vec[2 * i] = - static_cast(input_left * cos_tmp - input_right * sin_tmp); + static_cast(fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp)); bias_vec[2 * i + 1] = - static_cast(input_right * cos_tmp + input_left * sin_tmp); + static_cast(fmul_func(input_right, cos_tmp) + + fmul_func(input_left, sin_tmp)); } else { bias_vec[2 * i] = static_cast(input_left); bias_vec[2 * i + 1] = static_cast(input_right); @@ -100,9 +109,12 @@ __global__ void VariableLengthRotaryKernel( } Store(bias_vec, &qkv_out[base_idx]); } +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaGridDependencySynchronize(); +#endif } -template +template __global__ void VariableLengthRotaryKernel( const T *qkv, const float *cos_emb, // [1, 1, seq_len, dim_head / 2] @@ -115,7 +127,8 @@ __global__ void VariableLengthRotaryKernel( const int64_t elem_cnt, const int num_head, const int seq_len, - const int last_dim) { + const int last_dim, + const bool rope_3d) { using LoadT = AlignedVector; constexpr int HalfVecSize = VecSize / 2; using LoadEmbT = AlignedVector; @@ -126,12 +139,16 @@ __global__ void VariableLengthRotaryKernel( const int half_lastdim = last_dim / 2; const int hidden_size = num_head * last_dim; const int offset = 2 * hidden_size; +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaGridDependencySynchronize(); +#endif for (int64_t linear_index = global_thread_idx * VecSize, step = gridDim.x * blockDim.x * VecSize; linear_index < elem_cnt; linear_index += step) { const int token_idx = linear_index / offset; const int ori_bi = batch_id_per_token[token_idx]; + if (ori_bi == -1) continue; if (seq_lens && seq_lens[ori_bi] == 0) continue; const int bias = linear_index % offset; const int qkv_id = bias / hidden_size; @@ -139,14 +156,16 @@ __global__ void VariableLengthRotaryKernel( const int hi = qkv_bias / last_dim; const int h_bias = qkv_bias % last_dim; - const int ori_seq_id = (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi]; + const int ori_seq_id = + (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi]; const int emb_idx = ori_seq_id * half_lastdim + h_bias / 2; + int new_emb_idx = rope_3d ? emb_idx + ori_bi * last_dim * seq_len : emb_idx; const int64_t base_idx = token_idx * 3 * hidden_size + qkv_id * hidden_size + hi * last_dim + h_bias; Load(&qkv[base_idx], &src_vec); - Load(&cos_emb[emb_idx], &cos_emb_vec); - Load(&sin_emb[emb_idx], &sin_emb_vec); + Load(&cos_emb[new_emb_idx], &cos_emb_vec); + Load(&sin_emb[new_emb_idx], &sin_emb_vec); #pragma unroll for (int i = 0; i < HalfVecSize; i++) { const float input_left = static_cast(src_vec[2 * i]); @@ -154,16 +173,21 @@ __global__ void VariableLengthRotaryKernel( const float cos_tmp = cos_emb_vec[i]; const float sin_tmp = sin_emb_vec[i]; src_vec[2 * i] = - static_cast(input_left * cos_tmp - input_right * sin_tmp); + static_cast(fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp)); src_vec[2 * i + 1] = - static_cast(input_right * cos_tmp + input_left * sin_tmp); + static_cast(fmul_func(input_right, cos_tmp) + + fmul_func(input_left, sin_tmp)); } Store(src_vec, &qkv_out[base_idx]); } +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaGridDependencySynchronize(); +#endif } -template -__global__ void NeoxVariableLengthRotaryKernel( +template +__global__ void IntNeoxVariableLengthRotaryKernel( const int *qkv, const float *cos_emb, // [1, 1, seq_len, dim_head / 2] const float *sin_emb, @@ -177,7 +201,8 @@ __global__ void NeoxVariableLengthRotaryKernel( const int64_t elem_cnt, const int num_head, const int seq_len, - const int last_dim) { + const int last_dim, + const bool rope_3d) { using LoadT = AlignedVector; using LoadBiasT = AlignedVector; using LoadScaleT = AlignedVector; @@ -195,12 +220,16 @@ __global__ void NeoxVariableLengthRotaryKernel( const int hidden_size = num_head * half_lastdim; const int full_hidden_size = num_head * last_dim; const int offset = 3 * hidden_size; +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaGridDependencySynchronize(); +#endif for (int64_t linear_index = global_thread_idx * VecSize, step = gridDim.x * blockDim.x * VecSize; linear_index < elem_cnt; linear_index += step) { const int token_idx = linear_index / offset; const int ori_bi = batch_id_per_token[token_idx]; + if (ori_bi == -1) continue; if (seq_lens && seq_lens[ori_bi] == 0) continue; const int bias = linear_index % offset; const int qkv_id = bias / hidden_size; @@ -208,9 +237,12 @@ __global__ void NeoxVariableLengthRotaryKernel( const int hi = qkv_bias / half_lastdim; const int h_bias = qkv_bias % half_lastdim; - const int ori_seq_id = (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi]; + const int ori_seq_id = + (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi]; const int emb_idx = ori_seq_id * last_dim + h_bias; + int new_emb_idx = + rope_3d ? emb_idx + ori_bi * last_dim * seq_len * 2 : emb_idx; const int bias_idx_left = qkv_id * full_hidden_size + hi * last_dim + h_bias; const int bias_idx_right = bias_idx_left + half_lastdim; @@ -225,8 +257,8 @@ __global__ void NeoxVariableLengthRotaryKernel( Load(&qkv_out_scales[bias_idx_left], &left_out_scale_vec); Load(&qkv_out_scales[bias_idx_right], &right_out_scale_vec); if (qkv_id < 2) { - Load(&cos_emb[emb_idx], &cos_emb_vec); - Load(&sin_emb[emb_idx], &sin_emb_vec); + Load(&cos_emb[new_emb_idx], &cos_emb_vec); + Load(&sin_emb[new_emb_idx], &sin_emb_vec); } #pragma unroll for (int i = 0; i < VecSize; i++) { @@ -243,9 +275,11 @@ __global__ void NeoxVariableLengthRotaryKernel( const float cos_tmp = cos_emb_vec[i]; const float sin_tmp = sin_emb_vec[i]; left_bias_vec[i] = - static_cast(input_left * cos_tmp - input_right * sin_tmp); + static_cast(fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp)); right_bias_vec[i] = - static_cast(input_right * cos_tmp + input_left * sin_tmp); + static_cast(fmul_func(input_right, cos_tmp) + + fmul_func(input_left, sin_tmp)); } else { left_bias_vec[i] = static_cast(input_left); right_bias_vec[i] = static_cast(input_right); @@ -254,9 +288,12 @@ __global__ void NeoxVariableLengthRotaryKernel( Store(left_bias_vec, &qkv_out[base_idx_left]); Store(right_bias_vec, &qkv_out[base_idx_right]); } +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaGridDependencySynchronize(); +#endif } -template +template __global__ void NeoxVariableLengthRotaryKernel( const T *qkv, const float *cos_emb, // [1, 1, seq_len, dim_head / 2] @@ -269,7 +306,8 @@ __global__ void NeoxVariableLengthRotaryKernel( const int64_t elem_cnt, const int num_head, const int seq_len, - const int last_dim) { + const int last_dim, + const bool rope_3d) { using LoadT = AlignedVector; using LoadEmbT = AlignedVector; LoadT left_vec; @@ -281,12 +319,16 @@ __global__ void NeoxVariableLengthRotaryKernel( const int hidden_size = num_head * half_lastdim; const int full_hidden_size = num_head * last_dim; const int offset = 2 * hidden_size; +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaGridDependencySynchronize(); +#endif for (int64_t linear_index = global_thread_idx * VecSize, step = gridDim.x * blockDim.x * VecSize; linear_index < elem_cnt; linear_index += step) { const int token_idx = linear_index / offset; const int ori_bi = batch_id_per_token[token_idx]; + if (ori_bi == -1) continue; if (seq_lens && seq_lens[ori_bi] == 0) continue; const int bias = linear_index % offset; const int qkv_id = bias / hidden_size; @@ -294,9 +336,12 @@ __global__ void NeoxVariableLengthRotaryKernel( const int hi = qkv_bias / half_lastdim; const int h_bias = qkv_bias % half_lastdim; - const int ori_seq_id = (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi]; + const int ori_seq_id = + (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi]; const int emb_idx = ori_seq_id * last_dim + h_bias; + int new_emb_idx = + rope_3d ? emb_idx + ori_bi * last_dim * seq_len * 2 : emb_idx; const int base_idx_left = token_idx * 3 * full_hidden_size + qkv_id * full_hidden_size + hi * last_dim + h_bias; @@ -304,8 +349,8 @@ __global__ void NeoxVariableLengthRotaryKernel( Load(&qkv[base_idx_left], &left_vec); Load(&qkv[base_idx_right], &right_vec); - Load(&cos_emb[emb_idx], &cos_emb_vec); - Load(&sin_emb[emb_idx], &sin_emb_vec); + Load(&cos_emb[new_emb_idx], &cos_emb_vec); + Load(&sin_emb[new_emb_idx], &sin_emb_vec); #pragma unroll for (int i = 0; i < VecSize; i++) { const float input_left = static_cast(left_vec[i]); @@ -313,17 +358,22 @@ __global__ void NeoxVariableLengthRotaryKernel( const float cos_tmp = cos_emb_vec[i]; const float sin_tmp = sin_emb_vec[i]; left_vec[i] = - static_cast(input_left * cos_tmp - input_right * sin_tmp); + static_cast(fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp)); right_vec[i] = - static_cast(input_right * cos_tmp + input_left * sin_tmp); + static_cast(fmul_func(input_right, cos_tmp) + + fmul_func(input_left, sin_tmp)); } Store(left_vec, &qkv_out[base_idx_left]); Store(right_vec, &qkv_out[base_idx_right]); } +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaGridDependencySynchronize(); +#endif } -template -__global__ void GQAVariableLengthRotaryKernel( +template +__global__ void IntGQAVariableLengthRotaryKernel( const int *qkv, const float *cos_emb, // [1, 1, seq_len, dim_head / 2] const float *sin_emb, @@ -353,20 +403,27 @@ __global__ void GQAVariableLengthRotaryKernel( int64_t global_thread_idx = blockDim.x * blockIdx.x + threadIdx.x; const int half_lastdim = last_dim / 2; const int offset = (q_num_head + 2 * kv_num_head) * last_dim; +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaGridDependencySynchronize(); +#endif for (int64_t linear_index = global_thread_idx * VecSize, step = gridDim.x * blockDim.x * VecSize; linear_index < elem_cnt; linear_index += step) { const int token_idx = linear_index / offset; - const int ori_bi = batch_id_per_token[token_idx];; + const int ori_bi = batch_id_per_token[token_idx]; + if (ori_bi == -1) continue; if (seq_lens[ori_bi] == 0) continue; const int bias = linear_index % offset; const int hi = bias / last_dim; const int h_bias = bias % last_dim; - const int ori_seq_id = (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi]; + const int ori_seq_id = + (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi]; const int64_t emb_idx = ori_seq_id * half_lastdim + h_bias / 2; + int64_t new_emb_idx = + rope_3d ? emb_idx + ori_bi * last_dim * seq_len : emb_idx; const int64_t bias_idx = hi * last_dim + h_bias; const int64_t base_idx = token_idx * offset + bias_idx; Load(&qkv[base_idx], &src_vec); @@ -375,8 +432,8 @@ __global__ void GQAVariableLengthRotaryKernel( } Load(&qkv_out_scales[bias_idx], &out_scale_vec); if (hi < q_num_head + kv_num_head) { - Load(&cos_emb[emb_idx], &cos_emb_vec); - Load(&sin_emb[emb_idx], &sin_emb_vec); + Load(&cos_emb[new_emb_idx], &cos_emb_vec); + Load(&sin_emb[new_emb_idx], &sin_emb_vec); } #pragma unroll for (int i = 0; i < HalfVecSize; i++) { @@ -393,9 +450,11 @@ __global__ void GQAVariableLengthRotaryKernel( const float cos_tmp = cos_emb_vec[i]; const float sin_tmp = sin_emb_vec[i]; bias_vec[2 * i] = - static_cast(input_left * cos_tmp - input_right * sin_tmp); + static_cast(fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp)); bias_vec[2 * i + 1] = - static_cast(input_right * cos_tmp + input_left * sin_tmp); + static_cast(fmul_func(input_right, cos_tmp) + + fmul_func(input_left, sin_tmp)); } else { bias_vec[2 * i] = static_cast(input_left); bias_vec[2 * i + 1] = static_cast(input_right); @@ -403,10 +462,13 @@ __global__ void GQAVariableLengthRotaryKernel( } Store(bias_vec, &qkv_out[base_idx]); } +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaGridDependencySynchronize(); +#endif } -template -__global__ void GQAVariableLengthRotaryKernel( +template +__global__ void GQAVariableLengthRotaryQKNormKernel( const T *qkv, const float *cos_emb, const float *sin_emb, @@ -420,7 +482,105 @@ __global__ void GQAVariableLengthRotaryKernel( const int kv_num_head, const int seq_len, const int last_dim, - const bool rope_3d) { + const bool rope_3d, + const float *q_norm_weight, + const float *k_norm_weight, + const float rms_norm_eps) { + using LoadT = AlignedVector; + constexpr int HalfVecSize = VecSize / 2; + using LoadEmbT = AlignedVector; + using LoadFloat = AlignedVector; + LoadT src_vec; + LoadEmbT cos_emb_vec; + LoadEmbT sin_emb_vec; + LoadFloat tmp_vec; + LoadFloat q_norm_vec, k_norm_vec; + int64_t global_warp_idx = blockDim.y * blockIdx.x + threadIdx.y; + int64_t all_warp_num = gridDim.x * blockDim.y; + const int half_lastdim = last_dim / 2; + const int offset = (q_num_head + kv_num_head) * last_dim; + const int all_head_num = elem_cnt / last_dim; +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaGridDependencySynchronize(); +#endif + for (int global_hi = global_warp_idx; global_hi < all_head_num; + global_hi += all_warp_num) { + int64_t linear_index = global_hi * last_dim + threadIdx.x * VecSize; + const int token_idx = linear_index / offset; + const int ori_bi = batch_id_per_token[token_idx]; + if (seq_lens[ori_bi] == 0) continue; + const int bias = linear_index % offset; + const int hi = bias / last_dim; + const int h_bias = bias % last_dim; + + const int ori_seq_id = + (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi]; + const int64_t emb_idx = ori_seq_id * half_lastdim + h_bias / 2; + const int64_t base_idx = + token_idx * (q_num_head + 2 * kv_num_head) * last_dim + hi * last_dim + + h_bias; + Load(&qkv[base_idx], &src_vec); + + int64_t new_emb_idx = + rope_3d ? emb_idx + ori_bi * last_dim * seq_len : emb_idx; + Load(&cos_emb[new_emb_idx], &cos_emb_vec); + Load(&sin_emb[new_emb_idx], &sin_emb_vec); + + float thread_m2 = 0.0f; + float warp_m2 = 0.0f; + +#pragma unroll + for (int i = 0; i < HalfVecSize; i++) { + const float input_left = static_cast(src_vec[2 * i]); + const float input_right = static_cast(src_vec[2 * i + 1]); + const float cos_tmp = cos_emb_vec[i]; + const float sin_tmp = sin_emb_vec[i]; + float tmp1 = fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp); + float tmp2 = fmul_func(input_right, cos_tmp) + + fmul_func(input_left, sin_tmp); + tmp_vec[2 * i] = tmp1; + tmp_vec[2 * i + 1] = tmp2; + thread_m2 += tmp1 * tmp1 + tmp2 * tmp2; + } + WelfordWarpAllReduce(thread_m2, &warp_m2); + float row_variance = max(warp_m2 / last_dim, 0.0f); + float row_inv_var = Rsqrt(row_variance + rms_norm_eps); + + if (hi < q_num_head) { + Load(&q_norm_weight[threadIdx.x * VecSize], &q_norm_vec); +#pragma unroll + for (int i = 0; i < VecSize; i++) { + src_vec[i] = static_cast(tmp_vec[i] * row_inv_var * q_norm_vec[i]); + } + } else { + Load(&k_norm_weight[threadIdx.x * VecSize], &k_norm_vec); + for (int i = 0; i < VecSize; i++) { + src_vec[i] = static_cast(tmp_vec[i] * row_inv_var * k_norm_vec[i]); + } + } + Store(src_vec, &qkv_out[base_idx]); + } +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaGridDependencySynchronize(); +#endif +} + +template +__global__ void GQAVariableLengthRotaryKernel(const T *qkv, + const float *cos_emb, + const float *sin_emb, + const int *batch_id_per_token, + const int *cu_seqlens_q, + const int *seq_lens, + const int *seq_lens_decoder, + T *qkv_out, + const int64_t elem_cnt, + const int q_num_head, + const int kv_num_head, + const int seq_len, + const int last_dim, + const bool rope_3d) { using LoadT = AlignedVector; constexpr int HalfVecSize = VecSize / 2; using LoadEmbT = AlignedVector; @@ -430,18 +590,24 @@ __global__ void GQAVariableLengthRotaryKernel( int64_t global_thread_idx = blockDim.x * blockIdx.x + threadIdx.x; const int half_lastdim = last_dim / 2; const int offset = (q_num_head + kv_num_head) * last_dim; +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaGridDependencySynchronize(); +#endif for (int64_t linear_index = global_thread_idx * VecSize, step = gridDim.x * blockDim.x * VecSize; linear_index < elem_cnt; linear_index += step) { const int token_idx = linear_index / offset; - const int ori_bi = batch_id_per_token[token_idx];; + const int ori_bi = batch_id_per_token[token_idx]; + ; + if (ori_bi == -1) continue; if (seq_lens[ori_bi] == 0) continue; const int bias = linear_index % offset; const int hi = bias / last_dim; const int h_bias = bias % last_dim; - const int ori_seq_id = (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi]; + const int ori_seq_id = + (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi]; const int64_t emb_idx = ori_seq_id * half_lastdim + h_bias / 2; const int64_t base_idx = @@ -449,7 +615,8 @@ __global__ void GQAVariableLengthRotaryKernel( h_bias; Load(&qkv[base_idx], &src_vec); - int64_t new_emb_idx = rope_3d ? emb_idx + ori_bi * last_dim * seq_len : emb_idx; + int64_t new_emb_idx = + rope_3d ? emb_idx + ori_bi * last_dim * seq_len : emb_idx; Load(&cos_emb[new_emb_idx], &cos_emb_vec); Load(&sin_emb[new_emb_idx], &sin_emb_vec); #pragma unroll @@ -459,33 +626,39 @@ __global__ void GQAVariableLengthRotaryKernel( const float cos_tmp = cos_emb_vec[i]; const float sin_tmp = sin_emb_vec[i]; src_vec[2 * i] = - static_cast(input_left * cos_tmp - input_right * sin_tmp); + static_cast(fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp)); src_vec[2 * i + 1] = - static_cast(input_right * cos_tmp + input_left * sin_tmp); + static_cast(fmul_func(input_right, cos_tmp) + + fmul_func(input_left, sin_tmp)); } Store(src_vec, &qkv_out[base_idx]); } +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaGridDependencySynchronize(); +#endif } -template -__global__ void GQAVariableLengthRotaryQuantKVKernel(const int *qkv, - const float *cos_emb, // [1, 1, seq_len, dim_head / 2] - const float *sin_emb, - const float *qkv_out_scales, - const int *batch_id_per_token, - const int *cu_seqlens_q, - const int *seq_lens, - const int *seq_lens_decoder, - const T *qkv_biases, - const T *cache_k_scales, - const T *cache_v_scales, - T *qkv_out, - const int64_t elem_cnt, - const int q_num_head, - const int kv_num_head, - const int seq_len, - const int last_dim, - const bool rope_3d) { +template +__global__ void IntGQAVariableLengthRotaryQuantKVKernel( + const int *qkv, + const float *cos_emb, // [1, 1, seq_len, dim_head / 2] + const float *sin_emb, + const float *qkv_out_scales, + const int *batch_id_per_token, + const int *cu_seqlens_q, + const int *seq_lens, + const int *seq_lens_decoder, + const T *qkv_biases, + const T *cache_k_scales, + const T *cache_v_scales, + T *qkv_out, + const int64_t elem_cnt, + const int q_num_head, + const int kv_num_head, + const int seq_len, + const int last_dim, + const bool rope_3d) { using LoadIn = AlignedVector; using LoadBiasT = AlignedVector; constexpr int HalfVecSize = VecSize / 2; @@ -500,20 +673,27 @@ __global__ void GQAVariableLengthRotaryQuantKVKernel(const int *qkv, const int half_lastdim = last_dim / 2; // const int hidden_size = num_head * last_dim; const int offset = (q_num_head + 2 * kv_num_head) * last_dim; +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaGridDependencySynchronize(); +#endif for (int64_t linear_index = global_thread_idx * VecSize, step = gridDim.x * blockDim.x * VecSize; linear_index < elem_cnt; linear_index += step) { const int token_idx = linear_index / offset; const int ori_bi = batch_id_per_token[token_idx]; + if (ori_bi == -1) continue; if (seq_lens[ori_bi] == 0) continue; const int bias = linear_index % offset; const int hi = bias / last_dim; const int h_bias = bias % last_dim; - int ori_seq_id = (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi]; + int ori_seq_id = + (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi]; const int64_t emb_idx = ori_seq_id * half_lastdim + h_bias / 2; + int64_t new_emb_idx = + rope_3d ? emb_idx + ori_bi * last_dim * seq_len : emb_idx; const int64_t bias_idx = hi * last_dim + h_bias; const int64_t base_idx = token_idx * offset + bias_idx; Load(&qkv[base_idx], &src_vec); @@ -521,8 +701,8 @@ __global__ void GQAVariableLengthRotaryQuantKVKernel(const int *qkv, Load(&qkv_biases[bias_idx], &bias_vec); } Load(&qkv_out_scales[bias_idx], &out_scale_vec); - Load(&cos_emb[emb_idx], &cos_emb_vec); - Load(&sin_emb[emb_idx], &sin_emb_vec); + Load(&cos_emb[new_emb_idx], &cos_emb_vec); + Load(&sin_emb[new_emb_idx], &sin_emb_vec); #pragma unroll for (int i = 0; i < HalfVecSize; i++) { float input_left = static_cast(src_vec[2 * i]); @@ -533,47 +713,63 @@ __global__ void GQAVariableLengthRotaryQuantKVKernel(const int *qkv, input_right = qkv_biases ? input_right * out_scale_vec[2 * i + 1] + static_cast(bias_vec[2 * i + 1]) : input_right * out_scale_vec[2 * i + 1]; - if (hi < q_num_head) { // qk rope + if (hi < q_num_head) { // qk rope const float cos_tmp = cos_emb_vec[i]; const float sin_tmp = sin_emb_vec[i]; - bias_vec[2 * i] = static_cast(input_left * cos_tmp - input_right * sin_tmp); - bias_vec[2 * i + 1] = static_cast(input_right * cos_tmp + input_left * sin_tmp); + bias_vec[2 * i] = + static_cast(fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp)); + bias_vec[2 * i + 1] = + static_cast(fmul_func(input_right, cos_tmp) + + fmul_func(input_left, sin_tmp)); } else if (hi < q_num_head + kv_num_head) { int k_hi = hi - q_num_head; const int scale_idx = k_hi * last_dim + h_bias; const float cos_tmp = cos_emb_vec[i]; const float sin_tmp = sin_emb_vec[i]; - bias_vec[2 * i] = static_cast((input_left * cos_tmp - input_right * sin_tmp) * float(cache_k_scales[scale_idx + 2 * i])); - bias_vec[2 * i + 1] = static_cast((input_right * cos_tmp + input_left * sin_tmp) * float(cache_k_scales[scale_idx + 2 * i + 1])); + bias_vec[2 * i] = + static_cast((fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp)) * + float(cache_k_scales[scale_idx + 2 * i])); + bias_vec[2 * i + 1] = + static_cast((fmul_func(input_right, cos_tmp) + + fmul_func(input_left, sin_tmp)) * + float(cache_k_scales[scale_idx + 2 * i + 1])); } else { int v_hi = hi - q_num_head - kv_num_head; const int scale_idx = v_hi * last_dim + h_bias; - bias_vec[2 * i] = static_cast(input_left * float(cache_v_scales[scale_idx + 2 * i])); - bias_vec[2 * i + 1] = static_cast(input_right * float(cache_v_scales[scale_idx + 2 * i + 1])); + bias_vec[2 * i] = static_cast( + input_left * float(cache_v_scales[scale_idx + 2 * i])); + bias_vec[2 * i + 1] = static_cast( + input_right * float(cache_v_scales[scale_idx + 2 * i + 1])); } } Store(bias_vec, &qkv_out[base_idx]); } +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaGridDependencySynchronize(); +#endif } -template -__global__ void GQAVariableLengthRotaryQuantKVKernel(const T *qkv, - const float *cos_emb, // [1, 1, seq_len, dim_head / 2] - const float *sin_emb, - const int *batch_id_per_token, - const int *cu_seqlens_q, - const int *seq_lens, - const int *seq_lens_decoder, - const T *qkv_biases, - const T *cache_k_scales, - const T *cache_v_scales, - T *qkv_out, - const int64_t elem_cnt, - const int q_num_head, - const int kv_num_head, - const int seq_len, - const int last_dim, - const bool rope_3d) { +template +__global__ void GQAVariableLengthRotaryQuantKVKernel( + const T *qkv, + const float *cos_emb, // [1, 1, seq_len, dim_head / 2] + const float *sin_emb, + const int *batch_id_per_token, + const int *cu_seqlens_q, + const int *seq_lens, + const int *seq_lens_decoder, + const T *qkv_biases, + const T *cache_k_scales, + const T *cache_v_scales, + T *qkv_out, + const int64_t elem_cnt, + const int q_num_head, + const int kv_num_head, + const int seq_len, + const int last_dim, + const bool rope_3d) { using LoadT = AlignedVector; constexpr int HalfVecSize = VecSize / 2; using LoadEmbT = AlignedVector; @@ -585,61 +781,90 @@ __global__ void GQAVariableLengthRotaryQuantKVKernel(const T *qkv, const int half_lastdim = last_dim / 2; // const int hidden_size = num_head * last_dim; const int offset = (q_num_head + 2 * kv_num_head) * last_dim; +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaGridDependencySynchronize(); +#endif for (int64_t linear_index = global_thread_idx * VecSize, step = gridDim.x * blockDim.x * VecSize; linear_index < elem_cnt; linear_index += step) { const int token_idx = linear_index / offset; const int ori_bi = batch_id_per_token[token_idx]; + if (ori_bi == -1) continue; if (seq_lens[ori_bi] == 0) continue; const int bias = linear_index % offset; const int hi = bias / last_dim; const int h_bias = bias % last_dim; - int ori_seq_id = (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi]; + int ori_seq_id = + (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi]; const int64_t emb_idx = ori_seq_id * half_lastdim + h_bias / 2; + int64_t new_emb_idx = + rope_3d ? emb_idx + ori_bi * last_dim * seq_len : emb_idx; const int64_t bias_idx = hi * last_dim + h_bias; const int64_t base_idx = token_idx * offset + bias_idx; Load(&qkv[base_idx], &src_vec); if (qkv_biases) { Load(&qkv_biases[bias_idx], &bias_vec); } - Load(&cos_emb[emb_idx], &cos_emb_vec); - Load(&sin_emb[emb_idx], &sin_emb_vec); + Load(&cos_emb[new_emb_idx], &cos_emb_vec); + Load(&sin_emb[new_emb_idx], &sin_emb_vec); #pragma unroll for (int i = 0; i < HalfVecSize; i++) { - const float input_left = qkv_biases ? static_cast(src_vec[2 * i]+ bias_vec[2 * i]) : static_cast(src_vec[2 * i]); - const float input_right = qkv_biases ? static_cast(src_vec[2 * i + 1] + bias_vec[2 * i + 1]) : static_cast(src_vec[2 * i + 1]); + const float input_left = + qkv_biases ? static_cast(src_vec[2 * i] + bias_vec[2 * i]) + : static_cast(src_vec[2 * i]); + const float input_right = + qkv_biases + ? static_cast(src_vec[2 * i + 1] + bias_vec[2 * i + 1]) + : static_cast(src_vec[2 * i + 1]); // const float cos_tmp = cos_emb_vec[i]; // const float sin_tmp = sin_emb_vec[i]; - // src_vec[2 * i] = static_cast(input_left * cos_tmp - input_right * sin_tmp); - // src_vec[2 * i + 1] = static_cast(input_right * cos_tmp + input_left * sin_tmp); - if (hi < q_num_head) { // qk rope + // src_vec[2 * i] = static_cast(fmul_func(input_left, + // cos_tmp) - input_right * sin_tmp); src_vec[2 * i + 1] = + // static_cast(fmul_func(input_right, cos_tmp) + + // fmul_func(input_left, sin_tmp)); + if (hi < q_num_head) { // qk rope const float cos_tmp = cos_emb_vec[i]; const float sin_tmp = sin_emb_vec[i]; - src_vec[2 * i] = static_cast(input_left * cos_tmp - input_right * sin_tmp); - src_vec[2 * i + 1] = static_cast(input_right * cos_tmp + input_left * sin_tmp); + src_vec[2 * i] = + static_cast(fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp)); + src_vec[2 * i + 1] = + static_cast(fmul_func(input_right, cos_tmp) + + fmul_func(input_left, sin_tmp)); } else if (hi < q_num_head + kv_num_head) { int k_hi = hi - q_num_head; const int scale_idx = k_hi * last_dim + h_bias; const float cos_tmp = cos_emb_vec[i]; const float sin_tmp = sin_emb_vec[i]; - src_vec[2 * i] = static_cast((input_left * cos_tmp - input_right * sin_tmp) * float(cache_k_scales[scale_idx + 2 * i])); - src_vec[2 * i + 1] = static_cast((input_right * cos_tmp + input_left * sin_tmp) * float(cache_k_scales[scale_idx + 2 * i + 1])); + src_vec[2 * i] = + static_cast((fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp)) * + float(cache_k_scales[scale_idx + 2 * i])); + src_vec[2 * i + 1] = + static_cast((fmul_func(input_right, cos_tmp) + + fmul_func(input_left, sin_tmp)) * + float(cache_k_scales[scale_idx + 2 * i + 1])); } else { int v_hi = hi - q_num_head - kv_num_head; const int scale_idx = v_hi * last_dim + h_bias; - src_vec[2 * i] = static_cast(input_left * float(cache_v_scales[scale_idx + 2 * i])); - src_vec[2 * i + 1] = static_cast(input_right * float(cache_v_scales[scale_idx + 2 * i + 1])); + src_vec[2 * i] = static_cast( + input_left * float(cache_v_scales[scale_idx + 2 * i])); + src_vec[2 * i + 1] = static_cast( + input_right * float(cache_v_scales[scale_idx + 2 * i + 1])); } } Store(src_vec, &qkv_out[base_idx]); } +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaGridDependencySynchronize(); +#endif } -template -__global__ void GQANeoxVariableLengthRotaryKernel( +template +__global__ void IntGQANeoxVariableLengthRotaryKernel( const int *qkv, const float *cos_emb, // [1, 1, seq_len, dim_head / 2] const float *sin_emb, @@ -654,7 +879,8 @@ __global__ void GQANeoxVariableLengthRotaryKernel( const int q_num_head, const int kv_num_head, const int seq_len, - const int last_dim) { + const int last_dim, + const bool rope_3d) { using LoadT = AlignedVector; using LoadBiasT = AlignedVector; using LoadScaleT = AlignedVector; @@ -670,20 +896,27 @@ __global__ void GQANeoxVariableLengthRotaryKernel( int64_t global_thread_idx = blockDim.x * blockIdx.x + threadIdx.x; const int half_lastdim = last_dim / 2; const int offset = (q_num_head + 2 * kv_num_head) * half_lastdim; +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaGridDependencySynchronize(); +#endif for (int64_t linear_index = global_thread_idx * VecSize, step = gridDim.x * blockDim.x * VecSize; linear_index < elem_cnt; linear_index += step) { const int token_idx = linear_index / offset; const int ori_bi = batch_id_per_token[token_idx]; + if (ori_bi == -1) continue; if (seq_lens && seq_lens[ori_bi] == 0) continue; const int bias = linear_index % offset; const int hi = bias / half_lastdim; const int h_bias = bias % half_lastdim; - const int ori_seq_id = (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi]; + const int ori_seq_id = + (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi]; const int emb_idx = ori_seq_id * last_dim + h_bias; + int new_emb_idx = + rope_3d ? emb_idx + ori_bi * last_dim * seq_len * 2 : emb_idx; const int bias_idx_left = hi * last_dim + h_bias; const int bias_idx_right = bias_idx_left + half_lastdim; const int base_idx_left = @@ -698,8 +931,8 @@ __global__ void GQANeoxVariableLengthRotaryKernel( Load(&qkv_out_scales[bias_idx_left], &left_out_scale_vec); Load(&qkv_out_scales[bias_idx_right], &right_out_scale_vec); if (hi < (q_num_head + kv_num_head)) { - Load(&cos_emb[emb_idx], &cos_emb_vec); - Load(&sin_emb[emb_idx], &sin_emb_vec); + Load(&cos_emb[new_emb_idx], &cos_emb_vec); + Load(&sin_emb[new_emb_idx], &sin_emb_vec); } #pragma unroll for (int i = 0; i < VecSize; i++) { @@ -716,9 +949,11 @@ __global__ void GQANeoxVariableLengthRotaryKernel( const float cos_tmp = cos_emb_vec[i]; const float sin_tmp = sin_emb_vec[i]; left_bias_vec[i] = - static_cast(input_left * cos_tmp - input_right * sin_tmp); + static_cast(fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp)); right_bias_vec[i] = - static_cast(input_right * cos_tmp + input_left * sin_tmp); + static_cast(fmul_func(input_right, cos_tmp) + + fmul_func(input_left, sin_tmp)); } else { left_bias_vec[i] = static_cast(input_left); right_bias_vec[i] = static_cast(input_right); @@ -727,10 +962,90 @@ __global__ void GQANeoxVariableLengthRotaryKernel( Store(left_bias_vec, &qkv_out[base_idx_left]); Store(right_bias_vec, &qkv_out[base_idx_right]); } +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaGridDependencySynchronize(); +#endif } -template -__global__ void GQANeoxVariableLengthRotaryKernel( +template +__global__ void GQANeoxVariableLengthRotaryKernel(const T *qkv, + const float *cos_emb, + const float *sin_emb, + const int *batch_id_per_token, + const int *cu_seqlens_q, + const int *seq_lens, + const int *seq_lens_decoder, + const float *qkv_out_scales, + const T *qkv_biases, + T *qkv_out, + const int64_t elem_cnt, + const int q_num_head, + const int kv_num_head, + const int seq_len, + const int last_dim, + const bool rope_3d) { + using LoadT = AlignedVector; + using LoadEmbT = AlignedVector; + LoadT left_vec; + LoadT right_vec; + LoadEmbT cos_emb_vec; + LoadEmbT sin_emb_vec; + int64_t global_thread_idx = blockDim.x * blockIdx.x + threadIdx.x; + const int half_lastdim = last_dim / 2; + const int offset = (q_num_head + kv_num_head) * half_lastdim; +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaGridDependencySynchronize(); +#endif + for (int64_t linear_index = global_thread_idx * VecSize, + step = gridDim.x * blockDim.x * VecSize; + linear_index < elem_cnt; + linear_index += step) { + const int token_idx = linear_index / offset; + const int ori_bi = batch_id_per_token[token_idx]; + if (ori_bi == -1) continue; + if (seq_lens && seq_lens[ori_bi] == 0) continue; + const int bias = linear_index % offset; + const int hi = bias / half_lastdim; + const int h_bias = bias % half_lastdim; + + const int ori_seq_id = + (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi]; + + const int emb_idx = ori_seq_id * last_dim + h_bias; + int64_t new_emb_idx = + rope_3d ? emb_idx + ori_bi * last_dim * seq_len * 2 : emb_idx; + const int base_idx_left = + token_idx * (q_num_head + 2 * kv_num_head) * last_dim + hi * last_dim + + h_bias; + const int base_idx_right = base_idx_left + half_lastdim; + + Load(&qkv[base_idx_left], &left_vec); + Load(&qkv[base_idx_right], &right_vec); + Load(&cos_emb[new_emb_idx], &cos_emb_vec); + Load(&sin_emb[new_emb_idx], &sin_emb_vec); +#pragma unroll + for (int i = 0; i < VecSize; i++) { + const float input_left = static_cast(left_vec[i]); + const float input_right = static_cast(right_vec[i]); + const float cos_tmp = cos_emb_vec[i]; + const float sin_tmp = sin_emb_vec[i]; + left_vec[i] = + static_cast(fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp)); + right_vec[i] = + static_cast(fmul_func(input_right, cos_tmp) + + fmul_func(input_left, sin_tmp)); + } + Store(left_vec, &qkv_out[base_idx_left]); + Store(right_vec, &qkv_out[base_idx_right]); + } +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaGridDependencySynchronize(); +#endif +} + +template +__global__ void GQANeoxVariableLengthPartialRotaryKernel( const T *qkv, const float *cos_emb, const float *sin_emb, @@ -745,7 +1060,9 @@ __global__ void GQANeoxVariableLengthRotaryKernel( const int q_num_head, const int kv_num_head, const int seq_len, - const int last_dim) { + const int head_dim, + const int rotary_dim, + const bool rope_3d) { using LoadT = AlignedVector; using LoadEmbT = AlignedVector; LoadT left_vec; @@ -753,31 +1070,38 @@ __global__ void GQANeoxVariableLengthRotaryKernel( LoadEmbT cos_emb_vec; LoadEmbT sin_emb_vec; int64_t global_thread_idx = blockDim.x * blockIdx.x + threadIdx.x; - const int half_lastdim = last_dim / 2; - const int offset = (q_num_head + kv_num_head) * half_lastdim; + const int rotary_dim_half = rotary_dim / 2; + const int offset = (q_num_head + kv_num_head) * rotary_dim_half; +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaGridDependencySynchronize(); +#endif for (int64_t linear_index = global_thread_idx * VecSize, step = gridDim.x * blockDim.x * VecSize; linear_index < elem_cnt; linear_index += step) { const int token_idx = linear_index / offset; const int ori_bi = batch_id_per_token[token_idx]; + if (ori_bi == -1) continue; if (seq_lens && seq_lens[ori_bi] == 0) continue; const int bias = linear_index % offset; - const int hi = bias / half_lastdim; - const int h_bias = bias % half_lastdim; + const int hi = bias / rotary_dim_half; + const int h_bias = bias % rotary_dim_half; - const int ori_seq_id = (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi]; + const int ori_seq_id = + (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi]; - const int emb_idx = ori_seq_id * last_dim + h_bias; + const int emb_idx = ori_seq_id * rotary_dim_half + h_bias; + int64_t new_emb_idx = + rope_3d ? emb_idx + ori_bi * head_dim * seq_len * 2 : emb_idx; const int base_idx_left = - token_idx * (q_num_head + 2 * kv_num_head) * last_dim + hi * last_dim + + token_idx * (q_num_head + 2 * kv_num_head) * head_dim + hi * head_dim + h_bias; - const int base_idx_right = base_idx_left + half_lastdim; + const int base_idx_right = base_idx_left + rotary_dim_half; Load(&qkv[base_idx_left], &left_vec); Load(&qkv[base_idx_right], &right_vec); - Load(&cos_emb[emb_idx], &cos_emb_vec); - Load(&sin_emb[emb_idx], &sin_emb_vec); + Load(&cos_emb[new_emb_idx], &cos_emb_vec); + Load(&sin_emb[new_emb_idx], &sin_emb_vec); #pragma unroll for (int i = 0; i < VecSize; i++) { const float input_left = static_cast(left_vec[i]); @@ -785,16 +1109,21 @@ __global__ void GQANeoxVariableLengthRotaryKernel( const float cos_tmp = cos_emb_vec[i]; const float sin_tmp = sin_emb_vec[i]; left_vec[i] = - static_cast(input_left * cos_tmp - input_right * sin_tmp); + static_cast(fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp)); right_vec[i] = - static_cast(input_right * cos_tmp + input_left * sin_tmp); + static_cast(fmul_func(input_right, cos_tmp) + + fmul_func(input_left, sin_tmp)); } Store(left_vec, &qkv_out[base_idx_left]); Store(right_vec, &qkv_out[base_idx_right]); } +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaGridDependencySynchronize(); +#endif } -template +template __global__ void cache_kernel( const T *__restrict__ qkv, // [num_tokens, num_heads + 2 * kv_num_heads, // head_size] @@ -802,11 +1131,11 @@ __global__ void cache_kernel( // head_size] T *__restrict__ value_cache, // [num_blocks, kv_num_heads, block_size, // head_size] - const int *__restrict__ block_tables, // [bsz, max_blocks_per_seq] - const int *__restrict__ batch_id_per_token, // [num_tokens] - const int *__restrict__ cu_seqlens_q, // [bsz] - const int *__restrict__ seq_lens, // [bsz] - const int *__restrict__ seq_lens_decoder, // [bsz] + const int *__restrict__ block_tables, // [bsz, max_blocks_per_seq] + const int *__restrict__ batch_id_per_token, // [num_tokens] + const int *__restrict__ cu_seqlens_q, // [bsz] + const int *__restrict__ seq_lens, // [bsz] + const int *__restrict__ seq_lens_decoder, // [bsz] const int max_seq_len, const int max_blocks_per_seq, const int num_heads, @@ -820,6 +1149,9 @@ __global__ void cache_kernel( uint32_t global_thread_idx = blockDim.x * blockIdx.x + threadIdx.x; const uint32_t hidden_size = kv_num_heads * head_size; const uint32_t offset = 2 * hidden_size; +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaGridDependencySynchronize(); +#endif for (uint32_t linear_index = global_thread_idx * VecSize, step = gridDim.x * blockDim.x * VecSize; linear_index < elem_cnt; @@ -830,9 +1162,11 @@ __global__ void cache_kernel( const uint32_t qkv_bias = bias % hidden_size; const uint32_t hi = qkv_bias / head_size; const uint32_t h_bias = qkv_bias % head_size; - const uint32_t ori_bi = batch_id_per_token[token_idx]; + const int32_t ori_bi = batch_id_per_token[token_idx]; + if (ori_bi == -1) continue; // skip batch_id_per_token[token_idx]=-1 if (seq_lens[ori_bi] == 0) continue; - const uint32_t ori_seq_id = (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi]; + const uint32_t ori_seq_id = + (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi]; const int32_t *block_table_now = nullptr; @@ -854,9 +1188,11 @@ __global__ void cache_kernel( Store(src_vec, &value_cache[tgt_idx]); } } +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaGridDependencySynchronize(); +#endif } - template + bool IsFP8 = false> __global__ void append_write_cache_kv_c8_qkv( uint8_t *__restrict__ cache_k, uint8_t *__restrict__ cache_v, @@ -888,6 +1224,9 @@ __global__ void append_write_cache_kv_c8_qkv( const T cache_k_scale = cache_k_scales[kv_head_idx]; const T cache_v_scale = cache_v_scales[kv_head_idx]; const uint32_t tid = threadIdx.x, wid = threadIdx.y; +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaGridDependencySynchronize(); +#endif const uint32_t batch_id = batch_ids[btid]; const uint32_t tile_id = tile_ids[btid]; const uint32_t seq_len_this_time = seq_lens_this_time[batch_id]; @@ -906,6 +1245,7 @@ __global__ void append_write_cache_kv_c8_qkv( const uint32_t end_len = start_len + seq_len_this_time; const uint32_t tile_start = start_len_pad + tile_id * num_rows_per_block; + int block_id = __ldg(&block_table_now[tile_start / BLOCK_SIZE]); uint32_t chunk_start = tile_start + wid * num_frags_z * 16 + tid / 8; const uint32_t start_token_idx = cu_seqlens_q[batch_id]; @@ -913,7 +1253,36 @@ __global__ void append_write_cache_kv_c8_qkv( const uint32_t kv_h_stride = HEAD_DIM; __shared__ T k_smem_ori[num_rows_per_block * HEAD_DIM]; __shared__ T v_smem_ori[num_rows_per_block * HEAD_DIM]; + if (tile_start >= start_len) { + constexpr int KV_VEC_SIZE = 16 / sizeof(uint8_t); // 16 + using LoadPadKVT = AlignedVector; + // int lane_id = wid * 32 + tid; + // pad zero for this kv_head_idx for this block + LoadPadKVT pad_cache_vec; + *(reinterpret_cast(pad_cache_vec.val)) = make_uint4(0, 0, 0, 0); + // reset k + constexpr int num_vecs_per_head_k = HEAD_DIM / KV_VEC_SIZE; + constexpr int num_token_each_time_k = 32 / num_vecs_per_head_k; + uint32_t tgt_idx = + (block_id * kv_num_heads + kv_head_idx) * BLOCK_SIZE * HEAD_DIM + + tid % num_vecs_per_head_k * KV_VEC_SIZE; + for (int block_i = tid / num_vecs_per_head_k; block_i < BLOCK_SIZE; + block_i += num_token_each_time_k) { + Store(pad_cache_vec, + &cache_k[tgt_idx + block_i * HEAD_DIM]); + } + // reset v + const int num_vecs_per_head_v = BLOCK_SIZE / KV_VEC_SIZE; + const int num_token_each_time_v = 32 / num_vecs_per_head_v; + tgt_idx = (block_id * kv_num_heads + kv_head_idx) * HEAD_DIM * BLOCK_SIZE + + tid % num_vecs_per_head_v * KV_VEC_SIZE; + for (int block_i = tid / num_vecs_per_head_v; block_i < HEAD_DIM; + block_i += num_token_each_time_v) { + Store(pad_cache_vec, + &cache_v[tgt_idx + block_i * BLOCK_SIZE]); + } + } smem_t k_smem(k_smem_ori); smem_t v_smem(v_smem_ori); @@ -976,7 +1345,418 @@ __global__ void append_write_cache_kv_c8_qkv( uint32_t chunk_start_k = tile_start + wid * num_frags_z * 16 + tid / 4; uint32_t kv_frag[4]; + const uint32_t write_n_stride = kv_num_heads * BLOCK_SIZE * HEAD_DIM; + const uint32_t write_h_stride = BLOCK_SIZE * HEAD_DIM; + const uint32_t write_b_stride = HEAD_DIM; + const uint32_t write_d_stride = BLOCK_SIZE; + uint32_t k_write_idx = block_id * write_n_stride + + kv_head_idx * write_h_stride + + (wid * num_frags_z * 16 + tid / 4) * write_b_stride + + tid % 4 * 4; // 4 * int8 = 8 * int4 = 32bit +#pragma unroll + for (uint32_t fz = 0; fz < num_frags_z; ++fz) { + uint32_t k_write_idx_now_z = k_write_idx + fz * 16 * write_b_stride; +#pragma unroll + for (uint32_t fy = 0; fy < num_frags_y; ++fy) { + uint32_t k_write_idx_now = k_write_idx_now_z + + fy % 2 * 8 * write_b_stride + + fy / 2 * 32; // + fy % 2 * 16; + // load + k_smem.ldmatrix_m8n8x4(k_smem_offset_r, kv_frag); + // quant + T *k_frag_T = reinterpret_cast(kv_frag); + if (bf_pad_len != 0) { + Load(cache_k + k_write_idx_now, &cache_vec1); + Load(cache_k + k_write_idx_now + 16, &cache_vec2); + } +#pragma unroll + for (uint32_t v_id = 0; v_id < 8; ++v_id) { + uint8_t uint_quant_value; + if (chunk_start_k + (v_id / 4) * 8 >= start_len && + chunk_start_k + (v_id / 4) * 8 < end_len) { + uint_quant_value = QuantToC8( + cache_k_scale, k_frag_T[v_id], 127.0f, -127.0f); + } else { + uint_quant_value = 0; + } + if (bf_pad_len != 0) { + if (v_id < 4) { + cache_vec1[v_id] |= uint_quant_value; + } else { + cache_vec2[v_id % 4] |= uint_quant_value; + } + } else { + if (v_id < 4) { + cache_vec1[v_id] = uint_quant_value; + } else { + cache_vec2[v_id - 4] = uint_quant_value; + } + } + } + // store + Store(cache_vec1, cache_k + k_write_idx_now); + Store(cache_vec2, cache_k + k_write_idx_now + 16); + k_smem_offset_r = k_smem.advance_offset_by_column<2>(k_smem_offset_r, fy); + } + k_smem_offset_r = + k_smem.advance_offset_by_row<16, num_vecs_per_head>(k_smem_offset_r) - + 2 * num_frags_y; + chunk_start_k += 16; + } + + uint32_t chunk_start_v = tile_start + tid % 4 * 2; + uint32_t v_write_idx = block_id * write_n_stride + + kv_head_idx * write_h_stride + + (wid * num_frags_v * 16 + tid / 4) * write_d_stride + + tid % 4 * 4; // 4 * int8 = 8 * int4 = 32bit + const uint32_t num_frags_z_v = num_frags_z * NUM_WARPS; +#pragma unroll + for (uint32_t fy = 0; fy < num_frags_v; ++fy) { + uint32_t v_write_idx_now_v = v_write_idx + fy * 16 * write_d_stride; +#pragma unroll + for (uint32_t fz = 0; fz < num_frags_z_v; ++fz) { + uint32_t v_write_idx_now = v_write_idx_now_v + + fz % 2 * 8 * write_d_stride + + fz / 2 * 32; // + fz % 2 * 16; + // load + v_smem.ldmatrix_m8n8x4_trans(v_smem_offset_r, kv_frag); + // quant + T *v_frag_T = reinterpret_cast(kv_frag); + if (bf_pad_len != 0) { + Load(cache_v + v_write_idx_now, &cache_vec1); + Load(cache_v + v_write_idx_now + 16, &cache_vec2); + } +#pragma unroll + for (uint32_t v_id = 0; v_id < 8; ++v_id) { + uint8_t uint_quant_value; + if (chunk_start_v + v_id % 2 + (v_id % 4) / 2 * 8 >= start_len && + chunk_start_v + v_id % 2 + (v_id % 4) / 2 * 8 < end_len) { + uint_quant_value = QuantToC8( + cache_v_scale, v_frag_T[v_id], 127.0f, -127.0f); + // store now + } else { + uint_quant_value = 0; + } + if (bf_pad_len != 0) { + if (v_id < 4) { + cache_vec1[v_id] |= uint_quant_value; + } else { + cache_vec2[v_id % 4] |= uint_quant_value; + } + } else { + if (v_id < 4) { + cache_vec1[v_id] = uint_quant_value; + } else { + cache_vec2[v_id % 4] = uint_quant_value; + } + } + } + // store + Store(cache_vec1, cache_v + v_write_idx_now); + Store(cache_vec2, cache_v + v_write_idx_now + 16); + chunk_start_v += 16; + v_smem_offset_r = + k_smem.advance_offset_by_row<16, num_vecs_per_head>(v_smem_offset_r); + } + v_smem_offset_r = k_smem.advance_offset_by_column<2>( + v_smem_offset_r, wid * num_frags_v + fy) - + 16 * num_frags_z_v * num_vecs_per_head; + chunk_start_v -= 16 * num_frags_z_v; + } +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaGridDependencySynchronize(); +#endif +} + +template +__global__ void append_write_cache_kv_c8_qkv_dynamic( + uint8_t *__restrict__ cache_k, + uint8_t *__restrict__ cache_v, + const T *__restrict__ qkv_input, + T *__restrict__ cache_k_scales, // [block_num, num_heads, block_size] + T *__restrict__ cache_v_scales, // [block_num, num_heads, block_size] + const int *__restrict__ batch_ids, + const int *__restrict__ tile_ids, + const int *__restrict__ seq_lens_this_time, + const int *__restrict__ seq_lens_decoder, + const int *__restrict__ batch_id_per_token, + const int *__restrict__ cu_seqlens_q, + const int *__restrict__ block_tables, + const int max_seq_len, + const int max_blocks_per_seq, + const int num_heads, + const int kv_num_heads) { + constexpr uint32_t num_vecs_per_head = HEAD_DIM / num_elems_per_128b(); + constexpr uint32_t pad_len = BLOCK_SIZE; + const uint32_t btid = blockIdx.x, kv_head_idx = blockIdx.z; + const T cache_k_scale = cache_k_scales[kv_head_idx]; + const T cache_v_scale = cache_v_scales[kv_head_idx]; + const uint32_t tid = threadIdx.x, wid = threadIdx.y; +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaGridDependencySynchronize(); +#endif + const uint32_t batch_id = batch_ids[btid]; + const uint32_t tile_id = tile_ids[btid]; + const uint32_t seq_len_this_time = seq_lens_this_time[batch_id]; + if (seq_len_this_time <= 0) { + return; + } + const int *block_table_now = nullptr; + + block_table_now = block_tables + batch_id * max_blocks_per_seq; + + const uint32_t num_rows_per_block = + NUM_WARPS * num_frags_z * 16; // BLOCK_SIZE + const uint32_t start_len = seq_lens_decoder[batch_id]; + const uint32_t bf_pad_len = start_len % pad_len; + const uint32_t start_len_pad = start_len - bf_pad_len; + const uint32_t end_len = start_len + seq_len_this_time; + + const uint32_t tile_start = start_len_pad + tile_id * num_rows_per_block; int block_id = __ldg(&block_table_now[tile_start / BLOCK_SIZE]); + uint32_t chunk_start = tile_start + wid * num_frags_z * 16 + tid / 8; + + const uint32_t start_token_idx = cu_seqlens_q[batch_id]; + const uint32_t kv_batch_stride = (num_heads + 2 * kv_num_heads) * HEAD_DIM; + const uint32_t kv_h_stride = HEAD_DIM; + __shared__ T k_smem_ori[num_rows_per_block * HEAD_DIM]; + __shared__ T v_smem_ori[num_rows_per_block * HEAD_DIM]; + __shared__ T v_scale_smem[BLOCK_SIZE]; + if (tile_start >= start_len) { + constexpr int KV_VEC_SIZE = 16 / sizeof(uint8_t); // 16 + using LoadPadKVT = AlignedVector; + // pad zero for this kv_head_idx for this block + LoadPadKVT pad_cache_vec; + *(reinterpret_cast(pad_cache_vec.val)) = make_uint4(0, 0, 0, 0); + // reset k + constexpr int num_vecs_per_head_k = HEAD_DIM / KV_VEC_SIZE; + constexpr int num_token_each_time_k = 32 / num_vecs_per_head_k; + uint32_t tgt_idx = + (block_id * kv_num_heads + kv_head_idx) * BLOCK_SIZE * HEAD_DIM + + tid % num_vecs_per_head_k * KV_VEC_SIZE; + for (int block_i = tid / num_vecs_per_head_k; block_i < BLOCK_SIZE; + block_i += num_token_each_time_k) { + Store(pad_cache_vec, + &cache_k[tgt_idx + block_i * HEAD_DIM]); + } + + // reset v + const int num_vecs_per_head_v = BLOCK_SIZE / KV_VEC_SIZE; + const int num_token_each_time_v = 32 / num_vecs_per_head_v; + tgt_idx = (block_id * kv_num_heads + kv_head_idx) * HEAD_DIM * BLOCK_SIZE + + tid % num_vecs_per_head_v * KV_VEC_SIZE; + for (int block_i = tid / num_vecs_per_head_v; block_i < HEAD_DIM; + block_i += num_token_each_time_v) { + Store(pad_cache_vec, + &cache_v[tgt_idx + block_i * BLOCK_SIZE]); + } + } + smem_t k_smem(k_smem_ori); + smem_t v_smem(v_smem_ori); + + uint32_t kv_smem_offset_w = smem_t::get_permuted_offset( + wid * num_frags_z * 16 + tid / 8, tid % 8); // 4 * 8 per warp + + /* + 0 | 1 + 2 | 3 + */ + uint32_t k_smem_offset_r = smem_t::get_permuted_offset( + wid * num_frags_z * 16 + 8 * (tid / 16) + tid % 8, (tid % 16) / 8); + + constexpr uint32_t num_frags_v = num_frags_y / NUM_WARPS; + /* + 0 | 2 + 1 | 3 + */ + uint32_t v_smem_offset_r = smem_t::get_permuted_offset( + tid % 16, wid * num_frags_v * 2 + tid / 16); + + // load kv gmem to smem + const uint32_t real_start_token_idx = start_token_idx - bf_pad_len + + tile_id * num_rows_per_block + + wid * num_frags_z * 16 + tid / 8; + uint32_t k_read_idx = real_start_token_idx * kv_batch_stride + + (num_heads + kv_head_idx) * kv_h_stride + + tid % 8 * num_elems_per_128b(); + uint32_t v_read_idx = real_start_token_idx * kv_batch_stride + + (num_heads + kv_num_heads + kv_head_idx) * kv_h_stride + + tid % 8 * num_elems_per_128b(); +#pragma unroll + for (uint32_t fz = 0; fz < num_frags_z; ++fz) { +#pragma unroll + for (uint32_t j = 0; j < 4; ++j) { +#pragma unroll + for (uint32_t fy = 0; fy < num_frags_y / 4; + ++fy) { // (num_frags_y * 16) / (8 * num_elems_per_128b()) + if (chunk_start >= start_len && chunk_start < end_len) { + k_smem.load_128b_async( + kv_smem_offset_w, qkv_input + k_read_idx, chunk_start < end_len); + v_smem.load_128b_async( + kv_smem_offset_w, qkv_input + v_read_idx, chunk_start < end_len); + } + kv_smem_offset_w = + k_smem.advance_offset_by_column<8>(kv_smem_offset_w, fy); + k_read_idx += 8 * num_elems_per_128b(); + v_read_idx += 8 * num_elems_per_128b(); + } + kv_smem_offset_w = + k_smem.advance_offset_by_row<4, num_vecs_per_head>(kv_smem_offset_w) - + 2 * num_frags_y; + chunk_start += 4; + k_read_idx += + 4 * kv_batch_stride - 2 * num_frags_y * num_elems_per_128b(); + v_read_idx += + 4 * kv_batch_stride - 2 * num_frags_y * num_elems_per_128b(); + } + } + commit_group(); + wait_group<0>(); + __syncthreads(); + + // reduce scale + // 16 rows per warp + uint32_t kv_reduce_frag[4]; + T *kv_reduce_frag_T = reinterpret_cast(kv_reduce_frag); + + T k_local_max_value[num_frags_z * 2]; + T v_local_max_value[num_frags_z * 2]; +#pragma unroll + for (int i = 0; i < num_frags_z * 2; i++) { + k_local_max_value[i] = -INFINITY; + } +#pragma unroll + for (int i = 0; i < num_frags_z * 2; i++) { + v_local_max_value[i] = -INFINITY; + } + const int num_kv_heads = gridDim.z; + const int scale_offset = + block_id * num_kv_heads * BLOCK_SIZE + kv_head_idx * BLOCK_SIZE; + T *cache_k_scale_now = cache_k_scales + scale_offset; + T *cache_v_scale_now = cache_v_scales + scale_offset; + // k scale +#pragma unroll + for (uint32_t fz = 0; fz < num_frags_z; ++fz) { +#pragma unroll + for (uint32_t fy = 0; fy < num_frags_y; ++fy) { + // reduce per thread, 4 threads each row + k_smem.ldmatrix_m8n8x4(k_smem_offset_r, kv_reduce_frag); +#pragma unroll + for (int i = 0; i < 4; i++) { + k_local_max_value[fz * 2] = + __hmax(__habs(kv_reduce_frag_T[i]), k_local_max_value[fz * 2]); + } +#pragma unroll + for (int i = 0; i < 4; i++) { + k_local_max_value[fz * 2 + 1] = __hmax(__habs(kv_reduce_frag_T[i + 4]), + k_local_max_value[fz * 2 + 1]); + } + k_smem_offset_r = k_smem.advance_offset_by_column<2>(k_smem_offset_r, fy); + } + // reduce per row + for (int i = 0; i < 2; i++) { + T local_max_value = __habs(k_local_max_value[fz * 2 + i]); + local_max_value = __hmax(local_max_value, + __shfl_xor_sync(0xffffffff, local_max_value, 2)); + local_max_value = __hmax(local_max_value, + __shfl_xor_sync(0xffffffff, local_max_value, 1)); + // used for quant + k_local_max_value[fz * 2 + i] = __hdiv(448, local_max_value); + } + // store + if (tid % 4 == 0) { + const int offset_now = wid * num_frags_z * 16 + tid / 4; + // used for dequant + if (tile_start + offset_now >= start_len) { + if (tile_start + offset_now < end_len) { + cache_k_scale_now[offset_now] = __hdiv(1, k_local_max_value[fz * 2]); + } else { + cache_k_scale_now[offset_now] = 0; + } + } + if (tile_start + offset_now + 8 >= start_len) { + if (tile_start + offset_now + 8 < end_len) { + cache_k_scale_now[offset_now + 8] = + __hdiv(1, k_local_max_value[fz * 2 + 1]); + } else { + cache_k_scale_now[offset_now + 8] = 0; + } + } + } + __syncthreads(); + k_smem_offset_r -= 2 * num_frags_y; // num_frags_z = 1 + } +// v scale +#pragma unroll + for (uint32_t fz = 0; fz < num_frags_z; ++fz) { +#pragma unroll + for (uint32_t fy = 0; fy < num_frags_y; ++fy) { + // reduce per thread, 4 threads each row + v_smem.ldmatrix_m8n8x4(k_smem_offset_r, kv_reduce_frag); +#pragma unroll + for (int i = 0; i < 4; i++) { + v_local_max_value[fz * 2] = + __hmax(__habs(kv_reduce_frag_T[i]), v_local_max_value[fz * 2]); + } +#pragma unroll + for (int i = 0; i < 4; i++) { + v_local_max_value[fz * 2 + 1] = __hmax(__habs(kv_reduce_frag_T[i + 4]), + v_local_max_value[fz * 2 + 1]); + } + k_smem_offset_r = v_smem.advance_offset_by_column<2>(k_smem_offset_r, fy); + } + // reduce per row + for (int i = 0; i < 2; i++) { + T local_max_value = __habs(v_local_max_value[fz * 2 + i]); + local_max_value = __hmax(local_max_value, + __shfl_xor_sync(0xffffffff, local_max_value, 2)); + local_max_value = __hmax(local_max_value, + __shfl_xor_sync(0xffffffff, local_max_value, 1)); + v_local_max_value[fz * 2 + i] = __hdiv(448, local_max_value); + } + // store + if (tid % 4 == 0) { + const int offset_now = wid * num_frags_z * 16 + tid / 4; + // used for dequant + if (tile_start + offset_now >= start_len) { + if (tile_start + offset_now < end_len) { + cache_v_scale_now[offset_now] = __hdiv(1, v_local_max_value[fz * 2]); + v_scale_smem[offset_now] = v_local_max_value[fz * 2]; + } else { + cache_v_scale_now[offset_now] = 0; + v_scale_smem[offset_now] = 0; + } + } + if (tile_start + offset_now + 8 >= start_len) { + if (tile_start + offset_now + 8 < end_len) { + cache_v_scale_now[offset_now + 8] = + __hdiv(1, v_local_max_value[fz * 2 + 1]); + v_scale_smem[offset_now + 8] = v_local_max_value[fz * 2 + 1]; + } else { + cache_v_scale_now[offset_now + 8] = 0; + v_scale_smem[offset_now + 8] = 0; + } + } + } + __syncthreads(); + k_smem_offset_r -= 2 * num_frags_y; // num_frags_z = 1 + } + __syncthreads(); + + // mask, quant, store + using LoadKVT = AlignedVector; + LoadKVT cache_vec1; + LoadKVT cache_vec2; + + uint32_t chunk_start_k = tile_start + wid * num_frags_z * 16 + tid / 4; + uint32_t kv_frag[4]; const uint32_t write_n_stride = kv_num_heads * BLOCK_SIZE * HEAD_DIM; const uint32_t write_h_stride = BLOCK_SIZE * HEAD_DIM; const uint32_t write_b_stride = HEAD_DIM; @@ -1006,7 +1786,11 @@ __global__ void append_write_cache_kv_c8_qkv( uint8_t uint_quant_value; if (chunk_start_k + (v_id / 4) * 8 >= start_len && chunk_start_k + (v_id / 4) * 8 < end_len) { - uint_quant_value = QuantToC8(cache_k_scale, k_frag_T[v_id], 127.0f, -127.0f); + uint_quant_value = QuantToC8( + k_local_max_value[fz * 2 + v_id / 4], + k_frag_T[v_id], + 127.0f, + -127.0f); } else { uint_quant_value = 0; } @@ -1041,6 +1825,16 @@ __global__ void append_write_cache_kv_c8_qkv( (wid * num_frags_v * 16 + tid / 4) * write_d_stride + tid % 4 * 4; // 4 * int8 = 8 * int4 = 32bit const uint32_t num_frags_z_v = num_frags_z * NUM_WARPS; + T v_scales[num_frags_z_v * 4]; + for (int v_i = 0; v_i < num_frags_z_v; v_i++) { + const int offset = v_i * 16; + const int t_offset = tid % 4 * 2; + v_scales[v_i * 4] = v_scale_smem[offset + t_offset]; + v_scales[v_i * 4 + 1] = v_scale_smem[offset + t_offset + 1]; + v_scales[v_i * 4 + 2] = v_scale_smem[offset + t_offset + 8]; + v_scales[v_i * 4 + 3] = v_scale_smem[offset + t_offset + 9]; + } + #pragma unroll for (uint32_t fy = 0; fy < num_frags_v; ++fy) { uint32_t v_write_idx_now_v = v_write_idx + fy * 16 * write_d_stride; @@ -1062,7 +1856,8 @@ __global__ void append_write_cache_kv_c8_qkv( uint8_t uint_quant_value; if (chunk_start_v + v_id % 2 + (v_id % 4) / 2 * 8 >= start_len && chunk_start_v + v_id % 2 + (v_id % 4) / 2 * 8 < end_len) { - uint_quant_value = QuantToC8(cache_v_scale, v_frag_T[v_id], 127.0f, -127.0f); + uint_quant_value = QuantToC8( + v_scales[fz * 4 + v_id % 4], v_frag_T[v_id], 127.0f, -127.0f); // store now } else { uint_quant_value = 0; @@ -1093,6 +1888,9 @@ __global__ void append_write_cache_kv_c8_qkv( 16 * num_frags_z_v * num_vecs_per_head; chunk_start_v -= 16 * num_frags_z_v; } +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaGridDependencySynchronize(); +#endif } // Write Cache KV in Append @@ -1125,6 +1923,9 @@ __global__ void append_write_cache_kv_c4_qkv( constexpr uint32_t pad_len = BLOCK_SIZE; const uint32_t btid = blockIdx.x, kv_head_idx = blockIdx.z; const uint32_t tid = threadIdx.x, wid = threadIdx.y; +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaGridDependencySynchronize(); +#endif const uint32_t batch_id = batch_ids[btid]; const uint32_t tile_id = tile_ids[btid]; const uint32_t seq_len_this_time = seq_lens_this_time[batch_id]; @@ -1147,6 +1948,42 @@ __global__ void append_write_cache_kv_c4_qkv( const uint32_t start_token_idx = cu_seqlens_q[batch_id]; const uint32_t kv_batch_stride = (num_heads + 2 * kv_num_heads) * HEAD_DIM; const uint32_t kv_h_stride = HEAD_DIM; + int block_id = __ldg(&block_table_now[tile_start / BLOCK_SIZE]); + + const uint32_t HEAD_DIM_HALF = HEAD_DIM / 2; + const uint32_t BLOCK_SIZE_HALF = BLOCK_SIZE / 2; + + if (tile_start >= start_len) { + constexpr int KV_VEC_SIZE = 16 / sizeof(uint8_t); // 16 + using LoadPadKVT = AlignedVector; + // pad zero for this kv_head_idx for this block + LoadPadKVT pad_cache_vec; + *(reinterpret_cast(pad_cache_vec.val)) = make_uint4(0, 0, 0, 0); + // reset k + constexpr int num_vecs_per_head_k = HEAD_DIM_HALF / KV_VEC_SIZE; // 4 + constexpr int num_token_each_time_k = 32 / num_vecs_per_head_k; // 8 + uint32_t tgt_idx = + (block_id * kv_num_heads + kv_head_idx) * BLOCK_SIZE * HEAD_DIM_HALF + + tid % num_vecs_per_head_k * KV_VEC_SIZE; + for (int block_i = tid / num_vecs_per_head_k; block_i < BLOCK_SIZE; + block_i += num_token_each_time_k) { + Store(pad_cache_vec, + &cache_k[tgt_idx + block_i * HEAD_DIM_HALF]); + } + + // reset v + const int num_vecs_per_head_v = BLOCK_SIZE_HALF / KV_VEC_SIZE; // 2 + const int num_token_each_time_v = 32 / num_vecs_per_head_v; // 16 + tgt_idx = + (block_id * kv_num_heads + kv_head_idx) * HEAD_DIM * BLOCK_SIZE_HALF + + tid % num_vecs_per_head_v * KV_VEC_SIZE; + for (int block_i = tid / num_vecs_per_head_v; block_i < HEAD_DIM; + block_i += num_token_each_time_v) { + Store( + pad_cache_vec, &cache_v[tgt_idx + block_i * BLOCK_SIZE_HALF]); + } + } + __shared__ T k_smem_ori[num_rows_per_block * HEAD_DIM]; __shared__ T v_smem_ori[num_rows_per_block * HEAD_DIM]; __shared__ T k_scale_smem[HEAD_DIM]; @@ -1197,16 +2034,10 @@ __global__ void append_write_cache_kv_c4_qkv( for (uint32_t fy = 0; fy < num_frags_y / 4; ++fy) { // (num_frags_y * 16) / (8 * num_elems_per_128b()) if (chunk_start >= start_len && chunk_start < end_len) { - k_smem - .load_128b_async( - kv_smem_offset_w, - qkv_input + k_read_idx, - chunk_start < end_len); - v_smem - .load_128b_async( - kv_smem_offset_w, - qkv_input + v_read_idx, - chunk_start < end_len); + k_smem.load_128b_async( + kv_smem_offset_w, qkv_input + k_read_idx, chunk_start < end_len); + v_smem.load_128b_async( + kv_smem_offset_w, qkv_input + v_read_idx, chunk_start < end_len); } kv_smem_offset_w = k_smem.advance_offset_by_column<8>(kv_smem_offset_w, fy); @@ -1257,7 +2088,6 @@ __global__ void append_write_cache_kv_c4_qkv( uint32_t chunk_start_k = tile_start + wid * num_frags_z * 16 + tid / 4; uint32_t kv_frag[4]; - int block_id = __ldg(&block_table_now[tile_start / BLOCK_SIZE]); const uint32_t write_n_stride = kv_num_heads * BLOCK_SIZE * HEAD_DIM / 2; const uint32_t write_h_stride = BLOCK_SIZE * HEAD_DIM / 2; const uint32_t write_b_stride = HEAD_DIM / 2; @@ -1393,9 +2223,12 @@ __global__ void append_write_cache_kv_c4_qkv( 16 * num_frags_z_v * num_vecs_per_head; chunk_start_v -= 16 * num_frags_z_v; } +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaGridDependencySynchronize(); +#endif } -template +template void rotary_qk_variable( T *qkv_out, // [token_num, 3, num_head, dim_head] const QKV_TYPE *qkv_input, // qkv @@ -1414,9 +2247,8 @@ void rotary_qk_variable( const cudaStream_t &stream, bool use_neox_style = false, bool rope_3d = false) { - int64_t elem_nums = - qkv_out_scales ? token_num * 3 * head_num * dim_head - : token_num * 2 * head_num * dim_head; + int64_t elem_nums = qkv_out_scales ? token_num * 3 * head_num * dim_head + : token_num * 2 * head_num * dim_head; if (use_neox_style) { elem_nums /= 2; } @@ -1430,78 +2262,165 @@ void rotary_qk_variable( const float *cos_emb = rotary_emb; const float *sin_emb = rotary_emb + input_output_len * dim_head / 2; if (qkv_out_scales) { - VariableLengthRotaryKernel - <<>>( - reinterpret_cast(qkv_input), - cos_emb, - sin_emb, - batch_id_per_token, - cu_seqlens_q, - seq_lens, - seq_lens_decoder, - qkv_out_scales, - qkv_bias, - qkv_out, - elem_nums, - head_num, - seq_len, - dim_head); + launchWithPdlWhenEnabled( + IntVariableLengthRotaryKernel, + grid_size, + blocksize, + 0, + stream, + reinterpret_cast(qkv_input), + cos_emb, + sin_emb, + batch_id_per_token, + cu_seqlens_q, + seq_lens, + seq_lens_decoder, + qkv_out_scales, + qkv_bias, + qkv_out, + elem_nums, + head_num, + seq_len, + dim_head, + rope_3d); } else { - VariableLengthRotaryKernel - <<>>( - reinterpret_cast(qkv_input), - cos_emb, - sin_emb, - batch_id_per_token, - cu_seqlens_q, - seq_lens, - seq_lens_decoder, - qkv_out, - elem_nums, - head_num, - seq_len, - dim_head); + launchWithPdlWhenEnabled( + VariableLengthRotaryKernel, + grid_size, + blocksize, + 0, + stream, + reinterpret_cast(qkv_input), + cos_emb, + sin_emb, + batch_id_per_token, + cu_seqlens_q, + seq_lens, + seq_lens_decoder, + qkv_out, + elem_nums, + head_num, + seq_len, + dim_head, + rope_3d); } } else { const float *cos_emb = rotary_emb; const float *sin_emb = rotary_emb + input_output_len * dim_head; if (qkv_out_scales) { - NeoxVariableLengthRotaryKernel - <<>>( - reinterpret_cast(qkv_input), - cos_emb, - sin_emb, - batch_id_per_token, - cu_seqlens_q, - seq_lens, - seq_lens_decoder, - qkv_out_scales, - qkv_bias, - qkv_out, - elem_nums, - head_num, - seq_len, - dim_head); + launchWithPdlWhenEnabled( + IntNeoxVariableLengthRotaryKernel, + grid_size, + blocksize, + 0, + stream, + reinterpret_cast(qkv_input), + cos_emb, + sin_emb, + batch_id_per_token, + cu_seqlens_q, + seq_lens, + seq_lens_decoder, + qkv_out_scales, + qkv_bias, + qkv_out, + elem_nums, + head_num, + seq_len, + dim_head, + rope_3d); } else { - NeoxVariableLengthRotaryKernel - <<>>( - reinterpret_cast(qkv_input), - cos_emb, - sin_emb, - batch_id_per_token, - cu_seqlens_q, - seq_lens, - seq_lens_decoder, - qkv_out, - elem_nums, - head_num, - seq_len, - dim_head); + launchWithPdlWhenEnabled( + NeoxVariableLengthRotaryKernel, + grid_size, + blocksize, + 0, + stream, + reinterpret_cast(qkv_input), + cos_emb, + sin_emb, + batch_id_per_token, + cu_seqlens_q, + seq_lens, + seq_lens_decoder, + qkv_out, + elem_nums, + head_num, + seq_len, + dim_head, + rope_3d); } } } -template +template +void gqa_rotary_qk_norm_variable( + T *qkv_out, // [token_num, 3, num_head, dim_head] + const QKV_TYPE *qkv_input, // qkv + const float *qkv_out_scales, // [3, num_head, dim_head] + const T *qkv_bias, + const float *rotary_emb, // [2, 1, 1, seq_len, dim_head / 2] + const int *batch_id_per_token, + const int *cu_seqlens_q, + const int *seq_lens, + const int *seq_lens_decoder, + const int token_num, + const int num_heads, + const int kv_num_heads, + const int seq_len, + const int input_output_len, + const int dim_head, + const cudaStream_t &stream, + bool use_neox_style = false, + bool rope_3d = false, + const float *q_norm_weight = nullptr, + const float *k_norm_weight = nullptr, + const float rms_norm_eps = 1e-6) { + int64_t elem_nums = + qkv_out_scales + ? token_num * (num_heads + 2 * kv_num_heads) * dim_head + : token_num * (num_heads + kv_num_heads) * dim_head; // for all q k v + if (dim_head != 128) { + PADDLE_THROW( + "gqa rotary with qk norm only support head_dim=128, but got %d.", + dim_head); + } + constexpr int HEAD_DIM = 128; + constexpr int PackSize = HEAD_DIM / kWarpSize; + const int pack_num = elem_nums / PackSize; + const int blocksize = 128; + int grid_size = 1; + GetNumBlocks<128>(pack_num, &grid_size); + dim3 Block_Size(kWarpSize, blocksize / kWarpSize, 1); + + const float *cos_emb = rotary_emb; + const float *sin_emb = rotary_emb + input_output_len * dim_head / 2; + launchWithPdlWhenEnabled( + GQAVariableLengthRotaryQKNormKernel, + grid_size, + Block_Size, + 0, + stream, + reinterpret_cast(qkv_input), + cos_emb, + sin_emb, + batch_id_per_token, + cu_seqlens_q, + seq_lens, + seq_lens_decoder, + qkv_out, + elem_nums, + num_heads, + kv_num_heads, + seq_len, + dim_head, + rope_3d, + q_norm_weight, + k_norm_weight, + rms_norm_eps); +} + +template void gqa_rotary_qk_variable( T *qkv_out, // [token_num, 3, num_head, dim_head] const QKV_TYPE *qkv_input, // qkv @@ -1518,6 +2437,7 @@ void gqa_rotary_qk_variable( const int seq_len, const int input_output_len, const int dim_head, + const int rotary_dim, const cudaStream_t &stream, bool use_neox_style = false, bool rope_3d = false) { @@ -1539,86 +2459,146 @@ void gqa_rotary_qk_variable( const float *cos_emb = rotary_emb; const float *sin_emb = rotary_emb + input_output_len * dim_head / 2; if (qkv_out_scales) { - GQAVariableLengthRotaryKernel - <<>>( - reinterpret_cast(qkv_input), - cos_emb, - sin_emb, - batch_id_per_token, - cu_seqlens_q, - seq_lens, - seq_lens_decoder, - qkv_out_scales, - qkv_bias, - qkv_out, - elem_nums, - num_heads, - kv_num_heads, - seq_len, - dim_head, - rope_3d); + launchWithPdlWhenEnabled( + IntGQAVariableLengthRotaryKernel, + grid_size, + blocksize, + 0, + stream, + reinterpret_cast(qkv_input), + cos_emb, + sin_emb, + batch_id_per_token, + cu_seqlens_q, + seq_lens, + seq_lens_decoder, + qkv_out_scales, + qkv_bias, + qkv_out, + elem_nums, + num_heads, + kv_num_heads, + seq_len, + dim_head, + rope_3d); } else { - GQAVariableLengthRotaryKernel - <<>>( - reinterpret_cast(qkv_input), - cos_emb, - sin_emb, - batch_id_per_token, - cu_seqlens_q, - seq_lens, - seq_lens_decoder, - qkv_out, - elem_nums, - num_heads, - kv_num_heads, - seq_len, - dim_head, - rope_3d); + auto *kernelFn = + GQAVariableLengthRotaryKernel; + launchWithPdlWhenEnabled(kernelFn, + grid_size, + blocksize, + 0, + stream, + reinterpret_cast(qkv_input), + cos_emb, + sin_emb, + batch_id_per_token, + cu_seqlens_q, + seq_lens, + seq_lens_decoder, + qkv_out, + elem_nums, + num_heads, + kv_num_heads, + seq_len, + dim_head, + rope_3d); } } else { const float *cos_emb = rotary_emb; const float *sin_emb = rotary_emb + input_output_len * dim_head; if (qkv_out_scales) { - GQANeoxVariableLengthRotaryKernel - <<>>( - reinterpret_cast(qkv_input), - cos_emb, - sin_emb, - batch_id_per_token, - cu_seqlens_q, - seq_lens, - seq_lens_decoder, - qkv_out_scales, - qkv_bias, - qkv_out, - elem_nums, - num_heads, - kv_num_heads, - seq_len, - dim_head); + launchWithPdlWhenEnabled( + IntGQANeoxVariableLengthRotaryKernel, + grid_size, + blocksize, + 0, + stream, + reinterpret_cast(qkv_input), + cos_emb, + sin_emb, + batch_id_per_token, + cu_seqlens_q, + seq_lens, + seq_lens_decoder, + qkv_out_scales, + qkv_bias, + qkv_out, + elem_nums, + num_heads, + kv_num_heads, + seq_len, + dim_head, + rope_3d); } else { - GQANeoxVariableLengthRotaryKernel - <<>>( - reinterpret_cast(qkv_input), - cos_emb, - sin_emb, - batch_id_per_token, - cu_seqlens_q, - seq_lens, - seq_lens_decoder, - qkv_out_scales, - qkv_bias, - qkv_out, - elem_nums, - num_heads, - kv_num_heads, - seq_len, - dim_head); + if (rotary_dim < dim_head) { + PD_CHECK((rotary_dim / 2) % PackSize == 0); + elem_nums = + qkv_out_scales + ? token_num * (num_heads + 2 * kv_num_heads) * rotary_dim + : token_num * (num_heads + kv_num_heads) * + rotary_dim; // for all q k v + if (use_neox_style) { + elem_nums /= 2; + } + const int pack_num_new = elem_nums / PackSize; + GetNumBlocks<128>(pack_num_new, &grid_size); + auto *kernelFn = + GQANeoxVariableLengthPartialRotaryKernel; + launchWithPdlWhenEnabled(kernelFn, + grid_size, + blocksize, + 0, + stream, + reinterpret_cast(qkv_input), + cos_emb, + rotary_emb + input_output_len * rotary_dim / 2, + batch_id_per_token, + cu_seqlens_q, + seq_lens, + seq_lens_decoder, + qkv_out_scales, + qkv_bias, + qkv_out, + elem_nums, + num_heads, + kv_num_heads, + seq_len, + dim_head, + rotary_dim, + rope_3d); + } else { + auto *kernelFn = + GQANeoxVariableLengthRotaryKernel; + launchWithPdlWhenEnabled(kernelFn, + grid_size, + blocksize, + 0, + stream, + reinterpret_cast(qkv_input), + cos_emb, + sin_emb, + batch_id_per_token, + cu_seqlens_q, + seq_lens, + seq_lens_decoder, + qkv_out_scales, + qkv_bias, + qkv_out, + elem_nums, + num_heads, + kv_num_heads, + seq_len, + dim_head, + rope_3d); + } } } } -template +template void gqa_rotary_qk_quant_variable( T *qkv_out, // [token_num, 3, num_head, dim_head] const QKV_TYPE *qkv_input, // qkv @@ -1651,49 +2631,57 @@ void gqa_rotary_qk_quant_variable( int grid_size = 1; GetNumBlocks<128>(pack_num, &grid_size); const float *cos_emb = rotary_emb; - const float *sin_emb = rotary_emb + input_output_len * dim_head / 2; + const float *sin_emb = rotary_emb + input_output_len * dim_head / 2; if (!use_neox_style) { if (qkv_out_scales) { - GQAVariableLengthRotaryQuantKVKernel - <<>>( - reinterpret_cast(qkv_input), - cos_emb, - sin_emb, - qkv_out_scales, - batch_id_per_token, - cu_seqlens_q, - seq_lens, - seq_lens_decoder, - qkv_bias, - cache_k_scales, - cache_v_scales, - qkv_out, - elem_nums, - num_heads, - kv_num_heads, - seq_len, - dim_head, - rope_3d); + launchWithPdlWhenEnabled( + IntGQAVariableLengthRotaryQuantKVKernel, + grid_size, + blocksize, + 0, + stream, + reinterpret_cast(qkv_input), + cos_emb, + sin_emb, + qkv_out_scales, + batch_id_per_token, + cu_seqlens_q, + seq_lens, + seq_lens_decoder, + qkv_bias, + cache_k_scales, + cache_v_scales, + qkv_out, + elem_nums, + num_heads, + kv_num_heads, + seq_len, + dim_head, + rope_3d); } else { - GQAVariableLengthRotaryQuantKVKernel - <<>>( - reinterpret_cast(qkv_input), - cos_emb, - sin_emb, - batch_id_per_token, - cu_seqlens_q, - seq_lens, - seq_lens_decoder, - qkv_bias, - cache_k_scales, - cache_v_scales, - qkv_out, - elem_nums, - num_heads, - kv_num_heads, - seq_len, - dim_head, - rope_3d); + launchWithPdlWhenEnabled( + GQAVariableLengthRotaryQuantKVKernel, + grid_size, + blocksize, + 0, + stream, + reinterpret_cast(qkv_input), + cos_emb, + sin_emb, + batch_id_per_token, + cu_seqlens_q, + seq_lens, + seq_lens_decoder, + qkv_bias, + cache_k_scales, + cache_v_scales, + qkv_out, + elem_nums, + num_heads, + kv_num_heads, + seq_len, + dim_head, + rope_3d); } } else { PADDLE_THROW("Use_neox_style mode isn't implemented yet"); @@ -1722,14 +2710,18 @@ void CascadeAppendWriteCacheKVQKV( auto head_dim = meta_data.head_dims; auto block_size = meta_data.block_size; - const uint32_t elem_nums = - num_tokens * 2 * kv_num_heads * head_dim; + const uint32_t elem_nums = num_tokens * 2 * kv_num_heads * head_dim; constexpr int PackSize = 16 / sizeof(T); const int pack_num = elem_nums / PackSize; const int blocksize = 128; int grid_size = 1; GetNumBlocks<128>(pack_num, &grid_size); - cache_kernel<<>>( + launchWithPdlWhenEnabled( + cache_kernel, + grid_size, + blocksize, + 0, + stream, reinterpret_cast(const_cast(qkv.data())), reinterpret_cast(key_cache_out->data()), reinterpret_cast(value_cache_out->data()), @@ -1767,10 +2759,11 @@ void CascadeAppendWriteCacheKVC8QKV( int num_blocks_x_cpu, int max_seq_len, bool is_scale_channel_wise, - const bool is_fp8, + const std::string &cache_quant_type, cudaStream_t &stream, paddle::Tensor *cache_k_out, paddle::Tensor *cache_v_out) { + using NV_TYPE = typename cascade_attn_type_traits::type; auto max_blocks_per_seq = meta_data.max_blocks_per_seq; auto num_tokens = meta_data.token_nums; auto num_heads = meta_data.q_num_heads; @@ -1788,49 +2781,93 @@ void CascadeAppendWriteCacheKVC8QKV( dim3 blocks(32, num_warps); const uint32_t smem_size = (BLOCK_SIZE * HEAD_DIM) * sizeof(T) * 2; - auto kernel_fn = append_write_cache_kv_c8_qkv; - if (is_fp8) { - kernel_fn = append_write_cache_kv_c8_qkv; - } - if (is_scale_channel_wise) { - kernel_fn = append_write_cache_kv_c8_qkv; + if (cache_quant_type != "block_wise_fp8") { + auto kernel_fn = append_write_cache_kv_c8_qkv; + if (cache_quant_type == "cache_fp8") { + kernel_fn = append_write_cache_kv_c8_qkv; + } + if (is_scale_channel_wise) { + kernel_fn = append_write_cache_kv_c8_qkv; + } + cudaFuncSetAttribute( + kernel_fn, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); + launchWithPdlWhenEnabled(kernel_fn, + grids, + blocks, + 0, + stream, + cache_k_out->data(), + cache_v_out->data(), + qkv.data(), + cache_k_scale.data(), + cache_v_scale.data(), + batch_ids.data(), + tile_ids_per_batch.data(), + seq_lens_this_time.data(), + seq_lens_decoder.data(), + batch_id_per_token.data(), + cu_seqlens_q.data(), + block_table.data(), + max_seq_len, + max_blocks_per_seq, + num_heads, + kv_num_heads); + } else { + auto kernel_fn = append_write_cache_kv_c8_qkv_dynamic; + cudaFuncSetAttribute( + kernel_fn, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); + launchWithPdlWhenEnabled( + kernel_fn, + grids, + blocks, + 0, + stream, + cache_k_out->data(), + cache_v_out->data(), + reinterpret_cast(qkv.data()), + const_cast( + reinterpret_cast(cache_k_scale.data())), + const_cast( + reinterpret_cast(cache_v_scale.data())), + batch_ids.data(), + tile_ids_per_batch.data(), + seq_lens_this_time.data(), + seq_lens_decoder.data(), + batch_id_per_token.data(), + cu_seqlens_q.data(), + block_table.data(), + max_seq_len, + max_blocks_per_seq, + num_heads, + kv_num_heads); } - cudaFuncSetAttribute( - kernel_fn, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); - kernel_fn<<>>(cache_k_out->data(), - cache_v_out->data(), - qkv.data(), - cache_k_scale.data(), - cache_v_scale.data(), - batch_ids.data(), - tile_ids_per_batch.data(), - seq_lens_this_time.data(), - seq_lens_decoder.data(), - batch_id_per_token.data(), - cu_seqlens_q.data(), - block_table.data(), - max_seq_len, - max_blocks_per_seq, - num_heads, - kv_num_heads); } template @@ -1883,22 +2920,27 @@ void CascadeAppendWriteCacheKVC4QKV( num_warps>; cudaFuncSetAttribute( kernel_fn, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); - kernel_fn<<>>(cache_k_out->data(), - cache_v_out->data(), - qkv.data(), - cache_k_scale.data(), - cache_v_scale.data(), - cache_k_zp.data(), - cache_v_zp.data(), - batch_ids.data(), - tile_ids_per_batch.data(), - seq_lens_this_time.data(), - seq_lens_decoder.data(), - batch_id_per_token.data(), - cu_seqlens_q.data(), - block_table.data(), - max_seq_len, - max_blocks_per_seq, - num_heads, - kv_num_heads); + launchWithPdlWhenEnabled(kernel_fn, + grids, + blocks, + 0, + stream, + cache_k_out->data(), + cache_v_out->data(), + qkv.data(), + cache_k_scale.data(), + cache_v_scale.data(), + cache_k_zp.data(), + cache_v_zp.data(), + batch_ids.data(), + tile_ids_per_batch.data(), + seq_lens_this_time.data(), + seq_lens_decoder.data(), + batch_id_per_token.data(), + cu_seqlens_q.data(), + block_table.data(), + max_seq_len, + max_blocks_per_seq, + num_heads, + kv_num_heads); } diff --git a/custom_ops/gpu_ops/append_attn/encoder_write_cache_with_rope_kernel.h b/custom_ops/gpu_ops/append_attn/encoder_write_cache_with_rope_kernel.h index 5eb238216f7..23969aa429f 100644 --- a/custom_ops/gpu_ops/append_attn/encoder_write_cache_with_rope_kernel.h +++ b/custom_ops/gpu_ops/append_attn/encoder_write_cache_with_rope_kernel.h @@ -16,7 +16,7 @@ #include "encoder_write_cache_with_rope_impl.cuh" #include "remote_cache_kv_ipc.h" -template +template void EncoderWriteCacheWithRopeKernel( const AppendAttnMetaData& meta_data, const paddle::Tensor& @@ -46,80 +46,130 @@ void EncoderWriteCacheWithRopeKernel( cudaStream_t& stream, paddle::Tensor* qkv_out, paddle::Tensor* key_cache_out, - paddle::Tensor* value_cache_out) { + paddle::Tensor* value_cache_out, + const paddle::optional& q_norm_weight, + const paddle::optional& k_norm_weight, + const float rms_norm_eps) { auto token_num = meta_data.token_nums; auto num_heads = meta_data.q_num_heads; auto kv_num_heads = meta_data.kv_num_heads; auto head_dim = meta_data.head_dims; bool is_scale_channel_wise = false; - if (cache_k_scale && cache_k_scale.get().dims()[0] == head_dim * kv_num_heads) { + int rotary_dim = head_dim; + if (cache_k_scale && + cache_k_scale.get().dims()[0] == head_dim * kv_num_heads) { is_scale_channel_wise = true; } + if (rotary_embs) { + rotary_dim = + rotary_embs.get().dims()[rotary_embs.get().dims().size() - 1] * 2; + if (rotary_dim < head_dim) { + if (!use_neox_style || q_norm_weight || k_norm_weight || + num_heads == kv_num_heads || is_scale_channel_wise) { + PADDLE_THROW(phi::errors::Fatal( + "partial_rotary_factor < 1.0 only supports " + "use_neox_rotary_style=True, q_norm_weight/k_norm_weight) is None, " + "GQA and is_scale_channel_wise=false.")); + } + } + } - if (num_heads == kv_num_heads) { - rotary_qk_variable( - qkv_out->data(), - qkv.data(), - qkv_out_scales ? qkv_out_scales.get().data() : nullptr, - qkv_biases ? qkv_biases.get().data() : nullptr, - rotary_embs.get().data(), - batch_id_per_token.data(), - cu_seqlens_q.data(), - seq_lens_encoder.data(), - seq_lens_decoder.data(), - token_num, - num_heads, - max_seq_len, - rotary_embs.get().dims()[2], - head_dim, - stream, - use_neox_style, - rope_3d); + if (q_norm_weight && k_norm_weight) { + if (num_heads != kv_num_heads && !is_scale_channel_wise && + !use_neox_style) { + gqa_rotary_qk_norm_variable( + qkv_out->data(), + qkv.data(), + qkv_out_scales ? qkv_out_scales.get().data() : nullptr, + qkv_biases ? qkv_biases.get().data() : nullptr, + rotary_embs.get().data(), + batch_id_per_token.data(), + cu_seqlens_q.data(), + seq_lens_encoder.data(), + seq_lens_decoder.data(), + token_num, + num_heads, + kv_num_heads, + max_seq_len, + rope_3d ? rotary_embs.get().dims()[3] : rotary_embs.get().dims()[2], + head_dim, + stream, + use_neox_style, + rope_3d, + q_norm_weight ? q_norm_weight.get().data() : nullptr, + k_norm_weight ? k_norm_weight.get().data() : nullptr, + rms_norm_eps); + } else { + PD_THROW( + "gqa_rotary_qk_norm_variable only support gqa mode. channel wise " + "scale and neox style are not supported"); + } } else { - if (!is_scale_channel_wise) { - gqa_rotary_qk_variable( - qkv_out->data(), - qkv.data(), - qkv_out_scales ? qkv_out_scales.get().data() : nullptr, - qkv_biases ? qkv_biases.get().data() : nullptr, - rotary_embs.get().data(), - batch_id_per_token.data(), - cu_seqlens_q.data(), - seq_lens_encoder.data(), - seq_lens_decoder.data(), - token_num, - num_heads, - kv_num_heads, - max_seq_len, - rope_3d ? rotary_embs.get().dims()[3] : rotary_embs.get().dims()[2], - head_dim, - stream, - use_neox_style, - rope_3d); + if (num_heads == kv_num_heads) { + rotary_qk_variable( + qkv_out->data(), + qkv.data(), + qkv_out_scales ? qkv_out_scales.get().data() : nullptr, + qkv_biases ? qkv_biases.get().data() : nullptr, + rotary_embs.get().data(), + batch_id_per_token.data(), + cu_seqlens_q.data(), + seq_lens_encoder.data(), + seq_lens_decoder.data(), + token_num, + num_heads, + max_seq_len, + rotary_embs.get().dims()[2], + head_dim, + stream, + use_neox_style, + rope_3d); } else { - gqa_rotary_qk_quant_variable( - qkv_out->data(), - qkv.data(), - qkv_out_scales ? qkv_out_scales.get().data() : nullptr, - qkv_biases ? qkv_biases.get().data() : nullptr, - cache_k_scale ? cache_k_scale.get().data() : nullptr, - cache_v_scale ? cache_v_scale.get().data() : nullptr, - rotary_embs.get().data(), - batch_id_per_token.data(), - cu_seqlens_q.data(), - seq_lens_encoder.data(), - seq_lens_decoder.data(), - token_num, - num_heads, - kv_num_heads, - max_seq_len, - rotary_embs.get().dims()[2], - head_dim, - stream, - use_neox_style, - rope_3d); + if (!is_scale_channel_wise) { + gqa_rotary_qk_variable( + qkv_out->data(), + qkv.data(), + qkv_out_scales ? qkv_out_scales.get().data() : nullptr, + qkv_biases ? qkv_biases.get().data() : nullptr, + rotary_embs.get().data(), + batch_id_per_token.data(), + cu_seqlens_q.data(), + seq_lens_encoder.data(), + seq_lens_decoder.data(), + token_num, + num_heads, + kv_num_heads, + max_seq_len, + rope_3d ? rotary_embs.get().dims()[3] : rotary_embs.get().dims()[2], + head_dim, + rotary_dim, + stream, + use_neox_style, + rope_3d); + } else { + gqa_rotary_qk_quant_variable( + qkv_out->data(), + qkv.data(), + qkv_out_scales ? qkv_out_scales.get().data() : nullptr, + qkv_biases ? qkv_biases.get().data() : nullptr, + cache_k_scale ? cache_k_scale.get().data() : nullptr, + cache_v_scale ? cache_v_scale.get().data() : nullptr, + rotary_embs.get().data(), + batch_id_per_token.data(), + cu_seqlens_q.data(), + seq_lens_encoder.data(), + seq_lens_decoder.data(), + token_num, + num_heads, + kv_num_heads, + max_seq_len, + rotary_embs.get().dims()[2], + head_dim, + stream, + use_neox_style, + rope_3d); + } } - } const uint32_t block_size = meta_data.block_size; if (cache_quant_type_str == "none") { @@ -134,7 +184,9 @@ void EncoderWriteCacheWithRopeKernel( stream, key_cache_out, value_cache_out); - } else if (cache_quant_type_str == "cache_int8" or cache_quant_type_str == "cache_fp8") { + } else if (cache_quant_type_str == "cache_int8" or + cache_quant_type_str == "cache_fp8" or + cache_quant_type_str == "block_wise_fp8") { DISPATCH_HEAD_DIM( head_dim, HEAD_DIM, {DISPATCH_BLOCK_SIZE(block_size, BLOCK_SIZE, { CascadeAppendWriteCacheKVC8QKV( @@ -154,7 +206,7 @@ void EncoderWriteCacheWithRopeKernel( num_blocks, max_seq_len, is_scale_channel_wise, - cache_quant_type_str == "cache_fp8", + cache_quant_type_str, stream, key_cache_out, value_cache_out); @@ -190,23 +242,29 @@ void EncoderWriteCacheWithRopeKernel( "cache_int4_zp]"); } - const char* fmt_write_cache_completed_signal_str = std::getenv("FLAGS_fmt_write_cache_completed_signal"); - const char* FLAGS_use_pd_disaggregation_per_chunk = std::getenv("FLAGS_use_pd_disaggregation_per_chunk"); + const char* fmt_write_cache_completed_signal_str = + std::getenv("FLAGS_fmt_write_cache_completed_signal"); + const char* FLAGS_use_pd_disaggregation_per_chunk = + std::getenv("FLAGS_use_pd_disaggregation_per_chunk"); if (fmt_write_cache_completed_signal_str && (std::strcmp(fmt_write_cache_completed_signal_str, "true") == 0 || std::strcmp(fmt_write_cache_completed_signal_str, "1") == 0)) { - if (FLAGS_use_pd_disaggregation_per_chunk && - (std::strcmp(FLAGS_use_pd_disaggregation_per_chunk, "true") == 0 || - std::strcmp(FLAGS_use_pd_disaggregation_per_chunk, "1") == 0)) { - cudaLaunchHostFunc(qkv.stream(), - &(RemoteCacheKvIpc::save_cache_kv_complete_signal_layerwise_per_query), - (void*)nullptr); - } else { - if (kv_signal_data) { - cudaLaunchHostFunc(qkv.stream(), - &RemoteCacheKvIpc::save_cache_kv_complete_signal_layerwise, - (void*)(const_cast(kv_signal_data.get().data()))); - } + if (FLAGS_use_pd_disaggregation_per_chunk && + (std::strcmp(FLAGS_use_pd_disaggregation_per_chunk, "true") == 0 || + std::strcmp(FLAGS_use_pd_disaggregation_per_chunk, "1") == 0)) { + cudaLaunchHostFunc( + qkv.stream(), + &(RemoteCacheKvIpc:: + save_cache_kv_complete_signal_layerwise_per_query), + (void*)nullptr); + } else { + if (kv_signal_data) { + cudaLaunchHostFunc( + qkv.stream(), + &RemoteCacheKvIpc::save_cache_kv_complete_signal_layerwise, + (void*)(const_cast( + kv_signal_data.get().data()))); } + } } } diff --git a/custom_ops/gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu b/custom_ops/gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu index e438380e286..f94e8493f7f 100644 --- a/custom_ops/gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu +++ b/custom_ops/gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu @@ -11,18 +11,21 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. - +#include "cute/tensor.hpp" #include "helper.h" #include "paddle/extension.h" +#ifndef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU +#include "paddle/phi/backends/gpu/cuda/cuda_graph_with_memory_pool.h" #include "paddle/phi/core/memory/memcpy.h" +#endif +#include "utils.cuh" template -__global__ void -GetMaxLenKernel(const int *seq_lens, const int *seq_lens_this_time, - const int *seq_lens_encoder, - const int *seq_lens_this_time_merged, - const int *seq_lens_encoder_merged, const int *seq_mapping, - const int *system_lens, int *max_lens, const int batch_size) { +__global__ void GetMaxLenKernel(const int *seq_lens_decoder, + const int *seq_lens_this_time, + const int *seq_lens_encoder, + int *max_lens, + const int batch_size) { const int tid = threadIdx.x; typedef cub::BlockReduce BlockReduce; @@ -33,44 +36,27 @@ GetMaxLenKernel(const int *seq_lens, const int *seq_lens_this_time, int max_len_decoder_this_thread = 0; int max_len_this_thread = 0; int max_just_dec_len_this_thread = 0; - int max_just_dec_merged_len_this_time_this_thread = 0; - int max_system_len_this_thread = 0; - int max_dec_len_without_system_this_thread = 0; + int max_len_kv_this_thread = 0; for (int i = tid; i < batch_size; i += blockDim.x) { const int seq_len_this_time = seq_lens_this_time[i]; + const int seq_len_decoder = seq_lens_decoder[i]; max_len_this_time_this_thread = max(seq_len_this_time, max_len_this_time_this_thread); max_len_encoder_this_thread = max(seq_lens_encoder[i], max_len_encoder_this_thread); - max_len_decoder_this_thread = max(seq_lens[i], max_len_decoder_this_thread); - if (seq_len_this_time <= 0) - continue; - const int max_just_dec_len_now = seq_lens_encoder[i] > 0 ? 0 : seq_lens[i]; + max_len_decoder_this_thread = + max(seq_len_decoder, max_len_decoder_this_thread); + if (seq_len_this_time <= 0) continue; + const int max_just_dec_len_now = + seq_lens_encoder[i] > 0 ? 0 : seq_len_decoder; max_len_this_thread = - max(seq_lens[i] + seq_len_this_time, max_len_this_thread); + max(seq_len_decoder + seq_len_this_time, max_len_this_thread); max_just_dec_len_this_thread = max(max_just_dec_len_this_thread, max_just_dec_len_now); - if (system_lens) { - const int real_bid = seq_mapping[i]; - const int system_len_now = system_lens[real_bid]; - max_system_len_this_thread = - max(max_system_len_this_thread, system_len_now); - max_dec_len_without_system_this_thread = - max(max_dec_len_without_system_this_thread, - max_just_dec_len_now - system_len_now); - } - } - if (system_lens) { - for (int i = tid; i < batch_size; i += blockDim.x) { - const int ori_seq_len_this_time = seq_lens_this_time_merged[i]; - if (ori_seq_len_this_time <= 0) - continue; - const int max_just_dec_merged_len_this_time_now = - seq_lens_encoder_merged[i] > 0 ? 0 : ori_seq_len_this_time; - max_just_dec_merged_len_this_time_this_thread = - max(max_just_dec_merged_len_this_time_this_thread, - max_just_dec_merged_len_this_time_now); - } + + if (seq_len_decoder == 0) continue; + max_len_kv_this_thread = + max(seq_len_this_time + seq_len_decoder, max_len_kv_this_thread); } int total_max_len_this_time = BlockReduce(temp_storage) @@ -85,60 +71,161 @@ GetMaxLenKernel(const int *seq_lens, const int *seq_lens_this_time, BlockReduce(temp_storage).Reduce(max_len_this_thread, MaxOp()); int total_just_dec = BlockReduce(temp_storage) .Reduce(max_just_dec_len_this_thread, MaxOp()); - int total_just_dec_merged = - BlockReduce(temp_storage) - .Reduce(max_just_dec_merged_len_this_time_this_thread, MaxOp()); - int total_system_len = BlockReduce(temp_storage) - .Reduce(max_system_len_this_thread, MaxOp()); - int total_dec_len_without_system = - BlockReduce(temp_storage) - .Reduce(max_dec_len_without_system_this_thread, MaxOp()); + int total_max_len_kv = + BlockReduce(temp_storage).Reduce(max_len_kv_this_thread, MaxOp()); if (tid == 0) { max_lens[0] = total_max_len_this_time; max_lens[1] = total_max_len_encoder; max_lens[2] = total_max_len_decoder; max_lens[3] = total; max_lens[4] = total_just_dec; - max_lens[5] = total_just_dec_merged; - max_lens[6] = total_system_len; - max_lens[7] = total_dec_len_without_system; + max_lens[5] = total_max_len_kv; + } +} + +template +__global__ void search_chunk_size_for_mla( + const int *__restrict__ seq_lens_q, + const int *__restrict__ seq_lens_encoder, + const int *__restrict__ seq_lens_decoder, + int *__restrict__ num_blocks_x, + int *__restrict__ res_chunk_size, + const int bsz, + const int set_chunk_size, + const int block_size, + const int sm_cout) { + const uint32_t conf_id = threadIdx.x; + int gridx = 0; + if (set_chunk_size > 0 && conf_id == 0) { + for (uint32_t bid = 0; bid < bsz; bid++) { + int seq_len = seq_lens_q[bid]; + int seq_len_encoder = seq_lens_encoder[bid]; + int seq_len_decoder = seq_lens_decoder[bid] + seq_len; + if (seq_len == 0 || seq_len_encoder > 0) continue; + + int loop_times; + loop_times = cute::ceil_div(seq_len_decoder, set_chunk_size); + gridx += loop_times; + } + *num_blocks_x = gridx; + *res_chunk_size = set_chunk_size; + } else if (conf_id < config_size) { + __shared__ int gridx_shared[config_size]; + // chunk_size is a multiple of 64 + const int chunk_size = block_size << conf_id; + for (uint32_t bid = 0; bid < bsz; bid++) { + int seq_len = seq_lens_q[bid]; + int seq_len_encoder = seq_lens_encoder[bid]; + int seq_len_decoder = seq_lens_decoder[bid] + seq_len; + if (seq_len == 0 || seq_len_encoder > 0) continue; + + int loop_times; + loop_times = cute::ceil_div(seq_len_decoder, chunk_size); + gridx += loop_times; + } + gridx_shared[conf_id] = gridx; + __syncthreads(); + if (threadIdx.x == 0) { + uint32_t res_id = 0; + uint32_t max_last_wave_block = 0; + for (uint32_t i = 1; i < config_size; i++) { + uint32_t last_wave_block = gridx_shared[i] % sm_cout; + if (last_wave_block >= max_last_wave_block) { + res_id = i; + max_last_wave_block = last_wave_block; + } + } + *num_blocks_x = gridx_shared[res_id]; + *res_chunk_size = block_size << res_id; + } } } -void GetMaxLen(const paddle::Tensor &seq_lens_tensor, - const paddle::Tensor &seq_lens_this_time, - const paddle::Tensor &seq_lens_encoder, - paddle::Tensor &max_len_tensor, const int batch_size) { - constexpr int blockSize = 1024; - GetMaxLenKernel<<<1, blockSize, 0, seq_lens_encoder.stream()>>>( - seq_lens_tensor.data(), seq_lens_this_time.data(), - seq_lens_encoder.data(), nullptr, nullptr, nullptr, nullptr, - max_len_tensor.data(), batch_size); +__global__ void split_block_for_mla( + const int *__restrict__ seq_lens_q, + const int *__restrict__ seq_lens_encoder, + const int *__restrict__ seq_lens_decoder, + int *__restrict__ batch_ids, + int *__restrict__ tile_ids_per_batch, + const int bsz, + int *__restrict__ decoder_chunk_size_device) { + const int chunk_size = decoder_chunk_size_device[0]; + + if (threadIdx.x == 0) { + int index = 0; + for (uint32_t bid = 0; bid < bsz; bid++) { + int seq_len = seq_lens_q[bid]; + int seq_len_encoder = seq_lens_encoder[bid]; + int seq_len_decoder = seq_lens_decoder[bid] + seq_len; + + if (seq_len == 0) continue; + + int loop_times; + loop_times = cute::ceil_div(seq_len_decoder, chunk_size); + if (seq_len_encoder > 0) { + loop_times = 0; + } + for (uint32_t tile_id = 0; tile_id < loop_times; tile_id++) { + batch_ids[index] = bid; + tile_ids_per_batch[index++] = tile_id; + } + } + } } __global__ void split_q_block(const int *__restrict__ seq_lens_q, const int *__restrict__ seq_lens_encoder, int *__restrict__ batch_ids, int *__restrict__ tile_ids_per_batch, - int *__restrict__ num_blocks_x, const int bsz, + int *__restrict__ num_blocks_x, + const int bsz, const int num_rows_per_block, const int group_size) { - if (threadIdx.x == 0) { - int gridx = 0; - int index = 0; - for (uint32_t bid = 0; bid < bsz; bid++) { + // one block one warp + const int lane_id = threadIdx.x % WARP_SIZE; + int prev_offset = 0; + + // loop on warp tile:[base, base+32) + for (int base = 0; base < bsz; base += WARP_SIZE) { + const int bid = base + lane_id; + + // calculate loop_times for bid + int loop_times = 0; + if (bid < bsz) { int seq_len = seq_lens_q[bid]; if (seq_lens_encoder && seq_lens_encoder[bid] > 0) { seq_len = 0; } - const int loop_times = div_up(seq_len * group_size, num_rows_per_block); - for (uint32_t tile_id = 0; tile_id < loop_times; tile_id++) { - batch_ids[index] = bid; - tile_ids_per_batch[index++] = tile_id; + loop_times = div_up(seq_len * group_size, num_rows_per_block); + } + + // prefix sum for each lane, get the start offset in this tile + // inclusive scan + int x = loop_times; + for (int offset = 1; offset < WARP_SIZE; offset <<= 1) { + int y = __shfl_up_sync(0xffffffff, x, offset); + if (lane_id >= offset) x += y; + } + // exclusive prefix sum + int bid_offset = x - loop_times; + int tile_sum = __shfl_sync(0xffffffff, x, WARP_SIZE - 1); + + // write batch_ids and tile_ids_per_batch + if (bid < bsz && loop_times > 0) { + int write_base = prev_offset + bid_offset; + for (int t = 0; t < loop_times; ++t) { + int pos = write_base + t; + batch_ids[pos] = bid; + tile_ids_per_batch[pos] = t; } - gridx += loop_times; } - *num_blocks_x = gridx; + + // for next warp tile + prev_offset += tile_sum; + } + + if (threadIdx.x == 0) { + *num_blocks_x = prev_offset; } } @@ -146,8 +233,10 @@ __global__ void split_kv_block(const int *__restrict__ seq_lens_decoder, const int *__restrict__ seq_lens_encoder, int *__restrict__ batch_ids, int *__restrict__ tile_ids_per_batch, - int *__restrict__ num_blocks_x, const int bsz, - const int pad_len, const int num_row_per_block) { + int *__restrict__ num_blocks_x, + const int bsz, + const int pad_len, + const int num_row_per_block) { if (threadIdx.x == 0) { int gridx = 0; int index = 0; @@ -168,215 +257,233 @@ __global__ void split_kv_block(const int *__restrict__ seq_lens_decoder, } } -template -__global__ void -get_max_len_kv_ernel(int *max_seq_lens_out, const int *seq_lens_this_time, - const int *seq_lens_decoder, const int batch_size) { - const int tid = threadIdx.x; - - typedef cub::BlockReduce BlockReduce; - __shared__ typename BlockReduce::TempStorage temp_storage; - - int max_len_this_thread = 0; - for (int i = tid; i < batch_size; i += blockDim.x) { - if (seq_lens_decoder[i] == 0) - continue; - max_len_this_thread = - max(seq_lens_this_time[i] + seq_lens_decoder[i], max_len_this_thread); - } - int total = - BlockReduce(temp_storage).Reduce(max_len_this_thread, MaxOp()); - if (tid == 0) { - *max_seq_lens_out = total; - } -} - -std::vector GetBlockShapeAndSplitKVBlock( +void GetBlockShapeAndSplitKVBlock( const paddle::Tensor &seq_lens_encoder, const paddle::Tensor &seq_lens_decoder, const paddle::Tensor &seq_lens_this_time, - const int encoder_block_shape_q, const int decoder_block_shape_q, - const int group_size, const int block_size, - const int decoder_step_token_num) { + paddle::Tensor &decoder_batch_ids, // Inplace + paddle::Tensor &decoder_tile_ids_per_batch, // Inplace + paddle::Tensor &decoder_num_blocks_cpu, // Inplace, Pinned Memory + paddle::Tensor &decoder_num_blocks_device, // Inplace + paddle::Tensor &decoder_chunk_size_device, // Inplace + paddle::Tensor &max_len_tensor_cpu, // Inplace, CPU + paddle::Tensor &encoder_batch_ids, // Inplace + paddle::Tensor &encoder_tile_ids_per_batch, // Inplace + paddle::Tensor &encoder_num_blocks_x_cpu, // Inplace, CPU + paddle::Tensor &kv_batch_ids, // Inplace + paddle::Tensor &kv_tile_ids_per_batch, // Inplace + paddle::Tensor &kv_num_blocks_x_cpu, // Inplace, CPU + const int encoder_block_shape_q, + const int decoder_block_shape_q, + const int group_size, + const int block_size) { auto stream = seq_lens_encoder.stream(); - int bsz = seq_lens_encoder.shape()[0]; - auto max_len_tensor = - GetEmptyTensor({8}, paddle::DataType::INT32, seq_lens_encoder.place()); - GetMaxLen(seq_lens_decoder, seq_lens_this_time, seq_lens_encoder, - max_len_tensor, bsz); - - // max_len_this_time, max_enc_len_this_time, max_dec_len_this_time, - // max_enc_dec_len_this_time, max_just_dec_len_this_time, - // max_just_dec_merged_len_this_time, max_system_len, - // max_just_dec_len_without_system - auto max_len_cpu = max_len_tensor.copy_to(paddle::CPUPlace(), false); - auto max_len_cpu_ptr = max_len_cpu.data(); + int bsz = seq_lens_this_time.shape()[0]; + + paddle::Tensor max_len_tensor_gpu = + GetEmptyTensor({max_len_tensor_cpu.shape()[0]}, + paddle::DataType::INT32, + seq_lens_this_time.place()); + + GetMaxLenKernel<1024><<<1, 1024, 0, stream>>>(seq_lens_decoder.data(), + seq_lens_this_time.data(), + seq_lens_encoder.data(), + max_len_tensor_gpu.data(), + bsz); + // Note (sunxin): Skip capturing the DtoH copy (it's time-consuming); CPU data + // is only for branching in attention. +#ifndef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU + if (!phi::backends::gpu::IsCUDAGraphCapturing()) +#endif + max_len_tensor_cpu.copy_( + max_len_tensor_gpu, max_len_tensor_cpu.place(), false); + + auto max_len_cpu_ptr = max_len_tensor_cpu.data(); int max_len_this_time = max_len_cpu_ptr[0]; int max_enc_len_this_time = max_len_cpu_ptr[1]; int max_dec_len_this_time = max_len_cpu_ptr[2]; int max_enc_dec_len_this_time = max_len_cpu_ptr[3]; int max_just_dec_len_this_time = max_len_cpu_ptr[4]; - int max_just_dec_merged_len_this_time = max_len_cpu_ptr[5]; - int max_system_len = max_len_cpu_ptr[6]; - int max_just_dec_len_without_system = max_len_cpu_ptr[7]; - - paddle::Tensor encoder_batch_ids; - paddle::Tensor encoder_tile_ids_per_batch; - paddle::Tensor encoder_num_blocks_x_cpu; /*cpu*/ - paddle::Tensor kv_batch_ids; - paddle::Tensor kv_tile_ids_per_batch; - paddle::Tensor kv_num_blocks_x_cpu; /*cpu*/ - paddle::Tensor decoder_batch_ids; - paddle::Tensor decoder_tile_ids_per_batch; - paddle::Tensor decoder_num_blocks_x_cpu; /*cpu*/ - paddle::Tensor max_len_kv_cpu; /*cpu*/ - - auto max_len_kv = - GetEmptyTensor({1}, paddle::DataType::INT32, seq_lens_decoder.place()); - get_max_len_kv_ernel<128><<<1, 128, 0, stream>>>( - max_len_kv.data(), seq_lens_this_time.data(), - seq_lens_decoder.data(), bsz); - - max_len_kv_cpu = max_len_kv.copy_to(paddle::CPUPlace(), false); + int max_kv_len_this_time = max_len_cpu_ptr[5]; + + const uint32_t decoder_batch_ele_num = decoder_batch_ids.shape()[0]; + + const bool mla_backend = checkAttentionBackend(); + + // decoder + if (max_dec_len_this_time > 0) { + if (mla_backend) { + const int set_chunk_size = get_mla_dec_chunk_size(bsz); + + CUDA_CHECK(cudaMemsetAsync( + decoder_chunk_size_device.data(), 64, sizeof(int32_t), stream)); + + CUDA_CHECK(cudaMemsetAsync( + decoder_num_blocks_device.data(), 0, sizeof(int32_t), stream)); + + int device; + CUDA_CHECK(cudaGetDevice(&device)); + int sm_cout; + CUDA_CHECK(cudaDeviceGetAttribute( + &sm_cout, cudaDevAttrMultiProcessorCount, device)); + constexpr int config_size = + 12; // search space for chunk size:[64, 128, 256, ... 131072] + + search_chunk_size_for_mla + <<<1, 32, 0, stream>>>(seq_lens_this_time.data(), + seq_lens_encoder.data(), + seq_lens_decoder.data(), + decoder_num_blocks_device.data(), + decoder_chunk_size_device.data(), + bsz, + set_chunk_size, + block_size, + sm_cout); + + CUDA_CHECK(cudaMemsetAsync(decoder_batch_ids.data(), + 0, + decoder_batch_ele_num * sizeof(int32_t), + stream)); + CUDA_CHECK(cudaMemsetAsync(decoder_tile_ids_per_batch.data(), + 0, + decoder_batch_ele_num * sizeof(int32_t), + stream)); + + split_block_for_mla<<<1, 32, 0, stream>>>( + seq_lens_this_time.data(), + seq_lens_encoder.data(), + seq_lens_decoder.data(), + decoder_batch_ids.data(), + decoder_tile_ids_per_batch.data(), + bsz, + decoder_chunk_size_device.data()); + } else { + CUDA_CHECK(cudaMemsetAsync(decoder_batch_ids.data(), + 0xFF, + decoder_batch_ele_num * sizeof(int32_t), + stream)); + split_q_block<<<1, 32, 0, stream>>>( + seq_lens_this_time.data(), + seq_lens_encoder.data(), + decoder_batch_ids.data(), + decoder_tile_ids_per_batch.data(), + decoder_num_blocks_device.data(), + bsz, + decoder_block_shape_q, + group_size); + // Note (sunxin): Skip capturing the DtoH copy (it's time-consuming); CPU + // data is only for branching in attention. +#ifndef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU + if (!phi::backends::gpu::IsCUDAGraphCapturing()) +#endif + decoder_num_blocks_cpu.copy_( + decoder_num_blocks_device, decoder_num_blocks_cpu.place(), false); + } + } + // mla_backend not need run the following code. + if (mla_backend) return; + + // encoder if (max_enc_len_this_time > 0) { const uint32_t max_tile_size_per_bs_kv = div_up(max_enc_dec_len_this_time, block_size); - kv_batch_ids = - GetEmptyTensor({bsz * max_tile_size_per_bs_kv}, paddle::DataType::INT32, - seq_lens_encoder.place()); - kv_tile_ids_per_batch = - GetEmptyTensor({bsz * max_tile_size_per_bs_kv}, paddle::DataType::INT32, - seq_lens_encoder.place()); + const uint32_t kv_batch_shape = bsz * max_tile_size_per_bs_kv; + CUDA_CHECK(cudaMemsetAsync( + kv_batch_ids.data(), 0, kv_batch_shape * sizeof(int32_t), stream)); + CUDA_CHECK(cudaMemsetAsync(kv_tile_ids_per_batch.data(), + 0, + kv_batch_shape * sizeof(int32_t), + stream)); auto kv_num_blocks_x = GetEmptyTensor({1}, paddle::DataType::INT32, seq_lens_encoder.place()); split_kv_block<<<1, 32, 0, seq_lens_encoder.stream()>>>( seq_lens_decoder.data(), - // sequence_lengths->data(), - seq_lens_encoder.data(), kv_batch_ids.data(), - kv_tile_ids_per_batch.data(), kv_num_blocks_x.data(), bsz, - block_size, block_size); - - kv_num_blocks_x_cpu = kv_num_blocks_x.copy_to(paddle::CPUPlace(), false); + seq_lens_encoder.data(), + kv_batch_ids.data(), + kv_tile_ids_per_batch.data(), + kv_num_blocks_x.data(), + bsz, + block_size, + block_size); + kv_num_blocks_x_cpu.copy_( + kv_num_blocks_x, kv_num_blocks_x_cpu.place(), false); + // Clear buffer const uint32_t encoder_max_tile_size_per_bs_q = div_up((max_enc_dec_len_this_time * group_size), encoder_block_shape_q); - encoder_batch_ids = - GetEmptyTensor({bsz * encoder_max_tile_size_per_bs_q}, - paddle::DataType::INT32, seq_lens_encoder.place()); - encoder_tile_ids_per_batch = - GetEmptyTensor({bsz * encoder_max_tile_size_per_bs_q}, - paddle::DataType::INT32, seq_lens_encoder.place()); + const uint32_t encoder_batch_shape = bsz * encoder_max_tile_size_per_bs_q; + CUDA_CHECK(cudaMemsetAsync(encoder_batch_ids.data(), + 0, + encoder_batch_shape * sizeof(int32_t), + stream)); + CUDA_CHECK(cudaMemsetAsync(encoder_tile_ids_per_batch.data(), + 0, + encoder_batch_shape * sizeof(int32_t), + stream)); auto encoder_num_blocks_x = GetEmptyTensor({1}, paddle::DataType::INT32, seq_lens_encoder.place()); - split_q_block<<<1, 32, 0, stream>>>(seq_lens_encoder.data(), nullptr, + split_q_block<<<1, 32, 0, stream>>>(seq_lens_encoder.data(), + nullptr, encoder_batch_ids.data(), encoder_tile_ids_per_batch.data(), - encoder_num_blocks_x.data(), bsz, - encoder_block_shape_q, group_size); - encoder_num_blocks_x_cpu = - encoder_num_blocks_x.copy_to(paddle::CPUPlace(), false); - } else { - encoder_batch_ids = - GetEmptyTensor({0}, paddle::DataType::INT32, seq_lens_encoder.place()); - encoder_tile_ids_per_batch = - GetEmptyTensor({0}, paddle::DataType::INT32, seq_lens_encoder.place()); - encoder_num_blocks_x_cpu = - GetEmptyTensor({0}, paddle::DataType::INT32, paddle::CPUPlace()); - kv_batch_ids = - GetEmptyTensor({0}, paddle::DataType::INT32, seq_lens_encoder.place()); - kv_tile_ids_per_batch = - GetEmptyTensor({0}, paddle::DataType::INT32, seq_lens_encoder.place()); - kv_num_blocks_x_cpu = - GetEmptyTensor({0}, paddle::DataType::INT32, seq_lens_encoder.place()); - } - if (max_just_dec_len_this_time > 0) { - const uint32_t decoder_max_tile_size_per_bs_q = - div_up((decoder_step_token_num * group_size), decoder_block_shape_q); - - decoder_batch_ids = - GetEmptyTensor({bsz * decoder_max_tile_size_per_bs_q}, - paddle::DataType::INT32, seq_lens_encoder.place()); - decoder_tile_ids_per_batch = - GetEmptyTensor({bsz * decoder_max_tile_size_per_bs_q}, - paddle::DataType::INT32, seq_lens_encoder.place()); - auto decoder_num_blocks_x = - GetEmptyTensor({1}, paddle::DataType::INT32, seq_lens_encoder.place()); - split_q_block<<<1, 32, 0, stream>>>( - seq_lens_this_time.data(), seq_lens_encoder.data(), - decoder_batch_ids.data(), decoder_tile_ids_per_batch.data(), - decoder_num_blocks_x.data(), bsz, decoder_block_shape_q, - group_size); - decoder_num_blocks_x_cpu = - decoder_num_blocks_x.copy_to(paddle::CPUPlace(), false); - } else { - decoder_batch_ids = - GetEmptyTensor({0}, paddle::DataType::INT32, seq_lens_encoder.place()); - decoder_tile_ids_per_batch = - GetEmptyTensor({0}, paddle::DataType::INT32, seq_lens_encoder.place()); - decoder_num_blocks_x_cpu = - GetEmptyTensor({0}, paddle::DataType::INT32, paddle::CPUPlace()); + encoder_num_blocks_x.data(), + bsz, + encoder_block_shape_q, + group_size); + encoder_num_blocks_x_cpu.copy_( + encoder_num_blocks_x, encoder_num_blocks_x_cpu.place(), false); } - - return {encoder_batch_ids, - encoder_tile_ids_per_batch, - encoder_num_blocks_x_cpu, /*cpu*/ - kv_batch_ids, - kv_tile_ids_per_batch, - kv_num_blocks_x_cpu, /*cpu*/ - decoder_batch_ids, - decoder_tile_ids_per_batch, - decoder_num_blocks_x_cpu, /*cpu*/ - max_len_kv_cpu /*cpu*/, - max_len_cpu}; } -std::vector GetBlockShapeAndSplitKVBlockInferDtype( - const paddle::DataType &seq_lens_encoder_dtype, - const paddle::DataType &seq_lens_decoder_dtype, - const paddle::DataType &seq_lens_this_time_dtype) { - return { - paddle::DataType::INT32, paddle::DataType::INT32, paddle::DataType::INT32, - paddle::DataType::INT32, paddle::DataType::INT32, paddle::DataType::INT32, - paddle::DataType::INT32, paddle::DataType::INT32, paddle::DataType::INT32, - paddle::DataType::INT32, paddle::DataType::INT32}; +std::vector> GetBlockShapeAndSplitKVBlockInferShape( + const std::vector &seq_lens_encoder, + const std::vector &seq_lens_decoder, + const std::vector &seq_lens_this_time, + const int encoder_block_shape_q, + const int decoder_block_shape_q, + const int group_size, + const int block_size) { + return {}; } -std::vector> GetBlockShapeAndSplitKVBlockInferShape( - const std::vector &seq_lens_encoder_shape, - const std::vector &seq_lens_decoder_shape, - const std::vector &seq_lens_this_time_shape) { - std::vector dynamic_shape = {-1}; - - return {dynamic_shape, - dynamic_shape, - {1}, - dynamic_shape, - dynamic_shape, - {1}, - dynamic_shape, - dynamic_shape, - {1}, - {1}, - {8}}; +std::vector GetBlockShapeAndSplitKVBlockInferDtype( + const paddle::DataType &seq_lens_encoder, + const paddle::DataType &seq_lens_decoder, + const paddle::DataType &seq_lens_this_time, + const int encoder_block_shape_q, + const int decoder_block_shape_q, + const int group_size, + const int block_size) { + return {}; } PD_BUILD_STATIC_OP(get_block_shape_and_split_kv_block) - .Inputs({"seq_lens_encoder", "seq_lens_decoder", "seq_lens_this_time"}) - .Outputs({paddle::Optional("encoder_batch_ids"), - paddle::Optional("encoder_tile_ids_per_batch"), - paddle::Optional("encoder_num_blocks"), - paddle::Optional("kv_batch_ids"), - paddle::Optional("kv_tile_ids_per_batch"), - paddle::Optional("kv_num_blocks"), - paddle::Optional("decoder_batch_ids"), - paddle::Optional("decoder_tile_ids_per_batch"), - paddle::Optional("decoder_num_blocks"), - paddle::Optional("max_len_kv"), "set_max_lengths"}) - .Attrs({"encoder_block_shape_q: int", "decoder_block_shape_q: int", - "group_size: int", "block_size: int", - "decoder_step_token_num: int"}) + .Inputs({ + "seq_lens_encoder", + "seq_lens_decoder", + "seq_lens_this_time", + "decoder_batch_ids", + "decoder_tile_ids_per_batch", + "decoder_num_blocks_cpu", + "decoder_num_blocks_device", + "decoder_chunk_size_device", + "max_len_tensor_cpu", + "encoder_batch_ids", + "encoder_tile_ids_per_batch", + "encoder_num_blocks_x_cpu", + "kv_batch_ids", + "kv_tile_ids_per_batch", + "kv_num_blocks_x_cpu", + }) + .Outputs({ + + }) + .Attrs({"encoder_block_shape_q: int", + "decoder_block_shape_q: int", + "group_size: int", + "block_size: int"}) .SetKernelFn(PD_KERNEL(GetBlockShapeAndSplitKVBlock)) .SetInferShapeFn(PD_INFER_SHAPE(GetBlockShapeAndSplitKVBlockInferShape)) .SetInferDtypeFn(PD_INFER_DTYPE(GetBlockShapeAndSplitKVBlockInferDtype)); diff --git a/custom_ops/gpu_ops/append_attn/gqa_rope_write_cache.cu b/custom_ops/gpu_ops/append_attn/gqa_rope_write_cache.cu index 2cba8d547a2..e4d0554fea6 100644 --- a/custom_ops/gpu_ops/append_attn/gqa_rope_write_cache.cu +++ b/custom_ops/gpu_ops/append_attn/gqa_rope_write_cache.cu @@ -12,22 +12,24 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "encoder_write_cache_with_rope_impl.cuh" #include "helper.h" #include "paddle/extension.h" -#include "paddle/phi/core/memory/memcpy.h" -#include "encoder_write_cache_with_rope_impl.cuh" -#include "paddle/phi/kernels/gpu/flash_attn_v3_kernel.h" #include "paddle/phi/backends/context_pool.h" +#include "paddle/phi/core/memory/memcpy.h" +#include "qwen3_rope.h" #include "remote_cache_kv_ipc.h" -template +template __global__ void GQAVariableLengthRotarySplitKernel( const T *qkv, const float *cos_emb, const float *sin_emb, + const float *q_norm_weight, + const float *k_norm_weight, const int *batch_id_per_token, const int *cu_seqlens_q, - const int *seq_lens, + const int *seq_lens_encoder, const int *seq_lens_decoder, const int *cu_seqlens_k, T *qkv_out, @@ -37,77 +39,328 @@ __global__ void GQAVariableLengthRotarySplitKernel( const int64_t elem_cnt, const int q_num_head, const int kv_num_head, - const int seq_len, - const int last_dim) { + const int max_model_len, + const int head_dim, + const bool rope_3d, + const float rms_norm_eps) { using LoadT = AlignedVector; constexpr int HalfVecSize = VecSize / 2; using LoadEmbT = AlignedVector; + using LoadFloat = AlignedVector; LoadT src_vec; LoadEmbT cos_emb_vec; LoadEmbT sin_emb_vec; - int64_t global_thread_idx = blockDim.x * blockIdx.x + threadIdx.x; - const int half_lastdim = last_dim / 2; - const int offset = (q_num_head + kv_num_head * 2) * last_dim; - for (int64_t linear_index = global_thread_idx * VecSize, - step = gridDim.x * blockDim.x * VecSize; - linear_index < elem_cnt; - linear_index += step) { - const int token_idx = linear_index / offset; - const int ori_bi = batch_id_per_token[token_idx]; - if (seq_lens[ori_bi] == 0) continue; + LoadFloat tmp_vec; + LoadFloat q_norm_vec, k_norm_vec; + int64_t global_warp_idx = blockDim.y * blockIdx.x + threadIdx.y; + int64_t all_warp_num = gridDim.x * blockDim.y; + const int half_headdim = head_dim / 2; + const int offset = + (q_num_head + kv_num_head * 2) * head_dim; // for all q,k,v + const int all_head_num = elem_cnt / head_dim; + for (int gloabl_hi = global_warp_idx; gloabl_hi < all_head_num; + gloabl_hi += all_warp_num) { + int64_t linear_index = + gloabl_hi * head_dim + threadIdx.x * VecSize; // 全局index + const int token_idx = + linear_index / offset; // token id(第几个token,不分qkv) + const int ori_bi = batch_id_per_token[token_idx]; // 第几个batch + if (ori_bi == -1) continue; + + int cache_kv_len = seq_lens_decoder[ori_bi]; + // 这里其实是不需要处理的,但是由于FA3的bug,所以必须! + if (seq_lens_encoder[ori_bi] == 0) cache_kv_len = 0; + const int bias = linear_index % offset; - const int hi = bias / last_dim; - const int h_bias = bias % last_dim; + const int hi = bias / head_dim; + const int h_bias = bias % head_dim; - const int ori_seq_id = (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi]; + const int ori_seq_id = + (token_idx - cu_seqlens_q[ori_bi]) + + cache_kv_len; // 在当前seq中的id(拼接了seq到一个batch的情况下有效) + const int64_t emb_idx = + ori_seq_id * half_headdim + h_bias / 2; // embedding的id + const int64_t base_idx = + token_idx * (q_num_head + 2 * kv_num_head) * head_dim + hi * head_dim + + h_bias; + Load(&qkv[base_idx], &src_vec); const int kv_write_idx = cu_seqlens_k[ori_bi] + ori_seq_id; + int64_t base_split_idx; + T *out_p = nullptr; + if (hi < q_num_head) { + base_split_idx = + token_idx * q_num_head * head_dim + hi * head_dim + h_bias; + out_p = q; + } else if (hi < q_num_head + kv_num_head) { + base_split_idx = kv_write_idx * kv_num_head * head_dim + + (hi - q_num_head) * head_dim + h_bias; + out_p = k; + } else { + out_p = v; + base_split_idx = kv_write_idx * kv_num_head * head_dim + + (hi - q_num_head - kv_num_head) * head_dim + h_bias; + } - const int64_t emb_idx = ori_seq_id * half_lastdim + h_bias / 2; + // TODO check this correct or not + int64_t new_emb_idx = + rope_3d ? emb_idx + ori_bi * head_dim * max_model_len : emb_idx; + float thread_m2 = 0.0f; + float warp_m2 = 0.0f; + + if (q_norm_weight && k_norm_weight) { + if (hi < q_num_head + kv_num_head) { // only q and k need rope + Load(&cos_emb[new_emb_idx], &cos_emb_vec); + Load(&sin_emb[new_emb_idx], &sin_emb_vec); +#pragma unroll + for (int i = 0; i < HalfVecSize; i++) { + const float input_left = static_cast(src_vec[2 * i]); + const float input_right = static_cast(src_vec[2 * i + 1]); + const float cos_tmp = cos_emb_vec[i]; + const float sin_tmp = sin_emb_vec[i]; + float tmp1 = fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp); + float tmp2 = fmul_func(input_right, cos_tmp) + + fmul_func(input_left, sin_tmp); + tmp_vec[2 * i] = tmp1; + tmp_vec[2 * i + 1] = tmp2; + thread_m2 += tmp1 * tmp1 + tmp2 * tmp2; + } + } + WelfordWarpAllReduce(thread_m2, &warp_m2); // 单个head的标准差 + + if (hi < q_num_head + kv_num_head) { // only q and k need norm + float row_variance = max(warp_m2 / head_dim, 0.0f); + float row_inv_var = Rsqrt(row_variance + rms_norm_eps); + if (hi < q_num_head) { + Load(&q_norm_weight[threadIdx.x * VecSize], + &q_norm_vec); +#pragma unroll + for (int i = 0; i < VecSize; i++) { + src_vec[i] = + static_cast(tmp_vec[i] * row_inv_var * q_norm_vec[i]); + } + } else { + Load(&k_norm_weight[threadIdx.x * VecSize], + &k_norm_vec); + for (int i = 0; i < VecSize; i++) { + src_vec[i] = + static_cast(tmp_vec[i] * row_inv_var * k_norm_vec[i]); + } + } + } + } else { + if (hi < q_num_head + kv_num_head) { + Load(&cos_emb[new_emb_idx], &cos_emb_vec); + Load(&sin_emb[new_emb_idx], &sin_emb_vec); +#pragma unroll + for (int i = 0; i < HalfVecSize; i++) { + const float input_left = static_cast(src_vec[2 * i]); + const float input_right = static_cast(src_vec[2 * i + 1]); + const float cos_tmp = cos_emb_vec[i]; + const float sin_tmp = sin_emb_vec[i]; + src_vec[2 * i] = + static_cast(fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp)); + src_vec[2 * i + 1] = + static_cast(fmul_func(input_right, cos_tmp) + + fmul_func(input_left, sin_tmp)); + } + } + } + Store(src_vec, &qkv_out[base_idx]); + Store(src_vec, &out_p[base_split_idx]); + } +} + +template +void gqa_rotary_qk_split_variable( + T *qkv_out, // [token_num, 3, num_head, head_dim] + T *q, + T *k, + T *v, + const T *qkv_input, + const float *rotary_emb, // [2, 1, seq_len, 1, head_dim / 2] + const float *q_norm_weight, + const float *k_norm_weight, + const int *batch_id_per_token, + const int *seq_lens_encoder, + const int *seq_lens_decoder, + const int *cu_seqlens_q, + const int *cu_seqlens_k, + const int token_num, + const int num_heads, + const int kv_num_heads, + const int max_model_len, + const int input_output_len, + const int head_dim, + const bool rope_3d, + const float rms_norm_eps, + const cudaStream_t &stream) { + assert(head_dim == 128 && "head_dim must be 128"); + int64_t elem_nums = token_num * (num_heads + 2 * kv_num_heads) * head_dim; + + constexpr int HEAD_DIM = 128; + constexpr int PackSize = HEAD_DIM / kWarpSize; + const int pack_num = elem_nums / PackSize; + const int blocksize = 128; + int grid_size = 1; + GetNumBlocks<128>(pack_num, &grid_size); + dim3 block_size(kWarpSize, blocksize / kWarpSize); + + const float *cos_emb = rotary_emb; + const float *sin_emb = rotary_emb + input_output_len * head_dim / 2; + launchWithPdlWhenEnabled( + GQAVariableLengthRotarySplitKernel, + grid_size, + block_size, + 0, + stream, + qkv_input, + cos_emb, + sin_emb, + q_norm_weight, + k_norm_weight, + batch_id_per_token, + cu_seqlens_q, + seq_lens_encoder, + seq_lens_decoder, + cu_seqlens_k, + qkv_out, + q, + k, + v, + elem_nums, + num_heads, + kv_num_heads, + max_model_len, + head_dim, + rope_3d, + rms_norm_eps); +} + +template +__global__ void GQAVariableLengthNeoxPartialRotarySplitKernel( + const T *qkv, + const float *cos_emb, + const float *sin_emb, + const int *batch_id_per_token, + const int *cu_seqlens_q, + const int *seq_lens_encoder, + const int *seq_lens_decoder, + const int *cu_seqlens_k, + T *qkv_out, + T *q, + T *k, + T *v, + const int64_t elem_cnt, + const int q_num_head, + const int kv_num_head, + const int max_model_len, + const int head_dim, + const int rotary_dim) { + using LoadT = AlignedVector; + using LoadEmbT = AlignedVector; + LoadT src_vec; + LoadT src_vec_right; + LoadEmbT cos_emb_vec; + LoadEmbT sin_emb_vec; + int64_t global_warp_idx = blockDim.y * blockIdx.x + threadIdx.y; + int64_t all_warp_num = gridDim.x * blockDim.y; + const int half_rotary_dim = rotary_dim / 2; + const int half_headdim = head_dim / 2; + const int offset = + (q_num_head + kv_num_head * 2) * head_dim; // for all q,k,v + const int all_head_num = elem_cnt / head_dim; +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaGridDependencySynchronize(); +#endif + for (int gloabl_hi = global_warp_idx; gloabl_hi < all_head_num; + gloabl_hi += all_warp_num) { + int64_t linear_index = + gloabl_hi * head_dim + threadIdx.x * VecSize; // 全局index + const int token_idx = + linear_index / offset; // token id(第几个token,不分qkv) + const int ori_bi = batch_id_per_token[token_idx]; // 第几个batch + + int cache_kv_len = seq_lens_decoder[ori_bi]; + // 这里其实是不需要处理的,但是由于FA3的bug,所以必须! + if (seq_lens_encoder[ori_bi] == 0) cache_kv_len = 0; + + const int bias = linear_index % offset; + const int hi = bias / head_dim; + const int h_bias = bias % head_dim; + + const int ori_seq_id = + (token_idx - cu_seqlens_q[ori_bi]) + + cache_kv_len; // 在当前seq中的id(拼接了seq到一个batch的情况下有效) const int64_t base_idx = - token_idx * (q_num_head + 2 * kv_num_head) * last_dim + hi * last_dim + + token_idx * (q_num_head + 2 * kv_num_head) * head_dim + hi * head_dim + h_bias; + Load(&qkv[base_idx], &src_vec); + const int kv_write_idx = cu_seqlens_k[ori_bi] + ori_seq_id; int64_t base_split_idx; T *out_p = nullptr; if (hi < q_num_head) { - base_split_idx = token_idx * q_num_head * last_dim + hi * last_dim + h_bias; + base_split_idx = + token_idx * q_num_head * head_dim + hi * head_dim + h_bias; out_p = q; } else if (hi < q_num_head + kv_num_head) { - base_split_idx = kv_write_idx * kv_num_head * last_dim + (hi - q_num_head) * last_dim + h_bias; + base_split_idx = kv_write_idx * kv_num_head * head_dim + + (hi - q_num_head) * head_dim + h_bias; out_p = k; } else { out_p = v; - base_split_idx = kv_write_idx * kv_num_head * last_dim + (hi - q_num_head - kv_num_head) * last_dim + h_bias; + base_split_idx = kv_write_idx * kv_num_head * head_dim + + (hi - q_num_head - kv_num_head) * head_dim + h_bias; } - Load(&qkv[base_idx], &src_vec); - // do rope + if (hi < q_num_head + kv_num_head) { - Load(&cos_emb[emb_idx], &cos_emb_vec); - Load(&sin_emb[emb_idx], &sin_emb_vec); + if (h_bias < rotary_dim) { + int64_t emb_idx = ori_seq_id * half_rotary_dim; + if (h_bias < half_rotary_dim) { + Load(&qkv[base_idx + half_rotary_dim], &src_vec_right); + emb_idx += h_bias; + } else { + Load(&qkv[base_idx - half_rotary_dim], &src_vec_right); + emb_idx += h_bias - half_rotary_dim; + } + Load(&cos_emb[emb_idx], &cos_emb_vec); + Load(&sin_emb[emb_idx], &sin_emb_vec); #pragma unroll - for (int i = 0; i < HalfVecSize; i++) { - const float input_left = static_cast(src_vec[2 * i]); - const float input_right = static_cast(src_vec[2 * i + 1]); - const float cos_tmp = cos_emb_vec[i]; - const float sin_tmp = sin_emb_vec[i]; - src_vec[2 * i] = - static_cast(input_left * cos_tmp - input_right * sin_tmp); - src_vec[2 * i + 1] = - static_cast(input_right * cos_tmp + input_left * sin_tmp); + for (int i = 0; i < VecSize; i++) { + const float input_left = static_cast(src_vec[i]); + const float input_right = static_cast(src_vec_right[i]); + const float cos_tmp = cos_emb_vec[i]; + const float sin_tmp = sin_emb_vec[i]; + if (h_bias < half_rotary_dim) { + src_vec[i] = + static_cast(fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp)); + } else { + src_vec[i] = + static_cast(fmul_func(input_left, cos_tmp) + + fmul_func(input_right, sin_tmp)); + } + } } } + Store(src_vec, &qkv_out[base_idx]); Store(src_vec, &out_p[base_split_idx]); } +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaTriggerProgrammaticLaunchCompletion(); +#endif } -template -void gqa_rotary_qk_split_variable( - T *qkv_out, // [token_num, 3, num_head, dim_head] +template +void gqa_neox_partial_rotary_qk_split_variable( + T *qkv_out, // [token_num, 3, num_head, head_dim] T *q, T *k, T *v, const T *qkv_input, - const float *rotary_emb, // [2, 1, 1, seq_len, dim_head / 2] + const float *rotary_emb, // [2, 1, seq_len, 1, head_dim / 4] const int *batch_id_per_token, const int *seq_lens_encoder, const int *seq_lens_decoder, @@ -116,71 +369,281 @@ void gqa_rotary_qk_split_variable( const int token_num, const int num_heads, const int kv_num_heads, - const int seq_len, - const int input_output_len, - const int dim_head, + const int max_model_len, + const int head_dim, + const int rotary_dim, const cudaStream_t &stream) { - int64_t elem_nums = token_num * (num_heads + 2 * kv_num_heads) * dim_head; - constexpr int PackSize = 16 / sizeof(T); + assert(head_dim == 128 && "head_dim must be 128"); + int64_t elem_nums = token_num * (num_heads + 2 * kv_num_heads) * head_dim; + + constexpr int HEAD_DIM = 128; + constexpr int PackSize = HEAD_DIM / kWarpSize; + assert(rotary_dim / 2 % PackSize == 0); const int pack_num = elem_nums / PackSize; const int blocksize = 128; int grid_size = 1; GetNumBlocks<128>(pack_num, &grid_size); + dim3 block_size(kWarpSize, blocksize / kWarpSize); const float *cos_emb = rotary_emb; - const float *sin_emb = rotary_emb + input_output_len * dim_head / 2; - GQAVariableLengthRotarySplitKernel - <<>>( - qkv_input, - cos_emb, - sin_emb, - batch_id_per_token, - cu_seqlens_q, - seq_lens_encoder, - seq_lens_decoder, - cu_seqlens_k, - qkv_out, - q, - k, - v, - elem_nums, - num_heads, - kv_num_heads, - seq_len, - dim_head); + const float *sin_emb = rotary_emb + max_model_len * rotary_dim / 2; + launchWithPdlWhenEnabled( + GQAVariableLengthNeoxPartialRotarySplitKernel, + grid_size, + block_size, + 0, + stream, + qkv_input, + cos_emb, + sin_emb, + batch_id_per_token, + cu_seqlens_q, + seq_lens_encoder, + seq_lens_decoder, + cu_seqlens_k, + qkv_out, + q, + k, + v, + elem_nums, + num_heads, + kv_num_heads, + max_model_len, + head_dim, + rotary_dim); } template -__global__ void append_dequant_cache_kv_c8( - const CacheT *__restrict__ cache_k, - const CacheT *__restrict__ cache_v, - T *__restrict__ k_out, - T *__restrict__ v_out, - const T *__restrict__ cache_k_dequant_scales, - const T *__restrict__ cache_v_dequant_scales, - const int *__restrict__ seq_lens_this_time, - const int *__restrict__ seq_lens_decoder, - const int *__restrict__ cu_seqlens_k, - const int *__restrict__ block_tables, - const int *batch_ids, - const int *tile_ids_per_batch, - const int max_blocks_per_seq, - const int kv_num_heads) { - // start_kv_idx: 每个block的起始kv_idx - // batch_id:每个block属于的batch - // TODO: 1.scale预取 2.frag_dq_T复用 3.流水线编排 4.store访存合并 5.cacheT支持(int8/fp8) + uint32_t NUM_WARPS = 4> +__global__ void append_cache_kv_c16(const T *__restrict__ cache_k, + const T *__restrict__ cache_v, + T *__restrict__ k_out, + T *__restrict__ v_out, + const int *__restrict__ seq_lens_this_time, + const int *__restrict__ seq_lens_decoder, + const int *__restrict__ cu_seqlens_k, + const int *__restrict__ block_tables, + const int *batch_ids, + const int *tile_ids_per_batch, + const int max_blocks_per_seq, + const int kv_num_heads) { + // start_kv_idx: start kv_idx current block + // batch_id:block's batch_id + // TODO: 1.scale preload 2.frag_dq_T reuse 3.pipeline 4.store aligned 5.cacheT + // with template(int8/fp8) + const uint32_t tile_idx = blockIdx.x, kv_head_idx = blockIdx.z; + const uint32_t tid = threadIdx.x, wid = threadIdx.y; + + const uint32_t batch_id = batch_ids[tile_idx]; + const uint32_t start_kv_idx = tile_ids_per_batch[tile_idx] * BLOCK_SIZE; + const uint32_t end_idx = seq_lens_decoder[batch_id] - start_kv_idx; + if (seq_lens_this_time[batch_id] <= 0) { + return; + } + + const int *cur_block_table = block_tables + batch_id * max_blocks_per_seq; + uint32_t block_id = cur_block_table[start_kv_idx / BLOCK_SIZE]; + // cache_kv idx + uint32_t kv_h_stride = BLOCK_SIZE * HEAD_DIM; + uint32_t block_stride = kv_num_heads * kv_h_stride; + const CacheT *cur_cache_k = + cache_k + block_id * block_stride + kv_head_idx * kv_h_stride; + const CacheT *cur_cache_v = + cache_v + block_id * block_stride + kv_head_idx * kv_h_stride; + + // k_out v_out idx + uint32_t kv_t_stride = kv_num_heads * HEAD_DIM; + T *k_write_ptr = + k_out + (cu_seqlens_k[batch_id] + start_kv_idx) * kv_t_stride; + T *v_write_ptr = + v_out + (cu_seqlens_k[batch_id] + start_kv_idx) * kv_t_stride; + + uint32_t kv_frag[4]; + T *frag_dq_T = reinterpret_cast(kv_frag); + + constexpr uint32_t num_vecs_per_head = + HEAD_DIM / num_elems_per_128b(); + constexpr uint32_t inv_kv_stride = 8 / num_vecs_per_head; + + extern __shared__ uint8_t smem[]; + smem_t k_smem(smem); + uint32_t k_smem_offset_w = + smem_t::get_permuted_offset( + wid * 4 + tid / 8, tid % 8); // 4 * 4 per warp + + uint32_t k_smem_offset_r = + smem_t::get_permuted_offset( + wid * 16 + 8 * (tid / 16) + tid % 8, (tid % 16) / 8); + + uint32_t k_read_idx = + (wid * 4 + tid / 8) * HEAD_DIM + tid % 8 * num_elems_per_128b(); + + // load k_smem 64 rows 128 cols + for (int fz = 0; fz < 4; + fz++) { // 4 rows pre warp once, 16 rows all 4 warps once, need 4 iter + for (int fy = 0; fy < 2; fy++) { // 8 * 128b = 64 * bf16 once, need 2 iter + k_smem.load_128b_async( + k_smem_offset_w, cur_cache_k + k_read_idx, end_idx > 0); + k_smem_offset_w = k_smem.advance_offset_by_column<8, num_vecs_per_head>( + k_smem_offset_w, fy); + k_read_idx += 8 * num_elems_per_128b(); + } + k_smem_offset_w = + k_smem.advance_offset_by_row<4 * NUM_WARPS, num_vecs_per_head>( + k_smem_offset_w) - + 16; + k_read_idx += 4 * NUM_WARPS * HEAD_DIM - 16 * num_elems_per_128b(); + } + commit_group(); + wait_group<0>(); + __syncthreads(); + + // deal k_smem 64 rows 128 cols + for (int fz = 0; fz < 1; + fz++) { // 16 rows pre warp once, 64 rows all 4 warps once, need 1 iter + uint32_t row_idx = wid * 16 + tid / 4; + for (int fy = 0; fy < 8; fy++) { // 2 * 128b = 16 * bf16 once, need 8 iter + uint32_t col_idx = fy * 16 + tid % 4 * 2; + k_smem.ldmatrix_m8n8x4(k_smem_offset_r, kv_frag); + // layout + /*** + r0c0,r0c1, r0c8,r0c9 + r8c0,r8c1, r8c8,r8c9 + ***/ + T *k_tile_ptr0 = k_write_ptr + row_idx * kv_t_stride + + kv_head_idx * HEAD_DIM + col_idx; + T *k_tile_ptr1 = k_tile_ptr0 + 8 * kv_t_stride; + + if (row_idx < end_idx) { + k_tile_ptr0[0] = frag_dq_T[0]; + k_tile_ptr0[1] = frag_dq_T[1]; + k_tile_ptr0[8] = frag_dq_T[2]; + k_tile_ptr0[9] = frag_dq_T[3]; + } + + if (row_idx + 8 < end_idx) { + k_tile_ptr1[0] = frag_dq_T[4]; + k_tile_ptr1[1] = frag_dq_T[5]; + k_tile_ptr1[8] = frag_dq_T[6]; + k_tile_ptr1[9] = frag_dq_T[7]; + } + k_smem_offset_r = k_smem.advance_offset_by_column<2, num_vecs_per_head>( + k_smem_offset_r, fy); + } + k_smem_offset_r = + k_smem.advance_offset_by_row<16 * NUM_WARPS, num_vecs_per_head>( + k_smem_offset_r) - + 16; + } + + // ================v================ + smem_t v_smem(smem + BLOCK_SIZE * HEAD_DIM * sizeof(CacheT)); + uint32_t v_smem_offset_w = + smem_t::get_permuted_offset( + wid * 4 + tid / 8, tid % 8); // 4 * 4 per warp + uint32_t v_smem_offset_r = + smem_t::get_permuted_offset( + wid * 16 + 8 * (tid / 16) + tid % 8, (tid % 16) / 8); + + uint32_t v_read_idx = + (wid * 4 + tid / 8) * HEAD_DIM + tid % 8 * num_elems_per_128b(); + + // load v_smem 64 rows 128 cols + for (int fz = 0; fz < 4; fz++) { // // 4 rows pre warp once, 16 rows all 4 + // warps once, need 4 iter + for (int fy = 0; fy < 2; fy++) { // 8 * 128b = 64 * bf16 once, need 2 iter + v_smem.load_128b_async( + v_smem_offset_w, cur_cache_v + v_read_idx, end_idx > 0); + v_smem_offset_w = v_smem.advance_offset_by_column<8, num_vecs_per_head>( + v_smem_offset_w, fy); + v_read_idx += 8 * num_elems_per_128b(); + } + v_smem_offset_w = + v_smem.advance_offset_by_row<4 * NUM_WARPS, num_vecs_per_head>( + v_smem_offset_w) - + 16; + v_read_idx += 4 * NUM_WARPS * HEAD_DIM - 16 * num_elems_per_128b(); + } + commit_group(); + wait_group<0>(); + __syncthreads(); + + // deal v_smem 64 rows 128 cols + for (int fz = 0; fz < 1; + fz++) { // 16 rows pre warp once, 64 rows all 4 warps once, need 1 iter + uint32_t row_idx = wid * 16 + tid / 4; + for (int fy = 0; fy < 8; fy++) { // 2 * 128b = 16 * bf16 once, need 8 iter + uint32_t col_idx = fy * 16 + tid % 4 * 2; + v_smem.ldmatrix_m8n8x4(v_smem_offset_r, kv_frag); + // layout + /*** + r0c0,r0c1, r0c8,r0c9 + r8c0,r8c1, r8c8,r8c9 + ***/ + T *v_tile_ptr0 = v_write_ptr + row_idx * kv_t_stride + + kv_head_idx * HEAD_DIM + col_idx; + T *v_tile_ptr1 = v_tile_ptr0 + 8 * kv_t_stride; + + if (row_idx < end_idx) { + v_tile_ptr0[0] = frag_dq_T[0]; + v_tile_ptr0[1] = frag_dq_T[1]; + v_tile_ptr0[8] = frag_dq_T[2]; + v_tile_ptr0[9] = frag_dq_T[3]; + } + + if (row_idx + 8 < end_idx) { + v_tile_ptr1[0] = frag_dq_T[4]; + v_tile_ptr1[1] = frag_dq_T[5]; + v_tile_ptr1[8] = frag_dq_T[6]; + v_tile_ptr1[9] = frag_dq_T[7]; + } + v_smem_offset_r = v_smem.advance_offset_by_column<2, num_vecs_per_head>( + v_smem_offset_r, fy); + } + v_smem_offset_r = + v_smem.advance_offset_by_row<16 * NUM_WARPS, num_vecs_per_head>( + v_smem_offset_r) - + 16; + } +} + +template +__global__ void append_cache_kv_c8(const CacheT *__restrict__ cache_k, + const CacheT *__restrict__ cache_v, + T *__restrict__ k_out, + T *__restrict__ v_out, + const T *__restrict__ cache_k_quant_scales, + const T *__restrict__ cache_v_quant_scales, + const T *__restrict__ cache_k_dequant_scales, + const T *__restrict__ cache_v_dequant_scales, + const int *__restrict__ seq_lens_this_time, + const int *__restrict__ seq_lens_decoder, + const int *__restrict__ cu_seqlens_k, + const int *__restrict__ block_tables, + const int *batch_ids, + const int *tile_ids_per_batch, + const int max_blocks_per_seq, + const int kv_num_heads) { + // start_kv_idx: start kv_idx current block + // batch_id:block's batch_id + // TODO: 1.scale preload 2.frag_dq_T reuse 3.pipeline 4.store aligned 5.cacheT + // with template(int8/fp8) const uint32_t tile_idx = blockIdx.x, kv_head_idx = blockIdx.z; const uint32_t tid = threadIdx.x, wid = threadIdx.y; const uint32_t batch_id = batch_ids[tile_idx]; const uint32_t start_kv_idx = tile_ids_per_batch[tile_idx] * BLOCK_SIZE; const uint32_t end_idx = seq_lens_decoder[batch_id] - start_kv_idx; - if (seq_lens_this_time <= 0) { + if (seq_lens_this_time[batch_id] <= 0) { return; } @@ -189,18 +652,33 @@ __global__ void append_dequant_cache_kv_c8( // cache_kv idx uint32_t kv_h_stride = BLOCK_SIZE * HEAD_DIM; uint32_t block_stride = kv_num_heads * kv_h_stride; - const CacheT *cur_cache_k = cache_k + block_id * block_stride + kv_head_idx * kv_h_stride; - const CacheT *cur_cache_v = cache_v + block_id * block_stride + kv_head_idx * kv_h_stride; + const CacheT *cur_cache_k = + cache_k + block_id * block_stride + kv_head_idx * kv_h_stride; + const CacheT *cur_cache_v = + cache_v + block_id * block_stride + kv_head_idx * kv_h_stride; + const T *cur_cache_k_scales; + const T *cur_cache_v_scales; + T cache_k_scale = 0; + T cache_v_scale = 0; + if constexpr (dynamic_quant) { + cur_cache_k_scales = cache_k_quant_scales + + (block_id * kv_num_heads + kv_head_idx) * BLOCK_SIZE; + cur_cache_v_scales = cache_v_quant_scales + + (block_id * kv_num_heads + kv_head_idx) * BLOCK_SIZE; + } else { + cache_k_scale = cache_k_dequant_scales[kv_head_idx]; + cache_v_scale = cache_v_dequant_scales[kv_head_idx]; + } // k_out v_out idx uint32_t kv_t_stride = kv_num_heads * HEAD_DIM; - T *k_write_ptr = k_out + (cu_seqlens_k[batch_id] + start_kv_idx) * kv_t_stride; // 当前k block起始指针 - T *v_write_ptr = v_out + (cu_seqlens_k[batch_id] + start_kv_idx) * kv_t_stride; // 当前v block起始指针 + T *k_write_ptr = + k_out + (cu_seqlens_k[batch_id] + start_kv_idx) * kv_t_stride; + T *v_write_ptr = + v_out + (cu_seqlens_k[batch_id] + start_kv_idx) * kv_t_stride; uint32_t k_frag[4], v_frag[4], frag_dq[4]; T *frag_dq_T = reinterpret_cast(frag_dq); - T cache_k_scale = cache_k_dequant_scales[kv_head_idx]; - T cache_v_scale = cache_v_dequant_scales[kv_head_idx]; constexpr uint32_t num_vecs_per_head_k = HEAD_DIM / num_elems_per_128b(); @@ -211,94 +689,116 @@ __global__ void append_dequant_cache_kv_c8( extern __shared__ uint8_t smem[]; smem_t k_smem(smem); - uint32_t k_smem_offset_w = smem_t::get_permuted_offset( - wid * 4 + tid / 8, tid % 8); // 4 * 4 per warp + uint32_t k_smem_offset_w = + smem_t::get_permuted_offset( + wid * 4 + tid / 8, tid % 8); // 4 * 4 per warp - uint32_t k_smem_offset_r = smem_t::get_permuted_offset( - wid * 16 + 8 * (tid / 16) + tid % 8, (tid % 16) / 8); + uint32_t k_smem_offset_r = + smem_t::get_permuted_offset( + wid * 16 + 8 * (tid / 16) + tid % 8, (tid % 16) / 8); - uint32_t k_read_idx = (wid * 4 + tid / 8) * HEAD_DIM + - tid % 8 * num_elems_per_128b(); + uint32_t k_read_idx = + (wid * 4 + tid / 8) * HEAD_DIM + tid % 8 * num_elems_per_128b(); - // load k_smem 行是64 列是128 - for (int fz = 0; fz < 4; fz++) { // 每个warp1次4行,循环4次16行,4个warp64行 - for (int fy = 0; fy < 1; fy++) { // 一次8个128b = 128个uint8 + // load v_smem 64 rows, 128 cols + for (int fz = 0; fz < 4; + fz++) { // 4 rows pre warp once, 16 rows all 4 warps once, need 4 iter + for (int fy = 0; fy < 1; + fy++) { // 8 * 128b = 128 * uint8 once, need 1 iter k_smem.load_128b_async( - k_smem_offset_w, cur_cache_k + k_read_idx, end_idx > 0); - k_smem_offset_w = - k_smem.advance_offset_by_column<8, num_vecs_per_head_k>(k_smem_offset_w, fy); + k_smem_offset_w, cur_cache_k + k_read_idx, end_idx > 0); + k_smem_offset_w = k_smem.advance_offset_by_column<8, num_vecs_per_head_k>( + k_smem_offset_w, fy); k_read_idx += 8 * num_elems_per_128b(); } k_smem_offset_w = - k_smem.advance_offset_by_row<4 * NUM_WARPS, num_vecs_per_head_k>(k_smem_offset_w) - 8; + k_smem.advance_offset_by_row<4 * NUM_WARPS, num_vecs_per_head_k>( + k_smem_offset_w) - + 8; k_read_idx += 4 * NUM_WARPS * HEAD_DIM - 8 * num_elems_per_128b(); } commit_group(); wait_group<0>(); __syncthreads(); - // deal k_smem 行是64 列是128 - for (int fz = 0; fz < 1; fz++) { // 每个warp1次16行,4个warp64行 + // deal k_smem 64 rows, 128 cols + for (int fz = 0; fz < 1; + fz++) { // 16 rows pre warp once, 64 rows all 4 warps once, need 1 iter uint32_t row_idx = wid * 16 + tid / 4; - for (int fy = 0; fy < 4; fy++) { // 1次2个128b(32个uint8),4次循环8个128b(128个uint8) + for (int fy = 0; fy < 4; fy++) { // 2 * 128b = 32 * uint8 once, need 4 iter uint32_t col_idx = fy * 32 + tid % 4 * 2; k_smem.ldmatrix_m8n8x4(k_smem_offset_r, k_frag); - // 反量化 存储 + // layout /*** r0c0,r0c1,r0c8,r0c9, r8c0,r8c1,r8c8,r8c9 r0c16,r0c17,r0c24,r0c25, r8c16,r8c17,r8c24,r8c25 ***/ for (int i = 0; i < 4 / 2; i++) { - T *k_tile_ptr0 = k_write_ptr + row_idx * kv_t_stride + kv_head_idx * HEAD_DIM + col_idx; + T *k_tile_ptr0 = k_write_ptr + row_idx * kv_t_stride + + kv_head_idx * HEAD_DIM + col_idx; T *k_tile_ptr1 = k_tile_ptr0 + 8 * kv_t_stride; + T cache_k_scale_0 = cache_k_scale; + T cache_k_scale_1 = cache_k_scale; + if constexpr (dynamic_quant) { + cache_k_scale_0 = cur_cache_k_scales[row_idx]; + cache_k_scale_1 = cur_cache_k_scales[row_idx + 8]; + } if (row_idx < end_idx) { - convert_c8(frag_dq_T,k_frag[2 * i]); // 4个uint8/fp8 -> 4个T - - k_tile_ptr0[0] = frag_dq_T[0] * cache_k_scale; - k_tile_ptr0[1] = frag_dq_T[1] * cache_k_scale; - k_tile_ptr0[8] = frag_dq_T[2] * cache_k_scale; - k_tile_ptr0[9] = frag_dq_T[3] * cache_k_scale; + convert_c8(frag_dq_T, + k_frag[2 * i]); // 4 * uint8/fp8 -> 4 * T + k_tile_ptr0[0] = frag_dq_T[0] * cache_k_scale_0; + k_tile_ptr0[1] = frag_dq_T[1] * cache_k_scale_0; + k_tile_ptr0[8] = frag_dq_T[2] * cache_k_scale_0; + k_tile_ptr0[9] = frag_dq_T[3] * cache_k_scale_0; } if (row_idx + 8 < end_idx) { - convert_c8(frag_dq_T + 4,k_frag[2 * i + 1]); // 4个uint8/fp8 -> 4个T - - k_tile_ptr1[0] = frag_dq_T[4] * cache_k_scale; - k_tile_ptr1[1] = frag_dq_T[5] * cache_k_scale; - k_tile_ptr1[8] = frag_dq_T[6] * cache_k_scale; - k_tile_ptr1[9] = frag_dq_T[7] * cache_k_scale; + convert_c8(frag_dq_T + 4, + k_frag[2 * i + 1]); // 4 * uint8/fp8 -> 4 * T + k_tile_ptr1[0] = frag_dq_T[4] * cache_k_scale_1; + k_tile_ptr1[1] = frag_dq_T[5] * cache_k_scale_1; + k_tile_ptr1[8] = frag_dq_T[6] * cache_k_scale_1; + k_tile_ptr1[9] = frag_dq_T[7] * cache_k_scale_1; } col_idx += 16; } k_smem_offset_r = k_smem.advance_offset_by_column<2, num_vecs_per_head_k>( - k_smem_offset_r, fy); + k_smem_offset_r, fy); } k_smem_offset_r = - k_smem.advance_offset_by_row<16 * NUM_WARPS, num_vecs_per_head_k>(k_smem_offset_r) - 8; + k_smem.advance_offset_by_row<16 * NUM_WARPS, num_vecs_per_head_k>( + k_smem_offset_r) - + 8; } - // ================v================ + // ================v================ smem_t v_smem(smem + BLOCK_SIZE * HEAD_DIM * sizeof(CacheT)); - uint32_t v_smem_offset_w = smem_t::get_permuted_offset( - wid * 8 + tid / 4, tid % 4); // 4 * 8 per warp + uint32_t v_smem_offset_w = + smem_t::get_permuted_offset( + wid * 8 + tid / 4, tid % 4); // 4 * 8 per warp - uint32_t v_smem_offset_r = smem_t::get_permuted_offset( - wid * 16 + 8 * (tid / 16) + tid % 8, (tid % 16) / 8); + uint32_t v_smem_offset_r = + smem_t::get_permuted_offset( + wid * 16 + 8 * (tid / 16) + tid % 8, (tid % 16) / 8); - uint32_t v_read_idx = (wid * 8 + tid / 4) * BLOCK_SIZE + - tid % 4 * num_elems_per_128b(); - // load v_smem 行是128 列是64 - for (int fy = 0; fy < 4; fy++) { // 每个warp1次8行,循环4次32行,4个warp128行 - for (int fz = 0; fz < 1; fz++) { // 一次4个128b = 64个uint8 + uint32_t v_read_idx = + (wid * 8 + tid / 4) * BLOCK_SIZE + tid % 4 * num_elems_per_128b(); + // load v_smem 128 rows 64 cols + for (int fy = 0; fy < 4; + fy++) { // 8 rows pre warp once, 32 rows all 4 warps once, need 4 iter + for (int fz = 0; fz < 1; fz++) { // 4 * 128b = 64 * uint8 once, need 1 iter v_smem.load_128b_async( - v_smem_offset_w, cur_cache_v + v_read_idx, end_idx > 0); + v_smem_offset_w, cur_cache_v + v_read_idx, end_idx > 0); v_smem_offset_w = - v_smem.advance_offset_by_column<4, num_vecs_per_blocksize>(v_smem_offset_w, fz); + v_smem.advance_offset_by_column<4, num_vecs_per_blocksize>( + v_smem_offset_w, fz); v_read_idx += 4 * num_elems_per_128b(); } v_smem_offset_w = - v_smem.advance_offset_by_row<8 * NUM_WARPS, num_vecs_per_blocksize>(v_smem_offset_w) - 4; + v_smem.advance_offset_by_row<8 * NUM_WARPS, num_vecs_per_blocksize>( + v_smem_offset_w) - + 4; v_read_idx += 8 * NUM_WARPS * BLOCK_SIZE - 4 * num_elems_per_128b(); } @@ -306,156 +806,618 @@ __global__ void append_dequant_cache_kv_c8( wait_group<0>(); __syncthreads(); - // deal v_smem 行是128 列是64 row_idx是head_dim, col_idx是block_size - for (int fy = 0; fy < 2; fy++) { // 每个warp1次16行,循环2次32行,4个warp128行 + // deal v_smem 128 rows 64 cols + for (int fy = 0; fy < 2; + fy++) { // 16 rows pre warp once, 64 rows all 4 warps once, need 2 iter uint32_t dim_idx = fy * NUM_WARPS * 16 + wid * 16 + tid / 4; - for (int fz = 0; fz < 2; fz++) { // 1次2个128b(32个uint8),2次循环4个128b(64个uint8) + for (int fz = 0; fz < 2; fz++) { // 2 * 128b = 32 * uint8 once, need 2 iter uint32_t kv_idx = fz * 32 + tid % 4 * 2; v_smem.ldmatrix_m8n8x4(v_smem_offset_r, v_frag); - // 反量化 存储 + // layout for (int i = 0; i < 4 / 2; i++) { - T *v_tile_ptr0 = v_write_ptr + kv_idx * kv_t_stride + kv_head_idx * HEAD_DIM + dim_idx; + T *v_tile_ptr0 = v_write_ptr + kv_idx * kv_t_stride + + kv_head_idx * HEAD_DIM + dim_idx; T *v_tile_ptr1 = v_tile_ptr0 + 8; + T cache_v_scale_0 = cache_v_scale; + T cache_v_scale_1 = cache_v_scale; + T cache_v_scale_2 = cache_v_scale; + T cache_v_scale_3 = cache_v_scale; + + if constexpr (dynamic_quant) { + cache_v_scale_0 = cur_cache_v_scales[kv_idx]; + cache_v_scale_1 = cur_cache_v_scales[kv_idx + 1]; + cache_v_scale_2 = cur_cache_v_scales[kv_idx + 2]; + cache_v_scale_3 = cur_cache_v_scales[kv_idx + 3]; + } + convert_c8(frag_dq_T, + v_frag[2 * i]); // 4 * uint8/fp8 -> 4 * T + convert_c8(frag_dq_T + 4, + v_frag[2 * i + 1]); // 4 * uint8/fp8 -> 4 * T if (kv_idx < end_idx) { - convert_c8(frag_dq_T, v_frag[2 * i]); // 4个uint8/fp8 -> 4个T -#ifdef C8_DEBUG - if (tid == 0 && wid == 0 && tile_idx == 0 && kv_head_idx == 0) { - printf("1.fy: %d, fz:%d, row_idx: %d, col_idx: %d, v_frag: %.f, %.f, %.f, %.f \n", - fy, fz, kv_idx, dim_idx, static_cast(frag_dq_T[0]), static_cast(frag_dq_T[1]), - static_cast(frag_dq_T[2]), static_cast(frag_dq_T[3])); - } -#endif - v_tile_ptr0[0] = frag_dq_T[0] * cache_v_scale; - v_tile_ptr0[kv_t_stride] = frag_dq_T[1] * cache_v_scale; - v_tile_ptr0[8 * kv_t_stride] = frag_dq_T[2] * cache_v_scale; - v_tile_ptr0[9 * kv_t_stride] = frag_dq_T[3] * cache_v_scale; - - - convert_c8(frag_dq_T + 4, v_frag[2 * i + 1]); // 4个uint8/fp8 -> 4个T -#ifdef C8_DEBUG - if (tid == 0 && wid == 0 && tile_idx == 0 && kv_head_idx == 0) { - printf("2.fy: %d, fz:%d, row_idx: %d, col_idx: %d, v_frag: %.f, %.f, %.f, %.f \n", - fy, fz, kv_idx, dim_idx + 8, static_cast(frag_dq_T[4]), static_cast(frag_dq_T[5]), - static_cast(frag_dq_T[6]), static_cast(frag_dq_T[7])); - } -#endif - v_tile_ptr1[0] = frag_dq_T[4] * cache_v_scale; - v_tile_ptr1[kv_t_stride] = frag_dq_T[5] * cache_v_scale; - v_tile_ptr1[8 * kv_t_stride] = frag_dq_T[6] * cache_v_scale; - v_tile_ptr1[9 * kv_t_stride] = frag_dq_T[7] * cache_v_scale; + v_tile_ptr0[0] = frag_dq_T[0] * cache_v_scale_0; + v_tile_ptr1[0] = frag_dq_T[4] * cache_v_scale_0; + } + if (kv_idx + 1 < end_idx) { + v_tile_ptr0[kv_t_stride] = frag_dq_T[1] * cache_v_scale_1; + v_tile_ptr1[kv_t_stride] = frag_dq_T[5] * cache_v_scale_1; + } + if (kv_idx + 8 < end_idx) { + v_tile_ptr0[8 * kv_t_stride] = frag_dq_T[2] * cache_v_scale_2; + v_tile_ptr1[8 * kv_t_stride] = frag_dq_T[6] * cache_v_scale_2; + } + if (kv_idx + 9 < end_idx) { + v_tile_ptr0[9 * kv_t_stride] = frag_dq_T[3] * cache_v_scale_3; + v_tile_ptr1[9 * kv_t_stride] = frag_dq_T[7] * cache_v_scale_3; } kv_idx += 16; } - v_smem_offset_r = v_smem.advance_offset_by_column<2, num_vecs_per_blocksize>( - v_smem_offset_r, fz); + v_smem_offset_r = + v_smem.advance_offset_by_column<2, num_vecs_per_blocksize>( + v_smem_offset_r, fz); } v_smem_offset_r = - v_smem.advance_offset_by_row<16 * NUM_WARPS, num_vecs_per_blocksize>(v_smem_offset_r) - 4; + v_smem.advance_offset_by_row<16 * NUM_WARPS, num_vecs_per_blocksize>( + v_smem_offset_r) - + 4; + } +} + +template +__global__ void append_cache_kv_c4(const CacheT *__restrict__ cache_k, + const CacheT *__restrict__ cache_v, + T *__restrict__ k_out, + T *__restrict__ v_out, + const T *__restrict__ cache_k_dequant_scales, + const T *__restrict__ cache_v_dequant_scales, + const T *__restrict__ cache_k_zero_point, + const T *__restrict__ cache_v_zero_point, + const int *__restrict__ seq_lens_this_time, + const int *__restrict__ seq_lens_decoder, + const int *__restrict__ cu_seqlens_k, + const int *__restrict__ block_tables, + const int *batch_ids, + const int *tile_ids_per_batch, + const int max_blocks_per_seq, + const int kv_num_heads) { + // start_kv_idx: start kv_idx current block + // batch_id:block's batch_id + // TODO: 1.scale preload 2.frag_dq_T reuse 3.pipeline 4.store aligned 5.cacheT + // with template(int8/fp8) + const uint32_t tile_idx = blockIdx.x, kv_head_idx = blockIdx.z; + const uint32_t tid = threadIdx.x, wid = threadIdx.y; + + const uint32_t batch_id = batch_ids[tile_idx]; + const uint32_t start_kv_idx = tile_ids_per_batch[tile_idx] * BLOCK_SIZE; + const uint32_t end_idx = seq_lens_decoder[batch_id] - start_kv_idx; + if (seq_lens_this_time[batch_id] <= 0) { + return; + } + + const int *cur_block_table = block_tables + batch_id * max_blocks_per_seq; + uint32_t block_id = cur_block_table[start_kv_idx / BLOCK_SIZE]; + if (block_id < 0) block_id = 0; + + constexpr uint32_t HEAD_DIM_HALF = HEAD_DIM / 2; + constexpr uint32_t BLOCK_SIZE_HALF = BLOCK_SIZE / 2; + // cache_kv idx + uint32_t kv_h_stride = BLOCK_SIZE * HEAD_DIM_HALF; + uint32_t block_stride = kv_num_heads * kv_h_stride; + const CacheT *cur_cache_k = + cache_k + block_id * block_stride + kv_head_idx * kv_h_stride; + const CacheT *cur_cache_v = + cache_v + block_id * block_stride + kv_head_idx * kv_h_stride; + + // k_out v_out idx + uint32_t kv_t_stride = kv_num_heads * HEAD_DIM; + T *k_write_ptr = + k_out + (cu_seqlens_k[batch_id] + start_kv_idx) * kv_t_stride; + T *v_write_ptr = + v_out + (cu_seqlens_k[batch_id] + start_kv_idx) * kv_t_stride; + + extern __shared__ uint8_t smem[]; + + uint32_t k_frag[4], v_frag[4], frag_dq[8]; + T *frag_dq_T = reinterpret_cast(frag_dq); + + // load dequant scales and zero points + const T *cache_k_scale_now = cache_k_dequant_scales + kv_head_idx * HEAD_DIM; + const T *cache_k_zp_now = cache_k_zero_point + kv_head_idx * HEAD_DIM; + const T *cache_v_scale_now = cache_v_dequant_scales + kv_head_idx * HEAD_DIM; + const T *cache_v_zp_now = cache_v_zero_point + kv_head_idx * HEAD_DIM; + T *cache_k_scale_smem = + reinterpret_cast(smem + BLOCK_SIZE * HEAD_DIM * sizeof(CacheT)); + T *cache_k_zero_point_smem = cache_k_scale_smem + HEAD_DIM; + T *cache_v_scale_smem = cache_k_zero_point_smem + HEAD_DIM; + T *cache_v_zero_point_smem = cache_v_scale_smem + HEAD_DIM; +#pragma unroll + for (uint32_t i = wid * 32 + tid; i < HEAD_DIM; i += 128) { + cache_k_scale_smem[i] = cache_k_scale_now[i]; + cache_k_zero_point_smem[i] = cache_k_zp_now[i] + static_cast(136.f); + cache_v_scale_smem[i] = cache_v_scale_now[i]; + cache_v_zero_point_smem[i] = cache_v_zp_now[i] + static_cast(136.f); + } + + smem_t k_smem(smem); + constexpr uint32_t num_vecs_per_head_k = + HEAD_DIM_HALF / num_elems_per_128b(); // 2 + constexpr uint32_t num_vecs_per_blocksize = + BLOCK_SIZE_HALF / num_elems_per_128b(); + constexpr uint32_t inv_k_stride = 8 / num_vecs_per_head_k; // 4 + constexpr uint32_t inv_v_stride = 8 / num_vecs_per_blocksize; + + uint32_t k_smem_offset_w = + smem_t::get_permuted_offset( + wid * 8 + tid / 4, tid % 4); // 2(iter) * 4(warp) * 8 row per warp + + uint32_t k_smem_offset_r = + smem_t::get_permuted_offset( + wid * 16 + 8 * (tid / 16) + tid % 8, (tid % 16) / 8); // + + uint32_t k_read_idx = (wid * 8 + tid / 4) * HEAD_DIM / 2 + + tid % 4 * num_elems_per_128b(); + + // load k_smem 64 rows 128 cols + for (int fz = 0; fz < 2; + fz++) { // 4 rows pre warp once, 16 rows all 4 warps once, need 4 iter + for (int fy = 0; fy < 1; fy++) { // 4 * 128b = 128 * int4 once, need 1 iter + k_smem.load_128b_async( + k_smem_offset_w, cur_cache_k + k_read_idx, end_idx > 0); + k_smem_offset_w = k_smem.advance_offset_by_column<4, num_vecs_per_head_k>( + k_smem_offset_w, fy); + k_read_idx += 4 * num_elems_per_128b(); + } + k_smem_offset_w = + k_smem.advance_offset_by_row<8 * NUM_WARPS, num_vecs_per_head_k>( + k_smem_offset_w) - + 4; + k_read_idx += + 8 * NUM_WARPS * HEAD_DIM / 2 - 4 * num_elems_per_128b(); + } + commit_group(); + wait_group<0>(); + __syncthreads(); + + // deal k_smem 64 rows 128 cols + for (int fz = 0; fz < 1; + fz++) { // 16 rows pre warp once, 64 rows all 4 warps once, need 1 iter + uint32_t row_idx = wid * 16 + tid / 4; + for (int fy = 0; fy < 2; fy++) { // 2 * 128b = 64 * int4 once, need 2 iter + uint32_t col_idx = fy * 64 + tid % 4 * 2; + k_smem.ldmatrix_m8n8x4(k_smem_offset_r, k_frag); + + for (int i = 0; i < 2; i++) { + T *k_tile_ptr0 = k_write_ptr + row_idx * kv_t_stride + + kv_head_idx * HEAD_DIM + col_idx; + T *k_tile_ptr1 = k_tile_ptr0 + 8 * kv_t_stride; + convert_int4(frag_dq_T, k_frag[2 * i]); + convert_int4(frag_dq_T + 8, k_frag[2 * i + 1]); + + if (row_idx < end_idx) { + k_tile_ptr0[0] = (frag_dq_T[0] - cache_k_zero_point_smem[col_idx]) * + cache_k_scale_smem[col_idx]; + k_tile_ptr0[1] = + (frag_dq_T[1] - cache_k_zero_point_smem[col_idx + 1]) * + cache_k_scale_smem[col_idx + 1]; + k_tile_ptr0[8] = + (frag_dq_T[2] - cache_k_zero_point_smem[col_idx + 8]) * + cache_k_scale_smem[col_idx + 8]; + k_tile_ptr0[9] = + (frag_dq_T[3] - cache_k_zero_point_smem[col_idx + 9]) * + cache_k_scale_smem[col_idx + 9]; + k_tile_ptr0[16] = + (frag_dq_T[8] - cache_k_zero_point_smem[col_idx + 16]) * + cache_k_scale_smem[col_idx + 16]; + k_tile_ptr0[17] = + (frag_dq_T[9] - cache_k_zero_point_smem[col_idx + 17]) * + cache_k_scale_smem[col_idx + 17]; + k_tile_ptr0[24] = + (frag_dq_T[10] - cache_k_zero_point_smem[col_idx + 24]) * + cache_k_scale_smem[col_idx + 24]; + k_tile_ptr0[25] = + (frag_dq_T[11] - cache_k_zero_point_smem[col_idx + 25]) * + cache_k_scale_smem[col_idx + 25]; + } + + if (row_idx + 8 < end_idx) { + k_tile_ptr1[0] = (frag_dq_T[4] - cache_k_zero_point_smem[col_idx]) * + cache_k_scale_smem[col_idx]; + k_tile_ptr1[1] = + (frag_dq_T[5] - cache_k_zero_point_smem[col_idx + 1]) * + cache_k_scale_smem[col_idx + 1]; + k_tile_ptr1[8] = + (frag_dq_T[6] - cache_k_zero_point_smem[col_idx + 8]) * + cache_k_scale_smem[col_idx + 8]; + k_tile_ptr1[9] = + (frag_dq_T[7] - cache_k_zero_point_smem[col_idx + 9]) * + cache_k_scale_smem[col_idx + 9]; + k_tile_ptr1[16] = + (frag_dq_T[12] - cache_k_zero_point_smem[col_idx + 16]) * + cache_k_scale_smem[col_idx + 16]; + k_tile_ptr1[17] = + (frag_dq_T[13] - cache_k_zero_point_smem[col_idx + 17]) * + cache_k_scale_smem[col_idx + 17]; + k_tile_ptr1[24] = + (frag_dq_T[14] - cache_k_zero_point_smem[col_idx + 24]) * + cache_k_scale_smem[col_idx + 24]; + k_tile_ptr1[25] = + (frag_dq_T[15] - cache_k_zero_point_smem[col_idx + 25]) * + cache_k_scale_smem[col_idx + 25]; + } + col_idx += 32; + } + k_smem_offset_r = k_smem.advance_offset_by_column<2, num_vecs_per_head_k>( + k_smem_offset_r, fy); + } + k_smem_offset_r = + k_smem.advance_offset_by_row<16 * NUM_WARPS, num_vecs_per_head_k>( + k_smem_offset_r) - + 4; + } + + // ================v================ + smem_t v_smem(smem + BLOCK_SIZE * HEAD_DIM * sizeof(CacheT) / 2); + uint32_t v_smem_offset_w = + smem_t::get_permuted_offset( + wid * 16 + tid / 2, tid % 2); // 4 * 8 per warp + + uint32_t v_smem_offset_r = + smem_t::get_permuted_offset( + wid * 16 + 8 * (tid / 16) + tid % 8, (tid % 16) / 8); + + uint32_t v_read_idx = (wid * 16 + tid / 2) * BLOCK_SIZE_HALF + + tid % 2 * num_elems_per_128b(); + // load v_smem 128 rows 64 rows + for (int fy = 0; fy < 2; + fy++) { // 16 rows pre warp once, 64 rows all 4 warps once, need 2 iter + for (int fz = 0; fz < 1; fz++) { // 2 * 128b = 64 * int4 once, need 1 iter + v_smem.load_128b_async( + v_smem_offset_w, cur_cache_v + v_read_idx, end_idx > 0); + v_smem_offset_w = + v_smem.advance_offset_by_column<2, num_vecs_per_blocksize>( + v_smem_offset_w, fz); + v_read_idx += 2 * num_elems_per_128b(); + } + v_smem_offset_w = + v_smem.advance_offset_by_row<16 * NUM_WARPS, num_vecs_per_blocksize>( + v_smem_offset_w) - + 2; + v_read_idx += + 16 * NUM_WARPS * BLOCK_SIZE_HALF - 2 * num_elems_per_128b(); + } + + commit_group(); + wait_group<0>(); + __syncthreads(); + + // deal v_smem 128 rows 64 cols + for (int fy = 0; fy < 2; + fy++) { // 16 rows pre warp once, 64 rows all 4 warps once, need 2 iter + uint32_t dim_idx = fy * NUM_WARPS * 16 + wid * 16 + tid / 4; + for (int fz = 0; fz < 1; fz++) { // 2 * 128b = 64 * int4 once, need 1 iter + uint32_t kv_idx = fz * 64 + tid % 4 * 2; + v_smem.ldmatrix_m8n8x4(v_smem_offset_r, v_frag); + // layout + for (int i = 0; i < 2; i++) { + T *v_tile_ptr0 = v_write_ptr + kv_idx * kv_t_stride + + kv_head_idx * HEAD_DIM + dim_idx; + T *v_tile_ptr1 = v_tile_ptr0 + 8; + + convert_int4(frag_dq_T, v_frag[2 * i]); + convert_int4(frag_dq_T + 8, v_frag[2 * i + 1]); + if (kv_idx < end_idx) { + v_tile_ptr0[0] = (frag_dq_T[0] - cache_v_zero_point_smem[dim_idx]) * + cache_v_scale_smem[dim_idx]; + v_tile_ptr1[0] = + (frag_dq_T[4] - cache_v_zero_point_smem[dim_idx + 8]) * + cache_v_scale_smem[dim_idx + 8]; + } + if (kv_idx + 1 < end_idx) { + v_tile_ptr0[kv_t_stride] = + (frag_dq_T[1] - cache_v_zero_point_smem[dim_idx]) * + cache_v_scale_smem[dim_idx]; + v_tile_ptr1[kv_t_stride] = + (frag_dq_T[5] - cache_v_zero_point_smem[dim_idx + 8]) * + cache_v_scale_smem[dim_idx + 8]; + } + if (kv_idx + 8 < end_idx) { + v_tile_ptr0[8 * kv_t_stride] = + (frag_dq_T[2] - cache_v_zero_point_smem[dim_idx]) * + cache_v_scale_smem[dim_idx]; + v_tile_ptr1[8 * kv_t_stride] = + (frag_dq_T[6] - cache_v_zero_point_smem[dim_idx + 8]) * + cache_v_scale_smem[dim_idx + 8]; + } + if (kv_idx + 9 < end_idx) { + v_tile_ptr0[9 * kv_t_stride] = + (frag_dq_T[3] - cache_v_zero_point_smem[dim_idx]) * + cache_v_scale_smem[dim_idx]; + v_tile_ptr1[9 * kv_t_stride] = + (frag_dq_T[7] - cache_v_zero_point_smem[dim_idx + 8]) * + cache_v_scale_smem[dim_idx + 8]; + } + if (kv_idx + 16 < end_idx) { + v_tile_ptr0[16 * kv_t_stride] = + (frag_dq_T[8] - cache_v_zero_point_smem[dim_idx]) * + cache_v_scale_smem[dim_idx]; + v_tile_ptr1[16 * kv_t_stride] = + (frag_dq_T[12] - cache_v_zero_point_smem[dim_idx + 8]) * + cache_v_scale_smem[dim_idx + 8]; + } + if (kv_idx + 17 < end_idx) { + v_tile_ptr0[17 * kv_t_stride] = + (frag_dq_T[9] - cache_v_zero_point_smem[dim_idx]) * + cache_v_scale_smem[dim_idx]; + v_tile_ptr1[17 * kv_t_stride] = + (frag_dq_T[13] - cache_v_zero_point_smem[dim_idx + 8]) * + cache_v_scale_smem[dim_idx + 8]; + } + if (kv_idx + 24 < end_idx) { + v_tile_ptr0[24 * kv_t_stride] = + (frag_dq_T[10] - cache_v_zero_point_smem[dim_idx]) * + cache_v_scale_smem[dim_idx]; + v_tile_ptr1[24 * kv_t_stride] = + (frag_dq_T[14] - cache_v_zero_point_smem[dim_idx + 8]) * + cache_v_scale_smem[dim_idx + 8]; + } + if (kv_idx + 25 < end_idx) { + v_tile_ptr0[25 * kv_t_stride] = + (frag_dq_T[11] - cache_v_zero_point_smem[dim_idx]) * + cache_v_scale_smem[dim_idx]; + v_tile_ptr1[25 * kv_t_stride] = + (frag_dq_T[15] - cache_v_zero_point_smem[dim_idx + 8]) * + cache_v_scale_smem[dim_idx + 8]; + } + kv_idx += 32; + } + v_smem_offset_r = + v_smem.advance_offset_by_column<2, num_vecs_per_blocksize>( + v_smem_offset_r, fz); + } + v_smem_offset_r = + v_smem.advance_offset_by_row<16 * NUM_WARPS, num_vecs_per_blocksize>( + v_smem_offset_r) - + 2; } } template -void AppendDequantCache( - const paddle::Tensor &cache_k, - const paddle::Tensor &cache_v, - const paddle::Tensor &cache_k_dequant_scales, - const paddle::Tensor &cache_v_dequant_scales, - const paddle::Tensor &seq_lens_this_time, - const paddle::Tensor &seq_lens_decoder, - const paddle::Tensor &cu_seqlens_k, - const paddle::Tensor &block_tables, - const paddle::Tensor &cache_batch_ids, - const paddle::Tensor &cache_tile_ids_per_batch, - const paddle::Tensor &cache_num_blocks_x, - const int max_blocks_per_seq, - const int kv_num_heads, - const std::string &cache_quant_type, - paddle::Tensor *k_out, - paddle::Tensor *v_out, - const cudaStream_t& stream -) { +void AppendCacheKV( + const paddle::Tensor &cache_k, + const paddle::Tensor &cache_v, + const paddle::optional &cache_k_quant_scales, + const paddle::optional &cache_v_quant_scales, + const paddle::optional &cache_k_dequant_scales, + const paddle::optional &cache_v_dequant_scales, + const paddle::Tensor &cache_k_zp, + const paddle::Tensor &cache_v_zp, + const paddle::Tensor &seq_lens_this_time, + const paddle::Tensor &seq_lens_decoder, + const paddle::Tensor &cu_seqlens_k, + const paddle::Tensor &block_tables, + const paddle::Tensor &cache_batch_ids, + const paddle::Tensor &cache_tile_ids_per_batch, + const paddle::Tensor &cache_num_blocks_x, + const int max_blocks_per_seq, + const int kv_num_heads, + const std::string &cache_quant_type, + paddle::Tensor *k_out, + paddle::Tensor *v_out, + const cudaStream_t &stream) { using NV_TYPE = typename cascade_attn_type_traits::type; - if (cache_quant_type == "cache_int8" || cache_quant_type == "cache_fp8") { - constexpr int NUM_WARPS = 4; - int block_num = cache_num_blocks_x.data()[0]; - dim3 grids(block_num, 1, kv_num_heads); - dim3 blocks(32, NUM_WARPS); + constexpr int NUM_WARPS = 4; + int block_num = cache_num_blocks_x.data()[0]; + dim3 grids(block_num, 1, kv_num_heads); + dim3 blocks(32, NUM_WARPS); + if (cache_quant_type == "none") { + const uint32_t smem_size = BLOCK_SIZE * HEAD_DIM * sizeof(T) * 2; + auto kernel_func = + append_cache_kv_c16; + if (smem_size >= 48 * 1024) { + cudaFuncSetAttribute( + kernel_func, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); + } + launchWithPdlWhenEnabled( + kernel_func, + grids, + blocks, + smem_size, + stream, + reinterpret_cast(const_cast(cache_k.data())), + reinterpret_cast(const_cast(cache_v.data())), + reinterpret_cast(k_out->data()), + reinterpret_cast(v_out->data()), + seq_lens_this_time.data(), + seq_lens_decoder.data(), + cu_seqlens_k.data(), + block_tables.data(), + cache_batch_ids.data(), + cache_tile_ids_per_batch.data(), + max_blocks_per_seq, + kv_num_heads); + } else if (cache_quant_type == "cache_int8" || + cache_quant_type == "cache_fp8" || + cache_quant_type == "block_wise_fp8") { const uint32_t smem_size = BLOCK_SIZE * HEAD_DIM * sizeof(uint8_t) * 2; - auto kernel_func = append_dequant_cache_kv_c8; + auto kernel_func = append_cache_kv_c8; if (cache_quant_type == "cache_fp8") { - kernel_func = append_dequant_cache_kv_c8; + kernel_func = append_cache_kv_c8; + } else if (cache_quant_type == "block_wise_fp8") { + kernel_func = append_cache_kv_c8; } if (smem_size >= 48 * 1024) { - cudaFuncSetAttribute(kernel_func, - cudaFuncAttributeMaxDynamicSharedMemorySize, - smem_size); + cudaFuncSetAttribute( + kernel_func, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); } - kernel_func<<>>( - cache_k.data(), - cache_v.data(), - reinterpret_cast(k_out->data()), - reinterpret_cast(v_out->data()), - reinterpret_cast(const_cast(cache_k_dequant_scales.data())), - reinterpret_cast(const_cast(cache_v_dequant_scales.data())), - seq_lens_this_time.data(), - seq_lens_decoder.data(), - cu_seqlens_k.data(), - block_tables.data(), - cache_batch_ids.data(), - cache_tile_ids_per_batch.data(), - max_blocks_per_seq, - kv_num_heads - ); + launchWithPdlWhenEnabled( + kernel_func, + grids, + blocks, + smem_size, + stream, + cache_k.data(), + cache_v.data(), + reinterpret_cast(k_out->data()), + reinterpret_cast(v_out->data()), + cache_k_quant_scales ? reinterpret_cast(const_cast( + cache_k_quant_scales.get().data())) + : nullptr, + cache_v_quant_scales ? reinterpret_cast(const_cast( + cache_v_quant_scales.get().data())) + : nullptr, + cache_k_dequant_scales ? reinterpret_cast(const_cast( + cache_k_dequant_scales.get().data())) + : nullptr, + cache_v_dequant_scales ? reinterpret_cast(const_cast( + cache_v_dequant_scales.get().data())) + : nullptr, + seq_lens_this_time.data(), + seq_lens_decoder.data(), + cu_seqlens_k.data(), + block_tables.data(), + cache_batch_ids.data(), + cache_tile_ids_per_batch.data(), + max_blocks_per_seq, + kv_num_heads); + } else if (cache_quant_type == "cache_int4_zp") { + const uint32_t smem_size = + BLOCK_SIZE * HEAD_DIM * sizeof(uint8_t) + 4 * HEAD_DIM * sizeof(T); + + auto kernel_func = + append_cache_kv_c4; + + if (smem_size >= 48 * 1024) { + cudaFuncSetAttribute( + kernel_func, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); + } + launchWithPdlWhenEnabled( + kernel_func, + grids, + blocks, + smem_size, + stream, + cache_k.data(), + cache_v.data(), + reinterpret_cast(k_out->data()), + reinterpret_cast(v_out->data()), + reinterpret_cast( + const_cast(cache_k_dequant_scales.get().data())), + reinterpret_cast( + const_cast(cache_v_dequant_scales.get().data())), + reinterpret_cast(const_cast(cache_k_zp.data())), + reinterpret_cast(const_cast(cache_v_zp.data())), + seq_lens_this_time.data(), + seq_lens_decoder.data(), + cu_seqlens_k.data(), + block_tables.data(), + cache_batch_ids.data(), + cache_tile_ids_per_batch.data(), + max_blocks_per_seq, + kv_num_heads); } else { PADDLE_THROW("%s mode isn't implemented yet", cache_quant_type.c_str()); } } std::vector GQARopeWriteCacheKernel( - const paddle::Tensor& qkv, - const paddle::Tensor& key_cache, - const paddle::Tensor& value_cache, - const paddle::Tensor& cu_seqlens_q, - const paddle::Tensor& cu_seqlens_k, - const paddle::Tensor& rotary_embs, - const paddle::Tensor& seq_lens_this_time, - const paddle::Tensor& seq_lens_encoder, - const paddle::Tensor& seq_lens_decoder, - const paddle::Tensor& batch_id_per_token, - const paddle::Tensor& block_tables, - const paddle::Tensor& kv_batch_ids, - const paddle::Tensor& kv_tile_ids, - const paddle::Tensor& kv_num_blocks, - const paddle::Tensor& cache_batch_ids, - const paddle::Tensor& cache_tile_ids, - const paddle::Tensor& cache_num_blocks, - const paddle::optional& cache_k_quant_scales, - const paddle::optional& cache_v_quant_scales, - const paddle::optional& cache_k_dequant_scales, - const paddle::optional& cache_v_dequant_scales, - const paddle::optional& cache_k_zp, - const paddle::optional& cache_v_zp, - const paddle::optional& kv_signal_data, + const paddle::Tensor &qkv, + const paddle::Tensor &key_cache, + const paddle::Tensor &value_cache, + const paddle::Tensor &cu_seqlens_q, + const paddle::Tensor &cu_seqlens_k, + const paddle::Tensor &rotary_embs, + const paddle::Tensor &seq_lens_this_time, + const paddle::Tensor &seq_lens_encoder, + const paddle::Tensor &seq_lens_decoder, + const paddle::Tensor &batch_id_per_token, + const paddle::Tensor &block_tables, + const paddle::Tensor &kv_batch_ids, + const paddle::Tensor &kv_tile_ids, + const paddle::Tensor &kv_num_blocks, + const paddle::Tensor &cache_batch_ids, + const paddle::Tensor &cache_tile_ids, + const paddle::Tensor &cache_num_blocks, + const paddle::optional &q_norm_weight, + const paddle::optional &k_norm_weight, + const paddle::optional &cache_k_quant_scales, + const paddle::optional &cache_v_quant_scales, + const paddle::optional &cache_k_dequant_scales, + const paddle::optional &cache_v_dequant_scales, + const paddle::optional &cache_k_zp, + const paddle::optional &cache_v_zp, + const paddle::optional &kv_signal_data, const int kv_token_num, const int max_seq_len, - const std::string& cache_quant_type) { + const float rms_norm_eps, + const bool use_neox_rotary_style, + const std::string &cache_quant_type, + const bool rope_3d) { typedef PDTraits traits_; typedef typename traits_::DataType DataType_; typedef typename traits_::data_t data_t; const int kv_num_blocks_data = kv_num_blocks.data()[0]; - const auto& qkv_dims = qkv.dims(); - const auto& key_cache_dims = key_cache.dims(); + const auto &qkv_dims = qkv.dims(); + const auto &key_cache_dims = key_cache.dims(); const int token_num = qkv_dims[0]; const int max_blocks_per_seq = block_tables.dims()[1]; const int block_size = key_cache.dims()[2]; const int batch_size = seq_lens_this_time.dims()[0]; const int kv_num_heads = key_cache_dims[1]; - const int head_dim = key_cache_dims[3]; - const int num_heads = qkv_dims[qkv_dims.size() - 1] / head_dim - 2 * kv_num_heads; + const int head_dim = cache_quant_type == "cache_int4_zp" + ? key_cache_dims[3] * 2 + : key_cache_dims[3]; + const int num_heads = + qkv_dims[qkv_dims.size() - 1] / head_dim - 2 * kv_num_heads; const float softmax_scale = 1.f / sqrt(head_dim); + int rotary_dim = head_dim; + + PADDLE_ENFORCE_EQ(batch_id_per_token.dims().size(), 1); + PADDLE_ENFORCE_EQ(batch_id_per_token.dims()[0], token_num); + + if (!rope_3d) { + PADDLE_ENFORCE_EQ(rotary_embs.dims().size(), 5); + PADDLE_ENFORCE_EQ(rotary_embs.dims()[0], 2); + PADDLE_ENFORCE_EQ(rotary_embs.dims()[1], 1); + PADDLE_ENFORCE_EQ(rotary_embs.dims()[2], max_seq_len); + PADDLE_ENFORCE_EQ(rotary_embs.dims()[3], 1); + if (use_neox_rotary_style) { + // Note(ZKK) Qwen3 like model + // the [0,head_dim/2), [head_dim/2,head_dim) data are totally same! + if (rotary_embs.dims()[4] == head_dim) { + rotary_dim = head_dim; + } else { + // for glm partial rotary style + PADDLE_ENFORCE_EQ(rotary_embs.dims()[4], head_dim / 4); + rotary_dim = head_dim / 2; + } + } else { + PADDLE_ENFORCE_EQ(rotary_embs.dims()[4], head_dim / 2); + } + } AppendAttnMetaData meta_data; meta_data.token_nums = token_num; @@ -466,68 +1428,157 @@ std::vector GQARopeWriteCacheKernel( meta_data.block_size = block_size; meta_data.batch_size = seq_lens_this_time.dims()[0]; - phi::GPUContext* dev_ctx = static_cast(phi::DeviceContextPool::Instance().Get(qkv.place())); - auto stream = qkv.stream(); - paddle::Tensor qkv_out = GetEmptyTensor( - qkv.dims(), - qkv.dtype(), - qkv.place()); + paddle::Tensor qkv_out = GetEmptyTensor(qkv.dims(), qkv.dtype(), qkv.place()); paddle::Tensor q = GetEmptyTensor( - {token_num, num_heads, head_dim}, - qkv.dtype(), - qkv.place()); + {token_num, num_heads, head_dim}, qkv.dtype(), qkv.place()); paddle::Tensor k = GetEmptyTensor( - {kv_token_num, kv_num_heads, head_dim}, - qkv.dtype(), - qkv.place()); + {kv_token_num, kv_num_heads, head_dim}, qkv.dtype(), qkv.place()); paddle::Tensor v = GetEmptyTensor( - {kv_token_num, kv_num_heads, head_dim}, - qkv.dtype(), - qkv.place()); - - // rope - gqa_rotary_qk_split_variable( - qkv_out.data(), - q.data(), - k.data(), - v.data(), - qkv.data(), - rotary_embs.data(), - batch_id_per_token.data(), - seq_lens_encoder.data(), - seq_lens_decoder.data(), - cu_seqlens_q.data(), - cu_seqlens_k.data(), - token_num, - num_heads, - kv_num_heads, - max_seq_len, - rotary_embs.dims()[2], - head_dim, - stream); + {kv_token_num, kv_num_heads, head_dim}, qkv.dtype(), qkv.place()); + + bool enforce_fmul_rn = getEnvEnableRL(); + DISPATCH_BOOL_DTYPE(enforce_fmul_rn, EnforceFmulRN, { + if (use_neox_rotary_style) { + if (rotary_dim == head_dim) { + gqa_rotary_qk_split_variable_qwen3( + qkv_out.data(), + q.data(), + k.data(), + v.data(), + qkv.data(), + rotary_embs.data(), + batch_id_per_token.data(), + seq_lens_encoder.data(), + seq_lens_decoder.data(), + cu_seqlens_q.data(), + cu_seqlens_k.data(), + token_num, + num_heads, + kv_num_heads, + rope_3d ? rotary_embs.dims()[3] : rotary_embs.dims()[2], + head_dim, + rope_3d, + stream); + } else { + gqa_neox_partial_rotary_qk_split_variable( + qkv_out.data(), + q.data(), + k.data(), + v.data(), + qkv.data(), + rotary_embs.data(), + batch_id_per_token.data(), + seq_lens_encoder.data(), + seq_lens_decoder.data(), + cu_seqlens_q.data(), + cu_seqlens_k.data(), + token_num, + num_heads, + kv_num_heads, + max_seq_len, + head_dim, + rotary_dim, + stream); + } + } else { + gqa_rotary_qk_split_variable( + qkv_out.data(), + q.data(), + k.data(), + v.data(), + qkv.data(), + rotary_embs.data(), + q_norm_weight ? q_norm_weight.get().data() : nullptr, + k_norm_weight ? k_norm_weight.get().data() : nullptr, + batch_id_per_token.data(), + seq_lens_encoder.data(), + seq_lens_decoder.data(), + cu_seqlens_q.data(), + cu_seqlens_k.data(), + token_num, + num_heads, + kv_num_heads, + max_seq_len, + rope_3d ? rotary_embs.dims()[3] : rotary_embs.dims()[2], + head_dim, + rope_3d, + rms_norm_eps, + stream); + } + }) + + if (token_num < kv_token_num) { + AppendCacheKV(key_cache, + value_cache, + cache_k_quant_scales, + cache_v_quant_scales, + cache_k_dequant_scales, + cache_v_dequant_scales, + cache_k_zp.get(), + cache_v_zp.get(), + seq_lens_this_time, + seq_lens_decoder, + cu_seqlens_k, + block_tables, + cache_batch_ids, + cache_tile_ids, + cache_num_blocks, + max_blocks_per_seq, + kv_num_heads, + cache_quant_type, + &k, + &v, + stream); + } // write cache if (cache_quant_type == "none") { CascadeAppendWriteCacheKVQKV( - meta_data, - qkv_out, - block_tables, - batch_id_per_token, - cu_seqlens_q, - seq_lens_encoder, - seq_lens_decoder, - max_seq_len, - stream, - const_cast(&key_cache), - const_cast(&value_cache)); - } else if (cache_quant_type == "cache_int8" || cache_quant_type == "cache_fp8") { + meta_data, + qkv_out, + block_tables, + batch_id_per_token, + cu_seqlens_q, + seq_lens_encoder, + seq_lens_decoder, + max_seq_len, + stream, + const_cast(&key_cache), + const_cast(&value_cache)); + } else if (cache_quant_type == "cache_int8" || + cache_quant_type == "cache_fp8" || + cache_quant_type == "block_wise_fp8") { CascadeAppendWriteCacheKVC8QKV( meta_data, - *const_cast(&key_cache), - *const_cast(&value_cache), + *const_cast(&key_cache), + *const_cast(&value_cache), + qkv_out, + cache_k_quant_scales.get(), + cache_v_quant_scales.get(), + seq_lens_this_time, + seq_lens_decoder, + batch_id_per_token, + cu_seqlens_q, + block_tables, + kv_batch_ids, + kv_tile_ids, + kv_num_blocks_data, + max_seq_len, + false, // is_scale_channel_wise + cache_quant_type, + stream, + const_cast(&key_cache), + const_cast(&value_cache)); + } else if (cache_quant_type == "cache_int4_zp") { + CascadeAppendWriteCacheKVC4QKV( + meta_data, + *const_cast(&key_cache), + *const_cast(&value_cache), qkv_out, cache_k_quant_scales.get(), cache_v_quant_scales.get(), + cache_k_zp.get(), + cache_v_zp.get(), seq_lens_this_time, seq_lens_decoder, batch_id_per_token, @@ -537,52 +1588,38 @@ std::vector GQARopeWriteCacheKernel( kv_tile_ids, kv_num_blocks_data, max_seq_len, - false, // is_scale_channel_wise - cache_quant_type == "cache_fp8", // is_fp8 stream, - const_cast(&key_cache), - const_cast(&value_cache)); + const_cast(&key_cache), + const_cast(&value_cache)); + } else { + PD_THROW( + "cache_quant_type_str should be one of [none, cache_int8, cache_fp8, " + "cache_int4_zp]"); } - const char* fmt_write_cache_completed_signal_str = std::getenv("FLAGS_fmt_write_cache_completed_signal"); - const char* FLAGS_use_pd_disaggregation_per_chunk = std::getenv("FLAGS_use_pd_disaggregation_per_chunk"); + const char *fmt_write_cache_completed_signal_str = + std::getenv("FLAGS_fmt_write_cache_completed_signal"); + const char *FLAGS_use_pd_disaggregation_per_chunk = + std::getenv("FLAGS_use_pd_disaggregation_per_chunk"); if (fmt_write_cache_completed_signal_str && (std::strcmp(fmt_write_cache_completed_signal_str, "true") == 0 || std::strcmp(fmt_write_cache_completed_signal_str, "1") == 0)) { - if (FLAGS_use_pd_disaggregation_per_chunk && - (std::strcmp(FLAGS_use_pd_disaggregation_per_chunk, "true") == 0 || - std::strcmp(FLAGS_use_pd_disaggregation_per_chunk, "1") == 0)) { - cudaLaunchHostFunc(qkv.stream(), - &(RemoteCacheKvIpc::save_cache_kv_complete_signal_layerwise_per_query), - (void*)nullptr); - } else { - if (kv_signal_data) { - cudaLaunchHostFunc(qkv.stream(), - &RemoteCacheKvIpc::save_cache_kv_complete_signal_layerwise, - (void*)(const_cast(kv_signal_data.get().data()))); - } + if (FLAGS_use_pd_disaggregation_per_chunk && + (std::strcmp(FLAGS_use_pd_disaggregation_per_chunk, "true") == 0 || + std::strcmp(FLAGS_use_pd_disaggregation_per_chunk, "1") == 0)) { + cudaLaunchHostFunc( + qkv.stream(), + &(RemoteCacheKvIpc:: + save_cache_kv_complete_signal_layerwise_per_query), + (void *)nullptr); + } else { + if (kv_signal_data) { + cudaLaunchHostFunc( + qkv.stream(), + &RemoteCacheKvIpc::save_cache_kv_complete_signal_layerwise, + (void *)(const_cast( + kv_signal_data.get().data()))); } - } - - if (token_num < kv_token_num) { - AppendDequantCache( - key_cache, - value_cache, - cache_k_dequant_scales.get(), - cache_v_dequant_scales.get(), - seq_lens_this_time, - seq_lens_decoder, - cu_seqlens_k, - block_tables, - cache_batch_ids, - cache_tile_ids, - cache_num_blocks, - max_blocks_per_seq, - kv_num_heads, - cache_quant_type, - &k, - &v, - stream - ); + } } return {q, k, v, qkv_out}; } @@ -605,6 +1642,8 @@ PD_BUILD_STATIC_OP(gqa_rope_write_cache) "cache_batch_ids", "cache_tile_ids_per_batch", "cache_num_blocks", + paddle::Optional("q_norm_weight"), + paddle::Optional("k_norm_weight"), paddle::Optional("cache_k_quant_scales"), paddle::Optional("cache_v_quant_scales"), paddle::Optional("cache_k_dequant_scales"), @@ -612,15 +1651,13 @@ PD_BUILD_STATIC_OP(gqa_rope_write_cache) paddle::Optional("cache_k_zp"), paddle::Optional("cache_v_zp"), paddle::Optional("kv_signal_data")}) - .Outputs({"q", - "k", - "v", - "qkv_out", - "key_cache_out", - "value_cache_out"}) + .Outputs({"q", "k", "v", "qkv_out", "key_cache_out", "value_cache_out"}) .SetInplaceMap({{"key_cache", "key_cache_out"}, {"value_cache", "value_cache_out"}}) .Attrs({"kv_token_num: int", "max_seq_len: int", - "cache_quant_type: std::string"}) + "rms_norm_eps: float", + "use_neox_rotary_style: bool", + "cache_quant_type: std::string", + "rope_3d: bool"}) .SetKernelFn(PD_KERNEL(GQARopeWriteCacheKernel)); diff --git a/custom_ops/gpu_ops/append_attn/mem_util.cuh b/custom_ops/gpu_ops/append_attn/mem_util.cuh index 89b65992d6a..25c7a623881 100644 --- a/custom_ops/gpu_ops/append_attn/mem_util.cuh +++ b/custom_ops/gpu_ops/append_attn/mem_util.cuh @@ -15,6 +15,7 @@ #include #include +#include enum class SharedMemFillMode { kFillZero, kNoFill }; @@ -42,18 +43,35 @@ __device__ __forceinline__ void ldmatrix_m8n8x4_trans_impl(uint32_t* R, } __device__ __forceinline__ void commit_group() { +#ifdef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU + {} +#else asm volatile("cp.async.commit_group;\n" ::); +#endif } template __device__ __forceinline__ void wait_group() { +#ifdef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU + cooperative_groups::wait(cooperative_groups::this_thread_block()); +#else asm volatile("cp.async.wait_group %0;\n" ::"n"(n)); +#endif } template __device__ __forceinline__ void load_128b(T* smem_ptr, const T* gmem_ptr) { uint32_t smem_int_ptr = static_cast(__cvta_generic_to_shared(smem_ptr)); +#ifdef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU + if constexpr (prefetch_mode == PrefetchMode::kPrefetch) { + memset(__cvta_shared_to_generic(smem_int_ptr), 0, 16); + memcpy(__cvta_shared_to_generic(smem_int_ptr), (void*)gmem_ptr, 16); + } else { + memset(__cvta_shared_to_generic(smem_int_ptr), 0, 16); + memcpy(__cvta_shared_to_generic(smem_int_ptr), (void*)gmem_ptr, 16); + } +#else if constexpr (prefetch_mode == PrefetchMode::kPrefetch) { asm volatile( "cp.async.cg.shared.global.L2::128B [%0], [%1], %2, %3;\n" ::"r"( @@ -68,6 +86,7 @@ __device__ __forceinline__ void load_128b(T* smem_ptr, const T* gmem_ptr) { "n"(16), "r"(16)); } +#endif } template @@ -76,6 +95,32 @@ __device__ __forceinline__ void pred_load_128b(T* smem_ptr, bool predicate) { uint32_t smem_int_ptr = static_cast(__cvta_generic_to_shared(smem_ptr)); +#ifdef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU + if constexpr (fill_mode == SharedMemFillMode::kFillZero) { + int src_in_bytes = predicate ? 16 : 0; + if constexpr (prefetch_mode == PrefetchMode::kPrefetch) { + memset(__cvta_shared_to_generic(smem_int_ptr), 0, 16); + memcpy(__cvta_shared_to_generic(smem_int_ptr), + (void*)gmem_ptr, + src_in_bytes); + } else { + memset(__cvta_shared_to_generic(smem_int_ptr), 0, 16); + memcpy(__cvta_shared_to_generic(smem_int_ptr), + (void*)gmem_ptr, + src_in_bytes); + } + } else { + if constexpr (prefetch_mode == PrefetchMode::kPrefetch) { + if (predicate) { + memcpy(__cvta_shared_to_generic(smem_int_ptr), (void*)gmem_ptr, 16); + } + } else { + if (predicate) { + memcpy(__cvta_shared_to_generic(smem_int_ptr), (void*)gmem_ptr, 16); + } + } + } +#else if constexpr (fill_mode == SharedMemFillMode::kFillZero) { int src_in_bytes = predicate ? 16 : 0; if constexpr (prefetch_mode == PrefetchMode::kPrefetch) { @@ -115,6 +160,7 @@ __device__ __forceinline__ void pred_load_128b(T* smem_ptr, "n"(16)); } } +#endif } template @@ -123,6 +169,18 @@ __device__ __forceinline__ void pred_load_64b(T* smem_ptr, bool predicate) { uint32_t smem_int_ptr = static_cast(__cvta_generic_to_shared(smem_ptr)); +#ifdef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU + if constexpr (fill_mode == SharedMemFillMode::kFillZero) { + int src_in_bytes = predicate ? 8 : 0; + memset(__cvta_shared_to_generic(smem_int_ptr), 0, 8); + memcpy( + __cvta_shared_to_generic(smem_int_ptr), (void*)gmem_ptr, src_in_bytes); + } else { + if (predicate) { + memcpy(__cvta_shared_to_generic(smem_int_ptr), (void*)gmem_ptr, 8); + } + } +#else if constexpr (fill_mode == SharedMemFillMode::kFillZero) { int src_in_bytes = predicate ? 8 : 0; asm volatile( @@ -141,6 +199,7 @@ __device__ __forceinline__ void pred_load_64b(T* smem_ptr, "l"(gmem_ptr), "n"(8)); } +#endif } template @@ -149,6 +208,18 @@ __device__ __forceinline__ void pred_load_32b(T* smem_ptr, bool predicate) { uint32_t smem_int_ptr = static_cast(__cvta_generic_to_shared(smem_ptr)); +#ifdef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU + if constexpr (fill_mode == SharedMemFillMode::kFillZero) { + int src_in_bytes = predicate ? 4 : 0; + memset(__cvta_shared_to_generic(smem_int_ptr), 0, 4); + memcpy( + __cvta_shared_to_generic(smem_int_ptr), (void*)gmem_ptr, src_in_bytes); + } else { + if (predicate) { + memcpy(__cvta_shared_to_generic(smem_int_ptr), (void*)gmem_ptr, 4); + } + } +#else if constexpr (fill_mode == SharedMemFillMode::kFillZero) { int src_in_bytes = predicate ? 4 : 0; asm volatile( @@ -167,6 +238,7 @@ __device__ __forceinline__ void pred_load_32b(T* smem_ptr, "l"(gmem_ptr), "n"(4)); } +#endif } template @@ -209,7 +281,6 @@ struct smem_t { template __device__ __forceinline__ smem_t(T* base) : base((b128_t*)base) {} - template static __device__ __forceinline__ uint32_t get_permuted_offset(uint32_t i, uint32_t j) { diff --git a/custom_ops/gpu_ops/append_attn/mla_cache_kernel.cu b/custom_ops/gpu_ops/append_attn/mla_cache_kernel.cu index d2ee6bd732e..b582c862c38 100644 --- a/custom_ops/gpu_ops/append_attn/mla_cache_kernel.cu +++ b/custom_ops/gpu_ops/append_attn/mla_cache_kernel.cu @@ -14,20 +14,23 @@ #pragma once #include "mla_cache_kernel.cuh" +#include "helper.h" +#include "remote_cache_kv_ipc.h" template std::vector PrefillMLAWriteCache( - const AppendAttnMetaData& meta_data, - const paddle::Tensor& kv_nope, - const paddle::Tensor& kv_pe, - const paddle::Tensor& seq_lens, - const paddle::Tensor& seq_lens_decoder, - const paddle::Tensor& batch_id_per_token, - const paddle::Tensor& cu_seqlens_q, - const paddle::Tensor& block_tables, - const int max_seq_len, - cudaStream_t& stream, - paddle::Tensor* kv_cache) { + const AppendAttnMetaData& meta_data, + const paddle::Tensor& kv_nope, + const paddle::Tensor& kv_pe, + const paddle::Tensor& seq_lens, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& batch_id_per_token, + const paddle::Tensor& cu_seqlens_q, + const paddle::Tensor& block_tables, + const paddle::optional& kv_signal_data, + const int max_seq_len, + cudaStream_t& stream, + paddle::Tensor* kv_cache) { typedef PDTraits traits_; typedef typename traits_::DataType DataType_; typedef typename traits_::data_t data_t; @@ -49,8 +52,10 @@ std::vector PrefillMLAWriteCache( prefill_absorb_cache_kernel <<>>( - reinterpret_cast(const_cast(kv_nope.data())), - reinterpret_cast(const_cast(kv_pe.data())), + reinterpret_cast( + const_cast(kv_nope.data())), + reinterpret_cast( + const_cast(kv_pe.data())), reinterpret_cast(kv_cache->data()), block_tables.data(), batch_id_per_token.data(), @@ -64,6 +69,33 @@ std::vector PrefillMLAWriteCache( pe_size, block_size, elem_nums); + + const char* fmt_write_cache_completed_signal_str = + std::getenv("FLAGS_fmt_write_cache_completed_signal"); + const char* FLAGS_use_pd_disaggregation_per_chunk = + std::getenv("FLAGS_use_pd_disaggregation_per_chunk"); + + if (fmt_write_cache_completed_signal_str && + (std::strcmp(fmt_write_cache_completed_signal_str, "true") == 0 || + std::strcmp(fmt_write_cache_completed_signal_str, "1") == 0)) { + if (FLAGS_use_pd_disaggregation_per_chunk && + (std::strcmp(FLAGS_use_pd_disaggregation_per_chunk, "true") == 0 || + std::strcmp(FLAGS_use_pd_disaggregation_per_chunk, "1") == 0)) { + cudaLaunchHostFunc( + stream, + &(RemoteCacheKvIpc:: + save_cache_kv_complete_signal_layerwise_per_query), + (void*)nullptr); + } else { + if (kv_signal_data) { + cudaLaunchHostFunc( + stream, + &RemoteCacheKvIpc::save_cache_kv_complete_signal_layerwise, + (void*)(const_cast( + kv_signal_data.get().data()))); + } + } + } return {}; } @@ -76,6 +108,7 @@ std::vector PrefillMLAWriteCacheKernel( const paddle::Tensor& batch_id_per_token, const paddle::Tensor& cu_seqlens_q, const paddle::Tensor& block_tables, + const paddle::optional& kv_signal_data, const std::string& cache_quant_type_str, const int max_seq_len) { cudaStream_t stream = kv_pe.stream(); @@ -84,40 +117,45 @@ std::vector PrefillMLAWriteCacheKernel( const auto& kv_pe_dims = kv_pe.dims(); const auto& kv_cache_dims = kv_cache.dims(); meta_data.kv_num_heads = kv_cache_dims[1]; - const auto nope_size = kv_nope_dims[kv_nope_dims.size() - 1] / meta_data.kv_num_heads; + const auto nope_size = + kv_nope_dims[kv_nope_dims.size() - 1] / meta_data.kv_num_heads; meta_data.token_nums = kv_nope_dims[0]; meta_data.head_dims = kv_cache_dims[3]; meta_data.head_dims_v = nope_size; meta_data.max_blocks_per_seq = block_tables.dims()[1]; meta_data.block_size = kv_cache_dims[2]; - meta_data.batch_size = cu_seqlens_q.dims()[0]; + meta_data.batch_size = seq_lens_decoder.dims()[0]; switch (kv_pe.dtype()) { case paddle::DataType::BFLOAT16: { - return PrefillMLAWriteCache(meta_data, - kv_nope, - kv_pe, - seq_lens, - seq_lens_decoder, - batch_id_per_token, - cu_seqlens_q, - block_tables, - max_seq_len, - stream, - const_cast(&kv_cache)); + return PrefillMLAWriteCache( + meta_data, + kv_nope, + kv_pe, + seq_lens, + seq_lens_decoder, + batch_id_per_token, + cu_seqlens_q, + block_tables, + kv_signal_data, + max_seq_len, + stream, + const_cast(&kv_cache)); } case paddle::DataType::FLOAT16: { - return PrefillMLAWriteCache(meta_data, - kv_nope, - kv_pe, - seq_lens, - seq_lens_decoder, - batch_id_per_token, - cu_seqlens_q, - block_tables, - max_seq_len, - stream, - const_cast(&kv_cache)); + return PrefillMLAWriteCache( + meta_data, + kv_nope, + kv_pe, + seq_lens, + seq_lens_decoder, + batch_id_per_token, + cu_seqlens_q, + block_tables, + kv_signal_data, + max_seq_len, + stream, + const_cast(&kv_cache)); } } return {}; @@ -125,18 +163,18 @@ std::vector PrefillMLAWriteCacheKernel( template std::vector DecodeMLAWriteCache( - const AppendAttnMetaData& meta_data, - const paddle::Tensor& kv_nope, - const paddle::Tensor& kv_pe, - const paddle::Tensor& seq_lens, - const paddle::Tensor& seq_lens_encoder, - const paddle::Tensor& batch_id_per_token, - const paddle::Tensor& cu_seqlens_q, - const paddle::Tensor& block_tables, - const int max_seq_len, - const bool speculate_decoder, - cudaStream_t& stream, - paddle::Tensor* kv_cache) { + const AppendAttnMetaData& meta_data, + const paddle::Tensor& kv_nope, + const paddle::Tensor& kv_pe, + const paddle::Tensor& seq_lens, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& batch_id_per_token, + const paddle::Tensor& cu_seqlens_q, + const paddle::Tensor& block_tables, + const int max_seq_len, + const bool speculate_decoder, + cudaStream_t& stream, + paddle::Tensor* kv_cache) { typedef PDTraits traits_; typedef typename traits_::DataType DataType_; typedef typename traits_::data_t data_t; @@ -153,15 +191,16 @@ std::vector DecodeMLAWriteCache( const int blocksize = 128; int grid_size = 1; - if (speculate_decoder) { const uint32_t elem_nums = token_num * kv_num_heads * all_size; const int pack_num = elem_nums / PackSize; GetNumBlocks<128>(pack_num, &grid_size); speculate_decode_absorb_cache_kernel <<>>( - reinterpret_cast(const_cast(kv_nope.data())), - reinterpret_cast(const_cast(kv_pe.data())), + reinterpret_cast( + const_cast(kv_nope.data())), + reinterpret_cast( + const_cast(kv_pe.data())), reinterpret_cast(kv_cache->data()), block_tables.data(), batch_id_per_token.data(), @@ -181,8 +220,10 @@ std::vector DecodeMLAWriteCache( GetNumBlocks<128>(pack_num, &grid_size); decode_absorb_cache_kernel <<>>( - reinterpret_cast(const_cast(kv_nope.data())), - reinterpret_cast(const_cast(kv_pe.data())), + reinterpret_cast( + const_cast(kv_nope.data())), + reinterpret_cast( + const_cast(kv_pe.data())), reinterpret_cast(kv_cache->data()), block_tables.data(), cu_seqlens_q.data(), @@ -217,49 +258,51 @@ std::vector DecodeMLAWriteCacheKernel( const auto& kv_pe_dims = kv_pe.dims(); const auto& kv_cache_dims = kv_cache.dims(); meta_data.kv_num_heads = kv_cache_dims[1]; - const auto nope_size = kv_nope_dims[kv_nope_dims.size() - 1] / meta_data.kv_num_heads; + const auto nope_size = + kv_nope_dims[kv_nope_dims.size() - 1] / meta_data.kv_num_heads; meta_data.token_nums = kv_nope_dims[0]; meta_data.head_dims = kv_cache_dims[3]; meta_data.head_dims_v = nope_size; meta_data.max_blocks_per_seq = block_tables.dims()[1]; meta_data.block_size = kv_cache_dims[2]; - meta_data.batch_size = cu_seqlens_q.dims()[0]; + meta_data.batch_size = seq_lens_encoder.dims()[0]; switch (kv_pe.dtype()) { case paddle::DataType::BFLOAT16: { - return DecodeMLAWriteCache(meta_data, - kv_nope, - kv_pe, - seq_lens, - seq_lens_encoder, - batch_id_per_token, - cu_seqlens_q, - block_tables, - max_seq_len, - speculate_decoder, - stream, - const_cast(&kv_cache)); + return DecodeMLAWriteCache( + meta_data, + kv_nope, + kv_pe, + seq_lens, + seq_lens_encoder, + batch_id_per_token, + cu_seqlens_q, + block_tables, + max_seq_len, + speculate_decoder, + stream, + const_cast(&kv_cache)); } case paddle::DataType::FLOAT16: { - return DecodeMLAWriteCache(meta_data, - kv_nope, - kv_pe, - seq_lens, - seq_lens_encoder, - batch_id_per_token, - cu_seqlens_q, - block_tables, - max_seq_len, - speculate_decoder, - stream, - const_cast(&kv_cache)); + return DecodeMLAWriteCache( + meta_data, + kv_nope, + kv_pe, + seq_lens, + seq_lens_encoder, + batch_id_per_token, + cu_seqlens_q, + block_tables, + max_seq_len, + speculate_decoder, + stream, + const_cast(&kv_cache)); } } return {}; } - -PD_BUILD_OP(prefill_mla_write_cache) +PD_BUILD_STATIC_OP(prefill_mla_write_cache) .Inputs({"kv_nope", "kv_pe", "kv_cache", @@ -267,14 +310,14 @@ PD_BUILD_OP(prefill_mla_write_cache) "seq_lens_decoder", "batch_id_per_token", "cu_seqlens_q", - "block_tables"}) + "block_tables", + paddle::Optional("kv_signal_data")}) .Outputs({"kv_cache_out"}) .SetInplaceMap({{"kv_cache", "kv_cache_out"}}) - .Attrs({"cache_quant_type_str: std::string", - "max_seq_len: int"}) + .Attrs({"cache_quant_type_str: std::string", "max_seq_len: int"}) .SetKernelFn(PD_KERNEL(PrefillMLAWriteCacheKernel)); -PD_BUILD_OP(decode_mla_write_cache) +PD_BUILD_STATIC_OP(decode_mla_write_cache) .Inputs({"kv_nope", "kv_pe", "kv_cache", diff --git a/custom_ops/gpu_ops/append_attn/mla_cache_kernel.cuh b/custom_ops/gpu_ops/append_attn/mla_cache_kernel.cuh index 2efcb7a8c61..ec5b428bda1 100644 --- a/custom_ops/gpu_ops/append_attn/mla_cache_kernel.cuh +++ b/custom_ops/gpu_ops/append_attn/mla_cache_kernel.cuh @@ -20,10 +20,10 @@ template __global__ void decode_absorb_cache_kernel( const T* __restrict__ kv_nope, // [bsz, kv_num_heads, pe_size] 512 - const T* __restrict__ kv_pe, // [bsz, kv_num_heads, nope_size] 64 - T* __restrict__ kv_cache, // [num_blocks, kv_num_heads, block_size, - // nope_size] - const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq] + const T* __restrict__ kv_pe, // [bsz, kv_num_heads, nope_size] 64 + T* __restrict__ kv_cache, // [num_blocks, kv_num_heads, block_size, + // nope_size] + const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq] const int* __restrict__ cu_seqlens_q, const int* __restrict__ seq_lens, // [bsz] const int* __restrict__ seq_lens_encoder, // [bsz] @@ -62,26 +62,25 @@ __global__ void decode_absorb_cache_kernel( const int block_idx = block_table_now[write_seq_id / block_size]; const int block_offset = write_seq_id % block_size; - if (bias < nope_hidden_size) { // pe + if (bias < nope_hidden_size) { // pe const uint32_t inner_bias = bias; const uint32_t hi = inner_bias / nope_size; const uint32_t h_bias = inner_bias % nope_size; - const uint32_t tgt_idx = block_idx * kv_num_heads * block_size * all_size + - hi * block_size * all_size + - block_offset * all_size + h_bias; - const uint32_t ori_idx = - start_token_idx * nope_hidden_size + inner_bias; + const uint32_t tgt_idx = + block_idx * kv_num_heads * block_size * all_size + + hi * block_size * all_size + block_offset * all_size + h_bias; + const uint32_t ori_idx = start_token_idx * nope_hidden_size + inner_bias; Load(&kv_nope[ori_idx], &src_vec); Store(src_vec, &kv_cache[tgt_idx]); } else { const uint32_t inner_bias = bias - nope_hidden_size; const uint32_t hi = inner_bias / pe_size; const uint32_t h_bias = inner_bias % pe_size; - const uint32_t tgt_idx = block_idx * kv_num_heads * block_size * all_size + - hi * block_size * all_size + - block_offset * all_size + nope_size + h_bias; - const uint32_t ori_idx = - start_token_idx * pe_hidden_size + inner_bias; + const uint32_t tgt_idx = + block_idx * kv_num_heads * block_size * all_size + + hi * block_size * all_size + block_offset * all_size + nope_size + + h_bias; + const uint32_t ori_idx = start_token_idx * pe_hidden_size + inner_bias; Load(&kv_pe[ori_idx], &src_vec); Store(src_vec, &kv_cache[tgt_idx]); } @@ -91,10 +90,10 @@ __global__ void decode_absorb_cache_kernel( template __global__ void speculate_decode_absorb_cache_kernel( const T* __restrict__ kv_nope, // [bsz, kv_num_heads, pe_size] 512 - const T* __restrict__ kv_pe, // [bsz, kv_num_heads, nope_size] 64 - T* __restrict__ kv_cache, // [num_blocks, kv_num_heads, block_size, - // nope_size] - const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq] + const T* __restrict__ kv_pe, // [bsz, kv_num_heads, nope_size] 64 + T* __restrict__ kv_cache, // [num_blocks, kv_num_heads, block_size, + // nope_size] + const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq] const int* __restrict__ batch_id_per_token, const int* __restrict__ cu_seqlens_q, const int* __restrict__ seq_lens, // [bsz] @@ -125,8 +124,7 @@ __global__ void speculate_decode_absorb_cache_kernel( if (seq_lens[ori_bi] == 0) continue; const int bias = linear_index % hidden_size; const int start_token_idx = cu_seqlens_q[ori_bi]; - const int write_seq_id = - seq_lens[ori_bi] + token_id - start_token_idx; + const int write_seq_id = seq_lens[ori_bi] + token_id - start_token_idx; if (write_seq_id == 0) continue; const int* block_table_now = nullptr; @@ -145,26 +143,25 @@ __global__ void speculate_decode_absorb_cache_kernel( token_id, cu_seqlens_q[ori_bi]); } - if (bias < nope_hidden_size) { // pe + if (bias < nope_hidden_size) { // pe const uint32_t inner_bias = bias; const uint32_t hi = inner_bias / nope_size; const uint32_t h_bias = inner_bias % nope_size; - const uint32_t tgt_idx = block_idx * kv_num_heads * block_size * all_size + - hi * block_size * all_size + - block_offset * all_size + h_bias; - const uint32_t ori_idx = - token_id * nope_hidden_size + inner_bias; + const uint32_t tgt_idx = + block_idx * kv_num_heads * block_size * all_size + + hi * block_size * all_size + block_offset * all_size + h_bias; + const uint32_t ori_idx = token_id * nope_hidden_size + inner_bias; Load(&kv_nope[ori_idx], &src_vec); Store(src_vec, &kv_cache[tgt_idx]); } else { const uint32_t inner_bias = bias - nope_hidden_size; const uint32_t hi = inner_bias / pe_size; const uint32_t h_bias = inner_bias % pe_size; - const uint32_t tgt_idx = block_idx * kv_num_heads * block_size * all_size + - hi * block_size * all_size + - block_offset * all_size + nope_size + h_bias; - const uint32_t ori_idx = - token_id * pe_hidden_size + inner_bias; + const uint32_t tgt_idx = + block_idx * kv_num_heads * block_size * all_size + + hi * block_size * all_size + block_offset * all_size + nope_size + + h_bias; + const uint32_t ori_idx = token_id * pe_hidden_size + inner_bias; Load(&kv_pe[ori_idx], &src_vec); Store(src_vec, &kv_cache[tgt_idx]); } @@ -174,10 +171,10 @@ __global__ void speculate_decode_absorb_cache_kernel( template __global__ void prefill_absorb_cache_kernel( const T* __restrict__ kv_nope, // [bsz, kv_num_heads, pe_size] 512 - const T* __restrict__ kv_pe, // [bsz, kv_num_heads, nope_size] 64 - T* __restrict__ kv_cache, // [num_blocks, kv_num_heads, block_size, - // nope_size] - const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq] + const T* __restrict__ kv_pe, // [bsz, kv_num_heads, nope_size] 64 + T* __restrict__ kv_cache, // [num_blocks, kv_num_heads, block_size, + // nope_size] + const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq] const int* __restrict__ batch_id_per_token, const int* __restrict__ cu_seqlens_q, const int* __restrict__ seq_lens, // [bsz] @@ -206,33 +203,33 @@ __global__ void prefill_absorb_cache_kernel( const uint32_t bias = linear_index % hidden_size; const uint32_t ori_bi = batch_id_per_token[token_idx]; if (seq_lens[ori_bi] == 0) continue; - const uint32_t ori_seq_id = (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi]; + const uint32_t ori_seq_id = + (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi]; const int* block_table_now = nullptr; block_table_now = block_tables + ori_bi * max_blocks_per_seq; const uint32_t block_idx = block_table_now[ori_seq_id / block_size]; const uint32_t block_offset = ori_seq_id % block_size; - if (bias < nope_hidden_size) { // pe + if (bias < nope_hidden_size) { // pe const uint32_t inner_bias = bias; const uint32_t hi = inner_bias / nope_size; const uint32_t h_bias = inner_bias % nope_size; - const uint32_t tgt_idx = block_idx * kv_num_heads * block_size * all_size + - hi * block_size * all_size + - block_offset * all_size + h_bias; - const uint32_t ori_idx = - token_idx * nope_hidden_size + inner_bias; + const uint32_t tgt_idx = + block_idx * kv_num_heads * block_size * all_size + + hi * block_size * all_size + block_offset * all_size + h_bias; + const uint32_t ori_idx = token_idx * nope_hidden_size + inner_bias; Load(&kv_nope[ori_idx], &src_vec); Store(src_vec, &kv_cache[tgt_idx]); } else { const uint32_t inner_bias = bias - nope_hidden_size; const uint32_t hi = inner_bias / pe_size; const uint32_t h_bias = inner_bias % pe_size; - const uint32_t tgt_idx = block_idx * kv_num_heads * block_size * all_size + - hi * block_size * all_size + - block_offset * all_size + nope_size + h_bias; - const uint32_t ori_idx = - token_idx * pe_hidden_size + inner_bias; + const uint32_t tgt_idx = + block_idx * kv_num_heads * block_size * all_size + + hi * block_size * all_size + block_offset * all_size + nope_size + + h_bias; + const uint32_t ori_idx = token_idx * pe_hidden_size + inner_bias; Load(&kv_pe[ori_idx], &src_vec); Store(src_vec, &kv_cache[tgt_idx]); } diff --git a/custom_ops/gpu_ops/append_attn/multiquery_attention_c16_impl.cuh b/custom_ops/gpu_ops/append_attn/multiquery_attention_c16_impl.cuh new file mode 100644 index 00000000000..cf283a617d2 --- /dev/null +++ b/custom_ops/gpu_ops/append_attn/multiquery_attention_c16_impl.cuh @@ -0,0 +1,1418 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include "helper.h" // For getBoolEnv +#include "multiquery_attention_c16_kernel.h" + +template +__global__ void multi_query_append_attention_kernel( + const T *__restrict__ q, // [token_num, (num_heads + 2* kv_num_head) * + // head_dim] + const T *__restrict__ cache_k, // [max_block_num, num_heads, block_size, + // head_dim] + const T *__restrict__ cache_v, + const T *__restrict__ shift_bias, // [q_num_heads * HEAD_DIM] + const T *__restrict__ smooth_weight, // [q_num_heads * HEAD_DIM] + const T *__restrict__ sinks, // [q_num_heads] + const int *__restrict__ seq_lens, + const int *__restrict__ seq_lens_kv, + const int *__restrict__ batch_ids, + const int *__restrict__ tile_ids_per_batch, + const int *__restrict__ cu_seqlens_q, + const int *__restrict__ block_table, // [bsz, block_num_per_seq] + const int *__restrict__ mask_offset, + const int max_seq_len, + const int max_block_num_per_seq, + const float scale, + const float quant_max_bound, + const float quant_min_bound, + const float in_scale, + const uint32_t chunk_size, + const int num_blocks_x_cpu, + T *__restrict__ tmp_workspace, // split kv [token_num, num_chunks, + // num_heads, head_dim] + float *__restrict__ tmp_m, // [token_num, num_chunks, num_heads] + float *__restrict__ tmp_d, // [token_num, num_chunks, num_heads] + OutT *__restrict__ out, + const int speculate_max_draft_token_num = 5, + const int sliding_window = 0, + const int sink_size = 0) { + constexpr uint32_t num_vecs_per_head = HEAD_DIM / num_elems_per_128b(); + const uint32_t btid = blockIdx.x, kv_head_idx = blockIdx.z; + const uint32_t kv_num_heads = gridDim.z; + const uint32_t q_num_heads = kv_num_heads * GROUP_SIZE; + const uint32_t q_head_idx = kv_head_idx * GROUP_SIZE; + const uint32_t tid = threadIdx.x, wid = threadIdx.y; + const uint32_t num_chunks = gridDim.y; + const uint32_t chunk_idx = blockIdx.y; + + const uint32_t batch_id = batch_ids[btid]; + const uint32_t tile_id = tile_ids_per_batch[btid]; + const uint32_t num_rows_per_block = NUM_WARPS * num_frags_x * 16; + const int *block_table_now = block_table + batch_id * max_block_num_per_seq; + + // When cudagraph capture prefill, may launch more gridDim.x + if (btid >= static_cast(num_blocks_x_cpu)) { + return; + } + + const uint32_t q_len = seq_lens[batch_id]; + if (q_len <= 0) { + return; + } + + uint32_t kv_len = seq_lens_kv[batch_id]; + if (ENABLE_PREFILL) { + kv_len += q_len; + if (kv_len <= 0) { + return; + } + } else { + if (kv_len <= 0) { + return; + } + kv_len += q_len; + } + + const uint32_t num_chunks_this_seq = div_up(kv_len, chunk_size); + if (chunk_idx >= num_chunks_this_seq) { + return; + } + + const uint32_t chunk_start = partition_kv ? chunk_idx * chunk_size : 0; + const uint32_t chunk_end = + partition_kv ? min(kv_len, chunk_start + chunk_size) : kv_len; + const uint32_t chunk_len = chunk_end - chunk_start; + + extern __shared__ uint8_t smem[]; + static_assert(num_frags_y * 16 == HEAD_DIM); + static_assert(num_frags_z * 16 == BLOCK_SIZE); + + float s_frag[num_frags_x][num_frags_z][8]; + float o_frag[num_frags_x][num_frags_y][8]; + float m_frag[num_frags_x][2]; + float d_frag[num_frags_x][2]; + init_states(o_frag, m_frag, d_frag); + + const uint32_t q_n_stride = q_num_heads * HEAD_DIM; + const uint32_t q_ori_n_stride = (q_num_heads + kv_num_heads * 2) * HEAD_DIM; + const uint32_t kv_n_stride = kv_num_heads * BLOCK_SIZE * HEAD_DIM; + const uint32_t kv_h_stride = BLOCK_SIZE * HEAD_DIM; + const uint32_t kv_b_stride = HEAD_DIM; + const uint32_t q_start_seq_id = cu_seqlens_q[batch_id]; + const uint32_t q_base_seq_id_this_block = + (tile_id * NUM_WARPS + wid) * num_frags_x * 16; + const uint32_t q_offset = q_start_seq_id * q_ori_n_stride + + q_head_idx * HEAD_DIM + + tid % 8 * num_elems_per_128b(); + const uint32_t o_offset = q_start_seq_id * q_n_stride + + q_head_idx * HEAD_DIM + + tid % 8 * num_elems_per_128b(); + const T *q_base_ptr = q + q_offset; + T *o_base_ptr_T = nullptr; + OutT *o_base_ptr_int8 = nullptr; + if constexpr (partition_kv) { + if (ENABLE_PREFILL) { + o_base_ptr_T = tmp_workspace + q_start_seq_id * num_chunks * q_n_stride + + chunk_idx * q_n_stride + q_head_idx * HEAD_DIM + + tid % 8 * num_elems_per_128b(); + } else { + o_base_ptr_T = + tmp_workspace + + batch_id * speculate_max_draft_token_num * num_chunks * q_n_stride + + chunk_idx * q_n_stride + q_head_idx * HEAD_DIM + + tid % 8 * num_elems_per_128b(); + } + } else { + o_base_ptr_int8 = out + o_offset; + } + const int *mask_offset_this_seq = + mask_offset ? mask_offset + q_start_seq_id * 2 : nullptr; + smem_t qo_smem(smem); + + uint32_t q_smem_offset_r = smem_t::get_permuted_offset( + wid * num_frags_x * 16 + tid % 16, tid / 16); // 16 * 16 + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaGridDependencySynchronize(); +#endif + load_q_global_smem( + q_base_ptr, + &qo_smem, + q_base_seq_id_this_block, + q_len, + q_ori_n_stride, + HEAD_DIM); + commit_group(); + wait_group<0>(); + __syncthreads(); + + q_smem_inplace_multiply_sm_scale(&qo_smem, + scale); + + smem_t k_smem(smem + NUM_WARPS * num_frags_x * 16 * HEAD_DIM * sizeof(T)), + v_smem(smem + (NUM_WARPS * num_frags_x + num_frags_z) * 16 * HEAD_DIM * + sizeof(T)); + + const uint32_t num_iterations = div_up( + CAUSAL + ? (min(chunk_len, + sub_if_greater_or_zero( + kv_len - q_len + + div_up((tile_id + 1) * num_rows_per_block, GROUP_SIZE), + chunk_start))) + : chunk_len, + num_frags_z * 16); + const uint32_t mask_check_iteration = + (CAUSAL ? (min(chunk_len, + sub_if_greater_or_zero( + kv_len - q_len + + tile_id * num_rows_per_block / GROUP_SIZE, + chunk_start))) + : mask_offset ? 0 + : chunk_len) / + (num_frags_z * 16); + uint32_t k_smem_offset_r = smem_t::get_permuted_offset( + 8 * (tid / 16) + tid % 8, (tid % 16) / 8); + + uint32_t v_smem_offset_r = + smem_t::get_permuted_offset(tid % 16, tid / 16); + + uint32_t kv_smem_offset_w = smem_t::get_permuted_offset( + wid * 4 + tid / 8, tid % 8); + + uint32_t kv_idx_base = chunk_start; + int block_id = __ldg(&block_table_now[kv_idx_base / BLOCK_SIZE]); + const uint32_t const_offset = kv_head_idx * kv_h_stride + + (wid * 4 + tid / 8) * kv_b_stride + + tid % 8 * num_elems_per_128b(); + const T *cache_k_now = cache_k + block_id * kv_n_stride + const_offset; + const T *cache_v_now = cache_v + block_id * kv_n_stride + const_offset; + + produce_kv_blockwise(k_smem, + &kv_smem_offset_w, + &cache_k_now, + kv_head_idx, + kv_n_stride, + kv_h_stride, + kv_b_stride, + kv_idx_base, + chunk_end); + commit_group(); + produce_kv_blockwise(v_smem, + &kv_smem_offset_w, + &cache_v_now, + kv_head_idx, + kv_n_stride, + kv_h_stride, + kv_b_stride, + kv_idx_base, + chunk_end); + commit_group(); +#pragma unroll 1 + for (uint32_t iter = 0; iter < num_iterations; ++iter) { + wait_group<1>(); + __syncthreads(); + + // s = qk + compute_qk( + &qo_smem, &q_smem_offset_r, &k_smem, &k_smem_offset_r, s_frag); + // mask according to kv_idx and q_idx + if (iter >= mask_check_iteration || sliding_window > 0) { + mask_s(nullptr, + q_base_seq_id_this_block, + kv_idx_base, + q_len, + kv_len, + chunk_end, + -1, + s_frag, + mask_offset_this_seq, + sliding_window, + sink_size); + } + + // update m,d + update_mdo_states( + s_frag, o_frag, m_frag, d_frag); + __syncthreads(); + + kv_idx_base += num_frags_z * 16; + block_id = __ldg(&block_table_now[kv_idx_base / BLOCK_SIZE]); + if (block_id < 0) { + block_id = 0; + } + cache_k_now = cache_k + block_id * kv_n_stride + const_offset; + produce_kv_blockwise(k_smem, + &kv_smem_offset_w, + &cache_k_now, + kv_head_idx, + kv_n_stride, + kv_h_stride, + kv_b_stride, + kv_idx_base, + chunk_end); + commit_group(); + wait_group<1>(); + __syncthreads(); + + // compute sfm*v + compute_sfm_v( + &v_smem, &v_smem_offset_r, s_frag, o_frag, d_frag); + + __syncthreads(); + cache_v_now = cache_v + block_id * kv_n_stride + const_offset; + produce_kv_blockwise(v_smem, + &kv_smem_offset_w, + &cache_v_now, + kv_head_idx, + kv_n_stride, + kv_h_stride, + kv_b_stride, + kv_idx_base, + chunk_end); + commit_group(); + } + wait_group<0>(); + __syncthreads(); + + if constexpr (!partition_kv) { + if (sinks) { + float current_sinks[num_frags_x][2]; +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (uint32_t j = 0; j < 2; ++j) { + const uint32_t h_offset = + (q_base_seq_id_this_block + fx * 16 + tid / 4 + 8 * j) % + GROUP_SIZE; + current_sinks[fx][j] = + static_cast(sinks[q_head_idx + h_offset]); + } + } + normalize_d( + o_frag, d_frag, m_frag, current_sinks); + } else { + normalize_d(o_frag, d_frag); + } + } + if constexpr (partition_kv) { + write_o_reg_gmem_shift_smooth_quant( + o_frag, + &qo_smem, + o_base_ptr_T, + shift_bias, + smooth_weight, + q_base_seq_id_this_block, + q_head_idx, + quant_max_bound, + quant_min_bound, + in_scale, + q_len, + partition_kv ? q_n_stride * num_chunks : q_n_stride, + HEAD_DIM); + } else { + write_o_reg_gmem_shift_smooth_quant( + o_frag, + &qo_smem, + o_base_ptr_int8, + shift_bias, + smooth_weight, + q_base_seq_id_this_block, + q_head_idx, + quant_max_bound, + quant_min_bound, + in_scale, + q_len, + partition_kv ? q_n_stride * num_chunks : q_n_stride, + HEAD_DIM); + } + + if constexpr (partition_kv) { +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (uint32_t j = 0; j < 2; ++j) { + const uint32_t qo_idx_now = + q_base_seq_id_this_block + tid / 4 + j * 8 + fx * 16; + const uint32_t qo_head_idx = q_head_idx + qo_idx_now % GROUP_SIZE; + const uint32_t qo_idx = q_start_seq_id + qo_idx_now / GROUP_SIZE; + if (qo_idx - q_start_seq_id < q_len) { + uint32_t offset; + if (ENABLE_PREFILL) { + offset = + (qo_idx * num_chunks + chunk_idx) * q_num_heads + qo_head_idx; + } else { + offset = ((batch_id * speculate_max_draft_token_num + + qo_idx_now / GROUP_SIZE) * + num_chunks + + chunk_idx) * + q_num_heads + + qo_head_idx; + } + tmp_m[offset] = m_frag[fx][j]; + tmp_d[offset] = d_frag[fx][j]; + } + } + } + } +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaTriggerProgrammaticLaunchCompletion(); +#endif +} + +template +__global__ void multi_query_append_attention_warp1_4_kernel( + T *__restrict__ q, // [token_num, (num_heads + 2* kv_num_head) * head_dim] + T *__restrict__ cache_k, // [max_block_num, num_heads, block_size, + // head_dim] + T *__restrict__ cache_v, + const T *__restrict__ shift_bias, // [q_num_heads * HEAD_DIM] + const T *__restrict__ smooth_weight, // [q_num_heads * HEAD_DIM] + const T *__restrict__ sinks, // [q_num_heads] + const int *__restrict__ seq_lens, + const int *__restrict__ seq_lens_kv, + const int *__restrict__ seq_lens_encoder, + const int *__restrict__ batch_ids, + const int *__restrict__ tile_ids_per_batch, + const int *__restrict__ cu_seqlens_q, + const int *__restrict__ block_table, // [bsz, block_num_per_seq] + const int *__restrict__ mask_offset, + const bool *__restrict__ attn_mask, // [bsz, max_q, max_q] for tree-mask + const int max_seq_len, + const int max_block_num_per_seq, + const float scale, + const float quant_max_bound, + const float quant_min_bound, + const float in_scale, + const uint32_t chunk_size, + const int num_blocks_x_cpu, + T *__restrict__ tmp_workspace, // split kv [token_num, num_chunks, + // num_heads, head_dim] + float *__restrict__ tmp_m, // [token_num, num_chunks, num_heads] + float *__restrict__ tmp_d, // [token_num, num_chunks, num_heads] + OutT *__restrict__ out, + const int speculate_max_draft_token_num = 5, + const uint32_t attn_mask_len = -1, + const int sliding_window = 0, + const int sink_size = 0) { + constexpr uint32_t num_vecs_per_head = HEAD_DIM / num_elems_per_128b(); + static_assert(NUM_WARP_Q == 1, "NUM_WARP_Q must be 1"); + static_assert(NUM_WARP_KV == 4, "NUM_WARP_KV must be 4"); + const uint32_t btid = blockIdx.x, kv_head_idx = blockIdx.z; + const uint32_t kv_num_heads = gridDim.z; + const uint32_t q_num_heads = kv_num_heads * GROUP_SIZE; + const uint32_t q_head_idx = kv_head_idx * GROUP_SIZE; + const uint32_t tid = threadIdx.x, wid = threadIdx.y; + const uint32_t num_chunks = gridDim.y; + const uint32_t chunk_idx = blockIdx.y; + + const int32_t batch_id = batch_ids[btid]; + if (batch_id == -1) return; + + const uint32_t tile_id = tile_ids_per_batch[btid]; + const uint32_t num_rows_per_block = num_frags_x * 16; + const int *block_table_now = block_table + batch_id * max_block_num_per_seq; + + const uint32_t q_len = seq_lens[batch_id]; + if (q_len <= 0) { + return; + } + + uint32_t kv_len = seq_lens_kv[batch_id]; + if (ENABLE_PREFILL) { + kv_len += q_len; + if (kv_len <= 0) { + return; + } + } else { + if (kv_len <= 0) { + return; + } + kv_len += q_len; + } + const int seq_len_enc = seq_lens_encoder[batch_id]; + if (seq_len_enc > 0) { + return; + } + const uint32_t num_chunks_this_seq = div_up(kv_len, chunk_size); + if (chunk_idx >= num_chunks_this_seq) { + return; + } + + const uint32_t chunk_start = partition_kv ? chunk_idx * chunk_size : 0; + const uint32_t chunk_end = + partition_kv ? min(kv_len, chunk_start + chunk_size) : kv_len; + const uint32_t chunk_len = chunk_end - chunk_start; + + extern __shared__ uint8_t smem[]; + float s_frag[num_frags_x][num_frags_z][8]; + float o_frag[num_frags_x][num_frags_y][8]; + float m_frag[num_frags_x][2]; + float d_frag[num_frags_x][2]; + init_states(o_frag, m_frag, d_frag); + + const uint32_t q_n_stride = q_num_heads * HEAD_DIM; + const uint32_t q_ori_n_stride = (q_num_heads + kv_num_heads * 2) * HEAD_DIM; + const uint32_t kv_n_stride = kv_num_heads * BLOCK_SIZE * HEAD_DIM; + const uint32_t kv_h_stride = BLOCK_SIZE * HEAD_DIM; + const uint32_t kv_b_stride = HEAD_DIM; + const uint32_t q_start_seq_id = cu_seqlens_q[batch_id]; + const uint32_t q_base_seq_id_this_block = tile_id * num_frags_x * 16; + const uint32_t q_offset = q_start_seq_id * q_ori_n_stride + + q_head_idx * HEAD_DIM + + tid % 8 * num_elems_per_128b(); + const uint32_t o_offset = q_start_seq_id * q_n_stride + + q_head_idx * HEAD_DIM + + tid % 8 * num_elems_per_128b(); + T *q_base_ptr = q + q_offset; + T *o_base_ptr_T = nullptr; + OutT *o_base_ptr_int8 = nullptr; + // When partition_kv=false (nosplit), always write to out directly, + // even if num_chunks_this_seq > 1 (tmp_workspace may be nullptr). + if (!partition_kv || num_chunks_this_seq <= 1) { + o_base_ptr_int8 = out + o_offset; + } else { + if (ENABLE_PREFILL) { + o_base_ptr_T = tmp_workspace + batch_id * num_chunks * q_n_stride + + chunk_idx * q_n_stride + q_head_idx * HEAD_DIM + + tid % 8 * num_elems_per_128b(); + } else { + o_base_ptr_T = + tmp_workspace + + batch_id * speculate_max_draft_token_num * num_chunks * q_n_stride + + chunk_idx * q_n_stride + q_head_idx * HEAD_DIM + + tid % 8 * num_elems_per_128b(); + } + } + const int *mask_offset_this_seq = + mask_offset ? mask_offset + q_start_seq_id * 2 : nullptr; + smem_t qo_smem(smem); + + uint32_t q_smem_offset_r = smem_t::get_permuted_offset( + tid % 16, tid / 16); // 16 * 16 + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaGridDependencySynchronize(); +#endif + + load_q_global_smem_multi_warps(q_base_ptr, + &qo_smem, + q_base_seq_id_this_block, + q_len, + q_ori_n_stride, + HEAD_DIM); + commit_group(); + wait_group<0>(); + __syncthreads(); + + q_smem_inplace_multiply_sm_scale_multi_warps( + &qo_smem, scale); + + static_assert(num_rows_per_block == num_frags_x * 16); + static_assert(BLOCK_SIZE == NUM_WARP_KV * num_frags_z * 16); + smem_t k_smem(smem + num_rows_per_block * HEAD_DIM * sizeof(T)), + v_smem(smem + (num_rows_per_block + BLOCK_SIZE) * HEAD_DIM * sizeof(T)); + + const uint32_t num_iterations = div_up( + CAUSAL + ? (min(chunk_len, + sub_if_greater_or_zero( + kv_len - q_len + + div_up((tile_id + 1) * num_rows_per_block, GROUP_SIZE), + chunk_start))) + : chunk_len, + BLOCK_SIZE); + const uint32_t mask_check_iteration = + (CAUSAL ? (min(chunk_len, + sub_if_greater_or_zero(kv_len - q_len, chunk_start))) + : mask_offset ? 0 + : chunk_len) / + (BLOCK_SIZE); + + uint32_t k_smem_offset_r = smem_t::get_permuted_offset( + wid * num_frags_z * 16 + 8 * (tid / 16) + tid % 8, (tid % 16) / 8); + + uint32_t v_smem_offset_r = smem_t::get_permuted_offset( + wid * num_frags_z * 16 + tid % 16, tid / 16); + uint32_t kv_smem_offset_w = smem_t::get_permuted_offset( + wid * 4 + tid / 8, tid % 8); + + uint32_t kv_idx_base = chunk_start; + int block_id = __ldg(&block_table_now[kv_idx_base / BLOCK_SIZE]); + const uint32_t const_offset = kv_head_idx * kv_h_stride + + (wid * 4 + tid / 8) * kv_b_stride + + tid % 8 * num_elems_per_128b(); + T *cache_k_now = cache_k + block_id * kv_n_stride + const_offset; + T *cache_v_now = cache_v + block_id * kv_n_stride + const_offset; + + produce_kv_blockwise(k_smem, + &kv_smem_offset_w, + &cache_k_now, + kv_head_idx, + kv_n_stride, + kv_h_stride, + kv_b_stride, + kv_idx_base, + chunk_end); + commit_group(); + + produce_kv_blockwise(v_smem, + &kv_smem_offset_w, + &cache_v_now, + kv_head_idx, + kv_n_stride, + kv_h_stride, + kv_b_stride, + kv_idx_base, + chunk_end); + commit_group(); + +#pragma unroll 1 + for (uint32_t iter = 0; iter < num_iterations; ++iter) { + wait_group<1>(); + __syncthreads(); + + // s = qk + compute_qk( + &qo_smem, &q_smem_offset_r, &k_smem, &k_smem_offset_r, s_frag); + // mask according to kv_idx and q_idx + if (iter >= mask_check_iteration || sliding_window > 0) { + mask_s( + attn_mask ? attn_mask + batch_id * attn_mask_len * attn_mask_len + : nullptr, + q_base_seq_id_this_block, + kv_idx_base + wid * num_frags_z * 16, + q_len, + kv_len, + chunk_end, + attn_mask_len, + s_frag, + mask_offset_this_seq, + sliding_window, + sink_size); + } + + // update m,d + update_mdo_states( + s_frag, o_frag, m_frag, d_frag); + __syncthreads(); + + kv_idx_base += BLOCK_SIZE; + block_id = __ldg(&block_table_now[kv_idx_base / BLOCK_SIZE]); + if (block_id < 0) { + block_id = 0; + } + cache_k_now = cache_k + block_id * kv_n_stride + const_offset; + produce_kv_blockwise(k_smem, + &kv_smem_offset_w, + &cache_k_now, + kv_head_idx, + kv_n_stride, + kv_h_stride, + kv_b_stride, + kv_idx_base, + chunk_end); + commit_group(); + wait_group<1>(); + __syncthreads(); + + // compute sfm*v + compute_sfm_v( + &v_smem, &v_smem_offset_r, s_frag, o_frag, d_frag); + __syncthreads(); + + cache_v_now = cache_v + block_id * kv_n_stride + const_offset; + produce_kv_blockwise(v_smem, + &kv_smem_offset_w, + &cache_v_now, + kv_head_idx, + kv_n_stride, + kv_h_stride, + kv_b_stride, + kv_idx_base, + chunk_end); + commit_group(); + } + wait_group<0>(); + __syncthreads(); + + merge_block_res_v2( + o_frag, reinterpret_cast(smem), m_frag, d_frag, wid, tid); + + // nosplit: always normalize (partition_kv=false means no merge step later) + if (!partition_kv || num_chunks_this_seq <= 1) { + if (sinks) { + float current_sinks[num_frags_x][2]; +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (uint32_t j = 0; j < 2; ++j) { + const uint32_t h_offset = + (q_base_seq_id_this_block + fx * 16 + tid / 4 + 8 * j) % + GROUP_SIZE; + current_sinks[fx][j] = + static_cast(sinks[q_head_idx + h_offset]); + } + } + normalize_d( + o_frag, d_frag, m_frag, current_sinks); + } else { + normalize_d(o_frag, d_frag); + } + } + + // write o + // [num_frags_x, 16, num_frags_y, 16] + // nosplit: always write directly to out (not tmp_workspace) + if (!partition_kv || num_chunks_this_seq <= 1) { + write_o_reg_gmem_multi_warps_shift_smooth_quant( + o_frag, + &qo_smem, + o_base_ptr_int8, + shift_bias, + smooth_weight, + q_base_seq_id_this_block, + q_head_idx, + quant_max_bound, + quant_min_bound, + in_scale, + q_len, + q_n_stride, + HEAD_DIM); + } else { + write_o_reg_gmem_multi_warps_shift_smooth_quant( + o_frag, + &qo_smem, + o_base_ptr_T, + shift_bias, + smooth_weight, + q_base_seq_id_this_block, + q_head_idx, + quant_max_bound, + quant_min_bound, + in_scale, + q_len, + q_n_stride * num_chunks, + HEAD_DIM); + } + + // nosplit: skip tmp_m/tmp_d write (no merge step, tmp_m/tmp_d may be nullptr) + if (partition_kv && num_chunks_this_seq > 1) { + if (wid == 0) { +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (uint32_t j = 0; j < 2; ++j) { + const uint32_t qo_idx_now = + q_base_seq_id_this_block + tid / 4 + j * 8 + fx * 16; + const uint32_t qo_head_idx = q_head_idx + qo_idx_now % GROUP_SIZE; + const uint32_t qo_idx = q_start_seq_id + qo_idx_now / GROUP_SIZE; + + if (qo_idx - q_start_seq_id < q_len) { + uint32_t offset; + if (ENABLE_PREFILL) { + offset = (batch_id * num_chunks + chunk_idx) * q_num_heads + + qo_head_idx; + } else { + offset = ((batch_id * speculate_max_draft_token_num + + qo_idx_now / GROUP_SIZE) * + num_chunks + + chunk_idx) * + q_num_heads + + qo_head_idx; + } + tmp_m[offset] = m_frag[fx][j]; + tmp_d[offset] = d_frag[fx][j]; + } + } + } + } + } +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaTriggerProgrammaticLaunchCompletion(); +#endif +} + +template +void MultiQueryAppendAttention( + const AppendAttnMetaData &meta_data, + const paddle::Tensor &qkv, + const paddle::Tensor &cache_k, + const paddle::Tensor &cache_v, + const paddle::optional &attn_mask, + const paddle::optional &shift_bias, + const paddle::optional &smooth_weight, + const paddle::optional &sinks, + const paddle::Tensor &seq_lens_q, + const paddle::Tensor &seq_lens_kv, + const paddle::Tensor &seq_lens_encoder, + const paddle::Tensor &batch_id_per_token, + const paddle::Tensor &cu_seqlens_q, + const paddle::Tensor &block_table, + const paddle::Tensor &batch_ids, + const paddle::Tensor &tile_ids_per_batch, + const int num_blocks_x_cpu, + const int max_seq_len, + const int max_dec_len, + const float quant_max_bound, + const float quant_min_bound, + const float in_scale, + const int max_partition_size, + const int encoder_max_partition_size, + const int speculate_max_draft_token_num, + const bool is_decoder, + cudaStream_t &stream, + paddle::Tensor *out, + const int sliding_window, + const int sink_size = 0) { + using NV_TYPE = typename cascade_attn_type_traits::type; + using OUT_NV_TYPE = typename cascade_attn_type_traits::type; + + auto num_heads = meta_data.q_num_heads; + auto kv_num_heads = meta_data.kv_num_heads; + auto token_num = meta_data.token_nums; + auto bsz = meta_data.batch_size; + auto max_block_num_per_seq = meta_data.max_blocks_per_seq; + + constexpr uint32_t num_warps = 4; + constexpr uint32_t NUM_WARP_KV = num_warps / NUM_WARP_Q; + constexpr uint32_t num_frags_x = BLOCK_SHAPE_Q / (16 * NUM_WARP_Q); // 1 or 2 + constexpr uint32_t num_frags_y = HEAD_DIM / 16; + constexpr uint32_t num_qrow_per_block = NUM_WARP_Q * num_frags_x * 16; + + auto *allocator = paddle::GetAllocator(qkv.place()); + + const float scale = 1.f / sqrt(HEAD_DIM); + + if constexpr (NUM_WARP_Q == 4) { + constexpr uint32_t num_frags_z = BLOCK_SIZE / 16; + constexpr uint32_t smem_size = + (num_warps * num_frags_x + NUM_WARP_KV * num_frags_z * 2) * 16 * + HEAD_DIM * sizeof(T); + auto split_kv_kernel = multi_query_append_attention_kernel; + if (smem_size >= 48 * 1024) { + cudaFuncSetAttribute(split_kv_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + } + const int dev_id = 0; + int sm_count; + cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id); + + uint32_t chunk_size = static_cast(encoder_max_partition_size); + const int num_chunks = div_up(max_dec_len, chunk_size); + // Deterministic mode: force use nosplit kernel to ensure consistent + // floating-point accumulation order across all sequence lengths + const bool force_no_partition = getEnvDeterministicMode(); + + // Debug log for determinism verification + if (getEnvDeterministicDebug()) { + printf( + "[DET_DEBUG] num_chunks=%d, chunk_size=%u, max_dec_len=%d, " + "force_no_partition=%d\n", + num_chunks, + chunk_size, + max_dec_len, + force_no_partition); + } + + dim3 grids(num_blocks_x_cpu, num_chunks, kv_num_heads); + dim3 blocks(32, num_warps); + if (num_chunks <= 1 || force_no_partition) { + auto nosplit_kv_kernel = + multi_query_append_attention_kernel; + if (smem_size >= 48 * 1024) { + cudaFuncSetAttribute(nosplit_kv_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + } + launchWithPdlWhenEnabled( + nosplit_kv_kernel, + grids, + blocks, + smem_size, + stream, + reinterpret_cast(const_cast(qkv.data())), + reinterpret_cast(const_cast(cache_k.data())), + reinterpret_cast(const_cast(cache_v.data())), + shift_bias ? reinterpret_cast( + const_cast(shift_bias.get().data())) + : nullptr, + smooth_weight ? reinterpret_cast( + const_cast(smooth_weight.get().data())) + : nullptr, + sinks ? reinterpret_cast( + const_cast(sinks.get().data())) + : nullptr, + seq_lens_q.data(), + seq_lens_kv.data(), + batch_ids.data(), + tile_ids_per_batch.data(), + cu_seqlens_q.data(), + block_table.data(), + meta_data.mask_offset, + max_seq_len, + max_block_num_per_seq, + scale, + quant_max_bound, + quant_min_bound, + in_scale, + chunk_size, + num_blocks_x_cpu, + nullptr, + nullptr, + nullptr, + reinterpret_cast(out->data()), + speculate_max_draft_token_num, + sliding_window, + sink_size); + + } else { + phi::Allocator::AllocationPtr tmp_workspace, tmp_m, tmp_d; + if (ENABLE_PREFILL) { + tmp_workspace = allocator->Allocate( + phi::SizeOf(qkv.dtype()) * + static_cast(token_num * num_chunks * num_heads * HEAD_DIM)); + tmp_m = allocator->Allocate( + phi::SizeOf(paddle::DataType::FLOAT32) * + static_cast(token_num * num_chunks * num_heads)); + tmp_d = allocator->Allocate( + phi::SizeOf(paddle::DataType::FLOAT32) * + static_cast(token_num * num_chunks * num_heads)); + } else { + tmp_workspace = allocator->Allocate( + phi::SizeOf(qkv.dtype()) * + static_cast(speculate_max_draft_token_num * bsz * + num_chunks * num_heads * HEAD_DIM)); + tmp_m = allocator->Allocate( + phi::SizeOf(paddle::DataType::FLOAT32) * + static_cast(speculate_max_draft_token_num * bsz * + num_chunks * num_heads)); + tmp_d = allocator->Allocate( + phi::SizeOf(paddle::DataType::FLOAT32) * + static_cast(speculate_max_draft_token_num * bsz * + num_chunks * num_heads)); + } + + launchWithPdlWhenEnabled( + split_kv_kernel, + grids, + blocks, + smem_size, + stream, + reinterpret_cast(const_cast(qkv.data())), + reinterpret_cast(const_cast(cache_k.data())), + reinterpret_cast(const_cast(cache_v.data())), + shift_bias ? reinterpret_cast( + const_cast(shift_bias.get().data())) + : nullptr, + smooth_weight ? reinterpret_cast( + const_cast(smooth_weight.get().data())) + : nullptr, + sinks ? reinterpret_cast( + const_cast(sinks.get().data())) + : nullptr, + seq_lens_q.data(), + seq_lens_kv.data(), + batch_ids.data(), + tile_ids_per_batch.data(), + cu_seqlens_q.data(), + block_table.data(), + meta_data.mask_offset, + max_seq_len, + max_block_num_per_seq, + scale, + quant_max_bound, + quant_min_bound, + in_scale, + chunk_size, + num_blocks_x_cpu, + reinterpret_cast(tmp_workspace->ptr()), + static_cast(tmp_m->ptr()), + static_cast(tmp_d->ptr()), + reinterpret_cast(out->data()), + speculate_max_draft_token_num, + sliding_window, + sink_size); + // merge + constexpr int vec_size = num_elems_per_128b(); + constexpr int blockx = HEAD_DIM / vec_size; + constexpr int blocky = (128 + blockx - 1) / blockx; + dim3 grids_merge(min(sm_count * 4, token_num), + num_heads); // 128k is too large + dim3 blocks_merge(blockx, blocky); + auto *kernelFn = merge_multi_chunks_v2_kernel; + launchWithPdlWhenEnabled( + kernelFn, + grids_merge, + blocks_merge, + 0, + stream, + reinterpret_cast(tmp_workspace->ptr()), + static_cast(tmp_m->ptr()), + static_cast(tmp_d->ptr()), + seq_lens_q.data(), + seq_lens_kv.data(), + seq_lens_encoder.data(), + batch_id_per_token.data(), + cu_seqlens_q.data(), + shift_bias ? reinterpret_cast( + const_cast(shift_bias.get().data())) + : nullptr, + smooth_weight ? reinterpret_cast( + const_cast(smooth_weight.get().data())) + : nullptr, + sinks ? reinterpret_cast( + const_cast(sinks.get().data())) + : nullptr, + reinterpret_cast(out->data()), + quant_max_bound, + quant_min_bound, + in_scale, + max_seq_len, + num_chunks, + num_heads, + chunk_size, + HEAD_DIM, + token_num, + speculate_max_draft_token_num); + } + } else { + constexpr uint32_t num_frags_z = BLOCK_SIZE / 16 / NUM_WARP_KV; + constexpr uint32_t smem_size = + (num_frags_x + NUM_WARP_KV * num_frags_z * 2) * 16 * HEAD_DIM * + sizeof(T); + auto split_kv_kernel = + multi_query_append_attention_warp1_4_kernel; + if (smem_size >= 48 * 1024) { + cudaFuncSetAttribute(split_kv_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + } + const int dev_id = 0; + int sm_count; + cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id); + + uint32_t chunk_size = static_cast(max_partition_size); + + uint32_t attn_mask_len; + if (attn_mask) { + attn_mask_len = attn_mask.get().shape()[1]; + } else { + attn_mask_len = -1; + } + + const int num_chunks = div_up(max_seq_len, chunk_size); + // Deterministic mode: force nosplit kernel with gridDim.y=1 to ensure + // consistent floating-point accumulation order across all sequence lengths. + // NOTE: the warp1_4 nosplit kernel uses runtime num_chunks_this_seq check + // (not constexpr partition_kv), so we MUST set gridDim.y=1 to avoid + // nullptr write to tmp_workspace when num_chunks_this_seq > 1. + const bool force_no_partition = getEnvDeterministicMode(); + const int grid_chunks = force_no_partition ? 1 : num_chunks; + dim3 grids(num_blocks_x_cpu, grid_chunks, kv_num_heads); + dim3 blocks(32, num_warps); + // before it's deadcode: num_chunks <= 0 + // now it's only used for determinism + if (force_no_partition) { + auto nosplit_kv_kernel = + multi_query_append_attention_warp1_4_kernel; + if (smem_size >= 48 * 1024) { + cudaFuncSetAttribute(nosplit_kv_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + } + launchWithPdlWhenEnabled( + nosplit_kv_kernel, + grids, + blocks, + smem_size, + stream, + reinterpret_cast(const_cast(qkv.data())), + reinterpret_cast(const_cast(cache_k.data())), + reinterpret_cast(const_cast(cache_v.data())), + shift_bias ? reinterpret_cast( + const_cast(shift_bias.get().data())) + : nullptr, + smooth_weight ? reinterpret_cast( + const_cast(smooth_weight.get().data())) + : nullptr, + sinks ? reinterpret_cast( + const_cast(sinks.get().data())) + : nullptr, + seq_lens_q.data(), + seq_lens_kv.data(), + seq_lens_encoder.data(), + batch_ids.data(), + tile_ids_per_batch.data(), + cu_seqlens_q.data(), + block_table.data(), + meta_data.mask_offset, + attn_mask ? const_cast(attn_mask.get().data()) + : nullptr, + max_seq_len, + max_block_num_per_seq, + scale, + quant_max_bound, + quant_min_bound, + in_scale, + chunk_size, + num_blocks_x_cpu, + nullptr, + nullptr, + nullptr, + reinterpret_cast(out->data()), + speculate_max_draft_token_num, + attn_mask_len, + sliding_window, + sink_size); + } else { + phi::Allocator::AllocationPtr tmp_workspace, tmp_m, tmp_d; + if (is_decoder) { + tmp_workspace = allocator->Allocate( + phi::SizeOf(qkv.dtype()) * + static_cast(bsz * num_chunks * num_heads * HEAD_DIM)); + tmp_m = allocator->Allocate( + phi::SizeOf(paddle::DataType::FLOAT32) * + static_cast(bsz * num_chunks * num_heads)); + tmp_d = allocator->Allocate( + phi::SizeOf(paddle::DataType::FLOAT32) * + static_cast(bsz * num_chunks * num_heads)); + } else { + if (ENABLE_PREFILL) { + tmp_workspace = + allocator->Allocate(phi::SizeOf(qkv.dtype()) * + static_cast(token_num * num_chunks * + num_heads * HEAD_DIM)); + tmp_m = allocator->Allocate( + phi::SizeOf(paddle::DataType::FLOAT32) * + static_cast(token_num * num_chunks * num_heads)); + tmp_d = allocator->Allocate( + phi::SizeOf(paddle::DataType::FLOAT32) * + static_cast(token_num * num_chunks * num_heads)); + } else { + tmp_workspace = allocator->Allocate( + phi::SizeOf(qkv.dtype()) * + static_cast(speculate_max_draft_token_num * bsz * + num_chunks * num_heads * HEAD_DIM)); + tmp_m = allocator->Allocate( + phi::SizeOf(paddle::DataType::FLOAT32) * + static_cast(speculate_max_draft_token_num * bsz * + num_chunks * num_heads)); + tmp_d = allocator->Allocate( + phi::SizeOf(paddle::DataType::FLOAT32) * + static_cast(speculate_max_draft_token_num * bsz * + num_chunks * num_heads)); + } + } + launchWithPdlWhenEnabled( + split_kv_kernel, + grids, + blocks, + smem_size, + stream, + reinterpret_cast(const_cast(qkv.data())), + reinterpret_cast(const_cast(cache_k.data())), + reinterpret_cast(const_cast(cache_v.data())), + shift_bias ? reinterpret_cast( + const_cast(shift_bias.get().data())) + : nullptr, + smooth_weight ? reinterpret_cast( + const_cast(smooth_weight.get().data())) + : nullptr, + sinks ? reinterpret_cast( + const_cast(sinks.get().data())) + : nullptr, + seq_lens_q.data(), + seq_lens_kv.data(), + seq_lens_encoder.data(), + batch_ids.data(), + tile_ids_per_batch.data(), + cu_seqlens_q.data(), + block_table.data(), + meta_data.mask_offset, + attn_mask ? const_cast(attn_mask.get().data()) + : nullptr, + max_seq_len, + max_block_num_per_seq, + scale, + quant_max_bound, + quant_min_bound, + in_scale, + chunk_size, + num_blocks_x_cpu, + reinterpret_cast(tmp_workspace->ptr()), + static_cast(tmp_m->ptr()), + static_cast(tmp_d->ptr()), + reinterpret_cast(out->data()), + speculate_max_draft_token_num, + attn_mask_len, + sliding_window, + sink_size); + + // merge + constexpr int vec_size = num_elems_per_128b(); + if (is_decoder) { + constexpr int blockx = HEAD_DIM / vec_size; + constexpr int blocky = (128 + blockx - 1) / blockx; + dim3 grids_merge(bsz, num_heads); + dim3 blocks_merge(blockx, blocky); + auto *kernelFn = merge_multi_chunks_decoder_kernel; + launchWithPdlWhenEnabled( + kernelFn, + grids_merge, + blocks_merge, + 0, + stream, + reinterpret_cast(tmp_workspace->ptr()), + static_cast(tmp_m->ptr()), + static_cast(tmp_d->ptr()), + seq_lens_q.data(), + seq_lens_kv.data(), + seq_lens_encoder.data(), + cu_seqlens_q.data(), + shift_bias ? reinterpret_cast( + const_cast(shift_bias.get().data())) + : nullptr, + smooth_weight ? reinterpret_cast( + const_cast(smooth_weight.get().data())) + : nullptr, + sinks ? reinterpret_cast( + const_cast(sinks.get().data())) + : nullptr, + reinterpret_cast(out->data()), + quant_max_bound, + quant_min_bound, + in_scale, + max_seq_len, + num_chunks, + num_heads, + chunk_size, + HEAD_DIM); + } else { + constexpr int blockx = HEAD_DIM / vec_size; + constexpr int blocky = (128 + blockx - 1) / blockx; + dim3 grids_merge(min(sm_count * 4, token_num), num_heads); + dim3 blocks_merge(blockx, blocky); + auto *kernelFn = merge_multi_chunks_v2_kernel; + launchWithPdlWhenEnabled( + kernelFn, + grids_merge, + blocks_merge, + 0, + stream, + reinterpret_cast(tmp_workspace->ptr()), + static_cast(tmp_m->ptr()), + static_cast(tmp_d->ptr()), + seq_lens_q.data(), + seq_lens_kv.data(), + seq_lens_encoder.data(), + batch_id_per_token.data(), + cu_seqlens_q.data(), + shift_bias ? reinterpret_cast( + const_cast(shift_bias.get().data())) + : nullptr, + smooth_weight ? reinterpret_cast( + const_cast(smooth_weight.get().data())) + : nullptr, + sinks ? reinterpret_cast( + const_cast(sinks.get().data())) + : nullptr, + reinterpret_cast(out->data()), + quant_max_bound, + quant_min_bound, + in_scale, + max_seq_len, + num_chunks, + num_heads, + chunk_size, + HEAD_DIM, + token_num, + speculate_max_draft_token_num); + } + } + } +} diff --git a/custom_ops/gpu_ops/append_attn/multiquery_attention_c16_kernel.h b/custom_ops/gpu_ops/append_attn/multiquery_attention_c16_kernel.h new file mode 100644 index 00000000000..9fe215be66b --- /dev/null +++ b/custom_ops/gpu_ops/append_attn/multiquery_attention_c16_kernel.h @@ -0,0 +1,57 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include "append_attention_func.cuh" + +template +void MultiQueryAppendAttention( + const AppendAttnMetaData &meta_data, + const paddle::Tensor &qkv, + const paddle::Tensor &cache_k, + const paddle::Tensor &cache_v, + const paddle::optional &attn_mask, + const paddle::optional &shift_bias, + const paddle::optional &smooth_weight, + const paddle::optional &sinks, + const paddle::Tensor &seq_lens_q, + const paddle::Tensor &seq_lens_kv, + const paddle::Tensor &seq_lens_encoder, + const paddle::Tensor &batch_id_per_token, + const paddle::Tensor &cu_seqlens_q, + const paddle::Tensor &block_table, + const paddle::Tensor &batch_ids, + const paddle::Tensor &tile_ids_per_batch, + const int num_blocks_x_cpu, + const int max_seq_len, + const int max_dec_len, + const float quant_max_bound, + const float quant_min_bound, + const float in_scale, + const int max_partition_size, + const int encoder_max_partition_size, + const int speculate_max_draft_token_num, + const bool is_decoder, + cudaStream_t &stream, + paddle::Tensor *out, + const int sliding_window, + const int sink_size); diff --git a/custom_ops/gpu_ops/append_attn/multiquery_attention_c4_impl.cuh b/custom_ops/gpu_ops/append_attn/multiquery_attention_c4_impl.cuh new file mode 100644 index 00000000000..e0fa2b229c8 --- /dev/null +++ b/custom_ops/gpu_ops/append_attn/multiquery_attention_c4_impl.cuh @@ -0,0 +1,1632 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include "multiquery_attention_c4_kernel.h" + +template +__global__ void multi_query_append_attention_c4_kernel( + T *__restrict__ q, // [token_num, (num_heads + 2* kv_num_head) * head_dim] + CacheT *__restrict__ cache_k, // [max_block_num, num_heads, block_size, + // head_dim] + CacheT *__restrict__ cache_v, + const T *__restrict__ cache_k_scale, // [num_kv_heads, head_dim] + const T *__restrict__ cache_k_zero_point, // [num_kv_heads, head_dim] + const T *__restrict__ cache_v_scale, // [num_kv_heads, head_dim] + const T *__restrict__ cache_v_zero_point, // [num_kv_heads, head_dim] + const T *__restrict__ shift_bias, // [q_num_heads * HEAD_DIM] + const T *__restrict__ smooth_weight, // [q_num_heads * HEAD_DIM] + const T *__restrict__ sinks, // [q_num_heads] + const int *__restrict__ seq_lens, + const int *__restrict__ seq_lens_kv, + const int *__restrict__ batch_ids, + const int *__restrict__ tile_ids_per_batch, + const int *__restrict__ cu_seqlens_q, + const int *__restrict__ block_table, // [bsz, block_num_per_seq] + const int *__restrict__ mask_offset, + const int max_seq_len, + const int max_dec_len, + const int max_block_num_per_seq, + const float scale, + const float quant_max_bound, + const float quant_min_bound, + const float in_scale, + const uint32_t chunk_size, + const int num_blocks_x_cpu, + T *__restrict__ tmp_workspace, // split kv [token_num, num_chunks, + // num_heads, head_dim] + float *__restrict__ tmp_m, // [token_num, num_chunks, num_heads] + float *__restrict__ tmp_d, // [token_num, num_chunks, num_heads] + OutT *__restrict__ out, + const int speculate_max_draft_token_num = 5, + const int sliding_window = 0, + const int sink_size = 0) { + constexpr uint32_t num_vecs_per_head = HEAD_DIM / num_elems_per_128b(); + constexpr uint32_t num_vecs_per_head_k = + HEAD_DIM / 2 / num_elems_per_128b(); + constexpr uint32_t num_vecs_per_blocksize = + BLOCK_SIZE / 2 / num_elems_per_128b(); + constexpr uint32_t inv_k_stride = 8 / num_vecs_per_head_k; + constexpr uint32_t inv_v_stride = 8 / num_vecs_per_blocksize; + const uint32_t btid = blockIdx.x, kv_head_idx = blockIdx.z; + const uint32_t kv_num_heads = gridDim.z; + const uint32_t q_num_heads = kv_num_heads * GROUP_SIZE; + const uint32_t q_head_idx = kv_head_idx * GROUP_SIZE; + const uint32_t tid = threadIdx.x, wid = threadIdx.y; + const uint32_t num_chunks = gridDim.y; + const uint32_t chunk_idx = blockIdx.y; + + const uint32_t batch_id = batch_ids[btid]; + const uint32_t tile_id = tile_ids_per_batch[btid]; + const uint32_t num_rows_per_block = NUM_WARPS * num_frags_x * 16; + const int *block_table_now = nullptr; + + block_table_now = block_table + batch_id * max_block_num_per_seq; + + // When cudagraph capture prefill, may launch more gridDim.x + if (btid >= static_cast(num_blocks_x_cpu)) { + return; + } + + const uint32_t q_len = seq_lens[batch_id]; + if (q_len <= 0) { + return; + } + const uint32_t q_end = + min(q_len, div_up((tile_id + 1) * num_rows_per_block, GROUP_SIZE)); + uint32_t kv_len = seq_lens_kv[batch_id]; + if (ENABLE_PREFILL) { + kv_len += q_len; + if (kv_len <= 0) { + return; + } + } else { + if (kv_len <= 0) { + return; + } + kv_len += q_len; + } + const uint32_t num_chunks_this_seq = div_up(kv_len, chunk_size); + if (chunk_idx >= num_chunks_this_seq) { + return; + } + + const uint32_t chunk_start = partition_kv ? chunk_idx * chunk_size : 0; + const uint32_t chunk_end = + partition_kv ? min(kv_len, chunk_start + chunk_size) : kv_len; + const uint32_t chunk_len = chunk_end - chunk_start; + + extern __shared__ uint8_t smem[]; + float s_frag[num_frags_x][num_frags_z][8]; + float o_frag[num_frags_x][num_frags_y][8]; + float m_frag[num_frags_x][2]; + float d_frag[num_frags_x][2]; +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaGridDependencySynchronize(); +#endif + + const T *cache_k_scale_now = cache_k_scale + kv_head_idx * HEAD_DIM; + const T *cache_k_zp_now = cache_k_zero_point + kv_head_idx * HEAD_DIM; + const T *cache_v_scale_now = cache_v_scale + kv_head_idx * HEAD_DIM; + const T *cache_v_zp_now = cache_v_zero_point + kv_head_idx * HEAD_DIM; + T *cache_k_scale_smem = reinterpret_cast( + smem + NUM_WARPS * num_frags_x * 16 * HEAD_DIM * sizeof(T) + + num_frags_z * 16 * HEAD_DIM / 2 * sizeof(CacheT) * 2); + T *cache_k_zero_point_smem = cache_k_scale_smem + HEAD_DIM; + T *cache_v_scale_smem = cache_k_zero_point_smem + HEAD_DIM; + T *cache_v_zero_point_smem = cache_v_scale_smem + HEAD_DIM; +#pragma unroll + for (uint32_t i = wid * 32 + tid; i < HEAD_DIM; i += 128) { + cache_k_scale_smem[i] = cache_k_scale_now[i]; + cache_k_zero_point_smem[i] = cache_k_zp_now[i]; + cache_v_scale_smem[i] = cache_v_scale_now[i]; + cache_v_zero_point_smem[i] = cache_v_zp_now[i]; + } + + init_states(o_frag, m_frag, d_frag); + + const uint32_t q_n_stride = q_num_heads * HEAD_DIM; + const uint32_t q_ori_n_stride = (q_num_heads + kv_num_heads * 2) * HEAD_DIM; + const uint32_t kv_n_stride = kv_num_heads * BLOCK_SIZE * HEAD_DIM / 2; + const uint32_t kv_h_stride = BLOCK_SIZE * HEAD_DIM / 2; + const uint32_t kv_b_stride = HEAD_DIM / 2; + const uint32_t kv_d_stride = BLOCK_SIZE / 2; + const uint32_t q_start_seq_id = cu_seqlens_q[batch_id]; + const uint32_t q_base_seq_id_this_block = + (tile_id * NUM_WARPS + wid) * num_frags_x * 16; + const uint32_t q_offset = q_start_seq_id * q_ori_n_stride + + q_head_idx * HEAD_DIM + + tid % 8 * num_elems_per_128b(); + const uint32_t o_offset = q_start_seq_id * q_n_stride + + q_head_idx * HEAD_DIM + + tid % 8 * num_elems_per_128b(); + T *q_base_ptr = q + q_offset; + + T *o_base_ptr_T = nullptr; + OutT *o_base_ptr_int8 = nullptr; + if constexpr (partition_kv) { + if (ENABLE_PREFILL) { + o_base_ptr_T = tmp_workspace + q_start_seq_id * num_chunks * q_n_stride + + chunk_idx * q_n_stride + q_head_idx * HEAD_DIM + + tid % 8 * num_elems_per_128b(); + } else { + o_base_ptr_T = + tmp_workspace + + batch_id * speculate_max_draft_token_num * num_chunks * q_n_stride + + chunk_idx * q_n_stride + q_head_idx * HEAD_DIM + + tid % 8 * num_elems_per_128b(); + } + } else { + o_base_ptr_int8 = out + o_offset; + } + const int *mask_offset_this_seq = + mask_offset ? mask_offset + q_start_seq_id * 2 : nullptr; + smem_t qo_smem(smem); + + uint32_t q_smem_offset_r = smem_t::get_permuted_offset( + wid * num_frags_x * 16 + tid % 16, tid / 16); + load_q_global_smem( + q_base_ptr, + &qo_smem, + q_base_seq_id_this_block, + q_end, + q_ori_n_stride, + HEAD_DIM); + commit_group(); + wait_group<0>(); + __syncthreads(); + + q_smem_inplace_multiply_sm_scale(&qo_smem, + scale); + + T cache_k_scale_frag[num_frags_y][4]; + T cache_k_zp_frag[num_frags_y][4]; + T magic_number; + if constexpr (std::is_same::value) { + magic_number = static_cast(1032.f); + } else { + magic_number = static_cast(136.f); + } +#pragma unroll + for (uint32_t fy = 0; fy < num_frags_y; ++fy) { + *(reinterpret_cast(&cache_k_scale_frag[fy][0])) = + *(reinterpret_cast(&cache_k_scale_smem[fy * 16]) + tid % 4); + *(reinterpret_cast(&cache_k_scale_frag[fy][2])) = + *(reinterpret_cast(&cache_k_scale_smem[fy * 16]) + tid % 4 + + 4); + *(reinterpret_cast(&cache_k_zp_frag[fy][0])) = + *(reinterpret_cast(&cache_k_zero_point_smem[fy * 16]) + + tid % 4); + *(reinterpret_cast(&cache_k_zp_frag[fy][2])) = + *(reinterpret_cast(&cache_k_zero_point_smem[fy * 16]) + + tid % 4 + 4); +#pragma unroll + for (uint32_t zp_i = 0; zp_i < 4; ++zp_i) { + cache_k_zp_frag[fy][zp_i] += magic_number; // 128 + 8 + } + } + T cache_v_scale_frag[num_frags_y][2]; + T cache_v_zp_frag[num_frags_y][2]; +#pragma unroll + for (uint32_t fy = 0; fy < num_frags_y; ++fy) { + cache_v_scale_frag[fy][0] = cache_v_scale_smem[fy * 16 + tid / 4]; + cache_v_scale_frag[fy][1] = cache_v_scale_smem[fy * 16 + tid / 4 + 8]; + cache_v_zp_frag[fy][0] = + cache_v_zero_point_smem[fy * 16 + tid / 4] + magic_number; + cache_v_zp_frag[fy][1] = + cache_v_zero_point_smem[fy * 16 + tid / 4 + 8] + magic_number; + } + + smem_t k_smem(smem + NUM_WARPS * num_frags_x * 16 * HEAD_DIM * sizeof(T)), + v_smem(smem + NUM_WARPS * num_frags_x * 16 * HEAD_DIM * sizeof(T) + + num_frags_z * 16 * HEAD_DIM / 2 * sizeof(CacheT)); + + const uint32_t num_iterations = div_up( + CAUSAL + ? (min(chunk_len, + sub_if_greater_or_zero( + kv_len - q_len + + div_up((tile_id + 1) * num_rows_per_block, GROUP_SIZE), + chunk_start))) + : chunk_len, + num_frags_z * 16); + const uint32_t mask_check_iteration = + (CAUSAL ? (min(chunk_len, + sub_if_greater_or_zero( + kv_len - q_len + + tile_id * num_rows_per_block / GROUP_SIZE, + chunk_start))) + : mask_offset ? 0 + : chunk_len) / + (num_frags_z * 16); + + uint32_t k_smem_offset_r = + smem_t::get_permuted_offset( + 8 * (tid / 16) + tid % 8, (tid % 16) / 8); + + uint32_t v_smem_offset_r = + smem_t::get_permuted_offset( + 8 * (tid / 16) + tid % 8, (tid % 16) / 8); + + uint32_t k_smem_offset_w = + smem_t::get_permuted_offset( + wid * 8 + tid / 4, tid % 4); + uint32_t v_smem_offset_w = + smem_t::get_permuted_offset( + wid * 16 + tid / 2, tid % 2); // 2 * 128 / 8 = 32B, 64 nums + + uint32_t kv_idx_base = chunk_start; + const uint32_t const_k_offset = kv_head_idx * kv_h_stride + + (wid * 8 + tid / 4) * kv_b_stride + + tid % 4 * num_elems_per_128b(); + const uint32_t const_v_offset = kv_head_idx * kv_h_stride + + (wid * 16 + tid / 2) * kv_d_stride + + tid % 2 * num_elems_per_128b(); + + produce_k_blockwise_c4(k_smem, + &k_smem_offset_w, + cache_k, + block_table_now, + kv_head_idx, + kv_n_stride, + kv_h_stride, + kv_b_stride, + kv_idx_base, + chunk_end, + const_k_offset); + commit_group(); + produce_v_blockwise_c4(v_smem, + &v_smem_offset_w, + cache_v, + block_table_now, + kv_head_idx, + kv_n_stride, + kv_h_stride, + kv_d_stride, + kv_idx_base, + chunk_end, + const_v_offset); + commit_group(); + +#pragma unroll 1 + for (uint32_t iter = 0; iter < num_iterations; ++iter) { + wait_group<1>(); + __syncthreads(); + + compute_qk_c4( + &qo_smem, + &q_smem_offset_r, + &k_smem, + &k_smem_offset_r, + s_frag, + cache_k_scale_frag, + cache_k_zp_frag); + + if (iter >= mask_check_iteration || sliding_window > 0) { + mask_s(nullptr, + q_base_seq_id_this_block, + kv_idx_base, + q_len, + kv_len, + chunk_end, + -1, + s_frag, + mask_offset_this_seq, + sliding_window, + sink_size); + } + + update_mdo_states( + s_frag, o_frag, m_frag, d_frag); + __syncthreads(); + + kv_idx_base += num_frags_z * 16; + produce_k_blockwise_c4(k_smem, + &k_smem_offset_w, + cache_k, + block_table_now, + kv_head_idx, + kv_n_stride, + kv_h_stride, + kv_b_stride, + kv_idx_base, + chunk_end, + const_k_offset); + commit_group(); + wait_group<1>(); + __syncthreads(); + + compute_sfm_v_c4(&v_smem, + &v_smem_offset_r, + s_frag, + o_frag, + d_frag, + cache_v_scale_frag, + cache_v_zp_frag); + __syncthreads(); + + produce_v_blockwise_c4(v_smem, + &v_smem_offset_w, + cache_v, + block_table_now, + kv_head_idx, + kv_n_stride, + kv_h_stride, + kv_d_stride, + kv_idx_base, + chunk_end, + const_v_offset); + commit_group(); + } + wait_group<0>(); + __syncthreads(); + + if constexpr (!partition_kv) { + if (sinks) { + float current_sinks[num_frags_x][2]; +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (uint32_t j = 0; j < 2; ++j) { + const uint32_t h_offset = + (q_base_seq_id_this_block + fx * 16 + tid / 4 + 8 * j) % + GROUP_SIZE; + current_sinks[fx][j] = + static_cast(sinks[q_head_idx + h_offset]); + } + } + normalize_d( + o_frag, d_frag, m_frag, current_sinks); + } else { + normalize_d(o_frag, d_frag); + } + } + + if constexpr (partition_kv) { + write_o_reg_gmem_shift_smooth_quant( + o_frag, + &qo_smem, + o_base_ptr_T, + shift_bias, + smooth_weight, + q_base_seq_id_this_block, + q_head_idx, + quant_max_bound, + quant_min_bound, + in_scale, + q_len, + partition_kv ? q_n_stride * num_chunks : q_n_stride, + HEAD_DIM); + } else { + write_o_reg_gmem_shift_smooth_quant( + o_frag, + &qo_smem, + o_base_ptr_int8, + shift_bias, + smooth_weight, + q_base_seq_id_this_block, + q_head_idx, + quant_max_bound, + quant_min_bound, + in_scale, + q_len, + partition_kv ? q_n_stride * num_chunks : q_n_stride, + HEAD_DIM); + } + + if constexpr (partition_kv) { +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (uint32_t j = 0; j < 2; ++j) { + const uint32_t qo_idx_now = + q_base_seq_id_this_block + tid / 4 + j * 8 + fx * 16; + const uint32_t qo_head_idx = q_head_idx + qo_idx_now % GROUP_SIZE; + const uint32_t qo_idx = q_start_seq_id + qo_idx_now / GROUP_SIZE; + if (qo_idx - q_start_seq_id < q_len) { + uint32_t offset; + if (ENABLE_PREFILL) { + offset = + (qo_idx * num_chunks + chunk_idx) * q_num_heads + qo_head_idx; + } else { + offset = ((batch_id * speculate_max_draft_token_num + + qo_idx_now / GROUP_SIZE) * + num_chunks + + chunk_idx) * + q_num_heads + + qo_head_idx; + } + tmp_m[offset] = m_frag[fx][j]; + tmp_d[offset] = d_frag[fx][j]; + } + } + } + } +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaTriggerProgrammaticLaunchCompletion(); +#endif +} + +template +__global__ void multi_query_append_attention_c4_warp1_4_kernel( + T *__restrict__ q, // [token_num, (num_heads + 2* kv_num_head) * head_dim] + CacheT *__restrict__ cache_k, // [max_block_num, num_heads, block_size, + // head_dim] + CacheT *__restrict__ cache_v, + const T *__restrict__ cache_k_scale, // [num_kv_heads, head_dim] + const T *__restrict__ cache_k_zero_point, // [num_kv_heads, head_dim] + const T *__restrict__ cache_v_scale, // [num_kv_heads, head_dim] + const T *__restrict__ cache_v_zero_point, // [num_kv_heads, head_dim] + const T *__restrict__ shift_bias, // [q_num_heads * HEAD_DIM] + const T *__restrict__ smooth_weight, // [q_num_heads * HEAD_DIM] + const T *__restrict__ sinks, // [q_num_heads] + const int *__restrict__ seq_lens, + const int *__restrict__ seq_lens_kv, + const int *__restrict__ seq_lens_encoder, + const int *__restrict__ batch_ids, + const int *__restrict__ tile_ids_per_batch, + const int *__restrict__ cu_seqlens_q, + const int *__restrict__ block_table, // [bsz, block_num_per_seq] + const int *__restrict__ mask_offset, + const bool *__restrict__ attn_mask, // [bsz, max_q, max_q] for tree-mask + const int max_seq_len, + const int max_dec_len, + const int max_block_num_per_seq, + const float scale, + const float quant_max_bound, + const float quant_min_bound, + const float in_scale, + const uint32_t chunk_size, + const int num_blocks_x_cpu, + T *__restrict__ tmp_workspace, // split kv [token_num, num_chunks, + // num_heads, head_dim] + float *__restrict__ tmp_m, // [token_num, num_chunks, num_heads] + float *__restrict__ tmp_d, // [token_num, num_chunks, num_heads] + OutT *__restrict__ out, + const int speculate_max_draft_token_num = 5, + const uint32_t attn_mask_len = -1, + const int sliding_window = 0, + const int sink_size = 0) { + constexpr uint32_t num_vecs_per_head = HEAD_DIM / num_elems_per_128b(); + constexpr uint32_t num_vecs_per_head_k = + HEAD_DIM / 2 / num_elems_per_128b(); + constexpr uint32_t num_vecs_per_blocksize = + BLOCK_SIZE / 2 / num_elems_per_128b(); + constexpr uint32_t inv_k_stride = 8 / num_vecs_per_head_k; + constexpr uint32_t inv_v_stride = 8 / num_vecs_per_blocksize; + static_assert(NUM_WARP_Q == 1, "NUM_WARP_Q must be 1"); + static_assert(NUM_WARP_KV == 4, "NUM_WARP_KV must be 4"); + const uint32_t btid = blockIdx.x, kv_head_idx = blockIdx.z; + const uint32_t kv_num_heads = gridDim.z; + const uint32_t q_num_heads = kv_num_heads * GROUP_SIZE; + const uint32_t q_head_idx = kv_head_idx * GROUP_SIZE; + const uint32_t tid = threadIdx.x, wid = threadIdx.y; + const uint32_t num_chunks = gridDim.y; + const uint32_t chunk_idx = blockIdx.y; + + const int32_t batch_id = batch_ids[btid]; + if (batch_id == -1) return; + + const uint32_t tile_id = tile_ids_per_batch[btid]; + const uint32_t num_rows_per_block = num_frags_x * 16; + const int *block_table_now = block_table + batch_id * max_block_num_per_seq; + + // When cudagraph capture prefill, may launch more gridDim.x + if (btid >= static_cast(num_blocks_x_cpu)) { + return; + } + + const uint32_t q_len = seq_lens[batch_id]; + if (q_len <= 0) { + return; + } + const uint32_t q_end = + min(q_len, div_up((tile_id + 1) * num_rows_per_block, GROUP_SIZE)); + uint32_t kv_len = seq_lens_kv[batch_id]; + if (ENABLE_PREFILL) { + kv_len += q_len; + if (kv_len <= 0) { + return; + } + } else { + if (kv_len <= 0) { + return; + } + kv_len += q_len; + } + const int seq_len_enc = seq_lens_encoder[batch_id]; + if (seq_len_enc > 0) { + return; + } + const uint32_t num_chunks_this_seq = div_up(kv_len, chunk_size); + if (chunk_idx >= num_chunks_this_seq) { + return; + } + + const uint32_t chunk_start = partition_kv ? chunk_idx * chunk_size : 0; + const uint32_t chunk_end = + partition_kv ? min(kv_len, chunk_start + chunk_size) : kv_len; + const uint32_t chunk_len = chunk_end - chunk_start; + + extern __shared__ uint8_t smem[]; + float s_frag[num_frags_x][num_frags_z][8]; + float o_frag[num_frags_x][num_frags_y][8]; + float m_frag[num_frags_x][2]; + float d_frag[num_frags_x][2]; + init_states(o_frag, m_frag, d_frag); +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaGridDependencySynchronize(); +#endif + + const T *cache_k_scale_now = cache_k_scale + kv_head_idx * HEAD_DIM; + const T *cache_k_zp_now = cache_k_zero_point + kv_head_idx * HEAD_DIM; + const T *cache_v_scale_now = cache_v_scale + kv_head_idx * HEAD_DIM; + const T *cache_v_zp_now = cache_v_zero_point + kv_head_idx * HEAD_DIM; + T *cache_k_scale_smem = reinterpret_cast( + smem + NUM_WARP_Q * num_frags_x * 16 * HEAD_DIM * sizeof(T) + + NUM_WARP_KV * num_frags_z * 16 * HEAD_DIM / 2 * sizeof(CacheT) * 2); + T *cache_k_zero_point_smem = cache_k_scale_smem + HEAD_DIM; + T *cache_v_scale_smem = cache_k_zero_point_smem + HEAD_DIM; + T *cache_v_zero_point_smem = cache_v_scale_smem + HEAD_DIM; +#pragma unroll + for (uint32_t i = wid * 32 + tid; i < HEAD_DIM; i += 128) { + cache_k_scale_smem[i] = cache_k_scale_now[i]; + cache_k_zero_point_smem[i] = cache_k_zp_now[i]; + cache_v_scale_smem[i] = cache_v_scale_now[i]; + cache_v_zero_point_smem[i] = cache_v_zp_now[i]; + } + + const uint32_t q_n_stride = q_num_heads * HEAD_DIM; + const uint32_t q_ori_n_stride = (q_num_heads + kv_num_heads * 2) * HEAD_DIM; + const uint32_t kv_n_stride = kv_num_heads * BLOCK_SIZE * HEAD_DIM / 2; + const uint32_t kv_h_stride = BLOCK_SIZE * HEAD_DIM / 2; + const uint32_t kv_b_stride = HEAD_DIM / 2; + const uint32_t kv_d_stride = BLOCK_SIZE / 2; + const uint32_t q_start_seq_id = cu_seqlens_q[batch_id]; + const uint32_t q_base_seq_id_this_block = tile_id * num_frags_x * 16; + const uint32_t q_offset = q_start_seq_id * q_ori_n_stride + + q_head_idx * HEAD_DIM + + tid % 8 * num_elems_per_128b(); + const uint32_t o_offset = q_start_seq_id * q_n_stride + + q_head_idx * HEAD_DIM + + tid % 8 * num_elems_per_128b(); + T *q_base_ptr = q + q_offset; + + T *o_base_ptr_T = nullptr; + OutT *o_base_ptr_int8 = nullptr; + if (num_chunks_this_seq <= 1) { + o_base_ptr_int8 = out + o_offset; + } else { + if (ENABLE_PREFILL) { + o_base_ptr_T = tmp_workspace + batch_id * num_chunks * q_n_stride + + chunk_idx * q_n_stride + q_head_idx * HEAD_DIM + + tid % 8 * num_elems_per_128b(); + } else { + o_base_ptr_T = + tmp_workspace + + batch_id * speculate_max_draft_token_num * num_chunks * q_n_stride + + chunk_idx * q_n_stride + q_head_idx * HEAD_DIM + + tid % 8 * num_elems_per_128b(); + } + } + const int *mask_offset_this_seq = + mask_offset ? mask_offset + q_start_seq_id * 2 : nullptr; + smem_t qo_smem(smem); + + uint32_t q_smem_offset_r = + smem_t::get_permuted_offset(tid % 16, tid / 16); + load_q_global_smem_multi_warps(q_base_ptr, + &qo_smem, + q_base_seq_id_this_block, + q_end, + q_ori_n_stride, + HEAD_DIM); + commit_group(); + wait_group<0>(); + __syncthreads(); + + q_smem_inplace_multiply_sm_scale_multi_warps( + &qo_smem, scale); + + T cache_k_scale_frag[num_frags_y][4]; + T cache_k_zp_frag[num_frags_y][4]; + T magic_number; + if constexpr (std::is_same::value) { + magic_number = static_cast(1032.f); + } else { + magic_number = static_cast(136.f); + } +#pragma unroll + for (uint32_t fy = 0; fy < num_frags_y; ++fy) { + *(reinterpret_cast(&cache_k_scale_frag[fy][0])) = + *(reinterpret_cast(&cache_k_scale_smem[fy * 16]) + tid % 4); + *(reinterpret_cast(&cache_k_scale_frag[fy][2])) = + *(reinterpret_cast(&cache_k_scale_smem[fy * 16]) + tid % 4 + + 4); + *(reinterpret_cast(&cache_k_zp_frag[fy][0])) = + *(reinterpret_cast(&cache_k_zero_point_smem[fy * 16]) + + tid % 4); + *(reinterpret_cast(&cache_k_zp_frag[fy][2])) = + *(reinterpret_cast(&cache_k_zero_point_smem[fy * 16]) + + tid % 4 + 4); +#pragma unroll + for (uint32_t zp_i = 0; zp_i < 4; ++zp_i) { + cache_k_zp_frag[fy][zp_i] += magic_number; + } + } + T cache_v_scale_frag[num_frags_y][2]; + T cache_v_zp_frag[num_frags_y][2]; +#pragma unroll + for (uint32_t fy = 0; fy < num_frags_y; ++fy) { + cache_v_scale_frag[fy][0] = cache_v_scale_smem[fy * 16 + tid / 4]; + cache_v_scale_frag[fy][1] = cache_v_scale_smem[fy * 16 + tid / 4 + 8]; + cache_v_zp_frag[fy][0] = + cache_v_zero_point_smem[fy * 16 + tid / 4] + magic_number; + cache_v_zp_frag[fy][1] = + cache_v_zero_point_smem[fy * 16 + tid / 4 + 8] + magic_number; + } + + smem_t k_smem(smem + num_frags_x * 16 * HEAD_DIM * sizeof(T)), + v_smem(smem + num_frags_x * 16 * HEAD_DIM * sizeof(T) + + NUM_WARP_KV * num_frags_z * 16 * HEAD_DIM / 2 * sizeof(CacheT)); + + const uint32_t num_iterations = div_up( + CAUSAL + ? (min(chunk_len, + sub_if_greater_or_zero( + kv_len - q_len + + div_up((tile_id + 1) * num_rows_per_block, GROUP_SIZE), + chunk_start))) + : chunk_len, + NUM_WARP_KV * num_frags_z * 16); + const uint32_t mask_check_iteration = + (CAUSAL ? (min(chunk_len, + sub_if_greater_or_zero(kv_len - q_len, chunk_start))) + : mask_offset ? 0 + : chunk_len) / + (NUM_WARP_KV * num_frags_z * 16); + + uint32_t k_smem_offset_r = + smem_t::get_permuted_offset( + wid * num_frags_z * 16 + 8 * (tid / 16) + tid % 8, (tid % 16) / 8); + + uint32_t v_smem_offset_r = + smem_t::get_permuted_offset( + wid * num_frags_y * 16 + 8 * (tid / 16) + tid % 8, (tid % 16) / 8); + + uint32_t k_smem_offset_w = + smem_t::get_permuted_offset( + wid * 8 + tid / 4, tid % 4); + uint32_t v_smem_offset_w = + smem_t::get_permuted_offset( + wid * 16 + tid / 2, tid % 2); + + uint32_t kv_idx_base = chunk_start; + const uint32_t const_k_offset = kv_head_idx * kv_h_stride + + (wid * 8 + tid / 4) * kv_b_stride + + tid % 4 * num_elems_per_128b(); + const uint32_t const_v_offset = kv_head_idx * kv_h_stride + + (wid * 16 + tid / 2) * kv_d_stride + + tid % 2 * num_elems_per_128b(); + + produce_k_blockwise_c4(k_smem, + &k_smem_offset_w, + cache_k, + block_table_now, + kv_head_idx, + kv_n_stride, + kv_h_stride, + kv_b_stride, + kv_idx_base, + chunk_end, + const_k_offset); + commit_group(); + produce_v_blockwise_c4(v_smem, + &v_smem_offset_w, + cache_v, + block_table_now, + kv_head_idx, + kv_n_stride, + kv_h_stride, + kv_d_stride, + kv_idx_base, + chunk_end, + const_v_offset); + commit_group(); +#pragma unroll 1 + for (uint32_t iter = 0; iter < num_iterations; ++iter) { + wait_group<1>(); + __syncthreads(); + compute_qk_c4( + &qo_smem, + &q_smem_offset_r, + &k_smem, + &k_smem_offset_r, + s_frag, + cache_k_scale_frag, + cache_k_zp_frag); + if (iter >= mask_check_iteration || sliding_window > 0) { + mask_s( + attn_mask ? attn_mask + batch_id * attn_mask_len * attn_mask_len + : nullptr, + q_base_seq_id_this_block, + kv_idx_base + wid * num_frags_z * 16, + q_len, + kv_len, + chunk_end, + attn_mask_len, + s_frag, + mask_offset_this_seq, + sliding_window, + sink_size); + } + + update_mdo_states( + s_frag, o_frag, m_frag, d_frag); + __syncthreads(); + + kv_idx_base += NUM_WARP_KV * num_frags_z * 16; + produce_k_blockwise_c4(k_smem, + &k_smem_offset_w, + cache_k, + block_table_now, + kv_head_idx, + kv_n_stride, + kv_h_stride, + kv_b_stride, + kv_idx_base, + chunk_end, + const_k_offset); + commit_group(); + wait_group<1>(); + __syncthreads(); + + // compute sfm*v + compute_sfm_v_c4(&v_smem, + &v_smem_offset_r, + s_frag, + o_frag, + d_frag, + cache_v_scale_frag, + cache_v_zp_frag); + __syncthreads(); + + produce_v_blockwise_c4(v_smem, + &v_smem_offset_w, + cache_v, + block_table_now, + kv_head_idx, + kv_n_stride, + kv_h_stride, + kv_d_stride, + kv_idx_base, + chunk_end, + const_v_offset); + commit_group(); + } + wait_group<0>(); + __syncthreads(); + + merge_block_res_v2( + o_frag, reinterpret_cast(smem), m_frag, d_frag, wid, tid); + + if (num_chunks_this_seq <= 1) { + if (sinks) { + float current_sinks[num_frags_x][2]; +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (uint32_t j = 0; j < 2; ++j) { + const uint32_t h_offset = + (q_base_seq_id_this_block + fx * 16 + tid / 4 + 8 * j) % + GROUP_SIZE; + current_sinks[fx][j] = + static_cast(sinks[q_head_idx + h_offset]); + } + } + normalize_d( + o_frag, d_frag, m_frag, current_sinks); + } else { + normalize_d(o_frag, d_frag); + } + } + + // write o + // [num_frags_x, 16, num_frags_y, 16] + if (num_chunks_this_seq <= 1) { + write_o_reg_gmem_multi_warps_shift_smooth_quant( + o_frag, + &qo_smem, + o_base_ptr_int8, + shift_bias, + smooth_weight, + q_base_seq_id_this_block, + q_head_idx, + quant_max_bound, + quant_min_bound, + in_scale, + q_len, + q_n_stride, + HEAD_DIM); + } else { + write_o_reg_gmem_multi_warps_shift_smooth_quant( + o_frag, + &qo_smem, + o_base_ptr_T, + shift_bias, + smooth_weight, + q_base_seq_id_this_block, + q_head_idx, + quant_max_bound, + quant_min_bound, + in_scale, + q_len, + q_n_stride * num_chunks, + HEAD_DIM); + } + + if (num_chunks_this_seq > 1) { + if (wid == 0) { +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (uint32_t j = 0; j < 2; ++j) { + const uint32_t qo_idx_now = + q_base_seq_id_this_block + tid / 4 + j * 8 + fx * 16; + const uint32_t qo_head_idx = q_head_idx + qo_idx_now % GROUP_SIZE; + const uint32_t qo_idx = q_start_seq_id + qo_idx_now / GROUP_SIZE; + if (qo_idx - q_start_seq_id < q_len) { + uint32_t offset; + if (ENABLE_PREFILL) { + offset = (batch_id * num_chunks + chunk_idx) * q_num_heads + + qo_head_idx; + } else { + offset = ((batch_id * speculate_max_draft_token_num + + qo_idx_now / GROUP_SIZE) * + num_chunks + + chunk_idx) * + q_num_heads + + qo_head_idx; + } + tmp_m[offset] = m_frag[fx][j]; + tmp_d[offset] = d_frag[fx][j]; + } + } + } + } + } +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaTriggerProgrammaticLaunchCompletion(); +#endif +} + +template +void MultiQueryAppendC4Attention( + const AppendAttnMetaData &meta_data, + const paddle::Tensor &qkv, + const paddle::Tensor &cache_k, + const paddle::Tensor &cache_v, + const paddle::optional &attn_mask, + const paddle::Tensor &cache_k_scale, + const paddle::Tensor &cache_v_scale, + const paddle::optional &cache_k_zp, + const paddle::optional &cache_v_zp, + const paddle::optional &shift_bias, + const paddle::optional &smooth_weight, + const paddle::optional &sinks, + const paddle::Tensor &seq_lens_q, + const paddle::Tensor &seq_lens_kv, + const paddle::Tensor &seq_lens_encoder, + const paddle::Tensor &batch_id_per_token, + const paddle::Tensor &cu_seqlens_q, + const paddle::Tensor &block_table, + const paddle::Tensor &batch_ids, + const paddle::Tensor &tile_ids_per_batch, + const int num_blocks_x_cpu, + const int max_seq_len, + const int max_dec_len, + const float quant_max_bound, + const float quant_min_bound, + const float in_scale, + const int max_partition_size, + const int encoder_max_partition_size, + const int speculate_max_draft_token_num, + const bool is_decoder, + cudaStream_t &stream, + paddle::Tensor *out, + const int sliding_window = 0, + const int sink_size = 0) { + using NV_TYPE = typename cascade_attn_type_traits::type; + using OUT_NV_TYPE = typename cascade_attn_type_traits::type; + + auto num_heads = meta_data.q_num_heads; + auto kv_num_heads = meta_data.kv_num_heads; + auto token_num = meta_data.token_nums; + auto bsz = meta_data.batch_size; + auto max_block_num_per_seq = meta_data.max_blocks_per_seq; + + constexpr uint32_t num_warps = 4; + constexpr uint32_t NUM_WARP_KV = num_warps / NUM_WARP_Q; + constexpr uint32_t num_frags_x = BLOCK_SHAPE_Q / (16 * NUM_WARP_Q); + constexpr uint32_t num_frags_y = HEAD_DIM / 16; + constexpr uint32_t num_qrow_per_block = NUM_WARP_Q * num_frags_x * 16; + + auto *allocator = paddle::GetAllocator(qkv.place()); + + const float scale = 1.f / sqrt(HEAD_DIM); + + if constexpr (NUM_WARP_Q == 4) { + constexpr uint32_t num_frags_z = BLOCK_SIZE / 16; + constexpr uint32_t smem_size = + num_warps * num_frags_x * 16 * HEAD_DIM * sizeof(T) + + num_frags_z * 16 * HEAD_DIM / 2 * sizeof(uint8_t) * 2 + + HEAD_DIM * 4 * sizeof(T); + auto split_kv_kernel = + multi_query_append_attention_c4_kernel; + cudaFuncSetAttribute(split_kv_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + const int dev_id = 0; + int sm_count; + int act_blocks_per_sm; + cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id); + cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &act_blocks_per_sm, split_kv_kernel, num_warps * 32, smem_size); + assert(act_blocks_per_sm > 1); + const int num_blocks_per_wave = sm_count * act_blocks_per_sm; + const int num_blocks_need = num_blocks_x_cpu * kv_num_heads; + const int max_num_chunks = div_up(num_blocks_per_wave, num_blocks_need); + const float ratio = static_cast(num_blocks_need) / + static_cast(num_blocks_per_wave); + + uint32_t chunk_size = static_cast(encoder_max_partition_size); + const int num_chunks = div_up(max_dec_len, chunk_size); + + dim3 grids(num_blocks_x_cpu, num_chunks, kv_num_heads); + dim3 blocks(32, num_warps); + if (num_chunks <= 1) { + auto nosplit_kv_kernel = + multi_query_append_attention_c4_kernel; + if (smem_size >= 48 * 1024) { + cudaFuncSetAttribute(nosplit_kv_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + } + launchWithPdlWhenEnabled( + nosplit_kv_kernel, + grids, + blocks, + smem_size, + stream, + reinterpret_cast(const_cast(qkv.data())), + const_cast(cache_k.data()), + const_cast(cache_v.data()), + reinterpret_cast(const_cast(cache_k_scale.data())), + cache_k_zp ? reinterpret_cast( + const_cast(cache_k_zp.get().data())) + : nullptr, + reinterpret_cast(const_cast(cache_v_scale.data())), + cache_v_zp ? reinterpret_cast( + const_cast(cache_v_zp.get().data())) + : nullptr, + shift_bias ? reinterpret_cast( + const_cast(shift_bias.get().data())) + : nullptr, + smooth_weight ? reinterpret_cast( + const_cast(smooth_weight.get().data())) + : nullptr, + sinks ? reinterpret_cast( + const_cast(sinks.get().data())) + : nullptr, + seq_lens_q.data(), + seq_lens_kv.data(), + batch_ids.data(), + tile_ids_per_batch.data(), + cu_seqlens_q.data(), + block_table.data(), + meta_data.mask_offset, + max_seq_len, + max_dec_len, + max_block_num_per_seq, + scale, + quant_max_bound, + quant_min_bound, + in_scale, + chunk_size, + num_blocks_x_cpu, + nullptr, + nullptr, + nullptr, + reinterpret_cast(out->data()), + speculate_max_draft_token_num, + sliding_window, + sink_size); + } else { + phi::Allocator::AllocationPtr tmp_workspace, tmp_m, tmp_d; + if (ENABLE_PREFILL) { + tmp_workspace = allocator->Allocate( + phi::SizeOf(qkv.dtype()) * + static_cast(token_num * num_chunks * num_heads * HEAD_DIM)); + tmp_m = allocator->Allocate( + phi::SizeOf(paddle::DataType::FLOAT32) * + static_cast(token_num * num_chunks * num_heads)); + tmp_d = allocator->Allocate( + phi::SizeOf(paddle::DataType::FLOAT32) * + static_cast(token_num * num_chunks * num_heads)); + } else { + tmp_workspace = allocator->Allocate( + phi::SizeOf(qkv.dtype()) * + static_cast(speculate_max_draft_token_num * bsz * + num_chunks * num_heads * HEAD_DIM)); + tmp_m = allocator->Allocate( + phi::SizeOf(paddle::DataType::FLOAT32) * + static_cast(speculate_max_draft_token_num * bsz * + num_chunks * num_heads)); + tmp_d = allocator->Allocate( + phi::SizeOf(paddle::DataType::FLOAT32) * + static_cast(speculate_max_draft_token_num * bsz * + num_chunks * num_heads)); + } + launchWithPdlWhenEnabled( + split_kv_kernel, + grids, + blocks, + smem_size, + stream, + reinterpret_cast(const_cast(qkv.data())), + const_cast(cache_k.data()), + const_cast(cache_v.data()), + reinterpret_cast(const_cast(cache_k_scale.data())), + cache_k_zp ? reinterpret_cast( + const_cast(cache_k_zp.get().data())) + : nullptr, + reinterpret_cast(const_cast(cache_v_scale.data())), + cache_v_zp ? reinterpret_cast( + const_cast(cache_v_zp.get().data())) + : nullptr, + shift_bias ? reinterpret_cast( + const_cast(shift_bias.get().data())) + : nullptr, + smooth_weight ? reinterpret_cast( + const_cast(smooth_weight.get().data())) + : nullptr, + sinks ? reinterpret_cast( + const_cast(sinks.get().data())) + : nullptr, + seq_lens_q.data(), + seq_lens_kv.data(), + batch_ids.data(), + tile_ids_per_batch.data(), + cu_seqlens_q.data(), + block_table.data(), + meta_data.mask_offset, + max_seq_len, + max_dec_len, + max_block_num_per_seq, + scale, + quant_max_bound, + quant_min_bound, + in_scale, + chunk_size, + num_blocks_x_cpu, + reinterpret_cast(tmp_workspace->ptr()), + static_cast(tmp_m->ptr()), + static_cast(tmp_d->ptr()), + reinterpret_cast(out->data()), + speculate_max_draft_token_num, + sliding_window, + sink_size); + // merge + constexpr int vec_size = num_elems_per_128b(); + + constexpr int blockx = HEAD_DIM / vec_size; + constexpr int blocky = (128 + blockx - 1) / blockx; + dim3 grids_merge(min(sm_count * 4, token_num), num_heads); + dim3 blocks_merge(blockx, blocky); + launchWithPdlWhenEnabled( + merge_multi_chunks_v2_kernel, + grids_merge, + blocks_merge, + 0, + stream, + reinterpret_cast(tmp_workspace->ptr()), + static_cast(tmp_m->ptr()), + static_cast(tmp_d->ptr()), + seq_lens_q.data(), + seq_lens_kv.data(), + seq_lens_encoder.data(), + batch_id_per_token.data(), + cu_seqlens_q.data(), + shift_bias ? reinterpret_cast( + const_cast(shift_bias.get().data())) + : nullptr, + smooth_weight ? reinterpret_cast( + const_cast(smooth_weight.get().data())) + : nullptr, + sinks ? reinterpret_cast( + const_cast(sinks.get().data())) + : nullptr, + reinterpret_cast(out->data()), + quant_max_bound, + quant_min_bound, + in_scale, + max_seq_len, + num_chunks, + num_heads, + chunk_size, + HEAD_DIM, + token_num, + speculate_max_draft_token_num); + } + } else { + constexpr uint32_t num_frags_z = BLOCK_SIZE / 16 / NUM_WARP_KV * 4; + constexpr uint32_t smem_size = + num_frags_x * 16 * HEAD_DIM * sizeof(T) + + NUM_WARP_KV * num_frags_z * 16 * HEAD_DIM / 2 * sizeof(uint8_t) * 2 + + HEAD_DIM * 4 * sizeof(T); + auto split_kv_kernel = + multi_query_append_attention_c4_warp1_4_kernel; + if (smem_size >= 48 * 1024) { + cudaFuncSetAttribute(split_kv_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + } + const int dev_id = 0; + int sm_count; + int act_blocks_per_sm; + cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id); + cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &act_blocks_per_sm, split_kv_kernel, num_warps * 32, smem_size); + assert(act_blocks_per_sm > 1); + const int num_blocks_per_wave = sm_count * act_blocks_per_sm; + const int num_blocks_need = num_blocks_x_cpu * kv_num_heads; + const int max_num_chunks = div_up(num_blocks_per_wave, num_blocks_need); + const float ratio = static_cast(num_blocks_need) / + static_cast(num_blocks_per_wave); + + uint32_t chunk_size = static_cast(max_partition_size); + + const int num_chunks = div_up(max_seq_len, chunk_size); + uint32_t attn_mask_len; + if (attn_mask) { + attn_mask_len = attn_mask.get().shape()[1]; + } else { + attn_mask_len = -1; + } + + dim3 grids(num_blocks_x_cpu, num_chunks, kv_num_heads); + dim3 blocks(32, num_warps); + if (num_chunks <= 0) { + auto nosplit_kv_kernel = + multi_query_append_attention_c4_warp1_4_kernel; + if (smem_size >= 48 * 1024) { + cudaFuncSetAttribute(nosplit_kv_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + } + launchWithPdlWhenEnabled( + nosplit_kv_kernel, + grids, + blocks, + smem_size, + stream, + reinterpret_cast(const_cast(qkv.data())), + const_cast(cache_k.data()), + const_cast(cache_v.data()), + reinterpret_cast(const_cast(cache_k_scale.data())), + cache_k_zp ? reinterpret_cast( + const_cast(cache_k_zp.get().data())) + : nullptr, + reinterpret_cast(const_cast(cache_v_scale.data())), + cache_v_zp ? reinterpret_cast( + const_cast(cache_v_zp.get().data())) + : nullptr, + shift_bias ? reinterpret_cast( + const_cast(shift_bias.get().data())) + : nullptr, + smooth_weight ? reinterpret_cast( + const_cast(smooth_weight.get().data())) + : nullptr, + sinks ? reinterpret_cast( + const_cast(sinks.get().data())) + : nullptr, + seq_lens_q.data(), + seq_lens_kv.data(), + seq_lens_encoder.data(), + batch_ids.data(), + tile_ids_per_batch.data(), + cu_seqlens_q.data(), + block_table.data(), + meta_data.mask_offset, + attn_mask ? const_cast(attn_mask.get().data()) + : nullptr, + max_seq_len, + max_dec_len, + max_block_num_per_seq, + scale, + quant_max_bound, + quant_min_bound, + in_scale, + chunk_size, + num_blocks_x_cpu, + nullptr, + nullptr, + nullptr, + reinterpret_cast(out->data()), + speculate_max_draft_token_num, + attn_mask_len, + sliding_window, + sink_size); + } else { + phi::Allocator::AllocationPtr tmp_workspace, tmp_m, tmp_d; + if (is_decoder) { + tmp_workspace = allocator->Allocate( + phi::SizeOf(qkv.dtype()) * + static_cast(bsz * num_chunks * num_heads * HEAD_DIM)); + tmp_m = allocator->Allocate( + phi::SizeOf(paddle::DataType::FLOAT32) * + static_cast(bsz * num_chunks * num_heads)); + tmp_d = allocator->Allocate( + phi::SizeOf(paddle::DataType::FLOAT32) * + static_cast(bsz * num_chunks * num_heads)); + } else { + if (ENABLE_PREFILL) { + tmp_workspace = + allocator->Allocate(phi::SizeOf(qkv.dtype()) * + static_cast(token_num * num_chunks * + num_heads * HEAD_DIM)); + tmp_m = allocator->Allocate( + phi::SizeOf(paddle::DataType::FLOAT32) * + static_cast(token_num * num_chunks * num_heads)); + tmp_d = allocator->Allocate( + phi::SizeOf(paddle::DataType::FLOAT32) * + static_cast(token_num * num_chunks * num_heads)); + } else { + tmp_workspace = allocator->Allocate( + phi::SizeOf(qkv.dtype()) * + static_cast(speculate_max_draft_token_num * bsz * + num_chunks * num_heads * HEAD_DIM)); + tmp_m = allocator->Allocate( + phi::SizeOf(paddle::DataType::FLOAT32) * + static_cast(speculate_max_draft_token_num * bsz * + num_chunks * num_heads)); + tmp_d = allocator->Allocate( + phi::SizeOf(paddle::DataType::FLOAT32) * + static_cast(speculate_max_draft_token_num * bsz * + num_chunks * num_heads)); + } + } + launchWithPdlWhenEnabled( + split_kv_kernel, + grids, + blocks, + smem_size, + stream, + reinterpret_cast(const_cast(qkv.data())), + const_cast(cache_k.data()), + const_cast(cache_v.data()), + reinterpret_cast(const_cast(cache_k_scale.data())), + cache_k_zp ? reinterpret_cast( + const_cast(cache_k_zp.get().data())) + : nullptr, + reinterpret_cast(const_cast(cache_v_scale.data())), + cache_v_zp ? reinterpret_cast( + const_cast(cache_v_zp.get().data())) + : nullptr, + shift_bias ? reinterpret_cast( + const_cast(shift_bias.get().data())) + : nullptr, + smooth_weight ? reinterpret_cast( + const_cast(smooth_weight.get().data())) + : nullptr, + sinks ? reinterpret_cast( + const_cast(sinks.get().data())) + : nullptr, + seq_lens_q.data(), + seq_lens_kv.data(), + seq_lens_encoder.data(), + batch_ids.data(), + tile_ids_per_batch.data(), + cu_seqlens_q.data(), + block_table.data(), + meta_data.mask_offset, + attn_mask ? const_cast(attn_mask.get().data()) + : nullptr, + max_seq_len, + max_dec_len, + max_block_num_per_seq, + scale, + quant_max_bound, + quant_min_bound, + in_scale, + chunk_size, + num_blocks_x_cpu, + reinterpret_cast(tmp_workspace->ptr()), + static_cast(tmp_m->ptr()), + static_cast(tmp_d->ptr()), + reinterpret_cast(out->data()), + speculate_max_draft_token_num, + attn_mask_len, + sliding_window, + sink_size); + // merge + constexpr int vec_size = num_elems_per_128b(); + if (is_decoder) { + constexpr int blockx = HEAD_DIM / vec_size; + constexpr int blocky = (128 + blockx - 1) / blockx; + dim3 grids_merge(bsz, num_heads); + dim3 blocks_merge(blockx, blocky); + launchWithPdlWhenEnabled( + merge_multi_chunks_decoder_kernel, + grids_merge, + blocks_merge, + 0, + stream, + reinterpret_cast(tmp_workspace->ptr()), + static_cast(tmp_m->ptr()), + static_cast(tmp_d->ptr()), + seq_lens_q.data(), + seq_lens_kv.data(), + seq_lens_encoder.data(), + cu_seqlens_q.data(), + shift_bias ? reinterpret_cast( + const_cast(shift_bias.get().data())) + : nullptr, + smooth_weight ? reinterpret_cast( + const_cast(smooth_weight.get().data())) + : nullptr, + sinks ? reinterpret_cast( + const_cast(sinks.get().data())) + : nullptr, + reinterpret_cast(out->data()), + quant_max_bound, + quant_min_bound, + in_scale, + max_seq_len, + num_chunks, + num_heads, + chunk_size, + HEAD_DIM); + } else { + constexpr int blockx = HEAD_DIM / vec_size; + constexpr int blocky = (128 + blockx - 1) / blockx; + dim3 grids_merge(min(sm_count * 4, token_num), num_heads); + dim3 blocks_merge(blockx, blocky); + launchWithPdlWhenEnabled( + merge_multi_chunks_v2_kernel, + grids_merge, + blocks_merge, + 0, + stream, + reinterpret_cast(tmp_workspace->ptr()), + static_cast(tmp_m->ptr()), + static_cast(tmp_d->ptr()), + seq_lens_q.data(), + seq_lens_kv.data(), + seq_lens_encoder.data(), + batch_id_per_token.data(), + cu_seqlens_q.data(), + shift_bias ? reinterpret_cast( + const_cast(shift_bias.get().data())) + : nullptr, + smooth_weight ? reinterpret_cast( + const_cast(smooth_weight.get().data())) + : nullptr, + sinks ? reinterpret_cast( + const_cast(sinks.get().data())) + : nullptr, + reinterpret_cast(out->data()), + quant_max_bound, + quant_min_bound, + in_scale, + max_seq_len, + num_chunks, + num_heads, + chunk_size, + HEAD_DIM, + token_num, + speculate_max_draft_token_num); + } + } + } +} diff --git a/custom_ops/gpu_ops/append_attn/multiquery_attention_c4_kernel.h b/custom_ops/gpu_ops/append_attn/multiquery_attention_c4_kernel.h new file mode 100644 index 00000000000..5e00b2d3a19 --- /dev/null +++ b/custom_ops/gpu_ops/append_attn/multiquery_attention_c4_kernel.h @@ -0,0 +1,61 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include "append_attention_func.cuh" + +template +void MultiQueryAppendC4Attention( + const AppendAttnMetaData &meta_data, + const paddle::Tensor &qkv, + const paddle::Tensor &cache_k, + const paddle::Tensor &cache_v, + const paddle::optional &attn_mask, + const paddle::Tensor &cache_k_scale, + const paddle::Tensor &cache_v_scale, + const paddle::optional &cache_k_zp, + const paddle::optional &cache_v_zp, + const paddle::optional &shift_bias, + const paddle::optional &smooth_weight, + const paddle::optional &sinks, + const paddle::Tensor &seq_lens_q, + const paddle::Tensor &seq_lens_kv, + const paddle::Tensor &seq_lens_encoder, + const paddle::Tensor &batch_id_per_token, + const paddle::Tensor &cu_seqlens_q, + const paddle::Tensor &block_table, + const paddle::Tensor &batch_ids, + const paddle::Tensor &tile_ids_per_batch, + const int num_blocks_x_cpu, + const int max_seq_len, + const int max_dec_len, + const float quant_max_bound, + const float quant_min_bound, + const float in_scale, + const int max_partition_size, + const int encoder_max_partition_size, + const int speculate_max_draft_token_num, + const bool is_decoder, + cudaStream_t &stream, + paddle::Tensor *out, + const int sliding_window, + const int sink_size); diff --git a/custom_ops/gpu_ops/append_attn/multiquery_attention_c8_impl.cuh b/custom_ops/gpu_ops/append_attn/multiquery_attention_c8_impl.cuh new file mode 100644 index 00000000000..3d6ea9f8fd7 --- /dev/null +++ b/custom_ops/gpu_ops/append_attn/multiquery_attention_c8_impl.cuh @@ -0,0 +1,1779 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include "multiquery_attention_c8_kernel.h" + +template +__global__ void multi_query_append_attention_c8_kernel( + T *__restrict__ q, // [token_num, (num_heads + 2* kv_num_head) * head_dim] + CacheT *__restrict__ cache_k, // [max_block_num, num_heads, block_size, + // head_dim] + CacheT *__restrict__ cache_v, + const T *__restrict__ cache_k_scale, // [num_kv_heads] or [max_block_num, + // num_heads, block_size] + const T *__restrict__ cache_v_scale, // [num_kv_heads] or [max_block_num, + // num_heads, block_size] + const T *__restrict__ shift_bias, // [q_num_heads * HEAD_DIM] + const T *__restrict__ smooth_weight, // [q_num_heads * HEAD_DIM] + const T *__restrict__ sinks, // [q_num_heads] + const int *__restrict__ seq_lens, + const int *__restrict__ seq_lens_kv, + const int *__restrict__ batch_ids, + const int *__restrict__ tile_ids_per_batch, + const int *__restrict__ cu_seqlens_q, + const int *__restrict__ block_table, // [bsz, block_num_per_seq] + const int *__restrict__ mask_offset, + const int max_seq_len, + const int max_dec_len, + const int max_block_num_per_seq, + const float scale, + const float quant_max_bound, + const float quant_min_bound, + const float in_scale, + const uint32_t chunk_size, + const int num_blocks_x_cpu, + T *__restrict__ tmp_workspace, // split kv [token_num, num_chunks, + // num_heads, head_dim] + float *__restrict__ tmp_m, // [token_num, num_chunks, num_heads] + float *__restrict__ tmp_d, // [token_num, num_chunks, num_heads] + OutT *__restrict__ out, + const int speculate_max_draft_token_num = 5, + const int sliding_window = 0, + const int sink_size = 0) { + constexpr uint32_t num_vecs_per_head = + HEAD_DIM / num_elems_per_128b(); // 128 / 8 = 16 + constexpr uint32_t num_vecs_per_head_k = + HEAD_DIM / num_elems_per_128b(); // 128 / 16 = 8 + constexpr uint32_t num_vecs_per_blocksize = + BLOCK_SIZE / num_elems_per_128b(); // 64 / 16 = 4 + constexpr uint32_t inv_k_stride = 8 / num_vecs_per_head_k; + constexpr uint32_t inv_v_stride = 8 / num_vecs_per_blocksize; + const uint32_t btid = blockIdx.x, kv_head_idx = blockIdx.z; + const uint32_t kv_num_heads = gridDim.z; + const uint32_t q_num_heads = kv_num_heads * GROUP_SIZE; + const uint32_t q_head_idx = kv_head_idx * GROUP_SIZE; + const uint32_t tid = threadIdx.x, wid = threadIdx.y; + const uint32_t num_chunks = gridDim.y; + const uint32_t chunk_idx = blockIdx.y; + + const uint32_t batch_id = batch_ids[btid]; + const uint32_t tile_id = tile_ids_per_batch[btid]; + const uint32_t num_rows_per_block = NUM_WARPS * num_frags_x * 16; + const int *block_table_now = nullptr; + + block_table_now = block_table + batch_id * max_block_num_per_seq; + + // When cudagraph capture prefill, may launch more gridDim.x + if (btid >= static_cast(num_blocks_x_cpu)) { + return; + } + + const uint32_t q_len = seq_lens[batch_id]; + if (q_len <= 0) { + return; + } + + T cache_k_scale_reg[IsDynamicC8 ? num_frags_z * 2 : num_frags_y * 4]; + T cache_v_scale_reg[IsDynamicC8 ? num_frags_z * 4 : num_frags_y * 2]; + if constexpr (!IsDynamicC8) { + if constexpr (is_scale_channel_wise) { + int scale_col_base = threadIdx.x % 4 * 2 + kv_head_idx * HEAD_DIM; + const T *cache_k_scale_cur_head = cache_k_scale + scale_col_base; + for (int i = 0; i < num_frags_y; ++i) { + const int scale_idx = i * 16; + cache_k_scale_reg[i * 4] = cache_k_scale_cur_head[scale_idx]; + cache_k_scale_reg[i * 4 + 1] = cache_k_scale_cur_head[scale_idx + 1]; + cache_k_scale_reg[i * 4 + 2] = cache_k_scale_cur_head[scale_idx + 8]; + cache_k_scale_reg[i * 4 + 3] = cache_k_scale_cur_head[scale_idx + 9]; + } + scale_col_base = threadIdx.x / 4 + kv_head_idx * HEAD_DIM; + const T *cache_v_scale_cur_head = cache_v_scale + scale_col_base; + for (int i = 0; i < num_frags_y; ++i) { + const int scale_idx = i * 16; + cache_v_scale_reg[i * 2] = cache_v_scale_cur_head[scale_idx]; + cache_v_scale_reg[i * 2 + 1] = cache_v_scale_cur_head[scale_idx + 8]; + } + } else { + cache_k_scale_reg[0] = cache_k_scale[kv_head_idx]; + cache_v_scale_reg[0] = cache_v_scale[kv_head_idx]; + } + } + + const uint32_t q_end = + min(q_len, div_up((tile_id + 1) * num_rows_per_block, GROUP_SIZE)); + uint32_t kv_len = seq_lens_kv[batch_id]; + if (ENABLE_PREFILL) { + kv_len += q_len; + if (kv_len <= 0) { + return; + } + } else { + if (kv_len <= 0) { + return; + } + kv_len += q_len; + } + const uint32_t num_chunks_this_seq = div_up(kv_len, chunk_size); + if (chunk_idx >= num_chunks_this_seq) { + return; + } + + const uint32_t chunk_start = partition_kv ? chunk_idx * chunk_size : 0; + const uint32_t chunk_end = + partition_kv ? min(kv_len, chunk_start + chunk_size) : kv_len; + const uint32_t chunk_len = chunk_end - chunk_start; + + extern __shared__ uint8_t smem[]; + float s_frag[num_frags_x][num_frags_z][8]; + float o_frag[num_frags_x][num_frags_y][8]; + float m_frag[num_frags_x][2]; + float d_frag[num_frags_x][2]; + init_states(o_frag, m_frag, d_frag); + + const uint32_t q_n_stride = q_num_heads * HEAD_DIM; + const uint32_t q_ori_n_stride = (q_num_heads + kv_num_heads * 2) * HEAD_DIM; + const uint32_t kv_n_stride = kv_num_heads * BLOCK_SIZE * HEAD_DIM; + const uint32_t kv_h_stride = BLOCK_SIZE * HEAD_DIM; + const uint32_t kv_b_stride = HEAD_DIM; + const uint32_t kv_d_stride = BLOCK_SIZE; + const uint32_t q_start_seq_id = cu_seqlens_q[batch_id]; + const uint32_t q_base_seq_id_this_block = + (tile_id * NUM_WARPS + wid) * num_frags_x * 16; + const uint32_t q_offset = q_start_seq_id * q_ori_n_stride + + q_head_idx * HEAD_DIM + + tid % 8 * num_elems_per_128b(); + const uint32_t o_offset = q_start_seq_id * q_n_stride + + q_head_idx * HEAD_DIM + + tid % 8 * num_elems_per_128b(); + T *q_base_ptr = q + q_offset; + + T *o_base_ptr_T = nullptr; + OutT *o_base_ptr_int8 = nullptr; + if constexpr (partition_kv) { + if (ENABLE_PREFILL) { + o_base_ptr_T = tmp_workspace + q_start_seq_id * num_chunks * q_n_stride + + chunk_idx * q_n_stride + q_head_idx * HEAD_DIM + + tid % 8 * num_elems_per_128b(); + } else { + o_base_ptr_T = + tmp_workspace + + batch_id * speculate_max_draft_token_num * num_chunks * q_n_stride + + chunk_idx * q_n_stride + q_head_idx * HEAD_DIM + + tid % 8 * num_elems_per_128b(); + } + } else { + o_base_ptr_int8 = out + o_offset; + } + const int *mask_offset_this_seq = + mask_offset ? mask_offset + q_start_seq_id * 2 : nullptr; + smem_t qo_smem(smem); + + uint32_t q_smem_offset_r = smem_t::get_permuted_offset( + wid * num_frags_x * 16 + tid % 16, tid / 16); // 16 * 16 + load_q_global_smem( + q_base_ptr, + &qo_smem, + q_base_seq_id_this_block, + q_end, + q_ori_n_stride, + HEAD_DIM); + commit_group(); + wait_group<0>(); + __syncthreads(); + + q_smem_inplace_multiply_sm_scale(&qo_smem, + scale); + smem_t k_smem(smem + NUM_WARPS * num_frags_x * 16 * HEAD_DIM * sizeof(T)), + v_smem(smem + NUM_WARPS * num_frags_x * 16 * HEAD_DIM * sizeof(T) + + num_frags_z * 16 * HEAD_DIM * sizeof(CacheT)); + T *k_smem_scale_ptr = nullptr; + T *v_smem_scale_ptr = nullptr; + smem_t k_scale_smem; + smem_t v_scale_smem; + if constexpr (IsDynamicC8) { + k_smem_scale_ptr = reinterpret_cast( + smem + NUM_WARPS * num_frags_x * 16 * HEAD_DIM * sizeof(T) + + num_frags_z * 16 * HEAD_DIM * sizeof(CacheT) * 2); + v_smem_scale_ptr = k_smem_scale_ptr + num_frags_z * 16; + k_scale_smem.base = reinterpret_cast(k_smem_scale_ptr); + v_scale_smem.base = reinterpret_cast(v_smem_scale_ptr); + } + + const uint32_t num_iterations = div_up( + CAUSAL + ? (min(chunk_len, + sub_if_greater_or_zero( + kv_len - q_len + + div_up((tile_id + 1) * num_rows_per_block, GROUP_SIZE), + chunk_start))) + : chunk_len, + num_frags_z * 16); + const uint32_t mask_check_iteration = + (CAUSAL ? (min(chunk_len, + sub_if_greater_or_zero( + kv_len - q_len + + tile_id * num_rows_per_block / GROUP_SIZE, + chunk_start))) + : mask_offset ? 0 + : chunk_len) / + (num_frags_z * 16); + + uint32_t k_smem_offset_r = + smem_t::get_permuted_offset( + 8 * (tid / 16) + tid % 8, (tid % 16) / 8); + + uint32_t v_smem_offset_r = + smem_t::get_permuted_offset( + 8 * (tid / 16) + tid % 8, (tid % 16) / 8); + + uint32_t k_smem_offset_w = + smem_t::get_permuted_offset( + wid * 4 + tid / 8, tid % 8); + uint32_t v_smem_offset_w = + smem_t::get_permuted_offset( + wid * 8 + tid / 4, tid % 4); // 4 * 128 / 8 = 64 + + uint32_t kv_idx_base = chunk_start; + const uint32_t const_k_offset = kv_head_idx * kv_h_stride + + (wid * 4 + tid / 8) * kv_b_stride + + tid % 8 * num_elems_per_128b(); + const uint32_t const_v_offset = kv_head_idx * kv_h_stride + + (wid * 8 + tid / 4) * kv_d_stride + + tid % 4 * num_elems_per_128b(); + + produce_k_blockwise_c8(k_smem, + &k_smem_offset_w, + cache_k, + block_table_now, + kv_head_idx, + kv_n_stride, + kv_h_stride, + kv_b_stride, + kv_idx_base, + chunk_end, + const_k_offset); + if constexpr (IsDynamicC8) { + produce_kv_dynamic_scale_gmem2smem_async(k_scale_smem, + block_table_now, + cache_k_scale, + kv_idx_base, + kv_num_heads, + kv_head_idx, + chunk_end); + } + commit_group(); + produce_v_blockwise_c8(v_smem, + &v_smem_offset_w, + cache_v, + block_table_now, + kv_head_idx, + kv_n_stride, + kv_h_stride, + kv_d_stride, + kv_idx_base, + chunk_end, + const_v_offset); + if constexpr (IsDynamicC8) { + produce_kv_dynamic_scale_gmem2smem_async(v_scale_smem, + block_table_now, + cache_v_scale, + kv_idx_base, + kv_num_heads, + kv_head_idx, + chunk_end); + } + commit_group(); + +#pragma unroll 1 + for (uint32_t iter = 0; iter < num_iterations; ++iter) { + wait_group<1>(); + __syncthreads(); + if constexpr (IsDynamicC8) { + produce_k_dynamic_scale_smem2reg( + k_smem_scale_ptr, cache_k_scale_reg); + } + // s = qk + compute_qk_c8(&qo_smem, + &q_smem_offset_r, + &k_smem, + &k_smem_offset_r, + cache_k_scale_reg, + s_frag); + + // mask according to kv_idx and q_idx + if (iter >= mask_check_iteration || sliding_window > 0) { + mask_s(nullptr, + q_base_seq_id_this_block, + kv_idx_base, + q_len, + kv_len, + chunk_end, + -1, + s_frag, + mask_offset_this_seq, + sliding_window, + sink_size); + } + + // update m,d + update_mdo_states( + s_frag, o_frag, m_frag, d_frag); + __syncthreads(); + + const int ori_kv_idx_base = kv_idx_base; + kv_idx_base += num_frags_z * 16; + produce_k_blockwise_c8(k_smem, + &k_smem_offset_w, + cache_k, + block_table_now, + kv_head_idx, + kv_n_stride, + kv_h_stride, + kv_b_stride, + kv_idx_base, + chunk_end, + const_k_offset); + if constexpr (IsDynamicC8) { + produce_kv_dynamic_scale_gmem2smem_async(k_scale_smem, + block_table_now, + cache_k_scale, + kv_idx_base, + kv_num_heads, + kv_head_idx, + chunk_end); + } + commit_group(); + wait_group<1>(); + __syncthreads(); + if constexpr (IsDynamicC8) { + produce_v_dynamic_scale_smem2reg( + v_smem_scale_ptr, cache_v_scale_reg); + } + + // compute sfm*v + compute_sfm_v_c8( + &v_smem, &v_smem_offset_r, s_frag, o_frag, d_frag, cache_v_scale_reg); + __syncthreads(); + + produce_v_blockwise_c8(v_smem, + &v_smem_offset_w, + cache_v, + block_table_now, + kv_head_idx, + kv_n_stride, + kv_h_stride, + kv_d_stride, + kv_idx_base, + chunk_end, + const_v_offset); + if constexpr (IsDynamicC8) { + produce_kv_dynamic_scale_gmem2smem_async(v_scale_smem, + block_table_now, + cache_v_scale, + kv_idx_base, + kv_num_heads, + kv_head_idx, + chunk_end); + } + commit_group(); + } + wait_group<0>(); + __syncthreads(); + + if constexpr (!partition_kv) { + if (sinks) { + float current_sinks[num_frags_x][2]; +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (uint32_t j = 0; j < 2; ++j) { + const uint32_t h_offset = + (q_base_seq_id_this_block + fx * 16 + tid / 4 + 8 * j) % + GROUP_SIZE; + current_sinks[fx][j] = + static_cast(sinks[q_head_idx + h_offset]); + } + } + normalize_d( + o_frag, d_frag, m_frag, current_sinks); + } else { + normalize_d(o_frag, d_frag); + } + } + + // write o + // [num_frags_x, 16, num_frags_y, 16] + if constexpr (partition_kv) { + write_o_reg_gmem_shift_smooth_quant( + o_frag, + &qo_smem, + o_base_ptr_T, + shift_bias, + smooth_weight, + q_base_seq_id_this_block, + q_head_idx, + quant_max_bound, + quant_min_bound, + in_scale, + q_len, + partition_kv ? q_n_stride * num_chunks : q_n_stride, + HEAD_DIM); + } else { + write_o_reg_gmem_shift_smooth_quant( + o_frag, + &qo_smem, + o_base_ptr_int8, + shift_bias, + smooth_weight, + q_base_seq_id_this_block, + q_head_idx, + quant_max_bound, + quant_min_bound, + in_scale, + q_len, + partition_kv ? q_n_stride * num_chunks : q_n_stride, + HEAD_DIM); + } + + if constexpr (partition_kv) { +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (uint32_t j = 0; j < 2; ++j) { + const uint32_t qo_idx_now = + q_base_seq_id_this_block + tid / 4 + j * 8 + fx * 16; + const uint32_t qo_head_idx = q_head_idx + qo_idx_now % GROUP_SIZE; + const uint32_t qo_idx = q_start_seq_id + qo_idx_now / GROUP_SIZE; + if (qo_idx - q_start_seq_id < q_len) { + uint32_t offset; + if (ENABLE_PREFILL) { + offset = + (qo_idx * num_chunks + chunk_idx) * q_num_heads + qo_head_idx; + } else { + offset = ((batch_id * speculate_max_draft_token_num + + qo_idx_now / GROUP_SIZE) * + num_chunks + + chunk_idx) * + q_num_heads + + qo_head_idx; + } + tmp_m[offset] = m_frag[fx][j]; + tmp_d[offset] = d_frag[fx][j]; + } + } + } + } +} + +template +__global__ void multi_query_append_attention_c8_warp1_4_kernel( + T *__restrict__ q, // [token_num, (num_heads + 2* kv_num_head) * head_dim] + CacheT *__restrict__ cache_k, // [max_block_num, num_heads, block_size, + // head_dim] + CacheT *__restrict__ cache_v, + const T *__restrict__ cache_k_scale, // [num_kv_heads] or [max_block_num, + // num_heads, block_size] + const T *__restrict__ cache_v_scale, // [num_kv_heads] or [max_block_num, + // num_heads, block_size] + const T *__restrict__ shift_bias, // [q_num_heads * HEAD_DIM] + const T *__restrict__ smooth_weight, // [q_num_heads * HEAD_DIM] + const T *__restrict__ sinks, // [q_num_heads] + const int *__restrict__ seq_lens, + const int *__restrict__ seq_lens_kv, + const int *__restrict__ seq_lens_encoder, + const int *__restrict__ batch_ids, + const int *__restrict__ tile_ids_per_batch, + const int *__restrict__ cu_seqlens_q, + const int *__restrict__ block_table, // [bsz, block_num_per_seq] + const int *__restrict__ mask_offset, + const bool *__restrict__ attn_mask, // [bsz, max_q, max_q] for tree-mask + const int max_seq_len, + const int max_dec_len, + const int max_block_num_per_seq, + const float scale, + const float quant_max_bound, + const float quant_min_bound, + const float in_scale, + const uint32_t chunk_size, + const int num_blocks_x_cpu, + T *__restrict__ tmp_workspace, // split kv [token_num, num_chunks, + // num_heads, head_dim] + float *__restrict__ tmp_m, // [token_num, num_chunks, num_heads] + float *__restrict__ tmp_d, // [token_num, num_chunks, num_heads] + OutT *__restrict__ out, + const int speculate_max_draft_token_num = 5, + const uint32_t attn_mask_len = -1, + const int sliding_window = 0, + const int sink_size = 0) { + constexpr uint32_t num_vecs_per_head = HEAD_DIM / num_elems_per_128b(); + constexpr uint32_t num_vecs_per_head_k = + HEAD_DIM / num_elems_per_128b(); + constexpr uint32_t num_vecs_per_blocksize = + BLOCK_SIZE / num_elems_per_128b(); + constexpr uint32_t inv_k_stride = 8 / num_vecs_per_head_k; + constexpr uint32_t inv_v_stride = 8 / num_vecs_per_blocksize; + static_assert(NUM_WARP_Q == 1, "NUM_WARP_Q must be 1"); + static_assert(NUM_WARP_KV == 4, "NUM_WARP_KV must be 4"); + const uint32_t btid = blockIdx.x, kv_head_idx = blockIdx.z; + const uint32_t kv_num_heads = gridDim.z; + const uint32_t q_num_heads = kv_num_heads * GROUP_SIZE; + const uint32_t q_head_idx = kv_head_idx * GROUP_SIZE; + const uint32_t tid = threadIdx.x, wid = threadIdx.y; + const uint32_t num_chunks = gridDim.y; + const uint32_t chunk_idx = blockIdx.y; + + const int32_t batch_id = batch_ids[btid]; + if (batch_id == -1) return; + + const uint32_t tile_id = tile_ids_per_batch[btid]; + const uint32_t num_rows_per_block = num_frags_x * 16; + const int *block_table_now = block_table + batch_id * max_block_num_per_seq; + + // When cudagraph capture prefill, may launch more gridDim.x + if (btid >= static_cast(num_blocks_x_cpu)) { + return; + } + + const uint32_t q_len = seq_lens[batch_id]; + if (q_len <= 0) { + return; + } + const int seq_len_enc = seq_lens_encoder[batch_id]; + if (seq_len_enc > 0) { + return; + } + T cache_k_scale_reg[IsDynamicC8 ? num_frags_z * 2 : num_frags_y * 4]; + T cache_v_scale_reg[IsDynamicC8 ? num_frags_z * 4 : num_frags_y * 2]; + if constexpr (!IsDynamicC8) { + if constexpr (is_scale_channel_wise) { + int scale_col_base = threadIdx.x % 4 * 2 + kv_head_idx * HEAD_DIM; + const T *cache_k_scale_cur_head = cache_k_scale + scale_col_base; + for (int i = 0; i < num_frags_y; ++i) { + const int scale_idx = i * 16; + cache_k_scale_reg[i * 4] = cache_k_scale_cur_head[scale_idx]; + cache_k_scale_reg[i * 4 + 1] = cache_k_scale_cur_head[scale_idx + 1]; + cache_k_scale_reg[i * 4 + 2] = cache_k_scale_cur_head[scale_idx + 8]; + cache_k_scale_reg[i * 4 + 3] = cache_k_scale_cur_head[scale_idx + 9]; + } + scale_col_base = threadIdx.x / 4 + kv_head_idx * HEAD_DIM; + const T *cache_v_scale_cur_head = cache_v_scale + scale_col_base; + for (int i = 0; i < num_frags_y; ++i) { + const int scale_idx = i * 16; + cache_v_scale_reg[i * 2] = cache_v_scale_cur_head[scale_idx]; + cache_v_scale_reg[i * 2 + 1] = cache_v_scale_cur_head[scale_idx + 8]; + } + } else { + cache_k_scale_reg[0] = cache_k_scale[kv_head_idx]; + cache_v_scale_reg[0] = cache_v_scale[kv_head_idx]; + } + } + const uint32_t q_end = + min(q_len, div_up((tile_id + 1) * num_rows_per_block, GROUP_SIZE)); + uint32_t kv_len = seq_lens_kv[batch_id]; + if (ENABLE_PREFILL) { + kv_len += q_len; + if (kv_len <= 0) { + return; + } + } else { + if (kv_len <= 0) { + return; + } + kv_len += q_len; + } + const uint32_t num_chunks_this_seq = div_up(kv_len, chunk_size); + if (chunk_idx >= num_chunks_this_seq) { + return; + } + + const uint32_t chunk_start = partition_kv ? chunk_idx * chunk_size : 0; + const uint32_t chunk_end = + partition_kv ? min(kv_len, chunk_start + chunk_size) : kv_len; + const uint32_t chunk_len = chunk_end - chunk_start; + + extern __shared__ uint8_t smem[]; + float s_frag[num_frags_x][num_frags_z][8]; + float o_frag[num_frags_x][num_frags_y][8]; + float m_frag[num_frags_x][2]; + float d_frag[num_frags_x][2]; + init_states(o_frag, m_frag, d_frag); + + const uint32_t q_n_stride = q_num_heads * HEAD_DIM; + const uint32_t q_ori_n_stride = (q_num_heads + kv_num_heads * 2) * HEAD_DIM; + const uint32_t kv_n_stride = kv_num_heads * BLOCK_SIZE * HEAD_DIM; + const uint32_t kv_h_stride = BLOCK_SIZE * HEAD_DIM; + const uint32_t kv_b_stride = HEAD_DIM; + const uint32_t kv_d_stride = BLOCK_SIZE; + const uint32_t q_start_seq_id = cu_seqlens_q[batch_id]; + const uint32_t q_base_seq_id_this_block = tile_id * num_frags_x * 16; + const uint32_t q_offset = q_start_seq_id * q_ori_n_stride + + q_head_idx * HEAD_DIM + + tid % 8 * num_elems_per_128b(); + const uint32_t o_offset = q_start_seq_id * q_n_stride + + q_head_idx * HEAD_DIM + + tid % 8 * num_elems_per_128b(); + T *q_base_ptr = q + q_offset; + + T *o_base_ptr_T = nullptr; + OutT *o_base_ptr_int8 = nullptr; + if (num_chunks_this_seq <= 1) { + o_base_ptr_int8 = out + o_offset; + } else { + if (ENABLE_PREFILL) { + o_base_ptr_T = tmp_workspace + batch_id * num_chunks * q_n_stride + + chunk_idx * q_n_stride + q_head_idx * HEAD_DIM + + tid % 8 * num_elems_per_128b(); + } else { + o_base_ptr_T = + tmp_workspace + + batch_id * speculate_max_draft_token_num * num_chunks * q_n_stride + + chunk_idx * q_n_stride + q_head_idx * HEAD_DIM + + tid % 8 * num_elems_per_128b(); + } + } + const int *mask_offset_this_seq = + mask_offset ? mask_offset + q_start_seq_id * 2 : nullptr; + smem_t qo_smem(smem); + + uint32_t q_smem_offset_r = smem_t::get_permuted_offset( + tid % 16, tid / 16); // 16 * 16 + load_q_global_smem_multi_warps(q_base_ptr, + &qo_smem, + q_base_seq_id_this_block, + q_end, + q_ori_n_stride, + HEAD_DIM); + commit_group(); + wait_group<0>(); + __syncthreads(); + + q_smem_inplace_multiply_sm_scale_multi_warps( + &qo_smem, scale); + + smem_t k_smem(smem + num_frags_x * 16 * HEAD_DIM * sizeof(T)), + v_smem(smem + num_frags_x * 16 * HEAD_DIM * sizeof(T) + + NUM_WARP_KV * num_frags_z * 16 * HEAD_DIM * sizeof(CacheT)); + T *k_smem_scale_ptr = nullptr; + T *v_smem_scale_ptr = nullptr; + smem_t k_scale_smem; + smem_t v_scale_smem; + if constexpr (IsDynamicC8) { + k_smem_scale_ptr = reinterpret_cast( + smem + num_frags_x * 16 * HEAD_DIM * sizeof(T) + + NUM_WARP_KV * num_frags_z * 16 * HEAD_DIM * sizeof(CacheT) * 2); + v_smem_scale_ptr = k_smem_scale_ptr + NUM_WARP_KV * num_frags_z * 16; + k_scale_smem.base = reinterpret_cast(k_smem_scale_ptr); + v_scale_smem.base = reinterpret_cast(v_smem_scale_ptr); + } + + const uint32_t num_iterations = div_up( + CAUSAL + ? (min(chunk_len, + sub_if_greater_or_zero( + kv_len - q_len + + div_up((tile_id + 1) * num_rows_per_block, GROUP_SIZE), + chunk_start))) + : chunk_len, + NUM_WARP_KV * num_frags_z * 16); + const uint32_t mask_check_iteration = + (CAUSAL ? (min(chunk_len, + sub_if_greater_or_zero( + kv_len - q_len + + tile_id * num_rows_per_block / GROUP_SIZE, + chunk_start))) + : mask_offset ? 0 + : chunk_len) / + (NUM_WARP_KV * num_frags_z * 16); + + uint32_t k_smem_offset_r = + smem_t::get_permuted_offset( + wid * num_frags_z * 16 + 8 * (tid / 16) + tid % 8, (tid % 16) / 8); + + uint32_t v_smem_offset_r = + smem_t::get_permuted_offset( + (wid / 2) * num_frags_y * 16 + 8 * (tid / 16) + tid % 8, + (wid % 2) * num_frags_z + (tid % 16) / 8); + + uint32_t k_smem_offset_w = + smem_t::get_permuted_offset( + wid * 4 + tid / 8, tid % 8); + uint32_t v_smem_offset_w = + smem_t::get_permuted_offset( + wid * 8 + tid / 4, tid % 4); + + uint32_t kv_idx_base = chunk_start; + const uint32_t const_k_offset = kv_head_idx * kv_h_stride + + (wid * 4 + tid / 8) * kv_b_stride + + tid % 8 * num_elems_per_128b(); + const uint32_t const_v_offset = kv_head_idx * kv_h_stride + + (wid * 8 + tid / 4) * kv_d_stride + + tid % 4 * num_elems_per_128b(); + + // load BLOCK_SIZE * HEAD_DIM each time + produce_k_blockwise_c8(k_smem, + &k_smem_offset_w, + cache_k, + block_table_now, + kv_head_idx, + kv_n_stride, + kv_h_stride, + kv_b_stride, + kv_idx_base, + chunk_end, + const_k_offset); + if constexpr (IsDynamicC8) { + produce_kv_dynamic_scale_gmem2smem_async(k_scale_smem, + block_table_now, + cache_k_scale, + kv_idx_base, + kv_num_heads, + kv_head_idx, + chunk_end); + } + commit_group(); + produce_v_blockwise_c8(v_smem, + &v_smem_offset_w, + cache_v, + block_table_now, + kv_head_idx, + kv_n_stride, + kv_h_stride, + kv_d_stride, + kv_idx_base, + chunk_end, + const_v_offset); + if constexpr (IsDynamicC8) { + produce_kv_dynamic_scale_gmem2smem_async(v_scale_smem, + block_table_now, + cache_v_scale, + kv_idx_base, + kv_num_heads, + kv_head_idx, + chunk_end); + } + commit_group(); +#pragma unroll 1 + for (uint32_t iter = 0; iter < num_iterations; ++iter) { + wait_group<1>(); + __syncthreads(); + if constexpr (IsDynamicC8) { + produce_k_dynamic_scale_smem2reg( + k_smem_scale_ptr, cache_k_scale_reg); + } + + // s = qk + compute_qk_c8(&qo_smem, + &q_smem_offset_r, + &k_smem, + &k_smem_offset_r, + cache_k_scale_reg, + s_frag); + // mask according to kv_idx and q_idx + if (iter >= mask_check_iteration || sliding_window > 0) { + mask_s( + attn_mask ? attn_mask + batch_id * attn_mask_len * attn_mask_len + : nullptr, + q_base_seq_id_this_block, + kv_idx_base + wid * num_frags_z * 16, + q_len, + kv_len, + chunk_end, + attn_mask_len, + s_frag, + mask_offset_this_seq, + sliding_window, + sink_size); + } + + // update m,d + update_mdo_states( + s_frag, o_frag, m_frag, d_frag); + __syncthreads(); + + const uint32_t ori_kv_idx_base = kv_idx_base; + kv_idx_base += NUM_WARP_KV * num_frags_z * 16; + produce_k_blockwise_c8(k_smem, + &k_smem_offset_w, + cache_k, + block_table_now, + kv_head_idx, + kv_n_stride, + kv_h_stride, + kv_b_stride, + kv_idx_base, + chunk_end, + const_k_offset); + if constexpr (IsDynamicC8) { + produce_kv_dynamic_scale_gmem2smem_async(k_scale_smem, + block_table_now, + cache_k_scale, + kv_idx_base, + kv_num_heads, + kv_head_idx, + chunk_end); + } + commit_group(); + wait_group<1>(); + __syncthreads(); + if constexpr (IsDynamicC8) { + produce_v_dynamic_scale_smem2reg( + v_smem_scale_ptr, cache_v_scale_reg); + } + + // compute sfm * v + compute_sfm_v_c8_iter_sq_bvec( + &v_smem, &v_smem_offset_r, s_frag, o_frag, d_frag, cache_v_scale_reg); + __syncthreads(); + + produce_v_blockwise_c8(v_smem, + &v_smem_offset_w, + cache_v, + block_table_now, + kv_head_idx, + kv_n_stride, + kv_h_stride, + kv_d_stride, + kv_idx_base, + chunk_end, + const_v_offset); + if constexpr (IsDynamicC8) { + produce_kv_dynamic_scale_gmem2smem_async(v_scale_smem, + block_table_now, + cache_v_scale, + kv_idx_base, + kv_num_heads, + kv_head_idx, + chunk_end); + } + commit_group(); + } + wait_group<0>(); + __syncthreads(); + + merge_block_res_v2( + o_frag, reinterpret_cast(smem), m_frag, d_frag, wid, tid); + + if (num_chunks_this_seq <= 1) { + if (sinks) { + float current_sinks[num_frags_x][2]; +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (uint32_t j = 0; j < 2; ++j) { + const uint32_t h_offset = + (q_base_seq_id_this_block + fx * 16 + tid / 4 + 8 * j) % + GROUP_SIZE; + current_sinks[fx][j] = + static_cast(sinks[q_head_idx + h_offset]); + } + } + normalize_d( + o_frag, d_frag, m_frag, current_sinks); + } else { + normalize_d(o_frag, d_frag); + } + } + + // write o + // [num_frags_x, 16, num_frags_y, 16] + if (num_chunks_this_seq <= 1) { + write_o_reg_gmem_multi_warps_shift_smooth_quant( + o_frag, + &qo_smem, + o_base_ptr_int8, + shift_bias, + smooth_weight, + q_base_seq_id_this_block, + q_head_idx, + quant_max_bound, + quant_min_bound, + in_scale, + q_len, + q_n_stride, + HEAD_DIM); + } else { + write_o_reg_gmem_multi_warps_shift_smooth_quant( + o_frag, + &qo_smem, + o_base_ptr_T, + shift_bias, + smooth_weight, + q_base_seq_id_this_block, + q_head_idx, + quant_max_bound, + quant_min_bound, + in_scale, + q_len, + q_n_stride * num_chunks, + HEAD_DIM); + } + + if (num_chunks_this_seq > 1) { + if (wid == 0) { +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (uint32_t j = 0; j < 2; ++j) { + const uint32_t qo_idx_now = + q_base_seq_id_this_block + tid / 4 + j * 8 + fx * 16; + const uint32_t qo_head_idx = q_head_idx + qo_idx_now % GROUP_SIZE; + const uint32_t qo_idx = q_start_seq_id + qo_idx_now / GROUP_SIZE; + if (qo_idx - q_start_seq_id < q_len) { + uint32_t offset; + if (ENABLE_PREFILL) { + offset = (batch_id * num_chunks + chunk_idx) * q_num_heads + + qo_head_idx; + } else { + offset = ((batch_id * speculate_max_draft_token_num + + qo_idx_now / GROUP_SIZE) * + num_chunks + + chunk_idx) * + q_num_heads + + qo_head_idx; + } + tmp_m[offset] = m_frag[fx][j]; + tmp_d[offset] = d_frag[fx][j]; + } + } + } + } + } +} + +template +void MultiQueryAppendC8Attention( + const AppendAttnMetaData &meta_data, + const paddle::Tensor &qkv, + const paddle::Tensor &cache_k, + const paddle::Tensor &cache_v, + const paddle::optional &attn_mask, + const paddle::Tensor &cache_k_scale, + const paddle::Tensor &cache_v_scale, + const paddle::optional &shift_bias, + const paddle::optional &smooth_weight, + const paddle::optional &sinks, + const paddle::Tensor &seq_lens_q, + const paddle::Tensor &seq_lens_kv, + const paddle::Tensor &seq_lens_encoder, + const paddle::Tensor &batch_id_per_token, + const paddle::Tensor &cu_seqlens_q, + const paddle::Tensor &block_table, + const paddle::Tensor &batch_ids, + const paddle::Tensor &tile_ids_per_batch, + const int num_blocks_x_cpu, + const int max_seq_len, + const int max_dec_len, + const float quant_max_bound, + const float quant_min_bound, + const float in_scale, + const int max_partition_size, + const int encoder_max_partition_size, + const int speculate_max_draft_token_num, + const bool is_decoder, + cudaStream_t &stream, + paddle::Tensor *out, + const int sliding_window, + const int sink_size = 0) { + using NV_TYPE = typename cascade_attn_type_traits::type; + using OUT_NV_TYPE = typename cascade_attn_type_traits::type; + + auto num_heads = meta_data.q_num_heads; + auto kv_num_heads = meta_data.kv_num_heads; + auto token_num = meta_data.token_nums; + auto bsz = meta_data.batch_size; + auto max_block_num_per_seq = meta_data.max_blocks_per_seq; + + constexpr uint32_t num_warps = 4; + constexpr uint32_t NUM_WARP_KV = num_warps / NUM_WARP_Q; + constexpr uint32_t num_frags_x = BLOCK_SHAPE_Q / (16 * NUM_WARP_Q); + constexpr uint32_t num_frags_y = HEAD_DIM / 16; + constexpr uint32_t num_qrow_per_block = NUM_WARP_Q * num_frags_x * 16; + + auto *allocator = paddle::GetAllocator(qkv.place()); + + const float scale = 1.f / sqrt(HEAD_DIM); + bool is_scale_channel_wise = false; + if (cache_k_scale.dims()[0] == HEAD_DIM * kv_num_heads) { + is_scale_channel_wise = true; + } + + if constexpr (NUM_WARP_Q == 4) { + constexpr uint32_t num_frags_z = BLOCK_SIZE / 16; + constexpr uint32_t smem_size = + num_warps * num_frags_x * 16 * HEAD_DIM * sizeof(T) + + num_frags_z * 16 * HEAD_DIM * sizeof(uint8_t) * 2 + + num_frags_z * 16 * sizeof(T) * 2; + auto split_kv_kernel = + multi_query_append_attention_c8_kernel; + if (is_scale_channel_wise) { + split_kv_kernel = multi_query_append_attention_c8_kernel; + } + if (smem_size >= 48 * 1024) { + cudaFuncSetAttribute(split_kv_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + } + const int dev_id = 0; + int sm_count; + cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id); + uint32_t chunk_size = static_cast(encoder_max_partition_size); + const int num_chunks = div_up(max_dec_len, chunk_size); + dim3 grids(num_blocks_x_cpu, num_chunks, kv_num_heads); + dim3 blocks(32, num_warps); + if (num_chunks <= 1) { + auto nosplit_kv_kernel = + multi_query_append_attention_c8_kernel; + if (is_scale_channel_wise) { + nosplit_kv_kernel = + multi_query_append_attention_c8_kernel; + } + if (smem_size >= 48 * 1024) { + cudaFuncSetAttribute(nosplit_kv_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + } + launchWithPdlWhenEnabled( + nosplit_kv_kernel, + grids, + blocks, + smem_size, + stream, + reinterpret_cast(const_cast(qkv.data())), + const_cast(cache_k.data()), + const_cast(cache_v.data()), + reinterpret_cast(const_cast(cache_k_scale.data())), + reinterpret_cast(const_cast(cache_v_scale.data())), + shift_bias ? reinterpret_cast( + const_cast(shift_bias.get().data())) + : nullptr, + smooth_weight ? reinterpret_cast( + const_cast(smooth_weight.get().data())) + : nullptr, + sinks ? reinterpret_cast( + const_cast(sinks.get().data())) + : nullptr, + seq_lens_q.data(), + seq_lens_kv.data(), + batch_ids.data(), + tile_ids_per_batch.data(), + cu_seqlens_q.data(), + block_table.data(), + meta_data.mask_offset, + max_seq_len, + max_dec_len, + max_block_num_per_seq, + scale, + quant_max_bound, + quant_min_bound, + in_scale, + chunk_size, + num_blocks_x_cpu, + nullptr, + nullptr, + nullptr, + reinterpret_cast(out->data()), + speculate_max_draft_token_num, + sliding_window, + sink_size); + } else { + phi::Allocator::AllocationPtr tmp_workspace, tmp_m, tmp_d; + if (ENABLE_PREFILL) { + tmp_workspace = allocator->Allocate( + phi::SizeOf(qkv.dtype()) * + static_cast(token_num * num_chunks * num_heads * HEAD_DIM)); + tmp_m = allocator->Allocate( + phi::SizeOf(paddle::DataType::FLOAT32) * + static_cast(token_num * num_chunks * num_heads)); + tmp_d = allocator->Allocate( + phi::SizeOf(paddle::DataType::FLOAT32) * + static_cast(token_num * num_chunks * num_heads)); + } else { + tmp_workspace = allocator->Allocate( + phi::SizeOf(qkv.dtype()) * + static_cast(speculate_max_draft_token_num * bsz * + num_chunks * num_heads * HEAD_DIM)); + tmp_m = allocator->Allocate( + phi::SizeOf(paddle::DataType::FLOAT32) * + static_cast(speculate_max_draft_token_num * bsz * + num_chunks * num_heads)); + tmp_d = allocator->Allocate( + phi::SizeOf(paddle::DataType::FLOAT32) * + static_cast(speculate_max_draft_token_num * bsz * + num_chunks * num_heads)); + } + launchWithPdlWhenEnabled( + split_kv_kernel, + grids, + blocks, + smem_size, + stream, + reinterpret_cast(const_cast(qkv.data())), + const_cast(cache_k.data()), + const_cast(cache_v.data()), + reinterpret_cast(const_cast(cache_k_scale.data())), + reinterpret_cast(const_cast(cache_v_scale.data())), + shift_bias ? reinterpret_cast( + const_cast(shift_bias.get().data())) + : nullptr, + smooth_weight ? reinterpret_cast( + const_cast(smooth_weight.get().data())) + : nullptr, + sinks ? reinterpret_cast( + const_cast(sinks.get().data())) + : nullptr, + seq_lens_q.data(), + seq_lens_kv.data(), + batch_ids.data(), + tile_ids_per_batch.data(), + cu_seqlens_q.data(), + block_table.data(), + meta_data.mask_offset, + max_seq_len, + max_dec_len, + max_block_num_per_seq, + scale, + quant_max_bound, + quant_min_bound, + in_scale, + chunk_size, + num_blocks_x_cpu, + reinterpret_cast(tmp_workspace->ptr()), + static_cast(tmp_m->ptr()), + static_cast(tmp_d->ptr()), + reinterpret_cast(out->data()), + speculate_max_draft_token_num, + sliding_window, + sink_size); + // merge + constexpr int vec_size = num_elems_per_128b(); + constexpr int blockx = HEAD_DIM / vec_size; + constexpr int blocky = (128 + blockx - 1) / blockx; + dim3 grids_merge(min(sm_count * 4, token_num), num_heads); + dim3 blocks_merge(blockx, blocky); + launchWithPdlWhenEnabled( + merge_multi_chunks_v2_kernel, + grids_merge, + blocks_merge, + 0, + stream, + reinterpret_cast(tmp_workspace->ptr()), + static_cast(tmp_m->ptr()), + static_cast(tmp_d->ptr()), + seq_lens_q.data(), + seq_lens_kv.data(), + seq_lens_encoder.data(), + batch_id_per_token.data(), + cu_seqlens_q.data(), + shift_bias ? reinterpret_cast( + const_cast(shift_bias.get().data())) + : nullptr, + smooth_weight ? reinterpret_cast( + const_cast(smooth_weight.get().data())) + : nullptr, + sinks ? reinterpret_cast( + const_cast(sinks.get().data())) + : nullptr, + reinterpret_cast(out->data()), + quant_max_bound, + quant_min_bound, + in_scale, + max_seq_len, + num_chunks, + num_heads, + chunk_size, + HEAD_DIM, + token_num, + speculate_max_draft_token_num); + } + } else { + constexpr uint32_t num_frags_z = BLOCK_SIZE / 16 / NUM_WARP_KV * 2; + constexpr uint32_t smem_size = + num_frags_x * 16 * HEAD_DIM * sizeof(T) + + NUM_WARP_KV * num_frags_z * 16 * HEAD_DIM * sizeof(uint8_t) * 2 + + NUM_WARP_KV * num_frags_z * 16 * sizeof(T) * 2; + auto split_kv_kernel = + multi_query_append_attention_c8_warp1_4_kernel; + if (is_scale_channel_wise) { + split_kv_kernel = + multi_query_append_attention_c8_warp1_4_kernel; + } + if (smem_size >= 48 * 1024) { + cudaFuncSetAttribute(split_kv_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + } + const int dev_id = 0; + int sm_count; + cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id); + uint32_t chunk_size = static_cast(max_partition_size); + + const int num_chunks = div_up(max_seq_len, chunk_size); + uint32_t attn_mask_len; + if (attn_mask) { + attn_mask_len = attn_mask.get().shape()[1]; + } else { + attn_mask_len = -1; + } + + dim3 grids(num_blocks_x_cpu, num_chunks, kv_num_heads); + dim3 blocks(32, num_warps); + if (num_chunks <= 0) { + auto nosplit_kv_kernel = + multi_query_append_attention_c8_warp1_4_kernel; + if (is_scale_channel_wise) { + nosplit_kv_kernel = + multi_query_append_attention_c8_warp1_4_kernel; + } + if (smem_size >= 48 * 1024) { + cudaFuncSetAttribute(nosplit_kv_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + } + launchWithPdlWhenEnabled( + nosplit_kv_kernel, + grids, + blocks, + smem_size, + stream, + reinterpret_cast(const_cast(qkv.data())), + const_cast(cache_k.data()), + const_cast(cache_v.data()), + reinterpret_cast(const_cast(cache_k_scale.data())), + reinterpret_cast(const_cast(cache_v_scale.data())), + shift_bias ? reinterpret_cast( + const_cast(shift_bias.get().data())) + : nullptr, + smooth_weight ? reinterpret_cast( + const_cast(smooth_weight.get().data())) + : nullptr, + sinks ? reinterpret_cast( + const_cast(sinks.get().data())) + : nullptr, + seq_lens_q.data(), + seq_lens_kv.data(), + seq_lens_encoder.data(), + batch_ids.data(), + tile_ids_per_batch.data(), + cu_seqlens_q.data(), + block_table.data(), + meta_data.mask_offset, + attn_mask ? const_cast(attn_mask.get().data()) + : nullptr, + max_seq_len, + max_dec_len, + max_block_num_per_seq, + scale, + quant_max_bound, + quant_min_bound, + in_scale, + chunk_size, + num_blocks_x_cpu, + nullptr, + nullptr, + nullptr, + reinterpret_cast(out->data()), + speculate_max_draft_token_num, + attn_mask_len, + sliding_window, + sink_size); + } else { + phi::Allocator::AllocationPtr tmp_workspace, tmp_m, tmp_d; + if (is_decoder) { + tmp_workspace = allocator->Allocate( + phi::SizeOf(qkv.dtype()) * + static_cast(bsz * num_chunks * num_heads * HEAD_DIM)); + tmp_m = allocator->Allocate( + phi::SizeOf(paddle::DataType::FLOAT32) * + static_cast(bsz * num_chunks * num_heads)); + tmp_d = allocator->Allocate( + phi::SizeOf(paddle::DataType::FLOAT32) * + static_cast(bsz * num_chunks * num_heads)); + } else { + if (ENABLE_PREFILL) { + tmp_workspace = + allocator->Allocate(phi::SizeOf(qkv.dtype()) * + static_cast(token_num * num_chunks * + num_heads * HEAD_DIM)); + tmp_m = allocator->Allocate( + phi::SizeOf(paddle::DataType::FLOAT32) * + static_cast(token_num * num_chunks * num_heads)); + tmp_d = allocator->Allocate( + phi::SizeOf(paddle::DataType::FLOAT32) * + static_cast(token_num * num_chunks * num_heads)); + } else { + tmp_workspace = allocator->Allocate( + phi::SizeOf(qkv.dtype()) * + static_cast(speculate_max_draft_token_num * bsz * + num_chunks * num_heads * HEAD_DIM)); + tmp_m = allocator->Allocate( + phi::SizeOf(paddle::DataType::FLOAT32) * + static_cast(speculate_max_draft_token_num * bsz * + num_chunks * num_heads)); + tmp_d = allocator->Allocate( + phi::SizeOf(paddle::DataType::FLOAT32) * + static_cast(speculate_max_draft_token_num * bsz * + num_chunks * num_heads)); + } + } + launchWithPdlWhenEnabled( + split_kv_kernel, + grids, + blocks, + smem_size, + stream, + reinterpret_cast(const_cast(qkv.data())), + const_cast(cache_k.data()), + const_cast(cache_v.data()), + reinterpret_cast(const_cast(cache_k_scale.data())), + reinterpret_cast(const_cast(cache_v_scale.data())), + shift_bias ? reinterpret_cast( + const_cast(shift_bias.get().data())) + : nullptr, + smooth_weight ? reinterpret_cast( + const_cast(smooth_weight.get().data())) + : nullptr, + sinks ? reinterpret_cast( + const_cast(sinks.get().data())) + : nullptr, + seq_lens_q.data(), + seq_lens_kv.data(), + seq_lens_encoder.data(), + batch_ids.data(), + tile_ids_per_batch.data(), + cu_seqlens_q.data(), + block_table.data(), + meta_data.mask_offset, + attn_mask ? const_cast(attn_mask.get().data()) + : nullptr, + max_seq_len, + max_dec_len, + max_block_num_per_seq, + scale, + quant_max_bound, + quant_min_bound, + in_scale, + chunk_size, + num_blocks_x_cpu, + reinterpret_cast(tmp_workspace->ptr()), + static_cast(tmp_m->ptr()), + static_cast(tmp_d->ptr()), + reinterpret_cast(out->data()), + speculate_max_draft_token_num, + attn_mask_len, + sliding_window, + sink_size); + // merge + constexpr int vec_size = num_elems_per_128b(); + if (is_decoder) { + constexpr int blockx = HEAD_DIM / vec_size; + constexpr int blocky = (128 + blockx - 1) / blockx; + dim3 grids_merge(bsz, num_heads); + dim3 blocks_merge(blockx, blocky); + auto *kernelFn = merge_multi_chunks_decoder_kernel; + launchWithPdlWhenEnabled( + kernelFn, + grids_merge, + blocks_merge, + 0, + stream, + reinterpret_cast(tmp_workspace->ptr()), + static_cast(tmp_m->ptr()), + static_cast(tmp_d->ptr()), + seq_lens_q.data(), + seq_lens_kv.data(), + seq_lens_encoder.data(), + cu_seqlens_q.data(), + shift_bias ? reinterpret_cast( + const_cast(shift_bias.get().data())) + : nullptr, + smooth_weight ? reinterpret_cast( + const_cast(smooth_weight.get().data())) + : nullptr, + sinks ? reinterpret_cast( + const_cast(sinks.get().data())) + : nullptr, + reinterpret_cast(out->data()), + quant_max_bound, + quant_min_bound, + in_scale, + max_seq_len, + num_chunks, + num_heads, + chunk_size, + HEAD_DIM); + } else { + constexpr int blockx = HEAD_DIM / vec_size; + constexpr int blocky = (128 + blockx - 1) / blockx; + dim3 grids_merge(min(sm_count * 4, token_num), num_heads); + dim3 blocks_merge(blockx, blocky); + launchWithPdlWhenEnabled( + merge_multi_chunks_v2_kernel, + grids_merge, + blocks_merge, + 0, + stream, + reinterpret_cast(tmp_workspace->ptr()), + static_cast(tmp_m->ptr()), + static_cast(tmp_d->ptr()), + seq_lens_q.data(), + seq_lens_kv.data(), + seq_lens_encoder.data(), + batch_id_per_token.data(), + cu_seqlens_q.data(), + shift_bias ? reinterpret_cast( + const_cast(shift_bias.get().data())) + : nullptr, + smooth_weight ? reinterpret_cast( + const_cast(smooth_weight.get().data())) + : nullptr, + sinks ? reinterpret_cast( + const_cast(sinks.get().data())) + : nullptr, + reinterpret_cast(out->data()), + quant_max_bound, + quant_min_bound, + in_scale, + max_seq_len, + num_chunks, + num_heads, + chunk_size, + HEAD_DIM, + token_num, + speculate_max_draft_token_num); + } + } + } +} diff --git a/custom_ops/gpu_ops/append_attn/multiquery_attention_c8_kernel.h b/custom_ops/gpu_ops/append_attn/multiquery_attention_c8_kernel.h new file mode 100644 index 00000000000..ed782765826 --- /dev/null +++ b/custom_ops/gpu_ops/append_attn/multiquery_attention_c8_kernel.h @@ -0,0 +1,61 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include "append_attention_func.cuh" + +template +void MultiQueryAppendC8Attention( + const AppendAttnMetaData &meta_data, + const paddle::Tensor &qkv, + const paddle::Tensor &cache_k, + const paddle::Tensor &cache_v, + const paddle::optional &attn_mask, + const paddle::Tensor &cache_k_scale, + const paddle::Tensor &cache_v_scale, + const paddle::optional &shift_bias, + const paddle::optional &smooth_weight, + const paddle::optional &sinks, + const paddle::Tensor &seq_lens_q, + const paddle::Tensor &seq_lens_kv, + const paddle::Tensor &seq_lens_encoder, + const paddle::Tensor &batch_id_per_token, + const paddle::Tensor &cu_seqlens_q, + const paddle::Tensor &block_table, + const paddle::Tensor &batch_ids, + const paddle::Tensor &tile_ids_per_batch, + const int num_blocks_x_cpu, + const int max_seq_len, + const int max_dec_len, + const float quant_max_bound, + const float quant_min_bound, + const float in_scale, + const int max_partition_size, + const int encoder_max_partition_size, + const int speculate_max_draft_token_num, + const bool is_decoder, + cudaStream_t &stream, + paddle::Tensor *out, + const int sliding_window, + const int sink_size); diff --git a/custom_ops/gpu_ops/append_attn/multiquery_decoder_attention_impl.cuh b/custom_ops/gpu_ops/append_attn/multiquery_decoder_attention_impl.cuh new file mode 100644 index 00000000000..2eb6d6bde52 --- /dev/null +++ b/custom_ops/gpu_ops/append_attn/multiquery_decoder_attention_impl.cuh @@ -0,0 +1,550 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include "decode_attention_func.cuh" +#include "multiquery_decoder_attention_kernel.h" + +#define CHECK(call) \ + do { \ + const cudaError_t error_code = call; \ + if (error_code != cudaSuccess) { \ + printf("CUDA Error:\n"); \ + printf(" File: %s\n", __FILE__); \ + printf(" Line %d:\n", __LINE__); \ + printf(" Error code:%d\n", error_code); \ + printf(" Error text:%s\n", cudaGetErrorString(error_code)); \ + exit(1); \ + } \ + } while (0) + +template +__global__ void merge_varlen_multi_chunks_v2_kernel( + const T *__restrict__ multi_out, // [bsz, num_chunks, num_heads, head_dim] + const T *__restrict__ multi_m, // [bsz, num_chunks, num_heads] + const T *__restrict__ multi_d, // [bsz, num_chunks, num_heads] + const int *__restrict__ seq_lens_q, + const int *__restrict__ seq_lens_kv, + const int *__restrict__ cu_seqlens_q, + const T *__restrict__ shift_bias, // [q_num_heads * HEAD_DIM] + const T *__restrict__ smooth_weight, // [q_num_heads * HEAD_DIM] + OutT *__restrict__ out, // [token_num, num_heads, head_dim] + const float in_scale, + const int num_chunks, + const int chunk_size, + const int max_seq_len, + const int num_heads, + const int head_dim) { + const int vid = threadIdx.x, ty = threadIdx.y; + const int qid = blockIdx.x, hid = blockIdx.y; + const int seq_len_q = seq_lens_q[qid]; + if (seq_len_q == 0) return; + int seq_len_kv = seq_lens_kv[qid]; + if (seq_len_kv == 0) return; + seq_len_kv += seq_len_q; + const int num_chunks_this_seq = div_up(seq_len_kv, chunk_size); + if (num_chunks_this_seq == 1 || ty >= num_chunks_this_seq) { + return; + } + __shared__ T smem[bdy * HEAD_DIM]; + __shared__ T md_smem[bdy * 2]; + + const int start_token_ids = cu_seqlens_q[qid]; + using LoadT = AlignedVector; + LoadT load_vec; + LoadT res_vec; + if constexpr (std::is_same::value) { +#pragma unroll + for (int i = 0; i < vec_size / 2; ++i) { + *((half2 *)(&res_vec) + i) = make_half2(0, 0); + } + } else if constexpr (std::is_same::value) { +#pragma unroll + for (int i = 0; i < vec_size / 2; ++i) { + *((nv_bfloat162 *)(&res_vec) + i) = make_bfloat162(0, 0); + } + } + T m; + T d = 1.f; + if constexpr (std::is_same::value) { + m = __float2half(-5e4f); + } else if constexpr (std::is_same::value) { + m = __float2bfloat16(-3.38953e38f); + } + // merge per ty +#pragma unroll 2 + for (int i = ty; i < num_chunks_this_seq; i += bdy) { + uint32_t offset = (qid * num_chunks + i) * num_heads + hid; + T m_prev = m; + T d_prev = d; + const T m_now = multi_m[offset]; + const T d_now = multi_d[offset]; + m = m_prev > m_now ? m_prev : m_now; + offset = (qid * num_chunks * num_heads + i * num_heads + hid) * head_dim + + vid * vec_size; + Load(&multi_out[offset], &load_vec); + const T scale1 = hexp(m_prev - m), scale2 = hexp(m_now - m); + d = d * scale1 + d_now * scale2; +#pragma once + for (int j = 0; j < vec_size; j++) { + res_vec[j] = res_vec[j] * scale1 + load_vec[j] * scale2; + } + } + // store ty res + Store(res_vec, &smem[ty * head_dim + vid * vec_size]); + md_smem[2 * ty] = m; + md_smem[2 * ty + 1] = d; + __syncthreads(); + + // merge bdy + softmax_state_t st{}; + const uint32_t iter_num = min(num_chunks_this_seq, bdy); +#pragma once + for (int i = 0; i < iter_num; i++) { + Load(&smem[i * head_dim + vid * vec_size], &load_vec); + const T m_tmp = md_smem[2 * i], d_tmp = md_smem[2 * i + 1]; + st.merge(load_vec, m_tmp, d_tmp); + } + st.normalize(); + + AlignedVector out_vec; + +#pragma unroll + for (int i = 0; i < vec_size; ++i) { + out_vec[i] = static_cast(st.o[i]); + } + Store( + out_vec, + &out[(start_token_ids * num_heads + hid) * head_dim + vid * vec_size]); +} + +template +__global__ void multi_query_decode_attention_kernel( + T *__restrict__ q, // [token_num, num_heads, head_dim] + CacheT *__restrict__ cache_k, // [max_block_num, num_heads, block_size, + // head_dim] + CacheT *__restrict__ cache_v, + const T *__restrict__ shift_bias, // [q_num_heads * HEAD_DIM] + const T *__restrict__ smooth_weight, // [q_num_heads * HEAD_DIM] + const int *__restrict__ seq_lens_q, + const int *__restrict__ seq_lens_kv, + const int *__restrict__ cu_seqlens_q, + const int *__restrict__ block_table, // [bsz, block_num_per_seq] + const int max_seq_len, + const int max_dec_len, + const int max_block_num_per_seq, + const float scale, + const float in_scale, + const uint32_t chunk_size, + T *__restrict__ tmp_workspace, // [batch_size, num_chunks, num_heads, + // head_dim] + T *__restrict__ tmp_m, // [batch_size, num_chunks, num_heads] + T *__restrict__ tmp_d, // [batch_size, num_chunks, num_heads] + OutT *__restrict__ out) { + const uint32_t bidx = blockIdx.x, kv_head_idx = blockIdx.z; + const uint32_t bid = bidx, gid = threadIdx.y; + const uint32_t tidx = threadIdx.x; + constexpr uint32_t num_vec_per_head_qk = HEAD_DIM_QK / VEC_SIZE; + constexpr uint32_t num_vec_per_head_v = HEAD_DIM_V / VEC_SIZE; + constexpr uint32_t num_tile_v = (num_vec_per_head_v + bdx - 1) / bdx; + + const uint32_t q_head_idx = kv_head_idx * GROUP_SIZE + gid; + const uint32_t kv_num_heads = gridDim.z; + const uint32_t q_num_heads = kv_num_heads * GROUP_SIZE; + + const int *block_table_now = block_table + bid * max_block_num_per_seq; + + const uint32_t num_chunks = gridDim.y; + const uint32_t chunk_id = blockIdx.y; + const uint32_t q_len = seq_lens_q[bid]; + if (q_len <= 0) { + return; + } + uint32_t kv_len = seq_lens_kv[bid]; // !!!!!!!! + if (kv_len <= 0) { + return; + } + kv_len += q_len; + const uint32_t num_chunk_this_seq = div_up(kv_len, chunk_size); + const uint32_t q_start_idx = cu_seqlens_q[bid]; + const uint32_t q_write_idx = cu_seqlens_q[bid]; + if (chunk_id >= num_chunk_this_seq) { + return; + } + + const uint32_t chunk_start = partition_kv ? chunk_id * chunk_size : 0; + const uint32_t chunk_end = + partition_kv ? min(kv_len, chunk_start + chunk_size) : kv_len; + const uint32_t chunk_len = chunk_end - chunk_start; + + extern __shared__ uint8_t smem[]; + const T *q_now = q + (q_start_idx * q_num_heads + q_head_idx) * HEAD_DIM_QK; + T *q_smem = reinterpret_cast(smem); // [HEAD_DIM_QK * sizeof(T)] + T *cu_q_smem = q_smem + gid * HEAD_DIM_QK; +#pragma unroll + for (uint32_t vid = tidx; vid < num_vec_per_head_qk; vid += bdx) { + ((float4 *)(&cu_q_smem[vid * VEC_SIZE]))[0] = + ((float4 *)(&q_now[vid * VEC_SIZE]))[0]; + } + __syncthreads(); + using VecT = AlignedVector; + VecT q_vec; +#pragma unroll + for (uint32_t vid = tidx; vid < num_vec_per_head_qk; vid += bdx) { + Load(cu_q_smem + vid * VEC_SIZE, &q_vec); + for (uint32_t i = 0; i < VEC_SIZE; ++i) { + q_vec[i] *= scale; + } + Store(q_vec, cu_q_smem + vid * VEC_SIZE); + } + + CacheT *kv_smem = reinterpret_cast(smem + GROUP_SIZE * HEAD_DIM_QK * + sizeof(CacheT)); + uint32_t stage_idx = 0; + constexpr int loop_times = DEAL_EACH_TIME / bdy; +#pragma unroll + for (int i = 0; i < NUM_STAGES; ++i) { +#pragma unroll + for (int j = 0; j < loop_times; ++j) { + const uint32_t k_seq_offset = i * DEAL_EACH_TIME + j * bdy + gid; + const uint32_t k_seq_id = chunk_start + k_seq_offset; + produce_kv(kv_smem, + cache_k, + block_table_now, + k_seq_id, + k_seq_offset, + kv_head_idx, + kv_num_heads, + tidx, + chunk_start, + chunk_end); + } + commit_group(); + stage_idx = (stage_idx + 1) % NUM_STAGES; + } + + softmax_state_ts st; + float s[DEAL_EACH_TIME]; + + const uint32_t num_iters = div_up(chunk_len, DEAL_EACH_TIME); + for (int iter = 0; iter < num_iters; ++iter) { + wait_group(); + __syncthreads(); + // compute qk + compute_qk(cu_q_smem, + kv_smem, + chunk_start + iter * DEAL_EACH_TIME, + stage_idx, + iter * DEAL_EACH_TIME, + chunk_len, + tidx, + gid, + scale, + s, + st); + __syncthreads(); + + // compute sv + compute_sv( + s, kv_smem, stage_idx, iter * DEAL_EACH_TIME, chunk_len, tidx, st); + __syncthreads(); + +#pragma unroll + for (int j = 0; j < loop_times; ++j) { + const uint32_t k_seq_offset = j * bdy + gid; + produce_kv( + kv_smem, + cache_k, + block_table_now, + chunk_start + k_seq_offset + (iter + NUM_STAGES) * DEAL_EACH_TIME, + stage_idx * DEAL_EACH_TIME + k_seq_offset, + kv_head_idx, + kv_num_heads, + tidx, + chunk_start, + chunk_end); + } + commit_group(); + stage_idx = (stage_idx + 1) % NUM_STAGES; + } + wait_group<0>(); + __syncthreads(); + + // normize if not partition_kv + for (uint32_t vid = tidx; vid < num_vec_per_head_v; vid += bdx) { + const uint32_t tile_id = vid / bdx; + if (!partition_kv || num_chunk_this_seq == 1) { + st.normalize(tile_id); + } + if (partition_kv && num_chunk_this_seq > 1) { + const uint32_t head_idx = + (bid * num_chunks + chunk_id) * q_num_heads + q_head_idx; + Store( + st.o[tile_id], + tmp_workspace + head_idx * HEAD_DIM_V + vid * VEC_SIZE); + tmp_m[head_idx] = st.m; + tmp_d[head_idx] = st.d; + } else { + Store( + st.o[tile_id], + out + (q_write_idx * q_num_heads + q_head_idx) * HEAD_DIM_V + + vid * VEC_SIZE); + } + } +} + +template +void MultiQueryDecoderAttention( + const AppendAttnMetaData &meta_data, + cudaStream_t &stream, + const paddle::Tensor &q, + const paddle::Tensor + &cache_k, // [max_block_num, num_kv_heads, block_size, head_dim] + const paddle::Tensor &cache_v, // [num_kv_heads, head_dim] + const paddle::optional &attn_mask, + const paddle::optional &shift_bias, + const paddle::optional &smooth_weight, + const paddle::Tensor &seq_lens_q, + const paddle::Tensor &seq_lens_kv, + const paddle::Tensor &batch_id_per_token, + const paddle::Tensor &cu_seqlens_q, + const paddle::Tensor &block_table, + const int max_seq_len, + const int max_dec_len, + const float rope_scale, + const float rope_theta, + const float softmax_scale, + const float in_scale, + paddle::Tensor *out) { + using NV_TYPE = typename cascade_attn_type_traits::type; + + auto num_heads = meta_data.q_num_heads; + auto kv_num_heads = meta_data.kv_num_heads; + auto token_num = meta_data.token_nums; + auto bsz = meta_data.batch_size; + auto max_block_num_per_seq = meta_data.max_blocks_per_seq; + constexpr int num_stages = NUM_STAGE; + + constexpr int vec_size = 16 / sizeof(T); // 8 16 32 + constexpr int cache_vec_size = 128 / cache_bytes; // 8 16 32 + constexpr int blockxc = HEAD_DIM_QK / cache_vec_size; + constexpr int num_vec_per_head = HEAD_DIM_QK / vec_size; + constexpr int blockx = num_vec_per_head < 32 ? num_vec_per_head : 32; + + constexpr int blocky = GROUP_SIZE; + const int gridx = bsz; + + constexpr int num_threads = blockx * blocky; + + auto splitkv_kernel = multi_query_decode_attention_kernel; + uint32_t cache_smem_bytes = 0; + + const T *shift_bias_ptr = shift_bias ? shift_bias.get().data() : nullptr; + const T *smooth_weight_ptr = + smooth_weight ? smooth_weight.get().data() : nullptr; + cache_smem_bytes = num_stages * DEAL_EACH_TIME * HEAD_DIM_QK * sizeof(T); + + const uint32_t chunk_size = get_max_partition_size(bsz); + const int num_chunks = div_up(max_dec_len, chunk_size); + size_t smem_size = cache_smem_bytes + GROUP_SIZE * HEAD_DIM_QK * sizeof(T); + + if (smem_size >= 48 * 1024) { + cudaFuncSetAttribute( + splitkv_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); + } + const int dev_id = 0; + int sm_count; + int act_blocks_per_sm; + cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id); + cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &act_blocks_per_sm, splitkv_kernel, num_threads, smem_size); + assert(act_blocks_per_sm > 1); + + const int num_blocks_per_wave = sm_count * act_blocks_per_sm; + const int num_blocks_need = gridx * num_chunks * kv_num_heads; + const int max_num_chunks = div_up(num_blocks_per_wave, num_blocks_need); + const float ratio = static_cast(num_blocks_need) / + static_cast(num_blocks_per_wave); + + dim3 grids(gridx, num_chunks, kv_num_heads); + dim3 blocks(blockx, blocky); + if (num_chunks <= 1) { + auto no_splitkv_kernel = multi_query_decode_attention_kernel; + if (smem_size >= 48 * 1024) { + cudaFuncSetAttribute(no_splitkv_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + } + no_splitkv_kernel<<>>( + reinterpret_cast(const_cast(q.data())), + reinterpret_cast(const_cast(cache_k.data())), + reinterpret_cast(const_cast(cache_v.data())), + reinterpret_cast(const_cast(shift_bias_ptr)), + reinterpret_cast(const_cast(smooth_weight_ptr)), + seq_lens_q.data(), + seq_lens_kv.data(), + cu_seqlens_q.data(), + block_table.data(), + max_seq_len, + max_dec_len, + max_block_num_per_seq, + softmax_scale, + in_scale, + chunk_size, + nullptr, + nullptr, + nullptr, + reinterpret_cast(const_cast(out->data()))); + + // CHECK(cudaGetLastError()); + // CHECK(cudaDeviceSynchronize()); + } else { + auto *allocator = paddle::GetAllocator(q.place()); + phi::Allocator::AllocationPtr tmp_workspace, tmp_m, tmp_d; + tmp_workspace = allocator->Allocate( + phi::SizeOf(q.dtype()) * + static_cast(bsz * num_chunks * num_heads * HEAD_DIM_V)); + tmp_m = + allocator->Allocate(phi::SizeOf(q.dtype()) * + static_cast(bsz * num_chunks * num_heads)); + tmp_d = + allocator->Allocate(phi::SizeOf(q.dtype()) * + static_cast(bsz * num_chunks * num_heads)); + + splitkv_kernel<<>>( + reinterpret_cast(const_cast(q.data())), + reinterpret_cast(const_cast(cache_k.data())), + reinterpret_cast(const_cast(cache_v.data())), + reinterpret_cast(const_cast(shift_bias_ptr)), + reinterpret_cast(const_cast(smooth_weight_ptr)), + seq_lens_q.data(), + seq_lens_kv.data(), + cu_seqlens_q.data(), + block_table.data(), + max_seq_len, + max_dec_len, + max_block_num_per_seq, + softmax_scale, + in_scale, + chunk_size, + reinterpret_cast(tmp_workspace->ptr()), + reinterpret_cast(tmp_m->ptr()), + reinterpret_cast(tmp_d->ptr()), + reinterpret_cast(const_cast(out->data()))); + // CHECK(cudaGetLastError()); + // CHECK(cudaDeviceSynchronize()); + + constexpr int mblockx = HEAD_DIM_V / vec_size; + constexpr int bdy = 256 / mblockx; + dim3 grids_merge(bsz, num_heads); + dim3 blocks_merge(mblockx, bdy); + merge_varlen_multi_chunks_v2_kernel + <<>>( + reinterpret_cast(tmp_workspace->ptr()), + reinterpret_cast(tmp_m->ptr()), + reinterpret_cast(tmp_d->ptr()), + seq_lens_q.data(), + seq_lens_kv.data(), + cu_seqlens_q.data(), + reinterpret_cast(const_cast(shift_bias_ptr)), + reinterpret_cast(const_cast(smooth_weight_ptr)), + reinterpret_cast(const_cast(out->data())), + in_scale, + num_chunks, + chunk_size, + max_seq_len, + num_heads, + HEAD_DIM_V); + } + // CHECK(cudaGetLastError()); + // CHECK(cudaDeviceSynchronize()); +} diff --git a/custom_ops/gpu_ops/append_attn/multiquery_decoder_attention_kernel.h b/custom_ops/gpu_ops/append_attn/multiquery_decoder_attention_kernel.h new file mode 100644 index 00000000000..457f383e5ea --- /dev/null +++ b/custom_ops/gpu_ops/append_attn/multiquery_decoder_attention_kernel.h @@ -0,0 +1,48 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include "decode_attention_func.cuh" + +template +void MultiQueryDecoderAttention( + const AppendAttnMetaData &meta_data, + cudaStream_t &stream, + const paddle::Tensor &q, + const paddle::Tensor + &cache_k, // [max_block_num, num_kv_heads, block_size, head_dim] + const paddle::Tensor &cache_v, // [num_kv_heads, head_dim] + const paddle::optional &attn_mask, + const paddle::optional &shift_bias, + const paddle::optional &smooth_weight, + const paddle::Tensor &seq_lens_q, + const paddle::Tensor &seq_lens_kv, + const paddle::Tensor &batch_id_per_token, + const paddle::Tensor &cu_seqlens_q, + const paddle::Tensor &block_table, + const int max_seq_len, + const int max_dec_len, + const float rope_scale, + const float rope_theta, + const float softmax_scale, + const float in_scale, + paddle::Tensor *out); diff --git a/custom_ops/gpu_ops/append_attn/pre_cache_len_concat.cu b/custom_ops/gpu_ops/append_attn/pre_cache_len_concat.cu index 15da09e081c..492b3a26647 100644 --- a/custom_ops/gpu_ops/append_attn/pre_cache_len_concat.cu +++ b/custom_ops/gpu_ops/append_attn/pre_cache_len_concat.cu @@ -16,25 +16,26 @@ #include "paddle/extension.h" #include "paddle/phi/core/memory/memcpy.h" -__global__ void pre_cache_len_concat(const int* __restrict__ seq_lens_decoder, - const int* __restrict__ seq_lens_this_time, - int* __restrict__ cu_seqlens_k, - int* __restrict__ batch_ids, - int* __restrict__ tile_ids_per_batch, - int* __restrict__ num_blocks_x, - int* __restrict__ kv_token_num, - const int bsz, - const int num_row_per_block) { +__global__ void pre_cache_len_concat(const int* __restrict__ seq_lens_encoder, + const int* __restrict__ seq_lens_decoder, + const int* __restrict__ seq_lens_this_time, + int* __restrict__ cu_seqlens_k, + int* __restrict__ batch_ids, + int* __restrict__ tile_ids_per_batch, + int* __restrict__ num_blocks_x, + int* __restrict__ kv_token_num, + const int bsz, + const int num_row_per_block) { if (threadIdx.x == 0) { int gridx = 0; int index = 0; int total_tokens = 0; cu_seqlens_k[0] = 0; for (uint32_t bid = 0; bid < bsz; bid++) { - int cache_len = seq_lens_decoder[bid]; - const int q_len = seq_lens_this_time[bid]; - if (q_len <= 0) { - cache_len = 0; + int cache_len = 0; + if (seq_lens_encoder[bid] > 0) { + // only deal with chunked prefill case. + cache_len = seq_lens_decoder[bid]; } const int loop_times = div_up(cache_len, num_row_per_block); for (uint32_t tile_id = 0; tile_id < loop_times; tile_id++) { @@ -42,6 +43,7 @@ __global__ void pre_cache_len_concat(const int* __restrict__ seq_lens_decoder, tile_ids_per_batch[index++] = tile_id; } gridx += loop_times; + const int q_len = seq_lens_this_time[bid]; total_tokens += (cache_len + q_len); cu_seqlens_k[bid + 1] = total_tokens; } @@ -51,6 +53,7 @@ __global__ void pre_cache_len_concat(const int* __restrict__ seq_lens_decoder, } std::vector PreCacheLenConcat( + const paddle::Tensor& seq_lens_encoder, const paddle::Tensor& seq_lens_decoder, const paddle::Tensor& seq_lens_this_time, const int max_dec_len, @@ -58,45 +61,43 @@ std::vector PreCacheLenConcat( auto stream = seq_lens_decoder.stream(); auto place = seq_lens_decoder.place(); int bsz = seq_lens_this_time.shape()[0]; - const uint32_t max_tile_size_per_bs_pre_cache = div_up(max_dec_len, block_size); + const uint32_t max_tile_size_per_bs_pre_cache = + div_up(max_dec_len, block_size); - paddle::Tensor cu_seqlens_k = GetEmptyTensor( - {bsz + 1}, - paddle::DataType::INT32, - place); + paddle::Tensor cu_seqlens_k = + GetEmptyTensor({bsz + 1}, paddle::DataType::INT32, place); paddle::Tensor pre_cache_batch_ids = GetEmptyTensor( - {bsz * max_tile_size_per_bs_pre_cache}, - paddle::DataType::INT32, - place); + {bsz * max_tile_size_per_bs_pre_cache}, paddle::DataType::INT32, place); paddle::Tensor pre_cache_tile_ids_per_batch = GetEmptyTensor( - {bsz * max_tile_size_per_bs_pre_cache}, - paddle::DataType::INT32, - place); + {bsz * max_tile_size_per_bs_pre_cache}, paddle::DataType::INT32, place); paddle::Tensor pre_cache_num_blocks = - GetEmptyTensor({1}, paddle::DataType::INT32, place); + GetEmptyTensor({1}, paddle::DataType::INT32, place); paddle::Tensor kv_token_num = - GetEmptyTensor({1}, paddle::DataType::INT32, place); + GetEmptyTensor({1}, paddle::DataType::INT32, place); pre_cache_len_concat<<<1, 32, 0, stream>>>( - seq_lens_decoder.data(), - seq_lens_this_time.data(), - cu_seqlens_k.data(), - pre_cache_batch_ids.data(), - pre_cache_tile_ids_per_batch.data(), - pre_cache_num_blocks.data(), - kv_token_num.data(), - bsz, - block_size - ); - paddle::Tensor pre_cache_num_blocks_cpu = pre_cache_num_blocks.copy_to(paddle::CPUPlace(), false); - paddle::Tensor kv_token_num_cpu = kv_token_num.copy_to(paddle::CPUPlace(), false); + seq_lens_encoder.data(), + seq_lens_decoder.data(), + seq_lens_this_time.data(), + cu_seqlens_k.data(), + pre_cache_batch_ids.data(), + pre_cache_tile_ids_per_batch.data(), + pre_cache_num_blocks.data(), + kv_token_num.data(), + bsz, + block_size); + paddle::Tensor pre_cache_num_blocks_cpu = + pre_cache_num_blocks.copy_to(paddle::CPUPlace(), false); + paddle::Tensor kv_token_num_cpu = + kv_token_num.copy_to(paddle::CPUPlace(), false); - return {cu_seqlens_k, - pre_cache_batch_ids, - pre_cache_tile_ids_per_batch, - pre_cache_num_blocks_cpu, /*cpu*/ - kv_token_num_cpu /*cpu*/ - }; + return { + cu_seqlens_k, + pre_cache_batch_ids, + pre_cache_tile_ids_per_batch, + pre_cache_num_blocks_cpu, /*cpu*/ + kv_token_num_cpu /*cpu*/ + }; } std::vector PreCacheLenConcatInferDtype( @@ -121,15 +122,13 @@ std::vector> PreCacheLenConcatInferShape( } PD_BUILD_STATIC_OP(pre_cache_len_concat) - .Inputs({"seq_lens_decoder", - "seq_lens_this_time"}) + .Inputs({"seq_lens_encoder", "seq_lens_decoder", "seq_lens_this_time"}) .Outputs({"cu_seqlens_k", "pre_cache_batch_ids", "pre_cache_tile_ids_per_batch", "pre_cache_num_blocks_cpu", /*cpu*/ - "kv_token_num_cpu"}) /*cpu*/ - .Attrs({"max_dec_len: int", - "block_size: int"}) + "kv_token_num_cpu"}) /*cpu*/ + .Attrs({"max_dec_len: int", "block_size: int"}) .SetKernelFn(PD_KERNEL(PreCacheLenConcat)) .SetInferShapeFn(PD_INFER_SHAPE(PreCacheLenConcatInferShape)) .SetInferDtypeFn(PD_INFER_DTYPE(PreCacheLenConcatInferDtype)); diff --git a/custom_ops/gpu_ops/append_attn/qwen3_rope.h b/custom_ops/gpu_ops/append_attn/qwen3_rope.h new file mode 100644 index 00000000000..42017e42151 --- /dev/null +++ b/custom_ops/gpu_ops/append_attn/qwen3_rope.h @@ -0,0 +1,173 @@ +#include "encoder_write_cache_with_rope_impl.cuh" +#include "helper.h" +#include "paddle/extension.h" +#include "paddle/phi/backends/context_pool.h" +#include "paddle/phi/core/memory/memcpy.h" +#include "remote_cache_kv_ipc.h" + +template +__global__ void GQAVariableLengthRotarySplitKernel_Qwen3( + const T *qkv, + const float *cos_emb, + const float *sin_emb, + const int *batch_id_per_token, + const int *cu_seqlens_q, + const int *seq_lens_encoder, + const int *seq_lens_decoder, + const int *cu_seqlens_k, + T *qkv_out, + T *q, + T *k, + T *v, + const int64_t elem_cnt, + const int q_num_head, + const int kv_num_head, + const int max_model_len, + const int head_dim, + const bool rope_3d) { + using LoadT = AlignedVector; + using LoadEmbT = AlignedVector; + LoadEmbT cos_emb_vec; + LoadEmbT sin_emb_vec; + + const int64_t global_thread_idx = blockDim.x * blockIdx.x + threadIdx.x; + const int offset = (q_num_head + kv_num_head * 2) * (head_dim / 2); + const int64_t loop_times = elem_cnt / 2; + + for (int64_t linear_index = global_thread_idx * VecSize; + linear_index < loop_times; + linear_index += gridDim.x * blockDim.x * VecSize) { + const int token_idx = linear_index / offset; + + const int ori_bi = batch_id_per_token[token_idx]; // 第几个batch + + int cache_kv_len = seq_lens_decoder[ori_bi]; + // 这里其实是不需要处理的,但是由于FA3的bug,所以必须! + if (seq_lens_encoder[ori_bi] == 0) cache_kv_len = 0; + + const int bias = linear_index % offset; + const int hi = bias / (head_dim / 2); + const int h_bias = bias % (head_dim / 2); + // we should handle token_idx, hi 头 的 h_bias 部分! + + const int ori_seq_id = + (token_idx - cu_seqlens_q[ori_bi]) + + cache_kv_len; // 在当前seq中的id(拼接了seq到一个batch的情况下有效) + + const int half_headdim = head_dim / 2; + const int64_t emb_idx = ori_seq_id * head_dim + h_bias; // embedding的id + + const int64_t read_idx = + token_idx * (q_num_head + 2 * kv_num_head) * head_dim + hi * head_dim + + h_bias; + + LoadT src_vec0; + LoadT src_vec1; + + Load(&qkv[read_idx], &src_vec0); + Load(&qkv[read_idx + 64], &src_vec1); + + const int kv_write_idx = cu_seqlens_k[ori_bi] + ori_seq_id; + int64_t base_split_idx; + T *out_p = nullptr; + if (hi < q_num_head) { + base_split_idx = + token_idx * q_num_head * head_dim + hi * head_dim + h_bias; + out_p = q; + } else if (hi < q_num_head + kv_num_head) { + base_split_idx = kv_write_idx * kv_num_head * head_dim + + (hi - q_num_head) * head_dim + h_bias; + out_p = k; + } else { + out_p = v; + base_split_idx = kv_write_idx * kv_num_head * head_dim + + (hi - q_num_head - kv_num_head) * head_dim + h_bias; + } + + // TODO check this correct or not + int64_t new_emb_idx = + rope_3d ? emb_idx + ori_bi * 2 * max_model_len * head_dim : emb_idx; + + if (hi < q_num_head + kv_num_head) { + Load(&cos_emb[new_emb_idx], &cos_emb_vec); + Load(&sin_emb[new_emb_idx], &sin_emb_vec); +#pragma unroll + for (int i = 0; i < VecSize; i++) { + float input_left = static_cast(src_vec0[i]); + float input_right = static_cast(src_vec1[i]); + + const float cos_tmp = cos_emb_vec[i]; + const float sin_tmp = sin_emb_vec[i]; + src_vec0[i] = + static_cast(fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp)); + src_vec1[i] = + static_cast(fmul_func(input_right, cos_tmp) + + fmul_func(input_left, sin_tmp)); + } + } + Store(src_vec0, &qkv_out[read_idx]); + Store(src_vec0, &out_p[base_split_idx]); + Store(src_vec1, &qkv_out[read_idx + 64]); + Store(src_vec1, &out_p[base_split_idx + 64]); + } +} + +template +void gqa_rotary_qk_split_variable_qwen3(T *qkv_out, + T *q, + T *k, + T *v, + const T *qkv_input, + const float *rotary_emb, + const int *batch_id_per_token, + const int *seq_lens_encoder, + const int *seq_lens_decoder, + const int *cu_seqlens_q, + const int *cu_seqlens_k, + const int token_num, + const int num_heads, + const int kv_num_heads, + const int max_model_len, + const int head_dim, + const bool rope_3d, + const cudaStream_t &stream) { + assert(head_dim == 128 && "head_dim must be 128"); + + int64_t elem_nums = token_num * (num_heads + 2 * kv_num_heads) * head_dim; + + constexpr int HEAD_DIM = 128; + constexpr int PackSize = 8; + const int pack_num = elem_nums / PackSize; + const int blocksize = 128; + int grid_size = 1; + GetNumBlocks<128>(pack_num, &grid_size); + dim3 block_size(128); + + const float *cos_emb = rotary_emb; + const float *sin_emb = rotary_emb + max_model_len * head_dim; + launchWithPdlWhenEnabled( + GQAVariableLengthRotarySplitKernel_Qwen3, + grid_size, + block_size, + 0, + stream, + qkv_input, + cos_emb, + sin_emb, + batch_id_per_token, + cu_seqlens_q, + seq_lens_encoder, + seq_lens_decoder, + cu_seqlens_k, + qkv_out, + q, + k, + v, + elem_nums, + num_heads, + kv_num_heads, + max_model_len, + head_dim, + rope_3d); +} diff --git a/custom_ops/gpu_ops/append_attn/speculate_write_cache_with_rope_impl.cuh b/custom_ops/gpu_ops/append_attn/speculate_write_cache_with_rope_impl.cuh index 936d88e8701..9e63cf4e351 100644 --- a/custom_ops/gpu_ops/append_attn/speculate_write_cache_with_rope_impl.cuh +++ b/custom_ops/gpu_ops/append_attn/speculate_write_cache_with_rope_impl.cuh @@ -18,6 +18,173 @@ #include "mma_tensor_op.cuh" #include "utils.cuh" +template +__global__ void append_speculate_cache_T_rope_qk_norm_kernel( + const InT* __restrict__ qkv, // [token_num, num_heads + 2 * gqa_group_size, + // head_size] + T* __restrict__ key_cache, // [num_blocks, gqa_group_size, block_size, + // head_size // 2] + T* __restrict__ value_cache, // [num_blocks, gqa_group_size, block_size, + // head_size // 2] + T* __restrict__ q_out, + const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq] + const int* __restrict__ batch_id_per_token, // [num_tokens] + const int* __restrict__ cu_seqlens_q, + const int* __restrict__ seq_lens_decoder, // [bsz] + const int* __restrict__ seq_lens_encoder, // [bsz] + const float* __restrict__ cos_emb, + const float* __restrict__ sin_emb, + const float* + qkv_out_scales, // [(num_heads + 2 * gqa_group_size) * head_size] + const T* qkv_biases, // [num_head + 2 * gqa_group_size, dim_head] + const int max_seq_len, + const int max_blocks_per_seq, + const int num_heads, + const int output_inner_dim, + const int head_size, + const int block_size, + const int elem_cnt, + const int gqa_group_size, + const float* q_norm_weight, + const float* k_norm_weight, + const float rms_norm_eps, + const bool rope_3d) { + using LoadT = AlignedVector; + using LoadFloat = AlignedVector; + using LoadInT = AlignedVector; + constexpr int HalfVecSize = VecSize / 2; + using LoadEmbT = AlignedVector; + LoadInT src_vec; + LoadFloat scale_vec; + LoadT bias_vec; + LoadEmbT cos_emb_vec; + LoadEmbT sin_emb_vec; + LoadFloat tmp_vec; + LoadFloat q_norm_vec; + LoadFloat k_norm_vec; + + int64_t global_warp_idx = blockDim.y * blockIdx.x + threadIdx.y; + int64_t all_warp_num = gridDim.x * blockDim.y; + int64_t all_head_dim = elem_cnt / head_size; + + const int64_t hidden_size = (num_heads + 2 * gqa_group_size) * head_size; + const int half_head_size = head_size / 2; + for (int global_hi = global_warp_idx; global_hi < all_head_dim; + global_hi += all_warp_num) { + int64_t linear_index = global_hi * head_size + threadIdx.x * VecSize; + const int token_id = linear_index / hidden_size; + + const int ori_bi = batch_id_per_token[token_id]; + if (ori_bi == -1) continue; // NOTE(gongshaotian): For CUDAGraph padding + if (seq_lens_encoder[ori_bi] > 0) continue; + const int bias = linear_index % hidden_size; + const int hi = bias / head_size; // q + k + v + const int h_bias = bias % head_size; + const int start_token_idx = cu_seqlens_q[ori_bi]; + const int write_seq_id = + seq_lens_decoder[ori_bi] + token_id - start_token_idx; + if (write_seq_id == 0) continue; + + const int* block_table_now = block_tables + ori_bi * max_blocks_per_seq; + const int block_idx = block_table_now[write_seq_id / block_size]; + if (block_idx < 0) { + continue; // NOTE(gongshaotian): For CUDAGraph padding + } + const int block_offset = write_seq_id % block_size; + + const int write_q_idx = + token_id * output_inner_dim * head_size + hi * head_size + h_bias; + + const int bias_idx = hi * head_size + h_bias; + Load(&qkv[linear_index], &src_vec); + if (qkv_biases) { + Load(&qkv_biases[bias_idx], &bias_vec); + } + if (qkv_out_scales) { + Load(&qkv_out_scales[bias_idx], &scale_vec); + } + if (hi < num_heads + gqa_group_size) { + // q k rope + const int64_t emb_idx = write_seq_id * half_head_size + h_bias / 2; + uint32_t new_emb_idx = + rope_3d ? emb_idx + ori_bi * max_seq_len * head_size : emb_idx; + Load(&cos_emb[new_emb_idx], &cos_emb_vec); + Load(&sin_emb[new_emb_idx], &sin_emb_vec); + } + float thread_m2 = 0.0f; + float warp_m2 = 0.0f; +#pragma unroll + for (int i = 0; i < HalfVecSize; i++) { + // add_bias + rope + float input_left = static_cast(src_vec[2 * i]); + float input_right = static_cast(src_vec[2 * i + 1]); + if (qkv_out_scales) { + input_left *= scale_vec[2 * i]; + input_right *= scale_vec[2 * i + 1]; + } + if (qkv_biases) { + input_left = input_left + static_cast(bias_vec[2 * i]); + input_right = input_right + static_cast(bias_vec[2 * i + 1]); + } + if (hi < num_heads + gqa_group_size) { + const float cos_tmp = cos_emb_vec[i]; + const float sin_tmp = sin_emb_vec[i]; + float tmp1 = fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp); + float tmp2 = fmul_func(input_right, cos_tmp) + + fmul_func(input_left, sin_tmp); + thread_m2 += tmp1 * tmp1 + tmp2 * tmp2; + tmp_vec[2 * i] = tmp1; + tmp_vec[2 * i + 1] = tmp2; + } else { + bias_vec[2 * i] = static_cast(input_left); + bias_vec[2 * i + 1] = static_cast(input_right); + } + } + if (hi < (num_heads + gqa_group_size)) { + WelfordWarpAllReduce(thread_m2, &warp_m2); + float row_variance = max(warp_m2 / head_size, 0.0f); + float row_inv_var = Rsqrt(row_variance + rms_norm_eps); + if (hi < num_heads) { + Load(&q_norm_weight[threadIdx.x * VecSize], + &q_norm_vec); +#pragma unroll + for (int i = 0; i < VecSize; i++) { + bias_vec[i] = + static_cast(tmp_vec[i] * row_inv_var * q_norm_vec[i]); + } + } else { + Load(&k_norm_weight[threadIdx.x * VecSize], + &k_norm_vec); +#pragma unroll + for (int i = 0; i < VecSize; i++) { + bias_vec[i] = + static_cast(tmp_vec[i] * row_inv_var * k_norm_vec[i]); + } + } + } + if (hi < num_heads) { + // write q + Store(bias_vec, &q_out[write_q_idx]); + } else { + // write k/v + const int kv_head_idx = (hi - num_heads) % gqa_group_size; + const int tgt_idx = (block_idx * gqa_group_size * block_size * head_size + + kv_head_idx * block_size * head_size + + block_offset * head_size + h_bias); + // write + if (hi < num_heads + gqa_group_size) { + Store(bias_vec, &key_cache[tgt_idx]); + } else { + Store(bias_vec, &value_cache[tgt_idx]); + } + } + } +} + template __global__ void append_clear_cache_int8_block( uint8_t* __restrict__ key_cache, // [num_blocks, gqa_group_size, @@ -25,7 +192,7 @@ __global__ void append_clear_cache_int8_block( uint8_t* __restrict__ value_cache, // [num_blocks, gqa_group_size, // block_size, head_size // 2] const int* __restrict__ seq_lens, - const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq] + const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq] const int* __restrict__ batch_id_per_token, // [num_tokens] const int* __restrict__ cu_seqlens_q, const int* __restrict__ seq_lens_encoder, // [bsz] @@ -43,6 +210,7 @@ __global__ void append_clear_cache_int8_block( const int token_id = blockIdx.x; const int bid = batch_id_per_token[token_id]; + if (bid == -1) return; // NOTE(gongshaotian): For CUDAGraph padding const int start_token_idx = cu_seqlens_q[bid]; const int head_idx = blockIdx.y * NUM_WARPS + wid; @@ -91,7 +259,6 @@ __global__ void append_clear_cache_int8_block( } } - template __global__ void append_clear_cache_int4_block( uint8_t* __restrict__ key_cache, // [num_blocks, gqa_group_size, @@ -99,7 +266,7 @@ __global__ void append_clear_cache_int4_block( uint8_t* __restrict__ value_cache, // [num_blocks, gqa_group_size, // block_size, head_size // 2] const int* __restrict__ seq_lens, - const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq] + const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq] const int* __restrict__ batch_id_per_token, // [num_tokens] const int* __restrict__ cu_seqlens_q, const int* __restrict__ seq_lens_encoder, // [bsz] @@ -117,6 +284,7 @@ __global__ void append_clear_cache_int4_block( const int token_id = blockIdx.x; const int bid = batch_id_per_token[token_id]; + if (bid == -1) return; // NOTE(gongshaotian): For CUDAGraph padding const int start_token_idx = cu_seqlens_q[bid]; const int head_idx = blockIdx.y * NUM_WARPS + wid; @@ -168,7 +336,10 @@ __global__ void append_clear_cache_int4_block( } } -template +template __global__ void append_speculate_cache_rope_kernel( const InT* __restrict__ qkv, // [token_num, num_heads + 2 * gqa_group_size, // head_size] @@ -177,10 +348,11 @@ __global__ void append_speculate_cache_rope_kernel( T* __restrict__ value_cache, // [num_blocks, gqa_group_size, block_size, // head_size // 2] T* __restrict__ q_out, - const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq] + const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq] const int* __restrict__ batch_id_per_token, // [num_tokens] const int* __restrict__ cu_seqlens_q, const int* __restrict__ seq_lens_decoder, // [bsz] + const int* __restrict__ seq_lens_encoder, // [bsz] const float* __restrict__ cos_emb, const float* __restrict__ sin_emb, const float* @@ -193,7 +365,8 @@ __global__ void append_speculate_cache_rope_kernel( const int head_size, const int block_size, const int elem_cnt, - const int gqa_group_size) { + const int gqa_group_size, + const bool rope_3d) { using LoadT = AlignedVector; using LoadFloat = AlignedVector; using LoadInT = AlignedVector; @@ -215,7 +388,9 @@ __global__ void append_speculate_cache_rope_kernel( linear_index += step) { const int token_id = linear_index / hidden_size; const int ori_bi = batch_id_per_token[token_id]; - if (seq_lens_decoder[ori_bi] == 0) continue; + if (ori_bi == -1) continue; // NOTE(gongshaotian): For CUDAGraph padding + + if (seq_lens_encoder[ori_bi] > 0) continue; const int bias = linear_index % hidden_size; const int hi = bias / head_size; // q + k + v const int h_bias = bias % head_size; @@ -227,15 +402,7 @@ __global__ void append_speculate_cache_rope_kernel( const int* block_table_now = block_tables + ori_bi * max_blocks_per_seq; const int block_idx = block_table_now[write_seq_id / block_size]; if (block_idx < 0) { - printf( - "Fatal Error!!!, block idx %d when write_seq_id is %d\n some key var " - "%d %d %d %d\n", - block_idx, - write_seq_id, - ori_bi, - seq_lens_decoder[ori_bi], - token_id, - cu_seqlens_q[ori_bi]); + continue; // NOTE(gongshaotian): For CUDAGraph padding } const int block_offset = write_seq_id % block_size; @@ -253,8 +420,10 @@ __global__ void append_speculate_cache_rope_kernel( if (hi < num_heads + gqa_group_size) { // q k rope const int64_t emb_idx = write_seq_id * half_head_size + h_bias / 2; - Load(&cos_emb[emb_idx], &cos_emb_vec); - Load(&sin_emb[emb_idx], &sin_emb_vec); + int64_t new_emb_idx = + rope_3d ? emb_idx + ori_bi * max_seq_len * head_size : emb_idx; + Load(&cos_emb[new_emb_idx], &cos_emb_vec); + Load(&sin_emb[new_emb_idx], &sin_emb_vec); } #pragma unroll for (int i = 0; i < HalfVecSize; i++) { @@ -273,9 +442,11 @@ __global__ void append_speculate_cache_rope_kernel( const float cos_tmp = cos_emb_vec[i]; const float sin_tmp = sin_emb_vec[i]; bias_vec[2 * i] = - static_cast(input_left * cos_tmp - input_right * sin_tmp); + static_cast(fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp)); bias_vec[2 * i + 1] = - static_cast(input_right * cos_tmp + input_left * sin_tmp); + static_cast(fmul_func(input_right, cos_tmp) + + fmul_func(input_left, sin_tmp)); } else { bias_vec[2 * i] = static_cast(input_left); bias_vec[2 * i + 1] = static_cast(input_right); @@ -301,7 +472,10 @@ __global__ void append_speculate_cache_rope_kernel( } } -template +template __global__ void append_speculate_cache_neox_rope_kernel( const InT* __restrict__ qkv, // [token_num, num_heads + 2 * gqa_group_size, // head_size] @@ -310,10 +484,11 @@ __global__ void append_speculate_cache_neox_rope_kernel( T* __restrict__ value_cache, // [num_blocks, gqa_group_size, block_size, // head_size // 2] T* __restrict__ qkv_out, - const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq] + const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq] const int* __restrict__ batch_id_per_token, // [num_tokens] const int* __restrict__ cu_seqlens_q, const int* __restrict__ seq_lens_decoder, // [bsz] + const int* __restrict__ seq_lens_encoder, // [bsz] const float* __restrict__ cos_emb, const float* __restrict__ sin_emb, const float* @@ -326,7 +501,8 @@ __global__ void append_speculate_cache_neox_rope_kernel( const int head_size, const int block_size, const int elem_cnt, - const int gqa_group_size) { + const int gqa_group_size, + const bool rope_3d) { using LoadT = AlignedVector; using LoadFloat = AlignedVector; using LoadInT = AlignedVector; @@ -348,7 +524,8 @@ __global__ void append_speculate_cache_neox_rope_kernel( linear_index += step) { const int token_id = linear_index / half_hidden_size; const int ori_bi = batch_id_per_token[token_id]; - if (seq_lens_decoder[ori_bi] == 0) continue; + if (ori_bi == -1) continue; // NOTE(gongshaotian): For CUDAGraph padding + if (seq_lens_encoder[ori_bi] > 0) continue; const int bias = linear_index % half_hidden_size; const int hi = bias / half_head_size; // q + k + v const int h_bias = bias % half_head_size; @@ -360,15 +537,7 @@ __global__ void append_speculate_cache_neox_rope_kernel( const int* block_table_now = block_tables + ori_bi * max_blocks_per_seq; const int block_idx = block_table_now[write_seq_id / block_size]; if (block_idx < 0) { - printf( - "Fatal Error!!!, block idx %d when write_seq_id is %d\n some key var " - "%d %d %d %d\n", - block_idx, - write_seq_id, - ori_bi, - seq_lens_decoder[ori_bi], - token_id, - cu_seqlens_q[ori_bi]); + continue; // NOTE(gongshaotian): For CUDAGraph padding } const int block_offset = write_seq_id % block_size; @@ -390,8 +559,10 @@ __global__ void append_speculate_cache_neox_rope_kernel( if (hi < num_heads + gqa_group_size) { // q k rope const int64_t emb_idx = write_seq_id * head_size + h_bias; - Load(&cos_emb[emb_idx], &cos_emb_vec); - Load(&sin_emb[emb_idx], &sin_emb_vec); + int64_t new_emb_idx = + rope_3d ? emb_idx + ori_bi * max_seq_len * head_size * 2 : emb_idx; + Load(&cos_emb[new_emb_idx], &cos_emb_vec); + Load(&sin_emb[new_emb_idx], &sin_emb_vec); } #pragma unroll for (int i = 0; i < VecSize; i++) { @@ -410,9 +581,11 @@ __global__ void append_speculate_cache_neox_rope_kernel( const float cos_tmp = cos_emb_vec[i]; const float sin_tmp = sin_emb_vec[i]; left_bias_vec[i] = - static_cast(input_left * cos_tmp - input_right * sin_tmp); + static_cast(fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp)); right_bias_vec[i] = - static_cast(input_right * cos_tmp + input_left * sin_tmp); + static_cast(fmul_func(input_right, cos_tmp) + + fmul_func(input_left, sin_tmp)); } else { left_bias_vec[i] = static_cast(input_left); right_bias_vec[i] = static_cast(input_right); @@ -443,12 +616,464 @@ __global__ void append_speculate_cache_neox_rope_kernel( } } +template +__global__ void append_speculate_cache_neox_partial_rope_kernel( + const InT* __restrict__ qkv, // [token_num, num_heads + 2 * gqa_group_size, + // head_size] + T* __restrict__ key_cache, // [num_blocks, gqa_group_size, block_size, + // head_size // 2] + T* __restrict__ value_cache, // [num_blocks, gqa_group_size, block_size, + // head_size // 2] + T* __restrict__ qkv_out, + const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq] + const int* __restrict__ batch_id_per_token, // [num_tokens] + const int* __restrict__ cu_seqlens_q, + const int* __restrict__ seq_lens_decoder, // [bsz] + const int* __restrict__ seq_lens_encoder, // [bsz] + const float* __restrict__ cos_emb, + const float* __restrict__ sin_emb, + const float* + qkv_out_scales, // [(num_heads + 2 * gqa_group_size) * head_size] + const T* qkv_biases, // [num_head + 2 * gqa_group_size, dim_head] + const int max_seq_len, + const int max_blocks_per_seq, + const int num_heads, + const int output_inner_dim, + const int head_size, + const int rotary_dim, + const int block_size, + const int elem_cnt, + const int gqa_group_size, + const bool rope_3d) { + using LoadT = AlignedVector; + using LoadFloat = AlignedVector; + using LoadInT = AlignedVector; + constexpr int HalfVecSize = VecSize / 2; + using LoadEmbT = AlignedVector; + LoadInT left_vec, right_vec; + LoadT left_bias_vec, right_bias_vec; + LoadFloat left_out_scale_vec, right_out_scale_vec; + LoadEmbT cos_emb_vec; + LoadEmbT sin_emb_vec; + + int64_t global_thread_idx = blockDim.x * blockIdx.x + threadIdx.x; + const int64_t hidden_size = (num_heads + 2 * gqa_group_size) * head_size; + const int half_head_size = head_size / 2; + const int half_rotary_dim = rotary_dim / 2; + const int64_t half_hidden_size = hidden_size / 2; + for (int32_t linear_index = global_thread_idx * VecSize, + step = gridDim.x * blockDim.x * VecSize; + linear_index < elem_cnt; + linear_index += step) { + const int token_id = linear_index / half_hidden_size; + const int ori_bi = batch_id_per_token[token_id]; + if (ori_bi == -1) continue; // NOTE(gongshaotian): For CUDAGraph padding + if (seq_lens_encoder[ori_bi] > 0) continue; + const int bias = linear_index % half_hidden_size; + const int hi = bias / half_head_size; // q + k + v + const int h_bias = bias % half_head_size; + if (hi < num_heads && h_bias >= half_rotary_dim) { + continue; + } + const int start_token_idx = cu_seqlens_q[ori_bi]; + const int write_seq_id = + seq_lens_decoder[ori_bi] + token_id - start_token_idx; + if (write_seq_id == 0) continue; + + const int* block_table_now = block_tables + ori_bi * max_blocks_per_seq; + const int block_idx = block_table_now[write_seq_id / block_size]; + if (block_idx < 0) { + continue; // NOTE(gongshaotian): For CUDAGraph padding + } + const int block_offset = write_seq_id % block_size; + + const int bias_idx_left = hi * head_size + h_bias; + const int bias_idx_right = bias_idx_left + half_head_size; + int ori_idx_left = token_id * hidden_size + hi * head_size + h_bias; + int ori_idx_right = ori_idx_left + half_head_size; + if (hi < num_heads) { + ori_idx_right = ori_idx_left + half_rotary_dim; + } else if (hi < num_heads + gqa_group_size) { + if (h_bias < half_rotary_dim) { + ori_idx_right = ori_idx_left + half_rotary_dim; + } else { + ori_idx_left = ori_idx_left + half_rotary_dim; + ori_idx_right = ori_idx_left + half_rotary_dim; + } + } + Load(&qkv[ori_idx_left], &left_vec); + Load(&qkv[ori_idx_right], &right_vec); + if (qkv_biases) { + Load(&qkv_biases[bias_idx_left], &left_bias_vec); + Load(&qkv_biases[bias_idx_right], &right_bias_vec); + } + if (qkv_out_scales) { + Load(&qkv_out_scales[bias_idx_left], &left_out_scale_vec); + Load(&qkv_out_scales[bias_idx_right], + &right_out_scale_vec); + } + if (hi < num_heads + gqa_group_size) { + // q k rope + const int64_t emb_idx = write_seq_id * half_rotary_dim + h_bias; + int64_t new_emb_idx = + rope_3d ? emb_idx + ori_bi * max_seq_len * head_size * 2 : emb_idx; + if (h_bias < half_rotary_dim) { + Load(&cos_emb[new_emb_idx], &cos_emb_vec); + Load(&sin_emb[new_emb_idx], &sin_emb_vec); + } + } +#pragma unroll + for (int i = 0; i < VecSize; i++) { + // add_bias + rope + float input_left = static_cast(left_vec[i]); + float input_right = static_cast(right_vec[i]); + if (qkv_out_scales) { + input_left *= left_out_scale_vec[i]; + input_right *= right_out_scale_vec[i]; + } + if (qkv_biases) { + input_left = input_left + static_cast(left_bias_vec[i]); + input_right = input_right + static_cast(right_bias_vec[i]); + } + if (hi < num_heads + gqa_group_size && h_bias < half_rotary_dim) { + const float cos_tmp = cos_emb_vec[i]; + const float sin_tmp = sin_emb_vec[i]; + left_bias_vec[i] = + static_cast(fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp)); + right_bias_vec[i] = + static_cast(fmul_func(input_right, cos_tmp) + + fmul_func(input_left, sin_tmp)); + } else { + left_bias_vec[i] = static_cast(input_left); + right_bias_vec[i] = static_cast(input_right); + } + } + if (hi < num_heads) { + // write q + Store(left_bias_vec, &qkv_out[ori_idx_left]); + Store(right_bias_vec, &qkv_out[ori_idx_right]); + } else { + // write k/v + const int kv_head_idx = (hi - num_heads) % gqa_group_size; + int tgt_idx_left = (block_idx * gqa_group_size * block_size * head_size + + kv_head_idx * block_size * head_size + + block_offset * head_size + h_bias); + uint32_t tgt_idx_right = tgt_idx_left + half_head_size; + // write + if (hi < num_heads + gqa_group_size) { + if (h_bias < half_rotary_dim) { + tgt_idx_right = tgt_idx_left + half_rotary_dim; + } else { + tgt_idx_left = tgt_idx_left + half_rotary_dim; + tgt_idx_right = tgt_idx_left + half_rotary_dim; + } + Store(left_bias_vec, &key_cache[tgt_idx_left]); + Store(right_bias_vec, &key_cache[tgt_idx_right]); + } else { + Store(left_bias_vec, &value_cache[tgt_idx_left]); + Store(right_bias_vec, &value_cache[tgt_idx_right]); + } + } + } +} + +template +__global__ void append_speculate_cache_fp8_rope_qk_norm_dynamic_kernel( + const T* __restrict__ quant_qkv, // [num_head, num_heads + 2 * + // gqa_group_size, head_size] + uint8_t* __restrict__ key_cache, // [num_blocks, gqa_group_size, + // block_size, head_size // 2] + uint8_t* __restrict__ value_cache, // [num_blocks, gqa_group_size, + // block_size, head_size // 2] + T* __restrict__ qkv_out, + const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq] + const int* __restrict__ batch_id_per_token, // [num_tokens] + const int* __restrict__ cu_seqlens_q, + const int* __restrict__ seq_lens, // [bsz] + const int* __restrict__ seq_lens_encoder, // [bsz] + const float* __restrict__ cos_emb, + const float* __restrict__ sin_emb, + T* __restrict__ cache_k_scale, + T* __restrict__ cache_v_scale, + const float* q_norm_weight, + const float* k_norm_weight, + const int max_seq_len, + const int max_blocks_per_seq, + const int num_heads, + const int block_size, + const float max_bound, + const float min_bound, + const int gqa_group_size, + const bool rope_3d, + const float rms_norm_eps) { + static_assert(HeadDim == 128, "just support HeadDim be 128 now!"); + static_assert(VecSize == 4, "just support VecSize be 4 now, 32 * 4!"); + constexpr int NUM_WARPS = 4; + const int tid = threadIdx.x; + const int wid = tid / 32; + const int lane_id = tid % 32; + const int token_id = blockIdx.x; + + const int bid = batch_id_per_token[token_id]; + if (bid == -1) return; // NOTE(gongshaotian): For CUDAGraph padding + + const int start_token_idx = cu_seqlens_q[bid]; + const int head_idx = blockIdx.y * NUM_WARPS + wid; + int q_head_idx, k_head_idx, v_idx; + const int64_t hidden_size = (num_heads + 2 * gqa_group_size) * HeadDim; + constexpr int half_head_size = HeadDim / 2; + if (seq_lens_encoder[bid] > 0) return; + const int write_seq_id = seq_lens[bid] + token_id - start_token_idx; + if (write_seq_id == 0) return; + const int* block_table_now = block_tables + bid * max_blocks_per_seq; + const int block_idx = __ldg(&block_table_now[write_seq_id / block_size]); + const int block_offset = write_seq_id % block_size; + + int cache_offset; + if (head_idx < num_heads) { + cache_offset = 0; + } else if (head_idx < num_heads + 2 * gqa_group_size) { + cache_offset = block_idx * gqa_group_size * block_size + + (head_idx - num_heads) % gqa_group_size * block_size + + block_offset; + } + + float thread_m2 = 0.0f; + float warp_m2 = 0.0f; + + if (head_idx < num_heads) { + // q + using LoadT = AlignedVector; + using LoadBiasT = AlignedVector; + using LoadOutScaleT = AlignedVector; + constexpr int HalfVecSize = VecSize / 2; + using LoadEmbT = AlignedVector; + + LoadT src_vec; + LoadBiasT bias_vec; + LoadOutScaleT out_scale_vec; + LoadEmbT cos_emb_vec; + LoadEmbT sin_emb_vec; + const T* qkv_now = quant_qkv + token_id * hidden_size; + T* qkv_out_now = qkv_out + token_id * hidden_size; +#pragma unroll + for (uint32_t head_bias = lane_id * VecSize; head_bias < HeadDim; + head_bias += 32 * VecSize) { + const int bias_idx = head_idx * HeadDim + head_bias; + Load(&qkv_now[bias_idx], &src_vec); + + // q rope + const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2; + uint32_t new_emb_idx = + rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx; + Load(&cos_emb[new_emb_idx], &cos_emb_vec); + Load(&sin_emb[new_emb_idx], &sin_emb_vec); +#pragma unroll + for (int i = 0; i < HalfVecSize; i++) { + // dequant + add_bias + rope + float input_left = static_cast(src_vec[2 * i]); + float input_right = static_cast(src_vec[2 * i + 1]); + const float cos_tmp = cos_emb_vec[i]; + const float sin_tmp = sin_emb_vec[i]; + float tmp1 = fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp); + float tmp2 = fmul_func(input_right, cos_tmp) + + fmul_func(input_left, sin_tmp); + thread_m2 += tmp1 * tmp1 + tmp2 * tmp2; + bias_vec[2 * i] = static_cast(tmp1); + bias_vec[2 * i + 1] = static_cast(tmp2); + } + // qk norm + if (q_norm_weight) { + WelfordWarpAllReduce(thread_m2, &warp_m2); + float row_variance = max(warp_m2 / HeadDim, 0.0f); + float row_inv_var = Rsqrt(row_variance + rms_norm_eps); + LoadOutScaleT q_norm_vec; + Load(&q_norm_weight[lane_id * VecSize], &q_norm_vec); +#pragma unroll + for (int i = 0; i < VecSize; i++) { + bias_vec[i] = static_cast(static_cast(bias_vec[i]) * + row_inv_var * q_norm_vec[i]); + } + } + Store(bias_vec, &qkv_out_now[bias_idx]); + } + } else if (head_idx < num_heads + 2 * gqa_group_size) { + // k + constexpr int KV_VEC_SIZE = 16 / sizeof(uint8_t); // 16 + using LoadPadKVT = AlignedVector; + const uint32_t kv_head_idx = (head_idx - num_heads) % gqa_group_size; + + constexpr int K_VEC_SIZE = 4; + constexpr int HALF_K_VEC_SIZE = 2; + using LoadKVResT = AlignedVector; + using LoadKVT = AlignedVector; + using LoadT = AlignedVector; + using LoadBiasT = AlignedVector; + using LoadOutScaleT = AlignedVector; + using LoadEmbT = AlignedVector; + LoadKVResT cache_vec; + LoadT src_vec1, src_vec2; + LoadBiasT bias_vec1, bias_vec2; + LoadOutScaleT out_scale_vec1, out_scale_vec2; + LoadEmbT cos_emb_vec1, cos_emb_vec2; + LoadEmbT sin_emb_vec1, sin_emb_vec2; + + const T* qkv_now = quant_qkv + token_id * hidden_size; + const int head_bias = lane_id / 4 * 16 + lane_id % 4 * 2; + const int bias_idx = head_idx * HeadDim + head_bias; + Load(&qkv_now[bias_idx], &src_vec1); + Load(&qkv_now[bias_idx + 8], &src_vec2); + T scale = T(1.0f); + const int k_head_idx = head_idx - num_heads; + const int v_head_idx = head_idx - num_heads - gqa_group_size; + if (head_idx < num_heads + gqa_group_size) { + const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2; + uint32_t new_emb_idx = + rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx; + Load(&cos_emb[new_emb_idx], &cos_emb_vec1); + Load(&cos_emb[new_emb_idx + 4], &cos_emb_vec2); + Load(&sin_emb[new_emb_idx], &sin_emb_vec1); + Load(&sin_emb[new_emb_idx + 4], &sin_emb_vec2); + } + + float input_left = static_cast(src_vec1[0]); + float input_right = static_cast(src_vec1[1]); + if (head_idx < num_heads + gqa_group_size) { + float cos_tmp = cos_emb_vec1[0]; + float sin_tmp = sin_emb_vec1[0]; + float tmp1 = fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp); + float tmp2 = fmul_func(input_right, cos_tmp) + + fmul_func(input_left, sin_tmp); + thread_m2 += tmp1 * tmp1 + tmp2 * tmp2; + bias_vec1[0] = static_cast(tmp1); + bias_vec1[1] = static_cast(tmp2); + } else { + bias_vec1[0] = static_cast(input_left); + bias_vec1[1] = static_cast(input_right); + } + + input_left = static_cast(src_vec2[0]); + input_right = static_cast(src_vec2[1]); + if (head_idx < num_heads + gqa_group_size) { + float cos_tmp = cos_emb_vec2[0]; + float sin_tmp = sin_emb_vec2[0]; + float tmp1 = fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp); + float tmp2 = fmul_func(input_right, cos_tmp) + + fmul_func(input_left, sin_tmp); + thread_m2 += tmp1 * tmp1 + tmp2 * tmp2; + bias_vec2[0] = static_cast(tmp1); + bias_vec2[1] = static_cast(tmp2); + } else { + bias_vec2[0] = static_cast(input_left); + bias_vec2[1] = static_cast(input_right); + } + if (k_norm_weight) { + if (head_idx < num_heads + gqa_group_size) { + LoadOutScaleT k_norm_vec1, k_norm_vec2; + Load(&k_norm_weight[head_bias], &k_norm_vec1); + Load(&k_norm_weight[head_bias + 8], + &k_norm_vec2); + // qk norm + WelfordWarpAllReduce(thread_m2, &warp_m2); + float row_variance = max(warp_m2 / HeadDim, 0.0f); + float row_inv_var = Rsqrt(row_variance + rms_norm_eps); + + for (int i = 0; i < HALF_K_VEC_SIZE; i++) { + bias_vec1[i] = static_cast(static_cast(bias_vec1[i]) * + row_inv_var * k_norm_vec1[i]); + bias_vec2[i] = static_cast(static_cast(bias_vec2[i]) * + row_inv_var * k_norm_vec2[i]); + } + } + } + // reduce max, 1 head per warp + if constexpr (IsDynamic) { + T local_max = -INFINITY; +#pragma unroll + for (int i = 0; i < HALF_K_VEC_SIZE; i++) { + local_max = __hmax(local_max, __habs(bias_vec1[i])); + local_max = __hmax(local_max, __habs(bias_vec2[i])); + } +#pragma unroll + for (int m_offset = 16; m_offset > 0; m_offset /= 2) { + local_max = + __hmax(local_max, __shfl_xor_sync(0xffffffff, local_max, m_offset)); + } + + scale = __hdiv(448, local_max); + T* cache_k_scale_now = cache_k_scale + cache_offset; + T* cache_v_scale_now = cache_v_scale + cache_offset; + if (lane_id == 0) { + if (head_idx < num_heads + gqa_group_size) { + cache_k_scale_now[0] = __hdiv(1, scale); + } else { + cache_v_scale_now[0] = __hdiv(1, scale); + } + } + } else { + if (head_idx < num_heads + gqa_group_size) { + scale = __ldg(&cache_k_scale[kv_head_idx]); + } else { + scale = __ldg(&cache_v_scale[kv_head_idx]); + } + } + +#pragma unroll + for (uint32_t i = 0; i < HALF_K_VEC_SIZE; i++) { + cache_vec[i] = QuantToC8( + scale, bias_vec1[i], max_bound, min_bound); + cache_vec[i + HALF_K_VEC_SIZE] = QuantToC8( + scale, bias_vec2[i], max_bound, min_bound); + } + if (head_idx < num_heads + gqa_group_size) { + const int start_block_16 = + block_offset / 16 * 16 + block_offset % 8 + lane_id / 4 % 2 * 8; + const uint32_t tgt_cache_idx = + block_idx * gqa_group_size * block_size * HeadDim + + kv_head_idx * block_size * HeadDim + start_block_16 * HeadDim + + lane_id / 4 / 2 * 32 + (block_offset % 16) / 8 * 16 + lane_id % 4 * 4; + Store(cache_vec, &key_cache[tgt_cache_idx]); + } else { + const uint32_t base_tgt_cache_idx = + block_idx * gqa_group_size * HeadDim * block_size + + kv_head_idx * HeadDim * block_size + + (lane_id / 4 * 16 + lane_id % 4 * 2) * block_size + + block_offset / 16 % 2 * 8 * block_size + block_offset / 16 / 2 * 32; + const uint32_t tgt_cache_idx1 = base_tgt_cache_idx + + block_offset % 8 / 2 * 4 // per 4 + + block_offset % 16 / 8 * 2 // per 2 + + block_offset % 2; // per 1 + const uint32_t tgt_cache_idx2 = tgt_cache_idx1 + block_size; + const uint32_t tgt_cache_idx3 = tgt_cache_idx1 + 16; + const uint32_t tgt_cache_idx4 = tgt_cache_idx3 + block_size; + value_cache[tgt_cache_idx1] = cache_vec[0]; + value_cache[tgt_cache_idx2] = cache_vec[1]; + value_cache[tgt_cache_idx3] = cache_vec[2]; + value_cache[tgt_cache_idx4] = cache_vec[3]; + } + } +} + template + bool IsFP8 = false, + bool EnforceFmulRN = false> __global__ void append_speculate_cache_int8_rope_kernel( const InT* __restrict__ quant_qkv, // [num_head, num_heads + 2 * // gqa_group_size, head_size] @@ -457,7 +1082,7 @@ __global__ void append_speculate_cache_int8_rope_kernel( uint8_t* __restrict__ value_cache, // [num_blocks, gqa_group_size, // block_size, head_size // 2] T* __restrict__ qkv_out, - const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq] + const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq] const int* __restrict__ batch_id_per_token, // [num_tokens] const int* __restrict__ cu_seqlens_q, const int* __restrict__ seq_lens, // [bsz] @@ -476,7 +1101,8 @@ __global__ void append_speculate_cache_int8_rope_kernel( const int block_size, const float max_bound, const float min_bound, - const int gqa_group_size) { + const int gqa_group_size, + const bool rope_3d) { static_assert(HeadDim == 128, "just support HeadDim be 128 now!"); static_assert(VecSize == 4, "just support VecSize be 4 now, 32 * 4!"); constexpr int NUM_WARPS = 4; @@ -486,6 +1112,7 @@ __global__ void append_speculate_cache_int8_rope_kernel( const int token_id = blockIdx.x; const int bid = batch_id_per_token[token_id]; + if (bid == -1) return; // NOTE(gongshaotian): For CUDAGraph padding const int start_token_idx = cu_seqlens_q[bid]; const int head_idx = blockIdx.y * NUM_WARPS + wid; @@ -522,8 +1149,10 @@ __global__ void append_speculate_cache_int8_rope_kernel( // q rope const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2; - Load(&cos_emb[emb_idx], &cos_emb_vec); - Load(&sin_emb[emb_idx], &sin_emb_vec); + uint32_t new_emb_idx = + rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx; + Load(&cos_emb[new_emb_idx], &cos_emb_vec); + Load(&sin_emb[new_emb_idx], &sin_emb_vec); if (qkv_out_scales) { Load(&qkv_out_scales[bias_idx], &out_scale_vec); } @@ -548,9 +1177,11 @@ __global__ void append_speculate_cache_int8_rope_kernel( const float cos_tmp = cos_emb_vec[i]; const float sin_tmp = sin_emb_vec[i]; bias_vec[2 * i] = - static_cast(input_left * cos_tmp - input_right * sin_tmp); + static_cast(fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp)); bias_vec[2 * i + 1] = - static_cast(input_right * cos_tmp + input_left * sin_tmp); + static_cast(fmul_func(input_right, cos_tmp) + + fmul_func(input_left, sin_tmp)); } Store(bias_vec, &qkv_out_now[bias_idx]); } @@ -583,10 +1214,12 @@ __global__ void append_speculate_cache_int8_rope_kernel( T scale; if (head_idx < num_heads + gqa_group_size) { const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2; - Load(&cos_emb[emb_idx], &cos_emb_vec1); - Load(&cos_emb[emb_idx + 4], &cos_emb_vec2); - Load(&sin_emb[emb_idx], &sin_emb_vec1); - Load(&sin_emb[emb_idx + 4], &sin_emb_vec2); + uint32_t new_emb_idx = + rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx; + Load(&cos_emb[new_emb_idx], &cos_emb_vec1); + Load(&cos_emb[new_emb_idx + 4], &cos_emb_vec2); + Load(&sin_emb[new_emb_idx], &sin_emb_vec1); + Load(&sin_emb[new_emb_idx + 4], &sin_emb_vec2); scale = __ldg(&cache_k_scales[kv_head_idx]); } else { scale = __ldg(&cache_v_scales[kv_head_idx]); @@ -612,9 +1245,11 @@ __global__ void append_speculate_cache_int8_rope_kernel( float cos_tmp = cos_emb_vec1[0]; float sin_tmp = sin_emb_vec1[0]; bias_vec1[0] = - static_cast(input_left * cos_tmp - input_right * sin_tmp); + static_cast(fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp)); bias_vec1[1] = - static_cast(input_right * cos_tmp + input_left * sin_tmp); + static_cast(fmul_func(input_right, cos_tmp) + + fmul_func(input_left, sin_tmp)); } else { bias_vec1[0] = static_cast(input_left); bias_vec1[1] = static_cast(input_right); @@ -635,17 +1270,21 @@ __global__ void append_speculate_cache_int8_rope_kernel( float cos_tmp = cos_emb_vec2[0]; float sin_tmp = sin_emb_vec2[0]; bias_vec2[0] = - static_cast(input_left * cos_tmp - input_right * sin_tmp); + static_cast(fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp)); bias_vec2[1] = - static_cast(input_right * cos_tmp + input_left * sin_tmp); + static_cast(fmul_func(input_right, cos_tmp) + + fmul_func(input_left, sin_tmp)); } else { bias_vec2[0] = static_cast(input_left); bias_vec2[1] = static_cast(input_right); } #pragma unroll for (uint32_t i = 0; i < HALF_K_VEC_SIZE; i++) { - cache_vec[i] = QuantToC8(scale, bias_vec1[i], max_bound, min_bound); - cache_vec[i + HALF_K_VEC_SIZE] = QuantToC8(scale, bias_vec2[i], max_bound, min_bound); + cache_vec[i] = QuantToC8( + scale, bias_vec1[i], max_bound, min_bound); + cache_vec[i + HALF_K_VEC_SIZE] = QuantToC8( + scale, bias_vec2[i], max_bound, min_bound); } if (head_idx < num_heads + gqa_group_size) { const int start_block_16 = @@ -680,7 +1319,8 @@ template + typename InT = int, + bool EnforceFmulRN = false> __global__ void append_speculate_cache_int8_neox_rope_kernel( const InT* __restrict__ quant_qkv, // [num_head, num_heads + 2 * // gqa_group_size, head_size] @@ -689,7 +1329,7 @@ __global__ void append_speculate_cache_int8_neox_rope_kernel( uint8_t* __restrict__ value_cache, // [num_blocks, gqa_group_size, // block_size, head_size // 2] T* __restrict__ qkv_out, - const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq] + const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq] const int* __restrict__ batch_id_per_token, // [num_tokens] const int* __restrict__ cu_seqlens_q, const int* __restrict__ seq_lens, // [bsz] @@ -708,7 +1348,8 @@ __global__ void append_speculate_cache_int8_neox_rope_kernel( const int block_size, const float max_bound, const float min_bound, - const int gqa_group_size) { + const int gqa_group_size, + const bool rope_3d) { static_assert(HeadDim == 128, "just support HeadDim be 128 now!"); static_assert(VecSize == 4, "just support VecSize be 4 now, 32 * 4!"); constexpr int NUM_WARPS = 4; @@ -718,6 +1359,7 @@ __global__ void append_speculate_cache_int8_neox_rope_kernel( const int token_id = blockIdx.x; const int bid = batch_id_per_token[token_id]; + if (bid == -1) return; // NOTE(gongshaotian): For CUDAGraph padding const int start_token_idx = cu_seqlens_q[bid]; const int head_idx = blockIdx.y * NUM_WARPS + wid; @@ -757,8 +1399,10 @@ __global__ void append_speculate_cache_int8_neox_rope_kernel( // q rope const uint32_t emb_idx = write_seq_id * HeadDim + head_bias; - Load(&cos_emb[emb_idx], &cos_emb_vec); - Load(&sin_emb[emb_idx], &sin_emb_vec); + uint32_t new_emb_idx = + rope_3d ? emb_idx + bid * max_seq_len * HeadDim * 2 : emb_idx; + Load(&cos_emb[new_emb_idx], &cos_emb_vec); + Load(&sin_emb[new_emb_idx], &sin_emb_vec); if (qkv_out_scales) { Load(&qkv_out_scales[bias_idx_left], &left_out_scale_vec); @@ -787,9 +1431,11 @@ __global__ void append_speculate_cache_int8_neox_rope_kernel( const float cos_tmp = cos_emb_vec[i]; const float sin_tmp = sin_emb_vec[i]; left_bias_vec[i] = - static_cast(input_left * cos_tmp - input_right * sin_tmp); + static_cast(fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp)); right_bias_vec[i] = - static_cast(input_right * cos_tmp + input_left * sin_tmp); + static_cast(fmul_func(input_right, cos_tmp) + + fmul_func(input_left, sin_tmp)); } Store(left_bias_vec, &qkv_out_now[bias_idx_left]); Store(right_bias_vec, &qkv_out_now[bias_idx_right]); @@ -853,10 +1499,12 @@ __global__ void append_speculate_cache_int8_neox_rope_kernel( T scale; const uint32_t emb_idx = write_seq_id * HeadDim + head_bias; - Load(&cos_emb[emb_idx], &cos_emb_vec1); - Load(&cos_emb[emb_idx + 8], &cos_emb_vec2); - Load(&sin_emb[emb_idx], &sin_emb_vec1); - Load(&sin_emb[emb_idx + 8], &sin_emb_vec2); + uint32_t new_emb_idx = + rope_3d ? emb_idx + bid * max_seq_len * HeadDim * 2 : emb_idx; + Load(&cos_emb[new_emb_idx], &cos_emb_vec1); + Load(&cos_emb[new_emb_idx + 8], &cos_emb_vec2); + Load(&sin_emb[new_emb_idx], &sin_emb_vec1); + Load(&sin_emb[new_emb_idx + 8], &sin_emb_vec2); scale = __ldg(&cache_k_scales[kv_head_idx]); #pragma unroll for (int i = 0; i < HALF_K_VEC_SIZE; i++) { @@ -874,9 +1522,11 @@ __global__ void append_speculate_cache_int8_neox_rope_kernel( float cos_tmp = cos_emb_vec1[i]; float sin_tmp = sin_emb_vec1[i]; left_bias_vec1[i] = - static_cast(input_left * cos_tmp - input_right * sin_tmp); + static_cast(fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp)); right_bias_vec1[i] = - static_cast(input_right * cos_tmp + input_left * sin_tmp); + static_cast(fmul_func(input_right, cos_tmp) + + fmul_func(input_left, sin_tmp)); input_left = static_cast(left_src_vec2[i]); input_right = static_cast(right_src_vec2[i]); @@ -889,9 +1539,11 @@ __global__ void append_speculate_cache_int8_neox_rope_kernel( cos_tmp = cos_emb_vec2[i]; sin_tmp = sin_emb_vec2[i]; left_bias_vec2[i] = - static_cast(input_left * cos_tmp - input_right * sin_tmp); + static_cast(fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp)); right_bias_vec2[i] = - static_cast(input_right * cos_tmp + input_left * sin_tmp); + static_cast(fmul_func(input_right, cos_tmp) + + fmul_func(input_left, sin_tmp)); float quant_value1 = static_cast(scale * left_bias_vec1[i]); float quant_value2 = static_cast(scale * left_bias_vec2[i]); @@ -1058,7 +1710,8 @@ template + typename InT = int, + bool EnforceFmulRN = false> __global__ void append_speculate_cache_int4_rope_kernel( const InT* __restrict__ quant_qkv, // [bsz, num_heads + 2 * gqa_group_size, // head_size] @@ -1067,7 +1720,7 @@ __global__ void append_speculate_cache_int4_rope_kernel( uint8_t* __restrict__ value_cache, // [num_blocks, gqa_group_size, // block_size, head_size // 2] T* __restrict__ qkv_out, - const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq] + const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq] const int* __restrict__ batch_id_per_token, // [num_tokens] const int* __restrict__ cu_seqlens_q, const int* __restrict__ seq_lens, // [bsz] @@ -1088,7 +1741,8 @@ __global__ void append_speculate_cache_int4_rope_kernel( const int block_size, const float max_bound, const float min_bound, - const int gqa_group_size) { + const int gqa_group_size, + const bool rope_3d) { static_assert(HeadDim == 128, "just support HeadDim be 128 now!"); static_assert(VecSize == 4, "just support VecSize be 4 now, 32 * 4!"); constexpr int NUM_WARPS = 4; @@ -1099,6 +1753,7 @@ __global__ void append_speculate_cache_int4_rope_kernel( const int token_id = blockIdx.x; const int bid = batch_id_per_token[token_id]; + if (bid == -1) return; // NOTE(gongshaotian): For CUDAGraph padding const int start_token_idx = cu_seqlens_q[bid]; const int head_idx = blockIdx.y * NUM_WARPS + wid; @@ -1130,6 +1785,10 @@ __global__ void append_speculate_cache_int4_rope_kernel( LoadOutScaleT out_scale_vec; LoadEmbT cos_emb_vec; LoadEmbT sin_emb_vec; +#pragma unroll + for (int v_i = 0; v_i < VecSize; v_i++) { + bias_vec[v_i] = 0; + } const InT* qkv_now = quant_qkv + token_id * hidden_size; T* qkv_out_now = qkv_out + token_id * hidden_size; #pragma unroll @@ -1137,27 +1796,31 @@ __global__ void append_speculate_cache_int4_rope_kernel( head_bias += 32 * VecSize) { const int bias_idx = head_idx * HeadDim + head_bias; Load(&qkv_now[bias_idx], &src_vec); - Load(&qkv_biases[bias_idx], &bias_vec); - Load(&qkv_out_scales[bias_idx], &out_scale_vec); + // Load(&qkv_biases[bias_idx], &bias_vec); + // Load(&qkv_out_scales[bias_idx], &out_scale_vec); // q rope const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2; - Load(&cos_emb[emb_idx], &cos_emb_vec); - Load(&sin_emb[emb_idx], &sin_emb_vec); + uint32_t new_emb_idx = + rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx; + Load(&cos_emb[new_emb_idx], &cos_emb_vec); + Load(&sin_emb[new_emb_idx], &sin_emb_vec); #pragma unroll for (int i = 0; i < HalfVecSize; i++) { // dequant + add_bias + rope float input_left = static_cast(src_vec[2 * i]); float input_right = static_cast(src_vec[2 * i + 1]); - input_left = input_left * out_scale_vec[2 * i] + - static_cast(bias_vec[2 * i]); - input_right = input_right * out_scale_vec[2 * i + 1] + - static_cast(bias_vec[2 * i + 1]); + // input_left = input_left * out_scale_vec[2 * i] + + // static_cast(bias_vec[2 * i]); + // input_right = input_right * out_scale_vec[2 * i + 1] + + // static_cast(bias_vec[2 * i + 1]); const float cos_tmp = cos_emb_vec[i]; const float sin_tmp = sin_emb_vec[i]; bias_vec[2 * i] = - static_cast(input_left * cos_tmp - input_right * sin_tmp); + static_cast(fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp)); bias_vec[2 * i + 1] = - static_cast(input_right * cos_tmp + input_left * sin_tmp); + static_cast(fmul_func(input_right, cos_tmp) + + fmul_func(input_left, sin_tmp)); } Store(bias_vec, &qkv_out_now[bias_idx]); } @@ -1167,6 +1830,35 @@ __global__ void append_speculate_cache_int4_rope_kernel( using LoadPadKVT = AlignedVector; const uint32_t kv_head_idx = (head_idx - num_heads) % gqa_group_size; + if (block_offset == 0) { + // pad zero for this kv_head_idx for this block + LoadPadKVT pad_cache_vec; + *(reinterpret_cast(pad_cache_vec.val)) = make_uint4(0, 0, 0, 0); + if (head_idx < num_heads + gqa_group_size) { + constexpr int num_vecs_per_head_dim = half_head_size / KV_VEC_SIZE; + constexpr int num_token_each_time = 32 / num_vecs_per_head_dim; + const uint32_t tgt_idx = (block_idx * gqa_group_size + kv_head_idx) * + block_size * half_head_size + + lane_id % num_vecs_per_head_dim * KV_VEC_SIZE; + for (int block_i = lane_id / num_vecs_per_head_dim; + block_i < block_size; + block_i += num_token_each_time) { + Store( + pad_cache_vec, &key_cache[tgt_idx + block_i * half_head_size]); + } + } else { + const int num_vecs_per_head_dim = half_block_size / KV_VEC_SIZE; + const int num_token_each_time = 32 / num_vecs_per_head_dim; + const uint32_t tgt_idx = (block_idx * gqa_group_size + kv_head_idx) * + HeadDim * half_block_size + + lane_id % num_vecs_per_head_dim * KV_VEC_SIZE; + for (int block_i = lane_id / num_vecs_per_head_dim; block_i < HeadDim; + block_i += num_token_each_time) { + Store( + pad_cache_vec, &value_cache[tgt_idx + block_i * half_block_size]); + } + } + } constexpr int K_VEC_SIZE = 4; constexpr int HALF_K_VEC_SIZE = 2; using LoadKVResT = AlignedVector; @@ -1182,7 +1874,11 @@ __global__ void append_speculate_cache_int4_rope_kernel( LoadScaleT zp_vec1, zp_vec2; LoadEmbT cos_emb_vec1, cos_emb_vec2; LoadEmbT sin_emb_vec1, sin_emb_vec2; - +#pragma unroll + for (int v_i = 0; v_i < HALF_K_VEC_SIZE; v_i++) { + bias_vec1[v_i] = 0; + bias_vec2[v_i] = 0; + } const InT* qkv_now = quant_qkv + token_id * hidden_size; const int head_bias = lane_id / 4 * 16 + lane_id % 4 * 2; ////////// @@ -1191,17 +1887,19 @@ __global__ void append_speculate_cache_int4_rope_kernel( Load(&qkv_now[bias_idx], &src_vec1); Load(&qkv_now[bias_idx + 8], &src_vec2); ///// - Load(&qkv_biases[bias_idx], &bias_vec1); - Load(&qkv_biases[bias_idx + 8], &bias_vec2); - Load(&qkv_out_scales[bias_idx], &out_scale_vec1); - Load(&qkv_out_scales[bias_idx + 8], - &out_scale_vec2); + // Load(&qkv_biases[bias_idx], &bias_vec1); + // Load(&qkv_biases[bias_idx + 8], &bias_vec2); + // Load(&qkv_out_scales[bias_idx], &out_scale_vec1); + // Load(&qkv_out_scales[bias_idx + 8], + // &out_scale_vec2); if (head_idx < num_heads + gqa_group_size) { const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2; - Load(&cos_emb[emb_idx], &cos_emb_vec1); - Load(&cos_emb[emb_idx + 4], &cos_emb_vec2); - Load(&sin_emb[emb_idx], &sin_emb_vec1); - Load(&sin_emb[emb_idx + 4], &sin_emb_vec2); + uint32_t new_emb_idx = + rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx; + Load(&cos_emb[new_emb_idx], &cos_emb_vec1); + Load(&cos_emb[new_emb_idx + 4], &cos_emb_vec2); + Load(&sin_emb[new_emb_idx], &sin_emb_vec1); + Load(&sin_emb[new_emb_idx + 4], &sin_emb_vec2); Load(&cache_k_scales[cache_idx], &scale_vec1); Load(&cache_k_scales[cache_idx + 8], &scale_vec2); Load(&cache_k_zero_points[cache_idx], &zp_vec1); @@ -1215,17 +1913,19 @@ __global__ void append_speculate_cache_int4_rope_kernel( float input_left = static_cast(src_vec1[0]); float input_right = static_cast(src_vec1[1]); - input_left = - input_left * out_scale_vec1[0] + static_cast(bias_vec1[0]); - input_right = - input_right * out_scale_vec1[1] + static_cast(bias_vec1[1]); + // input_left = + // input_left * out_scale_vec1[0] + static_cast(bias_vec1[0]); + // input_right = + // input_right * out_scale_vec1[1] + static_cast(bias_vec1[1]); if (head_idx < num_heads + gqa_group_size) { float cos_tmp = cos_emb_vec1[0]; float sin_tmp = sin_emb_vec1[0]; bias_vec1[0] = - static_cast(input_left * cos_tmp - input_right * sin_tmp); + static_cast(fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp)); bias_vec1[1] = - static_cast(input_right * cos_tmp + input_left * sin_tmp); + static_cast(fmul_func(input_right, cos_tmp) + + fmul_func(input_left, sin_tmp)); } else { bias_vec1[0] = static_cast(input_left); bias_vec1[1] = static_cast(input_right); @@ -1233,17 +1933,19 @@ __global__ void append_speculate_cache_int4_rope_kernel( input_left = static_cast(src_vec2[0]); input_right = static_cast(src_vec2[1]); - input_left = - input_left * out_scale_vec2[0] + static_cast(bias_vec2[0]); - input_right = - input_right * out_scale_vec2[1] + static_cast(bias_vec2[1]); + // input_left = + // input_left * out_scale_vec2[0] + static_cast(bias_vec2[0]); + // input_right = + // input_right * out_scale_vec2[1] + static_cast(bias_vec2[1]); if (head_idx < num_heads + gqa_group_size) { float cos_tmp = cos_emb_vec2[0]; float sin_tmp = sin_emb_vec2[0]; bias_vec2[0] = - static_cast(input_left * cos_tmp - input_right * sin_tmp); + static_cast(fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp)); bias_vec2[1] = - static_cast(input_right * cos_tmp + input_left * sin_tmp); + static_cast(fmul_func(input_right, cos_tmp) + + fmul_func(input_left, sin_tmp)); } else { bias_vec2[0] = static_cast(input_left); bias_vec2[1] = static_cast(input_right); @@ -1304,7 +2006,6 @@ __global__ void append_speculate_cache_int4_rope_kernel( } Store(cache_vec, &key_cache[tgt_cache_idx]); } else { - const uint32_t base_tgt_cache_idx = block_idx * gqa_group_size * HeadDim * half_block_size + kv_head_idx * HeadDim * half_block_size + @@ -1364,7 +2065,8 @@ template + typename InT = int, + bool EnforceFmulRN = false> __global__ void append_speculate_cache_int4_neox_rope_kernel( const InT* __restrict__ quant_qkv, // [bsz, num_heads + 2 * gqa_group_size, // head_size] @@ -1373,7 +2075,7 @@ __global__ void append_speculate_cache_int4_neox_rope_kernel( uint8_t* __restrict__ value_cache, // [num_blocks, gqa_group_size, // block_size, head_size // 2] T* __restrict__ qkv_out, - const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq] + const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq] const int* __restrict__ batch_id_per_token, // [num_tokens] const int* __restrict__ cu_seqlens_q, const int* __restrict__ seq_lens, // [bsz] @@ -1394,7 +2096,8 @@ __global__ void append_speculate_cache_int4_neox_rope_kernel( const int block_size, const float max_bound, const float min_bound, - const int gqa_group_size) { + const int gqa_group_size, + const bool rope_3d) { static_assert(HeadDim == 128, "just support HeadDim be 128 now!"); static_assert(VecSize == 4, "just support VecSize be 4 now, 32 * 4!"); constexpr int NUM_WARPS = 4; @@ -1405,6 +2108,7 @@ __global__ void append_speculate_cache_int4_neox_rope_kernel( const int token_id = blockIdx.x; const int bid = batch_id_per_token[token_id]; + if (bid == -1) return; // NOTE(gongshaotian): For CUDAGraph padding const int start_token_idx = cu_seqlens_q[bid]; const int head_idx = blockIdx.y * NUM_WARPS + wid; @@ -1478,9 +2182,11 @@ __global__ void append_speculate_cache_int4_neox_rope_kernel( const float cos_tmp = cos_emb_vec[i]; const float sin_tmp = sin_emb_vec[i]; left_bias_vec[i] = - static_cast(input_left * cos_tmp - input_right * sin_tmp); + static_cast(fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp)); right_bias_vec[i] = - static_cast(input_right * cos_tmp + input_left * sin_tmp); + static_cast(fmul_func(input_right, cos_tmp) + + fmul_func(input_left, sin_tmp)); } Store(left_bias_vec, &qkv_out_now[bias_idx_left]); Store(right_bias_vec, &qkv_out_now[bias_idx_right]); @@ -1544,10 +2250,12 @@ __global__ void append_speculate_cache_int4_neox_rope_kernel( &right_out_scale_vec2); const uint32_t emb_idx = write_seq_id * HeadDim + head_bias; - Load(&cos_emb[emb_idx], &cos_emb_vec1); - Load(&cos_emb[emb_idx + 8], &cos_emb_vec2); - Load(&sin_emb[emb_idx], &sin_emb_vec1); - Load(&sin_emb[emb_idx + 8], &sin_emb_vec2); + uint32_t new_emb_idx = + rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx; + Load(&cos_emb[new_emb_idx], &cos_emb_vec1); + Load(&cos_emb[new_emb_idx + 8], &cos_emb_vec2); + Load(&sin_emb[new_emb_idx], &sin_emb_vec1); + Load(&sin_emb[new_emb_idx + 8], &sin_emb_vec2); Load(&cache_k_scales[left_cache_idx], &left_scale_vec1); Load(&cache_k_scales[left_cache_idx + 8], @@ -1577,19 +2285,22 @@ __global__ void append_speculate_cache_int4_neox_rope_kernel( float cos_tmp = cos_emb_vec1[0]; float sin_tmp = sin_emb_vec1[0]; left_bias_vec1[i] = - static_cast(input_left * cos_tmp - input_right * sin_tmp); + static_cast(fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp)); right_bias_vec1[i] = - static_cast(input_right * cos_tmp + input_left * sin_tmp); - + static_cast(fmul_func(input_right, cos_tmp) + + fmul_func(input_left, sin_tmp)); input_left = static_cast(left_src_vec2[i]); input_right = static_cast(right_src_vec2[i]); cos_tmp = cos_emb_vec2[i]; sin_tmp = sin_emb_vec2[i]; left_bias_vec2[i] = - static_cast(input_left * cos_tmp - input_right * sin_tmp); + static_cast(fmul_func(input_left, cos_tmp) - + fmul_func(input_right, sin_tmp)); right_bias_vec2[i] = - static_cast(input_right * cos_tmp + input_left * sin_tmp); + static_cast(fmul_func(input_right, cos_tmp) + + fmul_func(input_left, sin_tmp)); // quant + write k } LoadKVResT left_cache_vec, right_cache_vec; diff --git a/custom_ops/gpu_ops/append_attn/speculate_write_cache_with_rope_kernel.cu b/custom_ops/gpu_ops/append_attn/speculate_write_cache_with_rope_kernel.cu index fb6a24fefab..e87289a74ec 100644 --- a/custom_ops/gpu_ops/append_attn/speculate_write_cache_with_rope_kernel.cu +++ b/custom_ops/gpu_ops/append_attn/speculate_write_cache_with_rope_kernel.cu @@ -15,8 +15,83 @@ #include "speculate_write_cache_with_rope_kernel.h" #include "utils.cuh" +template +void append_speculate_cache_rope_qk_norm(const QKV_TYPE* qkv, + T* key_cache, + T* value_cache, + T* qkv_out, + const int* block_tables, + const int* batch_id_per_token, + const int* cu_seqlens_q, + const int* seq_lens, + const int* seq_lens_encoder, + const float* cos_emb, + const float* sin_emb, + const float* qkv_out_scales, + const T* qkv_biases, + const int max_seq_len, + const int max_blocks_per_seq, + const int num_heads, + const int kv_num_heads, + const int dim_head, + const int block_size, + const int bsz, + const int token_num, + const cudaStream_t& stream, + const bool use_neox_style, + const float* q_norm_weight, + const float* k_norm_weight, + const float rms_norm_eps, + const bool rope_3d) { + int output_inner_dim = num_heads + 2 * kv_num_heads; + const uint32_t elem_nums = + use_neox_style ? token_num * (num_heads + 2 * kv_num_heads) * dim_head / 2 + : token_num * (num_heads + 2 * kv_num_heads) * dim_head; + constexpr int HEAD_DIM = 128; + + constexpr int PackSize = HEAD_DIM / kWarpSize; + const int pack_num = elem_nums / PackSize; + const int blocksize = 128; + int grid_size = 1; + GetNumBlocks<128>(pack_num, &grid_size); + if (use_neox_style) { + PD_THROW("append_speculate_cache_rope_qk_norm not support neox rope yet"); + } else { + dim3 block_dim(kWarpSize, blocksize / kWarpSize, 1); + append_speculate_cache_T_rope_qk_norm_kernel + <<>>(qkv, + key_cache, + value_cache, + qkv_out, + block_tables, + batch_id_per_token, + cu_seqlens_q, + seq_lens, + seq_lens_encoder, + cos_emb, + sin_emb, + qkv_out_scales, + qkv_biases, + max_seq_len, + max_blocks_per_seq, + num_heads, + output_inner_dim, + dim_head, + block_size, + elem_nums, + kv_num_heads, + q_norm_weight, + k_norm_weight, + rms_norm_eps, + rope_3d); + } +} + // rope + write -template +template void append_speculate_cache_rope(const QKV_TYPE* qkv, T* key_cache, T* value_cache, @@ -35,11 +110,13 @@ void append_speculate_cache_rope(const QKV_TYPE* qkv, const int num_heads, const int kv_num_heads, const int dim_head, + const int rotary_dim, const int block_size, const int bsz, const int token_num, const cudaStream_t& stream, - const bool use_neox_style) { + const bool use_neox_style, + const bool rope_3d) { int output_inner_dim = num_heads + 2 * kv_num_heads; const uint32_t elem_nums = @@ -52,30 +129,66 @@ void append_speculate_cache_rope(const QKV_TYPE* qkv, int grid_size = 1; GetNumBlocks(pack_num, &grid_size); if (use_neox_style) { - append_speculate_cache_neox_rope_kernel - <<>>( - qkv, // [token_num, num_heads + 2 * gqa_group_size, head_size] - key_cache, - value_cache, - qkv_out, - block_tables, - batch_id_per_token, - cu_seqlens_q, - seq_lens, - cos_emb, - sin_emb, - qkv_out_scales, - qkv_biases, // [num_head + 2 * gqa_group_size, dim_head] - max_seq_len, - max_blocks_per_seq, - num_heads, - output_inner_dim, - dim_head, - block_size, - elem_nums, - kv_num_heads); + if (rotary_dim < dim_head) { + append_speculate_cache_neox_partial_rope_kernel + <<>>( + qkv, // [token_num, num_heads + 2 * gqa_group_size, head_size] + key_cache, + value_cache, + qkv_out, + block_tables, + batch_id_per_token, + cu_seqlens_q, + seq_lens, + seq_lens_encoder, + cos_emb, + sin_emb, + qkv_out_scales, + qkv_biases, // [num_head + 2 * gqa_group_size, dim_head] + max_seq_len, + max_blocks_per_seq, + num_heads, + output_inner_dim, + dim_head, + rotary_dim, + block_size, + elem_nums, + kv_num_heads, + rope_3d); + } else { + append_speculate_cache_neox_rope_kernel + <<>>( + qkv, // [token_num, num_heads + 2 * gqa_group_size, head_size] + key_cache, + value_cache, + qkv_out, + block_tables, + batch_id_per_token, + cu_seqlens_q, + seq_lens, + seq_lens_encoder, + cos_emb, + sin_emb, + qkv_out_scales, + qkv_biases, // [num_head + 2 * gqa_group_size, dim_head] + max_seq_len, + max_blocks_per_seq, + num_heads, + output_inner_dim, + dim_head, + block_size, + elem_nums, + kv_num_heads, + rope_3d); + } } else { - append_speculate_cache_rope_kernel + append_speculate_cache_rope_kernel <<>>( qkv, // [token_num, num_heads + 2 * gqa_group_size, head_size] key_cache, @@ -85,6 +198,7 @@ void append_speculate_cache_rope(const QKV_TYPE* qkv, batch_id_per_token, cu_seqlens_q, seq_lens, + seq_lens_encoder, cos_emb, sin_emb, qkv_out_scales, @@ -96,11 +210,93 @@ void append_speculate_cache_rope(const QKV_TYPE* qkv, dim_head, block_size, elem_nums, - kv_num_heads); + kv_num_heads, + rope_3d); } } -template +template +void append_speculate_cache_fp8_rope(const T* qkv, + uint8_t* key_cache, + uint8_t* value_cache, + T* qkv_out, + const int* block_tables, + const int* batch_id_per_token, + const int* cu_seqlens_q, + const int* seq_lens, + const int* seq_lens_encoder, + const float* cos_emb, + const float* sin_emb, + T* cache_k_scale, + T* cache_v_scale, + const float* q_norm_weight, + const float* k_norm_weight, + const int max_seq_len, + const int max_blocks_per_seq, + const int num_heads, + const int kv_num_heads, + const int dim_head, + const int block_size, + const int bsz, + const int token_num, + const cudaStream_t& stream, + const bool rope_3d, + const float rms_norm_eps) { + constexpr int num_warps = 4; + const int all_warps = + ((num_heads + 2 * kv_num_heads) + num_warps - 1) / num_warps * num_warps; + dim3 grids(token_num, all_warps / num_warps); + + append_clear_cache_int8_block<4, 128> + <<>>(key_cache, + value_cache, + seq_lens, + block_tables, + batch_id_per_token, + cu_seqlens_q, + seq_lens_encoder, + max_seq_len, + max_blocks_per_seq, + num_heads, + block_size, + kv_num_heads); + append_speculate_cache_fp8_rope_qk_norm_dynamic_kernel + <<>>(qkv, + key_cache, + value_cache, + qkv_out, + block_tables, + batch_id_per_token, + cu_seqlens_q, + seq_lens, + seq_lens_encoder, + cos_emb, + sin_emb, + cache_k_scale, + cache_v_scale, + q_norm_weight, + k_norm_weight, + max_seq_len, + max_blocks_per_seq, + num_heads, + block_size, + 127.0f, + -127.0f, + kv_num_heads, + rope_3d, + rms_norm_eps); +} + +template void append_speculate_cache_int8_rope(const QKV_TYPE* qkv, uint8_t* key_cache, uint8_t* value_cache, @@ -125,13 +321,14 @@ void append_speculate_cache_int8_rope(const QKV_TYPE* qkv, const int bsz, const int token_num, const cudaStream_t& stream, - const bool use_neox_style) { + const bool use_neox_style, + const bool rope_3d) { constexpr int num_warps = 4; const int all_warps = ((num_heads + 2 * kv_num_heads) + num_warps - 1) / num_warps * num_warps; dim3 grids(token_num, all_warps / num_warps); - append_clear_cache_int8_block<4> + append_clear_cache_int8_block<4, 128> <<>>(key_cache, value_cache, seq_lens, @@ -145,7 +342,12 @@ void append_speculate_cache_int8_rope(const QKV_TYPE* qkv, block_size, kv_num_heads); if (use_neox_style) { - append_speculate_cache_int8_neox_rope_kernel + append_speculate_cache_int8_neox_rope_kernel <<>>(qkv, key_cache, value_cache, @@ -167,9 +369,16 @@ void append_speculate_cache_int8_rope(const QKV_TYPE* qkv, block_size, 127.0f, -127.0f, - kv_num_heads); + kv_num_heads, + rope_3d); } else { - append_speculate_cache_int8_rope_kernel + append_speculate_cache_int8_rope_kernel <<>>(qkv, key_cache, value_cache, @@ -191,11 +400,12 @@ void append_speculate_cache_int8_rope(const QKV_TYPE* qkv, block_size, 127.0f, -127.0f, - kv_num_heads); + kv_num_heads, + rope_3d); } } -template +template void append_speculate_cache_int4_rope(const QKV_TYPE* qkv, uint8_t* key_cache, uint8_t* value_cache, @@ -222,13 +432,14 @@ void append_speculate_cache_int4_rope(const QKV_TYPE* qkv, const int bsz, const int token_num, const cudaStream_t& stream, - const bool use_neox_style) { + const bool use_neox_style, + const bool rope_3d) { constexpr int num_warps = 4; const int all_warps = ((num_heads + 2 * kv_num_heads) + num_warps - 1) / num_warps * num_warps; dim3 grids(token_num, all_warps / num_warps); - append_clear_cache_int4_block<4> + append_clear_cache_int4_block<4, 128> <<>>(key_cache, value_cache, seq_lens, @@ -242,7 +453,12 @@ void append_speculate_cache_int4_rope(const QKV_TYPE* qkv, block_size, kv_num_heads); if (use_neox_style) { - append_speculate_cache_int4_neox_rope_kernel + append_speculate_cache_int4_neox_rope_kernel <<>>(qkv, key_cache, value_cache, @@ -266,9 +482,15 @@ void append_speculate_cache_int4_rope(const QKV_TYPE* qkv, block_size, 7.0f, -8.0f, - kv_num_heads); + kv_num_heads, + rope_3d); } else { - append_speculate_cache_int4_rope_kernel + append_speculate_cache_int4_rope_kernel <<>>(qkv, key_cache, value_cache, @@ -292,10 +514,11 @@ void append_speculate_cache_int4_rope(const QKV_TYPE* qkv, block_size, 7.0f, -8.0f, - kv_num_heads); + kv_num_heads, + rope_3d); } } -template +template void SpeculateWriteCacheWithRoPEKernel( const AppendAttnMetaData& meta_data, const paddle::Tensor& qkv, @@ -313,11 +536,15 @@ void SpeculateWriteCacheWithRoPEKernel( const paddle::optional& cache_v_zp, const std::string& cache_quant_type_str, const bool use_neox_rotary_style, + const bool rope_3d, const int max_seq_len, cudaStream_t& stream, paddle::Tensor* qkv_out, paddle::Tensor* key_cache_out, - paddle::Tensor* value_cache_out) { + paddle::Tensor* value_cache_out, + const paddle::optional& q_norm_weight, + const paddle::optional& k_norm_weight, + const float rms_norm_eps) { typedef cascade_attn_type_traits traits_; typedef cascade_attn_type_traits qkt_nv_type_; typedef typename traits_::type DataType_; @@ -332,152 +559,306 @@ void SpeculateWriteCacheWithRoPEKernel( auto num_heads = meta_data.q_num_heads; auto kv_num_heads = meta_data.kv_num_heads; - const float* cos_emb = rotary_embs ? rotary_embs.get().data() : nullptr; const float* sin_emb; + int rotary_dim = dim_head; if (rotary_embs) { sin_emb = use_neox_rotary_style ? rotary_embs.get().data() + max_seq_len * dim_head : rotary_embs.get().data() + max_seq_len * dim_head / 2; + rotary_dim = + rotary_embs.get().dims()[rotary_embs.get().dims().size() - 1] * 2; + if (rotary_dim < dim_head) { + if (!use_neox_rotary_style || qkv_out_scales || q_norm_weight || + k_norm_weight || cache_quant_type_str != "none") { + PADDLE_THROW(phi::errors::Fatal( + "partial_rotary_factor < 1.0 only supports neox_rotary_style=True, " + "qkv_out_scales is None, q_norm_weight/k_norm_weight) is None, and " + "cache_quant_type_str is 'none'.")); + } + sin_emb = rotary_embs.get().data() + max_seq_len * rotary_dim / 2; + } } - if (cache_quant_type_str == "none") { - append_speculate_cache_rope( - reinterpret_cast(qkv_ptr), - reinterpret_cast(key_cache_out->data()), - reinterpret_cast(value_cache_out->data()), - reinterpret_cast(qkv_out->data()), - block_tables.data(), - batch_id_per_token.data(), - cu_seqlens_q.data(), - seq_lens.data(), - seq_lens_encoder.data(), - cos_emb, - sin_emb, - qkv_out_scales ? qkv_out_scales.get().data() : nullptr, - qkv_biases ? reinterpret_cast( - const_cast(qkv_biases.get().data())) - : nullptr, - max_seq_len, - max_blocks_per_seq, - num_heads, - kv_num_heads, - dim_head, - block_size, - bsz, - token_nums, - stream, - use_neox_rotary_style); - } else if (cache_quant_type_str == "cache_int8") { - append_speculate_cache_int8_rope( - reinterpret_cast(qkv_ptr), - key_cache_out->data(), - value_cache_out->data(), - reinterpret_cast(qkv_out->data()), - block_tables.data(), - batch_id_per_token.data(), - cu_seqlens_q.data(), - seq_lens.data(), - seq_lens_encoder.data(), - cos_emb, - sin_emb, - qkv_out_scales ? qkv_out_scales.get().data() : nullptr, - qkv_biases ? reinterpret_cast( - const_cast(qkv_biases.get().data())) - : nullptr, - cache_k_scale ? reinterpret_cast( - const_cast(cache_k_scale.get().data())) - : nullptr, - cache_v_scale ? reinterpret_cast( - const_cast(cache_v_scale.get().data())) - : nullptr, - max_seq_len, - max_blocks_per_seq, - num_heads, - kv_num_heads, - dim_head, - block_size, - bsz, - token_nums, - stream, - use_neox_rotary_style); - } else if (cache_quant_type_str == "cache_fp8") { - append_speculate_cache_int8_rope( - reinterpret_cast(qkv_ptr), - key_cache_out->data(), - value_cache_out->data(), - reinterpret_cast(qkv_out->data()), - block_tables.data(), - batch_id_per_token.data(), - cu_seqlens_q.data(), - seq_lens.data(), - seq_lens_encoder.data(), - cos_emb, - sin_emb, - qkv_out_scales ? qkv_out_scales.get().data() : nullptr, - qkv_biases ? reinterpret_cast( - const_cast(qkv_biases.get().data())) - : nullptr, - cache_k_scale ? reinterpret_cast( - const_cast(cache_k_scale.get().data())) - : nullptr, - cache_v_scale ? reinterpret_cast( - const_cast(cache_v_scale.get().data())) - : nullptr, - max_seq_len, - max_blocks_per_seq, - num_heads, - kv_num_heads, - dim_head, - block_size, - bsz, - token_nums, - stream, - use_neox_rotary_style); - } else if (cache_quant_type_str == "cache_int4_zp") { - append_speculate_cache_int4_rope( - reinterpret_cast(qkv_ptr), - key_cache_out->data(), - value_cache_out->data(), - reinterpret_cast(const_cast(qkv_out->data())), - block_tables.data(), - batch_id_per_token.data(), - cu_seqlens_q.data(), - seq_lens.data(), - seq_lens_encoder.data(), - cos_emb, - sin_emb, - qkv_out_scales ? qkv_out_scales.get().data() : nullptr, - qkv_biases ? reinterpret_cast( - const_cast(qkv_biases.get().data())) - : nullptr, - cache_k_scale ? reinterpret_cast( - const_cast(cache_k_scale.get().data())) - : nullptr, - cache_v_scale ? reinterpret_cast( - const_cast(cache_v_scale.get().data())) - : nullptr, - cache_k_zp ? reinterpret_cast( - const_cast(cache_k_zp.get().data())) - : nullptr, - cache_v_zp ? reinterpret_cast( - const_cast(cache_v_zp.get().data())) - : nullptr, - max_seq_len, - max_blocks_per_seq, - num_heads, - kv_num_heads, - dim_head, - block_size, - bsz, - token_nums, - stream, - use_neox_rotary_style); + if (q_norm_weight && k_norm_weight) { + if (cache_quant_type_str == "none") { + append_speculate_cache_rope_qk_norm( + reinterpret_cast(qkv_ptr), + reinterpret_cast(key_cache_out->data()), + reinterpret_cast(value_cache_out->data()), + reinterpret_cast(qkv_out->data()), + block_tables.data(), + batch_id_per_token.data(), + cu_seqlens_q.data(), + seq_lens.data(), + seq_lens_encoder.data(), + cos_emb, + sin_emb, + qkv_out_scales ? qkv_out_scales.get().data() : nullptr, + qkv_biases ? reinterpret_cast( + const_cast(qkv_biases.get().data())) + : nullptr, + max_seq_len, + max_blocks_per_seq, + num_heads, + kv_num_heads, + dim_head, + block_size, + bsz, + token_nums, + stream, + use_neox_rotary_style, + reinterpret_cast(q_norm_weight.get().data()), + reinterpret_cast(k_norm_weight.get().data()), + rms_norm_eps, + rope_3d); + } else if (cache_quant_type_str == "block_wise_fp8") { + append_speculate_cache_fp8_rope( + reinterpret_cast(qkv_ptr), + key_cache_out->data(), + value_cache_out->data(), + reinterpret_cast(qkv_out->data()), + block_tables.data(), + batch_id_per_token.data(), + cu_seqlens_q.data(), + seq_lens.data(), + seq_lens_encoder.data(), + cos_emb, + sin_emb, + const_cast(reinterpret_cast( + cache_k_scale.get().data())), + const_cast(reinterpret_cast( + cache_v_scale.get().data())), + q_norm_weight.get().data(), + k_norm_weight.get().data(), + max_seq_len, + max_blocks_per_seq, + num_heads, + kv_num_heads, + dim_head, + block_size, + bsz, + token_nums, + stream, + rope_3d, + rms_norm_eps); + } else if (cache_quant_type_str == "cache_fp8") { + append_speculate_cache_fp8_rope( + reinterpret_cast(qkv_ptr), + key_cache_out->data(), + value_cache_out->data(), + reinterpret_cast(qkv_out->data()), + block_tables.data(), + batch_id_per_token.data(), + cu_seqlens_q.data(), + seq_lens.data(), + seq_lens_encoder.data(), + cos_emb, + sin_emb, + const_cast(reinterpret_cast( + cache_k_scale.get().data())), + const_cast(reinterpret_cast( + cache_v_scale.get().data())), + q_norm_weight.get().data(), + k_norm_weight.get().data(), + max_seq_len, + max_blocks_per_seq, + num_heads, + kv_num_heads, + dim_head, + block_size, + bsz, + token_nums, + stream, + rope_3d, + rms_norm_eps); + } else { + PD_THROW( + "speculate_append_decode_cache_rope_qk_norm just supports " + "cache_quant_type " + "none/block_wise_fp8/cache_fp8"); + } + } else { - PD_THROW( - "cache_quant_type_str should be one of [none, cache_int8, " - "cache_int4_zp]"); + if (cache_quant_type_str == "none") { + append_speculate_cache_rope( + reinterpret_cast(qkv_ptr), + reinterpret_cast(key_cache_out->data()), + reinterpret_cast(value_cache_out->data()), + reinterpret_cast(qkv_out->data()), + block_tables.data(), + batch_id_per_token.data(), + cu_seqlens_q.data(), + seq_lens.data(), + seq_lens_encoder.data(), + cos_emb, + sin_emb, + qkv_out_scales ? qkv_out_scales.get().data() : nullptr, + qkv_biases ? reinterpret_cast( + const_cast(qkv_biases.get().data())) + : nullptr, + max_seq_len, + max_blocks_per_seq, + num_heads, + kv_num_heads, + dim_head, + rotary_dim, + block_size, + bsz, + token_nums, + stream, + use_neox_rotary_style, + rope_3d); + } else if (cache_quant_type_str == "cache_int8") { + append_speculate_cache_int8_rope( + reinterpret_cast(qkv_ptr), + key_cache_out->data(), + value_cache_out->data(), + reinterpret_cast(qkv_out->data()), + block_tables.data(), + batch_id_per_token.data(), + cu_seqlens_q.data(), + seq_lens.data(), + seq_lens_encoder.data(), + cos_emb, + sin_emb, + qkv_out_scales ? qkv_out_scales.get().data() : nullptr, + qkv_biases ? reinterpret_cast( + const_cast(qkv_biases.get().data())) + : nullptr, + cache_k_scale ? reinterpret_cast( + const_cast(cache_k_scale.get().data())) + : nullptr, + cache_v_scale ? reinterpret_cast( + const_cast(cache_v_scale.get().data())) + : nullptr, + max_seq_len, + max_blocks_per_seq, + num_heads, + kv_num_heads, + dim_head, + block_size, + bsz, + token_nums, + stream, + use_neox_rotary_style, + rope_3d); + } else if (cache_quant_type_str == "cache_fp8") { + append_speculate_cache_int8_rope( + reinterpret_cast(qkv_ptr), + key_cache_out->data(), + value_cache_out->data(), + reinterpret_cast(qkv_out->data()), + block_tables.data(), + batch_id_per_token.data(), + cu_seqlens_q.data(), + seq_lens.data(), + seq_lens_encoder.data(), + cos_emb, + sin_emb, + qkv_out_scales ? qkv_out_scales.get().data() : nullptr, + qkv_biases ? reinterpret_cast( + const_cast(qkv_biases.get().data())) + : nullptr, + cache_k_scale ? reinterpret_cast( + const_cast(cache_k_scale.get().data())) + : nullptr, + cache_v_scale ? reinterpret_cast( + const_cast(cache_v_scale.get().data())) + : nullptr, + max_seq_len, + max_blocks_per_seq, + num_heads, + kv_num_heads, + dim_head, + block_size, + bsz, + token_nums, + stream, + use_neox_rotary_style, + rope_3d); + } else if (cache_quant_type_str == "block_wise_fp8") { + append_speculate_cache_fp8_rope( + reinterpret_cast(qkv_ptr), + key_cache_out->data(), + value_cache_out->data(), + reinterpret_cast(qkv_out->data()), + block_tables.data(), + batch_id_per_token.data(), + cu_seqlens_q.data(), + seq_lens.data(), + seq_lens_encoder.data(), + cos_emb, + sin_emb, + const_cast(reinterpret_cast( + cache_k_scale.get().data())), + const_cast(reinterpret_cast( + cache_v_scale.get().data())), + nullptr, // q_norm_weight + nullptr, // k_norm_weight + max_seq_len, + max_blocks_per_seq, + num_heads, + kv_num_heads, + dim_head, + block_size, + bsz, + token_nums, + stream, + rope_3d, + rms_norm_eps); + } else if (cache_quant_type_str == "cache_int4_zp") { + append_speculate_cache_int4_rope( + reinterpret_cast(qkv_ptr), + key_cache_out->data(), + value_cache_out->data(), + reinterpret_cast(const_cast(qkv_out->data())), + block_tables.data(), + batch_id_per_token.data(), + cu_seqlens_q.data(), + seq_lens.data(), + seq_lens_encoder.data(), + cos_emb, + sin_emb, + qkv_out_scales ? qkv_out_scales.get().data() : nullptr, + qkv_biases ? reinterpret_cast( + const_cast(qkv_biases.get().data())) + : nullptr, + cache_k_scale ? reinterpret_cast( + const_cast(cache_k_scale.get().data())) + : nullptr, + cache_v_scale ? reinterpret_cast( + const_cast(cache_v_scale.get().data())) + : nullptr, + cache_k_zp ? reinterpret_cast( + const_cast(cache_k_zp.get().data())) + : nullptr, + cache_v_zp ? reinterpret_cast( + const_cast(cache_v_zp.get().data())) + : nullptr, + max_seq_len, + max_blocks_per_seq, + num_heads, + kv_num_heads, + dim_head, + block_size, + bsz, + token_nums, + stream, + use_neox_rotary_style, + rope_3d); + } else { + PD_THROW( + "cache_quant_type_str should be one of [none, cache_int8, " + "cache_int4_zp]"); + } } } @@ -500,11 +881,15 @@ template void SpeculateWriteCacheWithRoPEKernel( const paddle::optional& cache_v_zp, const std::string& cache_quant_type_str, const bool use_neox_rotary_style, + const bool rope_3d, const int max_seq_len, cudaStream_t& stream, paddle::Tensor* qkv_out, paddle::Tensor* key_cache_out, - paddle::Tensor* value_cache_out); + paddle::Tensor* value_cache_out, + const paddle::optional& q_norm_weight, + const paddle::optional& k_norm_weight, + const float rms_norm_eps); template void SpeculateWriteCacheWithRoPEKernel( @@ -526,11 +911,15 @@ SpeculateWriteCacheWithRoPEKernel( const paddle::optional& cache_v_zp, const std::string& cache_quant_type_str, const bool use_neox_rotary_style, + const bool rope_3d, const int max_seq_len, cudaStream_t& stream, paddle::Tensor* qkv_out, paddle::Tensor* key_cache_out, - paddle::Tensor* value_cache_out); + paddle::Tensor* value_cache_out, + const paddle::optional& q_norm_weight, + const paddle::optional& k_norm_weight, + const float rms_norm_eps); template void SpeculateWriteCacheWithRoPEKernel( const AppendAttnMetaData& meta_data, @@ -551,12 +940,15 @@ template void SpeculateWriteCacheWithRoPEKernel( const paddle::optional& cache_v_zp, const std::string& cache_quant_type_str, const bool use_neox_rotary_style, + const bool rope_3d, const int max_seq_len, cudaStream_t& stream, paddle::Tensor* qkv_out, paddle::Tensor* key_cache_out, - paddle::Tensor* value_cache_out); - + paddle::Tensor* value_cache_out, + const paddle::optional& q_norm_weight, + const paddle::optional& k_norm_weight, + const float rms_norm_eps); template void SpeculateWriteCacheWithRoPEKernel( @@ -578,8 +970,130 @@ SpeculateWriteCacheWithRoPEKernel( const paddle::optional& cache_v_zp, const std::string& cache_quant_type_str, const bool use_neox_rotary_style, + const bool rope_3d, + const int max_seq_len, + cudaStream_t& stream, + paddle::Tensor* qkv_out, + paddle::Tensor* key_cache_out, + paddle::Tensor* value_cache_out, + const paddle::optional& q_norm_weight, + const paddle::optional& k_norm_weight, + const float rms_norm_eps); + +template void SpeculateWriteCacheWithRoPEKernel( + const AppendAttnMetaData& meta_data, + const paddle::Tensor& + qkv, // [token_num, 3, num_head, head_dim] ([token_num, num_head + 2 * + // gqa_group_size, head_dim] if GQA) + const paddle::Tensor& seq_lens, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& batch_id_per_token, + const paddle::Tensor& cu_seqlens_q, + const paddle::Tensor& block_tables, + const paddle::optional& rotary_embs, + const paddle::optional& qkv_out_scales, + const paddle::optional& qkv_biases, + const paddle::optional& cache_k_scale, + const paddle::optional& cache_v_scale, + const paddle::optional& cache_k_zp, + const paddle::optional& cache_v_zp, + const std::string& cache_quant_type_str, + const bool use_neox_rotary_style, + const bool rope_3d, + const int max_seq_len, + cudaStream_t& stream, + paddle::Tensor* qkv_out, + paddle::Tensor* key_cache_out, + paddle::Tensor* value_cache_out, + const paddle::optional& q_norm_weight, + const paddle::optional& k_norm_weight, + const float rms_norm_eps); + +template void +SpeculateWriteCacheWithRoPEKernel( + const AppendAttnMetaData& meta_data, + const paddle::Tensor& + qkv, // [token_num, 3, num_head, head_dim] ([token_num, num_head + 2 * + // gqa_group_size, head_dim] if GQA) + const paddle::Tensor& seq_lens, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& batch_id_per_token, + const paddle::Tensor& cu_seqlens_q, + const paddle::Tensor& block_tables, + const paddle::optional& rotary_embs, + const paddle::optional& qkv_out_scales, + const paddle::optional& qkv_biases, + const paddle::optional& cache_k_scale, + const paddle::optional& cache_v_scale, + const paddle::optional& cache_k_zp, + const paddle::optional& cache_v_zp, + const std::string& cache_quant_type_str, + const bool use_neox_rotary_style, + const bool rope_3d, + const int max_seq_len, + cudaStream_t& stream, + paddle::Tensor* qkv_out, + paddle::Tensor* key_cache_out, + paddle::Tensor* value_cache_out, + const paddle::optional& q_norm_weight, + const paddle::optional& k_norm_weight, + const float rms_norm_eps); + +template void SpeculateWriteCacheWithRoPEKernel( + const AppendAttnMetaData& meta_data, + const paddle::Tensor& + qkv, // [token_num, 3, num_head, head_dim] ([token_num, num_head + 2 * + // gqa_group_size, head_dim] if GQA) + const paddle::Tensor& seq_lens, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& batch_id_per_token, + const paddle::Tensor& cu_seqlens_q, + const paddle::Tensor& block_tables, + const paddle::optional& rotary_embs, + const paddle::optional& qkv_out_scales, + const paddle::optional& qkv_biases, + const paddle::optional& cache_k_scale, + const paddle::optional& cache_v_scale, + const paddle::optional& cache_k_zp, + const paddle::optional& cache_v_zp, + const std::string& cache_quant_type_str, + const bool use_neox_rotary_style, + const bool rope_3d, + const int max_seq_len, + cudaStream_t& stream, + paddle::Tensor* qkv_out, + paddle::Tensor* key_cache_out, + paddle::Tensor* value_cache_out, + const paddle::optional& q_norm_weight, + const paddle::optional& k_norm_weight, + const float rms_norm_eps); + +template void +SpeculateWriteCacheWithRoPEKernel( + const AppendAttnMetaData& meta_data, + const paddle::Tensor& + qkv, // [token_num, 3, num_head, head_dim] ([token_num, num_head + 2 * + // gqa_group_size, head_dim] if GQA) + const paddle::Tensor& seq_lens, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& batch_id_per_token, + const paddle::Tensor& cu_seqlens_q, + const paddle::Tensor& block_tables, + const paddle::optional& rotary_embs, + const paddle::optional& qkv_out_scales, + const paddle::optional& qkv_biases, + const paddle::optional& cache_k_scale, + const paddle::optional& cache_v_scale, + const paddle::optional& cache_k_zp, + const paddle::optional& cache_v_zp, + const std::string& cache_quant_type_str, + const bool use_neox_rotary_style, + const bool rope_3d, const int max_seq_len, cudaStream_t& stream, paddle::Tensor* qkv_out, paddle::Tensor* key_cache_out, - paddle::Tensor* value_cache_out); + paddle::Tensor* value_cache_out, + const paddle::optional& q_norm_weight, + const paddle::optional& k_norm_weight, + const float rms_norm_eps); diff --git a/custom_ops/gpu_ops/append_attn/speculate_write_cache_with_rope_kernel.h b/custom_ops/gpu_ops/append_attn/speculate_write_cache_with_rope_kernel.h index 40ab34e05a6..c9c3ff9e0b9 100644 --- a/custom_ops/gpu_ops/append_attn/speculate_write_cache_with_rope_kernel.h +++ b/custom_ops/gpu_ops/append_attn/speculate_write_cache_with_rope_kernel.h @@ -15,7 +15,7 @@ #include "speculate_write_cache_with_rope_impl.cuh" -template +template void SpeculateWriteCacheWithRoPEKernel( const AppendAttnMetaData& meta_data, const paddle::Tensor& @@ -35,8 +35,12 @@ void SpeculateWriteCacheWithRoPEKernel( const paddle::optional& cache_v_zp, const std::string& cache_quant_type_str, const bool use_neox_rotary_style, + const bool rope_3d, const int max_seq_len, cudaStream_t& stream, paddle::Tensor* qkv_out, paddle::Tensor* key_cache_out, - paddle::Tensor* value_cache_out); + paddle::Tensor* value_cache_out, + const paddle::optional& q_norm_weight, + const paddle::optional& k_norm_weight, + const float rms_norm_eps); diff --git a/custom_ops/gpu_ops/append_attn/template_config.json b/custom_ops/gpu_ops/append_attn/template_config.json new file mode 100644 index 00000000000..22eb9d18e19 --- /dev/null +++ b/custom_ops/gpu_ops/append_attn/template_config.json @@ -0,0 +1,144 @@ +{ + "multiquery_attention_c8": { + "name": "multiquery_attention_c8", + "function_name": "MultiQueryAppendC8Attention", + "impl_file": "multiquery_attention_c8_impl.cuh", + "template_params": [ + "T", + "GROUP_SIZE", + "HEAD_DIM", + "BLOCK_SIZE", + "CAUSAL", + "BLOCK_SHAPE_Q", + "NUM_WARP_Q", + "OutT", + "ENABLE_PREFILL", + "IsFP8", + "IsDynamicC8" + ], + "dispatch_params": { + "GROUP_SIZE": [1, 2, 4, 5, 6, 7, 8, 12, 14, 16], + "HEAD_DIM": [128], + "BLOCK_SIZE": [64], + "CAUSAL": [0, 1], + "BLOCK_SHAPE_Q": [16, 32, 64, 128], + "ENABLE_PREFILL": [0, 1], + "IsFP8": [0, 1], + "IsDynamicC8": [0, 1] + }, + "data_types": [ + ["paddle::float16", "paddle::float16", "float16_float16"], + ["paddle::float16", "paddle::float8_e4m3fn", "float16_fp8"], + ["paddle::float16", "int8_t", "float16_int8"], + ["paddle::bfloat16", "paddle::bfloat16", "bfloat16_bfloat16"], + ["paddle::bfloat16", "paddle::float8_e4m3fn", "bfloat16_fp8"], + ["paddle::bfloat16", "int8_t", "bfloat16_int8"] + ], + "max_instances_per_file": 80, + "file_prefix": "multiquery_attention_c8_", + "function_signature": "template void {function_name}{template_args}(\n const AppendAttnMetaData &meta_data,\n const paddle::Tensor &qkv,\n const paddle::Tensor &cache_k,\n const paddle::Tensor &cache_v,\n const paddle::optional &attn_mask,\n const paddle::Tensor &cache_k_scale,\n const paddle::Tensor &cache_v_scale,\n const paddle::optional &shift_bias,\n const paddle::optional &smooth_weight,\n const paddle::optional &sinks,\n const paddle::Tensor &seq_lens_q,\n const paddle::Tensor &seq_lens_kv,\n const paddle::Tensor &seq_lens_encoder,\n const paddle::Tensor &batch_id_per_token,\n const paddle::Tensor &cu_seqlens_q,\n const paddle::Tensor &block_table,\n const paddle::Tensor &batch_ids,\n const paddle::Tensor &tile_ids_per_batch,\n const int num_blocks_x_cpu,\n const int max_seq_len,\n const int max_dec_len,\n const float quant_max_bound,\n const float quant_min_bound,\n const float in_scale,\n const int max_partition_size,\n const int encoder_max_partition_size,\n const int speculate_max_draft_token_num,\n const bool is_decoder,\n cudaStream_t &stream,\n paddle::Tensor *out,\n const int sliding_window,\n const int sink_size);\n\n" + }, + "multiquery_attention_c4": { + "name": "multiquery_attention_c4", + "function_name": "MultiQueryAppendC4Attention", + "impl_file": "multiquery_attention_c4_impl.cuh", + "template_params": [ + "T", + "GROUP_SIZE", + "HEAD_DIM", + "BLOCK_SIZE", + "CAUSAL", + "BLOCK_SHAPE_Q", + "NUM_WARP_Q", + "OutT", + "ENABLE_PREFILL" + ], + "dispatch_params": { + "GROUP_SIZE": [1, 2, 4, 5, 6, 7, 8, 12, 14, 16], + "HEAD_DIM": [128], + "BLOCK_SIZE": [64], + "CAUSAL": [0, 1], + "BLOCK_SHAPE_Q": [16, 32, 64, 128], + "ENABLE_PREFILL": [0, 1] + }, + "data_types": [ + ["paddle::float16", "paddle::float16", "float16_float16"], + ["paddle::float16", "paddle::float8_e4m3fn", "float16_fp8"], + ["paddle::float16", "int8_t", "float16_int8"], + ["paddle::bfloat16", "paddle::bfloat16", "bfloat16_bfloat16"], + ["paddle::bfloat16", "paddle::float8_e4m3fn", "bfloat16_fp8"], + ["paddle::bfloat16", "int8_t", "bfloat16_int8"] + ], + "max_instances_per_file": 160, + "file_prefix": "multiquery_attention_c4_", + "function_signature": "template void {function_name}{template_args}(\n const AppendAttnMetaData &meta_data,\n const paddle::Tensor &qkv,\n const paddle::Tensor &cache_k,\n const paddle::Tensor &cache_v,\n const paddle::optional &attn_mask,\n const paddle::Tensor &cache_k_scale,\n const paddle::Tensor &cache_v_scale,\n const paddle::optional &cache_k_zp,\n const paddle::optional &cache_v_zp,\n const paddle::optional &shift_bias,\n const paddle::optional &smooth_weight,\n const paddle::optional &sinks,\n const paddle::Tensor &seq_lens_q,\n const paddle::Tensor &seq_lens_kv,\n const paddle::Tensor &seq_lens_encoder,\n const paddle::Tensor &batch_id_per_token,\n const paddle::Tensor &cu_seqlens_q,\n const paddle::Tensor &block_table,\n const paddle::Tensor &batch_ids,\n const paddle::Tensor &tile_ids_per_batch,\n const int num_blocks_x_cpu,\n const int max_seq_len,\n const int max_dec_len,\n const float quant_max_bound,\n const float quant_min_bound,\n const float in_scale,\n const int max_partition_size,\n const int encoder_max_partition_size,\n const int speculate_max_draft_token_num,\n const bool is_decoder,\n cudaStream_t &stream,\n paddle::Tensor *out,\n const int sliding_window,\n const int sink_size);\n\n" + }, + "multiquery_attention_c16": { + "name": "multiquery_attention_c16", + "function_name": "MultiQueryAppendAttention", + "impl_file": "multiquery_attention_c16_impl.cuh", + "template_params": [ + "T", + "GROUP_SIZE", + "HEAD_DIM", + "BLOCK_SIZE", + "CAUSAL", + "BLOCK_SHAPE_Q", + "NUM_WARP_Q", + "OutT", + "ENABLE_PREFILL" + ], + "dispatch_params": { + "GROUP_SIZE": [1, 2, 4, 5, 6, 7, 8, 12, 14, 16], + "HEAD_DIM": [64,128], + "BLOCK_SIZE": [64], + "CAUSAL": [0, 1], + "BLOCK_SHAPE_Q": [16, 32, 64, 128], + "ENABLE_PREFILL": [0, 1] + }, + "data_types": [ + ["paddle::float16", "paddle::float16", "float16_float16"], + ["paddle::float16", "paddle::float8_e4m3fn", "float16_fp8"], + ["paddle::float16", "int8_t", "float16_int8"], + ["paddle::bfloat16", "paddle::bfloat16", "bfloat16_bfloat16"], + ["paddle::bfloat16", "paddle::float8_e4m3fn", "bfloat16_fp8"], + ["paddle::bfloat16", "int8_t", "bfloat16_int8"] + ], + "max_instances_per_file": 160, + "file_prefix": "multiquery_attention_c16_", + "function_signature": "template void {function_name}{template_args}(\n const AppendAttnMetaData &meta_data,\n const paddle::Tensor &qkv,\n const paddle::Tensor &cache_k,\n const paddle::Tensor &cache_v,\n const paddle::optional &attn_mask,\n const paddle::optional &shift_bias,\n const paddle::optional &smooth_weight,\n const paddle::optional &sinks,\n const paddle::Tensor &seq_lens_q,\n const paddle::Tensor &seq_lens_kv,\n const paddle::Tensor &seq_lens_encoder,\n const paddle::Tensor &batch_id_per_token,\n const paddle::Tensor &cu_seqlens_q,\n const paddle::Tensor &block_table,\n const paddle::Tensor &batch_ids,\n const paddle::Tensor &tile_ids_per_batch,\n const int num_blocks_x_cpu,\n const int max_seq_len,\n const int max_dec_len,\n const float quant_max_bound,\n const float quant_min_bound,\n const float in_scale,\n const int max_partition_size,\n const int encoder_max_partition_size,\n const int speculate_max_draft_token_num,\n const bool is_decoder,\n cudaStream_t &stream,\n paddle::Tensor *out,\n const int sliding_window,\n const int sink_size);\n\n" + }, + "multiquery_decoder_attention": { + "name": "multiquery_decoder_attention", + "function_name": "MultiQueryDecoderAttention", + "impl_file": "multiquery_decoder_attention_impl.cuh", + "template_params": [ + "T", + "GROUP_SIZE", + "HEAD_DIM_QK", + "HEAD_DIM_V", + "BLOCK_SIZE", + "CAUSAL", + "NUM_STAGE", + "cache_bytes", + "DEAL_EACH_TIME" + ], + "dispatch_params": { + "GROUP_SIZE": [8, 16, 128], + "HEAD_DIM_QK": [128, 192, 512, 576], + "HEAD_DIM_V": [128, 192, 512, 576], + "BLOCK_SIZE": [64], + "CAUSAL": [0, 1], + "NUM_STAGE": [2], + "cache_bytes": [16], + "DEAL_EACH_TIME": [32, 64] + }, + "data_types": [ + ["paddle::float16", "", "float16"], + ["paddle::bfloat16", "", "bfloat16"] + ], + "max_instances_per_file": 60, + "file_prefix": "multiquery_decoder_attention_", + "function_signature": "template void {function_name}{template_args}(\n const AppendAttnMetaData& meta_data,\n cudaStream_t &stream,\n const paddle::Tensor &q,\n const paddle::Tensor &cache_k,\n const paddle::Tensor &cache_v,\n const paddle::optional& attn_mask,\n const paddle::optional& shift_bias,\n const paddle::optional& smooth_weight,\n const paddle::Tensor &seq_lens_q,\n const paddle::Tensor &seq_lens_kv,\n const paddle::Tensor &batch_id_per_token,\n const paddle::Tensor &cu_seqlens_q,\n const paddle::Tensor &block_table,\n const int max_seq_len,\n const int max_dec_len,\n const float rope_scale,\n const float rope_theta,\n const float softmax_scale,\n const float in_scale,\n paddle::Tensor *out);\n\n" + } +} diff --git a/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c16_bfloat16_bfloat16_kernel.cu b/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c16_bfloat16_bfloat16_kernel.cu deleted file mode 100644 index 93db7851312..00000000000 --- a/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c16_bfloat16_bfloat16_kernel.cu +++ /dev/null @@ -1,59 +0,0 @@ -// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -#include "../append_attention_c16_impl.cuh" - - -template void CascadeAppendAttentionC16Kernel( - const AppendAttnMetaData& meta_data, - const paddle::Tensor& qkv, // [token_num, num_heads, head_dim] - const paddle::Tensor& - cache_k, // [max_block_num, num_heads, block_size, head_dim] - const paddle::Tensor& - cache_v, // [max_block_num, num_heads, head_dim, block_size] - const paddle::optional& attn_mask, - const paddle::optional& - cache_k_scale, // [num_kv_heads, head_dim] - const paddle::optional& - cache_v_scale, // [num_kv_heads, head_dim] - const paddle::optional& - cache_k_zp, // [num_kv_heads, head_dim] - const paddle::optional& - cache_v_zp, // [num_kv_heads, head_dim] - const paddle::optional& - shift_bias, // [num_kv_heads, head_dim] - const paddle::optional& - smooth_weight, // [num_kv_heads, head_dim] - const paddle::Tensor& seq_lens_q, - const paddle::Tensor& seq_lens_kv, - const paddle::Tensor& seq_lens_encoder, - const paddle::Tensor& batch_id_per_token, - const paddle::Tensor& cu_seqlens_q, - const paddle::Tensor& block_table, - const paddle::Tensor& batch_ids, - const paddle::Tensor& tile_ids_per_batch, - const int num_blocks, - const int block_shape_q, - const int max_seq_len, - const int max_dec_len, - const float quant_max_bound, - const float quant_min_bound, - const float in_scale, - const int max_partition_size, - const int encoder_max_partition_size, - const int speculate_max_draft_token_num, - const bool causal, - const bool is_decoder, - const bool enable_prefill, - cudaStream_t& stream, - paddle::Tensor* out); diff --git a/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c16_bfloat16_fp8_kernel.cu b/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c16_bfloat16_fp8_kernel.cu deleted file mode 100644 index 57370364814..00000000000 --- a/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c16_bfloat16_fp8_kernel.cu +++ /dev/null @@ -1,58 +0,0 @@ -// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -#include "../append_attention_c16_impl.cuh" - -template void CascadeAppendAttentionC16Kernel( - const AppendAttnMetaData& meta_data, - const paddle::Tensor& qkv, // [token_num, num_heads, head_dim] - const paddle::Tensor& - cache_k, // [max_block_num, num_heads, block_size, head_dim] - const paddle::Tensor& - cache_v, // [max_block_num, num_heads, head_dim, block_size] - const paddle::optional& attn_mask, - const paddle::optional& - cache_k_scale, // [num_kv_heads, head_dim] - const paddle::optional& - cache_v_scale, // [num_kv_heads, head_dim] - const paddle::optional& - cache_k_zp, // [num_kv_heads, head_dim] - const paddle::optional& - cache_v_zp, // [num_kv_heads, head_dim] - const paddle::optional& - shift_bias, // [num_kv_heads, head_dim] - const paddle::optional& - smooth_weight, // [num_kv_heads, head_dim] - const paddle::Tensor& seq_lens_q, - const paddle::Tensor& seq_lens_kv, - const paddle::Tensor& seq_lens_encoder, - const paddle::Tensor& batch_id_per_token, - const paddle::Tensor& cu_seqlens_q, - const paddle::Tensor& block_table, - const paddle::Tensor& batch_ids, - const paddle::Tensor& tile_ids_per_batch, - const int num_blocks, - const int block_shape_q, - const int max_seq_len, - const int max_dec_len, - const float quant_max_bound, - const float quant_min_bound, - const float in_scale, - const int max_partition_size, - const int encoder_max_partition_size, - const int speculate_max_draft_token_num, - const bool causal, - const bool is_decoder, - const bool enable_prefill, - cudaStream_t& stream, - paddle::Tensor* out); diff --git a/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c16_bfloat16_int8_kernel.cu b/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c16_bfloat16_int8_kernel.cu deleted file mode 100644 index 077a5764ea8..00000000000 --- a/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c16_bfloat16_int8_kernel.cu +++ /dev/null @@ -1,58 +0,0 @@ -// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -#include "../append_attention_c16_impl.cuh" - -template void CascadeAppendAttentionC16Kernel( - const AppendAttnMetaData& meta_data, - const paddle::Tensor& qkv, // [token_num, num_heads, head_dim] - const paddle::Tensor& - cache_k, // [max_block_num, num_heads, block_size, head_dim] - const paddle::Tensor& - cache_v, // [max_block_num, num_heads, head_dim, block_size] - const paddle::optional& attn_mask, - const paddle::optional& - cache_k_scale, // [num_kv_heads, head_dim] - const paddle::optional& - cache_v_scale, // [num_kv_heads, head_dim] - const paddle::optional& - cache_k_zp, // [num_kv_heads, head_dim] - const paddle::optional& - cache_v_zp, // [num_kv_heads, head_dim] - const paddle::optional& - shift_bias, // [num_kv_heads, head_dim] - const paddle::optional& - smooth_weight, // [num_kv_heads, head_dim] - const paddle::Tensor& seq_lens_q, - const paddle::Tensor& seq_lens_kv, - const paddle::Tensor& seq_lens_encoder, - const paddle::Tensor& batch_id_per_token, - const paddle::Tensor& cu_seqlens_q, - const paddle::Tensor& block_table, - const paddle::Tensor& batch_ids, - const paddle::Tensor& tile_ids_per_batch, - const int num_blocks, - const int block_shape_q, - const int max_seq_len, - const int max_dec_len, - const float quant_max_bound, - const float quant_min_bound, - const float in_scale, - const int max_partition_size, - const int encoder_max_partition_size, - const int speculate_max_draft_token_num, - const bool causal, - const bool is_decoder, - const bool enable_prefill, - cudaStream_t& stream, - paddle::Tensor* out); diff --git a/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c16_float16_float16_kernel.cu b/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c16_float16_float16_kernel.cu deleted file mode 100644 index 43625023812..00000000000 --- a/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c16_float16_float16_kernel.cu +++ /dev/null @@ -1,58 +0,0 @@ -// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -#include "../append_attention_c16_impl.cuh" - -template void CascadeAppendAttentionC16Kernel( - const AppendAttnMetaData& meta_data, - const paddle::Tensor& qkv, // [token_num, num_heads, head_dim] - const paddle::Tensor& - cache_k, // [max_block_num, num_heads, block_size, head_dim] - const paddle::Tensor& - cache_v, // [max_block_num, num_heads, head_dim, block_size] - const paddle::optional& attn_mask, - const paddle::optional& - cache_k_scale, // [num_kv_heads, head_dim] - const paddle::optional& - cache_v_scale, // [num_kv_heads, head_dim] - const paddle::optional& - cache_k_zp, // [num_kv_heads, head_dim] - const paddle::optional& - cache_v_zp, // [num_kv_heads, head_dim] - const paddle::optional& - shift_bias, // [num_kv_heads, head_dim] - const paddle::optional& - smooth_weight, // [num_kv_heads, head_dim] - const paddle::Tensor& seq_lens_q, - const paddle::Tensor& seq_lens_kv, - const paddle::Tensor& seq_lens_encoder, - const paddle::Tensor& batch_id_per_token, - const paddle::Tensor& cu_seqlens_q, - const paddle::Tensor& block_table, - const paddle::Tensor& batch_ids, - const paddle::Tensor& tile_ids_per_batch, - const int num_blocks, - const int block_shape_q, - const int max_seq_len, - const int max_dec_len, - const float quant_max_bound, - const float quant_min_bound, - const float in_scale, - const int max_partition_size, - const int encoder_max_partition_size, - const int speculate_max_draft_token_num, - const bool causal, - const bool is_decoder, - const bool enable_prefill, - cudaStream_t& stream, - paddle::Tensor* out); diff --git a/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c16_float16_fp8_kernel.cu b/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c16_float16_fp8_kernel.cu deleted file mode 100644 index daaad4de62c..00000000000 --- a/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c16_float16_fp8_kernel.cu +++ /dev/null @@ -1,58 +0,0 @@ -// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -#include "../append_attention_c16_impl.cuh" - -template void CascadeAppendAttentionC16Kernel( - const AppendAttnMetaData& meta_data, - const paddle::Tensor& qkv, // [token_num, num_heads, head_dim] - const paddle::Tensor& - cache_k, // [max_block_num, num_heads, block_size, head_dim] - const paddle::Tensor& - cache_v, // [max_block_num, num_heads, head_dim, block_size] - const paddle::optional& attn_mask, - const paddle::optional& - cache_k_scale, // [num_kv_heads, head_dim] - const paddle::optional& - cache_v_scale, // [num_kv_heads, head_dim] - const paddle::optional& - cache_k_zp, // [num_kv_heads, head_dim] - const paddle::optional& - cache_v_zp, // [num_kv_heads, head_dim] - const paddle::optional& - shift_bias, // [num_kv_heads, head_dim] - const paddle::optional& - smooth_weight, // [num_kv_heads, head_dim] - const paddle::Tensor& seq_lens_q, - const paddle::Tensor& seq_lens_kv, - const paddle::Tensor& seq_lens_encoder, - const paddle::Tensor& batch_id_per_token, - const paddle::Tensor& cu_seqlens_q, - const paddle::Tensor& block_table, - const paddle::Tensor& batch_ids, - const paddle::Tensor& tile_ids_per_batch, - const int num_blocks, - const int block_shape_q, - const int max_seq_len, - const int max_dec_len, - const float quant_max_bound, - const float quant_min_bound, - const float in_scale, - const int max_partition_size, - const int encoder_max_partition_size, - const int speculate_max_draft_token_num, - const bool causal, - const bool is_decoder, - const bool enable_prefill, - cudaStream_t& stream, - paddle::Tensor* out); diff --git a/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c16_float16_int8_kernel.cu b/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c16_float16_int8_kernel.cu deleted file mode 100644 index 549f1cec25f..00000000000 --- a/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c16_float16_int8_kernel.cu +++ /dev/null @@ -1,58 +0,0 @@ -// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -#include "../append_attention_c16_impl.cuh" - -template void CascadeAppendAttentionC16Kernel( - const AppendAttnMetaData& meta_data, - const paddle::Tensor& qkv, // [token_num, num_heads, head_dim] - const paddle::Tensor& - cache_k, // [max_block_num, num_heads, block_size, head_dim] - const paddle::Tensor& - cache_v, // [max_block_num, num_heads, head_dim, block_size] - const paddle::optional& attn_mask, - const paddle::optional& - cache_k_scale, // [num_kv_heads, head_dim] - const paddle::optional& - cache_v_scale, // [num_kv_heads, head_dim] - const paddle::optional& - cache_k_zp, // [num_kv_heads, head_dim] - const paddle::optional& - cache_v_zp, // [num_kv_heads, head_dim] - const paddle::optional& - shift_bias, // [num_kv_heads, head_dim] - const paddle::optional& - smooth_weight, // [num_kv_heads, head_dim] - const paddle::Tensor& seq_lens_q, - const paddle::Tensor& seq_lens_kv, - const paddle::Tensor& seq_lens_encoder, - const paddle::Tensor& batch_id_per_token, - const paddle::Tensor& cu_seqlens_q, - const paddle::Tensor& block_table, - const paddle::Tensor& batch_ids, - const paddle::Tensor& tile_ids_per_batch, - const int num_blocks, - const int block_shape_q, - const int max_seq_len, - const int max_dec_len, - const float quant_max_bound, - const float quant_min_bound, - const float in_scale, - const int max_partition_size, - const int encoder_max_partition_size, - const int speculate_max_draft_token_num, - const bool causal, - const bool is_decoder, - const bool enable_prefill, - cudaStream_t& stream, - paddle::Tensor* out); diff --git a/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c4_bfloat16_bfloat16_kernel.cu b/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c4_bfloat16_bfloat16_kernel.cu deleted file mode 100644 index 923f9b0d392..00000000000 --- a/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c4_bfloat16_bfloat16_kernel.cu +++ /dev/null @@ -1,58 +0,0 @@ -// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -#include "../append_attention_c4_impl.cuh" - -template void CascadeAppendAttentionC4Kernel( - const AppendAttnMetaData& meta_data, - const paddle::Tensor& qkv, // [token_num, num_heads, head_dim] - const paddle::Tensor& - cache_k, // [max_block_num, num_heads, block_size, head_dim] - const paddle::Tensor& - cache_v, // [max_block_num, num_heads, head_dim, block_size] - const paddle::optional& attn_mask, - const paddle::optional& - cache_k_scale, // [num_kv_heads, head_dim] - const paddle::optional& - cache_v_scale, // [num_kv_heads, head_dim] - const paddle::optional& - cache_k_zp, // [num_kv_heads, head_dim] - const paddle::optional& - cache_v_zp, // [num_kv_heads, head_dim] - const paddle::optional& - shift_bias, // [num_kv_heads, head_dim] - const paddle::optional& - smooth_weight, // [num_kv_heads, head_dim] - const paddle::Tensor& seq_lens_q, - const paddle::Tensor& seq_lens_kv, - const paddle::Tensor& seq_lens_encoder, - const paddle::Tensor& batch_id_per_token, - const paddle::Tensor& cu_seqlens_q, - const paddle::Tensor& block_table, - const paddle::Tensor& batch_ids, - const paddle::Tensor& tile_ids_per_batch, - const int num_blocks, - const int block_shape_q, - const int max_seq_len, - const int max_dec_len, - const float quant_max_bound, - const float quant_min_bound, - const float in_scale, - const int max_partition_size, - const int encoder_max_partition_size, - const int speculate_max_draft_token_num, - const bool causal, - const bool is_decoder, - const bool enable_prefill, - cudaStream_t& stream, - paddle::Tensor* out); diff --git a/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c4_bfloat16_fp8_kernel.cu b/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c4_bfloat16_fp8_kernel.cu deleted file mode 100644 index 888c410bbb5..00000000000 --- a/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c4_bfloat16_fp8_kernel.cu +++ /dev/null @@ -1,58 +0,0 @@ -// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -#include "../append_attention_c4_impl.cuh" - -template void CascadeAppendAttentionC4Kernel( - const AppendAttnMetaData& meta_data, - const paddle::Tensor& qkv, // [token_num, num_heads, head_dim] - const paddle::Tensor& - cache_k, // [max_block_num, num_heads, block_size, head_dim] - const paddle::Tensor& - cache_v, // [max_block_num, num_heads, head_dim, block_size] - const paddle::optional& attn_mask, - const paddle::optional& - cache_k_scale, // [num_kv_heads, head_dim] - const paddle::optional& - cache_v_scale, // [num_kv_heads, head_dim] - const paddle::optional& - cache_k_zp, // [num_kv_heads, head_dim] - const paddle::optional& - cache_v_zp, // [num_kv_heads, head_dim] - const paddle::optional& - shift_bias, // [num_kv_heads, head_dim] - const paddle::optional& - smooth_weight, // [num_kv_heads, head_dim] - const paddle::Tensor& seq_lens_q, - const paddle::Tensor& seq_lens_kv, - const paddle::Tensor& seq_lens_encoder, - const paddle::Tensor& batch_id_per_token, - const paddle::Tensor& cu_seqlens_q, - const paddle::Tensor& block_table, - const paddle::Tensor& batch_ids, - const paddle::Tensor& tile_ids_per_batch, - const int num_blocks, - const int block_shape_q, - const int max_seq_len, - const int max_dec_len, - const float quant_max_bound, - const float quant_min_bound, - const float in_scale, - const int max_partition_size, - const int encoder_max_partition_size, - const int speculate_max_draft_token_num, - const bool causal, - const bool is_decoder, - const bool enable_prefill, - cudaStream_t& stream, - paddle::Tensor* out); diff --git a/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c4_bfloat16_int8_kernel.cu b/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c4_bfloat16_int8_kernel.cu deleted file mode 100644 index fcef546ea6e..00000000000 --- a/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c4_bfloat16_int8_kernel.cu +++ /dev/null @@ -1,58 +0,0 @@ -// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -#include "../append_attention_c4_impl.cuh" - -template void CascadeAppendAttentionC4Kernel( - const AppendAttnMetaData& meta_data, - const paddle::Tensor& qkv, // [token_num, num_heads, head_dim] - const paddle::Tensor& - cache_k, // [max_block_num, num_heads, block_size, head_dim] - const paddle::Tensor& - cache_v, // [max_block_num, num_heads, head_dim, block_size] - const paddle::optional& attn_mask, - const paddle::optional& - cache_k_scale, // [num_kv_heads, head_dim] - const paddle::optional& - cache_v_scale, // [num_kv_heads, head_dim] - const paddle::optional& - cache_k_zp, // [num_kv_heads, head_dim] - const paddle::optional& - cache_v_zp, // [num_kv_heads, head_dim] - const paddle::optional& - shift_bias, // [num_kv_heads, head_dim] - const paddle::optional& - smooth_weight, // [num_kv_heads, head_dim] - const paddle::Tensor& seq_lens_q, - const paddle::Tensor& seq_lens_kv, - const paddle::Tensor& seq_lens_encoder, - const paddle::Tensor& batch_id_per_token, - const paddle::Tensor& cu_seqlens_q, - const paddle::Tensor& block_table, - const paddle::Tensor& batch_ids, - const paddle::Tensor& tile_ids_per_batch, - const int num_blocks, - const int block_shape_q, - const int max_seq_len, - const int max_dec_len, - const float quant_max_bound, - const float quant_min_bound, - const float in_scale, - const int max_partition_size, - const int encoder_max_partition_size, - const int speculate_max_draft_token_num, - const bool causal, - const bool is_decoder, - const bool enable_prefill, - cudaStream_t& stream, - paddle::Tensor* out); diff --git a/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c4_float16_float16_kernel.cu b/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c4_float16_float16_kernel.cu deleted file mode 100644 index 65637493715..00000000000 --- a/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c4_float16_float16_kernel.cu +++ /dev/null @@ -1,59 +0,0 @@ -// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -#include "../append_attention_c4_impl.cuh" - - -template void CascadeAppendAttentionC4Kernel( - const AppendAttnMetaData& meta_data, - const paddle::Tensor& qkv, // [token_num, num_heads, head_dim] - const paddle::Tensor& - cache_k, // [max_block_num, num_heads, block_size, head_dim] - const paddle::Tensor& - cache_v, // [max_block_num, num_heads, head_dim, block_size] - const paddle::optional& attn_mask, - const paddle::optional& - cache_k_scale, // [num_kv_heads, head_dim] - const paddle::optional& - cache_v_scale, // [num_kv_heads, head_dim] - const paddle::optional& - cache_k_zp, // [num_kv_heads, head_dim] - const paddle::optional& - cache_v_zp, // [num_kv_heads, head_dim] - const paddle::optional& - shift_bias, // [num_kv_heads, head_dim] - const paddle::optional& - smooth_weight, // [num_kv_heads, head_dim] - const paddle::Tensor& seq_lens_q, - const paddle::Tensor& seq_lens_kv, - const paddle::Tensor& seq_lens_encoder, - const paddle::Tensor& batch_id_per_token, - const paddle::Tensor& cu_seqlens_q, - const paddle::Tensor& block_table, - const paddle::Tensor& batch_ids, - const paddle::Tensor& tile_ids_per_batch, - const int num_blocks, - const int block_shape_q, - const int max_seq_len, - const int max_dec_len, - const float quant_max_bound, - const float quant_min_bound, - const float in_scale, - const int max_partition_size, - const int encoder_max_partition_size, - const int speculate_max_draft_token_num, - const bool causal, - const bool is_decoder, - const bool enable_prefill, - cudaStream_t& stream, - paddle::Tensor* out); diff --git a/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c4_float16_fp8_kernel.cu b/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c4_float16_fp8_kernel.cu deleted file mode 100644 index fba62df2bd6..00000000000 --- a/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c4_float16_fp8_kernel.cu +++ /dev/null @@ -1,58 +0,0 @@ -// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -#include "../append_attention_c4_impl.cuh" - -template void CascadeAppendAttentionC4Kernel( - const AppendAttnMetaData& meta_data, - const paddle::Tensor& qkv, // [token_num, num_heads, head_dim] - const paddle::Tensor& - cache_k, // [max_block_num, num_heads, block_size, head_dim] - const paddle::Tensor& - cache_v, // [max_block_num, num_heads, head_dim, block_size] - const paddle::optional& attn_mask, - const paddle::optional& - cache_k_scale, // [num_kv_heads, head_dim] - const paddle::optional& - cache_v_scale, // [num_kv_heads, head_dim] - const paddle::optional& - cache_k_zp, // [num_kv_heads, head_dim] - const paddle::optional& - cache_v_zp, // [num_kv_heads, head_dim] - const paddle::optional& - shift_bias, // [num_kv_heads, head_dim] - const paddle::optional& - smooth_weight, // [num_kv_heads, head_dim] - const paddle::Tensor& seq_lens_q, - const paddle::Tensor& seq_lens_kv, - const paddle::Tensor& seq_lens_encoder, - const paddle::Tensor& batch_id_per_token, - const paddle::Tensor& cu_seqlens_q, - const paddle::Tensor& block_table, - const paddle::Tensor& batch_ids, - const paddle::Tensor& tile_ids_per_batch, - const int num_blocks, - const int block_shape_q, - const int max_seq_len, - const int max_dec_len, - const float quant_max_bound, - const float quant_min_bound, - const float in_scale, - const int max_partition_size, - const int encoder_max_partition_size, - const int speculate_max_draft_token_num, - const bool causal, - const bool is_decoder, - const bool enable_prefill, - cudaStream_t& stream, - paddle::Tensor* out); diff --git a/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c4_float16_int8_kernel.cu b/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c4_float16_int8_kernel.cu deleted file mode 100644 index 7a6e21fa758..00000000000 --- a/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c4_float16_int8_kernel.cu +++ /dev/null @@ -1,58 +0,0 @@ -// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -#include "../append_attention_c4_impl.cuh" - -template void CascadeAppendAttentionC4Kernel( - const AppendAttnMetaData& meta_data, - const paddle::Tensor& qkv, // [token_num, num_heads, head_dim] - const paddle::Tensor& - cache_k, // [max_block_num, num_heads, block_size, head_dim] - const paddle::Tensor& - cache_v, // [max_block_num, num_heads, head_dim, block_size] - const paddle::optional& attn_mask, - const paddle::optional& - cache_k_scale, // [num_kv_heads, head_dim] - const paddle::optional& - cache_v_scale, // [num_kv_heads, head_dim] - const paddle::optional& - cache_k_zp, // [num_kv_heads, head_dim] - const paddle::optional& - cache_v_zp, // [num_kv_heads, head_dim] - const paddle::optional& - shift_bias, // [num_kv_heads, head_dim] - const paddle::optional& - smooth_weight, // [num_kv_heads, head_dim] - const paddle::Tensor& seq_lens_q, - const paddle::Tensor& seq_lens_kv, - const paddle::Tensor& seq_lens_encoder, - const paddle::Tensor& batch_id_per_token, - const paddle::Tensor& cu_seqlens_q, - const paddle::Tensor& block_table, - const paddle::Tensor& batch_ids, - const paddle::Tensor& tile_ids_per_batch, - const int num_blocks, - const int block_shape_q, - const int max_seq_len, - const int max_dec_len, - const float quant_max_bound, - const float quant_min_bound, - const float in_scale, - const int max_partition_size, - const int encoder_max_partition_size, - const int speculate_max_draft_token_num, - const bool causal, - const bool is_decoder, - const bool enable_prefill, - cudaStream_t& stream, - paddle::Tensor* out); diff --git a/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c8_bfloat16_bfloat16_kernel.cu b/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c8_bfloat16_bfloat16_kernel.cu deleted file mode 100644 index e860a046265..00000000000 --- a/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c8_bfloat16_bfloat16_kernel.cu +++ /dev/null @@ -1,107 +0,0 @@ -// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -#include "../append_attention_c8_impl.cuh" - - -template void -CascadeAppendAttentionC8Kernel( - const AppendAttnMetaData& meta_data, - const paddle::Tensor& qkv, // [token_num, num_heads, head_dim] - const paddle::Tensor& - cache_k, // [max_block_num, num_heads, block_size, head_dim] - const paddle::Tensor& - cache_v, // [max_block_num, num_heads, head_dim, block_size] - const paddle::optional& attn_mask, - const paddle::optional& - cache_k_scale, // [num_kv_heads, head_dim] - const paddle::optional& - cache_v_scale, // [num_kv_heads, head_dim] - const paddle::optional& - cache_k_zp, // [num_kv_heads, head_dim] - const paddle::optional& - cache_v_zp, // [num_kv_heads, head_dim] - const paddle::optional& - shift_bias, // [num_kv_heads, head_dim] - const paddle::optional& - smooth_weight, // [num_kv_heads, head_dim] - const paddle::Tensor& seq_lens_q, - const paddle::Tensor& seq_lens_kv, - const paddle::Tensor& seq_lens_encoder, - const paddle::Tensor& batch_id_per_token, - const paddle::Tensor& cu_seqlens_q, - const paddle::Tensor& block_table, - const paddle::Tensor& batch_ids, - const paddle::Tensor& tile_ids_per_batch, - const int num_blocks, - const int block_shape_q, - const int max_seq_len, - const int max_dec_len, - const float quant_max_bound, - const float quant_min_bound, - const float in_scale, - const int max_partition_size, - const int encoder_max_partition_size, - const int speculate_max_draft_token_num, - const bool causal, - const bool is_decoder, - const bool enable_prefill, - cudaStream_t& stream, - paddle::Tensor* out); - - - -template void -CascadeAppendAttentionC8Kernel( - const AppendAttnMetaData& meta_data, - const paddle::Tensor& qkv, // [token_num, num_heads, head_dim] - const paddle::Tensor& - cache_k, // [max_block_num, num_heads, block_size, head_dim] - const paddle::Tensor& - cache_v, // [max_block_num, num_heads, head_dim, block_size] - const paddle::optional& attn_mask, - const paddle::optional& - cache_k_scale, // [num_kv_heads, head_dim] - const paddle::optional& - cache_v_scale, // [num_kv_heads, head_dim] - const paddle::optional& - cache_k_zp, // [num_kv_heads, head_dim] - const paddle::optional& - cache_v_zp, // [num_kv_heads, head_dim] - const paddle::optional& - shift_bias, // [num_kv_heads, head_dim] - const paddle::optional& - smooth_weight, // [num_kv_heads, head_dim] - const paddle::Tensor& seq_lens_q, - const paddle::Tensor& seq_lens_kv, - const paddle::Tensor& seq_lens_encoder, - const paddle::Tensor& batch_id_per_token, - const paddle::Tensor& cu_seqlens_q, - const paddle::Tensor& block_table, - const paddle::Tensor& batch_ids, - const paddle::Tensor& tile_ids_per_batch, - const int num_blocks, - const int block_shape_q, - const int max_seq_len, - const int max_dec_len, - const float quant_max_bound, - const float quant_min_bound, - const float in_scale, - const int max_partition_size, - const int encoder_max_partition_size, - const int speculate_max_draft_token_num, - const bool causal, - const bool is_decoder, - const bool enable_prefill, - cudaStream_t& stream, - paddle::Tensor* out); diff --git a/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c8_bfloat16_fp8_kernel.cu b/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c8_bfloat16_fp8_kernel.cu deleted file mode 100644 index 3b61ecd16bd..00000000000 --- a/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c8_bfloat16_fp8_kernel.cu +++ /dev/null @@ -1,102 +0,0 @@ -// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -#include "../append_attention_c8_impl.cuh" - -template void CascadeAppendAttentionC8Kernel( - const AppendAttnMetaData& meta_data, - const paddle::Tensor& qkv, // [token_num, num_heads, head_dim] - const paddle::Tensor& - cache_k, // [max_block_num, num_heads, block_size, head_dim] - const paddle::Tensor& - cache_v, // [max_block_num, num_heads, head_dim, block_size] - const paddle::optional& attn_mask, - const paddle::optional& - cache_k_scale, // [num_kv_heads, head_dim] - const paddle::optional& - cache_v_scale, // [num_kv_heads, head_dim] - const paddle::optional& - cache_k_zp, // [num_kv_heads, head_dim] - const paddle::optional& - cache_v_zp, // [num_kv_heads, head_dim] - const paddle::optional& - shift_bias, // [num_kv_heads, head_dim] - const paddle::optional& - smooth_weight, // [num_kv_heads, head_dim] - const paddle::Tensor& seq_lens_q, - const paddle::Tensor& seq_lens_kv, - const paddle::Tensor& seq_lens_encoder, - const paddle::Tensor& batch_id_per_token, - const paddle::Tensor& cu_seqlens_q, - const paddle::Tensor& block_table, - const paddle::Tensor& batch_ids, - const paddle::Tensor& tile_ids_per_batch, - const int num_blocks, - const int block_shape_q, - const int max_seq_len, - const int max_dec_len, - const float quant_max_bound, - const float quant_min_bound, - const float in_scale, - const int max_partition_size, - const int encoder_max_partition_size, - const int speculate_max_draft_token_num, - const bool causal, - const bool is_decoder, - const bool enable_prefill, - cudaStream_t& stream, - paddle::Tensor* out); - -template void CascadeAppendAttentionC8Kernel( - const AppendAttnMetaData& meta_data, - const paddle::Tensor& qkv, // [token_num, num_heads, head_dim] - const paddle::Tensor& - cache_k, // [max_block_num, num_heads, block_size, head_dim] - const paddle::Tensor& - cache_v, // [max_block_num, num_heads, head_dim, block_size] - const paddle::optional& attn_mask, - const paddle::optional& - cache_k_scale, // [num_kv_heads, head_dim] - const paddle::optional& - cache_v_scale, // [num_kv_heads, head_dim] - const paddle::optional& - cache_k_zp, // [num_kv_heads, head_dim] - const paddle::optional& - cache_v_zp, // [num_kv_heads, head_dim] - const paddle::optional& - shift_bias, // [num_kv_heads, head_dim] - const paddle::optional& - smooth_weight, // [num_kv_heads, head_dim] - const paddle::Tensor& seq_lens_q, - const paddle::Tensor& seq_lens_kv, - const paddle::Tensor& seq_lens_encoder, - const paddle::Tensor& batch_id_per_token, - const paddle::Tensor& cu_seqlens_q, - const paddle::Tensor& block_table, - const paddle::Tensor& batch_ids, - const paddle::Tensor& tile_ids_per_batch, - const int num_blocks, - const int block_shape_q, - const int max_seq_len, - const int max_dec_len, - const float quant_max_bound, - const float quant_min_bound, - const float in_scale, - const int max_partition_size, - const int encoder_max_partition_size, - const int speculate_max_draft_token_num, - const bool causal, - const bool is_decoder, - const bool enable_prefill, - cudaStream_t& stream, - paddle::Tensor* out); diff --git a/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c8_bfloat16_int8_kernel.cu b/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c8_bfloat16_int8_kernel.cu deleted file mode 100644 index e864722b560..00000000000 --- a/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c8_bfloat16_int8_kernel.cu +++ /dev/null @@ -1,104 +0,0 @@ -// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -#include "../append_attention_c8_impl.cuh" - -template void CascadeAppendAttentionC8Kernel( - const AppendAttnMetaData& meta_data, - const paddle::Tensor& qkv, // [token_num, num_heads, head_dim] - const paddle::Tensor& - cache_k, // [max_block_num, num_heads, block_size, head_dim] - const paddle::Tensor& - cache_v, // [max_block_num, num_heads, head_dim, block_size] - const paddle::optional& attn_mask, - const paddle::optional& - cache_k_scale, // [num_kv_heads, head_dim] - const paddle::optional& - cache_v_scale, // [num_kv_heads, head_dim] - const paddle::optional& - cache_k_zp, // [num_kv_heads, head_dim] - const paddle::optional& - cache_v_zp, // [num_kv_heads, head_dim] - const paddle::optional& - shift_bias, // [num_kv_heads, head_dim] - const paddle::optional& - smooth_weight, // [num_kv_heads, head_dim] - const paddle::Tensor& seq_lens_q, - const paddle::Tensor& seq_lens_kv, - const paddle::Tensor& seq_lens_encoder, - const paddle::Tensor& batch_id_per_token, - const paddle::Tensor& cu_seqlens_q, - const paddle::Tensor& block_table, - const paddle::Tensor& batch_ids, - const paddle::Tensor& tile_ids_per_batch, - const int num_blocks, - const int block_shape_q, - const int max_seq_len, - const int max_dec_len, - const float quant_max_bound, - const float quant_min_bound, - const float in_scale, - const int max_partition_size, - const int encoder_max_partition_size, - const int speculate_max_draft_token_num, - const bool causal, - const bool is_decoder, - const bool enable_prefill, - cudaStream_t& stream, - paddle::Tensor* out); - - - -template void CascadeAppendAttentionC8Kernel( - const AppendAttnMetaData& meta_data, - const paddle::Tensor& qkv, // [token_num, num_heads, head_dim] - const paddle::Tensor& - cache_k, // [max_block_num, num_heads, block_size, head_dim] - const paddle::Tensor& - cache_v, // [max_block_num, num_heads, head_dim, block_size] - const paddle::optional& attn_mask, - const paddle::optional& - cache_k_scale, // [num_kv_heads, head_dim] - const paddle::optional& - cache_v_scale, // [num_kv_heads, head_dim] - const paddle::optional& - cache_k_zp, // [num_kv_heads, head_dim] - const paddle::optional& - cache_v_zp, // [num_kv_heads, head_dim] - const paddle::optional& - shift_bias, // [num_kv_heads, head_dim] - const paddle::optional& - smooth_weight, // [num_kv_heads, head_dim] - const paddle::Tensor& seq_lens_q, - const paddle::Tensor& seq_lens_kv, - const paddle::Tensor& seq_lens_encoder, - const paddle::Tensor& batch_id_per_token, - const paddle::Tensor& cu_seqlens_q, - const paddle::Tensor& block_table, - const paddle::Tensor& batch_ids, - const paddle::Tensor& tile_ids_per_batch, - const int num_blocks, - const int block_shape_q, - const int max_seq_len, - const int max_dec_len, - const float quant_max_bound, - const float quant_min_bound, - const float in_scale, - const int max_partition_size, - const int encoder_max_partition_size, - const int speculate_max_draft_token_num, - const bool causal, - const bool is_decoder, - const bool enable_prefill, - cudaStream_t& stream, - paddle::Tensor* out); diff --git a/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c8_float16_float16_kernel.cu b/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c8_float16_float16_kernel.cu deleted file mode 100644 index 4d7b11d99cc..00000000000 --- a/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c8_float16_float16_kernel.cu +++ /dev/null @@ -1,104 +0,0 @@ -// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -#include "../append_attention_c8_impl.cuh" - -template void CascadeAppendAttentionC8Kernel( - const AppendAttnMetaData& meta_data, - const paddle::Tensor& qkv, // [token_num, num_heads, head_dim] - const paddle::Tensor& - cache_k, // [max_block_num, num_heads, block_size, head_dim] - const paddle::Tensor& - cache_v, // [max_block_num, num_heads, head_dim, block_size] - const paddle::optional& attn_mask, - const paddle::optional& - cache_k_scale, // [num_kv_heads, head_dim] - const paddle::optional& - cache_v_scale, // [num_kv_heads, head_dim] - const paddle::optional& - cache_k_zp, // [num_kv_heads, head_dim] - const paddle::optional& - cache_v_zp, // [num_kv_heads, head_dim] - const paddle::optional& - shift_bias, // [num_kv_heads, head_dim] - const paddle::optional& - smooth_weight, // [num_kv_heads, head_dim] - const paddle::Tensor& seq_lens_q, - const paddle::Tensor& seq_lens_kv, - const paddle::Tensor& seq_lens_encoder, - const paddle::Tensor& batch_id_per_token, - const paddle::Tensor& cu_seqlens_q, - const paddle::Tensor& block_table, - const paddle::Tensor& batch_ids, - const paddle::Tensor& tile_ids_per_batch, - const int num_blocks, - const int block_shape_q, - const int max_seq_len, - const int max_dec_len, - const float quant_max_bound, - const float quant_min_bound, - const float in_scale, - const int max_partition_size, - const int encoder_max_partition_size, - const int speculate_max_draft_token_num, - const bool causal, - const bool is_decoder, - const bool enable_prefill, - cudaStream_t& stream, - paddle::Tensor* out); - - - -template void CascadeAppendAttentionC8Kernel( - const AppendAttnMetaData& meta_data, - const paddle::Tensor& qkv, // [token_num, num_heads, head_dim] - const paddle::Tensor& - cache_k, // [max_block_num, num_heads, block_size, head_dim] - const paddle::Tensor& - cache_v, // [max_block_num, num_heads, head_dim, block_size] - const paddle::optional& attn_mask, - const paddle::optional& - cache_k_scale, // [num_kv_heads, head_dim] - const paddle::optional& - cache_v_scale, // [num_kv_heads, head_dim] - const paddle::optional& - cache_k_zp, // [num_kv_heads, head_dim] - const paddle::optional& - cache_v_zp, // [num_kv_heads, head_dim] - const paddle::optional& - shift_bias, // [num_kv_heads, head_dim] - const paddle::optional& - smooth_weight, // [num_kv_heads, head_dim] - const paddle::Tensor& seq_lens_q, - const paddle::Tensor& seq_lens_kv, - const paddle::Tensor& seq_lens_encoder, - const paddle::Tensor& batch_id_per_token, - const paddle::Tensor& cu_seqlens_q, - const paddle::Tensor& block_table, - const paddle::Tensor& batch_ids, - const paddle::Tensor& tile_ids_per_batch, - const int num_blocks, - const int block_shape_q, - const int max_seq_len, - const int max_dec_len, - const float quant_max_bound, - const float quant_min_bound, - const float in_scale, - const int max_partition_size, - const int encoder_max_partition_size, - const int speculate_max_draft_token_num, - const bool causal, - const bool is_decoder, - const bool enable_prefill, - cudaStream_t& stream, - paddle::Tensor* out); diff --git a/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c8_float16_fp8_kerne.cu b/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c8_float16_fp8_kerne.cu deleted file mode 100644 index d03d618b23f..00000000000 --- a/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c8_float16_fp8_kerne.cu +++ /dev/null @@ -1,103 +0,0 @@ -// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -#include "../append_attention_c8_impl.cuh" - -template void CascadeAppendAttentionC8Kernel( - const AppendAttnMetaData& meta_data, - const paddle::Tensor& qkv, // [token_num, num_heads, head_dim] - const paddle::Tensor& - cache_k, // [max_block_num, num_heads, block_size, head_dim] - const paddle::Tensor& - cache_v, // [max_block_num, num_heads, head_dim, block_size] - const paddle::optional& attn_mask, - const paddle::optional& - cache_k_scale, // [num_kv_heads, head_dim] - const paddle::optional& - cache_v_scale, // [num_kv_heads, head_dim] - const paddle::optional& - cache_k_zp, // [num_kv_heads, head_dim] - const paddle::optional& - cache_v_zp, // [num_kv_heads, head_dim] - const paddle::optional& - shift_bias, // [num_kv_heads, head_dim] - const paddle::optional& - smooth_weight, // [num_kv_heads, head_dim] - const paddle::Tensor& seq_lens_q, - const paddle::Tensor& seq_lens_kv, - const paddle::Tensor& seq_lens_encoder, - const paddle::Tensor& batch_id_per_token, - const paddle::Tensor& cu_seqlens_q, - const paddle::Tensor& block_table, - const paddle::Tensor& batch_ids, - const paddle::Tensor& tile_ids_per_batch, - const int num_blocks, - const int block_shape_q, - const int max_seq_len, - const int max_dec_len, - const float quant_max_bound, - const float quant_min_bound, - const float in_scale, - const int max_partition_size, - const int encoder_max_partition_size, - const int speculate_max_draft_token_num, - const bool causal, - const bool is_decoder, - const bool enable_prefill, - cudaStream_t& stream, - paddle::Tensor* out); - - -template void CascadeAppendAttentionC8Kernel( - const AppendAttnMetaData& meta_data, - const paddle::Tensor& qkv, // [token_num, num_heads, head_dim] - const paddle::Tensor& - cache_k, // [max_block_num, num_heads, block_size, head_dim] - const paddle::Tensor& - cache_v, // [max_block_num, num_heads, head_dim, block_size] - const paddle::optional& attn_mask, - const paddle::optional& - cache_k_scale, // [num_kv_heads, head_dim] - const paddle::optional& - cache_v_scale, // [num_kv_heads, head_dim] - const paddle::optional& - cache_k_zp, // [num_kv_heads, head_dim] - const paddle::optional& - cache_v_zp, // [num_kv_heads, head_dim] - const paddle::optional& - shift_bias, // [num_kv_heads, head_dim] - const paddle::optional& - smooth_weight, // [num_kv_heads, head_dim] - const paddle::Tensor& seq_lens_q, - const paddle::Tensor& seq_lens_kv, - const paddle::Tensor& seq_lens_encoder, - const paddle::Tensor& batch_id_per_token, - const paddle::Tensor& cu_seqlens_q, - const paddle::Tensor& block_table, - const paddle::Tensor& batch_ids, - const paddle::Tensor& tile_ids_per_batch, - const int num_blocks, - const int block_shape_q, - const int max_seq_len, - const int max_dec_len, - const float quant_max_bound, - const float quant_min_bound, - const float in_scale, - const int max_partition_size, - const int encoder_max_partition_size, - const int speculate_max_draft_token_num, - const bool causal, - const bool is_decoder, - const bool enable_prefill, - cudaStream_t& stream, - paddle::Tensor* out); diff --git a/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c8_float16_int8_kerne.cu b/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c8_float16_int8_kerne.cu deleted file mode 100644 index 1ab83eb52db..00000000000 --- a/custom_ops/gpu_ops/append_attn/template_instantiation/append_attention_c8_float16_int8_kerne.cu +++ /dev/null @@ -1,103 +0,0 @@ -// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -#include "../append_attention_c8_impl.cuh" - -template void CascadeAppendAttentionC8Kernel( - const AppendAttnMetaData& meta_data, - const paddle::Tensor& qkv, // [token_num, num_heads, head_dim] - const paddle::Tensor& - cache_k, // [max_block_num, num_heads, block_size, head_dim] - const paddle::Tensor& - cache_v, // [max_block_num, num_heads, head_dim, block_size] - const paddle::optional& attn_mask, - const paddle::optional& - cache_k_scale, // [num_kv_heads, head_dim] - const paddle::optional& - cache_v_scale, // [num_kv_heads, head_dim] - const paddle::optional& - cache_k_zp, // [num_kv_heads, head_dim] - const paddle::optional& - cache_v_zp, // [num_kv_heads, head_dim] - const paddle::optional& - shift_bias, // [num_kv_heads, head_dim] - const paddle::optional& - smooth_weight, // [num_kv_heads, head_dim] - const paddle::Tensor& seq_lens_q, - const paddle::Tensor& seq_lens_kv, - const paddle::Tensor& seq_lens_encoder, - const paddle::Tensor& batch_id_per_token, - const paddle::Tensor& cu_seqlens_q, - const paddle::Tensor& block_table, - const paddle::Tensor& batch_ids, - const paddle::Tensor& tile_ids_per_batch, - const int num_blocks, - const int block_shape_q, - const int max_seq_len, - const int max_dec_len, - const float quant_max_bound, - const float quant_min_bound, - const float in_scale, - const int max_partition_size, - const int encoder_max_partition_size, - const int speculate_max_draft_token_num, - const bool causal, - const bool is_decoder, - const bool enable_prefill, - cudaStream_t& stream, - paddle::Tensor* out); - - -template void CascadeAppendAttentionC8Kernel( - const AppendAttnMetaData& meta_data, - const paddle::Tensor& qkv, // [token_num, num_heads, head_dim] - const paddle::Tensor& - cache_k, // [max_block_num, num_heads, block_size, head_dim] - const paddle::Tensor& - cache_v, // [max_block_num, num_heads, head_dim, block_size] - const paddle::optional& attn_mask, - const paddle::optional& - cache_k_scale, // [num_kv_heads, head_dim] - const paddle::optional& - cache_v_scale, // [num_kv_heads, head_dim] - const paddle::optional& - cache_k_zp, // [num_kv_heads, head_dim] - const paddle::optional& - cache_v_zp, // [num_kv_heads, head_dim] - const paddle::optional& - shift_bias, // [num_kv_heads, head_dim] - const paddle::optional& - smooth_weight, // [num_kv_heads, head_dim] - const paddle::Tensor& seq_lens_q, - const paddle::Tensor& seq_lens_kv, - const paddle::Tensor& seq_lens_encoder, - const paddle::Tensor& batch_id_per_token, - const paddle::Tensor& cu_seqlens_q, - const paddle::Tensor& block_table, - const paddle::Tensor& batch_ids, - const paddle::Tensor& tile_ids_per_batch, - const int num_blocks, - const int block_shape_q, - const int max_seq_len, - const int max_dec_len, - const float quant_max_bound, - const float quant_min_bound, - const float in_scale, - const int max_partition_size, - const int encoder_max_partition_size, - const int speculate_max_draft_token_num, - const bool causal, - const bool is_decoder, - const bool enable_prefill, - cudaStream_t& stream, - paddle::Tensor* out); diff --git a/custom_ops/gpu_ops/append_attn/template_instantiation/encoder_write_cache_with_rope_bfloat16_bfloat16_kernel.cu b/custom_ops/gpu_ops/append_attn/template_instantiation/encoder_write_cache_with_rope_bfloat16_bfloat16_kernel.cu index 8d786ce5838..915039908dc 100644 --- a/custom_ops/gpu_ops/append_attn/template_instantiation/encoder_write_cache_with_rope_bfloat16_bfloat16_kernel.cu +++ b/custom_ops/gpu_ops/append_attn/template_instantiation/encoder_write_cache_with_rope_bfloat16_bfloat16_kernel.cu @@ -43,4 +43,7 @@ EncoderWriteCacheWithRopeKernel( cudaStream_t& stream, paddle::Tensor* qkv_out, paddle::Tensor* key_cache_out, - paddle::Tensor* value_cache_out); + paddle::Tensor* value_cache_out, + const paddle::optional& q_norm_weight, + const paddle::optional& k_norm_weight, + const float rms_norm_eps); diff --git a/custom_ops/gpu_ops/append_attn/template_instantiation/encoder_write_cache_with_rope_bfloat16_int_kernel.cu b/custom_ops/gpu_ops/append_attn/template_instantiation/encoder_write_cache_with_rope_bfloat16_int_kernel.cu index a34da825823..3f3539b8a6e 100644 --- a/custom_ops/gpu_ops/append_attn/template_instantiation/encoder_write_cache_with_rope_bfloat16_int_kernel.cu +++ b/custom_ops/gpu_ops/append_attn/template_instantiation/encoder_write_cache_with_rope_bfloat16_int_kernel.cu @@ -42,4 +42,7 @@ template void EncoderWriteCacheWithRopeKernel( cudaStream_t& stream, paddle::Tensor* qkv_out, paddle::Tensor* key_cache_out, - paddle::Tensor* value_cache_out); + paddle::Tensor* value_cache_out, + const paddle::optional& q_norm_weight, + const paddle::optional& k_norm_weight, + const float rms_norm_eps); diff --git a/custom_ops/gpu_ops/append_attn/template_instantiation/encoder_write_cache_with_rope_float16_float16_kernel.cu b/custom_ops/gpu_ops/append_attn/template_instantiation/encoder_write_cache_with_rope_float16_float16_kernel.cu index 42f07ee8b75..a559ec77f37 100644 --- a/custom_ops/gpu_ops/append_attn/template_instantiation/encoder_write_cache_with_rope_float16_float16_kernel.cu +++ b/custom_ops/gpu_ops/append_attn/template_instantiation/encoder_write_cache_with_rope_float16_float16_kernel.cu @@ -42,4 +42,7 @@ template void EncoderWriteCacheWithRopeKernel( cudaStream_t& stream, paddle::Tensor* qkv_out, paddle::Tensor* key_cache_out, - paddle::Tensor* value_cache_out); + paddle::Tensor* value_cache_out, + const paddle::optional& q_norm_weight, + const paddle::optional& k_norm_weight, + const float rms_norm_eps); diff --git a/custom_ops/gpu_ops/append_attn/template_instantiation/encoder_write_cache_with_rope_float16_int_kernel.cu b/custom_ops/gpu_ops/append_attn/template_instantiation/encoder_write_cache_with_rope_float16_int_kernel.cu index ef3d3832e4e..3318a36472b 100644 --- a/custom_ops/gpu_ops/append_attn/template_instantiation/encoder_write_cache_with_rope_float16_int_kernel.cu +++ b/custom_ops/gpu_ops/append_attn/template_instantiation/encoder_write_cache_with_rope_float16_int_kernel.cu @@ -42,4 +42,7 @@ template void EncoderWriteCacheWithRopeKernel( cudaStream_t& stream, paddle::Tensor* qkv_out, paddle::Tensor* key_cache_out, - paddle::Tensor* value_cache_out); + paddle::Tensor* value_cache_out, + const paddle::optional& q_norm_weight, + const paddle::optional& k_norm_weight, + const float rms_norm_eps); diff --git a/custom_ops/gpu_ops/append_attn/utils.cuh b/custom_ops/gpu_ops/append_attn/utils.cuh index 05f500126cb..fa5b3bca178 100644 --- a/custom_ops/gpu_ops/append_attn/utils.cuh +++ b/custom_ops/gpu_ops/append_attn/utils.cuh @@ -27,6 +27,7 @@ struct AppendAttnMetaData { int head_dims; int head_dims_v; int max_blocks_per_seq; + const int* mask_offset = nullptr; }; __forceinline__ __host__ __device__ int div_up(int a, int b) { @@ -109,29 +110,33 @@ __device__ __forceinline__ uint32_t sub_if_greater_or_zero(uint32_t x, /******************************FASTER CAST*********************************/ -inline __device__ static void convert_fp8(__nv_bfloat16* result, const uint32_t& source) { - +inline __device__ static void convert_fp8(__nv_bfloat16* result, + const uint32_t& source) { #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 890) - uint32_t dest0; - uint32_t dest1; - asm volatile( \ - "{\n" \ - ".reg .b16 lo, hi;\n" \ - "mov.b32 {lo, hi}, %2;\n" \ - "cvt.rn.f16x2.e4m3x2 %0, lo;\n" \ - "cvt.rn.f16x2.e4m3x2 %1, hi;\n" \ - "}\n" : "=r"(dest0), "=r"(dest1) : "r"(source)); - - ((nv_bfloat162*)(result))[0] = __float22bfloat162_rn(__half22float2(((half2*)(&dest0))[0])); - ((nv_bfloat162*)(result))[1] = __float22bfloat162_rn(__half22float2(((half2*)(&dest1))[0])); + uint32_t dest0; + uint32_t dest1; + asm volatile( + "{\n" + ".reg .b16 lo, hi;\n" + "mov.b32 {lo, hi}, %2;\n" + "cvt.rn.f16x2.e4m3x2 %0, lo;\n" + "cvt.rn.f16x2.e4m3x2 %1, hi;\n" + "}\n" + : "=r"(dest0), "=r"(dest1) + : "r"(source)); + + ((nv_bfloat162*)(result))[0] = + __float22bfloat162_rn(__half22float2(((half2*)(&dest0))[0])); + ((nv_bfloat162*)(result))[1] = + __float22bfloat162_rn(__half22float2(((half2*)(&dest1))[0])); #else - printf("Do not support fp8 in arch < 890\n"); - asm("trap;"); + printf("Do not support fp8 in arch < 890\n"); + asm("trap;"); #endif - } -inline __device__ static void convert_fp8(half* result, const uint32_t& source) { +inline __device__ static void convert_fp8(half* result, + const uint32_t& source) { printf("Do not support fp8 to half although it's very easy.\n"); } @@ -300,6 +305,11 @@ __forceinline__ __host__ __device__ void vec_cast( #define DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, ...) \ switch (head_dim) { \ + case 64: { \ + constexpr size_t HEAD_DIM = 64; \ + __VA_ARGS__ \ + break; \ + } \ case 128: { \ constexpr size_t HEAD_DIM = 128; \ __VA_ARGS__ \ @@ -379,9 +389,8 @@ __forceinline__ __host__ __device__ void vec_cast( PD_THROW("not support the cache_type: ", cache_type); \ } - #define DISPATCH_DEAL_EACH_TIME(deal_each_time, DEAL_EACH_TIME, ...) \ - if (deal_each_time == 32) { \ + if (deal_each_time == 32) { \ constexpr size_t DEAL_EACH_TIME = 32; \ __VA_ARGS__ \ } else if (deal_each_time == 64) { \ @@ -398,7 +407,7 @@ __forceinline__ __host__ __device__ void vec_cast( } else if (num_threads == 256) { \ constexpr size_t NUM_THREADS = 256; \ __VA_ARGS__ \ - } else { \ + } else { \ PD_THROW("not support the num_threads", num_threads); \ } @@ -430,6 +439,9 @@ __forceinline__ __host__ __device__ void vec_cast( } else if (group_size == 12) { \ constexpr size_t GROUP_SIZE = 12; \ __VA_ARGS__ \ + } else if (group_size == 14) { \ + constexpr size_t GROUP_SIZE = 14; \ + __VA_ARGS__ \ } else if (group_size == 16) { \ constexpr size_t GROUP_SIZE = 16; \ __VA_ARGS__ \ @@ -437,8 +449,17 @@ __forceinline__ __host__ __device__ void vec_cast( PD_THROW("not support the group_size", group_size); \ } +#define DISPATCH_DyCfp8(is_dynamic_cfp8, IsDynamicC8, ...) \ + if (is_dynamic_cfp8) { \ + constexpr bool IsDynamicC8 = true; \ + __VA_ARGS__ \ + } else { \ + constexpr bool IsDynamicC8 = false; \ + __VA_ARGS__ \ + } + #define DISPATCH_MLA_GROUP_SIZE(group_size, GROUP_SIZE, ...) \ - if (group_size == 8) { \ + if (group_size == 8) { \ constexpr size_t GROUP_SIZE = 8; \ __VA_ARGS__ \ } else if (group_size == 16) { \ @@ -474,6 +495,9 @@ __forceinline__ __host__ __device__ void vec_cast( if (causal) { \ constexpr bool CAUSAL = true; \ __VA_ARGS__ \ + } else { \ + constexpr bool CAUSAL = false; \ + __VA_ARGS__ \ } #define DISPATCH_ENABLE_PREFILL(enable_prefill, ENABLE_PREFILL, ...) \ @@ -517,9 +541,11 @@ inline HOSTDEVICE T roundWithTiesToEven(T x) { : xUpper); } - template -__host__ __device__ __forceinline__ uint8_t QuantToC8(const T scale, const T value, const float max_bound, const float min_bound) { +__host__ __device__ __forceinline__ uint8_t QuantToC8(const T scale, + const T value, + const float max_bound, + const float min_bound) { uint8_t eight_bits; float quant_value; if constexpr (is_need_kv_quant) { @@ -551,11 +577,45 @@ __host__ __device__ __forceinline__ uint8_t QuantToC8(const T scale, const T val return eight_bits; } - -template inline __device__ static void convert_c8(T * result, const uint32_t& source){ +template +inline __device__ static void convert_c8(T* result, const uint32_t& source) { if constexpr (IsFP8) { convert_fp8(result, source); } else { convert_int8(result, source); } } + +constexpr int kWarpSize = 32; + +template +inline __device__ void WelfordCombine1(T b_m2, T* m2) { + *m2 += b_m2; +} + +template +__inline__ __device__ void WelfordWarpReduce(T thread_m2, T* m2) { + *m2 = thread_m2; + for (int mask = thread_group_width / 2; mask > 0; mask >>= 1) { + T b_m2 = __shfl_xor_sync(0xffffffff, *m2, mask); + WelfordCombine1(b_m2, m2); + } +} + +template +__inline__ __device__ void WelfordWarpAllReduce(T thread_m2, T* m2) { + WelfordWarpReduce(thread_m2, m2); +} + +template +__inline__ __device__ T Rsqrt(T x); + +template <> +__inline__ __device__ float Rsqrt(float x) { + return rsqrt(x); +} + +template <> +__inline__ __device__ double Rsqrt(double x) { + return rsqrt(x); +} diff --git a/custom_ops/gpu_ops/beam_search_softmax.cu b/custom_ops/gpu_ops/beam_search_softmax.cu index a12f2ca17be..59500dbe33f 100644 --- a/custom_ops/gpu_ops/beam_search_softmax.cu +++ b/custom_ops/gpu_ops/beam_search_softmax.cu @@ -28,16 +28,17 @@ namespace cub = hipcub; #include #include #include -#include "stdint.h" #include "helper.h" +#include "stdint.h" +#include "cccl_compat.h" // CCCL 3.0 compatibility #define FLT_MAX 1e38 static constexpr int kBlockSizeForSmallBeamWidth = 256; static constexpr int kMaxVocabPartForStage1FastKernel = 128; -#define CASE_K(K) \ - case K: \ +#define CASE_K(K) \ + case K: \ invokeTopKSoftMaxLauncher( \ params, beam_group_idx, stream); \ break @@ -368,7 +369,7 @@ __launch_bounds__(THREADBLOCK_SIZE, 1) __global__ using KVPair = cub::KeyValuePair; KVPair topKVPairPartial{vocab_size - 1, -MAX_T_VAL}; - cub::ArgMax argmax; + fd_cub_compat::ArgMax argmax; T const *local_logits = logits + beam_batch_id * vocab_size; #pragma unroll 1 @@ -595,7 +596,7 @@ __launch_bounds__(THREADBLOCK_SIZE) __global__ typename BlockReduceMD::TempStorage md; } smemReduceBuffer; - cub::ArgMax argmax; + fd_cub_compat::ArgMax argmax; MD partial_md{-MAX_T_VAL, 0.0f}; KVPair topKVPair{vocab_size - 1, -MAX_T_VAL}; @@ -1336,24 +1337,25 @@ adding while op and without affecting the speed. Use a 'fake inplace' method here. Not elegant but useful ︸_︸. *****/ -std::vector BeamSearchSoftmax(const paddle::Tensor &logits, - const paddle::Tensor &seq_lens, - const paddle::Tensor &stop_flags, // inplace - const paddle::Tensor &end_ids, - const paddle::Tensor &step_ids, - const paddle::Tensor &max_dec_lens, - const paddle::Tensor &block_tables, // inplace - const paddle::Tensor &cum_scores, // inplace - const paddle::Tensor &beam_cache_ids, // inplace - const paddle::Tensor &beam_hyps, // inplace - const paddle::Tensor &beam_hyps_score, // inplace - const paddle::Tensor &beam_finished, // inplace - const paddle::Tensor &beam_width, - const paddle::Tensor &beam_group_num, - const paddle::Tensor &length_penalty, - const paddle::Tensor &diversity_penalty, - bool fuse_softmax, - bool early_stop) { +std::vector BeamSearchSoftmax( + const paddle::Tensor &logits, + const paddle::Tensor &seq_lens, + const paddle::Tensor &stop_flags, // inplace + const paddle::Tensor &end_ids, + const paddle::Tensor &step_ids, + const paddle::Tensor &max_dec_lens, + const paddle::Tensor &block_tables, // inplace + const paddle::Tensor &cum_scores, // inplace + const paddle::Tensor &beam_cache_ids, // inplace + const paddle::Tensor &beam_hyps, // inplace + const paddle::Tensor &beam_hyps_score, // inplace + const paddle::Tensor &beam_finished, // inplace + const paddle::Tensor &beam_width, + const paddle::Tensor &beam_group_num, + const paddle::Tensor &length_penalty, + const paddle::Tensor &diversity_penalty, + bool fuse_softmax, + bool early_stop) { std::vector logits_shape = logits.shape(); // logits_shape auto cu_stream = logits.stream(); @@ -1371,6 +1373,9 @@ std::vector BeamSearchSoftmax(const paddle::Tensor &logits, cudaMemcpyDeviceToHost, cu_stream); + // Must synchronize before using host values copied from device + cudaStreamSynchronize(cu_stream); + int beam_batch_size = logits_shape[0]; int batch_size = beam_batch_size / beam_width_scalar; int vocab_size = logits_shape[1]; @@ -1380,43 +1385,43 @@ std::vector BeamSearchSoftmax(const paddle::Tensor &logits, const int end_ids_len = end_ids.dims()[0]; const int beam_group_size = beam_width_scalar / beam_group_num_scalar; - auto next_tokens = paddle::full({logits_shape[0], 1}, 0, end_ids.type(), - paddle::GPUPlace()); + auto next_tokens = + paddle::full({logits_shape[0], 1}, 0, end_ids.type(), paddle::GPUPlace()); - auto parent_ids = paddle::full({logits_shape[0], 1}, 0, end_ids.type(), - paddle::GPUPlace()); + auto parent_ids = + paddle::full({logits_shape[0], 1}, 0, end_ids.type(), paddle::GPUPlace()); - auto cum_scores_ori = paddle::empty(cum_scores.shape(), logits.type(), - paddle::GPUPlace()); + auto cum_scores_ori = + paddle::empty(cum_scores.shape(), logits.type(), paddle::GPUPlace()); - auto beam_cache_ids_ori = paddle::empty(beam_cache_ids.shape(), end_ids.type(), - paddle::GPUPlace()); + auto beam_cache_ids_ori = + paddle::empty(beam_cache_ids.shape(), end_ids.type(), paddle::GPUPlace()); - auto block_tables_ori = paddle::empty(block_tables.shape(), end_ids.type(), - paddle::GPUPlace()); + auto block_tables_ori = + paddle::empty(block_tables.shape(), end_ids.type(), paddle::GPUPlace()); cudaMemcpyAsync(cum_scores_ori.mutable_data(), cum_scores.data(), - sizeof(float)*cum_scores.numel(), + sizeof(float) * cum_scores.numel(), cudaMemcpyDeviceToDevice, cu_stream); cudaMemcpyAsync(beam_cache_ids_ori.mutable_data(), beam_cache_ids.data(), - sizeof(int)*beam_cache_ids.numel(), + sizeof(int) * beam_cache_ids.numel(), cudaMemcpyDeviceToDevice, cu_stream); cudaMemcpyAsync(block_tables_ori.mutable_data(), block_tables.data(), - sizeof(int)*block_tables.numel(), + sizeof(int) * block_tables.numel(), cudaMemcpyDeviceToDevice, cu_stream); const int tmp_size = batch_size * beam_group_size * beam_group_size * 2; - auto tmp_topk_id = paddle::full({tmp_size}, 0, end_ids.type(), - paddle::GPUPlace()); + auto tmp_topk_id = + paddle::full({tmp_size}, 0, end_ids.type(), paddle::GPUPlace()); - auto tmp_topk_val = paddle::full({tmp_size}, 0.0, logits.type(), - paddle::GPUPlace()); + auto tmp_topk_val = + paddle::full({tmp_size}, 0.0, logits.type(), paddle::GPUPlace()); BeamSearchParams params; params.batch_size = batch_size; @@ -1449,7 +1454,8 @@ std::vector BeamSearchSoftmax(const paddle::Tensor &logits, params.block_tables_out = const_cast(block_tables.data()); params.cum_scores_out = const_cast(cum_scores.data()); params.beam_hyps_out = const_cast(beam_hyps.data()); - params.beam_hyps_score_out = const_cast(beam_hyps_score.data()); + params.beam_hyps_score_out = + const_cast(beam_hyps_score.data()); params.beam_finished = const_cast(beam_finished.data()); params.stop_flags = const_cast(stop_flags.data()); @@ -1470,8 +1476,8 @@ std::vector BeamSearchSoftmax(const paddle::Tensor &logits, const int workspace_size = tmp_id_val_size * 2 + tmp_stage1_to_stage2_size; - auto wsp_buffer_tensor = paddle::full({workspace_size}, 0, logits.type(), - paddle::GPUPlace()); + auto wsp_buffer_tensor = + paddle::full({workspace_size}, 0, logits.type(), paddle::GPUPlace()); params.tmp_ids = reinterpret_cast(wsp_buffer_tensor.data()); params.tmp_vals = wsp_buffer_tensor.data() + tmp_id_val_size; @@ -1480,11 +1486,9 @@ std::vector BeamSearchSoftmax(const paddle::Tensor &logits, for (int beam_group_idx = 0; beam_group_idx < beam_group_num_scalar; ++beam_group_idx) { if (beam_group_num_scalar == 1) { - invokeTopkSoftMax( - ¶ms, beam_group_idx, cu_stream); + invokeTopkSoftMax(¶ms, beam_group_idx, cu_stream); } else { - invokeTopkSoftMax( - ¶ms, beam_group_idx, cu_stream); + invokeTopkSoftMax(¶ms, beam_group_idx, cu_stream); } } updateBeamSearchParams(¶ms, cu_stream); @@ -1492,54 +1496,66 @@ std::vector BeamSearchSoftmax(const paddle::Tensor &logits, } std::vector> BeamSearchSoftmaxShape( - const std::vector &logits, - const std::vector &seq_lens, - const std::vector &stop_flags, // inplace - const std::vector &end_ids, - const std::vector &step_ids, - const std::vector &max_dec_lens, - const std::vector &block_tables, // inplace - const std::vector &cum_scores, // inplace - const std::vector &beam_cache_ids, // inplace - const std::vector &beam_hyps, // inplace - const std::vector &beam_hyps_score, // inplace - const std::vector &beam_finished, // inplace - const std::vector &beam_width, - const std::vector &beam_group_num, - const std::vector &length_penalty, - const std::vector &diversity_penalty) { - std::vector next_tokens = {logits[0],1}; - std::vector parent_ids = {logits[0],1}; - return {next_tokens,parent_ids}; + const std::vector &logits, + const std::vector &seq_lens, + const std::vector &stop_flags, // inplace + const std::vector &end_ids, + const std::vector &step_ids, + const std::vector &max_dec_lens, + const std::vector &block_tables, // inplace + const std::vector &cum_scores, // inplace + const std::vector &beam_cache_ids, // inplace + const std::vector &beam_hyps, // inplace + const std::vector &beam_hyps_score, // inplace + const std::vector &beam_finished, // inplace + const std::vector &beam_width, + const std::vector &beam_group_num, + const std::vector &length_penalty, + const std::vector &diversity_penalty) { + std::vector next_tokens = {logits[0], 1}; + std::vector parent_ids = {logits[0], 1}; + return {next_tokens, parent_ids}; } std::vector BeamSearchSoftmaxDtype( - const paddle::DataType &logits, - const paddle::DataType &seq_lens, - const paddle::DataType &stop_flags, // inplace - const paddle::DataType &end_ids, - const paddle::DataType &step_ids, - const paddle::DataType &max_dec_lens, - const paddle::DataType &block_tables, // inplace - const paddle::DataType &cum_scores, // inplace - const paddle::DataType &beam_cache_ids, // inplace - const paddle::DataType &beam_hyps, // inplace - const paddle::DataType &beam_hyps_score, // inplace - const paddle::DataType &beam_finished, // inplace - const paddle::DataType &beam_width, - const paddle::DataType &beam_group_num, - const paddle::DataType &length_penalty, - const paddle::DataType &diversity_penalty) { - return {paddle::DataType::INT32, paddle::DataType::INT32}; + const paddle::DataType &logits, + const paddle::DataType &seq_lens, + const paddle::DataType &stop_flags, // inplace + const paddle::DataType &end_ids, + const paddle::DataType &step_ids, + const paddle::DataType &max_dec_lens, + const paddle::DataType &block_tables, // inplace + const paddle::DataType &cum_scores, // inplace + const paddle::DataType &beam_cache_ids, // inplace + const paddle::DataType &beam_hyps, // inplace + const paddle::DataType &beam_hyps_score, // inplace + const paddle::DataType &beam_finished, // inplace + const paddle::DataType &beam_width, + const paddle::DataType &beam_group_num, + const paddle::DataType &length_penalty, + const paddle::DataType &diversity_penalty) { + return {paddle::DataType::INT32, paddle::DataType::INT32}; } PD_BUILD_STATIC_OP(beam_search_softmax) - .Inputs({"logits", "seq_lens", "stop_flags", "end_ids", "step_ids", "max_dec_lens", "block_tables" - , "cum_scores", "beam_cache_ids", "beam_hyps", "beam_hyps_score", "beam_finished" - , "beam_width", "beam_group_num", "length_penalty", "diversity_penalty"}) + .Inputs({"logits", + "seq_lens", + "stop_flags", + "end_ids", + "step_ids", + "max_dec_lens", + "block_tables", + "cum_scores", + "beam_cache_ids", + "beam_hyps", + "beam_hyps_score", + "beam_finished", + "beam_width", + "beam_group_num", + "length_penalty", + "diversity_penalty"}) .Outputs({"next_tokens", "parent_ids"}) - .Attrs({"fuse_softmax: bool", - "early_stop: bool"}) + .Attrs({"fuse_softmax: bool", "early_stop: bool"}) .SetKernelFn(PD_KERNEL(BeamSearchSoftmax)) .SetInferShapeFn(PD_INFER_SHAPE(BeamSearchSoftmaxShape)) .SetInferDtypeFn(PD_INFER_DTYPE(BeamSearchSoftmaxDtype)); diff --git a/custom_ops/gpu_ops/cccl_compat.h b/custom_ops/gpu_ops/cccl_compat.h new file mode 100644 index 00000000000..2a9f9627ef5 --- /dev/null +++ b/custom_ops/gpu_ops/cccl_compat.h @@ -0,0 +1,151 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +// CCCL 3.0 compatibility header for CUDA 13.0+ +// In CCCL 3.0, cub::Sum, cub::Max, cub::Min are removed from the cub namespace. +// This header provides compatible implementations that work with both old and +// new versions. + +// Include cub headers based on platform +#ifdef PADDLE_WITH_HIP +#include +#else +#include +#endif + +// Detect CUDA 13.0+ (CCCL 3.0) +// __CUDACC_VER_MAJOR__ >= 13 indicates CUDA 13.0 or later +#if defined(__CUDACC_VER_MAJOR__) && __CUDACC_VER_MAJOR__ >= 13 +#define FD_CCCL_V3 1 +#endif + +namespace fd_cub_compat { + +// ============================================================================ +// Sum, Max, Min functors +// ============================================================================ + +#ifdef FD_CCCL_V3 +// CUDA 13.0+ (CCCL 3.0): Use custom implementations since cub::Sum/Max/Min are +// removed + +/// Functor for computing the sum of two values +struct Sum { + /// Apply the sum operation + template + __host__ __device__ __forceinline__ T operator()(const T &a, + const T &b) const { + return a + b; + } +}; + +/// Functor for computing the maximum of two values +struct Max { + /// Apply the max operation + template + __host__ __device__ __forceinline__ T operator()(const T &a, + const T &b) const { + return (b > a) ? b : a; + } +}; + +/// Functor for computing the minimum of two values +struct Min { + /// Apply the min operation + template + __host__ __device__ __forceinline__ T operator()(const T &a, + const T &b) const { + return (b < a) ? b : a; + } +}; + +#else +// CUDA 12.x and earlier: Use native cub implementations + +#ifdef PADDLE_WITH_HIP +using Sum = hipcub::Sum; +using Max = hipcub::Max; +using Min = hipcub::Min; +#else +using Sum = cub::Sum; +using Max = cub::Max; +using Min = cub::Min; +#endif + +#endif // FD_CCCL_V3 + +// ============================================================================ +// ArgMax, ArgMin functors +// These are also removed in CCCL 3.0 +// ============================================================================ + +#ifdef FD_CCCL_V3 +// CUDA 13.0+ (CCCL 3.0): Use custom implementations since cub::ArgMax/ArgMin +// are removed + +/// Functor for computing the ArgMax of two values (for cub::BlockReduce with +/// KeyValuePair) Returns the key-value pair with the larger value +struct ArgMax { + /// Apply ArgMax operation (returns pair with max value and its key/index) + template + __host__ __device__ __forceinline__ KeyValuePair + operator()(const KeyValuePair &a, const KeyValuePair &b) const { + return (b.value > a.value) ? b : a; + } +}; + +/// Functor for computing the ArgMin of two values (for cub::BlockReduce with +/// KeyValuePair) Returns the key-value pair with the smaller value +struct ArgMin { + /// Apply ArgMin operation (returns pair with min value and its key/index) + template + __host__ __device__ __forceinline__ KeyValuePair + operator()(const KeyValuePair &a, const KeyValuePair &b) const { + return (b.value < a.value) ? b : a; + } +}; + +#else +// CUDA 12.x and earlier: Use native cub implementations + +#ifdef PADDLE_WITH_HIP +using ArgMax = hipcub::ArgMax; +using ArgMin = hipcub::ArgMin; +#else +// For older CUDA versions, wrap the native cub::ArgMax/ArgMin +struct ArgMax { + template + __host__ __device__ __forceinline__ KeyValuePair + operator()(const KeyValuePair &a, const KeyValuePair &b) const { + cub::ArgMax argmax; + return argmax(a, b); + } +}; + +struct ArgMin { + template + __host__ __device__ __forceinline__ KeyValuePair + operator()(const KeyValuePair &a, const KeyValuePair &b) const { + cub::ArgMin argmin; + return argmin(a, b); + } +}; + +#endif // PADDLE_WITH_HIP + +#endif // FD_CCCL_V3 + +} // namespace fd_cub_compat diff --git a/custom_ops/gpu_ops/common/configManager.h b/custom_ops/gpu_ops/common/configManager.h index d0bb751e976..960df1259b4 100644 --- a/custom_ops/gpu_ops/common/configManager.h +++ b/custom_ops/gpu_ops/common/configManager.h @@ -22,87 +22,96 @@ #include class ConfigManager { -public: - static ConfigManager& get_instance(const std::string& config_path = "fastdeploy_op_configs.json") { - static ConfigManager instance(config_path); - return instance; - } + public: + static ConfigManager& get_instance( + const std::string& config_path = "fastdeploy_op_configs.json") { + static ConfigManager instance(config_path); + return instance; + } - std::string get_best_config(const std::string& op_name, const size_t m, const size_t n, const size_t k) { - initialize(); - std::string mnk_string = op_name + "-" + - std::to_string(update_m(m)) + "x" + std::to_string(n) + "x" + std::to_string(k); - if (configs_.contains(mnk_string)) { - return configs_.at(mnk_string); - } - return ""; + std::string get_best_config(const std::string& op_name, + const size_t m, + const size_t n, + const size_t k) { + initialize(); + std::string mnk_string = op_name + "-" + std::to_string(update_m(m)) + "x" + + std::to_string(n) + "x" + std::to_string(k); + if (configs_.contains(mnk_string)) { + return configs_.at(mnk_string); } + return ""; + } - int64_t update_m(const size_t m) { - size_t new_m = m; - if (m < 4) { - return m; - } else if (m < 16) { - return (m + 3) / 4 * 4; - } else if (m < 64) { - return (m + 15) / 16 * 16; - } else if (m < 256) { - return (m + 31) / 32 * 32; - } else if (m < 512) { - return (m + 63) / 64 * 64; - } else if (m < 1024) { - return (m + 127) / 128 * 128; - } else if (m < 8192) { - return (m + 1023) / 1024 * 1024; - } else if (m < 32768) { - return (m + 4095) / 4096 * 4096; - } else { - return 32768; - } + int64_t update_m(const size_t m) { + size_t new_m = m; + if (m < 4) { + return m; + } else if (m < 16) { + return (m + 3) / 4 * 4; + } else if (m < 64) { + return (m + 15) / 16 * 16; + } else if (m < 256) { + return (m + 31) / 32 * 32; + } else if (m < 512) { + return (m + 63) / 64 * 64; + } else if (m < 1024) { + return (m + 127) / 128 * 128; + } else if (m < 8192) { + return (m + 1023) / 1024 * 1024; + } else if (m < 32768) { + return (m + 4095) / 4096 * 4096; + } else { + return 32768; } + } - void update(const std::string& op_name, const size_t m, const size_t n, const size_t k, const std::string& config) { - initialize(); - std::string mnk_string = op_name + "-" + - std::to_string(update_m(m)) + "x" + std::to_string(n) + "x" + std::to_string(k); - configs_[mnk_string] = config; - } + void update(const std::string& op_name, + const size_t m, + const size_t n, + const size_t k, + const std::string& config) { + initialize(); + std::string mnk_string = op_name + "-" + std::to_string(update_m(m)) + "x" + + std::to_string(n) + "x" + std::to_string(k); + configs_[mnk_string] = config; + } - void print() const { - std::cout << configs_.dump(4) << std::endl; // Pretty print with 4 spaces - } + void print() const { + std::cout << configs_.dump(4) << std::endl; // Pretty print with 4 spaces + } - ~ConfigManager() { - std::ofstream file(config_path_); - if (file.is_open()) { - file << configs_.dump(4); // Pretty print with 4 spaces - file.close(); - } + ~ConfigManager() { + std::ofstream file(config_path_); + if (file.is_open()) { + file << configs_.dump(4); // Pretty print with 4 spaces + file.close(); } + } -private: - void initialize() { - if (initialized_) return; - std::ifstream file(config_path_); - if (file.is_open()) { - try { - file >> configs_; - } catch (const std::exception& e) { - std::cerr << "Error reading configs from " << config_path_ << " : " << e.what() << std::endl; - configs_ = nlohmann::json::object(); // Create an empty JSON object - } - file.close(); - } else { - configs_ = nlohmann::json::object(); // Create an empty JSON object - } - initialized_ = true; + private: + void initialize() { + if (initialized_) return; + std::ifstream file(config_path_); + if (file.is_open()) { + try { + file >> configs_; + } catch (const std::exception& e) { + std::cerr << "Error reading configs from " << config_path_ << " : " + << e.what() << std::endl; + configs_ = nlohmann::json::object(); // Create an empty JSON object + } + file.close(); + } else { + configs_ = nlohmann::json::object(); // Create an empty JSON object } + initialized_ = true; + } - ConfigManager(const std::string& config_path) : config_path_(config_path) {} - ConfigManager(const ConfigManager&) = delete; - ConfigManager& operator=(const ConfigManager&) = delete; + ConfigManager(const std::string& config_path) : config_path_(config_path) {} + ConfigManager(const ConfigManager&) = delete; + ConfigManager& operator=(const ConfigManager&) = delete; - nlohmann::json configs_; - std::string config_path_; - bool initialized_{false}; + nlohmann::json configs_; + std::string config_path_; + bool initialized_{false}; }; diff --git a/custom_ops/gpu_ops/common/cudaUtils.h b/custom_ops/gpu_ops/common/cudaUtils.h index 9bbd1f6e801..7123e33ebbc 100644 --- a/custom_ops/gpu_ops/common/cudaUtils.h +++ b/custom_ops/gpu_ops/common/cudaUtils.h @@ -16,18 +16,18 @@ #include #include "paddle/phi/core/enforce.h" -namespace common -{ +namespace common { -inline int getSMVersion() -{ - int device{-1}; - PADDLE_ENFORCE_GPU_SUCCESS(cudaGetDevice(&device)); - int sm_major = 0; - int sm_minor = 0; - PADDLE_ENFORCE_GPU_SUCCESS(cudaDeviceGetAttribute(&sm_major, cudaDevAttrComputeCapabilityMajor, device)); - PADDLE_ENFORCE_GPU_SUCCESS(cudaDeviceGetAttribute(&sm_minor, cudaDevAttrComputeCapabilityMinor, device)); - return sm_major * 10 + sm_minor; +inline int getSMVersion() { + int device{-1}; + PADDLE_ENFORCE_GPU_SUCCESS(cudaGetDevice(&device)); + int sm_major = 0; + int sm_minor = 0; + PADDLE_ENFORCE_GPU_SUCCESS(cudaDeviceGetAttribute( + &sm_major, cudaDevAttrComputeCapabilityMajor, device)); + PADDLE_ENFORCE_GPU_SUCCESS(cudaDeviceGetAttribute( + &sm_minor, cudaDevAttrComputeCapabilityMinor, device)); + return sm_major * 10 + sm_minor; } -} +} // namespace common diff --git a/custom_ops/gpu_ops/common/quantization.h b/custom_ops/gpu_ops/common/quantization.h index e6a74760bb0..433e7953111 100644 --- a/custom_ops/gpu_ops/common/quantization.h +++ b/custom_ops/gpu_ops/common/quantization.h @@ -20,312 +20,240 @@ #include #include -namespace common -{ - -class QuantMode -{ - // [WARNING] KEEP BELOW DEFINITION IN SYNC WITH tensorrt_llm/quantization/mode.py -public: - using BaseType = std::uint32_t; - - explicit constexpr QuantMode(BaseType value) noexcept - : mValue{value} - { - } +namespace common { - QuantMode() noexcept = default; +class QuantMode { + // [WARNING] KEEP BELOW DEFINITION IN SYNC WITH + // tensorrt_llm/quantization/mode.py + public: + using BaseType = std::uint32_t; - constexpr QuantMode(QuantMode const&) noexcept = default; + explicit constexpr QuantMode(BaseType value) noexcept : mValue{value} {} - constexpr QuantMode& operator=(QuantMode const& other) noexcept = default; + QuantMode() noexcept = default; - static constexpr QuantMode none() noexcept - { - return QuantMode(BaseType(0)); - } + constexpr QuantMode(QuantMode const&) noexcept = default; - static constexpr QuantMode int4Weights() noexcept - { - return QuantMode(BaseType(1u) << 0); - } + constexpr QuantMode& operator=(QuantMode const& other) noexcept = default; - static constexpr QuantMode int8Weights() noexcept - { - return QuantMode(BaseType(1u) << 1); - } + static constexpr QuantMode none() noexcept { return QuantMode(BaseType(0)); } - static constexpr QuantMode activations() noexcept - { - return QuantMode(BaseType(1u) << 2); - } + static constexpr QuantMode int4Weights() noexcept { + return QuantMode(BaseType(1u) << 0); + } - static constexpr QuantMode perChannelScaling() noexcept - { - return QuantMode(BaseType(1u) << 3); - } + static constexpr QuantMode int8Weights() noexcept { + return QuantMode(BaseType(1u) << 1); + } - static constexpr QuantMode perTokenScaling() noexcept - { - return QuantMode(BaseType(1u) << 4); - } + static constexpr QuantMode activations() noexcept { + return QuantMode(BaseType(1u) << 2); + } - static constexpr QuantMode perGroupScaling() noexcept - { - return QuantMode(BaseType(1u) << 5); - } + static constexpr QuantMode perChannelScaling() noexcept { + return QuantMode(BaseType(1u) << 3); + } - static constexpr QuantMode int8KvCache() noexcept - { - return QuantMode(BaseType(1u) << 6); - } + static constexpr QuantMode perTokenScaling() noexcept { + return QuantMode(BaseType(1u) << 4); + } - static constexpr QuantMode fp8KvCache() noexcept - { - return QuantMode(BaseType(1u) << 7); - } + static constexpr QuantMode perGroupScaling() noexcept { + return QuantMode(BaseType(1u) << 5); + } - static constexpr QuantMode fp8Qdq() noexcept - { - return QuantMode(BaseType(1u) << 8); - } + static constexpr QuantMode int8KvCache() noexcept { + return QuantMode(BaseType(1u) << 6); + } - static constexpr QuantMode fp8RowWise() noexcept - { - return QuantMode(BaseType(1u) << 3 | BaseType(1u) << 4 | BaseType(1u) << 9); - } + static constexpr QuantMode fp8KvCache() noexcept { + return QuantMode(BaseType(1u) << 7); + } - constexpr BaseType value() const noexcept - { - return mValue; - } + static constexpr QuantMode fp8Qdq() noexcept { + return QuantMode(BaseType(1u) << 8); + } - constexpr bool isSet(QuantMode const& mode) const noexcept - { - return (mValue & mode.value()) == mode.value(); - } + static constexpr QuantMode fp8RowWise() noexcept { + return QuantMode(BaseType(1u) << 3 | BaseType(1u) << 4 | BaseType(1u) << 9); + } - constexpr bool hasInt4Weights() const noexcept - { - return isSet(int4Weights()); - } + constexpr BaseType value() const noexcept { return mValue; } - constexpr bool hasInt8Weights() const noexcept - { - return isSet(int8Weights()); - } + constexpr bool isSet(QuantMode const& mode) const noexcept { + return (mValue & mode.value()) == mode.value(); + } - constexpr bool hasActivations() const noexcept - { - return isSet(activations()); - } + constexpr bool hasInt4Weights() const noexcept { + return isSet(int4Weights()); + } - constexpr bool hasPerChannelScaling() const noexcept - { - return isSet(perChannelScaling()); - } + constexpr bool hasInt8Weights() const noexcept { + return isSet(int8Weights()); + } - constexpr bool hasPerTokenScaling() const noexcept - { - return isSet(perTokenScaling()); - } + constexpr bool hasActivations() const noexcept { + return isSet(activations()); + } - constexpr bool hasPerGroupScaling() const noexcept - { - return isSet(perGroupScaling()); - } + constexpr bool hasPerChannelScaling() const noexcept { + return isSet(perChannelScaling()); + } - constexpr bool hasStaticActivationScaling() const noexcept - { - return !hasPerTokenScaling(); - } + constexpr bool hasPerTokenScaling() const noexcept { + return isSet(perTokenScaling()); + } - constexpr bool hasInt8KvCache() const noexcept - { - return isSet(int8KvCache()); - } + constexpr bool hasPerGroupScaling() const noexcept { + return isSet(perGroupScaling()); + } - constexpr bool hasFp8KvCache() const noexcept - { - return isSet(fp8KvCache()); - } + constexpr bool hasStaticActivationScaling() const noexcept { + return !hasPerTokenScaling(); + } - constexpr bool hasFp8Qdq() const noexcept - { - return isSet(fp8Qdq()); - } + constexpr bool hasInt8KvCache() const noexcept { + return isSet(int8KvCache()); + } - constexpr bool hasFp8RowWise() const noexcept - { - return isSet(fp8RowWise()); - } + constexpr bool hasFp8KvCache() const noexcept { return isSet(fp8KvCache()); } - constexpr bool hasKvCacheQuant() const noexcept - { - return hasInt8KvCache() || hasFp8KvCache(); - } + constexpr bool hasFp8Qdq() const noexcept { return isSet(fp8Qdq()); } - static constexpr QuantMode fromDescription(bool quantizeWeights = false, bool quantizeActivations = false, - bool perToken = false, bool perChannel = false, bool perGroup = false, bool useInt4Weights = false, - bool useInt8KvCache = false, bool useFp8KvCache = false, bool useFp8Qdq = false, bool useFp8RowWise = false) - { - QuantMode quantMode{}; - if (quantizeWeights) - { - if (useInt4Weights) - quantMode += int4Weights(); - else - quantMode += int8Weights(); - } - - if (quantizeActivations) - { - quantMode += activations(); - } - - if (perChannel) - { - quantMode += QuantMode::perChannelScaling(); - } - if (perToken) - { - quantMode += QuantMode::perTokenScaling(); - } - if (perGroup) - { - quantMode += QuantMode::perGroupScaling(); - } - - if (useInt8KvCache) - { - quantMode += int8KvCache(); - } - - if (useFp8KvCache) - { - quantMode += fp8KvCache(); - } - - if (useFp8Qdq) - { - quantMode += fp8Qdq(); - } - - if (useFp8RowWise) - { - quantMode += fp8RowWise(); - } - - return quantMode; - } + constexpr bool hasFp8RowWise() const noexcept { return isSet(fp8RowWise()); } - static constexpr QuantMode useSmoothQuant(bool perToken = false, bool perChannel = false) - { - return fromDescription(true, true, perToken, perChannel); - } + constexpr bool hasKvCacheQuant() const noexcept { + return hasInt8KvCache() || hasFp8KvCache(); + } - static constexpr QuantMode useWeightOnly(bool useInt4Weights = false, bool perGroup = false) - { - return fromDescription(true, false, false, false, perGroup, useInt4Weights); + static constexpr QuantMode fromDescription(bool quantizeWeights = false, + bool quantizeActivations = false, + bool perToken = false, + bool perChannel = false, + bool perGroup = false, + bool useInt4Weights = false, + bool useInt8KvCache = false, + bool useFp8KvCache = false, + bool useFp8Qdq = false, + bool useFp8RowWise = false) { + QuantMode quantMode{}; + if (quantizeWeights) { + if (useInt4Weights) + quantMode += int4Weights(); + else + quantMode += int8Weights(); } - static const QuantMode fromQuantAlgo( - std::optional quantAlgo = std::nullopt, std::optional kvCacheQuantAlgo = std::nullopt) - { - QuantMode quantMode{}; - if (quantAlgo == "W8A16") - { - quantMode = useWeightOnly(false, false); - } - else if (quantAlgo == "W4A16") - { - quantMode = useWeightOnly(true, false); - } - else if (quantAlgo == "W4A16_AWQ") - { - quantMode = useWeightOnly(true, true); - } - else if (quantAlgo == "W4A8_AWQ") - { - quantMode = useWeightOnly(true, true); - } - else if (quantAlgo == "W4A16_GPTQ") - { - quantMode = useWeightOnly(true, true); - } - else if (quantAlgo == "W8A8_SQ_PER_CHANNEL") - { - quantMode = useSmoothQuant(false, true); - } - else if (quantAlgo == "W8A8_SQ_PER_TENSOR_PLUGIN") - { - quantMode = useSmoothQuant(false, false); - } - else if (quantAlgo == "W8A8_SQ_PER_CHANNEL_PER_TOKEN_PLUGIN") - { - quantMode = useSmoothQuant(true, true); - } - else if (quantAlgo == "W8A8_SQ_PER_CHANNEL_PER_TENSOR_PLUGIN") - { - quantMode = useSmoothQuant(false, true); - } - else if (quantAlgo == "W8A8_SQ_PER_TENSOR_PER_TOKEN_PLUGIN") - { - quantMode = useSmoothQuant(true, false); - } - else if (quantAlgo == "FP8") - { - quantMode = fromDescription(false, false, false, false, false, false, false, false, true); - } - else if (quantAlgo == "FP8_ROWWISE") - { - quantMode = fromDescription(false, false, true, true, false, false, false, false, false, true); - } - - if (kvCacheQuantAlgo == "INT8") - { - quantMode += int8KvCache(); - } - else if (kvCacheQuantAlgo == "FP8") - { - quantMode += fp8KvCache(); - } - - return quantMode; + if (quantizeActivations) { + quantMode += activations(); } - constexpr QuantMode operator+(QuantMode const& other) const noexcept - { - return QuantMode(mValue | other.mValue); + if (perChannel) { + quantMode += QuantMode::perChannelScaling(); + } + if (perToken) { + quantMode += QuantMode::perTokenScaling(); + } + if (perGroup) { + quantMode += QuantMode::perGroupScaling(); } - constexpr QuantMode& operator+=(QuantMode const& other) noexcept - { - return *this = *this + other; + if (useInt8KvCache) { + quantMode += int8KvCache(); } - constexpr QuantMode operator-(QuantMode const& other) const noexcept - { - return QuantMode(mValue & ~other.mValue); + if (useFp8KvCache) { + quantMode += fp8KvCache(); } - constexpr QuantMode& operator-=(QuantMode const& other) noexcept - { - return *this = *this - other; + if (useFp8Qdq) { + quantMode += fp8Qdq(); } - constexpr bool operator==(QuantMode const& other) const noexcept - { - return mValue == other.mValue; + if (useFp8RowWise) { + quantMode += fp8RowWise(); } - constexpr bool operator!=(QuantMode const& other) const noexcept - { - return !(*this == other); + return quantMode; + } + + static constexpr QuantMode useSmoothQuant(bool perToken = false, + bool perChannel = false) { + return fromDescription(true, true, perToken, perChannel); + } + + static constexpr QuantMode useWeightOnly(bool useInt4Weights = false, + bool perGroup = false) { + return fromDescription(true, false, false, false, perGroup, useInt4Weights); + } + + static const QuantMode fromQuantAlgo( + std::optional quantAlgo = std::nullopt, + std::optional kvCacheQuantAlgo = std::nullopt) { + QuantMode quantMode{}; + if (quantAlgo == "W8A16") { + quantMode = useWeightOnly(false, false); + } else if (quantAlgo == "W4A16") { + quantMode = useWeightOnly(true, false); + } else if (quantAlgo == "W4A16_AWQ") { + quantMode = useWeightOnly(true, true); + } else if (quantAlgo == "W4A8_AWQ") { + quantMode = useWeightOnly(true, true); + } else if (quantAlgo == "W4A16_GPTQ") { + quantMode = useWeightOnly(true, true); + } else if (quantAlgo == "W8A8_SQ_PER_CHANNEL") { + quantMode = useSmoothQuant(false, true); + } else if (quantAlgo == "W8A8_SQ_PER_TENSOR_PLUGIN") { + quantMode = useSmoothQuant(false, false); + } else if (quantAlgo == "W8A8_SQ_PER_CHANNEL_PER_TOKEN_PLUGIN") { + quantMode = useSmoothQuant(true, true); + } else if (quantAlgo == "W8A8_SQ_PER_CHANNEL_PER_TENSOR_PLUGIN") { + quantMode = useSmoothQuant(false, true); + } else if (quantAlgo == "W8A8_SQ_PER_TENSOR_PER_TOKEN_PLUGIN") { + quantMode = useSmoothQuant(true, false); + } else if (quantAlgo == "FP8") { + quantMode = fromDescription( + false, false, false, false, false, false, false, false, true); + } else if (quantAlgo == "FP8_ROWWISE") { + quantMode = fromDescription( + false, false, true, true, false, false, false, false, false, true); } -private: - BaseType mValue{0}; + if (kvCacheQuantAlgo == "INT8") { + quantMode += int8KvCache(); + } else if (kvCacheQuantAlgo == "FP8") { + quantMode += fp8KvCache(); + } + + return quantMode; + } + + constexpr QuantMode operator+(QuantMode const& other) const noexcept { + return QuantMode(mValue | other.mValue); + } + + constexpr QuantMode& operator+=(QuantMode const& other) noexcept { + return *this = *this + other; + } + + constexpr QuantMode operator-(QuantMode const& other) const noexcept { + return QuantMode(mValue & ~other.mValue); + } + + constexpr QuantMode& operator-=(QuantMode const& other) noexcept { + return *this = *this - other; + } + + constexpr bool operator==(QuantMode const& other) const noexcept { + return mValue == other.mValue; + } + + constexpr bool operator!=(QuantMode const& other) const noexcept { + return !(*this == other); + } + + private: + BaseType mValue{0}; }; -} // namespace common +} // namespace common diff --git a/custom_ops/gpu_ops/cpp_extensions.cc b/custom_ops/gpu_ops/cpp_extensions.cc index e1d48f41ca2..40898434bf1 100644 --- a/custom_ops/gpu_ops/cpp_extensions.cc +++ b/custom_ops/gpu_ops/cpp_extensions.cc @@ -13,19 +13,20 @@ // limitations under the License. #include "paddle/extension.h" +#include "pybind11/numpy.h" #include "pybind11/pybind11.h" namespace py = pybind11; // 自定义异常类,用于处理CUDA错误 class CudaError : public std::exception { -public: + public: explicit CudaError(cudaError_t error) : error_(error) {} - const char *what() const noexcept override { + const char* what() const noexcept override { return cudaGetErrorString(error_); } -private: + private: cudaError_t error_; }; @@ -39,156 +40,320 @@ void check_cuda_error(cudaError_t error) { // 封装cudaHostAlloc的Python函数 uintptr_t cuda_host_alloc(size_t size, unsigned int flags = cudaHostAllocDefault) { - void *ptr = nullptr; + void* ptr = nullptr; check_cuda_error(cudaHostAlloc(&ptr, size, flags)); return reinterpret_cast(ptr); } // 封装cudaFreeHost的Python函数 void cuda_host_free(uintptr_t ptr) { - check_cuda_error(cudaFreeHost(reinterpret_cast(ptr))); + check_cuda_error(cudaFreeHost(reinterpret_cast(ptr))); } +paddle::Tensor CustomNumpyToTensor(py::array numpy_array, + paddle::Tensor tensor) { + py::buffer_info buf_info = numpy_array.request(); + void* numpy_data = buf_info.ptr; + size_t data_size = buf_info.size * buf_info.itemsize; + auto stream = tensor.stream(); + cudaMemcpyAsync((void*)(tensor.data()), + numpy_data, + data_size, + cudaMemcpyHostToDevice, + stream); + return tensor; +} + +void FlashAttentionMask(const paddle::Tensor& q_input, + const paddle::Tensor& k_input, + const paddle::Tensor& v_input, + const paddle::Tensor& cu_seq_q, + const paddle::Tensor& cu_seq_k, + const paddle::Tensor& seq_len_encoder, + const paddle::Tensor& attn_out, + const paddle::optional& mask, + const int head_num, + const int kv_head_num, + const int head_dim); + std::vector AppendAttention( - const paddle::Tensor &qkv, const paddle::Tensor &key_cache, - const paddle::Tensor &value_cache, const paddle::Tensor &seq_lens_encoder, - const paddle::Tensor &seq_lens_decoder, - const paddle::Tensor &seq_lens_this_time, - const paddle::Tensor &batch_id_per_token, const paddle::Tensor &cu_seqlens_q, - const paddle::Tensor &block_tables, const paddle::Tensor &encoder_batch_ids, - const paddle::Tensor &encoder_tile_ids_per_batch, - const paddle::Tensor &encoder_num_blocks, - const paddle::Tensor &kv_batch_ids, - const paddle::Tensor &kv_tile_ids_per_batch, - const paddle::Tensor &kv_num_blocks, - const paddle::Tensor &decoder_batch_ids, - const paddle::Tensor &decoder_tile_ids_per_batch, - const paddle::Tensor &decoder_num_blocks, - const paddle::Tensor &set_max_lengths, const paddle::Tensor &max_len_kv, - const paddle::optional &rotary_embs, - const paddle::optional &attn_mask, - const paddle::optional &qkv_bias, - const paddle::optional &qkv_out_scales, - const paddle::optional &cache_k_quant_scales, - const paddle::optional &cache_v_quant_scales, - const paddle::optional &cache_k_dequant_scales, - const paddle::optional &cache_v_dequant_scales, - const paddle::optional &cache_k_zp, - const paddle::optional &cache_v_zp, - const paddle::optional &out_linear_shifts, - const paddle::optional &out_linear_smooths, - const paddle::optional &kv_signal_data, - const std::string &compute_dtype, const std::string &cache_quant_type_str, - const bool use_neox_rotary_style, const bool rope_3d, - const int max_input_length, const float quant_max_bound, - const float quant_min_bound, const float out_linear_in_scale, - const int encoder_block_shape_q, const int decoder_block_shape_q, - const int max_partition_size, const int encoder_max_partition_size, - const int speculate_max_draft_token_num, const bool causal, - const bool speculate_decoder); + const paddle::Tensor& qkv, + const paddle::Tensor& key_cache, + const paddle::Tensor& value_cache, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& batch_id_per_token, + const paddle::Tensor& cu_seqlens_q, + const paddle::Tensor& block_tables, + const paddle::Tensor& encoder_batch_ids, + const paddle::Tensor& encoder_tile_ids_per_batch, + const paddle::Tensor& encoder_num_blocks, + const paddle::Tensor& kv_batch_ids, + const paddle::Tensor& kv_tile_ids_per_batch, + const paddle::Tensor& kv_num_blocks, + const paddle::Tensor& decoder_batch_ids, + const paddle::Tensor& decoder_tile_ids_per_batch, + const paddle::Tensor& decoder_num_blocks_cpu, + const paddle::Tensor& set_max_lengths, + const paddle::optional& rotary_embs, + const paddle::optional& attn_mask, + const paddle::optional& qkv_bias, + const paddle::optional& qkv_out_scales, + const paddle::optional& cache_k_quant_scales, + const paddle::optional& cache_v_quant_scales, + const paddle::optional& cache_k_dequant_scales, + const paddle::optional& cache_v_dequant_scales, + const paddle::optional& cache_k_zp, + const paddle::optional& cache_v_zp, + const paddle::optional& out_linear_shifts, + const paddle::optional& out_linear_smooths, + const paddle::optional& mask_offset, + const paddle::optional& kv_signal_data, + const paddle::optional& q_norm_weight, + const paddle::optional& k_norm_weight, + const paddle::optional& sinks, + const float rms_norm_eps, + const std::string& compute_dtype, + const std::string& cache_quant_type_str, + const bool use_neox_rotary_style, + const bool rope_3d, + const int max_input_length, + const float quant_max_bound, + const float quant_min_bound, + const float out_linear_in_scale, + const int encoder_block_shape_q, + const int decoder_block_shape_q, + const int max_partition_size, + const int encoder_max_partition_size, + const int speculate_max_draft_token_num, + const bool causal, + const bool speculate_decoder, + const int sliding_window, + const int sink_size); + +std::vector AppendAttentionWithOutput( + const paddle::Tensor& qkv, + const paddle::Tensor& key_cache, + const paddle::Tensor& value_cache, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& batch_id_per_token, + const paddle::Tensor& cu_seqlens_q, + const paddle::Tensor& block_tables, + const paddle::Tensor& encoder_batch_ids, + const paddle::Tensor& encoder_tile_ids_per_batch, + const paddle::Tensor& encoder_num_blocks, + const paddle::Tensor& kv_batch_ids, + const paddle::Tensor& kv_tile_ids_per_batch, + const paddle::Tensor& kv_num_blocks, + const paddle::Tensor& decoder_batch_ids, + const paddle::Tensor& decoder_tile_ids_per_batch, + const paddle::Tensor& decoder_num_blocks_cpu, + const paddle::Tensor& set_max_lengths, + paddle::Tensor& fmha_out, + const paddle::optional& rotary_embs, + const paddle::optional& attn_mask, + const paddle::optional& qkv_bias, + const paddle::optional& qkv_out_scales, + const paddle::optional& cache_k_quant_scales, + const paddle::optional& cache_v_quant_scales, + const paddle::optional& cache_k_dequant_scales, + const paddle::optional& cache_v_dequant_scales, + const paddle::optional& cache_k_zp, + const paddle::optional& cache_v_zp, + const paddle::optional& out_linear_shifts, + const paddle::optional& out_linear_smooths, + const paddle::optional& mask_offset, + const paddle::optional& kv_signal_data, + const paddle::optional& q_norm_weight, + const paddle::optional& k_norm_weight, + const paddle::optional& sinks, + const float rms_norm_eps, + const std::string& compute_dtype, + const std::string& cache_quant_type_str, + const bool use_neox_rotary_style, + const bool rope_3d, + const int max_input_length, + const float quant_max_bound, + const float quant_min_bound, + const float out_linear_in_scale, + const int encoder_block_shape_q, + const int decoder_block_shape_q, + const int max_partition_size, + const int encoder_max_partition_size, + const int speculate_max_draft_token_num, + const bool causal, + const bool speculate_decoder, + const int sliding_window, + const int sink_size); std::vector GQARopeWriteCacheKernel( - const paddle::Tensor &qkv, const paddle::Tensor &key_cache, - const paddle::Tensor &value_cache, const paddle::Tensor &cu_seqlens_q, - const paddle::Tensor &cu_seqlens_k, const paddle::Tensor &rotary_embs, - const paddle::Tensor &seq_lens_this_time, - const paddle::Tensor &seq_lens_encoder, - const paddle::Tensor &seq_lens_decoder, - const paddle::Tensor &batch_id_per_token, - const paddle::Tensor &block_tables, const paddle::Tensor &kv_batch_ids, - const paddle::Tensor &kv_tile_ids, const paddle::Tensor &kv_num_blocks, - const paddle::Tensor &cache_batch_ids, const paddle::Tensor &cache_tile_ids, - const paddle::Tensor &cache_num_blocks, - const paddle::optional &cache_k_quant_scales, - const paddle::optional &cache_v_quant_scales, - const paddle::optional &cache_k_dequant_scales, - const paddle::optional &cache_v_dequant_scales, - const paddle::optional &cache_k_zp, - const paddle::optional &cache_v_zp, - const paddle::optional &kv_signal_data, - const int kv_token_num, const int max_seq_len, - const std::string &cache_quant_type); - -std::vector -PreCacheLenConcat(const paddle::Tensor &seq_lens_decoder, - const paddle::Tensor &seq_lens_this_time, - const int max_dec_len, const int block_size); + const paddle::Tensor& qkv, + const paddle::Tensor& key_cache, + const paddle::Tensor& value_cache, + const paddle::Tensor& cu_seqlens_q, + const paddle::Tensor& cu_seqlens_k, + const paddle::Tensor& rotary_embs, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& batch_id_per_token, + const paddle::Tensor& block_tables, + const paddle::Tensor& kv_batch_ids, + const paddle::Tensor& kv_tile_ids, + const paddle::Tensor& kv_num_blocks, + const paddle::Tensor& cache_batch_ids, + const paddle::Tensor& cache_tile_ids, + const paddle::Tensor& cache_num_blocks, + const paddle::optional& q_norm_weight, + const paddle::optional& k_norm_weight, + const paddle::optional& cache_k_quant_scales, + const paddle::optional& cache_v_quant_scales, + const paddle::optional& cache_k_dequant_scales, + const paddle::optional& cache_v_dequant_scales, + const paddle::optional& cache_k_zp, + const paddle::optional& cache_v_zp, + const paddle::optional& kv_signal_data, + const int kv_token_num, + const int max_seq_len, + const float rms_norm_eps, + const bool use_neox_rotary_style, + const std::string& cache_quant_type, + const bool rope_3d); + +std::vector PreCacheLenConcat( + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& seq_lens_this_time, + const int max_dec_len, + const int block_size); paddle::Tensor FusedExpertMoeFunc( - const paddle::Tensor &input, const paddle::Tensor &gate_weight, - const paddle::Tensor &up_gate_proj_weight, const paddle::Tensor &down_proj_weight, - const paddle::optional &up_gate_proj_bias, - const paddle::optional &up_gate_proj_scale, - const paddle::optional &down_proj_bias, - const paddle::optional &down_proj_scale, - const std::string &quant_method, const int moe_topk, - const bool norm_topk_prob, const bool group_moe); + const paddle::Tensor& input, + const paddle::Tensor& gate_weight, + const paddle::Tensor& up_gate_proj_weight, + const paddle::Tensor& down_proj_weight, + const paddle::optional& up_gate_proj_bias, + const paddle::optional& up_gate_proj_scale, + const paddle::optional& down_proj_bias, + const paddle::optional& down_proj_scale, + const std::string& quant_method, + const int moe_topk, + const bool norm_topk_prob, + const bool group_moe); + +std::vector MacheteMMKernel( + paddle::Tensor const& A, + paddle::Tensor const& B, + paddle::optional const& maybe_group_scales, + paddle::optional const& maybe_group_zeros, + paddle::optional const& maybe_channel_scales, + paddle::optional const& maybe_token_scales, + std::string const& b_type_str, + std::string const& maybe_out_type_str, + int64_t const& maybe_group_size, + std::string const& maybe_schedule); + +std::vector MachetePrepackBKernel( + paddle::Tensor const& B, + std::string const& a_type_str, + std::string const& b_type_str, + std::string const& maybe_group_scales_type_str); + +std::vector MacheteSupportedSchedules( + std::string const& a_type_str, std::string const& b_type_str); std::vector MoeExpertDispatch( - const paddle::Tensor &input, const paddle::Tensor &gating_output, - const paddle::optional &gating_correction_bias, - const paddle::optional &w4a8_in_scale, const int moe_topk, - const bool group_moe, const bool topk_only_mode); - -std::vector -MoETopKSelectKernel(const paddle::Tensor &gating_logits, - const paddle::optional &bias, - const int moe_topk, const bool apply_norm_weight, - const bool enable_softmax_top_k_fused); - -std::vector -MoERedundantTopKSelectKernel(const paddle::Tensor &gating_logits, - const paddle::Tensor &expert_id_to_ep_rank_array, - const paddle::Tensor &expert_in_rank_num_list, - paddle::Tensor &tokens_per_expert_stats_list, - const paddle::optional &bias, - const int moe_topk, const bool apply_norm_weight, - const bool enable_softmax_top_k_fused, - const int redundant_ep_rank_num_plus_one); - -std::vector -EPMoeExpertDispatch(const paddle::Tensor &input, const paddle::Tensor &topk_ids, - const paddle::Tensor &topk_weights, - const paddle::optional &up_gate_proj_in_scale, - const std::vector &token_nums_per_expert, - const int token_nums_this_rank, - const std::string &moe_quant_type); + const paddle::Tensor& input, + const paddle::Tensor& gating_output, + const paddle::optional& gating_correction_bias, + const paddle::optional& w4a8_in_scale, + const int moe_topk, + const bool group_moe, + const std::string& moe_quant_type, + const bool topk_only_mode); + +std::vector MoETopKSelectKernel( + const paddle::Tensor& gating_logits, + const paddle::optional& bias, + const int moe_topk, + const bool apply_norm_weight, + const bool enable_softmax_top_k_fused); + +std::vector MoERedundantTopKSelectKernel( + const paddle::Tensor& gating_logits, + const paddle::Tensor& expert_id_to_ep_rank_array, + const paddle::Tensor& expert_in_rank_num_list, + paddle::Tensor& tokens_per_expert_stats_list, + const paddle::optional& bias, + const int moe_topk, + const bool apply_norm_weight, + const bool enable_softmax_top_k_fused, + const int redundant_ep_rank_num_plus_one); + +std::vector EPMoeExpertDispatch( + const paddle::Tensor& input, + const paddle::Tensor& topk_ids, + const paddle::Tensor& topk_weights, + const paddle::optional& up_gate_proj_in_scale, + const std::vector& token_nums_per_expert, + const int token_nums_this_rank, + const std::string& moe_quant_type); std::vector EPMoeExpertDispatchFP8( - const paddle::Tensor &input, const paddle::Tensor &scale, - const paddle::Tensor &topk_ids, const paddle::Tensor &topk_weights, - const paddle::Tensor &token_nums_per_expert, - const paddle::Tensor &token_nums_per_expert_padded, - const bool use_in_ep, const int token_nums_this_rank_padded); - -std::vector PerTokenQuant(paddle::Tensor &input, - const int block_size); -std::vector PerTokenQuantPadding(paddle::Tensor &input, - const int block_size); -std::vector -MaskedPerTokenQuant(paddle::Tensor &input, paddle::Tensor &recv_expert_count, - const int block_size); + const paddle::Tensor& input, + const paddle::Tensor& scale, + const paddle::Tensor& topk_ids, + const paddle::Tensor& topk_weights, + const paddle::Tensor& token_nums_per_expert, + const paddle::Tensor& token_nums_per_expert_padded, + const bool use_in_ep, + const int token_nums_this_rank_padded); + +std::vector PerTokenQuant(paddle::Tensor& input, + const int block_size, + const bool use_ue8m0); +std::vector PerTokenQuantPadding(paddle::Tensor& input, + const int block_size, + const bool use_ue8m0); + +std::vector FusedMaskSwigluFP8Quant( + paddle::Tensor& input, + paddle::Tensor& token_nums_per_expert, + const int block_size, + const bool use_ue8m0); std::vector EPMoeExpertCombine( - const paddle::Tensor &ffn_out, const paddle::Tensor &expert_scales_float, - const paddle::Tensor &permute_indices_per_token, - const paddle::Tensor &top_k_indices, - const paddle::optional &down_proj_bias, - const bool norm_topk_prob, const float routed_scaling_factor); - -std::vector> GetExpertTokenNum(const paddle::Tensor &topk_ids, + const paddle::Tensor& ffn_out, + const paddle::Tensor& expert_scales_float, + const paddle::Tensor& permute_indices_per_token, + const paddle::Tensor& top_k_indices, + const paddle::optional& down_proj_bias, + const bool norm_topk_prob, + const float routed_scaling_factor); + +std::vector> GetExpertTokenNum(const paddle::Tensor& topk_ids, const int num_experts); paddle::Tensor MoeExpertFFNFunc( const paddle::Tensor& permute_input, const paddle::Tensor& tokens_expert_prefix_sum, - const paddle::Tensor& up_gate_proj_weight, const paddle::Tensor& down_proj_weight, + const paddle::Tensor& up_gate_proj_weight, + const paddle::Tensor& down_proj_weight, + const paddle::optional& up_proj_in_scale, const paddle::optional& up_gate_proj_bias, const paddle::optional& up_gate_proj_scale, const paddle::optional& down_proj_scale, const paddle::optional& down_proj_in_scale, const paddle::optional& expert_idx_per_token, - const std::string& quant_method, const bool used_in_ep_low_latency); + const paddle::optional& max_tokens_per_expert, + const std::string& quant_method, + const bool used_in_ep_low_latency, + const int estimate_total_token_nums, + const int hadamard_block_size, + const std::string& activation); paddle::Tensor MoeExpertFFNWint2Func( const paddle::Tensor& permute_input, @@ -207,116 +372,172 @@ paddle::Tensor MoeExpertFFNWint2Func( const bool used_in_ep_low_latency); paddle::Tensor MoeExpertReduceFunc( - const paddle::Tensor &ffn_out, const paddle::Tensor &top_k_weight, - const paddle::Tensor &permute_indices_per_token, - const paddle::Tensor &top_k_indices, - const paddle::optional &down_proj_bias, - const bool norm_topk_prob, const float routed_scaling_factor); - -void InitKVSignalPerQuery(const paddle::Tensor &seq_lens_encoder_tensor, - const paddle::Tensor &seq_lens_this_time_tensor, - const paddle::Tensor &seq_lens_decoder_tensor, - const int rank, const int num_layers); - -void GetOutputKVSignal(const paddle::Tensor &x, int64_t rank_id, + const paddle::Tensor& ffn_out, + const paddle::Tensor& top_k_weight, + const paddle::Tensor& permute_indices_per_token, + const paddle::Tensor& top_k_indices, + const paddle::optional& down_proj_bias, + const bool norm_topk_prob, + const float routed_scaling_factor); + +void InitKVSignalPerQuery(const paddle::Tensor& seq_lens_encoder_tensor, + const paddle::Tensor& seq_lens_this_time_tensor, + const paddle::Tensor& seq_lens_decoder_tensor, + const int rank, + const int num_layers); + +void GetOutputKVSignal(const paddle::Tensor& x, + int64_t rank_id, bool wait_flag); -paddle::Tensor DequantInt8Func(const paddle::Tensor &input, - const paddle::Tensor &out_scale, +paddle::Tensor DequantInt8Func(const paddle::Tensor& input, + const paddle::Tensor& out_scale, std::string dtype); -paddle::Tensor OpenShmAndGetMetaSignalFunc(const int rank, const int device_id, +paddle::Tensor OpenShmAndGetMetaSignalFunc(const int rank, + const int device_id, const bool keep_pd_step_flag); -paddle::Tensor InitSignalLayerwiseFunc(const paddle::Tensor &kv_signal_metadata, +paddle::Tensor InitSignalLayerwiseFunc(const paddle::Tensor& kv_signal_metadata, const int layer_id); -std::vector GetBlockShapeAndSplitKVBlock( - const paddle::Tensor &seq_lens_encoder, - const paddle::Tensor &seq_lens_decoder, - const paddle::Tensor &seq_lens_this_time, - const int encoder_block_shape_q, const int decoder_block_shape_q, - const int group_size, const int block_size, - const int decoder_step_token_num); - -std::vector GetPaddingOffset(const paddle::Tensor &input_ids, - const paddle::Tensor &cum_offsets, - const paddle::Tensor &token_num, - const paddle::Tensor &seq_len); - -void SetValueByFlagsAndIdx(const paddle::Tensor &pre_ids_all, - const paddle::Tensor &input_ids, - const paddle::Tensor &seq_lens_this_time, - const paddle::Tensor &seq_lens_encoder, - const paddle::Tensor &seq_lens_decoder, - const paddle::Tensor &step_idx, - const paddle::Tensor &stop_flags); +void GetBlockShapeAndSplitKVBlock( + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& seq_lens_this_time, + paddle::Tensor& decoder_batch_ids, // Inplace + paddle::Tensor& decoder_tile_ids_per_batch, // Inplace + paddle::Tensor& decoder_num_blocks_cpu, // Inplace, Pinned Memory + paddle::Tensor& decoder_num_blocks_device, // Inplace + paddle::Tensor& decoder_chunk_size_device, // Inplace + paddle::Tensor& max_len_tensor_cpu, // Inplace, Pinned Memory + paddle::Tensor& encoder_batch_ids, // Inplace + paddle::Tensor& encoder_tile_ids_per_batch, // Inplace + paddle::Tensor& encoder_num_blocks_x_cpu, // Inplace, Pinned Memory + paddle::Tensor& kv_batch_ids, // Inplace + paddle::Tensor& kv_tile_ids_per_batch, // Inplace + paddle::Tensor& kv_num_blocks_x_cpu, // Inplace, Pinned Memory + const int encoder_block_shape_q, + const int decoder_block_shape_q, + const int group_size, + const int block_size); + +std::vector GetPaddingOffset( + const paddle::Tensor& input_ids, + const paddle::Tensor& seq_len, + const paddle::optional& draft_tokens, + const paddle::optional& seq_lens_encoder, + const int64_t token_num_cpu); + +void SetValueByFlagsAndIdx(const paddle::Tensor& token_ids_all, + const paddle::Tensor& input_ids, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& prompt_lens, + const paddle::Tensor& step_idx, + const paddle::Tensor& stop_flags); paddle::Tensor RebuildPaddingFunc( - const paddle::Tensor &tmp_out, // [token_num, dim_embed] - const paddle::Tensor &cum_offsets, // [bsz, 1] - const paddle::Tensor &seq_len_this_time, - const paddle::Tensor &seq_lens_decoder, - const paddle::Tensor &seq_lens_encoder, - const paddle::optional &output_padding_offset, - int max_input_length); - -void GetStopFlagsMulti(const paddle::Tensor &topk_ids, - const paddle::Tensor &stop_flags, - const paddle::Tensor &seq_lens, - const paddle::Tensor &end_ids, - const paddle::Tensor &next_tokens, + const paddle::Tensor& tmp_out, // [token_num, dim_embed] + const paddle::Tensor& cum_offsets, // [bsz, 1] + const paddle::Tensor& seq_len_this_time, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& seq_lens_encoder, + const paddle::optional& batch_id_per_token_output, + const paddle::optional& cu_seqlens_q_output, + const paddle::optional& first_token_out, + bool enable_logprob); + +void GetStopFlagsMulti(const paddle::Tensor& topk_ids, + const paddle::Tensor& stop_flags, + const paddle::Tensor& seq_lens, + const paddle::Tensor& end_ids, + const paddle::Tensor& next_tokens, + const paddle::Tensor& token_ids_all, + const paddle::Tensor& prompt_lens, + const paddle::Tensor& step_idx, + const paddle::Tensor& stop_seqs, + const paddle::Tensor& stop_seqs_len, + const paddle::Tensor& min_tokens, const bool beam_search); -void GetStopFlagsMultiSeqs( - const paddle::Tensor &topk_ids, const paddle::Tensor &pre_ids, - const paddle::Tensor &step_idx, const paddle::Tensor &stop_flags, - const paddle::Tensor &seq_lens, const paddle::Tensor &stop_seqs, - const paddle::Tensor &stop_seqs_len, const paddle::Tensor &end_ids); - -void UpdateInputes(const paddle::Tensor &stop_flags, - const paddle::Tensor ¬_need_stop, // only on cpu - const paddle::Tensor &seq_lens_this_time, - const paddle::Tensor &seq_lens_encoder, - const paddle::Tensor &seq_lens_decoder, - const paddle::Tensor &input_ids, - const paddle::Tensor &stop_nums, - const paddle::Tensor &next_tokens, - const paddle::Tensor &is_block_step); - -paddle::Tensor -GroupSwigluWithMasked(const paddle::Tensor &fc1_out_tensor, - const paddle::Tensor &token_nums_per_expert); +void UpdateInputs(const paddle::Tensor& stop_flags, + const paddle::Tensor& not_need_stop, // on device + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& input_ids, + const paddle::Tensor& next_tokens, + const paddle::Tensor& is_block_step); + +void UpdateInputsV1(const paddle::Tensor& stop_flags, + const paddle::Tensor& not_need_stop, // on device + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& step_seq_lens_decoder, + const paddle::Tensor& prompt_lens, + const paddle::Tensor& topk_ids, + const paddle::Tensor& input_ids, + const paddle::Tensor& block_tables, + const paddle::Tensor& next_tokens, + const paddle::Tensor& is_block_step, + const int block_size); + +void RecoverDecodeTask( + const paddle::Tensor& stop_flags, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& step_seq_lens_decoder, + const paddle::Tensor& block_tables, + const paddle::Tensor& is_block_step, + const paddle::optional& draft_tokens, + const paddle::optional& step_draft_tokens, + const paddle::optional& step_seq_lens_this_time, + const int block_size, + const int max_draft_tokens); + +paddle::Tensor GroupSwigluWithMasked( + const paddle::Tensor& fc1_out_tensor, + const paddle::Tensor& token_nums_per_expert); std::vector ExtractTextTokenOutput( - const paddle::Tensor &max_seq_len, const paddle::Tensor &max_seq_len_index, - const paddle::Tensor &mm_token_num_len, - const paddle::Tensor &seq_lens_this_time, - const paddle::Tensor &cu_seqlens_q, const paddle::Tensor &score_text); + const paddle::Tensor& max_seq_len, + const paddle::Tensor& max_seq_len_index, + const paddle::Tensor& mm_token_num_len, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& cu_seqlens_q, + const paddle::Tensor& hidden_states); -std::vector MoEDeepGEMMPermute(const paddle::Tensor &x, - const paddle::Tensor &topk_idx, +std::vector MoEDeepGEMMPermute(const paddle::Tensor& x, + const paddle::Tensor& topk_idx, const int num_experts, const int max_tokens_per_expert); std::vector MoEDeepGEMMDePermute( - const paddle::Tensor - &ffn_out, // [num_experts, max_tokens_per_expert, hidden] - const paddle::Tensor &permute_indices_per_token, // [token_num, topk}] - const paddle::Tensor &topk_idx, const paddle::Tensor &topk_weights); - -void TextImageIndexOut(const paddle::Tensor &token_type_ids, - const paddle::Tensor &text_input, - const paddle::Tensor &image_input); - -void TextImageGatherScatter(paddle::Tensor &input, paddle::Tensor &text_input, - paddle::Tensor &image_input, - paddle::Tensor &token_type_ids, - paddle::Tensor &text_index, - paddle::Tensor &image_index, const bool is_scatter); - -paddle::Tensor count_tokens_per_expert_func(const paddle::Tensor &topk_ids, - int64_t num_experts); + const paddle::Tensor& + ffn_out, // [num_experts, max_tokens_per_expert, hidden] + const paddle::Tensor& permute_indices_per_token, // [token_num, topk}] + const paddle::Tensor& topk_idx, + const paddle::Tensor& topk_weights); + +void TextImageIndexOut(const paddle::Tensor& token_type_ids, + paddle::Tensor& text_input, + paddle::Tensor& image_input); + +std::vector TextImageGatherScatter( + paddle::Tensor& input, + paddle::Tensor& text_input, + paddle::Tensor& image_input, + paddle::Tensor& token_type_ids, + paddle::Tensor& text_index, + paddle::Tensor& image_index, + const bool is_scatter); + +std::vector count_tokens_per_expert_func( + const paddle::Tensor& topk_ids, int64_t num_experts); void GetPositionIdsAndMaskEncoderBatch( const paddle::Tensor& seq_lens_encoder, const paddle::Tensor& seq_lens_decoder, @@ -337,7 +558,7 @@ std::vector DecodeMLAWriteCacheKernel( const int max_seq_len, const bool speculate_decoder); - std::vector PrefillMLAWriteCacheKernel( +std::vector PrefillMLAWriteCacheKernel( const paddle::Tensor& kv_nope, const paddle::Tensor& kv_pe, const paddle::Tensor& kv_cache, @@ -346,10 +567,10 @@ std::vector DecodeMLAWriteCacheKernel( const paddle::Tensor& batch_id_per_token, const paddle::Tensor& cu_seqlens_q, const paddle::Tensor& block_tables, + const paddle::optional& kv_signal_data, const std::string& cache_quant_type_str, const int max_seq_len); - void FusedRotaryPositionEncoding( paddle::Tensor& query, // [num_tokens, num_heads, head_size] or // [num_tokens, num_heads * head_size] @@ -365,23 +586,18 @@ std::vector MultiHeadLatentAttention( const paddle::Tensor& query, const paddle::Tensor& key_cache, const paddle::Tensor& value_cache, - const paddle::Tensor& seq_lens_encoder, const paddle::Tensor& seq_lens_decoder, const paddle::Tensor& seq_lens_this_time, const paddle::Tensor& cu_seqlens_q, const paddle::Tensor& batch_id_per_token, const paddle::Tensor& block_tables, - const paddle::Tensor& encoder_batch_ids, - const paddle::Tensor& encoder_tile_ids_per_batch, - const paddle::Tensor& encoder_num_blocks, const paddle::Tensor& kv_batch_ids, const paddle::Tensor& kv_tile_ids_per_batch, const paddle::Tensor& kv_num_blocks, const paddle::Tensor& decoder_batch_ids, const paddle::Tensor& decoder_tile_ids_per_batch, - const paddle::Tensor& decoder_num_blocks, - const paddle::Tensor& decoder_num_blocks_cpu, - const paddle::Tensor& max_enc_len_this_time, + const paddle::Tensor& decoder_num_blocks_device, + const paddle::Tensor& decoder_chunk_size_device, const paddle::Tensor& max_dec_len_this_time, const paddle::Tensor& max_len_kv, const paddle::optional& attn_mask, @@ -407,9 +623,10 @@ std::vector MultiHeadLatentAttention( const bool causal, const bool speculate_decoder); - -std::vector tritonmoe_preprocess_kernel(const paddle::Tensor& topk_ids, int64_t num_experts, int64_t GEMM_BLOCK_SIZE_M); - +std::vector tritonmoe_preprocess_kernel( + const paddle::Tensor& topk_ids, + int64_t num_experts, + int64_t GEMM_BLOCK_SIZE_M); std::vector MoeWna16MarlinGemmApi( const paddle::Tensor& a, @@ -437,36 +654,55 @@ std::vector MoeWna16MarlinGemmApi( bool use_atomic_add, bool use_fp32_reduce, bool is_zp_float); -void CutlassScaledMm(paddle::Tensor &c, paddle::Tensor const &a, - paddle::Tensor const &b, paddle::Tensor const &a_scales, - paddle::Tensor const &b_scales, - paddle::optional const &bias); - -void CutlassScaledMmAzp(paddle::Tensor& c, paddle::Tensor const& a, - paddle::Tensor const& b, - paddle::Tensor const& a_scales, - paddle::Tensor const& b_scales, - paddle::Tensor const& azp_adj, - paddle::optional const& azp, - paddle::optional const& bias); - -void StaticScaledFp8Quant(paddle::Tensor &out, paddle::Tensor const &input, - paddle::Tensor const &scale); - -void DynamicScaledFp8Quant(paddle::Tensor &out, paddle::Tensor const &input, - paddle::Tensor &scale); - -void DynamicPerTokenScaledFp8Quant(paddle::Tensor &out, - paddle::Tensor const &input, - paddle::Tensor &scales, float scale_ub); - -std::vector NoauxTc( - paddle::Tensor& scores, - paddle::Tensor& scores_with_bias, - int n_group, - int topk_group, - int topk, - float routed_scaling_factor); +void CutlassScaledMm(paddle::Tensor& c, + paddle::Tensor const& a, + paddle::Tensor const& b, + paddle::Tensor const& a_scales, + paddle::Tensor const& b_scales, + paddle::optional const& bias); + +void CutlassScaledMmAzp(paddle::Tensor& c, + paddle::Tensor const& a, + paddle::Tensor const& b, + paddle::Tensor const& a_scales, + paddle::Tensor const& b_scales, + paddle::Tensor const& azp_adj, + paddle::optional const& azp, + paddle::optional const& bias); + +void StaticScaledFp8Quant(paddle::Tensor& out, + paddle::Tensor const& input, + paddle::Tensor const& scale); + +void DynamicScaledFp8Quant(paddle::Tensor& out, + paddle::Tensor const& input, + paddle::Tensor& scale); + +void DynamicPerTokenScaledFp8Quant(paddle::Tensor& out, + paddle::Tensor const& input, + paddle::Tensor& scales, + float scale_ub); + +std::vector NoauxTc(paddle::Tensor& scores, + paddle::Tensor& scores_with_bias, + int n_group, + int topk_group, + int topk, + bool renormalize, + float routed_scaling_factor); + +std::vector NoauxTcRedundant( + paddle::Tensor& scores, + paddle::Tensor& scores_with_bias, + paddle::Tensor& expert_id_to_ep_rank_array, + paddle::Tensor& expert_in_rank_num_list, + paddle::Tensor& tokens_per_expert_stats_list, + int n_group, + int topk_group, + int topk, + bool renormalize, + float routed_scaling_factor, + int redundant_ep_rank_num_plus_one); #ifdef ENABLE_FP8 paddle::Tensor cutlass_fp8_fp8_half_gemm_func( @@ -479,24 +715,27 @@ paddle::Tensor cutlass_fp8_fp8_half_gemm_func( std::string output_dtype, std::string activation_type); -paddle::Tensor MoeFusedHadamardQuantFp8Func( - const paddle::Tensor &input, - const paddle::Tensor &scale, - const paddle::Tensor &topk_ids, - const int top_k, - const int intermediate_size, - const bool tiled); - -paddle::Tensor FusedHadamardQuantFp8Func( - const paddle::Tensor &input, - const float scale); +paddle::Tensor MoeFusedHadamardQuantFp8Func(const paddle::Tensor& input, + const paddle::Tensor& scale, + const paddle::Tensor& topk_ids, + const int top_k, + const int intermediate_size, + const bool tiled); + +paddle::Tensor FusedHadamardQuantFp8Func(const paddle::Tensor& input, + const float scale); #endif int64_t init_custom_all_reduce(const std::vector& fake_ipc_ptrs, - paddle::Tensor& rank_data, int64_t rank, bool full_nvlink); + paddle::Tensor& rank_data, + int64_t rank, + bool full_nvlink); -void all_reduce(int64_t _fa, paddle::Tensor& inp, paddle::Tensor& out, - int64_t reg_buffer, int64_t reg_buffer_sz_bytes); +void all_reduce(paddle::Tensor& inp, + paddle::Tensor& out, + int64_t _fa, + int64_t reg_buffer, + int64_t reg_buffer_sz_bytes); void dispose(int64_t _fa); @@ -504,7 +743,8 @@ int64_t meta_size(); void register_buffer(int64_t _fa, const std::vector& fake_ipc_ptrs); -std::tuple, std::vector> get_graph_buffer_ipc_meta(int64_t _fa); +std::tuple, std::vector> +get_graph_buffer_ipc_meta(int64_t _fa); void register_graph_buffers(int64_t _fa, const std::vector>& handles, @@ -517,111 +757,207 @@ int64_t open_mem_handle(paddle::Tensor& mem_handle); void free_shared_buffer(int64_t buffer); -// speculative decoding Kernel -std::vector SpeculateGetPaddingOffset( - const paddle::Tensor& input_ids, - const paddle::Tensor& draft_tokens, - const paddle::Tensor& cum_offsets, - const paddle::Tensor& token_num, - const paddle::Tensor& seq_len, - const paddle::Tensor& seq_lens_encoder); +void clear_ipc_handles(int64_t _fa); std::vector SpeculateGetSeqLensOutput( const paddle::Tensor& seq_lens_this_time, const paddle::Tensor& seq_lens_encoder, const paddle::Tensor& seq_lens_decoder); -std::vector SpeculateGetOutputPaddingOffset( - const paddle::Tensor& output_cum_offsets_tmp, - const paddle::Tensor& out_token_num, - const paddle::Tensor& seq_lens_output, - const int max_seq_len); +std::vector SpeculatePreProcess( + const int64_t cpu_token_num, + const paddle::Tensor& input_ids, + const paddle::Tensor& seq_len, + const paddle::Tensor& draft_tokens, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& seq_lens_decoder); +std::vector BuildSamplingParams( + const paddle::Tensor& top_p, + const paddle::Tensor& top_k, + paddle::Tensor& infer_seed, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& cu_seqlens_q_output, + const int64_t token_num_output_cpu, + const int64_t increment_value); + +void SpecTokenPenaltyMultiScores( + const paddle::Tensor& token_ids_all, + const paddle::Tensor& prompt_lens, + const paddle::Tensor& logits, + const paddle::Tensor& penalty_scores, + const paddle::Tensor& frequency_scores, + const paddle::Tensor& presence_scores, + const paddle::Tensor& temperatures, + const paddle::Tensor& bad_tokens, + const paddle::Tensor& bad_tokens_len, + const paddle::Tensor& cur_len, + const paddle::Tensor& min_len, + const paddle::Tensor& eos_token_id, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& batch_id_per_token_output, + const paddle::Tensor& cu_seqlens_q_output, + const int max_seq_len); -void SpecTokenPenaltyMultiScores(const paddle::Tensor &pre_ids, - const paddle::Tensor &logits, - const paddle::Tensor &penalty_scores, - const paddle::Tensor &frequency_scores, - const paddle::Tensor &presence_scores, - const paddle::Tensor &temperatures, - const paddle::Tensor &bad_tokens, - const paddle::Tensor &cur_len, - const paddle::Tensor &min_len, - const paddle::Tensor &eos_token_id, - const paddle::Tensor &seq_lens_this_time, - const paddle::Tensor &output_padding_offset, - const paddle::Tensor &output_cum_offsets, - const int max_seq_len); - -void SpecGetStopFlagsMultiSeqs(const paddle::Tensor &accept_tokens, - const paddle::Tensor &accept_num, - const paddle::Tensor &pre_ids, - const paddle::Tensor &step_idx, - const paddle::Tensor &stop_flags, - const paddle::Tensor &seq_lens, - const paddle::Tensor &stop_seqs, - const paddle::Tensor &stop_seqs_len, - const paddle::Tensor &end_ids); - - -void SpeculateVerify( - const paddle::Tensor &accept_tokens, const paddle::Tensor &accept_num, - const paddle::Tensor &step_idx, const paddle::Tensor &stop_flags, - const paddle::Tensor &seq_lens_encoder, - const paddle::Tensor &seq_lens_decoder, const paddle::Tensor &draft_tokens, - const paddle::Tensor &seq_lens_this_time, - const paddle::Tensor &verify_tokens, const paddle::Tensor &verify_scores, - const paddle::Tensor &max_dec_len, const paddle::Tensor &end_tokens, - const paddle::Tensor &is_block_step, - const paddle::Tensor &output_cum_offsets, - const paddle::Tensor &actual_candidate_len, - const paddle::Tensor &actual_draft_token_nums, const paddle::Tensor &topp, - int max_seq_len, int verify_window, bool enable_topp, bool benchmark_mode); - -void SpeculateUpdateV3(const paddle::Tensor &seq_lens_encoder, - const paddle::Tensor &seq_lens_decoder, - const paddle::Tensor ¬_need_stop, - const paddle::Tensor &draft_tokens, - const paddle::Tensor &actual_draft_token_nums, - const paddle::Tensor &accept_tokens, - const paddle::Tensor &accept_num, - const paddle::Tensor &stop_flags, - const paddle::Tensor &seq_lens_this_time, - const paddle::Tensor &is_block_step, - const paddle::Tensor &stop_nums); - -void SpeculateSetValueByFlagsAndIdx(const paddle::Tensor &pre_ids_all, - const paddle::Tensor &accept_tokens, - const paddle::Tensor &accept_num, - const paddle::Tensor &stop_flags, - const paddle::Tensor &seq_lens_this_time, - const paddle::Tensor &seq_lens_encoder, - const paddle::Tensor &seq_lens_decoder, - const paddle::Tensor &step_idx); +void SpecGetStopFlagsMultiSeqs(const paddle::Tensor& accept_tokens, + const paddle::Tensor& accept_num, + const paddle::Tensor& token_ids_all, + const paddle::Tensor& prompt_lens, + const paddle::Tensor& step_idx, + const paddle::Tensor& stop_flags, + const paddle::Tensor& seq_lens, + const paddle::Tensor& stop_seqs, + const paddle::Tensor& stop_seqs_len, + const paddle::Tensor& end_ids, + const paddle::Tensor& min_tokens); + +void VerifyDraftTokens(const paddle::Tensor& step_output_ids, + const paddle::Tensor& step_output_len, + const paddle::Tensor& step_input_ids, + const paddle::optional& target_tokens, + const paddle::optional& candidate_ids, + const paddle::optional& candidate_scores, + const paddle::optional& candidate_lens, + const paddle::Tensor& topp, + const paddle::Tensor& stop_flags, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& end_tokens, + const paddle::Tensor& is_block_step, + const paddle::Tensor& cu_seqlens_q_output, + const paddle::Tensor& reasoning_status, + const paddle::Tensor& max_dec_len, + const paddle::Tensor& step_idx, + int max_seq_len, + int verify_window, + int verify_strategy, + bool reject_all, + bool accept_all); + +void SpeculateVerify(const paddle::Tensor& sampled_token_ids, + const paddle::Tensor& accept_tokens, + const paddle::Tensor& accept_num, + const paddle::Tensor& step_idx, + const paddle::Tensor& stop_flags, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& draft_tokens, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& verify_tokens, + const paddle::Tensor& verify_scores, + const paddle::Tensor& max_dec_len, + const paddle::Tensor& end_tokens, + const paddle::Tensor& is_block_step, + const paddle::Tensor& cu_seqlens_q_output, + const paddle::Tensor& actual_candidate_len, + const paddle::Tensor& actual_draft_token_nums, + const paddle::Tensor& topp, + const paddle::Tensor& reasoning_status, + int max_seq_len, + int verify_window, + bool enable_topp, + bool benchmark_mode, + bool accept_all_drafts); + +void SpeculateUpdate(const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& not_need_stop, + const paddle::Tensor& draft_tokens, + const paddle::Tensor& actual_draft_token_nums, + const paddle::Tensor& accept_tokens, + const paddle::Tensor& accept_num, + const paddle::Tensor& stop_flags, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& is_block_step, + const paddle::Tensor& mask_rollback); + +void UnifiedUpdateModelStatus(const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& has_running_seqs, + const paddle::Tensor& step_input_ids, + const paddle::Tensor& step_output_ids, + const paddle::Tensor& step_output_len, + const paddle::Tensor& stop_flags, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& is_paused, + const paddle::Tensor& token_ids_all, + const paddle::Tensor& prompt_lens, + const paddle::Tensor& step_idx, + const paddle::Tensor& end_tokens, + const paddle::Tensor& max_dec_len); + +void NaiveUpdateModelStatus(const paddle::Tensor& accept_tokens, + const paddle::Tensor& accept_num, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& next_tokens, + const paddle::Tensor& cu_seqlens_q_output); + +void SpeculateSetValueByFlagsAndIdx(const paddle::Tensor& token_ids_all, + const paddle::Tensor& prompt_lens, + const paddle::Tensor& accept_tokens, + const paddle::Tensor& accept_num, + const paddle::Tensor& stop_flags, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& step_idx); void SpeculateSaveWithOutputMsgStatic(const paddle::Tensor& accept_tokens, const paddle::Tensor& accept_num, const paddle::Tensor& not_need_stop, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& prompt_lens, + const paddle::Tensor& preempted_idx, int64_t rank_id, - bool save_each_rank); - + bool save_each_rank, + bool skip_prefill); void SpeculateClearAcceptNums(const paddle::Tensor& accept_num, const paddle::Tensor& seq_lens_decoder); -void NgramMatch(const paddle::Tensor &input_ids, - const paddle::Tensor &input_ids_len, - const paddle::Tensor &pre_ids, - const paddle::Tensor &step_idx, - const paddle::Tensor &draft_token_num, - const paddle::Tensor &draft_tokens, - const paddle::Tensor &seq_lens_this_time, - const paddle::Tensor &seq_lens_encoder, - const paddle::Tensor &seq_lens_decoder, - const paddle::Tensor &max_dec_len, - const int max_ngram_size, - const int max_draft_tokens); - +void SpeculateScheduleCache(const paddle::Tensor& draft_tokens, + const paddle::Tensor& block_tables, + const paddle::Tensor& stop_flags, + const paddle::Tensor& prompt_lens, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& step_seq_lens_decoder, + const paddle::Tensor& step_draft_tokens, + const paddle::Tensor& step_seq_lens_this_time, + const paddle::Tensor& accept_num, + const paddle::Tensor& accept_tokens, + const paddle::Tensor& is_block_step, + const paddle::Tensor& not_need_stop, + const int block_size, + const int max_draft_tokens); + +void NgramMatch(const paddle::Tensor& input_ids, + const paddle::Tensor& input_ids_len, + const paddle::Tensor& token_ids_all, + const paddle::Tensor& prompt_lens, + const paddle::Tensor& step_idx, + const paddle::Tensor& draft_token_num, + const paddle::Tensor& draft_tokens, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& max_dec_len, + const int max_ngram_size, + const int max_draft_tokens); + +void HybridMtpNgram(const paddle::Tensor& input_ids, + const paddle::Tensor& input_ids_len, + const paddle::Tensor& pre_ids, + const paddle::Tensor& step_idx, + const paddle::Tensor& draft_token_num, + const paddle::Tensor& draft_tokens, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& max_dec_len, + const int max_ngram_size, + const int min_ngram_size, + const int max_draft_tokens); // MTP void DraftModelPostprocess(const paddle::Tensor& base_model_draft_tokens, @@ -629,7 +965,6 @@ void DraftModelPostprocess(const paddle::Tensor& base_model_draft_tokens, const paddle::Tensor& base_model_seq_lens_encoder, const paddle::Tensor& base_model_stop_flags); - void DraftModelPreprocess(const paddle::Tensor& draft_tokens, const paddle::Tensor& input_ids, const paddle::Tensor& stop_flags, @@ -638,19 +973,17 @@ void DraftModelPreprocess(const paddle::Tensor& draft_tokens, const paddle::Tensor& seq_lens_decoder, const paddle::Tensor& step_idx, const paddle::Tensor& not_need_stop, - const paddle::Tensor& batch_drop, + const paddle::Tensor& pre_ids, const paddle::Tensor& accept_tokens, const paddle::Tensor& accept_num, - const paddle::Tensor& base_model_seq_lens_encoder, - const paddle::Tensor& base_model_seq_lens_decoder, - const paddle::Tensor& base_model_step_idx, - const paddle::Tensor& base_model_stop_flags, - const paddle::Tensor& base_model_is_block_step, - const paddle::Tensor& base_model_draft_tokens, - const int max_draft_token, - const bool truncate_first_token, - const bool splitwise_prefill); - + const paddle::Tensor& target_model_seq_lens_encoder, + const paddle::Tensor& target_model_seq_lens_decoder, + const paddle::Tensor& target_model_step_idx, + const paddle::Tensor& target_model_stop_flags, + const paddle::Tensor& max_dec_len, + const paddle::Tensor& target_model_draft_tokens, + const int num_model_step, + const bool is_splitwise_prefill); void DraftModelUpdate(const paddle::Tensor& inter_next_tokens, const paddle::Tensor& draft_tokens, @@ -659,7 +992,7 @@ void DraftModelUpdate(const paddle::Tensor& inter_next_tokens, const paddle::Tensor& seq_lens_encoder, const paddle::Tensor& seq_lens_decoder, const paddle::Tensor& step_idx, - const paddle::Tensor& output_cum_offsets, + const paddle::Tensor& cu_seqlens_q_output, const paddle::Tensor& stop_flags, const paddle::Tensor& not_need_stop, const paddle::Tensor& max_dec_len, @@ -668,127 +1001,350 @@ void DraftModelUpdate(const paddle::Tensor& inter_next_tokens, const int max_seq_len, const int substep); - - std::vector EagleGetHiddenStates( - const paddle::Tensor& input, - const paddle::Tensor& seq_lens_this_time, - const paddle::Tensor& seq_lens_encoder, - const paddle::Tensor& seq_lens_decoder, - const paddle::Tensor& stop_flags, - const paddle::Tensor& accept_nums, - const paddle::Tensor& base_model_seq_lens_this_time, - const paddle::Tensor& base_model_seq_lens_encoder, - const int actual_draft_token_num); + const paddle::Tensor& input, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& stop_flags, + const paddle::Tensor& accept_nums, + const paddle::Tensor& base_model_seq_lens_this_time, + const paddle::Tensor& base_model_seq_lens_encoder, + const int actual_draft_token_num); + +std::vector EagleGetSelfHiddenStates( + const paddle::Tensor& input, + const paddle::Tensor& last_seq_lens_this_time, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& seq_lens_encoder); + +std::vector EagleGatherHiddenStates( + const paddle::Tensor& input, + const paddle::Tensor& cu_seqlens_q, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& batch_id_per_token_output, + const paddle::Tensor& cu_seqlens_q_output, + const paddle::Tensor& real_output_token_num); void MTPStepPaddle( - const paddle::Tensor &base_model_stop_flags, - const paddle::Tensor &stop_flags, - const paddle::Tensor &batch_drop, - const paddle::Tensor &seq_lens_this_time, - const paddle::Tensor &seq_lens_encoder, - const paddle::Tensor &seq_lens_decoder, - const paddle::Tensor &block_tables, // [bsz, block_num_per_seq] - const paddle::Tensor &encoder_block_lens, - const paddle::Tensor &used_list_len, - const paddle::Tensor &free_list, - const paddle::Tensor &free_list_len, + const paddle::Tensor& base_model_stop_flags, + const paddle::Tensor& stop_flags, + const paddle::Tensor& batch_drop, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& block_tables, // [bsz, block_num_per_seq] + const paddle::Tensor& encoder_block_lens, + const paddle::Tensor& used_list_len, + const paddle::Tensor& free_list, + const paddle::Tensor& free_list_len, const int block_size, const int max_draft_tokens); void SpeculateStepPaddle( - const paddle::Tensor &stop_flags, - const paddle::Tensor &seq_lens_this_time, - const paddle::Tensor &ori_seq_lens_encoder, - const paddle::Tensor &seq_lens_encoder, - const paddle::Tensor &seq_lens_decoder, - const paddle::Tensor &block_tables, // [bsz, block_num_per_seq] - const paddle::Tensor &encoder_block_lens, - const paddle::Tensor &is_block_step, - const paddle::Tensor &step_block_list, - const paddle::Tensor &step_lens, - const paddle::Tensor &recover_block_list, - const paddle::Tensor &recover_lens, - const paddle::Tensor &need_block_list, - const paddle::Tensor &need_block_len, - const paddle::Tensor &used_list_len, - const paddle::Tensor &free_list, - const paddle::Tensor &free_list_len, - const paddle::Tensor &input_ids, - const paddle::Tensor &pre_ids, - const paddle::Tensor &step_idx, - const paddle::Tensor &next_tokens, - const paddle::Tensor &first_token_ids, - const paddle::Tensor &accept_num, + const paddle::Tensor& stop_flags, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& ori_seq_lens_encoder, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& block_tables, // [bsz, block_num_per_seq] + const paddle::Tensor& encoder_block_lens, + const paddle::Tensor& is_block_step, + const paddle::Tensor& step_block_list, + const paddle::Tensor& step_lens, + const paddle::Tensor& recover_block_list, + const paddle::Tensor& recover_lens, + const paddle::Tensor& need_block_list, + const paddle::Tensor& need_block_len, + const paddle::Tensor& used_list_len, + const paddle::Tensor& free_list, + const paddle::Tensor& free_list_len, + const paddle::Tensor& input_ids, + const paddle::Tensor& pre_ids, + const paddle::Tensor& step_idx, + const paddle::Tensor& next_tokens, + const paddle::Tensor& first_token_ids, + const paddle::Tensor& accept_num, const int block_size, const int encoder_decoder_block_num, const int max_draft_tokens); -PYBIND11_MODULE(fastdeploy_ops, m) { +void MergePrefillDecodeOutput(const paddle::Tensor& encoder_res, + const paddle::Tensor& decoder_res, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& cu_seq_q, + const int head_num, + const int head_dim, + const int max_token); + +std::vector TopPSamplingReject( + const paddle::Tensor& probs, + const paddle::Tensor& top_p, + const paddle::optional& top_k, + int64_t seed); + +std::vector TopKRenorm(const paddle::Tensor& probs, + const paddle::Tensor& top_k); + +std::vector MinPSamplingFromProbs(const paddle::Tensor& probs, + const paddle::Tensor& min_p); + +void SaveOutMmsgStatic(const paddle::Tensor& x, + const paddle::Tensor& not_need_stop, + const paddle::Tensor& preempted_idx, + int64_t rank_id, + bool save_each_rank); + +void LimitThinkingContentLength(const paddle::Tensor& next_tokens, + const paddle::Tensor& max_think_lens, + const paddle::Tensor& max_reply_lens, + const paddle::Tensor& step_idx, + const paddle::Tensor& limit_status, + const paddle::Tensor& stop_flags, + const paddle::Tensor& eos_token_ids, + const paddle::Tensor& inject_token_ids, + const int64_t think_end_id, + const bool splitwise_role_is_decode); + +void SpeculateLimitThinkingContentLength(const paddle::Tensor& next_tokens, + const paddle::Tensor& max_think_lens, + const paddle::Tensor& max_reply_lens, + const paddle::Tensor& step_idx, + const paddle::Tensor& limit_status, + const paddle::Tensor& accept_num, + const paddle::Tensor& stop_flags, + const paddle::Tensor& eos_token_ids, + const paddle::Tensor& inject_token_ids, + const int64_t think_end_id, + const bool splitwise_role_is_decode); + +void SpeculateGetLogits(const paddle::Tensor& draft_logits, + const paddle::Tensor& next_token_num, + const paddle::Tensor& batch_token_num, + const paddle::Tensor& cu_next_token_offset, + const paddle::Tensor& cu_batch_token_offset, + const paddle::Tensor& logits, + const paddle::Tensor& first_token_logits, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& seq_lens_encoder); + +void SpeculateInsertFirstToken(const paddle::Tensor& token_ids, + const paddle::Tensor& accept_tokens, + const paddle::Tensor& next_tokens, + const paddle::Tensor& cu_next_token_offset, + const paddle::Tensor& cu_batch_token_offset, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& seq_lens_encoder); + +void SpeculateGetTargetLogits(const paddle::Tensor& target_logits, + const paddle::Tensor& logits, + const paddle::Tensor& cu_batch_token_offset, + const paddle::Tensor& ori_cu_batch_token_offset, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& accept_num); + +std::vector UpdateAttnMaskOffsets( + const paddle::Tensor& ids_remove_padding, + const paddle::Tensor& seq_lens_this_time, // only on cpu + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& cu_seqlens_q, + const paddle::Tensor& attn_mask_offsets_full, + const paddle::Tensor& is_block_step, + const paddle::Tensor& decode_states); + +std::vector FusedNeoxRopeEmbedding( + const paddle::Tensor& qkv, + const paddle::Tensor& cos_emb, + const paddle::Tensor& sin_emb, + const int num_heads, + const int head_dim); + +std::vector GeluTanh(paddle::Tensor& input); + +void ReasoningPhaseTokenConstraint( + const paddle::Tensor& logits, + const paddle::Tensor& token_ids_all, + const paddle::Tensor& prompt_lens, + const paddle::Tensor& stop_flags, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& step_idx, + const paddle::Tensor& allowed_tokens, + const paddle::Tensor& reasoning_status, + const paddle::Tensor& batch_id_per_token_output, + const paddle::Tensor& cu_seqlens_q_output, + const paddle::Tensor& enable_thinking, + int64_t think_end_id, + int64_t line_break_id); + +std::vector get_attn_mask_q( + const paddle::Tensor& cu_seqlens_q, + const paddle::Tensor& cu_seqlens_k, + const paddle::optional& attn_mask_kv, + const int kv_token_num); - m.def("get_expert_token_num", &GetExpertTokenNum, py::arg("topk_ids"), - py::arg("num_experts"), "get expert token num"); +std::vector PrefillPermuteToMaskedGemm( + const paddle::Tensor& x, + const paddle::Tensor& scale, + const paddle::Tensor& topk_ids, + const int num_local_experts, + const int max_token_num); + +std::vector DepermutePrefillCombine( + const paddle::Tensor& x, + const paddle::Tensor& indice_map, + const paddle::Tensor& topk_weights, + const int num_worst_tokens); + +void RadixTopkRaggedTransform( + paddle::Tensor& input, + paddle::Tensor& output_indices, + const paddle::Tensor& offsets, + paddle::Tensor& lengths, + paddle::optional& seq_len_decoder, + paddle::optional& batch_id_per_token, + paddle::optional& block_tables, + paddle::optional& maybe_row_states_buffer, + int max_block_num, + int top_k, + int q_num_heads = 0); + +std::vector DSMLAWriteCacheKernel( + const paddle::Tensor& kv_nope, + const paddle::Tensor& kv_pe, + const paddle::Tensor& kv_cache, + const paddle::Tensor& slot_mapping, + const paddle::optional& scale, + const std::string& cache_quant_type_str); + +std::vector IndexerKQuantAndCacheKernel( + const paddle::Tensor& k, + const paddle::Tensor& kv_cache, + const paddle::Tensor& slot_mapping, + const int64_t quant_block_size, + const std::string& scale_fmt); + +std::vector CpGatherIndexerKQuantCacheKernel( + const paddle::Tensor& kv_cache, + paddle::Tensor& dst_k, + paddle::Tensor& dst_scale, + const paddle::Tensor& block_table, + const paddle::Tensor& cu_seq_lens); + +void PerTokenGroupQuantFp8(const paddle::Tensor& input, + paddle::Tensor& output_q, + paddle::Tensor& output_s, + int64_t group_size, + double eps, + double fp8_min, + double fp8_max, + bool scale_ue8m0); + +PYBIND11_MODULE(fastdeploy_ops, m) { +#ifdef ENABLE_SM80_EXT_OPS + m.def("get_expert_token_num", + &GetExpertTokenNum, + py::arg("topk_ids"), + py::arg("num_experts"), + "get expert token num"); /** * moe/fused_moe/moe_redundant_topk_select.cu * moe_redundant_topk_select */ - m.def("f_moe_redundant_topk_select", &MoERedundantTopKSelectKernel, - py::arg("gating_logits"), py::arg("expert_id_to_ep_rank_array"), + m.def("moe_redundant_topk_select", + &MoERedundantTopKSelectKernel, + py::arg("gating_logits"), + py::arg("expert_id_to_ep_rank_array"), py::arg("expert_in_rank_num_list"), - py::arg("tokens_per_expert_stats_list"), py::arg("bias"), - py::arg("moe_topk"), py::arg("apply_norm_weight"), + py::arg("tokens_per_expert_stats_list"), + py::arg("bias"), + py::arg("moe_topk"), + py::arg("apply_norm_weight"), py::arg("enable_softmax_top_k_fused"), py::arg("redundant_ep_rank_num_plus_one"), "moe export RedundantTopKSelect function"); +#endif /** * open_shm_and_get_meta_signal.cc * InitKVSignalPerQuery */ - m.def("init_kv_signal_per_query", &InitKVSignalPerQuery, + m.def("init_kv_signal_per_query", + &InitKVSignalPerQuery, py::arg("seq_lens_encoder_tensor"), py::arg("seq_lens_this_time_tensor"), - py::arg("seq_lens_decoder_tensor"), py::arg("rank"), - py::arg("num_layers"), "init_kv_signal_per_query function"); + py::arg("seq_lens_decoder_tensor"), + py::arg("rank"), + py::arg("num_layers"), + "init_kv_signal_per_query function"); /** * GetOutputKVSignal */ - m.def("get_output_kv_signal", &GetOutputKVSignal, py::arg("x"), - py::arg("rank_id"), py::arg("wait_flag"), + m.def("get_output_kv_signal", + &GetOutputKVSignal, + py::call_guard(), + py::arg("x"), + py::arg("rank_id"), + py::arg("wait_flag"), "get_output_kv_signal function"); +#ifdef ENABLE_SM75_EXT_OPS m.def("moe_deepgemm_permute", &MoEDeepGEMMPermute, "MoEDeepGEMMPermute"); - m.def("moe_deepgemm_depermute", &MoEDeepGEMMDePermute, - "MoEDeepGEMMDePermute"); + m.def( + "moe_deepgemm_depermute", &MoEDeepGEMMDePermute, "MoEDeepGEMMDePermute"); +#endif /** * alloc_cache_pinned.cc * cuda_host_alloc * cuda_host_free */ - m.def("cuda_host_alloc", &cuda_host_alloc, "Allocate pinned memory", - py::arg("size"), py::arg("flags") = cudaHostAllocDefault); - m.def("cuda_host_free", &cuda_host_free, "Free pinned memory", - py::arg("ptr")); + m.def("cuda_host_alloc", + &cuda_host_alloc, + "Allocate pinned memory", + py::arg("size"), + py::arg("flags") = cudaHostAllocDefault); + m.def( + "cuda_host_free", &cuda_host_free, "Free pinned memory", py::arg("ptr")); py::register_exception(m, "CudaError"); +#ifdef ENABLE_SM80_EXT_OPS /** * append_attention.cu * append_attention */ m.def("append_attention", &AppendAttention, "append attention function"); + m.def("append_attention_with_output", + &AppendAttentionWithOutput, + "append attention with output function"); +#endif + +#ifdef ENABLE_FLASH_MASK_ATTENTION + m.def("flash_mask_attention", &FlashAttentionMask, "flash_mask_attention"); +#endif + +#ifdef ENABLE_SM80_EXT_OPS /** * gqa_rope_write_cache.cu * gqa_rope_write_cache */ - m.def("gqa_rope_write_cache", &GQARopeWriteCacheKernel, + m.def("gqa_rope_write_cache", + &GQARopeWriteCacheKernel, "gqa rope write cache function"); /** * pre_cache_len_concat.cu * pre_cache_len_concat */ - m.def("pre_cache_len_concat", &PreCacheLenConcat, + m.def("pre_cache_len_concat", + &PreCacheLenConcat, "pre_cache len concat function"); + /** * moe/fused_moe/fused_moe.cu * fused_moe @@ -805,45 +1361,113 @@ PYBIND11_MODULE(fastdeploy_ops, m) { * moe/fused_moe/moe_dispatch.cu * moe_expert_dispatch */ - m.def("moe_expert_dispatch", &MoeExpertDispatch, py::arg("input"), - py::arg("gating_output"), py::arg("gating_correction_bias"), - py::arg("w4a8_in_scale"), py::arg("moe_topk"), py::arg("group_moe"), - py::arg("topk_only_mode"), "moe export dispatch function"); + m.def("moe_expert_dispatch", + &MoeExpertDispatch, + py::arg("input"), + py::arg("gating_output"), + py::arg("gating_correction_bias"), + py::arg("w4a8_in_scale"), + py::arg("moe_topk"), + py::arg("group_moe"), + py::arg("moe_quant_type"), + py::arg("topk_only_mode"), + "moe export dispatch function"); /** * moe/fused_moe/ep_moe_prefill_func.cu * ep_moe_dispatch */ - m.def("ep_moe_expert_dispatch", &EPMoeExpertDispatch, py::arg("input"), - py::arg("topk_ids"), py::arg("topk_weights"), py::arg("up_gate_proj_in_scale"), - py::arg("token_nums_per_expert"), py::arg("token_nums_this_rank"), - py::arg("moe_quant_type"), "ep moe export dispatch function"); + m.def("ep_moe_expert_dispatch", + &EPMoeExpertDispatch, + py::arg("input"), + py::arg("topk_ids"), + py::arg("topk_weights"), + py::arg("up_gate_proj_in_scale"), + py::arg("token_nums_per_expert"), + py::arg("token_nums_this_rank"), + py::arg("moe_quant_type"), + "ep moe export dispatch function"); m.def("ep_moe_expert_dispatch_fp8", &EPMoeExpertDispatchFP8); - m.def("ep_moe_expert_combine", &EPMoeExpertCombine, py::arg("ffn_out"), - py::arg("expert_scales_float"), py::arg("permute_indices_per_token"), - py::arg("top_k_indices"), py::arg("down_proj_bias"), - py::arg("norm_topk_prob"), py::arg("routed_scaling_factor"), + m.def("ep_moe_expert_combine", + &EPMoeExpertCombine, + py::arg("ffn_out"), + py::arg("expert_scales_float"), + py::arg("permute_indices_per_token"), + py::arg("top_k_indices"), + py::arg("down_proj_bias"), + py::arg("norm_topk_prob"), + py::arg("routed_scaling_factor"), "ep moe export combine function"); +#endif - m.def("per_token_quant", &PerTokenQuant, py::arg("input"), - py::arg("block_size"), "per token per block quant"); + m.def("per_token_quant", + &PerTokenQuant, + py::arg("input"), + py::arg("block_size"), + py::arg("use_ue8m0"), + "per token per block quant"); - m.def("per_token_quant_padding", &PerTokenQuantPadding, py::arg("input"), + m.def("per_token_quant_padding", + &PerTokenQuantPadding, + py::arg("input"), py::arg("block_size"), - "per token per block quant and padding tranpose scale"); + py::arg("use_ue8m0"), + "per token per block quant and padding transpose scale"); - m.def("masked_per_token_quant", &MaskedPerTokenQuant, py::arg("input"), - py::arg("recv_expert_count"), py::arg("block_size"), - "per token per block quant"); + m.def("fused_mask_swiglu_fp8_quant", + &FusedMaskSwigluFP8Quant, + py::arg("input"), + py::arg("token_nums_per_expert"), + py::arg("block_size"), + py::arg("use_ue8m0") = false, + "fused mask swiglu and fp8 quant"); + +#ifdef ENABLE_MACHETE + /*machete/machete_mm.cu + * machete_mm + */ + m.def("machete_mm", + &MacheteMMKernel, + py::arg("A"), + py::arg("B"), + py::arg("maybe_group_scale"), + py::arg("maybe_group_zeros"), + py::arg("maybe_channel_scales"), + py::arg("maybe_token_scales"), + py::arg("b_type_str"), + py::arg("maybe_out_type_str"), + py::arg("maybe_group_size"), + py::arg("maybe_schedule"), + "machete mm function"); + + /*machete/machete_prepack_B.cu + * machete_prepack_B + */ + m.def("machete_prepack_B", + &MachetePrepackBKernel, + "machete prepacked B function"); + /*machete/machete_supported_schedules.cu + * machete_supported_schedules + */ + m.def("machete_supported_schedules", + &MacheteSupportedSchedules, + "machete supported schedules function"); +#endif + +#ifdef ENABLE_SM80_EXT_OPS /** * moe/fused_moe/moe_topk_select.cu * moe_topk_select */ - m.def("moe_topk_select", &MoETopKSelectKernel, py::arg("gating_logits"), - py::arg("bias"), py::arg("moe_topk"), py::arg("apply_norm_weight"), + m.def("moe_topk_select", + &MoETopKSelectKernel, + py::arg("gating_logits"), + py::arg("bias"), + py::arg("moe_topk"), + py::arg("apply_norm_weight"), py::arg("enable_softmax_top_k_fused"), "moe export TopKSelect function"); @@ -854,20 +1478,28 @@ PYBIND11_MODULE(fastdeploy_ops, m) { m.def("moe_expert_ffn", &MoeExpertFFNFunc, "moe export ffn function"); /** - * moe/fused_moe/moe_ffn_wint2.cu + * moe/fused_moe/moe_expert_ffn_wint2.cu * moe_expert_ffn_wint2 */ - m.def("moe_expert_ffn_wint2", &MoeExpertFFNWint2Func, "moe export ffn wint2 function"); + m.def("moe_expert_ffn_wint2", + &MoeExpertFFNWint2Func, + "moe export ffn wint2 function"); /** * moe/fused_moe/moe_expert_reduce.cu * moe_expert_reduce */ - m.def("moe_expert_reduce", &MoeExpertReduceFunc, py::arg("ffn_out"), - py::arg("top_k_weight"), py::arg("permute_indices_per_token"), - py::arg("top_k_indices"), py::arg("down_proj_bias"), - py::arg("norm_topk_prob"), py::arg("routed_scaling_factor"), + m.def("moe_expert_reduce", + &MoeExpertReduceFunc, + py::arg("ffn_out"), + py::arg("top_k_weight"), + py::arg("permute_indices_per_token"), + py::arg("top_k_indices"), + py::arg("down_proj_bias"), + py::arg("norm_topk_prob"), + py::arg("routed_scaling_factor"), "moe export reduce function"); +#endif /** * dequant_int8.cu @@ -879,22 +1511,27 @@ PYBIND11_MODULE(fastdeploy_ops, m) { * init_signal_layerwise.cc * init_signal_layerwise */ - m.def("init_signal_layerwise", &InitSignalLayerwiseFunc, + m.def("init_signal_layerwise", + &InitSignalLayerwiseFunc, "init_signal_layerwise function"); /** * open_shm_and_get_meta_signal.cc * open_shm_and_get_meta_signal */ - m.def("open_shm_and_get_meta_signal", &OpenShmAndGetMetaSignalFunc, + m.def("open_shm_and_get_meta_signal", + &OpenShmAndGetMetaSignalFunc, "open_shm_and_get_meta_signal function"); +#ifdef ENABLE_SM80_EXT_OPS /** * append_attn/get_block_shape_and_split_kv_block.cu * get_block_shape_and_split_kv_block */ m.def("get_block_shape_and_split_kv_block", - &GetBlockShapeAndSplitKVBlock, "get_block_shape_and_split_kv_block function"); + &GetBlockShapeAndSplitKVBlock, + "get_block_shape_and_split_kv_block function"); +#endif /** * get_padding_offset.cu @@ -906,7 +1543,8 @@ PYBIND11_MODULE(fastdeploy_ops, m) { * get_padding_offset.cu * get_padding_offset */ - m.def("set_value_by_flags_and_idx", &SetValueByFlagsAndIdx, + m.def("set_value_by_flags_and_idx", + &SetValueByFlagsAndIdx, "SetValueByFlagsAndIdx"); /** @@ -919,51 +1557,82 @@ PYBIND11_MODULE(fastdeploy_ops, m) { * stop_generation_multi_ends.cu * set_stop_value_multi_ends */ - m.def("set_stop_value_multi_ends", &GetStopFlagsMulti, + m.def("set_stop_value_multi_ends", + &GetStopFlagsMulti, "update_inputs function"); /** - * stop_generation_multi_stop_seqs.cu - * set_stop_value_multi_seqs + * update_inputs.cu + * update_inputs */ - m.def("set_stop_value_multi_seqs", &GetStopFlagsMultiSeqs, - "update_inputs function"); + m.def("update_inputs", &UpdateInputs, "update_inputs function"); /** - * update_inputs.cu - * update_inputs + * update_inputs_v1.cu + * update_inputs_v1 */ - m.def("update_inputs", &UpdateInputes, "update_inputs function"); + m.def("update_inputs_v1", + &UpdateInputsV1, + "update inputs for scheduler v1 function"); /** - * extract_text_token_output.cu - * extract_text_token_output + * recover_decode_task.cu + * recover_decode_task */ - m.def("extract_text_token_output", &ExtractTextTokenOutput, - "extract_text_token_output function"); + m.def("recover_decode_task", + &RecoverDecodeTask, + "recover decode task for scheduler v1 function"); - m.def("group_swiglu_with_masked", &GroupSwigluWithMasked, +#ifdef ENABLE_SM80_EXT_OPS + m.def("group_swiglu_with_masked", + &GroupSwigluWithMasked, "group_swiglu_with_masked function"); +#endif - m.def("text_image_index_out", &TextImageIndexOut, + m.def("text_image_index_out", + &TextImageIndexOut, "text_image_index_out function"); - m.def("text_image_gather_scatter", &TextImageGatherScatter, + m.def("text_image_gather_scatter", + &TextImageGatherScatter, "text_image_gather_scatter function"); +#ifdef ENABLE_SM80_EXT_OPS m.def("count_tokens_per_expert_func", &count_tokens_per_expert_func); + m.def("tritonmoe_preprocess_func", &tritonmoe_preprocess_kernel); - m.def("MoeWna16MarlinGemmApi", &MoeWna16MarlinGemmApi, - py::arg("a"), py::arg("c_or_none"), py::arg("b_q_weight"), - py::arg("b_scales"), py::arg("global_scale_or_none"), py::arg("b_zeros_or_none"), - py::arg("g_idx_or_none"), py::arg("perm_or_none"), py::arg("workspace"), py::arg("sorted_token_ids"), - py::arg("expert_ids"), py::arg("num_tokens_post_padded"), py::arg("topk_weights"), py::arg("moe_block_size"), - py::arg("top_k"), py::arg("mul_topk_weights"), py::arg("is_ep"), py::arg("b_q_type_str"), - py::arg("size_m"), py::arg("size_n"), py::arg("size_k"), py::arg("is_k_full"), py::arg("use_atomic_add"), - py::arg("use_fp32_reduce"), py::arg("is_zp_float")); + m.def("MoeWna16MarlinGemmApi", + &MoeWna16MarlinGemmApi, + py::arg("a"), + py::arg("c_or_none"), + py::arg("b_q_weight"), + py::arg("b_scales"), + py::arg("global_scale_or_none"), + py::arg("b_zeros_or_none"), + py::arg("g_idx_or_none"), + py::arg("perm_or_none"), + py::arg("workspace"), + py::arg("sorted_token_ids"), + py::arg("expert_ids"), + py::arg("num_tokens_post_padded"), + py::arg("topk_weights"), + py::arg("moe_block_size"), + py::arg("top_k"), + py::arg("mul_topk_weights"), + py::arg("is_ep"), + py::arg("b_q_type_str"), + py::arg("size_m"), + py::arg("size_n"), + py::arg("size_k"), + py::arg("is_k_full"), + py::arg("use_atomic_add"), + py::arg("use_fp32_reduce"), + py::arg("is_zp_float")); +#endif - m.def("get_position_ids_and_mask_encoder_batch", &GetPositionIdsAndMaskEncoderBatch, + m.def("get_position_ids_and_mask_encoder_batch", + &GetPositionIdsAndMaskEncoderBatch, "get_position_ids_and_mask_encoder_batch function"); /** @@ -972,7 +1641,9 @@ PYBIND11_MODULE(fastdeploy_ops, m) { * cutlass_scaled_mm_azp */ m.def("cutlass_scaled_mm", &CutlassScaledMm, "cutlass_scaled_mm function"); - m.def("cutlass_scaled_mm_azp", &CutlassScaledMmAzp, "cutlass_scaled_mm_azp function"); + m.def("cutlass_scaled_mm_azp", + &CutlassScaledMmAzp, + "cutlass_scaled_mm_azp function"); /** * quantization/common.cu @@ -980,39 +1651,84 @@ PYBIND11_MODULE(fastdeploy_ops, m) { * dynamic_scaled_fp8_quant * dynamic_per_token_scaled_fp8_quant */ - m.def("static_scaled_fp8_quant", &StaticScaledFp8Quant, "static_scaled_fp8_quant function", - py::arg("out"), py::arg("input"), py::arg("scale")); - - m.def("dynamic_scaled_fp8_quant", &DynamicScaledFp8Quant, + m.def("static_scaled_fp8_quant", + &StaticScaledFp8Quant, + "static_scaled_fp8_quant function", + py::arg("out"), + py::arg("input"), + py::arg("scale")); + + m.def("dynamic_scaled_fp8_quant", + &DynamicScaledFp8Quant, "dynamic_scaled_fp8_quant function", - py::arg("out"), py::arg("input"), py::arg("scale")); + py::arg("out"), + py::arg("input"), + py::arg("scale")); - m.def("dynamic_per_token_scaled_fp8_quant", &DynamicPerTokenScaledFp8Quant, + m.def("dynamic_per_token_scaled_fp8_quant", + &DynamicPerTokenScaledFp8Quant, "dynamic_per_token_scaled_fp8_quant function", - py::arg("out"), py::arg("input"), py::arg("scales"), py::arg("scale_ub")); - m.def("decode_mla_write_cache", &DecodeMLAWriteCacheKernel, "decode_mla_write_cache function"); + py::arg("out"), + py::arg("input"), + py::arg("scales"), + py::arg("scale_ub")); +#ifdef ENABLE_SM80_EXT_OPS + m.def("decode_mla_write_cache", + &DecodeMLAWriteCacheKernel, + "decode_mla_write_cache function"); + + m.def("prefill_mla_write_cache", + &PrefillMLAWriteCacheKernel, + "prefill_mla_write_cache function"); +#endif - m.def("prefill_mla_write_cache", &PrefillMLAWriteCacheKernel, "prefill_mla_write_cache function"); + m.def("fused_rotary_position_encoding", + &FusedRotaryPositionEncoding, + "fused_rotary_position_encoding function"); - m.def("fused_rotary_position_encoding", &FusedRotaryPositionEncoding, "fused_rotary_position_encoding function"); +#ifdef ENABLE_SM80_EXT_OPS + m.def("multi_head_latent_attention", + &MultiHeadLatentAttention, + "multi_head_latent_attention function"); +#endif - m.def("multi_head_latent_attention", &MultiHeadLatentAttention, "multi_head_latent_attention function"); + m.def("noaux_tc", &NoauxTc, "noaux_tc for Deepseekv3 MoE compute"); - m.def("noaux_tc",&NoauxTc, "noaux_tc for Deepseekv3 MoE compute"); + m.def("noaux_tc_redundant", + &NoauxTcRedundant, + "noaux_tc_redundant for MoE compute"); #ifdef ENABLE_FP8 - m.def("cutlass_fp8_fp8_half_gemm_fused", &cutlass_fp8_fp8_half_gemm_func, - py::arg("x"), py::arg("y"), py::arg("bias"), py::arg("transpose_x"), - py::arg("transpose_y"), py::arg("scale"), py::arg("output_dtype"), - py::arg("activation_type"), "cutlass_fp8_fp8_half_gemm_fused function"); - m.def("moe_fused_hadamard_quant_fp8", &MoeFusedHadamardQuantFp8Func, - py::arg("input"), py::arg("scale"), py::arg("topk_ids"), - py::arg("top_k"), py::arg("intermediate_size"), py::arg("tiled"), "moe_fused_hadamard_quant_fp8 function"); - m.def("fused_hadamard_quant_fp8", &FusedHadamardQuantFp8Func, - py::arg("input"), py::arg("scale"), "fused_hadamard_quant_fp8 function"); + m.def("cutlass_fp8_fp8_half_gemm_fused", + &cutlass_fp8_fp8_half_gemm_func, + py::arg("x"), + py::arg("y"), + py::arg("bias"), + py::arg("transpose_x"), + py::arg("transpose_y"), + py::arg("scale"), + py::arg("output_dtype"), + py::arg("activation_type"), + "cutlass_fp8_fp8_half_gemm_fused function"); + m.def("moe_fused_hadamard_quant_fp8", + &MoeFusedHadamardQuantFp8Func, + py::arg("input"), + py::arg("scale"), + py::arg("topk_ids"), + py::arg("top_k"), + py::arg("intermediate_size"), + py::arg("tiled"), + "moe_fused_hadamard_quant_fp8 function"); + m.def("fused_hadamard_quant_fp8", + &FusedHadamardQuantFp8Func, + py::arg("input"), + py::arg("scale"), + "fused_hadamard_quant_fp8 function"); #endif - m.def("init_custom_all_reduce", &init_custom_all_reduce, "init all reduce class function"); + m.def("init_custom_all_reduce", + &init_custom_all_reduce, + "init all reduce class function"); m.def("all_reduce", &all_reduce, "all reduce function"); @@ -1022,48 +1738,196 @@ PYBIND11_MODULE(fastdeploy_ops, m) { m.def("register_buffer", ®ister_buffer, "register ipc buffer"); - m.def("register_graph_buffers", ®ister_graph_buffers, "register_graph_buffers"); + m.def("register_graph_buffers", + ®ister_graph_buffers, + "register_graph_buffers"); - m.def("allocate_shared_buffer_and_handle", &allocate_shared_buffer_and_handle, "allocate_shared_buffer_and_handle"); + m.def("allocate_shared_buffer_and_handle", + &allocate_shared_buffer_and_handle, + "allocate_shared_buffer_and_handle"); m.def("free_shared_buffer", &free_shared_buffer, "free_shared_buffer"); + m.def("clear_ipc_handles", &clear_ipc_handles, "clear_ipc_handles"); + m.def("open_mem_handle", &open_mem_handle, "open_mem_handle"); - m.def("get_graph_buffer_ipc_meta", &get_graph_buffer_ipc_meta, "get_graph_buffer_ipc_meta"); + m.def("get_graph_buffer_ipc_meta", + &get_graph_buffer_ipc_meta, + "get_graph_buffer_ipc_meta"); - // speculative decoding Kernel - m.def("speculate_get_padding_offset", &SpeculateGetPaddingOffset, "speculate_get_padding_offset function"); +#ifdef ENABLE_SM80_EXT_OPS + m.def("speculate_get_seq_lens_output", + &SpeculateGetSeqLensOutput, + "speculate_get_seq_lens_output function"); - m.def("speculate_get_seq_lens_output", &SpeculateGetSeqLensOutput, "speculate_get_seq_lens_output function"); + m.def("speculate_pre_process", + &SpeculatePreProcess, + "speculate_pre_process function"); - m.def("speculate_get_output_padding_offset",&SpeculateGetOutputPaddingOffset, "speculate_get_output_padding_offset function"); + m.def("build_sampling_params", + &BuildSamplingParams, + "build_sampling_params function"); - m.def("speculate_get_token_penalty_multi_scores",&SpecTokenPenaltyMultiScores, "speculate_get_token_penalty_multi_scores function"); + m.def("speculate_get_token_penalty_multi_scores", + &SpecTokenPenaltyMultiScores, + "speculate_get_token_penalty_multi_scores function"); - m.def("speculate_set_stop_value_multi_seqs",&SpecGetStopFlagsMultiSeqs, "speculate_set_stop_value_multi_seqs function"); + m.def("speculate_set_stop_value_multi_seqs", + &SpecGetStopFlagsMultiSeqs, + "speculate_set_stop_value_multi_seqs function"); + m.def("speculate_verify", &SpeculateVerify, "speculate_verify function"); - m.def("speculate_verify",&SpeculateVerify, "speculate_verify function"); + m.def("verify_draft_tokens", + &VerifyDraftTokens, + "verify_draft_tokens function"); - m.def("speculate_update_v3",&SpeculateUpdateV3, "noaux_tc for Deepseekv3 MoE compute function"); + m.def("speculate_update", &SpeculateUpdate, "Speculate Update Kernel"); - m.def("speculate_set_value_by_flags_and_idx",&SpeculateSetValueByFlagsAndIdx, "speculate_set_value_by_flags_and_idx function"); + m.def("unified_update_model_status", + &UnifiedUpdateModelStatus, + "unified_update_model_status function"); - m.def("speculate_save_output", &SpeculateSaveWithOutputMsgStatic, "speculate_save_output function"); + m.def("naive_update_model_status", + &NaiveUpdateModelStatus, + "naive_update_model_status function"); - m.def("speculate_clear_accept_nums",&SpeculateClearAcceptNums, "speculate_clear_accept_nums function"); + m.def("speculate_set_value_by_flags_and_idx", + &SpeculateSetValueByFlagsAndIdx, + "speculate_set_value_by_flags_and_idx function"); + + m.def("speculate_save_output", + &SpeculateSaveWithOutputMsgStatic, + "speculate_save_output function"); + + m.def("speculate_clear_accept_nums", + &SpeculateClearAcceptNums, + "speculate_clear_accept_nums function"); + + m.def("speculate_schedule_cache", + &SpeculateScheduleCache, + "SpeculateScheduleCache function"); m.def("ngram_match", &NgramMatch, "ngram_match function"); - m.def("draft_model_postprocess",&DraftModelPostprocess, "draft_model_postprocess function"); + m.def("hybird_mtp_ngram", &HybridMtpNgram, "ngram_match_mixed function"); + + m.def("draft_model_postprocess", + &DraftModelPostprocess, + "draft_model_postprocess function"); + + m.def("draft_model_preprocess", + &DraftModelPreprocess, + "draft_model_preprocess function"); + + m.def("draft_model_update", &DraftModelUpdate, "draft_model_update function"); + + m.def("eagle_get_hidden_states", + &EagleGetHiddenStates, + "eagle_get_hidden_states function"); - m.def("draft_model_preprocess",&DraftModelPreprocess, "draft_model_preprocess function"); + m.def("eagle_get_self_hidden_states", + &EagleGetSelfHiddenStates, + "eagle_get_self_hidden_states function"); - m.def("draft_model_update",&DraftModelUpdate, "draft_model_update function"); + m.def("eagle_gather_hidden_states", + &EagleGatherHiddenStates, + "eagle_gather_hidden_states function"); - m.def("eagle_get_hidden_states",&EagleGetHiddenStates, "eagle_get_hidden_states function"); + m.def("mtp_step_paddle", &MTPStepPaddle, "mtp_step_paddle function"); - m.def("mtp_step_paddle",&MTPStepPaddle, "mtp_step_paddle function"); + m.def("speculate_step_paddle", + &SpeculateStepPaddle, + "speculate_step_paddle function"); + + m.def("merge_prefill_decode_output", + &MergePrefillDecodeOutput, + "merge_prefill_decode_output function"); + + m.def("rejection_top_p_sampling", + &TopPSamplingReject, + "rejection_top_p_sampling function"); + + m.def("top_k_renorm_probs", &TopKRenorm, "top_k_renorm_probs function"); + + m.def("min_p_sampling", &MinPSamplingFromProbs, "min_p_sampling function"); + + m.def("save_output", &SaveOutMmsgStatic, "save_output function"); + + m.def("limit_thinking_content_length", + &LimitThinkingContentLength, + "limit_thinking_content_length function"); + + m.def("speculate_limit_thinking_content_length", + &SpeculateLimitThinkingContentLength, + "speculate limit thinking content length function"); + + m.def("speculate_get_logits", + &SpeculateGetLogits, + "speculate_get_logits function"); + + m.def("speculate_insert_first_token", + &SpeculateInsertFirstToken, + "speculate_insert_first_token function"); + + m.def("speculate_get_target_logits", + &SpeculateGetTargetLogits, + "speculate_get_target_logits function"); +#endif + + m.def("update_attn_mask_offsets", + &UpdateAttnMaskOffsets, + "update attention mask"); + + m.def("fused_neox_rope_embedding", + &FusedNeoxRopeEmbedding, + "fused_neox_rope_embedding function"); + +#ifndef DISABLE_GELU_TANH_OP + m.def("gelu_tanh", &GeluTanh, "gelu_tanh function"); +#endif - m.def("speculate_step_paddle",&SpeculateStepPaddle, "speculate_step_paddle function"); + m.def("reasoning_phase_token_constraint", + &ReasoningPhaseTokenConstraint, + "reasoning_phase_token_constraint function"); + + m.def("get_attn_mask_q", &get_attn_mask_q, "get_attn_mask_q function"); + + m.def("custom_numpy_to_tensor", + &CustomNumpyToTensor, + "custom_numpy_to_tensor function"); + m.def("prefill_permute_to_masked_gemm", + &PrefillPermuteToMaskedGemm, + py::arg("x"), + py::arg("scale"), + py::arg("topk_ids"), + py::arg("num_local_experts"), + py::arg("max_token_num"), + "Prefill permute to masked GEMM for MoE"); + + m.def("depermute_prefill_combine", + &DepermutePrefillCombine, + py::arg("x"), + py::arg("indice_map"), + py::arg("topk_weights"), + py::arg("num_worst_tokens"), + "Depermute and combine expert outputs for MoE prefill"); + + m.def("radix_topk_ragged_transform", + &RadixTopkRaggedTransform, + "radix_topk_ragged_transform function"); + + m.def("dsk_attn_write_cache", &DSMLAWriteCacheKernel, "dsk_attn_write_cache"); + + m.def("indexer_k_quant_and_cache", + &IndexerKQuantAndCacheKernel, + "indexer_k_quant_and_cache"); + + m.def("cp_gather_indexer_k_quant_cache", + &CpGatherIndexerKQuantCacheKernel, + "cp_gather_indexer_k_quant_cache"); + + m.def("per_token_group_fp8_quant", + &PerTokenGroupQuantFp8, + "per_token_group_quant_fp8"); } diff --git a/custom_ops/gpu_ops/cuda_multiprocess.h b/custom_ops/gpu_ops/cuda_multiprocess.h index c4b3c841094..a001b601f45 100644 --- a/custom_ops/gpu_ops/cuda_multiprocess.h +++ b/custom_ops/gpu_ops/cuda_multiprocess.h @@ -41,6 +41,8 @@ #include #include #endif +#include +#include #include #ifdef PADDLE_WITH_HIP @@ -52,35 +54,34 @@ namespace cub = hipcub; #define GPU(str) cuda##str #endif -#define checkCudaErrors(call) \ - do { \ - GPU(Error_t) err = call; \ - if (err != GPU(Success)) { \ - printf("CUDA error at %s %d: %s\n", \ - __FILE__, \ - __LINE__, \ - GPU(GetErrorString)(err)); \ - exit(EXIT_FAILURE); \ - } \ - } while (0) +#define checkCudaErrors(call) \ + do { \ + GPU(Error_t) err = call; \ + if (err != GPU(Success)) { \ + throw std::runtime_error(std::string("CUDA error at ") + __FILE__ + \ + ":" + std::to_string(__LINE__) + " '" + \ + GPU(GetErrorString)(err) + "'"); \ + } \ + } while (0) typedef struct shmStruct_st { - size_t nprocesses; - GPU(IpcMemHandle_t) memHandle; + size_t nprocesses; + GPU(IpcMemHandle_t) memHandle; } shmStruct; typedef struct sharedMemoryInfo_st { - void *addr; - size_t size; + void *addr; + size_t size; #if defined(WIN32) || defined(_WIN32) || defined(WIN64) || defined(_WIN64) - HANDLE shmHandle; + HANDLE shmHandle; #else - int shmFd; + int shmFd; #endif } sharedMemoryInfo; - -inline int sharedMemoryOpen(const char *name, size_t sz, sharedMemoryInfo *info) { +inline int sharedMemoryOpen(const char *name, + size_t sz, + sharedMemoryInfo *info) { #if defined(WIN32) || defined(_WIN32) || defined(WIN64) || defined(_WIN64) info->size = sz; diff --git a/custom_ops/gpu_ops/custom_all_reduce/all_reduce.cu b/custom_ops/gpu_ops/custom_all_reduce/all_reduce.cu index 7c6d4cec793..f3143c423af 100644 --- a/custom_ops/gpu_ops/custom_all_reduce/all_reduce.cu +++ b/custom_ops/gpu_ops/custom_all_reduce/all_reduce.cu @@ -1,4 +1,5 @@ -// adapted from: https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_int8.cu +// adapted from: +// https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_int8.cu // Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. // @@ -14,8 +15,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "helper.h" #include "all_reduce.cuh" +#include "helper.h" // Fake pointer type, must match fptr_t type in ops.h. // We use this type alias to indicate when pointers are passed in as int64_t. @@ -23,8 +24,9 @@ using fptr_t = int64_t; static_assert(sizeof(void*) == sizeof(fptr_t)); fptr_t init_custom_all_reduce(const std::vector& fake_ipc_ptrs, - paddle::Tensor& rank_data, int64_t rank, - bool full_nvlink) { + paddle::Tensor& rank_data, + int64_t rank, + bool full_nvlink) { int world_size = fake_ipc_ptrs.size(); if (world_size > 8) throw std::invalid_argument("world size > 8 is not supported"); @@ -37,9 +39,71 @@ fptr_t init_custom_all_reduce(const std::vector& fake_ipc_ptrs, for (int i = 0; i < world_size; i++) { ipc_ptrs[i] = reinterpret_cast(fake_ipc_ptrs[i]); } - return (fptr_t) new paddle::CustomAllreduce(ipc_ptrs, rank_data.data(), - rank_data.numel(), rank, world_size, - full_nvlink); + return (fptr_t) new paddle::CustomAllreduce(ipc_ptrs, + rank_data.data(), + rank_data.numel(), + rank, + world_size, + full_nvlink); +} + +/** + * alltoall and transpose in decode. + */ +void decode_alltoall_transpose(paddle::Tensor& inp, + paddle::Tensor& out, + fptr_t _fa, + fptr_t _reg_buffer, + int64_t reg_buffer_sz_bytes) { + auto fa = reinterpret_cast(_fa); + auto stream = inp.stream(); + + auto input_size = inp.numel() * phi::SizeOf(inp.dtype()); + auto token_num = inp.shape()[0]; + auto hidden_size = inp.shape()[1]; + auto reg_buffer = reinterpret_cast(_reg_buffer); + if (reg_buffer) { + CUDACHECK(cudaMemcpyAsync( + reg_buffer, inp.data(), input_size, cudaMemcpyDeviceToDevice, stream)); + } else { + reg_buffer = inp.data(); + } + switch (out.dtype()) { + case phi::DataType::FLOAT32: { + fa->decode_alltoall_transpose(stream, + reinterpret_cast(reg_buffer), + reinterpret_cast(out.data()), + token_num, + hidden_size, + out.numel()); + break; + } + case phi::DataType::FLOAT16: { + fa->decode_alltoall_transpose(stream, + reinterpret_cast(reg_buffer), + reinterpret_cast(out.data()), + token_num, + hidden_size, + out.numel()); + break; + } +#if (!defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800) + case phi::DataType::BFLOAT16: { + fa->decode_alltoall_transpose( + stream, + reinterpret_cast(reg_buffer), + reinterpret_cast(out.data()), + token_num, + hidden_size, + out.numel()); + break; + } +#endif + default: + throw std::runtime_error( + "decode_alltoall_transpose only supports float32, float16 and " + "bfloat16"); + } } /** @@ -49,36 +113,43 @@ fptr_t init_custom_all_reduce(const std::vector& fake_ipc_ptrs, * Otherwise, _reg_buffer is assumed to be IPC-registered and inp is first * copied into _reg_buffer. */ -void all_reduce(fptr_t _fa, paddle::Tensor& inp, paddle::Tensor& out, - fptr_t _reg_buffer, int64_t reg_buffer_sz_bytes) { +void all_reduce(paddle::Tensor& inp, + paddle::Tensor& out, + fptr_t _fa, + fptr_t _reg_buffer, + int64_t reg_buffer_sz_bytes) { auto fa = reinterpret_cast(_fa); auto stream = inp.stream(); - auto input_size = inp.numel() * 2; + auto input_size = inp.numel() * phi::SizeOf(inp.dtype()); auto reg_buffer = reinterpret_cast(_reg_buffer); if (reg_buffer) { - cudaMemcpyAsync(reg_buffer, inp.data(), input_size, - cudaMemcpyDeviceToDevice, stream); + CUDACHECK(cudaMemcpyAsync( + reg_buffer, inp.data(), input_size, cudaMemcpyDeviceToDevice, stream)); } else { reg_buffer = inp.data(); } switch (out.dtype()) { case phi::DataType::FLOAT32: { - fa->allreduce(stream, reinterpret_cast(reg_buffer), + fa->allreduce(stream, + reinterpret_cast(reg_buffer), reinterpret_cast(out.data()), out.numel()); break; } case phi::DataType::FLOAT16: { - fa->allreduce(stream, reinterpret_cast(reg_buffer), - reinterpret_cast(out.data()), out.numel()); + fa->allreduce(stream, + reinterpret_cast(reg_buffer), + reinterpret_cast(out.data()), + out.numel()); break; } #if (!defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800) case phi::DataType::BFLOAT16: { - fa->allreduce( - stream, reinterpret_cast(reg_buffer), - reinterpret_cast(out.data()), out.numel()); + fa->allreduce(stream, + reinterpret_cast(reg_buffer), + reinterpret_cast(out.data()), + out.numel()); break; } #endif @@ -122,17 +193,21 @@ void register_graph_buffers(fptr_t _fa, for (int i = 0; i < handles.size(); i++) { bytes.emplace_back(handles[i].begin(), handles[i].end()); } - bytes.reserve(handles.size()); fa->register_graph_buffers(bytes, offsets); } +void clear_ipc_handles(fptr_t _fa) { + auto fa = reinterpret_cast(_fa); + fa->clear_ipc_handles(); +} + std::tuple allocate_shared_buffer_and_handle( int64_t size) { - auto device_index = phi::backends::gpu::GetCurrentDeviceId(); void* buffer; cudaStreamCaptureMode mode = cudaStreamCaptureModeRelaxed; - auto stream = paddle::GetCurrentCUDAStream(phi::GPUPlace(device_index))->raw_stream(); + auto stream = + paddle::GetCurrentCUDAStream(phi::GPUPlace(device_index))->raw_stream(); CUDACHECK(cudaThreadExchangeStreamCaptureMode(&mode)); // Allocate buffer @@ -144,22 +219,41 @@ std::tuple allocate_shared_buffer_and_handle( // Create IPC memhandle for the allocated buffer. // Will use it in open_mem_handle. auto handle = - paddle::empty({static_cast(sizeof(cudaIpcMemHandle_t))}, paddle::DataType::UINT8, paddle::GPUPlace(device_index)); - CUDACHECK( - cudaIpcGetMemHandle((cudaIpcMemHandle_t*)handle.data(), buffer)); + paddle::empty({static_cast(sizeof(cudaIpcMemHandle_t))}, + paddle::DataType::UINT8, + paddle::GPUPlace(device_index)); + CUDACHECK(cudaIpcGetMemHandle((cudaIpcMemHandle_t*)handle.data(), buffer)); return std::make_tuple(reinterpret_cast(buffer), handle); } - fptr_t open_mem_handle(paddle::Tensor& mem_handle) { void* ipc_ptr; - CUDACHECK(cudaIpcOpenMemHandle( - (void**)&ipc_ptr, *((const cudaIpcMemHandle_t*)mem_handle.data()), - cudaIpcMemLazyEnablePeerAccess)); + CUDACHECK( + cudaIpcOpenMemHandle((void**)&ipc_ptr, + *((const cudaIpcMemHandle_t*)mem_handle.data()), + cudaIpcMemLazyEnablePeerAccess)); return reinterpret_cast(ipc_ptr); } void free_shared_buffer(fptr_t buffer) { CUDACHECK(cudaFree(reinterpret_cast(buffer))); } + +PD_BUILD_STATIC_OP(decode_alltoall_transpose) + .Inputs({"inp", "out"}) + .Outputs({"new_out"}) + .Attrs({"_fa: int64_t", + "_reg_buffer: int64_t", + "reg_buffer_sz_bytes: int64_t"}) + .SetInplaceMap({{"out", "new_out"}}) + .SetKernelFn(PD_KERNEL(decode_alltoall_transpose)); + +PD_BUILD_STATIC_OP(all_reduce) + .Inputs({"inp", "out"}) + .Outputs({"new_out"}) + .Attrs({"_fa: int64_t", + "_reg_buffer: int64_t", + "reg_buffer_sz_bytes: int64_t"}) + .SetInplaceMap({{"out", "new_out"}}) + .SetKernelFn(PD_KERNEL(all_reduce)); diff --git a/custom_ops/gpu_ops/custom_all_reduce/all_reduce.cuh b/custom_ops/gpu_ops/custom_all_reduce/all_reduce.cuh index 2dd52871a9c..cb4c25bcf4e 100644 --- a/custom_ops/gpu_ops/custom_all_reduce/all_reduce.cuh +++ b/custom_ops/gpu_ops/custom_all_reduce/all_reduce.cuh @@ -18,21 +18,23 @@ #include #include -#include #include +#include #include #include +#include +#include #include #include -#define CUDACHECK(cmd) \ - do { \ - cudaError_t e = cmd; \ - if (e != cudaSuccess) { \ - printf("Failed: Cuda error %s:%d '%s'\n", __FILE__, __LINE__, \ - cudaGetErrorString(e)); \ - exit(EXIT_FAILURE); \ - } \ +#define CUDACHECK(cmd) \ + do { \ + cudaError_t e = cmd; \ + if (e != cudaSuccess) { \ + throw std::runtime_error(std::string("CUDA error at ") + __FILE__ + \ + ":" + std::to_string(__LINE__) + " '" + \ + cudaGetErrorString(e) + "'"); \ + } \ } while (0) namespace paddle { @@ -188,7 +190,8 @@ static DINLINE FlagType ld_flag_volatile(FlagType* flag_addr) { // semantic is used to enforce memory access order before and after this // barrier. template -DINLINE void multi_gpu_barrier(const RankSignals& sg, Signal* self_sg, +DINLINE void multi_gpu_barrier(const RankSignals& sg, + Signal* self_sg, int rank) { if constexpr (!is_start) __syncthreads(); static_assert( @@ -205,10 +208,12 @@ DINLINE void multi_gpu_barrier(const RankSignals& sg, Signal* self_sg, &self_sg->peer_counter[val % 2][blockIdx.x][threadIdx.x]; if constexpr (need_fence) { st_flag_release(peer_counter_ptr, val); - while (ld_flag_acquire(self_counter_ptr) != val); + while (ld_flag_acquire(self_counter_ptr) != val) { + } } else { st_flag_volatile(peer_counter_ptr, val); - while (ld_flag_volatile(self_counter_ptr) != val); + while (ld_flag_volatile(self_counter_ptr) != val) { + } } } if constexpr (is_start || need_fence) __syncthreads(); @@ -224,10 +229,46 @@ DINLINE P packed_reduce(const P* ptrs[], int idx) { return downcast

(tmp); } +template +__global__ void __launch_bounds__(512, 1) decode_alltoall_transpose_kernel( + RankData* _dp, // [tp_size, m / tp_size, part_hidden_size] + RankSignals sg, + Signal* self_sg, + T* __restrict__ result, // [m / tp_size, part_hidden_size * tp_size] + const int rank, + const int token_num, + const int hidden_size, + const int size) { + using P = typename packed_t::P; + using A = typename packed_t::A; + // note: we don't reorder the address so the accumulation order is the same + // for all ranks, ensuring bitwise identical results + const int hidden_size_p = hidden_size / packed_t::P::size; + const int part_hidden_size_p = hidden_size_p / ngpus; + const int rank_token_id = token_num / ngpus * rank; + auto dp = *_dp; + multi_gpu_barrier(sg, self_sg, rank); + // alltoall and transpose + for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; + idx += gridDim.x * blockDim.x) { + const int token_idx = idx / hidden_size_p; + const int src_token_idx = token_idx + rank_token_id; + const int src_rank = (idx % hidden_size_p) / part_hidden_size_p; + const int src_idx = + src_token_idx * part_hidden_size_p + (idx % part_hidden_size_p); + ((P*)result)[idx] = ((const P**)&dp.ptrs[0])[src_rank][src_idx]; + } + multi_gpu_barrier(sg, self_sg, rank); +} + template __global__ void __launch_bounds__(512, 1) - cross_device_reduce_1stage(RankData* _dp, RankSignals sg, Signal* self_sg, - T* __restrict__ result, int rank, int size) { + cross_device_reduce_1stage(RankData* _dp, + RankSignals sg, + Signal* self_sg, + T* __restrict__ result, + int rank, + int size) { using P = typename packed_t::P; using A = typename packed_t::A; // note: we don't reorder the address so the accumulation order is the same @@ -249,8 +290,12 @@ DINLINE P* get_tmp_buf(Signal* sg) { template __global__ void __launch_bounds__(512, 1) - cross_device_reduce_2stage(RankData* _dp, RankSignals sg, Signal* self_sg, - T* __restrict__ result, int rank, int size) { + cross_device_reduce_2stage(RankData* _dp, + RankSignals sg, + Signal* self_sg, + T* __restrict__ result, + int rank, + int size) { int tid = blockIdx.x * blockDim.x + threadIdx.x; int stride = gridDim.x * blockDim.x; using P = typename packed_t::P; @@ -303,7 +348,7 @@ class CustomAllreduce { bool full_nvlink_; RankSignals sg_; - // Stores an map from a pointer to its peer pointters from all ranks. + // Stores an map from a pointer to its peer pointers from all ranks. std::unordered_map buffers_; Signal* self_sg_; @@ -323,7 +368,7 @@ class CustomAllreduce { // 3. (In Python) all gather the IPC handles. // 4. Obtain the peer pointers by opening the IPC handles, and store them in // the rank data array at corresponding positions. - RankData *d_rank_data_base_, *d_rank_data_end_; + RankData *d_rank_data_base_origin_, *d_rank_data_base_, *d_rank_data_end_; std::vector graph_unreg_buffers_; // a map from IPC handles to opened IPC pointers std::map ipc_handles_; @@ -338,8 +383,12 @@ class CustomAllreduce { * Note: this class does not own any device memory. Any required buffers * are passed in from the constructor. */ - CustomAllreduce(Signal** signals, void* rank_data, size_t rank_data_sz, - int rank, int world_size, bool full_nvlink = true) + CustomAllreduce(Signal** signals, + void* rank_data, + size_t rank_data_sz, + int rank, + int world_size, + bool full_nvlink = true) : rank_(rank), world_size_(world_size), full_nvlink_(full_nvlink), @@ -349,6 +398,7 @@ class CustomAllreduce { for (int i = 0; i < world_size_; i++) { sg_.signals[i] = signals[i]; } + d_rank_data_base_origin_ = d_rank_data_base_; } char* open_ipc_handle(const void* ipc_handle) { @@ -405,6 +455,7 @@ class CustomAllreduce { CUDACHECK( cudaMemcpy(d_data, &data, sizeof(RankData), cudaMemcpyHostToDevice)); buffers_[ptrs[rank_]] = d_data; + d_rank_data_base_origin_ = d_rank_data_base_; } // Note: when registering graph buffers, we intentionally choose to not @@ -434,13 +485,95 @@ class CustomAllreduce { } } } - CUDACHECK(cudaMemcpy(d_rank_data_base_, rank_data.data(), + CUDACHECK(cudaMemcpy(d_rank_data_base_, + rank_data.data(), sizeof(RankData) * num_buffers, cudaMemcpyHostToDevice)); d_rank_data_base_ += num_buffers; graph_unreg_buffers_.clear(); } + /** + * alltoall and transpose in decode. + */ + template + void decode_alltoall_transpose(cudaStream_t stream, + T* input, + T* output, + int token_num, + int part_hidden_size, + int size, + int threads = 512, + int block_limit = 36) { + auto d = packed_t::P::size; + int hidden_size = part_hidden_size * world_size_; + if (size % d != 0) + throw std::runtime_error( + "custom decode_alltoall_transpose currently requires input length to " + "be multiple " + "of " + + std::to_string(d)); + if (size / d % world_size_ != 0) + throw std::runtime_error( + "custom decode_alltoall_transpose currently requires input length to " + "be multiple " + "of " + + std::to_string(d) + " and " + std::to_string(world_size_)); + if (token_num % world_size_ != 0) + throw std::runtime_error( + "custom decode_alltoall_transpose currently requires input token_num " + "to be multiple " + "of " + + std::to_string(world_size_)); + if (block_limit > kMaxBlocks) + throw std::runtime_error("max supported block limit is " + + std::to_string(kMaxBlocks) + ". Got " + + std::to_string(block_limit)); + + RankData* ptrs; + cudaStreamCaptureStatus status; + CUDACHECK(cudaStreamIsCapturing(stream, &status)); + if (status == cudaStreamCaptureStatusActive) { + ptrs = d_rank_data_base_ + graph_unreg_buffers_.size(); + graph_unreg_buffers_.push_back(input); + } else { + auto it = buffers_.find(input); + if (it == buffers_.end()) + throw std::runtime_error( + "buffer address " + + std::to_string(reinterpret_cast(input)) + + " is not registered!"); + ptrs = it->second; + } + + size /= d; + auto bytes = size * sizeof(typename packed_t::P); + int blocks = std::min(block_limit, (size + threads - 1) / threads); +#define KL(ngpus, name) \ + name<<>>( \ + ptrs, sg_, self_sg_, output, rank_, token_num, hidden_size, size); + +#define REDUCE_CASE(ngpus) \ + case ngpus: { \ + KL(ngpus, decode_alltoall_transpose_kernel); \ + break; \ + } + + switch (world_size_) { + REDUCE_CASE(2) + REDUCE_CASE(4) + REDUCE_CASE(6) + REDUCE_CASE(8) + default: + throw std::runtime_error( + "custom allreduce only supports num gpus in (2,4,6,8). Actual num " + "gpus = " + + std::to_string(world_size_)); + } +#undef REDUCE_CASE +#undef KL + } + /** * Performs allreduce, assuming input has already been registered. * @@ -451,8 +584,12 @@ class CustomAllreduce { * guess is that too many SMs will cause contention on NVLink bus. */ template - void allreduce(cudaStream_t stream, T* input, T* output, int size, - int threads = 512, int block_limit = 36) { + void allreduce(cudaStream_t stream, + T* input, + T* output, + int size, + int threads = 512, + int block_limit = 36) { auto d = packed_t::P::size; if (size % d != 0) throw std::runtime_error( @@ -483,9 +620,9 @@ class CustomAllreduce { size /= d; auto bytes = size * sizeof(typename packed_t::P); int blocks = std::min(block_limit, (size + threads - 1) / threads); -#define KL(ngpus, name) \ - name<<>>(ptrs, sg_, self_sg_, output, \ - rank_, size); +#define KL(ngpus, name) \ + name<<>>( \ + ptrs, sg_, self_sg_, output, rank_, size); #define REDUCE_CASE(ngpus) \ case ngpus: { \ @@ -517,10 +654,15 @@ class CustomAllreduce { #undef KL } - ~CustomAllreduce() { + void clear_ipc_handles() { for (auto [_, ptr] : ipc_handles_) { CUDACHECK(cudaIpcCloseMemHandle(ptr)); } + + ipc_handles_.clear(); + d_rank_data_base_ = d_rank_data_base_origin_; } + + ~CustomAllreduce() { clear_ipc_handles(); } }; } // namespace paddle diff --git a/custom_ops/gpu_ops/custom_ftok.h b/custom_ops/gpu_ops/custom_ftok.h new file mode 100644 index 00000000000..302061baf63 --- /dev/null +++ b/custom_ops/gpu_ops/custom_ftok.h @@ -0,0 +1,37 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include + +// Custom ftok that uses the low 20 bits of id instead of only 8 bits. +// This avoids dependency on filesystem paths while preserving queue separation. +inline key_t custom_ftok(const char* path, int id) { + struct stat st; + if (stat(path, &st) < 0) { + fprintf(stderr, + "[custom_ftok] stat(\"%s\") failed (errno=%d), " + "msg queue key will be invalid!\n", + path, + errno); + return static_cast(-1); + } + // low 4 bits of st_dev | low 8 bits of st_ino | low 20 bits of id + return static_cast(((st.st_dev & 0x0f) << 28) | + ((st.st_ino & 0xff) << 20) | (id & 0xfffff)); +} diff --git a/custom_ops/gpu_ops/cutlass_extensions/arch/copy_red_global.hpp b/custom_ops/gpu_ops/cutlass_extensions/arch/copy_red_global.hpp index 61a41031bfb..4b99d9651b0 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/arch/copy_red_global.hpp +++ b/custom_ops/gpu_ops/cutlass_extensions/arch/copy_red_global.hpp @@ -1,12 +1,12 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation @@ -18,14 +18,15 @@ * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ #pragma once @@ -38,315 +39,331 @@ // Config -#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 700) && (__CUDACC_VER_MAJOR__ >= 10)) +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 700) && \ + (__CUDACC_VER_MAJOR__ >= 10)) #define CUTE_ARCH_RED_F16_SM70_ENABLED #endif -#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && (__CUDACC_VER_MAJOR__ >= 12)) +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && \ + (__CUDACC_VER_MAJOR__ >= 12)) #define CUTE_ARCH_RED_VEC_SM90_ENABLED #define CUTE_ARCH_RED_BF16_SM90_ENABLED #endif -namespace cute -{ +namespace cute { ////////////////////////////////// // Wrapper around CUDA's atomicAdd ////////////////////////////////// template -struct TypedAtomicAdd -{ - using SRegisters = T[1]; - using DRegisters = T[1]; - - CUTE_HOST_DEVICE static constexpr void copy(T const& src, T& dst) - { - atomicAdd(&dst, src); - } +struct TypedAtomicAdd { + using SRegisters = T[1]; + using DRegisters = T[1]; + + CUTE_HOST_DEVICE static constexpr void copy(T const& src, T& dst) { + atomicAdd(&dst, src); + } }; template -struct Copy_Traits> -{ - // Logical thread id to thread idx (one-thread) - using ThrID = Layout<_1>; - - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout::value>>>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout::value>>>; - - // Reference map from (thr,val) to bit - using RefLayout = SrcLayout; +struct Copy_Traits> { + // Logical thread id to thread idx (one-thread) + using ThrID = Layout<_1>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout::value>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout::value>>>; + + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; }; ////////////////////////////////// // F16 ADD PTX ////////////////////////////////// -struct SM70_RED_ADD_NOFTZ_F16 -{ - using SRegisters = uint16_t[1]; - using DRegisters = uint16_t[1]; +struct SM70_RED_ADD_NOFTZ_F16 { + using SRegisters = uint16_t[1]; + using DRegisters = uint16_t[1]; - CUTE_HOST_DEVICE static void copy(uint16_t const& src0, uint16_t& gmem_dst) - { + CUTE_HOST_DEVICE static void copy(uint16_t const& src0, uint16_t& gmem_dst) { #if defined(CUTE_ARCH_RED_F16_SM70_ENABLED) - asm volatile("red.global.add.noftz.f16 [%0], %1;\n" ::"l"(&gmem_dst), "h"(src0)); + asm volatile("red.global.add.noftz.f16 [%0], %1;\n" ::"l"(&gmem_dst), + "h"(src0)); #else - CUTE_INVALID_CONTROL_PATH("Trying to use red.global.f16 without CUTE_ARCH_RED_F16_SM70_ENABLED."); + CUTE_INVALID_CONTROL_PATH( + "Trying to use red.global.f16 without CUTE_ARCH_RED_F16_SM70_ENABLED."); #endif - } + } }; template <> -struct Copy_Traits -{ - // Logical thread id to thread idx (one-thread) - using ThrID = Layout<_1>; +struct Copy_Traits { + // Logical thread id to thread idx (one-thread) + using ThrID = Layout<_1>; - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>; - // Reference map from (thr,val) to bit - using RefLayout = SrcLayout; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; }; -struct SM70_RED_ADD_NOFTZ_F16x2 -{ - using SRegisters = uint32_t[1]; - using DRegisters = uint32_t[1]; +struct SM70_RED_ADD_NOFTZ_F16x2 { + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[1]; - CUTE_HOST_DEVICE static void copy(uint32_t const& src0, uint32_t& gmem_dst) - { + CUTE_HOST_DEVICE static void copy(uint32_t const& src0, uint32_t& gmem_dst) { #if defined(CUTE_ARCH_RED_F16_SM70_ENABLED) - asm volatile("red.global.add.noftz.f16x2 [%0], %1;\n" ::"l"(&gmem_dst), "r"(src0)); + asm volatile("red.global.add.noftz.f16x2 [%0], %1;\n" ::"l"(&gmem_dst), + "r"(src0)); #else - CUTE_INVALID_CONTROL_PATH("Trying to use red.global.f16 without CUTE_ARCH_RED_F16_SM70_ENABLED."); + CUTE_INVALID_CONTROL_PATH( + "Trying to use red.global.f16 without CUTE_ARCH_RED_F16_SM70_ENABLED."); #endif - } + } }; template <> -struct Copy_Traits -{ - // Logical thread id to thread idx (one-thread) - using ThrID = Layout<_1>; +struct Copy_Traits { + // Logical thread id to thread idx (one-thread) + using ThrID = Layout<_1>; - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>; - // Reference map from (thr,val) to bit - using RefLayout = SrcLayout; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; }; -struct SM90_RED_ADD_NOFTZ_F16x2_V2 -{ - using SRegisters = uint32_t[2]; - using DRegisters = uint64_t[1]; +struct SM90_RED_ADD_NOFTZ_F16x2_V2 { + using SRegisters = uint32_t[2]; + using DRegisters = uint64_t[1]; - CUTE_HOST_DEVICE static void copy(uint32_t const& src0, uint32_t const& src1, uint64_t& gmem_dst) - { + CUTE_HOST_DEVICE static void copy(uint32_t const& src0, + uint32_t const& src1, + uint64_t& gmem_dst) { #if defined(CUTE_ARCH_RED_VEC_SM90_ENABLED) - asm volatile("red.global.add.noftz.v2.f16x2 [%0], {%1, %2};\n" ::"l"(&gmem_dst), "r"(src0), "r"(src1)); + asm volatile( + "red.global.add.noftz.v2.f16x2 [%0], {%1, %2};\n" ::"l"(&gmem_dst), + "r"(src0), + "r"(src1)); #else - CUTE_INVALID_CONTROL_PATH("Trying to use red.global.vX without CUTE_ARCH_RED_VEC_SM90_ENABLED."); + CUTE_INVALID_CONTROL_PATH( + "Trying to use red.global.vX without CUTE_ARCH_RED_VEC_SM90_ENABLED."); #endif - } + } }; template <> -struct Copy_Traits -{ - // Logical thread id to thread idx (one-thread) - using ThrID = Layout<_1>; +struct Copy_Traits { + // Logical thread id to thread idx (one-thread) + using ThrID = Layout<_1>; - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>; - // Reference map from (thr,val) to bit - using RefLayout = SrcLayout; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; }; -struct SM90_RED_ADD_NOFTZ_F16x2_V4 -{ - using SRegisters = uint32_t[4]; - using DRegisters = uint128_t[1]; +struct SM90_RED_ADD_NOFTZ_F16x2_V4 { + using SRegisters = uint32_t[4]; + using DRegisters = uint128_t[1]; - CUTE_HOST_DEVICE static void copy( - uint32_t const& src0, uint32_t const& src1, uint32_t const& src2, uint32_t const& src3, uint128_t& gmem_dst) - { + CUTE_HOST_DEVICE static void copy(uint32_t const& src0, + uint32_t const& src1, + uint32_t const& src2, + uint32_t const& src3, + uint128_t& gmem_dst) { #if defined(CUTE_ARCH_RED_VEC_SM90_ENABLED) - asm volatile("red.global.add.noftz.v4.f16x2 [%0], {%1, %2, %3, %4};\n" ::"l"(&gmem_dst), "r"(src0), "r"(src1), - "r"(src2), "r"(src3)); + asm volatile( + "red.global.add.noftz.v4.f16x2 [%0], {%1, %2, %3, %4};\n" ::"l"( + &gmem_dst), + "r"(src0), + "r"(src1), + "r"(src2), + "r"(src3)); #else - CUTE_INVALID_CONTROL_PATH("Trying to use red.global.vX without CUTE_ARCH_RED_VEC_SM90_ENABLED."); + CUTE_INVALID_CONTROL_PATH( + "Trying to use red.global.vX without CUTE_ARCH_RED_VEC_SM90_ENABLED."); #endif - } + } }; template <> -struct Copy_Traits -{ - // Logical thread id to thread idx (one-thread) - using ThrID = Layout<_1>; +struct Copy_Traits { + // Logical thread id to thread idx (one-thread) + using ThrID = Layout<_1>; - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>; - // Reference map from (thr,val) to bit - using RefLayout = SrcLayout; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; }; ////////////////////////////////// // BF16 ADD PTX ////////////////////////////////// -struct SM90_RED_ADD_NOFTZ_BF16 -{ - using SRegisters = uint16_t[1]; - using DRegisters = uint16_t[1]; +struct SM90_RED_ADD_NOFTZ_BF16 { + using SRegisters = uint16_t[1]; + using DRegisters = uint16_t[1]; - CUTE_HOST_DEVICE static void copy(uint16_t const& src0, uint16_t& gmem_dst) - { + CUTE_HOST_DEVICE static void copy(uint16_t const& src0, uint16_t& gmem_dst) { #if defined(CUTE_ARCH_RED_BF16_SM90_ENABLED) - asm volatile("red.global.add.noftz.bf16 [%0], %1;\n" ::"l"(&gmem_dst), "h"(src0)); + asm volatile("red.global.add.noftz.bf16 [%0], %1;\n" ::"l"(&gmem_dst), + "h"(src0)); #else - CUTE_INVALID_CONTROL_PATH("Trying to use red.global.bf16 without CUTE_ARCH_RED_BF16_SM90_ENABLED."); + CUTE_INVALID_CONTROL_PATH( + "Trying to use red.global.bf16 without " + "CUTE_ARCH_RED_BF16_SM90_ENABLED."); #endif - } + } }; template <> -struct Copy_Traits -{ - // Logical thread id to thread idx (one-thread) - using ThrID = Layout<_1>; +struct Copy_Traits { + // Logical thread id to thread idx (one-thread) + using ThrID = Layout<_1>; - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>; - // Reference map from (thr,val) to bit - using RefLayout = SrcLayout; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; }; ////////////////////////////////// -struct SM90_RED_ADD_NOFTZ_BF16x2 -{ - using SRegisters = uint32_t[1]; - using DRegisters = uint32_t[1]; +struct SM90_RED_ADD_NOFTZ_BF16x2 { + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[1]; - CUTE_HOST_DEVICE static void copy(uint32_t const& src0, uint32_t& gmem_dst) - { + CUTE_HOST_DEVICE static void copy(uint32_t const& src0, uint32_t& gmem_dst) { #if defined(CUTE_ARCH_RED_BF16_SM90_ENABLED) - asm volatile("red.global.add.noftz.bf16x2 [%0], %1;\n" ::"l"(&gmem_dst), "r"(src0)); + asm volatile("red.global.add.noftz.bf16x2 [%0], %1;\n" ::"l"(&gmem_dst), + "r"(src0)); #else - CUTE_INVALID_CONTROL_PATH("Trying to use red.global.bf16 without CUTE_ARCH_RED_BF16_SM90_ENABLED."); + CUTE_INVALID_CONTROL_PATH( + "Trying to use red.global.bf16 without " + "CUTE_ARCH_RED_BF16_SM90_ENABLED."); #endif - } + } }; template <> -struct Copy_Traits -{ - // Logical thread id to thread idx (one-thread) - using ThrID = Layout<_1>; +struct Copy_Traits { + // Logical thread id to thread idx (one-thread) + using ThrID = Layout<_1>; - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>; - // Reference map from (thr,val) to bit - using RefLayout = SrcLayout; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; }; ////////////////////////////////// -struct SM90_RED_ADD_NOFTZ_BF16x2_V2 -{ - using SRegisters = uint32_t[2]; - using DRegisters = uint64_t[1]; +struct SM90_RED_ADD_NOFTZ_BF16x2_V2 { + using SRegisters = uint32_t[2]; + using DRegisters = uint64_t[1]; - CUTE_HOST_DEVICE static void copy(uint32_t const& src0, uint32_t const& src1, uint64_t& gmem_dst) - { + CUTE_HOST_DEVICE static void copy(uint32_t const& src0, + uint32_t const& src1, + uint64_t& gmem_dst) { #if defined(CUTE_ARCH_RED_BF16_SM90_ENABLED) - asm volatile("red.global.add.noftz.v2.bf16x2 [%0], {%1, %2};\n" ::"l"(&gmem_dst), "r"(src0), "r"(src1)); + asm volatile( + "red.global.add.noftz.v2.bf16x2 [%0], {%1, %2};\n" ::"l"(&gmem_dst), + "r"(src0), + "r"(src1)); #else - CUTE_INVALID_CONTROL_PATH("Trying to use red.global.bf16 without CUTE_ARCH_RED_BF16_SM90_ENABLED."); + CUTE_INVALID_CONTROL_PATH( + "Trying to use red.global.bf16 without " + "CUTE_ARCH_RED_BF16_SM90_ENABLED."); #endif - } + } }; template <> -struct Copy_Traits -{ - // Logical thread id to thread idx (one-thread) - using ThrID = Layout<_1>; +struct Copy_Traits { + // Logical thread id to thread idx (one-thread) + using ThrID = Layout<_1>; - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>; - // Reference map from (thr,val) to bit - using RefLayout = SrcLayout; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; }; ////////////////////////////////// -struct SM90_RED_ADD_NOFTZ_BF16x2_V4 -{ - using SRegisters = uint32_t[4]; - using DRegisters = uint128_t[1]; +struct SM90_RED_ADD_NOFTZ_BF16x2_V4 { + using SRegisters = uint32_t[4]; + using DRegisters = uint128_t[1]; - CUTE_HOST_DEVICE static void copy( - uint32_t const& src0, uint32_t const& src1, uint32_t const& src2, uint32_t const& src3, uint128_t& gmem_dst) - { + CUTE_HOST_DEVICE static void copy(uint32_t const& src0, + uint32_t const& src1, + uint32_t const& src2, + uint32_t const& src3, + uint128_t& gmem_dst) { #if defined(CUTE_ARCH_RED_BF16_SM90_ENABLED) - asm volatile("red.global.add.noftz.v4.bf16x2 [%0], {%1, %2, %3, %4};\n" ::"l"(&gmem_dst), "r"(src0), "r"(src1), - "r"(src2), "r"(src3)); + asm volatile( + "red.global.add.noftz.v4.bf16x2 [%0], {%1, %2, %3, %4};\n" ::"l"( + &gmem_dst), + "r"(src0), + "r"(src1), + "r"(src2), + "r"(src3)); #else - CUTE_INVALID_CONTROL_PATH("Trying to use red.global.bf16 without CUTE_ARCH_RED_BF16_SM90_ENABLED."); + CUTE_INVALID_CONTROL_PATH( + "Trying to use red.global.bf16 without " + "CUTE_ARCH_RED_BF16_SM90_ENABLED."); #endif - } + } }; template <> -struct Copy_Traits -{ - // Logical thread id to thread idx (one-thread) - using ThrID = Layout<_1>; +struct Copy_Traits { + // Logical thread id to thread idx (one-thread) + using ThrID = Layout<_1>; - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>; - // Reference map from (thr,val) to bit - using RefLayout = SrcLayout; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; }; ////////////////////////////////// -} // end namespace cute +} // end namespace cute diff --git a/custom_ops/gpu_ops/cutlass_extensions/arch/memory_copy_sm80.h b/custom_ops/gpu_ops/cutlass_extensions/arch/memory_copy_sm80.h index a9975c01382..0e9fa1b1510 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/arch/memory_copy_sm80.h +++ b/custom_ops/gpu_ops/cutlass_extensions/arch/memory_copy_sm80.h @@ -1,12 +1,12 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation @@ -18,14 +18,15 @@ * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ @@ -59,8 +60,9 @@ template < bool GlobalToShared = true> struct copy; -/// Initiates an asynchronous copy from global memory to shared memory. Rather than predicate -/// the entire transfer, zeros are written to SMEM if the guard predicate is false. +/// Initiates an asynchronous copy from global memory to shared memory. Rather +/// than predicate the entire transfer, zeros are written to SMEM if the guard +/// predicate is false. /// /// cp.async /// @@ -72,7 +74,8 @@ template < bool GlobalToShared = true> struct copy_zfill; -/// Blocks until all but previous cp.async.commit_group operations have committed. +/// Blocks until all but previous cp.async.commit_group operations have +/// committed. /// /// cp.async /// @@ -86,11 +89,11 @@ template < /// Size of the access in bytes int SizeInBytes> struct copy { - /// Copy CUTLASS_DEVICE copy(void *smem_ptr, void const *global_ptr, bool pred_guard = true) { - cp_async(smem_ptr, global_ptr, pred_guard); + cp_async( + smem_ptr, global_ptr, pred_guard); } }; @@ -99,15 +102,15 @@ template < /// Size of the access in bytes int SizeInBytes> struct copy { - /// Copy CUTLASS_DEVICE copy(void *smem_ptr, void const *global_ptr, bool pred_guard = true) { - using AccessType = Array; + using AccessType = Array; - if (pred_guard) { - *static_cast(smem_ptr) = *static_cast(global_ptr); - } + if (pred_guard) { + *static_cast(smem_ptr) = + *static_cast(global_ptr); + } } }; @@ -116,11 +119,11 @@ template < /// Size of the access in bytes int SizeInBytes> struct copy_zfill { - /// Copy with zero fill CUTLASS_DEVICE copy_zfill(void *smem_ptr, void const *global_ptr, bool pred_guard) { - cp_async_zfill(smem_ptr, global_ptr, pred_guard); + cp_async_zfill( + smem_ptr, global_ptr, pred_guard); } }; @@ -129,20 +132,19 @@ template < /// Size of the access in bytes int SizeInBytes> struct copy_zfill { - /// Copy with zero fill CUTLASS_DEVICE copy_zfill(void *smem_ptr, void const *global_ptr, bool pred_guard) { - using AccessType = Array; - - if (pred_guard) { - *static_cast(smem_ptr) = *static_cast(global_ptr); - } - else { - AccessType zeros; - zeros.clear(); - *static_cast(smem_ptr) = zeros; - } + using AccessType = Array; + + if (pred_guard) { + *static_cast(smem_ptr) = + *static_cast(global_ptr); + } else { + AccessType zeros; + zeros.clear(); + *static_cast(smem_ptr) = zeros; + } } }; @@ -153,11 +155,11 @@ template < /// Size of the access in bytes int SizeInBytes> struct copy { - /// Copy CUTLASS_DEVICE copy(void *smem_ptr, void const *global_ptr, bool pred_guard = true) { - cp_async(smem_ptr, global_ptr, pred_guard); + cp_async( + smem_ptr, global_ptr, pred_guard); } }; @@ -166,15 +168,15 @@ template < /// Size of the access in bytes int SizeInBytes> struct copy { - /// Copy CUTLASS_DEVICE copy(void *smem_ptr, void const *global_ptr, bool pred_guard = true) { - using AccessType = Array; + using AccessType = Array; - if (pred_guard) { - *static_cast(smem_ptr) = *static_cast(global_ptr); - } + if (pred_guard) { + *static_cast(smem_ptr) = + *static_cast(global_ptr); + } } }; @@ -183,11 +185,11 @@ template < /// Size of the access in bytes int SizeInBytes> struct copy_zfill { - /// Copy with zero fill CUTLASS_DEVICE copy_zfill(void *smem_ptr, void const *global_ptr, bool pred_guard = true) { - cp_async_zfill(smem_ptr, global_ptr, pred_guard); + cp_async_zfill( + smem_ptr, global_ptr, pred_guard); } }; @@ -196,31 +198,29 @@ template < /// Size of the access in bytes int SizeInBytes> struct copy_zfill { - /// Copy with zero fill CUTLASS_DEVICE copy_zfill(void *smem_ptr, void const *global_ptr, bool pred_guard = true) { - using AccessType = Array; - - if (pred_guard) { - *static_cast(smem_ptr) = *static_cast(global_ptr); - } - else { - AccessType zeros; - zeros.clear(); - *static_cast(smem_ptr) = zeros; - } + using AccessType = Array; + + if (pred_guard) { + *static_cast(smem_ptr) = + *static_cast(global_ptr); + } else { + AccessType zeros; + zeros.clear(); + *static_cast(smem_ptr) = zeros; + } } }; -/// Establishes an ordering w.r.t previously issued cp.async instructions. Does not block. +/// Establishes an ordering w.r.t previously issued cp.async instructions. Does +/// not block. template -CUTLASS_DEVICE -void copy_fence() {} +CUTLASS_DEVICE void copy_fence() {} template <> -CUTLASS_DEVICE -void copy_fence() { +CUTLASS_DEVICE void copy_fence() { cp_async_fence(); } @@ -229,7 +229,6 @@ void copy_fence() { /// Partial specialization template struct copy_wait { - CUTLASS_DEVICE copy_wait() {} }; @@ -237,7 +236,6 @@ struct copy_wait { /// Partial specialization template struct copy_wait { - CUTLASS_DEVICE copy_wait() { cp_async_wait(); } }; diff --git a/custom_ops/gpu_ops/cutlass_extensions/arch/mma.h b/custom_ops/gpu_ops/cutlass_extensions/arch/mma.h index 2362da4f7f2..2ab2981518d 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/arch/mma.h +++ b/custom_ops/gpu_ops/cutlass_extensions/arch/mma.h @@ -1,12 +1,12 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation @@ -18,14 +18,15 @@ * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /*! \file @@ -37,10 +38,8 @@ ///////////////////////////////////////////////////////////////////////////////////////////////// -namespace cutlass -{ -namespace arch -{ +namespace cutlass { +namespace arch { // Tag which triggers MMA which will trigger struct OpMultiplyAddDequantizeInterleavedBToA; @@ -52,8 +51,8 @@ struct OpMultiplyAddDequantizeInterleavedBToA; split out the template below into OpMultiplyAddDequantizeInterleavedBToA along with the quantization op before instantiating the GEMM pieces. - Note that this is somewhat of a hack, but it SIGNIFICANTLY reduces the amount of - code we need to duplicate. + Note that this is somewhat of a hack, but it SIGNIFICANTLY reduces the amount + of code we need to duplicate. */ struct OpMultiplyAddDequantizeInterleavedBToA_percol_scale; struct OpMultiplyAddDequantizeInterleavedBToA_fine_scale; @@ -61,60 +60,59 @@ struct OpMultiplyAddDequantizeInterleavedBToA_fine_scalebias; // The default just forwards the original operator template -struct TagOperator -{ - using TaggedOperator = MmaOp; +struct TagOperator { + using TaggedOperator = MmaOp; }; // Specializations below attach more information to the operator template <> -struct TagOperator -{ - using TaggedOperator = OpMultiplyAddDequantizeInterleavedBToA_percol_scale; +struct TagOperator { + using TaggedOperator = OpMultiplyAddDequantizeInterleavedBToA_percol_scale; }; template <> -struct TagOperator -{ - using TaggedOperator = OpMultiplyAddDequantizeInterleavedBToA_fine_scale; +struct TagOperator { + using TaggedOperator = OpMultiplyAddDequantizeInterleavedBToA_fine_scale; }; template <> -struct TagOperator -{ - using TaggedOperator = OpMultiplyAddDequantizeInterleavedBToA_fine_scalebias; +struct TagOperator { + using TaggedOperator = OpMultiplyAddDequantizeInterleavedBToA_fine_scalebias; }; -// Here we instantiate some structs to "detag" the tagged operator. It splits it back to the original -// operator + the extra information. If no extra info was tagged, the dequant op per column scaling -// as a default. +// Here we instantiate some structs to "detag" the tagged operator. It splits it +// back to the original operator + the extra information. If no extra info was +// tagged, the dequant op per column scaling as a default. template -struct DetagOperator -{ - using Operator = TaggedMmaOp; - static constexpr WeightOnlyQuantOp QuantOp = WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY; +struct DetagOperator { + using Operator = TaggedMmaOp; + static constexpr WeightOnlyQuantOp QuantOp = + WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY; }; template <> -struct DetagOperator -{ - using Operator = OpMultiplyAddDequantizeInterleavedBToA; - static constexpr WeightOnlyQuantOp QuantOp = WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY; +struct DetagOperator { + using Operator = OpMultiplyAddDequantizeInterleavedBToA; + static constexpr WeightOnlyQuantOp QuantOp = + WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY; }; template <> -struct DetagOperator -{ - using Operator = OpMultiplyAddDequantizeInterleavedBToA; - static constexpr WeightOnlyQuantOp QuantOp = WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY; +struct DetagOperator { + using Operator = OpMultiplyAddDequantizeInterleavedBToA; + static constexpr WeightOnlyQuantOp QuantOp = + WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY; }; template <> -struct DetagOperator -{ - using Operator = OpMultiplyAddDequantizeInterleavedBToA; - static constexpr WeightOnlyQuantOp QuantOp = WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS; +struct DetagOperator { + using Operator = OpMultiplyAddDequantizeInterleavedBToA; + static constexpr WeightOnlyQuantOp QuantOp = + WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS; }; -} // namespace arch -} // namespace cutlass +} // namespace arch +} // namespace cutlass diff --git a/custom_ops/gpu_ops/cutlass_extensions/compute_occupancy.h b/custom_ops/gpu_ops/cutlass_extensions/compute_occupancy.h index 29ee9766918..7f7cce414e5 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/compute_occupancy.h +++ b/custom_ops/gpu_ops/cutlass_extensions/compute_occupancy.h @@ -20,66 +20,65 @@ #include "cutlass/device_kernel.h" #include "common/cudaUtils.h" -namespace cutlass_extensions -{ +namespace cutlass_extensions { template -inline int compute_occupancy_for_kernel() -{ +inline int compute_occupancy_for_kernel() { + int smem_size = int(sizeof(typename GemmKernel::SharedStorage)); - int smem_size = int(sizeof(typename GemmKernel::SharedStorage)); - - if (smem_size > (48 << 10)) - { - cudaFuncAttributes attr; - int device = 0; - int max_smem_per_block = 0; - PADDLE_ENFORCE_GPU_SUCCESS(cudaGetDevice(&device)); - PADDLE_ENFORCE_GPU_SUCCESS( - cudaDeviceGetAttribute(&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device)); - if constexpr (enable_cutlass_3x) - { - PADDLE_ENFORCE_GPU_SUCCESS(cudaFuncGetAttributes(&attr, cutlass::device_kernel)); - } - else - { - PADDLE_ENFORCE_GPU_SUCCESS(cudaFuncGetAttributes(&attr, cutlass::Kernel)); - } - if (smem_size + attr.sharedSizeBytes >= static_cast(max_smem_per_block)) - { - // This should mean that - // cudaFuncSetAttribute(cutlass::Kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size) - // wouldn't work. In that case, we return an occupancy of 0. This will cause the heuristic to ignore this - // configuration. - return 0; - } - - if constexpr (enable_cutlass_3x) - { - PADDLE_ENFORCE_GPU_SUCCESS(cudaFuncSetAttribute( - cutlass::device_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - } - else - { - PADDLE_ENFORCE_GPU_SUCCESS(cudaFuncSetAttribute( - cutlass::Kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - } + if (smem_size > (48 << 10)) { + cudaFuncAttributes attr; + int device = 0; + int max_smem_per_block = 0; + PADDLE_ENFORCE_GPU_SUCCESS(cudaGetDevice(&device)); + PADDLE_ENFORCE_GPU_SUCCESS(cudaDeviceGetAttribute( + &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device)); + if constexpr (enable_cutlass_3x) { + PADDLE_ENFORCE_GPU_SUCCESS( + cudaFuncGetAttributes(&attr, cutlass::device_kernel)); + } else { + PADDLE_ENFORCE_GPU_SUCCESS( + cudaFuncGetAttributes(&attr, cutlass::Kernel)); } - - int max_active_blocks = -1; - if constexpr (enable_cutlass_3x) - { - PADDLE_ENFORCE_GPU_SUCCESS( - cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_active_blocks, cutlass::device_kernel, - 128 * (GemmKernel::NumLoadWarpGroups + GemmKernel::NumMmaWarpGroups), smem_size)); + if (smem_size + attr.sharedSizeBytes >= + static_cast(max_smem_per_block)) { + // This should mean that + // cudaFuncSetAttribute(cutlass::Kernel, + // cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size) wouldn't work. + // In that case, we return an occupancy of 0. This will cause the + // heuristic to ignore this configuration. + return 0; } - else - { - PADDLE_ENFORCE_GPU_SUCCESS(cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &max_active_blocks, cutlass::Kernel, GemmKernel::kThreadCount, smem_size)); + + if constexpr (enable_cutlass_3x) { + PADDLE_ENFORCE_GPU_SUCCESS( + cudaFuncSetAttribute(cutlass::device_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size)); + } else { + PADDLE_ENFORCE_GPU_SUCCESS( + cudaFuncSetAttribute(cutlass::Kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size)); } + } + + int max_active_blocks = -1; + if constexpr (enable_cutlass_3x) { + PADDLE_ENFORCE_GPU_SUCCESS(cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &max_active_blocks, + cutlass::device_kernel, + 128 * (GemmKernel::NumLoadWarpGroups + GemmKernel::NumMmaWarpGroups), + smem_size)); + } else { + PADDLE_ENFORCE_GPU_SUCCESS(cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &max_active_blocks, + cutlass::Kernel, + GemmKernel::kThreadCount, + smem_size)); + } - return max_active_blocks; + return max_active_blocks; } -} // namespace cutlass_extensions +} // namespace cutlass_extensions diff --git a/custom_ops/gpu_ops/cutlass_extensions/epilogue/broadcast_load_epilogue_array_c3x.hpp b/custom_ops/gpu_ops/cutlass_extensions/epilogue/broadcast_load_epilogue_array_c3x.hpp index 3e5aa4b038e..60c0dc6302e 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/epilogue/broadcast_load_epilogue_array_c3x.hpp +++ b/custom_ops/gpu_ops/cutlass_extensions/epilogue/broadcast_load_epilogue_array_c3x.hpp @@ -47,7 +47,8 @@ // breaks when moving scales to the CPU. // -// adapted from: https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_array_c3x.hpp +// adapted from: +// https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_array_c3x.hpp #pragma once diff --git a/custom_ops/gpu_ops/cutlass_extensions/epilogue/broadcast_load_epilogue_c2x.hpp b/custom_ops/gpu_ops/cutlass_extensions/epilogue/broadcast_load_epilogue_c2x.hpp index fa1df1fb1e2..dbc6a718f3a 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/epilogue/broadcast_load_epilogue_c2x.hpp +++ b/custom_ops/gpu_ops/cutlass_extensions/epilogue/broadcast_load_epilogue_c2x.hpp @@ -47,7 +47,8 @@ // breaks when moving scales to the CPU. // -// adapted from: https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_c2x.hpp +// adapted from: +// https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_c2x.hpp #pragma once diff --git a/custom_ops/gpu_ops/cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp b/custom_ops/gpu_ops/cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp index 7b56c3c1ac8..aee2fbaeb29 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp +++ b/custom_ops/gpu_ops/cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp @@ -47,7 +47,8 @@ // breaks when moving scales to the CPU. // -// adapted from: https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp +// adapted from: +// https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp #pragma once diff --git a/custom_ops/gpu_ops/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp b/custom_ops/gpu_ops/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp index 513d3741fbc..42db1bd158b 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp +++ b/custom_ops/gpu_ops/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp @@ -1,4 +1,5 @@ -// adapted from: https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp +// adapted from: +// https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp #pragma once @@ -24,31 +25,41 @@ using namespace cute; */ template struct ScaledEpilogueBase { -protected: + protected: using Accum = cutlass::epilogue::threadblock::VisitorAccFetch; template using ColOrScalarLoad = cutlass::epilogue::threadblock::VisitorColOrScalarBroadcast< - OutputTileThreadMap, T, Stride, Int<0>, Int<0>>>; + OutputTileThreadMap, + T, + Stride, Int<0>, Int<0>>>; template using RowOrScalarLoad = cutlass::epilogue::threadblock::VisitorRowOrScalarBroadcast< - OutputTileThreadMap, T, Stride, Int<1>, Int<0>>>; + OutputTileThreadMap, + T, + Stride, Int<1>, Int<0>>>; template using ColLoad = cutlass::epilogue::threadblock::VisitorColBroadcast< - OutputTileThreadMap, T, Stride, Int<0>, Int<0>>>; + OutputTileThreadMap, + T, + Stride, Int<0>, Int<0>>>; template using RowLoad = cutlass::epilogue::threadblock::VisitorRowBroadcast< - OutputTileThreadMap, T, Stride, Int<1>, Int<0>>>; + OutputTileThreadMap, + T, + Stride, Int<1>, Int<0>>>; template using RowOrZeroLoad = cutlass::epilogue::threadblock::VisitorRowOrZeroBroadcast< - OutputTileThreadMap, T, Stride, Int<1>, Int<0>>>; + OutputTileThreadMap, + T, + Stride, Int<1>, Int<0>>>; // This utility function constructs the arguments for the load descriptors // from a tensor. It can handle both row and column, as well as row/column or @@ -56,15 +67,11 @@ struct ScaledEpilogueBase { template static auto args_from_tensor(paddle::Tensor const &tensor) { using Arguments = typename Descriptor::Arguments; - auto *data_ptr = static_cast(const_cast( - tensor.data())); - if constexpr (std::is_same_v> || - std::is_same_v>) { + auto *data_ptr = static_cast(const_cast(tensor.data())); + if constexpr (std::is_same_v> || + std::is_same_v>) { return Arguments{data_ptr, tensor.numel() != 1}; - } - else { + } else { // it would technically work but no use case as data_ptr is never nullptr static_assert(!std::is_same_v>); return Arguments{data_ptr}; @@ -102,24 +109,28 @@ struct ScaledEpilogueBase { template struct ScaledEpilogue : private ScaledEpilogueBase { -private: + private: using SUPER = ScaledEpilogueBase; using Accum = typename SUPER::Accum; using ScaleA = typename SUPER::template ColOrScalarLoad; using ScaleB = typename SUPER::template RowOrScalarLoad; using Compute0 = cutlass::epilogue::threadblock::VisitorCompute< - cutlass::multiplies, float, float, + cutlass::multiplies, + float, + float, cutlass::FloatRoundStyle::round_to_nearest>; using EVTCompute0 = cutlass::epilogue::threadblock::Sm80EVT; using Compute1 = cutlass::epilogue::threadblock::VisitorCompute< - cutlass::multiplies, ElementD, float, + cutlass::multiplies, + ElementD, + float, cutlass::FloatRoundStyle::round_to_nearest>; -public: + public: using EVTCompute = cutlass::epilogue::threadblock::Sm80EVT; using ArgumentType = typename EVTCompute::Arguments; @@ -146,26 +157,30 @@ struct ScaledEpilogue template struct ScaledEpilogueBias : protected ScaledEpilogueBase { -protected: + protected: using SUPER = ScaledEpilogueBase; using Accum = typename SUPER::Accum; using ScaleA = typename SUPER::template ColOrScalarLoad; using ScaleB = typename SUPER::template RowOrScalarLoad; using Bias = typename SUPER::template RowLoad; using Compute0 = cutlass::epilogue::threadblock::VisitorCompute< - cutlass::multiplies, float, float, + cutlass::multiplies, + float, + float, cutlass::FloatRoundStyle::round_to_nearest>; using EVTCompute0 = cutlass::epilogue::threadblock::Sm80EVT; using Compute1 = cutlass::epilogue::threadblock::VisitorCompute< - cutlass::multiply_add, ElementD, float, + cutlass::multiply_add, + ElementD, + float, cutlass::FloatRoundStyle::round_to_nearest>; -public: - using EVTCompute = cutlass::epilogue::threadblock::Sm80EVT; + public: + using EVTCompute = cutlass::epilogue::threadblock:: + Sm80EVT; using ArgumentType = typename EVTCompute::Arguments; static ArgumentType prepare_args(paddle::Tensor const &a_scales, paddle::Tensor const &b_scales, @@ -190,7 +205,7 @@ struct ScaledEpilogueBias template struct ScaledEpilogueBiasAzp : protected ScaledEpilogueBase { -private: + private: using SUPER = ScaledEpilogueBase; using Accum = typename SUPER::Accum; using ScaleA = typename SUPER::template ColOrScalarLoad; @@ -202,35 +217,40 @@ struct ScaledEpilogueBiasAzp // Compute float(accum - azp_adj), both operands are int32_t using ComputeAzp = cutlass::epilogue::threadblock::VisitorCompute< - cutlass::minus, float, int32_t, + cutlass::minus, + float, + int32_t, cutlass::FloatRoundStyle::round_to_nearest>; using EVTComputeAzp = cutlass::epilogue::threadblock::Sm80EVT; using ComputeScaleB = cutlass::epilogue::threadblock::VisitorCompute< - cutlass::multiplies, float, float, + cutlass::multiplies, + float, + float, cutlass::FloatRoundStyle::round_to_nearest>; - using EVTComputeScaleB = - cutlass::epilogue::threadblock::Sm80EVT; + using EVTComputeScaleB = cutlass::epilogue::threadblock:: + Sm80EVT; using ComputeScaleBiasA = cutlass::epilogue::threadblock::VisitorCompute< - cutlass::multiply_add, ElementD, float, + cutlass::multiply_add, + ElementD, + float, cutlass::FloatRoundStyle::round_to_nearest>; -public: - using EVTCompute = - cutlass::epilogue::threadblock::Sm80EVT; + public: + using EVTCompute = cutlass::epilogue::threadblock:: + Sm80EVT; using ArgumentType = typename EVTCompute::Arguments; - static ArgumentType - prepare_args(paddle::Tensor const &a_scales, paddle::Tensor const &b_scales, - paddle::Tensor const &azp_adj, - paddle::optional const &bias) { + static ArgumentType prepare_args( + paddle::Tensor const &a_scales, + paddle::Tensor const &b_scales, + paddle::Tensor const &azp_adj, + paddle::optional const &bias) { auto a_args = SUPER::template args_from_tensor(a_scales); auto b_args = SUPER::template args_from_tensor(b_scales); auto bias_args = SUPER::template args_from_tensor(bias); @@ -257,7 +277,7 @@ struct ScaledEpilogueBiasAzp template struct ScaledEpilogueBiasAzpToken : protected ScaledEpilogueBase { -private: + private: using SUPER = ScaledEpilogueBase; using Accum = typename SUPER::Accum; using ScaleA = typename SUPER::template ColOrScalarLoad; @@ -272,7 +292,9 @@ struct ScaledEpilogueBiasAzpToken // Compute azp * azp_adj using ComputeAzp = cutlass::epilogue::threadblock::VisitorCompute< - cutlass::multiplies, int32_t, int32_t, + cutlass::multiplies, + int32_t, + int32_t, cutlass::FloatRoundStyle::round_to_nearest>; using EVTComputeAzp = @@ -280,35 +302,41 @@ struct ScaledEpilogueBiasAzpToken // Compute float(accum - azp*azp_adj), all operands are int32_t using ComputeAcc = cutlass::epilogue::threadblock::VisitorCompute< - cutlass::minus, float, int32_t, + cutlass::minus, + float, + int32_t, cutlass::FloatRoundStyle::round_to_nearest>; using EVTComputeAcc = cutlass::epilogue::threadblock::Sm80EVT; using ComputeScaleB = cutlass::epilogue::threadblock::VisitorCompute< - cutlass::multiplies, float, float, + cutlass::multiplies, + float, + float, cutlass::FloatRoundStyle::round_to_nearest>; - using EVTComputeScaleB = - cutlass::epilogue::threadblock::Sm80EVT; + using EVTComputeScaleB = cutlass::epilogue::threadblock:: + Sm80EVT; using ComputeScaleBiasA = cutlass::epilogue::threadblock::VisitorCompute< - cutlass::multiply_add, ElementD, float, + cutlass::multiply_add, + ElementD, + float, cutlass::FloatRoundStyle::round_to_nearest>; -public: - using EVTCompute = - cutlass::epilogue::threadblock::Sm80EVT; + public: + using EVTCompute = cutlass::epilogue::threadblock:: + Sm80EVT; using ArgumentType = typename EVTCompute::Arguments; - static ArgumentType - prepare_args(paddle::Tensor const &a_scales, paddle::Tensor const &b_scales, - paddle::Tensor const &azp_adj, paddle::Tensor const &azp, - paddle::optional const &bias) { + static ArgumentType prepare_args( + paddle::Tensor const &a_scales, + paddle::Tensor const &b_scales, + paddle::Tensor const &azp_adj, + paddle::Tensor const &azp, + paddle::optional const &bias) { auto a_args = SUPER::template args_from_tensor(a_scales); auto b_args = SUPER::template args_from_tensor(b_scales); auto bias_args = SUPER::template args_from_tensor(bias); @@ -324,4 +352,4 @@ struct ScaledEpilogueBiasAzpToken } }; -}; // namespace fastdeploy::c2x +}; // namespace fastdeploy::c2x diff --git a/custom_ops/gpu_ops/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp b/custom_ops/gpu_ops/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp index 38a51d91457..abb73ce84c5 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp +++ b/custom_ops/gpu_ops/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp @@ -1,4 +1,5 @@ -// adapted from: https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp +// adapted from: +// https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp #pragma once @@ -6,6 +7,8 @@ // clang-format off #include "cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp" #include "cutlass_extensions/epilogue/broadcast_load_epilogue_array_c3x.hpp" + +#include "helper.h" // clang-format on /* @@ -22,24 +25,28 @@ namespace fastdeploy::c3x { using namespace cute; -template struct identity { +template +struct identity { CUTLASS_HOST_DEVICE T operator()(T lhs) const { return lhs; } }; template struct TrivialEpilogue { -private: + private: using Accum = cutlass::epilogue::fusion::Sm90AccFetch; using Compute = cutlass::epilogue::fusion::Sm90Compute< - cutlass::epilogue::thread::Identity, ElementD, ElementAcc, + cutlass::epilogue::thread::Identity, + ElementD, + ElementAcc, cutlass::FloatRoundStyle::round_to_nearest>; -public: + public: using EVTCompute = cutlass::epilogue::fusion::Sm90EVT; using ArgumentType = typename EVTCompute::Arguments; - template static ArgumentType prepare_args(Args... args) { + template + static ArgumentType prepare_args(Args... args) { return {}; } }; @@ -50,38 +57,60 @@ struct TrivialEpilogue { */ template struct ScaledEpilogueBase { -protected: + protected: using Accum = cutlass::epilogue::fusion::Sm90AccFetch; template using ColOrScalarLoad = cutlass::epilogue::fusion::Sm90ColOrScalarBroadcast< - 0 /*Stages*/, TileShape, T, Stride, Int<0>, Int<0>>>; + 0 /*Stages*/, + TileShape, + T, + Stride, Int<0>, Int<0>>>; template using RowOrScalarLoad = cutlass::epilogue::fusion::Sm90RowOrScalarBroadcast< - 0 /*Stages*/, TileShape, T, Stride, Int<1>, Int<0>>>; + 0 /*Stages*/, + TileShape, + T, + Stride, Int<1>, Int<0>>>; // Don't want to support nullptr by default template using ColLoad = cutlass::epilogue::fusion::Sm90ColBroadcast< - 0 /*Stages*/, TileShape, T, T, Stride, Int<0>, Int<0>>, - 128 / sizeof_bits_v, EnableNullPtr>; + 0 /*Stages*/, + TileShape, + T, + T, + Stride, Int<0>, Int<0>>, + 128 / sizeof_bits_v, + EnableNullPtr>; // Don't want to support nullptr by default template using RowLoad = cutlass::epilogue::fusion::Sm90RowBroadcast< - 0 /*Stages*/, TileShape, T, T, Stride, Int<1>, Int<0>>, - 128 / sizeof_bits_v, EnableNullPtr>; + 0 /*Stages*/, + TileShape, + T, + T, + Stride, Int<1>, Int<0>>, + 128 / sizeof_bits_v, + EnableNullPtr>; template using ColOrScalarLoadArray = cutlass::epilogue::fusion::Sm90ColOrScalarBroadcastArray< - 0 /*Stages*/, TileShape, T, Stride, Int<0>, Int<0>>>; + 0 /*Stages*/, + TileShape, + T, + Stride, Int<0>, Int<0>>>; template using RowOrScalarLoadArray = cutlass::epilogue::fusion::Sm90RowOrScalarBroadcastArray< - 0 /*Stages*/, TileShape, T, Stride, Int<1>, Int<0>>>; + 0 /*Stages*/, + TileShape, + T, + Stride, Int<1>, Int<0>>>; // This utility function constructs the arguments for the load descriptors // from a tensor. It can handle both row and column, as well as row/column or @@ -140,24 +169,28 @@ struct ScaledEpilogueBase { template struct ScaledEpilogue : private ScaledEpilogueBase { -private: + private: using SUPER = ScaledEpilogueBase; using Accum = typename SUPER::Accum; using ScaleA = typename SUPER::template ColOrScalarLoad; using ScaleB = typename SUPER::template RowOrScalarLoad; using Compute0 = cutlass::epilogue::fusion::Sm90Compute< - cutlass::multiplies, float, float, + cutlass::multiplies, + float, + float, cutlass::FloatRoundStyle::round_to_nearest>; using EVTCompute0 = cutlass::epilogue::fusion::Sm90EVT; using Compute1 = cutlass::epilogue::fusion::Sm90Compute< - cutlass::multiplies, ElementD, float, + cutlass::multiplies, + ElementD, + float, cutlass::FloatRoundStyle::round_to_nearest>; -public: + public: using EVTCompute = cutlass::epilogue::fusion::Sm90EVT; using ArgumentType = typename EVTCompute::Arguments; @@ -184,7 +217,7 @@ struct ScaledEpilogue template struct ScaledEpilogueBias : private ScaledEpilogueBase { -private: + private: using SUPER = ScaledEpilogueBase; using Accum = typename SUPER::Accum; using ScaleA = typename SUPER::template ColOrScalarLoad; @@ -192,17 +225,21 @@ struct ScaledEpilogueBias using Bias = typename SUPER::template RowLoad; using Compute0 = cutlass::epilogue::fusion::Sm90Compute< - cutlass::multiplies, float, float, + cutlass::multiplies, + float, + float, cutlass::FloatRoundStyle::round_to_nearest>; using EVTCompute0 = cutlass::epilogue::fusion::Sm90EVT; using Compute1 = cutlass::epilogue::fusion::Sm90Compute< - cutlass::multiply_add, ElementD, float, + cutlass::multiply_add, + ElementD, + float, cutlass::FloatRoundStyle::round_to_nearest>; -public: + public: using EVTCompute = cutlass::epilogue::fusion::Sm90EVT; @@ -227,7 +264,7 @@ struct ScaledEpilogueBias template struct ScaledEpilogueColumnBias : private ScaledEpilogueBase { -private: + private: using SUPER = ScaledEpilogueBase; using Accum = typename SUPER::Accum; using ScaleA = typename SUPER::template ColOrScalarLoad; @@ -235,17 +272,21 @@ struct ScaledEpilogueColumnBias using Bias = typename SUPER::template ColLoad; using Compute0 = cutlass::epilogue::fusion::Sm90Compute< - cutlass::multiplies, float, float, + cutlass::multiplies, + float, + float, cutlass::FloatRoundStyle::round_to_nearest>; using EVTCompute0 = cutlass::epilogue::fusion::Sm90EVT; using Compute1 = cutlass::epilogue::fusion::Sm90Compute< - cutlass::multiply_add, ElementD, float, + cutlass::multiply_add, + ElementD, + float, cutlass::FloatRoundStyle::round_to_nearest>; -public: + public: using EVTCompute = cutlass::epilogue::fusion::Sm90EVT; @@ -273,7 +314,7 @@ struct ScaledEpilogueColumnBias template struct ScaledEpilogueBiasAzp : private ScaledEpilogueBase { -private: + private: using SUPER = ScaledEpilogueBase; using Accum = typename SUPER::Accum; using ScaleA = typename SUPER::template ColOrScalarLoad; @@ -285,33 +326,39 @@ struct ScaledEpilogueBiasAzp // Compute float(accum - azp_adj), both operands are int32_t using ComputeAzp = cutlass::epilogue::fusion::Sm90Compute< - cutlass::minus, float, int32_t, + cutlass::minus, + float, + int32_t, cutlass::FloatRoundStyle::round_to_nearest>; using EVTComputeAzp = cutlass::epilogue::fusion::Sm90EVT; using ComputeScaleB = cutlass::epilogue::fusion::Sm90Compute< - cutlass::multiplies, float, float, + cutlass::multiplies, + float, + float, cutlass::FloatRoundStyle::round_to_nearest>; using EVTComputeScaleB = cutlass::epilogue::fusion::Sm90EVT; using ComputeScaleBiasA = cutlass::epilogue::fusion::Sm90Compute< - cutlass::multiply_add, ElementD, float, + cutlass::multiply_add, + ElementD, + float, cutlass::FloatRoundStyle::round_to_nearest>; -public: - using EVTCompute = - cutlass::epilogue::fusion::Sm90EVT; + public: + using EVTCompute = cutlass::epilogue::fusion:: + Sm90EVT; using ArgumentType = typename EVTCompute::Arguments; - static ArgumentType - prepare_args(paddle::Tensor const &a_scales, paddle::Tensor const &b_scales, - paddle::Tensor const &azp_adj, - paddle::optional const &bias) { + static ArgumentType prepare_args( + paddle::Tensor const &a_scales, + paddle::Tensor const &b_scales, + paddle::Tensor const &azp_adj, + paddle::optional const &bias) { auto a_args = SUPER::template args_from_tensor(a_scales); auto b_args = SUPER::template args_from_tensor(b_scales); auto bias_args = SUPER::template args_from_tensor(bias); @@ -338,7 +385,7 @@ struct ScaledEpilogueBiasAzp template struct ScaledEpilogueBiasAzpToken : private ScaledEpilogueBase { -private: + private: using SUPER = ScaledEpilogueBase; using Accum = typename SUPER::Accum; using ScaleA = typename SUPER::template ColOrScalarLoad; @@ -353,7 +400,9 @@ struct ScaledEpilogueBiasAzpToken // Compute azp * azp_adj using ComputeAzp = cutlass::epilogue::fusion::Sm90Compute< - cutlass::multiplies, int32_t, int32_t, + cutlass::multiplies, + int32_t, + int32_t, cutlass::FloatRoundStyle::round_to_nearest>; using EVTComputeAzp = @@ -361,33 +410,40 @@ struct ScaledEpilogueBiasAzpToken // Compute float(accum - azp*azp_adj), all operands are int32_t using ComputeAcc = cutlass::epilogue::fusion::Sm90Compute< - cutlass::minus, float, int32_t, + cutlass::minus, + float, + int32_t, cutlass::FloatRoundStyle::round_to_nearest>; using EVTComputeAcc = cutlass::epilogue::fusion::Sm90EVT; using ComputeScaleB = cutlass::epilogue::fusion::Sm90Compute< - cutlass::multiplies, float, float, + cutlass::multiplies, + float, + float, cutlass::FloatRoundStyle::round_to_nearest>; using EVTComputeScaleB = cutlass::epilogue::fusion::Sm90EVT; using ComputeScaleBiasA = cutlass::epilogue::fusion::Sm90Compute< - cutlass::multiply_add, ElementD, float, + cutlass::multiply_add, + ElementD, + float, cutlass::FloatRoundStyle::round_to_nearest>; -public: - using EVTCompute = - cutlass::epilogue::fusion::Sm90EVT; + public: + using EVTCompute = cutlass::epilogue::fusion:: + Sm90EVT; using ArgumentType = typename EVTCompute::Arguments; - static ArgumentType - prepare_args(paddle::Tensor const &a_scales, paddle::Tensor const &b_scales, - paddle::Tensor const &azp_adj, paddle::Tensor const &azp, - paddle::optional const &bias) { + static ArgumentType prepare_args( + paddle::Tensor const &a_scales, + paddle::Tensor const &b_scales, + paddle::Tensor const &azp_adj, + paddle::Tensor const &azp, + paddle::optional const &bias) { auto a_args = SUPER::template args_from_tensor(a_scales); auto b_args = SUPER::template args_from_tensor(b_scales); auto bias_args = SUPER::template args_from_tensor(bias); @@ -412,24 +468,28 @@ struct ScaledEpilogueBiasAzpToken template struct ScaledEpilogueArray : private ScaledEpilogueBase { -private: + private: using SUPER = ScaledEpilogueBase; using Accum = typename SUPER::Accum; using ScaleA = typename SUPER::template ColOrScalarLoadArray; using ScaleB = typename SUPER::template RowOrScalarLoadArray; using Compute0 = cutlass::epilogue::fusion::Sm90Compute< - cutlass::multiplies, float, float, + cutlass::multiplies, + float, + float, cutlass::FloatRoundStyle::round_to_nearest>; using EVTCompute0 = cutlass::epilogue::fusion::Sm90EVT; using Compute1 = cutlass::epilogue::fusion::Sm90Compute< - cutlass::multiplies, ElementD, float, + cutlass::multiplies, + ElementD, + float, cutlass::FloatRoundStyle::round_to_nearest>; -public: + public: using EVTCompute = cutlass::epilogue::fusion::Sm90EVT; using ArgumentType = typename EVTCompute::Arguments; @@ -439,7 +499,8 @@ struct ScaledEpilogueArray static ArgumentType prepare_args(float const *const *a_scales_ptr, float const *const *b_scales_ptr, - bool a_col_broadcast, bool b_row_broadcast) { + bool a_col_broadcast, + bool b_row_broadcast) { auto a_args = SUPER::template args_from_tensor( a_scales_ptr, a_col_broadcast); auto b_args = SUPER::template args_from_tensor( @@ -450,4 +511,4 @@ struct ScaledEpilogueArray } }; -}; // namespace fastdeploy::c3x +}; // namespace fastdeploy::c3x diff --git a/custom_ops/gpu_ops/cutlass_extensions/epilogue/thread/fused_activations.h b/custom_ops/gpu_ops/cutlass_extensions/epilogue/thread/fused_activations.h index f3c622b88a5..0aef590a10c 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/epilogue/thread/fused_activations.h +++ b/custom_ops/gpu_ops/cutlass_extensions/epilogue/thread/fused_activations.h @@ -1,12 +1,12 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation @@ -18,18 +18,20 @@ * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /*! \file - \brief Functor performing linear combination with a maximum operation used by epilogues. + \brief Functor performing linear combination with a maximum operation used by + epilogues. */ #pragma once @@ -46,60 +48,53 @@ ///////////////////////////////////////////////////////////////////////////////////////////////// -namespace cutlass -{ -namespace epilogue -{ -namespace thread -{ +namespace cutlass { +namespace epilogue { +namespace thread { ///////////////////////////////////////////////////////////////////////////////////////////////// -__forceinline__ __device__ float copysignf_pos(float a, float b) -{ - float r; - r = __int_as_float(__float_as_int(a) | (__float_as_int(b) & 0x80000000)); - return r; +__forceinline__ __device__ float copysignf_pos(float a, float b) { + float r; + r = __int_as_float(__float_as_int(a) | (__float_as_int(b) & 0x80000000)); + return r; } -__forceinline__ __device__ float tanh_opt(float x) -{ +__forceinline__ __device__ float tanh_opt(float x) { #if (__CUDACC_VER_MAJOR__ < 11) || (__CUDA_ARCH__ < 750) - float const exp_val = -1.f * fabs(2 * x); - return copysignf_pos((1.0f - __expf(exp_val)) / (__expf(exp_val) + 1.0f), x); + float const exp_val = -1.f * fabs(2 * x); + return copysignf_pos((1.0f - __expf(exp_val)) / (__expf(exp_val) + 1.0f), x); #else - return fast_tanh(x); + return fast_tanh(x); #endif } ///////////////////////////////////////////////////////////////////////////////////////////////// template <> -struct GELU_taylor -{ - static bool const kIsHeavy = true; +struct GELU_taylor { + static bool const kIsHeavy = true; - CUTLASS_DEVICE - float operator()(float const& z) const - { + CUTLASS_DEVICE + float operator()(float const& z) const { + float k0 = float(0.7978845608028654); + float k1 = float(0.044715); - float k0 = float(0.7978845608028654); - float k1 = float(0.044715); + return float( + cutlass::constants::half() * z * + (cutlass::constants::one() + + tanh_opt(k0 * z * (cutlass::constants::one() + k1 * z * z)))); + } - return float(cutlass::constants::half() * z - * (cutlass::constants::one() + tanh_opt(k0 * z * (cutlass::constants::one() + k1 * z * z)))); - } + using Params = LinearCombinationGenericParams; - using Params = LinearCombinationGenericParams; - - CUTLASS_DEVICE - float operator()(float const& scalar, Params const& params_) const - { - return this->operator()(scalar); - } + CUTLASS_DEVICE + float operator()(float const& scalar, Params const& params_) const { + return this->operator()(scalar); + } }; -} // namespace thread -} // namespace epilogue -} // namespace cutlass +} // namespace thread +} // namespace epilogue +} // namespace cutlass ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/custom_ops/gpu_ops/cutlass_extensions/epilogue/threadblock/epilogue_per_row_per_col_scale.h b/custom_ops/gpu_ops/cutlass_extensions/epilogue/threadblock/epilogue_per_row_per_col_scale.h index aeec5e5d0b5..031f3e70485 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/epilogue/threadblock/epilogue_per_row_per_col_scale.h +++ b/custom_ops/gpu_ops/cutlass_extensions/epilogue/threadblock/epilogue_per_row_per_col_scale.h @@ -1,12 +1,12 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation @@ -18,20 +18,23 @@ * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /*! \file - \brief Epilogue visitor for threadblock scoped INT8 GEMMs that uses one scaling factor per row, and one per column. + \brief Epilogue visitor for threadblock scoped INT8 GEMMs that uses one + scaling factor per row, and one per column. - original file: 3rdparty/cutlass/include/cutlass/epilogue/threadblock/epilogue_visitor_with_softmax.h + original file: + 3rdparty/cutlass/include/cutlass/epilogue/threadblock/epilogue_visitor_with_softmax.h */ @@ -46,305 +49,312 @@ #include "cutlass/numeric_conversion.h" #include "common/quantization.h" -namespace cutlass -{ -namespace epilogue -{ -namespace threadblock -{ - -template -class EpilogueVisitorPerRowPerCol -{ -public: - using ThreadblockShape = ThreadblockShape_; - static int const kThreadCount = ThreadCount; - - using ScaleTileIterator = ScaleTileIterator_; - using OutputTileIterator = OutputTileIterator_; - using ElementwiseFunctor = ElementwiseFunctor_; - - static int const kIterations = OutputTileIterator::kIterations; - static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess; - - using ElementOutput = typename OutputTileIterator::Element; - using LayoutOutput = cutlass::layout::RowMajor; - using ElementAccumulator = ElementAccumulator_; - - using AlphaScaleElementType = typename ScaleTileIterator::Element; - - using ElementCompute = ElementCompute_; - using AccumulatorFragment = Array; - using ComputeFragment = Array; - using OutputVector = Array; - - static int const kThreadsPerRow = OutputTileIterator::ThreadMap::Detail::kAccessWidth; - static bool const kHasMultiStepsInRow = (OutputTileIterator::ThreadMap::Iterations::kColumn > 1); - - /// Argument structure - struct Arguments - { - - typename ElementwiseFunctor::Params elementwise; - int64_t batch_stride_alpha; - int64_t batch_stride_C; - int64_t batch_stride_D; - - // - // Methods - // - Arguments() - : batch_stride_alpha(0) - , batch_stride_C(0) - , batch_stride_D(0) - { - } - - Arguments(typename ElementwiseFunctor::Params elementwise_) - : elementwise(elementwise_) - , batch_stride_alpha(0) - , batch_stride_C(0) - , batch_stride_D(0) - { - } - - Arguments(typename ElementwiseFunctor::Params elementwise_, int64_t batch_stride_alpha_, - int64_t batch_stride_C_, int64_t batch_stride_D_) - : elementwise(elementwise_) - , batch_stride_alpha(batch_stride_alpha_) - , batch_stride_C(batch_stride_C_) - , batch_stride_D(batch_stride_D_) - { - } - }; - - struct Params - { - - typename ElementwiseFunctor::Params elementwise; - int64_t batch_stride_alpha; - int64_t batch_stride_C; - int64_t batch_stride_D; - - // - // Methods - // - CUTLASS_HOST_DEVICE - Params() {} - - CUTLASS_HOST_DEVICE - Params(Arguments const& args) - : elementwise(args.elementwise) - , batch_stride_alpha(args.batch_stride_alpha) - , batch_stride_C(args.batch_stride_C) - , batch_stride_D(args.batch_stride_D) - { - } - }; - - /// Shared storage - struct SharedStorage - { - }; - -private: - Params const& params_; - SharedStorage& shared_storage_; - MatrixCoord extent_; - MatrixCoord extent_real_; - ElementwiseFunctor elementwise_; - - bool const per_token_quant_; - bool const per_channel_quant_; - - AlphaScaleElementType* ptr_alpha_row_; - AlphaScaleElementType* ptr_alpha_col_; - ScaleTileIterator iterator_alpha_col_; - OutputTileIterator iterator_C_; - OutputTileIterator iterator_D_; - - AlphaScaleElementType element_alpha_row_ = 1.0f; - AlphaScaleElementType element_alpha_col_ = 1.0f; - typename ScaleTileIterator::Fragment fragment_alpha_col_; - typename OutputTileIterator::Fragment fragment_C_; - typename OutputTileIterator::Fragment fragment_D_; - - ElementAccumulator beta_; - - int column_offset_; - - MatrixCoord thread_offset_; - -public: - CUTLASS_DEVICE - EpilogueVisitorPerRowPerCol(Params const& params, SharedStorage& shared_storage, - cutlass::MatrixCoord const& problem_size, int thread_idx, int warp_idx, int lane_idx, - typename ScaleTileIterator::Params params_alpha_col, typename OutputTileIterator::Params params_C, - typename OutputTileIterator::Params params_D, common::QuantMode quant_option, AlphaScaleElementType* ptr_alpha_row, - AlphaScaleElementType* ptr_alpha_col, typename OutputTileIterator::Element* ptr_C, - typename OutputTileIterator::Element* ptr_D, - cutlass::MatrixCoord const& threadblock_offset = cutlass::MatrixCoord(0, 0), int column_offset = 0, - cutlass::MatrixCoord const& problem_size_real = cutlass::MatrixCoord(0, 0)) - : params_(params) - , shared_storage_(shared_storage) - , extent_(problem_size) - , elementwise_(params.elementwise) - , per_token_quant_(quant_option.hasPerTokenScaling()) - , per_channel_quant_(quant_option.hasPerChannelScaling()) - , ptr_alpha_row_(ptr_alpha_row) - , ptr_alpha_col_(ptr_alpha_col) - , iterator_alpha_col_(params_alpha_col, ptr_alpha_col, problem_size, thread_idx, threadblock_offset) - , iterator_C_(params_C, ptr_C, problem_size, thread_idx, threadblock_offset) - , iterator_D_(params_D, ptr_D, problem_size, thread_idx, threadblock_offset) - , extent_real_(problem_size_real) - { - beta_ = (params.elementwise.beta_ptr ? *params.elementwise.beta_ptr : params.elementwise.beta); - - if (beta_ == ElementAccumulator()) - { - iterator_C_.clear_mask(); - } - - if (!per_channel_quant_ && (ptr_alpha_col_ != nullptr)) - { - element_alpha_col_ = *ptr_alpha_col_; - } - - if (!per_token_quant_ && (ptr_alpha_row_ != nullptr)) - { - element_alpha_row_ = *ptr_alpha_row_; - } +namespace cutlass { +namespace epilogue { +namespace threadblock { + +template +class EpilogueVisitorPerRowPerCol { + public: + using ThreadblockShape = ThreadblockShape_; + static int const kThreadCount = ThreadCount; + + using ScaleTileIterator = ScaleTileIterator_; + using OutputTileIterator = OutputTileIterator_; + using ElementwiseFunctor = ElementwiseFunctor_; + + static int const kIterations = OutputTileIterator::kIterations; + static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess; + + using ElementOutput = typename OutputTileIterator::Element; + using LayoutOutput = cutlass::layout::RowMajor; + using ElementAccumulator = ElementAccumulator_; + + using AlphaScaleElementType = typename ScaleTileIterator::Element; + + using ElementCompute = ElementCompute_; + using AccumulatorFragment = Array; + using ComputeFragment = Array; + using OutputVector = Array; + + static int const kThreadsPerRow = + OutputTileIterator::ThreadMap::Detail::kAccessWidth; + static bool const kHasMultiStepsInRow = + (OutputTileIterator::ThreadMap::Iterations::kColumn > 1); + + /// Argument structure + struct Arguments { + typename ElementwiseFunctor::Params elementwise; + int64_t batch_stride_alpha; + int64_t batch_stride_C; + int64_t batch_stride_D; + + // + // Methods + // + Arguments() : batch_stride_alpha(0), batch_stride_C(0), batch_stride_D(0) {} + + Arguments(typename ElementwiseFunctor::Params elementwise_) + : elementwise(elementwise_), + batch_stride_alpha(0), + batch_stride_C(0), + batch_stride_D(0) {} + + Arguments(typename ElementwiseFunctor::Params elementwise_, + int64_t batch_stride_alpha_, + int64_t batch_stride_C_, + int64_t batch_stride_D_) + : elementwise(elementwise_), + batch_stride_alpha(batch_stride_alpha_), + batch_stride_C(batch_stride_C_), + batch_stride_D(batch_stride_D_) {} + }; + + struct Params { + typename ElementwiseFunctor::Params elementwise; + int64_t batch_stride_alpha; + int64_t batch_stride_C; + int64_t batch_stride_D; + + // + // Methods + // + CUTLASS_HOST_DEVICE + Params() {} + + CUTLASS_HOST_DEVICE + Params(Arguments const& args) + : elementwise(args.elementwise), + batch_stride_alpha(args.batch_stride_alpha), + batch_stride_C(args.batch_stride_C), + batch_stride_D(args.batch_stride_D) {} + }; + + /// Shared storage + struct SharedStorage {}; + + private: + Params const& params_; + SharedStorage& shared_storage_; + MatrixCoord extent_; + MatrixCoord extent_real_; + ElementwiseFunctor elementwise_; + + bool const per_token_quant_; + bool const per_channel_quant_; + + AlphaScaleElementType* ptr_alpha_row_; + AlphaScaleElementType* ptr_alpha_col_; + ScaleTileIterator iterator_alpha_col_; + OutputTileIterator iterator_C_; + OutputTileIterator iterator_D_; + + AlphaScaleElementType element_alpha_row_ = 1.0f; + AlphaScaleElementType element_alpha_col_ = 1.0f; + typename ScaleTileIterator::Fragment fragment_alpha_col_; + typename OutputTileIterator::Fragment fragment_C_; + typename OutputTileIterator::Fragment fragment_D_; + + ElementAccumulator beta_; + + int column_offset_; + + MatrixCoord thread_offset_; + + public: + CUTLASS_DEVICE + EpilogueVisitorPerRowPerCol( + Params const& params, + SharedStorage& shared_storage, + cutlass::MatrixCoord const& problem_size, + int thread_idx, + int warp_idx, + int lane_idx, + typename ScaleTileIterator::Params params_alpha_col, + typename OutputTileIterator::Params params_C, + typename OutputTileIterator::Params params_D, + common::QuantMode quant_option, + AlphaScaleElementType* ptr_alpha_row, + AlphaScaleElementType* ptr_alpha_col, + typename OutputTileIterator::Element* ptr_C, + typename OutputTileIterator::Element* ptr_D, + cutlass::MatrixCoord const& threadblock_offset = cutlass::MatrixCoord(0, + 0), + int column_offset = 0, + cutlass::MatrixCoord const& problem_size_real = cutlass::MatrixCoord(0, + 0)) + : params_(params), + shared_storage_(shared_storage), + extent_(problem_size), + elementwise_(params.elementwise), + per_token_quant_(quant_option.hasPerTokenScaling()), + per_channel_quant_(quant_option.hasPerChannelScaling()), + ptr_alpha_row_(ptr_alpha_row), + ptr_alpha_col_(ptr_alpha_col), + iterator_alpha_col_(params_alpha_col, + ptr_alpha_col, + problem_size, + thread_idx, + threadblock_offset), + iterator_C_( + params_C, ptr_C, problem_size, thread_idx, threadblock_offset), + iterator_D_( + params_D, ptr_D, problem_size, thread_idx, threadblock_offset), + extent_real_(problem_size_real) { + beta_ = (params.elementwise.beta_ptr ? *params.elementwise.beta_ptr + : params.elementwise.beta); + + if (beta_ == ElementAccumulator()) { + iterator_C_.clear_mask(); } - /// Helper to indicate split-K behavior - CUTLASS_DEVICE - void set_k_partition(int split_k_index, ///< Index of this threadblock within split-K partitioned scheme - int split_k_slices) - { ///< Total number of split-K slices + if (!per_channel_quant_ && (ptr_alpha_col_ != nullptr)) { + element_alpha_col_ = *ptr_alpha_col_; } - /// Called to set the batch index - CUTLASS_DEVICE - void set_batch_index(int batch_idx) - { - iterator_alpha_col_.add_pointer_offset(batch_idx * params_.batch_stride_alpha); - iterator_C_.add_pointer_offset(batch_idx * params_.batch_stride_C); - iterator_D_.add_pointer_offset(batch_idx * params_.batch_stride_D); + if (!per_token_quant_ && (ptr_alpha_row_ != nullptr)) { + element_alpha_row_ = *ptr_alpha_row_; } - - /// Called at the start of the epilogue just before iterating over accumulator slices - CUTLASS_DEVICE - void begin_epilogue() - { - if (per_channel_quant_) - { - iterator_alpha_col_.load(fragment_alpha_col_); - } + } + + /// Helper to indicate split-K behavior + CUTLASS_DEVICE + void set_k_partition( + int split_k_index, ///< Index of this threadblock within split-K + ///< partitioned scheme + int split_k_slices) { ///< Total number of split-K slices + } + + /// Called to set the batch index + CUTLASS_DEVICE + void set_batch_index(int batch_idx) { + iterator_alpha_col_.add_pointer_offset(batch_idx * + params_.batch_stride_alpha); + iterator_C_.add_pointer_offset(batch_idx * params_.batch_stride_C); + iterator_D_.add_pointer_offset(batch_idx * params_.batch_stride_D); + } + + /// Called at the start of the epilogue just before iterating over accumulator + /// slices + CUTLASS_DEVICE + void begin_epilogue() { + if (per_channel_quant_) { + iterator_alpha_col_.load(fragment_alpha_col_); } - - /// Called at the start of one step before starting accumulator exchange - CUTLASS_DEVICE - void begin_step(int step_idx) - { - fragment_D_.clear(); - fragment_C_.clear(); - - if (elementwise_.kScale != cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling) - { - iterator_C_.load(fragment_C_); - ++iterator_C_; - } + } + + /// Called at the start of one step before starting accumulator exchange + CUTLASS_DEVICE + void begin_step(int step_idx) { + fragment_D_.clear(); + fragment_C_.clear(); + + if (elementwise_.kScale != + cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling) { + iterator_C_.load(fragment_C_); + ++iterator_C_; } - - /// Called at the start of a row - CUTLASS_DEVICE - void begin_row(int row_idx) - { - // load alpha_row in begin_step only when per token(row) scaling is used - if (per_token_quant_) - { - int thread_offset_row - = iterator_D_.thread_start_row() + OutputTileIterator::ThreadMap::iteration_offset(row_idx).row(); - - arch::global_load( - element_alpha_row_, ptr_alpha_row_ + thread_offset_row, thread_offset_row < extent_.row()); - } + } + + /// Called at the start of a row + CUTLASS_DEVICE + void begin_row(int row_idx) { + // load alpha_row in begin_step only when per token(row) scaling is used + if (per_token_quant_) { + int thread_offset_row = + iterator_D_.thread_start_row() + + OutputTileIterator::ThreadMap::iteration_offset(row_idx).row(); + + arch::global_load( + element_alpha_row_, + ptr_alpha_row_ + thread_offset_row, + thread_offset_row < extent_.row()); } - - /// Called after accumulators have been exchanged for each accumulator vector - CUTLASS_DEVICE - void visit(int iter_idx, int row_idx, int column_idx, int frag_idx, AccumulatorFragment const& accum) - { - - NumericArrayConverter source_converter; - - ComputeFragment result = source_converter(accum); - if (per_channel_quant_) - { - ComputeFragment alpha_col = reinterpret_cast(&fragment_alpha_col_)[column_idx]; - result = per_token_channel_scale_accumulator_(result, alpha_col, element_alpha_row_); - } - else - { - result = per_token_scale_accumulator_(result, element_alpha_col_, element_alpha_row_); - } - - // Convert to the output - NumericArrayConverter output_converter; - OutputVector& output = reinterpret_cast(&fragment_D_)[frag_idx]; - output = output_converter(result); + } + + /// Called after accumulators have been exchanged for each accumulator vector + CUTLASS_DEVICE + void visit(int iter_idx, + int row_idx, + int column_idx, + int frag_idx, + AccumulatorFragment const& accum) { + NumericArrayConverter + source_converter; + + ComputeFragment result = source_converter(accum); + if (per_channel_quant_) { + ComputeFragment alpha_col = + reinterpret_cast(&fragment_alpha_col_)[column_idx]; + result = per_token_channel_scale_accumulator_( + result, alpha_col, element_alpha_row_); + } else { + result = per_token_scale_accumulator_( + result, element_alpha_col_, element_alpha_row_); } - /// Called at the end of a row - CUTLASS_DEVICE - void end_row(int row_idx) {} - - /// Called after all accumulator elements have been visited - CUTLASS_DEVICE - void end_step(int step_idx) - { - - iterator_D_.store(fragment_D_); - ++iterator_D_; + // Convert to the output + NumericArrayConverter + output_converter; + OutputVector& output = + reinterpret_cast(&fragment_D_)[frag_idx]; + output = output_converter(result); + } + + /// Called at the end of a row + CUTLASS_DEVICE + void end_row(int row_idx) {} + + /// Called after all accumulator elements have been visited + CUTLASS_DEVICE + void end_step(int step_idx) { + iterator_D_.store(fragment_D_); + ++iterator_D_; + } + + /// Called after all steps have been completed + CUTLASS_DEVICE + void end_epilogue() {} + + private: + CUTLASS_DEVICE + ComputeFragment per_token_channel_scale_accumulator_( + ComputeFragment const& accum, + ComputeFragment const& scale_col, + AlphaScaleElementType const& scale_row) { + ComputeFragment result; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < ComputeFragment::kElements; ++i) { + result[i] = accum[i] * (scale_col[i] * scale_row); } - /// Called after all steps have been completed - CUTLASS_DEVICE - void end_epilogue() {} - -private: - CUTLASS_DEVICE - ComputeFragment per_token_channel_scale_accumulator_( - ComputeFragment const& accum, ComputeFragment const& scale_col, AlphaScaleElementType const& scale_row) - { - - ComputeFragment result; - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < ComputeFragment::kElements; ++i) - { - result[i] = accum[i] * (scale_col[i] * scale_row); - } - - return result; + return result; + } + + CUTLASS_DEVICE + ComputeFragment per_token_scale_accumulator_( + ComputeFragment const& accum, + AlphaScaleElementType const& scale_col, + AlphaScaleElementType const& scale_row) { + ComputeFragment result; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < ComputeFragment::kElements; ++i) { + result[i] = accum[i] * (scale_col * scale_row); } - CUTLASS_DEVICE - ComputeFragment per_token_scale_accumulator_( - ComputeFragment const& accum, AlphaScaleElementType const& scale_col, AlphaScaleElementType const& scale_row) - { - - ComputeFragment result; - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < ComputeFragment::kElements; ++i) - { - result[i] = accum[i] * (scale_col * scale_row); - } - - return result; - } + return result; + } }; -} // namespace threadblock -} // namespace epilogue -} // namespace cutlass +} // namespace threadblock +} // namespace epilogue +} // namespace cutlass diff --git a/custom_ops/gpu_ops/cutlass_extensions/epilogue/threadblock/epilogue_tensor_op_int32.h b/custom_ops/gpu_ops/cutlass_extensions/epilogue/threadblock/epilogue_tensor_op_int32.h index 6f26d790170..2f89cb3f215 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/epilogue/threadblock/epilogue_tensor_op_int32.h +++ b/custom_ops/gpu_ops/cutlass_extensions/epilogue/threadblock/epilogue_tensor_op_int32.h @@ -1,12 +1,12 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation @@ -18,23 +18,26 @@ * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /*! \file \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. - The epilogue rearranges the result of a matrix product through shared memory to match canonical - tensor layouts in global memory. Epilogues support conversion and reduction operations. + The epilogue rearranges the result of a matrix product through shared memory + to match canonical tensor layouts in global memory. Epilogues support + conversion and reduction operations. - original file: 3rdparty/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_tensor_op.h + original file: + 3rdparty/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_tensor_op.h */ @@ -80,35 +83,45 @@ //////////////////////////////////////////////////////////////////////////////// -namespace cutlass -{ -namespace epilogue -{ -namespace threadblock -{ +namespace cutlass { +namespace epilogue { +namespace threadblock { //////////////////////////////////////////////////////////////////////////////// -namespace detail -{ - -/// Partial specialization for bfloat16_t <= int32_t x 8 epilogues avoids shared memory bank conflicts. -template -struct DefaultIteratorsTensorOp -{ - using WarpTileIterator - = cutlass::epilogue::warp::TileIteratorTensorOpMixed; - - using SharedLoadIterator - = cutlass::epilogue::threadblock::SharedLoadIteratorMixed; - - static int const kFragmentsPerIteration = 2; +namespace detail { + +/// Partial specialization for bfloat16_t <= int32_t x 8 epilogues avoids shared +/// memory bank conflicts. +template +struct DefaultIteratorsTensorOp { + using WarpTileIterator = + cutlass::epilogue::warp::TileIteratorTensorOpMixed; + + using SharedLoadIterator = cutlass::epilogue::threadblock:: + SharedLoadIteratorMixed; + + static int const kFragmentsPerIteration = 2; }; ///////////////////////////////////////////////////////////////////////////////////////////////// -} // namespace detail +} // namespace detail ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -116,167 +129,159 @@ struct DefaultIteratorsTensorOp -class SharedLoadIteratorMixed -{ -public: - using ThreadMap = ThreadMap_; - using Shape = typename ThreadMap::Shape; - - using Element = int32_t; - - using Layout = layout::RowMajor; - using TensorRef = TensorRef; - using ConstTensorRef = typename TensorRef::ConstTensorRef; - - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - using TensorCoord = MatrixCoord; - - static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; - - static int const kAlignment = ThreadMap::kElementsPerAccess * sizeof_bits::value / 8; - - static int const kThreads = ThreadMap::kThreads; - - /// Fragment object - using Fragment = Array; - - /// Memory access size - using AccessType = AlignedArray; - - /// Vector type used for SMEM loads - using LoadType = AlignedArray::value, ThreadMap::kElementsPerAccess), - const_min(16, kAlignment)>; - - static int const kLoadsPerAccess = AccessType::kElements / LoadType::kElements; - -private: - // - // Data members - // - - /// Byte-level pointer - LoadType const* pointers_[kLoadsPerAccess]; - - /// Stride along adjacent rows in units of LoadType - int stride_; - -public: - // - // Methods - // - - /// Constructor - CUTLASS_DEVICE - SharedLoadIteratorMixed(TensorRef ref, int thread_idx) - : stride_((ref.stride(0) / LoadType::kElements)) - { - - TensorCoord thread_offset = ThreadMap::initial_offset(thread_idx); - - // Initialize pointers - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < kLoadsPerAccess; ++i) - { - pointers_[i] = reinterpret_cast(ref.data()); - - int col_idx = (thread_offset.column() / kElementsPerAccess) * kLoadsPerAccess; - int bank_offset = (col_idx * int(sizeof(LoadType)) / 128) % kLoadsPerAccess; - - col_idx += (bank_offset + i) % kLoadsPerAccess; - - pointers_[i] += thread_offset.row() * stride_ + col_idx; - } +template +class SharedLoadIteratorMixed { + public: + using ThreadMap = ThreadMap_; + using Shape = typename ThreadMap::Shape; + + using Element = int32_t; + + using Layout = layout::RowMajor; + using TensorRef = TensorRef; + using ConstTensorRef = typename TensorRef::ConstTensorRef; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + using TensorCoord = MatrixCoord; + + static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; + + static int const kAlignment = + ThreadMap::kElementsPerAccess * sizeof_bits::value / 8; + + static int const kThreads = ThreadMap::kThreads; + + /// Fragment object + using Fragment = + Array; + + /// Memory access size + using AccessType = + AlignedArray; + + /// Vector type used for SMEM loads + using LoadType = AlignedArray::value, + ThreadMap::kElementsPerAccess), + const_min(16, kAlignment)>; + + static int const kLoadsPerAccess = + AccessType::kElements / LoadType::kElements; + + private: + // + // Data members + // + + /// Byte-level pointer + LoadType const* pointers_[kLoadsPerAccess]; + + /// Stride along adjacent rows in units of LoadType + int stride_; + + public: + // + // Methods + // + + /// Constructor + CUTLASS_DEVICE + SharedLoadIteratorMixed(TensorRef ref, int thread_idx) + : stride_((ref.stride(0) / LoadType::kElements)) { + TensorCoord thread_offset = ThreadMap::initial_offset(thread_idx); + + // Initialize pointers + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kLoadsPerAccess; ++i) { + pointers_[i] = reinterpret_cast(ref.data()); + + int col_idx = + (thread_offset.column() / kElementsPerAccess) * kLoadsPerAccess; + int bank_offset = + (col_idx * int(sizeof(LoadType)) / 128) % kLoadsPerAccess; + + col_idx += (bank_offset + i) % kLoadsPerAccess; + + pointers_[i] += thread_offset.row() * stride_ + col_idx; } - - /// Adds a pointer offset in units of Element - CUTLASS_HOST_DEVICE - void add_pointer_offset(LongIndex pointer_offset) - { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < kLoadsPerAccess; ++i) - { - pointers_[i] += pointer_offset / LoadType::kElements; - } + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kLoadsPerAccess; ++i) { + pointers_[i] += pointer_offset / LoadType::kElements; } - - CUTLASS_DEVICE - void add_tile_offset(TensorCoord const& offset) - { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < kLoadsPerAccess; ++i) - { - pointers_[i] - += offset.row() * Shape::kRow * stride_ + offset.column() * Shape::kColumn / LoadType::kElements; - } + } + + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const& offset) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kLoadsPerAccess; ++i) { + pointers_[i] += offset.row() * Shape::kRow * stride_ + + offset.column() * Shape::kColumn / LoadType::kElements; } - - /// Loads a fragment from memory - CUTLASS_DEVICE - void load_with_pointer_offset(Fragment& frag, Index pointer_offset) const - { - + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment& frag, Index pointer_offset) const { + CUTLASS_PRAGMA_UNROLL + for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; + ++cluster) { + CUTLASS_PRAGMA_UNROLL + for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { CUTLASS_PRAGMA_UNROLL - for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) - { - - CUTLASS_PRAGMA_UNROLL - for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) - { - - CUTLASS_PRAGMA_UNROLL - for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) - { - - int row_ptr_offset = row * ThreadMap::Delta::kRow * stride_ - + group * ThreadMap::Delta::kGroup * stride_ + cluster * ThreadMap::Delta::kCluster * stride_ - + pointer_offset / LoadType::kElements; + for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { + int row_ptr_offset = row * ThreadMap::Delta::kRow * stride_ + + group * ThreadMap::Delta::kGroup * stride_ + + cluster * ThreadMap::Delta::kCluster * stride_ + + pointer_offset / LoadType::kElements; - int frag_row_idx - = (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster)); + int frag_row_idx = + (row + ThreadMap::Iterations::kRow * + (group + ThreadMap::Iterations::kGroup * cluster)); - LoadType* frag_ptr = reinterpret_cast(&frag); + LoadType* frag_ptr = reinterpret_cast(&frag); - CUTLASS_PRAGMA_UNROLL - for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) - { + CUTLASS_PRAGMA_UNROLL + for (int column = 0; column < ThreadMap::Iterations::kColumn; + ++column) { + int frag_idx = + frag_row_idx * ThreadMap::Iterations::kColumn + column; - int frag_idx = frag_row_idx * ThreadMap::Iterations::kColumn + column; - - CUTLASS_PRAGMA_UNROLL - for (int v = 0; v < kLoadsPerAccess; ++v) - { - - int vector_idx - = (column * ThreadMap::Delta::kColumn / kElementsPerAccess * kLoadsPerAccess); + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < kLoadsPerAccess; ++v) { + int vector_idx = (column * ThreadMap::Delta::kColumn / + kElementsPerAccess * kLoadsPerAccess); - LoadType const* memory_pointer = pointers_[v] + row_ptr_offset; + LoadType const* memory_pointer = pointers_[v] + row_ptr_offset; - frag_ptr[frag_idx * kLoadsPerAccess + v] = memory_pointer[vector_idx]; - } - } - } + frag_ptr[frag_idx * kLoadsPerAccess + v] = + memory_pointer[vector_idx]; } + } } + } } + } - /// Loads a fragment - CUTLASS_DEVICE - void load(Fragment& frag) const - { - - load_with_pointer_offset(frag, 0); - } + /// Loads a fragment + CUTLASS_DEVICE + void load(Fragment& frag) const { load_with_pointer_offset(frag, 0); } }; ///////////////////////////////////////////////////////////////////////////////////////////////// -} // namespace threadblock -} // namespace epilogue -} // namespace cutlass +} // namespace threadblock +} // namespace epilogue +} // namespace cutlass //////////////////////////////////////////////////////////////////////////////// diff --git a/custom_ops/gpu_ops/cutlass_extensions/epilogue_helpers.h b/custom_ops/gpu_ops/cutlass_extensions/epilogue_helpers.h index 6ed5b9b920c..ca8209d59c7 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/epilogue_helpers.h +++ b/custom_ops/gpu_ops/cutlass_extensions/epilogue_helpers.h @@ -1,6 +1,6 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: Apache-2.0 + * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & + * AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,9 +17,10 @@ /** * @file epilogue_helpers.h * - * This file includes types for the epilogues. The empty structs exist so we can signal to template - * code the type of epilogue we want to run, and let the underlying code specify the details such as - * element types, accumulator type and elements per vector access. + * This file includes types for the epilogues. The empty structs exist so we can + * signal to template code the type of epilogue we want to run, and let the + * underlying code specify the details such as element types, accumulator type + * and elements per vector access. * */ @@ -33,107 +34,161 @@ // #include "cutlass/epilogue/fusion/operations.hpp" -namespace cutlass_extensions -{ - -struct EpilogueOpBiasSilu -{ -}; - -struct EpilogueOpBiasReLU -{ -}; - -struct EpilogueOpBiasFtGelu -{ -}; - -struct EpilogueOpBias -{ -}; - -struct EpilogueOpDefaultSilu -{ -}; - -struct EpilogueOpDefaultReLU -{ -}; - -struct EpilogueOpDefaultFtGelu -{ -}; - -struct EpilogueOpDefault -{ -}; - -template -struct Epilogue -{ - static_assert(sizeof(ElementType) == 0, "Unrecognized Epilogue Tag"); -}; - -constexpr auto BiasScaleMode = cutlass::epilogue::thread::ScaleType::NoBetaScaling; - -template -struct Epilogue -{ - using Op = cutlass::epilogue::thread::LinearCombinationSilu; -}; - -template -struct Epilogue -{ - using Op = cutlass::epilogue::thread::LinearCombinationRelu; -}; - -template -struct Epilogue -{ - using Op = cutlass::epilogue::thread::LinearCombinationGeneric; -}; - -template -struct Epilogue -{ - using Op = cutlass::epilogue::thread::LinearCombination; +namespace cutlass_extensions { + +struct EpilogueOpBiasSilu {}; + +struct EpilogueOpBiasReLU {}; + +struct EpilogueOpBiasFtGelu {}; + +struct EpilogueOpBias {}; + +struct EpilogueOpDefaultSilu {}; + +struct EpilogueOpDefaultReLU {}; + +struct EpilogueOpDefaultFtGelu {}; + +struct EpilogueOpDefault {}; + +template +struct Epilogue { + static_assert(sizeof(ElementType) == 0, "Unrecognized Epilogue Tag"); +}; + +constexpr auto BiasScaleMode = + cutlass::epilogue::thread::ScaleType::NoBetaScaling; + +template +struct Epilogue { + using Op = + cutlass::epilogue::thread::LinearCombinationSilu; +}; + +template +struct Epilogue { + using Op = + cutlass::epilogue::thread::LinearCombinationRelu; +}; + +template +struct Epilogue { + using Op = cutlass::epilogue::thread::LinearCombinationGeneric< + cutlass::epilogue::thread::GELU_taylor, + ElementType, + ElementsPerVectorAccess, + ElementAccumulator, + ElementAccumulator, + BiasScaleMode, + cutlass::FloatRoundStyle::round_to_nearest, + true>; +}; + +template +struct Epilogue { + using Op = + cutlass::epilogue::thread::LinearCombination; }; constexpr auto DefaultScaleMode = cutlass::epilogue::thread::ScaleType::Default; -template -struct Epilogue -{ - using Op = cutlass::epilogue::thread::LinearCombinationSilu; -}; - -template -struct Epilogue -{ - using Op = cutlass::epilogue::thread::LinearCombinationRelu; -}; - -template -struct Epilogue -{ - using Op = cutlass::epilogue::thread::LinearCombinationGeneric; -}; - -template -struct Epilogue -{ - using Op = cutlass::epilogue::thread::LinearCombination; -}; - -} // namespace cutlass_extensions +template +struct Epilogue { + using Op = + cutlass::epilogue::thread::LinearCombinationSilu; +}; + +template +struct Epilogue { + using Op = + cutlass::epilogue::thread::LinearCombinationRelu; +}; + +template +struct Epilogue { + using Op = cutlass::epilogue::thread::LinearCombinationGeneric< + cutlass::epilogue::thread::GELU_taylor, + ElementType, + ElementsPerVectorAccess, + ElementAccumulator, + ElementAccumulator, + DefaultScaleMode, + cutlass::FloatRoundStyle::round_to_nearest, + true>; +}; + +template +struct Epilogue { + using Op = + cutlass::epilogue::thread::LinearCombination; +}; + +} // namespace cutlass_extensions diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/collective/collective_builder.hpp b/custom_ops/gpu_ops/cutlass_extensions/gemm/collective/collective_builder.hpp index d327eb18ae7..5d5b99b5b27 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/gemm/collective/collective_builder.hpp +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/collective/collective_builder.hpp @@ -21,7 +21,6 @@ #include "cutlass_extensions/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp" - ///////////////////////////////////////////////////////////////////////////////////////////////// namespace cutlass::gemm::collective { @@ -29,19 +28,17 @@ namespace cutlass::gemm::collective { ///////////////////////////////////////////////////////////////////////////////////////////////// // GMMA_TMA_WS_SS (BlockScaled Builders) -template < - class ElementA, - class GmemLayoutATag, - int AlignmentA, - class ElementB, - class GmemLayoutBTag, - int AlignmentB, - class ElementAccumulator, - class TileShape_MNK, - class ClusterShape_MNK, - class StageCountType, - int ScaleGranularityM -> +template struct CollectiveBuilder< arch::Sm90, arch::OpClassTensorOp, @@ -55,82 +52,124 @@ struct CollectiveBuilder< TileShape_MNK, ClusterShape_MNK, StageCountType, - KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum, - cute::enable_if_t< - not detail::is_use_rmem_A()> -> { - using KernelScheduleType = KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum; + KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum< + ScaleGranularityM>, + cute::enable_if_t()>> { + using KernelScheduleType = + KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum< + ScaleGranularityM>; static_assert(is_static::value); static_assert(is_static::value); #ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED - static_assert(cutlass::detail::dependent_false, "Unsupported Toolkit for SM90 Collective Builder\n"); + static_assert(cutlass::detail::dependent_false, + "Unsupported Toolkit for SM90 Collective Builder\n"); #endif - static_assert(detail::is_aligned(), + static_assert(detail::is_aligned(), "Should meet TMA alignment requirement\n"); - static constexpr bool IsArrayOfPointersGemm = (cute::is_any_of_v); + static constexpr bool IsArrayOfPointersGemm = + (cute::is_any_of_v); static constexpr bool IsFP8Input = detail::is_input_fp8(); static_assert((!IsFP8Input || !IsArrayOfPointersGemm), - "KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum is only compatible with FP8 Blocked Scaled version right now."); + "KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum is " + "only compatible with FP8 Blocked Scaled version right now."); // For fp32 types, map to tf32 MMA value type - using ElementAMma = cute::conditional_t, tfloat32_t, ElementA>; - using ElementBMma = cute::conditional_t, tfloat32_t, ElementB>; - - static constexpr cute::GMMA::Major GmmaMajorA = detail::gmma_ss_tag_to_major_A(); - static constexpr cute::GMMA::Major GmmaMajorB = detail::gmma_ss_tag_to_major_B(); - - static constexpr bool IsCooperative = cute::is_any_of_v>; + using ElementAMma = cute:: + conditional_t, tfloat32_t, ElementA>; + using ElementBMma = cute:: + conditional_t, tfloat32_t, ElementB>; + + static constexpr cute::GMMA::Major GmmaMajorA = + detail::gmma_ss_tag_to_major_A(); + static constexpr cute::GMMA::Major GmmaMajorB = + detail::gmma_ss_tag_to_major_B(); + + static constexpr bool IsCooperative = cute::is_any_of_v< + KernelScheduleType, + KernelTmaWarpSpecializedCooperative, + KernelPtrArrayTmaWarpSpecializedCooperative, + KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum< + ScaleGranularityM>>; using AtomLayoutMNK = cute::conditional_t>, Layout>>; - - using TiledMma = decltype(cute::make_tiled_mma(cute::GMMA::ss_op_selector< - ElementAMma, ElementBMma, ElementAccumulator, TileShape_MNK, GmmaMajorA, GmmaMajorB>(), AtomLayoutMNK{})); - - using GmemTiledCopyA = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape_MNK{}))); - using GmemTiledCopyB = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape_MNK{}))); - - using SmemLayoutAtomA = decltype(detail::ss_smem_selector< - GmmaMajorA, ElementAMma, decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); - using SmemLayoutAtomB = decltype(detail::ss_smem_selector< - GmmaMajorB, ElementBMma, decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); - - static constexpr size_t TensorMapStorage = IsArrayOfPointersGemm ? sizeof(cute::TmaDescriptor) * 2 /* for A and B */ : 0; + Layout>, + Layout>>; + + using TiledMma = decltype(cute::make_tiled_mma( + cute::GMMA::ss_op_selector(), + AtomLayoutMNK{})); + + using GmemTiledCopyA = decltype(detail::sm90_cluster_shape_to_tma_atom( + shape<1>(ClusterShape_MNK{}))); + using GmemTiledCopyB = decltype(detail::sm90_cluster_shape_to_tma_atom( + shape<0>(ClusterShape_MNK{}))); + + using SmemLayoutAtomA = + decltype(detail::ss_smem_selector(TileShape_MNK{})), + decltype(cute::get<2>( + TileShape_MNK{}))>()); + using SmemLayoutAtomB = + decltype(detail::ss_smem_selector(TileShape_MNK{})), + decltype(cute::get<2>( + TileShape_MNK{}))>()); + + static constexpr size_t TensorMapStorage = + IsArrayOfPointersGemm ? sizeof(cute::TmaDescriptor) * 2 /* for A and B */ + : 0; static constexpr int KernelSmemCarveout = static_cast(TensorMapStorage); - static constexpr int PipelineStages = detail::compute_stage_count_or_override(StageCountType{}); - using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedBlockScalingSubGroupMFP8; + static constexpr int PipelineStages = + detail::compute_stage_count_or_override(StageCountType{}); + using DispatchPolicy = + MainloopSm90TmaGmmaWarpSpecializedBlockScalingSubGroupMFP8< + PipelineStages, + ClusterShape_MNK, + KernelScheduleType, + ScaleGranularityM>; using SmemCopyAtomA = void; using SmemCopyAtomB = void; - using CollectiveOp = CollectiveMma< - DispatchPolicy, - TileShape_MNK, - ElementA, - TagToStrideA_t, - ElementB, - TagToStrideB_t, - TiledMma, - GmemTiledCopyA, - SmemLayoutAtomA, - SmemCopyAtomA, - cute::identity, - GmemTiledCopyB, - SmemLayoutAtomB, - SmemCopyAtomB, - cute::identity - >; + using CollectiveOp = CollectiveMma, + ElementB, + TagToStrideB_t, + TiledMma, + GmemTiledCopyA, + SmemLayoutAtomA, + SmemCopyAtomA, + cute::identity, + GmemTiledCopyB, + SmemLayoutAtomB, + SmemCopyAtomB, + cute::identity>; }; - ///////////////////////////////////////////////////////////////////////////////////////////////// } // namespace cutlass::gemm::collective diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/collective/collective_builder_gated.hpp b/custom_ops/gpu_ops/cutlass_extensions/gemm/collective/collective_builder_gated.hpp index 227aee50fe1..e3f7e03ef9c 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/gemm/collective/collective_builder_gated.hpp +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/collective/collective_builder_gated.hpp @@ -39,12 +39,23 @@ namespace cutlass::gemm::collective { ///////////////////////////////////////////////////////////////////////////////////////////////// -template class Activation, - bool SwapAB = false, class Enable = void> +template + class Activation, + bool SwapAB = false, + class Enable = void> struct CollectiveBuilderGated { static_assert(sizeof(ElementA) == 0, "Could not build a collective for given parameters."); @@ -52,7 +63,7 @@ struct CollectiveBuilderGated { ///////////////////////////////////////////////////////////////////////////////////////////////// -} // namespace cutlass::gemm::collective +} // namespace cutlass::gemm::collective ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/collective/collective_mma_gated.hpp b/custom_ops/gpu_ops/cutlass_extensions/gemm/collective/collective_mma_gated.hpp index 56849ee56f9..d2d06fed312 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/gemm/collective/collective_mma_gated.hpp +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/collective/collective_mma_gated.hpp @@ -39,12 +39,23 @@ namespace cutlass::gemm::collective { ///////////////////////////////////////////////////////////////////////////////////////////////// -template class Activation, + template + class Activation, bool SwapAB = false> struct CollectiveMmaGated { static_assert(cutlass::detail::dependent_false, @@ -53,7 +64,7 @@ struct CollectiveMmaGated { ///////////////////////////////////////////////////////////////////////////////////////////////// -} // namespace cutlass::gemm::collective +} // namespace cutlass::gemm::collective ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/collective/fp8_accumulation.hpp b/custom_ops/gpu_ops/cutlass_extensions/gemm/collective/fp8_accumulation.hpp index 0a530e5c141..b22492fd9e5 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/gemm/collective/fp8_accumulation.hpp +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/collective/fp8_accumulation.hpp @@ -12,17 +12,18 @@ // See the License for the specific language governing permissions and // limitations under the License. -// adapted from: https://github.com/soundOfDestiny/cutlass/blob/a4208aa6958864923505cade9c63eb2a6daf16e5/include/cutlass/gemm/collective/fp8_accumulation.hpp +// adapted from: +// https://github.com/soundOfDestiny/cutlass/blob/a4208aa6958864923505cade9c63eb2a6daf16e5/include/cutlass/gemm/collective/fp8_accumulation.hpp /*************************************************************************************************** - * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation @@ -34,14 +35,15 @@ * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ @@ -61,24 +63,26 @@ namespace cutlass::gemm::collective { -template < - class EngineAccum, - class LayoutAccum> +template struct GmmaFP8AccumulationWithScale { using TensorAccum = cute::Tensor; using ElementAccumulator = typename EngineAccum::value_type; - static_assert(is_static::value, "Accumulator Layout should be static"); - static_assert(is_rmem::value , "Accumulator tensor must be rmem resident."); + static_assert(is_static::value, + "Accumulator Layout should be static"); + static_assert(is_rmem::value, + "Accumulator tensor must be rmem resident."); -private: - TensorAccum& accum_; + private: + TensorAccum &accum_; TensorAccum accum_temp_; - uint32_t accum_promotion_interval_; // defines the max num of executed MMAs after which accum should be promoted. - uint32_t mma_count_per_mainloop_iteration_; // num of MMAs per k_tile of mainloop - uint32_t mma_count_; // current executed MMAs - uint32_t reset_accum_flag_; // accum needs to be zeroed or not. + uint32_t accum_promotion_interval_; // defines the max num of executed MMAs + // after which accum should be promoted. + uint32_t + mma_count_per_mainloop_iteration_; // num of MMAs per k_tile of mainloop + uint32_t mma_count_; // current executed MMAs + uint32_t reset_accum_flag_; // accum needs to be zeroed or not. // promote or `add` the partial accumulators to main accumulator (FADD). CUTLASS_DEVICE @@ -90,18 +94,20 @@ struct GmmaFP8AccumulationWithScale { } } - // `multiply` scale the partial accumulators and `add` to main accumulator (FFMA). - template < - class EngineScale, - class LayoutScale> - CUTLASS_DEVICE - void scale_core(const cute::Tensor &scale) { + // `multiply` scale the partial accumulators and `add` to main accumulator + // (FFMA). + template + CUTLASS_DEVICE void scale_core( + const cute::Tensor &scale) { using TensorScale = cute::Tensor; - static_assert(is_static::value, "Scale Layout should be static"); - static_assert(is_rmem::value , "Scale tensor must be rmem resident."); + static_assert(is_static::value, + "Scale Layout should be static"); + static_assert(is_rmem::value, + "Scale tensor must be rmem resident."); - static_assert(LayoutAccum{}.shape() == LayoutScale{}.shape(), "Accumulator and scale must have same shape."); + static_assert(LayoutAccum{}.shape() == LayoutScale{}.shape(), + "Accumulator and scale must have same shape."); warpgroup_wait<0>(); CUTLASS_PRAGMA_UNROLL @@ -110,18 +116,16 @@ struct GmmaFP8AccumulationWithScale { } } -public: + public: CUTLASS_DEVICE - GmmaFP8AccumulationWithScale( - TensorAccum &accum, - uint32_t accum_promotion_interval, - uint32_t mma_count_per_mainloop_iteration) + GmmaFP8AccumulationWithScale(TensorAccum &accum, + uint32_t accum_promotion_interval, + uint32_t mma_count_per_mainloop_iteration) : accum_(accum), accum_promotion_interval_(accum_promotion_interval), mma_count_per_mainloop_iteration_(mma_count_per_mainloop_iteration), mma_count_(0), - reset_accum_flag_(0) - { + reset_accum_flag_(0) { accum_temp_ = cute::make_fragment_like(accum); } @@ -130,32 +134,31 @@ struct GmmaFP8AccumulationWithScale { // CUTLASS_DEVICE - TensorAccum& operator()() { - return accum_temp_; - } + TensorAccum &operator()() { return accum_temp_; } /// prepare the MMA accumulators when initialization or zeroing is required. CUTLASS_DEVICE - bool prepare_if_needed() { - return reset_accum_flag_; - } + bool prepare_if_needed() { return reset_accum_flag_; } // // Methods (for FADD version) // - /// promote (add) the results from the MMA accumulators to main accumulator if needed. + /// promote (add) the results from the MMA accumulators to main accumulator if + /// needed. CUTLASS_DEVICE void promote_if_needed() { mma_count_ += mma_count_per_mainloop_iteration_; - reset_accum_flag_ = __shfl_sync(0xffffffff, mma_count_ == accum_promotion_interval_, 0); + reset_accum_flag_ = + __shfl_sync(0xffffffff, mma_count_ == accum_promotion_interval_, 0); if (reset_accum_flag_) { promote_core(); mma_count_ = 0; } } - /// promote (add) the residue results from the MMA accumulators to main accumulator if needed. + /// promote (add) the residue results from the MMA accumulators to main + /// accumulator if needed. CUTLASS_DEVICE void promote_residue_if_needed() { if (__shfl_sync(0xffffffff, mma_count_ > 0, 0)) { @@ -167,30 +170,29 @@ struct GmmaFP8AccumulationWithScale { // Methods (for FFMA version) // - /// scale (multiply_add) the results from the MMA accumulators to main accumulator if needed. - template < - class EngineScale, - class LayoutScale> - CUTLASS_DEVICE - void scale_if_needed(const cute::Tensor &scale) { + /// scale (multiply_add) the results from the MMA accumulators to main + /// accumulator if needed. + template + CUTLASS_DEVICE void scale_if_needed( + const cute::Tensor &scale) { mma_count_ += mma_count_per_mainloop_iteration_; - reset_accum_flag_ = __shfl_sync(0xffffffff, mma_count_ == accum_promotion_interval_, 0); + reset_accum_flag_ = + __shfl_sync(0xffffffff, mma_count_ == accum_promotion_interval_, 0); if (reset_accum_flag_) { scale_core(scale); mma_count_ = 0; } } - /// scale (multiply_add) the residue results from the MMA accumulators to main accumulator if needed. - template < - class EngineScale, - class LayoutScale> - CUTLASS_DEVICE - void scale_residue_if_needed(const cute::Tensor &scale) { + /// scale (multiply_add) the residue results from the MMA accumulators to main + /// accumulator if needed. + template + CUTLASS_DEVICE void scale_residue_if_needed( + const cute::Tensor &scale) { if (__shfl_sync(0xffffffff, mma_count_ > 0, 0)) { scale_core(scale); } } }; -} // namespace cutlass::gemm::collective +} // namespace cutlass::gemm::collective diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/collective/sm90_mma_gated_tma_gmma_ss_warpspecialized.hpp b/custom_ops/gpu_ops/cutlass_extensions/gemm/collective/sm90_mma_gated_tma_gmma_ss_warpspecialized.hpp index 8ff14a2a49b..f335ec2d399 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/gemm/collective/sm90_mma_gated_tma_gmma_ss_warpspecialized.hpp +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/collective/sm90_mma_gated_tma_gmma_ss_warpspecialized.hpp @@ -53,18 +53,43 @@ using namespace cute; ///////////////////////////////////////////////////////////////////////////////////////////////// // WarpSpecialized Mainloop -template class Activation_, bool SwapAB_> + template + class Activation_, + bool SwapAB_> struct CollectiveMmaGated< MainloopSm90TmaGmmaWarpSpecialized, - TileShape_, ElementA_, StrideA_, ElementB_, StrideB_, TiledMma_, - GmemTiledCopyA_, SmemLayoutAtomA_, SmemCopyAtomA_, TransformA_, - GmemTiledCopyB_, SmemLayoutAtomB_, SmemCopyAtomB_, TransformB_, Activation_, + TileShape_, + ElementA_, + StrideA_, + ElementB_, + StrideB_, + TiledMma_, + GmemTiledCopyA_, + SmemLayoutAtomA_, + SmemCopyAtomA_, + TransformA_, + GmemTiledCopyB_, + SmemLayoutAtomB_, + SmemCopyAtomB_, + TransformB_, + Activation_, SwapAB_> { static constexpr bool isGated = true; static constexpr bool SwapAB = SwapAB_; @@ -93,7 +118,8 @@ struct CollectiveMmaGated< using Activation = Activation_; using ElementAux = cute::conditional_t; - using ValTypeAux = cute::conditional_t; using MainloopPipeline = cutlass::PipelineTmaAsync; @@ -118,16 +144,20 @@ struct CollectiveMmaGated< // Tile along modes in a way that maximizes the TMA box size. using SmemLayoutA = decltype(tile_to_shape( SmemLayoutAtomA{}, - make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}), + make_shape(shape<0>(TileShape{}), + shape<2>(TileShape{}), Int{}), conditional_t<::cutlass::gemm::detail::is_major<0, StrideA>(), - Step<_2, _1, _3>, Step<_1, _2, _3>>{})); + Step<_2, _1, _3>, + Step<_1, _2, _3>>{})); using SmemLayoutB = decltype(tile_to_shape( SmemLayoutAtomB{}, - make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), + make_shape(shape<1>(TileShape{}), + shape<2>(TileShape{}), Int{}), conditional_t<::cutlass::gemm::detail::is_major<0, StrideB>(), - Step<_2, _1, _3>, Step<_1, _2, _3>>{})); + Step<_2, _1, _3>, + Step<_1, _2, _3>>{})); using SmemLayoutAux = cute::conditional_t; static_assert(DispatchPolicy::Stages >= 2, @@ -151,10 +181,12 @@ struct CollectiveMmaGated< static constexpr bool ConvertF32toTF32A = cute::is_same_v; static constexpr bool ConvertF32toTF32B = cute::is_same_v; using InternalElementA = - cute::conditional_t>>; using InternalElementB = - cute::conditional_t>>; using InternalElementAux = cute::conditional_t; @@ -195,18 +227,22 @@ struct CollectiveMmaGated< using TMA_A = decltype(make_tma_copy( GmemTiledCopyA{}, make_tensor(static_cast(nullptr), - repeat_like(StrideA{}, int32_t(0)), StrideA{}), + repeat_like(StrideA{}, int32_t(0)), + StrideA{}), SmemLayoutA{}(_, _, cute::Int<0>{}), make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), - size<1>(ClusterShape{}))); // mcast along N mode for this M load, if any + size<1>( + ClusterShape{}))); // mcast along N mode for this M load, if any // Assumption: StrideB is congruent with Problem_NK using TMA_B = decltype(make_tma_copy( GmemTiledCopyB{}, make_tensor(static_cast(nullptr), - repeat_like(StrideB{}, int32_t(0)), StrideB{}), + repeat_like(StrideB{}, int32_t(0)), + StrideB{}), SmemLayoutB{}(_, _, cute::Int<0>{}), make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), - size<0>(ClusterShape{}))); // mcast along M mode for this N load, if any + size<0>( + ClusterShape{}))); // mcast along M mode for this N load, if any using TMA_Aux = cute::conditional_t; TMA_A tma_load_a; TMA_B tma_load_b; @@ -220,9 +256,10 @@ struct CollectiveMmaGated< // template - static constexpr Params - to_underlying_arguments(ProblemShape const &problem_shape, - Arguments const &args, void *workspace) { + static constexpr Params to_underlying_arguments( + ProblemShape const &problem_shape, + Arguments const &args, + void *workspace) { (void)workspace; // Optionally append 1s until problem shape is rank-4 (MNKL), in case it is @@ -238,36 +275,44 @@ struct CollectiveMmaGated< Tensor tensor_b = make_tensor(ptr_B0, make_layout(make_shape(N, K, L), args.dB)); typename Params::TMA_A tma_load_a = make_tma_copy( - GmemTiledCopyA{}, tensor_a, SmemLayoutA{}(_, _, cute::Int<0>{}), + GmemTiledCopyA{}, + tensor_a, + SmemLayoutA{}(_, _, cute::Int<0>{}), make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), - size<1>(ClusterShape{})); // mcast along N mode for this M load, if any + size<1>(ClusterShape{})); // mcast along N mode for this M load, if any typename Params::TMA_B tma_load_b = make_tma_copy( - GmemTiledCopyB{}, tensor_b, SmemLayoutB{}(_, _, cute::Int<0>{}), + GmemTiledCopyB{}, + tensor_b, + SmemLayoutB{}(_, _, cute::Int<0>{}), make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), - size<0>(ClusterShape{})); // mcast along M mode for this N load, if any + size<0>(ClusterShape{})); // mcast along M mode for this N load, if any if constexpr (SwapAB) { auto ptr_Aux = reinterpret_cast(args.ptr_B1); Tensor tensor_aux = make_tensor(ptr_Aux, make_layout(make_shape(M, K, L), args.dA)); typename Params::TMA_Aux tma_load_aux = make_tma_copy( - GmemTiledCopyA{}, tensor_aux, SmemLayoutA{}(_, _, cute::Int<0>{}), + GmemTiledCopyA{}, + tensor_aux, + SmemLayoutA{}(_, _, cute::Int<0>{}), make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), size<1>( - ClusterShape{})); // mcast along N mode for this M load, if any - return {tma_load_a, tma_load_b, tma_load_aux, args.scale_d0, - args.scale_d1}; + ClusterShape{})); // mcast along N mode for this M load, if any + return { + tma_load_a, tma_load_b, tma_load_aux, args.scale_d0, args.scale_d1}; } else { auto ptr_Aux = reinterpret_cast(args.ptr_B1); Tensor tensor_aux = make_tensor(ptr_Aux, make_layout(make_shape(N, K, L), args.dB)); typename Params::TMA_Aux tma_load_aux = make_tma_copy( - GmemTiledCopyB{}, tensor_aux, SmemLayoutB{}(_, _, cute::Int<0>{}), + GmemTiledCopyB{}, + tensor_aux, + SmemLayoutB{}(_, _, cute::Int<0>{}), make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), size<0>( - ClusterShape{})); // mcast along M mode for this N load, if any - return {tma_load_a, tma_load_b, tma_load_aux, args.scale_d0, - args.scale_d1}; + ClusterShape{})); // mcast along M mode for this N load, if any + return { + tma_load_a, tma_load_b, tma_load_aux, args.scale_d0, args.scale_d1}; } } @@ -293,8 +338,9 @@ struct CollectiveMmaGated< cute::make_shape(N, K, L), StrideB{}); if (!implementable) { - CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the " - "minimum alignment requirements for TMA.\n"); + CUTLASS_TRACE_HOST( + " CAN IMPLEMENT: Problem Size doesn't meet the " + "minimum alignment requirements for TMA.\n"); } return implementable; } @@ -342,49 +388,64 @@ struct CollectiveMmaGated< // TMA requires special handling of strides to deal with coord codomain // mapping Represent the full tensors -- get these from TMA Tensor mA_mkl = mainloop_params.tma_load_a.get_tma_tensor( - make_shape(M, K, L)); // (m,k,l) + make_shape(M, K, L)); // (m,k,l) Tensor mB_nkl = mainloop_params.tma_load_b.get_tma_tensor( - make_shape(N, K, L)); // (n,k,l) + make_shape(N, K, L)); // (n,k,l) // Make tiled views, defer the slice - Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_, _, _), - Step<_1, X, _1>{}); // (BLK_M,BLK_K,m,k,l) - Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_, _, _), - Step{}); // (BLK_N,BLK_K,n,k,l) + Tensor gA_mkl = local_tile(mA_mkl, + TileShape{}, + make_coord(_, _, _), + Step<_1, X, _1>{}); // (BLK_M,BLK_K,m,k,l) + Tensor gB_nkl = local_tile(mB_nkl, + TileShape{}, + make_coord(_, _, _), + Step{}); // (BLK_N,BLK_K,n,k,l) if constexpr (SwapAB) { Tensor mAux_xkl = mainloop_params.tma_load_aux.get_tma_tensor( - make_shape(M, K, L)); // (m,k,l) - Tensor gAux_xkl = local_tile(mAux_xkl, TileShape{}, make_coord(_, _, _), - Step<_1, X, _1>{}); // (BLK_M,BLK_K,m,k,l) + make_shape(M, K, L)); // (m,k,l) + Tensor gAux_xkl = local_tile(mAux_xkl, + TileShape{}, + make_coord(_, _, _), + Step<_1, X, _1>{}); // (BLK_M,BLK_K,m,k,l) return cute::make_tuple(gA_mkl, gB_nkl, gAux_xkl); } else { Tensor mAux_xkl = mainloop_params.tma_load_aux.get_tma_tensor( - make_shape(N, K, L)); // (n,k,l) - Tensor gAux_xkl = local_tile(mAux_xkl, TileShape{}, make_coord(_, _, _), - Step{}); // (BLK_N,BLK_K,n,k,l) + make_shape(N, K, L)); // (n,k,l) + Tensor gAux_xkl = local_tile(mAux_xkl, + TileShape{}, + make_coord(_, _, _), + Step{}); // (BLK_N,BLK_K,n,k,l) return cute::make_tuple(gA_mkl, gB_nkl, gAux_xkl); } } /// Perform a collective-scoped matrix multiply-accumulate /// Producer Perspective - template - CUTLASS_DEVICE void - load(Params const &mainloop_params, MainloopPipeline pipeline, - PipelineState smem_pipe_write, - cute::tuple const &load_inputs, - BlockCoord const &blk_coord, KTileIterator k_tile_iter, int k_tile_count, - int thread_idx, uint32_t block_rank_in_cluster, - TensorStorage &shared_tensors) { + CUTLASS_DEVICE void load( + Params const &mainloop_params, + MainloopPipeline pipeline, + PipelineState smem_pipe_write, + cute::tuple const &load_inputs, + BlockCoord const &blk_coord, + KTileIterator k_tile_iter, + int k_tile_count, + int thread_idx, + uint32_t block_rank_in_cluster, + TensorStorage &shared_tensors) { int lane_predicate = cute::elect_one_sync(); if (lane_predicate) { Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), - SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), - SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) Tensor sAux = make_tensor(make_smem_ptr(shared_tensors.smem_Aux.data()), SmemLayoutAux{}); @@ -412,17 +473,17 @@ struct CollectiveMmaGated< cluster_local_block_id.x); // Partition the inputs based on the current block coordinates. auto [m_coord, n_coord, k_coord, l_coord] = blk_coord; - Tensor gA = gA_mkl(_, _, m_coord, _, l_coord); // (BLK_M,BLK_K,k) - Tensor gB = gB_nkl(_, _, n_coord, _, l_coord); // (BLK_N,BLK_K,k) + Tensor gA = gA_mkl(_, _, m_coord, _, l_coord); // (BLK_M,BLK_K,k) + Tensor gB = gB_nkl(_, _, n_coord, _, l_coord); // (BLK_N,BLK_K,k) Tensor gAux = SwapAB ? gAux_xkl(_, _, m_coord, _, l_coord) : gAux_xkl(_, _, n_coord, _, l_coord); // Applies the mapping from block_tma_a - Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k) - Tensor tAsA = block_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE) + Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k) + Tensor tAsA = block_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE) - Tensor tBgB = block_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k) - Tensor tBsB = block_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE) + Tensor tBgB = block_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k) + Tensor tBsB = block_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE) Tensor tAuxgAux = block_tma_aux.partition_S(gAux); Tensor tAuxsAux = block_tma_aux.partition_D(sAux); @@ -435,18 +496,18 @@ struct CollectiveMmaGated< // Maps the tile -> block, value if constexpr (cute::is_same_v) { auto block_layout = - Layout{}; // (m,n) -> - // block_id + Layout{}; // (m,n) -> + // block_id for (int n = 0; n < size<1>(block_layout); ++n) { - mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x, - n, Int<0>{})); + mcast_mask_a |= (uint16_t(1) << block_layout( + cluster_local_block_id.x, n, Int<0>{})); } } if constexpr (cute::is_same_v) { auto block_layout = - Layout{}; // (m,n) -> - // block_id + Layout{}; // (m,n) -> + // block_id for (int m = 0; m < size<0>(block_layout); ++m) { mcast_mask_b |= (uint16_t(1) << block_layout( m, cluster_local_block_id.y, Int<0>{})); @@ -475,11 +536,14 @@ struct CollectiveMmaGated< int write_stage = smem_pipe_write.index(); copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a), - tAgA(_, _, _, *k_tile_iter), tAsA(_, _, _, write_stage)); + tAgA(_, _, _, *k_tile_iter), + tAsA(_, _, _, write_stage)); copy(mainloop_params.tma_load_b.with(*tma_barrier, mcast_mask_b), - tBgB(_, _, _, *k_tile_iter), tBsB(_, _, _, write_stage)); + tBgB(_, _, _, *k_tile_iter), + tBsB(_, _, _, write_stage)); copy(mainloop_params.tma_load_aux.with(*tma_barrier, mcast_mask_aux), - tAuxgAux(_, _, _, *k_tile_iter), tAuxsAux(_, _, _, write_stage)); + tAuxgAux(_, _, _, *k_tile_iter), + tAuxsAux(_, _, _, write_stage)); ++k_tile_iter; // Advance smem_pipe_write @@ -508,10 +572,14 @@ struct CollectiveMmaGated< /// Perform a collective-scoped matrix multiply-accumulate /// Consumer Perspective template - CUTLASS_DEVICE void - mma(MainloopPipeline pipeline, PipelineState smem_pipe_read, - FrgTensorC &accum0, FrgTensorC &accum1, int k_tile_count, int thread_idx, - TensorStorage &shared_tensors, Params const &mainloop_params) { + CUTLASS_DEVICE void mma(MainloopPipeline pipeline, + PipelineState smem_pipe_read, + FrgTensorC &accum0, + FrgTensorC &accum1, + int k_tile_count, + int thread_idx, + TensorStorage &shared_tensors, + Params const &mainloop_params) { static_assert(is_rmem::value, "C tensor must be rmem resident."); static_assert(cute::rank(SmemLayoutA{}) == 3, @@ -528,9 +596,9 @@ struct CollectiveMmaGated< "smem sourced instructions."); Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), - SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), - SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) Tensor sAux = make_tensor(make_smem_ptr(shared_tensors.smem_Aux.data()), SmemLayoutAux{}); @@ -541,12 +609,12 @@ struct CollectiveMmaGated< TiledMma tiled_mma; auto thread_mma = tiled_mma.get_thread_slice(thread_idx); - Tensor tCsA = thread_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE) - Tensor tCsB = thread_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE) + Tensor tCsA = thread_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCsB = thread_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE) // Allocate "fragments/descriptors" - Tensor tCrA = thread_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K,PIPE) - Tensor tCrB = thread_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K,PIPE) + Tensor tCrA = thread_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCrB = thread_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K,PIPE) auto tCsAux = [&]() -> auto { if constexpr (SwapAB) { @@ -554,34 +622,36 @@ struct CollectiveMmaGated< } else { return thread_mma.partition_B(sAux); } - }(); + } + (); auto tCrAux = [&]() -> auto { if constexpr (SwapAB) { return thread_mma.make_fragment_A(tCsAux); } else { return thread_mma.make_fragment_B(tCsAux); } - }(); + } + (); - CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(accum0)); // M - CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum0)); // N - CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // K - CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsB)); // PIPE + CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(accum0)); // M + CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum0)); // N + CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // K + CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsB)); // PIPE if constexpr (SwapAB) { - CUTE_STATIC_ASSERT_V(size<1>(tCsAux) == size<1>(accum1)); // M - CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum1)); // N - CUTE_STATIC_ASSERT_V(size<2>(tCsB) == size<2>(tCsAux)); // K - CUTE_STATIC_ASSERT_V(size<3>(tCsB) == size<3>(tCsAux)); // PIPE + CUTE_STATIC_ASSERT_V(size<1>(tCsAux) == size<1>(accum1)); // M + CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum1)); // N + CUTE_STATIC_ASSERT_V(size<2>(tCsB) == size<2>(tCsAux)); // K + CUTE_STATIC_ASSERT_V(size<3>(tCsB) == size<3>(tCsAux)); // PIPE } else { - CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(accum1)); // M - CUTE_STATIC_ASSERT_V(size<1>(tCsAux) == size<2>(accum1)); // N - CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsAux)); // K - CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsAux)); // PIPE + CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(accum1)); // M + CUTE_STATIC_ASSERT_V(size<1>(tCsAux) == size<2>(accum1)); // N + CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsAux)); // K + CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsAux)); // PIPE } - CUTE_STATIC_ASSERT_V(Int{} == size<2>(sA)); // PIPE - CUTE_STATIC_ASSERT_V(Int{} == size<2>(sB)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sA)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sB)); // PIPE CUTE_STATIC_ASSERT_V(Int{} == - size<2>(sAux)); // PIPE + size<2>(sAux)); // PIPE // // PIPELINED MAIN LOOP @@ -613,14 +683,20 @@ struct CollectiveMmaGated< CUTLASS_PRAGMA_UNROLL for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { // (V,M,K) x (V,N,K) => (V,M,N) - cute::gemm(tiled_mma, tCrA(_, _, k_block, read_stage), - tCrB(_, _, k_block, read_stage), accum0); + cute::gemm(tiled_mma, + tCrA(_, _, k_block, read_stage), + tCrB(_, _, k_block, read_stage), + accum0); if constexpr (SwapAB) { - cute::gemm(tiled_mma, tCrAux(_, _, k_block, read_stage), - tCrB(_, _, k_block, read_stage), accum1); + cute::gemm(tiled_mma, + tCrAux(_, _, k_block, read_stage), + tCrB(_, _, k_block, read_stage), + accum1); } else { - cute::gemm(tiled_mma, tCrA(_, _, k_block, read_stage), - tCrAux(_, _, k_block, read_stage), accum1); + cute::gemm(tiled_mma, + tCrA(_, _, k_block, read_stage), + tCrAux(_, _, k_block, read_stage), + accum1); } tiled_mma.accumulate_ = GMMA::ScaleOut::One; } @@ -654,14 +730,20 @@ struct CollectiveMmaGated< CUTLASS_PRAGMA_UNROLL for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { // (V,M,K) x (V,N,K) => (V,M,N) - cute::gemm(tiled_mma, tCrA(_, _, k_block, read_stage), - tCrB(_, _, k_block, read_stage), accum0); + cute::gemm(tiled_mma, + tCrA(_, _, k_block, read_stage), + tCrB(_, _, k_block, read_stage), + accum0); if constexpr (SwapAB) { - cute::gemm(tiled_mma, tCrAux(_, _, k_block, read_stage), - tCrB(_, _, k_block, read_stage), accum1); + cute::gemm(tiled_mma, + tCrAux(_, _, k_block, read_stage), + tCrB(_, _, k_block, read_stage), + accum1); } else { - cute::gemm(tiled_mma, tCrA(_, _, k_block, read_stage), - tCrAux(_, _, k_block, read_stage), accum1); + cute::gemm(tiled_mma, + tCrA(_, _, k_block, read_stage), + tCrAux(_, _, k_block, read_stage), + accum1); } tiled_mma.accumulate_ = GMMA::ScaleOut::One; } @@ -699,8 +781,9 @@ struct CollectiveMmaGated< warpgroup_wait<0>(); for (int count = 0; count < prologue_mma_count; ++count) { - pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, - // done _computing_ on it + pipeline.consumer_release( + smem_pipe_release); // UNLOCK smem_pipe_release, + // done _computing_ on it ++smem_pipe_release; } } @@ -708,6 +791,6 @@ struct CollectiveMmaGated< ///////////////////////////////////////////////////////////////////////////////////////////////// -} // namespace cutlass::gemm::collective +} // namespace cutlass::gemm::collective ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/collective/sm90_mma_gated_tma_gmma_ss_warpspecialized_fp8.hpp b/custom_ops/gpu_ops/cutlass_extensions/gemm/collective/sm90_mma_gated_tma_gmma_ss_warpspecialized_fp8.hpp index 76ffbdb2e62..c34ad242e25 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/gemm/collective/sm90_mma_gated_tma_gmma_ss_warpspecialized_fp8.hpp +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/collective/sm90_mma_gated_tma_gmma_ss_warpspecialized_fp8.hpp @@ -55,18 +55,43 @@ using namespace cute; ///////////////////////////////////////////////////////////////////////////////////////////////// // WarpSpecialized Mainloop -template class Activation_, bool SwapAB_> + template + class Activation_, + bool SwapAB_> struct CollectiveMmaGated< MainloopSm90TmaGmmaWarpSpecializedFP8, - TileShape_, ElementA_, StrideA_, ElementB_, StrideB_, TiledMma_, - GmemTiledCopyA_, SmemLayoutAtomA_, SmemCopyAtomA_, TransformA_, - GmemTiledCopyB_, SmemLayoutAtomB_, SmemCopyAtomB_, TransformB_, Activation_, + TileShape_, + ElementA_, + StrideA_, + ElementB_, + StrideB_, + TiledMma_, + GmemTiledCopyA_, + SmemLayoutAtomA_, + SmemCopyAtomA_, + TransformA_, + GmemTiledCopyB_, + SmemLayoutAtomB_, + SmemCopyAtomB_, + TransformB_, + Activation_, SwapAB_> { static constexpr bool isGated = true; static constexpr bool SwapAB = SwapAB_; @@ -74,9 +99,9 @@ struct CollectiveMmaGated< // // Type Aliases // - using DispatchPolicy = - MainloopSm90TmaGmmaWarpSpecializedFP8; + using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedFP8; using TileShape = TileShape_; using ElementA = ElementA_; using StrideA = StrideA_; @@ -96,7 +121,8 @@ struct CollectiveMmaGated< using Activation = Activation_; using ElementAux = cute::conditional_t; - using ValTypeAux = cute::conditional_t; using MainloopPipeline = cutlass::PipelineTmaAsync; @@ -121,16 +147,20 @@ struct CollectiveMmaGated< // Tile along modes in a way that maximizes the TMA box size. using SmemLayoutA = decltype(tile_to_shape( SmemLayoutAtomA{}, - make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}), + make_shape(shape<0>(TileShape{}), + shape<2>(TileShape{}), Int{}), conditional_t<::cutlass::gemm::detail::is_major<0, StrideA>(), - Step<_2, _1, _3>, Step<_1, _2, _3>>{})); + Step<_2, _1, _3>, + Step<_1, _2, _3>>{})); using SmemLayoutB = decltype(tile_to_shape( SmemLayoutAtomB{}, - make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), + make_shape(shape<1>(TileShape{}), + shape<2>(TileShape{}), Int{}), conditional_t<::cutlass::gemm::detail::is_major<0, StrideB>(), - Step<_2, _1, _3>, Step<_1, _2, _3>>{})); + Step<_2, _1, _3>, + Step<_1, _2, _3>>{})); using SmemLayoutAux = cute::conditional_t; static_assert(DispatchPolicy::Stages >= 2, @@ -184,18 +214,22 @@ struct CollectiveMmaGated< using TMA_A = decltype(make_tma_copy( GmemTiledCopyA{}, make_tensor(static_cast(nullptr), - repeat_like(StrideA{}, int32_t(0)), StrideA{}), + repeat_like(StrideA{}, int32_t(0)), + StrideA{}), SmemLayoutA{}(_, _, 0), make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), - size<1>(ClusterShape{}))); // mcast along N mode for this M load, if any + size<1>( + ClusterShape{}))); // mcast along N mode for this M load, if any // Assumption: StrideB is congruent with Problem_NK using TMA_B = decltype(make_tma_copy( GmemTiledCopyB{}, make_tensor(static_cast(nullptr), - repeat_like(StrideB{}, int32_t(0)), StrideB{}), + repeat_like(StrideB{}, int32_t(0)), + StrideB{}), SmemLayoutB{}(_, _, 0), make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), - size<0>(ClusterShape{}))); // mcast along M mode for this N load, if any + size<0>( + ClusterShape{}))); // mcast along M mode for this N load, if any using TMA_Aux = cute::conditional_t; TMA_A tma_load_a; TMA_B tma_load_b; @@ -210,9 +244,10 @@ struct CollectiveMmaGated< // template - static constexpr Params - to_underlying_arguments(ProblemShape const &problem_shape, - Arguments const &args, void *workspace) { + static constexpr Params to_underlying_arguments( + ProblemShape const &problem_shape, + Arguments const &args, + void *workspace) { (void)workspace; // Optionally append 1s until problem shape is rank-4 (MNKL), in case it is @@ -228,35 +263,51 @@ struct CollectiveMmaGated< Tensor tensor_b = make_tensor(ptr_B0, make_layout(make_shape(N, K, L), args.dB)); typename Params::TMA_A tma_load_a = make_tma_copy( - GmemTiledCopyA{}, tensor_a, SmemLayoutA{}(_, _, cute::Int<0>{}), + GmemTiledCopyA{}, + tensor_a, + SmemLayoutA{}(_, _, cute::Int<0>{}), make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), - size<1>(ClusterShape{})); // mcast along N mode for this M load, if any + size<1>(ClusterShape{})); // mcast along N mode for this M load, if any typename Params::TMA_B tma_load_b = make_tma_copy( - GmemTiledCopyB{}, tensor_b, SmemLayoutB{}(_, _, cute::Int<0>{}), + GmemTiledCopyB{}, + tensor_b, + SmemLayoutB{}(_, _, cute::Int<0>{}), make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), - size<0>(ClusterShape{})); // mcast along M mode for this N load, if any + size<0>(ClusterShape{})); // mcast along M mode for this N load, if any if constexpr (SwapAB) { auto ptr_Aux = reinterpret_cast(args.ptr_B1); Tensor tensor_aux = make_tensor(ptr_Aux, make_layout(make_shape(M, K, L), args.dA)); typename Params::TMA_Aux tma_load_aux = make_tma_copy( - GmemTiledCopyA{}, tensor_aux, SmemLayoutA{}(_, _, cute::Int<0>{}), + GmemTiledCopyA{}, + tensor_aux, + SmemLayoutA{}(_, _, cute::Int<0>{}), make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), size<1>( - ClusterShape{})); // mcast along N mode for this M load, if any - return {tma_load_a, tma_load_b, tma_load_aux, - args.scale_d0, args.scale_d1, args.mma_promotion_interval}; + ClusterShape{})); // mcast along N mode for this M load, if any + return {tma_load_a, + tma_load_b, + tma_load_aux, + args.scale_d0, + args.scale_d1, + args.mma_promotion_interval}; } else { auto ptr_Aux = reinterpret_cast(args.ptr_B1); Tensor tensor_aux = make_tensor(ptr_Aux, make_layout(make_shape(N, K, L), args.dB)); typename Params::TMA_Aux tma_load_aux = make_tma_copy( - GmemTiledCopyB{}, tensor_aux, SmemLayoutB{}(_, _, cute::Int<0>{}), + GmemTiledCopyB{}, + tensor_aux, + SmemLayoutB{}(_, _, cute::Int<0>{}), make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), size<0>( - ClusterShape{})); // mcast along M mode for this N load, if any - return {tma_load_a, tma_load_b, tma_load_aux, - args.scale_d0, args.scale_d1, args.mma_promotion_interval}; + ClusterShape{})); // mcast along M mode for this N load, if any + return {tma_load_a, + tma_load_b, + tma_load_aux, + args.scale_d0, + args.scale_d1, + args.mma_promotion_interval}; } } @@ -285,8 +336,9 @@ struct CollectiveMmaGated< implementable = implementable && (args.mma_promotion_interval % 4 == 0); if (!implementable) { - CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the " - "minimum alignment requirements for TMA.\n"); + CUTLASS_TRACE_HOST( + " CAN IMPLEMENT: Problem Size doesn't meet the " + "minimum alignment requirements for TMA.\n"); } return implementable; } @@ -333,49 +385,64 @@ struct CollectiveMmaGated< // TMA requires special handling of strides to deal with coord codomain // mapping Represent the full tensors -- get these from TMA Tensor mA_mkl = mainloop_params.tma_load_a.get_tma_tensor( - make_shape(M, K, L)); // (m,k,l) + make_shape(M, K, L)); // (m,k,l) Tensor mB_nkl = mainloop_params.tma_load_b.get_tma_tensor( - make_shape(N, K, L)); // (n,k,l) + make_shape(N, K, L)); // (n,k,l) // Make tiled views, defer the slice - Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_, _, _), - Step<_1, X, _1>{}); // (BLK_M,BLK_K,m,k,l) - Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_, _, _), - Step{}); // (BLK_N,BLK_K,n,k,l) + Tensor gA_mkl = local_tile(mA_mkl, + TileShape{}, + make_coord(_, _, _), + Step<_1, X, _1>{}); // (BLK_M,BLK_K,m,k,l) + Tensor gB_nkl = local_tile(mB_nkl, + TileShape{}, + make_coord(_, _, _), + Step{}); // (BLK_N,BLK_K,n,k,l) if constexpr (SwapAB) { Tensor mAux_xkl = mainloop_params.tma_load_aux.get_tma_tensor( - make_shape(M, K, L)); // (m,k,l) - Tensor gAux_xkl = local_tile(mAux_xkl, TileShape{}, make_coord(_, _, _), - Step<_1, X, _1>{}); // (BLK_M,BLK_K,m,k,l) + make_shape(M, K, L)); // (m,k,l) + Tensor gAux_xkl = local_tile(mAux_xkl, + TileShape{}, + make_coord(_, _, _), + Step<_1, X, _1>{}); // (BLK_M,BLK_K,m,k,l) return cute::make_tuple(gA_mkl, gB_nkl, gAux_xkl); } else { Tensor mAux_xkl = mainloop_params.tma_load_aux.get_tma_tensor( - make_shape(N, K, L)); // (n,k,l) - Tensor gAux_xkl = local_tile(mAux_xkl, TileShape{}, make_coord(_, _, _), - Step{}); // (BLK_N,BLK_K,n,k,l) + make_shape(N, K, L)); // (n,k,l) + Tensor gAux_xkl = local_tile(mAux_xkl, + TileShape{}, + make_coord(_, _, _), + Step{}); // (BLK_N,BLK_K,n,k,l) return cute::make_tuple(gA_mkl, gB_nkl, gAux_xkl); } } /// Perform a collective-scoped matrix multiply-accumulate /// Producer Perspective - template - CUTLASS_DEVICE void - load(Params const &mainloop_params, MainloopPipeline pipeline, - PipelineState smem_pipe_write, - cute::tuple const &load_inputs, - BlockCoord const &blk_coord, KTileIterator k_tile_iter, int k_tile_count, - int thread_idx, uint32_t block_rank_in_cluster, - TensorStorage &shared_tensors) { + CUTLASS_DEVICE void load( + Params const &mainloop_params, + MainloopPipeline pipeline, + PipelineState smem_pipe_write, + cute::tuple const &load_inputs, + BlockCoord const &blk_coord, + KTileIterator k_tile_iter, + int k_tile_count, + int thread_idx, + uint32_t block_rank_in_cluster, + TensorStorage &shared_tensors) { int lane_predicate = cute::elect_one_sync(); if (lane_predicate) { Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), - SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), - SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) Tensor sAux = make_tensor(make_smem_ptr(shared_tensors.smem_Aux.data()), SmemLayoutAux{}); @@ -403,17 +470,17 @@ struct CollectiveMmaGated< // Partition the inputs based on the current block coordinates. auto [m_coord, n_coord, k_coord, l_coord] = blk_coord; - Tensor gA = gA_mkl(_, _, m_coord, _, l_coord); // (BLK_M,BLK_K,k) - Tensor gB = gB_nkl(_, _, n_coord, _, l_coord); // (BLK_N,BLK_K,k) + Tensor gA = gA_mkl(_, _, m_coord, _, l_coord); // (BLK_M,BLK_K,k) + Tensor gB = gB_nkl(_, _, n_coord, _, l_coord); // (BLK_N,BLK_K,k) Tensor gAux = SwapAB ? gAux_xkl(_, _, m_coord, _, l_coord) : gAux_xkl(_, _, n_coord, _, l_coord); // Applies the mapping from block_tma_a - Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k) - Tensor tAsA = block_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE) + Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k) + Tensor tAsA = block_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE) - Tensor tBgB = block_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k) - Tensor tBsB = block_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE) + Tensor tBgB = block_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k) + Tensor tBsB = block_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE) Tensor tAuxgAux = block_tma_aux.partition_S(gAux); Tensor tAuxsAux = block_tma_aux.partition_D(sAux); @@ -426,18 +493,18 @@ struct CollectiveMmaGated< // Maps the tile -> block, value if constexpr (cute::is_same_v) { auto block_layout = - Layout{}; // (m,n) -> - // block_id + Layout{}; // (m,n) -> + // block_id for (int n = 0; n < size<1>(block_layout); ++n) { - mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x, - n, Int<0>{})); + mcast_mask_a |= (uint16_t(1) << block_layout( + cluster_local_block_id.x, n, Int<0>{})); } } if constexpr (cute::is_same_v) { auto block_layout = - Layout{}; // (m,n) -> - // block_id + Layout{}; // (m,n) -> + // block_id for (int m = 0; m < size<0>(block_layout); ++m) { mcast_mask_b |= (uint16_t(1) << block_layout( m, cluster_local_block_id.y, Int<0>{})); @@ -466,11 +533,14 @@ struct CollectiveMmaGated< int write_stage = smem_pipe_write.index(); copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a), - tAgA(_, _, _, *k_tile_iter), tAsA(_, _, _, write_stage)); + tAgA(_, _, _, *k_tile_iter), + tAsA(_, _, _, write_stage)); copy(mainloop_params.tma_load_b.with(*tma_barrier, mcast_mask_b), - tBgB(_, _, _, *k_tile_iter), tBsB(_, _, _, write_stage)); + tBgB(_, _, _, *k_tile_iter), + tBsB(_, _, _, write_stage)); copy(mainloop_params.tma_load_aux.with(*tma_barrier, mcast_mask_aux), - tAuxgAux(_, _, _, *k_tile_iter), tAuxsAux(_, _, _, write_stage)); + tAuxgAux(_, _, _, *k_tile_iter), + tAuxsAux(_, _, _, write_stage)); ++k_tile_iter; // Advance smem_pipe_write @@ -499,11 +569,14 @@ struct CollectiveMmaGated< /// Perform a collective-scoped matrix multiply-accumulate /// Consumer Perspective template - CUTLASS_DEVICE void - mma(MainloopPipeline pipeline, PipelineState smem_pipe_read, - FrgTensorC &accum0, FrgTensorC &accum1, int k_tile_count, int thread_idx, - TensorStorage &shared_tensors, Params const &mainloop_params) { - + CUTLASS_DEVICE void mma(MainloopPipeline pipeline, + PipelineState smem_pipe_read, + FrgTensorC &accum0, + FrgTensorC &accum1, + int k_tile_count, + int thread_idx, + TensorStorage &shared_tensors, + Params const &mainloop_params) { static_assert(is_rmem::value, "C tensor must be rmem resident."); static_assert(cute::rank(SmemLayoutA{}) == 3, @@ -518,9 +591,9 @@ struct CollectiveMmaGated< "smem sourced instructions."); Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), - SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), - SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) Tensor sAux = make_tensor(make_smem_ptr(shared_tensors.smem_Aux.data()), SmemLayoutAux{}); @@ -531,12 +604,12 @@ struct CollectiveMmaGated< TiledMma tiled_mma; auto thread_mma = tiled_mma.get_thread_slice(thread_idx); - Tensor tCsA = thread_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE) - Tensor tCsB = thread_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE) + Tensor tCsA = thread_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCsB = thread_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE) // Allocate "fragments/descriptors" - Tensor tCrA = thread_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K,PIPE) - Tensor tCrB = thread_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K,PIPE) + Tensor tCrA = thread_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCrB = thread_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K,PIPE) auto tCsAux = [&]() -> auto { if constexpr (SwapAB) { @@ -544,34 +617,36 @@ struct CollectiveMmaGated< } else { return thread_mma.partition_B(sAux); } - }(); + } + (); auto tCrAux = [&]() -> auto { if constexpr (SwapAB) { return thread_mma.make_fragment_A(tCsAux); } else { return thread_mma.make_fragment_B(tCsAux); } - }(); + } + (); - CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(accum0)); // M - CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum0)); // N - CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // K - CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsB)); // PIPE + CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(accum0)); // M + CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum0)); // N + CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // K + CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsB)); // PIPE if constexpr (SwapAB) { - CUTE_STATIC_ASSERT_V(size<1>(tCsAux) == size<1>(accum1)); // M - CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum1)); // N - CUTE_STATIC_ASSERT_V(size<2>(tCsB) == size<2>(tCsAux)); // K - CUTE_STATIC_ASSERT_V(size<3>(tCsB) == size<3>(tCsAux)); // PIPE + CUTE_STATIC_ASSERT_V(size<1>(tCsAux) == size<1>(accum1)); // M + CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum1)); // N + CUTE_STATIC_ASSERT_V(size<2>(tCsB) == size<2>(tCsAux)); // K + CUTE_STATIC_ASSERT_V(size<3>(tCsB) == size<3>(tCsAux)); // PIPE } else { - CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(accum1)); // M - CUTE_STATIC_ASSERT_V(size<1>(tCsAux) == size<2>(accum1)); // N - CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsAux)); // K - CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsAux)); // PIPE + CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(accum1)); // M + CUTE_STATIC_ASSERT_V(size<1>(tCsAux) == size<2>(accum1)); // N + CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsAux)); // K + CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsAux)); // PIPE } - CUTE_STATIC_ASSERT_V(Int{} == size<2>(sA)); // PIPE - CUTE_STATIC_ASSERT_V(Int{} == size<2>(sB)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sA)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sB)); // PIPE CUTE_STATIC_ASSERT_V(Int{} == - size<2>(sAux)); // PIPE + size<2>(sAux)); // PIPE // // PIPELINED MAIN LOOP @@ -611,14 +686,20 @@ struct CollectiveMmaGated< CUTLASS_PRAGMA_UNROLL for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { // (V,M,K) x (V,N,K) => (V,M,N) - cute::gemm(tiled_mma, tCrA(_, _, k_block, read_stage), - tCrB(_, _, k_block, read_stage), accumulation0()); + cute::gemm(tiled_mma, + tCrA(_, _, k_block, read_stage), + tCrB(_, _, k_block, read_stage), + accumulation0()); if constexpr (SwapAB) { - cute::gemm(tiled_mma, tCrAux(_, _, k_block, read_stage), - tCrB(_, _, k_block, read_stage), accumulation1()); + cute::gemm(tiled_mma, + tCrAux(_, _, k_block, read_stage), + tCrB(_, _, k_block, read_stage), + accumulation1()); } else { - cute::gemm(tiled_mma, tCrA(_, _, k_block, read_stage), - tCrAux(_, _, k_block, read_stage), accumulation1()); + cute::gemm(tiled_mma, + tCrA(_, _, k_block, read_stage), + tCrAux(_, _, k_block, read_stage), + accumulation1()); } tiled_mma.accumulate_ = GMMA::ScaleOut::One; } @@ -659,14 +740,20 @@ struct CollectiveMmaGated< CUTLASS_PRAGMA_UNROLL for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { // (V,M,K) x (V,N,K) => (V,M,N) - cute::gemm(tiled_mma, tCrA(_, _, k_block, read_stage), - tCrB(_, _, k_block, read_stage), accumulation0()); + cute::gemm(tiled_mma, + tCrA(_, _, k_block, read_stage), + tCrB(_, _, k_block, read_stage), + accumulation0()); if constexpr (SwapAB) { - cute::gemm(tiled_mma, tCrAux(_, _, k_block, read_stage), - tCrB(_, _, k_block, read_stage), accumulation1()); + cute::gemm(tiled_mma, + tCrAux(_, _, k_block, read_stage), + tCrB(_, _, k_block, read_stage), + accumulation1()); } else { - cute::gemm(tiled_mma, tCrA(_, _, k_block, read_stage), - tCrAux(_, _, k_block, read_stage), accumulation1()); + cute::gemm(tiled_mma, + tCrA(_, _, k_block, read_stage), + tCrAux(_, _, k_block, read_stage), + accumulation1()); } tiled_mma.accumulate_ = GMMA::ScaleOut::One; } @@ -681,8 +768,9 @@ struct CollectiveMmaGated< accumulation0.promote_if_needed(); accumulation1.promote_if_needed(); - pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, - // done _computing_ on it + pipeline.consumer_release( + smem_pipe_release); // UNLOCK smem_pipe_release, + // done _computing_ on it // Advance smem_pipe_read and smem_pipe_release ++smem_pipe_read; @@ -710,8 +798,9 @@ struct CollectiveMmaGated< warpgroup_wait<0>(); for (int count = 0; count < prologue_mma_count; ++count) { - pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, - // done _computing_ on it + pipeline.consumer_release( + smem_pipe_release); // UNLOCK smem_pipe_release, + // done _computing_ on it ++smem_pipe_release; } } @@ -719,6 +808,6 @@ struct CollectiveMmaGated< ///////////////////////////////////////////////////////////////////////////////////////////////// -} // namespace cutlass::gemm::collective +} // namespace cutlass::gemm::collective ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp b/custom_ops/gpu_ops/cutlass_extensions/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp index be1f9747e77..837d65ae54d 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp @@ -12,18 +12,20 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Adapt from https://github.com/vllm-project/vllm/blob/v0.7.2/csrc/cutlass_extensions/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp -// Adapted (Heavily) from: https://github.com/soundOfDestiny/cutlass/blob/9d997ce0dea4c5fa1a617db6b7ff29aa9235822c/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp +// Adapt from +// https://github.com/vllm-project/vllm/blob/v0.7.2/csrc/cutlass_extensions/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp +// Adapted (Heavily) from: +// https://github.com/soundOfDestiny/cutlass/blob/9d997ce0dea4c5fa1a617db6b7ff29aa9235822c/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp /*************************************************************************************************** - * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation @@ -35,14 +37,15 @@ * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ @@ -73,46 +76,52 @@ using namespace cute; ///////////////////////////////////////////////////////////////////////////////////////////////// // WarpSpecialized Mainloop -template < - int Stages, - class ClusterShape, - class KernelSchedule, - int ScaleGranularityM_, - class TileShape_, - class ElementA_, - class StrideA_, - class ElementB_, - class StrideB_, - class TiledMma_, - class GmemTiledCopyA_, - class SmemLayoutAtomA_, - class SmemCopyAtomA_, - class TransformA_, - class GmemTiledCopyB_, - class SmemLayoutAtomB_, - class SmemCopyAtomB_, - class TransformB_> -struct CollectiveMma< - MainloopSm90TmaGmmaWarpSpecializedBlockScalingSubGroupMFP8, - TileShape_, - ElementA_, - StrideA_, - ElementB_, - StrideB_, - TiledMma_, - GmemTiledCopyA_, - SmemLayoutAtomA_, - SmemCopyAtomA_, - TransformA_, - GmemTiledCopyB_, - SmemLayoutAtomB_, - SmemCopyAtomB_, - TransformB_> -{ +template +struct CollectiveMma, + TileShape_, + ElementA_, + StrideA_, + ElementB_, + StrideB_, + TiledMma_, + GmemTiledCopyA_, + SmemLayoutAtomA_, + SmemCopyAtomA_, + TransformA_, + GmemTiledCopyB_, + SmemLayoutAtomB_, + SmemCopyAtomB_, + TransformB_> { // // Type Aliases // - using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedBlockScalingSubGroupMFP8; + using DispatchPolicy = + MainloopSm90TmaGmmaWarpSpecializedBlockScalingSubGroupMFP8< + Stages, + ClusterShape, + KernelSchedule, + ScaleGranularityM_>; using TileShape = TileShape_; using ElementA = ElementA_; using StrideA = StrideA_; @@ -139,55 +148,91 @@ struct CollectiveMma< // Two threads per CTA are producers (1 for operand tile and 32 for scales) static constexpr int NumProducerThreadEvents = 33; - static constexpr int ScaleGranularityM = ScaleGranularityM_ == 0 ? size<0>(TileShape{}) : ScaleGranularityM_; - static constexpr int ScaleMsPerTile = size<0>(TileShape{}) / ScaleGranularityM; - - static_assert(cute::rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); - static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); - static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); - - static_assert(cute::rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); - static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); - static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); - - static_assert((size<0>(TileShape{}) % ScaleGranularityM) == 0, "FP8 scaling granularity must evenly divide tile shape along M."); + static constexpr int ScaleGranularityM = + ScaleGranularityM_ == 0 ? size<0>(TileShape{}) : ScaleGranularityM_; + static constexpr int ScaleMsPerTile = + size<0>(TileShape{}) / ScaleGranularityM; + + static_assert(cute::rank(SmemLayoutAtomA{}) == 2, + "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, + "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, + "SmemLayoutAtom must evenly divide tile shape."); + + static_assert(cute::rank(SmemLayoutAtomB{}) == 2, + "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, + "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, + "SmemLayoutAtom must evenly divide tile shape."); + + static_assert( + (size<0>(TileShape{}) % ScaleGranularityM) == 0, + "FP8 scaling granularity must evenly divide tile shape along M."); // Tile along modes in a way that maximizes the TMA box size. using SmemLayoutA = decltype(tile_to_shape( SmemLayoutAtomA{}, - make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}), Int{}), - cute::conditional_t< ::cutlass::gemm::detail::is_major<0,StrideA>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + make_shape(shape<0>(TileShape{}), + shape<2>(TileShape{}), + Int{}), + cute::conditional_t<::cutlass::gemm::detail::is_major<0, StrideA>(), + Step<_2, _1, _3>, + Step<_1, _2, _3>>{})); using SmemLayoutB = decltype(tile_to_shape( SmemLayoutAtomB{}, - make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), Int{}), - cute::conditional_t< ::cutlass::gemm::detail::is_major<0,StrideB>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + make_shape(shape<1>(TileShape{}), + shape<2>(TileShape{}), + Int{}), + cute::conditional_t<::cutlass::gemm::detail::is_major<0, StrideB>(), + Step<_2, _1, _3>, + Step<_1, _2, _3>>{})); // Block scaling gmem-to-smem copy atom - using SmemBlockScalingCopyAtomA = Copy_Atom, ElementBlockScale>; - using SmemBlockScalingCopyAtomB = Copy_Atom, ElementBlockScale>; + using SmemBlockScalingCopyAtomA = + Copy_Atom, + ElementBlockScale>; + using SmemBlockScalingCopyAtomB = + Copy_Atom, + ElementBlockScale>; // Block scaling smem layout - using SmemLayoutScaleA = Layout, Int>>; - using SmemLayoutScaleB = Layout>, Stride<_1>>; // `ScaleNsPerTile` is always 1. - - static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 1 or more."); - static_assert(cute::is_base_of::value && - cute::is_base_of::value, - "MMA atom must source both A and B operand from smem_desc for this mainloop."); - static_assert(cute::is_same_v || cute::is_same_v, - "GmemTiledCopy - invalid SM90 TMA copy atom specified."); - static_assert(cute::is_same_v || cute::is_same_v, - "GmemTiledCopy - invalid SM90 TMA copy atom specified."); - static_assert(cute::is_same_v, - "ElementAccumulator and ElementBlockScale should be same datatype"); - - struct SharedStorage - { + using SmemLayoutScaleA = + Layout, Int>>; + using SmemLayoutScaleB = Layout>, + Stride<_1>>; // `ScaleNsPerTile` is always 1. + + static_assert(DispatchPolicy::Stages >= 2, + "Specialization requires Stages set to value 1 or more."); + static_assert(cute::is_base_of::value && + cute::is_base_of::value, + "MMA atom must source both A and B operand from smem_desc for " + "this mainloop."); + static_assert(cute::is_same_v || + cute::is_same_v, + "GmemTiledCopy - invalid SM90 TMA copy atom specified."); + static_assert(cute::is_same_v || + cute::is_same_v, + "GmemTiledCopy - invalid SM90 TMA copy atom specified."); + static_assert( + cute::is_same_v, + "ElementAccumulator and ElementBlockScale should be same datatype"); + + struct SharedStorage { struct TensorStorage : cute::aligned_struct<128> { - cute::array_aligned> smem_A; // mxk - cute::array_aligned> smem_B; // nxk - cute::array_aligned> smem_scale_A; // ScaleMsPerTile x k - cute::array_aligned> smem_scale_B; // 1xk + cute::array_aligned> + smem_A; // mxk + cute::array_aligned> + smem_B; // nxk + cute::array_aligned> + smem_scale_A; // ScaleMsPerTile x k + cute::array_aligned> + smem_scale_B; // 1xk } tensors; using PipelineStorage = typename MainloopPipeline::SharedStorage; @@ -211,15 +256,19 @@ struct CollectiveMma< // Assumption: StrideA is congruent with Problem_MK using TMA_A = decltype(make_tma_copy_A_sm90( GmemTiledCopyA{}, - make_tensor(static_cast(nullptr), repeat_like(StrideA{}, int32_t(0)), StrideA{}), - SmemLayoutA{}(_,_,0), + make_tensor(static_cast(nullptr), + repeat_like(StrideA{}, int32_t(0)), + StrideA{}), + SmemLayoutA{}(_, _, 0), TileShape{}, ClusterShape{})); // Assumption: StrideB is congruent with Problem_NK using TMA_B = decltype(make_tma_copy_B_sm90( GmemTiledCopyB{}, - make_tensor(static_cast(nullptr), repeat_like(StrideB{}, int32_t(0)), StrideB{}), - SmemLayoutB{}(_,_,0), + make_tensor(static_cast(nullptr), + repeat_like(StrideB{}, int32_t(0)), + StrideB{}), + SmemLayoutB{}(_, _, 0), TileShape{}, ClusterShape{})); TMA_A tma_load_a; @@ -237,103 +286,128 @@ struct CollectiveMma< // template - static constexpr Params - to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { - (void) workspace; + static constexpr Params to_underlying_arguments( + ProblemShape const& problem_shape, + Arguments const& args, + void* workspace) { + (void)workspace; - // Optionally append 1s until problem shape is rank-4 (MNKL), in case it is only rank-3 (MNK) + // Optionally append 1s until problem shape is rank-4 (MNKL), in case it is + // only rank-3 (MNK) auto problem_shape_MNKL = append<4>(problem_shape, 1); - auto [M,N,K,L] = problem_shape_MNKL; + auto [M, N, K, L] = problem_shape_MNKL; auto ptr_A = reinterpret_cast(args.ptr_A); auto ptr_B = reinterpret_cast(args.ptr_B); - Tensor tensor_a = make_tensor(ptr_A, make_layout(make_shape(M,K,L), args.dA)); - Tensor tensor_b = make_tensor(ptr_B, make_layout(make_shape(N,K,L), args.dB)); - typename Params::TMA_A tma_load_a = make_tma_copy_A_sm90( - GmemTiledCopyA{}, - tensor_a, - SmemLayoutA{}(_,_,cute::Int<0>{}), - TileShape{}, - ClusterShape{}); - typename Params::TMA_B tma_load_b = make_tma_copy_B_sm90( - GmemTiledCopyB{}, - tensor_b, - SmemLayoutB{}(_,_,cute::Int<0>{}), - TileShape{}, - ClusterShape{}); + Tensor tensor_a = + make_tensor(ptr_A, make_layout(make_shape(M, K, L), args.dA)); + Tensor tensor_b = + make_tensor(ptr_B, make_layout(make_shape(N, K, L), args.dB)); + typename Params::TMA_A tma_load_a = + make_tma_copy_A_sm90(GmemTiledCopyA{}, + tensor_a, + SmemLayoutA{}(_, _, cute::Int<0>{}), + TileShape{}, + ClusterShape{}); + typename Params::TMA_B tma_load_b = + make_tma_copy_B_sm90(GmemTiledCopyB{}, + tensor_b, + SmemLayoutB{}(_, _, cute::Int<0>{}), + TileShape{}, + ClusterShape{}); uint32_t transaction_bytes_mk = TmaTransactionBytesMK; uint32_t transaction_bytes_nk = TmaTransactionBytesNK; uint32_t transaction_bytes = transaction_bytes_mk + transaction_bytes_nk; - return { - tma_load_a, - tma_load_b, - transaction_bytes, - transaction_bytes_mk, - transaction_bytes_nk, - args.ptr_scale_A, - args.ptr_scale_B - }; + return {tma_load_a, + tma_load_b, + transaction_bytes, + transaction_bytes_mk, + transaction_bytes_nk, + args.ptr_scale_A, + args.ptr_scale_B}; } - template - static bool - can_implement( - ProblemShape const& problem_shape, - [[maybe_unused]] Arguments const& args) { + template + static bool can_implement(ProblemShape const& problem_shape, + [[maybe_unused]] Arguments const& args) { constexpr int tma_alignment_bits = 128; auto problem_shape_MNKL = append<4>(problem_shape, 1); - auto [M,N,K,L] = problem_shape_MNKL; + auto [M, N, K, L] = problem_shape_MNKL; bool implementable = true; - constexpr int min_tma_aligned_elements_A = tma_alignment_bits / cutlass::sizeof_bits::value; - implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(M,K,L), StrideA{}); - constexpr int min_tma_aligned_elements_B = tma_alignment_bits / cutlass::sizeof_bits::value; - implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(N,K,L), StrideB{}); + constexpr int min_tma_aligned_elements_A = + tma_alignment_bits / cutlass::sizeof_bits::value; + implementable = + implementable && + cutlass::detail::check_alignment( + cute::make_shape(M, K, L), StrideA{}); + constexpr int min_tma_aligned_elements_B = + tma_alignment_bits / cutlass::sizeof_bits::value; + implementable = + implementable && + cutlass::detail::check_alignment( + cute::make_shape(N, K, L), StrideB{}); if (!implementable) { - CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); + CUTLASS_TRACE_HOST( + " CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment " + "requirements for TMA.\n"); } return implementable; } static constexpr int K_PIPE_MAX = DispatchPolicy::Stages; static constexpr int K_PIPE_MMAS = 1; - static constexpr uint32_t TmaTransactionBytesMK = - cutlass::bits_to_bytes(size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) * static_cast(sizeof_bits::value)); - static constexpr uint32_t TmaTransactionBytesNK = - cutlass::bits_to_bytes(size<0>(SmemLayoutB{}) * size<1>(SmemLayoutB{}) * static_cast(sizeof_bits::value)); - static constexpr uint32_t TmaTransactionBytes = TmaTransactionBytesMK + TmaTransactionBytesNK; - - /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance + static constexpr uint32_t TmaTransactionBytesMK = cutlass::bits_to_bytes( + size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) * + static_cast(sizeof_bits::value)); + static constexpr uint32_t TmaTransactionBytesNK = cutlass::bits_to_bytes( + size<0>(SmemLayoutB{}) * size<1>(SmemLayoutB{}) * + static_cast(sizeof_bits::value)); + static constexpr uint32_t TmaTransactionBytes = + TmaTransactionBytesMK + TmaTransactionBytesNK; + + /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best + /// performance CUTLASS_DEVICE - static void prefetch_tma_descriptors(Params const& mainloop_params) - { - cute::prefetch_tma_descriptor(mainloop_params.tma_load_a.get_tma_descriptor()); - cute::prefetch_tma_descriptor(mainloop_params.tma_load_b.get_tma_descriptor()); + static void prefetch_tma_descriptors(Params const& mainloop_params) { + cute::prefetch_tma_descriptor( + mainloop_params.tma_load_a.get_tma_descriptor()); + cute::prefetch_tma_descriptor( + mainloop_params.tma_load_b.get_tma_descriptor()); } /// Set up the data needed by this collective for load and mma. - /// Returns a tuple of tensors. The collective and the kernel layer have the contract - /// Returned tuple must contain at least two elements, with the first two elements being: - /// gA_mkl - The tma tensor, A after a local tile so it has shape (BLK_M,BLK_K,m,k,l) - /// gB_nkl - The tma tensor, B after a local tile so it has shape (BLK_N,BLK_K,n,k,l) + /// Returns a tuple of tensors. The collective and the kernel layer have the + /// contract Returned tuple must contain at least two elements, with the first + /// two elements being: gA_mkl - The tma tensor, A after a local tile so it + /// has shape (BLK_M,BLK_K,m,k,l) gB_nkl - The tma tensor, B after a local + /// tile so it has shape (BLK_N,BLK_K,n,k,l) template - CUTLASS_DEVICE auto - load_init(ProblemShape_MNKL const& problem_shape_MNKL, Params const& mainloop_params) const { + CUTLASS_DEVICE auto load_init(ProblemShape_MNKL const& problem_shape_MNKL, + Params const& mainloop_params) const { using X = Underscore; // Separate out problem shape for convenience - auto [M,N,K,L] = problem_shape_MNKL; + auto [M, N, K, L] = problem_shape_MNKL; - // TMA requires special handling of strides to deal with coord codomain mapping - // Represent the full tensors -- get these from TMA - Tensor mA_mkl = mainloop_params.tma_load_a.get_tma_tensor(make_shape(M,K,L)); // (m,k,l) - Tensor mB_nkl = mainloop_params.tma_load_b.get_tma_tensor(make_shape(N,K,L)); // (n,k,l) + // TMA requires special handling of strides to deal with coord codomain + // mapping Represent the full tensors -- get these from TMA + Tensor mA_mkl = mainloop_params.tma_load_a.get_tma_tensor( + make_shape(M, K, L)); // (m,k,l) + Tensor mB_nkl = mainloop_params.tma_load_b.get_tma_tensor( + make_shape(N, K, L)); // (n,k,l) // Make tiled views, defer the slice - Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k,l) - Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k,l) + Tensor gA_mkl = local_tile(mA_mkl, + TileShape{}, + make_coord(_, _, _), + Step<_1, X, _1>{}); // (BLK_M,BLK_K,m,k,l) + Tensor gB_nkl = local_tile(mB_nkl, + TileShape{}, + make_coord(_, _, _), + Step{}); // (BLK_N,BLK_K,n,k,l) constexpr auto scales_m = Int{}; auto tM = get<2>(gA_mkl.shape()); @@ -341,84 +415,103 @@ struct CollectiveMma< auto tK = get<3>(gA_mkl.shape()); // Make the tiled views of scale tensors - auto scaleA_shape = make_shape(M / ScaleGranularityM, tK, L); // (scale_m,k,l) - auto scaleA_layout = make_ordered_layout(scaleA_shape, Step<_0, _1, _2>{}); - auto scaleB_shape = make_shape(tN, tK, L); // (n,k,l) + auto scaleA_shape = + make_shape(M / ScaleGranularityM, tK, L); // (scale_m,k,l) + auto scaleA_layout = make_ordered_layout(scaleA_shape, Step<_0, _1, _2>{}); + auto scaleB_shape = make_shape(tN, tK, L); // (n,k,l) auto scaleB_layout = make_ordered_layout(scaleB_shape, Step<_1, _0, _2>{}); - // Note that mScaleA_mkl and mScaleB_nkl are already blocked tiled in the `m` host and - // gScaleA_mkl and gScaleB_nkl in `g` global memory are same as mScaleA_mkl and mScaleB_nkl. - Tensor mScaleA_mkl = make_tensor(make_gmem_ptr(mainloop_params.ptr_scale_A), scaleA_layout); // (scale_m,k,l) - Tensor mScaleB_nkl = make_tensor(make_gmem_ptr(mainloop_params.ptr_scale_B), scaleB_layout); // (n,k,l) + // Note that mScaleA_mkl and mScaleB_nkl are already blocked tiled in the + // `m` host and gScaleA_mkl and gScaleB_nkl in `g` global memory are same as + // mScaleA_mkl and mScaleB_nkl. + Tensor mScaleA_mkl = make_tensor(make_gmem_ptr(mainloop_params.ptr_scale_A), + scaleA_layout); // (scale_m,k,l) + Tensor mScaleB_nkl = make_tensor(make_gmem_ptr(mainloop_params.ptr_scale_B), + scaleB_layout); // (n,k,l) return cute::make_tuple(gA_mkl, gB_nkl, mScaleA_mkl, mScaleB_nkl); } /// Perform a collective-scoped matrix multiply-accumulate /// Producer Perspective - template < - class TensorA, class TensorB, - class TensorScaleA, class TensorScaleB, - class KTileIterator, class BlockCoord - > - CUTLASS_DEVICE void - load( + template + CUTLASS_DEVICE void load( Params const& mainloop_params, MainloopPipeline pipeline, PipelineState smem_pipe_write, - cute::tuple const& load_inputs, + cute::tuple const& + load_inputs, BlockCoord const& blk_coord, - KTileIterator k_tile_iter, int k_tile_count, + KTileIterator k_tile_iter, + int k_tile_count, int thread_idx, uint32_t block_rank_in_cluster, TensorStorage& shared_tensors) { int lane_predicate = cute::elect_one_sync(); // Blockscaling: Tma loads for load_input and CpAsync for load_scale - Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) - Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) - Tensor sScaleA = make_tensor(cute::make_smem_ptr(shared_tensors.smem_scale_A.data()), SmemLayoutScaleA{}); // (ScaleMsPerTile,k) - Tensor sScaleB = make_tensor(cute::make_smem_ptr(shared_tensors.smem_scale_B.data()), SmemLayoutScaleB{}); // (k) + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), + SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), + SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + Tensor sScaleA = + make_tensor(cute::make_smem_ptr(shared_tensors.smem_scale_A.data()), + SmemLayoutScaleA{}); // (ScaleMsPerTile,k) + Tensor sScaleB = + make_tensor(cute::make_smem_ptr(shared_tensors.smem_scale_B.data()), + SmemLayoutScaleB{}); // (k) // // Prepare the TMA loads for A and B // constexpr uint32_t cluster_shape_x = get<0>(ClusterShape()); - uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x}; + uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, + block_rank_in_cluster / cluster_shape_x}; Tensor gA_mkl = get<0>(load_inputs); Tensor gB_nkl = get<1>(load_inputs); - auto block_tma_a = mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y); - auto block_tma_b = mainloop_params.tma_load_b.get_slice(cluster_local_block_id.x); + auto block_tma_a = + mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y); + auto block_tma_b = + mainloop_params.tma_load_b.get_slice(cluster_local_block_id.x); // Partition the inputs based on the current block coordinates. auto [m_coord, n_coord, k_coord, l_coord] = blk_coord; - Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k) - Tensor gB = gB_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k) - + Tensor gA = gA_mkl(_, _, m_coord, _, l_coord); // (BLK_M,BLK_K,k) + Tensor gB = gB_nkl(_, _, n_coord, _, l_coord); // (BLK_N,BLK_K,k) - // Block scaling: load_scale has scaling tensors in global memory which are not tiled + // Block scaling: load_scale has scaling tensors in global memory which are + // not tiled Tensor mScaleA_mkl = get<2>(load_inputs); Tensor mScaleB_nkl = get<3>(load_inputs); auto scales_m = get<0>(mScaleA_mkl.shape()); Tensor cScaleA_mkl = make_identity_tensor(mScaleA_mkl.shape()); - Tensor gScaleA = local_tile( - mScaleA_mkl, make_tile(Int{}), - make_coord(m_coord,_,l_coord)); // (ScaleMsPerTile,k,1) - Tensor cScaleA = local_tile( - cScaleA_mkl, make_tile(Int{}), - make_coord(m_coord,_,l_coord)); - Tensor gScaleB = mScaleB_nkl(n_coord,_,l_coord); // (1,k,1) + Tensor gScaleA = + local_tile(mScaleA_mkl, + make_tile(Int{}), + make_coord(m_coord, _, l_coord)); // (ScaleMsPerTile,k,1) + Tensor cScaleA = local_tile(cScaleA_mkl, + make_tile(Int{}), + make_coord(m_coord, _, l_coord)); + Tensor gScaleB = mScaleB_nkl(n_coord, _, l_coord); // (1,k,1) // TODO: test `scale_copy_a` with `ScaleMsPerTile` < 128 - TiledCopy scale_copy_a = make_tiled_copy(SmemBlockScalingCopyAtomA{}, - Layout>{}, Layout>{}); // (1,1,1) + TiledCopy scale_copy_a = + make_tiled_copy(SmemBlockScalingCopyAtomA{}, + Layout>{}, + Layout>{}); // (1,1,1) TiledCopy scale_copy_b = make_tiled_copy(SmemBlockScalingCopyAtomB{}, - Layout>{}, Layout>{}); // (1,1,1) + Layout>{}, + Layout>{}); // (1,1,1) ThrCopy thr_scale_copy_a = scale_copy_a.get_slice(threadIdx.x); ThrCopy thr_scale_copy_b = scale_copy_b.get_slice(threadIdx.x); @@ -430,11 +523,11 @@ struct CollectiveMma< Tensor tBsB_ScaleB = thr_scale_copy_b.partition_D(sScaleB); // Applies the mapping from block_tma_a - Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k) - Tensor tAsA = block_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE) + Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k) + Tensor tAsA = block_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE) - Tensor tBgB = block_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k) - Tensor tBsB = block_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE) + Tensor tBgB = block_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k) + Tensor tBsB = block_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE) uint16_t mcast_mask_a = 0; uint16_t mcast_mask_b = 0; @@ -442,30 +535,34 @@ struct CollectiveMma< // Issue TmaLoads for GEMM operands A/B and CpAsync for scale tensors // Maps the tile -> block, value if constexpr (cute::is_same_v) { - auto block_layout = Layout{}; // (m,n) -> block_id + auto block_layout = + Layout{}; // (m,n) -> block_id for (int n = 0; n < size<1>(block_layout); ++n) { - mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x,n,Int<0>{})); + mcast_mask_a |= (uint16_t(1) << block_layout( + cluster_local_block_id.x, n, Int<0>{})); } } if constexpr (cute::is_same_v) { - auto block_layout = Layout{}; // (m,n) -> block_id + auto block_layout = + Layout{}; // (m,n) -> block_id for (int m = 0; m < size<0>(block_layout); ++m) { - mcast_mask_b |= (uint16_t(1) << block_layout(m,cluster_local_block_id.y,Int<0>{})); + mcast_mask_b |= (uint16_t(1) << block_layout( + m, cluster_local_block_id.y, Int<0>{})); } } // Allocate predicate tensors for a_scales (since we can't guarantee that // all scales are valid, since we could have a partial tiles along M) - Tensor tApA_ScaleA = make_tensor(shape(tAsA_ScaleA(_,_,0))); - #pragma unroll + Tensor tApA_ScaleA = make_tensor(shape(tAsA_ScaleA(_, _, 0))); +#pragma unroll for (int i = 0; i < size(tApA_ScaleA); ++i) { tApA_ScaleA(i) = get<0>(tAcA_ScaleA(i)) < scales_m; } // Mainloop CUTLASS_PRAGMA_NO_UNROLL - for ( ; k_tile_count > 0; --k_tile_count) { + for (; k_tile_count > 0; --k_tile_count) { // LOCK smem_pipe_write for _writing_ pipeline.producer_acquire(smem_pipe_write); @@ -477,13 +574,25 @@ struct CollectiveMma< BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write); // Copy operands A and B from global memory to shared memory - if (lane_predicate) copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage)); - if (lane_predicate) copy(mainloop_params.tma_load_b.with(*tma_barrier, mcast_mask_b), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage)); + if (lane_predicate) + copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a), + tAgA(_, _, _, *k_tile_iter), + tAsA(_, _, _, write_stage)); + if (lane_predicate) + copy(mainloop_params.tma_load_b.with(*tma_barrier, mcast_mask_b), + tBgB(_, _, _, *k_tile_iter), + tBsB(_, _, _, write_stage)); // Copy scale tensors from global memory to shared memory - copy_if(scale_copy_a, tApA_ScaleA, tAgA_ScaleA(_,_,*k_tile_iter), tAsA_ScaleA(_,_,write_stage)); - copy(scale_copy_b, tBgB_ScaleB(_,*k_tile_iter), tBsB_ScaleB(_,write_stage)); - pipeline.producer_commit(smem_pipe_write, cutlass::arch::cpasync_barrier_arrive_noinc); + copy_if(scale_copy_a, + tApA_ScaleA, + tAgA_ScaleA(_, _, *k_tile_iter), + tAsA_ScaleA(_, _, write_stage)); + copy(scale_copy_b, + tBgB_ScaleB(_, *k_tile_iter), + tBsB_ScaleB(_, write_stage)); + pipeline.producer_commit(smem_pipe_write, + cutlass::arch::cpasync_barrier_arrive_noinc); ++k_tile_iter; @@ -493,10 +602,8 @@ struct CollectiveMma< } /// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster - CUTLASS_DEVICE void - load_tail( - MainloopPipeline pipeline, - PipelineState smem_pipe_write) { + CUTLASS_DEVICE void load_tail(MainloopPipeline pipeline, + PipelineState smem_pipe_write) { int lane_predicate = cute::elect_one_sync(); // Issue the epilogue waits @@ -513,37 +620,46 @@ struct CollectiveMma< /// Perform a collective-scoped matrix multiply-accumulate /// Consumer Perspective - template < - class FrgTensorC - > - CUTLASS_DEVICE void - mma(MainloopPipeline pipeline, - PipelineState smem_pipe_read, - FrgTensorC& accum, - int k_tile_count, - int thread_idx, - TensorStorage& shared_tensors, - Params const& mainloop_params) { - - - static_assert(is_rmem::value, "C tensor must be rmem resident."); - static_assert(cute::rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3."); - static_assert(cute::rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3."); + template + CUTLASS_DEVICE void mma(MainloopPipeline pipeline, + PipelineState smem_pipe_read, + FrgTensorC& accum, + int k_tile_count, + int thread_idx, + TensorStorage& shared_tensors, + Params const& mainloop_params) { + static_assert(is_rmem::value, + "C tensor must be rmem resident."); + static_assert(cute::rank(SmemLayoutA{}) == 3, + "Smem layout must be rank 3."); + static_assert(cute::rank(SmemLayoutB{}) == 3, + "Smem layout must be rank 3."); static_assert(cute::is_void_v, - "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions."); + "SM90 GMMA mainloops cannot have a non-void copy atom for " + "smem sourced instructions."); static_assert(cute::is_void_v, - "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions."); + "SM90 GMMA mainloops cannot have a non-void copy atom for " + "smem sourced instructions."); - Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) - Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), + SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), + SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) // Block scaling - Tensor sScaleAViewAsC = make_tensor(cute::make_smem_ptr(shared_tensors.smem_scale_A.data()), - Layout< - Shape, Int>, cute::tuple_element_t<1, TileShape>, Int>, - Stride, _0, Int> - >{}); // ((ScaleGranularityM,ScaleMsPerTile),n,k) - Tensor sScaleB = make_tensor(cute::make_smem_ptr(shared_tensors.smem_scale_B.data()), SmemLayoutScaleB{}); // (k) + Tensor sScaleAViewAsC = make_tensor( + cute::make_smem_ptr(shared_tensors.smem_scale_A.data()), + Layout< + Shape, Int>, + cute::tuple_element_t<1, TileShape>, + Int>, + Stride< + Stride<_0, _1>, + _0, + Int>>{}); // ((ScaleGranularityM,ScaleMsPerTile),n,k) + Tensor sScaleB = + make_tensor(cute::make_smem_ptr(shared_tensors.smem_scale_B.data()), + SmemLayoutScaleB{}); // (k) // // Define C accumulators and A/B partitioning @@ -551,52 +667,68 @@ struct CollectiveMma< // Layout of warp group to thread mapping - static_assert(stride<0>(typename TiledMma::ALayout{}) == 0 and - stride<0>(typename TiledMma::BLayout{}) == 0 and - size<0>(typename TiledMma::ALayout{}) == NumThreadsPerWarpGroup and - size<0>(typename TiledMma::BLayout{}) == NumThreadsPerWarpGroup, - "Stride of the first mode must be 0 and the size of the mode must be NumThreadsPerWarpGroup"); + static_assert( + stride<0>(typename TiledMma::ALayout{}) == 0 and + stride<0>(typename TiledMma::BLayout{}) == 0 and + size<0>(typename TiledMma::ALayout{}) == NumThreadsPerWarpGroup and + size<0>(typename TiledMma::BLayout{}) == NumThreadsPerWarpGroup, + "Stride of the first mode must be 0 and the size of the mode must be " + "NumThreadsPerWarpGroup"); constexpr int MmaWarpGroups = size(TiledMma{}) / NumThreadsPerWarpGroup; - Layout warp_group_thread_layout = make_layout(Int{}, - Int{}); + Layout warp_group_thread_layout = + make_layout(Int{}, Int{}); - int warp_group_idx = __shfl_sync(0xFFFFFFFF, thread_idx / NumThreadsPerWarpGroup, 0); + int warp_group_idx = + __shfl_sync(0xFFFFFFFF, thread_idx / NumThreadsPerWarpGroup, 0); TiledMma tiled_mma; - auto thread_mma = tiled_mma.get_slice(warp_group_thread_layout(warp_group_idx)); + auto thread_mma = + tiled_mma.get_slice(warp_group_thread_layout(warp_group_idx)); - Tensor tCsScaleAViewAsC = tiled_mma.get_slice(thread_idx).partition_C(sScaleAViewAsC); // (MMA,MMA_M,MMA_N,PIPE), `thread_mma` above is correct when partitioning A and B, but it is not correct when partitioning C. + Tensor tCsScaleAViewAsC = + tiled_mma.get_slice(thread_idx) + .partition_C( + sScaleAViewAsC); // (MMA,MMA_M,MMA_N,PIPE), `thread_mma` above + // is correct when partitioning A and B, but + // it is not correct when partitioning C. - Tensor tCsA = thread_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE) - Tensor tCsB = thread_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE) + Tensor tCsA = thread_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCsB = thread_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE) // Allocate "fragments/descriptors" - Tensor tCrA = thread_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K,PIPE) - Tensor tCrB = thread_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K,PIPE) + Tensor tCrA = thread_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCrB = thread_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K,PIPE) - CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(accum)); // M - CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum)); // N - CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // K - CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsB)); // PIPE - CUTE_STATIC_ASSERT_V(Int{} == size<2>(sA)); // PIPE - CUTE_STATIC_ASSERT_V(Int{} == size<2>(sB)); // PIPE + CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(accum)); // M + CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum)); // N + CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // K + CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsB)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sA)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sB)); // PIPE // // PIPELINED MAIN LOOP // - static_assert((0 <= K_PIPE_MMAS) && (K_PIPE_MMAS < K_PIPE_MAX), - "ERROR : Incorrect number of MMAs in flight"); + static_assert((0 <= K_PIPE_MMAS) && (K_PIPE_MMAS < K_PIPE_MAX), + "ERROR : Incorrect number of MMAs in flight"); // We release buffers to producer warps(dma load) with some mmas in flight PipelineState smem_pipe_release = smem_pipe_read; // Per block scale values for operand A and B - using RegLayoutScaleAViewAsC = decltype(make_layout_like(tCsScaleAViewAsC(_, _, _, 0).layout())); // `make_layout_like` makes a compact layout. - using RegLayoutScaleAEssential = decltype(filter_zeros(RegLayoutScaleAViewAsC{}.stride(), RegLayoutScaleAViewAsC{}.shape())); // an interface to traverse the underlying storage for the compact layout mentioned above - - Tensor tCrScaleAViewAsC = make_tensor(RegLayoutScaleAViewAsC{}); // (MMA,MMA_M,MMA_N) + using RegLayoutScaleAViewAsC = decltype(make_layout_like( + tCsScaleAViewAsC(_, _, _, 0) + .layout())); // `make_layout_like` makes a compact layout. + using RegLayoutScaleAEssential = decltype(filter_zeros( + RegLayoutScaleAViewAsC{}.stride(), + RegLayoutScaleAViewAsC{} + .shape())); // an interface to traverse the underlying storage for + // the compact layout mentioned above + + Tensor tCrScaleAViewAsC = make_tensor( + RegLayoutScaleAViewAsC{}); // (MMA,MMA_M,MMA_N) ElementBlockScale scale_b; // Prologue GMMAs @@ -604,12 +736,16 @@ struct CollectiveMma< tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; - GmmaFP8AccumulationWithScale accumulation(accum, size<2>(TileShape{}) / size<2>(typename TiledMma::AtomShape_MNK{}), size<2>(tCrA)); + GmmaFP8AccumulationWithScale accumulation( + accum, + size<2>(TileShape{}) / size<2>(typename TiledMma::AtomShape_MNK{}), + size<2>(tCrA)); warpgroup_fence_operand(accumulation()); CUTLASS_PRAGMA_UNROLL - for (int k_tile_prologue = prologue_mma_count; k_tile_prologue > 0; --k_tile_prologue) - { - // WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value) + for (int k_tile_prologue = prologue_mma_count; k_tile_prologue > 0; + --k_tile_prologue) { + // WAIT on smem_pipe_read until its data are available (phase bit flips + // from rdPhaseBit value) auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read); pipeline.consumer_wait(smem_pipe_read, barrier_token); @@ -623,11 +759,16 @@ struct CollectiveMma< scale_b = sScaleB[read_stage]; CUTLASS_PRAGMA_UNROLL for (int i = 0; i < size(RegLayoutScaleAEssential{}); i++) { - tCrScaleAViewAsC.data()[i] = tCsScaleAViewAsC(_, _, _, read_stage)(idx2crd(i, RegLayoutScaleAEssential{})); + tCrScaleAViewAsC.data()[i] = tCsScaleAViewAsC( + _, _, _, read_stage)(idx2crd(i, RegLayoutScaleAEssential{})); } if constexpr (ScaleMsPerTile == 1) { static_assert(size(RegLayoutScaleAEssential{}) == 1); - tCrScaleAViewAsC.data()[0] = __shfl_sync(0xffffffff, tCrScaleAViewAsC.data()[0] * scale_b, 0); // `tCrScaleAViewAsC.data()[0]` are all same in a warp group when `ScaleMsPerTile == 1`. + tCrScaleAViewAsC.data()[0] = + __shfl_sync(0xffffffff, + tCrScaleAViewAsC.data()[0] * scale_b, + 0); // `tCrScaleAViewAsC.data()[0]` are all same in a + // warp group when `ScaleMsPerTile == 1`. } else { CUTLASS_PRAGMA_UNROLL for (int i = 0; i < size(RegLayoutScaleAEssential{}); i++) { @@ -640,7 +781,10 @@ struct CollectiveMma< CUTLASS_PRAGMA_UNROLL for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { // (V,M,K) x (V,N,K) => (V,M,N) - cute::gemm(tiled_mma, tCrA(_,_,k_block,read_stage), tCrB(_,_,k_block,read_stage), accumulation()); + cute::gemm(tiled_mma, + tCrA(_, _, k_block, read_stage), + tCrB(_, _, k_block, read_stage), + accumulation()); tiled_mma.accumulate_ = GMMA::ScaleOut::One; } warpgroup_commit_batch(); @@ -656,9 +800,9 @@ struct CollectiveMma< k_tile_count -= prologue_mma_count; CUTLASS_PRAGMA_NO_UNROLL - for ( ; k_tile_count > 0; --k_tile_count) - { - // WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value) + for (; k_tile_count > 0; --k_tile_count) { + // WAIT on smem_pipe_read until its data are available (phase bit flips + // from rdPhaseBit value) auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read); pipeline.consumer_wait(smem_pipe_read, barrier_token); @@ -668,15 +812,21 @@ struct CollectiveMma< int read_stage = smem_pipe_read.index(); - // Load per block scale values from shared memory to registers (at most twice per block along M and exactly once per block along N) + // Load per block scale values from shared memory to registers (at most + // twice per block along M and exactly once per block along N) scale_b = sScaleB[read_stage]; CUTLASS_PRAGMA_UNROLL for (int i = 0; i < size(RegLayoutScaleAEssential{}); i++) { - tCrScaleAViewAsC.data()[i] = tCsScaleAViewAsC(_, _, _, read_stage)(idx2crd(i, RegLayoutScaleAEssential{})); + tCrScaleAViewAsC.data()[i] = tCsScaleAViewAsC( + _, _, _, read_stage)(idx2crd(i, RegLayoutScaleAEssential{})); } if constexpr (ScaleMsPerTile == 1) { static_assert(size(RegLayoutScaleAEssential{}) == 1); - tCrScaleAViewAsC.data()[0] = __shfl_sync(0xffffffff, tCrScaleAViewAsC.data()[0] * scale_b, 0); // `tCrScaleAViewAsC.data()[0]` are all same in a warp group when `ScaleMsPerTile == 1`. + tCrScaleAViewAsC.data()[0] = + __shfl_sync(0xffffffff, + tCrScaleAViewAsC.data()[0] * scale_b, + 0); // `tCrScaleAViewAsC.data()[0]` are all same in a + // warp group when `ScaleMsPerTile == 1`. } else { CUTLASS_PRAGMA_UNROLL for (int i = 0; i < size(RegLayoutScaleAEssential{}); i++) { @@ -694,19 +844,25 @@ struct CollectiveMma< CUTLASS_PRAGMA_UNROLL for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { // (V,M,K) x (V,N,K) => (V,M,N) - cute::gemm(tiled_mma, tCrA(_,_,k_block,read_stage), tCrB(_,_,k_block,read_stage), accumulation()); + cute::gemm(tiled_mma, + tCrA(_, _, k_block, read_stage), + tCrB(_, _, k_block, read_stage), + accumulation()); tiled_mma.accumulate_ = GMMA::ScaleOut::One; } warpgroup_commit_batch(); - /// Wait on the GMMA barrier for K_PIPE_MMAS (or fewer) outstanding to ensure smem_pipe_write is consumed + /// Wait on the GMMA barrier for K_PIPE_MMAS (or fewer) outstanding to + /// ensure smem_pipe_write is consumed warpgroup_wait(); warpgroup_fence_operand(accumulation()); // Block scale the accumulators with reg tensor `tCrScaleAViewAsC` accumulation.scale_if_needed(tCrScaleAViewAsC); - pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it + pipeline.consumer_release( + smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on + // it // Advance smem_pipe_read and smem_pipe_release ++smem_pipe_read; @@ -719,8 +875,9 @@ struct CollectiveMma< } /// Perform a Consumer Epilogue to release all buffers - CUTLASS_DEVICE void - mma_tail(MainloopPipeline pipeline, PipelineState smem_pipe_release, int k_tile_count) { + CUTLASS_DEVICE void mma_tail(MainloopPipeline pipeline, + PipelineState smem_pipe_release, + int k_tile_count) { // Prologue GMMAs int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count); k_tile_count -= prologue_mma_count; @@ -731,7 +888,9 @@ struct CollectiveMma< warpgroup_wait<0>(); for (int count = 0; count < prologue_mma_count; ++count) { - pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it + pipeline.consumer_release( + smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on + // it ++smem_pipe_release; } } @@ -739,6 +898,6 @@ struct CollectiveMma< ///////////////////////////////////////////////////////////////////////////////////////////////// -} // namespace cutlass::gemm::collective +} // namespace cutlass::gemm::collective ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/device/gemm_universal_base_compat.h b/custom_ops/gpu_ops/cutlass_extensions/gemm/device/gemm_universal_base_compat.h index 2edd5a228b4..ae5ff6decc7 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/gemm/device/gemm_universal_base_compat.h +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/device/gemm_universal_base_compat.h @@ -1,12 +1,12 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation @@ -18,20 +18,21 @@ * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /*! \file - \brief The universal GEMM accommodates serial reductions, parallel reductions, batched strided, and - batched array variants. + \brief The universal GEMM accommodates serial reductions, parallel reductions, + batched strided, and batched array variants. */ #pragma once @@ -54,385 +55,363 @@ //////////////////////////////////////////////////////////////////////////////// -namespace cutlass -{ -namespace gemm -{ -namespace device -{ +namespace cutlass { +namespace gemm { +namespace device { ///////////////////////////////////////////////////////////////////////////////////////////////// /* - This is the device layer from CUTLASS 2.10 (SHA - cc85b64cf676c45f98a17e3a47c0aafcf817f088) - It is replicated here since we needed to duplicate kernel level APIs for mixed dtype GEMMs - and SmoothQuant. The newer device layer is not compatible with these older kernel level APIs. + This is the device layer from CUTLASS 2.10 (SHA - + cc85b64cf676c45f98a17e3a47c0aafcf817f088) It is replicated here since we + needed to duplicate kernel level APIs for mixed dtype GEMMs and SmoothQuant. + The newer device layer is not compatible with these older kernel level APIs. - Note: While CUTLASS 3.x supports stream-k, none of the kernels in the extensions folder support - that feature at the moment. + Note: While CUTLASS 3.x supports stream-k, none of the kernels in the + extensions folder support that feature at the moment. */ template -class GemmUniversalBaseCompat -{ -public: - using GemmKernel = GemmKernel_; - using ThreadblockShape = typename GemmKernel::Mma::Shape; - - using ElementA = typename GemmKernel::ElementA; - using LayoutA = typename GemmKernel::LayoutA; - using TensorRefA = TensorRef; - static ComplexTransform const kTransformA = GemmKernel::kTransformA; - - using ElementB = typename GemmKernel::ElementB; - using LayoutB = typename GemmKernel::LayoutB; - using TensorRefB = TensorRef; - static ComplexTransform const kTransformB = GemmKernel::kTransformB; - - using ElementC = typename GemmKernel::ElementC; - using LayoutC = typename GemmKernel::LayoutC; - using TensorRefC = TensorRef; - using TensorRefD = TensorRef; - - using ElementAccumulator = typename GemmKernel::Mma::Policy::Operator::ElementC; - - using EpilogueOutputOp = typename GemmKernel::EpilogueOutputOp; - using ThreadblockSwizzle = typename GemmKernel::ThreadblockSwizzle; - using Operator = typename GemmKernel::Operator; - - /// Argument structure - using Arguments = typename GemmKernel::Arguments; - -protected: - /// Kernel parameters object - typename GemmKernel::Params params_; - -protected: - /// Private helper to obtain the grid dimensions with fix-up for split-K - static void get_grid_shape_(gemm::GemmCoord& grid_tiled_shape, int& gemm_k_size, Arguments const& args) - { - - // Determine grid shape - ThreadblockSwizzle threadblock_swizzle; - - grid_tiled_shape = threadblock_swizzle.get_tiled_shape( - args.problem_size, {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, args.batch_count); - - gemm_k_size = args.problem_size.k(); - - if (args.mode == GemmUniversalMode::kGemm || args.mode == GemmUniversalMode::kGemmSplitKParallel) - { - - int const kAlignK - = const_max(const_max(128 / sizeof_bits::value, 128 / sizeof_bits::value), 1); - - gemm_k_size = round_up(ceil_div(args.problem_size.k(), args.batch_count), kAlignK); - - if (gemm_k_size) - { - grid_tiled_shape.k() = ceil_div(args.problem_size.k(), gemm_k_size); - } - } +class GemmUniversalBaseCompat { + public: + using GemmKernel = GemmKernel_; + using ThreadblockShape = typename GemmKernel::Mma::Shape; + + using ElementA = typename GemmKernel::ElementA; + using LayoutA = typename GemmKernel::LayoutA; + using TensorRefA = TensorRef; + static ComplexTransform const kTransformA = GemmKernel::kTransformA; + + using ElementB = typename GemmKernel::ElementB; + using LayoutB = typename GemmKernel::LayoutB; + using TensorRefB = TensorRef; + static ComplexTransform const kTransformB = GemmKernel::kTransformB; + + using ElementC = typename GemmKernel::ElementC; + using LayoutC = typename GemmKernel::LayoutC; + using TensorRefC = TensorRef; + using TensorRefD = TensorRef; + + using ElementAccumulator = + typename GemmKernel::Mma::Policy::Operator::ElementC; + + using EpilogueOutputOp = typename GemmKernel::EpilogueOutputOp; + using ThreadblockSwizzle = typename GemmKernel::ThreadblockSwizzle; + using Operator = typename GemmKernel::Operator; + + /// Argument structure + using Arguments = typename GemmKernel::Arguments; + + protected: + /// Kernel parameters object + typename GemmKernel::Params params_; + + protected: + /// Private helper to obtain the grid dimensions with fix-up for split-K + static void get_grid_shape_(gemm::GemmCoord& grid_tiled_shape, + int& gemm_k_size, + Arguments const& args) { + // Determine grid shape + ThreadblockSwizzle threadblock_swizzle; + + grid_tiled_shape = threadblock_swizzle.get_tiled_shape( + args.problem_size, + {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, + args.batch_count); + + gemm_k_size = args.problem_size.k(); + + if (args.mode == GemmUniversalMode::kGemm || + args.mode == GemmUniversalMode::kGemmSplitKParallel) { + int const kAlignK = + const_max(const_max(128 / sizeof_bits::value, + 128 / sizeof_bits::value), + 1); + + gemm_k_size = + round_up(ceil_div(args.problem_size.k(), args.batch_count), kAlignK); + + if (gemm_k_size) { + grid_tiled_shape.k() = ceil_div(args.problem_size.k(), gemm_k_size); + } } + } -public: - /// Constructs the GEMM. - GemmUniversalBaseCompat() {} - - /// Determines whether the GEMM can execute the given problem. - static Status can_implement(Arguments const& args) - { - - // Determine grid shape - cutlass::gemm::GemmCoord grid_tiled_shape; - int gemm_k_size = 0; + public: + /// Constructs the GEMM. + GemmUniversalBaseCompat() {} - get_grid_shape_(grid_tiled_shape, gemm_k_size, args); + /// Determines whether the GEMM can execute the given problem. + static Status can_implement(Arguments const& args) { + // Determine grid shape + cutlass::gemm::GemmCoord grid_tiled_shape; + int gemm_k_size = 0; - ThreadblockSwizzle threadblock_swizzle; - dim3 grid = threadblock_swizzle.get_grid_shape(grid_tiled_shape); + get_grid_shape_(grid_tiled_shape, gemm_k_size, args); - uint32_t const kGridYZMax = ((1 << (sizeof(uint16_t) * 8)) - 1); + ThreadblockSwizzle threadblock_swizzle; + dim3 grid = threadblock_swizzle.get_grid_shape(grid_tiled_shape); - if (!(grid.y <= kGridYZMax && grid.z <= kGridYZMax)) - { + uint32_t const kGridYZMax = ((1 << (sizeof(uint16_t) * 8)) - 1); - return Status::kErrorInvalidProblem; - } - - return GemmKernel::can_implement(args); + if (!(grid.y <= kGridYZMax && grid.z <= kGridYZMax)) { + return Status::kErrorInvalidProblem; } - /// Gets the workspace size - static size_t get_workspace_size(Arguments const& args) - { - - CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::get_workspace_size()"); - - size_t workspace_bytes = 0; - - // Determine grid shape - cutlass::gemm::GemmCoord grid_tiled_shape; - int gemm_k_size = 0; - - get_grid_shape_(grid_tiled_shape, gemm_k_size, args); + return GemmKernel::can_implement(args); + } - if (args.mode == GemmUniversalMode::kGemmSplitKParallel) - { + /// Gets the workspace size + static size_t get_workspace_size(Arguments const& args) { + CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::get_workspace_size()"); - // Split-K parallel always requires a temporary workspace - workspace_bytes = sizeof(ElementC) * size_t(args.batch_stride_D) * size_t(grid_tiled_shape.k()); - } - else if (args.mode == GemmUniversalMode::kGemm && grid_tiled_shape.k() > 1) - { - - // Serial split-K only requires a temporary workspace if the number of partitions along the - // GEMM K dimension is greater than one. - workspace_bytes = sizeof(int) * size_t(grid_tiled_shape.m()) * size_t(grid_tiled_shape.n()); - } + size_t workspace_bytes = 0; - CUTLASS_TRACE_HOST(" workspace_bytes: " << workspace_bytes); + // Determine grid shape + cutlass::gemm::GemmCoord grid_tiled_shape; + int gemm_k_size = 0; - workspace_bytes += GemmKernel::get_extra_workspace_size(args, grid_tiled_shape); + get_grid_shape_(grid_tiled_shape, gemm_k_size, args); - return workspace_bytes; + if (args.mode == GemmUniversalMode::kGemmSplitKParallel) { + // Split-K parallel always requires a temporary workspace + workspace_bytes = sizeof(ElementC) * size_t(args.batch_stride_D) * + size_t(grid_tiled_shape.k()); + } else if (args.mode == GemmUniversalMode::kGemm && + grid_tiled_shape.k() > 1) { + // Serial split-K only requires a temporary workspace if the number of + // partitions along the GEMM K dimension is greater than one. + workspace_bytes = sizeof(int) * size_t(grid_tiled_shape.m()) * + size_t(grid_tiled_shape.n()); } - /// Computes the grid shape - static dim3 get_grid_shape(Arguments const& args) - { - - CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::get_grid_shape()"); - - ThreadblockSwizzle threadblock_swizzle; - - cutlass::gemm::GemmCoord grid_tiled_shape; - int gemm_k_size = 0; + CUTLASS_TRACE_HOST(" workspace_bytes: " << workspace_bytes); - get_grid_shape_(grid_tiled_shape, gemm_k_size, args); - dim3 result = threadblock_swizzle.get_grid_shape(grid_tiled_shape); + workspace_bytes += + GemmKernel::get_extra_workspace_size(args, grid_tiled_shape); - CUTLASS_TRACE_HOST(" grid_tiled_shape: " << grid_tiled_shape << "\n" - << " result = {" << result << "}"); + return workspace_bytes; + } - return result; - } - - /// Computes the maximum number of active blocks per multiprocessor - static int maximum_active_blocks(int smem_capacity = -1) - { + /// Computes the grid shape + static dim3 get_grid_shape(Arguments const& args) { + CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::get_grid_shape()"); - CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::maximum_active_blocks()"); + ThreadblockSwizzle threadblock_swizzle; - int max_active_blocks = -1; - int smem_size = int(sizeof(typename GemmKernel::SharedStorage)); + cutlass::gemm::GemmCoord grid_tiled_shape; + int gemm_k_size = 0; - CUTLASS_TRACE_HOST(" smem_size: " << smem_size << " bytes"); + get_grid_shape_(grid_tiled_shape, gemm_k_size, args); + dim3 result = threadblock_swizzle.get_grid_shape(grid_tiled_shape); - if (smem_size <= (48 << 10)) - { + CUTLASS_TRACE_HOST(" grid_tiled_shape: " << grid_tiled_shape << "\n" + << " result = {" << result + << "}"); - cudaError_t result = cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &max_active_blocks, Kernel, GemmKernel::kThreadCount, smem_size); + return result; + } - if (result == cudaSuccess) - { - CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks); - return max_active_blocks; - } - } - else - { + /// Computes the maximum number of active blocks per multiprocessor + static int maximum_active_blocks(int smem_capacity = -1) { + CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::maximum_active_blocks()"); - // Query assuming zero shared memory then compute occupancy limit based on SMEM - cudaError_t result = cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &max_active_blocks, Kernel, GemmKernel::kThreadCount, 0); + int max_active_blocks = -1; + int smem_size = int(sizeof(typename GemmKernel::SharedStorage)); - if (result != cudaSuccess) - { + CUTLASS_TRACE_HOST(" smem_size: " << smem_size << " bytes"); - CUTLASS_TRACE_HOST( - " cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error " << cudaGetErrorString(result)); + if (smem_size <= (48 << 10)) { + cudaError_t result = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &max_active_blocks, + Kernel, + GemmKernel::kThreadCount, + smem_size); - return -1; - } + if (result == cudaSuccess) { + CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks); + return max_active_blocks; + } + } else { + // Query assuming zero shared memory then compute occupancy limit based on + // SMEM + cudaError_t result = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &max_active_blocks, Kernel, GemmKernel::kThreadCount, 0); - if (smem_capacity < 0) - { - int device_idx = 0; - result = cudaGetDevice(&device_idx); + if (result != cudaSuccess) { + CUTLASS_TRACE_HOST( + " cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error " + << cudaGetErrorString(result)); - if (result != cudaSuccess) - { - return -1; - } + return -1; + } - cudaDeviceProp properties; - result = cudaGetDeviceProperties(&properties, device_idx); + if (smem_capacity < 0) { + int device_idx = 0; + result = cudaGetDevice(&device_idx); - if (result != cudaSuccess) - { - return -1; - } + if (result != cudaSuccess) { + return -1; + } - smem_capacity = static_cast(properties.sharedMemPerMultiprocessor); - } + cudaDeviceProp properties; + result = cudaGetDeviceProperties(&properties, device_idx); - int occupancy = std::min(max_active_blocks, smem_capacity / smem_size); + if (result != cudaSuccess) { + return -1; + } - CUTLASS_TRACE_HOST(" occupancy: " << occupancy); + smem_capacity = static_cast(properties.sharedMemPerMultiprocessor); + } - return occupancy; - } + int occupancy = std::min(max_active_blocks, smem_capacity / smem_size); - CUTLASS_TRACE_HOST(" returning internal error"); + CUTLASS_TRACE_HOST(" occupancy: " << occupancy); - return -1; + return occupancy; } - /// Initializes GEMM state from arguments. - Status initialize(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) - { + CUTLASS_TRACE_HOST(" returning internal error"); - CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::initialize() - workspace " - << workspace << ", stream: " << (stream ? "non-null" : "null")); + return -1; + } - size_t workspace_bytes = get_workspace_size(args); + /// Initializes GEMM state from arguments. + Status initialize(Arguments const& args, + void* workspace = nullptr, + cudaStream_t stream = nullptr) { + CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::initialize() - workspace " + << workspace + << ", stream: " << (stream ? "non-null" : "null")); - CUTLASS_TRACE_HOST(" workspace_bytes: " << workspace_bytes); + size_t workspace_bytes = get_workspace_size(args); - if (workspace_bytes) - { + CUTLASS_TRACE_HOST(" workspace_bytes: " << workspace_bytes); - if (!workspace) - { - CUTLASS_TRACE_HOST(" error: device workspace must not be null"); + if (workspace_bytes) { + if (!workspace) { + CUTLASS_TRACE_HOST(" error: device workspace must not be null"); - return Status::kErrorWorkspaceNull; - } + return Status::kErrorWorkspaceNull; + } - if (args.mode == GemmUniversalMode::kGemm) - { - CUTLASS_TRACE_HOST(" clearing device workspace"); - cudaError_t result = cudaMemsetAsync(workspace, 0, workspace_bytes, stream); + if (args.mode == GemmUniversalMode::kGemm) { + CUTLASS_TRACE_HOST(" clearing device workspace"); + cudaError_t result = + cudaMemsetAsync(workspace, 0, workspace_bytes, stream); - if (result != cudaSuccess) - { - CUTLASS_TRACE_HOST(" cudaMemsetAsync() returned error " << cudaGetErrorString(result)); + if (result != cudaSuccess) { + CUTLASS_TRACE_HOST(" cudaMemsetAsync() returned error " + << cudaGetErrorString(result)); - return Status::kErrorInternal; - } - } + return Status::kErrorInternal; } + } + } - // Get CUDA grid shape - cutlass::gemm::GemmCoord grid_tiled_shape; - int gemm_k_size = 0; - - get_grid_shape_(grid_tiled_shape, gemm_k_size, args); + // Get CUDA grid shape + cutlass::gemm::GemmCoord grid_tiled_shape; + int gemm_k_size = 0; - // Initialize the Params structure - params_ = typename GemmKernel::Params(args, grid_tiled_shape, gemm_k_size, static_cast(workspace)); + get_grid_shape_(grid_tiled_shape, gemm_k_size, args); - // Specify shared memory capacity for kernel. - int smem_size = int(sizeof(typename GemmKernel::SharedStorage)); + // Initialize the Params structure + params_ = typename GemmKernel::Params( + args, grid_tiled_shape, gemm_k_size, static_cast(workspace)); - if (smem_size >= (48 << 10)) - { - cudaError_t result - = cudaFuncSetAttribute(Kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); + // Specify shared memory capacity for kernel. + int smem_size = int(sizeof(typename GemmKernel::SharedStorage)); - if (result != cudaSuccess) - { - return Status::kErrorInternal; - } - } + if (smem_size >= (48 << 10)) { + cudaError_t result = + cudaFuncSetAttribute(Kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); - return Status::kSuccess; + if (result != cudaSuccess) { + return Status::kErrorInternal; + } } - /// Lightweight update given a subset of arguments - Status update(Arguments const& args, void* workspace = nullptr) - { - - CUTLASS_TRACE_HOST("GemmUniversalBaseCompat()::update() - workspace: " << workspace); + return Status::kSuccess; + } - size_t workspace_bytes = get_workspace_size(args); + /// Lightweight update given a subset of arguments + Status update(Arguments const& args, void* workspace = nullptr) { + CUTLASS_TRACE_HOST( + "GemmUniversalBaseCompat()::update() - workspace: " << workspace); - if (workspace_bytes && !workspace) - { - return Status::kErrorWorkspaceNull; - } - - params_.update(args, workspace); + size_t workspace_bytes = get_workspace_size(args); - return Status::kSuccess; + if (workspace_bytes && !workspace) { + return Status::kErrorWorkspaceNull; } - /// Runs the kernel using initialized state. - Status run(cudaStream_t stream = nullptr) - { - CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::run()"); + params_.update(args, workspace); - // - // Configure grid and block dimensions - // + return Status::kSuccess; + } - ThreadblockSwizzle threadblock_swizzle; + /// Runs the kernel using initialized state. + Status run(cudaStream_t stream = nullptr) { + CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::run()"); - dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape); - dim3 block(GemmKernel::kThreadCount, 1, 1); + // + // Configure grid and block dimensions + // - int smem_size = int(sizeof(typename GemmKernel::SharedStorage)); + ThreadblockSwizzle threadblock_swizzle; - // - // Launch kernel - // + dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape); + dim3 block(GemmKernel::kThreadCount, 1, 1); - CUTLASS_TRACE_HOST(" grid: (" << grid << "), block: (" << block << "), SMEM: " << smem_size << " bytes"); + int smem_size = int(sizeof(typename GemmKernel::SharedStorage)); - // Launch - cutlass::Kernel<<>>(params_); + // + // Launch kernel + // - // - // Query for errors - // - cudaError_t result = cudaGetLastError(); + CUTLASS_TRACE_HOST(" grid: (" << grid << "), block: (" << block + << "), SMEM: " << smem_size << " bytes"); - if (result != cudaSuccess) - { - CUTLASS_TRACE_HOST(" grid launch failed with error " << cudaGetErrorString(result)); - return Status::kErrorInternal; - } + // Launch + cutlass::Kernel<<>>(params_); - return Status::kSuccess; - } + // + // Query for errors + // + cudaError_t result = cudaGetLastError(); - /// Runs the kernel using initialized state. - Status operator()(cudaStream_t stream = nullptr) - { - return run(stream); + if (result != cudaSuccess) { + CUTLASS_TRACE_HOST(" grid launch failed with error " + << cudaGetErrorString(result)); + return Status::kErrorInternal; } - /// Runs the kernel using initialized state. - Status operator()(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) - { + return Status::kSuccess; + } - Status status = initialize(args, workspace, stream); + /// Runs the kernel using initialized state. + Status operator()(cudaStream_t stream = nullptr) { return run(stream); } - if (status == Status::kSuccess) - { - status = run(stream); - } + /// Runs the kernel using initialized state. + Status operator()(Arguments const& args, + void* workspace = nullptr, + cudaStream_t stream = nullptr) { + Status status = initialize(args, workspace, stream); - return status; + if (status == Status::kSuccess) { + status = run(stream); } + + return status; + } }; ///////////////////////////////////////////////////////////////////////////////////////////////// -} // namespace device -} // namespace gemm -} // namespace cutlass +} // namespace device +} // namespace gemm +} // namespace cutlass ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/device/splitk_gemm_grouped.h b/custom_ops/gpu_ops/cutlass_extensions/gemm/device/splitk_gemm_grouped.h index bfd3666b9c1..02859ee0d11 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/gemm/device/splitk_gemm_grouped.h +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/device/splitk_gemm_grouped.h @@ -1,12 +1,12 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation @@ -18,14 +18,15 @@ * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /*! @@ -55,488 +56,479 @@ //////////////////////////////////////////////////////////////////////////////// -namespace cutlass -{ -namespace gemm -{ -namespace device -{ +namespace cutlass { +namespace gemm { +namespace device { ///////////////////////////////////////////////////////////////////////////////////////////////// template -__global__ void splitkReduction(T_OUT** out_tensor, const T_IN* in_tensor, GemmCoord const* problem_sizes, int splitk, - int64_t* splitk_buffer_offsets) -{ - // in_tensor: [problem_idx, k_partition, hidden_size] - // Note that different requests of in_tensor might have different hidden_size (=m*n) - // so, we need to use splitk_buffer_offsets. - // out_tensor: problem_idx * [hidden_size] - - int const problem_idx = blockIdx.y; - GemmCoord problem = problem_sizes[problem_idx]; - int const hidden_size = problem.m() * problem.n(); - const T_IN* in_tensor_ = in_tensor + splitk_buffer_offsets[problem_idx] * splitk; - T_OUT* out_tensor_ = out_tensor[problem_idx]; - - for (int i = threadIdx.x + blockIdx.x * blockDim.x; i < hidden_size; i += blockDim.x * gridDim.x) - { - float sum = 0.0f; - for (int k_idx = 0; k_idx < splitk; k_idx++) - { - sum += (float) in_tensor_[k_idx * hidden_size + i]; - } - out_tensor_[i] = (T_OUT) (sum); +__global__ void splitkReduction(T_OUT** out_tensor, + const T_IN* in_tensor, + GemmCoord const* problem_sizes, + int splitk, + int64_t* splitk_buffer_offsets) { + // in_tensor: [problem_idx, k_partition, hidden_size] + // Note that different requests of in_tensor might have different + // hidden_size (=m*n) so, we need to use splitk_buffer_offsets. + // out_tensor: problem_idx * [hidden_size] + + int const problem_idx = blockIdx.y; + GemmCoord problem = problem_sizes[problem_idx]; + int const hidden_size = problem.m() * problem.n(); + const T_IN* in_tensor_ = + in_tensor + splitk_buffer_offsets[problem_idx] * splitk; + T_OUT* out_tensor_ = out_tensor[problem_idx]; + + for (int i = threadIdx.x + blockIdx.x * blockDim.x; i < hidden_size; + i += blockDim.x * gridDim.x) { + float sum = 0.0f; + for (int k_idx = 0; k_idx < splitk; k_idx++) { + sum += (float)in_tensor_[k_idx * hidden_size + i]; } + out_tensor_[i] = (T_OUT)(sum); + } } /// GEMM Grouped template -class BaseSplitkGrouped -{ -public: - using BaseKernel = BaseKernel_; - - using ElementA = typename BaseKernel::ElementA; - using LayoutA = typename BaseKernel::LayoutA; - using TensorRefA = TensorRef; - static ComplexTransform const kTransformA = BaseKernel::kTransformA; - static int const kAlignmentA = BaseKernel::kAlignmentA; - - using ElementB = typename BaseKernel::ElementB; - using LayoutB = typename BaseKernel::LayoutB; - using TensorRefB = TensorRef; - static ComplexTransform const kTransformB = BaseKernel::kTransformB; - static int const kAlignmentB = BaseKernel::kAlignmentB; - - using ElementC = typename BaseKernel::ElementC; - using LayoutC = typename BaseKernel::LayoutC; - using TensorRefC = TensorRef; - using TensorRefD = TensorRef; - static int const kAlignmentC = BaseKernel::kAlignmentC; - - using ElementAccumulator = typename BaseKernel::Mma::Policy::Operator::ElementC; - - using EpilogueOutputOp = typename BaseKernel::EpilogueOutputOp; - using ThreadblockSwizzle = typename threadblock::GemmSplitKHorizontalThreadblockSwizzle; - - using Operator = typename BaseKernel::Operator; - using WarpMmaOperator = typename BaseKernel::Mma::Policy::Operator; - - using ArchMmaOperator = typename WarpMmaOperator::ArchMmaOperator; - using MathOperator = typename WarpMmaOperator::MathOperator; - using OperatorClass = typename WarpMmaOperator::OperatorClass; - using ArchTag = typename WarpMmaOperator::ArchTag; - using ThreadblockShape = typename BaseKernel::Mma::Shape; - using WarpShape = typename BaseKernel::WarpShape; - using InstructionShape = typename BaseKernel::InstructionShape; - static int const kStages = BaseKernel::Mma::kStages; - - /// Argument structure - using Arguments = typename BaseKernel::Arguments; - - using ProblemInfo = typename BaseKernel::ProblemVisitor::ProblemInfo; - -protected: - /// Kernel parameters object - typename BaseKernel::Params gemm_params_; - -private: - /// Get the number of tiles across all problems in a group - static int32_t group_tile_count(cutlass::gemm::GemmCoord const* problem_sizes_ptr, int problem_count) - { - int32_t tiles = 0; - for (int32_t i = 0; i < problem_count; ++i) - { - cutlass::gemm::GemmCoord problem = problem_sizes_ptr[i]; - BaseKernel::ProblemVisitor::possibly_transpose_problem(problem); - tiles += problem_tile_count(problem); - } - return tiles; +class BaseSplitkGrouped { + public: + using BaseKernel = BaseKernel_; + + using ElementA = typename BaseKernel::ElementA; + using LayoutA = typename BaseKernel::LayoutA; + using TensorRefA = TensorRef; + static ComplexTransform const kTransformA = BaseKernel::kTransformA; + static int const kAlignmentA = BaseKernel::kAlignmentA; + + using ElementB = typename BaseKernel::ElementB; + using LayoutB = typename BaseKernel::LayoutB; + using TensorRefB = TensorRef; + static ComplexTransform const kTransformB = BaseKernel::kTransformB; + static int const kAlignmentB = BaseKernel::kAlignmentB; + + using ElementC = typename BaseKernel::ElementC; + using LayoutC = typename BaseKernel::LayoutC; + using TensorRefC = TensorRef; + using TensorRefD = TensorRef; + static int const kAlignmentC = BaseKernel::kAlignmentC; + + using ElementAccumulator = + typename BaseKernel::Mma::Policy::Operator::ElementC; + + using EpilogueOutputOp = typename BaseKernel::EpilogueOutputOp; + using ThreadblockSwizzle = + typename threadblock::GemmSplitKHorizontalThreadblockSwizzle; + + using Operator = typename BaseKernel::Operator; + using WarpMmaOperator = typename BaseKernel::Mma::Policy::Operator; + + using ArchMmaOperator = typename WarpMmaOperator::ArchMmaOperator; + using MathOperator = typename WarpMmaOperator::MathOperator; + using OperatorClass = typename WarpMmaOperator::OperatorClass; + using ArchTag = typename WarpMmaOperator::ArchTag; + using ThreadblockShape = typename BaseKernel::Mma::Shape; + using WarpShape = typename BaseKernel::WarpShape; + using InstructionShape = typename BaseKernel::InstructionShape; + static int const kStages = BaseKernel::Mma::kStages; + + /// Argument structure + using Arguments = typename BaseKernel::Arguments; + + using ProblemInfo = typename BaseKernel::ProblemVisitor::ProblemInfo; + + protected: + /// Kernel parameters object + typename BaseKernel::Params gemm_params_; + + private: + /// Get the number of tiles across all problems in a group + static int32_t group_tile_count( + cutlass::gemm::GemmCoord const* problem_sizes_ptr, int problem_count) { + int32_t tiles = 0; + for (int32_t i = 0; i < problem_count; ++i) { + cutlass::gemm::GemmCoord problem = problem_sizes_ptr[i]; + BaseKernel::ProblemVisitor::possibly_transpose_problem(problem); + tiles += problem_tile_count(problem); + } + return tiles; + } + + /// Copy from `data` to `workspace` + Status copy_to_workspace(void* workspace, void* data, size_t bytes) { + cudaError_t cuda_error = + cudaMemcpy(workspace, data, bytes, cudaMemcpyHostToDevice); + if (cuda_error != cudaSuccess) { + // Call cudaGetLastError() to clear the error bit + cuda_error = cudaGetLastError(); + CUTLASS_TRACE_HOST(" cudaMemcpy() returned error " + << cudaGetErrorString(cuda_error)); + return Status::kErrorInternal; } - /// Copy from `data` to `workspace` - Status copy_to_workspace(void* workspace, void* data, size_t bytes) - { - cudaError_t cuda_error = cudaMemcpy(workspace, data, bytes, cudaMemcpyHostToDevice); - if (cuda_error != cudaSuccess) - { - // Call cudaGetLastError() to clear the error bit - cuda_error = cudaGetLastError(); - CUTLASS_TRACE_HOST(" cudaMemcpy() returned error " << cudaGetErrorString(cuda_error)); - return Status::kErrorInternal; - } - - return Status::kSuccess; + return Status::kSuccess; + } + + /// Precomputes scheduling information for the grouped GEMM + Status precompute(Arguments const& args, + int32_t tile_count, + void* workspace) { + size_t workspace_bytes = get_workspace_size(args); + std::vector host_workspace(workspace_bytes); + BaseKernel::ProblemVisitor::host_precompute(args.host_problem_sizes, + args.problem_count, + args.threadblock_count, + (void*)host_workspace.data()); + return copy_to_workspace(workspace, host_workspace.data(), workspace_bytes); + } + + /// Reorder `data` according to `indices` + template + static void reorder_array(T* data, std::vector const& indices) { + // For now, simply create a copy of the data and then copy over to the + // original. + std::vector copy(indices.size()); + for (size_t i = 0; i < indices.size(); ++i) { + copy.at(i) = data[indices[i]]; } - /// Precomputes scheduling information for the grouped GEMM - Status precompute(Arguments const& args, int32_t tile_count, void* workspace) - { - size_t workspace_bytes = get_workspace_size(args); - std::vector host_workspace(workspace_bytes); - BaseKernel::ProblemVisitor::host_precompute( - args.host_problem_sizes, args.problem_count, args.threadblock_count, (void*) host_workspace.data()); - return copy_to_workspace(workspace, host_workspace.data(), workspace_bytes); + memcpy(data, copy.data(), indices.size() * sizeof(T)); + } + + public: + /// Constructs the GEMM. + BaseSplitkGrouped() {} + + /// Determines whether the GEMM can execute the given problem. + static Status can_implement(Arguments const& args) { + return BaseKernel::can_implement(args); + } + + /// Get the number of tiles in a problem + static int32_t problem_tile_count(cutlass::gemm::GemmCoord const& problem) { + auto grid = BaseKernel::ProblemVisitor::grid_shape(problem); + return BaseKernel::ProblemVisitor::tile_count(grid); + } + + /// Get the number of tiles across all problems in a group + static int32_t group_tile_count(Arguments const& args) { + if (args.host_problem_sizes == nullptr) { + CUTLASS_TRACE_HOST("Received nullptr for `args.host_problem_sizes"); + return -1; } - /// Reorder `data` according to `indices` - template - static void reorder_array(T* data, std::vector const& indices) - { - // For now, simply create a copy of the data and then copy over to the original. - std::vector copy(indices.size()); - for (size_t i = 0; i < indices.size(); ++i) - { - copy.at(i) = data[indices[i]]; - } - - memcpy(data, copy.data(), indices.size() * sizeof(T)); + return group_tile_count(args.host_problem_sizes, args.problem_count); + } + + /// Gets the workspace size + static size_t get_workspace_size(Arguments const& args) { + size_t total_mn = 0; + for (int i = 0; i < args.problem_count; i++) { + total_mn += + args.host_problem_sizes[i].m() * args.host_problem_sizes[i].n(); } + size_t workSpaceSize = + total_mn * sizeof(ElementAccumulator) * args.split_k_slices; -public: - /// Constructs the GEMM. - BaseSplitkGrouped() {} + if (BaseKernel::ProblemVisitor::kRequiresPrecomputation) { + workSpaceSize += BaseKernel::ProblemVisitor::get_workspace_size( + args.host_problem_sizes, args.problem_count, args.threadblock_count); + } + return workSpaceSize; + } + + /// Computes the grid shape + static dim3 get_grid_shape(Arguments const& args) { + return dim3(args.threadblock_count, 1, 1); + } + + /// Computes the maximum number of active blocks per multiprocessor + static int maximum_active_blocks(int smem_capacity = -1) { + CUTLASS_TRACE_HOST("BaseSplitkGrouped::maximum_active_blocks()"); + + int smem_size = int(sizeof(typename BaseKernel::SharedStorage)); + + CUTLASS_TRACE_HOST(" smem_size: " << smem_size << " bytes"); + + cudaError_t result; + if (smem_size > (48 << 10)) { + result = cudaFuncSetAttribute(Kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + + if (result != cudaSuccess) { + // Call cudaGetLastError() to clear the error bit + result = cudaGetLastError(); + CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error " + << cudaGetErrorString(result)); + return -1; + } + } - /// Determines whether the GEMM can execute the given problem. - static Status can_implement(Arguments const& args) - { + int max_active_blocks = -1; + result = + cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_active_blocks, + Kernel, + BaseKernel::kThreadCount, + smem_size); + + if (result != cudaSuccess) { + // Call cudaGetLastError() to clear the error bit + result = cudaGetLastError(); + CUTLASS_TRACE_HOST( + " cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error " + << cudaGetErrorString(result)); + return -1; + } - return BaseKernel::can_implement(args); + CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks); + return max_active_blocks; + } + + /// Sorts each pointer passed in according to the indices that sort + /// `problem_sizes_ptr` in descending order of problem-K dimension. + static void sort_problems(int problem_count, + cutlass::gemm::GemmCoord* problem_sizes_ptr, + int64_t* lda_host_ptr, + int64_t* ldb_host_ptr, + int64_t* ldc_host_ptr, + int64_t* ldd_host_ptr, + int64_t* offset_A_ptr, + int64_t* offset_B_ptr, + int64_t* offset_C_ptr, + int64_t* offset_D_ptr) { + std::vector indices(problem_count); + std::iota(indices.begin(), indices.end(), 0); + std::stable_sort(indices.begin(), + indices.end(), + [&problem_sizes_ptr](size_t i, size_t j) { + return problem_sizes_ptr[i].k() > + problem_sizes_ptr[j].k(); + }); + + reorder_array(problem_sizes_ptr, indices); + reorder_array(lda_host_ptr, indices); + reorder_array(ldb_host_ptr, indices); + reorder_array(ldc_host_ptr, indices); + reorder_array(ldd_host_ptr, indices); + reorder_array(offset_A_ptr, indices); + reorder_array(offset_B_ptr, indices); + reorder_array(offset_C_ptr, indices); + reorder_array(offset_D_ptr, indices); + } + + /// Computes the number of threadblocks to launch for the grouped kernel + static int sufficient( + cutlass::gemm::GemmCoord const* problem_sizes_ptr = nullptr, + int problem_count = 0, + int available_sm_count = -1) { + // Determine the number of blocks that would be launched to fill up a single + // wave on the GPU with each SM having maximum occupancy. + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + if (result != cudaSuccess) { + // Call cudaGetLastError() to clear the error bit + result = cudaGetLastError(); + CUTLASS_TRACE_HOST(" cudaGetDevice() returned error " + << cudaGetErrorString(result)); + return 0; } - /// Get the number of tiles in a problem - static int32_t problem_tile_count(cutlass::gemm::GemmCoord const& problem) - { - auto grid = BaseKernel::ProblemVisitor::grid_shape(problem); - return BaseKernel::ProblemVisitor::tile_count(grid); + int multiprocessor_count; + result = cudaDeviceGetAttribute( + &multiprocessor_count, cudaDevAttrMultiProcessorCount, device_idx); + if (result != cudaSuccess) { + CUTLASS_TRACE_HOST(" cudaDeviceGetAttribute() returned error " + << cudaGetErrorString(result)); + return 0; } - /// Get the number of tiles across all problems in a group - static int32_t group_tile_count(Arguments const& args) - { - if (args.host_problem_sizes == nullptr) - { - CUTLASS_TRACE_HOST("Received nullptr for `args.host_problem_sizes"); - return -1; - } + bool override_sm_count = + (available_sm_count < 0 || available_sm_count > multiprocessor_count); + if (override_sm_count) { + available_sm_count = multiprocessor_count; + } - return group_tile_count(args.host_problem_sizes, args.problem_count); + int max_active_blocks = maximum_active_blocks(); + if (max_active_blocks <= 0) { + return 0; } - /// Gets the workspace size - static size_t get_workspace_size(Arguments const& args) - { - size_t total_mn = 0; - for (int i = 0; i < args.problem_count; i++) - { - total_mn += args.host_problem_sizes[i].m() * args.host_problem_sizes[i].n(); - } - size_t workSpaceSize = total_mn * sizeof(ElementAccumulator) * args.split_k_slices; - - if (BaseKernel::ProblemVisitor::kRequiresPrecomputation) - { - workSpaceSize += BaseKernel::ProblemVisitor::get_workspace_size( - args.host_problem_sizes, args.problem_count, args.threadblock_count); - } - return workSpaceSize; + int occupancy_based_block_count = available_sm_count * max_active_blocks; + + if (problem_sizes_ptr == nullptr || problem_count == 0) { + return occupancy_based_block_count; } - /// Computes the grid shape - static dim3 get_grid_shape(Arguments const& args) - { + int total_tiles = group_tile_count(problem_sizes_ptr, problem_count); - return dim3(args.threadblock_count, 1, 1); + // If the group contains a single problem, launching the exact number of + // threadblocks needed to cover the problem minimizes the work performed + // per threadblock in finding the next tile to compute. We return + // total_tiles unless the user has provided the SM count. + if (problem_count == 1 && override_sm_count) { + return total_tiles; } - /// Computes the maximum number of active blocks per multiprocessor - static int maximum_active_blocks(int smem_capacity = -1) - { + // Choose between the full wave of threadblocks and the tile count. If there + // are fewer tiles in the group than threadblocks in the full wave, only + // some threadblocks will be assigned tiles. Those threadblocks + // which are not assigned tiles still need to perform the work of iterating + // through problem sizes to determine that they have no work to do. This + // competes for cycles with those threadblocks that are assigned tiles to + // compute. + return std::min(total_tiles, occupancy_based_block_count); + } + + /// Initializes GEMM state from arguments. + Status initialize(Arguments const& args, + void* workspace = nullptr, + cudaStream_t stream = nullptr) { + CUTLASS_TRACE_HOST("BaseSplitkGrouped::initialize() - workspace " + << workspace + << ", stream: " << (stream ? "non-null" : "null")); + + // Workspace + size_t workspace_bytes = get_workspace_size(args); + + if (workspace_bytes && !workspace) { + return Status::kErrorWorkspaceNull; + } - CUTLASS_TRACE_HOST("BaseSplitkGrouped::maximum_active_blocks()"); - - int smem_size = int(sizeof(typename BaseKernel::SharedStorage)); - - CUTLASS_TRACE_HOST(" smem_size: " << smem_size << " bytes"); - - cudaError_t result; - if (smem_size > (48 << 10)) - { - result = cudaFuncSetAttribute(Kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); - - if (result != cudaSuccess) - { - // Call cudaGetLastError() to clear the error bit - result = cudaGetLastError(); - CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error " << cudaGetErrorString(result)); - return -1; - } - } - - int max_active_blocks = -1; - result = cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &max_active_blocks, Kernel, BaseKernel::kThreadCount, smem_size); - - if (result != cudaSuccess) - { - // Call cudaGetLastError() to clear the error bit - result = cudaGetLastError(); - CUTLASS_TRACE_HOST( - " cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error " << cudaGetErrorString(result)); - return -1; - } - - CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks); - return max_active_blocks; + if (BaseKernel::ProblemVisitor::kRequiresPrecomputation) { + int32_t tile_count = group_tile_count(args); + Status status = precompute(args, tile_count, workspace); + if (status != Status::kSuccess) { + return status; + } + + gemm_params_ = typename BaseKernel::Params(args, workspace, tile_count); + } else { + gemm_params_ = typename BaseKernel::Params(args, workspace); } - /// Sorts each pointer passed in according to the indices that sort - /// `problem_sizes_ptr` in descending order of problem-K dimension. - static void sort_problems(int problem_count, cutlass::gemm::GemmCoord* problem_sizes_ptr, int64_t* lda_host_ptr, - int64_t* ldb_host_ptr, int64_t* ldc_host_ptr, int64_t* ldd_host_ptr, int64_t* offset_A_ptr, - int64_t* offset_B_ptr, int64_t* offset_C_ptr, int64_t* offset_D_ptr) - { - std::vector indices(problem_count); - std::iota(indices.begin(), indices.end(), 0); - std::stable_sort(indices.begin(), indices.end(), - [&problem_sizes_ptr](size_t i, size_t j) { return problem_sizes_ptr[i].k() > problem_sizes_ptr[j].k(); }); - - reorder_array(problem_sizes_ptr, indices); - reorder_array(lda_host_ptr, indices); - reorder_array(ldb_host_ptr, indices); - reorder_array(ldc_host_ptr, indices); - reorder_array(ldd_host_ptr, indices); - reorder_array(offset_A_ptr, indices); - reorder_array(offset_B_ptr, indices); - reorder_array(offset_C_ptr, indices); - reorder_array(offset_D_ptr, indices); + // Specify shared memory capacity for kernel. + int smem_size = int(sizeof(typename BaseKernel::SharedStorage)); + + if (smem_size >= (48 << 10)) { + cudaError_t result = + cudaFuncSetAttribute(Kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + + if (result != cudaSuccess) { + return Status::kErrorInternal; + } } - /// Computes the number of threadblocks to launch for the grouped kernel - static int sufficient( - cutlass::gemm::GemmCoord const* problem_sizes_ptr = nullptr, int problem_count = 0, int available_sm_count = -1) - { - // Determine the number of blocks that would be launched to fill up a single - // wave on the GPU with each SM having maximum occupancy. - int device_idx; - cudaError_t result = cudaGetDevice(&device_idx); - if (result != cudaSuccess) - { - // Call cudaGetLastError() to clear the error bit - result = cudaGetLastError(); - CUTLASS_TRACE_HOST(" cudaGetDevice() returned error " << cudaGetErrorString(result)); - return 0; - } - - int multiprocessor_count; - result = cudaDeviceGetAttribute(&multiprocessor_count, cudaDevAttrMultiProcessorCount, device_idx); - if (result != cudaSuccess) - { - CUTLASS_TRACE_HOST(" cudaDeviceGetAttribute() returned error " << cudaGetErrorString(result)); - return 0; - } - - bool override_sm_count = (available_sm_count < 0 || available_sm_count > multiprocessor_count); - if (override_sm_count) - { - available_sm_count = multiprocessor_count; - } - - int max_active_blocks = maximum_active_blocks(); - if (max_active_blocks <= 0) - { - return 0; - } - - int occupancy_based_block_count = available_sm_count * max_active_blocks; - - if (problem_sizes_ptr == nullptr || problem_count == 0) - { - return occupancy_based_block_count; - } - - int total_tiles = group_tile_count(problem_sizes_ptr, problem_count); - - // If the group contains a single problem, launching the exact number of - // threadblocks needed to cover the problem minimizes the work performed - // per threadblock in finding the next tile to compute. We return total_tiles - // unless the user has provided the SM count. - if (problem_count == 1 && override_sm_count) - { - return total_tiles; - } - - // Choose between the full wave of threadblocks and the tile count. If there - // are fewer tiles in the group than threadblocks in the full wave, only - // some threadblocks will be assigned tiles. Those threadblocks - // which are not assigned tiles still need to perform the work of iterating through - // problem sizes to determine that they have no work to do. This competes for cycles - // with those threadblocks that are assigned tiles to compute. - return std::min(total_tiles, occupancy_based_block_count); + return Status::kSuccess; + } + + /// Lightweight update given a subset of arguments + Status update(Arguments const& args, void* workspace = nullptr) { + size_t workspace_bytes = get_workspace_size(args); + + if (workspace_bytes && !workspace) { + return Status::kErrorWorkspaceNull; } - /// Initializes GEMM state from arguments. - Status initialize(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) - { + if (BaseKernel::ProblemVisitor::kRequiresPrecomputation) { + int32_t tile_count = group_tile_count(args); + Status status = precompute(args, tile_count, workspace); + if (status != Status::kSuccess) { + return status; + } - CUTLASS_TRACE_HOST("BaseSplitkGrouped::initialize() - workspace " - << workspace << ", stream: " << (stream ? "non-null" : "null")); - - // Workspace - size_t workspace_bytes = get_workspace_size(args); - - if (workspace_bytes && !workspace) - { - return Status::kErrorWorkspaceNull; - } - - if (BaseKernel::ProblemVisitor::kRequiresPrecomputation) - { - int32_t tile_count = group_tile_count(args); - Status status = precompute(args, tile_count, workspace); - if (status != Status::kSuccess) - { - return status; - } - - gemm_params_ = typename BaseKernel::Params(args, workspace, tile_count); - } - else - { - gemm_params_ = typename BaseKernel::Params(args, workspace); - } - - // Specify shared memory capacity for kernel. - int smem_size = int(sizeof(typename BaseKernel::SharedStorage)); - - if (smem_size >= (48 << 10)) - { - cudaError_t result - = cudaFuncSetAttribute(Kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); - - if (result != cudaSuccess) - { - return Status::kErrorInternal; - } - } - - return Status::kSuccess; + gemm_params_.update(args, workspace, tile_count); + } else { + gemm_params_.update(args, workspace); } - /// Lightweight update given a subset of arguments - Status update(Arguments const& args, void* workspace = nullptr) - { + return Status::kSuccess; + } - size_t workspace_bytes = get_workspace_size(args); - - if (workspace_bytes && !workspace) - { - return Status::kErrorWorkspaceNull; - } - - if (BaseKernel::ProblemVisitor::kRequiresPrecomputation) - { - int32_t tile_count = group_tile_count(args); - Status status = precompute(args, tile_count, workspace); - if (status != Status::kSuccess) - { - return status; - } - - gemm_params_.update(args, workspace, tile_count); - } - else - { - gemm_params_.update(args, workspace); - } - - return Status::kSuccess; + /// Runs the kernel using initialized state. + Status run(cudaStream_t stream = nullptr) { + if (!gemm_params_.problem_visitor.problem_count) { + return Status::kSuccess; } - /// Runs the kernel using initialized state. - Status run(cudaStream_t stream = nullptr) + // + // Launch kernel + // + + // Launch splitk grouped gemm { - if (!gemm_params_.problem_visitor.problem_count) - { - return Status::kSuccess; - } - - // - // Launch kernel - // - - // Launch splitk grouped gemm - { - dim3 grid(gemm_params_.threadblock_count, 1, gemm_params_.split_k_slices); - dim3 block(BaseKernel::kThreadCount, 1, 1); - - int smem_size = int(sizeof(typename BaseKernel::SharedStorage)); - cutlass::Kernel<<>>(gemm_params_); - - cudaError_t result = cudaGetLastError(); - - if (result != cudaSuccess) - { - CUTLASS_TRACE_HOST(" grid launch failed with error " << cudaGetErrorString(result)); - return Status::kErrorInternal; - } - } - - // Launch splitkReduction - { - dim3 grid(32, gemm_params_.problem_visitor.problem_count); - dim3 block(256); - splitkReduction<<>>(gemm_params_.ptr_D, gemm_params_.ptr_D_split, - gemm_params_.problem_visitor.problem_sizes, gemm_params_.split_k_slices, - gemm_params_.splitk_buffer_offsets); - - cudaError_t result = cudaGetLastError(); - - if (result != cudaSuccess) - { - CUTLASS_TRACE_HOST(" grid launch failed with error " << cudaGetErrorString(result)); - return Status::kErrorInternal; - } - } - - return Status::kSuccess; + dim3 grid(gemm_params_.threadblock_count, 1, gemm_params_.split_k_slices); + dim3 block(BaseKernel::kThreadCount, 1, 1); + + int smem_size = int(sizeof(typename BaseKernel::SharedStorage)); + cutlass::Kernel + <<>>(gemm_params_); + + cudaError_t result = cudaGetLastError(); + + if (result != cudaSuccess) { + CUTLASS_TRACE_HOST(" grid launch failed with error " + << cudaGetErrorString(result)); + return Status::kErrorInternal; + } } - /// Runs the kernel using initialized state. - Status operator()(cudaStream_t stream = nullptr) + // Launch splitkReduction { - return run(stream); + dim3 grid(32, gemm_params_.problem_visitor.problem_count); + dim3 block(256); + splitkReduction<<>>( + gemm_params_.ptr_D, + gemm_params_.ptr_D_split, + gemm_params_.problem_visitor.problem_sizes, + gemm_params_.split_k_slices, + gemm_params_.splitk_buffer_offsets); + + cudaError_t result = cudaGetLastError(); + + if (result != cudaSuccess) { + CUTLASS_TRACE_HOST(" grid launch failed with error " + << cudaGetErrorString(result)); + return Status::kErrorInternal; + } } - /// Initializes and runs the kernel. - Status operator()(Arguments const& args, void* workspace, cudaStream_t stream = nullptr) - { + return Status::kSuccess; + } - Status status = initialize(args, workspace, stream); + /// Runs the kernel using initialized state. + Status operator()(cudaStream_t stream = nullptr) { return run(stream); } - if (status == Status::kSuccess) - { - status = run(stream); - } + /// Initializes and runs the kernel. + Status operator()(Arguments const& args, + void* workspace, + cudaStream_t stream = nullptr) { + Status status = initialize(args, workspace, stream); - return status; + if (status == Status::kSuccess) { + status = run(stream); } + + return status; + } }; ///////////////////////////////////////////////////////////////////////////////////////////////// /// GEMM Grouped template -class SplitkGemmGrouped : public BaseSplitkGrouped -{ -public: - using GemmKernel = GemmKernel_; +class SplitkGemmGrouped : public BaseSplitkGrouped { + public: + using GemmKernel = GemmKernel_; }; ///////////////////////////////////////////////////////////////////////////////////////////////// -} // namespace device -} // namespace gemm -} // namespace cutlass +} // namespace device +} // namespace gemm +} // namespace cutlass ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/dispatch_policy.hpp b/custom_ops/gpu_ops/cutlass_extensions/gemm/dispatch_policy.hpp index f4cf0bf4200..c276c4e9c08 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/gemm/dispatch_policy.hpp +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/dispatch_policy.hpp @@ -30,7 +30,8 @@ struct KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum // n-buffer in smem (Hopper TMA), pipelined with Hopper GMMA and TMA, Warp // specialized dynamic schedule For FP8 kernels with Block Scaling -template , +template , class KernelSchedule = KernelTmaWarpSpecialized, int ScaleGranularityM = 0 // `ScaleGranularityM` specifies scaling granularity along M, @@ -38,7 +39,8 @@ template , // granularity is `size<0>(TileShape_MNK{})` along M. > struct MainloopSm90TmaGmmaWarpSpecializedBlockScalingSubGroupMFP8 - : MainloopSm90TmaGmmaWarpSpecialized { static_assert( cute::is_same_v< diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h b/custom_ops/gpu_ops/cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h index 3f834702709..0d25460e667 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h @@ -1,6 +1,6 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: Apache-2.0 + * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & + * AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -27,164 +27,190 @@ #include "cutlass_extensions/arch/mma.h" #include "cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h" -namespace cutlass -{ -namespace gemm -{ -namespace kernel -{ +namespace cutlass { +namespace gemm { +namespace kernel { template -struct MixedGemmArchTraits -{ - static_assert(dependent_false, "Unrecognised parameterization"); +struct MixedGemmArchTraits { + static_assert(dependent_false, "Unrecognised parameterization"); }; template -struct MixedGemmArchTraits -{ - static constexpr int Stages = 2; - using OperatorClass = cutlass::arch::OpClassSimt; - using AccType = float; - using LayoutB = cutlass::layout::ColumnMajor; - - static constexpr int ElementsPerAccessA = 1; - static constexpr int ElementsPerAccessB = 1; - static constexpr int ElementsPerAccessC = 1; - static constexpr int ThreadblockK = 8; - using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; - - using Operator = cutlass::arch::OpMultiplyAdd; +struct MixedGemmArchTraits { + static constexpr int Stages = 2; + using OperatorClass = cutlass::arch::OpClassSimt; + using AccType = float; + using LayoutB = cutlass::layout::ColumnMajor; + + static constexpr int ElementsPerAccessA = 1; + static constexpr int ElementsPerAccessB = 1; + static constexpr int ElementsPerAccessC = 1; + static constexpr int ThreadblockK = 8; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + + using Operator = cutlass::arch::OpMultiplyAdd; }; // ========================= Volta Traits =========================== // Volta will always dequantize after the global memory load. // This will instantiate any HMMA tensorcore kernels for Volta. -// Note that volta does not have native bfloat support so weights and activations will be casted to fp16 -// and compute will happen in fp16 then will be converted for bf16 output. +// Note that volta does not have native bfloat support so weights and +// activations will be casted to fp16 and compute will happen in fp16 then will +// be converted for bf16 output. template -struct MixedGemmArchTraits::value - || cutlass::platform::is_same::value>::type> -{ -private: - using LayoutDetails = LayoutDetailsB; - -public: - static constexpr int ThreadblockK = LayoutDetails::ThreadblockK; - - using OperatorClass = cutlass::arch::OpClassTensorOp; - using AccType = float; - using LayoutB = typename LayoutDetails::Layout; - - static constexpr int ElementsPerAccessA = 128 / cutlass::sizeof_bits::value; - static constexpr int ElementsPerAccessB = LayoutDetails::ElementsPerAccess; - static constexpr int ElementsPerAccessC = 128 / cutlass::sizeof_bits::value; - using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; - - using Operator = typename LayoutDetails::Operator; +struct MixedGemmArchTraits< + TypeA, + TypeB, + cutlass::arch::Sm70, + typename cutlass::platform::enable_if< + cutlass::platform::is_same::value || + cutlass::platform::is_same::value>::type> { + private: + using LayoutDetails = LayoutDetailsB; + + public: + static constexpr int ThreadblockK = LayoutDetails::ThreadblockK; + + using OperatorClass = cutlass::arch::OpClassTensorOp; + using AccType = float; + using LayoutB = typename LayoutDetails::Layout; + + static constexpr int ElementsPerAccessA = + 128 / cutlass::sizeof_bits::value; + static constexpr int ElementsPerAccessB = LayoutDetails::ElementsPerAccess; + static constexpr int ElementsPerAccessC = + 128 / cutlass::sizeof_bits::value; + using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; + + using Operator = typename LayoutDetails::Operator; }; // ======================= Turing Traits ============================== -// Note that turing does not have native bfloat support so weights and activations will be casted to fp16 -// and compute will happen in fp16 then will be converted for bf16 output. +// Note that turing does not have native bfloat support so weights and +// activations will be casted to fp16 and compute will happen in fp16 then will +// be converted for bf16 output. template -struct MixedGemmArchTraits::value - || cutlass::platform::is_same::value>::type> -{ -private: - using LayoutDetails = LayoutDetailsB; - -public: - static constexpr int ThreadblockK = LayoutDetails::ThreadblockK; - - using OperatorClass = cutlass::arch::OpClassTensorOp; - using AccType = float; - using LayoutB = typename LayoutDetails::Layout; - - static constexpr int ElementsPerAccessA = 128 / cutlass::sizeof_bits::value; - static constexpr int ElementsPerAccessB = LayoutDetails::ElementsPerAccess; - static constexpr int ElementsPerAccessC = 128 / cutlass::sizeof_bits::value; - using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; - - using Operator = typename LayoutDetails::Operator; +struct MixedGemmArchTraits< + TypeA, + TypeB, + cutlass::arch::Sm75, + typename cutlass::platform::enable_if< + cutlass::platform::is_same::value || + cutlass::platform::is_same::value>::type> { + private: + using LayoutDetails = LayoutDetailsB; + + public: + static constexpr int ThreadblockK = LayoutDetails::ThreadblockK; + + using OperatorClass = cutlass::arch::OpClassTensorOp; + using AccType = float; + using LayoutB = typename LayoutDetails::Layout; + + static constexpr int ElementsPerAccessA = + 128 / cutlass::sizeof_bits::value; + static constexpr int ElementsPerAccessB = LayoutDetails::ElementsPerAccess; + static constexpr int ElementsPerAccessC = + 128 / cutlass::sizeof_bits::value; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; + + using Operator = typename LayoutDetails::Operator; }; // ======================= Ampere Traits ============================== template -struct MixedGemmArchTraits::value - || cutlass::platform::is_same::value>::type> -{ -private: - using LayoutDetails = LayoutDetailsB; - -public: - static constexpr int ThreadblockK = LayoutDetails::ThreadblockK; - - using OperatorClass = cutlass::arch::OpClassTensorOp; - using AccType = float; - using LayoutB = typename LayoutDetails::Layout; - - static constexpr int ElementsPerAccessA = 128 / cutlass::sizeof_bits::value; - static constexpr int ElementsPerAccessB = LayoutDetails::ElementsPerAccess; - static constexpr int ElementsPerAccessC = 128 / cutlass::sizeof_bits::value; - using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; - - using Operator = typename LayoutDetails::Operator; +struct MixedGemmArchTraits< + TypeA, + TypeB, + cutlass::arch::Sm80, + typename cutlass::platform::enable_if< + cutlass::platform::is_same::value || + cutlass::platform::is_same::value>::type> { + private: + using LayoutDetails = LayoutDetailsB; + + public: + static constexpr int ThreadblockK = LayoutDetails::ThreadblockK; + + using OperatorClass = cutlass::arch::OpClassTensorOp; + using AccType = float; + using LayoutB = typename LayoutDetails::Layout; + + static constexpr int ElementsPerAccessA = + 128 / cutlass::sizeof_bits::value; + static constexpr int ElementsPerAccessB = LayoutDetails::ElementsPerAccess; + static constexpr int ElementsPerAccessC = + 128 / cutlass::sizeof_bits::value; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + + using Operator = typename LayoutDetails::Operator; }; // ======================= Ada Traits ============================== template -struct MixedGemmArchTraits::value - || cutlass::platform::is_same::value>::type> -{ -private: - using LayoutDetails = LayoutDetailsB; - -public: - static constexpr int ThreadblockK = LayoutDetails::ThreadblockK; - - using OperatorClass = cutlass::arch::OpClassTensorOp; - using AccType = float; - using LayoutB = typename LayoutDetails::Layout; - - static constexpr int ElementsPerAccessA = 128 / cutlass::sizeof_bits::value; - static constexpr int ElementsPerAccessB = LayoutDetails::ElementsPerAccess; - static constexpr int ElementsPerAccessC = 128 / cutlass::sizeof_bits::value; - using InstructionShape = cutlass::gemm::GemmShape<16, 8, 256 / cutlass::sizeof_bits::value>; - - using Operator = typename LayoutDetails::Operator; +struct MixedGemmArchTraits< + TypeA, + TypeB, + cutlass::arch::Sm89, + typename cutlass::platform::enable_if< + cutlass::platform::is_same::value || + cutlass::platform::is_same::value>::type> { + private: + using LayoutDetails = LayoutDetailsB; + + public: + static constexpr int ThreadblockK = LayoutDetails::ThreadblockK; + + using OperatorClass = cutlass::arch::OpClassTensorOp; + using AccType = float; + using LayoutB = typename LayoutDetails::Layout; + + static constexpr int ElementsPerAccessA = + 128 / cutlass::sizeof_bits::value; + static constexpr int ElementsPerAccessB = LayoutDetails::ElementsPerAccess; + static constexpr int ElementsPerAccessC = + 128 / cutlass::sizeof_bits::value; + using InstructionShape = + cutlass::gemm::GemmShape<16, 8, 256 / cutlass::sizeof_bits::value>; + + using Operator = typename LayoutDetails::Operator; }; // FP8 A/B = fp8, C/D = fp32 template -struct MixedGemmArchTraits::value - || cutlass::platform::is_same::value>::type> -{ -private: - using LayoutDetails = LayoutDetailsB; - -public: - static constexpr int ThreadblockK = LayoutDetails::ThreadblockK; - - using OperatorClass = cutlass::arch::OpClassTensorOp; - using AccType = float; - // be careful, TypeC should align with HopperGroupedGemmInput::OutputTypeAdaptor_t - using TypeC = __nv_bfloat16; - using LayoutB = typename LayoutDetails::Layout; - - static constexpr int ElementsPerAccessA = 128 / cutlass::sizeof_bits::value; - static constexpr int ElementsPerAccessB = LayoutDetails::ElementsPerAccess; - static constexpr int ElementsPerAccessC = 128 / cutlass::sizeof_bits::value; - using InstructionShape = cutlass::gemm::GemmShape<16, 8, 256 / cutlass::sizeof_bits::value>; - - using Operator = typename LayoutDetails::Operator; +struct MixedGemmArchTraits< + TypeA, + TypeB, + cutlass::arch::Sm89, + typename cutlass::platform::enable_if< + cutlass::platform::is_same::value || + cutlass::platform::is_same::value>::type> { + private: + using LayoutDetails = LayoutDetailsB; + + public: + static constexpr int ThreadblockK = LayoutDetails::ThreadblockK; + + using OperatorClass = cutlass::arch::OpClassTensorOp; + using AccType = float; + // be careful, TypeC should align with + // HopperGroupedGemmInput::OutputTypeAdaptor_t + using TypeC = __nv_bfloat16; + using LayoutB = typename LayoutDetails::Layout; + + static constexpr int ElementsPerAccessA = + 128 / cutlass::sizeof_bits::value; + static constexpr int ElementsPerAccessB = LayoutDetails::ElementsPerAccess; + static constexpr int ElementsPerAccessC = + 128 / cutlass::sizeof_bits::value; + using InstructionShape = + cutlass::gemm::GemmShape<16, 8, 256 / cutlass::sizeof_bits::value>; + + using Operator = typename LayoutDetails::Operator; }; -} // namespace kernel -} // namespace gemm -} // namespace cutlass +} // namespace kernel +} // namespace gemm +} // namespace cutlass diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/kernel/default_int8_traits.h b/custom_ops/gpu_ops/cutlass_extensions/gemm/kernel/default_int8_traits.h index 3fd722994e2..beab7fd253d 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/gemm/kernel/default_int8_traits.h +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/kernel/default_int8_traits.h @@ -1,6 +1,6 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: Apache-2.0 + * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & + * AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -22,36 +22,30 @@ #include "cutlass/gemm/gemm.h" #include "cutlass/layout/matrix.h" -namespace cutlass -{ -namespace gemm -{ -namespace kernel -{ +namespace cutlass { +namespace gemm { +namespace kernel { template -struct Int8GemmArchTraits -{ - using OperatorClass = cutlass::arch::OpClassSimt; - using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; +struct Int8GemmArchTraits { + using OperatorClass = cutlass::arch::OpClassSimt; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; }; // ======================= Turing Traits ============================== template <> -struct Int8GemmArchTraits -{ - using OperatorClass = cutlass::arch::OpClassTensorOp; - using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; +struct Int8GemmArchTraits { + using OperatorClass = cutlass::arch::OpClassTensorOp; + using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; }; // ======================= Ampere Traits ============================== template <> -struct Int8GemmArchTraits -{ - using OperatorClass = cutlass::arch::OpClassTensorOp; - using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; +struct Int8GemmArchTraits { + using OperatorClass = cutlass::arch::OpClassTensorOp; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; }; -} // namespace kernel -} // namespace gemm -} // namespace cutlass +} // namespace kernel +} // namespace gemm +} // namespace cutlass diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/kernel/fpA_intB_gemm.h b/custom_ops/gpu_ops/cutlass_extensions/gemm/kernel/fpA_intB_gemm.h index b83f6d76e00..f8ee5bdab9d 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/gemm/kernel/fpA_intB_gemm.h +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/kernel/fpA_intB_gemm.h @@ -1,12 +1,12 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation @@ -18,19 +18,21 @@ * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /*! \file - \brief Template for a pipelined GEMM kernel. Does not compute batching or support split-K. + \brief Template for a pipelined GEMM kernel. Does not compute batching or + support split-K. */ #pragma once @@ -46,523 +48,560 @@ ///////////////////////////////////////////////////////////////////////////////////////////////// -namespace cutlass -{ -namespace gemm -{ -namespace kernel -{ +namespace cutlass { +namespace gemm { +namespace kernel { ///////////////////////////////////////////////////////////////////////////////////////////////// -namespace detail -{ +namespace detail { template inline constexpr bool dependent_false_v = false; } -template -struct GemmFpAIntB -{ - - using Mma = Mma_; - using Epilogue = Epilogue_; - using EpilogueOutputOp = typename Epilogue::OutputOp; - using ThreadblockSwizzle = ThreadblockSwizzle_; - static bool const kSplitKSerial = SplitKSerial; - - using ElementA = typename Mma::IteratorA::Element; - using LayoutA = typename Mma::IteratorA::Layout; - using ElementB = typename Mma::IteratorB::Element; - using LayoutB = typename Mma::IteratorB::Layout; - using ElementC = typename Epilogue::OutputTileIterator::Element; - using LayoutC = typename Mma::LayoutC; - using ElementScale = ElementC; - - static ComplexTransform const kTransformA = Mma::kTransformA; - static ComplexTransform const kTransformB = Mma::kTransformA; - - // Type definitions about the mainloop. - using Operator = typename Mma::Operator; - using OperatorClass = typename Mma::Operator::OperatorClass; - using ThreadblockShape = typename Mma::Shape; - using WarpShape = typename Mma::Operator::Shape; - using InstructionShape = typename Mma::Policy::Operator::InstructionShape; - using ArchTag = typename Mma::ArchTag; - - static int const kStages = Mma::kStages; - static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; - static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; - static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; - - /// Warp count (concept: GemmShape) - using WarpCount = typename Mma::WarpCount; - static int const kThreadCount = 32 * WarpCount::kCount; - - static constexpr int kInterleave = Mma::IteratorB::Shape::kRow / Mma::Shape::kK; - - /// Parameters structure - struct Arguments - { - GemmUniversalMode mode = GemmUniversalMode::kGemm; - - cutlass::gemm::GemmCoord problem_size; - int group_size; - typename Mma::IteratorA::TensorRef ref_A; - typename Mma::IteratorB::TensorRef ref_B; - typename Mma::IteratorScale::TensorRef ref_scale; - typename Mma::IteratorScale::TensorRef ref_zero; - typename Epilogue::OutputTileIterator::TensorRef ref_C; - typename Epilogue::OutputTileIterator::TensorRef ref_D; - - // Control serial split-k - int batch_count; - - typename EpilogueOutputOp::Params output_op; - - // For gather+scatter operations - int const* gather_A_indices; - int const* gather_B_indices; - int const* scatter_D_indices; - - // Included so we can use Gemm Universal - int batch_stride_D = 0; - - // - // Methods - // - - CUTLASS_HOST_DEVICE - Arguments() {} - - CUTLASS_HOST_DEVICE - Arguments(cutlass::gemm::GemmCoord const& problem_size, int const group_size, - typename Mma::IteratorA::TensorRef ref_A, typename Mma::IteratorB::TensorRef ref_B, - typename Mma::IteratorScale::TensorRef ref_scale, typename Mma::IteratorScale::TensorRef ref_zero, - typename Epilogue::OutputTileIterator::TensorRef ref_C, - typename Epilogue::OutputTileIterator::TensorRef ref_D, int serial_split_k_factor, - typename EpilogueOutputOp::Params output_op = typename EpilogueOutputOp::Params(), - int const* gather_A_indices = nullptr, int const* gather_B_indices = nullptr, - int const* scatter_D_indices = nullptr) - : problem_size(problem_size) - , group_size(group_size) - , ref_A(ref_A) - , ref_B(ref_B) - , ref_scale(ref_scale) - , ref_zero(ref_zero) - , ref_C(ref_C) - , ref_D(ref_D) - , batch_count(serial_split_k_factor) - , output_op(output_op) - , gather_A_indices(gather_A_indices) - , gather_B_indices(gather_B_indices) - , scatter_D_indices(scatter_D_indices) - { - } - }; +template +struct GemmFpAIntB { + using Mma = Mma_; + using Epilogue = Epilogue_; + using EpilogueOutputOp = typename Epilogue::OutputOp; + using ThreadblockSwizzle = ThreadblockSwizzle_; + static bool const kSplitKSerial = SplitKSerial; + + using ElementA = typename Mma::IteratorA::Element; + using LayoutA = typename Mma::IteratorA::Layout; + using ElementB = typename Mma::IteratorB::Element; + using LayoutB = typename Mma::IteratorB::Layout; + using ElementC = typename Epilogue::OutputTileIterator::Element; + using LayoutC = typename Mma::LayoutC; + using ElementScale = ElementC; + + static ComplexTransform const kTransformA = Mma::kTransformA; + static ComplexTransform const kTransformB = Mma::kTransformA; + + // Type definitions about the mainloop. + using Operator = typename Mma::Operator; + using OperatorClass = typename Mma::Operator::OperatorClass; + using ThreadblockShape = typename Mma::Shape; + using WarpShape = typename Mma::Operator::Shape; + using InstructionShape = typename Mma::Policy::Operator::InstructionShape; + using ArchTag = typename Mma::ArchTag; + + static int const kStages = Mma::kStages; + static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; + static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; + static int const kAlignmentC = + Epilogue::OutputTileIterator::kElementsPerAccess; + + /// Warp count (concept: GemmShape) + using WarpCount = typename Mma::WarpCount; + static int const kThreadCount = 32 * WarpCount::kCount; + + static constexpr int kInterleave = + Mma::IteratorB::Shape::kRow / Mma::Shape::kK; + + /// Parameters structure + struct Arguments { + GemmUniversalMode mode = GemmUniversalMode::kGemm; + + cutlass::gemm::GemmCoord problem_size; + int group_size; + typename Mma::IteratorA::TensorRef ref_A; + typename Mma::IteratorB::TensorRef ref_B; + typename Mma::IteratorScale::TensorRef ref_scale; + typename Mma::IteratorScale::TensorRef ref_zero; + typename Epilogue::OutputTileIterator::TensorRef ref_C; + typename Epilogue::OutputTileIterator::TensorRef ref_D; + + // Control serial split-k + int batch_count; + + typename EpilogueOutputOp::Params output_op; + + // For gather+scatter operations + int const* gather_A_indices; + int const* gather_B_indices; + int const* scatter_D_indices; + + // Included so we can use Gemm Universal + int batch_stride_D = 0; - /// Parameters structure - struct Params - { - cutlass::gemm::GemmCoord problem_size; - int group_size; - cutlass::gemm::GemmCoord grid_tiled_shape; - int swizzle_log_tile; - typename Mma::IteratorA::Params params_A; - typename Mma::IteratorA::TensorRef ref_A; - typename Mma::IteratorB::Params params_B; - typename Mma::IteratorB::TensorRef ref_B; - typename Mma::IteratorScale::Params params_scale; - typename Mma::IteratorScale::TensorRef ref_scale; - typename Mma::IteratorScale::TensorRef ref_zero; - typename Epilogue::OutputTileIterator::Params params_C; - typename Epilogue::OutputTileIterator::TensorRef ref_C; - typename Epilogue::OutputTileIterator::Params params_D; - typename Epilogue::OutputTileIterator::TensorRef ref_D; - typename EpilogueOutputOp::Params output_op; - int* semaphore; - int gemm_k_size; - // For gather+scatter operations - int const* gather_A_indices; - int const* gather_B_indices; - int const* scatter_D_indices; - - // - // Methods - // - - CUTLASS_HOST_DEVICE - Params() - : swizzle_log_tile(0) - , semaphore(0) - , gemm_k_size(0) - { - } - - CUTLASS_HOST_DEVICE - Params(Arguments const& args, cutlass::gemm::GemmCoord const& grid_tiled_shape, int const gemm_k_size, - void* workspace = nullptr) - : problem_size(args.problem_size) - , group_size(args.group_size) - , grid_tiled_shape(grid_tiled_shape) - , swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)) - , params_A(args.ref_A.layout()) - , ref_A(args.ref_A) - , params_B(args.ref_B.layout()) - , ref_B(args.ref_B) - , params_scale(args.ref_scale.layout()) - , ref_scale(args.ref_scale) - , ref_zero(args.ref_zero) - , params_C(args.ref_C.layout()) - , ref_C(args.ref_C) - , params_D(args.ref_D.layout()) - , ref_D(args.ref_D) - , output_op(args.output_op) - , semaphore(static_cast(workspace)) - , gemm_k_size(gemm_k_size) - , gather_A_indices(args.gather_A_indices) - , gather_B_indices(args.gather_B_indices) - , scatter_D_indices(args.scatter_D_indices) - { - } - }; + // + // Methods + // - /// Shared memory storage structure - union SharedStorage - { - typename Mma::SharedStorage main_loop; - typename Epilogue::SharedStorage epilogue; - }; + CUTLASS_HOST_DEVICE + Arguments() {} + + CUTLASS_HOST_DEVICE + Arguments(cutlass::gemm::GemmCoord const& problem_size, + int const group_size, + typename Mma::IteratorA::TensorRef ref_A, + typename Mma::IteratorB::TensorRef ref_B, + typename Mma::IteratorScale::TensorRef ref_scale, + typename Mma::IteratorScale::TensorRef ref_zero, + typename Epilogue::OutputTileIterator::TensorRef ref_C, + typename Epilogue::OutputTileIterator::TensorRef ref_D, + int serial_split_k_factor, + typename EpilogueOutputOp::Params output_op = + typename EpilogueOutputOp::Params(), + int const* gather_A_indices = nullptr, + int const* gather_B_indices = nullptr, + int const* scatter_D_indices = nullptr) + : problem_size(problem_size), + group_size(group_size), + ref_A(ref_A), + ref_B(ref_B), + ref_scale(ref_scale), + ref_zero(ref_zero), + ref_C(ref_C), + ref_D(ref_D), + batch_count(serial_split_k_factor), + output_op(output_op), + gather_A_indices(gather_A_indices), + gather_B_indices(gather_B_indices), + scatter_D_indices(scatter_D_indices) {} + }; + + /// Parameters structure + struct Params { + cutlass::gemm::GemmCoord problem_size; + int group_size; + cutlass::gemm::GemmCoord grid_tiled_shape; + int swizzle_log_tile; + typename Mma::IteratorA::Params params_A; + typename Mma::IteratorA::TensorRef ref_A; + typename Mma::IteratorB::Params params_B; + typename Mma::IteratorB::TensorRef ref_B; + typename Mma::IteratorScale::Params params_scale; + typename Mma::IteratorScale::TensorRef ref_scale; + typename Mma::IteratorScale::TensorRef ref_zero; + typename Epilogue::OutputTileIterator::Params params_C; + typename Epilogue::OutputTileIterator::TensorRef ref_C; + typename Epilogue::OutputTileIterator::Params params_D; + typename Epilogue::OutputTileIterator::TensorRef ref_D; + typename EpilogueOutputOp::Params output_op; + int* semaphore; + int gemm_k_size; + // For gather+scatter operations + int const* gather_A_indices; + int const* gather_B_indices; + int const* scatter_D_indices; // // Methods // CUTLASS_HOST_DEVICE - GemmFpAIntB() {} - - /// Determines whether kernel satisfies alignment - static Status can_implement(Arguments const& args) - { - static int const kAlignmentA - = (platform::is_same>::value) ? 32 - : (platform::is_same>::value) + Params() : swizzle_log_tile(0), semaphore(0), gemm_k_size(0) {} + + CUTLASS_HOST_DEVICE + Params(Arguments const& args, + cutlass::gemm::GemmCoord const& grid_tiled_shape, + int const gemm_k_size, + void* workspace = nullptr) + : problem_size(args.problem_size), + group_size(args.group_size), + grid_tiled_shape(grid_tiled_shape), + swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)), + params_A(args.ref_A.layout()), + ref_A(args.ref_A), + params_B(args.ref_B.layout()), + ref_B(args.ref_B), + params_scale(args.ref_scale.layout()), + ref_scale(args.ref_scale), + ref_zero(args.ref_zero), + params_C(args.ref_C.layout()), + ref_C(args.ref_C), + params_D(args.ref_D.layout()), + ref_D(args.ref_D), + output_op(args.output_op), + semaphore(static_cast(workspace)), + gemm_k_size(gemm_k_size), + gather_A_indices(args.gather_A_indices), + gather_B_indices(args.gather_B_indices), + scatter_D_indices(args.scatter_D_indices) {} + }; + + /// Shared memory storage structure + union SharedStorage { + typename Mma::SharedStorage main_loop; + typename Epilogue::SharedStorage epilogue; + }; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + GemmFpAIntB() {} + + /// Determines whether kernel satisfies alignment + static Status can_implement(Arguments const& args) { + static int const kAlignmentA = + (platform::is_same>::value) + ? 32 + : (platform::is_same>::value) ? 64 : Mma::IteratorA::AccessType::kElements; - static int const kAlignmentB - = (platform::is_same>::value) ? 32 - : (platform::is_same>::value) + static int const kAlignmentB = + (platform::is_same>::value) + ? 32 + : (platform::is_same>::value) ? 64 : Mma::IteratorB::AccessType::kElements; - static int const kAlignmentScale = Mma::IteratorScale::AccessType::kElements; + static int const kAlignmentScale = + Mma::IteratorScale::AccessType::kElements; - static int const kAlignmentC = (platform::is_same>::value) + static int const kAlignmentC = + (platform::is_same>::value) ? 32 - : (platform::is_same>::value) + : (platform::is_same>::value) ? 64 : Epilogue::OutputTileIterator::kElementsPerAccess; - if (!TensorRef_aligned(args.ref_A, kAlignmentA)) - { - return Status::kErrorMisalignedOperand; - } - - if (!TensorRef_aligned(args.ref_B, kAlignmentB)) - { - return Status::kErrorMisalignedOperand; - } - - if (!TensorRef_aligned(args.ref_scale, kAlignmentScale)) - { - return Status::kErrorMisalignedOperand; - } - - if (!TensorRef_aligned(args.ref_zero, kAlignmentScale)) - { - return Status::kErrorMisalignedOperand; - } - - if (!TensorRef_aligned(args.ref_C, kAlignmentC)) - { - return Status::kErrorMisalignedOperand; - } - - if (!TensorRef_aligned(args.ref_D, kAlignmentC)) - { - return Status::kErrorMisalignedOperand; - } - - if (!args.ref_scale.good()) - { - return Status::kErrorNotSupported; - } - - if constexpr (hasZero(Mma::QuantOp)) - { - if (!args.ref_zero.good()) - { - return Status::kErrorNotSupported; - } - } - else - { - if (args.ref_zero.good()) - { - return Status::kErrorNotSupported; - } - } - - if constexpr (isFinegrained(Mma::QuantOp)) - { - if (args.group_size != 64 && args.group_size != 128) - { - return Status::kErrorNotSupported; - } - } - - return Status::kSuccess; + if (!TensorRef_aligned(args.ref_A, kAlignmentA)) { + return Status::kErrorMisalignedOperand; } - static size_t get_extra_workspace_size(Arguments const& args, cutlass::gemm::GemmCoord const& grid_tiled_shape) - { - - return 0; + if (!TensorRef_aligned(args.ref_B, kAlignmentB)) { + return Status::kErrorMisalignedOperand; } - // Initializes the fine grained scale+bias iterator. Needed since the fine grained iterator - // has a different constructor signature than a regular cutlass iterator - template = true> - CUTLASS_DEVICE static IteratorScale initialize_scale(typename IteratorScale::Params const& params, - typename IteratorScale::Pointer pointer_scale, typename IteratorScale::Pointer pointer_zero, - typename IteratorScale::TensorCoord extent, int thread_id, - typename IteratorScale::TensorCoord const& threadblock_offset, int group_size) - { - - return IteratorScale(params, pointer_scale, pointer_zero, extent, thread_id, threadblock_offset, group_size); + if (!TensorRef_aligned(args.ref_scale, kAlignmentScale)) { + return Status::kErrorMisalignedOperand; } - template = true> - CUTLASS_DEVICE static IteratorScale initialize_scale(typename IteratorScale::Params const& params, - typename IteratorScale::Pointer pointer_scale, typename IteratorScale::Pointer pointer_zero, - typename IteratorScale::TensorCoord extent, int thread_id, - typename IteratorScale::TensorCoord const& threadblock_offset, int group_size) - { - - return IteratorScale(params, pointer_scale, extent, thread_id, threadblock_offset); + if (!TensorRef_aligned(args.ref_zero, kAlignmentScale)) { + return Status::kErrorMisalignedOperand; } - CUTLASS_DEVICE - void run_kernel_(Params const& params, SharedStorage& shared_storage) - { - using LayoutB = typename Mma::IteratorB::Layout; - static_assert(platform::is_same::value && kInterleave == 1 - || platform::is_same::value && kInterleave >= 1, - "B must be row major/col major OR col major interleaved."); - - // Compute threadblock location - ThreadblockSwizzle threadblock_swizzle; - - cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); - - // Early exit if CTA is out of range - if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() - || params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) - { - - return; - } - - // Compute initial location in logical coordinates - cutlass::MatrixCoord tb_offset_A{ - threadblock_tile_offset.m() * Mma::Shape::kM, - threadblock_tile_offset.k() * params.gemm_k_size, - }; - - cutlass::MatrixCoord tb_offset_B{threadblock_tile_offset.k() * params.gemm_k_size * kInterleave, - threadblock_tile_offset.n() * Mma::Shape::kN / kInterleave}; - - typename MatrixCoord::Index fg_row_offset = threadblock_tile_offset.k() * params.gemm_k_size / 64; - typename MatrixCoord::Index scale_row_offset = isFinegrained(Mma::QuantOp) ? fg_row_offset : 0; - cutlass::MatrixCoord tb_offset_scale{scale_row_offset, threadblock_tile_offset.n() * Mma::Shape::kN}; - - // Problem size is a function of threadblock index in the K dimension - int problem_size_k = min(params.problem_size.k(), (threadblock_tile_offset.k() + 1) * params.gemm_k_size); - - // Compute threadblock-scoped matrix multiply-add - int gemm_k_iterations = (problem_size_k - tb_offset_A.column() + Mma::Shape::kK - 1) / Mma::Shape::kK; - - // Compute position within threadblock - int thread_idx = threadIdx.x; - - // Construct iterators to A and B operands - typename Mma::IteratorA iterator_A(params.params_A, params.ref_A.data(), - {params.problem_size.m(), problem_size_k}, thread_idx, tb_offset_A, params.gather_A_indices); - - typename Mma::IteratorB iterator_B(params.params_B, params.ref_B.data(), - {problem_size_k * kInterleave, params.problem_size.n() / kInterleave}, thread_idx, tb_offset_B, - params.gather_B_indices); - - typename MatrixCoord::Index scale_row_extent = isFinegrained(Mma::QuantOp) ? problem_size_k / 64 : 1; - typename Mma::IteratorScale iterator_scale = initialize_scale( - params.params_scale, params.ref_scale.data(), params.ref_zero.data(), - {scale_row_extent, params.problem_size.n()}, thread_idx, tb_offset_scale, params.group_size); - - // Broadcast the warp_id computed by lane 0 to ensure dependent code - // is compiled as warp-uniform. - int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); - int lane_idx = threadIdx.x % 32; - - // - // Main loop - // - // Construct thread-scoped matrix multiply - Mma mma(shared_storage.main_loop, params.group_size, thread_idx, warp_idx, lane_idx); - - typename Mma::FragmentC accumulators; - - accumulators.clear(); - - if (!kSplitKSerial || gemm_k_iterations > 0) - { - // Compute threadblock-scoped matrix multiply-add - mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, iterator_scale, accumulators); - } - - // - // Epilogue - // + if (!TensorRef_aligned(args.ref_C, kAlignmentC)) { + return Status::kErrorMisalignedOperand; + } - EpilogueOutputOp output_op(params.output_op); + if (!TensorRef_aligned(args.ref_D, kAlignmentC)) { + return Status::kErrorMisalignedOperand; + } - // - // Masked tile iterators constructed from members - // + if (!args.ref_scale.good()) { + return Status::kErrorNotSupported; + } - threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + if constexpr (hasZero(Mma::QuantOp)) { + if (!args.ref_zero.good()) { + return Status::kErrorNotSupported; + } + } else { + if (args.ref_zero.good()) { + return Status::kErrorNotSupported; + } + } - // assume identity swizzle - MatrixCoord threadblock_offset( - threadblock_tile_offset.m() * Mma::Shape::kM, threadblock_tile_offset.n() * Mma::Shape::kN); + if constexpr (isFinegrained(Mma::QuantOp)) { + if (args.group_size != 64 && args.group_size != 128) { + return Status::kErrorNotSupported; + } + } - int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m(); + return Status::kSuccess; + } + + static size_t get_extra_workspace_size( + Arguments const& args, cutlass::gemm::GemmCoord const& grid_tiled_shape) { + return 0; + } + + // Initializes the fine grained scale+bias iterator. Needed since the fine + // grained iterator has a different constructor signature than a regular + // cutlass iterator + template = true> + CUTLASS_DEVICE static IteratorScale initialize_scale( + typename IteratorScale::Params const& params, + typename IteratorScale::Pointer pointer_scale, + typename IteratorScale::Pointer pointer_zero, + typename IteratorScale::TensorCoord extent, + int thread_id, + typename IteratorScale::TensorCoord const& threadblock_offset, + int group_size) { + return IteratorScale(params, + pointer_scale, + pointer_zero, + extent, + thread_id, + threadblock_offset, + group_size); + } + + template = true> + CUTLASS_DEVICE static IteratorScale initialize_scale( + typename IteratorScale::Params const& params, + typename IteratorScale::Pointer pointer_scale, + typename IteratorScale::Pointer pointer_zero, + typename IteratorScale::TensorCoord extent, + int thread_id, + typename IteratorScale::TensorCoord const& threadblock_offset, + int group_size) { + return IteratorScale( + params, pointer_scale, extent, thread_id, threadblock_offset); + } + + CUTLASS_DEVICE + void run_kernel_(Params const& params, SharedStorage& shared_storage) { + using LayoutB = typename Mma::IteratorB::Layout; + static_assert(platform::is_same::value && + kInterleave == 1 || + platform::is_same::value && + kInterleave >= 1, + "B must be row major/col major OR col major interleaved."); + + // Compute threadblock location + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord threadblock_tile_offset = + threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + // Early exit if CTA is out of range + if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() || + params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) { + return; + } - // Construct the semaphore. - Semaphore semaphore(params.semaphore + block_idx, thread_idx); + // Compute initial location in logical coordinates + cutlass::MatrixCoord tb_offset_A{ + threadblock_tile_offset.m() * Mma::Shape::kM, + threadblock_tile_offset.k() * params.gemm_k_size, + }; - // If performing a reduction via split-K, fetch the initial synchronization - if (kSplitKSerial && params.grid_tiled_shape.k() > 1) - { + cutlass::MatrixCoord tb_offset_B{ + threadblock_tile_offset.k() * params.gemm_k_size * kInterleave, + threadblock_tile_offset.n() * Mma::Shape::kN / kInterleave}; + + typename MatrixCoord::Index fg_row_offset = + threadblock_tile_offset.k() * params.gemm_k_size / 64; + typename MatrixCoord::Index scale_row_offset = + isFinegrained(Mma::QuantOp) ? fg_row_offset : 0; + cutlass::MatrixCoord tb_offset_scale{ + scale_row_offset, threadblock_tile_offset.n() * Mma::Shape::kN}; + + // Problem size is a function of threadblock index in the K dimension + int problem_size_k = + min(params.problem_size.k(), + (threadblock_tile_offset.k() + 1) * params.gemm_k_size); + + // Compute threadblock-scoped matrix multiply-add + int gemm_k_iterations = + (problem_size_k - tb_offset_A.column() + Mma::Shape::kK - 1) / + Mma::Shape::kK; + + // Compute position within threadblock + int thread_idx = threadIdx.x; + + // Construct iterators to A and B operands + typename Mma::IteratorA iterator_A( + params.params_A, + params.ref_A.data(), + {params.problem_size.m(), problem_size_k}, + thread_idx, + tb_offset_A, + params.gather_A_indices); + + typename Mma::IteratorB iterator_B( + params.params_B, + params.ref_B.data(), + {problem_size_k * kInterleave, params.problem_size.n() / kInterleave}, + thread_idx, + tb_offset_B, + params.gather_B_indices); + + typename MatrixCoord::Index scale_row_extent = + isFinegrained(Mma::QuantOp) ? problem_size_k / 64 : 1; + typename Mma::IteratorScale iterator_scale = + initialize_scale( + params.params_scale, + params.ref_scale.data(), + params.ref_zero.data(), + {scale_row_extent, params.problem_size.n()}, + thread_idx, + tb_offset_scale, + params.group_size); + + // Broadcast the warp_id computed by lane 0 to ensure dependent code + // is compiled as warp-uniform. + int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + int lane_idx = threadIdx.x % 32; - // Fetch the synchronization lock initially but do not block. - semaphore.fetch(); + // + // Main loop + // + // Construct thread-scoped matrix multiply + Mma mma(shared_storage.main_loop, + params.group_size, + thread_idx, + warp_idx, + lane_idx); + + typename Mma::FragmentC accumulators; + + accumulators.clear(); + + if (!kSplitKSerial || gemm_k_iterations > 0) { + // Compute threadblock-scoped matrix multiply-add + mma(gemm_k_iterations, + accumulators, + iterator_A, + iterator_B, + iterator_scale, + accumulators); + } - // Indicate which position in a serial reduction the output operator is currently updating - output_op.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k()); - } + // + // Epilogue + // - // Tile iterator loading from source tensor. - typename Epilogue::OutputTileIterator iterator_C(params.params_C, params.ref_C.data(), params.problem_size.mn(), - thread_idx, threadblock_offset, params.scatter_D_indices); + EpilogueOutputOp output_op(params.output_op); - // Tile iterator writing to destination tensor. - typename Epilogue::OutputTileIterator iterator_D(params.params_D, params.ref_D.data(), params.problem_size.mn(), - thread_idx, threadblock_offset, params.scatter_D_indices); + // + // Masked tile iterators constructed from members + // - Epilogue epilogue(shared_storage.epilogue, thread_idx, warp_idx, lane_idx); + threadblock_tile_offset = + threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); - // Wait on the semaphore - this latency may have been covered by iterator construction - if (kSplitKSerial && params.grid_tiled_shape.k() > 1) - { + // assume identity swizzle + MatrixCoord threadblock_offset( + threadblock_tile_offset.m() * Mma::Shape::kM, + threadblock_tile_offset.n() * Mma::Shape::kN); - // For subsequent threadblocks, the source matrix is held in the 'D' tensor. - if (threadblock_tile_offset.k()) - { - iterator_C = iterator_D; - } + int block_idx = threadblock_tile_offset.m() + + threadblock_tile_offset.n() * params.grid_tiled_shape.m(); - semaphore.wait(threadblock_tile_offset.k()); - } + // Construct the semaphore. + Semaphore semaphore(params.semaphore + block_idx, thread_idx); - // Execute the epilogue operator to update the destination tensor. - epilogue(output_op, iterator_D, accumulators, iterator_C); + // If performing a reduction via split-K, fetch the initial synchronization + if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { + // Fetch the synchronization lock initially but do not block. + semaphore.fetch(); - // - // Release the semaphore - // + // Indicate which position in a serial reduction the output operator is + // currently updating + output_op.set_k_partition(threadblock_tile_offset.k(), + params.grid_tiled_shape.k()); + } - if (kSplitKSerial && params.grid_tiled_shape.k() > 1) - { + // Tile iterator loading from source tensor. + typename Epilogue::OutputTileIterator iterator_C(params.params_C, + params.ref_C.data(), + params.problem_size.mn(), + thread_idx, + threadblock_offset, + params.scatter_D_indices); + + // Tile iterator writing to destination tensor. + typename Epilogue::OutputTileIterator iterator_D(params.params_D, + params.ref_D.data(), + params.problem_size.mn(), + thread_idx, + threadblock_offset, + params.scatter_D_indices); + + Epilogue epilogue(shared_storage.epilogue, thread_idx, warp_idx, lane_idx); + + // Wait on the semaphore - this latency may have been covered by iterator + // construction + if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { + // For subsequent threadblocks, the source matrix is held in the 'D' + // tensor. + if (threadblock_tile_offset.k()) { + iterator_C = iterator_D; + } + + semaphore.wait(threadblock_tile_offset.k()); + } - int lock = 0; - if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) - { + // Execute the epilogue operator to update the destination tensor. + epilogue(output_op, iterator_D, accumulators, iterator_C); - // The final threadblock resets the semaphore for subsequent grids. - lock = 0; - } - else - { - // Otherwise, the semaphore is incremented - lock = threadblock_tile_offset.k() + 1; - } + // + // Release the semaphore + // - semaphore.release(lock); - } + if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { + int lock = 0; + if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) { + // The final threadblock resets the semaphore for subsequent grids. + lock = 0; + } else { + // Otherwise, the semaphore is incremented + lock = threadblock_tile_offset.k() + 1; + } + + semaphore.release(lock); } - - template - CUTLASS_DEVICE void run_kernel(Params const& params, SharedStorage& shared_storage) - { - if constexpr (platform::is_same::value) - { - run_kernel_(params, shared_storage); - } - else - { - CUTLASS_NOT_IMPLEMENTED(); - } + } + + template + CUTLASS_DEVICE void run_kernel(Params const& params, + SharedStorage& shared_storage) { + if constexpr (platform::is_same::value) { + run_kernel_(params, shared_storage); + } else { + CUTLASS_NOT_IMPLEMENTED(); } - - /* - To improve compilation speed, we do not compile the device operator if the CUDA_ARCH does not correspond - to the ArchTag of the cutlass kernel operator. - */ - /// Executes one GEMM - CUTLASS_DEVICE - void operator()(Params const& params, SharedStorage& shared_storage) - { + } + + /* + To improve compilation speed, we do not compile the device operator if the + CUDA_ARCH does not correspond to the ArchTag of the cutlass kernel + operator. + */ + /// Executes one GEMM + CUTLASS_DEVICE + void operator()(Params const& params, SharedStorage& shared_storage) { #if defined(__CUDA_ARCH__) #if (__CUDA_ARCH__ >= 700) && (__CUDA_ARCH__ < 750) - run_kernel(params, shared_storage); + run_kernel(params, shared_storage); #elif (__CUDA_ARCH__ >= 750) && (__CUDA_ARCH__ < 800) - run_kernel(params, shared_storage); + run_kernel(params, shared_storage); #elif (__CUDA_ARCH__ >= 800) && (__CUDA_ARCH__ < 890) - run_kernel(params, shared_storage); + run_kernel(params, shared_storage); #elif (__CUDA_ARCH__ == 890) - run_kernel(params, shared_storage); + run_kernel(params, shared_storage); #elif (__CUDA_ARCH__ >= 900) - CUTLASS_NOT_IMPLEMENTED(); // Don't compile these for Hopper or later. Use CUTLASS 3.x kernels. + CUTLASS_NOT_IMPLEMENTED(); // Don't compile these for Hopper or later. Use + // CUTLASS 3.x kernels. #else - static_assert( - false, "Invalid architecture being compiled. Only Volta+ supported in weight-only quantization kernels."); + static_assert(false, + "Invalid architecture being compiled. Only Volta+ supported " + "in weight-only quantization kernels."); #endif #else - CUTLASS_NOT_IMPLEMENTED(); + CUTLASS_NOT_IMPLEMENTED(); #endif - } + } }; ///////////////////////////////////////////////////////////////////////////////////////////////// -} // namespace kernel -} // namespace gemm -} // namespace cutlass +} // namespace kernel +} // namespace gemm +} // namespace cutlass diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/kernel/gemm_universal_gated.hpp b/custom_ops/gpu_ops/cutlass_extensions/gemm/kernel/gemm_universal_gated.hpp index 15faad26ee7..cc143751bcc 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/gemm/kernel/gemm_universal_gated.hpp +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/kernel/gemm_universal_gated.hpp @@ -54,15 +54,16 @@ namespace cutlass::gemm::kernel { * 2.x API type argument order. Template arguments without two names * belong to the 3.x API only. **/ -template + class TileScheduler_ = void, + class Enable = void> class GemmUniversalGated; //////////////////////////////////////////////////////////////////////////////// -} // namespace cutlass::gemm::kernel +} // namespace cutlass::gemm::kernel //////////////////////////////////////////////////////////////////////////////// diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/kernel/gemm_with_epilogue_visitor.h b/custom_ops/gpu_ops/cutlass_extensions/gemm/kernel/gemm_with_epilogue_visitor.h index 6c4c578c9c9..be80a3feec7 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/gemm/kernel/gemm_with_epilogue_visitor.h +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/kernel/gemm_with_epilogue_visitor.h @@ -1,12 +1,12 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation @@ -18,25 +18,27 @@ * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /*! \file \brief GEMM kernel to support the epilogue visitor model for customized softmax partial reduction epilogue fusion. - This source file will likely be moved to `include/cutlass/gemm/kernel/` in the future once - its usage has been stabilized. For now, it is included in this example to demonstrate - some basic output fusion options. + This source file will likely be moved to `include/cutlass/gemm/kernel/` in + the future once its usage has been stabilized. For now, it is included in + this example to demonstrate some basic output fusion options. - original file: 3rdparty/cutlass/examples/35_gemm_softmax/gemm_with_epilogue_visitor.h + original file: + 3rdparty/cutlass/examples/35_gemm_softmax/gemm_with_epilogue_visitor.h */ #pragma once @@ -55,533 +57,527 @@ namespace tk = common; ///////////////////////////////////////////////////////////////////////////////////////////////// -namespace cutlass -{ -namespace gemm -{ -namespace kernel -{ +namespace cutlass { +namespace gemm { +namespace kernel { ///////////////////////////////////////////////////////////////////////////////////////////////// -template -struct GemmWithEpilogueVisitor -{ -public: - using Mma = Mma_; - using Epilogue = Epilogue_; - using EpilogueVisitor = typename Epilogue::Visitor; - using ThreadblockSwizzle = ThreadblockSwizzle_; - - using ElementA = typename Mma::IteratorA::Element; - using LayoutA = typename Mma::IteratorA::Layout; - using TensorRefA = TensorRef; - - using ElementB = typename Mma::IteratorB::Element; - using LayoutB = typename Mma::IteratorB::Layout; - using TensorRefB = TensorRef; - - using ElementCompute = typename EpilogueVisitor::ElementCompute; - using LayoutAlphaCol = cutlass::layout::RowMajor; - using LayoutAlphaRow = cutlass::layout::ColumnMajor; - using TensorRefAlphaCol = TensorRef; - using TensorRefAlphaRow = TensorRef; - - using ElementC = typename EpilogueVisitor::ElementOutput; - using LayoutC = typename Epilogue::Layout; - using TensorRefC = TensorRef; - - static ComplexTransform const kTransformA = Mma::kTransformA; - static ComplexTransform const kTransformB = Mma::kTransformB; - using Operator = typename Mma::Operator; - - using OperatorClass = typename Mma::Operator::OperatorClass; - using ThreadblockShape = typename Mma::Shape; - using WarpShape = typename Mma::Operator::Shape; - using InstructionShape = typename Mma::Policy::Operator::InstructionShape; - using ArchTag = typename Mma::ArchTag; - using EpilogueOutputOp = - typename Epilogue::Visitor::ElementwiseFunctor; // Define type so GemmUniversalBase doesn't complain - - static int const kStages = Mma::kStages; - static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; - static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; - static int const kAlignmentC = EpilogueVisitor::kElementsPerAccess; - - /// Warp count (concept: GemmShape) - using WarpCount = typename Mma::WarpCount; - static int const kThreadCount = 32 * WarpCount::kCount; - - /// Split-K preserves splits that are 128b aligned - static int const kSplitKAlignment - = const_max(128 / sizeof_bits::value, 128 / sizeof_bits::value); - +template +struct GemmWithEpilogueVisitor { + public: + using Mma = Mma_; + using Epilogue = Epilogue_; + using EpilogueVisitor = typename Epilogue::Visitor; + using ThreadblockSwizzle = ThreadblockSwizzle_; + + using ElementA = typename Mma::IteratorA::Element; + using LayoutA = typename Mma::IteratorA::Layout; + using TensorRefA = TensorRef; + + using ElementB = typename Mma::IteratorB::Element; + using LayoutB = typename Mma::IteratorB::Layout; + using TensorRefB = TensorRef; + + using ElementCompute = typename EpilogueVisitor::ElementCompute; + using LayoutAlphaCol = cutlass::layout::RowMajor; + using LayoutAlphaRow = cutlass::layout::ColumnMajor; + using TensorRefAlphaCol = TensorRef; + using TensorRefAlphaRow = TensorRef; + + using ElementC = typename EpilogueVisitor::ElementOutput; + using LayoutC = typename Epilogue::Layout; + using TensorRefC = TensorRef; + + static ComplexTransform const kTransformA = Mma::kTransformA; + static ComplexTransform const kTransformB = Mma::kTransformB; + using Operator = typename Mma::Operator; + + using OperatorClass = typename Mma::Operator::OperatorClass; + using ThreadblockShape = typename Mma::Shape; + using WarpShape = typename Mma::Operator::Shape; + using InstructionShape = typename Mma::Policy::Operator::InstructionShape; + using ArchTag = typename Mma::ArchTag; + using EpilogueOutputOp = typename Epilogue::Visitor:: + ElementwiseFunctor; // Define type so GemmUniversalBase doesn't complain + + static int const kStages = Mma::kStages; + static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; + static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; + static int const kAlignmentC = EpilogueVisitor::kElementsPerAccess; + + /// Warp count (concept: GemmShape) + using WarpCount = typename Mma::WarpCount; + static int const kThreadCount = 32 * WarpCount::kCount; + + /// Split-K preserves splits that are 128b aligned + static int const kSplitKAlignment = const_max( + 128 / sizeof_bits::value, 128 / sizeof_bits::value); + + // + // Structures + // + + /// Argument structure + struct Arguments { // - // Structures + // Data members // - /// Argument structure - struct Arguments - { + GemmUniversalMode mode; + GemmCoord problem_size; + int batch_count; - // - // Data members - // + TensorRefA ref_A; + TensorRefB ref_B; + tk::QuantMode quant_option; + TensorRefAlphaCol ref_alpha_col; + TensorRefAlphaRow ref_alpha_row; + TensorRefC ref_C; + TensorRefC ref_D; - GemmUniversalMode mode; - GemmCoord problem_size; - int batch_count; + int64_t batch_stride_A; + int64_t batch_stride_B; + int64_t batch_stride_D; - TensorRefA ref_A; - TensorRefB ref_B; - tk::QuantMode quant_option; - TensorRefAlphaCol ref_alpha_col; - TensorRefAlphaRow ref_alpha_row; - TensorRefC ref_C; - TensorRefC ref_D; - - int64_t batch_stride_A; - int64_t batch_stride_B; - int64_t batch_stride_D; - - typename EpilogueVisitor::Arguments epilogue_visitor; - - // - // Methods - // - - Arguments() - : mode(GemmUniversalMode::kGemm) - , batch_count(1) - { - } - - /// constructs an arguments structure - Arguments(GemmUniversalMode mode_, GemmCoord problem_size_, int batch_count_, TensorRefA ref_A_, - TensorRefB ref_B_, tk::QuantMode quant_option_, TensorRefAlphaCol ref_alpha_col_, - TensorRefAlphaRow ref_alpha_row_, TensorRefC ref_C_, TensorRefC ref_D_, int64_t batch_stride_A_, - int64_t batch_stride_B_, typename EpilogueVisitor::Arguments epilogue_visitor_) - : mode(mode_) - , problem_size(problem_size_) - , batch_count(batch_count_) - , ref_A(ref_A_) - , ref_B(ref_B_) - , quant_option(quant_option_) - , ref_alpha_col(ref_alpha_col_) - , ref_alpha_row(ref_alpha_row_) - , ref_C(ref_C_) - , ref_D(ref_D_) - , batch_stride_A(batch_stride_A_) - , batch_stride_B(batch_stride_B_) - , batch_stride_D(0) - , epilogue_visitor(epilogue_visitor_) - { - } - }; + typename EpilogueVisitor::Arguments epilogue_visitor; // - // Structure for precomputing values in host memory and passing to kernels + // Methods // - /// Parameters structure - struct Params - { - - cutlass::gemm::GemmCoord problem_size; - cutlass::gemm::GemmCoord grid_tiled_shape; - int swizzle_log_tile; - - typename Mma::IteratorA::Params params_A; - typename Mma::IteratorB::Params params_B; - typename EpilogueVisitor::ScaleTileIterator::Params params_alpha_col; - typename EpilogueVisitor::ScaleTileIterator::Params params_alpha_row; - typename EpilogueVisitor::OutputTileIterator::Params params_C; - typename EpilogueVisitor::OutputTileIterator::Params params_D; - - GemmUniversalMode mode; - int batch_count; - int gemm_k_size; - - void* ptr_A; - void* ptr_B; - tk::QuantMode quant_option; - typename EpilogueVisitor::ScaleTileIterator::Element* ptr_alpha_col; - typename EpilogueVisitor::ScaleTileIterator::Element* ptr_alpha_row; - ElementC* ptr_C; - ElementC* ptr_D; - - int64_t batch_stride_A; - int64_t batch_stride_B; - - typename EpilogueVisitor::Params epilogue_visitor; - - // - // Methods - // - - CUTLASS_HOST_DEVICE - Params() - : swizzle_log_tile(0) - , params_A(0) - , params_B(0) - , params_alpha_col(0) - , params_C(0) - , params_D(0) - , batch_count(0) - , gemm_k_size(0) - , mode(cutlass::gemm::GemmUniversalMode::kGemm) - , ptr_A(nullptr) - , ptr_B(nullptr) - , ptr_alpha_col(nullptr) - , ptr_alpha_row(nullptr) - , ptr_C(nullptr) - , ptr_D(nullptr) - , batch_stride_A(0) - , batch_stride_B(0) - { - } - - Params( - Arguments const& args, cutlass::gemm::GemmCoord const& grid_tiled_shape_, int gemm_k_size_, int* workspace_) - : problem_size(args.problem_size) - , swizzle_log_tile(0) - , params_A(args.ref_A.layout()) - , params_B(args.ref_B.layout()) - , params_alpha_col(args.ref_alpha_col.layout()) - , params_alpha_row(args.ref_alpha_col.layout()) - , params_C(args.ref_C.layout()) - , params_D(args.ref_D.layout()) - , mode(args.mode) - , batch_count(args.batch_count) - , gemm_k_size(args.problem_size.k()) - , ptr_A(args.ref_A.data()) - , ptr_B(args.ref_B.data()) - , quant_option(args.quant_option) - , ptr_alpha_col(args.ref_alpha_col.data()) - , ptr_alpha_row(args.ref_alpha_row.data()) - , ptr_C(args.ref_C.data()) - , ptr_D(args.ref_D.data()) - , batch_stride_A(args.batch_stride_A) - , batch_stride_B(args.batch_stride_B) - , epilogue_visitor(args.epilogue_visitor) - { - - ThreadblockSwizzle threadblock_swizzle; - - grid_tiled_shape = threadblock_swizzle.get_tiled_shape(args.problem_size, - {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, args.batch_count); - - if (args.mode == GemmUniversalMode::kGemm || args.mode == GemmUniversalMode::kGemmSplitKParallel) - { - - int const kAlignK - = const_max(const_max(128 / sizeof_bits::value, 128 / sizeof_bits::value), 1); - - gemm_k_size = round_up(ceil_div(args.problem_size.k(), args.batch_count), kAlignK); - - if (gemm_k_size) - { - grid_tiled_shape.k() = ceil_div(args.problem_size.k(), gemm_k_size); - } - } - - swizzle_log_tile = threadblock_swizzle.get_log_tile(grid_tiled_shape); - } - }; - - /// Shared memory storage structure - union SharedStorage - { + Arguments() : mode(GemmUniversalMode::kGemm), batch_count(1) {} + + /// constructs an arguments structure + Arguments(GemmUniversalMode mode_, + GemmCoord problem_size_, + int batch_count_, + TensorRefA ref_A_, + TensorRefB ref_B_, + tk::QuantMode quant_option_, + TensorRefAlphaCol ref_alpha_col_, + TensorRefAlphaRow ref_alpha_row_, + TensorRefC ref_C_, + TensorRefC ref_D_, + int64_t batch_stride_A_, + int64_t batch_stride_B_, + typename EpilogueVisitor::Arguments epilogue_visitor_) + : mode(mode_), + problem_size(problem_size_), + batch_count(batch_count_), + ref_A(ref_A_), + ref_B(ref_B_), + quant_option(quant_option_), + ref_alpha_col(ref_alpha_col_), + ref_alpha_row(ref_alpha_row_), + ref_C(ref_C_), + ref_D(ref_D_), + batch_stride_A(batch_stride_A_), + batch_stride_B(batch_stride_B_), + batch_stride_D(0), + epilogue_visitor(epilogue_visitor_) {} + }; + + // + // Structure for precomputing values in host memory and passing to kernels + // + + /// Parameters structure + struct Params { + cutlass::gemm::GemmCoord problem_size; + cutlass::gemm::GemmCoord grid_tiled_shape; + int swizzle_log_tile; + + typename Mma::IteratorA::Params params_A; + typename Mma::IteratorB::Params params_B; + typename EpilogueVisitor::ScaleTileIterator::Params params_alpha_col; + typename EpilogueVisitor::ScaleTileIterator::Params params_alpha_row; + typename EpilogueVisitor::OutputTileIterator::Params params_C; + typename EpilogueVisitor::OutputTileIterator::Params params_D; + + GemmUniversalMode mode; + int batch_count; + int gemm_k_size; + + void* ptr_A; + void* ptr_B; + tk::QuantMode quant_option; + typename EpilogueVisitor::ScaleTileIterator::Element* ptr_alpha_col; + typename EpilogueVisitor::ScaleTileIterator::Element* ptr_alpha_row; + ElementC* ptr_C; + ElementC* ptr_D; + + int64_t batch_stride_A; + int64_t batch_stride_B; + + typename EpilogueVisitor::Params epilogue_visitor; - typename Mma::SharedStorage main_loop; - - struct - { - typename Epilogue::SharedStorage epilogue; - typename EpilogueVisitor::SharedStorage visitor; - } epilogue; - }; - -public: // // Methods // - CUTLASS_DEVICE - GemmWithEpilogueVisitor() {} - - /// Determines whether kernel satisfies alignment - static Status can_implement(cutlass::gemm::GemmCoord const& problem_size) - { - - CUTLASS_TRACE_HOST("GemmWithEpilogueVisitor::can_implement()"); - - static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; - static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; - static int const kAlignmentC = EpilogueVisitor::OutputTileIterator::kElementsPerAccess; - - bool isAMisaligned = false; - bool isBMisaligned = false; - bool isCMisaligned = false; - - if (platform::is_same::value) - { - isAMisaligned = problem_size.k() % kAlignmentA; - } - else if (platform::is_same::value) - { - isAMisaligned = problem_size.m() % kAlignmentA; - } - else if (platform::is_same>::value - || platform::is_same>::value) - { - isAMisaligned = problem_size.k() % kAlignmentA; + CUTLASS_HOST_DEVICE + Params() + : swizzle_log_tile(0), + params_A(0), + params_B(0), + params_alpha_col(0), + params_C(0), + params_D(0), + batch_count(0), + gemm_k_size(0), + mode(cutlass::gemm::GemmUniversalMode::kGemm), + ptr_A(nullptr), + ptr_B(nullptr), + ptr_alpha_col(nullptr), + ptr_alpha_row(nullptr), + ptr_C(nullptr), + ptr_D(nullptr), + batch_stride_A(0), + batch_stride_B(0) {} + + Params(Arguments const& args, + cutlass::gemm::GemmCoord const& grid_tiled_shape_, + int gemm_k_size_, + int* workspace_) + : problem_size(args.problem_size), + swizzle_log_tile(0), + params_A(args.ref_A.layout()), + params_B(args.ref_B.layout()), + params_alpha_col(args.ref_alpha_col.layout()), + params_alpha_row(args.ref_alpha_col.layout()), + params_C(args.ref_C.layout()), + params_D(args.ref_D.layout()), + mode(args.mode), + batch_count(args.batch_count), + gemm_k_size(args.problem_size.k()), + ptr_A(args.ref_A.data()), + ptr_B(args.ref_B.data()), + quant_option(args.quant_option), + ptr_alpha_col(args.ref_alpha_col.data()), + ptr_alpha_row(args.ref_alpha_row.data()), + ptr_C(args.ref_C.data()), + ptr_D(args.ref_D.data()), + batch_stride_A(args.batch_stride_A), + batch_stride_B(args.batch_stride_B), + epilogue_visitor(args.epilogue_visitor) { + ThreadblockSwizzle threadblock_swizzle; + + grid_tiled_shape = threadblock_swizzle.get_tiled_shape( + args.problem_size, + {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, + args.batch_count); + + if (args.mode == GemmUniversalMode::kGemm || + args.mode == GemmUniversalMode::kGemmSplitKParallel) { + int const kAlignK = + const_max(const_max(128 / sizeof_bits::value, + 128 / sizeof_bits::value), + 1); + + gemm_k_size = round_up( + ceil_div(args.problem_size.k(), args.batch_count), kAlignK); + + if (gemm_k_size) { + grid_tiled_shape.k() = ceil_div(args.problem_size.k(), gemm_k_size); } + } - if (platform::is_same::value) - { - isBMisaligned = problem_size.n() % kAlignmentB; - } - else if (platform::is_same::value) - { - isBMisaligned = problem_size.k() % kAlignmentB; - } - else if (platform::is_same>::value - || platform::is_same>::value) - { - isBMisaligned = problem_size.k() % kAlignmentB; - } + swizzle_log_tile = threadblock_swizzle.get_log_tile(grid_tiled_shape); + } + }; - if (platform::is_same::value) - { - isCMisaligned = problem_size.n() % kAlignmentC; - } - else if (platform::is_same::value) - { - isCMisaligned = problem_size.m() % kAlignmentC; - } - else if (platform::is_same>::value - || platform::is_same>::value) - { - isCMisaligned = problem_size.n() % kAlignmentC; - } + /// Shared memory storage structure + union SharedStorage { + typename Mma::SharedStorage main_loop; - if (isAMisaligned) - { - CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for A operand"); - return Status::kErrorMisalignedOperand; - } + struct { + typename Epilogue::SharedStorage epilogue; + typename EpilogueVisitor::SharedStorage visitor; + } epilogue; + }; - if (isBMisaligned) - { - CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for B operand"); - return Status::kErrorMisalignedOperand; - } + public: + // + // Methods + // - if (isCMisaligned) - { - CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for C operand"); - return Status::kErrorMisalignedOperand; - } + CUTLASS_DEVICE + GemmWithEpilogueVisitor() {} - CUTLASS_TRACE_HOST(" returning kSuccess"); + /// Determines whether kernel satisfies alignment + static Status can_implement(cutlass::gemm::GemmCoord const& problem_size) { + CUTLASS_TRACE_HOST("GemmWithEpilogueVisitor::can_implement()"); - return Status::kSuccess; + static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; + static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; + static int const kAlignmentC = + EpilogueVisitor::OutputTileIterator::kElementsPerAccess; + + bool isAMisaligned = false; + bool isBMisaligned = false; + bool isCMisaligned = false; + + if (platform::is_same::value) { + isAMisaligned = problem_size.k() % kAlignmentA; + } else if (platform::is_same::value) { + isAMisaligned = problem_size.m() % kAlignmentA; + } else if (platform::is_same>::value || + platform::is_same>::value) { + isAMisaligned = problem_size.k() % kAlignmentA; } - static Status can_implement(Arguments const& args) - { - return can_implement(args.problem_size); + if (platform::is_same::value) { + isBMisaligned = problem_size.n() % kAlignmentB; + } else if (platform::is_same::value) { + isBMisaligned = problem_size.k() % kAlignmentB; + } else if (platform::is_same>::value || + platform::is_same>::value) { + isBMisaligned = problem_size.k() % kAlignmentB; } - static size_t get_extra_workspace_size(Arguments const& args, cutlass::gemm::GemmCoord const& grid_tiled_shape) - { - - return 0; + if (platform::is_same::value) { + isCMisaligned = problem_size.n() % kAlignmentC; + } else if (platform::is_same::value) { + isCMisaligned = problem_size.m() % kAlignmentC; + } else if (platform::is_same>::value || + platform::is_same>::value) { + isCMisaligned = problem_size.n() % kAlignmentC; } -#define SPLIT_K_ENABLED 1 - - /// Executes one GEMM - CUTLASS_DEVICE - void run_kernel_(Params const& params, SharedStorage& shared_storage) - { + if (isAMisaligned) { + CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for A operand"); + return Status::kErrorMisalignedOperand; + } - // Compute threadblock location - ThreadblockSwizzle threadblock_swizzle; + if (isBMisaligned) { + CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for B operand"); + return Status::kErrorMisalignedOperand; + } - cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + if (isCMisaligned) { + CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for C operand"); + return Status::kErrorMisalignedOperand; + } - // Early exit if CTA is out of range - if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() - || params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) - { + CUTLASS_TRACE_HOST(" returning kSuccess"); - return; - } + return Status::kSuccess; + } - int offset_k = 0; - int problem_size_k = params.problem_size.k(); + static Status can_implement(Arguments const& args) { + return can_implement(args.problem_size); + } - ElementA* ptr_A = static_cast(params.ptr_A); - ElementB* ptr_B = static_cast(params.ptr_B); + static size_t get_extra_workspace_size( + Arguments const& args, cutlass::gemm::GemmCoord const& grid_tiled_shape) { + return 0; + } -#if SPLIT_K_ENABLED - // - // Fetch pointers based on mode. - // - if (params.mode == GemmUniversalMode::kGemm || params.mode == GemmUniversalMode::kGemmSplitKParallel) - { +#define SPLIT_K_ENABLED 1 - if (threadblock_tile_offset.k() + 1 < params.grid_tiled_shape.k()) - { + /// Executes one GEMM + CUTLASS_DEVICE + void run_kernel_(Params const& params, SharedStorage& shared_storage) { + // Compute threadblock location + ThreadblockSwizzle threadblock_swizzle; - problem_size_k = (threadblock_tile_offset.k() + 1) * params.gemm_k_size; - } + cutlass::gemm::GemmCoord threadblock_tile_offset = + threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); - offset_k = threadblock_tile_offset.k() * params.gemm_k_size; - } - else if (params.mode == GemmUniversalMode::kBatched) - { - ptr_A += threadblock_tile_offset.k() * params.batch_stride_A; - ptr_B += threadblock_tile_offset.k() * params.batch_stride_B; - } - else if (params.mode == GemmUniversalMode::kArray) - { - ptr_A = static_cast(params.ptr_A)[threadblock_tile_offset.k()]; - ptr_B = static_cast(params.ptr_B)[threadblock_tile_offset.k()]; - } -#endif + // Early exit if CTA is out of range + if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() || + params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) { + return; + } - // Compute initial location in logical coordinates - cutlass::MatrixCoord tb_offset_A{ - threadblock_tile_offset.m() * Mma::Shape::kM, - offset_k, - }; + int offset_k = 0; + int problem_size_k = params.problem_size.k(); - cutlass::MatrixCoord tb_offset_B{offset_k, threadblock_tile_offset.n() * Mma::Shape::kN}; + ElementA* ptr_A = static_cast(params.ptr_A); + ElementB* ptr_B = static_cast(params.ptr_B); - // Compute position within threadblock - int thread_idx = threadIdx.x; +#if SPLIT_K_ENABLED + // + // Fetch pointers based on mode. + // + if (params.mode == GemmUniversalMode::kGemm || + params.mode == GemmUniversalMode::kGemmSplitKParallel) { + if (threadblock_tile_offset.k() + 1 < params.grid_tiled_shape.k()) { + problem_size_k = (threadblock_tile_offset.k() + 1) * params.gemm_k_size; + } + + offset_k = threadblock_tile_offset.k() * params.gemm_k_size; + } else if (params.mode == GemmUniversalMode::kBatched) { + ptr_A += threadblock_tile_offset.k() * params.batch_stride_A; + ptr_B += threadblock_tile_offset.k() * params.batch_stride_B; + } else if (params.mode == GemmUniversalMode::kArray) { + ptr_A = static_cast( + params.ptr_A)[threadblock_tile_offset.k()]; + ptr_B = static_cast( + params.ptr_B)[threadblock_tile_offset.k()]; + } +#endif - // Construct iterators to A and B operands - typename Mma::IteratorA iterator_A( - params.params_A, ptr_A, {params.problem_size.m(), problem_size_k}, thread_idx, tb_offset_A); + // Compute initial location in logical coordinates + cutlass::MatrixCoord tb_offset_A{ + threadblock_tile_offset.m() * Mma::Shape::kM, + offset_k, + }; - typename Mma::IteratorB iterator_B( - params.params_B, ptr_B, {problem_size_k, params.problem_size.n()}, thread_idx, tb_offset_B); + cutlass::MatrixCoord tb_offset_B{ + offset_k, threadblock_tile_offset.n() * Mma::Shape::kN}; - // Broadcast the warp_id computed by lane 0 to ensure dependent code - // is compiled as warp-uniform. - int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + // Compute position within threadblock + int thread_idx = threadIdx.x; - int lane_idx = threadIdx.x % 32; + // Construct iterators to A and B operands + typename Mma::IteratorA iterator_A( + params.params_A, + ptr_A, + {params.problem_size.m(), problem_size_k}, + thread_idx, + tb_offset_A); - // - // Main loop - // + typename Mma::IteratorB iterator_B( + params.params_B, + ptr_B, + {problem_size_k, params.problem_size.n()}, + thread_idx, + tb_offset_B); - // Construct thread-scoped matrix multiply - Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); + // Broadcast the warp_id computed by lane 0 to ensure dependent code + // is compiled as warp-uniform. + int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); - typename Mma::FragmentC accumulators; + int lane_idx = threadIdx.x % 32; - accumulators.clear(); + // + // Main loop + // - // Compute threadblock-scoped matrix multiply-add - int gemm_k_iterations = (problem_size_k - offset_k + Mma::Shape::kK - 1) / Mma::Shape::kK; + // Construct thread-scoped matrix multiply + Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); - // Compute threadblock-scoped matrix multiply-add - mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators); + typename Mma::FragmentC accumulators; - // - // Masked tile iterators constructed from members - // + accumulators.clear(); - threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + // Compute threadblock-scoped matrix multiply-add + int gemm_k_iterations = + (problem_size_k - offset_k + Mma::Shape::kK - 1) / Mma::Shape::kK; - // assume identity swizzle - MatrixCoord threadblock_offset( - threadblock_tile_offset.m() * Mma::Shape::kM, threadblock_tile_offset.n() * Mma::Shape::kN); + // Compute threadblock-scoped matrix multiply-add + mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators); - int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m(); + // + // Masked tile iterators constructed from members + // - // - // Construct the epilogue visitor - // + threadblock_tile_offset = + threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); - EpilogueVisitor epilogue_visitor(params.epilogue_visitor, shared_storage.epilogue.visitor, - params.problem_size.mn(), thread_idx, warp_idx, lane_idx, params.params_alpha_col, params.params_C, - params.params_D, params.quant_option, params.ptr_alpha_row, params.ptr_alpha_col, params.ptr_C, - params.ptr_D, threadblock_offset, blockIdx.y * params.problem_size.m()); + // assume identity swizzle + MatrixCoord threadblock_offset( + threadblock_tile_offset.m() * Mma::Shape::kM, + threadblock_tile_offset.n() * Mma::Shape::kN); - if (params.mode == GemmUniversalMode::kGemm) - { - // Indicate which position in a serial reduction the output operator is currently updating - epilogue_visitor.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k()); - } - else if (params.mode == GemmUniversalMode::kBatched || params.mode == GemmUniversalMode::kArray) - { - epilogue_visitor.set_batch_index(threadblock_tile_offset.k()); - } + int block_idx = threadblock_tile_offset.m() + + threadblock_tile_offset.n() * params.grid_tiled_shape.m(); - // Construct the epilogue - Epilogue epilogue(shared_storage.epilogue.epilogue, thread_idx, warp_idx, lane_idx); + // + // Construct the epilogue visitor + // - // Execute the epilogue operator to update the destination tensor. - epilogue(epilogue_visitor, accumulators); + EpilogueVisitor epilogue_visitor(params.epilogue_visitor, + shared_storage.epilogue.visitor, + params.problem_size.mn(), + thread_idx, + warp_idx, + lane_idx, + params.params_alpha_col, + params.params_C, + params.params_D, + params.quant_option, + params.ptr_alpha_row, + params.ptr_alpha_col, + params.ptr_C, + params.ptr_D, + threadblock_offset, + blockIdx.y * params.problem_size.m()); + + if (params.mode == GemmUniversalMode::kGemm) { + // Indicate which position in a serial reduction the output operator is + // currently updating + epilogue_visitor.set_k_partition(threadblock_tile_offset.k(), + params.grid_tiled_shape.k()); + } else if (params.mode == GemmUniversalMode::kBatched || + params.mode == GemmUniversalMode::kArray) { + epilogue_visitor.set_batch_index(threadblock_tile_offset.k()); } - template - CUTLASS_DEVICE void run_kernel(Params const& params, SharedStorage& shared_storage) - { - if constexpr (platform::is_same::value) - { - run_kernel_(params, shared_storage); - } - else - { - CUTLASS_NOT_IMPLEMENTED(); - } + // Construct the epilogue + Epilogue epilogue( + shared_storage.epilogue.epilogue, thread_idx, warp_idx, lane_idx); + + // Execute the epilogue operator to update the destination tensor. + epilogue(epilogue_visitor, accumulators); + } + + template + CUTLASS_DEVICE void run_kernel(Params const& params, + SharedStorage& shared_storage) { + if constexpr (platform::is_same::value) { + run_kernel_(params, shared_storage); + } else { + CUTLASS_NOT_IMPLEMENTED(); } - - /* - To improve compilation speed, we do not compile the device operator if the CUDA_ARCH does not correspond - to the ArchTag of the cutlass kernel operator. - */ - /// Executes one GEMM - CUTLASS_DEVICE - void operator()(Params const& params, SharedStorage& shared_storage) - { + } + + /* + To improve compilation speed, we do not compile the device operator if the + CUDA_ARCH does not correspond to the ArchTag of the cutlass kernel + operator. + */ + /// Executes one GEMM + CUTLASS_DEVICE + void operator()(Params const& params, SharedStorage& shared_storage) { #if defined(__CUDA_ARCH__) #if (__CUDA_ARCH__ >= 700) && (__CUDA_ARCH__ < 720) - run_kernel(params, shared_storage); + run_kernel(params, shared_storage); #elif (__CUDA_ARCH__ >= 720) && (__CUDA_ARCH__ < 750) - run_kernel(params, shared_storage); + run_kernel(params, shared_storage); #elif (__CUDA_ARCH__ >= 750) && (__CUDA_ARCH__ < 800) - run_kernel(params, shared_storage); + run_kernel(params, shared_storage); #elif (__CUDA_ARCH__ >= 800) && (__CUDA_ARCH__ < 900) - run_kernel(params, shared_storage); + run_kernel(params, shared_storage); #elif (__CUDA_ARCH__ >= 900) - // TODO - replace with CUTLASS_NOT_IMPLEMENTED() and upgrade to 3.x kernels. - run_kernel(params, shared_storage); + // TODO - replace with CUTLASS_NOT_IMPLEMENTED() and upgrade to 3.x kernels. + run_kernel(params, shared_storage); #else - static_assert( - false, "Invalid architecture being compiled. Only Volta+ supported in weight-only quantization kernels."); + static_assert(false, + "Invalid architecture being compiled. Only Volta+ supported " + "in weight-only quantization kernels."); #endif #else - CUTLASS_NOT_IMPLEMENTED(); + CUTLASS_NOT_IMPLEMENTED(); #endif - } + } }; ///////////////////////////////////////////////////////////////////////////////////////////////// -} // namespace kernel -} // namespace gemm -} // namespace cutlass +} // namespace kernel +} // namespace gemm +} // namespace cutlass ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h b/custom_ops/gpu_ops/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h index 40f128b7a0a..56140c81d60 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h @@ -1,6 +1,6 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: Apache-2.0 + * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & + * AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,9 +15,10 @@ * limitations under the License. */ /* - This file exists so that we use the same weight layout for MoE grouped gemm and regular gemm when the weight is - quantized. The preprocessing code reads this template to know how to organize the quantized weight matrices - to be consumed by CUTLASS. + This file exists so that we use the same weight layout for MoE grouped gemm + and regular gemm when the weight is quantized. The preprocessing code reads + this template to know how to organize the quantized weight matrices to be + consumed by CUTLASS. Note that for int4, ThreadBlockK MUST be 64. @@ -35,128 +36,172 @@ #include "cutlass_extensions/arch/mma.h" #include "cutlass_extensions/tile_interleaved_layout.h" -namespace cutlass -{ -namespace gemm -{ -namespace kernel -{ +namespace cutlass { +namespace gemm { +namespace kernel { template -struct LayoutDetailsB -{ -}; +struct LayoutDetailsB {}; -// Volta specialiations. Volta will dequantize before STS, so we need a different operator +// Volta specialiations. Volta will dequantize before STS, so we need a +// different operator template -struct LayoutDetailsB -{ - static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits::value; - using Layout = layout::ColumnMajor; - static constexpr int ElementsPerAccess = 8; - using Operator = cutlass::arch::OpMultiplyAdd; +struct LayoutDetailsB { + static constexpr int ThreadblockK = + 128 * 8 / cutlass::sizeof_bits::value; + using Layout = layout::ColumnMajor; + static constexpr int ElementsPerAccess = 8; + using Operator = cutlass::arch::OpMultiplyAdd; }; -// Specializations for Turing+ when B is FP16. These are currently only used for MoE networks. -// TODO - Switch this to column major for weights since gemms should be more performant. +// Specializations for Turing+ when B is FP16. These are currently only used for +// MoE networks. +// TODO - Switch this to column major for weights since gemms should be more +// performant. template -struct LayoutDetailsB= 75>::type> -{ - static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits::value; - using Layout = layout::ColumnMajor; - static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits::value; - using Operator = cutlass::arch::OpMultiplyAdd; +struct LayoutDetailsB< + TypeA, + half_t, + Arch, + typename platform::enable_if= 75>::type> { + static constexpr int ThreadblockK = + 128 * 8 / cutlass::sizeof_bits::value; + using Layout = layout::ColumnMajor; + static constexpr int ElementsPerAccess = + 128 / cutlass::sizeof_bits::value; + using Operator = cutlass::arch::OpMultiplyAdd; }; template -struct LayoutDetailsB= 75>::type> -{ - static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits::value; - using Layout = layout::ColumnMajor; - static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits::value; - using Operator = cutlass::arch::OpMultiplyAdd; +struct LayoutDetailsB< + TypeA, + bfloat16_t, + Arch, + typename platform::enable_if= 75>::type> { + static constexpr int ThreadblockK = + 128 * 8 / cutlass::sizeof_bits::value; + using Layout = layout::ColumnMajor; + static constexpr int ElementsPerAccess = + 128 / cutlass::sizeof_bits::value; + using Operator = cutlass::arch::OpMultiplyAdd; }; template -struct LayoutDetailsB -{ - static constexpr int ThreadblockK = 64; - -private: - static constexpr int ElementsPerCacheLine = 128 * 8 / sizeof_bits::value; - static constexpr int ColumnsInterleaved = ElementsPerCacheLine / ThreadblockK; - -public: - using Layout = layout::ColumnMajor; - static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits::value; - using Operator = cutlass::arch::OpMultiplyAdd; - // for fast accumulation - // using Operator = cutlass::arch::OpMultiplyAddFastAccum; +struct LayoutDetailsB { + static constexpr int ThreadblockK = 64; + + private: + static constexpr int ElementsPerCacheLine = + 128 * 8 / sizeof_bits::value; + static constexpr int ColumnsInterleaved = ElementsPerCacheLine / ThreadblockK; + + public: + using Layout = layout::ColumnMajor; + static constexpr int ElementsPerAccess = + 128 / cutlass::sizeof_bits::value; + using Operator = cutlass::arch::OpMultiplyAdd; + // for fast accumulation + // using Operator = cutlass::arch::OpMultiplyAddFastAccum; }; -// Specializations for Turing+ when B is quantized. These can use the operator OpMultiplyAddDequantizeInterleavedBToA, -// which signals that we want to dequantize after loading from smem. +// Specializations for Turing+ when B is quantized. These can use the operator +// OpMultiplyAddDequantizeInterleavedBToA, which signals that we want to +// dequantize after loading from smem. template struct LayoutDetailsB < TypeA, uint8_t, Arch, - typename platform::enable_if= 75 && Arch::kMinComputeCapability<90>::type> -{ - static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits::value; - -private: - static constexpr int ElementsPerCacheLine = 128 * 8 / sizeof_bits::value; - static constexpr int ColumnsInterleaved = ElementsPerCacheLine / ThreadblockK; - -public: - using Layout = layout::ColumnMajorTileInterleave; - static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits::value; - using Operator = cutlass::arch::OpMultiplyAddDequantizeInterleavedBToA; + typename platform::enable_if= 75 && + Arch::kMinComputeCapability<90>::type> { + static constexpr int ThreadblockK = + 128 * 8 / cutlass::sizeof_bits::value; + + private: + static constexpr int ElementsPerCacheLine = + 128 * 8 / sizeof_bits::value; + static constexpr int ColumnsInterleaved = ElementsPerCacheLine / ThreadblockK; + + public: + using Layout = + layout::ColumnMajorTileInterleave; + static constexpr int ElementsPerAccess = + 128 / cutlass::sizeof_bits::value; + using Operator = cutlass::arch::OpMultiplyAddDequantizeInterleavedBToA; }; template struct LayoutDetailsB < TypeA, uint4b_t, Arch, - typename platform::enable_if= 75 && Arch::kMinComputeCapability<90>::type> -{ - static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits::value; - -private: - static constexpr int ElementsPerCacheLine = 128 * 8 / sizeof_bits::value; - static constexpr int ColumnsInterleaved = ElementsPerCacheLine / ThreadblockK; - -public: - using Layout = layout::ColumnMajorTileInterleave; - static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits::value; - using Operator = cutlass::arch::OpMultiplyAddDequantizeInterleavedBToA; + typename platform::enable_if= 75 && + Arch::kMinComputeCapability<90>::type> { + static constexpr int ThreadblockK = + 128 * 8 / cutlass::sizeof_bits::value; + + private: + static constexpr int ElementsPerCacheLine = + 128 * 8 / sizeof_bits::value; + static constexpr int ColumnsInterleaved = ElementsPerCacheLine / ThreadblockK; + + public: + using Layout = + layout::ColumnMajorTileInterleave; + static constexpr int ElementsPerAccess = + 128 / cutlass::sizeof_bits::value; + using Operator = cutlass::arch::OpMultiplyAddDequantizeInterleavedBToA; }; template -struct LayoutDetailsB= 75>::type> -{ - static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits::value; - using Layout = layout::RowMajor; - static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits::value; - using Operator = cutlass::arch::OpMultiplyAdd; +struct LayoutDetailsB< + TypeA, + uint2b_t, + Arch, + typename platform::enable_if= 75>::type> { + static constexpr int ThreadblockK = + 128 * 8 / cutlass::sizeof_bits::value; // 64 + + private: + static constexpr int ElementsPerCacheLine = + 128 * 8 / sizeof_bits::value; + static constexpr int ColumnsInterleaved = + ElementsPerCacheLine / ThreadblockK; // 8 + + public: + // using Layout = layout::ColumnMajor; + // static constexpr int ElementsPerAccess = 16; // at least 4-bytes + using Layout = + layout::ColumnMajorTileInterleave; + static constexpr int ElementsPerAccess = + 128 / cutlass::sizeof_bits::value; // 64 + using Operator = cutlass::arch::OpMultiplyAddDequantizeInterleavedBToA; }; template -struct LayoutDetailsB= 90>::type> -{ - static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits::value; - using Layout = layout::ColumnMajor; - static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits::value; - using Operator = cutlass::arch::OpMultiplyAdd; +struct LayoutDetailsB< + TypeA, + uint8_t, + Arch, + typename platform::enable_if= 90>::type> { + static constexpr int ThreadblockK = + 128 * 8 / cutlass::sizeof_bits::value; + using Layout = layout::ColumnMajor; + static constexpr int ElementsPerAccess = + 128 / cutlass::sizeof_bits::value; + using Operator = cutlass::arch::OpMultiplyAdd; }; template -struct LayoutDetailsB= 90>::type> -{ - static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits::value; - using Layout = layout::ColumnMajor; - static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits::value; - using Operator = cutlass::arch::OpMultiplyAdd; +struct LayoutDetailsB< + TypeA, + uint4b_t, + Arch, + typename platform::enable_if= 90>::type> { + static constexpr int ThreadblockK = + 128 * 8 / cutlass::sizeof_bits::value; + using Layout = layout::ColumnMajor; + static constexpr int ElementsPerAccess = + 128 / cutlass::sizeof_bits::value; + using Operator = cutlass::arch::OpMultiplyAdd; }; -} // namespace kernel -} // namespace gemm -} // namespace cutlass +} // namespace kernel +} // namespace gemm +} // namespace cutlass diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/kernel/moe_problem_visitor.h b/custom_ops/gpu_ops/cutlass_extensions/gemm/kernel/moe_problem_visitor.h index b9126e3500f..a2376e6b8b5 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/gemm/kernel/moe_problem_visitor.h +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/kernel/moe_problem_visitor.h @@ -160,11 +160,11 @@ struct BaseMoeProblemVisitor { CUTLASS_HOST_DEVICE cutlass::gemm::GemmCoord problem_size(int idx) const { - int64_t gemm_m = 0; if (params.total_rows < 0) { - const int64_t prev_problem_row = idx == 0 ? 0 : params.last_row_for_problem[idx - 1]; + const int64_t prev_problem_row = + idx == 0 ? 0 : params.last_row_for_problem[idx - 1]; const int64_t current_problem_row = params.last_row_for_problem[idx]; gemm_m = current_problem_row - prev_problem_row; } else { diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/kernel/sm90_gemm_gated_tma_warpspecialized_cooperative.hpp b/custom_ops/gpu_ops/cutlass_extensions/gemm/kernel/sm90_gemm_gated_tma_warpspecialized_cooperative.hpp index 843529cde55..efb90a5e7c6 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/gemm/kernel/sm90_gemm_gated_tma_warpspecialized_cooperative.hpp +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/kernel/sm90_gemm_gated_tma_warpspecialized_cooperative.hpp @@ -53,15 +53,20 @@ namespace cutlass::gemm::kernel { /////////////////////////////////////////////////////////////////////////////// -template +template class GemmUniversalGated< - ProblemShape_, CollectiveMainloop_, CollectiveEpilogue_, TileScheduler_, + ProblemShape_, + CollectiveMainloop_, + CollectiveEpilogue_, + TileScheduler_, cute::enable_if_t && CollectiveMainloop_::isGated>> { -public: + public: // // Type Aliases // @@ -98,7 +103,9 @@ class GemmUniversalGated< using TileSchedulerTag = TileScheduler_; using TileScheduler = - typename detail::TileSchedulerSelector::Scheduler; using TileSchedulerArguments = typename TileScheduler::Arguments; using TileSchedulerParams = typename TileScheduler::Params; @@ -222,9 +229,14 @@ class GemmUniversalGated< // used, therefore separate reduction will not be enabled. constexpr uint32_t NumEpilogueSubTiles = CollectiveEpilogue::get_store_pipe_increment(TileShape{}); - TileSchedulerParams scheduler = TileScheduler::to_underlying_arguments( - problem_shape_MNKL, TileShape{}, ClusterShape{}, hw_info, - args.scheduler, scheduler_workspace, NumEpilogueSubTiles); + TileSchedulerParams scheduler = + TileScheduler::to_underlying_arguments(problem_shape_MNKL, + TileShape{}, + ClusterShape{}, + hw_info, + args.scheduler, + scheduler_workspace, + NumEpilogueSubTiles); return {args.mode, problem_shape, @@ -242,8 +254,9 @@ class GemmUniversalGated< (args.mode == GemmUniversalMode::kBatched && cute::rank(ProblemShape{}) == 4); if (!implementable) { - CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Arguments or Problem Shape don't " - "meet the requirements.\n"); + CUTLASS_TRACE_HOST( + " CAN IMPLEMENT: Arguments or Problem Shape don't " + "meet the requirements.\n"); return implementable; } implementable &= @@ -262,7 +275,10 @@ class GemmUniversalGated< workspace_size += TileScheduler::template get_workspace_size( - args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups, + args.scheduler, + args.problem_shape, + args.hw_info, + NumMmaWarpGroups, NumEpilogueSubTiles); workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); @@ -273,10 +289,11 @@ class GemmUniversalGated< return workspace_size; } - static cutlass::Status - initialize_workspace(Arguments const &args, void *workspace = nullptr, - cudaStream_t stream = nullptr, - CudaHostAdapter *cuda_adapter = nullptr) { + static cutlass::Status initialize_workspace( + Arguments const &args, + void *workspace = nullptr, + cudaStream_t stream = nullptr, + CudaHostAdapter *cuda_adapter = nullptr) { Status status = Status::kSuccess; uint8_t *workspace_ptr = reinterpret_cast(workspace); size_t workspace_offset = 0; @@ -285,13 +302,20 @@ class GemmUniversalGated< status = TileScheduler::template initialize_workspace( - args.scheduler, workspace_ptr + workspace_offset, stream, - args.problem_shape, args.hw_info, NumMmaWarpGroups, + args.scheduler, + workspace_ptr + workspace_offset, + stream, + args.problem_shape, + args.hw_info, + NumMmaWarpGroups, NumEpilogueSubTiles); workspace_offset += TileScheduler::template get_workspace_size( - args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups, + args.scheduler, + args.problem_shape, + args.hw_info, + NumMmaWarpGroups, NumEpilogueSubTiles); workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); if (status != Status::kSuccess) { @@ -299,8 +323,11 @@ class GemmUniversalGated< } status = CollectiveEpilogue::initialize_workspace( - args.problem_shape, args.epilogue, workspace_ptr + workspace_offset, - stream, cuda_adapter); + args.problem_shape, + args.epilogue, + workspace_ptr + workspace_offset, + stream, + cuda_adapter); workspace_offset += CollectiveEpilogue::get_workspace_size( args.problem_shape, args.epilogue); workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); @@ -323,9 +350,12 @@ class GemmUniversalGated< params.scheduler.raster_order_ == TileScheduler::RasterOrder::AlongN ? TileScheduler::RasterOrderOptions::AlongN : TileScheduler::RasterOrderOptions::AlongM; - return TileScheduler::get_grid_shape(params.scheduler, params.problem_shape, - TileShape{}, ClusterShape{}, - params.hw_info, args); + return TileScheduler::get_grid_shape(params.scheduler, + params.problem_shape, + TileShape{}, + ClusterShape{}, + params.hw_info, + args); } static dim3 get_block_shape() { return dim3(MaxThreadsPerBlock, 1, 1); } @@ -337,8 +367,9 @@ class GemmUniversalGated< // Any Tensor Op MMA Atom in the WGMMA ISA is arch conditional to sm90a. #if !defined(__CUDA_ARCH_FEAT_SM90_ALL) - printf("ERROR : Arch conditional MMA instruction used without targeting " - "sm90a compute capability. Aborting.\n"); + printf( + "ERROR : Arch conditional MMA instruction used without targeting " + "sm90a compute capability. Aborting.\n"); #else // Preconditions @@ -469,7 +500,7 @@ class GemmUniversalGated< return []() { cute::cluster_wait(); }; } else { __syncthreads(); - return []() {}; // do nothing + return []() {}; // do nothing } }(); @@ -480,7 +511,7 @@ class GemmUniversalGated< // Get the appropriate blocks for this thread block -- potential for thread // block locality TiledMma tiled_mma; - auto blk_shape = TileShape{}; // (BLK_M,BLK_N,BLK_K) + auto blk_shape = TileShape{}; // (BLK_M,BLK_N,BLK_K) TileScheduler scheduler{params.scheduler}; auto work_tile_info = scheduler.get_current_work(); @@ -540,10 +571,16 @@ class GemmUniversalGated< auto k_tile_iter = cute::make_coord_iterator( idx2crd(work_k_tile_start, shape<3>(gA_mkl)), shape<3>(gA_mkl)); - collective_mainloop.load( - params.mainloop, mainloop_pipeline, mainloop_pipe_producer_state, - load_inputs, blk_coord, k_tile_iter, work_k_tile_count, lane_idx, - block_rank_in_cluster, shared_storage.tensors.mainloop); + collective_mainloop.load(params.mainloop, + mainloop_pipeline, + mainloop_pipe_producer_state, + load_inputs, + blk_coord, + k_tile_iter, + work_k_tile_count, + lane_idx, + block_rank_in_cluster, + shared_storage.tensors.mainloop); // Update starting pipeline state for the next tile mainloop_pipe_producer_state.advance(work_k_tile_count); @@ -555,12 +592,12 @@ class GemmUniversalGated< // Get next work tile work_tile_info = fetch_next_work(work_tile_info, scheduler); - } // Scheduler work fetch loop + } // Scheduler work fetch loop // Make sure all Consumer Warp Groups have been waited upon collective_mainloop.load_tail(mainloop_pipeline, mainloop_pipe_producer_state); - } // Mainloop Producer Warp End + } // Mainloop Producer Warp End // Epilogue Producer Warp else if (producer_warp_role == ProducerWarpRole::Epilogue && @@ -579,21 +616,26 @@ class GemmUniversalGated< auto blk_coord = make_coord(m_coord, n_coord, _, l_coord); epi_load_pipe_producer_state = collective_epilogue.load( - epi_load_pipeline, epi_load_pipe_producer_state, - problem_shape_MNKL, blk_shape, blk_coord, tiled_mma, lane_idx, + epi_load_pipeline, + epi_load_pipe_producer_state, + problem_shape_MNKL, + blk_shape, + blk_coord, + tiled_mma, + lane_idx, shared_storage.tensors.epilogue, work_tile_info.reduction_subtile_idx()); } // Get next work tile work_tile_info = fetch_next_work(work_tile_info, scheduler); - } // Scheduler work fetch loop + } // Scheduler work fetch loop // Make sure all Consumer Warp Groups have been waited upon collective_epilogue.load_tail(epi_load_pipeline, epi_load_pipe_producer_state); - } // Epilogue Producer Warp End - } // Producer Warp Group End + } // Epilogue Producer Warp End + } // Producer Warp Group End else if (warp_group_role == WarpGroupRole::Consumer0 || warp_group_role == WarpGroupRole::Consumer1) { @@ -618,14 +660,18 @@ class GemmUniversalGated< // // MSVC CTAD breaks if we say "Tensor" here, so we use "auto" instead. auto accumulators0 = partition_fragment_C( - tiled_mma, take<0, 2>(blk_shape)); // (MMA,MMA_M,MMA_N) + tiled_mma, take<0, 2>(blk_shape)); // (MMA,MMA_M,MMA_N) auto accumulators1 = partition_fragment_C( - tiled_mma, take<0, 2>(blk_shape)); // (MMA,MMA_M,MMA_N) + tiled_mma, take<0, 2>(blk_shape)); // (MMA,MMA_M,MMA_N) if (TileScheduler::valid_warpgroup_in_work_tile(work_tile_info)) { - collective_mainloop.mma( - mainloop_pipeline, mainloop_pipe_consumer_state, accumulators0, - accumulators1, work_k_tile_count, mma_thread_idx, - shared_storage.tensors.mainloop, params.mainloop); + collective_mainloop.mma(mainloop_pipeline, + mainloop_pipe_consumer_state, + accumulators0, + accumulators1, + work_k_tile_count, + mma_thread_idx, + shared_storage.tensors.mainloop, + params.mainloop); // Make sure the math instructions are done and free buffers before // entering the epilogue @@ -641,10 +687,16 @@ class GemmUniversalGated< canonical_warp_group_idx() - NumLoadWarpGroups; // Perform reduction across splits, if needed - TileScheduler::fixup(params.scheduler, work_tile_info, accumulators0, - NumMmaWarpGroups, consumer_warp_group_idx); - TileScheduler::fixup(params.scheduler, work_tile_info, accumulators1, - NumMmaWarpGroups, consumer_warp_group_idx); + TileScheduler::fixup(params.scheduler, + work_tile_info, + accumulators0, + NumMmaWarpGroups, + consumer_warp_group_idx); + TileScheduler::fixup(params.scheduler, + work_tile_info, + accumulators1, + NumMmaWarpGroups, + consumer_warp_group_idx); Activation elt_op; CUTLASS_PRAGMA_UNROLL @@ -657,12 +709,18 @@ class GemmUniversalGated< // Epilogue and write to gD auto [epi_load_pipe_consumer_state_next, epi_store_pipe_producer_state_next] = - collective_epilogue.store( - epi_load_pipeline, epi_load_pipe_consumer_state, - epi_store_pipeline, epi_store_pipe_producer_state, - problem_shape_MNKL, blk_shape, blk_coord, accumulators0, - tiled_mma, mma_thread_idx, shared_storage.tensors.epilogue, - work_tile_info.reduction_subtile_idx()); + collective_epilogue.store(epi_load_pipeline, + epi_load_pipe_consumer_state, + epi_store_pipeline, + epi_store_pipe_producer_state, + problem_shape_MNKL, + blk_shape, + blk_coord, + accumulators0, + tiled_mma, + mma_thread_idx, + shared_storage.tensors.epilogue, + work_tile_info.reduction_subtile_idx()); epi_load_pipe_consumer_state = epi_load_pipe_consumer_state_next; epi_store_pipe_producer_state = epi_store_pipe_producer_state_next; do_store_tail = true; @@ -670,23 +728,24 @@ class GemmUniversalGated< // Get next work tile work_tile_info = fetch_next_work(work_tile_info, scheduler); - } // Scheduler work fetch loop + } // Scheduler work fetch loop if (do_store_tail) { - collective_epilogue.store_tail( - epi_load_pipeline, epi_load_pipe_consumer_state, epi_store_pipeline, - epi_store_pipe_producer_state); + collective_epilogue.store_tail(epi_load_pipeline, + epi_load_pipe_consumer_state, + epi_store_pipeline, + epi_store_pipe_producer_state); } - } // Consumer Warp Groups End + } // Consumer Warp Groups End #endif } -private: + private: // Kernel helper function to get next work unit CUTLASS_DEVICE - typename TileScheduler::WorkTileInfo - fetch_next_work(typename TileScheduler::WorkTileInfo &work_tile_info, - TileScheduler &scheduler) const { + typename TileScheduler::WorkTileInfo fetch_next_work( + typename TileScheduler::WorkTileInfo &work_tile_info, + TileScheduler &scheduler) const { // Check whether we should continue on with the current work unit. If this // is the case, the work unit will have been updated in // continue_current_work to reflect the new tile to be computed. @@ -702,4 +761,4 @@ class GemmUniversalGated< /////////////////////////////////////////////////////////////////////////////// -} // namespace cutlass::gemm::kernel +} // namespace cutlass::gemm::kernel diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/kernel/sm90_gemm_gated_tma_warpspecialized_pingpong.hpp b/custom_ops/gpu_ops/cutlass_extensions/gemm/kernel/sm90_gemm_gated_tma_warpspecialized_pingpong.hpp index e6cc7de5c61..9609adc32a7 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/gemm/kernel/sm90_gemm_gated_tma_warpspecialized_pingpong.hpp +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/kernel/sm90_gemm_gated_tma_warpspecialized_pingpong.hpp @@ -56,15 +56,20 @@ namespace cutlass::gemm::kernel { /////////////////////////////////////////////////////////////////////////////// -template +template class GemmUniversalGated< - ProblemShape_, CollectiveMainloop_, CollectiveEpilogue_, TileScheduler_, + ProblemShape_, + CollectiveMainloop_, + CollectiveEpilogue_, + TileScheduler_, cute::enable_if_t && CollectiveMainloop_::isGated>> { -public: + public: // // Type Aliases // @@ -103,7 +108,9 @@ class GemmUniversalGated< "Ping-pong kernel does not currently support stream-K scheduler."); using TileSchedulerTag = TileScheduler_; using TileScheduler = - typename detail::TileSchedulerSelector::Scheduler; using TileSchedulerArguments = typename TileScheduler::Arguments; using TileSchedulerParams = typename TileScheduler::Params; @@ -236,9 +243,12 @@ class GemmUniversalGated< CollectiveEpilogue::to_underlying_arguments( args.problem_shape, args.epilogue, epilogue_workspace), hw_info, - TileScheduler::to_underlying_arguments( - problem_shape_MNKL, TileShape{}, ClusterShape{}, hw_info, - args.scheduler, scheduler_workspace)}; + TileScheduler::to_underlying_arguments(problem_shape_MNKL, + TileShape{}, + ClusterShape{}, + hw_info, + args.scheduler, + scheduler_workspace)}; } static bool can_implement(Arguments const &args) { @@ -246,8 +256,9 @@ class GemmUniversalGated< (args.mode == GemmUniversalMode::kBatched && cute::rank(ProblemShape{}) == 4); if (!implementable) { - CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Arguments or Problem Shape don't " - "meet the requirements.\n"); + CUTLASS_TRACE_HOST( + " CAN IMPLEMENT: Arguments or Problem Shape don't " + "meet the requirements.\n"); return implementable; } implementable &= @@ -273,18 +284,23 @@ class GemmUniversalGated< return workspace_size; } - static cutlass::Status - initialize_workspace(Arguments const &args, void *workspace = nullptr, - cudaStream_t stream = nullptr, - CudaHostAdapter *cuda_adapter = nullptr) { + static cutlass::Status initialize_workspace( + Arguments const &args, + void *workspace = nullptr, + cudaStream_t stream = nullptr, + CudaHostAdapter *cuda_adapter = nullptr) { Status status = Status::kSuccess; uint8_t *workspace_ptr = reinterpret_cast(workspace); size_t workspace_offset = 0; status = TileScheduler::template initialize_workspace( - args.scheduler, workspace_ptr + workspace_offset, stream, - args.problem_shape, args.hw_info, NumMmaWarpGroups); + args.scheduler, + workspace_ptr + workspace_offset, + stream, + args.problem_shape, + args.hw_info, + NumMmaWarpGroups); workspace_offset += TileScheduler::template get_workspace_size( @@ -295,8 +311,11 @@ class GemmUniversalGated< } status = CollectiveEpilogue::initialize_workspace( - args.problem_shape, args.epilogue, workspace_ptr + workspace_offset, - stream, cuda_adapter); + args.problem_shape, + args.epilogue, + workspace_ptr + workspace_offset, + stream, + cuda_adapter); workspace_offset += CollectiveEpilogue::get_workspace_size( args.problem_shape, args.epilogue); workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); @@ -319,9 +338,12 @@ class GemmUniversalGated< params.scheduler.raster_order_ == TileScheduler::RasterOrder::AlongN ? TileScheduler::RasterOrderOptions::AlongN : TileScheduler::RasterOrderOptions::AlongM; - return TileScheduler::get_grid_shape(params.scheduler, params.problem_shape, - TileShape{}, ClusterShape{}, - params.hw_info, args); + return TileScheduler::get_grid_shape(params.scheduler, + params.problem_shape, + TileShape{}, + ClusterShape{}, + params.hw_info, + args); } static dim3 get_block_shape() { return dim3(MaxThreadsPerBlock, 1, 1); } @@ -333,8 +355,9 @@ class GemmUniversalGated< // Any Tensor Op MMA Atom in the WGMMA ISA is arch conditional to sm90a. #if !defined(__CUDA_ARCH_FEAT_SM90_ALL) - printf("ERROR : Arch conditional MMA instruction used without targeting " - "sm90a compute capability. Aborting.\n"); + printf( + "ERROR : Arch conditional MMA instruction used without targeting " + "sm90a compute capability. Aborting.\n"); #else // Preconditions @@ -437,7 +460,7 @@ class GemmUniversalGated< params_math_wg_order_barrier.group_id = canonical_warp_group_idx() - static_cast(WarpGroupRole::Consumer0); params_math_wg_order_barrier.group_size = - NumThreadsPerWarpGroup; // Number of threads / participants in a group + NumThreadsPerWarpGroup; // Number of threads / participants in a group MathWarpGroupOrderBarrier math_wg_order_barrier( shared_storage.pipelines.math_wg_order, params_math_wg_order_barrier); @@ -464,7 +487,7 @@ class GemmUniversalGated< return []() { cute::cluster_wait(); }; } else { __syncthreads(); - return []() {}; // do nothing + return []() {}; // do nothing } }(); @@ -476,7 +499,7 @@ class GemmUniversalGated< // Get the appropriate blocks for this thread block -- potential for thread // block locality TiledMma tiled_mma; - auto blk_shape = TileShape{}; // (BLK_M,BLK_N,BLK_K) + auto blk_shape = TileShape{}; // (BLK_M,BLK_N,BLK_K) // In a warp specialized kernel, collectives expose data movement and // compute operations separately @@ -535,10 +558,16 @@ class GemmUniversalGated< auto k_tile_iter = cute::make_coord_iterator(shape<3>(gA_mkl)); - collective_mainloop.load( - params.mainloop, mainloop_pipeline, mainloop_pipe_producer_state, - load_inputs, blk_coord, k_tile_iter, k_tile_count, lane_idx, - block_rank_in_cluster, shared_storage.tensors.mainloop); + collective_mainloop.load(params.mainloop, + mainloop_pipeline, + mainloop_pipe_producer_state, + load_inputs, + blk_coord, + k_tile_iter, + k_tile_count, + lane_idx, + block_rank_in_cluster, + shared_storage.tensors.mainloop); // Update starting pipeline state for the next tile mainloop_pipe_producer_state.advance(k_tile_count); @@ -551,12 +580,12 @@ class GemmUniversalGated< // Get next work tile scheduler.advance_to_next_work(); work_tile_info = scheduler.get_current_work(); - } // Scheduler work fetch loop + } // Scheduler work fetch loop // Make sure all Consumer Warp Groups have been waited upon collective_mainloop.load_tail(mainloop_pipeline, mainloop_pipe_producer_state); - } // Mainloop Producer Warp End + } // Mainloop Producer Warp End // Epilogue Producer Warp else if (producer_warp_role == ProducerWarpRole::Epilogue && @@ -570,21 +599,26 @@ class GemmUniversalGated< auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl)); auto blk_coord = make_coord(m_coord, n_coord, _, l_coord); - epi_load_pipe_producer_state = collective_epilogue.load( - epi_load_pipeline, epi_load_pipe_producer_state, - problem_shape_MNKL, blk_shape, blk_coord, tiled_mma, lane_idx, - shared_storage.tensors.epilogue); + epi_load_pipe_producer_state = + collective_epilogue.load(epi_load_pipeline, + epi_load_pipe_producer_state, + problem_shape_MNKL, + blk_shape, + blk_coord, + tiled_mma, + lane_idx, + shared_storage.tensors.epilogue); // Get next work tile scheduler.advance_to_next_work(); work_tile_info = scheduler.get_current_work(); - } // Scheduler work fetch loop + } // Scheduler work fetch loop // Make sure all Consumer Warp Groups have been waited upon collective_epilogue.load_tail(epi_load_pipeline, epi_load_pipe_producer_state); - } // Epilogue Producer Warp End - } // Producer Warp Group End + } // Epilogue Producer Warp End + } // Producer Warp Group End else if (warp_group_role == WarpGroupRole::Consumer0 || warp_group_role == WarpGroupRole::Consumer1) { @@ -602,17 +636,21 @@ class GemmUniversalGated< // Allocate the accumulators for the (M,N) blk_shape Tensor accumulators0 = partition_fragment_C( - tiled_mma, take<0, 2>(blk_shape)); // (MMA,MMA_M,MMA_N) + tiled_mma, take<0, 2>(blk_shape)); // (MMA,MMA_M,MMA_N) Tensor accumulators1 = partition_fragment_C( - tiled_mma, take<0, 2>(blk_shape)); // (MMA,MMA_M,MMA_N) + tiled_mma, take<0, 2>(blk_shape)); // (MMA,MMA_M,MMA_N) // Order two Math WG's MMA one after the other, helps hide Epilogue math_wg_order_barrier.wait(); - collective_mainloop.mma( - mainloop_pipeline, mainloop_pipe_consumer_state, accumulators0, - accumulators1, k_tile_count, warp_group_thread_idx, - shared_storage.tensors.mainloop, params.mainloop); + collective_mainloop.mma(mainloop_pipeline, + mainloop_pipe_consumer_state, + accumulators0, + accumulators1, + k_tile_count, + warp_group_thread_idx, + shared_storage.tensors.mainloop, + params.mainloop); // Cue for next Math WG's MMA to start math_wg_order_barrier.arrive(); @@ -637,12 +675,17 @@ class GemmUniversalGated< // Epilogue and write to gD auto [epi_load_pipe_consumer_state_next, epi_store_pipe_producer_state_next] = - collective_epilogue.store( - epi_load_pipeline, epi_load_pipe_consumer_state, - epi_store_pipeline, epi_store_pipe_producer_state, - problem_shape_MNKL, blk_shape, blk_coord, accumulators0, - tiled_mma, warp_group_thread_idx, - shared_storage.tensors.epilogue); + collective_epilogue.store(epi_load_pipeline, + epi_load_pipe_consumer_state, + epi_store_pipeline, + epi_store_pipe_producer_state, + problem_shape_MNKL, + blk_shape, + blk_coord, + accumulators0, + tiled_mma, + warp_group_thread_idx, + shared_storage.tensors.epilogue); // TMA store pipeline wait is only visible to TMA-issuing warp, so for // multiple-consumer kernels we need to wait for all TMA stores to @@ -651,9 +694,10 @@ class GemmUniversalGated< // current consumer. auto [epi_load_pipe_consumer_state_next_, epi_store_pipe_producer_state_next_] = - collective_epilogue.store_tail( - epi_load_pipeline, epi_load_pipe_consumer_state_next, - epi_store_pipeline, epi_store_pipe_producer_state_next); + collective_epilogue.store_tail(epi_load_pipeline, + epi_load_pipe_consumer_state_next, + epi_store_pipeline, + epi_store_pipe_producer_state_next); // Update starting load/store pipeline states for the next tile // state has already been incremented by 1 tile in collective calls, @@ -669,12 +713,12 @@ class GemmUniversalGated< // Get next work tile scheduler.advance_to_next_work(NumMmaWarpGroups); work_tile_info = scheduler.get_current_work(); - } // Scheduler work fetch loop - } // Consumer Warp Groups End + } // Scheduler work fetch loop + } // Consumer Warp Groups End #endif } }; /////////////////////////////////////////////////////////////////////////////// -} // namespace cutlass::gemm::kernel +} // namespace cutlass::gemm::kernel diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/kernel/splitk_gemm_grouped.h b/custom_ops/gpu_ops/cutlass_extensions/gemm/kernel/splitk_gemm_grouped.h index 5e3531f0938..5d68ea26d96 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/gemm/kernel/splitk_gemm_grouped.h +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/kernel/splitk_gemm_grouped.h @@ -1,12 +1,12 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation @@ -18,14 +18,15 @@ * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ @@ -49,446 +50,471 @@ ///////////////////////////////////////////////////////////////////////////////////////////////// -namespace cutlass -{ -namespace gemm -{ -namespace kernel -{ +namespace cutlass { +namespace gemm { +namespace kernel { ///////////////////////////////////////////////////////////////////////////////////////////////// -template -struct SplitkGemmGrouped -{ -public: - using Mma = Mma_; - using Epilogue = Epilogue_; - using EpilogueOutputOp = typename Epilogue::OutputOp; - using ThreadblockSwizzle = ThreadblockSwizzle_; - static GroupScheduleMode const kGroupScheduleMode = GroupScheduleMode_; - static bool const kTransposed = Transposed; - - // Optional transpose - using MapArguments = kernel::detail::MapArguments; - - // Public-facing type definitions related to operand element type, layout, and complex conjugate - // operation. Must interact with the 'kTransposed' notion. - using ElementA = typename MapArguments::ElementA; - using LayoutA = typename MapArguments::LayoutA; - using ElementB = typename MapArguments::ElementB; - using LayoutB = typename MapArguments::LayoutB; - using ElementC = typename Epilogue::OutputTileIterator::Element; - using LayoutC = typename MapArguments::LayoutC; +template +struct SplitkGemmGrouped { + public: + using Mma = Mma_; + using Epilogue = Epilogue_; + using EpilogueOutputOp = typename Epilogue::OutputOp; + using ThreadblockSwizzle = ThreadblockSwizzle_; + static GroupScheduleMode const kGroupScheduleMode = GroupScheduleMode_; + static bool const kTransposed = Transposed; + + // Optional transpose + using MapArguments = + kernel::detail::MapArguments; + + // Public-facing type definitions related to operand element type, layout, and + // complex conjugate operation. Must interact with the 'kTransposed' notion. + using ElementA = typename MapArguments::ElementA; + using LayoutA = typename MapArguments::LayoutA; + using ElementB = typename MapArguments::ElementB; + using LayoutB = typename MapArguments::LayoutB; + using ElementC = typename Epilogue::OutputTileIterator::Element; + using LayoutC = typename MapArguments::LayoutC; + + using ElementFinalOutput = typename MapArguments::ElementA; + + static ComplexTransform const kTransformA = MapArguments::kTransformA; + static ComplexTransform const kTransformB = MapArguments::kTransformB; + + // Type definitions about the mainloop. + using Operator = typename Mma::Operator; + using OperatorClass = typename Mma::Operator::OperatorClass; + using ThreadblockShape = typename Mma::Shape; + using WarpShape = typename Mma::Operator::Shape; + using InstructionShape = typename Mma::Policy::Operator::InstructionShape; + using ArchTag = typename Mma::ArchTag; + + static int const kStages = Mma::kStages; + static int const kAlignmentA = MapArguments::kAlignmentA; + static int const kAlignmentB = MapArguments::kAlignmentB; + static int const kAlignmentC = + Epilogue::OutputTileIterator::kElementsPerAccess; + + /// Warp count (concept: GemmShape) + using WarpCount = typename Mma::WarpCount; + static int const kThreadCount = 32 * WarpCount::kCount; + + using ProblemVisitor = GemmGroupedProblemVisitor; + + // + // Structures + // + + /// Argument structure + struct Arguments { + // + // Data members + // - using ElementFinalOutput = typename MapArguments::ElementA; + GemmCoord* problem_sizes; + int problem_count; + int threadblock_count; - static ComplexTransform const kTransformA = MapArguments::kTransformA; - static ComplexTransform const kTransformB = MapArguments::kTransformB; + typename EpilogueOutputOp::Params output_op; - // Type definitions about the mainloop. - using Operator = typename Mma::Operator; - using OperatorClass = typename Mma::Operator::OperatorClass; - using ThreadblockShape = typename Mma::Shape; - using WarpShape = typename Mma::Operator::Shape; - using InstructionShape = typename Mma::Policy::Operator::InstructionShape; - using ArchTag = typename Mma::ArchTag; + ElementA** ptr_A; + ElementB** ptr_B; + ElementFinalOutput** ptr_C; + ElementFinalOutput** ptr_D; - static int const kStages = Mma::kStages; - static int const kAlignmentA = MapArguments::kAlignmentA; - static int const kAlignmentB = MapArguments::kAlignmentB; - static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; + typename LayoutA::Stride::LongIndex* lda; + typename LayoutB::Stride::LongIndex* ldb; + typename LayoutC::Stride::LongIndex* ldc; + typename LayoutC::Stride::LongIndex* ldd; - /// Warp count (concept: GemmShape) - using WarpCount = typename Mma::WarpCount; - static int const kThreadCount = 32 * WarpCount::kCount; + // Only used by device-level operator + GemmCoord* host_problem_sizes; - using ProblemVisitor - = GemmGroupedProblemVisitor; + // splitK + int split_k_slices; + int64_t* splitk_buffer_offsets; // - // Structures + // Methods // - /// Argument structure - struct Arguments - { - - // - // Data members - // - - GemmCoord* problem_sizes; - int problem_count; - int threadblock_count; - - typename EpilogueOutputOp::Params output_op; - - ElementA** ptr_A; - ElementB** ptr_B; - ElementFinalOutput** ptr_C; - ElementFinalOutput** ptr_D; - - typename LayoutA::Stride::LongIndex* lda; - typename LayoutB::Stride::LongIndex* ldb; - typename LayoutC::Stride::LongIndex* ldc; - typename LayoutC::Stride::LongIndex* ldd; - - // Only used by device-level operator - GemmCoord* host_problem_sizes; - - // splitK - int split_k_slices; - int64_t* splitk_buffer_offsets; - - // - // Methods - // - - /// Default ctor - CUTLASS_HOST_DEVICE - Arguments() - : problem_count(0) - , threadblock_count(0) - , ptr_A(nullptr) - , ptr_B(nullptr) - , ptr_C(nullptr) - , ptr_D(nullptr) - , lda(nullptr) - , ldb(nullptr) - , ldc(nullptr) - , ldd(nullptr) - , host_problem_sizes(nullptr) - , split_k_slices(1) - , splitk_buffer_offsets(nullptr) - { - } - - /// Ctor - CUTLASS_HOST_DEVICE - Arguments(GemmCoord* problem_sizes, int problem_count, int threadblock_count, - typename EpilogueOutputOp::Params output_op, ElementA** ptr_A, ElementB** ptr_B, ElementFinalOutput** ptr_C, - ElementFinalOutput** ptr_D, typename LayoutA::Stride::LongIndex* lda, - typename LayoutB::Stride::LongIndex* ldb, typename LayoutC::Stride::LongIndex* ldc, - typename LayoutC::Stride::LongIndex* ldd, GemmCoord* host_problem_sizes, int split_k_slices, - int64_t* splitk_buffer_offsets) - : problem_sizes(problem_sizes) - , problem_count(problem_count) - , threadblock_count(threadblock_count) - , output_op(output_op) - , ptr_A(ptr_A) - , ptr_B(ptr_B) - , ptr_C(ptr_C) - , ptr_D(ptr_D) - , lda(lda) - , ldb(ldb) - , ldc(ldc) - , ldd(ldd) - , host_problem_sizes(host_problem_sizes) - , split_k_slices(split_k_slices) - , splitk_buffer_offsets(splitk_buffer_offsets) - { - } - }; - - // - // Structure for precomputing values in host memory and passing to kernels - // + /// Default ctor + CUTLASS_HOST_DEVICE + Arguments() + : problem_count(0), + threadblock_count(0), + ptr_A(nullptr), + ptr_B(nullptr), + ptr_C(nullptr), + ptr_D(nullptr), + lda(nullptr), + ldb(nullptr), + ldc(nullptr), + ldd(nullptr), + host_problem_sizes(nullptr), + split_k_slices(1), + splitk_buffer_offsets(nullptr) {} + + /// Ctor + CUTLASS_HOST_DEVICE + Arguments(GemmCoord* problem_sizes, + int problem_count, + int threadblock_count, + typename EpilogueOutputOp::Params output_op, + ElementA** ptr_A, + ElementB** ptr_B, + ElementFinalOutput** ptr_C, + ElementFinalOutput** ptr_D, + typename LayoutA::Stride::LongIndex* lda, + typename LayoutB::Stride::LongIndex* ldb, + typename LayoutC::Stride::LongIndex* ldc, + typename LayoutC::Stride::LongIndex* ldd, + GemmCoord* host_problem_sizes, + int split_k_slices, + int64_t* splitk_buffer_offsets) + : problem_sizes(problem_sizes), + problem_count(problem_count), + threadblock_count(threadblock_count), + output_op(output_op), + ptr_A(ptr_A), + ptr_B(ptr_B), + ptr_C(ptr_C), + ptr_D(ptr_D), + lda(lda), + ldb(ldb), + ldc(ldc), + ldd(ldd), + host_problem_sizes(host_problem_sizes), + split_k_slices(split_k_slices), + splitk_buffer_offsets(splitk_buffer_offsets) {} + }; + + // + // Structure for precomputing values in host memory and passing to kernels + // + + /// Parameters structure + struct Params { + typename ProblemVisitor::Params problem_visitor; + int threadblock_count; + + typename EpilogueOutputOp::Params output_op; + + ElementA** ptr_A; + ElementB** ptr_B; + ElementFinalOutput** ptr_C; + ElementFinalOutput** ptr_D; + ElementC* ptr_C_split; + ElementC* ptr_D_split; + + typename LayoutA::Stride::LongIndex* lda; + typename LayoutB::Stride::LongIndex* ldb; + typename LayoutC::Stride::LongIndex* ldc; + typename LayoutC::Stride::LongIndex* ldd; - /// Parameters structure - struct Params - { - - typename ProblemVisitor::Params problem_visitor; - int threadblock_count; - - typename EpilogueOutputOp::Params output_op; - - ElementA** ptr_A; - ElementB** ptr_B; - ElementFinalOutput** ptr_C; - ElementFinalOutput** ptr_D; - ElementC* ptr_C_split; - ElementC* ptr_D_split; - - typename LayoutA::Stride::LongIndex* lda; - typename LayoutB::Stride::LongIndex* ldb; - typename LayoutC::Stride::LongIndex* ldc; - typename LayoutC::Stride::LongIndex* ldd; - - // - // Methods - // - - // splitk - GemmCoord grid_tiled_shape; - int swizzle_log_tile; - int gemm_k_size; - GemmCoord* host_problem_sizes; - int split_k_slices; - int64_t* splitk_buffer_offsets; - - CUTLASS_HOST_DEVICE - Params() - : ptr_A(nullptr) - , ptr_B(nullptr) - , ptr_C(nullptr) - , ptr_D(nullptr) - , ptr_C_split(nullptr) - , ptr_D_split(nullptr) - , lda(nullptr) - , ldb(nullptr) - , ldc(nullptr) - , ldd(nullptr) - , swizzle_log_tile(0) - , gemm_k_size(0) - , host_problem_sizes(nullptr) - , split_k_slices(1) - , splitk_buffer_offsets(nullptr) - { - } - - CUTLASS_HOST_DEVICE - Params(Arguments const& args, void* workspace = nullptr, int tile_count = 0) - : problem_visitor(args.problem_sizes, args.problem_count, workspace, tile_count) - , host_problem_sizes(args.host_problem_sizes) - , threadblock_count(args.threadblock_count) - , output_op(args.output_op) - , ptr_A(args.ptr_A) - , ptr_B(args.ptr_B) - , ptr_C(args.ptr_C) - , ptr_D(args.ptr_D) - , ptr_C_split((ElementC*) workspace) - , ptr_D_split((ElementC*) workspace) - , lda(args.lda) - , ldb(args.ldb) - , ldc(args.ldc) - , ldd(args.ldd) - , split_k_slices(args.split_k_slices) - , splitk_buffer_offsets(args.splitk_buffer_offsets) - { - // Determine grid shape - ThreadblockSwizzle threadblock_swizzle; - grid_tiled_shape = threadblock_swizzle.get_tiled_shape(args.host_problem_sizes[0], - {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, args.split_k_slices); - swizzle_log_tile = ThreadblockSwizzle().get_log_tile(grid_tiled_shape); - - // only support same k - int full_gemm_k_iterations = args.host_problem_sizes[0].k() / Mma::Shape::kK; - int gemm_k_iterations = full_gemm_k_iterations / grid_tiled_shape.k(); - - gemm_k_size = gemm_k_iterations * Mma::Shape::kK; - } - - CUTLASS_HOST_DEVICE - void update(Arguments const& args, void* workspace = nullptr, int tile_count = 0) - { - - problem_visitor = - typename ProblemVisitor::Params(args.problem_sizes, args.problem_count, workspace, tile_count); - threadblock_count = args.threadblock_count; - output_op = args.output_op; - ptr_A = args.ptr_A; - ptr_B = args.ptr_B; - ptr_C = args.ptr_C; - ptr_D = args.ptr_D; - ptr_C_split = workspace; - ptr_D_split = workspace; - - lda = args.lda; - ldb = args.ldb; - ldc = args.ldc; - ldd = args.ldd; - } - }; - - /// Shared memory storage structure - struct SharedStorage - { - union - { - typename Mma::SharedStorage main_loop; - typename Epilogue::SharedStorage epilogue; - } kernel; - - // ProblemVisitor shared storage can't be overlapped with others - typename ProblemVisitor::SharedStorage problem_visitor; - }; - -public: // // Methods // - CUTLASS_DEVICE - SplitkGemmGrouped() {} - - /// Determines whether kernel satisfies alignment - static Status can_implement(cutlass::gemm::GemmCoord const& problem_size) - { - return Status::kSuccess; + // splitk + GemmCoord grid_tiled_shape; + int swizzle_log_tile; + int gemm_k_size; + GemmCoord* host_problem_sizes; + int split_k_slices; + int64_t* splitk_buffer_offsets; + + CUTLASS_HOST_DEVICE + Params() + : ptr_A(nullptr), + ptr_B(nullptr), + ptr_C(nullptr), + ptr_D(nullptr), + ptr_C_split(nullptr), + ptr_D_split(nullptr), + lda(nullptr), + ldb(nullptr), + ldc(nullptr), + ldd(nullptr), + swizzle_log_tile(0), + gemm_k_size(0), + host_problem_sizes(nullptr), + split_k_slices(1), + splitk_buffer_offsets(nullptr) {} + + CUTLASS_HOST_DEVICE + Params(Arguments const& args, void* workspace = nullptr, int tile_count = 0) + : problem_visitor( + args.problem_sizes, args.problem_count, workspace, tile_count), + host_problem_sizes(args.host_problem_sizes), + threadblock_count(args.threadblock_count), + output_op(args.output_op), + ptr_A(args.ptr_A), + ptr_B(args.ptr_B), + ptr_C(args.ptr_C), + ptr_D(args.ptr_D), + ptr_C_split((ElementC*)workspace), + ptr_D_split((ElementC*)workspace), + lda(args.lda), + ldb(args.ldb), + ldc(args.ldc), + ldd(args.ldd), + split_k_slices(args.split_k_slices), + splitk_buffer_offsets(args.splitk_buffer_offsets) { + // Determine grid shape + ThreadblockSwizzle threadblock_swizzle; + grid_tiled_shape = threadblock_swizzle.get_tiled_shape( + args.host_problem_sizes[0], + {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, + args.split_k_slices); + swizzle_log_tile = ThreadblockSwizzle().get_log_tile(grid_tiled_shape); + + // only support same k + int full_gemm_k_iterations = + args.host_problem_sizes[0].k() / Mma::Shape::kK; + int gemm_k_iterations = full_gemm_k_iterations / grid_tiled_shape.k(); + + gemm_k_size = gemm_k_iterations * Mma::Shape::kK; } - static Status can_implement(Arguments const& args) - { - return Status::kSuccess; + CUTLASS_HOST_DEVICE + void update(Arguments const& args, + void* workspace = nullptr, + int tile_count = 0) { + problem_visitor = typename ProblemVisitor::Params( + args.problem_sizes, args.problem_count, workspace, tile_count); + threadblock_count = args.threadblock_count; + output_op = args.output_op; + ptr_A = args.ptr_A; + ptr_B = args.ptr_B; + ptr_C = args.ptr_C; + ptr_D = args.ptr_D; + ptr_C_split = workspace; + ptr_D_split = workspace; + + lda = args.lda; + ldb = args.ldb; + ldc = args.ldc; + ldd = args.ldd; } + }; + + /// Shared memory storage structure + struct SharedStorage { + union { + typename Mma::SharedStorage main_loop; + typename Epilogue::SharedStorage epilogue; + } kernel; + + // ProblemVisitor shared storage can't be overlapped with others + typename ProblemVisitor::SharedStorage problem_visitor; + }; + + public: + // + // Methods + // + + CUTLASS_DEVICE + SplitkGemmGrouped() {} + + /// Determines whether kernel satisfies alignment + static Status can_implement(cutlass::gemm::GemmCoord const& problem_size) { + return Status::kSuccess; + } + + static Status can_implement(Arguments const& args) { + return Status::kSuccess; + } + + /// Executes one GEMM + CUTLASS_DEVICE + void operator()(Params const& params, SharedStorage& shared_storage) { + // + // These types shadow the type-level definitions and support the ability to + // implement a 'transposed' GEMM that computes the transposed problems. + // + using ElementA = typename Mma::IteratorA::Element; + using LayoutA = typename Mma::IteratorA::Layout; + using ElementB = typename Mma::IteratorB::Element; + using LayoutB = typename Mma::IteratorB::Layout; + using ElementC = typename Epilogue::OutputTileIterator::Element; + using LayoutC = typename Epilogue::OutputTileIterator::Layout; - /// Executes one GEMM - CUTLASS_DEVICE - void operator()(Params const& params, SharedStorage& shared_storage) - { - - // - // These types shadow the type-level definitions and support the ability to implement - // a 'transposed' GEMM that computes the transposed problems. - // - using ElementA = typename Mma::IteratorA::Element; - using LayoutA = typename Mma::IteratorA::Layout; - using ElementB = typename Mma::IteratorB::Element; - using LayoutB = typename Mma::IteratorB::Layout; - using ElementC = typename Epilogue::OutputTileIterator::Element; - using LayoutC = typename Epilogue::OutputTileIterator::Layout; - - // - // Problem visitor. - // - ProblemVisitor problem_visitor(params.problem_visitor, shared_storage.problem_visitor, blockIdx.x); - - // Outer 'persistent' loop to iterate over tiles - while (problem_visitor.next_tile()) - { - - GemmCoord problem_size = problem_visitor.problem_size(); - int32_t problem_idx = problem_visitor.problem_index(); - int32_t threadblock_idx = int32_t(problem_visitor.threadblock_idx()); - - GemmCoord grid_shape = problem_visitor.grid_shape(problem_size); - - // Load element pointers. Exchange pointers and strides if working on the transpose - ElementA* ptr_A - = reinterpret_cast((kTransposed ? params.ptr_B[problem_idx] : params.ptr_A[problem_idx])); - typename LayoutA::LongIndex ldm_A = (kTransposed ? params.ldb[problem_idx] : params.lda[problem_idx]); - - ElementB* ptr_B - = reinterpret_cast((kTransposed ? params.ptr_A[problem_idx] : params.ptr_B[problem_idx])); - typename LayoutB::LongIndex ldm_B = (kTransposed ? params.lda[problem_idx] : params.ldb[problem_idx]); - - // Compute threadblock location - ThreadblockSwizzle threadblock_swizzle; - GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); - - cutlass::gemm::GemmCoord threadblock_offset(int(threadblock_idx / grid_shape.n()) * Mma::Shape::kM, - int(threadblock_idx % grid_shape.n()) * Mma::Shape::kN, 0); - - // Compute initial location in logical coordinates - cutlass::MatrixCoord tb_offset_A{ - threadblock_offset.m(), - threadblock_tile_offset.k() * params.gemm_k_size, - }; - - cutlass::MatrixCoord tb_offset_B{threadblock_tile_offset.k() * params.gemm_k_size, threadblock_offset.n()}; - - // Problem size is a function of threadblock index in the K dimension - int problem_size_k; - if (threadblock_tile_offset.k() + 1 == params.grid_tiled_shape.k()) - { - problem_size_k = problem_size.k(); - } - else - { - problem_size_k = (threadblock_tile_offset.k() + 1) * params.gemm_k_size; - } - - // Compute threadblock-scoped matrix multiply-add - int gemm_k_iterations = (problem_size_k - tb_offset_A.column() + Mma::Shape::kK - 1) / Mma::Shape::kK; - - // Compute position within threadblock - int thread_idx = threadIdx.x; - - // Construct iterators to A and B operands - typename Mma::IteratorA iterator_A( - LayoutA(ldm_A), ptr_A, {problem_size.m(), problem_size_k}, thread_idx, tb_offset_A); - - typename Mma::IteratorB iterator_B( - LayoutB(ldm_B), ptr_B, {problem_size_k, problem_size.n()}, thread_idx, tb_offset_B); - - typename Mma::FragmentC accumulators; - - accumulators.clear(); - - // Broadcast the warp_id computed by lane 0 to ensure dependent code - // is compiled as warp-uniform. - int warp_idx = canonical_warp_idx_sync(); - - int lane_idx = threadIdx.x % 32; - - // - // Matrix multiply phase - // - - // Construct thread-scoped matrix multiply - Mma mma(shared_storage.kernel.main_loop, thread_idx, warp_idx, lane_idx); - - // Wait for all threads to finish their epilogue phases from the previous tile. - __syncthreads(); - - // Compute threadblock-scoped matrix multiply-add - mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators); - - // - // Epilogue - // - - EpilogueOutputOp output_op(params.output_op); - - ElementC* ptr_C = params.ptr_C_split; - ElementC* ptr_D = params.ptr_D_split; - - LayoutC layout_C(params.ldc[problem_idx]); - LayoutC layout_D(params.ldd[problem_idx]); - - typename Epilogue::OutputTileIterator::Params params_C(layout_C); - typename Epilogue::OutputTileIterator::Params params_D(layout_D); - - // assume identity swizzle - MatrixCoord threadblock_offset_C(threadblock_offset.m(), threadblock_offset.n()); - - // Tile iterator loading from source tensor. - typename Epilogue::OutputTileIterator iterator_C( - params_C, ptr_C, problem_size.mn(), thread_idx, threadblock_offset_C); - - iterator_C.add_pointer_offset(problem_size.m() * problem_size.n() * threadblock_tile_offset.k() - + gridDim.z * params.splitk_buffer_offsets[problem_idx]); - - // Tile iterator writing to destination tensor. - typename Epilogue::OutputTileIterator iterator_D( - params_D, ptr_D, problem_size.mn(), thread_idx, threadblock_offset_C); - iterator_D.add_pointer_offset(problem_size.m() * problem_size.n() * threadblock_tile_offset.k() - + gridDim.z * params.splitk_buffer_offsets[problem_idx]); - - Epilogue epilogue(shared_storage.kernel.epilogue, thread_idx, warp_idx, lane_idx); + // + // Problem visitor. + // + ProblemVisitor problem_visitor( + params.problem_visitor, shared_storage.problem_visitor, blockIdx.x); + + // Outer 'persistent' loop to iterate over tiles + while (problem_visitor.next_tile()) { + GemmCoord problem_size = problem_visitor.problem_size(); + int32_t problem_idx = problem_visitor.problem_index(); + int32_t threadblock_idx = int32_t(problem_visitor.threadblock_idx()); + + GemmCoord grid_shape = problem_visitor.grid_shape(problem_size); + + // Load element pointers. Exchange pointers and strides if working on the + // transpose + ElementA* ptr_A = reinterpret_cast(( + kTransposed ? params.ptr_B[problem_idx] : params.ptr_A[problem_idx])); + typename LayoutA::LongIndex ldm_A = + (kTransposed ? params.ldb[problem_idx] : params.lda[problem_idx]); + + ElementB* ptr_B = reinterpret_cast(( + kTransposed ? params.ptr_A[problem_idx] : params.ptr_B[problem_idx])); + typename LayoutB::LongIndex ldm_B = + (kTransposed ? params.lda[problem_idx] : params.ldb[problem_idx]); + + // Compute threadblock location + ThreadblockSwizzle threadblock_swizzle; + GemmCoord threadblock_tile_offset = + threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + cutlass::gemm::GemmCoord threadblock_offset( + int(threadblock_idx / grid_shape.n()) * Mma::Shape::kM, + int(threadblock_idx % grid_shape.n()) * Mma::Shape::kN, + 0); + + // Compute initial location in logical coordinates + cutlass::MatrixCoord tb_offset_A{ + threadblock_offset.m(), + threadblock_tile_offset.k() * params.gemm_k_size, + }; + + cutlass::MatrixCoord tb_offset_B{ + threadblock_tile_offset.k() * params.gemm_k_size, + threadblock_offset.n()}; + + // Problem size is a function of threadblock index in the K dimension + int problem_size_k; + if (threadblock_tile_offset.k() + 1 == params.grid_tiled_shape.k()) { + problem_size_k = problem_size.k(); + } else { + problem_size_k = (threadblock_tile_offset.k() + 1) * params.gemm_k_size; + } + + // Compute threadblock-scoped matrix multiply-add + int gemm_k_iterations = + (problem_size_k - tb_offset_A.column() + Mma::Shape::kK - 1) / + Mma::Shape::kK; + + // Compute position within threadblock + int thread_idx = threadIdx.x; + + // Construct iterators to A and B operands + typename Mma::IteratorA iterator_A(LayoutA(ldm_A), + ptr_A, + {problem_size.m(), problem_size_k}, + thread_idx, + tb_offset_A); + + typename Mma::IteratorB iterator_B(LayoutB(ldm_B), + ptr_B, + {problem_size_k, problem_size.n()}, + thread_idx, + tb_offset_B); + + typename Mma::FragmentC accumulators; + + accumulators.clear(); + + // Broadcast the warp_id computed by lane 0 to ensure dependent code + // is compiled as warp-uniform. + int warp_idx = canonical_warp_idx_sync(); + + int lane_idx = threadIdx.x % 32; + + // + // Matrix multiply phase + // + + // Construct thread-scoped matrix multiply + Mma mma(shared_storage.kernel.main_loop, thread_idx, warp_idx, lane_idx); + + // Wait for all threads to finish their epilogue phases from the previous + // tile. + __syncthreads(); + + // Compute threadblock-scoped matrix multiply-add + mma(gemm_k_iterations, + accumulators, + iterator_A, + iterator_B, + accumulators); + + // + // Epilogue + // + + EpilogueOutputOp output_op(params.output_op); + + ElementC* ptr_C = params.ptr_C_split; + ElementC* ptr_D = params.ptr_D_split; + + LayoutC layout_C(params.ldc[problem_idx]); + LayoutC layout_D(params.ldd[problem_idx]); + + typename Epilogue::OutputTileIterator::Params params_C(layout_C); + typename Epilogue::OutputTileIterator::Params params_D(layout_D); + + // assume identity swizzle + MatrixCoord threadblock_offset_C(threadblock_offset.m(), + threadblock_offset.n()); + + // Tile iterator loading from source tensor. + typename Epilogue::OutputTileIterator iterator_C( + params_C, ptr_C, problem_size.mn(), thread_idx, threadblock_offset_C); + + iterator_C.add_pointer_offset( + problem_size.m() * problem_size.n() * threadblock_tile_offset.k() + + gridDim.z * params.splitk_buffer_offsets[problem_idx]); + + // Tile iterator writing to destination tensor. + typename Epilogue::OutputTileIterator iterator_D( + params_D, ptr_D, problem_size.mn(), thread_idx, threadblock_offset_C); + iterator_D.add_pointer_offset( + problem_size.m() * problem_size.n() * threadblock_tile_offset.k() + + gridDim.z * params.splitk_buffer_offsets[problem_idx]); + + Epilogue epilogue( + shared_storage.kernel.epilogue, thread_idx, warp_idx, lane_idx); - // Execute the epilogue operator to update the destination tensor. - epilogue(output_op, iterator_D, accumulators, iterator_C); + // Execute the epilogue operator to update the destination tensor. + epilogue(output_op, iterator_D, accumulators, iterator_C); - // Next tile - problem_visitor.advance(gridDim.x); - } + // Next tile + problem_visitor.advance(gridDim.x); } + } }; ///////////////////////////////////////////////////////////////////////////////////////////////// -} // namespace kernel -} // namespace gemm -} // namespace cutlass +} // namespace kernel +} // namespace gemm +} // namespace cutlass ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/default_dq_mma.h b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/default_dq_mma.h index ed5e3e4daf8..a268861a0dd 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/default_dq_mma.h +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/default_dq_mma.h @@ -1,6 +1,6 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: Apache-2.0 + * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & + * AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,18 +19,15 @@ #include "cutlass_extensions/arch/mma.h" #include "cutlass_extensions/interleaved_numeric_conversion.h" -namespace cutlass -{ -namespace gemm -{ -namespace threadblock -{ +namespace cutlass { +namespace gemm { +namespace threadblock { //////////////////////////////////////////////////////////////////////////////// -// We need to distinguish here, since we want volta support. It is too much effort -// to write shared memory iterators that are probably needed for volta to function -// properly. As a result, we allow converters both after the LDG (for volta) and after -// the LDS for Turing+. +// We need to distinguish here, since we want volta support. It is too much +// effort to write shared memory iterators that are probably needed for volta to +// function properly. As a result, we allow converters both after the LDG (for +// volta) and after the LDS for Turing+. template < /// Iterator for B matrix in global memory typename IteratorB, @@ -38,9 +35,7 @@ template < typename MmaOperator, /// Math operation perform by warp level operator typename MathOperator> -struct SetConverters -{ -}; +struct SetConverters {}; // Dequantize after LDG, so set transforms accordingly template < @@ -48,14 +43,16 @@ template < typename IteratorB, /// Mma Policy typename MmaOperator> -struct SetConverters -{ - using TransformAfterLDG - = FastInterleavedAndBiasedNumericArrayConverter; +struct SetConverters { + using TransformAfterLDG = FastInterleavedAndBiasedNumericArrayConverter< + typename MmaOperator::ArchMmaOperator::ElementB, + typename IteratorB::Element, + IteratorB::Fragment::kElements>; - using TransformAfterLDS = NumericArrayConverter; + using TransformAfterLDS = + NumericArrayConverter; }; // Dequantize after LDS, so set transforms accordingly @@ -65,14 +62,18 @@ template < typename IteratorB, /// Mma Policy typename MmaOperator> -struct SetConverters -{ - using TransformAfterLDG = NumericArrayConverter; +struct SetConverters { + using TransformAfterLDG = + NumericArrayConverter; - using TransformAfterLDS - = FastInterleavedAndBiasedNumericArrayConverter; + using TransformAfterLDS = FastInterleavedAndBiasedNumericArrayConverter< + typename MmaOperator::ArchMmaOperator::ElementB, + typename TransformAfterLDG::result_type::Element, + MmaOperator::FragmentB::kElements>; }; //////////////////////////////////////////////////////////////////////////////// @@ -120,6 +121,6 @@ template < typename Enable = void> struct DqMma; -} // namespace threadblock -} // namespace gemm -} // namespace cutlass +} // namespace threadblock +} // namespace gemm +} // namespace cutlass diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h index 17c6346553c..566cd379b2b 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h @@ -1,6 +1,6 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: Apache-2.0 + * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & + * AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -27,49 +27,77 @@ #include "cutlass_extensions/gemm/threadblock/default_dq_mma.h" #include "cutlass_extensions/transform/threadblock/fine_grained_scale_zero_iterator.h" -namespace cutlass -{ -namespace gemm -{ -namespace threadblock -{ +namespace cutlass { +namespace gemm { +namespace threadblock { //////////////////////////////////////////////////////////////////////////////// -template +template struct DefaultScaleIteratorsMultistage; // Fine grained iterators -template -struct DefaultScaleIteratorsMultistage> -{ - using IteratorScale - = cutlass::transform::threadblock::FineGrainedScaleZeroIterator, Element, - Layout, 0, Alignment>; - - using SmemIteratorScale = IteratorScale; +template +struct DefaultScaleIteratorsMultistage< + MmaShape, + Element, + Layout, + QuantOp, + Alignment, + std::enable_if_t> { + using IteratorScale = + cutlass::transform::threadblock::FineGrainedScaleZeroIterator< + cutlass::MatrixShape<1, MmaShape::kN>, + Element, + Layout, + 0, + Alignment>; + + using SmemIteratorScale = IteratorScale; }; // Per column iterators -template -struct DefaultScaleIteratorsMultistage> -{ - // ThreadMap for scale iterator - static_assert((MmaShape::kN % Alignment) == 0, ""); - -private: - using IteratorScaleThreadMap = transform::PitchLinearStripminedThreadMap, - MmaShape::kN / Alignment, Alignment>; - -public: - // Define iterators over tiles from the scale operand - using IteratorScale = cutlass::transform::threadblock::PredicatedTileIterator, - Element, Layout, 0, IteratorScaleThreadMap, Alignment>; - - using SmemIteratorScale = IteratorScale; +template +struct DefaultScaleIteratorsMultistage< + MmaShape, + Element, + Layout, + QuantOp, + Alignment, + std::enable_if_t> { + // ThreadMap for scale iterator + static_assert((MmaShape::kN % Alignment) == 0, ""); + + private: + using IteratorScaleThreadMap = transform::PitchLinearStripminedThreadMap< + layout::PitchLinearShape, + MmaShape::kN / Alignment, + Alignment>; + + public: + // Define iterators over tiles from the scale operand + using IteratorScale = cutlass::transform::threadblock::PredicatedTileIterator< + cutlass::MatrixShape<1, MmaShape::kN>, + Element, + Layout, + 0, + IteratorScaleThreadMap, + Alignment>; + + using SmemIteratorScale = IteratorScale; }; //////////////////////////////////////////////////////////////////////////////// @@ -111,69 +139,133 @@ template < typename Operator_, /// Use zfill or predicate for out-of-bound cp.async SharedMemoryClearOption SharedMemoryClear> -struct DqMma= 80 && !layout::IsColumnMajorTileInterleave::value)>::type> -{ - - static_assert(platform::is_same::value || platform::is_same::value - || platform::is_same::value, - "Element A must be fp16, fp8 or bf16"); - - using OperatorInfo = arch::DetagOperator; - using Operator = typename OperatorInfo::Operator; - static_assert(platform::is_same::value, - "Mma multistage must dequantize after ldsm"); - - static_assert(platform::is_same::value || platform::is_same::value, - "Element B must be uint8 or uint4"); - - static cutlass::arch::CacheOperation::Kind const CacheOpA = ((sizeof_bits::value * kAlignmentA) == 128) - ? cutlass::arch::CacheOperation::Global - : cutlass::arch::CacheOperation::Always; - - static cutlass::arch::CacheOperation::Kind const CacheOpB = ((sizeof_bits::value * kAlignmentB) == 128) - ? cutlass::arch::CacheOperation::Global - : cutlass::arch::CacheOperation::Always; - - // Define the MmaCore components - // Mma core does not depend on stages, so pass in at least 3 here to mma multistage pieces are created - using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore; - - // Define iterators over tiles from the A operand - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using AccessTypeA = cutlass::Array; - using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator< - cutlass::MatrixShape, ElementA, LayoutA, 1, ThreadMapA, - AccessTypeA>; - - // Define iterators over tiles from the B operand - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - using AccessTypeB = cutlass::Array; - using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator< - cutlass::MatrixShape, ElementB, LayoutB, 0, ThreadMapB, - AccessTypeB>; - - using ScaleIterators = DefaultScaleIteratorsMultistage; - - // Define iterators over tiles from the scale operand - using IteratorScale = typename ScaleIterators::IteratorScale; - - using SmemIteratorScale = typename ScaleIterators::SmemIteratorScale; - - using Converter = FastInterleavedAndBiasedNumericArrayConverter; - - // Define the threadblock-scoped pipelined matrix multiply - using ThreadblockMma = cutlass::gemm::threadblock::DqMmaMultistage; +struct DqMma= 80 && + !layout::IsColumnMajorTileInterleave::value)>::type> { + static_assert(platform::is_same::value || + platform::is_same::value || + platform::is_same::value, + "Element A must be fp16, fp8 or bf16"); + + using OperatorInfo = arch::DetagOperator; + using Operator = typename OperatorInfo::Operator; + static_assert( + platform::is_same::value, + "Mma multistage must dequantize after ldsm"); + + static_assert(platform::is_same::value || + platform::is_same::value, + "Element B must be uint8 or uint4"); + + static cutlass::arch::CacheOperation::Kind const CacheOpA = + ((sizeof_bits::value * kAlignmentA) == 128) + ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; + + static cutlass::arch::CacheOperation::Kind const CacheOpB = + ((sizeof_bits::value * kAlignmentB) == 128) + ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; + + // Define the MmaCore components + // Mma core does not depend on stages, so pass in at least 3 here to mma + // multistage pieces are created + using MmaCore = + typename cutlass::gemm::threadblock::DefaultMmaCore; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::Array; + using IteratorA = + cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, + ElementA, + LayoutA, + 1, + ThreadMapA, + AccessTypeA>; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::Array; + using IteratorB = + cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, + ElementB, + LayoutB, + 0, + ThreadMapB, + AccessTypeB>; + + using ScaleIterators = + DefaultScaleIteratorsMultistage; + + // Define iterators over tiles from the scale operand + using IteratorScale = typename ScaleIterators::IteratorScale; + + using SmemIteratorScale = typename ScaleIterators::SmemIteratorScale; + + using Converter = FastInterleavedAndBiasedNumericArrayConverter< + ElementScale, + ElementB, + MmaCore::MmaPolicy::Operator::FragmentB::kElements>; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = cutlass::gemm::threadblock::DqMmaMultistage< + typename MmaCore::Shape, + IteratorA, + typename MmaCore::SmemIteratorA, + MmaCore::kCacheOpA, + IteratorB, + typename MmaCore::SmemIteratorB, + MmaCore::kCacheOpB, + IteratorScale, + SmemIteratorScale, + ElementAccumulator, + layout::RowMajor, + typename MmaCore::MmaPolicy, + kStages, + Converter, + OperatorInfo::QuantOp, + SharedMemoryClear>; }; // Specialization to handle column major interleave B @@ -214,89 +306,159 @@ template < typename Operator_, /// Use zfill or predicate for out-of-bound cp.async SharedMemoryClearOption SharedMemoryClear> -struct DqMma= 80 && layout::IsColumnMajorTileInterleave::value)>::type> -{ - - static_assert(platform::is_same::value || platform::is_same::value - || platform::is_same::value, - "Element A must be fp16, fp8 or bf16"); - - using OperatorInfo = arch::DetagOperator; - using Operator = typename OperatorInfo::Operator; - static_assert(platform::is_same::value, - "Mma multistage must dequantize after ldsm"); - - static_assert(platform::is_same::value || platform::is_same::value, - "Element B must be uint8 or uint4"); - - static cutlass::arch::CacheOperation::Kind const CacheOpA = ((sizeof_bits::value * kAlignmentA) == 128) - ? cutlass::arch::CacheOperation::Global - : cutlass::arch::CacheOperation::Always; - - static cutlass::arch::CacheOperation::Kind const CacheOpB = ((sizeof_bits::value * kAlignmentB) == 128) - ? cutlass::arch::CacheOperation::Global - : cutlass::arch::CacheOperation::Always; - - // Define the MmaCore components - // Mma core does not depend on stages, so pass in at least 3 here to mma multistage pieces are created - using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore; - - // Define iterators over tiles from the A operand - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using AccessTypeA = cutlass::Array; - using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator< - cutlass::MatrixShape, ElementA, LayoutA, 1, ThreadMapA, - AccessTypeA>; - -private: - static constexpr int ColumnsInterleaved = LayoutB::kColumnsInterleaved; - static constexpr int RowsPerTile = LayoutB::kRowsPerTile; - static_assert(!(MmaCore::Shape::kN % ColumnsInterleaved), ""); - static_assert(RowsPerTile == MmaCore::Shape::kK, ""); - - using OriginalThreadMap = typename MmaCore::IteratorThreadMapB; - using OriginalWarpArrangement = typename OriginalThreadMap::Detail::WarpThreadArrangement; - static_assert(!(OriginalWarpArrangement::kStrided % ColumnsInterleaved), ""); - - using GmemIteratorShape - = MatrixShape; - using GmemThreadMapB = transform::PitchLinearWarpRakedThreadMap< - layout::PitchLinearShape, OriginalThreadMap::kThreads, - layout::PitchLinearShape, - MmaCore::kAccessSizeInBits / sizeof_bits::value>; - -public: - // Define iterators over tiles from the B operand - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - using AccessTypeB = cutlass::Array; - using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator; - - using ScaleIterators = DefaultScaleIteratorsMultistage; - - // Define iterators over tiles from the scale operand - using IteratorScale = typename ScaleIterators::IteratorScale; - - using SmemIteratorScale = typename ScaleIterators::SmemIteratorScale; - - using Converter = FastInterleavedAndBiasedNumericArrayConverter; - - // Define the threadblock-scoped pipelined matrix multiply - using ThreadblockMma = cutlass::gemm::threadblock::DqMmaMultistage; +struct DqMma= 80 && + layout::IsColumnMajorTileInterleave::value)>::type> { + static_assert(platform::is_same::value || + platform::is_same::value || + platform::is_same::value, + "Element A must be fp16, fp8 or bf16"); + + using OperatorInfo = arch::DetagOperator; + using Operator = typename OperatorInfo::Operator; + static_assert( + platform::is_same::value, + "Mma multistage must dequantize after ldsm"); + + static_assert(platform::is_same::value || + platform::is_same::value, + "Element B must be uint8 or uint4"); + + static cutlass::arch::CacheOperation::Kind const CacheOpA = + ((sizeof_bits::value * kAlignmentA) == 128) + ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; + + static cutlass::arch::CacheOperation::Kind const CacheOpB = + ((sizeof_bits::value * kAlignmentB) == 128) + ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; + + // Define the MmaCore components + // Mma core does not depend on stages, so pass in at least 3 here to mma + // multistage pieces are created + using MmaCore = + typename cutlass::gemm::threadblock::DefaultMmaCore; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::Array; + using IteratorA = + cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, + ElementA, + LayoutA, + 1, + ThreadMapA, + AccessTypeA>; + + private: + static constexpr int ColumnsInterleaved = LayoutB::kColumnsInterleaved; + static constexpr int RowsPerTile = LayoutB::kRowsPerTile; + static_assert(!(MmaCore::Shape::kN % ColumnsInterleaved), ""); + static_assert(RowsPerTile == MmaCore::Shape::kK, ""); + + using OriginalThreadMap = typename MmaCore::IteratorThreadMapB; + using OriginalWarpArrangement = + typename OriginalThreadMap::Detail::WarpThreadArrangement; + static_assert(!(OriginalWarpArrangement::kStrided % ColumnsInterleaved), ""); + + using GmemIteratorShape = + MatrixShape; + using GmemThreadMapB = transform::PitchLinearWarpRakedThreadMap< + layout::PitchLinearShape, + OriginalThreadMap::kThreads, + layout::PitchLinearShape< + OriginalWarpArrangement::kContiguous * ColumnsInterleaved, + OriginalWarpArrangement::kStrided / ColumnsInterleaved>, + MmaCore::kAccessSizeInBits / sizeof_bits::value>; + + public: + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::Array; + using IteratorB = + cutlass::transform::threadblock::PredicatedTileAccessIterator< + GmemIteratorShape, + ElementB, + layout::ColumnMajor, + 0, + GmemThreadMapB, + AccessTypeB>; + + using ScaleIterators = + DefaultScaleIteratorsMultistage; + + // Define iterators over tiles from the scale operand + using IteratorScale = typename ScaleIterators::IteratorScale; + + using SmemIteratorScale = typename ScaleIterators::SmemIteratorScale; + + using Converter = FastInterleavedAndBiasedNumericArrayConverter< + ElementScale, + ElementB, + MmaCore::MmaPolicy::Operator::FragmentB::kElements>; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = cutlass::gemm::threadblock::DqMmaMultistage< + typename MmaCore::Shape, + IteratorA, + typename MmaCore::SmemIteratorA, + MmaCore::kCacheOpA, + IteratorB, + typename MmaCore::SmemIteratorB, + MmaCore::kCacheOpB, + IteratorScale, + SmemIteratorScale, + ElementAccumulator, + layout::RowMajor, + typename MmaCore::MmaPolicy, + kStages, + Converter, + OperatorInfo::QuantOp, + SharedMemoryClear>; }; -} // namespace threadblock -} // namespace gemm -} // namespace cutlass +} // namespace threadblock +} // namespace gemm +} // namespace cutlass diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h index 345cd2eec9a..ba7f2863e08 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h @@ -1,6 +1,6 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: Apache-2.0 + * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & + * AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -27,58 +27,95 @@ #include "cutlass_extensions/gemm/threadblock/default_dq_mma.h" #include "cutlass_extensions/transform/threadblock/fine_grained_scale_zero_iterator.h" -namespace cutlass -{ -namespace gemm -{ -namespace threadblock -{ +namespace cutlass { +namespace gemm { +namespace threadblock { //////////////////////////////////////////////////////////////////////////////// -template +template struct DefaultScaleIteratorsPipelined; // Fine grained iterators -template -struct DefaultScaleIteratorsPipelined> -{ -private: - using SmemScaleType = half_t; - -public: - using IteratorScale - = cutlass::transform::threadblock::FineGrainedScaleZeroIterator, Element, - Layout, 0, Alignment>; - - using SmemIteratorScale - = cutlass::transform::threadblock::FineGrainedScaleZeroIterator, - SmemScaleType, Layout, 0, Alignment>; +template +struct DefaultScaleIteratorsPipelined< + MmaShape, + Element, + Layout, + QuantOp, + Alignment, + std::enable_if_t> { + private: + using SmemScaleType = half_t; + + public: + using IteratorScale = + cutlass::transform::threadblock::FineGrainedScaleZeroIterator< + cutlass::MatrixShape<1, MmaShape::kN>, + Element, + Layout, + 0, + Alignment>; + + using SmemIteratorScale = + cutlass::transform::threadblock::FineGrainedScaleZeroIterator< + cutlass::MatrixShape<1, MmaShape::kN>, + SmemScaleType, + Layout, + 0, + Alignment>; }; // Per column iterators -template -struct DefaultScaleIteratorsPipelined> -{ - static_assert((MmaShape::kN % Alignment) == 0, ""); - -private: - // ThreadMap for scale iterator - using IteratorScaleThreadMap = transform::PitchLinearStripminedThreadMap, - MmaShape::kN / Alignment, Alignment>; - using SmemScaleType = half_t; - -public: - // Define iterators over tiles from the scale operand - using IteratorScale = cutlass::transform::threadblock::PredicatedTileIterator, - Element, Layout, 0, IteratorScaleThreadMap, Alignment>; - - using SmemIteratorScale - = cutlass::transform::threadblock::PredicatedTileIterator, SmemScaleType, - Layout, 0, IteratorScaleThreadMap, Alignment>; +template +struct DefaultScaleIteratorsPipelined< + MmaShape, + Element, + Layout, + QuantOp, + Alignment, + std::enable_if_t> { + static_assert((MmaShape::kN % Alignment) == 0, ""); + + private: + // ThreadMap for scale iterator + using IteratorScaleThreadMap = transform::PitchLinearStripminedThreadMap< + layout::PitchLinearShape, + MmaShape::kN / Alignment, + Alignment>; + using SmemScaleType = half_t; + + public: + // Define iterators over tiles from the scale operand + using IteratorScale = cutlass::transform::threadblock::PredicatedTileIterator< + cutlass::MatrixShape<1, MmaShape::kN>, + Element, + Layout, + 0, + IteratorScaleThreadMap, + Alignment>; + + using SmemIteratorScale = + cutlass::transform::threadblock::PredicatedTileIterator< + cutlass::MatrixShape<1, MmaShape::kN>, + SmemScaleType, + Layout, + 0, + IteratorScaleThreadMap, + Alignment>; }; //////////////////////////////////////////////////////////////////////////////// @@ -116,57 +153,110 @@ template < typename InstructionShape, /// Operation performed by GEMM typename Operator_> -struct DqMma::value)>::type> -{ - - static_assert(platform::is_same::value || platform::is_same::value, - "Element A must be fp16 or bf16"); - - static_assert(platform::is_same::value || platform::is_same::value, - "Element B must be uint8 or uint4"); - - using OperatorInfo = arch::DetagOperator; - using Operator = typename OperatorInfo::Operator; - static_assert(OperatorInfo::QuantOp == WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY, ""); - - static constexpr bool DqAfterLDG = platform::is_same::value; - using MmaCoreElementA = half_t; - using MmaCoreElementB = typename platform::conditional::type; - - // Define the MmaCore components - using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore; - - // Define iterators over tiles from the A operand - using IteratorA = cutlass::transform::threadblock::PredicatedTileIterator< - cutlass::MatrixShape, ElementA, LayoutA, 1, - typename MmaCore::IteratorThreadMapA, kAlignmentA>; - - // Define iterators over tiles from the B operand - using IteratorB = cutlass::transform::threadblock::PredicatedTileIterator< - cutlass::MatrixShape, ElementB, LayoutB, 0, - typename MmaCore::IteratorThreadMapB, kAlignmentB>; - - using ScaleIterators = DefaultScaleIteratorsPipelined; - - // Define iterators over tiles from the scale operand - using IteratorScale = typename ScaleIterators::IteratorScale; - - using SmemIteratorScale = typename ScaleIterators::SmemIteratorScale; - - using Converters = SetConverters; - - // Define the threadblock-scoped pipelined matrix multiply - using ThreadblockMma = cutlass::gemm::threadblock::DqMmaPipelined; +struct DqMma::value)>::type> { + static_assert(platform::is_same::value || + platform::is_same::value, + "Element A must be fp16 or bf16"); + + static_assert(platform::is_same::value || + platform::is_same::value, + "Element B must be uint8 or uint4"); + + using OperatorInfo = arch::DetagOperator; + using Operator = typename OperatorInfo::Operator; + static_assert(OperatorInfo::QuantOp == + WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY, + ""); + + static constexpr bool DqAfterLDG = + platform::is_same::value; + using MmaCoreElementA = half_t; + using MmaCoreElementB = typename platform:: + conditional::type; + + // Define the MmaCore components + using MmaCore = + typename cutlass::gemm::threadblock::DefaultMmaCore; + + // Define iterators over tiles from the A operand + using IteratorA = cutlass::transform::threadblock::PredicatedTileIterator< + cutlass::MatrixShape, + ElementA, + LayoutA, + 1, + typename MmaCore::IteratorThreadMapA, + kAlignmentA>; + + // Define iterators over tiles from the B operand + using IteratorB = cutlass::transform::threadblock::PredicatedTileIterator< + cutlass::MatrixShape, + ElementB, + LayoutB, + 0, + typename MmaCore::IteratorThreadMapB, + kAlignmentB>; + + using ScaleIterators = DefaultScaleIteratorsPipelined; + + // Define iterators over tiles from the scale operand + using IteratorScale = typename ScaleIterators::IteratorScale; + + using SmemIteratorScale = typename ScaleIterators::SmemIteratorScale; + + using Converters = + SetConverters; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = cutlass::gemm::threadblock::DqMmaPipelined< + typename MmaCore::Shape, + IteratorA, + typename MmaCore::SmemIteratorA, + IteratorB, + typename MmaCore::SmemIteratorB, + IteratorScale, + SmemIteratorScale, + ElementAccumulator, + layout::RowMajor, + typename MmaCore::MmaPolicy, + typename Converters::TransformAfterLDG, + typename Converters::TransformAfterLDS, + OperatorInfo::QuantOp>; }; // Specialization to handle column major interleave B @@ -203,82 +293,140 @@ template < typename InstructionShape, /// Operation performed by GEMM typename Operator_> -struct DqMma::value)>::type> -{ - - static_assert(platform::is_same::value || platform::is_same::value, - "Element A must be fp16 or bf16"); - - static_assert(platform::is_same::value || platform::is_same::value, - "Element B must be uint8 or uint4"); - - using OperatorInfo = arch::DetagOperator; - using Operator = typename OperatorInfo::Operator; - - static constexpr bool DqAfterLDG = platform::is_same::value; - using MmaCoreElementA = half_t; - using MmaCoreElementB = typename platform::conditional::type; - - // Define the MmaCore components - using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore; - - // Define iterators over tiles from the A operand - using IteratorA = cutlass::transform::threadblock::PredicatedTileIterator< - cutlass::MatrixShape, ElementA, LayoutA, 1, - typename MmaCore::IteratorThreadMapA, kAlignmentA>; - -private: - static constexpr int ColumnsInterleaved = LayoutB::kColumnsInterleaved; - static constexpr int RowsPerTile = LayoutB::kRowsPerTile; - static_assert(!(MmaCore::Shape::kN % ColumnsInterleaved), ""); - static_assert(RowsPerTile == MmaCore::Shape::kK, ""); - - using OriginalThreadMap = typename MmaCore::IteratorThreadMapB; - using OriginalWarpArrangement = typename OriginalThreadMap::Detail::WarpThreadArrangement; - static_assert(!(OriginalWarpArrangement::kStrided % ColumnsInterleaved), ""); - - using GmemIteratorShape - = MatrixShape; - using GmemThreadMapB = transform::PitchLinearWarpRakedThreadMap< - layout::PitchLinearShape, OriginalThreadMap::kThreads, - layout::PitchLinearShape, - MmaCore::kAccessSizeInBits / sizeof_bits::value>; - -public: - // Define iterators over tiles from the B operand - using IteratorB = cutlass::transform::threadblock::PredicatedTileIterator; - - // ThreadMap for scale iterator - static_assert((MmaCore::Shape::kN % kAlignmentScale) == 0, ""); - using IteratorScaleThreadMap - = transform::PitchLinearStripminedThreadMap, - MmaCore::Shape::kN / kAlignmentScale, kAlignmentScale>; - - using ScaleIterators = DefaultScaleIteratorsPipelined; - - // Define iterators over tiles from the scale operand - using IteratorScale = typename ScaleIterators::IteratorScale; - - using SmemIteratorScale = typename ScaleIterators::SmemIteratorScale; - - using Converters = SetConverters; - - // Define the threadblock-scoped pipelined matrix multiply - using ThreadblockMma = cutlass::gemm::threadblock::DqMmaPipelined; +struct DqMma::value)>::type> { + static_assert(platform::is_same::value || + platform::is_same::value, + "Element A must be fp16 or bf16"); + + static_assert(platform::is_same::value || + platform::is_same::value, + "Element B must be uint8 or uint4"); + + using OperatorInfo = arch::DetagOperator; + using Operator = typename OperatorInfo::Operator; + + static constexpr bool DqAfterLDG = + platform::is_same::value; + using MmaCoreElementA = half_t; + using MmaCoreElementB = typename platform:: + conditional::type; + + // Define the MmaCore components + using MmaCore = + typename cutlass::gemm::threadblock::DefaultMmaCore; + + // Define iterators over tiles from the A operand + using IteratorA = cutlass::transform::threadblock::PredicatedTileIterator< + cutlass::MatrixShape, + ElementA, + LayoutA, + 1, + typename MmaCore::IteratorThreadMapA, + kAlignmentA>; + + private: + static constexpr int ColumnsInterleaved = LayoutB::kColumnsInterleaved; + static constexpr int RowsPerTile = LayoutB::kRowsPerTile; + static_assert(!(MmaCore::Shape::kN % ColumnsInterleaved), ""); + static_assert(RowsPerTile == MmaCore::Shape::kK, ""); + + using OriginalThreadMap = typename MmaCore::IteratorThreadMapB; + using OriginalWarpArrangement = + typename OriginalThreadMap::Detail::WarpThreadArrangement; + static_assert(!(OriginalWarpArrangement::kStrided % ColumnsInterleaved), ""); + + using GmemIteratorShape = + MatrixShape; + using GmemThreadMapB = transform::PitchLinearWarpRakedThreadMap< + layout::PitchLinearShape, + OriginalThreadMap::kThreads, + layout::PitchLinearShape< + OriginalWarpArrangement::kContiguous * ColumnsInterleaved, + OriginalWarpArrangement::kStrided / ColumnsInterleaved>, + MmaCore::kAccessSizeInBits / sizeof_bits::value>; + + public: + // Define iterators over tiles from the B operand + using IteratorB = cutlass::transform::threadblock::PredicatedTileIterator< + GmemIteratorShape, + ElementB, + layout::ColumnMajor, + 0, + GmemThreadMapB, + kAlignmentB>; + + // ThreadMap for scale iterator + static_assert((MmaCore::Shape::kN % kAlignmentScale) == 0, ""); + using IteratorScaleThreadMap = transform::PitchLinearStripminedThreadMap< + layout::PitchLinearShape, + MmaCore::Shape::kN / kAlignmentScale, + kAlignmentScale>; + + using ScaleIterators = DefaultScaleIteratorsPipelined; + + // Define iterators over tiles from the scale operand + using IteratorScale = typename ScaleIterators::IteratorScale; + + using SmemIteratorScale = typename ScaleIterators::SmemIteratorScale; + + using Converters = + SetConverters; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = cutlass::gemm::threadblock::DqMmaPipelined< + typename MmaCore::Shape, + IteratorA, + typename MmaCore::SmemIteratorA, + IteratorB, + typename MmaCore::SmemIteratorB, + IteratorScale, + SmemIteratorScale, + ElementAccumulator, + layout::RowMajor, + typename MmaCore::MmaPolicy, + typename Converters::TransformAfterLDG, + typename Converters::TransformAfterLDS, + OperatorInfo::QuantOp>; }; -} // namespace threadblock -} // namespace gemm -} // namespace cutlass +} // namespace threadblock +} // namespace gemm +} // namespace cutlass diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/default_mma.h b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/default_mma.h index bc395d04db2..31915cf3898 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/default_mma.h +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/default_mma.h @@ -1,6 +1,6 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: Apache-2.0 + * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & + * AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,18 +18,17 @@ #include "cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h" #include "cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h" +#include "cutlass_extensions/gemm/threadblock/default_wint2x_mma.h" #include "cutlass_extensions/gemm/threadblock/default_mma_bf16.h" -namespace cutlass -{ -namespace gemm -{ -namespace threadblock -{ +namespace cutlass { +namespace gemm { +namespace threadblock { //////////////////////////////////////////////////////////////////////////////// -/// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int8 weight, mma pipelined (stage=2) +/// Specialization for row-major output (OperatorClass TensorOp), fp16 +/// activation & int8 weight, mma pipelined (stage=2) template < /// Layout type for A matrix operand typename LayoutA, @@ -51,34 +50,61 @@ template < typename InstructionShape, /// Operation performed by GEMM typename Operator> -struct DefaultMma -{ - -private: - static constexpr int kAlignmentScale = 128 / sizeof_bits::value; - - using Mma = DqMma; - -public: - // Define the MmaCore components - using MmaCore = typename Mma::MmaCore; - - // Define iterators over tiles from the A operand - using IteratorA = typename Mma::IteratorA; - - // Define iterators over tiles from the B operand - using IteratorB = typename Mma::IteratorB; - - // Define the threadblock-scoped pipelined matrix multiply - using ThreadblockMma = typename Mma::ThreadblockMma; +struct DefaultMma { + private: + static constexpr int kAlignmentScale = 128 / sizeof_bits::value; + + using Mma = DqMma; + + public: + // Define the MmaCore components + using MmaCore = typename Mma::MmaCore; + + // Define iterators over tiles from the A operand + using IteratorA = typename Mma::IteratorA; + + // Define iterators over tiles from the B operand + using IteratorB = typename Mma::IteratorB; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = typename Mma::ThreadblockMma; }; //////////////////////////////////////////////////////////////////////////////// -/// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int4 weight, mma pipelined (stage=2) +/// Specialization for row-major output (OperatorClass TensorOp), fp16 +/// activation & int4 weight, mma pipelined (stage=2) template < /// Layout type for A matrix operand typename LayoutA, @@ -100,35 +126,61 @@ template < typename InstructionShape, /// Operation performed by GEMM typename Operator> -struct DefaultMma -{ - -private: - static constexpr int kAlignmentScale = 128 / sizeof_bits::value; - - using Mma = DqMma; - -public: - // Define the MmaCore components - using MmaCore = typename Mma::MmaCore; - - // Define iterators over tiles from the A operand - using IteratorA = typename Mma::IteratorA; - - // Define iterators over tiles from the B operand - using IteratorB = typename Mma::IteratorB; - - // Define the threadblock-scoped pipelined matrix multiply - using ThreadblockMma = typename Mma::ThreadblockMma; +struct DefaultMma { + private: + static constexpr int kAlignmentScale = 128 / sizeof_bits::value; + + using Mma = DqMma; + + public: + // Define the MmaCore components + using MmaCore = typename Mma::MmaCore; + + // Define iterators over tiles from the A operand + using IteratorA = typename Mma::IteratorA; + + // Define iterators over tiles from the B operand + using IteratorB = typename Mma::IteratorB; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = typename Mma::ThreadblockMma; }; //////////////////////////////////////////////////////////////////////////////// -/// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int8 weight, mma multistage -/// (stage>=3) +/// Specialization for row-major output (OperatorClass TensorOp), fp16 +/// activation & int8 weight, mma multistage (stage>=3) template < /// Layout type for A matrix operand typename LayoutA, @@ -154,36 +206,64 @@ template < int kStages, /// Shared memory clear option SharedMemoryClearOption SharedMemoryClear> -struct DefaultMma -{ - -private: - static constexpr int kAlignmentScale = 128 / sizeof_bits::value; - - using Mma = DqMma; - -public: - // Define the MmaCore components - using MmaCore = typename Mma::MmaCore; - - // Define iterators over tiles from the A operand - using IteratorA = typename Mma::IteratorA; - - // Define iterators over tiles from the B operand - using IteratorB = typename Mma::IteratorB; - - // Define the threadblock-scoped pipelined matrix multiply - using ThreadblockMma = typename Mma::ThreadblockMma; +struct DefaultMma { + private: + static constexpr int kAlignmentScale = 128 / sizeof_bits::value; + + using Mma = DqMma; + + public: + // Define the MmaCore components + using MmaCore = typename Mma::MmaCore; + + // Define iterators over tiles from the A operand + using IteratorA = typename Mma::IteratorA; + + // Define iterators over tiles from the B operand + using IteratorB = typename Mma::IteratorB; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = typename Mma::ThreadblockMma; }; //////////////////////////////////////////////////////////////////////////////// -/// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int4 weight, mma multistage -/// (stage>=3) +/// Specialization for row-major output (OperatorClass TensorOp), fp16 +/// activation & int4 weight, mma multistage (stage>=3) template < /// Layout type for A matrix operand typename LayoutA, @@ -209,37 +289,65 @@ template < int kStages, /// Shared memory clear option SharedMemoryClearOption SharedMemoryClear> -struct DefaultMma -{ - -private: - static constexpr int kAlignmentScale = 128 / sizeof_bits::value; - - using Mma = DqMma; - -public: - // Define the MmaCore components - using MmaCore = typename Mma::MmaCore; - - // Define iterators over tiles from the A operand - using IteratorA = typename Mma::IteratorA; - - // Define iterators over tiles from the B operand - using IteratorB = typename Mma::IteratorB; - - // Define the threadblock-scoped pipelined matrix multiply - using ThreadblockMma = typename Mma::ThreadblockMma; +struct DefaultMma { + private: + static constexpr int kAlignmentScale = 128 / sizeof_bits::value; + + using Mma = DqMma; + + public: + // Define the MmaCore components + using MmaCore = typename Mma::MmaCore; + + // Define iterators over tiles from the A operand + using IteratorA = typename Mma::IteratorA; + + // Define iterators over tiles from the B operand + using IteratorB = typename Mma::IteratorB; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = typename Mma::ThreadblockMma; }; #ifdef ENABLE_FP8 //////////////////////////////////////////////////////////////////////////////// -/// Specialization for row-major output (OperatorClass TensorOp), fp8 activation & int4 weight, mma multistage -/// (stage>=3) +/// Specialization for row-major output (OperatorClass TensorOp), fp8 activation +/// & int4 weight, mma multistage (stage>=3) template < /// Layout type for A matrix operand typename LayoutA, @@ -265,36 +373,65 @@ template < int kStages, /// Shared memory clear option SharedMemoryClearOption SharedMemoryClear> -struct DefaultMma -{ - -private: - static constexpr int kAlignmentScale = 128 / sizeof_bits::value; - - using Mma = DqMma; - -public: - // Define the MmaCore components - using MmaCore = typename Mma::MmaCore; - - // Define iterators over tiles from the A operand - using IteratorA = typename Mma::IteratorA; - - // Define iterators over tiles from the B operand - using IteratorB = typename Mma::IteratorB; - - // Define the threadblock-scoped pipelined matrix multiply - using ThreadblockMma = typename Mma::ThreadblockMma; +struct DefaultMma { + private: + static constexpr int kAlignmentScale = 128 / sizeof_bits::value; + + using Mma = DqMma; + + public: + // Define the MmaCore components + using MmaCore = typename Mma::MmaCore; + + // Define iterators over tiles from the A operand + using IteratorA = typename Mma::IteratorA; + + // Define iterators over tiles from the B operand + using IteratorB = typename Mma::IteratorB; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = typename Mma::ThreadblockMma; }; #endif -// fp16 x fp16 specialization on Ampere to use mma multistage for 2 stage. Helps avoid reg spills on -// large tile when not enough shared mem is present to do 3+ stage +// fp16 x fp16 specialization on Ampere to use mma multistage for 2 stage. Helps +// avoid reg spills on large tile when not enough shared mem is present to do 3+ +// stage template < /// Layout type for A matrix operand typename LayoutA, @@ -320,39 +457,86 @@ template < bool GatherA, /// Gather operand B by using an index array bool GatherB> -struct DefaultMma -{ - - // Define the MmaCore components - // 3 is used on purpose here to trigger components for mma multistage - using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore; - - // Define iterators over tiles from the A operand - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using AccessTypeA = cutlass::Array; - using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator< - cutlass::MatrixShape, half_t, LayoutA, 1, ThreadMapA, AccessTypeA, - GatherA>; - - // Define iterators over tiles from the B operand - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - using AccessTypeB = cutlass::Array; - using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator< - cutlass::MatrixShape, half_t, LayoutB, 0, ThreadMapB, AccessTypeB, - GatherB>; - - // Define the threadblock-scoped multistage matrix multiply - using ThreadblockMma = cutlass::gemm::threadblock::MmaMultistage; +struct DefaultMma { + // Define the MmaCore components + // 3 is used on purpose here to trigger components for mma multistage + using MmaCore = + typename cutlass::gemm::threadblock::DefaultMmaCore; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::Array; + using IteratorA = + cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, + half_t, + LayoutA, + 1, + ThreadMapA, + AccessTypeA, + GatherA>; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::Array; + using IteratorB = + cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, + half_t, + LayoutB, + 0, + ThreadMapB, + AccessTypeB, + GatherB>; + + // Define the threadblock-scoped multistage matrix multiply + using ThreadblockMma = + cutlass::gemm::threadblock::MmaMultistage; }; //////////////////////////////////////////////////////////////////////////////// -/// Specialization for row-major output (OperatorClass TensorOp), fbf16 activation & int2 weight, mma multistage +/// Specialization for row-major output (OperatorClass TensorOp), fbf16 +/// activation & int2 weight, mma multistage template < /// Layout type for A matrix operand @@ -375,41 +559,50 @@ template < typename InstructionShape, /// Operation performed by GEMM typename Operator> -struct DefaultMma -{ - static cutlass::arch::CacheOperation::Kind const CacheOpA = - ((sizeof_bits::value * kAlignmentA) == 128) ? cutlass::arch::CacheOperation::Global - : cutlass::arch::CacheOperation::Always; - - static cutlass::arch::CacheOperation::Kind const CacheOpB = - ((sizeof_bits::value * kAlignmentB) == 128) ? cutlass::arch::CacheOperation::Global - : cutlass::arch::CacheOperation::Always; - - // Define the MmaCore components - using MmaCore = - typename cutlass::gemm::threadblock::DefaultMmaCore; - - // Define iterators over tiles from the A operand - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using AccessTypeA = cutlass::Array; - using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator< - cutlass::MatrixShape, half_t, LayoutA, 1, ThreadMapA, - AccessTypeA>; - - // Define iterators over tiles from the B operand - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - using AccessTypeB = cutlass::Array; - using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator< - cutlass::MatrixShape, half_t, LayoutB, 0, ThreadMapB, - AccessTypeB>; - - // Define the threadblock-scoped multistage matrix multiply - using ThreadblockMma = cutlass::gemm::threadblock::Wint2xMmaMultistage; +struct DefaultMma { + private: + using Mma = DefaultWint2xMma; + + public: + // Define the MmaCore components + using MmaCore = typename Mma::MmaCore; + + // Define iterators over tiles from the A operand + using IteratorA = typename Mma::IteratorA; + + // Define iterators over tiles from the B operand + using IteratorB = typename Mma::IteratorB; + + // Define the threadblock-scoped multistage matrix multiply + using ThreadblockMma = typename Mma::ThreadblockMma; }; template < @@ -437,44 +630,55 @@ template < int kStages, /// Shared memory clear option SharedMemoryClearOption SharedMemoryClear> -struct DefaultMma -{ - static cutlass::arch::CacheOperation::Kind const CacheOpA = - ((sizeof_bits::value * kAlignmentA) == 128) ? cutlass::arch::CacheOperation::Global - : cutlass::arch::CacheOperation::Always; - - static cutlass::arch::CacheOperation::Kind const CacheOpB = - ((sizeof_bits::value * kAlignmentB) == 128) ? cutlass::arch::CacheOperation::Global - : cutlass::arch::CacheOperation::Always; - - // Define the MmaCore components - using MmaCore = - typename cutlass::gemm::threadblock::DefaultMmaCore; - - // Define iterators over tiles from the A operand - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using AccessTypeA = cutlass::Array; - using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator< - cutlass::MatrixShape, half_t, LayoutA, 1, ThreadMapA, - AccessTypeA>; - - // Define iterators over tiles from the B operand - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - using AccessTypeB = cutlass::Array; - using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator< - cutlass::MatrixShape, half_t, LayoutB, 0, ThreadMapB, - AccessTypeB>; - - // Define the threadblock-scoped multistage matrix multiply - using ThreadblockMma = cutlass::gemm::threadblock::Wint2xMmaMultistage; +struct DefaultMma { + private: + using Mma = DefaultWint2xMma; + + public: + // Define the MmaCore components + using MmaCore = typename Mma::MmaCore; + + // Define iterators over tiles from the A operand + using IteratorA = typename Mma::IteratorA; + + // Define iterators over tiles from the B operand + using IteratorB = typename Mma::IteratorB; + + // Define the threadblock-scoped multistage matrix multiply + using ThreadblockMma = typename Mma::ThreadblockMma; }; -} // namespace threadblock -} // namespace gemm -} // namespace cutlass +} // namespace threadblock +} // namespace gemm +} // namespace cutlass diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/default_mma_bf16.h b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/default_mma_bf16.h index 5d2c3117048..1ff648e0bfa 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/default_mma_bf16.h +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/default_mma_bf16.h @@ -1,6 +1,6 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: Apache-2.0 + * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & + * AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,7 +19,7 @@ #include "cutlass/gemm/threadblock/default_mma.h" #include "cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h" #include "cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h" -#include "cutlass_extensions/gemm/threadblock/wint2x_mma_multistage.h" +#include "cutlass_extensions/gemm/threadblock/default_wint2x_mma.h" namespace cutlass { namespace gemm { @@ -27,7 +27,8 @@ namespace threadblock { //////////////////////////////////////////////////////////////////////////////// -/// Specialization for row-major output (OperatorClass TensorOp), bf16 activation & bf16 weight +/// Specialization for row-major output (OperatorClass TensorOp), bf16 +/// activation & bf16 weight template < /// Layout type for A matrix operand typename LayoutA, @@ -55,40 +56,85 @@ template < bool GatherA, /// Gather operand B by using an index array bool GatherB> -struct DefaultMma -{ - -private: - // Conversions only needed pre-ampere. This will trigger mma pipeline, so we convert before STS. - static constexpr bool arch_has_bf16_mma = ArchTag::kMinComputeCapability >= 80; - using MmaElementA = typename platform::conditional::type; - using MmaElementB = typename platform::conditional::type; - -public: - // Define the MmaCore components - using MmaCore = - typename cutlass::gemm::threadblock::DefaultMmaCore; - - using IteratorA = cutlass::transform::threadblock::PredicatedTileIterator< - cutlass::MatrixShape, bfloat16_t, LayoutA, 1, - typename MmaCore::IteratorThreadMapA, kAlignmentA, GatherA>; - - // Define iterators over tiles from the B operand - using IteratorB = cutlass::transform::threadblock::PredicatedTileIterator< - cutlass::MatrixShape, bfloat16_t, LayoutB, 0, - typename MmaCore::IteratorThreadMapB, kAlignmentB, GatherB>; - - // Define the threadblock-scoped pipelined matrix multiply - using ThreadblockMma = cutlass::gemm::threadblock::MmaPipelined; +struct DefaultMma { + private: + // Conversions only needed pre-ampere. This will trigger mma pipeline, so we + // convert before STS. + static constexpr bool arch_has_bf16_mma = + ArchTag::kMinComputeCapability >= 80; + using MmaElementA = typename platform:: + conditional::type; + using MmaElementB = typename platform:: + conditional::type; + + public: + // Define the MmaCore components + using MmaCore = + typename cutlass::gemm::threadblock::DefaultMmaCore; + + using IteratorA = cutlass::transform::threadblock::PredicatedTileIterator< + cutlass::MatrixShape, + bfloat16_t, + LayoutA, + 1, + typename MmaCore::IteratorThreadMapA, + kAlignmentA, + GatherA>; + + // Define iterators over tiles from the B operand + using IteratorB = cutlass::transform::threadblock::PredicatedTileIterator< + cutlass::MatrixShape, + bfloat16_t, + LayoutB, + 0, + typename MmaCore::IteratorThreadMapB, + kAlignmentB, + GatherB>; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = + cutlass::gemm::threadblock::MmaPipelined; }; -// bf16 x bf16 specialization on Ampere to use mma multistage for 2 stage. Helps avoid reg spills on -// large tile when not enough shared mem is present to do 3+ stage +// bf16 x bf16 specialization on Ampere to use mma multistage for 2 stage. Helps +// avoid reg spills on large tile when not enough shared mem is present to do 3+ +// stage template < /// Layout type for A matrix operand typename LayoutA, @@ -114,40 +160,86 @@ template < bool GatherA, /// Gather operand B by using an index array bool GatherB> -struct DefaultMma -{ - - // Define the MmaCore components - // 3 is used on purpose here to trigger components for mma multistage - using MmaCore = - typename cutlass::gemm::threadblock::DefaultMmaCore; - - // Define iterators over tiles from the A operand - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using AccessTypeA = cutlass::Array; - using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator< - cutlass::MatrixShape, bfloat16_t, LayoutA, 1, ThreadMapA, - AccessTypeA, GatherA>; - - // Define iterators over tiles from the B operand - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - using AccessTypeB = cutlass::Array; - using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator< - cutlass::MatrixShape, bfloat16_t, LayoutB, 0, ThreadMapB, - AccessTypeB, GatherB>; - - // Define the threadblock-scoped multistage matrix multiply - using ThreadblockMma = cutlass::gemm::threadblock::MmaMultistage; +struct DefaultMma { + // Define the MmaCore components + // 3 is used on purpose here to trigger components for mma multistage + using MmaCore = + typename cutlass::gemm::threadblock::DefaultMmaCore; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::Array; + using IteratorA = + cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, + bfloat16_t, + LayoutA, + 1, + ThreadMapA, + AccessTypeA, + GatherA>; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::Array; + using IteratorB = + cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, + bfloat16_t, + LayoutB, + 0, + ThreadMapB, + AccessTypeB, + GatherB>; + + // Define the threadblock-scoped multistage matrix multiply + using ThreadblockMma = + cutlass::gemm::threadblock::MmaMultistage; }; //////////////////////////////////////////////////////////////////////////////// -/// Specialization for row-major output (OperatorClass TensorOp), bf16 activation & int8 weight +/// Specialization for row-major output (OperatorClass TensorOp), bf16 +/// activation & int8 weight template < /// Layout type for A matrix operand typename LayoutA, @@ -169,34 +261,61 @@ template < typename InstructionShape, /// Operation performed by GEMM typename Operator> -struct DefaultMma -{ - -private: - static constexpr int kAlignmentScale = 128 / sizeof_bits::value; - - using Mma = DqMma; - -public: - // Define the MmaCore components - using MmaCore = typename Mma::MmaCore; - - // Define iterators over tiles from the A operand - using IteratorA = typename Mma::IteratorA; - - // Define iterators over tiles from the B operand - using IteratorB = typename Mma::IteratorB; - - // Define the threadblock-scoped pipelined matrix multiply - using ThreadblockMma = typename Mma::ThreadblockMma; +struct DefaultMma { + private: + static constexpr int kAlignmentScale = 128 / sizeof_bits::value; + + using Mma = DqMma; + + public: + // Define the MmaCore components + using MmaCore = typename Mma::MmaCore; + + // Define iterators over tiles from the A operand + using IteratorA = typename Mma::IteratorA; + + // Define iterators over tiles from the B operand + using IteratorB = typename Mma::IteratorB; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = typename Mma::ThreadblockMma; }; //////////////////////////////////////////////////////////////////////////////// -/// Specialization for row-major output (OperatorClass TensorOp), bf16 activation & int4 weight +/// Specialization for row-major output (OperatorClass TensorOp), bf16 +/// activation & int4 weight template < /// Layout type for A matrix operand typename LayoutA, @@ -218,34 +337,61 @@ template < typename InstructionShape, /// Operation performed by GEMM typename Operator> -struct DefaultMma -{ - -private: - static constexpr int kAlignmentScale = 128 / sizeof_bits::value; - - using Mma = DqMma; - -public: - // Define the MmaCore components - using MmaCore = typename Mma::MmaCore; - - // Define iterators over tiles from the A operand - using IteratorA = typename Mma::IteratorA; - - // Define iterators over tiles from the B operand - using IteratorB = typename Mma::IteratorB; - - // Define the threadblock-scoped pipelined matrix multiply - using ThreadblockMma = typename Mma::ThreadblockMma; +struct DefaultMma { + private: + static constexpr int kAlignmentScale = 128 / sizeof_bits::value; + + using Mma = DqMma; + + public: + // Define the MmaCore components + using MmaCore = typename Mma::MmaCore; + + // Define iterators over tiles from the A operand + using IteratorA = typename Mma::IteratorA; + + // Define iterators over tiles from the B operand + using IteratorB = typename Mma::IteratorB; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = typename Mma::ThreadblockMma; }; //////////////////////////////////////////////////////////////////////////////// -/// Specialization for row-major output (OperatorClass TensorOp), bf16 activation & int8 weight +/// Specialization for row-major output (OperatorClass TensorOp), bf16 +/// activation & int8 weight template < /// Layout type for A matrix operand typename LayoutA, @@ -271,35 +417,64 @@ template < int kStages, /// Shared memory clear option SharedMemoryClearOption SharedMemoryClear> -struct DefaultMma -{ - -private: - static constexpr int kAlignmentScale = 128 / sizeof_bits::value; - - using Mma = DqMma; - -public: - // Define the MmaCore components - using MmaCore = typename Mma::MmaCore; - - // Define iterators over tiles from the A operand - using IteratorA = typename Mma::IteratorA; - - // Define iterators over tiles from the B operand - using IteratorB = typename Mma::IteratorB; - - // Define the threadblock-scoped pipelined matrix multiply - using ThreadblockMma = typename Mma::ThreadblockMma; +struct DefaultMma { + private: + static constexpr int kAlignmentScale = 128 / sizeof_bits::value; + + using Mma = DqMma; + + public: + // Define the MmaCore components + using MmaCore = typename Mma::MmaCore; + + // Define iterators over tiles from the A operand + using IteratorA = typename Mma::IteratorA; + + // Define iterators over tiles from the B operand + using IteratorB = typename Mma::IteratorB; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = typename Mma::ThreadblockMma; }; //////////////////////////////////////////////////////////////////////////////// -/// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int4 weight +/// Specialization for row-major output (OperatorClass TensorOp), fp16 +/// activation & int4 weight template < /// Layout type for A matrix operand typename LayoutA, @@ -325,35 +500,64 @@ template < int kStages, /// Shared memory clear option SharedMemoryClearOption SharedMemoryClear> -struct DefaultMma -{ - -private: - static constexpr int kAlignmentScale = 128 / sizeof_bits::value; - - using Mma = DqMma; - -public: - // Define the MmaCore components - using MmaCore = typename Mma::MmaCore; - - // Define iterators over tiles from the A operand - using IteratorA = typename Mma::IteratorA; - - // Define iterators over tiles from the B operand - using IteratorB = typename Mma::IteratorB; - - // Define the threadblock-scoped pipelined matrix multiply - using ThreadblockMma = typename Mma::ThreadblockMma; +struct DefaultMma { + private: + static constexpr int kAlignmentScale = 128 / sizeof_bits::value; + + using Mma = DqMma; + + public: + // Define the MmaCore components + using MmaCore = typename Mma::MmaCore; + + // Define iterators over tiles from the A operand + using IteratorA = typename Mma::IteratorA; + + // Define iterators over tiles from the B operand + using IteratorB = typename Mma::IteratorB; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = typename Mma::ThreadblockMma; }; //////////////////////////////////////////////////////////////////////////////// -/// Specialization for row-major output (OperatorClass TensorOp), fbf16 activation & int2 weight, mma multistage +/// Specialization for row-major output (OperatorClass TensorOp), fbf16 +/// activation & int2 weight, mma multistage template < /// Layout type for A matrix operand @@ -376,41 +580,50 @@ template < typename InstructionShape, /// Operation performed by GEMM typename Operator> -struct DefaultMma -{ - static cutlass::arch::CacheOperation::Kind const CacheOpA = - ((sizeof_bits::value * kAlignmentA) == 128) ? cutlass::arch::CacheOperation::Global - : cutlass::arch::CacheOperation::Always; - - static cutlass::arch::CacheOperation::Kind const CacheOpB = - ((sizeof_bits::value * kAlignmentB) == 128) ? cutlass::arch::CacheOperation::Global - : cutlass::arch::CacheOperation::Always; - - // Define the MmaCore components - using MmaCore = - typename cutlass::gemm::threadblock::DefaultMmaCore; - - // Define iterators over tiles from the A operand - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using AccessTypeA = cutlass::Array; - using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator< - cutlass::MatrixShape, bfloat16_t, LayoutA, 1, ThreadMapA, - AccessTypeA>; - - // Define iterators over tiles from the B operand - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - using AccessTypeB = cutlass::Array; - using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator< - cutlass::MatrixShape, bfloat16_t, LayoutB, 0, ThreadMapB, - AccessTypeB>; - - // Define the threadblock-scoped multistage matrix multiply - using ThreadblockMma = cutlass::gemm::threadblock::Wint2xMmaMultistage; +struct DefaultMma { + private: + using Mma = DefaultWint2xMma; + + public: + // Define the MmaCore components + using MmaCore = typename Mma::MmaCore; + + // Define iterators over tiles from the A operand + using IteratorA = typename Mma::IteratorA; + + // Define iterators over tiles from the B operand + using IteratorB = typename Mma::IteratorB; + + // Define the threadblock-scoped multistage matrix multiply + using ThreadblockMma = typename Mma::ThreadblockMma; }; template < @@ -438,44 +651,55 @@ template < int kStages, /// Shared memory clear option SharedMemoryClearOption SharedMemoryClear> -struct DefaultMma -{ - static cutlass::arch::CacheOperation::Kind const CacheOpA = - ((sizeof_bits::value * kAlignmentA) == 128) ? cutlass::arch::CacheOperation::Global - : cutlass::arch::CacheOperation::Always; - - static cutlass::arch::CacheOperation::Kind const CacheOpB = - ((sizeof_bits::value * kAlignmentB) == 128) ? cutlass::arch::CacheOperation::Global - : cutlass::arch::CacheOperation::Always; - - // Define the MmaCore components - using MmaCore = - typename cutlass::gemm::threadblock::DefaultMmaCore; - - // Define iterators over tiles from the A operand - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using AccessTypeA = cutlass::Array; - using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator< - cutlass::MatrixShape, bfloat16_t, LayoutA, 1, ThreadMapA, - AccessTypeA>; - - // Define iterators over tiles from the B operand - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - using AccessTypeB = cutlass::Array; - using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator< - cutlass::MatrixShape, bfloat16_t, LayoutB, 0, ThreadMapB, - AccessTypeB>; - - // Define the threadblock-scoped multistage matrix multiply - using ThreadblockMma = cutlass::gemm::threadblock::Wint2xMmaMultistage; +struct DefaultMma { + private: + using Mma = DefaultWint2xMma; + + public: + // Define the MmaCore components + using MmaCore = typename Mma::MmaCore; + + // Define iterators over tiles from the A operand + using IteratorA = typename Mma::IteratorA; + + // Define iterators over tiles from the B operand + using IteratorB = typename Mma::IteratorB; + + // Define the threadblock-scoped multistage matrix multiply + using ThreadblockMma = typename Mma::ThreadblockMma; }; -} // namespace threadblock -} // namespace gemm -} // namespace cutlass +} // namespace threadblock +} // namespace gemm +} // namespace cutlass diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/default_mma_core.h b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/default_mma_core.h new file mode 100644 index 00000000000..58fd5644169 --- /dev/null +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/default_mma_core.h @@ -0,0 +1,215 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/gemm/threadblock/default_mma_core_sm80.h" + +namespace cutlass { +namespace gemm { +namespace threadblock { + +/// Partial specialization: +/// +/// A: row-major +/// B: uint2b_t, column-major +/// Operator: tensor op class +/// +/// This uses the default warp-level operator given tile sizes +template < + /// Shape of threadblock-scoped matrix multiply operator (concept: + /// GemmShape) + typename Shape_, + /// Shape of warp-level matrix multiply operator (concept: GemmShape) + typename WarpShape_, + /// Shape of one matrix production operation (concept: GemmShape) + typename InstructionShape_, + /// Data type of A operand + typename ElementA_, + /// Data type of accumulator + typename ElementC_, + /// Layout of accumulator + typename LayoutC_, + /// Number of stages + int Stages, + /// Operation performed by MMA + typename Operator_, + /// Cache operation of operand A + cutlass::arch::CacheOperation::Kind CacheOpA, + /// Cache operation of operand B + cutlass::arch::CacheOperation::Kind CacheOpB> +struct DefaultMmaCore { + using Shape = Shape_; + using WarpShape = WarpShape_; + using InstructionShape = InstructionShape_; + using ElementA = ElementA_; + using LayoutA = layout::RowMajor; + using ElementB = uint2b_t; + using LayoutB = layout::ColumnMajor; + using ElementC = ElementC_; + using LayoutC = LayoutC_; + static int const kStages = Stages; + static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; + static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; + + /// Number of warps present + using WarpCount = GemmShape; + + // Divisility requirements + static_assert( + !(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN), + "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."); + + /// Number of threads per warp + static int const kWarpSize = warp::WarpSize::value; + + /// Size of a threadblock-scoped access + static int const kAccessSizeInBits = 128; + + /// Number of threads total + static int const kThreads = WarpCount::kCount * kWarpSize; + + /// Size of a threadblock-scoped access of B + static constexpr int kMaxThreadsForB = + (Shape::kK * Shape::kN * sizeof_bits::value) / + kAccessSizeInBits; + static constexpr int kThreadsForB = + kMaxThreadsForB > kThreads ? kThreads : kMaxThreadsForB; + + /// Default Operator + using Operator = Operator_; + + // Warp thread arrangement + static int const kWarpThreadArrangementContiguousA = + Shape::kK / (kAccessSizeInBits / sizeof_bits::value); + + static int const kWarpThreadArrangementStridedA = + kWarpSize / kWarpThreadArrangementContiguousA; + + static int const kWarpThreadArrangementContiguousB = + Shape::kK / (kAccessSizeInBits / sizeof_bits::value); + + static int const kWarpThreadArrangementStridedB = + kWarpSize / kWarpThreadArrangementContiguousB; + + // + // Shared memory layouts + // + + using SmemLayoutA = layout::RowMajorTensorOpMultiplicandCrosswise< + sizeof_bits::value, + Shape::kK>; + + // Shared memory layout + using SmemLayoutB = layout::ColumnMajorTensorOpMultiplicandCrosswise< + sizeof_bits::value, + Shape::kK>; + + // + // Iterators to write to shared memory + // + + /// ThreadMap of iterator A + using IteratorThreadMapA = transform::PitchLinearWarpRakedThreadMap< + layout::PitchLinearShape, + kThreads, + layout::PitchLinearShape, + kAccessSizeInBits / sizeof_bits::value>; + + /// Shared memory iterator to A operand + using SmemIteratorA = transform::threadblock::RegularTileAccessIterator< + MatrixShape, + ElementA, + SmemLayoutA, + 0, + IteratorThreadMapA>; + + /// ThreadMap of iterator B + using IteratorThreadMapB = transform::PitchLinearWarpRakedThreadMap< + layout::PitchLinearShape, + kThreadsForB, + layout::PitchLinearShape, + kAccessSizeInBits / sizeof_bits::value>; + + /// Shared memory iterator to B operand + using SmemIteratorB = transform::threadblock::RegularTileAccessIterator< + MatrixShape, + ElementB, + SmemLayoutB, + 1, + IteratorThreadMapB>; + + // + // Warp-level matrix multiply operator + // + + // Define the warp-level tensor op + using MmaTensorOp = + typename cutlass::gemm::warp::DefaultMmaTensorOp::Type; + + /// Policy used to define MmaPipelined + using MmaPolicy = MmaPolicy, + MatrixShape<0, 0>, + WarpCount::kK>; +}; + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/default_wint2x_mma.h b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/default_wint2x_mma.h new file mode 100644 index 00000000000..e4684b610b1 --- /dev/null +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/default_wint2x_mma.h @@ -0,0 +1,329 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & + * AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "cutlass_extensions/arch/mma.h" +#include "cutlass_extensions/gemm/threadblock/default_dq_mma.h" +#include "cutlass_extensions/gemm/threadblock/default_mma_core.h" +#include "cutlass_extensions/gemm/threadblock/wint2x_mma_multistage.h" +#include "cutlass_extensions/gemm/threadblock/wint2x_params_accessor.h" + +namespace cutlass { +namespace gemm { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +template +struct DefaultQuantParamsIterators { + private: + static constexpr int kAlignment = 128 / sizeof_bits::value; + static_assert((ThreadblockShape::kN % kAlignment) == 0, ""); + + static constexpr int kRows = + (GroupSize == -1) ? 1 + : (ThreadblockShape::kK + GroupSize - 1) / GroupSize; + static constexpr int kColumns = ThreadblockShape::kN; + + using IteratorThreadMap = transform::PitchLinearStripminedThreadMap< + layout::PitchLinearShape, + kColumns / kAlignment, + kAlignment>; + + public: + using Iterator = cutlass::transform::threadblock::PredicatedTileIterator< + MatrixShape, + ElementT, + layout::RowMajor, + 0, + IteratorThreadMap, + kAlignment>; + using SmemIterator = Iterator; +}; + +template +struct DefaultQuantParamsIterators { + private: + static constexpr int kAlignment = 32 / sizeof_bits::value; + static_assert((ThreadblockShape::kN % kAlignment) == 0, ""); + + static constexpr int kRows = + (GroupSize == -1) + ? 1 + : (ThreadblockShape::kK + 2 * GroupSize - 1) / (2 * GroupSize); + static constexpr int kColumns = + (GroupSize == -1) ? ThreadblockShape::kN : ThreadblockShape::kN * 2; + + using IteratorThreadMap = transform::PitchLinearStripminedThreadMap< + layout::PitchLinearShape, + kColumns / kAlignment, + kAlignment>; + + public: + using AccessType = cutlass::Array; + using Iterator = + cutlass::transform::threadblock::PredicatedTileAccessIterator< + MatrixShape, + uint4b_t, + layout::RowMajor, + 0, + IteratorThreadMap, + AccessType>; + + using SmemIterator = Iterator; +}; + +template < + /// Element type for A matrix operand + typename ElementA_, + /// Layout type for A matrix operand + typename LayoutA_, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB_, + /// Layout type for B matrix operand + typename LayoutB_, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator_, + /// Layout type for C and D matrix operands + typename LayoutC_, + /// Operator class tag + typename OperatorClass_, + /// Tag indicating architecture to tune for + typename ArchTag_, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape_, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape_, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape_, + /// Number of stages used in the pipelined mainloop + int Stages, + /// Operation performed by GEMM + typename Operator_, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone> +struct DefaultWint2xMma; + +//////////////////////////////////////////////////////////////////////////////// + +template < + /// Type for element A + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Type for element B + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Operator class tag + typename OperatorClass, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Stages in GEMM + int kStages, + /// Operator performed by GEMM + typename Operator, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear> +struct DefaultWint2xMma { + public: + static_assert(platform::is_same::value || + platform::is_same::value, + "Element A must be fp16 or bf16"); + + static_assert(platform::is_same::value, + "Element B must be uint2b_t"); + + static_assert( + platform::is_same::value, + "Mma multistage must dequantize after ldsm"); + + using ElementSuperScale = ElementA; + using ElementLocalScale = uint4b_t; + using ElementCodeScaleZp = float; + + static constexpr int kGroupSize = 64; + + static cutlass::arch::CacheOperation::Kind const CacheOpA = + ((sizeof_bits::value * kAlignmentA) == 128) + ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; + + static cutlass::arch::CacheOperation::Kind const CacheOpB = + ((sizeof_bits::value * kAlignmentB) == 128) + ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; + + // Define the MmaCore components + // Mma core does not depend on stages, so pass in at least 3 here to mma + // multistage pieces are created + using MmaCore = + typename cutlass::gemm::threadblock::DefaultMmaCore; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::Array; + using IteratorA = + cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, + ElementA, + LayoutA, + 1, + ThreadMapA, + AccessTypeA>; + + private: + static constexpr int kColumnsInterleaved = LayoutB::kColumnsInterleaved; + static constexpr int kRowsPerTile = LayoutB::kRowsPerTile; + static_assert(!(MmaCore::Shape::kN % kColumnsInterleaved), + "ThreadblockShape must be disivle by kColumnsInterleaved"); + static_assert(kRowsPerTile == MmaCore::Shape::kK, ""); + + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using WarpArrangement = typename ThreadMapB::Detail::WarpThreadArrangement; + static_assert(!(WarpArrangement::kStrided % kColumnsInterleaved), ""); + + using IteratorShapeB = MatrixShape; + using InterleavedThreadMapB = transform::PitchLinearWarpRakedThreadMap< + layout::PitchLinearShape, + ThreadMapB::kThreads, + layout::PitchLinearShape, + MmaCore::kAccessSizeInBits / sizeof_bits::value>; + + public: + // Define iterators over tiles from the B operand + using AccessTypeB = cutlass::Array; + using IteratorB = + cutlass::transform::threadblock::PredicatedTileAccessIterator< + IteratorShapeB, + ElementB, + layout::ColumnMajor, + 0, + InterleavedThreadMapB, + AccessTypeB>; + + private: + // Define iterators over tiles from extra quant params for B operand + using IteratorSuperScale = + typename DefaultQuantParamsIterators::Iterator; + using SmemIteratorSuperScale = + typename DefaultQuantParamsIterators::SmemIterator; + + using IteratorLocalScale = + typename DefaultQuantParamsIterators::Iterator; + using SmemIteratorLocalScale = + typename DefaultQuantParamsIterators::SmemIterator; + + using IteratorCodeScaleZp = + typename DefaultQuantParamsIterators::Iterator; + using SmemIteratorCodeScaleZp = + typename DefaultQuantParamsIterators::Iterator; + + public: + using QuantParamsAccessor = Wint2ParamsAccessor; + + // Define the threadblock-scoped multistage matrix multiply + using ThreadblockMma = cutlass::gemm::threadblock::Wint2xMmaMultistage< + typename MmaCore::Shape, + IteratorA, + typename MmaCore::SmemIteratorA, + MmaCore::kCacheOpA, + IteratorB, + typename MmaCore::SmemIteratorB, + MmaCore::kCacheOpB, + ElementAccumulator, + layout::RowMajor, + typename MmaCore::MmaPolicy, + kStages, + QuantParamsAccessor, + SharedMemoryClear>; +}; + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/dq_mma_base.h b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/dq_mma_base.h index 1fb7f7eb28f..51e410aad35 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/dq_mma_base.h +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/dq_mma_base.h @@ -1,12 +1,12 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation @@ -18,14 +18,15 @@ * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /*! \file @@ -47,30 +48,33 @@ //////////////////////////////////////////////////////////////////////////////// -namespace cutlass -{ -namespace gemm -{ -namespace threadblock -{ +namespace cutlass { +namespace gemm { +namespace threadblock { //////////////////////////////////////////////////////////////////////////////// // SFINAE trick so I can keep the same loop code for Volta and dispatch to the -// correct warp level mma. On volta, all data is stored to shared memory as FP16. +// correct warp level mma. On volta, all data is stored to shared memory as +// FP16. template -CUTLASS_DEVICE void run_warp_mma(WarpMma& warp_mma, typename WarpMma::FragmentC& D, - typename WarpMma::FragmentA const& A, typename WarpMma::FragmentB const& B, typename WarpMma::FragmentC const& C, - int const warp_tileB_k_offset) -{ - warp_mma(D, A, B, C); +CUTLASS_DEVICE void run_warp_mma(WarpMma& warp_mma, + typename WarpMma::FragmentC& D, + typename WarpMma::FragmentA const& A, + typename WarpMma::FragmentB const& B, + typename WarpMma::FragmentC const& C, + int const warp_tileB_k_offset) { + warp_mma(D, A, B, C); } template -CUTLASS_DEVICE void run_warp_mma(WarpMma& warp_mma, typename WarpMma::FragmentC& D, - typename WarpMma::TransformedFragmentA const& A, typename WarpMma::TransformedFragmentB const& B, - typename WarpMma::FragmentC const& C, int const warp_tileB_k_offset) -{ - warp_mma(D, A, B, C, warp_tileB_k_offset); +CUTLASS_DEVICE void run_warp_mma( + WarpMma& warp_mma, + typename WarpMma::FragmentC& D, + typename WarpMma::TransformedFragmentA const& A, + typename WarpMma::TransformedFragmentB const& B, + typename WarpMma::FragmentC const& C, + int const warp_tileB_k_offset) { + warp_mma(D, A, B, C, warp_tileB_k_offset); } //////////////////////////////////////////////////////////////////////////////// @@ -90,168 +94,169 @@ template < WeightOnlyQuantOp DequantOp, /// Used for partial specialization, typename Enable = bool> -class DqMmaBase -{ -public: - ///< Size of the Gemm problem - concept: gemm::GemmShape<> - using Shape = Shape_; +class DqMmaBase { + public: + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape = Shape_; - ///< Policy describing tuning details - using Policy = Policy_; + ///< Policy describing tuning details + using Policy = Policy_; - ///< Type of the scale to be loaded - using ElementScale = ElementScale_; + ///< Type of the scale to be loaded + using ElementScale = ElementScale_; - static_assert(DequantOp != WeightOnlyQuantOp::UNDEFINED, ""); + static_assert(DequantOp != WeightOnlyQuantOp::UNDEFINED, ""); - // Finegrained scales get streamed in via cp.async - static constexpr int ScalebiasStages = isFinegrained(DequantOp) ? Stages : 1; - // We always have scales. - static constexpr int ScaleElementsPerStage = Shape::kN; - // We sometimes have a bias - static constexpr int BiasElementsPerStage = hasZero(DequantOp) ? Shape::kN : 0; + // Finegrained scales get streamed in via cp.async + static constexpr int ScalebiasStages = isFinegrained(DequantOp) ? Stages : 1; + // We always have scales. + static constexpr int ScaleElementsPerStage = Shape::kN; + // We sometimes have a bias + static constexpr int BiasElementsPerStage = + hasZero(DequantOp) ? Shape::kN : 0; - // - // Dependent types - // + // + // Dependent types + // + + /// Warp-level Mma + using Operator = typename Policy::Operator; - /// Warp-level Mma - using Operator = typename Policy::Operator; + /// Shape describing the overall GEMM computed from shared memory + /// by each warp. + using WarpGemm = typename Policy::Operator::Shape; - /// Shape describing the overall GEMM computed from shared memory - /// by each warp. - using WarpGemm = typename Policy::Operator::Shape; + /// Shape describing the number of warps filling the CTA + using WarpCount = GemmShape; - /// Shape describing the number of warps filling the CTA - using WarpCount = GemmShape; + /// Number of warp-level GEMM operations + static int const kWarpGemmIterations = + (WarpGemm::kK / Operator::Policy::MmaShape::kK); - /// Number of warp-level GEMM operations - static int const kWarpGemmIterations = (WarpGemm::kK / Operator::Policy::MmaShape::kK); + static constexpr int kNumKIterationsPerWarpBLoad = + Operator::IteratorB::InstructionShape::kRow / + Operator::InstructionShape::kK; - static constexpr int kNumKIterationsPerWarpBLoad - = Operator::IteratorB::InstructionShape::kRow / Operator::InstructionShape::kK; + static_assert(!(kWarpGemmIterations % kNumKIterationsPerWarpBLoad), ""); + static constexpr int kWarpGemmIterationsForB = + kWarpGemmIterations / kNumKIterationsPerWarpBLoad; - static_assert(!(kWarpGemmIterations % kNumKIterationsPerWarpBLoad), ""); - static constexpr int kWarpGemmIterationsForB = kWarpGemmIterations / kNumKIterationsPerWarpBLoad; + /// Number of stages + static int const kStages = Stages; - /// Number of stages - static int const kStages = Stages; + /// Tensor reference to the A operand + using TensorRefA = + TensorRef; - /// Tensor reference to the A operand - using TensorRefA = TensorRef; + /// Tensor reference to the B operand + using TensorRefB = + TensorRef; - /// Tensor reference to the B operand - using TensorRefB = TensorRef; + // + // Nested structs + // + /// Shared storage object needed by threadblock-scoped GEMM + class SharedStorage { + public: // - // Nested structs + // Type definitions // - /// Shared storage object needed by threadblock-scoped GEMM - class SharedStorage - { - public: - // - // Type definitions - // - - /// Shape of the A matrix operand in shared memory - using ShapeA - = MatrixShape; - - /// Shape of the B matrix operand in shared memory - using ShapeB - = MatrixShape; - - /// Shape of the shared memory buffer for the scales for the B matrix. - using ShapeScale = MatrixShape; - /// Shape of the shared memory buffer for the biases of the B matrix. - using ShapeZero = MatrixShape; - - public: - // - // Data members - // - - /// Buffer for A operand - AlignedBuffer operand_A; - - /// Buffer for B operand - AlignedBuffer operand_B; - - /// Buffer to hold scales for threadblock - AlignedBuffer operand_scale; - - /// Buffer to hold scales for threadblock - AlignedBuffer operand_zero; - - public: - // - // Methods - // - - /// Returns a layout object for the A matrix - CUTLASS_DEVICE - static typename Operator::LayoutA LayoutA() - { - return Operator::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn}); - } - - /// Returns a layout object for the B matrix - CUTLASS_HOST_DEVICE - static typename Operator::LayoutB LayoutB() - { - return Operator::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn}); - } - - /// Returns a TensorRef to the A operand - CUTLASS_HOST_DEVICE - TensorRefA operand_A_ref() - { - return TensorRefA{operand_A.data(), LayoutA()}; - } - - /// Returns a TensorRef to the B operand - CUTLASS_HOST_DEVICE - TensorRefB operand_B_ref() - { - return TensorRefB{operand_B.data(), LayoutB()}; - } - }; - -protected: + /// Shape of the A matrix operand in shared memory + using ShapeA = + MatrixShape; + + /// Shape of the B matrix operand in shared memory + using ShapeB = MatrixShape; + + /// Shape of the shared memory buffer for the scales for the B matrix. + using ShapeScale = MatrixShape; + /// Shape of the shared memory buffer for the biases of the B matrix. + using ShapeZero = MatrixShape; + + public: // // Data members // - /// Iterator to load a warp-scoped tile of A operand from shared memory - typename Operator::IteratorA warp_tile_iterator_A_; + /// Buffer for A operand + AlignedBuffer operand_A; + + /// Buffer for B operand + AlignedBuffer operand_B; - /// Iterator to load a warp-scoped tile of B operand from shared memory - typename Operator::IteratorB warp_tile_iterator_B_; + /// Buffer to hold scales for threadblock + AlignedBuffer operand_scale; -public: - /// Construct from tensor references + /// Buffer to hold scales for threadblock + AlignedBuffer operand_zero; + + public: + // + // Methods + // + + /// Returns a layout object for the A matrix CUTLASS_DEVICE - DqMmaBase( - ///< Shared storage needed for internal use by threadblock-scoped GEMM - SharedStorage& shared_storage, - ///< ID within the threadblock - int thread_idx, - ///< ID of warp - int warp_idx, - ///< ID of each thread within a warp - int lane_idx) - : warp_tile_iterator_A_(shared_storage.operand_A_ref(), lane_idx) - , warp_tile_iterator_B_(shared_storage.operand_B_ref(), lane_idx) - { + static typename Operator::LayoutA LayoutA() { + return Operator::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn}); + } + + /// Returns a layout object for the B matrix + CUTLASS_HOST_DEVICE + static typename Operator::LayoutB LayoutB() { + return Operator::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn}); + } + + /// Returns a TensorRef to the A operand + CUTLASS_HOST_DEVICE + TensorRefA operand_A_ref() { + return TensorRefA{operand_A.data(), LayoutA()}; + } + + /// Returns a TensorRef to the B operand + CUTLASS_HOST_DEVICE + TensorRefB operand_B_ref() { + return TensorRefB{operand_B.data(), LayoutB()}; } + }; + + protected: + // + // Data members + // + + /// Iterator to load a warp-scoped tile of A operand from shared memory + typename Operator::IteratorA warp_tile_iterator_A_; + + /// Iterator to load a warp-scoped tile of B operand from shared memory + typename Operator::IteratorB warp_tile_iterator_B_; + + public: + /// Construct from tensor references + CUTLASS_DEVICE + DqMmaBase( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + SharedStorage& shared_storage, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx) + : warp_tile_iterator_A_(shared_storage.operand_A_ref(), lane_idx), + warp_tile_iterator_B_(shared_storage.operand_B_ref(), lane_idx) {} }; ///////////////////////////////////////////////////////////////////////////////////////////////// -} // namespace threadblock -} // namespace gemm -} // namespace cutlass +} // namespace threadblock +} // namespace gemm +} // namespace cutlass ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/dq_mma_multistage.h b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/dq_mma_multistage.h index 3c4036dd8cc..4a4a3137bed 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/dq_mma_multistage.h +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/dq_mma_multistage.h @@ -1,12 +1,12 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation @@ -18,14 +18,15 @@ * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /*! \file @@ -48,12 +49,9 @@ ///////////////////////////////////////////////////////////////////////////////////////////////// -namespace cutlass -{ -namespace gemm -{ -namespace threadblock -{ +namespace cutlass { +namespace gemm { +namespace threadblock { ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -102,9 +100,9 @@ template < typename Enable = void> class DqMmaMultistage; -} // namespace threadblock -} // namespace gemm -} // namespace cutlass +} // namespace threadblock +} // namespace gemm +} // namespace cutlass #include "cutlass_extensions/gemm/threadblock/dq_mma_multistage_finegrained.h" #include "cutlass_extensions/gemm/threadblock/dq_mma_multistage_percol.h" diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/dq_mma_multistage_finegrained.h b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/dq_mma_multistage_finegrained.h index e87a51b22c8..02b3b118407 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/dq_mma_multistage_finegrained.h +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/dq_mma_multistage_finegrained.h @@ -1,12 +1,12 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation @@ -18,14 +18,15 @@ * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /*! \file @@ -48,12 +49,9 @@ ///////////////////////////////////////////////////////////////////////////////////////////////// -namespace cutlass -{ -namespace gemm -{ -namespace threadblock -{ +namespace cutlass { +namespace gemm { +namespace threadblock { ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -98,460 +96,587 @@ template < WeightOnlyQuantOp QuantOp_, /// Use zfill or predicate for out-of-bound cp.async SharedMemoryClearOption SharedMemoryClear> -class DqMmaMultistage> - : public DqMmaBase -{ -public: - ///< Base class - using Base = DqMmaBase; - ///< Size of the Gemm problem - concept: gemm::GemmShape<> - using Shape = Shape_; - ///< Iterates over tiles of A operand in global memory - using IteratorA = IteratorA_; - ///< Iterates over tiles of B operand in global memory - using IteratorB = IteratorB_; - ///< Data type of accumulator matrix - using ElementC = ElementC_; - ///< Layout of accumulator matrix - using LayoutC = LayoutC_; - ///< Policy describing tuning details - using Policy = Policy_; - - using IteratorScale = IteratorScale_; - using ElementScale = typename IteratorScale::Element; - using LayoutScale = typename IteratorScale::Layout; - - using SmemIteratorA = SmemIteratorA_; - using SmemIteratorB = SmemIteratorB_; - using SmemIteratorScale = SmemIteratorScale_; - - static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; - static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; - - using TransformBAfterLDS = TransformBAfterLDS_; - - static constexpr WeightOnlyQuantOp QuantOp = QuantOp_; - // - // Dependent types - // - - /// Fragment of accumulator tile - using FragmentC = typename Policy::Operator::FragmentC; - - /// Warp-level Mma - using Operator = typename Policy::Operator; - - /// Minimum architecture is Sm80 to support cp.async - using ArchTag = arch::Sm80; - - using Dequantizer = warp::MmaTensorOpDequantizer; - - /// Complex transform on A operand - static ComplexTransform const kTransformA = Operator::kTransformA; - - /// Complex transform on B operand - static ComplexTransform const kTransformB = Operator::kTransformB; - - static_assert(Base::SharedStorage::ShapeScale::kRow == Stages, ""); - static_assert(Base::SharedStorage::ShapeScale::kColumn == Shape::kN, ""); - - /// Internal structure exposed for introspection. - struct Detail - { - - static_assert(Base::kWarpGemmIterations > 1, - "The pipelined structure requires at least two warp-level " - "GEMM operations."); - - /// Number of cp.async instructions to load one stage of operand A - static int const AsyncCopyIterationsPerStageA = IteratorA::ThreadMap::Iterations::kCount; - - /// Number of cp.async instructions to load one stage of operand B - static int const AsyncCopyIterationsPerStageB = IteratorB::ThreadMap::Iterations::kCount; - - /// Number of stages - static int const kStages = Stages; - - /// Number of cp.async instructions to load on group of operand A - static int const kAccessesPerGroupA - = (AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; - - /// Number of cp.async instructions to load on group of operand B - static int const kAccessesPerGroupB - = (AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; - }; - -private: - using WarpFragmentA = typename Operator::FragmentA; - using WarpFragmentB = typename Operator::FragmentB; - Dequantizer warp_dequantizer_; - - using ElementA = typename IteratorA::Element; - using ElementB = typename IteratorB::Element; - using LayoutDetailsForB = kernel::LayoutDetailsB; - - static constexpr bool RequiresTileInterleave - = layout::IsColumnMajorTileInterleave::value; - static_assert(!RequiresTileInterleave || (RequiresTileInterleave && (Shape::kK == LayoutDetailsForB::ThreadblockK)), - "Layout K must match threadblockK"); - -private: - // - // Data members - // - - /// Iterator to write threadblock-scoped tile of A operand to shared memory - SmemIteratorA smem_iterator_A_; - - /// Iterator to write threadblock-scoped tile of B operand to shared memory - SmemIteratorB smem_iterator_B_; - - /// Iterator to write threadblock-scoped tile of scale and zero operand to shared memory - SmemIteratorScale smem_iterator_scale_; - -public: - /// Construct from tensor references - CUTLASS_DEVICE - DqMmaMultistage( - ///< Shared storage needed for internal use by threadblock-scoped GEMM - typename Base::SharedStorage& shared_storage, - /// The group size for quantization - int const group_size, - ///< ID within the threadblock - int thread_idx, - ///< ID of warp - int warp_idx, - ///< ID of each thread within a warp - int lane_idx) - : Base(shared_storage, thread_idx, warp_idx, lane_idx) - , warp_dequantizer_({shared_storage.operand_scale.data(), LayoutScale(Shape::kN)}, - {shared_storage.operand_zero.data(), LayoutScale(Shape::kN)}, - (warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN)) / Base::WarpCount::kM, lane_idx) - , smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx) - , smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) - , smem_iterator_scale_(LayoutScale(Shape::kN), shared_storage.operand_scale.data(), - shared_storage.operand_zero.data(), {Base::kStages, Shape::kN}, thread_idx, group_size) - { - // Compute warp location within threadblock tile by mapping the warp_id to - // three coordinates: - // _m: the warp's position within the threadblock along the M dimension - // _n: the warp's position within the threadblock along the N dimension - // _k: the warp's position within the threadblock along the K dimension - - int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); - int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); - - int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; - int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; - - // Add per-warp offsets in units of warp-level tiles - this->warp_tile_iterator_A_.add_tile_offset({warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); - this->warp_tile_iterator_B_.add_tile_offset({Base::kWarpGemmIterationsForB * warp_idx_k, warp_idx_n}); +class DqMmaMultistage> + : public DqMmaBase { + public: + ///< Base class + using Base = DqMmaBase; + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape = Shape_; + ///< Iterates over tiles of A operand in global memory + using IteratorA = IteratorA_; + ///< Iterates over tiles of B operand in global memory + using IteratorB = IteratorB_; + ///< Data type of accumulator matrix + using ElementC = ElementC_; + ///< Layout of accumulator matrix + using LayoutC = LayoutC_; + ///< Policy describing tuning details + using Policy = Policy_; + + using IteratorScale = IteratorScale_; + using ElementScale = typename IteratorScale::Element; + using LayoutScale = typename IteratorScale::Layout; + + using SmemIteratorA = SmemIteratorA_; + using SmemIteratorB = SmemIteratorB_; + using SmemIteratorScale = SmemIteratorScale_; + + static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; + static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; + + using TransformBAfterLDS = TransformBAfterLDS_; + + static constexpr WeightOnlyQuantOp QuantOp = QuantOp_; + // + // Dependent types + // + + /// Fragment of accumulator tile + using FragmentC = typename Policy::Operator::FragmentC; + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Minimum architecture is Sm80 to support cp.async + using ArchTag = arch::Sm80; + + using Dequantizer = warp::MmaTensorOpDequantizer; + + /// Complex transform on A operand + static ComplexTransform const kTransformA = Operator::kTransformA; + + /// Complex transform on B operand + static ComplexTransform const kTransformB = Operator::kTransformB; + + static_assert(Base::SharedStorage::ShapeScale::kRow == Stages, ""); + static_assert(Base::SharedStorage::ShapeScale::kColumn == Shape::kN, ""); + + /// Internal structure exposed for introspection. + struct Detail { + static_assert(Base::kWarpGemmIterations > 1, + "The pipelined structure requires at least two warp-level " + "GEMM operations."); + + /// Number of cp.async instructions to load one stage of operand A + static int const AsyncCopyIterationsPerStageA = + IteratorA::ThreadMap::Iterations::kCount; + + /// Number of cp.async instructions to load one stage of operand B + static int const AsyncCopyIterationsPerStageB = + IteratorB::ThreadMap::Iterations::kCount; + + /// Number of stages + static int const kStages = Stages; + + /// Number of cp.async instructions to load on group of operand A + static int const kAccessesPerGroupA = + (AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) / + Base::kWarpGemmIterations; + + /// Number of cp.async instructions to load on group of operand B + static int const kAccessesPerGroupB = + (AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / + Base::kWarpGemmIterations; + }; + + private: + using WarpFragmentA = typename Operator::FragmentA; + using WarpFragmentB = typename Operator::FragmentB; + Dequantizer warp_dequantizer_; + + using ElementA = typename IteratorA::Element; + using ElementB = typename IteratorB::Element; + using LayoutDetailsForB = kernel::LayoutDetailsB; + + static constexpr bool RequiresTileInterleave = + layout::IsColumnMajorTileInterleave< + typename LayoutDetailsForB::Layout>::value; + static_assert(!RequiresTileInterleave || + (RequiresTileInterleave && + (Shape::kK == LayoutDetailsForB::ThreadblockK)), + "Layout K must match threadblockK"); + + private: + // + // Data members + // + + /// Iterator to write threadblock-scoped tile of A operand to shared memory + SmemIteratorA smem_iterator_A_; + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB smem_iterator_B_; + + /// Iterator to write threadblock-scoped tile of scale and zero operand to + /// shared memory + SmemIteratorScale smem_iterator_scale_; + + public: + /// Construct from tensor references + CUTLASS_DEVICE + DqMmaMultistage( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + typename Base::SharedStorage& shared_storage, + /// The group size for quantization + int const group_size, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx) + : Base(shared_storage, thread_idx, warp_idx, lane_idx), + warp_dequantizer_( + {shared_storage.operand_scale.data(), LayoutScale(Shape::kN)}, + {shared_storage.operand_zero.data(), LayoutScale(Shape::kN)}, + (warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN)) / + Base::WarpCount::kM, + lane_idx), + smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx), + smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx), + smem_iterator_scale_(LayoutScale(Shape::kN), + shared_storage.operand_scale.data(), + shared_storage.operand_zero.data(), + {Base::kStages, Shape::kN}, + thread_idx, + group_size) { + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + + int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); + int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); + + int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; + int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; + + // Add per-warp offsets in units of warp-level tiles + this->warp_tile_iterator_A_.add_tile_offset( + {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); + this->warp_tile_iterator_B_.add_tile_offset( + {Base::kWarpGemmIterationsForB * warp_idx_k, warp_idx_n}); + } + + CUTLASS_DEVICE + void copy_scales_and_advance(IteratorScale& iterator_scale, + int stage = -1, + int k_iter = -1) { + static_assert(IteratorScale::Shape::kRow == 1, "Scale stride must be 1."); + + typename IteratorScale::AccessType* gmem_scale_ptr = + iterator_scale.get_scale(); + typename IteratorScale::AccessType* gmem_zero_ptr = + iterator_scale.get_zero(); + + typename IteratorScale::AccessType* smem_scale_ptr = + reinterpret_cast( + this->smem_iterator_scale_.get_scale()); + typename IteratorScale::AccessType* smem_zero_ptr = + reinterpret_cast( + this->smem_iterator_scale_.get_zero()); + + int const kSrcBytes = sizeof_bits::value * + IteratorScale::kAlignment / 8; + + cutlass::arch::cp_async( + smem_scale_ptr, gmem_scale_ptr, iterator_scale.valid()); + + if (gmem_zero_ptr != nullptr) { + cutlass::arch::cp_async( + smem_zero_ptr, gmem_zero_ptr, iterator_scale.valid()); } - CUTLASS_DEVICE - void copy_scales_and_advance(IteratorScale& iterator_scale, int stage = -1, int k_iter = -1) - { - static_assert(IteratorScale::Shape::kRow == 1, "Scale stride must be 1."); + if (iterator_scale.group_size_ == 64) { + iterator_scale.add_tile_offset({1, 0}); + } else if (iterator_scale.group_size_ == 128) { + if constexpr (Shape::kK == 128) { + iterator_scale.add_tile_offset({1, 0}); + } else if constexpr (Shape::kK == 64) { + if (iterator_scale.row_groupsize64_ & 0x1) { + iterator_scale.add_tile_offset({1, 0}); + } + } else { + static_assert(Shape::kK == 0, + "Unsupported k tile shape, can only be 64 or 128"); + } + } - typename IteratorScale::AccessType* gmem_scale_ptr = iterator_scale.get_scale(); - typename IteratorScale::AccessType* gmem_zero_ptr = iterator_scale.get_zero(); + iterator_scale.row_groupsize64_++; - typename IteratorScale::AccessType* smem_scale_ptr - = reinterpret_cast(this->smem_iterator_scale_.get_scale()); - typename IteratorScale::AccessType* smem_zero_ptr - = reinterpret_cast(this->smem_iterator_scale_.get_zero()); + this->smem_iterator_scale_.add_tile_offset({1, 0}); + } - int const kSrcBytes = sizeof_bits::value * IteratorScale::kAlignment / 8; + CUTLASS_DEVICE + void copy_tiles_and_advance(IteratorA& iterator_A, + IteratorB& iterator_B, + int group_start_A = 0, + int group_start_B = 0) { + iterator_A.set_iteration_index(group_start_A * + IteratorA::kAccessesPerVector); + this->smem_iterator_A_.set_iteration_index(group_start_A); - cutlass::arch::cp_async(smem_scale_ptr, gmem_scale_ptr, iterator_scale.valid()); + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupA; ++j) { + if (group_start_A + j < Detail::AsyncCopyIterationsPerStageA) { + typename IteratorA::AccessType* dst_ptr = + reinterpret_cast( + this->smem_iterator_A_.get()); - if (gmem_zero_ptr != nullptr) - { - cutlass::arch::cp_async(smem_zero_ptr, gmem_zero_ptr, iterator_scale.valid()); - } + int const kSrcBytes = sizeof_bits::value * + IteratorA::ThreadMap::kElementsPerAccess / + IteratorA::kAccessesPerVector / 8; - if (iterator_scale.group_size_ == 64) - { - iterator_scale.add_tile_offset({1, 0}); - } - else if (iterator_scale.group_size_ == 128) - { - if constexpr (Shape::kK == 128) - { - iterator_scale.add_tile_offset({1, 0}); - } - else if constexpr (Shape::kK == 64) - { - if (iterator_scale.row_groupsize64_ & 0x1) - { - iterator_scale.add_tile_offset({1, 0}); - } - } - else - { - static_assert(Shape::kK == 0, "Unsupported k tile shape, can only be 64 or 128"); - } + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { + auto gmem_ptr = iterator_A.get(); + + if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { + cutlass::arch::cp_async_zfill( + dst_ptr + v, gmem_ptr, iterator_A.valid()); + } else { + cutlass::arch::cp_async( + dst_ptr + v, gmem_ptr, iterator_A.valid()); + } + + ++iterator_A; } - iterator_scale.row_groupsize64_++; - - this->smem_iterator_scale_.add_tile_offset({1, 0}); + ++this->smem_iterator_A_; + } } - CUTLASS_DEVICE - void copy_tiles_and_advance( - IteratorA& iterator_A, IteratorB& iterator_B, int group_start_A = 0, int group_start_B = 0) - { - iterator_A.set_iteration_index(group_start_A * IteratorA::kAccessesPerVector); - this->smem_iterator_A_.set_iteration_index(group_start_A); + iterator_B.set_iteration_index(group_start_B * + IteratorB::kAccessesPerVector); + this->smem_iterator_B_.set_iteration_index(group_start_B); - // Async Copy for operand A - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::kAccessesPerGroupA; ++j) - { - if (group_start_A + j < Detail::AsyncCopyIterationsPerStageA) - { - typename IteratorA::AccessType* dst_ptr - = reinterpret_cast(this->smem_iterator_A_.get()); - - int const kSrcBytes = sizeof_bits::value - * IteratorA::ThreadMap::kElementsPerAccess / IteratorA::kAccessesPerVector / 8; - - CUTLASS_PRAGMA_UNROLL - for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) - { - auto gmem_ptr = iterator_A.get(); - - if (SharedMemoryClear == SharedMemoryClearOption::kZfill) - { - cutlass::arch::cp_async_zfill(dst_ptr + v, gmem_ptr, iterator_A.valid()); - } - else - { - cutlass::arch::cp_async(dst_ptr + v, gmem_ptr, iterator_A.valid()); - } - - ++iterator_A; - } - - ++this->smem_iterator_A_; - } - } + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) { + if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) { + typename IteratorB::AccessType* dst_ptr = + reinterpret_cast( + this->smem_iterator_B_.get()); - iterator_B.set_iteration_index(group_start_B * IteratorB::kAccessesPerVector); - this->smem_iterator_B_.set_iteration_index(group_start_B); + int const kSrcBytes = sizeof_bits::value * + IteratorB::ThreadMap::kElementsPerAccess / + IteratorB::kAccessesPerVector / 8; - // Async Copy for operand B CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) - { - if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) - { - typename IteratorB::AccessType* dst_ptr - = reinterpret_cast(this->smem_iterator_B_.get()); - - int const kSrcBytes = sizeof_bits::value - * IteratorB::ThreadMap::kElementsPerAccess / IteratorB::kAccessesPerVector / 8; - - CUTLASS_PRAGMA_UNROLL - for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) - { - auto gmem_ptr = iterator_B.get(); - - if (SharedMemoryClear == SharedMemoryClearOption::kZfill) - { - cutlass::arch::cp_async_zfill(dst_ptr + v, gmem_ptr, iterator_B.valid()); - } - else - { - cutlass::arch::cp_async(dst_ptr + v, gmem_ptr, iterator_B.valid()); - } - - ++iterator_B; - } - ++this->smem_iterator_B_; - } + for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { + auto gmem_ptr = iterator_B.get(); + + if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { + cutlass::arch::cp_async_zfill( + dst_ptr + v, gmem_ptr, iterator_B.valid()); + } else { + cutlass::arch::cp_async( + dst_ptr + v, gmem_ptr, iterator_B.valid()); + } + + ++iterator_B; } + ++this->smem_iterator_B_; + } } + } + + /// Perform a threadblock-scoped matrix multiply-accumulate + CUTLASS_DEVICE + void operator()( + ///< problem size of GEMM + int gemm_k_iterations, + ///< destination accumulator tile + FragmentC& accum, + ///< iterator over A operand in global memory + IteratorA iterator_A, + ///< iterator over B operand in global memory + IteratorB iterator_B, + ///< iterator over scale operand in global memory + IteratorScale iterator_scale, + ///< initial value of accumulator + FragmentC const& src_accum) { + // + // Prologue + // - /// Perform a threadblock-scoped matrix multiply-accumulate - CUTLASS_DEVICE - void operator()( - ///< problem size of GEMM - int gemm_k_iterations, - ///< destination accumulator tile - FragmentC& accum, - ///< iterator over A operand in global memory - IteratorA iterator_A, - ///< iterator over B operand in global memory - IteratorB iterator_B, - ///< iterator over scale operand in global memory - IteratorScale iterator_scale, - ///< initial value of accumulator - FragmentC const& src_accum) - { - - // - // Prologue - // - - TransformBAfterLDS lds_converter; - - // Issue several complete stages - CUTLASS_PRAGMA_UNROLL - for (int stage = 0; stage < Base::kStages - 1; ++stage, --gemm_k_iterations) - { - - iterator_A.clear_mask(gemm_k_iterations == 0); - iterator_B.clear_mask(gemm_k_iterations == 0); - iterator_scale.clear_mask(gemm_k_iterations == 0); + TransformBAfterLDS lds_converter; - iterator_A.set_iteration_index(0); - this->smem_iterator_A_.set_iteration_index(0); + // Issue several complete stages + CUTLASS_PRAGMA_UNROLL + for (int stage = 0; stage < Base::kStages - 1; + ++stage, --gemm_k_iterations) { + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); + iterator_scale.clear_mask(gemm_k_iterations == 0); - // Async Copy for operand A - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) - { - typename IteratorA::AccessType* dst_ptr - = reinterpret_cast(this->smem_iterator_A_.get()); + iterator_A.set_iteration_index(0); + this->smem_iterator_A_.set_iteration_index(0); - CUTLASS_PRAGMA_UNROLL - for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) - { - int const kSrcBytes = sizeof_bits::value - * IteratorA::ThreadMap::kElementsPerAccess / IteratorA::kAccessesPerVector / 8; + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { + typename IteratorA::AccessType* dst_ptr = + reinterpret_cast( + this->smem_iterator_A_.get()); - cutlass::arch::cp_async_zfill( - dst_ptr + v, iterator_A.get(), iterator_A.valid()); + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { + int const kSrcBytes = + sizeof_bits::value * + IteratorA::ThreadMap::kElementsPerAccess / + IteratorA::kAccessesPerVector / 8; - ++iterator_A; - } + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_A.get(), iterator_A.valid()); - ++this->smem_iterator_A_; - } + ++iterator_A; + } - iterator_B.set_iteration_index(0); - this->smem_iterator_B_.set_iteration_index(0); + ++this->smem_iterator_A_; + } - // Async Copy for operand B - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) - { - typename IteratorB::AccessType* dst_ptr - = reinterpret_cast(this->smem_iterator_B_.get()); + iterator_B.set_iteration_index(0); + this->smem_iterator_B_.set_iteration_index(0); - CUTLASS_PRAGMA_UNROLL - for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) - { - int const kSrcBytes = sizeof_bits::value - * IteratorB::ThreadMap::kElementsPerAccess / IteratorB::kAccessesPerVector / 8; + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { + typename IteratorB::AccessType* dst_ptr = + reinterpret_cast( + this->smem_iterator_B_.get()); - cutlass::arch::cp_async_zfill( - dst_ptr + v, iterator_B.get(), iterator_B.valid()); + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { + int const kSrcBytes = + sizeof_bits::value * + IteratorB::ThreadMap::kElementsPerAccess / + IteratorB::kAccessesPerVector / 8; - ++iterator_B; - } + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_B.get(), iterator_B.valid()); - ++this->smem_iterator_B_; - } + ++iterator_B; + } - copy_scales_and_advance(iterator_scale, stage, gemm_k_iterations); + ++this->smem_iterator_B_; + } - // Move to the next stage - iterator_A.add_tile_offset({0, 1}); - iterator_B.add_tile_offset({1, 0}); + copy_scales_and_advance(iterator_scale, stage, gemm_k_iterations); - this->smem_iterator_A_.add_tile_offset({0, 1}); - this->smem_iterator_B_.add_tile_offset({1, 0}); + // Move to the next stage + iterator_A.add_tile_offset({0, 1}); + iterator_B.add_tile_offset({1, 0}); - // Defines the boundary of a stage of cp.async. - cutlass::arch::cp_async_fence(); - } + this->smem_iterator_A_.add_tile_offset({0, 1}); + this->smem_iterator_B_.add_tile_offset({1, 0}); - // Perform accumulation in the 'd' output operand - accum = src_accum; + // Defines the boundary of a stage of cp.async. + cutlass::arch::cp_async_fence(); + } - // - // Clear the remaining tiles of SMEM. This is a functional requirement for some kernels - // so that all accumulator elements outside the GEMM footprint are zero. - // + // Perform accumulation in the 'd' output operand + accum = src_accum; - if (SharedMemoryClear == SharedMemoryClearOption::kClearLastStage) - { + // + // Clear the remaining tiles of SMEM. This is a functional requirement for + // some kernels so that all accumulator elements outside the GEMM footprint + // are zero. + // - /// Iterator to write threadblock-scoped tile of A operand to shared memory - SmemIteratorA last_smem_iterator_A(this->smem_iterator_A_); + if (SharedMemoryClear == SharedMemoryClearOption::kClearLastStage) { + /// Iterator to write threadblock-scoped tile of A operand to shared + /// memory + SmemIteratorA last_smem_iterator_A(this->smem_iterator_A_); - typename IteratorA::AccessType zero_A; - zero_A.clear(); + typename IteratorA::AccessType zero_A; + zero_A.clear(); - last_smem_iterator_A.set_iteration_index(0); + last_smem_iterator_A.set_iteration_index(0); - // Async Copy for operand A - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) - { + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { + typename IteratorA::AccessType* dst_ptr = + reinterpret_cast( + last_smem_iterator_A.get()); - typename IteratorA::AccessType* dst_ptr - = reinterpret_cast(last_smem_iterator_A.get()); + *dst_ptr = zero_A; - *dst_ptr = zero_A; + ++last_smem_iterator_A; + } - ++last_smem_iterator_A; - } + /// Iterator to write threadblock-scoped tile of B operand to shared + /// memory + SmemIteratorB last_smem_iterator_B(this->smem_iterator_B_); + typename IteratorB::AccessType zero_B; - /// Iterator to write threadblock-scoped tile of B operand to shared memory - SmemIteratorB last_smem_iterator_B(this->smem_iterator_B_); - typename IteratorB::AccessType zero_B; + zero_B.clear(); + last_smem_iterator_B.set_iteration_index(0); - zero_B.clear(); - last_smem_iterator_B.set_iteration_index(0); + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { + typename IteratorB::AccessType* dst_ptr = + reinterpret_cast( + last_smem_iterator_B.get()); - // Async Copy for operand B - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) - { + *dst_ptr = zero_B; - typename IteratorB::AccessType* dst_ptr - = reinterpret_cast(last_smem_iterator_B.get()); + ++last_smem_iterator_B; + } + } - *dst_ptr = zero_B; + // Wait until we have at least one committed global fetch stage. + // (#uncommitted = Base::kStages - 1 - #committed) + cutlass::arch::cp_async_wait(); + __syncthreads(); + + // Pair of fragments used to overlap shared memory loads and math + // instructions + WarpFragmentA warp_frag_A[2]; + WarpFragmentB warp_frag_B[2]; + typename Dequantizer::FragmentScale warp_frag_scales; + typename Dequantizer::FragmentZero warp_frag_zeros; + + Operator warp_mma; + + this->warp_tile_iterator_A_.set_kgroup_index(0); + this->warp_tile_iterator_B_.set_kgroup_index(0); + + this->warp_tile_iterator_A_.load(warp_frag_A[0]); + this->warp_tile_iterator_B_.load(warp_frag_B[0]); + + warp_dequantizer_.load(warp_frag_scales, warp_frag_zeros); + + // if((threadIdx.x||threadIdx.y||threadIdx.z)==0){ + // uint32_t* frag_b_reg_ptr = + // reinterpret_cast(&warp_frag_B[0]); printf("#### + // warp_frag_b_load [0] bid:%d-%d-%d," + // " frag_b_reg_ptr:%x-%x-%x-%x-%x-%x-%x-%x \n", + // blockIdx.x,blockIdx.y,blockIdx.z, + // frag_b_reg_ptr[0], + // frag_b_reg_ptr[1], + // frag_b_reg_ptr[2], + // frag_b_reg_ptr[3], + // frag_b_reg_ptr[4], + // frag_b_reg_ptr[5], + // frag_b_reg_ptr[6], + // frag_b_reg_ptr[7] + // ); + // } + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_B_; + warp_dequantizer_.add_pointer_offset(Shape::kN); + + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); + iterator_scale.clear_mask(gemm_k_iterations == 0); + + int smem_write_stage_idx = Base::kStages - 1; + int smem_read_stage_idx = 0; - ++last_smem_iterator_B; - } - } + // + // Mainloop + // - // Wait until we have at least one committed global fetch stage. (#uncommitted = Base::kStages - 1 - #committed) - cutlass::arch::cp_async_wait(); - __syncthreads(); + CUTLASS_GEMM_LOOP + for (; gemm_k_iterations > (-Base::kStages + 1);) { + // + // Loop over GEMM K dimension + // + + // Computes a warp-level GEMM on data held in shared memory + // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; + ++warp_mma_k) { + // Load warp-level tiles from shared memory, wrapping to k offset if + // this is the last group as the case may be. + + this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % + Base::kWarpGemmIterations); + this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]); + ++this->warp_tile_iterator_A_; - // Pair of fragments used to overlap shared memory loads and math - // instructions - WarpFragmentA warp_frag_A[2]; - WarpFragmentB warp_frag_B[2]; - typename Dequantizer::FragmentScale warp_frag_scales; - typename Dequantizer::FragmentZero warp_frag_zeros; + int const warp_tileB_k_compute_offset = + warp_mma_k % Base::kNumKIterationsPerWarpBLoad; + int const warp_tileB_k_load_offset = + warp_mma_k / Base::kNumKIterationsPerWarpBLoad; + if (warp_tileB_k_compute_offset == + Base::kNumKIterationsPerWarpBLoad - 1) { + this->warp_tile_iterator_B_.set_kgroup_index( + (warp_tileB_k_load_offset + 1) % Base::kWarpGemmIterationsForB); + this->warp_tile_iterator_B_.load( + warp_frag_B[(warp_tileB_k_load_offset + 1) % 2]); + ++this->warp_tile_iterator_B_; + } - Operator warp_mma; + typename TransformBAfterLDS::result_type converted_frag_B = + lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2]); + warp_dequantizer_.dequantize( + converted_frag_B, warp_frag_scales, warp_frag_zeros); - this->warp_tile_iterator_A_.set_kgroup_index(0); - this->warp_tile_iterator_B_.set_kgroup_index(0); + using FragmentOperandB = + cutlass::Array; + constexpr cutlass::FloatRoundStyle RoundStyle = + cutlass::FloatRoundStyle::round_to_nearest; + constexpr int ConversionVectorWidth = + TransformBAfterLDS::result_type::kElements; + static_assert(ConversionVectorWidth == FragmentOperandB::kElements); - this->warp_tile_iterator_A_.load(warp_frag_A[0]); - this->warp_tile_iterator_B_.load(warp_frag_B[0]); + using Converter = cutlass::NumericArrayConverter; - warp_dequantizer_.load(warp_frag_scales, warp_frag_zeros); + FragmentOperandB converted_frag_B_operand = + Converter::convert(converted_frag_B); // if((threadIdx.x||threadIdx.y||threadIdx.z)==0){ - // uint32_t* frag_b_reg_ptr = reinterpret_cast(&warp_frag_B[0]); - // printf("#### warp_frag_b_load [0] bid:%d-%d-%d," - // " frag_b_reg_ptr:%x-%x-%x-%x-%x-%x-%x-%x \n", + // uint32_t* frag_b_reg_ptr = + // reinterpret_cast(&warp_frag_B[(warp_tileB_k_load_offset) + // % 2]); uint32_t* converted_frag_B_reg_ptr = + // reinterpret_cast(&converted_frag_B); printf("#### + // after lds_converter bid:%d-%d-%d" + // " frag_b_reg_ptr[%d]:%x-%x-%x-%x-%x-%x-%x-%x" + // " converted_frag_b_reg_ptr:%x-%x-%x-%x-%x-%x-%x-%x \n", // blockIdx.x,blockIdx.y,blockIdx.z, + // ((warp_tileB_k_load_offset) % 2), // frag_b_reg_ptr[0], // frag_b_reg_ptr[1], // frag_b_reg_ptr[2], @@ -559,195 +684,124 @@ class DqMmaMultistagewarp_tile_iterator_A_; - ++this->warp_tile_iterator_B_; - warp_dequantizer_.add_pointer_offset(Shape::kN); - - iterator_A.clear_mask(gemm_k_iterations == 0); - iterator_B.clear_mask(gemm_k_iterations == 0); - iterator_scale.clear_mask(gemm_k_iterations == 0); - - int smem_write_stage_idx = Base::kStages - 1; - int smem_read_stage_idx = 0; - - // - // Mainloop - // - - CUTLASS_GEMM_LOOP - for (; gemm_k_iterations > (-Base::kStages + 1);) - { - // - // Loop over GEMM K dimension - // - - // Computes a warp-level GEMM on data held in shared memory - // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate - CUTLASS_PRAGMA_UNROLL - for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) - { - - // Load warp-level tiles from shared memory, wrapping to k offset if - // this is the last group as the case may be. - - this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); - this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]); - ++this->warp_tile_iterator_A_; - - int const warp_tileB_k_compute_offset = warp_mma_k % Base::kNumKIterationsPerWarpBLoad; - int const warp_tileB_k_load_offset = warp_mma_k / Base::kNumKIterationsPerWarpBLoad; - if (warp_tileB_k_compute_offset == Base::kNumKIterationsPerWarpBLoad - 1) - { - this->warp_tile_iterator_B_.set_kgroup_index( - (warp_tileB_k_load_offset + 1) % Base::kWarpGemmIterationsForB); - this->warp_tile_iterator_B_.load(warp_frag_B[(warp_tileB_k_load_offset + 1) % 2]); - ++this->warp_tile_iterator_B_; - } - - typename TransformBAfterLDS::result_type converted_frag_B - = lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2]); - warp_dequantizer_.dequantize(converted_frag_B, warp_frag_scales, warp_frag_zeros); - - using FragmentOperandB = cutlass::Array; - constexpr cutlass::FloatRoundStyle RoundStyle = cutlass::FloatRoundStyle::round_to_nearest; - constexpr int ConversionVectorWidth = TransformBAfterLDS::result_type::kElements; - static_assert(ConversionVectorWidth == FragmentOperandB::kElements); - - using Converter - = cutlass::NumericArrayConverter; - - FragmentOperandB converted_frag_B_operand = Converter::convert(converted_frag_B); - - - // if((threadIdx.x||threadIdx.y||threadIdx.z)==0){ - // uint32_t* frag_b_reg_ptr = reinterpret_cast(&warp_frag_B[(warp_tileB_k_load_offset) % 2]); - // uint32_t* converted_frag_B_reg_ptr = reinterpret_cast(&converted_frag_B); - // printf("#### after lds_converter bid:%d-%d-%d" - // " frag_b_reg_ptr[%d]:%x-%x-%x-%x-%x-%x-%x-%x" - // " converted_frag_b_reg_ptr:%x-%x-%x-%x-%x-%x-%x-%x \n", - // blockIdx.x,blockIdx.y,blockIdx.z, - // ((warp_tileB_k_load_offset) % 2), - // frag_b_reg_ptr[0], - // frag_b_reg_ptr[1], - // frag_b_reg_ptr[2], - // frag_b_reg_ptr[3], - // frag_b_reg_ptr[4], - // frag_b_reg_ptr[5], - // frag_b_reg_ptr[6], - // frag_b_reg_ptr[7], - // converted_frag_B_reg_ptr[0], - // converted_frag_B_reg_ptr[1], - // converted_frag_B_reg_ptr[2], - // converted_frag_B_reg_ptr[3], - // converted_frag_B_reg_ptr[4], - // converted_frag_B_reg_ptr[5], - // converted_frag_B_reg_ptr[6], - // converted_frag_B_reg_ptr[7] - // ); - // } - - run_warp_mma(warp_mma, accum, warp_frag_A[warp_mma_k % 2], converted_frag_B_operand, accum, - warp_tileB_k_compute_offset); - - // Issue global->shared copies for the this stage - if (warp_mma_k < Base::kWarpGemmIterations - 1) - { - int group_start_iteration_A, group_start_iteration_B; - - group_start_iteration_A = warp_mma_k * Detail::kAccessesPerGroupA; - group_start_iteration_B = warp_mma_k * Detail::kAccessesPerGroupB; - - copy_tiles_and_advance(iterator_A, iterator_B, group_start_iteration_A, group_start_iteration_B); - - // This is the first group of a given stage, so we issue the loads for the B scales immediately. - if (group_start_iteration_B == 0) - { - copy_scales_and_advance(iterator_scale); - } - } - - if (warp_mma_k + 2 == Base::kWarpGemmIterations) - { - int group_start_iteration_A, group_start_iteration_B; - group_start_iteration_A = (warp_mma_k + 1) * Detail::kAccessesPerGroupA; - group_start_iteration_B = (warp_mma_k + 1) * Detail::kAccessesPerGroupB; - - copy_tiles_and_advance(iterator_A, iterator_B, group_start_iteration_A, group_start_iteration_B); - - // Inserts a memory fence between stages of cp.async instructions. - cutlass::arch::cp_async_fence(); - - // Wait until we have at least one committed global fetch stage. (#uncommitted = Base::kStages - 1 - - // #committed) - arch::cp_async_wait(); - __syncthreads(); - - // Move to the next stage - iterator_A.add_tile_offset({0, 1}); - iterator_B.add_tile_offset({1, 0}); - - this->smem_iterator_A_.add_tile_offset({0, 1}); - this->smem_iterator_B_.add_tile_offset({1, 0}); - - // Add negative offsets to return iterators to the 'start' of the - // circular buffer in shared memory - if (smem_write_stage_idx == (Base::kStages - 1)) - { - this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); - this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); - this->smem_iterator_scale_.add_tile_offset({-Base::kStages, 0}); - smem_write_stage_idx = 0; - } - else - { - ++smem_write_stage_idx; - } - - if (smem_read_stage_idx == (Base::kStages - 1)) - { - this->warp_tile_iterator_A_.add_tile_offset( - {0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations}); - this->warp_tile_iterator_B_.add_tile_offset( - {-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterationsForB, 0}); - warp_dequantizer_.add_pointer_offset(-Base::kStages * Shape::kN); - smem_read_stage_idx = 0; - } - else - { - ++smem_read_stage_idx; - } - - --gemm_k_iterations; - iterator_A.clear_mask(gemm_k_iterations == 0); - iterator_B.clear_mask(gemm_k_iterations == 0); - iterator_scale.clear_mask(gemm_k_iterations == 0); - } - } - - // Load the scale needed for the next tile iteration. - warp_dequantizer_.load(warp_frag_scales, warp_frag_zeros); - // Update internal pointer to set of scales in shared memory. - warp_dequantizer_.add_pointer_offset(Shape::kN); + run_warp_mma(warp_mma, + accum, + warp_frag_A[warp_mma_k % 2], + converted_frag_B_operand, + accum, + warp_tileB_k_compute_offset); + + // Issue global->shared copies for the this stage + if (warp_mma_k < Base::kWarpGemmIterations - 1) { + int group_start_iteration_A, group_start_iteration_B; + + group_start_iteration_A = warp_mma_k * Detail::kAccessesPerGroupA; + group_start_iteration_B = warp_mma_k * Detail::kAccessesPerGroupB; + + copy_tiles_and_advance(iterator_A, + iterator_B, + group_start_iteration_A, + group_start_iteration_B); + + // This is the first group of a given stage, so we issue the loads for + // the B scales immediately. + if (group_start_iteration_B == 0) { + copy_scales_and_advance(iterator_scale); + } } - if (SharedMemoryClear == SharedMemoryClearOption::kZfill) - { - // commit and drain all pending and predicated LDGSTS pnz from the GEMM mainloop - cutlass::arch::cp_async_fence(); - cutlass::arch::cp_async_wait<0>(); - __syncthreads(); + if (warp_mma_k + 2 == Base::kWarpGemmIterations) { + int group_start_iteration_A, group_start_iteration_B; + group_start_iteration_A = + (warp_mma_k + 1) * Detail::kAccessesPerGroupA; + group_start_iteration_B = + (warp_mma_k + 1) * Detail::kAccessesPerGroupB; + + copy_tiles_and_advance(iterator_A, + iterator_B, + group_start_iteration_A, + group_start_iteration_B); + + // Inserts a memory fence between stages of cp.async instructions. + cutlass::arch::cp_async_fence(); + + // Wait until we have at least one committed global fetch stage. + // (#uncommitted = Base::kStages - 1 - #committed) + arch::cp_async_wait(); + __syncthreads(); + + // Move to the next stage + iterator_A.add_tile_offset({0, 1}); + iterator_B.add_tile_offset({1, 0}); + + this->smem_iterator_A_.add_tile_offset({0, 1}); + this->smem_iterator_B_.add_tile_offset({1, 0}); + + // Add negative offsets to return iterators to the 'start' of the + // circular buffer in shared memory + if (smem_write_stage_idx == (Base::kStages - 1)) { + this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); + this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); + this->smem_iterator_scale_.add_tile_offset({-Base::kStages, 0}); + smem_write_stage_idx = 0; + } else { + ++smem_write_stage_idx; + } + + if (smem_read_stage_idx == (Base::kStages - 1)) { + this->warp_tile_iterator_A_.add_tile_offset( + {0, + -Base::kStages * Policy::kPartitionsK * + Base::kWarpGemmIterations}); + this->warp_tile_iterator_B_.add_tile_offset( + {-Base::kStages * Policy::kPartitionsK * + Base::kWarpGemmIterationsForB, + 0}); + warp_dequantizer_.add_pointer_offset(-Base::kStages * Shape::kN); + smem_read_stage_idx = 0; + } else { + ++smem_read_stage_idx; + } + + --gemm_k_iterations; + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); + iterator_scale.clear_mask(gemm_k_iterations == 0); } + } + + // Load the scale needed for the next tile iteration. + warp_dequantizer_.load(warp_frag_scales, warp_frag_zeros); + // Update internal pointer to set of scales in shared memory. + warp_dequantizer_.add_pointer_offset(Shape::kN); + } + + if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { + // commit and drain all pending and predicated LDGSTS pnz from the GEMM + // mainloop + cutlass::arch::cp_async_fence(); + cutlass::arch::cp_async_wait<0>(); + __syncthreads(); } + } }; ///////////////////////////////////////////////////////////////////////////////////////////////// -} // namespace threadblock -} // namespace gemm -} // namespace cutlass +} // namespace threadblock +} // namespace gemm +} // namespace cutlass ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/dq_mma_multistage_percol.h b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/dq_mma_multistage_percol.h index 83efdc5cb01..6371da633ca 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/dq_mma_multistage_percol.h +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/dq_mma_multistage_percol.h @@ -1,12 +1,12 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation @@ -18,14 +18,15 @@ * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /*! \file @@ -48,12 +49,9 @@ ///////////////////////////////////////////////////////////////////////////////////////////////// -namespace cutlass -{ -namespace gemm -{ -namespace threadblock -{ +namespace cutlass { +namespace gemm { +namespace threadblock { ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -98,550 +96,605 @@ template < WeightOnlyQuantOp QuantOp_, /// Use zfill or predicate for out-of-bound cp.async SharedMemoryClearOption SharedMemoryClear> -class DqMmaMultistage> - : public DqMmaBase -{ -public: - ///< Base class - using Base = DqMmaBase; - ///< Size of the Gemm problem - concept: gemm::GemmShape<> - using Shape = Shape_; - ///< Iterates over tiles of A operand in global memory - using IteratorA = IteratorA_; - ///< Iterates over tiles of B operand in global memory - using IteratorB = IteratorB_; - ///< Data type of accumulator matrix - using ElementC = ElementC_; - ///< Layout of accumulator matrix - using LayoutC = LayoutC_; - ///< Policy describing tuning details - using Policy = Policy_; - - using IteratorScale = IteratorScale_; - using ElementScale = typename IteratorScale::Element; - using LayoutScale = typename IteratorScale::Layout; - - using SmemIteratorA = SmemIteratorA_; - using SmemIteratorB = SmemIteratorB_; - using SmemIteratorScale = SmemIteratorScale_; - - static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; - static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; - - using TransformBAfterLDS = TransformBAfterLDS_; - - static constexpr WeightOnlyQuantOp QuantOp = QuantOp_; +class DqMmaMultistage> + : public DqMmaBase { + public: + ///< Base class + using Base = DqMmaBase; + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape = Shape_; + ///< Iterates over tiles of A operand in global memory + using IteratorA = IteratorA_; + ///< Iterates over tiles of B operand in global memory + using IteratorB = IteratorB_; + ///< Data type of accumulator matrix + using ElementC = ElementC_; + ///< Layout of accumulator matrix + using LayoutC = LayoutC_; + ///< Policy describing tuning details + using Policy = Policy_; + + using IteratorScale = IteratorScale_; + using ElementScale = typename IteratorScale::Element; + using LayoutScale = typename IteratorScale::Layout; + + using SmemIteratorA = SmemIteratorA_; + using SmemIteratorB = SmemIteratorB_; + using SmemIteratorScale = SmemIteratorScale_; + + static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; + static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; + + using TransformBAfterLDS = TransformBAfterLDS_; + + static constexpr WeightOnlyQuantOp QuantOp = QuantOp_; + + // + // Dependent types + // + + /// Fragment of operand Scale loaded from global memory; + using FragmentScale = typename IteratorScale::Fragment; + + /// Fragment of accumulator tile + using FragmentC = typename Policy::Operator::FragmentC; + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Minimum architecture is Sm80 to support cp.async + using ArchTag = arch::Sm80; + + using Dequantizer = warp::MmaTensorOpDequantizer; + + /// Complex transform on A operand + static ComplexTransform const kTransformA = Operator::kTransformA; + + /// Complex transform on B operand + static ComplexTransform const kTransformB = Operator::kTransformB; + + /// Internal structure exposed for introspection. + struct Detail { + static_assert(Base::kWarpGemmIterations > 1, + "The pipelined structure requires at least two warp-level " + "GEMM operations."); + + /// Number of cp.async instructions to load one stage of operand A + static int const AsyncCopyIterationsPerStageA = + IteratorA::ThreadMap::Iterations::kCount; + + /// Number of cp.async instructions to load one stage of operand B + static int const AsyncCopyIterationsPerStageB = + IteratorB::ThreadMap::Iterations::kCount; + + /// Number of stages + static int const kStages = Stages; + + /// Number of cp.async instructions to load on group of operand A + static int const kAccessesPerGroupA = + (AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) / + Base::kWarpGemmIterations; + + /// Number of cp.async instructions to load on group of operand B + static int const kAccessesPerGroupB = + (AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / + Base::kWarpGemmIterations; + }; + + private: + using WarpFragmentA = typename Operator::FragmentA; + using WarpFragmentB = typename Operator::FragmentB; + Dequantizer warp_dequantizer_; + + using ElementA = typename IteratorA::Element; + using ElementB = typename IteratorB::Element; + using LayoutDetailsForB = kernel::LayoutDetailsB; + + static constexpr bool RequiresTileInterleave = + layout::IsColumnMajorTileInterleave< + typename LayoutDetailsForB::Layout>::value; + static_assert(!RequiresTileInterleave || + (RequiresTileInterleave && + (Shape::kK == LayoutDetailsForB::ThreadblockK)), + "Layout K must match threadblockK"); + + private: + // + // Data members + // + + /// Iterator to write threadblock-scoped tile of A operand to shared memory + SmemIteratorA smem_iterator_A_; + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB smem_iterator_B_; + + /// Iterator to write threadblock-scoped tile of scale operand to shared + /// memory + SmemIteratorScale smem_iterator_scale_; + + public: + /// Construct from tensor references + CUTLASS_DEVICE + DqMmaMultistage( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + typename Base::SharedStorage& shared_storage, + ///< Group size for quantization. Not used by this main loop since it + ///< assumes per-column + int const group_size, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx) + : Base(shared_storage, thread_idx, warp_idx, lane_idx), + warp_dequantizer_( + {shared_storage.operand_scale.data(), LayoutScale(Shape::kN)}, + (warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN)) / + Base::WarpCount::kM, + lane_idx), + smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx), + smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx), + smem_iterator_scale_(LayoutScale(Shape::kN), + shared_storage.operand_scale.data(), + {1, Shape::kN}, + thread_idx) { + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + + int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); + int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); + + int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; + int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; + + // Add per-warp offsets in units of warp-level tiles + this->warp_tile_iterator_A_.add_tile_offset( + {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); + this->warp_tile_iterator_B_.add_tile_offset( + {Base::kWarpGemmIterationsForB * warp_idx_k, warp_idx_n}); + } + + CUTLASS_DEVICE + void copy_tiles_and_advance(IteratorA& iterator_A, + IteratorB& iterator_B, + int group_start_A = 0, + int group_start_B = 0) { + iterator_A.set_iteration_index(group_start_A * + IteratorA::kAccessesPerVector); + this->smem_iterator_A_.set_iteration_index(group_start_A); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupA; ++j) { + if (group_start_A + j < Detail::AsyncCopyIterationsPerStageA) { + typename IteratorA::AccessType* dst_ptr = + reinterpret_cast( + this->smem_iterator_A_.get()); + + int const kSrcBytes = sizeof_bits::value * + IteratorA::ThreadMap::kElementsPerAccess / + IteratorA::kAccessesPerVector / 8; - // - // Dependent types - // - - /// Fragment of operand Scale loaded from global memory; - using FragmentScale = typename IteratorScale::Fragment; - - /// Fragment of accumulator tile - using FragmentC = typename Policy::Operator::FragmentC; - - /// Warp-level Mma - using Operator = typename Policy::Operator; - - /// Minimum architecture is Sm80 to support cp.async - using ArchTag = arch::Sm80; - - using Dequantizer = warp::MmaTensorOpDequantizer; - - /// Complex transform on A operand - static ComplexTransform const kTransformA = Operator::kTransformA; - - /// Complex transform on B operand - static ComplexTransform const kTransformB = Operator::kTransformB; - - /// Internal structure exposed for introspection. - struct Detail - { - - static_assert(Base::kWarpGemmIterations > 1, - "The pipelined structure requires at least two warp-level " - "GEMM operations."); - - /// Number of cp.async instructions to load one stage of operand A - static int const AsyncCopyIterationsPerStageA = IteratorA::ThreadMap::Iterations::kCount; - - /// Number of cp.async instructions to load one stage of operand B - static int const AsyncCopyIterationsPerStageB = IteratorB::ThreadMap::Iterations::kCount; - - /// Number of stages - static int const kStages = Stages; - - /// Number of cp.async instructions to load on group of operand A - static int const kAccessesPerGroupA - = (AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; - - /// Number of cp.async instructions to load on group of operand B - static int const kAccessesPerGroupB - = (AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; - }; - -private: - using WarpFragmentA = typename Operator::FragmentA; - using WarpFragmentB = typename Operator::FragmentB; - Dequantizer warp_dequantizer_; - - using ElementA = typename IteratorA::Element; - using ElementB = typename IteratorB::Element; - using LayoutDetailsForB = kernel::LayoutDetailsB; - - static constexpr bool RequiresTileInterleave - = layout::IsColumnMajorTileInterleave::value; - static_assert(!RequiresTileInterleave || (RequiresTileInterleave && (Shape::kK == LayoutDetailsForB::ThreadblockK)), - "Layout K must match threadblockK"); - -private: - // - // Data members - // + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { + auto gmem_ptr = iterator_A.get(); + + if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { + cutlass::arch::cp_async_zfill( + dst_ptr + v, gmem_ptr, iterator_A.valid()); + } else { + cutlass::arch::cp_async( + dst_ptr + v, gmem_ptr, iterator_A.valid()); + } + + ++iterator_A; + } - /// Iterator to write threadblock-scoped tile of A operand to shared memory - SmemIteratorA smem_iterator_A_; - - /// Iterator to write threadblock-scoped tile of B operand to shared memory - SmemIteratorB smem_iterator_B_; - - /// Iterator to write threadblock-scoped tile of scale operand to shared memory - SmemIteratorScale smem_iterator_scale_; - -public: - /// Construct from tensor references - CUTLASS_DEVICE - DqMmaMultistage( - ///< Shared storage needed for internal use by threadblock-scoped GEMM - typename Base::SharedStorage& shared_storage, - ///< Group size for quantization. Not used by this main loop since it assumes per-column - int const group_size, - ///< ID within the threadblock - int thread_idx, - ///< ID of warp - int warp_idx, - ///< ID of each thread within a warp - int lane_idx) - : Base(shared_storage, thread_idx, warp_idx, lane_idx) - , warp_dequantizer_({shared_storage.operand_scale.data(), LayoutScale(Shape::kN)}, - (warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN)) / Base::WarpCount::kM, lane_idx) - , smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx) - , smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) - , smem_iterator_scale_(LayoutScale(Shape::kN), shared_storage.operand_scale.data(), {1, Shape::kN}, thread_idx) - { - // Compute warp location within threadblock tile by mapping the warp_id to - // three coordinates: - // _m: the warp's position within the threadblock along the M dimension - // _n: the warp's position within the threadblock along the N dimension - // _k: the warp's position within the threadblock along the K dimension - - int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); - int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); - - int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; - int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; - - // Add per-warp offsets in units of warp-level tiles - this->warp_tile_iterator_A_.add_tile_offset({warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); - this->warp_tile_iterator_B_.add_tile_offset({Base::kWarpGemmIterationsForB * warp_idx_k, warp_idx_n}); + ++this->smem_iterator_A_; + } } - CUTLASS_DEVICE - void copy_tiles_and_advance( - IteratorA& iterator_A, IteratorB& iterator_B, int group_start_A = 0, int group_start_B = 0) - { - iterator_A.set_iteration_index(group_start_A * IteratorA::kAccessesPerVector); - this->smem_iterator_A_.set_iteration_index(group_start_A); + iterator_B.set_iteration_index(group_start_B * + IteratorB::kAccessesPerVector); + this->smem_iterator_B_.set_iteration_index(group_start_B); - // Async Copy for operand A - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::kAccessesPerGroupA; ++j) - { - if (group_start_A + j < Detail::AsyncCopyIterationsPerStageA) - { - typename IteratorA::AccessType* dst_ptr - = reinterpret_cast(this->smem_iterator_A_.get()); - - int const kSrcBytes = sizeof_bits::value - * IteratorA::ThreadMap::kElementsPerAccess / IteratorA::kAccessesPerVector / 8; - - CUTLASS_PRAGMA_UNROLL - for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) - { - auto gmem_ptr = iterator_A.get(); - - if (SharedMemoryClear == SharedMemoryClearOption::kZfill) - { - cutlass::arch::cp_async_zfill(dst_ptr + v, gmem_ptr, iterator_A.valid()); - } - else - { - cutlass::arch::cp_async(dst_ptr + v, gmem_ptr, iterator_A.valid()); - } - - ++iterator_A; - } - - ++this->smem_iterator_A_; - } - } + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) { + if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) { + typename IteratorB::AccessType* dst_ptr = + reinterpret_cast( + this->smem_iterator_B_.get()); - iterator_B.set_iteration_index(group_start_B * IteratorB::kAccessesPerVector); - this->smem_iterator_B_.set_iteration_index(group_start_B); + int const kSrcBytes = sizeof_bits::value * + IteratorB::ThreadMap::kElementsPerAccess / + IteratorB::kAccessesPerVector / 8; - // Async Copy for operand B CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) - { - if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) - { - typename IteratorB::AccessType* dst_ptr - = reinterpret_cast(this->smem_iterator_B_.get()); - - int const kSrcBytes = sizeof_bits::value - * IteratorB::ThreadMap::kElementsPerAccess / IteratorB::kAccessesPerVector / 8; - - CUTLASS_PRAGMA_UNROLL - for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) - { - auto gmem_ptr = iterator_B.get(); - - if (SharedMemoryClear == SharedMemoryClearOption::kZfill) - { - cutlass::arch::cp_async_zfill(dst_ptr + v, gmem_ptr, iterator_B.valid()); - } - else - { - cutlass::arch::cp_async(dst_ptr + v, gmem_ptr, iterator_B.valid()); - } - - ++iterator_B; - } - ++this->smem_iterator_B_; - } + for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { + auto gmem_ptr = iterator_B.get(); + + if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { + cutlass::arch::cp_async_zfill( + dst_ptr + v, gmem_ptr, iterator_B.valid()); + } else { + cutlass::arch::cp_async( + dst_ptr + v, gmem_ptr, iterator_B.valid()); + } + + ++iterator_B; } + ++this->smem_iterator_B_; + } } + } + + /// Perform a threadblock-scoped matrix multiply-accumulate + CUTLASS_DEVICE + void operator()( + ///< problem size of GEMM + int gemm_k_iterations, + ///< destination accumulator tile + FragmentC& accum, + ///< iterator over A operand in global memory + IteratorA iterator_A, + ///< iterator over B operand in global memory + IteratorB iterator_B, + ///< iterator over scale operand in global memory + IteratorScale iterator_scale, + ///< initial value of accumulator + FragmentC const& src_accum) { + // + // Prologue + // - /// Perform a threadblock-scoped matrix multiply-accumulate - CUTLASS_DEVICE - void operator()( - ///< problem size of GEMM - int gemm_k_iterations, - ///< destination accumulator tile - FragmentC& accum, - ///< iterator over A operand in global memory - IteratorA iterator_A, - ///< iterator over B operand in global memory - IteratorB iterator_B, - ///< iterator over scale operand in global memory - IteratorScale iterator_scale, - ///< initial value of accumulator - FragmentC const& src_accum) - { - - // - // Prologue - // - - TransformBAfterLDS lds_converter; - - // NOTE - switch to ldg.sts - // Issue this first, so cp.async.commit_group will commit this load as well. - // Note: we do not commit here and this load will commit in the same group as - // the first load of A. - FragmentScale tb_frag_scales; - tb_frag_scales.clear(); - iterator_scale.load(tb_frag_scales); - this->smem_iterator_scale_.store(tb_frag_scales); - - // Issue several complete stages - CUTLASS_PRAGMA_UNROLL - for (int stage = 0; stage < Base::kStages - 1; ++stage, --gemm_k_iterations) - { + TransformBAfterLDS lds_converter; + + // NOTE - switch to ldg.sts + // Issue this first, so cp.async.commit_group will commit this load as well. + // Note: we do not commit here and this load will commit in the same group + // as + // the first load of A. + FragmentScale tb_frag_scales; + tb_frag_scales.clear(); + iterator_scale.load(tb_frag_scales); + this->smem_iterator_scale_.store(tb_frag_scales); + + // Issue several complete stages + CUTLASS_PRAGMA_UNROLL + for (int stage = 0; stage < Base::kStages - 1; + ++stage, --gemm_k_iterations) { + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); + + iterator_A.set_iteration_index(0); + this->smem_iterator_A_.set_iteration_index(0); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { + typename IteratorA::AccessType* dst_ptr = + reinterpret_cast( + this->smem_iterator_A_.get()); - iterator_A.clear_mask(gemm_k_iterations == 0); - iterator_B.clear_mask(gemm_k_iterations == 0); + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { + int const kSrcBytes = + sizeof_bits::value * + IteratorA::ThreadMap::kElementsPerAccess / + IteratorA::kAccessesPerVector / 8; - iterator_A.set_iteration_index(0); - this->smem_iterator_A_.set_iteration_index(0); + int src_bytes = (iterator_A.valid() ? kSrcBytes : 0); - // Async Copy for operand A - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) - { - typename IteratorA::AccessType* dst_ptr - = reinterpret_cast(this->smem_iterator_A_.get()); + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_A.get(), iterator_A.valid()); - CUTLASS_PRAGMA_UNROLL - for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) - { - int const kSrcBytes = sizeof_bits::value - * IteratorA::ThreadMap::kElementsPerAccess / IteratorA::kAccessesPerVector / 8; + ++iterator_A; + } - int src_bytes = (iterator_A.valid() ? kSrcBytes : 0); + ++this->smem_iterator_A_; + } - cutlass::arch::cp_async_zfill( - dst_ptr + v, iterator_A.get(), iterator_A.valid()); + iterator_B.set_iteration_index(0); + this->smem_iterator_B_.set_iteration_index(0); - ++iterator_A; - } + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { + typename IteratorB::AccessType* dst_ptr = + reinterpret_cast( + this->smem_iterator_B_.get()); - ++this->smem_iterator_A_; - } + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { + int const kSrcBytes = + sizeof_bits::value * + IteratorB::ThreadMap::kElementsPerAccess / + IteratorB::kAccessesPerVector / 8; - iterator_B.set_iteration_index(0); - this->smem_iterator_B_.set_iteration_index(0); + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_B.get(), iterator_B.valid()); - // Async Copy for operand B - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) - { - typename IteratorB::AccessType* dst_ptr - = reinterpret_cast(this->smem_iterator_B_.get()); + ++iterator_B; + } - CUTLASS_PRAGMA_UNROLL - for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) - { - int const kSrcBytes = sizeof_bits::value - * IteratorB::ThreadMap::kElementsPerAccess / IteratorB::kAccessesPerVector / 8; + ++this->smem_iterator_B_; + } - cutlass::arch::cp_async_zfill( - dst_ptr + v, iterator_B.get(), iterator_B.valid()); + // Move to the next stage + iterator_A.add_tile_offset({0, 1}); + iterator_B.add_tile_offset({1, 0}); - ++iterator_B; - } + this->smem_iterator_A_.add_tile_offset({0, 1}); + this->smem_iterator_B_.add_tile_offset({1, 0}); - ++this->smem_iterator_B_; - } + // Defines the boundary of a stage of cp.async. + cutlass::arch::cp_async_fence(); + } - // Move to the next stage - iterator_A.add_tile_offset({0, 1}); - iterator_B.add_tile_offset({1, 0}); + // Perform accumulation in the 'd' output operand + accum = src_accum; - this->smem_iterator_A_.add_tile_offset({0, 1}); - this->smem_iterator_B_.add_tile_offset({1, 0}); + // + // Clear the remaining tiles of SMEM. This is a functional requirement for + // some kernels so that all accumulator elements outside the GEMM footprint + // are zero. + // - // Defines the boundary of a stage of cp.async. - cutlass::arch::cp_async_fence(); - } + if (SharedMemoryClear == SharedMemoryClearOption::kClearLastStage) { + /// Iterator to write threadblock-scoped tile of A operand to shared + /// memory + SmemIteratorA last_smem_iterator_A(this->smem_iterator_A_); - // Perform accumulation in the 'd' output operand - accum = src_accum; + typename IteratorA::AccessType zero_A; + zero_A.clear(); - // - // Clear the remaining tiles of SMEM. This is a functional requirement for some kernels - // so that all accumulator elements outside the GEMM footprint are zero. - // + last_smem_iterator_A.set_iteration_index(0); - if (SharedMemoryClear == SharedMemoryClearOption::kClearLastStage) - { + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { + typename IteratorA::AccessType* dst_ptr = + reinterpret_cast( + last_smem_iterator_A.get()); - /// Iterator to write threadblock-scoped tile of A operand to shared memory - SmemIteratorA last_smem_iterator_A(this->smem_iterator_A_); + *dst_ptr = zero_A; - typename IteratorA::AccessType zero_A; - zero_A.clear(); + ++last_smem_iterator_A; + } - last_smem_iterator_A.set_iteration_index(0); + /// Iterator to write threadblock-scoped tile of B operand to shared + /// memory + SmemIteratorB last_smem_iterator_B(this->smem_iterator_B_); + typename IteratorB::AccessType zero_B; - // Async Copy for operand A - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) - { + zero_B.clear(); + last_smem_iterator_B.set_iteration_index(0); - typename IteratorA::AccessType* dst_ptr - = reinterpret_cast(last_smem_iterator_A.get()); + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { + typename IteratorB::AccessType* dst_ptr = + reinterpret_cast( + last_smem_iterator_B.get()); - *dst_ptr = zero_A; + *dst_ptr = zero_B; - ++last_smem_iterator_A; - } + ++last_smem_iterator_B; + } + } - /// Iterator to write threadblock-scoped tile of B operand to shared memory - SmemIteratorB last_smem_iterator_B(this->smem_iterator_B_); - typename IteratorB::AccessType zero_B; + // Wait until we have at least one committed global fetch stage. + // (#uncommitted = Base::kStages - 1 - #committed) + cutlass::arch::cp_async_wait(); + __syncthreads(); - zero_B.clear(); - last_smem_iterator_B.set_iteration_index(0); + // Pair of fragments used to overlap shared memory loads and math + // instructions + WarpFragmentA warp_frag_A[2]; + WarpFragmentB warp_frag_B[2]; + typename Dequantizer::FragmentScale warp_frag_scales; - // Async Copy for operand B - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) - { + Operator warp_mma; - typename IteratorB::AccessType* dst_ptr - = reinterpret_cast(last_smem_iterator_B.get()); + this->warp_tile_iterator_A_.set_kgroup_index(0); + this->warp_tile_iterator_B_.set_kgroup_index(0); - *dst_ptr = zero_B; + this->warp_tile_iterator_A_.load(warp_frag_A[0]); + this->warp_tile_iterator_B_.load(warp_frag_B[0]); + warp_dequantizer_.load(warp_frag_scales); - ++last_smem_iterator_B; - } - } + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_B_; - // Wait until we have at least one committed global fetch stage. (#uncommitted = Base::kStages - 1 - #committed) - cutlass::arch::cp_async_wait(); - __syncthreads(); + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); - // Pair of fragments used to overlap shared memory loads and math - // instructions - WarpFragmentA warp_frag_A[2]; - WarpFragmentB warp_frag_B[2]; - typename Dequantizer::FragmentScale warp_frag_scales; + int smem_write_stage_idx = Base::kStages - 1; + int smem_read_stage_idx = 0; - Operator warp_mma; + // + // Mainloop + // - this->warp_tile_iterator_A_.set_kgroup_index(0); - this->warp_tile_iterator_B_.set_kgroup_index(0); + CUTLASS_GEMM_LOOP + for (; gemm_k_iterations > (-Base::kStages + 1);) { + // + // Loop over GEMM K dimension + // + + // Computes a warp-level GEMM on data held in shared memory + // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; + ++warp_mma_k) { + // Load warp-level tiles from shared memory, wrapping to k offset if + // this is the last group as the case may be. + + this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % + Base::kWarpGemmIterations); + this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]); + ++this->warp_tile_iterator_A_; - this->warp_tile_iterator_A_.load(warp_frag_A[0]); - this->warp_tile_iterator_B_.load(warp_frag_B[0]); - warp_dequantizer_.load(warp_frag_scales); + int const warp_tileB_k_compute_offset = + warp_mma_k % Base::kNumKIterationsPerWarpBLoad; + int const warp_tileB_k_load_offset = + warp_mma_k / Base::kNumKIterationsPerWarpBLoad; + if (warp_tileB_k_compute_offset == + Base::kNumKIterationsPerWarpBLoad - 1) { + this->warp_tile_iterator_B_.set_kgroup_index( + (warp_tileB_k_load_offset + 1) % Base::kWarpGemmIterationsForB); + this->warp_tile_iterator_B_.load( + warp_frag_B[(warp_tileB_k_load_offset + 1) % 2]); + ++this->warp_tile_iterator_B_; + } - ++this->warp_tile_iterator_A_; - ++this->warp_tile_iterator_B_; - - iterator_A.clear_mask(gemm_k_iterations == 0); - iterator_B.clear_mask(gemm_k_iterations == 0); - - int smem_write_stage_idx = Base::kStages - 1; - int smem_read_stage_idx = 0; - - // - // Mainloop - // - - CUTLASS_GEMM_LOOP - for (; gemm_k_iterations > (-Base::kStages + 1);) - { - // - // Loop over GEMM K dimension - // - - // Computes a warp-level GEMM on data held in shared memory - // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate - CUTLASS_PRAGMA_UNROLL - for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) - { - - // Load warp-level tiles from shared memory, wrapping to k offset if - // this is the last group as the case may be. - - this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); - this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]); - ++this->warp_tile_iterator_A_; - - int const warp_tileB_k_compute_offset = warp_mma_k % Base::kNumKIterationsPerWarpBLoad; - int const warp_tileB_k_load_offset = warp_mma_k / Base::kNumKIterationsPerWarpBLoad; - if (warp_tileB_k_compute_offset == Base::kNumKIterationsPerWarpBLoad - 1) - { - this->warp_tile_iterator_B_.set_kgroup_index( - (warp_tileB_k_load_offset + 1) % Base::kWarpGemmIterationsForB); - this->warp_tile_iterator_B_.load(warp_frag_B[(warp_tileB_k_load_offset + 1) % 2]); - ++this->warp_tile_iterator_B_; - } - - typename TransformBAfterLDS::result_type converted_frag_B - = lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2]); - warp_dequantizer_.dequantize(converted_frag_B, warp_frag_scales); - - using FragmentOperandB = cutlass::Array; - constexpr cutlass::FloatRoundStyle RoundStyle = cutlass::FloatRoundStyle::round_to_nearest; - constexpr int ConversionVectorWidth = TransformBAfterLDS::result_type::kElements; - static_assert(ConversionVectorWidth == FragmentOperandB::kElements); - - using Converter - = cutlass::NumericArrayConverter; - - FragmentOperandB converted_frag_B_operand = Converter::convert(converted_frag_B); - run_warp_mma(warp_mma, accum, warp_frag_A[warp_mma_k % 2], converted_frag_B_operand, accum, - warp_tileB_k_compute_offset); - - // Issue global->shared copies for the this stage - if (warp_mma_k < Base::kWarpGemmIterations - 1) - { - int group_start_iteration_A, group_start_iteration_B; - - group_start_iteration_A = warp_mma_k * Detail::kAccessesPerGroupA; - group_start_iteration_B = warp_mma_k * Detail::kAccessesPerGroupB; - - copy_tiles_and_advance(iterator_A, iterator_B, group_start_iteration_A, group_start_iteration_B); - } - - if (warp_mma_k + 2 == Base::kWarpGemmIterations) - { - int group_start_iteration_A, group_start_iteration_B; - group_start_iteration_A = (warp_mma_k + 1) * Detail::kAccessesPerGroupA; - group_start_iteration_B = (warp_mma_k + 1) * Detail::kAccessesPerGroupB; - - copy_tiles_and_advance(iterator_A, iterator_B, group_start_iteration_A, group_start_iteration_B); - - // Inserts a memory fence between stages of cp.async instructions. - cutlass::arch::cp_async_fence(); - - // Wait until we have at least one committed global fetch stage. (#uncommitted = Base::kStages - 1 - - // #committed) - arch::cp_async_wait(); - __syncthreads(); - - // Move to the next stage - iterator_A.add_tile_offset({0, 1}); - iterator_B.add_tile_offset({1, 0}); - - this->smem_iterator_A_.add_tile_offset({0, 1}); - this->smem_iterator_B_.add_tile_offset({1, 0}); - - // Add negative offsets to return iterators to the 'start' of the - // circular buffer in shared memory - if (smem_write_stage_idx == (Base::kStages - 1)) - { - this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); - this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); - smem_write_stage_idx = 0; - } - else - { - ++smem_write_stage_idx; - } - - if (smem_read_stage_idx == (Base::kStages - 1)) - { - this->warp_tile_iterator_A_.add_tile_offset( - {0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations}); - this->warp_tile_iterator_B_.add_tile_offset( - {-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterationsForB, 0}); - smem_read_stage_idx = 0; - } - else - { - ++smem_read_stage_idx; - } - - --gemm_k_iterations; - iterator_A.clear_mask(gemm_k_iterations == 0); - iterator_B.clear_mask(gemm_k_iterations == 0); - } - } + typename TransformBAfterLDS::result_type converted_frag_B = + lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2]); + warp_dequantizer_.dequantize(converted_frag_B, warp_frag_scales); + + using FragmentOperandB = + cutlass::Array; + constexpr cutlass::FloatRoundStyle RoundStyle = + cutlass::FloatRoundStyle::round_to_nearest; + constexpr int ConversionVectorWidth = + TransformBAfterLDS::result_type::kElements; + static_assert(ConversionVectorWidth == FragmentOperandB::kElements); + + using Converter = cutlass::NumericArrayConverter; + + FragmentOperandB converted_frag_B_operand = + Converter::convert(converted_frag_B); + run_warp_mma(warp_mma, + accum, + warp_frag_A[warp_mma_k % 2], + converted_frag_B_operand, + accum, + warp_tileB_k_compute_offset); + + // Issue global->shared copies for the this stage + if (warp_mma_k < Base::kWarpGemmIterations - 1) { + int group_start_iteration_A, group_start_iteration_B; + + group_start_iteration_A = warp_mma_k * Detail::kAccessesPerGroupA; + group_start_iteration_B = warp_mma_k * Detail::kAccessesPerGroupB; + + copy_tiles_and_advance(iterator_A, + iterator_B, + group_start_iteration_A, + group_start_iteration_B); } - if (SharedMemoryClear == SharedMemoryClearOption::kZfill) - { - // commit and drain all pending and predicated LDGSTS pnz from the GEMM mainloop - cutlass::arch::cp_async_fence(); - cutlass::arch::cp_async_wait<0>(); - __syncthreads(); + if (warp_mma_k + 2 == Base::kWarpGemmIterations) { + int group_start_iteration_A, group_start_iteration_B; + group_start_iteration_A = + (warp_mma_k + 1) * Detail::kAccessesPerGroupA; + group_start_iteration_B = + (warp_mma_k + 1) * Detail::kAccessesPerGroupB; + + copy_tiles_and_advance(iterator_A, + iterator_B, + group_start_iteration_A, + group_start_iteration_B); + + // Inserts a memory fence between stages of cp.async instructions. + cutlass::arch::cp_async_fence(); + + // Wait until we have at least one committed global fetch stage. + // (#uncommitted = Base::kStages - 1 - #committed) + arch::cp_async_wait(); + __syncthreads(); + + // Move to the next stage + iterator_A.add_tile_offset({0, 1}); + iterator_B.add_tile_offset({1, 0}); + + this->smem_iterator_A_.add_tile_offset({0, 1}); + this->smem_iterator_B_.add_tile_offset({1, 0}); + + // Add negative offsets to return iterators to the 'start' of the + // circular buffer in shared memory + if (smem_write_stage_idx == (Base::kStages - 1)) { + this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); + this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); + smem_write_stage_idx = 0; + } else { + ++smem_write_stage_idx; + } + + if (smem_read_stage_idx == (Base::kStages - 1)) { + this->warp_tile_iterator_A_.add_tile_offset( + {0, + -Base::kStages * Policy::kPartitionsK * + Base::kWarpGemmIterations}); + this->warp_tile_iterator_B_.add_tile_offset( + {-Base::kStages * Policy::kPartitionsK * + Base::kWarpGemmIterationsForB, + 0}); + smem_read_stage_idx = 0; + } else { + ++smem_read_stage_idx; + } + + --gemm_k_iterations; + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); } + } + } + + if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { + // commit and drain all pending and predicated LDGSTS pnz from the GEMM + // mainloop + cutlass::arch::cp_async_fence(); + cutlass::arch::cp_async_wait<0>(); + __syncthreads(); } + } }; ///////////////////////////////////////////////////////////////////////////////////////////////// -} // namespace threadblock -} // namespace gemm -} // namespace cutlass +} // namespace threadblock +} // namespace gemm +} // namespace cutlass ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/dq_mma_pipelined.h b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/dq_mma_pipelined.h index bd3e38971b0..dd7e8ae4b78 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/dq_mma_pipelined.h +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/dq_mma_pipelined.h @@ -1,12 +1,12 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation @@ -18,14 +18,15 @@ * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /*! \file @@ -53,27 +54,27 @@ ///////////////////////////////////////////////////////////////////////////////////////////////// -namespace cutlass -{ -namespace gemm -{ -namespace threadblock -{ +namespace cutlass { +namespace gemm { +namespace threadblock { ///////////////////////////////////////////////////////////////////////////////////////////////// -/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions. +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. template < /// Size of the Gemm problem - concept: gemm::GemmShape<> typename Shape_, /// Iterates over tiles of A operand in global memory - // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) typename IteratorA_, /// Iterates over tiles of A operand in shared memory /// (concept: WriteableTileIterator | RandomAccessTileIterator) typename SmemIteratorA_, /// Iterates over tiles of B operand in global memory - // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) typename IteratorB_, /// Iterates over tiles of B operand in shared memory /// (concept: WriteableTileIterator | RandomAccessTileIterator) @@ -98,9 +99,9 @@ template < typename Enable = void> class DqMmaPipelined; -} // namespace threadblock -} // namespace gemm -} // namespace cutlass +} // namespace threadblock +} // namespace gemm +} // namespace cutlass #include "cutlass_extensions/gemm/threadblock/dq_mma_pipelined_finegrained.h" #include "cutlass_extensions/gemm/threadblock/dq_mma_pipelined_percol.h" diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/dq_mma_pipelined_finegrained.h b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/dq_mma_pipelined_finegrained.h index 50bdd0d85b0..01fe06329be 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/dq_mma_pipelined_finegrained.h +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/dq_mma_pipelined_finegrained.h @@ -1,12 +1,12 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation @@ -18,14 +18,15 @@ * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /*! \file @@ -53,27 +54,27 @@ ///////////////////////////////////////////////////////////////////////////////////////////////// -namespace cutlass -{ -namespace gemm -{ -namespace threadblock -{ +namespace cutlass { +namespace gemm { +namespace threadblock { ///////////////////////////////////////////////////////////////////////////////////////////////// -/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions. +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. template < /// Size of the Gemm problem - concept: gemm::GemmShape<> typename Shape_, /// Iterates over tiles of A operand in global memory - // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) typename IteratorA_, /// Iterates over tiles of A operand in shared memory /// (concept: WriteableTileIterator | RandomAccessTileIterator) typename SmemIteratorA_, /// Iterates over tiles of B operand in global memory - // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) typename IteratorB_, /// Iterates over tiles of B operand in shared memory /// (concept: WriteableTileIterator | RandomAccessTileIterator) @@ -94,393 +95,442 @@ template < typename TransformBAfterLDS_, /// The quantization operator being used WeightOnlyQuantOp QuantOp_> -class DqMmaPipelined> - : public DqMmaBase -{ -public: - ///< Base class - using Base = DqMmaBase; - - using Shape = Shape_; ///< Size of the Gemm problem - concept: gemm::GemmShape<> - using IteratorA = IteratorA_; ///< Iterates over tiles of A operand in global memory - using IteratorB = IteratorB_; ///< Iterates over tiles of B operand in global memory - using ElementC = ElementC_; ///< Data type of accumulator matrix - using LayoutC = LayoutC_; ///< Layout of accumulator matrix - using Policy = Policy_; ///< Policy describing tuning details - - using IteratorScale = IteratorScale_; - using ElementScale = typename IteratorScale::Element; - using LayoutScale = typename IteratorScale::Layout; - - using SmemIteratorA = SmemIteratorA_; - using SmemIteratorB = SmemIteratorB_; - using SmemIteratorScale = SmemIteratorScale_; - - using TransformBAfterLDG = TransformBAfterLDG_; - using TransformBAfterLDS = TransformBAfterLDS_; - - static constexpr WeightOnlyQuantOp QuantOp = QuantOp_; +class DqMmaPipelined> + : public DqMmaBase { + public: + ///< Base class + using Base = DqMmaBase; + + using Shape = + Shape_; ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using IteratorA = + IteratorA_; ///< Iterates over tiles of A operand in global memory + using IteratorB = + IteratorB_; ///< Iterates over tiles of B operand in global memory + using ElementC = ElementC_; ///< Data type of accumulator matrix + using LayoutC = LayoutC_; ///< Layout of accumulator matrix + using Policy = Policy_; ///< Policy describing tuning details + + using IteratorScale = IteratorScale_; + using ElementScale = typename IteratorScale::Element; + using LayoutScale = typename IteratorScale::Layout; + + using SmemIteratorA = SmemIteratorA_; + using SmemIteratorB = SmemIteratorB_; + using SmemIteratorScale = SmemIteratorScale_; + + using TransformBAfterLDG = TransformBAfterLDG_; + using TransformBAfterLDS = TransformBAfterLDS_; + + static constexpr WeightOnlyQuantOp QuantOp = QuantOp_; + + // + // Dependent types + // + + /// Fragment of operand A loaded from global memory + using FragmentA = typename IteratorA::Fragment; + + /// Fragment of operand B loaded from global memory + using FragmentB = typename IteratorB::Fragment; + + /// Fragment of operand Scale loaded from global memory; + using FragmentScale = typename IteratorScale::Fragment; + + /// Fragment of accumulator tile + using FragmentC = typename Policy::Operator::FragmentC; + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Obtain the arch tag from the warp-level operator + using ArchTag = typename Policy::Operator::ArchTag; + + using Dequantizer = + warp::MmaTensorOpDequantizer; + + /// Complex transform on A operand + static ComplexTransform const kTransformA = Operator::kTransformA; + + /// Complex transform on B operand + static ComplexTransform const kTransformB = Operator::kTransformB; + + // staticaly assert kStages for DqMmaPipelined is two (Double-buffered + // pipeline) + static_assert((Base::kStages == 2), + "DqMmaPipelined requires kStages set to value 2"); + + static_assert(Base::SharedStorage::ShapeScale::kRow == Base::kStages, ""); + static_assert(Base::SharedStorage::ShapeScale::kColumn == Shape::kN, ""); + + private: + using WarpFragmentA = typename Operator::FragmentA; + using WarpFragmentB = typename Operator::FragmentB; + Dequantizer warp_dequantizer_; + + using WarpFragmentScale = typename Dequantizer::FragmentScale; + using WarpFragmentZero = typename Dequantizer::FragmentZero; + + using ElementA = typename IteratorA::Element; + using ElementB = typename IteratorB::Element; + using LayoutDetailsForB = kernel::LayoutDetailsB; + + static constexpr bool RequiresTileInterleave = + layout::IsColumnMajorTileInterleave< + typename LayoutDetailsForB::Layout>::value; + static_assert(!RequiresTileInterleave || + (RequiresTileInterleave && + (Shape::kK == LayoutDetailsForB::ThreadblockK)), + "Layout K must match threadblockK"); + + protected: + /// Iterator to write threadblock-scoped tile of A operand to shared memory + SmemIteratorA smem_iterator_A_; + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB smem_iterator_B_; + + /// Iterator to write threadblock-scoped tile of scale and zero operand to + /// shared memory + SmemIteratorScale smem_iterator_scale_; + + public: + /// Construct from tensor references + CUTLASS_DEVICE + DqMmaPipelined(typename Base::SharedStorage& + shared_storage, ///< Shared storage needed for internal + ///< use by threadblock-scoped GEMM + int const group_size, ///< The group size for quantization + int thread_idx, ///< ID within the threadblock + int warp_idx, ///< ID of warp + int lane_idx ///< ID of each thread within a warp + ) + : Base(shared_storage, thread_idx, warp_idx, lane_idx), + warp_dequantizer_( + {shared_storage.operand_scale.data(), LayoutScale(Shape::kN)}, + {shared_storage.operand_zero.data(), LayoutScale(Shape::kN)}, + (warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN)) / + Base::WarpCount::kM, + lane_idx), + smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx), + smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx), + smem_iterator_scale_(LayoutScale(Shape::kN), + shared_storage.operand_scale.data(), + shared_storage.operand_zero.data(), + {Base::kStages, Shape::kN}, + thread_idx, + group_size) { + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + + int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); + int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); + + int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; + int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; + + // Add per-warp offsets in units of warp-level tiles + this->warp_tile_iterator_A_.add_tile_offset( + {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); + this->warp_tile_iterator_B_.add_tile_offset( + {Base::kWarpGemmIterationsForB * warp_idx_k, warp_idx_n}); + } + + CUTLASS_DEVICE + void copy_scales_and_advance(IteratorScale& iterator_scale) { + using TransformScale = + NumericArrayConverter; + + FragmentScale tb_frag_scales; + FragmentScale tb_frag_zeros; + tb_frag_scales.clear(); + tb_frag_zeros.clear(); + + TransformScale transformScale; + + using FragmentElement = typename FragmentScale::Element; + + auto gmem_scale_ptr = iterator_scale.get_scale(); + auto gmem_zero_ptr = iterator_scale.get_zero(); + + arch::global_load( + tb_frag_scales, gmem_scale_ptr, iterator_scale.valid()); + + if (gmem_zero_ptr != nullptr) { + arch::global_load( + tb_frag_zeros, gmem_zero_ptr, iterator_scale.valid()); + } - // - // Dependent types - // + typename TransformScale::result_type tb_frag_scales_fp16 = + transformScale(tb_frag_scales); + typename TransformScale::result_type tb_frag_zeros_fp16; + if (gmem_zero_ptr != nullptr) + tb_frag_zeros_fp16 = transformScale(tb_frag_zeros); + + auto frag_scale_ptr_fp16 = + reinterpret_cast( + &tb_frag_scales_fp16); + auto frag_zero_ptr_fp16 = + reinterpret_cast( + &tb_frag_zeros_fp16); + auto smem_scale_ptr = this->smem_iterator_scale_.get_scale(); + auto smem_zero_ptr = this->smem_iterator_scale_.get_zero(); + + if (iterator_scale.valid()) { + auto smem_offset = cast_smem_ptr_to_uint(smem_scale_ptr); + arch::shared_store(smem_offset, + frag_scale_ptr_fp16); + + if (gmem_zero_ptr != nullptr) { + smem_offset = cast_smem_ptr_to_uint(smem_zero_ptr); + arch::shared_store(smem_offset, + frag_zero_ptr_fp16); + } + } - /// Fragment of operand A loaded from global memory - using FragmentA = typename IteratorA::Fragment; - - /// Fragment of operand B loaded from global memory - using FragmentB = typename IteratorB::Fragment; - - /// Fragment of operand Scale loaded from global memory; - using FragmentScale = typename IteratorScale::Fragment; - - /// Fragment of accumulator tile - using FragmentC = typename Policy::Operator::FragmentC; - - /// Warp-level Mma - using Operator = typename Policy::Operator; - - /// Obtain the arch tag from the warp-level operator - using ArchTag = typename Policy::Operator::ArchTag; - - using Dequantizer = warp::MmaTensorOpDequantizer; - - /// Complex transform on A operand - static ComplexTransform const kTransformA = Operator::kTransformA; - - /// Complex transform on B operand - static ComplexTransform const kTransformB = Operator::kTransformB; - - // staticaly assert kStages for DqMmaPipelined is two (Double-buffered pipeline) - static_assert((Base::kStages == 2), "DqMmaPipelined requires kStages set to value 2"); - - static_assert(Base::SharedStorage::ShapeScale::kRow == Base::kStages, ""); - static_assert(Base::SharedStorage::ShapeScale::kColumn == Shape::kN, ""); - -private: - using WarpFragmentA = typename Operator::FragmentA; - using WarpFragmentB = typename Operator::FragmentB; - Dequantizer warp_dequantizer_; - - using WarpFragmentScale = typename Dequantizer::FragmentScale; - using WarpFragmentZero = typename Dequantizer::FragmentZero; - - using ElementA = typename IteratorA::Element; - using ElementB = typename IteratorB::Element; - using LayoutDetailsForB = kernel::LayoutDetailsB; - - static constexpr bool RequiresTileInterleave - = layout::IsColumnMajorTileInterleave::value; - static_assert(!RequiresTileInterleave || (RequiresTileInterleave && (Shape::kK == LayoutDetailsForB::ThreadblockK)), - "Layout K must match threadblockK"); - -protected: - /// Iterator to write threadblock-scoped tile of A operand to shared memory - SmemIteratorA smem_iterator_A_; - - /// Iterator to write threadblock-scoped tile of B operand to shared memory - SmemIteratorB smem_iterator_B_; - - /// Iterator to write threadblock-scoped tile of scale and zero operand to shared memory - SmemIteratorScale smem_iterator_scale_; - -public: - /// Construct from tensor references - CUTLASS_DEVICE - DqMmaPipelined(typename Base::SharedStorage& - shared_storage, ///< Shared storage needed for internal use by threadblock-scoped GEMM - int const group_size, ///< The group size for quantization - int thread_idx, ///< ID within the threadblock - int warp_idx, ///< ID of warp - int lane_idx ///< ID of each thread within a warp - ) - : Base(shared_storage, thread_idx, warp_idx, lane_idx) - , warp_dequantizer_({shared_storage.operand_scale.data(), LayoutScale(Shape::kN)}, - {shared_storage.operand_zero.data(), LayoutScale(Shape::kN)}, - (warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN)) / Base::WarpCount::kM, lane_idx) - , smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx) - , smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) - , smem_iterator_scale_(LayoutScale(Shape::kN), shared_storage.operand_scale.data(), - shared_storage.operand_zero.data(), {Base::kStages, Shape::kN}, thread_idx, group_size) - { - - // Compute warp location within threadblock tile by mapping the warp_id to - // three coordinates: - // _m: the warp's position within the threadblock along the M dimension - // _n: the warp's position within the threadblock along the N dimension - // _k: the warp's position within the threadblock along the K dimension - - int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); - int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); - - int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; - int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; - - // Add per-warp offsets in units of warp-level tiles - this->warp_tile_iterator_A_.add_tile_offset({warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); - this->warp_tile_iterator_B_.add_tile_offset({Base::kWarpGemmIterationsForB * warp_idx_k, warp_idx_n}); + if (iterator_scale.group_size_ == 64) { + iterator_scale.add_tile_offset({1, 0}); + } else if (iterator_scale.group_size_ == 128) { + if constexpr (Shape::kK == 128) { + iterator_scale.add_tile_offset({1, 0}); + } else if constexpr (Shape::kK == 64) { + if (iterator_scale.row_groupsize64_ & 0x1) { + iterator_scale.add_tile_offset({1, 0}); + } + } else { + static_assert(Shape::kK == 0, + "Unsupported k tile shape, can only be 64 or 128"); + } } - CUTLASS_DEVICE - void copy_scales_and_advance(IteratorScale& iterator_scale) - { - using TransformScale = NumericArrayConverter; + iterator_scale.row_groupsize64_++; - FragmentScale tb_frag_scales; - FragmentScale tb_frag_zeros; - tb_frag_scales.clear(); - tb_frag_zeros.clear(); + this->smem_iterator_scale_.add_tile_offset({1, 0}); + } - TransformScale transformScale; + /// Perform a threadblock-scoped matrix multiply-accumulate + CUTLASS_DEVICE + void operator()( + int gemm_k_iterations, ///< number of iterations of the mainloop + FragmentC& accum, ///< destination accumulator tile + IteratorA iterator_A, ///< iterator over A operand in global memory + IteratorB iterator_B, ///< iterator over B operand in global memory + IteratorScale + iterator_scale, ///< iterator over scale operand in global memory + FragmentC const& src_accum) { ///< source accumulator tile - using FragmentElement = typename FragmentScale::Element; + // + // Prologue + // + TransformBAfterLDG ldg_converter; + TransformBAfterLDS lds_converter; - auto gmem_scale_ptr = iterator_scale.get_scale(); - auto gmem_zero_ptr = iterator_scale.get_zero(); + using TransformA = NumericArrayConverter; - arch::global_load(tb_frag_scales, gmem_scale_ptr, iterator_scale.valid()); + // These transforms are mainly to handle when we have bfloat activations and + // weights in GMEM and want to issue HMMA on architectures older than + // Ampere. We will convert to FP16 before STS. + TransformA transformA; - if (gmem_zero_ptr != nullptr) - { - arch::global_load( - tb_frag_zeros, gmem_zero_ptr, iterator_scale.valid()); - } + // Perform accumulation in the 'd' output operand + accum = src_accum; - typename TransformScale::result_type tb_frag_scales_fp16 = transformScale(tb_frag_scales); - typename TransformScale::result_type tb_frag_zeros_fp16; - if (gmem_zero_ptr != nullptr) - tb_frag_zeros_fp16 = transformScale(tb_frag_zeros); - - auto frag_scale_ptr_fp16 = reinterpret_cast(&tb_frag_scales_fp16); - auto frag_zero_ptr_fp16 = reinterpret_cast(&tb_frag_zeros_fp16); - auto smem_scale_ptr = this->smem_iterator_scale_.get_scale(); - auto smem_zero_ptr = this->smem_iterator_scale_.get_zero(); - - if (iterator_scale.valid()) - { - auto smem_offset = cast_smem_ptr_to_uint(smem_scale_ptr); - arch::shared_store(smem_offset, frag_scale_ptr_fp16); - - if (gmem_zero_ptr != nullptr) - { - smem_offset = cast_smem_ptr_to_uint(smem_zero_ptr); - arch::shared_store(smem_offset, frag_zero_ptr_fp16); - } - } + FragmentA tb_frag_A; + FragmentB tb_frag_B; - if (iterator_scale.group_size_ == 64) - { - iterator_scale.add_tile_offset({1, 0}); - } - else if (iterator_scale.group_size_ == 128) - { - if constexpr (Shape::kK == 128) - { - iterator_scale.add_tile_offset({1, 0}); - } - else if constexpr (Shape::kK == 64) - { - if (iterator_scale.row_groupsize64_ & 0x1) - { - iterator_scale.add_tile_offset({1, 0}); - } - } - else - { - static_assert(Shape::kK == 0, "Unsupported k tile shape, can only be 64 or 128"); - } - } + tb_frag_A.clear(); + tb_frag_B.clear(); - iterator_scale.row_groupsize64_++; + // The last kblock is loaded in the prolog + iterator_A.load(tb_frag_A); + iterator_B.load(tb_frag_B); - this->smem_iterator_scale_.add_tile_offset({1, 0}); - } + ++iterator_A; + ++iterator_B; - /// Perform a threadblock-scoped matrix multiply-accumulate - CUTLASS_DEVICE - void operator()(int gemm_k_iterations, ///< number of iterations of the mainloop - FragmentC& accum, ///< destination accumulator tile - IteratorA iterator_A, ///< iterator over A operand in global memory - IteratorB iterator_B, ///< iterator over B operand in global memory - IteratorScale iterator_scale, ///< iterator over scale operand in global memory - FragmentC const& src_accum) - { ///< source accumulator tile + this->smem_iterator_A_.store(transformA(tb_frag_A)); + this->smem_iterator_B_.store(ldg_converter(tb_frag_B)); - // - // Prologue - // - TransformBAfterLDG ldg_converter; - TransformBAfterLDS lds_converter; + ++this->smem_iterator_A_; + ++this->smem_iterator_B_; - using TransformA - = NumericArrayConverter; + copy_scales_and_advance(iterator_scale); - // These transforms are mainly to handle when we have bfloat activations and weights in GMEM and want - // to issue HMMA on architectures older than Ampere. We will convert to FP16 before STS. - TransformA transformA; + __syncthreads(); - // Perform accumulation in the 'd' output operand - accum = src_accum; + // Pair of fragments used to overlap shared memory loads and math + // instructions + WarpFragmentA warp_frag_A[2]; + WarpFragmentB warp_frag_B[2]; + WarpFragmentScale warp_frag_scales; + WarpFragmentZero warp_frag_zero; - FragmentA tb_frag_A; - FragmentB tb_frag_B; + this->warp_tile_iterator_A_.set_kgroup_index(0); + this->warp_tile_iterator_B_.set_kgroup_index(0); - tb_frag_A.clear(); - tb_frag_B.clear(); + this->warp_tile_iterator_A_.load(warp_frag_A[0]); + this->warp_tile_iterator_B_.load(warp_frag_B[0]); - // The last kblock is loaded in the prolog - iterator_A.load(tb_frag_A); - iterator_B.load(tb_frag_B); + warp_dequantizer_.load(warp_frag_scales, warp_frag_zero); - ++iterator_A; - ++iterator_B; + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_B_; + warp_dequantizer_.add_pointer_offset(Shape::kN); - this->smem_iterator_A_.store(transformA(tb_frag_A)); - this->smem_iterator_B_.store(ldg_converter(tb_frag_B)); + Operator warp_mma; - ++this->smem_iterator_A_; - ++this->smem_iterator_B_; + int smem_write_stage_idx = 1; - copy_scales_and_advance(iterator_scale); + // Avoid reading out of bounds + iterator_A.clear_mask(gemm_k_iterations <= 1); + iterator_B.clear_mask(gemm_k_iterations <= 1); + iterator_scale.clear_mask(gemm_k_iterations <= 1); - __syncthreads(); + // Issue loads during the first warp-level matrix multiply-add *AFTER* + // issuing shared memory loads (which have the tighest latency requirement). - // Pair of fragments used to overlap shared memory loads and math instructions - WarpFragmentA warp_frag_A[2]; - WarpFragmentB warp_frag_B[2]; - WarpFragmentScale warp_frag_scales; - WarpFragmentZero warp_frag_zero; + // + // Mainloop + // - this->warp_tile_iterator_A_.set_kgroup_index(0); - this->warp_tile_iterator_B_.set_kgroup_index(0); + // Note: The main loop does not support Base::kWarpGemmIterations == 2. + CUTLASS_GEMM_LOOP + for (; gemm_k_iterations > 0; --gemm_k_iterations) { + // + // Loop over GEMM K dimension + // + + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; + ++warp_mma_k) { + // Load warp-level tiles from shared memory, wrapping to k offset if + // this is the last group as the case may be. + + if (warp_mma_k == Base::kWarpGemmIterations - 1) { + // Write fragments to shared memory + this->smem_iterator_A_.store(transformA(tb_frag_A)); + + this->smem_iterator_B_.store(ldg_converter(tb_frag_B)); + + __syncthreads(); + + ++this->smem_iterator_A_; + ++this->smem_iterator_B_; + + // Add negative offsets to return iterators to the 'start' of the + // circular buffer in shared memory + if (smem_write_stage_idx == 1) { + this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); + this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); + this->smem_iterator_scale_.add_tile_offset({-Base::kStages, 0}); + } else { + this->warp_tile_iterator_A_.add_tile_offset( + {0, + -Base::kStages * Policy::kPartitionsK * + Base::kWarpGemmIterations}); + this->warp_tile_iterator_B_.add_tile_offset( + {-Base::kStages * Policy::kPartitionsK * + Base::kWarpGemmIterationsForB, + 0}); + warp_dequantizer_.add_pointer_offset(-Base::kStages * Shape::kN); + } + + smem_write_stage_idx ^= 1; + } - this->warp_tile_iterator_A_.load(warp_frag_A[0]); - this->warp_tile_iterator_B_.load(warp_frag_B[0]); + this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % + Base::kWarpGemmIterations); + this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]); + ++this->warp_tile_iterator_A_; - warp_dequantizer_.load(warp_frag_scales, warp_frag_zero); + int const warp_tileB_k_compute_offset = + warp_mma_k % Base::kNumKIterationsPerWarpBLoad; + int const warp_tileB_k_load_offset = + warp_mma_k / Base::kNumKIterationsPerWarpBLoad; + // We are just about to finish computing on a fragment of B, so initiate + // the load for the next fragment. + if (warp_tileB_k_compute_offset == + Base::kNumKIterationsPerWarpBLoad - 1) { + this->warp_tile_iterator_B_.set_kgroup_index( + (warp_tileB_k_load_offset + 1) % Base::kWarpGemmIterationsForB); + this->warp_tile_iterator_B_.load( + warp_frag_B[(warp_tileB_k_load_offset + 1) % 2]); + ++this->warp_tile_iterator_B_; + } - ++this->warp_tile_iterator_A_; - ++this->warp_tile_iterator_B_; - warp_dequantizer_.add_pointer_offset(Shape::kN); - - Operator warp_mma; - - int smem_write_stage_idx = 1; - - // Avoid reading out of bounds - iterator_A.clear_mask(gemm_k_iterations <= 1); - iterator_B.clear_mask(gemm_k_iterations <= 1); - iterator_scale.clear_mask(gemm_k_iterations <= 1); - - // Issue loads during the first warp-level matrix multiply-add *AFTER* issuing - // shared memory loads (which have the tighest latency requirement). - - // - // Mainloop - // - - // Note: The main loop does not support Base::kWarpGemmIterations == 2. - CUTLASS_GEMM_LOOP - for (; gemm_k_iterations > 0; --gemm_k_iterations) - { - // - // Loop over GEMM K dimension - // - - CUTLASS_PRAGMA_UNROLL - for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) - { - - // Load warp-level tiles from shared memory, wrapping to k offset if this is the last group - // as the case may be. - - if (warp_mma_k == Base::kWarpGemmIterations - 1) - { - - // Write fragments to shared memory - this->smem_iterator_A_.store(transformA(tb_frag_A)); - - this->smem_iterator_B_.store(ldg_converter(tb_frag_B)); - - __syncthreads(); - - ++this->smem_iterator_A_; - ++this->smem_iterator_B_; - - // Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory - if (smem_write_stage_idx == 1) - { - this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); - this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); - this->smem_iterator_scale_.add_tile_offset({-Base::kStages, 0}); - } - else - { - this->warp_tile_iterator_A_.add_tile_offset( - {0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations}); - this->warp_tile_iterator_B_.add_tile_offset( - {-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterationsForB, 0}); - warp_dequantizer_.add_pointer_offset(-Base::kStages * Shape::kN); - } - - smem_write_stage_idx ^= 1; - } - - this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); - this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]); - ++this->warp_tile_iterator_A_; - - int const warp_tileB_k_compute_offset = warp_mma_k % Base::kNumKIterationsPerWarpBLoad; - int const warp_tileB_k_load_offset = warp_mma_k / Base::kNumKIterationsPerWarpBLoad; - // We are just about to finish computing on a fragment of B, so initiate the load for the next fragment. - if (warp_tileB_k_compute_offset == Base::kNumKIterationsPerWarpBLoad - 1) - { - this->warp_tile_iterator_B_.set_kgroup_index( - (warp_tileB_k_load_offset + 1) % Base::kWarpGemmIterationsForB); - this->warp_tile_iterator_B_.load(warp_frag_B[(warp_tileB_k_load_offset + 1) % 2]); - ++this->warp_tile_iterator_B_; - } - - if (warp_mma_k == 0) - { - - iterator_A.load(tb_frag_A); - iterator_B.load(tb_frag_B); - - ++iterator_A; - ++iterator_B; - - copy_scales_and_advance(iterator_scale); - - // Avoid reading out of bounds if this was the last loop iteration - iterator_A.clear_mask(gemm_k_iterations <= 2); - iterator_B.clear_mask(gemm_k_iterations <= 2); - iterator_scale.clear_mask(gemm_k_iterations <= 2); - } - - typename TransformBAfterLDS::result_type converted_frag_B - = lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2]); - warp_dequantizer_.dequantize(converted_frag_B, warp_frag_scales, warp_frag_zero); - run_warp_mma( - warp_mma, accum, warp_frag_A[warp_mma_k % 2], converted_frag_B, accum, warp_tileB_k_compute_offset); - } - - // Load the scales needed for the next tile iteration - warp_dequantizer_.load(warp_frag_scales, warp_frag_zero); - // Update internal pointer to the set of scales in shared memory - warp_dequantizer_.add_pointer_offset(Shape::kN); + if (warp_mma_k == 0) { + iterator_A.load(tb_frag_A); + iterator_B.load(tb_frag_B); + + ++iterator_A; + ++iterator_B; + + copy_scales_and_advance(iterator_scale); + + // Avoid reading out of bounds if this was the last loop iteration + iterator_A.clear_mask(gemm_k_iterations <= 2); + iterator_B.clear_mask(gemm_k_iterations <= 2); + iterator_scale.clear_mask(gemm_k_iterations <= 2); } + + typename TransformBAfterLDS::result_type converted_frag_B = + lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2]); + warp_dequantizer_.dequantize( + converted_frag_B, warp_frag_scales, warp_frag_zero); + run_warp_mma(warp_mma, + accum, + warp_frag_A[warp_mma_k % 2], + converted_frag_B, + accum, + warp_tileB_k_compute_offset); + } + + // Load the scales needed for the next tile iteration + warp_dequantizer_.load(warp_frag_scales, warp_frag_zero); + // Update internal pointer to the set of scales in shared memory + warp_dequantizer_.add_pointer_offset(Shape::kN); } + } }; ///////////////////////////////////////////////////////////////////////////////////////////////// -} // namespace threadblock -} // namespace gemm -} // namespace cutlass +} // namespace threadblock +} // namespace gemm +} // namespace cutlass ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/dq_mma_pipelined_percol.h b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/dq_mma_pipelined_percol.h index 316ea9f80a9..e6f512edfeb 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/dq_mma_pipelined_percol.h +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/dq_mma_pipelined_percol.h @@ -1,12 +1,12 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation @@ -18,14 +18,15 @@ * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /*! \file @@ -53,27 +54,27 @@ ///////////////////////////////////////////////////////////////////////////////////////////////// -namespace cutlass -{ -namespace gemm -{ -namespace threadblock -{ +namespace cutlass { +namespace gemm { +namespace threadblock { ///////////////////////////////////////////////////////////////////////////////////////////////// -/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions. +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. template < /// Size of the Gemm problem - concept: gemm::GemmShape<> typename Shape_, /// Iterates over tiles of A operand in global memory - // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) typename IteratorA_, /// Iterates over tiles of A operand in shared memory /// (concept: WriteableTileIterator | RandomAccessTileIterator) typename SmemIteratorA_, /// Iterates over tiles of B operand in global memory - // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) typename IteratorB_, /// Iterates over tiles of B operand in shared memory /// (concept: WriteableTileIterator | RandomAccessTileIterator) @@ -94,306 +95,361 @@ template < typename TransformBAfterLDS_, /// The quantization operator being used WeightOnlyQuantOp QuantOp_> -class DqMmaPipelined> - : public DqMmaBase -{ -public: - ///< Base class - using Base = DqMmaBase; - - using Shape = Shape_; ///< Size of the Gemm problem - concept: gemm::GemmShape<> - using IteratorA = IteratorA_; ///< Iterates over tiles of A operand in global memory - using IteratorB = IteratorB_; ///< Iterates over tiles of B operand in global memory - using ElementC = ElementC_; ///< Data type of accumulator matrix - using LayoutC = LayoutC_; ///< Layout of accumulator matrix - using Policy = Policy_; ///< Policy describing tuning details - - using IteratorScale = IteratorScale_; - using ElementScale = typename IteratorScale::Element; - using LayoutScale = typename IteratorScale::Layout; - - using SmemIteratorA = SmemIteratorA_; - using SmemIteratorB = SmemIteratorB_; - using SmemIteratorScale = SmemIteratorScale_; - - using TransformBAfterLDG = TransformBAfterLDG_; - using TransformBAfterLDS = TransformBAfterLDS_; - - static constexpr WeightOnlyQuantOp QuantOp = QuantOp_; +class DqMmaPipelined> + : public DqMmaBase { + public: + ///< Base class + using Base = DqMmaBase; + + using Shape = + Shape_; ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using IteratorA = + IteratorA_; ///< Iterates over tiles of A operand in global memory + using IteratorB = + IteratorB_; ///< Iterates over tiles of B operand in global memory + using ElementC = ElementC_; ///< Data type of accumulator matrix + using LayoutC = LayoutC_; ///< Layout of accumulator matrix + using Policy = Policy_; ///< Policy describing tuning details + + using IteratorScale = IteratorScale_; + using ElementScale = typename IteratorScale::Element; + using LayoutScale = typename IteratorScale::Layout; + + using SmemIteratorA = SmemIteratorA_; + using SmemIteratorB = SmemIteratorB_; + using SmemIteratorScale = SmemIteratorScale_; + + using TransformBAfterLDG = TransformBAfterLDG_; + using TransformBAfterLDS = TransformBAfterLDS_; + + static constexpr WeightOnlyQuantOp QuantOp = QuantOp_; + + // + // Dependent types + // + + /// Fragment of operand A loaded from global memory + using FragmentA = typename IteratorA::Fragment; + + /// Fragment of operand B loaded from global memory + using FragmentB = typename IteratorB::Fragment; + + /// Fragment of operand Scale loaded from global memory; + using FragmentScale = typename IteratorScale::Fragment; + + /// Fragment of accumulator tile + using FragmentC = typename Policy::Operator::FragmentC; + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Obtain the arch tag from the warp-level operator + using ArchTag = typename Policy::Operator::ArchTag; + + using Dequantizer = warp::MmaTensorOpDequantizer< + Operator, + typename Base::WarpGemm, + Operand::kB, + typename SmemIteratorScale::Fragment::Element, + LayoutScale, + 32, + QuantOp>; + + /// Complex transform on A operand + static ComplexTransform const kTransformA = Operator::kTransformA; + + /// Complex transform on B operand + static ComplexTransform const kTransformB = Operator::kTransformB; + + // staticaly assert kStages for DqMmaPipelined is two (Double-buffered + // pipeline) + static_assert((Base::kStages == 2), + "DqMmaPipelined requires kStages set to value 2"); + + private: + using WarpFragmentA = typename Operator::FragmentA; + using WarpFragmentB = typename Operator::FragmentB; + Dequantizer warp_dequantizer_; + + using ElementA = typename IteratorA::Element; + using ElementB = typename IteratorB::Element; + using LayoutDetailsForB = kernel::LayoutDetailsB; + + static constexpr bool RequiresTileInterleave = + layout::IsColumnMajorTileInterleave< + typename LayoutDetailsForB::Layout>::value; + static_assert(!RequiresTileInterleave || + (RequiresTileInterleave && + (Shape::kK == LayoutDetailsForB::ThreadblockK)), + "Layout K must match threadblockK"); + + protected: + /// Iterator to write threadblock-scoped tile of A operand to shared memory + SmemIteratorA smem_iterator_A_; + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB smem_iterator_B_; + + /// Iterator to write threadblock-scoped tile of scale operand to shared + /// memory + SmemIteratorScale smem_iterator_scale_; + + public: + /// Construct from tensor references + CUTLASS_DEVICE + DqMmaPipelined( + typename Base::SharedStorage& + shared_storage, ///< Shared storage needed for internal use by + ///< threadblock-scoped GEMM + int const + group_size, ///< Will not be used, just to adapt to finegrained + ///< modifications and make the compilation successful. + ///< Because DqMmaPipelined is only enabled for sm<80, so + ///< even if this argument is not added, it does not + ///< affect compilation for sm>=80. + int thread_idx, ///< ID within the threadblock + int warp_idx, ///< ID of warp + int lane_idx ///< ID of each thread within a warp + ) + : Base(shared_storage, thread_idx, warp_idx, lane_idx), + warp_dequantizer_( + {shared_storage.operand_scale.data(), LayoutScale(Shape::kN)}, + (warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN)) / + Base::WarpCount::kM, + lane_idx), + smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx), + smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx), + smem_iterator_scale_(LayoutScale(Shape::kN), + shared_storage.operand_scale.data(), + {1, Shape::kN}, + thread_idx) { + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + + int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); + int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); + + int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; + int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; + + // Add per-warp offsets in units of warp-level tiles + this->warp_tile_iterator_A_.add_tile_offset( + {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); + this->warp_tile_iterator_B_.add_tile_offset( + {Base::kWarpGemmIterationsForB * warp_idx_k, warp_idx_n}); + } + + /// Perform a threadblock-scoped matrix multiply-accumulate + CUTLASS_DEVICE + void operator()( + int gemm_k_iterations, ///< number of iterations of the mainloop + FragmentC& accum, ///< destination accumulator tile + IteratorA iterator_A, ///< iterator over A operand in global memory + IteratorB iterator_B, ///< iterator over B operand in global memory + IteratorScale + iterator_scale, ///< iterator over scale operand in global memory + FragmentC const& src_accum) { ///< source accumulator tile // - // Dependent types + // Prologue // + TransformBAfterLDG ldg_converter; + TransformBAfterLDS lds_converter; - /// Fragment of operand A loaded from global memory - using FragmentA = typename IteratorA::Fragment; - - /// Fragment of operand B loaded from global memory - using FragmentB = typename IteratorB::Fragment; - - /// Fragment of operand Scale loaded from global memory; - using FragmentScale = typename IteratorScale::Fragment; - - /// Fragment of accumulator tile - using FragmentC = typename Policy::Operator::FragmentC; - - /// Warp-level Mma - using Operator = typename Policy::Operator; - - /// Obtain the arch tag from the warp-level operator - using ArchTag = typename Policy::Operator::ArchTag; - - using Dequantizer = warp::MmaTensorOpDequantizer; - - /// Complex transform on A operand - static ComplexTransform const kTransformA = Operator::kTransformA; - - /// Complex transform on B operand - static ComplexTransform const kTransformB = Operator::kTransformB; - - // staticaly assert kStages for DqMmaPipelined is two (Double-buffered pipeline) - static_assert((Base::kStages == 2), "DqMmaPipelined requires kStages set to value 2"); - -private: - using WarpFragmentA = typename Operator::FragmentA; - using WarpFragmentB = typename Operator::FragmentB; - Dequantizer warp_dequantizer_; - - using ElementA = typename IteratorA::Element; - using ElementB = typename IteratorB::Element; - using LayoutDetailsForB = kernel::LayoutDetailsB; - - static constexpr bool RequiresTileInterleave - = layout::IsColumnMajorTileInterleave::value; - static_assert(!RequiresTileInterleave || (RequiresTileInterleave && (Shape::kK == LayoutDetailsForB::ThreadblockK)), - "Layout K must match threadblockK"); - -protected: - /// Iterator to write threadblock-scoped tile of A operand to shared memory - SmemIteratorA smem_iterator_A_; - - /// Iterator to write threadblock-scoped tile of B operand to shared memory - SmemIteratorB smem_iterator_B_; - - /// Iterator to write threadblock-scoped tile of scale operand to shared memory - SmemIteratorScale smem_iterator_scale_; - -public: - /// Construct from tensor references - CUTLASS_DEVICE - DqMmaPipelined(typename Base::SharedStorage& - shared_storage, ///< Shared storage needed for internal use by threadblock-scoped GEMM - int const group_size, ///< Will not be used, just to adapt to finegrained modifications and make the compilation - ///< successful. Because DqMmaPipelined is only enabled for sm<80, so even if this - ///< argument is not added, it does not affect compilation for sm>=80. - int thread_idx, ///< ID within the threadblock - int warp_idx, ///< ID of warp - int lane_idx ///< ID of each thread within a warp - ) - : Base(shared_storage, thread_idx, warp_idx, lane_idx) - , warp_dequantizer_({shared_storage.operand_scale.data(), LayoutScale(Shape::kN)}, - (warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN)) / Base::WarpCount::kM, lane_idx) - , smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx) - , smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) - , smem_iterator_scale_(LayoutScale(Shape::kN), shared_storage.operand_scale.data(), {1, Shape::kN}, thread_idx) - { - - // Compute warp location within threadblock tile by mapping the warp_id to - // three coordinates: - // _m: the warp's position within the threadblock along the M dimension - // _n: the warp's position within the threadblock along the N dimension - // _k: the warp's position within the threadblock along the K dimension - - int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); - int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); - - int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; - int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; - - // Add per-warp offsets in units of warp-level tiles - this->warp_tile_iterator_A_.add_tile_offset({warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); - this->warp_tile_iterator_B_.add_tile_offset({Base::kWarpGemmIterationsForB * warp_idx_k, warp_idx_n}); - } + using TransformA = NumericArrayConverter; + + using TransformScale = + NumericArrayConverter; + + // These transforms are mainly to handle when we have bfloat activations and + // weights in GMEM and want to issue HMMA on architectures older than + // Ampere. We will convert to FP16 before STS. + TransformA transformA; + TransformScale transformScale; + + // Perform accumulation in the 'd' output operand + accum = src_accum; - /// Perform a threadblock-scoped matrix multiply-accumulate - CUTLASS_DEVICE - void operator()(int gemm_k_iterations, ///< number of iterations of the mainloop - FragmentC& accum, ///< destination accumulator tile - IteratorA iterator_A, ///< iterator over A operand in global memory - IteratorB iterator_B, ///< iterator over B operand in global memory - IteratorScale iterator_scale, ///< iterator over scale operand in global memory - FragmentC const& src_accum) - { ///< source accumulator tile + FragmentA tb_frag_A; + FragmentB tb_frag_B; + FragmentScale tb_frag_scales; - // - // Prologue - // - TransformBAfterLDG ldg_converter; - TransformBAfterLDS lds_converter; + using WarpFragmentScale = typename Dequantizer::FragmentScale; + WarpFragmentScale warp_frag_scales; - using TransformA - = NumericArrayConverter; + tb_frag_A.clear(); + tb_frag_B.clear(); + tb_frag_scales.clear(); - using TransformScale = NumericArrayConverter; + // The last kblock is loaded in the prolog + iterator_A.load(tb_frag_A); + iterator_B.load(tb_frag_B); + iterator_scale.load(tb_frag_scales); - // These transforms are mainly to handle when we have bfloat activations and weights in GMEM and want - // to issue HMMA on architectures older than Ampere. We will convert to FP16 before STS. - TransformA transformA; - TransformScale transformScale; + ++iterator_A; + ++iterator_B; - // Perform accumulation in the 'd' output operand - accum = src_accum; + this->smem_iterator_A_.store(transformA(tb_frag_A)); + this->smem_iterator_B_.store(ldg_converter(tb_frag_B)); + this->smem_iterator_scale_.store(transformScale(tb_frag_scales)); - FragmentA tb_frag_A; - FragmentB tb_frag_B; - FragmentScale tb_frag_scales; + ++this->smem_iterator_A_; + ++this->smem_iterator_B_; - using WarpFragmentScale = typename Dequantizer::FragmentScale; - WarpFragmentScale warp_frag_scales; + __syncthreads(); - tb_frag_A.clear(); - tb_frag_B.clear(); - tb_frag_scales.clear(); + warp_dequantizer_.load(warp_frag_scales); - // The last kblock is loaded in the prolog - iterator_A.load(tb_frag_A); - iterator_B.load(tb_frag_B); - iterator_scale.load(tb_frag_scales); + // Pair of fragments used to overlap shared memory loads and math + // instructions + WarpFragmentA warp_frag_A[2]; + WarpFragmentB warp_frag_B[2]; - ++iterator_A; - ++iterator_B; + this->warp_tile_iterator_A_.set_kgroup_index(0); + this->warp_tile_iterator_B_.set_kgroup_index(0); - this->smem_iterator_A_.store(transformA(tb_frag_A)); - this->smem_iterator_B_.store(ldg_converter(tb_frag_B)); - this->smem_iterator_scale_.store(transformScale(tb_frag_scales)); + this->warp_tile_iterator_A_.load(warp_frag_A[0]); + this->warp_tile_iterator_B_.load(warp_frag_B[0]); - ++this->smem_iterator_A_; - ++this->smem_iterator_B_; + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_B_; - __syncthreads(); + Operator warp_mma; - warp_dequantizer_.load(warp_frag_scales); + int smem_write_stage_idx = 1; - // Pair of fragments used to overlap shared memory loads and math instructions - WarpFragmentA warp_frag_A[2]; - WarpFragmentB warp_frag_B[2]; + // Avoid reading out of bounds + iterator_A.clear_mask(gemm_k_iterations <= 1); + iterator_B.clear_mask(gemm_k_iterations <= 1); - this->warp_tile_iterator_A_.set_kgroup_index(0); - this->warp_tile_iterator_B_.set_kgroup_index(0); + // Issue loads during the first warp-level matrix multiply-add *AFTER* + // issuing shared memory loads (which have the tighest latency requirement). + + // + // Mainloop + // - this->warp_tile_iterator_A_.load(warp_frag_A[0]); - this->warp_tile_iterator_B_.load(warp_frag_B[0]); + // Note: The main loop does not support Base::kWarpGemmIterations == 2. + CUTLASS_GEMM_LOOP + for (; gemm_k_iterations > 0; --gemm_k_iterations) { + // + // Loop over GEMM K dimension + // + + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; + ++warp_mma_k) { + // Load warp-level tiles from shared memory, wrapping to k offset if + // this is the last group as the case may be. + + if (warp_mma_k == Base::kWarpGemmIterations - 1) { + // Write fragments to shared memory + this->smem_iterator_A_.store(transformA(tb_frag_A)); + + this->smem_iterator_B_.store(ldg_converter(tb_frag_B)); + + __syncthreads(); + + ++this->smem_iterator_A_; + ++this->smem_iterator_B_; + + // Add negative offsets to return iterators to the 'start' of the + // circular buffer in shared memory + if (smem_write_stage_idx == 1) { + this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); + this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); + } else { + this->warp_tile_iterator_A_.add_tile_offset( + {0, + -Base::kStages * Policy::kPartitionsK * + Base::kWarpGemmIterations}); + this->warp_tile_iterator_B_.add_tile_offset( + {-Base::kStages * Policy::kPartitionsK * + Base::kWarpGemmIterationsForB, + 0}); + } + + smem_write_stage_idx ^= 1; + } + this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % + Base::kWarpGemmIterations); + this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]); ++this->warp_tile_iterator_A_; - ++this->warp_tile_iterator_B_; - - Operator warp_mma; - - int smem_write_stage_idx = 1; - - // Avoid reading out of bounds - iterator_A.clear_mask(gemm_k_iterations <= 1); - iterator_B.clear_mask(gemm_k_iterations <= 1); - - // Issue loads during the first warp-level matrix multiply-add *AFTER* issuing - // shared memory loads (which have the tighest latency requirement). - - // - // Mainloop - // - - // Note: The main loop does not support Base::kWarpGemmIterations == 2. - CUTLASS_GEMM_LOOP - for (; gemm_k_iterations > 0; --gemm_k_iterations) - { - // - // Loop over GEMM K dimension - // - - CUTLASS_PRAGMA_UNROLL - for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) - { - - // Load warp-level tiles from shared memory, wrapping to k offset if this is the last group - // as the case may be. - - if (warp_mma_k == Base::kWarpGemmIterations - 1) - { - - // Write fragments to shared memory - this->smem_iterator_A_.store(transformA(tb_frag_A)); - - this->smem_iterator_B_.store(ldg_converter(tb_frag_B)); - - __syncthreads(); - - ++this->smem_iterator_A_; - ++this->smem_iterator_B_; - - // Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory - if (smem_write_stage_idx == 1) - { - this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); - this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); - } - else - { - this->warp_tile_iterator_A_.add_tile_offset( - {0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations}); - this->warp_tile_iterator_B_.add_tile_offset( - {-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterationsForB, 0}); - } - - smem_write_stage_idx ^= 1; - } - - this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); - this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]); - ++this->warp_tile_iterator_A_; - - int const warp_tileB_k_compute_offset = warp_mma_k % Base::kNumKIterationsPerWarpBLoad; - int const warp_tileB_k_load_offset = warp_mma_k / Base::kNumKIterationsPerWarpBLoad; - // We are just about to finish computing on a fragment of B, so initiate the load for the next fragment. - if (warp_tileB_k_compute_offset == Base::kNumKIterationsPerWarpBLoad - 1) - { - this->warp_tile_iterator_B_.set_kgroup_index( - (warp_tileB_k_load_offset + 1) % Base::kWarpGemmIterationsForB); - this->warp_tile_iterator_B_.load(warp_frag_B[(warp_tileB_k_load_offset + 1) % 2]); - ++this->warp_tile_iterator_B_; - } - - if (warp_mma_k == 0) - { - - iterator_A.load(tb_frag_A); - iterator_B.load(tb_frag_B); - - ++iterator_A; - ++iterator_B; - - // Avoid reading out of bounds if this was the last loop iteration - iterator_A.clear_mask(gemm_k_iterations <= 2); - iterator_B.clear_mask(gemm_k_iterations <= 2); - } - - typename TransformBAfterLDS::result_type converted_frag_B - = lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2]); - warp_dequantizer_.dequantize(converted_frag_B, warp_frag_scales); - run_warp_mma( - warp_mma, accum, warp_frag_A[warp_mma_k % 2], converted_frag_B, accum, warp_tileB_k_compute_offset); - } + + int const warp_tileB_k_compute_offset = + warp_mma_k % Base::kNumKIterationsPerWarpBLoad; + int const warp_tileB_k_load_offset = + warp_mma_k / Base::kNumKIterationsPerWarpBLoad; + // We are just about to finish computing on a fragment of B, so initiate + // the load for the next fragment. + if (warp_tileB_k_compute_offset == + Base::kNumKIterationsPerWarpBLoad - 1) { + this->warp_tile_iterator_B_.set_kgroup_index( + (warp_tileB_k_load_offset + 1) % Base::kWarpGemmIterationsForB); + this->warp_tile_iterator_B_.load( + warp_frag_B[(warp_tileB_k_load_offset + 1) % 2]); + ++this->warp_tile_iterator_B_; } + + if (warp_mma_k == 0) { + iterator_A.load(tb_frag_A); + iterator_B.load(tb_frag_B); + + ++iterator_A; + ++iterator_B; + + // Avoid reading out of bounds if this was the last loop iteration + iterator_A.clear_mask(gemm_k_iterations <= 2); + iterator_B.clear_mask(gemm_k_iterations <= 2); + } + + typename TransformBAfterLDS::result_type converted_frag_B = + lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2]); + warp_dequantizer_.dequantize(converted_frag_B, warp_frag_scales); + run_warp_mma(warp_mma, + accum, + warp_frag_A[warp_mma_k % 2], + converted_frag_B, + accum, + warp_tileB_k_compute_offset); + } } + } }; ///////////////////////////////////////////////////////////////////////////////////////////////// -} // namespace threadblock -} // namespace gemm -} // namespace cutlass +} // namespace threadblock +} // namespace gemm +} // namespace cutlass ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_mma_base.h b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_mma_base.h index 6dd55b647a0..9b118add9aa 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_mma_base.h +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_mma_base.h @@ -63,10 +63,10 @@ template < typename Policy_, /// Number of stages, int Stages, - /// Used for partial specialization - typename Enable = bool> + /// Size of extra quantized params + typename QuantParamsShape> class Wint2xMmaBase { -public: + public: ///< Size of the Gemm problem - concept: gemm::GemmShape<> using Shape = Shape_; @@ -85,14 +85,23 @@ class Wint2xMmaBase { using WarpGemm = typename Policy::Operator::Shape; /// Shape describing the number of warps filling the CTA - using WarpCount = - GemmShape; + using WarpCount = GemmShape; - /// Number of warp-level GEMM oeprations + /// Number of warp-level GEMM operations static int const kWarpGemmIterations = (WarpGemm::kK / Operator::Policy::MmaShape::kK); + /// Number of warp-level GEMM operations per load for B + static constexpr int kWarpGemmIterationsPerLoadForB = + Operator::IteratorB::InstructionShape::kRow / + Operator::InstructionShape::kK; + static_assert(!(kWarpGemmIterations % kWarpGemmIterationsPerLoadForB), ""); + + static constexpr int kWarpLoadIterationsForB = + kWarpGemmIterations / kWarpGemmIterationsPerLoadForB; + /// Number of stages static int const kStages = Stages; @@ -104,8 +113,6 @@ class Wint2xMmaBase { using TensorRefB = TensorRef; - // using TensorRefZippedB = TensorRef; - static_assert(kWarpGemmIterations > 1, "The pipelined structure requires at least two warp-level " "GEMM operations."); @@ -119,7 +126,7 @@ class Wint2xMmaBase { /// Shared storage object needed by threadblock-scoped GEMM class SharedStorage { - public: + public: // // Type definitions // @@ -130,22 +137,13 @@ class Wint2xMmaBase { Shape::kK * kStages + Policy::SmemPaddingA::kColumn>; /// Shape of the B matrix operand in shared memory - using ShapeB = MatrixShape; - // w uint8; local_scale uint8; - constexpr static int kZippedRowsPerStages = - Shape::kK / 4 + (Shape::kK + 127) / 128; - - // code_scale float; code_zp float; super_scale ElementB - constexpr static int kColumnWiseParamsRows = 2 * sizeof(float) + - sizeof_bits::value / 8; - - using ZippedShapeB = MatrixShape; - - using NopaddingShapeB = MatrixShape; + /// Shape of all quant params in shared memory + using QuantParamsShapeB = QuantParamsShape; - public: + public: // // Data members // @@ -156,14 +154,10 @@ class Wint2xMmaBase { /// Buffer for B operand AlignedBuffer operand_B; - /// Buffer for quanted B operand - AlignedBuffer operand_zipped_B; + /// Buffer for extra quant params of B operand + AlignedBuffer operand_quant_params_B; - /// Buffer for unzip B operand - AlignedBuffer - operand_unzip_B; - - public: + public: // // Methods // @@ -191,17 +185,9 @@ class Wint2xMmaBase { TensorRefB operand_B_ref() { return TensorRefB{operand_B.data(), LayoutB()}; } - - CUTLASS_HOST_DEVICE - uint8_t *operand_zipped_B_ptr() { return operand_zipped_B.data(); } - - CUTLASS_HOST_DEVICE - typename Operator::ElementB *operand_unzip_B_ptr() { - return operand_unzip_B.data(); - } }; -protected: + protected: // // Data members // @@ -212,7 +198,7 @@ class Wint2xMmaBase { /// Iterator to load a warp-scoped tile of B operand from shared memory typename Operator::IteratorB warp_tile_iterator_B_; -public: + public: /// Construct from tensor references CUTLASS_DEVICE Wint2xMmaBase( @@ -230,8 +216,8 @@ class Wint2xMmaBase { ///////////////////////////////////////////////////////////////////////////////////////////////// -} // namespace threadblock -} // namespace gemm -} // namespace cutlass +} // namespace threadblock +} // namespace gemm +} // namespace cutlass ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_mma_multistage.h b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_mma_multistage.h index 9531b01a7c0..245a89c152c 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_mma_multistage.h +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_mma_multistage.h @@ -1,12 +1,12 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation @@ -18,14 +18,15 @@ * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /*! \file @@ -45,7 +46,8 @@ #include "cutlass_extensions/arch/memory_copy_sm80.h" #include "cutlass_extensions/gemm/threadblock/wint2x_mma_base.h" -#include "cutlass_extensions/gemm/threadblock/wint2x_tile_dequanter.h" +#include "cutlass_extensions/gemm/threadblock/wint2x_params_accessor.h" +#include "cutlass_extensions/gemm/warp/mma_tensorop_wint2x_dequantizer.h" ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -86,15 +88,21 @@ template < typename Policy_, /// Number of stages, int Stages, + /// Accessor for extra quantized params + typename QuantParamsAccessor_, /// Use zfill or predicate for out-of-bound cp.async - SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone, - /// Used for partial specialization - typename Enable = bool> -class Wint2xMmaMultistage : - public Wint2xMmaBase { -public: + SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone> +class Wint2xMmaMultistage + : public Wint2xMmaBase { + public: ///< Base class - using Base = Wint2xMmaBase; + using Base = Wint2xMmaBase; ///< Size of the Gemm problem - concept: gemm::GemmShape<> using Shape = Shape_; ///< Iterates over tiles of A operand in global memory @@ -107,8 +115,11 @@ class Wint2xMmaMultistage : using LayoutC = LayoutC_; ///< Policy describing tuning details using Policy = Policy_; + /// Accessor for extra quantized params + using QuantParamsAccessor = QuantParamsAccessor_; + using QuantArguments = typename QuantParamsAccessor::Arguments; - using ZippedShapeB = typename Base::SharedStorage::ZippedShapeB; + static constexpr int kInterleave = IteratorB::Shape::kRow / Shape::kK; using SmemIteratorA = SmemIteratorA_; using SmemIteratorB = SmemIteratorB_; @@ -129,6 +140,20 @@ class Wint2xMmaMultistage : /// Minimum architecture is Sm80 to support cp.async using ArchTag = arch::Sm80; + // using LayoutScale = typename + // QuantParamsAccessor::IteratorSuperScale::Layout; + using LayoutScale = layout::RowMajor; + using WarpTransformedFragmentB = typename Operator::TransformedFragmentB; + using WarpDequantizer = warp::MmaTensorOpWin2xDequantizer< + Operator, + typename Base::WarpGemm, + Operand::kB, + typename WarpTransformedFragmentB::Element, + LayoutScale, + QuantParamsAccessor::kGroupSize>; + static_assert(sizeof(WarpDequantizer) > 0, + "WarpDequantizer template instantiation failed"); + /// Complex transform on A operand static ComplexTransform const kTransformA = Operator::kTransformA; @@ -137,7 +162,6 @@ class Wint2xMmaMultistage : /// Internal structure exposed for introspection. struct Detail { - /// Number of cp.async instructions to load one stage of operand A static int const AsyncCopyIterationsPerStageA = IteratorA::ThreadMap::Iterations::kCount; @@ -151,44 +175,68 @@ class Wint2xMmaMultistage : /// Number of cp.async instructions to load on group of operand A static int const kAccessesPerGroupA = - (AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; + (AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) / + Base::kWarpGemmIterations; /// Number of cp.async instructions to load on group of operand B static int const kAccessesPerGroupB = - (AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; - - // Optional staged-accumulation (e.g., tf32x3 kernels) for improved numerical - // accuracy, where each mainloop iteration first accumulates into a temporary - // set of freshly-cleared accumulators, which are subsequently added to the - // final accumulator set. - static bool const kStagedAccumulation = arch::detail::UseStagedAccumulation::value; + (AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / + Base::kWarpGemmIterations; + + // Optional staged-accumulation (e.g., tf32x3 kernels) for improved + // numerical accuracy, where each mainloop iteration first accumulates into + // a temporary set of freshly-cleared accumulators, which are subsequently + // added to the final accumulator set. + static bool const kStagedAccumulation = + arch::detail::UseStagedAccumulation::value; }; private: - // Structure encapsulating pipeline state live from one iteration to the next struct PipeState { - using WarpLoadedFragmentA = typename Operator::FragmentA; using WarpLoadedFragmentB = typename Operator::FragmentB; using WarpTransformedFragmentA = typename Operator::TransformedFragmentA; using WarpTransformedFragmentB = typename Operator::TransformedFragmentB; + using FragmentSuperScale = typename WarpDequantizer::FragmentSuperScale; + using FragmentCodeScaleZp = typename WarpDequantizer::FragmentCodeScaleZp; + using FragmentLocalScale = typename WarpDequantizer::FragmentLocalScale; + /// Temporary accumulator to facilitate staged-accumulation FragmentC tmp_accum_; - /// Pair of A fragments used to overlap shared memory loads and math instructions - WarpLoadedFragmentA warp_loaded_frag_A_[2]; - WarpTransformedFragmentA warp_transformed_frag_A_[2]; + /// Pair of A fragments used to overlap shared memory loads and math + /// instructions + WarpTransformedFragmentA warp_frag_A_[2]; - /// Pair of B fragments used to overlap shared memory loads and math instructions - WarpLoadedFragmentB warp_loaded_frag_B_[2]; - WarpTransformedFragmentB warp_transformed_frag_B_[2]; + /// Pair of B fragments used to overlap shared memory loads and math + /// instructions + WarpLoadedFragmentB warp_loaded_frag_B_; + WarpTransformedFragmentB warp_frag_B_[2]; + + /// channel-wise quant params + FragmentCodeScaleZp warp_frag_code_scale_; + FragmentCodeScaleZp warp_frag_code_zp_; + FragmentSuperScale warp_frag_super_scale_; + + /// group-wise quant params + FragmentLocalScale warp_frag_local_scale_; }; + using ElementA = typename IteratorA::Element; + using ElementB = typename IteratorB::Element; + using LayoutDetailsForB = kernel::LayoutDetailsB; - private: + static constexpr bool IsTileInterleaveLayout = + layout::IsColumnMajorTileInterleave< + typename LayoutDetailsForB::Layout>::value; + static_assert(!IsTileInterleaveLayout || + (IsTileInterleaveLayout && + (Shape::kK == LayoutDetailsForB::ThreadblockK)), + "Layout K must match threadblockK"); + private: // // Data members // @@ -202,19 +250,19 @@ class Wint2xMmaMultistage : /// Iterator to write threadblock-scoped tile of B operand to shared memory SmemIteratorB smem_iterator_B_; + /// Accessor for extra quant params for B + QuantParamsAccessor quant_params_accessor_B_; + + // Wint2 unzip operator + WarpDequantizer warp_dequantizer_; + /// Shared memory write stage index int smem_write_stage_idx_; /// Shared memory read stage index int smem_read_stage_idx_; - uint8_t* column_wise_smem_ptr_B_; - - uint8_t* smem_zipped_ptr_B_; - int smem_zipped_bytes_per_stage_B_; - -public: - + public: /// Construct from tensor references CUTLASS_DEVICE Wint2xMmaMultistage( @@ -225,14 +273,24 @@ class Wint2xMmaMultistage : ///< ID of warp int warp_idx, ///< ID of each thread within a warp - int lane_idx - ): - Base(shared_storage, thread_idx, warp_idx, lane_idx), - smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx), - smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx), - smem_write_stage_idx_(0), - smem_read_stage_idx_(0) - { + int lane_idx) + : Base(shared_storage, thread_idx, warp_idx, lane_idx), + smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx), + smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx), + quant_params_accessor_B_(shared_storage.operand_quant_params_B.data(), + thread_idx, + warp_idx, + lane_idx), + warp_dequantizer_( + quant_params_accessor_B_.super_scale_ref(), + quant_params_accessor_B_.local_scale_ref(), + quant_params_accessor_B_.code_scale_ref(), + quant_params_accessor_B_.code_zp_ref(), + (warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN)) / + Base::WarpCount::kM, + lane_idx), + smem_write_stage_idx_(0), + smem_read_stage_idx_(0) { // Compute warp location within threadblock tile by mapping the warp_id to // three coordinates: // _m: the warp's position within the threadblock along the M dimension @@ -250,44 +308,37 @@ class Wint2xMmaMultistage : {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); this->warp_tile_iterator_B_.add_tile_offset( {Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); - - column_wise_smem_ptr_B_ = shared_storage.operand_zipped_B_ptr(); - - smem_zipped_ptr_B_ = column_wise_smem_ptr_B_ + Base::SharedStorage::kColumnWiseParamsRows * ZippedShapeB::kColumn; - smem_zipped_bytes_per_stage_B_ = Base::SharedStorage::kZippedRowsPerStages * ZippedShapeB::kColumn; } /// Advance shared memory read-iterators to the next stage CUTLASS_DEVICE - void advance_smem_read_stage() - { + void advance_smem_read_stage() { ++smem_read_stage_idx_; if (smem_read_stage_idx_ == Base::kStages) { // Wrap back around to the 'start' of the circular buffer in shared memory - this->warp_tile_iterator_A_.add_tile_offset({0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations}); - // this->warp_tile_iterator_B_.add_tile_offset({-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations, 0}); + this->warp_tile_iterator_A_.add_tile_offset( + {0, + -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations}); + this->warp_tile_iterator_B_.add_tile_offset( + {-Base::kStages * Policy::kPartitionsK * + Base::kWarpLoadIterationsForB, + 0}); smem_read_stage_idx_ = 0; } - this->warp_tile_iterator_B_.add_tile_offset({-Policy::kPartitionsK * Base::kWarpGemmIterations, 0}); } - /// Advance global memory read-iterators and shared memory write-iterators to the stage - template + /// Advance global memory read-iterators and shared memory write-iterators to + /// the stage CUTLASS_DEVICE - void advance_smem_write_stage( - IteratorA &iterator_A, - IteratorB &iterator_B, - TileDequanterB &tile_dequanter_B) - { + void advance_smem_write_stage(IteratorA &iterator_A, IteratorB &iterator_B) { // Advance global iterators iterator_A.add_tile_offset({0, 1}); - //iterator_B.add_tile_offset({1, 0}); - tile_dequanter_B.AddTileOffset({1, 0}); + iterator_B.add_tile_offset({1, 0}); // Advance shared iterators smem_iterator_A_.add_tile_offset({0, 1}); - //smem_iterator_B_.add_tile_offset({1, 0}); + smem_iterator_B_.add_tile_offset({1, 0}); // Increment shared memory write stage index ++smem_write_stage_idx_; @@ -295,7 +346,7 @@ class Wint2xMmaMultistage : if (smem_write_stage_idx_ == Base::kStages) { // Wrap back around to the 'start' of the circular buffer in shared memory smem_iterator_A_.add_tile_offset({0, -Base::kStages}); - //smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); + smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); smem_write_stage_idx_ = 0; } } @@ -338,9 +389,14 @@ class Wint2xMmaMultistage : } } - template CUTLASS_DEVICE void copy_tiles_and_advance_B(IteratorB &iterator_B, int group_start_B = 0) { + if constexpr (SharedMemoryClear == SharedMemoryClearOption::kZfill) { + if (threadIdx.x >= IteratorB::ThreadMap::kThreads) { + return; + } + } + iterator_B.set_iteration_index(group_start_B * IteratorB::kAccessesPerVector); this->smem_iterator_B_.set_iteration_index(group_start_B); @@ -360,13 +416,16 @@ class Wint2xMmaMultistage : CUTLASS_PRAGMA_UNROLL for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { auto gmem_ptr = iterator_B.get(); + bool is_valid = (threadIdx.x < IteratorB::ThreadMap::kThreads) + ? iterator_B.valid() + : false; if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { - cutlass::arch::copy_zfill( - dst_ptr + v, gmem_ptr, iterator_B.valid()); + cutlass::arch::cp_async_zfill( + dst_ptr + v, gmem_ptr, is_valid); } else { - cutlass::arch::copy( - dst_ptr + v, gmem_ptr, iterator_B.valid()); + cutlass::arch::cp_async( + dst_ptr + v, gmem_ptr, is_valid); } ++iterator_B; @@ -375,7 +434,6 @@ class Wint2xMmaMultistage : ++this->smem_iterator_B_; } } - __syncthreads(); } CUTLASS_DEVICE @@ -394,12 +452,9 @@ class Wint2xMmaMultistage : for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { auto gmem_ptr = iterator_A.get(); - int const kSrcBytes = - sizeof_bits::value * - IteratorA::ThreadMap::kElementsPerAccess / - IteratorA::kAccessesPerVector / 8; - - int src_bytes = (iterator_A.valid() ? kSrcBytes : 0); + int const kSrcBytes = sizeof_bits::value * + IteratorA::ThreadMap::kElementsPerAccess / + IteratorA::kAccessesPerVector / 8; cutlass::arch::cp_async_zfill( dst_ptr + v, iterator_A.get(), iterator_A.valid()); @@ -411,9 +466,12 @@ class Wint2xMmaMultistage : } } - template CUTLASS_DEVICE void copy_tiles_and_advance_per_stage_B(IteratorB &iterator_B) { + if (threadIdx.x >= IteratorB::ThreadMap::kThreads) { + return; + } + iterator_B.set_iteration_index(0); this->smem_iterator_B_.set_iteration_index(0); @@ -428,46 +486,37 @@ class Wint2xMmaMultistage : for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { auto gmem_ptr = iterator_B.get(); - int const kSrcBytes = - sizeof_bits::value * - IteratorB::ThreadMap::kElementsPerAccess / - IteratorB::kAccessesPerVector / 8; + int const kSrcBytes = sizeof_bits::value * + IteratorB::ThreadMap::kElementsPerAccess / + IteratorB::kAccessesPerVector / 8; - if (InitStage) { - cutlass::arch::copy_zfill( - dst_ptr + v, iterator_B.get(), iterator_B.valid()); - } else { - if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { - cutlass::arch::copy_zfill( - dst_ptr + v, gmem_ptr, iterator_B.valid()); - } else { - cutlass::arch::copy( - dst_ptr + v, gmem_ptr, iterator_B.valid()); - } - } + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_B.get(), iterator_B.valid()); ++iterator_B; } ++this->smem_iterator_B_; } - __syncthreads(); } /// GEMM prologue. Bootstrap the global->shared memory pipeline by fetching - /// the global fragments needed by the first kStages-1 threadblock mainloop iterations - template + /// the global fragments needed by the first kStages-1 threadblock mainloop + /// iterations CUTLASS_DEVICE - void prologue( - IteratorA &iterator_A, ///< [in|out] iterator over A operand in global memory - IteratorB &iterator_B, ///< [in|out] iterator over B operand in global memory - TileDequanterB &tile_dequanter_B, - int &gemm_k_iterations) ///< [in|out] number of threadblock mainloop iterations remaining + void prologue(IteratorA &iterator_A, ///< [in|out] iterator over A operand in + ///< global memory + IteratorB &iterator_B, ///< [in|out] iterator over B operand in + ///< global memory + QuantArguments & + mma_quant_args, ///< iterators for extra quant params for B + int &gemm_k_iterations) ///< [in|out] number of threadblock + ///< mainloop iterations remaining { // Issue several complete stages CUTLASS_PRAGMA_UNROLL - for (int stage = 0; stage < Base::kStages - 1; ++stage, --gemm_k_iterations) { - + for (int stage = 0; stage < Base::kStages - 1; + ++stage, --gemm_k_iterations) { // Disable global fetching if done with global fetch iterations iterator_A.clear_mask(gemm_k_iterations == 0); iterator_B.clear_mask(gemm_k_iterations == 0); @@ -476,21 +525,32 @@ class Wint2xMmaMultistage : copy_tiles_and_advance_per_stage_A(iterator_A); // Async copy zipped B to shared memory. - tile_dequanter_B.Load(smem_zipped_ptr_B_ + (stage % Base::kStages) * smem_zipped_bytes_per_stage_B_, - column_wise_smem_ptr_B_, stage); + copy_tiles_and_advance_per_stage_B(iterator_B); + + // Async copy other quantized params to shared memory, local_scale, + // code_scale, code_zp, super_scale. + if (stage == 0) { + quant_params_accessor_B_.copy_tiles_and_advance_per_stage( + mma_quant_args, stage); + } else { + quant_params_accessor_B_.copy_tiles_and_advance_per_stage( + mma_quant_args, stage); + } // Move to the next write stage - advance_smem_write_stage(iterator_A, iterator_B, tile_dequanter_B); + advance_smem_write_stage(iterator_A, iterator_B); + quant_params_accessor_B_.advance_smem_write_stage(mma_quant_args); // Defines the boundary of a stage of cp.async. cutlass::arch::cp_async_fence(); } - // Optionally clear the remaining stages of SMEM. This is a functional requirement for - // some kernels so that all accumulator elements outside the GEMM footprint are zero. + // Optionally clear the remaining stages of SMEM. This is a functional + // requirement for some kernels so that all accumulator elements outside the + // GEMM footprint are zero. if (SharedMemoryClear == SharedMemoryClearOption::kClearLastStage) { - - /// Iterator to write threadblock-scoped tile of A operand to shared memory + /// Iterator to write threadblock-scoped tile of A operand to shared + /// memory SmemIteratorA last_smem_iterator_A(this->smem_iterator_A_); typename IteratorA::AccessType zero_A; @@ -500,7 +560,6 @@ class Wint2xMmaMultistage : // Async Copy for operand A CUTLASS_PRAGMA_UNROLL for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { - typename IteratorA::AccessType *dst_ptr = reinterpret_cast( last_smem_iterator_A.get()); @@ -510,7 +569,12 @@ class Wint2xMmaMultistage : ++last_smem_iterator_A; } - /// Iterator to write threadblock-scoped tile of B operand to shared memory + if (threadIdx.x >= IteratorB::ThreadMap::kThreads) { + return; + } + + /// Iterator to write threadblock-scoped tile of B operand to shared + /// memory SmemIteratorB last_smem_iterator_B(this->smem_iterator_B_); typename IteratorB::AccessType zero_B; @@ -520,7 +584,6 @@ class Wint2xMmaMultistage : // Async Copy for operand B CUTLASS_PRAGMA_UNROLL for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { - typename IteratorB::AccessType *dst_ptr = reinterpret_cast( last_smem_iterator_B.get()); @@ -534,67 +597,76 @@ class Wint2xMmaMultistage : /// Wait until we have at least one completed global fetch stage CUTLASS_DEVICE - void gmem_wait() - { - // Wait until we have at least one committed global fetch stage. (#uncommitted = Base::kStages - 1 - #committed) + void gmem_wait() { + // Wait until we have at least one committed global fetch stage. + // (#uncommitted = Base::kStages - 1 - #committed) cutlass::arch::cp_async_wait(); __syncthreads(); } /// Perform a threadblock mainloop iteration of matrix multiply-accumulate - template CUTLASS_DEVICE void mac_loop_iter( - PipeState &pipe_state, ///< [in|out] loop-carried pipeline state - FragmentC &accum, ///< [in|out] destination accumulator tile - IteratorA &iterator_A, ///< [in|out] iterator over A operand in global memory - IteratorB &iterator_B, ///< [in|out] iterator over B operand in global memory - TileDequanterB &tile_dequanter_B, ///< [in|out] tile dequantizer for B operand - int &gemm_k_iterations, ///< [in|out] number of threadblock mainloop iterations remaining - int stage) - { + PipeState &pipe_state, ///< [in|out] loop-carried pipeline state + FragmentC &accum, ///< [in|out] destination accumulator tile + IteratorA + &iterator_A, ///< [in|out] iterator over A operand in global memory + IteratorB + &iterator_B, ///< [in|out] iterator over B operand in global memory + QuantArguments + &mma_quant_args, ///< iterators for extra quant params for B + int &gemm_k_iterations, ///< [in|out] number of threadblock mainloop + ///< iterations remaining + int stage) { + const int mma_stage = stage - Base::kStages + 1; + // Unroll the warp-level MMA tiles of a threadblock's mainloop iteration CUTLASS_PRAGMA_UNROLL - for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) { - // CUTLASS_TRACE_DEVICE(" [MMa] stage=%d, warp_mma_k=%d", stage, warp_mma_k); - - // Load the next warp-tile's A fragment from shared memory - this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); - this->warp_tile_iterator_A_.load(pipe_state.warp_loaded_frag_A_[(warp_mma_k + 1) % 2]); - ++this->warp_tile_iterator_A_; - - if (warp_mma_k + 1 == Base::kWarpGemmIterations) { - // Unpack and dequant the first stage of B. - int unpack_stage = stage - Base::kStages + 2; - tile_dequanter_B.UnpackAndDequant(smem_zipped_ptr_B_ + (unpack_stage % Base::kStages) * smem_zipped_bytes_per_stage_B_, - column_wise_smem_ptr_B_, unpack_stage); + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; + ++warp_mma_k) { + int warp_k_compute_offset_B = + warp_mma_k % Base::kWarpGemmIterationsPerLoadForB; + + if (warp_k_compute_offset_B == Base::kWarpGemmIterationsPerLoadForB - 1) { + // Load the next warp-tile's B fragment from shared memory + this->warp_tile_iterator_B_.set_kgroup_index( + ((warp_mma_k + 1) % Base::kWarpGemmIterations) / + Base::kWarpLoadIterationsForB); + this->warp_tile_iterator_B_.load(pipe_state.warp_loaded_frag_B_); + ++this->warp_tile_iterator_B_; + } - // Copy dequatized data to shared memory used by mma core. - copy_tiles_and_advance_per_stage_B(iterator_B); + // load next-tile of group-wise local_scale from shared memory + if (warp_mma_k == Base::kWarpGemmIterations - 1) { + warp_dequantizer_.load(pipe_state.warp_frag_local_scale_); } - // Load the next warp-tile's B fragment from shared memory - this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); - this->warp_tile_iterator_B_.load(pipe_state.warp_loaded_frag_B_[(warp_mma_k + 1) % 2]); - ++this->warp_tile_iterator_B_; + // Load the next warp-tile's A fragment from shared memory + this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % + Base::kWarpGemmIterations); + this->warp_tile_iterator_A_.load( + pipe_state.warp_frag_A_[(warp_mma_k + 1) % 2]); + ++this->warp_tile_iterator_A_; - // Except for the first warp-tile, all warp-tiles convert their incoming shared memory fragments as necessary - if (warp_mma_k > 0) { - warp_mma_.transform( - pipe_state.warp_transformed_frag_A_[warp_mma_k % 2], - pipe_state.warp_transformed_frag_B_[warp_mma_k % 2], - pipe_state.warp_loaded_frag_A_[warp_mma_k % 2], - pipe_state.warp_loaded_frag_B_[warp_mma_k % 2]); - } + // dequantizes next warp-tile + warp_dequantizer_.dequantize( + pipe_state.warp_frag_local_scale_, + pipe_state.warp_frag_code_scale_, + pipe_state.warp_frag_code_zp_, + pipe_state.warp_frag_super_scale_, + pipe_state.warp_loaded_frag_B_, + pipe_state.warp_frag_B_[(warp_mma_k + 1) % 2], + ((warp_mma_k == Base::kWarpGemmIterations - 1) ? (mma_stage + 1) + : mma_stage) * + Shape::kK, + (warp_mma_k + 1) % Base::kWarpGemmIterationsPerLoadForB); // Execute the current warp-tile of MMA operations - if (Detail::kStagedAccumulation) { - warp_mma_( - pipe_state.tmp_accum_, - pipe_state.warp_transformed_frag_A_[warp_mma_k % 2], - pipe_state.warp_transformed_frag_B_[warp_mma_k % 2], - pipe_state.tmp_accum_ - ); + if constexpr (Detail::kStagedAccumulation) { + warp_mma_(pipe_state.tmp_accum_, + pipe_state.warp_frag_A_[warp_mma_k % 2], + pipe_state.warp_frag_B_[warp_mma_k % 2], + pipe_state.tmp_accum_); if (warp_mma_k == 0) { plus plus_accum; @@ -602,35 +674,46 @@ class Wint2xMmaMultistage : pipe_state.tmp_accum_.clear(); } } else { - warp_mma_( - accum, - pipe_state.warp_transformed_frag_A_[warp_mma_k % 2], - pipe_state.warp_transformed_frag_B_[warp_mma_k % 2], - accum - ); + warp_mma_(accum, + pipe_state.warp_frag_A_[warp_mma_k % 2], + pipe_state.warp_frag_B_[warp_mma_k % 2], + accum); } // Except for the last warp-tile, all warp-tiles issue their share of // global->shared fragment copies if (warp_mma_k < Base::kWarpGemmIterations - 1) { int group_start_iteration_A = warp_mma_k * Detail::kAccessesPerGroupA; + int group_start_iteration_B = warp_mma_k * Detail::kAccessesPerGroupB; copy_tiles_and_advance_A(iterator_A, group_start_iteration_A); + copy_tiles_and_advance_B(iterator_B, group_start_iteration_B); if (warp_mma_k == 0) { - tile_dequanter_B.Load(smem_zipped_ptr_B_ + (stage % Base::kStages) * smem_zipped_bytes_per_stage_B_, - column_wise_smem_ptr_B_, stage); + quant_params_accessor_B_.copy_tiles_and_advance_per_stage( + mma_quant_args, stage); } } // The second-to-last warp-tile also: - // - performs the last warp-tile's share of global->shared fragment copies + // - performs the last warp-tile's share of global->shared fragment + // copies // - moves to the next global fetch stage if (warp_mma_k + 2 == Base::kWarpGemmIterations) { // Performs the last warp-tile's share of global->shared fragment copies - int group_start_iteration_A = (warp_mma_k + 1) * Detail::kAccessesPerGroupA; + if constexpr (Detail::AsyncCopyIterationsPerStageA >= + Base::kWarpGemmIterations) { + int group_start_iteration_A = + (warp_mma_k + 1) * Detail::kAccessesPerGroupA; + copy_tiles_and_advance_A(iterator_A, group_start_iteration_A); + } - copy_tiles_and_advance_A(iterator_A, group_start_iteration_A); + if constexpr (Detail::AsyncCopyIterationsPerStageB >= + Base::kWarpGemmIterations) { + int group_start_iteration_B = + (warp_mma_k + 1) * Detail::kAccessesPerGroupB; + copy_tiles_and_advance_B(iterator_B, group_start_iteration_B); + } // Inserts a memory fence between stages of cp.async instructions. cutlass::arch::cp_async_fence(); @@ -639,69 +722,68 @@ class Wint2xMmaMultistage : gmem_wait(); // Move to the next global fetch stage - advance_smem_write_stage(iterator_A, iterator_B, tile_dequanter_B); + advance_smem_write_stage(iterator_A, iterator_B); + quant_params_accessor_B_.advance_smem_write_stage(mma_quant_args); + advance_smem_read_stage(); + int byte_offset = quant_params_accessor_B_.advance_smem_read_stage(); + warp_dequantizer_.add_pointer_offset(byte_offset); // Disable global fetching when done with global fetch iterations --gemm_k_iterations; iterator_A.clear_mask(gemm_k_iterations == 0); - iterator_B.clear_mask(gemm_k_iterations == (-Base::kStages + 1)); - } - - // The last warp-tile also converts the shared memory fragments used by - // the first warp-tile of the next iteration, if necessary (so we can - // immediately start issuing MMA instructions at the top of the loop ) - if (warp_mma_k + 1 == Base::kWarpGemmIterations) { - warp_mma_.transform( - pipe_state.warp_transformed_frag_A_[(warp_mma_k + 1) % 2], - pipe_state.warp_transformed_frag_B_[(warp_mma_k + 1) % 2], - pipe_state.warp_loaded_frag_A_[(warp_mma_k + 1) % 2], - pipe_state.warp_loaded_frag_B_[(warp_mma_k + 1) % 2]); + iterator_B.clear_mask(gemm_k_iterations == 0); + quant_params_accessor_B_.clear_mask(mma_quant_args, + gemm_k_iterations == 0); } } } /// Perform the specified number of threadblock mainloop iterations of matrix /// multiply-accumulate. Assumes prologue has been initiated. - template CUTLASS_DEVICE void gemm_iters( - int gemm_k_iterations, ///< number of threadblock mainloop iterations - FragmentC &accum, ///< [in|out] accumulator tile - IteratorA &iterator_A, ///< [in|out] iterator over A operand in global memory - IteratorB &iterator_B, - TileDequanterB &tile_dequanter_B) ///< [in|out] iterator over B operand in global memory - { + int gemm_k_iterations, ///< number of threadblock mainloop iterations + FragmentC &accum, ///< [in|out] accumulator tile + IteratorA + &iterator_A, ///< [in|out] iterator over A operand in global memory + IteratorB + &iterator_B, ///< [in|out] iterator over B operand in global memory + QuantArguments &mma_quant_args) { PipeState pipe_state; - // Unpack and dequant the first stage of B. - tile_dequanter_B.UnpackAndDequant(smem_zipped_ptr_B_, column_wise_smem_ptr_B_, 0); - // Disable global fetching if done with global fetch iterations iterator_A.clear_mask(gemm_k_iterations == 0); - iterator_B.clear_mask(gemm_k_iterations == (-Base::kStages + 1)); - - // Load first warp-tile's A fragment from shared memory - this->warp_tile_iterator_A_.set_kgroup_index(0); - this->warp_tile_iterator_A_.load(pipe_state.warp_loaded_frag_A_[0]); - ++this->warp_tile_iterator_A_; - - // Copy dequatized data to shared memory used by mma core. - copy_tiles_and_advance_per_stage_B(iterator_B); + iterator_B.clear_mask(gemm_k_iterations == 0); + quant_params_accessor_B_.clear_mask(mma_quant_args, gemm_k_iterations == 0); // Load first warp-tile's B fragment from shared memory this->warp_tile_iterator_B_.set_kgroup_index(0); - this->warp_tile_iterator_B_.load(pipe_state.warp_loaded_frag_B_[0]); + this->warp_tile_iterator_B_.load(pipe_state.warp_loaded_frag_B_); ++this->warp_tile_iterator_B_; - // Transform, if necessary, the first warp-tile's shared memory fragments - warp_mma_.transform( - pipe_state.warp_transformed_frag_A_[0], - pipe_state.warp_transformed_frag_B_[0], - pipe_state.warp_loaded_frag_A_[0], - pipe_state.warp_loaded_frag_B_[0]); + warp_dequantizer_.load(pipe_state.warp_frag_code_scale_, + pipe_state.warp_frag_code_zp_, + pipe_state.warp_frag_super_scale_); - if (Detail::kStagedAccumulation) { + warp_dequantizer_.load(pipe_state.warp_frag_local_scale_); + + // Load first warp-tile's A fragment from shared memory + this->warp_tile_iterator_A_.set_kgroup_index(0); + this->warp_tile_iterator_A_.load(pipe_state.warp_frag_A_[0]); + ++this->warp_tile_iterator_A_; + + // Dequantize B to in register + warp_dequantizer_.dequantize(pipe_state.warp_frag_local_scale_, + pipe_state.warp_frag_code_scale_, + pipe_state.warp_frag_code_zp_, + pipe_state.warp_frag_super_scale_, + pipe_state.warp_loaded_frag_B_, + pipe_state.warp_frag_B_[0], + 0, + 0); + + if constexpr (Detail::kStagedAccumulation) { pipe_state.tmp_accum_.clear(); } @@ -710,23 +792,23 @@ class Wint2xMmaMultistage : // Mainloop CUTLASS_GEMM_LOOP for (; gemm_k_iterations > (-Base::kStages + 1);) { - mac_loop_iter( - pipe_state, - accum, - iterator_A, - iterator_B, - tile_dequanter_B, - gemm_k_iterations, - stage); + mac_loop_iter(pipe_state, + accum, + iterator_A, + iterator_B, + mma_quant_args, + gemm_k_iterations, + stage); stage += 1; } - if (Detail::kStagedAccumulation) { + if constexpr (Detail::kStagedAccumulation) { plus plus_accum; accum = plus_accum(accum, pipe_state.tmp_accum_); } - // Commit and drain all pending and predicated cp.async pnz from the GEMM mainloop + // Commit and drain all pending and predicated cp.async pnz from the GEMM + // mainloop cutlass::arch::cp_async_fence(); cutlass::arch::cp_async_wait<0>(); __syncthreads(); @@ -734,15 +816,16 @@ class Wint2xMmaMultistage : /// Prepares the class for another prologue. CUTLASS_DEVICE - void wind_down() - { - // Catch-up the smem-read iterator to the smem-write iterator (so this class can be reused for another tile's prologue) - - // First, increment remaining warp tiles to get to the next full stage. (Ideally we would - // just decrement one tile, but not all iterators implement --() decrement.) - #pragma unroll - for (int warp_mma_k = 1; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) - { + void wind_down() { +// Catch-up the smem-read iterator to the smem-write iterator (so this class can +// be reused for another tile's prologue) + +// First, increment remaining warp tiles to get to the next full stage. (Ideally +// we would just decrement one tile, but not all iterators implement --() +// decrement.) +#pragma unroll + for (int warp_mma_k = 1; warp_mma_k < Base::kWarpGemmIterations; + ++warp_mma_k) { this->warp_tile_iterator_A_.set_kgroup_index(warp_mma_k); this->warp_tile_iterator_B_.set_kgroup_index(warp_mma_k); @@ -751,24 +834,24 @@ class Wint2xMmaMultistage : } smem_read_stage_idx_++; - // Then wrap back two full stages (one for the tile advancing we just did, and one to catch the write iterators) - static const int kStageIters = Policy::kPartitionsK * Base::kWarpGemmIterations; - if (smem_read_stage_idx_ > 1) - { + // Then wrap back two full stages (one for the tile advancing we just did, + // and one to catch the write iterators) + static const int kStageIters = + Policy::kPartitionsK * Base::kWarpGemmIterations; + if (smem_read_stage_idx_ > 1) { this->warp_tile_iterator_A_.add_tile_offset({0, (-2 * kStageIters)}); this->warp_tile_iterator_B_.add_tile_offset({(-2 * kStageIters), 0}); - } - else - { - this->warp_tile_iterator_A_.add_tile_offset({0, ((Base::kStages - 2) * kStageIters)}); - //this->warp_tile_iterator_B_.add_tile_offset({((Base::kStages - 2) * kStageIters), 0}); - this->warp_tile_iterator_B_.add_tile_offset({(-2 * kStageIters), 0}); + } else { + this->warp_tile_iterator_A_.add_tile_offset( + {0, ((Base::kStages - 2) * kStageIters)}); + this->warp_tile_iterator_B_.add_tile_offset( + {((Base::kStages - 2) * kStageIters), 0}); } smem_read_stage_idx_ = smem_write_stage_idx_; } - /// Perform a threadblock-scoped matrix multiply-accumulate, pre-load B to shared memory. - template + /// Perform a threadblock-scoped matrix multiply-accumulate, pre-load B to + /// shared memory. CUTLASS_DEVICE void operator()( ///< problem size of GEMM @@ -779,13 +862,13 @@ class Wint2xMmaMultistage : IteratorA iterator_A, ///< iterator over B operand in global memory IteratorB iterator_B, - ///< pre-load and dequantize B to shared memory - TileDequanterB tile_dequanter_B, + ///< iterators for extra quant params for B + QuantArguments mma_quant_args, ///< initial value of accumulator FragmentC const &src_accum) { - - // Prologue (start fetching iterations of global fragments into shared memory) - prologue(iterator_A, iterator_B, tile_dequanter_B, gemm_k_iterations); + // Prologue (start fetching iterations of global fragments into shared + // memory) + prologue(iterator_A, iterator_B, mma_quant_args, gemm_k_iterations); // Wait until we have at least one completed global fetch stage gmem_wait(); @@ -794,7 +877,8 @@ class Wint2xMmaMultistage : accum = src_accum; // Perform the MAC-iterations - gemm_iters(gemm_k_iterations, accum, iterator_A, iterator_B, tile_dequanter_B); + gemm_iters( + gemm_k_iterations, accum, iterator_A, iterator_B, mma_quant_args); } }; diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_params_accessor.h b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_params_accessor.h new file mode 100644 index 00000000000..2409087cbce --- /dev/null +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_params_accessor.h @@ -0,0 +1,358 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "cutlass/arch/memory_sm80.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/trace.h" + +namespace cutlass { +namespace gemm { +namespace threadblock { + +template < + /// Original data type + typename T, + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Iterators over super scales in global memory + typename IteratorSuperScale_, + /// Iterators over super scales in shared memory + typename SmemIteratorSuperScale_, + /// Iterators over local scales in global memory + typename IteratorLocalScale_, + /// Iterators over local scales in shared memory + typename SmemIteratorLocalScale_, + /// Iterators over code scales and zps in global memory + typename IteratorCodeScaleZp_, + /// Iterators over code scales and zps in shared memory + typename SmemIteratorCodeScaleZp_, + /// Number of stages, + int Stages_, + /// Group size for quantization + int GroupSize_> +class Wint2ParamsAccessor { + public: + static_assert(platform::is_same::value || + platform::is_same::value, + "T must be fp16 or bf16"); + + using ElementType = T; + using Shape = Shape_; + + using IteratorSuperScale = IteratorSuperScale_; + using SmemIteratorSuperScale = SmemIteratorSuperScale_; + + using IteratorLocalScale = IteratorLocalScale_; + using SmemIteratorLocalScale = SmemIteratorLocalScale_; + + using IteratorCodeScaleZp = IteratorCodeScaleZp_; + using SmemIteratorCodeScaleZp = SmemIteratorCodeScaleZp_; + + constexpr static int kStages = Stages_; + constexpr static int kGroupSize = GroupSize_; + + using ElementSuperScale = typename IteratorSuperScale::Element; + using LayoutSuperScale = typename IteratorSuperScale::Layout; + + /// local_scale uint4 and group-wise + using ElementLocalScale = typename IteratorLocalScale::Element; + using LayoutLocalScale = typename IteratorLocalScale::Layout; + static_assert(platform::is_same::value, + "local_scale's type must be uint4b_t."); + + using ElementCodeScaleZp = typename IteratorCodeScaleZp::Element; + using LayoutCodeScaleZp = typename IteratorCodeScaleZp::Layout; + + /// 2 uint4b_t values are stored in a single uint8_t + constexpr static int kStagesPerLocalScaleLoad = 2 * kGroupSize / Shape::kK; + constexpr static int kLocalScaleRows = + IteratorLocalScale::Shape::kRow * IteratorLocalScale::Shape::kColumn * + sizeof_bits::value / 8 / Shape::kN; + + using SmemElement = uint8_t; + constexpr static int kSmemRows = kLocalScaleRows * kStages + + sizeof(ElementSuperScale) + + sizeof(ElementCodeScaleZp) * 2; + constexpr static int kSmemColumns = Shape::kN; + + using QuantParamsShape = MatrixShape; + + constexpr static int kSuperScaleSmemOffset = 0; + constexpr static int kCodeScaleSmemOffset = + kSmemColumns * sizeof(ElementSuperScale); + constexpr static int kCodeZpSmemOffset = + kCodeScaleSmemOffset + kSmemColumns * sizeof(ElementCodeScaleZp); + constexpr static int kLocalScaleSmemOffset = + kCodeZpSmemOffset + kSmemColumns * sizeof(ElementCodeScaleZp); + + /// TensorRef type for loading element from a tensor + using SuperTensorRef = + cutlass::TensorRef; + using LocalTensorRef = + cutlass::TensorRef; + using CodeTensorRef = + cutlass::TensorRef; + + struct Arguments { + IteratorSuperScale iterator_super_scale; + IteratorLocalScale iterator_local_scale; + IteratorCodeScaleZp iterator_code_scale; + IteratorCodeScaleZp iterator_code_zp; + + int local_scale_pointer_offset; + + CUTLASS_DEVICE + Arguments(IteratorSuperScale iterator_super_scale, + IteratorLocalScale iterator_local_scale, + IteratorCodeScaleZp iterator_code_scale, + IteratorCodeScaleZp iterator_code_zp, + int local_scale_pointer_offset) + : iterator_super_scale(iterator_super_scale), + iterator_local_scale(iterator_local_scale), + iterator_code_scale(iterator_code_scale), + iterator_code_zp(iterator_code_zp), + local_scale_pointer_offset(local_scale_pointer_offset) {} + }; + + private: + // + // Data members + // + + /// Begin address of shared memory + uint8_t* smem_pointer_; + + /// Iterator to write threadblock-scoped tile of super scale operand to shared + /// memory + SmemIteratorSuperScale smem_iterator_super_scale_; + /// Iterator to write threadblock-scoped tile of local scale operand to shared + /// memory + SmemIteratorLocalScale smem_iterator_local_scale_; + /// Iterator to write threadblock-scoped tile of code scale operand to shared + /// memory + SmemIteratorCodeScaleZp smem_iterator_code_scale_; + /// Iterator to write threadblock-scoped tile of code zp operand to shared + /// memory + SmemIteratorCodeScaleZp smem_iterator_code_zp_; + + /// Shared memory write stage index + int smem_write_stage_idx_; + + /// Shared memory read stage index + int smem_read_stage_idx_; + + CUTLASS_DEVICE + ElementSuperScale* get_super_scale_smem_ptr() { + return reinterpret_cast(smem_pointer_ + + kSuperScaleSmemOffset); + } + + CUTLASS_DEVICE + ElementLocalScale* get_local_scale_smem_ptr() { + return reinterpret_cast(smem_pointer_ + + kLocalScaleSmemOffset); + } + + CUTLASS_DEVICE + ElementCodeScaleZp* get_code_scale_smem_ptr() { + return reinterpret_cast(smem_pointer_ + + kCodeScaleSmemOffset); + } + + CUTLASS_DEVICE + ElementCodeScaleZp* get_code_zp_smem_ptr() { + return reinterpret_cast(smem_pointer_ + + kCodeZpSmemOffset); + } + + public: + /// Construct from tensor references + CUTLASS_DEVICE + Wint2ParamsAccessor( + ///< prointer of shared memory + uint8_t* smem_pointer, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx) + : smem_pointer_(smem_pointer), + smem_iterator_super_scale_( + LayoutSuperScale(IteratorSuperScale::Shape::kColumn), + get_super_scale_smem_ptr(), + {1, IteratorSuperScale::Shape::kColumn}, + thread_idx), + smem_iterator_local_scale_( + LayoutLocalScale(IteratorLocalScale::Shape::kColumn), + get_local_scale_smem_ptr(), + {1, IteratorLocalScale::Shape::kColumn}, + thread_idx), + smem_iterator_code_scale_( + LayoutCodeScaleZp(IteratorCodeScaleZp::Shape::kColumn), + get_code_scale_smem_ptr(), + {1, IteratorCodeScaleZp::Shape::kColumn}, + thread_idx), + smem_iterator_code_zp_( + LayoutCodeScaleZp(IteratorCodeScaleZp::Shape::kColumn), + get_code_zp_smem_ptr(), + {1, IteratorCodeScaleZp::Shape::kColumn}, + thread_idx), + smem_write_stage_idx_(0), + smem_read_stage_idx_(0) {} + + CUTLASS_DEVICE + SuperTensorRef super_scale_ref() { + return {get_super_scale_smem_ptr(), + LayoutSuperScale(IteratorSuperScale::Shape::kColumn)}; + } + + CUTLASS_DEVICE + LocalTensorRef local_scale_ref() { + return {get_local_scale_smem_ptr(), + LayoutLocalScale(IteratorLocalScale::Shape::kColumn)}; + } + + CUTLASS_DEVICE + CodeTensorRef code_scale_ref() { + return {get_code_scale_smem_ptr(), + LayoutCodeScaleZp(IteratorCodeScaleZp::Shape::kColumn)}; + } + + CUTLASS_DEVICE + CodeTensorRef code_zp_ref() { + return {get_code_zp_smem_ptr(), + LayoutCodeScaleZp(IteratorCodeScaleZp::Shape::kColumn)}; + } + + template + CUTLASS_DEVICE void copy_tiles_and_advance_per_stage(Arguments& quant_args, + int stage) { + if constexpr (IsFirstStage) { + // Load channel-wise super_scale to shared memory, which only needs to be + // done once. + typename IteratorSuperScale::Fragment tb_frag_super_scale; + tb_frag_super_scale.clear(); + quant_args.iterator_super_scale.load(tb_frag_super_scale); + this->smem_iterator_super_scale_.store(tb_frag_super_scale); + + // Load channel-wise code_scale to shared memory, which only needs to be + // done once. + typename IteratorCodeScaleZp::Fragment tb_frag_code_scale; + tb_frag_code_scale.clear(); + quant_args.iterator_code_scale.load(tb_frag_code_scale); + this->smem_iterator_code_scale_.store(tb_frag_code_scale); + + // Load channel-wise code_zp to shared memory, which only needs to be done + // once. + typename IteratorCodeScaleZp::Fragment tb_frag_code_zp; + tb_frag_code_zp.clear(); + quant_args.iterator_code_zp.load(tb_frag_code_zp); + this->smem_iterator_code_zp_.store(tb_frag_code_zp); + } + + if ((stage % kStagesPerLocalScaleLoad) == 0) { + // Load group-wise local_scale to shared memory, which only needs to be + // done at each stage. Since 2 uint4b_t values of local_scale are saved in + // a single uint8_t, local_scale needs to be loaded once every two stages. + using AccessType = typename IteratorLocalScale::AccessType; + cutlass::arch::CacheOperation::Kind const kCacheOp = + (sizeof_bits::value == 128) + ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; + + quant_args.iterator_local_scale.set_iteration_index(0); + this->smem_iterator_local_scale_.set_iteration_index(0); + + // Async Copy for local_scale + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < IteratorLocalScale::ThreadMap::Iterations::kCount; + ++j) { + AccessType* dst_ptr = reinterpret_cast( + this->smem_iterator_local_scale_.get()); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorLocalScale::kAccessesPerVector; ++v) { + auto gmem_ptr = quant_args.iterator_local_scale.get(); + + int const kSrcBytes = + sizeof_bits::value * + IteratorLocalScale::ThreadMap::kElementsPerAccess / + IteratorLocalScale::kAccessesPerVector / 8; + + cutlass::arch::cp_async( + dst_ptr + v, gmem_ptr, quant_args.iterator_local_scale.valid()); + } + ++quant_args.iterator_local_scale; + } + ++this->smem_iterator_local_scale_; + } + } + + CUTLASS_DEVICE + void advance_smem_write_stage(Arguments& quant_args) { + if (smem_write_stage_idx_ % kStagesPerLocalScaleLoad == 0) { + // Advance global iterators + quant_args.iterator_local_scale.add_pointer_offset( + quant_args.local_scale_pointer_offset); + + // Advance shared iterators + int smem_pointer_offset = + IteratorLocalScale::Shape::kRow * IteratorLocalScale::Shape::kColumn; + smem_iterator_local_scale_.add_pointer_offset(smem_pointer_offset); + } + + // Increment shared memory write stage index + ++smem_write_stage_idx_; + + if (smem_write_stage_idx_ == kStagesPerLocalScaleLoad * kStages) { + // Wrap back around to the 'start' of the circular buffer in shared memory + int pointer_offset = -kStages * IteratorLocalScale::Shape::kRow * + IteratorLocalScale::Shape::kColumn; + smem_iterator_local_scale_.add_pointer_offset(pointer_offset); + smem_write_stage_idx_ = 0; + } + } + + CUTLASS_DEVICE + int advance_smem_read_stage() { + int byte_offset = 0; + + ++smem_read_stage_idx_; + + if (smem_read_stage_idx_ % kStagesPerLocalScaleLoad == 0) { + byte_offset = kLocalScaleRows * kSmemColumns; + } + + if (smem_read_stage_idx_ == kStagesPerLocalScaleLoad * kStages) { + smem_read_stage_idx_ = 0; + byte_offset = -(kStages - 1) * kLocalScaleRows * kSmemColumns; + } + + return byte_offset; + } + + CUTLASS_DEVICE + int clear_mask(Arguments& quant_args, bool cond) { + quant_args.iterator_local_scale.clear_mask(cond); + } +}; + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_tile_dequanter.h b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_tile_dequanter.h deleted file mode 100644 index cec6bcea034..00000000000 --- a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_tile_dequanter.h +++ /dev/null @@ -1,130 +0,0 @@ -// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include "cutlass/gemm_coord.h" -#include "cutlass/trace.h" - -#include "cutlass_extensions/gemm/threadblock/wint2x_unzip.h" - -namespace cutlass { -namespace gemm { -namespace threadblock { - -template -struct TileDequanter { - using WeightQuantTraits = WintQuantTraits; - using MmaElementT = typename WeightQuantTraits::MmaWeightType; - using QuantArguments = typename WeightQuantTraits::Arguments; - - using UnzipAndDequantFunctor = - UnzipAndDequantFunctor; - - static constexpr bool kUseSharedMemory = true; - - static constexpr int kRows = Rows; - static constexpr int kColumns = Columns; - static constexpr int kStages = Stages; - - MmaElementT *out_smem_ptr{nullptr}; - - char *pointer{nullptr}; - int64_t ldm{0}; - cutlass::MatrixCoord tb_offset; - cutlass::MatrixCoord extent; - - ScaleElementT *super_scale_ptr{nullptr}; - cutlass::MatrixCoord tb_offset_scale; - - QuantArguments quant_args; - - int64_t block_start_rows[kStages]; - bool need_preload{true}; - UnzipAndDequantFunctor unzip_functor; - - CUTLASS_DEVICE - TileDequanter(MmaElementT *out_smem_ptr, char *pointer, int64_t ldm, - const cutlass::MatrixCoord &extent, - const cutlass::MatrixCoord &tb_offset, - ScaleElementT *super_scale_ptr, - const cutlass::MatrixCoord &tb_offset_scale, - const QuantArguments &quant_args) - : out_smem_ptr(out_smem_ptr), pointer(pointer), ldm(ldm), extent(extent), - tb_offset(tb_offset), super_scale_ptr(super_scale_ptr), - tb_offset_scale(tb_offset_scale), quant_args(quant_args) {} - - CUTLASS_DEVICE - MmaElementT *GetOutPtr() { return out_smem_ptr; } - - CUTLASS_DEVICE - void AddTileOffset(const cutlass::MatrixCoord &tile_offset) { - tb_offset.row() += tile_offset.row() * kRows; - tb_offset.column() += tile_offset.column() * kColumns; - tb_offset_scale.column() += tile_offset.column() * kColumns; - } - - CUTLASS_DEVICE - void Load(uint8_t *zipped_smem_ptr, uint8_t *column_wise_smem_ptr, int stage) { - int zipped_row = WeightQuantTraits::CaclPackedDim(tb_offset.row()); - if (tb_offset.row() >= extent.row() || - tb_offset.column() >= extent.column()) { - return; - } - - block_start_rows[stage % kStages] = tb_offset.row(); - - using ZippedT = typename WeightQuantTraits::WeightType; - ZippedT *in_ptr = reinterpret_cast(pointer) + zipped_row * ldm + - tb_offset.column(); - ScaleElementT *scale_ptr = super_scale_ptr + tb_offset_scale.column(); - - if constexpr (Method == WintQuantMethod::kWeightOnlyInt2) { - const uint8_t *local_scale_ptr = quant_args.local_scale_ptr + - (tb_offset.row() / 128) * ldm + - tb_offset_scale.column(); - const float *code_scale_ptr = - quant_args.code_scale_ptr + tb_offset_scale.column(); - const float *code_zp_ptr = - quant_args.code_zp_ptr + tb_offset_scale.column(); - - typename UnzipAndDequantFunctor::Arguments args(zipped_smem_ptr, column_wise_smem_ptr); - unzip_functor.LoadAsync(in_ptr, local_scale_ptr, code_scale_ptr, code_zp_ptr, - scale_ptr, &args, ldm, need_preload); - need_preload = false; - } else { - // CUTLASS_TRACE_DEVICE("Not Supported!"); - } - } - - CUTLASS_DEVICE - void UnpackAndDequant(uint8_t *zipped_smem_ptr, uint8_t *column_wise_smem_ptr, int stage) { - int64_t block_start_row = block_start_rows[stage % kStages]; - if (block_start_row >= extent.row()) { - return; - } - - if constexpr (Method == WintQuantMethod::kWeightOnlyInt2) { - typename UnzipAndDequantFunctor::Arguments args(zipped_smem_ptr, column_wise_smem_ptr); - unzip_functor.ComputeVectorized(args, out_smem_ptr, block_start_row); - } else { - // CUTLASS_TRACE_DEVICE("Not Supported!"); - } - } -}; - -} // namespace threadblock -} // namespace gemm -} // namespace cutlass diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_unzip.h b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_unzip.h index 9d49d5eb53c..a1bc5a0ecf8 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_unzip.h +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_unzip.h @@ -29,18 +29,27 @@ namespace gemm { namespace threadblock { template -using UnzipArray = cutlass::AlignedArray::value / 8)>; - -template +using UnzipArray = + cutlass::AlignedArray::value / 8)>; + +template struct UnzipAndDequantFunctor { - __device__ void operator()(const T *in_ptr, const T *supper_scale_ptr, - T *out_ptr, const int64_t in_stride) {} + __device__ void operator()(const T *in_ptr, + const T *supper_scale_ptr, + T *out_ptr, + const int64_t in_stride) {} }; template -struct UnzipAndDequantFunctor { +struct UnzipAndDequantFunctor { using ZippedT = uint16_t; using ScaleComputeT = float; @@ -52,7 +61,8 @@ struct UnzipAndDequantFunctor> shift_bit) & kWeightMask; int32_t value = shifted_value - kBZP; @@ -61,8 +71,10 @@ struct UnzipAndDequantFunctor(scaled_value); } - __device__ void operator()(const uint16_t *in_ptr, const T *super_scale_ptr, - T *out_ptr, const int64_t in_stride) { + __device__ void operator()(const uint16_t *in_ptr, + const T *super_scale_ptr, + T *out_ptr, + const int64_t in_stride) { int32_t shift_bits[7] = {13, 11, 9, 6, 4, 2, 0}; int tid = threadIdx.x; @@ -111,8 +123,11 @@ struct UnzipAndDequantFunctor -struct UnzipAndDequantFunctor { +struct UnzipAndDequantFunctor { using ZippedT = uint8_t; using ScaleComputeT = float; @@ -129,9 +144,11 @@ struct UnzipAndDequantFunctor(column_wise_smem_ptr); - code_zp_ptr = reinterpret_cast(column_wise_smem_ptr + sizeof(float) * TileColumns); - super_scale_ptr = reinterpret_cast(column_wise_smem_ptr + 2 * sizeof(float) * TileColumns); + code_zp_ptr = reinterpret_cast(column_wise_smem_ptr + + sizeof(float) * TileColumns); + super_scale_ptr = reinterpret_cast(column_wise_smem_ptr + + 2 * sizeof(float) * TileColumns); } }; - __device__ void Load(const uint8_t *g_weight_ptr, const uint8_t *g_local_scale_ptr, - const float *g_code_scale_ptr, const float *g_code_zp_ptr, + __device__ void Load(const uint8_t *g_weight_ptr, + const uint8_t *g_local_scale_ptr, + const float *g_code_scale_ptr, + const float *g_code_zp_ptr, const T *g_super_scale_ptr, - Arguments *args, const int64_t in_stride, bool need_preload) { + Arguments *args, + const int64_t in_stride, + bool need_preload) { int tid = threadIdx.x; #pragma unroll @@ -186,7 +215,8 @@ struct UnzipAndDequantFunctorlocal_scale_ptr[ls_row_id * TileColumns + col] = g_local_scale_ptr[local_scale_offset]; + args->local_scale_ptr[ls_row_id * TileColumns + col] = + g_local_scale_ptr[local_scale_offset]; } #pragma unroll @@ -205,10 +235,12 @@ struct UnzipAndDequantFunctor( - args->weight_ptr + z_offset, g_weight_ptr + g_offset, true); + int z_offset = (tid * weight_per_thread_size + i * kBytesPerThread); + int g_offset = + z_offset / TileColumns * in_stride + z_offset % TileColumns; + cutlass::arch::cp_async( + args->weight_ptr + z_offset, g_weight_ptr + g_offset, true); } } else if (tid < weight_threads + local_scale_threads) { constexpr int start_thread_id = weight_threads; - constexpr int local_scale_per_thread_size = local_scale_size / local_scale_threads; - constexpr int kIterations = (local_scale_per_thread_size + kBytesPerThread - 1) / kBytesPerThread; + constexpr int local_scale_per_thread_size = + local_scale_size / local_scale_threads; + constexpr int kIterations = + (local_scale_per_thread_size + kBytesPerThread - 1) / kBytesPerThread; CUTLASS_PRAGMA_UNROLL for (int i = 0; i < kIterations; ++i) { - int z_offset = (tid - start_thread_id) * local_scale_per_thread_size + i * kBytesPerThread; - int g_offset = z_offset / TileColumns * in_stride + z_offset % TileColumns; - cutlass::arch::cp_async( - args->local_scale_ptr + z_offset, g_local_scale_ptr + g_offset, true); + int z_offset = (tid - start_thread_id) * local_scale_per_thread_size + + i * kBytesPerThread; + int g_offset = + z_offset / TileColumns * in_stride + z_offset % TileColumns; + cutlass::arch::cp_async( + args->local_scale_ptr + z_offset, + g_local_scale_ptr + g_offset, + true); } } else if (need_preload) { if (tid < weight_threads + local_scale_threads + code_scale_threads) { constexpr int start_thread_id = weight_threads + local_scale_threads; - constexpr int code_scale_per_thread_size = code_scale_size / code_scale_threads; - constexpr int kIterations = (code_scale_per_thread_size + kBytesPerThread - 1) / kBytesPerThread; + constexpr int code_scale_per_thread_size = + code_scale_size / code_scale_threads; + constexpr int kIterations = + (code_scale_per_thread_size + kBytesPerThread - 1) / + kBytesPerThread; CUTLASS_PRAGMA_UNROLL for (int i = 0; i < kIterations; ++i) { - int offset = ((tid - start_thread_id) * code_scale_per_thread_size + i * kBytesPerThread) / sizeof(float); - cutlass::arch::cp_async( + int offset = ((tid - start_thread_id) * code_scale_per_thread_size + + i * kBytesPerThread) / + sizeof(float); + cutlass::arch::cp_async( args->code_scale_ptr + offset, g_code_scale_ptr + offset, true); } - } else if (tid < weight_threads + local_scale_threads + code_scale_threads + code_zp_threads) { - constexpr int start_thread_id = weight_threads + local_scale_threads + code_scale_threads; + } else if (tid < weight_threads + local_scale_threads + + code_scale_threads + code_zp_threads) { + constexpr int start_thread_id = + weight_threads + local_scale_threads + code_scale_threads; constexpr int code_zp_per_thread_size = code_zp_size / code_zp_threads; - constexpr int kIterations = (code_zp_per_thread_size + kBytesPerThread - 1) / kBytesPerThread; + constexpr int kIterations = + (code_zp_per_thread_size + kBytesPerThread - 1) / kBytesPerThread; CUTLASS_PRAGMA_UNROLL for (int i = 0; i < kIterations; ++i) { - int offset = ((tid - start_thread_id) * code_zp_per_thread_size + i * kBytesPerThread) / sizeof(float); - cutlass::arch::cp_async( + int offset = ((tid - start_thread_id) * code_zp_per_thread_size + + i * kBytesPerThread) / + sizeof(float); + cutlass::arch::cp_async( args->code_zp_ptr + offset, g_code_zp_ptr + offset, true); } - } else if (tid < weight_threads + local_scale_threads + code_scale_threads + code_zp_threads + super_scale_threads) { + } else if (tid < weight_threads + local_scale_threads + + code_scale_threads + code_zp_threads + + super_scale_threads) { if (g_super_scale_ptr) { - constexpr int start_thread_id = weight_threads + local_scale_threads + code_scale_threads + code_zp_threads; - constexpr int super_scale_per_thread_size = super_scale_size / super_scale_threads; - constexpr int kIterations = (super_scale_per_thread_size + kBytesPerThread - 1) / kBytesPerThread; + constexpr int start_thread_id = weight_threads + local_scale_threads + + code_scale_threads + code_zp_threads; + constexpr int super_scale_per_thread_size = + super_scale_size / super_scale_threads; + constexpr int kIterations = + (super_scale_per_thread_size + kBytesPerThread - 1) / + kBytesPerThread; CUTLASS_PRAGMA_UNROLL for (int i = 0; i < kIterations; ++i) { - int offset = ((tid - start_thread_id) * super_scale_per_thread_size + i * kBytesPerThread) / sizeof(T); - cutlass::arch::cp_async( - args->super_scale_ptr + offset, g_super_scale_ptr + offset, true); + int offset = + ((tid - start_thread_id) * super_scale_per_thread_size + + i * kBytesPerThread) / + sizeof(T); + cutlass::arch::cp_async( + args->super_scale_ptr + offset, + g_super_scale_ptr + offset, + true); } } } } } - __device__ void Compute(const Arguments &args, T *out_ptr, + __device__ void Compute(const Arguments &args, + T *out_ptr, const int64_t block_start_row) { int32_t shift_bits[4] = {9, 6, 3, 0}; @@ -333,9 +408,9 @@ struct UnzipAndDequantFunctor(floor(zipped_value[zipped_row] * code_scale + code_zp + - static_cast(0.5))); + int32_t decode_value = static_cast( + floor(zipped_value[zipped_row] * code_scale + code_zp + + static_cast(0.5))); int row = group_id * 64 + zipped_row * 4; @@ -355,14 +430,17 @@ struct UnzipAndDequantFunctor= 32) ? 4 : 2; constexpr int RowStride = NumThreads * N / TileColumns; constexpr int kNumIters = kNumWeightsPerThread / N; - static_assert(N * NumThreads >= TileColumns, "N * NumThreads should be no less than TileColumns."); + static_assert(N * NumThreads >= TileColumns, + "N * NumThreads should be no less than TileColumns."); constexpr ScaleComputeT decode_value_zp = static_cast(0.5); @@ -373,19 +451,22 @@ struct UnzipAndDequantFunctor local_scales = - *reinterpret_cast *>(args.local_scale_ptr + begin_col_id); + *reinterpret_cast *>(args.local_scale_ptr + + begin_col_id); UnzipArray zipped_values[2]; int zipped_offset = begin_row_id * TileColumns + begin_col_id; - zipped_values[0] = - *reinterpret_cast *>(args.weight_ptr + zipped_offset); + zipped_values[0] = *reinterpret_cast *>( + args.weight_ptr + zipped_offset); - UnzipArray super_scales = - *reinterpret_cast *>(args.super_scale_ptr + begin_col_id); + UnzipArray super_scales = *reinterpret_cast *>( + args.super_scale_ptr + begin_col_id); UnzipArray code_scales = - *reinterpret_cast *>(args.code_scale_ptr + begin_col_id); + *reinterpret_cast *>(args.code_scale_ptr + + begin_col_id); UnzipArray code_zps = - *reinterpret_cast *>(args.code_zp_ptr + begin_col_id); + *reinterpret_cast *>(args.code_zp_ptr + + begin_col_id); // special for TileRows = 64 int local_scale_shift = (((block_start_row / 64) + 1) & 1) * 4; @@ -394,9 +475,10 @@ struct UnzipAndDequantFunctor(local_scales[i]) >> local_scale_shift) & kLocalScaleMask; - scales[i] = - static_cast(shifted_local_scale) * static_cast(super_scales[i]); + (static_cast(local_scales[i]) >> local_scale_shift) & + kLocalScaleMask; + scales[i] = static_cast(shifted_local_scale) * + static_cast(super_scales[i]); } #pragma unroll @@ -405,26 +487,33 @@ struct UnzipAndDequantFunctor *>(args.weight_ptr + zipped_offset); + *reinterpret_cast *>(args.weight_ptr + + zipped_offset); } UnzipArray outs[4]; #pragma unroll for (int i = 0; i < N; ++i) { - int32_t decode_value = - static_cast(floor(static_cast(zipped_values[iter_id & 1][i]) * code_scales[i] - + code_zps[i] + decode_value_zp)); + int32_t decode_value = static_cast( + floor(static_cast(zipped_values[iter_id & 1][i]) * + code_scales[i] + + code_zps[i] + decode_value_zp)); - ScaleComputeT value_3 = static_cast((decode_value & kWeightMask) - kBZP); + ScaleComputeT value_3 = + static_cast((decode_value & kWeightMask) - kBZP); decode_value >>= 3; - ScaleComputeT value_2 = static_cast((decode_value & kWeightMask) - kBZP); + ScaleComputeT value_2 = + static_cast((decode_value & kWeightMask) - kBZP); decode_value >>= 3; - ScaleComputeT value_1 = static_cast((decode_value & kWeightMask) - kBZP); + ScaleComputeT value_1 = + static_cast((decode_value & kWeightMask) - kBZP); decode_value >>= 3; - ScaleComputeT value_0 = static_cast((decode_value & kWeightMask) - kBZP); + ScaleComputeT value_0 = + static_cast((decode_value & kWeightMask) - kBZP); outs[0][i] = static_cast(scales[i] * value_0); outs[1][i] = static_cast(scales[i] * value_1); outs[2][i] = static_cast(scales[i] * value_2); diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/warp/default_mma_tensor_op.h b/custom_ops/gpu_ops/cutlass_extensions/gemm/warp/default_mma_tensor_op.h index 350b247de2e..8245aff71c3 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/gemm/warp/default_mma_tensor_op.h +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/warp/default_mma_tensor_op.h @@ -1,12 +1,12 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation @@ -18,18 +18,20 @@ * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /*! \file - \brief Default warp-level GEMM operators selected by data type, size, and layouts of operands. + \brief Default warp-level GEMM operators selected by data type, size, and + layouts of operands. */ #pragma once @@ -41,12 +43,9 @@ #include "cutlass_extensions/arch/mma.h" #include "cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h" -namespace cutlass -{ -namespace gemm -{ -namespace warp -{ +namespace cutlass { +namespace gemm { +namespace warp { ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -73,35 +72,60 @@ template < /// Store the accumulators in row major or column major. Row major is used /// when output layout is interleaved. bool AccumulatorsInRowMajor> -struct DefaultMmaTensorOp -{ +struct DefaultMmaTensorOp { + private: + // Shape for computing the FP16s + using ComputeInstructionShape = InstructionShape_; -private: - // Shape for computing the FP16s - using ComputeInstructionShape = InstructionShape_; + // Chosen so we get K=16 for int8, K=32 for int4, K=64 for int2. + static constexpr int LoadInstructionK = 128 / sizeof_bits::value; - // Chosen so we get K=16 for int8 and K=32 for int4. - static constexpr int LoadInstructionK = 128 / sizeof_bits::value; + // Shape for loading the narrow data type from shared memory + using LoadInstructionShape = + GemmShape; - // Shape for loading the narrow data type from shared memory - using LoadInstructionShape = GemmShape; + public: + using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< + cutlass::arch::Mma, + cutlass::MatrixShape<1, 1>>; -public: - using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< - cutlass::arch::Mma, - cutlass::MatrixShape<1, 1>>; - - // Define the warp-level tensor op - using Type = cutlass::gemm::warp::MmaTensorOpComputeBWithF16; + // Define the warp-level tensor op + using Type = + cutlass::gemm::warp::MmaTensorOpComputeBWithF16; }; ///////////////////////////////////////////////////////////////////////////////////////////////// -} // namespace warp -} // namespace gemm -} // namespace cutlass +} // namespace warp +} // namespace gemm +} // namespace cutlass ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h b/custom_ops/gpu_ops/cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h index 7c5088894b4..edc37d72a11 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h @@ -1,12 +1,12 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation @@ -18,19 +18,20 @@ * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /*! \file - \brief Templates implementing warp-level matrix multiply-accumulate operations targeting - Tensor Cores. + \brief Templates implementing warp-level matrix multiply-accumulate + operations targeting Tensor Cores. */ #pragma once @@ -58,15 +59,13 @@ ///////////////////////////////////////////////////////////////////////////////////////////////// -namespace cutlass -{ -namespace gemm -{ -namespace warp -{ +namespace cutlass { +namespace gemm { +namespace warp { ///////////////////////////////////////////////////////////////////////////////////////////////// -/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions. +/// Structure to compute the matrix product targeting Tensor Cores, for the case +/// when A is floating point and B is quantized integer. template < /// Size of the Gemm problem - concept: gemm::GemmShape<> typename Shape_, @@ -93,214 +92,489 @@ template < bool AccumulatorsInRowMajor = false, /// Used for partial specialization typename Enable = bool> -class MmaTensorOpComputeBWithF16 -{ -public: - /// Shape of warp-level matrix operation (concept: GemmShape) - using Shape = Shape_; - - /// Data type of multiplicand A - using ElementA = ElementA_; - - /// Layout of multiplicand A - using LayoutA = LayoutA_; - - /// Data type of multiplicand B - using ElementB = ElementB_; - - /// Layout of multiplicand B - using LayoutB = LayoutB_; - - /// Data type of accumulator matrix C - using ElementC = ElementC_; - - /// Layout of accumulator matrix C - using LayoutC = LayoutC_; - - /// Shape of the warp in units of thread (concept: MmaLanePolicySimt) - using Policy = Policy_; - - /// Underlying matrix multiply operator (concept: arch::Mma) - using ArchMmaOperator = typename Policy::Operator; - - /// Indicates math operator - using MathOperator = typename ArchMmaOperator::Operator; - - /// Architecture tag from underlying instruction - using ArchTag = typename ArchMmaOperator::ArchTag; - static_assert((platform::is_same::value - && platform::is_same::value) - || (platform::is_same::value - && platform::is_same::value - && ArchTag::kMinComputeCapability >= 80) - || (platform::is_same::value - && platform::is_same::value - && ArchTag::kMinComputeCapability >= 89), - "MmaTensorOpCvtBToA only supports underlying HMMA/QMMA"); - - static_assert(platform::is_same::value - || (platform::is_same::value && ArchTag::kMinComputeCapability >= 80) - || (platform::is_same::value && ArchTag::kMinComputeCapability >= 89), - "MmaTensorOpCvtBToA only supports Fp16 A or Bf16 A on Ampere+, or FP8 on Ada"); - - /// Indicates class of matrix operator - using OperatorClass = arch::OpClassTensorOp; - - /// Shape of underlying instruction - using InstructionShape = typename ArchMmaOperator::Shape; - - /// Instruction shape to override shared memory iterators with - using SharedMemoryInstructionShape = SharedMemoryInstructionShape_; +class MmaTensorOpComputeBWithF16 { + public: + /// Shape of warp-level matrix operation (concept: GemmShape) + using Shape = Shape_; + + /// Data type of multiplicand A + using ElementA = ElementA_; + + /// Layout of multiplicand A + using LayoutA = LayoutA_; + + /// Data type of multiplicand B + using ElementB = ElementB_; + + /// Layout of multiplicand B + using LayoutB = LayoutB_; + + /// Data type of accumulator matrix C + using ElementC = ElementC_; + + /// Layout of accumulator matrix C + using LayoutC = LayoutC_; + + /// Shape of the warp in units of thread (concept: MmaLanePolicySimt) + using Policy = Policy_; + + /// Underlying matrix multiply operator (concept: arch::Mma) + using ArchMmaOperator = typename Policy::Operator; + + /// Indicates math operator + using MathOperator = typename ArchMmaOperator::Operator; + + /// Architecture tag from underlying instruction + using ArchTag = typename ArchMmaOperator::ArchTag; + static_assert( + (platform::is_same::value && + platform::is_same::value) || + (platform::is_same::value && + platform::is_same::value && + ArchTag::kMinComputeCapability >= 80) || + (platform::is_same::value && + platform::is_same::value && + ArchTag::kMinComputeCapability >= 89), + "MmaTensorOpCvtBToA only supports underlying HMMA/QMMA"); + + static_assert(platform::is_same::value || + (platform::is_same::value && + ArchTag::kMinComputeCapability >= 80) || + (platform::is_same::value && + ArchTag::kMinComputeCapability >= 89), + "MmaTensorOpCvtBToA only supports Fp16 A or Bf16 A on Ampere+, " + "or FP8 on Ada"); + + /// Indicates class of matrix operator + using OperatorClass = arch::OpClassTensorOp; + + /// Shape of underlying instruction + using InstructionShape = typename ArchMmaOperator::Shape; + + /// Instruction shape to override shared memory iterators with + using SharedMemoryInstructionShape = SharedMemoryInstructionShape_; + + static_assert(SharedMemoryInstructionShape::kM == InstructionShape::kM, + "M dimension of compute instruction must match load"); + static_assert(SharedMemoryInstructionShape::kN == InstructionShape::kN, + "N dimension of compute instruction must match load"); + + static constexpr int kExpansionFactor = + SharedMemoryInstructionShape::kK / InstructionShape::kK; + + static_assert(!(Shape::kK % SharedMemoryInstructionShape::kK), ""); + + /// Complex transform on A operand + static ComplexTransform const kTransformA = ComplexTransform::kNone; + + /// Complex transform on B operand + static ComplexTransform const kTransformB = ComplexTransform::kNone; + + /// Number of threads participating in warp-level matrix product + static int const kThreadCount = 32; + + /// Number of partitions along K dimension + static int const kPartitionsK = PartitionsK_; + + public: + /// Iterates over the A operand in memory + using IteratorA = MmaTensorOpMultiplicandTileIterator< + MatrixShape, + Operand::kA, + ElementA, + LayoutA, + MatrixShape, + Policy::OpDelta::kRow, + kThreadCount, + kPartitionsK>; + + /// Storage for A tile + using FragmentA = typename IteratorA::Fragment; + + /// Storage for transformed A tile + using TransformedFragmentA = + Array; + + /// Iterates over the B operand in memory + using IteratorB = MmaTensorOpMultiplicandTileIterator< + MatrixShape, + Operand::kB, + ElementB, + LayoutB, + MatrixShape, + Policy::OpDelta::kRow, + kThreadCount, + kPartitionsK>; + + /// Storage for B tile + using FragmentB = typename IteratorB::Fragment; + + /// Storage for transformed B tile + using TransformedFragmentB = + Array; + + /// Iterates over the C operand in memory + using IteratorC = + MmaTensorOpAccumulatorTileIterator, + ElementC, + LayoutC, + typename ArchMmaOperator::Shape, + typename Policy::OpDelta>; + + /// Storage for C tile + using FragmentC = typename IteratorC::Fragment; + + /// Number of mma operations performed + using MmaIterations = + MatrixShape<(Shape::kM + ArchMmaOperator::Shape::kM - 1) / + ArchMmaOperator::Shape::kM, + (Shape::kN + ArchMmaOperator::Shape::kN - 1) / + ArchMmaOperator::Shape::kN>; + + public: + /// Underlying matrix multiply operator (concept: arch::Mma) + ArchMmaOperator mma; + + public: + // + // Methods + // + + /// Ctor + CUTLASS_DEVICE + MmaTensorOpComputeBWithF16() {} + + /// Performs a warp-level matrix multiply-accumulate operation + CUTLASS_DEVICE + void operator()(FragmentC& D, + TransformedFragmentA const& A, + TransformedFragmentB const& B, + FragmentC const& C, + int const warp_tileB_k_offset) const { + using MmaOperandA = typename ArchMmaOperator::FragmentA; + using MmaOperandB = typename ArchMmaOperator::FragmentB; + using MmaOperandC = typename ArchMmaOperator::FragmentC; static_assert( - SharedMemoryInstructionShape::kM == InstructionShape::kM, "M dimension of compute instruction must match load"); - static_assert( - SharedMemoryInstructionShape::kN == InstructionShape::kN, "N dimension of compute instruction must match load"); - - static constexpr int kExpansionFactor = SharedMemoryInstructionShape::kK / InstructionShape::kK; + TransformedFragmentB::kElements == + MmaOperandB::kElements * kExpansionFactor * MmaIterations::kColumn, + "Each thread should have a pack of mma registers for each column " + "iteration AND for the expanded K dim of " + "B"); - static_assert(!(Shape::kK % SharedMemoryInstructionShape::kK), ""); + D = C; - /// Complex transform on A operand - static ComplexTransform const kTransformA = ComplexTransform::kNone; + MmaOperandA const* ptr_A = reinterpret_cast(&A); + MmaOperandB const* ptr_B = reinterpret_cast(&B); + MmaOperandC* ptr_D = reinterpret_cast(&D); - /// Complex transform on B operand - static ComplexTransform const kTransformB = ComplexTransform::kNone; - - /// Number of threads participating in warp-level matrix product - static int const kThreadCount = 32; +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800) + // Serpentine visitation order maximizing reuse of Rb + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < MmaIterations::kColumn; ++n) { + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < MmaIterations::kRow; ++m) { + int m_serpentine = ((n % 2) ? (MmaIterations::kRow - 1 - m) : m); + + int n_offsetB = warp_tileB_k_offset + kExpansionFactor * n; + if (AccumulatorsInRowMajor) { // matrix B is reordered + mma(ptr_D[n + m_serpentine * MmaIterations::kColumn], + ptr_A[m_serpentine], + ptr_B[n_offsetB], + ptr_D[n + m_serpentine * MmaIterations::kColumn]); + } else { + mma(ptr_D[m_serpentine + n * MmaIterations::kRow], + ptr_A[m_serpentine], + ptr_B[n_offsetB], + ptr_D[m_serpentine + n * MmaIterations::kRow]); + } + } + } +#elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + // Serpentine visitation order maximizing reuse of Ra + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < MmaIterations::kRow; ++m) { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < MmaIterations::kColumn; ++n) { + int n_serpentine = ((m % 2) ? (MmaIterations::kColumn - 1 - n) : n); + + int n_serpentine_offsetB = + warp_tileB_k_offset + kExpansionFactor * n_serpentine; + if (AccumulatorsInRowMajor) { // matrix B is reordered + mma(ptr_D[n_serpentine + m * MmaIterations::kColumn], + ptr_A[m], + ptr_B[n_serpentine_offsetB], + ptr_D[n_serpentine + m * MmaIterations::kColumn]); + } else { + mma(ptr_D[m + n_serpentine * MmaIterations::kRow], + ptr_A[m], + ptr_B[n_serpentine_offsetB], + ptr_D[m + n_serpentine * MmaIterations::kRow]); + } + } + } +#else + assert(0); +#endif + } +}; +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Structure to compute the matrix product targeting Tensor Cores, for the case +/// when A is floating point and B is quantized integer. Specialization for B of +/// uint2b_t. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Data type of A elements + typename ElementA_, + /// Layout of A matrix (concept: MatrixLayout) + typename LayoutA_, + /// Layout of B matrix (concept: MatrixLayout) + typename LayoutB_, + /// Element type of C matrix + typename ElementC_, + /// Layout of C matrix (concept: MatrixLayout) + typename LayoutC_, + /// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy) + typename Policy_, + /// Instruction shape to override shared memory iterators with + typename SharedMemoryInstructionShape_, /// Number of partitions along K dimension - static int const kPartitionsK = PartitionsK_; - -public: - /// Iterates over the A operand in memory - using IteratorA - = MmaTensorOpMultiplicandTileIterator, Operand::kA, ElementA, LayoutA, - MatrixShape, Policy::OpDelta::kRow, kThreadCount, kPartitionsK>; - - /// Storage for A tile - using FragmentA = typename IteratorA::Fragment; - - /// Storage for transformed A tile - using TransformedFragmentA = Array; - - /// Iterates over the B operand in memory - using IteratorB = MmaTensorOpMultiplicandTileIterator, Operand::kB, ElementB, - LayoutB, MatrixShape, Policy::OpDelta::kRow, - kThreadCount, kPartitionsK>; - - /// Storage for B tile - using FragmentB = typename IteratorB::Fragment; - - /// Storage for transformed B tile - using TransformedFragmentB = Array; - - /// Iterates over the C operand in memory - using IteratorC = MmaTensorOpAccumulatorTileIterator, ElementC, LayoutC, - typename ArchMmaOperator::Shape, typename Policy::OpDelta>; - - /// Storage for C tile - using FragmentC = typename IteratorC::Fragment; - - /// Number of mma operations performed - using MmaIterations = MatrixShape<(Shape::kM + ArchMmaOperator::Shape::kM - 1) / ArchMmaOperator::Shape::kM, - (Shape::kN + ArchMmaOperator::Shape::kN - 1) / ArchMmaOperator::Shape::kN>; - -public: - /// Underlying matrix multiply operator (concept: arch::Mma) - ArchMmaOperator mma; - -public: - // - // Methods - // - - /// Ctor - CUTLASS_DEVICE - MmaTensorOpComputeBWithF16() {} - - /// Performs a warp-level matrix multiply-accumulate operation - CUTLASS_DEVICE - void operator()(FragmentC& D, TransformedFragmentA const& A, TransformedFragmentB const& B, FragmentC const& C, - int const warp_tileB_k_offset) const - { - - using MmaOperandA = typename ArchMmaOperator::FragmentA; - using MmaOperandB = typename ArchMmaOperator::FragmentB; - using MmaOperandC = typename ArchMmaOperator::FragmentC; - - static_assert( - TransformedFragmentB::kElements == MmaOperandB::kElements * kExpansionFactor * MmaIterations::kColumn, - "Each thread should have a pack of mma registers for each column iteration AND for the expanded K dim of " - "B"); - - D = C; - - MmaOperandA const* ptr_A = reinterpret_cast(&A); - MmaOperandB const* ptr_B = reinterpret_cast(&B); - MmaOperandC* ptr_D = reinterpret_cast(&D); + int PartitionsK_, + /// Store the accumulators in row major or column major. Row major is used + /// when output layout is interleaved. + bool AccumulatorsInRowMajor> +class MmaTensorOpComputeBWithF16 { + public: + /// Shape of warp-level matrix operation (concept: GemmShape) + using Shape = Shape_; + + /// Data type of multiplicand A + using ElementA = ElementA_; + + /// Layout of multiplicand A + using LayoutA = LayoutA_; + + /// Data type of multiplicand B + using ElementB = uint2b_t; + + /// Layout of multiplicand B + using LayoutB = LayoutB_; + + /// Data type of accumulator matrix C + using ElementC = ElementC_; + + /// Layout of accumulator matrix C + using LayoutC = LayoutC_; + + /// Shape of the warp in units of thread (concept: MmaLanePolicySimt) + using Policy = Policy_; + + /// Underlying matrix multiply operator (concept: arch::Mma) + using ArchMmaOperator = typename Policy::Operator; + + /// Indicates math operator + using MathOperator = typename ArchMmaOperator::Operator; + + /// Architecture tag from underlying instruction + using ArchTag = typename ArchMmaOperator::ArchTag; + static_assert( + (platform::is_same::value && + platform::is_same::value) || + (platform::is_same::value && + platform::is_same::value && + ArchTag::kMinComputeCapability >= 80), + "MmaTensorOpCvtBToA only supports underlying HMMA/QMMA"); + + static_assert(platform::is_same::value || + (platform::is_same::value && + ArchTag::kMinComputeCapability >= 80), + "MmaTensorOpCvtBToA only supports Fp16 A or Bf16 A on Ampere+"); + + /// Indicates class of matrix operator + using OperatorClass = arch::OpClassTensorOp; + + /// Shape of underlying instruction + using InstructionShape = typename ArchMmaOperator::Shape; + + /// Instruction shape to override shared memory iterators with + using SharedMemoryInstructionShape = SharedMemoryInstructionShape_; + + static_assert(SharedMemoryInstructionShape::kM == InstructionShape::kM, + "M dimension of compute instruction must match load"); + static_assert(SharedMemoryInstructionShape::kN == InstructionShape::kN, + "N dimension of compute instruction must match load"); + + static constexpr int kExpansionFactor = + SharedMemoryInstructionShape::kK / InstructionShape::kK; + + static_assert(!(Shape::kK % SharedMemoryInstructionShape::kK), ""); + + /// Complex transform on A operand + static ComplexTransform const kTransformA = ComplexTransform::kNone; + + /// Complex transform on B operand + static ComplexTransform const kTransformB = ComplexTransform::kNone; + + /// Number of threads participating in warp-level matrix product + static int const kThreadCount = 32; + + /// Number of partitions along K dimension + static int const kPartitionsK = PartitionsK_; + + public: + /// Iterates over the A operand in memory + using IteratorA = MmaTensorOpMultiplicandTileIterator< + MatrixShape, + Operand::kA, + ElementA, + LayoutA, + MatrixShape, + Policy::OpDelta::kRow, + kThreadCount, + kPartitionsK>; + + /// Storage for A tile + using FragmentA = typename IteratorA::Fragment; + + /// Storage for transformed A tile + using TransformedFragmentA = + Array; + + /// Iterates over the B operand in memory + using IteratorB = MmaTensorOpMultiplicandTileIterator< + MatrixShape, + Operand::kB, + ElementB, + LayoutB, + MatrixShape, + Policy::OpDelta::kRow, + kThreadCount, + kPartitionsK>; + + /// Storage for B tile + using FragmentB = typename IteratorB::Fragment; + + /// Storage for transformed B tile + using TransformedFragmentB = Array; + + /// Iterates over the C operand in memory + using IteratorC = + MmaTensorOpAccumulatorTileIterator, + ElementC, + LayoutC, + typename ArchMmaOperator::Shape, + typename Policy::OpDelta>; + + /// Storage for C tile + using FragmentC = typename IteratorC::Fragment; + + /// Number of mma operations performed + using MmaIterations = + MatrixShape<(Shape::kM + ArchMmaOperator::Shape::kM - 1) / + ArchMmaOperator::Shape::kM, + (Shape::kN + ArchMmaOperator::Shape::kN - 1) / + ArchMmaOperator::Shape::kN>; + + public: + /// Underlying matrix multiply operator (concept: arch::Mma) + ArchMmaOperator mma; + + public: + // + // Methods + // + + /// Ctor + CUTLASS_DEVICE + MmaTensorOpComputeBWithF16() {} + + /// Performs a warp-level matrix multiply-accumulate operation + CUTLASS_DEVICE + void operator()(FragmentC& D, + TransformedFragmentA const& A, + TransformedFragmentB const& B, + FragmentC const& C) const { + using MmaOperandA = typename ArchMmaOperator::FragmentA; + using MmaOperandB = typename ArchMmaOperator::FragmentB; + using MmaOperandC = typename ArchMmaOperator::FragmentC; + + D = C; + + MmaOperandA const* ptr_A = reinterpret_cast(&A); + MmaOperandB const* ptr_B = reinterpret_cast(&B); + MmaOperandC* ptr_D = reinterpret_cast(&D); #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800) - // Serpentine visitation order maximizing reuse of Rb - CUTLASS_PRAGMA_UNROLL - for (int n = 0; n < MmaIterations::kColumn; ++n) - { - - CUTLASS_PRAGMA_UNROLL - for (int m = 0; m < MmaIterations::kRow; ++m) - { - - int m_serpentine = ((n % 2) ? (MmaIterations::kRow - 1 - m) : m); - - int n_offsetB = warp_tileB_k_offset + kExpansionFactor * n; - if (AccumulatorsInRowMajor) - { // matrix B is reordered - mma(ptr_D[n + m_serpentine * MmaIterations::kColumn], ptr_A[m_serpentine], ptr_B[n_offsetB], - ptr_D[n + m_serpentine * MmaIterations::kColumn]); - } - else - { - mma(ptr_D[m_serpentine + n * MmaIterations::kRow], ptr_A[m_serpentine], ptr_B[n_offsetB], - ptr_D[m_serpentine + n * MmaIterations::kRow]); - } - } + // Serpentine visitation order maximizing reuse of Rb + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < MmaIterations::kColumn; ++n) { + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < MmaIterations::kRow; ++m) { + int m_serpentine = ((n % 2) ? (MmaIterations::kRow - 1 - m) : m); + + if (AccumulatorsInRowMajor) { // matrix B is reordered + mma(ptr_D[n + m_serpentine * MmaIterations::kColumn], + ptr_A[m_serpentine], + ptr_B[n], + ptr_D[n + m_serpentine * MmaIterations::kColumn]); + } else { + mma(ptr_D[m_serpentine + n * MmaIterations::kRow], + ptr_A[m_serpentine], + ptr_B[n], + ptr_D[m_serpentine + n * MmaIterations::kRow]); } + } + } #elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) - // Serpentine visitation order maximizing reuse of Ra - CUTLASS_PRAGMA_UNROLL - for (int m = 0; m < MmaIterations::kRow; ++m) - { - - CUTLASS_PRAGMA_UNROLL - for (int n = 0; n < MmaIterations::kColumn; ++n) - { - - int n_serpentine = ((m % 2) ? (MmaIterations::kColumn - 1 - n) : n); - - int n_serpentine_offsetB = warp_tileB_k_offset + kExpansionFactor * n_serpentine; - if (AccumulatorsInRowMajor) - { // matrix B is reordered - mma(ptr_D[n_serpentine + m * MmaIterations::kColumn], ptr_A[m], ptr_B[n_serpentine_offsetB], - ptr_D[n_serpentine + m * MmaIterations::kColumn]); - } - else - { - mma(ptr_D[m + n_serpentine * MmaIterations::kRow], ptr_A[m], ptr_B[n_serpentine_offsetB], - ptr_D[m + n_serpentine * MmaIterations::kRow]); - } - } + // Serpentine visitation order maximizing reuse of Ra + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < MmaIterations::kRow; ++m) { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < MmaIterations::kColumn; ++n) { + int n_serpentine = ((m % 2) ? (MmaIterations::kColumn - 1 - n) : n); + + if (AccumulatorsInRowMajor) { // matrix B is reordered + mma(ptr_D[n_serpentine + m * MmaIterations::kColumn], + ptr_A[m], + ptr_B[n_serpentine], + ptr_D[n_serpentine + m * MmaIterations::kColumn]); + } else { + mma(ptr_D[m + n_serpentine * MmaIterations::kRow], + ptr_A[m], + ptr_B[n_serpentine], + ptr_D[m + n_serpentine * MmaIterations::kRow]); } + } + } #else - assert(0); + assert(0); #endif - } + } }; ///////////////////////////////////////////////////////////////////////////////////////////////// -} // namespace warp -} // namespace gemm -} // namespace cutlass +} // namespace warp +} // namespace gemm +} // namespace cutlass ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h b/custom_ops/gpu_ops/cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h index 24e844abca3..5bb06a54146 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h @@ -1,12 +1,12 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation @@ -18,18 +18,20 @@ * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /*! \file - \brief Defines iterators used by warp-level matrix multiply operations targeting Tensor Cores. + \brief Defines iterators used by warp-level matrix multiply operations + targeting Tensor Cores. */ #pragma once @@ -57,12 +59,9 @@ //////////////////////////////////////////////////////////////////////////////// -namespace cutlass -{ -namespace gemm -{ -namespace warp -{ +namespace cutlass { +namespace gemm { +namespace warp { //////////////////////////////////////////////////////////////////////////////// @@ -94,193 +93,216 @@ template < typename Shape_, /// WeightOnlyQuantOp QuantOp_> -class MmaTensorOpDequantizer= 80 - && platform::is_same::value>::type> -{ - -public: - /// Mma Operator - using MmaOperator = MmaOperator_; - - // The architecture specific mma ooperator being used - using ArchMmaOperator = typename MmaOperator::ArchMmaOperator; - - // Mma Instruction Shape - using InstructionShape = typename ArchMmaOperator::Shape; - - // This is the ratio of the load instruction vs the compute instruction. - static constexpr int kExpansionFactor = MmaOperator::IteratorB::InstructionShape::kRow / InstructionShape::kK; - - /// Type of the scales - using ElementScale = bfloat16_t; - - /// Fragment to hold B data before Mma - using FragmentDequantizedOperand = Array; - - // Fragment to hold scale data to apply to B before mma - // We need 1 fp16 per matrix iteration in the N dimension - static constexpr int kColsPerMmaPerThread = 1; - using FragmentScale = Array; - using FragmentZero = Array; - - /// Warp mma shape - using Shape = Shape_; - - /// Layout of the scales in shared memory - using Layout = layout::RowMajor; - - /// TensorRef type for loading element from a tensor - using TensorRef = TensorRef; - - static constexpr WeightOnlyQuantOp QuantOp = QuantOp_; - - CUTLASS_DEVICE - MmaTensorOpDequantizer(TensorRef smem_scales, TensorRef smem_zeros, int const warp_idx_n, int const lane_idx) - { - int const warp_offset = warp_idx_n * Shape::kN; - int const quad = lane_idx / 4; - int const thread_offset = warp_offset + quad; - pointer_scale_ = smem_scales.data() + thread_offset; - if constexpr (hasZero(QuantOp)) - { - pointer_zero_ = smem_zeros.data() + thread_offset; - } +class MmaTensorOpDequantizer< + MmaOperator_, + Shape_, + Operand::kB, + bfloat16_t, + layout::RowMajor, + 32, + QuantOp_, + typename platform::enable_if< + MmaOperator_::ArchTag::kMinComputeCapability >= 80 && + platform::is_same::value>::type> { + public: + /// Mma Operator + using MmaOperator = MmaOperator_; + + // The architecture specific mma ooperator being used + using ArchMmaOperator = typename MmaOperator::ArchMmaOperator; + + // Mma Instruction Shape + using InstructionShape = typename ArchMmaOperator::Shape; + + // This is the ratio of the load instruction vs the compute instruction. + static constexpr int kExpansionFactor = + MmaOperator::IteratorB::InstructionShape::kRow / InstructionShape::kK; + + /// Type of the scales + using ElementScale = bfloat16_t; + + /// Fragment to hold B data before Mma + using FragmentDequantizedOperand = + Array; + + // Fragment to hold scale data to apply to B before mma + // We need 1 fp16 per matrix iteration in the N dimension + static constexpr int kColsPerMmaPerThread = 1; + using FragmentScale = + Array; + using FragmentZero = + Array; + + /// Warp mma shape + using Shape = Shape_; + + /// Layout of the scales in shared memory + using Layout = layout::RowMajor; + + /// TensorRef type for loading element from a tensor + using TensorRef = TensorRef; + + static constexpr WeightOnlyQuantOp QuantOp = QuantOp_; + + CUTLASS_DEVICE + MmaTensorOpDequantizer(TensorRef smem_scales, + TensorRef smem_zeros, + int const warp_idx_n, + int const lane_idx) { + int const warp_offset = warp_idx_n * Shape::kN; + int const quad = lane_idx / 4; + int const thread_offset = warp_offset + quad; + pointer_scale_ = smem_scales.data() + thread_offset; + if constexpr (hasZero(QuantOp)) { + pointer_zero_ = smem_zeros.data() + thread_offset; } - - CUTLASS_DEVICE - MmaTensorOpDequantizer(TensorRef smem_scales, int const warp_idx_n, int const lane_idx) - : MmaTensorOpDequantizer(smem_scales, TensorRef(), warp_idx_n, lane_idx) - { + } + + CUTLASS_DEVICE + MmaTensorOpDequantizer(TensorRef smem_scales, + int const warp_idx_n, + int const lane_idx) + : MmaTensorOpDequantizer(smem_scales, TensorRef(), warp_idx_n, lane_idx) { + } + + CUTLASS_DEVICE + void load(FragmentScale& scale_frag) { + CUTLASS_PRAGMA_UNROLL + for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; + ++mma_n_iter) { + scale_frag[mma_n_iter] = + pointer_scale_[mma_n_iter * InstructionShape::kN]; } + } - CUTLASS_DEVICE - void load(FragmentScale& scale_frag) - { - CUTLASS_PRAGMA_UNROLL - for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) - { - scale_frag[mma_n_iter] = pointer_scale_[mma_n_iter * InstructionShape::kN]; - } - } - - CUTLASS_DEVICE - void dequantize(FragmentDequantizedOperand& operand_frag, FragmentScale const& scale_frag) - { + CUTLASS_DEVICE + void dequantize(FragmentDequantizedOperand& operand_frag, + FragmentScale const& scale_frag) { #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && defined(ENABLE_BF16)) - using _MmaOperandB = typename ArchMmaOperator::FragmentB; - using ExpandedMmaOperandB = Array; - static_assert(ExpandedMmaOperandB::kElements * MmaOperator::MmaIterations::kColumn - == FragmentDequantizedOperand::kElements, - ""); - - __nv_bfloat16 const* scale_ptr = reinterpret_cast<__nv_bfloat16 const*>(&scale_frag); - ExpandedMmaOperandB* operand_frag_ptr = reinterpret_cast(&operand_frag); - CUTLASS_PRAGMA_UNROLL - for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) - { - static_assert(ExpandedMmaOperandB::kElements % 2 == 0, ""); - - __nv_bfloat162 scalex2 = __bfloat162bfloat162(scale_ptr[mma_n_iter]); - __nv_bfloat162* operand_bf16x2_ptr = reinterpret_cast<__nv_bfloat162*>(&operand_frag_ptr[mma_n_iter]); - - CUTLASS_PRAGMA_UNROLL - for (int ii = 0; ii < ExpandedMmaOperandB::kElements / 2; ++ii) - { - operand_bf16x2_ptr[ii] = __hmul2(operand_bf16x2_ptr[ii], scalex2); - } - } + using _MmaOperandB = typename ArchMmaOperator::FragmentB; + using ExpandedMmaOperandB = + Array; + static_assert( + ExpandedMmaOperandB::kElements * MmaOperator::MmaIterations::kColumn == + FragmentDequantizedOperand::kElements, + ""); + + __nv_bfloat16 const* scale_ptr = + reinterpret_cast<__nv_bfloat16 const*>(&scale_frag); + ExpandedMmaOperandB* operand_frag_ptr = + reinterpret_cast(&operand_frag); + CUTLASS_PRAGMA_UNROLL + for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; + ++mma_n_iter) { + static_assert(ExpandedMmaOperandB::kElements % 2 == 0, ""); + + __nv_bfloat162 scalex2 = __bfloat162bfloat162(scale_ptr[mma_n_iter]); + __nv_bfloat162* operand_bf16x2_ptr = + reinterpret_cast<__nv_bfloat162*>(&operand_frag_ptr[mma_n_iter]); + + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < ExpandedMmaOperandB::kElements / 2; ++ii) { + operand_bf16x2_ptr[ii] = __hmul2(operand_bf16x2_ptr[ii], scalex2); + } + } #else - // Slow path not implemented here on purpose. If we need to do HMMA on older arch, scale conversion should - // happen before scales are stored to shared memory and we should use the fp16 dequantizer. This will avoid - // numerous conversion instructions in GEMM main loop. - arch::device_breakpoint(); + // Slow path not implemented here on purpose. If we need to do HMMA on older + // arch, scale conversion should happen before scales are stored to shared + // memory and we should use the fp16 dequantizer. This will avoid numerous + // conversion instructions in GEMM main loop. + arch::device_breakpoint(); #endif + } + + CUTLASS_DEVICE + void load(FragmentScale& scale_frag, FragmentScale& zero_frag) { + if constexpr (hasZero(QuantOp)) { + CUTLASS_PRAGMA_UNROLL + for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; + ++mma_n_iter) { + scale_frag[mma_n_iter] = + pointer_scale_[mma_n_iter * InstructionShape::kN]; + zero_frag[mma_n_iter] = + pointer_zero_[mma_n_iter * InstructionShape::kN]; + } + } else { + CUTLASS_PRAGMA_UNROLL + for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; + ++mma_n_iter) { + scale_frag[mma_n_iter] = + pointer_scale_[mma_n_iter * InstructionShape::kN]; + } } + } - CUTLASS_DEVICE - void load(FragmentScale& scale_frag, FragmentScale& zero_frag) - { - if constexpr (hasZero(QuantOp)) - { - CUTLASS_PRAGMA_UNROLL - for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) - { - scale_frag[mma_n_iter] = pointer_scale_[mma_n_iter * InstructionShape::kN]; - zero_frag[mma_n_iter] = pointer_zero_[mma_n_iter * InstructionShape::kN]; - } - } - else - { - CUTLASS_PRAGMA_UNROLL - for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) - { - scale_frag[mma_n_iter] = pointer_scale_[mma_n_iter * InstructionShape::kN]; - } - } - } - - CUTLASS_DEVICE - void dequantize( - FragmentDequantizedOperand& operand_frag, FragmentScale const& scale_frag, FragmentScale const& zero_frag) - { + CUTLASS_DEVICE + void dequantize(FragmentDequantizedOperand& operand_frag, + FragmentScale const& scale_frag, + FragmentScale const& zero_frag) { #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && defined(ENABLE_BF16)) - using _MmaOperandB = typename ArchMmaOperator::FragmentB; - using ExpandedMmaOperandB = Array; - static_assert(ExpandedMmaOperandB::kElements * MmaOperator::MmaIterations::kColumn - == FragmentDequantizedOperand::kElements, - ""); - - __nv_bfloat16 const* scale_ptr = reinterpret_cast<__nv_bfloat16 const*>(&scale_frag); - __nv_bfloat16 const* zero_ptr = reinterpret_cast<__nv_bfloat16 const*>(&zero_frag); - - ExpandedMmaOperandB* operand_frag_ptr = reinterpret_cast(&operand_frag); + using _MmaOperandB = typename ArchMmaOperator::FragmentB; + using ExpandedMmaOperandB = + Array; + static_assert( + ExpandedMmaOperandB::kElements * MmaOperator::MmaIterations::kColumn == + FragmentDequantizedOperand::kElements, + ""); + + __nv_bfloat16 const* scale_ptr = + reinterpret_cast<__nv_bfloat16 const*>(&scale_frag); + __nv_bfloat16 const* zero_ptr = + reinterpret_cast<__nv_bfloat16 const*>(&zero_frag); + + ExpandedMmaOperandB* operand_frag_ptr = + reinterpret_cast(&operand_frag); + CUTLASS_PRAGMA_UNROLL + for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; + ++mma_n_iter) { + static_assert(ExpandedMmaOperandB::kElements % 2 == 0, ""); + + __nv_bfloat162 scalex2 = __bfloat162bfloat162(scale_ptr[mma_n_iter]); + __nv_bfloat162 zerox2 = __bfloat162bfloat162(zero_ptr[mma_n_iter]); + __nv_bfloat162* operand_bf16x2_ptr = + reinterpret_cast<__nv_bfloat162*>(&operand_frag_ptr[mma_n_iter]); + + if constexpr (hasZero(QuantOp)) { CUTLASS_PRAGMA_UNROLL - for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) - { - static_assert(ExpandedMmaOperandB::kElements % 2 == 0, ""); - - __nv_bfloat162 scalex2 = __bfloat162bfloat162(scale_ptr[mma_n_iter]); - __nv_bfloat162 zerox2 = __bfloat162bfloat162(zero_ptr[mma_n_iter]); - __nv_bfloat162* operand_bf16x2_ptr = reinterpret_cast<__nv_bfloat162*>(&operand_frag_ptr[mma_n_iter]); - - if constexpr (hasZero(QuantOp)) - { - CUTLASS_PRAGMA_UNROLL - for (int ii = 0; ii < ExpandedMmaOperandB::kElements / 2; ++ii) - { - operand_bf16x2_ptr[ii] = __hfma2(operand_bf16x2_ptr[ii], scalex2, zerox2); - } - } - else - { - CUTLASS_PRAGMA_UNROLL - for (int ii = 0; ii < ExpandedMmaOperandB::kElements / 2; ++ii) - { - operand_bf16x2_ptr[ii] = __hmul2(operand_bf16x2_ptr[ii], scalex2); - } - } + for (int ii = 0; ii < ExpandedMmaOperandB::kElements / 2; ++ii) { + operand_bf16x2_ptr[ii] = + __hfma2(operand_bf16x2_ptr[ii], scalex2, zerox2); } + } else { + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < ExpandedMmaOperandB::kElements / 2; ++ii) { + operand_bf16x2_ptr[ii] = __hmul2(operand_bf16x2_ptr[ii], scalex2); + } + } + } #else - // Slow path not implemented here on purpose. If we need to do HMMA on older arch, scale conversion should - // happen before scales are stored to shared memory and we should use the fp16 dequantizer. This will avoid - // numerous conversion instructions in GEMM main loop. - arch::device_breakpoint(); + // Slow path not implemented here on purpose. If we need to do HMMA on older + // arch, scale conversion should happen before scales are stored to shared + // memory and we should use the fp16 dequantizer. This will avoid numerous + // conversion instructions in GEMM main loop. + arch::device_breakpoint(); #endif - } - - // Adds a pointer offset in units of elements. - CUTLASS_DEVICE - void add_pointer_offset(int64_t const& offset) - { - static_assert(sizeof(ElementScale) > 1, ""); - pointer_scale_ += offset; - pointer_zero_ += offset; - } - -private: - ElementScale const* pointer_scale_; - ElementScale const* pointer_zero_; + } + + // Adds a pointer offset in units of elements. + CUTLASS_DEVICE + void add_pointer_offset(int64_t const& offset) { + static_assert(sizeof(ElementScale) > 1, ""); + pointer_scale_ += offset; + pointer_zero_ += offset; + } + + private: + ElementScale const* pointer_scale_; + ElementScale const* pointer_zero_; }; //////////////////////////////////////////////////////////////////////////////// @@ -293,170 +315,190 @@ template < typename Shape_, /// WeightOnlyQuantOp QuantOp_> -class MmaTensorOpDequantizer= 75 - && platform::is_same::value>::type> -{ - -public: - /// Mma Operator - using MmaOperator = MmaOperator_; - - // The architecture specific mma ooperator being used - using ArchMmaOperator = typename MmaOperator::ArchMmaOperator; - - // Mma Instruction Shape - using InstructionShape = typename ArchMmaOperator::Shape; - - // This is the ratio of the load instruction vs the compute instruction. - static constexpr int kExpansionFactor = MmaOperator::IteratorB::InstructionShape::kRow / InstructionShape::kK; - - /// Type of the scales - using ElementScale = half_t; - - /// Fragment to hold B data before Mma - using FragmentDequantizedOperand = Array; - - // Fragment to hold scale data to apply to B before mma - // We need 1 fp16 per matrix iteration in the N dimension - static constexpr int kColsPerMmaPerThread = 1; - using FragmentScale = Array; - using FragmentZero = Array; - - /// Warp mma shape - using Shape = Shape_; - - /// Layout of the scales in shared memory - using Layout = layout::RowMajor; - - /// TensorRef type for loading element from a tensor - using TensorRef = TensorRef; - - static constexpr WeightOnlyQuantOp QuantOp = QuantOp_; - - CUTLASS_DEVICE - MmaTensorOpDequantizer(TensorRef smem_scales, TensorRef smem_zeros, int const warp_idx_n, int const lane_idx) - { - int const warp_offset = warp_idx_n * Shape::kN; - int const quad = lane_idx / 4; - int const thread_offset = warp_offset + quad; - pointer_scale_ = smem_scales.data() + thread_offset; - if constexpr (hasZero(QuantOp)) - { - pointer_zero_ = smem_zeros.data() + thread_offset; - } +class MmaTensorOpDequantizer< + MmaOperator_, + Shape_, + Operand::kB, + half_t, + layout::RowMajor, + 32, + QuantOp_, + typename platform::enable_if< + MmaOperator_::ArchTag::kMinComputeCapability >= 75 && + platform::is_same::value>::type> { + public: + /// Mma Operator + using MmaOperator = MmaOperator_; + + // The architecture specific mma ooperator being used + using ArchMmaOperator = typename MmaOperator::ArchMmaOperator; + + // Mma Instruction Shape + using InstructionShape = typename ArchMmaOperator::Shape; + + // This is the ratio of the load instruction vs the compute instruction. + static constexpr int kExpansionFactor = + MmaOperator::IteratorB::InstructionShape::kRow / InstructionShape::kK; + + /// Type of the scales + using ElementScale = half_t; + + /// Fragment to hold B data before Mma + using FragmentDequantizedOperand = + Array; + + // Fragment to hold scale data to apply to B before mma + // We need 1 fp16 per matrix iteration in the N dimension + static constexpr int kColsPerMmaPerThread = 1; + using FragmentScale = + Array; + using FragmentZero = + Array; + + /// Warp mma shape + using Shape = Shape_; + + /// Layout of the scales in shared memory + using Layout = layout::RowMajor; + + /// TensorRef type for loading element from a tensor + using TensorRef = TensorRef; + + static constexpr WeightOnlyQuantOp QuantOp = QuantOp_; + + CUTLASS_DEVICE + MmaTensorOpDequantizer(TensorRef smem_scales, + TensorRef smem_zeros, + int const warp_idx_n, + int const lane_idx) { + int const warp_offset = warp_idx_n * Shape::kN; + int const quad = lane_idx / 4; + int const thread_offset = warp_offset + quad; + pointer_scale_ = smem_scales.data() + thread_offset; + if constexpr (hasZero(QuantOp)) { + pointer_zero_ = smem_zeros.data() + thread_offset; } - - CUTLASS_DEVICE - MmaTensorOpDequantizer(TensorRef smem_scales, int const warp_idx_n, int const lane_idx) - : MmaTensorOpDequantizer(smem_scales, TensorRef(), warp_idx_n, lane_idx) - { + } + + CUTLASS_DEVICE + MmaTensorOpDequantizer(TensorRef smem_scales, + int const warp_idx_n, + int const lane_idx) + : MmaTensorOpDequantizer(smem_scales, TensorRef(), warp_idx_n, lane_idx) { + } + + CUTLASS_DEVICE + void load(FragmentScale& scale_frag) { + CUTLASS_PRAGMA_UNROLL + for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; + ++mma_n_iter) { + scale_frag[mma_n_iter] = + pointer_scale_[mma_n_iter * InstructionShape::kN]; } - - CUTLASS_DEVICE - void load(FragmentScale& scale_frag) - { - CUTLASS_PRAGMA_UNROLL - for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) - { - scale_frag[mma_n_iter] = pointer_scale_[mma_n_iter * InstructionShape::kN]; - } + } + + CUTLASS_DEVICE + void dequantize(FragmentDequantizedOperand& operand_frag, + FragmentScale const& scale_frag) { + using _MmaOperandB = typename ArchMmaOperator::FragmentB; + using ExpandedMmaOperandB = + Array; + static_assert( + ExpandedMmaOperandB::kElements * MmaOperator::MmaIterations::kColumn == + FragmentDequantizedOperand::kElements, + ""); + + multiplies mul_op; + + ExpandedMmaOperandB* operand_frag_ptr = + reinterpret_cast(&operand_frag); + CUTLASS_PRAGMA_UNROLL + for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; + ++mma_n_iter) { + operand_frag_ptr[mma_n_iter] = + mul_op(operand_frag_ptr[mma_n_iter], scale_frag[mma_n_iter]); } - - CUTLASS_DEVICE - void dequantize(FragmentDequantizedOperand& operand_frag, FragmentScale const& scale_frag) - { - using _MmaOperandB = typename ArchMmaOperator::FragmentB; - using ExpandedMmaOperandB - = Array; - static_assert(ExpandedMmaOperandB::kElements * MmaOperator::MmaIterations::kColumn - == FragmentDequantizedOperand::kElements, - ""); - - multiplies mul_op; - - ExpandedMmaOperandB* operand_frag_ptr = reinterpret_cast(&operand_frag); - CUTLASS_PRAGMA_UNROLL - for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) - { - operand_frag_ptr[mma_n_iter] = mul_op(operand_frag_ptr[mma_n_iter], scale_frag[mma_n_iter]); - } + } + + CUTLASS_DEVICE + void load(FragmentScale& scale_frag, FragmentScale& zero_frag) { + if constexpr (hasZero(QuantOp)) { + CUTLASS_PRAGMA_UNROLL + for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; + ++mma_n_iter) { + scale_frag[mma_n_iter] = + pointer_scale_[mma_n_iter * InstructionShape::kN]; + zero_frag[mma_n_iter] = + pointer_zero_[mma_n_iter * InstructionShape::kN]; + } + } else { + CUTLASS_PRAGMA_UNROLL + for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; + ++mma_n_iter) { + scale_frag[mma_n_iter] = + pointer_scale_[mma_n_iter * InstructionShape::kN]; + } } - - CUTLASS_DEVICE - void load(FragmentScale& scale_frag, FragmentScale& zero_frag) - { - if constexpr (hasZero(QuantOp)) - { - CUTLASS_PRAGMA_UNROLL - for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) - { - scale_frag[mma_n_iter] = pointer_scale_[mma_n_iter * InstructionShape::kN]; - zero_frag[mma_n_iter] = pointer_zero_[mma_n_iter * InstructionShape::kN]; - } - } - else - { - CUTLASS_PRAGMA_UNROLL - for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) - { - scale_frag[mma_n_iter] = pointer_scale_[mma_n_iter * InstructionShape::kN]; - } - } + } + + CUTLASS_DEVICE + void dequantize(FragmentDequantizedOperand& operand_frag, + FragmentScale const& scale_frag, + FragmentScale const& zero_frag) { + using _MmaOperandB = typename ArchMmaOperator::FragmentB; + using ExpandedMmaOperandB = + Array; + static_assert( + ExpandedMmaOperandB::kElements * MmaOperator::MmaIterations::kColumn == + FragmentDequantizedOperand::kElements, + ""); + + multiplies mul_op; + ExpandedMmaOperandB* operand_frag_ptr = + reinterpret_cast(&operand_frag); + + if constexpr (hasZero(QuantOp)) { + plus plus_op; + + CUTLASS_PRAGMA_UNROLL + for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; + ++mma_n_iter) { + operand_frag_ptr[mma_n_iter] = plus_op( + mul_op(operand_frag_ptr[mma_n_iter], scale_frag[mma_n_iter]), + zero_frag[mma_n_iter]); + } + } else { + CUTLASS_PRAGMA_UNROLL + for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; + ++mma_n_iter) { + operand_frag_ptr[mma_n_iter] = + mul_op(operand_frag_ptr[mma_n_iter], scale_frag[mma_n_iter]); + } } - - CUTLASS_DEVICE - void dequantize( - FragmentDequantizedOperand& operand_frag, FragmentScale const& scale_frag, FragmentScale const& zero_frag) - { - using _MmaOperandB = typename ArchMmaOperator::FragmentB; - using ExpandedMmaOperandB - = Array; - static_assert(ExpandedMmaOperandB::kElements * MmaOperator::MmaIterations::kColumn - == FragmentDequantizedOperand::kElements, - ""); - - multiplies mul_op; - ExpandedMmaOperandB* operand_frag_ptr = reinterpret_cast(&operand_frag); - - if constexpr (hasZero(QuantOp)) - { - plus plus_op; - - CUTLASS_PRAGMA_UNROLL - for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) - { - operand_frag_ptr[mma_n_iter] - = plus_op(mul_op(operand_frag_ptr[mma_n_iter], scale_frag[mma_n_iter]), zero_frag[mma_n_iter]); - } - } - else - { - CUTLASS_PRAGMA_UNROLL - for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) - { - operand_frag_ptr[mma_n_iter] = mul_op(operand_frag_ptr[mma_n_iter], scale_frag[mma_n_iter]); - } - } - } - - // Adds a pointer offset in units of elements. - CUTLASS_DEVICE - void add_pointer_offset(int64_t const& offset) - { - static_assert(sizeof(ElementScale) > 1, ""); - pointer_scale_ += offset; - pointer_zero_ += offset; - } - -private: - ElementScale const* pointer_scale_; - ElementScale const* pointer_zero_; + } + + // Adds a pointer offset in units of elements. + CUTLASS_DEVICE + void add_pointer_offset(int64_t const& offset) { + static_assert(sizeof(ElementScale) > 1, ""); + pointer_scale_ += offset; + pointer_zero_ += offset; + } + + private: + ElementScale const* pointer_scale_; + ElementScale const* pointer_zero_; }; //////////////////////////////////////////////////////////////////////////////// -// Specialization for Volta A x RowMajor B tensorOp, for 32x32x4 interleaved gemm +// Specialization for Volta A x RowMajor B tensorOp, for 32x32x4 interleaved +// gemm template < /// Underlying matrix multiply operator (concept: MmaTensorOp) typename MmaOperator_, @@ -464,86 +506,98 @@ template < typename Shape_, /// WeightOnlyQuantOp QuantOp_> -class MmaTensorOpDequantizer::value - && platform::is_same::value>::type> -{ - -public: - static_assert(platform::is_same>::value, ""); - - /// Mma Operator - using MmaOperator = MmaOperator_; - - // The architecture specific mma ooperator being used - using ArchMmaOperator = typename MmaOperator::ArchMmaOperator; - - // Mma Instruction Shape - using InstructionShape = typename ArchMmaOperator::Shape; - - /// Type of the scales - using ElementScale = half_t; - - /// Fragment to hold B data before Mma - using FragmentDequantizedOperand = Array; - - /// Warp mma shape - using Shape = Shape_; - - // Fragment to hold scale data to apply to B before mma - // Each 32x32x4 matmul uses 8 elements from B. - static constexpr int ColsPerMmaTile = 32; - static constexpr int TileNIterations = Shape::kN / ColsPerMmaTile; - using FragmentScale = Array; - using AccessType = Array; - - /// Layout of the scales in shared memory - using Layout = layout::RowMajor; - - /// TensorRef type for loading element from a tensor - using TensorRef = TensorRef; - - static constexpr WeightOnlyQuantOp QuantOp = QuantOp_; - static_assert(QuantOp == WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY, ""); - - CUTLASS_DEVICE - MmaTensorOpDequantizer(TensorRef smem_scales, int const warp_idx_n, int const lane_idx) - { - int const warp_offset = warp_idx_n * Shape::kN; - int const base_col = lane_idx & 0xF8; - int const thread_offset = warp_offset + base_col; - pointer_ = smem_scales.data() + thread_offset; - } - - CUTLASS_DEVICE - void load(FragmentScale& scale_frag) - { - AccessType* scale_frag_ptr = reinterpret_cast(&scale_frag); - - CUTLASS_PRAGMA_UNROLL - for (int tile_iter = 0; tile_iter < TileNIterations; ++tile_iter) - { - // We jump by 32 here since volta does <32x32x4> super mmas inside a warp. - scale_frag_ptr[tile_iter] = *reinterpret_cast(pointer_ + ColsPerMmaTile * tile_iter); - } +class MmaTensorOpDequantizer< + MmaOperator_, + Shape_, + Operand::kB, + half_t, + layout::RowMajor, + 32, + QuantOp_, + typename platform::enable_if< + platform::is_same::value && + platform::is_same::value>::type> { + public: + static_assert(platform::is_same>::value, + ""); + + /// Mma Operator + using MmaOperator = MmaOperator_; + + // The architecture specific mma ooperator being used + using ArchMmaOperator = typename MmaOperator::ArchMmaOperator; + + // Mma Instruction Shape + using InstructionShape = typename ArchMmaOperator::Shape; + + /// Type of the scales + using ElementScale = half_t; + + /// Fragment to hold B data before Mma + using FragmentDequantizedOperand = + Array; + + /// Warp mma shape + using Shape = Shape_; + + // Fragment to hold scale data to apply to B before mma + // Each 32x32x4 matmul uses 8 elements from B. + static constexpr int ColsPerMmaTile = 32; + static constexpr int TileNIterations = Shape::kN / ColsPerMmaTile; + using FragmentScale = Array; + using AccessType = Array; + + /// Layout of the scales in shared memory + using Layout = layout::RowMajor; + + /// TensorRef type for loading element from a tensor + using TensorRef = TensorRef; + + static constexpr WeightOnlyQuantOp QuantOp = QuantOp_; + static_assert(QuantOp == WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY, ""); + + CUTLASS_DEVICE + MmaTensorOpDequantizer(TensorRef smem_scales, + int const warp_idx_n, + int const lane_idx) { + int const warp_offset = warp_idx_n * Shape::kN; + int const base_col = lane_idx & 0xF8; + int const thread_offset = warp_offset + base_col; + pointer_ = smem_scales.data() + thread_offset; + } + + CUTLASS_DEVICE + void load(FragmentScale& scale_frag) { + AccessType* scale_frag_ptr = reinterpret_cast(&scale_frag); + + CUTLASS_PRAGMA_UNROLL + for (int tile_iter = 0; tile_iter < TileNIterations; ++tile_iter) { + // We jump by 32 here since volta does <32x32x4> super mmas inside a warp. + scale_frag_ptr[tile_iter] = *reinterpret_cast( + pointer_ + ColsPerMmaTile * tile_iter); } + } - CUTLASS_DEVICE - void dequantize(FragmentDequantizedOperand& operand_frag, FragmentScale const& scale_frag) - { - static_assert(FragmentScale::kElements == FragmentDequantizedOperand::kElements, ""); + CUTLASS_DEVICE + void dequantize(FragmentDequantizedOperand& operand_frag, + FragmentScale const& scale_frag) { + static_assert( + FragmentScale::kElements == FragmentDequantizedOperand::kElements, ""); - multiplies mul_op; - operand_frag = mul_op(operand_frag, scale_frag); - } + multiplies mul_op; + operand_frag = mul_op(operand_frag, scale_frag); + } -private: - ElementScale const* pointer_; + private: + ElementScale const* pointer_; }; //////////////////////////////////////////////////////////////////////////////// -// Specialization for Volta A x ColumnMajor B tensorOp, for 32x32x4 interleaved gemm +// Specialization for Volta A x ColumnMajor B tensorOp, for 32x32x4 interleaved +// gemm template < /// Underlying matrix multiply operator (concept: MmaTensorOp) typename MmaOperator_, @@ -551,98 +605,110 @@ template < typename Shape_, /// WeightOnlyQuantOp QuantOp_> -class MmaTensorOpDequantizer::value - && platform::is_same::value>::type> -{ - -public: - static_assert(platform::is_same>::value, ""); - - /// Mma Operator - using MmaOperator = MmaOperator_; - - // The architecture specific mma ooperator being used - using ArchMmaOperator = typename MmaOperator::ArchMmaOperator; - - // Mma Instruction Shape - using InstructionShape = typename ArchMmaOperator::Shape; - - /// Type of the scales - using ElementScale = half_t; - - /// Fragment to hold B data before Mma - using FragmentDequantizedOperand = Array; - - /// Warp mma shape - using Shape = Shape_; - - // Fragment to hold scale data to apply to B before mma - // Each 32x32x4 matmul uses 8 elements from B. - static constexpr int ColsPerMmaTile = 32; - static constexpr int TileNIterations = Shape::kN / ColsPerMmaTile; - using FragmentScale = Array; - - /// Layout of the scales in shared memory - using Layout = layout::RowMajor; - - /// TensorRef type for loading element from a tensor - using TensorRef = TensorRef; - - static constexpr WeightOnlyQuantOp QuantOp = QuantOp_; - static_assert(QuantOp == WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY, ""); - - CUTLASS_DEVICE - MmaTensorOpDequantizer(TensorRef smem_scales, int const warp_idx_n, int const lane_idx) - { - int const warp_offset = warp_idx_n * Shape::kN; - int const base_col = lane_idx & 0xF8 + lane_idx % 4; - int const thread_offset = warp_offset + base_col; - pointer_ = smem_scales.data() + thread_offset; +class MmaTensorOpDequantizer< + MmaOperator_, + Shape_, + Operand::kB, + half_t, + layout::RowMajor, + 32, + QuantOp_, + typename platform::enable_if< + platform::is_same::value && + platform::is_same::value>::type> { + public: + static_assert(platform::is_same>::value, + ""); + + /// Mma Operator + using MmaOperator = MmaOperator_; + + // The architecture specific mma ooperator being used + using ArchMmaOperator = typename MmaOperator::ArchMmaOperator; + + // Mma Instruction Shape + using InstructionShape = typename ArchMmaOperator::Shape; + + /// Type of the scales + using ElementScale = half_t; + + /// Fragment to hold B data before Mma + using FragmentDequantizedOperand = + Array; + + /// Warp mma shape + using Shape = Shape_; + + // Fragment to hold scale data to apply to B before mma + // Each 32x32x4 matmul uses 8 elements from B. + static constexpr int ColsPerMmaTile = 32; + static constexpr int TileNIterations = Shape::kN / ColsPerMmaTile; + using FragmentScale = Array; + + /// Layout of the scales in shared memory + using Layout = layout::RowMajor; + + /// TensorRef type for loading element from a tensor + using TensorRef = TensorRef; + + static constexpr WeightOnlyQuantOp QuantOp = QuantOp_; + static_assert(QuantOp == WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY, ""); + + CUTLASS_DEVICE + MmaTensorOpDequantizer(TensorRef smem_scales, + int const warp_idx_n, + int const lane_idx) { + int const warp_offset = warp_idx_n * Shape::kN; + int const base_col = lane_idx & 0xF8 + lane_idx % 4; + int const thread_offset = warp_offset + base_col; + pointer_ = smem_scales.data() + thread_offset; + } + + CUTLASS_DEVICE + void load(FragmentScale& scale_frag) { + CUTLASS_PRAGMA_UNROLL + for (int tile_iter = 0; tile_iter < TileNIterations; ++tile_iter) { + // We jump by 32 here since volta does <32x32x4> super mmas inside a warp. + // For col major B, each thread will jump 4 cols to get its next value + // inside of the super mma. + CUTLASS_PRAGMA_UNROLL + for (int mma_iter = 0; mma_iter < 2; ++mma_iter) { + scale_frag[tile_iter * 2 + mma_iter] = + pointer_[ColsPerMmaTile * tile_iter + 4 * mma_iter]; + } } - - CUTLASS_DEVICE - void load(FragmentScale& scale_frag) - { - CUTLASS_PRAGMA_UNROLL - for (int tile_iter = 0; tile_iter < TileNIterations; ++tile_iter) - { - // We jump by 32 here since volta does <32x32x4> super mmas inside a warp. - // For col major B, each thread will jump 4 cols to get its next value inside - // of the super mma. - CUTLASS_PRAGMA_UNROLL - for (int mma_iter = 0; mma_iter < 2; ++mma_iter) - { - scale_frag[tile_iter * 2 + mma_iter] = pointer_[ColsPerMmaTile * tile_iter + 4 * mma_iter]; - } - } - } - - CUTLASS_DEVICE - void dequantize(FragmentDequantizedOperand& operand_frag, FragmentScale const& scale_frag) - { - using MmaOperandB = typename ArchMmaOperator::FragmentB; - static constexpr int total_n_mmas = 2 * TileNIterations; - static_assert(MmaOperandB::kElements * total_n_mmas == FragmentDequantizedOperand::kElements, ""); - - multiplies mul_op; - - MmaOperandB* operand_frag_ptr = reinterpret_cast(&operand_frag); - CUTLASS_PRAGMA_UNROLL - for (int mma_n_iter = 0; mma_n_iter < total_n_mmas; ++mma_n_iter) - { - operand_frag_ptr[mma_n_iter] = mul_op(operand_frag_ptr[mma_n_iter], scale_frag[mma_n_iter]); - } + } + + CUTLASS_DEVICE + void dequantize(FragmentDequantizedOperand& operand_frag, + FragmentScale const& scale_frag) { + using MmaOperandB = typename ArchMmaOperator::FragmentB; + static constexpr int total_n_mmas = 2 * TileNIterations; + static_assert(MmaOperandB::kElements * total_n_mmas == + FragmentDequantizedOperand::kElements, + ""); + + multiplies mul_op; + + MmaOperandB* operand_frag_ptr = + reinterpret_cast(&operand_frag); + CUTLASS_PRAGMA_UNROLL + for (int mma_n_iter = 0; mma_n_iter < total_n_mmas; ++mma_n_iter) { + operand_frag_ptr[mma_n_iter] = + mul_op(operand_frag_ptr[mma_n_iter], scale_frag[mma_n_iter]); } + } -private: - ElementScale const* pointer_; + private: + ElementScale const* pointer_; }; //////////////////////////////////////////////////////////////////////////////// -} // namespace warp -} // namespace gemm -} // namespace cutlass +} // namespace warp +} // namespace gemm +} // namespace cutlass //////////////////////////////////////////////////////////////////////////////// diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/warp/mma_tensorop_wint2x_dequantizer.h b/custom_ops/gpu_ops/cutlass_extensions/gemm/warp/mma_tensorop_wint2x_dequantizer.h new file mode 100644 index 00000000000..ac173150081 --- /dev/null +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/warp/mma_tensorop_wint2x_dequantizer.h @@ -0,0 +1,485 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Defines iterators used by warp-level matrix multiply operations + targeting Tensor Cores. +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/array.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/arch/arch.h" +#include "cutlass/arch/memory_sm75.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/functional.h" +#include "cutlass/platform/platform.h" + +#include "cutlass_extensions/interleaved_numeric_conversion.h" + +namespace cutlass { +namespace gemm { +namespace warp { + +namespace detail { + +template +struct DataTypeTraits; + +template <> +struct DataTypeTraits { + using Type = __nv_bfloat16; + using DualType = __nv_bfloat162; +}; + +template <> +struct DataTypeTraits { + using Type = __half; + using DualType = __half2; +}; + +template +struct LocalScaleConverter { + using FragmentSource = Array; + using FragmentResult = Array; + + CUTLASS_DEVICE + static void Apply(FragmentSource const& local_scale_frag, + FragmentResult const& super_scale_frag, + FragmentResult& scale_frag, + int shift_bit) { + constexpr uint32_t kLocalScaleMask = 0xf; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + int32_t shifted_value = + (static_cast(local_scale_frag[i]) >> shift_bit) & + kLocalScaleMask; + scale_frag[i] = static_cast(shifted_value) * super_scale_frag[i]; + } + } +}; + +template +struct LocalScaleConverter::type> { + using FragmentSource = Array; + using FragmentResult = Array; + + CUTLASS_DEVICE + static void Apply(FragmentSource const& local_scale_frag, + FragmentResult const& super_scale_frag, + FragmentResult& scale_frag, + int shift_bit) { + constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa; + constexpr uint32_t MASK = 0x000f000f; + // 2^10 = 1024 + constexpr uint32_t I4s_TO_FP16s_MAGIC_NUM = 0x64006400; + + // -2^10 = -1024 + constexpr uint32_t FP16_BIAS = 0xE400E400; + // 1.0 + constexpr uint32_t FP16_ONE = 0x3C003C00; + + __half2* scale_ptr = reinterpret_cast<__half2*>(&scale_frag); + __half2 const* super_scale_ptr = + reinterpret_cast<__half2 const*>(&super_scale_frag); + + uint32_t const* local_scale_ptr = + reinterpret_cast(&local_scale_frag); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 4; ++i) { + int i4s = local_scale_ptr[i] >> shift_bit; + + // unpack: 0, 1 + int32_t low = __byte_perm(i4s, i4s, 0xF1F0); + int32_t unpack0 = lop3(low, MASK, I4s_TO_FP16s_MAGIC_NUM); + // unpack: 2, 3 + int32_t high = __byte_perm(i4s, i4s, 0xF3F2); + int32_t unpack1 = lop3(high, MASK, I4s_TO_FP16s_MAGIC_NUM); + + __half2 scale0 = __hfma2(*reinterpret_cast<__half2*>(&unpack0), + *reinterpret_cast(&FP16_ONE), + *reinterpret_cast(&FP16_BIAS)); + __half2 scale1 = __hfma2(*reinterpret_cast<__half2*>(&unpack1), + *reinterpret_cast(&FP16_ONE), + *reinterpret_cast(&FP16_BIAS)); + + scale_ptr[2 * i] = __hmul2(scale0, super_scale_ptr[2 * i]); + scale_ptr[2 * i + 1] = __hmul2(scale1, super_scale_ptr[2 * i + 1]); + } + } +}; + +template +struct LocalScaleConverter::type> { + using FragmentSource = Array; + using FragmentResult = Array; + + CUTLASS_DEVICE + static void Apply(FragmentSource const& local_scale_frag, + FragmentResult const& super_scale_frag, + FragmentResult& scale_frag, + int shift_bit) { +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && defined(ENABLE_BF16)) + constexpr uint32_t immLut = (0xF0 & 0xCC) | 0xAA; + constexpr uint32_t MASK = 0x000F000F; + constexpr uint32_t I4s_TO_BF16s_MAGIC_NUM = 0x43004300; + + constexpr uint32_t BF16_BIAS = 0xC300C300; + constexpr uint32_t BF16_ONE = 0x3F803F80; + + __nv_bfloat162* scale_ptr = reinterpret_cast<__nv_bfloat162*>(&scale_frag); + __nv_bfloat162 const* super_scale_ptr = + reinterpret_cast<__nv_bfloat162 const*>(&super_scale_frag); + + uint32_t const* local_scale_ptr = + reinterpret_cast(&local_scale_frag); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 4; ++i) { + int i4s = local_scale_ptr[i] >> shift_bit; + + // unpack: 0, 1 + int32_t low = __byte_perm(i4s, i4s, 0xF1F0); + int32_t unpack0 = lop3(low, MASK, I4s_TO_BF16s_MAGIC_NUM); + // unpack: 2, 3 + int32_t high = __byte_perm(i4s, i4s, 0xF3F2); + int32_t unpack1 = lop3(high, MASK, I4s_TO_BF16s_MAGIC_NUM); + + nv_bfloat162 scale0 = + __hfma2(*reinterpret_cast(&unpack0), + *reinterpret_cast(&BF16_ONE), + *reinterpret_cast(&BF16_BIAS)); + nv_bfloat162 scale1 = + __hfma2(*reinterpret_cast(&unpack1), + *reinterpret_cast(&BF16_ONE), + *reinterpret_cast(&BF16_BIAS)); + + scale_ptr[2 * i] = __hmul2(scale0, super_scale_ptr[2 * i]); + scale_ptr[2 * i + 1] = __hmul2(scale1, super_scale_ptr[2 * i + 1]); + } +#else + // Slow path not implemented here on purpose. If we need to do HMMA on older + // arch, scale conversion should happen before scales are stored to shared + // memory and we should use the fp16 dequantizer. This will avoid numerous + // conversion instructions in GEMM main loop. + arch::device_breakpoint(); +#endif + } +}; + +} // namespace detail + +//////////////////////////////////////////////////////////////////////////////// + +template < + /// Matrix multiply operator + typename MmaOperator_, + /// Size of the matrix to load (concept: MatrixShape) + typename Shape_, + /// Operand identity + Operand Operand, + /// Data type of Scale elements + typename ElementOperand_, + /// Layout of operand + typename Layout_, + /// Group size for quantization + int GroupSize_, + /// + typename Enable = void> +class MmaTensorOpWin2xDequantizer { + // static_assert(false, "Not Supported!"); +}; + +//////////////////////////////////////////////////////////////////////////////// +// Bfloat specialization for Ampere +template < + /// Underlying matrix multiply operator (concept: MmaTensorOp) + typename MmaOperator_, + /// Shape of the warp level matrix multiply (concept: GemmShape) + typename Shape_, + /// Data type of Scale elements + typename ElementOperand_, + /// Group size for quantization + int GroupSize_> +class MmaTensorOpWin2xDequantizer +// typename platform::enable_if= +// 80 +// && platform::is_same::value>::type> +{ + public: + static_assert(platform::is_same::value || + platform::is_same::value, + "T must be fp16 or bf16"); + + /// Mma Operator + using MmaOperator = MmaOperator_; + + // The architecture specific mma ooperator being used + using ArchMmaOperator = typename MmaOperator::ArchMmaOperator; + + // Mma Instruction Shape + using InstructionShape = typename ArchMmaOperator::Shape; + + /// Warp mma shape + using Shape = Shape_; + + /// Type of mma operand + using ElementOperand = ElementOperand_; + + /// Layout of the scales in shared memory + using Layout = layout::RowMajor; + + /// Group size for quantization + static constexpr int kGroupSize = GroupSize_; + + /// Type of input + using ElementB = typename MmaOperator::FragmentB::Element; + static_assert(platform::is_same::value, + "ElementB must be uint2b_t"); + + /// Type of the scales + using ElementLocalScale = uint4b_t; + using ElementSuperScale = ElementOperand; + using ElementCodeScaleZp = float; + + // Fragment to hold scale data to apply to B before mma + // We need 1 fp16 per matrix iteration in the N dimension + static constexpr int kWarpIterationsAlongN = + MmaOperator::MmaIterations::kColumn; + + // use uint8_t to save 2 4-bits local scales + using FragmentLocalScale = Array; + using FragmentSuperScale = Array; + using FragmentCodeScaleZp = Array; + + /// Fragment to hold B data before Mma + using FragmentInput = Array; + + // This is the ratio of the load instruction vs the compute instruction. + static constexpr int kExpansionFactor = + MmaOperator::IteratorB::InstructionShape::kRow / InstructionShape::kK; + + static constexpr int kNumPacks = + sizeof_bits::value / sizeof_bits::value; + static constexpr int kUnpackFactor = + MmaOperator::FragmentB::kElements / (kWarpIterationsAlongN * kNumPacks); + static constexpr int kUnpackInterval = kExpansionFactor / kUnpackFactor; + + /// Unpack 4 uint2b_t values compreseed in a uint8_t to floating points. + using Uint2Converter = FastInterleavedAndBiasedNumericArrayConverter< + ElementOperand, + ElementB, + MmaOperator::FragmentB::kElements / kUnpackFactor>; + using FragmentInputUnpack = typename Uint2Converter::result_type; + + /// Fragment to hold internal scales before Mma + using FragmentScale = Array; + + /// Fragment of dequantized B + using FragmentOutput = + Array; + + /// TensorRef type for loading element from a tensor + using SuperTensorRef = cutlass::TensorRef; + using LocalTensorRef = cutlass::TensorRef; + using CodeTensorRef = cutlass::TensorRef; + + private: + // + // Data members + // + + uint8_t* pointer_local_scale_; + ElementCodeScaleZp* pointer_code_scale_; + ElementCodeScaleZp* pointer_code_zp_; + ElementSuperScale* pointer_super_scale_; + + // FragmentInputUnpack unpacked_frag_; + FragmentScale scale_frag_; + + public: + CUTLASS_DEVICE + MmaTensorOpWin2xDequantizer(SuperTensorRef smem_super_scale, + LocalTensorRef smem_local_scale, + CodeTensorRef smem_code_scale, + CodeTensorRef smem_code_zp, + int warp_idx_n, + int lane_idx) { + int warp_offset = warp_idx_n * Shape::kN; + int quad = lane_idx / 4; + int thread_offset = warp_offset + quad; + pointer_super_scale_ = smem_super_scale.data() + thread_offset; + pointer_code_scale_ = smem_code_scale.data() + thread_offset; + pointer_code_zp_ = smem_code_zp.data() + thread_offset; + pointer_local_scale_ = + reinterpret_cast(smem_local_scale.data()) + thread_offset; + } + + /// Channel-wise params, need to load just once + CUTLASS_DEVICE + void load(FragmentCodeScaleZp& code_scale_frag, + FragmentCodeScaleZp& code_zp_frag, + FragmentSuperScale& super_scale_frag) { + CUTLASS_PRAGMA_UNROLL + for (int mma_n_iter = 0; mma_n_iter < kWarpIterationsAlongN; ++mma_n_iter) { + super_scale_frag[mma_n_iter] = + pointer_super_scale_[mma_n_iter * + InstructionShape::kN]; // bank conflict + code_scale_frag[mma_n_iter] = + pointer_code_scale_[mma_n_iter * InstructionShape::kN]; + code_zp_frag[mma_n_iter] = + pointer_code_zp_[mma_n_iter * InstructionShape::kN]; + } + } + + /// Group-wise params, need to load multiple times + CUTLASS_DEVICE + void load(FragmentLocalScale& local_scale_frag) { + CUTLASS_PRAGMA_UNROLL + for (int mma_n_iter = 0; mma_n_iter < kWarpIterationsAlongN; ++mma_n_iter) { + local_scale_frag[mma_n_iter] = + pointer_local_scale_[mma_n_iter * + InstructionShape::kN]; // bank conflict + } + } + + CUTLASS_DEVICE + void dequantize(const FragmentLocalScale& local_scale_frag, + const FragmentCodeScaleZp& code_scale_frag, + const FragmentCodeScaleZp& code_zp_frag, + const FragmentSuperScale& super_scale_frag, + const FragmentInput& input_frag, + FragmentOutput& output_frag, + int tb_offset_k, + int warp_k_compute_offset) { + if constexpr (kUnpackInterval != 1) { + // unsupport now + arch::device_breakpoint(); + } + + typename Uint2Converter::source_type source_frag; + + int in_offset = warp_k_compute_offset * kUnpackInterval; + + uint8_t const* ptr_input = reinterpret_cast(&input_frag); + uint8_t* ptr_source = reinterpret_cast(&source_frag); + + CUTLASS_PRAGMA_UNROLL + for (int mma_n_iter = 0; mma_n_iter < kWarpIterationsAlongN; ++mma_n_iter) { + ptr_source[mma_n_iter] = + ptr_input[mma_n_iter * kUnpackFactor + in_offset]; + } + FragmentInputUnpack unpacked_frag = + Uint2Converter::convert(source_frag, code_scale_frag, code_zp_frag); + + // dequantize local_scale + if (warp_k_compute_offset == 0) { + using LocalScaleConverter = + detail::LocalScaleConverter; + + // special for TileRows = 64 + int local_scale_shift = (((tb_offset_k / kGroupSize) + 1) & 1) * 4; + LocalScaleConverter::Apply( + local_scale_frag, super_scale_frag, scale_frag_, local_scale_shift); + } + + // unscale + // After applying LOP3 optimizations for performance, the B operand requires + // data rearrangement. reorder: [0, 4, 1, 5, 2, 6, 3, 7, 8, 12, 9, 13, 10, + // 14, 11, 15] + const int kWarpIterationsAlongK = + FragmentOutput::kElements / kWarpIterationsAlongN; + + using Type = typename detail::DataTypeTraits::Type; + using DualType = typename detail::DataTypeTraits::DualType; + + Type* output_ptr = reinterpret_cast(&output_frag); + DualType const* unpacked_ptr = + reinterpret_cast(&unpacked_frag); + DualType const* scale_ptr = reinterpret_cast(&scale_frag_); + + CUTLASS_PRAGMA_UNROLL + for (int mma_n_iter = 0; mma_n_iter < kWarpIterationsAlongN; + mma_n_iter += 2) { + int mapped_idx_base = (mma_n_iter / 2) * kWarpIterationsAlongK; + + DualType scalex2 = scale_ptr[mma_n_iter / 2]; + + CUTLASS_PRAGMA_UNROLL + for (int mma_k_iter = 0; mma_k_iter < kWarpIterationsAlongK; + ++mma_k_iter) { + DualType unpacked_valuex2 = unpacked_ptr[mapped_idx_base + mma_k_iter]; + DualType scaled_value = __hmul2(unpacked_valuex2, scalex2); + output_ptr[mma_n_iter * kWarpIterationsAlongK + mma_k_iter] = + scaled_value.x; + output_ptr[(mma_n_iter + 1) * kWarpIterationsAlongK + mma_k_iter] = + scaled_value.y; + } + } + } + + /// Add an offset to pointer in units of elements. + /// Only group-wise params needs. + CUTLASS_DEVICE + void add_pointer_offset(int64_t const& offset) { + pointer_local_scale_ += offset; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace warp +} // namespace gemm +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm_configs.h b/custom_ops/gpu_ops/cutlass_extensions/gemm_configs.h index 81e58f20ef3..02becf23882 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/gemm_configs.h +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm_configs.h @@ -21,301 +21,322 @@ #include #include -namespace cutlass_extensions -{ -// Note: The shapes are in the format MxNxK. The K shape of the runtime config MUST match the K shape +namespace cutlass_extensions { +// Note: The shapes are in the format MxNxK. The K shape of the runtime config +// MUST match the K shape // in the kernel layout details when doing weight only quantization. -enum class CutlassTileConfig -{ - // Signals that we should run heuristics do choose a config - Undefined, +enum class CutlassTileConfig { + // Signals that we should run heuristics do choose a config + Undefined, - // Signals that we should run heuristics do choose a config - ChooseWithHeuristic, + // Signals that we should run heuristics do choose a config + ChooseWithHeuristic, - // SiMT config - CtaShape128x128x8_WarpShape64x64x8, + // SiMT config + CtaShape128x128x8_WarpShape64x64x8, - // TensorCore configs CTA_N = 128, CTA_K = 64 - // Warp configs for M=16 - CtaShape16x128x64_WarpShape16x32x64, - // Warp configs for M=32 - CtaShape32x128x64_WarpShape32x32x64, + // TensorCore configs CTA_N = 128, CTA_K = 64 + // Warp configs for M=16 + CtaShape16x128x64_WarpShape16x32x64, + // Warp configs for M=32 + CtaShape32x128x64_WarpShape32x32x64, - // Warp configs for M=64 - CtaShape64x128x64_WarpShape32x64x64, - CtaShape64x64x128_WarpShape32x64x64, - CtaShape64x128x64_WarpShape64x32x64, + // Warp configs for M=64 + CtaShape64x128x64_WarpShape32x64x64, + CtaShape64x64x128_WarpShape32x64x64, + CtaShape64x128x64_WarpShape64x32x64, - // Warp configs for M=128 - CtaShape128x64x64_WarpShape64x32x64, - CtaShape128x128x64_WarpShape64x32x64, - CtaShape128x128x64_WarpShape64x64x64, - CtaShape128x128x64_WarpShape128x32x64, - CtaShape128x256x64_WarpShape64x64x64, + // Warp configs for M=128 + CtaShape128x64x64_WarpShape64x32x64, + CtaShape128x128x64_WarpShape64x32x64, + CtaShape128x128x64_WarpShape64x64x64, + CtaShape128x128x64_WarpShape128x32x64, + CtaShape128x256x64_WarpShape64x64x64, - // Warp configs for M=256 - CtaShape256x128x64_WarpShape64x64x64, + // Warp configs for M=256 + CtaShape256x128x64_WarpShape64x64x64, - // TensorCore config CTA_N = 64, CTA_K = 128 - CtaShape128x64x128_WarpShape64x32x128, + // TensorCore config CTA_N = 64, CTA_K = 128 + CtaShape128x64x128_WarpShape64x32x128, - // TensorCore config CTA_N = 256, CTA_K = 64 - CtaShape16x256x64_WarpShape16x64x64, + // TensorCore config CTA_N = 256, CTA_K = 64 + CtaShape16x256x64_WarpShape16x64x64, - // TensorCore config CTA_N = 256, CTA_K = 128 - CtaShape16x256x128_WarpShape16x64x128 + // TensorCore config CTA_N = 256, CTA_K = 128 + CtaShape16x256x128_WarpShape16x64x128 }; -enum class SplitKStyle -{ - NO_SPLIT_K, - SPLIT_K_SERIAL, - STREAM_K, // Sm80+ - // SPLIT_K_PARALLEL // Not supported yet +enum class SplitKStyle { + NO_SPLIT_K, + SPLIT_K_SERIAL, + STREAM_K, // Sm80+ + // SPLIT_K_PARALLEL // Not supported yet }; // New enum for SM100 (Blackwell) Tile Configs // Placeholder values - actual optimal values need research -enum class CutlassTileConfigSM100 -{ - // Signals that we should run heuristics do choose a config - Undefined, - - // Signals that we should run heuristics do choose a config - ChooseWithHeuristic, - - // Actual SM100 tile configs based on user input (K-tile is 128B) - CtaShape64x64x128B, - CtaShape64x128x128B, - CtaShape64x256x128B, - CtaShape128x64x128B, - CtaShape128x128x128B, - CtaShape128x256x128B, - CtaShape256x64x128B, - CtaShape256x128x128B, - CtaShape256x256x128B - // Note: The user-provided list for get_candidate_tiles_sm100 also includes - // CtaShape128x64x128B and CtaShape256x64x128B for specific FP4 grouped gemm cases. - // These are already covered by the list above if general suffices. - // If they need distinct enum values, they should be added. - // For now, keeping the enum concise with unique shapes mentioned for general use. +enum class CutlassTileConfigSM100 { + // Signals that we should run heuristics do choose a config + Undefined, + + // Signals that we should run heuristics do choose a config + ChooseWithHeuristic, + + // Actual SM100 tile configs based on user input (K-tile is 128B) + CtaShape64x64x128B, + CtaShape64x128x128B, + CtaShape64x256x128B, + CtaShape128x64x128B, + CtaShape128x128x128B, + CtaShape128x256x128B, + CtaShape256x64x128B, + CtaShape256x128x128B, + CtaShape256x256x128B + // Note: The user-provided list for get_candidate_tiles_sm100 also includes + // CtaShape128x64x128B and CtaShape256x64x128B for specific FP4 grouped gemm + // cases. These are already covered by the list above if general suffices. If + // they need distinct enum values, they should be added. For now, keeping the + // enum concise with unique shapes mentioned for general use. }; - -enum class CutlassTileConfigSM90 -{ - // Signals that we should run heuristics do choose a config - Undefined, - - // Signals that we should run heuristics do choose a config - ChooseWithHeuristic, - - // CTA configs for M=64 - CtaShape64x16x128B, - CtaShape64x32x128B, - CtaShape64x64x128B, - CtaShape64x128x128B, - CtaShape64x256x128B, - - // CTA configs for M=128 - CtaShape128x16x128B, - CtaShape128x32x128B, - CtaShape128x64x128B, - CtaShape128x128x128B, - CtaShape128x256x128B, - - // CTA configs for M=128 - CtaShape256x128x128B, +enum class CutlassTileConfigSM90 { + // Signals that we should run heuristics do choose a config + Undefined, + + // Signals that we should run heuristics do choose a config + ChooseWithHeuristic, + + // CTA configs for M=64 + CtaShape64x16x128B, + CtaShape64x32x128B, + CtaShape64x64x128B, + CtaShape64x128x128B, + CtaShape64x256x128B, + + // CTA configs for M=128 + CtaShape128x16x128B, + CtaShape128x32x128B, + CtaShape128x64x128B, + CtaShape128x128x128B, + CtaShape128x256x128B, + + // CTA configs for M=128 + CtaShape256x128x128B, }; -enum class MainloopScheduleType -{ - AUTO // Automatically selects between pingpong and cooperative schedules on Hopper. On older architectures, this - // defaults to the "legacy" main loop schedule. +enum class MainloopScheduleType { + AUTO // Automatically selects between pingpong and cooperative schedules on + // Hopper. On older architectures, this defaults to the "legacy" main + // loop schedule. }; -enum class EpilogueScheduleType -{ - AUTO // Automatically chooses an epilogue schedule compatible with the selected main loop schedule for Hopper. For - // architectures older than hopper, the epilogue is always performed by the same thread block as the main loop. +enum class EpilogueScheduleType { + AUTO // Automatically chooses an epilogue schedule compatible with the + // selected main loop schedule for Hopper. For architectures older than + // hopper, the epilogue is always performed by the same thread block as + // the main loop. }; -enum class ClusterShape -{ - ClusterShape_1x1x1, - ClusterShape_2x1x1, - ClusterShape_1x2x1, - ClusterShape_2x2x1, - ClusterShape_1x8x1, - ClusterShape_8x1x1 +enum class ClusterShape { + ClusterShape_1x1x1, + ClusterShape_2x1x1, + ClusterShape_1x2x1, + ClusterShape_2x2x1, + ClusterShape_1x8x1, + ClusterShape_8x1x1 }; -struct CutlassGemmConfig -{ - enum CandidateConfigTypeParam : int - { - NONE = 0, - WEIGHT_ONLY = 1u << 0, - SIMT_ONLY = 1u << 1, - INT8_ONLY = 1u << 2, - HOPPER = 1u << 3, // SM90 - GROUPED_GEMM = 1u << 4, - FP8_ONLY = 1u << 5, - BLACKWELL = 1u << 6, // SM100 - FP4_ONLY = 1u << 7, // For Blackwell FP4/MXFP4 paths - }; - - CutlassTileConfig tile_config = CutlassTileConfig::ChooseWithHeuristic; - SplitKStyle split_k_style = SplitKStyle::NO_SPLIT_K; - int split_k_factor = -1; - int stages = -1; - - // config options for sm90 - CutlassTileConfigSM90 tile_config_sm90 = CutlassTileConfigSM90::ChooseWithHeuristic; - MainloopScheduleType mainloop_schedule = MainloopScheduleType::AUTO; - EpilogueScheduleType epilogue_schedule = EpilogueScheduleType::AUTO; - ClusterShape cluster_shape = ClusterShape::ClusterShape_1x1x1; - bool is_sm90 = false; - - // config options for sm100 (Blackwell) - // Assuming SM100 might use similar schedule/cluster types as SM90 for now. - // These might need to become SM100-specific if Blackwell introduces new concepts. - CutlassTileConfigSM100 tile_config_sm100 = CutlassTileConfigSM100::ChooseWithHeuristic; - // MainloopScheduleType mainloop_schedule_sm100 = MainloopScheduleType::AUTO; // Example if SM100 has different types - // EpilogueScheduleType epilogue_schedule_sm100 = EpilogueScheduleType::AUTO; // Example - // ClusterShape cluster_shape_sm100 = ClusterShape::ClusterShape_1x1x1; // Example - bool is_sm100 = false; - - - CutlassGemmConfig() : is_sm90(false), is_sm100(false) {} - - CutlassGemmConfig(CutlassTileConfig tile_config, SplitKStyle split_k_style, int split_k_factor, int stages) - : tile_config(tile_config) - , split_k_style(split_k_style) - , split_k_factor(split_k_factor) - , stages(stages) - , is_sm90(false) - , is_sm100(false) - { +struct CutlassGemmConfig { + enum CandidateConfigTypeParam : int { + NONE = 0, + WEIGHT_ONLY = 1u << 0, + SIMT_ONLY = 1u << 1, + INT8_ONLY = 1u << 2, + HOPPER = 1u << 3, // SM90 + GROUPED_GEMM = 1u << 4, + FP8_ONLY = 1u << 5, + BLACKWELL = 1u << 6, // SM100 + FP4_ONLY = 1u << 7, // For Blackwell FP4/MXFP4 paths + }; + + CutlassTileConfig tile_config = CutlassTileConfig::ChooseWithHeuristic; + SplitKStyle split_k_style = SplitKStyle::NO_SPLIT_K; + int split_k_factor = -1; + int stages = -1; + + // config options for sm90 + CutlassTileConfigSM90 tile_config_sm90 = + CutlassTileConfigSM90::ChooseWithHeuristic; + MainloopScheduleType mainloop_schedule = MainloopScheduleType::AUTO; + EpilogueScheduleType epilogue_schedule = EpilogueScheduleType::AUTO; + ClusterShape cluster_shape = ClusterShape::ClusterShape_1x1x1; + bool is_sm90 = false; + + // config options for sm100 (Blackwell) + // Assuming SM100 might use similar schedule/cluster types as SM90 for now. + // These might need to become SM100-specific if Blackwell introduces new + // concepts. + CutlassTileConfigSM100 tile_config_sm100 = + CutlassTileConfigSM100::ChooseWithHeuristic; + // MainloopScheduleType mainloop_schedule_sm100 = MainloopScheduleType::AUTO; + // // Example if SM100 has different types EpilogueScheduleType + // epilogue_schedule_sm100 = EpilogueScheduleType::AUTO; // Example + // ClusterShape cluster_shape_sm100 = ClusterShape::ClusterShape_1x1x1; // + // Example + bool is_sm100 = false; + + CutlassGemmConfig() : is_sm90(false), is_sm100(false) {} + + CutlassGemmConfig(CutlassTileConfig tile_config, + SplitKStyle split_k_style, + int split_k_factor, + int stages) + : tile_config(tile_config), + split_k_style(split_k_style), + split_k_factor(split_k_factor), + stages(stages), + is_sm90(false), + is_sm100(false) {} + + // Constructor for SM90 + CutlassGemmConfig(CutlassTileConfigSM90 tile_config_sm90_in, + MainloopScheduleType mainloop_schedule_in, + EpilogueScheduleType epilogue_schedule_in, + ClusterShape cluster_shape_in) + : tile_config_sm90(tile_config_sm90_in), + mainloop_schedule(mainloop_schedule_in), + epilogue_schedule(epilogue_schedule_in), + cluster_shape(cluster_shape_in), + is_sm90(true), + is_sm100(false) {} + + // Constructor for SM100 (Blackwell) + // Using existing MainloopScheduleType, EpilogueScheduleType, ClusterShape for + // now. These might need to be new SM100-specific types if Blackwell's TMA + // differs significantly. + CutlassGemmConfig(CutlassTileConfigSM100 tile_config_sm100_in, + MainloopScheduleType mainloop_schedule_in, + EpilogueScheduleType epilogue_schedule_in, + ClusterShape cluster_shape_in) + : tile_config_sm100(tile_config_sm100_in), + mainloop_schedule( + mainloop_schedule_in) // Potentially use mainloop_schedule_sm100 if + // types diverge + , + epilogue_schedule( + epilogue_schedule_in) // Potentially use epilogue_schedule_sm100 + , + cluster_shape(cluster_shape_in) // Potentially use cluster_shape_sm100 + , + is_sm90(false) // Explicitly false + , + is_sm100(true) {} + + std::string toString() const { + std::stringstream tactic; + tactic << "Cutlass GEMM Tactic"; + if (is_sm100 && + tile_config_sm100 != + cutlass_extensions::CutlassTileConfigSM100::ChooseWithHeuristic) { + assert(is_sm100 && !is_sm90 && "Invalid cutlass GEMM config: SM100"); + tactic + << "\n\tstyle=TMA_SM100" // Indicate SM100 specific TMA if applicable + << "\n\ttile shape ID: " << (int)tile_config_sm100 + << "\n\tcluster shape ID: " << (int)cluster_shape + << "\n\tmainloop sched: " << (int)mainloop_schedule + << "\n\tepi sched: " << (int)epilogue_schedule; + } else if (is_sm90 && tile_config_sm90 != + cutlass_extensions::CutlassTileConfigSM90:: + ChooseWithHeuristic) { + assert(is_sm90 && !is_sm100 && "Invalid cutlass GEMM config: SM90"); + tactic << "\n\tstyle=TMA_SM90" + << "\n\ttile shape ID: " << (int)tile_config_sm90 + << "\n\tcluster shape ID: " << (int)cluster_shape + << "\n\tmainloop sched: " << (int)mainloop_schedule + << "\n\tepi sched: " << (int)epilogue_schedule; + } else if (tile_config != + cutlass_extensions::CutlassTileConfig::ChooseWithHeuristic) { + assert(!is_sm90 && !is_sm100 && + "Invalid cutlass GEMM config: Compatible"); + tactic << "\n\tstyle=compatible" + << "\n\ttile shape ID: " << (int)tile_config + << "\n\tstages: " << (int)stages + << "\n\tsplit_k_style: " << (int)split_k_style + << "\n\tsplit k: " << (int)split_k_factor; + } else { + tactic << "\n\tundefined"; } - - // Constructor for SM90 - CutlassGemmConfig(CutlassTileConfigSM90 tile_config_sm90_in, MainloopScheduleType mainloop_schedule_in, - EpilogueScheduleType epilogue_schedule_in, ClusterShape cluster_shape_in) - : tile_config_sm90(tile_config_sm90_in) - , mainloop_schedule(mainloop_schedule_in) - , epilogue_schedule(epilogue_schedule_in) - , cluster_shape(cluster_shape_in) - , is_sm90(true) - , is_sm100(false) - { - } - - // Constructor for SM100 (Blackwell) - // Using existing MainloopScheduleType, EpilogueScheduleType, ClusterShape for now. - // These might need to be new SM100-specific types if Blackwell's TMA differs significantly. - CutlassGemmConfig(CutlassTileConfigSM100 tile_config_sm100_in, MainloopScheduleType mainloop_schedule_in, - EpilogueScheduleType epilogue_schedule_in, ClusterShape cluster_shape_in) - : tile_config_sm100(tile_config_sm100_in) - , mainloop_schedule(mainloop_schedule_in) // Potentially use mainloop_schedule_sm100 if types diverge - , epilogue_schedule(epilogue_schedule_in) // Potentially use epilogue_schedule_sm100 - , cluster_shape(cluster_shape_in) // Potentially use cluster_shape_sm100 - , is_sm90(false) // Explicitly false - , is_sm100(true) - { - } - - - std::string toString() const - { - std::stringstream tactic; - tactic << "Cutlass GEMM Tactic"; - if (is_sm100 && tile_config_sm100 != cutlass_extensions::CutlassTileConfigSM100::ChooseWithHeuristic) - { - assert(is_sm100 && !is_sm90 && "Invalid cutlass GEMM config: SM100"); - tactic << "\n\tstyle=TMA_SM100" // Indicate SM100 specific TMA if applicable - << "\n\ttile shape ID: " << (int) tile_config_sm100 - << "\n\tcluster shape ID: " << (int) cluster_shape - << "\n\tmainloop sched: " << (int) mainloop_schedule - << "\n\tepi sched: " << (int) epilogue_schedule; - } - else if (is_sm90 && tile_config_sm90 != cutlass_extensions::CutlassTileConfigSM90::ChooseWithHeuristic) - { - assert(is_sm90 && !is_sm100 && "Invalid cutlass GEMM config: SM90"); - tactic << "\n\tstyle=TMA_SM90" - << "\n\ttile shape ID: " << (int) tile_config_sm90 - << "\n\tcluster shape ID: " << (int) cluster_shape - << "\n\tmainloop sched: " << (int) mainloop_schedule - << "\n\tepi sched: " << (int) epilogue_schedule; - } - else if (tile_config != cutlass_extensions::CutlassTileConfig::ChooseWithHeuristic) - { - assert(!is_sm90 && !is_sm100 && "Invalid cutlass GEMM config: Compatible"); - tactic << "\n\tstyle=compatible" - << "\n\ttile shape ID: " << (int) tile_config - << "\n\tstages: " << (int) stages - << "\n\tsplit_k_style: " << (int) split_k_style - << "\n\tsplit k: " << (int) split_k_factor; - } - else - { - tactic << "\n\tundefined"; - } - tactic << "\n"; - return tactic.str(); - } - - void fromString(const std::string& str) { - std::istringstream stream(str); - std::string line; - - is_sm90 = false; // Reset flags + tactic << "\n"; + return tactic.str(); + } + + void fromString(const std::string& str) { + std::istringstream stream(str); + std::string line; + + is_sm90 = false; // Reset flags + is_sm100 = false; + + while (std::getline(stream, line)) { + if (line.find("style=TMA_SM100") != std::string::npos) { + is_sm100 = true; + is_sm90 = false; + std::getline(stream, line); + tile_config_sm100 = + static_cast( + std::stoi(line.substr(line.find(':') + 1))); + std::getline(stream, line); + cluster_shape = static_cast( + std::stoi(line.substr(line.find(':') + 1))); + std::getline(stream, line); + mainloop_schedule = + static_cast( + std::stoi(line.substr(line.find(':') + 1))); + std::getline(stream, line); + epilogue_schedule = + static_cast( + std::stoi(line.substr(line.find(':') + 1))); + } else if (line.find("style=TMA_SM90") != + std::string::npos) { // Check for SM90 specific first + is_sm90 = true; is_sm100 = false; - - while (std::getline(stream, line)) { - if (line.find("style=TMA_SM100") != std::string::npos) { - is_sm100 = true; - is_sm90 = false; - std::getline(stream, line); - tile_config_sm100 = static_cast(std::stoi(line.substr(line.find(':') + 1))); - std::getline(stream, line); - cluster_shape = static_cast(std::stoi(line.substr(line.find(':') + 1))); - std::getline(stream, line); - mainloop_schedule = static_cast(std::stoi(line.substr(line.find(':') + 1))); - std::getline(stream, line); - epilogue_schedule = static_cast(std::stoi(line.substr(line.find(':') + 1))); - } else if (line.find("style=TMA_SM90") != std::string::npos) { // Check for SM90 specific first - is_sm90 = true; - is_sm100 = false; - std::getline(stream, line); - tile_config_sm90 = static_cast(std::stoi(line.substr(line.find(':') + 1))); - std::getline(stream, line); - cluster_shape = static_cast(std::stoi(line.substr(line.find(':') + 1))); - std::getline(stream, line); - mainloop_schedule = static_cast(std::stoi(line.substr(line.find(':') + 1))); - std::getline(stream, line); - epilogue_schedule = static_cast(std::stoi(line.substr(line.find(':') + 1))); - } else if (line.find("style=compatible") != std::string::npos) { - is_sm90 = false; - is_sm100 = false; - std::getline(stream, line); - tile_config = static_cast(std::stoi(line.substr(line.find(':') + 1))); - std::getline(stream, line); - stages = std::stoi(line.substr(line.find(':') + 1)); - std::getline(stream, line); - split_k_style = static_cast(std::stoi(line.substr(line.find(':') + 1))); - std::getline(stream, line); - split_k_factor = std::stoi(line.substr(line.find(':') + 1)); - } - } + std::getline(stream, line); + tile_config_sm90 = + static_cast( + std::stoi(line.substr(line.find(':') + 1))); + std::getline(stream, line); + cluster_shape = static_cast( + std::stoi(line.substr(line.find(':') + 1))); + std::getline(stream, line); + mainloop_schedule = + static_cast( + std::stoi(line.substr(line.find(':') + 1))); + std::getline(stream, line); + epilogue_schedule = + static_cast( + std::stoi(line.substr(line.find(':') + 1))); + } else if (line.find("style=compatible") != std::string::npos) { + is_sm90 = false; + is_sm100 = false; + std::getline(stream, line); + tile_config = static_cast( + std::stoi(line.substr(line.find(':') + 1))); + std::getline(stream, line); + stages = std::stoi(line.substr(line.find(':') + 1)); + std::getline(stream, line); + split_k_style = static_cast( + std::stoi(line.substr(line.find(':') + 1))); + std::getline(stream, line); + split_k_factor = std::stoi(line.substr(line.find(':') + 1)); + } } + } }; -inline std::ostream& operator<<(std::ostream& out, CutlassGemmConfig const& config) -{ - // clang-format off +inline std::ostream& operator<<(std::ostream& out, + CutlassGemmConfig const& config) { + // clang-format off if (config.is_sm100) { out << "tile_config_sm100_enum: " << int(config.tile_config_sm100) @@ -337,8 +358,8 @@ inline std::ostream& operator<<(std::ostream& out, CutlassGemmConfig const& conf << ", split_k_factor: " << config.split_k_factor << ", stages: " << config.stages; } - // clang-format on - return out; + // clang-format on + return out; } -} // namespace cutlass_extensions +} // namespace cutlass_extensions diff --git a/custom_ops/gpu_ops/cutlass_extensions/interleaved_numeric_conversion.h b/custom_ops/gpu_ops/cutlass_extensions/interleaved_numeric_conversion.h index 44ba79680e6..9a9b35a324a 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/interleaved_numeric_conversion.h +++ b/custom_ops/gpu_ops/cutlass_extensions/interleaved_numeric_conversion.h @@ -1,12 +1,12 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation @@ -18,19 +18,21 @@ * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /*! \file - \brief Boost-like numeric conversion operator for int8 and CUTLASS int4b_t interleaved in a register + \brief Boost-like numeric conversion operator for int8 and CUTLASS int4b_t + interleaved in a register */ #pragma once @@ -39,409 +41,814 @@ #include "cutlass/array.h" #include "cutlass/half.h" #include "cutlass/numeric_types.h" - -namespace cutlass -{ - -// This converter is meant to be used with data interleaved in a 32-bit register where the even elements are in the low -// bits and the odd elemeents are in the high bits of the register. In addition, it assumes elements were originally -// signed and had a bias of 2**(b-1) added (where b is the number of bits in the type) to make all numbers unsigned. -// This converter will uninterleave the data and subtract the bias while converting to the result type. +#include "cutlass/trace.h" + +namespace cutlass { + +template +__device__ inline int lop3(int a, int b, int c) { + int res; + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(res) + : "r"(a), "r"(b), "r"(c), "n"(lut)); + return res; +} + +// This converter is meant to be used with data interleaved in a 32-bit register +// where the even elements are in the low bits and the odd elemeents are in the +// high bits of the register. In addition, it assumes elements were originally +// signed and had a bias of 2**(b-1) added (where b is the number of bits in the +// type) to make all numbers unsigned. This converter will uninterleave the data +// and subtract the bias while converting to the result type. template -struct FastInterleavedAndBiasedNumericArrayConverter -{ -}; +struct FastInterleavedAndBiasedNumericArrayConverter; template <> -struct FastInterleavedAndBiasedNumericArrayConverter -{ - using result_type = Array; - using source_type = Array; - - CUTLASS_DEVICE - static result_type convert(source_type const& source) - { - result_type result; - - uint32_t* h = reinterpret_cast(&result); - uint32_t const i8s = reinterpret_cast(source); - - static constexpr uint32_t mask_for_elt_01 = 0x5250; - static constexpr uint32_t mask_for_elt_23 = 0x5351; - static constexpr uint32_t start_byte_for_fp16 = 0x64646464; - asm volatile("prmt.b32 %0,%1,%2,%3;\n" : "=r"(h[0]) : "r"(i8s), "n"(start_byte_for_fp16), "n"(mask_for_elt_01)); - asm volatile("prmt.b32 %0,%1,%2,%3;\n" : "=r"(h[1]) : "r"(i8s), "n"(start_byte_for_fp16), "n"(mask_for_elt_23)); - - // Lastly, we subtract 1152 from our constructed number using fp16 math to get our signed integer as fp16. - static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480; - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(I8s_TO_F16s_MAGIC_NUM)); - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[1]) : "r"(h[1]), "r"(I8s_TO_F16s_MAGIC_NUM)); - - return result; - } - - CUTLASS_DEVICE - result_type operator()(source_type const& s) - { - return convert(s); - } +struct FastInterleavedAndBiasedNumericArrayConverter { + using result_type = Array; + using source_type = Array; + + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + result_type result; + + uint32_t* h = reinterpret_cast(&result); + uint32_t const i8s = reinterpret_cast(source); + + static constexpr uint32_t mask_for_elt_01 = 0x5250; + static constexpr uint32_t mask_for_elt_23 = 0x5351; + static constexpr uint32_t start_byte_for_fp16 = 0x64646464; + asm volatile("prmt.b32 %0,%1,%2,%3;\n" + : "=r"(h[0]) + : "r"(i8s), "n"(start_byte_for_fp16), "n"(mask_for_elt_01)); + asm volatile("prmt.b32 %0,%1,%2,%3;\n" + : "=r"(h[1]) + : "r"(i8s), "n"(start_byte_for_fp16), "n"(mask_for_elt_23)); + + // Lastly, we subtract 1152 from our constructed number using fp16 math to + // get our signed integer as fp16. + static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480; + asm volatile("sub.f16x2 %0, %1, %2;\n" + : "=r"(h[0]) + : "r"(h[0]), "r"(I8s_TO_F16s_MAGIC_NUM)); + asm volatile("sub.f16x2 %0, %1, %2;\n" + : "=r"(h[1]) + : "r"(h[1]), "r"(I8s_TO_F16s_MAGIC_NUM)); + + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) { return convert(s); } }; template -struct FastInterleavedAndBiasedNumericArrayConverter -{ - static constexpr int VEC_WIDTH = 4; - static_assert(!(N % VEC_WIDTH), "N must be multiple of 4."); - - using result_type = Array; - using source_type = Array; - - CUTLASS_DEVICE - static result_type convert(source_type const& source) - { - using scalar_result_type = typename result_type::Element; - using scalar_source_type = typename source_type::Element; - FastInterleavedAndBiasedNumericArrayConverter - convert_vector_; - - result_type result; - using vec_result = Array; - using vec_source = Array; - - vec_result* result_ptr = reinterpret_cast(&result); - vec_source const* source_ptr = reinterpret_cast(&source); - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N / VEC_WIDTH; ++i) - { - result_ptr[i] = convert_vector_(source_ptr[i]); - } - - return result; +struct FastInterleavedAndBiasedNumericArrayConverter { + static constexpr int VEC_WIDTH = 4; + static_assert(!(N % VEC_WIDTH), "N must be multiple of 4."); + + using result_type = Array; + using source_type = Array; + + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + using scalar_result_type = typename result_type::Element; + using scalar_source_type = typename source_type::Element; + FastInterleavedAndBiasedNumericArrayConverter + convert_vector_; + + result_type result; + using vec_result = Array; + using vec_source = Array; + + vec_result* result_ptr = reinterpret_cast(&result); + vec_source const* source_ptr = reinterpret_cast(&source); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / VEC_WIDTH; ++i) { + result_ptr[i] = convert_vector_(source_ptr[i]); } - CUTLASS_DEVICE - result_type operator()(source_type const& s) - { - return convert(s); - } + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) { return convert(s); } }; template <> -struct FastInterleavedAndBiasedNumericArrayConverter -{ - using result_type = Array; - using source_type = Array; - - CUTLASS_DEVICE - static result_type convert(source_type const& source) - { - result_type result; +struct FastInterleavedAndBiasedNumericArrayConverter { + using result_type = Array; + using source_type = Array; + + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + result_type result; #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) - uint32_t* bf16_result_ptr = reinterpret_cast(&result); - uint32_t const i8s = reinterpret_cast(source); - - static constexpr uint32_t fp32_base = 0x4B000000; - float fp32_intermediates[4]; - - // Construct FP32s, bfloat does not have enough mantissa for IADD trick - uint32_t* fp32_intermediates_casted = reinterpret_cast(fp32_intermediates); - fp32_intermediates_casted[0] = __byte_perm(i8s, fp32_base, 0x7650); - fp32_intermediates_casted[1] = __byte_perm(i8s, fp32_base, 0x7652); - fp32_intermediates_casted[2] = __byte_perm(i8s, fp32_base, 0x7651); - fp32_intermediates_casted[3] = __byte_perm(i8s, fp32_base, 0x7653); - - // Subtract out fp32_base + 128 to make the unsigned integer signed. - CUTLASS_PRAGMA_UNROLL - for (int ii = 0; ii < 4; ++ii) - { - fp32_intermediates[ii] -= 8388736.f; - } - - // Truncate the fp32 representation and pack up as bfloat16s. - CUTLASS_PRAGMA_UNROLL - for (int ii = 0; ii < 2; ++ii) - { - bf16_result_ptr[ii] - = __byte_perm(fp32_intermediates_casted[2 * ii + 0], fp32_intermediates_casted[2 * ii + 1], 0x7632); - } -#else - // Disable this on architectures older than Ampere since they lack hardware for bf16 mma. If one wishes to use - // HMMA on older hardware, they should Convert directly to FP16 using FP16 converters. - result.clear(); // Suppress compiler warning - arch::device_breakpoint(); -#endif - return result; + uint32_t* bf16_result_ptr = reinterpret_cast(&result); + uint32_t const i8s = reinterpret_cast(source); + + static constexpr uint32_t fp32_base = 0x4B000000; + float fp32_intermediates[4]; + + // Construct FP32s, bfloat does not have enough mantissa for IADD trick + uint32_t* fp32_intermediates_casted = + reinterpret_cast(fp32_intermediates); + fp32_intermediates_casted[0] = __byte_perm(i8s, fp32_base, 0x7650); + fp32_intermediates_casted[1] = __byte_perm(i8s, fp32_base, 0x7652); + fp32_intermediates_casted[2] = __byte_perm(i8s, fp32_base, 0x7651); + fp32_intermediates_casted[3] = __byte_perm(i8s, fp32_base, 0x7653); + + // Subtract out fp32_base + 128 to make the unsigned integer signed. + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < 4; ++ii) { + fp32_intermediates[ii] -= 8388736.f; } - CUTLASS_DEVICE - result_type operator()(source_type const& s) - { - return convert(s); + // Truncate the fp32 representation and pack up as bfloat16s. + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < 2; ++ii) { + bf16_result_ptr[ii] = __byte_perm(fp32_intermediates_casted[2 * ii + 0], + fp32_intermediates_casted[2 * ii + 1], + 0x7632); } +#else + // Disable this on architectures older than Ampere since they lack hardware + // for bf16 mma. If one wishes to use HMMA on older hardware, they should + // Convert directly to FP16 using FP16 converters. + result.clear(); // Suppress compiler warning + arch::device_breakpoint(); +#endif + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) { return convert(s); } }; template -struct FastInterleavedAndBiasedNumericArrayConverter -{ - static constexpr int VEC_WIDTH = 4; - static_assert(!(N % VEC_WIDTH), "N must be multiple of 4."); - - using result_type = Array; - using source_type = Array; - - CUTLASS_DEVICE - static result_type convert(source_type const& source) - { - using scalar_result_type = typename result_type::Element; - using scalar_source_type = typename source_type::Element; - FastInterleavedAndBiasedNumericArrayConverter - convert_vector_; - - result_type result; - using vec_result = Array; - using vec_source = Array; - - vec_result* result_ptr = reinterpret_cast(&result); - vec_source const* source_ptr = reinterpret_cast(&source); - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N / VEC_WIDTH; ++i) - { - result_ptr[i] = convert_vector_(source_ptr[i]); - } - - return result; +struct FastInterleavedAndBiasedNumericArrayConverter { + static constexpr int VEC_WIDTH = 4; + static_assert(!(N % VEC_WIDTH), "N must be multiple of 4."); + + using result_type = Array; + using source_type = Array; + + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + using scalar_result_type = typename result_type::Element; + using scalar_source_type = typename source_type::Element; + FastInterleavedAndBiasedNumericArrayConverter + convert_vector_; + + result_type result; + using vec_result = Array; + using vec_source = Array; + + vec_result* result_ptr = reinterpret_cast(&result); + vec_source const* source_ptr = reinterpret_cast(&source); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / VEC_WIDTH; ++i) { + result_ptr[i] = convert_vector_(source_ptr[i]); } - CUTLASS_DEVICE - result_type operator()(source_type const& s) - { - return convert(s); + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) { return convert(s); } +}; + +template <> +struct FastInterleavedAndBiasedNumericArrayConverter { + using result_type = Array; + using source_type = Array; + + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + result_type result; + + uint32_t* h = reinterpret_cast(&result); + uint32_t const i4s = reinterpret_cast(source); + + // First, we extract the i4s and construct an intermediate fp16 number. + static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint32_t BOTTOM_MASK = 0x000f000f; + static constexpr uint32_t TOP_MASK = 0x00f000f0; + static constexpr uint32_t I4s_TO_F16s_MAGIC_NUM = 0x64006400; + + // Note that the entire sequence only requires 1 shift instruction. This is + // thanks to the register packing format and the fact that we force our + // integers to be unsigned, and account for this in the fp16 subtractions. + // In addition, I exploit the fact that sub and fma have the same throughput + // in order to convert elt_23 and elt_67 to fp16 without having to shift + // them to the bottom bits before hand. + + // Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide + // RAW dependency if we issue immediately before required. + const uint32_t top_i4s = i4s >> 8; + // Extract elt_01 - (i4s & 0x000f000f) | 0x64006400 + asm volatile( + "lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[0]) + : "r"(i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); + // Extract elt_23 (i4s & 0x00f000f0) | 0x64006400 + asm volatile( + "lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[1]) + : "r"(i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); + // Extract elt_45 (top_i4s & 0x000f000f) | 0x64006400 + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[2]) + : "r"(top_i4s), + "n"(BOTTOM_MASK), + "n"(I4s_TO_F16s_MAGIC_NUM), + "n"(immLut)); + // Extract elt_67 (top_i4s & 0x00f000f0) | 0x64006400 + asm volatile( + "lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[3]) + : "r"(top_i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); + + // I use inline PTX below because I am not sure if the compiler will emit + // float2half instructions if I use the half2 ctor. In this case, I chose + // performance reliability over code readability. + + // This is the half2 {1032, 1032} represented as an integer. + static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64086408; + // This is the half2 {1 / 16, 1 / 16} represented as an integer. + static constexpr uint32_t ONE_SIXTEENTH = 0x2c002c00; + // This is the half2 {-72, -72} represented as an integer. + static constexpr uint32_t NEG_72 = 0xd480d480; + + // Finally, we construct the output numbers. + // Convert elt_01 + asm volatile("sub.f16x2 %0, %1, %2;\n" + : "=r"(h[0]) + : "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM)); + // Convert elt_23 + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" + : "=r"(h[1]) + : "r"(h[1]), "r"(ONE_SIXTEENTH), "r"(NEG_72)); + // Convert elt_45 + asm volatile("sub.f16x2 %0, %1, %2;\n" + : "=r"(h[2]) + : "r"(h[2]), "r"(FP16_TOP_MAGIC_NUM)); + // Convert elt_67 + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" + : "=r"(h[3]) + : "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_72)); + + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) { return convert(s); } +}; + +template +struct FastInterleavedAndBiasedNumericArrayConverter { + static constexpr int VEC_WIDTH = 8; + static_assert(!(N % VEC_WIDTH), "N must be multiple of 8."); + + using result_type = Array; + using source_type = Array; + + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + using scalar_result_type = typename result_type::Element; + using scalar_source_type = typename source_type::Element; + FastInterleavedAndBiasedNumericArrayConverter + convert_vector_; + + result_type result; + using vec_result = Array; + using vec_source = Array; + + vec_result* result_ptr = reinterpret_cast(&result); + vec_source const* source_ptr = reinterpret_cast(&source); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / VEC_WIDTH; ++i) { + result_ptr[i] = convert_vector_(source_ptr[i]); } + + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) { return convert(s); } }; template <> -struct FastInterleavedAndBiasedNumericArrayConverter -{ - using result_type = Array; - using source_type = Array; - - CUTLASS_DEVICE - static result_type convert(source_type const& source) - { - result_type result; - - uint32_t* h = reinterpret_cast(&result); - uint32_t const i4s = reinterpret_cast(source); - - // First, we extract the i4s and construct an intermediate fp16 number. - static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa; - static constexpr uint32_t BOTTOM_MASK = 0x000f000f; - static constexpr uint32_t TOP_MASK = 0x00f000f0; - static constexpr uint32_t I4s_TO_F16s_MAGIC_NUM = 0x64006400; - - // Note that the entire sequence only requires 1 shift instruction. This is thanks to the register packing - // format and the fact that we force our integers to be unsigned, and account for this in the fp16 subtractions. - // In addition, I exploit the fact that sub and fma have the same throughput in order to convert elt_23 and - // elt_67 to fp16 without having to shift them to the bottom bits before hand. - - // Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW dependency if we issue - // immediately before required. - const uint32_t top_i4s = i4s >> 8; - // Extract elt_01 - (i4s & 0x000f000f) | 0x64006400 - asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" - : "=r"(h[0]) - : "r"(i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); - // Extract elt_23 (i4s & 0x00f000f0) | 0x64006400 - asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" - : "=r"(h[1]) - : "r"(i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); - // Extract elt_45 (top_i4s & 0x000f000f) | 0x64006400 - asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" - : "=r"(h[2]) - : "r"(top_i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); - // Extract elt_67 (top_i4s & 0x00f000f0) | 0x64006400 - asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" - : "=r"(h[3]) - : "r"(top_i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); - - // I use inline PTX below because I am not sure if the compiler will emit float2half instructions if I use the - // half2 ctor. In this case, I chose performance reliability over code readability. - - // This is the half2 {1032, 1032} represented as an integer. - static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64086408; - // This is the half2 {1 / 16, 1 / 16} represented as an integer. - static constexpr uint32_t ONE_SIXTEENTH = 0x2c002c00; - // This is the half2 {-72, -72} represented as an integer. - static constexpr uint32_t NEG_72 = 0xd480d480; - - // Finally, we construct the output numbers. - // Convert elt_01 - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM)); - // Convert elt_23 - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[1]) : "r"(h[1]), "r"(ONE_SIXTEENTH), "r"(NEG_72)); - // Convert elt_45 - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(FP16_TOP_MAGIC_NUM)); - // Convert elt_67 - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_72)); - - return result; +struct FastInterleavedAndBiasedNumericArrayConverter { + using result_type = Array; + using source_type = Array; + + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + result_type result; +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) + + uint32_t* h = reinterpret_cast(&result); + uint32_t const source_i4s = reinterpret_cast(source); + + // First, we extract the i4s and construct an intermediate fp16 number. + static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint32_t MASK = 0x000f000f; + static constexpr uint32_t I4s_TO_BF16s_MAGIC_NUM = 0x43004300; + + // We don't have enough mantissa to remove as much shift overhead as FP16, + // so we must loop. No shift needed for first item. + uint32_t i4s = source_i4s; + asm volatile( + "lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[0]) + : "r"(i4s), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut)); + CUTLASS_PRAGMA_UNROLL + for (int ii = 1; ii < result_type::kElements / 2; ++ii) { + i4s >>= sizeof_bits::value; + // (i4s & 0x000f000f) | 0x43004300 + asm volatile( + "lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[ii]) + : "r"(i4s), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut)); } - CUTLASS_DEVICE - result_type operator()(source_type const& s) - { - return convert(s); + // This is the BF16 {-136, -136} represented as an integer. + static constexpr uint32_t BF16_BIAS = 0xC308C308; + static constexpr uint32_t BF16_ONE = 0x3F803F80; + + // Finally, we construct the output numbers. + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < result_type::kElements / 2; ++ii) { + // Since this section is for Ampere+, we use bf16 fma to do the bias + // subtraction + asm("fma.rn.bf16x2 %0, %1, %2, %3;\n" + : "=r"(h[ii]) + : "r"(h[ii]), "r"(BF16_ONE), "r"(BF16_BIAS)); } +#else + // Disable this on architectures older than Ampere since they lack hardware + // for bf16 mma. If one wishes to use HMMA on older hardware, they should + // Convert directly to FP16 using FP16 converters. + arch::device_breakpoint(); + result.clear(); // Suppress compiler warning. +#endif + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) { return convert(s); } }; template -struct FastInterleavedAndBiasedNumericArrayConverter -{ - static constexpr int VEC_WIDTH = 8; - static_assert(!(N % VEC_WIDTH), "N must be multiple of 8."); - - using result_type = Array; - using source_type = Array; - - CUTLASS_DEVICE - static result_type convert(source_type const& source) - { - using scalar_result_type = typename result_type::Element; - using scalar_source_type = typename source_type::Element; - FastInterleavedAndBiasedNumericArrayConverter - convert_vector_; - - result_type result; - using vec_result = Array; - using vec_source = Array; - - vec_result* result_ptr = reinterpret_cast(&result); - vec_source const* source_ptr = reinterpret_cast(&source); - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N / VEC_WIDTH; ++i) - { - result_ptr[i] = convert_vector_(source_ptr[i]); - } - - return result; +struct FastInterleavedAndBiasedNumericArrayConverter { + static constexpr int VEC_WIDTH = 8; + static_assert(!(N % VEC_WIDTH), "N must be multiple of 8."); + + using result_type = Array; + using source_type = Array; + + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + using scalar_result_type = typename result_type::Element; + using scalar_source_type = typename source_type::Element; + FastInterleavedAndBiasedNumericArrayConverter + convert_vector_; + + result_type result; + using vec_result = Array; + using vec_source = Array; + + vec_result* result_ptr = reinterpret_cast(&result); + vec_source const* source_ptr = reinterpret_cast(&source); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / VEC_WIDTH; ++i) { + result_ptr[i] = convert_vector_(source_ptr[i]); } - CUTLASS_DEVICE - result_type operator()(source_type const& s) - { - return convert(s); - } + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) { return convert(s); } }; template <> -struct FastInterleavedAndBiasedNumericArrayConverter -{ - using result_type = Array; - using source_type = Array; - - CUTLASS_DEVICE - static result_type convert(source_type const& source) - { - result_type result; -#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) +struct FastInterleavedAndBiasedNumericArrayConverter { + using result_type = Array; + using source_type = Array; + + using ScaleComputeT = float; + using code_type = Array; + + CUTLASS_DEVICE + static result_type convert(source_type const& source, + ScaleComputeT code_scale, + ScaleComputeT code_zp) { + uint32_t const i8s = reinterpret_cast(source); + + // 2^23 = 8388608 + static constexpr uint32_t FP32_BASE = 0x4B000000; + + float fp32_intermediates[4]; + uint32_t* fp32_intermediates_casted = + reinterpret_cast(fp32_intermediates); + fp32_intermediates_casted[0] = __byte_perm(i8s, FP32_BASE, 0x7650); + fp32_intermediates_casted[1] = __byte_perm(i8s, FP32_BASE, 0x7651); + fp32_intermediates_casted[2] = __byte_perm(i8s, FP32_BASE, 0x7652); + fp32_intermediates_casted[3] = __byte_perm(i8s, FP32_BASE, 0x7653); + + asm volatile("sub.f32 %0, %1, %2;\n" + : "=r"(fp32_intermediates_casted[0]) + : "r"(fp32_intermediates_casted[0]), "r"(FP32_BASE)); + asm volatile("sub.f32 %0, %1, %2;\n" + : "=r"(fp32_intermediates_casted[1]) + : "r"(fp32_intermediates_casted[1]), "r"(FP32_BASE)); + asm volatile("sub.f32 %0, %1, %2;\n" + : "=r"(fp32_intermediates_casted[2]) + : "r"(fp32_intermediates_casted[2]), "r"(FP32_BASE)); + asm volatile("sub.f32 %0, %1, %2;\n" + : "=r"(fp32_intermediates_casted[3]) + : "r"(fp32_intermediates_casted[3]), "r"(FP32_BASE)); + + int32_t decode_value[4]; + ScaleComputeT new_code_zp = code_zp + 0.5f; + + decode_value[0] = + __float2int_rd(fmaf(fp32_intermediates[0], code_scale, new_code_zp)); + decode_value[1] = + __float2int_rd(fmaf(fp32_intermediates[1], code_scale, new_code_zp)); + decode_value[2] = + __float2int_rd(fmaf(fp32_intermediates[2], code_scale, new_code_zp)); + decode_value[3] = + __float2int_rd(fmaf(fp32_intermediates[3], code_scale, new_code_zp)); + + return convert_impl(decode_value); + } + + CUTLASS_DEVICE + static result_type convert(source_type const& source, + code_type const& code_scale, + code_type const& code_zp) { + uint32_t const i8s = reinterpret_cast(source); + + // 2^23 = 8388608 + static constexpr uint32_t FP32_BASE = 0x4B000000; + + float fp32_intermediates[4]; + uint32_t* fp32_intermediates_casted = + reinterpret_cast(fp32_intermediates); + fp32_intermediates_casted[0] = __byte_perm(i8s, FP32_BASE, 0x7650); + fp32_intermediates_casted[1] = __byte_perm(i8s, FP32_BASE, 0x7651); + fp32_intermediates_casted[2] = __byte_perm(i8s, FP32_BASE, 0x7652); + fp32_intermediates_casted[3] = __byte_perm(i8s, FP32_BASE, 0x7653); + + asm volatile("sub.f32 %0, %1, %2;\n" + : "=r"(fp32_intermediates_casted[0]) + : "r"(fp32_intermediates_casted[0]), "r"(FP32_BASE)); + asm volatile("sub.f32 %0, %1, %2;\n" + : "=r"(fp32_intermediates_casted[1]) + : "r"(fp32_intermediates_casted[1]), "r"(FP32_BASE)); + asm volatile("sub.f32 %0, %1, %2;\n" + : "=r"(fp32_intermediates_casted[2]) + : "r"(fp32_intermediates_casted[2]), "r"(FP32_BASE)); + asm volatile("sub.f32 %0, %1, %2;\n" + : "=r"(fp32_intermediates_casted[3]) + : "r"(fp32_intermediates_casted[3]), "r"(FP32_BASE)); + + int32_t decode_value[4]; + + decode_value[0] = __float2int_rd( + fmaf(fp32_intermediates[0], code_scale[0], code_zp[0] + 0.5f)); + decode_value[1] = __float2int_rd( + fmaf(fp32_intermediates[1], code_scale[1], code_zp[1] + 0.5f)); + decode_value[2] = __float2int_rd( + fmaf(fp32_intermediates[2], code_scale[2], code_zp[2] + 0.5f)); + decode_value[3] = __float2int_rd( + fmaf(fp32_intermediates[3], code_scale[3], code_zp[3] + 0.5f)); + + return convert_impl(decode_value); + } + + CUTLASS_DEVICE + static result_type convert_impl(int32_t* decode_value) { + result_type result; + static constexpr uint32_t immLut = (0xF0 & 0xCC) | 0xAA; + + static constexpr uint32_t MASK = 0x003F003F; + // 2^10 = 1024 + static constexpr uint32_t EX = 0x64006400; + + uint32_t* h = reinterpret_cast(&result); + + int32_t q0 = __byte_perm(decode_value[0], decode_value[1], 0x5410); + int32_t q1 = __byte_perm(decode_value[2], decode_value[3], 0x5410); + + h[0] = lop3(q0 >> 9, MASK, EX); + h[1] = lop3(q0 >> 6, MASK, EX); + h[2] = lop3(q0 >> 3, MASK, EX); + h[3] = lop3(q0, MASK, EX); + + h[4] = lop3(q1 >> 9, MASK, EX); + h[5] = lop3(q1 >> 6, MASK, EX); + h[6] = lop3(q1 >> 3, MASK, EX); + h[7] = lop3(q1, MASK, EX); + + // 1024 + 32 = 1056 + static constexpr uint32_t SUB = 0x64206420; + + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(SUB)); + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[1]) : "r"(h[1]), "r"(SUB)); + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(SUB)); + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[3]) : "r"(h[3]), "r"(SUB)); + + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[4]) : "r"(h[4]), "r"(SUB)); + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[5]) : "r"(h[5]), "r"(SUB)); + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[6]) : "r"(h[6]), "r"(SUB)); + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[7]) : "r"(h[7]), "r"(SUB)); + + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s, + ScaleComputeT code_scale, + ScaleComputeT code_zp) { + return convert(s, code_scale, code_zp); + } +}; - uint32_t* h = reinterpret_cast(&result); - uint32_t const source_i4s = reinterpret_cast(source); - - // First, we extract the i4s and construct an intermediate fp16 number. - static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa; - static constexpr uint32_t MASK = 0x000f000f; - static constexpr uint32_t I4s_TO_BF16s_MAGIC_NUM = 0x43004300; - - // We don't have enough mantissa to remove as much shift overhead as FP16, so we must loop. - // No shift needed for first item. - uint32_t i4s = source_i4s; - asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" - : "=r"(h[0]) - : "r"(i4s), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut)); - CUTLASS_PRAGMA_UNROLL - for (int ii = 1; ii < result_type::kElements / 2; ++ii) - { - i4s >>= sizeof_bits::value; - // (i4s & 0x000f000f) | 0x43004300 - asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" - : "=r"(h[ii]) - : "r"(i4s), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut)); - } - - // This is the BF16 {-136, -136} represented as an integer. - static constexpr uint32_t BF16_BIAS = 0xC308C308; - static constexpr uint32_t BF16_ONE = 0x3F803F80; - - // Finally, we construct the output numbers. - CUTLASS_PRAGMA_UNROLL - for (int ii = 0; ii < result_type::kElements / 2; ++ii) - { - // Since this section is for Ampere+, we use bf16 fma to do the bias subtraction - asm("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[ii]) : "r"(h[ii]), "r"(BF16_ONE), "r"(BF16_BIAS)); - } +template <> +struct FastInterleavedAndBiasedNumericArrayConverter { + using result_type = Array; + using source_type = Array; + + using ScaleComputeT = float; + using code_type = Array; + + CUTLASS_DEVICE + static result_type convert(source_type const& source, + ScaleComputeT code_scale, + ScaleComputeT code_zp) { + uint32_t const i8s = reinterpret_cast(source); + + // 2^23 = 8388608 + static constexpr uint32_t FP32_BASE = 0x4B000000; + + float fp32_intermediates[4]; + uint32_t* fp32_intermediates_casted = + reinterpret_cast(fp32_intermediates); + fp32_intermediates_casted[0] = __byte_perm(i8s, FP32_BASE, 0x7650); + fp32_intermediates_casted[1] = __byte_perm(i8s, FP32_BASE, 0x7651); + fp32_intermediates_casted[2] = __byte_perm(i8s, FP32_BASE, 0x7652); + fp32_intermediates_casted[3] = __byte_perm(i8s, FP32_BASE, 0x7653); + + asm volatile("sub.f32 %0, %1, %2;\n" + : "=r"(fp32_intermediates_casted[0]) + : "r"(fp32_intermediates_casted[0]), "r"(FP32_BASE)); + asm volatile("sub.f32 %0, %1, %2;\n" + : "=r"(fp32_intermediates_casted[1]) + : "r"(fp32_intermediates_casted[1]), "r"(FP32_BASE)); + asm volatile("sub.f32 %0, %1, %2;\n" + : "=r"(fp32_intermediates_casted[2]) + : "r"(fp32_intermediates_casted[2]), "r"(FP32_BASE)); + asm volatile("sub.f32 %0, %1, %2;\n" + : "=r"(fp32_intermediates_casted[3]) + : "r"(fp32_intermediates_casted[3]), "r"(FP32_BASE)); + + int32_t decode_value[4]; + ScaleComputeT new_code_zp = code_zp + 0.5f; + + decode_value[0] = + __float2int_rd(fmaf(fp32_intermediates[0], code_scale, new_code_zp)); + decode_value[1] = + __float2int_rd(fmaf(fp32_intermediates[1], code_scale, new_code_zp)); + decode_value[2] = + __float2int_rd(fmaf(fp32_intermediates[2], code_scale, new_code_zp)); + decode_value[3] = + __float2int_rd(fmaf(fp32_intermediates[3], code_scale, new_code_zp)); + + return convert_impl(decode_value); + } + + CUTLASS_DEVICE + static result_type convert(source_type const& source, + code_type const& code_scale, + code_type const& code_zp) { + uint32_t const i8s = reinterpret_cast(source); + + // 2^23 = 8388608 + static constexpr uint32_t FP32_BASE = 0x4B000000; + + float fp32_intermediates[4]; + uint32_t* fp32_intermediates_casted = + reinterpret_cast(fp32_intermediates); + fp32_intermediates_casted[0] = __byte_perm(i8s, FP32_BASE, 0x7650); + fp32_intermediates_casted[1] = __byte_perm(i8s, FP32_BASE, 0x7651); + fp32_intermediates_casted[2] = __byte_perm(i8s, FP32_BASE, 0x7652); + fp32_intermediates_casted[3] = __byte_perm(i8s, FP32_BASE, 0x7653); + + asm volatile("sub.f32 %0, %1, %2;\n" + : "=r"(fp32_intermediates_casted[0]) + : "r"(fp32_intermediates_casted[0]), "r"(FP32_BASE)); + asm volatile("sub.f32 %0, %1, %2;\n" + : "=r"(fp32_intermediates_casted[1]) + : "r"(fp32_intermediates_casted[1]), "r"(FP32_BASE)); + asm volatile("sub.f32 %0, %1, %2;\n" + : "=r"(fp32_intermediates_casted[2]) + : "r"(fp32_intermediates_casted[2]), "r"(FP32_BASE)); + asm volatile("sub.f32 %0, %1, %2;\n" + : "=r"(fp32_intermediates_casted[3]) + : "r"(fp32_intermediates_casted[3]), "r"(FP32_BASE)); + + int32_t decode_value[4]; + + decode_value[0] = __float2int_rd( + fmaf(fp32_intermediates[0], code_scale[0], code_zp[0] + 0.5f)); + decode_value[1] = __float2int_rd( + fmaf(fp32_intermediates[1], code_scale[1], code_zp[1] + 0.5f)); + decode_value[2] = __float2int_rd( + fmaf(fp32_intermediates[2], code_scale[2], code_zp[2] + 0.5f)); + decode_value[3] = __float2int_rd( + fmaf(fp32_intermediates[3], code_scale[3], code_zp[3] + 0.5f)); + + return convert_impl(decode_value); + } + + CUTLASS_DEVICE + static result_type convert_impl(int32_t* decode_value) { + result_type result; + + static constexpr uint32_t immLut = (0xF0 & 0xCC) | 0xAA; + static constexpr uint32_t MASK = 0x003F003F; + // 2^7 = 128 + static constexpr uint32_t EX = 0x43004300; + + uint32_t* h = reinterpret_cast(&result); + + int32_t q0 = __byte_perm(decode_value[0], decode_value[1], 0x5410); + int32_t q1 = __byte_perm(decode_value[2], decode_value[3], 0x5410); + + h[0] = lop3(q0 >> 9, MASK, EX); + h[1] = lop3(q0 >> 6, MASK, EX); + h[2] = lop3(q0 >> 3, MASK, EX); + h[3] = lop3(q0, MASK, EX); + + h[4] = lop3(q1 >> 9, MASK, EX); + h[5] = lop3(q1 >> 6, MASK, EX); + h[6] = lop3(q1 >> 3, MASK, EX); + h[7] = lop3(q1, MASK, EX); + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && defined(ENABLE_BF16)) + // 128 + 32 = 160 + static constexpr uint32_t SUB = 0x43204320; + + asm volatile("sub.bf16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(SUB)); + asm volatile("sub.bf16x2 %0, %1, %2;\n" : "=r"(h[1]) : "r"(h[1]), "r"(SUB)); + asm volatile("sub.bf16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(SUB)); + asm volatile("sub.bf16x2 %0, %1, %2;\n" : "=r"(h[3]) : "r"(h[3]), "r"(SUB)); + + asm volatile("sub.bf16x2 %0, %1, %2;\n" : "=r"(h[4]) : "r"(h[4]), "r"(SUB)); + asm volatile("sub.bf16x2 %0, %1, %2;\n" : "=r"(h[5]) : "r"(h[5]), "r"(SUB)); + asm volatile("sub.bf16x2 %0, %1, %2;\n" : "=r"(h[6]) : "r"(h[6]), "r"(SUB)); + asm volatile("sub.bf16x2 %0, %1, %2;\n" : "=r"(h[7]) : "r"(h[7]), "r"(SUB)); #else - // Disable this on architectures older than Ampere since they lack hardware for bf16 mma. If one wishes to use - // HMMA on older hardware, they should Convert directly to FP16 using FP16 converters. - arch::device_breakpoint(); - result.clear(); // Suppress compiler warning. + // 1.0 + static constexpr uint32_t MUL = 0x3F803F80; + // -160 + static constexpr uint32_t ADD = 0xC320C320; + + asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" + : "=r"(h[0]) + : "r"(h[0]), "r"(MUL), "r"(ADD)); + asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" + : "=r"(h[1]) + : "r"(h[1]), "r"(MUL), "r"(ADD)); + asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" + : "=r"(h[2]) + : "r"(h[2]), "r"(MUL), "r"(ADD)); + asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" + : "=r"(h[3]) + : "r"(h[3]), "r"(MUL), "r"(ADD)); + + asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" + : "=r"(h[4]) + : "r"(h[4]), "r"(MUL), "r"(ADD)); + asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" + : "=r"(h[5]) + : "r"(h[5]), "r"(MUL), "r"(ADD)); + asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" + : "=r"(h[6]) + : "r"(h[6]), "r"(MUL), "r"(ADD)); + asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" + : "=r"(h[7]) + : "r"(h[7]), "r"(MUL), "r"(ADD)); #endif - return result; - } - CUTLASS_DEVICE - result_type operator()(source_type const& s) - { - return convert(s); - } + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s, + ScaleComputeT code_scale, + ScaleComputeT code_zp) { + return convert(s, code_scale, code_zp); + } }; -template -struct FastInterleavedAndBiasedNumericArrayConverter -{ - static constexpr int VEC_WIDTH = 8; - static_assert(!(N % VEC_WIDTH), "N must be multiple of 8."); - - using result_type = Array; - using source_type = Array; - - CUTLASS_DEVICE - static result_type convert(source_type const& source) - { - using scalar_result_type = typename result_type::Element; - using scalar_source_type = typename source_type::Element; - FastInterleavedAndBiasedNumericArrayConverter - convert_vector_; - - result_type result; - using vec_result = Array; - using vec_source = Array; - - vec_result* result_ptr = reinterpret_cast(&result); - vec_source const* source_ptr = reinterpret_cast(&source); - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N / VEC_WIDTH; ++i) - { - result_ptr[i] = convert_vector_(source_ptr[i]); - } - - return result; +template +struct FastInterleavedAndBiasedNumericArrayConverter { + static_assert(platform::is_same::value || + platform::is_same::value, + "T must be fp16 or bf16"); + + static constexpr int kVecWidth = 16; + static_assert(!(N % kVecWidth), "N must be multiple of 16."); + + using result_type = Array; + using source_type = Array; + using code_type = Array; + + CUTLASS_DEVICE + static result_type convert(source_type const& source, + code_type const& code_scale, + code_type const& code_zp) { + using scalar_result_type = typename result_type::Element; + using scalar_source_type = typename source_type::Element; + FastInterleavedAndBiasedNumericArrayConverter + convert_vector_; + + result_type result; + using vec_result = Array; + using vec_source = Array; + + vec_result* result_ptr = reinterpret_cast(&result); + vec_source const* source_ptr = reinterpret_cast(&source); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / kVecWidth; ++i) { + result_ptr[i] = convert_vector_(source_ptr[i], code_scale[i], code_zp[i]); } - CUTLASS_DEVICE - result_type operator()(source_type const& s) - { - return convert(s); + return result; + } + + CUTLASS_DEVICE + static result_type convert(source_type const& source, + Array const& code_scale, + Array const& code_zp) { + using scalar_result_type = typename result_type::Element; + using scalar_source_type = typename source_type::Element; + using Converter = + FastInterleavedAndBiasedNumericArrayConverter; + + result_type result; + using vec_result = typename Converter::result_type; + using vec_source = typename Converter::source_type; + using vec_code = typename Converter::code_type; + + vec_result* result_ptr = reinterpret_cast(&result); + vec_source const* source_ptr = reinterpret_cast(&source); + vec_code const* code_scale_ptr = + reinterpret_cast(&code_scale); + vec_code const* code_zp_ptr = reinterpret_cast(&code_zp); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / kVecWidth; ++i) { + result_ptr[i] = + Converter::convert(source_ptr[i], code_scale_ptr[i], code_zp_ptr[i]); } + + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s, + code_type const& code_scale, + code_type const& code_zp) { + return convert(s, code_scale, code_zp); + } }; ///////////////////////////////////////////////////////////////////////////////////////////////// -} // namespace cutlass +} // namespace cutlass ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/custom_ops/gpu_ops/cutlass_extensions/tile_interleaved_layout.h b/custom_ops/gpu_ops/cutlass_extensions/tile_interleaved_layout.h index 5a0cd295708..928f2645a53 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/tile_interleaved_layout.h +++ b/custom_ops/gpu_ops/cutlass_extensions/tile_interleaved_layout.h @@ -1,12 +1,12 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation @@ -18,14 +18,15 @@ * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /*! \file @@ -38,29 +39,24 @@ #include "cutlass/matrix_coord.h" #include "cutlass/pitch_linear_coord.h" -namespace cutlass -{ -namespace layout -{ +namespace cutlass { +namespace layout { template -struct ColumnMajorTileInterleave -{ - static constexpr int kRowsPerTile = RowsPerTile; - static constexpr int kColumnsInterleaved = ColumnsInterleaved; +struct ColumnMajorTileInterleave { + static constexpr int kRowsPerTile = RowsPerTile; + static constexpr int kColumnsInterleaved = ColumnsInterleaved; }; template -struct IsColumnMajorTileInterleave -{ - static constexpr bool value = false; +struct IsColumnMajorTileInterleave { + static constexpr bool value = false; }; template -struct IsColumnMajorTileInterleave> -{ - static constexpr bool value = true; +struct IsColumnMajorTileInterleave> { + static constexpr bool value = true; }; -} // namespace layout -} // namespace cutlass +} // namespace layout +} // namespace cutlass diff --git a/custom_ops/gpu_ops/cutlass_extensions/transform/threadblock/fine_grained_scale_zero_iterator.h b/custom_ops/gpu_ops/cutlass_extensions/transform/threadblock/fine_grained_scale_zero_iterator.h index 6095925e372..6d45e5cb02a 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/transform/threadblock/fine_grained_scale_zero_iterator.h +++ b/custom_ops/gpu_ops/cutlass_extensions/transform/threadblock/fine_grained_scale_zero_iterator.h @@ -1,12 +1,12 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation @@ -18,19 +18,20 @@ * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /*! \file - \brief Templates for visiting scales to be used when dequantizing the weights for weight-only GEMM - quantization. + \brief Templates for visiting scales to be used when dequantizing the + weights for weight-only GEMM quantization. */ #pragma once @@ -50,201 +51,205 @@ //////////////////////////////////////////////////////////////////////////////// -namespace cutlass -{ -namespace transform -{ -namespace threadblock -{ +namespace cutlass { +namespace transform { +namespace threadblock { //////////////////////////////////////////////////////////////////////////////// -template +template class FineGrainedScaleZeroIterator; template -class FineGrainedScaleZeroIterator -{ -public: - using Shape = Shape_; - using Element = Element_; - using Layout = layout::RowMajor; - static int const kAdvanceRank = 0; - static int const kAlignment = Alignment_; - - static int const kAccessesPerVector = 1; - - /// Row index of scales corresponding to the groupsize of 64 - int row_groupsize64_; - int group_size_; - - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - - using TensorRef = TensorRef; - using TensorView = TensorView; - using TensorCoord = typename Layout::TensorCoord; - using Pointer = Element*; - using NonConstPointer = typename platform::remove_const::type*; - - using AccessType = AlignedArray; - - using Fragment = cutlass::Array; - - // For compatibility with existing iterator interface - struct Params - { - LongIndex stride_ = 0; - - /// amount (in byte) to increment pointer from first access of current tile - /// to first access of next tile - LongIndex inc_advance_ = 0; - - // Default ctor - CUTLASS_HOST_DEVICE - Params() {} - - /// Construct the Params object given a pitch-linear tensor's layout - CUTLASS_HOST_DEVICE - Params(Layout const& layout) - : stride_(layout.stride(0)) - { - inc_advance_ = Shape::kRow * stride_ * sizeof_bits::value / 8; - } - }; - -private: - /// Internal pointer type permits fast address arithmetic - using BytePointer = char*; - -private: - // - // Data members - // - - /// Parameters object with precomputed internal state - Params const params_; - - /// Internal pointer to first access of tile - BytePointer pointer_scale_; - BytePointer pointer_zero_; - - bool is_valid_ = false; - -public: - /// Constructs a TileIterator from its precomputed state, threadblock offset, - /// and thread ID - CUTLASS_DEVICE - FineGrainedScaleZeroIterator( - ///< Precomputed parameters object - Params const& params, - ///< Pointer to start of scale tensor - Pointer pointer_scale, - ///< Pointer to start of zero tensor - Pointer pointer_zero, - ///< Extent of the scale and bias - TensorCoord extent, - ///< ID of each participating thread - int thread_id, - ///< Initial offset of threadblock - TensorCoord const& threadblock_offset, - ///< Group size - int group_size) - : params_(params) - , pointer_scale_(reinterpret_cast(const_cast(pointer_scale))) - , pointer_zero_(reinterpret_cast(const_cast(pointer_zero))) - { - row_groupsize64_ = threadblock_offset.row(); - group_size_ = group_size; - - const LongIndex tb_row_byte_offset - = threadblock_offset.row() / (group_size / 64) * params_.stride_ * sizeof_bits::value / 8; - const LongIndex tb_col_byte_offset = threadblock_offset.column() * sizeof_bits::value / 8; - pointer_scale_ += (tb_row_byte_offset + tb_col_byte_offset); - - if (pointer_zero_ != nullptr) - { - pointer_zero_ += (tb_row_byte_offset + tb_col_byte_offset); - } - - static constexpr int THREADS_PER_ROW = Shape::kColumn / kAlignment; - - int const thread_row = thread_id / THREADS_PER_ROW; - int const thread_col = thread_id % THREADS_PER_ROW; - - const LongIndex thread_row_byte_offset = thread_row * params_.stride_ * sizeof_bits::value / 8; - const LongIndex thread_col_byte_offset = thread_col * kAlignment * sizeof_bits::value / 8; - pointer_scale_ += (thread_row_byte_offset + thread_col_byte_offset); - if (pointer_zero_ != nullptr) - { - pointer_zero_ += (thread_row_byte_offset + thread_col_byte_offset); - } - - // For the rows, we must check that we are within the extent AND the tile to avoid extra reads on - // a given iteration. The same threads will be responsible for issues reads since the number of scales - // read in a given iteration is a constant. Therefore, we should never have to update is_valid_ - // outside of the constructor. - int const global_row = threadblock_offset.row() + thread_row; - int const global_col = threadblock_offset.column() + thread_col * kAlignment; - - bool const row_in_bounds = global_row < extent.row() && thread_row < Shape::kRow; - bool const col_in_bounds = global_col < extent.column(); - - is_valid_ = row_in_bounds && col_in_bounds; - } +class FineGrainedScaleZeroIterator { + public: + using Shape = Shape_; + using Element = Element_; + using Layout = layout::RowMajor; + static int const kAdvanceRank = 0; + static int const kAlignment = Alignment_; + + static int const kAccessesPerVector = 1; + + /// Row index of scales corresponding to the groupsize of 64 + int row_groupsize64_; + int group_size_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + using AccessType = AlignedArray; + + using Fragment = cutlass::Array; + + // For compatibility with existing iterator interface + struct Params { + LongIndex stride_ = 0; + + /// amount (in byte) to increment pointer from first access of current tile + /// to first access of next tile + LongIndex inc_advance_ = 0; + + // Default ctor + CUTLASS_HOST_DEVICE + Params() {} - /// Construct a PredicatedTileAccessIterator with zero threadblock offset - CUTLASS_HOST_DEVICE FineGrainedScaleZeroIterator(Params const& params, ///< Precomputed parameters object - Pointer pointer_scale, ///< Pointer to start of scale tensor - Pointer pointer_zero, ///< Pointer to start of zero tensor - TensorCoord extent, ///< Extent of tensor - int thread_id, ///< ID of each participating thread - int group_size) - : FineGrainedScaleZeroIterator( - params, pointer_scale, pointer_zero, extent, thread_id, make_Coord(0, 0), group_size) - { + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) : stride_(layout.stride(0)) { + inc_advance_ = Shape::kRow * stride_ * sizeof_bits::value / 8; } - - CUTLASS_DEVICE - void add_tile_offset(TensorCoord const& tile_offset) - { - const LongIndex row_byte_offset = tile_offset.row() * params_.inc_advance_; - const LongIndex col_byte_offset = tile_offset.column() * Shape::kColumn * sizeof_bits::value / 8; - pointer_scale_ += row_byte_offset + col_byte_offset; - if (pointer_zero_ != nullptr) - { - pointer_zero_ += row_byte_offset + col_byte_offset; - } + }; + + private: + /// Internal pointer type permits fast address arithmetic + using BytePointer = char*; + + private: + // + // Data members + // + + /// Parameters object with precomputed internal state + Params const params_; + + /// Internal pointer to first access of tile + BytePointer pointer_scale_; + BytePointer pointer_zero_; + + bool is_valid_ = false; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_DEVICE + FineGrainedScaleZeroIterator( + ///< Precomputed parameters object + Params const& params, + ///< Pointer to start of scale tensor + Pointer pointer_scale, + ///< Pointer to start of zero tensor + Pointer pointer_zero, + ///< Extent of the scale and bias + TensorCoord extent, + ///< ID of each participating thread + int thread_id, + ///< Initial offset of threadblock + TensorCoord const& threadblock_offset, + ///< Group size + int group_size) + : params_(params), + pointer_scale_(reinterpret_cast( + const_cast(pointer_scale))), + pointer_zero_(reinterpret_cast( + const_cast(pointer_zero))) { + row_groupsize64_ = threadblock_offset.row(); + group_size_ = group_size; + + const LongIndex tb_row_byte_offset = threadblock_offset.row() / + (group_size / 64) * params_.stride_ * + sizeof_bits::value / 8; + const LongIndex tb_col_byte_offset = + threadblock_offset.column() * sizeof_bits::value / 8; + pointer_scale_ += (tb_row_byte_offset + tb_col_byte_offset); + + if (pointer_zero_ != nullptr) { + pointer_zero_ += (tb_row_byte_offset + tb_col_byte_offset); } - /// Clears the predicate set efficiently - CUTLASS_HOST_DEVICE void clear_mask(bool enable = true) - { - is_valid_ &= (!enable); - } + static constexpr int THREADS_PER_ROW = Shape::kColumn / kAlignment; - /// Returns whether access is valid or not - CUTLASS_HOST_DEVICE - bool valid() const - { - return is_valid_; - } + int const thread_row = thread_id / THREADS_PER_ROW; + int const thread_col = thread_id % THREADS_PER_ROW; - /// Returns a scale pointer - CUTLASS_HOST_DEVICE - AccessType* get_scale() const - { - return reinterpret_cast(pointer_scale_); + const LongIndex thread_row_byte_offset = + thread_row * params_.stride_ * sizeof_bits::value / 8; + const LongIndex thread_col_byte_offset = + thread_col * kAlignment * sizeof_bits::value / 8; + pointer_scale_ += (thread_row_byte_offset + thread_col_byte_offset); + if (pointer_zero_ != nullptr) { + pointer_zero_ += (thread_row_byte_offset + thread_col_byte_offset); } - /// Returns a zero pointer - CUTLASS_HOST_DEVICE - AccessType* get_zero() const - { - return reinterpret_cast(pointer_zero_); + // For the rows, we must check that we are within the extent AND the tile to + // avoid extra reads on a given iteration. The same threads will be + // responsible for issues reads since the number of scales read in a given + // iteration is a constant. Therefore, we should never have to update + // is_valid_ outside of the constructor. + int const global_row = threadblock_offset.row() + thread_row; + int const global_col = + threadblock_offset.column() + thread_col * kAlignment; + + bool const row_in_bounds = + global_row < extent.row() && thread_row < Shape::kRow; + bool const col_in_bounds = global_col < extent.column(); + + is_valid_ = row_in_bounds && col_in_bounds; + } + + /// Construct a PredicatedTileAccessIterator with zero threadblock offset + CUTLASS_HOST_DEVICE FineGrainedScaleZeroIterator( + Params const& params, ///< Precomputed parameters object + Pointer pointer_scale, ///< Pointer to start of scale tensor + Pointer pointer_zero, ///< Pointer to start of zero tensor + TensorCoord extent, ///< Extent of tensor + int thread_id, ///< ID of each participating thread + int group_size) + : FineGrainedScaleZeroIterator(params, + pointer_scale, + pointer_zero, + extent, + thread_id, + make_Coord(0, 0), + group_size) {} + + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const& tile_offset) { + const LongIndex row_byte_offset = tile_offset.row() * params_.inc_advance_; + const LongIndex col_byte_offset = + tile_offset.column() * Shape::kColumn * sizeof_bits::value / 8; + pointer_scale_ += row_byte_offset + col_byte_offset; + if (pointer_zero_ != nullptr) { + pointer_zero_ += row_byte_offset + col_byte_offset; } + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE void clear_mask(bool enable = true) { + is_valid_ &= (!enable); + } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() const { return is_valid_; } + + /// Returns a scale pointer + CUTLASS_HOST_DEVICE + AccessType* get_scale() const { + return reinterpret_cast(pointer_scale_); + } + + /// Returns a zero pointer + CUTLASS_HOST_DEVICE + AccessType* get_zero() const { + return reinterpret_cast(pointer_zero_); + } }; -} // namespace threadblock -} // namespace transform -} // namespace cutlass +} // namespace threadblock +} // namespace transform +} // namespace cutlass diff --git a/custom_ops/gpu_ops/cutlass_extensions/util/gather_tensor.hpp b/custom_ops/gpu_ops/cutlass_extensions/util/gather_tensor.hpp index b430380b014..b29fc4db5f6 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/util/gather_tensor.hpp +++ b/custom_ops/gpu_ops/cutlass_extensions/util/gather_tensor.hpp @@ -1,12 +1,12 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation @@ -18,14 +18,15 @@ * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ #pragma once @@ -38,144 +39,148 @@ using namespace cute; /// Function object that applies an index to its argument template -struct IndexedGather -{ - CUTE_HOST_DEVICE constexpr IndexedGather(Iter indices = {}) - : indices_(indices) - { - } - - template - CUTE_HOST_DEVICE constexpr auto operator()(I i) const - { - return indices_[i]; - } - - CUTE_HOST_DEVICE friend void print(IndexedGather const& s) - { - cute::print("Indexed{"); - print(s.indices_); - print("}"); - } - - Iter indices_; +struct IndexedGather { + CUTE_HOST_DEVICE constexpr IndexedGather(Iter indices = {}) + : indices_(indices) {} + + template + CUTE_HOST_DEVICE constexpr auto operator()(I i) const { + return indices_[i]; + } + + CUTE_HOST_DEVICE friend void print(IndexedGather const& s) { + cute::print("Indexed{"); + print(s.indices_); + print("}"); + } + + Iter indices_; }; /// Custom stride object that applies a function followed by a stride template -struct CustomStride -{ - CUTE_HOST_DEVICE constexpr CustomStride(Func const& func, Stride const& stride) - : func_(func) - , stride_(stride) - { - } - - template - CUTE_HOST_DEVICE constexpr friend auto operator*(I i, CustomStride const& s) - { - return s.func_(i) * s.stride_; - } - - template - CUTE_HOST_DEVICE constexpr friend auto operator*(CustomStride const& s, I i) - { - return s.func_(i) * s.stride_; - } - - CUTE_HOST_DEVICE friend void print(CustomStride const& s) - { - cute::print("Custom{"); - print(s.func_); - cute::print(","); - print(s.stride_); - cute::print("}"); - } - - template - CUTE_HOST_DEVICE constexpr friend auto safe_div(CustomStride const& s, Div const& div) - { - return CustomStride(s.func_, safe_div(s.stride_, div)); - } - - // Circumvent the requirement on make_layout that shape and stride are integral - template - CUTE_HOST_DEVICE constexpr friend auto make_layout(Shape const& shape, CustomStride const& stride) - { - return Layout(shape, stride); - } - - Func func_; - Stride stride_; +struct CustomStride { + CUTE_HOST_DEVICE constexpr CustomStride(Func const& func, + Stride const& stride) + : func_(func), stride_(stride) {} + + template + CUTE_HOST_DEVICE constexpr friend auto operator*(I i, CustomStride const& s) { + return s.func_(i) * s.stride_; + } + + template + CUTE_HOST_DEVICE constexpr friend auto operator*(CustomStride const& s, I i) { + return s.func_(i) * s.stride_; + } + + CUTE_HOST_DEVICE friend void print(CustomStride const& s) { + cute::print("Custom{"); + print(s.func_); + cute::print(","); + print(s.stride_); + cute::print("}"); + } + + template + CUTE_HOST_DEVICE constexpr friend auto safe_div(CustomStride const& s, + Div const& div) { + return CustomStride( + s.func_, safe_div(s.stride_, div)); + } + + // Circumvent the requirement on make_layout that shape and stride are + // integral + template + CUTE_HOST_DEVICE constexpr friend auto make_layout( + Shape const& shape, CustomStride const& stride) { + return Layout(shape, stride); + } + + Func func_; + Stride stride_; }; template -CUTLASS_HOST_DEVICE auto make_custom_stride_layout(Stride const& stride, Func&& func) -{ - // Use a dummy shape and replace the first non-unit and non-zero stride with a custom gather stride - auto idx = find_if(stride, [](auto x) { return !is_constant<1, decltype(x)>{} && !is_constant<0, decltype(x)>{}; }); - constexpr int I = decltype(idx)::value; - return make_layout( - repeat_like(stride, _1{}), replace(stride, CustomStride{static_cast(func), get(stride)})); +CUTLASS_HOST_DEVICE auto make_custom_stride_layout(Stride const& stride, + Func&& func) { + // Use a dummy shape and replace the first non-unit and non-zero stride with a + // custom gather stride + auto idx = find_if(stride, [](auto x) { + return !is_constant<1, decltype(x)>{} && !is_constant<0, decltype(x)>{}; + }); + constexpr int I = decltype(idx)::value; + return make_layout( + repeat_like(stride, _1{}), + replace(stride, + CustomStride{static_cast(func), get(stride)})); } /// Helper function to optionally create a gather tensor template -CUTLASS_HOST_DEVICE auto make_gather_tensor(Iterator iter, Shape const& shape, Stride const& stride, Func&& func) -{ - Layout matrix_layout = make_identity_layout(shape); - auto offset = as_arithmetic_tuple(repeat_like(shape, _0{})); - Layout gather_layout = make_custom_stride_layout(stride, static_cast(func)); - return make_tensor(iter, ComposedLayout{gather_layout, offset, matrix_layout}); +CUTLASS_HOST_DEVICE auto make_gather_tensor(Iterator iter, + Shape const& shape, + Stride const& stride, + Func&& func) { + Layout matrix_layout = make_identity_layout(shape); + auto offset = as_arithmetic_tuple(repeat_like(shape, _0{})); + Layout gather_layout = + make_custom_stride_layout(stride, static_cast(func)); + return make_tensor(iter, + ComposedLayout{gather_layout, offset, matrix_layout}); } -namespace cute -{ +namespace cute { template -CUTE_HOST_DEVICE constexpr auto upcast(Shape const& shape, Stride const& stride) -{ - if constexpr (is_tuple::value) - { - return transform_layout(shape, stride, [](auto const& s, auto const& d) { return upcast(s, d); }); - } - else if constexpr (is_scaled_basis::value) - { - if constexpr (Stride::mode() == I) - { - return make_layout(shape_div(shape, Int{}), shape_div(stride, Int{})); - } - else - { - return make_layout(shape, stride); - } - } - else - { - return upcast(shape, stride); +CUTE_HOST_DEVICE constexpr auto upcast(Shape const& shape, + Stride const& stride) { + if constexpr (is_tuple::value) { + return transform_layout(shape, stride, [](auto const& s, auto const& d) { + return upcast(s, d); + }); + } else if constexpr (is_scaled_basis::value) { + if constexpr (Stride::mode() == I) { + return make_layout(shape_div(shape, Int{}), + shape_div(stride, Int{})); + } else { + return make_layout(shape, stride); } + } else { + return upcast(shape, stride); + } - CUTE_GCC_UNREACHABLE; + CUTE_GCC_UNREACHABLE; } -template +template CUTE_HOST_DEVICE constexpr auto upcast( - ComposedLayout, Offset, Layout> const& layout) -{ - // Find index of the stride-1 mode - that is the only one that requires updating inner shape and offset - auto idx = find_if(layout.layout_a().stride(), [](auto x) { return is_constant<1, decltype(x)>{}; }); - constexpr int I = decltype(idx)::value; - - // Upcast the outer layout (works as expected) - auto outer = upcast(layout.layout_a()); - - // Upcast the accumulated offset along stride-1 mode - auto offset = as_arithmetic_tuple(replace(layout.offset(), upcast(get(layout.offset())))); - - // Upcast the inner layout's shape along stride-1 mode - auto inner = upcast(layout.layout_b().shape(), layout.layout_b().stride()); - - return composition(outer, offset, inner); + ComposedLayout, + Offset, + Layout> const& layout) { + // Find index of the stride-1 mode - that is the only one that requires + // updating inner shape and offset + auto idx = find_if(layout.layout_a().stride(), + [](auto x) { return is_constant<1, decltype(x)>{}; }); + constexpr int I = decltype(idx)::value; + + // Upcast the outer layout (works as expected) + auto outer = upcast(layout.layout_a()); + + // Upcast the accumulated offset along stride-1 mode + auto offset = as_arithmetic_tuple( + replace(layout.offset(), upcast(get(layout.offset())))); + + // Upcast the inner layout's shape along stride-1 mode + auto inner = + upcast(layout.layout_b().shape(), layout.layout_b().stride()); + + return composition(outer, offset, inner); } -} // namespace cute +} // namespace cute diff --git a/custom_ops/gpu_ops/cutlass_extensions/weight_only_quant_op.h b/custom_ops/gpu_ops/cutlass_extensions/weight_only_quant_op.h index 64774428e9f..9f2d2552346 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/weight_only_quant_op.h +++ b/custom_ops/gpu_ops/cutlass_extensions/weight_only_quant_op.h @@ -1,12 +1,12 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation @@ -18,41 +18,40 @@ * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /*! \file - \brief Defines iterators used by warp-level matrix multiply operations targeting Tensor Cores. + \brief Defines iterators used by warp-level matrix multiply operations + targeting Tensor Cores. */ #pragma once -namespace cutlass -{ +namespace cutlass { -enum class WeightOnlyQuantOp -{ - UNDEFINED, - PER_COLUMN_SCALE_ONLY, - FINEGRAINED_SCALE_ONLY, - FINEGRAINED_SCALE_AND_ZEROS +enum class WeightOnlyQuantOp { + UNDEFINED, + PER_COLUMN_SCALE_ONLY, + FINEGRAINED_SCALE_ONLY, + FINEGRAINED_SCALE_AND_ZEROS }; -constexpr bool isFinegrained(WeightOnlyQuantOp op) -{ - return op == WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS || op == WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY; +constexpr bool isFinegrained(WeightOnlyQuantOp op) { + return op == WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS || + op == WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY; } -constexpr bool hasZero(WeightOnlyQuantOp op) -{ - return op == WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS; +constexpr bool hasZero(WeightOnlyQuantOp op) { + return op == WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS; } -} // namespace cutlass +} // namespace cutlass diff --git a/custom_ops/gpu_ops/cutlass_extensions/wint_type_traits.h b/custom_ops/gpu_ops/cutlass_extensions/wint_type_traits.h index 9e1c6c463b6..da6fcf41a26 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/wint_type_traits.h +++ b/custom_ops/gpu_ops/cutlass_extensions/wint_type_traits.h @@ -33,19 +33,23 @@ enum WintQuantMethod { }; // Convert CUDA data type to cutlass data type -template struct CutlassDataType { +template +struct CutlassDataType { using Type = T; }; -template <> struct CutlassDataType { +template <> +struct CutlassDataType { using Type = cutlass::half_t; }; -template <> struct CutlassDataType<__nv_bfloat16> { +template <> +struct CutlassDataType<__nv_bfloat16> { using Type = cutlass::bfloat16_t; }; -template struct WintQuantTraits; +template +struct WintQuantTraits; template struct WintQuantTraits { @@ -125,10 +129,13 @@ struct WintQuantTraits { static constexpr int32_t kNumPackedValues = 4; static constexpr int32_t kPackedSize = 16; + using LocalScaleType = uint4b_t; + using CodeScaleZpType = float; + struct Arguments { - const uint8_t *local_scale_ptr; // quanted 4-bits - const float *code_scale_ptr; - const float *code_zp_ptr; + uint8_t *local_scale_ptr; // quanted 4-bits + float *code_scale_ptr; + float *code_zp_ptr; }; CUTLASS_DEVICE @@ -137,4 +144,4 @@ struct WintQuantTraits { } }; -} // namespace cutlass +} // namespace cutlass diff --git a/custom_ops/gpu_ops/cutlass_kernels/cutlass_helper.h b/custom_ops/gpu_ops/cutlass_kernels/cutlass_helper.h index 3ac548c6258..f9e4aad2ff1 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/cutlass_helper.h +++ b/custom_ops/gpu_ops/cutlass_kernels/cutlass_helper.h @@ -24,11 +24,11 @@ /** * Helper function for checking CUTLASS errors */ -#define CUTLASS_CHECK(status) \ - { \ - cutlass::Status error = status; \ - PD_CHECK(error == cutlass::Status::kSuccess, \ - cutlassGetStatusString(error)); \ +#define CUTLASS_CHECK(status) \ + { \ + cutlass::Status error = status; \ + PD_CHECK(error == cutlass::Status::kSuccess, \ + cutlassGetStatusString(error)); \ } /** @@ -38,44 +38,50 @@ * __CUDA_ARCH__ is not defined in host code, so this lets us smuggle the ifdef * into code that will be executed on the device where it is defined. */ -template struct enable_sm90_or_later : Kernel { - template CUTLASS_DEVICE void operator()(Args &&...args) { +template +struct enable_sm90_or_later : Kernel { + template + CUTLASS_DEVICE void operator()(Args &&...args) { #if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 900 Kernel::operator()(std::forward(args)...); #endif } }; -template class CutlassDtypeTraits; +template +class CutlassDtypeTraits; -template <> class CutlassDtypeTraits { -public: +template <> +class CutlassDtypeTraits { + public: typedef float DataType; typedef float data_t; }; -template <> class CutlassDtypeTraits { -public: +template <> +class CutlassDtypeTraits { + public: typedef cutlass::half_t DataType; typedef paddle::float16 data_t; }; -template <> class CutlassDtypeTraits { -public: +template <> +class CutlassDtypeTraits { + public: typedef cutlass::bfloat16_t DataType; typedef paddle::bfloat16 data_t; }; class CutlassGemmConfigMannager { -public: + public: static CutlassGemmConfigMannager &getInstance() { static CutlassGemmConfigMannager instance; return instance; } CutlassGemmConfigMannager(const CutlassGemmConfigMannager &) = delete; - CutlassGemmConfigMannager & - operator=(const CutlassGemmConfigMannager &) = delete; + CutlassGemmConfigMannager &operator=(const CutlassGemmConfigMannager &) = + delete; void up_date_configs(const nlohmann::json &j) { std::lock_guard lock(mutex_); @@ -102,7 +108,7 @@ class CutlassGemmConfigMannager { return &json_; } -private: + private: void save_gemm_best_configs_(const std::string &config_file_path) { std::ifstream file(config_file_path); if (!file.good()) { diff --git a/custom_ops/gpu_ops/cutlass_kernels/cutlass_heuristic.cu b/custom_ops/gpu_ops/cutlass_kernels/cutlass_heuristic.cu index 6db16981c67..6ea5a275ad5 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/cutlass_heuristic.cu +++ b/custom_ops/gpu_ops/cutlass_kernels/cutlass_heuristic.cu @@ -19,14 +19,14 @@ #ifndef _WIN32 #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wstrict-aliasing" -#endif // #ifndef _WIN32 +#endif // #ifndef _WIN32 #include "cutlass/gemm/gemm.h" #include "cutlass/numeric_types.h" #ifndef _WIN32 #pragma GCC diagnostic pop -#endif // #ifndef _WIN32 +#endif // #ifndef _WIN32 #include #include @@ -35,491 +35,509 @@ using namespace cutlass_extensions; -namespace kernels -{ -namespace cutlass_kernels -{ +namespace kernels { +namespace cutlass_kernels { -struct TileShape -{ - int m; - int n; +struct TileShape { + int m; + int n; }; -TileShape get_cta_shape_for_config(CutlassTileConfig tile_config) -{ - switch (tile_config) - { - case CutlassTileConfig::CtaShape16x128x64_WarpShape16x32x64: return TileShape{16, 128}; - case CutlassTileConfig::CtaShape16x256x64_WarpShape16x64x64: return TileShape{16, 256}; - case CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64: return TileShape{32, 128}; - case CutlassTileConfig::CtaShape64x64x128_WarpShape32x64x64: return TileShape{64, 64}; +TileShape get_cta_shape_for_config(CutlassTileConfig tile_config) { + switch (tile_config) { + case CutlassTileConfig::CtaShape16x128x64_WarpShape16x32x64: + return TileShape{16, 128}; + case CutlassTileConfig::CtaShape16x256x64_WarpShape16x64x64: + return TileShape{16, 256}; + case CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64: + return TileShape{32, 128}; + case CutlassTileConfig::CtaShape64x64x128_WarpShape32x64x64: + return TileShape{64, 64}; case CutlassTileConfig::CtaShape64x128x64_WarpShape32x64x64: - case CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64: return TileShape{64, 128}; - case CutlassTileConfig::CtaShape128x64x64_WarpShape64x32x64: return TileShape{128, 64}; + case CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64: + return TileShape{64, 128}; + case CutlassTileConfig::CtaShape128x64x64_WarpShape64x32x64: + return TileShape{128, 64}; case CutlassTileConfig::CtaShape128x128x8_WarpShape64x64x8: case CutlassTileConfig::CtaShape128x128x64_WarpShape64x32x64: case CutlassTileConfig::CtaShape128x128x64_WarpShape64x64x64: - case CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64: return TileShape{128, 128}; - case CutlassTileConfig::CtaShape128x256x64_WarpShape64x64x64: return TileShape{128, 256}; - case CutlassTileConfig::CtaShape256x128x64_WarpShape64x64x64: return TileShape{256, 128}; - case CutlassTileConfig::CtaShape16x256x128_WarpShape16x64x128: return TileShape{16, 256}; - default: throw("[get_grid_shape_for_config] Invalid config"); - } + case CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64: + return TileShape{128, 128}; + case CutlassTileConfig::CtaShape128x256x64_WarpShape64x64x64: + return TileShape{128, 256}; + case CutlassTileConfig::CtaShape256x128x64_WarpShape64x64x64: + return TileShape{256, 128}; + case CutlassTileConfig::CtaShape16x256x128_WarpShape16x64x128: + return TileShape{16, 256}; + default: + throw("[get_grid_shape_for_config] Invalid config"); + } } -bool is_valid_split_k_factor(int64_t const m, int64_t const n, int64_t const k, TileShape const tile_shape, - int const split_k_factor, size_t const workspace_bytes, bool const is_weight_only) -{ - - // All tile sizes have a k_tile of 64. - static constexpr int k_tile = 128; - - // For weight-only quant, we need k and k_elements_per_split to be a multiple of cta_k - if (is_weight_only) - { - if ((k % k_tile) != 0) - { - return false; - } +bool is_valid_split_k_factor(int64_t const m, + int64_t const n, + int64_t const k, + TileShape const tile_shape, + int const split_k_factor, + size_t const workspace_bytes, + bool const is_weight_only) { + // All tile sizes have a k_tile of 64. + static constexpr int k_tile = 128; + + // For weight-only quant, we need k and k_elements_per_split to be a multiple + // of cta_k + if (is_weight_only) { + if ((k % k_tile) != 0) { + return false; + } - if ((k % split_k_factor) != 0) - { - return false; - } + if ((k % split_k_factor) != 0) { + return false; + } - int const k_elements_per_split = k / split_k_factor; - if ((k_elements_per_split % k_tile) != 0) - { - return false; - } + int const k_elements_per_split = k / split_k_factor; + if ((k_elements_per_split % k_tile) != 0) { + return false; } + } - // Check that the workspace has sufficient space for this split-k factor - int const ctas_in_m_dim = (m + tile_shape.m - 1) / tile_shape.m; - int const ctas_in_n_dim = (n + tile_shape.n - 1) / tile_shape.n; - int const required_ws_bytes = split_k_factor == 1 ? 0 : sizeof(int) * ctas_in_m_dim * ctas_in_n_dim; + // Check that the workspace has sufficient space for this split-k factor + int const ctas_in_m_dim = (m + tile_shape.m - 1) / tile_shape.m; + int const ctas_in_n_dim = (n + tile_shape.n - 1) / tile_shape.n; + int const required_ws_bytes = + split_k_factor == 1 ? 0 : sizeof(int) * ctas_in_m_dim * ctas_in_n_dim; - if (required_ws_bytes > workspace_bytes) - { - return false; - } + if (required_ws_bytes > workspace_bytes) { + return false; + } - return true; + return true; } std::vector get_candidate_tiles( - int const sm, CutlassGemmConfig::CandidateConfigTypeParam const config_type_param) -{ - enum class CutlassGemmType : char - { - Default, - WeightOnly, - Simt, - Int8, - Fp8 - }; - - CutlassGemmType gemm_type = CutlassGemmType::Default; - if (config_type_param & CutlassGemmConfig::SIMT_ONLY) - { - gemm_type = CutlassGemmType::Simt; - } - else if (config_type_param & CutlassGemmConfig::WEIGHT_ONLY) - { - gemm_type = CutlassGemmType::WeightOnly; - } - else if (config_type_param & CutlassGemmConfig::INT8_ONLY) - { - gemm_type = CutlassGemmType::Int8; - } - else if (config_type_param & CutlassGemmConfig::FP8_ONLY) - { - gemm_type = CutlassGemmType::Fp8; - } - - std::vector base_configs{ - CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64, CutlassTileConfig::CtaShape64x128x64_WarpShape32x64x64}; - if (sm >= 75) - { - base_configs.push_back(CutlassTileConfig::CtaShape128x128x64_WarpShape64x32x64); - } - - switch (gemm_type) - { - case CutlassGemmType::Simt: return {CutlassTileConfig::CtaShape128x128x8_WarpShape64x64x8}; + int const sm, + CutlassGemmConfig::CandidateConfigTypeParam const config_type_param) { + enum class CutlassGemmType : char { Default, WeightOnly, Simt, Int8, Fp8 }; + + CutlassGemmType gemm_type = CutlassGemmType::Default; + if (config_type_param & CutlassGemmConfig::SIMT_ONLY) { + gemm_type = CutlassGemmType::Simt; + } else if (config_type_param & CutlassGemmConfig::WEIGHT_ONLY) { + gemm_type = CutlassGemmType::WeightOnly; + } else if (config_type_param & CutlassGemmConfig::INT8_ONLY) { + gemm_type = CutlassGemmType::Int8; + } else if (config_type_param & CutlassGemmConfig::FP8_ONLY) { + gemm_type = CutlassGemmType::Fp8; + } + + std::vector base_configs{ + CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64, + CutlassTileConfig::CtaShape64x128x64_WarpShape32x64x64}; + if (sm >= 75) { + base_configs.push_back( + CutlassTileConfig::CtaShape128x128x64_WarpShape64x32x64); + } + + switch (gemm_type) { + case CutlassGemmType::Simt: + return {CutlassTileConfig::CtaShape128x128x8_WarpShape64x64x8}; case CutlassGemmType::WeightOnly: - if (sm >= 75) - { - return {CutlassTileConfig::CtaShape16x128x64_WarpShape16x32x64, + if (sm >= 75) { + return {CutlassTileConfig::CtaShape16x128x64_WarpShape16x32x64, CutlassTileConfig::CtaShape16x256x64_WarpShape16x64x64, CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64, CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64, CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64}; - } - else - { - return {CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64, + } else { + return {CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64, CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64}; - } + } case CutlassGemmType::Int8: - return {CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64, - CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64, - CutlassTileConfig::CtaShape128x64x64_WarpShape64x32x64, - CutlassTileConfig::CtaShape64x64x128_WarpShape32x64x64, - CutlassTileConfig::CtaShape128x256x64_WarpShape64x64x64, - CutlassTileConfig::CtaShape256x128x64_WarpShape64x64x64}; + return {CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64, + CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64, + CutlassTileConfig::CtaShape128x64x64_WarpShape64x32x64, + CutlassTileConfig::CtaShape64x64x128_WarpShape32x64x64, + CutlassTileConfig::CtaShape128x256x64_WarpShape64x64x64, + CutlassTileConfig::CtaShape256x128x64_WarpShape64x64x64}; case CutlassGemmType::Fp8: - if (config_type_param & CutlassGemmConfig::GROUPED_GEMM) - { - if (sm == 89) - { - return {CutlassTileConfig::CtaShape16x256x128_WarpShape16x64x128, - CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64, - CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64, - CutlassTileConfig::CtaShape64x64x128_WarpShape32x64x64, - CutlassTileConfig::CtaShape128x64x64_WarpShape64x32x64, - CutlassTileConfig::CtaShape128x256x64_WarpShape64x64x64, - CutlassTileConfig::CtaShape256x128x64_WarpShape64x64x64}; - } - else - { - // no valid ampere style fp8 configs for sm90 - return {}; - } + if (config_type_param & CutlassGemmConfig::GROUPED_GEMM) { + if (sm == 89) { + return {CutlassTileConfig::CtaShape16x256x128_WarpShape16x64x128, + CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64, + CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64, + CutlassTileConfig::CtaShape64x64x128_WarpShape32x64x64, + CutlassTileConfig::CtaShape128x64x64_WarpShape64x32x64, + CutlassTileConfig::CtaShape128x256x64_WarpShape64x64x64, + CutlassTileConfig::CtaShape256x128x64_WarpShape64x64x64}; + } else { + // no valid ampere style fp8 configs for sm90 + return {}; } - default: return base_configs; - } + } + default: + return base_configs; + } } std::vector get_candidate_tiles_sm90( - int const sm, CutlassGemmConfig::CandidateConfigTypeParam const config) -{ + int const sm, CutlassGemmConfig::CandidateConfigTypeParam const config) { #ifdef FAST_BUILD - // Fast build disables all configs except this one for SM90 - return {CutlassTileConfigSM90::CtaShape128x128x128B}; + // Fast build disables all configs except this one for SM90 + return {CutlassTileConfigSM90::CtaShape128x128x128B}; #else - if (config & CutlassGemmConfig::GROUPED_GEMM) - { - return {CutlassTileConfigSM90::CtaShape128x16x128B, CutlassTileConfigSM90::CtaShape128x32x128B, - CutlassTileConfigSM90::CtaShape128x64x128B, CutlassTileConfigSM90::CtaShape128x128x128B, - CutlassTileConfigSM90::CtaShape128x256x128B, CutlassTileConfigSM90::CtaShape256x128x128B}; - } - else - { - return {CutlassTileConfigSM90::CtaShape64x16x128B, CutlassTileConfigSM90::CtaShape64x32x128B, - CutlassTileConfigSM90::CtaShape64x64x128B, CutlassTileConfigSM90::CtaShape64x128x128B, - CutlassTileConfigSM90::CtaShape64x256x128B, CutlassTileConfigSM90::CtaShape128x16x128B, - CutlassTileConfigSM90::CtaShape128x32x128B, CutlassTileConfigSM90::CtaShape128x64x128B, - CutlassTileConfigSM90::CtaShape128x128x128B, CutlassTileConfigSM90::CtaShape128x256x128B}; - } + if (config & CutlassGemmConfig::GROUPED_GEMM) { + return {CutlassTileConfigSM90::CtaShape128x16x128B, + CutlassTileConfigSM90::CtaShape128x32x128B, + CutlassTileConfigSM90::CtaShape128x64x128B, + CutlassTileConfigSM90::CtaShape128x128x128B, + CutlassTileConfigSM90::CtaShape128x256x128B, + CutlassTileConfigSM90::CtaShape256x128x128B}; + } else { + return {CutlassTileConfigSM90::CtaShape64x16x128B, + CutlassTileConfigSM90::CtaShape64x32x128B, + CutlassTileConfigSM90::CtaShape64x64x128B, + CutlassTileConfigSM90::CtaShape64x128x128B, + CutlassTileConfigSM90::CtaShape64x256x128B, + CutlassTileConfigSM90::CtaShape128x16x128B, + CutlassTileConfigSM90::CtaShape128x32x128B, + CutlassTileConfigSM90::CtaShape128x64x128B, + CutlassTileConfigSM90::CtaShape128x128x128B, + CutlassTileConfigSM90::CtaShape128x256x128B}; + } #endif } -// We only compile CUTLASS kernels with multi-cast along M if the M tile is >= 128. This is purely to improve -// compilation speed. -bool supports_mcast_along_m(CutlassTileConfigSM90 const tile) -{ +// We only compile CUTLASS kernels with multi-cast along M if the M tile is >= +// 128. This is purely to improve compilation speed. +bool supports_mcast_along_m(CutlassTileConfigSM90 const tile) { #ifdef FAST_BUILD - return false; + return false; #else - std::set valid_tiles{CutlassTileConfigSM90::CtaShape128x16x128B, - CutlassTileConfigSM90::CtaShape128x32x128B, CutlassTileConfigSM90::CtaShape128x64x128B, - CutlassTileConfigSM90::CtaShape128x128x128B, CutlassTileConfigSM90::CtaShape128x256x128B, - CutlassTileConfigSM90::CtaShape256x128x128B}; - return valid_tiles.count(tile) == 1; + std::set valid_tiles{ + CutlassTileConfigSM90::CtaShape128x16x128B, + CutlassTileConfigSM90::CtaShape128x32x128B, + CutlassTileConfigSM90::CtaShape128x64x128B, + CutlassTileConfigSM90::CtaShape128x128x128B, + CutlassTileConfigSM90::CtaShape128x256x128B, + CutlassTileConfigSM90::CtaShape256x128x128B}; + return valid_tiles.count(tile) == 1; #endif } -// We only compile CUTLASS kernels with multi-cast along N if the N tile is >= 128. This is purely to improve -// compilation speed. -bool supports_mcast_along_n(CutlassTileConfigSM90 const tile) -{ +// We only compile CUTLASS kernels with multi-cast along N if the N tile is >= +// 128. This is purely to improve compilation speed. +bool supports_mcast_along_n(CutlassTileConfigSM90 const tile) { #ifdef FAST_BUILD - return false; + return false; #else - std::set valid_tiles{CutlassTileConfigSM90::CtaShape64x128x128B, - CutlassTileConfigSM90::CtaShape64x256x128B, CutlassTileConfigSM90::CtaShape128x128x128B, - CutlassTileConfigSM90::CtaShape128x256x128B, CutlassTileConfigSM90::CtaShape256x128x128B}; - return valid_tiles.count(tile) == 1; + std::set valid_tiles{ + CutlassTileConfigSM90::CtaShape64x128x128B, + CutlassTileConfigSM90::CtaShape64x256x128B, + CutlassTileConfigSM90::CtaShape128x128x128B, + CutlassTileConfigSM90::CtaShape128x256x128B, + CutlassTileConfigSM90::CtaShape256x128x128B}; + return valid_tiles.count(tile) == 1; #endif } // SM100 (Blackwell) candidate tile configurations std::vector get_candidate_tiles_sm100( - int /*sm*/, CutlassGemmConfig::CandidateConfigTypeParam const config) -{ + int /*sm*/, CutlassGemmConfig::CandidateConfigTypeParam const config) { #ifdef FAST_BUILD - return {CutlassTileConfigSM100::CtaShape128x128x128B}; + return {CutlassTileConfigSM100::CtaShape128x128x128B}; #else - /* Grouped-GEMM path first (Blackwell uses 1-SM and 2-SM “cluster” kernels) */ - if (config & CutlassGemmConfig::GROUPED_GEMM) + /* Grouped-GEMM path first (Blackwell uses 1-SM and 2-SM “cluster” kernels) + */ + if (config & CutlassGemmConfig::GROUPED_GEMM) { + if (config & CutlassGemmConfig::FP4_ONLY) // nvfp4 / mx_fp4 { - if (config & CutlassGemmConfig::FP4_ONLY) // nvfp4 / mx_fp4 - { - return { - /* 1 SM (M=128) */ - CutlassTileConfigSM100::CtaShape128x128x128B, - CutlassTileConfigSM100::CtaShape128x256x128B, - /* 2 SM (M=256) */ - CutlassTileConfigSM100::CtaShape256x128x128B, - CutlassTileConfigSM100::CtaShape256x256x128B, - /* slim tiles for very tall matrices */ - CutlassTileConfigSM100::CtaShape128x64x128B, - CutlassTileConfigSM100::CtaShape256x64x128B}; - } + return {/* 1 SM (M=128) */ + CutlassTileConfigSM100::CtaShape128x128x128B, + CutlassTileConfigSM100::CtaShape128x256x128B, + /* 2 SM (M=256) */ + CutlassTileConfigSM100::CtaShape256x128x128B, + CutlassTileConfigSM100::CtaShape256x256x128B, + /* slim tiles for very tall matrices */ + CutlassTileConfigSM100::CtaShape128x64x128B, + CutlassTileConfigSM100::CtaShape256x64x128B}; + } - /* Fp8 / Fp16 grouped-GEMM */ - return { - CutlassTileConfigSM100::CtaShape128x128x128B, + /* Fp8 / Fp16 grouped-GEMM */ + return {CutlassTileConfigSM100::CtaShape128x128x128B, CutlassTileConfigSM100::CtaShape128x256x128B, CutlassTileConfigSM100::CtaShape256x128x128B, CutlassTileConfigSM100::CtaShape256x256x128B}; - } - - /* Non-grouped path (plain GEMM or weight-only) */ - return { - /* 1 SM tiles */ - CutlassTileConfigSM100::CtaShape64x64x128B, - CutlassTileConfigSM100::CtaShape64x128x128B, - CutlassTileConfigSM100::CtaShape64x256x128B, - CutlassTileConfigSM100::CtaShape128x64x128B, - CutlassTileConfigSM100::CtaShape128x128x128B, - CutlassTileConfigSM100::CtaShape128x256x128B, - /* 2 SM tiles */ - CutlassTileConfigSM100::CtaShape256x64x128B, - CutlassTileConfigSM100::CtaShape256x128x128B, - CutlassTileConfigSM100::CtaShape256x256x128B}; + } + + /* Non-grouped path (plain GEMM or weight-only) */ + return {/* 1 SM tiles */ + CutlassTileConfigSM100::CtaShape64x64x128B, + CutlassTileConfigSM100::CtaShape64x128x128B, + CutlassTileConfigSM100::CtaShape64x256x128B, + CutlassTileConfigSM100::CtaShape128x64x128B, + CutlassTileConfigSM100::CtaShape128x128x128B, + CutlassTileConfigSM100::CtaShape128x256x128B, + /* 2 SM tiles */ + CutlassTileConfigSM100::CtaShape256x64x128B, + CutlassTileConfigSM100::CtaShape256x128x128B, + CutlassTileConfigSM100::CtaShape256x256x128B}; #endif } // M-multicast support for SM100. -bool supports_mcast_along_m_sm100(CutlassTileConfigSM100 tile) -{ +bool supports_mcast_along_m_sm100(CutlassTileConfigSM100 tile) { #ifdef FAST_BUILD - return false; + return false; #else - std::set m_tiles{ - CutlassTileConfigSM100::CtaShape128x64x128B, - CutlassTileConfigSM100::CtaShape128x128x128B, - CutlassTileConfigSM100::CtaShape128x256x128B, - CutlassTileConfigSM100::CtaShape256x64x128B, - CutlassTileConfigSM100::CtaShape256x128x128B, - CutlassTileConfigSM100::CtaShape256x256x128B}; - return m_tiles.count(tile) == 1; + std::set m_tiles{ + CutlassTileConfigSM100::CtaShape128x64x128B, + CutlassTileConfigSM100::CtaShape128x128x128B, + CutlassTileConfigSM100::CtaShape128x256x128B, + CutlassTileConfigSM100::CtaShape256x64x128B, + CutlassTileConfigSM100::CtaShape256x128x128B, + CutlassTileConfigSM100::CtaShape256x256x128B}; + return m_tiles.count(tile) == 1; #endif } // N-multicast support for SM100. -bool supports_mcast_along_n_sm100(CutlassTileConfigSM100 tile) -{ +bool supports_mcast_along_n_sm100(CutlassTileConfigSM100 tile) { #ifdef FAST_BUILD - return false; + return false; #else - std::set n_tiles{ - CutlassTileConfigSM100::CtaShape64x128x128B, - CutlassTileConfigSM100::CtaShape64x256x128B, - CutlassTileConfigSM100::CtaShape128x128x128B, - CutlassTileConfigSM100::CtaShape128x256x128B, - CutlassTileConfigSM100::CtaShape256x128x128B}; - return n_tiles.count(tile) == 1; + std::set n_tiles{ + CutlassTileConfigSM100::CtaShape64x128x128B, + CutlassTileConfigSM100::CtaShape64x256x128B, + CutlassTileConfigSM100::CtaShape128x128x128B, + CutlassTileConfigSM100::CtaShape128x256x128B, + CutlassTileConfigSM100::CtaShape256x128x128B}; + return n_tiles.count(tile) == 1; #endif } - std::vector get_candidate_configs( - int sm, int const max_split_k, CutlassGemmConfig::CandidateConfigTypeParam const config_type_param) -{ - if (sm == 90 && (config_type_param & CutlassGemmConfig::HOPPER)) - { - std::vector tiles = get_candidate_tiles_sm90(sm, config_type_param); - - std::vector candidate_configs; - for (auto const& tile_config : tiles) - { - CutlassGemmConfig config( - tile_config, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_1x1x1); - candidate_configs.push_back(config); - - bool const has_m_mcast = supports_mcast_along_m(tile_config); - bool const has_n_mcast = supports_mcast_along_n(tile_config); - if (has_m_mcast) - { - CutlassGemmConfig config(tile_config, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, - ClusterShape::ClusterShape_2x1x1); - candidate_configs.push_back(config); - } - - if (has_n_mcast) - { - CutlassGemmConfig config(tile_config, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, - ClusterShape::ClusterShape_1x2x1); - candidate_configs.push_back(config); - } - - if (has_m_mcast && has_n_mcast) - { - CutlassGemmConfig config(tile_config, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, - ClusterShape::ClusterShape_2x2x1); - candidate_configs.push_back(config); - } - } - return candidate_configs; + int sm, + int const max_split_k, + CutlassGemmConfig::CandidateConfigTypeParam const config_type_param) { + if (sm == 90 && (config_type_param & CutlassGemmConfig::HOPPER)) { + std::vector tiles = + get_candidate_tiles_sm90(sm, config_type_param); + + std::vector candidate_configs; + for (auto const& tile_config : tiles) { + CutlassGemmConfig config(tile_config, + MainloopScheduleType::AUTO, + EpilogueScheduleType::AUTO, + ClusterShape::ClusterShape_1x1x1); + candidate_configs.push_back(config); + + bool const has_m_mcast = supports_mcast_along_m(tile_config); + bool const has_n_mcast = supports_mcast_along_n(tile_config); + if (has_m_mcast) { + CutlassGemmConfig config(tile_config, + MainloopScheduleType::AUTO, + EpilogueScheduleType::AUTO, + ClusterShape::ClusterShape_2x1x1); + candidate_configs.push_back(config); + } + + if (has_n_mcast) { + CutlassGemmConfig config(tile_config, + MainloopScheduleType::AUTO, + EpilogueScheduleType::AUTO, + ClusterShape::ClusterShape_1x2x1); + candidate_configs.push_back(config); + } + + if (has_m_mcast && has_n_mcast) { + CutlassGemmConfig config(tile_config, + MainloopScheduleType::AUTO, + EpilogueScheduleType::AUTO, + ClusterShape::ClusterShape_2x2x1); + candidate_configs.push_back(config); + } } - else if (sm == 100 && (config_type_param & CutlassGemmConfig::BLACKWELL)) // Assuming SM100 for Blackwell - { - std::vector tiles = get_candidate_tiles_sm100(sm, config_type_param); - std::vector candidate_configs; - - for (auto const& tile_config_sm100 : tiles) - { - // SM100 uses MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO similar to SM90. - // Cluster shapes are also handled similarly. - CutlassGemmConfig config( - tile_config_sm100, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_1x1x1); - candidate_configs.push_back(config); - - bool const has_m_mcast = supports_mcast_along_m_sm100(tile_config_sm100); - bool const has_n_mcast = supports_mcast_along_n_sm100(tile_config_sm100); - - if (has_m_mcast) - { - CutlassGemmConfig mcast_m_config(tile_config_sm100, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, - ClusterShape::ClusterShape_2x1x1); - candidate_configs.push_back(mcast_m_config); - } - - if (has_n_mcast) - { - CutlassGemmConfig mcast_n_config(tile_config_sm100, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, - ClusterShape::ClusterShape_1x2x1); - candidate_configs.push_back(mcast_n_config); - } - - if (has_m_mcast && has_n_mcast) - { - CutlassGemmConfig mcast_mn_config(tile_config_sm100, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, - ClusterShape::ClusterShape_2x2x1); - candidate_configs.push_back(mcast_mn_config); - } - } - return candidate_configs; + return candidate_configs; + } else if (sm == 100 && + (config_type_param & + CutlassGemmConfig::BLACKWELL)) // Assuming SM100 for Blackwell + { + std::vector tiles = + get_candidate_tiles_sm100(sm, config_type_param); + std::vector candidate_configs; + + for (auto const& tile_config_sm100 : tiles) { + // SM100 uses MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO + // similar to SM90. Cluster shapes are also handled similarly. + CutlassGemmConfig config(tile_config_sm100, + MainloopScheduleType::AUTO, + EpilogueScheduleType::AUTO, + ClusterShape::ClusterShape_1x1x1); + candidate_configs.push_back(config); + + bool const has_m_mcast = supports_mcast_along_m_sm100(tile_config_sm100); + bool const has_n_mcast = supports_mcast_along_n_sm100(tile_config_sm100); + + if (has_m_mcast) { + CutlassGemmConfig mcast_m_config(tile_config_sm100, + MainloopScheduleType::AUTO, + EpilogueScheduleType::AUTO, + ClusterShape::ClusterShape_2x1x1); + candidate_configs.push_back(mcast_m_config); + } + + if (has_n_mcast) { + CutlassGemmConfig mcast_n_config(tile_config_sm100, + MainloopScheduleType::AUTO, + EpilogueScheduleType::AUTO, + ClusterShape::ClusterShape_1x2x1); + candidate_configs.push_back(mcast_n_config); + } + + if (has_m_mcast && has_n_mcast) { + CutlassGemmConfig mcast_mn_config(tile_config_sm100, + MainloopScheduleType::AUTO, + EpilogueScheduleType::AUTO, + ClusterShape::ClusterShape_2x2x1); + candidate_configs.push_back(mcast_mn_config); + } } - - // Fallback to older architecture configurations - std::vector tiles = get_candidate_tiles(sm, config_type_param); - std::vector candidate_configs; //Already declared above for SM90 path, ensure scope is correct or redeclare if necessary. - // It's fine here as it's within an else if / else block. - bool const int8_configs_only = config_type_param & CutlassGemmConfig::INT8_ONLY; - int const min_stages = int8_configs_only ? 3 : 2; - int const max_stages = int8_configs_only ? 6 : (sm >= 80 ? 4 : 2); - for (auto const& tile_config : tiles) - { - for (int stages = min_stages; stages <= max_stages; ++stages) - { - CutlassGemmConfig config(tile_config, SplitKStyle::NO_SPLIT_K, 1, stages); - candidate_configs.push_back(config); - if (sm >= 75) - { - for (int split_k_factor = 2; split_k_factor <= max_split_k; ++split_k_factor) - { - auto config = CutlassGemmConfig{tile_config, SplitKStyle::SPLIT_K_SERIAL, split_k_factor, stages}; - candidate_configs.push_back(config); - } - } + return candidate_configs; + } + + // Fallback to older architecture configurations + std::vector tiles = + get_candidate_tiles(sm, config_type_param); + std::vector + candidate_configs; // Already declared above for SM90 path, ensure scope + // is correct or redeclare if necessary. + // It's fine here as it's within an else if / else + // block. + bool const int8_configs_only = + config_type_param & CutlassGemmConfig::INT8_ONLY; + int const min_stages = int8_configs_only ? 3 : 2; + int const max_stages = int8_configs_only ? 6 : (sm >= 80 ? 4 : 2); + for (auto const& tile_config : tiles) { + for (int stages = min_stages; stages <= max_stages; ++stages) { + CutlassGemmConfig config(tile_config, SplitKStyle::NO_SPLIT_K, 1, stages); + candidate_configs.push_back(config); + if (sm >= 75) { + for (int split_k_factor = 2; split_k_factor <= max_split_k; + ++split_k_factor) { + auto config = CutlassGemmConfig{ + tile_config, SplitKStyle::SPLIT_K_SERIAL, split_k_factor, stages}; + candidate_configs.push_back(config); } + } } + } - return candidate_configs; + return candidate_configs; } -CutlassGemmConfig estimate_best_config_from_occupancies(std::vector const& candidate_configs, - std::vector const& occupancies, int64_t const m, int64_t const n, int64_t const k, int64_t const num_experts, - int const split_k_limit, size_t const workspace_bytes, int const multi_processor_count, int const is_weight_only) -{ - - if (occupancies.size() != candidate_configs.size()) - { - throw( - "[estimate_best_config_from_occupancies] occpancies and " - "candidate configs vectors must have equal length."); +CutlassGemmConfig estimate_best_config_from_occupancies( + std::vector const& candidate_configs, + std::vector const& occupancies, + int64_t const m, + int64_t const n, + int64_t const k, + int64_t const num_experts, + int const split_k_limit, + size_t const workspace_bytes, + int const multi_processor_count, + int const is_weight_only) { + if (occupancies.size() != candidate_configs.size()) { + throw( + "[estimate_best_config_from_occupancies] occpancies and " + "candidate configs vectors must have equal length."); + } + + CutlassGemmConfig best_config; + // Score will be [0, 1]. The objective is to minimize this score. + // It represents the fraction of SM resources unused in the last wave. + float config_score = 1.0f; + int config_waves = INT_MAX; + int current_m_tile = 0; + + int const max_split_k = n >= multi_processor_count * 256 ? 1 : split_k_limit; + for (int ii = 0; ii < candidate_configs.size(); ++ii) { + CutlassGemmConfig candidate_config = candidate_configs[ii]; + TileShape tile_shape = + get_cta_shape_for_config(candidate_config.tile_config); + int occupancy = occupancies[ii]; + + if (occupancy == 0) { + continue; } - CutlassGemmConfig best_config; - // Score will be [0, 1]. The objective is to minimize this score. - // It represents the fraction of SM resources unused in the last wave. - float config_score = 1.0f; - int config_waves = INT_MAX; - int current_m_tile = 0; - - int const max_split_k = n >= multi_processor_count * 256 ? 1 : split_k_limit; - for (int ii = 0; ii < candidate_configs.size(); ++ii) - { - CutlassGemmConfig candidate_config = candidate_configs[ii]; - TileShape tile_shape = get_cta_shape_for_config(candidate_config.tile_config); - int occupancy = occupancies[ii]; - - if (occupancy == 0) - { - continue; - } + // Keep small tile sizes when possible. + if (best_config.tile_config != CutlassTileConfig::ChooseWithHeuristic && + m < current_m_tile && current_m_tile < tile_shape.m) { + continue; + } - // Keep small tile sizes when possible. - if (best_config.tile_config != CutlassTileConfig::ChooseWithHeuristic && m < current_m_tile - && current_m_tile < tile_shape.m) - { - continue; - } + int const ctas_in_m_dim = (m + tile_shape.m - 1) / tile_shape.m; + int const ctas_in_n_dim = (n + tile_shape.n - 1) / tile_shape.n; - int const ctas_in_m_dim = (m + tile_shape.m - 1) / tile_shape.m; - int const ctas_in_n_dim = (n + tile_shape.n - 1) / tile_shape.n; - - for (int split_k_factor = 1; split_k_factor <= max_split_k; ++split_k_factor) - { - if (is_valid_split_k_factor(m, n, k, tile_shape, split_k_factor, workspace_bytes, is_weight_only)) - { - int const ctas_per_wave = occupancy * multi_processor_count; - int const ctas_for_problem = ctas_in_m_dim * ctas_in_n_dim * split_k_factor; - - int const num_waves_total = (ctas_for_problem + ctas_per_wave - 1) / ctas_per_wave; - float const num_waves_fractional = ctas_for_problem / float(ctas_per_wave); - float const current_score = float(num_waves_total) - num_waves_fractional; - - float const score_slack = 0.1f; - if (current_score < config_score - || ((config_waves > num_waves_total) && (current_score < config_score + score_slack))) - { - config_score = current_score; - config_waves = num_waves_total; - SplitKStyle split_style - = split_k_factor > 1 ? SplitKStyle::SPLIT_K_SERIAL : SplitKStyle::NO_SPLIT_K; - best_config = CutlassGemmConfig( - candidate_config.tile_config, split_style, split_k_factor, candidate_config.stages); - current_m_tile = tile_shape.m; - } - else if (current_score == config_score - && (best_config.stages < candidate_config.stages || split_k_factor < best_config.split_k_factor - || current_m_tile < tile_shape.m)) - { - // Prefer deeper pipeline or smaller split-k - SplitKStyle split_style - = split_k_factor > 1 ? SplitKStyle::SPLIT_K_SERIAL : SplitKStyle::NO_SPLIT_K; - best_config = CutlassGemmConfig( - candidate_config.tile_config, split_style, split_k_factor, candidate_config.stages); - current_m_tile = tile_shape.m; - config_waves = num_waves_total; - } - } + for (int split_k_factor = 1; split_k_factor <= max_split_k; + ++split_k_factor) { + if (is_valid_split_k_factor(m, + n, + k, + tile_shape, + split_k_factor, + workspace_bytes, + is_weight_only)) { + int const ctas_per_wave = occupancy * multi_processor_count; + int const ctas_for_problem = + ctas_in_m_dim * ctas_in_n_dim * split_k_factor; + + int const num_waves_total = + (ctas_for_problem + ctas_per_wave - 1) / ctas_per_wave; + float const num_waves_fractional = + ctas_for_problem / float(ctas_per_wave); + float const current_score = + float(num_waves_total) - num_waves_fractional; + + float const score_slack = 0.1f; + if (current_score < config_score || + ((config_waves > num_waves_total) && + (current_score < config_score + score_slack))) { + config_score = current_score; + config_waves = num_waves_total; + SplitKStyle split_style = split_k_factor > 1 + ? SplitKStyle::SPLIT_K_SERIAL + : SplitKStyle::NO_SPLIT_K; + best_config = CutlassGemmConfig(candidate_config.tile_config, + split_style, + split_k_factor, + candidate_config.stages); + current_m_tile = tile_shape.m; + } else if (current_score == config_score && + (best_config.stages < candidate_config.stages || + split_k_factor < best_config.split_k_factor || + current_m_tile < tile_shape.m)) { + // Prefer deeper pipeline or smaller split-k + SplitKStyle split_style = split_k_factor > 1 + ? SplitKStyle::SPLIT_K_SERIAL + : SplitKStyle::NO_SPLIT_K; + best_config = CutlassGemmConfig(candidate_config.tile_config, + split_style, + split_k_factor, + candidate_config.stages); + current_m_tile = tile_shape.m; + config_waves = num_waves_total; } + } } + } - if (best_config.tile_config == CutlassTileConfig::ChooseWithHeuristic) - { - throw("Heurisitc failed to find a valid config."); - } + if (best_config.tile_config == CutlassTileConfig::ChooseWithHeuristic) { + throw("Heurisitc failed to find a valid config."); + } - return best_config; + return best_config; } -} // namespace cutlass_kernels -} // namespace kernels +} // namespace cutlass_kernels +} // namespace kernels diff --git a/custom_ops/gpu_ops/cutlass_kernels/cutlass_heuristic.h b/custom_ops/gpu_ops/cutlass_kernels/cutlass_heuristic.h index 8165bc421c6..b6839be7f31 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/cutlass_heuristic.h +++ b/custom_ops/gpu_ops/cutlass_kernels/cutlass_heuristic.h @@ -20,36 +20,47 @@ #include "cutlass_extensions/gemm_configs.h" #include "common/cudaUtils.h" -namespace kernels -{ -namespace cutlass_kernels -{ +namespace kernels { +namespace cutlass_kernels { template -struct should_filter_sm90_gemm_problem_shape -{ +struct should_filter_sm90_gemm_problem_shape { #ifdef FAST_BUILD - constexpr static int TILE_K = 128 * 8 / cutlass::sizeof_bits::value; - using SupportedCtaShape = cute::Shape>; - using SupportedCgaShape = cute::Shape; + constexpr static int TILE_K = + 128 * 8 / cutlass::sizeof_bits::value; + using SupportedCtaShape = + cute::Shape>; + using SupportedCgaShape = cute::Shape; - constexpr static bool value - = !cute::is_same_v || !cute::is_same_v; + constexpr static bool value = + !cute::is_same_v || + !cute::is_same_v; #else - constexpr static bool value = false; + constexpr static bool value = false; #endif }; template -constexpr static bool should_filter_sm90_gemm_problem_shape_v - = should_filter_sm90_gemm_problem_shape::value; +constexpr static bool should_filter_sm90_gemm_problem_shape_v = + should_filter_sm90_gemm_problem_shape::value; std::vector get_candidate_configs( - int sm, int const max_split_k, cutlass_extensions::CutlassGemmConfig::CandidateConfigTypeParam const); + int sm, + int const max_split_k, + cutlass_extensions::CutlassGemmConfig::CandidateConfigTypeParam const); cutlass_extensions::CutlassGemmConfig estimate_best_config_from_occupancies( std::vector const& candidate_configs, - std::vector const& occupancies, int64_t const m, int64_t const n, int64_t const k, int64_t const num_experts, - int const split_k_limit, size_t const workspace_bytes, int const multi_processor_count, int const is_weight_only); + std::vector const& occupancies, + int64_t const m, + int64_t const n, + int64_t const k, + int64_t const num_experts, + int const split_k_limit, + size_t const workspace_bytes, + int const multi_processor_count, + int const is_weight_only); -} // namespace cutlass_kernels -} // namespace kernels +} // namespace cutlass_kernels +} // namespace kernels diff --git a/custom_ops/gpu_ops/cutlass_kernels/cutlass_preprocessors.cu b/custom_ops/gpu_ops/cutlass_kernels/cutlass_preprocessors.cu index 5c62b58085f..41c0412cd93 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/cutlass_preprocessors.cu +++ b/custom_ops/gpu_ops/cutlass_kernels/cutlass_preprocessors.cu @@ -19,752 +19,781 @@ #include "cutlass_kernels/cutlass_preprocessors.h" #include "cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h" -namespace kernels -{ -namespace cutlass_kernels -{ - -struct LayoutDetails -{ - enum class Layout - { - UNKNOWN, - ROW_MAJOR, - COLUMN_MAJOR - }; - - Layout layoutB = Layout::UNKNOWN; - int rows_per_column_tile = 1; - int columns_interleaved = 1; - - bool uses_imma_ldsm = false; +namespace kernels { +namespace cutlass_kernels { + +struct LayoutDetails { + enum class Layout { UNKNOWN, ROW_MAJOR, COLUMN_MAJOR }; + + Layout layoutB = Layout::UNKNOWN; + int rows_per_column_tile = 1; + int columns_interleaved = 1; + + bool uses_imma_ldsm = false; }; template -struct getLayoutDetails -{ -}; +struct getLayoutDetails {}; template <> -struct getLayoutDetails -{ - LayoutDetails operator()() - { - LayoutDetails layout_details; - layout_details.layoutB = LayoutDetails::Layout::ROW_MAJOR; - return layout_details; - } +struct getLayoutDetails { + LayoutDetails operator()() { + LayoutDetails layout_details; + layout_details.layoutB = LayoutDetails::Layout::ROW_MAJOR; + return layout_details; + } }; template <> -struct getLayoutDetails -{ - LayoutDetails operator()() - { - LayoutDetails layout_details; - layout_details.layoutB = LayoutDetails::Layout::COLUMN_MAJOR; - return layout_details; - } +struct getLayoutDetails { + LayoutDetails operator()() { + LayoutDetails layout_details; + layout_details.layoutB = LayoutDetails::Layout::COLUMN_MAJOR; + return layout_details; + } }; template -struct getLayoutDetails> -{ - LayoutDetails operator()() - { - LayoutDetails layout_details; - layout_details.layoutB = LayoutDetails::Layout::COLUMN_MAJOR; - layout_details.rows_per_column_tile = RowsPerTile; - layout_details.columns_interleaved = ColumnsInterleaved; - return layout_details; - } +struct getLayoutDetails< + cutlass::layout::ColumnMajorTileInterleave> { + LayoutDetails operator()() { + LayoutDetails layout_details; + layout_details.layoutB = LayoutDetails::Layout::COLUMN_MAJOR; + layout_details.rows_per_column_tile = RowsPerTile; + layout_details.columns_interleaved = ColumnsInterleaved; + return layout_details; + } }; template -LayoutDetails getLayoutDetailsForArchAndQuantType() -{ - - using CompileTraits = cutlass::gemm::kernel::LayoutDetailsB; - using LayoutB = typename CompileTraits::Layout; - using MmaOperator = typename CompileTraits::Operator; - LayoutDetails details = getLayoutDetails()(); - details.uses_imma_ldsm = std::is_same::value; - return details; +LayoutDetails getLayoutDetailsForArchAndQuantType() { + using CompileTraits = + cutlass::gemm::kernel::LayoutDetailsB; + using LayoutB = typename CompileTraits::Layout; + using MmaOperator = typename CompileTraits::Operator; + LayoutDetails details = getLayoutDetails()(); + details.uses_imma_ldsm = std::is_same< + MmaOperator, + cutlass::arch::OpMultiplyAddDequantizeInterleavedBToA>::value; + return details; } template -LayoutDetails getLayoutDetailsForArch(QuantType quant_type) -{ - int const bits_per_weight_element = get_weight_quant_bits(quant_type); - LayoutDetails details; - switch (quant_type) - { +LayoutDetails getLayoutDetailsForArch(QuantType quant_type) { + int const bits_per_weight_element = get_weight_quant_bits(quant_type); + LayoutDetails details; + switch (quant_type) { case QuantType::W8_A16: - details = getLayoutDetailsForArchAndQuantType(); - break; + details = getLayoutDetailsForArchAndQuantType(); + break; case QuantType::W4_A16: - details = getLayoutDetailsForArchAndQuantType(); - break; + details = getLayoutDetailsForArchAndQuantType(); + break; case QuantType::W4_AFP8: - details = getLayoutDetailsForArchAndQuantType(); - break; - default: PADDLE_THROW("Unsupported quantization type"); - } - return details; + details = getLayoutDetailsForArchAndQuantType(); + break; + default: + PADDLE_THROW("Unsupported quantization type"); + } + return details; } -LayoutDetails getLayoutDetailsForTransform(QuantType quant_type, int arch) -{ - if (arch >= 70 && arch < 75) - { - return getLayoutDetailsForArch(quant_type); - } - else if (arch >= 75 && arch < 80) - { - return getLayoutDetailsForArch(quant_type); - } - else if (arch >= 80 && arch < 90) - { - return getLayoutDetailsForArch(quant_type); - } - else if (arch == 90) - { - return getLayoutDetailsForArch(quant_type); - } - else - { - PADDLE_ENFORCE(false, "Unsupported Arch"); - return LayoutDetails(); - } +LayoutDetails getLayoutDetailsForTransform(QuantType quant_type, int arch) { + if (arch >= 70 && arch < 75) { + return getLayoutDetailsForArch(quant_type); + } else if (arch >= 75 && arch < 80) { + return getLayoutDetailsForArch(quant_type); + } else if (arch >= 80 && arch < 90) { + return getLayoutDetailsForArch(quant_type); + } else if (arch == 90) { + return getLayoutDetailsForArch(quant_type); + } else { + PADDLE_ENFORCE(false, "Unsupported Arch"); + return LayoutDetails(); + } } -// Permutes the rows of B in a way that is compatible with Turing+ architectures. +// Permutes the rows of B in a way that is compatible with Turing+ +// architectures. // // Throws an error for other architectures. // The data is permuted such that: // For W8_A16, each group of 16 rows is permuted using the map below: // 0 1 8 9 2 3 10 11 4 5 12 13 6 7 14 15 // For W4_A16, each group of 32 rows is permuted using the map below: -// 0 1 8 9 16 17 24 25 2 3 10 11 18 19 26 27 4 5 12 13 20 21 28 29 6 7 14 15 22 23 30 31 +// 0 1 8 9 16 17 24 25 2 3 10 11 18 19 26 27 4 5 12 13 20 21 28 29 6 7 14 15 22 +// 23 30 31 // For W4_A8, see the map in the code. The idea is similar to above. -// The goal of this permutation is to ensure data ends up in the correct threads after -// we execute LDSM. It counteracts the effect of the data being of different widths. -// For more information about the expected layouts, see the MMA section in the PTX docs. -std::vector get_permutation_map(QuantType quant_type) -{ - - if (quant_type == QuantType::W8_A16) - { - return {0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15}; - } - else if (quant_type == QuantType::W4_A16) - { - return {0, 1, 8, 9, 16, 17, 24, 25, 2, 3, 10, 11, 18, 19, 26, 27, 4, 5, 12, 13, 20, 21, 28, 29, 6, 7, 14, 15, - 22, 23, 30, 31}; - } - else if (quant_type == QuantType::W4_AFP8) - { - return {0, 1, 2, 3, 16, 17, 18, 19, 4, 5, 6, 7, 20, 21, 22, 23, 8, 9, 10, 11, 24, 25, 26, 27, 12, 13, 14, 15, - 28, 29, 30, 31}; - } - else - { - PADDLE_THROW("Invalid quantization type for LDSM permutation"); - } +// The goal of this permutation is to ensure data ends up in the correct threads +// after we execute LDSM. It counteracts the effect of the data being of +// different widths. For more information about the expected layouts, see the +// MMA section in the PTX docs. +std::vector get_permutation_map(QuantType quant_type) { + if (quant_type == QuantType::W8_A16) { + return {0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15}; + } else if (quant_type == QuantType::W4_A16) { + return {0, 1, 8, 9, 16, 17, 24, 25, 2, 3, 10, 11, 18, 19, 26, 27, + 4, 5, 12, 13, 20, 21, 28, 29, 6, 7, 14, 15, 22, 23, 30, 31}; + } else if (quant_type == QuantType::W4_AFP8) { + return {0, 1, 2, 3, 16, 17, 18, 19, 4, 5, 6, 7, 20, 21, 22, 23, + 8, 9, 10, 11, 24, 25, 26, 27, 12, 13, 14, 15, 28, 29, 30, 31}; + } else { + PADDLE_THROW("Invalid quantization type for LDSM permutation"); + } } -void permute_B_rows_for_mixed_gemm(int8_t* permuted_quantized_tensor, int8_t const* quantized_tensor, - std::vector const& shape, QuantType quant_type, int64_t const arch_version) -{ - // We only want to run this step for weight only quant. - std::vector row_permutation = get_permutation_map(quant_type); - - PADDLE_ENFORCE(shape.size() == 2 || shape.size() == 3, "Shape must be 2-D or 3-D"); - const size_t num_experts = shape.size() == 2 ? 1 : shape[0]; - const size_t num_rows = shape.size() == 2 ? shape[0] : shape[1]; - const size_t num_cols = shape.size() == 2 ? shape[1] : shape[2]; - - int const BITS_PER_ELT = get_weight_quant_bits(quant_type); - int const K = 16 / BITS_PER_ELT; - int const ELTS_PER_BYTE = 8 / BITS_PER_ELT; - int const ELTS_PER_REG = 32 / BITS_PER_ELT; - - uint32_t const* input_byte_ptr = reinterpret_cast(quantized_tensor); - uint32_t* output_byte_ptr = reinterpret_cast(permuted_quantized_tensor); - - int MMA_SHAPE_N = 8; - int B_ROWS_PER_MMA = 8 * K; - int const elts_in_int32 = 32 / BITS_PER_ELT; - - int const num_vec_cols = num_cols / elts_in_int32; - - PADDLE_ENFORCE( - arch_version >= 75, "Unsupported Arch. Pre-volta not supported. Column interleave not needed on Volta."); - - PADDLE_ENFORCE(num_rows % B_ROWS_PER_MMA == 0, - "Invalid shape for quantized tensor. Number of rows of quantized matrix must be a multiple of %d", - B_ROWS_PER_MMA); - PADDLE_ENFORCE(num_cols % MMA_SHAPE_N == 0, - "Invalid shape for quantized tensor. On turing/Ampere, the number of cols must be a multiple of %d.", - MMA_SHAPE_N); - - PADDLE_ENFORCE(size_t(B_ROWS_PER_MMA) == row_permutation.size(), "Unexpected number of LDSM rows permuted."); - - for (int expert = 0; expert < num_experts; ++expert) - { - const int64_t matrix_offset = expert * int64_t(num_rows) * int64_t(num_vec_cols); - for (int base_row = 0; base_row < num_rows; base_row += B_ROWS_PER_MMA) - { - for (int tile_row = 0; tile_row < B_ROWS_PER_MMA; ++tile_row) - { - - for (int write_col = 0; write_col < num_vec_cols; ++write_col) - { - int const write_row = base_row + tile_row; - int const tile_read_row = row_permutation[tile_row]; - int const read_row = base_row + tile_read_row; - int const read_col = write_col; - - const int64_t read_offset = matrix_offset + int64_t(read_row) * num_vec_cols + read_col; - const int64_t write_offset = matrix_offset + int64_t(write_row) * num_vec_cols + write_col; - - output_byte_ptr[write_offset] = input_byte_ptr[read_offset]; - } - } +void permute_B_rows_for_mixed_gemm(int8_t* permuted_quantized_tensor, + int8_t const* quantized_tensor, + std::vector const& shape, + QuantType quant_type, + int64_t const arch_version) { + // We only want to run this step for weight only quant. + std::vector row_permutation = get_permutation_map(quant_type); + + PADDLE_ENFORCE(shape.size() == 2 || shape.size() == 3, + "Shape must be 2-D or 3-D"); + const size_t num_experts = shape.size() == 2 ? 1 : shape[0]; + const size_t num_rows = shape.size() == 2 ? shape[0] : shape[1]; + const size_t num_cols = shape.size() == 2 ? shape[1] : shape[2]; + + int const BITS_PER_ELT = get_weight_quant_bits(quant_type); + int const K = 16 / BITS_PER_ELT; + int const ELTS_PER_BYTE = 8 / BITS_PER_ELT; + int const ELTS_PER_REG = 32 / BITS_PER_ELT; + + uint32_t const* input_byte_ptr = + reinterpret_cast(quantized_tensor); + uint32_t* output_byte_ptr = + reinterpret_cast(permuted_quantized_tensor); + + int MMA_SHAPE_N = 8; + int B_ROWS_PER_MMA = 8 * K; + int const elts_in_int32 = 32 / BITS_PER_ELT; + + int const num_vec_cols = num_cols / elts_in_int32; + + PADDLE_ENFORCE(arch_version >= 75, + "Unsupported Arch. Pre-volta not supported. Column interleave " + "not needed on Volta."); + + PADDLE_ENFORCE(num_rows % B_ROWS_PER_MMA == 0, + "Invalid shape for quantized tensor. Number of rows of " + "quantized matrix must be a multiple of %d", + B_ROWS_PER_MMA); + PADDLE_ENFORCE(num_cols % MMA_SHAPE_N == 0, + "Invalid shape for quantized tensor. On turing/Ampere, the " + "number of cols must be a multiple of %d.", + MMA_SHAPE_N); + + PADDLE_ENFORCE(size_t(B_ROWS_PER_MMA) == row_permutation.size(), + "Unexpected number of LDSM rows permuted."); + + for (int expert = 0; expert < num_experts; ++expert) { + const int64_t matrix_offset = + expert * int64_t(num_rows) * int64_t(num_vec_cols); + for (int base_row = 0; base_row < num_rows; base_row += B_ROWS_PER_MMA) { + for (int tile_row = 0; tile_row < B_ROWS_PER_MMA; ++tile_row) { + for (int write_col = 0; write_col < num_vec_cols; ++write_col) { + int const write_row = base_row + tile_row; + int const tile_read_row = row_permutation[tile_row]; + int const read_row = base_row + tile_read_row; + int const read_col = write_col; + + const int64_t read_offset = + matrix_offset + int64_t(read_row) * num_vec_cols + read_col; + const int64_t write_offset = + matrix_offset + int64_t(write_row) * num_vec_cols + write_col; + + output_byte_ptr[write_offset] = input_byte_ptr[read_offset]; } + } } + } } // We need to use this transpose to correctly handle packed int4 and int8 data -// The reason this code is relatively complex is that the "trivial" loops took a substantial -// amount of time to transpose leading to long preprocessing times. This seemed to be a big -// issue for relatively large models. +// The reason this code is relatively complex is that the "trivial" loops took a +// substantial amount of time to transpose leading to long preprocessing times. +// This seemed to be a big issue for relatively large models. template -void subbyte_transpose_impl( - int8_t* transposed_quantized_tensor, int8_t const* quantized_tensor, std::vector const& shape) -{ - constexpr int bits_per_elt = get_weight_quant_bits(quant_type); - - PADDLE_ENFORCE(shape.size() == 2 || shape.size() == 3, "Shape must be 2-D or 3-D"); - const size_t num_experts = shape.size() == 2 ? 1 : shape[0]; - const size_t num_rows = shape.size() == 2 ? shape[0] : shape[1]; - const size_t num_cols = shape.size() == 2 ? shape[1] : shape[2]; - - const size_t col_bytes = num_cols * bits_per_elt / 8; - const size_t col_bytes_trans = num_rows * bits_per_elt / 8; - const size_t num_bytes = size_t(num_experts) * num_rows * col_bytes; - - uint8_t const* input_byte_ptr = reinterpret_cast(quantized_tensor); - uint8_t* output_byte_ptr = reinterpret_cast(transposed_quantized_tensor); - - static constexpr int ELTS_PER_BYTE = 8 / bits_per_elt; - - static constexpr int M_TILE_L1 = 64; - static constexpr int N_TILE_L1 = M_TILE_L1 / ELTS_PER_BYTE; - uint8_t cache_buf[M_TILE_L1][N_TILE_L1]; - - static constexpr int VECTOR_WIDTH = std::min(32, N_TILE_L1); - - // We assume the dims are a multiple of vector width. Our kernels only handle dims which are multiples - // of 64 for weight-only quantization. As a result, this seemed like a reasonable tradeoff because it - // allows GCC to emit vector instructions. - PADDLE_ENFORCE(!(col_bytes_trans % VECTOR_WIDTH) && !(col_bytes % VECTOR_WIDTH), - "Number of bytes for rows and cols must be a multiple of %d. However, num_rows_bytes = %ld and " - "num_col_bytes = %ld.", - VECTOR_WIDTH, col_bytes_trans, col_bytes); - - int const num_m_tiles = (num_rows + M_TILE_L1 - 1) / M_TILE_L1; - int const num_n_tiles = (col_bytes + N_TILE_L1 - 1) / N_TILE_L1; - - for (size_t expert = 0; expert < num_experts; ++expert) - { - const size_t matrix_offset = expert * num_rows * col_bytes; - for (size_t row_tile_start = 0; row_tile_start < num_rows; row_tile_start += M_TILE_L1) - { - for (size_t col_tile_start_byte = 0; col_tile_start_byte < col_bytes; col_tile_start_byte += N_TILE_L1) - { - - int const row_limit = std::min(row_tile_start + M_TILE_L1, num_rows); - int const col_limit = std::min(col_tile_start_byte + N_TILE_L1, col_bytes); - - for (int ii = 0; ii < M_TILE_L1; ++ii) - { - int const row = row_tile_start + ii; - - for (int jj = 0; jj < N_TILE_L1; jj += VECTOR_WIDTH) - { - int const col = col_tile_start_byte + jj; - - const size_t logical_src_offset = matrix_offset + row * col_bytes + col; - - if (row < row_limit && col < col_limit) - { - for (int v = 0; v < VECTOR_WIDTH; ++v) - { - cache_buf[ii][jj + v] = input_byte_ptr[logical_src_offset + v]; - } - } - } - } - - if constexpr (bits_per_elt == 8) - { - for (int ii = 0; ii < M_TILE_L1; ++ii) - { - for (int jj = ii + 1; jj < N_TILE_L1; ++jj) - { - std::swap(cache_buf[ii][jj], cache_buf[jj][ii]); - } - } - } - else if constexpr (bits_per_elt == 4) - { - - for (int ii = 0; ii < M_TILE_L1; ++ii) - { - // Using M_TILE_L1 here is deliberate since we assume that the cache tile - // is square in the number of elements (not necessarily the number of bytes). - for (int jj = ii + 1; jj < M_TILE_L1; ++jj) - { - int const ii_byte = ii / ELTS_PER_BYTE; - int const ii_bit_offset = ii % ELTS_PER_BYTE; - - int const jj_byte = jj / ELTS_PER_BYTE; - int const jj_bit_offset = jj % ELTS_PER_BYTE; - - uint8_t src_elt = 0xF & (cache_buf[ii][jj_byte] >> (4 * jj_bit_offset)); - uint8_t tgt_elt = 0xF & (cache_buf[jj][ii_byte] >> (4 * ii_bit_offset)); - - cache_buf[ii][jj_byte] &= (0xF0 >> (4 * jj_bit_offset)); - cache_buf[jj][ii_byte] &= (0xF0 >> (4 * ii_bit_offset)); - - cache_buf[ii][jj_byte] |= (tgt_elt << (4 * jj_bit_offset)); - cache_buf[jj][ii_byte] |= (src_elt << (4 * ii_bit_offset)); - } - } - } - else - { - PADDLE_ENFORCE(false, "Unsupported quantization type."); - } - - const size_t row_tile_start_trans = col_tile_start_byte * ELTS_PER_BYTE; - const size_t col_tile_start_byte_trans = row_tile_start / ELTS_PER_BYTE; - - int const row_limit_trans = std::min(row_tile_start_trans + M_TILE_L1, num_cols); - int const col_limit_trans = std::min(col_tile_start_byte_trans + N_TILE_L1, col_bytes_trans); - - for (int ii = 0; ii < M_TILE_L1; ++ii) - { - int const row = row_tile_start_trans + ii; - for (int jj = 0; jj < N_TILE_L1; jj += VECTOR_WIDTH) - { - int const col = col_tile_start_byte_trans + jj; - - const size_t logical_tgt_offset = matrix_offset + row * col_bytes_trans + col; - - if (row < row_limit_trans && col < col_limit_trans) - { - for (int v = 0; v < VECTOR_WIDTH; ++v) - { - output_byte_ptr[logical_tgt_offset + v] = cache_buf[ii][jj + v]; - } - } - } - } +void subbyte_transpose_impl(int8_t* transposed_quantized_tensor, + int8_t const* quantized_tensor, + std::vector const& shape) { + constexpr int bits_per_elt = get_weight_quant_bits(quant_type); + + PADDLE_ENFORCE(shape.size() == 2 || shape.size() == 3, + "Shape must be 2-D or 3-D"); + const size_t num_experts = shape.size() == 2 ? 1 : shape[0]; + const size_t num_rows = shape.size() == 2 ? shape[0] : shape[1]; + const size_t num_cols = shape.size() == 2 ? shape[1] : shape[2]; + + const size_t col_bytes = num_cols * bits_per_elt / 8; + const size_t col_bytes_trans = num_rows * bits_per_elt / 8; + const size_t num_bytes = size_t(num_experts) * num_rows * col_bytes; + + uint8_t const* input_byte_ptr = + reinterpret_cast(quantized_tensor); + uint8_t* output_byte_ptr = + reinterpret_cast(transposed_quantized_tensor); + + static constexpr int ELTS_PER_BYTE = 8 / bits_per_elt; + + static constexpr int M_TILE_L1 = 64; + static constexpr int N_TILE_L1 = M_TILE_L1 / ELTS_PER_BYTE; + uint8_t cache_buf[M_TILE_L1][N_TILE_L1]; + + static constexpr int VECTOR_WIDTH = std::min(32, N_TILE_L1); + + // We assume the dims are a multiple of vector width. Our kernels only handle + // dims which are multiples of 64 for weight-only quantization. As a result, + // this seemed like a reasonable tradeoff because it allows GCC to emit vector + // instructions. + PADDLE_ENFORCE( + !(col_bytes_trans % VECTOR_WIDTH) && !(col_bytes % VECTOR_WIDTH), + "Number of bytes for rows and cols must be a multiple of %d. However, " + "num_rows_bytes = %ld and " + "num_col_bytes = %ld.", + VECTOR_WIDTH, + col_bytes_trans, + col_bytes); + + int const num_m_tiles = (num_rows + M_TILE_L1 - 1) / M_TILE_L1; + int const num_n_tiles = (col_bytes + N_TILE_L1 - 1) / N_TILE_L1; + + for (size_t expert = 0; expert < num_experts; ++expert) { + const size_t matrix_offset = expert * num_rows * col_bytes; + for (size_t row_tile_start = 0; row_tile_start < num_rows; + row_tile_start += M_TILE_L1) { + for (size_t col_tile_start_byte = 0; col_tile_start_byte < col_bytes; + col_tile_start_byte += N_TILE_L1) { + int const row_limit = std::min(row_tile_start + M_TILE_L1, num_rows); + int const col_limit = + std::min(col_tile_start_byte + N_TILE_L1, col_bytes); + + for (int ii = 0; ii < M_TILE_L1; ++ii) { + int const row = row_tile_start + ii; + + for (int jj = 0; jj < N_TILE_L1; jj += VECTOR_WIDTH) { + int const col = col_tile_start_byte + jj; + + const size_t logical_src_offset = + matrix_offset + row * col_bytes + col; + + if (row < row_limit && col < col_limit) { + for (int v = 0; v < VECTOR_WIDTH; ++v) { + cache_buf[ii][jj + v] = input_byte_ptr[logical_src_offset + v]; + } } + } } - } -} -void subbyte_transpose(int8_t* transposed_quantized_tensor, int8_t const* quantized_tensor, - std::vector const& shape, QuantType quant_type) -{ - if (quant_type == QuantType::W8_A16) - { - subbyte_transpose_impl(transposed_quantized_tensor, quantized_tensor, shape); - } - else if (quant_type == QuantType::W4_A16) - { - subbyte_transpose_impl(transposed_quantized_tensor, quantized_tensor, shape); - } - else if (quant_type == QuantType::W4_AFP8) - { - subbyte_transpose_impl(transposed_quantized_tensor, quantized_tensor, shape); - } - else - { - PADDLE_ENFORCE(false, "Invalid quant_type"); - } -} - -void add_bias_and_interleave_int8s_inplace(int8_t* int8_tensor, const size_t num_elts) -{ - for (int ii = 0; ii < num_elts; ++ii) - { - int8_tensor[ii] = int8_t(int(int8_tensor[ii]) + 128); - } + if constexpr (bits_per_elt == 8) { + for (int ii = 0; ii < M_TILE_L1; ++ii) { + for (int jj = ii + 1; jj < N_TILE_L1; ++jj) { + std::swap(cache_buf[ii][jj], cache_buf[jj][ii]); + } + } + } else if constexpr (bits_per_elt == 4) { + for (int ii = 0; ii < M_TILE_L1; ++ii) { + // Using M_TILE_L1 here is deliberate since we assume that the cache + // tile is square in the number of elements (not necessarily the + // number of bytes). + for (int jj = ii + 1; jj < M_TILE_L1; ++jj) { + int const ii_byte = ii / ELTS_PER_BYTE; + int const ii_bit_offset = ii % ELTS_PER_BYTE; + + int const jj_byte = jj / ELTS_PER_BYTE; + int const jj_bit_offset = jj % ELTS_PER_BYTE; + + uint8_t src_elt = + 0xF & (cache_buf[ii][jj_byte] >> (4 * jj_bit_offset)); + uint8_t tgt_elt = + 0xF & (cache_buf[jj][ii_byte] >> (4 * ii_bit_offset)); + + cache_buf[ii][jj_byte] &= (0xF0 >> (4 * jj_bit_offset)); + cache_buf[jj][ii_byte] &= (0xF0 >> (4 * ii_bit_offset)); + + cache_buf[ii][jj_byte] |= (tgt_elt << (4 * jj_bit_offset)); + cache_buf[jj][ii_byte] |= (src_elt << (4 * ii_bit_offset)); + } + } + } else { + PADDLE_ENFORCE(false, "Unsupported quantization type."); + } - // Step 2 will transform the layout of a 32-bit register in CUDA in order to match the int4 layout. This has no - // performance benefit and is purely so that int4 and int8 have the same layout. - // Pictorially, this does the following: - // bit 32 0 - // [elt_3 elt_2 elt_1 elt_0] (each elt occupies 8 bits) - // - // And it will rearrange the output 32 bit register to be the following: - // bit 32 0 - // [elt_3 elt_1 elt_2 elt_0] (each elt occupies 8 bits) - - PADDLE_ENFORCE(num_elts % 4 == 0, "Dimensions of int8 tensor must be a multiple of 4 for register relayout"); - for (size_t base = 0; base < num_elts; base += 4) - { - std::swap(int8_tensor[base + 1], int8_tensor[base + 2]); - } -} + const size_t row_tile_start_trans = col_tile_start_byte * ELTS_PER_BYTE; + const size_t col_tile_start_byte_trans = row_tile_start / ELTS_PER_BYTE; -void add_bias_and_interleave_int4s_inplace(int8_t* packed_int4_tensor, const size_t num_elts) -{ - int const num_bytes = num_elts / 2; - - // Step 1 will be to transform all the int4s to unsigned in order to make the dequantize take as little - // instructions as possible in the CUDA code. - for (size_t ii = 0; ii < num_bytes; ++ii) - { - int8_t transformed_packed_int4s = 0; - int8_t transformed_first_elt - = (int8_t(packed_int4_tensor[ii] << 4) >> 4) + 8; // The double shift here is to ensure sign extension - int8_t transformed_second_elt = (packed_int4_tensor[ii] >> 4) + 8; - - PADDLE_ENFORCE( - transformed_first_elt >= 0 && transformed_first_elt <= 15, "Illegal result for int4 transform (first elt)"); - PADDLE_ENFORCE(transformed_second_elt >= 0 && transformed_second_elt <= 15, - "Illegal result for int4 transform (second elt)"); - - // We don't need to mask in these ops since everything should be in the range 0-15 - transformed_packed_int4s |= transformed_first_elt; - transformed_packed_int4s |= (transformed_second_elt << 4); - packed_int4_tensor[ii] = transformed_packed_int4s; - } + int const row_limit_trans = + std::min(row_tile_start_trans + M_TILE_L1, num_cols); + int const col_limit_trans = + std::min(col_tile_start_byte_trans + N_TILE_L1, col_bytes_trans); - // Step 2 will transform the layout of a 32-bit register in CUDA in order to minimize the number of shift & logical - // instructions That are needed to extract the int4s in the GEMM main loop. Pictorially, the loop below will do the - // following: Take as input a 32 bit register with layout: bit 32 0 - // [elt_7 elt_6 elt_5 elt_4 elt_3 elt_2 elt_1 elt_0] (each elt occupies 4 bits) - // - // And it will rearrange the output 32 bit register to be the following: - // bit 32 0 - // [elt_7 elt_5 elt_3 elt_1 elt_6 elt_4 elt_2 elt_0] (each elt occupies 4 bits) - - PADDLE_ENFORCE(num_bytes % 4 == 0, "Dimensions of int4 tensor must be a multiple of 8 for register relayout"); - const size_t num_registers = num_bytes / 4; - - uint32_t* register_ptr = reinterpret_cast(packed_int4_tensor); - for (size_t ii = 0; ii < num_registers; ++ii) - { - const uint32_t current_register = register_ptr[ii]; - uint32_t transformed_register = 0; - - for (int dest_idx = 0; dest_idx < 8; ++dest_idx) - { - int const src_idx = dest_idx < 4 ? 2 * dest_idx : 2 * (dest_idx - 4) + 1; - int const src_shift = 4 * src_idx; - int const dest_shift = 4 * dest_idx; - - const uint32_t src_bits = (current_register >> src_shift) & 0xF; - transformed_register |= (src_bits << dest_shift); - } - register_ptr[ii] = transformed_register; - } -} + for (int ii = 0; ii < M_TILE_L1; ++ii) { + int const row = row_tile_start_trans + ii; + for (int jj = 0; jj < N_TILE_L1; jj += VECTOR_WIDTH) { + int const col = col_tile_start_byte_trans + jj; -void add_bias_and_interleave_quantized_tensor_inplace(int8_t* tensor, const size_t num_elts, QuantType quant_type) -{ - if (quant_type == QuantType::W8_A16) - { - add_bias_and_interleave_int8s_inplace(tensor, num_elts); - } - else if (quant_type == QuantType::W4_A16 || quant_type == QuantType::W4_AFP8) - { - // W4_AFP8 uses the same preprocessor as W4_A16 because the FP8 data must - // be converted to FP16 before the scales can be applied using CUDA cores. - // As a result, we still want permute the data so that it is well aligned - // for conversion to FP16. - add_bias_and_interleave_int4s_inplace(tensor, num_elts); - } - else - { - PADDLE_ENFORCE(false, "Invalid quantization type for interleaving."); - } -} + const size_t logical_tgt_offset = + matrix_offset + row * col_bytes_trans + col; -void interleave_column_major_tensor(int8_t* interleaved_quantized_tensor, int8_t const* quantized_tensor, - std::vector const& shape, QuantType quant_type, LayoutDetails details) -{ - PADDLE_ENFORCE(shape.size() == 2 || shape.size() == 3, "Shape must be 2-D or 3-D"); - const size_t num_experts = shape.size() == 2 ? 1 : shape[0]; - const size_t num_rows = shape.size() == 2 ? shape[0] : shape[1]; - const size_t num_cols = shape.size() == 2 ? shape[1] : shape[2]; - - int const BITS_PER_ELT = get_weight_quant_bits(quant_type); - int const elts_in_int32 = 32 / BITS_PER_ELT; - - int const rows_per_tile = details.rows_per_column_tile; - - PADDLE_ENFORCE(!(num_rows % elts_in_int32), - "The number of rows must be a multiple of %d but the number of rows is %ld.", elts_in_int32, num_rows); - - uint32_t const* input_byte_ptr = reinterpret_cast(quantized_tensor); - uint32_t* output_byte_ptr = reinterpret_cast(interleaved_quantized_tensor); - - PADDLE_ENFORCE(!(num_rows % rows_per_tile), - "The number of rows must be a multiple of %d but the number of rows is %ld.", rows_per_tile, num_rows); - - int const num_vec_rows = num_rows / elts_in_int32; - int const vec_rows_per_tile = rows_per_tile / elts_in_int32; - int const interleave = details.columns_interleaved; - - for (int expert = 0; expert < num_experts; ++expert) - { - const int64_t matrix_offset = expert * int64_t(num_vec_rows) * int64_t(num_cols); - for (int read_col = 0; read_col < num_cols; ++read_col) - { - const int64_t write_col = read_col / interleave; - for (int base_vec_row = 0; base_vec_row < num_vec_rows; base_vec_row += vec_rows_per_tile) - { - for (int vec_read_row = base_vec_row; - vec_read_row < std::min(num_vec_rows, base_vec_row + vec_rows_per_tile); ++vec_read_row) - { - const int64_t vec_write_row = interleave * base_vec_row - + vec_rows_per_tile * (read_col % interleave) + vec_read_row % vec_rows_per_tile; - - const int64_t read_offset = matrix_offset + int64_t(read_col) * num_vec_rows + vec_read_row; - const int64_t write_offset - = matrix_offset + int64_t(write_col) * num_vec_rows * interleave + vec_write_row; - output_byte_ptr[write_offset] = input_byte_ptr[read_offset]; - } + if (row < row_limit_trans && col < col_limit_trans) { + for (int v = 0; v < VECTOR_WIDTH; ++v) { + output_byte_ptr[logical_tgt_offset + v] = cache_buf[ii][jj + v]; + } } + } } + } } + } } -void preprocess_weights_for_mixed_gemm(int8_t* preprocessed_quantized_weight, int8_t const* row_major_quantized_weight, - std::vector const& shape, QuantType quant_type, bool force_interleave) -{ - int arch = 89; - if (force_interleave && arch == 90) - { - // Workaround for MOE which doesn't have specialised Hopper kernels yet - arch = 80; - } - LayoutDetails details = getLayoutDetailsForTransform(quant_type, arch); - - PADDLE_ENFORCE(shape.size() == 2 || shape.size() == 3, "Shape must be 2-D or 3-D"); - - size_t num_elts = 1; - for (auto const& dim : shape) - { - num_elts *= dim; - } - - const size_t num_bytes = num_elts * get_weight_quant_bits(quant_type) / 8; +void subbyte_transpose(int8_t* transposed_quantized_tensor, + int8_t const* quantized_tensor, + std::vector const& shape, + QuantType quant_type) { + if (quant_type == QuantType::W8_A16) { + subbyte_transpose_impl( + transposed_quantized_tensor, quantized_tensor, shape); + } else if (quant_type == QuantType::W4_A16) { + subbyte_transpose_impl( + transposed_quantized_tensor, quantized_tensor, shape); + } else if (quant_type == QuantType::W4_AFP8) { + subbyte_transpose_impl( + transposed_quantized_tensor, quantized_tensor, shape); + } else { + PADDLE_ENFORCE(false, "Invalid quant_type"); + } +} - std::vector src_buf(num_bytes); - std::vector dst_buf(num_bytes); - std::copy(row_major_quantized_weight, row_major_quantized_weight + num_bytes, src_buf.begin()); +void add_bias_and_interleave_int8s_inplace(int8_t* int8_tensor, + const size_t num_elts) { + for (int ii = 0; ii < num_elts; ++ii) { + int8_tensor[ii] = int8_t(int(int8_tensor[ii]) + 128); + } + + // Step 2 will transform the layout of a 32-bit register in CUDA in order to + // match the int4 layout. This has no performance benefit and is purely so + // that int4 and int8 have the same layout. Pictorially, this does the + // following: bit 32 0 + // [elt_3 elt_2 elt_1 elt_0] (each elt occupies 8 bits) + // + // And it will rearrange the output 32 bit register to be the following: + // bit 32 0 + // [elt_3 elt_1 elt_2 elt_0] (each elt occupies 8 bits) + + PADDLE_ENFORCE(num_elts % 4 == 0, + "Dimensions of int8 tensor must be a multiple of 4 for " + "register relayout"); + for (size_t base = 0; base < num_elts; base += 4) { + std::swap(int8_tensor[base + 1], int8_tensor[base + 2]); + } +} - // Works on row major data, so issue this permutation first. - if (details.uses_imma_ldsm) - { - permute_B_rows_for_mixed_gemm(dst_buf.data(), src_buf.data(), shape, quant_type, arch); - src_buf.swap(dst_buf); - } +void add_bias_and_interleave_int4s_inplace(int8_t* packed_int4_tensor, + const size_t num_elts) { + int const num_bytes = num_elts / 2; + + // Step 1 will be to transform all the int4s to unsigned in order to make the + // dequantize take as little instructions as possible in the CUDA code. + for (size_t ii = 0; ii < num_bytes; ++ii) { + int8_t transformed_packed_int4s = 0; + int8_t transformed_first_elt = + (int8_t(packed_int4_tensor[ii] << 4) >> 4) + + 8; // The double shift here is to ensure sign extension + int8_t transformed_second_elt = (packed_int4_tensor[ii] >> 4) + 8; + + PADDLE_ENFORCE(transformed_first_elt >= 0 && transformed_first_elt <= 15, + "Illegal result for int4 transform (first elt)"); + PADDLE_ENFORCE(transformed_second_elt >= 0 && transformed_second_elt <= 15, + "Illegal result for int4 transform (second elt)"); + + // We don't need to mask in these ops since everything should be in the + // range 0-15 + transformed_packed_int4s |= transformed_first_elt; + transformed_packed_int4s |= (transformed_second_elt << 4); + packed_int4_tensor[ii] = transformed_packed_int4s; + } + + // Step 2 will transform the layout of a 32-bit register in CUDA in order to + // minimize the number of shift & logical instructions That are needed to + // extract the int4s in the GEMM main loop. Pictorially, the loop below will + // do the following: Take as input a 32 bit register with layout: bit 32 0 + // [elt_7 elt_6 elt_5 elt_4 elt_3 elt_2 elt_1 elt_0] (each elt + // occupies 4 bits) + // + // And it will rearrange the output 32 bit register to be the following: + // bit 32 0 + // [elt_7 elt_5 elt_3 elt_1 elt_6 elt_4 elt_2 elt_0] (each elt + // occupies 4 bits) + + PADDLE_ENFORCE(num_bytes % 4 == 0, + "Dimensions of int4 tensor must be a multiple of 8 for " + "register relayout"); + const size_t num_registers = num_bytes / 4; + + uint32_t* register_ptr = reinterpret_cast(packed_int4_tensor); + for (size_t ii = 0; ii < num_registers; ++ii) { + const uint32_t current_register = register_ptr[ii]; + uint32_t transformed_register = 0; + + for (int dest_idx = 0; dest_idx < 8; ++dest_idx) { + int const src_idx = dest_idx < 4 ? 2 * dest_idx : 2 * (dest_idx - 4) + 1; + int const src_shift = 4 * src_idx; + int const dest_shift = 4 * dest_idx; + + const uint32_t src_bits = (current_register >> src_shift) & 0xF; + transformed_register |= (src_bits << dest_shift); + } + register_ptr[ii] = transformed_register; + } +} - if (details.layoutB == LayoutDetails::Layout::COLUMN_MAJOR) - { - subbyte_transpose(dst_buf.data(), src_buf.data(), shape, quant_type); - src_buf.swap(dst_buf); - } +void add_bias_and_interleave_quantized_tensor_inplace(int8_t* tensor, + const size_t num_elts, + QuantType quant_type) { + if (quant_type == QuantType::W8_A16) { + add_bias_and_interleave_int8s_inplace(tensor, num_elts); + } else if (quant_type == QuantType::W4_A16 || + quant_type == QuantType::W4_AFP8) { + // W4_AFP8 uses the same preprocessor as W4_A16 because the FP8 data must + // be converted to FP16 before the scales can be applied using CUDA cores. + // As a result, we still want permute the data so that it is well aligned + // for conversion to FP16. + add_bias_and_interleave_int4s_inplace(tensor, num_elts); + } else { + PADDLE_ENFORCE(false, "Invalid quantization type for interleaving."); + } +} - if (details.columns_interleaved > 1) - { - interleave_column_major_tensor(dst_buf.data(), src_buf.data(), shape, quant_type, details); - src_buf.swap(dst_buf); +void interleave_column_major_tensor(int8_t* interleaved_quantized_tensor, + int8_t const* quantized_tensor, + std::vector const& shape, + QuantType quant_type, + LayoutDetails details) { + PADDLE_ENFORCE(shape.size() == 2 || shape.size() == 3, + "Shape must be 2-D or 3-D"); + const size_t num_experts = shape.size() == 2 ? 1 : shape[0]; + const size_t num_rows = shape.size() == 2 ? shape[0] : shape[1]; + const size_t num_cols = shape.size() == 2 ? shape[1] : shape[2]; + + int const BITS_PER_ELT = get_weight_quant_bits(quant_type); + int const elts_in_int32 = 32 / BITS_PER_ELT; + + int const rows_per_tile = details.rows_per_column_tile; + + PADDLE_ENFORCE(!(num_rows % elts_in_int32), + "The number of rows must be a multiple of %d but the number " + "of rows is %ld.", + elts_in_int32, + num_rows); + + uint32_t const* input_byte_ptr = + reinterpret_cast(quantized_tensor); + uint32_t* output_byte_ptr = + reinterpret_cast(interleaved_quantized_tensor); + + PADDLE_ENFORCE(!(num_rows % rows_per_tile), + "The number of rows must be a multiple of %d but the number " + "of rows is %ld.", + rows_per_tile, + num_rows); + + int const num_vec_rows = num_rows / elts_in_int32; + int const vec_rows_per_tile = rows_per_tile / elts_in_int32; + int const interleave = details.columns_interleaved; + + for (int expert = 0; expert < num_experts; ++expert) { + const int64_t matrix_offset = + expert * int64_t(num_vec_rows) * int64_t(num_cols); + for (int read_col = 0; read_col < num_cols; ++read_col) { + const int64_t write_col = read_col / interleave; + for (int base_vec_row = 0; base_vec_row < num_vec_rows; + base_vec_row += vec_rows_per_tile) { + for (int vec_read_row = base_vec_row; + vec_read_row < + std::min(num_vec_rows, base_vec_row + vec_rows_per_tile); + ++vec_read_row) { + const int64_t vec_write_row = + interleave * base_vec_row + + vec_rows_per_tile * (read_col % interleave) + + vec_read_row % vec_rows_per_tile; + + const int64_t read_offset = + matrix_offset + int64_t(read_col) * num_vec_rows + vec_read_row; + const int64_t write_offset = + matrix_offset + int64_t(write_col) * num_vec_rows * interleave + + vec_write_row; + output_byte_ptr[write_offset] = input_byte_ptr[read_offset]; + } + } } + } +} - if (arch >= 70 && arch < 90) - { - add_bias_and_interleave_quantized_tensor_inplace(src_buf.data(), num_elts, quant_type); - } - std::copy(src_buf.begin(), src_buf.end(), preprocessed_quantized_weight); +void preprocess_weights_for_mixed_gemm(int8_t* preprocessed_quantized_weight, + int8_t const* row_major_quantized_weight, + std::vector const& shape, + QuantType quant_type, + bool force_interleave) { + int arch = 89; + if (force_interleave && arch == 90) { + // Workaround for MOE which doesn't have specialised Hopper kernels yet + arch = 80; + } + LayoutDetails details = getLayoutDetailsForTransform(quant_type, arch); + + PADDLE_ENFORCE(shape.size() == 2 || shape.size() == 3, + "Shape must be 2-D or 3-D"); + + size_t num_elts = 1; + for (auto const& dim : shape) { + num_elts *= dim; + } + + const size_t num_bytes = num_elts * get_weight_quant_bits(quant_type) / 8; + + std::vector src_buf(num_bytes); + std::vector dst_buf(num_bytes); + std::copy(row_major_quantized_weight, + row_major_quantized_weight + num_bytes, + src_buf.begin()); + + // Works on row major data, so issue this permutation first. + if (details.uses_imma_ldsm) { + permute_B_rows_for_mixed_gemm( + dst_buf.data(), src_buf.data(), shape, quant_type, arch); + src_buf.swap(dst_buf); + } + + if (details.layoutB == LayoutDetails::Layout::COLUMN_MAJOR) { + subbyte_transpose(dst_buf.data(), src_buf.data(), shape, quant_type); + src_buf.swap(dst_buf); + } + + if (details.columns_interleaved > 1) { + interleave_column_major_tensor( + dst_buf.data(), src_buf.data(), shape, quant_type, details); + src_buf.swap(dst_buf); + } + + if (arch >= 70 && arch < 90) { + add_bias_and_interleave_quantized_tensor_inplace( + src_buf.data(), num_elts, quant_type); + } + std::copy(src_buf.begin(), src_buf.end(), preprocessed_quantized_weight); } /* Arguments: - input_weight_ptr - the weight tensor to be quantized. Must be 2-D or 3-D and of type FP16. + input_weight_ptr - the weight tensor to be quantized. Must be 2-D or 3-D +and of type FP16. quant_type - the type of the output quantization weight. - This function does symmetric quantization on 2-D or 3-D tensors. It uses the full int range and assumes the - zero-point is zero and will automatically construct the scales. + This function does symmetric quantization on 2-D or 3-D tensors. It uses the +full int range and assumes the zero-point is zero and will automatically +construct the scales. - It always quantizes the last axis of the tensor. For 3-D tensors, it operates in "batched" mode where the tensor is - viewed as a stack of matrices and a scale is produced for each column of every matrix. + It always quantizes the last axis of the tensor. For 3-D tensors, it +operates in "batched" mode where the tensor is viewed as a stack of matrices and +a scale is produced for each column of every matrix. Outputs - processed_quantized_weight - quantized AND processed weight for GEMM. This MUST be used with the CUTLASS GEMM - unprocessed_quantized_weight - quantized but unprocessed weights. Useful for reference checking. - scale_ptr - scales for the quantized weight. + processed_quantized_weight - quantized AND processed weight for GEMM. This +MUST be used with the CUTLASS GEMM unprocessed_quantized_weight - quantized but +unprocessed weights. Useful for reference checking. scale_ptr - scales for the +quantized weight. - Note that the returned quantized_weights will be preprocessed in a way to accelerate the mixed type GEMM. The data - layout may not make sense if printed. + Note that the returned quantized_weights will be preprocessed in a way to +accelerate the mixed type GEMM. The data layout may not make sense if printed. Shapes: quant_type == int8: - If weight is a [m,n] matrix, quantized_weights will have shape [m,n] and scales of shape [n] - If weight is a [b,m,n] tensor, unprocessed_quantized_weight will have shape [b,m,n] and scales of shape [b,n] - quant_type == int4: - If weight is a [m,n] matrix, quantized_weights will have shape [m, ceil(n/2)] and scales of shape [n] - If weight is a [b,m,n] tensor, unprocessed_quantized_weight will have shape [b,m, ceil(n/2)] and scales of shape - [b,n] - - The quantized_weight will be of type torch.int8 and have two int4 values packed in a single byte. This is the - reason for halving the shape. At the time of writing this code, there was not an elegant way to handle this kind - of batched quantization using torch's quantized tensors (to the best of the author's knowledge). Scale tensors - must have a dimension of 1, which breaks the semantics we need for batched weights. + If weight is a [m,n] matrix, quantized_weights will have shape [m,n] and +scales of shape [n] If weight is a [b,m,n] tensor, unprocessed_quantized_weight +will have shape [b,m,n] and scales of shape [b,n] quant_type == int4: If weight +is a [m,n] matrix, quantized_weights will have shape [m, ceil(n/2)] and scales +of shape [n] If weight is a [b,m,n] tensor, unprocessed_quantized_weight will +have shape [b,m, ceil(n/2)] and scales of shape [b,n] + + The quantized_weight will be of type torch.int8 and have two int4 values +packed in a single byte. This is the reason for halving the shape. At the time +of writing this code, there was not an elegant way to handle this kind of +batched quantization using torch's quantized tensors (to the best of the +author's knowledge). Scale tensors must have a dimension of 1, which breaks the +semantics we need for batched weights. */ template -void symmetric_quantize(int8_t* processed_quantized_weight, int8_t* unprocessed_quantized_weight, - ComputeType* scale_ptr, WeightType const* input_weight_ptr, std::vector const& shape, QuantType quant_type, - bool force_interleave) -{ - - PADDLE_ENFORCE(processed_quantized_weight, "Processed quantized tensor is NULL"); - PADDLE_ENFORCE(scale_ptr, "Scale output pointer is NULL"); - PADDLE_ENFORCE(input_weight_ptr, "Input weight pointer is NULL"); - - PADDLE_ENFORCE(shape.size() == 2 || shape.size() == 3, "Shape must be 2-D or 3-D"); - const size_t num_experts = shape.size() == 2 ? 1 : shape[0]; - const size_t num_rows = shape.size() == 2 ? shape[0] : shape[1]; - const size_t num_cols = shape.size() == 2 ? shape[1] : shape[2]; - - int const bits_in_type = get_weight_quant_bits(quant_type); - int const bytes_per_out_col = num_cols * bits_in_type / 8; - - int const bits_per_weigtht_element = get_weight_quant_bits(quant_type); - - std::vector weight_buf; - if (unprocessed_quantized_weight == nullptr) - { - weight_buf.resize(num_experts * num_rows * num_cols); - unprocessed_quantized_weight = weight_buf.data(); - } - - int const input_mat_size = num_rows * num_cols; - int const quantized_mat_size = num_rows * bytes_per_out_col; - float const quant_range_scale = 1.f / float(1 << (bits_in_type - 1)); - - std::vector per_col_max(num_cols); - - for (int expert = 0; expert < num_experts; ++expert) - { - WeightType const* current_weight = input_weight_ptr + expert * input_mat_size; - int8_t* current_quantized_weight = unprocessed_quantized_weight + expert * quantized_mat_size; - - // First we find the per column max for this expert weight. - for (int jj = 0; jj < num_cols; ++jj) - { - per_col_max[jj] = 0.f; - } - - for (int ii = 0; ii < num_rows; ++ii) - { - WeightType const* current_weight_row = current_weight + ii * num_cols; - for (int jj = 0; jj < num_cols; ++jj) - { - per_col_max[jj] = std::max(per_col_max[jj], std::abs(float(current_weight_row[jj]))); - } - } - - // Then, we construct the scales - ComputeType* current_scales = scale_ptr + expert * num_cols; - for (int jj = 0; jj < num_cols; ++jj) - { - per_col_max[jj] *= quant_range_scale; - current_scales[jj] = ComputeType(per_col_max[jj]); - } - - // Finally, construct the weights. - for (int ii = 0; ii < num_rows; ++ii) - { - int8_t* current_quantized_weight_row = current_quantized_weight + ii * bytes_per_out_col; - WeightType const* current_weight_row = current_weight + ii * num_cols; - for (int jj = 0; jj < bytes_per_out_col; ++jj) - { - - if (bits_per_weigtht_element == 8) - { - float const col_scale = per_col_max[jj]; - float const weight_elt = float(current_weight_row[jj]); - float const scaled_weight = (col_scale != 0.0f) ? round(weight_elt / col_scale) : 0.0f; - const int8_t clipped_weight = int8_t(std::max(-128.f, std::min(127.f, scaled_weight))); - current_quantized_weight_row[jj] = clipped_weight; - } - else if (bits_per_weigtht_element == 4) - { - - // We will pack two int4 elements per iteration of the inner loop. - int8_t packed_int4s = 0; - for (int packed_idx = 0; packed_idx < 2; ++packed_idx) - { - int const input_idx = 2 * jj + packed_idx; - if (input_idx < num_cols) - { - float const col_scale = per_col_max[input_idx]; - float const weight_elt = float(current_weight_row[input_idx]); - float const scaled_weight = (col_scale != 0.0f) ? round(weight_elt / col_scale) : 0.0f; - int int_weight = int(scaled_weight); - const int8_t clipped_weight = std::max(-8, std::min(7, int_weight)); - - // Kill the sign extension bits (hence 0x0F mask) then shift to upper bits - // if packing the second int4 and or the bits into the final result. - packed_int4s |= ((clipped_weight & 0x0F) << (4 * packed_idx)); - } - } - current_quantized_weight_row[jj] = packed_int4s; - } - else - { - PADDLE_ENFORCE(false, "Unsupported quantization type"); - } +void symmetric_quantize(int8_t* processed_quantized_weight, + int8_t* unprocessed_quantized_weight, + ComputeType* scale_ptr, + WeightType const* input_weight_ptr, + std::vector const& shape, + QuantType quant_type, + bool force_interleave) { + PADDLE_ENFORCE(processed_quantized_weight, + "Processed quantized tensor is NULL"); + PADDLE_ENFORCE(scale_ptr, "Scale output pointer is NULL"); + PADDLE_ENFORCE(input_weight_ptr, "Input weight pointer is NULL"); + + PADDLE_ENFORCE(shape.size() == 2 || shape.size() == 3, + "Shape must be 2-D or 3-D"); + const size_t num_experts = shape.size() == 2 ? 1 : shape[0]; + const size_t num_rows = shape.size() == 2 ? shape[0] : shape[1]; + const size_t num_cols = shape.size() == 2 ? shape[1] : shape[2]; + + int const bits_in_type = get_weight_quant_bits(quant_type); + int const bytes_per_out_col = num_cols * bits_in_type / 8; + + int const bits_per_weigtht_element = get_weight_quant_bits(quant_type); + + std::vector weight_buf; + if (unprocessed_quantized_weight == nullptr) { + weight_buf.resize(num_experts * num_rows * num_cols); + unprocessed_quantized_weight = weight_buf.data(); + } + + int const input_mat_size = num_rows * num_cols; + int const quantized_mat_size = num_rows * bytes_per_out_col; + float const quant_range_scale = 1.f / float(1 << (bits_in_type - 1)); + + std::vector per_col_max(num_cols); + + for (int expert = 0; expert < num_experts; ++expert) { + WeightType const* current_weight = + input_weight_ptr + expert * input_mat_size; + int8_t* current_quantized_weight = + unprocessed_quantized_weight + expert * quantized_mat_size; + + // First we find the per column max for this expert weight. + for (int jj = 0; jj < num_cols; ++jj) { + per_col_max[jj] = 0.f; + } + + for (int ii = 0; ii < num_rows; ++ii) { + WeightType const* current_weight_row = current_weight + ii * num_cols; + for (int jj = 0; jj < num_cols; ++jj) { + per_col_max[jj] = + std::max(per_col_max[jj], std::abs(float(current_weight_row[jj]))); + } + } + + // Then, we construct the scales + ComputeType* current_scales = scale_ptr + expert * num_cols; + for (int jj = 0; jj < num_cols; ++jj) { + per_col_max[jj] *= quant_range_scale; + current_scales[jj] = ComputeType(per_col_max[jj]); + } + + // Finally, construct the weights. + for (int ii = 0; ii < num_rows; ++ii) { + int8_t* current_quantized_weight_row = + current_quantized_weight + ii * bytes_per_out_col; + WeightType const* current_weight_row = current_weight + ii * num_cols; + for (int jj = 0; jj < bytes_per_out_col; ++jj) { + if (bits_per_weigtht_element == 8) { + float const col_scale = per_col_max[jj]; + float const weight_elt = float(current_weight_row[jj]); + float const scaled_weight = + (col_scale != 0.0f) ? round(weight_elt / col_scale) : 0.0f; + const int8_t clipped_weight = + int8_t(std::max(-128.f, std::min(127.f, scaled_weight))); + current_quantized_weight_row[jj] = clipped_weight; + } else if (bits_per_weigtht_element == 4) { + // We will pack two int4 elements per iteration of the inner loop. + int8_t packed_int4s = 0; + for (int packed_idx = 0; packed_idx < 2; ++packed_idx) { + int const input_idx = 2 * jj + packed_idx; + if (input_idx < num_cols) { + float const col_scale = per_col_max[input_idx]; + float const weight_elt = float(current_weight_row[input_idx]); + float const scaled_weight = + (col_scale != 0.0f) ? round(weight_elt / col_scale) : 0.0f; + int int_weight = int(scaled_weight); + const int8_t clipped_weight = + std::max(-8, std::min(7, int_weight)); + + // Kill the sign extension bits (hence 0x0F mask) then shift to + // upper bits if packing the second int4 and or the bits into the + // final result. + packed_int4s |= ((clipped_weight & 0x0F) << (4 * packed_idx)); } + } + current_quantized_weight_row[jj] = packed_int4s; + } else { + PADDLE_ENFORCE(false, "Unsupported quantization type"); } + } } + } - preprocess_weights_for_mixed_gemm( - processed_quantized_weight, unprocessed_quantized_weight, shape, quant_type, force_interleave); + preprocess_weights_for_mixed_gemm(processed_quantized_weight, + unprocessed_quantized_weight, + shape, + quant_type, + force_interleave); } -template void symmetric_quantize( - int8_t*, int8_t*, half*, float const*, std::vector const&, QuantType, bool); - -template void symmetric_quantize( - int8_t*, int8_t*, half*, half const*, std::vector const&, QuantType, bool); +template void symmetric_quantize(int8_t*, + int8_t*, + half*, + float const*, + std::vector const&, + QuantType, + bool); + +template void symmetric_quantize(int8_t*, + int8_t*, + half*, + half const*, + std::vector const&, + QuantType, + bool); #ifdef ENABLE_BF16 template void symmetric_quantize<__nv_bfloat16, __nv_bfloat16>( - int8_t*, int8_t*, __nv_bfloat16*, __nv_bfloat16 const*, std::vector const&, QuantType, bool); + int8_t*, + int8_t*, + __nv_bfloat16*, + __nv_bfloat16 const*, + std::vector const&, + QuantType, + bool); template void symmetric_quantize<__nv_bfloat16, float>( - int8_t*, int8_t*, __nv_bfloat16*, float const*, std::vector const&, QuantType, bool); + int8_t*, + int8_t*, + __nv_bfloat16*, + float const*, + std::vector const&, + QuantType, + bool); #endif template -void symmetric_quantize(int8_t* processed_quantized_weight, ComputeType* scale_ptr, WeightType const* input_weight_ptr, - std::vector const& shape, QuantType quant_type, bool force_interleave) -{ - symmetric_quantize( - processed_quantized_weight, nullptr, scale_ptr, input_weight_ptr, shape, quant_type, force_interleave); +void symmetric_quantize(int8_t* processed_quantized_weight, + ComputeType* scale_ptr, + WeightType const* input_weight_ptr, + std::vector const& shape, + QuantType quant_type, + bool force_interleave) { + symmetric_quantize(processed_quantized_weight, + nullptr, + scale_ptr, + input_weight_ptr, + shape, + quant_type, + force_interleave); } template void symmetric_quantize( @@ -773,21 +802,42 @@ template void symmetric_quantize( template void symmetric_quantize( int8_t*, half*, float const*, std::vector const&, QuantType, bool); -template void symmetric_quantize(int8_t*, half*, half const*, std::vector const&, QuantType, bool); +template void symmetric_quantize( + int8_t*, half*, half const*, std::vector const&, QuantType, bool); #ifdef ENABLE_BF16 template void symmetric_quantize<__nv_bfloat16, __nv_bfloat16>( - int8_t*, __nv_bfloat16*, __nv_bfloat16 const*, std::vector const&, QuantType, bool); + int8_t*, + __nv_bfloat16*, + __nv_bfloat16 const*, + std::vector const&, + QuantType, + bool); template void symmetric_quantize<__nv_bfloat16, half>( - int8_t*, __nv_bfloat16*, half const*, std::vector const&, QuantType, bool); + int8_t*, + __nv_bfloat16*, + half const*, + std::vector const&, + QuantType, + bool); template void symmetric_quantize( - int8_t*, half*, __nv_bfloat16 const*, std::vector const&, QuantType, bool); + int8_t*, + half*, + __nv_bfloat16 const*, + std::vector const&, + QuantType, + bool); template void symmetric_quantize<__nv_bfloat16, float>( - int8_t*, __nv_bfloat16*, float const*, std::vector const&, QuantType, bool); + int8_t*, + __nv_bfloat16*, + float const*, + std::vector const&, + QuantType, + bool); #endif -} // namespace cutlass_kernels -} // namespace kernels +} // namespace cutlass_kernels +} // namespace kernels diff --git a/custom_ops/gpu_ops/cutlass_kernels/cutlass_preprocessors.h b/custom_ops/gpu_ops/cutlass_kernels/cutlass_preprocessors.h index 8d025c1289b..292d7d8cbe7 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/cutlass_preprocessors.h +++ b/custom_ops/gpu_ops/cutlass_kernels/cutlass_preprocessors.h @@ -20,52 +20,67 @@ #include #include -namespace kernels -{ -namespace cutlass_kernels -{ +namespace kernels { +namespace cutlass_kernels { -enum class QuantType -{ - W8_A16, - W4_A16, - W4_AFP8 -}; +enum class QuantType { W8_A16, W4_A16, W4_AFP8 }; -constexpr int get_weight_quant_bits(QuantType quant_type) -{ - switch (quant_type) - { - case QuantType::W8_A16: return 8; - case QuantType::W4_A16: return 4; - case QuantType::W4_AFP8: return 4; - default: PADDLE_THROW("Invalid quant_type"); return -1; - } +constexpr int get_weight_quant_bits(QuantType quant_type) { + switch (quant_type) { + case QuantType::W8_A16: + return 8; + case QuantType::W4_A16: + return 4; + case QuantType::W4_AFP8: + return 4; + default: + PADDLE_THROW("Invalid quant_type"); + return -1; + } } // Shapes here can be 2 or 3D. 2-D shapes are [num_rows, num_cols] // 3-D shapes are [num_experts, num_rows, num_cols] -void permute_B_rows_for_mixed_gemm(int8_t* permuted_quantized_tensor, int8_t const* quantized_tensor, - std::vector const& shape, QuantType quant_type, const int64_t arch_version); +void permute_B_rows_for_mixed_gemm(int8_t* permuted_quantized_tensor, + int8_t const* quantized_tensor, + std::vector const& shape, + QuantType quant_type, + const int64_t arch_version); -void subbyte_transpose(int8_t* transposed_quantized_tensor, int8_t const* quantized_tensor, - std::vector const& shape, QuantType quant_type); +void subbyte_transpose(int8_t* transposed_quantized_tensor, + int8_t const* quantized_tensor, + std::vector const& shape, + QuantType quant_type); -void add_bias_and_interleave_quantized_tensor_inplace(int8_t* tensor, const size_t num_elts, QuantType quant_type); +void add_bias_and_interleave_quantized_tensor_inplace(int8_t* tensor, + const size_t num_elts, + QuantType quant_type); -void preprocess_weights_for_mixed_gemm(int8_t* preprocessed_quantized_weight, int8_t const* row_major_quantized_weight, - std::vector const& shape, QuantType quant_type, bool force_interleave = false); +void preprocess_weights_for_mixed_gemm(int8_t* preprocessed_quantized_weight, + int8_t const* row_major_quantized_weight, + std::vector const& shape, + QuantType quant_type, + bool force_interleave = false); template -void symmetric_quantize(int8_t* processed_quantized_weight, ComputeType* scale_ptr, WeightType const* input_weight_ptr, - std::vector const& shape, QuantType quant_type, bool force_interleave); +void symmetric_quantize(int8_t* processed_quantized_weight, + ComputeType* scale_ptr, + WeightType const* input_weight_ptr, + std::vector const& shape, + QuantType quant_type, + bool force_interleave); -// This is exposed so that we can write tests that use the processed weights for CUTLASS but the unprocessed weight -// to implement a simple reference implementation. +// This is exposed so that we can write tests that use the processed weights for +// CUTLASS but the unprocessed weight to implement a simple reference +// implementation. template -void symmetric_quantize(int8_t* processed_quantized_weight, int8_t* unprocessed_quantized_weight, - ComputeType* scale_ptr, WeightType const* input_weight_ptr, std::vector const& shape, QuantType quant_type, - bool force_interleave); +void symmetric_quantize(int8_t* processed_quantized_weight, + int8_t* unprocessed_quantized_weight, + ComputeType* scale_ptr, + WeightType const* input_weight_ptr, + std::vector const& shape, + QuantType quant_type, + bool force_interleave); -} // namespace cutlass_kernels -} // namespace kernels +} // namespace cutlass_kernels +} // namespace kernels diff --git a/custom_ops/gpu_ops/cutlass_kernels/cutlass_type_conversion.h b/custom_ops/gpu_ops/cutlass_kernels/cutlass_type_conversion.h index cf344772a68..a10f49cc774 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/cutlass_type_conversion.h +++ b/custom_ops/gpu_ops/cutlass_kernels/cutlass_type_conversion.h @@ -24,45 +24,38 @@ #include "cutlass/float8.h" #include "cutlass/half.h" -namespace kernels -{ -namespace cutlass_kernels -{ +namespace kernels { +namespace cutlass_kernels { /////////////////////////////////////////////////////////////////////////////////////////////////// // Cuda to Cutlass template -struct CudaToCutlassTypeAdapter -{ - using type = T; +struct CudaToCutlassTypeAdapter { + using type = T; }; template <> -struct CudaToCutlassTypeAdapter -{ - using type = cutlass::half_t; +struct CudaToCutlassTypeAdapter { + using type = cutlass::half_t; }; #if defined(ENABLE_BF16) template <> -struct CudaToCutlassTypeAdapter<__nv_bfloat16> -{ - using type = cutlass::bfloat16_t; +struct CudaToCutlassTypeAdapter<__nv_bfloat16> { + using type = cutlass::bfloat16_t; }; #endif #if defined(ENABLE_FP8) template <> -struct CudaToCutlassTypeAdapter<__nv_fp8_e4m3> -{ - using type = cutlass::float_e4m3_t; +struct CudaToCutlassTypeAdapter<__nv_fp8_e4m3> { + using type = cutlass::float_e4m3_t; }; template <> -struct CudaToCutlassTypeAdapter<__nv_fp8_e5m2> -{ - using type = cutlass::float_e5m2_t; +struct CudaToCutlassTypeAdapter<__nv_fp8_e5m2> { + using type = cutlass::float_e5m2_t; }; #endif @@ -70,40 +63,35 @@ struct CudaToCutlassTypeAdapter<__nv_fp8_e5m2> // Cutlass to Cuda template -struct CutlassToCudaTypeAdapter -{ - using type = T; +struct CutlassToCudaTypeAdapter { + using type = T; }; template <> -struct CutlassToCudaTypeAdapter -{ - using type = half; +struct CutlassToCudaTypeAdapter { + using type = half; }; #if defined(ENABLE_BF16) template <> -struct CutlassToCudaTypeAdapter -{ - using type = __nv_bfloat16; +struct CutlassToCudaTypeAdapter { + using type = __nv_bfloat16; }; #endif #if defined(ENABLE_FP8) template <> -struct CutlassToCudaTypeAdapter -{ - using type = __nv_fp8_e4m3; +struct CutlassToCudaTypeAdapter { + using type = __nv_fp8_e4m3; }; template <> -struct CutlassToCudaTypeAdapter -{ - using type = __nv_fp8_e5m2; +struct CutlassToCudaTypeAdapter { + using type = __nv_fp8_e5m2; }; #endif /////////////////////////////////////////////////////////////////////////////////////////////////// -} // namespace cutlass_kernels -} // namespace kernels +} // namespace cutlass_kernels +} // namespace kernels diff --git a/custom_ops/gpu_ops/cutlass_kernels/fp8_gemm_fused/dual_gemm/thread/left_gelu_and_mul.h b/custom_ops/gpu_ops/cutlass_kernels/fp8_gemm_fused/dual_gemm/thread/left_gelu_and_mul.h index 743b6c70aaa..a4f421a01fa 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/fp8_gemm_fused/dual_gemm/thread/left_gelu_and_mul.h +++ b/custom_ops/gpu_ops/cutlass_kernels/fp8_gemm_fused/dual_gemm/thread/left_gelu_and_mul.h @@ -67,87 +67,87 @@ template < ElementOutput_, ///< Data type used to compute linear combination FloatRoundStyle Round = FloatRoundStyle::round_to_nearest> class LeftGELUAndMul { - public: - using ElementOutput = ElementOutput_; - using ElementAccumulator = ElementAccumulator_; - using ElementCompute = ElementCompute_; + public: + using ElementOutput = ElementOutput_; + using ElementAccumulator = ElementAccumulator_; + using ElementCompute = ElementCompute_; - static int const kCount = Count; - using FragmentOutput = Array; - using FragmentAccumulator = Array; - using ComputeFragment = Array; + static int const kCount = Count; + using FragmentOutput = Array; + using FragmentAccumulator = Array; + using ComputeFragment = Array; - static FloatRoundStyle const kRound = Round; + static FloatRoundStyle const kRound = Round; - struct Params { - ElementCompute alpha; + struct Params { + ElementCompute alpha; - CUTLASS_HOST_DEVICE - Params() : alpha(ElementCompute(1)) {} - - CUTLASS_HOST_DEVICE - Params(ElementCompute alpha) : alpha(alpha) {} // NOLINT - }; - - private: - // - // Data members - // - - ElementCompute alpha_; - ElementCompute beta_; - - public: - /// Constructs the function object, possibly loading from pointers in host - /// memory - CUTLASS_HOST_DEVICE - LeftGELUAndMul(Params const ¶ms) { alpha_ = params.alpha; } // NOLINT - - /// Returns true if source is needed - CUTLASS_HOST_DEVICE - bool is_source_needed() const { return true; } - - /// Functionally required for serial reduction in the epilogue - CUTLASS_HOST_DEVICE - void set_k_partition(int k_partition, int k_partition_count) { - assert(false); - } - - /// Computes linear scaling: D = alpha * accumulator + beta * source CUTLASS_HOST_DEVICE - FragmentOutput operator()(FragmentAccumulator const &lhs, - FragmentAccumulator const &rhs) const { - // Convert source to interal compute numeric type - NumericArrayConverter - accumulator_to_compute; - - // Convert to destination numeric type - NumericArrayConverter - compute_to_output; - - ComputeFragment converted_lhs = accumulator_to_compute(lhs); - ComputeFragment converted_rhs = accumulator_to_compute(rhs); - - cutlass::epilogue::thread::GELU_taylor gelu; - cutlass::multiplies mul; - auto gelu_lhs = gelu(converted_lhs); - // return compute_to_output(mul(gelu_lhs, converted_rhs)); - auto tmp = mul(gelu_lhs, converted_rhs); - return compute_to_output(mul(alpha_, tmp)); - } + Params() : alpha(ElementCompute(1)) {} CUTLASS_HOST_DEVICE - ElementOutput operator()(ElementAccumulator const &lhs, - ElementAccumulator const &rhs) const { - ElementCompute convert_lhs(lhs); - ElementCompute convert_rhs(rhs); - cutlass::epilogue::thread::GELU_taylor gelu; - cutlass::multiplies mul; - auto gelu_lhs = gelu(convert_lhs); - // return ElementOutput(mul(gelu_lhs, convert_rhs)); - auto tmp = mul(gelu_lhs, convert_rhs); - return compute_to_output(mul(alpha_, tmp)); - } + Params(ElementCompute alpha) : alpha(alpha) {} // NOLINT + }; + + private: + // + // Data members + // + + ElementCompute alpha_; + ElementCompute beta_; + + public: + /// Constructs the function object, possibly loading from pointers in host + /// memory + CUTLASS_HOST_DEVICE + LeftGELUAndMul(Params const ¶ms) { alpha_ = params.alpha; } // NOLINT + + /// Returns true if source is needed + CUTLASS_HOST_DEVICE + bool is_source_needed() const { return true; } + + /// Functionally required for serial reduction in the epilogue + CUTLASS_HOST_DEVICE + void set_k_partition(int k_partition, int k_partition_count) { + assert(false); + } + + /// Computes linear scaling: D = alpha * accumulator + beta * source + CUTLASS_HOST_DEVICE + FragmentOutput operator()(FragmentAccumulator const &lhs, + FragmentAccumulator const &rhs) const { + // Convert source to internal compute numeric type + NumericArrayConverter + accumulator_to_compute; + + // Convert to destination numeric type + NumericArrayConverter + compute_to_output; + + ComputeFragment converted_lhs = accumulator_to_compute(lhs); + ComputeFragment converted_rhs = accumulator_to_compute(rhs); + + cutlass::epilogue::thread::GELU_taylor gelu; + cutlass::multiplies mul; + auto gelu_lhs = gelu(converted_lhs); + // return compute_to_output(mul(gelu_lhs, converted_rhs)); + auto tmp = mul(gelu_lhs, converted_rhs); + return compute_to_output(mul(alpha_, tmp)); + } + + CUTLASS_HOST_DEVICE + ElementOutput operator()(ElementAccumulator const &lhs, + ElementAccumulator const &rhs) const { + ElementCompute convert_lhs(lhs); + ElementCompute convert_rhs(rhs); + cutlass::epilogue::thread::GELU_taylor gelu; + cutlass::multiplies mul; + auto gelu_lhs = gelu(convert_lhs); + // return ElementOutput(mul(gelu_lhs, convert_rhs)); + auto tmp = mul(gelu_lhs, convert_rhs); + return compute_to_output(mul(alpha_, tmp)); + } }; } // namespace thread diff --git a/custom_ops/gpu_ops/cutlass_kernels/fp8_gemm_fused/dual_gemm/thread/left_silu_and_mul.h b/custom_ops/gpu_ops/cutlass_kernels/fp8_gemm_fused/dual_gemm/thread/left_silu_and_mul.h index 7c1213c7e3e..51da87e8986 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/fp8_gemm_fused/dual_gemm/thread/left_silu_and_mul.h +++ b/custom_ops/gpu_ops/cutlass_kernels/fp8_gemm_fused/dual_gemm/thread/left_silu_and_mul.h @@ -67,87 +67,87 @@ template < ElementOutput_, ///< Data type used to compute linear combination FloatRoundStyle Round = FloatRoundStyle::round_to_nearest> class LeftSiLUAndMul { - public: - using ElementOutput = ElementOutput_; - using ElementAccumulator = ElementAccumulator_; - using ElementCompute = ElementCompute_; + public: + using ElementOutput = ElementOutput_; + using ElementAccumulator = ElementAccumulator_; + using ElementCompute = ElementCompute_; - static int const kCount = Count; - using FragmentOutput = Array; - using FragmentAccumulator = Array; - using ComputeFragment = Array; + static int const kCount = Count; + using FragmentOutput = Array; + using FragmentAccumulator = Array; + using ComputeFragment = Array; - static FloatRoundStyle const kRound = Round; + static FloatRoundStyle const kRound = Round; - struct Params { - ElementCompute alpha; + struct Params { + ElementCompute alpha; - CUTLASS_HOST_DEVICE - Params() : alpha(ElementCompute(1)) {} - - CUTLASS_HOST_DEVICE - Params(ElementCompute alpha) : alpha(alpha) {} // NOLINT - }; - - private: - // - // Data members - // - - ElementCompute alpha_; - ElementCompute beta_; - - public: - /// Constructs the function object, possibly loading from pointers in host - /// memory - CUTLASS_HOST_DEVICE - LeftSiLUAndMul(Params const ¶ms) { alpha_ = params.alpha; } // NOLINT - - /// Returns true if source is needed - CUTLASS_HOST_DEVICE - bool is_source_needed() const { return true; } - - /// Functionally required for serial reduction in the epilogue - CUTLASS_HOST_DEVICE - void set_k_partition(int k_partition, int k_partition_count) { - assert(false); - } - - /// Computes linear scaling: D = alpha * accumulator + beta * source CUTLASS_HOST_DEVICE - FragmentOutput operator()(FragmentAccumulator const &lhs, - FragmentAccumulator const &rhs) const { - // Convert source to interal compute numeric type - NumericArrayConverter - accumulator_to_compute; - - // Convert to destination numeric type - NumericArrayConverter - compute_to_output; - - ComputeFragment converted_lhs = accumulator_to_compute(lhs); - ComputeFragment converted_rhs = accumulator_to_compute(rhs); - - cutlass::epilogue::thread::SiLu silu; - cutlass::multiplies mul; - auto silu_lhs = silu(converted_lhs); - // return compute_to_output(mul(silu_lhs, converted_rhs)); - auto tmp = mul(silu_lhs, converted_rhs); - return compute_to_output(mul(alpha_, tmp)); - } + Params() : alpha(ElementCompute(1)) {} CUTLASS_HOST_DEVICE - ElementOutput operator()(ElementAccumulator const &lhs, - ElementAccumulator const &rhs) const { - ElementCompute convert_lhs(lhs); - ElementCompute convert_rhs(rhs); - cutlass::epilogue::thread::SiLu silu; - cutlass::multiplies mul; - auto silu_lhs = silu(convert_lhs); - // return ElementOutput(mul(silu_lhs, convert_rhs)); - auto tmp = mul(silu_lhs, convert_rhs); - return ElementOutput(mul(alpha_, tmp)); - } + Params(ElementCompute alpha) : alpha(alpha) {} // NOLINT + }; + + private: + // + // Data members + // + + ElementCompute alpha_; + ElementCompute beta_; + + public: + /// Constructs the function object, possibly loading from pointers in host + /// memory + CUTLASS_HOST_DEVICE + LeftSiLUAndMul(Params const ¶ms) { alpha_ = params.alpha; } // NOLINT + + /// Returns true if source is needed + CUTLASS_HOST_DEVICE + bool is_source_needed() const { return true; } + + /// Functionally required for serial reduction in the epilogue + CUTLASS_HOST_DEVICE + void set_k_partition(int k_partition, int k_partition_count) { + assert(false); + } + + /// Computes linear scaling: D = alpha * accumulator + beta * source + CUTLASS_HOST_DEVICE + FragmentOutput operator()(FragmentAccumulator const &lhs, + FragmentAccumulator const &rhs) const { + // Convert source to internal compute numeric type + NumericArrayConverter + accumulator_to_compute; + + // Convert to destination numeric type + NumericArrayConverter + compute_to_output; + + ComputeFragment converted_lhs = accumulator_to_compute(lhs); + ComputeFragment converted_rhs = accumulator_to_compute(rhs); + + cutlass::epilogue::thread::SiLu silu; + cutlass::multiplies mul; + auto silu_lhs = silu(converted_lhs); + // return compute_to_output(mul(silu_lhs, converted_rhs)); + auto tmp = mul(silu_lhs, converted_rhs); + return compute_to_output(mul(alpha_, tmp)); + } + + CUTLASS_HOST_DEVICE + ElementOutput operator()(ElementAccumulator const &lhs, + ElementAccumulator const &rhs) const { + ElementCompute convert_lhs(lhs); + ElementCompute convert_rhs(rhs); + cutlass::epilogue::thread::SiLu silu; + cutlass::multiplies mul; + auto silu_lhs = silu(convert_lhs); + // return ElementOutput(mul(silu_lhs, convert_rhs)); + auto tmp = mul(silu_lhs, convert_rhs); + return ElementOutput(mul(alpha_, tmp)); + } }; } // namespace thread diff --git a/custom_ops/gpu_ops/cutlass_kernels/fp8_gemm_fused/dual_gemm/threadblock/dual_epilogue.h b/custom_ops/gpu_ops/cutlass_kernels/fp8_gemm_fused/dual_gemm/threadblock/dual_epilogue.h index 7d679c341d0..2be64080139 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/fp8_gemm_fused/dual_gemm/threadblock/dual_epilogue.h +++ b/custom_ops/gpu_ops/cutlass_kernels/fp8_gemm_fused/dual_gemm/threadblock/dual_epilogue.h @@ -98,357 +98,357 @@ template ::value)> class DualEpilogue { - public: - using Base = EpilogueBase; - - using Shape = Shape_; - using WarpMmaOperator = WarpMmaOperator_; - static int const kPartitionsK = PartitionsK; - static bool constexpr kStoreD0 = StoreD0; - static bool constexpr kStoreD1 = StoreD1; - using OutputTileIterator = OutputTileIterator_; - using OutputTileIterator2 = OutputTileIterator2_; - using AccumulatorFragmentIterator = AccumulatorFragmentIterator_; - using WarpTileIterator = WarpTileIterator_; - using SharedLoadIterator = SharedLoadIterator_; - using OutputOp0 = OutputOp0_; - using OutputOp1 = OutputOp1_; - using OutputOp2 = OutputOp2_; - using Padding = Padding_; - - using Layout = layout::RowMajor; - using LongIndex = typename Layout::LongIndex; - - // The complete warp-level accumulator tile - using AccumulatorTile = typename Base::AccumulatorTile; - - // Accumulator element - using ElementAccumulator = typename WarpTileIterator::Element; - - // Output element - using ElementOutput = typename OutputTileIterator::Element; - - // Output access size - static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess; - - // Tensor reference to destination tensor - using TensorRef = typename OutputTileIterator::TensorRef; - - // Tensor reference to sync tensor - using SyncTensorRef = - typename cutlass::TensorRef; - - // Const tensor reference to source tensor - using ConstTensorRef = typename OutputTileIterator::ConstTensorRef; - - // Array type used to output - using OutputAccessType = Array; - - // Array type used to output - using OutputAccessType2 = Array; - - // Array type used by output functor - using AccumulatorAccessType = Array; - - // Number of warps - using WarpCount = typename Base::WarpCount; - - struct SharedStorage { - using Element = typename WarpTileIterator::Element; - - // Tensor reference to shared memory allocation - using TensorRef = typename WarpTileIterator::TensorRef; - - // Logical shape of the shared memory tile written to by all warps. - using Shape = typename Base::Shape; - - // Shape of the shared memory allocation for the epilogue - using StorageShape = typename Base::SharedStorage::StorageShape; - - // - // Data members - // - - AlignedBuffer storage[2]; - - // - // Methods - // - - // Returns a tensor reference to the shared memory buffer - CUTLASS_DEVICE - TensorRef reference(int i) { - return TensorRef( - storage[i].data(), - Layout::packed({StorageShape::kRow, StorageShape::kColumn})); - } - }; - - static int constexpr kSmemTiles = Base::kFragmentsPerIteration > 1 - ? Base::kFragmentsPerIteration - : kPartitionsK; - static int constexpr kSmemPointerOffset = - SharedStorage::StorageShape::kCount / kSmemTiles; - - public: - static_assert( - SharedLoadIterator::Fragment::kElements == - OutputTileIterator::Fragment::kElements, - "Mismatch between shared load iterator and output tile iterator."); - - static_assert(OutputTileIterator::kElementsPerAccess, - "OutputTileIterator::kElementsPerAccess must not be zero."); - - static_assert(!(OutputTileIterator::Fragment::kElements % - OutputTileIterator::kElementsPerAccess), - "Divisibility"); - - private: - // Loads fragment from shared memory aligned with output tensor - SharedLoadIterator shared_load_iterator0_; - SharedLoadIterator shared_load_iterator1_; - - // Stores a warp's fragment of accumulators to SMEM - WarpTileIterator warp_tile_iterator0_; - WarpTileIterator warp_tile_iterator1_; - - public: - // Constructor + public: + using Base = EpilogueBase; + + using Shape = Shape_; + using WarpMmaOperator = WarpMmaOperator_; + static int const kPartitionsK = PartitionsK; + static bool constexpr kStoreD0 = StoreD0; + static bool constexpr kStoreD1 = StoreD1; + using OutputTileIterator = OutputTileIterator_; + using OutputTileIterator2 = OutputTileIterator2_; + using AccumulatorFragmentIterator = AccumulatorFragmentIterator_; + using WarpTileIterator = WarpTileIterator_; + using SharedLoadIterator = SharedLoadIterator_; + using OutputOp0 = OutputOp0_; + using OutputOp1 = OutputOp1_; + using OutputOp2 = OutputOp2_; + using Padding = Padding_; + + using Layout = layout::RowMajor; + using LongIndex = typename Layout::LongIndex; + + // The complete warp-level accumulator tile + using AccumulatorTile = typename Base::AccumulatorTile; + + // Accumulator element + using ElementAccumulator = typename WarpTileIterator::Element; + + // Output element + using ElementOutput = typename OutputTileIterator::Element; + + // Output access size + static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess; + + // Tensor reference to destination tensor + using TensorRef = typename OutputTileIterator::TensorRef; + + // Tensor reference to sync tensor + using SyncTensorRef = + typename cutlass::TensorRef; + + // Const tensor reference to source tensor + using ConstTensorRef = typename OutputTileIterator::ConstTensorRef; + + // Array type used to output + using OutputAccessType = Array; + + // Array type used to output + using OutputAccessType2 = Array; + + // Array type used by output functor + using AccumulatorAccessType = Array; + + // Number of warps + using WarpCount = typename Base::WarpCount; + + struct SharedStorage { + using Element = typename WarpTileIterator::Element; + + // Tensor reference to shared memory allocation + using TensorRef = typename WarpTileIterator::TensorRef; + + // Logical shape of the shared memory tile written to by all warps. + using Shape = typename Base::Shape; + + // Shape of the shared memory allocation for the epilogue + using StorageShape = typename Base::SharedStorage::StorageShape; + + // + // Data members + // + + AlignedBuffer storage[2]; + + // + // Methods + // + + // Returns a tensor reference to the shared memory buffer CUTLASS_DEVICE - DualEpilogue( - SharedStorage &shared_storage, // Shared storage object // NOLINT - int thread_idx, // ID of a thread within the threadblock - int warp_idx, // ID of warp within threadblock - int lane_idx // Id of thread within warp - ) - : shared_load_iterator0_(shared_storage.reference(0), thread_idx), - shared_load_iterator1_(shared_storage.reference(1), thread_idx), - warp_tile_iterator0_(shared_storage.reference(0), lane_idx), - warp_tile_iterator1_(shared_storage.reference(1), lane_idx) { - int warp_k = warp_idx / (WarpCount::kM * WarpCount::kN); - int warp_mn = warp_idx % (WarpCount::kM * WarpCount::kN); - int warp_m = warp_mn % WarpCount::kM; - int warp_n = warp_mn / WarpCount::kM; - - MatrixCoord warp_offset{warp_k * WarpCount::kM + warp_m, warp_n}; - - warp_tile_iterator0_.add_tile_offset(warp_offset); - warp_tile_iterator1_.add_tile_offset(warp_offset); + TensorRef reference(int i) { + return TensorRef( + storage[i].data(), + Layout::packed({StorageShape::kRow, StorageShape::kColumn})); + } + }; + + static int constexpr kSmemTiles = Base::kFragmentsPerIteration > 1 + ? Base::kFragmentsPerIteration + : kPartitionsK; + static int constexpr kSmemPointerOffset = + SharedStorage::StorageShape::kCount / kSmemTiles; + + public: + static_assert( + SharedLoadIterator::Fragment::kElements == + OutputTileIterator::Fragment::kElements, + "Mismatch between shared load iterator and output tile iterator."); + + static_assert(OutputTileIterator::kElementsPerAccess, + "OutputTileIterator::kElementsPerAccess must not be zero."); + + static_assert(!(OutputTileIterator::Fragment::kElements % + OutputTileIterator::kElementsPerAccess), + "Divisibility"); + + private: + // Loads fragment from shared memory aligned with output tensor + SharedLoadIterator shared_load_iterator0_; + SharedLoadIterator shared_load_iterator1_; + + // Stores a warp's fragment of accumulators to SMEM + WarpTileIterator warp_tile_iterator0_; + WarpTileIterator warp_tile_iterator1_; + + public: + // Constructor + CUTLASS_DEVICE + DualEpilogue( + SharedStorage &shared_storage, // Shared storage object // NOLINT + int thread_idx, // ID of a thread within the threadblock + int warp_idx, // ID of warp within threadblock + int lane_idx // Id of thread within warp + ) + : shared_load_iterator0_(shared_storage.reference(0), thread_idx), + shared_load_iterator1_(shared_storage.reference(1), thread_idx), + warp_tile_iterator0_(shared_storage.reference(0), lane_idx), + warp_tile_iterator1_(shared_storage.reference(1), lane_idx) { + int warp_k = warp_idx / (WarpCount::kM * WarpCount::kN); + int warp_mn = warp_idx % (WarpCount::kM * WarpCount::kN); + int warp_m = warp_mn % WarpCount::kM; + int warp_n = warp_mn / WarpCount::kM; + + MatrixCoord warp_offset{warp_k * WarpCount::kM + warp_m, warp_n}; + + warp_tile_iterator0_.add_tile_offset(warp_offset); + warp_tile_iterator1_.add_tile_offset(warp_offset); + } + + // Streams the result to global memory + CUTLASS_DEVICE + void operator()(OutputOp0 const &output_op0, + OutputOp1 const &output_op1, + OutputOp2 const &output_op2, + OutputTileIterator dest0, + OutputTileIterator dest1, + OutputTileIterator2 dest2, + AccumulatorTile const &accumulator0, + AccumulatorTile const &accumulator1, + OutputTileIterator source_iterator[2], + bool writeToD2 // true if it's the final split-k + ) { + // TODO: Implement when no source is needed // NOLINT + typename OutputTileIterator::Fragment source_fragment[2]; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < 2; ++i) { + source_fragment[i].clear(); } - // Streams the result to global memory - CUTLASS_DEVICE - void operator()(OutputOp0 const &output_op0, - OutputOp1 const &output_op1, - OutputOp2 const &output_op2, - OutputTileIterator dest0, - OutputTileIterator dest1, - OutputTileIterator2 dest2, - AccumulatorTile const &accumulator0, - AccumulatorTile const &accumulator1, - OutputTileIterator source_iterator[2], - bool writeToD2 // true if it's the final split-k - ) { - // TODO: Implement when no source is needed // NOLINT - typename OutputTileIterator::Fragment source_fragment[2]; + // + // Iterator over warp-level accumulator fragment + // + + AccumulatorFragmentIterator accum_fragment_iterator[2] = {accumulator0, + accumulator1}; + + // + // Iterate over accumulator tile + // + +#pragma unroll(IterationsUnroll ? OutputTileIterator::kIterations : 1) + for (int iter = 0; iter < OutputTileIterator::kIterations; ++iter) { + // + // Load the source + // + CUTLASS_PRAGMA_UNROLL for (int i = 0; i < 2; ++i) { - source_fragment[i].clear(); + source_iterator[i].load(source_fragment[i]); + ++source_iterator[i]; } // - // Iterator over warp-level accumulator fragment + // Convert and store fragment // - AccumulatorFragmentIterator accum_fragment_iterator[2] = {accumulator0, - accumulator1}; + __syncthreads(); + + acc2smem_source_needed>::push(iter, + accum_fragment_iterator[0], + this->warp_tile_iterator0_); + acc2smem_source_needed>::push(iter, + accum_fragment_iterator[1], + this->warp_tile_iterator1_); + + __syncthreads(); + + // + // Load fragments from shared memory + // + + typename SharedLoadIterator::Fragment + aligned_accum_fragment0[kPartitionsK]; + typename SharedLoadIterator::Fragment + aligned_accum_fragment1[kPartitionsK]; + + shared_load_iterator0_.load(aligned_accum_fragment0[0]); + shared_load_iterator1_.load(aligned_accum_fragment1[0]); + + // If the number of k-slices is > 1 - perform a reduction amongst the + // k-slices + if (kPartitionsK > 1) { + plus add_fragments; + + CUTLASS_PRAGMA_UNROLL + for (int i = 1; i < kPartitionsK; ++i) { + shared_load_iterator0_.add_pointer_offset(kSmemPointerOffset); + shared_load_iterator1_.add_pointer_offset(kSmemPointerOffset); + shared_load_iterator0_.load(aligned_accum_fragment0[i]); + shared_load_iterator1_.load(aligned_accum_fragment1[i]); + aligned_accum_fragment0[0] = add_fragments( + aligned_accum_fragment0[0], aligned_accum_fragment0[i]); + aligned_accum_fragment1[0] = add_fragments( + aligned_accum_fragment1[0], aligned_accum_fragment1[i]); + } + + shared_load_iterator0_.add_pointer_offset((1 - kPartitionsK) * + kSmemPointerOffset); + shared_load_iterator1_.add_pointer_offset((1 - kPartitionsK) * + kSmemPointerOffset); + } + + // + // Compute the output result + // + + typename OutputTileIterator::Fragment output_fragment[2]; + typename OutputTileIterator2::Fragment output_fragment_final; + + apply_output_operator_(output_fragment, + output_fragment_final, + output_op0, + output_op1, + output_op2, + aligned_accum_fragment0[0], + aligned_accum_fragment1[0], + source_fragment); // - // Iterate over accumulator tile + // Store the final result // - #pragma unroll(IterationsUnroll ? OutputTileIterator::kIterations : 1) - for (int iter = 0; iter < OutputTileIterator::kIterations; ++iter) { - // - // Load the source - // - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < 2; ++i) { - source_iterator[i].load(source_fragment[i]); - ++source_iterator[i]; - } - - // - // Convert and store fragment - // - - __syncthreads(); - - acc2smem_source_needed>::push(iter, - accum_fragment_iterator[0], - this->warp_tile_iterator0_); - acc2smem_source_needed>::push(iter, - accum_fragment_iterator[1], - this->warp_tile_iterator1_); - - __syncthreads(); - - // - // Load fragments from shared memory - // - - typename SharedLoadIterator::Fragment - aligned_accum_fragment0[kPartitionsK]; - typename SharedLoadIterator::Fragment - aligned_accum_fragment1[kPartitionsK]; - - shared_load_iterator0_.load(aligned_accum_fragment0[0]); - shared_load_iterator1_.load(aligned_accum_fragment1[0]); - - // If the number of k-slices is > 1 - perform a reduction amongst the - // k-slices - if (kPartitionsK > 1) { - plus add_fragments; - - CUTLASS_PRAGMA_UNROLL - for (int i = 1; i < kPartitionsK; ++i) { - shared_load_iterator0_.add_pointer_offset(kSmemPointerOffset); - shared_load_iterator1_.add_pointer_offset(kSmemPointerOffset); - shared_load_iterator0_.load(aligned_accum_fragment0[i]); - shared_load_iterator1_.load(aligned_accum_fragment1[i]); - aligned_accum_fragment0[0] = add_fragments( - aligned_accum_fragment0[0], aligned_accum_fragment0[i]); - aligned_accum_fragment1[0] = add_fragments( - aligned_accum_fragment1[0], aligned_accum_fragment1[i]); - } - - shared_load_iterator0_.add_pointer_offset((1 - kPartitionsK) * - kSmemPointerOffset); - shared_load_iterator1_.add_pointer_offset((1 - kPartitionsK) * - kSmemPointerOffset); - } - - // - // Compute the output result - // - - typename OutputTileIterator::Fragment output_fragment[2]; - typename OutputTileIterator2::Fragment output_fragment_final; - - apply_output_operator_(output_fragment, - output_fragment_final, - output_op0, - output_op1, - output_op2, - aligned_accum_fragment0[0], - aligned_accum_fragment1[0], - source_fragment); - - // - // Store the final result - // - - if (kStoreD0) { - dest0.store(output_fragment[0]); - ++dest0; - } - if (kStoreD1) { - dest1.store(output_fragment[1]); - ++dest1; - } - if (writeToD2) { - dest2.store(output_fragment_final); - ++dest2; - } + if (kStoreD0) { + dest0.store(output_fragment[0]); + ++dest0; + } + if (kStoreD1) { + dest1.store(output_fragment[1]); + ++dest1; + } + if (writeToD2) { + dest2.store(output_fragment_final); + ++dest2; } } + } - private: - static_assert(kPartitionsK == 1 || Base::kFragmentsPerIteration == 1, - "One of these must be exactly 1."); - - template - struct acc2smem_source_needed; - - template - struct acc2smem_source_needed> { - template - CUTLASS_DEVICE static void helper( - AccumulatorFragmentIterator accum_fragment_iterator, - WarpTileIterator &warp_tile_iterator) { // NOLINT - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < Advance; i++) { - ++accum_fragment_iterator; - } - - typename AccumulatorFragmentIterator::Fragment accum_fragment; - accum_fragment_iterator.load(accum_fragment); - warp_tile_iterator.store(accum_fragment); - } + private: + static_assert(kPartitionsK == 1 || Base::kFragmentsPerIteration == 1, + "One of these must be exactly 1."); - CUTLASS_DEVICE - static void push(size_t pos, - AccumulatorFragmentIterator const &iterator_begin, - WarpTileIterator &warp_tile_iterator) { // NOLINT - int dummy[] = {(pos == Seq) && - (helper(iterator_begin, warp_tile_iterator), 0)...}; - } - }; + template + struct acc2smem_source_needed; - // Helper to invoke the output functor over each vector of output - CUTLASS_DEVICE - void apply_output_operator_( - typename OutputTileIterator::Fragment (&output_fragment)[2], - typename OutputTileIterator2::Fragment &output_fragment_final, // NOLINT - OutputOp0 const &output_op0, - OutputOp1 const &output_op1, - OutputOp2 const &output_op2, - typename SharedLoadIterator::Fragment const &aligned_accum_fragment0, - typename SharedLoadIterator::Fragment const &aligned_accum_fragment1, - typename OutputTileIterator::Fragment const (&source_fragment)[2]) { - OutputAccessType *output_frag_ptr[2] = { - reinterpret_cast(&output_fragment[0]), - reinterpret_cast(&output_fragment[1])}; - - OutputAccessType2 *output_frag_final_ptr = - reinterpret_cast(&output_fragment_final); - - AccumulatorAccessType const *compute_frag_ptr[2] = { - reinterpret_cast( - &aligned_accum_fragment0), - reinterpret_cast( - &aligned_accum_fragment1)}; - - OutputAccessType const *source_frag_ptr[2] = { - reinterpret_cast(&source_fragment[0]), - reinterpret_cast(&source_fragment[1])}; - - int const kOutputOpIterations = OutputTileIterator::Fragment::kElements / - OutputTileIterator::kElementsPerAccess; + template + struct acc2smem_source_needed> { + template + CUTLASS_DEVICE static void helper( + AccumulatorFragmentIterator accum_fragment_iterator, + WarpTileIterator &warp_tile_iterator) { // NOLINT + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Advance; i++) { + ++accum_fragment_iterator; + } - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < kOutputOpIterations; ++i) { - // Call the output operators - output_frag_ptr[0][i] = - output_op0(compute_frag_ptr[0][i], source_frag_ptr[0][i]); - output_frag_ptr[1][i] = - output_op1(compute_frag_ptr[1][i], source_frag_ptr[1][i]); - output_frag_final_ptr[i] = - output_op2(output_frag_ptr[0][i], output_frag_ptr[1][i]); - } + typename AccumulatorFragmentIterator::Fragment accum_fragment; + accum_fragment_iterator.load(accum_fragment); + warp_tile_iterator.store(accum_fragment); + } + + CUTLASS_DEVICE + static void push(size_t pos, + AccumulatorFragmentIterator const &iterator_begin, + WarpTileIterator &warp_tile_iterator) { // NOLINT + int dummy[] = {(pos == Seq) && + (helper(iterator_begin, warp_tile_iterator), 0)...}; + } + }; + + // Helper to invoke the output functor over each vector of output + CUTLASS_DEVICE + void apply_output_operator_( + typename OutputTileIterator::Fragment (&output_fragment)[2], + typename OutputTileIterator2::Fragment &output_fragment_final, // NOLINT + OutputOp0 const &output_op0, + OutputOp1 const &output_op1, + OutputOp2 const &output_op2, + typename SharedLoadIterator::Fragment const &aligned_accum_fragment0, + typename SharedLoadIterator::Fragment const &aligned_accum_fragment1, + typename OutputTileIterator::Fragment const (&source_fragment)[2]) { + OutputAccessType *output_frag_ptr[2] = { + reinterpret_cast(&output_fragment[0]), + reinterpret_cast(&output_fragment[1])}; + + OutputAccessType2 *output_frag_final_ptr = + reinterpret_cast(&output_fragment_final); + + AccumulatorAccessType const *compute_frag_ptr[2] = { + reinterpret_cast( + &aligned_accum_fragment0), + reinterpret_cast( + &aligned_accum_fragment1)}; + + OutputAccessType const *source_frag_ptr[2] = { + reinterpret_cast(&source_fragment[0]), + reinterpret_cast(&source_fragment[1])}; + + int const kOutputOpIterations = OutputTileIterator::Fragment::kElements / + OutputTileIterator::kElementsPerAccess; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kOutputOpIterations; ++i) { + // Call the output operators + output_frag_ptr[0][i] = + output_op0(compute_frag_ptr[0][i], source_frag_ptr[0][i]); + output_frag_ptr[1][i] = + output_op1(compute_frag_ptr[1][i], source_frag_ptr[1][i]); + output_frag_final_ptr[i] = + output_op2(output_frag_ptr[0][i], output_frag_ptr[1][i]); } + } }; } // namespace threadblock diff --git a/custom_ops/gpu_ops/cutlass_kernels/fp8_gemm_fused/dual_gemm/threadblock/dual_mma_base.h b/custom_ops/gpu_ops/cutlass_kernels/fp8_gemm_fused/dual_gemm/threadblock/dual_mma_base.h index 2d6fbc1fd26..530bf5665f2 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/fp8_gemm_fused/dual_gemm/threadblock/dual_mma_base.h +++ b/custom_ops/gpu_ops/cutlass_kernels/fp8_gemm_fused/dual_gemm/threadblock/dual_mma_base.h @@ -92,7 +92,7 @@ class DualMmaBase { Shape::kN / WarpGemm::kN, Shape::kK / WarpGemm::kK>; - /// Number of warp-level GEMM oeprations + /// Number of warp-level GEMM operations static int const kWarpGemmIterations = (WarpGemm::kK / Operator0::Policy::MmaShape::kK); diff --git a/custom_ops/gpu_ops/cutlass_kernels/fp8_gemm_fused/fuse_block_gemm_act_template_3x.h b/custom_ops/gpu_ops/cutlass_kernels/fp8_gemm_fused/fuse_block_gemm_act_template_3x.h index 1a5b838b81e..f95f20bfbf4 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/fp8_gemm_fused/fuse_block_gemm_act_template_3x.h +++ b/custom_ops/gpu_ops/cutlass_kernels/fp8_gemm_fused/fuse_block_gemm_act_template_3x.h @@ -34,34 +34,35 @@ using namespace cute; template < - typename InputType = phi::dtype::float8_e4m3fn, - typename OutType = phi::dtype::float16, - bool hasbias = false, - template typename Activation = cutlass::epilogue::thread::Identity, - typename TileShape = Shape<_128, _128, _128>, - typename ClusterShape = Shape<_1, _2, _1>, - typename KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum<1>, - typename EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative, - typename TileSchedule = cutlass::gemm::PersistentScheduler, - typename SM = cutlass::arch::Sm90 -> -bool dispatch_fuse_block_gemm_c3x(GemmEpilogueAllParams params){ - using ElementA = typename std::conditional_t, - cutlass::float_e4m3_t, - cutlass::float_e5m2_t>; + typename InputType = phi::dtype::float8_e4m3fn, + typename OutType = phi::dtype::float16, + bool hasbias = false, + template typename Activation = cutlass::epilogue::thread::Identity, + typename TileShape = Shape<_128, _128, _128>, + typename ClusterShape = Shape<_1, _2, _1>, + typename KernelSchedule = cutlass::gemm:: + KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum<1>, + typename EpilogueSchedule = + cutlass::epilogue::TmaWarpSpecializedCooperative, + typename TileSchedule = cutlass::gemm::PersistentScheduler, + typename SM = cutlass::arch::Sm90> +bool dispatch_fuse_block_gemm_c3x(GemmEpilogueAllParams params) { + using ElementA = typename std::conditional_t< + std::is_same_v, + cutlass::float_e4m3_t, + cutlass::float_e5m2_t>; using ElementB = ElementA; - using ElementD = typename std::conditional_t, + using ElementD = + typename std::conditional_t, cutlass::bfloat16_t, cutlass::half_t>; - using ElementC = std::conditional_t< - hasbias, - ElementD, - void>; + using ElementC = std::conditional_t; constexpr int ScaleMsPerTile = size<0>(TileShape{}); constexpr int ScaleGranularityM = size<0>(TileShape{}) / ScaleMsPerTile; - static constexpr bool IsStreamK = cute::is_same_v; + static constexpr bool IsStreamK = + cute::is_same_v; using LayoutA = cutlass::layout::RowMajor; using LayoutB = cutlass::layout::ColumnMajor; @@ -80,37 +81,54 @@ bool dispatch_fuse_block_gemm_c3x(GemmEpilogueAllParams params){ static constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest; - using FusionOperation = cutlass::epilogue::fusion::LinCombEltAct; - - using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< - SM, cutlass::arch::OpClassTensorOp, - TileShape, - ClusterShape, - cutlass::epilogue::collective::EpilogueTileAuto, - ElementAccumulator, ElementCompute, - ElementC, LayoutC, AlignmentC, - ElementD, LayoutD, AlignmentD, - EpilogueSchedule, - FusionOperation - >::CollectiveOp; - - using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< - SM, cutlass::arch::OpClassTensorOp, - ElementA, LayoutA, AlignmentA, - ElementB, LayoutB, AlignmentB, - ElementAccumulator, - TileShape, - ClusterShape, - cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, - KernelSchedule - >::CollectiveOp; - - using GemmKernel = cutlass::gemm::kernel::GemmUniversal< - Shape, - CollectiveMainloop, - CollectiveEpilogue, - TileSchedule - >; + using FusionOperation = + cutlass::epilogue::fusion::LinCombEltAct; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + SM, + cutlass::arch::OpClassTensorOp, + TileShape, + ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, + ElementCompute, + ElementC, + LayoutC, + AlignmentC, + ElementD, + LayoutD, + AlignmentD, + EpilogueSchedule, + FusionOperation>::CollectiveOp; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + SM, + cutlass::arch::OpClassTensorOp, + ElementA, + LayoutA, + AlignmentA, + ElementB, + LayoutB, + AlignmentB, + ElementAccumulator, + TileShape, + ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout( + sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule>::CollectiveOp; + + using GemmKernel = + cutlass::gemm::kernel::GemmUniversal, + CollectiveMainloop, + CollectiveEpilogue, + TileSchedule>; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; @@ -132,23 +150,27 @@ bool dispatch_fuse_block_gemm_c3x(GemmEpilogueAllParams params){ StrideD stride_D{params.ldd, cute::Int<1>{}, params.M * params.ldd}; auto a_ptr = reinterpret_cast(const_cast(params.A)); - auto a_scale_ptr = reinterpret_cast(const_cast(params.A_scale)); + auto a_scale_ptr = + reinterpret_cast(const_cast(params.A_scale)); auto b_ptr = reinterpret_cast(const_cast(params.B)); - auto b_scale_ptr = reinterpret_cast(const_cast(params.B_scale)); + auto b_scale_ptr = + reinterpret_cast(const_cast(params.B_scale)); auto c_ptr = reinterpret_cast(const_cast(params.bias)); auto d_ptr = reinterpret_cast(params.D); - ProblemShapeType problem_size = ProblemShapeType{params.M, params.N, params.K, params.batch_count}; + ProblemShapeType problem_size = + ProblemShapeType{params.M, params.N, params.K, params.batch_count}; typename Gemm::Arguments arguments{ - cutlass::gemm::GemmUniversalMode::kGemm, - problem_size, - {a_ptr, stride_A, b_ptr, stride_B, - a_scale_ptr, b_scale_ptr}, - {{params.scale}, // epilogue.thread - c_ptr, stride_C, d_ptr, stride_D} - }; - if constexpr (hasbias){ + cutlass::gemm::GemmUniversalMode::kGemm, + problem_size, + {a_ptr, stride_A, b_ptr, stride_B, a_scale_ptr, b_scale_ptr}, + {{params.scale}, // epilogue.thread + c_ptr, + stride_C, + d_ptr, + stride_D}}; + if constexpr (hasbias) { arguments.epilogue.thread.beta = 1.0; } @@ -162,12 +184,12 @@ bool dispatch_fuse_block_gemm_c3x(GemmEpilogueAllParams params){ arguments.scheduler.reduction_mode = ReductionMode::Nondeterministic; } - Gemm gemm_op; cutlass::Status status = gemm_op.can_implement(arguments); if (status != cutlass::Status::kSuccess) { - std::cout << "Gemm::can_implement() failed. " << cutlassGetStatusString(status) << std::endl; + std::cout << "Gemm::can_implement() failed. " + << cutlassGetStatusString(status) << std::endl; return false; } size_t workspace_size = Gemm::get_workspace_size(arguments); @@ -176,7 +198,8 @@ bool dispatch_fuse_block_gemm_c3x(GemmEpilogueAllParams params){ status = gemm_op(arguments, workspace->ptr(), params.stream); if (status != cutlass::Status::kSuccess) { - std::cout << "Gemm::run() failed." << cutlassGetStatusString(status) << std::endl; + std::cout << "Gemm::run() failed." << cutlassGetStatusString(status) + << std::endl; return false; } return true; diff --git a/custom_ops/gpu_ops/cutlass_kernels/fp8_gemm_fused/fuse_dual_gemm_act_template_3x.h b/custom_ops/gpu_ops/cutlass_kernels/fp8_gemm_fused/fuse_dual_gemm_act_template_3x.h index 632cdc296a1..ec484ac7435 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/fp8_gemm_fused/fuse_dual_gemm_act_template_3x.h +++ b/custom_ops/gpu_ops/cutlass_kernels/fp8_gemm_fused/fuse_dual_gemm_act_template_3x.h @@ -28,60 +28,67 @@ #include "cutlass_extensions/gemm/collective/collective_builder_gated.hpp" #include "cutlass_extensions/gemm/kernel/gemm_universal_gated.hpp" -template class Activation = - cutlass::epilogue::thread::SiLu, + template + class Activation = cutlass::epilogue::thread::SiLu, bool SwapAB = true> bool dispatch_dual_gemm_act_sm90(DualGemmEpilogueAllParams params) { using namespace cute; using ElementA = typename std::conditional_t< std::is_same_v, - cutlass::float_e4m3_t, cutlass::float_e5m2_t>; - using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand + cutlass::float_e4m3_t, + cutlass::float_e5m2_t>; + using LayoutA = + cutlass::layout::RowMajor; // Layout type for A matrix operand static constexpr int AlignmentA = 128 / cutlass::sizeof_bits< - ElementA>::value; // Memory access granularity/alignment of A - // matrix in units of elements (up to 16 bytes) + ElementA>::value; // Memory access granularity/alignment of A + // matrix in units of elements (up to 16 bytes) // B matrix configuration - using ElementB = ElementA; // Element type for B matrix operand + using ElementB = ElementA; // Element type for B matrix operand using LayoutB = - cutlass::layout::ColumnMajor; // Layout type for B matrix operand + cutlass::layout::ColumnMajor; // Layout type for B matrix operand static constexpr int AlignmentB = 128 / cutlass::sizeof_bits< - ElementB>::value; // Memory access granularity/alignment of B - // matrix in units of elements (up to 16 bytes) + ElementB>::value; // Memory access granularity/alignment of B + // matrix in units of elements (up to 16 bytes) - using ElementC = ElementA; // Element type for C matrix operands + using ElementC = ElementA; // Element type for C matrix operands - using LayoutC = cute::conditional_t; static constexpr int AlignmentC = 128 / cutlass::sizeof_bits< - ElementC>::value; // Memory access granularity/alignment of C matrices - // in units of elements (up to 16 bytes) + ElementC>::value; // Memory access granularity/alignment of C + // matrices in units of elements (up to 16 bytes) // Output matrix configuration - using ElementOutput = ElementA; // Element type for output matrix operands + using ElementOutput = ElementA; // Element type for output matrix operands // using LayoutOutput = cutlass::layout::RowMajor; // Layout type for output // matrix operands - using LayoutOutput = cute::conditional_t; static constexpr int AlignmentOutput = 128 / cutlass::sizeof_bits::value; // Multiply-accumulate blocking/pipelining details - using ElementAccumulator = float; // Element type for internal accumulation - using ElementCompute = float; // Element type for compute - using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that - // supports the intended feature - using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag - using TileShape = CTAShape; // Threadblock-level tile size + using ElementAccumulator = float; // Element type for internal accumulation + using ElementCompute = float; // Element type for compute + using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that + // supports the intended feature + using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag + using TileShape = CTAShape; // Threadblock-level tile size using KernelSchedule = MainloopScheduleType; using EpilogueSchedule = EpilogueScheduleType; using TileScheduler = TileSchedulerType; @@ -94,22 +101,46 @@ bool dispatch_dual_gemm_act_sm90(DualGemmEpilogueAllParams params) { using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< - ArchTag, OperatorClass, TileShape, ClusterShape, EpilogueTileType, - ElementAccumulator, ElementAccumulator, ElementC, LayoutC, AlignmentC, - ElementOutput, LayoutOutput, AlignmentOutput, EpilogueSchedule, + ArchTag, + OperatorClass, + TileShape, + ClusterShape, + EpilogueTileType, + ElementAccumulator, + ElementAccumulator, + ElementC, + LayoutC, + AlignmentC, + ElementOutput, + LayoutOutput, + AlignmentOutput, + EpilogueSchedule, FusionOperation>::CollectiveOp; using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilderGated< - ArchTag, OperatorClass, ElementA, LayoutA, AlignmentA, ElementB, - LayoutB, AlignmentB, ElementAccumulator, TileShape, ClusterShape, + ArchTag, + OperatorClass, + ElementA, + LayoutA, + AlignmentA, + ElementB, + LayoutB, + AlignmentB, + ElementAccumulator, + TileShape, + ClusterShape, cutlass::gemm::collective::StageCountAutoCarveout( sizeof(typename CollectiveEpilogue::SharedStorage))>, - KernelSchedule, Activation, SwapAB>::CollectiveOp; + KernelSchedule, + Activation, + SwapAB>::CollectiveOp; using GemmKernel = cutlass::gemm::kernel::GemmUniversalGated< - Shape, // Indicates ProblemShape - CollectiveMainloop, CollectiveEpilogue, TileScheduler>; + Shape, // Indicates ProblemShape + CollectiveMainloop, + CollectiveEpilogue, + TileScheduler>; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; @@ -141,7 +172,7 @@ bool dispatch_dual_gemm_act_sm90(DualGemmEpilogueAllParams params) { cutlass::gemm::GemmUniversalMode::kGemm, {arg_m, arg_n, params.K, params.batch_count}, {ptr_A, stride_A, ptr_B0, ptr_B1, stride_B, params.scale0, params.scale1}, - {{}, // epilogue.thread + {{}, // epilogue.thread nullptr, stride_C, reinterpret_cast(params.D), diff --git a/custom_ops/gpu_ops/cutlass_kernels/fp8_gemm_fused/fuse_dual_gemm_geglu_template.h b/custom_ops/gpu_ops/cutlass_kernels/fp8_gemm_fused/fuse_dual_gemm_geglu_template.h index 0e18c2a3890..5d2a9638b3a 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/fp8_gemm_fused/fuse_dual_gemm_geglu_template.h +++ b/custom_ops/gpu_ops/cutlass_kernels/fp8_gemm_fused/fuse_dual_gemm_geglu_template.h @@ -21,9 +21,15 @@ #include "fp8_gemm_fused/dual_gemm/device/dual_gemm.h" #include "fp8_gemm_fused/dual_gemm/thread/left_gelu_and_mul.h" -template +template bool dispatch_dual_gemm_geglu(DualGemmEpilogueAllParams params) { using ElementInputA = typename std::conditional_t< std::is_same_v, @@ -73,8 +79,8 @@ bool dispatch_dual_gemm_geglu(DualGemmEpilogueAllParams params) { cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // <- ?? static constexpr auto ScaleType = - hasbias? cutlass::epilogue::thread::ScaleType::NoBetaScaling - : cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling; + hasbias ? cutlass::epilogue::thread::ScaleType::NoBetaScaling + : cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling; using EpilogueOp0 = cutlass::epilogue::thread::LinearCombination< ElementInputC, // <- data type of output matrix @@ -86,7 +92,7 @@ bool dispatch_dual_gemm_geglu(DualGemmEpilogueAllParams params) { ElementAccumulator, // <- data type of accumulator ElementComputeEpilogue, ScaleType>; // <- data type for alpha/beta in linear - // combination function + // combination function using EpilogueOp1 = cutlass::epilogue::thread::LeftGELUAndMul< ElementOutput, @@ -143,11 +149,23 @@ bool dispatch_dual_gemm_geglu(DualGemmEpilogueAllParams params) { params.lda}, {reinterpret_cast(const_cast(params.B0)), params.ldb}, - hasbias? typename cutlass::TensorRef{reinterpret_cast(const_cast(params.bias0)), 0} : nullptr_ref, + hasbias + ? typename cutlass::TensorRef< + typename Gemm::ElementC, + typename Gemm::LayoutC>{reinterpret_cast( + const_cast(params.bias0)), + 0} + : nullptr_ref, nullptr_ref, {reinterpret_cast(const_cast(params.B1)), params.ldb}, - hasbias? typename cutlass::TensorRef{reinterpret_cast(const_cast(params.bias1)), 0} : nullptr_ref, + hasbias + ? typename cutlass::TensorRef< + typename Gemm::ElementC, + typename Gemm::LayoutC>{reinterpret_cast( + const_cast(params.bias1)), + 0} + : nullptr_ref, nullptr_ref, {reinterpret_cast(const_cast(params.D)), params.ldd}, diff --git a/custom_ops/gpu_ops/cutlass_kernels/fp8_gemm_fused/fuse_dual_gemm_swiglu_template.h b/custom_ops/gpu_ops/cutlass_kernels/fp8_gemm_fused/fuse_dual_gemm_swiglu_template.h index b5de12e7e90..dc788397503 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/fp8_gemm_fused/fuse_dual_gemm_swiglu_template.h +++ b/custom_ops/gpu_ops/cutlass_kernels/fp8_gemm_fused/fuse_dual_gemm_swiglu_template.h @@ -21,9 +21,15 @@ #include "fp8_gemm_fused/dual_gemm/device/dual_gemm.h" #include "fp8_gemm_fused/dual_gemm/thread/left_silu_and_mul.h" -template +template bool dispatch_dual_gemm_swiglu(DualGemmEpilogueAllParams params) { using ElementInputA = typename std::conditional_t< std::is_same_v, @@ -73,8 +79,8 @@ bool dispatch_dual_gemm_swiglu(DualGemmEpilogueAllParams params) { cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // <- ?? static constexpr auto ScaleType = - hasbias? cutlass::epilogue::thread::ScaleType::NoBetaScaling - : cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling; + hasbias ? cutlass::epilogue::thread::ScaleType::NoBetaScaling + : cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling; using EpilogueOp0 = cutlass::epilogue::thread::LinearCombination< ElementInputC, // <- data type of output matrix @@ -86,7 +92,7 @@ bool dispatch_dual_gemm_swiglu(DualGemmEpilogueAllParams params) { ElementAccumulator, // <- data type of accumulator ElementComputeEpilogue, ScaleType>; // <- data type for alpha/beta in linear - // combination function + // combination function using EpilogueOp1 = cutlass::epilogue::thread::LeftSiLUAndMul< ElementOutput, @@ -129,9 +135,9 @@ bool dispatch_dual_gemm_swiglu(DualGemmEpilogueAllParams params) { cutlass::gemm::GemmCoord problem_size = cutlass::gemm::GemmCoord{params.M, params.N, params.K}; - cutlass::gemm::DualGemmMode mode = params.batch_count > 1 ? - cutlass::gemm::DualGemmMode::kBatched : - cutlass::gemm::DualGemmMode::kGemm; + cutlass::gemm::DualGemmMode mode = params.batch_count > 1 + ? cutlass::gemm::DualGemmMode::kBatched + : cutlass::gemm::DualGemmMode::kGemm; typename cutlass::TensorRef nullptr_ref{}; @@ -144,11 +150,23 @@ bool dispatch_dual_gemm_swiglu(DualGemmEpilogueAllParams params) { params.lda}, {reinterpret_cast(const_cast(params.B0)), params.ldb}, - hasbias ? typename cutlass::TensorRef{reinterpret_cast(const_cast(params.bias0)), 0} : nullptr_ref, + hasbias + ? typename cutlass::TensorRef< + typename Gemm::ElementC, + typename Gemm::LayoutC>{reinterpret_cast( + const_cast(params.bias0)), + 0} + : nullptr_ref, nullptr_ref, {reinterpret_cast(const_cast(params.B1)), params.ldb}, - hasbias ? typename cutlass::TensorRef{reinterpret_cast(const_cast(params.bias1)), 0} : nullptr_ref, + hasbias + ? typename cutlass::TensorRef< + typename Gemm::ElementC, + typename Gemm::LayoutC>{reinterpret_cast( + const_cast(params.bias1)), + 0} + : nullptr_ref, nullptr_ref, {reinterpret_cast(const_cast(params.D)), params.ldd}, diff --git a/custom_ops/gpu_ops/cutlass_kernels/fp8_gemm_fused/fuse_gemm_act_template_3x.h b/custom_ops/gpu_ops/cutlass_kernels/fp8_gemm_fused/fuse_gemm_act_template_3x.h index c4701510705..83ded562704 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/fp8_gemm_fused/fuse_gemm_act_template_3x.h +++ b/custom_ops/gpu_ops/cutlass_kernels/fp8_gemm_fused/fuse_gemm_act_template_3x.h @@ -26,26 +26,28 @@ #include "cutlass/gemm/kernel/tile_scheduler.hpp" #include "cutlass/util/packed_stride.hpp" -template < - typename InputType, - typename OutType, - bool hasbias, - template typename Activation, - typename TileShape, - typename ClusterShape, - typename KernelSchedule = - cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum, - typename EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized, - typename SM = cutlass::arch::Sm90> +template + typename Activation, + typename TileShape, + typename ClusterShape, + typename KernelSchedule = + cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum, + typename EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized, + typename SM = cutlass::arch::Sm90> bool dispatch_fuse_gemm_act_sm90(GemmEpilogueAllParams params) { using namespace cute; using ElementA = typename std::conditional_t< std::is_same_v, - cutlass::float_e4m3_t, cutlass::float_e5m2_t>; + cutlass::float_e4m3_t, + cutlass::float_e5m2_t>; using ElementB = ElementA; using ElementD = typename std::conditional_t, - cutlass::bfloat16_t, cutlass::half_t>; + cutlass::bfloat16_t, + cutlass::half_t>; using ElementC = std::conditional_t; using LayoutA = cutlass::layout::RowMajor; @@ -66,29 +68,53 @@ bool dispatch_fuse_gemm_act_sm90(GemmEpilogueAllParams params) { static constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest; using FusionOperation = - cutlass::epilogue::fusion::LinCombEltAct; + cutlass::epilogue::fusion::LinCombEltAct; using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< - SM, cutlass::arch::OpClassTensorOp, TileShape, ClusterShape, - cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator, - ElementCompute, ElementC, LayoutC, AlignmentC, ElementD, LayoutD, - AlignmentD, EpilogueSchedule, FusionOperation>::CollectiveOp; + SM, + cutlass::arch::OpClassTensorOp, + TileShape, + ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, + ElementCompute, + ElementC, + LayoutC, + AlignmentC, + ElementD, + LayoutD, + AlignmentD, + EpilogueSchedule, + FusionOperation>::CollectiveOp; using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< - SM, cutlass::arch::OpClassTensorOp, ElementA, LayoutA, AlignmentA, - ElementB, LayoutB, AlignmentB, ElementAccumulator, TileShape, + SM, + cutlass::arch::OpClassTensorOp, + ElementA, + LayoutA, + AlignmentA, + ElementB, + LayoutB, + AlignmentB, + ElementAccumulator, + TileShape, ClusterShape, cutlass::gemm::collective::StageCountAutoCarveout( sizeof(typename CollectiveEpilogue::SharedStorage))>, KernelSchedule>::CollectiveOp; - using GemmKernel = cutlass::gemm::kernel::GemmUniversal< - Shape, CollectiveMainloop, CollectiveEpilogue, - cutlass::gemm::PersistentScheduler>; + using GemmKernel = + cutlass::gemm::kernel::GemmUniversal, + CollectiveMainloop, + CollectiveEpilogue, + cutlass::gemm::PersistentScheduler>; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; @@ -120,7 +146,7 @@ bool dispatch_fuse_gemm_act_sm90(GemmEpilogueAllParams params) { typename Gemm::Arguments arguments{cutlass::gemm::GemmUniversalMode::kGemm, problem_size, {a_ptr, stride_A, b_ptr, stride_B}, - {{params.scale}, // epilogue.thread + {{params.scale}, // epilogue.thread c_ptr, stride_C, d_ptr, diff --git a/custom_ops/gpu_ops/cutlass_kernels/fp8_gemm_fused/fuse_gemm_gelu_template.h b/custom_ops/gpu_ops/cutlass_kernels/fp8_gemm_fused/fuse_gemm_gelu_template.h index 32b8a132e80..50793eff0ac 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/fp8_gemm_fused/fuse_gemm_gelu_template.h +++ b/custom_ops/gpu_ops/cutlass_kernels/fp8_gemm_fused/fuse_gemm_gelu_template.h @@ -20,9 +20,14 @@ #include "cutlass/gemm/device/gemm_universal.h" #include "cutlass/gemm/device/gemm_splitk_parallel.h" -template +template bool dispatch_fuse_gemm_gelu(GemmEpilogueAllParams params) { using ElementInputA = typename std::conditional_t< std::is_same_v, @@ -64,9 +69,8 @@ bool dispatch_fuse_gemm_gelu(GemmEpilogueAllParams params) { using ShapeMMAOp = MMAShape; // <- MMA Op tile static constexpr auto ScaleType = - hasbias? cutlass::epilogue::thread::ScaleType::NoBetaScaling - : cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling; - + hasbias ? cutlass::epilogue::thread::ScaleType::NoBetaScaling + : cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling; // This code section describes how threadblocks are scheduled on GPU using SwizzleThreadBlock = @@ -82,7 +86,7 @@ bool dispatch_fuse_gemm_gelu(GemmEpilogueAllParams params) { ElementAccumulator, // <- data type of accumulator ElementComputeEpilogue, ScaleType>; // <- data type for alpha/beta in linear - // combination function + // combination function // Number of pipelines you want to use constexpr int NumStages = Stages; @@ -164,10 +168,15 @@ bool dispatch_fuse_gemm_gelu(GemmEpilogueAllParams params) { return true; } - -template +template bool dispatch_fuse_gemm_split_k_gelu(GemmEpilogueAllParams params) { using ElementInputA = typename std::conditional_t< std::is_same_v, @@ -209,9 +218,8 @@ bool dispatch_fuse_gemm_split_k_gelu(GemmEpilogueAllParams params) { using ShapeMMAOp = MMAShape; // <- MMA Op tile static constexpr auto ScaleType = - hasbias? cutlass::epilogue::thread::ScaleType::NoBetaScaling - : cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling; - + hasbias ? cutlass::epilogue::thread::ScaleType::NoBetaScaling + : cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling; // This code section describes how threadblocks are scheduled on GPU using SwizzleThreadBlock = @@ -227,7 +235,7 @@ bool dispatch_fuse_gemm_split_k_gelu(GemmEpilogueAllParams params) { ElementAccumulator, // <- data type of accumulator ElementComputeEpilogue, ScaleType>; // <- data type for alpha/beta in linear - // combination function + // combination function // Number of pipelines you want to use constexpr int NumStages = Stages; @@ -309,10 +317,14 @@ bool dispatch_fuse_gemm_split_k_gelu(GemmEpilogueAllParams params) { return true; } - -template +template bool dispatch_fuse_gemm_split_k_gelu(GemmEpilogueAllParams params) { using ElementInputA = typename std::conditional_t< std::is_same_v, @@ -354,8 +366,8 @@ bool dispatch_fuse_gemm_split_k_gelu(GemmEpilogueAllParams params) { using ShapeMMAOp = MMAShape; static constexpr auto ScaleType = - hasbias? cutlass::epilogue::thread::ScaleType::NoBetaScaling - : cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling; + hasbias ? cutlass::epilogue::thread::ScaleType::NoBetaScaling + : cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling; using EpilogueOp = cutlass::epilogue::thread::LinearCombinationGELU< ElementOutput, // <- data type of output matrix @@ -367,50 +379,55 @@ bool dispatch_fuse_gemm_split_k_gelu(GemmEpilogueAllParams params) { ElementAccumulator, // <- data type of accumulator ElementComputeEpilogue, ScaleType>; // <- data type for alpha/beta in linear - // combination function + // combination function // Number of pipelines you want to use constexpr int NumStages = Stages; - using ConvertScaledOp = cutlass::epilogue::thread::Convert< - ElementAccumulator, - cutlass::gemm::device::DefaultGemmConfiguration::EpilogueOutputOp::kCount, - ElementAccumulator>; - - /// Reduction operator - using ReductionOp = cutlass::reduction::thread::ReduceAdd< - ElementAccumulator, typename EpilogueOp::ElementAccumulator, - EpilogueOp::kCount>; - - /// Threadblock-level swizzling operator - using ThreadblockSwizzle = cutlass::gemm::threadblock::GemmSplitKHorizontalThreadblockSwizzle; - - /// Operation performed by GEMM - using Operator = cutlass::arch::OpMultiplyAddFastAccum; - - using Gemm = cutlass::gemm::device::GemmSplitKParallel; - + using ConvertScaledOp = cutlass::epilogue::thread::Convert< + ElementAccumulator, + cutlass::gemm::device::DefaultGemmConfiguration< + cutlass::arch::OpClassSimt, + SmArch, + ElementInputA, + ElementInputB, + ElementAccumulator, + ElementAccumulator>::EpilogueOutputOp::kCount, + ElementAccumulator>; + + /// Reduction operator + using ReductionOp = cutlass::reduction::thread::ReduceAdd< + ElementAccumulator, + typename EpilogueOp::ElementAccumulator, + EpilogueOp::kCount>; + + /// Threadblock-level swizzling operator + using ThreadblockSwizzle = + cutlass::gemm::threadblock::GemmSplitKHorizontalThreadblockSwizzle; + + /// Operation performed by GEMM + using Operator = cutlass::arch::OpMultiplyAddFastAccum; + + using Gemm = cutlass::gemm::device::GemmSplitKParallel; cutlass::gemm::GemmCoord problem_size = cutlass::gemm::GemmCoord{params.M, params.N, params.K}; @@ -421,15 +438,18 @@ bool dispatch_fuse_gemm_split_k_gelu(GemmEpilogueAllParams params) { // Split K dimension into 16 partitions int split_k_slices = params.split_k; - // Create a tuple of gemm kernel arguments. This is later passed as arguments to launch - // instantiated CUTLASS kernel - typename Gemm::Arguments arguments{problem_size, // <- problem size of matrix multiplication - {reinterpret_cast(const_cast(params.A)),params.lda}, - {reinterpret_cast(const_cast(params.B)),params.ldb}, - {reinterpret_cast(const_cast(params.bias)),0}, - {reinterpret_cast(params.D),params.ldd}, - {alpha, beta}, // <- tuple of alpha and beta - split_k_slices}; // <- k-dimension split factor + // Create a tuple of gemm kernel arguments. This is later passed as arguments + // to launch instantiated CUTLASS kernel + typename Gemm::Arguments arguments{ + problem_size, // <- problem size of matrix multiplication + {reinterpret_cast(const_cast(params.A)), + params.lda}, + {reinterpret_cast(const_cast(params.B)), + params.ldb}, + {reinterpret_cast(const_cast(params.bias)), 0}, + {reinterpret_cast(params.D), params.ldd}, + {alpha, beta}, // <- tuple of alpha and beta + split_k_slices}; // <- k-dimension split factor Gemm gemm_op; diff --git a/custom_ops/gpu_ops/cutlass_kernels/fp8_gemm_fused/fuse_gemm_noact_template.h b/custom_ops/gpu_ops/cutlass_kernels/fp8_gemm_fused/fuse_gemm_noact_template.h index 31d2a21e2a3..f7f2bbf3963 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/fp8_gemm_fused/fuse_gemm_noact_template.h +++ b/custom_ops/gpu_ops/cutlass_kernels/fp8_gemm_fused/fuse_gemm_noact_template.h @@ -20,9 +20,14 @@ #include "cutlass/gemm/device/gemm_universal.h" #include "cutlass/gemm/device/gemm_splitk_parallel.h" -template +template bool dispatch_fuse_gemm_noact(GemmEpilogueAllParams params) { using ElementInputA = typename std::conditional_t< std::is_same_v, @@ -64,9 +69,8 @@ bool dispatch_fuse_gemm_noact(GemmEpilogueAllParams params) { using ShapeMMAOp = MMAShape; // <- MMA Op tile static constexpr auto ScaleType = - hasbias? cutlass::epilogue::thread::ScaleType::NoBetaScaling - : cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling; - + hasbias ? cutlass::epilogue::thread::ScaleType::NoBetaScaling + : cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling; // This code section describes how threadblocks are scheduled on GPU using SwizzleThreadBlock = @@ -82,7 +86,7 @@ bool dispatch_fuse_gemm_noact(GemmEpilogueAllParams params) { ElementAccumulator, // <- data type of accumulator ElementComputeEpilogue, ScaleType>; // <- data type for alpha/beta in linear - // combination function + // combination function // Number of pipelines you want to use constexpr int NumStages = Stages; @@ -164,10 +168,14 @@ bool dispatch_fuse_gemm_noact(GemmEpilogueAllParams params) { return true; } - -template +template bool dispatch_fuse_gemm_split_k_noact(GemmEpilogueAllParams params) { using ElementInputA = typename std::conditional_t< std::is_same_v, @@ -209,8 +217,8 @@ bool dispatch_fuse_gemm_split_k_noact(GemmEpilogueAllParams params) { using ShapeMMAOp = MMAShape; static constexpr auto ScaleType = - hasbias? cutlass::epilogue::thread::ScaleType::NoBetaScaling - : cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling; + hasbias ? cutlass::epilogue::thread::ScaleType::NoBetaScaling + : cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling; using EpilogueOp = cutlass::epilogue::thread::LinearCombination< ElementOutput, // <- data type of output matrix @@ -222,50 +230,55 @@ bool dispatch_fuse_gemm_split_k_noact(GemmEpilogueAllParams params) { ElementAccumulator, // <- data type of accumulator ElementComputeEpilogue, ScaleType>; // <- data type for alpha/beta in linear - // combination function + // combination function // Number of pipelines you want to use constexpr int NumStages = Stages; - using ConvertScaledOp = cutlass::epilogue::thread::Convert< - ElementAccumulator, - cutlass::gemm::device::DefaultGemmConfiguration::EpilogueOutputOp::kCount, - ElementAccumulator>; - - /// Reduction operator - using ReductionOp = cutlass::reduction::thread::ReduceAdd< - ElementAccumulator, typename EpilogueOp::ElementAccumulator, - EpilogueOp::kCount>; - - /// Threadblock-level swizzling operator - using ThreadblockSwizzle = cutlass::gemm::threadblock::GemmSplitKHorizontalThreadblockSwizzle; - - /// Operation performed by GEMM - using Operator = cutlass::arch::OpMultiplyAddFastAccum; - - using Gemm = cutlass::gemm::device::GemmSplitKParallel; - + using ConvertScaledOp = cutlass::epilogue::thread::Convert< + ElementAccumulator, + cutlass::gemm::device::DefaultGemmConfiguration< + cutlass::arch::OpClassSimt, + SmArch, + ElementInputA, + ElementInputB, + ElementAccumulator, + ElementAccumulator>::EpilogueOutputOp::kCount, + ElementAccumulator>; + + /// Reduction operator + using ReductionOp = cutlass::reduction::thread::ReduceAdd< + ElementAccumulator, + typename EpilogueOp::ElementAccumulator, + EpilogueOp::kCount>; + + /// Threadblock-level swizzling operator + using ThreadblockSwizzle = + cutlass::gemm::threadblock::GemmSplitKHorizontalThreadblockSwizzle; + + /// Operation performed by GEMM + using Operator = cutlass::arch::OpMultiplyAddFastAccum; + + using Gemm = cutlass::gemm::device::GemmSplitKParallel; cutlass::gemm::GemmCoord problem_size = cutlass::gemm::GemmCoord{params.M, params.N, params.K}; @@ -276,15 +289,18 @@ bool dispatch_fuse_gemm_split_k_noact(GemmEpilogueAllParams params) { // Split K dimension into 16 partitions int split_k_slices = params.split_k; - // Create a tuple of gemm kernel arguments. This is later passed as arguments to launch - // instantiated CUTLASS kernel - typename Gemm::Arguments arguments{problem_size, // <- problem size of matrix multiplication - {reinterpret_cast(const_cast(params.A)),params.lda}, - {reinterpret_cast(const_cast(params.B)),params.ldb}, - {reinterpret_cast(const_cast(params.bias)),0}, - {reinterpret_cast(params.D),params.ldd}, - {alpha, beta}, // <- tuple of alpha and beta - split_k_slices}; // <- k-dimension split factor + // Create a tuple of gemm kernel arguments. This is later passed as arguments + // to launch instantiated CUTLASS kernel + typename Gemm::Arguments arguments{ + problem_size, // <- problem size of matrix multiplication + {reinterpret_cast(const_cast(params.A)), + params.lda}, + {reinterpret_cast(const_cast(params.B)), + params.ldb}, + {reinterpret_cast(const_cast(params.bias)), 0}, + {reinterpret_cast(params.D), params.ldd}, + {alpha, beta}, // <- tuple of alpha and beta + split_k_slices}; // <- k-dimension split factor Gemm gemm_op; diff --git a/custom_ops/gpu_ops/cutlass_kernels/fp8_gemm_fused/fuse_gemm_relu_template.h b/custom_ops/gpu_ops/cutlass_kernels/fp8_gemm_fused/fuse_gemm_relu_template.h index d2aa189eedf..f59b84f59f1 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/fp8_gemm_fused/fuse_gemm_relu_template.h +++ b/custom_ops/gpu_ops/cutlass_kernels/fp8_gemm_fused/fuse_gemm_relu_template.h @@ -20,9 +20,14 @@ #include "cutlass/gemm/device/gemm_universal.h" #include "cutlass/gemm/device/gemm_splitk_parallel.h" -template +template bool dispatch_fuse_gemm_relu(GemmEpilogueAllParams params) { using ElementInputA = typename std::conditional_t< std::is_same_v, @@ -64,9 +69,8 @@ bool dispatch_fuse_gemm_relu(GemmEpilogueAllParams params) { using ShapeMMAOp = MMAShape; // <- MMA Op tile static constexpr auto ScaleType = - hasbias? cutlass::epilogue::thread::ScaleType::NoBetaScaling - : cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling; - + hasbias ? cutlass::epilogue::thread::ScaleType::NoBetaScaling + : cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling; // This code section describes how threadblocks are scheduled on GPU using SwizzleThreadBlock = @@ -82,7 +86,7 @@ bool dispatch_fuse_gemm_relu(GemmEpilogueAllParams params) { ElementAccumulator, // <- data type of accumulator ElementComputeEpilogue, ScaleType>; // <- data type for alpha/beta in linear - // combination function + // combination function // Number of pipelines you want to use constexpr int NumStages = Stages; @@ -164,10 +168,14 @@ bool dispatch_fuse_gemm_relu(GemmEpilogueAllParams params) { return true; } - -template +template bool dispatch_fuse_gemm_split_k_relu(GemmEpilogueAllParams params) { using ElementInputA = typename std::conditional_t< std::is_same_v, @@ -209,8 +217,8 @@ bool dispatch_fuse_gemm_split_k_relu(GemmEpilogueAllParams params) { using ShapeMMAOp = MMAShape; static constexpr auto ScaleType = - hasbias? cutlass::epilogue::thread::ScaleType::NoBetaScaling - : cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling; + hasbias ? cutlass::epilogue::thread::ScaleType::NoBetaScaling + : cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling; using EpilogueOp = cutlass::epilogue::thread::LinearCombinationRelu< ElementOutput, // <- data type of output matrix @@ -222,50 +230,55 @@ bool dispatch_fuse_gemm_split_k_relu(GemmEpilogueAllParams params) { ElementAccumulator, // <- data type of accumulator ElementComputeEpilogue, ScaleType>; // <- data type for alpha/beta in linear - // combination function + // combination function // Number of pipelines you want to use constexpr int NumStages = Stages; - using ConvertScaledOp = cutlass::epilogue::thread::Convert< - ElementAccumulator, - cutlass::gemm::device::DefaultGemmConfiguration::EpilogueOutputOp::kCount, - ElementAccumulator>; - - /// Reduction operator - using ReductionOp = cutlass::reduction::thread::ReduceAdd< - ElementAccumulator, typename EpilogueOp::ElementAccumulator, - EpilogueOp::kCount>; - - /// Threadblock-level swizzling operator - using ThreadblockSwizzle = cutlass::gemm::threadblock::GemmSplitKHorizontalThreadblockSwizzle; - - /// Operation performed by GEMM - using Operator = cutlass::arch::OpMultiplyAddFastAccum; - - using Gemm = cutlass::gemm::device::GemmSplitKParallel; - + using ConvertScaledOp = cutlass::epilogue::thread::Convert< + ElementAccumulator, + cutlass::gemm::device::DefaultGemmConfiguration< + cutlass::arch::OpClassSimt, + SmArch, + ElementInputA, + ElementInputB, + ElementAccumulator, + ElementAccumulator>::EpilogueOutputOp::kCount, + ElementAccumulator>; + + /// Reduction operator + using ReductionOp = cutlass::reduction::thread::ReduceAdd< + ElementAccumulator, + typename EpilogueOp::ElementAccumulator, + EpilogueOp::kCount>; + + /// Threadblock-level swizzling operator + using ThreadblockSwizzle = + cutlass::gemm::threadblock::GemmSplitKHorizontalThreadblockSwizzle; + + /// Operation performed by GEMM + using Operator = cutlass::arch::OpMultiplyAddFastAccum; + + using Gemm = cutlass::gemm::device::GemmSplitKParallel; cutlass::gemm::GemmCoord problem_size = cutlass::gemm::GemmCoord{params.M, params.N, params.K}; @@ -276,15 +289,18 @@ bool dispatch_fuse_gemm_split_k_relu(GemmEpilogueAllParams params) { // Split K dimension into 16 partitions int split_k_slices = params.split_k; - // Create a tuple of gemm kernel arguments. This is later passed as arguments to launch - // instantiated CUTLASS kernel - typename Gemm::Arguments arguments{problem_size, // <- problem size of matrix multiplication - {reinterpret_cast(const_cast(params.A)),params.lda}, - {reinterpret_cast(const_cast(params.B)),params.ldb}, - {reinterpret_cast(const_cast(params.bias)),0}, - {reinterpret_cast(params.D),params.ldd}, - {alpha, beta}, // <- tuple of alpha and beta - split_k_slices}; // <- k-dimension split factor + // Create a tuple of gemm kernel arguments. This is later passed as arguments + // to launch instantiated CUTLASS kernel + typename Gemm::Arguments arguments{ + problem_size, // <- problem size of matrix multiplication + {reinterpret_cast(const_cast(params.A)), + params.lda}, + {reinterpret_cast(const_cast(params.B)), + params.ldb}, + {reinterpret_cast(const_cast(params.bias)), 0}, + {reinterpret_cast(params.D), params.ldd}, + {alpha, beta}, // <- tuple of alpha and beta + split_k_slices}; // <- k-dimension split factor Gemm gemm_op; diff --git a/custom_ops/gpu_ops/cutlass_kernels/fp8_gemm_fused/per_channel_fp8_fp8_half_gemm.h b/custom_ops/gpu_ops/cutlass_kernels/fp8_gemm_fused/per_channel_fp8_fp8_half_gemm.h index 3ea4d7c33e8..82373f86938 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/fp8_gemm_fused/per_channel_fp8_fp8_half_gemm.h +++ b/custom_ops/gpu_ops/cutlass_kernels/fp8_gemm_fused/per_channel_fp8_fp8_half_gemm.h @@ -17,10 +17,10 @@ #include "fp8_common.h" -#ifdef __GNUC__ // Check if the compiler is GCC or Clang +#ifdef __GNUC__ // Check if the compiler is GCC or Clang #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wstrict-aliasing" -#endif // __GNUC__ +#endif // __GNUC__ // clang-format off #include "cutlass/cutlass.h" @@ -30,91 +30,134 @@ #include "cutlass/epilogue/threadblock/fusion/visitors.hpp" // clang-format on -#ifdef __GNUC__ // Check if the compiler is GCC or Clang +#ifdef __GNUC__ // Check if the compiler is GCC or Clang #pragma GCC diagnostic pop -#endif // __GNUC__ - -template -struct DeviceGemmFp8RowwiseSm89 -{ - using ElementInput = typename std::conditional_t< - std::is_same_v, - cutlass::float_e4m3_t, - cutlass::float_e5m2_t>; - using ElementA = ElementInput; - using LayoutA = cutlass::layout::RowMajor; - static constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; - - using ElementB = ElementInput; - using LayoutB = cutlass::layout::ColumnMajor; - static constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; - - - using ElementOutput = - typename std::conditional_t, - cutlass::bfloat16_t, - cutlass::half_t>; - - using ElementC = ElementOutput; - using LayoutC = cutlass::layout::RowMajor; - static constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; - - using LayoutOutput = cutlass::layout::RowMajor; - static constexpr int AlignmentOutput = 128 / cutlass::sizeof_bits::value; - - using ElementAccumulator = AccumElementType; - using ElementComputeEpilogueScale = float; - using ArchTag = cutlass::arch::Sm89; - using OperatorClass = cutlass::arch::OpClassTensorOp; - - // Number of epilogue stages in EVT - static constexpr int EVTEpilogueStages = 1; - - using OutputTileThreadMap = cutlass::epilogue::threadblock::OutputTileThreadLayout; - - // Definition of EVT - using accSrc = cutlass::epilogue::threadblock::VisitorAccFetch; - - using ComputeBScale = cutlass::epilogue::threadblock::VisitorCompute; - using bScaleSrc = cutlass::epilogue::threadblock::VisitorRowBroadcast>; - using EpilogueBScale = cutlass::epilogue::threadblock::Sm80EVT; - - using ComputeAScale = cutlass::epilogue::threadblock::VisitorCompute; - using aScaleSrc = cutlass::epilogue::threadblock::VisitorColBroadcast>; - using EpilogueAScale = cutlass::epilogue::threadblock::Sm80EVT; - - using Bias = cutlass::epilogue::threadblock::VisitorRowBroadcast< - OutputTileThreadMap, ElementC, - cute::Stride // StrideMNL - >; - - using Compute0 = cutlass::epilogue::threadblock::VisitorCompute< - cutlass::plus, ElementC, ElementComputeEpilogueScale, - cutlass::FloatRoundStyle::round_to_nearest - >; - - using EVTCompute0 = cutlass::epilogue::threadblock::Sm80EVT< - Compute0, - EpilogueAScale, - Bias>; - - using dTar = cutlass::epilogue::threadblock::VisitorAuxStore>; - using EpilogueStore = cutlass::epilogue::threadblock::Sm80EVT; - - using EpilogueOp = EpilogueStore; - - using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmWithVisitor::GemmKernel; - - using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +#endif // __GNUC__ + +template +struct DeviceGemmFp8RowwiseSm89 { + using ElementInput = typename std::conditional_t< + std::is_same_v, + cutlass::float_e4m3_t, + cutlass::float_e5m2_t>; + using ElementA = ElementInput; + using LayoutA = cutlass::layout::RowMajor; + static constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; + + using ElementB = ElementInput; + using LayoutB = cutlass::layout::ColumnMajor; + static constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; + + using ElementOutput = typename std::conditional_t< + std::is_same_v, + cutlass::bfloat16_t, + cutlass::half_t>; + + using ElementC = ElementOutput; + using LayoutC = cutlass::layout::RowMajor; + static constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; + + using LayoutOutput = cutlass::layout::RowMajor; + static constexpr int AlignmentOutput = + 128 / cutlass::sizeof_bits::value; + + using ElementAccumulator = AccumElementType; + using ElementComputeEpilogueScale = float; + using ArchTag = cutlass::arch::Sm89; + using OperatorClass = cutlass::arch::OpClassTensorOp; + + // Number of epilogue stages in EVT + static constexpr int EVTEpilogueStages = 1; + + using OutputTileThreadMap = + cutlass::epilogue::threadblock::OutputTileThreadLayout; + + // Definition of EVT + using accSrc = cutlass::epilogue::threadblock::VisitorAccFetch; + + using ComputeBScale = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::multiplies, + ElementComputeEpilogueScale, + ElementComputeEpilogueScale, + cutlass::FloatRoundStyle::round_to_nearest>; + using bScaleSrc = cutlass::epilogue::threadblock::VisitorRowBroadcast< + OutputTileThreadMap, + ElementComputeEpilogueScale, + cute::Stride>; + using EpilogueBScale = + cutlass::epilogue::threadblock::Sm80EVT; + + using ComputeAScale = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::multiplies, + ElementComputeEpilogueScale, + ElementComputeEpilogueScale, + cutlass::FloatRoundStyle::round_to_nearest>; + using aScaleSrc = cutlass::epilogue::threadblock::VisitorColBroadcast< + OutputTileThreadMap, + ElementComputeEpilogueScale, + cute::Stride>; + using EpilogueAScale = cutlass::epilogue::threadblock:: + Sm80EVT; + + using Bias = cutlass::epilogue::threadblock::VisitorRowBroadcast< + OutputTileThreadMap, + ElementC, + cute::Stride // StrideMNL + >; + + using Compute0 = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::plus, + ElementC, + ElementComputeEpilogueScale, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTCompute0 = + cutlass::epilogue::threadblock::Sm80EVT; + + using dTar = cutlass::epilogue::threadblock::VisitorAuxStore< + OutputTileThreadMap, + ElementOutput, + cutlass::FloatRoundStyle::round_to_nearest, + cute::Stride>; + using EpilogueStore = + cutlass::epilogue::threadblock::Sm80EVT; + + using EpilogueOp = EpilogueStore; + + using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmWithVisitor< + ElementA, + LayoutA, + cutlass::ComplexTransform::kNone, + AlignmentA, + ElementB, + LayoutB, + cutlass::ComplexTransform::kNone, + AlignmentB, + ElementC, + LayoutC, + AlignmentC, + ElementAccumulator, + ElementComputeEpilogueScale, + OperatorClass, + ArchTag, + CtaShape, + WarpShape, + InstructionShape, + EpilogueOp, + cutlass::gemm::threadblock::ThreadblockSwizzleStreamK, + Stages, + cutlass::arch::OpMultiplyAddFastAccum, + EVTEpilogueStages>::GemmKernel; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; }; diff --git a/custom_ops/gpu_ops/cutlass_kernels/fp8_gemm_fused/visitor_fp8_gemm_fused_template.h b/custom_ops/gpu_ops/cutlass_kernels/fp8_gemm_fused/visitor_fp8_gemm_fused_template.h index a073c9bff3d..111b2cfd156 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/fp8_gemm_fused/visitor_fp8_gemm_fused_template.h +++ b/custom_ops/gpu_ops/cutlass_kernels/fp8_gemm_fused/visitor_fp8_gemm_fused_template.h @@ -16,56 +16,68 @@ #include "per_channel_fp8_fp8_half_gemm.h" // NOLINT template -typename Gemm::Arguments prepar_gemm_args_sm89(void* D, void const* A, void const* B, void const* C_bias, - int m, int n, int k, float const* scale_d0, float const* scale_d1) -{ - using ElementT = typename Gemm::ElementA; - using ElementOutput = typename Gemm::ElementD; - using ElementComputeEpilogue = float; +typename Gemm::Arguments prepar_gemm_args_sm89(void* D, + void const* A, + void const* B, + void const* C_bias, + int m, + int n, + int k, + float const* scale_d0, + float const* scale_d1) { + using ElementT = typename Gemm::ElementA; + using ElementOutput = typename Gemm::ElementD; + using ElementComputeEpilogue = float; - int const lda = k; - int const ldb = k; - int const ldc = n; + int const lda = k; + int const ldb = k; + int const ldc = n; - typename Gemm::Arguments args(cutlass::gemm::GemmUniversalMode::kGemm, // Mode - {m, n, k}, // Problem size - 1, // Split-k factor - {}, // Epilogue args - reinterpret_cast(A), // a pointer - reinterpret_cast(B), // b pointer - nullptr, // c pointer (unused) - nullptr, // d pointer (unused) - m * k, // batch stride a (unused) - n * k, // batch stride b (unused) - m * n, // batch stride c (unused) - m * n, // batch stride d (unused) - lda, // stride a - ldb, // stride b - ldc, // stride c (unused) - ldc); // stride d (unused) + typename Gemm::Arguments args( + cutlass::gemm::GemmUniversalMode::kGemm, // Mode + {m, n, k}, // Problem size + 1, // Split-k factor + {}, // Epilogue args + reinterpret_cast(A), // a pointer + reinterpret_cast(B), // b pointer + nullptr, // c pointer (unused) + nullptr, // d pointer (unused) + m * k, // batch stride a (unused) + n * k, // batch stride b (unused) + m * n, // batch stride c (unused) + m * n, // batch stride d (unused) + lda, // stride a + ldb, // stride b + ldc, // stride c (unused) + ldc); // stride d (unused) - args.epilogue = { - { - { - { - {}, // Accumulator - {reinterpret_cast(scale_d1), ElementComputeEpilogue(0), - {cute::_0{}, cute::_1{}, cute::_0{}}}, - {} // Multiplies - }, - {reinterpret_cast(scale_d0), ElementComputeEpilogue(0), {cute::_0{}, cute::_0{}, cute::_0{}}}, - {} // Multiplies - }, // Accum - {reinterpret_cast(C_bias), ElementOutput(0), {cute::_0{}, cute::_1{}, cute::_0{}}}, // Bias - {} // Compute0 - }, - {reinterpret_cast(D), {n, cute::_1{}, cute::_0{}}} - }; - return args; + args.epilogue = { + { + { + { + {}, // Accumulator + {reinterpret_cast(scale_d1), + ElementComputeEpilogue(0), + {cute::_0{}, cute::_1{}, cute::_0{}}}, + {} // Multiplies + }, + {reinterpret_cast(scale_d0), + ElementComputeEpilogue(0), + {cute::_0{}, cute::_0{}, cute::_0{}}}, + {} // Multiplies + }, // Accum + {reinterpret_cast(C_bias), + ElementOutput(0), + {cute::_0{}, cute::_1{}, cute::_0{}}}, // Bias + {} // Compute0 + }, + {reinterpret_cast(D), {n, cute::_1{}, cute::_0{}}}}; + return args; } template -bool per_channel_fp8_fp8_gemm_scale_bias(GemmEpilogueAllParams params, typename Gemm::Arguments args) { +bool per_channel_fp8_fp8_gemm_scale_bias(GemmEpilogueAllParams params, + typename Gemm::Arguments args) { Gemm per_channel_fp8_gemm; cutlass::Status status = per_channel_fp8_gemm.can_implement(args); @@ -89,14 +101,31 @@ bool per_channel_fp8_fp8_gemm_scale_bias(GemmEpilogueAllParams params, typename return true; } - -template +template bool dispatch_visitor_fuse_gemm(GemmEpilogueAllParams params) { - using AccumElementType = float; - using Gemm = typename DeviceGemmFp8RowwiseSm89::Gemm; - auto args = prepar_gemm_args_sm89(params.D, params.A, params.B, params.bias, params.M, params.N, params.K, params.scalar_scale, params.channel_scale); - per_channel_fp8_fp8_gemm_scale_bias(params, args); + using AccumElementType = float; + using Gemm = typename DeviceGemmFp8RowwiseSm89::Gemm; + auto args = prepar_gemm_args_sm89(params.D, + params.A, + params.B, + params.bias, + params.M, + params.N, + params.K, + params.scalar_scale, + params.channel_scale); + per_channel_fp8_fp8_gemm_scale_bias(params, args); } diff --git a/custom_ops/gpu_ops/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm.h b/custom_ops/gpu_ops/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm.h index 6b1ab209e35..44d5ccc220d 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm.h +++ b/custom_ops/gpu_ops/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm.h @@ -21,10 +21,8 @@ #include #include -namespace kernels -{ -namespace cutlass_kernels -{ +namespace kernels { +namespace cutlass_kernels { /* This runner only supports: @@ -32,64 +30,105 @@ namespace cutlass_kernels Activations, biases, scales and outputs are all assumed to be row-major. - However, it is assumed that B is in a special format governed by cutlass_extensions/gemm/kernel/mixed_gemm_B_layout. - In this case, B must be preprocessed using the cutlass weight only quant preprocessors. The weight preprocessor - will instantiate the layout and preprocess based on the instantiation, so layout changes should only require - modifications to mix_gemm_B_layout.h. + However, it is assumed that B is in a special format governed by + cutlass_extensions/gemm/kernel/mixed_gemm_B_layout. In this case, B must be + preprocessed using the cutlass weight only quant preprocessors. The weight + preprocessor will instantiate the layout and preprocess based on the + instantiation, so layout changes should only require modifications to + mix_gemm_B_layout.h. */ -class CutlassFpAIntBGemmRunnerInterface -{ -public: - CutlassFpAIntBGemmRunnerInterface() {} - - virtual ~CutlassFpAIntBGemmRunnerInterface() {} - - virtual void gemm(void const* A, void const* B, void const* weight_scales, void const* weight_zero_points, - void const* biases, float const alpha, void* C, int m, int n, int k, int const group_size, - cutlass_extensions::CutlassGemmConfig gemmConfig, void* workspace_ptr, const size_t workspace_bytes, cudaStream_t stream - ) = 0; - - // Returns desired workspace size in bytes. - virtual size_t getWorkspaceSize(int const m, int const n, int const k) = 0; - - virtual std::vector getConfigs(int k) const = 0; - -protected: - static constexpr int SPLIT_K_LIMIT = 7; - static constexpr int MIN_M_TILE = 16; - static constexpr int MIN_N_TILE = 64; +class CutlassFpAIntBGemmRunnerInterface { + public: + CutlassFpAIntBGemmRunnerInterface() {} + + virtual ~CutlassFpAIntBGemmRunnerInterface() {} + + virtual void gemm(void const* A, + void const* B, + void const* weight_scales, + void const* weight_zero_points, + void const* biases, + float const alpha, + void* C, + int m, + int n, + int k, + int const group_size, + cutlass_extensions::CutlassGemmConfig gemmConfig, + void* workspace_ptr, + const size_t workspace_bytes, + cudaStream_t stream) = 0; + + // Returns desired workspace size in bytes. + virtual size_t getWorkspaceSize(int const m, int const n, int const k) = 0; + + virtual std::vector getConfigs( + int k) const = 0; + + protected: + static constexpr int SPLIT_K_LIMIT = 7; + static constexpr int MIN_M_TILE = 16; + static constexpr int MIN_N_TILE = 64; }; -template -class CutlassFpAIntBGemmRunner : public virtual CutlassFpAIntBGemmRunnerInterface -{ -public: - CutlassFpAIntBGemmRunner(); - ~CutlassFpAIntBGemmRunner(); - - void gemm(void const* A, void const* B, void const* weight_scales, void const* weight_zero_points, - void const* biases, float const alpha, void* C, int m, int n, int k, int const group_size, - cutlass_extensions::CutlassGemmConfig gemmConfig, void* workspace_ptr, const size_t workspace_bytes, - cudaStream_t stream) override; - - // Returns desired workspace size in bytes. - size_t getWorkspaceSize(int const m, int const n, int const k) override; - - std::vector getConfigs(int k) const override; - -private: - template - void dispatch_to_arch(ActivationType const* A, WeightType const* B, ScaleZeroType const* weight_scales, - ScaleZeroType const* weight_zero_points, BiasType const* biases, float const alpha, OutputType* C, int m, int n, - int k, int const group_size, cutlass_extensions::CutlassGemmConfig gemm_config, void* workspace_ptr, - const size_t workspace_bytes, cudaStream_t stream, int* occupancy = nullptr); - -private: - int sm_; - int multi_processor_count_; +template +class CutlassFpAIntBGemmRunner + : public virtual CutlassFpAIntBGemmRunnerInterface { + public: + CutlassFpAIntBGemmRunner(); + ~CutlassFpAIntBGemmRunner(); + + void gemm(void const* A, + void const* B, + void const* weight_scales, + void const* weight_zero_points, + void const* biases, + float const alpha, + void* C, + int m, + int n, + int k, + int const group_size, + cutlass_extensions::CutlassGemmConfig gemmConfig, + void* workspace_ptr, + const size_t workspace_bytes, + cudaStream_t stream) override; + + // Returns desired workspace size in bytes. + size_t getWorkspaceSize(int const m, int const n, int const k) override; + + std::vector getConfigs( + int k) const override; + + private: + template + void dispatch_to_arch(ActivationType const* A, + WeightType const* B, + ScaleZeroType const* weight_scales, + ScaleZeroType const* weight_zero_points, + BiasType const* biases, + float const alpha, + OutputType* C, + int m, + int n, + int k, + int const group_size, + cutlass_extensions::CutlassGemmConfig gemm_config, + void* workspace_ptr, + const size_t workspace_bytes, + cudaStream_t stream, + int* occupancy = nullptr); + + private: + int sm_; + int multi_processor_count_; }; -} // namespace cutlass_kernels -} // namespace kernels +} // namespace cutlass_kernels +} // namespace kernels diff --git a/custom_ops/gpu_ops/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h b/custom_ops/gpu_ops/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h index 2e02658d265..bc62bb34bf7 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h +++ b/custom_ops/gpu_ops/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h @@ -43,311 +43,559 @@ namespace kernels { namespace cutlass_kernels { -template < - typename ActivationType, - typename WeightType, - typename ScaleZeroType, - typename BiasType, - typename OutputType, - typename arch, - cutlass::WeightOnlyQuantOp QuantOp, - typename EpilogueTag, - typename ThreadblockShape, - typename WarpShape, - int Stages> +template void generic_mixed_gemm_kernelLauncher( - ActivationType const* A, - WeightType const* B, - ScaleZeroType const* weight_scales, - ScaleZeroType const* weight_zero_points, - BiasType const* biases, - float const alpha, - OutputType* C, - int m, - int n, - int k, - int const group_size, - cutlass_extensions::CutlassGemmConfig gemm_config, - void* workspace, - size_t workspace_bytes, - cudaStream_t stream, - int* occupancy = nullptr) { - // The cutlass type for the input elements. This is needed to convert to cutlass::half_t if - // necessary. - using CutlassActivationType = typename CudaToCutlassTypeAdapter::type; - using CutlassWeightType = typename CudaToCutlassTypeAdapter::type; - using CutlassScaleZeroType = typename CudaToCutlassTypeAdapter::type; - using CutlassBiasType = typename CudaToCutlassTypeAdapter::type; - using CutlassOutputType = typename CudaToCutlassTypeAdapter::type; - - // We need separate config for each architecture since we will target different tensorcore - // instructions. For float, we do not target TCs. - using MixedGemmArchTraits = cutlass::gemm::kernel:: - MixedGemmArchTraits; - using ElementAccumulator = typename MixedGemmArchTraits::AccType; - - constexpr int ElementsPerAccessC = 128 / cutlass::sizeof_bits::value; - using EpilogueOp = typename cutlass_extensions:: - Epilogue::Op; - - using Operator = typename MixedGemmArchTraits::Operator; - using TaggedOperator = typename cutlass::arch::TagOperator::TaggedOperator; - - using GemmKernel_ = typename cutlass::gemm::kernel::DefaultGemm< - CutlassActivationType, - cutlass::layout::RowMajor, - MixedGemmArchTraits::ElementsPerAccessA, - CutlassWeightType, - typename MixedGemmArchTraits::LayoutB, - MixedGemmArchTraits::ElementsPerAccessB, - CutlassOutputType, - cutlass::layout::RowMajor, - ElementAccumulator, - cutlass::arch::OpClassTensorOp, - arch, - ThreadblockShape, - WarpShape, - typename MixedGemmArchTraits::InstructionShape, - EpilogueOp, - typename cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, - Stages, - true, - TaggedOperator>::GemmKernel; - - using GemmKernel = cutlass::gemm::kernel::GemmFpAIntB< - typename GemmKernel_::Mma, - typename GemmKernel_::Epilogue, - typename GemmKernel_::ThreadblockSwizzle, - arch, // Ensure top level arch is used for dispatch - GemmKernel_::kSplitKSerial>; - - if (occupancy != nullptr) { - *occupancy = cutlass_extensions::compute_occupancy_for_kernel(); - return; - } - - using Gemm = cutlass::gemm::device::GemmUniversalBaseCompat; - - int const ldb = - cutlass::platform:: - is_same::value - ? n - : k * GemmKernel::kInterleave; - - if (weight_scales == nullptr) { - throw std::runtime_error("Weight scales must always be set to a non-null value."); - } - - if constexpr (cutlass::isFinegrained(QuantOp)) { - if constexpr (cutlass::platform::is_same:: - value) { - if (group_size != 128) { - throw std::runtime_error( - "Only group size 128 supported for fine grained W4A(fp)8 kernels."); - } - } - if (group_size != 64 && group_size != 128) { - throw std::runtime_error( - "Only group size 64 and 128 supported for fine grained kernels."); - } - - if constexpr (QuantOp == cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY) { - if (weight_zero_points != nullptr) { - throw std::runtime_error( - "Weight zero pointer must be a nullptr for scale only fine grained"); - } - } else if constexpr (QuantOp == cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS) { - if (weight_zero_points == nullptr) { - throw std::runtime_error( - "Weight zero pointer must be valid for scale and bias fine grained"); - } - } - } else { - if (group_size != k) { - throw std::runtime_error("Invalid group size for per column scaling kernels."); - } - - if (weight_zero_points != nullptr) { - throw std::runtime_error( - "Weight zero-points must be null when running per column scaling"); - } - } - - int const ld_scale_zero = cutlass::isFinegrained(QuantOp) ? n : 0; - ElementAccumulator output_op_beta = - (biases == nullptr) ? ElementAccumulator(0.f) : ElementAccumulator(1.f); - typename Gemm::Arguments args( - {m, n, k}, - group_size, - {reinterpret_cast(const_cast(A)), k}, - {reinterpret_cast(const_cast(B)), ldb}, - {reinterpret_cast(const_cast(weight_scales)), - ld_scale_zero}, - {reinterpret_cast( - const_cast(weight_zero_points)), - ld_scale_zero}, - {reinterpret_cast(const_cast(biases)), 0}, - {reinterpret_cast(C), n}, - gemm_config.split_k_factor, - {ElementAccumulator(alpha), output_op_beta}); - - // This assertion is enabled because because for the column interleaved layout, K MUST be a - // multiple of threadblockK. The reason for this is that the default pitchlinear iterators are - // used to handle walking over the interleaved matrix. The way masking in handled in these do - // not map to the interleaved layout. We need to write our own predicated iterator in order to - // relax this limitation. - if (GemmKernel::kInterleave > 1 && - ((k % MixedGemmArchTraits::ThreadblockK) || - ((k / gemm_config.split_k_factor) % MixedGemmArchTraits::ThreadblockK))) { + ActivationType const* A, + WeightType const* B, + ScaleZeroType const* weight_scales, + ScaleZeroType const* weight_zero_points, + BiasType const* biases, + float const alpha, + OutputType* C, + int m, + int n, + int k, + int const group_size, + cutlass_extensions::CutlassGemmConfig gemm_config, + void* workspace, + size_t workspace_bytes, + cudaStream_t stream, + int* occupancy = nullptr) { + // The cutlass type for the input elements. This is needed to convert to + // cutlass::half_t if necessary. + using CutlassActivationType = + typename CudaToCutlassTypeAdapter::type; + using CutlassWeightType = typename CudaToCutlassTypeAdapter::type; + using CutlassScaleZeroType = + typename CudaToCutlassTypeAdapter::type; + using CutlassBiasType = typename CudaToCutlassTypeAdapter::type; + using CutlassOutputType = typename CudaToCutlassTypeAdapter::type; + + // We need separate config for each architecture since we will target + // different tensorcore instructions. For float, we do not target TCs. + using MixedGemmArchTraits = cutlass::gemm::kernel:: + MixedGemmArchTraits; + using ElementAccumulator = typename MixedGemmArchTraits::AccType; + + constexpr int ElementsPerAccessC = + 128 / cutlass::sizeof_bits::value; + using EpilogueOp = typename cutlass_extensions::Epilogue::Op; + + using Operator = typename MixedGemmArchTraits::Operator; + using TaggedOperator = + typename cutlass::arch::TagOperator::TaggedOperator; + + using GemmKernel_ = typename cutlass::gemm::kernel::DefaultGemm< + CutlassActivationType, + cutlass::layout::RowMajor, + MixedGemmArchTraits::ElementsPerAccessA, + CutlassWeightType, + typename MixedGemmArchTraits::LayoutB, + MixedGemmArchTraits::ElementsPerAccessB, + CutlassOutputType, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + arch, + ThreadblockShape, + WarpShape, + typename MixedGemmArchTraits::InstructionShape, + EpilogueOp, + typename cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + Stages, + true, + TaggedOperator>::GemmKernel; + + using GemmKernel = cutlass::gemm::kernel::GemmFpAIntB< + typename GemmKernel_::Mma, + typename GemmKernel_::Epilogue, + typename GemmKernel_::ThreadblockSwizzle, + arch, // Ensure top level arch is used for dispatch + GemmKernel_::kSplitKSerial>; + + if (occupancy != nullptr) { + *occupancy = cutlass_extensions::compute_occupancy_for_kernel(); + return; + } + + using Gemm = cutlass::gemm::device::GemmUniversalBaseCompat; + + int const ldb = + cutlass::platform::is_same::value + ? n + : k * GemmKernel::kInterleave; + + if (weight_scales == nullptr) { + throw std::runtime_error( + "Weight scales must always be set to a non-null value."); + } + + if constexpr (cutlass::isFinegrained(QuantOp)) { + if constexpr (cutlass::platform::is_same::value) { + if (group_size != 128) { throw std::runtime_error( - "Assertion: k[" + std::to_string(k) + "] must be multiple of threadblockK[" + - std::to_string(MixedGemmArchTraits::ThreadblockK) + "]"); + "Only group size 128 supported for fine grained W4A(fp)8 kernels."); + } } - - Gemm gemm; - - if (gemm.get_workspace_size(args) > workspace_bytes) { - std::cerr << "Requested split-k but workspace size insufficient. Falling back to " - "non-split-k implementation." - << std::endl; - // If requested split-k factor will require more workspace bytes, revert to standard gemm. - args.batch_count = 1; + if (group_size != 64 && group_size != 128) { + throw std::runtime_error( + "Only group size 64 and 128 supported for fine grained kernels."); } - auto can_implement = gemm.can_implement(args); - if (can_implement != cutlass::Status::kSuccess) { - std::string err_msg = "fp8_int4 cutlass kernel will fail for params. Error: " + - std::string(cutlassGetStatusString(can_implement)); - throw std::runtime_error("[fp8_int4 Runner] " + err_msg); + if constexpr (QuantOp == + cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY) { + if (weight_zero_points != nullptr) { + throw std::runtime_error( + "Weight zero pointer must be a nullptr for scale only fine " + "grained"); + } + } else if constexpr (QuantOp == cutlass::WeightOnlyQuantOp:: + FINEGRAINED_SCALE_AND_ZEROS) { + if (weight_zero_points == nullptr) { + throw std::runtime_error( + "Weight zero pointer must be valid for scale and bias fine " + "grained"); + } } - - auto init_status = gemm.initialize(args, workspace, stream); - if (init_status != cutlass::Status::kSuccess) { - std::string err_msg = "Failed to initialize cutlass fp8_int4 gemm. Error: " + - std::string(cutlassGetStatusString(init_status)); - throw std::runtime_error("[fp8_int4 Runner] " + err_msg); + } else { + if (group_size != k) { + throw std::runtime_error( + "Invalid group size for per column scaling kernels."); } - auto run_status = gemm.run(stream); - if (run_status != cutlass::Status::kSuccess) { - std::string err_msg = "Failed to run cutlass fp8_int4 gemm. Error: " + - std::string(cutlassGetStatusString(run_status)); - throw std::runtime_error("[fp8_int4 Runner] " + err_msg); + if (weight_zero_points != nullptr) { + throw std::runtime_error( + "Weight zero-points must be null when running per column scaling"); } + } + + int const ld_scale_zero = cutlass::isFinegrained(QuantOp) ? n : 0; + ElementAccumulator output_op_beta = + (biases == nullptr) ? ElementAccumulator(0.f) : ElementAccumulator(1.f); + typename Gemm::Arguments args( + {m, n, k}, + group_size, + {reinterpret_cast(const_cast(A)), + k}, + {reinterpret_cast(const_cast(B)), ldb}, + {reinterpret_cast( + const_cast(weight_scales)), + ld_scale_zero}, + {reinterpret_cast( + const_cast(weight_zero_points)), + ld_scale_zero}, + {reinterpret_cast(const_cast(biases)), 0}, + {reinterpret_cast(C), n}, + gemm_config.split_k_factor, + {ElementAccumulator(alpha), output_op_beta}); + + // This assertion is enabled because because for the column interleaved + // layout, K MUST be a multiple of threadblockK. The reason for this is that + // the default pitchlinear iterators are used to handle walking over the + // interleaved matrix. The way masking in handled in these do not map to the + // interleaved layout. We need to write our own predicated iterator in order + // to relax this limitation. + if (GemmKernel::kInterleave > 1 && ((k % MixedGemmArchTraits::ThreadblockK) || + ((k / gemm_config.split_k_factor) % + MixedGemmArchTraits::ThreadblockK))) { + throw std::runtime_error("Assertion: k[" + std::to_string(k) + + "] must be multiple of threadblockK[" + + std::to_string(MixedGemmArchTraits::ThreadblockK) + + "]"); + } + + Gemm gemm; + + if (gemm.get_workspace_size(args) > workspace_bytes) { + std::cerr + << "Requested split-k but workspace size insufficient. Falling back to " + "non-split-k implementation." + << std::endl; + // If requested split-k factor will require more workspace bytes, revert to + // standard gemm. + args.batch_count = 1; + } + + auto can_implement = gemm.can_implement(args); + if (can_implement != cutlass::Status::kSuccess) { + std::string err_msg = + "fp8_int4 cutlass kernel will fail for params. Error: " + + std::string(cutlassGetStatusString(can_implement)); + throw std::runtime_error("[fp8_int4 Runner] " + err_msg); + } + + auto init_status = gemm.initialize(args, workspace, stream); + if (init_status != cutlass::Status::kSuccess) { + std::string err_msg = + "Failed to initialize cutlass fp8_int4 gemm. Error: " + + std::string(cutlassGetStatusString(init_status)); + throw std::runtime_error("[fp8_int4 Runner] " + err_msg); + } + + auto run_status = gemm.run(stream); + if (run_status != cutlass::Status::kSuccess) { + std::string err_msg = "Failed to run cutlass fp8_int4 gemm. Error: " + + std::string(cutlassGetStatusString(run_status)); + throw std::runtime_error("[fp8_int4 Runner] " + err_msg); + } } -template < - typename ActivationType, - typename WeightType, - typename ScaleZeroType, - typename BiasType, - typename OutputType, - typename arch, - cutlass::WeightOnlyQuantOp QuantOp, - typename EpilogueTag, - typename ThreadblockShape, - typename WarpShape> -void dispatch_gemm_config( - ActivationType const* A, - WeightType const* B, - ScaleZeroType const* weight_scales, - ScaleZeroType const* weight_zero_points, - BiasType const* biases, - float const alpha, - OutputType* C, - int m, - int n, - int k, - int const group_size, - cutlass_extensions::CutlassGemmConfig gemm_config, - void* workspace, - size_t workspace_bytes, - cudaStream_t stream, - int* occupancy = nullptr) { - switch (gemm_config.stages) { +template +void dispatch_gemm_config(ActivationType const* A, + WeightType const* B, + ScaleZeroType const* weight_scales, + ScaleZeroType const* weight_zero_points, + BiasType const* biases, + float const alpha, + OutputType* C, + int m, + int n, + int k, + int const group_size, + cutlass_extensions::CutlassGemmConfig gemm_config, + void* workspace, + size_t workspace_bytes, + cudaStream_t stream, + int* occupancy = nullptr) { + switch (gemm_config.stages) { case 2: - throw std::runtime_error( - "[filter_and_run_mixed_gemm] Cutlass fp8_int4 gemm not supported for arch " + - std::to_string(arch::kMinComputeCapability) + " with stages set to 2"); - break; + throw std::runtime_error( + "[filter_and_run_mixed_gemm] Cutlass fp8_int4 gemm not supported for " + "arch " + + std::to_string(arch::kMinComputeCapability) + + " with stages set to 2"); + break; case 3: - generic_mixed_gemm_kernelLauncher< - ActivationType, - WeightType, - ScaleZeroType, - BiasType, - OutputType, - arch, - QuantOp, - EpilogueTag, - ThreadblockShape, - WarpShape, - 3>( - A, - B, - weight_scales, - weight_zero_points, - biases, - alpha, - C, - m, - n, - k, - group_size, - gemm_config, - workspace, - workspace_bytes, - stream, - occupancy); - break; + generic_mixed_gemm_kernelLauncher(A, + B, + weight_scales, + weight_zero_points, + biases, + alpha, + C, + m, + n, + k, + group_size, + gemm_config, + workspace, + workspace_bytes, + stream, + occupancy); + break; case 4: - generic_mixed_gemm_kernelLauncher< - ActivationType, - WeightType, - ScaleZeroType, - BiasType, - OutputType, - arch, - QuantOp, - EpilogueTag, - ThreadblockShape, - WarpShape, - 4>( - A, - B, - weight_scales, - weight_zero_points, - biases, - alpha, - C, - m, - n, - k, - group_size, - gemm_config, - workspace, - workspace_bytes, - stream, - occupancy); - break; + generic_mixed_gemm_kernelLauncher(A, + B, + weight_scales, + weight_zero_points, + biases, + alpha, + C, + m, + n, + k, + group_size, + gemm_config, + workspace, + workspace_bytes, + stream, + occupancy); + break; default: - std::string err_msg = "dispatch_gemm_config does not support stages " + - std::to_string(gemm_config.stages); - throw std::runtime_error("[dispatch_gemm_config] " + err_msg); - break; - } + std::string err_msg = "dispatch_gemm_config does not support stages " + + std::to_string(gemm_config.stages); + throw std::runtime_error("[dispatch_gemm_config] " + err_msg); + break; + } } -template < - typename ActivationType, - typename WeightType, - typename ScaleZeroType, - typename BiasType, - typename OutputType, - typename arch, - cutlass::WeightOnlyQuantOp QuantOp, - typename EpilogueTag> -void dispatch_gemm_to_cutlass( +template +void dispatch_gemm_to_cutlass(ActivationType const* A, + WeightType const* B, + ScaleZeroType const* weight_scales, + ScaleZeroType const* weight_zero_points, + BiasType const* biases, + float const alpha, + OutputType* C, + int m, + int n, + int k, + int const group_size, + void* workspace, + size_t workspace_bytes, + cutlass_extensions::CutlassGemmConfig gemm_config, + cudaStream_t stream, + int* occupancy = nullptr) { + // Note that SIMT configs are omitted here since they are not supported for + // fp8_int4. We also only instantiate configs here where threadblockShapeM == + // warpShapeM since those usually perform the best for mixed type gemms. + constexpr int tile_shape_k = + 128 * 8 / cutlass::sizeof_bits::value; + switch (gemm_config.tile_config) { + case cutlass_extensions::CutlassTileConfig:: + CtaShape16x128x64_WarpShape16x32x64: + dispatch_gemm_config, + cutlass::gemm::GemmShape<16, 32, tile_shape_k>>( + A, + B, + weight_scales, + weight_zero_points, + biases, + alpha, + C, + m, + n, + k, + group_size, + gemm_config, + workspace, + workspace_bytes, + stream, + occupancy); + break; + case cutlass_extensions::CutlassTileConfig:: + CtaShape16x256x64_WarpShape16x64x64: + dispatch_gemm_config, + cutlass::gemm::GemmShape<16, 64, tile_shape_k>>( + A, + B, + weight_scales, + weight_zero_points, + biases, + alpha, + C, + m, + n, + k, + group_size, + gemm_config, + workspace, + workspace_bytes, + stream, + occupancy); + break; + case cutlass_extensions::CutlassTileConfig:: + CtaShape32x128x64_WarpShape32x32x64: + dispatch_gemm_config, + cutlass::gemm::GemmShape<32, 32, tile_shape_k>>( + A, + B, + weight_scales, + weight_zero_points, + biases, + alpha, + C, + m, + n, + k, + group_size, + gemm_config, + workspace, + workspace_bytes, + stream, + occupancy); + break; + case cutlass_extensions::CutlassTileConfig:: + CtaShape64x128x64_WarpShape64x32x64: + dispatch_gemm_config, + cutlass::gemm::GemmShape<64, 32, tile_shape_k>>( + A, + B, + weight_scales, + weight_zero_points, + biases, + alpha, + C, + m, + n, + k, + group_size, + gemm_config, + workspace, + workspace_bytes, + stream, + occupancy); + break; + case cutlass_extensions::CutlassTileConfig:: + CtaShape128x128x64_WarpShape128x32x64: + dispatch_gemm_config, + cutlass::gemm::GemmShape<128, 32, tile_shape_k>>( + A, + B, + weight_scales, + weight_zero_points, + biases, + alpha, + C, + m, + n, + k, + group_size, + gemm_config, + workspace, + workspace_bytes, + stream, + occupancy); + break; + case cutlass_extensions::CutlassTileConfig::Undefined: + throw std::runtime_error( + "[fp8_int4][dispatch_gemm_to_cutlass] gemm config undefined."); + break; + case cutlass_extensions::CutlassTileConfig::ChooseWithHeuristic: + throw std::runtime_error( + "[fp8_int4][dispatch_gemm_to_cutlass] gemm config should have " + "already been set by " + "heuristic."); + break; + default: + printf("gemm_config.tile_config: %d", int(gemm_config.tile_config)); + throw std::runtime_error( + "[fp8_int4][dispatch_gemm_to_cutlass] Config is invalid for mixed " + "type GEMM."); + break; + } +} + +template +CutlassFpAIntBGemmRunner::CutlassFpAIntBGemmRunner() { + // printf(__PRETTY_FUNCTION__); + int device{-1}; + PADDLE_ENFORCE_GPU_SUCCESS(cudaGetDevice(&device)); + sm_ = common::getSMVersion(); + PADDLE_ENFORCE_GPU_SUCCESS(cudaDeviceGetAttribute( + &multi_processor_count_, cudaDevAttrMultiProcessorCount, device)); +} + +template +CutlassFpAIntBGemmRunner::~CutlassFpAIntBGemmRunner() { + // printf(__PRETTY_FUNCTION__); +} + +template +template +void CutlassFpAIntBGemmRunner:: + dispatch_to_arch( ActivationType const* A, WeightType const* B, ScaleZeroType const* weight_scales, @@ -359,424 +607,204 @@ void dispatch_gemm_to_cutlass( int n, int k, int const group_size, - void* workspace, - size_t workspace_bytes, cutlass_extensions::CutlassGemmConfig gemm_config, + void* workspace_ptr, + const size_t workspace_bytes, cudaStream_t stream, - int* occupancy = nullptr) { - // Note that SIMT configs are omitted here since they are not supported for fp8_int4. - // We also only instantiate configs here where threadblockShapeM == warpShapeM since those - // usually perform the best for mixed type gemms. - constexpr int tile_shape_k = 128 * 8 / cutlass::sizeof_bits::value; - switch (gemm_config.tile_config) { - case cutlass_extensions::CutlassTileConfig::CtaShape16x128x64_WarpShape16x32x64: - dispatch_gemm_config< - ActivationType, - WeightType, - ScaleZeroType, - BiasType, - OutputType, - arch, - QuantOp, - EpilogueTag, - cutlass::gemm::GemmShape<16, 128, tile_shape_k>, - cutlass::gemm::GemmShape<16, 32, tile_shape_k>>( - A, - B, - weight_scales, - weight_zero_points, - biases, - alpha, - C, - m, - n, - k, - group_size, - gemm_config, - workspace, - workspace_bytes, - stream, - occupancy); - break; - case cutlass_extensions::CutlassTileConfig::CtaShape16x256x64_WarpShape16x64x64: - dispatch_gemm_config< - ActivationType, - WeightType, - ScaleZeroType, - BiasType, - OutputType, - arch, - QuantOp, - EpilogueTag, - cutlass::gemm::GemmShape<16, 256, tile_shape_k>, - cutlass::gemm::GemmShape<16, 64, tile_shape_k>>( - A, - B, - weight_scales, - weight_zero_points, - biases, - alpha, - C, - m, - n, - k, - group_size, - gemm_config, - workspace, - workspace_bytes, - stream, - occupancy); - break; - case cutlass_extensions::CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64: - dispatch_gemm_config< - ActivationType, - WeightType, - ScaleZeroType, - BiasType, - OutputType, - arch, - QuantOp, - EpilogueTag, - cutlass::gemm::GemmShape<32, 128, tile_shape_k>, - cutlass::gemm::GemmShape<32, 32, tile_shape_k>>( - A, - B, - weight_scales, - weight_zero_points, - biases, - alpha, - C, - m, - n, - k, - group_size, - gemm_config, - workspace, - workspace_bytes, - stream, - occupancy); - break; - case cutlass_extensions::CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64: - dispatch_gemm_config< - ActivationType, - WeightType, - ScaleZeroType, - BiasType, - OutputType, - arch, - QuantOp, - EpilogueTag, - cutlass::gemm::GemmShape<64, 128, tile_shape_k>, - cutlass::gemm::GemmShape<64, 32, tile_shape_k>>( - A, - B, - weight_scales, - weight_zero_points, - biases, - alpha, - C, - m, - n, - k, - group_size, - gemm_config, - workspace, - workspace_bytes, - stream, - occupancy); - break; - case cutlass_extensions::CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64: - dispatch_gemm_config< - ActivationType, - WeightType, - ScaleZeroType, - BiasType, - OutputType, - arch, - QuantOp, - EpilogueTag, - cutlass::gemm::GemmShape<128, 128, tile_shape_k>, - cutlass::gemm::GemmShape<128, 32, tile_shape_k>>( - A, - B, - weight_scales, - weight_zero_points, - biases, - alpha, - C, - m, - n, - k, - group_size, - gemm_config, - workspace, - workspace_bytes, - stream, - occupancy); - break; - case cutlass_extensions::CutlassTileConfig::Undefined: - throw std::runtime_error("[fp8_int4][dispatch_gemm_to_cutlass] gemm config undefined."); - break; - case cutlass_extensions::CutlassTileConfig::ChooseWithHeuristic: - throw std::runtime_error( - "[fp8_int4][dispatch_gemm_to_cutlass] gemm config should have already been set by " - "heuristic."); - break; - default: - printf("gemm_config.tile_config: %d", int(gemm_config.tile_config)); - throw std::runtime_error( - "[fp8_int4][dispatch_gemm_to_cutlass] Config is invalid for mixed type GEMM."); - break; - } -} - -template < - typename ActivationType, - typename WeightType, - cutlass::WeightOnlyQuantOp QuantOp, - typename ScaleZeroType, - typename BiasType, - typename OutputType> -CutlassFpAIntBGemmRunner:: - CutlassFpAIntBGemmRunner() { - // printf(__PRETTY_FUNCTION__); - int device{-1}; - PADDLE_ENFORCE_GPU_SUCCESS(cudaGetDevice(&device)); - sm_ = common::getSMVersion(); - PADDLE_ENFORCE_GPU_SUCCESS(cudaDeviceGetAttribute( - &multi_processor_count_, cudaDevAttrMultiProcessorCount, device)); + int* occupancy) { + dispatch_gemm_to_cutlass(A, + B, + weight_scales, + weight_zero_points, + biases, + alpha, + C, + m, + n, + k, + group_size, + workspace_ptr, + workspace_bytes, + gemm_config, + stream, + occupancy); } -template < - typename ActivationType, - typename WeightType, - cutlass::WeightOnlyQuantOp QuantOp, - typename ScaleZeroType, - typename BiasType, - typename OutputType> -CutlassFpAIntBGemmRunner:: - ~CutlassFpAIntBGemmRunner() { - // printf(__PRETTY_FUNCTION__); -} - -template < - typename ActivationType, - typename WeightType, - cutlass::WeightOnlyQuantOp QuantOp, - typename ScaleZeroType, - typename BiasType, - typename OutputType> -template -void CutlassFpAIntBGemmRunner< - ActivationType, - WeightType, - QuantOp, - ScaleZeroType, - BiasType, - OutputType>:: - dispatch_to_arch( - ActivationType const* A, - WeightType const* B, - ScaleZeroType const* weight_scales, - ScaleZeroType const* weight_zero_points, - BiasType const* biases, - float const alpha, - OutputType* C, - int m, - int n, - int k, - int const group_size, - cutlass_extensions::CutlassGemmConfig gemm_config, - void* workspace_ptr, - const size_t workspace_bytes, - cudaStream_t stream, - int* occupancy) { - dispatch_gemm_to_cutlass< - ActivationType, - WeightType, - ScaleZeroType, - BiasType, - OutputType, - cutlass::arch::Sm89, - QuantOp, - EpilogueTag>( - A, - B, - weight_scales, - weight_zero_points, - biases, - alpha, - C, - m, - n, - k, - group_size, - workspace_ptr, - workspace_bytes, - gemm_config, - stream, - occupancy); -} - -template < - typename ActivationType, - typename WeightType, - cutlass::WeightOnlyQuantOp QuantOp, - typename ScaleZeroType, - typename BiasType, - typename OutputType> +template void CutlassFpAIntBGemmRunner< - ActivationType, - WeightType, - QuantOp, - ScaleZeroType, - BiasType, - OutputType>:: - gemm(void const* A, - void const* B, - void const* weight_scales, - void const* weight_zero_points, - void const* biases, - float const alpha, - void* C, - int m, - int n, - int k, - int const group_size, - cutlass_extensions::CutlassGemmConfig gemmConfig, - void* workspace_ptr, - const size_t workspace_bytes, - cudaStream_t stream) { - // printf(__PRETTY_FUNCTION__); - if (gemmConfig.tile_config == cutlass_extensions::CutlassTileConfig::ChooseWithHeuristic) { - std::vector configs = getConfigs(k); - std::vector occupancies(configs.size()); - for (size_t i = 0; i < configs.size(); ++i) { - dispatch_to_arch( - (ActivationType const*)A, - (WeightType const*)B, - (ScaleZeroType const*)weight_scales, - (ScaleZeroType const*)weight_zero_points, - (BiasType const*)biases, - alpha, - (OutputType*)C, - m, - n, - k, - group_size, - configs[i], - workspace_ptr, - workspace_bytes, - stream, - &occupancies[i]); - } - auto best_config = estimate_best_config_from_occupancies( - configs, - occupancies, - m, - n, - k, - 1, - SPLIT_K_LIMIT, - workspace_bytes, - multi_processor_count_, - true); - dispatch_to_arch( - (ActivationType const*)A, - (WeightType const*)B, - (ScaleZeroType const*)weight_scales, - (ScaleZeroType const*)weight_zero_points, - (BiasType const*)biases, - alpha, - (OutputType*)C, - m, - n, - k, - group_size, - best_config, - workspace_ptr, - workspace_bytes, - stream, - nullptr); - } else { - dispatch_to_arch( - (ActivationType const*)A, - (WeightType const*)B, - (ScaleZeroType const*)weight_scales, - (ScaleZeroType const*)weight_zero_points, - (BiasType const*)biases, - alpha, - (OutputType*)C, - m, - n, - k, - group_size, - gemmConfig, - workspace_ptr, - workspace_bytes, - stream, - nullptr); + ActivationType, + WeightType, + QuantOp, + ScaleZeroType, + BiasType, + OutputType>::gemm(void const* A, + void const* B, + void const* weight_scales, + void const* weight_zero_points, + void const* biases, + float const alpha, + void* C, + int m, + int n, + int k, + int const group_size, + cutlass_extensions::CutlassGemmConfig gemmConfig, + void* workspace_ptr, + const size_t workspace_bytes, + cudaStream_t stream) { + // printf(__PRETTY_FUNCTION__); + if (gemmConfig.tile_config == + cutlass_extensions::CutlassTileConfig::ChooseWithHeuristic) { + std::vector configs = getConfigs(k); + std::vector occupancies(configs.size()); + for (size_t i = 0; i < configs.size(); ++i) { + dispatch_to_arch( + (ActivationType const*)A, + (WeightType const*)B, + (ScaleZeroType const*)weight_scales, + (ScaleZeroType const*)weight_zero_points, + (BiasType const*)biases, + alpha, + (OutputType*)C, + m, + n, + k, + group_size, + configs[i], + workspace_ptr, + workspace_bytes, + stream, + &occupancies[i]); } + auto best_config = + estimate_best_config_from_occupancies(configs, + occupancies, + m, + n, + k, + 1, + SPLIT_K_LIMIT, + workspace_bytes, + multi_processor_count_, + true); + dispatch_to_arch( + (ActivationType const*)A, + (WeightType const*)B, + (ScaleZeroType const*)weight_scales, + (ScaleZeroType const*)weight_zero_points, + (BiasType const*)biases, + alpha, + (OutputType*)C, + m, + n, + k, + group_size, + best_config, + workspace_ptr, + workspace_bytes, + stream, + nullptr); + } else { + dispatch_to_arch( + (ActivationType const*)A, + (WeightType const*)B, + (ScaleZeroType const*)weight_scales, + (ScaleZeroType const*)weight_zero_points, + (BiasType const*)biases, + alpha, + (OutputType*)C, + m, + n, + k, + group_size, + gemmConfig, + workspace_ptr, + workspace_bytes, + stream, + nullptr); + } } -template < - typename ActivationType, - typename WeightType, - cutlass::WeightOnlyQuantOp QuantOp, - typename ScaleZeroType, - typename BiasType, - typename OutputType> +template std::vector -CutlassFpAIntBGemmRunner:: - getConfigs(int k) const { - // printf(__PRETTY_FUNCTION__); - cutlass_extensions::CutlassGemmConfig::CandidateConfigTypeParam config_type_param = - cutlass_extensions::CutlassGemmConfig::CandidateConfigTypeParam::HOPPER; - config_type_param = - static_cast( - config_type_param | - cutlass_extensions::CutlassGemmConfig::CandidateConfigTypeParam::WEIGHT_ONLY); - std::vector candidateConfigs = - get_candidate_configs(sm_, SPLIT_K_LIMIT, config_type_param); - - // filter configs that are not supported on sm89 - std::vector rets; - for (auto config : candidateConfigs) { - // sm89 doesn't support stages 2 - if (config.stages == 2) { - continue; - } - - if (config.stages >= 5) { - continue; - } - if (config.split_k_style != cutlass_extensions::SplitKStyle::NO_SPLIT_K) { - int k_size = (k + config.split_k_factor - 1) / config.split_k_factor; - if (k_size % 128) { - continue; - } - } - rets.push_back(config); +CutlassFpAIntBGemmRunner::getConfigs(int k) const { + // printf(__PRETTY_FUNCTION__); + cutlass_extensions::CutlassGemmConfig::CandidateConfigTypeParam + config_type_param = cutlass_extensions::CutlassGemmConfig:: + CandidateConfigTypeParam::HOPPER; + config_type_param = static_cast< + cutlass_extensions::CutlassGemmConfig::CandidateConfigTypeParam>( + config_type_param | cutlass_extensions::CutlassGemmConfig:: + CandidateConfigTypeParam::WEIGHT_ONLY); + std::vector candidateConfigs = + get_candidate_configs(sm_, SPLIT_K_LIMIT, config_type_param); + + // filter configs that are not supported on sm89 + std::vector rets; + for (auto config : candidateConfigs) { + // sm89 doesn't support stages 2 + if (config.stages == 2) { + continue; + } + + if (config.stages >= 5) { + continue; + } + if (config.split_k_style != cutlass_extensions::SplitKStyle::NO_SPLIT_K) { + int k_size = (k + config.split_k_factor - 1) / config.split_k_factor; + if (k_size % 128) { + continue; + } } - return rets; + rets.push_back(config); + } + return rets; } -template < - typename ActivationType, - typename WeightType, - cutlass::WeightOnlyQuantOp QuantOp, - typename ScaleZeroType, - typename BiasType, - typename OutputType> -size_t -CutlassFpAIntBGemmRunner:: - getWorkspaceSize(int const m, int const n, int const k) { - // printf(__PRETTY_FUNCTION__); - // These are the min tile sizes for each config, which would launch the maximum number of blocks - int const max_grid_m = cutlass::ceil_div(m, MIN_M_TILE); - int const max_grid_n = cutlass::ceil_div(n, MIN_N_TILE); - // We need 4 bytes per block in the worst case. We launch split_k_limit in z dim. - return static_cast(max_grid_m * max_grid_n * SPLIT_K_LIMIT * 4); +template +size_t CutlassFpAIntBGemmRunner::getWorkspaceSize(int const m, + int const n, + int const k) { + // printf(__PRETTY_FUNCTION__); + // These are the min tile sizes for each config, which would launch the + // maximum number of blocks + int const max_grid_m = cutlass::ceil_div(m, MIN_M_TILE); + int const max_grid_n = cutlass::ceil_div(n, MIN_N_TILE); + // We need 4 bytes per block in the worst case. We launch split_k_limit in z + // dim. + return static_cast(max_grid_m * max_grid_n * SPLIT_K_LIMIT * 4); } } // namespace cutlass_kernels diff --git a/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_cutlass_kernel.h b/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_cutlass_kernel.h index 356f3059687..9c5e7bfc47b 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_cutlass_kernel.h +++ b/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_cutlass_kernel.h @@ -43,7 +43,6 @@ #include "cutlass/trace.h" #include "cutlass_extensions/gemm/kernel/gemm_moe_problem_visitor.h" -#include "cutlass_extensions/gemm/threadblock/wint2x_tile_dequanter.h" #include "cutlass_extensions/tile_interleaved_layout.h" ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -278,7 +277,8 @@ struct MoeFCGemm { code_scale(const_cast(code_scale)), code_zp(const_cast(code_zp)), host_problem_sizes(nullptr) { - if (quant_method != WintQuantMethod::kNone || platform::is_same::value || + if (quant_method != WintQuantMethod::kNone || + platform::is_same::value || platform::is_same::value) { assert(weight_scales); } @@ -381,7 +381,8 @@ struct MoeFCGemm { } static Status can_implement(Arguments const& args) { - if (args.quant_method != WintQuantMethod::kNone || platform::is_same::value || + if (args.quant_method != WintQuantMethod::kNone || + platform::is_same::value || platform::is_same::value) { if (args.weight_scales == nullptr) { CUTLASS_TRACE_HOST( @@ -417,7 +418,6 @@ struct MoeFCGemm { template struct KernelRunner { - CUTLASS_DEVICE static void run_kernel(Params const& params, SharedStorage& shared_storage) { // NOLINT @@ -472,9 +472,13 @@ struct MoeFCGemm { int64_t rows_to_jump = 0; if (params.problem_visitor.total_rows < 0) { - rows_to_jump = problem_idx == 0 ? 0 : params.problem_visitor.last_row_for_problem[problem_idx - 1]; + rows_to_jump = problem_idx == 0 + ? 0 + : params.problem_visitor + .last_row_for_problem[problem_idx - 1]; } else { - rows_to_jump = problem_idx * (params.problem_visitor.total_rows / params.problem_visitor.problem_count); + rows_to_jump = problem_idx * (params.problem_visitor.total_rows / + params.problem_visitor.problem_count); } // begin address offset for A for current tile @@ -497,11 +501,13 @@ struct MoeFCGemm { 0, }; - // the begin threadblock_offset of B, which holds the same column id with C + // the begin threadblock_offset of B, which holds the same column id + // with C cutlass::MatrixCoord tb_offset_B{0, threadblock_offset.n() / kInterleave}; - // the begin threadblock_offset of scale, which holds the same column id with C, but with no row id + // the begin threadblock_offset of scale, which holds the same column id + // with C, but with no row id cutlass::MatrixCoord tb_offset_scale{0, threadblock_offset.n()}; // Compute position within threadblock @@ -629,7 +635,7 @@ struct MoeFCGemm { static constexpr bool compile_needed = platform::is_same::value; KernelRunner::run_kernel(params, shared_storage); -#elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && (__CUDA_ARCH__ < 910) +#elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && (__CUDA_ARCH__ < 1010) static constexpr bool compile_needed = platform::is_same::value; KernelRunner::run_kernel(params, shared_storage); @@ -650,9 +656,17 @@ template -struct Wint2xMoeFCGemm : public MoeFCGemm { +struct Wint2xMoeFCGemm : public MoeFCGemm { public: - using Base = MoeFCGemm; + using Base = MoeFCGemm; using Mma = Mma_; using Epilogue = Epilogue_; using EpilogueOutputOp = typename Epilogue::OutputOp; @@ -712,7 +726,11 @@ struct Wint2xMoeFCGemm : public MoeFCGemm struct KernelRunner { using WeightQuantTraits = WintQuantTraits; - using QuantArguments = typename WeightQuantTraits::Arguments; + using MmaQuantArguments = typename Mma::QuantParamsAccessor::Arguments; CUTLASS_DEVICE - static QuantArguments get_quant_args(Params const& params, int32_t problem_idx, const int64_t gemm_k, const int64_t gemm_n) { - QuantArguments quant_args; - if constexpr (QuantMethod == WintQuantMethod::kWeightOnlyInt2) { - quant_args.local_scale_ptr = params.local_scale + problem_idx * gemm_k * gemm_n / 128; - quant_args.code_scale_ptr = params.code_scale + problem_idx * gemm_n; - quant_args.code_zp_ptr = params.code_zp + problem_idx * gemm_n; - } - return quant_args; + static MmaQuantArguments prepare_quant_args( + Params const& params, + cutlass::gemm::GemmCoord const& threadblock_offset, + int64_t problem_idx, + const int32_t gemm_k, + const int32_t gemm_n, + const int thread_idx) { + // the begin threadblock_offset of scale, which holds the same column id + // with C, but with no row id + cutlass::MatrixCoord tb_offset_scale{0, threadblock_offset.n()}; + cutlass::MatrixCoord tb_offset_local_scale{0, threadblock_offset.n() * 2}; + + ElementScale* weight_scale_ptr = + params.weight_scales + problem_idx * gemm_n; + typename Mma::QuantParamsAccessor::IteratorSuperScale + iterator_super_scale( + Mma::QuantParamsAccessor::LayoutSuperScale(gemm_n), + weight_scale_ptr, + {1, gemm_n}, + thread_idx, + tb_offset_scale); + + int local_scale_pointer_offset = + ((ThreadblockShape::kK + 127) / 128) * (gemm_n * 2); + int64_t offset_in_bytes = problem_idx * gemm_k * gemm_n / 128; + uint4b_t* local_scale_ptr = + reinterpret_cast(params.local_scale + offset_in_bytes); + + typename Mma::QuantParamsAccessor::IteratorLocalScale + iterator_local_scale( + Mma::QuantParamsAccessor::LayoutLocalScale(gemm_n * 2), + local_scale_ptr, + {(gemm_k + 127) / 128, gemm_n * 2}, + thread_idx, + tb_offset_local_scale); + + float* code_scale_ptr = params.code_scale + problem_idx * gemm_n; + typename Mma::QuantParamsAccessor::IteratorCodeScaleZp + iterator_code_scale( + Mma::QuantParamsAccessor::LayoutCodeScaleZp(gemm_n), + code_scale_ptr, + {1, gemm_n}, + thread_idx, + tb_offset_scale); + + float* code_zp_ptr = params.code_zp + problem_idx * gemm_n; + typename Mma::QuantParamsAccessor::IteratorCodeScaleZp iterator_code_zp( + Mma::QuantParamsAccessor::LayoutCodeScaleZp(gemm_n), + code_zp_ptr, + {1, gemm_n}, + thread_idx, + tb_offset_scale); + + MmaQuantArguments mma_quant_args(iterator_super_scale, + iterator_local_scale, + iterator_code_scale, + iterator_code_zp, + local_scale_pointer_offset); + return mma_quant_args; } CUTLASS_DEVICE @@ -814,9 +883,6 @@ struct Wint2xMoeFCGemm : public MoeFCGemm= 1, "B must be row major/col major OR col major interleaved."); - // LayoutB should be RowMajor - using TileDequanterB = cutlass::gemm::threadblock::TileDequanter; - // // Problem visitor. // @@ -825,9 +891,13 @@ struct Wint2xMoeFCGemm : public MoeFCGemm::CaclPackedDim(gemm_k); - int64_t bytes_per_expert_matrix = (quant_gemm_k * gemm_n / 8) * cutlass::sizeof_bits::value; + // wint2.5 and wint2.0 is quantized and packed along k dimension with + // group_size 64. + const int64_t quant_gemm_k = + WintQuantTraits::CaclPackedDim(gemm_k); + int64_t bytes_per_expert_matrix = + (quant_gemm_k * gemm_n / 8) * + cutlass::sizeof_bits::value; // Outer 'persistent' loop to iterate over tiles while (problem_visitor.next_tile()) { @@ -843,20 +913,18 @@ struct Wint2xMoeFCGemm : public MoeFCGemm(byte_ptr_B); typename LayoutB::LongIndex ldm_B = platform::is_same::value ? gemm_n : gemm_k * kInterleave; - typename LayoutB::LongIndex ldm_B_shared = TileDequanterB::kColumns; - // the begin threadblock_offset of B, which holds the same column id with C + // the begin threadblock_offset of B, which holds the same column id + // with C cutlass::MatrixCoord tb_offset_B{0, threadblock_offset.n() / kInterleave}; - - cutlass::MatrixCoord extent_B{problem_size.k() * kInterleave, problem_size.n() / kInterleave}; - cutlass::MatrixCoord extent_B_shared{TileDequanterB::kRows, TileDequanterB::kColumns}; - - MmaElementB* smem_unzip_B_ptr = nullptr; - if constexpr (QuantMethod == WintQuantMethod::kWeightOnlyInt2) { - smem_unzip_B_ptr = shared_storage.main_loop.operand_unzip_B_ptr(); - } - QuantArguments quant_args = get_quant_args(params, problem_idx, gemm_k, gemm_n); - TileDequanterB tile_dequanter_B(smem_unzip_B_ptr, - byte_ptr_B, - ldm_B, - extent_B, - tb_offset_B, - weight_scale_ptr, - tb_offset_scale, - quant_args); - MmaElementB* ptr_B = tile_dequanter_B.GetOutPtr(); + cutlass::MatrixCoord extent_B{problem_size.k() * kInterleave, + problem_size.n() / kInterleave}; // Compute position within threadblock int thread_idx = threadIdx.x; @@ -914,20 +964,22 @@ struct Wint2xMoeFCGemm : public MoeFCGemm(params.ptr_C) + problem_idx * gemm_n : nullptr; + ElementC* ptr_C = params.ptr_C + ? reinterpret_cast(params.ptr_C) + + problem_idx * gemm_n + : nullptr; ElementC* ptr_D = reinterpret_cast(params.ptr_D) + rows_to_jump * gemm_n; @@ -1006,8 +1060,9 @@ struct Wint2xMoeFCGemm : public MoeFCGemm= 800) && (__CUDA_ARCH__ < 910) - KernelRunner::run_kernel(params, shared_storage); +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && (__CUDA_ARCH__ < 1010) + KernelRunner::run_kernel( + params, shared_storage); #else CUTLASS_NOT_IMPLEMENTED(); #endif diff --git a/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_gemm_kernels_bf16_bf16.cu b/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_gemm_kernels_bf16_bf16.cu index d8496073fa7..765321f61f4 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_gemm_kernels_bf16_bf16.cu +++ b/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_gemm_kernels_bf16_bf16.cu @@ -23,7 +23,8 @@ namespace phi { #ifdef PADDLE_CUDA_BF16 template class MoeGemmRunner< - __nv_bfloat16, cutlass::WintQuantTraits<__nv_bfloat16, cutlass::WintQuantMethod::kNone>>; + __nv_bfloat16, + cutlass::WintQuantTraits<__nv_bfloat16, cutlass::WintQuantMethod::kNone>>; #endif -} // namespace phi +} // namespace phi diff --git a/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_gemm_kernels_bf16_int2.cu b/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_gemm_kernels_bf16_int2.cu index 92d63948c18..991caf6b4f4 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_gemm_kernels_bf16_int2.cu +++ b/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_gemm_kernels_bf16_int2.cu @@ -24,7 +24,8 @@ namespace phi { #ifdef PADDLE_CUDA_BF16 template class MoeGemmRunner< __nv_bfloat16, - cutlass::WintQuantTraits<__nv_bfloat16, cutlass::WintQuantMethod::kWeightOnlyInt2>>; + cutlass::WintQuantTraits<__nv_bfloat16, + cutlass::WintQuantMethod::kWeightOnlyInt2>>; #endif -} // namespace phi +} // namespace phi diff --git a/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_gemm_kernels_bf16_int4.cu b/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_gemm_kernels_bf16_int4.cu index b82fbc107c5..c3512aa45ae 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_gemm_kernels_bf16_int4.cu +++ b/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_gemm_kernels_bf16_int4.cu @@ -23,7 +23,8 @@ namespace phi { #ifdef PADDLE_CUDA_BF16 template class MoeGemmRunner< __nv_bfloat16, - cutlass::WintQuantTraits<__nv_bfloat16, cutlass::WintQuantMethod::kWeightOnlyInt4>>; + cutlass::WintQuantTraits<__nv_bfloat16, + cutlass::WintQuantMethod::kWeightOnlyInt4>>; #endif -} // namespace phi +} // namespace phi diff --git a/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_gemm_kernels_bf16_int8.cu b/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_gemm_kernels_bf16_int8.cu index 97fdd104bac..f7788ca961c 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_gemm_kernels_bf16_int8.cu +++ b/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_gemm_kernels_bf16_int8.cu @@ -24,7 +24,8 @@ namespace phi { #ifdef PADDLE_CUDA_BF16 template class MoeGemmRunner< __nv_bfloat16, - cutlass::WintQuantTraits<__nv_bfloat16, cutlass::WintQuantMethod::kWeightOnlyInt8>>; + cutlass::WintQuantTraits<__nv_bfloat16, + cutlass::WintQuantMethod::kWeightOnlyInt8>>; #endif -} // namespace phi +} // namespace phi diff --git a/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_gemm_kernels_fp16_fp16.cu b/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_gemm_kernels_fp16_fp16.cu index a3d34b8e728..40608b1b98e 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_gemm_kernels_fp16_fp16.cu +++ b/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_gemm_kernels_fp16_fp16.cu @@ -21,7 +21,8 @@ namespace phi { -template class MoeGemmRunner>; +template class MoeGemmRunner< + half, + cutlass::WintQuantTraits>; -} // namespace phi +} // namespace phi diff --git a/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_gemm_kernels_fp16_int2.cu b/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_gemm_kernels_fp16_int2.cu index 5d84c9cfc1b..8d0519beca1 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_gemm_kernels_fp16_int2.cu +++ b/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_gemm_kernels_fp16_int2.cu @@ -22,6 +22,7 @@ namespace phi { template class MoeGemmRunner< - half, cutlass::WintQuantTraits>; + half, + cutlass::WintQuantTraits>; -} // namespace phi +} // namespace phi diff --git a/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_gemm_kernels_fp16_int4.cu b/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_gemm_kernels_fp16_int4.cu index 51707ebbb80..ffbfd11b678 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_gemm_kernels_fp16_int4.cu +++ b/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_gemm_kernels_fp16_int4.cu @@ -22,6 +22,7 @@ namespace phi { template class MoeGemmRunner< - half, cutlass::WintQuantTraits>; + half, + cutlass::WintQuantTraits>; -} // namespace phi +} // namespace phi diff --git a/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_gemm_kernels_fp16_int8.cu b/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_gemm_kernels_fp16_int8.cu index c796f9bbe50..adf9b91814d 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_gemm_kernels_fp16_int8.cu +++ b/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_gemm_kernels_fp16_int8.cu @@ -22,6 +22,7 @@ namespace phi { template class MoeGemmRunner< - half, cutlass::WintQuantTraits>; + half, + cutlass::WintQuantTraits>; -} // namespace phi +} // namespace phi diff --git a/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_gemm_kernels_template.h b/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_gemm_kernels_template.h index b5cb93ad3fa..db5af4f4938 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_gemm_kernels_template.h +++ b/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_gemm_kernels_template.h @@ -26,10 +26,10 @@ #include #include "cutlass/array.h" -#include "cutlass/trace.h" -#include "cutlass/numeric_conversion.h" #include "cutlass/gemm/device/gemm_grouped.h" #include "cutlass/gemm/kernel/default_gemm_grouped.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/trace.h" #include "paddle/common/errors.h" #include "paddle/phi/core/enforce.h" @@ -63,24 +63,28 @@ struct CutlassLayoutB { using Type = cutlass::layout::RowMajor; }; -template +template struct CutlassGemmKernel { - using Type = - cutlass::gemm::kernel::MoeFCGemm; + using Type = cutlass::gemm::kernel::MoeFCGemm< + typename BaseGemmKernel::Mma, + typename BaseGemmKernel::Epilogue, + typename BaseGemmKernel::ThreadblockSwizzle, + Arch, + BaseGemmKernel::kGroupScheduleMode>; }; template -struct CutlassGemmKernel { - using Type = - cutlass::gemm::kernel::Wint2xMoeFCGemm; +struct CutlassGemmKernel { + using Type = cutlass::gemm::kernel::Wint2xMoeFCGemm< + typename BaseGemmKernel::Mma, + typename BaseGemmKernel::Epilogue, + typename BaseGemmKernel::ThreadblockSwizzle, + Arch, + BaseGemmKernel::kGroupScheduleMode>; }; // ======================= Variable batched Gemm things ======================= @@ -91,21 +95,22 @@ template -void generic_moe_gemm_kernelLauncher(const T* A, - const typename WeightQuantTraits::WeightType* B, - const T* weight_scales, - const T* biases, - T* C, - int64_t* total_rows_before_expert, - int64_t total_rows, - int64_t gemm_n, - int64_t gemm_k, - int num_experts, - const typename WeightQuantTraits::Arguments& quant_args_B, - CutlassGemmConfig gemm_config, - const int multi_processor_count, - cudaStream_t stream, - int* kernel_occupancy = nullptr) { +void generic_moe_gemm_kernelLauncher( + const T* A, + const typename WeightQuantTraits::WeightType* B, + const T* weight_scales, + const T* biases, + T* C, + int64_t* total_rows_before_expert, + int64_t total_rows, + int64_t gemm_n, + int64_t gemm_k, + int num_experts, + const typename WeightQuantTraits::Arguments& quant_args_B, + CutlassGemmConfig gemm_config, + const int multi_processor_count, + cudaStream_t stream, + int* kernel_occupancy = nullptr) { if (gemm_config.split_k_style != SplitKStyle::NO_SPLIT_K) { throw std::runtime_error("[MoeGemm] Grouped gemm does not support split-k"); } @@ -128,12 +133,14 @@ void generic_moe_gemm_kernelLauncher(const T* A, cutlass::platform::is_same::value || cutlass::platform::is_same::value || cutlass::platform::is_same::value, - "Specialized for bfloat16, half, float, uint8_t (wint8), uint4b_t (wint4), uint16_t (wint2.5)"); + "Specialized for bfloat16, half, float, uint8_t (wint8), uint4b_t " + "(wint4), uint16_t (wint2.5)"); // The cutlass type for the input elements. This is needed to convert to // cutlass::half_t if necessary. using ElementType = typename cutlass::CutlassDataType::Type; - using CutlassWeightType = typename cutlass::CutlassDataType::Type; + using CutlassWeightType = typename cutlass::CutlassDataType< + typename WeightQuantTraits::WeightType>::Type; using CutlassMmaWeightType = typename WeightQuantTraits::MmaWeightType; using CutlassMmaKernelType = typename WeightQuantTraits::MmaKernelType; @@ -155,7 +162,8 @@ void generic_moe_gemm_kernelLauncher(const T* A, cutlass::ComplexTransform::kNone, MixedGemmArchTraits::ElementsPerAccessA, CutlassMmaKernelType, - typename CutlassLayoutB::Type, + typename CutlassLayoutB::Type, cutlass::ComplexTransform::kNone, MixedGemmArchTraits::ElementsPerAccessB, ElementType, @@ -172,7 +180,10 @@ void generic_moe_gemm_kernelLauncher(const T* A, cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly, typename MixedGemmArchTraits::Operator>::GemmKernel; - using GemmKernel = typename CutlassGemmKernel::Type; + using GemmKernel = + typename CutlassGemmKernel::Type; using GemmGrouped = cutlass::gemm::device::GemmGrouped; if (kernel_occupancy != nullptr) { @@ -194,7 +205,8 @@ void generic_moe_gemm_kernelLauncher(const T* A, const uint8_t* local_scale_B = nullptr; const float* code_scale_B = nullptr; const float* code_zp_B = nullptr; - if constexpr (WeightQuantTraits::kQuantMethod == cutlass::WintQuantMethod::kWeightOnlyInt2) { + if constexpr (WeightQuantTraits::kQuantMethod == + cutlass::WintQuantMethod::kWeightOnlyInt2) { local_scale_B = quant_args_B.local_scale_ptr; code_scale_B = quant_args_B.code_scale_ptr; code_zp_B = quant_args_B.code_zp_ptr; @@ -205,7 +217,7 @@ void generic_moe_gemm_kernelLauncher(const T* A, threadblock_count, epilogue_op, reinterpret_cast(A), - reinterpret_cast(B), + reinterpret_cast(B), reinterpret_cast(weight_scales), reinterpret_cast(biases), reinterpret_cast(C), @@ -253,21 +265,22 @@ template struct dispatch_stages { - static void dispatch(const T* A, - const typename WeightQuantTraits::WeightType* B, - const T* weight_scales, - const T* biases, - T* C, - int64_t* total_rows_before_expert, - int64_t total_rows, - int64_t gemm_n, - int64_t gemm_k, - int num_experts, - const typename WeightQuantTraits::Arguments& quant_args_B, - CutlassGemmConfig gemm_config, - int multi_processor_count, - cudaStream_t stream, - int* occupancy = nullptr) { + static void dispatch( + const T* A, + const typename WeightQuantTraits::WeightType* B, + const T* weight_scales, + const T* biases, + T* C, + int64_t* total_rows_before_expert, + int64_t total_rows, + int64_t gemm_n, + int64_t gemm_k, + int num_experts, + const typename WeightQuantTraits::Arguments& quant_args_B, + CutlassGemmConfig gemm_config, + int multi_processor_count, + cudaStream_t stream, + int* occupancy = nullptr) { // FT_LOG_DEBUG(__PRETTY_FUNCTION__); std::string err_msg = "Cutlass fpA_intB gemm. Not instantiates for arch " + std::to_string(arch::kMinComputeCapability) + @@ -289,21 +302,22 @@ struct dispatch_stages { - static void dispatch(const T* A, - const typename WeightQuantTraits::WeightType* B, - const T* weight_scales, - const T* biases, - T* C, - int64_t* total_rows_before_expert, - int64_t total_rows, - int64_t gemm_n, - int64_t gemm_k, - int num_experts, - const typename WeightQuantTraits::Arguments& quant_args_B, - CutlassGemmConfig gemm_config, - int multi_processor_count, - cudaStream_t stream, - int* occupancy = nullptr) { + static void dispatch( + const T* A, + const typename WeightQuantTraits::WeightType* B, + const T* weight_scales, + const T* biases, + T* C, + int64_t* total_rows_before_expert, + int64_t total_rows, + int64_t gemm_n, + int64_t gemm_k, + int num_experts, + const typename WeightQuantTraits::Arguments& quant_args_B, + CutlassGemmConfig gemm_config, + int multi_processor_count, + cudaStream_t stream, + int* occupancy = nullptr) { generic_moe_gemm_kernelLauncher 2)>::type> { - static void dispatch(const T* A, - const typename WeightQuantTraits::WeightType* B, - const T* weight_scales, - const T* biases, - T* C, - int64_t* total_rows_before_expert, - int64_t total_rows, - int64_t gemm_n, - int64_t gemm_k, - int num_experts, - const typename WeightQuantTraits::Arguments& quant_args_B, - CutlassGemmConfig gemm_config, - int multi_processor_count, - cudaStream_t stream, - int* occupancy = nullptr) { + static void dispatch( + const T* A, + const typename WeightQuantTraits::WeightType* B, + const T* weight_scales, + const T* biases, + T* C, + int64_t* total_rows_before_expert, + int64_t total_rows, + int64_t gemm_n, + int64_t gemm_k, + int num_experts, + const typename WeightQuantTraits::Arguments& quant_args_B, + CutlassGemmConfig gemm_config, + int multi_processor_count, + cudaStream_t stream, + int* occupancy = nullptr) { generic_moe_gemm_kernelLauncher -void dispatch_gemm_config(const T* A, - const typename WeightQuantTraits::WeightType* B, - const T* weight_scales, - const T* biases, - T* C, - int64_t* total_rows_before_expert, - int64_t total_rows, - int64_t gemm_n, - int64_t gemm_k, - int num_experts, - const typename WeightQuantTraits::Arguments& quant_args_B, - CutlassGemmConfig gemm_config, - int multi_processor_count, - cudaStream_t stream, - int* occupancy = nullptr) { +void dispatch_gemm_config( + const T* A, + const typename WeightQuantTraits::WeightType* B, + const T* weight_scales, + const T* biases, + T* C, + int64_t* total_rows_before_expert, + int64_t total_rows, + int64_t gemm_n, + int64_t gemm_k, + int num_experts, + const typename WeightQuantTraits::Arguments& quant_args_B, + CutlassGemmConfig gemm_config, + int multi_processor_count, + cudaStream_t stream, + int* occupancy = nullptr) { #define dispatch_stages_macro(STAGE) \ case STAGE: \ dispatch_stages::value && - std::is_same::value>::type* = - nullptr> -void dispatch_moe_gemm_to_cutlass(const T* A, - const typename WeightQuantTraits::WeightType* B, - const T* weight_scales, - const T* biases, - T* C, - int64_t* total_rows_before_expert, - int64_t total_rows, - int64_t gemm_n, - int64_t gemm_k, - int num_experts, - const typename WeightQuantTraits::Arguments& quant_args_B, - CutlassGemmConfig gemm_config, - int sm_version, - int multi_processor_count, - cudaStream_t stream, - int* occupancy = nullptr) { +template < + typename T, + typename WeightQuantTraits, + typename arch, + typename EpilogueTag, + typename std::enable_if< + !std::is_same::value && + std::is_same::value>::type* = + nullptr> +void dispatch_moe_gemm_to_cutlass( + const T* A, + const typename WeightQuantTraits::WeightType* B, + const T* weight_scales, + const T* biases, + T* C, + int64_t* total_rows_before_expert, + int64_t total_rows, + int64_t gemm_n, + int64_t gemm_k, + int num_experts, + const typename WeightQuantTraits::Arguments& quant_args_B, + CutlassGemmConfig gemm_config, + int sm_version, + int multi_processor_count, + cudaStream_t stream, + int* occupancy = nullptr) { switch (gemm_config.tile_config) { dispatch_gemm_config_macro(32, 128, 64, 32, 32, 64); dispatch_gemm_config_macro(64, 128, 64, 32, 64, 64); dispatch_gemm_config_macro(128, 128, 64, 64, 32, 64); case CutlassTileConfig::Undefined: - throw std::runtime_error("[dispatch_moe_gemm_to_cutlass] gemm config undefined."); + throw std::runtime_error( + "[dispatch_moe_gemm_to_cutlass] gemm config undefined."); break; case CutlassTileConfig::ChooseWithHeuristic: throw std::runtime_error( @@ -518,32 +538,36 @@ template ::value && - !std::is_same::value>::type* = - nullptr> -void dispatch_moe_gemm_to_cutlass(const T* A, - const typename WeightQuantTraits::WeightType* B, - const T* weight_scales, - const T* biases, - T* C, - int64_t* total_rows_before_expert, - int64_t total_rows, - int64_t gemm_n, - int64_t gemm_k, - int num_experts, - const typename WeightQuantTraits::Arguments& quant_args_B, - CutlassGemmConfig gemm_config, - int sm_version, - int multi_processor_count, - cudaStream_t stream, - int* occupancy = nullptr) { + typename std::enable_if< + !std::is_same::value && + !std::is_same::value>:: + type* = nullptr> +void dispatch_moe_gemm_to_cutlass( + const T* A, + const typename WeightQuantTraits::WeightType* B, + const T* weight_scales, + const T* biases, + T* C, + int64_t* total_rows_before_expert, + int64_t total_rows, + int64_t gemm_n, + int64_t gemm_k, + int num_experts, + const typename WeightQuantTraits::Arguments& quant_args_B, + CutlassGemmConfig gemm_config, + int sm_version, + int multi_processor_count, + cudaStream_t stream, + int* occupancy = nullptr) { if constexpr (std::is_same::value) { - if constexpr (WeightQuantTraits::kQuantMethod != cutlass::WintQuantMethod::kWeightOnlyInt2) { + if constexpr (WeightQuantTraits::kQuantMethod != + cutlass::WintQuantMethod::kWeightOnlyInt2) { switch (gemm_config.tile_config) { dispatch_gemm_config_macro(32, 128, 64, 32, 32, 64); dispatch_gemm_config_macro(64, 128, 64, 64, 64, 64); case CutlassTileConfig::Undefined: - throw std::runtime_error("[dispatch_moe_gemm_to_cutlass] gemm config undefined."); + throw std::runtime_error( + "[dispatch_moe_gemm_to_cutlass] gemm config undefined."); break; case CutlassTileConfig::ChooseWithHeuristic: throw std::runtime_error( @@ -558,7 +582,8 @@ void dispatch_moe_gemm_to_cutlass(const T* A, } } else { throw std::runtime_error( - "[dispatch_moe_gemm_to_cutlass] weight_only_int2 does not support sm70."); + "[dispatch_moe_gemm_to_cutlass] weight_only_int2 does not support " + "sm70."); } } else { switch (gemm_config.tile_config) { @@ -574,7 +599,8 @@ void dispatch_moe_gemm_to_cutlass(const T* A, dispatch_gemm_config_macro(64, 128, 64, 64, 32, 64); dispatch_gemm_config_macro(256, 128, 64, 64, 64, 64); case CutlassTileConfig::Undefined: - throw std::runtime_error("[dispatch_moe_gemm_to_cutlass] gemm config undefined."); + throw std::runtime_error( + "[dispatch_moe_gemm_to_cutlass] gemm config undefined."); break; case CutlassTileConfig::ChooseWithHeuristic: throw std::runtime_error( @@ -597,22 +623,23 @@ template < typename arch, typename EpilogueTag, typename std::enable_if::value>::type* = nullptr> -void dispatch_moe_gemm_to_cutlass(const T* A, - const typename WeightQuantTraits::WeightType* B, - const T* weight_scales, - const T* biases, - T* C, - int64_t* total_rows_before_expert, - int64_t total_rows, - int64_t gemm_n, - int64_t gemm_k, - int num_experts, - const typename WeightQuantTraits::Arguments& quant_args_B, - CutlassGemmConfig gemm_config, - int sm_version, - int multi_processor_count, - cudaStream_t stream, - int* occupancy = nullptr) { +void dispatch_moe_gemm_to_cutlass( + const T* A, + const typename WeightQuantTraits::WeightType* B, + const T* weight_scales, + const T* biases, + T* C, + int64_t* total_rows_before_expert, + int64_t total_rows, + int64_t gemm_n, + int64_t gemm_k, + int num_experts, + const typename WeightQuantTraits::Arguments& quant_args_B, + CutlassGemmConfig gemm_config, + int sm_version, + int multi_processor_count, + cudaStream_t stream, + int* occupancy = nullptr) { switch (gemm_config.tile_config) { dispatch_gemm_config_macro(128, 128, 8, 64, 64, 8); case CutlassTileConfig::Undefined: @@ -659,33 +686,34 @@ void MoeGemmRunner::dispatch_to_arch( CutlassGemmConfig gemm_config, cudaStream_t stream, int* occupancy) { -#define dispatch_moe_gemm_to_cutlass_macro(ARCH) \ +#define dispatch_moe_gemm_to_cutlass_macro(ARCH) \ dispatch_moe_gemm_to_cutlass( \ - A, \ - B, \ - weight_scales, \ - biases, \ - C, \ - total_rows_before_expert, \ - total_rows, \ - gemm_n, \ - gemm_k, \ - num_experts, \ - quant_args_B, \ - gemm_config, \ - sm_, \ - multi_processor_count_, \ - stream, \ + A, \ + B, \ + weight_scales, \ + biases, \ + C, \ + total_rows_before_expert, \ + total_rows, \ + gemm_n, \ + gemm_k, \ + num_experts, \ + quant_args_B, \ + gemm_config, \ + sm_, \ + multi_processor_count_, \ + stream, \ occupancy); if (sm_ >= 70 && sm_ < 75) { dispatch_moe_gemm_to_cutlass_macro(cutlass::arch::Sm70); } else if (sm_ >= 75 && sm_ < 80) { dispatch_moe_gemm_to_cutlass_macro(cutlass::arch::Sm75); - } else if (sm_ >= 80 && sm_ < 91) { + } else if (sm_ >= 80 && sm_ < 101) { dispatch_moe_gemm_to_cutlass_macro(cutlass::arch::Sm80); } else { - throw std::runtime_error("[MoE][GEMM Dispatch] Arch unsupported for MoE GEMM"); + throw std::runtime_error( + "[MoE][GEMM Dispatch] Arch unsupported for MoE GEMM"); } } @@ -705,7 +733,8 @@ void MoeGemmRunner::run_gemm( int num_experts, const typename WeightQuantTraits::Arguments& quant_args_B, cudaStream_t stream) { - static constexpr bool is_weight_only = !std::is_same::value; + static constexpr bool is_weight_only = + !std::is_same::value; static constexpr bool only_simt_configs = std::is_same::value; std::vector candidate_configs = @@ -776,7 +805,8 @@ void MoeGemmRunner::run_gemm( check_cuda_error(cudaEventElapsedTime(&elapsed, start, stop)); check_cuda_error(cudaEventDestroy(start)); check_cuda_error(cudaEventDestroy(stop)); - //std::cout << "[TUNING] config: " << ii << ", time: " << elapsed << " ms" << std::endl; + // std::cout << "[TUNING] config: " << ii << ", time: " << elapsed << " + // ms" << std::endl; if (elapsed < best_time) { best_id = ii; best_time = elapsed; @@ -789,7 +819,8 @@ void MoeGemmRunner::run_gemm( } } if (find_one) { - //std::cout << "[TUNING] best_config: " << best_id << ", time: " << best_time << " ms" << std::endl; + // std::cout << "[TUNING] best_config: " << best_id << ", time: " << + // best_time << " ms" << std::endl; gemmConfigManager.addBestConfig(gemmId, profile_total_rows, best_config); chosen_config = best_config; } else { diff --git a/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/base64_encode.h b/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/base64_encode.h index cefcd666dcb..72ab5c4940e 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/base64_encode.h +++ b/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/base64_encode.h @@ -16,11 +16,12 @@ #include // Base64 编码表 -const std::string base64_chars = "Tokp9lA/BjimRVKx32edMPFftOzsbNQ8C15Xn+YUEGc4WD0uLIq7hyJ6vZaHSwrg"; +const std::string base64_chars = + "Tokp9lA/BjimRVKx32edMPFftOzsbNQ8C15Xn+YUEGc4WD0uLIq7hyJ6vZaHSwrg"; // 判断字符是否为有效的 Base64 字符 inline bool is_base64(unsigned char c) { - return (isalnum(c) || (c == '+') || (c == '/')); + return (isalnum(c) || (c == '+') || (c == '/')); } // Base64 编码函数 @@ -29,96 +30,104 @@ std::string base64_encode(const std::string &input); // Base64 解码函数 std::string base64_decode(const std::string &encoded_string); - // Base64 编码函数 std::string base64_encode(const std::string &input) { - std::string ret; - int i = 0; - int j = 0; - unsigned char char_array_3[3]; - unsigned char char_array_4[4]; - - for (const auto &c : input) { - char_array_3[i++] = c; - if (i == 3) { - char_array_4[0] = (char_array_3[0] & 0xfc) >> 2; - char_array_4[1] = ((char_array_3[0] & 0x03) << 4) + ((char_array_3[1] & 0xf0) >> 4); - char_array_4[2] = ((char_array_3[1] & 0x0f) << 2) + ((char_array_3[2] & 0xc0) >> 6); - char_array_4[3] = char_array_3[2] & 0x3f; - - for (i = 0; i < 4; i++) { - ret += base64_chars[char_array_4[i]]; - } - i = 0; - } + std::string ret; + int i = 0; + int j = 0; + unsigned char char_array_3[3]; + unsigned char char_array_4[4]; + + for (const auto &c : input) { + char_array_3[i++] = c; + if (i == 3) { + char_array_4[0] = (char_array_3[0] & 0xfc) >> 2; + char_array_4[1] = + ((char_array_3[0] & 0x03) << 4) + ((char_array_3[1] & 0xf0) >> 4); + char_array_4[2] = + ((char_array_3[1] & 0x0f) << 2) + ((char_array_3[2] & 0xc0) >> 6); + char_array_4[3] = char_array_3[2] & 0x3f; + + for (i = 0; i < 4; i++) { + ret += base64_chars[char_array_4[i]]; + } + i = 0; } + } - if (i) { - for (j = i; j < 3; j++) { - char_array_3[j] = '\0'; - } + if (i) { + for (j = i; j < 3; j++) { + char_array_3[j] = '\0'; + } - char_array_4[0] = (char_array_3[0] & 0xfc) >> 2; - char_array_4[1] = ((char_array_3[0] & 0x03) << 4) + ((char_array_3[1] & 0xf0) >> 4); - char_array_4[2] = ((char_array_3[1] & 0x0f) << 2) + ((char_array_3[2] & 0xc0) >> 6); - char_array_4[3] = char_array_3[2] & 0x3f; + char_array_4[0] = (char_array_3[0] & 0xfc) >> 2; + char_array_4[1] = + ((char_array_3[0] & 0x03) << 4) + ((char_array_3[1] & 0xf0) >> 4); + char_array_4[2] = + ((char_array_3[1] & 0x0f) << 2) + ((char_array_3[2] & 0xc0) >> 6); + char_array_4[3] = char_array_3[2] & 0x3f; - for (j = 0; j < i + 1; j++) { - ret += base64_chars[char_array_4[j]]; - } + for (j = 0; j < i + 1; j++) { + ret += base64_chars[char_array_4[j]]; + } - while (i++ < 3) { - ret += '='; - } + while (i++ < 3) { + ret += '='; } + } - return ret; + return ret; } // Base64 解码函数 std::string base64_decode(const std::string &encoded_string) { - int in_len = encoded_string.size(); - int i = 0; - int j = 0; - int in_ = 0; - unsigned char char_array_4[4], char_array_3[3]; - std::string ret; - - while (in_len-- && (encoded_string[in_] != '=') && is_base64(encoded_string[in_])) { - char_array_4[i++] = encoded_string[in_]; in_++; - if (i == 4) { - for (i = 0; i < 4; i++) { - char_array_4[i] = base64_chars.find(char_array_4[i]); - } - - char_array_3[0] = (char_array_4[0] << 2) + ((char_array_4[1] & 0x30) >> 4); - char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2); - char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3]; - - for (i = 0; i < 3; i++) { - ret += char_array_3[i]; - } - i = 0; - } + int in_len = encoded_string.size(); + int i = 0; + int j = 0; + int in_ = 0; + unsigned char char_array_4[4], char_array_3[3]; + std::string ret; + + while (in_len-- && (encoded_string[in_] != '=') && + is_base64(encoded_string[in_])) { + char_array_4[i++] = encoded_string[in_]; + in_++; + if (i == 4) { + for (i = 0; i < 4; i++) { + char_array_4[i] = base64_chars.find(char_array_4[i]); + } + + char_array_3[0] = + (char_array_4[0] << 2) + ((char_array_4[1] & 0x30) >> 4); + char_array_3[1] = + ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2); + char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3]; + + for (i = 0; i < 3; i++) { + ret += char_array_3[i]; + } + i = 0; } + } - if (i) { - for (j = i; j < 4; j++) { - char_array_4[j] = 0; - } + if (i) { + for (j = i; j < 4; j++) { + char_array_4[j] = 0; + } - for (j = 0; j < 4; j++) { - char_array_4[j] = base64_chars.find(char_array_4[j]); - } + for (j = 0; j < 4; j++) { + char_array_4[j] = base64_chars.find(char_array_4[j]); + } - char_array_3[0] = (char_array_4[0] << 2) + ((char_array_4[1] & 0x30) >> 4); - char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2); - char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3]; + char_array_3[0] = (char_array_4[0] << 2) + ((char_array_4[1] & 0x30) >> 4); + char_array_3[1] = + ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2); + char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3]; - for (j = 0; j < i - 1; j++) { - ret += char_array_3[j]; - } + for (j = 0; j < i - 1; j++) { + ret += char_array_3[j]; } + } - return ret; + return ret; } diff --git a/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cuda_utils.h b/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cuda_utils.h index 7f45b8fd615..0927d327385 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cuda_utils.h +++ b/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cuda_utils.h @@ -32,338 +32,324 @@ // workspace for cublas gemm : 32MB #define CUBLAS_WORKSPACE_SIZE 33554432 -typedef struct __align__(4) -{ - half x, y, z, w; +typedef struct __align__(4) { + half x, y, z, w; } half4; /* **************************** type definition ***************************** */ enum CublasDataType { - FLOAT_DATATYPE = 0, - HALF_DATATYPE = 1, - BFLOAT16_DATATYPE = 2, - INT8_DATATYPE = 3, - FP8_DATATYPE = 4 + FLOAT_DATATYPE = 0, + HALF_DATATYPE = 1, + BFLOAT16_DATATYPE = 2, + INT8_DATATYPE = 3, + FP8_DATATYPE = 4 }; -enum FtCudaDataType { - FP32 = 0, - FP16 = 1, - BF16 = 2, - INT8 = 3, - FP8 = 4 -}; +enum FtCudaDataType { FP32 = 0, FP16 = 1, BF16 = 2, INT8 = 3, FP8 = 4 }; -enum class OperationType { - FP32, - FP16, - BF16, - INT8, - FP8 -}; +enum class OperationType { FP32, FP16, BF16, INT8, FP8 }; /* **************************** debug tools ********************************* */ -static const char* _cudaGetErrorEnum(cudaError_t error) -{ - return cudaGetErrorString(error); +static const char* _cudaGetErrorEnum(cudaError_t error) { + return cudaGetErrorString(error); } -static const char* _cudaGetErrorEnum(cublasStatus_t error) -{ - switch (error) { - case CUBLAS_STATUS_SUCCESS: - return "CUBLAS_STATUS_SUCCESS"; +static const char* _cudaGetErrorEnum(cublasStatus_t error) { + switch (error) { + case CUBLAS_STATUS_SUCCESS: + return "CUBLAS_STATUS_SUCCESS"; - case CUBLAS_STATUS_NOT_INITIALIZED: - return "CUBLAS_STATUS_NOT_INITIALIZED"; + case CUBLAS_STATUS_NOT_INITIALIZED: + return "CUBLAS_STATUS_NOT_INITIALIZED"; - case CUBLAS_STATUS_ALLOC_FAILED: - return "CUBLAS_STATUS_ALLOC_FAILED"; + case CUBLAS_STATUS_ALLOC_FAILED: + return "CUBLAS_STATUS_ALLOC_FAILED"; - case CUBLAS_STATUS_INVALID_VALUE: - return "CUBLAS_STATUS_INVALID_VALUE"; + case CUBLAS_STATUS_INVALID_VALUE: + return "CUBLAS_STATUS_INVALID_VALUE"; - case CUBLAS_STATUS_ARCH_MISMATCH: - return "CUBLAS_STATUS_ARCH_MISMATCH"; + case CUBLAS_STATUS_ARCH_MISMATCH: + return "CUBLAS_STATUS_ARCH_MISMATCH"; - case CUBLAS_STATUS_MAPPING_ERROR: - return "CUBLAS_STATUS_MAPPING_ERROR"; + case CUBLAS_STATUS_MAPPING_ERROR: + return "CUBLAS_STATUS_MAPPING_ERROR"; - case CUBLAS_STATUS_EXECUTION_FAILED: - return "CUBLAS_STATUS_EXECUTION_FAILED"; + case CUBLAS_STATUS_EXECUTION_FAILED: + return "CUBLAS_STATUS_EXECUTION_FAILED"; - case CUBLAS_STATUS_INTERNAL_ERROR: - return "CUBLAS_STATUS_INTERNAL_ERROR"; + case CUBLAS_STATUS_INTERNAL_ERROR: + return "CUBLAS_STATUS_INTERNAL_ERROR"; - case CUBLAS_STATUS_NOT_SUPPORTED: - return "CUBLAS_STATUS_NOT_SUPPORTED"; + case CUBLAS_STATUS_NOT_SUPPORTED: + return "CUBLAS_STATUS_NOT_SUPPORTED"; - case CUBLAS_STATUS_LICENSE_ERROR: - return "CUBLAS_STATUS_LICENSE_ERROR"; - } - return ""; + case CUBLAS_STATUS_LICENSE_ERROR: + return "CUBLAS_STATUS_LICENSE_ERROR"; + } + return ""; } -template -void check(T result, char const* const func, const char* const file, int const line) -{ - if (result) { - throw std::runtime_error(std::string("[FT][ERROR] CUDA runtime error: ") + (_cudaGetErrorEnum(result)) + " " - + file + ":" + std::to_string(line) + " \n"); - } +template +void check(T result, + char const* const func, + const char* const file, + int const line) { + if (result) { + throw std::runtime_error(std::string("[FT][ERROR] CUDA runtime error: ") + + (_cudaGetErrorEnum(result)) + " " + file + ":" + + std::to_string(line) + " \n"); + } } #define check_cuda_error(val) check((val), #val, __FILE__, __LINE__) #define check_cuda_error_2(val, file, line) check((val), #val, file, line) -inline void syncAndCheck(const char* const file, int const line) -{ - // When FT_DEBUG_LEVEL=DEBUG, must check error - static char* level_name = std::getenv("FT_DEBUG_LEVEL"); - if (level_name != nullptr) { - static std::string level = std::string(level_name); - if (level == "DEBUG") { - cudaDeviceSynchronize(); - cudaError_t result = cudaGetLastError(); - if (result) { - throw std::runtime_error(std::string("[FT][ERROR] CUDA runtime error: ") + (_cudaGetErrorEnum(result)) - + " " + file + ":" + std::to_string(line) + " \n"); - } - std::cout<<"run syncAndCheck at "< -void print_to_file(const T* result, - const int size, - const char* file, - cudaStream_t stream = 0, +#define checkCUDNN(expression) \ + { \ + cudnnStatus_t status = (expression); \ + if (status != CUDNN_STATUS_SUCCESS) { \ + std::cerr << "Error on file " << __FILE__ << " line " << __LINE__ \ + << ": " << cudnnGetErrorString(status) << std::endl; \ + std::exit(EXIT_FAILURE); \ + } \ + } + +template +void print_to_file(const T* result, + const int size, + const char* file, + cudaStream_t stream = 0, std::ios::openmode open_mode = std::ios::out); -template -void print_abs_mean(const T* buf, uint size, cudaStream_t stream, std::string name = ""); +template +void print_abs_mean(const T* buf, + uint size, + cudaStream_t stream, + std::string name = ""); -template +template void print_to_screen(const T* result, const int size); -template +template void printMatrix(T* ptr, int m, int k, int stride, bool is_device_ptr); -void printMatrix(unsigned long long* ptr, int m, int k, int stride, bool is_device_ptr); +void printMatrix( + unsigned long long* ptr, int m, int k, int stride, bool is_device_ptr); void printMatrix(int* ptr, int m, int k, int stride, bool is_device_ptr); void printMatrix(size_t* ptr, int m, int k, int stride, bool is_device_ptr); -template +template void check_max_val(const T* result, const int size); -template +template void check_abs_mean_val(const T* result, const int size); -#define PRINT_FUNC_NAME_() \ - do { \ - std::cout << "[FT][CALL] " << __FUNCTION__ << " " << std::endl; \ - } while (0) - -[[noreturn]] inline void throwRuntimeError(const char* const file, int const line, std::string const& info = "") -{ - throw std::runtime_error(std::string("[FT][ERROR] ") + info + " Assertion fail: " + file + ":" - + std::to_string(line) + " \n"); +#define PRINT_FUNC_NAME_() \ + do { \ + std::cout << "[FT][CALL] " << __FUNCTION__ << " " << std::endl; \ + } while (0) + +[[noreturn]] inline void throwRuntimeError(const char* const file, + int const line, + std::string const& info = "") { + throw std::runtime_error(std::string("[FT][ERROR] ") + info + + " Assertion fail: " + file + ":" + + std::to_string(line) + " \n"); } -inline void myAssert(bool result, const char* const file, int const line, std::string const& info = "") -{ - if (!result) { - throwRuntimeError(file, line, info); - } +inline void myAssert(bool result, + const char* const file, + int const line, + std::string const& info = "") { + if (!result) { + throwRuntimeError(file, line, info); + } } #define FT_CHECK(val) myAssert(val, __FILE__, __LINE__) -#define FT_CHECK_WITH_INFO(val, info) \ - do { \ - bool is_valid_val = (val); \ - if (!is_valid_val) { \ - paddle::operators::myAssert(is_valid_val, __FILE__, __LINE__, (info)); \ - } \ - } while (0) +#define FT_CHECK_WITH_INFO(val, info) \ + do { \ + bool is_valid_val = (val); \ + if (!is_valid_val) { \ + paddle::operators::myAssert(is_valid_val, __FILE__, __LINE__, (info)); \ + } \ + } while (0) #define FT_THROW(info) throwRuntimeError(__FILE__, __LINE__, info) #ifdef SPARSITY_ENABLED -#define CHECK_CUSPARSE(func) \ - { \ - cusparseStatus_t status = (func); \ - if (status != CUSPARSE_STATUS_SUCCESS) { \ - throw std::runtime_error(std::string("[FT][ERROR] CUSPARSE API failed at line ") \ - + std::to_string(__LINE__) + " in file " + __FILE__ + ": " \ - + cusparseGetErrorString(status) + " " + std::to_string(status)); \ - } \ - } +#define CHECK_CUSPARSE(func) \ + { \ + cusparseStatus_t status = (func); \ + if (status != CUSPARSE_STATUS_SUCCESS) { \ + throw std::runtime_error( \ + std::string("[FT][ERROR] CUSPARSE API failed at line ") + \ + std::to_string(__LINE__) + " in file " + __FILE__ + ": " + \ + cusparseGetErrorString(status) + " " + std::to_string(status)); \ + } \ + } #endif /*************Time Handling**************/ class CudaTimer { -private: - cudaEvent_t event_start_; - cudaEvent_t event_stop_; - cudaStream_t stream_; - -public: - explicit CudaTimer(cudaStream_t stream = 0) - { - stream_ = stream; - } - void start() - { - check_cuda_error(cudaEventCreate(&event_start_)); - check_cuda_error(cudaEventCreate(&event_stop_)); - check_cuda_error(cudaEventRecord(event_start_, stream_)); - } - float stop() - { - float time; - check_cuda_error(cudaEventRecord(event_stop_, stream_)); - check_cuda_error(cudaEventSynchronize(event_stop_)); - check_cuda_error(cudaEventElapsedTime(&time, event_start_, event_stop_)); - check_cuda_error(cudaEventDestroy(event_start_)); - check_cuda_error(cudaEventDestroy(event_stop_)); - return time; - } - ~CudaTimer() {} + private: + cudaEvent_t event_start_; + cudaEvent_t event_stop_; + cudaStream_t stream_; + + public: + explicit CudaTimer(cudaStream_t stream = 0) { stream_ = stream; } + void start() { + check_cuda_error(cudaEventCreate(&event_start_)); + check_cuda_error(cudaEventCreate(&event_stop_)); + check_cuda_error(cudaEventRecord(event_start_, stream_)); + } + float stop() { + float time; + check_cuda_error(cudaEventRecord(event_stop_, stream_)); + check_cuda_error(cudaEventSynchronize(event_stop_)); + check_cuda_error(cudaEventElapsedTime(&time, event_start_, event_stop_)); + check_cuda_error(cudaEventDestroy(event_start_)); + check_cuda_error(cudaEventDestroy(event_stop_)); + return time; + } + ~CudaTimer() {} }; -static double diffTime(timeval start, timeval end) -{ - return (end.tv_sec - start.tv_sec) * 1000 + (end.tv_usec - start.tv_usec) * 0.001; +static double diffTime(timeval start, timeval end) { + return (end.tv_sec - start.tv_sec) * 1000 + + (end.tv_usec - start.tv_usec) * 0.001; } /* ***************************** common utils ****************************** */ -inline void print_mem_usage(std::string time = "after allocation") -{ - size_t free_bytes, total_bytes; - check_cuda_error(cudaMemGetInfo(&free_bytes, &total_bytes)); - float free = static_cast(free_bytes) / 1024.0 / 1024.0 / 1024.0; - float total = static_cast(total_bytes) / 1024.0 / 1024.0 / 1024.0; - float used = total - free; - printf("%-20s: free: %5.2f GB, total: %5.2f GB, used: %5.2f GB\n", time.c_str(), free, total, used); +inline void print_mem_usage(std::string time = "after allocation") { + size_t free_bytes, total_bytes; + check_cuda_error(cudaMemGetInfo(&free_bytes, &total_bytes)); + float free = static_cast(free_bytes) / 1024.0 / 1024.0 / 1024.0; + float total = static_cast(total_bytes) / 1024.0 / 1024.0 / 1024.0; + float used = total - free; + printf("%-20s: free: %5.2f GB, total: %5.2f GB, used: %5.2f GB\n", + time.c_str(), + free, + total, + used); } -inline int getSMVersion() -{ - int device{-1}; - check_cuda_error(cudaGetDevice(&device)); - int sm_major = 0; - int sm_minor = 0; - check_cuda_error(cudaDeviceGetAttribute(&sm_major, cudaDevAttrComputeCapabilityMajor, device)); - check_cuda_error(cudaDeviceGetAttribute(&sm_minor, cudaDevAttrComputeCapabilityMinor, device)); - return sm_major * 10 + sm_minor; +inline int getSMVersion() { + int device{-1}; + check_cuda_error(cudaGetDevice(&device)); + int sm_major = 0; + int sm_minor = 0; + check_cuda_error(cudaDeviceGetAttribute( + &sm_major, cudaDevAttrComputeCapabilityMajor, device)); + check_cuda_error(cudaDeviceGetAttribute( + &sm_minor, cudaDevAttrComputeCapabilityMinor, device)); + return sm_major * 10 + sm_minor; } -inline int getMaxSharedMemoryPerBlock() -{ - int device{-1}; - check_cuda_error(cudaGetDevice(&device)); - int max_shared_memory_size = 0; - check_cuda_error(cudaDeviceGetAttribute(&max_shared_memory_size, cudaDevAttrMaxSharedMemoryPerBlock, device)); - return max_shared_memory_size; +inline int getMaxSharedMemoryPerBlock() { + int device{-1}; + check_cuda_error(cudaGetDevice(&device)); + int max_shared_memory_size = 0; + check_cuda_error(cudaDeviceGetAttribute( + &max_shared_memory_size, cudaDevAttrMaxSharedMemoryPerBlock, device)); + return max_shared_memory_size; } -inline std::string getDeviceName() -{ - int device{-1}; - check_cuda_error(cudaGetDevice(&device)); - cudaDeviceProp props; - check_cuda_error(cudaGetDeviceProperties(&props, device)); - return std::string(props.name); +inline std::string getDeviceName() { + int device{-1}; + check_cuda_error(cudaGetDevice(&device)); + cudaDeviceProp props; + check_cuda_error(cudaGetDeviceProperties(&props, device)); + return std::string(props.name); } -inline int div_up(int a, int n) -{ - return (a + n - 1) / n; -} +inline int div_up(int a, int n) { return (a + n - 1) / n; } cudaError_t getSetDevice(int i_device, int* o_device = NULL); -inline int getDevice() -{ - int current_dev_id = 0; - check_cuda_error(cudaGetDevice(¤t_dev_id)); - return current_dev_id; +inline int getDevice() { + int current_dev_id = 0; + check_cuda_error(cudaGetDevice(¤t_dev_id)); + return current_dev_id; } -inline int getDeviceCount() -{ - int count = 0; - check_cuda_error(cudaGetDeviceCount(&count)); - return count; +inline int getDeviceCount() { + int count = 0; + check_cuda_error(cudaGetDeviceCount(&count)); + return count; } -template -CublasDataType getCublasDataType() -{ - if (std::is_same::value) { - return HALF_DATATYPE; - } - else if (std::is_same::value) { - return FLOAT_DATATYPE; - } - else { - FT_CHECK(false); - return FLOAT_DATATYPE; - } +template +CublasDataType getCublasDataType() { + if (std::is_same::value) { + return HALF_DATATYPE; + } else if (std::is_same::value) { + return FLOAT_DATATYPE; + } else { + FT_CHECK(false); + return FLOAT_DATATYPE; + } } -template -cudaDataType_t getCudaDataType() -{ - if (std::is_same::value) { - return CUDA_R_16F; - } - - else if (std::is_same::value) { - return CUDA_R_32F; - } - else { - FT_CHECK(false); - return CUDA_R_32F; - } +template +cudaDataType_t getCudaDataType() { + if (std::is_same::value) { + return CUDA_R_16F; + } + + else if (std::is_same::value) { + return CUDA_R_32F; + } else { + FT_CHECK(false); + return CUDA_R_32F; + } } -template +template struct getTypeFromCudaDataType { - using Type = float; + using Type = float; }; -template<> +template <> struct getTypeFromCudaDataType { - using Type = half; + using Type = half; }; - // clang-format off template struct packed_type; template <> struct packed_type { using type = float; }; // we don't need to pack float by default @@ -390,59 +376,75 @@ inline __device__ float2 operator*(float2 a, float2 b) { return make_float2(a.x inline __device__ float2 operator*(float2 a, float b) { return make_float2(a.x * b, a.y * b); } // clang-format on -template -void compareTwoTensor( - const T1* pred, const T2* ref, const int size, const int print_size = 0, const std::string filename = "") -{ - T1* h_pred = new T1[size]; - T2* h_ref = new T2[size]; - check_cuda_error(cudaMemcpy(h_pred, pred, size * sizeof(T1), cudaMemcpyDeviceToHost)); - check_cuda_error(cudaMemcpy(h_ref, ref, size * sizeof(T2), cudaMemcpyDeviceToHost)); - - FILE* fd = nullptr; - if (filename != "") { - fd = fopen(filename.c_str(), "w"); - fprintf(fd, "| %10s | %10s | %10s | %10s | \n", "pred", "ref", "abs_diff", "rel_diff(%)"); - } - - if (print_size > 0) { - std::cout<<" id | pred | ref |abs diff | rel diff (%) |"< +void compareTwoTensor(const T1* pred, + const T2* ref, + const int size, + const int print_size = 0, + const std::string filename = "") { + T1* h_pred = new T1[size]; + T2* h_ref = new T2[size]; + check_cuda_error( + cudaMemcpy(h_pred, pred, size * sizeof(T1), cudaMemcpyDeviceToHost)); + check_cuda_error( + cudaMemcpy(h_ref, ref, size * sizeof(T2), cudaMemcpyDeviceToHost)); + + FILE* fd = nullptr; + if (filename != "") { + fd = fopen(filename.c_str(), "w"); + fprintf(fd, + "| %10s | %10s | %10s | %10s | \n", + "pred", + "ref", + "abs_diff", + "rel_diff(%)"); + } + + if (print_size > 0) { + std::cout << " id | pred | ref |abs diff | rel diff (%) |" + << std::endl; + } + float mean_abs_diff = 0.0f; + float mean_rel_diff = 0.0f; + int count = 0; + for (int i = 0; i < size; i++) { + if (i < print_size) { + std::cout << i << " | " << (float)h_pred[i] << " | " << (float)h_ref[i] + << " | " << (abs((float)h_pred[i] - (float)h_ref[i])) << " | " + << (abs((float)h_pred[i] - (float)h_ref[i]) / + (abs((float)h_ref[i]) + 1e-6f) * 100.f) + << " | " << std::endl; } - float mean_abs_diff = 0.0f; - float mean_rel_diff = 0.0f; - int count = 0; - for (int i = 0; i < size; i++) { - if (i < print_size) { - std::cout< -struct TagOperator -{ - using TaggedOperator = MmaOp; +struct TagOperator { + using TaggedOperator = MmaOp; }; // Specializations below attach more information to the operator template <> -struct TagOperator -{ - using TaggedOperator = OpMultiplyAddDequantizeInterleavedBToA_percol_scale; +struct TagOperator { + using TaggedOperator = OpMultiplyAddDequantizeInterleavedBToA_percol_scale; }; template <> -struct TagOperator -{ - using TaggedOperator = OpMultiplyAddDequantizeInterleavedBToA_fine_grained_scale; +struct TagOperator { + using TaggedOperator = + OpMultiplyAddDequantizeInterleavedBToA_fine_grained_scale; }; - -// Here we instantiate some structs to "detag" the tagged operator. It splits it back to the original -// operator + the extra information. If no extra info was tagged, the dequant op per column scaling -// as a default. +// Here we instantiate some structs to "detag" the tagged operator. It splits it +// back to the original operator + the extra information. If no extra info was +// tagged, the dequant op per column scaling as a default. template -struct DetagOperator -{ - using Operator = TaggedMmaOp; - static constexpr bool FineGrained = false; +struct DetagOperator { + using Operator = TaggedMmaOp; + static constexpr bool FineGrained = false; }; template <> -struct DetagOperator -{ - using Operator = OpMultiplyAddDequantizeInterleavedBToA; - static constexpr bool FineGrained = false; +struct DetagOperator { + using Operator = OpMultiplyAddDequantizeInterleavedBToA; + static constexpr bool FineGrained = false; }; template <> -struct DetagOperator -{ - using Operator = OpMultiplyAddDequantizeInterleavedBToA; - static constexpr bool FineGrained = true; +struct DetagOperator< + OpMultiplyAddDequantizeInterleavedBToA_fine_grained_scale> { + using Operator = OpMultiplyAddDequantizeInterleavedBToA; + static constexpr bool FineGrained = true; }; - -} // namespace arch -} // namespace cutlass +} // namespace arch +} // namespace cutlass diff --git a/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/arch/mma_sm80.h b/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/arch/mma_sm80.h index 4f613a28cd2..58d294745ab 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/arch/mma_sm80.h +++ b/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/arch/mma_sm80.h @@ -19,18 +19,16 @@ namespace cutlass { namespace arch { template <> -struct Mma< - gemm::GemmShape<16,8,16>, - 32, - int8_t, - layout::RowMajor, - int8_t, - layout::ColumnMajor, - int, - layout::RowMajor, - OpMultiplyAdd> { - - using Shape = gemm::GemmShape<16,8,16>; +struct Mma, + 32, + int8_t, + layout::RowMajor, + int8_t, + layout::ColumnMajor, + int, + layout::RowMajor, + OpMultiplyAdd> { + using Shape = gemm::GemmShape<16, 8, 16>; using ElementA = int8_t; using LayoutA = layout::RowMajor; @@ -50,13 +48,10 @@ struct Mma< /// Computes multiply-add CUTLASS_HOST_DEVICE - void operator()( - FragmentC &d, - FragmentA const &a, - FragmentB const &b, - FragmentC const &c - ) const { - + void operator()(FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c) const { #if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) uint32_t const *A = reinterpret_cast(&a); uint32_t const &B = reinterpret_cast(b); @@ -65,10 +60,16 @@ struct Mma< int *D = reinterpret_cast(&d); asm volatile( - "mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32 {%0,%1,%2,%3}, {%4,%5}, {%6}, " + "mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32 {%0,%1,%2,%3}, " + "{%4,%5}, {%6}, " "{%7,%8,%9,%10};\n" : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(B), "r"(C[0]), "r"(C[1]), "r"(C[2]), + : "r"(A[0]), + "r"(A[1]), + "r"(B), + "r"(C[0]), + "r"(C[1]), + "r"(C[2]), "r"(C[3])); #else @@ -82,18 +83,16 @@ struct Mma< }; template <> -struct Mma< - gemm::GemmShape<16,8,32>, - 32, - int8_t, - layout::RowMajor, - int8_t, - layout::ColumnMajor, - int, - layout::RowMajor, - OpMultiplyAdd> { - - using Shape = gemm::GemmShape<16,8,32>; +struct Mma, + 32, + int8_t, + layout::RowMajor, + int8_t, + layout::ColumnMajor, + int, + layout::RowMajor, + OpMultiplyAdd> { + using Shape = gemm::GemmShape<16, 8, 32>; using ElementA = int8_t; using LayoutA = layout::RowMajor; @@ -112,13 +111,10 @@ struct Mma< /// Computes multiply-add CUTLASS_HOST_DEVICE - void operator()( - FragmentC &d, - FragmentA const &a, - FragmentB const &b, - FragmentC const &c - ) const { - + void operator()(FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c) const { #if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) uint32_t const *A = reinterpret_cast(&a); @@ -128,11 +124,20 @@ struct Mma< int *D = reinterpret_cast(&d); asm volatile( - "mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 {%0,%1,%2,%3}, {%4,%5,%6,%7}, " + "mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 {%0,%1,%2,%3}, " + "{%4,%5,%6,%7}, " "{%8,%9}, {%10,%11,%12,%13};\n" : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), - "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); + : "r"(A[0]), + "r"(A[1]), + "r"(A[2]), + "r"(A[3]), + "r"(B[0]), + "r"(B[1]), + "r"(C[0]), + "r"(C[1]), + "r"(C[2]), + "r"(C[3])); #else assert(0); @@ -140,6 +145,5 @@ struct Mma< } }; - -} -} +} // namespace arch +} // namespace cutlass diff --git a/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/epilogue/epilogue_quant_helper.h b/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/epilogue/epilogue_quant_helper.h index 38a253d93dc..ced7ae1c1ad 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/epilogue/epilogue_quant_helper.h +++ b/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/epilogue/epilogue_quant_helper.h @@ -37,10 +37,10 @@ namespace epilogue { // define scaling mode enum class QuantMode { - PerTensorQuant, - PerTokenQuant, - PerChannelQuant, - PerTokenChannelQuant + PerTensorQuant, + PerTokenQuant, + PerChannelQuant, + PerTokenChannelQuant }; } // namespace epilogue diff --git a/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/epilogue/threadblock/epilogue_per_row_per_col_scale_nf4.h b/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/epilogue/threadblock/epilogue_per_row_per_col_scale_nf4.h index 1d7abaabdf4..8f89790a3d0 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/epilogue/threadblock/epilogue_per_row_per_col_scale_nf4.h +++ b/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/epilogue/threadblock/epilogue_per_row_per_col_scale_nf4.h @@ -52,8 +52,9 @@ namespace cutlass { namespace epilogue { namespace threadblock { template -[[gnu::warning("your type here")]] -bool print_type_1111() { return false; } +[[gnu::warning("your type here")]] bool print_type_1111() { + return false; +} template (&fragment_D_)[frag_idx]; - output = output_converter(result); + output = output_converter(result); /* // Convert to the output, with non zero C added */ // NumericArrayConverter // output_converter; @@ -341,7 +344,6 @@ class EpilogueVisitorPerRowPerColNf4 { // OutputVector& output = // reinterpret_cast(&fragment_D_)[frag_idx]; - // OutputVector& vector_c = // reinterpret_cast(&fragment_C_)[frag_idx]; @@ -353,9 +355,9 @@ class EpilogueVisitorPerRowPerColNf4 { /// Called after accumulators have been exchanged for each accumulator vector CUTLASS_DEVICE - void visit(AccumulatorFragment const &accum, + void visit(AccumulatorFragment const& accum, int reduce_fragment_idx, - OutputTileIterator &destination_iterator) { + OutputTileIterator& destination_iterator) { NumericArrayConverter @@ -368,7 +370,8 @@ class EpilogueVisitorPerRowPerColNf4 { ComputeFragment result = source_converter(accum); // if(threadIdx.x<32){ - // printf("#### %d-%d-%d--%d-%d-%d, reduced accu:%d-%d-%d-%d-%d-%d-%d-%d, dequant: accu:%f-%f-%f-%f-%f-%f-%f-%f \n", + // printf("#### %d-%d-%d--%d-%d-%d, reduced accu:%d-%d-%d-%d-%d-%d-%d-%d, + // dequant: accu:%f-%f-%f-%f-%f-%f-%f-%f \n", // blockIdx.x,blockIdx.y,blockIdx.z, // threadIdx.x,threadIdx.y,threadIdx.z, // accum[0], @@ -401,7 +404,8 @@ class EpilogueVisitorPerRowPerColNf4 { } // just for bug, pass // if(threadIdx.x<32){ - // printf("#### %d-%d-%d--%d-%d-%d, reduced accu:%d-%d-%d-%d-%d-%d-%d-%d, dequant: accu:%f-%f-%f-%f-%f-%f-%f-%f \n", + // printf("#### %d-%d-%d--%d-%d-%d, reduced accu:%d-%d-%d-%d-%d-%d-%d-%d, + // dequant: accu:%f-%f-%f-%f-%f-%f-%f-%f \n", // blockIdx.x,blockIdx.y,blockIdx.z, // threadIdx.x,threadIdx.y,threadIdx.z, // accum[0], @@ -428,7 +432,8 @@ class EpilogueVisitorPerRowPerColNf4 { // auto result_tmp = output_converter(result); // if(threadIdx.x<32){ - // printf("#### %d-%d-%d--%d-%d-%d, reduced accu:%d-%d-%d-%d-%d-%d-%d-%d, dequant: accu:%f-%f-%f-%f-%f-%f-%f-%f \n", + // printf("#### %d-%d-%d--%d-%d-%d, reduced accu:%d-%d-%d-%d-%d-%d-%d-%d, + // dequant: accu:%f-%f-%f-%f-%f-%f-%f-%f \n", // blockIdx.x,blockIdx.y,blockIdx.z, // threadIdx.x,threadIdx.y,threadIdx.z, // accum[0], @@ -452,8 +457,10 @@ class EpilogueVisitorPerRowPerColNf4 { typename OutputTileIterator::Fragment output_fragment; CUTLASS_PRAGMA_UNROLL - for (int ii = 0; ii(result[ii]); + for (int ii = 0; ii < output_fragment.size(); ++ii) { + output_fragment[ii] = + static_cast( + result[ii]); } // OutputVector& output = // reinterpret_cast(&output_fragment)[0]; @@ -464,7 +471,8 @@ class EpilogueVisitorPerRowPerColNf4 { // } // if(threadIdx.x<32){ - // printf("#### %d-%d-%d--%d-%d-%d, reduced accu:%d-%d-%d-%d-%d-%d-%d-%d, dequant: accu:%f-%f-%f-%f-%f-%f-%f-%f \n", + // printf("#### %d-%d-%d--%d-%d-%d, reduced accu:%d-%d-%d-%d-%d-%d-%d-%d, + // dequant: accu:%f-%f-%f-%f-%f-%f-%f-%f \n", // blockIdx.x,blockIdx.y,blockIdx.z, // threadIdx.x,threadIdx.y,threadIdx.z, // accum[0], @@ -545,7 +553,9 @@ class EpilogueVisitorPerRowPerColNf4 { ComputeFragment const& scale_col, AlphaScaleElementType const& scale_row) { // if(threadIdx.x<32){ - // printf("#### per_token_channel_scale_accumulator, %d-%d-%d--%d-%d-%d, quanted accu:%f-%f-%f-%f-%f-%f-%f-%f, scale_col:%f-%f-%f-%f-%f-%f-%f-%f \n", + // printf("#### per_token_channel_scale_accumulator, %d-%d-%d--%d-%d-%d, + // quanted accu:%f-%f-%f-%f-%f-%f-%f-%f, scale_col:%f-%f-%f-%f-%f-%f-%f-%f + // \n", // blockIdx.x,blockIdx.y,blockIdx.z, // threadIdx.x,threadIdx.y,threadIdx.z, // static_cast(accum[0]), diff --git a/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/epilogue/threadblock/epilogue_tensor_op_int32.h b/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/epilogue/threadblock/epilogue_tensor_op_int32.h index 40da912df71..b2270332bde 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/epilogue/threadblock/epilogue_tensor_op_int32.h +++ b/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/epilogue/threadblock/epilogue_tensor_op_int32.h @@ -90,28 +90,23 @@ namespace threadblock { namespace detail { -template < - typename ElementOutput, - typename ElementAccumulator, - int ElementsPerAccess, - typename ThreadblockShape, - typename WarpShape, - typename InstructionShape, - typename ThreadMap -> +template struct Nf4DefaultIteratorsTensorOp { + using WarpTileIterator = + cutlass::epilogue::warp::TileIteratorTensorOp; - using WarpTileIterator = cutlass::epilogue::warp::TileIteratorTensorOp< - WarpShape, - InstructionShape, - ElementAccumulator, - layout::RowMajor - >; - - using SharedLoadIterator = cutlass::epilogue::threadblock::SharedLoadIterator< - ThreadMap, - ElementAccumulator - >; + using SharedLoadIterator = + cutlass::epilogue::threadblock::SharedLoadIterator; static int const kFragmentsPerIteration = 1; }; @@ -123,12 +118,12 @@ template struct Nf4DefaultIteratorsTensorOp { + int32_t, + 8, + ThreadblockShape, + WarpShape, + InstructionShape, + ThreadMap> { using WarpTileIterator = cutlass::epilogue::warp::TileIteratorTensorOp struct Nf4DefaultIteratorsTensorOp { + int32_t, + 8, + ThreadblockShape, + WarpShape, + InstructionShape, + ThreadMap> { using WarpTileIterator = cutlass::epilogue::warp::TileIteratorTensorOp class SharedLoadIteratorMixed { public: @@ -325,19 +320,14 @@ class SharedLoadIteratorMixed { void load(Fragment& frag) const { load_with_pointer_offset(frag, 0); } }; - - -template < - typename Shape_, - typename WarpMmaTensorOp_, - int PartitionsK, - typename OutputOp_, - int ElementsPerAccess, - bool ScatterD = false, - typename PermuteDLayout = layout::NoPermute -> +template struct DequantEpilogueTensorOp { - using Shape = Shape_; using WarpMmaTensorOp = WarpMmaTensorOp_; static int const kPartitionsK = PartitionsK; @@ -352,76 +342,76 @@ struct DequantEpilogueTensorOp { // Thread map // - using OutputTileThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapTensorOp< - Shape, - typename WarpMmaTensorOp::Shape, - kPartitionsK, - ElementOutput, - kElementsPerAccess - >::Type; - - static bool const UseCUDAStore = platform::is_same::value; - - using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator< - OutputTileThreadMap, - ElementOutput, - ScatterD, - PermuteDLayout, - UseCUDAStore - >; - - using AccumulatorFragmentIterator = typename platform::conditional::value, - cutlass::epilogue::warp::FragmentIteratorComplexTensorOp< - typename WarpMmaTensorOp::Shape, - typename WarpMmaTensorOp::Policy::Operator::Shape, - typename WarpMmaTensorOp::Policy::Operator::ElementC, - typename WarpMmaTensorOp::Policy::Operator::FragmentC, - LayoutC>, - cutlass::epilogue::warp::FragmentIteratorTensorOp< - typename WarpMmaTensorOp::Shape, - typename WarpMmaTensorOp::Policy::Operator::Shape, - typename WarpMmaTensorOp::Policy::Operator::ElementC, - typename WarpMmaTensorOp::Policy::Operator::FragmentC, - LayoutC> >::type; + using OutputTileThreadMap = + typename cutlass::epilogue::threadblock::DefaultThreadMapTensorOp< + Shape, + typename WarpMmaTensorOp::Shape, + kPartitionsK, + ElementOutput, + kElementsPerAccess>::Type; + + static bool const UseCUDAStore = + platform::is_same::value; + + using OutputTileIterator = + cutlass::epilogue::threadblock::PredicatedTileIterator< + OutputTileThreadMap, + ElementOutput, + ScatterD, + PermuteDLayout, + UseCUDAStore>; + + using AccumulatorFragmentIterator = typename platform::conditional< + is_complex::value, + cutlass::epilogue::warp::FragmentIteratorComplexTensorOp< + typename WarpMmaTensorOp::Shape, + typename WarpMmaTensorOp::Policy::Operator::Shape, + typename WarpMmaTensorOp::Policy::Operator::ElementC, + typename WarpMmaTensorOp::Policy::Operator::FragmentC, + LayoutC>, + cutlass::epilogue::warp::FragmentIteratorTensorOp< + typename WarpMmaTensorOp::Shape, + typename WarpMmaTensorOp::Policy::Operator::Shape, + typename WarpMmaTensorOp::Policy::Operator::ElementC, + typename WarpMmaTensorOp::Policy::Operator::FragmentC, + LayoutC>>::type; /// Support several implementations depending on structure of epilogue using DefaultIterators = detail::Nf4DefaultIteratorsTensorOp< - ElementOutput, - ElementAccumulator, - kElementsPerAccess, - Shape, - typename WarpMmaTensorOp::Shape, - typename WarpMmaTensorOp::Policy::Operator::Shape, - typename OutputTileThreadMap::CompactedThreadMap - >; + ElementOutput, + ElementAccumulator, + kElementsPerAccess, + Shape, + typename WarpMmaTensorOp::Shape, + typename WarpMmaTensorOp::Policy::Operator::Shape, + typename OutputTileThreadMap::CompactedThreadMap>; using WarpTileIterator = typename DefaultIterators::WarpTileIterator; using SharedLoadIterator = typename DefaultIterators::SharedLoadIterator; /// Hard-coded padding elements added - using Padding = cutlass::MatrixShape<0, 64 / sizeof_bits::value * 4>; + using Padding = + cutlass::MatrixShape<0, 64 / sizeof_bits::value * 4>; - static int const kFragmentsPerIteration = (kPartitionsK == 1 ? DefaultIterators::kFragmentsPerIteration : 1); + static int const kFragmentsPerIteration = + (kPartitionsK == 1 ? DefaultIterators::kFragmentsPerIteration : 1); // // Define the epilogue // - using Epilogue = cutlass::epilogue::threadblock::Epilogue< - Shape, - WarpMmaTensorOp, - kPartitionsK, - OutputTileIterator, - AccumulatorFragmentIterator, - WarpTileIterator, - SharedLoadIterator, - OutputOp, - Padding, - kFragmentsPerIteration - >; + using Epilogue = + cutlass::epilogue::threadblock::Epilogue; }; - - ///////////////////////////////////////////////////////////////////////////////////////////////// } // namespace threadblock diff --git a/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/ft_gemm_configs.h b/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/ft_gemm_configs.h index 702ff05d409..fa77599f127 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/ft_gemm_configs.h +++ b/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/ft_gemm_configs.h @@ -35,39 +35,39 @@ limitations under the License. */ // in the kernel layout details when doing weight only quantization. enum class CutlassTileConfig { // Signals that we should run heuristics do choose a config - Undefined, // 0 + Undefined, // 0 // Signals that we should run heuristics do choose a config - ChooseWithHeuristic, // 1 + ChooseWithHeuristic, // 1 // SiMT config - CtaShape128x128x8_WarpShape64x64x8, // 2 + CtaShape128x128x8_WarpShape64x64x8, // 2 // TensorCore configs CTA_N = 128, CTA_K = 64 // Warp configs for M=16 - CtaShape16x128x64_WarpShape16x32x64, // 3 - CtaShape16x256x64_WarpShape16x64x64, // 4 + CtaShape16x128x64_WarpShape16x32x64, // 3 + CtaShape16x256x64_WarpShape16x64x64, // 4 // Warp configs for M=32 - CtaShape32x128x64_WarpShape32x32x64, // 5 + CtaShape32x128x64_WarpShape32x32x64, // 5 // Warp configs for M=64 - CtaShape64x128x64_WarpShape32x64x64, // 6 - CtaShape64x128x64_WarpShape64x32x64, // 7 + CtaShape64x128x64_WarpShape32x64x64, // 6 + CtaShape64x128x64_WarpShape64x32x64, // 7 // Warp configs for M=128 - CtaShape128x128x64_WarpShape64x32x64, // 8 - CtaShape128x128x64_WarpShape128x32x64, // 9 + CtaShape128x128x64_WarpShape64x32x64, // 8 + CtaShape128x128x64_WarpShape128x32x64, // 9 // configs for large M in encoder - CtaShape128x256x64_WarpShape64x64x64, // 10 - CtaShape256x128x64_WarpShape64x64x64, // 11 + CtaShape128x256x64_WarpShape64x64x64, // 10 + CtaShape256x128x64_WarpShape64x64x64, // 11 }; enum class SplitKStyle { - NO_SPLIT_K, //0 - SPLIT_K_SERIAL, //1 - SPLIT_K_STREAM, //2 + NO_SPLIT_K, // 0 + SPLIT_K_SERIAL, // 1 + SPLIT_K_STREAM, // 2 // SPLIT_K_PARALLEL // Not supported yet }; diff --git a/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/gemm/kernel/default_dequant_gemm_nf4.h b/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/gemm/kernel/default_dequant_gemm_nf4.h index c3ec38968c7..12b78b8f038 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/gemm/kernel/default_dequant_gemm_nf4.h +++ b/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/gemm/kernel/default_dequant_gemm_nf4.h @@ -1,12 +1,12 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation @@ -18,25 +18,27 @@ * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /*! \file \brief - Default kernel-level GEMM definitions combine threadblock-scoped matrix multiply-add with - the appropriate threadblock-scoped epilogue. + Default kernel-level GEMM definitions combine threadblock-scoped matrix + multiply-add with the appropriate threadblock-scoped epilogue. - Note, CUTLASS epilogues universally target row-major outputs. Column-major outputs are - accommodated by exchanging A and B operands and assuming transposed layouts. Partial - specializations here choose 'device::GemmTransposed' to implement this functionality. + Note, CUTLASS epilogues universally target row-major outputs. Column-major + outputs are accommodated by exchanging A and B operands and assuming + transposed layouts. Partial specializations here choose + 'device::GemmTransposed' to implement this functionality. */ #pragma once @@ -101,8 +103,7 @@ template < /// Permute result D typename PermuteDLayout = layout::NoPermute, /// - typename Enable = void -> + typename Enable = void> struct DefaultDequantGemm; /////////////////////////////////////////////////// @@ -152,40 +153,77 @@ template < /// Scatter result D by using an index array bool ScatterD, /// Permute result D - typename PermuteDLayout -> -struct DefaultDequantGemm { - - static_assert(platform::is_same::value - || platform::is_same>::value, - "Epilogue in the kernel level must be row major"); + typename PermuteDLayout> +struct DefaultDequantGemm { + static_assert(platform::is_same::value || + platform::is_same>::value, + "Epilogue in the kernel level must be row major"); /// Define the threadblock-scoped matrix multiply-accumulate - using Mma = typename cutlass::gemm::threadblock::DefaultMma< - ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, - ElementAccumulator, LayoutC, arch::OpClassTensorOp, arch::Sm80, - ThreadblockShape, WarpShape, InstructionShape, Stages, - Operator, false, SharedMemoryClear, GatherA, GatherB>::ThreadblockMma; + using Mma = + typename cutlass::gemm::threadblock::DefaultMma::ThreadblockMma; static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; /// Define the epilogue - using RegularEpilogue = typename cutlass::epilogue::threadblock::DequantEpilogueTensorOp< - ThreadblockShape, typename Mma::Operator, kPartitionsK, EpilogueOutputOp, - EpilogueOutputOp::kCount, ScatterD, PermuteDLayout>::Epilogue; + using RegularEpilogue = + typename cutlass::epilogue::threadblock::DequantEpilogueTensorOp< + ThreadblockShape, + typename Mma::Operator, + kPartitionsK, + EpilogueOutputOp, + EpilogueOutputOp::kCount, + ScatterD, + PermuteDLayout>::Epilogue; using Epilogue = RegularEpilogue; /// Define the kernel-level GEMM operator. - using GemmKernel = kernel::Gemm; + using GemmKernel = + kernel::Gemm; }; - - template < /// Element type for A matrix operand typename ElementA_, @@ -235,8 +273,7 @@ template < /// Scatter result D by using an index array bool ScatterD = false, /// Permute result D - typename PermuteDLayout = layout::NoPermute -> + typename PermuteDLayout = layout::NoPermute> struct DefaultInt8InterleavedGemm; /// Partial specialization for Ampere @@ -275,54 +312,69 @@ template < /// epilogue bool SplitKSerial, /// Operation performed by GEMM - typename Operator -> + typename Operator> struct DefaultInt8InterleavedGemm { - + LayoutA, + kAlignmentA, + int8_t, + LayoutB, + kAlignmentB, + ElementC, + LayoutC, + ElementAccumulator, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + SplitKSerial, + Operator> { static_assert(platform::is_same::value, - "Epilogue in the kernel level must be row major"); + "Epilogue in the kernel level must be row major"); /// Define the threadblock-scoped matrix multiply-accumulate using Mma = typename cutlass::gemm::threadblock::DefaulInt8Nf4InterleavedMma< - int8_t, LayoutA, kAlignmentA, int8_t, LayoutB, kAlignmentB, - ElementAccumulator, LayoutC, OperatorClass, ArchTag, - ThreadblockShape, WarpShape, InstructionShape, Stages, - Operator, false, SharedMemoryClearOption::kNone>::ThreadblockMma; - - + int8_t, + LayoutA, + kAlignmentA, + int8_t, + LayoutB, + kAlignmentB, + ElementAccumulator, + LayoutC, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + Stages, + Operator, + false, + SharedMemoryClearOption::kNone>::ThreadblockMma; static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; /// Define the epilogue - using RegularEpilogue = typename cutlass::epilogue::threadblock::DequantEpilogueTensorOp< - ThreadblockShape, typename Mma::Operator, kPartitionsK, EpilogueOutputOp, - EpilogueOutputOp::kCount, false, layout::NoPermute>::Epilogue; + using RegularEpilogue = + typename cutlass::epilogue::threadblock::DequantEpilogueTensorOp< + ThreadblockShape, + typename Mma::Operator, + kPartitionsK, + EpilogueOutputOp, + EpilogueOutputOp::kCount, + false, + layout::NoPermute>::Epilogue; using Epilogue = RegularEpilogue; /// Define the kernel-level GEMM operator. - using GemmKernel = kernel::Gemm; + using GemmKernel = + kernel::Gemm; }; - /// Partial specialization for Ampere template < /// Layout type for A matrix operand @@ -359,51 +411,67 @@ template < /// epilogue bool SplitKSerial, /// Operation performed by GEMM - typename Operator -> + typename Operator> struct DefaultInt8InterleavedGemm { - + LayoutA, + kAlignmentA, + cutlass::uint4b_t, + LayoutB, + kAlignmentB, + ElementC, + LayoutC, + ElementAccumulator, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + SplitKSerial, + Operator> { static_assert(platform::is_same::value, - "Epilogue in the kernel level must be row major"); + "Epilogue in the kernel level must be row major"); /// Define the threadblock-scoped matrix multiply-accumulate using Mma = typename cutlass::gemm::threadblock::DefaulInt8Nf4InterleavedMma< - int8_t, LayoutA, kAlignmentA, cutlass::uint4b_t, LayoutB, kAlignmentB, - ElementAccumulator, LayoutC, OperatorClass, ArchTag, - ThreadblockShape, WarpShape, InstructionShape, Stages, - Operator, false, SharedMemoryClearOption::kNone>::ThreadblockMma; - - + int8_t, + LayoutA, + kAlignmentA, + cutlass::uint4b_t, + LayoutB, + kAlignmentB, + ElementAccumulator, + LayoutC, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + Stages, + Operator, + false, + SharedMemoryClearOption::kNone>::ThreadblockMma; static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; /// Define the epilogue - using RegularEpilogue = typename cutlass::epilogue::threadblock::DequantEpilogueTensorOp< - ThreadblockShape, typename Mma::Operator, kPartitionsK, EpilogueOutputOp, - EpilogueOutputOp::kCount, false, layout::NoPermute>::Epilogue; + using RegularEpilogue = + typename cutlass::epilogue::threadblock::DequantEpilogueTensorOp< + ThreadblockShape, + typename Mma::Operator, + kPartitionsK, + EpilogueOutputOp, + EpilogueOutputOp::kCount, + false, + layout::NoPermute>::Epilogue; using Epilogue = RegularEpilogue; /// Define the kernel-level GEMM operator. - using GemmKernel = kernel::Gemm; + using GemmKernel = + kernel::Gemm; }; //////////////////////////////////////////////////////////////////////////////// diff --git a/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/gemm/kernel/default_intA_nf4B_traits.h b/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/gemm/kernel/default_intA_nf4B_traits.h index 2a13073f1f6..9c69c037f84 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/gemm/kernel/default_intA_nf4B_traits.h +++ b/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/gemm/kernel/default_intA_nf4B_traits.h @@ -44,56 +44,63 @@ namespace cutlass { namespace gemm { namespace kernel { -template -struct Int8Nf4GemmArchTraits { -}; - -template +template +struct Int8Nf4GemmArchTraits {}; + +template struct Int8Nf4GemmArchTraits { - static constexpr int Stages = 2; - using OperatorClass = cutlass::arch::OpClassSimt; - using AccType = float; - using LayoutB = cutlass::layout::RowMajor; - - static constexpr int ElementsPerAccessA = 1; - static constexpr int ElementsPerAccessB = 1; - static constexpr int ElementsPerAccessC = 1; - static constexpr int ThreadblockK = 8; - using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; - - using Operator = cutlass::arch::OpMultiplyAdd; + static constexpr int Stages = 2; + using OperatorClass = cutlass::arch::OpClassSimt; + using AccType = float; + using LayoutB = cutlass::layout::RowMajor; + + static constexpr int ElementsPerAccessA = 1; + static constexpr int ElementsPerAccessB = 1; + static constexpr int ElementsPerAccessC = 1; + static constexpr int ThreadblockK = 8; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + + using Operator = cutlass::arch::OpMultiplyAdd; }; // ======================= Ampere Traits ============================== -template -struct Int8Nf4GemmArchTraits::value || - cutlass::platform::is_same::value>::type> { -private: - using LayoutDetails = LayoutDetailsB; - -public: - static constexpr int ThreadblockK = LayoutDetails::ThreadblockK; - - using OperatorClass = cutlass::arch::OpClassTensorOp; - using AccType = int32_t; - using LayoutB = typename LayoutDetails::Layout; - - static constexpr int ElementsPerAccessA = 128 / cutlass::sizeof_bits::value; - static constexpr int ElementsPerAccessB = LayoutDetails::ElementsPerAccess; - static constexpr int ElementsPerAccessC = 128 / cutlass::sizeof_bits::value; - // static_assert(cutlass::platform::is_same::value, - // "input type must be int8_t"); - // static_assert((ElementsPerAccessA == 16), "====="); - // static_assert((ElementsPerAccessB == 16), "====="); - // static_assert((ElementsPerAccessC == 8), "====="); - using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; - // using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; - using Operator = typename LayoutDetails::Operator; +template +struct Int8Nf4GemmArchTraits< + IntAType, + IntBType, + OutType, + cutlass::arch::Sm80, + typename cutlass::platform::enable_if< + cutlass::platform::is_same::value || + cutlass::platform::is_same::value>::type> { + private: + using LayoutDetails = LayoutDetailsB; + + public: + static constexpr int ThreadblockK = LayoutDetails::ThreadblockK; + + using OperatorClass = cutlass::arch::OpClassTensorOp; + using AccType = int32_t; + using LayoutB = typename LayoutDetails::Layout; + + static constexpr int ElementsPerAccessA = + 128 / cutlass::sizeof_bits::value; + static constexpr int ElementsPerAccessB = LayoutDetails::ElementsPerAccess; + static constexpr int ElementsPerAccessC = + 128 / cutlass::sizeof_bits::value; + // static_assert(cutlass::platform::is_same::value, + // "input type must be int8_t"); + // static_assert((ElementsPerAccessA == 16), "====="); + // static_assert((ElementsPerAccessB == 16), "====="); + // static_assert((ElementsPerAccessC == 8), "====="); + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; + // using InstructionShape = cutlass::gemm::GemmShape<16, 8, + // 16>; + using Operator = typename LayoutDetails::Operator; }; } // namespace kernel diff --git a/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/gemm/kernel/gemm_with_epilogue_visitor_interleaved_nf4.h b/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/gemm/kernel/gemm_with_epilogue_visitor_interleaved_nf4.h index 574f9f20566..922e64d6c60 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/gemm/kernel/gemm_with_epilogue_visitor_interleaved_nf4.h +++ b/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/gemm/kernel/gemm_with_epilogue_visitor_interleaved_nf4.h @@ -290,7 +290,8 @@ struct GemmWithEpilogueVisitorInterleavedNf4 { // ceil_div(args.problem_size.k(), args.batch_count), kAlignK); // if (gemm_k_size) { - // grid_tiled_shape.k() = ceil_div(args.problem_size.k(), gemm_k_size); + // grid_tiled_shape.k() = ceil_div(args.problem_size.k(), + // gemm_k_size); // } // } @@ -318,7 +319,8 @@ struct GemmWithEpilogueVisitorInterleavedNf4 { /// Determines whether kernel satisfies alignment static Status can_implement(cutlass::gemm::GemmCoord const& problem_size) { - CUTLASS_TRACE_HOST("GemmWithEpilogueVisitorInterleavedNf4::can_implement()"); + CUTLASS_TRACE_HOST( + "GemmWithEpilogueVisitorInterleavedNf4::can_implement()"); static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; @@ -440,16 +442,17 @@ struct GemmWithEpilogueVisitorInterleavedNf4 { params.ptr_B)[threadblock_tile_offset.k()]; } #endif - // if(threadIdx.x==0){ - // printf("##### block: %d-%d-%d, offset_k:%d, threadblock_tile_offset.m-n-k():%d-%d-%d, params.gemm_k_size:%d \n", - // blockIdx.x, blockIdx.y, blockIdx.z, - // offset_k, - // threadblock_tile_offset.m(), - // threadblock_tile_offset.n(), - // threadblock_tile_offset.k(), - // params.gemm_k_size - // ); - // } + // if(threadIdx.x==0){ + // printf("##### block: %d-%d-%d, offset_k:%d, + // threadblock_tile_offset.m-n-k():%d-%d-%d, params.gemm_k_size:%d \n", + // blockIdx.x, blockIdx.y, blockIdx.z, + // offset_k, + // threadblock_tile_offset.m(), + // threadblock_tile_offset.n(), + // threadblock_tile_offset.k(), + // params.gemm_k_size + // ); + // } // Compute initial location in logical coordinates cutlass::MatrixCoord tb_offset_A{ @@ -460,7 +463,8 @@ struct GemmWithEpilogueVisitorInterleavedNf4 { // cutlass::MatrixCoord tb_offset_B{offset_k, threadblock_tile_offset.n() * // Mma::Shape::kN}; // printf("#### kInterleave:%d \n", kInterleave); - // printf("###### offset_k : %d; params.gemm_k_size:%d; threadblock_tile_offset.k():%d \n", + // printf("###### offset_k : %d; params.gemm_k_size:%d; + // threadblock_tile_offset.k():%d \n", // offset_k, // params.gemm_k_size, // threadblock_tile_offset.k() @@ -470,7 +474,8 @@ struct GemmWithEpilogueVisitorInterleavedNf4 { offset_k * kInterleave, threadblock_tile_offset.n() * Mma::Shape::kN / kInterleave}; // if(threadIdx.x==0){ - // printf("##### block: %d-%d-%d, tb_offset_B:%d-%d, kInterleave:%d, Mma::IteratorB::Shape::kRow:%d, Mma::Shape::kK:%d \n", + // printf("##### block: %d-%d-%d, tb_offset_B:%d-%d, kInterleave:%d, + // Mma::IteratorB::Shape::kRow:%d, Mma::Shape::kK:%d \n", // blockIdx.x, blockIdx.y, blockIdx.z, // offset_k * kInterleave, // threadblock_tile_offset.n() * Mma::Shape::kN / kInterleave, @@ -500,13 +505,11 @@ struct GemmWithEpilogueVisitorInterleavedNf4 { thread_idx, tb_offset_B); typename Mma::IteratorNF4LookUpTable iterator_nf4_look_up_table = - Mma::IteratorNF4LookUpTable( - params.params_nf4_look_up_table, - params.ref_nf4_look_up_table.data(), - {0,16}, - threadIdx.x, - {0,0} - ); + Mma::IteratorNF4LookUpTable(params.params_nf4_look_up_table, + params.ref_nf4_look_up_table.data(), + {0, 16}, + threadIdx.x, + {0, 0}); // Broadcast the warp_id computed by lane 0 to ensure dependent code // is compiled as warp-uniform. @@ -530,7 +533,12 @@ struct GemmWithEpilogueVisitorInterleavedNf4 { (problem_size_k - offset_k + Mma::Shape::kK - 1) / Mma::Shape::kK; // printf("#### gemm_k_iterations: %d \n", gemm_k_iterations); // Compute threadblock-scoped matrix multiply-add - mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, iterator_nf4_look_up_table, accumulators); + mma(gemm_k_iterations, + accumulators, + iterator_A, + iterator_B, + iterator_nf4_look_up_table, + accumulators); // if(threadIdx.x==0){ // printf("##### block: %d-%d-%d, offset-m-n-k:%d-%d-%d \n", // blockIdx.x, blockIdx.y, blockIdx.z, @@ -552,7 +560,8 @@ struct GemmWithEpilogueVisitorInterleavedNf4 { threadblock_tile_offset.n() * Mma::Shape::kN); // int block_idx = threadblock_tile_offset.m() + - // threadblock_tile_offset.n() * params.grid_tiled_shape.m(); + // threadblock_tile_offset.n() * + // params.grid_tiled_shape.m(); // // Construct the epilogue visitor diff --git a/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h b/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h index 35469f7ad98..b09368a8486 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h +++ b/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h @@ -29,9 +29,10 @@ See the License for the specific language governing permissions and limitations under the License. */ /* - This file exists so that we use the same weight layout for MoE grouped gemm and regular gemm when the weight is - quantized. The preprocessing code reads this template to know how to organize the quantized weight matrices - to be consumed by CUTLASS. + This file exists so that we use the same weight layout for MoE grouped gemm + and regular gemm when the weight is quantized. The preprocessing code reads + this template to know how to organize the quantized weight matrices to be + consumed by CUTLASS. Note that for int4, ThreadBlockK MUST be 64. @@ -53,79 +54,108 @@ namespace cutlass { namespace gemm { namespace kernel { -template -struct LayoutDetailsB { -}; +template +struct LayoutDetailsB {}; -// // Volta specialiations. Volta will dequantize before STS, so we need a different operator -template +// // Volta specialiations. Volta will dequantize before STS, so we need a +// different operator +template struct LayoutDetailsB { - static constexpr int ThreadblockK = 64; - using Layout = layout::RowMajor; - static constexpr int ElementsPerAccess = 8; - using Operator = cutlass::arch::OpMultiplyAdd; + static constexpr int ThreadblockK = 64; + using Layout = layout::RowMajor; + static constexpr int ElementsPerAccess = 8; + using Operator = cutlass::arch::OpMultiplyAdd; }; -// Specializations for Turing+ when B is FP16. These are currently only used for MoE networks. -// TODO - Switch this to column major for weights since gemms should be more performant. -template -struct LayoutDetailsB= 75>::type> { - static constexpr int ThreadblockK = 64; - using Layout = layout::RowMajor; - static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits::value; - using Operator = cutlass::arch::OpMultiplyAdd; +// Specializations for Turing+ when B is FP16. These are currently only used for +// MoE networks. +// TODO - Switch this to column major for weights since gemms should be more +// performant. +template +struct LayoutDetailsB< + half_t, + Arch, + typename platform::enable_if= 75>::type> { + static constexpr int ThreadblockK = 64; + using Layout = layout::RowMajor; + static constexpr int ElementsPerAccess = + 128 / cutlass::sizeof_bits::value; + using Operator = cutlass::arch::OpMultiplyAdd; }; -template -struct LayoutDetailsB= 75>::type> { - static constexpr int ThreadblockK = 64; - using Layout = layout::RowMajor; - static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits::value; - using Operator = cutlass::arch::OpMultiplyAdd; +template +struct LayoutDetailsB< + bfloat16_t, + Arch, + typename platform::enable_if= 75>::type> { + static constexpr int ThreadblockK = 64; + using Layout = layout::RowMajor; + static constexpr int ElementsPerAccess = + 128 / cutlass::sizeof_bits::value; + using Operator = cutlass::arch::OpMultiplyAdd; }; -// Specializations for Turing+ when B is quantized. These can use the operator OpMultiplyAddDequantizeInterleavedBToA, -// which signals that we want to dequantize after loading from smem. -template -struct LayoutDetailsB= 75>::type> { - static constexpr int ThreadblockK = 64; - -private: - static constexpr int ElementsPerCacheLine = 128 * 8 / sizeof_bits::value; - static constexpr int ColumnsInterleaved = ElementsPerCacheLine / ThreadblockK; - -public: - using Layout = layout::ColumnMajorTileInterleave; - static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits::value; - using Operator = cutlass::arch::OpMultiplyAddDequantizeInterleavedBToA; +// Specializations for Turing+ when B is quantized. These can use the operator +// OpMultiplyAddDequantizeInterleavedBToA, which signals that we want to +// dequantize after loading from smem. +template +struct LayoutDetailsB< + uint8_t, + Arch, + typename platform::enable_if= 75>::type> { + static constexpr int ThreadblockK = 64; + + private: + static constexpr int ElementsPerCacheLine = + 128 * 8 / sizeof_bits::value; + static constexpr int ColumnsInterleaved = ElementsPerCacheLine / ThreadblockK; + + public: + using Layout = + layout::ColumnMajorTileInterleave; + static constexpr int ElementsPerAccess = + 128 / cutlass::sizeof_bits::value; + using Operator = cutlass::arch::OpMultiplyAddDequantizeInterleavedBToA; }; -template -struct LayoutDetailsB= 75>::type> { - static constexpr int ThreadblockK = 64; - -private: - static constexpr int ElementsPerCacheLine = 128 * 8 / sizeof_bits::value; - static constexpr int ColumnsInterleaved = ElementsPerCacheLine / ThreadblockK; - -public: - using Layout = layout::ColumnMajorTileInterleave; - static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits::value; - using Operator = cutlass::arch::OpMultiplyAddDequantizeInterleavedBToA; +template +struct LayoutDetailsB< + uint4b_t, + Arch, + typename platform::enable_if= 75>::type> { + static constexpr int ThreadblockK = 64; + + private: + static constexpr int ElementsPerCacheLine = + 128 * 8 / sizeof_bits::value; + static constexpr int ColumnsInterleaved = ElementsPerCacheLine / ThreadblockK; + + public: + using Layout = + layout::ColumnMajorTileInterleave; + static constexpr int ElementsPerAccess = + 128 / cutlass::sizeof_bits::value; + using Operator = cutlass::arch::OpMultiplyAddDequantizeInterleavedBToA; }; // For int8 int8 int32 Gemm. Author(zhengzekang) -template -struct LayoutDetailsB= 75>::type> { - static constexpr int ThreadblockK = 64; - -private: - static constexpr int ElementsPerCacheLine = 128 * 8 / sizeof_bits::value; - static constexpr int ColumnsInterleaved = ElementsPerCacheLine / ThreadblockK; - -public: - using Layout = layout::ColumnMajorTileInterleave; - static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits::value; - using Operator = cutlass::arch::OpMultiplyAddDequantizeInterleavedBToA; +template +struct LayoutDetailsB< + int8_t, + Arch, + typename platform::enable_if= 75>::type> { + static constexpr int ThreadblockK = 64; + + private: + static constexpr int ElementsPerCacheLine = + 128 * 8 / sizeof_bits::value; + static constexpr int ColumnsInterleaved = ElementsPerCacheLine / ThreadblockK; + + public: + using Layout = + layout::ColumnMajorTileInterleave; + static constexpr int ElementsPerAccess = + 128 / cutlass::sizeof_bits::value; + using Operator = cutlass::arch::OpMultiplyAddDequantizeInterleavedBToA; }; } // namespace kernel diff --git a/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/gemm/threadblock/default_mma_nf4_int8_interleaved.h b/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/gemm/threadblock/default_mma_nf4_int8_interleaved.h index 107251cd84d..bc8ab30bd17 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/gemm/threadblock/default_mma_nf4_int8_interleaved.h +++ b/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/gemm/threadblock/default_mma_nf4_int8_interleaved.h @@ -64,7 +64,7 @@ template < typename InstructionShape_, /// Number of stages used in the pipelined mainloop int Stages, - /// Operation perfomed by GEMM + /// Operation performed by GEMM typename Operator, /// Store the accumulators in row major or column major. Row major is used /// when output layout is interleaved. @@ -74,12 +74,11 @@ template < /// Gather operand A by using an index array bool GatherA = false, /// Gather operand B by using an index array - bool GatherB = false - > + bool GatherB = false> struct DefaulInt8Nf4InterleavedMma; // int8 int8 int32 Gemm specialization. Author(zhengzekang) -template< +template < /// Layout type for A matrix operand typename LayoutA, /// Access granularity of A matrix in units of elements @@ -101,24 +100,6 @@ template< /// Operation performed by GEMM typename Operator> struct DefaulInt8Nf4InterleavedMma { - -private: - - using Mma = Int8Nf4InterleavedMma; - -public: - // Define the MmaCore components - using MmaCore = typename Mma::MmaCore; - - // Define iterators over tiles from the A operand - using IteratorA = typename Mma::IteratorA; - - // Define iterators over tiles from the B operand - using IteratorB = typename Mma::IteratorB; - - // Define the threadblock-scoped pipelined matrix multiply - using ThreadblockMma = typename Mma::ThreadblockMma; + Operator> { + private: + using Mma = Int8Nf4InterleavedMma; + + public: + // Define the MmaCore components + using MmaCore = typename Mma::MmaCore; + + // Define iterators over tiles from the A operand + using IteratorA = typename Mma::IteratorA; + + // Define iterators over tiles from the B operand + using IteratorB = typename Mma::IteratorB; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = typename Mma::ThreadblockMma; }; -template< +template < /// Layout type for A matrix operand typename LayoutA, /// Access granularity of A matrix in units of elements @@ -174,26 +171,6 @@ template< /// Shared memory clear option SharedMemoryClearOption SharedMemoryClear> struct DefaulInt8Nf4InterleavedMma { - -private: - - using Mma = Int8Nf4InterleavedMma; - -public: - // Define the MmaCore components - using MmaCore = typename Mma::MmaCore; - - // Define iterators over tiles from the A operand - using IteratorA = typename Mma::IteratorA; - - // Define iterators over tiles from the B operand - using IteratorB = typename Mma::IteratorB; - - // Define the threadblock-scoped pipelined matrix multiply - using ThreadblockMma = typename Mma::ThreadblockMma; + false, + SharedMemoryClear> { + private: + using Mma = Int8Nf4InterleavedMma; + + public: + // Define the MmaCore components + using MmaCore = typename Mma::MmaCore; + + // Define iterators over tiles from the A operand + using IteratorA = typename Mma::IteratorA; + + // Define iterators over tiles from the B operand + using IteratorB = typename Mma::IteratorB; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = typename Mma::ThreadblockMma; }; - - - // int8 int4 int32 -template< +template < /// Layout type for A matrix operand typename LayoutA, /// Access granularity of A matrix in units of elements @@ -250,24 +242,6 @@ template< /// Operation performed by GEMM typename Operator> struct DefaulInt8Nf4InterleavedMma { - -private: - - using Mma = Int8Nf4InterleavedMma; - -public: - // Define the MmaCore components - using MmaCore = typename Mma::MmaCore; - - // Define iterators over tiles from the A operand - using IteratorA = typename Mma::IteratorA; - - // Define iterators over tiles from the B operand - using IteratorB = typename Mma::IteratorB; - - // Define the threadblock-scoped pipelined matrix multiply - using ThreadblockMma = typename Mma::ThreadblockMma; + Operator> { + private: + using Mma = Int8Nf4InterleavedMma; + + public: + // Define the MmaCore components + using MmaCore = typename Mma::MmaCore; + + // Define iterators over tiles from the A operand + using IteratorA = typename Mma::IteratorA; + + // Define iterators over tiles from the B operand + using IteratorB = typename Mma::IteratorB; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = typename Mma::ThreadblockMma; }; -template< +template < /// Layout type for A matrix operand typename LayoutA, /// Access granularity of A matrix in units of elements @@ -323,26 +313,6 @@ template< /// Shared memory clear option SharedMemoryClearOption SharedMemoryClear> struct DefaulInt8Nf4InterleavedMma { - -private: - - using Mma = Int8Nf4InterleavedMma; - -public: - // Define the MmaCore components - using MmaCore = typename Mma::MmaCore; - - // Define iterators over tiles from the A operand - using IteratorA = typename Mma::IteratorA; - - // Define iterators over tiles from the B operand - using IteratorB = typename Mma::IteratorB; - - // Define the threadblock-scoped pipelined matrix multiply - using ThreadblockMma = typename Mma::ThreadblockMma; + false, + SharedMemoryClear> { + private: + using Mma = Int8Nf4InterleavedMma; + + public: + // Define the MmaCore components + using MmaCore = typename Mma::MmaCore; + + // Define iterators over tiles from the A operand + using IteratorA = typename Mma::IteratorA; + + // Define iterators over tiles from the B operand + using IteratorB = typename Mma::IteratorB; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = typename Mma::ThreadblockMma; }; } // namespace threadblock diff --git a/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/gemm/threadblock/default_nf4_int8_interleaved_mma.h b/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/gemm/threadblock/default_nf4_int8_interleaved_mma.h index e3cb9da1bfa..5c48abd7cc2 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/gemm/threadblock/default_nf4_int8_interleaved_mma.h +++ b/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/gemm/threadblock/default_nf4_int8_interleaved_mma.h @@ -38,57 +38,67 @@ namespace gemm { namespace threadblock { //////////////////////////////////////////////////////////////////////////////// -// // We need to distinguish here, since we want volta support. It is too much effort -// // to write shared memory iterators that are probably needed for volta to function -// // properly. As a result, we allow converters both after the LDG (for volta) and after +// // We need to distinguish here, since we want volta support. It is too much +// effort +// // to write shared memory iterators that are probably needed for volta to +// function +// // properly. As a result, we allow converters both after the LDG (for volta) +// and after // // the LDS for Turing+. -template< +template < /// Iterator for B matrix in global memory typename IteratorB, /// Warp level Mma typename MmaOperator, /// Math operation perform by warp level operator typename MathOperator> -struct SetConvertersInt8Nf4Interleaved { -}; +struct SetConvertersInt8Nf4Interleaved {}; // // Dequantize after LDG, so set transforms accordingly -template< +template < /// Iterator for B matrix in global memory typename IteratorB, /// Mma Policy typename MmaOperator> -struct SetConvertersInt8Nf4Interleaved { - using TransformAfterLDG = - FastInterleavedAndBiasedNumericArrayConverterNf4; - - using TransformAfterLDS = NumericArrayConverter; +struct SetConvertersInt8Nf4Interleaved { + using TransformAfterLDG = FastInterleavedAndBiasedNumericArrayConverterNf4< + typename MmaOperator::ArchMmaOperator::ElementB, + typename IteratorB::Element, + IteratorB::Fragment::kElements>; + + using TransformAfterLDS = + NumericArrayConverter; }; // Dequantize after LDS, so set transforms accordingly -template< +template < /// Iterator for B matrix in global memory typename IteratorB, /// Mma Policy typename MmaOperator> -struct SetConvertersInt8Nf4Interleaved { - using TransformAfterLDG = - NumericArrayConverter; - - using TransformAfterLDS = - FastInterleavedAndBiasedNumericArrayConverterNf4; +struct SetConvertersInt8Nf4Interleaved< + IteratorB, + MmaOperator, + arch::OpMultiplyAddDequantizeInterleavedBToA> { + using TransformAfterLDG = + NumericArrayConverter; + + using TransformAfterLDS = FastInterleavedAndBiasedNumericArrayConverterNf4< + typename MmaOperator::ArchMmaOperator::ElementB, + typename TransformAfterLDG::result_type::Element, + MmaOperator::FragmentB::kElements>; }; //////////////////////////////////////////////////////////////////////////////// -template< +template < /// Element type for A matrix operand typename ElementA_, /// Layout type for A matrix operand diff --git a/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/gemm/threadblock/default_nf4_int8_interleaved_mma_multistage.h b/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/gemm/threadblock/default_nf4_int8_interleaved_mma_multistage.h index b2bc5d25114..aa3a04d9ac4 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/gemm/threadblock/default_nf4_int8_interleaved_mma_multistage.h +++ b/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/gemm/threadblock/default_nf4_int8_interleaved_mma_multistage.h @@ -46,7 +46,7 @@ namespace threadblock { //////////////////////////////////////////////////////////////////////////////// -template< +template < /// Type for elementA typename ElementA, /// Layout type for A matrix operand @@ -78,120 +78,134 @@ template< /// SharedMemoryClearOption SharedMemoryClear> struct Int8Nf4InterleavedMma= 80)>::type> { - - static_assert(platform::is_same::value, - "Element A must be in8t"); - - static_assert(platform::is_same::value, - "Mma multistage must dequantize after ldsm"); - - static_assert(platform::is_same::value || platform::is_same::value, - "Element B must be int8 or uint4"); - static_assert(WarpShape::kK!=0,""); - static_assert(ThreadblockShape::kK!=0,""); - - static cutlass::arch::CacheOperation::Kind const CacheOpA = ((sizeof_bits::value * kAlignmentA) == 128) ? - cutlass::arch::CacheOperation::Global : - cutlass::arch::CacheOperation::Always; - - static cutlass::arch::CacheOperation::Kind const CacheOpB = ((sizeof_bits::value * kAlignmentB) == 128) ? - cutlass::arch::CacheOperation::Global : - cutlass::arch::CacheOperation::Always; - - // Define the MmaCore components - // Mma core does not depend on stages, so pass in at least 3 here to mma multistage pieces are created - using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore; - - // Define iterators over tiles from the A operand - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using AccessTypeA = cutlass::Array; - using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator< - cutlass::MatrixShape, - ElementA, - LayoutA, - 1, - ThreadMapA, - AccessTypeA>; - - // Define iterators over tiles from the B operand - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - using AccessTypeB = cutlass::Array; - using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator< - cutlass::MatrixShape, - ElementB, - LayoutB, - 0, - ThreadMapB, - AccessTypeB>; - static int const kAlignmentNF4LookUpTable = 128 / sizeof_bits::value; - using AccessTypeNF4LookUpTable = cutlass::Array; - using IteratorNF4LookUpTableThreadMap = transform::PitchLinearStripminedThreadMap< - layout::PitchLinearShape<4, 1>, - 16 / kAlignmentNF4LookUpTable, - kAlignmentNF4LookUpTable>; - using IteratorNF4LookUpTable = cutlass::transform::threadblock::PredicatedTileAccessIterator< - cutlass::MatrixShape<1, 16>, - int32_t, - cutlass::layout::RowMajor, - 0, - IteratorNF4LookUpTableThreadMap, - AccessTypeNF4LookUpTable>; - - using SmemIteratorNF4LookUpTable = IteratorNF4LookUpTable; - - using Converter = FastInterleavedAndBiasedNumericArrayConverterNf4; - - // Define the threadblock-scoped pipelined matrix multiply - using ThreadblockMma = cutlass::gemm::threadblock::Int8Nf4InterleavedMmaMultistage; + LayoutA, + kAlignmentA, + ElementB, + LayoutB, + kAlignmentB, + ElementAccumulator, + layout::RowMajor, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + kStages, + Operator, + SharedMemoryClear, + typename platform::enable_if<( + ArchTag::kMinComputeCapability >= 80)>::type> { + static_assert(platform::is_same::value, + "Element A must be in8t"); + + static_assert( + platform::is_same::value, + "Mma multistage must dequantize after ldsm"); + + static_assert(platform::is_same::value || + platform::is_same::value, + "Element B must be int8 or uint4"); + static_assert(WarpShape::kK != 0, ""); + static_assert(ThreadblockShape::kK != 0, ""); + + static cutlass::arch::CacheOperation::Kind const CacheOpA = + ((sizeof_bits::value * kAlignmentA) == 128) + ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; + + static cutlass::arch::CacheOperation::Kind const CacheOpB = + ((sizeof_bits::value * kAlignmentB) == 128) + ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; + + // Define the MmaCore components + // Mma core does not depend on stages, so pass in at least 3 here to mma + // multistage pieces are created + using MmaCore = + typename cutlass::gemm::threadblock::DefaultMmaCore; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::Array; + using IteratorA = + cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, + ElementA, + LayoutA, + 1, + ThreadMapA, + AccessTypeA>; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::Array; + using IteratorB = + cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, + ElementB, + LayoutB, + 0, + ThreadMapB, + AccessTypeB>; + static int const kAlignmentNF4LookUpTable = 128 / sizeof_bits::value; + using AccessTypeNF4LookUpTable = + cutlass::Array; + using IteratorNF4LookUpTableThreadMap = + transform::PitchLinearStripminedThreadMap, + 16 / kAlignmentNF4LookUpTable, + kAlignmentNF4LookUpTable>; + using IteratorNF4LookUpTable = + cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape<1, 16>, + int32_t, + cutlass::layout::RowMajor, + 0, + IteratorNF4LookUpTableThreadMap, + AccessTypeNF4LookUpTable>; + + using SmemIteratorNF4LookUpTable = IteratorNF4LookUpTable; + + using Converter = FastInterleavedAndBiasedNumericArrayConverterNf4< + ElementA, + ElementB, + MmaCore::MmaPolicy::Operator::FragmentB::kElements>; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = + cutlass::gemm::threadblock::Int8Nf4InterleavedMmaMultistage< + typename MmaCore::Shape, + IteratorA, + typename MmaCore::SmemIteratorA, + MmaCore::kCacheOpA, + IteratorB, + typename MmaCore::SmemIteratorB, + MmaCore::kCacheOpB, + IteratorNF4LookUpTable, + SmemIteratorNF4LookUpTable, + ElementAccumulator, + layout::RowMajor, + typename MmaCore::MmaPolicy, + kStages, + Converter, + SharedMemoryClear>; }; -template< +template < /// Type for element A typename ElementA, /// Layout type for A matrix operand @@ -224,140 +238,156 @@ template< int RowsPerTile, /// int ColumnsInterleaved> -struct Int8Nf4InterleavedMma, - kAlignmentB, - ElementAccumulator, - layout::RowMajor, - OperatorClass, - ArchTag, - ThreadblockShape, - WarpShape, - InstructionShape, - kStages, - Operator, - SharedMemoryClear, - typename platform::enable_if<(ArchTag::kMinComputeCapability >= 80)>::type> { - - static_assert(platform::is_same::value , - "Element A int8_t"); - - static_assert(platform::is_same::value, - "Mma multistage must dequantize after ldsm"); - - // static_assert(platform::is_same::value || platform::is_same::value, - // "Element B must be uint8 or uint4"); - static_assert(platform::is_same::value, - "Element B must be uint4"); - - static cutlass::arch::CacheOperation::Kind const CacheOpA = ((sizeof_bits::value * kAlignmentA) == 128) ? - cutlass::arch::CacheOperation::Global : - cutlass::arch::CacheOperation::Always; - - static cutlass::arch::CacheOperation::Kind const CacheOpB = ((sizeof_bits::value * kAlignmentB) == 128) ? - cutlass::arch::CacheOperation::Global : - cutlass::arch::CacheOperation::Always; - - // Define the MmaCore components - // Mma core does not depend on stages, so pass in at least 3 here to mma multistage pieces are created - using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore; - - // Define iterators over tiles from the A operand - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using AccessTypeA = cutlass::Array; - using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator< - cutlass::MatrixShape, - ElementA, - LayoutA, - 1, - ThreadMapA, - AccessTypeA>; - -private: - static_assert(!(MmaCore::Shape::kN % ColumnsInterleaved), ""); - static_assert(RowsPerTile == MmaCore::Shape::kK, ""); - // static_assert(ColumnsInterleaved==4 || cutlass::platform::is_same::value, "####"); - using OriginalThreadMap = typename MmaCore::IteratorThreadMapB; - using OriginalWarpArrangement = typename OriginalThreadMap::Detail::WarpThreadArrangement; - static_assert(!(OriginalWarpArrangement::kStrided % ColumnsInterleaved), ""); - - using GmemIteratorShape = - MatrixShape; - using GmemThreadMapB = transform::PitchLinearWarpRakedThreadMap< - layout::PitchLinearShape, - OriginalThreadMap::kThreads, - layout::PitchLinearShape, - MmaCore::kAccessSizeInBits / sizeof_bits::value>; - -public: - // Define iterators over tiles from the B operand - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - using AccessTypeB = cutlass::Array; - using IteratorB = cutlass::transform::threadblock:: - PredicatedTileAccessIterator; - - static int const kAlignmentNF4LookUpTable = 128 / sizeof_bits::value; - using AccessTypeNF4LookUpTable = cutlass::Array; - using IteratorNF4LookUpTableThreadMap = transform::PitchLinearStripminedThreadMap< - layout::PitchLinearShape<4, 1>, - 1, - 4>; - using IteratorNF4LookUpTable = cutlass::transform::threadblock::PredicatedTileAccessIterator< - cutlass::MatrixShape<1, 16>, - int32_t, - cutlass::layout::RowMajor, - 0, - IteratorNF4LookUpTableThreadMap, - AccessTypeNF4LookUpTable>; - - - using SmemIteratorNF4LookUpTable = IteratorNF4LookUpTable; - - // static_assert(MmaCore::MmaPolicy::Operator::FragmentB::kElements==64,"MmaCore::MmaPolicy::Operator::FragmentB::kElements == 32"); - using Converter = FastInterleavedAndBiasedNumericArrayConverterNf4; - - // Define the threadblock-scoped pipelined matrix multiply - using ThreadblockMma = cutlass::gemm::threadblock::Int8Nf4InterleavedMmaMultistage; +struct Int8Nf4InterleavedMma< + ElementA, + LayoutA, + kAlignmentA, + ElementB, + layout::ColumnMajorTileInterleave, + kAlignmentB, + ElementAccumulator, + layout::RowMajor, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + kStages, + Operator, + SharedMemoryClear, + typename platform::enable_if<(ArchTag::kMinComputeCapability >= + 80)>::type> { + static_assert(platform::is_same::value, "Element A int8_t"); + + static_assert( + platform::is_same::value, + "Mma multistage must dequantize after ldsm"); + + // static_assert(platform::is_same::value || + // platform::is_same::value, + // "Element B must be uint8 or uint4"); + static_assert(platform::is_same::value, + "Element B must be uint4"); + + static cutlass::arch::CacheOperation::Kind const CacheOpA = + ((sizeof_bits::value * kAlignmentA) == 128) + ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; + + static cutlass::arch::CacheOperation::Kind const CacheOpB = + ((sizeof_bits::value * kAlignmentB) == 128) + ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; + + // Define the MmaCore components + // Mma core does not depend on stages, so pass in at least 3 here to mma + // multistage pieces are created + using MmaCore = + typename cutlass::gemm::threadblock::DefaultMmaCore; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::Array; + using IteratorA = + cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, + ElementA, + LayoutA, + 1, + ThreadMapA, + AccessTypeA>; + + private: + static_assert(!(MmaCore::Shape::kN % ColumnsInterleaved), ""); + static_assert(RowsPerTile == MmaCore::Shape::kK, ""); + // static_assert(ColumnsInterleaved==4 || cutlass::platform::is_same::value, "####"); + using OriginalThreadMap = typename MmaCore::IteratorThreadMapB; + using OriginalWarpArrangement = + typename OriginalThreadMap::Detail::WarpThreadArrangement; + static_assert(!(OriginalWarpArrangement::kStrided % ColumnsInterleaved), ""); + + using GmemIteratorShape = + MatrixShape; + using GmemThreadMapB = transform::PitchLinearWarpRakedThreadMap< + layout::PitchLinearShape, + OriginalThreadMap::kThreads, + layout::PitchLinearShape< + OriginalWarpArrangement::kContiguous * ColumnsInterleaved, + OriginalWarpArrangement::kStrided / ColumnsInterleaved>, + MmaCore::kAccessSizeInBits / sizeof_bits::value>; + + public: + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::Array; + using IteratorB = + cutlass::transform::threadblock::PredicatedTileAccessIterator< + GmemIteratorShape, + ElementB, + layout::ColumnMajor, + 0, + GmemThreadMapB, + AccessTypeB>; + + static int const kAlignmentNF4LookUpTable = 128 / sizeof_bits::value; + using AccessTypeNF4LookUpTable = + cutlass::Array; + using IteratorNF4LookUpTableThreadMap = transform:: + PitchLinearStripminedThreadMap, 1, 4>; + using IteratorNF4LookUpTable = + cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape<1, 16>, + int32_t, + cutlass::layout::RowMajor, + 0, + IteratorNF4LookUpTableThreadMap, + AccessTypeNF4LookUpTable>; + + using SmemIteratorNF4LookUpTable = IteratorNF4LookUpTable; + + // static_assert(MmaCore::MmaPolicy::Operator::FragmentB::kElements==64,"MmaCore::MmaPolicy::Operator::FragmentB::kElements + // == 32"); + using Converter = FastInterleavedAndBiasedNumericArrayConverterNf4< + ElementA, + ElementB, + MmaCore::MmaPolicy::Operator::FragmentB::kElements>; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = + cutlass::gemm::threadblock::Int8Nf4InterleavedMmaMultistage< + typename MmaCore::Shape, + IteratorA, + typename MmaCore::SmemIteratorA, + MmaCore::kCacheOpA, + IteratorB, + typename MmaCore::SmemIteratorB, + MmaCore::kCacheOpB, + IteratorNF4LookUpTable, + SmemIteratorNF4LookUpTable, + ElementAccumulator, + layout::RowMajor, + typename MmaCore::MmaPolicy, + kStages, + Converter, + SharedMemoryClear>; }; } // namespace threadblock diff --git a/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/gemm/threadblock/int8_mma_base.h b/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/gemm/threadblock/int8_mma_base.h index 9648deb56a5..585d88ec85c 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/gemm/threadblock/int8_mma_base.h +++ b/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/gemm/threadblock/int8_mma_base.h @@ -1,12 +1,12 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation @@ -18,14 +18,15 @@ * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. @@ -63,46 +64,48 @@ namespace gemm { namespace threadblock { // //////////////////////////////////////////////////////////////////////////////// -// // SFINAE trick so I can keep the same loop code for Volta and dispatch to the -// // correct warp level mma. On volta, all data is stored to shared memory as FP16. -// template -// CUTLASS_DEVICE void run_warp_mma(WarpMma& warp_mma, +// // SFINAE trick so I can keep the same loop code for Volta and dispatch to +// the +// // correct warp level mma. On volta, all data is stored to shared memory as +// FP16. template CUTLASS_DEVICE +// void run_warp_mma(WarpMma& warp_mma, // typename WarpMma::FragmentC& D, // typename WarpMma::FragmentA const& A, // typename WarpMma::FragmentB const& B, // typename WarpMma::FragmentC const& C, -// const int warp_tileB_k_offset) +// const int warp_tileB_k_offset) // { // warp_mma(D, A, B, C); // } // template -// CUTLASS_DEVICE void run_warp_mma(WarpMma& warp_mma, -// typename WarpMma::FragmentC& D, -// typename WarpMma::TransformedFragmentA const& A, -// typename WarpMma::TransformedFragmentB const& B, -// typename WarpMma::FragmentC const& C, -// const int warp_tileB_k_offset) +// CUTLASS_DEVICE void run_warp_mma(WarpMma& warp_mma, +// typename WarpMma::FragmentC& D, typename +// WarpMma::TransformedFragmentA const& A, +// typename WarpMma::TransformedFragmentB +// const& B, typename WarpMma::FragmentC const& +// C, const int warp_tileB_k_offset) // { // warp_mma(D, A, B, C, warp_tileB_k_offset); // } -// TODO(zhengzekang): Since we annotate the first implement, we currently hack to `run_ampere_warp_mma` used in A100. -template -CUTLASS_DEVICE void run_ampere_warp_mma(WarpMma& warp_mma, - typename WarpMma::FragmentC& D, - typename WarpMma::TransformedFragmentA const& A, - typename WarpMma::TransformedFragmentB const& B, - typename WarpMma::FragmentC const& C, - const int warp_tileB_k_offset) -{ - warp_mma(D, A, B, C, warp_tileB_k_offset); +// TODO(zhengzekang): Since we annotate the first implement, we currently hack +// to `run_ampere_warp_mma` used in A100. +template +CUTLASS_DEVICE void run_ampere_warp_mma( + WarpMma& warp_mma, + typename WarpMma::FragmentC& D, + typename WarpMma::TransformedFragmentA const& A, + typename WarpMma::TransformedFragmentB const& B, + typename WarpMma::FragmentC const& C, + const int warp_tileB_k_offset) { + warp_mma(D, A, B, C, warp_tileB_k_offset); } //////////////////////////////////////////////////////////////////////////////// /// Structure to compute the matrix product targeting CUDA cores and SIMT math /// instructions. -template< +template < /// Size of the Gemm problem - concept: gemm::GemmShape<> typename Shape_, /// Policy describing tuning details (concept: MmaPolicy) @@ -112,135 +115,139 @@ template< /// Used for partial specialization typename Enable = bool> class Int8InterleavedMmaBase { -public: - ///< Size of the Gemm problem - concept: gemm::GemmShape<> - using Shape = Shape_; - - ///< Policy describing tuning details - using Policy = Policy_; - + public: + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape = Shape_; + + ///< Policy describing tuning details + using Policy = Policy_; + + // + // Dependent types + // + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Shape describing the overall GEMM computed from shared memory + /// by each warp. + using WarpGemm = typename Policy::Operator::Shape; + + /// Shape describing the number of warps filling the CTA + using WarpCount = GemmShape; + + /// Number of warp-level GEMM operations + static int const kWarpGemmIterations = + (WarpGemm::kK / Operator::Policy::MmaShape::kK); + static_assert(Operator::IteratorB::InstructionShape::kRow >= + Operator::InstructionShape::kK, + ""); + static constexpr int kNumKIterationsPerWarpBLoad = + Operator::IteratorB::InstructionShape::kRow / + Operator::InstructionShape::kK; + static_assert(!(kWarpGemmIterations % kNumKIterationsPerWarpBLoad), ""); + static constexpr int kWarpGemmIterationsForB = + kWarpGemmIterations / kNumKIterationsPerWarpBLoad; + + /// Number of stages + static int const kStages = Stages; + + /// Tensor reference to the A operand + using TensorRefA = + TensorRef; + + /// Tensor reference to the B operand + using TensorRefB = + TensorRef; + + // + // Nested structs + // + + /// Shared storage object needed by threadblock-scoped GEMM + class SharedStorage { + public: // - // Dependent types + // Type definitions // - /// Warp-level Mma - using Operator = typename Policy::Operator; - - /// Shape describing the overall GEMM computed from shared memory - /// by each warp. - using WarpGemm = typename Policy::Operator::Shape; - - /// Shape describing the number of warps filling the CTA - using WarpCount = GemmShape; + /// Shape of the A matrix operand in shared memory + using ShapeA = + MatrixShape; - /// Number of warp-level GEMM oeprations - static int const kWarpGemmIterations = (WarpGemm::kK / Operator::Policy::MmaShape::kK); - static_assert(Operator::IteratorB::InstructionShape::kRow>=Operator::InstructionShape::kK,""); - static constexpr int kNumKIterationsPerWarpBLoad = - Operator::IteratorB::InstructionShape::kRow / Operator::InstructionShape::kK; - static_assert(!(kWarpGemmIterations % kNumKIterationsPerWarpBLoad), ""); - static constexpr int kWarpGemmIterationsForB = kWarpGemmIterations / kNumKIterationsPerWarpBLoad; + /// Shape of the B matrix operand in shared memory + using ShapeB = MatrixShape; - /// Number of stages - static int const kStages = Stages; + public: + // + // Data members + // - /// Tensor reference to the A operand - using TensorRefA = TensorRef; + /// Buffer for A operand + AlignedBuffer operand_A; - /// Tensor reference to the B operand - using TensorRefB = TensorRef; + /// Buffer for B operand + AlignedBuffer operand_B; + public: // - // Nested structs + // Methods // - /// Shared storage object needed by threadblock-scoped GEMM - class SharedStorage { - public: - // - // Type definitions - // - - /// Shape of the A matrix operand in shared memory - using ShapeA = - MatrixShape; - - /// Shape of the B matrix operand in shared memory - using ShapeB = - MatrixShape; - - public: - // - // Data members - // - - /// Buffer for A operand - AlignedBuffer operand_A; - - /// Buffer for B operand - AlignedBuffer operand_B; - - public: - // - // Methods - // - - /// Returns a layout object for the A matrix - CUTLASS_DEVICE - static typename Operator::LayoutA LayoutA() - { - return Operator::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn}); - } - - /// Returns a layout object for the B matrix - CUTLASS_HOST_DEVICE - static typename Operator::LayoutB LayoutB() - { - return Operator::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn}); - } - - /// Returns a TensorRef to the A operand - CUTLASS_HOST_DEVICE - TensorRefA operand_A_ref() - { - return TensorRefA{operand_A.data(), LayoutA()}; - } - - /// Returns a TensorRef to the B operand - CUTLASS_HOST_DEVICE - TensorRefB operand_B_ref() - { - return TensorRefB{operand_B.data(), LayoutB()}; - } - }; - -protected: - // - // Data members - // + /// Returns a layout object for the A matrix + CUTLASS_DEVICE + static typename Operator::LayoutA LayoutA() { + return Operator::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn}); + } - /// Iterator to load a warp-scoped tile of A operand from shared memory - typename Operator::IteratorA warp_tile_iterator_A_; + /// Returns a layout object for the B matrix + CUTLASS_HOST_DEVICE + static typename Operator::LayoutB LayoutB() { + return Operator::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn}); + } - /// Iterator to load a warp-scoped tile of B operand from shared memory - typename Operator::IteratorB warp_tile_iterator_B_; + /// Returns a TensorRef to the A operand + CUTLASS_HOST_DEVICE + TensorRefA operand_A_ref() { + return TensorRefA{operand_A.data(), LayoutA()}; + } -public: - /// Construct from tensor references - CUTLASS_DEVICE - Int8InterleavedMmaBase( - ///< Shared storage needed for internal use by threadblock-scoped GEMM - SharedStorage& shared_storage, - ///< ID within the threadblock - int thread_idx, - ///< ID of warp - int warp_idx, - ///< ID of each thread within a warp - int lane_idx): - warp_tile_iterator_A_(shared_storage.operand_A_ref(), lane_idx), - warp_tile_iterator_B_(shared_storage.operand_B_ref(), lane_idx) - { + /// Returns a TensorRef to the B operand + CUTLASS_HOST_DEVICE + TensorRefB operand_B_ref() { + return TensorRefB{operand_B.data(), LayoutB()}; } + }; + + protected: + // + // Data members + // + + /// Iterator to load a warp-scoped tile of A operand from shared memory + typename Operator::IteratorA warp_tile_iterator_A_; + + /// Iterator to load a warp-scoped tile of B operand from shared memory + typename Operator::IteratorB warp_tile_iterator_B_; + + public: + /// Construct from tensor references + CUTLASS_DEVICE + Int8InterleavedMmaBase( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + SharedStorage& shared_storage, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx) + : warp_tile_iterator_A_(shared_storage.operand_A_ref(), lane_idx), + warp_tile_iterator_B_(shared_storage.operand_B_ref(), lane_idx) {} }; ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/gemm/threadblock/int8_mma_multistage.h b/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/gemm/threadblock/int8_mma_multistage.h index 218b33c5ac5..57e19ad92b9 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/gemm/threadblock/int8_mma_multistage.h +++ b/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/gemm/threadblock/int8_mma_multistage.h @@ -1,12 +1,12 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation @@ -18,18 +18,18 @@ * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ - /* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); @@ -72,7 +72,7 @@ namespace threadblock { /// Structure to compute the matrix product targeting CUDA cores and SIMT math /// instructions. -template< +template < /// Size of the Gemm problem - concept: gemm::GemmShape<> typename Shape_, /// Iterates over tiles of A operand in global memory @@ -107,489 +107,532 @@ template< SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone, /// Used for partial specialization typename Enable = bool> -class Int8InterleavedMmaMultistage: public Int8InterleavedMmaBase { -public: - ///< Base class - using Base = Int8InterleavedMmaBase; - ///< Size of the Gemm problem - concept: gemm::GemmShape<> - using Shape = Shape_; - ///< Iterates over tiles of A operand in global memory - using IteratorA = IteratorA_; - ///< Iterates over tiles of B operand in global memory - using IteratorB = IteratorB_; - ///< Data type of accumulator matrix - using ElementC = ElementC_; - ///< Layout of accumulator matrix - using LayoutC = LayoutC_; - ///< Policy describing tuning details - using Policy = Policy_; - - - using SmemIteratorA = SmemIteratorA_; - using SmemIteratorB = SmemIteratorB_; - - static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; - static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; - - using TransformBAfterLDS = TransformBAfterLDS_; - - // - // Dependent types - // - - /// Fragment of accumulator tile - using FragmentC = typename Policy::Operator::FragmentC; - - /// Warp-level Mma - using Operator = typename Policy::Operator; - - /// Minimum architecture is Sm80 to support cp.async - using ArchTag = arch::Sm80; - - /// Complex transform on A operand - static ComplexTransform const kTransformA = Operator::kTransformA; - - /// Complex transform on B operand - static ComplexTransform const kTransformB = Operator::kTransformB; - - /// Internal structure exposed for introspection. - struct Detail { - - static_assert(Base::kWarpGemmIterations > 1, - "The pipelined structure requires at least two warp-level " - "GEMM operations."); - // static_assert(Base::kWarpGemmIterations==4,"Base::kWarpGemmIterations!=4"); - /// Number of cp.async instructions to load one stage of operand A - static int const AsyncCopyIterationsPerStageA = IteratorA::ThreadMap::Iterations::kCount; - - /// Number of cp.async instructions to load one stage of operand B - static int const AsyncCopyIterationsPerStageB = IteratorB::ThreadMap::Iterations::kCount; - - /// Number of stages - static int const kStages = Stages; - - /// Number of cp.async instructions to load on group of operand A - static int const kAccessesPerGroupA = - (AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; - - /// Number of cp.async instructions to load on group of operand B - static int const kAccessesPerGroupB = - (AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; - }; - -private: - using WarpFragmentA = typename Operator::FragmentA; - using WarpFragmentB = typename Operator::FragmentB; - using ElementB = typename IteratorB::Element; - using LayoutDetailsForB = kernel::LayoutDetailsB; - - static constexpr bool RequiresTileInterleave = - layout::IsColumnMajorTileInterleave::value; - static_assert(!RequiresTileInterleave || (RequiresTileInterleave && (Shape::kK == LayoutDetailsForB::ThreadblockK)), - "Layout K must match threadblockK"); - -private: - // - // Data members - // - - /// Iterator to write threadblock-scoped tile of A operand to shared memory - SmemIteratorA smem_iterator_A_; - - /// Iterator to write threadblock-scoped tile of B operand to shared memory - SmemIteratorB smem_iterator_B_; - -public: - /// Construct from tensor references - CUTLASS_DEVICE - Int8InterleavedMmaMultistage( - ///< Shared storage needed for internal use by threadblock-scoped GEMM - typename Base::SharedStorage& shared_storage, - ///< ID within the threadblock - int thread_idx, - ///< ID of warp - int warp_idx, - ///< ID of each thread within a warp - int lane_idx): - Base(shared_storage, thread_idx, warp_idx, lane_idx), +class Int8InterleavedMmaMultistage + : public Int8InterleavedMmaBase { + public: + ///< Base class + using Base = Int8InterleavedMmaBase; + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape = Shape_; + ///< Iterates over tiles of A operand in global memory + using IteratorA = IteratorA_; + ///< Iterates over tiles of B operand in global memory + using IteratorB = IteratorB_; + ///< Data type of accumulator matrix + using ElementC = ElementC_; + ///< Layout of accumulator matrix + using LayoutC = LayoutC_; + ///< Policy describing tuning details + using Policy = Policy_; + + using SmemIteratorA = SmemIteratorA_; + using SmemIteratorB = SmemIteratorB_; + + static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; + static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; + + using TransformBAfterLDS = TransformBAfterLDS_; + + // + // Dependent types + // + + /// Fragment of accumulator tile + using FragmentC = typename Policy::Operator::FragmentC; + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Minimum architecture is Sm80 to support cp.async + using ArchTag = arch::Sm80; + + /// Complex transform on A operand + static ComplexTransform const kTransformA = Operator::kTransformA; + + /// Complex transform on B operand + static ComplexTransform const kTransformB = Operator::kTransformB; + + /// Internal structure exposed for introspection. + struct Detail { + static_assert(Base::kWarpGemmIterations > 1, + "The pipelined structure requires at least two warp-level " + "GEMM operations."); + // static_assert(Base::kWarpGemmIterations==4,"Base::kWarpGemmIterations!=4"); + /// Number of cp.async instructions to load one stage of operand A + static int const AsyncCopyIterationsPerStageA = + IteratorA::ThreadMap::Iterations::kCount; + + /// Number of cp.async instructions to load one stage of operand B + static int const AsyncCopyIterationsPerStageB = + IteratorB::ThreadMap::Iterations::kCount; + + /// Number of stages + static int const kStages = Stages; + + /// Number of cp.async instructions to load on group of operand A + static int const kAccessesPerGroupA = + (AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) / + Base::kWarpGemmIterations; + + /// Number of cp.async instructions to load on group of operand B + static int const kAccessesPerGroupB = + (AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / + Base::kWarpGemmIterations; + }; + + private: + using WarpFragmentA = typename Operator::FragmentA; + using WarpFragmentB = typename Operator::FragmentB; + using ElementB = typename IteratorB::Element; + using LayoutDetailsForB = kernel::LayoutDetailsB; + + static constexpr bool RequiresTileInterleave = + layout::IsColumnMajorTileInterleave< + typename LayoutDetailsForB::Layout>::value; + static_assert(!RequiresTileInterleave || + (RequiresTileInterleave && + (Shape::kK == LayoutDetailsForB::ThreadblockK)), + "Layout K must match threadblockK"); + + private: + // + // Data members + // + + /// Iterator to write threadblock-scoped tile of A operand to shared memory + SmemIteratorA smem_iterator_A_; + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB smem_iterator_B_; + + public: + /// Construct from tensor references + CUTLASS_DEVICE + Int8InterleavedMmaMultistage( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + typename Base::SharedStorage& shared_storage, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx) + : Base(shared_storage, thread_idx, warp_idx, lane_idx), smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx), - smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) - { - // Compute warp location within threadblock tile by mapping the warp_id to - // three coordinates: - // _m: the warp's position within the threadblock along the M dimension - // _n: the warp's position within the threadblock along the N dimension - // _k: the warp's position within the threadblock along the K dimension - - int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); - int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); - - int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; - int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; - - // Add per-warp offsets in units of warp-level tiles - this->warp_tile_iterator_A_.add_tile_offset({warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); - this->warp_tile_iterator_B_.add_tile_offset({Base::kWarpGemmIterationsForB * warp_idx_k, warp_idx_n}); - } + smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) { + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + + int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); + int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); + + int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; + int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; + + // Add per-warp offsets in units of warp-level tiles + this->warp_tile_iterator_A_.add_tile_offset( + {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); + this->warp_tile_iterator_B_.add_tile_offset( + {Base::kWarpGemmIterationsForB * warp_idx_k, warp_idx_n}); + } + + CUTLASS_DEVICE + void copy_tiles_and_advance(IteratorA& iterator_A, + IteratorB& iterator_B, + int group_start_A = 0, + int group_start_B = 0) { + iterator_A.set_iteration_index(group_start_A * + IteratorA::kAccessesPerVector); + this->smem_iterator_A_.set_iteration_index(group_start_A); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupA; ++j) { + if (group_start_A + j < Detail::AsyncCopyIterationsPerStageA) { + typename IteratorA::AccessType* dst_ptr = + reinterpret_cast( + this->smem_iterator_A_.get()); + + int const kSrcBytes = sizeof_bits::value * + IteratorA::ThreadMap::kElementsPerAccess / + IteratorA::kAccessesPerVector / 8; - CUTLASS_DEVICE - void - copy_tiles_and_advance(IteratorA& iterator_A, IteratorB& iterator_B, int group_start_A = 0, int group_start_B = 0) - { - iterator_A.set_iteration_index(group_start_A * IteratorA::kAccessesPerVector); - this->smem_iterator_A_.set_iteration_index(group_start_A); - - // Async Copy for operand A CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::kAccessesPerGroupA; ++j) { - if (group_start_A + j < Detail::AsyncCopyIterationsPerStageA) { - typename IteratorA::AccessType* dst_ptr = - reinterpret_cast(this->smem_iterator_A_.get()); - - int const kSrcBytes = sizeof_bits::value - * IteratorA::ThreadMap::kElementsPerAccess / IteratorA::kAccessesPerVector / 8; - - CUTLASS_PRAGMA_UNROLL - for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { - auto gmem_ptr = iterator_A.get(); - - if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { - cutlass::arch::cp_async_zfill(dst_ptr + v, gmem_ptr, iterator_A.valid()); - } - else { - cutlass::arch::cp_async(dst_ptr + v, gmem_ptr, iterator_A.valid()); - } - - ++iterator_A; - } - - ++this->smem_iterator_A_; - } + for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { + auto gmem_ptr = iterator_A.get(); + + if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { + cutlass::arch::cp_async_zfill( + dst_ptr + v, gmem_ptr, iterator_A.valid()); + } else { + cutlass::arch::cp_async( + dst_ptr + v, gmem_ptr, iterator_A.valid()); + } + + ++iterator_A; } - iterator_B.set_iteration_index(group_start_B * IteratorB::kAccessesPerVector); - this->smem_iterator_B_.set_iteration_index(group_start_B); + ++this->smem_iterator_A_; + } + } - // Async Copy for operand B + iterator_B.set_iteration_index(group_start_B * + IteratorB::kAccessesPerVector); + this->smem_iterator_B_.set_iteration_index(group_start_B); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) { + if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) { + typename IteratorB::AccessType* dst_ptr = + reinterpret_cast( + this->smem_iterator_B_.get()); + + int const kSrcBytes = sizeof_bits::value * + IteratorB::ThreadMap::kElementsPerAccess / + IteratorB::kAccessesPerVector / 8; + // if((threadIdx.x==0) && threadIdx.y==0 && threadIdx.z==0 && + // blockIdx.x==0 && blockIdx.y==0 && blockIdx.z==0){ + // printf("### kSrcBytes:%d, + // IteratorB::kAccessesPerVector:%d\n", kSrcBytes, + // IteratorB::kAccessesPerVector); + // } + // // static_assert(kSrcBytes==32, "kSrcBytes==32"); + // if( threadIdx.y==0 && threadIdx.z==0 && + // blockIdx.x==0 && blockIdx.y==0 && blockIdx.z==0){ + // printf("gmem_ptr cp source of thread %d: ",threadIdx.x); + // for(int i=0;i<16;++i){ + // printf("%d, ",iterator_B.get()[i]); + // } + // printf("\n"); + // } CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) { - if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) { - typename IteratorB::AccessType* dst_ptr = - reinterpret_cast(this->smem_iterator_B_.get()); - - int const kSrcBytes = sizeof_bits::value - * IteratorB::ThreadMap::kElementsPerAccess / IteratorB::kAccessesPerVector / 8; - // if((threadIdx.x==0) && threadIdx.y==0 && threadIdx.z==0 && - // blockIdx.x==0 && blockIdx.y==0 && blockIdx.z==0){ - // printf("### kSrcBytes:%d, IteratorB::kAccessesPerVector:%d\n", kSrcBytes, IteratorB::kAccessesPerVector); - // } - // // static_assert(kSrcBytes==32, "kSrcBytes==32"); - // if( threadIdx.y==0 && threadIdx.z==0 && - // blockIdx.x==0 && blockIdx.y==0 && blockIdx.z==0){ - // printf("gmem_ptr cp source of thread %d: ",threadIdx.x); - // for(int i=0;i<16;++i){ - // printf("%d, ",iterator_B.get()[i]); - // } - // printf("\n"); - // } - CUTLASS_PRAGMA_UNROLL - for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { - auto gmem_ptr = iterator_B.get(); - - if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { - cutlass::arch::cp_async_zfill(dst_ptr + v, gmem_ptr, iterator_B.valid()); - } - else { - cutlass::arch::cp_async(dst_ptr + v, gmem_ptr, iterator_B.valid()); - } - ++iterator_B; - } - ++this->smem_iterator_B_; - } + for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { + auto gmem_ptr = iterator_B.get(); + + if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { + cutlass::arch::cp_async_zfill( + dst_ptr + v, gmem_ptr, iterator_B.valid()); + } else { + cutlass::arch::cp_async( + dst_ptr + v, gmem_ptr, iterator_B.valid()); + } + ++iterator_B; } + ++this->smem_iterator_B_; + } } + } + + /// Perform a threadblock-scoped matrix multiply-accumulate + CUTLASS_DEVICE + void operator()( + ///< problem size of GEMM + int gemm_k_iterations, + ///< destination accumulator tile + FragmentC& accum, + ///< iterator over A operand in global memory + IteratorA iterator_A, + ///< iterator over B operand in global memory + IteratorB iterator_B, + ///< initial value of accumulator + FragmentC const& src_accum) { + // + // Prologue + // - /// Perform a threadblock-scoped matrix multiply-accumulate - CUTLASS_DEVICE - void operator()( - ///< problem size of GEMM - int gemm_k_iterations, - ///< destination accumulator tile - FragmentC& accum, - ///< iterator over A operand in global memory - IteratorA iterator_A, - ///< iterator over B operand in global memory - IteratorB iterator_B, - ///< initial value of accumulator - FragmentC const& src_accum) - { - - // - // Prologue - // - - TransformBAfterLDS lds_converter; - - // NOTE - switch to ldg.sts - // Issue this first, so cp.async.commit_group will commit this load as well. - // Note: we do not commit here and this load will commit in the same group as - // the first load of A. - - // Issue several complete stages - CUTLASS_PRAGMA_UNROLL - for (int stage = 0; stage < Base::kStages - 1; ++stage, --gemm_k_iterations) { + TransformBAfterLDS lds_converter; - iterator_A.clear_mask(gemm_k_iterations == 0); - iterator_B.clear_mask(gemm_k_iterations == 0); + // NOTE - switch to ldg.sts + // Issue this first, so cp.async.commit_group will commit this load as well. + // Note: we do not commit here and this load will commit in the same group + // as + // the first load of A. - iterator_A.set_iteration_index(0); - this->smem_iterator_A_.set_iteration_index(0); + // Issue several complete stages + CUTLASS_PRAGMA_UNROLL + for (int stage = 0; stage < Base::kStages - 1; + ++stage, --gemm_k_iterations) { + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); - // Async Copy for operand A - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { - typename IteratorA::AccessType* dst_ptr = - reinterpret_cast(this->smem_iterator_A_.get()); + iterator_A.set_iteration_index(0); + this->smem_iterator_A_.set_iteration_index(0); - CUTLASS_PRAGMA_UNROLL - for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { - int const kSrcBytes = sizeof_bits::value - * IteratorA::ThreadMap::kElementsPerAccess / IteratorA::kAccessesPerVector - / 8; + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { + typename IteratorA::AccessType* dst_ptr = + reinterpret_cast( + this->smem_iterator_A_.get()); - int src_bytes = (iterator_A.valid() ? kSrcBytes : 0); + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { + int const kSrcBytes = + sizeof_bits::value * + IteratorA::ThreadMap::kElementsPerAccess / + IteratorA::kAccessesPerVector / 8; - cutlass::arch::cp_async_zfill( - dst_ptr + v, iterator_A.get(), iterator_A.valid()); + int src_bytes = (iterator_A.valid() ? kSrcBytes : 0); - ++iterator_A; - } + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_A.get(), iterator_A.valid()); - ++this->smem_iterator_A_; - } + ++iterator_A; + } - iterator_B.set_iteration_index(0); - this->smem_iterator_B_.set_iteration_index(0); + ++this->smem_iterator_A_; + } - // Async Copy for operand B - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { - typename IteratorB::AccessType* dst_ptr = - reinterpret_cast(this->smem_iterator_B_.get()); + iterator_B.set_iteration_index(0); + this->smem_iterator_B_.set_iteration_index(0); - CUTLASS_PRAGMA_UNROLL - for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { - int const kSrcBytes = sizeof_bits::value - * IteratorB::ThreadMap::kElementsPerAccess / IteratorB::kAccessesPerVector - / 8; + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { + typename IteratorB::AccessType* dst_ptr = + reinterpret_cast( + this->smem_iterator_B_.get()); - cutlass::arch::cp_async_zfill( - dst_ptr + v, iterator_B.get(), iterator_B.valid()); + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { + int const kSrcBytes = + sizeof_bits::value * + IteratorB::ThreadMap::kElementsPerAccess / + IteratorB::kAccessesPerVector / 8; - ++iterator_B; - } + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_B.get(), iterator_B.valid()); - ++this->smem_iterator_B_; - } + ++iterator_B; + } - // Move to the next stage - iterator_A.add_tile_offset({0, 1}); - iterator_B.add_tile_offset({1, 0}); + ++this->smem_iterator_B_; + } - this->smem_iterator_A_.add_tile_offset({0, 1}); - this->smem_iterator_B_.add_tile_offset({1, 0}); + // Move to the next stage + iterator_A.add_tile_offset({0, 1}); + iterator_B.add_tile_offset({1, 0}); - // Defines the boundary of a stage of cp.async. - cutlass::arch::cp_async_fence(); - } + this->smem_iterator_A_.add_tile_offset({0, 1}); + this->smem_iterator_B_.add_tile_offset({1, 0}); + + // Defines the boundary of a stage of cp.async. + cutlass::arch::cp_async_fence(); + } - // Perform accumulation in the 'd' output operand - accum = src_accum; + // Perform accumulation in the 'd' output operand + accum = src_accum; - // - // Clear the remaining tiles of SMEM. This is a functional requirement for some kernels - // so that all accumulator elements outside the GEMM footprint are zero. - // + // + // Clear the remaining tiles of SMEM. This is a functional requirement for + // some kernels so that all accumulator elements outside the GEMM footprint + // are zero. + // - if (SharedMemoryClear == SharedMemoryClearOption::kClearLastStage) { + if (SharedMemoryClear == SharedMemoryClearOption::kClearLastStage) { + /// Iterator to write threadblock-scoped tile of A operand to shared + /// memory + SmemIteratorA last_smem_iterator_A(this->smem_iterator_A_); - /// Iterator to write threadblock-scoped tile of A operand to shared memory - SmemIteratorA last_smem_iterator_A(this->smem_iterator_A_); + typename IteratorA::AccessType zero_A; + zero_A.clear(); - typename IteratorA::AccessType zero_A; - zero_A.clear(); + last_smem_iterator_A.set_iteration_index(0); - last_smem_iterator_A.set_iteration_index(0); + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { + typename IteratorA::AccessType* dst_ptr = + reinterpret_cast( + last_smem_iterator_A.get()); - // Async Copy for operand A - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { + *dst_ptr = zero_A; - typename IteratorA::AccessType* dst_ptr = - reinterpret_cast(last_smem_iterator_A.get()); + ++last_smem_iterator_A; + } - *dst_ptr = zero_A; + /// Iterator to write threadblock-scoped tile of B operand to shared + /// memory + SmemIteratorB last_smem_iterator_B(this->smem_iterator_B_); + typename IteratorB::AccessType zero_B; - ++last_smem_iterator_A; - } + zero_B.clear(); + last_smem_iterator_B.set_iteration_index(0); - /// Iterator to write threadblock-scoped tile of B operand to shared memory - SmemIteratorB last_smem_iterator_B(this->smem_iterator_B_); - typename IteratorB::AccessType zero_B; + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { + typename IteratorB::AccessType* dst_ptr = + reinterpret_cast( + last_smem_iterator_B.get()); - zero_B.clear(); - last_smem_iterator_B.set_iteration_index(0); + *dst_ptr = zero_B; - // Async Copy for operand B - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { + ++last_smem_iterator_B; + } + } - typename IteratorB::AccessType* dst_ptr = - reinterpret_cast(last_smem_iterator_B.get()); + // Waits until kStages-2 stages have committed. + cutlass::arch::cp_async_wait(); + __syncthreads(); - *dst_ptr = zero_B; + // Pair of fragments used to overlap shared memory loads and math + // instructions + WarpFragmentA warp_frag_A[2]; + WarpFragmentB warp_frag_B[2]; - ++last_smem_iterator_B; - } - } + Operator warp_mma; - // Waits until kStages-2 stages have committed. - cutlass::arch::cp_async_wait(); - __syncthreads(); + this->warp_tile_iterator_A_.set_kgroup_index(0); + this->warp_tile_iterator_B_.set_kgroup_index(0); - // Pair of fragments used to overlap shared memory loads and math - // instructions - WarpFragmentA warp_frag_A[2]; - WarpFragmentB warp_frag_B[2]; + this->warp_tile_iterator_A_.load(warp_frag_A[0]); + this->warp_tile_iterator_B_.load(warp_frag_B[0]); - Operator warp_mma; + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_B_; - this->warp_tile_iterator_A_.set_kgroup_index(0); - this->warp_tile_iterator_B_.set_kgroup_index(0); + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); - this->warp_tile_iterator_A_.load(warp_frag_A[0]); - this->warp_tile_iterator_B_.load(warp_frag_B[0]); + int smem_write_stage_idx = Base::kStages - 1; + int smem_read_stage_idx = 0; + // + // Mainloop + // + + CUTLASS_GEMM_LOOP + for (; gemm_k_iterations > (-Base::kStages + 1);) { + // + // Loop over GEMM K dimension + // + + // Computes a warp-level GEMM on data held in shared memory + // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; + ++warp_mma_k) { + // Load warp-level tiles from shared memory, wrapping to k offset if + // this is the last group as the case may be. + this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % + Base::kWarpGemmIterations); + this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]); ++this->warp_tile_iterator_A_; - ++this->warp_tile_iterator_B_; - - iterator_A.clear_mask(gemm_k_iterations == 0); - iterator_B.clear_mask(gemm_k_iterations == 0); - - int smem_write_stage_idx = Base::kStages - 1; - int smem_read_stage_idx = 0; - - // - // Mainloop - // - - CUTLASS_GEMM_LOOP - for (; gemm_k_iterations > (-Base::kStages + 1);) { - // - // Loop over GEMM K dimension - // - - // Computes a warp-level GEMM on data held in shared memory - // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate - CUTLASS_PRAGMA_UNROLL - for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) { - - // Load warp-level tiles from shared memory, wrapping to k offset if - // this is the last group as the case may be. - this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); - this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]); - ++this->warp_tile_iterator_A_; - // static_assert(Base::kNumKIterationsPerWarpBLoad==1,"Base::kNumKIterationsPerWarpBLoad!=1"); - // static_assert(Base::kWarpGemmIterationsForB==4,"Base::kWarpGemmIterationsForB!=4"); - const int warp_tileB_k_compute_offset = warp_mma_k % Base::kNumKIterationsPerWarpBLoad; - const int warp_tileB_k_load_offset = warp_mma_k / Base::kNumKIterationsPerWarpBLoad; - if (warp_tileB_k_compute_offset == Base::kNumKIterationsPerWarpBLoad - 1) { - this->warp_tile_iterator_B_.set_kgroup_index((warp_tileB_k_load_offset + 1) - % Base::kWarpGemmIterationsForB); - this->warp_tile_iterator_B_.load(warp_frag_B[(warp_tileB_k_load_offset + 1) % 2]); - ++this->warp_tile_iterator_B_; - } - // TOOD(wangbojun) lds_converter can be remove for int8 B input - typename TransformBAfterLDS::result_type converted_frag_B = - lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2]); - - // TODO(zhengzekang) - // run_warp_mma( - run_ampere_warp_mma( - warp_mma, accum, warp_frag_A[warp_mma_k % 2], converted_frag_B, accum, warp_tileB_k_compute_offset); - // if(threadIdx.x==0 && threadIdx.y==0 && threadIdx.z==0 && - // blockIdx.x==0 && blockIdx.y==0 && blockIdx.z==0){ - // printf("### run_warp_mma: " - // "%d \n", - // reinterpret_cast(accum)); - // } - // Issue global->shared copies for the this stage - if (warp_mma_k < Base::kWarpGemmIterations - 1) { - int group_start_iteration_A, group_start_iteration_B; - - group_start_iteration_A = warp_mma_k * Detail::kAccessesPerGroupA; - group_start_iteration_B = warp_mma_k * Detail::kAccessesPerGroupB; - - copy_tiles_and_advance(iterator_A, iterator_B, group_start_iteration_A, group_start_iteration_B); - } - - if (warp_mma_k + 2 == Base::kWarpGemmIterations) { - int group_start_iteration_A, group_start_iteration_B; - group_start_iteration_A = (warp_mma_k + 1) * Detail::kAccessesPerGroupA; - group_start_iteration_B = (warp_mma_k + 1) * Detail::kAccessesPerGroupB; - - copy_tiles_and_advance(iterator_A, iterator_B, group_start_iteration_A, group_start_iteration_B); - - // Inserts a memory fence between stages of cp.async instructions. - cutlass::arch::cp_async_fence(); - - // Waits until kStages-2 stages have committed. - arch::cp_async_wait(); - __syncthreads(); - - // Move to the next stage - iterator_A.add_tile_offset({0, 1}); - iterator_B.add_tile_offset({1, 0}); - - this->smem_iterator_A_.add_tile_offset({0, 1}); - this->smem_iterator_B_.add_tile_offset({1, 0}); - - // Add negative offsets to return iterators to the 'start' of the - // circular buffer in shared memory - if (smem_write_stage_idx == (Base::kStages - 1)) { - this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); - this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); - smem_write_stage_idx = 0; - } - else { - ++smem_write_stage_idx; - } - - if (smem_read_stage_idx == (Base::kStages - 1)) { - this->warp_tile_iterator_A_.add_tile_offset( - {0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations}); - this->warp_tile_iterator_B_.add_tile_offset( - {-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterationsForB, 0}); - smem_read_stage_idx = 0; - } - else { - ++smem_read_stage_idx; - } - - --gemm_k_iterations; - iterator_A.clear_mask(gemm_k_iterations == 0); - iterator_B.clear_mask(gemm_k_iterations == 0); - } - } + // static_assert(Base::kNumKIterationsPerWarpBLoad==1,"Base::kNumKIterationsPerWarpBLoad!=1"); + // static_assert(Base::kWarpGemmIterationsForB==4,"Base::kWarpGemmIterationsForB!=4"); + const int warp_tileB_k_compute_offset = + warp_mma_k % Base::kNumKIterationsPerWarpBLoad; + const int warp_tileB_k_load_offset = + warp_mma_k / Base::kNumKIterationsPerWarpBLoad; + if (warp_tileB_k_compute_offset == + Base::kNumKIterationsPerWarpBLoad - 1) { + this->warp_tile_iterator_B_.set_kgroup_index( + (warp_tileB_k_load_offset + 1) % Base::kWarpGemmIterationsForB); + this->warp_tile_iterator_B_.load( + warp_frag_B[(warp_tileB_k_load_offset + 1) % 2]); + ++this->warp_tile_iterator_B_; + } + // TODO(wangbojun) lds_converter can be remove for int8 B input + typename TransformBAfterLDS::result_type converted_frag_B = + lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2]); + + // TODO(zhengzekang) + // run_warp_mma( + run_ampere_warp_mma(warp_mma, + accum, + warp_frag_A[warp_mma_k % 2], + converted_frag_B, + accum, + warp_tileB_k_compute_offset); + // if(threadIdx.x==0 && threadIdx.y==0 && threadIdx.z==0 && + // blockIdx.x==0 && blockIdx.y==0 && blockIdx.z==0){ + // printf("### run_warp_mma: " + // "%d \n", + // reinterpret_cast(accum)); + // } + // Issue global->shared copies for the this stage + if (warp_mma_k < Base::kWarpGemmIterations - 1) { + int group_start_iteration_A, group_start_iteration_B; + + group_start_iteration_A = warp_mma_k * Detail::kAccessesPerGroupA; + group_start_iteration_B = warp_mma_k * Detail::kAccessesPerGroupB; + + copy_tiles_and_advance(iterator_A, + iterator_B, + group_start_iteration_A, + group_start_iteration_B); } - if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { - // commit and drain all pending and predicated LDGSTS pnz from the GEMM mainloop - cutlass::arch::cp_async_fence(); - cutlass::arch::cp_async_wait<0>(); - __syncthreads(); + if (warp_mma_k + 2 == Base::kWarpGemmIterations) { + int group_start_iteration_A, group_start_iteration_B; + group_start_iteration_A = + (warp_mma_k + 1) * Detail::kAccessesPerGroupA; + group_start_iteration_B = + (warp_mma_k + 1) * Detail::kAccessesPerGroupB; + + copy_tiles_and_advance(iterator_A, + iterator_B, + group_start_iteration_A, + group_start_iteration_B); + + // Inserts a memory fence between stages of cp.async instructions. + cutlass::arch::cp_async_fence(); + + // Waits until kStages-2 stages have committed. + arch::cp_async_wait(); + __syncthreads(); + + // Move to the next stage + iterator_A.add_tile_offset({0, 1}); + iterator_B.add_tile_offset({1, 0}); + + this->smem_iterator_A_.add_tile_offset({0, 1}); + this->smem_iterator_B_.add_tile_offset({1, 0}); + + // Add negative offsets to return iterators to the 'start' of the + // circular buffer in shared memory + if (smem_write_stage_idx == (Base::kStages - 1)) { + this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); + this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); + smem_write_stage_idx = 0; + } else { + ++smem_write_stage_idx; + } + + if (smem_read_stage_idx == (Base::kStages - 1)) { + this->warp_tile_iterator_A_.add_tile_offset( + {0, + -Base::kStages * Policy::kPartitionsK * + Base::kWarpGemmIterations}); + this->warp_tile_iterator_B_.add_tile_offset( + {-Base::kStages * Policy::kPartitionsK * + Base::kWarpGemmIterationsForB, + 0}); + smem_read_stage_idx = 0; + } else { + ++smem_read_stage_idx; + } + + --gemm_k_iterations; + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); } + } + } + + if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { + // commit and drain all pending and predicated LDGSTS pnz from the GEMM + // mainloop + cutlass::arch::cp_async_fence(); + cutlass::arch::cp_async_wait<0>(); + __syncthreads(); } + } }; ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/gemm/threadblock/int8_mma_pipelined.h b/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/gemm/threadblock/int8_mma_pipelined.h index 2b7312f365a..f7c9cac33f7 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/gemm/threadblock/int8_mma_pipelined.h +++ b/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/gemm/threadblock/int8_mma_pipelined.h @@ -1,12 +1,12 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation @@ -18,14 +18,15 @@ * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ @@ -74,18 +75,21 @@ namespace threadblock { ///////////////////////////////////////////////////////////////////////////////////////////////// -/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions. -template< +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < /// Size of the Gemm problem - concept: gemm::GemmShape<> typename Shape_, /// Iterates over tiles of A operand in global memory - // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) typename IteratorA_, /// Iterates over tiles of A operand in shared memory /// (concept: WriteableTileIterator | RandomAccessTileIterator) typename SmemIteratorA_, /// Iterates over tiles of B operand in global memory - // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) typename IteratorB_, /// Iterates over tiles of B operand in shared memory /// (concept: WriteableTileIterator | RandomAccessTileIterator) @@ -102,250 +106,275 @@ template< typename TransformBAfterLDS_, /// Used for partial specialization typename Enable = bool> -class Int8InterleavedMmaPipelined: public Int8InterleavedMmaBase { -public: - ///< Base class - using Base = Int8InterleavedMmaBase; - using Shape = Shape_; ///< Size of the Gemm problem - concept: gemm::GemmShape<> - using IteratorA = IteratorA_; ///< Iterates over tiles of A operand in global memory - using IteratorB = IteratorB_; ///< Iterates over tiles of B operand in global memory - using ElementC = ElementC_; ///< Data type of accumulator matrix - using LayoutC = LayoutC_; ///< Layout of accumulator matrix - using Policy = Policy_; ///< Policy describing tuning details - - using SmemIteratorA = SmemIteratorA_; - using SmemIteratorB = SmemIteratorB_; - - using TransformBAfterLDG = TransformBAfterLDG_; - using TransformBAfterLDS = TransformBAfterLDS_; +class Int8InterleavedMmaPipelined + : public Int8InterleavedMmaBase { + public: + ///< Base class + using Base = Int8InterleavedMmaBase; + using Shape = + Shape_; ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using IteratorA = + IteratorA_; ///< Iterates over tiles of A operand in global memory + using IteratorB = + IteratorB_; ///< Iterates over tiles of B operand in global memory + using ElementC = ElementC_; ///< Data type of accumulator matrix + using LayoutC = LayoutC_; ///< Layout of accumulator matrix + using Policy = Policy_; ///< Policy describing tuning details + + using SmemIteratorA = SmemIteratorA_; + using SmemIteratorB = SmemIteratorB_; + + using TransformBAfterLDG = TransformBAfterLDG_; + using TransformBAfterLDS = TransformBAfterLDS_; + + // + // Dependent types + // + + /// Fragment of operand A loaded from global memory + using FragmentA = typename IteratorA::Fragment; + + /// Fragment of operand B loaded from global memory + using FragmentB = typename IteratorB::Fragment; + + /// Fragment of accumulator tile + using FragmentC = typename Policy::Operator::FragmentC; + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Obtain the arch tag from the warp-level operator + using ArchTag = typename Policy::Operator::ArchTag; + + /// Complex transform on A operand + static ComplexTransform const kTransformA = Operator::kTransformA; + + /// Complex transform on B operand + static ComplexTransform const kTransformB = Operator::kTransformB; + + // staticaly assert kStages for Int8InterleavedMmaPipelined is two + // (Double-buffered pipeline) + static_assert((Base::kStages == 2), + "Int8InterleavedMmaPipelined requires kStages set to value 2"); + + private: + using WarpFragmentA = typename Operator::FragmentA; + using WarpFragmentB = typename Operator::FragmentB; + using ElementB = typename IteratorB::Element; + using LayoutDetailsForB = kernel::LayoutDetailsB; + + static constexpr bool RequiresTileInterleave = + layout::IsColumnMajorTileInterleave< + typename LayoutDetailsForB::Layout>::value; + static_assert(!RequiresTileInterleave || + (RequiresTileInterleave && + (Shape::kK == LayoutDetailsForB::ThreadblockK)), + "Layout K must match threadblockK"); + + protected: + /// Iterator to write threadblock-scoped tile of A operand to shared memory + SmemIteratorA smem_iterator_A_; + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB smem_iterator_B_; + + public: + /// Construct from tensor references + CUTLASS_DEVICE + Int8InterleavedMmaPipelined( + typename Base::SharedStorage& + shared_storage, ///< Shared storage needed for internal use by + ///< threadblock-scoped GEMM + int thread_idx, ///< ID within the threadblock + int warp_idx, ///< ID of warp + int lane_idx ///< ID of each thread within a warp + ) + : Base(shared_storage, thread_idx, warp_idx, lane_idx), + smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx), + smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) { + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + + int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); + int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); + + int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; + int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; + + // Add per-warp offsets in units of warp-level tiles + this->warp_tile_iterator_A_.add_tile_offset( + {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); + this->warp_tile_iterator_B_.add_tile_offset( + {Base::kWarpGemmIterationsForB * warp_idx_k, warp_idx_n}); + } + + /// Perform a threadblock-scoped matrix multiply-accumulate + CUTLASS_DEVICE + void operator()( + int gemm_k_iterations, ///< number of iterations of the mainloop + FragmentC& accum, ///< destination accumulator tile + IteratorA iterator_A, ///< iterator over A operand in global memory + IteratorB iterator_B, ///< iterator over B operand in global memory + FragmentC const& src_accum) { ///< source accumulator tile // - // Dependent types + // Prologue // + TransformBAfterLDG ldg_converter; + TransformBAfterLDS lds_converter; - /// Fragment of operand A loaded from global memory - using FragmentA = typename IteratorA::Fragment; - - /// Fragment of operand B loaded from global memory - using FragmentB = typename IteratorB::Fragment; - - /// Fragment of accumulator tile - using FragmentC = typename Policy::Operator::FragmentC; - - /// Warp-level Mma - using Operator = typename Policy::Operator; - - /// Obtain the arch tag from the warp-level operator - using ArchTag = typename Policy::Operator::ArchTag; - - /// Complex transform on A operand - static ComplexTransform const kTransformA = Operator::kTransformA; - - /// Complex transform on B operand - static ComplexTransform const kTransformB = Operator::kTransformB; - - // staticaly assert kStages for Int8InterleavedMmaPipelined is two (Double-buffered pipeline) - static_assert((Base::kStages == 2), "Int8InterleavedMmaPipelined requires kStages set to value 2"); - -private: - using WarpFragmentA = typename Operator::FragmentA; - using WarpFragmentB = typename Operator::FragmentB; - using ElementB = typename IteratorB::Element; - using LayoutDetailsForB = kernel::LayoutDetailsB; - - static constexpr bool RequiresTileInterleave = - layout::IsColumnMajorTileInterleave::value; - static_assert(!RequiresTileInterleave || (RequiresTileInterleave && (Shape::kK == LayoutDetailsForB::ThreadblockK)), - "Layout K must match threadblockK"); + using TransformA = NumericArrayConverter; -protected: - /// Iterator to write threadblock-scoped tile of A operand to shared memory - SmemIteratorA smem_iterator_A_; + // These transforms are mainly to handle when we have bfloat activations and + // weights in GMEM and want to issue HMMA on architectures older than + // Ampere. We will convert to FP16 before STS. + TransformA transformA; - /// Iterator to write threadblock-scoped tile of B operand to shared memory - SmemIteratorB smem_iterator_B_; + // Perform accumulation in the 'd' output operand + accum = src_accum; -public: - /// Construct from tensor references - CUTLASS_DEVICE - Int8InterleavedMmaPipelined(typename Base::SharedStorage& - shared_storage, ///< Shared storage needed for internal use by threadblock-scoped GEMM - int thread_idx, ///< ID within the threadblock - int warp_idx, ///< ID of warp - int lane_idx ///< ID of each thread within a warp - ): - Base(shared_storage, thread_idx, warp_idx, lane_idx), - smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx), - smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) - { - - // Compute warp location within threadblock tile by mapping the warp_id to - // three coordinates: - // _m: the warp's position within the threadblock along the M dimension - // _n: the warp's position within the threadblock along the N dimension - // _k: the warp's position within the threadblock along the K dimension - - int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); - int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); - - int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; - int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; - - // Add per-warp offsets in units of warp-level tiles - this->warp_tile_iterator_A_.add_tile_offset({warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); - this->warp_tile_iterator_B_.add_tile_offset({Base::kWarpGemmIterationsForB * warp_idx_k, warp_idx_n}); - } + FragmentA tb_frag_A; + FragmentB tb_frag_B; - /// Perform a threadblock-scoped matrix multiply-accumulate - CUTLASS_DEVICE - void operator()(int gemm_k_iterations, ///< number of iterations of the mainloop - FragmentC& accum, ///< destination accumulator tile - IteratorA iterator_A, ///< iterator over A operand in global memory - IteratorB iterator_B, ///< iterator over B operand in global memory - FragmentC const& src_accum) - { ///< source accumulator tile + tb_frag_A.clear(); + tb_frag_B.clear(); - // - // Prologue - // - TransformBAfterLDG ldg_converter; - TransformBAfterLDS lds_converter; + // The last kblock is loaded in the prolog + iterator_A.load(tb_frag_A); + iterator_B.load(tb_frag_B); - using TransformA = - NumericArrayConverter; + ++iterator_A; + ++iterator_B; - // These transforms are mainly to handle when we have bfloat activations and weights in GMEM and want - // to issue HMMA on architectures older than Ampere. We will convert to FP16 before STS. - TransformA transformA; + this->smem_iterator_A_.store(transformA(tb_frag_A)); + this->smem_iterator_B_.store(ldg_converter(tb_frag_B)); - // Perform accumulation in the 'd' output operand - accum = src_accum; + ++this->smem_iterator_A_; + ++this->smem_iterator_B_; - FragmentA tb_frag_A; - FragmentB tb_frag_B; + __syncthreads(); + // Pair of fragments used to overlap shared memory loads and math + // instructions + WarpFragmentA warp_frag_A[2]; + WarpFragmentB warp_frag_B[2]; - tb_frag_A.clear(); - tb_frag_B.clear(); + this->warp_tile_iterator_A_.set_kgroup_index(0); + this->warp_tile_iterator_B_.set_kgroup_index(0); - // The last kblock is loaded in the prolog - iterator_A.load(tb_frag_A); - iterator_B.load(tb_frag_B); + this->warp_tile_iterator_A_.load(warp_frag_A[0]); + this->warp_tile_iterator_B_.load(warp_frag_B[0]); - ++iterator_A; - ++iterator_B; + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_B_; - this->smem_iterator_A_.store(transformA(tb_frag_A)); - this->smem_iterator_B_.store(ldg_converter(tb_frag_B)); + Operator warp_mma; - ++this->smem_iterator_A_; - ++this->smem_iterator_B_; + int smem_write_stage_idx = 1; - __syncthreads(); + // Avoid reading out of bounds + iterator_A.clear_mask(gemm_k_iterations <= 1); + iterator_B.clear_mask(gemm_k_iterations <= 1); - // Pair of fragments used to overlap shared memory loads and math instructions - WarpFragmentA warp_frag_A[2]; - WarpFragmentB warp_frag_B[2]; + // Issue loads during the first warp-level matrix multiply-add *AFTER* + // issuing shared memory loads (which have the tighest latency requirement). - this->warp_tile_iterator_A_.set_kgroup_index(0); - this->warp_tile_iterator_B_.set_kgroup_index(0); + // + // Mainloop + // - this->warp_tile_iterator_A_.load(warp_frag_A[0]); - this->warp_tile_iterator_B_.load(warp_frag_B[0]); + // Note: The main loop does not support Base::kWarpGemmIterations == 2. + CUTLASS_GEMM_LOOP + for (; gemm_k_iterations > 0; --gemm_k_iterations) { + // + // Loop over GEMM K dimension + // + + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; + ++warp_mma_k) { + // Load warp-level tiles from shared memory, wrapping to k offset if + // this is the last group as the case may be. + + if (warp_mma_k == Base::kWarpGemmIterations - 1) { + // Write fragments to shared memory + this->smem_iterator_A_.store(transformA(tb_frag_A)); + + this->smem_iterator_B_.store(ldg_converter(tb_frag_B)); + + __syncthreads(); + + ++this->smem_iterator_A_; + ++this->smem_iterator_B_; + + // Add negative offsets to return iterators to the 'start' of the + // circular buffer in shared memory + if (smem_write_stage_idx == 1) { + this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); + this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); + } else { + this->warp_tile_iterator_A_.add_tile_offset( + {0, + -Base::kStages * Policy::kPartitionsK * + Base::kWarpGemmIterations}); + this->warp_tile_iterator_B_.add_tile_offset( + {-Base::kStages * Policy::kPartitionsK * + Base::kWarpGemmIterationsForB, + 0}); + } + + smem_write_stage_idx ^= 1; + } + this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % + Base::kWarpGemmIterations); + this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]); ++this->warp_tile_iterator_A_; - ++this->warp_tile_iterator_B_; - - Operator warp_mma; - int smem_write_stage_idx = 1; - - // Avoid reading out of bounds - iterator_A.clear_mask(gemm_k_iterations <= 1); - iterator_B.clear_mask(gemm_k_iterations <= 1); - - // Issue loads during the first warp-level matrix multiply-add *AFTER* issuing - // shared memory loads (which have the tighest latency requirement). - - // - // Mainloop - // - - // Note: The main loop does not support Base::kWarpGemmIterations == 2. - CUTLASS_GEMM_LOOP - for (; gemm_k_iterations > 0; --gemm_k_iterations) { - // - // Loop over GEMM K dimension - // - - CUTLASS_PRAGMA_UNROLL - for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) { - - // Load warp-level tiles from shared memory, wrapping to k offset if this is the last group - // as the case may be. - - if (warp_mma_k == Base::kWarpGemmIterations - 1) { - - // Write fragments to shared memory - this->smem_iterator_A_.store(transformA(tb_frag_A)); - - this->smem_iterator_B_.store(ldg_converter(tb_frag_B)); - - __syncthreads(); - - ++this->smem_iterator_A_; - ++this->smem_iterator_B_; - - // Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory - if (smem_write_stage_idx == 1) { - this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); - this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); - } - else { - this->warp_tile_iterator_A_.add_tile_offset( - {0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations}); - this->warp_tile_iterator_B_.add_tile_offset( - {-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterationsForB, 0}); - } - - smem_write_stage_idx ^= 1; - } - - this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); - this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]); - ++this->warp_tile_iterator_A_; - - const int warp_tileB_k_compute_offset = warp_mma_k % Base::kNumKIterationsPerWarpBLoad; - const int warp_tileB_k_load_offset = warp_mma_k / Base::kNumKIterationsPerWarpBLoad; - // We are just about to finish computing on a fragment of B, so initiate the load for the next fragment. - if (warp_tileB_k_compute_offset == Base::kNumKIterationsPerWarpBLoad - 1) { - this->warp_tile_iterator_B_.set_kgroup_index((warp_tileB_k_load_offset + 1) - % Base::kWarpGemmIterationsForB); - this->warp_tile_iterator_B_.load(warp_frag_B[(warp_tileB_k_load_offset + 1) % 2]); - ++this->warp_tile_iterator_B_; - } - - if (warp_mma_k == 0) { - - iterator_A.load(tb_frag_A); - iterator_B.load(tb_frag_B); + const int warp_tileB_k_compute_offset = + warp_mma_k % Base::kNumKIterationsPerWarpBLoad; + const int warp_tileB_k_load_offset = + warp_mma_k / Base::kNumKIterationsPerWarpBLoad; + // We are just about to finish computing on a fragment of B, so initiate + // the load for the next fragment. + if (warp_tileB_k_compute_offset == + Base::kNumKIterationsPerWarpBLoad - 1) { + this->warp_tile_iterator_B_.set_kgroup_index( + (warp_tileB_k_load_offset + 1) % Base::kWarpGemmIterationsForB); + this->warp_tile_iterator_B_.load( + warp_frag_B[(warp_tileB_k_load_offset + 1) % 2]); + ++this->warp_tile_iterator_B_; + } - ++iterator_A; - ++iterator_B; + if (warp_mma_k == 0) { + iterator_A.load(tb_frag_A); + iterator_B.load(tb_frag_B); - // Avoid reading out of bounds if this was the last loop iteration - iterator_A.clear_mask(gemm_k_iterations <= 2); - iterator_B.clear_mask(gemm_k_iterations <= 2); - } + ++iterator_A; + ++iterator_B; - typename TransformBAfterLDS::result_type converted_frag_B = - lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2]); - // run_warp_mma( - run_ampere_warp_mma( - warp_mma, accum, warp_frag_A[warp_mma_k % 2], converted_frag_B, accum, warp_tileB_k_compute_offset); - } + // Avoid reading out of bounds if this was the last loop iteration + iterator_A.clear_mask(gemm_k_iterations <= 2); + iterator_B.clear_mask(gemm_k_iterations <= 2); } + + typename TransformBAfterLDS::result_type converted_frag_B = + lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2]); + // run_warp_mma( + run_ampere_warp_mma(warp_mma, + accum, + warp_frag_A[warp_mma_k % 2], + converted_frag_B, + accum, + warp_tileB_k_compute_offset); + } } + } }; ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/gemm/threadblock/nf4_int8_mma_base.h b/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/gemm/threadblock/nf4_int8_mma_base.h index 02c03c707d8..5048bc52e06 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/gemm/threadblock/nf4_int8_mma_base.h +++ b/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/gemm/threadblock/nf4_int8_mma_base.h @@ -1,12 +1,12 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation @@ -18,14 +18,15 @@ * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. @@ -65,7 +66,7 @@ namespace threadblock { //////////////////////////////////////////////////////////////////////////////// /// Structure to compute the matrix product targeting CUDA cores and SIMT math /// instructions. -template< +template < /// Size of the Gemm problem - concept: gemm::GemmShape<> typename Shape_, /// Policy describing tuning details (concept: MmaPolicy) @@ -75,138 +76,142 @@ template< /// Used for partial specialization typename Enable = bool> class Int8Nf4InterleavedMmaBase { -public: - ///< Size of the Gemm problem - concept: gemm::GemmShape<> - using Shape = Shape_; - - ///< Policy describing tuning details - using Policy = Policy_; - + public: + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape = Shape_; + + ///< Policy describing tuning details + using Policy = Policy_; + + // + // Dependent types + // + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Shape describing the overall GEMM computed from shared memory + /// by each warp. + using WarpGemm = typename Policy::Operator::Shape; + + /// Shape describing the number of warps filling the CTA + using WarpCount = GemmShape; + + /// Number of warp-level GEMM operations + static int const kWarpGemmIterations = + (WarpGemm::kK / Operator::Policy::MmaShape::kK); + static_assert(Operator::IteratorB::InstructionShape::kRow >= + Operator::InstructionShape::kK, + ""); + static constexpr int kNumKIterationsPerWarpBLoad = + Operator::IteratorB::InstructionShape::kRow / + Operator::InstructionShape::kK; + static_assert(!(kWarpGemmIterations % kNumKIterationsPerWarpBLoad), ""); + static constexpr int kWarpGemmIterationsForB = + kWarpGemmIterations / kNumKIterationsPerWarpBLoad; + + /// Number of stages + static int const kStages = Stages; + + /// Tensor reference to the A operand + using TensorRefA = + TensorRef; + + /// Tensor reference to the B operand + using TensorRefB = + TensorRef; + + // + // Nested structs + // + + /// Shared storage object needed by threadblock-scoped GEMM + class SharedStorage { + public: // - // Dependent types + // Type definitions // - /// Warp-level Mma - using Operator = typename Policy::Operator; + /// Shape of the A matrix operand in shared memory + using ShapeA = + MatrixShape; - /// Shape describing the overall GEMM computed from shared memory - /// by each warp. - using WarpGemm = typename Policy::Operator::Shape; + /// Shape of the B matrix operand in shared memory + using ShapeB = MatrixShape; - /// Shape describing the number of warps filling the CTA - using WarpCount = GemmShape; - - /// Number of warp-level GEMM oeprations - static int const kWarpGemmIterations = (WarpGemm::kK / Operator::Policy::MmaShape::kK); - static_assert(Operator::IteratorB::InstructionShape::kRow>=Operator::InstructionShape::kK,""); - static constexpr int kNumKIterationsPerWarpBLoad = - Operator::IteratorB::InstructionShape::kRow / Operator::InstructionShape::kK; - static_assert(!(kWarpGemmIterations % kNumKIterationsPerWarpBLoad), ""); - static constexpr int kWarpGemmIterationsForB = kWarpGemmIterations / kNumKIterationsPerWarpBLoad; + public: + // + // Data members + // - /// Number of stages - static int const kStages = Stages; + /// Buffer for A operand + AlignedBuffer operand_A; - /// Tensor reference to the A operand - using TensorRefA = TensorRef; + /// Buffer for B operand + AlignedBuffer operand_B; - /// Tensor reference to the B operand - using TensorRefB = TensorRef; + /// Buffer to hold scales for threadblock + AlignedBuffer operand_nf4_look_up_table; + public: // - // Nested structs + // Methods // - /// Shared storage object needed by threadblock-scoped GEMM - class SharedStorage { - public: - // - // Type definitions - // - - /// Shape of the A matrix operand in shared memory - using ShapeA = - MatrixShape; - - /// Shape of the B matrix operand in shared memory - using ShapeB = - MatrixShape; - - public: - // - // Data members - // - - /// Buffer for A operand - AlignedBuffer operand_A; - - /// Buffer for B operand - AlignedBuffer operand_B; - - /// Buffer to hold scales for threadblock - AlignedBuffer operand_nf4_look_up_table; - - public: - // - // Methods - // - - /// Returns a layout object for the A matrix - CUTLASS_DEVICE - static typename Operator::LayoutA LayoutA() - { - return Operator::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn}); - } - - /// Returns a layout object for the B matrix - CUTLASS_HOST_DEVICE - static typename Operator::LayoutB LayoutB() - { - return Operator::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn}); - } - - /// Returns a TensorRef to the A operand - CUTLASS_HOST_DEVICE - TensorRefA operand_A_ref() - { - return TensorRefA{operand_A.data(), LayoutA()}; - } - - /// Returns a TensorRef to the B operand - CUTLASS_HOST_DEVICE - TensorRefB operand_B_ref() - { - return TensorRefB{operand_B.data(), LayoutB()}; - } - }; - -protected: - // - // Data members - // + /// Returns a layout object for the A matrix + CUTLASS_DEVICE + static typename Operator::LayoutA LayoutA() { + return Operator::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn}); + } - /// Iterator to load a warp-scoped tile of A operand from shared memory - typename Operator::IteratorA warp_tile_iterator_A_; + /// Returns a layout object for the B matrix + CUTLASS_HOST_DEVICE + static typename Operator::LayoutB LayoutB() { + return Operator::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn}); + } - /// Iterator to load a warp-scoped tile of B operand from shared memory - typename Operator::IteratorB warp_tile_iterator_B_; + /// Returns a TensorRef to the A operand + CUTLASS_HOST_DEVICE + TensorRefA operand_A_ref() { + return TensorRefA{operand_A.data(), LayoutA()}; + } -public: - /// Construct from tensor references - CUTLASS_DEVICE - Int8Nf4InterleavedMmaBase( - ///< Shared storage needed for internal use by threadblock-scoped GEMM - SharedStorage& shared_storage, - ///< ID within the threadblock - int thread_idx, - ///< ID of warp - int warp_idx, - ///< ID of each thread within a warp - int lane_idx): - warp_tile_iterator_A_(shared_storage.operand_A_ref(), lane_idx), - warp_tile_iterator_B_(shared_storage.operand_B_ref(), lane_idx) - { + /// Returns a TensorRef to the B operand + CUTLASS_HOST_DEVICE + TensorRefB operand_B_ref() { + return TensorRefB{operand_B.data(), LayoutB()}; } + }; + + protected: + // + // Data members + // + + /// Iterator to load a warp-scoped tile of A operand from shared memory + typename Operator::IteratorA warp_tile_iterator_A_; + + /// Iterator to load a warp-scoped tile of B operand from shared memory + typename Operator::IteratorB warp_tile_iterator_B_; + + public: + /// Construct from tensor references + CUTLASS_DEVICE + Int8Nf4InterleavedMmaBase( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + SharedStorage& shared_storage, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx) + : warp_tile_iterator_A_(shared_storage.operand_A_ref(), lane_idx), + warp_tile_iterator_B_(shared_storage.operand_B_ref(), lane_idx) {} }; ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/gemm/threadblock/nf4_int8_mma_multistage.h b/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/gemm/threadblock/nf4_int8_mma_multistage.h index e7e204620d6..c28e0ed844c 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/gemm/threadblock/nf4_int8_mma_multistage.h +++ b/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/gemm/threadblock/nf4_int8_mma_multistage.h @@ -1,12 +1,12 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation @@ -18,18 +18,18 @@ * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ - /* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); @@ -66,8 +66,9 @@ limitations under the License. */ ///////////////////////////////////////////////////////////////////////////////////////////////// template -[[gnu::warning("your type here")]] -bool print_type() { return false; } +[[gnu::warning("your type here")]] bool print_type() { + return false; +} namespace cutlass { namespace gemm { @@ -77,7 +78,7 @@ namespace threadblock { /// Structure to compute the matrix product targeting CUDA cores and SIMT math /// instructions. -template< +template < /// Size of the Gemm problem - concept: gemm::GemmShape<> typename Shape_, /// Iterates over tiles of A operand in global memory @@ -116,467 +117,601 @@ template< SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone, /// Used for partial specialization typename Enable = bool> -class Int8Nf4InterleavedMmaMultistage: public Int8Nf4InterleavedMmaBase { -public: - ///< Base class - using Base = Int8Nf4InterleavedMmaBase; - ///< Size of the Gemm problem - concept: gemm::GemmShape<> - using Shape = Shape_; - ///< Iterates over tiles of A operand in global memory - using IteratorA = IteratorA_; - ///< Iterates over tiles of B operand in global memory - using IteratorB = IteratorB_; - ///< Iterates over tiles of nf4 look up table in global memory - using IteratorNF4LookUpTable = IteratorNF4LookUpTable_; - using ElementNF4LookUpTable = typename IteratorNF4LookUpTable::Element; - using LayoutNF4LookUpTable = typename IteratorNF4LookUpTable::Layout; - - - ///< Data type of accumulator matrix - using ElementC = ElementC_; - ///< Layout of accumulator matrix - using LayoutC = LayoutC_; - ///< Policy describing tuning details - using Policy = Policy_; - - - using SmemIteratorA = SmemIteratorA_; - using SmemIteratorB = SmemIteratorB_; - using SmemIteratorNF4LookUpTable = SmemIteratorNF4LookUpTable_; - static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; - static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; - - using TransformBAfterLDS = TransformBAfterLDS_; - - // - // Dependent types - // - - /// Fragment of accumulator tile - using FragmentC = typename Policy::Operator::FragmentC; - - /// Warp-level Mma - using Operator = typename Policy::Operator; - - /// Minimum architecture is Sm80 to support cp.async - using ArchTag = arch::Sm80; - - /// Complex transform on A operand - static ComplexTransform const kTransformA = Operator::kTransformA; - - /// Complex transform on B operand - static ComplexTransform const kTransformB = Operator::kTransformB; - - /// Internal structure exposed for introspection. - struct Detail { - - static_assert(Base::kWarpGemmIterations > 1, - "The pipelined structure requires at least two warp-level " - "GEMM operations."); - // static_assert(Base::kWarpGemmIterations==4,"Base::kWarpGemmIterations!=4"); - /// Number of cp.async instructions to load one stage of operand A - static int const AsyncCopyIterationsPerStageA = IteratorA::ThreadMap::Iterations::kCount; +class Int8Nf4InterleavedMmaMultistage + : public Int8Nf4InterleavedMmaBase { + public: + ///< Base class + using Base = Int8Nf4InterleavedMmaBase; + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape = Shape_; + ///< Iterates over tiles of A operand in global memory + using IteratorA = IteratorA_; + ///< Iterates over tiles of B operand in global memory + using IteratorB = IteratorB_; + ///< Iterates over tiles of nf4 look up table in global memory + using IteratorNF4LookUpTable = IteratorNF4LookUpTable_; + using ElementNF4LookUpTable = typename IteratorNF4LookUpTable::Element; + using LayoutNF4LookUpTable = typename IteratorNF4LookUpTable::Layout; + + ///< Data type of accumulator matrix + using ElementC = ElementC_; + ///< Layout of accumulator matrix + using LayoutC = LayoutC_; + ///< Policy describing tuning details + using Policy = Policy_; + + using SmemIteratorA = SmemIteratorA_; + using SmemIteratorB = SmemIteratorB_; + using SmemIteratorNF4LookUpTable = SmemIteratorNF4LookUpTable_; + static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; + static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; + + using TransformBAfterLDS = TransformBAfterLDS_; + + // + // Dependent types + // + + /// Fragment of accumulator tile + using FragmentC = typename Policy::Operator::FragmentC; + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Minimum architecture is Sm80 to support cp.async + using ArchTag = arch::Sm80; + + /// Complex transform on A operand + static ComplexTransform const kTransformA = Operator::kTransformA; + + /// Complex transform on B operand + static ComplexTransform const kTransformB = Operator::kTransformB; + + /// Internal structure exposed for introspection. + struct Detail { + static_assert(Base::kWarpGemmIterations > 1, + "The pipelined structure requires at least two warp-level " + "GEMM operations."); + // static_assert(Base::kWarpGemmIterations==4,"Base::kWarpGemmIterations!=4"); + /// Number of cp.async instructions to load one stage of operand A + static int const AsyncCopyIterationsPerStageA = + IteratorA::ThreadMap::Iterations::kCount; + + /// Number of cp.async instructions to load one stage of operand B + static int const AsyncCopyIterationsPerStageB = + IteratorB::ThreadMap::Iterations::kCount; + + /// Number of stages + static int const kStages = Stages; + + /// Number of cp.async instructions to load on group of operand A + static int const kAccessesPerGroupA = + (AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) / + Base::kWarpGemmIterations; + + /// Number of cp.async instructions to load on group of operand B + static int const kAccessesPerGroupB = + (AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / + Base::kWarpGemmIterations; + }; + + private: + using WarpFragmentA = typename Operator::FragmentA; + using WarpFragmentB = typename Operator::FragmentB; + using ElementB = typename IteratorB::Element; + using LayoutDetailsForB = kernel::LayoutDetailsB; + + static constexpr bool RequiresTileInterleave = + layout::IsColumnMajorTileInterleave< + typename LayoutDetailsForB::Layout>::value; + static_assert(!RequiresTileInterleave || + (RequiresTileInterleave && + (Shape::kK == LayoutDetailsForB::ThreadblockK)), + "Layout K must match threadblockK"); + + private: + // + // Data members + // + + /// Iterator to write threadblock-scoped tile of A operand to shared memory + SmemIteratorA smem_iterator_A_; + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB smem_iterator_B_; + + SmemIteratorNF4LookUpTable smem_iterator_nf4_look_up_table_; + + public: + /// Construct from tensor references + CUTLASS_DEVICE + Int8Nf4InterleavedMmaMultistage( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + typename Base::SharedStorage& shared_storage, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx) + : Base(shared_storage, thread_idx, warp_idx, lane_idx), + smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx), + smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx), + smem_iterator_nf4_look_up_table_( + LayoutNF4LookUpTable(16), + shared_storage.operand_nf4_look_up_table.data(), + {1, 16}, + thread_idx) { + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + + int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); + int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); + + int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; + int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; + + // Add per-warp offsets in units of warp-level tiles + this->warp_tile_iterator_A_.add_tile_offset( + {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); + this->warp_tile_iterator_B_.add_tile_offset( + {Base::kWarpGemmIterationsForB * warp_idx_k, warp_idx_n}); + + // if((threadIdx.x % 32) == 0){ + // printf("#### %d-%d-%d-%d-%d-%d, gmem_ptr_nf4_look_up_table:%p, + // kSrcBytesNf4:%d \n", + // blockIdx.x, blockIdx.y, blockIdx.z, + // threadIdx.x, threadIdx.y, threadIdx.z, + // gmem_ptr_nf4_look_up_table, + // kSrcBytesNf4); + // } + // cutlass::arch::cp_async_zfill( + // dst_ptr_nf4_look_up_table, gmem_ptr_nf4_look_up_table, + // iterator_nf4_look_up_table.valid()); + } + + CUTLASS_DEVICE + void copy_tiles_and_advance(IteratorA& iterator_A, + IteratorB& iterator_B, + int group_start_A = 0, + int group_start_B = 0) { + iterator_A.set_iteration_index(group_start_A * + IteratorA::kAccessesPerVector); + this->smem_iterator_A_.set_iteration_index(group_start_A); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupA; ++j) { + if (group_start_A + j < Detail::AsyncCopyIterationsPerStageA) { + typename IteratorA::AccessType* dst_ptr = + reinterpret_cast( + this->smem_iterator_A_.get()); + + int const kSrcBytes = sizeof_bits::value * + IteratorA::ThreadMap::kElementsPerAccess / + IteratorA::kAccessesPerVector / 8; - /// Number of cp.async instructions to load one stage of operand B - static int const AsyncCopyIterationsPerStageB = IteratorB::ThreadMap::Iterations::kCount; + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { + auto gmem_ptr = iterator_A.get(); + + if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { + cutlass::arch::cp_async_zfill( + dst_ptr + v, gmem_ptr, iterator_A.valid()); + } else { + cutlass::arch::cp_async( + dst_ptr + v, gmem_ptr, iterator_A.valid()); + } + + ++iterator_A; + } - /// Number of stages - static int const kStages = Stages; + ++this->smem_iterator_A_; + } + } - /// Number of cp.async instructions to load on group of operand A - static int const kAccessesPerGroupA = - (AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; + iterator_B.set_iteration_index(group_start_B * + IteratorB::kAccessesPerVector); + this->smem_iterator_B_.set_iteration_index(group_start_B); - /// Number of cp.async instructions to load on group of operand B - static int const kAccessesPerGroupB = - (AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; - }; + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) { + if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) { + typename IteratorB::AccessType* dst_ptr = + reinterpret_cast( + this->smem_iterator_B_.get()); -private: - using WarpFragmentA = typename Operator::FragmentA; - using WarpFragmentB = typename Operator::FragmentB; - using ElementB = typename IteratorB::Element; - using LayoutDetailsForB = kernel::LayoutDetailsB; + int const kSrcBytes = sizeof_bits::value * + IteratorB::ThreadMap::kElementsPerAccess / + IteratorB::kAccessesPerVector / 8; - static constexpr bool RequiresTileInterleave = - layout::IsColumnMajorTileInterleave::value; - static_assert(!RequiresTileInterleave || (RequiresTileInterleave && (Shape::kK == LayoutDetailsForB::ThreadblockK)), - "Layout K must match threadblockK"); + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { + auto gmem_ptr = iterator_B.get(); + + if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { + cutlass::arch::cp_async_zfill( + dst_ptr + v, gmem_ptr, iterator_B.valid()); + } else { + cutlass::arch::cp_async( + dst_ptr + v, gmem_ptr, iterator_B.valid()); + } + + // if(true && (threadIdx.x||threadIdx.y||threadIdx.z)==0){ + // int32_t* print_ptr = + // reinterpret_cast(iterator_B.get()); int32_t* + // print_ptr_smem = reinterpret_cast(dst_ptr+v); if + // (iterator_B.valid()) + // { + // printf("gmem_ptr cp source of thread %d-%d-%d;%d-%d-%d: + // %p:%x-%x-%x-%x=>%x,%x,%x,%x \n", + // blockIdx.x,blockIdx.y,blockIdx.z, + // threadIdx.x,threadIdx.y,threadIdx.z, + // iterator_B.get(), + // static_cast(print_ptr[0]), + // static_cast(print_ptr[1]), + // static_cast(print_ptr[2]), + // static_cast(print_ptr[3]), + // static_cast(print_ptr_smem[0]), + // static_cast(print_ptr_smem[1]), + // static_cast(print_ptr_smem[2]), + // static_cast(print_ptr_smem[3])); + // } + // } + ++iterator_B; + } + ++this->smem_iterator_B_; + } + } + } + + /// Perform a threadblock-scoped matrix multiply-accumulate + CUTLASS_DEVICE + void operator()( + ///< problem size of GEMM + int gemm_k_iterations, + ///< destination accumulator tile + FragmentC& accum, + ///< iterator over A operand in global memory + IteratorA iterator_A, + ///< iterator over B operand in global memory + IteratorB iterator_B, + IteratorNF4LookUpTable iterator_nf4_look_up_table, + ///< initial value of accumulator + FragmentC const& src_accum) { + // printf("gemm_k_iterations:%d\n", gemm_k_iterations); -private: // - // Data members + // Prologue // - /// Iterator to write threadblock-scoped tile of A operand to shared memory - SmemIteratorA smem_iterator_A_; - - /// Iterator to write threadblock-scoped tile of B operand to shared memory - SmemIteratorB smem_iterator_B_; - - SmemIteratorNF4LookUpTable smem_iterator_nf4_look_up_table_; - -public: - /// Construct from tensor references - CUTLASS_DEVICE - Int8Nf4InterleavedMmaMultistage( - ///< Shared storage needed for internal use by threadblock-scoped GEMM - typename Base::SharedStorage& shared_storage, - ///< ID within the threadblock - int thread_idx, - ///< ID of warp - int warp_idx, - ///< ID of each thread within a warp - int lane_idx): - Base(shared_storage, thread_idx, warp_idx, lane_idx), - smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx), - smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx), - smem_iterator_nf4_look_up_table_(LayoutNF4LookUpTable(16), - shared_storage.operand_nf4_look_up_table.data(), - {1, 16}, - thread_idx) -{ - // Compute warp location within threadblock tile by mapping the warp_id to - // three coordinates: - // _m: the warp's position within the threadblock along the M dimension - // _n: the warp's position within the threadblock along the N dimension - // _k: the warp's position within the threadblock along the K dimension - - int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); - int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); - - int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; - int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; - - // Add per-warp offsets in units of warp-level tiles - this->warp_tile_iterator_A_.add_tile_offset({warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); - this->warp_tile_iterator_B_.add_tile_offset({Base::kWarpGemmIterationsForB * warp_idx_k, warp_idx_n}); - - - // if((threadIdx.x % 32) == 0){ - // printf("#### %d-%d-%d-%d-%d-%d, gmem_ptr_nf4_look_up_table:%p, kSrcBytesNf4:%d \n", - // blockIdx.x, blockIdx.y, blockIdx.z, - // threadIdx.x, threadIdx.y, threadIdx.z, - // gmem_ptr_nf4_look_up_table, - // kSrcBytesNf4); - // } - // cutlass::arch::cp_async_zfill( - // dst_ptr_nf4_look_up_table, gmem_ptr_nf4_look_up_table, iterator_nf4_look_up_table.valid()); - } + // use share memory to get look_up_table of nf4; + + // __shared__ uint32_t shared_look_up_table[16]; + + // int lane_idx=threadIdx.x%32; + // int warp_idx=threadIdx.x/32; + // if(lane_idx<16){ + // shared_look_up_table[lane_idx]=lane_idx; + // } + + // __shared__ uint32_t shared_look_up_table[32][32]; + // if(warp_idx==0){ + // CUTLASS_PRAGMA_UNROLL + // for(int ii=0;ii<16;++ii){ + // shared_look_up_table[lane_idx][ii]=ii; + // } + // } + + /// load look up table to smem here + // __shared__ int32_t nf4_smem_look_up_table[16]; + + // int32_t* gmem_ptr_nf4_look_up_table = + // reinterpret_cast(iterator_nf4_look_up_table.get()); + // // smem look up table + // int32_t* dst_ptr_nf4_look_up_table = + // reinterpret_cast(nf4_smem_look_up_table); + + // if(lane_idx == 0){ + // int4* dst_ptr_nf4_look_up_table_int4 = + // reinterpret_cast(nf4_smem_look_up_table); + // dst_ptr_nf4_look_up_table_int4[lane_idx] = + // *(reinterpret_cast(gmem_ptr_nf4_look_up_table) + lane_idx); + // } + // __syncthreads(); + // // reg look up table + // cutlass::Array reg_look_up_table; + // CUTLASS_PRAGMA_UNROLL + // for(int ii=0;ii<4;++ii){ + // reg_look_up_table[ii]=*(reinterpret_cast(dst_ptr_nf4_look_up_table) + // + ii); + // } + + TransformBAfterLDS lds_converter; + + // NOTE - switch to ldg.sts + // Issue this first, so cp.async.commit_group will commit this load as well. + // Note: we do not commit here and this load will commit in the same group + // as + // the first load of A. + + // Issue several complete stages + CUTLASS_PRAGMA_UNROLL + for (int stage = 0; stage < Base::kStages - 1; + ++stage, --gemm_k_iterations) { + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); + + iterator_A.set_iteration_index(0); + this->smem_iterator_A_.set_iteration_index(0); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { + typename IteratorA::AccessType* dst_ptr = + reinterpret_cast( + this->smem_iterator_A_.get()); - CUTLASS_DEVICE - void - copy_tiles_and_advance(IteratorA& iterator_A, IteratorB& iterator_B, int group_start_A = 0, int group_start_B = 0) - { - iterator_A.set_iteration_index(group_start_A * IteratorA::kAccessesPerVector); - this->smem_iterator_A_.set_iteration_index(group_start_A); - - // Async Copy for operand A CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::kAccessesPerGroupA; ++j) { - if (group_start_A + j < Detail::AsyncCopyIterationsPerStageA) { - typename IteratorA::AccessType* dst_ptr = - reinterpret_cast(this->smem_iterator_A_.get()); - - int const kSrcBytes = sizeof_bits::value - * IteratorA::ThreadMap::kElementsPerAccess / IteratorA::kAccessesPerVector / 8; - - CUTLASS_PRAGMA_UNROLL - for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { - auto gmem_ptr = iterator_A.get(); - - if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { - cutlass::arch::cp_async_zfill(dst_ptr + v, gmem_ptr, iterator_A.valid()); - } - else { - cutlass::arch::cp_async(dst_ptr + v, gmem_ptr, iterator_A.valid()); - } - - ++iterator_A; - } - - ++this->smem_iterator_A_; - } - } + for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { + int const kSrcBytes = + sizeof_bits::value * + IteratorA::ThreadMap::kElementsPerAccess / + IteratorA::kAccessesPerVector / 8; - iterator_B.set_iteration_index(group_start_B * IteratorB::kAccessesPerVector); - this->smem_iterator_B_.set_iteration_index(group_start_B); + int src_bytes = (iterator_A.valid() ? kSrcBytes : 0); - // Async Copy for operand B - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) { - if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) { - typename IteratorB::AccessType* dst_ptr = - reinterpret_cast(this->smem_iterator_B_.get()); - - int const kSrcBytes = sizeof_bits::value - * IteratorB::ThreadMap::kElementsPerAccess / IteratorB::kAccessesPerVector / 8; - - CUTLASS_PRAGMA_UNROLL - for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { - auto gmem_ptr = iterator_B.get(); - - if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { - cutlass::arch::cp_async_zfill(dst_ptr + v, gmem_ptr, iterator_B.valid()); - } - else { - cutlass::arch::cp_async(dst_ptr + v, gmem_ptr, iterator_B.valid()); - } - - // if(true && (threadIdx.x||threadIdx.y||threadIdx.z)==0){ - // int32_t* print_ptr = reinterpret_cast(iterator_B.get()); - // int32_t* print_ptr_smem = reinterpret_cast(dst_ptr+v); - // if (iterator_B.valid()) - // { - // printf("gmem_ptr cp source of thread %d-%d-%d;%d-%d-%d: %p:%x-%x-%x-%x=>%x,%x,%x,%x \n", - // blockIdx.x,blockIdx.y,blockIdx.z, - // threadIdx.x,threadIdx.y,threadIdx.z, - // iterator_B.get(), - // static_cast(print_ptr[0]), - // static_cast(print_ptr[1]), - // static_cast(print_ptr[2]), - // static_cast(print_ptr[3]), - // static_cast(print_ptr_smem[0]), - // static_cast(print_ptr_smem[1]), - // static_cast(print_ptr_smem[2]), - // static_cast(print_ptr_smem[3])); - // } - // } - ++iterator_B; - } - ++this->smem_iterator_B_; - } + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_A.get(), iterator_A.valid()); + + ++iterator_A; } - } - /// Perform a threadblock-scoped matrix multiply-accumulate - CUTLASS_DEVICE - void operator()( - ///< problem size of GEMM - int gemm_k_iterations, - ///< destination accumulator tile - FragmentC& accum, - ///< iterator over A operand in global memory - IteratorA iterator_A, - ///< iterator over B operand in global memory - IteratorB iterator_B, - IteratorNF4LookUpTable iterator_nf4_look_up_table, - ///< initial value of accumulator - FragmentC const& src_accum) - { - - // printf("gemm_k_iterations:%d\n", gemm_k_iterations); - - // - // Prologue - // - - // use share memory to get look_up_table of nf4; - - // __shared__ uint32_t shared_look_up_table[16]; - - // int lane_idx=threadIdx.x%32; - // int warp_idx=threadIdx.x/32; - // if(lane_idx<16){ - // shared_look_up_table[lane_idx]=lane_idx; - // } + ++this->smem_iterator_A_; + } - // __shared__ uint32_t shared_look_up_table[32][32]; - // if(warp_idx==0){ - // CUTLASS_PRAGMA_UNROLL - // for(int ii=0;ii<16;++ii){ - // shared_look_up_table[lane_idx][ii]=ii; - // } - // } + iterator_B.set_iteration_index(0); + // print_type(); + this->smem_iterator_B_.set_iteration_index(0); - /// load look up table to smem here - // __shared__ int32_t nf4_smem_look_up_table[16]; + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { + typename IteratorB::AccessType* dst_ptr = + reinterpret_cast( + this->smem_iterator_B_.get()); - // int32_t* gmem_ptr_nf4_look_up_table = reinterpret_cast(iterator_nf4_look_up_table.get()); - // // smem look up table - // int32_t* dst_ptr_nf4_look_up_table = reinterpret_cast(nf4_smem_look_up_table); + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { + int const kSrcBytes = + sizeof_bits::value * + IteratorB::ThreadMap::kElementsPerAccess / + IteratorB::kAccessesPerVector / 8; + + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_B.get(), iterator_B.valid()); + + // if(true && (threadIdx.x||threadIdx.y||threadIdx.z)==0){ + // int32_t* print_ptr = + // reinterpret_cast(iterator_B.get()); int32_t* + // print_ptr_smem = reinterpret_cast(dst_ptr+v); if + // (iterator_B.valid()) + // { + // printf("gmem_ptr cp source of thread %d-%d-%d;%d-%d-%d: + // %p:%x-%x-%x-%x=>%x,%x,%x,%x \n", + // blockIdx.x,blockIdx.y,blockIdx.z, + // threadIdx.x,threadIdx.y,threadIdx.z, + // iterator_B.get(), + // static_cast(print_ptr[0]), + // static_cast(print_ptr[1]), + // static_cast(print_ptr[2]), + // static_cast(print_ptr[3]), + // static_cast(print_ptr_smem[0]), + // static_cast(print_ptr_smem[1]), + // static_cast(print_ptr_smem[2]), + // static_cast(print_ptr_smem[3])); + // } + // } + ++iterator_B; + } - // if(lane_idx == 0){ - // int4* dst_ptr_nf4_look_up_table_int4 = reinterpret_cast(nf4_smem_look_up_table); - // dst_ptr_nf4_look_up_table_int4[lane_idx] = *(reinterpret_cast(gmem_ptr_nf4_look_up_table) + lane_idx); - // } - // __syncthreads(); - // // reg look up table - // cutlass::Array reg_look_up_table; - // CUTLASS_PRAGMA_UNROLL - // for(int ii=0;ii<4;++ii){ - // reg_look_up_table[ii]=*(reinterpret_cast(dst_ptr_nf4_look_up_table) + ii); - // } + ++this->smem_iterator_B_; + } - TransformBAfterLDS lds_converter; + // Move to the next stage + iterator_A.add_tile_offset({0, 1}); + iterator_B.add_tile_offset({1, 0}); - // NOTE - switch to ldg.sts - // Issue this first, so cp.async.commit_group will commit this load as well. - // Note: we do not commit here and this load will commit in the same group as - // the first load of A. + this->smem_iterator_A_.add_tile_offset({0, 1}); + this->smem_iterator_B_.add_tile_offset({1, 0}); - // Issue several complete stages - CUTLASS_PRAGMA_UNROLL - for (int stage = 0; stage < Base::kStages - 1; ++stage, --gemm_k_iterations) { - - iterator_A.clear_mask(gemm_k_iterations == 0); - iterator_B.clear_mask(gemm_k_iterations == 0); - - iterator_A.set_iteration_index(0); - this->smem_iterator_A_.set_iteration_index(0); - - // Async Copy for operand A - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { - typename IteratorA::AccessType* dst_ptr = - reinterpret_cast(this->smem_iterator_A_.get()); - - CUTLASS_PRAGMA_UNROLL - for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { - int const kSrcBytes = sizeof_bits::value - * IteratorA::ThreadMap::kElementsPerAccess / IteratorA::kAccessesPerVector - / 8; - - int src_bytes = (iterator_A.valid() ? kSrcBytes : 0); - - cutlass::arch::cp_async_zfill( - dst_ptr + v, iterator_A.get(), iterator_A.valid()); - - ++iterator_A; - } - - ++this->smem_iterator_A_; - } - - iterator_B.set_iteration_index(0); - // print_type(); - this->smem_iterator_B_.set_iteration_index(0); - - // Async Copy for operand B - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { - typename IteratorB::AccessType* dst_ptr = - reinterpret_cast(this->smem_iterator_B_.get()); - - CUTLASS_PRAGMA_UNROLL - for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { - int const kSrcBytes = sizeof_bits::value - * IteratorB::ThreadMap::kElementsPerAccess / IteratorB::kAccessesPerVector - / 8; - - cutlass::arch::cp_async_zfill( - dst_ptr + v, iterator_B.get(), iterator_B.valid()); - - // if(true && (threadIdx.x||threadIdx.y||threadIdx.z)==0){ - // int32_t* print_ptr = reinterpret_cast(iterator_B.get()); - // int32_t* print_ptr_smem = reinterpret_cast(dst_ptr+v); - // if (iterator_B.valid()) - // { - // printf("gmem_ptr cp source of thread %d-%d-%d;%d-%d-%d: %p:%x-%x-%x-%x=>%x,%x,%x,%x \n", - // blockIdx.x,blockIdx.y,blockIdx.z, - // threadIdx.x,threadIdx.y,threadIdx.z, - // iterator_B.get(), - // static_cast(print_ptr[0]), - // static_cast(print_ptr[1]), - // static_cast(print_ptr[2]), - // static_cast(print_ptr[3]), - // static_cast(print_ptr_smem[0]), - // static_cast(print_ptr_smem[1]), - // static_cast(print_ptr_smem[2]), - // static_cast(print_ptr_smem[3])); - // } - // } - ++iterator_B; - } - - ++this->smem_iterator_B_; - } - - // Move to the next stage - iterator_A.add_tile_offset({0, 1}); - iterator_B.add_tile_offset({1, 0}); - - this->smem_iterator_A_.add_tile_offset({0, 1}); - this->smem_iterator_B_.add_tile_offset({1, 0}); - - // Defines the boundary of a stage of cp.async. - cutlass::arch::cp_async_fence(); - } + // Defines the boundary of a stage of cp.async. + cutlass::arch::cp_async_fence(); + } - // Perform accumulation in the 'd' output operand - accum = src_accum; + // Perform accumulation in the 'd' output operand + accum = src_accum; - // - // Clear the remaining tiles of SMEM. This is a functional requirement for some kernels - // so that all accumulator elements outside the GEMM footprint are zero. - // + // + // Clear the remaining tiles of SMEM. This is a functional requirement for + // some kernels so that all accumulator elements outside the GEMM footprint + // are zero. + // - if (SharedMemoryClear == SharedMemoryClearOption::kClearLastStage) { + if (SharedMemoryClear == SharedMemoryClearOption::kClearLastStage) { + /// Iterator to write threadblock-scoped tile of A operand to shared + /// memory + SmemIteratorA last_smem_iterator_A(this->smem_iterator_A_); - /// Iterator to write threadblock-scoped tile of A operand to shared memory - SmemIteratorA last_smem_iterator_A(this->smem_iterator_A_); + typename IteratorA::AccessType zero_A; + zero_A.clear(); - typename IteratorA::AccessType zero_A; - zero_A.clear(); + last_smem_iterator_A.set_iteration_index(0); - last_smem_iterator_A.set_iteration_index(0); + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { + typename IteratorA::AccessType* dst_ptr = + reinterpret_cast( + last_smem_iterator_A.get()); - // Async Copy for operand A - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { + *dst_ptr = zero_A; - typename IteratorA::AccessType* dst_ptr = - reinterpret_cast(last_smem_iterator_A.get()); + ++last_smem_iterator_A; + } - *dst_ptr = zero_A; + /// Iterator to write threadblock-scoped tile of B operand to shared + /// memory + SmemIteratorB last_smem_iterator_B(this->smem_iterator_B_); + typename IteratorB::AccessType zero_B; - ++last_smem_iterator_A; - } + zero_B.clear(); + last_smem_iterator_B.set_iteration_index(0); - /// Iterator to write threadblock-scoped tile of B operand to shared memory - SmemIteratorB last_smem_iterator_B(this->smem_iterator_B_); - typename IteratorB::AccessType zero_B; + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { + typename IteratorB::AccessType* dst_ptr = + reinterpret_cast( + last_smem_iterator_B.get()); - zero_B.clear(); - last_smem_iterator_B.set_iteration_index(0); + *dst_ptr = zero_B; - // Async Copy for operand B - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { + ++last_smem_iterator_B; + } + } - typename IteratorB::AccessType* dst_ptr = - reinterpret_cast(last_smem_iterator_B.get()); + // Waits until kStages-2 stages have committed. + cutlass::arch::cp_async_wait(); + __syncthreads(); + + // Pair of fragments used to overlap shared memory loads and math + // instructions + WarpFragmentA warp_frag_A[2]; + WarpFragmentB warp_frag_B[2]; + typename TransformBAfterLDS::result_type converted_frag_B_buffer[2]; + Operator warp_mma; + + this->warp_tile_iterator_A_.set_kgroup_index(0); + this->warp_tile_iterator_B_.set_kgroup_index(0); + + this->warp_tile_iterator_B_.load(warp_frag_B[0]); + converted_frag_B_buffer[0] = lds_converter(warp_frag_B[0]); + this->warp_tile_iterator_A_.load(warp_frag_A[0]); + + // if((threadIdx.x||threadIdx.y||threadIdx.z)==0){ + // uint32_t* frag_b_reg_ptr = + // reinterpret_cast(&warp_frag_B[0]); printf("#### + // warp_frag_b_load [0] bid:%d-%d-%d," + // " frag_b_reg_ptr:%x-%x-%x-%x-%x-%x-%x-%x \n", + // blockIdx.x,blockIdx.y,blockIdx.z, + // frag_b_reg_ptr[0], + // frag_b_reg_ptr[1], + // frag_b_reg_ptr[2], + // frag_b_reg_ptr[3], + // frag_b_reg_ptr[4], + // frag_b_reg_ptr[5], + // frag_b_reg_ptr[6], + // frag_b_reg_ptr[7] + // ); + // } + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_B_; + + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); + + int smem_write_stage_idx = Base::kStages - 1; + int smem_read_stage_idx = 0; - *dst_ptr = zero_B; + // + // Mainloop + // - ++last_smem_iterator_B; - } + __syncthreads(); + CUTLASS_GEMM_LOOP + for (; gemm_k_iterations > (-Base::kStages + 1);) { + // + // Loop over GEMM K dimension + // + + // Computes a warp-level GEMM on data held in shared memory + // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; + ++warp_mma_k) { + // Load warp-level tiles from shared memory, wrapping to k offset if + // this is the last group as the case may be. + this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % + Base::kWarpGemmIterations); + this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]); + ++this->warp_tile_iterator_A_; + // static_assert(Base::kNumKIterationsPerWarpBLoad==1,"Base::kNumKIterationsPerWarpBLoad!=1"); + // static_assert(Base::kWarpGemmIterationsForB==4,"Base::kWarpGemmIterationsForB!=4"); + const int warp_tileB_k_compute_offset = + warp_mma_k % Base::kNumKIterationsPerWarpBLoad; + const int warp_tileB_k_load_offset = + warp_mma_k / Base::kNumKIterationsPerWarpBLoad; + if (warp_tileB_k_compute_offset == + Base::kNumKIterationsPerWarpBLoad - 1) { + this->warp_tile_iterator_B_.set_kgroup_index( + (warp_tileB_k_load_offset + 1) % Base::kWarpGemmIterationsForB); + this->warp_tile_iterator_B_.load( + warp_frag_B[(warp_tileB_k_load_offset + 1) % 2]); + converted_frag_B_buffer[(warp_tileB_k_load_offset + 1) % 2] = + lds_converter(warp_frag_B[(warp_tileB_k_load_offset + 1) % 2]); + + ++this->warp_tile_iterator_B_; + // if((threadIdx.x||threadIdx.y||threadIdx.z)==0){ + // uint32_t* frag_b_reg_ptr = + // reinterpret_cast(&warp_frag_B[(warp_tileB_k_load_offset + // + 1) % 2]); printf("#### warp_frag_b load [%d] bid:%d-%d-%d," + // " frag_b_reg_ptr:%x-%x-%x-%x-%x-%x-%x-%x \n", + // ((warp_tileB_k_load_offset + 1) % 2), + // blockIdx.x,blockIdx.y,blockIdx.z, + // frag_b_reg_ptr[0], + // frag_b_reg_ptr[1], + // frag_b_reg_ptr[2], + // frag_b_reg_ptr[3], + // frag_b_reg_ptr[4], + // frag_b_reg_ptr[5], + // frag_b_reg_ptr[6], + // frag_b_reg_ptr[7] + // ); + // } } + // TODO(wangbojun) lds_converter can be remove for int8 B input + // int4 + // typename TransformBAfterLDS::result_type converted_frag_B = + // lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2]); - // Waits until kStages-2 stages have committed. - cutlass::arch::cp_async_wait(); - __syncthreads(); + // typename TransformBAfterLDS::result_type converted_frag_B = + // lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2], + // reinterpret_cast(nf4_smem_look_up_table)); - // Pair of fragments used to overlap shared memory loads and math - // instructions - WarpFragmentA warp_frag_A[2]; - WarpFragmentB warp_frag_B[2]; - typename TransformBAfterLDS::result_type converted_frag_B_buffer[2]; - Operator warp_mma; + // typename TransformBAfterLDS::result_type converted_frag_B = + // lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2], + // reg_look_up_table); - this->warp_tile_iterator_A_.set_kgroup_index(0); - this->warp_tile_iterator_B_.set_kgroup_index(0); - - this->warp_tile_iterator_B_.load(warp_frag_B[0]); - converted_frag_B_buffer[0] = - lds_converter(warp_frag_B[0]); - this->warp_tile_iterator_A_.load(warp_frag_A[0]); + // typename TransformBAfterLDS::result_type converted_frag_B = + // lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2], + // shared_look_up_table, warp_idx, lane_idx); // if((threadIdx.x||threadIdx.y||threadIdx.z)==0){ - // uint32_t* frag_b_reg_ptr = reinterpret_cast(&warp_frag_B[0]); - // printf("#### warp_frag_b_load [0] bid:%d-%d-%d," - // " frag_b_reg_ptr:%x-%x-%x-%x-%x-%x-%x-%x \n", + // uint32_t* frag_b_reg_ptr = + // reinterpret_cast(&warp_frag_B[(warp_tileB_k_load_offset) + // % 2]); uint32_t* converted_frag_B_reg_ptr = + // reinterpret_cast(&converted_frag_B); printf("#### + // after lds_converter bid:%d-%d-%d" + // " frag_b_reg_ptr[%d]:%x-%x-%x-%x-%x-%x-%x-%x" + // " converted_frag_b_reg_ptr:%x-%x-%x-%x-%x-%x-%x-%x \n", // blockIdx.x,blockIdx.y,blockIdx.z, + // ((warp_tileB_k_load_offset) % 2), // frag_b_reg_ptr[0], // frag_b_reg_ptr[1], // frag_b_reg_ptr[2], @@ -584,308 +719,243 @@ class Int8Nf4InterleavedMmaMultistage: public Int8Nf4InterleavedMmaBasewarp_tile_iterator_A_; - ++this->warp_tile_iterator_B_; - - iterator_A.clear_mask(gemm_k_iterations == 0); - iterator_B.clear_mask(gemm_k_iterations == 0); - - int smem_write_stage_idx = Base::kStages - 1; - int smem_read_stage_idx = 0; - - // - // Mainloop - // - - __syncthreads(); - CUTLASS_GEMM_LOOP - for (; gemm_k_iterations > (-Base::kStages + 1);) { - // - // Loop over GEMM K dimension - // - - // Computes a warp-level GEMM on data held in shared memory - // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate - CUTLASS_PRAGMA_UNROLL - for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) { - - // Load warp-level tiles from shared memory, wrapping to k offset if - // this is the last group as the case may be. - this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); - this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]); - ++this->warp_tile_iterator_A_; - // static_assert(Base::kNumKIterationsPerWarpBLoad==1,"Base::kNumKIterationsPerWarpBLoad!=1"); - // static_assert(Base::kWarpGemmIterationsForB==4,"Base::kWarpGemmIterationsForB!=4"); - const int warp_tileB_k_compute_offset = warp_mma_k % Base::kNumKIterationsPerWarpBLoad; - const int warp_tileB_k_load_offset = warp_mma_k / Base::kNumKIterationsPerWarpBLoad; - if (warp_tileB_k_compute_offset == Base::kNumKIterationsPerWarpBLoad - 1) { - this->warp_tile_iterator_B_.set_kgroup_index((warp_tileB_k_load_offset + 1) - % Base::kWarpGemmIterationsForB); - this->warp_tile_iterator_B_.load(warp_frag_B[(warp_tileB_k_load_offset + 1) % 2]); - converted_frag_B_buffer[(warp_tileB_k_load_offset + 1) % 2] = - lds_converter(warp_frag_B[(warp_tileB_k_load_offset + 1) % 2]); - - ++this->warp_tile_iterator_B_; - // if((threadIdx.x||threadIdx.y||threadIdx.z)==0){ - // uint32_t* frag_b_reg_ptr = reinterpret_cast(&warp_frag_B[(warp_tileB_k_load_offset + 1) % 2]); - // printf("#### warp_frag_b load [%d] bid:%d-%d-%d," - // " frag_b_reg_ptr:%x-%x-%x-%x-%x-%x-%x-%x \n", - // ((warp_tileB_k_load_offset + 1) % 2), - // blockIdx.x,blockIdx.y,blockIdx.z, - // frag_b_reg_ptr[0], - // frag_b_reg_ptr[1], - // frag_b_reg_ptr[2], - // frag_b_reg_ptr[3], - // frag_b_reg_ptr[4], - // frag_b_reg_ptr[5], - // frag_b_reg_ptr[6], - // frag_b_reg_ptr[7] - // ); - // } - } - // TOOD(wangbojun) lds_converter can be remove for int8 B input - // int4 - // typename TransformBAfterLDS::result_type converted_frag_B = - // lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2]); - - // typename TransformBAfterLDS::result_type converted_frag_B = - // lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2], reinterpret_cast(nf4_smem_look_up_table)); - - // typename TransformBAfterLDS::result_type converted_frag_B = - // lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2], reg_look_up_table); - - // typename TransformBAfterLDS::result_type converted_frag_B = - // lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2], shared_look_up_table, warp_idx, lane_idx); - - // if((threadIdx.x||threadIdx.y||threadIdx.z)==0){ - // uint32_t* frag_b_reg_ptr = reinterpret_cast(&warp_frag_B[(warp_tileB_k_load_offset) % 2]); - // uint32_t* converted_frag_B_reg_ptr = reinterpret_cast(&converted_frag_B); - // printf("#### after lds_converter bid:%d-%d-%d" - // " frag_b_reg_ptr[%d]:%x-%x-%x-%x-%x-%x-%x-%x" - // " converted_frag_b_reg_ptr:%x-%x-%x-%x-%x-%x-%x-%x \n", - // blockIdx.x,blockIdx.y,blockIdx.z, - // ((warp_tileB_k_load_offset) % 2), - // frag_b_reg_ptr[0], - // frag_b_reg_ptr[1], - // frag_b_reg_ptr[2], - // frag_b_reg_ptr[3], - // frag_b_reg_ptr[4], - // frag_b_reg_ptr[5], - // frag_b_reg_ptr[6], - // frag_b_reg_ptr[7], - // converted_frag_B_reg_ptr[0], - // converted_frag_B_reg_ptr[1], - // converted_frag_B_reg_ptr[2], - // converted_frag_B_reg_ptr[3], - // converted_frag_B_reg_ptr[4], - // converted_frag_B_reg_ptr[5], - // converted_frag_B_reg_ptr[6], - // converted_frag_B_reg_ptr[7] - // ); - // } - - // bool ::print_type< ::cutlass::Array< ::cutlass::integer_subbyte<(int)4, (bool)0> , (int)64, (bool)0> > ()") - // print_type(); - // bool ::print_type< ::cutlass::Array > ()") from a - // print_type(); - // cutlass::Array - // print_type(); - - // print_type(); - // TODO(zhengzekang) - // run_warp_mma( - - - // if(true){ - // uint32_t none_zero = 0; - // // uint32_t* converted_frag_B_reg_ptr = reinterpret_cast(&converted_frag_B); - // uint32_t* converted_frag_B_reg_ptr = reinterpret_cast(&warp_frag_B[warp_mma_k % 2]); - // uint32_t* frag_a_reg_ptr = reinterpret_cast(&warp_frag_A[warp_mma_k % 2]); - // CUTLASS_PRAGMA_UNROLL - // for(int ii=0;ii0;none_zero_i/=2){ - // none_zero|= __shfl_xor_sync(-1,none_zero,none_zero_i); - // } - - // if(none_zero!=0){ - // printf("## before mma ## bidtid:%d-%d-%d-%d-%d-%d, warp_mma_k:%d, frag_B_reg_ptr:%x-%x-%x-%x-%x-%x-%x-%x; frag_a_reg_ptr:%x-%x-%x-%x-%x-%x-%x-%x" - // " accu: %d-%d-%d-%d-%d-%d-%d-%d-%d-%d-%d-%d-%d-%d-%d-%d-%d-%d-%d-%d-%d-%d-%d-%d-%d-%d-%d-%d-%d-%d-%d-%d \n", - // blockIdx.x,blockIdx.y,blockIdx.z, - // warp_mma_k, - // threadIdx.x,threadIdx.y,threadIdx.z, - // converted_frag_B_reg_ptr[0], - // converted_frag_B_reg_ptr[1], - // converted_frag_B_reg_ptr[2], - // converted_frag_B_reg_ptr[3], - // converted_frag_B_reg_ptr[4], - // converted_frag_B_reg_ptr[5], - // converted_frag_B_reg_ptr[6], - // converted_frag_B_reg_ptr[7], - // frag_a_reg_ptr[0], - // frag_a_reg_ptr[1], - // frag_a_reg_ptr[2], - // frag_a_reg_ptr[3], - // frag_a_reg_ptr[4], - // frag_a_reg_ptr[5], - // frag_a_reg_ptr[6], - // frag_a_reg_ptr[7], - // accum[0], - // accum[1], - // accum[2], - // accum[3], - // accum[4], - // accum[5], - // accum[6], - // accum[7], - // accum[8], - // accum[9], - // accum[10], - // accum[11], - // accum[12], - // accum[13], - // accum[14], - // accum[15], - // accum[16], - // accum[17], - // accum[18], - // accum[19], - // accum[20], - // accum[21], - // accum[22], - // accum[23], - // accum[24], - // accum[25], - // accum[26], - // accum[27], - // accum[28], - // accum[29], - // accum[30], - // accum[31] - // ); - // } - // } - run_ampere_warp_mma( - warp_mma, accum, warp_frag_A[warp_mma_k % 2], converted_frag_B_buffer[warp_tileB_k_load_offset % 2], accum, warp_tileB_k_compute_offset); - // auto tmp = static_cast(warp_frag_B[warp_tileB_k_load_offset % 2]); - // if(threadIdx.x==0 && threadIdx.y==0 && threadIdx.z==0 && - // blockIdx.x==0 && blockIdx.y==0 && blockIdx.z==0){ - // printf("### run_warp_mma: " - // "%d \n", - // reinterpret_cast(accum)); - // } - // if(true){ - // uint32_t none_zero = 0; - // uint32_t* converted_frag_B_reg_ptr = reinterpret_cast(&converted_frag_B); - // // uint32_t* converted_frag_B_reg_ptr = reinterpret_cast(&warp_frag_B[warp_mma_k % 2]); - // uint32_t* frag_a_reg_ptr = reinterpret_cast(&warp_frag_A[warp_mma_k % 2]); - // CUTLASS_PRAGMA_UNROLL - // for(int ii=0;ii0;none_zero_i/=2){ - // none_zero|= __shfl_xor_sync(-1,none_zero,none_zero_i); - // } - - // // if(none_zero!=0){ - // if((blockIdx.y||blockIdx.z||threadIdx.x||threadIdx.y||threadIdx.z)==0){ - - // printf("## after mma ## bidtid:%d-%d-%d-%d-%d-%d, warp_mma_k:%d, gemm_k_iterations:%d, Base::kWarpGemmIterations:%d," - // " converted_frag_B_reg_ptr:%x; frag_a_reg_ptr:%x" - // " accu: %d \n", - // blockIdx.x,blockIdx.y,blockIdx.z, - // threadIdx.x,threadIdx.y,threadIdx.z, - // warp_mma_k, - // gemm_k_iterations, - // Base::kWarpGemmIterations, - // converted_frag_B_reg_ptr[0], - // frag_a_reg_ptr[0], - // accum[0] - // ); - // } - // } - // Issue global->shared copies for the this stage - if (warp_mma_k < Base::kWarpGemmIterations - 1) { - int group_start_iteration_A, group_start_iteration_B; - - group_start_iteration_A = warp_mma_k * Detail::kAccessesPerGroupA; - group_start_iteration_B = warp_mma_k * Detail::kAccessesPerGroupB; - - copy_tiles_and_advance(iterator_A, iterator_B, group_start_iteration_A, group_start_iteration_B); - } - - if (warp_mma_k + 2 == Base::kWarpGemmIterations) { - int group_start_iteration_A, group_start_iteration_B; - group_start_iteration_A = (warp_mma_k + 1) * Detail::kAccessesPerGroupA; - group_start_iteration_B = (warp_mma_k + 1) * Detail::kAccessesPerGroupB; - - copy_tiles_and_advance(iterator_A, iterator_B, group_start_iteration_A, group_start_iteration_B); - - // Inserts a memory fence between stages of cp.async instructions. - cutlass::arch::cp_async_fence(); - - // Waits until kStages-2 stages have committed. - arch::cp_async_wait(); - __syncthreads(); - - // Move to the next stage - iterator_A.add_tile_offset({0, 1}); - iterator_B.add_tile_offset({1, 0}); - - this->smem_iterator_A_.add_tile_offset({0, 1}); - this->smem_iterator_B_.add_tile_offset({1, 0}); - - // Add negative offsets to return iterators to the 'start' of the - // circular buffer in shared memory - if (smem_write_stage_idx == (Base::kStages - 1)) { - this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); - this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); - smem_write_stage_idx = 0; - } - else { - ++smem_write_stage_idx; - } - - if (smem_read_stage_idx == (Base::kStages - 1)) { - this->warp_tile_iterator_A_.add_tile_offset( - {0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations}); - this->warp_tile_iterator_B_.add_tile_offset( - {-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterationsForB, 0}); - smem_read_stage_idx = 0; - } - else { - ++smem_read_stage_idx; - } - - --gemm_k_iterations; - iterator_A.clear_mask(gemm_k_iterations == 0); - iterator_B.clear_mask(gemm_k_iterations == 0); - } - } + + // bool ::print_type< ::cutlass::Array< + // ::cutlass::integer_subbyte<(int)4, (bool)0> , (int)64, (bool)0> > + // ()") print_type(); bool ::print_type< + // ::cutlass::Array > ()") from a + // print_type(); + // cutlass::Array + // print_type(); + + // print_type(); + // TODO(zhengzekang) + // run_warp_mma( + + // if(true){ + // uint32_t none_zero = 0; + // // uint32_t* converted_frag_B_reg_ptr = + // reinterpret_cast(&converted_frag_B); uint32_t* + // converted_frag_B_reg_ptr = + // reinterpret_cast(&warp_frag_B[warp_mma_k % 2]); + // uint32_t* frag_a_reg_ptr = + // reinterpret_cast(&warp_frag_A[warp_mma_k % 2]); + // CUTLASS_PRAGMA_UNROLL + // for(int ii=0;ii0;none_zero_i/=2){ + // none_zero|= __shfl_xor_sync(-1,none_zero,none_zero_i); + // } + + // if(none_zero!=0){ + // printf("## before mma ## bidtid:%d-%d-%d-%d-%d-%d, + // warp_mma_k:%d, frag_B_reg_ptr:%x-%x-%x-%x-%x-%x-%x-%x; + // frag_a_reg_ptr:%x-%x-%x-%x-%x-%x-%x-%x" + // " accu: + // %d-%d-%d-%d-%d-%d-%d-%d-%d-%d-%d-%d-%d-%d-%d-%d-%d-%d-%d-%d-%d-%d-%d-%d-%d-%d-%d-%d-%d-%d-%d-%d + // \n", blockIdx.x,blockIdx.y,blockIdx.z, warp_mma_k, + // threadIdx.x,threadIdx.y,threadIdx.z, + // converted_frag_B_reg_ptr[0], + // converted_frag_B_reg_ptr[1], + // converted_frag_B_reg_ptr[2], + // converted_frag_B_reg_ptr[3], + // converted_frag_B_reg_ptr[4], + // converted_frag_B_reg_ptr[5], + // converted_frag_B_reg_ptr[6], + // converted_frag_B_reg_ptr[7], + // frag_a_reg_ptr[0], + // frag_a_reg_ptr[1], + // frag_a_reg_ptr[2], + // frag_a_reg_ptr[3], + // frag_a_reg_ptr[4], + // frag_a_reg_ptr[5], + // frag_a_reg_ptr[6], + // frag_a_reg_ptr[7], + // accum[0], + // accum[1], + // accum[2], + // accum[3], + // accum[4], + // accum[5], + // accum[6], + // accum[7], + // accum[8], + // accum[9], + // accum[10], + // accum[11], + // accum[12], + // accum[13], + // accum[14], + // accum[15], + // accum[16], + // accum[17], + // accum[18], + // accum[19], + // accum[20], + // accum[21], + // accum[22], + // accum[23], + // accum[24], + // accum[25], + // accum[26], + // accum[27], + // accum[28], + // accum[29], + // accum[30], + // accum[31] + // ); + // } + // } + run_ampere_warp_mma( + warp_mma, + accum, + warp_frag_A[warp_mma_k % 2], + converted_frag_B_buffer[warp_tileB_k_load_offset % 2], + accum, + warp_tileB_k_compute_offset); + // auto tmp = static_cast(warp_frag_B[warp_tileB_k_load_offset + // % 2]); if(threadIdx.x==0 && threadIdx.y==0 && threadIdx.z==0 && + // blockIdx.x==0 && blockIdx.y==0 && blockIdx.z==0){ + // printf("### run_warp_mma: " + // "%d \n", + // reinterpret_cast(accum)); + // } + // if(true){ + // uint32_t none_zero = 0; + // uint32_t* converted_frag_B_reg_ptr = + // reinterpret_cast(&converted_frag_B); + // // uint32_t* converted_frag_B_reg_ptr = + // reinterpret_cast(&warp_frag_B[warp_mma_k % 2]); + // uint32_t* frag_a_reg_ptr = + // reinterpret_cast(&warp_frag_A[warp_mma_k % 2]); + // CUTLASS_PRAGMA_UNROLL + // for(int ii=0;ii0;none_zero_i/=2){ + // none_zero|= __shfl_xor_sync(-1,none_zero,none_zero_i); + // } + + // // if(none_zero!=0){ + // if((blockIdx.y||blockIdx.z||threadIdx.x||threadIdx.y||threadIdx.z)==0){ + + // printf("## after mma ## bidtid:%d-%d-%d-%d-%d-%d, + // warp_mma_k:%d, gemm_k_iterations:%d, + // Base::kWarpGemmIterations:%d," + // " converted_frag_B_reg_ptr:%x; frag_a_reg_ptr:%x" + // " accu: %d \n", + // blockIdx.x,blockIdx.y,blockIdx.z, + // threadIdx.x,threadIdx.y,threadIdx.z, + // warp_mma_k, + // gemm_k_iterations, + // Base::kWarpGemmIterations, + // converted_frag_B_reg_ptr[0], + // frag_a_reg_ptr[0], + // accum[0] + // ); + // } + // } + // Issue global->shared copies for the this stage + if (warp_mma_k < Base::kWarpGemmIterations - 1) { + int group_start_iteration_A, group_start_iteration_B; + + group_start_iteration_A = warp_mma_k * Detail::kAccessesPerGroupA; + group_start_iteration_B = warp_mma_k * Detail::kAccessesPerGroupB; + + copy_tiles_and_advance(iterator_A, + iterator_B, + group_start_iteration_A, + group_start_iteration_B); } - if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { - // commit and drain all pending and predicated LDGSTS pnz from the GEMM mainloop - cutlass::arch::cp_async_fence(); - cutlass::arch::cp_async_wait<0>(); - __syncthreads(); + if (warp_mma_k + 2 == Base::kWarpGemmIterations) { + int group_start_iteration_A, group_start_iteration_B; + group_start_iteration_A = + (warp_mma_k + 1) * Detail::kAccessesPerGroupA; + group_start_iteration_B = + (warp_mma_k + 1) * Detail::kAccessesPerGroupB; + + copy_tiles_and_advance(iterator_A, + iterator_B, + group_start_iteration_A, + group_start_iteration_B); + + // Inserts a memory fence between stages of cp.async instructions. + cutlass::arch::cp_async_fence(); + + // Waits until kStages-2 stages have committed. + arch::cp_async_wait(); + __syncthreads(); + + // Move to the next stage + iterator_A.add_tile_offset({0, 1}); + iterator_B.add_tile_offset({1, 0}); + + this->smem_iterator_A_.add_tile_offset({0, 1}); + this->smem_iterator_B_.add_tile_offset({1, 0}); + + // Add negative offsets to return iterators to the 'start' of the + // circular buffer in shared memory + if (smem_write_stage_idx == (Base::kStages - 1)) { + this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); + this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); + smem_write_stage_idx = 0; + } else { + ++smem_write_stage_idx; + } + + if (smem_read_stage_idx == (Base::kStages - 1)) { + this->warp_tile_iterator_A_.add_tile_offset( + {0, + -Base::kStages * Policy::kPartitionsK * + Base::kWarpGemmIterations}); + this->warp_tile_iterator_B_.add_tile_offset( + {-Base::kStages * Policy::kPartitionsK * + Base::kWarpGemmIterationsForB, + 0}); + smem_read_stage_idx = 0; + } else { + ++smem_read_stage_idx; + } + + --gemm_k_iterations; + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); } + } + } + + if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { + // commit and drain all pending and predicated LDGSTS pnz from the GEMM + // mainloop + cutlass::arch::cp_async_fence(); + cutlass::arch::cp_async_wait<0>(); + __syncthreads(); } + } }; ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/gemm/warp/default_mma_tensor_op.h b/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/gemm/warp/default_mma_tensor_op.h index 97fca6da675..1e11b8e10f6 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/gemm/warp/default_mma_tensor_op.h +++ b/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/gemm/warp/default_mma_tensor_op.h @@ -122,10 +122,9 @@ struct DefaultMmaTensorOp; }; - - -// Specialization for int8 int8 int32 gemm, we use instruction shape<16, 8, 32> for better performance. -template< +// Specialization for int8 int8 int32 gemm, we use instruction shape<16, 8, 32> +// for better performance. +template < /// Shape of one matrix production operation (concept: GemmShape) typename WarpShape_, /// Shape of one matrix production operation (concept: GemmShape) @@ -152,48 +151,51 @@ struct DefaultMmaTensorOp { -private: - // Shape for computing the FP16s - using ComputeInstructionShape = InstructionShape_; + private: + // Shape for computing the FP16s + using ComputeInstructionShape = InstructionShape_; - // Chosen so we get K=32. - static constexpr int LoadInstructionK = 32 * sizeof_bits::value / sizeof_bits::value; + // Chosen so we get K=32. + static constexpr int LoadInstructionK = + 32 * sizeof_bits::value / sizeof_bits::value; - // Shape for loading the narrow data type from shared memory - using LoadInstructionShape = GemmShape; + // Shape for loading the narrow data type from shared memory + using LoadInstructionShape = + GemmShape; -public: - using Policy = cutlass::gemm::warp::MmaTensorOpPolicy, - arch::OpMultiplyAdd>, - cutlass::MatrixShape<1, 1>>; + public: + using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< + cutlass::arch::Mma< + InstructionShape_, + 32, + int8_t, + cutlass::layout::RowMajor, + int8_t, + cutlass::layout::ColumnMajor, + int32_t, + cutlass::layout::RowMajor, + // cutlass::layout::RowMajorTensorOpMultiplicandCrosswise<32,128>, + arch::OpMultiplyAdd>, + cutlass::MatrixShape<1, 1>>; - // Define the warp-level tensor op - using Type = cutlass::gemm::warp::MmaTensorOpComputeBWithF16; + // Define the warp-level tensor op + using Type = + cutlass::gemm::warp::MmaTensorOpComputeBWithF16; }; - - - -// Specialization for int8 int8 int32 gemm, we use instruction shape<16, 8, 32> for better performance. -template< +// Specialization for int8 int8 int32 gemm, we use instruction shape<16, 8, 32> +// for better performance. +template < /// Shape of one matrix production operation (concept: GemmShape) typename WarpShape_, /// Shape of one matrix production operation (concept: GemmShape) @@ -220,41 +222,46 @@ struct DefaultMmaTensorOp { -private: - // Shape for computing the FP16s - using ComputeInstructionShape = InstructionShape_; + private: + // Shape for computing the FP16s + using ComputeInstructionShape = InstructionShape_; - // Chosen so we get K=64. - static constexpr int LoadInstructionK = 16 * sizeof_bits::value / sizeof_bits::value; + // Chosen so we get K=64. + static constexpr int LoadInstructionK = + 16 * sizeof_bits::value / sizeof_bits::value; - // Shape for loading the narrow data type from shared memory - using LoadInstructionShape = GemmShape; + // Shape for loading the narrow data type from shared memory + using LoadInstructionShape = + GemmShape; -public: - using Policy = cutlass::gemm::warp::MmaTensorOpPolicy, - arch::OpMultiplyAdd>, - cutlass::MatrixShape<1, 1>>; + public: + using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< + cutlass::arch::Mma< + InstructionShape_, + 32, + int8_t, + cutlass::layout::RowMajor, + int8_t, + cutlass::layout::ColumnMajor, + int32_t, + cutlass::layout::RowMajor, + // cutlass::layout::RowMajorTensorOpMultiplicandCrosswise<32,128>, + arch::OpMultiplyAdd>, + cutlass::MatrixShape<1, 1>>; - // Define the warp-level tensor op - using Type = cutlass::gemm::warp::MmaTensorOpComputeBWithF16; + // Define the warp-level tensor op + using Type = + cutlass::gemm::warp::MmaTensorOpComputeBWithF16; }; ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h b/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h index 7279541fe9f..3476b3816c9 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h +++ b/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h @@ -1,12 +1,12 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation @@ -18,19 +18,20 @@ * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /*! \file - \brief Templates implementing warp-level matrix multiply-accumulate operations targeting - Tensor Cores. + \brief Templates implementing warp-level matrix multiply-accumulate + operations targeting Tensor Cores. */ #pragma once @@ -63,8 +64,9 @@ namespace gemm { namespace warp { ///////////////////////////////////////////////////////////////////////////////////////////////// -/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions. -template< +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < /// Size of the Gemm problem - concept: gemm::GemmShape<> typename Shape_, /// Data type of A elements @@ -91,222 +93,229 @@ template< /// Used for partial specialization typename Enable = bool> class MmaTensorOpComputeBWithF16 { -public: - /// Shape of warp-level matrix operation (concept: GemmShape) - using Shape = Shape_; - - /// Data type of multiplicand A - using ElementA = ElementA_; - - /// Layout of multiplicand A - using LayoutA = LayoutA_; - - /// Data type of multiplicand B - using ElementB = ElementB_; - - /// Layout of multiplicand B - using LayoutB = LayoutB_; - - /// Data type of accumulator matrix C - using ElementC = ElementC_; - - /// Layout of accumulator matrix C - using LayoutC = LayoutC_; - - /// Shape of the warp in units of thread (concept: MmaLanePolicySimt) - using Policy = Policy_; - - /// Underlying matrix multiply operator (concept: arch::Mma) - using ArchMmaOperator = typename Policy::Operator; - - /// Indicates math operator - using MathOperator = typename ArchMmaOperator::Operator; - - /// Architecture tag from underlying instruction - using ArchTag = typename ArchMmaOperator::ArchTag; - static_assert((platform::is_same::value - && platform::is_same::value) - || (platform::is_same::value - && platform::is_same::value - && ArchTag::kMinComputeCapability >= 80), - "MmaTensorOpCvtBToA only supports underlying HMMA"); - - static_assert(platform::is_same::value - || (platform::is_same::value && ArchTag::kMinComputeCapability >= 80), - "MmaTensorOpCvtBToA only supports Fp16 A or Bf16 A on Ampere+"); - - // Indicates class of matrix operator - using OperatorClass = arch::OpClassTensorOp; - - /// Shape of underlying instruction - using InstructionShape = typename ArchMmaOperator::Shape; - - /// Instruction shape to override shared memory iterators with - using SharedMemoryInstructionShape = SharedMemoryInstructionShape_; - - static_assert(SharedMemoryInstructionShape::kM == InstructionShape::kM, - "M dimension of compute instruction must match load"); - static_assert(SharedMemoryInstructionShape::kN == InstructionShape::kN, - "N dimension of compute instruction must match load"); - - static constexpr int kExpansionFactor = SharedMemoryInstructionShape::kK / InstructionShape::kK; - - static_assert(!(Shape::kK % SharedMemoryInstructionShape::kK), ""); - - /// Complex transform on A operand - static ComplexTransform const kTransformA = ComplexTransform::kNone; - - /// Complex transform on B operand - static ComplexTransform const kTransformB = ComplexTransform::kNone; - - /// Number of threads participating in warp-level matrix product - static int const kThreadCount = 32; - - /// Number of partitions along K dimension - static int const kPartitionsK = PartitionsK_; - -public: - /// Iterates over the A operand in memory - using IteratorA = MmaTensorOpMultiplicandTileIterator, - Operand::kA, - ElementA, - LayoutA, - MatrixShape, - Policy::OpDelta::kRow, - kThreadCount, - kPartitionsK>; - - /// Storage for A tile - using FragmentA = typename IteratorA::Fragment; - - /// Storage for transformed A tile - using TransformedFragmentA = Array; - - /// Iterates over the B operand in memory - using IteratorB = - MmaTensorOpMultiplicandTileIterator, - Operand::kB, - ElementB, - LayoutB, - MatrixShape, - Policy::OpDelta::kRow, - kThreadCount, - kPartitionsK>; - - /// Storage for B tile - using FragmentB = typename IteratorB::Fragment; - - /// Storage for transformed B tile - using TransformedFragmentB = Array; - - /// Iterates over the C operand in memory - using IteratorC = MmaTensorOpAccumulatorTileIterator, - ElementC, - LayoutC, - typename ArchMmaOperator::Shape, - typename Policy::OpDelta>; - - /// Storage for C tile - using FragmentC = typename IteratorC::Fragment; - - /// Number of mma operations performed - using MmaIterations = MatrixShape<(Shape::kM + ArchMmaOperator::Shape::kM - 1) / ArchMmaOperator::Shape::kM, - (Shape::kN + ArchMmaOperator::Shape::kN - 1) / ArchMmaOperator::Shape::kN>; - -public: - /// Underlying matrix multiply operator (concept: arch::Mma) - ArchMmaOperator mma; - -public: - // - // Methods - // - - /// Ctor - CUTLASS_DEVICE - MmaTensorOpComputeBWithF16() {} - - /// Performs a warp-level matrix multiply-accumulate operation - CUTLASS_DEVICE - void operator()(FragmentC& D, - TransformedFragmentA const& A, - TransformedFragmentB const& B, - FragmentC const& C, - const int warp_tileB_k_offset) const - { - - using MmaOperandA = typename ArchMmaOperator::FragmentA; - using MmaOperandB = typename ArchMmaOperator::FragmentB; - using MmaOperandC = typename ArchMmaOperator::FragmentC; - - static_assert( - TransformedFragmentB::kElements == MmaOperandB::kElements * kExpansionFactor * MmaIterations::kColumn, - "Each thread should have a pack of mma registers for each column iteration AND for the expanded K dim of B"); - - D = C; - - MmaOperandA const* ptr_A = reinterpret_cast(&A); - MmaOperandB const* ptr_B = reinterpret_cast(&B); - MmaOperandC* ptr_D = reinterpret_cast(&D); + public: + /// Shape of warp-level matrix operation (concept: GemmShape) + using Shape = Shape_; + + /// Data type of multiplicand A + using ElementA = ElementA_; + + /// Layout of multiplicand A + using LayoutA = LayoutA_; + + /// Data type of multiplicand B + using ElementB = ElementB_; + + /// Layout of multiplicand B + using LayoutB = LayoutB_; + + /// Data type of accumulator matrix C + using ElementC = ElementC_; + + /// Layout of accumulator matrix C + using LayoutC = LayoutC_; + + /// Shape of the warp in units of thread (concept: MmaLanePolicySimt) + using Policy = Policy_; + + /// Underlying matrix multiply operator (concept: arch::Mma) + using ArchMmaOperator = typename Policy::Operator; + + /// Indicates math operator + using MathOperator = typename ArchMmaOperator::Operator; + + /// Architecture tag from underlying instruction + using ArchTag = typename ArchMmaOperator::ArchTag; + static_assert( + (platform::is_same::value && + platform::is_same::value) || + (platform::is_same::value && + platform::is_same::value && + ArchTag::kMinComputeCapability >= 80), + "MmaTensorOpCvtBToA only supports underlying HMMA"); + + static_assert(platform::is_same::value || + (platform::is_same::value && + ArchTag::kMinComputeCapability >= 80), + "MmaTensorOpCvtBToA only supports Fp16 A or Bf16 A on Ampere+"); + + // Indicates class of matrix operator + using OperatorClass = arch::OpClassTensorOp; + + /// Shape of underlying instruction + using InstructionShape = typename ArchMmaOperator::Shape; + + /// Instruction shape to override shared memory iterators with + using SharedMemoryInstructionShape = SharedMemoryInstructionShape_; + + static_assert(SharedMemoryInstructionShape::kM == InstructionShape::kM, + "M dimension of compute instruction must match load"); + static_assert(SharedMemoryInstructionShape::kN == InstructionShape::kN, + "N dimension of compute instruction must match load"); + + static constexpr int kExpansionFactor = + SharedMemoryInstructionShape::kK / InstructionShape::kK; + + static_assert(!(Shape::kK % SharedMemoryInstructionShape::kK), ""); + + /// Complex transform on A operand + static ComplexTransform const kTransformA = ComplexTransform::kNone; + + /// Complex transform on B operand + static ComplexTransform const kTransformB = ComplexTransform::kNone; + + /// Number of threads participating in warp-level matrix product + static int const kThreadCount = 32; + + /// Number of partitions along K dimension + static int const kPartitionsK = PartitionsK_; + + public: + /// Iterates over the A operand in memory + using IteratorA = MmaTensorOpMultiplicandTileIterator< + MatrixShape, + Operand::kA, + ElementA, + LayoutA, + MatrixShape, + Policy::OpDelta::kRow, + kThreadCount, + kPartitionsK>; + + /// Storage for A tile + using FragmentA = typename IteratorA::Fragment; + + /// Storage for transformed A tile + using TransformedFragmentA = + Array; + + /// Iterates over the B operand in memory + using IteratorB = MmaTensorOpMultiplicandTileIterator< + MatrixShape, + Operand::kB, + ElementB, + LayoutB, + MatrixShape, + Policy::OpDelta::kRow, + kThreadCount, + kPartitionsK>; + + /// Storage for B tile + using FragmentB = typename IteratorB::Fragment; + + /// Storage for transformed B tile + using TransformedFragmentB = + Array; + + /// Iterates over the C operand in memory + using IteratorC = + MmaTensorOpAccumulatorTileIterator, + ElementC, + LayoutC, + typename ArchMmaOperator::Shape, + typename Policy::OpDelta>; + + /// Storage for C tile + using FragmentC = typename IteratorC::Fragment; + + /// Number of mma operations performed + using MmaIterations = + MatrixShape<(Shape::kM + ArchMmaOperator::Shape::kM - 1) / + ArchMmaOperator::Shape::kM, + (Shape::kN + ArchMmaOperator::Shape::kN - 1) / + ArchMmaOperator::Shape::kN>; + + public: + /// Underlying matrix multiply operator (concept: arch::Mma) + ArchMmaOperator mma; + + public: + // + // Methods + // + + /// Ctor + CUTLASS_DEVICE + MmaTensorOpComputeBWithF16() {} + + /// Performs a warp-level matrix multiply-accumulate operation + CUTLASS_DEVICE + void operator()(FragmentC& D, + TransformedFragmentA const& A, + TransformedFragmentB const& B, + FragmentC const& C, + const int warp_tileB_k_offset) const { + using MmaOperandA = typename ArchMmaOperator::FragmentA; + using MmaOperandB = typename ArchMmaOperator::FragmentB; + using MmaOperandC = typename ArchMmaOperator::FragmentC; + + static_assert( + TransformedFragmentB::kElements == + MmaOperandB::kElements * kExpansionFactor * MmaIterations::kColumn, + "Each thread should have a pack of mma registers for each column " + "iteration AND for the expanded K dim of B"); + + D = C; + + MmaOperandA const* ptr_A = reinterpret_cast(&A); + MmaOperandB const* ptr_B = reinterpret_cast(&B); + MmaOperandC* ptr_D = reinterpret_cast(&D); #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800) - // Serpentine visitation order maximizing reuse of Rb - CUTLASS_PRAGMA_UNROLL - for (int n = 0; n < MmaIterations::kColumn; ++n) { - - CUTLASS_PRAGMA_UNROLL - for (int m = 0; m < MmaIterations::kRow; ++m) { - - int m_serpentine = ((n % 2) ? (MmaIterations::kRow - 1 - m) : m); - - int n_offsetB = warp_tileB_k_offset + kExpansionFactor * n; - if (AccumulatorsInRowMajor) { // matrix B is reordered - mma(ptr_D[n + m_serpentine * MmaIterations::kColumn], - ptr_A[m_serpentine], - ptr_B[n_offsetB], - ptr_D[n + m_serpentine * MmaIterations::kColumn]); - } - else { - mma(ptr_D[m_serpentine + n * MmaIterations::kRow], - ptr_A[m_serpentine], - ptr_B[n_offsetB], - ptr_D[m_serpentine + n * MmaIterations::kRow]); - } - } + // Serpentine visitation order maximizing reuse of Rb + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < MmaIterations::kColumn; ++n) { + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < MmaIterations::kRow; ++m) { + int m_serpentine = ((n % 2) ? (MmaIterations::kRow - 1 - m) : m); + + int n_offsetB = warp_tileB_k_offset + kExpansionFactor * n; + if (AccumulatorsInRowMajor) { // matrix B is reordered + mma(ptr_D[n + m_serpentine * MmaIterations::kColumn], + ptr_A[m_serpentine], + ptr_B[n_offsetB], + ptr_D[n + m_serpentine * MmaIterations::kColumn]); + } else { + mma(ptr_D[m_serpentine + n * MmaIterations::kRow], + ptr_A[m_serpentine], + ptr_B[n_offsetB], + ptr_D[m_serpentine + n * MmaIterations::kRow]); } + } + } #elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) - // Serpentine visitation order maximizing reuse of Ra - CUTLASS_PRAGMA_UNROLL - for (int m = 0; m < MmaIterations::kRow; ++m) { - - CUTLASS_PRAGMA_UNROLL - for (int n = 0; n < MmaIterations::kColumn; ++n) { - - int n_serpentine = ((m % 2) ? (MmaIterations::kColumn - 1 - n) : n); - - int n_serpentine_offsetB = warp_tileB_k_offset + kExpansionFactor * n_serpentine; - if (AccumulatorsInRowMajor) { // matrix B is reordered - mma(ptr_D[n_serpentine + m * MmaIterations::kColumn], - ptr_A[m], - ptr_B[n_serpentine_offsetB], - ptr_D[n_serpentine + m * MmaIterations::kColumn]); - } - else { - mma(ptr_D[m + n_serpentine * MmaIterations::kRow], - ptr_A[m], - ptr_B[n_serpentine_offsetB], - ptr_D[m + n_serpentine * MmaIterations::kRow]); - } - } + // Serpentine visitation order maximizing reuse of Ra + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < MmaIterations::kRow; ++m) { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < MmaIterations::kColumn; ++n) { + int n_serpentine = ((m % 2) ? (MmaIterations::kColumn - 1 - n) : n); + + int n_serpentine_offsetB = + warp_tileB_k_offset + kExpansionFactor * n_serpentine; + if (AccumulatorsInRowMajor) { // matrix B is reordered + mma(ptr_D[n_serpentine + m * MmaIterations::kColumn], + ptr_A[m], + ptr_B[n_serpentine_offsetB], + ptr_D[n_serpentine + m * MmaIterations::kColumn]); + } else { + mma(ptr_D[m + n_serpentine * MmaIterations::kRow], + ptr_A[m], + ptr_B[n_serpentine_offsetB], + ptr_D[m + n_serpentine * MmaIterations::kRow]); } + } + } #else - assert(0); + assert(0); #endif - } + } }; // Specialization for int8 int8 int32. Author(zhengzekang) -template< +template < /// Size of the Gemm problem - concept: gemm::GemmShape<> typename Shape_, /// Layout of A matrix (concept: MatrixLayout) @@ -328,233 +337,235 @@ template< bool AccumulatorsInRowMajor, /// Used for partial specialization typename Enable> -class MmaTensorOpComputeBWithF16< - Shape_, - int8_t, - LayoutA_, - int8_t, - LayoutB_, - ElementC_, - LayoutC_, - Policy_, - SharedMemoryInstructionShape_, - PartitionsK_, - AccumulatorsInRowMajor, - Enable> { -public: - /// Shape of warp-level matrix operation (concept: GemmShape) - using Shape = Shape_; - - /// Data type of multiplicand A - using ElementA = int8_t; - - /// Layout of multiplicand A - using LayoutA = LayoutA_; - - /// Data type of multiplicand B - using ElementB = int8_t; - - /// Layout of multiplicand B - using LayoutB = LayoutB_; - - /// Data type of accumulator matrix C - using ElementC = ElementC_; - - /// Layout of accumulator matrix C - using LayoutC = LayoutC_; - - /// Shape of the warp in units of thread (concept: MmaLanePolicySimt) - using Policy = Policy_; - - /// Underlying matrix multiply operator (concept: arch::Mma) - using ArchMmaOperator = typename Policy::Operator; - - /// Indicates math operator - using MathOperator = typename ArchMmaOperator::Operator; - - /// Architecture tag from underlying instruction - using ArchTag = typename ArchMmaOperator::ArchTag; - static_assert((platform::is_same::value - && platform::is_same::value), - "MmaTensorOpCvtBToA only supports underlying iMMA"); - - // static_assert(platform::is_same::value, - // "MmaTensorOpCvtBToA only supports Fp16 A or Bf16 A on Ampere+"); - - /// Indicates class of matrix operator - using OperatorClass = arch::OpClassTensorOp; - - /// Shape of underlying instruction - using InstructionShape = typename ArchMmaOperator::Shape; - - /// Instruction shape to override shared memory iterators with - using SharedMemoryInstructionShape = SharedMemoryInstructionShape_; - - static_assert(SharedMemoryInstructionShape::kM == InstructionShape::kM, - "M dimension of compute instruction must match load"); - static_assert(SharedMemoryInstructionShape::kN == InstructionShape::kN, - "N dimension of compute instruction must match load"); - - static constexpr int kExpansionFactor = SharedMemoryInstructionShape::kK / InstructionShape::kK; - - static_assert(!(Shape::kK % SharedMemoryInstructionShape::kK), ""); - - /// Complex transform on A operand - static ComplexTransform const kTransformA = ComplexTransform::kNone; - - /// Complex transform on B operand - static ComplexTransform const kTransformB = ComplexTransform::kNone; - - /// Number of threads participating in warp-level matrix product - static int const kThreadCount = 32; - - /// Number of partitions along K dimension - static int const kPartitionsK = PartitionsK_; - -public: - /// Iterates over the A operand in memory - using IteratorA = MmaTensorOpMultiplicandTileIterator, - Operand::kA, - ElementA, - LayoutA, - MatrixShape, - Policy::OpDelta::kRow, - kThreadCount, - kPartitionsK>; - - /// Storage for A tile - using FragmentA = typename IteratorA::Fragment; - - /// Storage for transformed A tile - using TransformedFragmentA = Array; - - /// Iterates over the B operand in memory - using IteratorB = - MmaTensorOpMultiplicandTileIterator, - Operand::kB, - ElementB, - LayoutB, - MatrixShape, - Policy::OpDelta::kRow, - kThreadCount, - kPartitionsK>; - - /// Storage for B tile - using FragmentB = typename IteratorB::Fragment; - - /// Storage for transformed B tile - using TransformedFragmentB = Array; - - /// Iterates over the C operand in memory - using IteratorC = MmaTensorOpAccumulatorTileIterator, - ElementC, - LayoutC, - typename ArchMmaOperator::Shape, - typename Policy::OpDelta>; - - /// Storage for C tile - using FragmentC = typename IteratorC::Fragment; - - /// Number of mma operations performed - using MmaIterations = MatrixShape<(Shape::kM + ArchMmaOperator::Shape::kM - 1) / ArchMmaOperator::Shape::kM, - (Shape::kN + ArchMmaOperator::Shape::kN - 1) / ArchMmaOperator::Shape::kN>; - -public: - /// Underlying matrix multiply operator (concept: arch::Mma) - ArchMmaOperator mma; - -public: - // - // Methods - // - - /// Ctor - CUTLASS_DEVICE - MmaTensorOpComputeBWithF16() {} - - /// Performs a warp-level matrix multiply-accumulate operation - CUTLASS_DEVICE - void operator()(FragmentC& D, - TransformedFragmentA const& A, - TransformedFragmentB const& B, - FragmentC const& C, - const int warp_tileB_k_offset) const - { - - using MmaOperandA = typename ArchMmaOperator::FragmentA; - using MmaOperandB = typename ArchMmaOperator::FragmentB; - using MmaOperandC = typename ArchMmaOperator::FragmentC; - - static_assert( - TransformedFragmentB::kElements == MmaOperandB::kElements * kExpansionFactor * MmaIterations::kColumn, - "Each thread should have a pack of mma registers for each column iteration AND for the expanded K dim of B"); - - D = C; - - MmaOperandA const* ptr_A = reinterpret_cast(&A); - MmaOperandB const* ptr_B = reinterpret_cast(&B); - MmaOperandC* ptr_D = reinterpret_cast(&D); +class MmaTensorOpComputeBWithF16 { + public: + /// Shape of warp-level matrix operation (concept: GemmShape) + using Shape = Shape_; + + /// Data type of multiplicand A + using ElementA = int8_t; + + /// Layout of multiplicand A + using LayoutA = LayoutA_; + + /// Data type of multiplicand B + using ElementB = int8_t; + + /// Layout of multiplicand B + using LayoutB = LayoutB_; + + /// Data type of accumulator matrix C + using ElementC = ElementC_; + + /// Layout of accumulator matrix C + using LayoutC = LayoutC_; + + /// Shape of the warp in units of thread (concept: MmaLanePolicySimt) + using Policy = Policy_; + + /// Underlying matrix multiply operator (concept: arch::Mma) + using ArchMmaOperator = typename Policy::Operator; + + /// Indicates math operator + using MathOperator = typename ArchMmaOperator::Operator; + + /// Architecture tag from underlying instruction + using ArchTag = typename ArchMmaOperator::ArchTag; + static_assert( + (platform::is_same::value && + platform::is_same::value), + "MmaTensorOpCvtBToA only supports underlying iMMA"); + + // static_assert(platform::is_same::value, + // "MmaTensorOpCvtBToA only supports Fp16 A or Bf16 A on + // Ampere+"); + + /// Indicates class of matrix operator + using OperatorClass = arch::OpClassTensorOp; + + /// Shape of underlying instruction + using InstructionShape = typename ArchMmaOperator::Shape; + + /// Instruction shape to override shared memory iterators with + using SharedMemoryInstructionShape = SharedMemoryInstructionShape_; + + static_assert(SharedMemoryInstructionShape::kM == InstructionShape::kM, + "M dimension of compute instruction must match load"); + static_assert(SharedMemoryInstructionShape::kN == InstructionShape::kN, + "N dimension of compute instruction must match load"); + + static constexpr int kExpansionFactor = + SharedMemoryInstructionShape::kK / InstructionShape::kK; + + static_assert(!(Shape::kK % SharedMemoryInstructionShape::kK), ""); + + /// Complex transform on A operand + static ComplexTransform const kTransformA = ComplexTransform::kNone; + + /// Complex transform on B operand + static ComplexTransform const kTransformB = ComplexTransform::kNone; + + /// Number of threads participating in warp-level matrix product + static int const kThreadCount = 32; + + /// Number of partitions along K dimension + static int const kPartitionsK = PartitionsK_; + + public: + /// Iterates over the A operand in memory + using IteratorA = MmaTensorOpMultiplicandTileIterator< + MatrixShape, + Operand::kA, + ElementA, + LayoutA, + MatrixShape, + Policy::OpDelta::kRow, + kThreadCount, + kPartitionsK>; + + /// Storage for A tile + using FragmentA = typename IteratorA::Fragment; + + /// Storage for transformed A tile + using TransformedFragmentA = + Array; + + /// Iterates over the B operand in memory + using IteratorB = MmaTensorOpMultiplicandTileIterator< + MatrixShape, + Operand::kB, + ElementB, + LayoutB, + MatrixShape, + Policy::OpDelta::kRow, + kThreadCount, + kPartitionsK>; + + /// Storage for B tile + using FragmentB = typename IteratorB::Fragment; + + /// Storage for transformed B tile + using TransformedFragmentB = + Array; + + /// Iterates over the C operand in memory + using IteratorC = + MmaTensorOpAccumulatorTileIterator, + ElementC, + LayoutC, + typename ArchMmaOperator::Shape, + typename Policy::OpDelta>; + + /// Storage for C tile + using FragmentC = typename IteratorC::Fragment; + + /// Number of mma operations performed + using MmaIterations = + MatrixShape<(Shape::kM + ArchMmaOperator::Shape::kM - 1) / + ArchMmaOperator::Shape::kM, + (Shape::kN + ArchMmaOperator::Shape::kN - 1) / + ArchMmaOperator::Shape::kN>; + + public: + /// Underlying matrix multiply operator (concept: arch::Mma) + ArchMmaOperator mma; + + public: + // + // Methods + // + + /// Ctor + CUTLASS_DEVICE + MmaTensorOpComputeBWithF16() {} + + /// Performs a warp-level matrix multiply-accumulate operation + CUTLASS_DEVICE + void operator()(FragmentC& D, + TransformedFragmentA const& A, + TransformedFragmentB const& B, + FragmentC const& C, + const int warp_tileB_k_offset) const { + using MmaOperandA = typename ArchMmaOperator::FragmentA; + using MmaOperandB = typename ArchMmaOperator::FragmentB; + using MmaOperandC = typename ArchMmaOperator::FragmentC; + + static_assert( + TransformedFragmentB::kElements == + MmaOperandB::kElements * kExpansionFactor * MmaIterations::kColumn, + "Each thread should have a pack of mma registers for each column " + "iteration AND for the expanded K dim of B"); + + D = C; + + MmaOperandA const* ptr_A = reinterpret_cast(&A); + MmaOperandB const* ptr_B = reinterpret_cast(&B); + MmaOperandC* ptr_D = reinterpret_cast(&D); #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800) - // Serpentine visitation order maximizing reuse of Rb - CUTLASS_PRAGMA_UNROLL - for (int n = 0; n < MmaIterations::kColumn; ++n) { - - CUTLASS_PRAGMA_UNROLL - for (int m = 0; m < MmaIterations::kRow; ++m) { - - int m_serpentine = ((n % 2) ? (MmaIterations::kRow - 1 - m) : m); - - int n_offsetB = warp_tileB_k_offset + kExpansionFactor * n; - if (AccumulatorsInRowMajor) { // matrix B is reordered - mma(ptr_D[n + m_serpentine * MmaIterations::kColumn], - ptr_A[m_serpentine], - ptr_B[n_offsetB], - ptr_D[n + m_serpentine * MmaIterations::kColumn]); - } - else { - mma(ptr_D[m_serpentine + n * MmaIterations::kRow], - ptr_A[m_serpentine], - ptr_B[n_offsetB], - ptr_D[m_serpentine + n * MmaIterations::kRow]); - } - } + // Serpentine visitation order maximizing reuse of Rb + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < MmaIterations::kColumn; ++n) { + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < MmaIterations::kRow; ++m) { + int m_serpentine = ((n % 2) ? (MmaIterations::kRow - 1 - m) : m); + + int n_offsetB = warp_tileB_k_offset + kExpansionFactor * n; + if (AccumulatorsInRowMajor) { // matrix B is reordered + mma(ptr_D[n + m_serpentine * MmaIterations::kColumn], + ptr_A[m_serpentine], + ptr_B[n_offsetB], + ptr_D[n + m_serpentine * MmaIterations::kColumn]); + } else { + mma(ptr_D[m_serpentine + n * MmaIterations::kRow], + ptr_A[m_serpentine], + ptr_B[n_offsetB], + ptr_D[m_serpentine + n * MmaIterations::kRow]); } + } + } #elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) - // Serpentine visitation order maximizing reuse of Ra - CUTLASS_PRAGMA_UNROLL - for (int m = 0; m < MmaIterations::kRow; ++m) { - - CUTLASS_PRAGMA_UNROLL - for (int n = 0; n < MmaIterations::kColumn; ++n) { - - int n_serpentine = ((m % 2) ? (MmaIterations::kColumn - 1 - n) : n); - - int n_serpentine_offsetB = warp_tileB_k_offset + kExpansionFactor * n_serpentine; - if (AccumulatorsInRowMajor) { // matrix B is reordered - mma(ptr_D[n_serpentine + m * MmaIterations::kColumn], - ptr_A[m], - ptr_B[n_serpentine_offsetB], - ptr_D[n_serpentine + m * MmaIterations::kColumn]); - } - else { - mma(ptr_D[m + n_serpentine * MmaIterations::kRow], - ptr_A[m], - ptr_B[n_serpentine_offsetB], - ptr_D[m + n_serpentine * MmaIterations::kRow]); - } - } + // Serpentine visitation order maximizing reuse of Ra + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < MmaIterations::kRow; ++m) { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < MmaIterations::kColumn; ++n) { + int n_serpentine = ((m % 2) ? (MmaIterations::kColumn - 1 - n) : n); + + int n_serpentine_offsetB = + warp_tileB_k_offset + kExpansionFactor * n_serpentine; + if (AccumulatorsInRowMajor) { // matrix B is reordered + mma(ptr_D[n_serpentine + m * MmaIterations::kColumn], + ptr_A[m], + ptr_B[n_serpentine_offsetB], + ptr_D[n_serpentine + m * MmaIterations::kColumn]); + } else { + mma(ptr_D[m + n_serpentine * MmaIterations::kRow], + ptr_A[m], + ptr_B[n_serpentine_offsetB], + ptr_D[m + n_serpentine * MmaIterations::kRow]); } + } + } #else - assert(0); + assert(0); #endif - } + } }; - - // Specialization for int8 int8 int32. Author(zhengzekang) -template< +template < /// Size of the Gemm problem - concept: gemm::GemmShape<> typename Shape_, /// Layout of A matrix (concept: MatrixLayout) @@ -576,235 +587,250 @@ template< bool AccumulatorsInRowMajor, /// Used for partial specialization typename Enable> -class MmaTensorOpComputeBWithF16< - Shape_, - int8_t, - LayoutA_, - cutlass::uint4b_t, - LayoutB_, - ElementC_, - LayoutC_, - Policy_, - SharedMemoryInstructionShape_, - PartitionsK_, - AccumulatorsInRowMajor, - Enable> { -public: - /// Shape of warp-level matrix operation (concept: GemmShape) - using Shape = Shape_; - - /// Data type of multiplicand A - using ElementA = int8_t; - - /// Layout of multiplicand A - using LayoutA = LayoutA_; - - /// Data type of multiplicand B - using ElementB = cutlass::uint4b_t; - - /// Layout of multiplicand B - using LayoutB = LayoutB_; - - /// Data type of accumulator matrix C - using ElementC = ElementC_; - - /// Layout of accumulator matrix C - using LayoutC = LayoutC_; - - /// Shape of the warp in units of thread (concept: MmaLanePolicySimt) - using Policy = Policy_; - - /// Underlying matrix multiply operator (concept: arch::Mma) - using ArchMmaOperator = typename Policy::Operator; - - /// Indicates math operator - using MathOperator = typename ArchMmaOperator::Operator; - - /// Architecture tag from underlying instruction - using ArchTag = typename ArchMmaOperator::ArchTag; - static_assert((platform::is_same::value - && platform::is_same::value), - "MmaTensorOpCvtBToA only supports underlying iMMA"); - - // static_assert(platform::is_same::value, - // "MmaTensorOpCvtBToA only supports Fp16 A or Bf16 A on Ampere+"); - - /// Indicates class of matrix operator - using OperatorClass = arch::OpClassTensorOp; - - /// Shape of underlying instruction - using InstructionShape = typename ArchMmaOperator::Shape; - - /// Instruction shape to override shared memory iterators with - using SharedMemoryInstructionShape = SharedMemoryInstructionShape_; - - static_assert(SharedMemoryInstructionShape::kM == InstructionShape::kM, - "M dimension of compute instruction must match load"); - static_assert(SharedMemoryInstructionShape::kN == InstructionShape::kN, - "N dimension of compute instruction must match load"); - - static constexpr int kExpansionFactor = SharedMemoryInstructionShape::kK / InstructionShape::kK; - - static_assert(!(Shape::kK % SharedMemoryInstructionShape::kK), ""); - - /// Complex transform on A operand - static ComplexTransform const kTransformA = ComplexTransform::kNone; - - /// Complex transform on B operand - static ComplexTransform const kTransformB = ComplexTransform::kNone; - - /// Number of threads participating in warp-level matrix product - static int const kThreadCount = 32; - - /// Number of partitions along K dimension - static int const kPartitionsK = PartitionsK_; - -public: - /// Iterates over the A operand in memory - - - // ::cutlass::gemm::warp::MmaTensorOpMultiplicandTileIterator< ::cutlass::MatrixShape<(int)32, (int)64> , ( ::cutlass::gemm::Operand)0, signed char, ::cutlass::layout::RowMajorTensorOpMultiplicandCrosswise<(int)8, (int)64> , ::cutlass::MatrixShape<(int)16, (int)32> , (int)1, (int)32, (int)1> > ()") - using IteratorA = MmaTensorOpMultiplicandTileIterator, - Operand::kA, - ElementA, - LayoutA, - MatrixShape, - Policy::OpDelta::kRow, - kThreadCount, - kPartitionsK>; - - /// Storage for A tile - using FragmentA = typename IteratorA::Fragment; - - /// Storage for transformed A tile - using TransformedFragmentA = Array; -// bool ::print_type< ::cutlass::gemm::warp::MmaTensorOpMultiplicandTileIterator< ::cutlass::MatrixShape<(int)64, (int)32> , ( ::cutlass::gemm::Operand)1, ::cutlass::integer_subbyte<(int)4, (bool)0> , ::cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise<(int)4, (int)64> , ::cutlass::MatrixShape<(int)64, (int)8> , (int)1, (int)32, (int)1> > ()") - /// Iterates over the B operand in memory - using IteratorB = - MmaTensorOpMultiplicandTileIterator, - Operand::kB, - ElementB, - LayoutB, - MatrixShape, - Policy::OpDelta::kRow, - kThreadCount, - kPartitionsK>; -// ::cutlass::gemm::warp::MmaTensorOpMultiplicandTileIterator< ::cutlass::MatrixShape<(int)64, (int)32> , ( ::cutlass::gemm::Operand)1, ::cutlass::integer_subbyte<(int)4, (bool)0> , ::cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise<(int)4, (int)64> , ::cutlass::MatrixShape<(int)64, (int)8> , (int)1, (int)32, (int)1> > - // print_type(); - /// Storage for B tile - using FragmentB = typename IteratorB::Fragment; - - /// Storage for transformed B tile - using TransformedFragmentB = Array; - - /// Iterates over the C operand in memory - using IteratorC = MmaTensorOpAccumulatorTileIterator, - ElementC, - LayoutC, - typename ArchMmaOperator::Shape, - typename Policy::OpDelta>; - - /// Storage for C tile - using FragmentC = typename IteratorC::Fragment; - - /// Number of mma operations performed - using MmaIterations = MatrixShape<(Shape::kM + ArchMmaOperator::Shape::kM - 1) / ArchMmaOperator::Shape::kM, - (Shape::kN + ArchMmaOperator::Shape::kN - 1) / ArchMmaOperator::Shape::kN>; - -public: - /// Underlying matrix multiply operator (concept: arch::Mma) - ArchMmaOperator mma; - -public: - // - // Methods - // - - /// Ctor - CUTLASS_DEVICE - MmaTensorOpComputeBWithF16() {} - - /// Performs a warp-level matrix multiply-accumulate operation - CUTLASS_DEVICE - void operator()(FragmentC& D, - TransformedFragmentA const& A, - TransformedFragmentB const& B, - FragmentC const& C, - const int warp_tileB_k_offset) const - { - - using MmaOperandA = typename ArchMmaOperator::FragmentA; - using MmaOperandB = typename ArchMmaOperator::FragmentB; - using MmaOperandC = typename ArchMmaOperator::FragmentC; - - static_assert( - TransformedFragmentB::kElements == MmaOperandB::kElements * kExpansionFactor * MmaIterations::kColumn, - "Each thread should have a pack of mma registers for each column iteration AND for the expanded K dim of B"); - - D = C; - - MmaOperandA const* ptr_A = reinterpret_cast(&A); - MmaOperandB const* ptr_B = reinterpret_cast(&B); - MmaOperandC* ptr_D = reinterpret_cast(&D); +class MmaTensorOpComputeBWithF16 { + public: + /// Shape of warp-level matrix operation (concept: GemmShape) + using Shape = Shape_; + + /// Data type of multiplicand A + using ElementA = int8_t; + + /// Layout of multiplicand A + using LayoutA = LayoutA_; + + /// Data type of multiplicand B + using ElementB = cutlass::uint4b_t; + + /// Layout of multiplicand B + using LayoutB = LayoutB_; + + /// Data type of accumulator matrix C + using ElementC = ElementC_; + + /// Layout of accumulator matrix C + using LayoutC = LayoutC_; + + /// Shape of the warp in units of thread (concept: MmaLanePolicySimt) + using Policy = Policy_; + + /// Underlying matrix multiply operator (concept: arch::Mma) + using ArchMmaOperator = typename Policy::Operator; + + /// Indicates math operator + using MathOperator = typename ArchMmaOperator::Operator; + + /// Architecture tag from underlying instruction + using ArchTag = typename ArchMmaOperator::ArchTag; + static_assert( + (platform::is_same::value && + platform::is_same::value), + "MmaTensorOpCvtBToA only supports underlying iMMA"); + + // static_assert(platform::is_same::value, + // "MmaTensorOpCvtBToA only supports Fp16 A or Bf16 A on + // Ampere+"); + + /// Indicates class of matrix operator + using OperatorClass = arch::OpClassTensorOp; + + /// Shape of underlying instruction + using InstructionShape = typename ArchMmaOperator::Shape; + + /// Instruction shape to override shared memory iterators with + using SharedMemoryInstructionShape = SharedMemoryInstructionShape_; + + static_assert(SharedMemoryInstructionShape::kM == InstructionShape::kM, + "M dimension of compute instruction must match load"); + static_assert(SharedMemoryInstructionShape::kN == InstructionShape::kN, + "N dimension of compute instruction must match load"); + + static constexpr int kExpansionFactor = + SharedMemoryInstructionShape::kK / InstructionShape::kK; + + static_assert(!(Shape::kK % SharedMemoryInstructionShape::kK), ""); + + /// Complex transform on A operand + static ComplexTransform const kTransformA = ComplexTransform::kNone; + + /// Complex transform on B operand + static ComplexTransform const kTransformB = ComplexTransform::kNone; + + /// Number of threads participating in warp-level matrix product + static int const kThreadCount = 32; + + /// Number of partitions along K dimension + static int const kPartitionsK = PartitionsK_; + + public: + /// Iterates over the A operand in memory + + // ::cutlass::gemm::warp::MmaTensorOpMultiplicandTileIterator< + // ::cutlass::MatrixShape<(int)32, (int)64> , ( ::cutlass::gemm::Operand)0, + // signed char, + // ::cutlass::layout::RowMajorTensorOpMultiplicandCrosswise<(int)8, (int)64> , + // ::cutlass::MatrixShape<(int)16, (int)32> , (int)1, (int)32, (int)1> > ()") + using IteratorA = MmaTensorOpMultiplicandTileIterator< + MatrixShape, + Operand::kA, + ElementA, + LayoutA, + MatrixShape, + Policy::OpDelta::kRow, + kThreadCount, + kPartitionsK>; + + /// Storage for A tile + using FragmentA = typename IteratorA::Fragment; + + /// Storage for transformed A tile + using TransformedFragmentA = + Array; + // bool ::print_type< + // ::cutlass::gemm::warp::MmaTensorOpMultiplicandTileIterator< + // ::cutlass::MatrixShape<(int)64, (int)32> , ( ::cutlass::gemm::Operand)1, + // ::cutlass::integer_subbyte<(int)4, (bool)0> , + // ::cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise<(int)4, + // (int)64> , ::cutlass::MatrixShape<(int)64, (int)8> , (int)1, (int)32, + // (int)1> > ()") + /// Iterates over the B operand in memory + using IteratorB = MmaTensorOpMultiplicandTileIterator< + MatrixShape, + Operand::kB, + ElementB, + LayoutB, + MatrixShape, + Policy::OpDelta::kRow, + kThreadCount, + kPartitionsK>; + // ::cutlass::gemm::warp::MmaTensorOpMultiplicandTileIterator< + // ::cutlass::MatrixShape<(int)64, (int)32> , ( ::cutlass::gemm::Operand)1, + // ::cutlass::integer_subbyte<(int)4, (bool)0> , + // ::cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise<(int)4, + // (int)64> , ::cutlass::MatrixShape<(int)64, (int)8> , (int)1, (int)32, + // (int)1> > print_type(); + /// Storage for B tile + using FragmentB = typename IteratorB::Fragment; + + /// Storage for transformed B tile + using TransformedFragmentB = + Array; + + /// Iterates over the C operand in memory + using IteratorC = + MmaTensorOpAccumulatorTileIterator, + ElementC, + LayoutC, + typename ArchMmaOperator::Shape, + typename Policy::OpDelta>; + + /// Storage for C tile + using FragmentC = typename IteratorC::Fragment; + + /// Number of mma operations performed + using MmaIterations = + MatrixShape<(Shape::kM + ArchMmaOperator::Shape::kM - 1) / + ArchMmaOperator::Shape::kM, + (Shape::kN + ArchMmaOperator::Shape::kN - 1) / + ArchMmaOperator::Shape::kN>; + + public: + /// Underlying matrix multiply operator (concept: arch::Mma) + ArchMmaOperator mma; + + public: + // + // Methods + // + + /// Ctor + CUTLASS_DEVICE + MmaTensorOpComputeBWithF16() {} + + /// Performs a warp-level matrix multiply-accumulate operation + CUTLASS_DEVICE + void operator()(FragmentC& D, + TransformedFragmentA const& A, + TransformedFragmentB const& B, + FragmentC const& C, + const int warp_tileB_k_offset) const { + using MmaOperandA = typename ArchMmaOperator::FragmentA; + using MmaOperandB = typename ArchMmaOperator::FragmentB; + using MmaOperandC = typename ArchMmaOperator::FragmentC; + + static_assert( + TransformedFragmentB::kElements == + MmaOperandB::kElements * kExpansionFactor * MmaIterations::kColumn, + "Each thread should have a pack of mma registers for each column " + "iteration AND for the expanded K dim of B"); + + D = C; + + MmaOperandA const* ptr_A = reinterpret_cast(&A); + MmaOperandB const* ptr_B = reinterpret_cast(&B); + MmaOperandC* ptr_D = reinterpret_cast(&D); #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800) - // Serpentine visitation order maximizing reuse of Rb - CUTLASS_PRAGMA_UNROLL - for (int n = 0; n < MmaIterations::kColumn; ++n) { - - CUTLASS_PRAGMA_UNROLL - for (int m = 0; m < MmaIterations::kRow; ++m) { - - int m_serpentine = ((n % 2) ? (MmaIterations::kRow - 1 - m) : m); - - int n_offsetB = warp_tileB_k_offset + kExpansionFactor * n; - if (AccumulatorsInRowMajor) { // matrix B is reordered - mma(ptr_D[n + m_serpentine * MmaIterations::kColumn], - ptr_A[m_serpentine], - ptr_B[n_offsetB], - ptr_D[n + m_serpentine * MmaIterations::kColumn]); - } - else { - mma(ptr_D[m_serpentine + n * MmaIterations::kRow], - ptr_A[m_serpentine], - ptr_B[n_offsetB], - ptr_D[m_serpentine + n * MmaIterations::kRow]); - } - } + // Serpentine visitation order maximizing reuse of Rb + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < MmaIterations::kColumn; ++n) { + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < MmaIterations::kRow; ++m) { + int m_serpentine = ((n % 2) ? (MmaIterations::kRow - 1 - m) : m); + + int n_offsetB = warp_tileB_k_offset + kExpansionFactor * n; + if (AccumulatorsInRowMajor) { // matrix B is reordered + mma(ptr_D[n + m_serpentine * MmaIterations::kColumn], + ptr_A[m_serpentine], + ptr_B[n_offsetB], + ptr_D[n + m_serpentine * MmaIterations::kColumn]); + } else { + mma(ptr_D[m_serpentine + n * MmaIterations::kRow], + ptr_A[m_serpentine], + ptr_B[n_offsetB], + ptr_D[m_serpentine + n * MmaIterations::kRow]); } + } + } #elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) - // Serpentine visitation order maximizing reuse of Ra - CUTLASS_PRAGMA_UNROLL - for (int m = 0; m < MmaIterations::kRow; ++m) { - - CUTLASS_PRAGMA_UNROLL - for (int n = 0; n < MmaIterations::kColumn; ++n) { - - int n_serpentine = ((m % 2) ? (MmaIterations::kColumn - 1 - n) : n); - - int n_serpentine_offsetB = warp_tileB_k_offset + kExpansionFactor * n_serpentine; - if (AccumulatorsInRowMajor) { // matrix B is reordered - mma(ptr_D[n_serpentine + m * MmaIterations::kColumn], - ptr_A[m], - ptr_B[n_serpentine_offsetB], - ptr_D[n_serpentine + m * MmaIterations::kColumn]); - } - else { - mma(ptr_D[m + n_serpentine * MmaIterations::kRow], - ptr_A[m], - ptr_B[n_serpentine_offsetB], - ptr_D[m + n_serpentine * MmaIterations::kRow]); - } - } + // Serpentine visitation order maximizing reuse of Ra + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < MmaIterations::kRow; ++m) { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < MmaIterations::kColumn; ++n) { + int n_serpentine = ((m % 2) ? (MmaIterations::kColumn - 1 - n) : n); + + int n_serpentine_offsetB = + warp_tileB_k_offset + kExpansionFactor * n_serpentine; + if (AccumulatorsInRowMajor) { // matrix B is reordered + mma(ptr_D[n_serpentine + m * MmaIterations::kColumn], + ptr_A[m], + ptr_B[n_serpentine_offsetB], + ptr_D[n_serpentine + m * MmaIterations::kColumn]); + } else { + mma(ptr_D[m + n_serpentine * MmaIterations::kRow], + ptr_A[m], + ptr_B[n_serpentine_offsetB], + ptr_D[m + n_serpentine * MmaIterations::kRow]); } + } + } #else - assert(0); + assert(0); #endif - } + } }; - - ///////////////////////////////////////////////////////////////////////////////////////////////// } // namespace warp diff --git a/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h b/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h index 73a11224334..73c8fbb4572 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h +++ b/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h @@ -187,10 +187,9 @@ class MmaTensorOpDequantizer< // Adds a pointer offset in units of elements. CUTLASS_DEVICE - void add_pointer_offset(int64_t const& offset) - { - static_assert(sizeof(ElementScale) > 1, ""); - pointer_ += offset; + void add_pointer_offset(int64_t const& offset) { + static_assert(sizeof(ElementScale) > 1, ""); + pointer_ += offset; } private: @@ -297,10 +296,9 @@ class MmaTensorOpDequantizer< } // Adds a pointer offset in units of elements. CUTLASS_DEVICE - void add_pointer_offset(int64_t const& offset) - { - static_assert(sizeof(ElementScale) > 1, ""); - pointer_ += offset; + void add_pointer_offset(int64_t const& offset) { + static_assert(sizeof(ElementScale) > 1, ""); + pointer_ += offset; } private: diff --git a/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/interleaved_numeric_conversion.h b/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/interleaved_numeric_conversion.h index 6f09be73e35..d68a683f4a5 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/interleaved_numeric_conversion.h +++ b/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/interleaved_numeric_conversion.h @@ -435,180 +435,197 @@ struct FastInterleavedAndBiasedNumericArrayConverter { result_type operator()(source_type const& s) { return convert(s); } }; -template<> +template <> struct FastInterleavedAndBiasedNumericArrayConverter { - using result_type = Array; - using source_type = Array; - - CUTLASS_DEVICE - static result_type convert(source_type const& source) - { - result_type result; - - uint32_t* h = reinterpret_cast(&result); - uint32_t const i8s = reinterpret_cast(source); - - // 3 2 1 0 -> 3 1 2 0 - static constexpr uint32_t mask_for_elt = 0x3120; - asm volatile("prmt.b32 %0,%1,%2,%3;\n" : "=r"(h[0]) : "r"(i8s), "r"(i8s), "n"(mask_for_elt)); - - // Author zhengzekang - uint8_t* tmp = reinterpret_cast(&result); - #pragma unroll - for(int i = 0; i < 4; i++){ - result[i] = static_cast(static_cast(tmp[i]) - 128); - } - return result; - } + using result_type = Array; + using source_type = Array; - CUTLASS_DEVICE - result_type operator()(source_type const& s) - { - return convert(s); - } -}; + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + result_type result; + uint32_t* h = reinterpret_cast(&result); + uint32_t const i8s = reinterpret_cast(source); -template<> -struct FastInterleavedAndBiasedNumericArrayConverter { - using result_type = Array; - using source_type = Array; - - CUTLASS_DEVICE - static result_type convert(source_type const& source) - { - result_type result; - - uint32_t* h = reinterpret_cast(&result); - const uint32_t * i8s = reinterpret_cast(&source); - - // 3 2 1 0 -> 3 1 2 0 - static constexpr uint32_t mask_for_elt_1 = 0x7120; - static constexpr uint32_t mask_for_elt_2 = 0x3654; - - asm volatile("prmt.b32 %0,%1,%2,%3;\n" : "=r"(h[0]) : "r"(i8s[0]), "r"(i8s[1]), "n"(mask_for_elt_1)); - asm volatile("prmt.b32 %0,%1,%2,%3;\n" : "=r"(h[1]) : "r"(i8s[0]), "r"(i8s[1]), "n"(mask_for_elt_2)); - - // Author zhengzekang - uint8_t* tmp = reinterpret_cast(&result); - #pragma unroll - for(int i = 0; i < 8; i++){ - result[i] = static_cast(static_cast(tmp[i]) - 128); - } - return result; - } + // 3 2 1 0 -> 3 1 2 0 + static constexpr uint32_t mask_for_elt = 0x3120; + asm volatile("prmt.b32 %0,%1,%2,%3;\n" + : "=r"(h[0]) + : "r"(i8s), "r"(i8s), "n"(mask_for_elt)); - CUTLASS_DEVICE - result_type operator()(source_type const& s) - { - return convert(s); + // Author zhengzekang + uint8_t* tmp = reinterpret_cast(&result); +#pragma unroll + for (int i = 0; i < 4; i++) { + result[i] = static_cast(static_cast(tmp[i]) - 128); } + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) { return convert(s); } }; +template <> +struct FastInterleavedAndBiasedNumericArrayConverter { + using result_type = Array; + using source_type = Array; -template<> -struct FastInterleavedAndBiasedNumericArrayConverter { - using result_type = Array; - using source_type = Array; + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + result_type result; - CUTLASS_DEVICE - static result_type convert(source_type const& source) - { - result_type result; + uint32_t* h = reinterpret_cast(&result); + const uint32_t* i8s = reinterpret_cast(&source); - uint32_t* h = reinterpret_cast(&result); - const uint32_t * i8s = reinterpret_cast(&source); + // 3 2 1 0 -> 3 1 2 0 + static constexpr uint32_t mask_for_elt_1 = 0x7120; + static constexpr uint32_t mask_for_elt_2 = 0x3654; - // 3 2 1 0 -> 3 1 2 0 - static constexpr uint32_t mask_for_elt_1 = 0x3120; - static constexpr uint32_t mask_for_elt_2 = 0x3120; - static constexpr uint32_t mask_for_elt_3 = 0x3120; - static constexpr uint32_t mask_for_elt_4 = 0x3120; + asm volatile("prmt.b32 %0,%1,%2,%3;\n" + : "=r"(h[0]) + : "r"(i8s[0]), "r"(i8s[1]), "n"(mask_for_elt_1)); + asm volatile("prmt.b32 %0,%1,%2,%3;\n" + : "=r"(h[1]) + : "r"(i8s[0]), "r"(i8s[1]), "n"(mask_for_elt_2)); - asm volatile("prmt.b32 %0,%1,%2,%3;\n" : "=r"(h[0]) : "r"(i8s[0]), "r"(i8s[2]), "n"(mask_for_elt_1)); - asm volatile("prmt.b32 %0,%1,%2,%3;\n" : "=r"(h[1]) : "r"(i8s[1]), "r"(i8s[3]), "n"(mask_for_elt_2)); - asm volatile("prmt.b32 %0,%1,%2,%3;\n" : "=r"(h[2]) : "r"(i8s[2]), "r"(i8s[0]), "n"(mask_for_elt_3)); - asm volatile("prmt.b32 %0,%1,%2,%3;\n" : "=r"(h[3]) : "r"(i8s[3]), "r"(i8s[1]), "n"(mask_for_elt_4)); + // Author zhengzekang + uint8_t* tmp = reinterpret_cast(&result); +#pragma unroll + for (int i = 0; i < 8; i++) { + result[i] = static_cast(static_cast(tmp[i]) - 128); + } + return result; + } + CUTLASS_DEVICE + result_type operator()(source_type const& s) { return convert(s); } +}; + +template <> +struct FastInterleavedAndBiasedNumericArrayConverter { + using result_type = Array; + using source_type = Array; - // Author zhengzekang - uint8_t* tmp = reinterpret_cast(&result); - #pragma unroll - for(int i = 0; i < 16; i++){ - result[i] = static_cast(static_cast(tmp[i]) - 128); - } + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + result_type result; - return result; + uint32_t* h = reinterpret_cast(&result); + const uint32_t* i8s = reinterpret_cast(&source); - } + // 3 2 1 0 -> 3 1 2 0 + static constexpr uint32_t mask_for_elt_1 = 0x3120; + static constexpr uint32_t mask_for_elt_2 = 0x3120; + static constexpr uint32_t mask_for_elt_3 = 0x3120; + static constexpr uint32_t mask_for_elt_4 = 0x3120; - CUTLASS_DEVICE - result_type operator()(source_type const& s) - { - return convert(s); + asm volatile("prmt.b32 %0,%1,%2,%3;\n" + : "=r"(h[0]) + : "r"(i8s[0]), "r"(i8s[2]), "n"(mask_for_elt_1)); + asm volatile("prmt.b32 %0,%1,%2,%3;\n" + : "=r"(h[1]) + : "r"(i8s[1]), "r"(i8s[3]), "n"(mask_for_elt_2)); + asm volatile("prmt.b32 %0,%1,%2,%3;\n" + : "=r"(h[2]) + : "r"(i8s[2]), "r"(i8s[0]), "n"(mask_for_elt_3)); + asm volatile("prmt.b32 %0,%1,%2,%3;\n" + : "=r"(h[3]) + : "r"(i8s[3]), "r"(i8s[1]), "n"(mask_for_elt_4)); + + // Author zhengzekang + uint8_t* tmp = reinterpret_cast(&result); +#pragma unroll + for (int i = 0; i < 16; i++) { + result[i] = static_cast(static_cast(tmp[i]) - 128); } + + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) { return convert(s); } }; -template +template struct FastInterleavedAndBiasedNumericArrayConverter { - static constexpr int VEC_WIDTH = 4; - static_assert(N == 32,"N must be 32"); - static_assert(!(N % VEC_WIDTH), "N must be multiple of 16."); - - using result_type = Array; - using source_type = Array; - - CUTLASS_DEVICE - static result_type convert(source_type const& source) - { - using scalar_result_type = typename result_type::Element; - using scalar_source_type = typename source_type::Element; - FastInterleavedAndBiasedNumericArrayConverter - convert_vector_; - - result_type result; - using vec_result = Array; - using vec_source = Array; - - vec_result* result_ptr = reinterpret_cast(&result); - vec_source const* source_ptr = reinterpret_cast(&source); - - CUTLASS_PRAGMA_UNROLL - for (int32_t i = 0; i < N / VEC_WIDTH; ++i) { - result_ptr[i] = convert_vector_((source_ptr)[i]); - } - Array temp_rearrange_array; - auto lane_idx_div4 = threadIdx.x%4; - - CUTLASS_PRAGMA_UNROLL - for (int32_t i=0;i< N/8;++i){ - uint32_t* temp_rearrange_array_ptr = reinterpret_cast(&temp_rearrange_array); - uint32_t* result_reg_ptr = reinterpret_cast(result_ptr)+i * 2; - temp_rearrange_array_ptr[0] = __shfl_xor_sync(0xFFFFFFFF,reinterpret_cast(result_reg_ptr)[0],3); - temp_rearrange_array_ptr[1] = __shfl_xor_sync(0xFFFFFFFF,reinterpret_cast(result_reg_ptr)[1],3); - if( lane_idx_div4==1 || lane_idx_div4==2 ){ - result_reg_ptr[0]=temp_rearrange_array_ptr[0]; - result_reg_ptr[1]=temp_rearrange_array_ptr[1]; - } - temp_rearrange_array_ptr[0] = __shfl_xor_sync(0xFFFFFFFF,reinterpret_cast(result_reg_ptr)[0],2); - temp_rearrange_array_ptr[1] = __shfl_xor_sync(0xFFFFFFFF,reinterpret_cast(result_reg_ptr)[1],2); - if(lane_idx_div4<2){ - asm volatile("prmt.b32 %0,%1,%2,%3;\n" : "=r"(result_reg_ptr[0]) : "r"(result_reg_ptr[0]), "r"(temp_rearrange_array_ptr[0]), "n"(0x5410)); - asm volatile("prmt.b32 %0,%1,%2,%3;\n" : "=r"(result_reg_ptr[1]) : "r"(result_reg_ptr[1]), "r"(temp_rearrange_array_ptr[1]), "n"(0x5410)); - } - else{ - asm volatile("prmt.b32 %0,%1,%2,%3;\n" : "=r"(result_reg_ptr[0]) : "r"(result_reg_ptr[0]), "r"(temp_rearrange_array_ptr[0]), "n"(0x3276)); - asm volatile("prmt.b32 %0,%1,%2,%3;\n" : "=r"(result_reg_ptr[1]) : "r"(result_reg_ptr[1]), "r"(temp_rearrange_array_ptr[1]), "n"(0x3276)); - } - } - - return result; + static constexpr int VEC_WIDTH = 4; + static_assert(N == 32, "N must be 32"); + static_assert(!(N % VEC_WIDTH), "N must be multiple of 16."); + + using result_type = Array; + using source_type = Array; + + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + using scalar_result_type = typename result_type::Element; + using scalar_source_type = typename source_type::Element; + FastInterleavedAndBiasedNumericArrayConverter + convert_vector_; + + result_type result; + using vec_result = Array; + using vec_source = Array; + + vec_result* result_ptr = reinterpret_cast(&result); + vec_source const* source_ptr = reinterpret_cast(&source); + + CUTLASS_PRAGMA_UNROLL + for (int32_t i = 0; i < N / VEC_WIDTH; ++i) { + result_ptr[i] = convert_vector_((source_ptr)[i]); } + Array temp_rearrange_array; + auto lane_idx_div4 = threadIdx.x % 4; - CUTLASS_DEVICE - result_type operator()(source_type const& s) - { - return convert(s); + CUTLASS_PRAGMA_UNROLL + for (int32_t i = 0; i < N / 8; ++i) { + uint32_t* temp_rearrange_array_ptr = + reinterpret_cast(&temp_rearrange_array); + uint32_t* result_reg_ptr = + reinterpret_cast(result_ptr) + i * 2; + temp_rearrange_array_ptr[0] = __shfl_xor_sync( + 0xFFFFFFFF, reinterpret_cast(result_reg_ptr)[0], 3); + temp_rearrange_array_ptr[1] = __shfl_xor_sync( + 0xFFFFFFFF, reinterpret_cast(result_reg_ptr)[1], 3); + if (lane_idx_div4 == 1 || lane_idx_div4 == 2) { + result_reg_ptr[0] = temp_rearrange_array_ptr[0]; + result_reg_ptr[1] = temp_rearrange_array_ptr[1]; + } + temp_rearrange_array_ptr[0] = __shfl_xor_sync( + 0xFFFFFFFF, reinterpret_cast(result_reg_ptr)[0], 2); + temp_rearrange_array_ptr[1] = __shfl_xor_sync( + 0xFFFFFFFF, reinterpret_cast(result_reg_ptr)[1], 2); + if (lane_idx_div4 < 2) { + asm volatile("prmt.b32 %0,%1,%2,%3;\n" + : "=r"(result_reg_ptr[0]) + : "r"(result_reg_ptr[0]), + "r"(temp_rearrange_array_ptr[0]), + "n"(0x5410)); + asm volatile("prmt.b32 %0,%1,%2,%3;\n" + : "=r"(result_reg_ptr[1]) + : "r"(result_reg_ptr[1]), + "r"(temp_rearrange_array_ptr[1]), + "n"(0x5410)); + } else { + asm volatile("prmt.b32 %0,%1,%2,%3;\n" + : "=r"(result_reg_ptr[0]) + : "r"(result_reg_ptr[0]), + "r"(temp_rearrange_array_ptr[0]), + "n"(0x3276)); + asm volatile("prmt.b32 %0,%1,%2,%3;\n" + : "=r"(result_reg_ptr[1]) + : "r"(result_reg_ptr[1]), + "r"(temp_rearrange_array_ptr[1]), + "n"(0x3276)); + } } + + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) { return convert(s); } }; ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/interleaved_numeric_conversion_nf4.h b/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/interleaved_numeric_conversion_nf4.h index b21b204c85f..c4f03583cef 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/interleaved_numeric_conversion_nf4.h +++ b/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/interleaved_numeric_conversion_nf4.h @@ -104,8 +104,8 @@ struct FastInterleavedAndBiasedNumericArrayConverterNf4 { using scalar_result_type = typename result_type::Element; using scalar_source_type = typename source_type::Element; FastInterleavedAndBiasedNumericArrayConverterNf4 + scalar_source_type, + VEC_WIDTH> convert_vector_; result_type result; @@ -128,7 +128,9 @@ struct FastInterleavedAndBiasedNumericArrayConverterNf4 { }; template <> -struct FastInterleavedAndBiasedNumericArrayConverterNf4 { +struct FastInterleavedAndBiasedNumericArrayConverterNf4 { using result_type = Array; using source_type = Array; @@ -179,7 +181,9 @@ struct FastInterleavedAndBiasedNumericArrayConverterNf4 }; template -struct FastInterleavedAndBiasedNumericArrayConverterNf4 { +struct FastInterleavedAndBiasedNumericArrayConverterNf4 { static constexpr int VEC_WIDTH = 4; static_assert(!(N % VEC_WIDTH), "N must be multiple of 4."); @@ -191,8 +195,8 @@ struct FastInterleavedAndBiasedNumericArrayConverterNf4 using scalar_result_type = typename result_type::Element; using scalar_source_type = typename source_type::Element; FastInterleavedAndBiasedNumericArrayConverterNf4 + scalar_source_type, + VEC_WIDTH> convert_vector_; result_type result; @@ -314,8 +318,8 @@ struct FastInterleavedAndBiasedNumericArrayConverterNf4 { using scalar_result_type = typename result_type::Element; using scalar_source_type = typename source_type::Element; FastInterleavedAndBiasedNumericArrayConverterNf4 + scalar_source_type, + VEC_WIDTH> convert_vector_; result_type result; @@ -338,7 +342,9 @@ struct FastInterleavedAndBiasedNumericArrayConverterNf4 { }; template <> -struct FastInterleavedAndBiasedNumericArrayConverterNf4 { +struct FastInterleavedAndBiasedNumericArrayConverterNf4 { using result_type = Array; using source_type = Array; @@ -400,7 +406,9 @@ struct FastInterleavedAndBiasedNumericArrayConverterNf4 }; template -struct FastInterleavedAndBiasedNumericArrayConverterNf4 { +struct FastInterleavedAndBiasedNumericArrayConverterNf4 { static constexpr int VEC_WIDTH = 8; static_assert(!(N % VEC_WIDTH), "N must be multiple of 8."); @@ -412,8 +420,8 @@ struct FastInterleavedAndBiasedNumericArrayConverterNf4 using scalar_result_type = typename result_type::Element; using scalar_source_type = typename source_type::Element; FastInterleavedAndBiasedNumericArrayConverterNf4 + scalar_source_type, + VEC_WIDTH> convert_vector_; result_type result; @@ -435,388 +443,405 @@ struct FastInterleavedAndBiasedNumericArrayConverterNf4 result_type operator()(source_type const& s) { return convert(s); } }; -template<> +template <> struct FastInterleavedAndBiasedNumericArrayConverterNf4 { - using result_type = Array; - using source_type = Array; - - CUTLASS_DEVICE - static result_type convert(source_type const& source) - { - result_type result; - - uint32_t* h = reinterpret_cast(&result); - uint32_t const i8s = reinterpret_cast(source); - - // 3 2 1 0 -> 3 1 2 0 - static constexpr uint32_t mask_for_elt = 0x3120; - asm volatile("prmt.b32 %0,%1,%2,%3;\n" : "=r"(h[0]) : "r"(i8s), "r"(i8s), "n"(mask_for_elt)); - - // Author zhengzekang - uint8_t* tmp = reinterpret_cast(&result); - #pragma unroll - for(int i = 0; i < 4; i++){ - result[i] = static_cast(static_cast(tmp[i]) - 128); - } - return result; - } + using result_type = Array; + using source_type = Array; - CUTLASS_DEVICE - result_type operator()(source_type const& s) - { - return convert(s); - } -}; + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + result_type result; + + uint32_t* h = reinterpret_cast(&result); + uint32_t const i8s = reinterpret_cast(source); + // 3 2 1 0 -> 3 1 2 0 + static constexpr uint32_t mask_for_elt = 0x3120; + asm volatile("prmt.b32 %0,%1,%2,%3;\n" + : "=r"(h[0]) + : "r"(i8s), "r"(i8s), "n"(mask_for_elt)); -template<> -struct FastInterleavedAndBiasedNumericArrayConverterNf4 { - using result_type = Array; - using source_type = Array; - - CUTLASS_DEVICE - static result_type convert(source_type const& source) - { - result_type result; - - uint32_t* h = reinterpret_cast(&result); - const uint32_t * i8s = reinterpret_cast(&source); - - // 3 2 1 0 -> 3 1 2 0 - static constexpr uint32_t mask_for_elt_1 = 0x7120; - static constexpr uint32_t mask_for_elt_2 = 0x3654; - - asm volatile("prmt.b32 %0,%1,%2,%3;\n" : "=r"(h[0]) : "r"(i8s[0]), "r"(i8s[1]), "n"(mask_for_elt_1)); - asm volatile("prmt.b32 %0,%1,%2,%3;\n" : "=r"(h[1]) : "r"(i8s[0]), "r"(i8s[1]), "n"(mask_for_elt_2)); - - // Author zhengzekang - uint8_t* tmp = reinterpret_cast(&result); - #pragma unroll - for(int i = 0; i < 8; i++){ - result[i] = static_cast(static_cast(tmp[i]) - 128); - } - return result; + // Author zhengzekang + uint8_t* tmp = reinterpret_cast(&result); +#pragma unroll + for (int i = 0; i < 4; i++) { + result[i] = static_cast(static_cast(tmp[i]) - 128); } + return result; + } - CUTLASS_DEVICE - result_type operator()(source_type const& s) - { - return convert(s); - } + CUTLASS_DEVICE + result_type operator()(source_type const& s) { return convert(s); } }; +template <> +struct FastInterleavedAndBiasedNumericArrayConverterNf4 { + using result_type = Array; + using source_type = Array; -template<> -struct FastInterleavedAndBiasedNumericArrayConverterNf4 { - using result_type = Array; - using source_type = Array; + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + result_type result; + + uint32_t* h = reinterpret_cast(&result); + const uint32_t* i8s = reinterpret_cast(&source); + + // 3 2 1 0 -> 3 1 2 0 + static constexpr uint32_t mask_for_elt_1 = 0x7120; + static constexpr uint32_t mask_for_elt_2 = 0x3654; + + asm volatile("prmt.b32 %0,%1,%2,%3;\n" + : "=r"(h[0]) + : "r"(i8s[0]), "r"(i8s[1]), "n"(mask_for_elt_1)); + asm volatile("prmt.b32 %0,%1,%2,%3;\n" + : "=r"(h[1]) + : "r"(i8s[0]), "r"(i8s[1]), "n"(mask_for_elt_2)); - CUTLASS_DEVICE - static result_type convert(source_type const& source) - { - result_type result; + // Author zhengzekang + uint8_t* tmp = reinterpret_cast(&result); +#pragma unroll + for (int i = 0; i < 8; i++) { + result[i] = static_cast(static_cast(tmp[i]) - 128); + } + return result; + } - uint32_t* h = reinterpret_cast(&result); - const uint32_t * i8s = reinterpret_cast(&source); + CUTLASS_DEVICE + result_type operator()(source_type const& s) { return convert(s); } +}; - // 3 2 1 0 -> 3 1 2 0 - static constexpr uint32_t mask_for_elt_1 = 0x3120; - static constexpr uint32_t mask_for_elt_2 = 0x3120; - static constexpr uint32_t mask_for_elt_3 = 0x3120; - static constexpr uint32_t mask_for_elt_4 = 0x3120; +template <> +struct FastInterleavedAndBiasedNumericArrayConverterNf4 { + using result_type = Array; + using source_type = Array; - asm volatile("prmt.b32 %0,%1,%2,%3;\n" : "=r"(h[0]) : "r"(i8s[0]), "r"(i8s[2]), "n"(mask_for_elt_1)); - asm volatile("prmt.b32 %0,%1,%2,%3;\n" : "=r"(h[1]) : "r"(i8s[1]), "r"(i8s[3]), "n"(mask_for_elt_2)); - asm volatile("prmt.b32 %0,%1,%2,%3;\n" : "=r"(h[2]) : "r"(i8s[2]), "r"(i8s[0]), "n"(mask_for_elt_3)); - asm volatile("prmt.b32 %0,%1,%2,%3;\n" : "=r"(h[3]) : "r"(i8s[3]), "r"(i8s[1]), "n"(mask_for_elt_4)); + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + result_type result; + uint32_t* h = reinterpret_cast(&result); + const uint32_t* i8s = reinterpret_cast(&source); - // Author zhengzekang - uint8_t* tmp = reinterpret_cast(&result); - #pragma unroll - for(int i = 0; i < 16; i++){ - result[i] = static_cast(static_cast(tmp[i]) - 128); - } + // 3 2 1 0 -> 3 1 2 0 + static constexpr uint32_t mask_for_elt_1 = 0x3120; + static constexpr uint32_t mask_for_elt_2 = 0x3120; + static constexpr uint32_t mask_for_elt_3 = 0x3120; + static constexpr uint32_t mask_for_elt_4 = 0x3120; - return result; + asm volatile("prmt.b32 %0,%1,%2,%3;\n" + : "=r"(h[0]) + : "r"(i8s[0]), "r"(i8s[2]), "n"(mask_for_elt_1)); + asm volatile("prmt.b32 %0,%1,%2,%3;\n" + : "=r"(h[1]) + : "r"(i8s[1]), "r"(i8s[3]), "n"(mask_for_elt_2)); + asm volatile("prmt.b32 %0,%1,%2,%3;\n" + : "=r"(h[2]) + : "r"(i8s[2]), "r"(i8s[0]), "n"(mask_for_elt_3)); + asm volatile("prmt.b32 %0,%1,%2,%3;\n" + : "=r"(h[3]) + : "r"(i8s[3]), "r"(i8s[1]), "n"(mask_for_elt_4)); + // Author zhengzekang + uint8_t* tmp = reinterpret_cast(&result); +#pragma unroll + for (int i = 0; i < 16; i++) { + result[i] = static_cast(static_cast(tmp[i]) - 128); } - CUTLASS_DEVICE - result_type operator()(source_type const& s) - { - return convert(s); - } + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) { return convert(s); } }; -template +template struct FastInterleavedAndBiasedNumericArrayConverterNf4 { - static constexpr int VEC_WIDTH = 4; - static_assert(N == 32,"N must be 32"); - static_assert(!(N % VEC_WIDTH), "N must be multiple of 16."); + static constexpr int VEC_WIDTH = 4; + static_assert(N == 32, "N must be 32"); + static_assert(!(N % VEC_WIDTH), "N must be multiple of 16."); - using result_type = Array; - using source_type = Array; + using result_type = Array; + using source_type = Array; - CUTLASS_DEVICE - static result_type convert(source_type const& source) - { - return source; - } + CUTLASS_DEVICE + static result_type convert(source_type const& source) { return source; } - CUTLASS_DEVICE - result_type operator()(source_type const& s) - { - return convert(s); - } + CUTLASS_DEVICE + result_type operator()(source_type const& s) { return convert(s); } }; -template -struct FastInterleavedAndBiasedNumericArrayConverterNf4 { - static constexpr int VEC_WIDTH = 8; - // static_assert(N == 64,"N must be 64"); - static_assert(!(N % VEC_WIDTH), "N must be multiple of VEC_WIDTH."); - - using result_type = Array; - using source_type = Array; - using vec_source = Array; - using vec_result = Array; - CUTLASS_DEVICE - static result_type convert(source_type const& source) - { - //nf4 - using scalar_result_type = typename result_type::Element; - using scalar_source_type = typename source_type::Element; - result_type result; - vec_result* result_ptr = reinterpret_cast(&result); - vec_source const* source_ptr = reinterpret_cast(&source); - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N / VEC_WIDTH; ++i) { - vec_result index; - uint32_t* result_i_ptr = reinterpret_cast(&(result_ptr[i])); - uint32_t const i4s = reinterpret_cast(source_ptr[i]); - static constexpr uint32_t up_int4_mask = 0xf0f0f0f0; - static constexpr uint32_t immLut_0 = 0x40; - static constexpr uint32_t immLut_1 = 0x80; - asm volatile( - "lop3.b32 %0, %1, %2, %3, %4;\n" - : "=r"(result_i_ptr[1]) - : "r"(i4s), "r"(i4s), "n"(up_int4_mask), "n"(immLut_1)); - asm volatile( - "lop3.b32 %0, %1, %2, %3, %4;\n" - : "=r"(result_i_ptr[0]) - : "r"(i4s), "r"(i4s), "n"(up_int4_mask), "n"(immLut_0)); - result_i_ptr[0]=(result_i_ptr[0]<<4); - } - return result; +template +struct FastInterleavedAndBiasedNumericArrayConverterNf4 { + static constexpr int VEC_WIDTH = 8; + // static_assert(N == 64,"N must be 64"); + static_assert(!(N % VEC_WIDTH), "N must be multiple of VEC_WIDTH."); + + using result_type = Array; + using source_type = Array; + using vec_source = Array; + using vec_result = Array; + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + // nf4 + using scalar_result_type = typename result_type::Element; + using scalar_source_type = typename source_type::Element; + result_type result; + vec_result* result_ptr = reinterpret_cast(&result); + vec_source const* source_ptr = reinterpret_cast(&source); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / VEC_WIDTH; ++i) { + vec_result index; + uint32_t* result_i_ptr = reinterpret_cast(&(result_ptr[i])); + uint32_t const i4s = reinterpret_cast(source_ptr[i]); + static constexpr uint32_t up_int4_mask = 0xf0f0f0f0; + static constexpr uint32_t immLut_0 = 0x40; + static constexpr uint32_t immLut_1 = 0x80; + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(result_i_ptr[1]) + : "r"(i4s), "r"(i4s), "n"(up_int4_mask), "n"(immLut_1)); + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(result_i_ptr[0]) + : "r"(i4s), "r"(i4s), "n"(up_int4_mask), "n"(immLut_0)); + result_i_ptr[0] = (result_i_ptr[0] << 4); } + return result; + } - CUTLASS_DEVICE - static result_type convert(source_type const& source, int32_t* shared_look_up_table) - { - using scalar_result_type = typename result_type::Element; - using scalar_source_type = typename source_type::Element; - result_type result; - // static constexpr uint32_t loop_up_table[4]{0x03020100,0x07060504,0x0B0A0908,0x0F0E0D0C}; - vec_result* result_ptr = reinterpret_cast(&result); - vec_source const* source_ptr = reinterpret_cast(&source); - // FastInterleavedAndBiasedNumericArrayConverterNf4 - // convert_vector_; - // static constexpr Array loop_up_table{0x00010203,0x04050607,0x08090A0B,0x0C0D0E0F}; - for (int i = 0; i < N / VEC_WIDTH; ++i) { - vec_result index; - uint32_t* h = reinterpret_cast(&index); - // const int8_t* loop_up_table_int8 = reinterpret_cast(&loop_up_table); - uint32_t const i4s = reinterpret_cast(source_ptr[i]); - static constexpr uint32_t down_int4_mask = 0x0f0f0f0f; - static constexpr uint32_t up_int4_mask = 0xf0f0f0f0; - h[0]=i4s&down_int4_mask; - h[1]=i4s&up_int4_mask; - h[1]=h[1]>>4; - - //TODO(wangbojun)!!!! do nf4 lookup table - CUTLASS_PRAGMA_UNROLL - for(int ii=0; ii(index[ii]); - // } - - // CUTLASS_PRAGMA_UNROLL - // for(int ii=0; ii(index[0]), - // static_cast(index[1]), - // static_cast(index[2]), - // static_cast(index[3]), - // static_cast(index[4]), - // static_cast(index[5]), - // static_cast(index[6]), - // static_cast(index[7]), - // static_cast(result_ptr[i][0]), - // static_cast(result_ptr[i][1]), - // static_cast(result_ptr[i][2]), - // static_cast(result_ptr[i][3]), - // static_cast(result_ptr[i][4]), - // static_cast(result_ptr[i][5]), - // static_cast(result_ptr[i][6]), - // static_cast(result_ptr[i][7]) - // ); - } - return result; + CUTLASS_DEVICE + static result_type convert(source_type const& source, + int32_t* shared_look_up_table) { + using scalar_result_type = typename result_type::Element; + using scalar_source_type = typename source_type::Element; + result_type result; + // static constexpr uint32_t + // loop_up_table[4]{0x03020100,0x07060504,0x0B0A0908,0x0F0E0D0C}; + vec_result* result_ptr = reinterpret_cast(&result); + vec_source const* source_ptr = reinterpret_cast(&source); + // FastInterleavedAndBiasedNumericArrayConverterNf4 + // convert_vector_; + // static constexpr Array + // loop_up_table{0x00010203,0x04050607,0x08090A0B,0x0C0D0E0F}; + for (int i = 0; i < N / VEC_WIDTH; ++i) { + vec_result index; + uint32_t* h = reinterpret_cast(&index); + // const int8_t* loop_up_table_int8 = reinterpret_cast(&loop_up_table); + uint32_t const i4s = reinterpret_cast(source_ptr[i]); + static constexpr uint32_t down_int4_mask = 0x0f0f0f0f; + static constexpr uint32_t up_int4_mask = 0xf0f0f0f0; + h[0] = i4s & down_int4_mask; + h[1] = i4s & up_int4_mask; + h[1] = h[1] >> 4; + + // TODO(wangbojun)!!!! do nf4 lookup table + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < VEC_WIDTH; ++ii) { + result_ptr[i][ii] = shared_look_up_table[index[ii]]; + } + // CUTLASS_PRAGMA_UNROLL + // for(int ii=0; ii(index[ii]); + // } + + // CUTLASS_PRAGMA_UNROLL + // for(int ii=0; ii(index[0]), + // static_cast(index[1]), + // static_cast(index[2]), + // static_cast(index[3]), + // static_cast(index[4]), + // static_cast(index[5]), + // static_cast(index[6]), + // static_cast(index[7]), + // static_cast(result_ptr[i][0]), + // static_cast(result_ptr[i][1]), + // static_cast(result_ptr[i][2]), + // static_cast(result_ptr[i][3]), + // static_cast(result_ptr[i][4]), + // static_cast(result_ptr[i][5]), + // static_cast(result_ptr[i][6]), + // static_cast(result_ptr[i][7]) + // ); } + return result; + } + // #define NF4_LUT_DEBUG + CUTLASS_DEVICE + static result_type convert( + source_type const& source, + cutlass::Array const& reg_look_up_table) { + // static_assert(VEC_WIDTH==16, "VEC_WIDTH == 16 for int8 int8 int32") + using scalar_result_type = typename result_type::Element; + using scalar_source_type = typename source_type::Element; + result_type result; + // static constexpr uint32_t + // loop_up_table[4]{0x03020100,0x07060504,0x0B0A0908,0x0F0E0D0C}; + vec_result* result_ptr = reinterpret_cast(&result); + vec_source const* source_ptr = reinterpret_cast(&source); + // FastInterleavedAndBiasedNumericArrayConverterNf4 + // convert_vector_; + // static constexpr Array + // loop_up_table{0x00010203,0x04050607,0x08090A0B,0x0C0D0E0F}; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / VEC_WIDTH; ++i) { + uint32_t bitwise_workspace; + uint32_t* result_reg = reinterpret_cast(&(result_ptr[i])); + uint32_t const i4s = reinterpret_cast(source_ptr[i]); + static constexpr uint32_t down_int4_mask = 0x0f0f0f0f; +#ifdef NF4_LUT_DEBUG + if (threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0 && + blockIdx.z == 0) { + printf("#### lup 0:%08x, lup 1:%08x, lup2:%08x, lup3:%08x \n", + reg_look_up_table[0], + reg_look_up_table[1], + reg_look_up_table[2], + reg_look_up_table[3]); + printf("#### i4s:%08x \n", i4s); + } +#endif + bitwise_workspace = i4s & down_int4_mask; // 0x0v0v0v0v -// #define NF4_LUT_DEBUG - CUTLASS_DEVICE - static result_type convert(source_type const& source, cutlass::Array const& reg_look_up_table) - { - // static_assert(VEC_WIDTH==16, "VEC_WIDTH == 16 for int8 int8 int32") - using scalar_result_type = typename result_type::Element; - using scalar_source_type = typename source_type::Element; - result_type result; - // static constexpr uint32_t loop_up_table[4]{0x03020100,0x07060504,0x0B0A0908,0x0F0E0D0C}; - vec_result* result_ptr = reinterpret_cast(&result); - vec_source const* source_ptr = reinterpret_cast(&source); - // FastInterleavedAndBiasedNumericArrayConverterNf4 - // convert_vector_; - // static constexpr Array loop_up_table{0x00010203,0x04050607,0x08090A0B,0x0C0D0E0F}; - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N / VEC_WIDTH; ++i) { - uint32_t bitwise_workspace; - uint32_t* result_reg = reinterpret_cast(&(result_ptr[i])); - uint32_t const i4s = reinterpret_cast(source_ptr[i]); - static constexpr uint32_t down_int4_mask = 0x0f0f0f0f; - #ifdef NF4_LUT_DEBUG - if(threadIdx.x==0 && blockIdx.x==0 && blockIdx.y==0 && blockIdx.z==0){ - printf("#### lup 0:%08x, lup 1:%08x, lup2:%08x, lup3:%08x \n", - reg_look_up_table[0], - reg_look_up_table[1], - reg_look_up_table[2], - reg_look_up_table[3]); - printf("#### i4s:%08x \n", i4s); - } - #endif - bitwise_workspace = i4s & down_int4_mask; // 0x0v0v0v0v - - #ifdef NF4_LUT_DEBUG - if(threadIdx.x==0 && blockIdx.x==0 && blockIdx.y==0 && blockIdx.z==0){ - printf("#### h0 0x0v0v0v0v:%08x \n", bitwise_workspace); - } - #endif - - uint32_t ge_7_mask_h_0 = (bitwise_workspace |bitwise_workspace << 4) & 0x88888888; - ge_7_mask_h_0 |= ge_7_mask_h_0 >> 1; - ge_7_mask_h_0 |= ge_7_mask_h_0 >> 2; - - #ifdef NF4_LUT_DEBUG - if(threadIdx.x==0 && blockIdx.x==0 && blockIdx.y==0 && blockIdx.z==0){ - printf("#### ge_7_mask_h_0:%08x \n", ge_7_mask_h_0); - } - #endif - - // uint32_t look_up_h_0 = h[0]; - bitwise_workspace = (bitwise_workspace | (bitwise_workspace >> 4)) & 0x00FF00FF; - bitwise_workspace = (bitwise_workspace | (bitwise_workspace >> 8)) & 0x00007777; // make h[i] into 0x0000vvvv - - #ifdef NF4_LUT_DEBUG - if(threadIdx.x==0 && blockIdx.x==0 && blockIdx.y==0 && blockIdx.z==0){ - printf("#### h0 0x0000vvvv:%08x \n", bitwise_workspace); - } - #endif - - uint32_t result_1=0; - asm volatile("prmt.b32 %0,%1,%2,%3;\n" - : "=r"(result_1) - : "r"(reg_look_up_table[0]), "r"(reg_look_up_table[1]), "r"(bitwise_workspace)); - #ifdef NF4_LUT_DEBUG - if(threadIdx.x==0 && blockIdx.x==0 && blockIdx.y==0 && blockIdx.z==0){ - printf("#### result_1:%08x \n", result_1); - } - #endif - result_reg[0] = result_1 & (~ge_7_mask_h_0); - result_1 = ((~result_1)) & (ge_7_mask_h_0); - // result_1 = ((~result_1) | 0x01010101) & (ge_7_mask_h_0); // half mirror - result_reg[0] = result_reg[0] | result_1; - #ifdef NF4_LUT_DEBUG - if(threadIdx.x==0 && blockIdx.x==0 && blockIdx.y==0 && blockIdx.z==0){ - printf("#### result[0]:%08x \n", result_reg[0]); - } - #endif - bitwise_workspace = i4s >> 4; - bitwise_workspace = bitwise_workspace & down_int4_mask; - #ifdef NF4_LUT_DEBUG - if(threadIdx.x==0 && blockIdx.x==0 && blockIdx.y==0 && blockIdx.z==0){ - printf("#### h1 0x0v0v0v0v:%08x \n", bitwise_workspace); - } - #endif - ge_7_mask_h_0 = (bitwise_workspace | bitwise_workspace << 4) & 0x88888888; - ge_7_mask_h_0 |= ge_7_mask_h_0 >> 1; - ge_7_mask_h_0 |= ge_7_mask_h_0 >> 2; - #ifdef NF4_LUT_DEBUG - if(threadIdx.x==0 && blockIdx.x==0 && blockIdx.y==0 && blockIdx.z==0){ - printf("#### ge_7_mask_h_0:%08x \n", ge_7_mask_h_0); - } - #endif - bitwise_workspace = (bitwise_workspace | (bitwise_workspace >> 4)) & 0x00FF00FF; - bitwise_workspace = (bitwise_workspace | (bitwise_workspace >> 8)) & 0x00007777; // make look_up_h_0 into 0x0000vvvv - #ifdef NF4_LUT_DEBUG - if(threadIdx.x==0 && blockIdx.x==0 && blockIdx.y==0 && blockIdx.z==0){ - printf("#### h1 0x0000vvvv:%08x \n", bitwise_workspace); - } - #endif - asm volatile("prmt.b32 %0,%1,%2,%3;\n" - : "=r"(result_1) - : "r"(reg_look_up_table[0]), "r"(reg_look_up_table[1]), "r"(bitwise_workspace)); - #ifdef NF4_LUT_DEBUG - if(threadIdx.x==0 && blockIdx.x==0 && blockIdx.y==0 && blockIdx.z==0){ - printf("#### result_1:%08x \n", result_1); - } - #endif - result_reg[1] = result_1 & (~ge_7_mask_h_0); - result_1 = ((~result_1)) & (ge_7_mask_h_0); - // result_1 = ((~result_1) | 0x01010101) & (ge_7_mask_h_0); // half mirror - result_reg[1] = result_reg[1] | result_1; - #ifdef NF4_LUT_DEBUG - if(threadIdx.x==0 && blockIdx.x==0 && blockIdx.y==0 && blockIdx.z==0){ - printf("#### result[1]:%08x \n", result_reg[1]); - } - #endif - } - return result; - } +#ifdef NF4_LUT_DEBUG + if (threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0 && + blockIdx.z == 0) { + printf("#### h0 0x0v0v0v0v:%08x \n", bitwise_workspace); + } +#endif + uint32_t ge_7_mask_h_0 = + (bitwise_workspace | bitwise_workspace << 4) & 0x88888888; + ge_7_mask_h_0 |= ge_7_mask_h_0 >> 1; + ge_7_mask_h_0 |= ge_7_mask_h_0 >> 2; - CUTLASS_DEVICE - result_type operator()(source_type const& s) - { - return convert(s); - } - CUTLASS_DEVICE - result_type operator()(source_type const& s, uint32_t shared_look_up_table[32][32], int32_t warp_idx, int32_t lane_idx) - { - return convert(s,shared_look_up_table,warp_idx, lane_idx); - } - CUTLASS_DEVICE - result_type operator()(source_type const& s, uint32_t shared_look_up_table[16]) - { - return convert(s,shared_look_up_table); - } - CUTLASS_DEVICE - result_type operator()(source_type const& s, int32_t* shared_look_up_table) - { - return convert(s,shared_look_up_table); - } - CUTLASS_DEVICE - result_type operator()(source_type const& s, cutlass::Array const& reg_look_up_table) - { - return convert(s,reg_look_up_table); +#ifdef NF4_LUT_DEBUG + if (threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0 && + blockIdx.z == 0) { + printf("#### ge_7_mask_h_0:%08x \n", ge_7_mask_h_0); + } +#endif + + // uint32_t look_up_h_0 = h[0]; + bitwise_workspace = + (bitwise_workspace | (bitwise_workspace >> 4)) & 0x00FF00FF; + bitwise_workspace = (bitwise_workspace | (bitwise_workspace >> 8)) & + 0x00007777; // make h[i] into 0x0000vvvv + +#ifdef NF4_LUT_DEBUG + if (threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0 && + blockIdx.z == 0) { + printf("#### h0 0x0000vvvv:%08x \n", bitwise_workspace); + } +#endif + + uint32_t result_1 = 0; + asm volatile("prmt.b32 %0,%1,%2,%3;\n" + : "=r"(result_1) + : "r"(reg_look_up_table[0]), + "r"(reg_look_up_table[1]), + "r"(bitwise_workspace)); +#ifdef NF4_LUT_DEBUG + if (threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0 && + blockIdx.z == 0) { + printf("#### result_1:%08x \n", result_1); + } +#endif + result_reg[0] = result_1 & (~ge_7_mask_h_0); + result_1 = ((~result_1)) & (ge_7_mask_h_0); + // result_1 = ((~result_1) | 0x01010101) & (ge_7_mask_h_0); // half mirror + result_reg[0] = result_reg[0] | result_1; +#ifdef NF4_LUT_DEBUG + if (threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0 && + blockIdx.z == 0) { + printf("#### result[0]:%08x \n", result_reg[0]); + } +#endif + bitwise_workspace = i4s >> 4; + bitwise_workspace = bitwise_workspace & down_int4_mask; +#ifdef NF4_LUT_DEBUG + if (threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0 && + blockIdx.z == 0) { + printf("#### h1 0x0v0v0v0v:%08x \n", bitwise_workspace); + } +#endif + ge_7_mask_h_0 = (bitwise_workspace | bitwise_workspace << 4) & 0x88888888; + ge_7_mask_h_0 |= ge_7_mask_h_0 >> 1; + ge_7_mask_h_0 |= ge_7_mask_h_0 >> 2; +#ifdef NF4_LUT_DEBUG + if (threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0 && + blockIdx.z == 0) { + printf("#### ge_7_mask_h_0:%08x \n", ge_7_mask_h_0); + } +#endif + bitwise_workspace = + (bitwise_workspace | (bitwise_workspace >> 4)) & 0x00FF00FF; + bitwise_workspace = (bitwise_workspace | (bitwise_workspace >> 8)) & + 0x00007777; // make look_up_h_0 into 0x0000vvvv +#ifdef NF4_LUT_DEBUG + if (threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0 && + blockIdx.z == 0) { + printf("#### h1 0x0000vvvv:%08x \n", bitwise_workspace); + } +#endif + asm volatile("prmt.b32 %0,%1,%2,%3;\n" + : "=r"(result_1) + : "r"(reg_look_up_table[0]), + "r"(reg_look_up_table[1]), + "r"(bitwise_workspace)); +#ifdef NF4_LUT_DEBUG + if (threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0 && + blockIdx.z == 0) { + printf("#### result_1:%08x \n", result_1); + } +#endif + result_reg[1] = result_1 & (~ge_7_mask_h_0); + result_1 = ((~result_1)) & (ge_7_mask_h_0); + // result_1 = ((~result_1) | 0x01010101) & (ge_7_mask_h_0); // half mirror + result_reg[1] = result_reg[1] | result_1; +#ifdef NF4_LUT_DEBUG + if (threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0 && + blockIdx.z == 0) { + printf("#### result[1]:%08x \n", result_reg[1]); + } +#endif } + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) { return convert(s); } + CUTLASS_DEVICE + result_type operator()(source_type const& s, + uint32_t shared_look_up_table[32][32], + int32_t warp_idx, + int32_t lane_idx) { + return convert(s, shared_look_up_table, warp_idx, lane_idx); + } + CUTLASS_DEVICE + result_type operator()(source_type const& s, + uint32_t shared_look_up_table[16]) { + return convert(s, shared_look_up_table); + } + CUTLASS_DEVICE + result_type operator()(source_type const& s, int32_t* shared_look_up_table) { + return convert(s, shared_look_up_table); + } + CUTLASS_DEVICE + result_type operator()( + source_type const& s, + cutlass::Array const& reg_look_up_table) { + return convert(s, reg_look_up_table); + } }; ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/tile_interleaved_layout.h b/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/tile_interleaved_layout.h index 89152360f1d..40265273e05 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/tile_interleaved_layout.h +++ b/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_extensions/tile_interleaved_layout.h @@ -40,20 +40,20 @@ limitations under the License. */ namespace cutlass { namespace layout { -template +template class ColumnMajorTileInterleave { - static constexpr int kRowsPerTile = RowsPerTile; - static constexpr int kColumnsInterleaved = ColumnsInterleaved; + static constexpr int kRowsPerTile = RowsPerTile; + static constexpr int kColumnsInterleaved = ColumnsInterleaved; }; -template +template struct IsColumnMajorTileInterleave { - static constexpr bool value = false; + static constexpr bool value = false; }; -template +template struct IsColumnMajorTileInterleave> { - static constexpr bool value = true; + static constexpr bool value = true; }; } // namespace layout diff --git a/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_heuristic_w4a4.h b/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_heuristic_w4a4.h index 096a857de60..5fa4d7c98d2 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_heuristic_w4a4.h +++ b/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/cutlass_heuristic_w4a4.h @@ -33,7 +33,6 @@ limitations under the License. */ #include "glog/logging.h" #include "w4a4_gemm_configs.h" - static TileShape get_cta_shape_for_config_w4a4(CutlassTileConfig tile_config) { switch (tile_config) { case CutlassTileConfig::CtaShape16x128x64_WarpShape16x32x64: @@ -58,12 +57,12 @@ static TileShape get_cta_shape_for_config_w4a4(CutlassTileConfig tile_config) { } static bool is_valid_split_k_factor_w4a4(const int64_t m, - const int64_t n, - const int64_t k, - const TileShape tile_shape, - const int split_k_factor, - const size_t workspace_bytes, - const bool is_weight_only) { + const int64_t n, + const int64_t k, + const TileShape tile_shape, + const int split_k_factor, + const size_t workspace_bytes, + const bool is_weight_only) { // All tile sizes have a k_tile of 64. static constexpr int k_tile = 64; @@ -125,7 +124,7 @@ static std::vector get_candidate_tiles_w4a4( CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64 // CutlassTileConfig::CtaShape128x256x64_WarpShape64x64x64, // CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64, - }; + }; std::vector quant_B_configs; switch (sm) { case 90: @@ -148,7 +147,6 @@ static std::vector get_candidate_tiles_w4a4( return simt_configs_only ? simt_configs : allowed_configs; } - static std::vector get_candidate_configs_nf4( int sm, const bool is_weight_only, @@ -171,7 +169,6 @@ static std::vector get_candidate_configs_nf4( return candidate_configs; } - static CutlassGemmConfig estimate_best_config_from_occupancies_w4a4( const std::vector& candidate_configs, const std::vector& occupancies, @@ -189,7 +186,7 @@ static CutlassGemmConfig estimate_best_config_from_occupancies_w4a4( "candidate configs vectors must have equal length."); } - VLOG(1)<<"estimate_best_config_from_occupancies_w4a4"; + VLOG(1) << "estimate_best_config_from_occupancies_w4a4"; CutlassGemmConfig best_config; // Score will be [0, 1]. The objective is to minimize this score. // It represents the fraction of SM resources unused in the last wave. @@ -198,25 +195,25 @@ static CutlassGemmConfig estimate_best_config_from_occupancies_w4a4( int current_m_tile = 0; { - VLOG(1)<<"######## begin of cutlass gemm search"; - if (m >= 256 && - std::find_if( - candidate_configs.begin(), - candidate_configs.end(), - [](const CutlassGemmConfig& gemm_config) { - return gemm_config.tile_config == - CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64; - }) != candidate_configs.end()) { - VLOG(1) << "m >= 256, encoder config"; - best_config = CutlassGemmConfig{ - CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64, - SplitKStyle::SPLIT_K_STREAM, - // SplitKStyle::NO_SPLIT_K, - 1, - 3}; - - } else { - VLOG(1) << "m <= 64 , decoder config"; + VLOG(1) << "######## begin of cutlass gemm search"; + if (m >= 256 && + std::find_if( + candidate_configs.begin(), + candidate_configs.end(), + [](const CutlassGemmConfig& gemm_config) { + return gemm_config.tile_config == + CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64; + }) != candidate_configs.end()) { + VLOG(1) << "m >= 256, encoder config"; + best_config = CutlassGemmConfig{ + CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64, + SplitKStyle::SPLIT_K_STREAM, + // SplitKStyle::NO_SPLIT_K, + 1, + 3}; + + } else { + VLOG(1) << "m <= 64 , decoder config"; const int max_split_k = n >= multi_processor_count * 256 ? 1 : split_k_limit; for (int ii = 0; ii < candidate_configs.size(); ++ii) { @@ -240,14 +237,14 @@ static CutlassGemmConfig estimate_best_config_from_occupancies_w4a4( const int ctas_in_n_dim = (n + tile_shape.n - 1) / tile_shape.n; for (int split_k_factor = 1; split_k_factor <= max_split_k; - ++split_k_factor) { + ++split_k_factor) { if (is_valid_split_k_factor_w4a4(m, - n, - k, - tile_shape, - split_k_factor, - workspace_bytes, - is_weight_only)) { + n, + k, + tile_shape, + split_k_factor, + workspace_bytes, + is_weight_only)) { const int ctas_per_wave = occupancy * multi_processor_count; const int ctas_for_problem = ctas_in_m_dim * ctas_in_n_dim * split_k_factor; @@ -262,7 +259,7 @@ static CutlassGemmConfig estimate_best_config_from_occupancies_w4a4( const float score_slack = 0.1f; if (current_score < config_score || ((config_waves > num_waves_total) && - (current_score < config_score + score_slack))) { + (current_score < config_score + score_slack))) { config_score = current_score; config_waves = num_waves_total; SplitKStyle split_style = split_k_factor > 1 @@ -273,9 +270,10 @@ static CutlassGemmConfig estimate_best_config_from_occupancies_w4a4( split_k_factor, candidate_config.stages}; current_m_tile = tile_shape.m; - // std::cout<<"#### split-k factor: "<::type>(candidate_config.tile_config)<::type>(candidate_config.tile_config)<::type>(best_config.tile_config); + VLOG(1) << "#### best split-k factor: " << best_config.split_k_factor + << " config: " + << static_cast::type>( + best_config.tile_config); return best_config; } diff --git a/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/w4a4_gemm_configs.h b/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/w4a4_gemm_configs.h index 790ec3b17bc..00c338ba977 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/w4a4_gemm_configs.h +++ b/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/w4a4_gemm_configs.h @@ -33,44 +33,44 @@ limitations under the License. */ // Note: The shapes are in the format MxNxK. The K shape of the runtime config enum class CutlassTileConfig { // Signals that we should run heuristics do choose a config - Undefined, // 0 + Undefined, // 0 // Signals that we should run heuristics do choose a config - ChooseWithHeuristic, // 1 + ChooseWithHeuristic, // 1 // SiMT config - CtaShape128x128x8_WarpShape64x64x8, // 2 + CtaShape128x128x8_WarpShape64x64x8, // 2 // TensorCore configs CTA_N = 128, CTA_K = 64 // Warp configs for M=16 - CtaShape16x128x64_WarpShape16x32x64, // 3 - CtaShape16x256x64_WarpShape16x64x64, // 4 + CtaShape16x128x64_WarpShape16x32x64, // 3 + CtaShape16x256x64_WarpShape16x64x64, // 4 // Warp configs for M=32 - CtaShape32x128x64_WarpShape32x32x64, // 5 + CtaShape32x128x64_WarpShape32x32x64, // 5 // Warp configs for M=64 - CtaShape64x128x64_WarpShape32x64x64, // 6 - CtaShape64x128x64_WarpShape64x32x64, // 7 + CtaShape64x128x64_WarpShape32x64x64, // 6 + CtaShape64x128x64_WarpShape64x32x64, // 7 // Warp configs for M=128 - CtaShape128x128x64_WarpShape64x32x64, // 8 - CtaShape128x128x64_WarpShape128x32x64, // 9 + CtaShape128x128x64_WarpShape64x32x64, // 8 + CtaShape128x128x64_WarpShape128x32x64, // 9 // configs for large M in encoder - CtaShape128x256x64_WarpShape64x64x64, // 10 - CtaShape256x128x64_WarpShape64x64x64, // 11 + CtaShape128x256x64_WarpShape64x64x64, // 10 + CtaShape256x128x64_WarpShape64x64x64, // 11 - CtaShape32x256x64_WarpShape32x64x64, // 12 - CtaShape64x256x64_WarpShape64x64x64, // 13 - CtaShape128x256x64_WarpShape128x64x64, // 14 - CtaShape32x512x64_WarpShape32x128x64, // 15 + CtaShape32x256x64_WarpShape32x64x64, // 12 + CtaShape64x256x64_WarpShape64x64x64, // 13 + CtaShape128x256x64_WarpShape128x64x64, // 14 + CtaShape32x512x64_WarpShape32x128x64, // 15 }; enum class SplitKStyle { - NO_SPLIT_K, //0 - SPLIT_K_SERIAL, //1 - SPLIT_K_STREAM, //2 + NO_SPLIT_K, // 0 + SPLIT_K_SERIAL, // 1 + SPLIT_K_STREAM, // 2 // SPLIT_K_PARALLEL // Not supported yet }; diff --git a/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/w4a8_gemm_grouped.h b/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/w4a8_gemm_grouped.h index f871cb1d8e7..ae90627fc75 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/w4a8_gemm_grouped.h +++ b/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/w4a8_gemm_grouped.h @@ -1,12 +1,12 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation @@ -18,22 +18,23 @@ * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /*! \file - \brief The universal GEMM accommodates streamk, batched strided, and batched array variants. + \brief The universal GEMM accommodates streamk, batched strided, and batched + array variants. */ - #pragma once #include @@ -59,11 +60,9 @@ namespace device { ///////////////////////////////////////////////////////////////////////////////////////////////// - template class W4A8MoeGemmUniversalBase { -public: - + public: using GemmKernel = GemmKernel_; using ThreadblockShape = typename GemmKernel::Mma::Shape; @@ -92,8 +91,7 @@ class W4A8MoeGemmUniversalBase { /// Argument structure using Arguments = typename GemmKernel::Arguments; -protected: - + protected: // // Device properties (uniform across all instances of the current thread) // @@ -112,8 +110,7 @@ class W4A8MoeGemmUniversalBase { /// Initialize static thread-local members for the thread's current device, /// if necessary. - static Status init_device_props() - { + static Status init_device_props() { CUTLASS_TRACE_HOST("W4A8MoeGemmUniversalBase::init_device_props()"); cudaError_t cudart_result; @@ -122,7 +119,8 @@ class W4A8MoeGemmUniversalBase { int current_ordinal; cudart_result = cudaGetDevice(¤t_ordinal); if (cudart_result != cudaSuccess) { - CUTLASS_TRACE_HOST(" cudaGetDevice() returned error " << cudaGetErrorString(cudart_result)); + CUTLASS_TRACE_HOST(" cudaGetDevice() returned error " + << cudaGetErrorString(cudart_result)); return Status::kErrorInternal; } @@ -133,64 +131,80 @@ class W4A8MoeGemmUniversalBase { } // Update SM count member - cudart_result = cudaDeviceGetAttribute (&device_sms_, cudaDevAttrMultiProcessorCount, current_ordinal); + cudart_result = cudaDeviceGetAttribute( + &device_sms_, cudaDevAttrMultiProcessorCount, current_ordinal); if (cudart_result != cudaSuccess) { - CUTLASS_TRACE_HOST(" cudaDeviceGetAttribute() returned error " << cudaGetErrorString(cudart_result)); + CUTLASS_TRACE_HOST(" cudaDeviceGetAttribute() returned error " + << cudaGetErrorString(cudart_result)); return Status::kErrorInternal; } - // Update the kernel function's shared memory configuration for the current device + // Update the kernel function's shared memory configuration for the current + // device smem_size_ = int(sizeof(typename GemmKernel::SharedStorage)); // If requires more than 48KB: configure for extended, dynamic shared memory - if (smem_size_ >= (48 << 10)) - { - cudart_result = cudaFuncSetAttribute( - Kernel2, - cudaFuncAttributeMaxDynamicSharedMemorySize, - smem_size_); + if (smem_size_ >= (48 << 10)) { + cudart_result = + cudaFuncSetAttribute(Kernel2, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size_); if (cudart_result != cudaSuccess) { - CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error " << cudaGetErrorString(cudart_result)); + CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error " + << cudaGetErrorString(cudart_result)); return Status::kErrorInternal; } - cudart_result = cudaFuncSetAttribute( - Kernel2, - cudaFuncAttributePreferredSharedMemoryCarveout, 100); // 100% shared memory + cudart_result = + cudaFuncSetAttribute(Kernel2, + cudaFuncAttributePreferredSharedMemoryCarveout, + 100); // 100% shared memory if (cudart_result != cudaSuccess) { - CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error " << cudaGetErrorString(cudart_result)); + CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error " + << cudaGetErrorString(cudart_result)); return Status::kErrorInternal; } } // Update SM occupancy member cudart_result = cudaOccupancyMaxActiveBlocksPerMultiprocessorWithFlags( - &sm_occupancy_, - Kernel2, - GemmKernel::kThreadCount, - smem_size_, - cudaOccupancyDisableCachingOverride); + &sm_occupancy_, + Kernel2, + GemmKernel::kThreadCount, + smem_size_, + cudaOccupancyDisableCachingOverride); if (cudart_result != cudaSuccess) { - CUTLASS_TRACE_HOST(" cudaOccupancyMaxActiveBlocksPerMultiprocessorWithFlags() returned error " << cudaGetErrorString(cudart_result)); + CUTLASS_TRACE_HOST( + " cudaOccupancyMaxActiveBlocksPerMultiprocessorWithFlags() returned " + "error " + << cudaGetErrorString(cudart_result)); return Status::kErrorInternal; } // Update device ordinal member on success device_ordinal_ = current_ordinal; - CUTLASS_TRACE_HOST(" " - "device_ordinal: (" << device_ordinal_ << "), " - "device_sms: (" << device_sms_ << "), " - "sm_occupancy: (" << sm_occupancy_ << ") " - "smem_size: (" << smem_size_ << ") " - "GemmKernel::kThreadCount: (" << GemmKernel::kThreadCount << ")"); + CUTLASS_TRACE_HOST( + " " + "device_ordinal: (" + << device_ordinal_ + << "), " + "device_sms: (" + << device_sms_ + << "), " + "sm_occupancy: (" + << sm_occupancy_ + << ") " + "smem_size: (" + << smem_size_ + << ") " + "GemmKernel::kThreadCount: (" + << GemmKernel::kThreadCount << ")"); return Status::kSuccess; } - -protected: - + protected: // // Instance data members // @@ -198,10 +212,8 @@ class W4A8MoeGemmUniversalBase { /// Kernel parameters typename GemmKernel::Params params_; - /// Initialize params member - Status init_params(Arguments const &args) - { + Status init_params(Arguments const &args) { // Initialize static device properties, if necessary Status result = init_device_props(); if (result != Status::kSuccess) { @@ -213,40 +225,31 @@ class W4A8MoeGemmUniversalBase { return Status::kSuccess; } -public: - + public: //--------------------------------------------------------------------------------------------- // Stateless API //--------------------------------------------------------------------------------------------- /// Determines whether the GEMM can execute the given problem. - static Status can_implement(Arguments const &args) - { + static Status can_implement(Arguments const &args) { CUTLASS_TRACE_HOST("W4A8MoeGemmUniversalBase::can_implement()"); - // printf("--1\n"); // Initialize static kernel and device properties, if necessary. Status result = init_device_props(); - // printf("--1-2\n"); if (result != Status::kSuccess) { return result; } - // printf("--2\n"); dim3 grid = get_grid_shape(args); // printf("--grid:%d, %d, %d\n", grid.x, grid.y, grid.z); if (!(grid.y <= std::numeric_limits::max() && - grid.z <= std::numeric_limits::max())) - { + grid.z <= std::numeric_limits::max())) { return Status::kErrorInvalidProblem; } - // printf("--3\n"); return GemmKernel::can_implement(args); } - /// Returns the workspace size (in bytes) needed for the problem /// geometry expressed by these arguments - static size_t get_workspace_size(Arguments const &args) - { + static size_t get_workspace_size(Arguments const &args) { CUTLASS_TRACE_HOST("W4A8MoeGemmUniversalBase::get_workspace_size()"); // Initialize parameters from args @@ -262,56 +265,80 @@ class W4A8MoeGemmUniversalBase { return workspace_bytes; } - /// Returns the grid extents in thread blocks to launch - static dim3 get_grid_shape(Arguments const &args) - { + static dim3 get_grid_shape(Arguments const &args) { CUTLASS_TRACE_HOST("W4A8MoeGemmUniversalBase::get_grid_shape()"); // Initialize parameters from args W4A8MoeGemmUniversalBase base; if (base.init_params(args) != Status::kSuccess) { - return dim3(0,0,0); + return dim3(0, 0, 0); } // Get dims from parameters dim3 grid_dims = base.params_.get_grid_dims(); - CUTLASS_TRACE_HOST( - " tiled_shape: " << base.params_.get_tiled_shape() << "\n" - << " grid_dims: {" << grid_dims << "}"); + CUTLASS_TRACE_HOST(" tiled_shape: " + << base.params_.get_tiled_shape() << "\n" + << " grid_dims: {" << grid_dims << "}"); return grid_dims; } - /// Returns the maximum number of active thread blocks per multiprocessor - static int maximum_active_blocks() - { + static int maximum_active_blocks(int smem_capacity = -1) { CUTLASS_TRACE_HOST("W4A8MoeGemmUniversalBase::maximum_active_blocks()"); - // Initialize static device properties, if necessary - if (init_device_props() != Status::kSuccess) { + int smem_size = int(sizeof(typename GemmKernel_::SharedStorage)); + + CUTLASS_TRACE_HOST(" smem_size: " << smem_size << " bytes"); + + cudaError_t result; + if (smem_size > (48 << 10)) { + result = cudaFuncSetAttribute(Kernel2, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + + if (result != cudaSuccess) { + // Call cudaGetLastError() to clear the error bit + result = cudaGetLastError(); + CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error " + << cudaGetErrorString(result)); + return -1; + } + } + + int max_active_blocks = -1; + result = + cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_active_blocks, + Kernel2, + GemmKernel_::kThreadCount, + smem_size); + + if (result != cudaSuccess) { + // Call cudaGetLastError() to clear the error bit + result = cudaGetLastError(); + CUTLASS_TRACE_HOST( + " cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error " + << cudaGetErrorString(result)); return -1; } - CUTLASS_TRACE_HOST(" max_active_blocks: " << sm_occupancy_); - return sm_occupancy_; + CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks); + return max_active_blocks; } - //--------------------------------------------------------------------------------------------- // Stateful API //--------------------------------------------------------------------------------------------- /// Initializes GEMM state from arguments and workspace memory - Status initialize( - Arguments const &args, - void *workspace, - cudaStream_t stream = nullptr) - { + Status initialize(Arguments const &args, + void *workspace, + cudaStream_t stream = nullptr) { CUTLASS_TRACE_HOST("W4A8MoeGemmUniversalBase::initialize() - workspace " - << workspace << ", stream: " << (stream ? "non-null" : "null")); + << workspace + << ", stream: " << (stream ? "non-null" : "null")); // Initialize parameters from args Status result = init_params(args); @@ -323,59 +350,54 @@ class W4A8MoeGemmUniversalBase { return params_.init_workspace(workspace, stream); } - - /// Lightweight update given a subset of arguments. Problem geometry is assumed to - /// remain the same. - Status update(Arguments const &args) - { + /// Lightweight update given a subset of arguments. Problem geometry is + /// assumed to remain the same. + Status update(Arguments const &args) { CUTLASS_TRACE_HOST("W4A8MoeGemmUniversalBase()::update()"); params_.update(args); return Status::kSuccess; } - /// Runs the kernel using initialized state. - Status run(cudaStream_t stream = nullptr) - { + Status run(cudaStream_t stream = nullptr) { CUTLASS_TRACE_HOST("W4A8MoeGemmUniversalBase::run()"); // Configure grid and block dimensions dim3 block(GemmKernel::kThreadCount, 1, 1); - // dim3 grid = params_.get_grid_dims(); - dim3 grid(216, 1, 1); + dim3 grid(params_.threadblock_count, 1, 1); // Launch kernel - CUTLASS_TRACE_HOST(" " - "grid: (" << grid << "), " - "block: (" << block << "), " - "SMEM: (" << smem_size_ << ")"); + CUTLASS_TRACE_HOST( + " " + "grid: (" + << grid + << "), " + "block: (" + << block + << "), " + "SMEM: (" + << smem_size_ << ")"); Kernel2<<>>(params_); // Query for errors cudaError_t result = cudaGetLastError(); if (result != cudaSuccess) { - CUTLASS_TRACE_HOST(" grid launch failed with error " << cudaGetErrorString(result)); + CUTLASS_TRACE_HOST(" grid launch failed with error " + << cudaGetErrorString(result)); return Status::kErrorInternal; } return Status::kSuccess; } - /// Runs the kernel using initialized state. - Status operator()(cudaStream_t stream = nullptr) - { - return run(stream); - } - + Status operator()(cudaStream_t stream = nullptr) { return run(stream); } /// Runs the kernel using initialized state. - Status operator()( - Arguments const &args, - void *workspace = nullptr, - cudaStream_t stream = nullptr) - { + Status operator()(Arguments const &args, + void *workspace = nullptr, + cudaStream_t stream = nullptr) { Status status = initialize(args, workspace, stream); if (status == Status::kSuccess) { @@ -386,7 +408,6 @@ class W4A8MoeGemmUniversalBase { } }; - ///////////////////////////////////////////////////////////////////////////////////////////////// /// Static initializers ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -407,12 +428,10 @@ thread_local int W4A8MoeGemmUniversalBase::sm_occupancy_ = -1; template thread_local int W4A8MoeGemmUniversalBase::smem_size_ = -1; - - ///////////////////////////////////////////////////////////////////////////////////////////////// -} // namespace device -} // namespace gemm -} // namespace cutlass +} // namespace device +} // namespace gemm +} // namespace cutlass ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/w4a8_moe_cutlass_kernel.h b/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/w4a8_moe_cutlass_kernel.h index dbcc8912f2d..18cdeca89cb 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/w4a8_moe_cutlass_kernel.h +++ b/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/w4a8_moe_cutlass_kernel.h @@ -37,60 +37,65 @@ namespace kernel { ///////////////////////////////////////////////////////////////////////////////////////////////// -template +template struct MoeW4A8Gemm { - using Mma = Mma_; - using Epilogue = Epilogue_; - using EpilogueOutputOp = typename Epilogue::OutputOp; - using ThreadblockSwizzle = ThreadblockSwizzle_; + using Mma = Mma_; + using Epilogue = Epilogue_; + using EpilogueOutputOp = typename Epilogue::OutputOp; + using ThreadblockSwizzle = ThreadblockSwizzle_; static bool const kSplitKSerial = SplitKSerial; static GroupScheduleMode const kGroupScheduleMode = GroupScheduleMode_; static bool const kTransposed = false; - using ElementA = typename Mma::IteratorA::Element; - using LayoutA = typename Mma::IteratorA::Layout; + using ElementA = typename Mma::IteratorA::Element; + using LayoutA = typename Mma::IteratorA::Layout; using TensorRefA = TensorRef; - using ElementB = typename Mma::IteratorB::Element; - using LayoutB = typename Mma::IteratorB::Layout; + using ElementB = typename Mma::IteratorB::Element; + using LayoutB = typename Mma::IteratorB::Layout; using TensorRefB = TensorRef; using LayoutAlphaCol = cutlass::layout::RowMajor; using LayoutAlphaRow = cutlass::layout::ColumnMajor; using TensorRefNf4LookUpTable = TensorRef; - using ElementC = typename Epilogue::OutputTileIterator::Element; - using LayoutC = typename Mma::LayoutC; + using ElementC = typename Epilogue::OutputTileIterator::Element; + using LayoutC = typename Mma::LayoutC; using TensorRefC = TensorRef; static ComplexTransform const kTransformA = Mma::kTransformA; static ComplexTransform const kTransformB = Mma::kTransformA; // Type definitions about the mainloop. - using Operator = typename Mma::Operator; - using OperatorClass = typename Mma::Operator::OperatorClass; + using Operator = typename Mma::Operator; + using OperatorClass = typename Mma::Operator::OperatorClass; using ThreadblockShape = typename Mma::Shape; - using WarpShape = typename Mma::Operator::Shape; + using WarpShape = typename Mma::Operator::Shape; using InstructionShape = typename Mma::Policy::Operator::InstructionShape; - using ArchTag = typename Mma::ArchTag; + using ArchTag = typename Mma::ArchTag; - static int const kStages = Mma::kStages; + static int const kStages = Mma::kStages; static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; - static int const kAlignmentD = Epilogue::OutputTileIterator::kElementsPerAccess; + static int const kAlignmentD = + Epilogue::OutputTileIterator::kElementsPerAccess; /// Warp count (concept: GemmShape) - using WarpCount = typename Mma::WarpCount; + using WarpCount = typename Mma::WarpCount; static int const kThreadCount = 32 * WarpCount::kCount; - static int const kSplitKAlignment = const_max(128 / sizeof_bits::value, 128 / sizeof_bits::value); + static int const kSplitKAlignment = const_max( + 128 / sizeof_bits::value, 128 / sizeof_bits::value); - static constexpr int kInterleave = Mma::IteratorB::Shape::kRow / Mma::Shape::kK; + static constexpr int kInterleave = + Mma::IteratorB::Shape::kRow / Mma::Shape::kK; using ProblemVisitor = GemmMoeProblemVisitor; - /// Argument structure struct Arguments : UniversalArgumentsBase { // @@ -133,8 +137,7 @@ struct MoeW4A8Gemm { Arguments() {} /// constructs an arguments structure - Arguments( - cutlass::gemm::GemmUniversalMode mode_, + Arguments(cutlass::gemm::GemmUniversalMode mode_, GemmCoord problem_size_, int problem_count, int batch_count_, @@ -170,7 +173,7 @@ struct MoeW4A8Gemm { host_problem_sizes(nullptr) {} }; - /// Parameters structure + /// Parameters structure struct Params : UniversalParamsBase::value) { - isAMisaligned = problem_size.k() % kAlignmentA; - } else if (platform::is_same::value) { - isAMisaligned = problem_size.m() % kAlignmentA; - } else if (platform::is_same>::value || - platform::is_same>::value) { - isAMisaligned = problem_size.k() % kAlignmentA; - } + // + // Methods + // + + CUTLASS_HOST_DEVICE + MoeW4A8Gemm() {} + + /// Determines whether kernel satisfies alignment + CUTLASS_HOST_DEVICE + static Status can_implement(cutlass::gemm::GemmCoord const& problem_size) { + CUTLASS_TRACE_HOST( + "GemmWithEpilogueVisitorInterleavedNf4::can_implement()"); + + static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; + static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; + + bool isAMisaligned = false; + bool isBMisaligned = false; + bool isCMisaligned = false; + + if (platform::is_same::value) { + isAMisaligned = problem_size.k() % kAlignmentA; + } else if (platform::is_same::value) { + isAMisaligned = problem_size.m() % kAlignmentA; + } else if (platform::is_same>::value || + platform::is_same>::value) { + isAMisaligned = problem_size.k() % kAlignmentA; + } - if (platform::is_same::value) { - isBMisaligned = problem_size.n() % kAlignmentB; - } else if (platform::is_same::value) { - isBMisaligned = problem_size.k() % kAlignmentB; - } else if (platform::is_same>::value || - platform::is_same>::value) { - isBMisaligned = problem_size.k() % kAlignmentB; - } + if (platform::is_same::value) { + isBMisaligned = problem_size.n() % kAlignmentB; + } else if (platform::is_same::value) { + isBMisaligned = problem_size.k() % kAlignmentB; + } else if (platform::is_same>::value || + platform::is_same>::value) { + isBMisaligned = problem_size.k() % kAlignmentB; + } - if (platform::is_same::value) { - // isCMisaligned = problem_size.n() % kAlignmentC; - } else if (platform::is_same::value) { - // isCMisaligned = problem_size.m() % kAlignmentC; - } else if (platform::is_same>::value || - platform::is_same>::value) { - // isCMisaligned = problem_size.n() % kAlignmentC; - } + if (platform::is_same::value) { + // isCMisaligned = problem_size.n() % kAlignmentC; + } else if (platform::is_same::value) { + // isCMisaligned = problem_size.m() % kAlignmentC; + } else if (platform::is_same>::value || + platform::is_same>::value) { + // isCMisaligned = problem_size.n() % kAlignmentC; + } - if (isAMisaligned) { - CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for A operand"); - return Status::kErrorMisalignedOperand; - } + if (isAMisaligned) { + CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for A operand"); + return Status::kErrorMisalignedOperand; + } - if (isBMisaligned) { - CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for B operand"); - return Status::kErrorMisalignedOperand; - } + if (isBMisaligned) { + CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for B operand"); + return Status::kErrorMisalignedOperand; + } - if (isCMisaligned) { - CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for C operand"); - return Status::kErrorMisalignedOperand; - } + if (isCMisaligned) { + CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for C operand"); + return Status::kErrorMisalignedOperand; + } - CUTLASS_TRACE_HOST(" returning kSuccess"); + CUTLASS_TRACE_HOST(" returning kSuccess"); - return Status::kSuccess; - } + return Status::kSuccess; + } static Status can_implement(Arguments const& args) { return can_implement(args.problem_size); } - static size_t get_extra_workspace_size(Arguments const& args, cutlass::gemm::GemmCoord const& grid_tiled_shape) - { - - return 0; - } - + static size_t get_extra_workspace_size( + Arguments const& args, cutlass::gemm::GemmCoord const& grid_tiled_shape) { + return 0; + } - CUTLASS_DEVICE - static void invoke( - Params const ¶ms, - SharedStorage &shared_storage) - { - MoeW4A8Gemm op; - op(params, shared_storage); - } + CUTLASS_DEVICE + static void invoke(Params const& params, SharedStorage& shared_storage) { + MoeW4A8Gemm op; + op(params, shared_storage); + } #define SPLIT_K_ENABLED 1 /// Executes one GEMM CUTLASS_DEVICE void operator()(Params const& params, SharedStorage& shared_storage) { - using ElementA = typename Mma::IteratorA::Element; using LayoutA = typename Mma::IteratorA::Layout; using ElementB = typename Mma::IteratorB::Element; @@ -340,12 +334,11 @@ struct MoeW4A8Gemm { static constexpr int kInterleave = Mma::IteratorB::Shape::kRow / Mma::Shape::kK; - static_assert( - platform::is_same::value && - kInterleave == 1 || - platform::is_same::value && - kInterleave >= 1, - "B must be row major/col major OR col major interleaved."); + static_assert(platform::is_same::value && + kInterleave == 1 || + platform::is_same::value && + kInterleave >= 1, + "B must be row major/col major OR col major interleaved."); // // Problem visitor. @@ -357,190 +350,187 @@ struct MoeW4A8Gemm { int64_t bytes_per_expert_matrix = (gemm_k * gemm_n / 8) * cutlass::sizeof_bits::value; - // Outer 'persistent' loop to iterate over tiles - while (problem_visitor.next_tile()) { + // Outer 'persistent' loop to iterate over tiles + while (problem_visitor.next_tile()) { // // Compute threadblock location ThreadblockSwizzle threadblock_swizzle; cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); - GemmCoord problem_size = problem_visitor.problem_size(); - int32_t problem_idx = problem_visitor.problem_index(); - int32_t cta_idx = int32_t(problem_visitor.threadblock_idx()); - - GemmCoord grid_shape = problem_visitor.grid_shape(problem_size); - - cutlass::gemm::GemmCoord threadblock_offset( - int(cta_idx / grid_shape.n()) * Mma::Shape::kM, // NOLINT - int(cta_idx % grid_shape.n()) * Mma::Shape::kN, // NOLINT - 0); - - // Load element pointers. Exchange pointers and strides if working on - // the transpose - const int64_t rows_to_jump = - problem_idx == 0 - ? 0 - : params.problem_visitor.last_row_for_problem[problem_idx - 1]; - ElementA* ptr_A = - reinterpret_cast(params.ptr_A) + rows_to_jump * gemm_k; - typename LayoutA::LongIndex ldm_A = gemm_k; - - char* byte_ptr_B = ((char*)params.ptr_B) + // NOLINT - problem_idx * bytes_per_expert_matrix; // NOLINT - ElementB* ptr_B = reinterpret_cast(byte_ptr_B); - typename LayoutB::LongIndex ldm_B = - platform::is_same::value - ? gemm_n - : gemm_k * kInterleave; + GemmCoord problem_size = problem_visitor.problem_size(); + int32_t problem_idx = problem_visitor.problem_index(); + int32_t cta_idx = int32_t(problem_visitor.threadblock_idx()); + + GemmCoord grid_shape = problem_visitor.grid_shape(problem_size); + + cutlass::gemm::GemmCoord threadblock_offset( + int(cta_idx / grid_shape.n()) * Mma::Shape::kM, // NOLINT + int(cta_idx % grid_shape.n()) * Mma::Shape::kN, // NOLINT + 0); + + // Load element pointers. Exchange pointers and strides if working on + // the transpose + const int64_t rows_to_jump = + problem_idx == 0 + ? 0 + : params.problem_visitor.last_row_for_problem[problem_idx - 1]; + ElementA* ptr_A = + reinterpret_cast(params.ptr_A) + rows_to_jump * gemm_k; + typename LayoutA::LongIndex ldm_A = gemm_k; + + char* byte_ptr_B = ((char*)params.ptr_B) + // NOLINT + problem_idx * bytes_per_expert_matrix; // NOLINT + ElementB* ptr_B = reinterpret_cast(byte_ptr_B); + typename LayoutB::LongIndex ldm_B = + platform::is_same::value + ? gemm_n + : gemm_k * kInterleave; int offset_k = 0; int problem_size_k = params.problem_size.k(); - // Compute initial location in logical coordinates + // Compute initial location in logical coordinates cutlass::MatrixCoord tb_offset_A{ threadblock_offset.m(), 0, }; - cutlass::MatrixCoord tb_offset_B{ - 0, - threadblock_offset.n() / kInterleave}; - - // Compute position within threadblock - int thread_idx = threadIdx.x; - - // Construct iterators to A and B operands - typename Mma::IteratorA iterator_A( - params.params_A, - ptr_A, - {params.problem_size.m(), problem_size_k}, - thread_idx, - tb_offset_A); - - - typename Mma::IteratorB iterator_B( - params.params_B, - ptr_B, - {problem_size_k * kInterleave, params.problem_size.n() / kInterleave}, - thread_idx, - tb_offset_B); - typename Mma::IteratorNF4LookUpTable iterator_nf4_look_up_table = - Mma::IteratorNF4LookUpTable( - params.params_nf4_look_up_table, - params.ref_nf4_look_up_table.data(), - {0,16}, - threadIdx.x, - {0,0} - ); - - // Broadcast the warp_id computed by lane 0 to ensure dependent code - // is compiled as warp-uniform. - int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); - - int lane_idx = threadIdx.x % 32; - - // - // Main loop - // - - // Construct thread-scoped matrix multiply - Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); - - typename Mma::FragmentC accumulators; - - accumulators.clear(); - - // Compute threadblock-scoped matrix multiply-add - int gemm_k_iterations = - (problem_size_k - offset_k + Mma::Shape::kK - 1) / Mma::Shape::kK; - // printf("#### gemm_k_iterations: %d \n", gemm_k_iterations); - // Compute threadblock-scoped matrix multiply-add - mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, iterator_nf4_look_up_table, accumulators); - // if(threadIdx.x==0){ - // printf("##### block: %d-%d-%d, offset-m-n-k:%d-%d-%d \n", - // blockIdx.x, blockIdx.y, blockIdx.z, - // threadblock_tile_offset.m(), - // threadblock_tile_offset.n(), - // threadblock_tile_offset.k() - // ); - // } - // - // Masked tile iterators constructed from members - // - EpilogueOutputOp output_op(params.output_op); - - threadblock_tile_offset = - threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); - - ElementC* ptr_D = - reinterpret_cast(params.ptr_D) + rows_to_jump * gemm_n; - - - int block_idx = threadblock_tile_offset.m() + - threadblock_tile_offset.n() * params.grid_tiled_shape.m(); - - - - // Construct the semaphore. - Semaphore semaphore(params.semaphore + block_idx, thread_idx); - - // If performing a reduction via split-K, fetch the initial synchronization - if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { - - // Fetch the synchronization lock initially but do not block. - semaphore.fetch(); - - // Indicate which position in a serial reduction the output operator is currently updating - output_op.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k()); - } - - // Tile iterator writing to destination tensor. - typename Epilogue::OutputTileIterator iterator_D(params.params_D, - ptr_D, - params.problem_size.mn(), - thread_idx, - threadblock_offset, - params.scatter_D_indices); - - - Epilogue epilogue(shared_storage.epilogue, thread_idx, warp_idx, lane_idx); - - // Wait on the semaphore - this latency may have been covered by iterator construction - if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { - - // For subsequent threadblocks, the source matrix is held in the 'D' tensor. + cutlass::MatrixCoord tb_offset_B{0, threadblock_offset.n() / kInterleave}; + + // Compute position within threadblock + int thread_idx = threadIdx.x; + + // Construct iterators to A and B operands + typename Mma::IteratorA iterator_A( + params.params_A, + ptr_A, + {params.problem_size.m(), problem_size_k}, + thread_idx, + tb_offset_A); + + typename Mma::IteratorB iterator_B( + params.params_B, + ptr_B, + {problem_size_k * kInterleave, params.problem_size.n() / kInterleave}, + thread_idx, + tb_offset_B); + typename Mma::IteratorNF4LookUpTable iterator_nf4_look_up_table = + Mma::IteratorNF4LookUpTable(params.params_nf4_look_up_table, + params.ref_nf4_look_up_table.data(), + {0, 16}, + threadIdx.x, + {0, 0}); + + // Broadcast the warp_id computed by lane 0 to ensure dependent code + // is compiled as warp-uniform. + int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + + int lane_idx = threadIdx.x % 32; + + // + // Main loop + // + + // Construct thread-scoped matrix multiply + Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); + + typename Mma::FragmentC accumulators; + + accumulators.clear(); + + // Compute threadblock-scoped matrix multiply-add + int gemm_k_iterations = + (problem_size_k - offset_k + Mma::Shape::kK - 1) / Mma::Shape::kK; + // printf("#### gemm_k_iterations: %d \n", gemm_k_iterations); + // Compute threadblock-scoped matrix multiply-add + mma(gemm_k_iterations, + accumulators, + iterator_A, + iterator_B, + iterator_nf4_look_up_table, + accumulators); + // if(threadIdx.x==0){ + // printf("##### block: %d-%d-%d, offset-m-n-k:%d-%d-%d \n", + // blockIdx.x, blockIdx.y, blockIdx.z, + // threadblock_tile_offset.m(), + // threadblock_tile_offset.n(), + // threadblock_tile_offset.k() + // ); + // } + // + // Masked tile iterators constructed from members + // + EpilogueOutputOp output_op(params.output_op); + + threadblock_tile_offset = + threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); - semaphore.wait(threadblock_tile_offset.k()); - } + ElementC* ptr_D = + reinterpret_cast(params.ptr_D) + rows_to_jump * gemm_n; - // Execute the epilogue operator to update the destination tensor. - epilogue(output_op, iterator_D, accumulators, iterator_D); + int block_idx = threadblock_tile_offset.m() + + threadblock_tile_offset.n() * params.grid_tiled_shape.m(); - // - // Release the semaphore - // + // Construct the semaphore. + Semaphore semaphore(params.semaphore + block_idx, thread_idx); - if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { + // If performing a reduction via split-K, fetch the initial + // synchronization + if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { + // Fetch the synchronization lock initially but do not block. + semaphore.fetch(); - int lock = 0; - if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) { + // Indicate which position in a serial reduction the output operator is + // currently updating + output_op.set_k_partition(threadblock_tile_offset.k(), + params.grid_tiled_shape.k()); + } - // The final threadblock resets the semaphore for subsequent grids. - lock = 0; - } - else { - // Otherwise, the semaphore is incremented - lock = threadblock_tile_offset.k() + 1; - } + // Tile iterator writing to destination tensor. + typename Epilogue::OutputTileIterator iterator_D( + params.params_D, + ptr_D, + params.problem_size.mn(), + thread_idx, + threadblock_offset, + params.scatter_D_indices); + + Epilogue epilogue( + shared_storage.epilogue, thread_idx, warp_idx, lane_idx); + + // Wait on the semaphore - this latency may have been covered by iterator + // construction + if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { + // For subsequent threadblocks, the source matrix is held in the 'D' + // tensor. + + semaphore.wait(threadblock_tile_offset.k()); + } - semaphore.release(lock); - } + // Execute the epilogue operator to update the destination tensor. + epilogue(output_op, iterator_D, accumulators, iterator_D); + + // + // Release the semaphore + // + + if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { + int lock = 0; + if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) { + // The final threadblock resets the semaphore for subsequent grids. + lock = 0; + } else { + // Otherwise, the semaphore is incremented + lock = threadblock_tile_offset.k() + 1; + } - // Next tile - shared_storage.problem_visitor.advance(gridDim.x); + semaphore.release(lock); } + // Next tile + shared_storage.problem_visitor.advance(gridDim.x); + } } }; diff --git a/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/w4a8_moe_cutlass_kernel_template.cu b/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/w4a8_moe_cutlass_kernel_template.cu index ede2178d273..cf02b86af19 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/w4a8_moe_cutlass_kernel_template.cu +++ b/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/w4a8_moe_cutlass_kernel_template.cu @@ -38,10 +38,9 @@ #include "cutlass_kernels/w4a8_moe/cutlass_extensions/gemm/kernel/gemm_with_epilogue_visitor_interleaved_nf4.h" #include "w4a8_moe_gemm_with_epilogue_visitor.h" - template class IntegerType { - public: + public: static constexpr int value = val; }; @@ -76,17 +75,21 @@ void generic_w4a8_moe_gemm_kernelLauncher( int multi_processor_count, cudaStream_t stream, int* occupancy) { - if (gemm_config.split_k_style == SplitKStyle::NO_SPLIT_K){ + if (gemm_config.split_k_style == SplitKStyle::NO_SPLIT_K) { static_assert(cutlass::platform::is_same::value, "input type must be int8_t"); - // The cutlass type for the input elements. This is needed to convert to cutlass::half_t if necessary. - // using OutputElementType_ = OutputType; - using OutputElementType_ = typename cutlass::platform::conditional::value, - cutlass::bfloat16_t, OutputType>::type; + // The cutlass type for the input elements. This is needed to convert to + // cutlass::half_t if necessary. using OutputElementType_ = OutputType; + using OutputElementType_ = typename cutlass::platform::conditional< + cutlass::platform::is_same::value, + cutlass::bfloat16_t, + OutputType>::type; - using OutputElementType = typename cutlass::platform::conditional::value, - cutlass::half_t, OutputElementType_>::type; + using OutputElementType = typename cutlass::platform::conditional< + cutlass::platform::is_same::value, + cutlass::half_t, + OutputElementType_>::type; using CutlassIntAType_ = IntAType; using CutlassIntAType = CutlassIntAType_; @@ -94,47 +97,55 @@ void generic_w4a8_moe_gemm_kernelLauncher( using CutlassIntBType_ = IntBType; using CutlassIntBType = CutlassIntBType_; - // We need separate config for each architecture since we will target different tensorcore instructions. For float, - // we do not target TCs. - - using MixedGemmArchTraits = cutlass::gemm::kernel:: - Int8Nf4GemmArchTraits; + // We need separate config for each architecture since we will target + // different tensorcore instructions. For float, we do not target TCs. - using ElementAccumulator = typename MixedGemmArchTraits::AccType; - using ElementCompute = float; + using MixedGemmArchTraits = + cutlass::gemm::kernel::Int8Nf4GemmArchTraits; + using ElementAccumulator = typename MixedGemmArchTraits::AccType; + using ElementCompute = float; // ============== using EpilogueOp = - typename Epilogue::Op; - - using ThreadBlockSwizzle = typename cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle; - using GemmKernel_ = typename cutlass::gemm::kernel::DefaultInt8InterleavedGemm< - CutlassIntAType, - cutlass::layout::RowMajor, - MixedGemmArchTraits::ElementsPerAccessA, - CutlassIntBType, - typename MixedGemmArchTraits::LayoutB, - MixedGemmArchTraits::ElementsPerAccessB, - OutputElementType, - cutlass::layout::RowMajor, - ElementAccumulator, - cutlass::arch::OpClassTensorOp, - arch, - ThreadblockShape, - WarpShape, - typename MixedGemmArchTraits::InstructionShape, - EpilogueOp, - ThreadBlockSwizzle, - Stages, - true, - typename MixedGemmArchTraits::Operator>::GemmKernel; - using GemmKernel = cutlass::gemm::kernel::MoeW4A8Gemm; + typename Epilogue::Op; + + using ThreadBlockSwizzle = typename cutlass::gemm::threadblock:: + GemmBatchedIdentityThreadblockSwizzle; + using GemmKernel_ = + typename cutlass::gemm::kernel::DefaultInt8InterleavedGemm< + CutlassIntAType, + cutlass::layout::RowMajor, + MixedGemmArchTraits::ElementsPerAccessA, + CutlassIntBType, + typename MixedGemmArchTraits::LayoutB, + MixedGemmArchTraits::ElementsPerAccessB, + OutputElementType, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + arch, + ThreadblockShape, + WarpShape, + typename MixedGemmArchTraits::InstructionShape, + EpilogueOp, + ThreadBlockSwizzle, + Stages, + true, + typename MixedGemmArchTraits::Operator>::GemmKernel; + using GemmKernel = cutlass::gemm::kernel::MoeW4A8Gemm< + typename GemmKernel_::Mma, + typename GemmKernel_::Epilogue, + typename GemmKernel_::ThreadblockSwizzle, + arch, // Ensure top level arch is used for dispatch + GemmKernel_::kSplitKSerial, + cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly>; using AlphaColTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator< cutlass::epilogue::threadblock::OutputTileOptimalThreadMap< @@ -147,38 +158,44 @@ void generic_w4a8_moe_gemm_kernelLauncher( cutlass::sizeof_bits::value>, OutputElementType>; - using EpilogueVisitor = typename cutlass::epilogue::threadblock::EpilogueVisitorPerRowPerColNf4< - ThreadblockShape, - GemmKernel::kThreadCount, - AlphaColTileIterator, - typename GemmKernel::Epilogue::OutputTileIterator, - ElementAccumulator, - ElementCompute, - EpilogueOp>; + using EpilogueVisitor = + typename cutlass::epilogue::threadblock::EpilogueVisitorPerRowPerColNf4< + ThreadblockShape, + GemmKernel::kThreadCount, + AlphaColTileIterator, + typename GemmKernel::Epilogue::OutputTileIterator, + ElementAccumulator, + ElementCompute, + EpilogueOp>; /// Epilogue using Epilogue = typename cutlass::epilogue::threadblock:: - EpilogueWithVisitorFromExistingEpilogue::Epilogue; + EpilogueWithVisitorFromExistingEpilogue< + EpilogueVisitor, + typename GemmKernel::Epilogue>::Epilogue; // GEMM using GemmWithEpilogueVisitorKernel = - cutlass::gemm::kernel::MoeW4A8GemmWithEpilogueVisitorInterleavedNf4; - + cutlass::gemm::kernel::MoeW4A8GemmWithEpilogueVisitorInterleavedNf4< + typename GemmKernel::Mma, + Epilogue, + ThreadBlockSwizzle, + cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly>; if (occupancy != nullptr) { - *occupancy = compute_occupancy_for_kernel(); - return; + *occupancy = + compute_occupancy_for_kernel(); + return; } - using Gemm = cutlass::gemm::device::W4A8MoeGemmUniversalBase; + using Gemm = cutlass::gemm::device::W4A8MoeGemmUniversalBase< + GemmWithEpilogueVisitorKernel>; const int ldb = - cutlass::platform::is_same::value ? - n : - k * GemmKernel::kInterleave; + cutlass::platform::is_same::value + ? n + : k * GemmKernel::kInterleave; typename EpilogueOp::Params linear_scaling_params; @@ -192,64 +209,79 @@ void generic_w4a8_moe_gemm_kernelLauncher( const int threadblock_count = multi_processor_count * occupancy_; - typename Gemm::Arguments args{cutlass::gemm::GemmUniversalMode::kBatched, - num_experts, - threadblock_count, - {total_rows, n, k}, - 1, - {reinterpret_cast(const_cast(A)), k}, - {reinterpret_cast(const_cast(B)), ldb}, - quant_mode, - {reinterpret_cast(const_cast(col_scale)), 0}, - {reinterpret_cast(const_cast(row_scale)), 0}, - {const_cast(nf4_look_up_table), 0}, - {reinterpret_cast(C), n}, - {reinterpret_cast(C), n}, - total_rows_before_expert, - total_rows_in_ll_else_minus1, - n, - k, - (int64_t)0, - (int64_t)0, - typename EpilogueVisitor::Arguments(linear_scaling_params, 0, 0, 0)}; - - // This assertion is enabled because because for the column interleaved layout, K MUST be a multiple of - // threadblockK. The reason for this is that the default pitchlinear iterators are used to handle walking over the - // interleaved matrix. The way masking in handled in these do not map to the interleaved layout. We need to write - // our own predicated iterator in order to relax this limitation. - if (GemmKernel::kInterleave > 1 - && ((k % MixedGemmArchTraits::ThreadblockK) - || ((k / gemm_config.split_k_factor) % MixedGemmArchTraits::ThreadblockK))) { - throw std::runtime_error("Temp assertion: k must be multiple of threadblockK"); + typename Gemm::Arguments args{ + cutlass::gemm::GemmUniversalMode::kBatched, + num_experts, + threadblock_count, + {total_rows, n, k}, + 1, + {reinterpret_cast(const_cast(A)), k}, + {reinterpret_cast(const_cast(B)), ldb}, + quant_mode, + {reinterpret_cast( + const_cast(col_scale)), + 0}, + {reinterpret_cast( + const_cast(row_scale)), + 0}, + {const_cast(nf4_look_up_table), 0}, + {reinterpret_cast(C), n}, + {reinterpret_cast(C), n}, + total_rows_before_expert, + total_rows_in_ll_else_minus1, + n, + k, + (int64_t)0, + (int64_t)0, + typename EpilogueVisitor::Arguments(linear_scaling_params, 0, 0, 0)}; + + // This assertion is enabled because because for the column interleaved + // layout, K MUST be a multiple of threadblockK. The reason for this is that + // the default pitchlinear iterators are used to handle walking over the + // interleaved matrix. The way masking in handled in these do not map to the + // interleaved layout. We need to write our own predicated iterator in order + // to relax this limitation. + if (GemmKernel::kInterleave > 1 && + ((k % MixedGemmArchTraits::ThreadblockK) || + ((k / gemm_config.split_k_factor) % + MixedGemmArchTraits::ThreadblockK))) { + throw std::runtime_error( + "Temp assertion: k must be multiple of threadblockK"); } Gemm gemm; if (gemm.get_workspace_size(args) > workspace_bytes) { - std::cout<< - "Requested split-k but workspace size insufficient. Falling back to non-split-k implementation."< -void dispatch_moe_gemm_to_cutlass( - const IntAType* A, - const IntBType* B, - cutlass::epilogue::QuantMode quant_mode, - const OutputType* col_scale, - const OutputType* row_scale, - const int32_t* nf4_look_up_table, - OutputType* C, - int64_t* total_rows_before_expert, - int64_t total_rows_in_ll_else_minus1, - int64_t total_rows, - int64_t gemm_n, - int64_t gemm_k, - int num_experts, - CutlassGemmConfig gemm_config, - char* workspace_ptr, - const size_t workspace_bytes, - // int sm_version, - int multi_processor_count, - cudaStream_t stream, - int* occupancy = nullptr) { +void dispatch_moe_gemm_to_cutlass(const IntAType* A, + const IntBType* B, + cutlass::epilogue::QuantMode quant_mode, + const OutputType* col_scale, + const OutputType* row_scale, + const int32_t* nf4_look_up_table, + OutputType* C, + int64_t* total_rows_before_expert, + int64_t total_rows_in_ll_else_minus1, + int64_t total_rows, + int64_t gemm_n, + int64_t gemm_k, + int num_experts, + CutlassGemmConfig gemm_config, + char* workspace_ptr, + const size_t workspace_bytes, + // int sm_version, + int multi_processor_count, + cudaStream_t stream, + int* occupancy = nullptr) { // VLOG(1)<<__PRETTY_FUNCTION__; auto dispatch_by_tile = [&](auto ThreadblockShapeM, @@ -371,67 +400,67 @@ void dispatch_moe_gemm_to_cutlass( auto WarpShapeM, auto WarpShapeN, auto WarpShapeK) { - dispatch_gemm_config< - OutputType, - IntAType, - IntBType, - arch, - EpilogueTag, - cutlass::gemm::GemmShape, - cutlass::gemm::GemmShape> - (A, - B, - quant_mode, - col_scale, - row_scale, - nf4_look_up_table, - C, - total_rows_before_expert, - total_rows_in_ll_else_minus1, - total_rows, - gemm_n, - gemm_k, - num_experts, - gemm_config, - workspace_ptr, - workspace_bytes, - multi_processor_count, - stream, - occupancy); + dispatch_gemm_config< + OutputType, + IntAType, + IntBType, + arch, + EpilogueTag, + cutlass::gemm::GemmShape, + cutlass::gemm::GemmShape>( + A, + B, + quant_mode, + col_scale, + row_scale, + nf4_look_up_table, + C, + total_rows_before_expert, + total_rows_in_ll_else_minus1, + total_rows, + gemm_n, + gemm_k, + num_experts, + gemm_config, + workspace_ptr, + workspace_bytes, + multi_processor_count, + stream, + occupancy); }; switch (gemm_config.tile_config) { case CutlassTileConfig::CtaShape16x128x64_WarpShape16x32x64: - dispatch_by_tile(Int<16>(), Int<64>(), Int<64>(), - Int<16>(), Int<32>(), Int<64>()); + dispatch_by_tile( + Int<16>(), Int<64>(), Int<64>(), Int<16>(), Int<32>(), Int<64>()); break; case CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64: - dispatch_by_tile(Int<32>(), Int<128>(), Int<64>(), - Int<32>(), Int<32>(), Int<64>()); + dispatch_by_tile( + Int<32>(), Int<128>(), Int<64>(), Int<32>(), Int<32>(), Int<64>()); break; case CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64: - dispatch_by_tile(Int<64>(), Int<128>(), Int<64>(), - Int<64>(), Int<32>(), Int<64>()); + dispatch_by_tile( + Int<64>(), Int<128>(), Int<64>(), Int<64>(), Int<32>(), Int<64>()); break; case CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64: - dispatch_by_tile(Int<128>(), Int<128>(), Int<64>(), - Int<128>(), Int<32>(), Int<64>()); + dispatch_by_tile( + Int<128>(), Int<128>(), Int<64>(), Int<128>(), Int<32>(), Int<64>()); break; case CutlassTileConfig::CtaShape32x512x64_WarpShape32x128x64: - dispatch_by_tile(Int<32>(), Int<512>(), Int<64>(), - Int<32>(), Int<128>(), Int<64>()); + dispatch_by_tile( + Int<32>(), Int<512>(), Int<64>(), Int<32>(), Int<128>(), Int<64>()); break; case CutlassTileConfig::CtaShape32x256x64_WarpShape32x64x64: - dispatch_by_tile(Int<32>(), Int<256>(), Int<64>(), - Int<32>(), Int<64>(), Int<64>()); + dispatch_by_tile( + Int<32>(), Int<256>(), Int<64>(), Int<32>(), Int<64>(), Int<64>()); break; case CutlassTileConfig::CtaShape64x256x64_WarpShape64x64x64: - dispatch_by_tile(Int<64>(), Int<256>(), Int<64>(), - Int<64>(), Int<64>(), Int<64>()); + dispatch_by_tile( + Int<64>(), Int<256>(), Int<64>(), Int<64>(), Int<64>(), Int<64>()); break; // case CutlassTileConfig::CtaShape128x256x64_WarpShape128x64x64: // dispatch_by_tile(Int<128>(), Int<256>(), Int<64>(), @@ -463,7 +492,6 @@ void dispatch_moe_gemm_to_cutlass( } } - template W4A8MoeGemmRunner::W4A8MoeGemmRunner() { int device{-1}; @@ -472,137 +500,155 @@ W4A8MoeGemmRunner::W4A8MoeGemmRunner() { // sm_ = 80; check_cuda_error(cudaDeviceGetAttribute( &multi_processor_count_, cudaDevAttrMultiProcessorCount, device)); - std::string FLAGS_cutlass_w4a8_moe_best_config=""; + std::string FLAGS_cutlass_w4a8_moe_best_config = ""; if (getenv("FLAGS_cutlass_w4a8_moe_best_config")) { - FLAGS_cutlass_w4a8_moe_best_config = getenv("FLAGS_cutlass_w4a8_moe_best_config"); + FLAGS_cutlass_w4a8_moe_best_config = + getenv("FLAGS_cutlass_w4a8_moe_best_config"); } - if(tuned_configs_from_file.empty() && FLAGS_cutlass_w4a8_moe_best_config!="") { + if (tuned_configs_from_file.empty() && + FLAGS_cutlass_w4a8_moe_best_config != "") { std::string config_file_path = FLAGS_cutlass_w4a8_moe_best_config; - if (config_file_path.find(".config")!=std::string::npos) { + if (config_file_path.find(".config") != std::string::npos) { std::ifstream config_file(FLAGS_cutlass_w4a8_moe_best_config); - if (config_file.is_open()) { - VLOG(1)<<"Get tuned w4a8 moe gemm config from: "< vec_configs; - while(std::getline(ss, item, ',')) { - try { - int value = std::stoi(item); - vec_configs.push_back(value); - } catch (const std::invalid_argument& e) { - std::cerr << "Invalid argument: " << item << " is not an integer." << std::endl; - return; - } catch (const std::out_of_range& e) { - std::cerr << "Out of range: " << item << " is out of the range of representable values." << std::endl; - return; - } + if (config_file.is_open()) { + VLOG(1) << "Get tuned w4a8 moe gemm config from: " << config_file_path; + std::string config_string; + while (std::getline(config_file, config_string)) { + // decode one line of base64 string + config_string = base64_decode(config_string); + VLOG(1) << "decode config_string: " << config_string; + std::stringstream ss(config_string); + std::string item; + std::vector vec_configs; + while (std::getline(ss, item, ',')) { + try { + int value = std::stoi(item); + vec_configs.push_back(value); + } catch (const std::invalid_argument& e) { + std::cerr << "Invalid argument: " << item << " is not an integer." + << std::endl; + return; + } catch (const std::out_of_range& e) { + std::cerr << "Out of range: " << item + << " is out of the range of representable values." + << std::endl; + return; } - W4A8MoeGEMMConfig search_config; - search_config.total_rows = vec_configs[0]; - search_config.n = vec_configs[1]; - search_config.k = vec_configs[2]; - search_config.num_experts = vec_configs[3]; - search_config.tile_config = static_cast(vec_configs[4]); - search_config.split_k_style = static_cast(vec_configs[5]); - search_config.split_k_factor = vec_configs[6]; - search_config.stages = vec_configs[7]; - tuned_configs_from_file.push_back(search_config); - VLOG(1)<<"tuned_configs_from_file: "<(search_config.tile_config)<<"," << static_cast(search_config.split_k_style)<<","<(vec_configs[4]); + search_config.split_k_style = + static_cast(vec_configs[5]); + search_config.split_k_factor = vec_configs[6]; + search_config.stages = vec_configs[7]; + tuned_configs_from_file.push_back(search_config); + VLOG(1) << "tuned_configs_from_file: " << search_config.total_rows + << "," << search_config.n << "," << search_config.k << "," + << search_config.num_experts << "," + << static_cast(search_config.tile_config) << "," + << static_cast(search_config.split_k_style) << "," + << search_config.split_k_factor << "," + << search_config.stages; } + } else { + VLOG(1) << "No tuned w4a8 gemm config."; + } } else { - FILE * fp; + FILE* fp; fp = fopen(config_file_path.c_str(), "r"); - if(fp) { - VLOG(1)<<"Get tuned w4a8 moe gemm config from: "<(tile_config); - search_config.split_k_style = static_cast(split_k_style); - search_config.split_k_factor = split_k_factor; - search_config.stages = stages; - tuned_configs_from_file.push_back(search_config); - VLOG(1)<<"tuned_configs_from_file: "<(tile_config); + search_config.split_k_style = static_cast(split_k_style); + search_config.split_k_factor = split_k_factor; + search_config.stages = stages; + tuned_configs_from_file.push_back(search_config); + VLOG(1) << "tuned_configs_from_file: " << total_rows_tmp << "," + << n_tmp << "," << k_tmp << "," << num_experts_tmp << "," + << tile_config << "," << split_k_style << "," + << split_k_factor << "," << stages; + if (feof(fp)) break; } - } else if(FLAGS_cutlass_w4a8_moe_best_config=="") { - VLOG(1)<<"No tuned w4a8 gemm config."; + } else if (FLAGS_cutlass_w4a8_moe_best_config == "") { + VLOG(1) << "No tuned w4a8 gemm config."; } } } } template -W4A8MoeGemmRunner::~W4A8MoeGemmRunner() { -} - - +W4A8MoeGemmRunner::~W4A8MoeGemmRunner() {} template template -void W4A8MoeGemmRunner::dispatch_to_arch( - const IntAType* A, - const IntBType* B, - cutlass::epilogue::QuantMode quant_mode, - const OutputType* col_scale, - const OutputType* row_scale, - const int32_t* nf4_look_up_table, - OutputType* C, - int64_t* total_rows_before_expert, - int64_t total_rows_in_ll_else_minus1, - int64_t total_rows, - int64_t gemm_n, - int64_t gemm_k, - int num_experts, - CutlassGemmConfig gemm_config, - char* workspace_ptr, - const size_t workspace_bytes, - cudaStream_t stream, - int* occupancy) { - +void W4A8MoeGemmRunner::dispatch_to_arch< + EpilogueTag>(const IntAType* A, + const IntBType* B, + cutlass::epilogue::QuantMode quant_mode, + const OutputType* col_scale, + const OutputType* row_scale, + const int32_t* nf4_look_up_table, + OutputType* C, + int64_t* total_rows_before_expert, + int64_t total_rows_in_ll_else_minus1, + int64_t total_rows, + int64_t gemm_n, + int64_t gemm_k, + int num_experts, + CutlassGemmConfig gemm_config, + char* workspace_ptr, + const size_t workspace_bytes, + cudaStream_t stream, + int* occupancy) { // only sm80 here dispatch_moe_gemm_to_cutlass(A, - B, - quant_mode, - col_scale, - row_scale, - nf4_look_up_table, - C, - total_rows_before_expert, - total_rows_in_ll_else_minus1, - total_rows, - gemm_n, - gemm_k, - num_experts, - gemm_config, - workspace_ptr, - workspace_bytes, - multi_processor_count_, - stream, - occupancy); - - + IntAType, + IntBType, + cutlass::arch::Sm80, + EpilogueTag>(A, + B, + quant_mode, + col_scale, + row_scale, + nf4_look_up_table, + C, + total_rows_before_expert, + total_rows_in_ll_else_minus1, + total_rows, + gemm_n, + gemm_k, + num_experts, + gemm_config, + workspace_ptr, + workspace_bytes, + multi_processor_count_, + stream, + occupancy); } template @@ -625,8 +671,8 @@ void W4A8MoeGemmRunner::run_gemm( int num_experts, cudaStream_t stream, CutlassGemmConfig gemm_config) { - VLOG(1)<<__PRETTY_FUNCTION__; - static constexpr bool is_weight_only = true; //todo(yuanxiaolan) + VLOG(1) << __PRETTY_FUNCTION__; + static constexpr bool is_weight_only = true; // todo(yuanxiaolan) bool is_weight_only_encoder = total_rows >= 512 ? true : false; VLOG(1) << "gemm_config tile_config" @@ -636,29 +682,29 @@ void W4A8MoeGemmRunner::run_gemm( VLOG(1) << "gemm_config split_k_factor " << gemm_config.split_k_factor; VLOG(1) << "gemm_config stages " << gemm_config.stages; - if(gemm_config.tile_config != CutlassTileConfig::Undefined) { + if (gemm_config.tile_config != CutlassTileConfig::Undefined) { dispatch_to_arch(A, - B, - quant_mode, - col_scale, - row_scale, - nf4_look_up_table, - C, - total_rows_before_expert, - total_rows_in_ll_else_minus1, - total_rows, - gemm_n, - gemm_k, - num_experts, - gemm_config, - workspace_ptr, - workspace_bytes, - stream); + B, + quant_mode, + col_scale, + row_scale, + nf4_look_up_table, + C, + total_rows_before_expert, + total_rows_in_ll_else_minus1, + total_rows, + gemm_n, + gemm_k, + num_experts, + gemm_config, + workspace_ptr, + workspace_bytes, + stream); return; } - std::vector candidate_configs = - get_candidate_configs_nf4(80, is_weight_only, is_weight_only_encoder, false); + std::vector candidate_configs = get_candidate_configs_nf4( + 80, is_weight_only, is_weight_only_encoder, false); std::vector occupancies(candidate_configs.size()); for (size_t ii = 0; ii < candidate_configs.size(); ++ii) { @@ -686,20 +732,21 @@ void W4A8MoeGemmRunner::run_gemm( int local_multi_processor_count{0}; check_cuda_error(cudaGetDevice(&local_device)); // sm_ = getSMVersion(); - check_cuda_error(cudaDeviceGetAttribute( - &local_multi_processor_count, cudaDevAttrMultiProcessorCount, local_device)); + check_cuda_error(cudaDeviceGetAttribute(&local_multi_processor_count, + cudaDevAttrMultiProcessorCount, + local_device)); CutlassGemmConfig chosen_config = estimate_best_config_from_occupancies_w4a4(candidate_configs, - occupancies, - total_rows, - gemm_n, - gemm_k, - num_experts, - split_k_limit, - workspace_bytes, - local_multi_processor_count, - is_weight_only); + occupancies, + total_rows, + gemm_n, + gemm_k, + num_experts, + split_k_limit, + workspace_bytes, + local_multi_processor_count, + is_weight_only); VLOG(1) << "chosen_config tile_config " << static_cast(chosen_config.tile_config); @@ -711,7 +758,6 @@ void W4A8MoeGemmRunner::run_gemm( VLOG(1) << "total_rows " << total_rows << "gemm_n " << gemm_n << "gemm_k " << gemm_k; - dispatch_to_arch(A, B, quant_mode, @@ -732,7 +778,8 @@ void W4A8MoeGemmRunner::run_gemm( } // template -// void W4A8MoeGemmRunner::moe_gemm_bias_act( const IntAType* A, +// void W4A8MoeGemmRunner::moe_gemm_bias_act( +// const IntAType* A, // const IntBType* B, // QuantMode quant_mode, // const OutputType* col_scale, @@ -769,65 +816,82 @@ void W4A8MoeGemmRunner::run_gemm( template void W4A8MoeGemmRunner::moe_gemm( - const IntAType* A, - const IntBType* B, - cutlass::epilogue::QuantMode quant_mode, - const OutputType* col_scale, - const OutputType* row_scale, - const int32_t* nf4_look_up_table, - OutputType* C, - int64_t* total_rows_before_expert, - int64_t total_rows_in_ll_else_minus1, - int64_t total_rows, - int64_t gemm_n, - int64_t gemm_k, - char* workspace_ptr, - const size_t workspace_bytes, - int num_experts, - cudaStream_t stream, - CutlassGemmConfig gemm_config) { + const IntAType* A, + const IntBType* B, + cutlass::epilogue::QuantMode quant_mode, + const OutputType* col_scale, + const OutputType* row_scale, + const int32_t* nf4_look_up_table, + OutputType* C, + int64_t* total_rows_before_expert, + int64_t total_rows_in_ll_else_minus1, + int64_t total_rows, + int64_t gemm_n, + int64_t gemm_k, + char* workspace_ptr, + const size_t workspace_bytes, + int num_experts, + cudaStream_t stream, + CutlassGemmConfig gemm_config) { CutlassGemmConfig gemm_config_from_file_and_param = gemm_config; - if(!tuned_configs_from_file.empty()){ - bool match=false; + if (!tuned_configs_from_file.empty()) { + bool match = false; int best_total_rows, best_n, best_k, best_num_experts; - int max_config_total_rows_in_file=0; + int max_config_total_rows_in_file = 0; W4A8MoeGEMMConfig max_total_rows_config; - for(const auto& tuned_config:tuned_configs_from_file) { - // choose the smallest config_m with config_m >=m - if(tuned_config.total_rows <= total_rows && tuned_config.n==gemm_n && tuned_config.k==gemm_k && tuned_config.num_experts==num_experts) { - best_total_rows=tuned_config.total_rows; - best_n=tuned_config.n; - best_k=tuned_config.k; - best_num_experts=tuned_config.num_experts; - gemm_config_from_file_and_param.tile_config = tuned_config.tile_config; - gemm_config_from_file_and_param.split_k_style = tuned_config.split_k_style; - gemm_config_from_file_and_param.split_k_factor = tuned_config.split_k_factor; - gemm_config_from_file_and_param.stages = tuned_config.stages; - match=true; - } - if(tuned_config.total_rows > max_config_total_rows_in_file && tuned_config.n==gemm_n && tuned_config.k==gemm_k && tuned_config.num_experts==num_experts){ - max_config_total_rows_in_file = tuned_config.total_rows; - max_total_rows_config = tuned_config; - } + for (const auto& tuned_config : tuned_configs_from_file) { + // choose the smallest config_m with config_m >=m + if (tuned_config.total_rows <= total_rows && tuned_config.n == gemm_n && + tuned_config.k == gemm_k && tuned_config.num_experts == num_experts) { + best_total_rows = tuned_config.total_rows; + best_n = tuned_config.n; + best_k = tuned_config.k; + best_num_experts = tuned_config.num_experts; + gemm_config_from_file_and_param.tile_config = tuned_config.tile_config; + gemm_config_from_file_and_param.split_k_style = + tuned_config.split_k_style; + gemm_config_from_file_and_param.split_k_factor = + tuned_config.split_k_factor; + gemm_config_from_file_and_param.stages = tuned_config.stages; + match = true; + } + if (tuned_config.total_rows > max_config_total_rows_in_file && + tuned_config.n == gemm_n && tuned_config.k == gemm_k && + tuned_config.num_experts == num_experts) { + max_config_total_rows_in_file = tuned_config.total_rows; + max_total_rows_config = tuned_config; + } } - if(!match){ - if (max_total_rows_config.n==gemm_n && max_total_rows_config.k==gemm_k && max_total_rows_config.num_experts==num_experts) { + if (!match) { + if (max_total_rows_config.n == gemm_n && + max_total_rows_config.k == gemm_k && + max_total_rows_config.num_experts == num_experts) { best_total_rows = max_config_total_rows_in_file; - gemm_config_from_file_and_param.tile_config = max_total_rows_config.tile_config; - gemm_config_from_file_and_param.split_k_style = max_total_rows_config.split_k_style; - gemm_config_from_file_and_param.split_k_factor = max_total_rows_config.split_k_factor; + gemm_config_from_file_and_param.tile_config = + max_total_rows_config.tile_config; + gemm_config_from_file_and_param.split_k_style = + max_total_rows_config.split_k_style; + gemm_config_from_file_and_param.split_k_factor = + max_total_rows_config.split_k_factor; gemm_config_from_file_and_param.stages = max_total_rows_config.stages; } } - VLOG(1) <<"W4A8 moe gemm " - <<"total_rows: "<(gemm_config_from_file_and_param.tile_config) - <<"split_k_style: "<(gemm_config_from_file_and_param.split_k_style) - <<"split_k_factor: "<(gemm_config_from_file_and_param.split_k_factor) - <<"stages: "<(gemm_config_from_file_and_param.stages); + VLOG(1) << "W4A8 moe gemm " + << "total_rows: " << total_rows << " n: " << gemm_n + << " k: " << gemm_k + << "Using gemm config from config file: config_total_rows: " + << best_total_rows << " config_n: " << best_n + << " config_k: " << best_k << "tile_config: " + << static_cast(gemm_config_from_file_and_param.tile_config) + << "split_k_style: " + << static_cast(gemm_config_from_file_and_param.split_k_style) + << "split_k_factor: " + << static_cast(gemm_config_from_file_and_param.split_k_factor) + << "stages: " + << static_cast(gemm_config_from_file_and_param.stages); } else { - VLOG(1) << "tuned_configs_from_file is empty, use W4A8 gemm config in param"; + VLOG(1) + << "tuned_configs_from_file is empty, use W4A8 gemm config in param"; } run_gemm(A, B, @@ -849,7 +913,10 @@ void W4A8MoeGemmRunner::moe_gemm( } template -std::vector::W4A8MoeGEMMConfig> W4A8MoeGemmRunner::tuned_configs_from_file = {}; +std::vector:: + W4A8MoeGEMMConfig> + W4A8MoeGemmRunner::tuned_configs_from_file = + {}; template int W4A8MoeGemmRunner::getWorkspaceSize( @@ -863,6 +930,5 @@ int W4A8MoeGemmRunner::getWorkspaceSize( return max_grid_m * max_grid_n * split_k_limit * 4; } - template class W4A8MoeGemmRunner; template class W4A8MoeGemmRunner<__nv_bfloat16, int8_t, cutlass::uint4b_t>; diff --git a/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/w4a8_moe_gemm_config_search.sh b/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/w4a8_moe_gemm_config_search.sh index eb3be5fa566..f26aff8b8ce 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/w4a8_moe_gemm_config_search.sh +++ b/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/w4a8_moe_gemm_config_search.sh @@ -21,12 +21,12 @@ rm -rf up_gate_proj_7168_8192.log rm -rf down_proj_8192_3584.log num_experts=8 -for tokens_per_expert in 12 +for tokens_per_expert in 1 2 4 8 16 20 24 28 32 36 48 64 96 128 160 192 224 256 384 512 768 1024 2048 3072 4096 8192 do wait -CUDA_VISIBLE_DEVICES=2 ./w4a8_moe_gemm_test ${num_experts} ${up_gate_proj_n} ${up_gate_proj_k} ${tokens_per_expert} 1 0 >> up_gate_proj_${up_gate_proj_n}_${up_gate_proj_k}.log 2>&1 & -# CUDA_VISIBLE_DEVICES=3 ./w4a8_moe_gemm_test ${num_experts} ${down_proj_n} ${down_proj_k} ${tokens_per_expert} 1 0 >> down_proj_${down_proj_n}_${down_proj_k}.log 2>&1 & +CUDA_VISIBLE_DEVICES=2 ./w4a8_moe_gemm_test ${num_experts} ${ffn1_n} ${ffn1_k} ${tokens_per_expert} 0 1 >> ffn1_${ffn1_n}_${ffn1_k}.log 2>&1 & +CUDA_VISIBLE_DEVICES=3 ./w4a8_moe_gemm_test ${num_experts} ${ffn2_n} ${ffn2_k} ${tokens_per_expert} 0 1 >> ffn2_${ffn2_n}_${ffn2_k}.log 2>&1 & done wait echo "#### finish ####" diff --git a/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/w4a8_moe_gemm_kernel.h b/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/w4a8_moe_gemm_kernel.h index 7d06b59d653..6ec9121634c 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/w4a8_moe_gemm_kernel.h +++ b/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/w4a8_moe_gemm_kernel.h @@ -30,45 +30,46 @@ class W4A8MoeGemmRunner { ~W4A8MoeGemmRunner(); void moe_gemm_bias_act(const IntAType* A, - const IntBType* B, - cutlass::epilogue::QuantMode quant_mode, - const OutputType* col_scale, - const OutputType* row_scale, - const OutputType* biases, - const int32_t* nf4_look_up_table, - OutputType* C, - int64_t* total_rows_before_expert, - int m, - int n, - int k, - int num_experts, - std::string activation_type, - char* workspace_ptr, - const size_t workspace_bytes, - cudaStream_t stream); + const IntBType* B, + cutlass::epilogue::QuantMode quant_mode, + const OutputType* col_scale, + const OutputType* row_scale, + const OutputType* biases, + const int32_t* nf4_look_up_table, + OutputType* C, + int64_t* total_rows_before_expert, + int m, + int n, + int k, + int num_experts, + std::string activation_type, + char* workspace_ptr, + const size_t workspace_bytes, + cudaStream_t stream); void moe_gemm(const IntAType* A, - const IntBType* B, - cutlass::epilogue::QuantMode quant_mode, - const OutputType* col_scale, - const OutputType* row_scale, - const int32_t* nf4_look_up_table, - OutputType* C, - int64_t* total_rows_before_expert, - int64_t total_rows_in_ll_else_minus1, - int64_t total_rows, - int64_t gemm_n, - int64_t gemm_k, - char* workspace_ptr, - const size_t workspace_bytes, - int num_experts, - cudaStream_t stream, - CutlassGemmConfig gemm_config = CutlassGemmConfig{CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64, - SplitKStyle::NO_SPLIT_K, - 1, - 5}); - private: + const IntBType* B, + cutlass::epilogue::QuantMode quant_mode, + const OutputType* col_scale, + const OutputType* row_scale, + const int32_t* nf4_look_up_table, + OutputType* C, + int64_t* total_rows_before_expert, + int64_t total_rows_in_ll_else_minus1, + int64_t total_rows, + int64_t gemm_n, + int64_t gemm_k, + char* workspace_ptr, + const size_t workspace_bytes, + int num_experts, + cudaStream_t stream, + CutlassGemmConfig gemm_config = CutlassGemmConfig{ + CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64, + SplitKStyle::NO_SPLIT_K, + 1, + 5}); + private: template void dispatch_to_arch(const IntAType* A, const IntBType* B, @@ -108,8 +109,7 @@ class W4A8MoeGemmRunner { cudaStream_t stream, CutlassGemmConfig gemm_config); - int getWorkspaceSize( - const int m, const int n, const int k); + int getWorkspaceSize(const int m, const int n, const int k); private: static constexpr int split_k_limit = 4; diff --git a/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/w4a8_moe_gemm_kernel_template.h b/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/w4a8_moe_gemm_kernel_template.h index 80ffb06de43..3ab51307d61 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/w4a8_moe_gemm_kernel_template.h +++ b/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/w4a8_moe_gemm_kernel_template.h @@ -29,7 +29,8 @@ template -void generic_w4a8_moe_gemm_kernelLauncher(const IntAType* A, +void generic_w4a8_moe_gemm_kernelLauncher( + const IntAType* A, const IntBType* B, cutlass::epilogue::QuantMode quant_mode, const OutputType* col_scale, @@ -48,7 +49,6 @@ void generic_w4a8_moe_gemm_kernelLauncher(const IntAType* A, cudaStream_t stream, int* occupancy); - template struct dispatch_stages { + IntAType, + IntBType, + arch, + EpilogueTag, + ThreadblockShape, + WarpShape, + 2> { static void dispatch(const IntAType* A, const IntBType* B, cutlass::epilogue::QuantMode quant_mode, @@ -119,31 +118,31 @@ struct dispatch_stages(A, - B, - quant_mode, - col_scale, - row_scale, - nf4_look_up_table, - C, - total_rows_before_expert, - total_rows, - total_rows_in_ll_else_minus1, - n, - k, - num_experts, - gemm_config, - workspace, - workspace_bytes, - multi_processor_count, - stream, - occupancy); + IntAType, + IntBType, + arch, + EpilogueTag, + ThreadblockShape, + WarpShape, + 2>(A, + B, + quant_mode, + col_scale, + row_scale, + nf4_look_up_table, + C, + total_rows_before_expert, + total_rows, + total_rows_in_ll_else_minus1, + n, + k, + num_experts, + gemm_config, + workspace, + workspace_bytes, + multi_processor_count, + stream, + occupancy); } }; @@ -155,12 +154,12 @@ template struct dispatch_stages 2)>::type> { static void dispatch(const IntAType* A, @@ -183,35 +182,34 @@ struct dispatch_stages(A, - B, - quant_mode, - col_scale, - row_scale, - nf4_look_up_table, - C, - total_rows_before_expert, - total_rows, - total_rows_in_ll_else_minus1, - n, - k, - num_experts, - gemm_config, - workspace, - workspace_bytes, - multi_processor_count, - stream, - occupancy); + IntAType, + IntBType, + cutlass::arch::Sm80, + EpilogueTag, + ThreadblockShape, + WarpShape, + Stages>(A, + B, + quant_mode, + col_scale, + row_scale, + nf4_look_up_table, + C, + total_rows_before_expert, + total_rows, + total_rows_in_ll_else_minus1, + n, + k, + num_experts, + gemm_config, + workspace, + workspace_bytes, + multi_processor_count, + stream, + occupancy); } }; - template -static void PrintMatrix(const T *mat_d, int num, std::string name, +static void PrintMatrix(const T *mat_d, + int num, + std::string name, int numOfCols) { std::vector tmp(num); cudaMemcpy(tmp.data(), mat_d, sizeof(T) * num, cudaMemcpyDeviceToHost); @@ -104,17 +111,17 @@ static void PrintMatrix(const T *mat_d, int num, std::string name, uint as_uint(const float x) { return *(uint *)&x; } uint16_t ConvertFloat2Half(const float x) { - const uint b = as_uint(x) + 0x00001000; // round-to-nearest-even: add last - // bit after truncated mantissa - const uint e = (b & 0x7F800000) >> 23; // exponent - const uint m = b & 0x007FFFFF; // mantissa; in line below: 0x007FF000 = - // 0x00800000-0x00001000 = decimal indicator - // flag - initial rounding + const uint b = as_uint(x) + 0x00001000; // round-to-nearest-even: add last + // bit after truncated mantissa + const uint e = (b & 0x7F800000) >> 23; // exponent + const uint m = b & 0x007FFFFF; // mantissa; in line below: 0x007FF000 = + // 0x00800000-0x00001000 = decimal indicator + // flag - initial rounding return (b & 0x80000000) >> 16 | (e > 112) * ((((e - 112) << 10) & 0x7C00) | m >> 13) | ((e < 113) & (e > 101)) * ((((0x007FF000 + m) >> (125 - e)) + 1) >> 1) | - (e > 143) * 0x7FFF; // sign : normalized : denormalized : saturate + (e > 143) * 0x7FFF; // sign : normalized : denormalized : saturate } inline float fp32_from_bits(uint32_t w) { @@ -274,7 +281,9 @@ float CPUHalfConvert2Float(const uint16_t h) { return fp32_from_bits(result); } -static void PrintHalfMatrix(const int16_t *mat_d, int num, std::string name, +static void PrintHalfMatrix(const int16_t *mat_d, + int num, + std::string name, int numOfCols) { std::vector tmp(num); cudaMemcpy(tmp.data(), mat_d, sizeof(int16_t) * num, cudaMemcpyDeviceToHost); @@ -296,7 +305,9 @@ static void PrintHalfMatrix(const int16_t *mat_d, int num, std::string name, } template -static void PrintMatrixCPU(const T *mat, int num, std::string name, +static void PrintMatrixCPU(const T *mat, + int num, + std::string name, int numOfCols) { std::ofstream outfile; outfile.open(name + ".txt", std::ios::out); @@ -315,7 +326,9 @@ static void PrintMatrixCPU(const T *mat, int num, std::string name, outfile.close(); } -static void PrintMatrixCPU_int4(const int8_t *mat, int num, std::string name, +static void PrintMatrixCPU_int4(const int8_t *mat, + int num, + std::string name, int numOfCols) { std::ofstream outfile; outfile.open(name + ".txt", std::ios::out); @@ -333,7 +346,9 @@ static void PrintMatrixCPU_int4(const int8_t *mat, int num, std::string name, outfile.close(); } template -static void PrintHalfMatrixCPU(const T *mat, int num, std::string name, +static void PrintHalfMatrixCPU(const T *mat, + int num, + std::string name, int numOfCols) { std::ofstream outfile; outfile.open(name + ".txt", std::ios::out); @@ -349,8 +364,8 @@ static void PrintHalfMatrixCPU(const T *mat, int num, std::string name, } template -void naive_matmul(const T *a, const T *b, outputT *c, size_t m, size_t n, - size_t k) { +void naive_matmul( + const T *a, const T *b, outputT *c, size_t m, size_t n, size_t k) { for (int ik = 0; ik < k; ik++) { for (int im = 0; im < m; im++) { for (int in = 0; in < n; in++) { @@ -361,13 +376,17 @@ void naive_matmul(const T *a, const T *b, outputT *c, size_t m, size_t n, } template -void naive_matmul_fused_dequantize_nf4(const T *a, const T *b, +void naive_matmul_fused_dequantize_nf4(const T *a, + const T *b, const ScaleType *col_scale, const ScaleType *row_scale, const int32_t *nf4_look_up_table, - outputT *c, size_t num_experts, + outputT *c, + size_t num_experts, int64_t *total_rows_before_experts, - size_t total_rows, size_t n, size_t k) { + size_t total_rows, + size_t n, + size_t k) { // PrintMatrixCPU( // a, total_rows * k, "naive_matmul_a", k); // PrintMatrixCPU( @@ -442,10 +461,15 @@ void naive_matmul_fused_dequantize_nf4(const T *a, const T *b, } // Author (zhengzekang): we use float to monitor half matmul in CPU. -void CheckHalfDiff(int16_t *device_res, float *host_result, size_t elem_cnt, - float atol, float rtol) { +void CheckHalfDiff(int16_t *device_res, + float *host_result, + size_t elem_cnt, + float atol, + float rtol) { std::vector device_data(elem_cnt); - cudaMemcpy(device_data.data(), device_res, sizeof(int16_t) * elem_cnt, + cudaMemcpy(device_data.data(), + device_res, + sizeof(int16_t) * elem_cnt, cudaMemcpyDeviceToHost); for (size_t i = 0; i < elem_cnt; i++) { @@ -459,7 +483,10 @@ void CheckHalfDiff(int16_t *device_res, float *host_result, size_t elem_cnt, printf( "Here in Idx: %d, CUDA result is: %f, Host result is: %f, absolute " "diff val is: %f \n", - i, device_res_val, host_res_val, absolute_diff); + i, + device_res_val, + host_res_val, + absolute_diff); return; } } @@ -508,11 +535,11 @@ CutlassGemmConfig GetGemmConfig(int token_nums, // gemm_config_tuple:[m,n,k,tile_config,split_k_style,split_k_factor,stages] for (int i = 0; i < len_of_gemm_config_tuple; i += 7) { gemm_config.tile_config = - CutlassTileConfig(gemm_config_tuple[i + 3]); // tile_config + CutlassTileConfig(gemm_config_tuple[i + 3]); // tile_config gemm_config.split_k_style = - SplitKStyle(gemm_config_tuple[i + 4]); // split_k_style - gemm_config.split_k_factor = gemm_config_tuple[i + 5]; // split_k_factor - gemm_config.stages = gemm_config_tuple[i + 6]; // stages + SplitKStyle(gemm_config_tuple[i + 4]); // split_k_style + gemm_config.split_k_factor = gemm_config_tuple[i + 5]; // split_k_factor + gemm_config.stages = gemm_config_tuple[i + 6]; // stages // make sure we have at least one tuned config if (token_nums <= gemm_config_tuple[i + 0]) { break; @@ -522,7 +549,8 @@ CutlassGemmConfig GetGemmConfig(int token_nums, } template -void get_tensor_from_file(const std::string file_path, int64_t numel, +void get_tensor_from_file(const std::string file_path, + int64_t numel, T *tensor_ptr) { std::fstream datafile; datafile.open(file_path, std::ios_base::in | std::ios_base::out); @@ -641,7 +669,7 @@ int main(int argc, char *argv[]) { auto mixed_gemm_runner = W4A8MoeGemmRunner(); // int mixgemm_max_size = std::max(m, k); - int mixgemm_workspace_size_bytes = 1 * 1024 * 1024 * 1024; // 1G workspace + int mixgemm_workspace_size_bytes = 1 * 1024 * 1024 * 1024; // 1G workspace std::cout << "mixgemm_workspace_size_bytes: " << mixgemm_workspace_size_bytes << std::endl; char *mixgemm_workspace_data; @@ -663,8 +691,8 @@ int main(int argc, char *argv[]) { } } else { std::cout << "get a data from: " << a_data_file << std::endl; - get_tensor_from_file(a_data_file, total_rows * k, - a_int.data()); + get_tensor_from_file( + a_data_file, total_rows * k, a_int.data()); } // PrintMatrixCPU(a_int.data(),total_rows*k,"a_int8_cpu",n); } @@ -752,7 +780,8 @@ int main(int argc, char *argv[]) { // PrintMatrixCPU_int4(packed_b_int.data(),num_experts*k*n,"w4a8_packed_b_int4",n); permute_B_rows_for_mixed_gemm_int4<4>( b_int_processed.data() + ie * k * n / 2, - packed_b_int.data() + ie * k * n / 2, std::vector{k, n}, + packed_b_int.data() + ie * k * n / 2, + std::vector{k, n}, (int64_t)80); // PrintMatrixCPU_int4(b_int_processed.data(),num_experts*k*n,"w4a8_permuted_int4",n); @@ -820,8 +849,8 @@ int main(int argc, char *argv[]) { // } } } else { - get_tensor_from_file(row_scale_data_file, total_rows, - row_scale_float.data()); + get_tensor_from_file( + row_scale_data_file, total_rows, row_scale_float.data()); } // PrintMatrixCPU(row_scale_float.data(),total_rows,"row_scale_float_cpu",total_rows); } @@ -839,8 +868,8 @@ int main(int argc, char *argv[]) { // } } } else { - get_tensor_from_file(col_scale_data_file, num_experts * n, - col_scale_float.data()); + get_tensor_from_file( + col_scale_data_file, num_experts * n, col_scale_float.data()); } // PrintMatrixCPU(col_scale_float.data(),num_experts*n,"col_scale_float_cpu",n); } @@ -878,21 +907,35 @@ int main(int argc, char *argv[]) { cudaMalloc(&d_col_scale_half, num_experts * n * sizeof(uint16_t)); cudaMalloc(&d_nf4_look_up_table, 4 * sizeof(uint32_t)); cudaMalloc(&d_total_rows_before_experts, num_experts * sizeof(int64_t)); - cudaMemcpy(d_a_int, a_int.data(), total_rows * k * sizeof(int8_t), + cudaMemcpy(d_a_int, + a_int.data(), + total_rows * k * sizeof(int8_t), cudaMemcpyHostToDevice); - cudaMemcpy(d_b_int, b_int_processed_3.data(), - num_experts * k * n / 2 * sizeof(int8_t), cudaMemcpyHostToDevice); - - cudaMemcpy(d_row_scale_half, row_scale_half.data(), - total_rows * sizeof(uint16_t), cudaMemcpyHostToDevice); - cudaMemcpy(d_col_scale_half, col_scale_half.data(), - num_experts * n * sizeof(uint16_t), cudaMemcpyHostToDevice); - cudaMemcpy(d_nf4_look_up_table, nf4_look_up_table_compress.data(), - 4 * sizeof(uint32_t), cudaMemcpyHostToDevice); - cudaMemcpy(d_c_int, c_half.data(), total_rows * n * sizeof(uint16_t), + cudaMemcpy(d_b_int, + b_int_processed_3.data(), + num_experts * k * n / 2 * sizeof(int8_t), + cudaMemcpyHostToDevice); + + cudaMemcpy(d_row_scale_half, + row_scale_half.data(), + total_rows * sizeof(uint16_t), + cudaMemcpyHostToDevice); + cudaMemcpy(d_col_scale_half, + col_scale_half.data(), + num_experts * n * sizeof(uint16_t), + cudaMemcpyHostToDevice); + cudaMemcpy(d_nf4_look_up_table, + nf4_look_up_table_compress.data(), + 4 * sizeof(uint32_t), + cudaMemcpyHostToDevice); + cudaMemcpy(d_c_int, + c_half.data(), + total_rows * n * sizeof(uint16_t), + cudaMemcpyHostToDevice); + cudaMemcpy(d_total_rows_before_experts, + total_rows_before_experts.data(), + num_experts * sizeof(int64_t), cudaMemcpyHostToDevice); - cudaMemcpy(d_total_rows_before_experts, total_rows_before_experts.data(), - num_experts * sizeof(int64_t), cudaMemcpyHostToDevice); cudaDeviceSynchronize(); cudaError_t result = cudaGetLastError(); @@ -908,7 +951,9 @@ int main(int argc, char *argv[]) { std::cout << "=== do warm up for " << kWarmTime << " times" << std::endl; auto test_config = CutlassGemmConfig{CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64, - SplitKStyle::NO_SPLIT_K, 1, 5}; + SplitKStyle::NO_SPLIT_K, + 1, + 5}; std::cout << "=== do warm up end" << std::endl; for (int i = 0; i < kWarmTime; i++) { printf("warm up %d\n", i); @@ -920,9 +965,16 @@ int main(int argc, char *argv[]) { reinterpret_cast(d_row_scale_half), reinterpret_cast(d_nf4_look_up_table), reinterpret_cast(d_c_int), - reinterpret_cast(d_total_rows_before_experts), -1, - total_rows, n, k, mixgemm_workspace_data, mixgemm_workspace_size_bytes, - num_experts, 0, test_config); + reinterpret_cast(d_total_rows_before_experts), + -1, + total_rows, + n, + k, + mixgemm_workspace_data, + mixgemm_workspace_size_bytes, + num_experts, + 0, + test_config); cudaError_t err = cudaGetLastError(); if (err != cudaSuccess) { std::cout << "error: " << cudaGetErrorString(err) << std::endl; @@ -996,7 +1048,6 @@ int main(int argc, char *argv[]) { CutlassTileConfig::CtaShape64x256x64_WarpShape64x64x64, CutlassTileConfig::CtaShape32x512x64_WarpShape32x128x64, CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64, - CutlassTileConfig::CtaShape32x512x64_WarpShape32x128x64, }; std::vector all_split_k_style{SplitKStyle::NO_SPLIT_K}; @@ -1013,7 +1064,6 @@ int main(int argc, char *argv[]) { cudaEventCreate(&end); cudaEventRecord(begin, 0); for (int i = 0; i < kTestTime; ++i) { - mixed_gemm_runner.moe_gemm( reinterpret_cast(d_a_int), reinterpret_cast((void *)d_b_int), @@ -1022,9 +1072,15 @@ int main(int argc, char *argv[]) { reinterpret_cast(d_row_scale_half), reinterpret_cast(d_nf4_look_up_table), reinterpret_cast(d_c_int), - reinterpret_cast(d_total_rows_before_experts), -1, - total_rows, n, k, mixgemm_workspace_data, - mixgemm_workspace_size_bytes, num_experts, 0, + reinterpret_cast(d_total_rows_before_experts), + -1, + total_rows, + n, + k, + mixgemm_workspace_data, + mixgemm_workspace_size_bytes, + num_experts, + 0, test_gemm_config); } cudaEventRecord(end, 0); @@ -1150,37 +1206,58 @@ int main(int argc, char *argv[]) { if (do_check) { std::cout << "=== do accuracy check " << std::endl; cudaMemset(d_c_int, 0, total_rows * n * sizeof(uint16_t)); - PrintHalfMatrix(static_cast(d_c_int), total_rows * n, - "CUDA_c_dequantize_fp16_output_before_gemm", n); + PrintHalfMatrix(static_cast(d_c_int), + total_rows * n, + "CUDA_c_dequantize_fp16_output_before_gemm", + n); mixed_gemm_runner.moe_gemm( reinterpret_cast(d_a_int), reinterpret_cast((void *)d_b_int), cutlass::epilogue::QuantMode::PerChannelQuant, reinterpret_cast(d_col_scale_half), - nullptr, // reinterpret_cast(d_row_scale_half), - nullptr, // reinterpret_cast(d_nf4_look_up_table), + nullptr, // reinterpret_cast(d_row_scale_half), + nullptr, // reinterpret_cast(d_nf4_look_up_table), reinterpret_cast(d_c_int), - reinterpret_cast(d_total_rows_before_experts), -1, - total_rows, n, k, mixgemm_workspace_data, mixgemm_workspace_size_bytes, - num_experts, 0); + reinterpret_cast(d_total_rows_before_experts), + -1, + total_rows, + n, + k, + mixgemm_workspace_data, + mixgemm_workspace_size_bytes, + num_experts, + 0); cudaDeviceSynchronize(); // PrintMatrix(reinterpret_cast(d_nf4_look_up_table),4,"d_nf4_look_up_table",1); printf("##### d_nf4_look_up_table address: %p \n", d_nf4_look_up_table); naive_matmul_fused_dequantize_nf4( - a_int.data(), b_int.data(), col_scale_float.data(), - nullptr, // row_scale_float.data(), - nullptr, // nf4_look_up_table.data(), - c_float.data(), num_experts, total_rows_before_experts.data(), - total_rows, n, k); - PrintMatrixCPU(c_float.data(), total_rows * n, - "CPU_c_fake_fp16_dequantize_output_base", n); - PrintHalfMatrix(static_cast(d_c_int), total_rows * n, - "CUDA_c_dequantize_fp16_output", n); - CheckHalfDiff(static_cast(d_c_int), c_float.data(), - total_rows * n, 1e-4, 1e-2); + a_int.data(), + b_int.data(), + col_scale_float.data(), + nullptr, // row_scale_float.data(), + nullptr, // nf4_look_up_table.data(), + c_float.data(), + num_experts, + total_rows_before_experts.data(), + total_rows, + n, + k); + PrintMatrixCPU(c_float.data(), + total_rows * n, + "CPU_c_fake_fp16_dequantize_output_base", + n); + PrintHalfMatrix(static_cast(d_c_int), + total_rows * n, + "CUDA_c_dequantize_fp16_output", + n); + CheckHalfDiff(static_cast(d_c_int), + c_float.data(), + total_rows * n, + 1e-4, + 1e-2); } // if(kTestTime > 0){ diff --git a/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/w4a8_moe_gemm_with_epilogue_visitor.h b/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/w4a8_moe_gemm_with_epilogue_visitor.h index 648a21a3535..804dcfdf919 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/w4a8_moe_gemm_with_epilogue_visitor.h +++ b/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/w4a8_moe_gemm_with_epilogue_visitor.h @@ -33,10 +33,9 @@ namespace gemm { namespace kernel { template + GroupScheduleMode GroupScheduleMode_> struct MoeW4A8GemmWithEpilogueVisitorInterleavedNf4 { public: using Mma = Mma_; @@ -98,10 +97,10 @@ struct MoeW4A8GemmWithEpilogueVisitorInterleavedNf4 { static bool const kTransposed = false; using ProblemVisitor = GemmMoeProblemVisitor; + kGroupScheduleMode, + kThreadCount, + kThreadCount, + kTransposed>; // // Structures @@ -206,7 +205,6 @@ struct MoeW4A8GemmWithEpilogueVisitorInterleavedNf4 { ElementC, LayoutA, LayoutB> { - using ParamsBase = UniversalParamsBase::value) { - isCMisaligned = problem_size.n() % kAlignmentC; + isCMisaligned = problem_size.n() % kAlignmentC; } else if (platform::is_same::value) { isCMisaligned = problem_size.m() % kAlignmentC; } else if (platform::is_same::value && + kInterleave == 1 || + platform::is_same::value && + kInterleave >= 1, + "B must be row major/col major OR col major interleaved."); - using ElementA = typename Mma::IteratorA::Element; - using LayoutA = typename Mma::IteratorA::Layout; - using ElementB = typename Mma::IteratorB::Element; - using LayoutB = typename Mma::IteratorB::Layout; - - static constexpr int kInterleave = - Mma::IteratorB::Shape::kRow / Mma::Shape::kK; - static_assert( - platform::is_same::value && - kInterleave == 1 || - platform::is_same::value && - kInterleave >= 1, - "B must be row major/col major OR col major interleaved."); - - // - // Problem visitor. - // - ProblemVisitor problem_visitor( - params.problem_visitor, shared_storage.problem_visitor, blockIdx.x); - const int64_t gemm_k = params.problem_visitor.gemm_k; - const int64_t gemm_n = params.problem_visitor.gemm_n; - int64_t bytes_per_expert_matrix = - (gemm_k * gemm_n / 8) * cutlass::sizeof_bits::value; - - // Outer 'persistent' loop to iterate over tiles - while (problem_visitor.next_tile()) { + // + // Problem visitor. + // + ProblemVisitor problem_visitor( + params.problem_visitor, shared_storage.problem_visitor, blockIdx.x); + const int64_t gemm_k = params.problem_visitor.gemm_k; + const int64_t gemm_n = params.problem_visitor.gemm_n; + int64_t bytes_per_expert_matrix = + (gemm_k * gemm_n / 8) * cutlass::sizeof_bits::value; + + // Outer 'persistent' loop to iterate over tiles + while (problem_visitor.next_tile()) { // // Compute threadblock location ThreadblockSwizzle threadblock_swizzle; cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); - GemmCoord problem_size = problem_visitor.problem_size(); - int32_t problem_idx = problem_visitor.problem_index(); - int32_t cta_idx = int32_t(problem_visitor.threadblock_idx()); + GemmCoord problem_size = problem_visitor.problem_size(); + int32_t problem_idx = problem_visitor.problem_index(); + int32_t cta_idx = int32_t(problem_visitor.threadblock_idx()); - GemmCoord grid_shape = problem_visitor.grid_shape(problem_size); + GemmCoord grid_shape = problem_visitor.grid_shape(problem_size); - cutlass::MatrixCoord threadblock_offset( - int(cta_idx / grid_shape.n()) * Mma::Shape::kM, // NOLINT - int(cta_idx % grid_shape.n()) * Mma::Shape::kN // NOLINT - ); + cutlass::MatrixCoord threadblock_offset( + int(cta_idx / grid_shape.n()) * Mma::Shape::kM, // NOLINT + int(cta_idx % grid_shape.n()) * Mma::Shape::kN // NOLINT + ); - // if (threadIdx.x == 0) { - // printf("%d-%d-%d problem_size: %d, %d problem_idx: %d, cta_idx: %d\n", blockIdx.x,blockIdx.y,blockIdx.z, problem_size.m(), problem_size.n(), problem_idx, cta_idx); - // } - - // Load element pointers. Exchange pointers and strides if working on - // the transpose - int64_t rows_to_jump = 0; - if (params.problem_visitor.total_rows < 0) { - rows_to_jump = problem_idx == 0 ? 0 : params.problem_visitor.last_row_for_problem[problem_idx - 1]; - } else { - rows_to_jump = problem_idx * (params.problem_visitor.total_rows / params.problem_visitor.problem_count); - } + // if (threadIdx.x == 0) { + // printf("%d-%d-%d problem_size: %d, %d problem_idx: %d, cta_idx: + // %d\n", blockIdx.x,blockIdx.y,blockIdx.z, problem_size.m(), + // problem_size.n(), problem_idx, cta_idx); + // } - ElementA* ptr_A = - reinterpret_cast(params.ptr_A) + rows_to_jump * gemm_k; - typename LayoutA::LongIndex ldm_A = gemm_k; + // Load element pointers. Exchange pointers and strides if working on + // the transpose + int64_t rows_to_jump = 0; + if (params.problem_visitor.total_rows < 0) { + rows_to_jump = + problem_idx == 0 + ? 0 + : params.problem_visitor.last_row_for_problem[problem_idx - 1]; + } else { + rows_to_jump = problem_idx * (params.problem_visitor.total_rows / + params.problem_visitor.problem_count); + } - char* byte_ptr_B = ((char*)params.ptr_B) + // NOLINT - problem_idx * bytes_per_expert_matrix; // NOLINT - ElementB* ptr_B = reinterpret_cast(byte_ptr_B); - typename LayoutB::LongIndex ldm_B = - platform::is_same::value - ? gemm_n - : gemm_k * kInterleave; + ElementA* ptr_A = + reinterpret_cast(params.ptr_A) + rows_to_jump * gemm_k; + typename LayoutA::LongIndex ldm_A = gemm_k; + char* byte_ptr_B = ((char*)params.ptr_B) + // NOLINT + problem_idx * bytes_per_expert_matrix; // NOLINT + ElementB* ptr_B = reinterpret_cast(byte_ptr_B); + typename LayoutB::LongIndex ldm_B = + platform::is_same::value + ? gemm_n + : gemm_k * kInterleave; int offset_k = 0; int problem_size_k = params.problem_size.k(); - - // Maybe need to modify? Author zhengzekang. - #if SPLIT_K_ENABLED +// Maybe need to modify? Author zhengzekang. +#if SPLIT_K_ENABLED // // Fetch pointers based on mode. // if (params.mode == GemmUniversalMode::kGemm || params.mode == GemmUniversalMode::kGemmSplitKParallel) { if (threadblock_tile_offset.k() + 1 < params.grid_tiled_shape.k()) { - problem_size_k = (threadblock_tile_offset.k() + 1) * params.gemm_k_size; + problem_size_k = + (threadblock_tile_offset.k() + 1) * params.gemm_k_size; } offset_k = threadblock_tile_offset.k() * params.gemm_k_size; } else if (params.mode == GemmUniversalMode::kBatched) { @@ -475,17 +480,19 @@ struct MoeW4A8GemmWithEpilogueVisitorInterleavedNf4 { ptr_B = static_cast( params.ptr_B)[threadblock_tile_offset.k()]; } - #endif - // if(threadIdx.x==0){ - // printf("##### block: %d-%d-%d, offset_k:%d, threadblock_tile_offset.m-n-k():%d-%d-%d, params.gemm_k_size:%d \n", - // blockIdx.x, blockIdx.y, blockIdx.z, - // offset_k, - // threadblock_tile_offset.m(), - // threadblock_tile_offset.n(), - // threadblock_tile_offset.k(), - // params.gemm_k_size - // ); - // } +#endif + // if(threadIdx.x==0){ + // printf("##### block: %d-%d-%d, offset_k:%d, + // threadblock_tile_offset.m-n-k():%d-%d-%d, params.gemm_k_size:%d + // \n", + // blockIdx.x, blockIdx.y, blockIdx.z, + // offset_k, + // threadblock_tile_offset.m(), + // threadblock_tile_offset.n(), + // threadblock_tile_offset.k(), + // params.gemm_k_size + // ); + // } // Compute initial location in logical coordinates cutlass::MatrixCoord tb_offset_A{ @@ -493,21 +500,18 @@ struct MoeW4A8GemmWithEpilogueVisitorInterleavedNf4 { 0, }; - cutlass::MatrixCoord tb_offset_B{ - 0, - threadblock_offset.column() / kInterleave}; + 0, threadblock_offset.column() / kInterleave}; // Compute position within threadblock int thread_idx = threadIdx.x; // Construct iterators to A and B operands - typename Mma::IteratorA iterator_A( - params.params_A, - ptr_A, - {problem_size.m(), problem_size_k}, - thread_idx, - tb_offset_A); + typename Mma::IteratorA iterator_A(params.params_A, + ptr_A, + {problem_size.m(), problem_size_k}, + thread_idx, + tb_offset_A); typename Mma::IteratorB iterator_B( params.params_B, @@ -516,13 +520,11 @@ struct MoeW4A8GemmWithEpilogueVisitorInterleavedNf4 { thread_idx, tb_offset_B); typename Mma::IteratorNF4LookUpTable iterator_nf4_look_up_table = - Mma::IteratorNF4LookUpTable( - params.params_nf4_look_up_table, - params.ref_nf4_look_up_table.data(), - {0,16}, - threadIdx.x, - {0,0} - ); + Mma::IteratorNF4LookUpTable(params.params_nf4_look_up_table, + params.ref_nf4_look_up_table.data(), + {0, 16}, + threadIdx.x, + {0, 0}); // Broadcast the warp_id computed by lane 0 to ensure dependent code // is compiled as warp-uniform. @@ -542,9 +544,14 @@ struct MoeW4A8GemmWithEpilogueVisitorInterleavedNf4 { // Compute threadblock-scoped matrix multiply-add int gemm_k_iterations = - (problem_size_k + Mma::Shape::kK - 1) / Mma::Shape::kK; + (problem_size_k + Mma::Shape::kK - 1) / Mma::Shape::kK; // Compute threadblock-scoped matrix multiply-add - mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, iterator_nf4_look_up_table, accumulators); + mma(gemm_k_iterations, + accumulators, + iterator_A, + iterator_B, + iterator_nf4_look_up_table, + accumulators); // if(threadIdx.x==0){ // printf("##### block: %d-%d-%d, offset-m-n-k:%d-%d-%d \n", // blockIdx.x, blockIdx.y, blockIdx.z, @@ -560,44 +567,54 @@ struct MoeW4A8GemmWithEpilogueVisitorInterleavedNf4 { threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); - ElementC* ptr_C = - reinterpret_cast(params.ptr_C) + rows_to_jump * gemm_n; - ElementC* ptr_D = - reinterpret_cast(params.ptr_D) + rows_to_jump * gemm_n; - - using Element_scale = typename EpilogueVisitor::ScaleTileIterator::Element; - Element_scale* ptr_alpha_row = params.ptr_alpha_row == nullptr ? params.ptr_alpha_row : reinterpret_cast(params.ptr_alpha_row) + rows_to_jump; - Element_scale* ptr_alpha_col = reinterpret_cast(params.ptr_alpha_col) + problem_idx * params.problem_size.n(); + ElementC* ptr_C = + reinterpret_cast(params.ptr_C) + rows_to_jump * gemm_n; + ElementC* ptr_D = + reinterpret_cast(params.ptr_D) + rows_to_jump * gemm_n; + + using Element_scale = + typename EpilogueVisitor::ScaleTileIterator::Element; + Element_scale* ptr_alpha_row = + params.ptr_alpha_row == nullptr + ? params.ptr_alpha_row + : reinterpret_cast(params.ptr_alpha_row) + + rows_to_jump; + Element_scale* ptr_alpha_col = + reinterpret_cast(params.ptr_alpha_col) + + problem_idx * params.problem_size.n(); // if (threadIdx.x == 0) - // printf("##### block: %d-%d-%d, ptr_alpha_row:%p,(%f) ptr_alpha_col:%p,(%f)\n", blockIdx.x, blockIdx.y, blockIdx.z, ptr_alpha_row, static_cast(*ptr_alpha_row), ptr_alpha_col, static_cast(*ptr_alpha_col)); + // printf("##### block: %d-%d-%d, ptr_alpha_row:%p,(%f) + // ptr_alpha_col:%p,(%f)\n", blockIdx.x, blockIdx.y, blockIdx.z, + // ptr_alpha_row, static_cast(*ptr_alpha_row), ptr_alpha_col, + // static_cast(*ptr_alpha_col)); // // Construct the epilogue visitor // EpilogueVisitor epilogue_visitor(params.epilogue_visitor, - shared_storage.epilogue.visitor, - problem_size.mn(), - thread_idx, - warp_idx, - lane_idx, - params.params_alpha_col, - params.params_C, - params.params_D, - params.quant_mode, - ptr_alpha_row, - ptr_alpha_col, - ptr_C, - ptr_D, - threadblock_offset, - blockIdx.y * params.problem_size.m()); + shared_storage.epilogue.visitor, + problem_size.mn(), + thread_idx, + warp_idx, + lane_idx, + params.params_alpha_col, + params.params_C, + params.params_D, + params.quant_mode, + ptr_alpha_row, + ptr_alpha_col, + ptr_C, + ptr_D, + threadblock_offset, + blockIdx.y * params.problem_size.m()); if (params.mode == GemmUniversalMode::kGemm) { // Indicate which position in a serial reduction the output operator is // currently updating epilogue_visitor.set_k_partition(threadblock_tile_offset.k(), - params.grid_tiled_shape.k()); + params.grid_tiled_shape.k()); } else if (params.mode == GemmUniversalMode::kBatched || - params.mode == GemmUniversalMode::kArray) { + params.mode == GemmUniversalMode::kArray) { epilogue_visitor.set_batch_index(threadblock_tile_offset.k()); } @@ -608,9 +625,9 @@ struct MoeW4A8GemmWithEpilogueVisitorInterleavedNf4 { // Execute the epilogue operator to update the destination tensor. epilogue(epilogue_visitor, accumulators); - // Next tile - problem_visitor.advance(gridDim.x); - } + // Next tile + problem_visitor.advance(gridDim.x); + } } }; diff --git a/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/weight_process_utils.h b/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/weight_process_utils.h index 68779ba28ca..dd6536c762e 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/weight_process_utils.h +++ b/custom_ops/gpu_ops/cutlass_kernels/w4a8_moe/weight_process_utils.h @@ -32,482 +32,496 @@ limitations under the License. */ void row_major_to_column_major(int8_t* col_major_tensor, const int8_t* row_major_tensor, - const std::vector& shape){ - size_t m = shape[0]; - size_t n = shape[1]; - for(auto i=0;i& shape) { + size_t m = shape[0]; + size_t n = shape[1]; + for (auto i = 0; i < m * n; i++) { + size_t im = i / n; + size_t in = i % n; + col_major_tensor[in * m + im] = row_major_tensor[im * n + in]; + } } void add_bias_and_interleave_int8s_inplace(int8_t* int8_tensor_ptr, - int64_t num_elts) -{ - int8_t* int8_tensor = reinterpret_cast(int8_tensor_ptr); - for (int ii = 0; ii < num_elts; ++ii) { - int8_tensor[ii] = int8_t(int(int8_tensor[ii]) + 128); - // int8_tensor[ii] = int8_t(int(int8_tensor[ii])); - } - - // Step 2 will transform the layout of a 32-bit register in CUDA in order to match the int4 layout. This has no - // performance benefit and is purely so that int4 and int8 have the same layout. - // Pictorially, this does the following: - // bit 32 0 - // [elt_3 elt_2 elt_1 elt_0] (each elt occupies 8 bits) - // - // And it will rearrange the output 32 bit register to be the following: - // bit 32 0 - // [elt_3 elt_1 elt_2 elt_0] (each elt occupies 8 bits) - - for (int64_t base = 0; base < num_elts; base += 4) { - std::swap(int8_tensor[base + 1], int8_tensor[base + 2]); - } + int64_t num_elts) { + int8_t* int8_tensor = reinterpret_cast(int8_tensor_ptr); + for (int ii = 0; ii < num_elts; ++ii) { + int8_tensor[ii] = int8_t(int(int8_tensor[ii]) + 128); + // int8_tensor[ii] = int8_t(int(int8_tensor[ii])); + } + + // Step 2 will transform the layout of a 32-bit register in CUDA in order to + // match the int4 layout. This has no performance benefit and is purely so + // that int4 and int8 have the same layout. Pictorially, this does the + // following: bit 32 0 + // [elt_3 elt_2 elt_1 elt_0] (each elt occupies 8 bits) + // + // And it will rearrange the output 32 bit register to be the following: + // bit 32 0 + // [elt_3 elt_1 elt_2 elt_0] (each elt occupies 8 bits) + + for (int64_t base = 0; base < num_elts; base += 4) { + std::swap(int8_tensor[base + 1], int8_tensor[base + 2]); + } } +void subbyte_transpose_impl_int4(int8_t* transposed_quantized_tensor, + const int8_t* quantized_tensor, + const std::vector& shape) { + const int bits_per_elt = 4; + + const size_t num_experts = shape.size() == 2 ? 1 : shape[0]; + const size_t num_rows = shape.size() == 2 ? shape[0] : shape[1]; + const size_t num_cols = shape.size() == 2 ? shape[1] : shape[2]; + + const size_t col_bytes = num_cols * bits_per_elt / 8; + const size_t col_bytes_trans = num_rows * bits_per_elt / 8; + const size_t num_bytes = size_t(num_experts) * num_rows * col_bytes; + + const uint8_t* input_byte_ptr = + reinterpret_cast(quantized_tensor); + uint8_t* output_byte_ptr = + reinterpret_cast(transposed_quantized_tensor); + + // static_assert(quant_type == QuantType::INT8_WEIGHT_ONLY || quant_type == + // QuantType::PACKED_INT4_WEIGHT_ONLY, ""); + static constexpr int ELTS_PER_BYTE = 2; + + static constexpr int M_TILE_L1 = 64; + static constexpr int N_TILE_L1 = M_TILE_L1 / ELTS_PER_BYTE; + uint8_t cache_buf[M_TILE_L1][N_TILE_L1]; + + static constexpr int VECTOR_WIDTH = std::min(32, N_TILE_L1); + + // We assume the dims are a multiple of vector width. Our kernels only handle + // dims which are multiples of 64 for weight-only quantization. As a result, + // this seemed like a reasonable tradeoff because it allows GCC to emit vector + // instructions. + + const int num_m_tiles = (num_rows + M_TILE_L1 - 1) / M_TILE_L1; + const int num_n_tiles = (col_bytes + N_TILE_L1 - 1) / N_TILE_L1; + + for (size_t expert = 0; expert < num_experts; ++expert) { + const size_t matrix_offset = expert * num_rows * col_bytes; + for (size_t row_tile_start = 0; row_tile_start < num_rows; + row_tile_start += M_TILE_L1) { + for (size_t col_tile_start_byte = 0; col_tile_start_byte < col_bytes; + col_tile_start_byte += N_TILE_L1) { + const int row_limit = std::min(row_tile_start + M_TILE_L1, num_rows); + const int col_limit = + std::min(col_tile_start_byte + N_TILE_L1, col_bytes); + + for (int ii = 0; ii < M_TILE_L1; ++ii) { + const int row = row_tile_start + ii; + + for (int jj = 0; jj < N_TILE_L1; jj += VECTOR_WIDTH) { + const int col = col_tile_start_byte + jj; + + const size_t logical_src_offset = + matrix_offset + row * col_bytes + col; + + if (row < row_limit && col < col_limit) { + for (int v = 0; v < VECTOR_WIDTH; ++v) { + cache_buf[ii][jj + v] = input_byte_ptr[logical_src_offset + v]; + } + } + } + } -void subbyte_transpose_impl_int4(int8_t* transposed_quantized_tensor, - const int8_t* quantized_tensor, - const std::vector& shape) -{ - const int bits_per_elt = 4; - - const size_t num_experts = shape.size() == 2 ? 1 : shape[0]; - const size_t num_rows = shape.size() == 2 ? shape[0] : shape[1]; - const size_t num_cols = shape.size() == 2 ? shape[1] : shape[2]; - - const size_t col_bytes = num_cols * bits_per_elt / 8; - const size_t col_bytes_trans = num_rows * bits_per_elt / 8; - const size_t num_bytes = size_t(num_experts) * num_rows * col_bytes; - - const uint8_t* input_byte_ptr = reinterpret_cast(quantized_tensor); - uint8_t* output_byte_ptr = reinterpret_cast(transposed_quantized_tensor); - - // static_assert(quant_type == QuantType::INT8_WEIGHT_ONLY || quant_type == QuantType::PACKED_INT4_WEIGHT_ONLY, ""); - static constexpr int ELTS_PER_BYTE = 2; - - static constexpr int M_TILE_L1 = 64; - static constexpr int N_TILE_L1 = M_TILE_L1 / ELTS_PER_BYTE; - uint8_t cache_buf[M_TILE_L1][N_TILE_L1]; - - static constexpr int VECTOR_WIDTH = std::min(32, N_TILE_L1); - - // We assume the dims are a multiple of vector width. Our kernels only handle dims which are multiples - // of 64 for weight-only quantization. As a result, this seemed like a reasonable tradeoff because it - // allows GCC to emit vector instructions. - - const int num_m_tiles = (num_rows + M_TILE_L1 - 1) / M_TILE_L1; - const int num_n_tiles = (col_bytes + N_TILE_L1 - 1) / N_TILE_L1; - - for (size_t expert = 0; expert < num_experts; ++expert) { - const size_t matrix_offset = expert * num_rows * col_bytes; - for (size_t row_tile_start = 0; row_tile_start < num_rows; row_tile_start += M_TILE_L1) { - for (size_t col_tile_start_byte = 0; col_tile_start_byte < col_bytes; col_tile_start_byte += N_TILE_L1) { - - const int row_limit = std::min(row_tile_start + M_TILE_L1, num_rows); - const int col_limit = std::min(col_tile_start_byte + N_TILE_L1, col_bytes); - - for (int ii = 0; ii < M_TILE_L1; ++ii) { - const int row = row_tile_start + ii; - - for (int jj = 0; jj < N_TILE_L1; jj += VECTOR_WIDTH) { - const int col = col_tile_start_byte + jj; - - const size_t logical_src_offset = matrix_offset + row * col_bytes + col; - - if (row < row_limit && col < col_limit) { - for (int v = 0; v < VECTOR_WIDTH; ++v) { - cache_buf[ii][jj + v] = input_byte_ptr[logical_src_offset + v]; - } - } - } - } - - - for (int ii = 0; ii < M_TILE_L1; ++ii) { - // Using M_TILE_L1 here is deliberate since we assume that the cache tile - // is square in the number of elements (not necessarily the number of bytes). - for (int jj = ii + 1; jj < M_TILE_L1; ++jj) { - const int ii_byte = ii / ELTS_PER_BYTE; - const int ii_bit_offset = ii % ELTS_PER_BYTE; - - const int jj_byte = jj / ELTS_PER_BYTE; - const int jj_bit_offset = jj % ELTS_PER_BYTE; + for (int ii = 0; ii < M_TILE_L1; ++ii) { + // Using M_TILE_L1 here is deliberate since we assume that the cache + // tile is square in the number of elements (not necessarily the + // number of bytes). + for (int jj = ii + 1; jj < M_TILE_L1; ++jj) { + const int ii_byte = ii / ELTS_PER_BYTE; + const int ii_bit_offset = ii % ELTS_PER_BYTE; - uint8_t src_elt = 0xF & (cache_buf[ii][jj_byte] >> (4 * jj_bit_offset)); - uint8_t tgt_elt = 0xF & (cache_buf[jj][ii_byte] >> (4 * ii_bit_offset)); + const int jj_byte = jj / ELTS_PER_BYTE; + const int jj_bit_offset = jj % ELTS_PER_BYTE; - cache_buf[ii][jj_byte] &= (0xF0 >> (4 * jj_bit_offset)); - cache_buf[jj][ii_byte] &= (0xF0 >> (4 * ii_bit_offset)); + uint8_t src_elt = + 0xF & (cache_buf[ii][jj_byte] >> (4 * jj_bit_offset)); + uint8_t tgt_elt = + 0xF & (cache_buf[jj][ii_byte] >> (4 * ii_bit_offset)); - cache_buf[ii][jj_byte] |= (tgt_elt << (4 * jj_bit_offset)); - cache_buf[jj][ii_byte] |= (src_elt << (4 * ii_bit_offset)); - } - } + cache_buf[ii][jj_byte] &= (0xF0 >> (4 * jj_bit_offset)); + cache_buf[jj][ii_byte] &= (0xF0 >> (4 * ii_bit_offset)); + cache_buf[ii][jj_byte] |= (tgt_elt << (4 * jj_bit_offset)); + cache_buf[jj][ii_byte] |= (src_elt << (4 * ii_bit_offset)); + } + } - const size_t row_tile_start_trans = col_tile_start_byte * ELTS_PER_BYTE; - const size_t col_tile_start_byte_trans = row_tile_start / ELTS_PER_BYTE; + const size_t row_tile_start_trans = col_tile_start_byte * ELTS_PER_BYTE; + const size_t col_tile_start_byte_trans = row_tile_start / ELTS_PER_BYTE; - const int row_limit_trans = std::min(row_tile_start_trans + M_TILE_L1, num_cols); - const int col_limit_trans = std::min(col_tile_start_byte_trans + N_TILE_L1, col_bytes_trans); + const int row_limit_trans = + std::min(row_tile_start_trans + M_TILE_L1, num_cols); + const int col_limit_trans = + std::min(col_tile_start_byte_trans + N_TILE_L1, col_bytes_trans); - for (int ii = 0; ii < M_TILE_L1; ++ii) { - const int row = row_tile_start_trans + ii; - for (int jj = 0; jj < N_TILE_L1; jj += VECTOR_WIDTH) { - const int col = col_tile_start_byte_trans + jj; + for (int ii = 0; ii < M_TILE_L1; ++ii) { + const int row = row_tile_start_trans + ii; + for (int jj = 0; jj < N_TILE_L1; jj += VECTOR_WIDTH) { + const int col = col_tile_start_byte_trans + jj; - const size_t logical_tgt_offset = matrix_offset + row * col_bytes_trans + col; + const size_t logical_tgt_offset = + matrix_offset + row * col_bytes_trans + col; - if (row < row_limit_trans && col < col_limit_trans) { - for (int v = 0; v < VECTOR_WIDTH; ++v) { - output_byte_ptr[logical_tgt_offset + v] = cache_buf[ii][jj + v]; - } - } - } - } + if (row < row_limit_trans && col < col_limit_trans) { + for (int v = 0; v < VECTOR_WIDTH; ++v) { + output_byte_ptr[logical_tgt_offset + v] = cache_buf[ii][jj + v]; + } } + } } + } } + } } - -void add_bias_and_interleave_int4s_inplace(int8_t* packed_int4_tensor, const size_t num_elts) -{ - const int num_bytes = num_elts / 2; - - // Step 1 will be to transform all the int4s to unsigned in order to make the dequantize take as little - // instructions as possible in the CUDA code. - for (size_t ii = 0; ii < num_bytes; ++ii) { - int8_t transformed_packed_int4s = 0; - // We don't need to mask in these ops since everything should be in the range 0-15 - int8_t transformed_first_elt = (packed_int4_tensor[ii] & 0x0F); - int8_t transformed_second_elt = (packed_int4_tensor[ii] >> 4); - - transformed_packed_int4s |= transformed_first_elt; - transformed_packed_int4s |= (transformed_second_elt << 4); - packed_int4_tensor[ii] = transformed_packed_int4s; - } - - // Step 2 will transform the layout of a 32-bit register in CUDA in order to minimize the number of shift & logical - // instructions That are needed to extract the int4s in the GEMM main loop. Pictorially, the loop below will do the - // following: Take as input a 32 bit register with layout: bit 32 0 - // [elt_7 elt_6 elt_5 elt_4 elt_3 elt_2 elt_1 elt_0] (each elt occupies 4 bits) - // - // And it will rearrange the output 32 bit register to be the following: - // bit 32 0 - // [elt_7 elt_5 elt_3 elt_1 elt_6 elt_4 elt_2 elt_0] (each elt occupies 4 bits) - - // FT_CHECK_WITH_INFO(num_bytes % 4 == 0, "Dimensions of int4 tensor must be a multiple of 8 for register relayout"); - const size_t num_registers = num_bytes / 4; - - uint32_t* register_ptr = reinterpret_cast(packed_int4_tensor); - for (size_t ii = 0; ii < num_registers; ++ii) { - const uint32_t current_register = register_ptr[ii]; - uint32_t transformed_register = 0; - - for (int dest_idx = 0; dest_idx < 8; ++dest_idx) { - const int src_idx = dest_idx < 4 ? 2 * dest_idx : 2 * (dest_idx - 4) + 1; - const int src_shift = 4 * src_idx; - const int dest_shift = 4 * dest_idx; - - const uint32_t src_bits = (current_register >> src_shift) & 0xF; - transformed_register |= (src_bits << dest_shift); - - } - register_ptr[ii] = transformed_register; +void add_bias_and_interleave_int4s_inplace(int8_t* packed_int4_tensor, + const size_t num_elts) { + const int num_bytes = num_elts / 2; + + // Step 1 will be to transform all the int4s to unsigned in order to make the + // dequantize take as little instructions as possible in the CUDA code. + for (size_t ii = 0; ii < num_bytes; ++ii) { + int8_t transformed_packed_int4s = 0; + // We don't need to mask in these ops since everything should be in the + // range 0-15 + int8_t transformed_first_elt = (packed_int4_tensor[ii] & 0x0F); + int8_t transformed_second_elt = (packed_int4_tensor[ii] >> 4); + + transformed_packed_int4s |= transformed_first_elt; + transformed_packed_int4s |= (transformed_second_elt << 4); + packed_int4_tensor[ii] = transformed_packed_int4s; + } + + // Step 2 will transform the layout of a 32-bit register in CUDA in order to + // minimize the number of shift & logical instructions That are needed to + // extract the int4s in the GEMM main loop. Pictorially, the loop below will + // do the following: Take as input a 32 bit register with layout: bit 32 0 + // [elt_7 elt_6 elt_5 elt_4 elt_3 elt_2 elt_1 elt_0] (each elt + // occupies 4 bits) + // + // And it will rearrange the output 32 bit register to be the following: + // bit 32 0 + // [elt_7 elt_5 elt_3 elt_1 elt_6 elt_4 elt_2 elt_0] (each elt + // occupies 4 bits) + + // FT_CHECK_WITH_INFO(num_bytes % 4 == 0, "Dimensions of int4 tensor must be a + // multiple of 8 for register relayout"); + const size_t num_registers = num_bytes / 4; + + uint32_t* register_ptr = reinterpret_cast(packed_int4_tensor); + for (size_t ii = 0; ii < num_registers; ++ii) { + const uint32_t current_register = register_ptr[ii]; + uint32_t transformed_register = 0; + + for (int dest_idx = 0; dest_idx < 8; ++dest_idx) { + const int src_idx = dest_idx < 4 ? 2 * dest_idx : 2 * (dest_idx - 4) + 1; + const int src_shift = 4 * src_idx; + const int dest_shift = 4 * dest_idx; + + const uint32_t src_bits = (current_register >> src_shift) & 0xF; + transformed_register |= (src_bits << dest_shift); } + register_ptr[ii] = transformed_register; + } } -void permute_B_rows_for_mixed_and_int8_gemm(int8_t* permuted_quantized_tensor, - const int8_t* quantized_tensor, +void permute_B_rows_for_mixed_and_int8_gemm(int8_t* permuted_quantized_tensor, + const int8_t* quantized_tensor, const std::vector& shape, - const int64_t arch_version) -{ - - // We only want to run this step for weight only quant. - const size_t num_rows = shape.size() == 2 ? shape[0] : shape[1]; - const size_t num_cols = shape.size() == 2 ? shape[1] : shape[2]; - - const int BITS_PER_ELT = 8; - const int K = 16 / BITS_PER_ELT; - const int ELTS_PER_BYTE = 8 / BITS_PER_ELT; - const int ELTS_PER_REG = 32 / BITS_PER_ELT; - - const uint32_t* input_byte_ptr = reinterpret_cast(quantized_tensor); - uint32_t* output_byte_ptr = reinterpret_cast(permuted_quantized_tensor); - - int MMA_SHAPE_N = 8; - int B_ROWS_PER_MMA = 8 * K; - const int elts_in_int32 = 32 / BITS_PER_ELT; - - const int num_vec_cols = num_cols / elts_in_int32; - - // The code is written as below so it works for both int8 and packed int4. - for (int base_row = 0; base_row < num_rows; base_row += B_ROWS_PER_MMA) { - for (int tile_row = 0; tile_row < B_ROWS_PER_MMA; ++tile_row) { - - for (int write_col = 0; write_col < num_vec_cols; ++write_col) { - const int write_row = base_row + tile_row; - const int tile_read_row = - 4 * (((tile_row % ELTS_PER_REG) / 2)) + tile_row % 2 + 2 * (tile_row / ELTS_PER_REG); - - const int read_row = base_row + tile_read_row; - const int read_col = write_col; - - const int64_t read_offset = int64_t(read_row) * num_vec_cols + read_col; - const int64_t write_offset = int64_t(write_row) * num_vec_cols + write_col; - - output_byte_ptr[write_offset] = input_byte_ptr[read_offset]; - } - } + const int64_t arch_version) { + // We only want to run this step for weight only quant. + const size_t num_rows = shape.size() == 2 ? shape[0] : shape[1]; + const size_t num_cols = shape.size() == 2 ? shape[1] : shape[2]; + + const int BITS_PER_ELT = 8; + const int K = 16 / BITS_PER_ELT; + const int ELTS_PER_BYTE = 8 / BITS_PER_ELT; + const int ELTS_PER_REG = 32 / BITS_PER_ELT; + + const uint32_t* input_byte_ptr = + reinterpret_cast(quantized_tensor); + uint32_t* output_byte_ptr = + reinterpret_cast(permuted_quantized_tensor); + + int MMA_SHAPE_N = 8; + int B_ROWS_PER_MMA = 8 * K; + const int elts_in_int32 = 32 / BITS_PER_ELT; + + const int num_vec_cols = num_cols / elts_in_int32; + + // The code is written as below so it works for both int8 and packed int4. + for (int base_row = 0; base_row < num_rows; base_row += B_ROWS_PER_MMA) { + for (int tile_row = 0; tile_row < B_ROWS_PER_MMA; ++tile_row) { + for (int write_col = 0; write_col < num_vec_cols; ++write_col) { + const int write_row = base_row + tile_row; + const int tile_read_row = 4 * (((tile_row % ELTS_PER_REG) / 2)) + + tile_row % 2 + 2 * (tile_row / ELTS_PER_REG); + + const int read_row = base_row + tile_read_row; + const int read_col = write_col; + + const int64_t read_offset = int64_t(read_row) * num_vec_cols + read_col; + const int64_t write_offset = + int64_t(write_row) * num_vec_cols + write_col; + + output_byte_ptr[write_offset] = input_byte_ptr[read_offset]; + } } + } } -// Permutes the rows of B for Turing and Ampere. Throws an error for other architectures. -// The data is permuted such that: -// For int8, each group of 16 rows is permuted using the map below: +// Permutes the rows of B for Turing and Ampere. Throws an error for other +// architectures. The data is permuted such that: For int8, each group of 16 +// rows is permuted using the map below: // 0 1 8 9 2 3 10 11 4 5 12 13 6 7 14 15 // 0 1 2 3 4 5 6 7 -template -void permute_B_rows_for_mixed_gemm(int8_t* permuted_quantized_tensor, - const int8_t* quantized_tensor, +template +void permute_B_rows_for_mixed_gemm(int8_t* permuted_quantized_tensor, + const int8_t* quantized_tensor, const std::vector& shape, - const int64_t arch_version) -{ - - // We only want to run this step for weight only quant. - const size_t num_rows = shape.size() == 2 ? shape[0] : shape[1]; - const size_t num_cols = shape.size() == 2 ? shape[1] : shape[2]; - - const int BITS_PER_ELT = bits; - const int K = 16 / BITS_PER_ELT; - const int ELTS_PER_BYTE = 8 / BITS_PER_ELT; - const int ELTS_PER_REG = 32 / BITS_PER_ELT; - - const uint32_t* input_byte_ptr = reinterpret_cast(quantized_tensor); - uint32_t* output_byte_ptr = reinterpret_cast(permuted_quantized_tensor); - - int MMA_SHAPE_N = 8; - int B_ROWS_PER_MMA = 8 * K; - const int elts_in_int32 = 32 / BITS_PER_ELT; - - const int num_vec_cols = num_cols / elts_in_int32; - - // The code is written as below so it works for both int8 and packed int4. - for (int base_row = 0; base_row < num_rows; base_row += B_ROWS_PER_MMA) { - for (int tile_row = 0; tile_row < B_ROWS_PER_MMA; ++tile_row) { - - for (int write_col = 0; write_col < num_vec_cols; ++write_col) { - const int write_row = base_row + tile_row; - const int tile_read_row = - 8 * (((tile_row % ELTS_PER_REG) / 2)) + tile_row % 2 + 2 * (tile_row / ELTS_PER_REG); - if(base_row == 0 && write_col == 0){ - std::cout<<"tile_read_row:"<(quantized_tensor); + uint32_t* output_byte_ptr = + reinterpret_cast(permuted_quantized_tensor); + + int MMA_SHAPE_N = 8; + int B_ROWS_PER_MMA = 8 * K; + const int elts_in_int32 = 32 / BITS_PER_ELT; + + const int num_vec_cols = num_cols / elts_in_int32; + + // The code is written as below so it works for both int8 and packed int4. + for (int base_row = 0; base_row < num_rows; base_row += B_ROWS_PER_MMA) { + for (int tile_row = 0; tile_row < B_ROWS_PER_MMA; ++tile_row) { + for (int write_col = 0; write_col < num_vec_cols; ++write_col) { + const int write_row = base_row + tile_row; + const int tile_read_row = 8 * (((tile_row % ELTS_PER_REG) / 2)) + + tile_row % 2 + 2 * (tile_row / ELTS_PER_REG); + if (base_row == 0 && write_col == 0) { + std::cout << "tile_read_row:" << tile_read_row << std::endl; + } + const int read_row = base_row + tile_read_row; + const int read_col = write_col; - const int64_t read_offset = int64_t(read_row) * num_vec_cols + read_col; - const int64_t write_offset = int64_t(write_row) * num_vec_cols + write_col; + const int64_t read_offset = int64_t(read_row) * num_vec_cols + read_col; + const int64_t write_offset = + int64_t(write_row) * num_vec_cols + write_col; - output_byte_ptr[write_offset] = input_byte_ptr[read_offset]; - } - } + output_byte_ptr[write_offset] = input_byte_ptr[read_offset]; + } } + } } -template -void permute_B_rows_for_mixed_gemm_int4(int8_t* permuted_quantized_tensor, - const int8_t* quantized_tensor, - const std::vector& shape, - const int64_t arch_version) -{ - - // We only want to run this step for weight only quant. - const size_t num_rows = shape.size() == 2 ? shape[0] : shape[1]; - const size_t num_cols = shape.size() == 2 ? shape[1] : shape[2]; - - const int BITS_PER_ELT = bits; //4 - const int K = 16 / BITS_PER_ELT; // 4 - const int ELTS_PER_BYTE = 8 / BITS_PER_ELT; // 2 - const int ELTS_PER_REG = 32 / BITS_PER_ELT; // 8 - - const uint32_t* input_byte_ptr = reinterpret_cast(quantized_tensor); - uint32_t* output_byte_ptr = reinterpret_cast(permuted_quantized_tensor); - - int MMA_SHAPE_N = 8; - int B_ROWS_PER_MMA = 8 * K; // 32 - const int elts_in_int32 = 32 / BITS_PER_ELT; - - const int num_vec_cols = num_cols / elts_in_int32; - const std::vector tile_col_map{ - 0,2,16,18, - 1,3,17,19, - 4,6,20,22, - 5,7,21,23, - 8,10,24,26, - 9,11,25,27, - 12,14,28,30, - 13,15,29,31}; - - // const std::vector tile_col_map{ - // 0 0,2,16,18, - // 4 1,3,17,19, - // 8 4,6,20,22, - // 12 5,7,21,23, - // 16 8,10,24,26, - // 20 9,11,25,27, - // 24 12,14,28,30, - // 28 13,15,29,31}; - // std::vector tile_col_map(32); - // for(int i=0;i<32;i++){ - // tile_col_map[i]=i; - // } - // // tile_col_map[1]=4; - // tile_col_map[0]=0; - // tile_col_map[4]=1; - // tile_col_map[1]=2; - // tile_col_map[5]=3; - // tile_col_map[8]=4; - // tile_col_map[12]=5; - // tile_col_map[9]=6; - // tile_col_map[13]=7; - // tile_col_map[16]=8; - // tile_col_map[20]=9; - // tile_col_map[17]=10; - // tile_col_map[21]=11; - // tile_col_map[24]=12; - // tile_col_map[28]=13; - // tile_col_map[25]=14; - // tile_col_map[29]=15; - - // tile_col_map[4]=1; - // tile_col_map[4]=1; - // tile_col_map[4]=2; - - // The code is written as below so it works for both int8 and packed int4. - for (int base_row = 0; base_row < num_rows; base_row += B_ROWS_PER_MMA) { - for (int tile_row = 0; tile_row < B_ROWS_PER_MMA; ++tile_row) { - - for (int write_col = 0; write_col < num_vec_cols; ++write_col) { - const int write_row = base_row + tile_row; - // const int tile_read_row = - // 8 * (((tile_row % ELTS_PER_REG) / 2)) + tile_row % 2 + 2 * (tile_row / ELTS_PER_REG); - // const int tile_read_row = std::distance(tile_col_map.begin(), std::find(tile_col_map.begin(),tile_col_map.end(), tile_row)); - const int tile_read_row = tile_col_map[tile_row]; - if(base_row == 0 && write_col == 0){ - std::cout<<" write_row:"< +void permute_B_rows_for_mixed_gemm_int4(int8_t* permuted_quantized_tensor, + const int8_t* quantized_tensor, + const std::vector& shape, + const int64_t arch_version) { + // We only want to run this step for weight only quant. + const size_t num_rows = shape.size() == 2 ? shape[0] : shape[1]; + const size_t num_cols = shape.size() == 2 ? shape[1] : shape[2]; + + const int BITS_PER_ELT = bits; // 4 + const int K = 16 / BITS_PER_ELT; // 4 + const int ELTS_PER_BYTE = 8 / BITS_PER_ELT; // 2 + const int ELTS_PER_REG = 32 / BITS_PER_ELT; // 8 + + const uint32_t* input_byte_ptr = + reinterpret_cast(quantized_tensor); + uint32_t* output_byte_ptr = + reinterpret_cast(permuted_quantized_tensor); + + int MMA_SHAPE_N = 8; + int B_ROWS_PER_MMA = 8 * K; // 32 + const int elts_in_int32 = 32 / BITS_PER_ELT; + + const int num_vec_cols = num_cols / elts_in_int32; + const std::vector tile_col_map{0, 2, 16, 18, 1, 3, 17, 19, 4, 6, 20, + 22, 5, 7, 21, 23, 8, 10, 24, 26, 9, 11, + 25, 27, 12, 14, 28, 30, 13, 15, 29, 31}; + + // const std::vector tile_col_map{ + // 0 0,2,16,18, + // 4 1,3,17,19, + // 8 4,6,20,22, + // 12 5,7,21,23, + // 16 8,10,24,26, + // 20 9,11,25,27, + // 24 12,14,28,30, + // 28 13,15,29,31}; + // std::vector tile_col_map(32); + // for(int i=0;i<32;i++){ + // tile_col_map[i]=i; + // } + // // tile_col_map[1]=4; + // tile_col_map[0]=0; + // tile_col_map[4]=1; + // tile_col_map[1]=2; + // tile_col_map[5]=3; + // tile_col_map[8]=4; + // tile_col_map[12]=5; + // tile_col_map[9]=6; + // tile_col_map[13]=7; + // tile_col_map[16]=8; + // tile_col_map[20]=9; + // tile_col_map[17]=10; + // tile_col_map[21]=11; + // tile_col_map[24]=12; + // tile_col_map[28]=13; + // tile_col_map[25]=14; + // tile_col_map[29]=15; + + // tile_col_map[4]=1; + // tile_col_map[4]=1; + // tile_col_map[4]=2; + + // The code is written as below so it works for both int8 and packed int4. + for (int base_row = 0; base_row < num_rows; base_row += B_ROWS_PER_MMA) { + for (int tile_row = 0; tile_row < B_ROWS_PER_MMA; ++tile_row) { + for (int write_col = 0; write_col < num_vec_cols; ++write_col) { + const int write_row = base_row + tile_row; + // const int tile_read_row = + // 8 * (((tile_row % ELTS_PER_REG) / 2)) + tile_row % 2 + 2 * + // (tile_row / ELTS_PER_REG); + // const int tile_read_row = std::distance(tile_col_map.begin(), + // std::find(tile_col_map.begin(),tile_col_map.end(), tile_row)); + const int tile_read_row = tile_col_map[tile_row]; + if (base_row == 0 && write_col == 0) { + std::cout << " write_row:" << tile_row + << " tile_read_row:" << tile_read_row << std::endl; } - } -} + const int read_row = base_row + tile_read_row; + const int read_col = write_col; + const int64_t read_offset = int64_t(read_row) * num_vec_cols + read_col; + const int64_t write_offset = + int64_t(write_row) * num_vec_cols + write_col; -void interleave_column_major_tensor(int8_t* interleaved_quantized_tensor, - const int8_t* quantized_tensor, - const std::vector& shape) -{ - - // We only want to run this step for weight only quant. - std::cout<<"### in interleave_column_major_tensor"<(quantized_tensor); - uint32_t* output_byte_ptr = reinterpret_cast(interleaved_quantized_tensor); - - - const size_t num_vec_rows = num_rows / elts_in_int32; - const size_t vec_rows_per_tile = rows_per_tile / elts_in_int32; - const size_t interleave = 2; - std::cout<<"num_vec_rows:"<& shape) { + // We only want to run this step for weight only quant. + std::cout << "### in interleave_column_major_tensor" << std::endl; + const size_t num_rows = shape.size() == 2 ? shape[0] : shape[1]; + const size_t num_cols = shape.size() == 2 ? shape[1] : shape[2]; + + const size_t BITS_PER_ELT = 8; + const size_t elts_in_int32 = 32 / BITS_PER_ELT; + + const size_t rows_per_tile = 64; + std::cout << "running interleave_column_major_tensor" << std::endl; + std::cout << "num_rows:" << num_rows << "," + << "num_cols:" << num_cols << "," + << "BITS_PER_ELT:" << BITS_PER_ELT << "," + << "elts_in_int32:" << elts_in_int32 << "," + << "rows_per_tile:" << rows_per_tile << std::endl; + + const uint32_t* input_byte_ptr = + reinterpret_cast(quantized_tensor); + uint32_t* output_byte_ptr = + reinterpret_cast(interleaved_quantized_tensor); + + const size_t num_vec_rows = num_rows / elts_in_int32; + const size_t vec_rows_per_tile = rows_per_tile / elts_in_int32; + const size_t interleave = 2; + std::cout << "num_vec_rows:" << num_vec_rows << "," + << "vec_rows_per_tile:" << vec_rows_per_tile << "," + << "interleave:" << interleave << std::endl; + for (int read_col = 0; read_col < num_cols; ++read_col) { + const size_t write_col = read_col / interleave; + for (int base_vec_row = 0; base_vec_row < num_vec_rows; + base_vec_row += vec_rows_per_tile) { + for (int vec_read_row = base_vec_row; + vec_read_row < + std::min(num_vec_rows, base_vec_row + vec_rows_per_tile); + ++vec_read_row) { + const size_t vec_write_row = + interleave * base_vec_row + + vec_rows_per_tile * (read_col % interleave) + + vec_read_row % vec_rows_per_tile; + + const size_t read_offset = + size_t(read_col) * num_vec_rows + vec_read_row; + const size_t write_offset = + size_t(write_col) * num_vec_rows * interleave + vec_write_row; + output_byte_ptr[write_offset] = input_byte_ptr[read_offset]; + } + } + } +} -void interleave_column_major_tensor_int4(int8_t* interleaved_quantized_tensor, - const int8_t* quantized_tensor, - const std::vector& shape) -{ - - // We only want to run this step for weight only quant. - std::cout<<"### in interleave_column_major_tensor"<(quantized_tensor); - uint32_t* output_byte_ptr = reinterpret_cast(interleaved_quantized_tensor); - - - const size_t num_vec_rows = num_rows / elts_in_int32; - const size_t vec_rows_per_tile = rows_per_tile / elts_in_int32; - const size_t interleave = 4; - std::cout<<"num_vec_rows:"<& shape) { + // We only want to run this step for weight only quant. + std::cout << "### in interleave_column_major_tensor" << std::endl; + const size_t num_rows = shape.size() == 2 ? shape[0] : shape[1]; + const size_t num_cols = shape.size() == 2 ? shape[1] : shape[2]; + + const size_t BITS_PER_ELT = 4; + const size_t elts_in_int32 = 32 / BITS_PER_ELT; + + const size_t rows_per_tile = 64; + std::cout << "running interleave_column_major_tensor" << std::endl; + std::cout << "num_rows:" << num_rows << "," + << "num_cols:" << num_cols << "," + << "BITS_PER_ELT:" << BITS_PER_ELT << "," + << "elts_in_int32:" << elts_in_int32 << "," + << "rows_per_tile:" << rows_per_tile << std::endl; + + const uint32_t* input_byte_ptr = + reinterpret_cast(quantized_tensor); + uint32_t* output_byte_ptr = + reinterpret_cast(interleaved_quantized_tensor); + + const size_t num_vec_rows = num_rows / elts_in_int32; + const size_t vec_rows_per_tile = rows_per_tile / elts_in_int32; + const size_t interleave = 4; + std::cout << "num_vec_rows:" << num_vec_rows << "," + << "vec_rows_per_tile:" << vec_rows_per_tile << "," + << "interleave:" << interleave << std::endl; + for (int read_col = 0; read_col < num_cols; ++read_col) { + const size_t write_col = read_col / interleave; + for (int base_vec_row = 0; base_vec_row < num_vec_rows; + base_vec_row += vec_rows_per_tile) { + for (int vec_read_row = base_vec_row; + vec_read_row < + std::min(num_vec_rows, base_vec_row + vec_rows_per_tile); + ++vec_read_row) { + const size_t vec_write_row = + interleave * base_vec_row + + vec_rows_per_tile * (read_col % interleave) + + vec_read_row % vec_rows_per_tile; + + const size_t read_offset = + size_t(read_col) * num_vec_rows + vec_read_row; + const size_t write_offset = + size_t(write_col) * num_vec_rows * interleave + vec_write_row; + output_byte_ptr[write_offset] = input_byte_ptr[read_offset]; + } } + } } diff --git a/custom_ops/gpu_ops/cutlass_kernels/w8a8/c3x/cutlass_gemm_caller.cuh b/custom_ops/gpu_ops/cutlass_kernels/w8a8/c3x/cutlass_gemm_caller.cuh index 1944716c3e5..e35018a26fb 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/w8a8/c3x/cutlass_gemm_caller.cuh +++ b/custom_ops/gpu_ops/cutlass_kernels/w8a8/c3x/cutlass_gemm_caller.cuh @@ -1,4 +1,5 @@ -// adapted from: https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/quantization/cutlass_w8a8/c3x/cutlass_gemm_caller.cuh +// adapted from: +// https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/quantization/cutlass_w8a8/c3x/cutlass_gemm_caller.cuh #pragma once // clang-format will break include orders @@ -21,15 +22,16 @@ namespace fastdeploy::c3x { -static inline cute::Shape -get_problem_shape(paddle::Tensor const &a, paddle::Tensor const &b) { +static inline cute::Shape get_problem_shape( + paddle::Tensor const &a, paddle::Tensor const &b) { int32_t m = a.dims()[0], n = b.dims()[0], k = a.dims()[1]; return {m, n, k, 1}; } template void cutlass_gemm_caller( - phi::Place device, cute::Shape prob_shape, + phi::Place device, + cute::Shape prob_shape, typename GemmKernel::MainloopArguments mainloop_args, typename GemmKernel::EpilogueArguments epilogue_args, typename GemmKernel::TileSchedulerArguments scheduler = {}) { @@ -57,7 +59,8 @@ void cutlass_gemm_caller( } template -void cutlass_gemm_caller(paddle::Tensor &out, paddle::Tensor const &a, +void cutlass_gemm_caller(paddle::Tensor &out, + paddle::Tensor const &a, paddle::Tensor const &b, EpilogueArgs &&...epilogue_params) { using ElementAB = typename Gemm::ElementAB; @@ -86,17 +89,20 @@ void cutlass_gemm_caller(paddle::Tensor &out, paddle::Tensor const &a, auto a_ptr = static_cast(const_cast(a.data())); auto b_ptr = static_cast(const_cast(b.data())); - typename GemmKernel::MainloopArguments mainloop_args{a_ptr, a_stride, b_ptr, - b_stride}; + typename GemmKernel::MainloopArguments mainloop_args{ + a_ptr, a_stride, b_ptr, b_stride}; auto c_ptr = static_cast(const_cast(out.data())); typename GemmKernel::EpilogueArguments epilogue_args{ Gemm::Epilogue::prepare_args( std::forward(epilogue_params)...), - c_ptr, c_stride, c_ptr, d_stride}; + c_ptr, + c_stride, + c_ptr, + d_stride}; - cutlass_gemm_caller(a.place(), prob_shape, mainloop_args, - epilogue_args); + cutlass_gemm_caller( + a.place(), prob_shape, mainloop_args, epilogue_args); } -} // namespace fastdeploy::c3x +} // namespace fastdeploy::c3x diff --git a/custom_ops/gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm.cuh b/custom_ops/gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm.cuh index 26278a79fd4..4d8edbd6210 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm.cuh +++ b/custom_ops/gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm.cuh @@ -1,4 +1,5 @@ -// adapted from: https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/quantization/cutlass_w8a8/c3x/scaled_mm.cuh +// adapted from: +// https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/quantization/cutlass_w8a8/c3x/scaled_mm.cuh #pragma once @@ -31,16 +32,19 @@ using namespace cute; namespace fastdeploy { -template typename Epilogue_, - typename TileShape, typename ClusterShape, typename KernelSchedule, +template + typename Epilogue_, + typename TileShape, + typename ClusterShape, + typename KernelSchedule, typename EpilogueSchedule> struct cutlass_3x_gemm { using ElementAB = ElementAB_; using ElementD = ElementD_; - using ElementAcc = - typename std::conditional, int32_t, - float>::type; + using ElementAcc = typename std:: + conditional, int32_t, float>::type; using Epilogue = Epilogue_; @@ -57,10 +61,21 @@ struct cutlass_3x_gemm { using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< - cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape, - ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto, - ElementAcc, float, ElementC, StrideC, AlignmentCD, ElementD, StrideD, - AlignmentCD, EpilogueSchedule, EVTCompute>::CollectiveOp; + cutlass::arch::Sm90, + cutlass::arch::OpClassTensorOp, + TileShape, + ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, + float, + ElementC, + StrideC, + AlignmentCD, + ElementD, + StrideD, + AlignmentCD, + EpilogueSchedule, + EVTCompute>::CollectiveOp; static constexpr size_t CEStorageSize = sizeof(typename CollectiveEpilogue::SharedStorage); @@ -78,16 +93,22 @@ struct cutlass_3x_gemm { KernelSchedule>::CollectiveOp; // clang-format on - using KernelType = enable_sm90_or_later, CollectiveMainloop, CollectiveEpilogue, - cutlass::gemm::PersistentScheduler>>; + using KernelType = enable_sm90_or_later< + cutlass::gemm::kernel::GemmUniversal, + CollectiveMainloop, + CollectiveEpilogue, + cutlass::gemm::PersistentScheduler>>; struct GemmKernel : public KernelType {}; }; -template typename Epilogue_, - typename TileShape, typename ClusterShape, typename KernelSchedule, +template + typename Epilogue_, + typename TileShape, + typename ClusterShape, + typename KernelSchedule, typename EpilogueSchedule> struct cutlass_3x_gemm_sm100 { using ElementAB = ElementAB_; @@ -108,9 +129,8 @@ struct cutlass_3x_gemm_sm100 { using LayoutD = cutlass::layout::RowMajor; static constexpr int AlignmentD = AlignmentC; - using ElementAcc = - typename std::conditional, int32_t, - float>::type; + using ElementAcc = typename std:: + conditional, int32_t, float>::type; using Epilogue = Epilogue_; // MMA type @@ -127,23 +147,44 @@ struct cutlass_3x_gemm_sm100 { using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< - cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, TileShape, - ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto, - ElementAccumulator, ElementCompute, ElementC, LayoutC, AlignmentC, - ElementD, LayoutD, AlignmentD, EpilogueSchedule, + cutlass::arch::Sm100, + cutlass::arch::OpClassTensorOp, + TileShape, + ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, + ElementCompute, + ElementC, + LayoutC, + AlignmentC, + ElementD, + LayoutD, + AlignmentD, + EpilogueSchedule, EVTCompute>::CollectiveOp; using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< - cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, ElementAB, - LayoutA, AlignmentA, ElementAB, LayoutB, AlignmentB, - ElementAccumulator, TileShape, ClusterShape, + cutlass::arch::Sm100, + cutlass::arch::OpClassTensorOp, + ElementAB, + LayoutA, + AlignmentA, + ElementAB, + LayoutB, + AlignmentB, + ElementAccumulator, + TileShape, + ClusterShape, cutlass::gemm::collective::StageCountAutoCarveout( sizeof(typename CollectiveEpilogue::SharedStorage))>, KernelSchedule>::CollectiveOp; - using GemmKernel = cutlass::gemm::kernel::GemmUniversal< - Shape, CollectiveMainloop, CollectiveEpilogue, void>; + using GemmKernel = + cutlass::gemm::kernel::GemmUniversal, + CollectiveMainloop, + CollectiveEpilogue, + void>; }; -} // namespace fastdeploy +} // namespace fastdeploy diff --git a/custom_ops/gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm_azp_sm90_int8.cu b/custom_ops/gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm_azp_sm90_int8.cu index f5d4d6aa28e..704f1021f87 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm_azp_sm90_int8.cu +++ b/custom_ops/gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm_azp_sm90_int8.cu @@ -1,4 +1,5 @@ -// adapted from: https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_azp_sm90_int8.cu +// adapted from: +// https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_azp_sm90_int8.cu // clang-format will break include orders // clang-format off @@ -10,18 +11,22 @@ namespace fastdeploy { void cutlass_scaled_mm_azp_sm90_int8( - paddle::Tensor &out, paddle::Tensor const &a, paddle::Tensor const &b, - paddle::Tensor const &a_scales, paddle::Tensor const &b_scales, - paddle::Tensor const &azp_adj, paddle::optional const &azp, + paddle::Tensor &out, + paddle::Tensor const &a, + paddle::Tensor const &b, + paddle::Tensor const &a_scales, + paddle::Tensor const &b_scales, + paddle::Tensor const &azp_adj, + paddle::optional const &azp, paddle::optional const &bias) { if (azp) { return cutlass_scaled_mm_sm90_int8_epilogue< - c3x::ScaledEpilogueBiasAzpToken>(out, a, b, a_scales, b_scales, azp_adj, - *azp, bias); + c3x::ScaledEpilogueBiasAzpToken>( + out, a, b, a_scales, b_scales, azp_adj, *azp, bias); } else { return cutlass_scaled_mm_sm90_int8_epilogue( out, a, b, a_scales, b_scales, azp_adj, bias); } } -} // namespace fastdeploy +} // namespace fastdeploy diff --git a/custom_ops/gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm_helper.hpp b/custom_ops/gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm_helper.hpp index 9a601f75ad2..2bfa58231fa 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm_helper.hpp +++ b/custom_ops/gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm_helper.hpp @@ -1,34 +1,38 @@ -// adapted from: https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_helper.hpp +// adapted from: +// https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_helper.hpp #include "helper.h" template -void dispatch_scaled_mm(paddle::Tensor &c, paddle::Tensor const &a, - paddle::Tensor const &b, paddle::Tensor const &a_scales, +void dispatch_scaled_mm(paddle::Tensor &c, + paddle::Tensor const &a, + paddle::Tensor const &b, + paddle::Tensor const &a_scales, paddle::Tensor const &b_scales, paddle::optional const &bias, - Fp8Func fp8_func, Int8Func int8_func) { - PD_CHECK(a_scales.dtype() == paddle::DataType::FLOAT32); - PD_CHECK(b_scales.dtype() == paddle::DataType::FLOAT32); + Fp8Func fp8_func, + Int8Func int8_func) { + PD_CHECK(a_scales.dtype() == paddle::DataType::FLOAT32); + PD_CHECK(b_scales.dtype() == paddle::DataType::FLOAT32); - int M = a.dims()[0], N = b.dims()[0], K = a.dims()[1]; + int M = a.dims()[0], N = b.dims()[0], K = a.dims()[1]; - if ((a_scales.numel() == 1 || a_scales.numel() == a.dims()[0]) && - (b_scales.numel() == 1 || b_scales.numel() == b.dims()[0])) { - // Standard per-tensor/per-token/per-channel scaling - PD_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous()); - if (a.dtype() == phi::DataType::FLOAT8_E4M3FN) { - fp8_func(c, a, b, a_scales, b_scales, bias); + if ((a_scales.numel() == 1 || a_scales.numel() == a.dims()[0]) && + (b_scales.numel() == 1 || b_scales.numel() == b.dims()[0])) { + // Standard per-tensor/per-token/per-channel scaling + PD_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous()); + if (a.dtype() == phi::DataType::FLOAT8_E4M3FN) { + fp8_func(c, a, b, a_scales, b_scales, bias); + } else { + PD_CHECK(a.dtype() == paddle::DataType::INT8); + if constexpr (!std::is_same_v) { + int8_func(c, a, b, a_scales, b_scales, bias); } else { - PD_CHECK(a.dtype() == paddle::DataType::INT8); - if constexpr (!std::is_same_v) { - int8_func(c, a, b, a_scales, b_scales, bias); - } else { - PD_CHECK(false, "Int8 not supported for this architecture"); - } + PD_CHECK(false, "Int8 not supported for this architecture"); } - } else { - PADDLE_THROW(phi::errors::Unimplemented( - "No kernel for this combination of input dtypes is implemented.")); } + } else { + PADDLE_THROW(phi::errors::Unimplemented( + "No kernel for this combination of input dtypes is implemented.")); + } } diff --git a/custom_ops/gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm_kernels.hpp b/custom_ops/gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm_kernels.hpp index 75472ea805e..4e3e67ac5a4 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm_kernels.hpp +++ b/custom_ops/gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm_kernels.hpp @@ -1,4 +1,5 @@ -// adapted from: https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_kernels.hpp +// adapted from: +// https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_kernels.hpp #pragma once @@ -6,30 +7,35 @@ namespace fastdeploy { -void cutlass_scaled_mm_sm90_fp8(paddle::Tensor &out, paddle::Tensor const &a, +void cutlass_scaled_mm_sm90_fp8(paddle::Tensor &out, + paddle::Tensor const &a, paddle::Tensor const &b, paddle::Tensor const &a_scales, paddle::Tensor const &b_scales, paddle::optional const &bias); -void cutlass_scaled_mm_sm90_int8(paddle::Tensor &out, paddle::Tensor const &a, +void cutlass_scaled_mm_sm90_int8(paddle::Tensor &out, + paddle::Tensor const &a, paddle::Tensor const &b, paddle::Tensor const &a_scales, paddle::Tensor const &b_scales, paddle::optional const &bias); -void cutlass_scaled_mm_azp_sm90_int8(paddle::Tensor& out, paddle::Tensor const& a, - paddle::Tensor const& b, - paddle::Tensor const& a_scales, - paddle::Tensor const& b_scales, - paddle::Tensor const& azp_adj, - paddle::optional const& azp, - paddle::optional const& bias); - -void cutlass_scaled_mm_sm100_fp8(paddle::Tensor &out, paddle::Tensor const &a, +void cutlass_scaled_mm_azp_sm90_int8( + paddle::Tensor &out, + paddle::Tensor const &a, + paddle::Tensor const &b, + paddle::Tensor const &a_scales, + paddle::Tensor const &b_scales, + paddle::Tensor const &azp_adj, + paddle::optional const &azp, + paddle::optional const &bias); + +void cutlass_scaled_mm_sm100_fp8(paddle::Tensor &out, + paddle::Tensor const &a, paddle::Tensor const &b, paddle::Tensor const &a_scales, paddle::Tensor const &b_scales, paddle::optional const &bias); -} // namespace fastdeploy +} // namespace fastdeploy diff --git a/custom_ops/gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm_sm90_fp8.cu b/custom_ops/gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm_sm90_fp8.cu index 801e90fd733..1b197b80289 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm_sm90_fp8.cu +++ b/custom_ops/gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm_sm90_fp8.cu @@ -1,4 +1,5 @@ -// adapted from: https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_fp8.cu +// adapted from: +// https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_fp8.cu // clang-format will break include orders // clang-format off @@ -9,7 +10,8 @@ namespace fastdeploy { -void cutlass_scaled_mm_sm90_fp8(paddle::Tensor &out, paddle::Tensor const &a, +void cutlass_scaled_mm_sm90_fp8(paddle::Tensor &out, + paddle::Tensor const &a, paddle::Tensor const &b, paddle::Tensor const &a_scales, paddle::Tensor const &b_scales, @@ -17,7 +19,8 @@ void cutlass_scaled_mm_sm90_fp8(paddle::Tensor &out, paddle::Tensor const &a, PD_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous()); if (bias) { PD_CHECK(bias->dtype() == out.dtype(), - "currently bias dtype must match output dtype ", out.dtype()); + "currently bias dtype must match output dtype ", + out.dtype()); return cutlass_scaled_mm_sm90_fp8_epilogue( out, a, b, a_scales, b_scales, *bias); } else { @@ -25,4 +28,4 @@ void cutlass_scaled_mm_sm90_fp8(paddle::Tensor &out, paddle::Tensor const &a, out, a, b, a_scales, b_scales); } } -} // namespace fastdeploy +} // namespace fastdeploy diff --git a/custom_ops/gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm_sm90_fp8_dispatch.cuh b/custom_ops/gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm_sm90_fp8_dispatch.cuh index ac86aeba857..cd0eda3ad9b 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm_sm90_fp8_dispatch.cuh +++ b/custom_ops/gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm_sm90_fp8_dispatch.cuh @@ -1,4 +1,5 @@ -// adapted from: https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_fp8_dispatch.cuh +// adapted from: +// https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_fp8_dispatch.cuh #pragma once @@ -17,8 +18,10 @@ namespace fastdeploy { using c3x::cutlass_gemm_caller; -template typename Epilogue> +template + typename Epilogue> struct sm90_fp8_config_default { // M in (128, inf) static_assert(std::is_same()); @@ -27,13 +30,19 @@ struct sm90_fp8_config_default { using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; using TileShape = Shape<_128, _128, _128>; using ClusterShape = Shape<_2, _1, _1>; - using Cutlass3xGemm = - cutlass_3x_gemm; + using Cutlass3xGemm = cutlass_3x_gemm; }; -template typename Epilogue> +template + typename Epilogue> struct sm90_fp8_config_M128 { // M in (64, 128] static_assert(std::is_same()); @@ -42,13 +51,19 @@ struct sm90_fp8_config_M128 { using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; using TileShape = Shape<_64, _128, _128>; using ClusterShape = Shape<_2, _1, _1>; - using Cutlass3xGemm = - cutlass_3x_gemm; + using Cutlass3xGemm = cutlass_3x_gemm; }; -template typename Epilogue> +template + typename Epilogue> struct sm90_fp8_config_M64 { // M in [1, 64] static_assert(std::is_same()); @@ -58,13 +73,19 @@ struct sm90_fp8_config_M64 { using TileShape = Shape<_64, _64, _128>; using ClusterShape = Shape<_1, _8, _1>; - using Cutlass3xGemm = - cutlass_3x_gemm; + using Cutlass3xGemm = cutlass_3x_gemm; }; -template typename Epilogue, +template + typename Epilogue, typename... EpilogueArgs> inline void cutlass_gemm_sm90_fp8_dispatch(paddle::Tensor &out, paddle::Tensor const &a, @@ -75,8 +96,8 @@ inline void cutlass_gemm_sm90_fp8_dispatch(paddle::Tensor &out, PD_CHECK(b.dtype() == phi::DataType::FLOAT8_E4M3FN); using Cutlass3xGemmDefault = - typename sm90_fp8_config_default::Cutlass3xGemm; + typename sm90_fp8_config_default:: + Cutlass3xGemm; using Cutlass3xGemmM64 = typename sm90_fp8_config_M64::Cutlass3xGemm; using Cutlass3xGemmM128 = @@ -84,7 +105,7 @@ inline void cutlass_gemm_sm90_fp8_dispatch(paddle::Tensor &out, uint32_t const m = a.dims()[0]; uint32_t const mp2 = - std::max(static_cast(64), next_pow_2(m)); // next power of 2 + std::max(static_cast(64), next_pow_2(m)); // next power of 2 if (mp2 <= 64) { // m in [1, 64] @@ -112,14 +133,16 @@ void cutlass_scaled_mm_sm90_fp8_epilogue(paddle::Tensor &out, if (out.dtype() == paddle::DataType::BFLOAT16) { return cutlass_gemm_sm90_fp8_dispatch( + cutlass::bfloat16_t, + Epilogue>( out, a, b, std::forward(epilogue_args)...); } else { PD_CHECK(out.dtype() == paddle::DataType::FLOAT16); return cutlass_gemm_sm90_fp8_dispatch( + cutlass::half_t, + Epilogue>( out, a, b, std::forward(epilogue_args)...); } } -} // namespace fastdeploy +} // namespace fastdeploy diff --git a/custom_ops/gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm_sm90_int8.cu b/custom_ops/gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm_sm90_int8.cu index 633f76fd887..5256b27c143 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm_sm90_int8.cu +++ b/custom_ops/gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm_sm90_int8.cu @@ -1,4 +1,5 @@ -// adapted from: https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_int8.cu +// adapted from: +// https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_int8.cu // clang-format will break include orders // clang-format off @@ -9,7 +10,8 @@ namespace fastdeploy { -void cutlass_scaled_mm_sm90_int8(paddle::Tensor &out, paddle::Tensor const &a, +void cutlass_scaled_mm_sm90_int8(paddle::Tensor &out, + paddle::Tensor const &a, paddle::Tensor const &b, paddle::Tensor const &a_scales, paddle::Tensor const &b_scales, @@ -17,7 +19,8 @@ void cutlass_scaled_mm_sm90_int8(paddle::Tensor &out, paddle::Tensor const &a, PD_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous()); if (bias) { PD_CHECK(bias->dtype() == out.dtype(), - "currently bias dtype must match output dtype ", out.dtype()); + "currently bias dtype must match output dtype ", + out.dtype()); return cutlass_scaled_mm_sm90_int8_epilogue( out, a, b, a_scales, b_scales, *bias); } else { @@ -26,4 +29,4 @@ void cutlass_scaled_mm_sm90_int8(paddle::Tensor &out, paddle::Tensor const &a, } } -} // namespace fastdeploy +} // namespace fastdeploy diff --git a/custom_ops/gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm_sm90_int8_dispatch.cuh b/custom_ops/gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm_sm90_int8_dispatch.cuh index df63de0fa6a..1b14e1b749c 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm_sm90_int8_dispatch.cuh +++ b/custom_ops/gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm_sm90_int8_dispatch.cuh @@ -1,4 +1,5 @@ -// adapted from: https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_int8_dispatch.cuh +// adapted from: +// https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_int8_dispatch.cuh #pragma once @@ -17,8 +18,10 @@ namespace fastdeploy { using c3x::cutlass_gemm_caller; -template typename Epilogue> +template + typename Epilogue> struct sm90_int8_config_default { // For M > 128 and any N static_assert(std::is_same()); @@ -27,13 +30,19 @@ struct sm90_int8_config_default { using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; using TileShape = Shape<_128, _128, _128>; using ClusterShape = Shape<_2, _1, _1>; - using Cutlass3xGemm = - cutlass_3x_gemm; + using Cutlass3xGemm = cutlass_3x_gemm; }; -template typename Epilogue> +template + typename Epilogue> struct sm90_int8_config_M128 { // For M in (64, 128] and any N static_assert(std::is_same()); @@ -42,13 +51,19 @@ struct sm90_int8_config_M128 { using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; using TileShape = Shape<_64, _128, _128>; using ClusterShape = Shape<_2, _1, _1>; - using Cutlass3xGemm = - cutlass_3x_gemm; + using Cutlass3xGemm = cutlass_3x_gemm; }; -template typename Epilogue> +template + typename Epilogue> struct sm90_int8_config_M64 { // For M in (32, 64] and any N static_assert(std::is_same()); @@ -56,13 +71,19 @@ struct sm90_int8_config_M64 { using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; using TileShape = Shape<_64, _64, _256>; using ClusterShape = Shape<_1, _1, _1>; - using Cutlass3xGemm = - cutlass_3x_gemm; + using Cutlass3xGemm = cutlass_3x_gemm; }; -template typename Epilogue> +template + typename Epilogue> struct sm90_int8_config_M32_NBig { // For M in [1, 32] and N >= 8192 static_assert(std::is_same()); @@ -70,13 +91,19 @@ struct sm90_int8_config_M32_NBig { using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; using TileShape = Shape<_64, _128, _256>; using ClusterShape = Shape<_1, _4, _1>; - using Cutlass3xGemm = - cutlass_3x_gemm; + using Cutlass3xGemm = cutlass_3x_gemm; }; -template typename Epilogue> +template + typename Epilogue> struct sm90_int8_config_M32_NSmall { // For M in [1, 32] and N < 8192 static_assert(std::is_same()); @@ -84,13 +111,19 @@ struct sm90_int8_config_M32_NSmall { using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; using TileShape = Shape<_64, _64, _256>; using ClusterShape = Shape<_1, _8, _1>; - using Cutlass3xGemm = - cutlass_3x_gemm; + using Cutlass3xGemm = cutlass_3x_gemm; }; -template typename Epilogue, +template + typename Epilogue, typename... EpilogueArgs> inline void cutlass_gemm_sm90_int8_dispatch(paddle::Tensor &out, paddle::Tensor const &a, @@ -101,25 +134,25 @@ inline void cutlass_gemm_sm90_int8_dispatch(paddle::Tensor &out, PD_CHECK(b.dtype() == paddle::DataType::INT8); using Cutlass3xGemmDefault = - typename sm90_int8_config_default::Cutlass3xGemm; + typename sm90_int8_config_default:: + Cutlass3xGemm; using Cutlass3xGemmM128 = typename sm90_int8_config_M128::Cutlass3xGemm; using Cutlass3xGemmM64 = typename sm90_int8_config_M64::Cutlass3xGemm; using Cutlass3xGemmM32NBig = - typename sm90_int8_config_M32_NBig::Cutlass3xGemm; + typename sm90_int8_config_M32_NBig:: + Cutlass3xGemm; using Cutlass3xGemmM32NSmall = - typename sm90_int8_config_M32_NSmall::Cutlass3xGemm; + typename sm90_int8_config_M32_NSmall:: + Cutlass3xGemm; uint32_t const n = out.dims()[1]; bool const is_small_n = n < 8192; uint32_t const m = a.dims()[0]; uint32_t const mp2 = - std::max(static_cast(32), next_pow_2(m)); // next power of 2 + std::max(static_cast(32), next_pow_2(m)); // next power of 2 if (mp2 <= 32) { // m in [1, 32] @@ -155,7 +188,8 @@ void cutlass_scaled_mm_sm90_int8_epilogue(paddle::Tensor &out, PD_CHECK(b.dtype() == paddle::DataType::INT8); if (out.dtype() == paddle::DataType::BFLOAT16) { - return cutlass_gemm_sm90_int8_dispatch( out, a, b, std::forward(epilogue_args)...); } else { @@ -165,4 +199,4 @@ void cutlass_scaled_mm_sm90_int8_epilogue(paddle::Tensor &out, } } -} // namespace fastdeploy +} // namespace fastdeploy diff --git a/custom_ops/gpu_ops/cutlass_kernels/w8a8/scaled_mm_c2x.cu b/custom_ops/gpu_ops/cutlass_kernels/w8a8/scaled_mm_c2x.cu index 55015ea3e96..e3ce7e1fcbf 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/w8a8/scaled_mm_c2x.cu +++ b/custom_ops/gpu_ops/cutlass_kernels/w8a8/scaled_mm_c2x.cu @@ -1,4 +1,5 @@ -// adapted from: https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu +// adapted from: +// https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu #include "helper.h" #include @@ -20,7 +21,8 @@ using namespace fastdeploy; template