Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/deploy-dataflow.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ jobs:
--project_id ${{ env.PROJECT_ID }} \
--region ${{ env.REGION }} \
--service_url ${{ steps.service_url.outputs.url }} \
--batch_size 50 \
--max_batch_duration_secs 1.0 \
--runner DataflowRunner \
--job_name $JOB_NAME \
--temp_location ${{ env.TEMP_LOCATION }} \
Expand Down
6 changes: 5 additions & 1 deletion Readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -275,10 +275,14 @@ The project uses a **blessed model pattern** for production deployments:
Real-time inference is handled through:

1. **Data Ingestion**: Pub/Sub receives real-time inference requests
2. **Stream Processing**: Dataflow processes messages and calls FastAPI services
2. **Stream Processing**: Dataflow processes messages with micro-batching and calls FastAPI services
3. **Model Serving**: Cloud Run hosts FastAPI containers with blessed models
4. **Results Storage**: Predictions are written to BigQuery for monitoring

Streaming supports **micro-batching** via Beam's `BatchElements` with `max_batch_duration_secs`. Up to 50 messages are grouped into a single `/predict` call, reducing HTTP overhead by ~10-50x. At low traffic, partial batches flush after 1 second so no message waits indefinitely. Both `--batch_size` and `--max_batch_duration_secs` are tunable via CLI args.

For high-volume workloads, the pipeline also uses **async HTTP** (`aiohttp`) to overlap multiple batch calls concurrently within a single worker, providing an additional ~2-4x throughput improvement on top of batching.

### Key Benefits

- **Cost Effective**: Cloud Run FastAPI services cost ~90% less than Vertex AI endpoints
Expand Down
59 changes: 35 additions & 24 deletions docs/scale_inference_batching.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ The current Dataflow streaming pipeline (`iris_streaming_pipeline.py`) makes **o
This means throughput is capped at ~10-30 messages/sec/worker, dominated by network latency — not model inference time. The FastAPI server already accepts `List[Dict]` and calls `model.predict(df)` on the full batch at once, so sending 50 instances costs nearly the same server-side time as sending 1.

Two changes fix this:
1. **Micro-batching** — use Beam windowing + `GroupIntoBatches` to collect ~50 messages before calling `/predict` once
1. **Micro-batching** — use `BatchElements` with `max_batch_duration_secs` to collect up to 50 messages before calling `/predict` once, flushing partial batches after 1 second at low traffic
2. **Async HTTP** — use `aiohttp` instead of `requests` to overlap network I/O across concurrent batch calls within a worker

## Current pipeline flow
Expand All @@ -24,17 +24,19 @@ Pub/Sub → Parse JSON → [1 msg] → HTTP POST (1 instance) → Add Metadata
## Target pipeline flow

```
Pub/Sub → Parse JSON → Window(5s) → BatchElements(50) → [50 msgs] → async HTTP POST (50 instances)
→ unbundle → Add Metadata → BigQuery
Pub/Sub → Parse JSON → BatchElements(max=50, flush=1s) → [up to 50 msgs] → async HTTP POST
→ unbundle → Add Metadata → BigQuery
```

No `FixedWindows` needed. `BatchElements` with `max_batch_duration_secs=1` uses Beam's stateful processing (State & Timers API) to batch across bundles. At high traffic, batches fill to 50 and flush immediately with near-zero latency. At low traffic, partial batches flush after at most 1 second — avoiding the problem where default `BatchElements` only batches within a bundle (which on Dataflow streaming is often size 1, making it a no-op).

## Changes

All changes are in a single file: `src/ml_pipelines_kfp/dataflow/iris_streaming_pipeline.py`
All changes are in a single file: `src/dataflow/iris_streaming_pipeline.py`

### 1. Add windowing + batching to the pipeline
### 1. Add micro-batching to the pipeline

Insert a fixed window and `BatchElements` between Parse and the HTTP call. This tells Beam: "collect messages arriving within a 5-second window, group them into batches of up to 50, then process each batch."
Insert `BatchElements` with `max_batch_duration_secs` between Parse and the HTTP call. This tells Beam: "collect up to 50 messages, flushing immediately when full or after 1 second if the batch is still partial."

```python
from apache_beam.transforms.util import BatchElements
Expand All @@ -44,8 +46,11 @@ predictions = (
pipeline
| "Read from Pub/Sub" >> ReadFromPubSub(topic=known_args.input_topic)
| "Parse JSON" >> beam.ParDo(ParsePubSubMessage())
| "Window" >> beam.WindowInto(window.FixedWindows(5))
| "Batch Elements" >> BatchElements(min_batch_size=1, max_batch_size=50)
| "Batch Elements" >> BatchElements(
min_batch_size=1,
max_batch_size=50,
max_batch_duration_secs=1,
)
| "Call FastAPI Batch" >> beam.ParDo(
BatchCallFastAPIService(known_args.service_url)
)
Expand All @@ -55,17 +60,20 @@ predictions = (
```

**Why these parameters:**
- `FixedWindows(5)` — 5-second windows. Short enough for near-real-time latency, long enough to accumulate a batch under moderate load. At low traffic, `min_batch_size=1` ensures messages don't wait forever.
- `max_batch_size=50` — matches a reasonable HTTP payload size. The FastAPI server builds a pandas DataFrame from the instances; 50 rows is trivial. Going higher (500+) risks HTTP timeouts and large retry payloads.
- `max_batch_duration_secs=1` — activates the **stateful implementation** (Beam State & Timers API, requires Beam 2.52+). Without this, default `BatchElements` only batches within a single bundle — on Dataflow streaming, bundles are frequently size 1 at low throughput, making the transform a no-op. With it, elements are batched across bundles and partial batches flush after 1 second. The timing is best-effort — actual hold time may slightly exceed 1s.
- `min_batch_size=1` — ensures single messages still flow through at low traffic instead of blocking until 50 arrive.
- **No `FixedWindows` needed** — the `max_batch_duration_secs` timer replaces the fixed window. This avoids adding an artificial 5-second latency floor. At high traffic, batches fill to 50 and flush immediately. At low traffic, worst case is ~1 second.

