From f5a710786b886dc5d06d3eec78afc42e3244c288 Mon Sep 17 00:00:00 2001 From: Erik Schultheis Date: Fri, 20 Feb 2026 13:32:08 +0100 Subject: [PATCH] small fixes --- kernels/csrc/utils.cuh | 12 ++++++------ kernels/python/quartet2/linear.py | 6 ++++-- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/kernels/csrc/utils.cuh b/kernels/csrc/utils.cuh index 8e4d2f0..c3d9a0d 100644 --- a/kernels/csrc/utils.cuh +++ b/kernels/csrc/utils.cuh @@ -31,22 +31,22 @@ struct m16_n16_k32_c_fragment { }; template -constexpr char ptx_type_name[] = "unknown_dtype"; +inline constexpr char ptx_type_name[] = "unknown_dtype"; template<> -constexpr char ptx_type_name[4] = "f32"; +inline constexpr char ptx_type_name[4] = "f32"; template<> -constexpr char ptx_type_name[4] = "f16"; +inline constexpr char ptx_type_name[4] = "f16"; template<> -constexpr char ptx_type_name[5] = "bf16"; +inline constexpr char ptx_type_name[5] = "bf16"; template<> -constexpr char ptx_type_name<__nv_fp8_e4m3>[5] = "e4m3"; +inline constexpr char ptx_type_name<__nv_fp8_e4m3>[5] = "e4m3"; template<> -constexpr char ptx_type_name<__nv_fp8_e5m2>[5] = "e5m2"; +inline constexpr char ptx_type_name<__nv_fp8_e5m2>[5] = "e5m2"; __device__ __forceinline__ m16_n16_a_fragment load_fragment_a(int lane_id, const nv_bfloat16* base, int ldd) { // see https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-fragment-mma-16816-i8-f8 diff --git a/kernels/python/quartet2/linear.py b/kernels/python/quartet2/linear.py index ef2e90d..4b4fbed 100644 --- a/kernels/python/quartet2/linear.py +++ b/kernels/python/quartet2/linear.py @@ -164,8 +164,10 @@ def forward(ctx, input, weight, had, mode: NVFP4QuantMode, disable_backward_quan if autocast_enabled: input = input.to(torch.bfloat16) weight = weight.to(torch.bfloat16) - - assert input.dtype == torch.bfloat16 + elif weight.dtype != torch.bfloat16: + raise TypeError("Weight must be bfloat16. Either set `dtype=torch.bfloat16` or enable autocast`") + elif input.dtype != torch.bfloat16: + raise TypeError("Input must be bfloat16. Either cast input to bfloat16 or enable autocast`") forward_scale_override = 1.0