From 26f0c8cff2c29c603950652ba21e00b00f760394 Mon Sep 17 00:00:00 2001 From: "dylan.lee" Date: Sat, 2 Aug 2025 11:35:51 -0400 Subject: [PATCH 1/5] Limit the number of concurrently dispatched jobs Added a semaphore to NomadJobManager so that the dispatch job requests don't exceed the urllib3 pool size limit used by the nomad python library --- src/default_config.py | 1 + src/load_config.py | 6 ++++++ src/main.py | 1 + src/nomad_job_manager.py | 12 ++++++++++-- 4 files changed, 18 insertions(+), 2 deletions(-) diff --git a/src/default_config.py b/src/default_config.py index 0df7a5d..42be8af 100644 --- a/src/default_config.py +++ b/src/default_config.py @@ -40,3 +40,4 @@ # General defaults FIM_TYPE = "extent" HTTP_CONNECTION_LIMIT = 100 +NOMAD_MAX_CONCURRENT_DISPATCH = 8 # Slightly below urllib3's default pool size of 10 diff --git a/src/load_config.py b/src/load_config.py index b50089e..0ba6f60 100644 --- a/src/load_config.py +++ b/src/load_config.py @@ -105,6 +105,12 @@ class Defaults(BaseModel): gt=0, description="Max concurrent outgoing HTTP connections", ) + nomad_max_concurrent_dispatch: int = Field( + default_factory=lambda: int(os.getenv("NOMAD_MAX_CONCURRENT_DISPATCH", str(default_config.NOMAD_MAX_CONCURRENT_DISPATCH))), + gt=0, + le=10, + description="Max concurrent Nomad API calls (should be less than urllib3's pool size of 10)", + ) class AppConfig(BaseModel): diff --git a/src/main.py b/src/main.py index 1f6fd3a..afc646f 100644 --- a/src/main.py +++ b/src/main.py @@ -346,6 +346,7 @@ async def _main(): token=cfg.nomad.token, session=session, log_db=log_db, + max_concurrent_dispatch=cfg.defaults.nomad_max_concurrent_dispatch, ) await nomad.start() diff --git a/src/nomad_job_manager.py b/src/nomad_job_manager.py index afbd8ee..318ca36 100644 --- a/src/nomad_job_manager.py +++ b/src/nomad_job_manager.py @@ -77,12 +77,14 @@ def __init__( token: Optional[str] = None, session: Optional[aiohttp.ClientSession] = None, log_db: Optional[Any] = None, + max_concurrent_dispatch: int = 8, ): self.nomad_addr = nomad_addr self.namespace = namespace self.token = token self.session = session self.log_db = log_db + self.max_concurrent_dispatch = max_concurrent_dispatch parsed = urlparse(str(nomad_addr)) self.client = nomad.Nomad( @@ -93,6 +95,10 @@ def __init__( namespace=namespace or None, ) + # Create semaphore to limit concurrent Nomad API calls + self._api_semaphore = asyncio.Semaphore(max_concurrent_dispatch) + logger.info(f"Nomad API concurrency limited to {max_concurrent_dispatch} calls") + # Track active jobs self._active_jobs: Dict[str, JobTracker] = {} self._monitoring_task: Optional[asyncio.Task] = None @@ -126,8 +132,10 @@ async def stop(self): @_nomad_retry async def _nomad_call(self, func, *args, **kwargs): - """Execute a Nomad API call with retry logic.""" - return await asyncio.to_thread(func, *args, **kwargs) + """Execute a Nomad API call with retry logic and concurrency limiting.""" + async with self._api_semaphore: + logger.debug(f"Nomad API call: {func.__name__} (semaphore: {self._api_semaphore._value}/{self.max_concurrent_dispatch})") + return await asyncio.to_thread(func, *args, **kwargs) async def dispatch_and_track( self, From cba61b222473f8f147b3d57b6c3172b5fea62016 Mon Sep 17 00:00:00 2001 From: Dylan Lee Date: Thu, 28 Aug 2025 11:54:05 -0400 Subject: [PATCH 2/5] Stagger job submissions at each pipeline stage --- src/pipeline_stages.py | 25 ++++++++++++++++++++++++- 1 file changed, 24 insertions(+), 1 deletion(-) diff --git a/src/pipeline_stages.py b/src/pipeline_stages.py index c28ec45..2f6262a 100644 --- a/src/pipeline_stages.py +++ b/src/pipeline_stages.py @@ -1,16 +1,23 @@ import asyncio import logging +import random from abc import ABC, abstractmethod from typing import Any, Dict, List from pydantic import BaseModel, field_serializer from load_config import AppConfig +from nomad_job_manager import NomadError from pipeline_utils import PathFactory, PipelineResult logger = logging.getLogger(__name__) +def stagger_delay() -> float: + """Generate a random delay for job submission staggering.""" + return random.uniform(0.1, 2) + + class DispatchMetaBase(BaseModel): """ Common parameters for all dispatched jobs. @@ -158,10 +165,17 @@ async def run(self, results: List[PipelineResult]) -> List[PipelineResult]: ) result.set_path("inundation", f"catchment_{catch_id}", output_path) - task = asyncio.create_task(self._process_catchment(result, catch_id, catchment_info, output_path)) + task = asyncio.create_task( + self._process_catchment( + result, catch_id, catchment_info, output_path + ) + ) tasks.append(task) task_metadata.append((result, catch_id, output_path)) + # stagger delay to spread load on Nomad + await asyncio.sleep(stagger_delay()) + task_results = await asyncio.gather(*tasks, return_exceptions=True) # Group results by scenario and validate outputs @@ -279,6 +293,9 @@ async def run(self, results: List[PipelineResult]) -> List[PipelineResult]: ) hand_tasks.append(hand_task) + # stagger delay to spread load on Nomad + await asyncio.sleep(stagger_delay()) + # Benchmark mosaic benchmark_output_path = self.path_factory.benchmark_mosaic_path( result.collection_name, result.scenario_name @@ -300,6 +317,9 @@ async def run(self, results: List[PipelineResult]) -> List[PipelineResult]: benchmark_tasks.append(benchmark_task) task_results.append(result) + # stagger delay to spread load on Nomad + await asyncio.sleep(stagger_delay()) + if not hand_tasks: self.log_stage_complete("Mosaic", 0, len(valid_results)) return [] @@ -379,6 +399,9 @@ async def run(self, results: List[PipelineResult]) -> List[PipelineResult]: tasks.append(task) task_results.append(result) + # stagger delay to spread load on Nomad + await asyncio.sleep(stagger_delay()) + if not tasks: self.log_stage_complete("Agreement", 0, len(valid_results)) return [] From 3d17cff61f351cf95c2e9bc7a44a4c0558130a03 Mon Sep 17 00:00:00 2001 From: Dylan Lee Date: Thu, 28 Aug 2025 12:02:05 -0400 Subject: [PATCH 3/5] Update nomad_job_manager.py to use polling and add lost job tracking Edited nomad_job_manager.py to work with polling instead of the Nomad event stream. Also modified pipeline_stages.py so that a stage is marked failed if a job in that stage is lost by the Nomad API --- src/nomad_job_manager.py | 536 ++++++++++++++++++++++----------------- src/pipeline_stages.py | 246 +++++++++++++++--- 2 files changed, 509 insertions(+), 273 deletions(-) diff --git a/src/nomad_job_manager.py b/src/nomad_job_manager.py index 318ca36..f524fe8 100644 --- a/src/nomad_job_manager.py +++ b/src/nomad_job_manager.py @@ -1,14 +1,14 @@ import asyncio import json import logging +import random import time from dataclasses import dataclass, field from datetime import datetime, timezone from enum import Enum -from typing import Any, Dict, List, Optional, Set, Tuple +from typing import Any, Callable, Coroutine, Dict, List, Optional, Tuple from urllib.parse import urlparse -import aiohttp import nomad from tenacity import ( retry, @@ -21,36 +21,56 @@ class JobStatus(Enum): + """Internal job status enumeration aligned with Nomad ClientStatus values.""" + DISPATCHED = "dispatched" - ALLOCATED = "allocated" + PENDING = "pending" RUNNING = "running" SUCCEEDED = "succeeded" FAILED = "failed" LOST = "lost" + STOPPED = "stopped" + CANCELLED = "cancelled" UNKNOWN = "unknown" # Mapping from Nomad statuses to our internal statuses NOMAD_STATUS_MAP = { - "pending": JobStatus.ALLOCATED, + "pending": JobStatus.PENDING, "running": JobStatus.RUNNING, "complete": JobStatus.SUCCEEDED, "failed": JobStatus.FAILED, "lost": JobStatus.LOST, - "dead": JobStatus.FAILED, + "dead": JobStatus.STOPPED, + "stopped": JobStatus.STOPPED, + "cancelled": JobStatus.CANCELLED, } class NomadError(Exception): + """Base exception for the NomadJobManager.""" + pass class JobNotFoundError(NomadError): + """Raised when a specific job cannot be found in Nomad.""" + pass class JobDispatchError(NomadError): - pass + """ + Raised when a job fails to dispatch or fails during execution. + Contains enriched details about the failure. + """ + + def __init__(self, message: str, **kwargs): + super().__init__(message) + self.job_id: Optional[str] = kwargs.get("job_id") + self.allocation_id: Optional[str] = kwargs.get("allocation_id") + self.exit_code: Optional[int] = kwargs.get("exit_code") + self.original_error: Optional[Exception] = kwargs.get("original_error") @dataclass @@ -69,73 +89,90 @@ class JobTracker: timestamp: Optional[datetime] = None +class _NomadAPIClient: + """A wrapper for Nomad API calls that provides retries and concurrency limiting.""" + + def __init__(self, client: nomad.Nomad, max_concurrency: int): + self.client = client + self.semaphore = asyncio.Semaphore(max_concurrency) + self._nomad_retry = retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=2, max=30), + retry=retry_if_exception_type( + ( + nomad.api.exceptions.BaseNomadException, + nomad.api.exceptions.URLNotFoundNomadException, + ) + ), + ) + + async def _call(self, func: Callable, *args, **kwargs) -> Any: + async with self.semaphore: + return await asyncio.to_thread(func, *args, **kwargs) + + async def dispatch_job( + self, job_name: str, prefix: str, meta: Optional[Dict[str, str]] + ) -> str: + wrapped_call = self._nomad_retry(self._call) + result = await wrapped_call( + self.client.job.dispatch_job, + id_=job_name, + payload=None, + meta=meta, + id_prefix_template=prefix.replace(":", "-"), + ) + return result["DispatchedJobID"] + + async def get_job(self, job_id: str) -> Dict[str, Any]: + wrapped_call = self._nomad_retry(self._call) + return await wrapped_call(self.client.job.get_job, job_id) + + async def get_allocations(self, job_id: str) -> List[Dict[str, Any]]: + wrapped_call = self._nomad_retry(self._call) + return await wrapped_call(self.client.job.get_allocations, job_id) + + class NomadJobManager: + """ + Manages the lifecycle of parameterized Nomad jobs using a polling-based + strategy with exponential backoff. + """ + def __init__( self, nomad_addr: str, namespace: str = "default", token: Optional[str] = None, - session: Optional[aiohttp.ClientSession] = None, log_db: Optional[Any] = None, - max_concurrent_dispatch: int = 8, + max_concurrent_dispatch: int = 10, ): self.nomad_addr = nomad_addr self.namespace = namespace self.token = token - self.session = session self.log_db = log_db - self.max_concurrent_dispatch = max_concurrent_dispatch - parsed = urlparse(str(nomad_addr)) - self.client = nomad.Nomad( + parsed = urlparse(nomad_addr) + nomad_client = nomad.Nomad( host=parsed.hostname, port=parsed.port, verify=False, - token=token or None, - namespace=namespace or None, + token=token, + namespace=namespace, ) - - # Create semaphore to limit concurrent Nomad API calls - self._api_semaphore = asyncio.Semaphore(max_concurrent_dispatch) - logger.info(f"Nomad API concurrency limited to {max_concurrent_dispatch} calls") - - # Track active jobs + self.api = _NomadAPIClient(nomad_client, max_concurrent_dispatch) self._active_jobs: Dict[str, JobTracker] = {} - self._monitoring_task: Optional[asyncio.Task] = None self._shutdown_event = asyncio.Event() - self._event_index = 0 - async def start(self): - if not self._monitoring_task: - self._monitoring_task = asyncio.create_task(self._monitor_events()) - logger.info("Started Nomad job manager") + """Starts the job manager (no background tasks in this version).""" + logger.info("Starting NomadJobManager (polling-based).") + # No background tasks to start in a pure polling model async def stop(self): + """Stops the job manager.""" + logger.info("Stopping NomadJobManager.") self._shutdown_event.set() - if self._monitoring_task: - await self._monitoring_task - self._monitoring_task = None - logger.info("Stopped Nomad job manager") - - # Retry decorator for Nomad API calls - _nomad_retry = retry( - stop=stop_after_attempt(3), - wait=wait_exponential(multiplier=1, min=2, max=30), - retry=retry_if_exception_type( - ( - nomad.api.exceptions.BaseNomadException, - nomad.api.exceptions.URLNotFoundNomadException, - ) - ), - ) - - @_nomad_retry - async def _nomad_call(self, func, *args, **kwargs): - """Execute a Nomad API call with retry logic and concurrency limiting.""" - async with self._api_semaphore: - logger.debug(f"Nomad API call: {func.__name__} (semaphore: {self._api_semaphore._value}/{self.max_concurrent_dispatch})") - return await asyncio.to_thread(func, *args, **kwargs) + # No background tasks to stop async def dispatch_and_track( self, @@ -144,222 +181,257 @@ async def dispatch_and_track( meta: Optional[Dict[str, str]] = None, ) -> Tuple[str, int]: """ - Dispatch a job and track it to completion. - - Returns: - Tuple of (job_id, exit_code) + Dispatches a job and polls its status until completion using + an exponential backoff strategy. """ - job_id = await self._dispatch_job(job_name, prefix, meta) + try: + job_id = await self.api.dispatch_job(job_name, prefix, meta) + except Exception as e: + # Extract more detailed error information + error_msg = str(e) + + # Special handling for RetryError that wraps BaseNomadException + if ( + "RetryError" in str(type(e)) + and "BaseNomadException" in error_msg + ): + # Try to extract the cause from RetryError + if hasattr(e, "__cause__") and e.__cause__: + error_msg = f"{error_msg} - Cause: {str(e.__cause__)}" + elif hasattr(e, "last_attempt") and hasattr( + e.last_attempt, "exception" + ): + # For tenacity RetryError + last_exception = e.last_attempt.exception() + if last_exception: + error_msg = f"RetryError after 3 attempts - Last error: {str(last_exception)}" + + logger.error(f"Failed to dispatch job {job_name}: {error_msg}") + raise JobDispatchError( + f"Failed to dispatch job {job_name}: {error_msg}" + ) from e tracker = JobTracker( - job_id=job_id, - task_name=job_name, - stage=meta.get("stage") if meta else None, + job_id=job_id, task_name=job_name, stage=(meta or {}).get("stage") ) self._active_jobs[job_id] = tracker - - # Update database if available - await self._update_job_status(tracker) + await self._update_db_status(tracker) try: - await self._wait_for_allocation(tracker) + # Polling loop with exponential backoff + poll_delay = 1.0 # Start with a 1-second delay + backoff_factor = 1.5 + jitter_factor = 0.25 # Apply up to 25% jitter + max_poll_delay = 30.0 # Max delay of 30s - await self._wait_for_completion(tracker) + while not tracker.completion_event.is_set(): + jitter = poll_delay * jitter_factor * random.uniform(-1, 1) + await asyncio.sleep(max(0, poll_delay + jitter)) - if tracker.error: - raise tracker.error + await self._poll_job_and_update_tracker(tracker) - return job_id, tracker.exit_code or 0 + # Increase delay for the next poll + poll_delay = min(poll_delay * backoff_factor, max_poll_delay) + if tracker.error: + raise JobDispatchError( + f"Job {job_id} failed with exit code {tracker.exit_code}: {tracker.error}", + job_id=job_id, + allocation_id=tracker.allocation_id, + exit_code=tracker.exit_code, + original_error=tracker.error, + ) + return job_id, tracker.exit_code or 0 finally: self._active_jobs.pop(job_id, None) - async def _dispatch_job( - self, - job_name: str, - prefix: str, - meta: Optional[Dict[str, str]] = None, - ) -> str: - """Dispatch a parameterized job.""" - payload = {"Meta": meta} if meta else {} - + async def _poll_job_and_update_tracker(self, tracker: JobTracker): + """Polls a single job and updates its tracker if the status has changed.""" try: - logger.debug(f"Dispatching job {job_name} with meta: {meta}") - result = await self._nomad_call( - self.client.job.dispatch_job, - id_=job_name, - payload=None, - meta=meta, - id_prefix_template=prefix, - ) - job_id = result["DispatchedJobID"] - logger.info(f"Dispatched job {job_id} from {job_name}") - return job_id - - except Exception as e: - logger.error(f"Failed to dispatch job {job_name}: {e}") - logger.error(f"Payload was: {payload}") - raise JobDispatchError(f"Failed to dispatch job {job_name}: {e}") + # First, check if the job itself still exists. This avoids polling for allocations of a job that has been purged. + await self.api.get_job(tracker.job_id) + + allocations = await self.api.get_allocations(tracker.job_id) + + time_since_dispatch = time.time() - tracker.dispatch_time + if ( + not allocations and time_since_dispatch > 28800 + ): # allow 8 hrs to allocate for jobs that take a while to go from pending to dispatched + logger.warning( + f"No allocations found for job {tracker.job_id} after 30 minutes, marking as LOST." + ) + tracker.status = JobStatus.LOST + tracker.error = Exception( + "Job lost: No allocations found after timeout." + ) + tracker.completion_event.set() + await self._update_db_status(tracker) + return - async def _wait_for_allocation(self, tracker: JobTracker): - while True: - if tracker.allocation_id: + if not allocations: + logger.debug(f"Polling {tracker.job_id}: No allocations yet.") return - if tracker.status in (JobStatus.FAILED, JobStatus.LOST): - raise JobDispatchError(f"Job {tracker.job_id} failed during allocation") - await asyncio.sleep(1) - - async def _wait_for_completion(self, tracker: JobTracker): - await tracker.completion_event.wait() - - async def _monitor_events(self): - """Monitor Nomad event stream for job updates.""" - retry_count = 0 - max_retries = 5 - - while not self._shutdown_event.is_set(): - try: - await self._process_event_stream() - retry_count = 0 - except Exception as e: - retry_count += 1 - if retry_count >= max_retries: - logger.error(f"Event stream failed {max_retries} times, stopping monitor") - break - - wait_time = min(2**retry_count, 60) - logger.warning(f"Event stream error: {e}, retrying in {wait_time}s") - await asyncio.sleep(wait_time) - - async def _process_event_stream(self): - if not self.session: - self.session = aiohttp.ClientSession() - - url = f"{self.nomad_addr}/v1/event/stream" - params = { - "index": self._event_index, - "namespace": self.namespace, - } - - headers = {} - if self.token: - headers["X-Nomad-Token"] = self.token - - async with self.session.get(url, params=params, headers=headers, timeout=None) as response: - response.raise_for_status() - - buffer = b"" - async for chunk in response.content.iter_chunked(8 * 1024): - if self._shutdown_event.is_set(): - break - - buffer += chunk - while b"\n" in buffer: - line_bytes, buffer = buffer.split(b"\n", 1) - line = line_bytes.decode("utf-8").strip() - - if not line or line == "{}": - continue - - try: - data = json.loads(line) - if "Index" in data: - self._event_index = data["Index"] - - events = data.get("Events", []) - for event in events: - await self._handle_event(event) - - except json.JSONDecodeError: - logger.debug(f"Failed to parse event line: {line}") - except Exception as e: - logger.error(f"Error handling event: {e}") - - async def _handle_event(self, event: Dict[str, Any]): - """Handle a single Nomad event.""" - topic = event.get("Topic") - if topic != "Allocation": - return - payload = event.get("Payload", {}).get("Allocation", {}) - job_id = payload.get("JobID") + latest_alloc = max( + allocations, key=lambda a: a.get("CreateTime", 0) + ) + await self._update_tracker_from_allocation(tracker, latest_alloc) + + except ( + nomad.api.exceptions.URLNotFoundNomadException, + nomad.api.exceptions.BaseNomadException, + ) as e: + # Extract more detailed error information + error_msg = str(e) + if hasattr(e, "__cause__") and e.__cause__: + error_msg = f"{error_msg} - Cause: {str(e.__cause__)}" + + logger.error( + f"Polling failed for job {tracker.job_id}, it may have been purged. Marking as LOST. Error: {error_msg}" + ) + tracker.status = JobStatus.LOST + tracker.error = e + tracker.completion_event.set() + await self._update_db_status(tracker) + except Exception as e: + # Check if this is a RetryError wrapping a Nomad exception + if "RetryError" in str(type(e)): + # Try to extract the wrapped exception + wrapped_exception = None + if hasattr(e, "last_attempt") and hasattr(e.last_attempt, "exception"): + wrapped_exception = e.last_attempt.exception() + elif hasattr(e, "__cause__") and e.__cause__: + wrapped_exception = e.__cause__ + + # Check if the wrapped exception is a URLNotFoundNomadException + if wrapped_exception and "URLNotFoundNomadException" in str(type(wrapped_exception)): + logger.error( + f"Polling failed for job {tracker.job_id} after retries, job may have been purged. Marking as LOST. Error: {e}" + ) + tracker.status = JobStatus.LOST + tracker.error = e + tracker.completion_event.set() + await self._update_db_status(tracker) + return + + logger.warning( + f"Unexpected error while polling for job {tracker.job_id}: {e}" + ) - if not job_id or job_id not in self._active_jobs: + async def _update_tracker_from_allocation( + self, tracker: JobTracker, allocation: Dict[str, Any] + ): + """Updates a tracker's state based on a Nomad allocation object.""" + if tracker.completion_event.is_set(): return - tracker = self._active_jobs[job_id] - allocation_id = payload.get("ID") - client_status = payload.get("ClientStatus", "").lower() - - # Update allocation ID - if not tracker.allocation_id and allocation_id: - tracker.allocation_id = allocation_id - logger.info(f"Job {job_id} allocated: {allocation_id}") - - # Map status old_status = tracker.status + client_status = allocation.get("ClientStatus", "").lower() new_status = NOMAD_STATUS_MAP.get(client_status, JobStatus.UNKNOWN) - if new_status != JobStatus.UNKNOWN and new_status != old_status: - tracker.status = new_status - tracker.timestamp = datetime.now(timezone.utc) + if new_status == old_status: + return # No change - # Extract exit code for completed jobs - if client_status == "complete": - task_states = payload.get("TaskStates", {}) - for task_state in task_states.values(): - if task_state.get("FinishedAt"): - tracker.exit_code = task_state.get("ExitCode", 0) - break + logger.info( + f"Job {tracker.job_id} status change: {old_status.name} -> {new_status.name}" + ) + tracker.status = new_status + tracker.timestamp = datetime.now(timezone.utc) + if not tracker.allocation_id: + tracker.allocation_id = allocation.get("ID") + + is_terminal = new_status in ( + JobStatus.SUCCEEDED, + JobStatus.FAILED, + JobStatus.LOST, + JobStatus.STOPPED, + JobStatus.CANCELLED, + ) - # Update database - await self._update_job_status(tracker) + if is_terminal: + self._extract_final_state(tracker, allocation) + logger.info( + f"Job {tracker.job_id} completed with status JobStatus.{new_status.name} " + f"(exit_code: {tracker.exit_code})" + ) + tracker.completion_event.set() - if new_status in (JobStatus.SUCCEEDED, JobStatus.FAILED, JobStatus.LOST): - tracker.completion_event.set() - logger.info(f"Job {job_id} completed with status {new_status}") + await self._update_db_status(tracker) - async def _update_job_status(self, tracker: JobTracker): + def _extract_final_state( + self, tracker: JobTracker, allocation: Dict[str, Any] + ): + """Extracts the exit code and error message from a terminal allocation.""" + task_states = allocation.get("TaskStates", {}) + for task_name, task_state in task_states.items(): + if task_state.get("FinishedAt"): + if tracker.status == JobStatus.FAILED: + tracker.exit_code = task_state.get("ExitCode", 1) + failure_details = [] + # Search for a meaningful event message + for event in reversed(task_state.get("Events", [])): + if event.get("Type") in [ + "Terminated", + "Task Failed", + "Driver Failure", + "Killing", + ]: + reason = event.get( + "DisplayMessage", "No reason provided." + ) + failure_details.append( + f"Task '{task_name}' failed: {reason}" + ) + break + if not failure_details: + failure_details.append( + f"Task '{task_name}' failed with exit code {tracker.exit_code}." + ) + tracker.error = Exception("; ".join(failure_details)) + elif tracker.status == JobStatus.CANCELLED: + tracker.exit_code = task_state.get("ExitCode", 0) + # Search for cancellation reason + cancellation_reason = "Job was cancelled" + for event in reversed(task_state.get("Events", [])): + if event.get("Type") in ["Killing", "Terminated"]: + reason = event.get("DisplayMessage", "") + if reason: + cancellation_reason = f"Job cancelled: {reason}" + break + tracker.error = Exception(cancellation_reason) + elif tracker.status == JobStatus.STOPPED: + tracker.exit_code = task_state.get("ExitCode", 0) + # Search for stop reason + stop_reason = "Job was stopped" + for event in reversed(task_state.get("Events", [])): + if event.get("Type") in [ + "Killing", + "Terminated", + "Not Restarting", + ]: + reason = event.get("DisplayMessage", "") + if reason: + stop_reason = f"Job stopped: {reason}" + break + tracker.error = Exception(stop_reason) + else: + tracker.exit_code = task_state.get("ExitCode", 0) + return + + async def _update_db_status(self, tracker: JobTracker): + """Updates an external database with the latest job status.""" + logger.info( + f"Job {tracker.job_id} status updated: JobStatus.{tracker.status.name}" + ) if not self.log_db: return - try: - status_str = tracker.status.value - - # Map internal status to database status if needed - if tracker.status == JobStatus.ALLOCATED: - status_str = "allocated" - elif tracker.status == JobStatus.SUCCEEDED: - status_str = "succeeded" - await self.log_db.update_job_status( job_id=tracker.job_id, - status=status_str, + status=tracker.status.value, stage=tracker.stage or "unknown", ) - except Exception as e: - logger.error(f"Failed to update job status in database: {e}") - - async def get_job_status(self, job_id: str) -> Dict[str, Any]: - if job_id in self._active_jobs: - tracker = self._active_jobs[job_id] - return { - "job_id": job_id, - "status": tracker.status.value, - "allocation_id": tracker.allocation_id, - "exit_code": tracker.exit_code, - "dispatch_time": tracker.dispatch_time, - } - - # Query Nomad for job status - try: - job = await self._nomad_call(self.client.job.get_job, job_id) - status = job.get("Status", "unknown").lower() - return { - "job_id": job_id, - "status": NOMAD_STATUS_MAP.get(status, JobStatus.UNKNOWN).value, - } - except Exception as e: - logger.error(f"Failed to get job status: {e}") - raise JobNotFoundError(f"Job {job_id} not found") + logger.error( + f"Failed to update job status in DB for {tracker.job_id}: {e}" + ) diff --git a/src/pipeline_stages.py b/src/pipeline_stages.py index 2f6262a..c6bbbb2 100644 --- a/src/pipeline_stages.py +++ b/src/pipeline_stages.py @@ -43,6 +43,7 @@ class InundationDispatchMeta(DispatchMetaBase): class MosaicDispatchMeta(DispatchMetaBase): raster_paths: List[str] output_path: str + clip_geometry_path: str = "" @field_serializer("raster_paths", mode="plain") def _ser_raster(self, v: List[str], info): @@ -83,14 +84,18 @@ async def run(self, results: List[PipelineResult]) -> List[PipelineResult]: pass @abstractmethod - def filter_inputs(self, results: List[PipelineResult]) -> List[PipelineResult]: + def filter_inputs( + self, results: List[PipelineResult] + ) -> List[PipelineResult]: """Filter which results can be processed by this stage.""" pass def log_stage_start(self, stage_name: str, input_count: int): logger.debug(f"{stage_name}: Starting with {input_count} inputs") - def log_stage_complete(self, stage_name: str, success_count: int, total_count: int): + def log_stage_complete( + self, stage_name: str, success_count: int, total_count: int + ): logger.debug(f"{stage_name}: {success_count}/{total_count} succeeded") def _get_base_meta_kwargs(self) -> Dict[str, str]: @@ -142,10 +147,14 @@ def __init__( tags: Dict[str, str], catchments: Dict[str, Dict[str, Any]], ): - super().__init__(config, nomad_service, data_service, path_factory, tags) + super().__init__( + config, nomad_service, data_service, path_factory, tags + ) self.catchments = catchments - def filter_inputs(self, results: List[PipelineResult]) -> List[PipelineResult]: + def filter_inputs( + self, results: List[PipelineResult] + ) -> List[PipelineResult]: """All results are valid for inundation stage.""" return results @@ -159,15 +168,32 @@ async def run(self, results: List[PipelineResult]) -> List[PipelineResult]: task_metadata = [] for result in valid_results: + # Copy flowfile once per scenario + flowfile_s3_path = await self.data_svc.copy_file_to_uri( + result.flowfile_path, + self.path_factory.flowfile_path( + result.collection_name, result.scenario_name + ), + ) + logger.debug( + f"[{result.scenario_id}] Copied flowfile to S3: {flowfile_s3_path}" + ) + for catch_id, catchment_info in self.catchments.items(): output_path = self.path_factory.inundation_output_path( result.collection_name, result.scenario_name, catch_id ) - result.set_path("inundation", f"catchment_{catch_id}", output_path) + result.set_path( + "inundation", f"catchment_{catch_id}", output_path + ) task = asyncio.create_task( self._process_catchment( - result, catch_id, catchment_info, output_path + result, + catch_id, + catchment_info, + output_path, + flowfile_s3_path, ) ) tasks.append(task) @@ -180,14 +206,24 @@ async def run(self, results: List[PipelineResult]) -> List[PipelineResult]: # Group results by scenario and validate outputs scenario_outputs = {} - for (result, catch_id, output_path), task_result in zip(task_metadata, task_results): + lost_jobs = [] + for (result, catch_id, output_path), task_result in zip( + task_metadata, task_results + ): if result.scenario_id not in scenario_outputs: scenario_outputs[result.scenario_id] = [] if not isinstance(task_result, Exception): scenario_outputs[result.scenario_id].append(output_path) else: - logger.error(f"[{result.scenario_id}] Catchment {catch_id} inundation failed: {task_result}") + logger.error( + f"[{result.scenario_id}] Catchment {catch_id} inundation failed: {task_result}" + ) + # Check if this is a LOST job (job was purged from Nomad) + error_str = str(task_result) + if ("Job lost" in error_str or + "URLNotFoundNomadException" in error_str): + lost_jobs.append(f"{result.scenario_id}/{catch_id}") # Validate files and update results updated_results = [] @@ -196,20 +232,38 @@ async def run(self, results: List[PipelineResult]) -> List[PipelineResult]: if outputs: valid_outputs = await self.data_svc.validate_files(outputs) if valid_outputs: - result.set_path("inundation", "valid_outputs", valid_outputs) + result.set_path( + "inundation", "valid_outputs", valid_outputs + ) result.status = "inundation_complete" updated_results.append(result) - logger.debug(f"[{result.scenario_id}] {len(valid_outputs)}/{len(outputs)} inundation outputs valid") + logger.debug( + f"[{result.scenario_id}] {len(valid_outputs)}/{len(outputs)} inundation outputs valid" + ) else: result.mark_failed("No valid inundation outputs") else: result.mark_failed("No inundation outputs produced") - self.log_stage_complete("Inundation", len(updated_results), len(valid_results)) + self.log_stage_complete( + "Inundation", len(updated_results), len(valid_results) + ) + + # Raise exception if any inundation jobs were lost + if lost_jobs: + error_msg = f"Inundation stage had {len(lost_jobs)} lost job(s): {', '.join(lost_jobs)}" + logger.error(error_msg) + raise NomadError(error_msg) + return updated_results async def _process_catchment( - self, result: PipelineResult, catch_id: str, catchment_info: Dict[str, Any], output_path: str + self, + result: PipelineResult, + catch_id: str, + catchment_info: Dict[str, Any], + output_path: str, + flowfile_s3_path: str, ) -> str: """Process a single catchment for a scenario.""" # Copy files to S3 @@ -222,14 +276,17 @@ async def _process_catchment( local_parquet, self.path_factory.catchment_parquet_path(catch_id), ) - flowfile_s3_path = await self.data_svc.copy_file_to_uri( - result.flowfile_path, self.path_factory.flowfile_path(result.collection_name, result.scenario_name) - ) - meta = self._create_inundation_meta(parquet_path, flowfile_s3_path, output_path) + meta = self._create_inundation_meta( + parquet_path, flowfile_s3_path, output_path + ) - # add catchment internal tag - internal_tags = {"catchment": str(catch_id)[:7]} + # add bench_src, scenario, and catchment internal tags + internal_tags = { + "bench_src": result.collection_name, + "scenario": result.scenario_name, + "catchment": str(catch_id)[:7], + } tags_str = self._create_tags_str(internal_tags) job_id, _ = await self.nomad.dispatch_and_track( @@ -237,7 +294,9 @@ async def _process_catchment( prefix=tags_str, meta=meta.model_dump(), ) - logger.debug(f"[{result.scenario_id}/{catch_id}] inundator done → {job_id}") + logger.debug( + f"[{result.scenario_id}/{catch_id}] inundator done → {job_id}" + ) return job_id def _create_inundation_meta( @@ -254,15 +313,46 @@ def _create_inundation_meta( class MosaicStage(PipelineStage): """Stage that creates HAND and benchmark mosaics from inundation outputs.""" - def filter_inputs(self, results: List[PipelineResult]) -> List[PipelineResult]: + def __init__( + self, + config: AppConfig, + nomad_service, + data_service, + path_factory: PathFactory, + tags: Dict[str, str], + aoi_path: str, + ): + super().__init__( + config, nomad_service, data_service, path_factory, tags + ) + self.aoi_path = aoi_path + self._clip_geometry_s3_path = None + + def filter_inputs( + self, results: List[PipelineResult] + ) -> List[PipelineResult]: """Filter results that have valid inundation outputs.""" - return [r for r in results if r.status == "inundation_complete" and r.get_path("inundation", "valid_outputs")] + return [ + r + for r in results + if r.status == "inundation_complete" + and r.get_path("inundation", "valid_outputs") + ] async def run(self, results: List[PipelineResult]) -> List[PipelineResult]: """Run mosaic jobs for scenarios with valid inundation outputs.""" valid_results = self.filter_inputs(results) self.log_stage_start("Mosaic", len(valid_results)) + # Copy AOI file to S3 once for all mosaic jobs + if valid_results and self.aoi_path: + self._clip_geometry_s3_path = await self.data_svc.copy_file_to_uri( + self.aoi_path, self.path_factory.aoi_path() + ) + logger.debug( + f"Copied AOI file to S3: {self._clip_geometry_s3_path}" + ) + hand_tasks = [] benchmark_tasks = [] task_results = [] @@ -272,16 +362,25 @@ async def run(self, results: List[PipelineResult]) -> List[PipelineResult]: benchmark_rasters = result.benchmark_rasters if not valid_outputs or not benchmark_rasters: - logger.warning(f"[{result.scenario_id}] Skipping mosaic - missing inputs") + logger.warning( + f"[{result.scenario_id}] Skipping mosaic - missing inputs" + ) continue # HAND mosaic - hand_output_path = self.path_factory.hand_mosaic_path(result.collection_name, result.scenario_name) + hand_output_path = self.path_factory.hand_mosaic_path( + result.collection_name, result.scenario_name + ) result.set_path("mosaic", "hand", hand_output_path) - hand_meta = self._create_mosaic_meta(valid_outputs, hand_output_path) + hand_meta = self._create_mosaic_meta( + valid_outputs, hand_output_path + ) # add cand_src and scenario internal tags for HAND mosaic - hand_internal_tags = {"cand_src": "hand", "scenario": result.scenario_name} + hand_internal_tags = { + "cand_src": "hand", + "scenario": result.scenario_name, + } hand_tags_str = self._create_tags_str(hand_internal_tags) hand_task = asyncio.create_task( @@ -301,10 +400,15 @@ async def run(self, results: List[PipelineResult]) -> List[PipelineResult]: result.collection_name, result.scenario_name ) result.set_path("mosaic", "benchmark", benchmark_output_path) - benchmark_meta = self._create_mosaic_meta(benchmark_rasters, benchmark_output_path) + benchmark_meta = self._create_mosaic_meta( + benchmark_rasters, benchmark_output_path + ) # add bench_src and scenario internal tags for benchmark mosaic - benchmark_internal_tags = {"bench_src": result.collection_name, "scenario": result.scenario_name} + benchmark_internal_tags = { + "bench_src": result.collection_name, + "scenario": result.scenario_name, + } benchmark_tags_str = self._create_tags_str(benchmark_internal_tags) benchmark_task = asyncio.create_task( @@ -325,33 +429,63 @@ async def run(self, results: List[PipelineResult]) -> List[PipelineResult]: return [] hand_results = await asyncio.gather(*hand_tasks, return_exceptions=True) - benchmark_results = await asyncio.gather(*benchmark_tasks, return_exceptions=True) + benchmark_results = await asyncio.gather( + *benchmark_tasks, return_exceptions=True + ) # Update results based on success/failure successful_results = [] - for result, hand_result, benchmark_result in zip(task_results, hand_results, benchmark_results): + failed_scenarios = [] + + for result, hand_result, benchmark_result in zip( + task_results, hand_results, benchmark_results + ): hand_failed = isinstance(hand_result, Exception) benchmark_failed = isinstance(benchmark_result, Exception) if hand_failed: - logger.error(f"[{result.scenario_id}] HAND mosaic failed: {hand_result}") + logger.error( + f"[{result.scenario_id}] HAND mosaic failed: {hand_result}" + ) if benchmark_failed: - logger.error(f"[{result.scenario_id}] Benchmark mosaic failed: {benchmark_result}") + logger.error( + f"[{result.scenario_id}] Benchmark mosaic failed: {benchmark_result}" + ) if not hand_failed and not benchmark_failed: result.status = "mosaic_complete" successful_results.append(result) logger.debug(f"[{result.scenario_id}] Mosaic stage complete") else: + failure_type = [] + if hand_failed: + failure_type.append("HAND") + if benchmark_failed: + failure_type.append("benchmark") result.mark_failed("One or both mosaics failed") + failed_scenarios.append( + f"{result.scenario_id} ({', '.join(failure_type)})" + ) + + self.log_stage_complete( + "Mosaic", len(successful_results), len(valid_results) + ) + + # Raise exception if any mosaic jobs failed + if failed_scenarios: + error_msg = f"Mosaic stage failed for {len(failed_scenarios)} scenario(s): {', '.join(failed_scenarios)}" + logger.error(error_msg) + raise NomadError(error_msg) - self.log_stage_complete("Mosaic", len(successful_results), len(valid_results)) return successful_results - def _create_mosaic_meta(self, raster_paths: List[str], output_path: str) -> MosaicDispatchMeta: + def _create_mosaic_meta( + self, raster_paths: List[str], output_path: str + ) -> MosaicDispatchMeta: return MosaicDispatchMeta( raster_paths=raster_paths, output_path=output_path, + clip_geometry_path=self._clip_geometry_s3_path or "", **self._get_base_meta_kwargs(), ) @@ -359,7 +493,9 @@ def _create_mosaic_meta(self, raster_paths: List[str], output_path: str) -> Mosa class AgreementStage(PipelineStage): """Stage that creates agreement maps from HAND and benchmark mosaics.""" - def filter_inputs(self, results: List[PipelineResult]) -> List[PipelineResult]: + def filter_inputs( + self, results: List[PipelineResult] + ) -> List[PipelineResult]: """Filter results that have valid mosaic outputs.""" return [r for r in results if r.status == "mosaic_complete"] @@ -375,18 +511,28 @@ async def run(self, results: List[PipelineResult]) -> List[PipelineResult]: hand_mosaic = result.get_path("mosaic", "hand") benchmark_mosaic = result.get_path("mosaic", "benchmark") - agreement_output_path = self.path_factory.agreement_map_path(result.collection_name, result.scenario_name) - metrics_output_path = self.path_factory.agreement_metrics_path(result.collection_name, result.scenario_name) + agreement_output_path = self.path_factory.agreement_map_path( + result.collection_name, result.scenario_name + ) + metrics_output_path = self.path_factory.agreement_metrics_path( + result.collection_name, result.scenario_name + ) result.set_path("agreement", "map", agreement_output_path) result.set_path("agreement", "metrics", metrics_output_path) meta = self._create_agreement_meta( - hand_mosaic, benchmark_mosaic, agreement_output_path, metrics_output_path + hand_mosaic, + benchmark_mosaic, + agreement_output_path, + metrics_output_path, ) # Create tags string with bench_src and scenario internal tags for agreement - agreement_internal_tags = {"bench_src": result.collection_name, "scenario": result.scenario_name} + agreement_internal_tags = { + "bench_src": result.collection_name, + "scenario": result.scenario_name, + } agreement_tags_str = self._create_tags_str(agreement_internal_tags) task = asyncio.create_task( @@ -410,20 +556,38 @@ async def run(self, results: List[PipelineResult]) -> List[PipelineResult]: # Update results based on success/failure successful_results = [] + failed_scenarios = [] + for result, job_result in zip(task_results, task_job_results): if isinstance(job_result, Exception): result.mark_failed(f"Agreement job failed: {job_result}") - logger.error(f"[{result.scenario_id}] Agreement job failed: {job_result}") + logger.error( + f"[{result.scenario_id}] Agreement job failed: {job_result}" + ) + failed_scenarios.append(result.scenario_id) else: result.mark_completed() successful_results.append(result) logger.debug(f"[{result.scenario_id}] Pipeline complete") - self.log_stage_complete("Agreement", len(successful_results), len(valid_results)) + self.log_stage_complete( + "Agreement", len(successful_results), len(valid_results) + ) + + # Raise exception if any agreement jobs failed + if failed_scenarios: + error_msg = f"Agreement stage failed for {len(failed_scenarios)} scenario(s): {', '.join(failed_scenarios)}" + logger.error(error_msg) + raise NomadError(error_msg) + return successful_results def _create_agreement_meta( - self, candidate_path: str, benchmark_path: str, output_path: str, metrics_path: str = "" + self, + candidate_path: str, + benchmark_path: str, + output_path: str, + metrics_path: str = "", ) -> AgreementDispatchMeta: return AgreementDispatchMeta( candidate_path=candidate_path, From bf25ead6ef0c93f1c4f7734fbb46766809bfb84d Mon Sep 17 00:00:00 2001 From: Dylan Lee Date: Thu, 28 Aug 2025 12:38:18 -0400 Subject: [PATCH 4/5] Revert src/main.py in this branch to main branches src/main.py When I was cherrypicking changes in other files in src into this branch accidently copied over main.py changes. Reverted the changes since src/main.py changes are being tracked in the update-entrypoint branch --- src/default_config.py | 2 +- src/main.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/src/default_config.py b/src/default_config.py index 42be8af..abb4b1e 100644 --- a/src/default_config.py +++ b/src/default_config.py @@ -40,4 +40,4 @@ # General defaults FIM_TYPE = "extent" HTTP_CONNECTION_LIMIT = 100 -NOMAD_MAX_CONCURRENT_DISPATCH = 8 # Slightly below urllib3's default pool size of 10 +NOMAD_MAX_CONCURRENT_DISPATCH = 5 # Keep this on the smaller side so that there aren't too many API calls being generated by many pipeline jobs running on a single client. diff --git a/src/main.py b/src/main.py index afc646f..1f6fd3a 100644 --- a/src/main.py +++ b/src/main.py @@ -346,7 +346,6 @@ async def _main(): token=cfg.nomad.token, session=session, log_db=log_db, - max_concurrent_dispatch=cfg.defaults.nomad_max_concurrent_dispatch, ) await nomad.start() From bf9f645c20bec3e3aaae3363d74c3789f65f3928 Mon Sep 17 00:00:00 2001 From: Dylan Lee Date: Wed, 10 Sep 2025 15:18:08 -0400 Subject: [PATCH 5/5] Reformat line lengths --- src/load_config.py | 4 +- src/nomad_job_manager.py | 73 +++++--------------- src/pipeline_stages.py | 146 ++++++++++----------------------------- 3 files changed, 59 insertions(+), 164 deletions(-) diff --git a/src/load_config.py b/src/load_config.py index 0ba6f60..7c3bc38 100644 --- a/src/load_config.py +++ b/src/load_config.py @@ -106,7 +106,9 @@ class Defaults(BaseModel): description="Max concurrent outgoing HTTP connections", ) nomad_max_concurrent_dispatch: int = Field( - default_factory=lambda: int(os.getenv("NOMAD_MAX_CONCURRENT_DISPATCH", str(default_config.NOMAD_MAX_CONCURRENT_DISPATCH))), + default_factory=lambda: int( + os.getenv("NOMAD_MAX_CONCURRENT_DISPATCH", str(default_config.NOMAD_MAX_CONCURRENT_DISPATCH)) + ), gt=0, le=10, description="Max concurrent Nomad API calls (should be less than urllib3's pool size of 10)", diff --git a/src/nomad_job_manager.py b/src/nomad_job_manager.py index f524fe8..7d39335 100644 --- a/src/nomad_job_manager.py +++ b/src/nomad_job_manager.py @@ -110,9 +110,7 @@ async def _call(self, func: Callable, *args, **kwargs) -> Any: async with self.semaphore: return await asyncio.to_thread(func, *args, **kwargs) - async def dispatch_job( - self, job_name: str, prefix: str, meta: Optional[Dict[str, str]] - ) -> str: + async def dispatch_job(self, job_name: str, prefix: str, meta: Optional[Dict[str, str]]) -> str: wrapped_call = self._nomad_retry(self._call) result = await wrapped_call( self.client.job.dispatch_job, @@ -191,29 +189,20 @@ async def dispatch_and_track( error_msg = str(e) # Special handling for RetryError that wraps BaseNomadException - if ( - "RetryError" in str(type(e)) - and "BaseNomadException" in error_msg - ): + if "RetryError" in str(type(e)) and "BaseNomadException" in error_msg: # Try to extract the cause from RetryError if hasattr(e, "__cause__") and e.__cause__: error_msg = f"{error_msg} - Cause: {str(e.__cause__)}" - elif hasattr(e, "last_attempt") and hasattr( - e.last_attempt, "exception" - ): + elif hasattr(e, "last_attempt") and hasattr(e.last_attempt, "exception"): # For tenacity RetryError last_exception = e.last_attempt.exception() if last_exception: error_msg = f"RetryError after 3 attempts - Last error: {str(last_exception)}" logger.error(f"Failed to dispatch job {job_name}: {error_msg}") - raise JobDispatchError( - f"Failed to dispatch job {job_name}: {error_msg}" - ) from e + raise JobDispatchError(f"Failed to dispatch job {job_name}: {error_msg}") from e - tracker = JobTracker( - job_id=job_id, task_name=job_name, stage=(meta or {}).get("stage") - ) + tracker = JobTracker(job_id=job_id, task_name=job_name, stage=(meta or {}).get("stage")) self._active_jobs[job_id] = tracker await self._update_db_status(tracker) @@ -257,13 +246,9 @@ async def _poll_job_and_update_tracker(self, tracker: JobTracker): if ( not allocations and time_since_dispatch > 28800 ): # allow 8 hrs to allocate for jobs that take a while to go from pending to dispatched - logger.warning( - f"No allocations found for job {tracker.job_id} after 30 minutes, marking as LOST." - ) + logger.warning(f"No allocations found for job {tracker.job_id} after 30 minutes, marking as LOST.") tracker.status = JobStatus.LOST - tracker.error = Exception( - "Job lost: No allocations found after timeout." - ) + tracker.error = Exception("Job lost: No allocations found after timeout.") tracker.completion_event.set() await self._update_db_status(tracker) return @@ -272,9 +257,7 @@ async def _poll_job_and_update_tracker(self, tracker: JobTracker): logger.debug(f"Polling {tracker.job_id}: No allocations yet.") return - latest_alloc = max( - allocations, key=lambda a: a.get("CreateTime", 0) - ) + latest_alloc = max(allocations, key=lambda a: a.get("CreateTime", 0)) await self._update_tracker_from_allocation(tracker, latest_alloc) except ( @@ -302,7 +285,7 @@ async def _poll_job_and_update_tracker(self, tracker: JobTracker): wrapped_exception = e.last_attempt.exception() elif hasattr(e, "__cause__") and e.__cause__: wrapped_exception = e.__cause__ - + # Check if the wrapped exception is a URLNotFoundNomadException if wrapped_exception and "URLNotFoundNomadException" in str(type(wrapped_exception)): logger.error( @@ -313,14 +296,10 @@ async def _poll_job_and_update_tracker(self, tracker: JobTracker): tracker.completion_event.set() await self._update_db_status(tracker) return - - logger.warning( - f"Unexpected error while polling for job {tracker.job_id}: {e}" - ) - async def _update_tracker_from_allocation( - self, tracker: JobTracker, allocation: Dict[str, Any] - ): + logger.warning(f"Unexpected error while polling for job {tracker.job_id}: {e}") + + async def _update_tracker_from_allocation(self, tracker: JobTracker, allocation: Dict[str, Any]): """Updates a tracker's state based on a Nomad allocation object.""" if tracker.completion_event.is_set(): return @@ -332,9 +311,7 @@ async def _update_tracker_from_allocation( if new_status == old_status: return # No change - logger.info( - f"Job {tracker.job_id} status change: {old_status.name} -> {new_status.name}" - ) + logger.info(f"Job {tracker.job_id} status change: {old_status.name} -> {new_status.name}") tracker.status = new_status tracker.timestamp = datetime.now(timezone.utc) if not tracker.allocation_id: @@ -358,9 +335,7 @@ async def _update_tracker_from_allocation( await self._update_db_status(tracker) - def _extract_final_state( - self, tracker: JobTracker, allocation: Dict[str, Any] - ): + def _extract_final_state(self, tracker: JobTracker, allocation: Dict[str, Any]): """Extracts the exit code and error message from a terminal allocation.""" task_states = allocation.get("TaskStates", {}) for task_name, task_state in task_states.items(): @@ -376,17 +351,11 @@ def _extract_final_state( "Driver Failure", "Killing", ]: - reason = event.get( - "DisplayMessage", "No reason provided." - ) - failure_details.append( - f"Task '{task_name}' failed: {reason}" - ) + reason = event.get("DisplayMessage", "No reason provided.") + failure_details.append(f"Task '{task_name}' failed: {reason}") break if not failure_details: - failure_details.append( - f"Task '{task_name}' failed with exit code {tracker.exit_code}." - ) + failure_details.append(f"Task '{task_name}' failed with exit code {tracker.exit_code}.") tracker.error = Exception("; ".join(failure_details)) elif tracker.status == JobStatus.CANCELLED: tracker.exit_code = task_state.get("ExitCode", 0) @@ -420,9 +389,7 @@ def _extract_final_state( async def _update_db_status(self, tracker: JobTracker): """Updates an external database with the latest job status.""" - logger.info( - f"Job {tracker.job_id} status updated: JobStatus.{tracker.status.name}" - ) + logger.info(f"Job {tracker.job_id} status updated: JobStatus.{tracker.status.name}") if not self.log_db: return try: @@ -432,6 +399,4 @@ async def _update_db_status(self, tracker: JobTracker): stage=tracker.stage or "unknown", ) except Exception as e: - logger.error( - f"Failed to update job status in DB for {tracker.job_id}: {e}" - ) + logger.error(f"Failed to update job status in DB for {tracker.job_id}: {e}") diff --git a/src/pipeline_stages.py b/src/pipeline_stages.py index c6bbbb2..db8ac09 100644 --- a/src/pipeline_stages.py +++ b/src/pipeline_stages.py @@ -84,18 +84,14 @@ async def run(self, results: List[PipelineResult]) -> List[PipelineResult]: pass @abstractmethod - def filter_inputs( - self, results: List[PipelineResult] - ) -> List[PipelineResult]: + def filter_inputs(self, results: List[PipelineResult]) -> List[PipelineResult]: """Filter which results can be processed by this stage.""" pass def log_stage_start(self, stage_name: str, input_count: int): logger.debug(f"{stage_name}: Starting with {input_count} inputs") - def log_stage_complete( - self, stage_name: str, success_count: int, total_count: int - ): + def log_stage_complete(self, stage_name: str, success_count: int, total_count: int): logger.debug(f"{stage_name}: {success_count}/{total_count} succeeded") def _get_base_meta_kwargs(self) -> Dict[str, str]: @@ -147,14 +143,10 @@ def __init__( tags: Dict[str, str], catchments: Dict[str, Dict[str, Any]], ): - super().__init__( - config, nomad_service, data_service, path_factory, tags - ) + super().__init__(config, nomad_service, data_service, path_factory, tags) self.catchments = catchments - def filter_inputs( - self, results: List[PipelineResult] - ) -> List[PipelineResult]: + def filter_inputs(self, results: List[PipelineResult]) -> List[PipelineResult]: """All results are valid for inundation stage.""" return results @@ -171,21 +163,15 @@ async def run(self, results: List[PipelineResult]) -> List[PipelineResult]: # Copy flowfile once per scenario flowfile_s3_path = await self.data_svc.copy_file_to_uri( result.flowfile_path, - self.path_factory.flowfile_path( - result.collection_name, result.scenario_name - ), - ) - logger.debug( - f"[{result.scenario_id}] Copied flowfile to S3: {flowfile_s3_path}" + self.path_factory.flowfile_path(result.collection_name, result.scenario_name), ) + logger.debug(f"[{result.scenario_id}] Copied flowfile to S3: {flowfile_s3_path}") for catch_id, catchment_info in self.catchments.items(): output_path = self.path_factory.inundation_output_path( result.collection_name, result.scenario_name, catch_id ) - result.set_path( - "inundation", f"catchment_{catch_id}", output_path - ) + result.set_path("inundation", f"catchment_{catch_id}", output_path) task = asyncio.create_task( self._process_catchment( @@ -207,22 +193,17 @@ async def run(self, results: List[PipelineResult]) -> List[PipelineResult]: # Group results by scenario and validate outputs scenario_outputs = {} lost_jobs = [] - for (result, catch_id, output_path), task_result in zip( - task_metadata, task_results - ): + for (result, catch_id, output_path), task_result in zip(task_metadata, task_results): if result.scenario_id not in scenario_outputs: scenario_outputs[result.scenario_id] = [] if not isinstance(task_result, Exception): scenario_outputs[result.scenario_id].append(output_path) else: - logger.error( - f"[{result.scenario_id}] Catchment {catch_id} inundation failed: {task_result}" - ) + logger.error(f"[{result.scenario_id}] Catchment {catch_id} inundation failed: {task_result}") # Check if this is a LOST job (job was purged from Nomad) error_str = str(task_result) - if ("Job lost" in error_str or - "URLNotFoundNomadException" in error_str): + if "Job lost" in error_str or "URLNotFoundNomadException" in error_str: lost_jobs.append(f"{result.scenario_id}/{catch_id}") # Validate files and update results @@ -232,29 +213,23 @@ async def run(self, results: List[PipelineResult]) -> List[PipelineResult]: if outputs: valid_outputs = await self.data_svc.validate_files(outputs) if valid_outputs: - result.set_path( - "inundation", "valid_outputs", valid_outputs - ) + result.set_path("inundation", "valid_outputs", valid_outputs) result.status = "inundation_complete" updated_results.append(result) - logger.debug( - f"[{result.scenario_id}] {len(valid_outputs)}/{len(outputs)} inundation outputs valid" - ) + logger.debug(f"[{result.scenario_id}] {len(valid_outputs)}/{len(outputs)} inundation outputs valid") else: result.mark_failed("No valid inundation outputs") else: result.mark_failed("No inundation outputs produced") - self.log_stage_complete( - "Inundation", len(updated_results), len(valid_results) - ) - + self.log_stage_complete("Inundation", len(updated_results), len(valid_results)) + # Raise exception if any inundation jobs were lost if lost_jobs: error_msg = f"Inundation stage had {len(lost_jobs)} lost job(s): {', '.join(lost_jobs)}" logger.error(error_msg) raise NomadError(error_msg) - + return updated_results async def _process_catchment( @@ -277,9 +252,7 @@ async def _process_catchment( self.path_factory.catchment_parquet_path(catch_id), ) - meta = self._create_inundation_meta( - parquet_path, flowfile_s3_path, output_path - ) + meta = self._create_inundation_meta(parquet_path, flowfile_s3_path, output_path) # add bench_src, scenario, and catchment internal tags internal_tags = { @@ -294,9 +267,7 @@ async def _process_catchment( prefix=tags_str, meta=meta.model_dump(), ) - logger.debug( - f"[{result.scenario_id}/{catch_id}] inundator done → {job_id}" - ) + logger.debug(f"[{result.scenario_id}/{catch_id}] inundator done → {job_id}") return job_id def _create_inundation_meta( @@ -322,22 +293,13 @@ def __init__( tags: Dict[str, str], aoi_path: str, ): - super().__init__( - config, nomad_service, data_service, path_factory, tags - ) + super().__init__(config, nomad_service, data_service, path_factory, tags) self.aoi_path = aoi_path self._clip_geometry_s3_path = None - def filter_inputs( - self, results: List[PipelineResult] - ) -> List[PipelineResult]: + def filter_inputs(self, results: List[PipelineResult]) -> List[PipelineResult]: """Filter results that have valid inundation outputs.""" - return [ - r - for r in results - if r.status == "inundation_complete" - and r.get_path("inundation", "valid_outputs") - ] + return [r for r in results if r.status == "inundation_complete" and r.get_path("inundation", "valid_outputs")] async def run(self, results: List[PipelineResult]) -> List[PipelineResult]: """Run mosaic jobs for scenarios with valid inundation outputs.""" @@ -349,9 +311,7 @@ async def run(self, results: List[PipelineResult]) -> List[PipelineResult]: self._clip_geometry_s3_path = await self.data_svc.copy_file_to_uri( self.aoi_path, self.path_factory.aoi_path() ) - logger.debug( - f"Copied AOI file to S3: {self._clip_geometry_s3_path}" - ) + logger.debug(f"Copied AOI file to S3: {self._clip_geometry_s3_path}") hand_tasks = [] benchmark_tasks = [] @@ -362,19 +322,13 @@ async def run(self, results: List[PipelineResult]) -> List[PipelineResult]: benchmark_rasters = result.benchmark_rasters if not valid_outputs or not benchmark_rasters: - logger.warning( - f"[{result.scenario_id}] Skipping mosaic - missing inputs" - ) + logger.warning(f"[{result.scenario_id}] Skipping mosaic - missing inputs") continue # HAND mosaic - hand_output_path = self.path_factory.hand_mosaic_path( - result.collection_name, result.scenario_name - ) + hand_output_path = self.path_factory.hand_mosaic_path(result.collection_name, result.scenario_name) result.set_path("mosaic", "hand", hand_output_path) - hand_meta = self._create_mosaic_meta( - valid_outputs, hand_output_path - ) + hand_meta = self._create_mosaic_meta(valid_outputs, hand_output_path) # add cand_src and scenario internal tags for HAND mosaic hand_internal_tags = { @@ -400,9 +354,7 @@ async def run(self, results: List[PipelineResult]) -> List[PipelineResult]: result.collection_name, result.scenario_name ) result.set_path("mosaic", "benchmark", benchmark_output_path) - benchmark_meta = self._create_mosaic_meta( - benchmark_rasters, benchmark_output_path - ) + benchmark_meta = self._create_mosaic_meta(benchmark_rasters, benchmark_output_path) # add bench_src and scenario internal tags for benchmark mosaic benchmark_internal_tags = { @@ -429,28 +381,20 @@ async def run(self, results: List[PipelineResult]) -> List[PipelineResult]: return [] hand_results = await asyncio.gather(*hand_tasks, return_exceptions=True) - benchmark_results = await asyncio.gather( - *benchmark_tasks, return_exceptions=True - ) + benchmark_results = await asyncio.gather(*benchmark_tasks, return_exceptions=True) # Update results based on success/failure successful_results = [] failed_scenarios = [] - for result, hand_result, benchmark_result in zip( - task_results, hand_results, benchmark_results - ): + for result, hand_result, benchmark_result in zip(task_results, hand_results, benchmark_results): hand_failed = isinstance(hand_result, Exception) benchmark_failed = isinstance(benchmark_result, Exception) if hand_failed: - logger.error( - f"[{result.scenario_id}] HAND mosaic failed: {hand_result}" - ) + logger.error(f"[{result.scenario_id}] HAND mosaic failed: {hand_result}") if benchmark_failed: - logger.error( - f"[{result.scenario_id}] Benchmark mosaic failed: {benchmark_result}" - ) + logger.error(f"[{result.scenario_id}] Benchmark mosaic failed: {benchmark_result}") if not hand_failed and not benchmark_failed: result.status = "mosaic_complete" @@ -463,13 +407,9 @@ async def run(self, results: List[PipelineResult]) -> List[PipelineResult]: if benchmark_failed: failure_type.append("benchmark") result.mark_failed("One or both mosaics failed") - failed_scenarios.append( - f"{result.scenario_id} ({', '.join(failure_type)})" - ) + failed_scenarios.append(f"{result.scenario_id} ({', '.join(failure_type)})") - self.log_stage_complete( - "Mosaic", len(successful_results), len(valid_results) - ) + self.log_stage_complete("Mosaic", len(successful_results), len(valid_results)) # Raise exception if any mosaic jobs failed if failed_scenarios: @@ -479,9 +419,7 @@ async def run(self, results: List[PipelineResult]) -> List[PipelineResult]: return successful_results - def _create_mosaic_meta( - self, raster_paths: List[str], output_path: str - ) -> MosaicDispatchMeta: + def _create_mosaic_meta(self, raster_paths: List[str], output_path: str) -> MosaicDispatchMeta: return MosaicDispatchMeta( raster_paths=raster_paths, output_path=output_path, @@ -493,9 +431,7 @@ def _create_mosaic_meta( class AgreementStage(PipelineStage): """Stage that creates agreement maps from HAND and benchmark mosaics.""" - def filter_inputs( - self, results: List[PipelineResult] - ) -> List[PipelineResult]: + def filter_inputs(self, results: List[PipelineResult]) -> List[PipelineResult]: """Filter results that have valid mosaic outputs.""" return [r for r in results if r.status == "mosaic_complete"] @@ -511,12 +447,8 @@ async def run(self, results: List[PipelineResult]) -> List[PipelineResult]: hand_mosaic = result.get_path("mosaic", "hand") benchmark_mosaic = result.get_path("mosaic", "benchmark") - agreement_output_path = self.path_factory.agreement_map_path( - result.collection_name, result.scenario_name - ) - metrics_output_path = self.path_factory.agreement_metrics_path( - result.collection_name, result.scenario_name - ) + agreement_output_path = self.path_factory.agreement_map_path(result.collection_name, result.scenario_name) + metrics_output_path = self.path_factory.agreement_metrics_path(result.collection_name, result.scenario_name) result.set_path("agreement", "map", agreement_output_path) result.set_path("agreement", "metrics", metrics_output_path) @@ -561,18 +493,14 @@ async def run(self, results: List[PipelineResult]) -> List[PipelineResult]: for result, job_result in zip(task_results, task_job_results): if isinstance(job_result, Exception): result.mark_failed(f"Agreement job failed: {job_result}") - logger.error( - f"[{result.scenario_id}] Agreement job failed: {job_result}" - ) + logger.error(f"[{result.scenario_id}] Agreement job failed: {job_result}") failed_scenarios.append(result.scenario_id) else: result.mark_completed() successful_results.append(result) logger.debug(f"[{result.scenario_id}] Pipeline complete") - self.log_stage_complete( - "Agreement", len(successful_results), len(valid_results) - ) + self.log_stage_complete("Agreement", len(successful_results), len(valid_results)) # Raise exception if any agreement jobs failed if failed_scenarios: