Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 23 additions & 21 deletions cpp/tensorrt_llm/kernels/fusedLayernormKernels/fp4_converter.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,9 @@ struct FP4Converter<TIn, UE8M0_SF, std::enable_if_t<std::is_same_v<TIn, half> ||
{
}

// write_output=false: participate in warp shuffles but skip global writes (OOB threads).
template <size_t ELTS_PER_THREAD, typename T>
__device__ __forceinline__ void post_process(int rowIdx, int n_base, T packed_input) const
__device__ __forceinline__ void post_process(int rowIdx, int n_base, T packed_input, bool write_output = true) const
{

#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 1000
Expand All @@ -82,10 +83,6 @@ struct FP4Converter<TIn, UE8M0_SF, std::enable_if_t<std::is_same_v<TIn, half> ||

int colIdx = n_base / ELTS_PER_THREAD;

// Get the input tensor offset.
// int inOffset = rowIdx * (numCols / ELTS_PER_THREAD) + colIdx;
// PackedVec vec = reinterpret_cast<PackedVec const*>(in)[inOffset];

// Get absolute maximum values among the local 8 values.
auto localMax = __habs2({packed_input.array[0], packed_input.array[1]});

Expand Down Expand Up @@ -123,9 +120,6 @@ struct FP4Converter<TIn, UE8M0_SF, std::enable_if_t<std::is_same_v<TIn, half> ||
SFValue = static_cast<float>(tmp);
}

auto SFOffset = cvt_quant_get_sf_out_offset<uint32_t, NUM_THREADS_PER_SF>(std::nullopt /* batchIdx */, rowIdx,
colIdx, std::nullopt /* numRows */, numCols / SF_VEC_SIZE, SFout, QuantizationSFLayout::SWIZZLED);
*SFOffset = fp8SFVal;
// Get the output scale.
// Recipe: final_scale = reciprocal(fp32(fp8(SFValue * SFScaleVal))) * reciprocal(SFScaleVal))
float outputScale = reciprocal_approximate_ftz(SFValue * reciprocal_approximate_ftz(SFScaleVal));
Expand All @@ -151,11 +145,16 @@ struct FP4Converter<TIn, UE8M0_SF, std::enable_if_t<std::is_same_v<TIn, half> ||
// Convert to e2m1 values.
uint32_t e2m1Vec = fp32_vec_to_e2m1(fp2Vals);

// Get the output tensor offset.
// Same as inOffset because 8 elements are packed into one uint32_t.
int64_t outOffset = rowIdx * (numCols / ELTS_PER_THREAD) + colIdx;
// Write the e2m1 values to global memory.
out[outOffset] = e2m1Vec;
if (write_output)
{
auto SFOffset
= cvt_quant_get_sf_out_offset<uint32_t, NUM_THREADS_PER_SF>(std::nullopt /* batchIdx */, rowIdx, colIdx,
std::nullopt /* numRows */, numCols / SF_VEC_SIZE, SFout, QuantizationSFLayout::SWIZZLED);
*SFOffset = fp8SFVal;

int64_t outOffset = rowIdx * (numCols / ELTS_PER_THREAD) + colIdx;
out[outOffset] = e2m1Vec;
}
#else
printf("FP4 is not supported pre-Blackwell!\n");
#endif
Expand Down Expand Up @@ -187,7 +186,7 @@ struct FP4Converter<float, UE8M0_SF>
}

template <size_t ELTS_PER_THREAD, typename T>
__device__ __forceinline__ void post_process(int rowIdx, int n_base, T packed_input) const
__device__ __forceinline__ void post_process(int rowIdx, int n_base, T packed_input, bool write_output = true) const
{

#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 1000
Expand Down Expand Up @@ -236,8 +235,6 @@ struct FP4Converter<float, UE8M0_SF>
SFValue = static_cast<float>(tmp);
}

auto SFOffset = cvt_quant_get_sf_out_offset<uint32_t, NUM_THREADS_PER_SF>(std::nullopt /* batchIdx */, rowIdx,
colIdx, std::nullopt /* numRows */, numCols / SF_VEC_SIZE, SFout, QuantizationSFLayout::SWIZZLED);
float outputScale = reciprocal_approximate_ftz(SFValue * reciprocal_approximate_ftz(SFScaleVal));

// Convert the input to float.
Expand All @@ -253,11 +250,16 @@ struct FP4Converter<float, UE8M0_SF>
// Convert to e2m1 values.
uint32_t e2m1Vec = fp32_vec_to_e2m1(fp2Vals);

