From 906019f4b0dec6acc22c37156317b9f00d88c1f9 Mon Sep 17 00:00:00 2001 From: Michael Langmayr Date: Tue, 9 Dec 2025 11:28:56 -0800 Subject: [PATCH 1/4] fix flatfield results, return correction matrix --- .gitignore | 2 +- src/core/flat.py | 14 ++++++++------ src/keck_primitives/save_corrected.py | 9 ++++----- src/main.py | 3 +-- src/workflows/flows/batch_flat_flow.py | 8 ++++---- src/workflows/prefect_tasks/load_flat.py | 2 +- src/workflows/prefect_tasks/qa_plot.py | 2 +- src/workflows/prefect_tasks/save_corrected.py | 9 ++++----- src/workflows/prefect_tasks/save_trace.py | 2 +- 9 files changed, 25 insertions(+), 26 deletions(-) diff --git a/.gitignore b/.gitignore index 8ebe52b..8ad75b4 100644 --- a/.gitignore +++ b/.gitignore @@ -176,4 +176,4 @@ output/* .python-version -.DS_Store \ No newline at end of file +.DS_Store diff --git a/src/core/flat.py b/src/core/flat.py index 1c88f19..4e1a609 100644 --- a/src/core/flat.py +++ b/src/core/flat.py @@ -43,15 +43,17 @@ def create_flat_correction(norm_data: np.ndarray) -> np.ndarray: return correction -def save_corrected_fits(original_data: np.ndarray, correction: np.ndarray, header: dict, output_path: str) -> str: - """Apply the flat correction to the original data and save as a new FITS file.""" - corrected_data = original_data.astype(np.float64) * correction +def save_correction_fits(correction: np.ndarray, header: dict, output_path: str) -> str: + """Save the flat field correction map to a FITS file. + The correction map contains multiplicative factors (~1.0) to apply to science frames: + corrected_science = raw_science * correction + """ # Add DRP history to header - header.add_history("DRP: Flat field correction applied") - header["FLATCORR"] = (True, "Flat field correction applied") + header.add_history("DRP: Flat field correction map created") + header["FLATCORR"] = (True, "Flat field correction map") - hdu = fits.PrimaryHDU(data=corrected_data, header=header) + hdu = fits.PrimaryHDU(data=correction, header=header) hdul = fits.HDUList([hdu]) os.makedirs(os.path.dirname(output_path), exist_ok=True) hdul.writeto(output_path, overwrite=True) diff --git a/src/keck_primitives/save_corrected.py b/src/keck_primitives/save_corrected.py index 66da9e0..b2363da 100644 --- a/src/keck_primitives/save_corrected.py +++ b/src/keck_primitives/save_corrected.py @@ -1,18 +1,17 @@ from keckdrpframework.models.arguments import Arguments from keckdrpframework.primitives.base_primitive import BasePrimitive -from core.flat import save_corrected_fits +from core.flat import save_correction_fits -class SaveCorrectedFits(BasePrimitive): +class SaveCorrectionFits(BasePrimitive): def __init__(self, action, context): super().__init__(action, context) def _perform(self, input_args: Arguments, config: dict) -> dict: - """Apply the flat correction to the original data and save as a new FITS file.""" - original_data = input_args["original_data"] + """Save the flat field correction map to a FITS file.""" correction = input_args["correction"] header = input_args["header"] output_path = input_args["output_path"] - result_path = save_corrected_fits(original_data, correction, header, output_path) + result_path = save_correction_fits(correction, header, output_path) return {"output_path": result_path} diff --git a/src/main.py b/src/main.py index ad7bc5f..6810d3a 100644 --- a/src/main.py +++ b/src/main.py @@ -3,10 +3,9 @@ This script initializes the pipeline and processes all flat field FITS files found in the specified input directory, saving the results to the output directory. """ -import yaml import os import subprocess -import time +import yaml from workflows.flows.batch_flat_flow import batch_process_all_flats def load_config(config_path="config/config.yaml"): diff --git a/src/workflows/flows/batch_flat_flow.py b/src/workflows/flows/batch_flat_flow.py index e75aebe..dcdb9f6 100644 --- a/src/workflows/flows/batch_flat_flow.py +++ b/src/workflows/flows/batch_flat_flow.py @@ -4,7 +4,7 @@ from workflows.prefect_tasks.load_flat import load_flat_frame_task from workflows.prefect_tasks.create_master_flat import create_master_flat_task from workflows.prefect_tasks.trace_slits import trace_slits_task -from workflows.prefect_tasks.save_corrected import save_corrected_fits_task +from workflows.prefect_tasks.save_corrected import save_correction_fits_task from workflows.prefect_tasks.save_trace import save_trace_solution_task from workflows.prefect_tasks.qa_plot import generate_qa_plot_task @@ -22,12 +22,12 @@ def process_single_flat_frame(flat_fits_path: str, output_dir: str): filename = os.path.splitext(os.path.basename(flat_fits_path))[0] # Construct output paths - corrected_output = os.path.join(output_dir, filename, "flat_corrected.fits") + correction_output = os.path.join(output_dir, filename, "flat_correction.fits") trace_output = os.path.join(output_dir, filename, "slit_trace.txt") qa_output = os.path.join(output_dir, filename, "flat_norm_qa.png") # Ensure output dirs - os.makedirs(os.path.dirname(corrected_output), exist_ok=True) + os.makedirs(os.path.dirname(correction_output), exist_ok=True) # Load FITS logger.info(f"Loading {flat_fits_path}") @@ -50,7 +50,7 @@ def process_single_flat_frame(flat_fits_path: str, output_dir: str): # Save outputs logger.info("Saving results") - save_corrected_fits_task(data, correction, header, corrected_output) + save_correction_fits_task(correction, header, correction_output) save_trace_solution_task(slit_positions, trace_output) generate_qa_plot_task(correction, qa_output) diff --git a/src/workflows/prefect_tasks/load_flat.py b/src/workflows/prefect_tasks/load_flat.py index 4bd7764..019096d 100644 --- a/src/workflows/prefect_tasks/load_flat.py +++ b/src/workflows/prefect_tasks/load_flat.py @@ -13,4 +13,4 @@ def load_flat_frame_task(filepath: str): context = DummyContext() result = LoadFlat(action, context)._perform(args, config={}) - return result["flat_data"], result["header"] \ No newline at end of file + return result["flat_data"], result["header"] diff --git a/src/workflows/prefect_tasks/qa_plot.py b/src/workflows/prefect_tasks/qa_plot.py index 8a94f69..7bad944 100644 --- a/src/workflows/prefect_tasks/qa_plot.py +++ b/src/workflows/prefect_tasks/qa_plot.py @@ -15,4 +15,4 @@ def generate_qa_plot_task(data, output_path: str, title: str = "Flat QA"): context = DummyContext() result = GenerateQAPlot(action, context)._perform(args, config={}) - return result["output_path"] \ No newline at end of file + return result["output_path"] diff --git a/src/workflows/prefect_tasks/save_corrected.py b/src/workflows/prefect_tasks/save_corrected.py index 4f12796..44d85df 100644 --- a/src/workflows/prefect_tasks/save_corrected.py +++ b/src/workflows/prefect_tasks/save_corrected.py @@ -1,13 +1,12 @@ from prefect import task from keckdrpframework.models.arguments import Arguments -from keck_primitives.save_corrected import SaveCorrectedFits +from keck_primitives.save_corrected import SaveCorrectionFits from keck_primitives.utils import DummyAction, DummyContext -@task(name="Save Corrected FITS") -def save_corrected_fits_task(original_data, correction, header, output_path: str): +@task(name="Save Correction FITS") +def save_correction_fits_task(correction, header, output_path: str): args = Arguments() - args["original_data"] = original_data args["correction"] = correction args["header"] = header args["output_path"] = output_path @@ -15,5 +14,5 @@ def save_corrected_fits_task(original_data, correction, header, output_path: str action = DummyAction(args=args) context = DummyContext() - result = SaveCorrectedFits(action, context)._perform(args, config={}) + result = SaveCorrectionFits(action, context)._perform(args, config={}) return result["output_path"] \ No newline at end of file diff --git a/src/workflows/prefect_tasks/save_trace.py b/src/workflows/prefect_tasks/save_trace.py index c32a8d2..ff22305 100644 --- a/src/workflows/prefect_tasks/save_trace.py +++ b/src/workflows/prefect_tasks/save_trace.py @@ -14,4 +14,4 @@ def save_trace_solution_task(slit_positions, output_path: str): context = DummyContext() result = SaveTraceSolution(action, context)._perform(args, config={}) - return result["output_path"] \ No newline at end of file + return result["output_path"] From ed105fc8a7767c0f3e317c9c8000af5018cd9f2a Mon Sep 17 00:00:00 2001 From: Michael Langmayr Date: Tue, 9 Dec 2025 11:54:29 -0800 Subject: [PATCH 2/4] update naming and README to be consitent --- README.md | 6 +++--- src/core/flat.py | 3 +-- .../{save_corrected.py => save_correction.py} | 0 src/workflows/flows/batch_flat_flow.py | 2 +- .../prefect_tasks/{save_corrected.py => save_correction.py} | 4 ++-- 5 files changed, 7 insertions(+), 8 deletions(-) rename src/keck_primitives/{save_corrected.py => save_correction.py} (100%) rename src/workflows/prefect_tasks/{save_corrected.py => save_correction.py} (84%) diff --git a/README.md b/README.md index 112c347..0d22fe3 100644 --- a/README.md +++ b/README.md @@ -94,12 +94,12 @@ For each input FITS file, the following will be written to the output directory: ``` / -├── flat_corrected.fits # Flat-field corrected image +├── flat_correction.fits # Flat-field correction matrix ├── flat_norm_qa.png # QA plot of normalized flat └── slit_trace.txt # Slit trace positions ``` -The corrected FITS file includes a `FLATCOR` keyword in the header: +The correction FITS file includes a `FLATCOR` keyword in the header: ``` FLATCOR = 'True' / Flat-field correction applied @@ -125,7 +125,7 @@ def batch_process_all_flats(...): Output filenames and directory structure can be customized in: - `save_trace_solution()` -- `save_corrected_fits()` +- `save_correction_fits()` - `generate_qa_plot()` --- diff --git a/src/core/flat.py b/src/core/flat.py index 4e1a609..9a4c38a 100644 --- a/src/core/flat.py +++ b/src/core/flat.py @@ -46,8 +46,7 @@ def create_flat_correction(norm_data: np.ndarray) -> np.ndarray: def save_correction_fits(correction: np.ndarray, header: dict, output_path: str) -> str: """Save the flat field correction map to a FITS file. - The correction map contains multiplicative factors (~1.0) to apply to science frames: - corrected_science = raw_science * correction + The correction map contains multiplicative factors (~1.0) to apply to science frames. """ # Add DRP history to header header.add_history("DRP: Flat field correction map created") diff --git a/src/keck_primitives/save_corrected.py b/src/keck_primitives/save_correction.py similarity index 100% rename from src/keck_primitives/save_corrected.py rename to src/keck_primitives/save_correction.py diff --git a/src/workflows/flows/batch_flat_flow.py b/src/workflows/flows/batch_flat_flow.py index dcdb9f6..992bfb6 100644 --- a/src/workflows/flows/batch_flat_flow.py +++ b/src/workflows/flows/batch_flat_flow.py @@ -4,7 +4,7 @@ from workflows.prefect_tasks.load_flat import load_flat_frame_task from workflows.prefect_tasks.create_master_flat import create_master_flat_task from workflows.prefect_tasks.trace_slits import trace_slits_task -from workflows.prefect_tasks.save_corrected import save_correction_fits_task +from workflows.prefect_tasks.save_correction import save_correction_fits_task from workflows.prefect_tasks.save_trace import save_trace_solution_task from workflows.prefect_tasks.qa_plot import generate_qa_plot_task diff --git a/src/workflows/prefect_tasks/save_corrected.py b/src/workflows/prefect_tasks/save_correction.py similarity index 84% rename from src/workflows/prefect_tasks/save_corrected.py rename to src/workflows/prefect_tasks/save_correction.py index 44d85df..d49eac0 100644 --- a/src/workflows/prefect_tasks/save_corrected.py +++ b/src/workflows/prefect_tasks/save_correction.py @@ -1,6 +1,6 @@ from prefect import task from keckdrpframework.models.arguments import Arguments -from keck_primitives.save_corrected import SaveCorrectionFits +from keck_primitives.save_correction import SaveCorrectionFits from keck_primitives.utils import DummyAction, DummyContext @@ -15,4 +15,4 @@ def save_correction_fits_task(correction, header, output_path: str): context = DummyContext() result = SaveCorrectionFits(action, context)._perform(args, config={}) - return result["output_path"] \ No newline at end of file + return result["output_path"] From 777efc258cff54cd07642dc74bb200b62f5d5c7d Mon Sep 17 00:00:00 2001 From: Michael Langmayr Date: Fri, 12 Dec 2025 13:15:05 -0800 Subject: [PATCH 3/4] add parameter to save corrected flat fits file --- config/config.yaml | 3 ++ src/core/flat.py | 41 ++++++++++++++----- src/keck_primitives/save_correction.py | 13 ++++-- src/main.py | 3 +- src/workflows/flows/batch_flat_flow.py | 15 ++++--- .../prefect_tasks/save_correction.py | 15 +++++-- 6 files changed, 65 insertions(+), 25 deletions(-) diff --git a/config/config.yaml b/config/config.yaml index 17f7b35..8dcc719 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -4,3 +4,6 @@ output_dir: output # Set to true to start Prefect UI (at http://127.0.0.1:4200) use_prefect_server: true + +# Set to true to also save corrected flat images (default: false, saves only correction matrix) +save_corrected: true diff --git a/src/core/flat.py b/src/core/flat.py index 9a4c38a..632d8fc 100644 --- a/src/core/flat.py +++ b/src/core/flat.py @@ -43,16 +43,38 @@ def create_flat_correction(norm_data: np.ndarray) -> np.ndarray: return correction -def save_correction_fits(correction: np.ndarray, header: dict, output_path: str) -> str: - """Save the flat field correction map to a FITS file. +def save_flat_fits( + correction: np.ndarray, + header: dict, + output_path: str, + original_data: Optional[np.ndarray] = None +) -> str: + """Save flat field results to a FITS file. - The correction map contains multiplicative factors (~1.0) to apply to science frames. + By default, saves the correction matrix (multiplicative factors ~1.0 to apply to science frames). + If original_data is provided, saves the corrected image instead. + + Args: + correction: The flat field correction matrix + header: FITS header to include + output_path: Path to save the FITS file + original_data: If provided, save corrected image (correction / original_data) + + Returns: + Path to the saved FITS file """ - # Add DRP history to header - header.add_history("DRP: Flat field correction map created") - header["FLATCORR"] = (True, "Flat field correction map") + header = header.copy() + + if original_data is not None: + header.add_history("DRP: Flat field corrected image created") + header["FLATCORR"] = (True, "Flat field corrected image") + data_to_save = correction / original_data.astype(np.float64) + else: + header.add_history("DRP: Flat field correction map created") + header["FLATCORR"] = (True, "Flat field correction map") + data_to_save = correction - hdu = fits.PrimaryHDU(data=correction, header=header) + hdu = fits.PrimaryHDU(data=data_to_save, header=header) hdul = fits.HDUList([hdu]) os.makedirs(os.path.dirname(output_path), exist_ok=True) hdul.writeto(output_path, overwrite=True) @@ -203,8 +225,5 @@ def create_master_flat(flat_data: np.ndarray, **kwargs) -> np.ndarray: Returns: Master flat correction map (multiply science data by this) """ - ratio = normalize_flat_spectroscopic(flat_data, **kwargs) - # The ratio is already the correction map - correction = ratio - + correction = normalize_flat_spectroscopic(flat_data, **kwargs) return correction diff --git a/src/keck_primitives/save_correction.py b/src/keck_primitives/save_correction.py index b2363da..bbf632c 100644 --- a/src/keck_primitives/save_correction.py +++ b/src/keck_primitives/save_correction.py @@ -1,17 +1,22 @@ from keckdrpframework.models.arguments import Arguments from keckdrpframework.primitives.base_primitive import BasePrimitive -from core.flat import save_correction_fits +from core.flat import save_flat_fits -class SaveCorrectionFits(BasePrimitive): +class SaveFlatFits(BasePrimitive): def __init__(self, action, context): super().__init__(action, context) def _perform(self, input_args: Arguments, config: dict) -> dict: - """Save the flat field correction map to a FITS file.""" + """Save flat field results to a FITS file. + + By default saves the correction matrix. If original_data is provided, + saves the corrected image instead. + """ correction = input_args["correction"] header = input_args["header"] output_path = input_args["output_path"] + original_data = input_args["original_data"] if "original_data" in input_args else None - result_path = save_correction_fits(correction, header, output_path) + result_path = save_flat_fits(correction, header, output_path, original_data) return {"output_path": result_path} diff --git a/src/main.py b/src/main.py index 6810d3a..854db35 100644 --- a/src/main.py +++ b/src/main.py @@ -18,6 +18,7 @@ def load_config(config_path="config/config.yaml"): input_dir = config["input_dir"] output_dir = config["output_dir"] use_prefect_server = config.get("use_prefect_server", True) + save_corrected = config.get("save_corrected", False) os.makedirs(output_dir, exist_ok=True) @@ -38,7 +39,7 @@ def load_config(config_path="config/config.yaml"): try: print(f"🟢 Starting batch processing of FITS files in {input_dir}") - batch_process_all_flats(input_dir=input_dir, output_dir=output_dir) + batch_process_all_flats(input_dir=input_dir, output_dir=output_dir, save_corrected=save_corrected) if use_prefect_server: print("\n✅ Pipeline completed!") diff --git a/src/workflows/flows/batch_flat_flow.py b/src/workflows/flows/batch_flat_flow.py index 992bfb6..5e304de 100644 --- a/src/workflows/flows/batch_flat_flow.py +++ b/src/workflows/flows/batch_flat_flow.py @@ -4,19 +4,20 @@ from workflows.prefect_tasks.load_flat import load_flat_frame_task from workflows.prefect_tasks.create_master_flat import create_master_flat_task from workflows.prefect_tasks.trace_slits import trace_slits_task -from workflows.prefect_tasks.save_correction import save_correction_fits_task +from workflows.prefect_tasks.save_correction import save_flat_fits_task from workflows.prefect_tasks.save_trace import save_trace_solution_task from workflows.prefect_tasks.qa_plot import generate_qa_plot_task @task(name="Process Single Flat Frame") -def process_single_flat_frame(flat_fits_path: str, output_dir: str): +def process_single_flat_frame(flat_fits_path: str, output_dir: str, save_corrected: bool = False): """ Process a single LRIS2 flat FITS file through all DRP steps. Args: flat_fits_path: Path to input FITS file output_dir: Output directory for results + save_corrected: If True, also save the corrected flat image """ logger = get_run_logger() filename = os.path.splitext(os.path.basename(flat_fits_path))[0] @@ -50,7 +51,10 @@ def process_single_flat_frame(flat_fits_path: str, output_dir: str): # Save outputs logger.info("Saving results") - save_correction_fits_task(correction, header, correction_output) + save_flat_fits_task(correction, header, correction_output) + if save_corrected: + corrected_output = os.path.join(output_dir, filename, "flat_corrected.fits") + save_flat_fits_task(correction, header, corrected_output, original_data=data) save_trace_solution_task(slit_positions, trace_output) generate_qa_plot_task(correction, qa_output) @@ -61,13 +65,14 @@ def process_single_flat_frame(flat_fits_path: str, output_dir: str): description="Process all flat frames using spectroscopic flat fielding", task_runner=ConcurrentTaskRunner(max_workers=2), # You can adjust this ) -def batch_process_all_flats(input_dir: str, output_dir: str): +def batch_process_all_flats(input_dir: str, output_dir: str, save_corrected: bool = False): """ Process all FITS files in a directory using spectroscopic flat fielding. Args: input_dir: Directory containing input FITS files output_dir: Directory for output files + save_corrected: If True, also save the corrected flat image (default: False) """ logger = get_run_logger() @@ -78,7 +83,7 @@ def batch_process_all_flats(input_dir: str, output_dir: str): ] logger.info(f"Found {len(fits_files)} FITS files in {input_dir}.") - futures = [process_single_flat_frame.submit(fp, output_dir) for fp in fits_files] + futures = [process_single_flat_frame.submit(fp, output_dir, save_corrected) for fp in fits_files] for future in futures: future.result() diff --git a/src/workflows/prefect_tasks/save_correction.py b/src/workflows/prefect_tasks/save_correction.py index d49eac0..fc76543 100644 --- a/src/workflows/prefect_tasks/save_correction.py +++ b/src/workflows/prefect_tasks/save_correction.py @@ -1,18 +1,25 @@ from prefect import task from keckdrpframework.models.arguments import Arguments -from keck_primitives.save_correction import SaveCorrectionFits +from keck_primitives.save_correction import SaveFlatFits from keck_primitives.utils import DummyAction, DummyContext -@task(name="Save Correction FITS") -def save_correction_fits_task(correction, header, output_path: str): +@task(name="Save Flat FITS") +def save_flat_fits_task(correction, header, output_path: str, original_data=None): + """Save flat field results to a FITS file. + + By default saves the correction matrix. If original_data is provided, + saves the corrected image instead. + """ args = Arguments() args["correction"] = correction args["header"] = header args["output_path"] = output_path + if original_data is not None: + args["original_data"] = original_data action = DummyAction(args=args) context = DummyContext() - result = SaveCorrectionFits(action, context)._perform(args, config={}) + result = SaveFlatFits(action, context)._perform(args, config={}) return result["output_path"] From 0cffce442b34128255fea35186f3948043d5f3ea Mon Sep 17 00:00:00 2001 From: Michael Langmayr Date: Fri, 12 Dec 2025 13:43:57 -0800 Subject: [PATCH 4/4] update README, add max_workers config property --- README.md | 33 +++++++++++-------- config/config.yaml | 7 ++-- src/main.py | 5 +-- src/workflows/flows/batch_flat_flow.py | 45 ++++++++++++++------------ 4 files changed, 52 insertions(+), 38 deletions(-) diff --git a/README.md b/README.md index 0d22fe3..d9be295 100644 --- a/README.md +++ b/README.md @@ -67,16 +67,21 @@ pip install -e .[dev] To batch-process a folder of flat-field FITS files: -```python -from lris2_drp.flows import batch_process_all_flats +```bash +python src/main.py +``` + +Configure the pipeline via `config/config.yaml`: -batch_process_all_flats( - input_dir="/path/to/flats", - output_dir="/path/to/results", -) +```yaml +input_dir: data/lris2_flats +output_dir: output +use_prefect_server: true +save_corrected_flat: false # save corrected flat images +max_workers: 2 ``` -This will process up to 2 files in parallel (configurable) using Prefect’s `ConcurrentTaskRunner`. +Files are processed in parallel using Prefect's `ConcurrentTaskRunner`. Each file goes through: @@ -95,8 +100,9 @@ For each input FITS file, the following will be written to the output directory: ``` / ├── flat_correction.fits # Flat-field correction matrix -├── flat_norm_qa.png # QA plot of normalized flat -└── slit_trace.txt # Slit trace positions +├── flat_corrected.fits # Corrected flat image (if save_corrected: true) +├── flat_norm_qa.png # QA plot of normalized flat +└── slit_trace.txt # Slit trace positions ``` The correction FITS file includes a `FLATCOR` keyword in the header: @@ -113,11 +119,10 @@ Additional keywords track reduction steps (optional to expand). ### Adjust Parallelism -To change the number of files processed in parallel, set `max_workers` in `ConcurrentTaskRunner`: +To change the number of files processed in parallel, set `max_workers` in `config/config.yaml`: -```python -@flow(task_runner=ConcurrentTaskRunner(max_workers=4)) -def batch_process_all_flats(...): +```yaml +max_workers: 4 ``` ### Customize Output Paths @@ -125,7 +130,7 @@ def batch_process_all_flats(...): Output filenames and directory structure can be customized in: - `save_trace_solution()` -- `save_correction_fits()` +- `save_flat_fits()` - `generate_qa_plot()` --- diff --git a/config/config.yaml b/config/config.yaml index 8dcc719..b24e748 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -5,5 +5,8 @@ output_dir: output # Set to true to start Prefect UI (at http://127.0.0.1:4200) use_prefect_server: true -# Set to true to also save corrected flat images (default: false, saves only correction matrix) -save_corrected: true +# Set to true to also save corrected flat images (flat_corrected.fits) +save_corrected_flat: true + +# Number of files to process in parallel +max_workers: 2 diff --git a/src/main.py b/src/main.py index 854db35..14b3f3a 100644 --- a/src/main.py +++ b/src/main.py @@ -18,7 +18,8 @@ def load_config(config_path="config/config.yaml"): input_dir = config["input_dir"] output_dir = config["output_dir"] use_prefect_server = config.get("use_prefect_server", True) - save_corrected = config.get("save_corrected", False) + save_corrected_flat = config.get("save_corrected_flat", False) + max_workers = config.get("max_workers", 2) os.makedirs(output_dir, exist_ok=True) @@ -39,7 +40,7 @@ def load_config(config_path="config/config.yaml"): try: print(f"🟢 Starting batch processing of FITS files in {input_dir}") - batch_process_all_flats(input_dir=input_dir, output_dir=output_dir, save_corrected=save_corrected) + batch_process_all_flats(input_dir=input_dir, output_dir=output_dir, save_corrected_flat=save_corrected_flat, max_workers=max_workers) if use_prefect_server: print("\n✅ Pipeline completed!") diff --git a/src/workflows/flows/batch_flat_flow.py b/src/workflows/flows/batch_flat_flow.py index 5e304de..3d01035 100644 --- a/src/workflows/flows/batch_flat_flow.py +++ b/src/workflows/flows/batch_flat_flow.py @@ -10,14 +10,14 @@ @task(name="Process Single Flat Frame") -def process_single_flat_frame(flat_fits_path: str, output_dir: str, save_corrected: bool = False): +def process_single_flat_frame(flat_fits_path: str, output_dir: str, save_corrected_flat: bool = False): """ Process a single LRIS2 flat FITS file through all DRP steps. Args: flat_fits_path: Path to input FITS file output_dir: Output directory for results - save_corrected: If True, also save the corrected flat image + save_corrected_flat: If True, also save the corrected flat image """ logger = get_run_logger() filename = os.path.splitext(os.path.basename(flat_fits_path))[0] @@ -52,7 +52,7 @@ def process_single_flat_frame(flat_fits_path: str, output_dir: str, save_correct # Save outputs logger.info("Saving results") save_flat_fits_task(correction, header, correction_output) - if save_corrected: + if save_corrected_flat: corrected_output = os.path.join(output_dir, filename, "flat_corrected.fits") save_flat_fits_task(correction, header, corrected_output, original_data=data) save_trace_solution_task(slit_positions, trace_output) @@ -60,30 +60,35 @@ def process_single_flat_frame(flat_fits_path: str, output_dir: str, save_correct logger.info(f"Finished processing {flat_fits_path}") -@flow( - name="Batch Process LRIS2 Flats", - description="Process all flat frames using spectroscopic flat fielding", - task_runner=ConcurrentTaskRunner(max_workers=2), # You can adjust this -) -def batch_process_all_flats(input_dir: str, output_dir: str, save_corrected: bool = False): +def batch_process_all_flats(input_dir: str, output_dir: str, save_corrected_flat: bool = False, max_workers: int = 2): """ Process all FITS files in a directory using spectroscopic flat fielding. Args: input_dir: Directory containing input FITS files output_dir: Directory for output files - save_corrected: If True, also save the corrected flat image (default: False) + save_corrected_flat: If True, also save the corrected flat image + max_workers: Number of files to process in parallel """ - logger = get_run_logger() + @flow( + name="Batch Process LRIS2 Flats", + description="Process all flat frames using spectroscopic flat fielding", + task_runner=ConcurrentTaskRunner(max_workers=max_workers), + ) + def _batch_process_all_flats(input_dir: str, output_dir: str, save_corrected_flat: bool, max_workers: int): + logger = get_run_logger() + + fits_files = [ + os.path.join(input_dir, f) + for f in os.listdir(input_dir) + if f.lower().endswith(".fits") + ] + logger.info(f"Found {len(fits_files)} FITS files in {input_dir}.") + logger.info(f"Settings: save_corrected_flat={save_corrected_flat}, max_workers={max_workers}") - fits_files = [ - os.path.join(input_dir, f) - for f in os.listdir(input_dir) - if f.lower().endswith(".fits") - ] - logger.info(f"Found {len(fits_files)} FITS files in {input_dir}.") + futures = [process_single_flat_frame.submit(fp, output_dir, save_corrected_flat) for fp in fits_files] - futures = [process_single_flat_frame.submit(fp, output_dir, save_corrected) for fp in fits_files] + for future in futures: + future.result() - for future in futures: - future.result() + return _batch_process_all_flats(input_dir, output_dir, save_corrected_flat, max_workers)