diff --git a/cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/fmhaKernels.h b/cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/fmhaKernels.h index 93e6a2c18e21..6912cc11b850 100644 --- a/cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/fmhaKernels.h +++ b/cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/fmhaKernels.h @@ -190,6 +190,16 @@ class TllmGenFmhaKernel } } + static bool shouldUseNvrtc(FmhaOptions const& options) + { + // Check if the NVRTC path should be used for a given FMHA configuration. + bool isLlama70bFp4Tp4 = options.mHeadDimQk == 128 && options.mHeadDimV == 128 + && options.mDtypeKv == tg::Dtype::E4m3 && options.mNumHeadsQ == 16 && options.mNumHeadsQPerKv == 8; + + return options.mFmhaKernelType == FmhaKernelType::SwapsMmaAbForGeneration && !options.mIsMlaGen + && options.mDtypeKv != tg::Dtype::E2m1 && options.mHeadDimQk != 64 && !isLlama70bFp4Tp4; + } + std::pair checkIfKernelExist(RunnerParams const& params) const { // Some conditions to check if the kernel is supported. @@ -205,38 +215,57 @@ class TllmGenFmhaKernel return std::make_pair(false, "Empty batch or zero sequence length"); } - // The selectKernelParams that might be updated. - SelectKernelParams selectKernelParams{params}; - - int32_t ctaDim = 512; - FmhaOptions options; - FmhaOptionsFromArgs optionsFromArgs; - parseOptionsFromRunnerParams(params, options); - options.mCudaArch = intToCudaArch(mSM); + try + { + int32_t ctaDim = 512; + FmhaOptions options; + FmhaOptionsFromArgs optionsFromArgs; + parseOptionsFromRunnerParams(params, options); + options.mCudaArch = intToCudaArch(mSM); + + FmhaAutoTuner autoTuner(options, optionsFromArgs, params.mMultiProcessorCount); + std::tie(options, optionsFromArgs, ctaDim) = autoTuner.selectKernel(); + // Check if the options are valid or not. + checkFmhaOptions(options, optionsFromArgs); + // Update the options if needed. + updateFmhaOptions(options, optionsFromArgs); + // The number of CtasQ and CtasKv per sequence, Ctas in the Y dimension, and Ctas in the Z + // dimension. + computeNumCtas(options, params.mMultiProcessorCount); + + if (shouldUseNvrtc(options)) + { + // For the NVRTC path, we return supported as long as autotuner successfully selected a kernel config. + std::ostringstream sstream; + populateJsonConfig(options, sstream); + std::string info = sstream.str(); - FmhaAutoTuner autoTuner(options, optionsFromArgs, params.mMultiProcessorCount); - std::tie(options, optionsFromArgs, ctaDim) = autoTuner.selectKernel(); - // Check if the options are valid or not. - checkFmhaOptions(options, optionsFromArgs); - // Update the options if needed. - updateFmhaOptions(options, optionsFromArgs); - // The number of CtasQ and CtasKv per sequence, Ctas in the Y dimension, and Ctas in the Z - // dimension. - computeNumCtas(options, params.mMultiProcessorCount); + return std::make_pair(true, info); + } + else + { + // Check if a precompiled cubin exists for this configuration (same lookup as run()). + // If not, return (false, info) so the dispatcher can fall back to unfused MHA like on main. + algoFilterForCubinPath(options); + auto [hashId, info] = hashFromFmhaOptions(options); - // Check if a precompiled cubin exists for this configuration (same lookup as run()). - // If not, return (false, info) so the dispatcher can fall back to unfused MHA like on main. - algoFilterForCubinPath(options); - auto [hashId, info] = hashFromFmhaOptions(options); + if (mFunctions.find(hashId) == mFunctions.end()) + { + TLLM_LOG_WARNING("Trtllm-gen kernels not found: " + info); + return std::make_pair(false, info); + } + TLLM_LOG_DEBUG("TRTLLM-Gen kernel traits: %s", info.c_str()); - if (mFunctions.find(hashId) == mFunctions.end()) + return std::make_pair(true, info); + } + } + catch (std::exception const& e) { - TLLM_LOG_WARNING("Trtllm-gen kernels not found: " + info); - return std::make_pair(false, info); + // Omitting e.what(), they may contain "Runtime Error" and make scripts believe a fatal error happened. + std::string const errorInfo = std::string("Exception during TrtllmGen kernel existence check"); + TLLM_LOG_WARNING(errorInfo); + return std::make_pair(false, errorInfo); } - TLLM_LOG_DEBUG("TRTLLM-Gen kernel traits: %s", info.c_str()); - - return std::make_pair(true, info); } void algoFilterForCubinPath(FmhaOptions& options) const @@ -285,13 +314,7 @@ class TllmGenFmhaKernel FmhaData fmhaData; setFmhaData(params, options, fmhaData); - bool isLlama70bFp4Tp4 = options.mHeadDimQk == 128 && options.mHeadDimV == 128 - && options.mDtypeKv == tg::Dtype::E4m3 && options.mNumHeadsQ == 16 && options.mNumHeadsQPerKv == 8; - - bool shouldUseNvrtc = options.mFmhaKernelType == FmhaKernelType::SwapsMmaAbForGeneration && !options.mIsMlaGen - && options.mDtypeKv != tg::Dtype::E2m1 && options.mHeadDimQk != 64 && !isLlama70bFp4Tp4; - - if (shouldUseNvrtc) + if (shouldUseNvrtc(options)) { // nvrtc path - uses mFmhaInterface member for kernel caching FmhaConfig fmhaConfig;