From a1ae3e3bfe6a0902dc2c401987ab16d2f5851ea7 Mon Sep 17 00:00:00 2001 From: Dylan Lee Date: Thu, 28 Aug 2025 12:10:28 -0400 Subject: [PATCH 1/3] Update STAC and HAND index queries For stac_querier.py go ahead and add aoi_is_item query functionality. For hand_index_querier.py improve DuckDB credential setting, simplified catchment data parquet structure, and better error handling --- src/hand_index_querier.py | 82 ++++++--- src/stac_querier.py | 370 +++++++++++++++++++++++++++++++------- 2 files changed, 363 insertions(+), 89 deletions(-) diff --git a/src/hand_index_querier.py b/src/hand_index_querier.py index d7814f7..ac896f2 100644 --- a/src/hand_index_querier.py +++ b/src/hand_index_querier.py @@ -3,10 +3,12 @@ from pathlib import Path from typing import Dict, List, Optional, Tuple +import boto3 import duckdb import geopandas as gpd import pandas as pd from shapely.geometry import Polygon +from shapely.wkb import loads from shapely.wkt import dumps logger = logging.getLogger(__name__) @@ -14,7 +16,7 @@ class HandIndexQuerier: """ - A class for querying HAND index data from partitioned parquet files. + A class for querying HAND index data from parquet files. Provides spatial intersection and filtering capabilities. """ @@ -23,12 +25,14 @@ def __init__(self, partitioned_base_path: str, overlap_threshold_percent: float Initialize the HandIndexQuerier. Args: - partitioned_base_path: Base path to partitioned parquet files (local or s3://) + partitioned_base_path: Base path to parquet files (local or s3://) overlap_threshold_percent: Minimum overlap percentage to keep a catchment """ self.partitioned_base_path = partitioned_base_path self.overlap_threshold_percent = overlap_threshold_percent self.con = None + self.credentials = boto3.Session().get_credentials() + self.s3_region = boto3.Session().region_name def _ensure_connection(self): """Ensure DuckDB connection is established with required extensions.""" @@ -43,14 +47,41 @@ def _ensure_connection(self): self.con.execute("LOAD httpfs;") self.con.execute("INSTALL aws;") self.con.execute("LOAD aws;") + + # Try using DuckDB's credential chain provider instead of manual credentials + # This should work better with temporary credentials + try: + self.con.execute(f""" + CREATE SECRET s3_secret ( + TYPE S3, + PROVIDER CREDENTIAL_CHAIN, + REGION '{self.s3_region or 'us-east-1'}' + ); + """) + logger.info("Using DuckDB credential chain for S3 access") + except duckdb.Error as cred_chain_error: + logger.warning(f"Credential chain failed, falling back to manual credentials: {cred_chain_error}") + # Fallback to manual credential setting + self.con.execute(f"SET s3_region = '{self.s3_region or 'us-east-1'}';") + if self.credentials: + if self.credentials.access_key: + self.con.execute(f"SET s3_access_key_id='{self.credentials.access_key}';") + if self.credentials.secret_key: + self.con.execute(f"SET s3_secret_access_key='{self.credentials.secret_key}';") + if self.credentials.token: + self.con.execute(f"SET s3_session_token='{self.credentials.token}';") + + self.con.execute("SET memory_limit = '7GB';") + self.con.execute("SET temp_directory = '/tmp';") + except duckdb.Error as e: logger.warning("Could not load DuckDB extensions: %s", e) - # Create partitioned views - self._create_partitioned_views() + # Create views + self._create_views() - def _create_partitioned_views(self): - """Create views for partitioned tables.""" + def _create_views(self): + """Create views for non-partitioned tables.""" base_path = self.partitioned_base_path if not base_path.endswith("/"): base_path += "/" @@ -59,13 +90,13 @@ def _create_partitioned_views(self): CREATE OR REPLACE VIEW catchments_partitioned AS SELECT * FROM read_parquet('{base_path}catchments/*/*.parquet', hive_partitioning = 1); - CREATE OR REPLACE VIEW hydrotables_partitioned AS - SELECT * FROM read_parquet('{base_path}hydrotables/*/*.parquet', hive_partitioning = 1); + CREATE OR REPLACE VIEW hydrotables AS + SELECT * FROM read_parquet('{base_path}hydrotables.parquet'); - CREATE OR REPLACE VIEW hand_rem_rasters_partitioned AS + CREATE OR REPLACE VIEW hand_rem_rasters AS SELECT * FROM read_parquet('{base_path}hand_rem_rasters.parquet'); - CREATE OR REPLACE VIEW hand_catchment_rasters_partitioned AS + CREATE OR REPLACE VIEW hand_catchment_rasters AS SELECT * FROM read_parquet('{base_path}hand_catchment_rasters.parquet'); """ @@ -91,9 +122,9 @@ def _partitioned_query_cte(self, wkt4326: str) -> str: SELECT c.catchment_id, c.geometry, - c.h3_partition_key + c.h3_index FROM catchments_partitioned c - JOIN transformed_query tq ON ST_Intersects(c.geometry, tq.query_geom) + JOIN transformed_query tq ON ST_Intersects(ST_GeomFromWKB(c.geometry), tq.query_geom) ) """ @@ -136,7 +167,7 @@ def _get_catchment_data_for_polygon( + """ SELECT fc.catchment_id, - ST_AsWKB(fc.geometry) AS geom_wkb + fc.geometry AS geom_wkb FROM filtered_catchments AS fc; """ ) @@ -146,28 +177,27 @@ def _get_catchment_data_for_polygon( empty_gdf = gpd.GeoDataFrame(columns=["catchment_id", "geometry"], geometry="geometry", crs="EPSG:5070") return empty_gdf, pd.DataFrame(), query_poly_5070 - # Decode WKB → shapely geometries - wkb_series = geom_df["geom_wkb"].apply(lambda x: bytes(x) if isinstance(x, bytearray) else x) + # Convert WKB data to Shapely geometry objects. Wrapping wkb in bytes because duckdb exports a bytearray but shapely wants bytes. + geom_df["geometry"] = geom_df["geom_wkb"].apply(lambda wkb: loads(bytes(wkb)) if wkb is not None else None) geometries_gdf = gpd.GeoDataFrame( - geom_df[["catchment_id"]], - geometry=gpd.GeoSeries.from_wkb(wkb_series, crs="EPSG:5070"), + geom_df[["catchment_id", "geometry"]], + geometry="geometry", crs="EPSG:5070", ) - # Build and run the attribute query using partitioned tables + # Build and run the attribute query using non-partitioned tables sql_attr = ( cte + """ SELECT fc.catchment_id, - h.* EXCLUDE (catchment_id, h3_partition_key), - hrr.rem_raster_id, + h.csv_path, hrr.raster_path AS rem_raster_path, hcr.raster_path AS catchment_raster_path FROM filtered_catchments AS fc - LEFT JOIN hydrotables_partitioned AS h ON fc.catchment_id = h.catchment_id - LEFT JOIN hand_rem_rasters_partitioned AS hrr ON fc.catchment_id = hrr.catchment_id - LEFT JOIN hand_catchment_rasters_partitioned AS hcr ON hrr.rem_raster_id = hcr.rem_raster_id; + LEFT JOIN hydrotables AS h ON fc.catchment_id = h.catchment_id + LEFT JOIN hand_rem_rasters AS hrr ON fc.catchment_id = hrr.catchment_id + LEFT JOIN hand_catchment_rasters AS hcr ON fc.catchment_id = hcr.catchment_id; """ ) attributes_df = self.con.execute(sql_attr).fetch_df() @@ -275,12 +305,6 @@ def query_catchments_for_polygon( for catch_id, group in filtered_attrs.groupby("catchment_id"): df = group.drop(columns=["catchment_id"]).copy() - # Convert UUID columns to strings for Parquet compatibility - uuid_columns = ["rem_raster_id", "catchment_raster_id"] - for col in uuid_columns: - if col in df.columns: - df[col] = df[col].astype(str) - out_path = outdir / f"{catch_id}.parquet" df.to_parquet(str(out_path), index=False) logger.info("Wrote %d rows for catchment '%s' → %s", len(df), catch_id, out_path) diff --git a/src/stac_querier.py b/src/stac_querier.py index d531afc..d728c26 100644 --- a/src/stac_querier.py +++ b/src/stac_querier.py @@ -66,7 +66,9 @@ def compare_versions(v1: str, v2: str) -> int: class Interval: start: datetime end: datetime - assets: Dict[str, List[str]] = field(default_factory=lambda: defaultdict(list)) + assets: Dict[str, List[str]] = field( + default_factory=lambda: defaultdict(list) + ) class StacQuerier: @@ -81,6 +83,7 @@ def __init__( collections: Optional[List[str]] = None, overlap_threshold_percent: float = 40.0, datetime_filter: Optional[str] = None, + aoi_is_item: bool = False, ): """ Initialize the StacQuerier. @@ -90,11 +93,13 @@ def __init__( collections: List of STAC collection IDs to query (None means query all available) overlap_threshold_percent: Minimum overlap percentage to keep a STAC item datetime_filter: STAC datetime or interval filter + aoi_is_item: If True, query specific STAC items directly by ID """ self.api_url = api_url self.collections = collections self.overlap_threshold_percent = overlap_threshold_percent self.datetime_filter = datetime_filter + self.aoi_is_item = aoi_is_item self.client = None # Collection specifications @@ -103,15 +108,20 @@ def __init__( ] self.NWS_USGS_MAGS = ["action", "minor", "moderate", "major"] self.NWS_USGS_SPEC = [ - (re.compile(rf"^{mag}_(extent_raster|flow_file)$"), mag, r"\1") for mag in self.NWS_USGS_MAGS + (re.compile(rf"^{mag}_(extent_raster|flow_file)$"), mag, r"\1") + for mag in self.NWS_USGS_MAGS + ] + self.RIPPLE_SPEC: List[Tuple[Pattern, str, str]] = [ + (re.compile(r"^(\d+yr)_extent$"), r"\1", "extents") ] - self.RIPPLE_SPEC: List[Tuple[Pattern, str, str]] = [(re.compile(r"^(\d+yr)_extent$"), r"\1", "extents")] CollectionConfig = Dict[ str, Tuple[ Callable[[Any], str], # grouping fn → group_id - Dict[str, Callable[[str, Any], bool]], # asset_type → test(key, asset) + Dict[ + str, Callable[[str, Any], bool] + ], # asset_type → test(key, asset) ], ] @@ -119,21 +129,28 @@ def __init__( "gfm-collection": ( lambda item: str(item.properties.get("dfo_event_id", item.id)), { - "extents": lambda k, a: k.endswith("_Observed_Water_Extent"), + "extents": lambda k, a: k.endswith( + "_Observed_Water_Extent" + ), "flowfiles": lambda k, a: k.endswith("_flowfile"), }, ), "gfm-expanded-collection": ( self._group_gfm_expanded_initial, { - "extents": lambda k, a: k.endswith("_Observed_Water_Extent"), - "flowfiles": lambda k, a: k.endswith("_flowfile") or k == "NWM_ANA_flowfile", + "extents": lambda k, a: k.endswith( + "_Observed_Water_Extent" + ), + "flowfiles": lambda k, a: k.endswith("_flowfile") + or k == "NWM_ANA_flowfile", }, ), "hwm-collection": ( lambda item: item.id, { - "points": lambda k, a: k == "data" and a.media_type and "geopackage" in a.media_type, + "points": lambda k, a: k == "data" + and a.media_type + and "geopackage" in a.media_type, "flowfiles": lambda k, a: k.endswith("-flowfile"), }, ), @@ -149,7 +166,9 @@ def _ensure_client(self): logger.error(f"Could not open STAC API: {e}") raise - def _filter_items_by_geometry(self, items: List[Any], query_polygon: Polygon) -> List[Any]: + def _filter_items_by_geometry( + self, items: List[Any], query_polygon: Polygon + ) -> List[Any]: """ Filter STAC items using geometry relationships similar to HAND query filtering. @@ -195,14 +214,18 @@ def _filter_items_by_geometry(self, items: List[Any], query_polygon: Polygon) -> # Compute overlap percentage relative to item's area if not contains_query and not within_query: - intersection_area = item_geom.intersection(query_polygon).area + intersection_area = item_geom.intersection( + query_polygon + ).area item_area = item_geom.area if item_area > 0: overlap_pct = (intersection_area / item_area) * 100 else: overlap_pct = 0.0 else: - overlap_pct = 100.0 # Contains or within means 100% relevant + overlap_pct = ( + 100.0 # Contains or within means 100% relevant + ) # Apply selection criteria if contains_query: @@ -218,7 +241,9 @@ def _filter_items_by_geometry(self, items: List[Any], query_polygon: Polygon) -> stats["removed_count"] += 1 except Exception as e: - logger.warning(f"Error processing item geometry for {getattr(item, 'id', 'unknown')}: {e}") + logger.warning( + f"Error processing item geometry for {getattr(item, 'id', 'unknown')}: {e}" + ) # Include item if geometry processing fails filtered_items.append(item) @@ -262,20 +287,26 @@ def _group_gfm_expanded_initial(self, item: Any) -> str: fmt = "%Y-%m-%dT%H:%M:%SZ" return f"{start.strftime(fmt)}/{end.strftime(fmt)}" - def _format_results(self, item_iter: Any) -> Dict[str, Dict[str, Dict[str, List[str]]]]: + def _format_results( + self, item_iter: Any + ) -> Dict[str, Dict[str, Dict[str, List[str]]]]: results = defaultdict(lambda: defaultdict(lambda: defaultdict(list))) seen = set() ripple_cache: Dict[str, List[str]] = {} ripple_best_items = {} # Track STAC item IDs for each scenario item_ids = defaultdict(lambda: defaultdict(set)) + gauge_info = defaultdict(lambda: defaultdict(lambda: None)) + hucs_info = defaultdict(lambda: defaultdict(list)) for idx, item in enumerate(item_iter, start=1): if item.collection_id == "ripple-fim-collection": try: source = item.properties.get("source", "") hucs = tuple(item.properties.get("hucs", [])) - flows2fim_version = item.properties.get("flows2fim_version", "") + flows2fim_version = item.properties.get( + "flows2fim_version", "" + ) if source and hucs and flows2fim_version: key = (source, hucs) @@ -283,9 +314,15 @@ def _format_results(self, item_iter: Any) -> Dict[str, Dict[str, Dict[str, List[ # If we haven't seen this source+hucs combo, or this version is better if ( key not in ripple_best_items - or compare_versions(flows2fim_version, ripple_best_items[key][1]) > 0 + or compare_versions( + flows2fim_version, ripple_best_items[key][1] + ) + > 0 ): - ripple_best_items[key] = (item.id, flows2fim_version) + ripple_best_items[key] = ( + item.id, + flows2fim_version, + ) else: # Skip this item - we already have a better version logger.info( @@ -293,7 +330,9 @@ def _format_results(self, item_iter: Any) -> Dict[str, Dict[str, Dict[str, List[ ) continue except Exception as e: - logger.warning(f"Error processing Ripple item {item.id}: {e}") + logger.warning( + f"Error processing Ripple item {item.id}: {e}" + ) if item.id in seen: continue @@ -302,16 +341,25 @@ def _format_results(self, item_iter: Any) -> Dict[str, Dict[str, Dict[str, List[ logger.info(f"Processed {idx} items (last: {item.id})") coll = item.collection_id or "" - short = coll.replace("-collection", "").replace("-fim", "") - if short == "gfm-expanded": - short = "gfm_expanded" + collection_key = coll + # Handle special case for gfm-expanded-collection to maintain compatibility + if coll == "gfm-expanded-collection": + collection_key = "gfm_expanded" # 1) item‐level grouping if coll in self.COLLECTIONS: group_fn, tests = self.COLLECTIONS[coll] gid = group_fn(item) - bucket = results[short][gid] - item_ids[short][gid].add(item.id) + bucket = results[collection_key][gid] + item_ids[collection_key][gid].add(item.id) + + # Extract hucs for item-level collections + hucs = item.properties.get("hucs", []) + if hucs: + for huc in hucs: + if huc not in hucs_info[collection_key][gid]: + hucs_info[collection_key][gid].append(huc) + for k, a in item.assets.items(): if not a.href: continue @@ -319,28 +367,71 @@ def _format_results(self, item_iter: Any) -> Dict[str, Dict[str, Dict[str, List[ if test(k, a) and a.href not in bucket[atype]: bucket[atype].append(a.href) - # 2) BLE/NWS/USGS/Ripple asset‐level grouping - elif coll == "ble-collection" or coll.endswith("-fim-collection"): + # 2) BLE/NWS/USGS asset‐level grouping (excluding Ripple) + elif coll == "ble-collection" or coll in [ + "nws-fim-collection", + "usgs-fim-collection", + ]: + specs = ( + self.BLE_SPEC + if coll == "ble-collection" + else self.NWS_USGS_SPEC + ) + + found = set() + for k, a in item.assets.items(): + if not a.href: + continue + for pat, gid_t, at_t in specs: + m = pat.match(k) + if not m: + continue + gid = m.expand(gid_t) if "\\" in gid_t else gid_t + at = m.expand(at_t) + bkt = results[collection_key][gid] + item_ids[collection_key][gid].add(item.id) + if a.href not in bkt[at]: + bkt[at].append(a.href) + found.add(gid) + + if coll in [ + "nws-fim-collection", + "usgs-fim-collection", + ]: + gauge = item.properties.get("gauge") + if ( + gauge + and gauge_info[collection_key][gid] is None + ): + gauge_info[collection_key][gid] = gauge + + # Extract hucs for all asset-level collections (BLE, NWS, USGS) + hucs = item.properties.get("hucs", []) + if hucs: + for gid in found: + for huc in hucs: + if huc not in hucs_info[collection_key][gid]: + hucs_info[collection_key][gid].append(huc) + + # 3) Ripple asset-level grouping with source-based separation + elif coll == "ripple-fim-collection": # preload ripple assets once - if coll == "ripple-fim-collection" and not ripple_cache: + if not ripple_cache: try: col = item.get_collection() for ak, aa in col.assets.items(): m = re.search(r"flows_(\d+)_yr_", ak) ri = f"{m.group(1)}yr" if m else None if ri and aa.media_type == "text/csv": - logger.info(f"Caching Ripple flowfile for {ri}: {aa.href}") + logger.info( + f"Caching Ripple flowfile for {ri}: {aa.href}" + ) ripple_cache.setdefault(ri, []).append(aa.href) except Exception as e: logger.warning(f"Ripple cache failed: {e}") - specs = ( - self.BLE_SPEC - if coll == "ble-collection" - else self.RIPPLE_SPEC - if coll == "ripple-fim-collection" - else self.NWS_USGS_SPEC - ) + specs = self.RIPPLE_SPEC + source = item.properties.get("source", "unknown_source") found = set() for k, a in item.assets.items(): @@ -350,42 +441,88 @@ def _format_results(self, item_iter: Any) -> Dict[str, Dict[str, Dict[str, List[ m = pat.match(k) if not m: continue - gid = m.expand(gid_t) if "\\" in gid_t else gid_t + # Get the base ID (e.g., "100yr") + base_gid = m.expand(gid_t) if "\\" in gid_t else gid_t + # Create the new, unique group ID + gid = f"{source}-{base_gid}" at = m.expand(at_t) - bkt = results[short][gid] - item_ids[short][gid].add(item.id) + bkt = results[collection_key][gid] + item_ids[collection_key][gid].add(item.id) if a.href not in bkt[at]: bkt[at].append(a.href) found.add(gid) - # append ripple flowfiles - if coll == "ripple-fim-collection": + # Extract hucs for Ripple collections + hucs = item.properties.get("hucs", []) + if hucs: for gid in found: - if gid in ripple_cache: - logger.info(f"Adding cached flowfiles for {gid}") - for href in ripple_cache[gid]: - if href not in results[short][gid]["flowfiles"]: - results[short][gid]["flowfiles"].append(href) + for huc in hucs: + if huc not in hucs_info[collection_key][gid]: + hucs_info[collection_key][gid].append(huc) - # 3) fallback + # append ripple flowfiles + for composite_gid in found: + # Extract "100yr" from "SourceA-100yr" + base_gid = composite_gid.split("-", 1)[-1] + if base_gid in ripple_cache: + logger.info( + f"Adding cached flowfiles for {composite_gid}" + ) + for href in ripple_cache[base_gid]: + if ( + href + not in results[collection_key][composite_gid][ + "flowfiles" + ] + ): + results[collection_key][composite_gid][ + "flowfiles" + ].append(href) + + # 4) fallback else: logger.warning(f"Unknown coll '{coll}'; grouping by item.id") gid = item.id - bkt = results[short][gid] - item_ids[short][gid].add(item.id) + bkt = results[collection_key][gid] + item_ids[collection_key][gid].add(item.id) + + if "nws" in coll.lower() or "usgs" in coll.lower(): + gauge = item.properties.get("gauge") + if gauge and gauge_info[collection_key][gid] is None: + gauge_info[collection_key][gid] = gauge + + # Extract hucs for fallback collections + hucs = item.properties.get("hucs", []) + if hucs: + for huc in hucs: + if huc not in hucs_info[collection_key][gid]: + hucs_info[collection_key][gid].append(huc) + for k, a in item.assets.items(): if not a.href: continue - if "extent" in k and a.media_type and "tiff" in a.media_type: + if ( + "extent" in k + and a.media_type + and "tiff" in a.media_type + ): bkt["extents"].append(a.href) elif "flow" in k and a.media_type and "csv" in a.media_type: bkt["flowfiles"].append(a.href) - # Add item IDs to results - for collection_short in results: - for scenario_id in results[collection_short]: - results[collection_short][scenario_id]["stac_items"] = list(item_ids[collection_short][scenario_id]) - + # Add item IDs, gauge information, and hucs to results + for collection_key in results: + for scenario_id in results[collection_key]: + results[collection_key][scenario_id]["stac_items"] = list( + item_ids[collection_key][scenario_id] + ) + results[collection_key][scenario_id]["gauge"] = gauge_info[ + collection_key + ][scenario_id] + results[collection_key][scenario_id]["hucs"] = hucs_info[ + collection_key + ][scenario_id] + logger.info(f"Finished formatting {len(seen)} items.") return results @@ -410,6 +547,18 @@ def _merge_gfm_expanded( for h in hs: if h not in iv.assets[at]: iv.assets[at].append(h) + elif at == "gauge": + # Handle gauge specially - it's a single value, not a list + if hs is not None: + iv.assets[at] = hs + elif at == "hucs": + # Handle hucs specially - merge lists without duplicates + if hs: + existing_hucs = iv.assets.get(at, []) + for h in hs: + if h not in existing_hucs: + existing_hucs.append(h) + iv.assets[at] = existing_hucs else: for h in hs: if h not in iv.assets[at]: @@ -430,6 +579,18 @@ def _merge_gfm_expanded( for h in hs: if h not in cur.assets[at]: cur.assets[at].append(h) + elif at == "gauge": + # Handle gauge specially - use the first non-null value + if hs is not None and at not in cur.assets: + cur.assets[at] = hs + elif at == "hucs": + # Handle hucs specially - merge lists without duplicates + if hs: + existing_hucs = cur.assets.get(at, []) + for h in hs: + if h not in existing_hucs: + existing_hucs.append(h) + cur.assets[at] = existing_hucs else: for h in hs: if h not in cur.assets[at]: @@ -447,7 +608,9 @@ def _merge_gfm_expanded( return out def query_stac_for_polygon( - self, polygon_gdf: Optional[gpd.GeoDataFrame] = None, roi_geojson: Optional[Dict] = None + self, + polygon_gdf: Optional[gpd.GeoDataFrame] = None, + roi_geojson: Optional[Dict] = None, ) -> Dict[str, Dict[str, Dict[str, List[str]]]]: """ Query STAC API for items and group them by collection/scenario. @@ -482,34 +645,79 @@ def query_stac_for_polygon( try: query_polygon = shape(roi_geojson) except Exception as e: - logger.warning(f"Could not convert ROI to shapely geometry: {e}") + logger.warning( + f"Could not convert ROI to shapely geometry: {e}" + ) # Build search kwargs search_kw = { "datetime": self.datetime_filter, **({"intersects": intersects} if intersects else {}), } - + + # KLUDGE: Hardcode 2022 date range for GFM collections. this will be removed after trial batches with gfm data are completed. + if self.collections is not None and any( + "gfm" in coll.lower() for coll in self.collections + ): + search_kw["datetime"] = "2022-01-01T00:00:00Z/2022-12-31T23:59:59Z" + logger.info( + "KLUDGE: Overriding datetime filter to 2022 for GFM collections" + ) + # Only include collections if specified, otherwise query all available if self.collections is not None: search_kw["collections"] = self.collections - + search_kw = {k: v for k, v in search_kw.items() if v is not None} try: - collections_msg = search_kw.get("collections", "all available collections") + collections_msg = search_kw.get( + "collections", "all available collections" + ) logger.info(f"Searching collections {collections_msg}") search = self.client.search(**search_kw) - items = list(search.items()) # Convert to list for geometry filtering + items = list( + search.items() + ) # Convert to list for geometry filtering + + if not items: + logger.info("STAC query returned no items") + return {} # Apply geometry filtering if query polygon is provided if query_polygon and self.overlap_threshold_percent: - logger.info(f"Applying geometry filtering with {self.overlap_threshold_percent}% overlap threshold") + logger.info( + f"Applying geometry filtering with {self.overlap_threshold_percent}% overlap threshold" + ) items = self._filter_items_by_geometry(items, query_polygon) + if not items: + logger.info("No items remained after geometry filtering") + return {} grouped = self._format_results(items) + + if not grouped: + logger.info( + "No valid scenarios found after processing STAC items" + ) + return {} if "gfm_expanded" in grouped: - grouped["gfm_expanded"] = self._merge_gfm_expanded(grouped["gfm_expanded"]) + grouped["gfm_expanded"] = self._merge_gfm_expanded( + grouped["gfm_expanded"] + ) + + # Truncate timestamp ranges to just the first timestamp to get rid of double directory writing issue + truncated = {} + for scenario_key, scenario_data in grouped[ + "gfm_expanded" + ].items(): + # If key contains a slash (timestamp range), take only the first part + if "/" in scenario_key: + first_timestamp = scenario_key.split("/")[0] + truncated[first_timestamp] = scenario_data + else: + truncated[scenario_key] = scenario_data + grouped["gfm_expanded"] = truncated return dictify(grouped) @@ -520,6 +728,48 @@ def query_stac_for_polygon( logger.error(f"Unexpected error: {ex}") raise + def query_stac_by_item_id( + self, item_id: str + ) -> Dict[str, Dict[str, Dict[str, List[str]]]]: + """ + Query STAC API for a specific item by ID. + + Args: + item_id: STAC item ID to retrieve + + Returns: + Dictionary mapping collection -> scenario -> asset_type -> [paths] + Each scenario also includes 'stac_items' key with list of STAC item IDs + """ + self._ensure_client() + + try: + logger.info(f"Querying STAC for item ID: {item_id}") + + # Get the specific item + item = next(self.client.search(ids=[item_id]).items(), None) + if not item: + logger.info(f"No item found with ID: {item_id}") + return {} + + # Process the single item using the existing format_results method + grouped = self._format_results([item]) + + if not grouped: + logger.info( + f"No valid scenarios found after processing item {item_id}" + ) + return {} + + return dictify(grouped) + + except requests.RequestException as rex: + logger.error(f"STAC request failed for item {item_id}: {rex}") + raise + except Exception as ex: + logger.error(f"Unexpected error querying item {item_id}: {ex}") + raise + def save_results(self, results: Dict, output_path: str): """Save query results to JSON file.""" with open(output_path, "w") as f: From ce9ab2490f698c90943286aa543604d76f129885 Mon Sep 17 00:00:00 2001 From: Dylan Lee Date: Thu, 28 Aug 2025 12:18:52 -0400 Subject: [PATCH 2/3] Edit data_service.py data_service.py now has credential refresh. This was necessary because couldn't get fsspec to use IAM credentials directly so was obtaining credentials from the IAM using boto3 and passing them to fsspec. But then those credentials were expiring. Pass through pass through additional metadata fields gauge, hucs, and stac_items from STAC queries Add aoi_is_item flag to be able to query directly by stac item append_file_to_uri() method supporting atomic append operations for S3 and local files DataServiceException class for better specific error handling. This is so that a pipeline can be reported as failed if there is a data service issue --- src/data_service.py | 351 +++++++++++++++++++++++++++++++++++++------- 1 file changed, 299 insertions(+), 52 deletions(-) diff --git a/src/data_service.py b/src/data_service.py index f789af3..552a8ae 100644 --- a/src/data_service.py +++ b/src/data_service.py @@ -4,11 +4,15 @@ import os import shutil import tempfile +import threading +import time from contextlib import suppress +from datetime import datetime, timedelta from io import StringIO from pathlib import Path from typing import Any, Dict, List, Optional +import boto3 import fsspec import geopandas as gpd from botocore.exceptions import ClientError, NoCredentialsError @@ -20,12 +24,33 @@ from stac_querier import StacQuerier +class DataServiceException(Exception): + """Custom exception for data service operations.""" + + pass + + class DataService: """Service to query data sources and interact with S3 via fsspec.""" - def __init__(self, config: AppConfig, hand_index_path: str, benchmark_collections: Optional[List[str]] = None): + def __init__( + self, + config: AppConfig, + hand_index_path: str, + benchmark_collections: Optional[List[str]] = None, + aoi_is_item: bool = False, + ): self.config = config - # Configure S3 filesystem options from environment + self.aoi_is_item = aoi_is_item + + # Initialize credential refresh tracking + self._credential_lock = threading.Lock() + self._last_credential_refresh = datetime.now() + self._refresh_thread = None + self._stop_refresh = threading.Event() + self._use_iam_credentials = False + + # Configure S3 filesystem options - try config first, then IAM credentials self._s3_options = {} if config.aws.AWS_ACCESS_KEY_ID: self._s3_options["key"] = config.aws.AWS_ACCESS_KEY_ID @@ -34,6 +59,29 @@ def __init__(self, config: AppConfig, hand_index_path: str, benchmark_collection if config.aws.AWS_SESSION_TOKEN: self._s3_options["token"] = config.aws.AWS_SESSION_TOKEN + # If no explicit credentials in config, try to get from IAM instance profile + if not self._s3_options: + try: + credentials = boto3.Session().get_credentials() + if credentials: + self._s3_options["key"] = credentials.access_key + self._s3_options["secret"] = credentials.secret_key + if credentials.token: + self._s3_options["token"] = credentials.token + self._use_iam_credentials = True + logging.info( + "Using IAM instance profile credentials for S3 access" + ) + + # Start credential refresh thread for IAM credentials + self._start_credential_refresh_thread() + else: + logging.warning( + "No AWS credentials found in config or IAM instance profile" + ) + except Exception as e: + logging.warning(f"Failed to get IAM credentials: {e}") + # Initialize HandIndexQuerier with provided path self.hand_querier = None if hand_index_path: @@ -50,13 +98,61 @@ def __init__(self, config: AppConfig, hand_index_path: str, benchmark_collection collections=benchmark_collections, overlap_threshold_percent=config.stac.overlap_threshold_percent, datetime_filter=config.stac.datetime_filter, + aoi_is_item=aoi_is_item, ) # Initialize FlowfileCombiner self.flowfile_combiner = FlowfileCombiner( - output_dir=config.flow_scenarios.output_dir if config.flow_scenarios else "combined_flowfiles" + output_dir=config.flow_scenarios.output_dir + if config.flow_scenarios + else "combined_flowfiles" ) + def _start_credential_refresh_thread(self): + """Start a background thread to refresh IAM credentials every 4 hours.""" + if self._use_iam_credentials: + self._refresh_thread = threading.Thread( + target=self._credential_refresh_loop, daemon=True + ) + self._refresh_thread.start() + logging.info( + "Started IAM credential refresh thread (4-hour interval)" + ) + + def _credential_refresh_loop(self): + """Background thread loop to refresh credentials every 4 hours.""" + refresh_interval = 4 * 60 * 60 # sec + + while not self._stop_refresh.is_set(): + # Wait for 4 hours or until stop event is set + if self._stop_refresh.wait(refresh_interval): + break + + # Refresh credentials + self._refresh_credentials() + + def _refresh_credentials(self): + """Refresh IAM credentials from boto3 session.""" + try: + logging.info("Refreshing IAM credentials...") + credentials = boto3.Session().get_credentials() + + if credentials: + with self._credential_lock: + self._s3_options["key"] = credentials.access_key + self._s3_options["secret"] = credentials.secret_key + if credentials.token: + self._s3_options["token"] = credentials.token + self._last_credential_refresh = datetime.now() + + logging.info("Successfully refreshed IAM credentials") + else: + logging.error( + "Failed to refresh credentials: No credentials available from boto3 session" + ) + except Exception as e: + logging.error(f"Error refreshing IAM credentials: {e}") + def load_polygon_gdf_from_file(self, file_path: str) -> gpd.GeoDataFrame: """ Args: @@ -71,55 +167,91 @@ def load_polygon_gdf_from_file(self, file_path: str) -> gpd.GeoDataFrame: """ try: # geopandas can read from both local and S3 paths - # For S3 paths, it will use storage_options; for local paths, it ignores them - gdf = gpd.read_file(file_path, storage_options=self._s3_options if file_path.startswith("s3://") else None) + gdf = gpd.read_file(file_path) if len(gdf) == 0: raise ValueError(f"Empty GeoDataFrame in file: {file_path}") # Ensure CRS is EPSG:4326 if gdf.crs and gdf.crs.to_epsg() != 4326: - logging.info(f"Converting polygon data from {gdf.crs} to EPSG:4326") + logging.info( + f"Converting polygon data from {gdf.crs} to EPSG:4326" + ) gdf = gdf.to_crs("EPSG:4326") elif not gdf.crs: - logging.warning(f"No CRS found in {file_path}, assuming EPSG:4326") + logging.warning( + f"No CRS found in {file_path}, assuming EPSG:4326" + ) gdf.set_crs("EPSG:4326", inplace=True) - logging.info(f"Loaded polygon GeoDataFrame with {len(gdf)} features from {file_path}") + logging.info( + f"Loaded polygon GeoDataFrame with {len(gdf)} features from {file_path}" + ) return gdf except FileNotFoundError: raise FileNotFoundError(f"Polygon data file not found: {file_path}") except Exception as e: - raise ValueError(f"Error loading GeoDataFrame from {file_path}: {e}") + raise ValueError( + f"Error loading GeoDataFrame from {file_path}: {e}" + ) - async def query_stac_for_flow_scenarios(self, polygon_gdf: gpd.GeoDataFrame) -> Dict: + async def query_stac_for_flow_scenarios( + self, + polygon_gdf: gpd.GeoDataFrame, + tags: Optional[Dict[str, str]] = None, + ) -> Dict: """ - Query STAC API for flow scenarios based on polygon. + Query STAC API for flow scenarios based on polygon or item ID. Args: polygon_gdf: GeoDataFrame containing polygon geometry + tags: Optional tags dictionary (required when aoi_is_item is True) Returns: Dictionary with STAC query results and combined flowfiles """ # Generate polygon_id from index or use a default - polygon_id = f"polygon_{len(polygon_gdf)}" if len(polygon_gdf) > 0 else "unknown" + polygon_id = ( + f"polygon_{len(polygon_gdf)}" if len(polygon_gdf) > 0 else "unknown" + ) if not self.stac_querier or not self.config.stac: - raise RuntimeError("STAC configuration is required but not provided") + raise RuntimeError( + "STAC configuration is required but not provided" + ) - logging.info(f"Querying STAC for flow scenarios for polygon: {polygon_id}") + logging.info( + f"Querying STAC for flow scenarios for polygon: {polygon_id}" + ) try: # Run STAC query in executor to avoid blocking loop = asyncio.get_running_loop() - stac_results = await loop.run_in_executor( - None, - self.stac_querier.query_stac_for_polygon, - polygon_gdf, - None, # roi_geojson not needed since we have polygon_gdf - ) + + if self.aoi_is_item: + # Direct item query mode - extract aoi_name from tags + if not tags or "aoi_name" not in tags: + raise ValueError( + "aoi_name tag is required when --aoi_is_item is used" + ) + + aoi_name = tags["aoi_name"] + logging.info(f"Querying STAC for specific item ID: {aoi_name}") + + stac_results = await loop.run_in_executor( + None, + self.stac_querier.query_stac_by_item_id, + aoi_name, + ) + else: + # Standard spatial query mode + stac_results = await loop.run_in_executor( + None, + self.stac_querier.query_stac_for_polygon, + polygon_gdf, + None, # roi_geojson not needed since we have polygon_gdf + ) if not stac_results: logging.info(f"No STAC results found for polygon {polygon_id}") @@ -128,7 +260,9 @@ async def query_stac_for_flow_scenarios(self, polygon_gdf: gpd.GeoDataFrame) -> # Process flowfiles from STAC results combined_flowfiles = {} if self.flowfile_combiner: - logging.info(f"Processing flowfiles for {len(stac_results)} collections") + logging.info( + f"Processing flowfiles for {len(stac_results)} collections" + ) # Create temporary directory for this polygon's combined flowfiles temp_dir = f"/tmp/flow_scenarios_{polygon_id}" @@ -141,7 +275,9 @@ async def query_stac_for_flow_scenarios(self, polygon_gdf: gpd.GeoDataFrame) -> temp_dir, ) - logging.info(f"Combined flowfiles created for polygon {polygon_id}") + logging.info( + f"Combined flowfiles created for polygon {polygon_id}" + ) return { "scenarios": stac_results, @@ -167,14 +303,18 @@ async def query_for_catchments(self, polygon_gdf: gpd.GeoDataFrame) -> Dict: await asyncio.sleep(0.01) # Generate polygon_id from index or use a default - polygon_id = f"polygon_{len(polygon_gdf)}" if len(polygon_gdf) > 0 else "unknown" + polygon_id = ( + f"polygon_{len(polygon_gdf)}" if len(polygon_gdf) > 0 else "unknown" + ) logging.info(f"Querying catchments for polygon: {polygon_id}") if self.hand_querier: # Use real hand index query try: # Create temporary directory for parquet outputs - with tempfile.TemporaryDirectory(prefix=f"catchments_{polygon_id}_") as temp_dir: + with tempfile.TemporaryDirectory( + prefix=f"catchments_{polygon_id}_" + ) as temp_dir: # Run the query with temporary output directory loop = asyncio.get_running_loop() catchments_result = await loop.run_in_executor( @@ -185,7 +325,9 @@ async def query_for_catchments(self, polygon_gdf: gpd.GeoDataFrame) -> Dict: ) if not catchments_result: - logging.info(f"No catchments found for polygon {polygon_id}") + logging.info( + f"No catchments found for polygon {polygon_id}" + ) return {"catchments": {}, "hand_version": "real_query"} # Copy parquet files to a more permanent location if needed @@ -209,14 +351,21 @@ async def query_for_catchments(self, polygon_gdf: gpd.GeoDataFrame) -> Dict: logging.info( f"Data service returning {len(catchments)} catchments for polygon {polygon_id} (real query)." ) - return {"catchments": catchments, "hand_version": "real_query"} + return { + "catchments": catchments, + "hand_version": "real_query", + } except Exception as e: - logging.error(f"Error in hand index query for polygon {polygon_id}: {e}") + logging.error( + f"Error in hand index query for polygon {polygon_id}: {e}" + ) raise # No hand querier available - raise RuntimeError(f"Hand index querier is required but not initialized for polygon {polygon_id}") + raise RuntimeError( + f"Hand index querier is required but not initialized for polygon {polygon_id}" + ) async def copy_file_to_uri(self, source_path: str, dest_uri: str): """Copies a file (e.g., parquet) to a URI (local or S3) using fsspec. @@ -229,12 +378,15 @@ async def copy_file_to_uri(self, source_path: str, dest_uri: str): else source_path ) dest_uri_normalized = ( - os.path.abspath(dest_uri) if not dest_uri.startswith(("s3://", "http://", "https://")) else dest_uri + os.path.abspath(dest_uri) + if not dest_uri.startswith(("s3://", "http://", "https://")) + else dest_uri ) # Copy if source is local and destination is S3, or if both are local but paths differ should_copy = ( - not source_path.startswith(("s3://", "http://", "https://")) and dest_uri.startswith("s3://") + not source_path.startswith(("s3://", "http://", "https://")) + and dest_uri.startswith("s3://") ) or ( not source_path.startswith(("s3://", "http://", "https://")) and not dest_uri.startswith(("s3://", "http://", "https://")) @@ -259,24 +411,90 @@ async def copy_file_to_uri(self, source_path: str, dest_uri: str): return dest_uri except Exception as e: logging.exception(f"Failed to copy {source_path} to {dest_uri}") - raise ConnectionError(f"Failed to copy file to {dest_uri}") from e + raise DataServiceException( + f"Failed to copy file to {dest_uri}: {str(e)}" + ) from e else: # If already on S3 or both local with same path, just return the source path - logging.debug(f"No copy needed - using existing path: {source_path}") + logging.debug( + f"No copy needed - using existing path: {source_path}" + ) return source_path def _sync_copy_file(self, source_path: str, dest_uri: str): """Synchronous helper for copying files using fsspec.""" - # Create filesystem based on destination URI + # Copy file using fsspec.open directly (allows fallback to default AWS credentials) + with open(source_path, "rb") as src: + # Use thread-safe credential access for S3 operations + if dest_uri.startswith("s3://"): + with self._credential_lock: + s3_opts = self._s3_options.copy() + with fsspec.open(dest_uri, "wb", **s3_opts) as dst: + dst.write(src.read()) + else: + with fsspec.open(dest_uri, "wb") as dst: + dst.write(src.read()) + + async def append_file_to_uri(self, source_path: str, dest_uri: str): + """Appends a file to a URI (local or S3) using fsspec. + If the destination file doesn't exist, creates it with the source content. + If it exists, appends the source content to it.""" + + logging.debug(f"Appending file from {source_path} to {dest_uri}") + try: + # Check if destination exists + dest_exists = await self.check_file_exists(dest_uri) + + # Run in executor for async operation + loop = asyncio.get_running_loop() + await loop.run_in_executor( + None, + self._sync_append_file, + source_path, + dest_uri, + dest_exists, + ) + logging.info( + f"Successfully {'appended to' if dest_exists else 'created'} {dest_uri}" + ) + return dest_uri + except Exception as e: + logging.exception(f"Failed to append {source_path} to {dest_uri}") + raise DataServiceException( + f"Failed to append file to {dest_uri}: {str(e)}" + ) from e + + def _sync_append_file( + self, source_path: str, dest_uri: str, dest_exists: bool + ): + """Synchronous helper for appending files using fsspec.""" + # Read source content + with open(source_path, "rb") as src: + source_content = src.read() + + # Use thread-safe credential access for S3 operations if dest_uri.startswith("s3://"): - fs = fsspec.filesystem("s3", **self._s3_options) + with self._credential_lock: + s3_opts = self._s3_options.copy() else: - fs = fsspec.filesystem("file") + s3_opts = {} - # Copy file using fsspec - with open(source_path, "rb") as src: - with fs.open(dest_uri, "wb") as dst: - dst.write(src.read()) + if dest_exists: + # If destination exists, read existing content first + with fsspec.open(dest_uri, "rb", **s3_opts) as dst: + existing_content = dst.read() + + # Write combined content + with fsspec.open(dest_uri, "wb", **s3_opts) as dst: + dst.write(existing_content) + dst.write(source_content) + else: + # If destination doesn't exist, just write the source content + if not dest_uri.startswith(("s3://", "http://", "https://")): + os.makedirs(os.path.dirname(dest_uri), exist_ok=True) + + with fsspec.open(dest_uri, "wb", **s3_opts) as dst: + dst.write(source_content) async def check_file_exists(self, uri: str) -> bool: """Check if a file exists (S3 or local). @@ -295,20 +513,27 @@ async def check_file_exists(self, uri: str) -> bool: uri, ) except Exception as e: - logging.warning(f"Error checking file {uri}: {e}") - return False + logging.exception(f"Error checking file {uri}") + raise DataServiceException( + f"Failed to check file existence {uri}: {str(e)}" + ) from e def _sync_check_file_exists(self, uri: str) -> bool: """Synchronous helper to check if file exists (S3 or local).""" try: if uri.startswith("s3://"): - fs = fsspec.filesystem("s3", **self._s3_options) + # Use thread-safe credential access for S3 operations + with self._credential_lock: + s3_opts = self._s3_options.copy() + fs = fsspec.filesystem("s3", **s3_opts) return fs.exists(uri) else: # Local file return Path(uri).exists() except (FileNotFoundError, NoCredentialsError, ClientError) as e: - logging.debug(f"File {uri} does not exist or is not accessible: {e}") + logging.debug( + f"File {uri} does not exist or is not accessible: {e}" + ) return False except Exception as e: logging.warning(f"Unexpected error checking file {uri}: {e}") @@ -349,7 +574,9 @@ async def validate_files(self, uris: List[str]) -> List[str]: for uri in missing_uris: logging.info(f" Missing: {uri}") - logging.info(f"Validated files: {len(valid_uris)} exist, {len(missing_uris)} missing") + logging.info( + f"Validated files: {len(valid_uris)} exist, {len(missing_uris)} missing" + ) return valid_uris def find_metrics_files(self, base_path: str) -> List[str]: @@ -365,26 +592,46 @@ def find_metrics_files(self, base_path: str) -> List[str]: try: if base_path.startswith("s3://"): - fs = fsspec.filesystem("s3", **self._s3_options) - # Use glob to find all metrics.csv files recursively - pattern = f"{base_path.rstrip('/')}/*/*/metrics.csv" + # Use thread-safe credential access for S3 operations + with self._credential_lock: + s3_opts = self._s3_options.copy() + fs = fsspec.filesystem("s3", **s3_opts) + # Use glob to find all metrics.csv files recursively (including tagged versions) + pattern = f"{base_path.rstrip('/')}/**/*__metrics.csv" metrics_files = fs.glob(pattern) # Add s3:// prefix back since glob strips it metrics_files = [f"s3://{path}" for path in metrics_files] else: base_path_obj = Path(base_path) if base_path_obj.exists(): - metrics_files = [str(p) for p in base_path_obj.glob("*/*/metrics.csv")] + metrics_files = [ + str(p) for p in base_path_obj.glob("**/*__metrics.csv") + ] - logging.info(f"Found {len(metrics_files)} metrics.csv files in {base_path}") + logging.info( + f"Found {len(metrics_files)} metrics.csv files in {base_path}" + ) return metrics_files except Exception as e: - logging.error(f"Error finding metrics files in {base_path}: {e}") - return [] + logging.exception(f"Error finding metrics files in {base_path}") + raise DataServiceException( + f"Failed to find metrics files in {base_path}: {str(e)}" + ) from e def cleanup(self): - """Clean up resources, including HandIndexQuerier connection.""" + """Clean up resources, including HandIndexQuerier connection and refresh thread.""" + # Stop the credential refresh thread if it's running + if self._refresh_thread and self._refresh_thread.is_alive(): + logging.info("Stopping credential refresh thread...") + self._stop_refresh.set() + self._refresh_thread.join(timeout=5) + if self._refresh_thread.is_alive(): + logging.warning( + "Credential refresh thread did not stop gracefully" + ) + + # Clean up HandIndexQuerier if self.hand_querier: self.hand_querier.close() self.hand_querier = None From 5c6fe489f5ee9950d087170da8e1c9567004411b Mon Sep 17 00:00:00 2001 From: Dylan Lee Date: Wed, 10 Sep 2025 11:48:15 -0400 Subject: [PATCH 3/3] Fix line lengths --- src/data_service.py | 139 +++++++++------------------------- src/hand_index_querier.py | 6 +- src/stac_querier.py | 155 +++++++++----------------------------- 3 files changed, 75 insertions(+), 225 deletions(-) diff --git a/src/data_service.py b/src/data_service.py index 552a8ae..5d0867e 100644 --- a/src/data_service.py +++ b/src/data_service.py @@ -69,16 +69,12 @@ def __init__( if credentials.token: self._s3_options["token"] = credentials.token self._use_iam_credentials = True - logging.info( - "Using IAM instance profile credentials for S3 access" - ) + logging.info("Using IAM instance profile credentials for S3 access") # Start credential refresh thread for IAM credentials self._start_credential_refresh_thread() else: - logging.warning( - "No AWS credentials found in config or IAM instance profile" - ) + logging.warning("No AWS credentials found in config or IAM instance profile") except Exception as e: logging.warning(f"Failed to get IAM credentials: {e}") @@ -103,21 +99,15 @@ def __init__( # Initialize FlowfileCombiner self.flowfile_combiner = FlowfileCombiner( - output_dir=config.flow_scenarios.output_dir - if config.flow_scenarios - else "combined_flowfiles" + output_dir=config.flow_scenarios.output_dir if config.flow_scenarios else "combined_flowfiles" ) def _start_credential_refresh_thread(self): """Start a background thread to refresh IAM credentials every 4 hours.""" if self._use_iam_credentials: - self._refresh_thread = threading.Thread( - target=self._credential_refresh_loop, daemon=True - ) + self._refresh_thread = threading.Thread(target=self._credential_refresh_loop, daemon=True) self._refresh_thread.start() - logging.info( - "Started IAM credential refresh thread (4-hour interval)" - ) + logging.info("Started IAM credential refresh thread (4-hour interval)") def _credential_refresh_loop(self): """Background thread loop to refresh credentials every 4 hours.""" @@ -147,9 +137,7 @@ def _refresh_credentials(self): logging.info("Successfully refreshed IAM credentials") else: - logging.error( - "Failed to refresh credentials: No credentials available from boto3 session" - ) + logging.error("Failed to refresh credentials: No credentials available from boto3 session") except Exception as e: logging.error(f"Error refreshing IAM credentials: {e}") @@ -174,27 +162,19 @@ def load_polygon_gdf_from_file(self, file_path: str) -> gpd.GeoDataFrame: # Ensure CRS is EPSG:4326 if gdf.crs and gdf.crs.to_epsg() != 4326: - logging.info( - f"Converting polygon data from {gdf.crs} to EPSG:4326" - ) + logging.info(f"Converting polygon data from {gdf.crs} to EPSG:4326") gdf = gdf.to_crs("EPSG:4326") elif not gdf.crs: - logging.warning( - f"No CRS found in {file_path}, assuming EPSG:4326" - ) + logging.warning(f"No CRS found in {file_path}, assuming EPSG:4326") gdf.set_crs("EPSG:4326", inplace=True) - logging.info( - f"Loaded polygon GeoDataFrame with {len(gdf)} features from {file_path}" - ) + logging.info(f"Loaded polygon GeoDataFrame with {len(gdf)} features from {file_path}") return gdf except FileNotFoundError: raise FileNotFoundError(f"Polygon data file not found: {file_path}") except Exception as e: - raise ValueError( - f"Error loading GeoDataFrame from {file_path}: {e}" - ) + raise ValueError(f"Error loading GeoDataFrame from {file_path}: {e}") async def query_stac_for_flow_scenarios( self, @@ -212,18 +192,12 @@ async def query_stac_for_flow_scenarios( Dictionary with STAC query results and combined flowfiles """ # Generate polygon_id from index or use a default - polygon_id = ( - f"polygon_{len(polygon_gdf)}" if len(polygon_gdf) > 0 else "unknown" - ) + polygon_id = f"polygon_{len(polygon_gdf)}" if len(polygon_gdf) > 0 else "unknown" if not self.stac_querier or not self.config.stac: - raise RuntimeError( - "STAC configuration is required but not provided" - ) + raise RuntimeError("STAC configuration is required but not provided") - logging.info( - f"Querying STAC for flow scenarios for polygon: {polygon_id}" - ) + logging.info(f"Querying STAC for flow scenarios for polygon: {polygon_id}") try: # Run STAC query in executor to avoid blocking @@ -232,9 +206,7 @@ async def query_stac_for_flow_scenarios( if self.aoi_is_item: # Direct item query mode - extract aoi_name from tags if not tags or "aoi_name" not in tags: - raise ValueError( - "aoi_name tag is required when --aoi_is_item is used" - ) + raise ValueError("aoi_name tag is required when --aoi_is_item is used") aoi_name = tags["aoi_name"] logging.info(f"Querying STAC for specific item ID: {aoi_name}") @@ -260,9 +232,7 @@ async def query_stac_for_flow_scenarios( # Process flowfiles from STAC results combined_flowfiles = {} if self.flowfile_combiner: - logging.info( - f"Processing flowfiles for {len(stac_results)} collections" - ) + logging.info(f"Processing flowfiles for {len(stac_results)} collections") # Create temporary directory for this polygon's combined flowfiles temp_dir = f"/tmp/flow_scenarios_{polygon_id}" @@ -275,9 +245,7 @@ async def query_stac_for_flow_scenarios( temp_dir, ) - logging.info( - f"Combined flowfiles created for polygon {polygon_id}" - ) + logging.info(f"Combined flowfiles created for polygon {polygon_id}") return { "scenarios": stac_results, @@ -303,18 +271,14 @@ async def query_for_catchments(self, polygon_gdf: gpd.GeoDataFrame) -> Dict: await asyncio.sleep(0.01) # Generate polygon_id from index or use a default - polygon_id = ( - f"polygon_{len(polygon_gdf)}" if len(polygon_gdf) > 0 else "unknown" - ) + polygon_id = f"polygon_{len(polygon_gdf)}" if len(polygon_gdf) > 0 else "unknown" logging.info(f"Querying catchments for polygon: {polygon_id}") if self.hand_querier: # Use real hand index query try: # Create temporary directory for parquet outputs - with tempfile.TemporaryDirectory( - prefix=f"catchments_{polygon_id}_" - ) as temp_dir: + with tempfile.TemporaryDirectory(prefix=f"catchments_{polygon_id}_") as temp_dir: # Run the query with temporary output directory loop = asyncio.get_running_loop() catchments_result = await loop.run_in_executor( @@ -325,9 +289,7 @@ async def query_for_catchments(self, polygon_gdf: gpd.GeoDataFrame) -> Dict: ) if not catchments_result: - logging.info( - f"No catchments found for polygon {polygon_id}" - ) + logging.info(f"No catchments found for polygon {polygon_id}") return {"catchments": {}, "hand_version": "real_query"} # Copy parquet files to a more permanent location if needed @@ -357,15 +319,11 @@ async def query_for_catchments(self, polygon_gdf: gpd.GeoDataFrame) -> Dict: } except Exception as e: - logging.error( - f"Error in hand index query for polygon {polygon_id}: {e}" - ) + logging.error(f"Error in hand index query for polygon {polygon_id}: {e}") raise # No hand querier available - raise RuntimeError( - f"Hand index querier is required but not initialized for polygon {polygon_id}" - ) + raise RuntimeError(f"Hand index querier is required but not initialized for polygon {polygon_id}") async def copy_file_to_uri(self, source_path: str, dest_uri: str): """Copies a file (e.g., parquet) to a URI (local or S3) using fsspec. @@ -378,15 +336,12 @@ async def copy_file_to_uri(self, source_path: str, dest_uri: str): else source_path ) dest_uri_normalized = ( - os.path.abspath(dest_uri) - if not dest_uri.startswith(("s3://", "http://", "https://")) - else dest_uri + os.path.abspath(dest_uri) if not dest_uri.startswith(("s3://", "http://", "https://")) else dest_uri ) # Copy if source is local and destination is S3, or if both are local but paths differ should_copy = ( - not source_path.startswith(("s3://", "http://", "https://")) - and dest_uri.startswith("s3://") + not source_path.startswith(("s3://", "http://", "https://")) and dest_uri.startswith("s3://") ) or ( not source_path.startswith(("s3://", "http://", "https://")) and not dest_uri.startswith(("s3://", "http://", "https://")) @@ -411,14 +366,10 @@ async def copy_file_to_uri(self, source_path: str, dest_uri: str): return dest_uri except Exception as e: logging.exception(f"Failed to copy {source_path} to {dest_uri}") - raise DataServiceException( - f"Failed to copy file to {dest_uri}: {str(e)}" - ) from e + raise DataServiceException(f"Failed to copy file to {dest_uri}: {str(e)}") from e else: # If already on S3 or both local with same path, just return the source path - logging.debug( - f"No copy needed - using existing path: {source_path}" - ) + logging.debug(f"No copy needed - using existing path: {source_path}") return source_path def _sync_copy_file(self, source_path: str, dest_uri: str): @@ -454,19 +405,13 @@ async def append_file_to_uri(self, source_path: str, dest_uri: str): dest_uri, dest_exists, ) - logging.info( - f"Successfully {'appended to' if dest_exists else 'created'} {dest_uri}" - ) + logging.info(f"Successfully {'appended to' if dest_exists else 'created'} {dest_uri}") return dest_uri except Exception as e: logging.exception(f"Failed to append {source_path} to {dest_uri}") - raise DataServiceException( - f"Failed to append file to {dest_uri}: {str(e)}" - ) from e + raise DataServiceException(f"Failed to append file to {dest_uri}: {str(e)}") from e - def _sync_append_file( - self, source_path: str, dest_uri: str, dest_exists: bool - ): + def _sync_append_file(self, source_path: str, dest_uri: str, dest_exists: bool): """Synchronous helper for appending files using fsspec.""" # Read source content with open(source_path, "rb") as src: @@ -514,9 +459,7 @@ async def check_file_exists(self, uri: str) -> bool: ) except Exception as e: logging.exception(f"Error checking file {uri}") - raise DataServiceException( - f"Failed to check file existence {uri}: {str(e)}" - ) from e + raise DataServiceException(f"Failed to check file existence {uri}: {str(e)}") from e def _sync_check_file_exists(self, uri: str) -> bool: """Synchronous helper to check if file exists (S3 or local).""" @@ -531,9 +474,7 @@ def _sync_check_file_exists(self, uri: str) -> bool: # Local file return Path(uri).exists() except (FileNotFoundError, NoCredentialsError, ClientError) as e: - logging.debug( - f"File {uri} does not exist or is not accessible: {e}" - ) + logging.debug(f"File {uri} does not exist or is not accessible: {e}") return False except Exception as e: logging.warning(f"Unexpected error checking file {uri}: {e}") @@ -574,9 +515,7 @@ async def validate_files(self, uris: List[str]) -> List[str]: for uri in missing_uris: logging.info(f" Missing: {uri}") - logging.info( - f"Validated files: {len(valid_uris)} exist, {len(missing_uris)} missing" - ) + logging.info(f"Validated files: {len(valid_uris)} exist, {len(missing_uris)} missing") return valid_uris def find_metrics_files(self, base_path: str) -> List[str]: @@ -604,20 +543,14 @@ def find_metrics_files(self, base_path: str) -> List[str]: else: base_path_obj = Path(base_path) if base_path_obj.exists(): - metrics_files = [ - str(p) for p in base_path_obj.glob("**/*__metrics.csv") - ] + metrics_files = [str(p) for p in base_path_obj.glob("**/*__metrics.csv")] - logging.info( - f"Found {len(metrics_files)} metrics.csv files in {base_path}" - ) + logging.info(f"Found {len(metrics_files)} metrics.csv files in {base_path}") return metrics_files except Exception as e: logging.exception(f"Error finding metrics files in {base_path}") - raise DataServiceException( - f"Failed to find metrics files in {base_path}: {str(e)}" - ) from e + raise DataServiceException(f"Failed to find metrics files in {base_path}: {str(e)}") from e def cleanup(self): """Clean up resources, including HandIndexQuerier connection and refresh thread.""" @@ -627,9 +560,7 @@ def cleanup(self): self._stop_refresh.set() self._refresh_thread.join(timeout=5) if self._refresh_thread.is_alive(): - logging.warning( - "Credential refresh thread did not stop gracefully" - ) + logging.warning("Credential refresh thread did not stop gracefully") # Clean up HandIndexQuerier if self.hand_querier: diff --git a/src/hand_index_querier.py b/src/hand_index_querier.py index ac896f2..8fea0dc 100644 --- a/src/hand_index_querier.py +++ b/src/hand_index_querier.py @@ -47,7 +47,7 @@ def _ensure_connection(self): self.con.execute("LOAD httpfs;") self.con.execute("INSTALL aws;") self.con.execute("LOAD aws;") - + # Try using DuckDB's credential chain provider instead of manual credentials # This should work better with temporary credentials try: @@ -55,7 +55,7 @@ def _ensure_connection(self): CREATE SECRET s3_secret ( TYPE S3, PROVIDER CREDENTIAL_CHAIN, - REGION '{self.s3_region or 'us-east-1'}' + REGION '{self.s3_region or "us-east-1"}' ); """) logger.info("Using DuckDB credential chain for S3 access") @@ -70,7 +70,7 @@ def _ensure_connection(self): self.con.execute(f"SET s3_secret_access_key='{self.credentials.secret_key}';") if self.credentials.token: self.con.execute(f"SET s3_session_token='{self.credentials.token}';") - + self.con.execute("SET memory_limit = '7GB';") self.con.execute("SET temp_directory = '/tmp';") diff --git a/src/stac_querier.py b/src/stac_querier.py index d728c26..3509bfe 100644 --- a/src/stac_querier.py +++ b/src/stac_querier.py @@ -66,9 +66,7 @@ def compare_versions(v1: str, v2: str) -> int: class Interval: start: datetime end: datetime - assets: Dict[str, List[str]] = field( - default_factory=lambda: defaultdict(list) - ) + assets: Dict[str, List[str]] = field(default_factory=lambda: defaultdict(list)) class StacQuerier: @@ -108,20 +106,15 @@ def __init__( ] self.NWS_USGS_MAGS = ["action", "minor", "moderate", "major"] self.NWS_USGS_SPEC = [ - (re.compile(rf"^{mag}_(extent_raster|flow_file)$"), mag, r"\1") - for mag in self.NWS_USGS_MAGS - ] - self.RIPPLE_SPEC: List[Tuple[Pattern, str, str]] = [ - (re.compile(r"^(\d+yr)_extent$"), r"\1", "extents") + (re.compile(rf"^{mag}_(extent_raster|flow_file)$"), mag, r"\1") for mag in self.NWS_USGS_MAGS ] + self.RIPPLE_SPEC: List[Tuple[Pattern, str, str]] = [(re.compile(r"^(\d+yr)_extent$"), r"\1", "extents")] CollectionConfig = Dict[ str, Tuple[ Callable[[Any], str], # grouping fn → group_id - Dict[ - str, Callable[[str, Any], bool] - ], # asset_type → test(key, asset) + Dict[str, Callable[[str, Any], bool]], # asset_type → test(key, asset) ], ] @@ -129,28 +122,21 @@ def __init__( "gfm-collection": ( lambda item: str(item.properties.get("dfo_event_id", item.id)), { - "extents": lambda k, a: k.endswith( - "_Observed_Water_Extent" - ), + "extents": lambda k, a: k.endswith("_Observed_Water_Extent"), "flowfiles": lambda k, a: k.endswith("_flowfile"), }, ), "gfm-expanded-collection": ( self._group_gfm_expanded_initial, { - "extents": lambda k, a: k.endswith( - "_Observed_Water_Extent" - ), - "flowfiles": lambda k, a: k.endswith("_flowfile") - or k == "NWM_ANA_flowfile", + "extents": lambda k, a: k.endswith("_Observed_Water_Extent"), + "flowfiles": lambda k, a: k.endswith("_flowfile") or k == "NWM_ANA_flowfile", }, ), "hwm-collection": ( lambda item: item.id, { - "points": lambda k, a: k == "data" - and a.media_type - and "geopackage" in a.media_type, + "points": lambda k, a: k == "data" and a.media_type and "geopackage" in a.media_type, "flowfiles": lambda k, a: k.endswith("-flowfile"), }, ), @@ -166,9 +152,7 @@ def _ensure_client(self): logger.error(f"Could not open STAC API: {e}") raise - def _filter_items_by_geometry( - self, items: List[Any], query_polygon: Polygon - ) -> List[Any]: + def _filter_items_by_geometry(self, items: List[Any], query_polygon: Polygon) -> List[Any]: """ Filter STAC items using geometry relationships similar to HAND query filtering. @@ -214,18 +198,14 @@ def _filter_items_by_geometry( # Compute overlap percentage relative to item's area if not contains_query and not within_query: - intersection_area = item_geom.intersection( - query_polygon - ).area + intersection_area = item_geom.intersection(query_polygon).area item_area = item_geom.area if item_area > 0: overlap_pct = (intersection_area / item_area) * 100 else: overlap_pct = 0.0 else: - overlap_pct = ( - 100.0 # Contains or within means 100% relevant - ) + overlap_pct = 100.0 # Contains or within means 100% relevant # Apply selection criteria if contains_query: @@ -241,9 +221,7 @@ def _filter_items_by_geometry( stats["removed_count"] += 1 except Exception as e: - logger.warning( - f"Error processing item geometry for {getattr(item, 'id', 'unknown')}: {e}" - ) + logger.warning(f"Error processing item geometry for {getattr(item, 'id', 'unknown')}: {e}") # Include item if geometry processing fails filtered_items.append(item) @@ -287,9 +265,7 @@ def _group_gfm_expanded_initial(self, item: Any) -> str: fmt = "%Y-%m-%dT%H:%M:%SZ" return f"{start.strftime(fmt)}/{end.strftime(fmt)}" - def _format_results( - self, item_iter: Any - ) -> Dict[str, Dict[str, Dict[str, List[str]]]]: + def _format_results(self, item_iter: Any) -> Dict[str, Dict[str, Dict[str, List[str]]]]: results = defaultdict(lambda: defaultdict(lambda: defaultdict(list))) seen = set() ripple_cache: Dict[str, List[str]] = {} @@ -304,9 +280,7 @@ def _format_results( try: source = item.properties.get("source", "") hucs = tuple(item.properties.get("hucs", [])) - flows2fim_version = item.properties.get( - "flows2fim_version", "" - ) + flows2fim_version = item.properties.get("flows2fim_version", "") if source and hucs and flows2fim_version: key = (source, hucs) @@ -314,10 +288,7 @@ def _format_results( # If we haven't seen this source+hucs combo, or this version is better if ( key not in ripple_best_items - or compare_versions( - flows2fim_version, ripple_best_items[key][1] - ) - > 0 + or compare_versions(flows2fim_version, ripple_best_items[key][1]) > 0 ): ripple_best_items[key] = ( item.id, @@ -330,9 +301,7 @@ def _format_results( ) continue except Exception as e: - logger.warning( - f"Error processing Ripple item {item.id}: {e}" - ) + logger.warning(f"Error processing Ripple item {item.id}: {e}") if item.id in seen: continue @@ -372,11 +341,7 @@ def _format_results( "nws-fim-collection", "usgs-fim-collection", ]: - specs = ( - self.BLE_SPEC - if coll == "ble-collection" - else self.NWS_USGS_SPEC - ) + specs = self.BLE_SPEC if coll == "ble-collection" else self.NWS_USGS_SPEC found = set() for k, a in item.assets.items(): @@ -399,10 +364,7 @@ def _format_results( "usgs-fim-collection", ]: gauge = item.properties.get("gauge") - if ( - gauge - and gauge_info[collection_key][gid] is None - ): + if gauge and gauge_info[collection_key][gid] is None: gauge_info[collection_key][gid] = gauge # Extract hucs for all asset-level collections (BLE, NWS, USGS) @@ -423,9 +385,7 @@ def _format_results( m = re.search(r"flows_(\d+)_yr_", ak) ri = f"{m.group(1)}yr" if m else None if ri and aa.media_type == "text/csv": - logger.info( - f"Caching Ripple flowfile for {ri}: {aa.href}" - ) + logger.info(f"Caching Ripple flowfile for {ri}: {aa.href}") ripple_cache.setdefault(ri, []).append(aa.href) except Exception as e: logger.warning(f"Ripple cache failed: {e}") @@ -465,19 +425,10 @@ def _format_results( # Extract "100yr" from "SourceA-100yr" base_gid = composite_gid.split("-", 1)[-1] if base_gid in ripple_cache: - logger.info( - f"Adding cached flowfiles for {composite_gid}" - ) + logger.info(f"Adding cached flowfiles for {composite_gid}") for href in ripple_cache[base_gid]: - if ( - href - not in results[collection_key][composite_gid][ - "flowfiles" - ] - ): - results[collection_key][composite_gid][ - "flowfiles" - ].append(href) + if href not in results[collection_key][composite_gid]["flowfiles"]: + results[collection_key][composite_gid]["flowfiles"].append(href) # 4) fallback else: @@ -501,11 +452,7 @@ def _format_results( for k, a in item.assets.items(): if not a.href: continue - if ( - "extent" in k - and a.media_type - and "tiff" in a.media_type - ): + if "extent" in k and a.media_type and "tiff" in a.media_type: bkt["extents"].append(a.href) elif "flow" in k and a.media_type and "csv" in a.media_type: bkt["flowfiles"].append(a.href) @@ -513,15 +460,9 @@ def _format_results( # Add item IDs, gauge information, and hucs to results for collection_key in results: for scenario_id in results[collection_key]: - results[collection_key][scenario_id]["stac_items"] = list( - item_ids[collection_key][scenario_id] - ) - results[collection_key][scenario_id]["gauge"] = gauge_info[ - collection_key - ][scenario_id] - results[collection_key][scenario_id]["hucs"] = hucs_info[ - collection_key - ][scenario_id] + results[collection_key][scenario_id]["stac_items"] = list(item_ids[collection_key][scenario_id]) + results[collection_key][scenario_id]["gauge"] = gauge_info[collection_key][scenario_id] + results[collection_key][scenario_id]["hucs"] = hucs_info[collection_key][scenario_id] logger.info(f"Finished formatting {len(seen)} items.") return results @@ -645,9 +586,7 @@ def query_stac_for_polygon( try: query_polygon = shape(roi_geojson) except Exception as e: - logger.warning( - f"Could not convert ROI to shapely geometry: {e}" - ) + logger.warning(f"Could not convert ROI to shapely geometry: {e}") # Build search kwargs search_kw = { @@ -656,13 +595,9 @@ def query_stac_for_polygon( } # KLUDGE: Hardcode 2022 date range for GFM collections. this will be removed after trial batches with gfm data are completed. - if self.collections is not None and any( - "gfm" in coll.lower() for coll in self.collections - ): + if self.collections is not None and any("gfm" in coll.lower() for coll in self.collections): search_kw["datetime"] = "2022-01-01T00:00:00Z/2022-12-31T23:59:59Z" - logger.info( - "KLUDGE: Overriding datetime filter to 2022 for GFM collections" - ) + logger.info("KLUDGE: Overriding datetime filter to 2022 for GFM collections") # Only include collections if specified, otherwise query all available if self.collections is not None: @@ -671,14 +606,10 @@ def query_stac_for_polygon( search_kw = {k: v for k, v in search_kw.items() if v is not None} try: - collections_msg = search_kw.get( - "collections", "all available collections" - ) + collections_msg = search_kw.get("collections", "all available collections") logger.info(f"Searching collections {collections_msg}") search = self.client.search(**search_kw) - items = list( - search.items() - ) # Convert to list for geometry filtering + items = list(search.items()) # Convert to list for geometry filtering if not items: logger.info("STAC query returned no items") @@ -686,9 +617,7 @@ def query_stac_for_polygon( # Apply geometry filtering if query polygon is provided if query_polygon and self.overlap_threshold_percent: - logger.info( - f"Applying geometry filtering with {self.overlap_threshold_percent}% overlap threshold" - ) + logger.info(f"Applying geometry filtering with {self.overlap_threshold_percent}% overlap threshold") items = self._filter_items_by_geometry(items, query_polygon) if not items: logger.info("No items remained after geometry filtering") @@ -697,20 +626,14 @@ def query_stac_for_polygon( grouped = self._format_results(items) if not grouped: - logger.info( - "No valid scenarios found after processing STAC items" - ) + logger.info("No valid scenarios found after processing STAC items") return {} if "gfm_expanded" in grouped: - grouped["gfm_expanded"] = self._merge_gfm_expanded( - grouped["gfm_expanded"] - ) + grouped["gfm_expanded"] = self._merge_gfm_expanded(grouped["gfm_expanded"]) # Truncate timestamp ranges to just the first timestamp to get rid of double directory writing issue truncated = {} - for scenario_key, scenario_data in grouped[ - "gfm_expanded" - ].items(): + for scenario_key, scenario_data in grouped["gfm_expanded"].items(): # If key contains a slash (timestamp range), take only the first part if "/" in scenario_key: first_timestamp = scenario_key.split("/")[0] @@ -728,9 +651,7 @@ def query_stac_for_polygon( logger.error(f"Unexpected error: {ex}") raise - def query_stac_by_item_id( - self, item_id: str - ) -> Dict[str, Dict[str, Dict[str, List[str]]]]: + def query_stac_by_item_id(self, item_id: str) -> Dict[str, Dict[str, Dict[str, List[str]]]]: """ Query STAC API for a specific item by ID. @@ -756,9 +677,7 @@ def query_stac_by_item_id( grouped = self._format_results([item]) if not grouped: - logger.info( - f"No valid scenarios found after processing item {item_id}" - ) + logger.info(f"No valid scenarios found after processing item {item_id}") return {} return dictify(grouped)