diff --git a/Sources/ContainerCommands/Container/ContainerRun.swift b/Sources/ContainerCommands/Container/ContainerRun.swift index c83fbf790..081a6c1fd 100644 --- a/Sources/ContainerCommands/Container/ContainerRun.swift +++ b/Sources/ContainerCommands/Container/ContainerRun.swift @@ -51,6 +51,9 @@ extension Application { @OptionGroup(title: "Image fetch options") var imageFetchFlags: Flags.ImageFetch + @OptionGroup(title: "GPU options") + var gpuFlags: Flags.GPU + @OptionGroup public var logOptions: Flags.Logging @@ -94,7 +97,7 @@ extension Application { ) } - let ck = try await Utility.containerConfigFromFlags( + var ck = try await Utility.containerConfigFromFlags( id: id, image: image, arguments: arguments, @@ -107,6 +110,29 @@ extension Application { log: log ) + // GPU support: inject vsock environment variables so the guest can + // reach the MLX Container Daemon on the host. + if gpuFlags.gpu { + let gpuEnv = [ + "MLX_VSOCK_CID=2", + "MLX_VSOCK_PORT=\(gpuFlags.gpuPort)", + "MLX_GPU_ENABLED=1", + ] + if var config = ck.0 as? ContainerConfiguration { + config.process.env.append(contentsOf: gpuEnv) + if let model = gpuFlags.gpuModel { + config.process.env.append("MLX_GPU_MODEL=\(model)") + } + if let mem = gpuFlags.gpuMemory { + config.process.env.append("MLX_GPU_MEMORY=\(mem)") + } + } + log.info("GPU enabled: vsock port \(gpuFlags.gpuPort)") + if let model = gpuFlags.gpuModel { + log.info("GPU model: \(model)") + } + } + progress.set(description: "Starting container") let options = ContainerCreateOptions(autoRemove: managementFlags.remove) diff --git a/Sources/Services/ContainerAPIService/Client/Flags.swift b/Sources/Services/ContainerAPIService/Client/Flags.swift index 88de209f9..983047ea3 100644 --- a/Sources/Services/ContainerAPIService/Client/Flags.swift +++ b/Sources/Services/ContainerAPIService/Client/Flags.swift @@ -353,4 +353,53 @@ public struct Flags { @Option(name: .long, help: "Maximum number of concurrent downloads (default: 3)") public var maxConcurrentDownloads: Int = 3 } + + /// GPU acceleration flags for MLX inference on Apple Silicon. + /// + /// When `--gpu` is passed, the container runtime starts the MLX Container + /// Daemon on the host (if not already running) and injects vsock environment + /// variables into the guest VM. Code inside the container can then use the + /// `mlx-container` Python package to run Metal-accelerated inference without + /// installing MLX or Metal drivers in the Linux guest. + /// + /// Requires: container-toolkit-mlx + /// https://github.com/RobotFlow-Labs/container-toolkit-mlx + public struct GPU: ParsableArguments { + public init() {} + + public init( + gpu: Bool, + gpuMemory: UInt64?, + gpuModel: String?, + gpuMaxTokens: Int, + gpuPort: UInt32 + ) { + self.gpu = gpu + self.gpuMemory = gpuMemory + self.gpuModel = gpuModel + self.gpuMaxTokens = gpuMaxTokens + self.gpuPort = gpuPort + } + + @Flag(name: .long, help: "Enable GPU access via the MLX Container Toolkit") + public var gpu: Bool = false + + @Option( + name: .customLong("gpu-memory"), + help: ArgumentHelp("GPU memory budget in GB (0 = share all available)", valueName: "gb") + ) + public var gpuMemory: UInt64? + + @Option( + name: .customLong("gpu-model"), + help: ArgumentHelp("HuggingFace model to pre-load (e.g. mlx-community/Llama-3.2-1B-4bit)", valueName: "id") + ) + public var gpuModel: String? + + @Option(name: .customLong("gpu-max-tokens"), help: "Maximum tokens per inference request (default: 4096)") + public var gpuMaxTokens: Int = 4096 + + @Option(name: .customLong("gpu-port"), help: "vsock port for the GPU daemon (default: 2048)") + public var gpuPort: UInt32 = 2048 + } }