Skip to content

Commit 1d060ec

Browse files
committed
add NVRTC to checkIfKernelExist and extract shouldUseNvrtc utility function
Signed-off-by: Pengbo Wang <221450789+pengbowang-nv@users.noreply.github.com>
1 parent f4e84a9 commit 1d060ec

1 file changed

Lines changed: 34 additions & 21 deletions

File tree

cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/fmhaKernels.h

Lines changed: 34 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,16 @@ class TllmGenFmhaKernel
190190
}
191191
}
192192

193+
static bool shouldUseNvrtc(FmhaOptions const& options)
194+
{
195+
// Check if the NVRTC path should be used for a given FMHA configuration.
196+
bool isLlama70bFp4Tp4 = options.mHeadDimQk == 128 && options.mHeadDimV == 128
197+
&& options.mDtypeKv == tg::Dtype::E4m3 && options.mNumHeadsQ == 16 && options.mNumHeadsQPerKv == 8;
198+
199+
return options.mFmhaKernelType == FmhaKernelType::SwapsMmaAbForGeneration && !options.mIsMlaGen
200+
&& options.mDtypeKv != tg::Dtype::E2m1 && options.mHeadDimQk != 64 && !isLlama70bFp4Tp4;
201+
}
202+
193203
std::pair<bool, std::string> checkIfKernelExist(RunnerParams const& params) const
194204
{
195205
// Some conditions to check if the kernel is supported.
@@ -207,10 +217,6 @@ class TllmGenFmhaKernel
207217

208218
try
209219
{
210-
211-
// The selectKernelParams that might be updated.
212-
SelectKernelParams selectKernelParams{params};
213-
214220
int32_t ctaDim = 512;
215221
FmhaOptions options;
216222
FmhaOptionsFromArgs optionsFromArgs;
@@ -227,22 +233,35 @@ class TllmGenFmhaKernel
227233
// dimension.
228234
computeNumCtas(options, params.mMultiProcessorCount);
229235

230-
// Check if a precompiled cubin exists for this configuration (same lookup as run()).
231-
// If not, return (false, info) so the dispatcher can fall back to unfused MHA like on main.
232-
algoFilterForCubinPath(options);
233-
auto [hashId, info] = hashFromFmhaOptions(options);
234-
235-
if (mFunctions.find(hashId) == mFunctions.end())
236+
if (shouldUseNvrtc(options))
236237
{
237-
TLLM_LOG_WARNING("Trtllm-gen kernels not found: " + info);
238-
return std::make_pair(false, info);
238+
// For the NVRTC path, we return supported as long as autotuner successfully selected a kernel config.
239+
std::ostringstream sstream;
240+
populateJsonConfig(options, sstream);
241+
std::string info = sstream.str();
242+
243+
return std::make_pair(true, info);
239244
}
240-
TLLM_LOG_DEBUG("TRTLLM-Gen kernel traits: %s", info.c_str());
245+
else
246+
{
247+
// Check if a precompiled cubin exists for this configuration (same lookup as run()).
248+
// If not, return (false, info) so the dispatcher can fall back to unfused MHA like on main.
249+
algoFilterForCubinPath(options);
250+
auto [hashId, info] = hashFromFmhaOptions(options);
251+
252+
if (mFunctions.find(hashId) == mFunctions.end())
253+
{
254+
TLLM_LOG_WARNING("Trtllm-gen kernels not found: " + info);
255+
return std::make_pair(false, info);
256+
}
257+
TLLM_LOG_DEBUG("TRTLLM-Gen kernel traits: %s", info.c_str());
241258

242-
return std::make_pair(true, info);
259+
return std::make_pair(true, info);
260+
}
243261
}
244262
catch (std::exception const& e)
245263
{
264+
// Omitting e.what(), they may contain "Runtime Error" and make scripts believe a fatal error happened.
246265
std::string const errorInfo = std::string("Exception during TrtllmGen kernel existence check");
247266
TLLM_LOG_WARNING(errorInfo);
248267
return std::make_pair(false, errorInfo);
@@ -295,13 +314,7 @@ class TllmGenFmhaKernel
295314

296315
FmhaData fmhaData;
297316
setFmhaData(params, options, fmhaData);
298-
bool isLlama70bFp4Tp4 = options.mHeadDimQk == 128 && options.mHeadDimV == 128
299-
&& options.mDtypeKv == tg::Dtype::E4m3 && options.mNumHeadsQ == 16 && options.mNumHeadsQPerKv == 8;
300-
301-
bool shouldUseNvrtc = options.mFmhaKernelType == FmhaKernelType::SwapsMmaAbForGeneration && !options.mIsMlaGen
302-
&& options.mDtypeKv != tg::Dtype::E2m1 && options.mHeadDimQk != 64 && !isLlama70bFp4Tp4;
303-
304-
if (shouldUseNvrtc)
317+
if (shouldUseNvrtc(options))
305318
{
306319
// nvrtc path - uses mFmhaInterface member for kernel caching
307320
FmhaConfig fmhaConfig;

0 commit comments

Comments
 (0)