Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
331 changes: 203 additions & 128 deletions Libraries/MLXLLM/Models/Gemma4Text.swift

Large diffs are not rendered by default.

77 changes: 71 additions & 6 deletions Libraries/MLXLMCommon/Evaluate.swift
Original file line number Diff line number Diff line change
Expand Up @@ -1098,6 +1098,14 @@ public struct MTPTokenIterator: TokenIteratorProtocol {
// Logits from the previous step's MTP heads
var mtpLogits: [MLXArray]?

// Partial rollback state (llama.cpp PR #22673 style).
// After accepting k of N drafts, the backbone hidden state at position k is stored here.
// On the NEXT cold-start round (mtpLogits=nil), callMTPHeadOnly seeds one draft from this
// state — turning a zero-draft round into a one-draft round without an extra main-model pass.
private var rollbackH: MLXArray? = nil // [B, 1, D] backbone state at accepted pos
private var rollbackToken: MLXArray? = nil // [1, 1] int32 — the output token (x_{k+1})
private var rollbackPosOffset: Int = 0 // sequence position of the accepted token

// Buffer of accepted tokens from the current speculation round
private var pendingTokens = [Int]()
private var pendingIndex = 0
Expand Down Expand Up @@ -1187,7 +1195,36 @@ public struct MTPTokenIterator: TokenIteratorProtocol {
}
}

// If no draft tokens were generated (e.g. first step), fallback to regular generation
// Partial rollback (llama.cpp PR #22673 style): after a partial accept, use the stored
// backbone hidden state to seed one draft via callMTPHeadOnly, then verify it.
// Empirically at 40K: rollback-ON=28.5 tok/s vs rollback-OFF=21.1 tok/s (+35%).
// The rollback-seeded 2-token verify replaces a cold-start callMTP round — despite
// cascading 44% rejection, the shorter verify batch (2 vs 5 tokens) recovers faster.
if draftTokens.isEmpty {
if let rH = rollbackH,
let rTok = rollbackToken,
let assistantModel = model as? any MTPPartialRollback {
rollbackH = nil
rollbackToken = nil
// depth=1: only depth-0 is well-conditioned from h_k.
// depth>1 chains MTP greedy argmax, compounding context misalignment
// vs the trimmed KV cache — empirically caused -31% TPS at 40K with depth=4.
let depth = 1
let headLogits = assistantModel.callMTPHeadOnly(
rH, nextToken: rTok, cache: cache, posOffset: rollbackPosOffset, mtpDepth: depth)
if !headLogits.isEmpty {
var draftProcessor = processor
let draftLogit = headLogits[0][0..., 0, 0...] // [B, V]
var dl = draftProcessor?.process(logits: draftLogit) ?? draftLogit
let draftToken = sampler.sample(logits: dl)
draftProcessor?.didSample(token: draftToken)
draftTokens.append(draftToken)
draftProcessedLogits.append(dl)
}
}
}


