Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -176,4 +176,4 @@ output/*

.python-version

.DS_Store
.DS_Store
37 changes: 21 additions & 16 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,16 +67,21 @@ pip install -e .[dev]

To batch-process a folder of flat-field FITS files:

```python
from lris2_drp.flows import batch_process_all_flats
```bash
python src/main.py
```

Configure the pipeline via `config/config.yaml`:

batch_process_all_flats(
input_dir="/path/to/flats",
output_dir="/path/to/results",
)
```yaml
input_dir: data/lris2_flats
output_dir: output
use_prefect_server: true
save_corrected_flat: false # save corrected flat images
max_workers: 2
```

This will process up to 2 files in parallel (configurable) using Prefects `ConcurrentTaskRunner`.
Files are processed in parallel using Prefect's `ConcurrentTaskRunner`.

Each file goes through:

Expand All @@ -94,12 +99,13 @@ For each input FITS file, the following will be written to the output directory:

```
<filename>/
├── flat_corrected.fits # Flat-field corrected image
├── flat_norm_qa.png # QA plot of normalized flat
└── slit_trace.txt # Slit trace positions
├── flat_correction.fits # Flat-field correction matrix
├── flat_corrected.fits # Corrected flat image (if save_corrected: true)
├── flat_norm_qa.png # QA plot of normalized flat
└── slit_trace.txt # Slit trace positions
```

The corrected FITS file includes a `FLATCOR` keyword in the header:
The correction FITS file includes a `FLATCOR` keyword in the header:

```
FLATCOR = 'True' / Flat-field correction applied
Expand All @@ -113,19 +119,18 @@ Additional keywords track reduction steps (optional to expand).

### Adjust Parallelism

To change the number of files processed in parallel, set `max_workers` in `ConcurrentTaskRunner`:
To change the number of files processed in parallel, set `max_workers` in `config/config.yaml`:

```python
@flow(task_runner=ConcurrentTaskRunner(max_workers=4))
def batch_process_all_flats(...):
```yaml
max_workers: 4
```

### Customize Output Paths

Output filenames and directory structure can be customized in:

- `save_trace_solution()`
- `save_corrected_fits()`
- `save_flat_fits()`
- `generate_qa_plot()`

---
6 changes: 6 additions & 0 deletions config/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,9 @@ output_dir: output

# Set to true to start Prefect UI (at http://127.0.0.1:4200)
use_prefect_server: true

# Set to true to also save corrected flat images (flat_corrected.fits)
save_corrected_flat: true

# Number of files to process in parallel
max_workers: 2
42 changes: 31 additions & 11 deletions src/core/flat.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,15 +43,38 @@ def create_flat_correction(norm_data: np.ndarray) -> np.ndarray:
return correction


def save_corrected_fits(original_data: np.ndarray, correction: np.ndarray, header: dict, output_path: str) -> str:
"""Apply the flat correction to the original data and save as a new FITS file."""
corrected_data = original_data.astype(np.float64) * correction
def save_flat_fits(
correction: np.ndarray,
header: dict,
output_path: str,
original_data: Optional[np.ndarray] = None
) -> str:
"""Save flat field results to a FITS file.

# Add DRP history to header
header.add_history("DRP: Flat field correction applied")
header["FLATCORR"] = (True, "Flat field correction applied")
By default, saves the correction matrix (multiplicative factors ~1.0 to apply to science frames).
If original_data is provided, saves the corrected image instead.

hdu = fits.PrimaryHDU(data=corrected_data, header=header)
Args:
correction: The flat field correction matrix
header: FITS header to include
output_path: Path to save the FITS file
original_data: If provided, save corrected image (correction / original_data)

Returns:
Path to the saved FITS file
"""
header = header.copy()

if original_data is not None:
header.add_history("DRP: Flat field corrected image created")
header["FLATCORR"] = (True, "Flat field corrected image")
data_to_save = correction / original_data.astype(np.float64)
else:
header.add_history("DRP: Flat field correction map created")
header["FLATCORR"] = (True, "Flat field correction map")
data_to_save = correction

hdu = fits.PrimaryHDU(data=data_to_save, header=header)
hdul = fits.HDUList([hdu])
os.makedirs(os.path.dirname(output_path), exist_ok=True)
hdul.writeto(output_path, overwrite=True)
Expand Down Expand Up @@ -202,8 +225,5 @@ def create_master_flat(flat_data: np.ndarray, **kwargs) -> np.ndarray:
Returns:
Master flat correction map (multiply science data by this)
"""
ratio = normalize_flat_spectroscopic(flat_data, **kwargs)
# The ratio is already the correction map
correction = ratio

correction = normalize_flat_spectroscopic(flat_data, **kwargs)
return correction
Original file line number Diff line number Diff line change
@@ -1,18 +1,22 @@
from keckdrpframework.models.arguments import Arguments
from keckdrpframework.primitives.base_primitive import BasePrimitive
from core.flat import save_corrected_fits
from core.flat import save_flat_fits


class SaveCorrectedFits(BasePrimitive):
class SaveFlatFits(BasePrimitive):
def __init__(self, action, context):
super().__init__(action, context)

def _perform(self, input_args: Arguments, config: dict) -> dict:
"""Apply the flat correction to the original data and save as a new FITS file."""
original_data = input_args["original_data"]
"""Save flat field results to a FITS file.

By default saves the correction matrix. If original_data is provided,
saves the corrected image instead.
"""
correction = input_args["correction"]
header = input_args["header"]
output_path = input_args["output_path"]
original_data = input_args["original_data"] if "original_data" in input_args else None

result_path = save_corrected_fits(original_data, correction, header, output_path)
result_path = save_flat_fits(correction, header, output_path, original_data)
return {"output_path": result_path}
7 changes: 4 additions & 3 deletions src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,9 @@
This script initializes the pipeline and processes all flat field FITS files
found in the specified input directory, saving the results to the output directory.
"""
import yaml
import os
import subprocess
import time
import yaml
from workflows.flows.batch_flat_flow import batch_process_all_flats

def load_config(config_path="config/config.yaml"):
Expand All @@ -19,6 +18,8 @@ def load_config(config_path="config/config.yaml"):
input_dir = config["input_dir"]
output_dir = config["output_dir"]
use_prefect_server = config.get("use_prefect_server", True)
save_corrected_flat = config.get("save_corrected_flat", False)
max_workers = config.get("max_workers", 2)

os.makedirs(output_dir, exist_ok=True)

Expand All @@ -39,7 +40,7 @@ def load_config(config_path="config/config.yaml"):

try:
print(f"🟢 Starting batch processing of FITS files in {input_dir}")
batch_process_all_flats(input_dir=input_dir, output_dir=output_dir)
batch_process_all_flats(input_dir=input_dir, output_dir=output_dir, save_corrected_flat=save_corrected_flat, max_workers=max_workers)

if use_prefect_server:
print("\n✅ Pipeline completed!")
Expand Down
52 changes: 31 additions & 21 deletions src/workflows/flows/batch_flat_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,30 +4,31 @@
from workflows.prefect_tasks.load_flat import load_flat_frame_task
from workflows.prefect_tasks.create_master_flat import create_master_flat_task
from workflows.prefect_tasks.trace_slits import trace_slits_task
from workflows.prefect_tasks.save_corrected import save_corrected_fits_task
from workflows.prefect_tasks.save_correction import save_flat_fits_task
from workflows.prefect_tasks.save_trace import save_trace_solution_task
from workflows.prefect_tasks.qa_plot import generate_qa_plot_task


@task(name="Process Single Flat Frame")
def process_single_flat_frame(flat_fits_path: str, output_dir: str):
def process_single_flat_frame(flat_fits_path: str, output_dir: str, save_corrected_flat: bool = False):
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think there is a use case to correct the flat image itself by the master_flat and save the corrected flat. As far as I know it is not needed for any of the DRP steps. We just need the master flat.

"""
Process a single LRIS2 flat FITS file through all DRP steps.

Args:
flat_fits_path: Path to input FITS file
output_dir: Output directory for results
save_corrected_flat: If True, also save the corrected flat image
"""
logger = get_run_logger()
filename = os.path.splitext(os.path.basename(flat_fits_path))[0]

# Construct output paths
corrected_output = os.path.join(output_dir, filename, "flat_corrected.fits")
correction_output = os.path.join(output_dir, filename, "flat_correction.fits")
trace_output = os.path.join(output_dir, filename, "slit_trace.txt")
qa_output = os.path.join(output_dir, filename, "flat_norm_qa.png")

# Ensure output dirs
os.makedirs(os.path.dirname(corrected_output), exist_ok=True)
os.makedirs(os.path.dirname(correction_output), exist_ok=True)

# Load FITS
logger.info(f"Loading {flat_fits_path}")
Expand All @@ -50,35 +51,44 @@ def process_single_flat_frame(flat_fits_path: str, output_dir: str):

# Save outputs
logger.info("Saving results")
save_corrected_fits_task(data, correction, header, corrected_output)
save_flat_fits_task(correction, header, correction_output)
if save_corrected_flat:
corrected_output = os.path.join(output_dir, filename, "flat_corrected.fits")
save_flat_fits_task(correction, header, corrected_output, original_data=data)
save_trace_solution_task(slit_positions, trace_output)
generate_qa_plot_task(correction, qa_output)

logger.info(f"Finished processing {flat_fits_path}")

@flow(
name="Batch Process LRIS2 Flats",
description="Process all flat frames using spectroscopic flat fielding",
task_runner=ConcurrentTaskRunner(max_workers=2), # You can adjust this
)
def batch_process_all_flats(input_dir: str, output_dir: str):
def batch_process_all_flats(input_dir: str, output_dir: str, save_corrected_flat: bool = False, max_workers: int = 2):
"""
Process all FITS files in a directory using spectroscopic flat fielding.

Args:
input_dir: Directory containing input FITS files
output_dir: Directory for output files
save_corrected_flat: If True, also save the corrected flat image
max_workers: Number of files to process in parallel
"""
logger = get_run_logger()
@flow(
name="Batch Process LRIS2 Flats",
description="Process all flat frames using spectroscopic flat fielding",
task_runner=ConcurrentTaskRunner(max_workers=max_workers),
)
def _batch_process_all_flats(input_dir: str, output_dir: str, save_corrected_flat: bool, max_workers: int):
logger = get_run_logger()

fits_files = [
os.path.join(input_dir, f)
for f in os.listdir(input_dir)
if f.lower().endswith(".fits")
]
logger.info(f"Found {len(fits_files)} FITS files in {input_dir}.")
logger.info(f"Settings: save_corrected_flat={save_corrected_flat}, max_workers={max_workers}")

fits_files = [
os.path.join(input_dir, f)
for f in os.listdir(input_dir)
if f.lower().endswith(".fits")
]
logger.info(f"Found {len(fits_files)} FITS files in {input_dir}.")
futures = [process_single_flat_frame.submit(fp, output_dir, save_corrected_flat) for fp in fits_files]

futures = [process_single_flat_frame.submit(fp, output_dir) for fp in fits_files]
for future in futures:
future.result()

for future in futures:
future.result()
return _batch_process_all_flats(input_dir, output_dir, save_corrected_flat, max_workers)
2 changes: 1 addition & 1 deletion src/workflows/prefect_tasks/load_flat.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,4 @@ def load_flat_frame_task(filepath: str):
context = DummyContext()

result = LoadFlat(action, context)._perform(args, config={})
return result["flat_data"], result["header"]
return result["flat_data"], result["header"]
2 changes: 1 addition & 1 deletion src/workflows/prefect_tasks/qa_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,4 @@ def generate_qa_plot_task(data, output_path: str, title: str = "Flat QA"):
context = DummyContext()

result = GenerateQAPlot(action, context)._perform(args, config={})
return result["output_path"]
return result["output_path"]
19 changes: 0 additions & 19 deletions src/workflows/prefect_tasks/save_corrected.py

This file was deleted.

25 changes: 25 additions & 0 deletions src/workflows/prefect_tasks/save_correction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from prefect import task
from keckdrpframework.models.arguments import Arguments
from keck_primitives.save_correction import SaveFlatFits
from keck_primitives.utils import DummyAction, DummyContext


@task(name="Save Flat FITS")
def save_flat_fits_task(correction, header, output_path: str, original_data=None):
"""Save flat field results to a FITS file.

By default saves the correction matrix. If original_data is provided,
saves the corrected image instead.
"""
args = Arguments()
args["correction"] = correction
args["header"] = header
args["output_path"] = output_path
if original_data is not None:
args["original_data"] = original_data

action = DummyAction(args=args)
context = DummyContext()

result = SaveFlatFits(action, context)._perform(args, config={})
return result["output_path"]
2 changes: 1 addition & 1 deletion src/workflows/prefect_tasks/save_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,4 @@ def save_trace_solution_task(slit_positions, output_path: str):
context = DummyContext()

result = SaveTraceSolution(action, context)._perform(args, config={})
return result["output_path"]
return result["output_path"]
Loading