Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 82 additions & 13 deletions ggml/src/ggml-cuda/ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2344,6 +2344,15 @@ static bool ggml_cuda_should_fuse_mul_mat_vec_f(const ggml_tensor * tensor) {
return use_mul_mat_vec_f;
}

static inline bool ggml_cuda_mmvq_eligible(const ggml_tensor * src0,
const ggml_tensor * src1,
const ggml_tensor * dst,
const bool bad_padding_clear) {
return ggml_is_quantized(src0->type) && !bad_padding_clear &&
(src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_Q8_1) && dst->type == GGML_TYPE_F32 &&
src1->ne[1] <= MMVQ_MAX_BATCH_SIZE;
}

static bool ggml_cuda_should_fuse_mul_mat_vec_q(const ggml_tensor * tensor) {
ggml_tensor * src0 = tensor->src[0];
ggml_tensor * src1 = tensor->src[1];
Expand All @@ -2353,8 +2362,7 @@ static bool ggml_cuda_should_fuse_mul_mat_vec_q(const ggml_tensor * tensor) {
ggml_nbytes(src0) != ggml_backend_buffer_get_alloc_size(src0->buffer, src0) &&
src0->view_src;

bool use_mul_mat_vec_q = ggml_is_quantized(src0->type) && !bad_padding_clear && src1->type == GGML_TYPE_F32 &&
dst->type == GGML_TYPE_F32 && src1->ne[1] <= MMVQ_MAX_BATCH_SIZE;
bool use_mul_mat_vec_q = ggml_cuda_mmvq_eligible(src0, src1, dst, bad_padding_clear);

// fusion is not universally faster on Pascal
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
Expand Down Expand Up @@ -2395,9 +2403,7 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
bool use_mul_mat_f = !ggml_is_quantized(src0->type)
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
bool use_mul_mat_vec_q = ggml_is_quantized(src0->type) && !bad_padding_clear
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
&& src1->ne[1] <= MMVQ_MAX_BATCH_SIZE;
bool use_mul_mat_vec_q = ggml_cuda_mmvq_eligible(src0, src1, dst, bad_padding_clear);
bool use_mul_mat_q = ggml_is_quantized(src0->type) && !bad_padding_clear
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;

Expand Down Expand Up @@ -2472,7 +2478,7 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
const ggml_tensor * src1 = dst->src[1];
const ggml_tensor * ids = dst->src[2];

GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT(src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_Q8_1);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
GGML_ASSERT(!ggml_backend_buft_is_cuda_split(src0->buffer->buft) && "mul_mat_id does not support split buffers");

Expand All @@ -2481,7 +2487,7 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;

// [TAG_MUL_MAT_ID_CUDA_GRAPHS]
if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
if ((src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_Q8_1) && dst->type == GGML_TYPE_F32) {
static_assert(MMVQ_MAX_BATCH_SIZE == MMVF_MAX_BATCH_SIZE);
if (ne2 <= MMVQ_MAX_BATCH_SIZE) {
if (ggml_is_quantized(src0->type)) {
Expand Down Expand Up @@ -2511,6 +2517,7 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *

// note: this path should not be reached when recording CUDA graphs, because it requires stream synchronization
// TODO: add asserts to verify this. should work with CUDA, HIP, etc.
GGML_ASSERT(src1->type == GGML_TYPE_F32 && "Q8_1 src1 must be handled by the MMVQ path above");
cudaStream_t stream = ctx.stream();

GGML_ASSERT(nb12 % nb11 == 0);
Expand Down Expand Up @@ -3522,10 +3529,10 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph,
GGML_ASSERT(rms_norm->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT(rms_norm->type == GGML_TYPE_F32);

//rms norm only supports F32
if (mul->src[0]->type != GGML_TYPE_F32 ||
mul->src[1]->type != GGML_TYPE_F32 ||
mul->type != GGML_TYPE_F32) {
//rms norm only supports F32; Q8_1 mul output is only valid for the 2-op fusion
const bool mul_q8_1_ok = (ops.size() == 2) && (mul->type == GGML_TYPE_Q8_1);
if (mul->src[0]->type != GGML_TYPE_F32 || mul->src[1]->type != GGML_TYPE_F32 ||
(mul->type != GGML_TYPE_F32 && !mul_q8_1_ok)) {
return false;
}

Expand Down Expand Up @@ -3990,7 +3997,12 @@ static int ggml_cuda_try_fuse(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph
}

if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL }, {})) {
ggml_cuda_op_rms_norm_fused(*cuda_ctx, node, cgraph->nodes[i + 1]);
ggml_tensor * mul_node = cgraph->nodes[i + 1];
if (mul_node->type == GGML_TYPE_Q8_1) {
ggml_cuda_op_rms_norm_mul_q8_1(*cuda_ctx, node, mul_node);
} else {
ggml_cuda_op_rms_norm_fused(*cuda_ctx, node, mul_node);
}
return 1;
}

Expand Down Expand Up @@ -4342,13 +4354,70 @@ static void ggml_backend_cuda_event_wait(ggml_backend_t backend, ggml_backend_ev
static void ggml_backend_cuda_graph_optimize(ggml_backend_t backend, ggml_cgraph * cgraph) {
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backend->context;

// Change mul->type F32→Q8_1 when all downstream consumers are MMVQ-eligible.
// gallocr then allocates Q8_1-sized memory; try_fuse dispatches the Q8_1 kernel.
for (int i = 0; i + 1 < cgraph->n_nodes; i++) {
ggml_tensor * rms = cgraph->nodes[i];
ggml_tensor * mul = cgraph->nodes[i + 1];

if (rms->op != GGML_OP_RMS_NORM) {
continue;
}
if (mul->op != GGML_OP_MUL) {
continue;
}
if (mul->type != GGML_TYPE_F32) {
continue;
}
if (mul->src[0] != rms && mul->src[1] != rms) {
continue;
}

if (mul->ne[0] % MATRIX_ROW_PADDING != 0) {
continue;
}
if (mul->ne[0] % QK8_1 != 0) {
continue;
}

const int32_t mul_use_count = ggml_node_get_use_count(cgraph, i + 1);
if (mul_use_count == 0) {
continue;
}
int found = 0;
bool all_mmvq = true;

for (int j = i + 2; j < cgraph->n_nodes && found < mul_use_count && all_mmvq; j++) {
ggml_tensor * cand = cgraph->nodes[j];
if (cand->src[0] != mul && cand->src[1] != mul) {
continue;
}
Comment on lines +4392 to +4394

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just iterate over GGML_MAX_SRC, I don't think it makes sense to risk a potential but in the future here.

found++;
const bool is_mmvq_op = cand->op == GGML_OP_MUL_MAT || cand->op == GGML_OP_MUL_MAT_ID;
const bool src0_quantized = cand->src[0] && ggml_is_quantized(cand->src[0]->type);
const int64_t batch = (cand->op == GGML_OP_MUL_MAT_ID) ? cand->ne[2] : cand->ne[1];
if (!is_mmvq_op || !src0_quantized || batch > MMVQ_MAX_BATCH_SIZE) {
all_mmvq = false;
}
Comment on lines +4396 to +4401

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the wrong logic, it has to exactly mirror the kernel selection logic in ggml_cuda_mul_mat.

}

if (!all_mmvq || found != mul_use_count) {
continue;
}

mul->type = GGML_TYPE_Q8_1;
mul->nb[0] = ggml_type_size(GGML_TYPE_Q8_1);
mul->nb[1] = ggml_row_size(GGML_TYPE_Q8_1, mul->ne[0]);
mul->nb[2] = mul->nb[1] * mul->ne[1];
mul->nb[3] = mul->nb[2] * mul->ne[2];
}

#ifdef USE_CUDA_GRAPH
const void * graph_key = ggml_cuda_graph_get_key(cgraph);
const bool use_cuda_graph = ggml_cuda_graph_set_enabled(cuda_ctx, graph_key);
#else
const bool use_cuda_graph = false;
GGML_UNUSED(cuda_ctx);
GGML_UNUSED(cgraph);
#endif

static bool enable_graph_optimization = [] {
Expand Down
25 changes: 16 additions & 9 deletions ggml/src/ggml-cuda/mmvq.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1032,7 +1032,7 @@ static void mul_mat_vec_q_switch_type(
void ggml_cuda_mul_mat_vec_q(
ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst,
const ggml_cuda_mm_fusion_args_host * fusion) {
GGML_ASSERT( src1->type == GGML_TYPE_F32);
GGML_ASSERT(src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_Q8_1);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
GGML_ASSERT(!ids || ids->type == GGML_TYPE_I32); // Optional, used for batched GGML_MUL_MAT_ID.

Expand Down Expand Up @@ -1092,12 +1092,21 @@ void ggml_cuda_mul_mat_vec_q(
}

const int64_t ne10_padded = GGML_PAD(ne10, MATRIX_ROW_PADDING);
ggml_cuda_pool_alloc<char> src1_q8_1(ctx.pool(), ne13*ne12 * ne11*ne10_padded * sizeof(block_q8_1)/QK8_1);
{

const char * src1_q8_1_data;
ggml_cuda_pool_alloc<char> src1_q8_1_buf(ctx.pool());
if (src1->type == GGML_TYPE_Q8_1) {
// src1 is already quantized (fused RMS+MUL→Q8_1 kernel wrote directly)
GGML_ASSERT(ne10 == ne10_padded);
src1_q8_1_data = (const char *) src1->data;
} else {
src1_q8_1_buf.alloc(ne13 * ne12 * ne11 * ne10_padded * sizeof(block_q8_1) / QK8_1);
const int64_t s11 = src1->nb[1] / ts_src1;
const int64_t s12 = src1->nb[2] / ts_src1;
const int64_t s13 = src1->nb[3] / ts_src1;
quantize_row_q8_1_cuda(src1_d, nullptr, src1_q8_1.get(), src0->type, ne10, s11, s12, s13, ne10_padded, ne11, ne12, ne13, stream);
quantize_row_q8_1_cuda(src1_d, nullptr, src1_q8_1_buf.get(), src0->type, ne10, s11, s12, s13, ne10_padded, ne11,
ne12, ne13, stream);
src1_q8_1_data = src1_q8_1_buf.get();
}

const int64_t s01 = src0->nb[1] / ts_src0;
Expand All @@ -1122,11 +1131,9 @@ void ggml_cuda_mul_mat_vec_q(

const int64_t ids_stride = ids ? ids->nb[1] / ggml_type_size(ids->type) : 0;

mul_mat_vec_q_switch_type(
src0->data, src0->type, src1_q8_1.get(), ids_d, fusion_local, dst_d, ne00,
ne01, ncols_dst, s01, stride_col_y, stride_col_dst,
ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
ne03, ne3, s03, s13, s3, ids_stride, stream);
mul_mat_vec_q_switch_type(src0->data, src0->type, src1_q8_1_data, ids_d, fusion_local, dst_d, ne00, ne01, ncols_dst,
s01, stride_col_y, stride_col_dst, ne02, nchannels_y, nchannels_dst, s02,
stride_channel_y, stride_channel_dst, ne03, ne3, s03, s13, s3, ids_stride, stream);
}

void ggml_cuda_op_mul_mat_vec_q(
Expand Down
95 changes: 95 additions & 0 deletions ggml/src/ggml-cuda/norm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -621,6 +621,101 @@ void ggml_cuda_op_rms_norm_fused_add(ggml_backend_cuda_context & ctx,
eps, stream);
}

template <int block_size>
static __global__ __launch_bounds__(block_size) void rms_norm_mul_q8_1_f32(const float * __restrict__ x,
Comment thread
lnigam marked this conversation as resolved.
const float * __restrict__ weight,
block_q8_1 * __restrict__ dst,
const int ncols,
const int64_t stride_row,
const int64_t stride_channel,
const int64_t stride_sample,
const float eps) {
const int nrows = gridDim.x;
const int nchannels = gridDim.y;

const int row = blockIdx.x;
const int channel = blockIdx.y;
const int sample = blockIdx.z;
const int tid = threadIdx.x;

x += sample * stride_sample + channel * stride_channel + row * stride_row;
dst += ((sample * nchannels + channel) * nrows + row) * (ncols / QK8_1);

float partial_sum_sq = 0.0f;
for (int col = tid; col < ncols; col += block_size) {
const float xi = x[col];
partial_sum_sq += xi * xi;
}

extern __shared__ float s_sum[];
float sum_sq = block_reduce<block_reduce_method::SUM, block_size>(partial_sum_sq, s_sum);
const float scale = rsqrtf(sum_sq / ncols + eps);

const int warp_id = tid / WARP_SIZE;
const int lane_id = tid % WARP_SIZE;
const int nwarps = block_size / WARP_SIZE;
const int nq8_blocks = ncols / QK8_1;

for (int b = warp_id; b < nq8_blocks; b += nwarps) {
const int col = b * QK8_1 + lane_id;
const float val = scale * x[col] * weight[col];

float amax = fabsf(val);
float fsum = val;

amax = warp_reduce_max<QK8_1>(amax);
fsum = warp_reduce_sum<QK8_1>(fsum);

const float d = amax / 127.0f;
const int8_t q = amax == 0.0f ? 0 : roundf(val / d);

dst[b].qs[lane_id] = q;

if (lane_id == 0) {
dst[b].ds = make_half2(d, fsum);
}
}
}

void ggml_cuda_op_rms_norm_mul_q8_1(ggml_backend_cuda_context & ctx,
ggml_tensor * rms_norm_node,
ggml_tensor * mul_node) {
const ggml_tensor * src = rms_norm_node->src[0];
const ggml_tensor * weight = (mul_node->src[0] == rms_norm_node) ? mul_node->src[1] : mul_node->src[0];

const float * src_d = (const float *) src->data;
const float * weight_d = (const float *) weight->data;

block_q8_1 * dst_d = (block_q8_1 *) mul_node->data;
cudaStream_t stream = ctx.stream();

float eps;
memcpy(&eps, rms_norm_node->op_params, sizeof(float));

const int64_t ne00 = src->ne[0];
const int64_t ne01 = src->ne[1];
const int64_t ne02 = src->ne[2];
const int64_t ne03 = src->ne[3];

GGML_ASSERT(ne00 % QK8_1 == 0);
GGML_ASSERT(src->nb[0] == sizeof(float));
GGML_ASSERT(weight->ne[0] == ne00 && weight->ne[1] == 1 && weight->ne[2] == 1 && weight->ne[3] == 1);
GGML_ASSERT(weight->nb[0] == sizeof(float));

const int64_t s01 = src->nb[1] / sizeof(float);
const int64_t s02 = src->nb[2] / sizeof(float);
const int64_t s03 = src->nb[3] / sizeof(float);

const dim3 blocks_num(ne01, ne02, ne03);
if (ne00 < 1024) {
rms_norm_mul_q8_1_f32<256>
<<<blocks_num, dim3(256), 32 * sizeof(float), stream>>>(src_d, weight_d, dst_d, ne00, s01, s02, s03, eps);
} else {
rms_norm_mul_q8_1_f32<1024>
<<<blocks_num, dim3(1024), 32 * sizeof(float), stream>>>(src_d, weight_d, dst_d, ne00, s01, s02, s03, eps);
}
}

void ggml_cuda_op_rms_norm_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * grad = dst->src[0]; // gradients
const ggml_tensor * src0f = dst->src[1]; // src0 from forward pass
Expand Down
2 changes: 2 additions & 0 deletions ggml/src/ggml-cuda/norm.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ void ggml_cuda_op_rms_norm_fused_add(ggml_backend_cuda_context & ctx,
ggml_tensor * mul_tensor,
ggml_tensor * add_tensor);

void ggml_cuda_op_rms_norm_mul_q8_1(ggml_backend_cuda_context & ctx, ggml_tensor * rms_dst, ggml_tensor * mul_tensor);

void ggml_cuda_op_rms_norm_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

void ggml_cuda_op_l2_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
57 changes: 57 additions & 0 deletions tests/test-backend-ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2554,6 +2554,55 @@ struct test_rms_norm_mul_rope : public test_case {
}
};

// GGML_OP_RMS_NORM + GGML_OP_MUL -> Q8_1 + GGML_OP_MUL_MAT (fused quantize path)
struct test_rms_norm_mul_q8_1_mul_mat : public test_case {
const ggml_type weight_type; // quantized weight type for MUL_MAT src0
const int64_t k; // hidden dim (ne[0] of input and weight)
const int64_t m; // output rows (ne[1] of MUL_MAT src0)
const int64_t n_tokens; // batch size (ne[1] of input; must be <= MMVQ_MAX_BATCH_SIZE)
const float eps;

std::string op_desc(ggml_tensor * t) override {
GGML_UNUSED(t);
return "RMS_NORM_MUL_Q8_1_MUL_MAT";
}

std::string vars() override { return VARS_TO_STR5(weight_type, k, m, n_tokens, eps); }

double max_nmse_err() override { return 5e-4; }

bool run_whole_graph() override { return true; }

test_rms_norm_mul_q8_1_mul_mat(ggml_type weight_type = GGML_TYPE_Q4_K,
int64_t k = 2048,
int64_t m = 512,
int64_t n_tokens = 1,
float eps = 1e-6f) :
weight_type(weight_type),
k(k),
m(m),
n_tokens(n_tokens),
eps(eps) {}

ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * x = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, k, n_tokens);
ggml_set_name(x, "x");

ggml_tensor * norm_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, k);
ggml_set_name(norm_w, "norm_w");

ggml_tensor * mm_w = ggml_new_tensor_2d(ctx, weight_type, k, m);
ggml_set_name(mm_w, "mm_w");

ggml_tensor * rms = ggml_rms_norm(ctx, x, eps);
ggml_tensor * mul = ggml_mul(ctx, rms, norm_w);
ggml_tensor * out = ggml_mul_mat(ctx, mm_w, mul);
ggml_set_name(out, "out");

return out;
}
};

// GGML_OP_ARGMAX
struct test_argmax : public test_case {
const ggml_type type;
Expand Down Expand Up @@ -8064,6 +8113,14 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
}
}
}
// RMS_NORM + MUL -> Q8_1 fused quantize + MUL_MAT
for (ggml_type wt : { GGML_TYPE_Q4_K, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0 }) {
for (int64_t n_tokens : { 1, 4, 8 }) {
test_cases.emplace_back(new test_rms_norm_mul_q8_1_mul_mat(wt, 512, 256, n_tokens));
test_cases.emplace_back(new test_rms_norm_mul_q8_1_mul_mat(wt, 2048, 512, n_tokens));
}
}

for (int64_t d_conv : {3, 4, 9}) {
for (int64_t d_inner: {1024, 1536, 2048}) {
test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {d_conv, d_inner, 1, 1}, {d_conv, d_inner, 1, 1}));
Expand Down
Loading