forked from Scottcjn/ram-coffers
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathvcipher-flash-attn-patch.c
More file actions
155 lines (133 loc) · 6.1 KB
/
vcipher-flash-attn-patch.c
File metadata and controls
155 lines (133 loc) · 6.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
/*
* vcipher-flash-attn-patch.c
*
* Integration patch for ops.cpp: ggml_compute_forward_flash_attn_ext_f16_one_chunk
*
* This shows the modified inner loop with vcipher prefiltering.
* Apply to ~/llama.cpp/ggml/src/ggml-cpu/ops.cpp around line 8190.
*
* PHASE 2: vcipher prefilter — O(1) hardware crypto check per K-V pair.
* Skip full dot product + softmax + V accumulation for low-score pairs.
*
* Expected: 4-16x fewer full dot products on long sequences.
*/
/* === BEFORE (original inner loop) ===
for (int64_t ic = 0; ic < nek1; ++ic) {
const float mv = mp ? slope*GGML_CPU_FP16_TO_FP32(mp[ic]) : 0.0f;
if (mv == -INFINITY) { continue; }
float s;
const char * k_data = (const char *) k->data + (ic*nbk1 + ik2*nbk2 + ik3*nbk3);
kq_vec_dot(DK, &s, 0, k_data, 0, Q_q, 0, 1);
s = s*scale;
... softmax + V accumulation ...
}
=== AFTER (vcipher prefiltered) === */
#ifdef GGML_PSE_VCIPHER_PREFILTER
/* ─── PSE vcipher Prefilter Phase ───
*
* PASS 1: Quick vcipher_attention_score() on first 16 bytes of Q and K.
* This is O(1) per pair (single vcipher instruction + byte reduction).
* Build a score array, find threshold, then only do full dot product
* for positions above threshold.
*
* Cost: ~0.044 us per K-V pair (vs ~1-10 us for full kq_vec_dot on DK=128+)
*/
{
const int prefilter_top_k = VCIPHER_COLLAPSE_TOP_K; /* Keep top 8 per 16 */
const float prefilter_ratio = 0.25f; /* Keep top 25% of K-V pairs */
const int64_t keep_count = (int64_t)(nek1 * prefilter_ratio);
/* Allocate prefilter scores on stack for short sequences, heap for long */
uint32_t prefilter_stack[512];
uint32_t *prefilter_scores = (nek1 <= 512) ? prefilter_stack
: (uint32_t*)malloc(nek1 * sizeof(uint32_t));
/* PASS 1: vcipher prefilter — O(1) per pair */
for (int64_t ic = 0; ic < nek1; ++ic) {
const float mv = mp ? slope*GGML_CPU_FP16_TO_FP32(mp[ic]) : 0.0f;
if (mv == -INFINITY) {
prefilter_scores[ic] = 0; /* Masked out */
continue;
}
const char * k_data = (const char *) k->data + (ic*nbk1 + ik2*nbk2 + ik3*nbk3);
/* Use first 16 bytes of Q and K as vcipher input.
* For MXFP4/Q4_K quantized K, these bytes contain the most
* significant scale factors — good enough for ranking. */
prefilter_scores[ic] = vcipher_attention_score(
pq, (const float*)k_data, iq2, (int)ic);
}
/* Find threshold: top keep_count scores proceed to full dot product */
uint32_t threshold = 0;
if (keep_count < nek1 && keep_count > 0) {
/* Quick approximate threshold via partial sort */
uint32_t top_scores[64];
int n_top = (keep_count < 64) ? (int)keep_count : 64;
for (int i = 0; i < n_top; i++) top_scores[i] = 0;
for (int64_t ic = 0; ic < nek1; ++ic) {
uint32_t s = prefilter_scores[ic];
if (s > top_scores[n_top - 1]) {
top_scores[n_top - 1] = s;
/* Bubble up */
for (int k = n_top - 1; k > 0 && top_scores[k] > top_scores[k-1]; k--) {
uint32_t tmp = top_scores[k];
top_scores[k] = top_scores[k-1];
top_scores[k-1] = tmp;
}
}
}
threshold = top_scores[n_top - 1];
}
/* PASS 2: Full dot product ONLY for positions above vcipher threshold */
for (int64_t ic = 0; ic < nek1; ++ic) {
/* Skip positions that failed vcipher prefilter */
if (prefilter_scores[ic] < threshold && threshold > 0) {
continue; /* SKIP: no kq_vec_dot, no expf, no V accumulation */
}
const float mv = mp ? slope*GGML_CPU_FP16_TO_FP32(mp[ic]) : 0.0f;
if (mv == -INFINITY) { continue; }
float s;
const char * k_data = (const char *) k->data + (ic*nbk1 + ik2*nbk2 + ik3*nbk3);
kq_vec_dot(DK, &s, 0, k_data, 0, Q_q, 0, 1);
s = s*scale;
if (logit_softcap != 0.0f) {
s = logit_softcap*tanhf(s);
}
s += mv;
const float Mold = M;
float ms = 1.0f;
float vs = 1.0f;
const char * v_data = ((const char *) v->data + (ic*nbv1 + iv2*nbv2 + iv3*nbv3));
if (v->type == GGML_TYPE_F16) {
if (s > M) {
M = s;
ms = expf(Mold - M);
ggml_vec_scale_f16(DV, VKQ16, ms);
} else {
vs = expf(s - M);
}
ggml_vec_mad_f16(DV, VKQ16, (const ggml_fp16_t *) v_data, vs);
} else {
if (s > M) {
M = s;
ms = expf(Mold - M);
ggml_vec_scale_f32(DV, VKQ32, ms);
} else {
vs = expf(s - M);
}
if (v_to_float) {
v_to_float(v_data, V32, DV);
ggml_vec_mad_f32(DV, VKQ32, V32, vs);
} else {
ggml_vec_mad_f32(DV, VKQ32, (const float *) v_data, vs);
}
}
S = S*ms + vs;
}
if (prefilter_scores != prefilter_stack) {
free(prefilter_scores);
}
}
#else
/* Original unmodified loop */
for (int64_t ic = 0; ic < nek1; ++ic) {
/* ... original code ... */
}
#endif