Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 134 additions & 0 deletions src/core/pypeit_tracing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
"""
PyPEIT-based slit tracing.
"""
import os
from typing import Tuple, List, Optional
import numpy as np

from pypeit.images.buildimage import TraceImage
from pypeit.edgetrace import EdgeTraceSet
from pypeit.par.pypeitpar import EdgeTracePar
from pypeit.spectrographs.util import load_spectrograph


def trace_slits_pypeit(
data: np.ndarray,
spectrograph_name: str = 'keck_lris_red',
fwhm_uniform: float = 3.0,
niter_uniform: int = 9,
det_min_spec_length: float = 0.3,
follow_span: int = 20,
fit_order: int = 5,
) -> Tuple[np.ndarray, np.ndarray, Optional[object]]:
"""
Trace slit edges using PyPEIT's EdgeTraceSet.

Args:
data: 2D flat field image (spatial x spectral)
spectrograph_name: PyPEIT spectrograph name (default: keck_lris_red)
Use keck_lris_red or keck_lris_blue as proxy for LRIS2 evaluation
fwhm_uniform: FWHM for uniform filter smoothing
niter_uniform: Number of iterations for uniform filter
det_min_spec_length: Minimum spectral length for valid trace (fraction)
follow_span: Number of pixels to use when following edges
fit_order: Polynomial order for trace fitting

Returns:
Tuple of:
- left_edges: 2D array of left edge positions (n_slits x n_spectral)
- right_edges: 2D array of right edge positions (n_slits x n_spectral)
- edges: The EdgeTraceSet object (for advanced use/saving)
"""
# Load spectrograph (required by PyPEIT)
spectrograph = load_spectrograph(spectrograph_name)

# Build EdgeTracePar with our settings
edge_par = EdgeTracePar()
edge_par['fwhm_uniform'] = fwhm_uniform
edge_par['niter_uniform'] = niter_uniform
edge_par['det_min_spec_length'] = det_min_spec_length
edge_par['follow_span'] = follow_span
edge_par['fit_order'] = fit_order

# Create TraceImage from numpy array
# PyPEIT expects the image in a specific format
trace_img = TraceImage(data.astype(np.float64))

# Create EdgeTraceSet with spectrograph for proper defaults
# Use auto=True to run the full tracing pipeline automatically
# This handles PCA decomposition, edge syncing, and all refinement steps
edges = EdgeTraceSet(
trace_img,
spectrograph=spectrograph,
par=edge_par,
auto=True, # Run full pipeline automatically
)

# Extract edge positions
# edges.edge_fit contains the fitted edge positions
# Shape is (nspec, ntrace) where ntrace = 2 * nslits
if edges.edge_fit is None or edges.edge_fit.size == 0:
# No edges found - return empty arrays
nspec = data.shape[0]
return np.array([]).reshape(0, nspec), np.array([]).reshape(0, nspec), edges

# Separate left and right edges
# In PyPEIT, left edges have negative trace IDs, right have positive
left_mask = edges.traceid < 0
right_mask = edges.traceid > 0

# Get edge positions (shape: nspec x ntrace)
edge_positions = edges.edge_fit

# Extract left and right edges
left_edges = edge_positions[:, left_mask].T # (n_slits, nspec)
right_edges = edge_positions[:, right_mask].T # (n_slits, nspec)

return left_edges, right_edges, edges


def get_slit_centers(left_edges: np.ndarray, right_edges: np.ndarray) -> List[int]:
"""
Calculate slit center positions from left/right edges.

This provides compatibility with the existing trace_slits_1d interface.

Args:
left_edges: 2D array of left edge positions (n_slits x n_spectral)
right_edges: 2D array of right edge positions (n_slits x n_spectral)

Returns:
List of slit center positions (median across spectral direction)
"""
if left_edges.size == 0 or right_edges.size == 0:
return []

# Calculate center as midpoint between left and right edges
# Use median along spectral direction for single position per slit
centers = (np.median(left_edges, axis=1) + np.median(right_edges, axis=1)) / 2
return [int(c) for c in centers]


def save_edge_trace(
left_edges: np.ndarray,
right_edges: np.ndarray,
output_path: str
) -> str:
"""
Save slit edge traces to a numpy file.

Args:
left_edges: 2D array of left edge positions
right_edges: 2D array of right edge positions
output_path: Path for output file

Returns:
Path to saved file
"""
os.makedirs(os.path.dirname(output_path), exist_ok=True)
np.savez(
output_path,
left_edges=left_edges,
right_edges=right_edges
)
return output_path
69 changes: 69 additions & 0 deletions src/core/qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
matplotlib.use('Agg') # Use non-interactive backend for saving plots
import matplotlib.pyplot as plt
import numpy as np
from typing import Optional


def generate_qa_plot(data: np.ndarray, output_path: str, title: str = "Flat QA") -> str:
Expand All @@ -20,3 +21,71 @@ def generate_qa_plot(data: np.ndarray, output_path: str, title: str = "Flat QA")
plt.savefig(output_path)
plt.close()
return output_path


def generate_trace_qa_plot(
data: np.ndarray,
left_edges: np.ndarray,
right_edges: np.ndarray,
output_path: str,
title: str = "Slit Trace QA",
vmin: Optional[float] = None,
vmax: Optional[float] = None,
) -> str:
"""
Generate a QA plot showing traced slit edges overlaid on the flat image.

Args:
data: 2D flat field image
left_edges: 2D array of left edge positions (n_slits x n_spectral)
right_edges: 2D array of right edge positions (n_slits x n_spectral)
output_path: Path to save the plot
title: Plot title
vmin: Minimum value for image scaling
vmax: Maximum value for image scaling

Returns:
Path to saved plot
"""
if data is None or not hasattr(data, "shape") or data.ndim != 2:
raise ValueError(f"Invalid data shape for QA plot: {getattr(data, 'shape', None)}")

