7070#include < condition_variable>
7171#include < cstddef>
7272#include < cstdint>
73- #include < float.h >
73+ #include < cfloat >
7474#include < initializer_list>
7575#include < limits>
7676#include < map>
7777#include < memory>
7878#include < mutex>
79- #include < stdarg.h >
80- #include < stdio.h >
81- #include < stdlib.h >
79+ #include < cstdarg >
80+ #include < cstdio >
81+ #include < cstdlib >
8282#include < string>
8383#include < vector>
84+ #include < unordered_set>
8485
8586static_assert (sizeof (half) == sizeof (ggml_fp16_t ), " wrong fp16 size" );
8687
@@ -2916,22 +2917,26 @@ static bool ggml_cuda_graph_check_compability(ggml_cgraph * cgraph) {
29162917}
29172918
29182919static void ggml_cuda_graph_node_set_properties (ggml_cuda_graph_node_properties * props, ggml_tensor * node) {
2919- props->node_address = node->data ;
2920+ memset (props, 0 , sizeof (ggml_cuda_graph_node_properties));
2921+ props->node_data = node->data ;
29202922 props->node_op = node->op ;
29212923 props->flags = node->flags ;
29222924 for (int i = 0 ; i < GGML_MAX_DIMS; i++) {
29232925 props->ne [i] = node->ne [i];
29242926 props->nb [i] = node->nb [i];
29252927 }
29262928 for (int i = 0 ; i < GGML_MAX_SRC; i++) {
2927- props->src_address [i] = node->src [i] ? node->src [i]->data : nullptr ;
2929+ if (!node->src [i]) {
2930+ continue ;
2931+ }
2932+
2933+ props->src_data [i] = node->src [i]->data ;
29282934 }
29292935 memcpy (props->op_params , node->op_params , GGML_MAX_OP_PARAMS);
29302936}
29312937
29322938static bool ggml_cuda_graph_node_properties_match (ggml_tensor * node, ggml_cuda_graph_node_properties * props) {
2933- if (node->data != props->node_address &&
2934- node->op != GGML_OP_VIEW) {
2939+ if (node->data != props->node_data && node->op != GGML_OP_VIEW) {
29352940 return false ;
29362941 }
29372942
@@ -2948,12 +2953,18 @@ static bool ggml_cuda_graph_node_properties_match(ggml_tensor * node, ggml_cuda_
29482953 }
29492954 }
29502955
2951- for (int i = 0 ; i < GGML_MAX_SRC; i++) {
2952- if (node->src [i] &&
2953- node->src [i]->data != props->src_address [i] &&
2954- node->op != GGML_OP_VIEW
2955- ) {
2956- return false ;
2956+ if (node->op != GGML_OP_VIEW) {
2957+ for (int i = 0 ; i < GGML_MAX_SRC; i++) {
2958+ if (!node->src [i]) {
2959+ if (props->src_data [i] != nullptr ) {
2960+ return false ;
2961+ }
2962+ continue ;
2963+ }
2964+
2965+ if (node->src [i]->data != props->src_data [i]) {
2966+ return false ;
2967+ }
29572968 }
29582969 }
29592970
@@ -2974,7 +2985,6 @@ static const void * ggml_cuda_graph_get_key(ggml_cgraph * cgraph) {
29742985}
29752986
29762987static bool ggml_cuda_graph_update_required (ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph) {
2977-
29782988 bool res = false ;
29792989
29802990 const void * graph_key = ggml_cuda_graph_get_key (cgraph);
@@ -2985,33 +2995,52 @@ static bool ggml_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx
29852995 }
29862996
29872997 // Check if the graph size has changed
2988- if (graph->props .size () != (size_t )cgraph->n_nodes + cgraph-> n_leafs ) {
2998+ if (graph->props .size () != (size_t )cgraph->n_nodes ) {
29892999 res = true ;
2990- graph->props .resize (cgraph->n_nodes + cgraph-> n_leafs );
3000+ graph->props .resize (cgraph->n_nodes );
29913001 }
29923002
29933003 // Loop over nodes in GGML graph to determine if CUDA graph update is required
29943004 // and store properties to allow this comparison for the next token
3005+ std::unordered_set<ggml_tensor *> seen_node;
3006+ std::vector<ggml_tensor *> srcs_extra;
29953007 for (int i = 0 ; i < cgraph->n_nodes ; i++) {
29963008 bool props_match = true ;
3009+
3010+ seen_node.insert (cgraph->nodes [i]);
3011+
29973012 if (!res) {
29983013 props_match = ggml_cuda_graph_node_properties_match (cgraph->nodes [i], &graph->props [i]);
29993014 }
30003015 if (!props_match) {
30013016 res = true ;
30023017 }
30033018 ggml_cuda_graph_node_set_properties (&graph->props [i], cgraph->nodes [i]);
3019+
3020+ for (int src_idx = 0 ; src_idx < GGML_MAX_SRC; ++src_idx) {
3021+ ggml_tensor * src = cgraph->nodes [i]->src [src_idx];
3022+ if (src && seen_node.find (src) == seen_node.end ()) {
3023+ srcs_extra.push_back (src);
3024+ }
3025+ }
3026+ }
3027+
3028+ if (graph->extra .size () != (size_t ) srcs_extra.size ()) {
3029+ res = true ;
3030+ graph->extra .resize (srcs_extra.size ());
30043031 }
30053032
3006- for (int i = 0 ; i < cgraph-> n_leafs ; i++ ) {
3033+ for (size_t i = 0 ; i < srcs_extra. size (); ++i ) {
30073034 bool props_match = true ;
3035+
30083036 if (!res) {
3009- props_match = ggml_cuda_graph_node_properties_match (cgraph-> leafs [i], &graph->props [cgraph-> n_nodes + i]);
3037+ props_match = ggml_cuda_graph_node_properties_match (srcs_extra [i], &graph->extra [ i]);
30103038 }
3039+
30113040 if (!props_match) {
30123041 res = true ;
30133042 }
3014- ggml_cuda_graph_node_set_properties (&graph->props [cgraph-> n_nodes + i], cgraph-> leafs [i]);
3043+ ggml_cuda_graph_node_set_properties (&graph->extra [ i], srcs_extra [i]);
30153044 }
30163045
30173046 return res;
0 commit comments