diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 2e5eaff9bf4..ad30ecd8fa5 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -924,6 +924,13 @@ struct ggml_cuda_type_traits { static constexpr int qr = 1; }; +template<> +struct ggml_cuda_type_traits { + static constexpr int qk = QK1_0; + static constexpr int qr = QR1_0; + static constexpr int qi = QI1_0; +}; + template<> struct ggml_cuda_type_traits { static constexpr int qk = QK4_0; diff --git a/ggml/src/ggml-cuda/convert.cu b/ggml/src/ggml-cuda/convert.cu index 79ccfe568a2..61630a35a29 100644 --- a/ggml/src/ggml-cuda/convert.cu +++ b/ggml/src/ggml-cuda/convert.cu @@ -711,6 +711,8 @@ to_bf16_cuda_t ggml_get_to_bf16_cuda(ggml_type type) { to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) { switch (type) { + case GGML_TYPE_Q1_0: + return dequantize_block_cont_cuda; case GGML_TYPE_Q4_0: return dequantize_row_q4_0_cuda; case GGML_TYPE_Q4_1: @@ -767,6 +769,8 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) { to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) { switch (type) { + case GGML_TYPE_Q1_0: + return dequantize_block_cont_cuda; case GGML_TYPE_Q4_0: return dequantize_row_q4_0_cuda; case GGML_TYPE_Q4_1: @@ -822,6 +826,8 @@ to_fp16_nc_cuda_t ggml_get_to_fp16_nc_cuda(ggml_type type) { switch (type) { case GGML_TYPE_F32: return convert_unary_cuda; + case GGML_TYPE_Q1_0: + return dequantize_block_cuda; case GGML_TYPE_Q4_0: return dequantize_block_cuda; case GGML_TYPE_Q4_1: @@ -843,6 +849,8 @@ to_bf16_nc_cuda_t ggml_get_to_bf16_nc_cuda(ggml_type type) { switch (type) { case GGML_TYPE_F32: return convert_unary_cuda; + case GGML_TYPE_Q1_0: + return dequantize_block_cuda; case GGML_TYPE_Q4_0: return dequantize_block_cuda; case GGML_TYPE_Q4_1: @@ -864,6 +872,8 @@ to_fp32_nc_cuda_t ggml_get_to_fp32_nc_cuda(ggml_type type) { switch (type) { case GGML_TYPE_F16: return convert_unary_cuda; + case GGML_TYPE_Q1_0: + return dequantize_block_cuda; case GGML_TYPE_Q4_0: return dequantize_block_cuda; case GGML_TYPE_Q4_1: diff --git a/ggml/src/ggml-cuda/dequantize.cuh b/ggml/src/ggml-cuda/dequantize.cuh index e060fb29fdc..9ae1342fc0e 100644 --- a/ggml/src/ggml-cuda/dequantize.cuh +++ b/ggml/src/ggml-cuda/dequantize.cuh @@ -1,5 +1,27 @@ #include "common.cuh" +static __device__ __forceinline__ void dequantize_q1_0(const void * vx, const int64_t ib, const int iqs, float2 & v){ + const block_q1_0 * x = (const block_q1_0 *) vx; + + const float d = x[ib].d; + + const int bit_index_0 = iqs; + const int bit_index_1 = iqs + 1; + + const int byte_index_0 = bit_index_0 / 8; + const int bit_offset_0 = bit_index_0 % 8; + + const int byte_index_1 = bit_index_1 / 8; + const int bit_offset_1 = bit_index_1 % 8; + + // Extract bits: 1 = +d, 0 = -d (branchless) + const int bit_0 = (x[ib].qs[byte_index_0] >> bit_offset_0) & 1; + const int bit_1 = (x[ib].qs[byte_index_1] >> bit_offset_1) & 1; + + v.x = (2*bit_0 - 1) * d; + v.y = (2*bit_1 - 1) * d; +} + static __device__ __forceinline__ void dequantize_q4_0(const void * vx, const int64_t ib, const int iqs, float2 & v){ const block_q4_0 * x = (const block_q4_0 *) vx; diff --git a/ggml/src/ggml-cuda/getrows.cu b/ggml/src/ggml-cuda/getrows.cu index 2fab33243dd..e99cba63d34 100644 --- a/ggml/src/ggml-cuda/getrows.cu +++ b/ggml/src/ggml-cuda/getrows.cu @@ -179,6 +179,10 @@ static void ggml_cuda_get_rows_switch_src0_type( get_rows_cuda_float((const nv_bfloat16 *) src0_d, src1_d, dst_d, ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream); break; + case GGML_TYPE_Q1_0: + get_rows_cuda_q(src0_d, src1_d, dst_d, + ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream); + break; case GGML_TYPE_Q4_0: get_rows_cuda_q(src0_d, src1_d, dst_d, ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream); diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index c17db3875ad..790f53cead7 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -4831,6 +4831,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g switch (a->type) { case GGML_TYPE_F32: case GGML_TYPE_F16: + case GGML_TYPE_Q1_0: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: @@ -4868,6 +4869,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_TYPE_F32: case GGML_TYPE_BF16: case GGML_TYPE_I32: + case GGML_TYPE_Q1_0: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: diff --git a/ggml/src/ggml-cuda/mmq.cu b/ggml/src/ggml-cuda/mmq.cu index 27b4145ac9a..3f01ff5bfb0 100644 --- a/ggml/src/ggml-cuda/mmq.cu +++ b/ggml/src/ggml-cuda/mmq.cu @@ -5,6 +5,9 @@ static void ggml_cuda_mul_mat_q_switch_type(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) { switch (args.type_x) { + case GGML_TYPE_Q1_0: + mul_mat_q_case(ctx, args, stream); + break; case GGML_TYPE_Q4_0: mul_mat_q_case(ctx, args, stream); break; @@ -270,6 +273,7 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11, int64_t bool mmq_supported; switch (type) { + case GGML_TYPE_Q1_0: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh index 18911141472..28b662df925 100644 --- a/ggml/src/ggml-cuda/mmq.cuh +++ b/ggml/src/ggml-cuda/mmq.cuh @@ -57,6 +57,8 @@ static_assert(sizeof(block_fp4_mmq) == sizeof(block_q8_1_mmq), "Unexpected b static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) { switch (type_x) { + case GGML_TYPE_Q1_0: + return MMQ_Q8_1_DS_LAYOUT_D4; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: return MMQ_Q8_1_DS_LAYOUT_DS4; @@ -185,6 +187,7 @@ static constexpr __device__ int get_mmq_y_device() { static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml_type type, int mmq_y) { switch (type) { + case GGML_TYPE_Q1_0: return MMQ_DP4A_TXS_Q8_0; case GGML_TYPE_Q4_0: return MMQ_DP4A_TXS_Q4_0; case GGML_TYPE_Q4_1: return MMQ_DP4A_TXS_Q4_1; case GGML_TYPE_Q5_0: return MMQ_DP4A_TXS_Q8_0; @@ -229,6 +232,7 @@ static_assert(MMQ_MMA_TILE_X_K_NVFP4 % 8 == 4, "Wrong padding."); static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) { switch (type) { + case GGML_TYPE_Q1_0: return MMQ_MMA_TILE_X_K_Q8_0; case GGML_TYPE_Q4_0: return MMQ_MMA_TILE_X_K_Q8_0; case GGML_TYPE_Q4_1: return MMQ_MMA_TILE_X_K_Q8_1; case GGML_TYPE_Q5_0: return MMQ_MMA_TILE_X_K_Q8_0; @@ -302,6 +306,87 @@ static constexpr __device__ int mmq_get_nwarps_device() { // ------------------------------------------------------------ +template static __device__ __forceinline__ void load_tiles_q1_0( + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + 2*MMQ_TILE_NE_K); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q8_0, mmq_y); + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + + constexpr int blocks_per_iter = MMQ_ITER_K / QK1_0; + constexpr int threads_per_row = blocks_per_iter * QI1_0; + constexpr int nrows = warp_size / threads_per_row; + constexpr int scale_entries_per_block = QK1_0 / QK8_1; + constexpr int scale_entries_per_row = blocks_per_iter * scale_entries_per_block; + + const int txi = threadIdx.x % threads_per_row; + const int kbx = txi / QI1_0; + const int kqsx = txi % QI1_0; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) { + int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row; + + if (need_check) { + i = min(i, i_max); + } + + const block_q1_0 * bxi = (const block_q1_0 *) x + kbx0 + i*stride + kbx; + const int qs_offset = 4*kqsx; + const int qs0 = bxi->qs[qs_offset + 0] | (bxi->qs[qs_offset + 1] << 8) | + (bxi->qs[qs_offset + 2] << 16) | (bxi->qs[qs_offset + 3] << 24); + + int unpacked_bytes[8]; +#pragma unroll + for (int j = 0; j < 8; ++j) { + const int shift = j * 4; + const int bits4 = (qs0 >> shift) & 0x0F; + const int b0 = (bits4 & 0x01) ? 1 : -1; + const int b1 = (bits4 & 0x02) ? 1 : -1; + const int b2 = (bits4 & 0x04) ? 1 : -1; + const int b3 = (bits4 & 0x08) ? 1 : -1; + unpacked_bytes[j] = (b0 & 0xFF) | ((b1 & 0xFF) << 8) | ((b2 & 0xFF) << 16) | ((b3 & 0xFF) << 24); + } + + const int dst_offset = kbx*(scale_entries_per_block*QI8_0) + kqsx*QI8_0; +#pragma unroll + for (int j = 0; j < 8; ++j) { +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + dst_offset + j] = unpacked_bytes[j]; +#else + x_qs[i*(2*MMQ_TILE_NE_K + 1) + dst_offset + j] = unpacked_bytes[j]; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + } + } + + const int ksx = threadIdx.x % scale_entries_per_row; + const int scale_block = ksx / scale_entries_per_block; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { + int i = i0 + threadIdx.y; + + if (need_check) { + i = min(i, i_max); + } + + const block_q1_0 * bxi = (const block_q1_0 *) x + kbx0 + i*stride + scale_block; + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + ksx] = bxi->d; +#else + x_df[i*(2*MMQ_TILE_NE_K/QI8_0) + i/(QI8_0/2) + ksx] = bxi->d; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + } +} + template static __device__ __forceinline__ void load_tiles_q4_0( const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { constexpr int nwarps = mmq_get_nwarps_device(); @@ -3290,6 +3375,14 @@ static __device__ __forceinline__ void mmq_write_back_mma( template struct mmq_type_traits; +template +struct mmq_type_traits { + static constexpr int vdr = VDR_Q1_0_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_q1_0; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; +}; + template struct mmq_type_traits { static constexpr int vdr = VDR_Q4_0_Q8_1_MMQ; diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu index 07b10167bc4..8f55cace1a1 100644 --- a/ggml/src/ggml-cuda/mmvq.cu +++ b/ggml/src/ggml-cuda/mmvq.cu @@ -9,6 +9,7 @@ typedef float (*vec_dot_q_cuda_t)(const void * __restrict__ vbq, const block_q8_ static constexpr __device__ vec_dot_q_cuda_t get_vec_dot_q_cuda(ggml_type type) { switch (type) { + case GGML_TYPE_Q1_0: return vec_dot_q1_0_q8_1; case GGML_TYPE_Q4_0: return vec_dot_q4_0_q8_1; case GGML_TYPE_Q4_1: return vec_dot_q4_1_q8_1; case GGML_TYPE_Q5_0: return vec_dot_q5_0_q8_1; @@ -36,6 +37,7 @@ static constexpr __device__ vec_dot_q_cuda_t get_vec_dot_q_cuda(ggml_type type) static constexpr __host__ __device__ int get_vdr_mmvq(ggml_type type) { switch (type) { + case GGML_TYPE_Q1_0: return VDR_Q1_0_Q8_1_MMVQ; case GGML_TYPE_Q4_0: return VDR_Q4_0_Q8_1_MMVQ; case GGML_TYPE_Q4_1: return VDR_Q4_1_Q8_1_MMVQ; case GGML_TYPE_Q5_0: return VDR_Q5_0_Q8_1_MMVQ; @@ -886,6 +888,12 @@ static void mul_mat_vec_q_switch_type( const int nsamples_x, const int nsamples_dst, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst, const int ids_stride, cudaStream_t stream) { switch (type_x) { + case GGML_TYPE_Q1_0: + mul_mat_vec_q_switch_ncols_dst + (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); + break; case GGML_TYPE_Q4_0: mul_mat_vec_q_switch_ncols_dst (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, diff --git a/ggml/src/ggml-cuda/template-instances/generate_cu_files.py b/ggml/src/ggml-cuda/template-instances/generate_cu_files.py index 40d51f93fa4..841059c15b5 100755 --- a/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +++ b/ggml/src/ggml-cuda/template-instances/generate_cu_files.py @@ -32,6 +32,7 @@ SOURCE_FATTN_MMA_CASE = "DECL_FATTN_MMA_F16_CASE({head_size_kq}, {head_size_v}, {ncols1}, {ncols2});\n" TYPES_MMQ = [ + "GGML_TYPE_Q1_0", "GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0", "GGML_TYPE_Q2_K", "GGML_TYPE_Q3_K", "GGML_TYPE_Q4_K", "GGML_TYPE_Q5_K", "GGML_TYPE_Q6_K", "GGML_TYPE_IQ2_XXS", "GGML_TYPE_IQ2_XS", "GGML_TYPE_IQ2_S", "GGML_TYPE_IQ3_XXS", "GGML_TYPE_IQ3_S", diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-q1_0.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-q1_0.cu new file mode 100644 index 00000000000..f0686b0d0d8 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmq-instance-q1_0.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmq.cuh" + +DECL_MMQ_CASE(GGML_TYPE_Q1_0); diff --git a/ggml/src/ggml-cuda/vecdotq.cuh b/ggml/src/ggml-cuda/vecdotq.cuh index 40b2b41e7e8..d1741cc8d7b 100644 --- a/ggml/src/ggml-cuda/vecdotq.cuh +++ b/ggml/src/ggml-cuda/vecdotq.cuh @@ -106,6 +106,9 @@ static __device__ __forceinline__ uint32_t unpack_ksigns(const uint8_t v) { // VDR = vec dot ratio, how many contiguous integers each thread processes when the vec dot kernel is called // MMVQ = mul_mat_vec_q, MMQ = mul_mat_q +#define VDR_Q1_0_Q8_1_MMVQ 1 // Process one 32-element chunk at a time for parallelism +#define VDR_Q1_0_Q8_1_MMQ 4 // Q1_0 has 128 bits (4 ints) per block + #define VDR_Q4_0_Q8_1_MMVQ 2 #define VDR_Q4_0_Q8_1_MMQ 4 @@ -669,6 +672,51 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmq( return d6 * sumf_d; } +static __device__ __forceinline__ float vec_dot_q1_0_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) { + + const block_q1_0 * bq1_0 = (const block_q1_0 *) vbq + kbx; + + // Q1_0: 128 elements with ONE scale + // Q8_1: 32 elements per block with individual scales + // iqs selects which of the 4 chunks of 32 elements to process (0-3) + + const float d1 = bq1_0->d; + + // Process only the chunk specified by iqs + const block_q8_1 * bq8_1_chunk = bq8_1 + iqs; + + // Load 32 bits (4 bytes) for this chunk from Q1_0 + const int offset = iqs * 4; + const int v = bq1_0->qs[offset + 0] | (bq1_0->qs[offset + 1] << 8) | + (bq1_0->qs[offset + 2] << 16) | (bq1_0->qs[offset + 3] << 24); + + // Unpack 32 bits into 32 signed values (-1 or +1) + int vi_bytes[8]; +#pragma unroll + for (int j = 0; j < 8; ++j) { + const int shift = j * 4; + const int bits4 = (v >> shift) & 0x0F; + const int b0 = (bits4 & 0x01) ? 1 : -1; + const int b1 = (bits4 & 0x02) ? 1 : -1; + const int b2 = (bits4 & 0x04) ? 1 : -1; + const int b3 = (bits4 & 0x08) ? 1 : -1; + vi_bytes[j] = (b0 & 0xFF) | ((b1 & 0xFF) << 8) | ((b2 & 0xFF) << 16) | ((b3 & 0xFF) << 24); + } + + // Compute dot product for this 32-element chunk + int sumi = 0; +#pragma unroll + for (int j = 0; j < 8; ++j) { + const int u = get_int_b4(bq8_1_chunk->qs, j); + sumi = ggml_cuda_dp4a(vi_bytes[j], u, sumi); + } + + // Apply Q1_0's single scale and this chunk's Q8_1 scale + const float d8 = __low2float(bq8_1_chunk->ds); + return d1 * d8 * sumi; +} + static __device__ __forceinline__ float vec_dot_q4_0_q8_1( const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) { diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index b2a54bd85d0..702a249d754 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -1394,7 +1394,7 @@ struct vk_op_im2col_push_constants { uint32_t IW; uint32_t IH; uint32_t OW; uint32_t OH; uint32_t KW; uint32_t KH; - uint32_t pelements; + uint32_t OH_batch; uint32_t CHW; int32_t s0; int32_t s1; int32_t p0; int32_t p1; @@ -10064,7 +10064,13 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co const uint32_t batch = src1->ne[is_2D ? 3 : 2]; - elements = { OW * KW * KH, OH, batch * IC }; + const uint32_t CHW = IC * KH * KW; + // Cap X workgroups to limit concurrent IC channel reads. + // The shader loops over X to cover the full CHW dimension. + // AMD prefers a lower limit + const uint32_t min_cap = ctx->device->vendor_id == VK_VENDOR_ID_AMD ? 512u : 4096u; + const uint32_t x_elements = std::min(CHW, std::max(min_cap, OW * KH * KW)); + elements = { x_elements, OW, OH * batch }; elements[1] = std::min(elements[1], ctx->device->properties.limits.maxComputeWorkGroupCount[1]); elements[2] = std::min(elements[2], ctx->device->properties.limits.maxComputeWorkGroupCount[2]); } break; @@ -11727,7 +11733,6 @@ static void ggml_vk_im2col(ggml_backend_vk_context * ctx, vk_context& subctx, co const uint32_t offset_delta = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32 const uint32_t batch_offset = src1->nb[is_2D ? 3 : 2] / 4; // nb is byte offset, src is type float32 - const uint32_t pelements = OW * KW * KH; const uint32_t batch = src1->ne[is_2D ? 3 : 2]; const ggml_backend_vk_buffer_context * d_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; @@ -11739,7 +11744,7 @@ static void ggml_vk_im2col(ggml_backend_vk_context * ctx, vk_context& subctx, co dst_addr, batch_offset, offset_delta, IC, IW, IH, OW, OH, KW, KH, - pelements, + OH * batch, IC * KH * KW, s0, s1, p0, p1, d0, d1, batch * IC }); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp b/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp index 674f91e5ed2..ba4c2103f0c 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp @@ -13,7 +13,7 @@ layout (push_constant) uniform parameter uint IW; uint IH; uint OW; uint OH; uint KW; uint KH; - uint pelements; + uint OH_batch; uint CHW; int s0; int s1; int p0; int p1; @@ -34,82 +34,60 @@ layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; layout (buffer_reference) buffer D_ptr {D_TYPE d;}; #endif -void im2col(const uint y, const uint z) { - const uint gidx = gl_GlobalInvocationID.x; +void im2col(const uint ow, const uint z_idx) { + const uint oh = z_idx % p.OH; + const uint batch_idx = z_idx / p.OH; - const uint oh = y; - const uint batch = z / p.IC; - const uint ic = z % p.IC; + const uint gidx = gl_LocalInvocationID.x; + const uint src_batch = batch_idx * p.batch_offset; + const BDA_OFFSET_T dst_row = ((BDA_OFFSET_T(batch_idx) * p.OH + oh) * p.OW + ow) * p.CHW; - const uint src_base = ic * p.offset_delta + batch * p.batch_offset; - const BDA_OFFSET_T dst_base = ((BDA_OFFSET_T(batch) * p.OH + oh) * p.OW) * p.CHW + BDA_OFFSET_T(ic) * (p.KW * p.KH); - const int oh_s1 = int(oh) * p.s1; - const uint ksize = p.OW * p.KH; + const uint KHKW = p.KH * p.KW; - const uint base_linear_idx = gidx * NUM_ITER; + uint wg_x = gl_WorkGroupID.x; + do { + const uint wg_offset = wg_x * 512; - uint current_kx = base_linear_idx / ksize; - const uint rem = base_linear_idx - (current_kx * ksize); - uint current_ky = rem / p.OW; - uint current_ix = rem % p.OW; + [[unroll]] for (uint i = 0; i < NUM_ITER; ++i) { + const uint chw_idx = wg_offset + gidx + i * BLOCK_SIZE; - A_TYPE values[NUM_ITER]; - BDA_OFFSET_T offset_dst[NUM_ITER]; - [[unroll]] for (uint idx = 0; idx < NUM_ITER; ++idx) { - values[idx] = A_TYPE(0); - } - - [[unroll]] for (uint idx = 0; idx < NUM_ITER; ++idx) { - - const uint linear_idx = base_linear_idx + idx; - - if (linear_idx >= p.pelements) { - continue; - } - - const uint iiw = current_ix * p.s0 + current_kx * p.d0 - p.p0; - const uint iih = oh_s1 + current_ky * p.d1 - p.p1; - - offset_dst[idx] = dst_base + BDA_OFFSET_T(current_ix) * p.CHW + current_ky * p.KW + current_kx; - - if ((iih < p.IH) && (iiw < p.IW)) { - values[idx] = data_a[src_base + iih * p.IW + iiw]; - } - - if (++current_ix == p.OW) { - current_ix = 0; - if (++current_ky == p.KH) { - current_ky = 0; - current_kx++; + if (chw_idx >= p.CHW) { + return; } - } - } - [[unroll]] for (uint idx = 0; idx < NUM_ITER; ++idx) { + const uint ic = chw_idx / KHKW; + const uint rem = chw_idx - ic * KHKW; + const uint ky = rem / p.KW; + const uint kx = rem - ky * p.KW; - const uint linear_idx = base_linear_idx + idx; + const uint iiw = ow * p.s0 + kx * p.d0 - p.p0; + const uint iih = oh * p.s1 + ky * p.d1 - p.p1; - if (linear_idx >= p.pelements) { - continue; - } + A_TYPE val = A_TYPE(0); + if (iih < p.IH && iiw < p.IW) { + val = data_a[src_batch + ic * p.offset_delta + iih * p.IW + iiw]; + } #if BDA - D_ptr dst_addr = D_ptr(p.dst_addr + D_SIZE * offset_dst[idx]); - dst_addr.d = D_TYPE(values[idx]); + D_ptr out_ptr = D_ptr(p.dst_addr + D_SIZE * (dst_row + chw_idx)); + out_ptr.d = D_TYPE(val); #else - data_d[offset_dst[idx]] = D_TYPE(values[idx]); + data_d[dst_row + chw_idx] = D_TYPE(val); #endif - } + } + + wg_x += gl_NumWorkGroups.x; + } while (wg_x * 512 < p.CHW); } void main() { - uint y = gl_GlobalInvocationID.y; - while (y < p.OH) { + uint ow = gl_GlobalInvocationID.y; + while (ow < p.OW) { uint z = gl_GlobalInvocationID.z; - while (z < p.batch_IC) { - im2col(y, z); + while (z < p.OH_batch) { + im2col(ow, z); z += gl_NumWorkGroups.z; } - y += gl_NumWorkGroups.y; + ow += gl_NumWorkGroups.y; } } diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl b/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl index 0d3501c34a2..62fe72ee3b1 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl @@ -9,42 +9,65 @@ fn get_byte_i32(value: u32, index: u32) -> i32 { #endif #ifdef U32_DEQUANT_HELPERS -fn load_u16_at( - buf: ptr, read_write>, - byte_offset: u32) -> u32 { - let word = buf[byte_offset / 4]; - let shift = (byte_offset & 0x2) * 8; - return (word >> shift) & 0xFFFF; +#ifdef DECLARE_BYTE_LOADERS_SRC +fn load_u16_at_src(byte_offset: u32) -> u32 { + let word = src[byte_offset / 4u]; + let shift = (byte_offset & 0x2u) * 8u; + return (word >> shift) & 0xFFFFu; } -fn load_u32_at( - buf: ptr, read_write>, - byte_offset: u32) -> u32 { - let word_idx = byte_offset / 4; - let shift = (byte_offset & 0x3) * 8; - let lo = buf[word_idx]; - let hi = buf[word_idx + 1]; - let shifted = (lo >> shift) | (hi << (32 - shift)); - return select(shifted, lo, shift == 0); +fn load_u32_at_src(byte_offset: u32) -> u32 { + let word_idx = byte_offset / 4u; + let shift = (byte_offset & 0x3u) * 8u; + let lo = src[word_idx]; + let hi = src[word_idx + 1u]; + let shifted = (lo >> shift) | (hi << (32u - shift)); + return select(shifted, lo, shift == 0u); } -fn load_f16_at( - buf: ptr, read_write>, - byte_offset: u32) -> f16 { - let packed = unpack2x16float(load_u16_at(buf, byte_offset)); +fn load_f16_at_src(byte_offset: u32) -> f16 { + let packed = unpack2x16float(load_u16_at_src(byte_offset)); return f16(packed[0]); } -fn load_f16_as_f32_at( - buf: ptr, read_write>, - byte_offset: u32) -> f32 { - let word = buf[byte_offset / 4]; - let shift = (byte_offset & 0x2) * 8; - let d_bits = (word >> shift) & 0xFFFF; +fn load_f16_as_f32_at_src(byte_offset: u32) -> f32 { + let word = src[byte_offset / 4u]; + let shift = (byte_offset & 0x2u) * 8u; + let d_bits = (word >> shift) & 0xFFFFu; return unpack2x16float(d_bits)[0]; } #endif +#ifdef DECLARE_BYTE_LOADERS_SRC0 +fn load_u16_at_src0(byte_offset: u32) -> u32 { + let word = src0[byte_offset / 4u]; + let shift = (byte_offset & 0x2u) * 8u; + return (word >> shift) & 0xFFFFu; +} + +fn load_u32_at_src0(byte_offset: u32) -> u32 { + let word_idx = byte_offset / 4u; + let shift = (byte_offset & 0x3u) * 8u; + let lo = src0[word_idx]; + let hi = src0[word_idx + 1u]; + let shifted = (lo >> shift) | (hi << (32u - shift)); + return select(shifted, lo, shift == 0u); +} + +fn load_f16_at_src0(byte_offset: u32) -> f16 { + let packed = unpack2x16float(load_u16_at_src0(byte_offset)); + return f16(packed[0]); +} + +fn load_f16_as_f32_at_src0(byte_offset: u32) -> f32 { + let word = src0[byte_offset / 4u]; + let shift = (byte_offset & 0x2u) * 8u; + let d_bits = (word >> shift) & 0xFFFFu; + return unpack2x16float(d_bits)[0]; +} +#endif +#endif + #ifdef Q4_1_T diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.wgsl index 3c8b84c9ac3..1415798fa6b 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.wgsl @@ -1,6 +1,8 @@ enable f16; +#define DECLARE_BYTE_LOADERS_SRC #include "common_decls.tmpl" + #ifdef F32_VEC fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { dst[(dst_base / 4) + offset] = src[(src_base / 4) + offset]; @@ -28,10 +30,10 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { #ifdef Q4_0 fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { let block_byte_base = (src_base + offset) * 18; // Block stride: 18 bytes - let d = load_f16_as_f32_at(&src, block_byte_base); + let d = load_f16_as_f32_at_src(block_byte_base); for (var j: u32 = 0u; j < 4; j++) { let q_byte_offset = block_byte_base + 2 + j * 4; - let q_packed = load_u32_at(&src, q_byte_offset); + let q_packed = load_u32_at_src(q_byte_offset); for (var k: u32 = 0; k < 4; k++) { let q_byte = get_byte(q_packed, k); let q_hi = (f32((q_byte >> 4) & 0xF) - 8.0) * d; @@ -66,11 +68,11 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { #ifdef Q5_0 fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { let block_byte_base = (src_base + offset) * 22; // Block stride: 22 bytes - let d = load_f16_as_f32_at(&src, block_byte_base); - let qh_packed = load_u32_at(&src, block_byte_base + 2); + let d = load_f16_as_f32_at_src(block_byte_base); + let qh_packed = load_u32_at_src(block_byte_base + 2); for (var j: u32 = 0; j < 4; j++) { let q_byte_offset = block_byte_base + 6 + j * 4; - let q_packed = load_u32_at(&src, q_byte_offset); + let q_packed = load_u32_at_src(q_byte_offset); for (var k: u32 = 0; k < 4; k++) { let q_byte = get_byte(q_packed, k); @@ -113,10 +115,10 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { #ifdef Q8_0 fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { let block_byte_base = (src_base + offset) * 34; // Block stride: 34 bytes - let d = load_f16_as_f32_at(&src, block_byte_base); + let d = load_f16_as_f32_at_src(block_byte_base); for (var j: u32 = 0u; j < 8u; j++) { let q_byte_offset = block_byte_base + 2u + j * 4u; - let q_packed = load_u32_at(&src, q_byte_offset); + let q_packed = load_u32_at_src(q_byte_offset); for (var k: u32 = 0u; k < 4u; k++) { let q_byte = get_byte_i32(q_packed, k); let q_val = f32(q_byte) * d; @@ -162,16 +164,16 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { let block_byte_base = (src_base + offset) * 110; // Block stride: 110 bytes // Bytes 108-109: f16 scale 'd' - let d = load_f16_as_f32_at(&src, block_byte_base + 108); + let d = load_f16_as_f32_at_src(block_byte_base + 108); // Bytes 96-107: 12 bytes of scales (3 u32s) let kmask1: u32 = 0x03030303; let kmask2: u32 = 0x0f0f0f0f; var scale_vals: array; - scale_vals[0] = load_u32_at(&src, block_byte_base + 96); - scale_vals[1] = load_u32_at(&src, block_byte_base + 100); - scale_vals[2] = load_u32_at(&src, block_byte_base + 104); + scale_vals[0] = load_u32_at_src(block_byte_base + 96); + scale_vals[1] = load_u32_at_src(block_byte_base + 100); + scale_vals[2] = load_u32_at_src(block_byte_base + 104); var tmp: u32 = scale_vals[2]; scale_vals[2] = ((scale_vals[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4); @@ -182,13 +184,13 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { // Bytes 0-31: 32 bytes of hmask (8 u32s) var hmask_vals: array; for (var i: u32 = 0; i < 8; i++) { - hmask_vals[i] = load_u32_at(&src, block_byte_base + i * 4); + hmask_vals[i] = load_u32_at_src(block_byte_base + i * 4); } // Bytes 32-95: 64 bytes of qs (16 u32s) var qs_vals: array; for (var i: u32 = 0u; i < 16; i++) { - qs_vals[i] = load_u32_at(&src, block_byte_base + 32 + i * 4); + qs_vals[i] = load_u32_at_src(block_byte_base + 32 + i * 4); } var dst_i = dst_base + offset * 256; @@ -286,24 +288,24 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { let block_byte_base = (src_base + offset) * 210; // Block stride: 210 bytes // Bytes 208-209: f16 scale 'd' - let d = load_f16_as_f32_at(&src, block_byte_base + 208); + let d = load_f16_as_f32_at_src(block_byte_base + 208); // Bytes 0-127: 128 bytes of ql (32 u32s) var ql_vals: array; for (var i: u32 = 0; i < 32; i++) { - ql_vals[i] = load_u32_at(&src, block_byte_base + i * 4); + ql_vals[i] = load_u32_at_src(block_byte_base + i * 4); } // Bytes 128-191: 64 bytes of qh (16 u32s) var qh_vals: array; for (var i: u32 = 0; i < 16u; i++) { - qh_vals[i] = load_u32_at(&src, block_byte_base + 128 + i * 4u); + qh_vals[i] = load_u32_at_src(block_byte_base + 128 + i * 4u); } // Bytes 192-207: 16 bytes of scales (4 u32s) var scale_vals: array; for (var i: u32 = 0; i < 4; i++) { - scale_vals[i] = load_u32_at(&src, block_byte_base + 192 + i * 4); + scale_vals[i] = load_u32_at_src(block_byte_base + 192 + i * 4); } var dst_i = dst_base + offset * 256; @@ -345,13 +347,13 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { #ifdef IQ2_XXS fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { let block_byte_base = (src_base + offset) * 66; // Block stride: 66 bytes - let d = load_f16_as_f32_at(&src, block_byte_base); + let d = load_f16_as_f32_at_src(block_byte_base); var dst_i = dst_base + offset * 256; for (var ib: u32 = 0; ib < 32; ib += 4) { let aux0_offset = block_byte_base + 2 + ib * 2; let aux1_offset = block_byte_base + 2 + (ib + 2) * 2; - let aux0 = load_u32_at(&src, aux0_offset); - let aux1 = load_u32_at(&src, aux1_offset); + let aux0 = load_u32_at_src(aux0_offset); + let aux1 = load_u32_at_src(aux1_offset); let db = d * (0.5 + f32(aux1 >> 28)) * 0.25; for (var l: u32 = 0; l < 4; l++) { let ig = get_byte(aux0, l) * 8; @@ -373,12 +375,12 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { #ifdef IQ2_XS fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { let block_byte_base = (src_base + offset) * 74; // Block stride: 74 bytes - let d = load_f16_as_f32_at(&src, block_byte_base); + let d = load_f16_as_f32_at_src(block_byte_base); var dst_i = dst_base + offset * 256; var scale_vals = array( - load_u32_at(&src, block_byte_base + 66), - load_u32_at(&src, block_byte_base + 70) + load_u32_at_src(block_byte_base + 66), + load_u32_at_src(block_byte_base + 70) ); for (var ib: u32 = 0; ib < 32; ib += 4) { @@ -389,7 +391,7 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { ); for (var l: u32 = 0; l < 4; l++) { let qs_offset = block_byte_base + 2 + (ib + l) * 2; - let qs_val = load_u32_at(&src, qs_offset) & 0xFFFF; + let qs_val = load_u32_at_src(qs_offset) & 0xFFFF; let ig = (qs_val & 511) * 8; let is = qs_val >> 9; let signs = get_byte(ksigns_iq2xs[is / 4], is % 4); @@ -408,21 +410,21 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { #ifdef IQ2_S fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { let block_byte_base = (src_base + offset) * 82; // Block stride: 82 bytes - let d = load_f16_as_f32_at(&src, block_byte_base); + let d = load_f16_as_f32_at_src(block_byte_base); var dst_i = dst_base + offset * 256; var qs_vals : array; for (var i: u32 = 0; i < 16; i++) { - qs_vals[i] = load_u32_at(&src, block_byte_base + 2 + i * 4); + qs_vals[i] = load_u32_at_src(block_byte_base + 2 + i * 4); } var qh_vals: array; - qh_vals[0] = load_u32_at(&src, block_byte_base + 66); - qh_vals[1] = load_u32_at(&src, block_byte_base + 70); + qh_vals[0] = load_u32_at_src(block_byte_base + 66); + qh_vals[1] = load_u32_at_src(block_byte_base + 70); var scale_vals: array; - scale_vals[0] = load_u32_at(&src, block_byte_base + 74); - scale_vals[1] = load_u32_at(&src, block_byte_base + 78); + scale_vals[0] = load_u32_at_src(block_byte_base + 74); + scale_vals[1] = load_u32_at_src(block_byte_base + 78); for (var ib: u32 = 0; ib < 8; ib ++) { let s = get_byte(scale_vals[ib / 4], ib % 4); @@ -450,16 +452,16 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { #ifdef IQ3_XXS fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { let block_byte_base = (src_base + offset) * 98; // Block stride: 98 bytes - let d = load_f16_as_f32_at(&src, block_byte_base); + let d = load_f16_as_f32_at_src(block_byte_base); var dst_i = dst_base + offset * 256; for (var ib: u32 = 0; ib < 16; ib += 2) { let sc_sign_offset = block_byte_base + 2 + (ib + 32) * 2; - let sc_sign = load_u32_at(&src, sc_sign_offset); + let sc_sign = load_u32_at_src(sc_sign_offset); let db = d * (0.5 + f32(sc_sign >> 28)) * 0.5; for (var l: u32 = 0; l < 4; l++) { let is = (sc_sign >> (7 * l)) & 127; let signs = get_byte(ksigns_iq2xs[is / 4], is % 4); - let ig_val = load_u32_at(&src, block_byte_base + 2 + (ib * 2 + l) * 2) & 0xFFFF; + let ig_val = load_u32_at_src(block_byte_base + 2 + (ib * 2 + l) * 2) & 0xFFFF; let ig1 = get_byte(ig_val, 0); let ig2 = get_byte(ig_val, 1); for (var j: u32 = 0; j < 4; j++) { @@ -480,20 +482,20 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { #ifdef IQ3_S fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { let block_byte_base = (src_base + offset) * 110; // Block stride: 110 bytes - let d = load_f16_as_f32_at(&src, block_byte_base); + let d = load_f16_as_f32_at_src(block_byte_base); var dst_i = dst_base + offset * 256; var qh_vals = array( - load_u32_at(&src, block_byte_base + 66), - load_u32_at(&src, block_byte_base + 70) + load_u32_at_src(block_byte_base + 66), + load_u32_at_src(block_byte_base + 70) ); var sign_vals: array; for (var i: u32 = 0; i < 8; i++) { - sign_vals[i] = load_u32_at(&src, block_byte_base + 74 + i * 4); + sign_vals[i] = load_u32_at_src(block_byte_base + 74 + i * 4); } - var scale_vals = load_u32_at(&src, block_byte_base + 106); + var scale_vals = load_u32_at_src(block_byte_base + 106); for (var ib: u32 = 0; ib < 4; ib++) { let s = get_byte(scale_vals, ib); @@ -507,7 +509,7 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { let sign_w = sign_vals[ib * 2 + k]; for (var l: u32 = 0; l < 4; l++) { let signs = get_byte(sign_w, l); - let ig_val = load_u32_at(&src, block_byte_base + 2 + (ib * 8 + k * 4 + l) * 2) & 0xFFFF; + let ig_val = load_u32_at_src(block_byte_base + 2 + (ib * 8 + k * 4 + l) * 2) & 0xFFFF; let ig1 = get_byte(ig_val, 0) | ((qh_byte << ((8 - (2 * l)))) & 256); let ig2 = get_byte(ig_val, 1) | ((qh_byte << ((7 - (2 * l)))) & 256); for (var j: u32 = 0; j < 4; j++) { @@ -529,13 +531,13 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { #ifdef IQ1_S fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { let block_byte_base = (src_base + offset) * 50; // Block stride: 50 bytes - let d = load_f16_as_f32_at(&src, block_byte_base); + let d = load_f16_as_f32_at_src(block_byte_base); var dst_i = dst_base + offset * 256; for (var ib: u32 = 0; ib < 8; ib++) { - let qh = load_u32_at(&src, block_byte_base + 34 + ib * 2) & 0xFFFF; + let qh = load_u32_at_src(block_byte_base + 34 + ib * 2) & 0xFFFF; let dl = d * (2.0 * f32((qh >> 12) & 7) + 1.0); let delta = select(IQ1_DELTA, -IQ1_DELTA, (qh & 0x8000) != 0); - let qs_w = load_u32_at(&src, block_byte_base + 2 + ib * 4); + let qs_w = load_u32_at_src(block_byte_base + 2 + ib * 4); for (var l: u32 = 0; l < 4; l++) { let ig = (get_byte(qs_w, l) | (((qh >> (3 * l)) & 7) << 8)) * 8; for (var j: u32 = 0; j < 8; j++) { @@ -596,11 +598,11 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { #ifdef IQ4_NL fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { let block_byte_base = (src_base + offset) * 18; // Block stride: 18 bytes - let d = load_f16_as_f32_at(&src, block_byte_base); + let d = load_f16_as_f32_at_src(block_byte_base); var dst_i = dst_base + offset * 32; var qs: array; for (var i: u32 = 0; i < 4; i++) { - qs[i] = load_u32_at(&src, block_byte_base + 2 + i * 4); + qs[i] = load_u32_at_src(block_byte_base + 2 + i * 4); } for (var j: u32 = 0; j < 16; j++) { let qsb = get_byte(qs[j / 4], j % 4); diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl index fdabaf09b2e..fcbefdeb802 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl @@ -1,7 +1,9 @@ enable f16; +#define DECLARE_BYTE_LOADERS_SRC0 #include "common_decls.tmpl" + #ifdef FLOAT const BLOCK_SIZE = 1u; @@ -21,11 +23,11 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { #ifdef Q4_0 fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { let block_byte_base = (src0_idx_base + offset) * 18; // Block stride: 18 bytes - let d = load_f16_as_f32_at(&src0, block_byte_base); + let d = load_f16_as_f32_at_src0(block_byte_base); var sum: f32 = 0.0; for (var j: u32 = 0; j < 4; j++) { let q_byte_offset = block_byte_base + 2 + j * 4; - let q_packed = load_u32_at(&src0, q_byte_offset); + let q_packed = load_u32_at_src0(q_byte_offset); for (var k: u32 = 0; k < 4; k++) { let q_byte = get_byte(q_packed, k); let q_hi = (f32((q_byte >> 4) & 0xF) - 8.0f) * d; @@ -63,12 +65,12 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { #ifdef Q5_0 fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { let block_byte_base = (src0_idx_base + offset) * 22; // Block stride: 22 bytes - let d = load_f16_as_f32_at(&src0, block_byte_base); + let d = load_f16_as_f32_at_src0(block_byte_base); var sum: f32 = 0.0; - let qh_packed = load_u32_at(&src0, block_byte_base + 2); + let qh_packed = load_u32_at_src0(block_byte_base + 2); for (var j: u32 = 0; j < 4; j++) { let q_byte_offset = block_byte_base + 6 + j * 4; - let q_packed = load_u32_at(&src0, q_byte_offset); + let q_packed = load_u32_at_src0(q_byte_offset); for (var k: u32 = 0; k < 4; k++) { let q_byte = get_byte(q_packed, k); let qh_hi = (qh_packed >> (j * 4 + k + 12)) & 0x10; @@ -110,11 +112,11 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { #ifdef Q8_0 fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { let block_byte_base = (src0_idx_base + offset) * 34; // Block stride: 34 bytes - let d = load_f16_as_f32_at(&src0, block_byte_base); + let d = load_f16_as_f32_at_src0(block_byte_base); var sum: f32 = 0.0; for (var j: u32 = 0; j < 8; j++) { let q_byte_offset = block_byte_base + 2 + j * 4; - let q_packed = load_u32_at(&src0, q_byte_offset); + let q_packed = load_u32_at_src0(q_byte_offset); for (var k: u32 = 0u; k < 4u; k++) { let q_byte = get_byte_i32(q_packed, k); let q_val = f32(q_byte) * d; @@ -184,7 +186,7 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { let block_byte_base = (src0_idx_base + offset) * 110; // Block stride: 110 bytes // Bytes 108-109: f16 scale 'd' - let d = load_f16_as_f32_at(&src0, block_byte_base + 108); + let d = load_f16_as_f32_at_src0(block_byte_base + 108); // extract 6-bit scales, which consist of 4-bits from first 8 bytes of scale, // and 2-bits from the last 4 bytes @@ -192,9 +194,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { let kmask1: u32 = 0x03030303; let kmask2: u32 = 0x0f0f0f0f; var scale_vals: array; - scale_vals[0] = load_u32_at(&src0, block_byte_base + 96); - scale_vals[1] = load_u32_at(&src0, block_byte_base + 100); - scale_vals[2] = load_u32_at(&src0, block_byte_base + 104); + scale_vals[0] = load_u32_at_src0(block_byte_base + 96); + scale_vals[1] = load_u32_at_src0(block_byte_base + 100); + scale_vals[2] = load_u32_at_src0(block_byte_base + 104); var tmp: u32 = scale_vals[2]; scale_vals[2] = ((scale_vals[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4); @@ -205,13 +207,13 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { // Bytes 0-31: 32 bytes of hmask (8 u32s) var hmask_vals: array; for (var i: u32 = 0; i < 8; i++) { - hmask_vals[i] = load_u32_at(&src0, block_byte_base + i * 4); + hmask_vals[i] = load_u32_at_src0(block_byte_base + i * 4); } // Bytes 32-95: 64 bytes of qs (16 u32s) var qs_vals: array; for (var i: u32 = 0u; i < 16; i++) { - qs_vals[i] = load_u32_at(&src0, block_byte_base + 32 + i * 4); + qs_vals[i] = load_u32_at_src0(block_byte_base + 32 + i * 4); } var sum = 0.0; @@ -313,24 +315,24 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { let block_byte_base = (src0_idx_base + offset) * 210; // Block stride: 210 bytes // Bytes 208-209: f16 scale 'd' - let d = load_f16_as_f32_at(&src0, block_byte_base + 208); + let d = load_f16_as_f32_at_src0(block_byte_base + 208); // Bytes 0-127: 128 bytes of ql (32 u32s) var ql_vals: array; for (var i: u32 = 0; i < 32; i++) { - ql_vals[i] = load_u32_at(&src0, block_byte_base + i * 4); + ql_vals[i] = load_u32_at_src0(block_byte_base + i * 4); } // Bytes 128-191: 64 bytes of qh (16 u32s) var qh_vals: array; for (var i: u32 = 0; i < 16; i++) { - qh_vals[i] = load_u32_at(&src0, block_byte_base + 128 + i * 4); + qh_vals[i] = load_u32_at_src0(block_byte_base + 128 + i * 4); } // Bytes 192-207: 16 bytes of scales (4 u32s) var scale_vals: array; for (var i: u32 = 0; i < 4; i++) { - scale_vals[i] = load_u32_at(&src0, block_byte_base + 192 + i * 4); + scale_vals[i] = load_u32_at_src0(block_byte_base + 192 + i * 4); } var sum = 0.0; @@ -374,14 +376,14 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { #ifdef IQ2_XXS fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { let block_byte_base = (src0_idx_base + offset) * 66; // Block stride: 66 bytes - let d = load_f16_as_f32_at(&src0, block_byte_base); + let d = load_f16_as_f32_at_src0(block_byte_base); var src1_i = src1_idx_base + offset * 256; var sum = 0.0; for (var ib: u32 = 0; ib < 32; ib += 4) { let aux0_offset = block_byte_base + 2 + ib * 2; let aux1_offset = block_byte_base + 2 + (ib + 2) * 2; - let aux0 = load_u32_at(&src0, aux0_offset); - let aux1 = load_u32_at(&src0, aux1_offset); + let aux0 = load_u32_at_src0(aux0_offset); + let aux1 = load_u32_at_src0(aux1_offset); let db = d * (0.5 + f32(aux1 >> 28)) * 0.25; for (var l: u32 = 0; l < 4; l++) { let ig = get_byte(aux0, l) * 8; @@ -402,12 +404,12 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { #ifdef IQ2_XS fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { let block_byte_base = (src0_idx_base + offset) * 74; // Block stride: 74 bytes - let d = load_f16_as_f32_at(&src0, block_byte_base); + let d = load_f16_as_f32_at_src0(block_byte_base); var src1_i = src1_idx_base + offset * 256; var scale_vals = array( - load_u32_at(&src0, block_byte_base + 66), - load_u32_at(&src0, block_byte_base + 70) + load_u32_at_src0(block_byte_base + 66), + load_u32_at_src0(block_byte_base + 70) ); var sum = 0.0; @@ -419,7 +421,7 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { ); for (var l: u32 = 0; l < 4; l++) { let qs_offset = block_byte_base + 2 + (ib + l) * 2; - let qs_val = load_u32_at(&src0, qs_offset) & 0xFFFF; + let qs_val = load_u32_at_src0(qs_offset) & 0xFFFF; let ig = (qs_val & 511) * 8; let is = qs_val >> 9; let signs = get_byte(ksigns_iq2xs[is / 4], is % 4); @@ -439,21 +441,21 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { #ifdef IQ2_S fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { let block_byte_base = (src0_idx_base + offset) * 82; // Block stride: 82 bytes - let d = load_f16_as_f32_at(&src0, block_byte_base); + let d = load_f16_as_f32_at_src0(block_byte_base); var src1_i = src1_idx_base + offset * 256; var qs_vals : array; for (var i: u32 = 0; i < 16; i++) { - qs_vals[i] = load_u32_at(&src0, block_byte_base + 2 + i * 4); + qs_vals[i] = load_u32_at_src0(block_byte_base + 2 + i * 4); } var qh_vals: array; - qh_vals[0] = load_u32_at(&src0, block_byte_base + 66); - qh_vals[1] = load_u32_at(&src0, block_byte_base + 70); + qh_vals[0] = load_u32_at_src0(block_byte_base + 66); + qh_vals[1] = load_u32_at_src0(block_byte_base + 70); var scale_vals: array; - scale_vals[0] = load_u32_at(&src0, block_byte_base + 74); - scale_vals[1] = load_u32_at(&src0, block_byte_base + 78); + scale_vals[0] = load_u32_at_src0(block_byte_base + 74); + scale_vals[1] = load_u32_at_src0(block_byte_base + 78); var sum = 0.0; for (var ib: u32 = 0; ib < 8; ib ++) { @@ -483,17 +485,17 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { #ifdef IQ3_XXS fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { let block_byte_base = (src0_idx_base + offset) * 98; // Block stride: 98 bytes - let d = load_f16_as_f32_at(&src0, block_byte_base); + let d = load_f16_as_f32_at_src0(block_byte_base); var src1_i = src1_idx_base + offset * 256; var sum = 0.0; for (var ib: u32 = 0; ib < 16; ib += 2) { let sc_sign_offset = block_byte_base + 2 + (ib + 32) * 2; - let sc_sign = load_u32_at(&src0, sc_sign_offset); + let sc_sign = load_u32_at_src0(sc_sign_offset); let db = d * (0.5 + f32(sc_sign >> 28)) * 0.5; for (var l: u32 = 0; l < 4; l++) { let is = (sc_sign >> (7 * l)) & 127; let signs = get_byte(ksigns_iq2xs[is / 4], is % 4); - let ig_val = load_u32_at(&src0, block_byte_base + 2 + (ib * 2 + l) * 2) & 0xFFFF; + let ig_val = load_u32_at_src0(block_byte_base + 2 + (ib * 2 + l) * 2) & 0xFFFF; let ig1 = get_byte(ig_val, 0); let ig2 = get_byte(ig_val, 1); for (var j: u32 = 0; j < 4; j++) { @@ -515,20 +517,20 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { #ifdef IQ3_S fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { let block_byte_base = (src0_idx_base + offset) * 110; // Block stride: 110 bytes - let d = load_f16_as_f32_at(&src0, block_byte_base); + let d = load_f16_as_f32_at_src0(block_byte_base); var src1_i = src1_idx_base + offset * 256; var qh_vals = array( - load_u32_at(&src0, block_byte_base + 66), - load_u32_at(&src0, block_byte_base + 70) + load_u32_at_src0(block_byte_base + 66), + load_u32_at_src0(block_byte_base + 70) ); var sign_vals: array; for (var i: u32 = 0; i < 8; i++) { - sign_vals[i] = load_u32_at(&src0, block_byte_base + 74 + i * 4); + sign_vals[i] = load_u32_at_src0(block_byte_base + 74 + i * 4); } - var scale_vals = load_u32_at(&src0, block_byte_base + 106); + var scale_vals = load_u32_at_src0(block_byte_base + 106); var sum = 0.0; for (var ib: u32 = 0; ib < 4; ib++) { @@ -543,7 +545,7 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { let sign_w = sign_vals[ib * 2 + k]; for (var l: u32 = 0; l < 4; l++) { let signs = get_byte(sign_w, l); - let ig_val = load_u32_at(&src0, block_byte_base + 2 + (ib * 8 + k * 4 + l) * 2) & 0xFFFF; + let ig_val = load_u32_at_src0(block_byte_base + 2 + (ib * 8 + k * 4 + l) * 2) & 0xFFFF; let ig1 = get_byte(ig_val, 0) | ((qh_byte << ((8 - (2 * l)))) & 256); let ig2 = get_byte(ig_val, 1) | ((qh_byte << ((7 - (2 * l)))) & 256); for (var j: u32 = 0; j < 4; j++) { @@ -566,14 +568,14 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { #ifdef IQ1_S fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { let block_byte_base = (src0_idx_base + offset) * 50; // Block stride: 50 bytes - let d = load_f16_as_f32_at(&src0, block_byte_base); + let d = load_f16_as_f32_at_src0(block_byte_base); var src1_i = src1_idx_base + offset * 256; var sum = 0.0; for (var ib: u32 = 0; ib < 8; ib++) { - let qh = load_u32_at(&src0, block_byte_base + 34 + ib * 2) & 0xFFFF; + let qh = load_u32_at_src0(block_byte_base + 34 + ib * 2) & 0xFFFF; let dl = d * (2.0 * f32((qh >> 12) & 7) + 1.0); let delta = select(IQ1_DELTA, -IQ1_DELTA, (qh & 0x8000) != 0); - let qs_w = load_u32_at(&src0, block_byte_base + 2 + ib * 4); + let qs_w = load_u32_at_src0(block_byte_base + 2 + ib * 4); for (var l: u32 = 0; l < 4; l++) { let ig = (get_byte(qs_w, l) | (((qh >> (3 * l)) & 7) << 8)) * 8; for (var j: u32 = 0; j < 8; j++) { @@ -638,12 +640,12 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { #ifdef IQ4_NL fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { let block_byte_base = (src0_idx_base + offset) * 18; // Block stride: 18 bytes - let d = load_f16_as_f32_at(&src0, block_byte_base); + let d = load_f16_as_f32_at_src0(block_byte_base); var src1_i = src1_idx_base + offset * 32; var sum = 0.0; var qs: array; for (var i: u32 = 0; i < 4; i++) { - qs[i] = load_u32_at(&src0, block_byte_base + 2 + i * 4); + qs[i] = load_u32_at_src0(block_byte_base + 2 + i * 4); } for (var j: u32 = 0; j < 16; j++) { let qsb = get_byte(qs[j / 4], j % 4); diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl index 56a76a6e6c4..5a323818260 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl @@ -84,11 +84,11 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 if (global_m < params.m && global_k < params.k / BLOCK_SIZE) { let src0_idx = batch_offset + global_m * params.stride_01 + global_k; let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; - let d = load_f16_at(&src0, block_byte_base); + let d = load_f16_at_src0(block_byte_base); for (var j = 0u; j < F16_PER_THREAD; j += 2) { let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j); - let q_packed = load_u32_at(&src0, q_byte_offset); + let q_packed = load_u32_at_src0(q_byte_offset); for (var k = 0u; k < 4u; k++) { let q_byte = get_byte(q_packed, k); let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d; @@ -125,12 +125,12 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 if (global_m < params.m && global_k < params.k / BLOCK_SIZE) { let src0_idx = batch_offset + global_m * params.stride_01 + global_k; let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; - let d = load_f16_at(&src0, block_byte_base); - let m = load_f16_at(&src0, block_byte_base + 2u); + let d = load_f16_at_src0(block_byte_base); + let m = load_f16_at_src0(block_byte_base + 2u); for (var j = 0u; j < F16_PER_THREAD; j += 2) { let q_byte_offset = block_byte_base + 4u + 2u * (block_offset + j); - let q_packed = load_u32_at(&src0, q_byte_offset); + let q_packed = load_u32_at_src0(q_byte_offset); for (var k = 0u; k < 4u; k++) { let q_byte = get_byte(q_packed, k); let q_lo = f16(q_byte & 0xF) * d + m; @@ -171,12 +171,12 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 let src0_idx = batch_offset + global_m * params.stride_01 + global_k; let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; - let d = load_f16_at(&src0, block_byte_base); - let qh_packed = load_u32_at(&src0, block_byte_base + 2u); + let d = load_f16_at_src0(block_byte_base); + let qh_packed = load_u32_at_src0(block_byte_base + 2u); for (var j = 0u; j < 2; j++) { let q_byte_offset = block_byte_base + 6u + 2u * (block_offset + j * 2u); - let q_packed = load_u32_at(&src0, q_byte_offset); + let q_packed = load_u32_at_src0(q_byte_offset); let j_adjusted = j + (block_offset / 2u); @@ -225,14 +225,14 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 let src0_idx = batch_offset + global_m * params.stride_01 + global_k; let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; - let d = load_f16_at(&src0, block_byte_base); - let m = load_f16_at(&src0, block_byte_base + 2u); - let qh_packed = load_u32_at(&src0, block_byte_base + 4u); + let d = load_f16_at_src0(block_byte_base); + let m = load_f16_at_src0(block_byte_base + 2u); + let qh_packed = load_u32_at_src0(block_byte_base + 4u); for (var j = 0u; j < 2; j++) { let q_byte_offset = block_byte_base + 8u + 2u * (block_offset + j * 2u); - let q_packed = load_u32_at(&src0, q_byte_offset); + let q_packed = load_u32_at_src0(q_byte_offset); let j_adjusted = j + (block_offset / 2u); @@ -277,11 +277,11 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 if (global_m < params.m && global_k < params.k / BLOCK_SIZE) { let src0_idx = batch_offset + global_m * params.stride_01 + global_k; let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; - let d = load_f16_at(&src0, block_byte_base); + let d = load_f16_at_src0(block_byte_base); for (var j = 0u; j < F16_PER_THREAD; j+=2) { let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j); - let q_packed = load_u32_at(&src0, q_byte_offset); + let q_packed = load_u32_at_src0(q_byte_offset); for (var k = 0u; k < 4u; k++) { let q_byte = get_byte_i32(q_packed, k); @@ -317,12 +317,12 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 if (global_m < params.m && global_k < params.k / BLOCK_SIZE) { let src0_idx = batch_offset + global_m * params.stride_01 + global_k; let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; - let d = load_f16_at(&src0, block_byte_base); - let m = load_f16_at(&src0, block_byte_base + 2u); + let d = load_f16_at_src0(block_byte_base); + let m = load_f16_at_src0(block_byte_base + 2u); for (var j = 0u; j < F16_PER_THREAD; j+=2) { let q_byte_offset = block_byte_base + 4u + 2u * (block_offset + j); - let q_packed = load_u32_at(&src0, q_byte_offset); + let q_packed = load_u32_at_src0(q_byte_offset); for (var k = 0u; k < 4u; k++) { let q_byte = get_byte_i32(q_packed, k); @@ -359,8 +359,8 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 let src0_idx = batch_offset + global_m * params.stride_01 + block_k; let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; - let d = load_f16_at(&src0, block_byte_base + 80u); - let dmin = load_f16_at(&src0, block_byte_base + 82u); + let d = load_f16_at_src0(block_byte_base + 80u); + let dmin = load_f16_at_src0(block_byte_base + 82u); // Decode the element at position k_in_block let block_of_32 = k_in_block / 32u; @@ -373,14 +373,14 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 let is = k_in_block / 16u; - let sc_packed = load_u32_at(&src0, block_byte_base + 4u * (is / 4u)); + let sc_packed = load_u32_at_src0(block_byte_base + 4u * (is / 4u)); let sc = get_byte(sc_packed, is % 4u); let dl = d * f16(sc & 0xFu); let ml = dmin * f16(sc >> 4u); let q_idx = q_b_idx + k + l; - let q_packed = load_u32_at(&src0, block_byte_base + 16u + 4u * (q_idx / 4u)); + let q_packed = load_u32_at_src0(block_byte_base + 16u + 4u * (q_idx / 4u)); let q_byte = get_byte(q_packed, q_idx % 4u); let qs_val = (q_byte >> shift) & 3u; @@ -413,7 +413,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 let src0_idx = batch_offset + global_m * params.stride_01 + block_k; let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; - let d = load_f16_at(&src0, block_byte_base + 108u); + let d = load_f16_at_src0(block_byte_base + 108u); // Load and unpack scales let kmask1: u32 = 0x03030303u; @@ -421,7 +421,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 var scale_vals: array; for (var i: u32 = 0u; i < 4u; i++) { - scale_vals[i] = load_u32_at(&src0, block_byte_base + 96u + 4u * i); + scale_vals[i] = load_u32_at_src0(block_byte_base + 96u + 4u * i); } var tmp: u32 = scale_vals[2]; @@ -433,12 +433,12 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 // Load hmask and qs arrays var hmask_vals: array; for (var i: u32 = 0u; i < 8u; i++) { - hmask_vals[i] = load_u32_at(&src0, block_byte_base + 4u * i); + hmask_vals[i] = load_u32_at_src0(block_byte_base + 4u * i); } var qs_vals: array; for (var i: u32 = 0u; i < 16u; i++) { - qs_vals[i] = load_u32_at(&src0, block_byte_base + 32u + 4u * i); + qs_vals[i] = load_u32_at_src0(block_byte_base + 32u + 4u * i); } let half = k_in_block / 128u; // 0 or 1 @@ -499,8 +499,8 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 let src0_idx = batch_offset + global_m * params.stride_01 + block_k; let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; - let d = load_f16_at(&src0, block_byte_base); - let dmin = load_f16_at(&src0, block_byte_base + 2u); + let d = load_f16_at_src0(block_byte_base); + let dmin = load_f16_at_src0(block_byte_base + 2u); // Map k_in_block to loop structure: // Outer loop over 64-element groups (alternating q_b_idx) @@ -520,14 +520,14 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 let scale_base = block_byte_base + 4u; if (is < 4u) { - let sc_byte = get_byte(load_u32_at(&src0, scale_base), is % 4u); - let min_byte = get_byte(load_u32_at(&src0, scale_base + 4), is % 4u); + let sc_byte = get_byte(load_u32_at_src0(scale_base), is % 4u); + let min_byte = get_byte(load_u32_at_src0(scale_base + 4), is % 4u); sc = sc_byte & 63u; mn = min_byte & 63u; } else { - let sc_min_lo = get_byte(load_u32_at(&src0, scale_base + 8), (is + 4u) % 4u); - let sc_hi = get_byte(load_u32_at(&src0, scale_base), (is - 4u) % 4u); - let min_hi = get_byte(load_u32_at(&src0, scale_base + 4), is % 4u); + let sc_min_lo = get_byte(load_u32_at_src0(scale_base + 8), (is + 4u) % 4u); + let sc_hi = get_byte(load_u32_at_src0(scale_base), (is - 4u) % 4u); + let min_hi = get_byte(load_u32_at_src0(scale_base + 4), is % 4u); sc = (sc_min_lo & 0xFu) | ((sc_hi >> 6u) << 4u); mn = (sc_min_lo >> 4u) | ((min_hi >> 6u) << 4u); @@ -537,7 +537,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 let ml = dmin * f16(mn); let q_idx = q_b_idx + l; - let q_packed = load_u32_at(&src0, block_byte_base + 16u + 4u * (q_idx / 4u)); + let q_packed = load_u32_at_src0(block_byte_base + 16u + 4u * (q_idx / 4u)); let q_byte = get_byte(q_packed, q_idx % 4u); let qs_val = (q_byte >> shift) & 0xFu; @@ -571,8 +571,8 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 let src0_idx = batch_offset + global_m * params.stride_01 + block_k; let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; - let d = load_f16_at(&src0, block_byte_base); - let dmin = load_f16_at(&src0, block_byte_base + 2u); + let d = load_f16_at_src0(block_byte_base); + let dmin = load_f16_at_src0(block_byte_base + 2u); // The original loop processes elements in groups of 64 @@ -597,14 +597,14 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 let scale_base = block_byte_base + 4u; if (is < 4u) { - let sc_byte = get_byte(load_u32_at(&src0, scale_base), is % 4u); - let min_byte = get_byte(load_u32_at(&src0, scale_base + 4), is % 4u); + let sc_byte = get_byte(load_u32_at_src0(scale_base), is % 4u); + let min_byte = get_byte(load_u32_at_src0(scale_base + 4), is % 4u); sc = sc_byte & 63u; mn = min_byte & 63u; } else { - let sc_min_lo = get_byte(load_u32_at(&src0, scale_base + 8), (is + 4u) % 4u); - let sc_hi = get_byte(load_u32_at(&src0, scale_base), (is - 4u) % 4u); - let min_hi = get_byte(load_u32_at(&src0, scale_base + 4), is % 4u); + let sc_min_lo = get_byte(load_u32_at_src0(scale_base + 8), (is + 4u) % 4u); + let sc_hi = get_byte(load_u32_at_src0(scale_base), (is - 4u) % 4u); + let min_hi = get_byte(load_u32_at_src0(scale_base + 4), is % 4u); sc = (sc_min_lo & 0xFu) | ((sc_hi >> 6u) << 4u); mn = (sc_min_lo >> 4u) | ((min_hi >> 6u) << 4u); @@ -614,11 +614,11 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 let ml = dmin * f16(mn); let q_idx = q_b_idx + l; - let q_packed = load_u32_at(&src0, block_byte_base + 48u + 4u * (q_idx / 4u)); + let q_packed = load_u32_at_src0(block_byte_base + 48u + 4u * (q_idx / 4u)); let q_byte = get_byte(q_packed, q_idx % 4u); - let qh_packed = load_u32_at(&src0, block_byte_base + 16u + 4u * (l / 4u)); + let qh_packed = load_u32_at_src0(block_byte_base + 16u + 4u * (l / 4u)); let qh_byte = get_byte(qh_packed, l % 4u); @@ -666,17 +666,17 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 // Load only ql13 word needed let ql13_flat = ql_b_idx + l; - let ql13 = load_u32_at(&src0, block_byte_base + ql13_flat); + let ql13 = load_u32_at_src0(block_byte_base + ql13_flat); let ql13_b = get_byte(ql13, 0u); // Load only ql24 word needed let ql24_flat = ql_b_idx + l + 32u; - let ql24 = load_u32_at(&src0, block_byte_base + ql24_flat); + let ql24 = load_u32_at_src0(block_byte_base + ql24_flat); let ql24_b = get_byte(ql24, 0u); // Load only qh word needed let qh_flat = qh_b_idx + l; - let qh = load_u32_at(&src0, block_byte_base + 128u + qh_flat); + let qh = load_u32_at_src0(block_byte_base + 128u + qh_flat); let qh_b = get_byte(qh, 0u); let q1 = f16((ql13_b & 0xFu) | ((qh_b & 3u) << 4u)) - f16(32.0); @@ -687,10 +687,10 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 // Load only the scale word needed let is = l / 16u; let sc_idx = sc_b_idx + is + quarter * 2u; - let sc = load_u32_at(&src0, block_byte_base + 192u + sc_idx); + let sc = load_u32_at_src0(block_byte_base + 192u + sc_idx); let sc_val = get_byte_i32(sc, 0u); - let d = load_f16_at(&src0, block_byte_base + 208u); + let d = load_f16_at_src0(block_byte_base + 208u); var q_val: f16; if (quarter == 0u) { diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id.wgsl index 5f763a6400a..91039ff2546 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id.wgsl @@ -1,6 +1,8 @@ enable f16; +#define DECLARE_BYTE_LOADERS_SRC0 #include "common_decls.tmpl" + #include "mul_mat_decls.tmpl" #ifdef VEC diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl index ee37e6d249c..98bbdeb83ba 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl @@ -1,6 +1,8 @@ enable f16; +#define DECLARE_BYTE_LOADERS_SRC0 #include "common_decls.tmpl" + #include "mul_mat_decls.tmpl" #ifdef VEC diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.wgsl index 4151ce430b0..d86a72ce6e0 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.wgsl @@ -3,7 +3,9 @@ enable f16; enable subgroups; enable chromium_experimental_subgroup_matrix; +#define DECLARE_BYTE_LOADERS_SRC0 #include "common_decls.tmpl" + #include "mul_mat_decls.tmpl" // TODO: this shader path does not work with some models like qwen2.5 on Metal devices, f16 accumulation causes NaNs. @@ -196,4 +198,3 @@ fn main(@builtin(workgroup_id) wg_id: vec3, } } } - diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl index 6f6bcaf7940..9f7b3e32eca 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl @@ -1,7 +1,9 @@ enable f16; +#define DECLARE_BYTE_LOADERS_SRC0 #include "common_decls.tmpl" + #ifdef VEC #define VEC_SIZE 4 @@ -65,10 +67,10 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES; // each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17] let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u; - let d = f32(load_f16_at(&src0, block_byte_base)); + let d = f32(load_f16_at_src0(block_byte_base)); for (var j = 0u; j < F16_PER_THREAD; j += 2) { let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j); - let q_packed = load_u32_at(&src0, q_byte_offset); + let q_packed = load_u32_at_src0(q_byte_offset); for (var k: u32 = 0; k < 4; k++) { let q_byte = get_byte(q_packed, k); let q_hi = (f32((q_byte >> 4) & 0xF) - 8.0) * d; @@ -98,11 +100,11 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES; // each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17] let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u; - let d = f32(load_f16_at(&src0, block_byte_base)); - let m = f32(load_f16_at(&src0, block_byte_base + 2u)); + let d = f32(load_f16_at_src0(block_byte_base)); + let m = f32(load_f16_at_src0(block_byte_base + 2u)); for (var j = 0u; j < F16_PER_THREAD; j += 2) { let q_byte_offset = block_byte_base + 4u + 2u * (block_offset + j); - let q_packed = load_u32_at(&src0, q_byte_offset); + let q_packed = load_u32_at_src0(q_byte_offset); for (var k: u32 = 0; k < 4; k++) { let q_byte = get_byte(q_packed, k); let q_hi = f32((q_byte >> 4) & 0xF) * d + m; @@ -132,12 +134,12 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES; // each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17] let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u; - let d = f32(load_f16_at(&src0, block_byte_base)); - let qh_packed = load_u32_at(&src0, block_byte_base + 2u); + let d = f32(load_f16_at_src0(block_byte_base)); + let qh_packed = load_u32_at_src0(block_byte_base + 2u); for (var j = 0u; j < 2; j++) { let q_byte_offset = block_byte_base + 6u + 2u * (block_offset + j * 2u); - let q_packed = load_u32_at(&src0, q_byte_offset); + let q_packed = load_u32_at_src0(q_byte_offset); let j_adjusted = j + (block_offset / 2u); @@ -176,13 +178,13 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES; // each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17] let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u; - let d = f32(load_f16_at(&src0, block_byte_base)); - let m = load_f16_at(&src0, block_byte_base + 2u); - let qh_packed = load_u32_at(&src0, block_byte_base + 4u); + let d = f32(load_f16_at_src0(block_byte_base)); + let m = load_f16_at_src0(block_byte_base + 2u); + let qh_packed = load_u32_at_src0(block_byte_base + 4u); for (var j = 0u; j < 2; j++) { let q_byte_offset = block_byte_base + 8u + 2u * (block_offset + j * 2u); - let q_packed = load_u32_at(&src0, q_byte_offset); + let q_packed = load_u32_at_src0(q_byte_offset); let j_adjusted = j + (block_offset / 2u); @@ -221,11 +223,11 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES; // each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17] let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u; - let d = f32(load_f16_at(&src0, block_byte_base)); + let d = f32(load_f16_at_src0(block_byte_base)); for (var j = 0u; j < F16_PER_THREAD; j += 2) { let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j); - let q_packed = load_u32_at(&src0, q_byte_offset); + let q_packed = load_u32_at_src0(q_byte_offset); for (var k: u32 = 0; k < 4; k++) { let q_byte = get_byte_i32(q_packed, k); let q_val = f32(q_byte) * d; @@ -254,12 +256,12 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES; // each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17] let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u; - let d = f32(load_f16_at(&src0, block_byte_base)); - let m = load_f16_at(&src0, block_byte_base + 2u); + let d = f32(load_f16_at_src0(block_byte_base)); + let m = load_f16_at_src0(block_byte_base + 2u); for (var j = 0u; j < F16_PER_THREAD; j += 2) { let q_byte_offset = block_byte_base + 4u + 2u * (block_offset + j); - let q_packed = load_u32_at(&src0, q_byte_offset); + let q_packed = load_u32_at_src0(q_byte_offset); for (var k: u32 = 0; k < 4; k++) { let q_byte = get_byte_i32(q_packed, k); let q_val = f32(q_byte) * d + f32(m); @@ -309,13 +311,13 @@ fn mul_acc(tig: u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { for (var i = ix; i < nb; i += 2u) { let bbase = (idx_base + k_block_start + i) * BLOCK_SIZE_BYTES; - let d = f32(load_f16_at(&src0, bbase + 208u)); + let d = f32(load_f16_at_src0(bbase + 208u)); - let ql1_u32 = load_u32_at(&src0, bbase + q_offset_l); - let ql2_u32 = load_u32_at(&src0, bbase + q_offset_l + 32u); - let qh_u32 = load_u32_at(&src0, bbase + 128u + q_offset_h); - let sc_u32_0 = load_u32_at(&src0, bbase + sc_base_byte); - let sc_u32_1 = load_u32_at(&src0, bbase + sc_base_byte + 4u); + let ql1_u32 = load_u32_at_src0(bbase + q_offset_l); + let ql2_u32 = load_u32_at_src0(bbase + q_offset_l + 32u); + let qh_u32 = load_u32_at_src0(bbase + 128u + q_offset_h); + let sc_u32_0 = load_u32_at_src0(bbase + sc_base_byte); + let sc_u32_1 = load_u32_at_src0(bbase + sc_base_byte + 4u); let sc0 = sbyte_of(sc_u32_0, sc_byte_pos); let sc2 = sbyte_of(sc_u32_0, sc_byte_pos + 2u); diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl index 8c334817ccd..b8f1bca1284 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl @@ -147,15 +147,12 @@ fn main(@builtin(global_invocation_id) gid: vec3) { -9.010913, 9.010913))); #endif #ifdef XIELU + let val = f32(src[params.offset_src + src_idx]); let res = - select(((exp(min(src[params.offset_src + src_idx], TYPE(params.eps))) - 1.0) - - src[params.offset_src + src_idx]) * - TYPE(params.alpha_n) + - TYPE(params.beta) * src[params.offset_src + src_idx], - TYPE(params.alpha_p) * src[params.offset_src + src_idx] * - src[params.offset_src + src_idx] + - TYPE(params.beta) * src[params.offset_src + src_idx], - src[params.offset_src + src_idx] > 0.0); + TYPE(select( + ((exp(min(val, params.eps)) - 1.0) - val) * params.alpha_n + params.beta * val, + params.alpha_p * val * val + params.beta * val, + val > 0.0)); #endif #ifdef SOFTPLUS let src_f32 = f32(src[params.offset_src + src_idx]);