Skip to content
This repository was archived by the owner on Feb 26, 2026. It is now read-only.
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 80 additions & 20 deletions src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import geopandas as gpd

import default_config
from data_service import DataService
from data_service import DataService, DataServiceException
from load_config import AppConfig, load_config
from metrics_aggregator import MetricsAggregator
from nomad_job_manager import NomadJobManager
Expand Down Expand Up @@ -42,13 +42,15 @@ def __init__(
polygon_gdf: gpd.GeoDataFrame,
tags: Dict[str, str],
outputs_path: str,
aoi_path: str,
log_db: Optional[PipelineLogDB] = None,
):
self.config = config
self.nomad = nomad
self.data_svc = data_svc
self.polygon_gdf = polygon_gdf
self.tags = tags
self.aoi_path = aoi_path
self.log_db = log_db
# Ensure the temp directory exists
temp_dir = "/tmp"
Expand All @@ -66,11 +68,11 @@ def __init__(
self.benchmark_scenarios: Dict[str, Dict[str, List[str]]] = {}
self.stac_results: Dict[str, Dict[str, Dict[str, List[str]]]] = {}

async def initialize(self) -> None:
"""Query for catchments and flow scenarios."""
async def initialize(self) -> Optional[Dict[str, Any]]:
"""Query for catchments and flow scenarios. Returns early exit info if no data found."""
# Query STAC for flow scenarios (always required)
logger.debug("Querying STAC for flow scenarios")
stac_data = await self.data_svc.query_stac_for_flow_scenarios(self.polygon_gdf)
stac_data = await self.data_svc.query_stac_for_flow_scenarios(self.polygon_gdf, self.tags)
self.flow_scenarios = stac_data.get("combined_flowfiles", {})

# Extract benchmark rasters from STAC scenarios
Expand All @@ -92,24 +94,41 @@ async def initialize(self) -> None:

if self.flow_scenarios:
logger.debug(f"Found {len(self.flow_scenarios)} collections")

if not self.flow_scenarios:
raise RuntimeError("No flow scenarios found")
else:
logger.warning("No flow scenarios found in STAC query results")
return {
"status": "no_data",
"message": "No flow scenarios found for the given polygon",
"catchment_count": 0,
"total_scenarios_attempted": 0,
"successful_scenarios": 0,
}

# Query hand index for catchments
logger.debug("Querying hand index for catchments")
data = await self.data_svc.query_for_catchments(self.polygon_gdf)
self.catchments = data.get("catchments", {})

if not self.catchments:
raise RuntimeError("No catchments found")
logger.warning("No catchments found in hand index query results")
return {
"status": "no_data",
"message": "No catchments found for the given polygon",
"catchment_count": 0,
"total_scenarios_attempted": 0,
"successful_scenarios": 0,
}

total_scenarios = sum(len(scenarios) for scenarios in self.flow_scenarios.values())
logger.info(f"Initialization complete: {len(self.catchments)} catchments, " f"{total_scenarios} flow scenarios")
logger.info(f"Initialization complete: {len(self.catchments)} catchments, {total_scenarios} flow scenarios")
return None

async def run(self) -> Dict[str, Any]:
"""Run the pipeline with stage-based parallelism."""
await self.initialize()
early_exit = await self.initialize()
if early_exit is not None:
logger.info(f"Pipeline exiting early: {early_exit['message']}")
return early_exit

# Build scenario results
results = []
Expand Down Expand Up @@ -144,13 +163,27 @@ async def run(self) -> Dict[str, Any]:
self.tags,
self.catchments,
)
mosaic_stage = MosaicStage(self.config, self.nomad, self.data_svc, self.path_factory, self.tags)
agreement_stage = AgreementStage(self.config, self.nomad, self.data_svc, self.path_factory, self.tags)
mosaic_stage = MosaicStage(
self.config,
self.nomad,
self.data_svc,
self.path_factory,
self.tags,
self.aoi_path,
)
agreement_stage = AgreementStage(
self.config,
self.nomad,
self.data_svc,
self.path_factory,
self.tags,
)

results = await inundation_stage.run(results)
results = await mosaic_stage.run(results)
results = await agreement_stage.run(results)

# Save results to JSON file
if results:
try:
results_json_path = self.path_factory.results_json_path()
Expand All @@ -170,13 +203,15 @@ async def run(self) -> Dict[str, Any]:
}
serializable_results.append(result_dict)

# Write to temporary file first, then copy to final location
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as temp_file:
json.dump(serializable_results, temp_file, indent=2)
temp_json_path = temp_file.name

