Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 140 additions & 2 deletions gui/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,9 @@ def __init__(
downsample_factor,
flatfield=None,
darkfield=None,
registration_z=None,
registration_t=0,
registration_channel=0,
):
super().__init__()
self.tiff_path = tiff_path
Expand All @@ -139,6 +142,9 @@ def __init__(
self.downsample_factor = downsample_factor
self.flatfield = flatfield
self.darkfield = darkfield
self.registration_z = registration_z
self.registration_t = registration_t
self.registration_channel = registration_channel

def run(self):
try:
Expand All @@ -153,6 +159,9 @@ def run(self):
downsample_factors=(self.downsample_factor, self.downsample_factor),
flatfield=self.flatfield,
darkfield=self.darkfield,
registration_z=self.registration_z,
registration_t=self.registration_t,
channel_to_use=self.registration_channel,
)

positions = np.array(tf_full._tile_positions)
Expand Down Expand Up @@ -198,7 +207,11 @@ def run(self):

# Create a new TileFusion for the subset
tf = TileFusion(
self.tiff_path, downsample_factors=(self.downsample_factor, self.downsample_factor)
self.tiff_path,
downsample_factors=(self.downsample_factor, self.downsample_factor),
registration_z=self.registration_z,
registration_t=self.registration_t,
channel_to_use=self.registration_channel,
)
tf._tile_positions = selected_positions
tf.n_tiles = len(selected_indices)
Expand Down Expand Up @@ -328,6 +341,9 @@ def __init__(
fusion_mode="blended",
flatfield=None,
darkfield=None,
registration_z=None,
registration_t=0,
registration_channel=0,
):
super().__init__()
self.tiff_path = tiff_path
Expand All @@ -337,6 +353,9 @@ def __init__(
self.fusion_mode = fusion_mode
self.flatfield = flatfield
self.darkfield = darkfield
self.registration_z = registration_z
self.registration_t = registration_t
self.registration_channel = registration_channel
self.output_path = None

def run(self):
Expand Down Expand Up @@ -379,6 +398,9 @@ def run(self):
downsample_factors=(self.downsample_factor, self.downsample_factor),
flatfield=self.flatfield,
darkfield=self.darkfield,
registration_z=self.registration_z,
registration_t=self.registration_t,
channel_to_use=self.registration_channel,
)
load_time = time.time() - step_start
self.progress.emit(f"Loaded {tf.n_tiles} tiles ({tf.Y}x{tf.X} each) [{load_time:.1f}s]")
Expand Down Expand Up @@ -702,7 +724,7 @@ class StitcherGUI(QMainWindow):
def __init__(self):
super().__init__()
self.setWindowTitle("Stitcher")
self.setMinimumSize(500, 850)
self.setMinimumSize(580, 850)

self.worker = None
self.output_path = None
Expand All @@ -714,6 +736,12 @@ def __init__(self):
self.darkfield = None # Shape (C, Y, X) or None
self.flatfield_worker = None

# Dataset dimension state (for registration z/t selection)
self.dataset_n_z = 1
self.dataset_n_t = 1
self.dataset_n_channels = 1
self.dataset_channel_names = []

self.setup_ui()

def setup_ui(self):
Expand Down Expand Up @@ -887,6 +915,36 @@ def setup_ui(self):
downsample_layout.addStretch()
settings_layout.addWidget(self.downsample_widget)

# Registration z/t selection (shown when registration enabled AND multi-z/t dataset)
self.reg_zt_widget = QWidget()
self.reg_zt_widget.setVisible(False)
reg_zt_layout = QHBoxLayout(self.reg_zt_widget)
reg_zt_layout.setContentsMargins(20, 0, 0, 0)
self.reg_z_label = QLabel("Z-level:")
reg_zt_layout.addWidget(self.reg_z_label)
self.reg_z_spin = QSpinBox()
self.reg_z_spin.setRange(0, 0)
self.reg_z_spin.setValue(0)
self.reg_z_spin.setToolTip("Z-level to use for registration")
self.reg_z_spin.setFixedWidth(60)
reg_zt_layout.addWidget(self.reg_z_spin)
self.reg_t_label = QLabel("Timepoint:")
reg_zt_layout.addWidget(self.reg_t_label)
self.reg_t_spin = QSpinBox()
self.reg_t_spin.setRange(0, 0)
self.reg_t_spin.setValue(0)
self.reg_t_spin.setToolTip("Timepoint to use for registration")
self.reg_t_spin.setFixedWidth(60)
reg_zt_layout.addWidget(self.reg_t_spin)
self.reg_channel_label = QLabel("Channel:")
reg_zt_layout.addWidget(self.reg_channel_label)
self.reg_channel_combo = QComboBox()
self.reg_channel_combo.setToolTip("Channel to use for registration")
self.reg_channel_combo.setMinimumWidth(120)
reg_zt_layout.addWidget(self.reg_channel_combo)
reg_zt_layout.addStretch()
settings_layout.addWidget(self.reg_zt_widget)

