diff --git a/.claude/skills/nimbus-interface/SKILL.md b/.claude/skills/nimbus-interface/SKILL.md new file mode 100644 index 0000000..aa466bf --- /dev/null +++ b/.claude/skills/nimbus-interface/SKILL.md @@ -0,0 +1,110 @@ +--- +name: nimbus-interface +description: Reference for the NimbusImage/Girder API used by all workers in this repository. Use when building, debugging, or testing NimbusImage workers — including image loading, annotation CRUD, property computation, multi-channel merging, coordinate conversions, local test environments, and infrastructure troubleshooting (e.g. HTTP 500 errors). Also use when writing test scripts that interact with the Nimbus API. +--- + +# NimbusImage Worker Development + +## Quick Start + +Determine the task type: +- **Building/modifying a worker** → See [references/api.md](references/api.md) for full API patterns +- **Debugging HTTP 500 errors** → Check prerequisites below +- **Writing local test scripts** → See local testing section below +- **Coordinate confusion** → See critical pitfalls below + +## Infrastructure Prerequisites + +The Girder server requires **MongoDB**. Without it, all endpoints return HTTP 500 (except `/system/version`). Debug with: +```bash +docker ps | grep mongo # Must be running +curl -s http://localhost:8080/api/v1/system/version # Works without MongoDB +``` + +Full stack: `girder`, `worker` (celery), `rabbitmq`, `memcached`, `mongodb`. +Compose file: `/home/arjun/UPennContrast/docker-compose.yaml`. + +## Critical Pitfalls + +### Coordinate swap (numpy vs annotations) +Numpy is `[row, col]` = `[y, x]`. Annotations use `{'x': pixel_x, 'y': pixel_y}`. +```python +# skimage contour (row, col) → annotation: +coords = [{'x': float(col), 'y': float(row)} for row, col in contour] + +# Use annotation_tools helpers to avoid manual swaps: +from annotation_utilities.annotation_tools import polygons_to_annotations, annotations_to_polygons +``` + +### The 0.5 pixel offset +scikit-image uses pixel centers; Girder uses top-left corner: +```python +polygon = np.array([[c['y'] - 0.5, c['x'] - 0.5] for c in annotation['coordinates']]) +rr, cc = draw.polygon(polygon[:, 0], polygon[:, 1], shape=image.shape) +``` + +### Tags interface returns a list, not a dict +```python +# CORRECT: +tags = params['workerInterface'].get('Training Tag', []) +# WRONG (crashes with AttributeError): +tags = params['workerInterface'].get('Training Tag', {}).get('tags', []) +``` + +### Multi-channel merge output dtype +`process_and_merge_channels` returns `float64` with values 0-255 (not 0-1). Convert for ML: +```python +rgb_uint8 = np.clip(merged, 0, 255).astype(np.uint8) +``` + +Typical shapes: +- `getRegion().squeeze()`: `(H, W)` uint16 +- `get_images_for_all_channels`: each `(H, W, 1)` uint16 +- `process_and_merge_channels`: `(H, W, 3)` float64, values 0-255 + +## Local Testing + +### Avoid importing entrypoint.py +Worker entrypoints import heavy ML libraries (torch, sam2) at module level. Copy helper functions locally instead of importing the entrypoint. + +### Local venv dependencies +```bash +pip install girder-client tifffile +pip install -e /home/arjun/UPennContrast/devops/girder/annotation_client +pip install -e /home/arjun/ImageAnalysisProject/annotation_utilities +pip install -e /home/arjun/ImageAnalysisProject/worker_client +pip install numpy scipy scikit-image shapely matplotlib pillow numba +# ML deps (torch, sam2, etc.) only needed for inference, not API testing +``` + +### Authentication for test scripts +```python +import girder_client +gc = girder_client.GirderClient(apiUrl='http://localhost:8080/api/v1') +gc.authenticate('username', 'password') +token = gc.token # Use this token with annotation_client classes +``` +Env vars: `NIMBUS_API_URL` (default `http://localhost:8080/api/v1`), `NIMBUS_TOKEN`. + +### Test dataset +Dataset `69988c84b48d8121b565aba4`: 2 channels (Brightfield, YFP), 7Z, 4T, 6XY, 1024x1022 uint16. 544 polygons tagged "YFP blob" at XY=0 Z=3 Time=0. + +## Key Packages + +| Package | Location | +|---------|----------| +| annotation_client | `/home/arjun/UPennContrast/devops/girder/annotation_client/` | +| annotation_utilities | `/home/arjun/ImageAnalysisProject/annotation_utilities/` | +| worker_client | `/home/arjun/ImageAnalysisProject/worker_client/` | +| Workers | `/home/arjun/ImageAnalysisProject/workers/` | + +Key source files: `annotation_client/{annotations,tiles,workers}.py`, `annotation_utilities/{annotation_tools,batch_argument_parser}.py` + +## Detailed API Reference + +See [references/api.md](references/api.md) for complete API patterns including: +- Image access (single frame, subregion, multi-channel merge) +- Annotation CRUD (fetch, filter, create, delete) +- Property value computation and submission +- Writing images back to Girder +- Worker interface type table diff --git a/.claude/skills/nimbus-interface/references/api.md b/.claude/skills/nimbus-interface/references/api.md new file mode 100644 index 0000000..18b18d4 --- /dev/null +++ b/.claude/skills/nimbus-interface/references/api.md @@ -0,0 +1,204 @@ +# NimbusImage API Reference + +## Table of Contents +- [Image Access](#image-access) +- [Annotations](#annotations) +- [Property Values](#property-values) +- [Writing Images to Girder](#writing-images-to-girder) +- [Worker Interface Types](#worker-interface-types) + +--- + +## Image Access + +### Setup +```python +import annotation_client.tiles as tiles +tileClient = tiles.UPennContrastDataset(apiUrl=apiUrl, token=token, datasetId=datasetId) +``` + +### Metadata +```python +idx = tileClient.tiles['IndexRange'] +num_channels = idx.get('IndexC', 1) +num_z = idx.get('IndexZ', 1) +num_time = idx.get('IndexT', 1) +num_xy = idx.get('IndexXY', 1) +size_x = tileClient.tiles['sizeX'] +size_y = tileClient.tiles['sizeY'] +channel_names = tileClient.tiles.get('channels', []) +pixel_scale = tileClient.tiles.get('mm_x') # mm per pixel +``` + +### Single frame +```python +frame = tileClient.coordinatesToFrameIndex(XY, Z=z, T=time, channel=channel) +image = tileClient.getRegion(datasetId, frame=frame).squeeze() +# Returns (H, W) uint16 +``` + +### Subregion +```python +image = tileClient.getRegion(datasetId, frame=frame, + left=x_min, top=y_min, right=x_max, bottom=y_max, + units="base_pixels").squeeze() +``` + +### Multi-channel merged RGB +```python +import annotation_utilities.annotation_tools as annotation_tools + +images = annotation_tools.get_images_for_all_channels(tileClient, datasetId, XY, Z, Time) +# Each: (H, W, 1) uint16 +layers = annotation_tools.get_layers(tileClient.client, datasetId) +merged = annotation_tools.process_and_merge_channels(images, layers) +# Returns: (H, W, 3) float64, values 0-255 +``` +Merge modes: `'lighten'` (max, default), `'add'` (sum), `'screen'`. + +--- + +## Annotations + +### Client setup +```python +import annotation_client.annotations as annotations_client +annotationClient = annotations_client.UPennContrastAnnotationClient(apiUrl=apiUrl, token=token) +``` + +### Data structure +```python +{ + 'shape': 'polygon', # or 'point', 'line' + 'coordinates': [{'x': float, 'y': float}, ...], + 'location': {'XY': int, 'Z': int, 'Time': int}, + 'channel': int, + 'datasetId': str, + 'tags': ['tag1', 'tag2'], +} +``` + +### Fetch +```python +polygons = annotationClient.getAnnotationsByDatasetId(datasetId, shape='polygon') + +# Filter by tags server-side (must JSON-serialize) +import json +polygons = annotationClient.getAnnotationsByDatasetId( + datasetId, shape='polygon', tags=json.dumps(['my_tag'])) + +ann = annotationClient.getAnnotationById(annotationId) +``` + +### Client-side filtering +```python +import annotation_utilities.annotation_tools as annotation_tools + +filtered = annotation_tools.get_annotations_with_tags(annotations, tags, exclusive=False) +# exclusive=False: any matching tag; exclusive=True: exact tag set match + +filtered = annotation_tools.filter_elements_T_XY_Z(annotations, time, xy, z) +``` + +### Create +```python +annotationClient.createAnnotation(annotation_dict) +annotationClient.createMultipleAnnotations(annotation_list) # preferred + +# Using helpers (handles coordinate swap): +from annotation_utilities.annotation_tools import polygons_to_annotations +annotations = polygons_to_annotations( + shapely_polygons, datasetId, XY=0, Time=0, Z=0, tags=['my_tag'], channel=0) +``` + +### Delete +```python +annotationClient.deleteAnnotation(annotationId) +annotationClient.deleteMultipleAnnotations([id1, id2, ...]) +``` + +--- + +## Property Values + +### Setup +```python +import annotation_client.workers as workers +workerClient = workers.UPennContrastWorkerClient(datasetId, apiUrl, token, params) +``` + +### Get annotations for computation +```python +annotationList = workerClient.get_annotation_list_by_shape('polygon', limit=0) +annotationList = annotation_tools.get_annotations_with_tags( + annotationList, + params.get('tags', {}).get('tags', []), + params.get('tags', {}).get('exclusive', False)) +``` + +### Submit values +```python +property_values = {} +for ann in annotationList: + property_values[ann['_id']] = { + 'Area': float(area), + 'MeanIntensity': float(mean), + } +workerClient.add_multiple_annotation_property_values({datasetId: property_values}) +``` + +### Nested properties (per-Z, per-channel) +```python +property_values[ann['_id']] = { + 'MeanIntensity': {'z001': 42.0, 'z002': 84.0}, +} +``` + +### Pixel scale +```python +pixel_size = params['scales']['pixelSize'] # {'unit': 'mm', 'value': 0.000219} +z_step = params['scales']['zStep'] +t_step = params['scales']['tStep'] +``` + +--- + +## Writing Images to Girder + +```python +import large_image as li + +sink = li.new() +for i, frame in enumerate(tileClient.tiles['frames']): + large_image_params = {f'{k.lower()[5:]}': v for k, v in frame.items() + if k.startswith('Index') and len(k) > 5} + image = tileClient.getRegion(datasetId, frame=i).squeeze() + processed = your_function(image) + sink.addTile(processed, 0, 0, **large_image_params) + +if 'channels' in tileClient.tiles: + sink.channelNames = tileClient.tiles['channels'] +sink.mm_x = tileClient.tiles['mm_x'] +sink.mm_y = tileClient.tiles['mm_y'] +sink.magnification = tileClient.tiles['magnification'] + +sink.write('/tmp/output.tiff') +gc = tileClient.client +item = gc.uploadFileToFolder(datasetId, '/tmp/output.tiff') +gc.addMetadataToItem(item['itemId'], {'tool': 'YourWorker'}) +``` + +--- + +## Worker Interface Types + +| Type | Returns | Example | +|------|---------|---------| +| `number` | `int`/`float` | `32`, `0.5` | +| `text` | `str` | `"1-3, 5-8"` | +| `select` | `str` | `"model_name.pt"` | +| `checkbox` | `bool` | `True` | +| `channel` | `int` | `0` | +| `channelCheckboxes` | `dict[str, bool]` | `{"0": True, "1": False}` | +| `tags` | `list[str]` | `["DAPI blob"]` | +| `layer` | `str` | `"layer_id"` | diff --git a/CLAUDE.md b/CLAUDE.md index f3aeb08..665e39c 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -72,6 +72,53 @@ def compute(datasetId, apiUrl, token, params): Interface types: `number`, `text`, `select`, `checkbox`, `channel`, `channelCheckboxes`, `tags`, `layer`, `notes` +### Interface Parameter Data Types (What `params['workerInterface']` Returns) + +Each interface type returns a specific data type in `params['workerInterface']['FieldName']`: + +| Interface Type | Returns | Example Value | +|----------------|---------|---------------| +| `number` | `int` or `float` | `32`, `0.5` | +| `text` | `str` | `"1-3, 5-8"`, `""` | +| `select` | `str` | `"sam2.1_hiera_small.pt"` | +| `checkbox` | `bool` | `True`, `False` | +| `channel` | `int` | `0` | +| `channelCheckboxes` | `dict` of `str` → `bool` | `{"0": True, "1": False, "2": True}` | +| `tags` | **`list` of `str`** | `["DAPI blob"]`, `["cell", "nucleus"]` | +| `layer` | `str` | `"layer_id"` | + +**Common pitfall with `tags`**: The `tags` type returns a **plain list of strings**, NOT a dict. Do not call `.get('tags')` on the result. + +```python +# CORRECT - tags returns a list directly: +training_tags = params['workerInterface'].get('Training Tag', []) +# training_tags = ["DAPI blob"] + +# WRONG - will crash with AttributeError: 'list' object has no attribute 'get': +training_tags = params['workerInterface'].get('Training Tag', {}).get('tags', []) +``` + +**Note**: `params['tags']` (the top-level output tags for the worker, NOT a workerInterface field) is also a plain list of strings (e.g., `["DAPI blob"]`). Meanwhile, `params['tags']` used in property workers via `workerClient.get_annotation_list_by_shape()` uses `{'tags': [...], 'exclusive': bool}` — these are two different things. + +**Validating tags** (recommended pattern from cellpose_train, piscis): +```python +tags = workerInterface.get('My Tag Field', []) +if not tags or len(tags) == 0: + sendError("No tag selected", "Please select at least one tag.") + return +``` + +**Using tags to filter annotations**: +```python +# Pass the list directly to annotation_tools +filtered = annotation_tools.get_annotations_with_tags( + annotation_list, tags, exclusive=False) + +# Or with Girder API (must JSON-serialize) +annotations = annotationClient.getAnnotationsByDatasetId( + datasetId, shape='polygon', tags=json.dumps(tags)) +``` + ### Key APIs **annotation_client** (installed from NimbusImage repo): diff --git a/build_machine_learning_workers.sh b/build_machine_learning_workers.sh index 732b51a..b4c0069 100755 --- a/build_machine_learning_workers.sh +++ b/build_machine_learning_workers.sh @@ -39,6 +39,11 @@ docker build . -f ./workers/annotations/sam2_automatic_mask_generator/Dockerfile # Command for M1: # docker build . -f ./workers/annotations/sam2_automatic_mask_generator/Dockerfile_M1 -t annotations/sam2_automatic_mask_generator:latest $NO_CACHE +echo "Building SAM2 few-shot segmentation worker" +docker build . -f ./workers/annotations/sam2_fewshot_segmentation/Dockerfile -t annotations/sam2_fewshot_segmentation:latest $NO_CACHE +# Command for M1: +# docker build . -f ./workers/annotations/sam2_fewshot_segmentation/Dockerfile_M1 -t annotations/sam2_fewshot_segmentation:latest $NO_CACHE + echo "Building SAM2 propagate worker" docker build . -f ./workers/annotations/sam2_propagate/$DOCKERFILE -t annotations/sam2_propagate_worker:latest $NO_CACHE diff --git a/workers/annotations/sam2_fewshot_segmentation/Dockerfile b/workers/annotations/sam2_fewshot_segmentation/Dockerfile new file mode 100644 index 0000000..f3d057a --- /dev/null +++ b/workers/annotations/sam2_fewshot_segmentation/Dockerfile @@ -0,0 +1,70 @@ +FROM nvidia/cuda:12.1.0-cudnn8-devel-ubuntu22.04 as base +LABEL isUPennContrastWorker=True +LABEL com.nvidia.volumes.needed="nvidia_driver" + +ENV NVIDIA_VISIBLE_DEVICES all +ENV NVIDIA_DRIVER_CAPABILITIES compute,utility + +RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -qy tzdata && \ + apt-get install -qy software-properties-common python3-software-properties && \ + apt-get update && apt-get install -qy \ + build-essential \ + wget \ + python3 \ + r-base \ + libffi-dev \ + libssl-dev \ + libjpeg-dev \ + zlib1g-dev \ + r-base \ + git \ + libpython3-dev && \ + apt-get clean && rm -rf /var/lib/apt/lists/* + +ENV PATH="/root/miniforge3/bin:$PATH" +ARG PATH="/root/miniforge3/bin:$PATH" + +RUN wget \ + https://github.com/conda-forge/miniforge/releases/latest/download/Miniforge3-Linux-x86_64.sh \ + && mkdir /root/.conda \ + && bash Miniforge3-Linux-x86_64.sh -b \ + && rm -f Miniforge3-Linux-x86_64.sh + +FROM base as build + +COPY ./workers/annotations/sam2_fewshot_segmentation/environment.yml / +RUN conda env create --file /environment.yml +SHELL ["conda", "run", "-n", "worker", "/bin/bash", "-c"] + +RUN pip install rtree shapely + +RUN git clone https://github.com/arjunrajlaboratory/NimbusImage/ + +RUN pip install -r /NimbusImage/devops/girder/annotation_client/requirements.txt +RUN pip install -e /NimbusImage/devops/girder/annotation_client/ + +RUN mkdir -p /code +RUN git clone https://github.com/facebookresearch/sam2.git /code/sam2 +RUN pip install -e /code/sam2 + +# Change directory to sam2/checkpoints +WORKDIR /code/sam2/checkpoints +# Download the checkpoints into the checkpoints directory +RUN ./download_ckpts.sh +# Change back to the root directory +WORKDIR / + +COPY ./workers/annotations/sam2_fewshot_segmentation/entrypoint.py / + +COPY ./annotation_utilities /annotation_utilities +RUN pip install /annotation_utilities + +LABEL isUPennContrastWorker="" \ + isAnnotationWorker="" \ + interfaceName="SAM2 few-shot segmentation" \ + interfaceCategory="SAM2" \ + description="Uses SAM2 features for few-shot segmentation based on training annotations" \ + annotationShape="polygon" \ + defaultToolName="SAM2 few-shot" + +ENTRYPOINT ["conda", "run", "--no-capture-output", "-n", "worker", "python", "/entrypoint.py"] diff --git a/workers/annotations/sam2_fewshot_segmentation/Dockerfile_M1 b/workers/annotations/sam2_fewshot_segmentation/Dockerfile_M1 new file mode 100644 index 0000000..5e70586 --- /dev/null +++ b/workers/annotations/sam2_fewshot_segmentation/Dockerfile_M1 @@ -0,0 +1,73 @@ +FROM nvidia/cuda:11.8.0-cudnn8-devel-ubuntu22.04 as base +LABEL isUPennContrastWorker=True +LABEL com.nvidia.volumes.needed="nvidia_driver" + +ENV NVIDIA_VISIBLE_DEVICES all +ENV NVIDIA_DRIVER_CAPABILITIES compute,utility + +RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -qy tzdata && \ + apt-get install -qy software-properties-common python3-software-properties && \ + apt-get update && apt-get install -qy \ + build-essential \ + wget \ + python3 \ + r-base \ + libffi-dev \ + libssl-dev \ + libjpeg-dev \ + zlib1g-dev \ + r-base \ + git \ + libpython3-dev && \ + apt-get clean && rm -rf /var/lib/apt/lists/* + +# The below is for the M1 Macs and should be changed for other architectures +ENV PATH="/root/miniconda3/bin:$PATH" +ARG PATH="/root/miniconda3/bin:$PATH" + +RUN wget \ + https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-aarch64.sh \ + && mkdir /root/.conda \ + && bash Miniconda3-latest-Linux-aarch64.sh -b \ + && rm -f Miniconda3-latest-Linux-aarch64.sh +# END M1 Mac specific + + +FROM base as build + +COPY ./workers/annotations/sam2_fewshot_segmentation/environment.yml / +RUN conda env create --file /environment.yml +SHELL ["conda", "run", "-n", "worker", "/bin/bash", "-c"] + +RUN pip install rtree shapely + +RUN git clone https://github.com/arjunrajlaboratory/NimbusImage/ + +RUN pip install -r /NimbusImage/devops/girder/annotation_client/requirements.txt +RUN pip install -e /NimbusImage/devops/girder/annotation_client/ + +RUN mkdir -p /code +RUN git clone https://github.com/facebookresearch/sam2.git /code/sam2 +RUN pip install -e /code/sam2 + +# Change directory to sam2/checkpoints +WORKDIR /code/sam2/checkpoints +# Download the checkpoints into the checkpoints directory +RUN ./download_ckpts.sh +# Change back to the root directory +WORKDIR / + +COPY ./workers/annotations/sam2_fewshot_segmentation/entrypoint.py / + +COPY ./annotation_utilities /annotation_utilities +RUN pip install /annotation_utilities + +LABEL isUPennContrastWorker="" \ + isAnnotationWorker="" \ + interfaceName="SAM2 few-shot segmentation" \ + interfaceCategory="SAM2" \ + description="Uses SAM2 features for few-shot segmentation based on training annotations" \ + annotationShape="polygon" \ + defaultToolName="SAM2 few-shot" + +ENTRYPOINT ["conda", "run", "--no-capture-output", "-n", "worker", "python", "/entrypoint.py"] diff --git a/workers/annotations/sam2_fewshot_segmentation/SAM2_FEWSHOT.md b/workers/annotations/sam2_fewshot_segmentation/SAM2_FEWSHOT.md new file mode 100644 index 0000000..5030298 --- /dev/null +++ b/workers/annotations/sam2_fewshot_segmentation/SAM2_FEWSHOT.md @@ -0,0 +1,143 @@ +# SAM2 Few-Shot Segmentation Worker + +## Overview + +This worker segments objects in microscopy images using few-shot learning with SAM2. Users annotate a small number of training examples (5-20 objects) with a specific tag, and the worker finds similar objects across the dataset using SAM2's frozen image encoder features. No model training is required. + +## How It Works + +### Phase 1: Training Feature Extraction + +For each polygon annotation matching the user-specified Training Tag: + +1. Load the merged multi-channel image at the annotation's location +2. Convert the annotation polygon to a binary mask +3. Crop the image around the object with context padding (object occupies ~20% of crop area by default) +4. Encode the crop through SAM2's image encoder via `SAM2ImagePredictor.set_image()` +5. Extract the `image_embed` feature map (shape: `1, 256, 64, 64`) +6. Pool the feature map using mask-weighted averaging to produce a 256-dimensional feature vector +7. Average all training feature vectors into a single L2-normalized prototype + +### Phase 2: Inference + +For each image frame in the batch: + +1. Run `SAM2AutomaticMaskGenerator` to generate all candidate masks +2. For each candidate mask: + - Apply the same crop-encode-pool pipeline as training + - Compute cosine similarity between the candidate's feature vector and the training prototype + - Keep the mask if similarity >= threshold +3. Convert passing masks to polygon annotations via `find_contours` + `polygons_to_annotations` +4. Upload all annotations to the server + +## Interface Parameters + +| Parameter | Type | Default | Description | +|-----------|------|---------|-------------| +| Training Tag | tags | (required) | Tag identifying training annotation examples | +| Batch XY | text | current | XY positions to process (e.g., "1-3, 5-8") | +| Batch Z | text | current | Z slices to process | +| Batch Time | text | current | Time points to process | +| Model | select | sam2.1_hiera_small.pt | SAM2 checkpoint to use | +| Similarity Threshold | number | 0.5 | Minimum cosine similarity to keep a mask (0.0-1.0) | +| Target Occupancy | number | 0.20 | Fraction of crop area the object should occupy (0.05-0.80) | +| Points per side | number | 32 | Grid density for SAM2 mask generation (16-128) | +| Min Mask Area | number | 100 | Minimum mask area in pixels to consider | +| Max Mask Area | number | 0 | Maximum mask area in pixels (0 = no limit) | +| Smoothing | number | 0.3 | Polygon simplification tolerance | + +## Key Design Decisions + +### Context Padding (Target Occupancy) + +SAM2 was trained on images where objects occupy a reasonable fraction of the frame. Tight crops around objects would be out-of-distribution. The `Target Occupancy` parameter controls how much of the crop the object fills: + +- `crop_side = sqrt(object_area / target_occupancy)` +- Default 0.20 means the object occupies ~20% of the crop area +- The same occupancy is used for both training and inference to ensure consistent feature extraction + +### Mask-Weighted Feature Pooling + +Since we have binary masks for both training annotations and candidate masks, we use them to focus the feature pooling on the actual object pixels rather than background: + +``` +feature_vector = (features * mask).sum(dim=[2,3]) / mask.sum() +``` + +The mask is bilinearly resized from the crop resolution to the feature map resolution (64x64). + +### SAM2ImagePredictor for Encoding + +We use `SAM2ImagePredictor.set_image()` rather than calling `forward_image` directly. This ensures proper handling of: +- Image transforms (resize to 1024x1024, normalization) +- `no_mem_embed` addition (SAM2's learned "no memory" token) +- Consistent feature extraction matching SAM2's internal pipeline + +The `image_embed` from `predictor._features["image_embed"]` gives a `(1, 256, 64, 64)` feature map -- the lowest-resolution, highest-semantic features from SAM2's FPN neck. + +## Tuning Guide + +### Similarity Threshold + +- **Too many false positives**: Increase threshold (try 0.6-0.8) +- **Too few detections (missing objects)**: Decrease threshold (try 0.3-0.4) +- **Start at 0.5** and adjust based on results + +### Target Occupancy + +- **Objects are very small in the image**: Try 0.10-0.15 (more context) +- **Objects are large in the image**: Try 0.30-0.40 (less context) +- **Default 0.20** works well for most microscopy objects + +### Points per side + +- **More masks needed (small objects)**: Increase to 48-64 +- **Faster processing**: Decrease to 16-24 +- **Default 32** balances coverage and speed + +### Min/Max Mask Area + +- Use training annotation areas as a guide +- Set Min to ~50% of smallest training annotation area +- Set Max to ~200% of largest training annotation area +- Set Max to 0 to disable upper limit + +## Performance Characteristics + +- **GPU required**: SAM2 encoder needs CUDA +- **Memory**: ~4GB VRAM for SAM2 small model +- **Speed**: Most time is spent encoding candidate masks individually (one forward pass per candidate). With 32 points per side, expect ~50-200 candidate masks per image. +- **Data efficiency**: Works with 5-20 training examples + +## Possible Future Improvements + +- **Multiple prototypes**: Keep all training vectors instead of averaging, use max similarity (helps when training examples show multiple morphologies) +- **Full-image encoding**: Encode each image once and pool from the full feature map instead of cropping each candidate (faster but lower feature quality for small objects) +- **Negative examples**: Allow users to tag "not this" examples to reduce false positives +- **Size/shape priors**: Learn area distribution from training and filter candidates by size +- **Adaptive thresholding**: Use relative ranking (e.g., top 25%) instead of fixed threshold + +## TODO / Future Work + +- [ ] **Tiled image support**: Large microscopy images should be processed in tiles (like cellposesam's deeptile approach) rather than loading the entire image at once. This would reduce memory usage and allow processing of arbitrarily large images. +- [ ] **Multiple prototypes**: Keep all training feature vectors instead of averaging into a single prototype. Use max similarity or k-NN voting at inference. This would help when training examples show significant morphological variation. +- [ ] **Full-image encoding optimization**: Encode each inference image once and pool from the full feature map for each candidate mask, instead of cropping and re-encoding per candidate. Much faster but may reduce feature quality for small objects. +- [ ] **Negative examples**: Add a "Negative Tag" interface field so users can tag objects they do NOT want to match. Subtract negative similarity from positive similarity to reduce false positives. +- [ ] **Size/shape priors**: Learn area and aspect ratio distributions from training annotations and use them as an additional filter (e.g., reject candidates whose area is >2 std from training mean). +- [ ] **Adaptive thresholding**: Instead of a fixed similarity threshold, use relative ranking (e.g., keep top N% of candidates) or Otsu-style automatic thresholding on the similarity distribution. +- [ ] **Multi-scale feature extraction**: Extract features at multiple occupancy levels (e.g., 0.15, 0.25, 0.40) and concatenate for a richer feature vector. Helps when objects vary significantly in size. +- [ ] **Batch encoding**: Group multiple candidate crops into a batch tensor and encode them in a single forward pass through SAM2 for better GPU utilization. +- [ ] **Cache training prototype**: If the same training tag is used repeatedly, cache the prototype to avoid re-computing features on every run. +- [ ] **Similarity score as property**: Expose the similarity score as an annotation property so users can sort/filter results by confidence. +- [ ] **Support for point annotations as training**: Allow users to provide point prompts (not just polygon masks) as training examples, using SAM2's prompt-based segmentation to generate masks from points first. + +## Files + +| File | Purpose | +|------|---------| +| `entrypoint.py` | Worker logic: interface definition, feature extraction, inference pipeline | +| `Dockerfile` | x86_64 production build (CUDA 12.1, SAM2 checkpoints) | +| `Dockerfile_M1` | arm64/M1 Mac build (CUDA 11.8) | +| `environment.yml` | Conda environment specification | +| `tests/test_sam2_fewshot.py` | Unit tests for helper functions | +| `tests/Dockerfile_Test` | Test Docker image | diff --git a/workers/annotations/sam2_fewshot_segmentation/entrypoint.py b/workers/annotations/sam2_fewshot_segmentation/entrypoint.py new file mode 100644 index 0000000..0908c73 --- /dev/null +++ b/workers/annotations/sam2_fewshot_segmentation/entrypoint.py @@ -0,0 +1,488 @@ +import argparse +import json +import sys +import os + +from itertools import product + +import annotation_client.annotations as annotations_client +import annotation_client.workers as workers +import annotation_client.tiles as tiles + +import annotation_utilities.annotation_tools as annotation_tools +import annotation_utilities.batch_argument_parser as batch_argument_parser + +import numpy as np +from shapely.geometry import Polygon +from skimage.measure import find_contours + +import torch +import torch.nn.functional as F +from sam2.build_sam import build_sam2 +from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator +from sam2.sam2_image_predictor import SAM2ImagePredictor + +from annotation_client.utils import sendProgress, sendError + + +def interface(image, apiUrl, token): + client = workers.UPennContrastWorkerPreviewClient(apiUrl=apiUrl, token=token) + + models = [f for f in os.listdir('/code/sam2/checkpoints') if f.endswith('.pt')] + default_model = 'sam2.1_hiera_small.pt' if 'sam2.1_hiera_small.pt' in models else models[0] if models else None + + interface = { + 'Training Tag': { + 'type': 'tags', + 'displayOrder': 0, + }, + 'Batch XY': { + 'type': 'text', + 'displayOrder': 1, + }, + 'Batch Z': { + 'type': 'text', + 'displayOrder': 2, + }, + 'Batch Time': { + 'type': 'text', + 'displayOrder': 3, + }, + 'Model': { + 'type': 'select', + 'items': models, + 'default': default_model, + 'displayOrder': 4, + }, + 'Similarity Threshold': { + 'type': 'number', + 'min': 0.0, + 'max': 1.0, + 'default': 0.5, + 'displayOrder': 5, + }, + 'Target Occupancy': { + 'type': 'number', + 'min': 0.05, + 'max': 0.80, + 'default': 0.20, + 'displayOrder': 6, + }, + 'Points per side': { + 'type': 'number', + 'min': 16, + 'max': 128, + 'default': 32, + 'displayOrder': 7, + }, + 'Min Mask Area': { + 'type': 'number', + 'min': 0, + 'max': 100000, + 'default': 100, + 'displayOrder': 8, + }, + 'Max Mask Area': { + 'type': 'number', + 'min': 0, + 'max': 10000000, + 'default': 0, + 'displayOrder': 9, + }, + 'Smoothing': { + 'type': 'number', + 'min': 0, + 'max': 3, + 'default': 0.3, + 'displayOrder': 10, + }, + } + client.setWorkerImageInterface(image, interface) + + +def extract_crop_with_context(image, mask, target_occupancy=0.20): + """Extract a crop of the image where the masked object occupies roughly + target_occupancy fraction of the crop area. + + Args: + image: numpy array (H, W, C) or (H, W) + mask: binary numpy array (H, W) + target_occupancy: desired fraction of crop area occupied by object + + Returns: + crop_image: numpy array resized/cropped region + crop_mask: binary numpy array of same spatial size as crop_image + """ + ys, xs = np.where(mask > 0) + if len(ys) == 0: + return image, mask + + y_min, y_max = ys.min(), ys.max() + x_min, x_max = xs.min(), xs.max() + obj_h = y_max - y_min + 1 + obj_w = x_max - x_min + 1 + + obj_area = mask.sum() + if obj_area == 0: + return image, mask + + # Determine crop size so that object occupies target_occupancy of area + crop_area = obj_area / target_occupancy + crop_side = int(np.sqrt(crop_area)) + # Ensure crop is at least as large as the object bounding box + crop_side = max(crop_side, obj_h, obj_w) + + # Center the crop on the object center + cy = (y_min + y_max) / 2.0 + cx = (x_min + x_max) / 2.0 + + h, w = image.shape[:2] + + half = crop_side / 2.0 + top = int(max(0, cy - half)) + left = int(max(0, cx - half)) + bottom = int(min(h, top + crop_side)) + right = int(min(w, left + crop_side)) + + # Adjust if we hit boundaries + if bottom - top < crop_side and top > 0: + top = max(0, bottom - crop_side) + if right - left < crop_side and left > 0: + left = max(0, right - crop_side) + + crop_image = image[top:bottom, left:right] + crop_mask = mask[top:bottom, left:right] + + return crop_image, crop_mask + + +def encode_image_with_sam2(predictor, image_np): + """Encode an image crop using SAM2's image encoder via SAM2ImagePredictor. + + Uses set_image() which handles transforms, backbone encoding, and + no_mem_embed addition consistently with SAM2's internal pipeline. + + Args: + predictor: SAM2ImagePredictor instance + image_np: numpy array (H, W, 3) uint8 + + Returns: + features: tensor of shape [1, 256, 64, 64] (image_embed) + """ + predictor.set_image(image_np) + # image_embed is the lowest-resolution, highest-semantic feature map + # Shape: (1, 256, 64, 64) for 1024x1024 input + return predictor._features["image_embed"] + + +def pool_features_with_mask(features, mask_np, feat_h, feat_w): + """Pool feature map using a binary mask via weighted averaging. + + Args: + features: tensor (1, C, feat_h, feat_w) + mask_np: binary numpy array (crop_h, crop_w) + feat_h: feature map height + feat_w: feature map width + + Returns: + feature_vector: tensor of shape (C,) + """ + # Resize mask to feature map dimensions + mask_tensor = torch.from_numpy(mask_np.astype(np.float32)).unsqueeze(0).unsqueeze(0) + mask_resized = F.interpolate(mask_tensor, size=(feat_h, feat_w), mode='bilinear', align_corners=False) + mask_resized = mask_resized.to(features.device) + + # Weighted pooling + mask_sum = mask_resized.sum() + if mask_sum > 0: + weighted = (features * mask_resized).sum(dim=[2, 3]) / mask_sum + else: + weighted = features.mean(dim=[2, 3]) + + return weighted.squeeze(0) # (C,) + + +def ensure_rgb(image): + """Ensure image is (H, W, 3) uint8 RGB.""" + if image.ndim == 2: + image = np.stack([image, image, image], axis=-1) + elif image.ndim == 3 and image.shape[2] == 1: + image = np.repeat(image, 3, axis=2) + elif image.ndim == 3 and image.shape[2] == 4: + image = image[:, :, :3] + + if image.dtype == np.float32 or image.dtype == np.float64: + if image.max() <= 1.0 and image.min() >= 0.0: + image = (image * 255).astype(np.uint8) + else: + image = np.clip(image, 0, 255).astype(np.uint8) + elif image.dtype == np.uint16: + image = (image / 256).astype(np.uint8) + elif image.dtype != np.uint8: + image = image.astype(np.uint8) + + return image + + +def annotation_to_mask(annotation, image_shape): + """Convert a polygon annotation to a binary mask. + + Args: + annotation: annotation dict with 'coordinates' list of {'x': ..., 'y': ...} + image_shape: (H, W) of the target image + + Returns: + mask: binary numpy array (H, W) + """ + from skimage.draw import polygon as draw_polygon + + coords = annotation['coordinates'] + # Annotation coordinates: 'x' and 'y' in image pixel space + rows = np.array([c['y'] for c in coords]) + cols = np.array([c['x'] for c in coords]) + + mask = np.zeros(image_shape[:2], dtype=np.uint8) + rr, cc = draw_polygon(rows, cols, shape=image_shape[:2]) + mask[rr, cc] = 1 + return mask + + +def compute(datasetId, apiUrl, token, params): + annotationClient = annotations_client.UPennContrastAnnotationClient(apiUrl=apiUrl, token=token) + workerClient = workers.UPennContrastWorkerClient(datasetId, apiUrl, token, params) + tileClient = tiles.UPennContrastDataset(apiUrl=apiUrl, token=token, datasetId=datasetId) + + # Parse parameters + model_name = params['workerInterface']['Model'] + similarity_threshold = float(params['workerInterface']['Similarity Threshold']) + target_occupancy = float(params['workerInterface']['Target Occupancy']) + points_per_side = int(params['workerInterface']['Points per side']) + min_mask_area = int(params['workerInterface']['Min Mask Area']) + max_mask_area = int(params['workerInterface']['Max Mask Area']) + smoothing = float(params['workerInterface']['Smoothing']) + + batch_xy = params['workerInterface'].get('Batch XY', '') + batch_z = params['workerInterface'].get('Batch Z', '') + batch_time = params['workerInterface'].get('Batch Time', '') + + batch_xy = batch_argument_parser.process_range_list(batch_xy, convert_one_to_zero_index=True) + batch_z = batch_argument_parser.process_range_list(batch_z, convert_one_to_zero_index=True) + batch_time = batch_argument_parser.process_range_list(batch_time, convert_one_to_zero_index=True) + + tile = params['tile'] + channel = params['channel'] + tags = params['tags'] + + if batch_xy is None: + batch_xy = [tile['XY']] + if batch_z is None: + batch_z = [tile['Z']] + if batch_time is None: + batch_time = [tile['Time']] + + # Parse training tag - 'type': 'tags' returns a list of strings directly + training_tags = params['workerInterface'].get('Training Tag', []) + if not training_tags or len(training_tags) == 0: + sendError("No training tag selected", + "Please select a tag that identifies your training annotations.") + return + + # ── SAM2 model setup ── + sendProgress(0.0, "Loading model", "Initializing SAM2...") + torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__() + if torch.cuda.get_device_properties(0).major >= 8: + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + + checkpoint_path = f"/code/sam2/checkpoints/{model_name}" + model_to_cfg = { + 'sam2.1_hiera_base_plus.pt': 'sam2.1_hiera_b+.yaml', + 'sam2.1_hiera_large.pt': 'sam2.1_hiera_l.yaml', + 'sam2.1_hiera_small.pt': 'sam2.1_hiera_s.yaml', + 'sam2.1_hiera_tiny.pt': 'sam2.1_hiera_t.yaml', + } + model_cfg = f"configs/sam2.1/{model_to_cfg[model_name]}" + sam2_model = build_sam2(model_cfg, checkpoint_path, device='cuda', apply_postprocessing=False) + predictor = SAM2ImagePredictor(sam2_model) + + # ── Phase 1: Extract training prototype ── + sendProgress(0.05, "Extracting training features", "Fetching training annotations...") + + # Fetch all polygon annotations from the dataset + all_annotations = annotationClient.getAnnotationsByDatasetId(datasetId, shape='polygon') + training_annotations = annotation_tools.get_annotations_with_tags( + all_annotations, training_tags, exclusive=False + ) + + if len(training_annotations) == 0: + sendError("No training annotations found", f"No polygon annotations found with tags: {training_tags}") + return + + print(f"Found {len(training_annotations)} training annotations") + + feature_vectors = [] + for idx, annotation in enumerate(training_annotations): + loc = annotation['location'] + ann_xy = loc.get('XY', 0) + ann_z = loc.get('Z', 0) + ann_time = loc.get('Time', 0) + + # Get the merged image at the annotation's location + images = annotation_tools.get_images_for_all_channels(tileClient, datasetId, ann_xy, ann_z, ann_time) + layers = annotation_tools.get_layers(tileClient.client, datasetId) + merged_image = annotation_tools.process_and_merge_channels(images, layers) + merged_image = ensure_rgb(merged_image) + + # Convert annotation to mask + mask = annotation_to_mask(annotation, merged_image.shape) + + if mask.sum() == 0: + print(f"Warning: training annotation {idx} produced empty mask, skipping") + continue + + # Extract crop with context padding + crop_image, crop_mask = extract_crop_with_context(merged_image, mask, target_occupancy) + crop_image = ensure_rgb(crop_image) + + # Encode with SAM2 + features = encode_image_with_sam2(predictor, crop_image) + feat_h, feat_w = features.shape[2], features.shape[3] + + # Pool features with mask + feature_vec = pool_features_with_mask(features, crop_mask, feat_h, feat_w) + feature_vectors.append(feature_vec) + + sendProgress(0.05 + 0.15 * (idx + 1) / len(training_annotations), + "Extracting training features", + f"Processed {idx + 1}/{len(training_annotations)} training examples") + + if len(feature_vectors) == 0: + sendError("No valid training features", "All training annotations produced empty masks") + return + + # Create prototype by averaging feature vectors + training_prototype = torch.stack(feature_vectors).mean(dim=0) + training_prototype = F.normalize(training_prototype.unsqueeze(0), dim=1).squeeze(0) + + print(f"Training prototype shape: {training_prototype.shape}") + + # Optionally learn size statistics from training annotations + training_areas = [] + for annotation in training_annotations: + coords = annotation['coordinates'] + rows = [c['y'] for c in coords] + cols = [c['x'] for c in coords] + poly = Polygon(zip(cols, rows)) + if poly.is_valid: + training_areas.append(poly.area) + + mean_area = np.mean(training_areas) if training_areas else None + std_area = np.std(training_areas) if training_areas else None + print(f"Training area stats: mean={mean_area}, std={std_area}") + + # ── Phase 2: Inference ── + mask_generator = SAM2AutomaticMaskGenerator( + sam2_model, + points_per_side=points_per_side, + pred_iou_thresh=0.88, + stability_score_thresh=0.95, + min_mask_region_area=min_mask_area, + ) + + batches = list(product(batch_xy, batch_z, batch_time)) + total_batches = len(batches) + new_annotations = [] + + for i, batch in enumerate(batches): + XY, Z, Time = batch + + sendProgress(0.2 + 0.7 * i / total_batches, + "Segmenting", + f"Processing frame {i + 1}/{total_batches}") + + # Get merged image for this batch + images = annotation_tools.get_images_for_all_channels(tileClient, datasetId, XY, Z, Time) + layers = annotation_tools.get_layers(tileClient.client, datasetId) + merged_image = annotation_tools.process_and_merge_channels(images, layers) + merged_image_rgb = ensure_rgb(merged_image) + + # Generate candidate masks with SAM2 + candidate_masks = mask_generator.generate(merged_image_rgb.astype(np.float32)) + print(f"Frame {i + 1}: generated {len(candidate_masks)} candidate masks") + + # Filter candidates by similarity to training prototype + filtered_polygons = [] + for mask_data in candidate_masks: + mask = mask_data['segmentation'] + area = mask.sum() + + # Area filtering + if min_mask_area > 0 and area < min_mask_area: + continue + if max_mask_area > 0 and area > max_mask_area: + continue + + # Extract crop with context, encode, and compare + crop_image, crop_mask = extract_crop_with_context( + merged_image_rgb, mask, target_occupancy + ) + crop_image = ensure_rgb(crop_image) + + if crop_mask.sum() == 0: + continue + + features = encode_image_with_sam2(predictor, crop_image) + feat_h, feat_w = features.shape[2], features.shape[3] + feature_vec = pool_features_with_mask(features, crop_mask, feat_h, feat_w) + + # Compute cosine similarity + feature_vec_norm = F.normalize(feature_vec.unsqueeze(0), dim=1) + similarity = F.cosine_similarity( + feature_vec_norm, + training_prototype.unsqueeze(0) + ).item() + + if similarity >= similarity_threshold: + # Convert mask to polygon + contours = find_contours(mask, 0.5) + if len(contours) == 0: + continue + polygon = Polygon(contours[0]).simplify(smoothing, preserve_topology=True) + if polygon.is_valid and not polygon.is_empty: + filtered_polygons.append(polygon) + + print(f"Frame {i + 1}: {len(filtered_polygons)} masks passed similarity filter") + + # Convert polygons to annotations + temp_annotations = annotation_tools.polygons_to_annotations( + filtered_polygons, datasetId, XY=XY, Time=Time, Z=Z, tags=tags, channel=channel + ) + new_annotations.extend(temp_annotations) + + sendProgress(0.9, "Uploading annotations", f"Sending {len(new_annotations)} annotations to server") + annotationClient.createMultipleAnnotations(new_annotations) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser( + description='SAM2 Few-Shot Segmentation') + + parser.add_argument('--datasetId', type=str, required=False, action='store') + parser.add_argument('--apiUrl', type=str, required=True, action='store') + parser.add_argument('--token', type=str, required=True, action='store') + parser.add_argument('--request', type=str, required=True, action='store') + parser.add_argument('--parameters', type=str, + required=True, action='store') + + args = parser.parse_args(sys.argv[1:]) + + params = json.loads(args.parameters) + datasetId = args.datasetId + apiUrl = args.apiUrl + token = args.token + + match args.request: + case 'compute': + compute(datasetId, apiUrl, token, params) + case 'interface': + interface(params['image'], apiUrl, token) diff --git a/workers/annotations/sam2_fewshot_segmentation/environment.yml b/workers/annotations/sam2_fewshot_segmentation/environment.yml new file mode 100644 index 0000000..39e0881 --- /dev/null +++ b/workers/annotations/sam2_fewshot_segmentation/environment.yml @@ -0,0 +1,18 @@ +name: worker +channels: + - conda-forge + - defaults +dependencies: + - python=3.10 + - pip + - imageio + - rasterio + - shapely + - pillow + - opencv + - matplotlib + - scikit-image + - pip: + - pycocotools + - onnxruntime + - onnx diff --git a/workers/annotations/sam2_fewshot_segmentation/local_tests/.gitignore b/workers/annotations/sam2_fewshot_segmentation/local_tests/.gitignore new file mode 100644 index 0000000..77ac754 --- /dev/null +++ b/workers/annotations/sam2_fewshot_segmentation/local_tests/.gitignore @@ -0,0 +1,3 @@ +.venv/ +__pycache__/ +*.pyc diff --git a/workers/annotations/sam2_fewshot_segmentation/local_tests/setup_env.sh b/workers/annotations/sam2_fewshot_segmentation/local_tests/setup_env.sh new file mode 100755 index 0000000..f3650b3 --- /dev/null +++ b/workers/annotations/sam2_fewshot_segmentation/local_tests/setup_env.sh @@ -0,0 +1,114 @@ +#!/usr/bin/env bash +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +VENV_DIR="$SCRIPT_DIR/.venv" + +SKIP_SAM2=false +for arg in "$@"; do + case $arg in + --skip-sam2) SKIP_SAM2=true ;; + esac +done + +echo "=== SAM2 Few-Shot Local Test Environment Setup ===" +echo "Script dir: $SCRIPT_DIR" +if [ "$SKIP_SAM2" = true ]; then + echo "Skipping SAM2/PyTorch install (--skip-sam2 flag)" +fi + +# --- Create venv --- +if [ -d "$VENV_DIR" ]; then + echo "Virtual environment already exists at $VENV_DIR" + echo "To recreate, delete it first: rm -rf $VENV_DIR" +else + echo "Creating virtual environment..." + python3 -m venv "$VENV_DIR" +fi + +# shellcheck disable=SC1091 +source "$VENV_DIR/bin/activate" +echo "Using Python: $(which python)" + +# --- Install pip basics --- +pip install --upgrade pip setuptools wheel + +# --- annotation_client from NimbusImage local clone --- +ANNOTATION_CLIENT_DIR="/home/arjun/UPennContrast/devops/girder/annotation_client" +if [ -d "$ANNOTATION_CLIENT_DIR" ]; then + echo "Installing annotation_client from $ANNOTATION_CLIENT_DIR..." + pip install girder-client tifffile + pip install -e "$ANNOTATION_CLIENT_DIR" +else + echo "WARNING: annotation_client not found at $ANNOTATION_CLIENT_DIR" + echo "You'll need to install it manually." +fi + +# --- annotation_utilities --- +ANNOTATION_UTILS_DIR="/home/arjun/ImageAnalysisProject/annotation_utilities" +if [ -d "$ANNOTATION_UTILS_DIR" ]; then + echo "Installing annotation_utilities from $ANNOTATION_UTILS_DIR..." + pip install -e "$ANNOTATION_UTILS_DIR" +else + echo "WARNING: annotation_utilities not found at $ANNOTATION_UTILS_DIR" +fi + +# --- worker_client --- +WORKER_CLIENT_DIR="/home/arjun/ImageAnalysisProject/worker_client" +if [ -d "$WORKER_CLIENT_DIR" ]; then + echo "Installing worker_client from $WORKER_CLIENT_DIR..." + pip install -e "$WORKER_CLIENT_DIR" +else + echo "WARNING: worker_client not found at $WORKER_CLIENT_DIR" +fi + +# --- Scientific Python stack --- +echo "Installing scientific Python packages..." +pip install numpy scipy scikit-image shapely matplotlib pillow numba + +# --- PyTorch + SAM2 (skippable for connection-only testing) --- +if [ "$SKIP_SAM2" = false ]; then + echo "Installing PyTorch..." + if command -v nvidia-smi &> /dev/null; then + echo "NVIDIA GPU detected, installing CUDA PyTorch..." + pip install torch torchvision --index-url https://download.pytorch.org/whl/cu121 + else + echo "No NVIDIA GPU detected, installing CPU PyTorch..." + pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu + fi + + SAM2_DIR="/tmp/sam2" + echo "Installing SAM2..." + if [ -d "$SAM2_DIR" ]; then + echo "Using existing SAM2 clone at $SAM2_DIR" + else + echo "Cloning SAM2 to $SAM2_DIR..." + git clone https://github.com/facebookresearch/sam2.git "$SAM2_DIR" + fi + SAM2_BUILD_CUDA=0 pip install -e "$SAM2_DIR" + + # Download SAM2 checkpoints if needed + CHECKPOINT_DIR="$SAM2_DIR/checkpoints" + if [ ! -f "$CHECKPOINT_DIR/sam2.1_hiera_small.pt" ]; then + echo "Downloading SAM2 checkpoints..." + pushd "$SAM2_DIR" > /dev/null + if [ -f "checkpoints/download_ckpts.sh" ]; then + bash checkpoints/download_ckpts.sh + else + mkdir -p checkpoints + echo "WARNING: Checkpoint download script not found." + echo "Download manually from https://github.com/facebookresearch/sam2#download-checkpoints" + fi + popd > /dev/null + else + echo "SAM2 checkpoints already present." + fi +else + echo "Skipping PyTorch and SAM2 install." + echo "test_connection.py will still work for API testing." +fi + +echo "" +echo "=== Setup complete ===" +echo "Activate with: source $VENV_DIR/bin/activate" +echo "Then run: python $SCRIPT_DIR/test_connection.py" diff --git a/workers/annotations/sam2_fewshot_segmentation/local_tests/test_connection.py b/workers/annotations/sam2_fewshot_segmentation/local_tests/test_connection.py new file mode 100644 index 0000000..eb99253 --- /dev/null +++ b/workers/annotations/sam2_fewshot_segmentation/local_tests/test_connection.py @@ -0,0 +1,251 @@ +#!/usr/bin/env python3 +""" +Test connection to NimbusImage API. + +Validates image loading, annotation retrieval, and multi-channel merging +against a live Nimbus instance. Run after setup_env.sh creates the venv. + +Usage: + # With env vars: + export NIMBUS_API_URL=http://localhost:8080/api/v1 + export NIMBUS_TOKEN=your_token_here + python test_connection.py + + # With interactive login (will prompt for username/password): + python test_connection.py + + # With a specific dataset: + python test_connection.py --dataset 69988c84b48d8121b565aba4 +""" + +import argparse +import getpass +import os +import sys +from collections import Counter + +import numpy as np + +import annotation_client.annotations as annotations_client +import annotation_client.tiles as tiles +import annotation_utilities.annotation_tools as annotation_tools + + +def ensure_rgb(image): + """Ensure image is (H, W, 3) uint8 RGB. Copied from entrypoint.py to avoid sam2 import.""" + if image.ndim == 2: + image = np.stack([image, image, image], axis=-1) + elif image.ndim == 3 and image.shape[2] == 1: + image = np.repeat(image, 3, axis=2) + elif image.ndim == 3 and image.shape[2] == 4: + image = image[:, :, :3] + + if image.dtype == np.float32 or image.dtype == np.float64: + if image.max() <= 1.0 and image.min() >= 0.0: + image = (image * 255).astype(np.uint8) + else: + image = np.clip(image, 0, 255).astype(np.uint8) + elif image.dtype == np.uint16: + image = (image / 256).astype(np.uint8) + elif image.dtype != np.uint8: + image = image.astype(np.uint8) + + return image + + +def get_auth(args): + """Get API URL and token from env vars, CLI args, or interactive login.""" + api_url = args.api_url or os.environ.get('NIMBUS_API_URL', 'http://localhost:8080/api/v1') + token = args.token or os.environ.get('NIMBUS_TOKEN') + + if token: + print(f"API URL: {api_url}") + print(f"Token: {token[:8]}...") + return api_url, token + + # Login via girder_client (CLI args or interactive) + import girder_client + gc = girder_client.GirderClient(apiUrl=api_url) + + username = args.username or os.environ.get('NIMBUS_USERNAME') or input("Username: ") + password = args.password or os.environ.get('NIMBUS_PASSWORD') or getpass.getpass("Password: ") + + print(f"Logging in as '{username}' to {api_url}...") + gc.authenticate(username, password) + token = gc.token + + print(f"API URL: {api_url}") + print(f"Token: {token[:8]}...") + return api_url, token + + +def test_image_loading(api_url, token, dataset_id): + """Test image metadata and loading.""" + print("\n" + "=" * 60) + print("IMAGE LOADING TEST") + print("=" * 60) + + tileClient = tiles.UPennContrastDataset( + apiUrl=api_url, token=token, datasetId=dataset_id + ) + + # Report metadata + idx_range = tileClient.tiles.get('IndexRange', {}) + num_channels = idx_range.get('IndexC', 1) + num_z = idx_range.get('IndexZ', 1) + num_time = idx_range.get('IndexT', 1) + num_xy = idx_range.get('IndexXY', 1) + + print(f" Channels: {num_channels}") + print(f" Z-planes: {num_z}") + print(f" Timepoints: {num_time}") + print(f" XY pos: {num_xy}") + + if 'channels' in tileClient.tiles: + print(f" Channel names: {tileClient.tiles['channels']}") + + size_x = tileClient.tiles.get('sizeX', 'unknown') + size_y = tileClient.tiles.get('sizeY', 'unknown') + print(f" Image size: {size_x} x {size_y}") + print(f" mm_x: {tileClient.tiles.get('mm_x', 'N/A')}") + print(f" mm_y: {tileClient.tiles.get('mm_y', 'N/A')}") + print(f" Total frames: {len(tileClient.tiles.get('frames', []))}") + + # Load a single frame + print("\n Loading single frame (XY=0, Z=0, T=0, C=0)...") + frame = tileClient.coordinatesToFrameIndex(0, Z=0, T=0, channel=0) + image = tileClient.getRegion(dataset_id, frame=frame).squeeze() + print(f" Single frame: shape={image.shape}, dtype={image.dtype}") + + # Load merged RGB via the same pipeline as the SAM2 worker + print("\n Loading merged RGB image (same as SAM2 worker)...") + images = annotation_tools.get_images_for_all_channels(tileClient, dataset_id, 0, 0, 0) + print(f" Loaded {len(images)} channel images") + for i, img in enumerate(images): + print(f" Channel {i}: shape={img.shape}, dtype={img.dtype}, " + f"range=[{img.min():.1f}, {img.max():.1f}]") + + layers = annotation_tools.get_layers(tileClient.client, dataset_id) + print(f" Found {len(layers)} layers") + + merged = annotation_tools.process_and_merge_channels(images, layers) + print(f" Merged image: shape={merged.shape}, dtype={merged.dtype}") + + rgb = ensure_rgb(merged) + print(f" RGB output: shape={rgb.shape}, dtype={rgb.dtype}, " + f"range=[{rgb.min()}, {rgb.max()}]") + + return tileClient + + +def test_annotation_loading(api_url, token, dataset_id): + """Test annotation retrieval and analysis.""" + print("\n" + "=" * 60) + print("ANNOTATION LOADING TEST") + print("=" * 60) + + annotationClient = annotations_client.UPennContrastAnnotationClient( + apiUrl=api_url, token=token + ) + + # Get all polygon annotations + print(" Fetching polygon annotations...") + polygons = annotationClient.getAnnotationsByDatasetId(dataset_id, shape='polygon') + print(f" Total polygons: {len(polygons)}") + + if len(polygons) == 0: + print(" No polygon annotations found. Skipping detailed analysis.") + return + + # Tag breakdown + tag_counter = Counter() + for ann in polygons: + ann_tags = ann.get('tags', []) + if ann_tags: + for t in ann_tags: + tag_counter[t] += 1 + else: + tag_counter['(untagged)'] += 1 + + print(f"\n Tag breakdown ({len(tag_counter)} unique tags):") + for tag, count in tag_counter.most_common(): + print(f" {tag}: {count}") + + # Location distribution + xy_counter = Counter() + z_counter = Counter() + time_counter = Counter() + for ann in polygons: + loc = ann.get('location', {}) + xy_counter[loc.get('XY', 0)] += 1 + z_counter[loc.get('Z', 0)] += 1 + time_counter[loc.get('Time', 0)] += 1 + + print(f"\n Location distribution:") + print(f" XY positions: {dict(xy_counter.most_common())}") + print(f" Z planes: {dict(z_counter.most_common())}") + print(f" Timepoints: {dict(time_counter.most_common())}") + + # Coordinate stats + all_x = [] + all_y = [] + for ann in polygons: + for coord in ann.get('coordinates', []): + all_x.append(coord['x']) + all_y.append(coord['y']) + + if all_x: + print(f"\n Coordinate ranges:") + print(f" X: [{min(all_x):.1f}, {max(all_x):.1f}]") + print(f" Y: [{min(all_y):.1f}, {max(all_y):.1f}]") + + # Also check for points and lines + points = annotationClient.getAnnotationsByDatasetId(dataset_id, shape='point') + lines = annotationClient.getAnnotationsByDatasetId(dataset_id, shape='line') + print(f"\n Other shapes: {len(points)} points, {len(lines)} lines") + + +def main(): + parser = argparse.ArgumentParser(description='Test NimbusImage API connection') + parser.add_argument('--dataset', type=str, + default='69988c84b48d8121b565aba4', + help='Dataset ID to test with') + parser.add_argument('--api-url', type=str, default=None, + help='API URL (default: $NIMBUS_API_URL or http://localhost:8080/api/v1)') + parser.add_argument('--token', type=str, default=None, + help='Auth token (default: $NIMBUS_TOKEN or interactive login)') + parser.add_argument('--username', type=str, default=None, + help='Username for login (default: $NIMBUS_USERNAME or interactive)') + parser.add_argument('--password', type=str, default=None, + help='Password for login (default: $NIMBUS_PASSWORD or interactive)') + args = parser.parse_args() + + print("=" * 60) + print("NimbusImage API Connection Test") + print("=" * 60) + + api_url, token = get_auth(args) + dataset_id = args.dataset + print(f"Dataset: {dataset_id}") + + try: + test_image_loading(api_url, token, dataset_id) + except Exception as e: + print(f"\n ERROR in image loading: {e}") + import traceback + traceback.print_exc() + + try: + test_annotation_loading(api_url, token, dataset_id) + except Exception as e: + print(f"\n ERROR in annotation loading: {e}") + import traceback + traceback.print_exc() + + print("\n" + "=" * 60) + print("TEST COMPLETE") + print("=" * 60) + + +if __name__ == '__main__': + main() diff --git a/workers/annotations/sam2_fewshot_segmentation/tests/Dockerfile_Test b/workers/annotations/sam2_fewshot_segmentation/tests/Dockerfile_Test new file mode 100644 index 0000000..f04c20a --- /dev/null +++ b/workers/annotations/sam2_fewshot_segmentation/tests/Dockerfile_Test @@ -0,0 +1,14 @@ +# Use the existing sam2_fewshot_segmentation worker as the base +FROM annotations/sam2_fewshot_segmentation:latest AS test + +# Install test dependencies +SHELL ["conda", "run", "-n", "worker", "/bin/bash", "-c"] +RUN pip install pytest pytest-mock + +# Copy test files +RUN mkdir -p /tests +COPY ./workers/annotations/sam2_fewshot_segmentation/tests/*.py /tests +WORKDIR /tests + +# Command to run tests +ENTRYPOINT ["conda", "run", "--no-capture-output", "-n", "worker", "python3", "-m", "pytest", "-v"] diff --git a/workers/annotations/sam2_fewshot_segmentation/tests/__init__.py b/workers/annotations/sam2_fewshot_segmentation/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/workers/annotations/sam2_fewshot_segmentation/tests/test_sam2_fewshot.py b/workers/annotations/sam2_fewshot_segmentation/tests/test_sam2_fewshot.py new file mode 100644 index 0000000..00bcc41 --- /dev/null +++ b/workers/annotations/sam2_fewshot_segmentation/tests/test_sam2_fewshot.py @@ -0,0 +1,278 @@ +import pytest +import numpy as np +from unittest.mock import patch, MagicMock + +from entrypoint import ( + extract_crop_with_context, + pool_features_with_mask, + ensure_rgb, + annotation_to_mask, + interface, +) + + +class TestExtractCropWithContext: + """Tests for the context-aware crop extraction.""" + + def test_basic_crop_centered(self): + """Test that crop is centered on the object.""" + image = np.zeros((200, 200, 3), dtype=np.uint8) + mask = np.zeros((200, 200), dtype=np.uint8) + # Place a 20x20 object in the center + mask[90:110, 90:110] = 1 + image[90:110, 90:110] = 128 + + crop_image, crop_mask = extract_crop_with_context(image, mask, target_occupancy=0.20) + + # Object should still be present in the crop + assert crop_mask.sum() > 0 + # Crop should be larger than the object itself + assert crop_image.shape[0] >= 20 + assert crop_image.shape[1] >= 20 + + def test_small_object_gets_more_context(self): + """A small object should get a proportionally larger crop.""" + image = np.zeros((500, 500, 3), dtype=np.uint8) + mask = np.zeros((500, 500), dtype=np.uint8) + # Small 10x10 object + mask[245:255, 245:255] = 1 + + crop_image, crop_mask = extract_crop_with_context(image, mask, target_occupancy=0.20) + + # With 100 pixels of object area and 0.20 occupancy, + # crop area should be ~500, so side ~22 + obj_area = 100 + expected_crop_area = obj_area / 0.20 + expected_side = int(np.sqrt(expected_crop_area)) + assert crop_image.shape[0] >= expected_side - 2 # Allow small margin + + def test_object_at_edge(self): + """Object near image edge should still produce valid crop.""" + image = np.zeros((100, 100, 3), dtype=np.uint8) + mask = np.zeros((100, 100), dtype=np.uint8) + # Object at top-left corner + mask[0:10, 0:10] = 1 + + crop_image, crop_mask = extract_crop_with_context(image, mask, target_occupancy=0.20) + + # Should not crash and mask should be preserved + assert crop_mask.sum() > 0 + assert crop_image.shape[0] > 0 + assert crop_image.shape[1] > 0 + + def test_empty_mask_returns_original(self): + """Empty mask should return the original image and mask.""" + image = np.zeros((100, 100, 3), dtype=np.uint8) + mask = np.zeros((100, 100), dtype=np.uint8) + + crop_image, crop_mask = extract_crop_with_context(image, mask, target_occupancy=0.20) + + assert np.array_equal(crop_image, image) + assert np.array_equal(crop_mask, mask) + + def test_large_object_respects_bounding_box(self): + """Crop should be at least as large as the object bounding box.""" + image = np.zeros((200, 200, 3), dtype=np.uint8) + mask = np.zeros((200, 200), dtype=np.uint8) + # Large 80x80 object + mask[60:140, 60:140] = 1 + + crop_image, crop_mask = extract_crop_with_context(image, mask, target_occupancy=0.20) + + # Crop must encompass the full object + assert crop_image.shape[0] >= 80 + assert crop_image.shape[1] >= 80 + # And the mask pixels should all be within the crop + assert crop_mask.sum() == mask.sum() + + +class TestPoolFeaturesWithMask: + """Tests for the weighted feature pooling.""" + + def test_basic_pooling(self): + """Pooling with full mask should equal global average.""" + import torch + + C, H, W = 32, 8, 8 + features = torch.ones(1, C, H, W) + mask = np.ones((H, W), dtype=np.float32) + + result = pool_features_with_mask(features, mask, H, W) + + assert result.shape == (C,) + # With all-ones features and all-ones mask, result should be all ones + assert torch.allclose(result, torch.ones(C), atol=1e-3) + + def test_masked_region_pooling(self): + """Pooling should focus on masked region.""" + import torch + + C, H, W = 16, 8, 8 + features = torch.zeros(1, C, H, W) + # Set top-left quadrant to 1.0 + features[:, :, :4, :4] = 1.0 + + # Mask only the top-left quadrant + mask = np.zeros((H, W), dtype=np.float32) + mask[:4, :4] = 1.0 + + result = pool_features_with_mask(features, mask, H, W) + + # Should be close to 1.0 since we're pooling from the region with value 1 + assert torch.allclose(result, torch.ones(C), atol=0.2) + + def test_empty_mask_fallback(self): + """Empty mask should fall back to global average pooling.""" + import torch + + C, H, W = 16, 8, 8 + features = torch.ones(1, C, H, W) * 3.0 + mask = np.zeros((H, W), dtype=np.float32) + + result = pool_features_with_mask(features, mask, H, W) + + # Should fall back to mean pooling + assert result.shape == (C,) + assert torch.allclose(result, torch.ones(C) * 3.0, atol=1e-3) + + def test_mask_upscaling(self): + """Test that mask is properly resized to match feature dimensions.""" + import torch + + C = 16 + feat_h, feat_w = 8, 8 + features = torch.ones(1, C, feat_h, feat_w) + + # Mask at different resolution than features + mask = np.ones((32, 32), dtype=np.float32) + + result = pool_features_with_mask(features, mask, feat_h, feat_w) + + assert result.shape == (C,) + assert torch.allclose(result, torch.ones(C), atol=1e-3) + + +class TestEnsureRgb: + """Tests for image format normalization.""" + + def test_grayscale_to_rgb(self): + image = np.zeros((100, 100), dtype=np.uint8) + result = ensure_rgb(image) + assert result.shape == (100, 100, 3) + assert result.dtype == np.uint8 + + def test_single_channel_to_rgb(self): + image = np.zeros((100, 100, 1), dtype=np.uint8) + result = ensure_rgb(image) + assert result.shape == (100, 100, 3) + + def test_rgba_to_rgb(self): + image = np.zeros((100, 100, 4), dtype=np.uint8) + result = ensure_rgb(image) + assert result.shape == (100, 100, 3) + + def test_float_0_1_to_uint8(self): + image = np.ones((100, 100, 3), dtype=np.float32) * 0.5 + result = ensure_rgb(image) + assert result.dtype == np.uint8 + assert result.max() == 127 or result.max() == 128 # rounding + + def test_float_0_255_to_uint8(self): + image = np.ones((100, 100, 3), dtype=np.float32) * 200.0 + result = ensure_rgb(image) + assert result.dtype == np.uint8 + assert result.max() == 200 + + def test_uint16_to_uint8(self): + image = np.ones((100, 100, 3), dtype=np.uint16) * 512 + result = ensure_rgb(image) + assert result.dtype == np.uint8 + assert result.max() == 2 # 512 / 256 = 2 + + def test_rgb_uint8_passthrough(self): + image = np.ones((100, 100, 3), dtype=np.uint8) * 42 + result = ensure_rgb(image) + assert result.dtype == np.uint8 + assert np.array_equal(result, image) + + +class TestAnnotationToMask: + """Tests for converting polygon annotations to binary masks.""" + + def test_square_annotation(self): + annotation = { + 'coordinates': [ + {'x': 10, 'y': 10}, + {'x': 10, 'y': 20}, + {'x': 20, 'y': 20}, + {'x': 20, 'y': 10}, + ] + } + mask = annotation_to_mask(annotation, (30, 30)) + assert mask.shape == (30, 30) + assert mask.sum() > 0 + # Center of the square should be 1 + assert mask[15, 15] == 1 + # Outside should be 0 + assert mask[0, 0] == 0 + + def test_mask_matches_image_shape(self): + annotation = { + 'coordinates': [ + {'x': 5, 'y': 5}, + {'x': 5, 'y': 15}, + {'x': 15, 'y': 15}, + {'x': 15, 'y': 5}, + ] + } + mask = annotation_to_mask(annotation, (100, 200)) + assert mask.shape == (100, 200) + + +class TestInterface: + """Test the interface function.""" + + @patch('annotation_client.workers.UPennContrastWorkerPreviewClient') + def test_interface_sets_all_fields(self, mock_client_class): + mock_client = mock_client_class.return_value + + # Mock the checkpoint directory + with patch('os.listdir', return_value=['sam2.1_hiera_small.pt', 'sam2.1_hiera_large.pt']): + interface('test_image', 'http://test-api', 'test-token') + + mock_client.setWorkerImageInterface.assert_called_once() + interface_data = mock_client.setWorkerImageInterface.call_args[0][1] + + # Verify all expected fields are present + expected_fields = [ + 'Training Tag', 'Batch XY', 'Batch Z', 'Batch Time', + 'Model', 'Similarity Threshold', 'Target Occupancy', + 'Points per side', 'Min Mask Area', 'Max Mask Area', 'Smoothing', + ] + for field in expected_fields: + assert field in interface_data, f"Missing interface field: {field}" + + # Verify types + assert interface_data['Training Tag']['type'] == 'tags' + assert interface_data['Model']['type'] == 'select' + assert interface_data['Similarity Threshold']['type'] == 'number' + assert interface_data['Target Occupancy']['type'] == 'number' + assert interface_data['Points per side']['type'] == 'number' + assert interface_data['Smoothing']['type'] == 'number' + + # Verify defaults + assert interface_data['Similarity Threshold']['default'] == 0.5 + assert interface_data['Target Occupancy']['default'] == 0.20 + assert interface_data['Points per side']['default'] == 32 + assert interface_data['Model']['default'] == 'sam2.1_hiera_small.pt' + + @patch('annotation_client.workers.UPennContrastWorkerPreviewClient') + def test_interface_with_no_models(self, mock_client_class): + mock_client = mock_client_class.return_value + + with patch('os.listdir', return_value=[]): + interface('test_image', 'http://test-api', 'test-token') + + interface_data = mock_client.setWorkerImageInterface.call_args[0][1] + assert interface_data['Model']['default'] is None + assert interface_data['Model']['items'] == []