os.makedirs(os.path.dirname(output_path), exist_ok=True)

fig, ax = plt.subplots(figsize=(12, 8))

# Display the flat image
if vmin is None:
vmin = np.percentile(data, 1)
if vmax is None:
vmax = np.percentile(data, 99)

im = ax.imshow(data, cmap="gray", aspect="auto", origin="lower", vmin=vmin, vmax=vmax)
fig.colorbar(im, ax=ax, label="Counts")

# Overlay traced edges
n_slits = left_edges.shape[0] if left_edges.size > 0 else 0
nspec = data.shape[0]
spectral_coords = np.arange(nspec)

colors = plt.cm.tab10(np.linspace(0, 1, max(n_slits, 1)))

for i in range(n_slits):
color = colors[i % len(colors)]
# Plot left edge
ax.plot(left_edges[i], spectral_coords, color=color, linewidth=1.5, label=f"Slit {i+1}" if i < 10 else None)
# Plot right edge
ax.plot(right_edges[i], spectral_coords, color=color, linewidth=1.5, linestyle="--")

ax.set_xlabel("Spatial (pixels)")
ax.set_ylabel("Spectral (pixels)")
ax.set_title(f"{title} - {n_slits} slits detected")

if n_slits > 0 and n_slits <= 10:
ax.legend(loc="upper right", fontsize=8)

plt.tight_layout()
plt.savefig(output_path, dpi=150)
plt.close()

return output_path
4 changes: 3 additions & 1 deletion src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import os
import subprocess
from workflows.flows.batch_flat_flow import batch_process_all_flats
from workflows.flows.pypeit_trace_flow import pypeit_trace_flow

def load_config(config_path="config/config.yaml"):
"""Load configuration from a YAML file."""
Expand Down Expand Up @@ -38,7 +39,8 @@ def load_config(config_path="config/config.yaml"):

try:
print(f"🟢 Starting batch processing of FITS files in {input_dir}")
batch_process_all_flats(input_dir=input_dir, output_dir=output_dir)
# batch_process_all_flats(input_dir=input_dir, output_dir=output_dir)
pypeit_trace_flow(input_dir=input_dir, output_dir=output_dir)

if use_prefect_server:
print("\n✅ Pipeline completed!")
Expand Down
93 changes: 93 additions & 0 deletions src/workflows/flows/pypeit_trace_flow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
"""
PyPEIT-based slit tracing flow.
"""
import os
from prefect import flow, task, get_run_logger
from prefect.task_runners import ConcurrentTaskRunner
from workflows.prefect_tasks.save_trace import save_trace_solution_task
from workflows.prefect_tasks.pypeit_tasks import (
load_flat_frame_task,
trace_slits_pypeit_task,
get_slit_centers_task,
save_edge_trace_task,
generate_trace_qa_plot_task,
)


@task(name="Trace Slits (PyPEIT)")
def trace_slits_pypeit(fits_path: str, output_dir: str):
"""
Trace slits in a FITS file using PyPEIT algorithms.

Args:
fits_path: Path to input FITS file
output_dir: Output directory for results
"""
logger = get_run_logger()
filename = os.path.splitext(os.path.basename(fits_path))[0]

# Construct output paths
trace_output = os.path.join(output_dir, filename, "slit_trace.txt")
edges_output = os.path.join(output_dir, filename, "slit_edges.npz")
qa_output = os.path.join(output_dir, filename, "trace_qa.png")

# Ensure output dirs
os.makedirs(os.path.dirname(trace_output), exist_ok=True)

# Load FITS
logger.info(f"Loading {fits_path}")
data, header = load_flat_frame_task(fits_path)

# Trace slits with PyPEIT
logger.info("Tracing slits with PyPEIT EdgeTraceSet")
left_edges, right_edges = trace_slits_pypeit_task(data)

# Get slit centers for compatibility with existing outputs
slit_positions = get_slit_centers_task(left_edges, right_edges)
logger.info(f"Found {len(slit_positions)} slits")

# Save outputs
logger.info("Saving results")
save_trace_solution_task(slit_positions, trace_output)
save_edge_trace_task(left_edges, right_edges, edges_output)

# Generate QA plot
logger.info("Generating QA plot")
generate_trace_qa_plot_task(data, left_edges, right_edges, qa_output, title=filename)

logger.info(f"Finished tracing {fits_path}")


@flow(
name="PyPEIT Slit Tracing",
description="Trace slits in FITS frames using PyPEIT algorithms",
task_runner=ConcurrentTaskRunner(max_workers=2),
)
def pypeit_trace_flow(input_dir: str, output_dir: str):
"""
Trace slits in all FITS files using PyPEIT algorithms.

This flow uses PyPEIT's EdgeTraceSet for slit tracing.

Args:
input_dir: Directory containing input FITS files
output_dir: Directory for output files
"""
logger = get_run_logger()

fits_files = [
os.path.join(input_dir, f)
for f in os.listdir(input_dir)
if f.lower().endswith(".fits")
]
logger.info(f"Found {len(fits_files)} FITS files in {input_dir}")

futures = [
trace_slits_pypeit.submit(fp, output_dir)
for fp in fits_files
]

for future in futures:
future.result()

logger.info("PyPEIT slit tracing complete")
Loading
Loading