Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 27 additions & 1 deletion Sources/ContainerCommands/Container/ContainerRun.swift
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ extension Application {
@OptionGroup(title: "Image fetch options")
var imageFetchFlags: Flags.ImageFetch

@OptionGroup(title: "GPU options")
var gpuFlags: Flags.GPU

@OptionGroup
public var logOptions: Flags.Logging

Expand Down Expand Up @@ -94,7 +97,7 @@ extension Application {
)
}

let ck = try await Utility.containerConfigFromFlags(
var ck = try await Utility.containerConfigFromFlags(
id: id,
image: image,
arguments: arguments,
Expand All @@ -107,6 +110,29 @@ extension Application {
log: log
)

// GPU support: inject vsock environment variables so the guest can
// reach the MLX Container Daemon on the host.
if gpuFlags.gpu {
let gpuEnv = [
"MLX_VSOCK_CID=2",
"MLX_VSOCK_PORT=\(gpuFlags.gpuPort)",
"MLX_GPU_ENABLED=1",
]
if var config = ck.0 as? ContainerConfiguration {
config.process.env.append(contentsOf: gpuEnv)
if let model = gpuFlags.gpuModel {
config.process.env.append("MLX_GPU_MODEL=\(model)")
}
if let mem = gpuFlags.gpuMemory {
config.process.env.append("MLX_GPU_MEMORY=\(mem)")
}
}
log.info("GPU enabled: vsock port \(gpuFlags.gpuPort)")
if let model = gpuFlags.gpuModel {
log.info("GPU model: \(model)")
}
}

progress.set(description: "Starting container")

let options = ContainerCreateOptions(autoRemove: managementFlags.remove)
Expand Down
49 changes: 49 additions & 0 deletions Sources/Services/ContainerAPIService/Client/Flags.swift
Original file line number Diff line number Diff line change
Expand Up @@ -353,4 +353,53 @@ public struct Flags {
@Option(name: .long, help: "Maximum number of concurrent downloads (default: 3)")
public var maxConcurrentDownloads: Int = 3
}

/// GPU acceleration flags for MLX inference on Apple Silicon.
///
/// When `--gpu` is passed, the container runtime starts the MLX Container
/// Daemon on the host (if not already running) and injects vsock environment
/// variables into the guest VM. Code inside the container can then use the
/// `mlx-container` Python package to run Metal-accelerated inference without
/// installing MLX or Metal drivers in the Linux guest.
///
/// Requires: container-toolkit-mlx
/// https://github.com/RobotFlow-Labs/container-toolkit-mlx
public struct GPU: ParsableArguments {
public init() {}

public init(
gpu: Bool,
gpuMemory: UInt64?,
gpuModel: String?,
gpuMaxTokens: Int,
gpuPort: UInt32
) {
self.gpu = gpu
self.gpuMemory = gpuMemory
self.gpuModel = gpuModel
self.gpuMaxTokens = gpuMaxTokens
self.gpuPort = gpuPort
}

@Flag(name: .long, help: "Enable GPU access via the MLX Container Toolkit")
public var gpu: Bool = false

@Option(
name: .customLong("gpu-memory"),
help: ArgumentHelp("GPU memory budget in GB (0 = share all available)", valueName: "gb")
)
public var gpuMemory: UInt64?

@Option(
name: .customLong("gpu-model"),
help: ArgumentHelp("HuggingFace model to pre-load (e.g. mlx-community/Llama-3.2-1B-4bit)", valueName: "id")
)
public var gpuModel: String?

@Option(name: .customLong("gpu-max-tokens"), help: "Maximum tokens per inference request (default: 4096)")
public var gpuMaxTokens: Int = 4096

@Option(name: .customLong("gpu-port"), help: "vsock port for the GPU daemon (default: 2048)")
public var gpuPort: UInt32 = 2048
}
}