From 67c44d97157b6f1e1870a42fa4da1c36d94e5746 Mon Sep 17 00:00:00 2001 From: Raphael Date: Mon, 2 Mar 2026 17:17:07 -0800 Subject: [PATCH 01/20] batcher: emit queue/poll lifecycle events for live UI --- src/batchling/core.py | 263 +++++++++++++++++++++++++++++++++++++++++- tests/test_core.py | 141 +++++++++++++++++++++- 2 files changed, 402 insertions(+), 2 deletions(-) diff --git a/src/batchling/core.py b/src/batchling/core.py index 65dc7aa..f27c451 100644 --- a/src/batchling/core.py +++ b/src/batchling/core.py @@ -31,6 +31,59 @@ CACHE_RETENTION_SECONDS = 30 * 24 * 60 * 60 +class BatcherEvent(t.TypedDict, total=False): + """ + Lifecycle event emitted by ``Batcher`` for optional observers. + + event_type : str + Event identifier. + timestamp : float + Event timestamp in UNIX seconds. + provider : str + Provider name. + endpoint : str + Request endpoint. + model : str + Request model. + queue_key : QueueKey + Queue identifier. + batch_id : str + Provider batch identifier. + status : str + Provider batch status. + request_count : int + Number of requests in the event scope. + pending_count : int + Current queue pending size. + custom_id : str + Custom request identifier. + source : str + Event source subsystem. + error : str + Error text. + missing_count : int + Number of missing results. + """ + + event_type: str + timestamp: float + provider: str + endpoint: str + model: str + queue_key: QueueKey + batch_id: str + status: str + request_count: int + pending_count: int + custom_id: str + source: str + error: str + missing_count: int + + +BatcherEventListener = t.Callable[[BatcherEvent], None] + + @dataclass class _PendingRequest: # FIXME: _PendingRequest can use a generic type to match any request from: @@ -135,6 +188,7 @@ def __init__( self._resumed_poll_tasks: set[asyncio.Task[None]] = set() self._resumed_batches: dict[ResumedBatchKey, _ResumedBatch] = {} self._resumed_lock = asyncio.Lock() + self._event_listeners: set[BatcherEventListener] = set() self._cache_store: RequestCacheStore | None = None if self._cache_enabled: @@ -163,6 +217,73 @@ def __init__( ), ) + def _add_event_listener( + self, + *, + listener: BatcherEventListener, + ) -> None: + """ + Register a listener receiving batch lifecycle events. + + Parameters + ---------- + listener : BatcherEventListener + Observer callback. + """ + self._event_listeners.add(listener) + + def _remove_event_listener( + self, + *, + listener: BatcherEventListener, + ) -> None: + """ + Unregister a previously registered lifecycle listener. + + Parameters + ---------- + listener : BatcherEventListener + Observer callback. + """ + self._event_listeners.discard(listener) + + def _emit_event( + self, + *, + event_type: str, + **payload: t.Any, + ) -> None: + """ + Emit a lifecycle event to all registered listeners. + + Parameters + ---------- + event_type : str + Event identifier. + **payload : typing.Any + Event payload. + """ + if not self._event_listeners: + return + event = t.cast( + typ=BatcherEvent, + val={ + "event_type": event_type, + "timestamp": time.time(), + **payload, + }, + ) + for listener in list(self._event_listeners): + try: + listener(event) + except Exception as error: + log_debug( + logger=log, + event="Batcher event listener failed", + listener=repr(listener), + error=str(object=error), + ) + @staticmethod def _format_queue_key(*, queue_key: QueueKey) -> str: """ @@ -353,6 +474,16 @@ async def _try_submit_from_cache( batch_id=cache_entry.batch_id, custom_id=cache_entry.custom_id, ) + cache_source = "cache_dry_run" if self._dry_run else "resumed_poll" + self._emit_event( + event_type="cache_hit_routed", + provider=provider_name, + endpoint=endpoint, + model=model_name, + batch_id=cache_entry.batch_id, + custom_id=cache_entry.custom_id, + source=cache_source, + ) if self._dry_run: dry_run_request = _PendingRequest( custom_id=cache_entry.custom_id, @@ -401,6 +532,15 @@ async def _enqueue_pending_request( queue = self._pending_by_provider.setdefault(queue_key, []) queue.append(request) pending_count = len(queue) + self._emit_event( + event_type="request_queued", + provider=queue_key[0], + endpoint=queue_key[1], + model=queue_key[2], + queue_key=queue_key, + pending_count=pending_count, + custom_id=request.custom_id, + ) if pending_count == 1: self._window_tasks[queue_key] = asyncio.create_task( @@ -610,7 +750,6 @@ async def _attach_cached_request( future=future, ) ) - if should_start_poller: task = asyncio.create_task( coro=self._poll_cached_batch(resume_key=resume_key), @@ -686,6 +825,15 @@ async def _window_timer(self, *, queue_key: QueueKey) -> None: queue_key=queue_name, error=str(object=e), ) + self._emit_event( + event_type="window_timer_error", + provider=provider_name, + endpoint=queue_endpoint, + model=model_name, + queue_key=queue_key, + error=str(object=e), + source="window_timer", + ) await self._fail_pending_provider_requests( queue_key=queue_key, error=e, @@ -722,6 +870,14 @@ async def _submit_requests( queue_key=queue_name, request_count=len(requests), ) + self._emit_event( + event_type="batch_submitting", + provider=provider_name, + endpoint=queue_endpoint, + model=model_name, + queue_key=queue_key, + request_count=len(requests), + ) task = asyncio.create_task( coro=self._process_batch(queue_key=queue_key, requests=requests), name=f"batch_submit_{queue_name}_{uuid.uuid4()}", @@ -876,6 +1032,16 @@ async def _process_batch( try: if self._dry_run: dry_run_batch_id = f"dryrun-{uuid.uuid4()}" + self._emit_event( + event_type="batch_processing", + provider=provider.name, + endpoint=queue_endpoint, + model=model_name, + queue_key=queue_key, + request_count=len(requests), + batch_id=dry_run_batch_id, + source="dry_run", + ) active_batch = _ActiveBatch( batch_id=dry_run_batch_id, output_file_id="", @@ -900,6 +1066,17 @@ async def _process_batch( batch_id=dry_run_batch_id, request_count=len(requests), ) + self._emit_event( + event_type="batch_terminal", + provider=provider.name, + endpoint=queue_endpoint, + model=model_name, + queue_key=queue_key, + request_count=len(requests), + batch_id=dry_run_batch_id, + status="simulated", + source="dry_run", + ) return log_info( @@ -911,11 +1088,30 @@ async def _process_batch( queue_key=queue_name, request_count=len(requests), ) + self._emit_event( + event_type="batch_processing", + provider=provider.name, + endpoint=queue_endpoint, + model=model_name, + queue_key=queue_key, + request_count=len(requests), + source="submit", + ) batch_submission = await provider.process_batch( requests=requests, client_factory=self._client_factory, queue_key=queue_key, ) + self._emit_event( + event_type="batch_processing", + provider=provider.name, + endpoint=queue_endpoint, + model=model_name, + queue_key=queue_key, + request_count=len(requests), + batch_id=batch_submission.batch_id, + source="poll_start", + ) self._write_cache_entries( queue_key=queue_key, requests=requests, @@ -948,6 +1144,15 @@ async def _process_batch( queue_key=queue_name, error=str(object=e), ) + self._emit_event( + event_type="batch_failed", + provider=provider_name, + endpoint=queue_endpoint, + model=model_name, + queue_key=queue_key, + request_count=len(requests), + error=str(object=e), + ) for req in requests: if not req.future.done(): req.future.set_exception(e) @@ -1056,6 +1261,13 @@ async def _poll_batch( api_headers=api_headers, batch_id=active_batch.batch_id, ) + self._emit_event( + event_type="batch_polled", + provider=provider.name, + batch_id=active_batch.batch_id, + status=poll_snapshot.status, + source="active_poll", + ) active_batch.output_file_id = poll_snapshot.output_file_id active_batch.error_file_id = poll_snapshot.error_file_id @@ -1067,6 +1279,13 @@ async def _poll_batch( batch_id=active_batch.batch_id, status=poll_snapshot.status, ) + self._emit_event( + event_type="batch_terminal", + provider=provider.name, + batch_id=active_batch.batch_id, + status=poll_snapshot.status, + source="active_poll", + ) await self._resolve_batch_results( base_url=base_url, api_headers=api_headers, @@ -1105,7 +1324,21 @@ async def _poll_cached_batch( api_headers=resumed_batch.api_headers, batch_id=batch_id, ) + self._emit_event( + event_type="batch_polled", + provider=provider.name, + batch_id=batch_id, + status=poll_snapshot.status, + source="resumed_poll", + ) if poll_snapshot.status in provider.batch_terminal_states: + self._emit_event( + event_type="batch_terminal", + provider=provider.name, + batch_id=batch_id, + status=poll_snapshot.status, + source="resumed_poll", + ) await self._resolve_cached_batch_results( resume_key=resume_key, output_file_id=poll_snapshot.output_file_id, @@ -1117,6 +1350,13 @@ async def _poll_cached_batch( except asyncio.CancelledError: raise except Exception as error: + self._emit_event( + event_type="batch_failed", + provider=provider.name, + batch_id=batch_id, + error=str(object=error), + source="resumed_poll", + ) await self._fail_resumed_batch_requests( resume_key=resume_key, error=error, @@ -1199,6 +1439,12 @@ async def _resolve_cached_batch_results( pending.future.set_result(resolved_response) if missing_hashes: + self._emit_event( + event_type="missing_results", + batch_id=batch_id, + missing_count=len(missing_hashes), + source="resumed_results", + ) _ = self._invalidate_cache_hashes(request_hashes=missing_hashes) async def _fail_resumed_batch_requests( @@ -1405,6 +1651,12 @@ def _fail_missing_results( batch_id=active_batch.batch_id, missing_count=len(missing), ) + self._emit_event( + event_type="missing_results", + batch_id=active_batch.batch_id, + missing_count=len(missing), + source="results", + ) error = RuntimeError(f"Missing results for {len(missing)} request(s)") for custom_id in missing: pending = active_batch.requests.get(custom_id) @@ -1447,6 +1699,15 @@ async def close(self) -> None: queue_key=queue_name, request_count=len(requests), ) + self._emit_event( + event_type="final_flush_submitting", + provider=provider_name, + endpoint=queue_endpoint, + model=model_name, + queue_key=queue_key, + request_count=len(requests), + source="close", + ) await self._submit_requests( queue_key=queue_key, requests=requests, diff --git a/tests/test_core.py b/tests/test_core.py index a15d5be..960ba6e 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -10,7 +10,7 @@ import pytest from batchling.cache import CacheEntry -from batchling.core import Batcher, _ActiveBatch, _PendingRequest +from batchling.core import Batcher, BatcherEvent, _ActiveBatch, _PendingRequest from batchling.providers.anthropic import AnthropicProvider from batchling.providers.base import PollSnapshot, ProviderRequestSpec, ResumeContext from batchling.providers.gemini import GeminiProvider @@ -1948,3 +1948,142 @@ async def test_dry_run_cache_hit_is_read_only(mock_openai_api_transport: httpx.M assert dry_run_entry is not None assert dry_run_entry.created_at == original_created_at await dry_run_batcher.close() + + +@pytest.mark.asyncio +async def test_submit_emits_lifecycle_events(batcher: Batcher, provider: OpenAIProvider) -> None: + """Test submit emits queue, submit, poll, and terminal lifecycle events.""" + events: list[BatcherEvent] = [] + batcher._add_event_listener(listener=events.append) + + _ = await batcher.submit( + client_type="httpx", + method="POST", + url="api.openai.com", + endpoint="/v1/chat/completions", + provider=provider, + headers={"Authorization": "Bearer token"}, + body=b'{"model":"model-a","messages":[]}', + ) + + event_types = {str(event["event_type"]) for event in events} + assert "request_queued" in event_types + assert "batch_submitting" in event_types + assert "batch_processing" in event_types + assert "batch_polled" in event_types + assert "batch_terminal" in event_types + + +@pytest.mark.asyncio +async def test_dry_run_emits_terminal_without_poll(provider: OpenAIProvider) -> None: + """Test dry-run emits terminal lifecycle without provider polling events.""" + dry_run_batcher = Batcher( + batch_size=1, + batch_window_seconds=1.0, + dry_run=True, + cache=False, + ) + events: list[BatcherEvent] = [] + dry_run_batcher._add_event_listener(listener=events.append) + + _ = await dry_run_batcher.submit( + client_type="httpx", + method="POST", + url="api.openai.com", + endpoint="/v1/chat/completions", + provider=provider, + headers={"Authorization": "Bearer token"}, + body=b'{"model":"model-a","messages":[]}', + ) + + event_types = {str(event["event_type"]) for event in events} + assert "request_queued" in event_types + assert "batch_submitting" in event_types + assert "batch_terminal" in event_types + assert "batch_polled" not in event_types + + await dry_run_batcher.close() + + +@pytest.mark.asyncio +async def test_cache_hit_emits_cache_hit_routed( + mock_openai_api_transport: httpx.MockTransport, +) -> None: + """Test cache-hit routing emits dedicated lifecycle event.""" + writer_batcher = Batcher( + batch_size=2, + batch_window_seconds=0.1, + cache=True, + ) + writer_batcher._client_factory = lambda: httpx.AsyncClient(transport=mock_openai_api_transport) + writer_batcher._poll_interval_seconds = 0.01 + provider = OpenAIProvider() + + _ = await writer_batcher.submit( + client_type="httpx", + method="POST", + url="api.openai.com", + endpoint="/v1/chat/completions", + provider=provider, + headers={"Authorization": "Bearer token"}, + body=b'{"model":"model-a","messages":[]}', + ) + await writer_batcher.close() + + dry_run_batcher = Batcher( + batch_size=2, + batch_window_seconds=0.1, + cache=True, + dry_run=True, + ) + events: list[BatcherEvent] = [] + dry_run_batcher._add_event_listener(listener=events.append) + + _ = await dry_run_batcher.submit( + client_type="httpx", + method="POST", + url="api.openai.com", + endpoint="/v1/chat/completions", + provider=OpenAIProvider(), + headers={"Authorization": "Bearer token"}, + body=b'{"model":"model-a","messages":[]}', + ) + + assert any( + event["event_type"] == "cache_hit_routed" and event.get("source") == "cache_dry_run" + for event in events + ) + await dry_run_batcher.close() + + +def test_fail_missing_results_emits_lifecycle_event() -> None: + """Test missing output mapping emits missing-results lifecycle event.""" + batcher = Batcher( + batch_size=2, + batch_window_seconds=1.0, + cache=False, + ) + loop = asyncio.new_event_loop() + future: asyncio.Future[t.Any] = loop.create_future() + request = _PendingRequest( + custom_id="custom-1", + queue_key=("openai", "/v1/chat/completions", "model-a"), + params={}, + provider=OpenAIProvider(), + future=future, + request_hash="hash-1", + ) + active_batch = _ActiveBatch( + batch_id="batch-1", + output_file_id="file-1", + error_file_id="", + requests={"custom-1": request}, + ) + + events: list[BatcherEvent] = [] + batcher._add_event_listener(listener=events.append) + + batcher._fail_missing_results(active_batch=active_batch, seen=set()) + + assert any(event["event_type"] == "missing_results" for event in events) + loop.close() From 214e4f675991a611791f156d63de554ec3e270b0 Mon Sep 17 00:00:00 2001 From: Raphael Date: Mon, 2 Mar 2026 17:17:31 -0800 Subject: [PATCH 02/20] ui: add rich live batching display renderer --- pyproject.toml | 1 + requirements.txt | 1 + src/batchling/rich_display.py | 278 ++++++++++++++++++++++++++++++++++ tests/test_rich_display.py | 100 ++++++++++++ uv.lock | 2 + 5 files changed, 382 insertions(+) create mode 100644 src/batchling/rich_display.py create mode 100644 tests/test_rich_display.py diff --git a/pyproject.toml b/pyproject.toml index ba8394a..b7c4e15 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,6 +31,7 @@ classifiers = [ dependencies = [ "aiohttp>=3.13.3", "httpx>=0.28.1", + "rich>=14.2.0", "typer>=0.20.0", ] diff --git a/requirements.txt b/requirements.txt index d0e7873..99d8d5c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -645,6 +645,7 @@ requests-toolbelt==1.0.0 respx==0.22.0 rich==14.2.0 # via + # batchling # cyclopts # fastmcp # instructor diff --git a/src/batchling/rich_display.py b/src/batchling/rich_display.py new file mode 100644 index 0000000..0170329 --- /dev/null +++ b/src/batchling/rich_display.py @@ -0,0 +1,278 @@ +"""Rich live display for batch lifecycle visibility.""" + +from __future__ import annotations + +import os +import sys +import time +import typing as t +from collections import deque +from dataclasses import dataclass + +from rich.console import Console, Group +from rich.live import Live +from rich.panel import Panel +from rich.table import Table +from rich.text import Text + +from batchling.core import BatcherEvent, QueueKey + +LiveDisplayMode = t.Literal["auto", "on", "off"] + + +@dataclass +class _QueueActivity: + """In-memory queue activity snapshot for rendering.""" + + pending_count: int = 0 + active_batches: int = 0 + submitted_batches: int = 0 + last_status: str = "-" + last_batch_id: str = "-" + + +class BatcherRichDisplay: + """ + Render queue and lifecycle activity through a Rich ``Live`` panel. + + Parameters + ---------- + max_events : int, optional + Maximum number of lifecycle events kept in the rolling feed. + refresh_per_second : float, optional + Refresh rate for Rich live updates. + console : Console | None, optional + Rich console to render to. Defaults to ``Console(stderr=True)``. + """ + + def __init__( + self, + *, + max_events: int = 20, + refresh_per_second: float = 8.0, + console: Console | None = None, + ) -> None: + self._console = console or Console(stderr=True) + self._max_events = max_events + self._refresh_per_second = refresh_per_second + self._events: deque[BatcherEvent] = deque(maxlen=max_events) + self._queues: dict[QueueKey, _QueueActivity] = {} + self._live: Live | None = None + + def start(self) -> None: + """Start the live panel if not already running.""" + if self._live is not None: + return + self._live = Live( + renderable=self._render(), + console=self._console, + refresh_per_second=self._refresh_per_second, + transient=False, + ) + self._live.start(refresh=True) + + def stop(self) -> None: + """Stop the live panel if running.""" + if self._live is None: + return + self._live.stop() + self._live = None + + def on_event(self, event: BatcherEvent) -> None: + """ + Consume one batch lifecycle event and refresh the panel. + + Parameters + ---------- + event : BatcherEvent + Lifecycle event emitted by ``Batcher``. + """ + event_type = str(object=event.get("event_type", "unknown")) + queue_key = self._resolve_queue_key(event=event) + + if queue_key is not None: + queue_activity = self._queues.setdefault(queue_key, _QueueActivity()) + + if event_type == "request_queued": + queue_activity.pending_count = int( + event.get("pending_count", queue_activity.pending_count) + ) + elif event_type in {"batch_submitting", "final_flush_submitting"}: + queue_activity.pending_count = 0 + queue_activity.submitted_batches += 1 + elif event_type == "batch_processing": + source = str(object=event.get("source", "")) + if source in {"dry_run", "poll_start"}: + queue_activity.active_batches += 1 + elif event_type in {"batch_terminal", "batch_failed"}: + queue_activity.active_batches = max(0, queue_activity.active_batches - 1) + + event_status = event.get("status") + if event_status is not None: + queue_activity.last_status = str(object=event_status) + event_batch_id = event.get("batch_id") + if event_batch_id is not None: + queue_activity.last_batch_id = str(object=event_batch_id) + + self._events.append(event) + if self._live is not None: + self._live.update(renderable=self._render(), refresh=True) + + def _resolve_queue_key(self, *, event: BatcherEvent) -> QueueKey | None: + """ + Resolve queue key from event payload. + + Parameters + ---------- + event : BatcherEvent + Event payload. + + Returns + ------- + QueueKey | None + Queue key when available. + """ + raw_queue_key = event.get("queue_key") + if isinstance(raw_queue_key, tuple) and len(raw_queue_key) == 3: + provider, endpoint, model = raw_queue_key + return str(provider), str(endpoint), str(model) + + provider = event.get("provider") + endpoint = event.get("endpoint") + model = event.get("model") + if provider is None or endpoint is None or model is None: + return None + + return str(provider), str(endpoint), str(model) + + def _render(self) -> Panel: + """Build the current Rich panel renderable.""" + queue_table = self._build_queue_table() + event_table = self._build_event_table() + return Panel( + renderable=Group(queue_table, event_table), + title="batchling live activity", + border_style="cyan", + ) + + def _build_queue_table(self) -> Table: + """Build queue activity table.""" + table = Table(title="Queues", expand=True) + table.add_column(header="Provider", style="bold") + table.add_column(header="Endpoint") + table.add_column(header="Model") + table.add_column(header="Pending", justify="right") + table.add_column(header="Active", justify="right") + table.add_column(header="Submitted", justify="right") + table.add_column(header="Last Status") + table.add_column(header="Batch ID") + + if not self._queues: + table.add_row("-", "-", "-", "0", "0", "0", "-", "-") + return table + + for queue_key in sorted(self._queues.keys()): + provider, endpoint, model = queue_key + queue_activity = self._queues[queue_key] + table.add_row( + provider, + endpoint, + model, + str(queue_activity.pending_count), + str(queue_activity.active_batches), + str(queue_activity.submitted_batches), + queue_activity.last_status, + queue_activity.last_batch_id, + ) + return table + + def _build_event_table(self) -> Table: + """Build rolling lifecycle event feed.""" + table = Table(title=f"Recent Events (max {self._max_events})", expand=True) + table.add_column(header="Time", width=8) + table.add_column(header="Event") + table.add_column(header="Details") + + if not self._events: + table.add_row("-", "-", "No lifecycle events yet") + return table + + for event in reversed(self._events): + event_timestamp = float(event.get("timestamp", time.time())) + event_time = time.strftime("%H:%M:%S", time.localtime(event_timestamp)) + event_type = str(object=event.get("event_type", "unknown")) + details = self._format_event_details(event=event) + table.add_row(event_time, event_type, details) + + return table + + def _format_event_details(self, *, event: BatcherEvent) -> Text: + """ + Build compact details text for one event row. + + Parameters + ---------- + event : BatcherEvent + Event payload. + + Returns + ------- + Text + Formatted details. + """ + details: list[str] = [] + provider = event.get("provider") + endpoint = event.get("endpoint") + model = event.get("model") + batch_id = event.get("batch_id") + status = event.get("status") + pending_count = event.get("pending_count") + request_count = event.get("request_count") + error = event.get("error") + + if provider is not None: + details.append(f"provider={provider}") + if endpoint is not None: + details.append(f"endpoint={endpoint}") + if model is not None: + details.append(f"model={model}") + if pending_count is not None: + details.append(f"pending={pending_count}") + if request_count is not None: + details.append(f"requests={request_count}") + if batch_id is not None: + details.append(f"batch_id={batch_id}") + if status is not None: + details.append(f"status={status}") + if error is not None: + details.append(f"error={error}") + + return Text(" ".join(details) if details else "-") + + +def should_enable_live_display(*, mode: LiveDisplayMode) -> bool: + """ + Resolve if the Rich live panel should be enabled. + + Parameters + ---------- + mode : LiveDisplayMode + Desired display mode. + + Returns + ------- + bool + ``True`` when the live panel should run. + """ + if mode == "on": + return True + if mode == "off": + return False + + stderr_stream = sys.stderr + is_tty = bool(getattr(stderr_stream, "isatty", lambda: False)()) + terminal_name = str(object=os.environ.get("TERM", "")).lower() + is_dumb_terminal = terminal_name in {"", "dumb"} + is_ci = bool(os.environ.get("CI")) + + return is_tty and not is_dumb_terminal and not is_ci diff --git a/tests/test_rich_display.py b/tests/test_rich_display.py new file mode 100644 index 0000000..04a1557 --- /dev/null +++ b/tests/test_rich_display.py @@ -0,0 +1,100 @@ +"""Tests for Rich live display helpers.""" + +import io + +from rich.console import Console + +import batchling.rich_display as rich_display + + +def test_should_enable_live_display_auto_in_interactive_terminal( + monkeypatch, +) -> None: + """Test auto mode enables display in interactive terminals.""" + + class DummyStderr: + def isatty(self) -> bool: + return True + + monkeypatch.setattr(rich_display.sys, "stderr", DummyStderr()) + monkeypatch.setenv("TERM", "xterm-256color") + monkeypatch.delenv("CI", raising=False) + + assert rich_display.should_enable_live_display(mode="auto") is True + + +def test_should_enable_live_display_auto_disabled_in_ci(monkeypatch) -> None: + """Test auto mode disables display in CI environments.""" + + class DummyStderr: + def isatty(self) -> bool: + return True + + monkeypatch.setattr(rich_display.sys, "stderr", DummyStderr()) + monkeypatch.setenv("TERM", "xterm-256color") + monkeypatch.setenv("CI", "true") + + assert rich_display.should_enable_live_display(mode="auto") is False + + +def test_batcher_rich_display_consumes_events() -> None: + """Test Rich display event handling and rendering lifecycle.""" + display = rich_display.BatcherRichDisplay( + console=Console(file=io.StringIO(), force_terminal=False), + ) + + display.start() + display.on_event( + { + "event_type": "request_queued", + "timestamp": 1.0, + "provider": "openai", + "endpoint": "/v1/chat/completions", + "model": "model-a", + "queue_key": ("openai", "/v1/chat/completions", "model-a"), + "pending_count": 1, + } + ) + display.on_event( + { + "event_type": "batch_submitting", + "timestamp": 2.0, + "provider": "openai", + "endpoint": "/v1/chat/completions", + "model": "model-a", + "queue_key": ("openai", "/v1/chat/completions", "model-a"), + "request_count": 1, + } + ) + display.on_event( + { + "event_type": "batch_processing", + "timestamp": 3.0, + "provider": "openai", + "endpoint": "/v1/chat/completions", + "model": "model-a", + "queue_key": ("openai", "/v1/chat/completions", "model-a"), + "batch_id": "batch-1", + "source": "poll_start", + } + ) + display.on_event( + { + "event_type": "batch_terminal", + "timestamp": 4.0, + "provider": "openai", + "endpoint": "/v1/chat/completions", + "model": "model-a", + "queue_key": ("openai", "/v1/chat/completions", "model-a"), + "batch_id": "batch-1", + "status": "completed", + } + ) + display.stop() + + queue_activity = display._queues[("openai", "/v1/chat/completions", "model-a")] + assert queue_activity.pending_count == 0 + assert queue_activity.active_batches == 0 + assert queue_activity.submitted_batches == 1 + assert queue_activity.last_status == "completed" + assert queue_activity.last_batch_id == "batch-1" diff --git a/uv.lock b/uv.lock index 9e3a08a..cf0005b 100644 --- a/uv.lock +++ b/uv.lock @@ -267,6 +267,7 @@ source = { editable = "." } dependencies = [ { name = "aiohttp" }, { name = "httpx" }, + { name = "rich" }, { name = "typer" }, ] @@ -304,6 +305,7 @@ dev = [ requires-dist = [ { name = "aiohttp", specifier = ">=3.13.3" }, { name = "httpx", specifier = ">=0.28.1" }, + { name = "rich", specifier = ">=14.2.0" }, { name = "typer", specifier = ">=0.20.0" }, ] From 85c557614748f28676a21b9a924a0a68e1506ba3 Mon Sep 17 00:00:00 2001 From: Raphael Date: Mon, 2 Mar 2026 17:17:41 -0800 Subject: [PATCH 03/20] batchify/context: add live_display tri-state API and context wiring --- src/batchling/api.py | 6 +++ src/batchling/context.py | 86 ++++++++++++++++++++++++++++++++++++++-- tests/test_api.py | 18 +++++++++ tests/test_context.py | 77 +++++++++++++++++++++++++++++++++++ 4 files changed, 184 insertions(+), 3 deletions(-) diff --git a/src/batchling/api.py b/src/batchling/api.py index a00afb9..6d8b1bc 100644 --- a/src/batchling/api.py +++ b/src/batchling/api.py @@ -8,6 +8,7 @@ from batchling.core import Batcher from batchling.hooks import install_hooks from batchling.logging import setup_logging +from batchling.rich_display import LiveDisplayMode def batchify( @@ -16,6 +17,7 @@ def batchify( batch_poll_interval_seconds: float = 10.0, dry_run: bool = False, cache: bool = True, + live_display: LiveDisplayMode = "auto", ) -> BatchingContext: """ Context manager used to activate batching for a scoped context.
@@ -37,6 +39,9 @@ def batchify( cache : bool, optional If ``True``, enable persistent request cache lookups.
This parameter allows to skip the batch submission and go straight to the polling phase for requests that have already been sent. + live_display : {"auto", "on", "off"}, optional + Toggle the Rich live panel shown while the context is active.
+ ``"auto"`` enables the panel only in interactive terminals. Returns ------- @@ -65,4 +70,5 @@ def batchify( # 3. Return BatchingContext with no yielded target. return BatchingContext( batcher=batcher, + live_display=live_display, ) diff --git a/src/batchling/context.py b/src/batchling/context.py index 0916481..8e9f844 100644 --- a/src/batchling/context.py +++ b/src/batchling/context.py @@ -8,6 +8,11 @@ from batchling.core import Batcher from batchling.hooks import active_batcher +from batchling.rich_display import ( + BatcherRichDisplay, + LiveDisplayMode, + should_enable_live_display, +) class BatchingContext: @@ -18,9 +23,16 @@ class BatchingContext: ---------- batcher : Batcher Batcher instance used for the scope of the context manager. + live_display : LiveDisplayMode, optional + Live display mode used when entering the context. """ - def __init__(self, *, batcher: "Batcher") -> None: + def __init__( + self, + *, + batcher: "Batcher", + live_display: LiveDisplayMode = "auto", + ) -> None: """ Initialize the context manager. @@ -28,10 +40,60 @@ def __init__(self, *, batcher: "Batcher") -> None: ---------- batcher : Batcher Batcher instance used for the scope of the context manager. + live_display : LiveDisplayMode, optional + Live display mode used when entering the context. """ self._self_batcher = batcher + self._self_live_display_mode = live_display + self._self_live_display: BatcherRichDisplay | None = None self._self_context_token: t.Any | None = None + def _start_live_display(self) -> None: + """ + Start the Rich live display when enabled. + + Notes + ----- + Display errors are downgraded to warnings to avoid breaking batching. + """ + if self._self_live_display is not None: + return + if not should_enable_live_display(mode=self._self_live_display_mode): + return + try: + display = BatcherRichDisplay() + self._self_batcher._add_event_listener(listener=display.on_event) + display.start() + self._self_live_display = display + except Exception as error: + warnings.warn( + message=f"Failed to start batchling live display: {error}", + category=UserWarning, + stacklevel=2, + ) + + def _stop_live_display(self) -> None: + """ + Stop and unregister the Rich live display. + + Notes + ----- + Display shutdown errors are downgraded to warnings. + """ + if self._self_live_display is None: + return + display = self._self_live_display + self._self_live_display = None + try: + self._self_batcher._remove_event_listener(listener=display.on_event) + display.stop() + except Exception as error: + warnings.warn( + message=f"Failed to stop batchling live display: {error}", + category=UserWarning, + stacklevel=2, + ) + def __enter__(self) -> None: """ Enter the synchronous context manager and activate the batcher. @@ -42,6 +104,7 @@ def __enter__(self) -> None: ``None`` for scoped activation. """ self._self_context_token = active_batcher.set(self._self_batcher) + self._start_live_display() return None def __exit__( @@ -67,7 +130,8 @@ def __exit__( self._self_context_token = None try: loop = asyncio.get_running_loop() - loop.create_task(coro=self._self_batcher.close()) + close_task = loop.create_task(coro=self._self_batcher.close()) + close_task.add_done_callback(self._on_sync_close_done) except RuntimeError: warnings.warn( message=( @@ -78,6 +142,18 @@ def __exit__( category=UserWarning, stacklevel=2, ) + self._stop_live_display() + + def _on_sync_close_done(self, _: asyncio.Task[None]) -> None: + """ + Callback run when sync-context close task completes. + + Parameters + ---------- + _ : asyncio.Task[None] + Completed close task. + """ + self._stop_live_display() async def __aenter__(self) -> None: """ @@ -89,6 +165,7 @@ async def __aenter__(self) -> None: ``None`` for scoped activation. """ self._self_context_token = active_batcher.set(self._self_batcher) + self._start_live_display() return None async def __aexit__( @@ -112,4 +189,7 @@ async def __aexit__( if self._self_context_token is not None: active_batcher.reset(self._self_context_token) self._self_context_token = None - await self._self_batcher.close() + try: + await self._self_batcher.close() + finally: + self._stop_live_display() diff --git a/tests/test_api.py b/tests/test_api.py index 03da464..b87d54f 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -60,6 +60,24 @@ async def test_batchify_configures_cache_flag(reset_hooks, reset_context): assert wrapped._self_batcher._cache_enabled is False +@pytest.mark.asyncio +async def test_batchify_forwards_live_display_mode(reset_hooks, reset_context): + """Test that batchify forwards live display mode to BatchingContext.""" + wrapped = batchify( + live_display="off", + ) + + assert wrapped._self_live_display_mode == "off" + + +@pytest.mark.asyncio +async def test_batchify_live_display_defaults_to_auto(reset_hooks, reset_context): + """Test that live display mode defaults to auto.""" + wrapped = batchify() + + assert wrapped._self_live_display_mode == "auto" + + @pytest.mark.asyncio async def test_batchify_returns_context_manager(reset_hooks, reset_context): """Test that batchify returns a BatchingContext.""" diff --git a/tests/test_context.py b/tests/test_context.py index e967a57..b0f828d 100644 --- a/tests/test_context.py +++ b/tests/test_context.py @@ -96,3 +96,80 @@ async def test_batching_context_without_target(batcher: Batcher, reset_context: async with context as active_target: assert active_batcher.get() is batcher assert active_target is None + + +@pytest.mark.asyncio +async def test_batching_context_starts_and_stops_live_display( + batcher: Batcher, + reset_context: None, + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Test that async context starts and stops live display listeners.""" + + class DummyDisplay: + """Simple display stub.""" + + def __init__(self) -> None: + self.started = False + self.stopped = False + + def start(self) -> None: + self.started = True + + def stop(self) -> None: + self.stopped = True + + def on_event(self, event: dict[str, t.Any]) -> None: + del event + + dummy_display = DummyDisplay() + monkeypatch.setattr("batchling.context.BatcherRichDisplay", lambda: dummy_display) + monkeypatch.setattr("batchling.context.should_enable_live_display", lambda **_kwargs: True) + context = BatchingContext( + batcher=batcher, + live_display="on", + ) + + with patch.object(target=batcher, attribute="close", new_callable=AsyncMock): + async with context: + assert dummy_display.started is True + assert dummy_display.stopped is True + + +def test_batching_context_sync_stops_live_display_without_loop( + batcher: Batcher, + reset_context: None, + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Test sync context stops display when no event loop is running.""" + + class DummyDisplay: + """Simple display stub.""" + + def __init__(self) -> None: + self.stopped = False + + def start(self) -> None: + return None + + def stop(self) -> None: + self.stopped = True + + def on_event(self, event: dict[str, t.Any]) -> None: + del event + + dummy_display = DummyDisplay() + monkeypatch.setattr("batchling.context.BatcherRichDisplay", lambda: dummy_display) + monkeypatch.setattr("batchling.context.should_enable_live_display", lambda **_kwargs: True) + + context = BatchingContext( + batcher=batcher, + live_display="on", + ) + + with warnings.catch_warnings(record=True): + warnings.simplefilter(action="always") + with context: + pass + + assert dummy_display.stopped is True From 314c7c2238fd1c4f11535162a8e55da77bbc8fd7 Mon Sep 17 00:00:00 2001 From: Raphael Date: Mon, 2 Mar 2026 17:17:52 -0800 Subject: [PATCH 04/20] cli/tests/docs: expose live_display flag and document behavior --- docs/architecture/api.md | 7 +++++-- docs/architecture/context.md | 7 ++++++- docs/cli.md | 14 ++++++++++++++ docs/python-sdk.md | 15 +++++++++++++++ src/batchling/cli/main.py | 17 ++++++++++++++++- tests/test_cli_script_runner.py | 34 +++++++++++++++++++++++++++++++++ 6 files changed, 90 insertions(+), 4 deletions(-) diff --git a/docs/architecture/api.md b/docs/architecture/api.md index ca251fe..11dfe92 100644 --- a/docs/architecture/api.md +++ b/docs/architecture/api.md @@ -9,7 +9,7 @@ yields `None`. Import it from `batchling`. - Install HTTP hooks once (idempotent). - Construct a `Batcher` with configuration such as `batch_size`, `batch_window_seconds`, `batch_poll_interval_seconds`, `dry_run`, - and `cache`. + `cache`, and `live_display`. - Configure `batchling` logging defaults with Python's stdlib `logging` (`WARNING` by default). - Return a `BatchingContext` to scope batching to a context manager. @@ -26,6 +26,9 @@ yields `None`. Import it from `batchling`. - **`cache` behavior**: when `cache=True` (default), intercepted requests are fingerprinted and looked up in a persistent request cache. Cache hits bypass queueing and resume polling from an existing provider batch when not in dry-run mode. +- **`live_display` behavior**: `live_display` accepts `auto`, `on`, or `off`. + In `auto`, the Rich panel is enabled only when `stderr` is a TTY, terminal + is not `dumb`, and `CI` is not set. - **Outputs**: `BatchingContext[None]` instance that yields `None`. - **Logging**: lifecycle milestones are emitted at `INFO`, problems at `WARNING`/`ERROR`, and high-volume diagnostics at `DEBUG`. Request payloads @@ -43,7 +46,7 @@ Behavior: - CLI options map directly to `batchify` arguments: `batch_size`, `batch_window_seconds`, `batch_poll_interval_seconds`, `dry_run`, - and `cache`. + `cache`, and `live_display`. - Script target must use `module_path:function_name` syntax. - Forwarded callable arguments are mapped as: positional tokens are passed as positional arguments; diff --git a/docs/architecture/context.md b/docs/architecture/context.md index 78f67a9..53b409d 100644 --- a/docs/architecture/context.md +++ b/docs/architecture/context.md @@ -9,6 +9,7 @@ a context variable. - Activate the `active_batcher` context for the duration of a context block. - Yield `None` for scope-only lifecycle control. - Support sync and async context manager patterns for cleanup and context scoping. +- Start and stop optional Rich live activity display while the context is active. ## Flow summary @@ -16,7 +17,11 @@ a context variable. 2. `__enter__`/`__aenter__` set the active batcher for the entire context block. 3. `__exit__` resets the context and schedules `batcher.close()` if an event loop is running (otherwise it warns). -4. `__aexit__` resets the context and awaits `batcher.close()` to flush pending work. +4. If `live_display` is enabled, the context registers a lifecycle listener and starts + the Rich panel at enter-time. +5. `__aexit__` resets the context and awaits `batcher.close()` to flush pending work. +6. The live display listener is removed and the panel is stopped when context cleanup + finishes. ## Code reference diff --git a/docs/cli.md b/docs/cli.md index 692a1ec..3a7d456 100644 --- a/docs/cli.md +++ b/docs/cli.md @@ -64,6 +64,20 @@ batchling generate_product_images.py:main That's it! Just run that command and you save 50% off your workflow. +## Live visibility panel + +The CLI also exposes the live Rich panel control: + +```bash +batchling generate_product_images.py:main --live-display auto +``` + +`--live-display` accepts: + +- `auto` (default): only in interactive terminals (`TTY`, non-`dumb`, non-`CI`) +- `on`: always render the panel +- `off`: never render the panel + ## Next Steps If you haven't yet, look at how you can: diff --git a/docs/python-sdk.md b/docs/python-sdk.md index 4a00228..7038d9f 100644 --- a/docs/python-sdk.md +++ b/docs/python-sdk.md @@ -79,6 +79,21 @@ async def main(): That's it! Update three lines of code and you save 50% off your workflow. +## Live visibility panel + +You can enable a Rich live panel while the context is active: + +```py +async with batchify(live_display="on"): + generated_images = await asyncio.gather(*tasks) +``` + +`live_display` accepts: + +- `auto` (default): only in interactive terminals (`TTY`, non-`dumb`, non-`CI`) +- `on`: always render the panel +- `off`: never render the panel + You can now run this script normally using python and start saving money: ```bash diff --git a/src/batchling/cli/main.py b/src/batchling/cli/main.py index d266302..2f3bafd 100644 --- a/src/batchling/cli/main.py +++ b/src/batchling/cli/main.py @@ -7,6 +7,7 @@ import typer from batchling import batchify +from batchling.rich_display import LiveDisplayMode # syncify = lambda f: wraps(f)(lambda *args, **kwargs: asyncio.run(f(*args, **kwargs))) @@ -75,7 +76,8 @@ async def run_script_with_batchify( batch_poll_interval_seconds: float, dry_run: bool, cache: bool, -): + live_display: LiveDisplayMode, +) -> None: """ Execute a Python script under a batchify context. @@ -97,6 +99,8 @@ async def run_script_with_batchify( Dry run mode passed to ``batchify``. cache : bool Cache mode passed to ``batchify``. + live_display : {"auto", "on", "off"} + Live display mode passed to ``batchify``. """ if not module_path.exists(): typer.echo(f"Script not found: {module_path}") @@ -112,6 +116,7 @@ async def run_script_with_batchify( batch_poll_interval_seconds=batch_poll_interval_seconds, dry_run=dry_run, cache=cache, + live_display=live_display, ): ns = runpy.run_path(path_name=script_path_as_posix, run_name="batchling.runtime") func = ns.get(func_name) @@ -155,6 +160,15 @@ def main( bool, typer.Option("--cache/--no-cache", help="Enable persistent request caching"), ] = True, + live_display: t.Annotated[ + LiveDisplayMode, + typer.Option( + help=( + "Show the live Rich panel: auto (interactive terminals only), " + "on (always), off (never)" + ) + ), + ] = "auto", ): """Run a script under ``batchify``.""" try: @@ -173,5 +187,6 @@ def main( batch_poll_interval_seconds=batch_poll_interval_seconds, dry_run=dry_run, cache=cache, + live_display=live_display, ) ) diff --git a/tests/test_cli_script_runner.py b/tests/test_cli_script_runner.py index 0472062..4cdcca1 100644 --- a/tests/test_cli_script_runner.py +++ b/tests/test_cli_script_runner.py @@ -61,6 +61,7 @@ def fake_batchify(**kwargs): } assert captured_batchify_kwargs["dry_run"] is True assert captured_batchify_kwargs["cache"] is True + assert captured_batchify_kwargs["live_display"] == "auto" def test_run_script_with_cache_option(tmp_path: Path, monkeypatch): @@ -92,6 +93,39 @@ def fake_batchify(**kwargs): assert result.exit_code == 0 assert captured_batchify_kwargs["cache"] is False + assert captured_batchify_kwargs["live_display"] == "auto" + + +def test_run_script_with_live_display_option(tmp_path: Path, monkeypatch): + script_path = tmp_path / "script.py" + script_path.write_text( + "\n".join( + [ + "async def foo(*args, **kwargs):", + " return None", + ] + ) + + "\n" + ) + captured_batchify_kwargs: dict = {} + + def fake_batchify(**kwargs): + captured_batchify_kwargs.update(kwargs) + return DummyAsyncBatchifyContext() + + monkeypatch.setattr(cli_main, "batchify", fake_batchify) + + result = runner.invoke( + app, + [ + f"{script_path.as_posix()}:foo", + "--live-display", + "off", + ], + ) + + assert result.exit_code == 0 + assert captured_batchify_kwargs["live_display"] == "off" def test_batch_size_flag_scope_for_cli_and_target_function(tmp_path: Path, monkeypatch): From 6ebd2093e247fbd3546131ce1beae16fd2cb232d Mon Sep 17 00:00:00 2001 From: Raphael Date: Mon, 2 Mar 2026 17:24:57 -0800 Subject: [PATCH 05/20] ui: render live display as sent-batches table --- docs/cli.md | 3 + docs/python-sdk.md | 3 + src/batchling/rich_display.py | 253 ++++++++++++++-------------------- tests/test_rich_display.py | 118 ++++++++-------- 4 files changed, 174 insertions(+), 203 deletions(-) diff --git a/docs/cli.md b/docs/cli.md index 3a7d456..0426cf8 100644 --- a/docs/cli.md +++ b/docs/cli.md @@ -78,6 +78,9 @@ batchling generate_product_images.py:main --live-display auto - `on`: always render the panel - `off`: never render the panel +When enabled, the panel shows every sent batch with: +`batch_id`, `provider`, `endpoint`, `model`, `size`, and latest status. + ## Next Steps If you haven't yet, look at how you can: diff --git a/docs/python-sdk.md b/docs/python-sdk.md index 7038d9f..41f9bb4 100644 --- a/docs/python-sdk.md +++ b/docs/python-sdk.md @@ -94,6 +94,9 @@ async with batchify(live_display="on"): - `on`: always render the panel - `off`: never render the panel +When enabled, the panel focuses on sent batches and shows one row per batch: +`batch_id`, `provider`, `endpoint`, `model`, `size`, and latest status. + You can now run this script normally using python and start saving money: ```bash diff --git a/src/batchling/rich_display.py b/src/batchling/rich_display.py index 0170329..9d927e6 100644 --- a/src/batchling/rich_display.py +++ b/src/batchling/rich_display.py @@ -6,39 +6,37 @@ import sys import time import typing as t -from collections import deque from dataclasses import dataclass -from rich.console import Console, Group +from rich.console import Console from rich.live import Live from rich.panel import Panel from rich.table import Table -from rich.text import Text -from batchling.core import BatcherEvent, QueueKey +from batchling.core import BatcherEvent LiveDisplayMode = t.Literal["auto", "on", "off"] @dataclass -class _QueueActivity: - """In-memory queue activity snapshot for rendering.""" +class _BatchActivity: + """In-memory batch activity snapshot for rendering.""" - pending_count: int = 0 - active_batches: int = 0 - submitted_batches: int = 0 - last_status: str = "-" - last_batch_id: str = "-" + batch_id: str + provider: str = "-" + endpoint: str = "-" + model: str = "-" + size: int = 0 + latest_status: str = "submitted" + updated_at: float = 0.0 class BatcherRichDisplay: """ - Render queue and lifecycle activity through a Rich ``Live`` panel. + Render sent-batch lifecycle activity through a Rich ``Live`` panel. Parameters ---------- - max_events : int, optional - Maximum number of lifecycle events kept in the rolling feed. refresh_per_second : float, optional Refresh rate for Rich live updates. console : Console | None, optional @@ -48,15 +46,12 @@ class BatcherRichDisplay: def __init__( self, *, - max_events: int = 20, refresh_per_second: float = 8.0, console: Console | None = None, ) -> None: self._console = console or Console(stderr=True) - self._max_events = max_events self._refresh_per_second = refresh_per_second - self._events: deque[BatcherEvent] = deque(maxlen=max_events) - self._queues: dict[QueueKey, _QueueActivity] = {} + self._batches: dict[str, _BatchActivity] = {} self._live: Live | None = None def start(self) -> None: @@ -88,167 +83,127 @@ def on_event(self, event: BatcherEvent) -> None: Lifecycle event emitted by ``Batcher``. """ event_type = str(object=event.get("event_type", "unknown")) - queue_key = self._resolve_queue_key(event=event) - - if queue_key is not None: - queue_activity = self._queues.setdefault(queue_key, _QueueActivity()) - - if event_type == "request_queued": - queue_activity.pending_count = int( - event.get("pending_count", queue_activity.pending_count) - ) - elif event_type in {"batch_submitting", "final_flush_submitting"}: - queue_activity.pending_count = 0 - queue_activity.submitted_batches += 1 - elif event_type == "batch_processing": - source = str(object=event.get("source", "")) - if source in {"dry_run", "poll_start"}: - queue_activity.active_batches += 1 - elif event_type in {"batch_terminal", "batch_failed"}: - queue_activity.active_batches = max(0, queue_activity.active_batches - 1) - - event_status = event.get("status") - if event_status is not None: - queue_activity.last_status = str(object=event_status) - event_batch_id = event.get("batch_id") - if event_batch_id is not None: - queue_activity.last_batch_id = str(object=event_batch_id) - - self._events.append(event) + source = str(object=event.get("source", "")) + batch_id = event.get("batch_id") + + if batch_id is not None and event_type == "batch_processing": + batch = self._get_or_create_batch(batch_id=str(object=batch_id)) + self._update_batch_identity(batch=batch, event=event) + request_count = event.get("request_count") + if isinstance(request_count, int): + batch.size = max(batch.size, request_count) + if source == "dry_run": + batch.latest_status = "simulated" + else: + batch.latest_status = "submitted" + elif batch_id is not None and event_type == "batch_polled": + batch = self._get_or_create_batch(batch_id=str(object=batch_id)) + self._update_batch_identity(batch=batch, event=event) + status = event.get("status") + if status is not None: + batch.latest_status = str(object=status) + elif batch_id is not None and event_type == "batch_terminal": + batch = self._get_or_create_batch(batch_id=str(object=batch_id)) + self._update_batch_identity(batch=batch, event=event) + status = event.get("status") + if status is not None: + batch.latest_status = str(object=status) + elif batch_id is not None and event_type == "batch_failed": + batch = self._get_or_create_batch(batch_id=str(object=batch_id)) + self._update_batch_identity(batch=batch, event=event) + batch.latest_status = "failed" + elif batch_id is not None and event_type == "cache_hit_routed" and source == "resumed_poll": + batch = self._get_or_create_batch(batch_id=str(object=batch_id)) + self._update_batch_identity(batch=batch, event=event) + batch.size += 1 + if batch.latest_status == "submitted": + batch.latest_status = "resumed" + if self._live is not None: self._live.update(renderable=self._render(), refresh=True) - def _resolve_queue_key(self, *, event: BatcherEvent) -> QueueKey | None: + def _get_or_create_batch(self, *, batch_id: str) -> _BatchActivity: """ - Resolve queue key from event payload. + Fetch or create batch display state. Parameters ---------- - event : BatcherEvent - Event payload. + batch_id : str + Provider batch identifier. Returns ------- - QueueKey | None - Queue key when available. + _BatchActivity + Batch display state. + """ + batch = self._batches.get(batch_id) + if batch is None: + batch = _BatchActivity(batch_id=batch_id) + self._batches[batch_id] = batch + batch.updated_at = time.time() + return batch + + @staticmethod + def _update_batch_identity(*, batch: _BatchActivity, event: BatcherEvent) -> None: """ - raw_queue_key = event.get("queue_key") - if isinstance(raw_queue_key, tuple) and len(raw_queue_key) == 3: - provider, endpoint, model = raw_queue_key - return str(provider), str(endpoint), str(model) + Update provider/endpoint/model metadata from one lifecycle event. + Parameters + ---------- + batch : _BatchActivity + Mutable batch row. + event : BatcherEvent + Lifecycle event payload. + """ provider = event.get("provider") endpoint = event.get("endpoint") model = event.get("model") - if provider is None or endpoint is None or model is None: - return None - - return str(provider), str(endpoint), str(model) + if provider is not None: + batch.provider = str(object=provider) + if endpoint is not None: + batch.endpoint = str(object=endpoint) + if model is not None: + batch.model = str(object=model) def _render(self) -> Panel: """Build the current Rich panel renderable.""" - queue_table = self._build_queue_table() - event_table = self._build_event_table() + table = self._build_batches_table() return Panel( - renderable=Group(queue_table, event_table), - title="batchling live activity", + renderable=table, + title="batchling sent batches", border_style="cyan", ) - def _build_queue_table(self) -> Table: - """Build queue activity table.""" - table = Table(title="Queues", expand=True) - table.add_column(header="Provider", style="bold") + def _build_batches_table(self) -> Table: + """Build sent-batches activity table.""" + table = Table(title="Sent Batches", expand=True) + table.add_column(header="Batch ID", style="bold") + table.add_column(header="Provider") table.add_column(header="Endpoint") table.add_column(header="Model") - table.add_column(header="Pending", justify="right") - table.add_column(header="Active", justify="right") - table.add_column(header="Submitted", justify="right") - table.add_column(header="Last Status") - table.add_column(header="Batch ID") - - if not self._queues: - table.add_row("-", "-", "-", "0", "0", "0", "-", "-") + table.add_column(header="Size", justify="right") + table.add_column(header="Latest Status") + + if not self._batches: + table.add_row("-", "-", "-", "-", "0", "waiting") return table - for queue_key in sorted(self._queues.keys()): - provider, endpoint, model = queue_key - queue_activity = self._queues[queue_key] + ordered_batches = sorted( + self._batches.values(), + key=lambda batch: batch.updated_at, + reverse=True, + ) + for batch in ordered_batches: table.add_row( - provider, - endpoint, - model, - str(queue_activity.pending_count), - str(queue_activity.active_batches), - str(queue_activity.submitted_batches), - queue_activity.last_status, - queue_activity.last_batch_id, + batch.batch_id, + batch.provider, + batch.endpoint, + batch.model, + str(batch.size), + batch.latest_status, ) return table - def _build_event_table(self) -> Table: - """Build rolling lifecycle event feed.""" - table = Table(title=f"Recent Events (max {self._max_events})", expand=True) - table.add_column(header="Time", width=8) - table.add_column(header="Event") - table.add_column(header="Details") - - if not self._events: - table.add_row("-", "-", "No lifecycle events yet") - return table - - for event in reversed(self._events): - event_timestamp = float(event.get("timestamp", time.time())) - event_time = time.strftime("%H:%M:%S", time.localtime(event_timestamp)) - event_type = str(object=event.get("event_type", "unknown")) - details = self._format_event_details(event=event) - table.add_row(event_time, event_type, details) - - return table - - def _format_event_details(self, *, event: BatcherEvent) -> Text: - """ - Build compact details text for one event row. - - Parameters - ---------- - event : BatcherEvent - Event payload. - - Returns - ------- - Text - Formatted details. - """ - details: list[str] = [] - provider = event.get("provider") - endpoint = event.get("endpoint") - model = event.get("model") - batch_id = event.get("batch_id") - status = event.get("status") - pending_count = event.get("pending_count") - request_count = event.get("request_count") - error = event.get("error") - - if provider is not None: - details.append(f"provider={provider}") - if endpoint is not None: - details.append(f"endpoint={endpoint}") - if model is not None: - details.append(f"model={model}") - if pending_count is not None: - details.append(f"pending={pending_count}") - if request_count is not None: - details.append(f"requests={request_count}") - if batch_id is not None: - details.append(f"batch_id={batch_id}") - if status is not None: - details.append(f"status={status}") - if error is not None: - details.append(f"error={error}") - - return Text(" ".join(details) if details else "-") - def should_enable_live_display(*, mode: LiveDisplayMode) -> bool: """ diff --git a/tests/test_rich_display.py b/tests/test_rich_display.py index 04a1557..20658a1 100644 --- a/tests/test_rich_display.py +++ b/tests/test_rich_display.py @@ -37,64 +37,74 @@ def isatty(self) -> bool: assert rich_display.should_enable_live_display(mode="auto") is False -def test_batcher_rich_display_consumes_events() -> None: - """Test Rich display event handling and rendering lifecycle.""" +def test_batcher_rich_display_shows_sent_batches() -> None: + """Test sent-batch table tracks batch metadata and latest status.""" display = rich_display.BatcherRichDisplay( console=Console(file=io.StringIO(), force_terminal=False), ) display.start() - display.on_event( - { - "event_type": "request_queued", - "timestamp": 1.0, - "provider": "openai", - "endpoint": "/v1/chat/completions", - "model": "model-a", - "queue_key": ("openai", "/v1/chat/completions", "model-a"), - "pending_count": 1, - } - ) - display.on_event( - { - "event_type": "batch_submitting", - "timestamp": 2.0, - "provider": "openai", - "endpoint": "/v1/chat/completions", - "model": "model-a", - "queue_key": ("openai", "/v1/chat/completions", "model-a"), - "request_count": 1, - } - ) - display.on_event( - { - "event_type": "batch_processing", - "timestamp": 3.0, - "provider": "openai", - "endpoint": "/v1/chat/completions", - "model": "model-a", - "queue_key": ("openai", "/v1/chat/completions", "model-a"), - "batch_id": "batch-1", - "source": "poll_start", - } - ) - display.on_event( - { - "event_type": "batch_terminal", - "timestamp": 4.0, - "provider": "openai", - "endpoint": "/v1/chat/completions", - "model": "model-a", - "queue_key": ("openai", "/v1/chat/completions", "model-a"), - "batch_id": "batch-1", - "status": "completed", - } - ) + processing_event: rich_display.BatcherEvent = { + "event_type": "batch_processing", + "timestamp": 1.0, + "provider": "openai", + "endpoint": "/v1/chat/completions", + "model": "model-a", + "queue_key": ("openai", "/v1/chat/completions", "model-a"), + "batch_id": "batch-1", + "request_count": 3, + "source": "poll_start", + } + polled_event: rich_display.BatcherEvent = { + "event_type": "batch_polled", + "timestamp": 2.0, + "provider": "openai", + "batch_id": "batch-1", + "status": "running", + "source": "active_poll", + } + terminal_event: rich_display.BatcherEvent = { + "event_type": "batch_terminal", + "timestamp": 3.0, + "provider": "openai", + "batch_id": "batch-1", + "status": "completed", + "source": "active_poll", + } + + display.on_event(processing_event) + display.on_event(polled_event) + display.on_event(terminal_event) display.stop() - queue_activity = display._queues[("openai", "/v1/chat/completions", "model-a")] - assert queue_activity.pending_count == 0 - assert queue_activity.active_batches == 0 - assert queue_activity.submitted_batches == 1 - assert queue_activity.last_status == "completed" - assert queue_activity.last_batch_id == "batch-1" + batch = display._batches["batch-1"] + assert batch.provider == "openai" + assert batch.endpoint == "/v1/chat/completions" + assert batch.model == "model-a" + assert batch.size == 3 + assert batch.latest_status == "completed" + + +def test_batcher_rich_display_tracks_resumed_batch_size() -> None: + """Test resumed cache-hit routing increments displayed batch size.""" + display = rich_display.BatcherRichDisplay( + console=Console(file=io.StringIO(), force_terminal=False), + ) + + cache_event: rich_display.BatcherEvent = { + "event_type": "cache_hit_routed", + "timestamp": 1.0, + "provider": "openai", + "endpoint": "/v1/chat/completions", + "model": "model-a", + "batch_id": "batch-cached-1", + "source": "resumed_poll", + "custom_id": "custom-1", + } + + display.on_event(cache_event) + display.on_event(cache_event) + + batch = display._batches["batch-cached-1"] + assert batch.size == 2 + assert batch.latest_status == "resumed" From bb1d8269c3d2556f0dcc373c37bc0c25d438063a Mon Sep 17 00:00:00 2001 From: Raphael Date: Mon, 2 Mar 2026 18:30:17 -0800 Subject: [PATCH 06/20] ui: show context progress from completed batch sizes --- docs/cli.md | 4 +- docs/python-sdk.md | 4 +- src/batchling/rich_display.py | 117 ++++++++++++++++------------------ tests/test_rich_display.py | 64 ++++++++++++------- 4 files changed, 101 insertions(+), 88 deletions(-) diff --git a/docs/cli.md b/docs/cli.md index 0426cf8..4631408 100644 --- a/docs/cli.md +++ b/docs/cli.md @@ -78,8 +78,8 @@ batchling generate_product_images.py:main --live-display auto - `on`: always render the panel - `off`: never render the panel -When enabled, the panel shows every sent batch with: -`batch_id`, `provider`, `endpoint`, `model`, `size`, and latest status. +When enabled, the panel shows overall context progress: +`completed_samples / total_samples` and completion percentage. ## Next Steps diff --git a/docs/python-sdk.md b/docs/python-sdk.md index 41f9bb4..047daa4 100644 --- a/docs/python-sdk.md +++ b/docs/python-sdk.md @@ -94,8 +94,8 @@ async with batchify(live_display="on"): - `on`: always render the panel - `off`: never render the panel -When enabled, the panel focuses on sent batches and shows one row per batch: -`batch_id`, `provider`, `endpoint`, `model`, `size`, and latest status. +When enabled, the panel shows context-level progress only: +`completed_samples / total_samples` and completion percentage. You can now run this script normally using python and start saving money: diff --git a/src/batchling/rich_display.py b/src/batchling/rich_display.py index 9d927e6..1626402 100644 --- a/src/batchling/rich_display.py +++ b/src/batchling/rich_display.py @@ -20,20 +20,21 @@ @dataclass class _BatchActivity: - """In-memory batch activity snapshot for rendering.""" + """In-memory batch activity snapshot for progress aggregation.""" batch_id: str - provider: str = "-" - endpoint: str = "-" - model: str = "-" size: int = 0 latest_status: str = "submitted" + completed: bool = False updated_at: float = 0.0 class BatcherRichDisplay: """ - Render sent-batch lifecycle activity through a Rich ``Live`` panel. + Render context-level sample progress through a Rich ``Live`` panel. + + Progress is computed from tracked sent batches as: + ``sum(size of completed batches) / sum(size of all tracked batches)``. Parameters ---------- @@ -88,33 +89,26 @@ def on_event(self, event: BatcherEvent) -> None: if batch_id is not None and event_type == "batch_processing": batch = self._get_or_create_batch(batch_id=str(object=batch_id)) - self._update_batch_identity(batch=batch, event=event) request_count = event.get("request_count") if isinstance(request_count, int): batch.size = max(batch.size, request_count) - if source == "dry_run": - batch.latest_status = "simulated" - else: - batch.latest_status = "submitted" + batch.latest_status = "simulated" if source == "dry_run" else "submitted" elif batch_id is not None and event_type == "batch_polled": batch = self._get_or_create_batch(batch_id=str(object=batch_id)) - self._update_batch_identity(batch=batch, event=event) status = event.get("status") if status is not None: batch.latest_status = str(object=status) elif batch_id is not None and event_type == "batch_terminal": batch = self._get_or_create_batch(batch_id=str(object=batch_id)) - self._update_batch_identity(batch=batch, event=event) - status = event.get("status") - if status is not None: - batch.latest_status = str(object=status) + status = str(object=event.get("status", "completed")) + batch.latest_status = status + batch.completed = self._status_counts_as_completed(status=status) elif batch_id is not None and event_type == "batch_failed": batch = self._get_or_create_batch(batch_id=str(object=batch_id)) - self._update_batch_identity(batch=batch, event=event) batch.latest_status = "failed" + batch.completed = False elif batch_id is not None and event_type == "cache_hit_routed" and source == "resumed_poll": batch = self._get_or_create_batch(batch_id=str(object=batch_id)) - self._update_batch_identity(batch=batch, event=event) batch.size += 1 if batch.latest_status == "submitted": batch.latest_status = "resumed" @@ -144,64 +138,65 @@ def _get_or_create_batch(self, *, batch_id: str) -> _BatchActivity: return batch @staticmethod - def _update_batch_identity(*, batch: _BatchActivity, event: BatcherEvent) -> None: + def _status_counts_as_completed(*, status: str) -> bool: """ - Update provider/endpoint/model metadata from one lifecycle event. + Determine whether a terminal status counts as completed samples. Parameters ---------- - batch : _BatchActivity - Mutable batch row. - event : BatcherEvent - Lifecycle event payload. + status : str + Terminal provider status. + + Returns + ------- + bool + ``True`` when terminal state should contribute to completed samples. """ - provider = event.get("provider") - endpoint = event.get("endpoint") - model = event.get("model") - if provider is not None: - batch.provider = str(object=provider) - if endpoint is not None: - batch.endpoint = str(object=endpoint) - if model is not None: - batch.model = str(object=model) + lowered_status = status.lower() + negative_markers = ("fail", "error", "cancel", "expired", "timeout") + if any(marker in lowered_status for marker in negative_markers): + return False + return True + + def _compute_progress(self) -> tuple[int, int, float]: + """ + Compute aggregate context progress from tracked batches. + + Returns + ------- + tuple[int, int, float] + ``(completed_samples, total_samples, percent)``. + """ + total_samples = sum(batch.size for batch in self._batches.values()) + completed_samples = sum(batch.size for batch in self._batches.values() if batch.completed) + if total_samples <= 0: + return 0, 0, 0.0 + percent = (completed_samples / total_samples) * 100 + return completed_samples, total_samples, percent def _render(self) -> Panel: """Build the current Rich panel renderable.""" - table = self._build_batches_table() + table = self._build_progress_table() return Panel( renderable=table, - title="batchling sent batches", + title="batchling context progress", border_style="cyan", ) - def _build_batches_table(self) -> Table: - """Build sent-batches activity table.""" - table = Table(title="Sent Batches", expand=True) - table.add_column(header="Batch ID", style="bold") - table.add_column(header="Provider") - table.add_column(header="Endpoint") - table.add_column(header="Model") - table.add_column(header="Size", justify="right") - table.add_column(header="Latest Status") - - if not self._batches: - table.add_row("-", "-", "-", "-", "0", "waiting") - return table - - ordered_batches = sorted( - self._batches.values(), - key=lambda batch: batch.updated_at, - reverse=True, + def _build_progress_table(self) -> Table: + """Build aggregate context progress table.""" + completed_samples, total_samples, percent = self._compute_progress() + + table = Table(title="Overall Progress", expand=True) + table.add_column(header="Completed Samples", justify="right") + table.add_column(header="Total Samples", justify="right") + table.add_column(header="Completion", justify="right") + + table.add_row( + str(completed_samples), + str(total_samples), + f"{percent:.1f}%", ) - for batch in ordered_batches: - table.add_row( - batch.batch_id, - batch.provider, - batch.endpoint, - batch.model, - str(batch.size), - batch.latest_status, - ) return table diff --git a/tests/test_rich_display.py b/tests/test_rich_display.py index 20658a1..b879ffe 100644 --- a/tests/test_rich_display.py +++ b/tests/test_rich_display.py @@ -37,13 +37,12 @@ def isatty(self) -> bool: assert rich_display.should_enable_live_display(mode="auto") is False -def test_batcher_rich_display_shows_sent_batches() -> None: - """Test sent-batch table tracks batch metadata and latest status.""" +def test_batcher_rich_display_computes_context_progress() -> None: + """Test context progress is derived from completed batch sizes.""" display = rich_display.BatcherRichDisplay( console=Console(file=io.StringIO(), force_terminal=False), ) - display.start() processing_event: rich_display.BatcherEvent = { "event_type": "batch_processing", "timestamp": 1.0, @@ -55,38 +54,47 @@ def test_batcher_rich_display_shows_sent_batches() -> None: "request_count": 3, "source": "poll_start", } - polled_event: rich_display.BatcherEvent = { - "event_type": "batch_polled", + terminal_event: rich_display.BatcherEvent = { + "event_type": "batch_terminal", "timestamp": 2.0, "provider": "openai", "batch_id": "batch-1", - "status": "running", + "status": "completed", "source": "active_poll", } - terminal_event: rich_display.BatcherEvent = { - "event_type": "batch_terminal", + failed_batch_event: rich_display.BatcherEvent = { + "event_type": "batch_processing", "timestamp": 3.0, "provider": "openai", - "batch_id": "batch-1", - "status": "completed", + "endpoint": "/v1/chat/completions", + "model": "model-a", + "queue_key": ("openai", "/v1/chat/completions", "model-a"), + "batch_id": "batch-2", + "request_count": 2, + "source": "poll_start", + } + failed_terminal_event: rich_display.BatcherEvent = { + "event_type": "batch_terminal", + "timestamp": 4.0, + "provider": "openai", + "batch_id": "batch-2", + "status": "failed", "source": "active_poll", } display.on_event(processing_event) - display.on_event(polled_event) display.on_event(terminal_event) - display.stop() + display.on_event(failed_batch_event) + display.on_event(failed_terminal_event) - batch = display._batches["batch-1"] - assert batch.provider == "openai" - assert batch.endpoint == "/v1/chat/completions" - assert batch.model == "model-a" - assert batch.size == 3 - assert batch.latest_status == "completed" + completed_samples, total_samples, percent = display._compute_progress() + assert completed_samples == 3 + assert total_samples == 5 + assert percent == 60.0 -def test_batcher_rich_display_tracks_resumed_batch_size() -> None: - """Test resumed cache-hit routing increments displayed batch size.""" +def test_batcher_rich_display_tracks_resumed_batch_progress() -> None: + """Test resumed cache-hit routing contributes to total and completion.""" display = rich_display.BatcherRichDisplay( console=Console(file=io.StringIO(), force_terminal=False), ) @@ -101,10 +109,20 @@ def test_batcher_rich_display_tracks_resumed_batch_size() -> None: "source": "resumed_poll", "custom_id": "custom-1", } + terminal_event: rich_display.BatcherEvent = { + "event_type": "batch_terminal", + "timestamp": 2.0, + "provider": "openai", + "batch_id": "batch-cached-1", + "status": "completed", + "source": "resumed_poll", + } display.on_event(cache_event) display.on_event(cache_event) + display.on_event(terminal_event) - batch = display._batches["batch-cached-1"] - assert batch.size == 2 - assert batch.latest_status == "resumed" + completed_samples, total_samples, percent = display._compute_progress() + assert completed_samples == 2 + assert total_samples == 2 + assert percent == 100.0 From d2765339423bd034e08a542b62f5341b9f592f69 Mon Sep 17 00:00:00 2001 From: Raphael Date: Mon, 2 Mar 2026 18:37:19 -0800 Subject: [PATCH 07/20] ui: render context metrics as rich progress bar --- src/batchling/rich_display.py | 35 ++++++++++++++++++++--------------- 1 file changed, 20 insertions(+), 15 deletions(-) diff --git a/src/batchling/rich_display.py b/src/batchling/rich_display.py index 1626402..4d57014 100644 --- a/src/batchling/rich_display.py +++ b/src/batchling/rich_display.py @@ -11,7 +11,7 @@ from rich.console import Console from rich.live import Live from rich.panel import Panel -from rich.table import Table +from rich.progress import BarColumn, Progress, TextColumn from batchling.core import BatcherEvent @@ -176,28 +176,33 @@ def _compute_progress(self) -> tuple[int, int, float]: def _render(self) -> Panel: """Build the current Rich panel renderable.""" - table = self._build_progress_table() + progress_bar = self._build_progress_bar() return Panel( - renderable=table, + renderable=progress_bar, title="batchling context progress", border_style="cyan", ) - def _build_progress_table(self) -> Table: - """Build aggregate context progress table.""" + def _build_progress_bar(self) -> Progress: + """Build aggregate context progress as a Rich progress bar.""" completed_samples, total_samples, percent = self._compute_progress() - table = Table(title="Overall Progress", expand=True) - table.add_column(header="Completed Samples", justify="right") - table.add_column(header="Total Samples", justify="right") - table.add_column(header="Completion", justify="right") - - table.add_row( - str(completed_samples), - str(total_samples), - f"{percent:.1f}%", + progress = Progress( + TextColumn(text_format="[bold]Progress[/bold]"), + BarColumn(), + TextColumn(text_format="{task.fields[completed_samples]}/{task.fields[total_samples]}"), + TextColumn(text_format=f"{percent:.1f}%"), + expand=True, + ) + display_total = max(total_samples, 1) + _ = progress.add_task( + description="samples", + total=display_total, + completed=min(completed_samples, display_total), + completed_samples=completed_samples, + total_samples=total_samples, ) - return table + return progress def should_enable_live_display(*, mode: LiveDisplayMode) -> bool: From f8f14c087f11bba4c04e20d4034a645d796055bc Mon Sep 17 00:00:00 2001 From: Raphael Date: Mon, 2 Mar 2026 18:40:14 -0800 Subject: [PATCH 08/20] ui: left-align progress bar and inline metrics --- src/batchling/rich_display.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/src/batchling/rich_display.py b/src/batchling/rich_display.py index 4d57014..9509faa 100644 --- a/src/batchling/rich_display.py +++ b/src/batchling/rich_display.py @@ -185,13 +185,16 @@ def _render(self) -> Panel: def _build_progress_bar(self) -> Progress: """Build aggregate context progress as a Rich progress bar.""" - completed_samples, total_samples, percent = self._compute_progress() + completed_samples, total_samples, _ = self._compute_progress() progress = Progress( - TextColumn(text_format="[bold]Progress[/bold]"), - BarColumn(), - TextColumn(text_format="{task.fields[completed_samples]}/{task.fields[total_samples]}"), - TextColumn(text_format=f"{percent:.1f}%"), + BarColumn(bar_width=None), + TextColumn( + text_format=( + "{task.fields[completed_samples]}/" + "{task.fields[total_samples]} ({task.percentage:.1f}%)" + ) + ), expand=True, ) display_total = max(total_samples, 1) From b4994482091d9ca9de715339688e9924081008a2 Mon Sep 17 00:00:00 2001 From: Raphael Date: Mon, 2 Mar 2026 18:45:09 -0800 Subject: [PATCH 09/20] ui: add elapsed timer next to progress metrics --- docs/cli.md | 3 +- docs/python-sdk.md | 3 +- src/batchling/rich_display.py | 42 +++++++++++++++++++++++- tests/test_rich_display.py | 60 +++++++++++++++++++++++++++++++++++ 4 files changed, 105 insertions(+), 3 deletions(-) diff --git a/docs/cli.md b/docs/cli.md index 4631408..a9dcb6b 100644 --- a/docs/cli.md +++ b/docs/cli.md @@ -79,7 +79,8 @@ batchling generate_product_images.py:main --live-display auto - `off`: never render the panel When enabled, the panel shows overall context progress: -`completed_samples / total_samples` and completion percentage. +`completed_samples / total_samples`, completion percentage, and `Time Elapsed` +since the first batch seen in the context. ## Next Steps diff --git a/docs/python-sdk.md b/docs/python-sdk.md index 047daa4..19711c9 100644 --- a/docs/python-sdk.md +++ b/docs/python-sdk.md @@ -95,7 +95,8 @@ async with batchify(live_display="on"): - `off`: never render the panel When enabled, the panel shows context-level progress only: -`completed_samples / total_samples` and completion percentage. +`completed_samples / total_samples`, completion percentage, and `Time Elapsed` +since the first batch seen in the context. You can now run this script normally using python and start saving money: diff --git a/src/batchling/rich_display.py b/src/batchling/rich_display.py index 9509faa..42f0382 100644 --- a/src/batchling/rich_display.py +++ b/src/batchling/rich_display.py @@ -53,6 +53,7 @@ def __init__( self._console = console or Console(stderr=True) self._refresh_per_second = refresh_per_second self._batches: dict[str, _BatchActivity] = {} + self._first_batch_created_at: float | None = None self._live: Live | None = None def start(self) -> None: @@ -130,11 +131,14 @@ def _get_or_create_batch(self, *, batch_id: str) -> _BatchActivity: _BatchActivity Batch display state. """ + now = time.time() batch = self._batches.get(batch_id) if batch is None: batch = _BatchActivity(batch_id=batch_id) self._batches[batch_id] = batch - batch.updated_at = time.time() + if self._first_batch_created_at is None: + self._first_batch_created_at = now + batch.updated_at = now return batch @staticmethod @@ -174,6 +178,39 @@ def _compute_progress(self) -> tuple[int, int, float]: percent = (completed_samples / total_samples) * 100 return completed_samples, total_samples, percent + def _compute_elapsed_seconds(self) -> int: + """ + Compute elapsed seconds since first batch creation in this context. + + Returns + ------- + int + Elapsed seconds. + """ + if self._first_batch_created_at is None: + return 0 + return max(0, int(time.time() - self._first_batch_created_at)) + + @staticmethod + def _format_elapsed(*, elapsed_seconds: int) -> str: + """ + Format elapsed seconds as ``HH:MM:SS``. + + Parameters + ---------- + elapsed_seconds : int + Elapsed seconds. + + Returns + ------- + str + Formatted duration. + """ + hours = elapsed_seconds // 3600 + minutes = (elapsed_seconds % 3600) // 60 + seconds = elapsed_seconds % 60 + return f"{hours:02d}:{minutes:02d}:{seconds:02d}" + def _render(self) -> Panel: """Build the current Rich panel renderable.""" progress_bar = self._build_progress_bar() @@ -186,6 +223,8 @@ def _render(self) -> Panel: def _build_progress_bar(self) -> Progress: """Build aggregate context progress as a Rich progress bar.""" completed_samples, total_samples, _ = self._compute_progress() + elapsed_seconds = self._compute_elapsed_seconds() + elapsed_label = self._format_elapsed(elapsed_seconds=elapsed_seconds) progress = Progress( BarColumn(bar_width=None), @@ -195,6 +234,7 @@ def _build_progress_bar(self) -> Progress: "{task.fields[total_samples]} ({task.percentage:.1f}%)" ) ), + TextColumn(text_format=f"Time Elapsed: {elapsed_label}"), expand=True, ) display_total = max(total_samples, 1) diff --git a/tests/test_rich_display.py b/tests/test_rich_display.py index b879ffe..501dfb2 100644 --- a/tests/test_rich_display.py +++ b/tests/test_rich_display.py @@ -126,3 +126,63 @@ def test_batcher_rich_display_tracks_resumed_batch_progress() -> None: assert completed_samples == 2 assert total_samples == 2 assert percent == 100.0 + + +def test_batcher_rich_display_elapsed_uses_first_batch_time(monkeypatch) -> None: + """Test elapsed timer starts from first batch seen in the context.""" + current_time = {"value": 100.0} + + def fake_time() -> float: + return current_time["value"] + + monkeypatch.setattr(rich_display.time, "time", fake_time) + + display = rich_display.BatcherRichDisplay( + console=Console(file=io.StringIO(), force_terminal=False), + ) + + first_batch_event: rich_display.BatcherEvent = { + "event_type": "batch_processing", + "timestamp": 100.0, + "provider": "openai", + "endpoint": "/v1/chat/completions", + "model": "model-a", + "queue_key": ("openai", "/v1/chat/completions", "model-a"), + "batch_id": "batch-1", + "request_count": 1, + "source": "poll_start", + } + display.on_event(first_batch_event) + + current_time["value"] = 127.0 + assert display._compute_elapsed_seconds() == 27 + assert display._format_elapsed(elapsed_seconds=27) == "00:00:27" + + +def test_batcher_rich_display_elapsed_starts_with_cache_batch(monkeypatch) -> None: + """Test elapsed timer also starts when the first batch comes from cache routing.""" + current_time = {"value": 200.0} + + def fake_time() -> float: + return current_time["value"] + + monkeypatch.setattr(rich_display.time, "time", fake_time) + + display = rich_display.BatcherRichDisplay( + console=Console(file=io.StringIO(), force_terminal=False), + ) + + cache_event: rich_display.BatcherEvent = { + "event_type": "cache_hit_routed", + "timestamp": 200.0, + "provider": "openai", + "endpoint": "/v1/chat/completions", + "model": "model-a", + "batch_id": "batch-cached-1", + "source": "resumed_poll", + "custom_id": "custom-1", + } + display.on_event(cache_event) + + current_time["value"] = 206.0 + assert display._compute_elapsed_seconds() == 6 From 6d7ca0667cfeeec5e18a7d48e4e125913215ef15 Mon Sep 17 00:00:00 2001 From: Raphael Date: Mon, 2 Mar 2026 18:48:49 -0800 Subject: [PATCH 10/20] ui: lower live display refresh to 1hz --- src/batchling/rich_display.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/batchling/rich_display.py b/src/batchling/rich_display.py index 42f0382..1845836 100644 --- a/src/batchling/rich_display.py +++ b/src/batchling/rich_display.py @@ -47,7 +47,7 @@ class BatcherRichDisplay: def __init__( self, *, - refresh_per_second: float = 8.0, + refresh_per_second: float = 1.0, console: Console | None = None, ) -> None: self._console = console or Console(stderr=True) From 90865080aba3c0965bdeaed55850808a3df72046 Mon Sep 17 00:00:00 2001 From: Raphael Date: Mon, 2 Mar 2026 18:55:00 -0800 Subject: [PATCH 11/20] ui: heartbeat refresh every second without poll countdown --- src/batchling/context.py | 33 +++++++++++++++++++++++++++++++++ src/batchling/rich_display.py | 9 +++++++-- tests/test_context.py | 4 ++++ 3 files changed, 44 insertions(+), 2 deletions(-) diff --git a/src/batchling/context.py b/src/batchling/context.py index 8e9f844..039e6e4 100644 --- a/src/batchling/context.py +++ b/src/batchling/context.py @@ -46,6 +46,7 @@ def __init__( self._self_batcher = batcher self._self_live_display_mode = live_display self._self_live_display: BatcherRichDisplay | None = None + self._self_live_display_heartbeat_task: asyncio.Task[None] | None = None self._self_context_token: t.Any | None = None def _start_live_display(self) -> None: @@ -65,6 +66,7 @@ def _start_live_display(self) -> None: self._self_batcher._add_event_listener(listener=display.on_event) display.start() self._self_live_display = display + self._start_live_display_heartbeat() except Exception as error: warnings.warn( message=f"Failed to start batchling live display: {error}", @@ -72,6 +74,33 @@ def _start_live_display(self) -> None: stacklevel=2, ) + async def _run_live_display_heartbeat(self) -> None: + """ + Periodically refresh the live display while the context is active. + """ + try: + while self._self_live_display is not None: + self._self_live_display.refresh() + await asyncio.sleep(1.0) + except asyncio.CancelledError: + raise + + def _start_live_display_heartbeat(self) -> None: + """ + Start the 1-second live display heartbeat when an event loop exists. + """ + if self._self_live_display is None: + return + if self._self_live_display_heartbeat_task is not None: + return + try: + loop = asyncio.get_running_loop() + except RuntimeError: + return + self._self_live_display_heartbeat_task = loop.create_task( + coro=self._run_live_display_heartbeat() + ) + def _stop_live_display(self) -> None: """ Stop and unregister the Rich live display. @@ -84,6 +113,10 @@ def _stop_live_display(self) -> None: return display = self._self_live_display self._self_live_display = None + heartbeat_task = self._self_live_display_heartbeat_task + self._self_live_display_heartbeat_task = None + if heartbeat_task is not None and not heartbeat_task.done(): + heartbeat_task.cancel() try: self._self_batcher._remove_event_listener(listener=display.on_event) display.stop() diff --git a/src/batchling/rich_display.py b/src/batchling/rich_display.py index 1845836..e3f5a38 100644 --- a/src/batchling/rich_display.py +++ b/src/batchling/rich_display.py @@ -114,8 +114,13 @@ def on_event(self, event: BatcherEvent) -> None: if batch.latest_status == "submitted": batch.latest_status = "resumed" - if self._live is not None: - self._live.update(renderable=self._render(), refresh=True) + self.refresh() + + def refresh(self) -> None: + """Force one live-panel refresh when running.""" + if self._live is None: + return + self._live.update(renderable=self._render(), refresh=True) def _get_or_create_batch(self, *, batch_id: str) -> _BatchActivity: """ diff --git a/tests/test_context.py b/tests/test_context.py index b0f828d..198c378 100644 --- a/tests/test_context.py +++ b/tests/test_context.py @@ -133,7 +133,10 @@ def on_event(self, event: dict[str, t.Any]) -> None: with patch.object(target=batcher, attribute="close", new_callable=AsyncMock): async with context: assert dummy_display.started is True + assert context._self_live_display_heartbeat_task is not None + assert not context._self_live_display_heartbeat_task.done() assert dummy_display.stopped is True + assert context._self_live_display_heartbeat_task is None def test_batching_context_sync_stops_live_display_without_loop( @@ -173,3 +176,4 @@ def on_event(self, event: dict[str, t.Any]) -> None: pass assert dummy_display.stopped is True + assert context._self_live_display_heartbeat_task is None From 4e05899e567e6ec59187e390042b3d01e18de052 Mon Sep 17 00:00:00 2001 From: Raphael Date: Mon, 2 Mar 2026 18:56:16 -0800 Subject: [PATCH 12/20] ui: colorize live progress metrics --- src/batchling/rich_display.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/batchling/rich_display.py b/src/batchling/rich_display.py index e3f5a38..01da1ba 100644 --- a/src/batchling/rich_display.py +++ b/src/batchling/rich_display.py @@ -235,11 +235,12 @@ def _build_progress_bar(self) -> Progress: BarColumn(bar_width=None), TextColumn( text_format=( - "{task.fields[completed_samples]}/" - "{task.fields[total_samples]} ({task.percentage:.1f}%)" + "[bold green]{task.fields[completed_samples]}[/bold green]/" + "[bold cyan]{task.fields[total_samples]}[/bold cyan] " + "([bold yellow]{task.percentage:.1f}%[/bold yellow])" ) ), - TextColumn(text_format=f"Time Elapsed: {elapsed_label}"), + TextColumn(text_format=f"Time Elapsed: [bold magenta]{elapsed_label}[/bold magenta]"), expand=True, ) display_total = max(total_samples, 1) From ea1b16bc0932ef2393383056f4bec21e674d554c Mon Sep 17 00:00:00 2001 From: Raphael Date: Mon, 2 Mar 2026 19:04:09 -0800 Subject: [PATCH 13/20] ui: add styled request counters under progress bar --- src/batchling/rich_display.py | 62 +++++++++++++++++++++++++++++- tests/test_rich_display.py | 71 +++++++++++++++++++++++++++++++++++ 2 files changed, 131 insertions(+), 2 deletions(-) diff --git a/src/batchling/rich_display.py b/src/batchling/rich_display.py index 01da1ba..53ade2c 100644 --- a/src/batchling/rich_display.py +++ b/src/batchling/rich_display.py @@ -8,10 +8,11 @@ import typing as t from dataclasses import dataclass -from rich.console import Console +from rich.console import Console, Group from rich.live import Live from rich.panel import Panel from rich.progress import BarColumn, Progress, TextColumn +from rich.text import Text from batchling.core import BatcherEvent @@ -26,6 +27,7 @@ class _BatchActivity: size: int = 0 latest_status: str = "submitted" completed: bool = False + terminal: bool = False updated_at: float = 0.0 @@ -53,6 +55,7 @@ def __init__( self._console = console or Console(stderr=True) self._refresh_per_second = refresh_per_second self._batches: dict[str, _BatchActivity] = {} + self._cached_samples = 0 self._first_batch_created_at: float | None = None self._live: Live | None = None @@ -94,25 +97,31 @@ def on_event(self, event: BatcherEvent) -> None: if isinstance(request_count, int): batch.size = max(batch.size, request_count) batch.latest_status = "simulated" if source == "dry_run" else "submitted" + batch.terminal = False elif batch_id is not None and event_type == "batch_polled": batch = self._get_or_create_batch(batch_id=str(object=batch_id)) status = event.get("status") if status is not None: batch.latest_status = str(object=status) + batch.terminal = False elif batch_id is not None and event_type == "batch_terminal": batch = self._get_or_create_batch(batch_id=str(object=batch_id)) status = str(object=event.get("status", "completed")) batch.latest_status = status batch.completed = self._status_counts_as_completed(status=status) + batch.terminal = True elif batch_id is not None and event_type == "batch_failed": batch = self._get_or_create_batch(batch_id=str(object=batch_id)) batch.latest_status = "failed" batch.completed = False + batch.terminal = True elif batch_id is not None and event_type == "cache_hit_routed" and source == "resumed_poll": batch = self._get_or_create_batch(batch_id=str(object=batch_id)) batch.size += 1 + self._cached_samples += 1 if batch.latest_status == "submitted": batch.latest_status = "resumed" + batch.terminal = False self.refresh() @@ -183,6 +192,22 @@ def _compute_progress(self) -> tuple[int, int, float]: percent = (completed_samples / total_samples) * 100 return completed_samples, total_samples, percent + def _compute_request_metrics(self) -> tuple[int, int, int, int]: + """ + Compute aggregate request counters shown under the progress bar. + + Returns + ------- + tuple[int, int, int, int] + ``(total_samples, cached_samples, completed_samples, in_progress_samples)``. + """ + total_samples = sum(batch.size for batch in self._batches.values()) + completed_samples = sum(batch.size for batch in self._batches.values() if batch.completed) + in_progress_samples = sum( + batch.size for batch in self._batches.values() if not batch.terminal + ) + return total_samples, self._cached_samples, completed_samples, in_progress_samples + def _compute_elapsed_seconds(self) -> int: """ Compute elapsed seconds since first batch creation in this context. @@ -219,8 +244,9 @@ def _format_elapsed(*, elapsed_seconds: int) -> str: def _render(self) -> Panel: """Build the current Rich panel renderable.""" progress_bar = self._build_progress_bar() + requests_line = self._build_requests_line() return Panel( - renderable=progress_bar, + renderable=Group(progress_bar, requests_line), title="batchling context progress", border_style="cyan", ) @@ -253,6 +279,38 @@ def _build_progress_bar(self) -> Progress: ) return progress + def _build_requests_line(self) -> Text: + """ + Build one-line request metrics shown under the progress bar. + + Returns + ------- + Text + Styled metrics line. + """ + total_samples, cached_samples, completed_samples, in_progress_samples = ( + self._compute_request_metrics() + ) + line = Text() + line.append(text="Requests", style="bold white") + line.append(text=": ", style="white") + line.append(text="Total", style="grey70") + line.append(text=": ", style="grey70") + line.append(text=str(object=total_samples), style="bold cyan") + line.append(text=" - ", style="grey70") + line.append(text="Cached", style="grey70") + line.append(text=": ", style="grey70") + line.append(text=str(object=cached_samples), style="bold magenta") + line.append(text=" - ", style="grey70") + line.append(text="Completed", style="grey70") + line.append(text=": ", style="grey70") + line.append(text=str(object=completed_samples), style="bold green") + line.append(text=" - ", style="grey70") + line.append(text="In Progress", style="grey70") + line.append(text=": ", style="grey70") + line.append(text=str(object=in_progress_samples), style="bold yellow") + return line + def should_enable_live_display(*, mode: LiveDisplayMode) -> bool: """ diff --git a/tests/test_rich_display.py b/tests/test_rich_display.py index 501dfb2..980f193 100644 --- a/tests/test_rich_display.py +++ b/tests/test_rich_display.py @@ -186,3 +186,74 @@ def fake_time() -> float: current_time["value"] = 206.0 assert display._compute_elapsed_seconds() == 6 + + +def test_batcher_rich_display_request_metrics_line() -> None: + """Test requests metrics aggregate total/cached/completed/in-progress samples.""" + display = rich_display.BatcherRichDisplay( + console=Console(file=io.StringIO(), force_terminal=False), + ) + + processing_event_batch_1: rich_display.BatcherEvent = { + "event_type": "batch_processing", + "timestamp": 1.0, + "provider": "openai", + "endpoint": "/v1/chat/completions", + "model": "model-a", + "queue_key": ("openai", "/v1/chat/completions", "model-a"), + "batch_id": "batch-1", + "request_count": 3, + "source": "poll_start", + } + terminal_event_batch_1: rich_display.BatcherEvent = { + "event_type": "batch_terminal", + "timestamp": 2.0, + "provider": "openai", + "batch_id": "batch-1", + "status": "completed", + "source": "active_poll", + } + processing_event_batch_2: rich_display.BatcherEvent = { + "event_type": "batch_processing", + "timestamp": 3.0, + "provider": "openai", + "endpoint": "/v1/chat/completions", + "model": "model-a", + "queue_key": ("openai", "/v1/chat/completions", "model-a"), + "batch_id": "batch-2", + "request_count": 2, + "source": "poll_start", + } + cache_event_batch_3: rich_display.BatcherEvent = { + "event_type": "cache_hit_routed", + "timestamp": 4.0, + "provider": "openai", + "endpoint": "/v1/chat/completions", + "model": "model-a", + "batch_id": "batch-3", + "source": "resumed_poll", + "custom_id": "custom-1", + } + terminal_event_batch_3: rich_display.BatcherEvent = { + "event_type": "batch_terminal", + "timestamp": 5.0, + "provider": "openai", + "batch_id": "batch-3", + "status": "completed", + "source": "resumed_poll", + } + + display.on_event(processing_event_batch_1) + display.on_event(terminal_event_batch_1) + display.on_event(processing_event_batch_2) + display.on_event(cache_event_batch_3) + display.on_event(cache_event_batch_3) + display.on_event(terminal_event_batch_3) + + total_samples, cached_samples, completed_samples, in_progress_samples = ( + display._compute_request_metrics() + ) + assert total_samples == 7 + assert cached_samples == 2 + assert completed_samples == 5 + assert in_progress_samples == 2 From 6cd27ccfa2a65d8a3e09fd1e2d66f585d94b206e Mon Sep 17 00:00:00 2001 From: Raphael Date: Mon, 2 Mar 2026 19:14:44 -0800 Subject: [PATCH 14/20] ui: add pending batches table with dataframe-style truncation --- docs/cli.md | 3 + docs/python-sdk.md | 3 + src/batchling/rich_display.py | 117 +++++++++++++++++++++++++++++++++- tests/test_rich_display.py | 80 +++++++++++++++++++++++ 4 files changed, 201 insertions(+), 2 deletions(-) diff --git a/docs/cli.md b/docs/cli.md index a9dcb6b..ed3d08b 100644 --- a/docs/cli.md +++ b/docs/cli.md @@ -81,6 +81,9 @@ batchling generate_product_images.py:main --live-display auto When enabled, the panel shows overall context progress: `completed_samples / total_samples`, completion percentage, and `Time Elapsed` since the first batch seen in the context. +It also shows request counters and a pending-batches table (`batch_id`, +`provider`, `endpoint`, `model`, `status`) truncated to 5 rows +(`top 2`, `...`, `last 2`). ## Next Steps diff --git a/docs/python-sdk.md b/docs/python-sdk.md index 19711c9..c9f4bf2 100644 --- a/docs/python-sdk.md +++ b/docs/python-sdk.md @@ -97,6 +97,9 @@ async with batchify(live_display="on"): When enabled, the panel shows context-level progress only: `completed_samples / total_samples`, completion percentage, and `Time Elapsed` since the first batch seen in the context. +It also shows request counters and a pending-batches table (`batch_id`, +`provider`, `endpoint`, `model`, `status`) truncated to 5 rows +(`top 2`, `...`, `last 2`). You can now run this script normally using python and start saving money: diff --git a/src/batchling/rich_display.py b/src/batchling/rich_display.py index 53ade2c..a132ae3 100644 --- a/src/batchling/rich_display.py +++ b/src/batchling/rich_display.py @@ -12,6 +12,7 @@ from rich.live import Live from rich.panel import Panel from rich.progress import BarColumn, Progress, TextColumn +from rich.table import Table from rich.text import Text from batchling.core import BatcherEvent @@ -24,10 +25,14 @@ class _BatchActivity: """In-memory batch activity snapshot for progress aggregation.""" batch_id: str + provider: str = "-" + endpoint: str = "-" + model: str = "-" size: int = 0 latest_status: str = "submitted" completed: bool = False terminal: bool = False + created_at: float = 0.0 updated_at: float = 0.0 @@ -93,6 +98,7 @@ def on_event(self, event: BatcherEvent) -> None: if batch_id is not None and event_type == "batch_processing": batch = self._get_or_create_batch(batch_id=str(object=batch_id)) + self._update_batch_identity(batch=batch, event=event) request_count = event.get("request_count") if isinstance(request_count, int): batch.size = max(batch.size, request_count) @@ -100,23 +106,27 @@ def on_event(self, event: BatcherEvent) -> None: batch.terminal = False elif batch_id is not None and event_type == "batch_polled": batch = self._get_or_create_batch(batch_id=str(object=batch_id)) + self._update_batch_identity(batch=batch, event=event) status = event.get("status") if status is not None: batch.latest_status = str(object=status) batch.terminal = False elif batch_id is not None and event_type == "batch_terminal": batch = self._get_or_create_batch(batch_id=str(object=batch_id)) + self._update_batch_identity(batch=batch, event=event) status = str(object=event.get("status", "completed")) batch.latest_status = status batch.completed = self._status_counts_as_completed(status=status) batch.terminal = True elif batch_id is not None and event_type == "batch_failed": batch = self._get_or_create_batch(batch_id=str(object=batch_id)) + self._update_batch_identity(batch=batch, event=event) batch.latest_status = "failed" batch.completed = False batch.terminal = True elif batch_id is not None and event_type == "cache_hit_routed" and source == "resumed_poll": batch = self._get_or_create_batch(batch_id=str(object=batch_id)) + self._update_batch_identity(batch=batch, event=event) batch.size += 1 self._cached_samples += 1 if batch.latest_status == "submitted": @@ -148,13 +158,38 @@ def _get_or_create_batch(self, *, batch_id: str) -> _BatchActivity: now = time.time() batch = self._batches.get(batch_id) if batch is None: - batch = _BatchActivity(batch_id=batch_id) + batch = _BatchActivity( + batch_id=batch_id, + created_at=now, + ) self._batches[batch_id] = batch if self._first_batch_created_at is None: self._first_batch_created_at = now batch.updated_at = now return batch + @staticmethod + def _update_batch_identity(*, batch: _BatchActivity, event: BatcherEvent) -> None: + """ + Update batch metadata from lifecycle event payload. + + Parameters + ---------- + batch : _BatchActivity + Mutable batch row. + event : BatcherEvent + Lifecycle event payload. + """ + provider = event.get("provider") + endpoint = event.get("endpoint") + model = event.get("model") + if provider is not None: + batch.provider = str(object=provider) + if endpoint is not None: + batch.endpoint = str(object=endpoint) + if model is not None: + batch.model = str(object=model) + @staticmethod def _status_counts_as_completed(*, status: str) -> bool: """ @@ -245,8 +280,12 @@ def _render(self) -> Panel: """Build the current Rich panel renderable.""" progress_bar = self._build_progress_bar() requests_line = self._build_requests_line() + pending_batches_line = self._build_pending_batches_line() + pending_batches_table = self._build_pending_batches_table() return Panel( - renderable=Group(progress_bar, requests_line), + renderable=Group( + progress_bar, requests_line, pending_batches_line, pending_batches_table + ), title="batchling context progress", border_style="cyan", ) @@ -311,6 +350,80 @@ def _build_requests_line(self) -> Text: line.append(text=str(object=in_progress_samples), style="bold yellow") return line + def _get_pending_batches(self) -> list[_BatchActivity]: + """ + Return pending (non-terminal) batches sorted by oldest first. + + Returns + ------- + list[_BatchActivity] + Sorted pending batches. + """ + pending_batches = [batch for batch in self._batches.values() if not batch.terminal] + return sorted(pending_batches, key=lambda batch: batch.created_at) + + def _build_pending_batches_line(self) -> Text: + """ + Build one-line pending batches summary shown above the table. + + Returns + ------- + Text + Styled summary line. + """ + pending_count = len(self._get_pending_batches()) + line = Text() + line.append(text="Pending batches", style="bold white") + line.append(text=": ", style="white") + line.append(text=str(object=pending_count), style="bold yellow") + return line + + def _build_pending_batches_table(self) -> Table: + """ + Build pending-batches table with dataframe-style truncation. + + Returns + ------- + Table + Pending batches table. + """ + pending_batches = self._get_pending_batches() + table = Table(expand=True) + table.add_column(header="batch_id", style="bold") + table.add_column(header="provider") + table.add_column(header="endpoint") + table.add_column(header="model") + table.add_column(header="status") + + if not pending_batches: + table.add_row("-", "-", "-", "-", "-") + return table + + display_batches: list[_BatchActivity | None] + if len(pending_batches) <= 5: + display_batches = [*pending_batches] + else: + display_batches = [ + pending_batches[0], + pending_batches[1], + None, + pending_batches[-2], + pending_batches[-1], + ] + + for batch in display_batches: + if batch is None: + table.add_row("...", "...", "...", "...", "...") + continue + table.add_row( + batch.batch_id, + batch.provider, + batch.endpoint, + batch.model, + batch.latest_status, + ) + return table + def should_enable_live_display(*, mode: LiveDisplayMode) -> bool: """ diff --git a/tests/test_rich_display.py b/tests/test_rich_display.py index 980f193..ea0e797 100644 --- a/tests/test_rich_display.py +++ b/tests/test_rich_display.py @@ -257,3 +257,83 @@ def test_batcher_rich_display_request_metrics_line() -> None: assert cached_samples == 2 assert completed_samples == 5 assert in_progress_samples == 2 + + +def test_batcher_rich_display_pending_batches_table_truncates_with_ellipsis() -> None: + """Test pending table shows top2/ellipsis/last2 for more than five rows.""" + display = rich_display.BatcherRichDisplay( + console=Console(file=io.StringIO(), force_terminal=False), + ) + + for batch_index in range(1, 7): + processing_event: rich_display.BatcherEvent = { + "event_type": "batch_processing", + "timestamp": float(batch_index), + "provider": "openai", + "endpoint": f"/v1/endpoint/{batch_index}", + "model": f"model-{batch_index}", + "queue_key": ("openai", f"/v1/endpoint/{batch_index}", f"model-{batch_index}"), + "batch_id": f"batch-{batch_index}", + "request_count": batch_index, + "source": "poll_start", + } + display.on_event(processing_event) + + table = display._build_pending_batches_table() + batch_id_cells = table.columns[0]._cells + + assert len(batch_id_cells) == 5 + assert batch_id_cells[0] == "batch-1" + assert batch_id_cells[1] == "batch-2" + assert batch_id_cells[2] == "..." + assert batch_id_cells[3] == "batch-5" + assert batch_id_cells[4] == "batch-6" + + +def test_batcher_rich_display_pending_batches_excludes_terminal() -> None: + """Test pending table includes only non-terminal batches.""" + display = rich_display.BatcherRichDisplay( + console=Console(file=io.StringIO(), force_terminal=False), + ) + + pending_event: rich_display.BatcherEvent = { + "event_type": "batch_processing", + "timestamp": 1.0, + "provider": "openai", + "endpoint": "/v1/pending", + "model": "model-pending", + "queue_key": ("openai", "/v1/pending", "model-pending"), + "batch_id": "batch-pending", + "request_count": 1, + "source": "poll_start", + } + completed_event: rich_display.BatcherEvent = { + "event_type": "batch_processing", + "timestamp": 2.0, + "provider": "openai", + "endpoint": "/v1/completed", + "model": "model-completed", + "queue_key": ("openai", "/v1/completed", "model-completed"), + "batch_id": "batch-completed", + "request_count": 1, + "source": "poll_start", + } + terminal_event: rich_display.BatcherEvent = { + "event_type": "batch_terminal", + "timestamp": 3.0, + "provider": "openai", + "batch_id": "batch-completed", + "status": "completed", + "source": "active_poll", + } + + display.on_event(pending_event) + display.on_event(completed_event) + display.on_event(terminal_event) + + pending_batches = display._get_pending_batches() + assert len(pending_batches) == 1 + assert pending_batches[0].batch_id == "batch-pending" + + pending_line = display._build_pending_batches_line() + assert "Pending batches: 1" in pending_line.plain From c20d0593f86b115e165bfd298b284dfc3fd4eeda Mon Sep 17 00:00:00 2001 From: Raphael Date: Mon, 2 Mar 2026 19:26:52 -0800 Subject: [PATCH 15/20] ui: change color to green for completed samples --- src/batchling/rich_display.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/batchling/rich_display.py b/src/batchling/rich_display.py index a132ae3..255fccf 100644 --- a/src/batchling/rich_display.py +++ b/src/batchling/rich_display.py @@ -302,7 +302,7 @@ def _build_progress_bar(self) -> Progress: text_format=( "[bold green]{task.fields[completed_samples]}[/bold green]/" "[bold cyan]{task.fields[total_samples]}[/bold cyan] " - "([bold yellow]{task.percentage:.1f}%[/bold yellow])" + "([bold green]{task.percentage:.1f}%[/bold green])" ) ), TextColumn(text_format=f"Time Elapsed: [bold magenta]{elapsed_label}[/bold magenta]"), From a971bd0bc79407cd4425ac030778a71a23f74460 Mon Sep 17 00:00:00 2001 From: Raphael Date: Mon, 2 Mar 2026 19:33:38 -0800 Subject: [PATCH 16/20] live display: switch to bool auto mode with polling-log fallback --- docs/architecture/api.md | 9 +- docs/architecture/context.md | 5 +- docs/cli.md | 11 +- docs/python-sdk.md | 13 ++- src/batchling/api.py | 11 +- src/batchling/cli/main.py | 18 +-- src/batchling/context.py | 193 +++++++++++++++++++++++++++++--- src/batchling/rich_display.py | 13 +-- tests/test_api.py | 14 +-- tests/test_cli_script_runner.py | 11 +- tests/test_context.py | 43 ++++++- tests/test_rich_display.py | 18 ++- 12 files changed, 289 insertions(+), 70 deletions(-) diff --git a/docs/architecture/api.md b/docs/architecture/api.md index 11dfe92..637636a 100644 --- a/docs/architecture/api.md +++ b/docs/architecture/api.md @@ -26,9 +26,12 @@ yields `None`. Import it from `batchling`. - **`cache` behavior**: when `cache=True` (default), intercepted requests are fingerprinted and looked up in a persistent request cache. Cache hits bypass queueing and resume polling from an existing provider batch when not in dry-run mode. -- **`live_display` behavior**: `live_display` accepts `auto`, `on`, or `off`. - In `auto`, the Rich panel is enabled only when `stderr` is a TTY, terminal - is not `dumb`, and `CI` is not set. +- **`live_display` behavior**: `live_display` is a boolean. + When `True` (default), Rich panel rendering runs in auto mode and is enabled + only when `stderr` is a TTY, terminal is not `dumb`, and `CI` is not set. + If auto mode disables Rich, context-level progress is logged at `INFO` on + polling events. + When `False`, live display and fallback progress logs are both disabled. - **Outputs**: `BatchingContext[None]` instance that yields `None`. - **Logging**: lifecycle milestones are emitted at `INFO`, problems at `WARNING`/`ERROR`, and high-volume diagnostics at `DEBUG`. Request payloads diff --git a/docs/architecture/context.md b/docs/architecture/context.md index 53b409d..8e67530 100644 --- a/docs/architecture/context.md +++ b/docs/architecture/context.md @@ -17,8 +17,9 @@ a context variable. 2. `__enter__`/`__aenter__` set the active batcher for the entire context block. 3. `__exit__` resets the context and schedules `batcher.close()` if an event loop is running (otherwise it warns). -4. If `live_display` is enabled, the context registers a lifecycle listener and starts - the Rich panel at enter-time. +4. If `live_display=True`, the context attempts to start Rich panel rendering at + enter-time when terminal auto-detection passes (`TTY`, non-`dumb`, non-`CI`). + Otherwise it registers an `INFO` logging fallback that emits progress at poll-time. 5. `__aexit__` resets the context and awaits `batcher.close()` to flush pending work. 6. The live display listener is removed and the panel is stopped when context cleanup finishes. diff --git a/docs/cli.md b/docs/cli.md index ed3d08b..bd59cf2 100644 --- a/docs/cli.md +++ b/docs/cli.md @@ -69,14 +69,15 @@ That's it! Just run that command and you save 50% off your workflow. The CLI also exposes the live Rich panel control: ```bash -batchling generate_product_images.py:main --live-display auto +batchling generate_product_images.py:main --live-display ``` -`--live-display` accepts: +`--live-display` is a boolean flag pair: -- `auto` (default): only in interactive terminals (`TTY`, non-`dumb`, non-`CI`) -- `on`: always render the panel -- `off`: never render the panel +- `--live-display` (default): auto mode, Rich panel only in interactive terminals + (`TTY`, non-`dumb`, non-`CI`). If Rich auto-disables, progress is emitted as + `INFO` logs on polling events. +- `--no-live-display`: disable both Rich panel and fallback progress logs. When enabled, the panel shows overall context progress: `completed_samples / total_samples`, completion percentage, and `Time Elapsed` diff --git a/docs/python-sdk.md b/docs/python-sdk.md index c9f4bf2..7b42439 100644 --- a/docs/python-sdk.md +++ b/docs/python-sdk.md @@ -81,18 +81,19 @@ That's it! Update three lines of code and you save 50% off your workflow. ## Live visibility panel -You can enable a Rich live panel while the context is active: +You can toggle live visibility behavior while the context is active: ```py -async with batchify(live_display="on"): +async with batchify(live_display=True): generated_images = await asyncio.gather(*tasks) ``` -`live_display` accepts: +`live_display` accepts a boolean: -- `auto` (default): only in interactive terminals (`TTY`, non-`dumb`, non-`CI`) -- `on`: always render the panel -- `off`: never render the panel +- `True` (default): auto mode, Rich panel only in interactive terminals + (`TTY`, non-`dumb`, non-`CI`). If Rich auto-disables, progress is emitted as + `INFO` logs on polling events. +- `False`: disable both Rich panel and fallback progress logs. When enabled, the panel shows context-level progress only: `completed_samples / total_samples`, completion percentage, and `Time Elapsed` diff --git a/src/batchling/api.py b/src/batchling/api.py index 6d8b1bc..46a2f88 100644 --- a/src/batchling/api.py +++ b/src/batchling/api.py @@ -8,7 +8,6 @@ from batchling.core import Batcher from batchling.hooks import install_hooks from batchling.logging import setup_logging -from batchling.rich_display import LiveDisplayMode def batchify( @@ -17,7 +16,7 @@ def batchify( batch_poll_interval_seconds: float = 10.0, dry_run: bool = False, cache: bool = True, - live_display: LiveDisplayMode = "auto", + live_display: bool = True, ) -> BatchingContext: """ Context manager used to activate batching for a scoped context.
@@ -39,9 +38,11 @@ def batchify( cache : bool, optional If ``True``, enable persistent request cache lookups.
This parameter allows to skip the batch submission and go straight to the polling phase for requests that have already been sent. - live_display : {"auto", "on", "off"}, optional - Toggle the Rich live panel shown while the context is active.
- ``"auto"`` enables the panel only in interactive terminals. + live_display : bool, optional + Enable live display behavior while the context is active.
+ When ``True``, Rich panel rendering is attempted with terminal auto-detection. + If terminal auto-detection disables Rich (non-TTY, ``TERM=dumb``, or ``CI``), + progress is logged at ``INFO`` on polling events. Returns ------- diff --git a/src/batchling/cli/main.py b/src/batchling/cli/main.py index 2f3bafd..6cba6dd 100644 --- a/src/batchling/cli/main.py +++ b/src/batchling/cli/main.py @@ -7,7 +7,6 @@ import typer from batchling import batchify -from batchling.rich_display import LiveDisplayMode # syncify = lambda f: wraps(f)(lambda *args, **kwargs: asyncio.run(f(*args, **kwargs))) @@ -76,7 +75,7 @@ async def run_script_with_batchify( batch_poll_interval_seconds: float, dry_run: bool, cache: bool, - live_display: LiveDisplayMode, + live_display: bool, ) -> None: """ Execute a Python script under a batchify context. @@ -99,8 +98,8 @@ async def run_script_with_batchify( Dry run mode passed to ``batchify``. cache : bool Cache mode passed to ``batchify``. - live_display : {"auto", "on", "off"} - Live display mode passed to ``batchify``. + live_display : bool + Live display toggle passed to ``batchify``. """ if not module_path.exists(): typer.echo(f"Script not found: {module_path}") @@ -161,14 +160,15 @@ def main( typer.Option("--cache/--no-cache", help="Enable persistent request caching"), ] = True, live_display: t.Annotated[ - LiveDisplayMode, + bool, typer.Option( + "--live-display/--no-live-display", help=( - "Show the live Rich panel: auto (interactive terminals only), " - "on (always), off (never)" - ) + "Enable auto live display. When disabled by terminal auto-detection, " + "fallback polling progress is logged at INFO." + ), ), - ] = "auto", + ] = True, ): """Run a script under ``batchify``.""" try: diff --git a/src/batchling/context.py b/src/batchling/context.py index 039e6e4..6687a90 100644 --- a/src/batchling/context.py +++ b/src/batchling/context.py @@ -3,17 +3,145 @@ """ import asyncio +import logging import typing as t import warnings +from dataclasses import dataclass -from batchling.core import Batcher +from batchling.core import Batcher, BatcherEvent from batchling.hooks import active_batcher +from batchling.logging import log_info from batchling.rich_display import ( BatcherRichDisplay, - LiveDisplayMode, should_enable_live_display, ) +log = logging.getLogger(name=__name__) + + +@dataclass +class _ProgressLogBatchState: + """Aggregate state used by the polling progress logger fallback.""" + + size: int = 0 + completed: bool = False + terminal: bool = False + + +class _PollingProgressLogger: + """INFO logger fallback used when Rich live display auto-disables.""" + + def __init__(self) -> None: + self._state_by_batch_id: dict[str, _ProgressLogBatchState] = {} + + @staticmethod + def _status_counts_as_completed(*, status: str) -> bool: + """ + Determine whether a terminal status counts as completed samples. + + Parameters + ---------- + status : str + Terminal provider status. + + Returns + ------- + bool + ``True`` when terminal state should contribute to completed samples. + """ + lowered_status = status.lower() + negative_markers = ("fail", "error", "cancel", "expired", "timeout") + if any(marker in lowered_status for marker in negative_markers): + return False + return True + + def _compute_progress(self) -> tuple[int, int, float, int]: + """ + Compute aggregate sample progress from tracked batch states. + + Returns + ------- + tuple[int, int, float, int] + ``(completed_samples, total_samples, percent, in_progress_samples)``. + """ + total_samples = sum(state.size for state in self._state_by_batch_id.values()) + completed_samples = sum( + state.size for state in self._state_by_batch_id.values() if state.completed + ) + in_progress_samples = sum( + state.size for state in self._state_by_batch_id.values() if not state.terminal + ) + if total_samples <= 0: + return 0, 0, 0.0, in_progress_samples + percent = (completed_samples / total_samples) * 100.0 + return completed_samples, total_samples, percent, in_progress_samples + + def _get_or_create_batch_state(self, *, batch_id: str) -> _ProgressLogBatchState: + """ + Get or create progress state for one batch. + + Parameters + ---------- + batch_id : str + Provider batch identifier. + + Returns + ------- + _ProgressLogBatchState + Mutable batch state. + """ + state = self._state_by_batch_id.get(batch_id) + if state is None: + state = _ProgressLogBatchState() + self._state_by_batch_id[batch_id] = state + return state + + def on_event(self, event: BatcherEvent) -> None: + """ + Consume one lifecycle event and log progress on poll events. + + Parameters + ---------- + event : BatcherEvent + Lifecycle event emitted by ``Batcher``. + """ + event_type = str(object=event.get("event_type", "unknown")) + batch_id = event.get("batch_id") + if batch_id is not None: + state = self._get_or_create_batch_state(batch_id=str(object=batch_id)) + if event_type == "batch_processing": + request_count = event.get("request_count") + if isinstance(request_count, int): + state.size = max(state.size, request_count) + state.terminal = False + elif event_type == "cache_hit_routed" and str(object=event.get("source", "")) == ( + "resumed_poll" + ): + state.size += 1 + state.terminal = False + elif event_type == "batch_terminal": + status = str(object=event.get("status", "completed")) + state.completed = self._status_counts_as_completed(status=status) + state.terminal = True + elif event_type == "batch_failed": + state.completed = False + state.terminal = True + + if event_type != "batch_polled": + return + + completed_samples, total_samples, percent, in_progress_samples = self._compute_progress() + log_info( + logger=log, + event="Live display fallback progress", + batch_id=event.get("batch_id"), + status=event.get("status"), + completed_samples=completed_samples, + total_samples=total_samples, + percent=f"{percent:.1f}", + in_progress_samples=in_progress_samples, + ) + class BatchingContext: """ @@ -23,15 +151,15 @@ class BatchingContext: ---------- batcher : Batcher Batcher instance used for the scope of the context manager. - live_display : LiveDisplayMode, optional - Live display mode used when entering the context. + live_display : bool, optional + Whether to enable auto live display behavior for the context. """ def __init__( self, *, batcher: "Batcher", - live_display: LiveDisplayMode = "auto", + live_display: bool = True, ) -> None: """ Initialize the context manager. @@ -40,15 +168,44 @@ def __init__( ---------- batcher : Batcher Batcher instance used for the scope of the context manager. - live_display : LiveDisplayMode, optional - Live display mode used when entering the context. + live_display : bool, optional + Whether to enable auto live display behavior for the context. """ self._self_batcher = batcher - self._self_live_display_mode = live_display + self._self_live_display_enabled = live_display self._self_live_display: BatcherRichDisplay | None = None self._self_live_display_heartbeat_task: asyncio.Task[None] | None = None + self._self_polling_progress_logger: _PollingProgressLogger | None = None self._self_context_token: t.Any | None = None + def _start_polling_progress_logger(self) -> None: + """ + Start the INFO polling progress fallback listener. + + Notes + ----- + Fallback listener errors are downgraded to warnings. + """ + if self._self_polling_progress_logger is not None: + return + try: + listener = _PollingProgressLogger() + self._self_batcher._add_event_listener(listener=listener.on_event) + self._self_polling_progress_logger = listener + log_info( + logger=log, + event=( + "Live display disabled by terminal auto-detection; " + "using polling progress INFO logs" + ), + ) + except Exception as error: + warnings.warn( + message=f"Failed to start batchling polling progress logs: {error}", + category=UserWarning, + stacklevel=2, + ) + def _start_live_display(self) -> None: """ Start the Rich live display when enabled. @@ -57,9 +214,12 @@ def _start_live_display(self) -> None: ----- Display errors are downgraded to warnings to avoid breaking batching. """ - if self._self_live_display is not None: + if self._self_live_display is not None or self._self_polling_progress_logger is not None: + return + if not self._self_live_display_enabled: return - if not should_enable_live_display(mode=self._self_live_display_mode): + if not should_enable_live_display(enabled=self._self_live_display_enabled): + self._start_polling_progress_logger() return try: display = BatcherRichDisplay() @@ -109,17 +269,22 @@ def _stop_live_display(self) -> None: ----- Display shutdown errors are downgraded to warnings. """ - if self._self_live_display is None: - return display = self._self_live_display + fallback_listener = self._self_polling_progress_logger + if display is None and fallback_listener is None: + return self._self_live_display = None + self._self_polling_progress_logger = None heartbeat_task = self._self_live_display_heartbeat_task self._self_live_display_heartbeat_task = None if heartbeat_task is not None and not heartbeat_task.done(): heartbeat_task.cancel() try: - self._self_batcher._remove_event_listener(listener=display.on_event) - display.stop() + if display is not None: + self._self_batcher._remove_event_listener(listener=display.on_event) + display.stop() + if fallback_listener is not None: + self._self_batcher._remove_event_listener(listener=fallback_listener.on_event) except Exception as error: warnings.warn( message=f"Failed to stop batchling live display: {error}", diff --git a/src/batchling/rich_display.py b/src/batchling/rich_display.py index 255fccf..14f638e 100644 --- a/src/batchling/rich_display.py +++ b/src/batchling/rich_display.py @@ -5,7 +5,6 @@ import os import sys import time -import typing as t from dataclasses import dataclass from rich.console import Console, Group @@ -17,8 +16,6 @@ from batchling.core import BatcherEvent -LiveDisplayMode = t.Literal["auto", "on", "off"] - @dataclass class _BatchActivity: @@ -425,23 +422,21 @@ def _build_pending_batches_table(self) -> Table: return table -def should_enable_live_display(*, mode: LiveDisplayMode) -> bool: +def should_enable_live_display(*, enabled: bool) -> bool: """ Resolve if the Rich live panel should be enabled. Parameters ---------- - mode : LiveDisplayMode - Desired display mode. + enabled : bool + Requested live display toggle. Returns ------- bool ``True`` when the live panel should run. """ - if mode == "on": - return True - if mode == "off": + if not enabled: return False stderr_stream = sys.stderr diff --git a/tests/test_api.py b/tests/test_api.py index b87d54f..446d2b6 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -61,21 +61,21 @@ async def test_batchify_configures_cache_flag(reset_hooks, reset_context): @pytest.mark.asyncio -async def test_batchify_forwards_live_display_mode(reset_hooks, reset_context): - """Test that batchify forwards live display mode to BatchingContext.""" +async def test_batchify_forwards_live_display_flag(reset_hooks, reset_context): + """Test that batchify forwards live display flag to BatchingContext.""" wrapped = batchify( - live_display="off", + live_display=False, ) - assert wrapped._self_live_display_mode == "off" + assert wrapped._self_live_display_enabled is False @pytest.mark.asyncio -async def test_batchify_live_display_defaults_to_auto(reset_hooks, reset_context): - """Test that live display mode defaults to auto.""" +async def test_batchify_live_display_defaults_to_true(reset_hooks, reset_context): + """Test that live display defaults to enabled.""" wrapped = batchify() - assert wrapped._self_live_display_mode == "auto" + assert wrapped._self_live_display_enabled is True @pytest.mark.asyncio diff --git a/tests/test_cli_script_runner.py b/tests/test_cli_script_runner.py index 4cdcca1..e7426d7 100644 --- a/tests/test_cli_script_runner.py +++ b/tests/test_cli_script_runner.py @@ -61,7 +61,7 @@ def fake_batchify(**kwargs): } assert captured_batchify_kwargs["dry_run"] is True assert captured_batchify_kwargs["cache"] is True - assert captured_batchify_kwargs["live_display"] == "auto" + assert captured_batchify_kwargs["live_display"] is True def test_run_script_with_cache_option(tmp_path: Path, monkeypatch): @@ -93,10 +93,10 @@ def fake_batchify(**kwargs): assert result.exit_code == 0 assert captured_batchify_kwargs["cache"] is False - assert captured_batchify_kwargs["live_display"] == "auto" + assert captured_batchify_kwargs["live_display"] is True -def test_run_script_with_live_display_option(tmp_path: Path, monkeypatch): +def test_run_script_with_no_live_display_option(tmp_path: Path, monkeypatch): script_path = tmp_path / "script.py" script_path.write_text( "\n".join( @@ -119,13 +119,12 @@ def fake_batchify(**kwargs): app, [ f"{script_path.as_posix()}:foo", - "--live-display", - "off", + "--no-live-display", ], ) assert result.exit_code == 0 - assert captured_batchify_kwargs["live_display"] == "off" + assert captured_batchify_kwargs["live_display"] is False def test_batch_size_flag_scope_for_cli_and_target_function(tmp_path: Path, monkeypatch): diff --git a/tests/test_context.py b/tests/test_context.py index 198c378..72c6eb9 100644 --- a/tests/test_context.py +++ b/tests/test_context.py @@ -3,6 +3,7 @@ """ import asyncio +import logging import typing as t import warnings from unittest.mock import AsyncMock, patch @@ -127,7 +128,7 @@ def on_event(self, event: dict[str, t.Any]) -> None: monkeypatch.setattr("batchling.context.should_enable_live_display", lambda **_kwargs: True) context = BatchingContext( batcher=batcher, - live_display="on", + live_display=True, ) with patch.object(target=batcher, attribute="close", new_callable=AsyncMock): @@ -167,7 +168,7 @@ def on_event(self, event: dict[str, t.Any]) -> None: context = BatchingContext( batcher=batcher, - live_display="on", + live_display=True, ) with warnings.catch_warnings(record=True): @@ -177,3 +178,41 @@ def on_event(self, event: dict[str, t.Any]) -> None: assert dummy_display.stopped is True assert context._self_live_display_heartbeat_task is None + + +def test_batching_context_uses_polling_progress_fallback_when_auto_disabled( + batcher: Batcher, + caplog: pytest.LogCaptureFixture, + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Test live-display fallback logs progress at poll time when Rich is disabled.""" + monkeypatch.setattr("batchling.context.should_enable_live_display", lambda **_kwargs: False) + context = BatchingContext( + batcher=batcher, + live_display=True, + ) + + caplog.set_level(level=logging.INFO, logger="batchling.context") + + context._start_live_display() + assert context._self_live_display is None + assert context._self_polling_progress_logger is not None + assert any("using polling progress INFO logs" in record.message for record in caplog.records) + + batcher._emit_event( + event_type="batch_processing", + batch_id="batch-1", + request_count=4, + source="poll_start", + ) + batcher._emit_event( + event_type="batch_polled", + batch_id="batch-1", + status="in_progress", + source="active_poll", + ) + + assert any("Live display fallback progress" in record.message for record in caplog.records) + + context._stop_live_display() + assert context._self_polling_progress_logger is None diff --git a/tests/test_rich_display.py b/tests/test_rich_display.py index ea0e797..7c76323 100644 --- a/tests/test_rich_display.py +++ b/tests/test_rich_display.py @@ -20,7 +20,7 @@ def isatty(self) -> bool: monkeypatch.setenv("TERM", "xterm-256color") monkeypatch.delenv("CI", raising=False) - assert rich_display.should_enable_live_display(mode="auto") is True + assert rich_display.should_enable_live_display(enabled=True) is True def test_should_enable_live_display_auto_disabled_in_ci(monkeypatch) -> None: @@ -34,7 +34,21 @@ def isatty(self) -> bool: monkeypatch.setenv("TERM", "xterm-256color") monkeypatch.setenv("CI", "true") - assert rich_display.should_enable_live_display(mode="auto") is False + assert rich_display.should_enable_live_display(enabled=True) is False + + +def test_should_enable_live_display_disabled_by_flag(monkeypatch) -> None: + """Test explicit disable always returns False.""" + + class DummyStderr: + def isatty(self) -> bool: + return True + + monkeypatch.setattr(rich_display.sys, "stderr", DummyStderr()) + monkeypatch.setenv("TERM", "xterm-256color") + monkeypatch.delenv("CI", raising=False) + + assert rich_display.should_enable_live_display(enabled=False) is False def test_batcher_rich_display_computes_context_progress() -> None: From 5a1bbb2fe9755a19804ffe1084f948b70b31eaae Mon Sep 17 00:00:00 2001 From: Raphael Date: Mon, 2 Mar 2026 19:40:42 -0800 Subject: [PATCH 17/20] examples: add provider race streaming completion example --- examples/providers/provider_race.py | 311 ++++++++++++++++++++++++++++ 1 file changed, 311 insertions(+) create mode 100644 examples/providers/provider_race.py diff --git a/examples/providers/provider_race.py b/examples/providers/provider_race.py new file mode 100644 index 0000000..ee5c9b8 --- /dev/null +++ b/examples/providers/provider_race.py @@ -0,0 +1,311 @@ +import asyncio +import os +import time +import typing as t +from dataclasses import dataclass + +from anthropic import AsyncAnthropic +from dotenv import load_dotenv +from groq import AsyncGroq +from mistralai import Mistral +from openai import AsyncOpenAI +from together import AsyncTogether + +from batchling import batchify + +load_dotenv() + + +@dataclass +class ProviderRaceResult: + """One provider completion entry in completion order.""" + + model: str + elapsed_seconds: float + answer: str + + +ProviderRequestBuilder = t.Callable[[], t.Coroutine[t.Any, t.Any, tuple[str, str]]] + + +async def run_openai_request(*, prompt: str) -> tuple[str, str]: + """ + Send one OpenAI request. + + Parameters + ---------- + prompt : str + User prompt sent to the provider. + + Returns + ------- + tuple[str, str] + ``(model_name, answer_text)``. + """ + client = AsyncOpenAI(api_key=os.getenv(key="OPENAI_API_KEY")) + response = await client.responses.create( + input=prompt, + model="gpt-4o-mini", + ) + content = response.output[-1].content + return response.model, content[0].text + + +async def run_anthropic_request(*, prompt: str) -> tuple[str, str]: + """ + Send one Anthropic request. + + Parameters + ---------- + prompt : str + User prompt sent to the provider. + + Returns + ------- + tuple[str, str] + ``(model_name, answer_text)``. + """ + client = AsyncAnthropic(api_key=os.getenv(key="ANTHROPIC_API_KEY")) + response = await client.messages.create( + max_tokens=1024, + messages=[ + { + "role": "user", + "content": prompt, + } + ], + model="claude-haiku-4-5", + ) + return response.model, response.content[0].text + + +async def run_groq_request(*, prompt: str) -> tuple[str, str]: + """ + Send one Groq request. + + Parameters + ---------- + prompt : str + User prompt sent to the provider. + + Returns + ------- + tuple[str, str] + ``(model_name, answer_text)``. + """ + client = AsyncGroq(api_key=os.getenv(key="GROQ_API_KEY")) + response = await client.chat.completions.create( + model="llama-3.1-8b-instant", + messages=[ + { + "role": "user", + "content": prompt, + } + ], + ) + return response.model, response.choices[0].message.content + + +async def run_mistral_request(*, prompt: str) -> tuple[str, str]: + """ + Send one Mistral request. + + Parameters + ---------- + prompt : str + User prompt sent to the provider. + + Returns + ------- + tuple[str, str] + ``(model_name, answer_text)``. + """ + client = Mistral(api_key=os.getenv(key="MISTRAL_API_KEY")) + response = await client.chat.complete_async( + model="mistral-medium-2505", + messages=[ + { + "role": "user", + "content": prompt, + } + ], + stream=False, + response_format={"type": "text"}, + ) + return response.model, str(object=response.choices[0].message.content) + + +async def run_together_request(*, prompt: str) -> tuple[str, str]: + """ + Send one Together request. + + Parameters + ---------- + prompt : str + User prompt sent to the provider. + + Returns + ------- + tuple[str, str] + ``(model_name, answer_text)``. + """ + client = AsyncTogether(api_key=os.getenv(key="TOGETHER_API_KEY")) + response = await client.chat.completions.create( + model="google/gemma-3n-E4B-it", + messages=[ + { + "role": "user", + "content": prompt, + } + ], + ) + return response.model, response.choices[0].message.content + + +async def run_doubleword_request(*, prompt: str) -> tuple[str, str]: + """ + Send one Doubleword request. + + Parameters + ---------- + prompt : str + User prompt sent to the provider. + + Returns + ------- + tuple[str, str] + ``(model_name, answer_text)``. + """ + client = AsyncOpenAI( + api_key=os.getenv(key="DOUBLEWORD_API_KEY"), + base_url="https://api.doubleword.ai/v1", + ) + response = await client.responses.create( + input=prompt, + model="openai/gpt-oss-20b", + ) + content = response.output[-1].content + return response.model, content[0].text + + +async def run_provider_request( + *, + request_builder: ProviderRequestBuilder, + started_at: float, +) -> ProviderRaceResult: + """ + Execute one provider request and annotate elapsed time. + + Parameters + ---------- + request_builder : ProviderRequestBuilder + Provider request coroutine factory. + started_at : float + Shared wall-clock start time in ``perf_counter`` seconds. + + Returns + ------- + ProviderRaceResult + Result payload with answer and elapsed time. + """ + model, answer = await request_builder() + elapsed_seconds = time.perf_counter() - started_at + return ProviderRaceResult( + model=model, + elapsed_seconds=elapsed_seconds, + answer=answer, + ) + + +def build_enabled_request_builders(*, prompt: str) -> list[ProviderRequestBuilder]: + """ + Build one request factory per configured provider. + + Parameters + ---------- + prompt : str + Shared text prompt sent to all providers. + + Returns + ------- + list[ProviderRequestBuilder] + Enabled provider request factories. + """ + providers: list[tuple[str, ProviderRequestBuilder]] = [ + ( + "OPENAI_API_KEY", + lambda: run_openai_request(prompt=prompt), + ), + ( + "ANTHROPIC_API_KEY", + lambda: run_anthropic_request(prompt=prompt), + ), + ( + "GROQ_API_KEY", + lambda: run_groq_request(prompt=prompt), + ), + ( + "MISTRAL_API_KEY", + lambda: run_mistral_request(prompt=prompt), + ), + ( + "TOGETHER_API_KEY", + lambda: run_together_request(prompt=prompt), + ), + ( + "DOUBLEWORD_API_KEY", + lambda: run_doubleword_request(prompt=prompt), + ), + ] + enabled_builders: list[ProviderRequestBuilder] = [] + for env_var_name, request_builder in providers: + api_key = os.getenv(key=env_var_name) + if not api_key: + continue + enabled_builders.append(request_builder) + return enabled_builders + + +async def main() -> None: + """ + Run one request per provider and collect completion-order results. + + The race excludes Gemini and XAI on purpose because their model field + extraction differs from the other provider examples. + """ + prompt = "Give one short sentence explaining what asynchronous batching is." + request_builders = build_enabled_request_builders(prompt=prompt) + if not request_builders: + print("No providers configured. Set at least one provider API key in your environment.") + return + + started_at = time.perf_counter() + tasks = [ + asyncio.create_task( + run_provider_request( + request_builder=request_builder, + started_at=started_at, + ) + ) + for request_builder in request_builders + ] + + completion_order_register: list[ProviderRaceResult] = [] + for task in asyncio.as_completed(tasks): + result = await task + completion_order_register.append(result) + + for index, result in enumerate(completion_order_register, start=1): + print(f"{index}. model={result.model}") + print(f" elapsed={result.elapsed_seconds:.2f}s") + print(f" answer={result.answer}\n") + + +async def run_with_batchify() -> None: + """Run the provider race inside ``batchify`` for direct script execution.""" + async with batchify(): + await main() + + +if __name__ == "__main__": + asyncio.run(run_with_batchify()) From 0c8697142a28490b191376581b373ef75bf3c7ae Mon Sep 17 00:00:00 2001 From: Raphael Date: Mon, 2 Mar 2026 21:30:33 -0800 Subject: [PATCH 18/20] ui: stabilize queue progress formatting and remove stale rich display state --- docs/cli.md | 7 +- docs/python-sdk.md | 7 +- examples/providers/provider_race.py | 311 ---------------------------- src/batchling/rich_display.py | 177 +++++++++------- tests/test_rich_display.py | 123 ++++++----- 5 files changed, 177 insertions(+), 448 deletions(-) delete mode 100644 examples/providers/provider_race.py diff --git a/docs/cli.md b/docs/cli.md index bd59cf2..9c05cda 100644 --- a/docs/cli.md +++ b/docs/cli.md @@ -82,9 +82,10 @@ batchling generate_product_images.py:main --live-display When enabled, the panel shows overall context progress: `completed_samples / total_samples`, completion percentage, and `Time Elapsed` since the first batch seen in the context. -It also shows request counters and a pending-batches table (`batch_id`, -`provider`, `endpoint`, `model`, `status`) truncated to 5 rows -(`top 2`, `...`, `last 2`). +It also shows request counters and a queue summary table with one row per +`(provider, endpoint, model)`, including `progress` as +`completed/total (percentage)` where `completed` is terminal batches and +`total` is `running + completed`. ## Next Steps diff --git a/docs/python-sdk.md b/docs/python-sdk.md index 7b42439..4208075 100644 --- a/docs/python-sdk.md +++ b/docs/python-sdk.md @@ -98,9 +98,10 @@ async with batchify(live_display=True): When enabled, the panel shows context-level progress only: `completed_samples / total_samples`, completion percentage, and `Time Elapsed` since the first batch seen in the context. -It also shows request counters and a pending-batches table (`batch_id`, -`provider`, `endpoint`, `model`, `status`) truncated to 5 rows -(`top 2`, `...`, `last 2`). +It also shows request counters and a queue summary table with one row per +`(provider, endpoint, model)`, including `progress` as +`completed/total (percentage)` where `completed` is terminal batches and +`total` is `running + completed`. You can now run this script normally using python and start saving money: diff --git a/examples/providers/provider_race.py b/examples/providers/provider_race.py deleted file mode 100644 index ee5c9b8..0000000 --- a/examples/providers/provider_race.py +++ /dev/null @@ -1,311 +0,0 @@ -import asyncio -import os -import time -import typing as t -from dataclasses import dataclass - -from anthropic import AsyncAnthropic -from dotenv import load_dotenv -from groq import AsyncGroq -from mistralai import Mistral -from openai import AsyncOpenAI -from together import AsyncTogether - -from batchling import batchify - -load_dotenv() - - -@dataclass -class ProviderRaceResult: - """One provider completion entry in completion order.""" - - model: str - elapsed_seconds: float - answer: str - - -ProviderRequestBuilder = t.Callable[[], t.Coroutine[t.Any, t.Any, tuple[str, str]]] - - -async def run_openai_request(*, prompt: str) -> tuple[str, str]: - """ - Send one OpenAI request. - - Parameters - ---------- - prompt : str - User prompt sent to the provider. - - Returns - ------- - tuple[str, str] - ``(model_name, answer_text)``. - """ - client = AsyncOpenAI(api_key=os.getenv(key="OPENAI_API_KEY")) - response = await client.responses.create( - input=prompt, - model="gpt-4o-mini", - ) - content = response.output[-1].content - return response.model, content[0].text - - -async def run_anthropic_request(*, prompt: str) -> tuple[str, str]: - """ - Send one Anthropic request. - - Parameters - ---------- - prompt : str - User prompt sent to the provider. - - Returns - ------- - tuple[str, str] - ``(model_name, answer_text)``. - """ - client = AsyncAnthropic(api_key=os.getenv(key="ANTHROPIC_API_KEY")) - response = await client.messages.create( - max_tokens=1024, - messages=[ - { - "role": "user", - "content": prompt, - } - ], - model="claude-haiku-4-5", - ) - return response.model, response.content[0].text - - -async def run_groq_request(*, prompt: str) -> tuple[str, str]: - """ - Send one Groq request. - - Parameters - ---------- - prompt : str - User prompt sent to the provider. - - Returns - ------- - tuple[str, str] - ``(model_name, answer_text)``. - """ - client = AsyncGroq(api_key=os.getenv(key="GROQ_API_KEY")) - response = await client.chat.completions.create( - model="llama-3.1-8b-instant", - messages=[ - { - "role": "user", - "content": prompt, - } - ], - ) - return response.model, response.choices[0].message.content - - -async def run_mistral_request(*, prompt: str) -> tuple[str, str]: - """ - Send one Mistral request. - - Parameters - ---------- - prompt : str - User prompt sent to the provider. - - Returns - ------- - tuple[str, str] - ``(model_name, answer_text)``. - """ - client = Mistral(api_key=os.getenv(key="MISTRAL_API_KEY")) - response = await client.chat.complete_async( - model="mistral-medium-2505", - messages=[ - { - "role": "user", - "content": prompt, - } - ], - stream=False, - response_format={"type": "text"}, - ) - return response.model, str(object=response.choices[0].message.content) - - -async def run_together_request(*, prompt: str) -> tuple[str, str]: - """ - Send one Together request. - - Parameters - ---------- - prompt : str - User prompt sent to the provider. - - Returns - ------- - tuple[str, str] - ``(model_name, answer_text)``. - """ - client = AsyncTogether(api_key=os.getenv(key="TOGETHER_API_KEY")) - response = await client.chat.completions.create( - model="google/gemma-3n-E4B-it", - messages=[ - { - "role": "user", - "content": prompt, - } - ], - ) - return response.model, response.choices[0].message.content - - -async def run_doubleword_request(*, prompt: str) -> tuple[str, str]: - """ - Send one Doubleword request. - - Parameters - ---------- - prompt : str - User prompt sent to the provider. - - Returns - ------- - tuple[str, str] - ``(model_name, answer_text)``. - """ - client = AsyncOpenAI( - api_key=os.getenv(key="DOUBLEWORD_API_KEY"), - base_url="https://api.doubleword.ai/v1", - ) - response = await client.responses.create( - input=prompt, - model="openai/gpt-oss-20b", - ) - content = response.output[-1].content - return response.model, content[0].text - - -async def run_provider_request( - *, - request_builder: ProviderRequestBuilder, - started_at: float, -) -> ProviderRaceResult: - """ - Execute one provider request and annotate elapsed time. - - Parameters - ---------- - request_builder : ProviderRequestBuilder - Provider request coroutine factory. - started_at : float - Shared wall-clock start time in ``perf_counter`` seconds. - - Returns - ------- - ProviderRaceResult - Result payload with answer and elapsed time. - """ - model, answer = await request_builder() - elapsed_seconds = time.perf_counter() - started_at - return ProviderRaceResult( - model=model, - elapsed_seconds=elapsed_seconds, - answer=answer, - ) - - -def build_enabled_request_builders(*, prompt: str) -> list[ProviderRequestBuilder]: - """ - Build one request factory per configured provider. - - Parameters - ---------- - prompt : str - Shared text prompt sent to all providers. - - Returns - ------- - list[ProviderRequestBuilder] - Enabled provider request factories. - """ - providers: list[tuple[str, ProviderRequestBuilder]] = [ - ( - "OPENAI_API_KEY", - lambda: run_openai_request(prompt=prompt), - ), - ( - "ANTHROPIC_API_KEY", - lambda: run_anthropic_request(prompt=prompt), - ), - ( - "GROQ_API_KEY", - lambda: run_groq_request(prompt=prompt), - ), - ( - "MISTRAL_API_KEY", - lambda: run_mistral_request(prompt=prompt), - ), - ( - "TOGETHER_API_KEY", - lambda: run_together_request(prompt=prompt), - ), - ( - "DOUBLEWORD_API_KEY", - lambda: run_doubleword_request(prompt=prompt), - ), - ] - enabled_builders: list[ProviderRequestBuilder] = [] - for env_var_name, request_builder in providers: - api_key = os.getenv(key=env_var_name) - if not api_key: - continue - enabled_builders.append(request_builder) - return enabled_builders - - -async def main() -> None: - """ - Run one request per provider and collect completion-order results. - - The race excludes Gemini and XAI on purpose because their model field - extraction differs from the other provider examples. - """ - prompt = "Give one short sentence explaining what asynchronous batching is." - request_builders = build_enabled_request_builders(prompt=prompt) - if not request_builders: - print("No providers configured. Set at least one provider API key in your environment.") - return - - started_at = time.perf_counter() - tasks = [ - asyncio.create_task( - run_provider_request( - request_builder=request_builder, - started_at=started_at, - ) - ) - for request_builder in request_builders - ] - - completion_order_register: list[ProviderRaceResult] = [] - for task in asyncio.as_completed(tasks): - result = await task - completion_order_register.append(result) - - for index, result in enumerate(completion_order_register, start=1): - print(f"{index}. model={result.model}") - print(f" elapsed={result.elapsed_seconds:.2f}s") - print(f" answer={result.answer}\n") - - -async def run_with_batchify() -> None: - """Run the provider race inside ``batchify`` for direct script execution.""" - async with batchify(): - await main() - - -if __name__ == "__main__": - asyncio.run(run_with_batchify()) diff --git a/src/batchling/rich_display.py b/src/batchling/rich_display.py index 14f638e..b4f4c50 100644 --- a/src/batchling/rich_display.py +++ b/src/batchling/rich_display.py @@ -26,11 +26,8 @@ class _BatchActivity: endpoint: str = "-" model: str = "-" size: int = 0 - latest_status: str = "submitted" completed: bool = False terminal: bool = False - created_at: float = 0.0 - updated_at: float = 0.0 class BatcherRichDisplay: @@ -99,26 +96,20 @@ def on_event(self, event: BatcherEvent) -> None: request_count = event.get("request_count") if isinstance(request_count, int): batch.size = max(batch.size, request_count) - batch.latest_status = "simulated" if source == "dry_run" else "submitted" batch.terminal = False elif batch_id is not None and event_type == "batch_polled": batch = self._get_or_create_batch(batch_id=str(object=batch_id)) self._update_batch_identity(batch=batch, event=event) - status = event.get("status") - if status is not None: - batch.latest_status = str(object=status) batch.terminal = False elif batch_id is not None and event_type == "batch_terminal": batch = self._get_or_create_batch(batch_id=str(object=batch_id)) self._update_batch_identity(batch=batch, event=event) status = str(object=event.get("status", "completed")) - batch.latest_status = status batch.completed = self._status_counts_as_completed(status=status) batch.terminal = True elif batch_id is not None and event_type == "batch_failed": batch = self._get_or_create_batch(batch_id=str(object=batch_id)) self._update_batch_identity(batch=batch, event=event) - batch.latest_status = "failed" batch.completed = False batch.terminal = True elif batch_id is not None and event_type == "cache_hit_routed" and source == "resumed_poll": @@ -126,8 +117,6 @@ def on_event(self, event: BatcherEvent) -> None: self._update_batch_identity(batch=batch, event=event) batch.size += 1 self._cached_samples += 1 - if batch.latest_status == "submitted": - batch.latest_status = "resumed" batch.terminal = False self.refresh() @@ -152,17 +141,14 @@ def _get_or_create_batch(self, *, batch_id: str) -> _BatchActivity: _BatchActivity Batch display state. """ - now = time.time() batch = self._batches.get(batch_id) if batch is None: batch = _BatchActivity( batch_id=batch_id, - created_at=now, ) self._batches[batch_id] = batch if self._first_batch_created_at is None: - self._first_batch_created_at = now - batch.updated_at = now + self._first_batch_created_at = time.time() return batch @staticmethod @@ -277,12 +263,9 @@ def _render(self) -> Panel: """Build the current Rich panel renderable.""" progress_bar = self._build_progress_bar() requests_line = self._build_requests_line() - pending_batches_line = self._build_pending_batches_line() - pending_batches_table = self._build_pending_batches_table() + queue_summary_table = self._build_queue_summary_table() return Panel( - renderable=Group( - progress_bar, requests_line, pending_batches_line, pending_batches_table - ), + renderable=Group(progress_bar, requests_line, queue_summary_table), title="batchling context progress", border_style="cyan", ) @@ -292,13 +275,14 @@ def _build_progress_bar(self) -> Progress: completed_samples, total_samples, _ = self._compute_progress() elapsed_seconds = self._compute_elapsed_seconds() elapsed_label = self._format_elapsed(elapsed_seconds=elapsed_seconds) + sample_width = max(1, len(str(object=total_samples))) progress = Progress( BarColumn(bar_width=None), TextColumn( text_format=( - "[bold green]{task.fields[completed_samples]}[/bold green]/" - "[bold cyan]{task.fields[total_samples]}[/bold cyan] " + f"[bold green]{{task.fields[completed_samples]:>{sample_width}}}[/bold green]/" + f"[bold cyan]{{task.fields[total_samples]:>{sample_width}}}[/bold cyan] " "([bold green]{task.percentage:.1f}%[/bold green])" ) ), @@ -347,79 +331,116 @@ def _build_requests_line(self) -> Text: line.append(text=str(object=in_progress_samples), style="bold yellow") return line - def _get_pending_batches(self) -> list[_BatchActivity]: + def _compute_queue_batch_counts(self) -> list[tuple[str, str, str, int, int]]: """ - Return pending (non-terminal) batches sorted by oldest first. + Aggregate queue-level running and terminal batch counts. Returns ------- - list[_BatchActivity] - Sorted pending batches. + list[tuple[str, str, str, int, int]] + Sorted rows as ``(provider, endpoint, model, running, completed)``. """ - pending_batches = [batch for batch in self._batches.values() if not batch.terminal] - return sorted(pending_batches, key=lambda batch: batch.created_at) - - def _build_pending_batches_line(self) -> Text: + counts_by_queue: dict[tuple[str, str, str], list[int]] = {} + for batch in self._batches.values(): + queue_key = (batch.provider, batch.endpoint, batch.model) + counters = counts_by_queue.setdefault(queue_key, [0, 0]) + if batch.terminal: + counters[1] += 1 + else: + counters[0] += 1 + + rows = [ + (provider, endpoint, model, counters[0], counters[1]) + for (provider, endpoint, model), counters in counts_by_queue.items() + ] + return sorted(rows, key=lambda row: (row[0], row[1], row[2])) + + def _build_queue_summary_table(self) -> Table: """ - Build one-line pending batches summary shown above the table. + Build queue-level table with per-queue progress summary. Returns ------- - Text - Styled summary line. + Table + Queue summary table. """ - pending_count = len(self._get_pending_batches()) - line = Text() - line.append(text="Pending batches", style="bold white") - line.append(text=": ", style="white") - line.append(text=str(object=pending_count), style="bold yellow") - return line + queue_rows = self._compute_queue_batch_counts() + table = Table(expand=False) + table.add_column( + header="provider", + style="bold blue", + width=12, + no_wrap=True, + overflow="ellipsis", + ) + table.add_column( + header="endpoint", + width=34, + no_wrap=True, + overflow="ellipsis", + ) + table.add_column( + header="model", + style="bold magenta", + width=28, + no_wrap=True, + overflow="ellipsis", + ) + table.add_column( + header="progress", + justify="right", + width=16, + no_wrap=True, + overflow="ellipsis", + ) - def _build_pending_batches_table(self) -> Table: + if not queue_rows: + table.add_row("-", "-", "-", self._format_queue_progress(running=0, completed=0)) + return table + + for provider, endpoint, model, running, completed in queue_rows: + table.add_row( + provider, + endpoint, + model, + self._format_queue_progress( + running=running, + completed=completed, + ), + ) + return table + + @staticmethod + def _format_queue_progress(*, running: int, completed: int) -> Text: """ - Build pending-batches table with dataframe-style truncation. + Format one queue progress cell as ``completed/total (percent)``. + + Parameters + ---------- + running : int + Number of non-terminal batches. + completed : int + Number of terminal batches. Returns ------- - Table - Pending batches table. + Text + Formatted queue progress. """ - pending_batches = self._get_pending_batches() - table = Table(expand=True) - table.add_column(header="batch_id", style="bold") - table.add_column(header="provider") - table.add_column(header="endpoint") - table.add_column(header="model") - table.add_column(header="status") - - if not pending_batches: - table.add_row("-", "-", "-", "-", "-") - return table - - display_batches: list[_BatchActivity | None] - if len(pending_batches) <= 5: - display_batches = [*pending_batches] + total = running + completed + if total <= 0: + percent = 0.0 else: - display_batches = [ - pending_batches[0], - pending_batches[1], - None, - pending_batches[-2], - pending_batches[-1], - ] - - for batch in display_batches: - if batch is None: - table.add_row("...", "...", "...", "...", "...") - continue - table.add_row( - batch.batch_id, - batch.provider, - batch.endpoint, - batch.model, - batch.latest_status, - ) - return table + percent = (completed / total) * 100.0 + count_width = max(1, len(str(object=total))) + progress = Text() + progress.append(text=f"{completed:>{count_width}}", style="bold green") + progress.append(text="/", style="white") + progress.append(text=f"{total:>{count_width}}", style="bold cyan") + progress.append(text=" (", style="white") + progress.append(text=f"{percent:.1f}%", style="bold green") + progress.append(text=")", style="white") + return progress def should_enable_live_display(*, enabled: bool) -> bool: diff --git a/tests/test_rich_display.py b/tests/test_rich_display.py index 7c76323..2d37959 100644 --- a/tests/test_rich_display.py +++ b/tests/test_rich_display.py @@ -3,6 +3,7 @@ import io from rich.console import Console +from rich.text import Text import batchling.rich_display as rich_display @@ -273,81 +274,97 @@ def test_batcher_rich_display_request_metrics_line() -> None: assert in_progress_samples == 2 -def test_batcher_rich_display_pending_batches_table_truncates_with_ellipsis() -> None: - """Test pending table shows top2/ellipsis/last2 for more than five rows.""" +def test_batcher_rich_display_queue_table_progress_column() -> None: + """Test queue table progress column derives from completed/total counts.""" display = rich_display.BatcherRichDisplay( console=Console(file=io.StringIO(), force_terminal=False), ) - for batch_index in range(1, 7): - processing_event: rich_display.BatcherEvent = { - "event_type": "batch_processing", - "timestamp": float(batch_index), - "provider": "openai", - "endpoint": f"/v1/endpoint/{batch_index}", - "model": f"model-{batch_index}", - "queue_key": ("openai", f"/v1/endpoint/{batch_index}", f"model-{batch_index}"), - "batch_id": f"batch-{batch_index}", - "request_count": batch_index, - "source": "poll_start", - } - display.on_event(processing_event) - - table = display._build_pending_batches_table() - batch_id_cells = table.columns[0]._cells - - assert len(batch_id_cells) == 5 - assert batch_id_cells[0] == "batch-1" - assert batch_id_cells[1] == "batch-2" - assert batch_id_cells[2] == "..." - assert batch_id_cells[3] == "batch-5" - assert batch_id_cells[4] == "batch-6" - - -def test_batcher_rich_display_pending_batches_excludes_terminal() -> None: - """Test pending table includes only non-terminal batches.""" - display = rich_display.BatcherRichDisplay( - console=Console(file=io.StringIO(), force_terminal=False), - ) - - pending_event: rich_display.BatcherEvent = { + queue_event_batch_1: rich_display.BatcherEvent = { "event_type": "batch_processing", "timestamp": 1.0, "provider": "openai", - "endpoint": "/v1/pending", - "model": "model-pending", - "queue_key": ("openai", "/v1/pending", "model-pending"), - "batch_id": "batch-pending", + "endpoint": "/v1/chat/completions", + "model": "model-a", + "queue_key": ("openai", "/v1/chat/completions", "model-a"), + "batch_id": "batch-1", "request_count": 1, "source": "poll_start", } - completed_event: rich_display.BatcherEvent = { + queue_event_batch_2: rich_display.BatcherEvent = { "event_type": "batch_processing", "timestamp": 2.0, "provider": "openai", - "endpoint": "/v1/completed", - "model": "model-completed", - "queue_key": ("openai", "/v1/completed", "model-completed"), - "batch_id": "batch-completed", + "endpoint": "/v1/chat/completions", + "model": "model-a", + "queue_key": ("openai", "/v1/chat/completions", "model-a"), + "batch_id": "batch-2", "request_count": 1, "source": "poll_start", } - terminal_event: rich_display.BatcherEvent = { + terminal_event_batch_2: rich_display.BatcherEvent = { "event_type": "batch_terminal", "timestamp": 3.0, "provider": "openai", - "batch_id": "batch-completed", + "batch_id": "batch-2", "status": "completed", "source": "active_poll", } + other_queue_event: rich_display.BatcherEvent = { + "event_type": "batch_processing", + "timestamp": 4.0, + "provider": "groq", + "endpoint": "/openai/v1/chat/completions", + "model": "llama-3.1-8b-instant", + "queue_key": ("groq", "/openai/v1/chat/completions", "llama-3.1-8b-instant"), + "batch_id": "batch-3", + "request_count": 1, + "source": "poll_start", + } - display.on_event(pending_event) - display.on_event(completed_event) - display.on_event(terminal_event) + display.on_event(queue_event_batch_1) + display.on_event(queue_event_batch_2) + display.on_event(terminal_event_batch_2) + display.on_event(other_queue_event) + + queue_rows = display._compute_queue_batch_counts() + assert queue_rows == [ + ("groq", "/openai/v1/chat/completions", "llama-3.1-8b-instant", 1, 0), + ("openai", "/v1/chat/completions", "model-a", 1, 1), + ] + + table = display._build_queue_summary_table() + assert table.columns[0].width == 12 + assert table.columns[1].width == 34 + assert table.columns[2].width == 28 + assert table.columns[3].width == 16 + assert table.columns[0]._cells == ["groq", "openai"] + progress_cells = table.columns[3]._cells + assert isinstance(progress_cells[0], Text) + assert isinstance(progress_cells[1], Text) + assert progress_cells[0].plain == "0/1 (0.0%)" + assert progress_cells[1].plain == "1/2 (50.0%)" + + +def test_batcher_rich_display_queue_table_empty_state() -> None: + """Test queue table renders default row when no batches are tracked.""" + display = rich_display.BatcherRichDisplay( + console=Console(file=io.StringIO(), force_terminal=False), + ) + + table = display._build_queue_summary_table() + assert table.columns[0]._cells == ["-"] + assert table.columns[1]._cells == ["-"] + assert table.columns[2]._cells == ["-"] + progress_cells = table.columns[3]._cells + assert isinstance(progress_cells[0], Text) + assert progress_cells[0].plain == "0/0 (0.0%)" - pending_batches = display._get_pending_batches() - assert len(pending_batches) == 1 - assert pending_batches[0].batch_id == "batch-pending" - pending_line = display._build_pending_batches_line() - assert "Pending batches: 1" in pending_line.plain +def test_batcher_rich_display_queue_progress_pads_to_total_width() -> None: + """Test queue progress keeps parenthesis anchor stable for large totals.""" + progress_text = rich_display.BatcherRichDisplay._format_queue_progress( + running=99, + completed=1, + ) + assert progress_text.plain == " 1/100 (1.0%)" From 9c275c4351b4a4ee861ba028e8fae2331a4a94e6 Mon Sep 17 00:00:00 2001 From: Raphael Date: Mon, 2 Mar 2026 21:39:23 -0800 Subject: [PATCH 19/20] examples: racing --- examples/racing.py | 278 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 278 insertions(+) create mode 100644 examples/racing.py diff --git a/examples/racing.py b/examples/racing.py new file mode 100644 index 0000000..22ede9a --- /dev/null +++ b/examples/racing.py @@ -0,0 +1,278 @@ +import asyncio +import os +import time +import typing as t +from dataclasses import dataclass + +from dotenv import load_dotenv +from groq import AsyncGroq +from mistralai import Mistral +from openai import AsyncOpenAI +from together import AsyncTogether + +from batchling import batchify + +load_dotenv() + + +@dataclass +class ProviderRaceResult: + """One provider completion entry in completion order.""" + + model: str + elapsed_seconds: float + answer: str + + +ProviderRequestBuilder = t.Callable[[], t.Coroutine[t.Any, t.Any, tuple[str, str]]] + + +async def run_openai_request(*, prompt: str) -> tuple[str, str]: + """ + Send one OpenAI request. + + Parameters + ---------- + prompt : str + User prompt sent to the provider. + + Returns + ------- + tuple[str, str] + ``(model_name, answer_text)``. + """ + client = AsyncOpenAI(api_key=os.getenv(key="OPENAI_API_KEY")) + response = await client.responses.create( + input=prompt, + model="gpt-4o-mini", + ) + content = response.output[-1].content + return response.model, content[0].text + + +async def run_groq_request(*, prompt: str) -> tuple[str, str]: + """ + Send one Groq request. + + Parameters + ---------- + prompt : str + User prompt sent to the provider. + + Returns + ------- + tuple[str, str] + ``(model_name, answer_text)``. + """ + client = AsyncGroq(api_key=os.getenv(key="GROQ_API_KEY")) + response = await client.chat.completions.create( + model="llama-3.1-8b-instant", + messages=[ + { + "role": "user", + "content": prompt, + } + ], + ) + return response.model, response.choices[0].message.content + + +async def run_mistral_request(*, prompt: str) -> tuple[str, str]: + """ + Send one Mistral request. + + Parameters + ---------- + prompt : str + User prompt sent to the provider. + + Returns + ------- + tuple[str, str] + ``(model_name, answer_text)``. + """ + client = Mistral(api_key=os.getenv(key="MISTRAL_API_KEY")) + response = await client.chat.complete_async( + model="mistral-medium-2505", + messages=[ + { + "role": "user", + "content": prompt, + } + ], + stream=False, + response_format={"type": "text"}, + ) + return response.model, str(object=response.choices[0].message.content) + + +async def run_together_request(*, prompt: str) -> tuple[str, str]: + """ + Send one Together request. + + Parameters + ---------- + prompt : str + User prompt sent to the provider. + + Returns + ------- + tuple[str, str] + ``(model_name, answer_text)``. + """ + client = AsyncTogether(api_key=os.getenv(key="TOGETHER_API_KEY")) + response = await client.chat.completions.create( + model="google/gemma-3n-E4B-it", + messages=[ + { + "role": "user", + "content": prompt, + } + ], + ) + return response.model, response.choices[0].message.content + + +async def run_doubleword_request(*, prompt: str) -> tuple[str, str]: + """ + Send one Doubleword request. + + Parameters + ---------- + prompt : str + User prompt sent to the provider. + + Returns + ------- + tuple[str, str] + ``(model_name, answer_text)``. + """ + client = AsyncOpenAI( + api_key=os.getenv(key="DOUBLEWORD_API_KEY"), + base_url="https://api.doubleword.ai/v1", + ) + response = await client.responses.create( + input=prompt, + model="openai/gpt-oss-20b", + ) + content = response.output[-1].content + return response.model, content[0].text + + +async def run_provider_request( + *, + request_builder: ProviderRequestBuilder, + started_at: float, +) -> ProviderRaceResult: + """ + Execute one provider request and annotate elapsed time. + + Parameters + ---------- + request_builder : ProviderRequestBuilder + Provider request coroutine factory. + started_at : float + Shared wall-clock start time in ``perf_counter`` seconds. + + Returns + ------- + ProviderRaceResult + Result payload with answer and elapsed time. + """ + model, answer = await request_builder() + elapsed_seconds = time.perf_counter() - started_at + return ProviderRaceResult( + model=model, + elapsed_seconds=elapsed_seconds, + answer=answer, + ) + + +def build_enabled_request_builders(*, prompt: str) -> list[ProviderRequestBuilder]: + """ + Build one request factory per configured provider. + + Parameters + ---------- + prompt : str + Shared text prompt sent to all providers. + + Returns + ------- + list[ProviderRequestBuilder] + Enabled provider request factories. + """ + providers: list[tuple[str, ProviderRequestBuilder]] = [ + ( + "OPENAI_API_KEY", + lambda: run_openai_request(prompt=prompt), + ), + ( + "GROQ_API_KEY", + lambda: run_groq_request(prompt=prompt), + ), + ( + "MISTRAL_API_KEY", + lambda: run_mistral_request(prompt=prompt), + ), + ( + "TOGETHER_API_KEY", + lambda: run_together_request(prompt=prompt), + ), + ( + "DOUBLEWORD_API_KEY", + lambda: run_doubleword_request(prompt=prompt), + ), + ] + enabled_builders: list[ProviderRequestBuilder] = [] + for env_var_name, request_builder in providers: + api_key = os.getenv(key=env_var_name) + if not api_key: + continue + enabled_builders.append(request_builder) + return enabled_builders + + +async def main() -> None: + """ + Run one request per provider and collect completion-order results. + + The race excludes Anthropic, Gemini, and XAI on purpose because their model field + extraction differs from the other provider examples. + """ + prompt = "Give one short sentence explaining what asynchronous batching is." + request_builders = build_enabled_request_builders(prompt=prompt) + if not request_builders: + print("No providers configured. Set at least one provider API key in your environment.") + return + + started_at = time.perf_counter() + tasks = [ + asyncio.create_task( + run_provider_request( + request_builder=request_builder, + started_at=started_at, + ) + ) + for request_builder in request_builders + ] + + completion_order_register: list[ProviderRaceResult] = [] + for task in asyncio.as_completed(tasks): + result = await task + completion_order_register.append(result) + + for index, result in enumerate(completion_order_register, start=1): + print(f"{index}. model={result.model}") + print(f" elapsed={result.elapsed_seconds:.2f}s") + print(f" answer={result.answer}\n") + + +async def run_with_batchify() -> None: + """Run the provider race inside ``batchify`` for direct script execution.""" + async with batchify(): + await main() + + +if __name__ == "__main__": + asyncio.run(run_with_batchify()) From f18a1ec07cbc8d239f51e82d39c58d0a8bbef9d2 Mon Sep 17 00:00:00 2001 From: Raphael Date: Mon, 2 Mar 2026 21:45:46 -0800 Subject: [PATCH 20/20] fix: refactor shared code --- src/batchling/context.py | 101 +-------------- src/batchling/progress_state.py | 223 ++++++++++++++++++++++++++++++++ src/batchling/rich_display.py | 153 +--------------------- 3 files changed, 236 insertions(+), 241 deletions(-) create mode 100644 src/batchling/progress_state.py diff --git a/src/batchling/context.py b/src/batchling/context.py index 6687a90..82ccf12 100644 --- a/src/batchling/context.py +++ b/src/batchling/context.py @@ -6,11 +6,11 @@ import logging import typing as t import warnings -from dataclasses import dataclass from batchling.core import Batcher, BatcherEvent from batchling.hooks import active_batcher from batchling.logging import log_info +from batchling.progress_state import BatchProgressState from batchling.rich_display import ( BatcherRichDisplay, should_enable_live_display, @@ -19,82 +19,11 @@ log = logging.getLogger(name=__name__) -@dataclass -class _ProgressLogBatchState: - """Aggregate state used by the polling progress logger fallback.""" - - size: int = 0 - completed: bool = False - terminal: bool = False - - class _PollingProgressLogger: """INFO logger fallback used when Rich live display auto-disables.""" def __init__(self) -> None: - self._state_by_batch_id: dict[str, _ProgressLogBatchState] = {} - - @staticmethod - def _status_counts_as_completed(*, status: str) -> bool: - """ - Determine whether a terminal status counts as completed samples. - - Parameters - ---------- - status : str - Terminal provider status. - - Returns - ------- - bool - ``True`` when terminal state should contribute to completed samples. - """ - lowered_status = status.lower() - negative_markers = ("fail", "error", "cancel", "expired", "timeout") - if any(marker in lowered_status for marker in negative_markers): - return False - return True - - def _compute_progress(self) -> tuple[int, int, float, int]: - """ - Compute aggregate sample progress from tracked batch states. - - Returns - ------- - tuple[int, int, float, int] - ``(completed_samples, total_samples, percent, in_progress_samples)``. - """ - total_samples = sum(state.size for state in self._state_by_batch_id.values()) - completed_samples = sum( - state.size for state in self._state_by_batch_id.values() if state.completed - ) - in_progress_samples = sum( - state.size for state in self._state_by_batch_id.values() if not state.terminal - ) - if total_samples <= 0: - return 0, 0, 0.0, in_progress_samples - percent = (completed_samples / total_samples) * 100.0 - return completed_samples, total_samples, percent, in_progress_samples - - def _get_or_create_batch_state(self, *, batch_id: str) -> _ProgressLogBatchState: - """ - Get or create progress state for one batch. - - Parameters - ---------- - batch_id : str - Provider batch identifier. - - Returns - ------- - _ProgressLogBatchState - Mutable batch state. - """ - state = self._state_by_batch_id.get(batch_id) - if state is None: - state = _ProgressLogBatchState() - self._state_by_batch_id[batch_id] = state - return state + self._progress_state = BatchProgressState() def on_event(self, event: BatcherEvent) -> None: """ @@ -105,32 +34,14 @@ def on_event(self, event: BatcherEvent) -> None: event : BatcherEvent Lifecycle event emitted by ``Batcher``. """ - event_type = str(object=event.get("event_type", "unknown")) - batch_id = event.get("batch_id") - if batch_id is not None: - state = self._get_or_create_batch_state(batch_id=str(object=batch_id)) - if event_type == "batch_processing": - request_count = event.get("request_count") - if isinstance(request_count, int): - state.size = max(state.size, request_count) - state.terminal = False - elif event_type == "cache_hit_routed" and str(object=event.get("source", "")) == ( - "resumed_poll" - ): - state.size += 1 - state.terminal = False - elif event_type == "batch_terminal": - status = str(object=event.get("status", "completed")) - state.completed = self._status_counts_as_completed(status=status) - state.terminal = True - elif event_type == "batch_failed": - state.completed = False - state.terminal = True + self._progress_state.on_event(event=event) + event_type = str(object=event.get("event_type", "unknown")) if event_type != "batch_polled": return - completed_samples, total_samples, percent, in_progress_samples = self._compute_progress() + completed_samples, total_samples, percent = self._progress_state.compute_progress() + _, _, _, in_progress_samples = self._progress_state.compute_request_metrics() log_info( logger=log, event="Live display fallback progress", diff --git a/src/batchling/progress_state.py b/src/batchling/progress_state.py new file mode 100644 index 0000000..4a9b0c6 --- /dev/null +++ b/src/batchling/progress_state.py @@ -0,0 +1,223 @@ +"""Shared progress-state tracking for live display and fallback logging.""" + +from __future__ import annotations + +import time +import typing as t +from dataclasses import dataclass + +from batchling.core import BatcherEvent + + +@dataclass +class _TrackedBatch: + """In-memory batch state used for aggregate progress computations.""" + + batch_id: str + provider: str = "-" + endpoint: str = "-" + model: str = "-" + size: int = 0 + completed: bool = False + terminal: bool = False + + +class BatchProgressState: + """ + Track batch lifecycle state and compute shared aggregate metrics. + + Parameters + ---------- + now_fn : typing.Callable[[], float] | None, optional + Clock function used for elapsed-time calculations. + """ + + def __init__( + self, + *, + now_fn: t.Callable[[], float] | None = None, + ) -> None: + self._now_fn = now_fn or time.time + self._batches: dict[str, _TrackedBatch] = {} + self._cached_samples = 0 + self._first_batch_created_at: float | None = None + + def on_event(self, *, event: BatcherEvent) -> None: + """ + Update tracked state from one lifecycle event. + + Parameters + ---------- + event : BatcherEvent + Lifecycle event emitted by ``Batcher``. + """ + event_type = str(object=event.get("event_type", "unknown")) + source = str(object=event.get("source", "")) + batch_id = event.get("batch_id") + + if batch_id is None: + return + + batch = self._get_or_create_batch(batch_id=str(object=batch_id)) + self._update_batch_identity(batch=batch, event=event) + + if event_type == "batch_processing": + request_count = event.get("request_count") + if isinstance(request_count, int): + batch.size = max(batch.size, request_count) + batch.terminal = False + return + + if event_type == "batch_polled": + batch.terminal = False + return + + if event_type == "batch_terminal": + status = str(object=event.get("status", "completed")) + batch.completed = self._status_counts_as_completed(status=status) + batch.terminal = True + return + + if event_type == "batch_failed": + batch.completed = False + batch.terminal = True + return + + if event_type == "cache_hit_routed" and source == "resumed_poll": + batch.size += 1 + self._cached_samples += 1 + batch.terminal = False + + def compute_progress(self) -> tuple[int, int, float]: + """ + Compute aggregate sample progress from tracked batches. + + Returns + ------- + tuple[int, int, float] + ``(completed_samples, total_samples, percent)``. + """ + total_samples = sum(batch.size for batch in self._batches.values()) + completed_samples = sum(batch.size for batch in self._batches.values() if batch.completed) + if total_samples <= 0: + return 0, 0, 0.0 + percent = (completed_samples / total_samples) * 100.0 + return completed_samples, total_samples, percent + + def compute_request_metrics(self) -> tuple[int, int, int, int]: + """ + Compute aggregate request counters from tracked batches. + + Returns + ------- + tuple[int, int, int, int] + ``(total_samples, cached_samples, completed_samples, in_progress_samples)``. + """ + total_samples = sum(batch.size for batch in self._batches.values()) + completed_samples = sum(batch.size for batch in self._batches.values() if batch.completed) + in_progress_samples = sum( + batch.size for batch in self._batches.values() if not batch.terminal + ) + return total_samples, self._cached_samples, completed_samples, in_progress_samples + + def compute_queue_batch_counts(self) -> list[tuple[str, str, str, int, int]]: + """ + Aggregate queue-level running and terminal batch counts. + + Returns + ------- + list[tuple[str, str, str, int, int]] + Sorted rows as ``(provider, endpoint, model, running, completed)``. + """ + counts_by_queue: dict[tuple[str, str, str], list[int]] = {} + for batch in self._batches.values(): + queue_key = (batch.provider, batch.endpoint, batch.model) + counters = counts_by_queue.setdefault(queue_key, [0, 0]) + if batch.terminal: + counters[1] += 1 + else: + counters[0] += 1 + + rows = [ + (provider, endpoint, model, counters[0], counters[1]) + for (provider, endpoint, model), counters in counts_by_queue.items() + ] + return sorted(rows, key=lambda row: (row[0], row[1], row[2])) + + def compute_elapsed_seconds(self) -> int: + """ + Compute elapsed seconds since first tracked batch in this context. + + Returns + ------- + int + Elapsed seconds. + """ + if self._first_batch_created_at is None: + return 0 + return max(0, int(self._now_fn() - self._first_batch_created_at)) + + def _get_or_create_batch(self, *, batch_id: str) -> _TrackedBatch: + """ + Get or create one tracked batch record. + + Parameters + ---------- + batch_id : str + Provider batch identifier. + + Returns + ------- + _TrackedBatch + Mutable tracked batch. + """ + batch = self._batches.get(batch_id) + if batch is None: + batch = _TrackedBatch(batch_id=batch_id) + self._batches[batch_id] = batch + if self._first_batch_created_at is None: + self._first_batch_created_at = self._now_fn() + return batch + + @staticmethod + def _update_batch_identity(*, batch: _TrackedBatch, event: BatcherEvent) -> None: + """ + Update batch metadata from lifecycle event payload. + + Parameters + ---------- + batch : _TrackedBatch + Mutable tracked batch. + event : BatcherEvent + Lifecycle event payload. + """ + provider = event.get("provider") + endpoint = event.get("endpoint") + model = event.get("model") + if provider is not None: + batch.provider = str(object=provider) + if endpoint is not None: + batch.endpoint = str(object=endpoint) + if model is not None: + batch.model = str(object=model) + + @staticmethod + def _status_counts_as_completed(*, status: str) -> bool: + """ + Determine whether a terminal status counts as completed samples. + + Parameters + ---------- + status : str + Terminal provider status. + + Returns + ------- + bool + ``True`` when terminal state should contribute to completed samples. + """ + lowered_status = status.lower() + negative_markers = ("fail", "error", "cancel", "expired", "timeout") + if any(marker in lowered_status for marker in negative_markers): + return False + return True diff --git a/src/batchling/rich_display.py b/src/batchling/rich_display.py index b4f4c50..0060755 100644 --- a/src/batchling/rich_display.py +++ b/src/batchling/rich_display.py @@ -5,7 +5,6 @@ import os import sys import time -from dataclasses import dataclass from rich.console import Console, Group from rich.live import Live @@ -15,19 +14,7 @@ from rich.text import Text from batchling.core import BatcherEvent - - -@dataclass -class _BatchActivity: - """In-memory batch activity snapshot for progress aggregation.""" - - batch_id: str - provider: str = "-" - endpoint: str = "-" - model: str = "-" - size: int = 0 - completed: bool = False - terminal: bool = False +from batchling.progress_state import BatchProgressState class BatcherRichDisplay: @@ -53,9 +40,7 @@ def __init__( ) -> None: self._console = console or Console(stderr=True) self._refresh_per_second = refresh_per_second - self._batches: dict[str, _BatchActivity] = {} - self._cached_samples = 0 - self._first_batch_created_at: float | None = None + self._progress_state = BatchProgressState(now_fn=time.time) self._live: Live | None = None def start(self) -> None: @@ -86,39 +71,7 @@ def on_event(self, event: BatcherEvent) -> None: event : BatcherEvent Lifecycle event emitted by ``Batcher``. """ - event_type = str(object=event.get("event_type", "unknown")) - source = str(object=event.get("source", "")) - batch_id = event.get("batch_id") - - if batch_id is not None and event_type == "batch_processing": - batch = self._get_or_create_batch(batch_id=str(object=batch_id)) - self._update_batch_identity(batch=batch, event=event) - request_count = event.get("request_count") - if isinstance(request_count, int): - batch.size = max(batch.size, request_count) - batch.terminal = False - elif batch_id is not None and event_type == "batch_polled": - batch = self._get_or_create_batch(batch_id=str(object=batch_id)) - self._update_batch_identity(batch=batch, event=event) - batch.terminal = False - elif batch_id is not None and event_type == "batch_terminal": - batch = self._get_or_create_batch(batch_id=str(object=batch_id)) - self._update_batch_identity(batch=batch, event=event) - status = str(object=event.get("status", "completed")) - batch.completed = self._status_counts_as_completed(status=status) - batch.terminal = True - elif batch_id is not None and event_type == "batch_failed": - batch = self._get_or_create_batch(batch_id=str(object=batch_id)) - self._update_batch_identity(batch=batch, event=event) - batch.completed = False - batch.terminal = True - elif batch_id is not None and event_type == "cache_hit_routed" and source == "resumed_poll": - batch = self._get_or_create_batch(batch_id=str(object=batch_id)) - self._update_batch_identity(batch=batch, event=event) - batch.size += 1 - self._cached_samples += 1 - batch.terminal = False - + self._progress_state.on_event(event=event) self.refresh() def refresh(self) -> None: @@ -127,73 +80,6 @@ def refresh(self) -> None: return self._live.update(renderable=self._render(), refresh=True) - def _get_or_create_batch(self, *, batch_id: str) -> _BatchActivity: - """ - Fetch or create batch display state. - - Parameters - ---------- - batch_id : str - Provider batch identifier. - - Returns - ------- - _BatchActivity - Batch display state. - """ - batch = self._batches.get(batch_id) - if batch is None: - batch = _BatchActivity( - batch_id=batch_id, - ) - self._batches[batch_id] = batch - if self._first_batch_created_at is None: - self._first_batch_created_at = time.time() - return batch - - @staticmethod - def _update_batch_identity(*, batch: _BatchActivity, event: BatcherEvent) -> None: - """ - Update batch metadata from lifecycle event payload. - - Parameters - ---------- - batch : _BatchActivity - Mutable batch row. - event : BatcherEvent - Lifecycle event payload. - """ - provider = event.get("provider") - endpoint = event.get("endpoint") - model = event.get("model") - if provider is not None: - batch.provider = str(object=provider) - if endpoint is not None: - batch.endpoint = str(object=endpoint) - if model is not None: - batch.model = str(object=model) - - @staticmethod - def _status_counts_as_completed(*, status: str) -> bool: - """ - Determine whether a terminal status counts as completed samples. - - Parameters - ---------- - status : str - Terminal provider status. - - Returns - ------- - bool - ``True`` when terminal state should contribute to completed samples. - """ - lowered_status = status.lower() - negative_markers = ("fail", "error", "cancel", "expired", "timeout") - if any(marker in lowered_status for marker in negative_markers): - return False - return True - def _compute_progress(self) -> tuple[int, int, float]: """ Compute aggregate context progress from tracked batches. @@ -203,12 +89,7 @@ def _compute_progress(self) -> tuple[int, int, float]: tuple[int, int, float] ``(completed_samples, total_samples, percent)``. """ - total_samples = sum(batch.size for batch in self._batches.values()) - completed_samples = sum(batch.size for batch in self._batches.values() if batch.completed) - if total_samples <= 0: - return 0, 0, 0.0 - percent = (completed_samples / total_samples) * 100 - return completed_samples, total_samples, percent + return self._progress_state.compute_progress() def _compute_request_metrics(self) -> tuple[int, int, int, int]: """ @@ -219,12 +100,7 @@ def _compute_request_metrics(self) -> tuple[int, int, int, int]: tuple[int, int, int, int] ``(total_samples, cached_samples, completed_samples, in_progress_samples)``. """ - total_samples = sum(batch.size for batch in self._batches.values()) - completed_samples = sum(batch.size for batch in self._batches.values() if batch.completed) - in_progress_samples = sum( - batch.size for batch in self._batches.values() if not batch.terminal - ) - return total_samples, self._cached_samples, completed_samples, in_progress_samples + return self._progress_state.compute_request_metrics() def _compute_elapsed_seconds(self) -> int: """ @@ -235,9 +111,7 @@ def _compute_elapsed_seconds(self) -> int: int Elapsed seconds. """ - if self._first_batch_created_at is None: - return 0 - return max(0, int(time.time() - self._first_batch_created_at)) + return self._progress_state.compute_elapsed_seconds() @staticmethod def _format_elapsed(*, elapsed_seconds: int) -> str: @@ -340,20 +214,7 @@ def _compute_queue_batch_counts(self) -> list[tuple[str, str, str, int, int]]: list[tuple[str, str, str, int, int]] Sorted rows as ``(provider, endpoint, model, running, completed)``. """ - counts_by_queue: dict[tuple[str, str, str], list[int]] = {} - for batch in self._batches.values(): - queue_key = (batch.provider, batch.endpoint, batch.model) - counters = counts_by_queue.setdefault(queue_key, [0, 0]) - if batch.terminal: - counters[1] += 1 - else: - counters[0] += 1 - - rows = [ - (provider, endpoint, model, counters[0], counters[1]) - for (provider, endpoint, model), counters in counts_by_queue.items() - ] - return sorted(rows, key=lambda row: (row[0], row[1], row[2])) + return self._progress_state.compute_queue_batch_counts() def _build_queue_summary_table(self) -> Table: """