@@ -190,6 +190,16 @@ class TllmGenFmhaKernel
190190 }
191191 }
192192
193+ static bool shouldUseNvrtc (FmhaOptions const & options)
194+ {
195+ // Check if the NVRTC path should be used for a given FMHA configuration.
196+ bool isLlama70bFp4Tp4 = options.mHeadDimQk == 128 && options.mHeadDimV == 128
197+ && options.mDtypeKv == tg::Dtype::E4m3 && options.mNumHeadsQ == 16 && options.mNumHeadsQPerKv == 8 ;
198+
199+ return options.mFmhaKernelType == FmhaKernelType::SwapsMmaAbForGeneration && !options.mIsMlaGen
200+ && options.mDtypeKv != tg::Dtype::E2m1 && options.mHeadDimQk != 64 && !isLlama70bFp4Tp4;
201+ }
202+
193203 std::pair<bool , std::string> checkIfKernelExist (RunnerParams const & params) const
194204 {
195205 // Some conditions to check if the kernel is supported.
@@ -207,10 +217,6 @@ class TllmGenFmhaKernel
207217
208218 try
209219 {
210-
211- // The selectKernelParams that might be updated.
212- SelectKernelParams selectKernelParams{params};
213-
214220 int32_t ctaDim = 512 ;
215221 FmhaOptions options;
216222 FmhaOptionsFromArgs optionsFromArgs;
@@ -227,22 +233,35 @@ class TllmGenFmhaKernel
227233 // dimension.
228234 computeNumCtas (options, params.mMultiProcessorCount );
229235
230- // Check if a precompiled cubin exists for this configuration (same lookup as run()).
231- // If not, return (false, info) so the dispatcher can fall back to unfused MHA like on main.
232- algoFilterForCubinPath (options);
233- auto [hashId, info] = hashFromFmhaOptions (options);
234-
235- if (mFunctions .find (hashId) == mFunctions .end ())
236+ if (shouldUseNvrtc (options))
236237 {
237- TLLM_LOG_WARNING (" Trtllm-gen kernels not found: " + info);
238- return std::make_pair (false , info);
238+ // For the NVRTC path, we return supported as long as autotuner successfully selected a kernel config.
239+ std::ostringstream sstream;
240+ populateJsonConfig (options, sstream);
241+ std::string info = sstream.str ();
242+
243+ return std::make_pair (true , info);
239244 }
240- TLLM_LOG_DEBUG (" TRTLLM-Gen kernel traits: %s" , info.c_str ());
245+ else
246+ {
247+ // Check if a precompiled cubin exists for this configuration (same lookup as run()).
248+ // If not, return (false, info) so the dispatcher can fall back to unfused MHA like on main.
249+ algoFilterForCubinPath (options);
250+ auto [hashId, info] = hashFromFmhaOptions (options);
251+
252+ if (mFunctions .find (hashId) == mFunctions .end ())
253+ {
254+ TLLM_LOG_WARNING (" Trtllm-gen kernels not found: " + info);
255+ return std::make_pair (false , info);
256+ }
257+ TLLM_LOG_DEBUG (" TRTLLM-Gen kernel traits: %s" , info.c_str ());
241258
242- return std::make_pair (true , info);
259+ return std::make_pair (true , info);
260+ }
243261 }
244262 catch (std::exception const & e)
245263 {
264+ // Omitting e.what(), they may contain "Runtime Error" and make scripts believe a fatal error happened.
246265 std::string const errorInfo = std::string (" Exception during TrtllmGen kernel existence check" );
247266 TLLM_LOG_WARNING (errorInfo);
248267 return std::make_pair (false , errorInfo);
@@ -295,13 +314,7 @@ class TllmGenFmhaKernel
295314
296315 FmhaData fmhaData;
297316 setFmhaData (params, options, fmhaData);
298- bool isLlama70bFp4Tp4 = options.mHeadDimQk == 128 && options.mHeadDimV == 128
299- && options.mDtypeKv == tg::Dtype::E4m3 && options.mNumHeadsQ == 16 && options.mNumHeadsQPerKv == 8 ;
300-
301- bool shouldUseNvrtc = options.mFmhaKernelType == FmhaKernelType::SwapsMmaAbForGeneration && !options.mIsMlaGen
302- && options.mDtypeKv != tg::Dtype::E2m1 && options.mHeadDimQk != 64 && !isLlama70bFp4Tp4;
303-
304- if (shouldUseNvrtc)
317+ if (shouldUseNvrtc (options))
305318 {
306319 // nvrtc path - uses mFmhaInterface member for kernel caching
307320 FmhaConfig fmhaConfig;
0 commit comments