@@ -13922,50 +13922,56 @@ static void deltanet_forward(tq_model_t* model, tq_state_t* s, int l) {
1392213922 float decay = decay_vals[h]; /* precomputed exp(gate) */
1392313923
1392413924#ifdef __ARM_NEON
13925- /* NEON-optimized: fused decay + sk computation.
13926- * For each row i of state: decay state, accumulate sk.
13927- * sk[j] = sum_i(S[i,j] * K[i]) after decay */
13925+ /* NEON-optimized: llama.cpp-aligned delta rule.
13926+ * Formula (matches gated_delta_net.cu):
13927+ * sk = S @ K (BEFORE decay)
13928+ * d = (V - g*sk) * beta
13929+ * S = g*S + K * d
13930+ * o = S @ Q
13931+ * The key difference from the previous impl: sk is computed
13932+ * on the ORIGINAL state, then decay is applied to both sk
13933+ * (in the delta) and S (in the update). This prevents
13934+ * short-prompt instability where early tokens have near-zero
13935+ * state and the decay-first approach loses information. */
1392813936 float* sk = s->delta_sk;
1392913937 memset(sk, 0, (size_t)dv * sizeof(float));
1393013938
13931- float32x4_t vdecay = vdupq_n_f32( decay);
13939+ /* Step A: sk = S @ K (on original state, BEFORE decay) */
1393213940 for (int i = 0; i < dk; i++) {
1393313941 float* sp = sh + i * dv;
1393413942 float ki = kh[i];
1393513943 float32x4_t vki = vdupq_n_f32(ki);
1393613944 int j = 0;
1393713945 for (; j + 3 < dv; j += 4) {
1393813946 float32x4_t vs = vld1q_f32(sp + j);
13939- vs = vmulq_f32(vs, vdecay); /* decay */
13940- vst1q_f32(sp + j, vs); /* store decayed state */
1394113947 float32x4_t vsk = vld1q_f32(sk + j);
13942- vsk = vfmaq_f32(vsk, vs, vki); /* accumulate sk */
13948+ vsk = vfmaq_f32(vsk, vs, vki);
1394313949 vst1q_f32(sk + j, vsk);
1394413950 }
1394513951 for (; j < dv; j++) {
13946- sp[j] *= decay;
1394713952 sk[j] += sp[j] * ki;
1394813953 }
1394913954 }
1395013955
13951- /* Delta : d = beta * (V - sk) */
13956+ /* Step B : d = (V - g* sk) * beta */
1395213957 float* d_vec = s->delta_dvec;
1395313958 float32x4_t vbeta = vdupq_n_f32(beta_h);
13959+ float32x4_t vdecay = vdupq_n_f32(decay);
1395413960 {
1395513961 int j = 0;
1395613962 for (; j + 3 < dv; j += 4) {
1395713963 float32x4_t vv = vld1q_f32(vh + j);
13958- float32x4_t vs = vld1q_f32(sk + j);
13959- float32x4_t vd = vmulq_f32(vbeta, vsubq_f32(vv, vs ));
13964+ float32x4_t vsk = vld1q_f32(sk + j);
13965+ float32x4_t vd = vmulq_f32(vbeta, vsubq_f32(vv, vmulq_f32(vdecay, vsk) ));
1396013966 vst1q_f32(d_vec + j, vd);
1396113967 }
1396213968 for (; j < dv; j++) {
13963- d_vec[j] = beta_h * (vh[j] - sk[j]);
13969+ d_vec[j] = beta_h * (vh[j] - decay * sk[j]);
1396413970 }
1396513971 }
1396613972
13967- /* State update : S[i][j] += K[i] * d[j] (rank-1 outer product )
13968- * + Output: o[j] = sum_i(S[i,j] * Q[i]) (simultaneously) */
13973+ /* Step C : S = g*S + K*d (state update )
13974+ * + Output: o = S @ Q (simultaneously) */
1396913975 float* oh = s->delta_out + h * dv;
1397013976 memset(oh, 0, (size_t)dv * sizeof(float));
1397113977
@@ -13978,26 +13984,24 @@ static void deltanet_forward(tq_model_t* model, tq_state_t* s, int l) {
1397813984 int j = 0;
1397913985 for (; j + 3 < dv; j += 4) {
1398013986 float32x4_t vs = vld1q_f32(sp + j);
13987+ vs = vmulq_f32(vs, vdecay); /* S = g*S */
1398113988 float32x4_t vd = vld1q_f32(d_vec + j);
13982- vs = vfmaq_f32(vs, vki, vd); /* S += K[i] * d */
13989+ vs = vfmaq_f32(vs, vki, vd); /* S += K[i] * d */
1398313990 vst1q_f32(sp + j, vs);
1398413991 float32x4_t vo = vld1q_f32(oh + j);
13985- vo = vfmaq_f32(vo, vs, vqi); /* o += S * Q[i] */
13992+ vo = vfmaq_f32(vo, vs, vqi); /* o += S * Q[i] */
1398613993 vst1q_f32(oh + j, vo);
1398713994 }
1398813995 for (; j < dv; j++) {
13989- sp[j] += ki * d_vec[j];
13996+ sp[j] = decay * sp[j] + ki * d_vec[j];
1399013997 oh[j] += sp[j] * qi;
1399113998 }
1399213999 }
1399314000#else
13994- /* Scalar fallback */
13995- /* Decay: S = S * exp(gate) */
13996- for (int i = 0; i < dk * dv; i++) {
13997- sh[i] *= decay;
13998- }
14001+ /* Scalar fallback — llama.cpp-aligned formula:
14002+ * sk = S @ K, d = (V - g*sk) * beta, S = g*S + K*d, o = S @ Q */
1399914003
14000- /* Compute sk */
14004+ /* Compute sk = S @ K (original state, before decay) */
1400114005 float* sk = s->delta_sk;
1400214006 for (int j = 0; j < dv; j++) {
1400314007 float sum = 0.0f;
@@ -14007,20 +14011,20 @@ static void deltanet_forward(tq_model_t* model, tq_state_t* s, int l) {
1400714011 sk[j] = sum;
1400814012 }
1400914013
14010- /* Delta */
14014+ /* Delta: d = (V - g*sk) * beta */
1401114015 float* d_vec = s->delta_dvec;
1401214016 for (int j = 0; j < dv; j++) {
14013- d_vec[j] = beta_h * (vh[j] - sk[j]);
14017+ d_vec[j] = beta_h * (vh[j] - decay * sk[j]);
1401414018 }
1401514019
14016- /* State update */
14020+ /* State update: S = g*S + K*d */
1401714021 for (int i = 0; i < dk; i++) {
1401814022 for (int j = 0; j < dv; j++) {
14019- sh[i * dv + j] += kh[i] * d_vec[j];
14023+ sh[i * dv + j] = decay * sh[i * dv + j] + kh[i] * d_vec[j];
1402014024 }
1402114025 }
1402214026
14023- /* Output */
14027+ /* Output: o = S @ Q */
1402414028 float* oh = s->delta_out + h * dv;
1402514029 for (int j = 0; j < dv; j++) {
1402614030 float sum = 0.0f;
0 commit comments