// Get the output tensor offset.
// Same as inOffset because 8 elements are packed into one uint32_t.
int64_t outOffset = rowIdx * (numCols / ELTS_PER_THREAD) + colIdx;
// Write the e2m1 values to global memory.
out[outOffset] = e2m1Vec;
if (write_output)
{
auto SFOffset
= cvt_quant_get_sf_out_offset<uint32_t, NUM_THREADS_PER_SF>(std::nullopt /* batchIdx */, rowIdx, colIdx,
std::nullopt /* numRows */, numCols / SF_VEC_SIZE, SFout, QuantizationSFLayout::SWIZZLED);
*SFOffset = fp8SFVal;

int64_t outOffset = rowIdx * (numCols / ELTS_PER_THREAD) + colIdx;
out[outOffset] = e2m1Vec;
}
#else
printf("FP4 is not supported pre-Blackwell!\n");
#endif
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -270,9 +270,14 @@ struct LowLatencyLayerNorm
{
auto n_base = (thread_id + i * N_THREADS) * Traits::PACKED_ELEMS_PER_COMPUTE;
auto in_bound = n_base < param.n;
if (!in_bound)

// FP4Converter uses __shfl_xor_sync — all warp threads must stay converged.
if constexpr (std::is_same_v<typename Traits::FusedOperator, void>)
{
break;
if (!in_bound)
{
break;
}
}

typename PackType<typename Traits::OutputType, Traits::PACKED_ELEMS_PER_COMPUTE>::type normed_output;
Expand Down Expand Up @@ -317,13 +322,16 @@ struct LowLatencyLayerNorm
else
{
fused_operator.template post_process<Traits::PACKED_ELEMS_PER_COMPUTE, decltype(normed_output)>(
work_id, n_base, normed_output);
work_id, n_base, normed_output, in_bound);
}
if constexpr (Traits::HIGH_PRECISION_NORMED_OUTPUT)
{
reinterpret_cast<decltype(high_precision_normed_output)*>(
&param.high_precision_normed_output[work_id * param.n + n_base])[0]
= high_precision_normed_output;
if (in_bound)
{
reinterpret_cast<decltype(high_precision_normed_output)*>(
&param.high_precision_normed_output[work_id * param.n + n_base])[0]
= high_precision_normed_output;
}
}
}
}
Expand Down
36 changes: 25 additions & 11 deletions cpp/tensorrt_llm/kernels/fusedLayernormKernels/ws_layernorm.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ struct DummyFusedOperator
}

template <size_t ELEMS_PER_THREAD, typename T>
__device__ __forceinline__ void post_process(int m, int n_base, T packed_input)
__device__ __forceinline__ void post_process(int m, int n_base, T packed_input, bool write_output = true)
{
}
};
Expand Down Expand Up @@ -655,9 +655,14 @@ struct WarpSpecializedLayerNorm

auto n_base = (thread_id + i * 128) * Traits::PACKED_ELEMS_PER_COMPUTE;
auto in_bound = n_base < param.n;
if (!in_bound)

// FP4Converter uses __shfl_xor_sync — all warp threads must stay converged.
if constexpr (std::is_same_v<typename Traits::FusedOperator, void>)
{
break;
if (!in_bound)
{
break;
}
}

if constexpr (Traits::GAMMA)
Expand Down Expand Up @@ -753,16 +758,22 @@ struct WarpSpecializedLayerNorm
= normed_output;
if constexpr (Traits::UNNORMED_OUTPUT)
{
reinterpret_cast<decltype(output)*>(
&shared->output_vec[0][buffer_id][m_offset * Traits::N_BLOCK + n_base])[0]
= output;
if (in_bound)
{
reinterpret_cast<decltype(output)*>(
&shared->output_vec[0][buffer_id][m_offset * Traits::N_BLOCK + n_base])[0]
= output;
}
}
}
else
{
if constexpr (Traits::UNNORMED_OUTPUT)
{
reinterpret_cast<decltype(output)*>(&param.output[m * param.n + n_base])[0] = output;
if (in_bound)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we also add in_bound check for USE_BULK_STORE case?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added the check.
For our current codes, seems USE_BULK_STORE is always false, BTW.

{
reinterpret_cast<decltype(output)*>(&param.output[m * param.n + n_base])[0] = output;
}
}
// TODO: Move this generic writeback into dummy fused operator.
if constexpr (std::is_same_v<typename Traits::FusedOperator, void>)
Expand All @@ -774,13 +785,16 @@ struct WarpSpecializedLayerNorm
{
fused_operator
.template post_process<Traits::PACKED_ELEMS_PER_COMPUTE, decltype(normed_output)>(
m, n_base, normed_output);
m, n_base, normed_output, in_bound);
}
if constexpr (Traits::HIGH_PRECISION_NORMED_OUTPUT)
{
reinterpret_cast<decltype(high_precision_normed_output)*>(
&param.high_precision_normed_output[m * param.n + n_base])[0]
= high_precision_normed_output;
if (in_bound)
{
reinterpret_cast<decltype(high_precision_normed_output)*>(
&param.high_precision_normed_output[m * param.n + n_base])[0]
= high_precision_normed_output;
}
}
}
}
Expand Down
21 changes: 11 additions & 10 deletions tensorrt_llm/_torch/models/modeling_nemotron_h.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,8 @@ def _compute_routed_output():
lora_params=lora_params,
layer_idx=self.layer_idx)
else:
routed_hidden_states = hidden_states
# Use bf16; fused norm's swizzled SF is incompatible with MoE alltoall.
routed_hidden_states = hidden_states_hp

