From 67129df17c4fffe4c86b2b7ac260b3cbecbc4716 Mon Sep 17 00:00:00 2001 From: sukru tikves Date: Thu, 2 Jul 2026 13:08:28 -0700 Subject: [PATCH] Make generate() async to serialize back-to-back turns InferenceEngine.generate() is now async throws instead of throws. The pipelined engine awaits the prior generation Task before starting a new one, preventing the fatalError on rapid multi-turn conversations. The serialization preserves KV cache state -- prefix caching handles reuse automatically across turns. No data is lost; the engine just waits for the GPU pipeline to drain before restarting. Adds stress tests: back-to-back turns, rapid-fire 10-turn, and generate-after-cancel sequences. --- .../ConstrainedDecodingStrategy.swift | 2 +- .../VanillaDecodingStrategy.swift | 4 +- .../CoreAIPipelinedEngine.swift | 10 ++- .../CoreAISequentialEngine.swift | 2 +- .../CoreAISequentialVLMEngine.swift | 4 +- .../CoreAIStaticShapeEngine.swift | 2 +- .../InferenceEngines/InferenceEngine.swift | 4 +- .../LanguageModel/CoreAILanguageModel.swift | 2 +- .../TextGeneration/TextGenerator.swift | 2 +- .../Tools/benchmark/BenchmarkMain.swift | 2 +- .../Tools/llm-runner/LLMRunnerMain.swift | 2 +- .../LanguageModelsTests/CancelAPITests.swift | 85 ++++++++++++++++++- .../GenerationStopReasonTests.swift | 4 +- .../LanguageModelsTests/TestUtilities.swift | 2 +- .../UnifiedGenerationAPITests.swift | 62 +++++++------- 15 files changed, 137 insertions(+), 52 deletions(-) diff --git a/swift/Sources/CoreAILanguageModels/DecodingStrategies/ConstrainedDecodingStrategy.swift b/swift/Sources/CoreAILanguageModels/DecodingStrategies/ConstrainedDecodingStrategy.swift index 369c5a6..275f919 100644 --- a/swift/Sources/CoreAILanguageModels/DecodingStrategies/ConstrainedDecodingStrategy.swift +++ b/swift/Sources/CoreAILanguageModels/DecodingStrategies/ConstrainedDecodingStrategy.swift @@ -122,7 +122,7 @@ public struct ConstrainedDecodingStrategy: DecodingStrategy { constrainedOptions: InferenceOptions ) async throws -> (Int32?, [LogitsScalarType]?) { var rawLogits: [LogitsScalarType]? = nil - for try await output in try inferenceEngine.generate( + for try await output in try await inferenceEngine.generate( with: inputTokens, samplingConfiguration: samplingConfiguration, inferenceOptions: constrainedOptions diff --git a/swift/Sources/CoreAILanguageModels/DecodingStrategies/VanillaDecodingStrategy.swift b/swift/Sources/CoreAILanguageModels/DecodingStrategies/VanillaDecodingStrategy.swift index 015863a..9d40942 100644 --- a/swift/Sources/CoreAILanguageModels/DecodingStrategies/VanillaDecodingStrategy.swift +++ b/swift/Sources/CoreAILanguageModels/DecodingStrategies/VanillaDecodingStrategy.swift @@ -34,7 +34,7 @@ public struct VanillaDecodingStrategy: DecodingStrategy { samplingConfiguration: SamplingConfiguration, options: InferenceOptions, stopSequences: StopSequences - ) throws -> VanillaDecodedSequence { + ) async throws -> VanillaDecodedSequence { CLILogger.log("🔄 Starting vanilla decoding generation") // Eager setup. @@ -44,7 +44,7 @@ public struct VanillaDecodingStrategy: DecodingStrategy { .map(Int32.init) CLILogger.log("Input tokens: \(inputTokens.prefix(10))... (showing first 10)") - let stream = try inferenceEngine.generate( + let stream = try await inferenceEngine.generate( with: inputTokens, samplingConfiguration: samplingConfiguration, inferenceOptions: options diff --git a/swift/Sources/CoreAILanguageModels/InferenceEngines/CoreAIPipelinedEngine.swift b/swift/Sources/CoreAILanguageModels/InferenceEngines/CoreAIPipelinedEngine.swift index 0d57887..1ee9839 100644 --- a/swift/Sources/CoreAILanguageModels/InferenceEngines/CoreAIPipelinedEngine.swift +++ b/swift/Sources/CoreAILanguageModels/InferenceEngines/CoreAIPipelinedEngine.swift @@ -104,7 +104,7 @@ final class CoreAIPipelinedEngine: InferenceEngine, Sendable { with input: [TokenId], samplingConfiguration: SamplingConfiguration, inferenceOptions: InferenceOptions - ) throws -> GenerationSequence { + ) async throws -> GenerationSequence { if inferenceOptions.includeLogits { throw InferenceRuntimeError.invalidArgument( "CoreAI pipelined engine does not support logits (GPU-side sampling). " @@ -117,6 +117,14 @@ final class CoreAIPipelinedEngine: InferenceEngine, Sendable { + "Use a sequential engine for evaluation." ) } + + // Serialize: if a prior generation is still winding down (GPU drain), + // cancel it and wait for the engine slot to be released. + if let priorTask = _generationTask.withLock({ $0 }) { + _activeToken.withLock { $0?.cancel() } + await priorTask.value + } + let maxTokens = inferenceOptions.maxTokens let stopReasonStore = StopReasonStore() let (base, outputContinuation) = diff --git a/swift/Sources/CoreAILanguageModels/InferenceEngines/CoreAISequentialEngine.swift b/swift/Sources/CoreAILanguageModels/InferenceEngines/CoreAISequentialEngine.swift index e1a3816..620b9a8 100644 --- a/swift/Sources/CoreAILanguageModels/InferenceEngines/CoreAISequentialEngine.swift +++ b/swift/Sources/CoreAILanguageModels/InferenceEngines/CoreAISequentialEngine.swift @@ -348,7 +348,7 @@ public final class CoreAISequentialEngine: InferenceEngine, @unchecked Sendable with input: [TokenId], samplingConfiguration: SamplingConfiguration, inferenceOptions: InferenceOptions - ) throws -> GenerationSequence { + ) async throws -> GenerationSequence { // Implicit prefix caching: resolve before creating Iterator. // Implicit prefix caching: resolve input against history. if history.count > 0 { diff --git a/swift/Sources/CoreAILanguageModels/InferenceEngines/CoreAISequentialVLMEngine.swift b/swift/Sources/CoreAILanguageModels/InferenceEngines/CoreAISequentialVLMEngine.swift index b9d2795..ff85811 100644 --- a/swift/Sources/CoreAILanguageModels/InferenceEngines/CoreAISequentialVLMEngine.swift +++ b/swift/Sources/CoreAILanguageModels/InferenceEngines/CoreAISequentialVLMEngine.swift @@ -769,7 +769,7 @@ public final class CoreAISequentialVLMEngine: MultimodalInferenceEngine, @unchec with input: [TokenId], samplingConfiguration: SamplingConfiguration, inferenceOptions: InferenceOptions - ) throws -> GenerationSequence { + ) async throws -> GenerationSequence { let token = GenerationToken() _activeToken.withLock { $0 = token } return GenerationSequence( @@ -794,7 +794,7 @@ public final class CoreAISequentialVLMEngine: MultimodalInferenceEngine, @unchec tokens: [TokenId], samplingConfiguration: SamplingConfiguration, inferenceOptions: InferenceOptions - ) throws -> GenerationSequence { + ) async throws -> GenerationSequence { let token = GenerationToken() _activeToken.withLock { $0 = token } return GenerationSequence( diff --git a/swift/Sources/CoreAILanguageModels/InferenceEngines/CoreAIStaticShapeEngine.swift b/swift/Sources/CoreAILanguageModels/InferenceEngines/CoreAIStaticShapeEngine.swift index 36b39bc..5e9ccc0 100644 --- a/swift/Sources/CoreAILanguageModels/InferenceEngines/CoreAIStaticShapeEngine.swift +++ b/swift/Sources/CoreAILanguageModels/InferenceEngines/CoreAIStaticShapeEngine.swift @@ -331,7 +331,7 @@ public final class StaticShapeEngine: InferenceEngine, @unchecked Sendable { with input: [TokenId], samplingConfiguration: SamplingConfiguration, inferenceOptions: InferenceOptions - ) throws -> GenerationSequence { + ) async throws -> GenerationSequence { // Implicit prefix caching: resolve input against history. if history.count > 0 { let (commonPrefix, _) = history.resolve(input: input) diff --git a/swift/Sources/CoreAILanguageModels/InferenceEngines/InferenceEngine.swift b/swift/Sources/CoreAILanguageModels/InferenceEngines/InferenceEngine.swift index 20cca65..f9d4161 100644 --- a/swift/Sources/CoreAILanguageModels/InferenceEngines/InferenceEngine.swift +++ b/swift/Sources/CoreAILanguageModels/InferenceEngines/InferenceEngine.swift @@ -102,7 +102,7 @@ public protocol InferenceEngine: Sendable { with input: [TokenId], samplingConfiguration: SamplingConfiguration, inferenceOptions: InferenceOptions - ) throws -> OutputSequence + ) async throws -> OutputSequence // MARK: - Lifecycle @@ -311,7 +311,7 @@ public protocol MultimodalInferenceEngine: InferenceEngine { tokens: [TokenId], samplingConfiguration: SamplingConfiguration, inferenceOptions: InferenceOptions - ) throws -> OutputSequence + ) async throws -> OutputSequence } // TODO: Multi-turn — caller can cache EmbeddedInput across turns and pass it diff --git a/swift/Sources/CoreAILanguageModels/LanguageModel/CoreAILanguageModel.swift b/swift/Sources/CoreAILanguageModels/LanguageModel/CoreAILanguageModel.swift index 238900b..3fd4ef2 100644 --- a/swift/Sources/CoreAILanguageModels/LanguageModel/CoreAILanguageModel.swift +++ b/swift/Sources/CoreAILanguageModels/LanguageModel/CoreAILanguageModel.swift @@ -335,7 +335,7 @@ public struct CoreAILanguageModel: LanguageModel { maxTokens: Int, channel: LanguageModelExecutorGenerationChannel ) async throws { - let tokenStream = try engine.generate( + let tokenStream = try await engine.generate( with: promptTokens.map(Int32.init), samplingConfiguration: samplingConfig, inferenceOptions: InferenceOptions(maxTokens: maxTokens) diff --git a/swift/Sources/CoreAILanguageModels/TextGeneration/TextGenerator.swift b/swift/Sources/CoreAILanguageModels/TextGeneration/TextGenerator.swift index 40bb48e..64f11a6 100644 --- a/swift/Sources/CoreAILanguageModels/TextGeneration/TextGenerator.swift +++ b/swift/Sources/CoreAILanguageModels/TextGeneration/TextGenerator.swift @@ -138,7 +138,7 @@ public class TextGenerator { forcedContinuation: continuationTokens ) - let stream = try inferenceEngine.generate( + let stream = try await inferenceEngine.generate( with: contextTokens, samplingConfiguration: SamplingConfiguration.greedy, inferenceOptions: options diff --git a/swift/Sources/Tools/benchmark/BenchmarkMain.swift b/swift/Sources/Tools/benchmark/BenchmarkMain.swift index 1e855e7..d4f18e7 100644 --- a/swift/Sources/Tools/benchmark/BenchmarkMain.swift +++ b/swift/Sources/Tools/benchmark/BenchmarkMain.swift @@ -128,7 +128,7 @@ struct LLMBenchmark: AsyncParsableCommand { let options = InferenceOptions(maxTokens: generationTokens, includeLogits: false) let start = SuspendingClock.now - let stream = try engine.generate( + let stream = try await engine.generate( with: prompt, samplingConfiguration: sampling, inferenceOptions: options ) diff --git a/swift/Sources/Tools/llm-runner/LLMRunnerMain.swift b/swift/Sources/Tools/llm-runner/LLMRunnerMain.swift index cff76fa..3b13638 100644 --- a/swift/Sources/Tools/llm-runner/LLMRunnerMain.swift +++ b/swift/Sources/Tools/llm-runner/LLMRunnerMain.swift @@ -891,7 +891,7 @@ struct LLMRunner: AsyncParsableCommand, Sendable { await PerformanceMetrics.shared.setPromptTokenCount(vlmTokens.count) - let tokenStream = try vlmEngine.generate( + let tokenStream = try await vlmEngine.generate( with: embeddedInput, tokens: vlmTokens, samplingConfiguration: samplingConfiguration, diff --git a/swift/Tests/LanguageModelsTests/CancelAPITests.swift b/swift/Tests/LanguageModelsTests/CancelAPITests.swift index ca60398..7f25f73 100644 --- a/swift/Tests/LanguageModelsTests/CancelAPITests.swift +++ b/swift/Tests/LanguageModelsTests/CancelAPITests.swift @@ -25,7 +25,7 @@ struct CancelAPITests { @Test("engine is busy during generation") func busyDuringGeneration() async throws { let engine = MockEngine(tokens: [10, 20, 30]) - let stream = try engine.generate( + let stream = try await engine.generate( with: [1], samplingConfiguration: .greedy, inferenceOptions: InferenceOptions(maxTokens: 100) @@ -40,7 +40,7 @@ struct CancelAPITests { @Test("cancel() stops generation and marks .cancelled") func cancelStopsGeneration() async throws { let engine = MockEngine(tokens: [10, 20, 30]) - let stream = try engine.generate( + let stream = try await engine.generate( with: [1], samplingConfiguration: .greedy, inferenceOptions: InferenceOptions(maxTokens: 100) @@ -62,7 +62,7 @@ struct CancelAPITests { @Test("engine becomes idle after generation completes naturally") func idleAfterNaturalCompletion() async throws { let engine = MockEngine(tokens: [10, 20, 30]) - let stream = try engine.generate( + let stream = try await engine.generate( with: [1], samplingConfiguration: .greedy, inferenceOptions: InferenceOptions(maxTokens: 3) @@ -78,7 +78,7 @@ struct CancelAPITests { @Test("reset() cancels active generation") func resetCancelsGeneration() async throws { let engine = MockEngine(tokens: [10, 20, 30]) - let stream = try engine.generate( + let stream = try await engine.generate( with: [1], samplingConfiguration: .greedy, inferenceOptions: InferenceOptions(maxTokens: 100) @@ -117,4 +117,81 @@ struct CancelAPITests { token.cancel() #expect(token.isCancelled) } + + // MARK: - Back-to-back turn serialization + + @Test("back-to-back generate() calls do not crash") + func backToBackGenerate() async throws { + let engine = MockEngine(tokens: [10, 20, 30, 40, 50]) + + // Turn 1: consume partially (simulates EOS break mid-stream) + let stream1 = try await engine.generate( + with: [1, 2, 3], + samplingConfiguration: .greedy, + inferenceOptions: InferenceOptions(maxTokens: 5) + ) + var count1 = 0 + for try await _ in stream1 { + count1 += 1 + if count1 == 2 { break } + } + + // Turn 2: immediately start next generation — must not crash + let stream2 = try await engine.generate( + with: [1, 2, 3, 10, 20, 4, 5], + samplingConfiguration: .greedy, + inferenceOptions: InferenceOptions(maxTokens: 3) + ) + var count2 = 0 + for try await _ in stream2 { + count2 += 1 + } + #expect(count2 == 3) + } + + @Test("rapid-fire multi-turn stress (10 turns, no delay)") + func rapidFireMultiTurn() async throws { + let engine = MockEngine(tokens: [10, 20, 30, 40, 50]) + + for turn in 0..<10 { + let prompt = Array(0..<(turn + 1)).map { Int32($0) } + let stream = try await engine.generate( + with: prompt, + samplingConfiguration: .greedy, + inferenceOptions: InferenceOptions(maxTokens: 3) + ) + var tokens: [Int32] = [] + for try await output in stream { + tokens.append(output.tokenId) + } + #expect(tokens.count <= 3) + } + } + + @Test("generate() after cancel() works cleanly") + func generateAfterCancel() async throws { + let engine = MockEngine(tokens: [10, 20, 30]) + + let stream1 = try await engine.generate( + with: [1], + samplingConfiguration: .greedy, + inferenceOptions: InferenceOptions(maxTokens: 100) + ) + _ = stream1 // don't consume at all + + try await engine.cancel() + #expect(!engine.isBusy) + + // Should work immediately after cancel + let stream2 = try await engine.generate( + with: [1, 2], + samplingConfiguration: .greedy, + inferenceOptions: InferenceOptions(maxTokens: 2) + ) + var count = 0 + for try await _ in stream2 { + count += 1 + } + #expect(count == 2) + } } diff --git a/swift/Tests/LanguageModelsTests/GenerationStopReasonTests.swift b/swift/Tests/LanguageModelsTests/GenerationStopReasonTests.swift index b889981..4fbc139 100644 --- a/swift/Tests/LanguageModelsTests/GenerationStopReasonTests.swift +++ b/swift/Tests/LanguageModelsTests/GenerationStopReasonTests.swift @@ -12,7 +12,7 @@ struct GenerationStopReasonTests { @Test("iterates all tokens and sets .maxTokens on normal completion") func normalCompletion() async throws { let engine = MockEngine(tokens: [1, 2, 3], maxContextLength: 100) - let stream = try engine.generate( + let stream = try await engine.generate( with: [0], samplingConfiguration: .greedy, inferenceOptions: InferenceOptions(maxTokens: 3) @@ -31,7 +31,7 @@ struct GenerationStopReasonTests { func eosSetByDecoder() async throws { let eosToken: Int32 = 99 let engine = MockEngine(tokens: [10, 20, eosToken, 40], maxContextLength: 100) - let stream = try engine.generate( + let stream = try await engine.generate( with: [0], samplingConfiguration: .greedy, inferenceOptions: InferenceOptions(maxTokens: 10) diff --git a/swift/Tests/LanguageModelsTests/TestUtilities.swift b/swift/Tests/LanguageModelsTests/TestUtilities.swift index 1fcd925..801cd65 100644 --- a/swift/Tests/LanguageModelsTests/TestUtilities.swift +++ b/swift/Tests/LanguageModelsTests/TestUtilities.swift @@ -71,7 +71,7 @@ class MockEngine: InferenceEngine, @unchecked Sendable { with input: [TokenId], samplingConfiguration: SamplingConfiguration, inferenceOptions: InferenceOptions - ) throws -> GenerationSequence { + ) async throws -> GenerationSequence { let token = GenerationToken() _activeToken.withLock { $0 = token } diff --git a/swift/Tests/LanguageModelsTests/UnifiedGenerationAPITests.swift b/swift/Tests/LanguageModelsTests/UnifiedGenerationAPITests.swift index 12783de..81cc436 100644 --- a/swift/Tests/LanguageModelsTests/UnifiedGenerationAPITests.swift +++ b/swift/Tests/LanguageModelsTests/UnifiedGenerationAPITests.swift @@ -58,7 +58,7 @@ struct GenerateDefaultExtensionTests { var outputs: [InferenceOutput] = [] let generation = InferenceOptions(maxTokens: 5) - for try await output in try engine.generate( + for try await output in try await engine.generate( with: [1, 2, 3], samplingConfiguration: SamplingConfiguration.greedy, inferenceOptions: generation @@ -80,7 +80,7 @@ struct GenerateDefaultExtensionTests { let engine = MockEngine(tokens: [42]) let generation = InferenceOptions(maxTokens: 1, includeLogits: false) - for try await output in try engine.generate( + for try await output in try await engine.generate( with: [1], samplingConfiguration: SamplingConfiguration.greedy, inferenceOptions: generation @@ -94,7 +94,7 @@ struct GenerateDefaultExtensionTests { let engine = MockEngine(tokens: [42], vocabSize: 50) let generation = InferenceOptions(maxTokens: 1, includeLogits: true) - for try await output in try engine.generate( + for try await output in try await engine.generate( with: [1], samplingConfiguration: SamplingConfiguration.greedy, inferenceOptions: generation @@ -109,7 +109,7 @@ struct GenerateDefaultExtensionTests { let engine = MockEngine(tokens: [5], vocabSize: 10) let generation = InferenceOptions(maxTokens: 1, includeLogits: true) - for try await output in try engine.generate( + for try await output in try await engine.generate( with: [1], samplingConfiguration: SamplingConfiguration.greedy, inferenceOptions: generation @@ -129,7 +129,7 @@ struct GenerateDefaultExtensionTests { let engine = MockEngine(tokens: [42], vocabSize: nil) let generation = InferenceOptions(maxTokens: 1, includeLogits: true) - for try await output in try engine.generate( + for try await output in try await engine.generate( with: [1], samplingConfiguration: SamplingConfiguration.greedy, inferenceOptions: generation @@ -145,7 +145,7 @@ struct GenerateDefaultExtensionTests { var count = 0 let generation = InferenceOptions(maxTokens: 100) // Request way more than available - for try await _ in try engine.generate( + for try await _ in try await engine.generate( with: [1, 2, 3], // 3 tokens prompt samplingConfiguration: SamplingConfiguration.greedy, inferenceOptions: generation @@ -162,7 +162,7 @@ struct GenerateDefaultExtensionTests { var count = 0 let generation = InferenceOptions() // nil maxTokens - for try await _ in try engine.generate( + for try await _ in try await engine.generate( with: [1, 2, 3], // 3 tokens prompt samplingConfiguration: SamplingConfiguration.greedy, inferenceOptions: generation @@ -178,7 +178,7 @@ struct GenerateDefaultExtensionTests { let engine = MockEngine(tokens: [10]) // Generate a token to advance state - for try await _ in try engine.generate( + for try await _ in try await engine.generate( with: [1], samplingConfiguration: SamplingConfiguration.greedy, inferenceOptions: InferenceOptions(maxTokens: 1) @@ -205,7 +205,7 @@ struct GenerateMultiCallTests { // Simulate guided-generation pattern: call generate(maxTokens:1) repeatedly for _ in 0..<20 { var got: InferenceOutput? - for try await output in try engine.generate( + for try await output in try await engine.generate( with: tokens, samplingConfiguration: .greedy, inferenceOptions: InferenceOptions(maxTokens: 1, includeLogits: true) @@ -227,7 +227,7 @@ struct GenerateMultiCallTests { // Turn 1 var count1 = 0 - for try await _ in try engine.generate( + for try await _ in try await engine.generate( with: [1, 2, 3], samplingConfiguration: .greedy, inferenceOptions: InferenceOptions(maxTokens: 5) @@ -241,7 +241,7 @@ struct GenerateMultiCallTests { // Turn 2 var count2 = 0 - for try await _ in try engine.generate( + for try await _ in try await engine.generate( with: [4, 5, 6], samplingConfiguration: .greedy, inferenceOptions: InferenceOptions(maxTokens: 3) @@ -257,7 +257,7 @@ struct GenerateMultiCallTests { let forced: [Int32] = [7, 8, 9] var outputs: [InferenceOutput] = [] - for try await output in try engine.generate( + for try await output in try await engine.generate( with: [1, 2, 3], samplingConfiguration: .greedy, inferenceOptions: InferenceOptions( @@ -278,7 +278,7 @@ struct GenerateMultiCallTests { let engine = MockEngine(tokens: [10, 20]) var count = 0 - for try await _ in try engine.generate( + for try await _ in try await engine.generate( with: [1], samplingConfiguration: .greedy, inferenceOptions: InferenceOptions(forcedContinuation: []) @@ -346,7 +346,7 @@ struct PartialResetParityTests { // Generate the full reference sequence once try await engine.reset() var referenceTokens: [Int32] = [] - for try await output in try engine.generate( + for try await output in try await engine.generate( with: prompt, samplingConfiguration: .greedy, inferenceOptions: InferenceOptions(maxTokens: totalTokens) @@ -366,7 +366,7 @@ struct PartialResetParityTests { // --- Path A: Full reset + re-generate everything --- try await engine.reset() var fullTokens: [Int32] = [] - for try await output in try engine.generate( + for try await output in try await engine.generate( with: prompt, samplingConfiguration: .greedy, inferenceOptions: InferenceOptions(maxTokens: totalTokens) @@ -379,7 +379,7 @@ struct PartialResetParityTests { var partialTokens: [Int32] = [] // Generate first resetPoint tokens - for try await output in try engine.generate( + for try await output in try await engine.generate( with: prompt, samplingConfiguration: .greedy, inferenceOptions: InferenceOptions(maxTokens: resetPoint) @@ -395,7 +395,7 @@ struct PartialResetParityTests { // Continue generating the remaining tokens let remaining = totalTokens - resetPoint let continueInput = prompt + partialTokens // full context so far - for try await output in try engine.generate( + for try await output in try await engine.generate( with: continueInput, samplingConfiguration: .greedy, inferenceOptions: InferenceOptions(maxTokens: remaining) @@ -423,7 +423,7 @@ struct PartialResetParityTests { #expect(engine.processedTokenCount == 0) // After generating 10 tokens from a 5-token prompt - for try await _ in try engine.generate( + for try await _ in try await engine.generate( with: prompt, samplingConfiguration: .greedy, inferenceOptions: InferenceOptions(maxTokens: 10) @@ -436,7 +436,7 @@ struct PartialResetParityTests { // Generate again — input matches cached prefix, so no new prefill let continueInput = prompt - for try await _ in try engine.generate( + for try await _ in try await engine.generate( with: continueInput, samplingConfiguration: .greedy, inferenceOptions: InferenceOptions(maxTokens: 7) @@ -551,7 +551,7 @@ struct PrefixCachingTests { // First generation: prompt [1, 2, 3] var tokens: [Int32] = [1, 2, 3] - for try await output in try engine.generate( + for try await output in try await engine.generate( with: tokens, samplingConfiguration: .greedy, inferenceOptions: InferenceOptions(maxTokens: 3) @@ -563,7 +563,7 @@ struct PrefixCachingTests { // Second generation with same prefix + new suffix: // Engine should detect prefix hit for the first 6 tokens - for try await output in try engine.generate( + for try await output in try await engine.generate( with: tokens, samplingConfiguration: .greedy, inferenceOptions: InferenceOptions(maxTokens: 2) @@ -581,7 +581,7 @@ struct PrefixCachingTests { let engine = MockEngine(tokens: [10, 20, 30, 40, 50], maxContextLength: 100) // First generation: prompt [1, 2, 3] - for try await _ in try engine.generate( + for try await _ in try await engine.generate( with: [1, 2, 3], samplingConfiguration: .greedy, inferenceOptions: InferenceOptions(maxTokens: 3) @@ -592,7 +592,7 @@ struct PrefixCachingTests { // Second generation with divergent prefix: [1, 2, 99, ...] // Should auto-detect divergence at position 2 and rewind var tokens: [Int32] = [] - for try await output in try engine.generate( + for try await output in try await engine.generate( with: [1, 2, 99, 100], samplingConfiguration: .greedy, inferenceOptions: InferenceOptions(maxTokens: 2) @@ -612,7 +612,7 @@ struct PrefixCachingTests { let engine = MockEngine(tokens: [10, 20, 30], maxContextLength: 100) // First generation - for try await _ in try engine.generate( + for try await _ in try await engine.generate( with: [1, 2, 3], samplingConfiguration: .greedy, inferenceOptions: InferenceOptions(maxTokens: 2) @@ -620,7 +620,7 @@ struct PrefixCachingTests { #expect(engine.processedTokenCount == 5) // Completely different input - for try await _ in try engine.generate( + for try await _ in try await engine.generate( with: [99, 98, 97], samplingConfiguration: .greedy, inferenceOptions: InferenceOptions(maxTokens: 1) @@ -637,7 +637,7 @@ struct PrefixCachingTests { // Turn 1: generate with prompt [1, 2, 3] var context: [Int32] = [1, 2, 3] - for try await output in try engine.generate( + for try await output in try await engine.generate( with: context, samplingConfiguration: .greedy, inferenceOptions: InferenceOptions(maxTokens: 5) @@ -650,7 +650,7 @@ struct PrefixCachingTests { // Turn 2: append user message tokens, generate again // The first 8 tokens should be a prefix hit context.append(contentsOf: [77, 78, 79]) // new user message - for try await output in try engine.generate( + for try await output in try await engine.generate( with: context, samplingConfiguration: .greedy, inferenceOptions: InferenceOptions(maxTokens: 3) @@ -667,7 +667,7 @@ struct PrefixCachingTests { let engine = MockEngine(tokens: [10, 20], maxContextLength: 100) // Generate some tokens - for try await _ in try engine.generate( + for try await _ in try await engine.generate( with: [1, 2, 3], samplingConfiguration: .greedy, inferenceOptions: InferenceOptions(maxTokens: 2) @@ -680,7 +680,7 @@ struct PrefixCachingTests { #expect(engine.processedTokenCount == 0) // Next generation starts fresh - for try await _ in try engine.generate( + for try await _ in try await engine.generate( with: [1, 2, 3], samplingConfiguration: .greedy, inferenceOptions: InferenceOptions(maxTokens: 1) @@ -693,7 +693,7 @@ struct PrefixCachingTests { let engine = MockEngine(tokens: [10, 20, 30], maxContextLength: 100) // Generate: prompt [1, 2, 3] + 3 tokens = history of 6 - for try await _ in try engine.generate( + for try await _ in try await engine.generate( with: [1, 2, 3], samplingConfiguration: .greedy, inferenceOptions: InferenceOptions(maxTokens: 3) @@ -708,7 +708,7 @@ struct PrefixCachingTests { #expect(engine.history.tokens == [1, 2, 3]) // Re-generate with same prompt — should be a full prefix hit - for try await _ in try engine.generate( + for try await _ in try await engine.generate( with: [1, 2, 3], samplingConfiguration: .greedy, inferenceOptions: InferenceOptions(maxTokens: 2)