Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions CLAUDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,41 @@ curl "http://localhost:8002/tmf-api/aiInferenceJob"
- Model pre-loaded at startup to avoid initialization overhead
- Single high-accuracy multilingual model for consistent global results

### Dynamic Batching (v2.9+)
Concurrent prediction requests are coalesced into padded batches by the
`DynamicBatcher` in `src/api_interface/services/batching_service.py`. This
raises GPU utilisation dramatically — a T4 handles a batch of 32 SMS in
roughly the same wall-time as a single message, so per-message throughput
scales near-linearly with batch size up to saturation.

Tuning knobs (all overridable via `OTS_` env vars):

| Setting | Default | Description |
|---|---|---|
| `batching_enabled` | `true` | Master switch. Disable for single-request debugging. |
| `max_batch_size` | `32` | Maximum requests per forward pass. |
| `batch_wait_ms` | `15` | Max time to wait collecting a partial batch. |
| `max_text_length` | `96` | Token truncation (was 512; typical SMS = 20–60 tokens). |
| `use_fp16` | `true` | FP16 weights on CUDA. Ignored on CPU/MPS. |

Example:
```bash
OTS_MAX_BATCH_SIZE=16 OTS_BATCH_WAIT_MS=20 \
uvicorn src.api_interface.main:app --host 0.0.0.0 --port 8002
```

### Observability: /metrics
Prometheus-compatible metrics are exposed at `GET /metrics` (no extra
dependency). Useful series:

- `ots_requests_total` / `ots_batches_total` — throughput counters
- `ots_inference_seconds_total` — total GPU time
- `ots_queue_depth` — current batcher backlog (adaptive-concurrency signal)
- `ots_last_batch_size` / `ots_batch_size_bucket{le="N"}` — batch efficiency
- `ots_api_info{device,fp16,max_text_length,version}` — build info

The `ots-bridge` can scrape this to drive adaptive concurrency limits.

### Development Workflow
1. Dataset preparation in CSV format with `text,label` columns
2. Model training using framework-specific training scripts
Expand Down
9 changes: 6 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -280,10 +280,12 @@ docker-compose up -d
- Technical details and performance indicators

### Performance
- **Inference Speed**: 54 messages/second (Apple Silicon M1 Pro)
- **Response Time**: <200ms typical
- **Inference Speed**: 54 messages/second (Apple Silicon M1 Pro, single-request)
- **Dynamic Batching**: Coalesces concurrent requests into padded GPU batches — on NVIDIA T4 (FP16, batch=32) this unlocks hundreds of MPS per instance
- **Response Time**: <200ms typical (single-request); per-message cost drops sharply under load thanks to batching
- **Languages**: 104+ supported via mBERT
- **Accuracy**: Production-ready classification
- **Tuning**: `OTS_MAX_BATCH_SIZE`, `OTS_BATCH_WAIT_MS`, `OTS_MAX_TEXT_LENGTH`, `OTS_USE_FP16` env vars

## 🧪 Testing

Expand Down Expand Up @@ -400,7 +402,8 @@ spec:

### Health Checks
- **API Health**: `GET /health`
- **Model Status**: `GET /model/status`
- **Model Status**: `GET /model/status`
- **Prometheus Metrics**: `GET /metrics` — batcher throughput, queue depth, batch-size histogram, inference time
- **System Metrics**: Built-in performance monitoring

### Logs
Expand Down
20 changes: 18 additions & 2 deletions src/api_interface/config/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,28 @@ class Settings(BaseSettings):


# Processing
max_text_length: int = 512
# SMS payloads are typically 20-60 tokens. The previous default of 512
# padded every request to the full sequence length which wasted ~10x the
# GPU FLOPs per forward pass. 96 tokens covers even long-form phishing
# content with headroom; outliers are truncated safely.
max_text_length: int = 96
default_model: str = "ots-mbert"
default_mbert_version: str = "multilingual"

# Device Configuration
device: str = "cpu" # Will be overridden by auto-detection

# FP16 inference on CUDA. Ignored on CPU/MPS. Tensor-core GPUs (T4, A10,
# L4, A100) gain ~2x throughput with no measurable accuracy loss for
# classification heads.
use_fp16: bool = True

# Dynamic batching configuration. Coalesces concurrent single-message
# requests into padded batches to raise GPU utilization. Disable for
# single-request debugging.
batching_enabled: bool = True
max_batch_size: int = 32
batch_wait_ms: int = 15

# Logging
log_level: str = "INFO"
Expand Down
19 changes: 18 additions & 1 deletion src/api_interface/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@
from .config.settings import settings
from .utils.logging import setup_logging, logger
from .utils.exceptions import OpenTextShieldException
from .services.batching_service import get_batcher, init_batcher
from .services.model_loader import model_manager
from .middleware.security import setup_cors_middleware
from .routers import health, prediction, feedback, audit, tmforum_event
from .routers import health, metrics, prediction, feedback, audit, tmforum_event

