Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 57 additions & 34 deletions cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/fmhaKernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,16 @@ class TllmGenFmhaKernel
}
}

static bool shouldUseNvrtc(FmhaOptions const& options)
{
// Check if the NVRTC path should be used for a given FMHA configuration.
bool isLlama70bFp4Tp4 = options.mHeadDimQk == 128 && options.mHeadDimV == 128
&& options.mDtypeKv == tg::Dtype::E4m3 && options.mNumHeadsQ == 16 && options.mNumHeadsQPerKv == 8;

return options.mFmhaKernelType == FmhaKernelType::SwapsMmaAbForGeneration && !options.mIsMlaGen
&& options.mDtypeKv != tg::Dtype::E2m1 && options.mHeadDimQk != 64 && !isLlama70bFp4Tp4;
}

std::pair<bool, std::string> checkIfKernelExist(RunnerParams const& params) const
{
// Some conditions to check if the kernel is supported.
Expand All @@ -205,38 +215,57 @@ class TllmGenFmhaKernel
return std::make_pair(false, "Empty batch or zero sequence length");
}

// The selectKernelParams that might be updated.
SelectKernelParams selectKernelParams{params};

int32_t ctaDim = 512;
FmhaOptions options;
FmhaOptionsFromArgs optionsFromArgs;
parseOptionsFromRunnerParams(params, options);
options.mCudaArch = intToCudaArch(mSM);
try
{
int32_t ctaDim = 512;
FmhaOptions options;
FmhaOptionsFromArgs optionsFromArgs;
parseOptionsFromRunnerParams(params, options);
options.mCudaArch = intToCudaArch(mSM);

FmhaAutoTuner autoTuner(options, optionsFromArgs, params.mMultiProcessorCount);
std::tie(options, optionsFromArgs, ctaDim) = autoTuner.selectKernel();
// Check if the options are valid or not.
checkFmhaOptions(options, optionsFromArgs);
// Update the options if needed.
updateFmhaOptions(options, optionsFromArgs);
// The number of CtasQ and CtasKv per sequence, Ctas in the Y dimension, and Ctas in the Z
// dimension.
computeNumCtas(options, params.mMultiProcessorCount);

if (shouldUseNvrtc(options))
{
// For the NVRTC path, we return supported as long as autotuner successfully selected a kernel config.
std::ostringstream sstream;
populateJsonConfig(options, sstream);
std::string info = sstream.str();

FmhaAutoTuner autoTuner(options, optionsFromArgs, params.mMultiProcessorCount);
std::tie(options, optionsFromArgs, ctaDim) = autoTuner.selectKernel();
// Check if the options are valid or not.
checkFmhaOptions(options, optionsFromArgs);
// Update the options if needed.
updateFmhaOptions(options, optionsFromArgs);
// The number of CtasQ and CtasKv per sequence, Ctas in the Y dimension, and Ctas in the Z
// dimension.
computeNumCtas(options, params.mMultiProcessorCount);
return std::make_pair(true, info);
}
else
{
// Check if a precompiled cubin exists for this configuration (same lookup as run()).
// If not, return (false, info) so the dispatcher can fall back to unfused MHA like on main.
algoFilterForCubinPath(options);
auto [hashId, info] = hashFromFmhaOptions(options);

// Check if a precompiled cubin exists for this configuration (same lookup as run()).
// If not, return (false, info) so the dispatcher can fall back to unfused MHA like on main.
algoFilterForCubinPath(options);
auto [hashId, info] = hashFromFmhaOptions(options);
if (mFunctions.find(hashId) == mFunctions.end())
{
TLLM_LOG_WARNING("Trtllm-gen kernels not found: " + info);
return std::make_pair(false, info);
}
TLLM_LOG_DEBUG("TRTLLM-Gen kernel traits: %s", info.c_str());

if (mFunctions.find(hashId) == mFunctions.end())
return std::make_pair(true, info);
}
}
catch (std::exception const& e)
{
TLLM_LOG_WARNING("Trtllm-gen kernels not found: " + info);
return std::make_pair(false, info);
// Omitting e.what(), they may contain "Runtime Error" and make scripts believe a fatal error happened.
std::string const errorInfo = std::string("Exception during TrtllmGen kernel existence check");
TLLM_LOG_WARNING(errorInfo);
return std::make_pair(false, errorInfo);
Comment thread
pengbowang-nv marked this conversation as resolved.
}
TLLM_LOG_DEBUG("TRTLLM-Gen kernel traits: %s", info.c_str());

return std::make_pair(true, info);
}

void algoFilterForCubinPath(FmhaOptions& options) const
Expand Down Expand Up @@ -285,13 +314,7 @@ class TllmGenFmhaKernel

FmhaData fmhaData;
setFmhaData(params, options, fmhaData);
bool isLlama70bFp4Tp4 = options.mHeadDimQk == 128 && options.mHeadDimV == 128
&& options.mDtypeKv == tg::Dtype::E4m3 && options.mNumHeadsQ == 16 && options.mNumHeadsQPerKv == 8;

bool shouldUseNvrtc = options.mFmhaKernelType == FmhaKernelType::SwapsMmaAbForGeneration && !options.mIsMlaGen
&& options.mDtypeKv != tg::Dtype::E2m1 && options.mHeadDimQk != 64 && !isLlama70bFp4Tp4;

if (shouldUseNvrtc)
if (shouldUseNvrtc(options))
{
// nvrtc path - uses mFmhaInterface member for kernel caching
FmhaConfig fmhaConfig;
Expand Down
Loading