-
Notifications
You must be signed in to change notification settings - Fork 0
Virchow2 model #2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
Jurgee
wants to merge
40
commits into
main
Choose a base branch
from
feature/virchow2-model
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
40 commits
Select commit
Hold shift + click to select a range
a167a22
feat: tensorrt support
matejpekar 1d3310f
fix: remove flush
matejpekar 4e27a48
feat: add docker files for cpu/gpu
fd3154d
feat: add PVC for TensorRT
eaac807
feat: add support of TensorRT for models
46fe8b1
feat: add TensorRT cache to workers
f07723e
add Jiri as coauthor
9d6e265
fix: remove gpu number from serve.deployment in code
e7612f9
fix: warning suppress
5945f10
feat: add jobs to download virchow2
8ef4cc5
feat: add model provider for hf
a6c427e
feat: add pvc for huggingface
27c7801
feat: add virchow2 model
e5d84cb
fix
e1fcb6c
fix: fine tune
e7ac073
feat: add into dockerfile
51f07a4
fix: remove installs from model
178f226
fix: based on official docs
5cc123f
fix
964114e
fix: remove comment
181f79e
chore: update docker gpu file
Jurgee 156a7d4
feat: optimalize virchow2 deployment
Jurgee 57176d6
fix: remove hf token, create new secret
Jurgee fe51ee2
fix
Jurgee 9f37d03
Merge branch 'main' into feature/virchow2-model
Jurgee 210c7e6
fix: remove intra threads
Jurgee bf7cff1
fix: lint
Jurgee 6813264
fix: remove duplicity
Jurgee 7510c9f
fixes
Jurgee 2eae503
docker files
Jurgee c5095bd
fix: docker
Jurgee e94baec
chore: new docker image
Jurgee 7cdd290
chore: cpu docker
Jurgee 8cce2cb
fix
Jurgee 7e329a8
final changes
Jurgee 8dfea82
fix: usage of master branch
Jurgee e6f8603
Potential fix for pull request finding
Jurgee fb646c4
Potential fix for pull request finding
Jurgee b2d083c
Potential fix for pull request finding
Jurgee bfd90a9
Potential fix for pull request finding
Jurgee File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Some comments aren't visible on the classic Files Changed page.
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,48 @@ | ||
| import os | ||
|
|
||
| from huggingface_hub import login, snapshot_download | ||
|
|
||
|
|
||
| HF_TOKEN = os.environ.get("HF_TOKEN") | ||
| CACHE_DIR = "/mnt/huggingface_cache" | ||
| MODEL_ID = "paige-ai/Virchow2" | ||
|
|
||
| os.environ["HF_HOME"] = CACHE_DIR | ||
| os.makedirs(CACHE_DIR, exist_ok=True) | ||
|
|
||
| print(f"Starting download for {MODEL_ID} to {CACHE_DIR}") | ||
|
|
||
| if HF_TOKEN: | ||
| print("Logging in to Hugging Face...") | ||
| login(token=HF_TOKEN) | ||
| else: | ||
| print("No HF_TOKEN provided! Download might fail for gated models.") | ||
|
|
||
| print("Downloading model snapshot...") | ||
| try: | ||
| path = snapshot_download( | ||
| repo_id=MODEL_ID, | ||
| resume_download=True, | ||
| local_files_only=False, | ||
| ) | ||
| print(f"Model downloaded to: {path}") | ||
|
|
||
| print("Verifying model files exist...") | ||
| import timm | ||
|
|
||
| try: | ||
| model = timm.create_model( | ||
| f"hf-hub:{MODEL_ID}", | ||
| pretrained=True, | ||
| num_classes=0, | ||
| ) | ||
| print(f"Model successfully loaded! Type: {type(model).__name__}") | ||
| del model # Free memory | ||
| except Exception as e: | ||
| print(f"Verification warning: {e}") | ||
|
|
||
| except Exception as e: | ||
| print(f"Download failed: {e}") | ||
| exit(1) | ||
|
|
||
| print("DONE. Model is cached and ready for offline use.") | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,65 @@ | ||
| apiVersion: batch/v1 | ||
| kind: Job | ||
| metadata: | ||
| name: virchow-downloader | ||
| namespace: rationai-jobs-ns | ||
| spec: | ||
| template: | ||
| spec: | ||
| securityContext: | ||
| runAsNonRoot: true | ||
| runAsUser: 1000 | ||
| fsGroup: 1000 | ||
| seccompProfile: | ||
| type: RuntimeDefault | ||
| containers: | ||
| - name: downloader | ||
| image: python:3.12 | ||
| resources: | ||
| requests: | ||
| memory: "4Gi" | ||
| cpu: "1" | ||
| limits: | ||
| memory: "4Gi" | ||
| cpu: "2" | ||
| securityContext: | ||
| allowPrivilegeEscalation: false | ||
| capabilities: | ||
| drop: ["ALL"] | ||
| command: ["/bin/bash", "-c"] | ||
| args: | ||
| - | | ||
| pip install --user --no-cache-dir huggingface_hub transformers torch timm | ||
| python3 /mnt/scripts/download_virchow2.py | ||
| env: | ||
| - name: HOME | ||
| value: /tmp | ||
| - name: HF_TOKEN | ||
| valueFrom: | ||
| secretKeyRef: | ||
| name: huggingface-secret | ||
| key: token | ||
| - name: HTTPS_PROXY | ||
| value: "http://proxy.ics.muni.cz:3128" | ||
| - name: TORCH_HOME | ||
| value: /tmp/torch | ||
| - name: TORCHINDUCTOR_CACHE_DIR | ||
| value: /tmp/torch/inductor_cache | ||
| volumeMounts: | ||
| - name: huggingface-cache | ||
| mountPath: /mnt/huggingface_cache | ||
| - name: scripts | ||
| mountPath: /mnt/scripts | ||
| - name: temp | ||
| mountPath: /tmp | ||
| restartPolicy: Never | ||
| volumes: | ||
| - name: huggingface-cache | ||
| persistentVolumeClaim: | ||
| claimName: huggingface-cache-pvc | ||
| - name: scripts | ||
| configMap: | ||
| name: downloader-script | ||
| defaultMode: 0755 | ||
| - name: temp | ||
| emptyDir: {} | ||
Jurgee marked this conversation as resolved.
Show resolved
Hide resolved
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,133 @@ | ||
| import asyncio | ||
| from typing import Any, TypedDict | ||
|
|
||
| import numpy as np | ||
| from fastapi import FastAPI, Request | ||
| from numpy.typing import NDArray | ||
| from ray import serve | ||
|
|
||
|
|
||
| class Config(TypedDict): | ||
| tile_size: int | ||
| model: dict[str, Any] | ||
| max_batch_size: int | ||
| batch_wait_timeout_s: float | ||
|
|
||
|
|
||
| fastapi = FastAPI() | ||
|
|
||
|
|
||
| @serve.deployment(num_replicas="auto") | ||
| @serve.ingress(fastapi) | ||
| class Virchow2: | ||
| """Virchow2 foundation model for pathology.""" | ||
|
|
||
| def __init__(self) -> None: | ||
| import os | ||
|
|
||
| import lz4.frame | ||
|
|
||
| # Enforce offline mode for timm/huggingface_hub | ||
| os.environ["HF_HUB_OFFLINE"] = "1" | ||
|
|
||
| import torch | ||
|
|
||
| self.torch = torch | ||
| self.lz4 = lz4.frame | ||
| self.model: Any = None | ||
| self.transforms: Any = None | ||
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
| self.tile_size: int = 0 | ||
|
|
||
| def reconfigure(self, config: Config) -> None: | ||
| import importlib | ||
| import logging | ||
|
|
||
| import timm | ||
| from timm.data.config import resolve_data_config | ||
| from timm.data.transforms_factory import create_transform | ||
| from timm.layers.mlp import SwiGLUPacked | ||
|
|
||
| torch = self.torch | ||
|
|
||
| logger = logging.getLogger("ray.serve") | ||
| self.tile_size = config["tile_size"] | ||
|
|
||
| # Load model using the provider | ||
| module_path, attr_name = config["model"].pop("_target_").split(":") | ||
| provider = getattr(importlib.import_module(module_path), attr_name) | ||
| repo_id = config["model"]["repo_id"] | ||
|
|
||
| logger.info(f"Loading Virchow2 model from {repo_id}...") | ||
| provider(**config["model"]) | ||
|
|
||
| # Load model with official architecture | ||
| self.model = timm.create_model( | ||
| f"hf-hub:{repo_id}", | ||
| pretrained=True, | ||
| num_classes=0, | ||
| mlp_layer=SwiGLUPacked, | ||
| act_layer=torch.nn.SiLU, | ||
| ) | ||
| self.model = self.model.to(self.device).eval() | ||
|
|
||
| # Get transforms from model config | ||
| self.transforms = create_transform( | ||
| **resolve_data_config(self.model.pretrained_cfg, model=self.model) | ||
| ) | ||
|
|
||
| logger.info("Virchow2 model loaded and moved to GPU.") | ||
|
|
||
| self.predict.set_max_batch_size(config["max_batch_size"]) # type: ignore[attr-defined] | ||
| self.predict.set_batch_wait_timeout_s(config["batch_wait_timeout_s"]) # type: ignore[attr-defined] | ||
|
|
||
| @serve.batch | ||
| async def predict(self, images: list[NDArray[np.uint8]]) -> list[list[float]]: | ||
| from PIL import Image | ||
|
|
||
| if self.model is None or self.transforms is None: | ||
| raise RuntimeError("Model or transforms not initialized") | ||
|
|
||
| torch = self.torch | ||
|
|
||
| pil_images = [Image.fromarray(img) for img in images] | ||
| tensors = torch.stack([self.transforms(img) for img in pil_images]).to( | ||
| self.device | ||
| ) | ||
|
|
||
| device_type = self.device.type | ||
| autocast_dtype = torch.float16 if device_type == "cuda" else torch.bfloat16 | ||
|
|
||
| with ( | ||
| torch.inference_mode(), | ||
| torch.autocast(device_type=device_type, dtype=autocast_dtype), | ||
| ): | ||
| output = self.model(tensors) | ||
|
|
||
| # Extract embeddings as per official model card | ||
| class_token = output[:, 0] # CLS token: batch x 1280 | ||
| patch_tokens = output[ | ||
| :, 5: | ||
| ] # Skip register tokens (1-4): batch x 256 x 1280 | ||
|
|
||
| # Concatenate CLS token with mean of patch tokens | ||
| embedding = torch.cat( | ||
| [class_token, patch_tokens.mean(dim=1)], dim=-1 | ||
| ) # batch x 2560 | ||
|
|
||
| return embedding.half().cpu().tolist() | ||
|
|
||
| @fastapi.post("/") | ||
| async def root(self, request: Request) -> list[float]: | ||
| data = await asyncio.to_thread(self.lz4.decompress, await request.body()) | ||
|
|
||
| # Reshape to (height, width, channels) - RGB image | ||
| image = np.frombuffer(data, dtype=np.uint8).reshape( | ||
| self.tile_size, self.tile_size, 3 | ||
| ) | ||
|
|
||
| results = await self.predict(image) | ||
| return results | ||
|
|
||
|
|
||
| app = Virchow2.bind() # type: ignore[attr-defined] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,12 @@ | ||
| apiVersion: v1 | ||
| kind: PersistentVolumeClaim | ||
| metadata: | ||
| name: huggingface-cache-pvc | ||
| namespace: rationai-jobs-ns | ||
Jurgee marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| spec: | ||
| accessModes: | ||
| - ReadWriteMany | ||
| resources: | ||
| requests: | ||
| storage: 15Gi | ||
| storageClassName: nfs-csi | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.