From 376645045920df4bd8f2682d5729aefaac6becb4 Mon Sep 17 00:00:00 2001 From: Dylan Lee Date: Thu, 28 Aug 2025 11:09:10 -0400 Subject: [PATCH 1/2] Add batch submission code used during ripple batch runs extract_stac_geometries.py extracts and uploads stac geometries so that they can be used to query a HAND index for HAND data that coincides with a given STAC item submit_stac_batch.py submits batches of pipelines where each pipeline evaluates the data associated with a STAC item. The list of STAC items to be processed in a batch is given as an input to submit_stac_batch.py --- tools/extract_stac_geometries.py | 150 +++++++++ tools/submit_stac_batch.py | 530 +++++++++++++++++++++++++++++++ 2 files changed, 680 insertions(+) create mode 100644 tools/extract_stac_geometries.py create mode 100755 tools/submit_stac_batch.py diff --git a/tools/extract_stac_geometries.py b/tools/extract_stac_geometries.py new file mode 100644 index 0000000..f7a884e --- /dev/null +++ b/tools/extract_stac_geometries.py @@ -0,0 +1,150 @@ +#!/usr/bin/env python3 +import argparse +import logging +import os +from pathlib import Path +from typing import Optional + +import geopandas as gpd +import pystac_client +from shapely.geometry import shape +from shapely.ops import unary_union + + +def extract_geometry_by_stac_id( + item_id: str, + stac_api_url: str = "http://benchmark-stac.test.nextgenwaterprediction.com:8000", + collection: Optional[str] = None, + use_convex_hull: bool = False, +) -> gpd.GeoDataFrame: + """ + Fetch geometry for a STAC item by its ID. + + Args: + item_id: The STAC item ID to query for + stac_api_url: URL of the STAC API endpoint + collection: Optional collection to search within + use_convex_hull: If True, compute convex hull of the geometry + + Returns: + GeoDataFrame with the requested geometry in EPSG:4326. + + Raises: + ValueError: if no item is found with that ID + """ + # Connect to STAC API + catalog = pystac_client.Client.open(stac_api_url) + + # Build search parameters + search_params = {"ids": [item_id]} + if collection: + search_params["collections"] = [collection] + + # Search for the item + search = catalog.search(**search_params) + items = list(search.items()) + + if not items: + raise ValueError(f"No STAC item found with ID: {item_id}") + + # Get the first (should be only) item + item = items[0] + + # Extract geometry + if not item.geometry: + raise ValueError(f"STAC item {item_id} has no geometry") + + # Convert to shapely geometry + geom = shape(item.geometry) + + # Apply convex hull if requested + if use_convex_hull: + geom = geom.convex_hull + + # Create GeoDataFrame + gdf = gpd.GeoDataFrame( + [{"item_id": item_id, "collection": item.collection_id}], + geometry=[geom], + crs="EPSG:4326" + ) + + return gdf + + +def should_use_convex_hull(collection_id: Optional[str]) -> bool: + """ + Determine if convex hull should be used based on collection ID. + + Args: + collection_id: The collection ID from the STAC item + + Returns: + True if convex hull should be applied + """ + fim_collections = {"ble-collection", "nws-fim-collection", "usgs-fim-collection"} + return collection_id in fim_collections if collection_id else False + + +def main(): + parser = argparse.ArgumentParser(description="Extract geometries for STAC items") + parser.add_argument("item_list", help="Text file: one STAC item ID per line") + parser.add_argument("output_dir", help="Directory to save individual .gpkg files") + parser.add_argument( + "--stac-api-url", + default="http://benchmark-stac.test.nextgenwaterprediction.com:8000", + help="STAC API URL" + ) + parser.add_argument( + "--collection", + help="Optional: specific collection to search within" + ) + args = parser.parse_args() + + logging.basicConfig(level=os.environ.get("LOG_LEVEL", "INFO"), format="%(asctime)s %(levelname)s %(message)s") + + outdir = Path(args.output_dir) + outdir.mkdir(parents=True, exist_ok=True) + + with open(args.item_list) as fh: + item_ids = [line.strip() for line in fh if line.strip()] + + logging.info(f"Fetching {len(item_ids)} STAC item geometries from {args.stac_api_url}") + success = 0 + fail = 0 + + for idx, item_id in enumerate(item_ids): + try: + logging.info(f"[{idx}] Fetching STAC item {item_id}") + + # First, fetch without convex hull to get collection info + gdf = extract_geometry_by_stac_id( + item_id, + stac_api_url=args.stac_api_url, + collection=args.collection, + use_convex_hull=False + ) + + # Check if we should use convex hull based on collection + collection_id = gdf.iloc[0]["collection"] + if should_use_convex_hull(collection_id): + logging.info(f"[{idx}] Applying convex hull for collection {collection_id}") + gdf = extract_geometry_by_stac_id( + item_id, + stac_api_url=args.stac_api_url, + collection=args.collection, + use_convex_hull=True + ) + + out_fp = outdir / f"stac_{item_id}.gpkg" + gdf.to_file(out_fp, driver="GPKG") + logging.info(f"[{idx}] Saved {item_id} → {out_fp.name}") + success += 1 + except Exception as e: + logging.error(f"[{idx}] Failed {item_id}: {e}") + fail += 1 + + logging.info(f"Done: {success} succeeded, {fail} failed ({len(item_ids)} total).") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/tools/submit_stac_batch.py b/tools/submit_stac_batch.py new file mode 100755 index 0000000..b0889fc --- /dev/null +++ b/tools/submit_stac_batch.py @@ -0,0 +1,530 @@ +import argparse +import logging +import os +import sys +import time +from functools import wraps +from pathlib import Path +from typing import Any, Callable, Dict, List, Optional +from urllib.parse import urlparse + +import fsspec +import nomad + +from extract_stac_geometries import ( + extract_geometry_by_stac_id, + should_use_convex_hull, +) + + +def retry_with_backoff(max_retries: int = 2, backoff_base: float = 2.0): + """ + Decorator to add retry logic with exponential backoff to functions. + + Args: + max_retries: Maximum number of retry attempts (default: 2) + backoff_base: Base for exponential backoff calculation (default: 2.0) + """ + + def decorator(func: Callable) -> Callable: + @wraps(func) + def wrapper(*args, **kwargs) -> Any: + last_exception = None + func_name = func.__name__ + + for attempt in range(max_retries + 1): + try: + result = func(*args, **kwargs) + + if attempt > 0: + logging.info( + f"Function {func_name} succeeded on attempt {attempt + 1}" + ) + + return result + + except Exception as e: + last_exception = e + if attempt < max_retries: + wait_time = backoff_base**attempt + logging.warning( + f"Function {func_name} failed (attempt {attempt + 1}/{max_retries + 1}): {e}. Retrying in {wait_time} seconds..." + ) + time.sleep(wait_time) + else: + logging.error( + f"Function {func_name} failed after {max_retries + 1} attempts: {e}" + ) + + raise last_exception + + return wrapper + + return decorator + + +@retry_with_backoff(max_retries=2) +def submit_pipeline_job( + nomad_client: nomad.Nomad, + item_id: str, + gpkg_path: str, + batch_name: str, + output_root: str, + hand_index_path: str, + benchmark_sources: str, + collection_id: str, + nomad_token: Optional[str] = None, + use_local_creds: bool = False, +) -> str: + """ + Submit a pipeline job for a single STAC item. + + Returns: + Dispatched job ID + """ + # Ensure output_root doesn't have redundant trailing slashes + output_root_clean = output_root.rstrip("/") + + meta = { + "aoi": str(gpkg_path), + "outputs_path": f"{output_root_clean}/{item_id}/", + "hand_index_path": hand_index_path, + "benchmark_sources": benchmark_sources, + "tags": f"batch_name={batch_name} aoi_name={item_id}", + "nomad_token": nomad_token or os.environ.get("NOMAD_TOKEN", ""), + "registry_token": os.environ.get("REGISTRY_TOKEN", ""), + } + + # Include AWS credentials from environment if using local creds + if use_local_creds: + meta.update( + { + "aws_access_key": os.environ.get("AWS_ACCESS_KEY_ID", ""), + "aws_secret_key": os.environ.get("AWS_SECRET_ACCESS_KEY", ""), + "aws_session_token": os.environ.get("AWS_SESSION_TOKEN", ""), + } + ) + + # Create the id_prefix_template in the format [batch_name=value,aoi_name=value,collection=value] + id_prefix_template = f"[batch_name={batch_name},aoi_name={item_id},collection={collection_id}]" + + result = nomad_client.job.dispatch_job( + id_="pipeline", + payload=None, + meta=meta, + id_prefix_template=id_prefix_template, + ) + + return result["DispatchedJobID"] + + +@retry_with_backoff(max_retries=2) +def get_running_pipeline_jobs(nomad_client: nomad.Nomad) -> int: + """ + Get the count of running and queued pipeline jobs, including dispatched jobs waiting for allocation. + + Returns: + Number of pipeline jobs in active states (not finished) + """ + # Get all jobs to include dispatched jobs that haven't been allocated yet + jobs = nomad_client.jobs.get_jobs() + pipeline_jobs = [ + job for job in jobs if job.get("ID", "").startswith("pipeline") + ] + + running_count = 0 + for job in pipeline_jobs: + job_status = job.get("Status", "") + # Debug logging to see actual job statuses + logging.debug( + f"Pipeline job {job.get('ID', 'unknown')}: Status={job_status}" + ) + + # Count jobs that are not finished (dead = finished) + # "running" includes both allocated jobs and dispatched jobs waiting for allocation + if job_status != "dead": + running_count += 1 + + logging.debug( + f"Found {running_count} active pipeline jobs out of {len(pipeline_jobs)} total pipeline jobs" + ) + return running_count + + +def extract_items( + item_ids: List[str], + temp_dir: Path, + stac_api_url: str, + collection: Optional[str] = None, +) -> Dict[str, tuple[Path, str]]: + """ + Extract STAC item geometries and save as individual gpkg files. + + Returns: + Dict mapping item IDs to tuples of (gpkg file path, collection ID) + """ + item_files = {} + + for item_id in item_ids: + try: + logging.info(f"Extracting geometry for STAC item {item_id}") + + # First, fetch without convex hull to get collection info + gdf = extract_geometry_by_stac_id( + item_id, + stac_api_url=stac_api_url, + collection=collection, + use_convex_hull=False, + ) + + # Check if we should use convex hull based on collection + collection_id = gdf.iloc[0]["collection"] + if should_use_convex_hull(collection_id): + logging.info( + f"Applying convex hull for collection {collection_id}" + ) + gdf = extract_geometry_by_stac_id( + item_id, + stac_api_url=stac_api_url, + collection=collection, + use_convex_hull=True, + ) + + # Save to temp file + output_file = temp_dir / f"stac_{item_id}.gpkg" + gdf.to_file(output_file, driver="GPKG") + + item_files[item_id] = (output_file, collection_id) + logging.info( + f"Saved STAC item {item_id} (collection: {collection_id}) to {output_file}" + ) + + except Exception as e: + logging.error(f"Failed to extract STAC item {item_id}: {e}") + # Continue with other items + + return item_files + + +def main(): + parser = argparse.ArgumentParser( + description="Submit batch of pipeline jobs for multiple STAC items" + ) + + # Required arguments + parser.add_argument( + "--batch_name", + required=True, + help="Name for this batch of jobs (passed as tag)", + ) + parser.add_argument( + "--output_root", + required=True, + help="Root directory for outputs (STAC item ID will be appended to create an individual pipelines output path)", + ) + parser.add_argument( + "--hand_index_path", + required=True, + help="Path to HAND index (passed to pipeline)", + ) + parser.add_argument( + "--benchmark_sources", + required=True, + help="Comma-separated benchmark sources (passed to pipeline)", + ) + parser.add_argument( + "--item_list", + required=True, + help="Path to text file with STAC item IDs (one per line)", + ) + + # Optional arguments + parser.add_argument( + "--temp_dir", + default="/tmp/stac_batch", + help="Temporary directory for extracted STAC item .gpkg files", + ) + parser.add_argument( + "--wait_seconds", + type=int, + default=0, + help="Seconds to wait between job submissions", + ) + parser.add_argument( + "--stop_threshold", + type=int, + required=True, + help="Stop submitting jobs when this many pipelines are running/queued", + ) + parser.add_argument( + "--resume_threshold", + type=int, + required=True, + help="Resume submitting jobs when running/queued pipelines drop to this level (must be less than stop_threshold)", + ) + + # Nomad connection arguments + parser.add_argument( + "--nomad_addr", + default=os.environ.get("NOMAD_ADDR", "http://localhost:4646"), + help="Nomad server address", + ) + parser.add_argument( + "--nomad_namespace", + default=os.environ.get("NOMAD_NAMESPACE", "default"), + help="Nomad namespace", + ) + parser.add_argument( + "--nomad_token", + default=os.environ.get("NOMAD_TOKEN"), + help="Nomad ACL token", + ) + + # AWS authentication arguments + parser.add_argument( + "--use-local-creds", + action="store_true", + help="Use AWS credentials from shell environment instead of IAM roles", + ) + + # Local output arguments + parser.add_argument( + "--use-local-output", + action="store_true", + help="Copy AOI files to a local directory instead of uploading to S3", + ) + + # STAC-specific arguments + parser.add_argument( + "--stac_api_url", + default="http://benchmark-stac.test.nextgenwaterprediction.com:8000", + help="STAC API URL", + ) + parser.add_argument( + "--collection", help="Optional: specific collection to search within" + ) + + args = parser.parse_args() + + # Validate threshold relationship + if args.resume_threshold >= args.stop_threshold: + parser.error("--resume_threshold must be less than --stop_threshold") + + # Setup logging + logging.basicConfig( + level=os.environ.get("LOG_LEVEL", "INFO"), + format="%(asctime)s %(levelname)s %(message)s", + ) + + # Create temp directory + temp_dir = Path(args.temp_dir) + temp_dir.mkdir(parents=True, exist_ok=True) + + # Read STAC item IDs + with open(args.item_list, "r") as f: + item_ids = [line.strip() for line in f if line.strip()] + + logging.info(f"Loaded {len(item_ids)} STAC item IDs from {args.item_list}") + + # Extract STAC item geometries + logging.info("Extracting STAC item geometries...") + item_files = extract_items( + item_ids, temp_dir, args.stac_api_url, args.collection + ) + + if not item_files: + logging.error("No STAC item geometries extracted successfully") + return 1 + + logging.info( + f"Successfully extracted {len(item_files)} STAC item geometries" + ) + + # Initialize appropriate filesystem based on output mode + if args.use_local_output: + # Use local filesystem + fs = fsspec.filesystem("file") + base_path = f"{args.output_root.rstrip('/')}/{args.batch_name}/stac_aois" + logging.info(f"Using local output directory: {base_path}") + else: + # Use S3 filesystem with fimc-data profile + fs = fsspec.filesystem("s3", profile="fimbucket") + base_path = f"{args.output_root.rstrip('/')}/{args.batch_name}/stac_aois" + logging.info(f"Using S3 output path: {base_path}") + + # Create output directory if using local filesystem + if args.use_local_output: + fs.makedirs(base_path, exist_ok=True) + + # Upload/copy AOI files using fsspec + aoi_paths = {} + action_verb = "Copying" if args.use_local_output else "Uploading" + + logging.info(f"{action_verb} AOI files to {base_path}") + for item_id, (local_path, collection_id) in item_files.items(): + dest_path = f"{base_path}/stac_{item_id}.gpkg" + try: + with open(local_path, "rb") as local_file: + with fs.open(dest_path, "wb") as dest_file: + dest_file.write(local_file.read()) + + # For local output, convert host path to container path + if args.use_local_output: + # Get the absolute path on the host + abs_dest_path = os.path.abspath(dest_path) + # Find where local-batches is in the path and replace everything before it with / + if '/local-batches/' in abs_dest_path: + # Split at local-batches and rejoin with container mount point + parts = abs_dest_path.split('/local-batches/') + container_path = '/local-batches/' + parts[-1] + else: + # Fallback - just use the dest_path as is + container_path = dest_path + aoi_paths[item_id] = (container_path, collection_id) + logging.info(f"{action_verb} {local_path} to {dest_path} (container: {container_path})") + else: + aoi_paths[item_id] = (dest_path, collection_id) + logging.info(f"{action_verb} {local_path} to {dest_path}") + except Exception as e: + logging.error(f"Failed to {action_verb.lower()} AOI for STAC item {item_id}: {e}") + continue + + if not aoi_paths: + logging.error(f"No AOI files {action_verb.lower()} successfully") + return 1 + + # Initialize Nomad client + parsed = urlparse(args.nomad_addr) + nomad_client = nomad.Nomad( + host=parsed.hostname, + port=parsed.port or 4646, + verify=False, + token=args.nomad_token, + namespace=args.nomad_namespace, + ) + + # Process STAC items - submit all jobs immediately + submitted_jobs = [] + failed_submissions = [] + + # Track submission state for hysteresis + submission_paused = False + + logging.info(f"Starting job submission for {len(aoi_paths)} STAC items") + logging.info( + f"Thresholds - Stop: {args.stop_threshold}, Resume: {args.resume_threshold}" + ) + + for item_id, (aoi_path, collection_id) in aoi_paths.items(): + # Implement hysteresis for job submission control + while True: + current_jobs = get_running_pipeline_jobs(nomad_client) + # Subtract 1 to account for the parameterized job template that's always running + actual_running = current_jobs - 1 + + if not submission_paused: + # Currently submitting - check if we should pause + if actual_running >= args.stop_threshold: + submission_paused = True + logging.info( + f"Stop threshold ({args.stop_threshold}) reached. Current jobs: {actual_running}. " + f"Pausing submissions until jobs drop to {args.resume_threshold}..." + ) + else: + # Can continue submitting + break + else: + # Currently paused - check if we should resume + if actual_running <= args.resume_threshold: + submission_paused = False + logging.info( + f"Resume threshold ({args.resume_threshold}) reached. Current jobs: {actual_running}. " + f"Resuming submissions..." + ) + break + else: + # Still need to wait + wait_time = max( + args.wait_seconds, 10 + ) # Minimum 10 seconds to avoid hammering the API + logging.debug( + f"Waiting for jobs to drop to resume threshold. Current: {actual_running}, " + f"Resume at: {args.resume_threshold}. Waiting {wait_time} seconds..." + ) + time.sleep(wait_time) + + logging.info( + f"Submitting job for STAC item {item_id} (collection: {collection_id})" + ) + + try: + job_id = submit_pipeline_job( + nomad_client=nomad_client, + item_id=item_id, + gpkg_path=aoi_path, + batch_name=args.batch_name, + output_root=args.output_root, + hand_index_path=args.hand_index_path, + benchmark_sources=args.benchmark_sources, + collection_id=collection_id, + nomad_token=args.nomad_token, + use_local_creds=args.use_local_creds, + ) + + submitted_jobs.append((item_id, job_id)) + logging.info( + f"Successfully submitted job {job_id} for STAC item {item_id}" + ) + + # Wait between submissions if specified + if args.wait_seconds > 0: + logging.info( + f"Waiting {args.wait_seconds} seconds before next submission..." + ) + time.sleep(args.wait_seconds) + + except Exception as e: + logging.error(f"Failed to submit job for STAC item {item_id}: {e}") + failed_submissions.append((item_id, str(e))) + + # Summary + logging.info("\n" + "=" * 60) + logging.info("BATCH SUBMISSION COMPLETE") + logging.info("=" * 60) + logging.info(f"Total STAC items processed: {len(item_files)}") + logging.info(f"Successfully submitted: {len(submitted_jobs)}") + logging.info(f"Failed submissions: {len(failed_submissions)}") + + if submitted_jobs: + logging.info("\nSubmitted jobs:") + for item_id, job_id in submitted_jobs: + logging.info(f" STAC item {item_id}: {job_id}") + + if failed_submissions: + logging.info("\nFailed submissions:") + for item_id, error in failed_submissions: + logging.info(f" STAC item {item_id}: {error}") + + # Monitor running jobs until all are complete + if submitted_jobs: + logging.info("\nMonitoring job completion...") + while True: + current_jobs = get_running_pipeline_jobs(nomad_client) + logging.info( + f"Currently running pipeline jobs: {current_jobs - 1}" + ) # don't count the parent job + + if ( + current_jobs <= 1 + ): # Only the parameterized job template should remain + logging.info("All submitted jobs have completed!") + break + + # Wait before checking again + time.sleep(60) # Check every minute + + return 0 if not failed_submissions else 1 + + +if __name__ == "__main__": + sys.exit(main()) From 53b5f0ece9ad6e8168e5cd7c299b59e4e1f6cb92 Mon Sep 17 00:00:00 2001 From: Dylan Lee Date: Wed, 10 Sep 2025 15:23:54 -0400 Subject: [PATCH 2/2] Reformat line lengths --- tools/extract_stac_geometries.py | 55 ++++++++------------- tools/submit_stac_batch.py | 85 +++++++++----------------------- 2 files changed, 44 insertions(+), 96 deletions(-) diff --git a/tools/extract_stac_geometries.py b/tools/extract_stac_geometries.py index f7a884e..ae75b24 100644 --- a/tools/extract_stac_geometries.py +++ b/tools/extract_stac_geometries.py @@ -34,50 +34,46 @@ def extract_geometry_by_stac_id( """ # Connect to STAC API catalog = pystac_client.Client.open(stac_api_url) - + # Build search parameters search_params = {"ids": [item_id]} if collection: search_params["collections"] = [collection] - + # Search for the item search = catalog.search(**search_params) items = list(search.items()) - + if not items: raise ValueError(f"No STAC item found with ID: {item_id}") - + # Get the first (should be only) item item = items[0] - + # Extract geometry if not item.geometry: raise ValueError(f"STAC item {item_id} has no geometry") - + # Convert to shapely geometry geom = shape(item.geometry) - + # Apply convex hull if requested if use_convex_hull: geom = geom.convex_hull - + # Create GeoDataFrame - gdf = gpd.GeoDataFrame( - [{"item_id": item_id, "collection": item.collection_id}], - geometry=[geom], - crs="EPSG:4326" - ) - + gdf = gpd.GeoDataFrame([{"item_id": item_id, "collection": item.collection_id}], geometry=[geom], crs="EPSG:4326") + return gdf def should_use_convex_hull(collection_id: Optional[str]) -> bool: """ Determine if convex hull should be used based on collection ID. - + Args: collection_id: The collection ID from the STAC item - + Returns: True if convex hull should be applied """ @@ -90,14 +86,9 @@ def main(): parser.add_argument("item_list", help="Text file: one STAC item ID per line") parser.add_argument("output_dir", help="Directory to save individual .gpkg files") parser.add_argument( - "--stac-api-url", - default="http://benchmark-stac.test.nextgenwaterprediction.com:8000", - help="STAC API URL" - ) - parser.add_argument( - "--collection", - help="Optional: specific collection to search within" + "--stac-api-url", default="http://benchmark-stac.test.nextgenwaterprediction.com:8000", help="STAC API URL" ) + parser.add_argument("--collection", help="Optional: specific collection to search within") args = parser.parse_args() logging.basicConfig(level=os.environ.get("LOG_LEVEL", "INFO"), format="%(asctime)s %(levelname)s %(message)s") @@ -115,26 +106,20 @@ def main(): for idx, item_id in enumerate(item_ids): try: logging.info(f"[{idx}] Fetching STAC item {item_id}") - + # First, fetch without convex hull to get collection info gdf = extract_geometry_by_stac_id( - item_id, - stac_api_url=args.stac_api_url, - collection=args.collection, - use_convex_hull=False + item_id, stac_api_url=args.stac_api_url, collection=args.collection, use_convex_hull=False ) - + # Check if we should use convex hull based on collection collection_id = gdf.iloc[0]["collection"] if should_use_convex_hull(collection_id): logging.info(f"[{idx}] Applying convex hull for collection {collection_id}") gdf = extract_geometry_by_stac_id( - item_id, - stac_api_url=args.stac_api_url, - collection=args.collection, - use_convex_hull=True + item_id, stac_api_url=args.stac_api_url, collection=args.collection, use_convex_hull=True ) - + out_fp = outdir / f"stac_{item_id}.gpkg" gdf.to_file(out_fp, driver="GPKG") logging.info(f"[{idx}] Saved {item_id} → {out_fp.name}") @@ -147,4 +132,4 @@ def main(): if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/tools/submit_stac_batch.py b/tools/submit_stac_batch.py index b0889fc..51c5eec 100755 --- a/tools/submit_stac_batch.py +++ b/tools/submit_stac_batch.py @@ -10,7 +10,6 @@ import fsspec import nomad - from extract_stac_geometries import ( extract_geometry_by_stac_id, should_use_convex_hull, @@ -37,9 +36,7 @@ def wrapper(*args, **kwargs) -> Any: result = func(*args, **kwargs) if attempt > 0: - logging.info( - f"Function {func_name} succeeded on attempt {attempt + 1}" - ) + logging.info(f"Function {func_name} succeeded on attempt {attempt + 1}") return result @@ -52,9 +49,7 @@ def wrapper(*args, **kwargs) -> Any: ) time.sleep(wait_time) else: - logging.error( - f"Function {func_name} failed after {max_retries + 1} attempts: {e}" - ) + logging.error(f"Function {func_name} failed after {max_retries + 1} attempts: {e}") raise last_exception @@ -128,26 +123,20 @@ def get_running_pipeline_jobs(nomad_client: nomad.Nomad) -> int: """ # Get all jobs to include dispatched jobs that haven't been allocated yet jobs = nomad_client.jobs.get_jobs() - pipeline_jobs = [ - job for job in jobs if job.get("ID", "").startswith("pipeline") - ] + pipeline_jobs = [job for job in jobs if job.get("ID", "").startswith("pipeline")] running_count = 0 for job in pipeline_jobs: job_status = job.get("Status", "") # Debug logging to see actual job statuses - logging.debug( - f"Pipeline job {job.get('ID', 'unknown')}: Status={job_status}" - ) + logging.debug(f"Pipeline job {job.get('ID', 'unknown')}: Status={job_status}") # Count jobs that are not finished (dead = finished) # "running" includes both allocated jobs and dispatched jobs waiting for allocation if job_status != "dead": running_count += 1 - logging.debug( - f"Found {running_count} active pipeline jobs out of {len(pipeline_jobs)} total pipeline jobs" - ) + logging.debug(f"Found {running_count} active pipeline jobs out of {len(pipeline_jobs)} total pipeline jobs") return running_count @@ -180,9 +169,7 @@ def extract_items( # Check if we should use convex hull based on collection collection_id = gdf.iloc[0]["collection"] if should_use_convex_hull(collection_id): - logging.info( - f"Applying convex hull for collection {collection_id}" - ) + logging.info(f"Applying convex hull for collection {collection_id}") gdf = extract_geometry_by_stac_id( item_id, stac_api_url=stac_api_url, @@ -195,9 +182,7 @@ def extract_items( gdf.to_file(output_file, driver="GPKG") item_files[item_id] = (output_file, collection_id) - logging.info( - f"Saved STAC item {item_id} (collection: {collection_id}) to {output_file}" - ) + logging.info(f"Saved STAC item {item_id} (collection: {collection_id}) to {output_file}") except Exception as e: logging.error(f"Failed to extract STAC item {item_id}: {e}") @@ -207,9 +192,7 @@ def extract_items( def main(): - parser = argparse.ArgumentParser( - description="Submit batch of pipeline jobs for multiple STAC items" - ) + parser = argparse.ArgumentParser(description="Submit batch of pipeline jobs for multiple STAC items") # Required arguments parser.add_argument( @@ -286,7 +269,7 @@ def main(): action="store_true", help="Use AWS credentials from shell environment instead of IAM roles", ) - + # Local output arguments parser.add_argument( "--use-local-output", @@ -300,9 +283,7 @@ def main(): default="http://benchmark-stac.test.nextgenwaterprediction.com:8000", help="STAC API URL", ) - parser.add_argument( - "--collection", help="Optional: specific collection to search within" - ) + parser.add_argument("--collection", help="Optional: specific collection to search within") args = parser.parse_args() @@ -328,17 +309,13 @@ def main(): # Extract STAC item geometries logging.info("Extracting STAC item geometries...") - item_files = extract_items( - item_ids, temp_dir, args.stac_api_url, args.collection - ) + item_files = extract_items(item_ids, temp_dir, args.stac_api_url, args.collection) if not item_files: logging.error("No STAC item geometries extracted successfully") return 1 - logging.info( - f"Successfully extracted {len(item_files)} STAC item geometries" - ) + logging.info(f"Successfully extracted {len(item_files)} STAC item geometries") # Initialize appropriate filesystem based on output mode if args.use_local_output: @@ -359,7 +336,7 @@ def main(): # Upload/copy AOI files using fsspec aoi_paths = {} action_verb = "Copying" if args.use_local_output else "Uploading" - + logging.info(f"{action_verb} AOI files to {base_path}") for item_id, (local_path, collection_id) in item_files.items(): dest_path = f"{base_path}/stac_{item_id}.gpkg" @@ -367,16 +344,16 @@ def main(): with open(local_path, "rb") as local_file: with fs.open(dest_path, "wb") as dest_file: dest_file.write(local_file.read()) - + # For local output, convert host path to container path if args.use_local_output: # Get the absolute path on the host abs_dest_path = os.path.abspath(dest_path) # Find where local-batches is in the path and replace everything before it with / - if '/local-batches/' in abs_dest_path: + if "/local-batches/" in abs_dest_path: # Split at local-batches and rejoin with container mount point - parts = abs_dest_path.split('/local-batches/') - container_path = '/local-batches/' + parts[-1] + parts = abs_dest_path.split("/local-batches/") + container_path = "/local-batches/" + parts[-1] else: # Fallback - just use the dest_path as is container_path = dest_path @@ -411,9 +388,7 @@ def main(): submission_paused = False logging.info(f"Starting job submission for {len(aoi_paths)} STAC items") - logging.info( - f"Thresholds - Stop: {args.stop_threshold}, Resume: {args.resume_threshold}" - ) + logging.info(f"Thresholds - Stop: {args.stop_threshold}, Resume: {args.resume_threshold}") for item_id, (aoi_path, collection_id) in aoi_paths.items(): # Implement hysteresis for job submission control @@ -444,18 +419,14 @@ def main(): break else: # Still need to wait - wait_time = max( - args.wait_seconds, 10 - ) # Minimum 10 seconds to avoid hammering the API + wait_time = max(args.wait_seconds, 10) # Minimum 10 seconds to avoid hammering the API logging.debug( f"Waiting for jobs to drop to resume threshold. Current: {actual_running}, " f"Resume at: {args.resume_threshold}. Waiting {wait_time} seconds..." ) time.sleep(wait_time) - logging.info( - f"Submitting job for STAC item {item_id} (collection: {collection_id})" - ) + logging.info(f"Submitting job for STAC item {item_id} (collection: {collection_id})") try: job_id = submit_pipeline_job( @@ -472,15 +443,11 @@ def main(): ) submitted_jobs.append((item_id, job_id)) - logging.info( - f"Successfully submitted job {job_id} for STAC item {item_id}" - ) + logging.info(f"Successfully submitted job {job_id} for STAC item {item_id}") # Wait between submissions if specified if args.wait_seconds > 0: - logging.info( - f"Waiting {args.wait_seconds} seconds before next submission..." - ) + logging.info(f"Waiting {args.wait_seconds} seconds before next submission...") time.sleep(args.wait_seconds) except Exception as e: @@ -510,13 +477,9 @@ def main(): logging.info("\nMonitoring job completion...") while True: current_jobs = get_running_pipeline_jobs(nomad_client) - logging.info( - f"Currently running pipeline jobs: {current_jobs - 1}" - ) # don't count the parent job + logging.info(f"Currently running pipeline jobs: {current_jobs - 1}") # don't count the parent job - if ( - current_jobs <= 1 - ): # Only the parameterized job template should remain + if current_jobs <= 1: # Only the parameterized job template should remain logging.info("All submitted jobs have completed!") break