diff --git a/.gitignore b/.gitignore index 8ebe52b..8ad75b4 100644 --- a/.gitignore +++ b/.gitignore @@ -176,4 +176,4 @@ output/* .python-version -.DS_Store \ No newline at end of file +.DS_Store diff --git a/README.md b/README.md index 112c347..d9be295 100644 --- a/README.md +++ b/README.md @@ -67,16 +67,21 @@ pip install -e .[dev] To batch-process a folder of flat-field FITS files: -```python -from lris2_drp.flows import batch_process_all_flats +```bash +python src/main.py +``` + +Configure the pipeline via `config/config.yaml`: -batch_process_all_flats( - input_dir="/path/to/flats", - output_dir="/path/to/results", -) +```yaml +input_dir: data/lris2_flats +output_dir: output +use_prefect_server: true +save_corrected_flat: false # save corrected flat images +max_workers: 2 ``` -This will process up to 2 files in parallel (configurable) using Prefect’s `ConcurrentTaskRunner`. +Files are processed in parallel using Prefect's `ConcurrentTaskRunner`. Each file goes through: @@ -94,12 +99,13 @@ For each input FITS file, the following will be written to the output directory: ``` / -├── flat_corrected.fits # Flat-field corrected image -├── flat_norm_qa.png # QA plot of normalized flat -└── slit_trace.txt # Slit trace positions +├── flat_correction.fits # Flat-field correction matrix +├── flat_corrected.fits # Corrected flat image (if save_corrected: true) +├── flat_norm_qa.png # QA plot of normalized flat +└── slit_trace.txt # Slit trace positions ``` -The corrected FITS file includes a `FLATCOR` keyword in the header: +The correction FITS file includes a `FLATCOR` keyword in the header: ``` FLATCOR = 'True' / Flat-field correction applied @@ -113,11 +119,10 @@ Additional keywords track reduction steps (optional to expand). ### Adjust Parallelism -To change the number of files processed in parallel, set `max_workers` in `ConcurrentTaskRunner`: +To change the number of files processed in parallel, set `max_workers` in `config/config.yaml`: -```python -@flow(task_runner=ConcurrentTaskRunner(max_workers=4)) -def batch_process_all_flats(...): +```yaml +max_workers: 4 ``` ### Customize Output Paths @@ -125,7 +130,7 @@ def batch_process_all_flats(...): Output filenames and directory structure can be customized in: - `save_trace_solution()` -- `save_corrected_fits()` +- `save_flat_fits()` - `generate_qa_plot()` --- diff --git a/config/config.yaml b/config/config.yaml index 17f7b35..b24e748 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -4,3 +4,9 @@ output_dir: output # Set to true to start Prefect UI (at http://127.0.0.1:4200) use_prefect_server: true + +# Set to true to also save corrected flat images (flat_corrected.fits) +save_corrected_flat: true + +# Number of files to process in parallel +max_workers: 2 diff --git a/src/core/flat.py b/src/core/flat.py index 1c88f19..632d8fc 100644 --- a/src/core/flat.py +++ b/src/core/flat.py @@ -43,15 +43,38 @@ def create_flat_correction(norm_data: np.ndarray) -> np.ndarray: return correction -def save_corrected_fits(original_data: np.ndarray, correction: np.ndarray, header: dict, output_path: str) -> str: - """Apply the flat correction to the original data and save as a new FITS file.""" - corrected_data = original_data.astype(np.float64) * correction +def save_flat_fits( + correction: np.ndarray, + header: dict, + output_path: str, + original_data: Optional[np.ndarray] = None +) -> str: + """Save flat field results to a FITS file. - # Add DRP history to header - header.add_history("DRP: Flat field correction applied") - header["FLATCORR"] = (True, "Flat field correction applied") + By default, saves the correction matrix (multiplicative factors ~1.0 to apply to science frames). + If original_data is provided, saves the corrected image instead. - hdu = fits.PrimaryHDU(data=corrected_data, header=header) + Args: + correction: The flat field correction matrix + header: FITS header to include + output_path: Path to save the FITS file + original_data: If provided, save corrected image (correction / original_data) + + Returns: + Path to the saved FITS file + """ + header = header.copy() + + if original_data is not None: + header.add_history("DRP: Flat field corrected image created") + header["FLATCORR"] = (True, "Flat field corrected image") + data_to_save = correction / original_data.astype(np.float64) + else: + header.add_history("DRP: Flat field correction map created") + header["FLATCORR"] = (True, "Flat field correction map") + data_to_save = correction + + hdu = fits.PrimaryHDU(data=data_to_save, header=header) hdul = fits.HDUList([hdu]) os.makedirs(os.path.dirname(output_path), exist_ok=True) hdul.writeto(output_path, overwrite=True) @@ -202,8 +225,5 @@ def create_master_flat(flat_data: np.ndarray, **kwargs) -> np.ndarray: Returns: Master flat correction map (multiply science data by this) """ - ratio = normalize_flat_spectroscopic(flat_data, **kwargs) - # The ratio is already the correction map - correction = ratio - + correction = normalize_flat_spectroscopic(flat_data, **kwargs) return correction diff --git a/src/keck_primitives/save_corrected.py b/src/keck_primitives/save_correction.py similarity index 51% rename from src/keck_primitives/save_corrected.py rename to src/keck_primitives/save_correction.py index 66da9e0..bbf632c 100644 --- a/src/keck_primitives/save_corrected.py +++ b/src/keck_primitives/save_correction.py @@ -1,18 +1,22 @@ from keckdrpframework.models.arguments import Arguments from keckdrpframework.primitives.base_primitive import BasePrimitive -from core.flat import save_corrected_fits +from core.flat import save_flat_fits -class SaveCorrectedFits(BasePrimitive): +class SaveFlatFits(BasePrimitive): def __init__(self, action, context): super().__init__(action, context) def _perform(self, input_args: Arguments, config: dict) -> dict: - """Apply the flat correction to the original data and save as a new FITS file.""" - original_data = input_args["original_data"] + """Save flat field results to a FITS file. + + By default saves the correction matrix. If original_data is provided, + saves the corrected image instead. + """ correction = input_args["correction"] header = input_args["header"] output_path = input_args["output_path"] + original_data = input_args["original_data"] if "original_data" in input_args else None - result_path = save_corrected_fits(original_data, correction, header, output_path) + result_path = save_flat_fits(correction, header, output_path, original_data) return {"output_path": result_path} diff --git a/src/main.py b/src/main.py index ad7bc5f..14b3f3a 100644 --- a/src/main.py +++ b/src/main.py @@ -3,10 +3,9 @@ This script initializes the pipeline and processes all flat field FITS files found in the specified input directory, saving the results to the output directory. """ -import yaml import os import subprocess -import time +import yaml from workflows.flows.batch_flat_flow import batch_process_all_flats def load_config(config_path="config/config.yaml"): @@ -19,6 +18,8 @@ def load_config(config_path="config/config.yaml"): input_dir = config["input_dir"] output_dir = config["output_dir"] use_prefect_server = config.get("use_prefect_server", True) + save_corrected_flat = config.get("save_corrected_flat", False) + max_workers = config.get("max_workers", 2) os.makedirs(output_dir, exist_ok=True) @@ -39,7 +40,7 @@ def load_config(config_path="config/config.yaml"): try: print(f"🟢 Starting batch processing of FITS files in {input_dir}") - batch_process_all_flats(input_dir=input_dir, output_dir=output_dir) + batch_process_all_flats(input_dir=input_dir, output_dir=output_dir, save_corrected_flat=save_corrected_flat, max_workers=max_workers) if use_prefect_server: print("\n✅ Pipeline completed!") diff --git a/src/workflows/flows/batch_flat_flow.py b/src/workflows/flows/batch_flat_flow.py index e75aebe..3d01035 100644 --- a/src/workflows/flows/batch_flat_flow.py +++ b/src/workflows/flows/batch_flat_flow.py @@ -4,30 +4,31 @@ from workflows.prefect_tasks.load_flat import load_flat_frame_task from workflows.prefect_tasks.create_master_flat import create_master_flat_task from workflows.prefect_tasks.trace_slits import trace_slits_task -from workflows.prefect_tasks.save_corrected import save_corrected_fits_task +from workflows.prefect_tasks.save_correction import save_flat_fits_task from workflows.prefect_tasks.save_trace import save_trace_solution_task from workflows.prefect_tasks.qa_plot import generate_qa_plot_task @task(name="Process Single Flat Frame") -def process_single_flat_frame(flat_fits_path: str, output_dir: str): +def process_single_flat_frame(flat_fits_path: str, output_dir: str, save_corrected_flat: bool = False): """ Process a single LRIS2 flat FITS file through all DRP steps. Args: flat_fits_path: Path to input FITS file output_dir: Output directory for results + save_corrected_flat: If True, also save the corrected flat image """ logger = get_run_logger() filename = os.path.splitext(os.path.basename(flat_fits_path))[0] # Construct output paths - corrected_output = os.path.join(output_dir, filename, "flat_corrected.fits") + correction_output = os.path.join(output_dir, filename, "flat_correction.fits") trace_output = os.path.join(output_dir, filename, "slit_trace.txt") qa_output = os.path.join(output_dir, filename, "flat_norm_qa.png") # Ensure output dirs - os.makedirs(os.path.dirname(corrected_output), exist_ok=True) + os.makedirs(os.path.dirname(correction_output), exist_ok=True) # Load FITS logger.info(f"Loading {flat_fits_path}") @@ -50,35 +51,44 @@ def process_single_flat_frame(flat_fits_path: str, output_dir: str): # Save outputs logger.info("Saving results") - save_corrected_fits_task(data, correction, header, corrected_output) + save_flat_fits_task(correction, header, correction_output) + if save_corrected_flat: + corrected_output = os.path.join(output_dir, filename, "flat_corrected.fits") + save_flat_fits_task(correction, header, corrected_output, original_data=data) save_trace_solution_task(slit_positions, trace_output) generate_qa_plot_task(correction, qa_output) logger.info(f"Finished processing {flat_fits_path}") -@flow( - name="Batch Process LRIS2 Flats", - description="Process all flat frames using spectroscopic flat fielding", - task_runner=ConcurrentTaskRunner(max_workers=2), # You can adjust this -) -def batch_process_all_flats(input_dir: str, output_dir: str): +def batch_process_all_flats(input_dir: str, output_dir: str, save_corrected_flat: bool = False, max_workers: int = 2): """ Process all FITS files in a directory using spectroscopic flat fielding. Args: input_dir: Directory containing input FITS files output_dir: Directory for output files + save_corrected_flat: If True, also save the corrected flat image + max_workers: Number of files to process in parallel """ - logger = get_run_logger() + @flow( + name="Batch Process LRIS2 Flats", + description="Process all flat frames using spectroscopic flat fielding", + task_runner=ConcurrentTaskRunner(max_workers=max_workers), + ) + def _batch_process_all_flats(input_dir: str, output_dir: str, save_corrected_flat: bool, max_workers: int): + logger = get_run_logger() + + fits_files = [ + os.path.join(input_dir, f) + for f in os.listdir(input_dir) + if f.lower().endswith(".fits") + ] + logger.info(f"Found {len(fits_files)} FITS files in {input_dir}.") + logger.info(f"Settings: save_corrected_flat={save_corrected_flat}, max_workers={max_workers}") - fits_files = [ - os.path.join(input_dir, f) - for f in os.listdir(input_dir) - if f.lower().endswith(".fits") - ] - logger.info(f"Found {len(fits_files)} FITS files in {input_dir}.") + futures = [process_single_flat_frame.submit(fp, output_dir, save_corrected_flat) for fp in fits_files] - futures = [process_single_flat_frame.submit(fp, output_dir) for fp in fits_files] + for future in futures: + future.result() - for future in futures: - future.result() + return _batch_process_all_flats(input_dir, output_dir, save_corrected_flat, max_workers) diff --git a/src/workflows/prefect_tasks/load_flat.py b/src/workflows/prefect_tasks/load_flat.py index 4bd7764..019096d 100644 --- a/src/workflows/prefect_tasks/load_flat.py +++ b/src/workflows/prefect_tasks/load_flat.py @@ -13,4 +13,4 @@ def load_flat_frame_task(filepath: str): context = DummyContext() result = LoadFlat(action, context)._perform(args, config={}) - return result["flat_data"], result["header"] \ No newline at end of file + return result["flat_data"], result["header"] diff --git a/src/workflows/prefect_tasks/qa_plot.py b/src/workflows/prefect_tasks/qa_plot.py index 8a94f69..7bad944 100644 --- a/src/workflows/prefect_tasks/qa_plot.py +++ b/src/workflows/prefect_tasks/qa_plot.py @@ -15,4 +15,4 @@ def generate_qa_plot_task(data, output_path: str, title: str = "Flat QA"): context = DummyContext() result = GenerateQAPlot(action, context)._perform(args, config={}) - return result["output_path"] \ No newline at end of file + return result["output_path"] diff --git a/src/workflows/prefect_tasks/save_corrected.py b/src/workflows/prefect_tasks/save_corrected.py deleted file mode 100644 index 4f12796..0000000 --- a/src/workflows/prefect_tasks/save_corrected.py +++ /dev/null @@ -1,19 +0,0 @@ -from prefect import task -from keckdrpframework.models.arguments import Arguments -from keck_primitives.save_corrected import SaveCorrectedFits -from keck_primitives.utils import DummyAction, DummyContext - - -@task(name="Save Corrected FITS") -def save_corrected_fits_task(original_data, correction, header, output_path: str): - args = Arguments() - args["original_data"] = original_data - args["correction"] = correction - args["header"] = header - args["output_path"] = output_path - - action = DummyAction(args=args) - context = DummyContext() - - result = SaveCorrectedFits(action, context)._perform(args, config={}) - return result["output_path"] \ No newline at end of file diff --git a/src/workflows/prefect_tasks/save_correction.py b/src/workflows/prefect_tasks/save_correction.py new file mode 100644 index 0000000..fc76543 --- /dev/null +++ b/src/workflows/prefect_tasks/save_correction.py @@ -0,0 +1,25 @@ +from prefect import task +from keckdrpframework.models.arguments import Arguments +from keck_primitives.save_correction import SaveFlatFits +from keck_primitives.utils import DummyAction, DummyContext + + +@task(name="Save Flat FITS") +def save_flat_fits_task(correction, header, output_path: str, original_data=None): + """Save flat field results to a FITS file. + + By default saves the correction matrix. If original_data is provided, + saves the corrected image instead. + """ + args = Arguments() + args["correction"] = correction + args["header"] = header + args["output_path"] = output_path + if original_data is not None: + args["original_data"] = original_data + + action = DummyAction(args=args) + context = DummyContext() + + result = SaveFlatFits(action, context)._perform(args, config={}) + return result["output_path"] diff --git a/src/workflows/prefect_tasks/save_trace.py b/src/workflows/prefect_tasks/save_trace.py index c32a8d2..ff22305 100644 --- a/src/workflows/prefect_tasks/save_trace.py +++ b/src/workflows/prefect_tasks/save_trace.py @@ -14,4 +14,4 @@ def save_trace_solution_task(slit_positions, output_path: str): context = DummyContext() result = SaveTraceSolution(action, context)._perform(args, config={}) - return result["output_path"] \ No newline at end of file + return result["output_path"]