await self.data_svc.copy_file_to_uri(temp_json_path, results_json_path)
logger.info(f"Results JSON written to {results_json_path}")

# Clean up temp file
if os.path.exists(temp_json_path):
os.unlink(temp_json_path)

Expand All @@ -190,8 +225,10 @@ async def run(self) -> Dict[str, Any]:
outputs_path=str(self.path_factory.base),
stac_results=self.stac_results,
data_service=self.data_svc,
flow_scenarios=self.flow_scenarios,
aoi_name=self.path_factory.aoi_name,
)
metrics_path = aggregator.save_results(self.path_factory.metrics_path())
metrics_path = aggregator.save_results(self.path_factory.results_path())
logger.info(f"Metrics aggregation completed: {metrics_path}")
except Exception as e:
logger.error(f"Metrics aggregation failed: {e}")
Expand All @@ -208,6 +245,9 @@ async def run(self) -> Dict[str, Any]:
}
logger.info(f"Pipeline SUCCESS: {len(successful_results)}/{total_attempted} scenarios completed")
return summary
except DataServiceException as e:
logger.error(f"Pipeline FAILED due to data service error: {str(e)}")
return {"status": "failed", "error": str(e), "message": f"Data service error: {str(e)}"}
except Exception as e:
logger.error(f"Pipeline FAILED: {str(e)}")
return {
Expand Down Expand Up @@ -261,8 +301,8 @@ def parsed_tags(tag_list):

if tags:
tags_str = ",".join(f"{k}={v}" for k, v in tags.items())
if len(tags_str) > 120:
raise argparse.ArgumentTypeError(f"Tags exceed 120 character limit ({len(tags_str)} chars): {tags_str}")
if len(tags_str) > 150:
raise argparse.ArgumentTypeError(f"Tags exceed 150 character limit ({len(tags_str)} chars): {tags_str}")

return tags

Expand All @@ -282,7 +322,12 @@ def parsed_tags(tag_list):
default=None,
help="Comma-separated list of STAC collections to query (e.g., 'ble-collection,nwm-collection'). Defaults to all available sources.",
)
parser.add_argument("--hand_index_path", type=str, required=True, help="Path to HAND index data (required)")
parser.add_argument(
"--hand_index_path",
type=str,
required=True,
help="Path to HAND index data (required)",
)

parser.add_argument(
"--tags",
Expand All @@ -292,6 +337,12 @@ def parsed_tags(tag_list):
help="List of key=value pairs for tagging (e.g., --tags batch=my_batch aoi=texas) These tags are included in job_ids that the pipeline will dispatch.",
)

parser.add_argument(
"--aoi_is_item",
action="store_true",
help="If set, treat the aoi_name tag as a STAC item ID for direct querying instead of performing spatial queries",
)

args = parser.parse_args()

if args.tags and args.tags != [""]:
Expand Down Expand Up @@ -344,8 +395,8 @@ async def _main():
nomad_addr=cfg.nomad.address,
namespace=cfg.nomad.namespace,
token=cfg.nomad.token,
session=session,
log_db=log_db,
max_concurrent_dispatch=cfg.defaults.nomad_max_concurrent_dispatch,
)
await nomad.start()

Expand All @@ -355,7 +406,7 @@ async def _main():
benchmark_collections = [col.strip() for col in args.benchmark_sources.split(",")]
logging.info(f"Using benchmark sources: {benchmark_collections}")

data_svc = DataService(cfg, args.hand_index_path, benchmark_collections)
data_svc = DataService(cfg, args.hand_index_path, benchmark_collections, args.aoi_is_item)

logging.info(f"Loading polygon from: {args.aoi}")
polygon_gdf = data_svc.load_polygon_gdf_from_file(args.aoi)
Expand All @@ -374,7 +425,16 @@ async def _main():

logging.info(f"Using HAND index path: {args.hand_index_path}")

pipeline = PolygonPipeline(cfg, nomad, data_svc, polygon_gdf, args.tags, outputs_path, log_db)
pipeline = PolygonPipeline(
cfg,
nomad,
data_svc,
polygon_gdf,
args.tags,
outputs_path,
args.aoi,
log_db,
)
logging.info(f"Started pipeline run for {args.aoi} with outputs to {outputs_path}")

try:
Expand All @@ -389,7 +449,7 @@ async def _main():
root_logger.removeHandler(file_handler)

final_log_path = pipeline.path_factory.logs_path()
await data_svc.copy_file_to_uri(temp_log_path, final_log_path)
await data_svc.append_file_to_uri(temp_log_path, final_log_path)
logging.info(f"Logs written to {final_log_path}")

print(json.dumps(summary, indent=2))
Expand Down