Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 44 additions & 1 deletion gui/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
from PyQt5.QtCore import Qt, QThread, pyqtSignal
from PyQt5.QtGui import QDragEnterEvent, QDropEvent


STYLE_SHEET = """
QGroupBox {
font-weight: bold;
Expand Down Expand Up @@ -238,6 +237,15 @@ def patched_read_tile_region(tile_idx, y_slice, x_slice):
tf.refine_tile_positions_with_cross_correlation()
self.progress.emit(f"Found {len(tf.pairwise_metrics)} pairs")

# Estimate pixel size
try:
estimated_px, deviation = tf.estimate_pixel_size()
self.progress.emit(
f"Estimated pixel size: {estimated_px:.4f} µm ({deviation:+.1f}% from metadata)"
)
except ValueError as e:
self.progress.emit(f"Pixel size estimation skipped: {e}")

tf.optimize_shifts(
method="TWO_ROUND_ITERATIVE", rel_thresh=0.5, abs_thresh=2.0, iterative=True
)
Expand Down Expand Up @@ -344,6 +352,7 @@ def __init__(
registration_z=None,
registration_t=0,
registration_channel=0,
use_estimated_pixel_size=False,
):
super().__init__()
self.tiff_path = tiff_path
Expand All @@ -356,8 +365,24 @@ def __init__(
self.registration_z = registration_z
self.registration_t = registration_t
self.registration_channel = registration_channel
self.use_estimated_pixel_size = use_estimated_pixel_size
self.output_path = None

def _apply_estimated_pixel_size(self, tf, estimated_px: float, deviation: float) -> None:
"""Apply estimated pixel size if enabled and within reasonable bounds."""
if not self.use_estimated_pixel_size or abs(deviation) <= 1.0:
return

# Sanity check: estimated should be within 50% of original
ratio = estimated_px / tf._pixel_size[0]
if 0.5 < ratio < 2.0:
tf._pixel_size = (estimated_px, estimated_px)
self.progress.emit("Using estimated pixel size for stitching")
else:
self.progress.emit(
f"Warning: Estimated pixel size {estimated_px:.4f} is unreasonable, ignoring"
)

def run(self):
try:
from tilefusion import TileFusion
Expand Down Expand Up @@ -426,6 +451,16 @@ def run(self):
self.progress.emit(
f"Registration complete: {len(tf.pairwise_metrics)} pairs [{reg_time:.1f}s]"
)

# Estimate pixel size
try:
estimated_px, deviation = tf.estimate_pixel_size()
self.progress.emit(
f"Estimated pixel size: {estimated_px:.4f} µm ({deviation:+.1f}% from metadata)"
)
self._apply_estimated_pixel_size(tf, estimated_px, deviation)
except ValueError as e:
self.progress.emit(f"Could not estimate pixel size: {e}")
else:
tf.threshold = 1.0 # Skip registration
self.progress.emit("Using stage positions (no registration)")
Expand Down Expand Up @@ -963,6 +998,13 @@ def setup_ui(self):
blend_value_layout.addStretch()
settings_layout.addWidget(self.blend_value_widget)

self.use_estimated_px_checkbox = QCheckBox("Use estimated pixel size")
self.use_estimated_px_checkbox.setToolTip(
"Estimate pixel size from registration and use it for stitching"
)
self.use_estimated_px_checkbox.setChecked(False)
settings_layout.addWidget(self.use_estimated_px_checkbox)

layout.addWidget(settings_group)

# Run button
Expand Down Expand Up @@ -1364,6 +1406,7 @@ def run_stitching(self):
registration_z=registration_z,
registration_t=registration_t,
registration_channel=registration_channel,
use_estimated_pixel_size=self.use_estimated_px_checkbox.isChecked(),
)
self.worker.progress.connect(self.log)
self.worker.finished.connect(self.on_fusion_finished)
Expand Down
1 change: 1 addition & 0 deletions scripts/view_in_napari.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
Simple script to view fused OME-Zarr in napari.
Works around napari-ome-zarr plugin issues with Zarr v3.
"""

import sys
from pathlib import Path