if draftTokens.isEmpty {
let mtpResult = model.callMTP(y.tokens[.newAxis], cache: cache, mtpCaches: mtpCaches)
guard !mtpResult.isEmpty else { return }
Expand Down Expand Up @@ -1353,19 +1390,47 @@ public struct MTPTokenIterator: TokenIteratorProtocol {
// Set y for the next round
y = .init(tokens: finalTokenOut)

// Update mtpLogits from the verification pass for the NEXT speculation round.
// mtpResult[1..N] contains the MTP head outputs for each depth.
// Each head output is [B, 1, vocab] — extract directly (no position indexing needed).
// Only keep them if ALL drafts were accepted, otherwise they are invalid due to cache rewind.
// Capture partial rollback state (llama.cpp PR #22673 / pending_h approach).
// When k < N drafts are accepted:
// hBackbone[:, verifyStart+k, :] = hidden state after the k-th accepted token
// finalTokenOut = x_{k+1}, the bonus token being output this round
// On the next cold-start round, callMTPHeadOnly will use this state to seed
// one draft (predicting x_{k+2}) without re-running the main model.
if accepted < draftTokens.count,
let assistantModel = model as? any MTPPartialRollback,
let allH = assistantModel.lastBackboneHiddenStateAll {
let seedPos = verifyStart + accepted
if seedPos < allH.dim(1) {
rollbackH = allH[0..., seedPos..<(seedPos + 1), 0...] // [B, 1, D]
rollbackToken = finalTokenOut.flattened().reshaped([1, 1]) // [B=1, 1] — flatten first to handle 0-D scalars
// posOffset = current KV cache length after trim (= position of finalTokenOut)
rollbackPosOffset = cache.first.map {
if let c = $0 as? KVCacheSimple { return c.offset }
if let c = $0 as? RotatingKVCache { return c.offset }
return 0
} ?? 0
}
} else {
// All accepted or rollback not available — clear any stale state
rollbackH = nil
rollbackToken = nil
}

// Update mtpLogits for the NEXT speculation round.
// Only valid to reuse when ALL drafts accepted: the MTP head ran from the last position
// which matches the trimmed-cache state. On partial accept the head ran from a stale
// position (rejected tokens still in context at inference time) — stale logits hurt.
// Partial-accept cold-start is handled by the rollback path (rollbackH) above.
if accepted == draftTokens.count && mtpResult.count > 1 {
self.mtpLogits = mtpResult.dropFirst().map { headLogits in
// headLogits shape: [B, 1, vocab] — squeeze to [B, vocab] for the sampler
headLogits[0..., headLogits.dim(1) - 1, 0...]
}
} else {
self.mtpLogits = nil
}



// Force evaluation of MTP state to prevent graph collapse
var evalArrays = [mainTokens] + draftTokens
if let mtpLogits = self.mtpLogits { evalArrays.append(contentsOf: mtpLogits) }
Expand Down
18 changes: 18 additions & 0 deletions Libraries/MLXLMCommon/LanguageModel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,24 @@ public protocol DualModelMTP: MTPLanguageModel {
var mainModelRef: (any BaseLanguageModel)? { get set }
}

/// Protocol for MTP models that support partial rollback (llama.cpp PR #22673 style).
/// After accepting k of N drafts, the model can run just the MTP head from a stored
/// backbone hidden state — generating one draft without re-running the full main model.
public protocol MTPPartialRollback: MTPLanguageModel {
/// The full [B, S, D] backbone hidden state from the most recent callMTP pass.
var lastBackboneHiddenStateAll: MLXArray? { get }

/// Run only the MTP head from a stored backbone hidden state.
/// - Parameters:
/// - h: [B, 1, D] backbone hidden state at the accepted position
/// - nextToken: [B, 1] int32 — the output token (x_{k+1})
/// - cache: main model KV cache (post-trim, for cross-attention)
/// - posOffset: sequence position of the accepted token
/// - mtpDepth: how many draft logits to produce
/// - Returns: [depth-0 logits, ...] each [B, 1, V] — NO main logits prefix
func callMTPHeadOnly(_ h: MLXArray, nextToken: MLXArray, cache: [KVCache]?, posOffset: Int, mtpDepth: Int) -> [MLXArray]
}

extension MTPLanguageModel {
/// Default: call the two-argument overload with no MTP caches.
/// Models that don't override `makeMTPCaches` get a zero-element array.
Expand Down
2 changes: 0 additions & 2 deletions Libraries/MLXLMCommon/Load.swift
Original file line number Diff line number Diff line change
Expand Up @@ -126,12 +126,10 @@ public func loadWeights(
let allPrefixes = ["", "model.", "language_model.", "model.language_model."]
let candidates = [expert0Name, stripped0Name, strippedMtpName] + allPrefixes.map { $0 + stripped0Name } + allPrefixes.map { $0 + strippedMtpName }
var foundUnstacked = false
var matchedCandidate = ""

for candidate in candidates {
if ExpertStreamerManager.shared?.getFile(for: candidate) != nil {
foundUnstacked = true
matchedCandidate = candidate
var map = [Int: (path: String, tensorName: String)]()
for i in 0 ..< sl.numExperts {
let c = candidate.replacingOccurrences(of: ".experts.0.", with: ".experts.\(i).")
Expand Down
5 changes: 1 addition & 4 deletions Libraries/MLXLMCommon/SwitchLayers.swift
Original file line number Diff line number Diff line change
Expand Up @@ -316,10 +316,7 @@ public class SwitchGLU: Module, @unchecked Sendable {
var outShape = x.shape
outShape[outShape.count - 1] = downProj.outputDims
let result = MLXArray.zeros(outShape).asType(.float16)
if doSort {
return MLX.squeezed(scatterUnsort(x: result, invOrder: inverseOrder, shape: indices.shape), axis: -2)
}
return MLX.squeezed(result, axis: -2)
return MLX.squeezed(scatterUnsort(x: result, invOrder: inverseOrder, shape: indices.shape), axis: -2)
}

// Parse routing — `idx.asArray()` is the actual sync point on GPU.
Expand Down
7 changes: 7 additions & 0 deletions test_array_init.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
import Foundation
import MLX
MLX.GPU.set(cacheLimit: 10 * 1024 * 1024)

let size: Int = 10
let arr = MLXArray(0 ..< size).asType(.int32)
print(arr)
Comment on lines +1 to +7
13 changes: 13 additions & 0 deletions test_scatter.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import Foundation
import MLX

MLX.GPU.set(cacheLimit: 10 * 1024 * 1024)

var out = MLXArray.zeros([4, 10])
let rows = MLXArray(0 ..< Int32(4)).reshaped([4, 1])
let cols = MLXArray([1, 2, 0, 4, 3, 5, 2, 9]).reshaped([4, 2])
let vals = MLXArray([10, 20, 30, 40, 50, 60, 70, 80]).reshaped([4, 2])

out[rows, cols] = vals
MLX.eval(out)
print(out)
Comment on lines +3 to +13
Loading