Skip to content
This repository was archived by the owner on Feb 26, 2026. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
240 changes: 209 additions & 31 deletions src/data_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,15 @@
import os
import shutil
import tempfile
import threading
import time
from contextlib import suppress
from datetime import datetime, timedelta
from io import StringIO
from pathlib import Path
from typing import Any, Dict, List, Optional

import boto3
import fsspec
import geopandas as gpd
from botocore.exceptions import ClientError, NoCredentialsError
Expand All @@ -20,12 +24,33 @@
from stac_querier import StacQuerier


class DataServiceException(Exception):
"""Custom exception for data service operations."""

pass


class DataService:
"""Service to query data sources and interact with S3 via fsspec."""

def __init__(self, config: AppConfig, hand_index_path: str, benchmark_collections: Optional[List[str]] = None):
def __init__(
self,
config: AppConfig,
hand_index_path: str,
benchmark_collections: Optional[List[str]] = None,
aoi_is_item: bool = False,
):
self.config = config
# Configure S3 filesystem options from environment
self.aoi_is_item = aoi_is_item

# Initialize credential refresh tracking
self._credential_lock = threading.Lock()
self._last_credential_refresh = datetime.now()
self._refresh_thread = None
self._stop_refresh = threading.Event()
self._use_iam_credentials = False