Expand Down
77 changes: 72 additions & 5 deletions src/tilefusion/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,10 @@ class TileFusion:
Channel index for registration.
multiscale_downsample : str
Either "stride" (default) or "block_mean" to control multiscale reduction.
registration_z : int, optional
Z-level to use for registration. If None, uses middle z-level.
registration_t : int
Timepoint to use for registration. Defaults to 0.
"""

def __init__(
Expand Down Expand Up @@ -461,7 +465,9 @@ def _update_profiles(self) -> None:
# I/O methods (delegate to format-specific loaders)
# -------------------------------------------------------------------------

def _read_tile(self, tile_idx: int, z_level: int = None, time_idx: int = None) -> np.ndarray:
def _read_tile(
self, tile_idx: int, z_level: Optional[int] = None, time_idx: Optional[int] = None
) -> np.ndarray:
"""Read a single tile from the input data (all channels)."""
if z_level is None:
z_level = self._registration_z # Default to registration z-level
Expand All @@ -471,7 +477,7 @@ def _read_tile(self, tile_idx: int, z_level: int = None, time_idx: int = None) -
if self._is_zarr_format:
zarr_ts = self._metadata["tensorstore"]
is_3d = self._metadata.get("is_3d", False)
tile = read_zarr_tile(zarr_ts, tile_idx, is_3d)
tile = read_zarr_tile(zarr_ts, tile_idx, is_3d, z_level=z_level, time_idx=time_idx)
elif self._is_individual_tiffs_format:
tile = read_individual_tiffs_tile(
self._metadata["image_folder"],
Expand Down Expand Up @@ -508,8 +514,8 @@ def _read_tile_region(
tile_idx: int,
y_slice: slice,
x_slice: slice,
z_level: int = None,
time_idx: int = None,
z_level: Optional[int] = None,
time_idx: Optional[int] = None,
) -> np.ndarray:
"""Read a region of a tile from the input data."""
if z_level is None:
Expand All @@ -521,7 +527,14 @@ def _read_tile_region(
zarr_ts = self._metadata["tensorstore"]
is_3d = self._metadata.get("is_3d", False)
region = read_zarr_region(
zarr_ts, tile_idx, y_slice, x_slice, self.channel_to_use, is_3d
zarr_ts,
tile_idx,
y_slice,
x_slice,
self.channel_to_use,
is_3d,
z_level=z_level,
time_idx=time_idx,
)
elif self._is_individual_tiffs_format:
region = read_individual_tiffs_region(
Expand Down Expand Up @@ -799,6 +812,60 @@ def read_patch(idx, y_bounds, x_bounds):

io_executor.shutdown(wait=True)

def estimate_pixel_size(self) -> Tuple[float, float]:
"""
Estimate pixel size from registration results.

Compares expected shifts (from stage positions / metadata pixel size)
with measured shifts (from cross-correlation) to estimate true pixel size.

Returns
-------
estimated_pixel_size : float
Estimated pixel size in same units as metadata (typically um).
deviation_percent : float
Percentage deviation from metadata: (estimated/metadata - 1) * 100

