Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions .github/workflows/test-main.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
name: Run DRP Main

on:
push:
branches: [main, mike/**]
pull_request:
branches: [main]

jobs:
run-main:
runs-on: ubuntu-latest

steps:
- name: Checkout repository
uses: actions/checkout@v3

- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: "3.11"

- name: Install project dependencies
run: |
python -m venv venv
source venv/bin/activate
pip install --upgrade pip
pip install -e .

- name: Run main.py
run: |
source venv/bin/activate
python src/main.py
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -172,4 +172,6 @@ cython_debug/
.pypirc

# Output files
output/*
output/*

.python-version
File renamed without changes.
18 changes: 0 additions & 18 deletions main.py

This file was deleted.

4 changes: 0 additions & 4 deletions output/.gitignore

This file was deleted.

17 changes: 6 additions & 11 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
[build-system]
# Specifies the build system to use.
requires = ["setuptools>=42"]
requires = ["setuptools>=61"]
build-backend = "setuptools.build_meta"

[project]
# Basic information about your project.
name = "lris2-drp"
version = "0.1.0"
dependencies = [
"prefect>=3.4.11",
"prefect",
"astropy",
"numpy",
"matplotlib",
"scipy"
"scipy",
"keckdrpframework"
]
requires-python = ">=3.9"
authors = [
Expand All @@ -24,7 +24,6 @@ maintainers = [
description = "Data Reduction Pipeline for LRIS2"
readme = "README.md"
license = { text = "MIT" }
keywords = ["example", "package", "keywords"]
classifiers = [
"Programming Language :: Python :: 3",
"License :: OSI Approved :: MIT License",
Expand All @@ -41,17 +40,13 @@ Repository = "https://github.com/CaltechOpticalObservatories/lris2-drp"
"Bug Tracker" = "https://github.com/CaltechOpticalObservatories/lris2-drp/issues"
#Changelog = "https://github.com/yourusername/your-repo/blob/master/CHANGELOG.md"

[project.scripts]
# Defines command-line scripts for your package. Replace with your script and function.
lris2-drp = "lris2_drp.main:main"

[project.optional-dependencies]
# Optional dependencies that can be installed with extra tags, like "dev".
dev = [
"pytest",
"pytest",
"black",
"flake8"
]

[tool.setuptools]
packages = { find = { include = ["lris2_drp"] } }
packages = { find = { include = ["src"] } }
File renamed without changes.
File renamed without changes.
6 changes: 3 additions & 3 deletions lris2_drp/flows.py → src/core/flows.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from prefect import flow, task, get_run_logger
from prefect.task_runners import ConcurrentTaskRunner
import os
from lris2_drp.flat import (
from core.flat import (
load_flat_frame,
normalize_flat,
create_flat_correction,
save_corrected_fits,
)
from lris2_drp.tracing import trace_slits_1d, save_trace_solution
from lris2_drp.qa import generate_qa_plot
from core.tracing import trace_slits_1d, save_trace_solution
from core.qa import generate_qa_plot

@task(name="Process Single LRIS2 Flat Frame", description="Run all DRP steps on a single flat FITS file")
def process_flat_frame(filepath: str, output_dir: str):
Expand Down
9 changes: 7 additions & 2 deletions lris2_drp/qa.py → src/core/qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,15 @@
@task(name="Generate QA Plot", description="Save normalized flat as PNG", tags=["qa", "plot"])
def generate_qa_plot(data: np.ndarray, output_path: str, title: str = "Flat QA") -> str:
"""Generate a QA plot for the normalized flat field data."""

if data is None or not hasattr(data, "shape") or data.ndim != 2:
raise ValueError(f"Invalid data shape for QA plot: {getattr(data, 'shape', None)}")

os.makedirs(os.path.dirname(output_path), exist_ok=True)
plt.figure(figsize=(10, 4))
im = plt.imshow(data, cmap="gray", aspect="auto", origin="lower")
plt.colorbar(im)
fig, ax = plt.subplots()
im = ax.imshow(data, cmap="gray", aspect="auto", origin="lower")
fig.colorbar(im, ax=ax)
plt.title(title)
plt.savefig(output_path)
plt.close()
Expand Down
File renamed without changes.
Empty file added src/keck_primitives/__init__.py
Empty file.
13 changes: 13 additions & 0 deletions src/keck_primitives/normalize_flat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import numpy as np
from keckdrpframework.models.arguments import Arguments
from keckdrpframework.primitives.base_primitive import BasePrimitive

class NormalizeFlat(BasePrimitive):
def __init__(self, action, context):
super().__init__(action, context)

def _perform(self, input_args: Arguments, config: dict) -> dict:
data = input_args["flat_data"]
median = np.median(data[data > 0])
norm = data / median
return {"norm_data": norm}
14 changes: 14 additions & 0 deletions src/keck_primitives/trace_slits.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import numpy as np
from scipy.signal import find_peaks
from keckdrpframework.models.arguments import Arguments
from keckdrpframework.primitives.base_primitive import BasePrimitive

class TraceSlits1D(BasePrimitive):
def __init__(self, action, context):
super().__init__(action, context)

def _perform(self, input_args: Arguments, config: dict) -> dict:
data = input_args["flat_data"]
profile = np.median(data, axis=1)
peaks, _ = find_peaks(profile, distance=20, prominence=0.05)
return {"slit_positions": list(peaks)}
15 changes: 15 additions & 0 deletions src/keck_primitives/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
"""Utility functions and classes for Keck DRP primitives."""

class DummyAction:
def __init__(self, args=None):
self.args = args or {}
self.action_type = "dummy"
self.timestamp = None
self.context = None
self.logger = None

class DummyContext:
def __init__(self, logger=None, config=None):
import logging
self.logger = logger or logging.getLogger("dummy_logger")
self.config = config or {}
23 changes: 23 additions & 0 deletions src/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
"""
Main entry point for the LRIS2 Data Reduction Pipeline (DRP).
This script initializes the pipeline and processes all flat field FITS files
found in the specified input directory, saving the results to the output directory.
"""
import yaml
import os
from workflows.flows.batch_flat_flow import batch_process_all_flats

def load_config(config_path="config/config.yaml"):
"""Load configuration from a YAML file."""
with open(config_path, "r") as f:
return yaml.safe_load(f)

if __name__ == "__main__":
config = load_config()
input_dir = config["input_dir"]
output_dir = config["output_dir"]

os.makedirs(output_dir, exist_ok=True)

print(f"🟢 Starting batch processing of FITS files in {input_dir}")
batch_process_all_flats(input_dir=input_dir, output_dir=output_dir)
Empty file added src/workflows/__init__.py
Empty file.
Empty file added src/workflows/flows/__init__.py
Empty file.
61 changes: 61 additions & 0 deletions src/workflows/flows/batch_flat_flow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import os
from prefect import flow, task, get_run_logger
from prefect.task_runners import ConcurrentTaskRunner
from workflows.prefect_tasks.load_flat import load_flat_frame_task
from workflows.prefect_tasks.normalize_flat import normalize_flat_task
from workflows.prefect_tasks.create_correction import create_flat_correction_task
from workflows.prefect_tasks.trace_slits import trace_slits_task
from workflows.prefect_tasks.save_corrected import save_corrected_fits_task
from workflows.prefect_tasks.save_trace import save_trace_solution_task
from workflows.prefect_tasks.qa_plot import generate_qa_plot_task


@task(name="Process Single Flat Frame")
def process_single_flat_frame(flat_fits_path: str, output_dir: str):
"""Process a single LRIS2 flat FITS file through all DRP steps."""
logger = get_run_logger()
filename = os.path.splitext(os.path.basename(flat_fits_path))[0]

# Construct output paths
corrected_output = os.path.join(output_dir, filename, "flat_corrected.fits")
trace_output = os.path.join(output_dir, filename, "slit_trace.txt")
qa_output = os.path.join(output_dir, filename, "flat_norm_qa.png")

# Ensure output dirs
os.makedirs(os.path.dirname(corrected_output), exist_ok=True)

# Load FITS
data, header = load_flat_frame_task(flat_fits_path)

# DRP steps
norm = normalize_flat_task(data)
correction = create_flat_correction_task(norm)
slit_positions = trace_slits_task(data)

# Save outputs
save_corrected_fits_task(data, correction, header, corrected_output)
save_trace_solution_task(slit_positions, trace_output)
generate_qa_plot_task(norm, qa_output)

logger.info(f"Finished processing {flat_fits_path}")

@flow(
name="Batch Process LRIS2 Flats",
description="Process all flat frames concurrently using Prefect",
task_runner=ConcurrentTaskRunner(max_workers=2), # You can adjust this
)
def batch_process_all_flats(input_dir: str, output_dir: str):
"""Process all FITS files in a directory using concurrent subflows."""
logger = get_run_logger()

fits_files = [
os.path.join(input_dir, f)
for f in os.listdir(input_dir)
if f.lower().endswith(".fits")
]
logger.info(f"Found {len(fits_files)} FITS files in {input_dir}.")

futures = [process_single_flat_frame.submit(fp, output_dir) for fp in fits_files]

for future in futures:
future.result()
Empty file.
6 changes: 6 additions & 0 deletions src/workflows/prefect_tasks/create_correction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from prefect import task
from core.flat import create_flat_correction

@task(name="Create Flat Correction")
def create_flat_correction_task(norm_data):
return create_flat_correction.fn(norm_data)
6 changes: 6 additions & 0 deletions src/workflows/prefect_tasks/load_flat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from prefect import task
from core.flat import load_flat_frame

@task(name="Load Flat Frame")
def load_flat_frame_task(filepath: str):
return load_flat_frame.fn(filepath)
16 changes: 16 additions & 0 deletions src/workflows/prefect_tasks/normalize_flat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from prefect import task
from keckdrpframework.models.arguments import Arguments
from keck_primitives.normalize_flat import NormalizeFlat
from keck_primitives.utils import DummyAction, DummyContext


@task(name="Normalize Flat")
def normalize_flat_task(data):
args = Arguments()
args["flat_data"] = data

action = DummyAction(args=args)
context = DummyContext()

result = NormalizeFlat(action, context)._perform(args, config={})
return result["norm_data"]
6 changes: 6 additions & 0 deletions src/workflows/prefect_tasks/qa_plot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from prefect import task
from core.qa import generate_qa_plot

@task(name="Generate QA Plot")
def generate_qa_plot_task(data, output_path: str, title: str = "Flat QA"):
return generate_qa_plot.fn(data, output_path, title)
6 changes: 6 additions & 0 deletions src/workflows/prefect_tasks/save_corrected.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from prefect import task
from core.flat import save_corrected_fits

@task(name="Save Corrected FITS")
def save_corrected_fits_task(original_data, correction, header, output_path: str):
return save_corrected_fits.fn(original_data, correction, header, output_path)
6 changes: 6 additions & 0 deletions src/workflows/prefect_tasks/save_trace.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from prefect import task
from core.tracing import save_trace_solution

@task(name="Save Trace Solution")
def save_trace_solution_task(slit_positions, output_path: str):
return save_trace_solution.fn(slit_positions, output_path)
14 changes: 14 additions & 0 deletions src/workflows/prefect_tasks/trace_slits.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from prefect import task
from keck_primitives.trace_slits import TraceSlits1D
from keckdrpframework.models.arguments import Arguments
from keck_primitives.utils import DummyAction, DummyContext

@task(name="Trace Slits 1D")
def trace_slits_task(data):
"""Task to trace slits in 1D data."""
args = Arguments()
args["flat_data"] = data
action = DummyAction(args=args)
context = DummyContext()
result = TraceSlits1D(action, context)._perform(args, config={})
return result["slit_positions"]