self.blend_checkbox = QCheckBox("Enable blending")
self.blend_checkbox.setChecked(False)
self.blend_checkbox.toggled.connect(self.on_blend_toggled)
Expand Down Expand Up @@ -978,6 +1036,31 @@ def on_file_dropped(self, file_path):
self.clear_flatfield_button.setEnabled(False)
self.save_flatfield_button.setEnabled(False)

# Load dataset dimensions for registration z/t selection
try:
from tilefusion import TileFusion

tf_temp = TileFusion(file_path)
self.dataset_n_z = tf_temp.n_z
self.dataset_n_t = tf_temp.n_t
self.dataset_n_channels = tf_temp.channels
if "channel_names" in tf_temp._metadata:
self.dataset_channel_names = tf_temp._metadata["channel_names"]
else:
self.dataset_channel_names = [
f"Channel {i}" for i in range(self.dataset_n_channels)
]
tf_temp.close()
if self.dataset_n_z > 1 or self.dataset_n_t > 1:
self.log(f"Dataset: {self.dataset_n_z} z-levels, {self.dataset_n_t} timepoints")
self._update_reg_zt_controls()
except Exception:
Comment on lines +1040 to +1057
Copy link

Copilot AI Feb 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

on_file_dropped creates a temporary TileFusion instance to probe dataset dimensions, but tf_temp.close() is only called on the success path. If anything raises after construction (e.g., during metadata access), file handles may leak. Use a context manager (with TileFusion(...) as tf_temp:) or a try/finally to guarantee cleanup.

Copilot uses AI. Check for mistakes.
self.dataset_n_z = 1
self.dataset_n_t = 1
self.dataset_n_channels = 1
self.dataset_channel_names = []
self._update_reg_zt_controls()

# Auto-load existing flatfield if present, otherwise disable correction
# For directories (SQUID folders), also check inside the directory
if path.is_dir():
Expand All @@ -997,6 +1080,41 @@ def on_file_dropped(self, file_path):

def on_registration_toggled(self, checked):
self.downsample_widget.setVisible(checked)
self._update_reg_zt_controls()

def _update_reg_zt_controls(self):
"""Update visibility and ranges of registration z/t controls."""
registration_enabled = self.registration_checkbox.isChecked()
has_multi_z = self.dataset_n_z > 1
has_multi_t = self.dataset_n_t > 1
has_multi_channel = self.dataset_n_channels > 1

# Show z/t widget only when registration is enabled AND dataset has multi-z or multi-t or multi-channel
show_zt = registration_enabled and (has_multi_z or has_multi_t or has_multi_channel)
self.reg_zt_widget.setVisible(show_zt)

