From be3e546526b9f787d91c021769e9892a7b50ceff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=88=BF=E8=8D=A3=E7=A5=A5?= Date: Tue, 11 Nov 2025 23:23:47 +0800 Subject: [PATCH] Feature: Add macOS MPS support --- server/Dockerfile | 6 +++--- server/nninteractive_slicer_server/main.py | 16 ++++++++++++++-- server/pyproject.toml | 2 +- server/requirements.txt | 2 +- 4 files changed, 19 insertions(+), 7 deletions(-) diff --git a/server/Dockerfile b/server/Dockerfile index ea97f8a..01268be 100644 --- a/server/Dockerfile +++ b/server/Dockerfile @@ -1,6 +1,6 @@ -FROM nvidia/cuda:12.1.0-devel-ubuntu20.04 +FROM ubuntu:20.04 -RUN apt-get update && apt-get install -y wget +RUN apt-get update && apt-get install -y wget python3.12 python3.12-pip RUN useradd -m user @@ -25,7 +25,7 @@ RUN /bin/bash -c "\ RUN echo "Conda env nnInteractive created" -RUN /opt/conda/envs/nnInteractive/bin/python3.12 -m pip install torch torchvision --index-url https://download.pytorch.org/whl/cu126 +RUN /opt/conda/envs/nnInteractive/bin/python3.12 -m pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu WORKDIR /opt/server diff --git a/server/nninteractive_slicer_server/main.py b/server/nninteractive_slicer_server/main.py index 4ce5ea8..c8d3d18 100644 --- a/server/nninteractive_slicer_server/main.py +++ b/server/nninteractive_slicer_server/main.py @@ -139,14 +139,26 @@ def download_weights(self): def make_session(self): """ Creates an nnInteractiveInferenceSession, points it at the downloaded model. + Automatically detect the best available device (MPS > CUDA > CPU). """ + # Automatically detect the best available device + if torch.backends.mps.is_available() and torch.backends.mps.is_built(): + device = torch.device("mps") + print("Using MPS (Metal Performance Shaders) device for inference") + elif torch.cuda.is_available(): + device = torch.device("cuda:0") + print("Using CUDA device for inference") + else: + device = torch.device("cpu") + print("Using CPU device for inference") + session = nnInteractiveInferenceSession( - device=torch.device("cuda:0"), # Set inference device + device=device, # Set inference device automatically use_torch_compile=False, # Experimental: Not tested yet verbose=True, torch_n_threads=os.cpu_count(), # Use available CPU cores do_autozoom=True, # Enables AutoZoom for better patching - use_pinned_memory=True, # Optimizes GPU memory transfers + use_pinned_memory=(device.type == 'cuda'), # Only use pinned memory for CUDA ) # Load the trained model diff --git a/server/pyproject.toml b/server/pyproject.toml index 1ddbd61..59554bf 100644 --- a/server/pyproject.toml +++ b/server/pyproject.toml @@ -16,7 +16,7 @@ dependencies = [ "nninteractive==1.0.1", "fastapi==0.111.0", "numpy==2.2.3", - "torch==2.6.0", + "torch>=2.9.0", "Pillow==11.1.0", "transformers==4.49.0", "xxhash==3.5.0" diff --git a/server/requirements.txt b/server/requirements.txt index 0249852..642336d 100644 --- a/server/requirements.txt +++ b/server/requirements.txt @@ -1,7 +1,7 @@ nninteractive==1.0.1 fastapi==0.111.0 numpy==2.2.3 -torch==2.6.0 +torch>=2.9.0 Pillow==11.1.0 transformers==4.49.0 xxhash==3.5.0 \ No newline at end of file