Raises
------
ValueError
If no valid pairwise metrics available.
"""
if not self.pairwise_metrics:
raise ValueError("No pairwise metrics available. Run registration first.")

ratios = []

for (i, j), (dy_measured, dx_measured, score) in self.pairwise_metrics.items():
# Get stage positions
pos_i = np.array(self._tile_positions[i])
pos_j = np.array(self._tile_positions[j])

# Expected shift in pixels = stage_distance / pixel_size
stage_diff = pos_j - pos_i # (dy, dx) in physical units
expected_dy = stage_diff[0] / self._pixel_size[0]
expected_dx = stage_diff[1] / self._pixel_size[1]

# Compute ratio for non-zero shifts (both expected and measured must be significant)
if abs(dx_measured) > 5 and abs(expected_dx) > 5: # Horizontal shift
ratio = expected_dx / dx_measured
ratios.append(ratio)
if abs(dy_measured) > 5 and abs(expected_dy) > 5: # Vertical shift
ratio = expected_dy / dy_measured
ratios.append(ratio)

if not ratios:
raise ValueError("No valid shift measurements for pixel size estimation.")

# Use median to filter outliers
median_ratio = float(np.median(ratios))

# Estimated pixel size (assume isotropic)
estimated = self._pixel_size[0] * median_ratio
deviation_percent = (median_ratio - 1.0) * 100.0

return estimated, deviation_percent

# -------------------------------------------------------------------------
# Optimization
# -------------------------------------------------------------------------
Expand Down
43 changes: 33 additions & 10 deletions src/tilefusion/io/zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,8 @@ def read_zarr_tile(
zarr_ts: ts.TensorStore,
tile_idx: int,
is_3d: bool = False,
z_level: int = None,
time_idx: int = 0,
) -> np.ndarray:
"""
Read all channels of a tile from Zarr format.
Expand All @@ -111,18 +113,27 @@ def read_zarr_tile(
tile_idx : int
Index of the tile.
is_3d : bool
If True, data is 3D and max projection is applied.
If True, data is 3D.
z_level : int, optional
Z-level to read. If None and is_3d, uses max projection.
time_idx : int
Timepoint to read. Defaults to 0.

Returns
-------
arr : ndarray of shape (C, Y, X)
Tile data as float32.
"""
if is_3d:
arr = zarr_ts[0, tile_idx, :, :, :, :].read().result()
arr = np.max(arr, axis=1) # Max projection along Z
if z_level is not None:
# Read specific z-level
arr = zarr_ts[time_idx, tile_idx, :, z_level, :, :].read().result()
else:
# Max projection along Z (legacy behavior)
arr = zarr_ts[time_idx, tile_idx, :, :, :, :].read().result()
arr = np.max(arr, axis=1)
else:
arr = zarr_ts[0, tile_idx, :, :, :].read().result()
arr = zarr_ts[time_idx, tile_idx, :, :, :].read().result()
return arr.astype(np.float32)


Expand All @@ -133,6 +144,8 @@ def read_zarr_region(
x_slice: slice,
channel_idx: int = 0,
is_3d: bool = False,
z_level: int = None,
time_idx: int = 0,
) -> np.ndarray:
"""
Read a region of a single channel from Zarr format.
Expand All @@ -149,20 +162,30 @@ def read_zarr_region(
Channel index.
is_3d : bool
If True, data is 3D.
z_level : int, optional
Z-level to read. If None and is_3d, uses max projection.
time_idx : int
Timepoint to read. Defaults to 0.

Returns
-------
arr : ndarray of shape (1, h, w)
Tile region as float32.
"""
if is_3d:
arr = zarr_ts[0, tile_idx, channel_idx, :, y_slice, x_slice].read().result()
arr = np.max(arr, axis=0)
arr = arr[np.newaxis, :, :]
if z_level is not None:
# Read specific z-level
arr = (
zarr_ts[time_idx, tile_idx, channel_idx, z_level, y_slice, x_slice].read().result()
)
else:
# Max projection along Z (legacy behavior)
arr = zarr_ts[time_idx, tile_idx, channel_idx, :, y_slice, x_slice].read().result()
arr = np.max(arr, axis=0)
else:
arr = zarr_ts[0, tile_idx, channel_idx, y_slice, x_slice].read().result()
arr = arr[np.newaxis, :, :]
return arr.astype(np.float32)
arr = zarr_ts[time_idx, tile_idx, channel_idx, y_slice, x_slice].read().result()

return arr[np.newaxis, :, :].astype(np.float32)


def create_zarr_store(
Expand Down
58 changes: 58 additions & 0 deletions tests/test_core_pixel_estimation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
"""Tests for pixel size estimation."""

import numpy as np
import pytest
from unittest.mock import MagicMock


class TestEstimatePixelSize:
"""Tests for TileFusion.estimate_pixel_size()."""

def _create_mock_tilefusion(self, tile_positions, pixel_size, pairwise_metrics):
"""Create a mock TileFusion with required state."""
from tilefusion.core import TileFusion

mock = MagicMock(spec=TileFusion)
mock._tile_positions = tile_positions
mock._pixel_size = pixel_size
mock.pairwise_metrics = pairwise_metrics

# Bind the real method
mock.estimate_pixel_size = lambda: TileFusion.estimate_pixel_size(mock)
return mock

def test_perfect_calibration(self):
"""When measured shifts match expected, deviation should be ~0%."""
tile_positions = [(0, 0), (0, 90), (90, 0), (90, 90)]
pixel_size = (1.0, 1.0)
pairwise_metrics = {
(0, 1): (0, 90, 0.95),
(0, 2): (90, 0, 0.95),
(1, 3): (90, 0, 0.95),
(2, 3): (0, 90, 0.95),
}

tf = self._create_mock_tilefusion(tile_positions, pixel_size, pairwise_metrics)
estimated, deviation = tf.estimate_pixel_size()

assert abs(estimated - 1.0) < 0.01
assert abs(deviation) < 1.0

def test_pixel_size_underestimated(self):
"""When metadata pixel size is too small, estimated should be larger."""
tile_positions = [(0, 0), (0, 90)]
pixel_size = (1.0, 1.0)
pairwise_metrics = {(0, 1): (0, 82, 0.95)}

tf = self._create_mock_tilefusion(tile_positions, pixel_size, pairwise_metrics)
estimated, deviation = tf.estimate_pixel_size()

assert 1.05 < estimated < 1.15
assert deviation > 5.0

def test_no_metrics_raises(self):
"""Should raise if no pairwise metrics."""
tf = self._create_mock_tilefusion([], (1.0, 1.0), {})

with pytest.raises(ValueError, match="No pairwise metrics"):
tf.estimate_pixel_size()