diff --git a/src/default_config.py b/src/default_config.py index 0df7a5d..abb4b1e 100644 --- a/src/default_config.py +++ b/src/default_config.py @@ -40,3 +40,4 @@ # General defaults FIM_TYPE = "extent" HTTP_CONNECTION_LIMIT = 100 +NOMAD_MAX_CONCURRENT_DISPATCH = 5 # Keep this on the smaller side so that there aren't too many API calls being generated by many pipeline jobs running on a single client. diff --git a/src/load_config.py b/src/load_config.py index b50089e..7c3bc38 100644 --- a/src/load_config.py +++ b/src/load_config.py @@ -105,6 +105,14 @@ class Defaults(BaseModel): gt=0, description="Max concurrent outgoing HTTP connections", ) + nomad_max_concurrent_dispatch: int = Field( + default_factory=lambda: int( + os.getenv("NOMAD_MAX_CONCURRENT_DISPATCH", str(default_config.NOMAD_MAX_CONCURRENT_DISPATCH)) + ), + gt=0, + le=10, + description="Max concurrent Nomad API calls (should be less than urllib3's pool size of 10)", + ) class AppConfig(BaseModel): diff --git a/src/nomad_job_manager.py b/src/nomad_job_manager.py index afbd8ee..7d39335 100644 --- a/src/nomad_job_manager.py +++ b/src/nomad_job_manager.py @@ -1,14 +1,14 @@ import asyncio import json import logging +import random import time from dataclasses import dataclass, field from datetime import datetime, timezone from enum import Enum -from typing import Any, Dict, List, Optional, Set, Tuple +from typing import Any, Callable, Coroutine, Dict, List, Optional, Tuple from urllib.parse import urlparse -import aiohttp import nomad from tenacity import ( retry, @@ -21,36 +21,56 @@ class JobStatus(Enum): + """Internal job status enumeration aligned with Nomad ClientStatus values.""" + DISPATCHED = "dispatched" - ALLOCATED = "allocated" + PENDING = "pending" RUNNING = "running" SUCCEEDED = "succeeded" FAILED = "failed" LOST = "lost" + STOPPED = "stopped" + CANCELLED = "cancelled" UNKNOWN = "unknown" # Mapping from Nomad statuses to our internal statuses NOMAD_STATUS_MAP = { - "pending": JobStatus.ALLOCATED, + "pending": JobStatus.PENDING, "running": JobStatus.RUNNING, "complete": JobStatus.SUCCEEDED, "failed": JobStatus.FAILED, "lost": JobStatus.LOST, - "dead": JobStatus.FAILED, + "dead": JobStatus.STOPPED, + "stopped": JobStatus.STOPPED, + "cancelled": JobStatus.CANCELLED, } class NomadError(Exception): + """Base exception for the NomadJobManager.""" + pass class JobNotFoundError(NomadError): + """Raised when a specific job cannot be found in Nomad.""" + pass class JobDispatchError(NomadError): - pass + """ + Raised when a job fails to dispatch or fails during execution. + Contains enriched details about the failure. + """ + + def __init__(self, message: str, **kwargs): + super().__init__(message) + self.job_id: Optional[str] = kwargs.get("job_id") + self.allocation_id: Optional[str] = kwargs.get("allocation_id") + self.exit_code: Optional[int] = kwargs.get("exit_code") + self.original_error: Optional[Exception] = kwargs.get("original_error") @dataclass @@ -69,65 +89,88 @@ class JobTracker: timestamp: Optional[datetime] = None +class _NomadAPIClient: + """A wrapper for Nomad API calls that provides retries and concurrency limiting.""" + + def __init__(self, client: nomad.Nomad, max_concurrency: int): + self.client = client + self.semaphore = asyncio.Semaphore(max_concurrency) + self._nomad_retry = retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=2, max=30), + retry=retry_if_exception_type( + ( + nomad.api.exceptions.BaseNomadException, + nomad.api.exceptions.URLNotFoundNomadException, + ) + ), + ) + + async def _call(self, func: Callable, *args, **kwargs) -> Any: + async with self.semaphore: + return await asyncio.to_thread(func, *args, **kwargs) + + async def dispatch_job(self, job_name: str, prefix: str, meta: Optional[Dict[str, str]]) -> str: + wrapped_call = self._nomad_retry(self._call) + result = await wrapped_call( + self.client.job.dispatch_job, + id_=job_name, + payload=None, + meta=meta, + id_prefix_template=prefix.replace(":", "-"), + ) + return result["DispatchedJobID"] + + async def get_job(self, job_id: str) -> Dict[str, Any]: + wrapped_call = self._nomad_retry(self._call) + return await wrapped_call(self.client.job.get_job, job_id) + + async def get_allocations(self, job_id: str) -> List[Dict[str, Any]]: + wrapped_call = self._nomad_retry(self._call) + return await wrapped_call(self.client.job.get_allocations, job_id) + + class NomadJobManager: + """ + Manages the lifecycle of parameterized Nomad jobs using a polling-based + strategy with exponential backoff. + """ + def __init__( self, nomad_addr: str, namespace: str = "default", token: Optional[str] = None, - session: Optional[aiohttp.ClientSession] = None, log_db: Optional[Any] = None, + max_concurrent_dispatch: int = 10, ): self.nomad_addr = nomad_addr self.namespace = namespace self.token = token - self.session = session self.log_db = log_db - parsed = urlparse(str(nomad_addr)) - self.client = nomad.Nomad( + parsed = urlparse(nomad_addr) + nomad_client = nomad.Nomad( host=parsed.hostname, port=parsed.port, verify=False, - token=token or None, - namespace=namespace or None, + token=token, + namespace=namespace, ) - - # Track active jobs + self.api = _NomadAPIClient(nomad_client, max_concurrent_dispatch) self._active_jobs: Dict[str, JobTracker] = {} - self._monitoring_task: Optional[asyncio.Task] = None self._shutdown_event = asyncio.Event() - self._event_index = 0 - async def start(self): - if not self._monitoring_task: - self._monitoring_task = asyncio.create_task(self._monitor_events()) - logger.info("Started Nomad job manager") + """Starts the job manager (no background tasks in this version).""" + logger.info("Starting NomadJobManager (polling-based).") + # No background tasks to start in a pure polling model async def stop(self): + """Stops the job manager.""" + logger.info("Stopping NomadJobManager.") self._shutdown_event.set() - if self._monitoring_task: - await self._monitoring_task - self._monitoring_task = None - logger.info("Stopped Nomad job manager") - - # Retry decorator for Nomad API calls - _nomad_retry = retry( - stop=stop_after_attempt(3), - wait=wait_exponential(multiplier=1, min=2, max=30), - retry=retry_if_exception_type( - ( - nomad.api.exceptions.BaseNomadException, - nomad.api.exceptions.URLNotFoundNomadException, - ) - ), - ) - - @_nomad_retry - async def _nomad_call(self, func, *args, **kwargs): - """Execute a Nomad API call with retry logic.""" - return await asyncio.to_thread(func, *args, **kwargs) + # No background tasks to stop async def dispatch_and_track( self, @@ -136,222 +179,224 @@ async def dispatch_and_track( meta: Optional[Dict[str, str]] = None, ) -> Tuple[str, int]: """ - Dispatch a job and track it to completion. - - Returns: - Tuple of (job_id, exit_code) + Dispatches a job and polls its status until completion using + an exponential backoff strategy. """ - job_id = await self._dispatch_job(job_name, prefix, meta) - - tracker = JobTracker( - job_id=job_id, - task_name=job_name, - stage=meta.get("stage") if meta else None, - ) + try: + job_id = await self.api.dispatch_job(job_name, prefix, meta) + except Exception as e: + # Extract more detailed error information + error_msg = str(e) + + # Special handling for RetryError that wraps BaseNomadException + if "RetryError" in str(type(e)) and "BaseNomadException" in error_msg: + # Try to extract the cause from RetryError + if hasattr(e, "__cause__") and e.__cause__: + error_msg = f"{error_msg} - Cause: {str(e.__cause__)}" + elif hasattr(e, "last_attempt") and hasattr(e.last_attempt, "exception"): + # For tenacity RetryError + last_exception = e.last_attempt.exception() + if last_exception: + error_msg = f"RetryError after 3 attempts - Last error: {str(last_exception)}" + + logger.error(f"Failed to dispatch job {job_name}: {error_msg}") + raise JobDispatchError(f"Failed to dispatch job {job_name}: {error_msg}") from e + + tracker = JobTracker(job_id=job_id, task_name=job_name, stage=(meta or {}).get("stage")) self._active_jobs[job_id] = tracker - - # Update database if available - await self._update_job_status(tracker) + await self._update_db_status(tracker) try: - await self._wait_for_allocation(tracker) + # Polling loop with exponential backoff + poll_delay = 1.0 # Start with a 1-second delay + backoff_factor = 1.5 + jitter_factor = 0.25 # Apply up to 25% jitter + max_poll_delay = 30.0 # Max delay of 30s - await self._wait_for_completion(tracker) + while not tracker.completion_event.is_set(): + jitter = poll_delay * jitter_factor * random.uniform(-1, 1) + await asyncio.sleep(max(0, poll_delay + jitter)) - if tracker.error: - raise tracker.error + await self._poll_job_and_update_tracker(tracker) - return job_id, tracker.exit_code or 0 + # Increase delay for the next poll + poll_delay = min(poll_delay * backoff_factor, max_poll_delay) + if tracker.error: + raise JobDispatchError( + f"Job {job_id} failed with exit code {tracker.exit_code}: {tracker.error}", + job_id=job_id, + allocation_id=tracker.allocation_id, + exit_code=tracker.exit_code, + original_error=tracker.error, + ) + return job_id, tracker.exit_code or 0 finally: self._active_jobs.pop(job_id, None) - async def _dispatch_job( - self, - job_name: str, - prefix: str, - meta: Optional[Dict[str, str]] = None, - ) -> str: - """Dispatch a parameterized job.""" - payload = {"Meta": meta} if meta else {} - + async def _poll_job_and_update_tracker(self, tracker: JobTracker): + """Polls a single job and updates its tracker if the status has changed.""" try: - logger.debug(f"Dispatching job {job_name} with meta: {meta}") - result = await self._nomad_call( - self.client.job.dispatch_job, - id_=job_name, - payload=None, - meta=meta, - id_prefix_template=prefix, - ) - job_id = result["DispatchedJobID"] - logger.info(f"Dispatched job {job_id} from {job_name}") - return job_id - - except Exception as e: - logger.error(f"Failed to dispatch job {job_name}: {e}") - logger.error(f"Payload was: {payload}") - raise JobDispatchError(f"Failed to dispatch job {job_name}: {e}") + # First, check if the job itself still exists. This avoids polling for allocations of a job that has been purged. + await self.api.get_job(tracker.job_id) + + allocations = await self.api.get_allocations(tracker.job_id) + + time_since_dispatch = time.time() - tracker.dispatch_time + if ( + not allocations and time_since_dispatch > 28800 + ): # allow 8 hrs to allocate for jobs that take a while to go from pending to dispatched + logger.warning(f"No allocations found for job {tracker.job_id} after 30 minutes, marking as LOST.") + tracker.status = JobStatus.LOST + tracker.error = Exception("Job lost: No allocations found after timeout.") + tracker.completion_event.set() + await self._update_db_status(tracker) + return - async def _wait_for_allocation(self, tracker: JobTracker): - while True: - if tracker.allocation_id: + if not allocations: + logger.debug(f"Polling {tracker.job_id}: No allocations yet.") return - if tracker.status in (JobStatus.FAILED, JobStatus.LOST): - raise JobDispatchError(f"Job {tracker.job_id} failed during allocation") - await asyncio.sleep(1) - - async def _wait_for_completion(self, tracker: JobTracker): - await tracker.completion_event.wait() - - async def _monitor_events(self): - """Monitor Nomad event stream for job updates.""" - retry_count = 0 - max_retries = 5 - - while not self._shutdown_event.is_set(): - try: - await self._process_event_stream() - retry_count = 0 - except Exception as e: - retry_count += 1 - if retry_count >= max_retries: - logger.error(f"Event stream failed {max_retries} times, stopping monitor") - break - - wait_time = min(2**retry_count, 60) - logger.warning(f"Event stream error: {e}, retrying in {wait_time}s") - await asyncio.sleep(wait_time) - - async def _process_event_stream(self): - if not self.session: - self.session = aiohttp.ClientSession() - - url = f"{self.nomad_addr}/v1/event/stream" - params = { - "index": self._event_index, - "namespace": self.namespace, - } - - headers = {} - if self.token: - headers["X-Nomad-Token"] = self.token - - async with self.session.get(url, params=params, headers=headers, timeout=None) as response: - response.raise_for_status() - - buffer = b"" - async for chunk in response.content.iter_chunked(8 * 1024): - if self._shutdown_event.is_set(): - break - - buffer += chunk - while b"\n" in buffer: - line_bytes, buffer = buffer.split(b"\n", 1) - line = line_bytes.decode("utf-8").strip() - - if not line or line == "{}": - continue - - try: - data = json.loads(line) - if "Index" in data: - self._event_index = data["Index"] - - events = data.get("Events", []) - for event in events: - await self._handle_event(event) - - except json.JSONDecodeError: - logger.debug(f"Failed to parse event line: {line}") - except Exception as e: - logger.error(f"Error handling event: {e}") - - async def _handle_event(self, event: Dict[str, Any]): - """Handle a single Nomad event.""" - topic = event.get("Topic") - if topic != "Allocation": - return - payload = event.get("Payload", {}).get("Allocation", {}) - job_id = payload.get("JobID") + latest_alloc = max(allocations, key=lambda a: a.get("CreateTime", 0)) + await self._update_tracker_from_allocation(tracker, latest_alloc) - if not job_id or job_id not in self._active_jobs: - return + except ( + nomad.api.exceptions.URLNotFoundNomadException, + nomad.api.exceptions.BaseNomadException, + ) as e: + # Extract more detailed error information + error_msg = str(e) + if hasattr(e, "__cause__") and e.__cause__: + error_msg = f"{error_msg} - Cause: {str(e.__cause__)}" - tracker = self._active_jobs[job_id] - allocation_id = payload.get("ID") - client_status = payload.get("ClientStatus", "").lower() - - # Update allocation ID - if not tracker.allocation_id and allocation_id: - tracker.allocation_id = allocation_id - logger.info(f"Job {job_id} allocated: {allocation_id}") + logger.error( + f"Polling failed for job {tracker.job_id}, it may have been purged. Marking as LOST. Error: {error_msg}" + ) + tracker.status = JobStatus.LOST + tracker.error = e + tracker.completion_event.set() + await self._update_db_status(tracker) + except Exception as e: + # Check if this is a RetryError wrapping a Nomad exception + if "RetryError" in str(type(e)): + # Try to extract the wrapped exception + wrapped_exception = None + if hasattr(e, "last_attempt") and hasattr(e.last_attempt, "exception"): + wrapped_exception = e.last_attempt.exception() + elif hasattr(e, "__cause__") and e.__cause__: + wrapped_exception = e.__cause__ + + # Check if the wrapped exception is a URLNotFoundNomadException + if wrapped_exception and "URLNotFoundNomadException" in str(type(wrapped_exception)): + logger.error( + f"Polling failed for job {tracker.job_id} after retries, job may have been purged. Marking as LOST. Error: {e}" + ) + tracker.status = JobStatus.LOST + tracker.error = e + tracker.completion_event.set() + await self._update_db_status(tracker) + return + + logger.warning(f"Unexpected error while polling for job {tracker.job_id}: {e}") + + async def _update_tracker_from_allocation(self, tracker: JobTracker, allocation: Dict[str, Any]): + """Updates a tracker's state based on a Nomad allocation object.""" + if tracker.completion_event.is_set(): + return - # Map status old_status = tracker.status + client_status = allocation.get("ClientStatus", "").lower() new_status = NOMAD_STATUS_MAP.get(client_status, JobStatus.UNKNOWN) - if new_status != JobStatus.UNKNOWN and new_status != old_status: - tracker.status = new_status - tracker.timestamp = datetime.now(timezone.utc) - - # Extract exit code for completed jobs - if client_status == "complete": - task_states = payload.get("TaskStates", {}) - for task_state in task_states.values(): - if task_state.get("FinishedAt"): - tracker.exit_code = task_state.get("ExitCode", 0) - break - - # Update database - await self._update_job_status(tracker) + if new_status == old_status: + return # No change + + logger.info(f"Job {tracker.job_id} status change: {old_status.name} -> {new_status.name}") + tracker.status = new_status + tracker.timestamp = datetime.now(timezone.utc) + if not tracker.allocation_id: + tracker.allocation_id = allocation.get("ID") + + is_terminal = new_status in ( + JobStatus.SUCCEEDED, + JobStatus.FAILED, + JobStatus.LOST, + JobStatus.STOPPED, + JobStatus.CANCELLED, + ) - if new_status in (JobStatus.SUCCEEDED, JobStatus.FAILED, JobStatus.LOST): - tracker.completion_event.set() - logger.info(f"Job {job_id} completed with status {new_status}") + if is_terminal: + self._extract_final_state(tracker, allocation) + logger.info( + f"Job {tracker.job_id} completed with status JobStatus.{new_status.name} " + f"(exit_code: {tracker.exit_code})" + ) + tracker.completion_event.set() + + await self._update_db_status(tracker) + + def _extract_final_state(self, tracker: JobTracker, allocation: Dict[str, Any]): + """Extracts the exit code and error message from a terminal allocation.""" + task_states = allocation.get("TaskStates", {}) + for task_name, task_state in task_states.items(): + if task_state.get("FinishedAt"): + if tracker.status == JobStatus.FAILED: + tracker.exit_code = task_state.get("ExitCode", 1) + failure_details = [] + # Search for a meaningful event message + for event in reversed(task_state.get("Events", [])): + if event.get("Type") in [ + "Terminated", + "Task Failed", + "Driver Failure", + "Killing", + ]: + reason = event.get("DisplayMessage", "No reason provided.") + failure_details.append(f"Task '{task_name}' failed: {reason}") + break + if not failure_details: + failure_details.append(f"Task '{task_name}' failed with exit code {tracker.exit_code}.") + tracker.error = Exception("; ".join(failure_details)) + elif tracker.status == JobStatus.CANCELLED: + tracker.exit_code = task_state.get("ExitCode", 0) + # Search for cancellation reason + cancellation_reason = "Job was cancelled" + for event in reversed(task_state.get("Events", [])): + if event.get("Type") in ["Killing", "Terminated"]: + reason = event.get("DisplayMessage", "") + if reason: + cancellation_reason = f"Job cancelled: {reason}" + break + tracker.error = Exception(cancellation_reason) + elif tracker.status == JobStatus.STOPPED: + tracker.exit_code = task_state.get("ExitCode", 0) + # Search for stop reason + stop_reason = "Job was stopped" + for event in reversed(task_state.get("Events", [])): + if event.get("Type") in [ + "Killing", + "Terminated", + "Not Restarting", + ]: + reason = event.get("DisplayMessage", "") + if reason: + stop_reason = f"Job stopped: {reason}" + break + tracker.error = Exception(stop_reason) + else: + tracker.exit_code = task_state.get("ExitCode", 0) + return - async def _update_job_status(self, tracker: JobTracker): + async def _update_db_status(self, tracker: JobTracker): + """Updates an external database with the latest job status.""" + logger.info(f"Job {tracker.job_id} status updated: JobStatus.{tracker.status.name}") if not self.log_db: return - try: - status_str = tracker.status.value - - # Map internal status to database status if needed - if tracker.status == JobStatus.ALLOCATED: - status_str = "allocated" - elif tracker.status == JobStatus.SUCCEEDED: - status_str = "succeeded" - await self.log_db.update_job_status( job_id=tracker.job_id, - status=status_str, + status=tracker.status.value, stage=tracker.stage or "unknown", ) - - except Exception as e: - logger.error(f"Failed to update job status in database: {e}") - - async def get_job_status(self, job_id: str) -> Dict[str, Any]: - if job_id in self._active_jobs: - tracker = self._active_jobs[job_id] - return { - "job_id": job_id, - "status": tracker.status.value, - "allocation_id": tracker.allocation_id, - "exit_code": tracker.exit_code, - "dispatch_time": tracker.dispatch_time, - } - - # Query Nomad for job status - try: - job = await self._nomad_call(self.client.job.get_job, job_id) - status = job.get("Status", "unknown").lower() - return { - "job_id": job_id, - "status": NOMAD_STATUS_MAP.get(status, JobStatus.UNKNOWN).value, - } except Exception as e: - logger.error(f"Failed to get job status: {e}") - raise JobNotFoundError(f"Job {job_id} not found") + logger.error(f"Failed to update job status in DB for {tracker.job_id}: {e}") diff --git a/src/pipeline_stages.py b/src/pipeline_stages.py index c28ec45..db8ac09 100644 --- a/src/pipeline_stages.py +++ b/src/pipeline_stages.py @@ -1,16 +1,23 @@ import asyncio import logging +import random from abc import ABC, abstractmethod from typing import Any, Dict, List from pydantic import BaseModel, field_serializer from load_config import AppConfig +from nomad_job_manager import NomadError from pipeline_utils import PathFactory, PipelineResult logger = logging.getLogger(__name__) +def stagger_delay() -> float: + """Generate a random delay for job submission staggering.""" + return random.uniform(0.1, 2) + + class DispatchMetaBase(BaseModel): """ Common parameters for all dispatched jobs. @@ -36,6 +43,7 @@ class InundationDispatchMeta(DispatchMetaBase): class MosaicDispatchMeta(DispatchMetaBase): raster_paths: List[str] output_path: str + clip_geometry_path: str = "" @field_serializer("raster_paths", mode="plain") def _ser_raster(self, v: List[str], info): @@ -152,20 +160,39 @@ async def run(self, results: List[PipelineResult]) -> List[PipelineResult]: task_metadata = [] for result in valid_results: + # Copy flowfile once per scenario + flowfile_s3_path = await self.data_svc.copy_file_to_uri( + result.flowfile_path, + self.path_factory.flowfile_path(result.collection_name, result.scenario_name), + ) + logger.debug(f"[{result.scenario_id}] Copied flowfile to S3: {flowfile_s3_path}") + for catch_id, catchment_info in self.catchments.items(): output_path = self.path_factory.inundation_output_path( result.collection_name, result.scenario_name, catch_id ) result.set_path("inundation", f"catchment_{catch_id}", output_path) - task = asyncio.create_task(self._process_catchment(result, catch_id, catchment_info, output_path)) + task = asyncio.create_task( + self._process_catchment( + result, + catch_id, + catchment_info, + output_path, + flowfile_s3_path, + ) + ) tasks.append(task) task_metadata.append((result, catch_id, output_path)) + # stagger delay to spread load on Nomad + await asyncio.sleep(stagger_delay()) + task_results = await asyncio.gather(*tasks, return_exceptions=True) # Group results by scenario and validate outputs scenario_outputs = {} + lost_jobs = [] for (result, catch_id, output_path), task_result in zip(task_metadata, task_results): if result.scenario_id not in scenario_outputs: scenario_outputs[result.scenario_id] = [] @@ -174,6 +201,10 @@ async def run(self, results: List[PipelineResult]) -> List[PipelineResult]: scenario_outputs[result.scenario_id].append(output_path) else: logger.error(f"[{result.scenario_id}] Catchment {catch_id} inundation failed: {task_result}") + # Check if this is a LOST job (job was purged from Nomad) + error_str = str(task_result) + if "Job lost" in error_str or "URLNotFoundNomadException" in error_str: + lost_jobs.append(f"{result.scenario_id}/{catch_id}") # Validate files and update results updated_results = [] @@ -192,10 +223,22 @@ async def run(self, results: List[PipelineResult]) -> List[PipelineResult]: result.mark_failed("No inundation outputs produced") self.log_stage_complete("Inundation", len(updated_results), len(valid_results)) + + # Raise exception if any inundation jobs were lost + if lost_jobs: + error_msg = f"Inundation stage had {len(lost_jobs)} lost job(s): {', '.join(lost_jobs)}" + logger.error(error_msg) + raise NomadError(error_msg) + return updated_results async def _process_catchment( - self, result: PipelineResult, catch_id: str, catchment_info: Dict[str, Any], output_path: str + self, + result: PipelineResult, + catch_id: str, + catchment_info: Dict[str, Any], + output_path: str, + flowfile_s3_path: str, ) -> str: """Process a single catchment for a scenario.""" # Copy files to S3 @@ -208,14 +251,15 @@ async def _process_catchment( local_parquet, self.path_factory.catchment_parquet_path(catch_id), ) - flowfile_s3_path = await self.data_svc.copy_file_to_uri( - result.flowfile_path, self.path_factory.flowfile_path(result.collection_name, result.scenario_name) - ) meta = self._create_inundation_meta(parquet_path, flowfile_s3_path, output_path) - # add catchment internal tag - internal_tags = {"catchment": str(catch_id)[:7]} + # add bench_src, scenario, and catchment internal tags + internal_tags = { + "bench_src": result.collection_name, + "scenario": result.scenario_name, + "catchment": str(catch_id)[:7], + } tags_str = self._create_tags_str(internal_tags) job_id, _ = await self.nomad.dispatch_and_track( @@ -240,6 +284,19 @@ def _create_inundation_meta( class MosaicStage(PipelineStage): """Stage that creates HAND and benchmark mosaics from inundation outputs.""" + def __init__( + self, + config: AppConfig, + nomad_service, + data_service, + path_factory: PathFactory, + tags: Dict[str, str], + aoi_path: str, + ): + super().__init__(config, nomad_service, data_service, path_factory, tags) + self.aoi_path = aoi_path + self._clip_geometry_s3_path = None + def filter_inputs(self, results: List[PipelineResult]) -> List[PipelineResult]: """Filter results that have valid inundation outputs.""" return [r for r in results if r.status == "inundation_complete" and r.get_path("inundation", "valid_outputs")] @@ -249,6 +306,13 @@ async def run(self, results: List[PipelineResult]) -> List[PipelineResult]: valid_results = self.filter_inputs(results) self.log_stage_start("Mosaic", len(valid_results)) + # Copy AOI file to S3 once for all mosaic jobs + if valid_results and self.aoi_path: + self._clip_geometry_s3_path = await self.data_svc.copy_file_to_uri( + self.aoi_path, self.path_factory.aoi_path() + ) + logger.debug(f"Copied AOI file to S3: {self._clip_geometry_s3_path}") + hand_tasks = [] benchmark_tasks = [] task_results = [] @@ -267,7 +331,10 @@ async def run(self, results: List[PipelineResult]) -> List[PipelineResult]: hand_meta = self._create_mosaic_meta(valid_outputs, hand_output_path) # add cand_src and scenario internal tags for HAND mosaic - hand_internal_tags = {"cand_src": "hand", "scenario": result.scenario_name} + hand_internal_tags = { + "cand_src": "hand", + "scenario": result.scenario_name, + } hand_tags_str = self._create_tags_str(hand_internal_tags) hand_task = asyncio.create_task( @@ -279,6 +346,9 @@ async def run(self, results: List[PipelineResult]) -> List[PipelineResult]: ) hand_tasks.append(hand_task) + # stagger delay to spread load on Nomad + await asyncio.sleep(stagger_delay()) + # Benchmark mosaic benchmark_output_path = self.path_factory.benchmark_mosaic_path( result.collection_name, result.scenario_name @@ -287,7 +357,10 @@ async def run(self, results: List[PipelineResult]) -> List[PipelineResult]: benchmark_meta = self._create_mosaic_meta(benchmark_rasters, benchmark_output_path) # add bench_src and scenario internal tags for benchmark mosaic - benchmark_internal_tags = {"bench_src": result.collection_name, "scenario": result.scenario_name} + benchmark_internal_tags = { + "bench_src": result.collection_name, + "scenario": result.scenario_name, + } benchmark_tags_str = self._create_tags_str(benchmark_internal_tags) benchmark_task = asyncio.create_task( @@ -300,6 +373,9 @@ async def run(self, results: List[PipelineResult]) -> List[PipelineResult]: benchmark_tasks.append(benchmark_task) task_results.append(result) + # stagger delay to spread load on Nomad + await asyncio.sleep(stagger_delay()) + if not hand_tasks: self.log_stage_complete("Mosaic", 0, len(valid_results)) return [] @@ -309,6 +385,8 @@ async def run(self, results: List[PipelineResult]) -> List[PipelineResult]: # Update results based on success/failure successful_results = [] + failed_scenarios = [] + for result, hand_result, benchmark_result in zip(task_results, hand_results, benchmark_results): hand_failed = isinstance(hand_result, Exception) benchmark_failed = isinstance(benchmark_result, Exception) @@ -323,15 +401,29 @@ async def run(self, results: List[PipelineResult]) -> List[PipelineResult]: successful_results.append(result) logger.debug(f"[{result.scenario_id}] Mosaic stage complete") else: + failure_type = [] + if hand_failed: + failure_type.append("HAND") + if benchmark_failed: + failure_type.append("benchmark") result.mark_failed("One or both mosaics failed") + failed_scenarios.append(f"{result.scenario_id} ({', '.join(failure_type)})") self.log_stage_complete("Mosaic", len(successful_results), len(valid_results)) + + # Raise exception if any mosaic jobs failed + if failed_scenarios: + error_msg = f"Mosaic stage failed for {len(failed_scenarios)} scenario(s): {', '.join(failed_scenarios)}" + logger.error(error_msg) + raise NomadError(error_msg) + return successful_results def _create_mosaic_meta(self, raster_paths: List[str], output_path: str) -> MosaicDispatchMeta: return MosaicDispatchMeta( raster_paths=raster_paths, output_path=output_path, + clip_geometry_path=self._clip_geometry_s3_path or "", **self._get_base_meta_kwargs(), ) @@ -362,11 +454,17 @@ async def run(self, results: List[PipelineResult]) -> List[PipelineResult]: result.set_path("agreement", "metrics", metrics_output_path) meta = self._create_agreement_meta( - hand_mosaic, benchmark_mosaic, agreement_output_path, metrics_output_path + hand_mosaic, + benchmark_mosaic, + agreement_output_path, + metrics_output_path, ) # Create tags string with bench_src and scenario internal tags for agreement - agreement_internal_tags = {"bench_src": result.collection_name, "scenario": result.scenario_name} + agreement_internal_tags = { + "bench_src": result.collection_name, + "scenario": result.scenario_name, + } agreement_tags_str = self._create_tags_str(agreement_internal_tags) task = asyncio.create_task( @@ -379,6 +477,9 @@ async def run(self, results: List[PipelineResult]) -> List[PipelineResult]: tasks.append(task) task_results.append(result) + # stagger delay to spread load on Nomad + await asyncio.sleep(stagger_delay()) + if not tasks: self.log_stage_complete("Agreement", 0, len(valid_results)) return [] @@ -387,20 +488,34 @@ async def run(self, results: List[PipelineResult]) -> List[PipelineResult]: # Update results based on success/failure successful_results = [] + failed_scenarios = [] + for result, job_result in zip(task_results, task_job_results): if isinstance(job_result, Exception): result.mark_failed(f"Agreement job failed: {job_result}") logger.error(f"[{result.scenario_id}] Agreement job failed: {job_result}") + failed_scenarios.append(result.scenario_id) else: result.mark_completed() successful_results.append(result) logger.debug(f"[{result.scenario_id}] Pipeline complete") self.log_stage_complete("Agreement", len(successful_results), len(valid_results)) + + # Raise exception if any agreement jobs failed + if failed_scenarios: + error_msg = f"Agreement stage failed for {len(failed_scenarios)} scenario(s): {', '.join(failed_scenarios)}" + logger.error(error_msg) + raise NomadError(error_msg) + return successful_results def _create_agreement_meta( - self, candidate_path: str, benchmark_path: str, output_path: str, metrics_path: str = "" + self, + candidate_path: str, + benchmark_path: str, + output_path: str, + metrics_path: str = "", ) -> AgreementDispatchMeta: return AgreementDispatchMeta( candidate_path=candidate_path,