**Tradeoff:** The stateful path requires internal keying, which triggers a shuffle (network transfer between workers). For small payloads like iris features this is negligible, but worth noting for larger payloads.

**Adding CLI args for tuning:**

```python
parser.add_argument("--batch_size", type=int, default=50,
help="Max instances per /predict call")
parser.add_argument("--window_seconds", type=int, default=5,
help="Fixed window duration in seconds")
parser.add_argument("--max_batch_duration_secs", type=float, default=1.0,
help="Max seconds to buffer a partial batch before flushing")
```

### 2. Replace `CallFastAPIService` with `BatchCallFastAPIService`
Expand Down Expand Up @@ -279,15 +287,15 @@ Alternatively, since `pyproject.toml` already declares `aiohttp` and the pipelin

### 5. Remove old `CallFastAPIService` class

Delete the old single-element DoFn (lines 66-142) since it's fully replaced by `BatchCallFastAPIService`.
Delete the old single-element `CallFastAPIService` DoFn since it's fully replaced by `BatchCallFastAPIService`.

## Implementation options

You can do this in two phases or all at once:

| Phase | Change | Throughput gain | Complexity |
|---|---|---|---|
| **Phase 1** | Batching + `requests.Session` (step 1-2) | ~10-50x (amortize network latency over 50 msgs) | Low — straightforward Beam transforms |
| **Phase 1** | Stateful micro-batching + `requests.Session` (step 1-2) | ~10-50x (amortize network latency over 50 msgs) | Low — straightforward Beam transforms |
| **Phase 2** | Async HTTP with `aiohttp` (step 3) | ~2-4x on top of phase 1 (overlap concurrent batches) | Medium — async event loop management |

Phase 1 alone gives the biggest win. Phase 2 adds concurrency on top. If traffic is moderate (~100s msgs/sec), phase 1 may be sufficient.
Expand All @@ -296,32 +304,35 @@ Phase 1 alone gives the biggest win. Phase 2 adds concurrency on top. If traffic

| File | Action |
|---|---|
| `src/ml_pipelines_kfp/dataflow/iris_streaming_pipeline.py` | Add windowing + BatchElements, replace `CallFastAPIService` with `BatchCallFastAPIService`, add `--batch_size` and `--window_seconds` CLI args |
| `scripts/deploy_dataflow_streaming.sh` | Add `--batch_size` and `--window_seconds` flags (optional, defaults are fine) |
| `.github/workflows/deploy-dataflow.yaml` | Same — add flags if overriding defaults |
| `src/dataflow/iris_streaming_pipeline.py` | Add `BatchElements` with `max_batch_duration_secs`, replace `CallFastAPIService` with `BatchCallFastAPIService`, add `--batch_size` and `--max_batch_duration_secs` CLI args |
| `.github/workflows/deploy-dataflow.yaml` | Add `--batch_size` and `--max_batch_duration_secs` flags if overriding defaults |

No changes to `fastapi_server.py` — it already handles batched instances.

## Verification

1. **Unit test locally with DirectRunner:**
1. **Unit test locally with DirectRunner (staging):**
```bash
python src/ml_pipelines_kfp/dataflow/iris_streaming_pipeline.py \
python src/dataflow/iris_streaming_pipeline.py \
--input_topic projects/deeplearning-sahil/topics/iris-inference-data \
--output_table deeplearning-sahil:ml_dataset.iris_predictions_streaming \
--output_table deeplearning-sahil:ml_dataset.iris_predictions_streaming_staging \
--project_id deeplearning-sahil \
--region us-central1 \
--service_url https://iris-classifier-xgboost-service-zoxyfmo73q-uc.a.run.app \
--service_url https://iris-classifier-xgboost-service-staging-zoxyfmo73q-uc.a.run.app \
--runner DirectRunner \
--streaming
```
Run the pubsub producer in parallel and verify predictions land in BigQuery.
Run the pubsub producer in parallel and verify predictions land in the staging BigQuery table.

2. **Check batch sizes in logs:** Look for `Batch prediction failed (N instances)` log pattern — if N is consistently 50, batching is working. At low traffic, N will be smaller (down to 1) which is expected.
2. **Deploy to staging via GitHub Action:** Trigger the `Deploy Dataflow Streaming` workflow with `environment=staging` to validate on Dataflow workers.

3. **Compare throughput:** Before and after, monitor Dataflow job metrics:
3. **Check batch sizes in logs:** Look for `Batch prediction failed (N instances)` log pattern — if N is consistently 50, batching is working. At low traffic, N will be smaller (down to 1) which is expected.

4. **Compare throughput:** Before and after, monitor Dataflow job metrics:
- Elements processed/sec (should increase ~10-50x)
- System lag (should decrease — messages spend less time waiting)
- Worker CPU (should decrease — less time idle on network I/O)

4. **Verify no data loss:** Count Pub/Sub acked messages vs BigQuery rows inserted over a time window. Should match.
5. **Verify no data loss:** Count Pub/Sub acked messages vs BigQuery rows inserted over a time window. Should match.

6. **Promote to prod:** After staging validation, trigger the workflow with `environment=prod`.
Loading
Loading