From fe4343106e1fbafb5ddaded4a9f3694b229964dc Mon Sep 17 00:00:00 2001 From: Raphael Date: Mon, 2 Mar 2026 22:18:13 -0800 Subject: [PATCH 1/4] dry-run: raise DryRunEarlyExit and add CLI graceful exit --- src/batchling/__init__.py | 2 + src/batchling/cli/main.py | 28 +++--- src/batchling/core.py | 76 +++++++-------- src/batchling/exceptions.py | 46 +++++++++ tests/test_cli_script_runner.py | 33 +++++++ tests/test_core.py | 166 ++++++++++++++++---------------- 6 files changed, 215 insertions(+), 136 deletions(-) create mode 100644 src/batchling/exceptions.py diff --git a/src/batchling/__init__.py b/src/batchling/__init__.py index ac328d4..bf31715 100644 --- a/src/batchling/__init__.py +++ b/src/batchling/__init__.py @@ -1,5 +1,7 @@ from .api import batchify as batchify +from .exceptions import DryRunEarlyExit as DryRunEarlyExit __all__ = [ "batchify", + "DryRunEarlyExit", ] diff --git a/src/batchling/cli/main.py b/src/batchling/cli/main.py index 6cba6dd..549d728 100644 --- a/src/batchling/cli/main.py +++ b/src/batchling/cli/main.py @@ -7,6 +7,7 @@ import typer from batchling import batchify +from batchling.exceptions import DryRunEarlyExit # syncify = lambda f: wraps(f)(lambda *args, **kwargs: asyncio.run(f(*args, **kwargs))) @@ -177,16 +178,19 @@ def main( raise typer.BadParameter("Script path must be a module path, use 'module:func' syntax") script_args = list(ctx.args) - asyncio.run( - run_script_with_batchify( - module_path=Path(module_path), - func_name=func_name, - script_args=script_args, - batch_size=batch_size, - batch_window_seconds=batch_window_seconds, - batch_poll_interval_seconds=batch_poll_interval_seconds, - dry_run=dry_run, - cache=cache, - live_display=live_display, + try: + asyncio.run( + run_script_with_batchify( + module_path=Path(module_path), + func_name=func_name, + script_args=script_args, + batch_size=batch_size, + batch_window_seconds=batch_window_seconds, + batch_poll_interval_seconds=batch_poll_interval_seconds, + dry_run=dry_run, + cache=cache, + live_display=live_display, + ) ) - ) + except DryRunEarlyExit: + return diff --git a/src/batchling/core.py b/src/batchling/core.py index f27c451..06906ab 100644 --- a/src/batchling/core.py +++ b/src/batchling/core.py @@ -19,6 +19,7 @@ import httpx from batchling.cache import CacheEntry, RequestCacheStore +from batchling.exceptions import DryRunEarlyExit from batchling.logging import log_debug, log_error, log_info, log_warning from batchling.providers import BaseProvider from batchling.providers.base import PollSnapshot, ProviderRequestSpec @@ -493,10 +494,14 @@ async def _try_submit_from_cache( future=future, request_hash=request_hash, ) - future.set_result( - self._build_dry_run_response( - request=dry_run_request, - cache_hit=True, + future.set_exception( + DryRunEarlyExit( + source="cache_dry_run", + provider=dry_run_request.provider.name, + endpoint=dry_run_request.queue_key[1], + model=dry_run_request.queue_key[2], + batch_id=cache_entry.batch_id, + custom_id=dry_run_request.custom_id, ) ) return cache_entry @@ -641,6 +646,8 @@ async def submit( return await future except asyncio.CancelledError: raise + except DryRunEarlyExit: + raise except Exception as error: log_warning( logger=log, @@ -1051,9 +1058,14 @@ async def _process_batch( self._active_batches.append(active_batch) for req in requests: if not req.future.done(): - req.future.set_result( - self._build_dry_run_response( - request=req, + req.future.set_exception( + DryRunEarlyExit( + source="dry_run", + provider=req.provider.name, + endpoint=req.queue_key[1], + model=req.queue_key[2], + batch_id=dry_run_batch_id, + custom_id=req.custom_id, ) ) log_info( @@ -1157,42 +1169,6 @@ async def _process_batch( if not req.future.done(): req.future.set_exception(e) - def _build_dry_run_response( - self, - *, - request: _PendingRequest, - cache_hit: bool = False, - ) -> httpx.Response: - """ - Build a synthetic response for dry-run batch mode. - - Parameters - ---------- - request : _PendingRequest - Pending request metadata. - cache_hit : bool, optional - Whether this dry-run response came from a cache lookup. - - Returns - ------- - httpx.Response - Synthetic successful response. - """ - return httpx.Response( - status_code=200, - headers={ - "x-batchling-dry-run": "1", - "x-batchling-cache-hit": "1" if cache_hit else "0", - }, - json={ - "dry_run": True, - "custom_id": request.custom_id, - "provider": request.provider.name, - "status": "simulated", - "cache_hit": cache_hit, - }, - ) - async def _poll_batch_once( self, *, @@ -1712,3 +1688,17 @@ async def close(self) -> None: queue_key=queue_key, requests=requests, ) + + # Wait for in-flight batch submission and resumed poll tasks so close() + # leaves the batcher in a stable state for summary/report consumers. + while True: + pending_batch_tasks = [task for task in self._batch_tasks if not task.done()] + if not pending_batch_tasks: + break + _ = await asyncio.gather(*pending_batch_tasks, return_exceptions=True) + + while True: + pending_resumed_tasks = [task for task in self._resumed_poll_tasks if not task.done()] + if not pending_resumed_tasks: + break + _ = await asyncio.gather(*pending_resumed_tasks, return_exceptions=True) diff --git a/src/batchling/exceptions.py b/src/batchling/exceptions.py new file mode 100644 index 0000000..3fcef03 --- /dev/null +++ b/src/batchling/exceptions.py @@ -0,0 +1,46 @@ +"""Public exceptions exposed by batchling.""" + +from __future__ import annotations + + +class DryRunEarlyExit(RuntimeError): + """ + Raised when dry-run mode exits before returning a provider response. + + Parameters + ---------- + source : str + Dry-run branch that produced the early exit. + provider : str + Provider name. + endpoint : str + Endpoint key. + model : str + Model key. + batch_id : str + Simulated batch identifier. + custom_id : str + Request custom identifier. + """ + + def __init__( + self, + *, + source: str, + provider: str, + endpoint: str, + model: str, + batch_id: str, + custom_id: str, + ) -> None: + self.source = source + self.provider = provider + self.endpoint = endpoint + self.model = model + self.batch_id = batch_id + self.custom_id = custom_id + message = ( + "Dry run early exit triggered for " + f"{provider}:{endpoint}:{model} (batch_id={batch_id}, custom_id={custom_id}, source={source})" + ) + super().__init__(message) diff --git a/tests/test_cli_script_runner.py b/tests/test_cli_script_runner.py index e7426d7..27fa3d1 100644 --- a/tests/test_cli_script_runner.py +++ b/tests/test_cli_script_runner.py @@ -5,6 +5,7 @@ import batchling.cli.main as cli_main from batchling.cli.main import app +from batchling.exceptions import DryRunEarlyExit runner = CliRunner() @@ -206,3 +207,35 @@ def test_run_script_invalid_script_path(): assert result.exit_code == 1 assert "Script not found" in result.output + + +def test_cli_catches_dry_run_early_exit(tmp_path: Path, monkeypatch) -> None: + script_path = tmp_path / "script.py" + script_path.write_text("async def foo():\n return None\n") + + async def raise_dry_run_early_exit(**kwargs) -> None: + del kwargs + raise DryRunEarlyExit( + source="dry_run", + provider="openai", + endpoint="/v1/chat/completions", + model="model-a", + batch_id="dryrun-1", + custom_id="custom-1", + ) + + monkeypatch.setattr( + cli_main, + "run_script_with_batchify", + raise_dry_run_early_exit, + ) + + result = runner.invoke( + app, + [ + f"{script_path.as_posix()}:foo", + "--dry-run", + ], + ) + + assert result.exit_code == 0 diff --git a/tests/test_core.py b/tests/test_core.py index 960ba6e..daddd62 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -11,6 +11,7 @@ from batchling.cache import CacheEntry from batchling.core import Batcher, BatcherEvent, _ActiveBatch, _PendingRequest +from batchling.exceptions import DryRunEarlyExit from batchling.providers.anthropic import AnthropicProvider from batchling.providers.base import PollSnapshot, ProviderRequestSpec, ResumeContext from batchling.providers.gemini import GeminiProvider @@ -816,43 +817,32 @@ async def test_submit_after_close(batcher: Batcher, provider: OpenAIProvider): @pytest.mark.asyncio -async def test_dry_run_returns_simulated_response(provider: OpenAIProvider): - """Test dry-run returns a simulated response without provider I/O.""" +async def test_dry_run_submit_raises_early_exit(provider: OpenAIProvider): + """Test dry-run submit raises early exit without provider I/O.""" dry_run_batcher = Batcher( - batch_size=3, + batch_size=1, batch_window_seconds=0.5, dry_run=True, cache=False, ) - result = await dry_run_batcher.submit( - client_type="httpx", - method="GET", - url="api.openai.com", - endpoint="/v1/chat/completions", - provider=provider, - body=b'{"model":"model-a","messages":[]}', - headers={"Authorization": "Bearer token"}, - ) - - result = await dry_run_batcher.submit( - client_type="httpx", - method="GET", - url="api.openai.com", - endpoint="/v1/chat/completions", - provider=provider, - body=b'{"model":"model-a","messages":[]}', - headers={"Authorization": "Bearer token"}, - ) + with pytest.raises(DryRunEarlyExit) as error: + _ = await dry_run_batcher.submit( + client_type="httpx", + method="GET", + url="api.openai.com", + endpoint="/v1/chat/completions", + provider=provider, + body=b'{"model":"model-a","messages":[]}', + headers={"Authorization": "Bearer token"}, + ) - assert isinstance(result, httpx.Response) - assert result.status_code == 200 - assert result.headers.get("x-batchling-dry-run") == "1" - assert result.json()["dry_run"] is True - assert result.json()["provider"] == "openai" - assert result.json()["status"] == "simulated" + assert error.value.source == "dry_run" + assert error.value.provider == "openai" + assert error.value.endpoint == "/v1/chat/completions" + assert error.value.model == "model-a" assert _pending_count(batcher=dry_run_batcher) == 0 - assert len(dry_run_batcher._active_batches) == 2 + assert len(dry_run_batcher._active_batches) == 1 await dry_run_batcher.close() @@ -880,17 +870,17 @@ async def failing_process_batch(*_args, **_kwargs): failing_process_batch, ) - result = await dry_run_batcher.submit( - client_type="httpx", - method="GET", - url="api.openai.com", - endpoint="/v1/chat/completions", - provider=provider, - body=b'{"model":"model-a","messages":[]}', - headers={"Authorization": "Bearer token"}, - ) + with pytest.raises(DryRunEarlyExit): + _ = await dry_run_batcher.submit( + client_type="httpx", + method="GET", + url="api.openai.com", + endpoint="/v1/chat/completions", + provider=provider, + body=b'{"model":"model-a","messages":[]}', + headers={"Authorization": "Bearer token"}, + ) - assert result.status_code == 200 assert call_count == 0 await dry_run_batcher.close() @@ -934,9 +924,10 @@ async def test_dry_run_still_batches_by_size(provider: OpenAIProvider): body=b'{"model":"model-a","messages":[]}', headers={"Authorization": "Bearer token"}, ), + return_exceptions=True, ) - assert all(isinstance(r, httpx.Response) and r.status_code == 200 for r in results) + assert all(isinstance(result, DryRunEarlyExit) for result in results) assert len(dry_run_batcher._active_batches) == 1 assert _pending_count(batcher=dry_run_batcher) == 0 @@ -966,11 +957,11 @@ async def test_dry_run_close_flushes_pending_requests(provider: OpenAIProvider): await asyncio.sleep(delay=0.05) await dry_run_batcher.close() + assert not dry_run_batcher._batch_tasks + assert not dry_run_batcher._resumed_poll_tasks - result = await task - assert isinstance(result, httpx.Response) - assert result.status_code == 200 - assert result.headers.get("x-batchling-dry-run") == "1" + with pytest.raises(DryRunEarlyExit): + _ = await task assert len(dry_run_batcher._active_batches) == 1 @@ -997,10 +988,11 @@ async def test_homogeneous_provider_same_model_uses_same_queue(): provider=provider, body=b'{"model":"gpt-4o-mini","messages":[]}', ), + return_exceptions=True, ) assert len(batcher._active_batches) == 1 - assert all(result.status_code == 200 for result in results) + assert all(isinstance(result, DryRunEarlyExit) for result in results) assert _pending_count_for_provider(batcher=batcher, provider_name=provider.name) == 0 await batcher.close() @@ -1030,7 +1022,8 @@ async def test_homogeneous_provider_pending_request_stores_queue_key(): assert queue[0].queue_key == _queue_key(provider_name=provider.name, model_name="model-a") await batcher.close() - await task + with pytest.raises(DryRunEarlyExit): + _ = await task @pytest.mark.asyncio @@ -1061,7 +1054,8 @@ async def test_strict_queue_key_stores_provider_endpoint_model(): assert queue[0].queue_key == queue_key await batcher.close() - await task + with pytest.raises(DryRunEarlyExit): + _ = await task def test_gemini_queue_key_extracts_model_from_endpoint() -> None: @@ -1364,7 +1358,7 @@ async def test_homogeneous_provider_different_models_use_distinct_queues(): provider = HomogeneousOpenAIProvider() batcher = Batcher(batch_size=2, batch_window_seconds=10.0, dry_run=True, cache=False) - await asyncio.gather( + results = await asyncio.gather( batcher.submit( client_type="httpx", method="POST", @@ -1397,8 +1391,10 @@ async def test_homogeneous_provider_different_models_use_distinct_queues(): provider=provider, body=b'{"model":"model-b","messages":[]}', ), + return_exceptions=True, ) + assert all(isinstance(result, DryRunEarlyExit) for result in results) assert len(batcher._active_batches) == 2 assert _pending_count_for_provider(batcher=batcher, provider_name=provider.name) == 0 await batcher.close() @@ -1459,7 +1455,8 @@ async def test_strict_queue_key_mixed_models_use_distinct_queues(): assert _queue_key(provider_name=provider.name, model_name="model-b") in batcher._window_tasks await batcher.close() - await asyncio.gather(task_1, task_2) + results = await asyncio.gather(task_1, task_2, return_exceptions=True) + assert all(isinstance(result, DryRunEarlyExit) for result in results) @pytest.mark.asyncio @@ -1509,7 +1506,8 @@ async def test_strict_queue_key_different_endpoints_use_distinct_queues(): ) await batcher.close() - await asyncio.gather(task_1, task_2) + results = await asyncio.gather(task_1, task_2, return_exceptions=True) + assert all(isinstance(result, DryRunEarlyExit) for result in results) @pytest.mark.asyncio @@ -1551,7 +1549,8 @@ async def test_close_flushes_all_model_scoped_queues_for_homogeneous_provider(): ) await batcher.close() - await asyncio.gather(task_1, task_2) + results = await asyncio.gather(task_1, task_2, return_exceptions=True) + assert all(isinstance(result, DryRunEarlyExit) for result in results) assert _pending_count_for_provider(batcher=batcher, provider_name=provider.name) == 0 assert len(batcher._active_batches) == 2 @@ -1931,18 +1930,21 @@ async def test_dry_run_cache_hit_is_read_only(mock_openai_api_transport: httpx.M dry_run_batcher = Batcher(batch_size=2, batch_window_seconds=0.1, cache=True, dry_run=True) dry_run_provider = OpenAIProvider() - dry_run_response = await dry_run_batcher.submit( - client_type="httpx", - method="POST", - url="api.openai.com", - endpoint="/v1/chat/completions", - provider=dry_run_provider, - headers={"Authorization": "Bearer token"}, - body=b'{"model":"model-a","messages":[]}', - ) - assert dry_run_response.headers["x-batchling-dry-run"] == "1" - assert dry_run_response.headers["x-batchling-cache-hit"] == "1" - assert dry_run_response.json()["cache_hit"] is True + with pytest.raises(DryRunEarlyExit) as dry_run_error: + _ = await dry_run_batcher.submit( + client_type="httpx", + method="POST", + url="api.openai.com", + endpoint="/v1/chat/completions", + provider=dry_run_provider, + headers={"Authorization": "Bearer token"}, + body=b'{"model":"model-a","messages":[]}', + ) + assert dry_run_error.value.source == "cache_dry_run" + assert dry_run_error.value.batch_id == original_entry.batch_id + assert dry_run_error.value.provider == "openai" + assert dry_run_error.value.endpoint == "/v1/chat/completions" + assert dry_run_error.value.model == "model-a" assert dry_run_batcher._cache_store is not None dry_run_entry = dry_run_batcher._cache_store.get_by_hash(request_hash=request_hash) assert dry_run_entry is not None @@ -1986,15 +1988,16 @@ async def test_dry_run_emits_terminal_without_poll(provider: OpenAIProvider) -> events: list[BatcherEvent] = [] dry_run_batcher._add_event_listener(listener=events.append) - _ = await dry_run_batcher.submit( - client_type="httpx", - method="POST", - url="api.openai.com", - endpoint="/v1/chat/completions", - provider=provider, - headers={"Authorization": "Bearer token"}, - body=b'{"model":"model-a","messages":[]}', - ) + with pytest.raises(DryRunEarlyExit): + _ = await dry_run_batcher.submit( + client_type="httpx", + method="POST", + url="api.openai.com", + endpoint="/v1/chat/completions", + provider=provider, + headers={"Authorization": "Bearer token"}, + body=b'{"model":"model-a","messages":[]}', + ) event_types = {str(event["event_type"]) for event in events} assert "request_queued" in event_types @@ -2039,15 +2042,16 @@ async def test_cache_hit_emits_cache_hit_routed( events: list[BatcherEvent] = [] dry_run_batcher._add_event_listener(listener=events.append) - _ = await dry_run_batcher.submit( - client_type="httpx", - method="POST", - url="api.openai.com", - endpoint="/v1/chat/completions", - provider=OpenAIProvider(), - headers={"Authorization": "Bearer token"}, - body=b'{"model":"model-a","messages":[]}', - ) + with pytest.raises(DryRunEarlyExit): + _ = await dry_run_batcher.submit( + client_type="httpx", + method="POST", + url="api.openai.com", + endpoint="/v1/chat/completions", + provider=OpenAIProvider(), + headers={"Authorization": "Bearer token"}, + body=b'{"model":"model-a","messages":[]}', + ) assert any( event["event_type"] == "cache_hit_routed" and event.get("source") == "cache_dry_run" From 7ca1d65ca67be549df767ca1e5bedb2600a0c102 Mon Sep 17 00:00:00 2001 From: Raphael Date: Mon, 2 Mar 2026 22:18:24 -0800 Subject: [PATCH 2/4] dry-run: add static summary aggregation and teardown reporting --- src/batchling/context.py | 100 +++++++++++++++++++++-- src/batchling/progress_state.py | 136 ++++++++++++++++++++++++++++++++ src/batchling/rich_display.py | 132 ++++++++++++++++++++++++++++++- tests/test_context.py | 49 ++++++++++++ tests/test_rich_display.py | 90 +++++++++++++++++++++ 5 files changed, 501 insertions(+), 6 deletions(-) diff --git a/src/batchling/context.py b/src/batchling/context.py index 82ccf12..b0dc621 100644 --- a/src/batchling/context.py +++ b/src/batchling/context.py @@ -13,6 +13,7 @@ from batchling.progress_state import BatchProgressState from batchling.rich_display import ( BatcherRichDisplay, + DryRunSummaryDisplay, should_enable_live_display, ) @@ -87,8 +88,85 @@ def __init__( self._self_live_display: BatcherRichDisplay | None = None self._self_live_display_heartbeat_task: asyncio.Task[None] | None = None self._self_polling_progress_logger: _PollingProgressLogger | None = None + self._self_dry_run_summary_display: DryRunSummaryDisplay | None = None + self._self_dry_run_summary_printed = False self._self_context_token: t.Any | None = None + def _start_dry_run_summary_listener(self) -> None: + """ + Start dry-run summary listener for static teardown reporting. + + Notes + ----- + Listener errors are downgraded to warnings to avoid breaking batching. + """ + if not self._self_batcher._dry_run: + return + if self._self_dry_run_summary_display is not None: + return + try: + display = DryRunSummaryDisplay() + self._self_batcher._add_event_listener(listener=display.on_event) + self._self_dry_run_summary_display = display + except Exception as error: + warnings.warn( + message=f"Failed to start batchling dry-run summary listener: {error}", + category=UserWarning, + stacklevel=2, + ) + + def _stop_dry_run_summary_listener(self) -> None: + """ + Stop and unregister the dry-run summary listener. + + Notes + ----- + Listener shutdown errors are downgraded to warnings. + """ + display = self._self_dry_run_summary_display + if display is None: + return + self._self_dry_run_summary_display = None + try: + self._self_batcher._remove_event_listener(listener=display.on_event) + except Exception as error: + warnings.warn( + message=f"Failed to stop batchling dry-run summary listener: {error}", + category=UserWarning, + stacklevel=2, + ) + + def _print_dry_run_summary_once(self) -> None: + """ + Print static dry-run summary report exactly once. + + Notes + ----- + Reporting errors are downgraded to warnings. + """ + display = self._self_dry_run_summary_display + if display is None: + return + if self._self_dry_run_summary_printed: + return + try: + display.print_summary() + self._self_dry_run_summary_printed = True + except Exception as error: + warnings.warn( + message=f"Failed to print batchling dry-run summary: {error}", + category=UserWarning, + stacklevel=2, + ) + + def _finalize_context_displays(self) -> None: + """ + Stop live display and finalize dry-run reporting/listeners. + """ + self._stop_live_display() + self._print_dry_run_summary_once() + self._stop_dry_run_summary_listener() + def _start_polling_progress_logger(self) -> None: """ Start the INFO polling progress fallback listener. @@ -213,6 +291,7 @@ def __enter__(self) -> None: ``None`` for scoped activation. """ self._self_context_token = active_batcher.set(self._self_batcher) + self._start_dry_run_summary_listener() self._start_live_display() return None @@ -251,18 +330,28 @@ def __exit__( category=UserWarning, stacklevel=2, ) - self._stop_live_display() + self._finalize_context_displays() - def _on_sync_close_done(self, _: asyncio.Task[None]) -> None: + def _on_sync_close_done(self, close_task: asyncio.Task[None]) -> None: """ Callback run when sync-context close task completes. Parameters ---------- - _ : asyncio.Task[None] + close_task : asyncio.Task[None] Completed close task. """ - self._stop_live_display() + try: + _ = close_task.result() + except asyncio.CancelledError: + pass + except Exception as error: + warnings.warn( + message=f"Failed to close batcher in sync context: {error}", + category=UserWarning, + stacklevel=2, + ) + self._finalize_context_displays() async def __aenter__(self) -> None: """ @@ -274,6 +363,7 @@ async def __aenter__(self) -> None: ``None`` for scoped activation. """ self._self_context_token = active_batcher.set(self._self_batcher) + self._start_dry_run_summary_listener() self._start_live_display() return None @@ -301,4 +391,4 @@ async def __aexit__( try: await self._self_batcher.close() finally: - self._stop_live_display() + self._finalize_context_displays() diff --git a/src/batchling/progress_state.py b/src/batchling/progress_state.py index 4a9b0c6..8be6770 100644 --- a/src/batchling/progress_state.py +++ b/src/batchling/progress_state.py @@ -22,6 +22,14 @@ class _TrackedBatch: terminal: bool = False +@dataclass +class _DryRunQueueSummary: + """Aggregated dry-run counters per queue key.""" + + expected_requests: int = 0 + expected_batches: int = 0 + + class BatchProgressState: """ Track batch lifecycle state and compute shared aggregate metrics. @@ -221,3 +229,131 @@ def _status_counts_as_completed(*, status: str) -> bool: if any(marker in lowered_status for marker in negative_markers): return False return True + + +class DryRunSummaryState: + """ + Aggregate dry-run request and batch estimates from lifecycle events. + + Notes + ----- + This state tracks only dry-run relevant counters and is intended to feed + a static summary rendered at context teardown. + """ + + def __init__(self) -> None: + self._would_batch_requests_total = 0 + self._would_cache_requests_total = 0 + self._queue_counts: dict[tuple[str, str, str], _DryRunQueueSummary] = {} + + def on_event(self, *, event: BatcherEvent) -> None: + """ + Update summary counters using one lifecycle event. + + Parameters + ---------- + event : BatcherEvent + Lifecycle event emitted by ``Batcher``. + """ + event_type = str(object=event.get("event_type", "unknown")) + source = str(object=event.get("source", "")) + + if event_type == "request_queued": + queue_key = self._extract_queue_key(event=event) + if queue_key is None: + return + self._would_batch_requests_total += 1 + queue_summary = self._queue_counts.setdefault(queue_key, _DryRunQueueSummary()) + queue_summary.expected_requests += 1 + return + + if event_type == "batch_processing" and source == "dry_run": + queue_key = self._extract_queue_key(event=event) + if queue_key is None: + return + queue_summary = self._queue_counts.setdefault(queue_key, _DryRunQueueSummary()) + queue_summary.expected_batches += 1 + return + + if event_type == "cache_hit_routed" and source == "cache_dry_run": + self._would_cache_requests_total += 1 + + @property + def would_batch_requests_total(self) -> int: + """ + Total number of requests that would have been batched. + + Returns + ------- + int + Global count of queued requests. + """ + return self._would_batch_requests_total + + @property + def would_cache_requests_total(self) -> int: + """ + Total number of dry-run cache-hit requests. + + Returns + ------- + int + Global dry-run cache-hit count. + """ + return self._would_cache_requests_total + + def compute_queue_rows(self) -> list[tuple[str, str, str, int, int]]: + """ + Return sorted per-queue dry-run summary rows. + + Returns + ------- + list[tuple[str, str, str, int, int]] + Rows formatted as + ``(provider, endpoint, model, expected_requests, expected_batches)``. + """ + rows = [ + ( + provider, + endpoint, + model, + queue_summary.expected_requests, + queue_summary.expected_batches, + ) + for (provider, endpoint, model), queue_summary in self._queue_counts.items() + ] + return sorted(rows, key=lambda row: (row[0], row[1], row[2])) + + @staticmethod + def _extract_queue_key(*, event: BatcherEvent) -> tuple[str, str, str] | None: + """ + Extract queue key from lifecycle event payload. + + Parameters + ---------- + event : BatcherEvent + Lifecycle event payload. + + Returns + ------- + tuple[str, str, str] | None + Queue key when available. + """ + queue_key = event.get("queue_key") + if ( + isinstance(queue_key, tuple) + and len(queue_key) == 3 + and all(isinstance(part, str) for part in queue_key) + ): + return queue_key + + provider = event.get("provider") + endpoint = event.get("endpoint") + model = event.get("model") + if provider is None or endpoint is None or model is None: + return None + return ( + str(object=provider), + str(object=endpoint), + str(object=model), + ) diff --git a/src/batchling/rich_display.py b/src/batchling/rich_display.py index 0060755..69a70cc 100644 --- a/src/batchling/rich_display.py +++ b/src/batchling/rich_display.py @@ -14,7 +14,7 @@ from rich.text import Text from batchling.core import BatcherEvent -from batchling.progress_state import BatchProgressState +from batchling.progress_state import BatchProgressState, DryRunSummaryState class BatcherRichDisplay: @@ -304,6 +304,136 @@ def _format_queue_progress(*, running: int, completed: int) -> Text: return progress +class DryRunSummaryDisplay: + """ + Render a static Rich report for dry-run planning totals. + + Parameters + ---------- + console : Console | None, optional + Rich console to render to. Defaults to ``Console(stderr=True)``. + """ + + def __init__( + self, + *, + console: Console | None = None, + ) -> None: + self._console = console or Console(stderr=True) + self._summary_state = DryRunSummaryState() + + def on_event(self, event: BatcherEvent) -> None: + """ + Consume one lifecycle event for dry-run summary aggregation. + + Parameters + ---------- + event : BatcherEvent + Lifecycle event emitted by ``Batcher``. + """ + self._summary_state.on_event(event=event) + + def print_summary(self) -> None: + """Print the static dry-run report panel.""" + self._console.print(self._render()) + + def _render(self) -> Panel: + """Build the static dry-run summary panel.""" + return Panel( + renderable=Group( + self._build_totals_line(), + self._build_queue_summary_table(), + ), + title="batchling dry run summary", + border_style="yellow", + ) + + def _build_totals_line(self) -> Text: + """ + Build top-level totals line for the dry-run report. + + Returns + ------- + Text + Styled totals text. + """ + line = Text() + line.append(text="Would Batch", style="grey70") + line.append(text=": ", style="grey70") + line.append( + text=str(object=self._summary_state.would_batch_requests_total), + style="bold cyan", + ) + line.append(text=" - ", style="grey70") + line.append(text="Would Cache", style="grey70") + line.append(text=": ", style="grey70") + line.append( + text=str(object=self._summary_state.would_cache_requests_total), + style="bold magenta", + ) + return line + + def _build_queue_summary_table(self) -> Table: + """ + Build queue-level dry-run estimate table. + + Returns + ------- + Table + Queue estimate table. + """ + queue_rows = self._summary_state.compute_queue_rows() + table = Table(expand=False) + table.add_column( + header="provider", + style="bold blue", + width=12, + no_wrap=True, + overflow="ellipsis", + ) + table.add_column( + header="endpoint", + width=34, + no_wrap=True, + overflow="ellipsis", + ) + table.add_column( + header="model", + style="bold magenta", + width=28, + no_wrap=True, + overflow="ellipsis", + ) + table.add_column( + header="expected requests", + justify="right", + width=17, + no_wrap=True, + overflow="ellipsis", + ) + table.add_column( + header="expected batches", + justify="right", + width=16, + no_wrap=True, + overflow="ellipsis", + ) + + if not queue_rows: + table.add_row("-", "-", "-", "0", "0") + return table + + for provider, endpoint, model, expected_requests, expected_batches in queue_rows: + table.add_row( + provider, + endpoint, + model, + str(object=expected_requests), + str(object=expected_batches), + ) + return table + + def should_enable_live_display(*, enabled: bool) -> bool: """ Resolve if the Rich live panel should be enabled. diff --git a/tests/test_context.py b/tests/test_context.py index 72c6eb9..15536ff 100644 --- a/tests/test_context.py +++ b/tests/test_context.py @@ -12,7 +12,9 @@ from batchling.context import BatchingContext from batchling.core import Batcher +from batchling.exceptions import DryRunEarlyExit from batchling.hooks import active_batcher +from batchling.providers.openai import OpenAIProvider @pytest.fixture @@ -216,3 +218,50 @@ def test_batching_context_uses_polling_progress_fallback_when_auto_disabled( context._stop_live_display() assert context._self_polling_progress_logger is None + + +@pytest.mark.asyncio +async def test_batching_context_prints_dry_run_report_when_live_display_disabled( + reset_context: None, + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Test dry-run report prints on context exit even when live display is disabled.""" + + class DummyDryRunSummaryDisplay: + """Minimal dry-run summary display stub.""" + + def __init__(self) -> None: + self.print_calls = 0 + + def on_event(self, event: dict[str, t.Any]) -> None: + del event + + def print_summary(self) -> None: + self.print_calls += 1 + + dry_run_display = DummyDryRunSummaryDisplay() + monkeypatch.setattr("batchling.context.DryRunSummaryDisplay", lambda: dry_run_display) + + batcher = Batcher( + batch_size=1, + batch_window_seconds=1.0, + dry_run=True, + cache=False, + ) + context = BatchingContext( + batcher=batcher, + live_display=False, + ) + + with pytest.raises(DryRunEarlyExit): + async with context: + _ = await batcher.submit( + client_type="httpx", + method="POST", + url="api.openai.com", + endpoint="/v1/chat/completions", + provider=OpenAIProvider(), + body=b'{"model":"model-a","messages":[]}', + ) + + assert dry_run_display.print_calls == 1 diff --git a/tests/test_rich_display.py b/tests/test_rich_display.py index 2d37959..2bc29be 100644 --- a/tests/test_rich_display.py +++ b/tests/test_rich_display.py @@ -368,3 +368,93 @@ def test_batcher_rich_display_queue_progress_pads_to_total_width() -> None: completed=1, ) assert progress_text.plain == " 1/100 (1.0%)" + + +def test_dry_run_summary_display_aggregates_totals_and_queues() -> None: + """Test dry-run static summary tracks expected totals and queue estimates.""" + display = rich_display.DryRunSummaryDisplay( + console=Console(file=io.StringIO(), force_terminal=False), + ) + + display.on_event( + { + "event_type": "request_queued", + "provider": "openai", + "endpoint": "/v1/chat/completions", + "model": "model-a", + "queue_key": ("openai", "/v1/chat/completions", "model-a"), + "custom_id": "1", + } + ) + display.on_event( + { + "event_type": "request_queued", + "provider": "openai", + "endpoint": "/v1/chat/completions", + "model": "model-a", + "queue_key": ("openai", "/v1/chat/completions", "model-a"), + "custom_id": "2", + } + ) + display.on_event( + { + "event_type": "request_queued", + "provider": "groq", + "endpoint": "/openai/v1/chat/completions", + "model": "llama-3.1-8b-instant", + "queue_key": ("groq", "/openai/v1/chat/completions", "llama-3.1-8b-instant"), + "custom_id": "3", + } + ) + display.on_event( + { + "event_type": "batch_processing", + "source": "dry_run", + "provider": "openai", + "endpoint": "/v1/chat/completions", + "model": "model-a", + "queue_key": ("openai", "/v1/chat/completions", "model-a"), + "request_count": 2, + } + ) + display.on_event( + { + "event_type": "batch_processing", + "source": "dry_run", + "provider": "groq", + "endpoint": "/openai/v1/chat/completions", + "model": "llama-3.1-8b-instant", + "queue_key": ("groq", "/openai/v1/chat/completions", "llama-3.1-8b-instant"), + "request_count": 1, + } + ) + display.on_event( + { + "event_type": "cache_hit_routed", + "source": "cache_dry_run", + "provider": "openai", + "endpoint": "/v1/chat/completions", + "model": "model-a", + "batch_id": "batch-1", + "custom_id": "4", + } + ) + + totals_line = display._build_totals_line() + assert totals_line.plain == "Would Batch: 3 - Would Cache: 1" + + queue_table = display._build_queue_summary_table() + assert queue_table.columns[0]._cells == ["groq", "openai"] + assert queue_table.columns[3]._cells == ["1", "2"] + assert queue_table.columns[4]._cells == ["1", "1"] + + +def test_dry_run_summary_display_renders_empty_state() -> None: + """Test dry-run summary table defaults to zero row without events.""" + display = rich_display.DryRunSummaryDisplay( + console=Console(file=io.StringIO(), force_terminal=False), + ) + queue_table = display._build_queue_summary_table() + assert queue_table.columns[0]._cells == ["-"] + assert queue_table.columns[3]._cells == ["0"] + assert queue_table.columns[4]._cells == ["0"] From 8bd7ae9024f3cf5c639fbf5a0e7300a3f34f137d Mon Sep 17 00:00:00 2001 From: Raphael Date: Mon, 2 Mar 2026 22:18:30 -0800 Subject: [PATCH 3/4] docs: update dry-run contract and teardown summary behavior --- docs/architecture/api.md | 5 +++-- docs/architecture/context.md | 8 +++++--- docs/architecture/core.md | 6 ++++-- docs/dry-run.md | 36 +++++++++++++++++++++++++++++++++--- src/batchling/api.py | 3 ++- 5 files changed, 47 insertions(+), 11 deletions(-) diff --git a/docs/architecture/api.md b/docs/architecture/api.md index 637636a..f9f1d35 100644 --- a/docs/architecture/api.md +++ b/docs/architecture/api.md @@ -21,8 +21,9 @@ yields `None`. Import it from `batchling`. strict queue key `(provider, endpoint, model)`. - **`dry_run` behavior**: when `dry_run=True`, requests are still intercepted, queued, and grouped using normal window/size triggers, but provider batch submission and polling - are skipped. Requests resolve with synthetic `httpx.Response` objects marked with - `x-batchling-dry-run: 1`. + are skipped. Intercepted requests raise `DryRunEarlyExit` on return instead of + producing synthetic provider responses. A static dry-run summary report is emitted + at context teardown. - **`cache` behavior**: when `cache=True` (default), intercepted requests are fingerprinted and looked up in a persistent request cache. Cache hits bypass queueing and resume polling from an existing provider batch when not in dry-run mode. diff --git a/docs/architecture/context.md b/docs/architecture/context.md index 8e67530..ff98f36 100644 --- a/docs/architecture/context.md +++ b/docs/architecture/context.md @@ -10,6 +10,7 @@ a context variable. - Yield `None` for scope-only lifecycle control. - Support sync and async context manager patterns for cleanup and context scoping. - Start and stop optional Rich live activity display while the context is active. +- In dry-run mode, aggregate and print a static Rich summary at teardown. ## Flow summary @@ -20,9 +21,10 @@ a context variable. 4. If `live_display=True`, the context attempts to start Rich panel rendering at enter-time when terminal auto-detection passes (`TTY`, non-`dumb`, non-`CI`). Otherwise it registers an `INFO` logging fallback that emits progress at poll-time. -5. `__aexit__` resets the context and awaits `batcher.close()` to flush pending work. -6. The live display listener is removed and the panel is stopped when context cleanup - finishes. +5. In dry-run mode, a dedicated summary listener is also registered at enter-time. +6. `__aexit__` resets the context and awaits `batcher.close()` to flush pending work. +7. On teardown, the context prints one static dry-run summary report when dry-run is enabled. +8. Display/listener cleanup runs after close completes. ## Code reference diff --git a/docs/architecture/core.md b/docs/architecture/core.md index 5e8ccbd..07837ca 100644 --- a/docs/architecture/core.md +++ b/docs/architecture/core.md @@ -34,9 +34,11 @@ resolves futures back to callers. 7. `close()` flushes remaining requests and cancels timers. In `dry_run` mode, step 3 and provider polling are bypassed: `_process_batch()` still -creates `_ActiveBatch` for tracking, then resolves each request immediately with a -synthetic `httpx.Response` (`200`) marked with `x-batchling-dry-run: 1`. +creates `_ActiveBatch` for tracking, then resolves each request by raising +`DryRunEarlyExit`. Cache lookups remain enabled in dry-run mode for hit accounting, but cache writes are disabled. +`close()` also waits for in-flight background submission/poll tasks so teardown +reporting has stable totals. ## Extension notes diff --git a/docs/dry-run.md b/docs/dry-run.md index f166b8c..de0109b 100644 --- a/docs/dry-run.md +++ b/docs/dry-run.md @@ -4,9 +4,39 @@ This feature exists for users to be able to debug and better understand what WILL happen when they ultimately disable the flag, giving them the transparency required to be confident in the library. -In practice, the dry run feature deactivates all batch submissions, but everything is done virtually, which means we can count incoming requests, number of batch we would have created, etc.. - -To put it simply, it provides users with an exact breakdown of what their batched inference run would have been for real. +In practice, dry-run deactivates all provider submissions while keeping the +internal batching path active (queueing, windowing, and per-queue grouping). + +To put it simply, it provides users with an exact breakdown of what their +batched inference run would have been for real. + +## Behavior details + +- Requests are still intercepted and grouped by queue key + `(provider, endpoint, model)`. +- Provider submission/polling is skipped. +- Intercepted requests raise `DryRunEarlyExit` instead of returning synthetic + provider responses. +- The CLI catches `DryRunEarlyExit` and exits cleanly after printing the report. +- On context exit, batchling prints a static Rich summary with: + - total requests that would have been batched + - total requests that would have been cache hits + - per-queue expected requests and expected batch counts + +## SDK handling + +When running the SDK directly, catch `DryRunEarlyExit` if you want to continue +control flow after the first intercepted request: + +```python +from batchling import DryRunEarlyExit, batchify + +try: + async with batchify(dry_run=True): + ... +except DryRunEarlyExit: + pass +``` ## Activating dry run diff --git a/src/batchling/api.py b/src/batchling/api.py index 46a2f88..bbfaac0 100644 --- a/src/batchling/api.py +++ b/src/batchling/api.py @@ -34,7 +34,8 @@ def batchify( dry_run : bool, optional If ``True``, intercept and batch requests without sending provider batches.
Use it to debug or before sending big jobs.
- Batched requests resolve to synthetic responses. + Intercepted requests raise ``DryRunEarlyExit`` instead of returning + provider responses. cache : bool, optional If ``True``, enable persistent request cache lookups.
This parameter allows to skip the batch submission and go straight to the polling phase for requests that have already been sent. From 4741caf72dbe987f55553ce669d6a478f28e8070 Mon Sep 17 00:00:00 2001 From: Raphael Date: Mon, 2 Mar 2026 22:47:58 -0800 Subject: [PATCH 4/4] feat: add dry-run rich report --- docs/dry-run.md | 51 +++-- examples/many.py | 368 ++++++++++++++++++++++++++++++++++ src/batchling/context.py | 2 + src/batchling/exceptions.py | 2 +- src/batchling/rich_display.py | 4 +- tests/test_context.py | 42 ++++ tests/test_rich_display.py | 2 +- 7 files changed, 441 insertions(+), 30 deletions(-) create mode 100644 examples/many.py diff --git a/docs/dry-run.md b/docs/dry-run.md index de0109b..492f68a 100644 --- a/docs/dry-run.md +++ b/docs/dry-run.md @@ -10,34 +10,33 @@ internal batching path active (queueing, windowing, and per-queue grouping). To put it simply, it provides users with an exact breakdown of what their batched inference run would have been for real. -## Behavior details - -- Requests are still intercepted and grouped by queue key - `(provider, endpoint, model)`. -- Provider submission/polling is skipped. -- Intercepted requests raise `DryRunEarlyExit` instead of returning synthetic - provider responses. -- The CLI catches `DryRunEarlyExit` and exits cleanly after printing the report. -- On context exit, batchling prints a static Rich summary with: - - total requests that would have been batched - - total requests that would have been cache hits - - per-queue expected requests and expected batch counts - -## SDK handling - -When running the SDK directly, catch `DryRunEarlyExit` if you want to continue -control flow after the first intercepted request: - -```python -from batchling import DryRunEarlyExit, batchify - -try: - async with batchify(dry_run=True): - ... -except DryRunEarlyExit: - pass +Sample output: + +```text +╭────────────────────────────────────────────── batchling dry run summary ───────────────────────────────────────────────╮ +│ Batchable Requests: 8 - Cache Hit Requests: 0 │ +│ ┏━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┓ │ +│ ┃ provider ┃ endpoint ┃ model ┃ expected reques… ┃ expected batch… ┃ │ +│ ┡━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━┩ │ +│ │ anthropic │ /v1/messages │ claude-haiku-4-5 │ 1 │ 1 │ │ +│ │ doubleword │ /v1/responses │ openai/gpt-oss-20b │ 1 │ 1 │ │ +│ │ gemini │ /v1beta/models/gemini-2.5-flash-… │ gemini-2.5-flash-lite │ 1 │ 1 │ │ +│ │ groq │ /openai/v1/chat/completions │ llama-3.1-8b-instant │ 1 │ 1 │ │ +│ │ mistral │ /v1/chat/completions │ mistral-medium-2505 │ 1 │ 1 │ │ +│ │ openai │ /v1/responses │ gpt-4o-mini │ 1 │ 1 │ │ +│ │ together │ /v1/chat/completions │ google/gemma-3n-E4B-it │ 1 │ 1 │ │ +│ │ xai │ /v1/chat/completions │ grok-4-1-fast-non-reasoning │ 1 │ 1 │ │ +│ └─────────────┴───────────────────────────────────┴─────────────────────────────┴──────────────────┴─────────────────┘ │ +╰────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯ ``` +## Avoid partial counts + +Dry-run exits as soon as the first intercepted request returns, which can lead +to partial totals if requests are awaited one by one. To let batchling see the +full request set before exit, schedule requests together and await them with +`asyncio.gather`. + ## Activating dry run Dry run is activated by setting up a flag in the CLI or SDK: diff --git a/examples/many.py b/examples/many.py new file mode 100644 index 0000000..34d59f7 --- /dev/null +++ b/examples/many.py @@ -0,0 +1,368 @@ +import asyncio +import os +import typing as t +from dataclasses import dataclass + +from dotenv import load_dotenv + +from batchling import batchify + +load_dotenv() + + +ProviderRequestRunner = t.Callable[..., t.Coroutine[t.Any, t.Any, tuple[str, str]]] + + +@dataclass +class ProviderRequestSpec: + """ + One provider request definition. + + Parameters + ---------- + provider : str + Provider display name. + env_var : str + Environment variable holding the API key. + request_runner : ProviderRequestRunner + Coroutine function sending one request and returning ``(model, answer)``. + """ + + provider: str + env_var: str + request_runner: ProviderRequestRunner + + +async def run_openai_request(*, prompt: str) -> tuple[str, str]: + """ + Send one OpenAI responses request. + + Parameters + ---------- + prompt : str + User question. + + Returns + ------- + tuple[str, str] + ``(model_name, answer_text)``. + """ + from openai import AsyncOpenAI + + client = AsyncOpenAI(api_key=os.getenv(key="OPENAI_API_KEY")) + response = await client.responses.create( + input=prompt, + model="gpt-4o-mini", + ) + content = response.output[-1].content + return response.model, content[0].text + + +async def run_anthropic_request(*, prompt: str) -> tuple[str, str]: + """ + Send one Anthropic messages request. + + Parameters + ---------- + prompt : str + User question. + + Returns + ------- + tuple[str, str] + ``(model_name, answer_text)``. + """ + from anthropic import AsyncAnthropic + + client = AsyncAnthropic(api_key=os.getenv(key="ANTHROPIC_API_KEY")) + response = await client.messages.create( + model="claude-haiku-4-5", + max_tokens=512, + messages=[ + { + "role": "user", + "content": prompt, + } + ], + ) + return response.model, response.content[0].text + + +async def run_groq_request(*, prompt: str) -> tuple[str, str]: + """ + Send one Groq chat completion request. + + Parameters + ---------- + prompt : str + User question. + + Returns + ------- + tuple[str, str] + ``(model_name, answer_text)``. + """ + from groq import AsyncGroq + + client = AsyncGroq(api_key=os.getenv(key="GROQ_API_KEY")) + response = await client.chat.completions.create( + model="llama-3.1-8b-instant", + messages=[ + { + "role": "user", + "content": prompt, + } + ], + ) + return response.model, str(object=response.choices[0].message.content) + + +async def run_mistral_request(*, prompt: str) -> tuple[str, str]: + """ + Send one Mistral chat completion request. + + Parameters + ---------- + prompt : str + User question. + + Returns + ------- + tuple[str, str] + ``(model_name, answer_text)``. + """ + from mistralai import Mistral + + client = Mistral(api_key=os.getenv(key="MISTRAL_API_KEY")) + response = await client.chat.complete_async( + model="mistral-medium-2505", + stream=False, + response_format={"type": "text"}, + messages=[ + { + "role": "user", + "content": prompt, + } + ], + ) + return response.model, str(object=response.choices[0].message.content) + + +async def run_together_request(*, prompt: str) -> tuple[str, str]: + """ + Send one Together chat completion request. + + Parameters + ---------- + prompt : str + User question. + + Returns + ------- + tuple[str, str] + ``(model_name, answer_text)``. + """ + from together import AsyncTogether + + client = AsyncTogether(api_key=os.getenv(key="TOGETHER_API_KEY")) + response = await client.chat.completions.create( + model="google/gemma-3n-E4B-it", + messages=[ + { + "role": "user", + "content": prompt, + } + ], + ) + return response.model, str(object=response.choices[0].message.content) + + +async def run_doubleword_request(*, prompt: str) -> tuple[str, str]: + """ + Send one Doubleword responses request. + + Parameters + ---------- + prompt : str + User question. + + Returns + ------- + tuple[str, str] + ``(model_name, answer_text)``. + """ + from openai import AsyncOpenAI + + client = AsyncOpenAI( + api_key=os.getenv(key="DOUBLEWORD_API_KEY"), + base_url="https://api.doubleword.ai/v1", + ) + response = await client.responses.create( + input=prompt, + model="openai/gpt-oss-20b", + ) + content = response.output[-1].content + return response.model, content[0].text + + +async def run_xai_request(*, prompt: str) -> tuple[str, str]: + """ + Send one XAI chat completion request. + + Parameters + ---------- + prompt : str + User question. + + Returns + ------- + tuple[str, str] + ``(model_name, answer_text)``. + """ + from openai import AsyncOpenAI + + client = AsyncOpenAI( + api_key=os.getenv(key="XAI_API_KEY"), + base_url="https://api.x.ai/v1", + ) + response = await client.chat.completions.create( + model="grok-4-1-fast-non-reasoning", + messages=[ + { + "role": "user", + "content": prompt, + } + ], + ) + model_name = str(object=response.chat_get_completion["model"]) + answer_text = str(object=response.chat_get_completion["choices"][0]["message"]["content"]) + return model_name, answer_text + + +async def run_gemini_request(*, prompt: str) -> tuple[str, str]: + """ + Send one Gemini generate_content request. + + Parameters + ---------- + prompt : str + User question. + + Returns + ------- + tuple[str, str] + ``(model_name, answer_text)``. + """ + from google import genai + + client = genai.Client(api_key=os.getenv(key="GEMINI_API_KEY")).aio + response = await client.models.generate_content( + model="gemini-2.5-flash-lite", + contents=prompt, + ) + return response.model_version, str(object=response.text) + + +def build_provider_specs() -> list[ProviderRequestSpec]: + """ + Return all provider request specs supported by this example. + + Returns + ------- + list[ProviderRequestSpec] + Provider definitions for one-request execution. + """ + return [ + ProviderRequestSpec( + provider="openai", + env_var="OPENAI_API_KEY", + request_runner=run_openai_request, + ), + ProviderRequestSpec( + provider="anthropic", + env_var="ANTHROPIC_API_KEY", + request_runner=run_anthropic_request, + ), + ProviderRequestSpec( + provider="groq", + env_var="GROQ_API_KEY", + request_runner=run_groq_request, + ), + ProviderRequestSpec( + provider="mistral", + env_var="MISTRAL_API_KEY", + request_runner=run_mistral_request, + ), + ProviderRequestSpec( + provider="together", + env_var="TOGETHER_API_KEY", + request_runner=run_together_request, + ), + ProviderRequestSpec( + provider="doubleword", + env_var="DOUBLEWORD_API_KEY", + request_runner=run_doubleword_request, + ), + ProviderRequestSpec( + provider="xai", + env_var="XAI_API_KEY", + request_runner=run_xai_request, + ), + ProviderRequestSpec( + provider="gemini", + env_var="GEMINI_API_KEY", + request_runner=run_gemini_request, + ), + ] + + +def build_enabled_provider_specs() -> list[ProviderRequestSpec]: + """ + Return provider specs that have API keys configured. + + Returns + ------- + list[ProviderRequestSpec] + Enabled provider definitions. + """ + enabled_specs: list[ProviderRequestSpec] = [] + for spec in build_provider_specs(): + if os.getenv(key=spec.env_var): + enabled_specs.append(spec) + return enabled_specs + + +async def main() -> None: + """ + Ask one question to many providers and collect all answers with gather. + + Notes + ----- + This example intentionally uses ``asyncio.gather(..., return_exceptions=True)`` + so all provider outcomes are collected in one pass. + """ + question = "Give one short sentence explaining what asynchronous batching is." + enabled_specs = build_enabled_provider_specs() + if not enabled_specs: + print("No providers configured. Set at least one provider API key in your environment.") + return + + tasks = [spec.request_runner(prompt=question) for spec in enabled_specs] + results = await asyncio.gather(*tasks, return_exceptions=True) + + for spec, result in zip(enabled_specs, results, strict=True): + if isinstance(result, BaseException): + print(f"{spec.provider} error:\n{type(result).__name__}: {result}\n") + continue + model_name, answer_text = result + print(f"{spec.provider} ({model_name}) answer:\n{answer_text}\n") + + +async def run_with_batchify() -> None: + """Run `main` inside `batchify` for direct script execution.""" + async with batchify(): + await main() + + +if __name__ == "__main__": + asyncio.run(run_with_batchify()) diff --git a/src/batchling/context.py b/src/batchling/context.py index b0dc621..3135300 100644 --- a/src/batchling/context.py +++ b/src/batchling/context.py @@ -205,6 +205,8 @@ def _start_live_display(self) -> None: """ if self._self_live_display is not None or self._self_polling_progress_logger is not None: return + if self._self_batcher._dry_run: + return if not self._self_live_display_enabled: return if not should_enable_live_display(enabled=self._self_live_display_enabled): diff --git a/src/batchling/exceptions.py b/src/batchling/exceptions.py index 3fcef03..ed2e235 100644 --- a/src/batchling/exceptions.py +++ b/src/batchling/exceptions.py @@ -3,7 +3,7 @@ from __future__ import annotations -class DryRunEarlyExit(RuntimeError): +class DryRunEarlyExit(BaseException): """ Raised when dry-run mode exits before returning a provider response. diff --git a/src/batchling/rich_display.py b/src/batchling/rich_display.py index 69a70cc..811e417 100644 --- a/src/batchling/rich_display.py +++ b/src/batchling/rich_display.py @@ -358,14 +358,14 @@ def _build_totals_line(self) -> Text: Styled totals text. """ line = Text() - line.append(text="Would Batch", style="grey70") + line.append(text="Batchable Requests", style="grey70") line.append(text=": ", style="grey70") line.append( text=str(object=self._summary_state.would_batch_requests_total), style="bold cyan", ) line.append(text=" - ", style="grey70") - line.append(text="Would Cache", style="grey70") + line.append(text="Cache Hit Requests", style="grey70") line.append(text=": ", style="grey70") line.append( text=str(object=self._summary_state.would_cache_requests_total), diff --git a/tests/test_context.py b/tests/test_context.py index 15536ff..ba780a5 100644 --- a/tests/test_context.py +++ b/tests/test_context.py @@ -220,6 +220,48 @@ def test_batching_context_uses_polling_progress_fallback_when_auto_disabled( assert context._self_polling_progress_logger is None +def test_batching_context_skips_live_display_in_dry_run( + reset_context: None, + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Test dry-run mode does not start live display listeners.""" + + class DummyDisplay: + """Simple display stub.""" + + def __init__(self) -> None: + self.started = False + + def start(self) -> None: + self.started = True + + def stop(self) -> None: + return None + + def on_event(self, event: dict[str, t.Any]) -> None: + del event + + dummy_display = DummyDisplay() + monkeypatch.setattr("batchling.context.BatcherRichDisplay", lambda: dummy_display) + monkeypatch.setattr("batchling.context.should_enable_live_display", lambda **_kwargs: True) + + batcher = Batcher( + batch_size=1, + batch_window_seconds=1.0, + dry_run=True, + cache=False, + ) + context = BatchingContext( + batcher=batcher, + live_display=True, + ) + + context._start_live_display() + assert dummy_display.started is False + assert context._self_live_display is None + assert context._self_polling_progress_logger is None + + @pytest.mark.asyncio async def test_batching_context_prints_dry_run_report_when_live_display_disabled( reset_context: None, diff --git a/tests/test_rich_display.py b/tests/test_rich_display.py index 2bc29be..4a87b8d 100644 --- a/tests/test_rich_display.py +++ b/tests/test_rich_display.py @@ -441,7 +441,7 @@ def test_dry_run_summary_display_aggregates_totals_and_queues() -> None: ) totals_line = display._build_totals_line() - assert totals_line.plain == "Would Batch: 3 - Would Cache: 1" + assert totals_line.plain == "Batchable Requests: 3 - Cache Hit Requests: 1" queue_table = display._build_queue_summary_table() assert queue_table.columns[0]._cells == ["groq", "openai"]