Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -381,7 +381,7 @@ public final class CoreAISequentialVLMEngine: MultimodalInferenceEngine, @unchec

CLILogger.log("VLM encodeImage complete: \(tokenCount) embedding tokens")

return EmbeddedInput(
return try EmbeddedInput(
embeddings: projectedEmbeddings,
embeddingPositions: placeholderRange
)
Expand Down Expand Up @@ -556,8 +556,8 @@ public final class CoreAISequentialVLMEngine: MultimodalInferenceEngine, @unchec
+ "expected \(imageTokenCount) from config. Check prompt template.")
}

let seqLen = textEmbeddings.shape.count >= 2 ? textEmbeddings.shape[1] : 0
let imgSeqLen = imageEmbeddings.shape.count >= 2 ? imageEmbeddings.shape[1] : 0
let seqLen = textEmbeddings.shape[1]
let imgSeqLen = imageEmbeddings.shape[1]
guard imgSeqLen >= imageTokenCount else {
throw InferenceRuntimeError.invalidArgument(
"scatterMerge: image embeddings have \(imgSeqLen) tokens, need \(imageTokenCount)")
Expand All @@ -570,10 +570,10 @@ public final class CoreAISequentialVLMEngine: MultimodalInferenceEngine, @unchec
}

// Copy image embeddings into placeholder positions.
precondition(
Comment thread
stikves marked this conversation as resolved.
imageEmbeddings.scalarType == .float16,
"scatterMerge only supports float16 embeddings; got \(imageEmbeddings.scalarType)"
)
guard imageEmbeddings.scalarType == .float16 else {
throw InferenceRuntimeError.invalidInputType(
"scatterMerge only supports float16 embeddings; got \(imageEmbeddings.scalarType)")
}
imageEmbeddings.view(as: Float16.self).withUnsafePointer { imgPtr, _, _ in
var mutableView = merged.mutableView(as: Float16.self)
mutableView.withUnsafeMutablePointer { mergedPtr, _, _ in
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,22 +12,25 @@ import Foundation
/// language model. The engine performs scatter-merge: replacing placeholder
/// token positions with these embeddings before the first forward pass.
public struct EmbeddedInput: Sendable {
/// The embedding tensor, typically shape [1, seq_len, hidden_dim].
/// The embedding tensor, shape [batch, seq_len, hidden_dim].
/// Scalar type matches the LLM's expected input (float16, bFloat16, etc.).
public let embeddings: NDArray

/// Positions in the token sequence where embeddings replace placeholders.
public let embeddingPositions: Range<Int>

public init(embeddings: NDArray, embeddingPositions: Range<Int>) {
public init(embeddings: NDArray, embeddingPositions: Range<Int>) throws {
guard embeddings.shape.count == 3 else {
throw InferenceRuntimeError.invalidArgument(
"EmbeddedInput requires 3D embeddings [batch, seq_len, hidden_dim], "
+ "got shape with \(embeddings.shape.count) dimensions")
}
self.embeddings = embeddings
self.embeddingPositions = embeddingPositions
}

/// Number of embedding tokens (seq_len dimension).
public var tokenCount: Int {
embeddings.shape.count >= 2 ? embeddings.shape[1] : 0
}
public var tokenCount: Int { embeddings.shape[1] }

// TODO: Multi-turn support — allow multiple image regions per input,
// persistent across generate() calls (keep in KV cache on reset).
Expand Down
4 changes: 1 addition & 3 deletions swift/Sources/Tools/llm-runner/LLMRunnerMain.swift
Original file line number Diff line number Diff line change
Expand Up @@ -179,9 +179,6 @@ struct LLMRunner: AsyncParsableCommand, Sendable {
@Option(name: .customLong("image"), help: "Path to an image file for vision-language models")
var imagePath: String?

@Option(
help: "Maximum tiles for image splitting (overrides model config). 1 = single crop, no tiling.")

@Flag(help: "Enable verbose logging")
var verbose: Bool = false

Expand Down Expand Up @@ -374,6 +371,7 @@ struct LLMRunner: AsyncParsableCommand, Sendable {
)
let vlmConfig = VLMModelConfig(base: baseConfig, visionConfig: visionConfig)

// Sequential to avoid runtime errors with concurrent model preparation.
Comment thread
stikves marked this conversation as resolved.
let visionModel = try await PreparedModel.prepare(at: visionURL)
let embedModel = try await PreparedModel.prepare(at: embedURL)
let llmModel = try await PreparedModel.prepare(at: mainURL)
Expand Down