final_hidden_states = self.experts(
routed_hidden_states,
Expand Down Expand Up @@ -397,16 +398,16 @@ def __init__(
if key.startswith(layer_prefix) and cfg.quant_mode.has_nvfp4():
self.is_nvfp4 = True
break
# The fused RMSNorm+NVFP4 CUDA kernel requires hidden_size to be
# a supported tile size. Non-power-of-2 hidden sizes within tile
# ranges may cause kernel hangs. Disable fused NVFP4 for such cases.
# Supported tile sizes: 2048, 4096, 8192, 16384
_SUPPORTED_NVFP4_HIDDEN_SIZES = {2048, 4096, 8192, 16384}
if self.is_nvfp4 and config.hidden_size not in _SUPPORTED_NVFP4_HIDDEN_SIZES:

# Fused RMSNorm+NVFP4 kernel constraints (fusedAddRMSNormQuant.cpp).
hidden_size = config.hidden_size
if self.is_nvfp4 and not (2048 <= hidden_size <= 16384
and hidden_size % 16 == 0):
logger.warning_once(
f"Layer {layer_idx}: Disabling fused NVFP4 RMSNorm for hidden_size={config.hidden_size}. "
f"Supported sizes: {_SUPPORTED_NVFP4_HIDDEN_SIZES}. Using non-fused path.",
key=f"disable_nvfp4_rmsnorm_with_{config.hidden_size}",
f"Layer {layer_idx}: Disabling fused NVFP4 RMSNorm for hidden_size={hidden_size}. "
f"Requires 2048 <= hidden_size <= 16384 and hidden_size % 16 == 0. "
"Using non-fused path.",
key=f"disable_nvfp4_rmsnorm_with_{hidden_size}",
)
self.is_nvfp4 = False
# LoRA layers require regular bf16 tensors, not Fp4QuantizedTensor.
Expand Down
44 changes: 44 additions & 0 deletions tests/unittest/_torch/modules/test_fused_add_rms_norm_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,3 +334,47 @@ def test_low_latency_layernorm_hp_output_consistency(dtype):
sf_out_hp_valid = sf_out_hp[valid_sf_indices]
sf_out_no_hp_valid = sf_out_no_hp[valid_sf_indices]
torch.testing.assert_close(sf_out_hp_valid, sf_out_no_hp_valid, rtol=0, atol=0)


@skip_unsupported
@pytest.mark.parametrize("m,n", [(1, 5120), (4, 7168), (64, 2560)])
def test_fused_add_rms_norm_quant_non_power_of_2_hidden(m, n):
"""Non-power-of-2 hidden sizes must not hang or corrupt output."""
torch.manual_seed(42)
device = torch.device("cuda")
dtype = torch.bfloat16
eps = 1e-6

hidden_states = torch.randn(m, n, dtype=dtype, device=device)
residual = torch.randn(m, n, dtype=dtype, device=device)
gamma = torch.randn(n, dtype=dtype, device=device) * 0.5 + 1.0

normed_ref, residual_ref = rms_norm_ref(hidden_states, residual, gamma, eps)
sf_scale = (normed_ref.abs().amax().float() / (6.0 * 448.0)).view(1)

# Run without hp_output
results_no_hp = torch.ops.trtllm.fused_add_rms_norm_quant(
hidden_states, residual, gamma, sf_scale, True, eps=eps, output_hp_norm=False
)
assert results_no_hp[3] is None
normed_fp4_no_hp, residual_out_no_hp, sf_out_no_hp = results_no_hp[:3]

# Run with hp_output
results_hp = torch.ops.trtllm.fused_add_rms_norm_quant(
hidden_states, residual, gamma, sf_scale, True, eps=eps, output_hp_norm=True
)
normed_fp4_hp, residual_out_hp, sf_out_hp, hp_normed_output = results_hp

# Verify residual and HP norm against references
torch.testing.assert_close(residual_out_no_hp, residual_ref, rtol=1e-3, atol=1e-3)
torch.testing.assert_close(residual_out_hp, residual_ref, rtol=1e-3, atol=1e-3)
torch.testing.assert_close(hp_normed_output, normed_ref, rtol=1e-2, atol=1e-2)

# Quantized outputs must be identical regardless of hp_output flag
torch.testing.assert_close(normed_fp4_hp, normed_fp4_no_hp, rtol=0, atol=0)
torch.testing.assert_close(residual_out_hp, residual_out_no_hp, rtol=0, atol=0)
# Compare only valid SF indices (swizzled layout pads rows to 128)
valid_sf_indices = get_swizzled_sf_indices(m, n)
sf_out_hp_valid = sf_out_hp[valid_sf_indices]
sf_out_no_hp_valid = sf_out_no_hp[valid_sf_indices]
torch.testing.assert_close(sf_out_hp_valid, sf_out_no_hp_valid, rtol=0, atol=0)
Loading