diff --git a/Libraries/MLXLLM/Models/Gemma4Text.swift b/Libraries/MLXLLM/Models/Gemma4Text.swift index 8afbba152..babf1d483 100644 --- a/Libraries/MLXLLM/Models/Gemma4Text.swift +++ b/Libraries/MLXLLM/Models/Gemma4Text.swift @@ -990,7 +990,7 @@ extension Gemma4TextModel: LoRAModel { // MARK: - Assistant -public class Gemma4AssistantModel: Module, LLMModel, DualModelMTP, KVCacheDimensionProvider { +public class Gemma4AssistantModel: Module, LLMModel, DualModelMTP, MTPPartialRollback, KVCacheDimensionProvider { public let vocabularySize: Int public let kvHeads: [Int] @@ -1016,6 +1016,16 @@ public class Gemma4AssistantModel: Module, LLMModel, DualModelMTP, KVCacheDimens // Reference to the main model so we can call it inside callMTP public var mainModelRef: (any BaseLanguageModel)? = nil + /// Full [B, S, D] backbone hidden state from the most recent callMTP verification pass. + /// Stored so MTPTokenIterator can extract the hidden state at the accepted position + /// for partial rollback (re-seeding the MTP head without re-running the main model). + public var lastBackboneHiddenStateAll: MLXArray? = nil + + /// Number of draft tokens to produce per MTP head call. + /// depth=2: each pass costs 24% overhead (2 × ~12% per assistant layer pass at 40K). + /// depth=4: costs 48% overhead — empirically worse due to Metal kernel launch cost per depth. + public var numMTPDraftTokens: Int = 2 + public init(_ fullConfig: Gemma4Configuration) { let config = fullConfig.textConfig self.config = config @@ -1116,7 +1126,7 @@ public class Gemma4AssistantModel: Module, LLMModel, DualModelMTP, KVCacheDimens // Use mlx scatter via the __setitem__ approach: let scatterIdx2D = selectedCanonicalShaped.reshaped([B * S, totalCandidates]).asType(.int32) let selectedLogits2D = selectedLogits.reshaped([B * S, totalCandidates]) - var output2D = output.reshaped([B * S, vocabSize]) + let output2D = output.reshaped([B * S, vocabSize]) let rowIndices = MLXArray.arange(B * S).asType(.int32).reshaped([B * S, 1]) output2D[rowIndices, scatterIdx2D] = selectedLogits2D output = output2D.reshaped([B, S, vocabSize]) @@ -1134,86 +1144,53 @@ public class Gemma4AssistantModel: Module, LLMModel, DualModelMTP, KVCacheDimens return model.embedTokens.asLinear(h) } - public func callMTP(_ inputs: MLXArray, cache: [KVCache]?, mtpCaches: [[KVCache]]?) -> [MLXArray] { - guard let mainModel = mainModelRef else { - fatalError("mainModelRef must be set on Gemma4AssistantModel before calling callMTP") - } - - let posOffset = cache?.first.map { gemma4CapturePositionOffset(from: $0) } - - // 1. Run the main model to get main logits and backbone hidden state - guard let llmMain = mainModel as? any LLMModel else { - fatalError("mainModelRef must be an LLMModel") - } - let mainLogits = llmMain(inputs, cache: cache) - - // Extract the NORMALIZED hidden state from the backbone - var hBackbone: MLXArray - if let g4m = mainModel as? Gemma4Model, let lhs = g4m.lastHiddenState { - hBackbone = lhs - } else if let g4tm = mainModel as? Gemma4TextModel, let lhs = g4tm.lastHiddenState { - hBackbone = lhs - } else { - fatalError("[MTP] Could not extract normalized hidden state from main model") - } - - var allLogits = [mainLogits] - - // pre_projection: [256, 3072] — expects concat(hBackbone, embedToken) both 1536-dim → 3072 - // post_projection: [1536, 256] — maps assistant 256-dim state back to 1536 backbone dim - - // For depth=0, we don't have a draft token yet — we use the LAST token from inputs as the "current" token. - // hBackbone[..., -1:, ...] is the hidden state after the last real token. - // We embed the last input token to form the first concatenation. - let backboneDim = hBackbone.dim(-1) // 1536 - - // Get the last hidden state (the one that will predict the next token) - let seqLen = hBackbone.dim(1) - var hLast = hBackbone[0..., (seqLen-1).. PrepareResult { + guard let mainModel = mainModelRef as? any LLMModel else { + // mainModelRef not set yet — fall through to token-by-token (no prefill cache warming) + return .tokens(input.text) } + return try mainModel.prepare(input, cache: cache, windowSize: windowSize) + } - // Run as many depth iterations as needed for numDraftTokens + 1 (the accepted token's head) - // For numDraft=2 we need 2 MTP heads (depth 0 and 1 give us draft 1 and draft 2). - // Running only what we need avoids extra compute. - let mtpDepth = (mtpCaches?.count ?? 0) + 2 // fallback: 2 depths for 2 draft tokens - - for _ in 0 ..< mtpDepth { - // Step A: Concatenate token embedding + backbone hidden state → [B, 1, 3072] - // HF does torch.cat([last_token_embedding, last_hidden_state], dim=-1) - let hConcat = concatenated([eEmbed, hLast], axis: -1) // [B, 1, 3072] - - // Step B: Pre-projection → [B, 1, 256] + // MARK: - MTP Head Loop (shared by callMTP and callMTPHeadOnly) + + /// Run the iterative MTP head loop. + /// - Parameters: + /// - hLast: [B, 1, backboneDim] — initial backbone hidden state + /// - eEmbed: [B, 1, backboneDim] — embedding of the first "next" token + /// - posOffset: fixed position offset for assistant RoPE + /// - backboneDim: dimension of backbone hidden state + /// - cache: main model KV cache (for cross-attention in assistant layers) + /// - depth: how many MTP outputs to produce + /// - Returns: [depth-0 logits, depth-1 logits, ...] each [B, 1, V] + private func runMTPHead( + hLast hLastIn: MLXArray, + eEmbed eEmbedIn: MLXArray, + posOffset: Gemma4PositionOffset, + backboneDim: Int, + cache: [KVCache]?, + depth: Int + ) -> [MLXArray] { + var hLast = hLastIn + var eEmbed = eEmbedIn + var results = [MLXArray]() + + for _ in 0 ..< depth { + let hConcat = concatenated([eEmbed, hLast], axis: -1) var hAssistant: MLXArray - if let preProjWeight = preProjectionWeight { - hAssistant = matmul(hConcat, preProjWeight.T) // [B, 1, 256] + if let w = preProjectionWeight { + hAssistant = matmul(hConcat, w.T) } else { hAssistant = hConcat if hAssistant.dim(-1) != config.hiddenSize { @@ -1221,94 +1198,192 @@ public class Gemma4AssistantModel: Module, LLMModel, DualModelMTP, KVCacheDimens } } - // Step C: Run all 4 assistant transformer layers for i in 0 ..< config.numHiddenLayers { let layer = model.layers[i] - - // Pass main model KV cache as sharedKV for cross-attention var sharedKV: (MLXArray, MLXArray)? = nil if let fullCache = cache { let layerType = model.layers[i].layerType - // Assistant layers attend to the main model's last SWA or FA cache - // Full-attention layers use the last full-attention cache; SWA uses last SWA cache let mainIdx = layerType == "sliding_attention" ? fullCache.count - 2 : fullCache.count - 1 if mainIdx >= 0 { + // Cap shared-KV cross-attention to the last N backbone positions. + // The backbone hLast already encodes the full history; the assistant + // only needs local conditioning. Capping to 16 positions reduces + // cross-attention bandwidth from O(T) → O(16) at long contexts, + // eliminating the 2× slowdown at 40K–100K without hurting short-ctx. + let maxSharedKV = 16 let cacheElement = fullCache[mainIdx] if let c = cacheElement as? KVCacheSimple, let k = c.keys, let v = c.values { - // Slice to valid offset (avoid zero-padded buffer positions) - let validK = k[0..., 0..., 0..