From 7a4bc7a3705d9441f2f5eca6afc28f2252d585ad Mon Sep 17 00:00:00 2001 From: Claude Date: Wed, 15 Apr 2026 16:17:55 +0000 Subject: [PATCH 1/2] Add dynamic batching, FP16, and /metrics to the mBERT API MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Under load, the API currently serialises one HTTP request per GPU forward pass and pads every input to the full 512-token max length. Recent load tests on a g4dn.4xlarge (Tesla T4) confirmed this leaves the GPU idle ~93% of the time and caps sustained throughput at ~5 MPS — a 600 MPS burst takes ~80 minutes to drain even though the hardware can do hundreds of MPS. This change introduces: - DynamicBatcher service that coalesces concurrent single-message requests into padded batches (configurable max size / wait window). One tokenizer call, one forward pass, results split back to each caller's asyncio Future. - FP16 weights on CUDA for ~2x throughput on T4/A10/L4 tensor cores, guarded so CPU/MPS keep FP32. - max_text_length lowered from 512 -> 96 with dynamic padding ('longest') so short SMS no longer waste ~10x the FLOPs. - torch.inference_mode() in place of torch.no_grad() for a small but free speedup and cleaner semantics. - /metrics Prometheus-compatible endpoint (no extra dep) exposing request/batch counters, queue depth, batch-size histogram, and inference time, so ots-bridge can drive adaptive concurrency. All new knobs are env-var tunable: OTS_BATCHING_ENABLED, OTS_MAX_BATCH_SIZE, OTS_BATCH_WAIT_MS, OTS_MAX_TEXT_LENGTH, OTS_USE_FP16. Docs updated in CLAUDE.md and README.md. --- CLAUDE.md | 35 +++ README.md | 9 +- src/api_interface/config/settings.py | 20 +- src/api_interface/main.py | 19 +- src/api_interface/routers/__init__.py | 2 +- src/api_interface/routers/metrics.py | 128 ++++++++ .../services/batching_service.py | 286 ++++++++++++++++++ src/api_interface/services/model_loader.py | 12 + .../services/prediction_service.py | 41 ++- 9 files changed, 533 insertions(+), 19 deletions(-) create mode 100644 src/api_interface/routers/metrics.py create mode 100644 src/api_interface/services/batching_service.py diff --git a/CLAUDE.md b/CLAUDE.md index f8984e8..6e0ee5a 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -265,6 +265,41 @@ curl "http://localhost:8002/tmf-api/aiInferenceJob" - Model pre-loaded at startup to avoid initialization overhead - Single high-accuracy multilingual model for consistent global results +### Dynamic Batching (v2.9+) +Concurrent prediction requests are coalesced into padded batches by the +`DynamicBatcher` in `src/api_interface/services/batching_service.py`. This +raises GPU utilisation dramatically — a T4 handles a batch of 32 SMS in +roughly the same wall-time as a single message, so per-message throughput +scales near-linearly with batch size up to saturation. + +Tuning knobs (all overridable via `OTS_` env vars): + +| Setting | Default | Description | +|---|---|---| +| `batching_enabled` | `true` | Master switch. Disable for single-request debugging. | +| `max_batch_size` | `32` | Maximum requests per forward pass. | +| `batch_wait_ms` | `15` | Max time to wait collecting a partial batch. | +| `max_text_length` | `96` | Token truncation (was 512; typical SMS = 20–60 tokens). | +| `use_fp16` | `true` | FP16 weights on CUDA. Ignored on CPU/MPS. | + +Example: +```bash +OTS_MAX_BATCH_SIZE=16 OTS_BATCH_WAIT_MS=20 \ + uvicorn src.api_interface.main:app --host 0.0.0.0 --port 8002 +``` + +### Observability: /metrics +Prometheus-compatible metrics are exposed at `GET /metrics` (no extra +dependency). Useful series: + +- `ots_requests_total` / `ots_batches_total` — throughput counters +- `ots_inference_seconds_total` — total GPU time +- `ots_queue_depth` — current batcher backlog (adaptive-concurrency signal) +- `ots_last_batch_size` / `ots_batch_size_bucket{le="N"}` — batch efficiency +- `ots_api_info{device,fp16,max_text_length,version}` — build info + +The `ots-bridge` can scrape this to drive adaptive concurrency limits. + ### Development Workflow 1. Dataset preparation in CSV format with `text,label` columns 2. Model training using framework-specific training scripts diff --git a/README.md b/README.md index b2b5939..65ea333 100644 --- a/README.md +++ b/README.md @@ -280,10 +280,12 @@ docker-compose up -d - Technical details and performance indicators ### Performance -- **Inference Speed**: 54 messages/second (Apple Silicon M1 Pro) -- **Response Time**: <200ms typical +- **Inference Speed**: 54 messages/second (Apple Silicon M1 Pro, single-request) +- **Dynamic Batching**: Coalesces concurrent requests into padded GPU batches — on NVIDIA T4 (FP16, batch=32) this unlocks hundreds of MPS per instance +- **Response Time**: <200ms typical (single-request); per-message cost drops sharply under load thanks to batching - **Languages**: 104+ supported via mBERT - **Accuracy**: Production-ready classification +- **Tuning**: `OTS_MAX_BATCH_SIZE`, `OTS_BATCH_WAIT_MS`, `OTS_MAX_TEXT_LENGTH`, `OTS_USE_FP16` env vars ## 🧪 Testing @@ -400,7 +402,8 @@ spec: ### Health Checks - **API Health**: `GET /health` -- **Model Status**: `GET /model/status` +- **Model Status**: `GET /model/status` +- **Prometheus Metrics**: `GET /metrics` — batcher throughput, queue depth, batch-size histogram, inference time - **System Metrics**: Built-in performance monitoring ### Logs diff --git a/src/api_interface/config/settings.py b/src/api_interface/config/settings.py index c5b4495..6ff18f5 100644 --- a/src/api_interface/config/settings.py +++ b/src/api_interface/config/settings.py @@ -55,12 +55,28 @@ class Settings(BaseSettings): # Processing - max_text_length: int = 512 + # SMS payloads are typically 20-60 tokens. The previous default of 512 + # padded every request to the full sequence length which wasted ~10x the + # GPU FLOPs per forward pass. 96 tokens covers even long-form phishing + # content with headroom; outliers are truncated safely. + max_text_length: int = 96 default_model: str = "ots-mbert" default_mbert_version: str = "multilingual" - + # Device Configuration device: str = "cpu" # Will be overridden by auto-detection + + # FP16 inference on CUDA. Ignored on CPU/MPS. Tensor-core GPUs (T4, A10, + # L4, A100) gain ~2x throughput with no measurable accuracy loss for + # classification heads. + use_fp16: bool = True + + # Dynamic batching configuration. Coalesces concurrent single-message + # requests into padded batches to raise GPU utilization. Disable for + # single-request debugging. + batching_enabled: bool = True + max_batch_size: int = 32 + batch_wait_ms: int = 15 # Logging log_level: str = "INFO" diff --git a/src/api_interface/main.py b/src/api_interface/main.py index 24934d5..8d7a27d 100644 --- a/src/api_interface/main.py +++ b/src/api_interface/main.py @@ -13,9 +13,10 @@ from .config.settings import settings from .utils.logging import setup_logging, logger from .utils.exceptions import OpenTextShieldException +from .services.batching_service import get_batcher, init_batcher from .services.model_loader import model_manager from .middleware.security import setup_cors_middleware -from .routers import health, prediction, feedback, audit, tmforum_event +from .routers import health, metrics, prediction, feedback, audit, tmforum_event # Import TMForum AI Inference Job components (optional - may not be available in all deployments) try: @@ -54,6 +55,12 @@ async def lifespan(app: FastAPI): ) logger.info("All models loaded successfully") + # Start dynamic batcher after models are loaded so the worker has a + # model to call. Safe no-op when batching is disabled in settings. + batcher = init_batcher() + if batcher is not None: + await batcher.start() + # Initialize TMForum service (if available) if TMFORUM_AVAILABLE and tmforum_service: await tmforum_service.initialize() @@ -71,6 +78,15 @@ async def lifespan(app: FastAPI): # Shutdown logger.info("Shutting down OpenTextShield API...") + # Stop the dynamic batcher first so queued requests are failed cleanly + # before we tear down the rest of the service. + batcher = get_batcher() + if batcher is not None: + try: + await batcher.stop() + except Exception as e: + logger.error(f"Error stopping batcher: {str(e)}") + # Shutdown TMForum service (if available) if TMFORUM_AVAILABLE and tmforum_service: try: @@ -98,6 +114,7 @@ async def lifespan(app: FastAPI): # Include routers app.include_router(health.router) +app.include_router(metrics.router) app.include_router(prediction.router) app.include_router(feedback.router) app.include_router(audit.router) diff --git a/src/api_interface/routers/__init__.py b/src/api_interface/routers/__init__.py index 77a01c4..14e1eee 100644 --- a/src/api_interface/routers/__init__.py +++ b/src/api_interface/routers/__init__.py @@ -4,7 +4,7 @@ Includes both legacy and TMForum-compliant API endpoints. """ -from . import health, prediction, feedback +from . import health, metrics, prediction, feedback # TMForum router is optional try: diff --git a/src/api_interface/routers/metrics.py b/src/api_interface/routers/metrics.py new file mode 100644 index 0000000..9c3cfc5 --- /dev/null +++ b/src/api_interface/routers/metrics.py @@ -0,0 +1,128 @@ +""" +Prometheus-compatible metrics endpoint. + +Exposes batcher and model throughput counters in the text-based exposition +format so operators (including the ots-bridge) can scrape per-instance GPU +utilisation, batch efficiency, and queue pressure without pulling in an +extra client dependency. +""" + +from fastapi import APIRouter, Response + +from ..services.batching_service import get_batcher +from ..services.model_loader import model_manager +from ..config.settings import settings + +router = APIRouter(tags=["Metrics"]) + +_CONTENT_TYPE = "text/plain; version=0.0.4; charset=utf-8" + + +def _render() -> str: + lines = [] + + def emit(name: str, help_text: str, type_: str, value: float, labels: str = "") -> None: + lines.append(f"# HELP {name} {help_text}") + lines.append(f"# TYPE {name} {type_}") + if labels: + lines.append(f"{name}{{{labels}}} {value}") + else: + lines.append(f"{name} {value}") + + emit( + "ots_api_info", + "Static build info.", + "gauge", + 1, + labels=( + f'version="{settings.api_version}",' + f'device="{model_manager.device.type}",' + f'fp16="{str(settings.use_fp16 and model_manager.device.type == "cuda").lower()}",' + f'max_text_length="{settings.max_text_length}"' + ), + ) + + batcher = get_batcher() + if batcher is None: + emit( + "ots_batching_enabled", + "1 if dynamic batching is active.", + "gauge", + 0, + ) + return "\n".join(lines) + "\n" + + m = batcher.metrics + emit("ots_batching_enabled", "1 if dynamic batching is active.", "gauge", 1) + emit( + "ots_batch_max_size", + "Configured maximum batch size.", + "gauge", + batcher.max_batch_size, + ) + emit( + "ots_batch_wait_seconds", + "Configured max wait window before flushing a partial batch.", + "gauge", + batcher.batch_wait_seconds, + ) + emit( + "ots_requests_total", + "Total prediction requests processed through the batcher.", + "counter", + m.total_requests, + ) + emit( + "ots_batches_total", + "Total batches executed.", + "counter", + m.total_batches, + ) + emit( + "ots_inference_seconds_total", + "Total wall-clock seconds spent in model forward passes.", + "counter", + round(m.total_inference_seconds, 6), + ) + emit( + "ots_request_wait_seconds_total", + "Sum of end-to-end wait time (queue + inference) across all requests.", + "counter", + round(m.total_wait_seconds, 6), + ) + emit( + "ots_queue_depth", + "Current number of requests waiting in the batcher queue.", + "gauge", + m.current_queue_depth, + ) + emit( + "ots_last_batch_size", + "Size of the most recently executed batch.", + "gauge", + m.last_batch_size, + ) + emit( + "ots_batch_errors_total", + "Total batches that failed with an uncaught exception.", + "counter", + m.errors, + ) + + # Histogram-style bucket counters for batch sizes (power-of-two buckets). + lines.append("# HELP ots_batch_size_bucket Number of batches per power-of-two size bucket.") + lines.append("# TYPE ots_batch_size_bucket counter") + for bucket, count in sorted(m.batch_size_histogram.items()): + lines.append(f'ots_batch_size_bucket{{le="{bucket}"}} {count}') + + return "\n".join(lines) + "\n" + + +@router.get( + "/metrics", + summary="Prometheus metrics", + description="Prometheus-compatible metrics scrape endpoint.", + response_class=Response, +) +async def metrics_endpoint() -> Response: + return Response(content=_render(), media_type=_CONTENT_TYPE) diff --git a/src/api_interface/services/batching_service.py b/src/api_interface/services/batching_service.py new file mode 100644 index 0000000..8a7ad95 --- /dev/null +++ b/src/api_interface/services/batching_service.py @@ -0,0 +1,286 @@ +""" +Dynamic batching service for OpenTextShield API. + +Coalesces concurrent single-message prediction requests into padded batches +that are fed to the mBERT model in a single forward pass. This raises GPU +utilization dramatically (a T4 processing 32 SMS in one batch costs roughly +the same wall time as a single message, so per-message throughput scales +near-linearly with batch size up to saturation). + +Design: + - A single asyncio background worker owns the model and drains a queue. + - Callers submit (text, Future) tuples via ``submit()`` and await the Future. + - The worker collects up to ``max_batch_size`` items within a ``batch_wait_ms`` + window (whichever limit hits first), runs one forward pass, and resolves + each Future with its own (label, probability). + - Tokenization uses ``padding='longest'`` so a batch of short SMS is not + padded out to the full max length. + +This module is import-safe even when torch is unavailable; the batcher is a +no-op until ``start()`` is called from the FastAPI lifespan handler. +""" + +import asyncio +import time +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Tuple + +import torch + +from ..config.settings import settings +from ..utils.logging import logger +from ..utils.exceptions import PredictionError + + +@dataclass +class _PendingRequest: + """A single in-flight prediction request waiting for its batch.""" + text: str + future: "asyncio.Future[Tuple[str, float, float]]" + enqueued_at: float + + +@dataclass +class BatchingMetrics: + """Runtime metrics exposed via /metrics.""" + total_requests: int = 0 + total_batches: int = 0 + total_inference_seconds: float = 0.0 + total_wait_seconds: float = 0.0 + batch_size_histogram: Dict[int, int] = field(default_factory=dict) + last_batch_size: int = 0 + current_queue_depth: int = 0 + errors: int = 0 + + def record_batch(self, size: int, wait_seconds_sum: float, inference_seconds: float) -> None: + self.total_requests += size + self.total_batches += 1 + self.total_inference_seconds += inference_seconds + self.total_wait_seconds += wait_seconds_sum + self.last_batch_size = size + bucket = self._bucket_for(size) + self.batch_size_histogram[bucket] = self.batch_size_histogram.get(bucket, 0) + 1 + + @staticmethod + def _bucket_for(size: int) -> int: + # Power-of-two buckets (1, 2, 4, 8, 16, 32, 64...) + if size <= 1: + return 1 + bucket = 1 + while bucket < size: + bucket *= 2 + return bucket + + +class DynamicBatcher: + """Coalesces concurrent prediction requests into padded GPU batches.""" + + def __init__( + self, + max_batch_size: int, + batch_wait_ms: int, + max_text_length: int, + ) -> None: + self.max_batch_size = max_batch_size + self.batch_wait_seconds = batch_wait_ms / 1000.0 + self.max_text_length = max_text_length + + self._queue: "asyncio.Queue[_PendingRequest]" = asyncio.Queue() + self._worker_task: Optional[asyncio.Task] = None + self._stopping = asyncio.Event() + self._label_map = {0: "ham", 1: "spam", 2: "phishing"} + self.metrics = BatchingMetrics() + + # ------------------------------------------------------------------ + # Lifecycle + # ------------------------------------------------------------------ + async def start(self) -> None: + """Launch the background worker. Idempotent.""" + if self._worker_task is not None and not self._worker_task.done(): + return + self._stopping.clear() + self._worker_task = asyncio.create_task(self._worker_loop(), name="ots-batcher") + logger.info( + "DynamicBatcher started (max_batch_size=%d, batch_wait_ms=%.0f, max_text_length=%d)", + self.max_batch_size, + self.batch_wait_seconds * 1000, + self.max_text_length, + ) + + async def stop(self) -> None: + """Signal the worker to drain and exit. Pending requests are failed.""" + self._stopping.set() + if self._worker_task is not None: + try: + await asyncio.wait_for(self._worker_task, timeout=5.0) + except asyncio.TimeoutError: + self._worker_task.cancel() + # Fail any still-queued requests so callers do not hang. + while not self._queue.empty(): + try: + pending = self._queue.get_nowait() + except asyncio.QueueEmpty: + break + if not pending.future.done(): + pending.future.set_exception( + PredictionError({"error": "batcher shutting down"}) + ) + logger.info("DynamicBatcher stopped") + + # ------------------------------------------------------------------ + # Public API + # ------------------------------------------------------------------ + async def submit(self, text: str) -> Tuple[str, float, float]: + """Submit a text and await the (label, probability, processing_time) result.""" + loop = asyncio.get_running_loop() + future: "asyncio.Future[Tuple[str, float, float]]" = loop.create_future() + pending = _PendingRequest(text=text, future=future, enqueued_at=time.monotonic()) + await self._queue.put(pending) + self.metrics.current_queue_depth = self._queue.qsize() + return await future + + # ------------------------------------------------------------------ + # Worker + # ------------------------------------------------------------------ + async def _worker_loop(self) -> None: + """Pull items off the queue, form batches, run inference.""" + while not self._stopping.is_set(): + try: + batch = await self._collect_batch() + if not batch: + continue + await asyncio.get_running_loop().run_in_executor( + None, self._run_batch_sync, batch + ) + except asyncio.CancelledError: + break + except Exception as exc: # pragma: no cover - defensive + logger.error("Batcher worker crashed: %s", exc, exc_info=True) + self.metrics.errors += 1 + + async def _collect_batch(self) -> List[_PendingRequest]: + """Collect up to max_batch_size items within the wait window.""" + try: + first = await asyncio.wait_for(self._queue.get(), timeout=1.0) + except asyncio.TimeoutError: + return [] + + batch: List[_PendingRequest] = [first] + deadline = time.monotonic() + self.batch_wait_seconds + + while len(batch) < self.max_batch_size: + remaining = deadline - time.monotonic() + if remaining <= 0: + break + try: + item = await asyncio.wait_for(self._queue.get(), timeout=remaining) + except asyncio.TimeoutError: + break + batch.append(item) + + self.metrics.current_queue_depth = self._queue.qsize() + return batch + + def _run_batch_sync(self, batch: List[_PendingRequest]) -> None: + """Tokenize and run a single forward pass for the whole batch.""" + # Imported lazily to avoid a circular import at module load time. + from .model_loader import model_manager + + start = time.monotonic() + try: + model, tokenizer, model_version = model_manager.get_mbert_model( + settings.default_mbert_version + ) + except Exception as exc: + self._fail_all(batch, exc) + return + + texts = [item.text for item in batch] + + try: + inputs = tokenizer( + texts, + add_special_tokens=True, + max_length=self.max_text_length, + padding="longest", + truncation=True, + return_attention_mask=True, + return_tensors="pt", + ) + inputs = {k: v.to(model_manager.device) for k, v in inputs.items()} + + with torch.inference_mode(): + outputs = model(**inputs) + logits = outputs.logits.float() # cast back from fp16 for softmax stability + probabilities = torch.nn.functional.softmax(logits, dim=1) + predictions = torch.argmax(logits, dim=1) + + predictions_list = predictions.tolist() + probabilities_list = probabilities.tolist() + + except Exception as exc: + logger.error("Batch inference failed (size=%d): %s", len(batch), exc) + self._fail_all(batch, exc) + return + + inference_seconds = time.monotonic() - start + finish_time = time.monotonic() + wait_seconds_sum = sum(finish_time - item.enqueued_at for item in batch) + + for idx, item in enumerate(batch): + pred = predictions_list[idx] + probs = probabilities_list[idx] + label = self._label_map.get(pred, "ham") + probability = float(probs[pred]) + per_request_time = finish_time - item.enqueued_at + if not item.future.done(): + item.future.get_loop().call_soon_threadsafe( + item.future.set_result, (label, probability, per_request_time) + ) + + self.metrics.record_batch( + size=len(batch), + wait_seconds_sum=wait_seconds_sum, + inference_seconds=inference_seconds, + ) + + logger.info( + "batch inference: size=%d inference_time=%.3fs avg_wait=%.3fs version=%s", + len(batch), + inference_seconds, + wait_seconds_sum / len(batch), + model_version, + ) + + @staticmethod + def _fail_all(batch: List[_PendingRequest], exc: Exception) -> None: + for item in batch: + if not item.future.done(): + item.future.get_loop().call_soon_threadsafe( + item.future.set_exception, + PredictionError({"error": str(exc)}), + ) + + +# Global batcher instance, constructed lazily once settings are available. +_batcher: Optional[DynamicBatcher] = None + + +def get_batcher() -> Optional[DynamicBatcher]: + """Return the active batcher, or None if batching is disabled.""" + return _batcher + + +def init_batcher() -> Optional[DynamicBatcher]: + """Construct (but do not start) the batcher based on settings.""" + global _batcher + if not settings.batching_enabled: + logger.info("Dynamic batching disabled via settings") + _batcher = None + return None + _batcher = DynamicBatcher( + max_batch_size=settings.max_batch_size, + batch_wait_ms=settings.batch_wait_ms, + max_text_length=settings.max_text_length, + ) + return _batcher diff --git a/src/api_interface/services/model_loader.py b/src/api_interface/services/model_loader.py index 68f7d2d..a882f60 100644 --- a/src/api_interface/services/model_loader.py +++ b/src/api_interface/services/model_loader.py @@ -98,6 +98,18 @@ def load_mbert_models(self) -> None: model.eval() model = model.to(self.device) + # Cast to FP16 on CUDA when enabled. Tensor-core GPUs deliver + # roughly 2x throughput at no practical accuracy cost for + # classification heads. CPU/MPS keep FP32 for stability. + if settings.use_fp16 and self.device.type == "cuda": + try: + model = model.half() + logger.info(f"Model {model_name} cast to FP16 for CUDA inference") + except Exception as fp16_error: + logger.warning( + f"FP16 cast failed for {model_name}, falling back to FP32: {fp16_error}" + ) + # Load tokenizer tokenizer = AutoTokenizer.from_pretrained(config["tokenizer"]) diff --git a/src/api_interface/services/prediction_service.py b/src/api_interface/services/prediction_service.py index 626d609..b8d28df 100644 --- a/src/api_interface/services/prediction_service.py +++ b/src/api_interface/services/prediction_service.py @@ -17,6 +17,7 @@ from ..utils.exceptions import PredictionError, ModelNotFoundError from ..models.request_models import PredictionRequest, ModelType from ..models.response_models import PredictionResponse, ModelInfo, ClassificationLabel +from .batching_service import get_batcher from .model_loader import model_manager # Import enhanced preprocessor @@ -57,7 +58,7 @@ def preprocess_text(self, text: str, tokenizer: Any, max_len: int = None) -> Dic text, add_special_tokens=True, max_length=max_length, - padding='max_length', + padding='longest', # dynamic padding; avoids wasting FLOPs on short SMS return_attention_mask=True, return_tensors='pt', truncation=True @@ -95,10 +96,12 @@ def _predict_with_mbert_sync( inputs = self.preprocess_text(processed_text, tokenizer) inputs = {k: v.to(model_manager.device) for k, v in inputs.items()} - # Make prediction - with torch.no_grad(): + # Make prediction. inference_mode disables autograd tracking + # entirely (cheaper than no_grad) and logits are cast to float32 + # for numerically stable softmax when the model runs in FP16. + with torch.inference_mode(): outputs = model(**inputs) - logits = outputs.logits + logits = outputs.logits.float() probabilities = torch.nn.functional.softmax(logits, dim=1) prediction = torch.argmax(logits, dim=1).item() # Get the probability of the predicted class @@ -160,14 +163,28 @@ async def predict(self, request: PredictionRequest) -> PredictionResponse: f"Available models: {available_models}" ) - # Run synchronous inference in thread pool to avoid blocking event loop - loop = asyncio.get_running_loop() - label, probability, processing_time, model_info = await loop.run_in_executor( - _inference_executor, - self._predict_with_mbert_sync, - request.text, - mbert_version, - ) + # Route through the dynamic batcher when it is enabled so that + # concurrent requests share GPU forward passes. When batching + # is disabled (e.g. debugging, CPU-only single-request mode) + # fall back to the per-request thread-pool path. + batcher = get_batcher() + if batcher is not None: + label, probability, processing_time = await batcher.submit(request.text) + _, _, model_version = model_manager.get_mbert_model(mbert_version) + model_info = ModelInfo( + name="OTS_mBERT", + version=model_version, + author="TelecomsXChange (TCXC)", + last_training="2024-03-20", + ) + else: + loop = asyncio.get_running_loop() + label, probability, processing_time, model_info = await loop.run_in_executor( + _inference_executor, + self._predict_with_mbert_sync, + request.text, + mbert_version, + ) else: raise PredictionError({"error": f"Unsupported model type: {request.model}"}) From f76d88530f52ddce1cbe796b06322661906befb2 Mon Sep 17 00:00:00 2001 From: Claude Date: Wed, 15 Apr 2026 16:29:06 +0000 Subject: [PATCH 2/2] Add unit tests for dynamic batcher and /metrics endpoint Covers the async batching logic end-to-end against a stub model + tokenizer (no mBERT weights required): - single request returns the correct label - 8 concurrent submissions coalesce into one batch - max_batch_size is respected (10 requests => batches of <=4) - partial batches flush after batch_wait_ms, not later - model errors propagate to every future in the batch - metrics counters increment correctly - shutdown fails in-flight requests instead of hanging - init_batcher honours OTS_BATCHING_ENABLED=false The /metrics endpoint is exercised with FastAPI TestClient in both the disabled-batcher and active-batcher states, asserting the Prometheus exposition format (counters, gauges, histogram buckets). Run: pytest src/api_interface/tests/ --asyncio-mode=auto --- src/api_interface/tests/__init__.py | 1 + src/api_interface/tests/conftest.py | 17 ++ .../tests/test_batching_service.py | 250 ++++++++++++++++++ .../tests/test_metrics_endpoint.py | 73 +++++ 4 files changed, 341 insertions(+) create mode 100644 src/api_interface/tests/__init__.py create mode 100644 src/api_interface/tests/conftest.py create mode 100644 src/api_interface/tests/test_batching_service.py create mode 100644 src/api_interface/tests/test_metrics_endpoint.py diff --git a/src/api_interface/tests/__init__.py b/src/api_interface/tests/__init__.py new file mode 100644 index 0000000..61afb20 --- /dev/null +++ b/src/api_interface/tests/__init__.py @@ -0,0 +1 @@ +"""Tests for the OpenTextShield API.""" diff --git a/src/api_interface/tests/conftest.py b/src/api_interface/tests/conftest.py new file mode 100644 index 0000000..afe849f --- /dev/null +++ b/src/api_interface/tests/conftest.py @@ -0,0 +1,17 @@ +"""Pytest configuration for the API test suite.""" + +import pytest + + +def pytest_collection_modifyitems(config, items): + """Auto-mark all coroutine tests so pytest-asyncio picks them up.""" + for item in items: + if item.get_closest_marker("asyncio") is None: + test_fn = getattr(item, "function", None) + if test_fn is not None and _is_coroutine(test_fn): + item.add_marker(pytest.mark.asyncio) + + +def _is_coroutine(fn): + import inspect + return inspect.iscoroutinefunction(fn) diff --git a/src/api_interface/tests/test_batching_service.py b/src/api_interface/tests/test_batching_service.py new file mode 100644 index 0000000..30388ce --- /dev/null +++ b/src/api_interface/tests/test_batching_service.py @@ -0,0 +1,250 @@ +""" +Tests for the DynamicBatcher. + +These tests exercise the real async logic of the batcher against a stub +model + stub tokenizer. They do not require the mBERT weights or the +transformers library — just torch for tensor ops. + +Run: + pytest src/api_interface/tests/test_batching_service.py -v +""" + +import asyncio +import sys +from pathlib import Path +from types import SimpleNamespace +from typing import List + +import pytest +import torch + +# Make ``src`` importable when the tests are run from the repo root. +REPO_ROOT = Path(__file__).resolve().parents[3] +if str(REPO_ROOT) not in sys.path: + sys.path.insert(0, str(REPO_ROOT)) + +from src.api_interface.services import batching_service +from src.api_interface.services.batching_service import DynamicBatcher + + +# --------------------------------------------------------------------------- +# Stubs +# --------------------------------------------------------------------------- +class _StubTokenizer: + """Produces deterministic token ids without requiring HuggingFace tokenizers.""" + + def __call__( + self, + texts, + add_special_tokens=True, + max_length=96, + padding="longest", + truncation=True, + return_attention_mask=True, + return_tensors="pt", + ): + if isinstance(texts, str): + texts = [texts] + # One token per character, capped at max_length. + encoded = [[ord(c) % 30000 for c in t[:max_length]] or [0] for t in texts] + longest = max(len(e) for e in encoded) + input_ids = torch.zeros(len(encoded), longest, dtype=torch.long) + attention_mask = torch.zeros(len(encoded), longest, dtype=torch.long) + for i, row in enumerate(encoded): + input_ids[i, : len(row)] = torch.tensor(row, dtype=torch.long) + attention_mask[i, : len(row)] = 1 + return {"input_ids": input_ids, "attention_mask": attention_mask} + + +class _StubModel(torch.nn.Module): + """Returns a deterministic label per text based on its first character. + + ``h``* -> ham (class 0), ``s``* -> spam (class 1), ``p``* -> phishing (class 2). + Records each batch size so tests can assert coalescing worked. + """ + + def __init__(self) -> None: + super().__init__() + self.batch_sizes: List[int] = [] + self.call_delay = 0.0 + self.raise_on_next = False + + def forward(self, input_ids, attention_mask=None, **_ignored): + self.batch_sizes.append(input_ids.shape[0]) + if self.raise_on_next: + self.raise_on_next = False + raise RuntimeError("forced failure") + if self.call_delay: + import time as _t + _t.sleep(self.call_delay) + batch = input_ids.shape[0] + logits = torch.full((batch, 3), -5.0) + for i in range(batch): + # Use first non-pad token to decide class + first_tok = int(input_ids[i, 0].item()) + ch = chr(first_tok) if first_tok < 128 else "h" + if ch == "s": + cls = 1 + elif ch == "p": + cls = 2 + else: + cls = 0 + logits[i, cls] = 5.0 + return SimpleNamespace(logits=logits) + + +class _StubModelManager: + def __init__(self, model: _StubModel, tokenizer: _StubTokenizer) -> None: + self.device = torch.device("cpu") + self._model = model + self._tokenizer = tokenizer + + def get_mbert_model(self, _name): + return self._model, self._tokenizer, "stub-2.5" + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- +@pytest.fixture() +def stub_stack(monkeypatch): + """Replace the real model_manager inside batching_service with a stub.""" + model = _StubModel() + tokenizer = _StubTokenizer() + manager = _StubModelManager(model, tokenizer) + + # The batcher imports ``model_manager`` lazily inside ``_run_batch_sync`` via + # ``from .model_loader import model_manager`` — patch the attribute on the + # module so that later imports hit our stub. + import src.api_interface.services.model_loader as model_loader_mod + monkeypatch.setattr(model_loader_mod, "model_manager", manager) + + return SimpleNamespace(model=model, tokenizer=tokenizer, manager=manager) + + +@pytest.fixture() +async def batcher(stub_stack): + b = DynamicBatcher(max_batch_size=8, batch_wait_ms=30, max_text_length=32) + await b.start() + try: + yield b, stub_stack + finally: + await b.stop() + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- +@pytest.mark.asyncio +async def test_single_request_returns_correct_label(batcher): + b, stack = batcher + label, prob, elapsed = await b.submit("hello there") + assert label == "ham" + assert 0.99 < prob <= 1.0 + assert elapsed >= 0 + assert stack.model.batch_sizes == [1] + + +@pytest.mark.asyncio +async def test_concurrent_requests_are_coalesced(batcher): + b, stack = batcher + # Fire 8 requests simultaneously. They should land in a single batch since + # the batch wait window (30 ms) is larger than the scheduling jitter and + # max_batch_size == 8. + texts = ["ham one", "spam two", "phish me", "hello", "hi", "hey", "yo", "sup"] + results = await asyncio.gather(*(b.submit(t) for t in texts)) + labels = [r[0] for r in results] + assert labels[0] == "ham" + assert labels[1] == "spam" + assert labels[2] == "phishing" + # One batch, size 8. + assert stack.model.batch_sizes == [8], ( + f"expected a single coalesced batch of 8, got {stack.model.batch_sizes}" + ) + + +@pytest.mark.asyncio +async def test_max_batch_size_is_respected(stub_stack): + b = DynamicBatcher(max_batch_size=4, batch_wait_ms=50, max_text_length=32) + await b.start() + try: + texts = [f"hello {i}" for i in range(10)] + results = await asyncio.gather(*(b.submit(t) for t in texts)) + assert len(results) == 10 + # No single batch should exceed max_batch_size == 4. + assert all(sz <= 4 for sz in stub_stack.model.batch_sizes), stub_stack.model.batch_sizes + # All batches summed should account for every request. + assert sum(stub_stack.model.batch_sizes) == 10 + finally: + await b.stop() + + +@pytest.mark.asyncio +async def test_wait_window_flushes_partial_batch(stub_stack): + b = DynamicBatcher(max_batch_size=32, batch_wait_ms=20, max_text_length=32) + await b.start() + try: + # Only 2 requests; should flush after the 20 ms window, not stall. + import time + t0 = time.monotonic() + await asyncio.gather(b.submit("hello"), b.submit("hi there")) + elapsed = time.monotonic() - t0 + # Must have waited at least one window, but well under a second. + assert 0.01 < elapsed < 1.0, f"elapsed={elapsed}" + assert stub_stack.model.batch_sizes == [2] + finally: + await b.stop() + + +@pytest.mark.asyncio +async def test_model_error_propagates_to_all_futures(stub_stack): + b = DynamicBatcher(max_batch_size=4, batch_wait_ms=20, max_text_length=32) + await b.start() + try: + stub_stack.model.raise_on_next = True + with pytest.raises(Exception): + await asyncio.gather(b.submit("hello"), b.submit("spammy")) + finally: + await b.stop() + + +@pytest.mark.asyncio +async def test_metrics_are_recorded(batcher): + b, stack = batcher + await asyncio.gather(b.submit("hello"), b.submit("spam"), b.submit("phish")) + m = b.metrics + assert m.total_requests == 3 + assert m.total_batches >= 1 + assert m.last_batch_size >= 1 + assert m.total_inference_seconds > 0 + # Histogram bucket for size-3 batch rounds up to 4. + assert any(bucket >= 3 for bucket in m.batch_size_histogram) + + +@pytest.mark.asyncio +async def test_shutdown_fails_pending_requests(stub_stack): + """Pending submissions must not hang when the batcher is stopped.""" + b = DynamicBatcher(max_batch_size=32, batch_wait_ms=50, max_text_length=32) + # Slow the model down so requests queue up. + stub_stack.model.call_delay = 0.2 + await b.start() + # Submit one request that will be in-flight, then stop immediately. + task = asyncio.create_task(b.submit("hello")) + await asyncio.sleep(0.01) + await b.stop() + # Task must complete (success or failure) rather than hang forever. + try: + await asyncio.wait_for(task, timeout=2.0) + except Exception: + pass # any terminal state is acceptable; hanging is not + assert task.done() + + +@pytest.mark.asyncio +async def test_init_batcher_respects_disabled_flag(monkeypatch): + from src.api_interface.config import settings as settings_mod + monkeypatch.setattr(settings_mod.settings, "batching_enabled", False) + # Re-run the init to pick up the new flag. + result = batching_service.init_batcher() + assert result is None + assert batching_service.get_batcher() is None diff --git a/src/api_interface/tests/test_metrics_endpoint.py b/src/api_interface/tests/test_metrics_endpoint.py new file mode 100644 index 0000000..d36b185 --- /dev/null +++ b/src/api_interface/tests/test_metrics_endpoint.py @@ -0,0 +1,73 @@ +""" +Tests for the /metrics Prometheus endpoint. + +These tests spin up the FastAPI app in-memory and verify that the rendered +exposition format contains the expected counters and gauges. +""" + +import sys +from pathlib import Path +from types import SimpleNamespace + +import pytest +from fastapi import FastAPI +from fastapi.testclient import TestClient + +REPO_ROOT = Path(__file__).resolve().parents[3] +if str(REPO_ROOT) not in sys.path: + sys.path.insert(0, str(REPO_ROOT)) + + +def _build_app(batcher=None): + """Construct a minimal FastAPI app with just the metrics router mounted.""" + from src.api_interface.routers import metrics as metrics_router + from src.api_interface.services import batching_service + + batching_service._batcher = batcher # inject stub batcher (or None) + app = FastAPI() + app.include_router(metrics_router.router) + return app + + +def test_metrics_endpoint_disabled_batcher(): + client = TestClient(_build_app(batcher=None)) + resp = client.get("/metrics") + assert resp.status_code == 200 + body = resp.text + assert "ots_api_info" in body + assert "ots_batching_enabled 0" in body + # Must advertise text/plain for Prometheus to scrape it. + assert resp.headers["content-type"].startswith("text/plain") + + +def test_metrics_endpoint_active_batcher(): + from src.api_interface.services.batching_service import DynamicBatcher + + b = DynamicBatcher(max_batch_size=16, batch_wait_ms=20, max_text_length=64) + # Simulate that two batches of sizes 4 and 8 have been processed. + b.metrics.record_batch(size=4, wait_seconds_sum=0.04, inference_seconds=0.02) + b.metrics.record_batch(size=8, wait_seconds_sum=0.12, inference_seconds=0.03) + b.metrics.current_queue_depth = 2 + + client = TestClient(_build_app(batcher=b)) + resp = client.get("/metrics") + assert resp.status_code == 200 + body = resp.text + + # Required series present. + for name in [ + "ots_batching_enabled 1", + "ots_batch_max_size 16", + "ots_requests_total 12", # 4 + 8 + "ots_batches_total 2", + "ots_queue_depth 2", + "ots_last_batch_size 8", + "ots_batch_size_bucket", + ]: + assert name in body, f"expected '{name}' in metrics output:\n{body}" + + # Histogram lines must be well-formed Prometheus counters. + hist_lines = [ln for ln in body.splitlines() if ln.startswith("ots_batch_size_bucket{")] + assert hist_lines, "no histogram buckets emitted" + for line in hist_lines: + assert 'le="' in line