diff --git a/.github/workflows/test-main.yaml b/.github/workflows/test-main.yaml new file mode 100644 index 0000000..6ce461c --- /dev/null +++ b/.github/workflows/test-main.yaml @@ -0,0 +1,32 @@ +name: Run DRP Main + +on: + push: + branches: [main, mike/**] + pull_request: + branches: [main] + +jobs: + run-main: + runs-on: ubuntu-latest + + steps: + - name: Checkout repository + uses: actions/checkout@v3 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.11" + + - name: Install project dependencies + run: | + python -m venv venv + source venv/bin/activate + pip install --upgrade pip + pip install -e . + + - name: Run main.py + run: | + source venv/bin/activate + python src/main.py diff --git a/.gitignore b/.gitignore index 068ef90..7d20e5d 100644 --- a/.gitignore +++ b/.gitignore @@ -172,4 +172,6 @@ cython_debug/ .pypirc # Output files -output/* \ No newline at end of file +output/* + +.python-version \ No newline at end of file diff --git a/config.yaml b/config/config.yaml similarity index 100% rename from config.yaml rename to config/config.yaml diff --git a/main.py b/main.py deleted file mode 100644 index ffeabe4..0000000 --- a/main.py +++ /dev/null @@ -1,18 +0,0 @@ -# main.py -import os -if "PREFECT_API_URL" in os.environ: - del os.environ["PREFECT_API_URL"] -import yaml -from lris2_drp.flows import batch_process_all_flats - -def load_config(config_path="config.yaml"): - """Load configuration from a YAML file.""" - with open(config_path, "r") as f: - return yaml.safe_load(f) - -if __name__ == "__main__": - config = load_config() - batch_process_all_flats( - input_dir=config["input_dir"], - output_dir=config["output_dir"] - ) diff --git a/output/.gitignore b/output/.gitignore deleted file mode 100644 index c9302cf..0000000 --- a/output/.gitignore +++ /dev/null @@ -1,4 +0,0 @@ -# Ignore everything in this directory -* -# Except this .gitignore file -!.gitignore diff --git a/pyproject.toml b/pyproject.toml index 04c7ad9..082e8f1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,18 +1,18 @@ [build-system] # Specifies the build system to use. -requires = ["setuptools>=42"] +requires = ["setuptools>=61"] build-backend = "setuptools.build_meta" [project] -# Basic information about your project. name = "lris2-drp" version = "0.1.0" dependencies = [ - "prefect>=3.4.11", + "prefect", "astropy", "numpy", "matplotlib", - "scipy" + "scipy", + "keckdrpframework" ] requires-python = ">=3.9" authors = [ @@ -24,7 +24,6 @@ maintainers = [ description = "Data Reduction Pipeline for LRIS2" readme = "README.md" license = { text = "MIT" } -keywords = ["example", "package", "keywords"] classifiers = [ "Programming Language :: Python :: 3", "License :: OSI Approved :: MIT License", @@ -41,17 +40,13 @@ Repository = "https://github.com/CaltechOpticalObservatories/lris2-drp" "Bug Tracker" = "https://github.com/CaltechOpticalObservatories/lris2-drp/issues" #Changelog = "https://github.com/yourusername/your-repo/blob/master/CHANGELOG.md" -[project.scripts] -# Defines command-line scripts for your package. Replace with your script and function. -lris2-drp = "lris2_drp.main:main" [project.optional-dependencies] -# Optional dependencies that can be installed with extra tags, like "dev". dev = [ - "pytest", + "pytest", "black", "flake8" ] [tool.setuptools] -packages = { find = { include = ["lris2_drp"] } } +packages = { find = { include = ["src"] } } diff --git a/lris2_drp/__init__.py b/src/core/__init__.py similarity index 100% rename from lris2_drp/__init__.py rename to src/core/__init__.py diff --git a/lris2_drp/flat.py b/src/core/flat.py similarity index 100% rename from lris2_drp/flat.py rename to src/core/flat.py diff --git a/lris2_drp/flows.py b/src/core/flows.py similarity index 92% rename from lris2_drp/flows.py rename to src/core/flows.py index 3d9f9dc..9e07c7a 100644 --- a/lris2_drp/flows.py +++ b/src/core/flows.py @@ -1,14 +1,14 @@ from prefect import flow, task, get_run_logger from prefect.task_runners import ConcurrentTaskRunner import os -from lris2_drp.flat import ( +from core.flat import ( load_flat_frame, normalize_flat, create_flat_correction, save_corrected_fits, ) -from lris2_drp.tracing import trace_slits_1d, save_trace_solution -from lris2_drp.qa import generate_qa_plot +from core.tracing import trace_slits_1d, save_trace_solution +from core.qa import generate_qa_plot @task(name="Process Single LRIS2 Flat Frame", description="Run all DRP steps on a single flat FITS file") def process_flat_frame(filepath: str, output_dir: str): diff --git a/lris2_drp/qa.py b/src/core/qa.py similarity index 67% rename from lris2_drp/qa.py rename to src/core/qa.py index d9168fe..dfa0f26 100644 --- a/lris2_drp/qa.py +++ b/src/core/qa.py @@ -9,10 +9,15 @@ @task(name="Generate QA Plot", description="Save normalized flat as PNG", tags=["qa", "plot"]) def generate_qa_plot(data: np.ndarray, output_path: str, title: str = "Flat QA") -> str: """Generate a QA plot for the normalized flat field data.""" + + if data is None or not hasattr(data, "shape") or data.ndim != 2: + raise ValueError(f"Invalid data shape for QA plot: {getattr(data, 'shape', None)}") + os.makedirs(os.path.dirname(output_path), exist_ok=True) plt.figure(figsize=(10, 4)) - im = plt.imshow(data, cmap="gray", aspect="auto", origin="lower") - plt.colorbar(im) + fig, ax = plt.subplots() + im = ax.imshow(data, cmap="gray", aspect="auto", origin="lower") + fig.colorbar(im, ax=ax) plt.title(title) plt.savefig(output_path) plt.close() diff --git a/lris2_drp/tracing.py b/src/core/tracing.py similarity index 100% rename from lris2_drp/tracing.py rename to src/core/tracing.py diff --git a/src/keck_primitives/__init__.py b/src/keck_primitives/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/keck_primitives/normalize_flat.py b/src/keck_primitives/normalize_flat.py new file mode 100644 index 0000000..b702946 --- /dev/null +++ b/src/keck_primitives/normalize_flat.py @@ -0,0 +1,13 @@ +import numpy as np +from keckdrpframework.models.arguments import Arguments +from keckdrpframework.primitives.base_primitive import BasePrimitive + +class NormalizeFlat(BasePrimitive): + def __init__(self, action, context): + super().__init__(action, context) + + def _perform(self, input_args: Arguments, config: dict) -> dict: + data = input_args["flat_data"] + median = np.median(data[data > 0]) + norm = data / median + return {"norm_data": norm} diff --git a/src/keck_primitives/trace_slits.py b/src/keck_primitives/trace_slits.py new file mode 100644 index 0000000..195d7cc --- /dev/null +++ b/src/keck_primitives/trace_slits.py @@ -0,0 +1,14 @@ +import numpy as np +from scipy.signal import find_peaks +from keckdrpframework.models.arguments import Arguments +from keckdrpframework.primitives.base_primitive import BasePrimitive + +class TraceSlits1D(BasePrimitive): + def __init__(self, action, context): + super().__init__(action, context) + + def _perform(self, input_args: Arguments, config: dict) -> dict: + data = input_args["flat_data"] + profile = np.median(data, axis=1) + peaks, _ = find_peaks(profile, distance=20, prominence=0.05) + return {"slit_positions": list(peaks)} diff --git a/src/keck_primitives/utils.py b/src/keck_primitives/utils.py new file mode 100644 index 0000000..53d134b --- /dev/null +++ b/src/keck_primitives/utils.py @@ -0,0 +1,15 @@ +"""Utility functions and classes for Keck DRP primitives.""" + +class DummyAction: + def __init__(self, args=None): + self.args = args or {} + self.action_type = "dummy" + self.timestamp = None + self.context = None + self.logger = None + +class DummyContext: + def __init__(self, logger=None, config=None): + import logging + self.logger = logger or logging.getLogger("dummy_logger") + self.config = config or {} diff --git a/src/main.py b/src/main.py new file mode 100644 index 0000000..ee46364 --- /dev/null +++ b/src/main.py @@ -0,0 +1,23 @@ +""" +Main entry point for the LRIS2 Data Reduction Pipeline (DRP). +This script initializes the pipeline and processes all flat field FITS files +found in the specified input directory, saving the results to the output directory. +""" +import yaml +import os +from workflows.flows.batch_flat_flow import batch_process_all_flats + +def load_config(config_path="config/config.yaml"): + """Load configuration from a YAML file.""" + with open(config_path, "r") as f: + return yaml.safe_load(f) + +if __name__ == "__main__": + config = load_config() + input_dir = config["input_dir"] + output_dir = config["output_dir"] + + os.makedirs(output_dir, exist_ok=True) + + print(f"🟢 Starting batch processing of FITS files in {input_dir}") + batch_process_all_flats(input_dir=input_dir, output_dir=output_dir) diff --git a/src/workflows/__init__.py b/src/workflows/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/workflows/flows/__init__.py b/src/workflows/flows/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/workflows/flows/batch_flat_flow.py b/src/workflows/flows/batch_flat_flow.py new file mode 100644 index 0000000..2703fd9 --- /dev/null +++ b/src/workflows/flows/batch_flat_flow.py @@ -0,0 +1,61 @@ +import os +from prefect import flow, task, get_run_logger +from prefect.task_runners import ConcurrentTaskRunner +from workflows.prefect_tasks.load_flat import load_flat_frame_task +from workflows.prefect_tasks.normalize_flat import normalize_flat_task +from workflows.prefect_tasks.create_correction import create_flat_correction_task +from workflows.prefect_tasks.trace_slits import trace_slits_task +from workflows.prefect_tasks.save_corrected import save_corrected_fits_task +from workflows.prefect_tasks.save_trace import save_trace_solution_task +from workflows.prefect_tasks.qa_plot import generate_qa_plot_task + + +@task(name="Process Single Flat Frame") +def process_single_flat_frame(flat_fits_path: str, output_dir: str): + """Process a single LRIS2 flat FITS file through all DRP steps.""" + logger = get_run_logger() + filename = os.path.splitext(os.path.basename(flat_fits_path))[0] + + # Construct output paths + corrected_output = os.path.join(output_dir, filename, "flat_corrected.fits") + trace_output = os.path.join(output_dir, filename, "slit_trace.txt") + qa_output = os.path.join(output_dir, filename, "flat_norm_qa.png") + + # Ensure output dirs + os.makedirs(os.path.dirname(corrected_output), exist_ok=True) + + # Load FITS + data, header = load_flat_frame_task(flat_fits_path) + + # DRP steps + norm = normalize_flat_task(data) + correction = create_flat_correction_task(norm) + slit_positions = trace_slits_task(data) + + # Save outputs + save_corrected_fits_task(data, correction, header, corrected_output) + save_trace_solution_task(slit_positions, trace_output) + generate_qa_plot_task(norm, qa_output) + + logger.info(f"Finished processing {flat_fits_path}") + +@flow( + name="Batch Process LRIS2 Flats", + description="Process all flat frames concurrently using Prefect", + task_runner=ConcurrentTaskRunner(max_workers=2), # You can adjust this +) +def batch_process_all_flats(input_dir: str, output_dir: str): + """Process all FITS files in a directory using concurrent subflows.""" + logger = get_run_logger() + + fits_files = [ + os.path.join(input_dir, f) + for f in os.listdir(input_dir) + if f.lower().endswith(".fits") + ] + logger.info(f"Found {len(fits_files)} FITS files in {input_dir}.") + + futures = [process_single_flat_frame.submit(fp, output_dir) for fp in fits_files] + + for future in futures: + future.result() diff --git a/src/workflows/prefect_tasks/__init__.py b/src/workflows/prefect_tasks/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/workflows/prefect_tasks/create_correction.py b/src/workflows/prefect_tasks/create_correction.py new file mode 100644 index 0000000..349e050 --- /dev/null +++ b/src/workflows/prefect_tasks/create_correction.py @@ -0,0 +1,6 @@ +from prefect import task +from core.flat import create_flat_correction + +@task(name="Create Flat Correction") +def create_flat_correction_task(norm_data): + return create_flat_correction.fn(norm_data) \ No newline at end of file diff --git a/src/workflows/prefect_tasks/load_flat.py b/src/workflows/prefect_tasks/load_flat.py new file mode 100644 index 0000000..62089c8 --- /dev/null +++ b/src/workflows/prefect_tasks/load_flat.py @@ -0,0 +1,6 @@ +from prefect import task +from core.flat import load_flat_frame + +@task(name="Load Flat Frame") +def load_flat_frame_task(filepath: str): + return load_flat_frame.fn(filepath) \ No newline at end of file diff --git a/src/workflows/prefect_tasks/normalize_flat.py b/src/workflows/prefect_tasks/normalize_flat.py new file mode 100644 index 0000000..6599645 --- /dev/null +++ b/src/workflows/prefect_tasks/normalize_flat.py @@ -0,0 +1,16 @@ +from prefect import task +from keckdrpframework.models.arguments import Arguments +from keck_primitives.normalize_flat import NormalizeFlat +from keck_primitives.utils import DummyAction, DummyContext + + +@task(name="Normalize Flat") +def normalize_flat_task(data): + args = Arguments() + args["flat_data"] = data + + action = DummyAction(args=args) + context = DummyContext() + + result = NormalizeFlat(action, context)._perform(args, config={}) + return result["norm_data"] diff --git a/src/workflows/prefect_tasks/qa_plot.py b/src/workflows/prefect_tasks/qa_plot.py new file mode 100644 index 0000000..fc23cad --- /dev/null +++ b/src/workflows/prefect_tasks/qa_plot.py @@ -0,0 +1,6 @@ +from prefect import task +from core.qa import generate_qa_plot + +@task(name="Generate QA Plot") +def generate_qa_plot_task(data, output_path: str, title: str = "Flat QA"): + return generate_qa_plot.fn(data, output_path, title) \ No newline at end of file diff --git a/src/workflows/prefect_tasks/save_corrected.py b/src/workflows/prefect_tasks/save_corrected.py new file mode 100644 index 0000000..61d1cb3 --- /dev/null +++ b/src/workflows/prefect_tasks/save_corrected.py @@ -0,0 +1,6 @@ +from prefect import task +from core.flat import save_corrected_fits + +@task(name="Save Corrected FITS") +def save_corrected_fits_task(original_data, correction, header, output_path: str): + return save_corrected_fits.fn(original_data, correction, header, output_path) \ No newline at end of file diff --git a/src/workflows/prefect_tasks/save_trace.py b/src/workflows/prefect_tasks/save_trace.py new file mode 100644 index 0000000..d78b3bc --- /dev/null +++ b/src/workflows/prefect_tasks/save_trace.py @@ -0,0 +1,6 @@ +from prefect import task +from core.tracing import save_trace_solution + +@task(name="Save Trace Solution") +def save_trace_solution_task(slit_positions, output_path: str): + return save_trace_solution.fn(slit_positions, output_path) \ No newline at end of file diff --git a/src/workflows/prefect_tasks/trace_slits.py b/src/workflows/prefect_tasks/trace_slits.py new file mode 100644 index 0000000..74f7b9c --- /dev/null +++ b/src/workflows/prefect_tasks/trace_slits.py @@ -0,0 +1,14 @@ +from prefect import task +from keck_primitives.trace_slits import TraceSlits1D +from keckdrpframework.models.arguments import Arguments +from keck_primitives.utils import DummyAction, DummyContext + +@task(name="Trace Slits 1D") +def trace_slits_task(data): + """Task to trace slits in 1D data.""" + args = Arguments() + args["flat_data"] = data + action = DummyAction(args=args) + context = DummyContext() + result = TraceSlits1D(action, context)._perform(args, config={}) + return result["slit_positions"]