Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
a167a22
feat: tensorrt support
matejpekar Jan 17, 2026
1d3310f
fix: remove flush
matejpekar Jan 17, 2026
4e27a48
feat: add docker files for cpu/gpu
Feb 8, 2026
fd3154d
feat: add PVC for TensorRT
Feb 8, 2026
eaac807
feat: add support of TensorRT for models
Feb 8, 2026
46fe8b1
feat: add TensorRT cache to workers
Feb 8, 2026
f07723e
add Jiri as coauthor
Feb 8, 2026
9d6e265
fix: remove gpu number from serve.deployment in code
Feb 8, 2026
e7612f9
fix: warning suppress
Feb 9, 2026
5945f10
feat: add jobs to download virchow2
Feb 10, 2026
8ef4cc5
feat: add model provider for hf
Feb 10, 2026
a6c427e
feat: add pvc for huggingface
Feb 10, 2026
27c7801
feat: add virchow2 model
Feb 10, 2026
e5d84cb
fix
Feb 10, 2026
e1fcb6c
fix: fine tune
Feb 10, 2026
e7ac073
feat: add into dockerfile
Feb 14, 2026
51f07a4
fix: remove installs from model
Feb 14, 2026
178f226
fix: based on official docs
Feb 14, 2026
5cc123f
fix
Feb 14, 2026
964114e
fix: remove comment
Feb 14, 2026
181f79e
chore: update docker gpu file
Jurgee Mar 13, 2026
156a7d4
feat: optimalize virchow2 deployment
Jurgee Mar 13, 2026
57176d6
fix: remove hf token, create new secret
Jurgee Mar 14, 2026
fe51ee2
fix
Jurgee Mar 14, 2026
9f37d03
Merge branch 'main' into feature/virchow2-model
Jurgee Mar 14, 2026
210c7e6
fix: remove intra threads
Jurgee Mar 14, 2026
bf7cff1
fix: lint
Jurgee Mar 14, 2026
6813264
fix: remove duplicity
Jurgee Mar 14, 2026
7510c9f
fixes
Jurgee Mar 14, 2026
2eae503
docker files
Jurgee Mar 14, 2026
c5095bd
fix: docker
Jurgee Mar 14, 2026
e94baec
chore: new docker image
Jurgee Mar 14, 2026
7cdd290
chore: cpu docker
Jurgee Mar 14, 2026
8cce2cb
fix
Jurgee Mar 14, 2026
7e329a8
final changes
Jurgee Mar 14, 2026
8dfea82
fix: usage of master branch
Jurgee Mar 15, 2026
e6f8603
Potential fix for pull request finding
Jurgee Mar 15, 2026
fb646c4
Potential fix for pull request finding
Jurgee Mar 15, 2026
b2d083c
Potential fix for pull request finding
Jurgee Mar 15, 2026
bfd90a9
Potential fix for pull request finding
Jurgee Mar 15, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion docker/Dockerfile.cpu
Original file line number Diff line number Diff line change
Expand Up @@ -51,4 +51,5 @@ RUN sudo apt-get update && sudo apt-get -y upgrade && \
# Cleanup
RUN sudo apt-get remove -y --purge systemd systemd-sysv && sudo apt-get autoremove --purge -y && sudo apt-get clean && sudo rm -rf /var/lib/apt/lists/*

RUN pip install --no-cache-dir onnxruntime lz4 ratiopath "mlflow<3.0"
RUN pip install --no-cache-dir \
onnxruntime lz4 ratiopath "mlflow<3.0"
12 changes: 11 additions & 1 deletion docker/Dockerfile.gpu
Original file line number Diff line number Diff line change
Expand Up @@ -53,4 +53,14 @@ RUN sudo sh -c 'echo "/usr/local/lib" > /etc/ld.so.conf.d/custom-libs.conf' && \
sudo sh -c 'echo "/home/ray/anaconda3/lib/python3.12/site-packages/nvidia/cudnn/lib" > /etc/ld.so.conf.d/nvidia-libs.conf' && \
sudo ldconfig

RUN pip install --no-cache-dir onnxruntime-gpu tensorrt lz4 ratiopath "mlflow<3.0"
RUN pip install --no-cache-dir \
--extra-index-url https://pypi.nvidia.com \
onnxruntime-gpu tensorrt-cu12==10.3.0 lz4 ratiopath "mlflow<3.0"

RUN pip install --no-cache-dir \
torch==2.4.0+cu121 torchvision==0.19.0+cu121 \
--index-url https://download.pytorch.org/whl/cu121

RUN pip install --no-cache-dir \
"timm>=1.0.0" \
"huggingface-hub>=0.23.0"
48 changes: 48 additions & 0 deletions misc/virchow2_downloader/download_virchow2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import os

from huggingface_hub import login, snapshot_download


HF_TOKEN = os.environ.get("HF_TOKEN")
CACHE_DIR = "/mnt/huggingface_cache"
MODEL_ID = "paige-ai/Virchow2"

os.environ["HF_HOME"] = CACHE_DIR
os.makedirs(CACHE_DIR, exist_ok=True)

print(f"Starting download for {MODEL_ID} to {CACHE_DIR}")

if HF_TOKEN:
print("Logging in to Hugging Face...")
login(token=HF_TOKEN)
else:
print("No HF_TOKEN provided! Download might fail for gated models.")

print("Downloading model snapshot...")
try:
path = snapshot_download(
repo_id=MODEL_ID,
resume_download=True,
local_files_only=False,
)
print(f"Model downloaded to: {path}")

print("Verifying model files exist...")
import timm

try:
model = timm.create_model(
f"hf-hub:{MODEL_ID}",
pretrained=True,
num_classes=0,
)
print(f"Model successfully loaded! Type: {type(model).__name__}")
del model # Free memory
except Exception as e:
print(f"Verification warning: {e}")

except Exception as e:
print(f"Download failed: {e}")
exit(1)

print("DONE. Model is cached and ready for offline use.")
65 changes: 65 additions & 0 deletions misc/virchow2_downloader/virchow2_downloader_job.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
apiVersion: batch/v1
kind: Job
metadata:
name: virchow-downloader
namespace: rationai-jobs-ns
spec:
template:
spec:
securityContext:
runAsNonRoot: true
runAsUser: 1000
fsGroup: 1000
seccompProfile:
type: RuntimeDefault
containers:
- name: downloader
image: python:3.12
resources:
requests:
memory: "4Gi"
cpu: "1"
limits:
memory: "4Gi"
cpu: "2"
securityContext:
allowPrivilegeEscalation: false
capabilities:
drop: ["ALL"]
command: ["/bin/bash", "-c"]
args:
- |
pip install --user --no-cache-dir huggingface_hub transformers torch timm
python3 /mnt/scripts/download_virchow2.py
env:
- name: HOME
value: /tmp
- name: HF_TOKEN
valueFrom:
secretKeyRef:
name: huggingface-secret
key: token
- name: HTTPS_PROXY
value: "http://proxy.ics.muni.cz:3128"
- name: TORCH_HOME
value: /tmp/torch
- name: TORCHINDUCTOR_CACHE_DIR
value: /tmp/torch/inductor_cache
volumeMounts:
- name: huggingface-cache
mountPath: /mnt/huggingface_cache
- name: scripts
mountPath: /mnt/scripts
- name: temp
mountPath: /tmp
restartPolicy: Never
volumes:
- name: huggingface-cache
persistentVolumeClaim:
claimName: huggingface-cache-pvc
- name: scripts
configMap:
name: downloader-script
defaultMode: 0755
- name: temp
emptyDir: {}
9 changes: 5 additions & 4 deletions models/binary_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ class Config(TypedDict):
model: dict[str, Any]
max_batch_size: int
batch_wait_timeout_s: float
intra_op_num_threads: int
trt_cache_path: str


Expand Down Expand Up @@ -69,7 +68,7 @@ def reconfigure(self, config: Config) -> None:
"trt_engine_cache_path": cache_path,
"trt_max_workspace_size": config.get(
"trt_max_workspace_size", 8 * 1024 * 1024 * 1024
), # type: ignore[typeddict-item]
),
"trt_builder_optimization_level": 5,
"trt_timing_cache_enable": True,
"trt_profile_min_shapes": min_shape,
Expand All @@ -79,7 +78,6 @@ def reconfigure(self, config: Config) -> None:

# Configure ONNX Runtime session
sess_options = ort.SessionOptions()
sess_options.intra_op_num_threads = config["intra_op_num_threads"]
sess_options.inter_op_num_threads = 1

# Enable all graph optimizations (constant folding, node fusion, etc.) for maximum inference performance.
Expand Down Expand Up @@ -118,7 +116,10 @@ async def predict(self, images: list[NDArray[np.uint8]]) -> list[float]:
"""Run inference on a batch of images."""
batch = np.stack(images, axis=0, dtype=np.uint8)

outputs = self.session.run([self.output_name], {self.input_name: batch})
outputs = self.session.run(
[self.output_name],
{self.input_name: batch},
)

return outputs[0].flatten().tolist() # pyright: ignore[reportAttributeAccessIssue]

Expand Down
4 changes: 1 addition & 3 deletions models/semantic_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ class Config(TypedDict):
model: dict[str, Any]
max_batch_size: int
batch_wait_timeout_s: float
intra_op_num_threads: int
trt_cache_path: str


Expand Down Expand Up @@ -65,7 +64,7 @@ def reconfigure(self, config: Config) -> None:
"trt_engine_cache_path": cache_path,
"trt_max_workspace_size": config.get(
"trt_max_workspace_size", 8 * 1024 * 1024 * 1024
), # type: ignore[typeddict-item]
),
"trt_builder_optimization_level": 5,
"trt_timing_cache_enable": True,
"trt_profile_min_shapes": min_shape,
Expand All @@ -75,7 +74,6 @@ def reconfigure(self, config: Config) -> None:

# Configure ONNX Runtime session
sess_options = ort.SessionOptions()
sess_options.intra_op_num_threads = config["intra_op_num_threads"]
sess_options.inter_op_num_threads = 1

# Enable all graph optimizations (constant folding, node fusion, etc.) for maximum inference performance.
Expand Down
133 changes: 133 additions & 0 deletions models/virchow2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
import asyncio
from typing import Any, TypedDict

import numpy as np
from fastapi import FastAPI, Request
from numpy.typing import NDArray
from ray import serve


class Config(TypedDict):
tile_size: int
model: dict[str, Any]
max_batch_size: int
batch_wait_timeout_s: float


fastapi = FastAPI()


@serve.deployment(num_replicas="auto")
@serve.ingress(fastapi)
class Virchow2:
"""Virchow2 foundation model for pathology."""

def __init__(self) -> None:
import os

import lz4.frame

# Enforce offline mode for timm/huggingface_hub
os.environ["HF_HUB_OFFLINE"] = "1"

import torch

self.torch = torch
self.lz4 = lz4.frame
self.model: Any = None
self.transforms: Any = None
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.tile_size: int = 0

def reconfigure(self, config: Config) -> None:
import importlib
import logging

import timm
from timm.data.config import resolve_data_config
from timm.data.transforms_factory import create_transform
from timm.layers.mlp import SwiGLUPacked

torch = self.torch

logger = logging.getLogger("ray.serve")
self.tile_size = config["tile_size"]

# Load model using the provider
module_path, attr_name = config["model"].pop("_target_").split(":")
provider = getattr(importlib.import_module(module_path), attr_name)
repo_id = config["model"]["repo_id"]

logger.info(f"Loading Virchow2 model from {repo_id}...")
provider(**config["model"])

# Load model with official architecture
self.model = timm.create_model(
f"hf-hub:{repo_id}",
pretrained=True,
num_classes=0,
mlp_layer=SwiGLUPacked,
act_layer=torch.nn.SiLU,
)
self.model = self.model.to(self.device).eval()

# Get transforms from model config
self.transforms = create_transform(
**resolve_data_config(self.model.pretrained_cfg, model=self.model)
)

logger.info("Virchow2 model loaded and moved to GPU.")

self.predict.set_max_batch_size(config["max_batch_size"]) # type: ignore[attr-defined]
self.predict.set_batch_wait_timeout_s(config["batch_wait_timeout_s"]) # type: ignore[attr-defined]

@serve.batch
async def predict(self, images: list[NDArray[np.uint8]]) -> list[list[float]]:
from PIL import Image

if self.model is None or self.transforms is None:
raise RuntimeError("Model or transforms not initialized")

torch = self.torch

pil_images = [Image.fromarray(img) for img in images]
tensors = torch.stack([self.transforms(img) for img in pil_images]).to(
self.device
)

device_type = self.device.type
autocast_dtype = torch.float16 if device_type == "cuda" else torch.bfloat16

with (
torch.inference_mode(),
torch.autocast(device_type=device_type, dtype=autocast_dtype),
):
output = self.model(tensors)

# Extract embeddings as per official model card
class_token = output[:, 0] # CLS token: batch x 1280
patch_tokens = output[
:, 5:
] # Skip register tokens (1-4): batch x 256 x 1280

# Concatenate CLS token with mean of patch tokens
embedding = torch.cat(
[class_token, patch_tokens.mean(dim=1)], dim=-1
) # batch x 2560

return embedding.half().cpu().tolist()

@fastapi.post("/")
async def root(self, request: Request) -> list[float]:
data = await asyncio.to_thread(self.lz4.decompress, await request.body())

# Reshape to (height, width, channels) - RGB image
image = np.frombuffer(data, dtype=np.uint8).reshape(
self.tile_size, self.tile_size, 3
)

results = await self.predict(image)
return results


app = Virchow2.bind() # type: ignore[attr-defined]
22 changes: 22 additions & 0 deletions providers/model_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,25 @@ def mlflow(artifact_uri: str) -> str:
import mlflow.artifacts

return mlflow.artifacts.download_artifacts(artifact_uri=artifact_uri)


def huggingface(repo_id: str, filename: str | None = None) -> str:
import os

from huggingface_hub import hf_hub_download, snapshot_download

hf_home = os.environ.get("HF_HOME", "/mnt/huggingface_cache")
os.makedirs(hf_home, exist_ok=True)
os.environ["HF_HOME"] = hf_home

if filename:
return hf_hub_download(
repo_id=repo_id,
filename=filename,
local_files_only=True,
)
else:
return snapshot_download(
repo_id=repo_id,
local_files_only=True,
)
12 changes: 12 additions & 0 deletions pvc/huggingface-pvc.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
apiVersion: v1
kind: PersistentVolumeClaim
metadata:
name: huggingface-cache-pvc
namespace: rationai-jobs-ns
spec:
accessModes:
- ReadWriteMany
resources:
requests:
storage: 15Gi
storageClassName: nfs-csi
Loading