diff --git a/cellacdc/myutils.py b/cellacdc/myutils.py index f2695bcd3..b36ea318f 100644 --- a/cellacdc/myutils.py +++ b/cellacdc/myutils.py @@ -2092,6 +2092,9 @@ def download_model(model_name): except Exception as e: traceback.print_exc() return False + elif model_name == 'micro-sam': + # micro-sam downloads weights on first get_sam_model() via pooch + return True elif model_name == 'DeepSea': try: _download_deepsea_models() @@ -3997,9 +4000,12 @@ def import_promptable_segment_module(model_name): f'cellacdc.promptable_models.{model_name}.acdcPromptSegment' ) except ModuleNotFoundError as e: - # Check if custom model - cp = config.ConfigParser() + # Check if custom model (config is GUI-only) + import configparser + cp = configparser.ConfigParser() cp.read(promptable_models_list_file_path) + if not cp.has_section(model_name): + raise e from None model_path = cp[model_name]['path'] spec = importlib.util.spec_from_file_location( 'acdcPromptSegment', model_path @@ -4116,8 +4122,9 @@ def import_segment_module(model_name): try: acdcSegment = import_module(f'cellacdc.models.{model_name}.acdcSegment') except ModuleNotFoundError as e: - # Check if custom model - cp = config.ConfigParser() + # Check if custom model (config is GUI-only) + import configparser + cp = configparser.ConfigParser() cp.read(models_list_file_path) model_path = cp[model_name]['path'] spec = importlib.util.spec_from_file_location('acdcSegment', model_path) diff --git a/cellacdc/promptable_models/micro-sam/acdcPromptSegment.py b/cellacdc/promptable_models/micro-sam/acdcPromptSegment.py new file mode 100644 index 000000000..9ac080882 --- /dev/null +++ b/cellacdc/promptable_models/micro-sam/acdcPromptSegment.py @@ -0,0 +1,266 @@ +"""Promptable segmentation via micro-sam (domain-adapted SAM for microscopy).""" + +from collections import defaultdict + +from cellacdc.promptable_models.utils import build_combined_mask + +import numpy as np +import cv2 + +from cellacdc import myutils +from micro_sam.util import get_sam_model, get_device, precompute_image_embeddings +from micro_sam.prompt_based_segmentation import segment_from_points + + +class AvailableModels: + # Light microscopy first, then vanilla + values = [ + "vit_b_lm", "vit_t_lm", "vit_l_lm", + "vit_b", "vit_t", "vit_l", "vit_h", + ] + + +class NotParam: + not_a_param = True + + +class Model: + def __init__(self, model_type: AvailableModels = "vit_b_lm", gpu: bool = True): + """Promptable micro-sam model (domain-adapted for microscopy). + + Parameters + ---------- + model_type : AvailableModels, optional + Model variant. Default is "vit_b_lm" (light microscopy). + gpu : bool, optional + Whether to run on GPU if available. Default is True. + """ + if gpu: + from cellacdc import is_mac_arm64 + if is_mac_arm64: + device = "cpu" + else: + device = "cuda" + else: + device = "cpu" + + self.model = get_sam_model(model_type=model_type, device=get_device(device)) + + self._image_embeddings = None + self._embedded_shape = None + + self.prompt_ids_image_mapper = {} + self.prompts = defaultdict(list) + self.negative_prompts = defaultdict(list) + + def _normalize_prompt(self, prompt): + prompt = tuple(prompt) + if len(prompt) != 3: + raise ValueError( + "Point prompt must be a sequence of 3 coordinates (z, y, x)." + ) + z, y, x = prompt + if z is None or (isinstance(z, float) and np.isnan(z)): + z = 0 + return int(z), float(y), float(x) + + def _to_rgb(self, image): + img = myutils.to_uint8(image) + if img.ndim == 2: + try: + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + except Exception: + img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) + else: + if img.shape[-1] == 4: + img = img[..., :3] + try: + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + except Exception: + pass + return img + + def _set_image(self, image): + img_rgb = self._to_rgb(image) + if self._embedded_shape is None or self._embedded_shape != img_rgb.shape: + self._image_embeddings = precompute_image_embeddings( + self.model, img_rgb, ndim=2, verbose=False, + ) + self._embedded_shape = img_rgb.shape + + def _collect_prompts(self, prompt_id, treat_other_objects_as_background): + pos_prompts = self.prompts.get(prompt_id, []) + neg_prompts = list(self.negative_prompts.get(0, [])) + neg_prompts.extend(self.negative_prompts.get(prompt_id, [])) + + if treat_other_objects_as_background: + for other_id, other_prompts in self.prompts.items(): + if other_id == prompt_id: + continue + neg_prompts.extend(other_prompts) + + return pos_prompts, neg_prompts + + def _points_for_slice(self, prompts, z): + coords = [] + labels = [] + num_pos = 0 + for prompt, prompt_type, label in prompts: + if prompt_type != "point": + raise ValueError(f"Unsupported prompt type: {prompt_type}") + + z_p, y, x = self._normalize_prompt(prompt) + if z is not None and z_p != z: + continue + + coords.append([x, y]) + labels.append(label) + if label == 1: + num_pos += 1 + + if not coords: + return None, None, 0 + + return np.array(coords), np.array(labels), num_pos + + def add_prompt( + self, + prompt, + prompt_id: int, + *args, + image=None, + image_origin=(0, 0, 0), + parent_obj_id=0, + prompt_type="point", + **kwargs, + ): + """Add prompt to model.""" + prompt = self._normalize_prompt(prompt) + + if prompt_id not in self.prompt_ids_image_mapper and prompt_id != 0: + self.prompt_ids_image_mapper[prompt_id] = (image, image_origin) + + if prompt_id != 0: + self.prompts[prompt_id].append((prompt, prompt_type)) + elif parent_obj_id != 0: + self.negative_prompts[parent_obj_id].append((prompt, prompt_type)) + else: + self.negative_prompts[0].append((prompt, prompt_type)) + + def segment( + self, + image, + lab: NotParam = None, + treat_other_objects_as_background: bool = False, + *args, + **kwargs, + ): + """Run segmentation using the prompts added.""" + is_rgb_image = image.ndim >= 3 and image.shape[-1] in (3, 4) + is_z_stack = (image.ndim == 3 and not is_rgb_image) or (image.ndim == 4) + + if is_rgb_image: + lab_out = np.zeros(image.shape[:-1], dtype=np.uint32) + else: + lab_out = np.zeros(image.shape, dtype=np.uint32) + + for prompt_id, (prompt_image, image_origin) in self.prompt_ids_image_mapper.items(): + if prompt_id == 0: + continue + + if prompt_image is None: + prompt_image = image + + pos_prompts, neg_prompts = self._collect_prompts( + prompt_id, treat_other_objects_as_background + ) + + is_prompt_rgb = ( + prompt_image.ndim >= 3 and prompt_image.shape[-1] in (3, 4) + ) + is_prompt_z_stack = ( + (prompt_image.ndim == 3 and not is_prompt_rgb) + or (prompt_image.ndim == 4) + ) + + if is_prompt_rgb: + obj_mask = np.zeros(prompt_image.shape[:-1], dtype=bool) + else: + obj_mask = np.zeros(prompt_image.shape, dtype=bool) + + prompts = [] + for prompt, prompt_type in neg_prompts: + prompts.append((prompt, prompt_type, 0)) + for prompt, prompt_type in pos_prompts: + prompts.append((prompt, prompt_type, 1)) + + if not prompts: + continue + + if is_prompt_z_stack: + z_dim = obj_mask.shape[0] + for z in range(z_dim): + point_coords, point_labels, num_pos = self._points_for_slice( + prompts, z + ) + if num_pos == 0: + continue + + self._set_image(prompt_image[z]) + # micro-sam expects points (y, x); we have [x, y] + points_yx = point_coords[:, ::-1].astype(np.float64) + mask = segment_from_points( + self.model, + points_yx, + point_labels, + image_embeddings=self._image_embeddings, + use_best_multimask=True, + ) + obj_mask[z] = np.asarray(mask).squeeze().astype(bool) + else: + point_coords, point_labels, num_pos = self._points_for_slice( + prompts, None + ) + if num_pos == 0: + continue + + self._set_image(prompt_image) + points_yx = point_coords[:, ::-1].astype(np.float64) + mask = segment_from_points( + self.model, + points_yx, + point_labels, + image_embeddings=self._image_embeddings, + use_best_multimask=True, + ) + obj_mask[:] = np.asarray(mask).squeeze().astype(bool) + + if not np.any(obj_mask): + continue + + z0, y0, x0 = map(int, image_origin) + if obj_mask.ndim == 2: + obj_slice = ( + slice(y0, y0 + obj_mask.shape[0]), + slice(x0, x0 + obj_mask.shape[1]), + ) + else: + obj_slice = ( + slice(z0, z0 + obj_mask.shape[0]), + slice(y0, y0 + obj_mask.shape[1]), + slice(x0, x0 + obj_mask.shape[2]), + ) + + lab_out[obj_slice][obj_mask] = prompt_id + + lab_out = build_combined_mask(lab_out) + + self.prompt_ids_image_mapper = {} + self.prompts = defaultdict(list) + self.negative_prompts = defaultdict(list) + + return lab_out + + +def url_help(): + return "https://computational-cell-analytics.github.io/micro-sam/" diff --git a/cellacdc/promptable_models/sam2/acdcPromptSegment.py b/cellacdc/promptable_models/sam2/acdcPromptSegment.py index a9c7d69c6..745f30b1a 100644 --- a/cellacdc/promptable_models/sam2/acdcPromptSegment.py +++ b/cellacdc/promptable_models/sam2/acdcPromptSegment.py @@ -154,7 +154,7 @@ def segment( self, image, lab: NotParam = None, - treat_other_objects_as_background: bool = True, + treat_other_objects_as_background: bool = False, *args, **kwargs, ): diff --git a/cellacdc/promptable_models/segment_anything/acdcPromptSegment.py b/cellacdc/promptable_models/segment_anything/acdcPromptSegment.py index 20ee1014b..1a74a1061 100644 --- a/cellacdc/promptable_models/segment_anything/acdcPromptSegment.py +++ b/cellacdc/promptable_models/segment_anything/acdcPromptSegment.py @@ -154,7 +154,7 @@ def segment( self, image, lab: NotParam = None, - treat_other_objects_as_background: bool = True, + treat_other_objects_as_background: bool = False, *args, **kwargs, ): diff --git a/pyproject.toml b/pyproject.toml index e96d7a741..40d1abc1b 100755 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,7 +39,7 @@ classifiers = [ requires-python = ">=3.10" dependencies = [ "numpy", - "pandas<3.0", + "pandas", "opencv-python-headless", "natsort", "h5py", @@ -50,9 +50,7 @@ dependencies = [ "boto3", "requests", "setuptools-scm", - "matplotlib", - "sympy", - "imagecodecs" + "matplotlib" ] dynamic = [ "version", @@ -127,6 +125,26 @@ deepsea = [ "munkres", ] +sam = [ + "torch", + "segment_anything @ git+https://github.com/facebookresearch/segment-anything.git", +] + +sam2 = [ + "torch", + "sam-2 @ git+https://github.com/facebookresearch/sam2.git", +] + +cellsam = [ + "torch", + "cellSAM @ git+https://github.com/vanvalenlab/cellSAM.git", +] + +microsam = [ + "torch", + "micro-sam @ git+https://github.com/keejkrej/micro-sam.git@cellacdc", +] + all = [ "PyQt6", "torchvision", diff --git a/tests/prompt_segm/test_microsam.py b/tests/prompt_segm/test_microsam.py new file mode 100644 index 000000000..cc7c58756 --- /dev/null +++ b/tests/prompt_segm/test_microsam.py @@ -0,0 +1,76 @@ +"""Tests for promptable micro-sam.""" + +from pathlib import Path + +import pytest + +from cellacdc import myutils +from tests.utils import ( + ensure_microsam, + get_test_dataset, + get_ground_truth_centroids, + validate_labels, + save_segmentation_overlay, +) + +ensure_microsam() + + +class TestPromptableMicroSAM: + @pytest.fixture(scope="class", autouse=True) + def download_models(self): + """micro-sam downloads weights on first get_sam_model().""" + myutils.download_model("micro-sam") + + @pytest.fixture + def test_data(self): + """Load test dataset with ground truth.""" + dataset = get_test_dataset() + segm_data = dataset.segm_data() + posData = dataset.posData() + posData.loadImgData() + posData.loadOtherFiles(load_segm_data=False, load_metadata=True) + posData.buildPaths() + return posData, segm_data + + def test_promptable_segmentation_with_ground_truth_centroids(self, test_data): + """Test micro-sam promptable segmentation using ground truth centroids.""" + posData, segm_data = test_data + + # Use last frame + frame_index = len(posData.img_data) - 1 + frame = posData.img_data[frame_index] + gt_mask = segm_data[frame_index] + + # Get centroids from ground truth + centroids = get_ground_truth_centroids(gt_mask) + assert len(centroids) > 0, "No objects found in ground truth" + + acdcPromptSegment = myutils.import_promptable_segment_module("micro-sam") + model = acdcPromptSegment.Model(model_type="vit_b_lm", gpu=True) + + # Add prompts for each ground truth centroid + for label_id, y, x in centroids: + model.add_prompt( + prompt=(0, y, x), + prompt_id=label_id, + image=frame, + image_origin=(0, 0, 0), + prompt_type="point", + ) + + labels = model.segment(frame, treat_other_objects_as_background=False) + + validate_labels(labels, frame.shape[:2]) + + num_gt_objects = len(centroids) + num_detected = labels.max() + print(f"[INFO] Ground truth objects: {num_gt_objects}") + print(f"[INFO] Detected objects: {num_detected}") + + plots_dir = Path(__file__).parent.parent / "_plots" / "prompt_segm" / "microsam" + save_segmentation_overlay( + labels, frame, frame_index, + plots_dir / f"test_promptable_microsam_frame_{frame_index:04d}.png", + prompt_points=centroids, + ) diff --git a/tests/utils/__init__.py b/tests/utils/__init__.py index 8228927eb..156a90026 100644 --- a/tests/utils/__init__.py +++ b/tests/utils/__init__.py @@ -9,6 +9,7 @@ print_segmentation_results, ensure_sam, ensure_sam2, + ensure_microsam, ensure_cellsam, get_test_posdata, get_test_dataset, @@ -24,6 +25,7 @@ "print_segmentation_results", "ensure_sam", "ensure_sam2", + "ensure_microsam", "ensure_cellsam", "get_test_posdata", "get_test_dataset", diff --git a/tests/utils/segmentation.py b/tests/utils/segmentation.py index 0bdee1ca4..214df9252 100644 --- a/tests/utils/segmentation.py +++ b/tests/utils/segmentation.py @@ -221,41 +221,23 @@ def save_labels_image(labels: np.ndarray, output_path: Path): def ensure_sam(): - """Ensure segment_anything is importable, checking local repo as fallback.""" - import importlib - import sys - - try: - importlib.import_module("segment_anything") - return - except ModuleNotFoundError: - repo_root = Path(__file__).resolve().parents[3].parent - candidate = repo_root / "segment-anything" - if candidate.exists(): - sys.path.insert(0, str(candidate)) - + """Ensure segment_anything is importable.""" import pytest pytest.importorskip("segment_anything") def ensure_sam2(): - """Ensure sam2 is importable, checking local repo as fallback.""" - import importlib - import sys - - try: - importlib.import_module("sam2") - return - except ModuleNotFoundError: - repo_root = Path(__file__).resolve().parents[3].parent - candidate = repo_root / "sam2" - if candidate.exists(): - sys.path.insert(0, str(candidate)) - + """Ensure sam2 is importable.""" import pytest pytest.importorskip("sam2") +def ensure_microsam(): + """Ensure micro_sam is importable.""" + import pytest + pytest.importorskip("micro_sam") + + def ensure_cellsam(): """Ensure cellSAM is importable.""" import pytest