Skip to content

Commit 4fdbc1e

Browse files
authored
cuda : fix nkvo, offload and cuda graph node properties matching (ggml-org#19165)
* cuda : fix nkvo * cont : more robust cuda graph node property matching * cont : restore pre-leafs implementation * cont : comments + static_assert
1 parent 7b7ae85 commit 4fdbc1e

4 files changed

Lines changed: 59 additions & 32 deletions

File tree

ggml/src/ggml-cuda/common.cuh

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1122,15 +1122,17 @@ struct ggml_tensor_extra_gpu {
11221122
#endif
11231123

11241124
struct ggml_cuda_graph_node_properties {
1125-
void * node_address;
1125+
void * node_data;
11261126
ggml_op node_op;
11271127
int32_t flags;
11281128
int64_t ne[GGML_MAX_DIMS];
11291129
size_t nb[GGML_MAX_DIMS];
1130-
void * src_address[GGML_MAX_SRC];
1130+
void * src_data[GGML_MAX_SRC];
11311131
int32_t op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t)];
11321132
};
11331133

1134+
static_assert(std::is_trivial<ggml_cuda_graph_node_properties>::value, "ggml_cuda_graph_node_properties must be trivial");
1135+
11341136
struct ggml_cuda_graph {
11351137
#ifdef USE_CUDA_GRAPH
11361138
~ggml_cuda_graph() {
@@ -1150,6 +1152,12 @@ struct ggml_cuda_graph {
11501152
int number_consecutive_updates = 0;
11511153
std::vector<ggml_cuda_graph_node_properties> props;
11521154

1155+
// these are extra tensors (inputs) that participate in the ggml graph but are not nodes
1156+
// they properties also have to match in order to be able to safely reuse a CUDA graph
1157+
// ref: https://github.com/ggml-org/llama.cpp/pull/18583
1158+
// ref: https://github.com/ggml-org/llama.cpp/pull/19165
1159+
std::vector<ggml_cuda_graph_node_properties> extra;
1160+
11531161
void record_update(bool use_graph, bool update_required) {
11541162
if (use_graph && update_required) {
11551163
number_consecutive_updates++;

ggml/src/ggml-cuda/fattn.cu

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -310,8 +310,6 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
310310
}
311311
}
312312

313-
const bool V_is_K_view = V->view_src && (V->view_src == K || (V->view_src == K->view_src && V->view_offs == K->view_offs));
314-
315313
const int cc = ggml_cuda_info().devices[device].cc;
316314

317315
switch (K->ne[0]) {
@@ -334,9 +332,6 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
334332
if (!gqa_opt_applies) {
335333
return BEST_FATTN_KERNEL_NONE;
336334
}
337-
if (!V_is_K_view) {
338-
return BEST_FATTN_KERNEL_NONE;
339-
}
340335
break;
341336
default:
342337
return BEST_FATTN_KERNEL_NONE;

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 49 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -70,17 +70,18 @@
7070
#include <condition_variable>
7171
#include <cstddef>
7272
#include <cstdint>
73-
#include <float.h>
73+
#include <cfloat>
7474
#include <initializer_list>
7575
#include <limits>
7676
#include <map>
7777
#include <memory>
7878
#include <mutex>
79-
#include <stdarg.h>
80-
#include <stdio.h>
81-
#include <stdlib.h>
79+
#include <cstdarg>
80+
#include <cstdio>
81+
#include <cstdlib>
8282
#include <string>
8383
#include <vector>
84+
#include <unordered_set>
8485

8586
static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
8687

@@ -2916,22 +2917,26 @@ static bool ggml_cuda_graph_check_compability(ggml_cgraph * cgraph) {
29162917
}
29172918

29182919
static void ggml_cuda_graph_node_set_properties(ggml_cuda_graph_node_properties * props, ggml_tensor * node) {
2919-
props->node_address = node->data;
2920+
memset(props, 0, sizeof(ggml_cuda_graph_node_properties));
2921+
props->node_data = node->data;
29202922
props->node_op = node->op;
29212923
props->flags = node->flags;
29222924
for (int i = 0; i < GGML_MAX_DIMS; i++) {
29232925
props->ne[i] = node->ne[i];
29242926
props->nb[i] = node->nb[i];
29252927
}
29262928
for (int i = 0; i < GGML_MAX_SRC; i++) {
2927-
props->src_address[i] = node->src[i] ? node->src[i]->data : nullptr;
2929+
if (!node->src[i]) {
2930+
continue;
2931+
}
2932+
2933+
props->src_data[i] = node->src[i]->data;
29282934
}
29292935
memcpy(props->op_params, node->op_params, GGML_MAX_OP_PARAMS);
29302936
}
29312937

29322938
static bool ggml_cuda_graph_node_properties_match(ggml_tensor * node, ggml_cuda_graph_node_properties * props) {
2933-
if (node->data != props->node_address &&
2934-
node->op != GGML_OP_VIEW) {
2939+
if (node->data != props->node_data && node->op != GGML_OP_VIEW) {
29352940
return false;
29362941
}
29372942

@@ -2948,12 +2953,18 @@ static bool ggml_cuda_graph_node_properties_match(ggml_tensor * node, ggml_cuda_
29482953
}
29492954
}
29502955

2951-
for (int i = 0; i < GGML_MAX_SRC; i++) {
2952-
if (node->src[i] &&
2953-
node->src[i]->data != props->src_address[i] &&
2954-
node->op != GGML_OP_VIEW
2955-
) {
2956-
return false;
2956+
if (node->op != GGML_OP_VIEW) {
2957+
for (int i = 0; i < GGML_MAX_SRC; i++) {
2958+
if (!node->src[i]) {
2959+
if (props->src_data[i] != nullptr) {
2960+
return false;
2961+
}
2962+
continue;
2963+
}
2964+
2965+
if (node->src[i]->data != props->src_data[i]) {
2966+
return false;
2967+
}
29572968
}
29582969
}
29592970

@@ -2974,7 +2985,6 @@ static const void * ggml_cuda_graph_get_key(ggml_cgraph * cgraph) {
29742985
}
29752986

29762987
static bool ggml_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph) {
2977-
29782988
bool res = false;
29792989

29802990
const void * graph_key = ggml_cuda_graph_get_key(cgraph);
@@ -2985,33 +2995,52 @@ static bool ggml_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx
29852995
}
29862996

29872997
// Check if the graph size has changed
2988-
if (graph->props.size() != (size_t)cgraph->n_nodes + cgraph->n_leafs) {
2998+
if (graph->props.size() != (size_t)cgraph->n_nodes) {
29892999
res = true;
2990-
graph->props.resize(cgraph->n_nodes + cgraph->n_leafs);
3000+
graph->props.resize(cgraph->n_nodes);
29913001
}
29923002

29933003
// Loop over nodes in GGML graph to determine if CUDA graph update is required
29943004
// and store properties to allow this comparison for the next token
3005+
std::unordered_set<ggml_tensor *> seen_node;
3006+
std::vector<ggml_tensor *> srcs_extra;
29953007
for (int i = 0; i < cgraph->n_nodes; i++) {
29963008
bool props_match = true;
3009+
3010+
seen_node.insert(cgraph->nodes[i]);
3011+
29973012
if (!res) {
29983013
props_match = ggml_cuda_graph_node_properties_match(cgraph->nodes[i], &graph->props[i]);
29993014
}
30003015
if (!props_match) {
30013016
res = true;
30023017
}
30033018
ggml_cuda_graph_node_set_properties(&graph->props[i], cgraph->nodes[i]);
3019+
3020+
for (int src_idx = 0; src_idx < GGML_MAX_SRC; ++src_idx) {
3021+
ggml_tensor * src = cgraph->nodes[i]->src[src_idx];
3022+
if (src && seen_node.find(src) == seen_node.end()) {
3023+
srcs_extra.push_back(src);
3024+
}
3025+
}
3026+
}
3027+
3028+
if (graph->extra.size() != (size_t) srcs_extra.size()) {
3029+
res = true;
3030+
graph->extra.resize(srcs_extra.size());
30043031
}
30053032

3006-
for (int i = 0; i < cgraph->n_leafs; i++) {
3033+
for (size_t i = 0; i < srcs_extra.size(); ++i) {
30073034
bool props_match = true;
3035+
30083036
if (!res) {
3009-
props_match = ggml_cuda_graph_node_properties_match(cgraph->leafs[i], &graph->props[cgraph->n_nodes + i]);
3037+
props_match = ggml_cuda_graph_node_properties_match(srcs_extra[i], &graph->extra[i]);
30103038
}
3039+
30113040
if (!props_match) {
30123041
res = true;
30133042
}
3014-
ggml_cuda_graph_node_set_properties(&graph->props[cgraph->n_nodes + i], cgraph->leafs[i]);
3043+
ggml_cuda_graph_node_set_properties(&graph->extra[i], srcs_extra[i]);
30153044
}
30163045

30173046
return res;

src/llama-graph.cpp

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1630,11 +1630,6 @@ ggml_tensor * llm_graph_context::build_attn_mha(
16301630
hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f);
16311631
cb(cur, LLAMA_TENSOR_NAME_FATTN, il);
16321632

1633-
if (!cparams.offload_kqv) {
1634-
// all nodes between the KV store and the attention output are run on the CPU
1635-
ggml_backend_sched_set_tensor_backend(sched, cur, backend_cpu);
1636-
}
1637-
16381633
ggml_flash_attn_ext_add_sinks(cur, sinks);
16391634
ggml_flash_attn_ext_set_prec (cur, GGML_PREC_F32);
16401635

0 commit comments

Comments
 (0)