From 25f5e5fcd9d32e236a353892c562e43a87a469d3 Mon Sep 17 00:00:00 2001 From: Sahil Batra Date: Sun, 14 Jun 2026 19:50:59 -0400 Subject: [PATCH 1/2] support micro batching --- .github/workflows/deploy-dataflow.yaml | 2 + Readme.md | 6 +- docs/scale_inference_batching.md | 59 ++++++---- src/dataflow/iris_streaming_pipeline.py | 146 ++++++++++++++---------- 4 files changed, 128 insertions(+), 85 deletions(-) diff --git a/.github/workflows/deploy-dataflow.yaml b/.github/workflows/deploy-dataflow.yaml index 2e67e2d..a74a20b 100644 --- a/.github/workflows/deploy-dataflow.yaml +++ b/.github/workflows/deploy-dataflow.yaml @@ -78,6 +78,8 @@ jobs: --project_id ${{ env.PROJECT_ID }} \ --region ${{ env.REGION }} \ --service_url ${{ steps.service_url.outputs.url }} \ + --batch_size 50 \ + --max_batch_duration_secs 1.0 \ --runner DataflowRunner \ --job_name $JOB_NAME \ --temp_location ${{ env.TEMP_LOCATION }} \ diff --git a/Readme.md b/Readme.md index 5534f16..45b5702 100644 --- a/Readme.md +++ b/Readme.md @@ -275,10 +275,14 @@ The project uses a **blessed model pattern** for production deployments: Real-time inference is handled through: 1. **Data Ingestion**: Pub/Sub receives real-time inference requests -2. **Stream Processing**: Dataflow processes messages and calls FastAPI services +2. **Stream Processing**: Dataflow processes messages with micro-batching and calls FastAPI services 3. **Model Serving**: Cloud Run hosts FastAPI containers with blessed models 4. **Results Storage**: Predictions are written to BigQuery for monitoring +Streaming supports **micro-batching** via Beam's `BatchElements` with `max_batch_duration_secs`. Up to 50 messages are grouped into a single `/predict` call, reducing HTTP overhead by ~10-50x. At low traffic, partial batches flush after 1 second so no message waits indefinitely. Both `--batch_size` and `--max_batch_duration_secs` are tunable via CLI args. + +For high-volume workloads, the pipeline also uses **async HTTP** (`aiohttp`) to overlap multiple batch calls concurrently within a single worker, providing an additional ~2-4x throughput improvement on top of batching. + ### Key Benefits - **Cost Effective**: Cloud Run FastAPI services cost ~90% less than Vertex AI endpoints diff --git a/docs/scale_inference_batching.md b/docs/scale_inference_batching.md index beeaa0c..96317cf 100644 --- a/docs/scale_inference_batching.md +++ b/docs/scale_inference_batching.md @@ -10,7 +10,7 @@ The current Dataflow streaming pipeline (`iris_streaming_pipeline.py`) makes **o This means throughput is capped at ~10-30 messages/sec/worker, dominated by network latency — not model inference time. The FastAPI server already accepts `List[Dict]` and calls `model.predict(df)` on the full batch at once, so sending 50 instances costs nearly the same server-side time as sending 1. Two changes fix this: -1. **Micro-batching** — use Beam windowing + `GroupIntoBatches` to collect ~50 messages before calling `/predict` once +1. **Micro-batching** — use `BatchElements` with `max_batch_duration_secs` to collect up to 50 messages before calling `/predict` once, flushing partial batches after 1 second at low traffic 2. **Async HTTP** — use `aiohttp` instead of `requests` to overlap network I/O across concurrent batch calls within a worker ## Current pipeline flow @@ -24,17 +24,19 @@ Pub/Sub → Parse JSON → [1 msg] → HTTP POST (1 instance) → Add Metadata ## Target pipeline flow ``` -Pub/Sub → Parse JSON → Window(5s) → BatchElements(50) → [50 msgs] → async HTTP POST (50 instances) - → unbundle → Add Metadata → BigQuery +Pub/Sub → Parse JSON → BatchElements(max=50, flush=1s) → [up to 50 msgs] → async HTTP POST + → unbundle → Add Metadata → BigQuery ``` +No `FixedWindows` needed. `BatchElements` with `max_batch_duration_secs=1` uses Beam's stateful processing (State & Timers API) to batch across bundles. At high traffic, batches fill to 50 and flush immediately with near-zero latency. At low traffic, partial batches flush after at most 1 second — avoiding the problem where default `BatchElements` only batches within a bundle (which on Dataflow streaming is often size 1, making it a no-op). + ## Changes -All changes are in a single file: `src/ml_pipelines_kfp/dataflow/iris_streaming_pipeline.py` +All changes are in a single file: `src/dataflow/iris_streaming_pipeline.py` -### 1. Add windowing + batching to the pipeline +### 1. Add micro-batching to the pipeline -Insert a fixed window and `BatchElements` between Parse and the HTTP call. This tells Beam: "collect messages arriving within a 5-second window, group them into batches of up to 50, then process each batch." +Insert `BatchElements` with `max_batch_duration_secs` between Parse and the HTTP call. This tells Beam: "collect up to 50 messages, flushing immediately when full or after 1 second if the batch is still partial." ```python from apache_beam.transforms.util import BatchElements @@ -44,8 +46,11 @@ predictions = ( pipeline | "Read from Pub/Sub" >> ReadFromPubSub(topic=known_args.input_topic) | "Parse JSON" >> beam.ParDo(ParsePubSubMessage()) - | "Window" >> beam.WindowInto(window.FixedWindows(5)) - | "Batch Elements" >> BatchElements(min_batch_size=1, max_batch_size=50) + | "Batch Elements" >> BatchElements( + min_batch_size=1, + max_batch_size=50, + max_batch_duration_secs=1, + ) | "Call FastAPI Batch" >> beam.ParDo( BatchCallFastAPIService(known_args.service_url) ) @@ -55,17 +60,20 @@ predictions = ( ``` **Why these parameters:** -- `FixedWindows(5)` — 5-second windows. Short enough for near-real-time latency, long enough to accumulate a batch under moderate load. At low traffic, `min_batch_size=1` ensures messages don't wait forever. - `max_batch_size=50` — matches a reasonable HTTP payload size. The FastAPI server builds a pandas DataFrame from the instances; 50 rows is trivial. Going higher (500+) risks HTTP timeouts and large retry payloads. +- `max_batch_duration_secs=1` — activates the **stateful implementation** (Beam State & Timers API, requires Beam 2.52+). Without this, default `BatchElements` only batches within a single bundle — on Dataflow streaming, bundles are frequently size 1 at low throughput, making the transform a no-op. With it, elements are batched across bundles and partial batches flush after 1 second. The timing is best-effort — actual hold time may slightly exceed 1s. - `min_batch_size=1` — ensures single messages still flow through at low traffic instead of blocking until 50 arrive. +- **No `FixedWindows` needed** — the `max_batch_duration_secs` timer replaces the fixed window. This avoids adding an artificial 5-second latency floor. At high traffic, batches fill to 50 and flush immediately. At low traffic, worst case is ~1 second. + +**Tradeoff:** The stateful path requires internal keying, which triggers a shuffle (network transfer between workers). For small payloads like iris features this is negligible, but worth noting for larger payloads. **Adding CLI args for tuning:** ```python parser.add_argument("--batch_size", type=int, default=50, help="Max instances per /predict call") -parser.add_argument("--window_seconds", type=int, default=5, - help="Fixed window duration in seconds") +parser.add_argument("--max_batch_duration_secs", type=float, default=1.0, + help="Max seconds to buffer a partial batch before flushing") ``` ### 2. Replace `CallFastAPIService` with `BatchCallFastAPIService` @@ -279,7 +287,7 @@ Alternatively, since `pyproject.toml` already declares `aiohttp` and the pipelin ### 5. Remove old `CallFastAPIService` class -Delete the old single-element DoFn (lines 66-142) since it's fully replaced by `BatchCallFastAPIService`. +Delete the old single-element `CallFastAPIService` DoFn since it's fully replaced by `BatchCallFastAPIService`. ## Implementation options @@ -287,7 +295,7 @@ You can do this in two phases or all at once: | Phase | Change | Throughput gain | Complexity | |---|---|---|---| -| **Phase 1** | Batching + `requests.Session` (step 1-2) | ~10-50x (amortize network latency over 50 msgs) | Low — straightforward Beam transforms | +| **Phase 1** | Stateful micro-batching + `requests.Session` (step 1-2) | ~10-50x (amortize network latency over 50 msgs) | Low — straightforward Beam transforms | | **Phase 2** | Async HTTP with `aiohttp` (step 3) | ~2-4x on top of phase 1 (overlap concurrent batches) | Medium — async event loop management | Phase 1 alone gives the biggest win. Phase 2 adds concurrency on top. If traffic is moderate (~100s msgs/sec), phase 1 may be sufficient. @@ -296,32 +304,35 @@ Phase 1 alone gives the biggest win. Phase 2 adds concurrency on top. If traffic | File | Action | |---|---| -| `src/ml_pipelines_kfp/dataflow/iris_streaming_pipeline.py` | Add windowing + BatchElements, replace `CallFastAPIService` with `BatchCallFastAPIService`, add `--batch_size` and `--window_seconds` CLI args | -| `scripts/deploy_dataflow_streaming.sh` | Add `--batch_size` and `--window_seconds` flags (optional, defaults are fine) | -| `.github/workflows/deploy-dataflow.yaml` | Same — add flags if overriding defaults | +| `src/dataflow/iris_streaming_pipeline.py` | Add `BatchElements` with `max_batch_duration_secs`, replace `CallFastAPIService` with `BatchCallFastAPIService`, add `--batch_size` and `--max_batch_duration_secs` CLI args | +| `.github/workflows/deploy-dataflow.yaml` | Add `--batch_size` and `--max_batch_duration_secs` flags if overriding defaults | No changes to `fastapi_server.py` — it already handles batched instances. ## Verification -1. **Unit test locally with DirectRunner:** +1. **Unit test locally with DirectRunner (staging):** ```bash - python src/ml_pipelines_kfp/dataflow/iris_streaming_pipeline.py \ + python src/dataflow/iris_streaming_pipeline.py \ --input_topic projects/deeplearning-sahil/topics/iris-inference-data \ - --output_table deeplearning-sahil:ml_dataset.iris_predictions_streaming \ + --output_table deeplearning-sahil:ml_dataset.iris_predictions_streaming_staging \ --project_id deeplearning-sahil \ --region us-central1 \ - --service_url https://iris-classifier-xgboost-service-zoxyfmo73q-uc.a.run.app \ + --service_url https://iris-classifier-xgboost-service-staging-zoxyfmo73q-uc.a.run.app \ --runner DirectRunner \ --streaming ``` - Run the pubsub producer in parallel and verify predictions land in BigQuery. + Run the pubsub producer in parallel and verify predictions land in the staging BigQuery table. -2. **Check batch sizes in logs:** Look for `Batch prediction failed (N instances)` log pattern — if N is consistently 50, batching is working. At low traffic, N will be smaller (down to 1) which is expected. +2. **Deploy to staging via GitHub Action:** Trigger the `Deploy Dataflow Streaming` workflow with `environment=staging` to validate on Dataflow workers. -3. **Compare throughput:** Before and after, monitor Dataflow job metrics: +3. **Check batch sizes in logs:** Look for `Batch prediction failed (N instances)` log pattern — if N is consistently 50, batching is working. At low traffic, N will be smaller (down to 1) which is expected. + +4. **Compare throughput:** Before and after, monitor Dataflow job metrics: - Elements processed/sec (should increase ~10-50x) - System lag (should decrease — messages spend less time waiting) - Worker CPU (should decrease — less time idle on network I/O) -4. **Verify no data loss:** Count Pub/Sub acked messages vs BigQuery rows inserted over a time window. Should match. +5. **Verify no data loss:** Count Pub/Sub acked messages vs BigQuery rows inserted over a time window. Should match. + +6. **Promote to prod:** After staging validation, trigger the workflow with `environment=prod`. diff --git a/src/dataflow/iris_streaming_pipeline.py b/src/dataflow/iris_streaming_pipeline.py index 563090d..33ed05b 100644 --- a/src/dataflow/iris_streaming_pipeline.py +++ b/src/dataflow/iris_streaming_pipeline.py @@ -5,13 +5,15 @@ import json import argparse +import asyncio +import logging from typing import Any, Dict, List -import requests import time +import aiohttp import apache_beam as beam from apache_beam.options.pipeline_options import PipelineOptions -from apache_beam.transforms import window +from apache_beam.transforms.util import BatchElements from apache_beam.io import ReadFromPubSub, WriteToBigQuery from ml_pipelines_kfp.log import get_logger @@ -62,78 +64,89 @@ def process(self, element): logger.error(f"Error parsing message: {e}, message: {element}") -class CallFastAPIService(beam.DoFn): - """Call FastAPI ML service for inference.""" +class BatchCallFastAPIService(beam.DoFn): + """Call FastAPI with a batch of instances using async HTTP.""" - def __init__(self, service_url: str): + def __init__(self, service_url, max_concurrent=4): self.service_url = service_url self.predict_url = f"{service_url}/predict" + self.max_concurrent = max_concurrent - def process(self, element): - import time + def setup(self): + self._loop = asyncio.new_event_loop() + self._connector = aiohttp.TCPConnector(limit=self.max_concurrent) + self._session = aiohttp.ClientSession(connector=self._connector) + + def teardown(self): + self._loop.run_until_complete(self._session.close()) + self._loop.close() + + def process(self, batch): + results = self._loop.run_until_complete(self._call_async(batch)) + yield from results + + async def _call_async(self, batch): from datetime import datetime - import requests start_time = time.time() - try: - payload = { - "instances": [ - { - "SepalLengthCm": element["sepal_length"], - "SepalWidthCm": element["sepal_width"], - "PetalLengthCm": element["petal_length"], - "PetalWidthCm": element["petal_width"], - } - ] + instances = [ + { + "SepalLengthCm": e["sepal_length"], + "SepalWidthCm": e["sepal_width"], + "PetalLengthCm": e["petal_length"], + "PetalWidthCm": e["petal_width"], } + for e in batch + ] - response = requests.post(self.predict_url, json=payload, timeout=30) - response.raise_for_status() + try: + async with self._session.post( + self.predict_url, + json={"instances": instances}, + timeout=aiohttp.ClientTimeout(total=30), + ) as response: + response.raise_for_status() + result_data = await response.json() - result_data = response.json() predictions = result_data.get("predictions", []) - - if predictions: - prediction_result = predictions[0] - predicted_class = str(prediction_result.get("prediction", "unknown")) - else: - predicted_class = "unknown" - processing_time = time.time() - start_time - result = { - "sepal_length": element["sepal_length"], - "sepal_width": element["sepal_width"], - "petal_length": element["petal_length"], - "petal_width": element["petal_width"], - "timestamp": element.get("timestamp", datetime.utcnow().isoformat()), - "sample_id": element.get("sample_id", 0), - "prediction": predicted_class, - "prediction_timestamp": datetime.utcnow().isoformat(), - "model_service": self.service_url, - "processing_time": processing_time, - } - - logger.info( - f"Prediction for sample {element.get('sample_id')}: {predicted_class}" - ) - yield result + results = [] + for element, pred in zip(batch, predictions): + predicted_class = str(pred.get("prediction", "unknown")) + results.append({ + "sepal_length": element["sepal_length"], + "sepal_width": element["sepal_width"], + "petal_length": element["petal_length"], + "petal_width": element["petal_width"], + "timestamp": element.get("timestamp", datetime.utcnow().isoformat()), + "sample_id": element.get("sample_id", 0), + "prediction": predicted_class, + "prediction_timestamp": datetime.utcnow().isoformat(), + "model_service": self.service_url, + "processing_time": processing_time / len(batch), + }) + return results except Exception as e: - logger.error(f"Error calling FastAPI service: {e}, element: {element}") - yield { - "sepal_length": element.get("sepal_length", 0.0), - "sepal_width": element.get("sepal_width", 0.0), - "petal_length": element.get("petal_length", 0.0), - "petal_width": element.get("petal_width", 0.0), - "timestamp": element.get("timestamp", datetime.utcnow().isoformat()), - "sample_id": element.get("sample_id", 0), - "prediction": "ERROR", - "prediction_timestamp": datetime.utcnow().isoformat(), - "model_service": f"ERROR: {str(e)}", - "processing_time": time.time() - start_time, - } + logging.error(f"Batch prediction failed ({len(batch)} instances): {e}") + processing_time = time.time() - start_time + return [ + { + "sepal_length": el.get("sepal_length", 0.0), + "sepal_width": el.get("sepal_width", 0.0), + "petal_length": el.get("petal_length", 0.0), + "petal_width": el.get("petal_width", 0.0), + "timestamp": el.get("timestamp", datetime.utcnow().isoformat()), + "sample_id": el.get("sample_id", 0), + "prediction": "ERROR", + "prediction_timestamp": datetime.utcnow().isoformat(), + "model_service": f"ERROR: {str(e)}", + "processing_time": processing_time, + } + for el in batch + ] class AddProcessingMetadata(beam.DoFn): @@ -164,6 +177,14 @@ def run_pipeline(argv=None): parser.add_argument("--project_id", required=True, help="Project ID") parser.add_argument("--region", required=True, help="GCP Region") parser.add_argument("--service_url", required=True, help="FastAPI service URL") + parser.add_argument( + "--batch_size", type=int, default=50, + help="Max instances per /predict call", + ) + parser.add_argument( + "--max_batch_duration_secs", type=float, default=1.0, + help="Max seconds to buffer a partial batch before flushing", + ) known_args, pipeline_args = parser.parse_known_args(argv) logger.info(f"Known args: {known_args}") @@ -183,8 +204,13 @@ def run_pipeline(argv=None): pipeline | "Read from Pub/Sub" >> ReadFromPubSub(topic=known_args.input_topic) | "Parse JSON" >> beam.ParDo(ParsePubSubMessage()) - | "Call FastAPI Service" - >> beam.ParDo(CallFastAPIService(known_args.service_url)) + | "Batch Elements" >> BatchElements( + min_batch_size=1, + max_batch_size=known_args.batch_size, + max_batch_duration_secs=known_args.max_batch_duration_secs, + ) + | "Call FastAPI Batch" + >> beam.ParDo(BatchCallFastAPIService(known_args.service_url)) | "Add Metadata" >> beam.ParDo(AddProcessingMetadata()) | "Write to BigQuery" >> WriteToBigQuery( From 8c833dcd876d34b7f9e36ace9b1724fb9cad150d Mon Sep 17 00:00:00 2001 From: Sahil Batra Date: Sun, 14 Jun 2026 19:55:01 -0400 Subject: [PATCH 2/2] Add --no_wait flag so deploy exits after submitting streaming job The `with beam.Pipeline()` context manager calls wait_until_finish(), which blocks forever for streaming pipelines. Use explicit pipeline.run() with conditional wait gated on --no_wait. Co-Authored-By: Claude Opus 4.6 --- src/dataflow/iris_streaming_pipeline.py | 56 ++++++++++++++----------- 1 file changed, 32 insertions(+), 24 deletions(-) diff --git a/src/dataflow/iris_streaming_pipeline.py b/src/dataflow/iris_streaming_pipeline.py index 33ed05b..e5421ac 100644 --- a/src/dataflow/iris_streaming_pipeline.py +++ b/src/dataflow/iris_streaming_pipeline.py @@ -185,6 +185,10 @@ def run_pipeline(argv=None): "--max_batch_duration_secs", type=float, default=1.0, help="Max seconds to buffer a partial batch before flushing", ) + parser.add_argument( + "--no_wait", action="store_true", + help="Submit the job and exit without waiting for it to finish", + ) known_args, pipeline_args = parser.parse_known_args(argv) logger.info(f"Known args: {known_args}") @@ -198,31 +202,35 @@ def run_pipeline(argv=None): google_cloud_options.project = known_args.project_id google_cloud_options.region = known_args.region - with beam.Pipeline(options=pipeline_options) as pipeline: - - predictions = ( - pipeline - | "Read from Pub/Sub" >> ReadFromPubSub(topic=known_args.input_topic) - | "Parse JSON" >> beam.ParDo(ParsePubSubMessage()) - | "Batch Elements" >> BatchElements( - min_batch_size=1, - max_batch_size=known_args.batch_size, - max_batch_duration_secs=known_args.max_batch_duration_secs, - ) - | "Call FastAPI Batch" - >> beam.ParDo(BatchCallFastAPIService(known_args.service_url)) - | "Add Metadata" >> beam.ParDo(AddProcessingMetadata()) - | "Write to BigQuery" - >> WriteToBigQuery( - table=known_args.output_table, - schema=PREDICTION_SCHEMA, - write_disposition=beam.io.BigQueryDisposition.WRITE_APPEND, - create_disposition=beam.io.BigQueryDisposition.CREATE_IF_NEEDED, - additional_bq_parameters={ - "timePartitioning": {"type": "DAY", "field": "prediction_timestamp"} - }, - ) + pipeline = beam.Pipeline(options=pipeline_options) + + predictions = ( + pipeline + | "Read from Pub/Sub" >> ReadFromPubSub(topic=known_args.input_topic) + | "Parse JSON" >> beam.ParDo(ParsePubSubMessage()) + | "Batch Elements" >> BatchElements( + min_batch_size=1, + max_batch_size=known_args.batch_size, + max_batch_duration_secs=known_args.max_batch_duration_secs, ) + | "Call FastAPI Batch" + >> beam.ParDo(BatchCallFastAPIService(known_args.service_url)) + | "Add Metadata" >> beam.ParDo(AddProcessingMetadata()) + | "Write to BigQuery" + >> WriteToBigQuery( + table=known_args.output_table, + schema=PREDICTION_SCHEMA, + write_disposition=beam.io.BigQueryDisposition.WRITE_APPEND, + create_disposition=beam.io.BigQueryDisposition.CREATE_IF_NEEDED, + additional_bq_parameters={ + "timePartitioning": {"type": "DAY", "field": "prediction_timestamp"} + }, + ) + ) + + result = pipeline.run() + if not known_args.no_wait: + result.wait_until_finish() if __name__ == "__main__":