if show_zt:
# Update z spinbox
self.reg_z_label.setVisible(has_multi_z)
self.reg_z_spin.setVisible(has_multi_z)
if has_multi_z:
self.reg_z_spin.setRange(0, self.dataset_n_z - 1)
self.reg_z_spin.setValue(self.dataset_n_z // 2) # Default to middle

# Update t spinbox
self.reg_t_label.setVisible(has_multi_t)
self.reg_t_spin.setVisible(has_multi_t)
if has_multi_t:
self.reg_t_spin.setRange(0, self.dataset_n_t - 1)
self.reg_t_spin.setValue(0) # Default to first timepoint

# Update channel combo
self.reg_channel_label.setVisible(has_multi_channel)
self.reg_channel_combo.setVisible(has_multi_channel)
if has_multi_channel:
self.reg_channel_combo.clear()
self.reg_channel_combo.addItems(self.dataset_channel_names)
self.reg_channel_combo.setCurrentIndex(0)
Comment on lines +1097 to +1117
Copy link

Copilot AI Feb 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_update_reg_zt_controls() resets the z/time/channel selections every time it runs (setValue(...) / setCurrentIndex(0)). Since it’s called when registration is toggled and when files are dropped, a user’s selection can be lost unexpectedly. Preserve the current values when still in range, and only apply defaults when the dataset changes or on first show.

Suggested change
# Update z spinbox
self.reg_z_label.setVisible(has_multi_z)
self.reg_z_spin.setVisible(has_multi_z)
if has_multi_z:
self.reg_z_spin.setRange(0, self.dataset_n_z - 1)
self.reg_z_spin.setValue(self.dataset_n_z // 2) # Default to middle
# Update t spinbox
self.reg_t_label.setVisible(has_multi_t)
self.reg_t_spin.setVisible(has_multi_t)
if has_multi_t:
self.reg_t_spin.setRange(0, self.dataset_n_t - 1)
self.reg_t_spin.setValue(0) # Default to first timepoint
# Update channel combo
self.reg_channel_label.setVisible(has_multi_channel)
self.reg_channel_combo.setVisible(has_multi_channel)
if has_multi_channel:
self.reg_channel_combo.clear()
self.reg_channel_combo.addItems(self.dataset_channel_names)
self.reg_channel_combo.setCurrentIndex(0)
# Preserve previous selections where possible
prev_z = self.reg_z_spin.value() if self.reg_z_spin.isVisible() else None
prev_t = self.reg_t_spin.value() if self.reg_t_spin.isVisible() else None
prev_channel_index = (
self.reg_channel_combo.currentIndex()
if self.reg_channel_combo.count() > 0
else None
)
existing_channels = [
self.reg_channel_combo.itemText(i)
for i in range(self.reg_channel_combo.count())
]
# Update z spinbox
self.reg_z_label.setVisible(has_multi_z)
self.reg_z_spin.setVisible(has_multi_z)
if has_multi_z:
self.reg_z_spin.setRange(0, self.dataset_n_z - 1)
max_z = self.dataset_n_z - 1
if prev_z is not None and 0 <= prev_z <= max_z:
self.reg_z_spin.setValue(prev_z)
else:
# Default to middle slice when no previous valid selection
self.reg_z_spin.setValue(self.dataset_n_z // 2)
# Update t spinbox
self.reg_t_label.setVisible(has_multi_t)
self.reg_t_spin.setVisible(has_multi_t)
if has_multi_t:
self.reg_t_spin.setRange(0, self.dataset_n_t - 1)
max_t = self.dataset_n_t - 1
if prev_t is not None and 0 <= prev_t <= max_t:
self.reg_t_spin.setValue(prev_t)
else:
# Default to first timepoint when no previous valid selection
self.reg_t_spin.setValue(0)
# Update channel combo
self.reg_channel_label.setVisible(has_multi_channel)
self.reg_channel_combo.setVisible(has_multi_channel)
if has_multi_channel:
# Only rebuild items if the dataset channels changed
if existing_channels != self.dataset_channel_names:
self.reg_channel_combo.clear()
self.reg_channel_combo.addItems(self.dataset_channel_names)
# On dataset change, default to first channel
self.reg_channel_combo.setCurrentIndex(0)
else:
# Preserve previous channel selection if still valid
max_channel_index = self.reg_channel_combo.count() - 1
if (
prev_channel_index is not None
and 0 <= prev_channel_index <= max_channel_index
):
self.reg_channel_combo.setCurrentIndex(prev_channel_index)
else:
self.reg_channel_combo.setCurrentIndex(0)

Copilot uses AI. Check for mistakes.

def on_blend_toggled(self, checked):
self.blend_value_widget.setVisible(checked)
Expand Down Expand Up @@ -1228,6 +1346,13 @@ def run_stitching(self):
flatfield = self.flatfield if self.flatfield_checkbox.isChecked() else None
darkfield = self.darkfield if self.flatfield_checkbox.isChecked() else None

# Get registration z/t values (None means use default middle z)
registration_z = self.reg_z_spin.value() if self.dataset_n_z > 1 else None
registration_t = self.reg_t_spin.value() if self.dataset_n_t > 1 else 0
registration_channel = (
self.reg_channel_combo.currentIndex() if self.dataset_n_channels > 1 else 0
)

self.worker = FusionWorker(
self.drop_area.file_path,
self.registration_checkbox.isChecked(),
Expand All @@ -1236,6 +1361,9 @@ def run_stitching(self):
fusion_mode,
flatfield=flatfield,
darkfield=darkfield,
registration_z=registration_z,
registration_t=registration_t,
registration_channel=registration_channel,
)
self.worker.progress.connect(self.log)
self.worker.finished.connect(self.on_fusion_finished)
Expand Down Expand Up @@ -1294,13 +1422,23 @@ def run_preview(self):
flatfield = self.flatfield if self.flatfield_checkbox.isChecked() else None
darkfield = self.darkfield if self.flatfield_checkbox.isChecked() else None

# Get registration z/t values (None means use default middle z)
registration_z = self.reg_z_spin.value() if self.dataset_n_z > 1 else None
registration_t = self.reg_t_spin.value() if self.dataset_n_t > 1 else 0
registration_channel = (
self.reg_channel_combo.currentIndex() if self.dataset_n_channels > 1 else 0
)

self.preview_worker = PreviewWorker(
self.drop_area.file_path,
self.preview_cols_spin.value(),
self.preview_rows_spin.value(),
self.downsample_spin.value(),
flatfield=flatfield,
darkfield=darkfield,
registration_z=registration_z,
registration_t=registration_t,
registration_channel=registration_channel,
)
self.preview_worker.progress.connect(self.log)
self.preview_worker.finished.connect(self.on_preview_finished)
Expand Down
26 changes: 22 additions & 4 deletions src/tilefusion/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,8 @@ def __init__(
region: Optional[str] = None,
flatfield: Optional[np.ndarray] = None,
darkfield: Optional[np.ndarray] = None,
registration_z: Optional[int] = None,
registration_t: int = 0,
):
self.tiff_path = Path(tiff_path)
if not self.tiff_path.exists():
Expand Down Expand Up @@ -194,6 +196,18 @@ def __init__(
self._time_folders = self._metadata.get("time_folders", None)
self._middle_z = self.n_z // 2 # Use middle z-level for registration

# Registration z/t selection (validate after n_z/n_t are known)
if registration_z is None:
self._registration_z = self._middle_z
else:
if registration_z < 0 or registration_z >= self.n_z:
raise ValueError(f"registration_z={registration_z} out of range [0, {self.n_z})")
self._registration_z = registration_z

if registration_t < 0 or registration_t >= self.n_t:
raise ValueError(f"registration_t={registration_t} out of range [0, {self.n_t})")
self._registration_t = registration_t

# Configuration
self.downsample_factors = tuple(downsample_factors)
self.ssim_window = int(ssim_window)
Expand Down Expand Up @@ -447,10 +461,12 @@ def _update_profiles(self) -> None:
# I/O methods (delegate to format-specific loaders)
# -------------------------------------------------------------------------

def _read_tile(self, tile_idx: int, z_level: int = None, time_idx: int = 0) -> np.ndarray:
def _read_tile(self, tile_idx: int, z_level: int = None, time_idx: int = None) -> np.ndarray:
"""Read a single tile from the input data (all channels)."""
if z_level is None:
z_level = self._middle_z # Default to middle z for registration
z_level = self._registration_z # Default to registration z-level
if time_idx is None:
time_idx = self._registration_t # Default to registration timepoint

if self._is_zarr_format:
zarr_ts = self._metadata["tensorstore"]
Expand Down Expand Up @@ -493,11 +509,13 @@ def _read_tile_region(
y_slice: slice,
x_slice: slice,
z_level: int = None,
time_idx: int = 0,
time_idx: int = None,
) -> np.ndarray:
"""Read a region of a tile from the input data."""
if z_level is None:
z_level = self._middle_z # Default to middle z for registration
z_level = self._registration_z # Default to registration z-level
if time_idx is None:
time_idx = self._registration_t # Default to registration timepoint

if self._is_zarr_format:
zarr_ts = self._metadata["tensorstore"]
Expand Down
Loading