diff --git a/gui/app.py b/gui/app.py index c2ad469..8e0894b 100644 --- a/gui/app.py +++ b/gui/app.py @@ -36,7 +36,6 @@ from PyQt5.QtCore import Qt, QThread, pyqtSignal from PyQt5.QtGui import QDragEnterEvent, QDropEvent - STYLE_SHEET = """ QGroupBox { font-weight: bold; @@ -238,6 +237,15 @@ def patched_read_tile_region(tile_idx, y_slice, x_slice): tf.refine_tile_positions_with_cross_correlation() self.progress.emit(f"Found {len(tf.pairwise_metrics)} pairs") + # Estimate pixel size + try: + estimated_px, deviation = tf.estimate_pixel_size() + self.progress.emit( + f"Estimated pixel size: {estimated_px:.4f} µm ({deviation:+.1f}% from metadata)" + ) + except ValueError as e: + self.progress.emit(f"Pixel size estimation skipped: {e}") + tf.optimize_shifts( method="TWO_ROUND_ITERATIVE", rel_thresh=0.5, abs_thresh=2.0, iterative=True ) @@ -344,6 +352,7 @@ def __init__( registration_z=None, registration_t=0, registration_channel=0, + use_estimated_pixel_size=False, ): super().__init__() self.tiff_path = tiff_path @@ -356,8 +365,24 @@ def __init__( self.registration_z = registration_z self.registration_t = registration_t self.registration_channel = registration_channel + self.use_estimated_pixel_size = use_estimated_pixel_size self.output_path = None + def _apply_estimated_pixel_size(self, tf, estimated_px: float, deviation: float) -> None: + """Apply estimated pixel size if enabled and within reasonable bounds.""" + if not self.use_estimated_pixel_size or abs(deviation) <= 1.0: + return + + # Sanity check: estimated should be within 50% of original + ratio = estimated_px / tf._pixel_size[0] + if 0.5 < ratio < 2.0: + tf._pixel_size = (estimated_px, estimated_px) + self.progress.emit("Using estimated pixel size for stitching") + else: + self.progress.emit( + f"Warning: Estimated pixel size {estimated_px:.4f} is unreasonable, ignoring" + ) + def run(self): try: from tilefusion import TileFusion @@ -426,6 +451,16 @@ def run(self): self.progress.emit( f"Registration complete: {len(tf.pairwise_metrics)} pairs [{reg_time:.1f}s]" ) + + # Estimate pixel size + try: + estimated_px, deviation = tf.estimate_pixel_size() + self.progress.emit( + f"Estimated pixel size: {estimated_px:.4f} µm ({deviation:+.1f}% from metadata)" + ) + self._apply_estimated_pixel_size(tf, estimated_px, deviation) + except ValueError as e: + self.progress.emit(f"Could not estimate pixel size: {e}") else: tf.threshold = 1.0 # Skip registration self.progress.emit("Using stage positions (no registration)") @@ -963,6 +998,13 @@ def setup_ui(self): blend_value_layout.addStretch() settings_layout.addWidget(self.blend_value_widget) + self.use_estimated_px_checkbox = QCheckBox("Use estimated pixel size") + self.use_estimated_px_checkbox.setToolTip( + "Estimate pixel size from registration and use it for stitching" + ) + self.use_estimated_px_checkbox.setChecked(False) + settings_layout.addWidget(self.use_estimated_px_checkbox) + layout.addWidget(settings_group) # Run button @@ -1364,6 +1406,7 @@ def run_stitching(self): registration_z=registration_z, registration_t=registration_t, registration_channel=registration_channel, + use_estimated_pixel_size=self.use_estimated_px_checkbox.isChecked(), ) self.worker.progress.connect(self.log) self.worker.finished.connect(self.on_fusion_finished) diff --git a/scripts/view_in_napari.py b/scripts/view_in_napari.py index c38628c..a4697d8 100644 --- a/scripts/view_in_napari.py +++ b/scripts/view_in_napari.py @@ -3,6 +3,7 @@ Simple script to view fused OME-Zarr in napari. Works around napari-ome-zarr plugin issues with Zarr v3. """ + import sys from pathlib import Path diff --git a/src/tilefusion/core.py b/src/tilefusion/core.py index d28bc7e..ab81ea8 100644 --- a/src/tilefusion/core.py +++ b/src/tilefusion/core.py @@ -84,6 +84,10 @@ class TileFusion: Channel index for registration. multiscale_downsample : str Either "stride" (default) or "block_mean" to control multiscale reduction. + registration_z : int, optional + Z-level to use for registration. If None, uses middle z-level. + registration_t : int + Timepoint to use for registration. Defaults to 0. """ def __init__( @@ -461,7 +465,9 @@ def _update_profiles(self) -> None: # I/O methods (delegate to format-specific loaders) # ------------------------------------------------------------------------- - def _read_tile(self, tile_idx: int, z_level: int = None, time_idx: int = None) -> np.ndarray: + def _read_tile( + self, tile_idx: int, z_level: Optional[int] = None, time_idx: Optional[int] = None + ) -> np.ndarray: """Read a single tile from the input data (all channels).""" if z_level is None: z_level = self._registration_z # Default to registration z-level @@ -471,7 +477,7 @@ def _read_tile(self, tile_idx: int, z_level: int = None, time_idx: int = None) - if self._is_zarr_format: zarr_ts = self._metadata["tensorstore"] is_3d = self._metadata.get("is_3d", False) - tile = read_zarr_tile(zarr_ts, tile_idx, is_3d) + tile = read_zarr_tile(zarr_ts, tile_idx, is_3d, z_level=z_level, time_idx=time_idx) elif self._is_individual_tiffs_format: tile = read_individual_tiffs_tile( self._metadata["image_folder"], @@ -508,8 +514,8 @@ def _read_tile_region( tile_idx: int, y_slice: slice, x_slice: slice, - z_level: int = None, - time_idx: int = None, + z_level: Optional[int] = None, + time_idx: Optional[int] = None, ) -> np.ndarray: """Read a region of a tile from the input data.""" if z_level is None: @@ -521,7 +527,14 @@ def _read_tile_region( zarr_ts = self._metadata["tensorstore"] is_3d = self._metadata.get("is_3d", False) region = read_zarr_region( - zarr_ts, tile_idx, y_slice, x_slice, self.channel_to_use, is_3d + zarr_ts, + tile_idx, + y_slice, + x_slice, + self.channel_to_use, + is_3d, + z_level=z_level, + time_idx=time_idx, ) elif self._is_individual_tiffs_format: region = read_individual_tiffs_region( @@ -799,6 +812,60 @@ def read_patch(idx, y_bounds, x_bounds): io_executor.shutdown(wait=True) + def estimate_pixel_size(self) -> Tuple[float, float]: + """ + Estimate pixel size from registration results. + + Compares expected shifts (from stage positions / metadata pixel size) + with measured shifts (from cross-correlation) to estimate true pixel size. + + Returns + ------- + estimated_pixel_size : float + Estimated pixel size in same units as metadata (typically um). + deviation_percent : float + Percentage deviation from metadata: (estimated/metadata - 1) * 100 + + Raises + ------ + ValueError + If no valid pairwise metrics available. + """ + if not self.pairwise_metrics: + raise ValueError("No pairwise metrics available. Run registration first.") + + ratios = [] + + for (i, j), (dy_measured, dx_measured, score) in self.pairwise_metrics.items(): + # Get stage positions + pos_i = np.array(self._tile_positions[i]) + pos_j = np.array(self._tile_positions[j]) + + # Expected shift in pixels = stage_distance / pixel_size + stage_diff = pos_j - pos_i # (dy, dx) in physical units + expected_dy = stage_diff[0] / self._pixel_size[0] + expected_dx = stage_diff[1] / self._pixel_size[1] + + # Compute ratio for non-zero shifts (both expected and measured must be significant) + if abs(dx_measured) > 5 and abs(expected_dx) > 5: # Horizontal shift + ratio = expected_dx / dx_measured + ratios.append(ratio) + if abs(dy_measured) > 5 and abs(expected_dy) > 5: # Vertical shift + ratio = expected_dy / dy_measured + ratios.append(ratio) + + if not ratios: + raise ValueError("No valid shift measurements for pixel size estimation.") + + # Use median to filter outliers + median_ratio = float(np.median(ratios)) + + # Estimated pixel size (assume isotropic) + estimated = self._pixel_size[0] * median_ratio + deviation_percent = (median_ratio - 1.0) * 100.0 + + return estimated, deviation_percent + # ------------------------------------------------------------------------- # Optimization # ------------------------------------------------------------------------- diff --git a/src/tilefusion/io/zarr.py b/src/tilefusion/io/zarr.py index dfd66ab..0fd760a 100644 --- a/src/tilefusion/io/zarr.py +++ b/src/tilefusion/io/zarr.py @@ -100,6 +100,8 @@ def read_zarr_tile( zarr_ts: ts.TensorStore, tile_idx: int, is_3d: bool = False, + z_level: int = None, + time_idx: int = 0, ) -> np.ndarray: """ Read all channels of a tile from Zarr format. @@ -111,7 +113,11 @@ def read_zarr_tile( tile_idx : int Index of the tile. is_3d : bool - If True, data is 3D and max projection is applied. + If True, data is 3D. + z_level : int, optional + Z-level to read. If None and is_3d, uses max projection. + time_idx : int + Timepoint to read. Defaults to 0. Returns ------- @@ -119,10 +125,15 @@ def read_zarr_tile( Tile data as float32. """ if is_3d: - arr = zarr_ts[0, tile_idx, :, :, :, :].read().result() - arr = np.max(arr, axis=1) # Max projection along Z + if z_level is not None: + # Read specific z-level + arr = zarr_ts[time_idx, tile_idx, :, z_level, :, :].read().result() + else: + # Max projection along Z (legacy behavior) + arr = zarr_ts[time_idx, tile_idx, :, :, :, :].read().result() + arr = np.max(arr, axis=1) else: - arr = zarr_ts[0, tile_idx, :, :, :].read().result() + arr = zarr_ts[time_idx, tile_idx, :, :, :].read().result() return arr.astype(np.float32) @@ -133,6 +144,8 @@ def read_zarr_region( x_slice: slice, channel_idx: int = 0, is_3d: bool = False, + z_level: int = None, + time_idx: int = 0, ) -> np.ndarray: """ Read a region of a single channel from Zarr format. @@ -149,6 +162,10 @@ def read_zarr_region( Channel index. is_3d : bool If True, data is 3D. + z_level : int, optional + Z-level to read. If None and is_3d, uses max projection. + time_idx : int + Timepoint to read. Defaults to 0. Returns ------- @@ -156,13 +173,19 @@ def read_zarr_region( Tile region as float32. """ if is_3d: - arr = zarr_ts[0, tile_idx, channel_idx, :, y_slice, x_slice].read().result() - arr = np.max(arr, axis=0) - arr = arr[np.newaxis, :, :] + if z_level is not None: + # Read specific z-level + arr = ( + zarr_ts[time_idx, tile_idx, channel_idx, z_level, y_slice, x_slice].read().result() + ) + else: + # Max projection along Z (legacy behavior) + arr = zarr_ts[time_idx, tile_idx, channel_idx, :, y_slice, x_slice].read().result() + arr = np.max(arr, axis=0) else: - arr = zarr_ts[0, tile_idx, channel_idx, y_slice, x_slice].read().result() - arr = arr[np.newaxis, :, :] - return arr.astype(np.float32) + arr = zarr_ts[time_idx, tile_idx, channel_idx, y_slice, x_slice].read().result() + + return arr[np.newaxis, :, :].astype(np.float32) def create_zarr_store( diff --git a/tests/test_core_pixel_estimation.py b/tests/test_core_pixel_estimation.py new file mode 100644 index 0000000..ef9a738 --- /dev/null +++ b/tests/test_core_pixel_estimation.py @@ -0,0 +1,58 @@ +"""Tests for pixel size estimation.""" + +import numpy as np +import pytest +from unittest.mock import MagicMock + + +class TestEstimatePixelSize: + """Tests for TileFusion.estimate_pixel_size().""" + + def _create_mock_tilefusion(self, tile_positions, pixel_size, pairwise_metrics): + """Create a mock TileFusion with required state.""" + from tilefusion.core import TileFusion + + mock = MagicMock(spec=TileFusion) + mock._tile_positions = tile_positions + mock._pixel_size = pixel_size + mock.pairwise_metrics = pairwise_metrics + + # Bind the real method + mock.estimate_pixel_size = lambda: TileFusion.estimate_pixel_size(mock) + return mock + + def test_perfect_calibration(self): + """When measured shifts match expected, deviation should be ~0%.""" + tile_positions = [(0, 0), (0, 90), (90, 0), (90, 90)] + pixel_size = (1.0, 1.0) + pairwise_metrics = { + (0, 1): (0, 90, 0.95), + (0, 2): (90, 0, 0.95), + (1, 3): (90, 0, 0.95), + (2, 3): (0, 90, 0.95), + } + + tf = self._create_mock_tilefusion(tile_positions, pixel_size, pairwise_metrics) + estimated, deviation = tf.estimate_pixel_size() + + assert abs(estimated - 1.0) < 0.01 + assert abs(deviation) < 1.0 + + def test_pixel_size_underestimated(self): + """When metadata pixel size is too small, estimated should be larger.""" + tile_positions = [(0, 0), (0, 90)] + pixel_size = (1.0, 1.0) + pairwise_metrics = {(0, 1): (0, 82, 0.95)} + + tf = self._create_mock_tilefusion(tile_positions, pixel_size, pairwise_metrics) + estimated, deviation = tf.estimate_pixel_size() + + assert 1.05 < estimated < 1.15 + assert deviation > 5.0 + + def test_no_metrics_raises(self): + """Should raise if no pairwise metrics.""" + tf = self._create_mock_tilefusion([], (1.0, 1.0), {}) + + with pytest.raises(ValueError, match="No pairwise metrics"): + tf.estimate_pixel_size()