# Configure S3 filesystem options - try config first, then IAM credentials
self._s3_options = {}
if config.aws.AWS_ACCESS_KEY_ID:
self._s3_options["key"] = config.aws.AWS_ACCESS_KEY_ID
Expand All @@ -34,6 +59,25 @@ def __init__(self, config: AppConfig, hand_index_path: str, benchmark_collection
if config.aws.AWS_SESSION_TOKEN:
self._s3_options["token"] = config.aws.AWS_SESSION_TOKEN

# If no explicit credentials in config, try to get from IAM instance profile
if not self._s3_options:
try:
credentials = boto3.Session().get_credentials()
if credentials:
self._s3_options["key"] = credentials.access_key
self._s3_options["secret"] = credentials.secret_key
if credentials.token:
self._s3_options["token"] = credentials.token
self._use_iam_credentials = True
logging.info("Using IAM instance profile credentials for S3 access")

# Start credential refresh thread for IAM credentials
self._start_credential_refresh_thread()
else:
logging.warning("No AWS credentials found in config or IAM instance profile")
except Exception as e:
logging.warning(f"Failed to get IAM credentials: {e}")

# Initialize HandIndexQuerier with provided path
self.hand_querier = None
if hand_index_path:
Expand All @@ -50,13 +94,53 @@ def __init__(self, config: AppConfig, hand_index_path: str, benchmark_collection
collections=benchmark_collections,
overlap_threshold_percent=config.stac.overlap_threshold_percent,
datetime_filter=config.stac.datetime_filter,
aoi_is_item=aoi_is_item,
)

# Initialize FlowfileCombiner
self.flowfile_combiner = FlowfileCombiner(
output_dir=config.flow_scenarios.output_dir if config.flow_scenarios else "combined_flowfiles"
)

def _start_credential_refresh_thread(self):
"""Start a background thread to refresh IAM credentials every 4 hours."""
if self._use_iam_credentials:
self._refresh_thread = threading.Thread(target=self._credential_refresh_loop, daemon=True)
self._refresh_thread.start()
logging.info("Started IAM credential refresh thread (4-hour interval)")

def _credential_refresh_loop(self):
"""Background thread loop to refresh credentials every 4 hours."""
refresh_interval = 4 * 60 * 60 # sec

while not self._stop_refresh.is_set():
# Wait for 4 hours or until stop event is set
if self._stop_refresh.wait(refresh_interval):
break

# Refresh credentials
self._refresh_credentials()

def _refresh_credentials(self):
"""Refresh IAM credentials from boto3 session."""
try:
logging.info("Refreshing IAM credentials...")
credentials = boto3.Session().get_credentials()

if credentials:
with self._credential_lock:
self._s3_options["key"] = credentials.access_key
self._s3_options["secret"] = credentials.secret_key
if credentials.token:
self._s3_options["token"] = credentials.token
self._last_credential_refresh = datetime.now()

logging.info("Successfully refreshed IAM credentials")
else:
logging.error("Failed to refresh credentials: No credentials available from boto3 session")
except Exception as e:
logging.error(f"Error refreshing IAM credentials: {e}")

def load_polygon_gdf_from_file(self, file_path: str) -> gpd.GeoDataFrame:
"""
Args:
Expand All @@ -71,8 +155,7 @@ def load_polygon_gdf_from_file(self, file_path: str) -> gpd.GeoDataFrame:
"""
try:
# geopandas can read from both local and S3 paths
# For S3 paths, it will use storage_options; for local paths, it ignores them
gdf = gpd.read_file(file_path, storage_options=self._s3_options if file_path.startswith("s3://") else None)
gdf = gpd.read_file(file_path)

if len(gdf) == 0:
raise ValueError(f"Empty GeoDataFrame in file: {file_path}")
Expand All @@ -93,12 +176,17 @@ def load_polygon_gdf_from_file(self, file_path: str) -> gpd.GeoDataFrame:
except Exception as e:
raise ValueError(f"Error loading GeoDataFrame from {file_path}: {e}")

async def query_stac_for_flow_scenarios(self, polygon_gdf: gpd.GeoDataFrame) -> Dict:
async def query_stac_for_flow_scenarios(
self,
polygon_gdf: gpd.GeoDataFrame,
tags: Optional[Dict[str, str]] = None,
) -> Dict:
"""
Query STAC API for flow scenarios based on polygon.
Query STAC API for flow scenarios based on polygon or item ID.

Args:
polygon_gdf: GeoDataFrame containing polygon geometry
tags: Optional tags dictionary (required when aoi_is_item is True)

Returns:
Dictionary with STAC query results and combined flowfiles
Expand All @@ -114,12 +202,28 @@ async def query_stac_for_flow_scenarios(self, polygon_gdf: gpd.GeoDataFrame) ->
try:
# Run STAC query in executor to avoid blocking
loop = asyncio.get_running_loop()
stac_results = await loop.run_in_executor(
None,
self.stac_querier.query_stac_for_polygon,
polygon_gdf,
None, # roi_geojson not needed since we have polygon_gdf
)

if self.aoi_is_item:
# Direct item query mode - extract aoi_name from tags
if not tags or "aoi_name" not in tags:
raise ValueError("aoi_name tag is required when --aoi_is_item is used")

aoi_name = tags["aoi_name"]
logging.info(f"Querying STAC for specific item ID: {aoi_name}")

stac_results = await loop.run_in_executor(
None,
self.stac_querier.query_stac_by_item_id,
aoi_name,
)
else:
# Standard spatial query mode
stac_results = await loop.run_in_executor(
None,
self.stac_querier.query_stac_for_polygon,
polygon_gdf,
None, # roi_geojson not needed since we have polygon_gdf
)

if not stac_results:
logging.info(f"No STAC results found for polygon {polygon_id}")
Expand Down Expand Up @@ -209,7 +313,10 @@ async def query_for_catchments(self, polygon_gdf: gpd.GeoDataFrame) -> Dict:
logging.info(
f"Data service returning {len(catchments)} catchments for polygon {polygon_id} (real query)."
)
return {"catchments": catchments, "hand_version": "real_query"}
return {
"catchments": catchments,
"hand_version": "real_query",
}

except Exception as e:
logging.error(f"Error in hand index query for polygon {polygon_id}: {e}")
Expand Down Expand Up @@ -259,24 +366,80 @@ async def copy_file_to_uri(self, source_path: str, dest_uri: str):
return dest_uri
except Exception as e:
logging.exception(f"Failed to copy {source_path} to {dest_uri}")
raise ConnectionError(f"Failed to copy file to {dest_uri}") from e
raise DataServiceException(f"Failed to copy file to {dest_uri}: {str(e)}") from e
else:
# If already on S3 or both local with same path, just return the source path
logging.debug(f"No copy needed - using existing path: {source_path}")
return source_path

def _sync_copy_file(self, source_path: str, dest_uri: str):
"""Synchronous helper for copying files using fsspec."""
# Create filesystem based on destination URI
# Copy file using fsspec.open directly (allows fallback to default AWS credentials)
with open(source_path, "rb") as src:
# Use thread-safe credential access for S3 operations
if dest_uri.startswith("s3://"):
with self._credential_lock:
s3_opts = self._s3_options.copy()
with fsspec.open(dest_uri, "wb", **s3_opts) as dst:
dst.write(src.read())
else:
with fsspec.open(dest_uri, "wb") as dst:
dst.write(src.read())

async def append_file_to_uri(self, source_path: str, dest_uri: str):
"""Appends a file to a URI (local or S3) using fsspec.
If the destination file doesn't exist, creates it with the source content.
If it exists, appends the source content to it."""

logging.debug(f"Appending file from {source_path} to {dest_uri}")
try:
# Check if destination exists
dest_exists = await self.check_file_exists(dest_uri)

# Run in executor for async operation
loop = asyncio.get_running_loop()
await loop.run_in_executor(
None,
self._sync_append_file,
source_path,
dest_uri,
dest_exists,
)
logging.info(f"Successfully {'appended to' if dest_exists else 'created'} {dest_uri}")
return dest_uri
except Exception as e:
logging.exception(f"Failed to append {source_path} to {dest_uri}")
raise DataServiceException(f"Failed to append file to {dest_uri}: {str(e)}") from e

def _sync_append_file(self, source_path: str, dest_uri: str, dest_exists: bool):
"""Synchronous helper for appending files using fsspec."""
# Read source content
with open(source_path, "rb") as src:
source_content = src.read()

# Use thread-safe credential access for S3 operations
if dest_uri.startswith("s3://"):
fs = fsspec.filesystem("s3", **self._s3_options)
with self._credential_lock:
s3_opts = self._s3_options.copy()
else:
fs = fsspec.filesystem("file")
s3_opts = {}

# Copy file using fsspec
with open(source_path, "rb") as src:
with fs.open(dest_uri, "wb") as dst:
dst.write(src.read())
if dest_exists:
# If destination exists, read existing content first
with fsspec.open(dest_uri, "rb", **s3_opts) as dst:
existing_content = dst.read()

# Write combined content
with fsspec.open(dest_uri, "wb", **s3_opts) as dst:
dst.write(existing_content)
dst.write(source_content)
else:
# If destination doesn't exist, just write the source content
if not dest_uri.startswith(("s3://", "http://", "https://")):
os.makedirs(os.path.dirname(dest_uri), exist_ok=True)

with fsspec.open(dest_uri, "wb", **s3_opts) as dst:
dst.write(source_content)

async def check_file_exists(self, uri: str) -> bool:
"""Check if a file exists (S3 or local).
Expand All @@ -295,14 +458,17 @@ async def check_file_exists(self, uri: str) -> bool:
uri,
)
except Exception as e:
logging.warning(f"Error checking file {uri}: {e}")
return False
logging.exception(f"Error checking file {uri}")
raise DataServiceException(f"Failed to check file existence {uri}: {str(e)}") from e

