Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 106 additions & 13 deletions src/orchestrator/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,15 @@

import asyncio
import heapq
import logging
import time
from typing import Any, Dict, Optional
from uuid import uuid4


logger = logging.getLogger(__name__)


class PriorityQueue:
def __init__(self):
self._queue = []
Expand All @@ -30,22 +34,65 @@ def __len__(self) -> int:
return len(self._queue)


DEFAULT_MAX_RETRIES = 3
MAX_RETRY_METADATA = 100 # hard cap: prevent unbounded metadata growth


class TaskScheduler:
def __init__(self):
"""Priority-based task scheduler with bounded retry metadata."""

def __init__(self, max_retries: int = DEFAULT_MAX_RETRIES):
self._queues: Dict[str, PriorityQueue] = {}
self._scheduled: Dict[str, float] = {}
self._in_flight: Dict[str, Dict] = {}
self._max_retries = 3

def enqueue(self, task: Dict, queue: str = "default", priority: int = 0) -> str:
task_id = str(uuid4())
self._default_max_retries = max_retries
self._dead_letter: Dict[str, Dict] = {} # permanently failed tasks

def _get_task_max_retries(self, task: Dict) -> int:
"""Return per-task max_retries or the scheduler default."""
return task.get("max_retries", self._default_max_retries)

def enqueue(
self,
task: Dict,
queue: str = "default",
priority: int = 0,
preserve_retries: bool = False,
) -> Optional[str]:
"""Enqueue a task.

Args:
task: The task dict.
queue: Target queue name.
priority: Scheduling priority (higher = sooner).
preserve_retries: If True, keep existing retry count; otherwise reset to 0.

Returns:
Task ID on success, or None if the task has exhausted retries.

"""
# Enforce the repeated-failures invariant: reject tasks whose retry
# metadata has already been exhausted.
retries = task.get("retries", 0)
if retries >= MAX_RETRY_METADATA:
logger.warning(
"Rejecting enqueue for task %s — retry count %d exceeds hard cap %d",
task.get("id", "unknown"),
retries,
MAX_RETRY_METADATA,
)
return None

task_id = task.get("id") or str(uuid4())
task["id"] = task_id
task["enqueued_at"] = time.time()
task["retries"] = 0
if not preserve_retries:
task["retries"] = 0

if queue not in self._queues:
self._queues[queue] = PriorityQueue()
self._queues[queue].push(task, priority)
logger.debug("Enqueued task %s on queue %s (retries=%d)", task_id, queue, task.get("retries", 0))
return task_id

def schedule(self, task: Dict, delay: float, queue: str = "default", priority: int = 0) -> str:
Expand All @@ -60,7 +107,9 @@ async def dequeue(self, queue: str = "default", timeout: float = 1.0) -> Optiona
for tid in expired:
task = self._scheduled.pop(tid)
if task:
self.enqueue(task, queue)
task_id = self.enqueue(task, queue, preserve_retries=True)
if task_id is None:
logger.error("Scheduled task %s rejected during promotion — moving to dead letter", tid)

if queue in self._queues and len(self._queues[queue]) > 0:
task = self._queues[queue].pop()
Expand All @@ -70,15 +119,59 @@ async def dequeue(self, queue: str = "default", timeout: float = 1.0) -> Optiona
return None

def complete(self, task_id: str) -> bool:
return self._in_flight.pop(task_id, None) is not None
task = self._in_flight.pop(task_id, None)
if task:
logger.debug("Task %s completed successfully", task_id)
return True
return False

def fail(self, task_id: str, queue: str = "default") -> bool:
"""Record a task failure and optionally re-enqueue for retry.

Returns:
True if the task was re-enqueued for retry.
False if retries are exhausted (task goes to dead-letter) or task not found.
"""
task = self._in_flight.pop(task_id, None)
if task:
task["retries"] += 1
if task["retries"] < self._max_retries:
self.enqueue(task, queue, priority=task.get("priority", 0))
return True
if not task:
logger.debug("fail() called for unknown task_id %s", task_id)
return False

# Bound retry metadata growth: cap increment to prevent overflow
current_retries = task.get("retries", 0)
if current_retries < MAX_RETRY_METADATA:
task["retries"] = current_retries + 1
else:
task["retries"] = current_retries # idempotent: don't grow past cap

new_retries = task["retries"]
max_r = self._get_task_max_retries(task)

# Enforce the repeated-failures invariant before committing state
if new_retries >= max_r or new_retries >= MAX_RETRY_METADATA:
logger.warning(
"Task %s failed permanently after %d retries (max=%d, hard_cap=%d) — moving to dead letter",
task_id,
new_retries,
max_r,
MAX_RETRY_METADATA,
)
self._dead_letter[task_id] = task
return False

# Safe to re-enqueue — preserve existing retry metadata
re_enqueued = self.enqueue(task, queue, priority=task.get("priority", 0), preserve_retries=True)
if re_enqueued is not None:
logger.info(
"Task %s will retry (attempt %d/%d)",
task_id,
new_retries,
max_r,
)
return True

# enqueue rejected (e.g. hard cap) — fall through to dead letter
self._dead_letter[task_id] = task
return False

# 2019-04-25T08:37:12 update
Expand Down
Loading