diff --git a/src/data_service.py b/src/data_service.py index f789af3..5d0867e 100644 --- a/src/data_service.py +++ b/src/data_service.py @@ -4,11 +4,15 @@ import os import shutil import tempfile +import threading +import time from contextlib import suppress +from datetime import datetime, timedelta from io import StringIO from pathlib import Path from typing import Any, Dict, List, Optional +import boto3 import fsspec import geopandas as gpd from botocore.exceptions import ClientError, NoCredentialsError @@ -20,12 +24,33 @@ from stac_querier import StacQuerier +class DataServiceException(Exception): + """Custom exception for data service operations.""" + + pass + + class DataService: """Service to query data sources and interact with S3 via fsspec.""" - def __init__(self, config: AppConfig, hand_index_path: str, benchmark_collections: Optional[List[str]] = None): + def __init__( + self, + config: AppConfig, + hand_index_path: str, + benchmark_collections: Optional[List[str]] = None, + aoi_is_item: bool = False, + ): self.config = config - # Configure S3 filesystem options from environment + self.aoi_is_item = aoi_is_item + + # Initialize credential refresh tracking + self._credential_lock = threading.Lock() + self._last_credential_refresh = datetime.now() + self._refresh_thread = None + self._stop_refresh = threading.Event() + self._use_iam_credentials = False + + # Configure S3 filesystem options - try config first, then IAM credentials self._s3_options = {} if config.aws.AWS_ACCESS_KEY_ID: self._s3_options["key"] = config.aws.AWS_ACCESS_KEY_ID @@ -34,6 +59,25 @@ def __init__(self, config: AppConfig, hand_index_path: str, benchmark_collection if config.aws.AWS_SESSION_TOKEN: self._s3_options["token"] = config.aws.AWS_SESSION_TOKEN + # If no explicit credentials in config, try to get from IAM instance profile + if not self._s3_options: + try: + credentials = boto3.Session().get_credentials() + if credentials: + self._s3_options["key"] = credentials.access_key + self._s3_options["secret"] = credentials.secret_key + if credentials.token: + self._s3_options["token"] = credentials.token + self._use_iam_credentials = True + logging.info("Using IAM instance profile credentials for S3 access") + + # Start credential refresh thread for IAM credentials + self._start_credential_refresh_thread() + else: + logging.warning("No AWS credentials found in config or IAM instance profile") + except Exception as e: + logging.warning(f"Failed to get IAM credentials: {e}") + # Initialize HandIndexQuerier with provided path self.hand_querier = None if hand_index_path: @@ -50,6 +94,7 @@ def __init__(self, config: AppConfig, hand_index_path: str, benchmark_collection collections=benchmark_collections, overlap_threshold_percent=config.stac.overlap_threshold_percent, datetime_filter=config.stac.datetime_filter, + aoi_is_item=aoi_is_item, ) # Initialize FlowfileCombiner @@ -57,6 +102,45 @@ def __init__(self, config: AppConfig, hand_index_path: str, benchmark_collection output_dir=config.flow_scenarios.output_dir if config.flow_scenarios else "combined_flowfiles" ) + def _start_credential_refresh_thread(self): + """Start a background thread to refresh IAM credentials every 4 hours.""" + if self._use_iam_credentials: + self._refresh_thread = threading.Thread(target=self._credential_refresh_loop, daemon=True) + self._refresh_thread.start() + logging.info("Started IAM credential refresh thread (4-hour interval)") + + def _credential_refresh_loop(self): + """Background thread loop to refresh credentials every 4 hours.""" + refresh_interval = 4 * 60 * 60 # sec + + while not self._stop_refresh.is_set(): + # Wait for 4 hours or until stop event is set + if self._stop_refresh.wait(refresh_interval): + break + + # Refresh credentials + self._refresh_credentials() + + def _refresh_credentials(self): + """Refresh IAM credentials from boto3 session.""" + try: + logging.info("Refreshing IAM credentials...") + credentials = boto3.Session().get_credentials() + + if credentials: + with self._credential_lock: + self._s3_options["key"] = credentials.access_key + self._s3_options["secret"] = credentials.secret_key + if credentials.token: + self._s3_options["token"] = credentials.token + self._last_credential_refresh = datetime.now() + + logging.info("Successfully refreshed IAM credentials") + else: + logging.error("Failed to refresh credentials: No credentials available from boto3 session") + except Exception as e: + logging.error(f"Error refreshing IAM credentials: {e}") + def load_polygon_gdf_from_file(self, file_path: str) -> gpd.GeoDataFrame: """ Args: @@ -71,8 +155,7 @@ def load_polygon_gdf_from_file(self, file_path: str) -> gpd.GeoDataFrame: """ try: # geopandas can read from both local and S3 paths - # For S3 paths, it will use storage_options; for local paths, it ignores them - gdf = gpd.read_file(file_path, storage_options=self._s3_options if file_path.startswith("s3://") else None) + gdf = gpd.read_file(file_path) if len(gdf) == 0: raise ValueError(f"Empty GeoDataFrame in file: {file_path}") @@ -93,12 +176,17 @@ def load_polygon_gdf_from_file(self, file_path: str) -> gpd.GeoDataFrame: except Exception as e: raise ValueError(f"Error loading GeoDataFrame from {file_path}: {e}") - async def query_stac_for_flow_scenarios(self, polygon_gdf: gpd.GeoDataFrame) -> Dict: + async def query_stac_for_flow_scenarios( + self, + polygon_gdf: gpd.GeoDataFrame, + tags: Optional[Dict[str, str]] = None, + ) -> Dict: """ - Query STAC API for flow scenarios based on polygon. + Query STAC API for flow scenarios based on polygon or item ID. Args: polygon_gdf: GeoDataFrame containing polygon geometry + tags: Optional tags dictionary (required when aoi_is_item is True) Returns: Dictionary with STAC query results and combined flowfiles @@ -114,12 +202,28 @@ async def query_stac_for_flow_scenarios(self, polygon_gdf: gpd.GeoDataFrame) -> try: # Run STAC query in executor to avoid blocking loop = asyncio.get_running_loop() - stac_results = await loop.run_in_executor( - None, - self.stac_querier.query_stac_for_polygon, - polygon_gdf, - None, # roi_geojson not needed since we have polygon_gdf - ) + + if self.aoi_is_item: + # Direct item query mode - extract aoi_name from tags + if not tags or "aoi_name" not in tags: + raise ValueError("aoi_name tag is required when --aoi_is_item is used") + + aoi_name = tags["aoi_name"] + logging.info(f"Querying STAC for specific item ID: {aoi_name}") + + stac_results = await loop.run_in_executor( + None, + self.stac_querier.query_stac_by_item_id, + aoi_name, + ) + else: + # Standard spatial query mode + stac_results = await loop.run_in_executor( + None, + self.stac_querier.query_stac_for_polygon, + polygon_gdf, + None, # roi_geojson not needed since we have polygon_gdf + ) if not stac_results: logging.info(f"No STAC results found for polygon {polygon_id}") @@ -209,7 +313,10 @@ async def query_for_catchments(self, polygon_gdf: gpd.GeoDataFrame) -> Dict: logging.info( f"Data service returning {len(catchments)} catchments for polygon {polygon_id} (real query)." ) - return {"catchments": catchments, "hand_version": "real_query"} + return { + "catchments": catchments, + "hand_version": "real_query", + } except Exception as e: logging.error(f"Error in hand index query for polygon {polygon_id}: {e}") @@ -259,7 +366,7 @@ async def copy_file_to_uri(self, source_path: str, dest_uri: str): return dest_uri except Exception as e: logging.exception(f"Failed to copy {source_path} to {dest_uri}") - raise ConnectionError(f"Failed to copy file to {dest_uri}") from e + raise DataServiceException(f"Failed to copy file to {dest_uri}: {str(e)}") from e else: # If already on S3 or both local with same path, just return the source path logging.debug(f"No copy needed - using existing path: {source_path}") @@ -267,16 +374,72 @@ async def copy_file_to_uri(self, source_path: str, dest_uri: str): def _sync_copy_file(self, source_path: str, dest_uri: str): """Synchronous helper for copying files using fsspec.""" - # Create filesystem based on destination URI + # Copy file using fsspec.open directly (allows fallback to default AWS credentials) + with open(source_path, "rb") as src: + # Use thread-safe credential access for S3 operations + if dest_uri.startswith("s3://"): + with self._credential_lock: + s3_opts = self._s3_options.copy() + with fsspec.open(dest_uri, "wb", **s3_opts) as dst: + dst.write(src.read()) + else: + with fsspec.open(dest_uri, "wb") as dst: + dst.write(src.read()) + + async def append_file_to_uri(self, source_path: str, dest_uri: str): + """Appends a file to a URI (local or S3) using fsspec. + If the destination file doesn't exist, creates it with the source content. + If it exists, appends the source content to it.""" + + logging.debug(f"Appending file from {source_path} to {dest_uri}") + try: + # Check if destination exists + dest_exists = await self.check_file_exists(dest_uri) + + # Run in executor for async operation + loop = asyncio.get_running_loop() + await loop.run_in_executor( + None, + self._sync_append_file, + source_path, + dest_uri, + dest_exists, + ) + logging.info(f"Successfully {'appended to' if dest_exists else 'created'} {dest_uri}") + return dest_uri + except Exception as e: + logging.exception(f"Failed to append {source_path} to {dest_uri}") + raise DataServiceException(f"Failed to append file to {dest_uri}: {str(e)}") from e + + def _sync_append_file(self, source_path: str, dest_uri: str, dest_exists: bool): + """Synchronous helper for appending files using fsspec.""" + # Read source content + with open(source_path, "rb") as src: + source_content = src.read() + + # Use thread-safe credential access for S3 operations if dest_uri.startswith("s3://"): - fs = fsspec.filesystem("s3", **self._s3_options) + with self._credential_lock: + s3_opts = self._s3_options.copy() else: - fs = fsspec.filesystem("file") + s3_opts = {} - # Copy file using fsspec - with open(source_path, "rb") as src: - with fs.open(dest_uri, "wb") as dst: - dst.write(src.read()) + if dest_exists: + # If destination exists, read existing content first + with fsspec.open(dest_uri, "rb", **s3_opts) as dst: + existing_content = dst.read() + + # Write combined content + with fsspec.open(dest_uri, "wb", **s3_opts) as dst: + dst.write(existing_content) + dst.write(source_content) + else: + # If destination doesn't exist, just write the source content + if not dest_uri.startswith(("s3://", "http://", "https://")): + os.makedirs(os.path.dirname(dest_uri), exist_ok=True) + + with fsspec.open(dest_uri, "wb", **s3_opts) as dst: + dst.write(source_content) async def check_file_exists(self, uri: str) -> bool: """Check if a file exists (S3 or local). @@ -295,14 +458,17 @@ async def check_file_exists(self, uri: str) -> bool: uri, ) except Exception as e: - logging.warning(f"Error checking file {uri}: {e}") - return False + logging.exception(f"Error checking file {uri}") + raise DataServiceException(f"Failed to check file existence {uri}: {str(e)}") from e def _sync_check_file_exists(self, uri: str) -> bool: """Synchronous helper to check if file exists (S3 or local).""" try: if uri.startswith("s3://"): - fs = fsspec.filesystem("s3", **self._s3_options) + # Use thread-safe credential access for S3 operations + with self._credential_lock: + s3_opts = self._s3_options.copy() + fs = fsspec.filesystem("s3", **s3_opts) return fs.exists(uri) else: # Local file @@ -365,26 +531,38 @@ def find_metrics_files(self, base_path: str) -> List[str]: try: if base_path.startswith("s3://"): - fs = fsspec.filesystem("s3", **self._s3_options) - # Use glob to find all metrics.csv files recursively - pattern = f"{base_path.rstrip('/')}/*/*/metrics.csv" + # Use thread-safe credential access for S3 operations + with self._credential_lock: + s3_opts = self._s3_options.copy() + fs = fsspec.filesystem("s3", **s3_opts) + # Use glob to find all metrics.csv files recursively (including tagged versions) + pattern = f"{base_path.rstrip('/')}/**/*__metrics.csv" metrics_files = fs.glob(pattern) # Add s3:// prefix back since glob strips it metrics_files = [f"s3://{path}" for path in metrics_files] else: base_path_obj = Path(base_path) if base_path_obj.exists(): - metrics_files = [str(p) for p in base_path_obj.glob("*/*/metrics.csv")] + metrics_files = [str(p) for p in base_path_obj.glob("**/*__metrics.csv")] logging.info(f"Found {len(metrics_files)} metrics.csv files in {base_path}") return metrics_files except Exception as e: - logging.error(f"Error finding metrics files in {base_path}: {e}") - return [] + logging.exception(f"Error finding metrics files in {base_path}") + raise DataServiceException(f"Failed to find metrics files in {base_path}: {str(e)}") from e def cleanup(self): - """Clean up resources, including HandIndexQuerier connection.""" + """Clean up resources, including HandIndexQuerier connection and refresh thread.""" + # Stop the credential refresh thread if it's running + if self._refresh_thread and self._refresh_thread.is_alive(): + logging.info("Stopping credential refresh thread...") + self._stop_refresh.set() + self._refresh_thread.join(timeout=5) + if self._refresh_thread.is_alive(): + logging.warning("Credential refresh thread did not stop gracefully") + + # Clean up HandIndexQuerier if self.hand_querier: self.hand_querier.close() self.hand_querier = None diff --git a/src/hand_index_querier.py b/src/hand_index_querier.py index d7814f7..8fea0dc 100644 --- a/src/hand_index_querier.py +++ b/src/hand_index_querier.py @@ -3,10 +3,12 @@ from pathlib import Path from typing import Dict, List, Optional, Tuple +import boto3 import duckdb import geopandas as gpd import pandas as pd from shapely.geometry import Polygon +from shapely.wkb import loads from shapely.wkt import dumps logger = logging.getLogger(__name__) @@ -14,7 +16,7 @@ class HandIndexQuerier: """ - A class for querying HAND index data from partitioned parquet files. + A class for querying HAND index data from parquet files. Provides spatial intersection and filtering capabilities. """ @@ -23,12 +25,14 @@ def __init__(self, partitioned_base_path: str, overlap_threshold_percent: float Initialize the HandIndexQuerier. Args: - partitioned_base_path: Base path to partitioned parquet files (local or s3://) + partitioned_base_path: Base path to parquet files (local or s3://) overlap_threshold_percent: Minimum overlap percentage to keep a catchment """ self.partitioned_base_path = partitioned_base_path self.overlap_threshold_percent = overlap_threshold_percent self.con = None + self.credentials = boto3.Session().get_credentials() + self.s3_region = boto3.Session().region_name def _ensure_connection(self): """Ensure DuckDB connection is established with required extensions.""" @@ -43,14 +47,41 @@ def _ensure_connection(self): self.con.execute("LOAD httpfs;") self.con.execute("INSTALL aws;") self.con.execute("LOAD aws;") + + # Try using DuckDB's credential chain provider instead of manual credentials + # This should work better with temporary credentials + try: + self.con.execute(f""" + CREATE SECRET s3_secret ( + TYPE S3, + PROVIDER CREDENTIAL_CHAIN, + REGION '{self.s3_region or "us-east-1"}' + ); + """) + logger.info("Using DuckDB credential chain for S3 access") + except duckdb.Error as cred_chain_error: + logger.warning(f"Credential chain failed, falling back to manual credentials: {cred_chain_error}") + # Fallback to manual credential setting + self.con.execute(f"SET s3_region = '{self.s3_region or 'us-east-1'}';") + if self.credentials: + if self.credentials.access_key: + self.con.execute(f"SET s3_access_key_id='{self.credentials.access_key}';") + if self.credentials.secret_key: + self.con.execute(f"SET s3_secret_access_key='{self.credentials.secret_key}';") + if self.credentials.token: + self.con.execute(f"SET s3_session_token='{self.credentials.token}';") + + self.con.execute("SET memory_limit = '7GB';") + self.con.execute("SET temp_directory = '/tmp';") + except duckdb.Error as e: logger.warning("Could not load DuckDB extensions: %s", e) - # Create partitioned views - self._create_partitioned_views() + # Create views + self._create_views() - def _create_partitioned_views(self): - """Create views for partitioned tables.""" + def _create_views(self): + """Create views for non-partitioned tables.""" base_path = self.partitioned_base_path if not base_path.endswith("/"): base_path += "/" @@ -59,13 +90,13 @@ def _create_partitioned_views(self): CREATE OR REPLACE VIEW catchments_partitioned AS SELECT * FROM read_parquet('{base_path}catchments/*/*.parquet', hive_partitioning = 1); - CREATE OR REPLACE VIEW hydrotables_partitioned AS - SELECT * FROM read_parquet('{base_path}hydrotables/*/*.parquet', hive_partitioning = 1); + CREATE OR REPLACE VIEW hydrotables AS + SELECT * FROM read_parquet('{base_path}hydrotables.parquet'); - CREATE OR REPLACE VIEW hand_rem_rasters_partitioned AS + CREATE OR REPLACE VIEW hand_rem_rasters AS SELECT * FROM read_parquet('{base_path}hand_rem_rasters.parquet'); - CREATE OR REPLACE VIEW hand_catchment_rasters_partitioned AS + CREATE OR REPLACE VIEW hand_catchment_rasters AS SELECT * FROM read_parquet('{base_path}hand_catchment_rasters.parquet'); """ @@ -91,9 +122,9 @@ def _partitioned_query_cte(self, wkt4326: str) -> str: SELECT c.catchment_id, c.geometry, - c.h3_partition_key + c.h3_index FROM catchments_partitioned c - JOIN transformed_query tq ON ST_Intersects(c.geometry, tq.query_geom) + JOIN transformed_query tq ON ST_Intersects(ST_GeomFromWKB(c.geometry), tq.query_geom) ) """ @@ -136,7 +167,7 @@ def _get_catchment_data_for_polygon( + """ SELECT fc.catchment_id, - ST_AsWKB(fc.geometry) AS geom_wkb + fc.geometry AS geom_wkb FROM filtered_catchments AS fc; """ ) @@ -146,28 +177,27 @@ def _get_catchment_data_for_polygon( empty_gdf = gpd.GeoDataFrame(columns=["catchment_id", "geometry"], geometry="geometry", crs="EPSG:5070") return empty_gdf, pd.DataFrame(), query_poly_5070 - # Decode WKB → shapely geometries - wkb_series = geom_df["geom_wkb"].apply(lambda x: bytes(x) if isinstance(x, bytearray) else x) + # Convert WKB data to Shapely geometry objects. Wrapping wkb in bytes because duckdb exports a bytearray but shapely wants bytes. + geom_df["geometry"] = geom_df["geom_wkb"].apply(lambda wkb: loads(bytes(wkb)) if wkb is not None else None) geometries_gdf = gpd.GeoDataFrame( - geom_df[["catchment_id"]], - geometry=gpd.GeoSeries.from_wkb(wkb_series, crs="EPSG:5070"), + geom_df[["catchment_id", "geometry"]], + geometry="geometry", crs="EPSG:5070", ) - # Build and run the attribute query using partitioned tables + # Build and run the attribute query using non-partitioned tables sql_attr = ( cte + """ SELECT fc.catchment_id, - h.* EXCLUDE (catchment_id, h3_partition_key), - hrr.rem_raster_id, + h.csv_path, hrr.raster_path AS rem_raster_path, hcr.raster_path AS catchment_raster_path FROM filtered_catchments AS fc - LEFT JOIN hydrotables_partitioned AS h ON fc.catchment_id = h.catchment_id - LEFT JOIN hand_rem_rasters_partitioned AS hrr ON fc.catchment_id = hrr.catchment_id - LEFT JOIN hand_catchment_rasters_partitioned AS hcr ON hrr.rem_raster_id = hcr.rem_raster_id; + LEFT JOIN hydrotables AS h ON fc.catchment_id = h.catchment_id + LEFT JOIN hand_rem_rasters AS hrr ON fc.catchment_id = hrr.catchment_id + LEFT JOIN hand_catchment_rasters AS hcr ON fc.catchment_id = hcr.catchment_id; """ ) attributes_df = self.con.execute(sql_attr).fetch_df() @@ -275,12 +305,6 @@ def query_catchments_for_polygon( for catch_id, group in filtered_attrs.groupby("catchment_id"): df = group.drop(columns=["catchment_id"]).copy() - # Convert UUID columns to strings for Parquet compatibility - uuid_columns = ["rem_raster_id", "catchment_raster_id"] - for col in uuid_columns: - if col in df.columns: - df[col] = df[col].astype(str) - out_path = outdir / f"{catch_id}.parquet" df.to_parquet(str(out_path), index=False) logger.info("Wrote %d rows for catchment '%s' → %s", len(df), catch_id, out_path) diff --git a/src/stac_querier.py b/src/stac_querier.py index d531afc..3509bfe 100644 --- a/src/stac_querier.py +++ b/src/stac_querier.py @@ -81,6 +81,7 @@ def __init__( collections: Optional[List[str]] = None, overlap_threshold_percent: float = 40.0, datetime_filter: Optional[str] = None, + aoi_is_item: bool = False, ): """ Initialize the StacQuerier. @@ -90,11 +91,13 @@ def __init__( collections: List of STAC collection IDs to query (None means query all available) overlap_threshold_percent: Minimum overlap percentage to keep a STAC item datetime_filter: STAC datetime or interval filter + aoi_is_item: If True, query specific STAC items directly by ID """ self.api_url = api_url self.collections = collections self.overlap_threshold_percent = overlap_threshold_percent self.datetime_filter = datetime_filter + self.aoi_is_item = aoi_is_item self.client = None # Collection specifications @@ -269,6 +272,8 @@ def _format_results(self, item_iter: Any) -> Dict[str, Dict[str, Dict[str, List[ ripple_best_items = {} # Track STAC item IDs for each scenario item_ids = defaultdict(lambda: defaultdict(set)) + gauge_info = defaultdict(lambda: defaultdict(lambda: None)) + hucs_info = defaultdict(lambda: defaultdict(list)) for idx, item in enumerate(item_iter, start=1): if item.collection_id == "ripple-fim-collection": @@ -285,7 +290,10 @@ def _format_results(self, item_iter: Any) -> Dict[str, Dict[str, Dict[str, List[ key not in ripple_best_items or compare_versions(flows2fim_version, ripple_best_items[key][1]) > 0 ): - ripple_best_items[key] = (item.id, flows2fim_version) + ripple_best_items[key] = ( + item.id, + flows2fim_version, + ) else: # Skip this item - we already have a better version logger.info( @@ -302,16 +310,25 @@ def _format_results(self, item_iter: Any) -> Dict[str, Dict[str, Dict[str, List[ logger.info(f"Processed {idx} items (last: {item.id})") coll = item.collection_id or "" - short = coll.replace("-collection", "").replace("-fim", "") - if short == "gfm-expanded": - short = "gfm_expanded" + collection_key = coll + # Handle special case for gfm-expanded-collection to maintain compatibility + if coll == "gfm-expanded-collection": + collection_key = "gfm_expanded" # 1) item‐level grouping if coll in self.COLLECTIONS: group_fn, tests = self.COLLECTIONS[coll] gid = group_fn(item) - bucket = results[short][gid] - item_ids[short][gid].add(item.id) + bucket = results[collection_key][gid] + item_ids[collection_key][gid].add(item.id) + + # Extract hucs for item-level collections + hucs = item.properties.get("hucs", []) + if hucs: + for huc in hucs: + if huc not in hucs_info[collection_key][gid]: + hucs_info[collection_key][gid].append(huc) + for k, a in item.assets.items(): if not a.href: continue @@ -319,10 +336,49 @@ def _format_results(self, item_iter: Any) -> Dict[str, Dict[str, Dict[str, List[ if test(k, a) and a.href not in bucket[atype]: bucket[atype].append(a.href) - # 2) BLE/NWS/USGS/Ripple asset‐level grouping - elif coll == "ble-collection" or coll.endswith("-fim-collection"): + # 2) BLE/NWS/USGS asset‐level grouping (excluding Ripple) + elif coll == "ble-collection" or coll in [ + "nws-fim-collection", + "usgs-fim-collection", + ]: + specs = self.BLE_SPEC if coll == "ble-collection" else self.NWS_USGS_SPEC + + found = set() + for k, a in item.assets.items(): + if not a.href: + continue + for pat, gid_t, at_t in specs: + m = pat.match(k) + if not m: + continue + gid = m.expand(gid_t) if "\\" in gid_t else gid_t + at = m.expand(at_t) + bkt = results[collection_key][gid] + item_ids[collection_key][gid].add(item.id) + if a.href not in bkt[at]: + bkt[at].append(a.href) + found.add(gid) + + if coll in [ + "nws-fim-collection", + "usgs-fim-collection", + ]: + gauge = item.properties.get("gauge") + if gauge and gauge_info[collection_key][gid] is None: + gauge_info[collection_key][gid] = gauge + + # Extract hucs for all asset-level collections (BLE, NWS, USGS) + hucs = item.properties.get("hucs", []) + if hucs: + for gid in found: + for huc in hucs: + if huc not in hucs_info[collection_key][gid]: + hucs_info[collection_key][gid].append(huc) + + # 3) Ripple asset-level grouping with source-based separation + elif coll == "ripple-fim-collection": # preload ripple assets once - if coll == "ripple-fim-collection" and not ripple_cache: + if not ripple_cache: try: col = item.get_collection() for ak, aa in col.assets.items(): @@ -334,13 +390,8 @@ def _format_results(self, item_iter: Any) -> Dict[str, Dict[str, Dict[str, List[ except Exception as e: logger.warning(f"Ripple cache failed: {e}") - specs = ( - self.BLE_SPEC - if coll == "ble-collection" - else self.RIPPLE_SPEC - if coll == "ripple-fim-collection" - else self.NWS_USGS_SPEC - ) + specs = self.RIPPLE_SPEC + source = item.properties.get("source", "unknown_source") found = set() for k, a in item.assets.items(): @@ -350,29 +401,54 @@ def _format_results(self, item_iter: Any) -> Dict[str, Dict[str, Dict[str, List[ m = pat.match(k) if not m: continue - gid = m.expand(gid_t) if "\\" in gid_t else gid_t + # Get the base ID (e.g., "100yr") + base_gid = m.expand(gid_t) if "\\" in gid_t else gid_t + # Create the new, unique group ID + gid = f"{source}-{base_gid}" at = m.expand(at_t) - bkt = results[short][gid] - item_ids[short][gid].add(item.id) + bkt = results[collection_key][gid] + item_ids[collection_key][gid].add(item.id) if a.href not in bkt[at]: bkt[at].append(a.href) found.add(gid) - # append ripple flowfiles - if coll == "ripple-fim-collection": + # Extract hucs for Ripple collections + hucs = item.properties.get("hucs", []) + if hucs: for gid in found: - if gid in ripple_cache: - logger.info(f"Adding cached flowfiles for {gid}") - for href in ripple_cache[gid]: - if href not in results[short][gid]["flowfiles"]: - results[short][gid]["flowfiles"].append(href) + for huc in hucs: + if huc not in hucs_info[collection_key][gid]: + hucs_info[collection_key][gid].append(huc) - # 3) fallback + # append ripple flowfiles + for composite_gid in found: + # Extract "100yr" from "SourceA-100yr" + base_gid = composite_gid.split("-", 1)[-1] + if base_gid in ripple_cache: + logger.info(f"Adding cached flowfiles for {composite_gid}") + for href in ripple_cache[base_gid]: + if href not in results[collection_key][composite_gid]["flowfiles"]: + results[collection_key][composite_gid]["flowfiles"].append(href) + + # 4) fallback else: logger.warning(f"Unknown coll '{coll}'; grouping by item.id") gid = item.id - bkt = results[short][gid] - item_ids[short][gid].add(item.id) + bkt = results[collection_key][gid] + item_ids[collection_key][gid].add(item.id) + + if "nws" in coll.lower() or "usgs" in coll.lower(): + gauge = item.properties.get("gauge") + if gauge and gauge_info[collection_key][gid] is None: + gauge_info[collection_key][gid] = gauge + + # Extract hucs for fallback collections + hucs = item.properties.get("hucs", []) + if hucs: + for huc in hucs: + if huc not in hucs_info[collection_key][gid]: + hucs_info[collection_key][gid].append(huc) + for k, a in item.assets.items(): if not a.href: continue @@ -381,11 +457,13 @@ def _format_results(self, item_iter: Any) -> Dict[str, Dict[str, Dict[str, List[ elif "flow" in k and a.media_type and "csv" in a.media_type: bkt["flowfiles"].append(a.href) - # Add item IDs to results - for collection_short in results: - for scenario_id in results[collection_short]: - results[collection_short][scenario_id]["stac_items"] = list(item_ids[collection_short][scenario_id]) - + # Add item IDs, gauge information, and hucs to results + for collection_key in results: + for scenario_id in results[collection_key]: + results[collection_key][scenario_id]["stac_items"] = list(item_ids[collection_key][scenario_id]) + results[collection_key][scenario_id]["gauge"] = gauge_info[collection_key][scenario_id] + results[collection_key][scenario_id]["hucs"] = hucs_info[collection_key][scenario_id] + logger.info(f"Finished formatting {len(seen)} items.") return results @@ -410,6 +488,18 @@ def _merge_gfm_expanded( for h in hs: if h not in iv.assets[at]: iv.assets[at].append(h) + elif at == "gauge": + # Handle gauge specially - it's a single value, not a list + if hs is not None: + iv.assets[at] = hs + elif at == "hucs": + # Handle hucs specially - merge lists without duplicates + if hs: + existing_hucs = iv.assets.get(at, []) + for h in hs: + if h not in existing_hucs: + existing_hucs.append(h) + iv.assets[at] = existing_hucs else: for h in hs: if h not in iv.assets[at]: @@ -430,6 +520,18 @@ def _merge_gfm_expanded( for h in hs: if h not in cur.assets[at]: cur.assets[at].append(h) + elif at == "gauge": + # Handle gauge specially - use the first non-null value + if hs is not None and at not in cur.assets: + cur.assets[at] = hs + elif at == "hucs": + # Handle hucs specially - merge lists without duplicates + if hs: + existing_hucs = cur.assets.get(at, []) + for h in hs: + if h not in existing_hucs: + existing_hucs.append(h) + cur.assets[at] = existing_hucs else: for h in hs: if h not in cur.assets[at]: @@ -447,7 +549,9 @@ def _merge_gfm_expanded( return out def query_stac_for_polygon( - self, polygon_gdf: Optional[gpd.GeoDataFrame] = None, roi_geojson: Optional[Dict] = None + self, + polygon_gdf: Optional[gpd.GeoDataFrame] = None, + roi_geojson: Optional[Dict] = None, ) -> Dict[str, Dict[str, Dict[str, List[str]]]]: """ Query STAC API for items and group them by collection/scenario. @@ -489,11 +593,16 @@ def query_stac_for_polygon( "datetime": self.datetime_filter, **({"intersects": intersects} if intersects else {}), } - + + # KLUDGE: Hardcode 2022 date range for GFM collections. this will be removed after trial batches with gfm data are completed. + if self.collections is not None and any("gfm" in coll.lower() for coll in self.collections): + search_kw["datetime"] = "2022-01-01T00:00:00Z/2022-12-31T23:59:59Z" + logger.info("KLUDGE: Overriding datetime filter to 2022 for GFM collections") + # Only include collections if specified, otherwise query all available if self.collections is not None: search_kw["collections"] = self.collections - + search_kw = {k: v for k, v in search_kw.items() if v is not None} try: @@ -502,15 +611,37 @@ def query_stac_for_polygon( search = self.client.search(**search_kw) items = list(search.items()) # Convert to list for geometry filtering + if not items: + logger.info("STAC query returned no items") + return {} + # Apply geometry filtering if query polygon is provided if query_polygon and self.overlap_threshold_percent: logger.info(f"Applying geometry filtering with {self.overlap_threshold_percent}% overlap threshold") items = self._filter_items_by_geometry(items, query_polygon) + if not items: + logger.info("No items remained after geometry filtering") + return {} grouped = self._format_results(items) + + if not grouped: + logger.info("No valid scenarios found after processing STAC items") + return {} if "gfm_expanded" in grouped: grouped["gfm_expanded"] = self._merge_gfm_expanded(grouped["gfm_expanded"]) + # Truncate timestamp ranges to just the first timestamp to get rid of double directory writing issue + truncated = {} + for scenario_key, scenario_data in grouped["gfm_expanded"].items(): + # If key contains a slash (timestamp range), take only the first part + if "/" in scenario_key: + first_timestamp = scenario_key.split("/")[0] + truncated[first_timestamp] = scenario_data + else: + truncated[scenario_key] = scenario_data + grouped["gfm_expanded"] = truncated + return dictify(grouped) except requests.RequestException as rex: @@ -520,6 +651,44 @@ def query_stac_for_polygon( logger.error(f"Unexpected error: {ex}") raise + def query_stac_by_item_id(self, item_id: str) -> Dict[str, Dict[str, Dict[str, List[str]]]]: + """ + Query STAC API for a specific item by ID. + + Args: + item_id: STAC item ID to retrieve + + Returns: + Dictionary mapping collection -> scenario -> asset_type -> [paths] + Each scenario also includes 'stac_items' key with list of STAC item IDs + """ + self._ensure_client() + + try: + logger.info(f"Querying STAC for item ID: {item_id}") + + # Get the specific item + item = next(self.client.search(ids=[item_id]).items(), None) + if not item: + logger.info(f"No item found with ID: {item_id}") + return {} + + # Process the single item using the existing format_results method + grouped = self._format_results([item]) + + if not grouped: + logger.info(f"No valid scenarios found after processing item {item_id}") + return {} + + return dictify(grouped) + + except requests.RequestException as rex: + logger.error(f"STAC request failed for item {item_id}: {rex}") + raise + except Exception as ex: + logger.error(f"Unexpected error querying item {item_id}: {ex}") + raise + def save_results(self, results: Dict, output_path: str): """Save query results to JSON file.""" with open(output_path, "w") as f: