Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 42 additions & 9 deletions ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c
Original file line number Diff line number Diff line change
Expand Up @@ -742,17 +742,45 @@ static void transfer_output_chunk_threaded(struct htp_context *ctx, float *dst,
// activations : fp32 -> fp16

static void transfer_activation_chunk_fp32_to_fp16(__fp16 *restrict vtcm_dst, const float *restrict src, int n_rows, int k_block, int k_stride) {
for (int r = 0; r < n_rows; r += 2) {
const int n_rows_padded = hex_align_up(n_rows, HMX_FP16_TILE_N_ROWS);
const int n_rows_tiled = (n_rows / HMX_FP16_TILE_N_ROWS) * HMX_FP16_TILE_N_ROWS;

int r = 0;

#pragma unroll(2)
for (r = 0; r < n_rows_tiled; r += 2) {
int r0 = r / HMX_FP16_TILE_N_ROWS; // tile row index
int r1 = r % HMX_FP16_TILE_N_ROWS; // intra-tile row idx

const bool next_row_valid = (r + 1) < n_rows;

const HVX_Vector *pv_in0 = (const HVX_Vector *) (src + (r + 0) * k_stride);
const HVX_Vector *pv_in1 = (const HVX_Vector *) (src + (r + 1) * k_stride);
for (int c = 0; c < k_block; c += 32) {
HVX_Vector v0 = *pv_in0++;
HVX_Vector v1 = next_row_valid ? *pv_in1++ : Q6_V_vzero();
HVX_Vector v1 = *pv_in1++;

HVX_Vector v_out = hvx_vec_f32_to_f16_shuff(v0, v1);

// compute output position
int c0 = c / HMX_FP16_TILE_N_COLS; // tile column index
int tile_idx = r0 * (k_block / HMX_FP16_TILE_N_COLS) + c0;

HVX_Vector *tile = (HVX_Vector *) (vtcm_dst + tile_idx * HMX_FP16_TILE_N_ELMS);
tile[r1 / 2] = v_out;
}
}

for (; r < n_rows_padded; r += 2) {
int r0 = r / HMX_FP16_TILE_N_ROWS; // tile row index
int r1 = r % HMX_FP16_TILE_N_ROWS; // intra-tile row idx

const bool row0_valid = r < n_rows;
const bool row1_valid = (r + 1) < n_rows;

const HVX_Vector *pv_in0 = row0_valid ? (const HVX_Vector *) (src + (r + 0) * k_stride) : NULL;
const HVX_Vector *pv_in1 = row1_valid ? (const HVX_Vector *) (src + (r + 1) * k_stride) : NULL;
for (int c = 0; c < k_block; c += 32) {
HVX_Vector v0 = row0_valid ? *pv_in0++ : Q6_V_vzero();
HVX_Vector v1 = row1_valid ? *pv_in1++ : Q6_V_vzero();

HVX_Vector v_out = hvx_vec_f32_to_f16_shuff(v0, v1);

Expand Down Expand Up @@ -889,7 +917,9 @@ static __attribute__((noinline)) int mat_mul_qk_0_d16a32_out_stationary(struct h
// n_block_cost = m*2: each extra N-block re-loads all M×K activation (cheaper).
const size_t m_block_cost = (size_t) n * 3;
const size_t n_block_cost = (size_t) m * 2;
if (hmx_compute_chunks(vtcm_budget, overhead, per_n, per_m, per_mn, m, n, m_block_cost, n_block_cost, &M_BLOCK_SIZE,
if (hmx_compute_chunks(vtcm_budget, overhead, per_n, per_m, per_mn,
hex_align_up(m, HMX_FP16_TILE_N_ROWS), n,
m_block_cost, n_block_cost, &M_BLOCK_SIZE,
&N_BLOCK_SIZE, &vtcm_used) != 0) {
FARF(HIGH, "%s: VTCM too small (m=%d k=%d n=%d budget=%zu)", __func__, m, k, n, vtcm_budget);
return -1;
Expand Down Expand Up @@ -1084,7 +1114,8 @@ int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx, float *restrict ds

if (m >= 128) {
size_t mc = 0, nc = 0, used = 0;
if (hmx_compute_chunks(vtcm_budget, /*overhead=*/256, pipe_per_n, /*per_m=*/vec_dot_size, pipe_per_mn, m, n,
if (hmx_compute_chunks(vtcm_budget, /*overhead=*/256, pipe_per_n, /*per_m=*/vec_dot_size, pipe_per_mn,
hex_align_up(m, HMX_FP16_TILE_N_ROWS), n,
/*m_block_cost=*/(size_t) n * 3,
/*n_block_cost=*/(size_t) m * 2, &mc, &nc, &used) == 0 &&
hmx_ceil_div((size_t) n, nc) >= 2) {
Expand All @@ -1096,7 +1127,8 @@ int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx, float *restrict ds
}

if (!use_pipeline) {
if (hmx_compute_chunks(vtcm_budget, /*overhead=*/256, seq_per_n, /*per_m=*/vec_dot_size, seq_per_mn, m, n,
if (hmx_compute_chunks(vtcm_budget, /*overhead=*/256, seq_per_n, /*per_m=*/vec_dot_size, seq_per_mn,
hex_align_up(m, HMX_FP16_TILE_N_ROWS), n,
/*m_block_cost=*/(size_t) n * 3,
/*n_block_cost=*/(size_t) m * 2, &m_chunk_n_rows, &n_chunk_n_cols, &vtcm_used) != 0) {
FARF(HIGH, "%s: VTCM too small (m=%d k=%d n=%d budget=%zu)", __func__, m, k, n, vtcm_budget);
Expand Down Expand Up @@ -1432,7 +1464,8 @@ int hmx_mat_mul_permuted_w16a32_batched(struct htp_context *ctx, const hmx_matmu
if (hmx_compute_chunks(vtcm_budget, /*overhead=*/256,
/*per_n=*/3 * vec_dot_size,
/*per_m=*/group_size * vec_dot_size + f32_scratch_per_m,
/*per_mn=*/sizeof(__fp16), params->m, params->n,
/*per_mn=*/sizeof(__fp16),
hex_align_up(params->m, HMX_FP16_TILE_N_ROWS), params->n,
/*m_block_cost=*/(size_t) params->n,
/*n_block_cost=*/(size_t) params->m, &m_chunk_n_rows, &n_chunk_n_cols, &vtcm_used) != 0) {
FARF(HIGH, "%s: grouped path does not fit VTCM, falling back to legacy batched loop", __func__);
Expand Down Expand Up @@ -1612,7 +1645,7 @@ int hmx_mat_mul_permuted_w16a32(struct htp_context *ctx, float *restrict dst, co
/*per_n=*/3 * vec_dot_size, // W + S0 + S1
/*per_m=*/vec_dot_size + f32_scratch_per_m, // A + optional F32 scratch
/*per_mn=*/sizeof(__fp16), // O
m, n,
hex_align_up(m, HMX_FP16_TILE_N_ROWS), n,
/*m_block_cost=*/(size_t) n,
/*n_block_cost=*/(size_t) m, &m_chunk_n_rows, &n_chunk_n_cols, &vtcm_used) != 0) {
FARF(HIGH, "%s: VTCM too small (m=%d k=%d n=%d budget=%zu)", __func__, m, k, n, vtcm_budget);
Expand Down
36 changes: 6 additions & 30 deletions ggml/src/ggml-hexagon/htp/matmul-ops.c
Original file line number Diff line number Diff line change
Expand Up @@ -2991,12 +2991,10 @@ int op_matmul(struct htp_ops_context * octx) {
return op_matmul_hvx(octx);
}

// M alignment: when M > 32 but not 32-aligned, we split into
// HMX (first m_hmx = M & ~31 rows) + HVX (remaining m_tail rows).
// When M <= 32 and not 32-aligned, fall back entirely to HVX.
// M alignment: Use HMX when M >= 32, the last partial tile (m_total % 32 rows)
// is handled by HMX itself; when M < 32 fall back to HVX.
const int m_total = (int) src1->ne[1];
const int m_tail = m_total % 32;
const int m_hmx = m_total - m_tail;
const int m_hmx = m_total & ~31; // 0 when M < 32

if (m_hmx == 0) {
return op_matmul_hvx(octx);
Expand All @@ -3009,7 +3007,6 @@ int op_matmul(struct htp_ops_context * octx) {
int k = (int) src0->ne[0]; // inner dimension
int n = (int) src0->ne[1]; // weight columns

// --- Phase 1: HMX on the first m_hmx (32-aligned) rows ---
int ret = -1;

// Row strides in elements. For compact tensors these equal k; for
Expand All @@ -3027,7 +3024,7 @@ int op_matmul(struct htp_ops_context * octx) {
.dst = (float *) dst->data,
.activation = (float *) src1->data,
.permuted_weight = (const __fp16 *) src0->data,
.m = m_hmx,
.m = m_total,
.k = k,
.n = n,
.act_stride = act_stride,
Expand All @@ -3048,40 +3045,19 @@ int op_matmul(struct htp_ops_context * octx) {
} else {
ret = hmx_mat_mul_permuted_w16a32(octx->ctx,
(float*) dst->data, (float*) src1->data, (const __fp16 *) src0->data,
m_hmx, k, n, act_stride, wgt_stride);
m_total, k, n, act_stride, wgt_stride);
}
} else {
ret = hmx_mat_mul_permuted_qk_0_d16a32(octx->ctx,
(float*) dst->data, (float*) src1->data, (const uint8_t *) src0->data,
m_hmx, k, n, (int) src0->type);
m_total, k, n, (int) src0->type);
}

if (ret != 0) {
FARF(HIGH, "HMX matmul failed (ret=%d), falling back to HVX", ret);
return op_matmul(octx);
}

// --- Phase 2: HVX on the remaining m_tail rows ---
if (m_tail > 0) {
// copy of src1 and dst
struct htp_tensor src1_tail = *src1;
struct htp_tensor dst_tail = *dst;

src1_tail.ne[1] = m_tail; // only tail rows
dst_tail.ne[1] = m_tail; // only tail rows

// Offset activation and dst pointers past the HMX-processed rows.
// Use nb[1] (row stride in bytes) to compute the byte offset.
src1_tail.data += (uint32_t) m_hmx * src1->nb[1];
dst_tail.data += (uint32_t) m_hmx * dst->nb[1];

octx->src[1] = &src1_tail;
octx->dst = &dst_tail;

FARF(HIGH, "hmx-matmul: HVX tail m_tail %d src1 %p dst %p", m_tail, (void *) src1_tail.data, (void *) dst_tail.data);
return op_matmul_hvx(octx);
}

return 0;
#endif // HTP_HAS_HMX
}
Expand Down
10 changes: 5 additions & 5 deletions ggml/src/ggml-opencl/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -66,16 +66,13 @@ set(GGML_OPENCL_KERNELS
diag
div
gelu
gemv_noshuffle_general
gemv_noshuffle
get_rows
glu
group_norm
solve_tri
im2col_f32
im2col_f16
mean
mul_mat_Ab_Bi_8x4
mul_mv_f16_f16
mul_mv_f16_f32_1row
mul_mv_f16_f32_l4
Expand Down Expand Up @@ -120,12 +117,15 @@ set(GGML_OPENCL_KERNELS
mul_mm_q4_k_f32_l4_lm
mul_mm_q5_k_f32_l4_lm
mul_mm_q6_k_f32_l4_lm
mul_mm_q8_0_f32_8x4
gemv_noshuffle_q4_0_f32
gemv_noshuffle_q4_0_f32_spec
gemm_noshuffle_q4_0_f32
gemv_noshuffle_q4_1_f32
gemm_noshuffle_q4_1_f32
gemv_noshuffle_iq4_nl_f32
gemm_noshuffle_iq4_nl_f32
gemv_noshuffle_general_q8_0_f32
gemv_noshuffle_q8_0_f32
gemm_noshuffle_q8_0_f32
gemv_noshuffle_q4_k_f32
gemm_noshuffle_q4_k_f32
gemv_noshuffle_q6_k_f32
Expand Down
Loading
Loading