def _sync_check_file_exists(self, uri: str) -> bool:
"""Synchronous helper to check if file exists (S3 or local)."""
try:
if uri.startswith("s3://"):
fs = fsspec.filesystem("s3", **self._s3_options)
# Use thread-safe credential access for S3 operations
with self._credential_lock:
s3_opts = self._s3_options.copy()
fs = fsspec.filesystem("s3", **s3_opts)
return fs.exists(uri)
else:
# Local file
Expand Down Expand Up @@ -365,26 +531,38 @@ def find_metrics_files(self, base_path: str) -> List[str]:

try:
if base_path.startswith("s3://"):
fs = fsspec.filesystem("s3", **self._s3_options)
# Use glob to find all metrics.csv files recursively
pattern = f"{base_path.rstrip('/')}/*/*/metrics.csv"
# Use thread-safe credential access for S3 operations
with self._credential_lock:
s3_opts = self._s3_options.copy()
fs = fsspec.filesystem("s3", **s3_opts)
# Use glob to find all metrics.csv files recursively (including tagged versions)
pattern = f"{base_path.rstrip('/')}/**/*__metrics.csv"
metrics_files = fs.glob(pattern)
# Add s3:// prefix back since glob strips it
metrics_files = [f"s3://{path}" for path in metrics_files]
else:
base_path_obj = Path(base_path)
if base_path_obj.exists():
metrics_files = [str(p) for p in base_path_obj.glob("*/*/metrics.csv")]
metrics_files = [str(p) for p in base_path_obj.glob("**/*__metrics.csv")]

logging.info(f"Found {len(metrics_files)} metrics.csv files in {base_path}")
return metrics_files

except Exception as e:
logging.error(f"Error finding metrics files in {base_path}: {e}")
return []
logging.exception(f"Error finding metrics files in {base_path}")
raise DataServiceException(f"Failed to find metrics files in {base_path}: {str(e)}") from e

def cleanup(self):
"""Clean up resources, including HandIndexQuerier connection."""
"""Clean up resources, including HandIndexQuerier connection and refresh thread."""
# Stop the credential refresh thread if it's running
if self._refresh_thread and self._refresh_thread.is_alive():
logging.info("Stopping credential refresh thread...")
self._stop_refresh.set()
self._refresh_thread.join(timeout=5)
if self._refresh_thread.is_alive():
logging.warning("Credential refresh thread did not stop gracefully")

# Clean up HandIndexQuerier
if self.hand_querier:
self.hand_querier.close()
self.hand_querier = None
Loading