From 8d93bdff11f553623fed9e1bc63307ad74c21e7f Mon Sep 17 00:00:00 2001 From: Michael Langmayr Date: Fri, 16 Jan 2026 13:47:41 -0800 Subject: [PATCH 1/2] add slit tracing flow using Pypeit EdgeTraceSet --- src/core/pypeit_tracing.py | 134 ++++++++++++++++++++ src/core/qa.py | 69 ++++++++++ src/main.py | 4 +- src/workflows/flows/pypeit_trace_flow.py | 95 ++++++++++++++ src/workflows/prefect_tasks/pypeit_tasks.py | 118 +++++++++++++++++ 5 files changed, 419 insertions(+), 1 deletion(-) create mode 100644 src/core/pypeit_tracing.py create mode 100644 src/workflows/flows/pypeit_trace_flow.py create mode 100644 src/workflows/prefect_tasks/pypeit_tasks.py diff --git a/src/core/pypeit_tracing.py b/src/core/pypeit_tracing.py new file mode 100644 index 0000000..ddcab66 --- /dev/null +++ b/src/core/pypeit_tracing.py @@ -0,0 +1,134 @@ +""" +PyPEIT-based slit tracing. +""" +import os +from typing import Tuple, List, Optional +import numpy as np + +from pypeit.images.buildimage import TraceImage +from pypeit.edgetrace import EdgeTraceSet +from pypeit.par.pypeitpar import EdgeTracePar +from pypeit.spectrographs.util import load_spectrograph + + +def trace_slits_pypeit( + data: np.ndarray, + spectrograph_name: str = 'keck_lris_red', + fwhm_uniform: float = 3.0, + niter_uniform: int = 9, + det_min_spec_length: float = 0.3, + follow_span: int = 20, + fit_order: int = 5, +) -> Tuple[np.ndarray, np.ndarray, Optional[object]]: + """ + Trace slit edges using PyPEIT's EdgeTraceSet. + + Args: + data: 2D flat field image (spatial x spectral) + spectrograph_name: PyPEIT spectrograph name (default: keck_lris_red) + Use keck_lris_red or keck_lris_blue as proxy for LRIS2 evaluation + fwhm_uniform: FWHM for uniform filter smoothing + niter_uniform: Number of iterations for uniform filter + det_min_spec_length: Minimum spectral length for valid trace (fraction) + follow_span: Number of pixels to use when following edges + fit_order: Polynomial order for trace fitting + + Returns: + Tuple of: + - left_edges: 2D array of left edge positions (n_slits x n_spectral) + - right_edges: 2D array of right edge positions (n_slits x n_spectral) + - edges: The EdgeTraceSet object (for advanced use/saving) + """ + # Load spectrograph (required by PyPEIT) + spectrograph = load_spectrograph(spectrograph_name) + + # Build EdgeTracePar with our settings + edge_par = EdgeTracePar() + edge_par['fwhm_uniform'] = fwhm_uniform + edge_par['niter_uniform'] = niter_uniform + edge_par['det_min_spec_length'] = det_min_spec_length + edge_par['follow_span'] = follow_span + edge_par['fit_order'] = fit_order + + # Create TraceImage from numpy array + # PyPEIT expects the image in a specific format + trace_img = TraceImage(data.astype(np.float64)) + + # Create EdgeTraceSet with spectrograph for proper defaults + # Use auto=True to run the full tracing pipeline automatically + # This handles PCA decomposition, edge syncing, and all refinement steps + edges = EdgeTraceSet( + trace_img, + spectrograph=spectrograph, + par=edge_par, + auto=True, # Run full pipeline automatically + ) + + # Extract edge positions + # edges.edge_fit contains the fitted edge positions + # Shape is (nspec, ntrace) where ntrace = 2 * nslits + if edges.edge_fit is None or edges.edge_fit.size == 0: + # No edges found - return empty arrays + nspec = data.shape[0] + return np.array([]).reshape(0, nspec), np.array([]).reshape(0, nspec), edges + + # Separate left and right edges + # In PyPEIT, left edges have negative trace IDs, right have positive + left_mask = edges.traceid < 0 + right_mask = edges.traceid > 0 + + # Get edge positions (shape: nspec x ntrace) + edge_positions = edges.edge_fit + + # Extract left and right edges + left_edges = edge_positions[:, left_mask].T # (n_slits, nspec) + right_edges = edge_positions[:, right_mask].T # (n_slits, nspec) + + return left_edges, right_edges, edges + + +def get_slit_centers(left_edges: np.ndarray, right_edges: np.ndarray) -> List[int]: + """ + Calculate slit center positions from left/right edges. + + This provides compatibility with the existing trace_slits_1d interface. + + Args: + left_edges: 2D array of left edge positions (n_slits x n_spectral) + right_edges: 2D array of right edge positions (n_slits x n_spectral) + + Returns: + List of slit center positions (median across spectral direction) + """ + if left_edges.size == 0 or right_edges.size == 0: + return [] + + # Calculate center as midpoint between left and right edges + # Use median along spectral direction for single position per slit + centers = (np.median(left_edges, axis=1) + np.median(right_edges, axis=1)) / 2 + return [int(c) for c in centers] + + +def save_edge_trace( + left_edges: np.ndarray, + right_edges: np.ndarray, + output_path: str +) -> str: + """ + Save slit edge traces to a numpy file. + + Args: + left_edges: 2D array of left edge positions + right_edges: 2D array of right edge positions + output_path: Path for output file + + Returns: + Path to saved file + """ + os.makedirs(os.path.dirname(output_path), exist_ok=True) + np.savez( + output_path, + left_edges=left_edges, + right_edges=right_edges + ) + return output_path diff --git a/src/core/qa.py b/src/core/qa.py index 7b3d638..af5d392 100644 --- a/src/core/qa.py +++ b/src/core/qa.py @@ -3,6 +3,7 @@ matplotlib.use('Agg') # Use non-interactive backend for saving plots import matplotlib.pyplot as plt import numpy as np +from typing import Optional def generate_qa_plot(data: np.ndarray, output_path: str, title: str = "Flat QA") -> str: @@ -20,3 +21,71 @@ def generate_qa_plot(data: np.ndarray, output_path: str, title: str = "Flat QA") plt.savefig(output_path) plt.close() return output_path + + +def generate_trace_qa_plot( + data: np.ndarray, + left_edges: np.ndarray, + right_edges: np.ndarray, + output_path: str, + title: str = "Slit Trace QA", + vmin: Optional[float] = None, + vmax: Optional[float] = None, +) -> str: + """ + Generate a QA plot showing traced slit edges overlaid on the flat image. + + Args: + data: 2D flat field image + left_edges: 2D array of left edge positions (n_slits x n_spectral) + right_edges: 2D array of right edge positions (n_slits x n_spectral) + output_path: Path to save the plot + title: Plot title + vmin: Minimum value for image scaling + vmax: Maximum value for image scaling + + Returns: + Path to saved plot + """ + if data is None or not hasattr(data, "shape") or data.ndim != 2: + raise ValueError(f"Invalid data shape for QA plot: {getattr(data, 'shape', None)}") + + os.makedirs(os.path.dirname(output_path), exist_ok=True) + + fig, ax = plt.subplots(figsize=(12, 8)) + + # Display the flat image + if vmin is None: + vmin = np.percentile(data, 1) + if vmax is None: + vmax = np.percentile(data, 99) + + im = ax.imshow(data, cmap="gray", aspect="auto", origin="lower", vmin=vmin, vmax=vmax) + fig.colorbar(im, ax=ax, label="Counts") + + # Overlay traced edges + n_slits = left_edges.shape[0] if left_edges.size > 0 else 0 + nspec = data.shape[0] + spectral_coords = np.arange(nspec) + + colors = plt.cm.tab10(np.linspace(0, 1, max(n_slits, 1))) + + for i in range(n_slits): + color = colors[i % len(colors)] + # Plot left edge + ax.plot(left_edges[i], spectral_coords, color=color, linewidth=1.5, label=f"Slit {i+1}" if i < 10 else None) + # Plot right edge + ax.plot(right_edges[i], spectral_coords, color=color, linewidth=1.5, linestyle="--") + + ax.set_xlabel("Spatial (pixels)") + ax.set_ylabel("Spectral (pixels)") + ax.set_title(f"{title} - {n_slits} slits detected") + + if n_slits > 0 and n_slits <= 10: + ax.legend(loc="upper right", fontsize=8) + + plt.tight_layout() + plt.savefig(output_path, dpi=150) + plt.close() + + return output_path diff --git a/src/main.py b/src/main.py index 9e14c3d..719f83a 100644 --- a/src/main.py +++ b/src/main.py @@ -7,6 +7,7 @@ import os import subprocess from workflows.flows.batch_flat_flow import batch_process_all_flats +from workflows.flows.pypeit_trace_flow import pypeit_trace_flow def load_config(config_path="config/config.yaml"): """Load configuration from a YAML file.""" @@ -38,7 +39,8 @@ def load_config(config_path="config/config.yaml"): try: print(f"🟢 Starting batch processing of FITS files in {input_dir}") - batch_process_all_flats(input_dir=input_dir, output_dir=output_dir) + # batch_process_all_flats(input_dir=input_dir, output_dir=output_dir) + pypeit_trace_flow(input_dir=input_dir, output_dir=output_dir) if use_prefect_server: print("\nāœ… Pipeline completed!") diff --git a/src/workflows/flows/pypeit_trace_flow.py b/src/workflows/flows/pypeit_trace_flow.py new file mode 100644 index 0000000..913c63c --- /dev/null +++ b/src/workflows/flows/pypeit_trace_flow.py @@ -0,0 +1,95 @@ +""" +PyPEIT-based slit tracing flow. + +This flow uses PyPEIT's algorithms for slit tracing. +""" +import os +from prefect import flow, task, get_run_logger +from prefect.task_runners import ConcurrentTaskRunner +from workflows.prefect_tasks.save_trace import save_trace_solution_task +from workflows.prefect_tasks.pypeit_tasks import ( + load_flat_frame_task, + trace_slits_pypeit_task, + get_slit_centers_task, + save_edge_trace_task, + generate_trace_qa_plot_task, +) + + +@task(name="Trace Slits (PyPEIT)") +def trace_slits_pypeit(fits_path: str, output_dir: str): + """ + Trace slits in a FITS file using PyPEIT algorithms. + + Args: + fits_path: Path to input FITS file + output_dir: Output directory for results + """ + logger = get_run_logger() + filename = os.path.splitext(os.path.basename(fits_path))[0] + + # Construct output paths + trace_output = os.path.join(output_dir, filename, "slit_trace.txt") + edges_output = os.path.join(output_dir, filename, "slit_edges.npz") + qa_output = os.path.join(output_dir, filename, "trace_qa.png") + + # Ensure output dirs + os.makedirs(os.path.dirname(trace_output), exist_ok=True) + + # Load FITS + logger.info(f"Loading {fits_path}") + data, header = load_flat_frame_task(fits_path) + + # Trace slits with PyPEIT + logger.info("Tracing slits with PyPEIT EdgeTraceSet") + left_edges, right_edges = trace_slits_pypeit_task(data) + + # Get slit centers for compatibility with existing outputs + slit_positions = get_slit_centers_task(left_edges, right_edges) + logger.info(f"Found {len(slit_positions)} slits") + + # Save outputs + logger.info("Saving results") + save_trace_solution_task(slit_positions, trace_output) + save_edge_trace_task(left_edges, right_edges, edges_output) + + # Generate QA plot + logger.info("Generating QA plot") + generate_trace_qa_plot_task(data, left_edges, right_edges, qa_output, title=filename) + + logger.info(f"Finished tracing {fits_path}") + + +@flow( + name="PyPEIT Slit Tracing", + description="Trace slits in FITS frames using PyPEIT algorithms", + task_runner=ConcurrentTaskRunner(max_workers=2), +) +def pypeit_trace_flow(input_dir: str, output_dir: str): + """ + Trace slits in all FITS files using PyPEIT algorithms. + + This flow uses PyPEIT's EdgeTraceSet for slit tracing. + + Args: + input_dir: Directory containing input FITS files + output_dir: Directory for output files + """ + logger = get_run_logger() + + fits_files = [ + os.path.join(input_dir, f) + for f in os.listdir(input_dir) + if f.lower().endswith(".fits") + ] + logger.info(f"Found {len(fits_files)} FITS files in {input_dir}") + + futures = [ + trace_slits_pypeit.submit(fp, output_dir) + for fp in fits_files + ] + + for future in futures: + future.result() + + logger.info("PyPEIT slit tracing complete") diff --git a/src/workflows/prefect_tasks/pypeit_tasks.py b/src/workflows/prefect_tasks/pypeit_tasks.py new file mode 100644 index 0000000..5c3f7e4 --- /dev/null +++ b/src/workflows/prefect_tasks/pypeit_tasks.py @@ -0,0 +1,118 @@ +""" +Prefect task wrappers for PyPEIT-based processing. +""" +from typing import Tuple, List +import numpy as np +from prefect import task, get_run_logger + +from core.flat import load_flat_frame +from core.pypeit_tracing import trace_slits_pypeit, get_slit_centers, save_edge_trace +from core.qa import generate_trace_qa_plot + + +@task(name="Trace Slits PyPEIT") +def trace_slits_pypeit_task( + data: np.ndarray, + **kwargs +) -> Tuple[np.ndarray, np.ndarray]: + """ + Prefect task to trace slits using PyPEIT's EdgeTraceSet. + + Args: + data: 2D flat field image + **kwargs: Additional arguments for trace_slits_pypeit + + Returns: + Tuple of (left_edges, right_edges) + """ + logger = get_run_logger() + logger.info("Running PyPEIT edge tracing") + + left_edges, right_edges, _ = trace_slits_pypeit(data, **kwargs) + + n_slits = left_edges.shape[0] if left_edges.size > 0 else 0 + logger.info(f"PyPEIT found {n_slits} slits") + + return left_edges, right_edges + + +@task(name="Get Slit Centers PyPEIT") +def get_slit_centers_task( + left_edges: np.ndarray, + right_edges: np.ndarray +) -> List[int]: + """ + Prefect task to get slit center positions from edge arrays. + + Args: + left_edges: 2D array of left edge positions + right_edges: 2D array of right edge positions + + Returns: + List of slit center positions + """ + return get_slit_centers(left_edges, right_edges) + + +@task(name="Save Edge Trace") +def save_edge_trace_task( + left_edges: np.ndarray, + right_edges: np.ndarray, + output_path: str +) -> str: + """ + Prefect task to save edge trace arrays. + + Args: + left_edges: 2D array of left edge positions + right_edges: 2D array of right edge positions + output_path: Path for output file + + Returns: + Path to saved file + """ + logger = get_run_logger() + logger.info(f"Saving edge traces to {output_path}") + return save_edge_trace(left_edges, right_edges, output_path) + + +@task(name="Load Flat Frame") +def load_flat_frame_task(filepath: str) -> Tuple[np.ndarray, dict]: + """ + Prefect task to load a FITS file. + + Args: + filepath: Path to FITS file + + Returns: + Tuple of (data, header) + """ + logger = get_run_logger() + logger.info(f"Loading FITS file: {filepath}") + return load_flat_frame(filepath) + + +@task(name="Generate Trace QA Plot") +def generate_trace_qa_plot_task( + data: np.ndarray, + left_edges: np.ndarray, + right_edges: np.ndarray, + output_path: str, + title: str = "Slit Trace QA", +) -> str: + """ + Prefect task to generate a QA plot for traced slits. + + Args: + data: 2D flat field image + left_edges: 2D array of left edge positions + right_edges: 2D array of right edge positions + output_path: Path to save the plot + title: Plot title + + Returns: + Path to saved plot + """ + logger = get_run_logger() + logger.info(f"Generating trace QA plot: {output_path}") + return generate_trace_qa_plot(data, left_edges, right_edges, output_path, title) From ecd551e4e629ea671ba0a3b9b5ffac336b3b4741 Mon Sep 17 00:00:00 2001 From: Michael Langmayr Date: Fri, 16 Jan 2026 13:49:33 -0800 Subject: [PATCH 2/2] cleanup --- src/workflows/flows/pypeit_trace_flow.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/workflows/flows/pypeit_trace_flow.py b/src/workflows/flows/pypeit_trace_flow.py index 913c63c..ef097d3 100644 --- a/src/workflows/flows/pypeit_trace_flow.py +++ b/src/workflows/flows/pypeit_trace_flow.py @@ -1,7 +1,5 @@ """ PyPEIT-based slit tracing flow. - -This flow uses PyPEIT's algorithms for slit tracing. """ import os from prefect import flow, task, get_run_logger