Skip to content

Commit d26ca5e

Browse files
unamedkrclaude
andcommitted
fix(deltanet): align decay formula with llama.cpp (#95)
Changed DeltaNet recurrent update from: OLD: S=decay*S; sk=S@K; d=beta*(V-sk); S=S+K*d NEW: sk=S@K; d=(V-decay*sk)*beta; S=decay*S+K*d Matches llama.cpp gated_delta_net.cu: sk computed on ORIGINAL state before decay decay applied to sk in delta AND to S in update Impact: fixes numerical difference with reference, but does NOT fix the short-prompt instability alone. Document QA still works (2/2). Short prompts still fail (0/5). The issue requires additional fixes (L2 norm removal or Q scaling timing change per reference analysis). Refs #95 Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 7e2ca31 commit d26ca5e

1 file changed

Lines changed: 32 additions & 28 deletions

File tree

quant.h

Lines changed: 32 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -13922,50 +13922,56 @@ static void deltanet_forward(tq_model_t* model, tq_state_t* s, int l) {
1392213922
float decay = decay_vals[h]; /* precomputed exp(gate) */
1392313923

1392413924
#ifdef __ARM_NEON
13925-
/* NEON-optimized: fused decay + sk computation.
13926-
* For each row i of state: decay state, accumulate sk.
13927-
* sk[j] = sum_i(S[i,j] * K[i]) after decay */
13925+
/* NEON-optimized: llama.cpp-aligned delta rule.
13926+
* Formula (matches gated_delta_net.cu):
13927+
* sk = S @ K (BEFORE decay)
13928+
* d = (V - g*sk) * beta
13929+
* S = g*S + K * d
13930+
* o = S @ Q
13931+
* The key difference from the previous impl: sk is computed
13932+
* on the ORIGINAL state, then decay is applied to both sk
13933+
* (in the delta) and S (in the update). This prevents
13934+
* short-prompt instability where early tokens have near-zero
13935+
* state and the decay-first approach loses information. */
1392813936
float* sk = s->delta_sk;
1392913937
memset(sk, 0, (size_t)dv * sizeof(float));
1393013938

13931-
float32x4_t vdecay = vdupq_n_f32(decay);
13939+
/* Step A: sk = S @ K (on original state, BEFORE decay) */
1393213940
for (int i = 0; i < dk; i++) {
1393313941
float* sp = sh + i * dv;
1393413942
float ki = kh[i];
1393513943
float32x4_t vki = vdupq_n_f32(ki);
1393613944
int j = 0;
1393713945
for (; j + 3 < dv; j += 4) {
1393813946
float32x4_t vs = vld1q_f32(sp + j);
13939-
vs = vmulq_f32(vs, vdecay); /* decay */
13940-
vst1q_f32(sp + j, vs); /* store decayed state */
1394113947
float32x4_t vsk = vld1q_f32(sk + j);
13942-
vsk = vfmaq_f32(vsk, vs, vki); /* accumulate sk */
13948+
vsk = vfmaq_f32(vsk, vs, vki);
1394313949
vst1q_f32(sk + j, vsk);
1394413950
}
1394513951
for (; j < dv; j++) {
13946-
sp[j] *= decay;
1394713952
sk[j] += sp[j] * ki;
1394813953
}
1394913954
}
1395013955

13951-
/* Delta: d = beta * (V - sk) */
13956+
/* Step B: d = (V - g*sk) * beta */
1395213957
float* d_vec = s->delta_dvec;
1395313958
float32x4_t vbeta = vdupq_n_f32(beta_h);
13959+
float32x4_t vdecay = vdupq_n_f32(decay);
1395413960
{
1395513961
int j = 0;
1395613962
for (; j + 3 < dv; j += 4) {
1395713963
float32x4_t vv = vld1q_f32(vh + j);
13958-
float32x4_t vs = vld1q_f32(sk + j);
13959-
float32x4_t vd = vmulq_f32(vbeta, vsubq_f32(vv, vs));
13964+
float32x4_t vsk = vld1q_f32(sk + j);
13965+
float32x4_t vd = vmulq_f32(vbeta, vsubq_f32(vv, vmulq_f32(vdecay, vsk)));
1396013966
vst1q_f32(d_vec + j, vd);
1396113967
}
1396213968
for (; j < dv; j++) {
13963-
d_vec[j] = beta_h * (vh[j] - sk[j]);
13969+
d_vec[j] = beta_h * (vh[j] - decay * sk[j]);
1396413970
}
1396513971
}
1396613972

13967-
/* State update: S[i][j] += K[i] * d[j] (rank-1 outer product)
13968-
* + Output: o[j] = sum_i(S[i,j] * Q[i]) (simultaneously) */
13973+
/* Step C: S = g*S + K*d (state update)
13974+
* + Output: o = S @ Q (simultaneously) */
1396913975
float* oh = s->delta_out + h * dv;
1397013976
memset(oh, 0, (size_t)dv * sizeof(float));
1397113977

@@ -13978,26 +13984,24 @@ static void deltanet_forward(tq_model_t* model, tq_state_t* s, int l) {
1397813984
int j = 0;
1397913985
for (; j + 3 < dv; j += 4) {
1398013986
float32x4_t vs = vld1q_f32(sp + j);
13987+
vs = vmulq_f32(vs, vdecay); /* S = g*S */
1398113988
float32x4_t vd = vld1q_f32(d_vec + j);
13982-
vs = vfmaq_f32(vs, vki, vd); /* S += K[i] * d */
13989+
vs = vfmaq_f32(vs, vki, vd); /* S += K[i] * d */
1398313990
vst1q_f32(sp + j, vs);
1398413991
float32x4_t vo = vld1q_f32(oh + j);
13985-
vo = vfmaq_f32(vo, vs, vqi); /* o += S * Q[i] */
13992+
vo = vfmaq_f32(vo, vs, vqi); /* o += S * Q[i] */
1398613993
vst1q_f32(oh + j, vo);
1398713994
}
1398813995
for (; j < dv; j++) {
13989-
sp[j] += ki * d_vec[j];
13996+
sp[j] = decay * sp[j] + ki * d_vec[j];
1399013997
oh[j] += sp[j] * qi;
1399113998
}
1399213999
}
1399314000
#else
13994-
/* Scalar fallback */
13995-
/* Decay: S = S * exp(gate) */
13996-
for (int i = 0; i < dk * dv; i++) {
13997-
sh[i] *= decay;
13998-
}
14001+
/* Scalar fallback — llama.cpp-aligned formula:
14002+
* sk = S @ K, d = (V - g*sk) * beta, S = g*S + K*d, o = S @ Q */
1399914003

14000-
/* Compute sk */
14004+
/* Compute sk = S @ K (original state, before decay) */
1400114005
float* sk = s->delta_sk;
1400214006
for (int j = 0; j < dv; j++) {
1400314007
float sum = 0.0f;
@@ -14007,20 +14011,20 @@ static void deltanet_forward(tq_model_t* model, tq_state_t* s, int l) {
1400714011
sk[j] = sum;
1400814012
}
1400914013

14010-
/* Delta */
14014+
/* Delta: d = (V - g*sk) * beta */
1401114015
float* d_vec = s->delta_dvec;
1401214016
for (int j = 0; j < dv; j++) {
14013-
d_vec[j] = beta_h * (vh[j] - sk[j]);
14017+
d_vec[j] = beta_h * (vh[j] - decay * sk[j]);
1401414018
}
1401514019

14016-
/* State update */
14020+
/* State update: S = g*S + K*d */
1401714021
for (int i = 0; i < dk; i++) {
1401814022
for (int j = 0; j < dv; j++) {
14019-
sh[i * dv + j] += kh[i] * d_vec[j];
14023+
sh[i * dv + j] = decay * sh[i * dv + j] + kh[i] * d_vec[j];
1402014024
}
1402114025
}
1402214026

14023-
/* Output */
14027+
/* Output: o = S @ Q */
1402414028
float* oh = s->delta_out + h * dv;
1402514029
for (int j = 0; j < dv; j++) {
1402614030
float sum = 0.0f;

0 commit comments

Comments
 (0)