From 6f1950c4c9c50ded402f75922a678c97be72d710 Mon Sep 17 00:00:00 2001 From: AIFlow_ML Date: Sun, 15 Mar 2026 11:08:54 +0100 Subject: [PATCH] Add --gpu flag for Metal/MLX inference in containers via vsock Adds GPU acceleration support for Linux containers on Apple Silicon through the MLX Container Toolkit. When --gpu is passed to `container run`, the runtime injects vsock environment variables into the guest VM, enabling code inside the container to access the host's Metal GPU for ML inference. Architecture: Container (Linux VM) --[gRPC over vsock]--> Host daemon (MLX/Metal) The host-side daemon (mlx-container-daemon) manages model loading and serves inference requests over the same vsock channel that vminitd already uses for container management. No GPU drivers or Metal frameworks are needed inside the Linux guest. New flags on `container run`: --gpu Enable GPU access --gpu-model Pre-load a HuggingFace model --gpu-memory GPU memory budget --gpu-max-tokens Max tokens per request --gpu-port vsock port (default: 2048) Example: container run --gpu --gpu-model mlx-community/Llama-3.2-1B-4bit \ ubuntu:latest python3 -c \ "from mlx_container import generate; print(generate('Hello', model='mlx-community/Llama-3.2-1B-4bit').text)" Requires: https://github.com/RobotFlow-Labs/container-toolkit-mlx Signed-off-by: ilessio --- .../Container/ContainerRun.swift | 28 ++++++++++- .../ContainerAPIService/Client/Flags.swift | 49 +++++++++++++++++++ 2 files changed, 76 insertions(+), 1 deletion(-) diff --git a/Sources/ContainerCommands/Container/ContainerRun.swift b/Sources/ContainerCommands/Container/ContainerRun.swift index c83fbf790..081a6c1fd 100644 --- a/Sources/ContainerCommands/Container/ContainerRun.swift +++ b/Sources/ContainerCommands/Container/ContainerRun.swift @@ -51,6 +51,9 @@ extension Application { @OptionGroup(title: "Image fetch options") var imageFetchFlags: Flags.ImageFetch + @OptionGroup(title: "GPU options") + var gpuFlags: Flags.GPU + @OptionGroup public var logOptions: Flags.Logging @@ -94,7 +97,7 @@ extension Application { ) } - let ck = try await Utility.containerConfigFromFlags( + var ck = try await Utility.containerConfigFromFlags( id: id, image: image, arguments: arguments, @@ -107,6 +110,29 @@ extension Application { log: log ) + // GPU support: inject vsock environment variables so the guest can + // reach the MLX Container Daemon on the host. + if gpuFlags.gpu { + let gpuEnv = [ + "MLX_VSOCK_CID=2", + "MLX_VSOCK_PORT=\(gpuFlags.gpuPort)", + "MLX_GPU_ENABLED=1", + ] + if var config = ck.0 as? ContainerConfiguration { + config.process.env.append(contentsOf: gpuEnv) + if let model = gpuFlags.gpuModel { + config.process.env.append("MLX_GPU_MODEL=\(model)") + } + if let mem = gpuFlags.gpuMemory { + config.process.env.append("MLX_GPU_MEMORY=\(mem)") + } + } + log.info("GPU enabled: vsock port \(gpuFlags.gpuPort)") + if let model = gpuFlags.gpuModel { + log.info("GPU model: \(model)") + } + } + progress.set(description: "Starting container") let options = ContainerCreateOptions(autoRemove: managementFlags.remove) diff --git a/Sources/Services/ContainerAPIService/Client/Flags.swift b/Sources/Services/ContainerAPIService/Client/Flags.swift index 88de209f9..983047ea3 100644 --- a/Sources/Services/ContainerAPIService/Client/Flags.swift +++ b/Sources/Services/ContainerAPIService/Client/Flags.swift @@ -353,4 +353,53 @@ public struct Flags { @Option(name: .long, help: "Maximum number of concurrent downloads (default: 3)") public var maxConcurrentDownloads: Int = 3 } + + /// GPU acceleration flags for MLX inference on Apple Silicon. + /// + /// When `--gpu` is passed, the container runtime starts the MLX Container + /// Daemon on the host (if not already running) and injects vsock environment + /// variables into the guest VM. Code inside the container can then use the + /// `mlx-container` Python package to run Metal-accelerated inference without + /// installing MLX or Metal drivers in the Linux guest. + /// + /// Requires: container-toolkit-mlx + /// https://github.com/RobotFlow-Labs/container-toolkit-mlx + public struct GPU: ParsableArguments { + public init() {} + + public init( + gpu: Bool, + gpuMemory: UInt64?, + gpuModel: String?, + gpuMaxTokens: Int, + gpuPort: UInt32 + ) { + self.gpu = gpu + self.gpuMemory = gpuMemory + self.gpuModel = gpuModel + self.gpuMaxTokens = gpuMaxTokens + self.gpuPort = gpuPort + } + + @Flag(name: .long, help: "Enable GPU access via the MLX Container Toolkit") + public var gpu: Bool = false + + @Option( + name: .customLong("gpu-memory"), + help: ArgumentHelp("GPU memory budget in GB (0 = share all available)", valueName: "gb") + ) + public var gpuMemory: UInt64? + + @Option( + name: .customLong("gpu-model"), + help: ArgumentHelp("HuggingFace model to pre-load (e.g. mlx-community/Llama-3.2-1B-4bit)", valueName: "id") + ) + public var gpuModel: String? + + @Option(name: .customLong("gpu-max-tokens"), help: "Maximum tokens per inference request (default: 4096)") + public var gpuMaxTokens: Int = 4096 + + @Option(name: .customLong("gpu-port"), help: "vsock port for the GPU daemon (default: 2048)") + public var gpuPort: UInt32 = 2048 + } }