# Import TMForum AI Inference Job components (optional - may not be available in all deployments)
try:
Expand Down Expand Up @@ -54,6 +55,12 @@ async def lifespan(app: FastAPI):
)
logger.info("All models loaded successfully")

# Start dynamic batcher after models are loaded so the worker has a
# model to call. Safe no-op when batching is disabled in settings.
batcher = init_batcher()
if batcher is not None:
await batcher.start()

# Initialize TMForum service (if available)
if TMFORUM_AVAILABLE and tmforum_service:
await tmforum_service.initialize()
Expand All @@ -71,6 +78,15 @@ async def lifespan(app: FastAPI):
# Shutdown
logger.info("Shutting down OpenTextShield API...")

# Stop the dynamic batcher first so queued requests are failed cleanly
# before we tear down the rest of the service.
batcher = get_batcher()
if batcher is not None:
try:
await batcher.stop()
except Exception as e:
logger.error(f"Error stopping batcher: {str(e)}")

# Shutdown TMForum service (if available)
if TMFORUM_AVAILABLE and tmforum_service:
try:
Expand Down Expand Up @@ -98,6 +114,7 @@ async def lifespan(app: FastAPI):

# Include routers
app.include_router(health.router)
app.include_router(metrics.router)
app.include_router(prediction.router)
app.include_router(feedback.router)
app.include_router(audit.router)
Expand Down
2 changes: 1 addition & 1 deletion src/api_interface/routers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
Includes both legacy and TMForum-compliant API endpoints.
"""

from . import health, prediction, feedback
from . import health, metrics, prediction, feedback

# TMForum router is optional
try:
Expand Down
128 changes: 128 additions & 0 deletions src/api_interface/routers/metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
"""
Prometheus-compatible metrics endpoint.

Exposes batcher and model throughput counters in the text-based exposition
format so operators (including the ots-bridge) can scrape per-instance GPU
utilisation, batch efficiency, and queue pressure without pulling in an
extra client dependency.
"""

from fastapi import APIRouter, Response

from ..services.batching_service import get_batcher
from ..services.model_loader import model_manager
from ..config.settings import settings

router = APIRouter(tags=["Metrics"])

_CONTENT_TYPE = "text/plain; version=0.0.4; charset=utf-8"


def _render() -> str:
lines = []

def emit(name: str, help_text: str, type_: str, value: float, labels: str = "") -> None:
lines.append(f"# HELP {name} {help_text}")
lines.append(f"# TYPE {name} {type_}")
if labels:
lines.append(f"{name}{{{labels}}} {value}")
else:
lines.append(f"{name} {value}")

emit(
"ots_api_info",
"Static build info.",
"gauge",
1,
labels=(
f'version="{settings.api_version}",'
f'device="{model_manager.device.type}",'
f'fp16="{str(settings.use_fp16 and model_manager.device.type == "cuda").lower()}",'
f'max_text_length="{settings.max_text_length}"'
),
)

batcher = get_batcher()
if batcher is None:
emit(
"ots_batching_enabled",
"1 if dynamic batching is active.",
"gauge",
0,
)
return "\n".join(lines) + "\n"

m = batcher.metrics
emit("ots_batching_enabled", "1 if dynamic batching is active.", "gauge", 1)
emit(
"ots_batch_max_size",
"Configured maximum batch size.",
"gauge",
batcher.max_batch_size,
)
emit(
"ots_batch_wait_seconds",
"Configured max wait window before flushing a partial batch.",
"gauge",
batcher.batch_wait_seconds,
)
emit(
"ots_requests_total",
"Total prediction requests processed through the batcher.",
"counter",
m.total_requests,
)
emit(
"ots_batches_total",
"Total batches executed.",
"counter",
m.total_batches,
)
emit(
"ots_inference_seconds_total",
"Total wall-clock seconds spent in model forward passes.",
"counter",
round(m.total_inference_seconds, 6),
)
emit(
"ots_request_wait_seconds_total",
"Sum of end-to-end wait time (queue + inference) across all requests.",
"counter",
round(m.total_wait_seconds, 6),
)
emit(
"ots_queue_depth",
"Current number of requests waiting in the batcher queue.",
"gauge",
m.current_queue_depth,
)
emit(
"ots_last_batch_size",
"Size of the most recently executed batch.",
"gauge",
m.last_batch_size,
)
emit(
"ots_batch_errors_total",
"Total batches that failed with an uncaught exception.",
"counter",
m.errors,
)

# Histogram-style bucket counters for batch sizes (power-of-two buckets).
lines.append("# HELP ots_batch_size_bucket Number of batches per power-of-two size bucket.")
lines.append("# TYPE ots_batch_size_bucket counter")
for bucket, count in sorted(m.batch_size_histogram.items()):
lines.append(f'ots_batch_size_bucket{{le="{bucket}"}} {count}')

return "\n".join(lines) + "\n"


@router.get(
"/metrics",
summary="Prometheus metrics",
description="Prometheus-compatible metrics scrape endpoint.",
response_class=Response,
)
async def metrics_endpoint() -> Response:
return Response(content=_render(), media_type=_CONTENT_TYPE)
Loading