Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ public struct ConstrainedDecodingStrategy: DecodingStrategy {
constrainedOptions: InferenceOptions
) async throws -> (Int32?, [LogitsScalarType]?) {
var rawLogits: [LogitsScalarType]? = nil
for try await output in try inferenceEngine.generate(
for try await output in try await inferenceEngine.generate(
with: inputTokens,
samplingConfiguration: samplingConfiguration,
inferenceOptions: constrainedOptions
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ public struct VanillaDecodingStrategy: DecodingStrategy {
samplingConfiguration: SamplingConfiguration,
options: InferenceOptions,
stopSequences: StopSequences
) throws -> VanillaDecodedSequence {
) async throws -> VanillaDecodedSequence {
CLILogger.log("🔄 Starting vanilla decoding generation")

// Eager setup.
Expand All @@ -44,7 +44,7 @@ public struct VanillaDecodingStrategy: DecodingStrategy {
.map(Int32.init)
CLILogger.log("Input tokens: \(inputTokens.prefix(10))... (showing first 10)")

let stream = try inferenceEngine.generate(
let stream = try await inferenceEngine.generate(
with: inputTokens,
samplingConfiguration: samplingConfiguration,
inferenceOptions: options
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ final class CoreAIPipelinedEngine: InferenceEngine, Sendable {
with input: [TokenId],
samplingConfiguration: SamplingConfiguration,
inferenceOptions: InferenceOptions
) throws -> GenerationSequence {
) async throws -> GenerationSequence {
if inferenceOptions.includeLogits {
throw InferenceRuntimeError.invalidArgument(
"CoreAI pipelined engine does not support logits (GPU-side sampling). "
Expand All @@ -117,6 +117,14 @@ final class CoreAIPipelinedEngine: InferenceEngine, Sendable {
+ "Use a sequential engine for evaluation."
)
}

// Serialize: if a prior generation is still winding down (GPU drain),
// cancel it and wait for the engine slot to be released.
if let priorTask = _generationTask.withLock({ $0 }) {
_activeToken.withLock { $0?.cancel() }
await priorTask.value
}

let maxTokens = inferenceOptions.maxTokens
let stopReasonStore = StopReasonStore()
let (base, outputContinuation) =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,7 @@ public final class CoreAISequentialEngine: InferenceEngine, @unchecked Sendable
with input: [TokenId],
samplingConfiguration: SamplingConfiguration,
inferenceOptions: InferenceOptions
) throws -> GenerationSequence {
) async throws -> GenerationSequence {
// Implicit prefix caching: resolve before creating Iterator.
// Implicit prefix caching: resolve input against history.
if history.count > 0 {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -769,7 +769,7 @@ public final class CoreAISequentialVLMEngine: MultimodalInferenceEngine, @unchec
with input: [TokenId],
samplingConfiguration: SamplingConfiguration,
inferenceOptions: InferenceOptions
) throws -> GenerationSequence {
) async throws -> GenerationSequence {
let token = GenerationToken()
_activeToken.withLock { $0 = token }
return GenerationSequence(
Expand All @@ -794,7 +794,7 @@ public final class CoreAISequentialVLMEngine: MultimodalInferenceEngine, @unchec
tokens: [TokenId],
samplingConfiguration: SamplingConfiguration,
inferenceOptions: InferenceOptions
) throws -> GenerationSequence {
) async throws -> GenerationSequence {
let token = GenerationToken()
_activeToken.withLock { $0 = token }
return GenerationSequence(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ public final class StaticShapeEngine: InferenceEngine, @unchecked Sendable {
with input: [TokenId],
samplingConfiguration: SamplingConfiguration,
inferenceOptions: InferenceOptions
) throws -> GenerationSequence {
) async throws -> GenerationSequence {
// Implicit prefix caching: resolve input against history.
if history.count > 0 {
let (commonPrefix, _) = history.resolve(input: input)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ public protocol InferenceEngine: Sendable {
with input: [TokenId],
samplingConfiguration: SamplingConfiguration,
inferenceOptions: InferenceOptions
) throws -> OutputSequence
) async throws -> OutputSequence

// MARK: - Lifecycle

Expand Down Expand Up @@ -311,7 +311,7 @@ public protocol MultimodalInferenceEngine: InferenceEngine {
tokens: [TokenId],
samplingConfiguration: SamplingConfiguration,
inferenceOptions: InferenceOptions
) throws -> OutputSequence
) async throws -> OutputSequence
}

// TODO: Multi-turn — caller can cache EmbeddedInput across turns and pass it
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ public struct CoreAILanguageModel: LanguageModel {
maxTokens: Int,
channel: LanguageModelExecutorGenerationChannel
) async throws {
let tokenStream = try engine.generate(
let tokenStream = try await engine.generate(
with: promptTokens.map(Int32.init),
samplingConfiguration: samplingConfig,
inferenceOptions: InferenceOptions(maxTokens: maxTokens)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ public class TextGenerator {
forcedContinuation: continuationTokens
)

let stream = try inferenceEngine.generate(
let stream = try await inferenceEngine.generate(
with: contextTokens,
samplingConfiguration: SamplingConfiguration.greedy,
inferenceOptions: options
Expand Down
2 changes: 1 addition & 1 deletion swift/Sources/Tools/benchmark/BenchmarkMain.swift
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ struct LLMBenchmark: AsyncParsableCommand {

let options = InferenceOptions(maxTokens: generationTokens, includeLogits: false)
let start = SuspendingClock.now
let stream = try engine.generate(
let stream = try await engine.generate(
with: prompt, samplingConfiguration: sampling, inferenceOptions: options
)

Expand Down
2 changes: 1 addition & 1 deletion swift/Sources/Tools/llm-runner/LLMRunnerMain.swift
Original file line number Diff line number Diff line change
Expand Up @@ -891,7 +891,7 @@ struct LLMRunner: AsyncParsableCommand, Sendable {

await PerformanceMetrics.shared.setPromptTokenCount(vlmTokens.count)

let tokenStream = try vlmEngine.generate(
let tokenStream = try await vlmEngine.generate(
with: embeddedInput,
tokens: vlmTokens,
samplingConfiguration: samplingConfiguration,
Expand Down
85 changes: 81 additions & 4 deletions swift/Tests/LanguageModelsTests/CancelAPITests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ struct CancelAPITests {
@Test("engine is busy during generation")
func busyDuringGeneration() async throws {
let engine = MockEngine(tokens: [10, 20, 30])
let stream = try engine.generate(
let stream = try await engine.generate(
with: [1],
samplingConfiguration: .greedy,
inferenceOptions: InferenceOptions(maxTokens: 100)
Expand All @@ -40,7 +40,7 @@ struct CancelAPITests {
@Test("cancel() stops generation and marks .cancelled")
func cancelStopsGeneration() async throws {
let engine = MockEngine(tokens: [10, 20, 30])
let stream = try engine.generate(
let stream = try await engine.generate(
with: [1],
samplingConfiguration: .greedy,
inferenceOptions: InferenceOptions(maxTokens: 100)
Expand All @@ -62,7 +62,7 @@ struct CancelAPITests {
@Test("engine becomes idle after generation completes naturally")
func idleAfterNaturalCompletion() async throws {
let engine = MockEngine(tokens: [10, 20, 30])
let stream = try engine.generate(
let stream = try await engine.generate(
with: [1],
samplingConfiguration: .greedy,
inferenceOptions: InferenceOptions(maxTokens: 3)
Expand All @@ -78,7 +78,7 @@ struct CancelAPITests {
@Test("reset() cancels active generation")
func resetCancelsGeneration() async throws {
let engine = MockEngine(tokens: [10, 20, 30])
let stream = try engine.generate(
let stream = try await engine.generate(
with: [1],
samplingConfiguration: .greedy,
inferenceOptions: InferenceOptions(maxTokens: 100)
Expand Down Expand Up @@ -117,4 +117,81 @@ struct CancelAPITests {
token.cancel()
#expect(token.isCancelled)
}

// MARK: - Back-to-back turn serialization

@Test("back-to-back generate() calls do not crash")
func backToBackGenerate() async throws {
let engine = MockEngine(tokens: [10, 20, 30, 40, 50])

// Turn 1: consume partially (simulates EOS break mid-stream)
let stream1 = try await engine.generate(
with: [1, 2, 3],
samplingConfiguration: .greedy,
inferenceOptions: InferenceOptions(maxTokens: 5)
)
var count1 = 0
for try await _ in stream1 {
count1 += 1
if count1 == 2 { break }
}

// Turn 2: immediately start next generation — must not crash
let stream2 = try await engine.generate(
with: [1, 2, 3, 10, 20, 4, 5],
samplingConfiguration: .greedy,
inferenceOptions: InferenceOptions(maxTokens: 3)
)
var count2 = 0
for try await _ in stream2 {
count2 += 1
}
#expect(count2 == 3)
}

@Test("rapid-fire multi-turn stress (10 turns, no delay)")
func rapidFireMultiTurn() async throws {
let engine = MockEngine(tokens: [10, 20, 30, 40, 50])

for turn in 0..<10 {
let prompt = Array(0..<(turn + 1)).map { Int32($0) }
let stream = try await engine.generate(
with: prompt,
samplingConfiguration: .greedy,
inferenceOptions: InferenceOptions(maxTokens: 3)
)
var tokens: [Int32] = []
for try await output in stream {
tokens.append(output.tokenId)
}
#expect(tokens.count <= 3)
}
}

@Test("generate() after cancel() works cleanly")
func generateAfterCancel() async throws {
let engine = MockEngine(tokens: [10, 20, 30])

let stream1 = try await engine.generate(
with: [1],
samplingConfiguration: .greedy,
inferenceOptions: InferenceOptions(maxTokens: 100)
)
_ = stream1 // don't consume at all

try await engine.cancel()
#expect(!engine.isBusy)

// Should work immediately after cancel
let stream2 = try await engine.generate(
with: [1, 2],
samplingConfiguration: .greedy,
inferenceOptions: InferenceOptions(maxTokens: 2)
)
var count = 0
for try await _ in stream2 {
count += 1
}
#expect(count == 2)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ struct GenerationStopReasonTests {
@Test("iterates all tokens and sets .maxTokens on normal completion")
func normalCompletion() async throws {
let engine = MockEngine(tokens: [1, 2, 3], maxContextLength: 100)
let stream = try engine.generate(
let stream = try await engine.generate(
with: [0],
samplingConfiguration: .greedy,
inferenceOptions: InferenceOptions(maxTokens: 3)
Expand All @@ -31,7 +31,7 @@ struct GenerationStopReasonTests {
func eosSetByDecoder() async throws {
let eosToken: Int32 = 99
let engine = MockEngine(tokens: [10, 20, eosToken, 40], maxContextLength: 100)
let stream = try engine.generate(
let stream = try await engine.generate(
with: [0],
samplingConfiguration: .greedy,
inferenceOptions: InferenceOptions(maxTokens: 10)
Expand Down
2 changes: 1 addition & 1 deletion swift/Tests/LanguageModelsTests/TestUtilities.swift
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ class MockEngine: InferenceEngine, @unchecked Sendable {
with input: [TokenId],
samplingConfiguration: SamplingConfiguration,
inferenceOptions: InferenceOptions
) throws -> GenerationSequence {
) async throws -> GenerationSequence {
let token = GenerationToken()
_activeToken.withLock { $0 = token }

Expand Down
Loading