diff --git a/docs/architecture/api.md b/docs/architecture/api.md
index ca251fe..637636a 100644
--- a/docs/architecture/api.md
+++ b/docs/architecture/api.md
@@ -9,7 +9,7 @@ yields `None`. Import it from `batchling`.
- Install HTTP hooks once (idempotent).
- Construct a `Batcher` with configuration such as `batch_size`,
`batch_window_seconds`, `batch_poll_interval_seconds`, `dry_run`,
- and `cache`.
+ `cache`, and `live_display`.
- Configure `batchling` logging defaults with Python's stdlib `logging`
(`WARNING` by default).
- Return a `BatchingContext` to scope batching to a context manager.
@@ -26,6 +26,12 @@ yields `None`. Import it from `batchling`.
- **`cache` behavior**: when `cache=True` (default), intercepted requests are fingerprinted
and looked up in a persistent request cache. Cache hits bypass queueing and resume polling
from an existing provider batch when not in dry-run mode.
+- **`live_display` behavior**: `live_display` is a boolean.
+ When `True` (default), Rich panel rendering runs in auto mode and is enabled
+ only when `stderr` is a TTY, terminal is not `dumb`, and `CI` is not set.
+ If auto mode disables Rich, context-level progress is logged at `INFO` on
+ polling events.
+ When `False`, live display and fallback progress logs are both disabled.
- **Outputs**: `BatchingContext[None]` instance that yields `None`.
- **Logging**: lifecycle milestones are emitted at `INFO`, problems at
`WARNING`/`ERROR`, and high-volume diagnostics at `DEBUG`. Request payloads
@@ -43,7 +49,7 @@ Behavior:
- CLI options map directly to `batchify` arguments:
`batch_size`, `batch_window_seconds`, `batch_poll_interval_seconds`, `dry_run`,
- and `cache`.
+ `cache`, and `live_display`.
- Script target must use `module_path:function_name` syntax.
- Forwarded callable arguments are mapped as:
positional tokens are passed as positional arguments;
diff --git a/docs/architecture/context.md b/docs/architecture/context.md
index 78f67a9..8e67530 100644
--- a/docs/architecture/context.md
+++ b/docs/architecture/context.md
@@ -9,6 +9,7 @@ a context variable.
- Activate the `active_batcher` context for the duration of a context block.
- Yield `None` for scope-only lifecycle control.
- Support sync and async context manager patterns for cleanup and context scoping.
+- Start and stop optional Rich live activity display while the context is active.
## Flow summary
@@ -16,7 +17,12 @@ a context variable.
2. `__enter__`/`__aenter__` set the active batcher for the entire context block.
3. `__exit__` resets the context and schedules `batcher.close()` if an event loop is
running (otherwise it warns).
-4. `__aexit__` resets the context and awaits `batcher.close()` to flush pending work.
+4. If `live_display=True`, the context attempts to start Rich panel rendering at
+ enter-time when terminal auto-detection passes (`TTY`, non-`dumb`, non-`CI`).
+ Otherwise it registers an `INFO` logging fallback that emits progress at poll-time.
+5. `__aexit__` resets the context and awaits `batcher.close()` to flush pending work.
+6. The live display listener is removed and the panel is stopped when context cleanup
+ finishes.
## Code reference
diff --git a/docs/cli.md b/docs/cli.md
index 692a1ec..9c05cda 100644
--- a/docs/cli.md
+++ b/docs/cli.md
@@ -64,6 +64,29 @@ batchling generate_product_images.py:main
That's it! Just run that command and you save 50% off your workflow.
+## Live visibility panel
+
+The CLI also exposes the live Rich panel control:
+
+```bash
+batchling generate_product_images.py:main --live-display
+```
+
+`--live-display` is a boolean flag pair:
+
+- `--live-display` (default): auto mode, Rich panel only in interactive terminals
+ (`TTY`, non-`dumb`, non-`CI`). If Rich auto-disables, progress is emitted as
+ `INFO` logs on polling events.
+- `--no-live-display`: disable both Rich panel and fallback progress logs.
+
+When enabled, the panel shows overall context progress:
+`completed_samples / total_samples`, completion percentage, and `Time Elapsed`
+since the first batch seen in the context.
+It also shows request counters and a queue summary table with one row per
+`(provider, endpoint, model)`, including `progress` as
+`completed/total (percentage)` where `completed` is terminal batches and
+`total` is `running + completed`.
+
## Next Steps
If you haven't yet, look at how you can:
diff --git a/docs/python-sdk.md b/docs/python-sdk.md
index 4a00228..4208075 100644
--- a/docs/python-sdk.md
+++ b/docs/python-sdk.md
@@ -79,6 +79,30 @@ async def main():
That's it! Update three lines of code and you save 50% off your workflow.
+## Live visibility panel
+
+You can toggle live visibility behavior while the context is active:
+
+```py
+async with batchify(live_display=True):
+ generated_images = await asyncio.gather(*tasks)
+```
+
+`live_display` accepts a boolean:
+
+- `True` (default): auto mode, Rich panel only in interactive terminals
+ (`TTY`, non-`dumb`, non-`CI`). If Rich auto-disables, progress is emitted as
+ `INFO` logs on polling events.
+- `False`: disable both Rich panel and fallback progress logs.
+
+When enabled, the panel shows context-level progress only:
+`completed_samples / total_samples`, completion percentage, and `Time Elapsed`
+since the first batch seen in the context.
+It also shows request counters and a queue summary table with one row per
+`(provider, endpoint, model)`, including `progress` as
+`completed/total (percentage)` where `completed` is terminal batches and
+`total` is `running + completed`.
+
You can now run this script normally using python and start saving money:
```bash
diff --git a/examples/racing.py b/examples/racing.py
new file mode 100644
index 0000000..22ede9a
--- /dev/null
+++ b/examples/racing.py
@@ -0,0 +1,278 @@
+import asyncio
+import os
+import time
+import typing as t
+from dataclasses import dataclass
+
+from dotenv import load_dotenv
+from groq import AsyncGroq
+from mistralai import Mistral
+from openai import AsyncOpenAI
+from together import AsyncTogether
+
+from batchling import batchify
+
+load_dotenv()
+
+
+@dataclass
+class ProviderRaceResult:
+ """One provider completion entry in completion order."""
+
+ model: str
+ elapsed_seconds: float
+ answer: str
+
+
+ProviderRequestBuilder = t.Callable[[], t.Coroutine[t.Any, t.Any, tuple[str, str]]]
+
+
+async def run_openai_request(*, prompt: str) -> tuple[str, str]:
+ """
+ Send one OpenAI request.
+
+ Parameters
+ ----------
+ prompt : str
+ User prompt sent to the provider.
+
+ Returns
+ -------
+ tuple[str, str]
+ ``(model_name, answer_text)``.
+ """
+ client = AsyncOpenAI(api_key=os.getenv(key="OPENAI_API_KEY"))
+ response = await client.responses.create(
+ input=prompt,
+ model="gpt-4o-mini",
+ )
+ content = response.output[-1].content
+ return response.model, content[0].text
+
+
+async def run_groq_request(*, prompt: str) -> tuple[str, str]:
+ """
+ Send one Groq request.
+
+ Parameters
+ ----------
+ prompt : str
+ User prompt sent to the provider.
+
+ Returns
+ -------
+ tuple[str, str]
+ ``(model_name, answer_text)``.
+ """
+ client = AsyncGroq(api_key=os.getenv(key="GROQ_API_KEY"))
+ response = await client.chat.completions.create(
+ model="llama-3.1-8b-instant",
+ messages=[
+ {
+ "role": "user",
+ "content": prompt,
+ }
+ ],
+ )
+ return response.model, response.choices[0].message.content
+
+
+async def run_mistral_request(*, prompt: str) -> tuple[str, str]:
+ """
+ Send one Mistral request.
+
+ Parameters
+ ----------
+ prompt : str
+ User prompt sent to the provider.
+
+ Returns
+ -------
+ tuple[str, str]
+ ``(model_name, answer_text)``.
+ """
+ client = Mistral(api_key=os.getenv(key="MISTRAL_API_KEY"))
+ response = await client.chat.complete_async(
+ model="mistral-medium-2505",
+ messages=[
+ {
+ "role": "user",
+ "content": prompt,
+ }
+ ],
+ stream=False,
+ response_format={"type": "text"},
+ )
+ return response.model, str(object=response.choices[0].message.content)
+
+
+async def run_together_request(*, prompt: str) -> tuple[str, str]:
+ """
+ Send one Together request.
+
+ Parameters
+ ----------
+ prompt : str
+ User prompt sent to the provider.
+
+ Returns
+ -------
+ tuple[str, str]
+ ``(model_name, answer_text)``.
+ """
+ client = AsyncTogether(api_key=os.getenv(key="TOGETHER_API_KEY"))
+ response = await client.chat.completions.create(
+ model="google/gemma-3n-E4B-it",
+ messages=[
+ {
+ "role": "user",
+ "content": prompt,
+ }
+ ],
+ )
+ return response.model, response.choices[0].message.content
+
+
+async def run_doubleword_request(*, prompt: str) -> tuple[str, str]:
+ """
+ Send one Doubleword request.
+
+ Parameters
+ ----------
+ prompt : str
+ User prompt sent to the provider.
+
+ Returns
+ -------
+ tuple[str, str]
+ ``(model_name, answer_text)``.
+ """
+ client = AsyncOpenAI(
+ api_key=os.getenv(key="DOUBLEWORD_API_KEY"),
+ base_url="https://api.doubleword.ai/v1",
+ )
+ response = await client.responses.create(
+ input=prompt,
+ model="openai/gpt-oss-20b",
+ )
+ content = response.output[-1].content
+ return response.model, content[0].text
+
+
+async def run_provider_request(
+ *,
+ request_builder: ProviderRequestBuilder,
+ started_at: float,
+) -> ProviderRaceResult:
+ """
+ Execute one provider request and annotate elapsed time.
+
+ Parameters
+ ----------
+ request_builder : ProviderRequestBuilder
+ Provider request coroutine factory.
+ started_at : float
+ Shared wall-clock start time in ``perf_counter`` seconds.
+
+ Returns
+ -------
+ ProviderRaceResult
+ Result payload with answer and elapsed time.
+ """
+ model, answer = await request_builder()
+ elapsed_seconds = time.perf_counter() - started_at
+ return ProviderRaceResult(
+ model=model,
+ elapsed_seconds=elapsed_seconds,
+ answer=answer,
+ )
+
+
+def build_enabled_request_builders(*, prompt: str) -> list[ProviderRequestBuilder]:
+ """
+ Build one request factory per configured provider.
+
+ Parameters
+ ----------
+ prompt : str
+ Shared text prompt sent to all providers.
+
+ Returns
+ -------
+ list[ProviderRequestBuilder]
+ Enabled provider request factories.
+ """
+ providers: list[tuple[str, ProviderRequestBuilder]] = [
+ (
+ "OPENAI_API_KEY",
+ lambda: run_openai_request(prompt=prompt),
+ ),
+ (
+ "GROQ_API_KEY",
+ lambda: run_groq_request(prompt=prompt),
+ ),
+ (
+ "MISTRAL_API_KEY",
+ lambda: run_mistral_request(prompt=prompt),
+ ),
+ (
+ "TOGETHER_API_KEY",
+ lambda: run_together_request(prompt=prompt),
+ ),
+ (
+ "DOUBLEWORD_API_KEY",
+ lambda: run_doubleword_request(prompt=prompt),
+ ),
+ ]
+ enabled_builders: list[ProviderRequestBuilder] = []
+ for env_var_name, request_builder in providers:
+ api_key = os.getenv(key=env_var_name)
+ if not api_key:
+ continue
+ enabled_builders.append(request_builder)
+ return enabled_builders
+
+
+async def main() -> None:
+ """
+ Run one request per provider and collect completion-order results.
+
+ The race excludes Anthropic, Gemini, and XAI on purpose because their model field
+ extraction differs from the other provider examples.
+ """
+ prompt = "Give one short sentence explaining what asynchronous batching is."
+ request_builders = build_enabled_request_builders(prompt=prompt)
+ if not request_builders:
+ print("No providers configured. Set at least one provider API key in your environment.")
+ return
+
+ started_at = time.perf_counter()
+ tasks = [
+ asyncio.create_task(
+ run_provider_request(
+ request_builder=request_builder,
+ started_at=started_at,
+ )
+ )
+ for request_builder in request_builders
+ ]
+
+ completion_order_register: list[ProviderRaceResult] = []
+ for task in asyncio.as_completed(tasks):
+ result = await task
+ completion_order_register.append(result)
+
+ for index, result in enumerate(completion_order_register, start=1):
+ print(f"{index}. model={result.model}")
+ print(f" elapsed={result.elapsed_seconds:.2f}s")
+ print(f" answer={result.answer}\n")
+
+
+async def run_with_batchify() -> None:
+ """Run the provider race inside ``batchify`` for direct script execution."""
+ async with batchify():
+ await main()
+
+
+if __name__ == "__main__":
+ asyncio.run(run_with_batchify())
diff --git a/pyproject.toml b/pyproject.toml
index ba8394a..b7c4e15 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -31,6 +31,7 @@ classifiers = [
dependencies = [
"aiohttp>=3.13.3",
"httpx>=0.28.1",
+ "rich>=14.2.0",
"typer>=0.20.0",
]
diff --git a/requirements.txt b/requirements.txt
index d0e7873..99d8d5c 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -645,6 +645,7 @@ requests-toolbelt==1.0.0
respx==0.22.0
rich==14.2.0
# via
+ # batchling
# cyclopts
# fastmcp
# instructor
diff --git a/src/batchling/api.py b/src/batchling/api.py
index a00afb9..46a2f88 100644
--- a/src/batchling/api.py
+++ b/src/batchling/api.py
@@ -16,6 +16,7 @@ def batchify(
batch_poll_interval_seconds: float = 10.0,
dry_run: bool = False,
cache: bool = True,
+ live_display: bool = True,
) -> BatchingContext:
"""
Context manager used to activate batching for a scoped context.
@@ -37,6 +38,11 @@ def batchify(
cache : bool, optional
If ``True``, enable persistent request cache lookups.
This parameter allows to skip the batch submission and go straight to the polling phase for requests that have already been sent.
+ live_display : bool, optional
+ Enable live display behavior while the context is active.
+ When ``True``, Rich panel rendering is attempted with terminal auto-detection.
+ If terminal auto-detection disables Rich (non-TTY, ``TERM=dumb``, or ``CI``),
+ progress is logged at ``INFO`` on polling events.
Returns
-------
@@ -65,4 +71,5 @@ def batchify(
# 3. Return BatchingContext with no yielded target.
return BatchingContext(
batcher=batcher,
+ live_display=live_display,
)
diff --git a/src/batchling/cli/main.py b/src/batchling/cli/main.py
index d266302..6cba6dd 100644
--- a/src/batchling/cli/main.py
+++ b/src/batchling/cli/main.py
@@ -75,7 +75,8 @@ async def run_script_with_batchify(
batch_poll_interval_seconds: float,
dry_run: bool,
cache: bool,
-):
+ live_display: bool,
+) -> None:
"""
Execute a Python script under a batchify context.
@@ -97,6 +98,8 @@ async def run_script_with_batchify(
Dry run mode passed to ``batchify``.
cache : bool
Cache mode passed to ``batchify``.
+ live_display : bool
+ Live display toggle passed to ``batchify``.
"""
if not module_path.exists():
typer.echo(f"Script not found: {module_path}")
@@ -112,6 +115,7 @@ async def run_script_with_batchify(
batch_poll_interval_seconds=batch_poll_interval_seconds,
dry_run=dry_run,
cache=cache,
+ live_display=live_display,
):
ns = runpy.run_path(path_name=script_path_as_posix, run_name="batchling.runtime")
func = ns.get(func_name)
@@ -155,6 +159,16 @@ def main(
bool,
typer.Option("--cache/--no-cache", help="Enable persistent request caching"),
] = True,
+ live_display: t.Annotated[
+ bool,
+ typer.Option(
+ "--live-display/--no-live-display",
+ help=(
+ "Enable auto live display. When disabled by terminal auto-detection, "
+ "fallback polling progress is logged at INFO."
+ ),
+ ),
+ ] = True,
):
"""Run a script under ``batchify``."""
try:
@@ -173,5 +187,6 @@ def main(
batch_poll_interval_seconds=batch_poll_interval_seconds,
dry_run=dry_run,
cache=cache,
+ live_display=live_display,
)
)
diff --git a/src/batchling/context.py b/src/batchling/context.py
index 0916481..82ccf12 100644
--- a/src/batchling/context.py
+++ b/src/batchling/context.py
@@ -3,11 +3,55 @@
"""
import asyncio
+import logging
import typing as t
import warnings
-from batchling.core import Batcher
+from batchling.core import Batcher, BatcherEvent
from batchling.hooks import active_batcher
+from batchling.logging import log_info
+from batchling.progress_state import BatchProgressState
+from batchling.rich_display import (
+ BatcherRichDisplay,
+ should_enable_live_display,
+)
+
+log = logging.getLogger(name=__name__)
+
+
+class _PollingProgressLogger:
+ """INFO logger fallback used when Rich live display auto-disables."""
+
+ def __init__(self) -> None:
+ self._progress_state = BatchProgressState()
+
+ def on_event(self, event: BatcherEvent) -> None:
+ """
+ Consume one lifecycle event and log progress on poll events.
+
+ Parameters
+ ----------
+ event : BatcherEvent
+ Lifecycle event emitted by ``Batcher``.
+ """
+ self._progress_state.on_event(event=event)
+
+ event_type = str(object=event.get("event_type", "unknown"))
+ if event_type != "batch_polled":
+ return
+
+ completed_samples, total_samples, percent = self._progress_state.compute_progress()
+ _, _, _, in_progress_samples = self._progress_state.compute_request_metrics()
+ log_info(
+ logger=log,
+ event="Live display fallback progress",
+ batch_id=event.get("batch_id"),
+ status=event.get("status"),
+ completed_samples=completed_samples,
+ total_samples=total_samples,
+ percent=f"{percent:.1f}",
+ in_progress_samples=in_progress_samples,
+ )
class BatchingContext:
@@ -18,9 +62,16 @@ class BatchingContext:
----------
batcher : Batcher
Batcher instance used for the scope of the context manager.
+ live_display : bool, optional
+ Whether to enable auto live display behavior for the context.
"""
- def __init__(self, *, batcher: "Batcher") -> None:
+ def __init__(
+ self,
+ *,
+ batcher: "Batcher",
+ live_display: bool = True,
+ ) -> None:
"""
Initialize the context manager.
@@ -28,10 +79,130 @@ def __init__(self, *, batcher: "Batcher") -> None:
----------
batcher : Batcher
Batcher instance used for the scope of the context manager.
+ live_display : bool, optional
+ Whether to enable auto live display behavior for the context.
"""
self._self_batcher = batcher
+ self._self_live_display_enabled = live_display
+ self._self_live_display: BatcherRichDisplay | None = None
+ self._self_live_display_heartbeat_task: asyncio.Task[None] | None = None
+ self._self_polling_progress_logger: _PollingProgressLogger | None = None
self._self_context_token: t.Any | None = None
+ def _start_polling_progress_logger(self) -> None:
+ """
+ Start the INFO polling progress fallback listener.
+
+ Notes
+ -----
+ Fallback listener errors are downgraded to warnings.
+ """
+ if self._self_polling_progress_logger is not None:
+ return
+ try:
+ listener = _PollingProgressLogger()
+ self._self_batcher._add_event_listener(listener=listener.on_event)
+ self._self_polling_progress_logger = listener
+ log_info(
+ logger=log,
+ event=(
+ "Live display disabled by terminal auto-detection; "
+ "using polling progress INFO logs"
+ ),
+ )
+ except Exception as error:
+ warnings.warn(
+ message=f"Failed to start batchling polling progress logs: {error}",
+ category=UserWarning,
+ stacklevel=2,
+ )
+
+ def _start_live_display(self) -> None:
+ """
+ Start the Rich live display when enabled.
+
+ Notes
+ -----
+ Display errors are downgraded to warnings to avoid breaking batching.
+ """
+ if self._self_live_display is not None or self._self_polling_progress_logger is not None:
+ return
+ if not self._self_live_display_enabled:
+ return
+ if not should_enable_live_display(enabled=self._self_live_display_enabled):
+ self._start_polling_progress_logger()
+ return
+ try:
+ display = BatcherRichDisplay()
+ self._self_batcher._add_event_listener(listener=display.on_event)
+ display.start()
+ self._self_live_display = display
+ self._start_live_display_heartbeat()
+ except Exception as error:
+ warnings.warn(
+ message=f"Failed to start batchling live display: {error}",
+ category=UserWarning,
+ stacklevel=2,
+ )
+
+ async def _run_live_display_heartbeat(self) -> None:
+ """
+ Periodically refresh the live display while the context is active.
+ """
+ try:
+ while self._self_live_display is not None:
+ self._self_live_display.refresh()
+ await asyncio.sleep(1.0)
+ except asyncio.CancelledError:
+ raise
+
+ def _start_live_display_heartbeat(self) -> None:
+ """
+ Start the 1-second live display heartbeat when an event loop exists.
+ """
+ if self._self_live_display is None:
+ return
+ if self._self_live_display_heartbeat_task is not None:
+ return
+ try:
+ loop = asyncio.get_running_loop()
+ except RuntimeError:
+ return
+ self._self_live_display_heartbeat_task = loop.create_task(
+ coro=self._run_live_display_heartbeat()
+ )
+
+ def _stop_live_display(self) -> None:
+ """
+ Stop and unregister the Rich live display.
+
+ Notes
+ -----
+ Display shutdown errors are downgraded to warnings.
+ """
+ display = self._self_live_display
+ fallback_listener = self._self_polling_progress_logger
+ if display is None and fallback_listener is None:
+ return
+ self._self_live_display = None
+ self._self_polling_progress_logger = None
+ heartbeat_task = self._self_live_display_heartbeat_task
+ self._self_live_display_heartbeat_task = None
+ if heartbeat_task is not None and not heartbeat_task.done():
+ heartbeat_task.cancel()
+ try:
+ if display is not None:
+ self._self_batcher._remove_event_listener(listener=display.on_event)
+ display.stop()
+ if fallback_listener is not None:
+ self._self_batcher._remove_event_listener(listener=fallback_listener.on_event)
+ except Exception as error:
+ warnings.warn(
+ message=f"Failed to stop batchling live display: {error}",
+ category=UserWarning,
+ stacklevel=2,
+ )
+
def __enter__(self) -> None:
"""
Enter the synchronous context manager and activate the batcher.
@@ -42,6 +213,7 @@ def __enter__(self) -> None:
``None`` for scoped activation.
"""
self._self_context_token = active_batcher.set(self._self_batcher)
+ self._start_live_display()
return None
def __exit__(
@@ -67,7 +239,8 @@ def __exit__(
self._self_context_token = None
try:
loop = asyncio.get_running_loop()
- loop.create_task(coro=self._self_batcher.close())
+ close_task = loop.create_task(coro=self._self_batcher.close())
+ close_task.add_done_callback(self._on_sync_close_done)
except RuntimeError:
warnings.warn(
message=(
@@ -78,6 +251,18 @@ def __exit__(
category=UserWarning,
stacklevel=2,
)
+ self._stop_live_display()
+
+ def _on_sync_close_done(self, _: asyncio.Task[None]) -> None:
+ """
+ Callback run when sync-context close task completes.
+
+ Parameters
+ ----------
+ _ : asyncio.Task[None]
+ Completed close task.
+ """
+ self._stop_live_display()
async def __aenter__(self) -> None:
"""
@@ -89,6 +274,7 @@ async def __aenter__(self) -> None:
``None`` for scoped activation.
"""
self._self_context_token = active_batcher.set(self._self_batcher)
+ self._start_live_display()
return None
async def __aexit__(
@@ -112,4 +298,7 @@ async def __aexit__(
if self._self_context_token is not None:
active_batcher.reset(self._self_context_token)
self._self_context_token = None
- await self._self_batcher.close()
+ try:
+ await self._self_batcher.close()
+ finally:
+ self._stop_live_display()
diff --git a/src/batchling/core.py b/src/batchling/core.py
index 65dc7aa..f27c451 100644
--- a/src/batchling/core.py
+++ b/src/batchling/core.py
@@ -31,6 +31,59 @@
CACHE_RETENTION_SECONDS = 30 * 24 * 60 * 60
+class BatcherEvent(t.TypedDict, total=False):
+ """
+ Lifecycle event emitted by ``Batcher`` for optional observers.
+
+ event_type : str
+ Event identifier.
+ timestamp : float
+ Event timestamp in UNIX seconds.
+ provider : str
+ Provider name.
+ endpoint : str
+ Request endpoint.
+ model : str
+ Request model.
+ queue_key : QueueKey
+ Queue identifier.
+ batch_id : str
+ Provider batch identifier.
+ status : str
+ Provider batch status.
+ request_count : int
+ Number of requests in the event scope.
+ pending_count : int
+ Current queue pending size.
+ custom_id : str
+ Custom request identifier.
+ source : str
+ Event source subsystem.
+ error : str
+ Error text.
+ missing_count : int
+ Number of missing results.
+ """
+
+ event_type: str
+ timestamp: float
+ provider: str
+ endpoint: str
+ model: str
+ queue_key: QueueKey
+ batch_id: str
+ status: str
+ request_count: int
+ pending_count: int
+ custom_id: str
+ source: str
+ error: str
+ missing_count: int
+
+
+BatcherEventListener = t.Callable[[BatcherEvent], None]
+
+
@dataclass
class _PendingRequest:
# FIXME: _PendingRequest can use a generic type to match any request from:
@@ -135,6 +188,7 @@ def __init__(
self._resumed_poll_tasks: set[asyncio.Task[None]] = set()
self._resumed_batches: dict[ResumedBatchKey, _ResumedBatch] = {}
self._resumed_lock = asyncio.Lock()
+ self._event_listeners: set[BatcherEventListener] = set()
self._cache_store: RequestCacheStore | None = None
if self._cache_enabled:
@@ -163,6 +217,73 @@ def __init__(
),
)
+ def _add_event_listener(
+ self,
+ *,
+ listener: BatcherEventListener,
+ ) -> None:
+ """
+ Register a listener receiving batch lifecycle events.
+
+ Parameters
+ ----------
+ listener : BatcherEventListener
+ Observer callback.
+ """
+ self._event_listeners.add(listener)
+
+ def _remove_event_listener(
+ self,
+ *,
+ listener: BatcherEventListener,
+ ) -> None:
+ """
+ Unregister a previously registered lifecycle listener.
+
+ Parameters
+ ----------
+ listener : BatcherEventListener
+ Observer callback.
+ """
+ self._event_listeners.discard(listener)
+
+ def _emit_event(
+ self,
+ *,
+ event_type: str,
+ **payload: t.Any,
+ ) -> None:
+ """
+ Emit a lifecycle event to all registered listeners.
+
+ Parameters
+ ----------
+ event_type : str
+ Event identifier.
+ **payload : typing.Any
+ Event payload.
+ """
+ if not self._event_listeners:
+ return
+ event = t.cast(
+ typ=BatcherEvent,
+ val={
+ "event_type": event_type,
+ "timestamp": time.time(),
+ **payload,
+ },
+ )
+ for listener in list(self._event_listeners):
+ try:
+ listener(event)
+ except Exception as error:
+ log_debug(
+ logger=log,
+ event="Batcher event listener failed",
+ listener=repr(listener),
+ error=str(object=error),
+ )
+
@staticmethod
def _format_queue_key(*, queue_key: QueueKey) -> str:
"""
@@ -353,6 +474,16 @@ async def _try_submit_from_cache(
batch_id=cache_entry.batch_id,
custom_id=cache_entry.custom_id,
)
+ cache_source = "cache_dry_run" if self._dry_run else "resumed_poll"
+ self._emit_event(
+ event_type="cache_hit_routed",
+ provider=provider_name,
+ endpoint=endpoint,
+ model=model_name,
+ batch_id=cache_entry.batch_id,
+ custom_id=cache_entry.custom_id,
+ source=cache_source,
+ )
if self._dry_run:
dry_run_request = _PendingRequest(
custom_id=cache_entry.custom_id,
@@ -401,6 +532,15 @@ async def _enqueue_pending_request(
queue = self._pending_by_provider.setdefault(queue_key, [])
queue.append(request)
pending_count = len(queue)
+ self._emit_event(
+ event_type="request_queued",
+ provider=queue_key[0],
+ endpoint=queue_key[1],
+ model=queue_key[2],
+ queue_key=queue_key,
+ pending_count=pending_count,
+ custom_id=request.custom_id,
+ )
if pending_count == 1:
self._window_tasks[queue_key] = asyncio.create_task(
@@ -610,7 +750,6 @@ async def _attach_cached_request(
future=future,
)
)
-
if should_start_poller:
task = asyncio.create_task(
coro=self._poll_cached_batch(resume_key=resume_key),
@@ -686,6 +825,15 @@ async def _window_timer(self, *, queue_key: QueueKey) -> None:
queue_key=queue_name,
error=str(object=e),
)
+ self._emit_event(
+ event_type="window_timer_error",
+ provider=provider_name,
+ endpoint=queue_endpoint,
+ model=model_name,
+ queue_key=queue_key,
+ error=str(object=e),
+ source="window_timer",
+ )
await self._fail_pending_provider_requests(
queue_key=queue_key,
error=e,
@@ -722,6 +870,14 @@ async def _submit_requests(
queue_key=queue_name,
request_count=len(requests),
)
+ self._emit_event(
+ event_type="batch_submitting",
+ provider=provider_name,
+ endpoint=queue_endpoint,
+ model=model_name,
+ queue_key=queue_key,
+ request_count=len(requests),
+ )
task = asyncio.create_task(
coro=self._process_batch(queue_key=queue_key, requests=requests),
name=f"batch_submit_{queue_name}_{uuid.uuid4()}",
@@ -876,6 +1032,16 @@ async def _process_batch(
try:
if self._dry_run:
dry_run_batch_id = f"dryrun-{uuid.uuid4()}"
+ self._emit_event(
+ event_type="batch_processing",
+ provider=provider.name,
+ endpoint=queue_endpoint,
+ model=model_name,
+ queue_key=queue_key,
+ request_count=len(requests),
+ batch_id=dry_run_batch_id,
+ source="dry_run",
+ )
active_batch = _ActiveBatch(
batch_id=dry_run_batch_id,
output_file_id="",
@@ -900,6 +1066,17 @@ async def _process_batch(
batch_id=dry_run_batch_id,
request_count=len(requests),
)
+ self._emit_event(
+ event_type="batch_terminal",
+ provider=provider.name,
+ endpoint=queue_endpoint,
+ model=model_name,
+ queue_key=queue_key,
+ request_count=len(requests),
+ batch_id=dry_run_batch_id,
+ status="simulated",
+ source="dry_run",
+ )
return
log_info(
@@ -911,11 +1088,30 @@ async def _process_batch(
queue_key=queue_name,
request_count=len(requests),
)
+ self._emit_event(
+ event_type="batch_processing",
+ provider=provider.name,
+ endpoint=queue_endpoint,
+ model=model_name,
+ queue_key=queue_key,
+ request_count=len(requests),
+ source="submit",
+ )
batch_submission = await provider.process_batch(
requests=requests,
client_factory=self._client_factory,
queue_key=queue_key,
)
+ self._emit_event(
+ event_type="batch_processing",
+ provider=provider.name,
+ endpoint=queue_endpoint,
+ model=model_name,
+ queue_key=queue_key,
+ request_count=len(requests),
+ batch_id=batch_submission.batch_id,
+ source="poll_start",
+ )
self._write_cache_entries(
queue_key=queue_key,
requests=requests,
@@ -948,6 +1144,15 @@ async def _process_batch(
queue_key=queue_name,
error=str(object=e),
)
+ self._emit_event(
+ event_type="batch_failed",
+ provider=provider_name,
+ endpoint=queue_endpoint,
+ model=model_name,
+ queue_key=queue_key,
+ request_count=len(requests),
+ error=str(object=e),
+ )
for req in requests:
if not req.future.done():
req.future.set_exception(e)
@@ -1056,6 +1261,13 @@ async def _poll_batch(
api_headers=api_headers,
batch_id=active_batch.batch_id,
)
+ self._emit_event(
+ event_type="batch_polled",
+ provider=provider.name,
+ batch_id=active_batch.batch_id,
+ status=poll_snapshot.status,
+ source="active_poll",
+ )
active_batch.output_file_id = poll_snapshot.output_file_id
active_batch.error_file_id = poll_snapshot.error_file_id
@@ -1067,6 +1279,13 @@ async def _poll_batch(
batch_id=active_batch.batch_id,
status=poll_snapshot.status,
)
+ self._emit_event(
+ event_type="batch_terminal",
+ provider=provider.name,
+ batch_id=active_batch.batch_id,
+ status=poll_snapshot.status,
+ source="active_poll",
+ )
await self._resolve_batch_results(
base_url=base_url,
api_headers=api_headers,
@@ -1105,7 +1324,21 @@ async def _poll_cached_batch(
api_headers=resumed_batch.api_headers,
batch_id=batch_id,
)
+ self._emit_event(
+ event_type="batch_polled",
+ provider=provider.name,
+ batch_id=batch_id,
+ status=poll_snapshot.status,
+ source="resumed_poll",
+ )
if poll_snapshot.status in provider.batch_terminal_states:
+ self._emit_event(
+ event_type="batch_terminal",
+ provider=provider.name,
+ batch_id=batch_id,
+ status=poll_snapshot.status,
+ source="resumed_poll",
+ )
await self._resolve_cached_batch_results(
resume_key=resume_key,
output_file_id=poll_snapshot.output_file_id,
@@ -1117,6 +1350,13 @@ async def _poll_cached_batch(
except asyncio.CancelledError:
raise
except Exception as error:
+ self._emit_event(
+ event_type="batch_failed",
+ provider=provider.name,
+ batch_id=batch_id,
+ error=str(object=error),
+ source="resumed_poll",
+ )
await self._fail_resumed_batch_requests(
resume_key=resume_key,
error=error,
@@ -1199,6 +1439,12 @@ async def _resolve_cached_batch_results(
pending.future.set_result(resolved_response)
if missing_hashes:
+ self._emit_event(
+ event_type="missing_results",
+ batch_id=batch_id,
+ missing_count=len(missing_hashes),
+ source="resumed_results",
+ )
_ = self._invalidate_cache_hashes(request_hashes=missing_hashes)
async def _fail_resumed_batch_requests(
@@ -1405,6 +1651,12 @@ def _fail_missing_results(
batch_id=active_batch.batch_id,
missing_count=len(missing),
)
+ self._emit_event(
+ event_type="missing_results",
+ batch_id=active_batch.batch_id,
+ missing_count=len(missing),
+ source="results",
+ )
error = RuntimeError(f"Missing results for {len(missing)} request(s)")
for custom_id in missing:
pending = active_batch.requests.get(custom_id)
@@ -1447,6 +1699,15 @@ async def close(self) -> None:
queue_key=queue_name,
request_count=len(requests),
)
+ self._emit_event(
+ event_type="final_flush_submitting",
+ provider=provider_name,
+ endpoint=queue_endpoint,
+ model=model_name,
+ queue_key=queue_key,
+ request_count=len(requests),
+ source="close",
+ )
await self._submit_requests(
queue_key=queue_key,
requests=requests,
diff --git a/src/batchling/progress_state.py b/src/batchling/progress_state.py
new file mode 100644
index 0000000..4a9b0c6
--- /dev/null
+++ b/src/batchling/progress_state.py
@@ -0,0 +1,223 @@
+"""Shared progress-state tracking for live display and fallback logging."""
+
+from __future__ import annotations
+
+import time
+import typing as t
+from dataclasses import dataclass
+
+from batchling.core import BatcherEvent
+
+
+@dataclass
+class _TrackedBatch:
+ """In-memory batch state used for aggregate progress computations."""
+
+ batch_id: str
+ provider: str = "-"
+ endpoint: str = "-"
+ model: str = "-"
+ size: int = 0
+ completed: bool = False
+ terminal: bool = False
+
+
+class BatchProgressState:
+ """
+ Track batch lifecycle state and compute shared aggregate metrics.
+
+ Parameters
+ ----------
+ now_fn : typing.Callable[[], float] | None, optional
+ Clock function used for elapsed-time calculations.
+ """
+
+ def __init__(
+ self,
+ *,
+ now_fn: t.Callable[[], float] | None = None,
+ ) -> None:
+ self._now_fn = now_fn or time.time
+ self._batches: dict[str, _TrackedBatch] = {}
+ self._cached_samples = 0
+ self._first_batch_created_at: float | None = None
+
+ def on_event(self, *, event: BatcherEvent) -> None:
+ """
+ Update tracked state from one lifecycle event.
+
+ Parameters
+ ----------
+ event : BatcherEvent
+ Lifecycle event emitted by ``Batcher``.
+ """
+ event_type = str(object=event.get("event_type", "unknown"))
+ source = str(object=event.get("source", ""))
+ batch_id = event.get("batch_id")
+
+ if batch_id is None:
+ return
+
+ batch = self._get_or_create_batch(batch_id=str(object=batch_id))
+ self._update_batch_identity(batch=batch, event=event)
+
+ if event_type == "batch_processing":
+ request_count = event.get("request_count")
+ if isinstance(request_count, int):
+ batch.size = max(batch.size, request_count)
+ batch.terminal = False
+ return
+
+ if event_type == "batch_polled":
+ batch.terminal = False
+ return
+
+ if event_type == "batch_terminal":
+ status = str(object=event.get("status", "completed"))
+ batch.completed = self._status_counts_as_completed(status=status)
+ batch.terminal = True
+ return
+
+ if event_type == "batch_failed":
+ batch.completed = False
+ batch.terminal = True
+ return
+
+ if event_type == "cache_hit_routed" and source == "resumed_poll":
+ batch.size += 1
+ self._cached_samples += 1
+ batch.terminal = False
+
+ def compute_progress(self) -> tuple[int, int, float]:
+ """
+ Compute aggregate sample progress from tracked batches.
+
+ Returns
+ -------
+ tuple[int, int, float]
+ ``(completed_samples, total_samples, percent)``.
+ """
+ total_samples = sum(batch.size for batch in self._batches.values())
+ completed_samples = sum(batch.size for batch in self._batches.values() if batch.completed)
+ if total_samples <= 0:
+ return 0, 0, 0.0
+ percent = (completed_samples / total_samples) * 100.0
+ return completed_samples, total_samples, percent
+
+ def compute_request_metrics(self) -> tuple[int, int, int, int]:
+ """
+ Compute aggregate request counters from tracked batches.
+
+ Returns
+ -------
+ tuple[int, int, int, int]
+ ``(total_samples, cached_samples, completed_samples, in_progress_samples)``.
+ """
+ total_samples = sum(batch.size for batch in self._batches.values())
+ completed_samples = sum(batch.size for batch in self._batches.values() if batch.completed)
+ in_progress_samples = sum(
+ batch.size for batch in self._batches.values() if not batch.terminal
+ )
+ return total_samples, self._cached_samples, completed_samples, in_progress_samples
+
+ def compute_queue_batch_counts(self) -> list[tuple[str, str, str, int, int]]:
+ """
+ Aggregate queue-level running and terminal batch counts.
+
+ Returns
+ -------
+ list[tuple[str, str, str, int, int]]
+ Sorted rows as ``(provider, endpoint, model, running, completed)``.
+ """
+ counts_by_queue: dict[tuple[str, str, str], list[int]] = {}
+ for batch in self._batches.values():
+ queue_key = (batch.provider, batch.endpoint, batch.model)
+ counters = counts_by_queue.setdefault(queue_key, [0, 0])
+ if batch.terminal:
+ counters[1] += 1
+ else:
+ counters[0] += 1
+
+ rows = [
+ (provider, endpoint, model, counters[0], counters[1])
+ for (provider, endpoint, model), counters in counts_by_queue.items()
+ ]
+ return sorted(rows, key=lambda row: (row[0], row[1], row[2]))
+
+ def compute_elapsed_seconds(self) -> int:
+ """
+ Compute elapsed seconds since first tracked batch in this context.
+
+ Returns
+ -------
+ int
+ Elapsed seconds.
+ """
+ if self._first_batch_created_at is None:
+ return 0
+ return max(0, int(self._now_fn() - self._first_batch_created_at))
+
+ def _get_or_create_batch(self, *, batch_id: str) -> _TrackedBatch:
+ """
+ Get or create one tracked batch record.
+
+ Parameters
+ ----------
+ batch_id : str
+ Provider batch identifier.
+
+ Returns
+ -------
+ _TrackedBatch
+ Mutable tracked batch.
+ """
+ batch = self._batches.get(batch_id)
+ if batch is None:
+ batch = _TrackedBatch(batch_id=batch_id)
+ self._batches[batch_id] = batch
+ if self._first_batch_created_at is None:
+ self._first_batch_created_at = self._now_fn()
+ return batch
+
+ @staticmethod
+ def _update_batch_identity(*, batch: _TrackedBatch, event: BatcherEvent) -> None:
+ """
+ Update batch metadata from lifecycle event payload.
+
+ Parameters
+ ----------
+ batch : _TrackedBatch
+ Mutable tracked batch.
+ event : BatcherEvent
+ Lifecycle event payload.
+ """
+ provider = event.get("provider")
+ endpoint = event.get("endpoint")
+ model = event.get("model")
+ if provider is not None:
+ batch.provider = str(object=provider)
+ if endpoint is not None:
+ batch.endpoint = str(object=endpoint)
+ if model is not None:
+ batch.model = str(object=model)
+
+ @staticmethod
+ def _status_counts_as_completed(*, status: str) -> bool:
+ """
+ Determine whether a terminal status counts as completed samples.
+
+ Parameters
+ ----------
+ status : str
+ Terminal provider status.
+
+ Returns
+ -------
+ bool
+ ``True`` when terminal state should contribute to completed samples.
+ """
+ lowered_status = status.lower()
+ negative_markers = ("fail", "error", "cancel", "expired", "timeout")
+ if any(marker in lowered_status for marker in negative_markers):
+ return False
+ return True
diff --git a/src/batchling/rich_display.py b/src/batchling/rich_display.py
new file mode 100644
index 0000000..0060755
--- /dev/null
+++ b/src/batchling/rich_display.py
@@ -0,0 +1,330 @@
+"""Rich live display for batch lifecycle visibility."""
+
+from __future__ import annotations
+
+import os
+import sys
+import time
+
+from rich.console import Console, Group
+from rich.live import Live
+from rich.panel import Panel
+from rich.progress import BarColumn, Progress, TextColumn
+from rich.table import Table
+from rich.text import Text
+
+from batchling.core import BatcherEvent
+from batchling.progress_state import BatchProgressState
+
+
+class BatcherRichDisplay:
+ """
+ Render context-level sample progress through a Rich ``Live`` panel.
+
+ Progress is computed from tracked sent batches as:
+ ``sum(size of completed batches) / sum(size of all tracked batches)``.
+
+ Parameters
+ ----------
+ refresh_per_second : float, optional
+ Refresh rate for Rich live updates.
+ console : Console | None, optional
+ Rich console to render to. Defaults to ``Console(stderr=True)``.
+ """
+
+ def __init__(
+ self,
+ *,
+ refresh_per_second: float = 1.0,
+ console: Console | None = None,
+ ) -> None:
+ self._console = console or Console(stderr=True)
+ self._refresh_per_second = refresh_per_second
+ self._progress_state = BatchProgressState(now_fn=time.time)
+ self._live: Live | None = None
+
+ def start(self) -> None:
+ """Start the live panel if not already running."""
+ if self._live is not None:
+ return
+ self._live = Live(
+ renderable=self._render(),
+ console=self._console,
+ refresh_per_second=self._refresh_per_second,
+ transient=False,
+ )
+ self._live.start(refresh=True)
+
+ def stop(self) -> None:
+ """Stop the live panel if running."""
+ if self._live is None:
+ return
+ self._live.stop()
+ self._live = None
+
+ def on_event(self, event: BatcherEvent) -> None:
+ """
+ Consume one batch lifecycle event and refresh the panel.
+
+ Parameters
+ ----------
+ event : BatcherEvent
+ Lifecycle event emitted by ``Batcher``.
+ """
+ self._progress_state.on_event(event=event)
+ self.refresh()
+
+ def refresh(self) -> None:
+ """Force one live-panel refresh when running."""
+ if self._live is None:
+ return
+ self._live.update(renderable=self._render(), refresh=True)
+
+ def _compute_progress(self) -> tuple[int, int, float]:
+ """
+ Compute aggregate context progress from tracked batches.
+
+ Returns
+ -------
+ tuple[int, int, float]
+ ``(completed_samples, total_samples, percent)``.
+ """
+ return self._progress_state.compute_progress()
+
+ def _compute_request_metrics(self) -> tuple[int, int, int, int]:
+ """
+ Compute aggregate request counters shown under the progress bar.
+
+ Returns
+ -------
+ tuple[int, int, int, int]
+ ``(total_samples, cached_samples, completed_samples, in_progress_samples)``.
+ """
+ return self._progress_state.compute_request_metrics()
+
+ def _compute_elapsed_seconds(self) -> int:
+ """
+ Compute elapsed seconds since first batch creation in this context.
+
+ Returns
+ -------
+ int
+ Elapsed seconds.
+ """
+ return self._progress_state.compute_elapsed_seconds()
+
+ @staticmethod
+ def _format_elapsed(*, elapsed_seconds: int) -> str:
+ """
+ Format elapsed seconds as ``HH:MM:SS``.
+
+ Parameters
+ ----------
+ elapsed_seconds : int
+ Elapsed seconds.
+
+ Returns
+ -------
+ str
+ Formatted duration.
+ """
+ hours = elapsed_seconds // 3600
+ minutes = (elapsed_seconds % 3600) // 60
+ seconds = elapsed_seconds % 60
+ return f"{hours:02d}:{minutes:02d}:{seconds:02d}"
+
+ def _render(self) -> Panel:
+ """Build the current Rich panel renderable."""
+ progress_bar = self._build_progress_bar()
+ requests_line = self._build_requests_line()
+ queue_summary_table = self._build_queue_summary_table()
+ return Panel(
+ renderable=Group(progress_bar, requests_line, queue_summary_table),
+ title="batchling context progress",
+ border_style="cyan",
+ )
+
+ def _build_progress_bar(self) -> Progress:
+ """Build aggregate context progress as a Rich progress bar."""
+ completed_samples, total_samples, _ = self._compute_progress()
+ elapsed_seconds = self._compute_elapsed_seconds()
+ elapsed_label = self._format_elapsed(elapsed_seconds=elapsed_seconds)
+ sample_width = max(1, len(str(object=total_samples)))
+
+ progress = Progress(
+ BarColumn(bar_width=None),
+ TextColumn(
+ text_format=(
+ f"[bold green]{{task.fields[completed_samples]:>{sample_width}}}[/bold green]/"
+ f"[bold cyan]{{task.fields[total_samples]:>{sample_width}}}[/bold cyan] "
+ "([bold green]{task.percentage:.1f}%[/bold green])"
+ )
+ ),
+ TextColumn(text_format=f"Time Elapsed: [bold magenta]{elapsed_label}[/bold magenta]"),
+ expand=True,
+ )
+ display_total = max(total_samples, 1)
+ _ = progress.add_task(
+ description="samples",
+ total=display_total,
+ completed=min(completed_samples, display_total),
+ completed_samples=completed_samples,
+ total_samples=total_samples,
+ )
+ return progress
+
+ def _build_requests_line(self) -> Text:
+ """
+ Build one-line request metrics shown under the progress bar.
+
+ Returns
+ -------
+ Text
+ Styled metrics line.
+ """
+ total_samples, cached_samples, completed_samples, in_progress_samples = (
+ self._compute_request_metrics()
+ )
+ line = Text()
+ line.append(text="Requests", style="bold white")
+ line.append(text=": ", style="white")
+ line.append(text="Total", style="grey70")
+ line.append(text=": ", style="grey70")
+ line.append(text=str(object=total_samples), style="bold cyan")
+ line.append(text=" - ", style="grey70")
+ line.append(text="Cached", style="grey70")
+ line.append(text=": ", style="grey70")
+ line.append(text=str(object=cached_samples), style="bold magenta")
+ line.append(text=" - ", style="grey70")
+ line.append(text="Completed", style="grey70")
+ line.append(text=": ", style="grey70")
+ line.append(text=str(object=completed_samples), style="bold green")
+ line.append(text=" - ", style="grey70")
+ line.append(text="In Progress", style="grey70")
+ line.append(text=": ", style="grey70")
+ line.append(text=str(object=in_progress_samples), style="bold yellow")
+ return line
+
+ def _compute_queue_batch_counts(self) -> list[tuple[str, str, str, int, int]]:
+ """
+ Aggregate queue-level running and terminal batch counts.
+
+ Returns
+ -------
+ list[tuple[str, str, str, int, int]]
+ Sorted rows as ``(provider, endpoint, model, running, completed)``.
+ """
+ return self._progress_state.compute_queue_batch_counts()
+
+ def _build_queue_summary_table(self) -> Table:
+ """
+ Build queue-level table with per-queue progress summary.
+
+ Returns
+ -------
+ Table
+ Queue summary table.
+ """
+ queue_rows = self._compute_queue_batch_counts()
+ table = Table(expand=False)
+ table.add_column(
+ header="provider",
+ style="bold blue",
+ width=12,
+ no_wrap=True,
+ overflow="ellipsis",
+ )
+ table.add_column(
+ header="endpoint",
+ width=34,
+ no_wrap=True,
+ overflow="ellipsis",
+ )
+ table.add_column(
+ header="model",
+ style="bold magenta",
+ width=28,
+ no_wrap=True,
+ overflow="ellipsis",
+ )
+ table.add_column(
+ header="progress",
+ justify="right",
+ width=16,
+ no_wrap=True,
+ overflow="ellipsis",
+ )
+
+ if not queue_rows:
+ table.add_row("-", "-", "-", self._format_queue_progress(running=0, completed=0))
+ return table
+
+ for provider, endpoint, model, running, completed in queue_rows:
+ table.add_row(
+ provider,
+ endpoint,
+ model,
+ self._format_queue_progress(
+ running=running,
+ completed=completed,
+ ),
+ )
+ return table
+
+ @staticmethod
+ def _format_queue_progress(*, running: int, completed: int) -> Text:
+ """
+ Format one queue progress cell as ``completed/total (percent)``.
+
+ Parameters
+ ----------
+ running : int
+ Number of non-terminal batches.
+ completed : int
+ Number of terminal batches.
+
+ Returns
+ -------
+ Text
+ Formatted queue progress.
+ """
+ total = running + completed
+ if total <= 0:
+ percent = 0.0
+ else:
+ percent = (completed / total) * 100.0
+ count_width = max(1, len(str(object=total)))
+ progress = Text()
+ progress.append(text=f"{completed:>{count_width}}", style="bold green")
+ progress.append(text="/", style="white")
+ progress.append(text=f"{total:>{count_width}}", style="bold cyan")
+ progress.append(text=" (", style="white")
+ progress.append(text=f"{percent:.1f}%", style="bold green")
+ progress.append(text=")", style="white")
+ return progress
+
+
+def should_enable_live_display(*, enabled: bool) -> bool:
+ """
+ Resolve if the Rich live panel should be enabled.
+
+ Parameters
+ ----------
+ enabled : bool
+ Requested live display toggle.
+
+ Returns
+ -------
+ bool
+ ``True`` when the live panel should run.
+ """
+ if not enabled:
+ return False
+
+ stderr_stream = sys.stderr
+ is_tty = bool(getattr(stderr_stream, "isatty", lambda: False)())
+ terminal_name = str(object=os.environ.get("TERM", "")).lower()
+ is_dumb_terminal = terminal_name in {"", "dumb"}
+ is_ci = bool(os.environ.get("CI"))
+
+ return is_tty and not is_dumb_terminal and not is_ci
diff --git a/tests/test_api.py b/tests/test_api.py
index 03da464..446d2b6 100644
--- a/tests/test_api.py
+++ b/tests/test_api.py
@@ -60,6 +60,24 @@ async def test_batchify_configures_cache_flag(reset_hooks, reset_context):
assert wrapped._self_batcher._cache_enabled is False
+@pytest.mark.asyncio
+async def test_batchify_forwards_live_display_flag(reset_hooks, reset_context):
+ """Test that batchify forwards live display flag to BatchingContext."""
+ wrapped = batchify(
+ live_display=False,
+ )
+
+ assert wrapped._self_live_display_enabled is False
+
+
+@pytest.mark.asyncio
+async def test_batchify_live_display_defaults_to_true(reset_hooks, reset_context):
+ """Test that live display defaults to enabled."""
+ wrapped = batchify()
+
+ assert wrapped._self_live_display_enabled is True
+
+
@pytest.mark.asyncio
async def test_batchify_returns_context_manager(reset_hooks, reset_context):
"""Test that batchify returns a BatchingContext."""
diff --git a/tests/test_cli_script_runner.py b/tests/test_cli_script_runner.py
index 0472062..e7426d7 100644
--- a/tests/test_cli_script_runner.py
+++ b/tests/test_cli_script_runner.py
@@ -61,6 +61,7 @@ def fake_batchify(**kwargs):
}
assert captured_batchify_kwargs["dry_run"] is True
assert captured_batchify_kwargs["cache"] is True
+ assert captured_batchify_kwargs["live_display"] is True
def test_run_script_with_cache_option(tmp_path: Path, monkeypatch):
@@ -92,6 +93,38 @@ def fake_batchify(**kwargs):
assert result.exit_code == 0
assert captured_batchify_kwargs["cache"] is False
+ assert captured_batchify_kwargs["live_display"] is True
+
+
+def test_run_script_with_no_live_display_option(tmp_path: Path, monkeypatch):
+ script_path = tmp_path / "script.py"
+ script_path.write_text(
+ "\n".join(
+ [
+ "async def foo(*args, **kwargs):",
+ " return None",
+ ]
+ )
+ + "\n"
+ )
+ captured_batchify_kwargs: dict = {}
+
+ def fake_batchify(**kwargs):
+ captured_batchify_kwargs.update(kwargs)
+ return DummyAsyncBatchifyContext()
+
+ monkeypatch.setattr(cli_main, "batchify", fake_batchify)
+
+ result = runner.invoke(
+ app,
+ [
+ f"{script_path.as_posix()}:foo",
+ "--no-live-display",
+ ],
+ )
+
+ assert result.exit_code == 0
+ assert captured_batchify_kwargs["live_display"] is False
def test_batch_size_flag_scope_for_cli_and_target_function(tmp_path: Path, monkeypatch):
diff --git a/tests/test_context.py b/tests/test_context.py
index e967a57..72c6eb9 100644
--- a/tests/test_context.py
+++ b/tests/test_context.py
@@ -3,6 +3,7 @@
"""
import asyncio
+import logging
import typing as t
import warnings
from unittest.mock import AsyncMock, patch
@@ -96,3 +97,122 @@ async def test_batching_context_without_target(batcher: Batcher, reset_context:
async with context as active_target:
assert active_batcher.get() is batcher
assert active_target is None
+
+
+@pytest.mark.asyncio
+async def test_batching_context_starts_and_stops_live_display(
+ batcher: Batcher,
+ reset_context: None,
+ monkeypatch: pytest.MonkeyPatch,
+) -> None:
+ """Test that async context starts and stops live display listeners."""
+
+ class DummyDisplay:
+ """Simple display stub."""
+
+ def __init__(self) -> None:
+ self.started = False
+ self.stopped = False
+
+ def start(self) -> None:
+ self.started = True
+
+ def stop(self) -> None:
+ self.stopped = True
+
+ def on_event(self, event: dict[str, t.Any]) -> None:
+ del event
+
+ dummy_display = DummyDisplay()
+ monkeypatch.setattr("batchling.context.BatcherRichDisplay", lambda: dummy_display)
+ monkeypatch.setattr("batchling.context.should_enable_live_display", lambda **_kwargs: True)
+ context = BatchingContext(
+ batcher=batcher,
+ live_display=True,
+ )
+
+ with patch.object(target=batcher, attribute="close", new_callable=AsyncMock):
+ async with context:
+ assert dummy_display.started is True
+ assert context._self_live_display_heartbeat_task is not None
+ assert not context._self_live_display_heartbeat_task.done()
+ assert dummy_display.stopped is True
+ assert context._self_live_display_heartbeat_task is None
+
+
+def test_batching_context_sync_stops_live_display_without_loop(
+ batcher: Batcher,
+ reset_context: None,
+ monkeypatch: pytest.MonkeyPatch,
+) -> None:
+ """Test sync context stops display when no event loop is running."""
+
+ class DummyDisplay:
+ """Simple display stub."""
+
+ def __init__(self) -> None:
+ self.stopped = False
+
+ def start(self) -> None:
+ return None
+
+ def stop(self) -> None:
+ self.stopped = True
+
+ def on_event(self, event: dict[str, t.Any]) -> None:
+ del event
+
+ dummy_display = DummyDisplay()
+ monkeypatch.setattr("batchling.context.BatcherRichDisplay", lambda: dummy_display)
+ monkeypatch.setattr("batchling.context.should_enable_live_display", lambda **_kwargs: True)
+
+ context = BatchingContext(
+ batcher=batcher,
+ live_display=True,
+ )
+
+ with warnings.catch_warnings(record=True):
+ warnings.simplefilter(action="always")
+ with context:
+ pass
+
+ assert dummy_display.stopped is True
+ assert context._self_live_display_heartbeat_task is None
+
+
+def test_batching_context_uses_polling_progress_fallback_when_auto_disabled(
+ batcher: Batcher,
+ caplog: pytest.LogCaptureFixture,
+ monkeypatch: pytest.MonkeyPatch,
+) -> None:
+ """Test live-display fallback logs progress at poll time when Rich is disabled."""
+ monkeypatch.setattr("batchling.context.should_enable_live_display", lambda **_kwargs: False)
+ context = BatchingContext(
+ batcher=batcher,
+ live_display=True,
+ )
+
+ caplog.set_level(level=logging.INFO, logger="batchling.context")
+
+ context._start_live_display()
+ assert context._self_live_display is None
+ assert context._self_polling_progress_logger is not None
+ assert any("using polling progress INFO logs" in record.message for record in caplog.records)
+
+ batcher._emit_event(
+ event_type="batch_processing",
+ batch_id="batch-1",
+ request_count=4,
+ source="poll_start",
+ )
+ batcher._emit_event(
+ event_type="batch_polled",
+ batch_id="batch-1",
+ status="in_progress",
+ source="active_poll",
+ )
+
+ assert any("Live display fallback progress" in record.message for record in caplog.records)
+
+ context._stop_live_display()
+ assert context._self_polling_progress_logger is None
diff --git a/tests/test_core.py b/tests/test_core.py
index a15d5be..960ba6e 100644
--- a/tests/test_core.py
+++ b/tests/test_core.py
@@ -10,7 +10,7 @@
import pytest
from batchling.cache import CacheEntry
-from batchling.core import Batcher, _ActiveBatch, _PendingRequest
+from batchling.core import Batcher, BatcherEvent, _ActiveBatch, _PendingRequest
from batchling.providers.anthropic import AnthropicProvider
from batchling.providers.base import PollSnapshot, ProviderRequestSpec, ResumeContext
from batchling.providers.gemini import GeminiProvider
@@ -1948,3 +1948,142 @@ async def test_dry_run_cache_hit_is_read_only(mock_openai_api_transport: httpx.M
assert dry_run_entry is not None
assert dry_run_entry.created_at == original_created_at
await dry_run_batcher.close()
+
+
+@pytest.mark.asyncio
+async def test_submit_emits_lifecycle_events(batcher: Batcher, provider: OpenAIProvider) -> None:
+ """Test submit emits queue, submit, poll, and terminal lifecycle events."""
+ events: list[BatcherEvent] = []
+ batcher._add_event_listener(listener=events.append)
+
+ _ = await batcher.submit(
+ client_type="httpx",
+ method="POST",
+ url="api.openai.com",
+ endpoint="/v1/chat/completions",
+ provider=provider,
+ headers={"Authorization": "Bearer token"},
+ body=b'{"model":"model-a","messages":[]}',
+ )
+
+ event_types = {str(event["event_type"]) for event in events}
+ assert "request_queued" in event_types
+ assert "batch_submitting" in event_types
+ assert "batch_processing" in event_types
+ assert "batch_polled" in event_types
+ assert "batch_terminal" in event_types
+
+
+@pytest.mark.asyncio
+async def test_dry_run_emits_terminal_without_poll(provider: OpenAIProvider) -> None:
+ """Test dry-run emits terminal lifecycle without provider polling events."""
+ dry_run_batcher = Batcher(
+ batch_size=1,
+ batch_window_seconds=1.0,
+ dry_run=True,
+ cache=False,
+ )
+ events: list[BatcherEvent] = []
+ dry_run_batcher._add_event_listener(listener=events.append)
+
+ _ = await dry_run_batcher.submit(
+ client_type="httpx",
+ method="POST",
+ url="api.openai.com",
+ endpoint="/v1/chat/completions",
+ provider=provider,
+ headers={"Authorization": "Bearer token"},
+ body=b'{"model":"model-a","messages":[]}',
+ )
+
+ event_types = {str(event["event_type"]) for event in events}
+ assert "request_queued" in event_types
+ assert "batch_submitting" in event_types
+ assert "batch_terminal" in event_types
+ assert "batch_polled" not in event_types
+
+ await dry_run_batcher.close()
+
+
+@pytest.mark.asyncio
+async def test_cache_hit_emits_cache_hit_routed(
+ mock_openai_api_transport: httpx.MockTransport,
+) -> None:
+ """Test cache-hit routing emits dedicated lifecycle event."""
+ writer_batcher = Batcher(
+ batch_size=2,
+ batch_window_seconds=0.1,
+ cache=True,
+ )
+ writer_batcher._client_factory = lambda: httpx.AsyncClient(transport=mock_openai_api_transport)
+ writer_batcher._poll_interval_seconds = 0.01
+ provider = OpenAIProvider()
+
+ _ = await writer_batcher.submit(
+ client_type="httpx",
+ method="POST",
+ url="api.openai.com",
+ endpoint="/v1/chat/completions",
+ provider=provider,
+ headers={"Authorization": "Bearer token"},
+ body=b'{"model":"model-a","messages":[]}',
+ )
+ await writer_batcher.close()
+
+ dry_run_batcher = Batcher(
+ batch_size=2,
+ batch_window_seconds=0.1,
+ cache=True,
+ dry_run=True,
+ )
+ events: list[BatcherEvent] = []
+ dry_run_batcher._add_event_listener(listener=events.append)
+
+ _ = await dry_run_batcher.submit(
+ client_type="httpx",
+ method="POST",
+ url="api.openai.com",
+ endpoint="/v1/chat/completions",
+ provider=OpenAIProvider(),
+ headers={"Authorization": "Bearer token"},
+ body=b'{"model":"model-a","messages":[]}',
+ )
+
+ assert any(
+ event["event_type"] == "cache_hit_routed" and event.get("source") == "cache_dry_run"
+ for event in events
+ )
+ await dry_run_batcher.close()
+
+
+def test_fail_missing_results_emits_lifecycle_event() -> None:
+ """Test missing output mapping emits missing-results lifecycle event."""
+ batcher = Batcher(
+ batch_size=2,
+ batch_window_seconds=1.0,
+ cache=False,
+ )
+ loop = asyncio.new_event_loop()
+ future: asyncio.Future[t.Any] = loop.create_future()
+ request = _PendingRequest(
+ custom_id="custom-1",
+ queue_key=("openai", "/v1/chat/completions", "model-a"),
+ params={},
+ provider=OpenAIProvider(),
+ future=future,
+ request_hash="hash-1",
+ )
+ active_batch = _ActiveBatch(
+ batch_id="batch-1",
+ output_file_id="file-1",
+ error_file_id="",
+ requests={"custom-1": request},
+ )
+
+ events: list[BatcherEvent] = []
+ batcher._add_event_listener(listener=events.append)
+
+ batcher._fail_missing_results(active_batch=active_batch, seen=set())
+
+ assert any(event["event_type"] == "missing_results" for event in events)
+ loop.close()
diff --git a/tests/test_rich_display.py b/tests/test_rich_display.py
new file mode 100644
index 0000000..2d37959
--- /dev/null
+++ b/tests/test_rich_display.py
@@ -0,0 +1,370 @@
+"""Tests for Rich live display helpers."""
+
+import io
+
+from rich.console import Console
+from rich.text import Text
+
+import batchling.rich_display as rich_display
+
+
+def test_should_enable_live_display_auto_in_interactive_terminal(
+ monkeypatch,
+) -> None:
+ """Test auto mode enables display in interactive terminals."""
+
+ class DummyStderr:
+ def isatty(self) -> bool:
+ return True
+
+ monkeypatch.setattr(rich_display.sys, "stderr", DummyStderr())
+ monkeypatch.setenv("TERM", "xterm-256color")
+ monkeypatch.delenv("CI", raising=False)
+
+ assert rich_display.should_enable_live_display(enabled=True) is True
+
+
+def test_should_enable_live_display_auto_disabled_in_ci(monkeypatch) -> None:
+ """Test auto mode disables display in CI environments."""
+
+ class DummyStderr:
+ def isatty(self) -> bool:
+ return True
+
+ monkeypatch.setattr(rich_display.sys, "stderr", DummyStderr())
+ monkeypatch.setenv("TERM", "xterm-256color")
+ monkeypatch.setenv("CI", "true")
+
+ assert rich_display.should_enable_live_display(enabled=True) is False
+
+
+def test_should_enable_live_display_disabled_by_flag(monkeypatch) -> None:
+ """Test explicit disable always returns False."""
+
+ class DummyStderr:
+ def isatty(self) -> bool:
+ return True
+
+ monkeypatch.setattr(rich_display.sys, "stderr", DummyStderr())
+ monkeypatch.setenv("TERM", "xterm-256color")
+ monkeypatch.delenv("CI", raising=False)
+
+ assert rich_display.should_enable_live_display(enabled=False) is False
+
+
+def test_batcher_rich_display_computes_context_progress() -> None:
+ """Test context progress is derived from completed batch sizes."""
+ display = rich_display.BatcherRichDisplay(
+ console=Console(file=io.StringIO(), force_terminal=False),
+ )
+
+ processing_event: rich_display.BatcherEvent = {
+ "event_type": "batch_processing",
+ "timestamp": 1.0,
+ "provider": "openai",
+ "endpoint": "/v1/chat/completions",
+ "model": "model-a",
+ "queue_key": ("openai", "/v1/chat/completions", "model-a"),
+ "batch_id": "batch-1",
+ "request_count": 3,
+ "source": "poll_start",
+ }
+ terminal_event: rich_display.BatcherEvent = {
+ "event_type": "batch_terminal",
+ "timestamp": 2.0,
+ "provider": "openai",
+ "batch_id": "batch-1",
+ "status": "completed",
+ "source": "active_poll",
+ }
+ failed_batch_event: rich_display.BatcherEvent = {
+ "event_type": "batch_processing",
+ "timestamp": 3.0,
+ "provider": "openai",
+ "endpoint": "/v1/chat/completions",
+ "model": "model-a",
+ "queue_key": ("openai", "/v1/chat/completions", "model-a"),
+ "batch_id": "batch-2",
+ "request_count": 2,
+ "source": "poll_start",
+ }
+ failed_terminal_event: rich_display.BatcherEvent = {
+ "event_type": "batch_terminal",
+ "timestamp": 4.0,
+ "provider": "openai",
+ "batch_id": "batch-2",
+ "status": "failed",
+ "source": "active_poll",
+ }
+
+ display.on_event(processing_event)
+ display.on_event(terminal_event)
+ display.on_event(failed_batch_event)
+ display.on_event(failed_terminal_event)
+
+ completed_samples, total_samples, percent = display._compute_progress()
+ assert completed_samples == 3
+ assert total_samples == 5
+ assert percent == 60.0
+
+
+def test_batcher_rich_display_tracks_resumed_batch_progress() -> None:
+ """Test resumed cache-hit routing contributes to total and completion."""
+ display = rich_display.BatcherRichDisplay(
+ console=Console(file=io.StringIO(), force_terminal=False),
+ )
+
+ cache_event: rich_display.BatcherEvent = {
+ "event_type": "cache_hit_routed",
+ "timestamp": 1.0,
+ "provider": "openai",
+ "endpoint": "/v1/chat/completions",
+ "model": "model-a",
+ "batch_id": "batch-cached-1",
+ "source": "resumed_poll",
+ "custom_id": "custom-1",
+ }
+ terminal_event: rich_display.BatcherEvent = {
+ "event_type": "batch_terminal",
+ "timestamp": 2.0,
+ "provider": "openai",
+ "batch_id": "batch-cached-1",
+ "status": "completed",
+ "source": "resumed_poll",
+ }
+
+ display.on_event(cache_event)
+ display.on_event(cache_event)
+ display.on_event(terminal_event)
+
+ completed_samples, total_samples, percent = display._compute_progress()
+ assert completed_samples == 2
+ assert total_samples == 2
+ assert percent == 100.0
+
+
+def test_batcher_rich_display_elapsed_uses_first_batch_time(monkeypatch) -> None:
+ """Test elapsed timer starts from first batch seen in the context."""
+ current_time = {"value": 100.0}
+
+ def fake_time() -> float:
+ return current_time["value"]
+
+ monkeypatch.setattr(rich_display.time, "time", fake_time)
+
+ display = rich_display.BatcherRichDisplay(
+ console=Console(file=io.StringIO(), force_terminal=False),
+ )
+
+ first_batch_event: rich_display.BatcherEvent = {
+ "event_type": "batch_processing",
+ "timestamp": 100.0,
+ "provider": "openai",
+ "endpoint": "/v1/chat/completions",
+ "model": "model-a",
+ "queue_key": ("openai", "/v1/chat/completions", "model-a"),
+ "batch_id": "batch-1",
+ "request_count": 1,
+ "source": "poll_start",
+ }
+ display.on_event(first_batch_event)
+
+ current_time["value"] = 127.0
+ assert display._compute_elapsed_seconds() == 27
+ assert display._format_elapsed(elapsed_seconds=27) == "00:00:27"
+
+
+def test_batcher_rich_display_elapsed_starts_with_cache_batch(monkeypatch) -> None:
+ """Test elapsed timer also starts when the first batch comes from cache routing."""
+ current_time = {"value": 200.0}
+
+ def fake_time() -> float:
+ return current_time["value"]
+
+ monkeypatch.setattr(rich_display.time, "time", fake_time)
+
+ display = rich_display.BatcherRichDisplay(
+ console=Console(file=io.StringIO(), force_terminal=False),
+ )
+
+ cache_event: rich_display.BatcherEvent = {
+ "event_type": "cache_hit_routed",
+ "timestamp": 200.0,
+ "provider": "openai",
+ "endpoint": "/v1/chat/completions",
+ "model": "model-a",
+ "batch_id": "batch-cached-1",
+ "source": "resumed_poll",
+ "custom_id": "custom-1",
+ }
+ display.on_event(cache_event)
+
+ current_time["value"] = 206.0
+ assert display._compute_elapsed_seconds() == 6
+
+
+def test_batcher_rich_display_request_metrics_line() -> None:
+ """Test requests metrics aggregate total/cached/completed/in-progress samples."""
+ display = rich_display.BatcherRichDisplay(
+ console=Console(file=io.StringIO(), force_terminal=False),
+ )
+
+ processing_event_batch_1: rich_display.BatcherEvent = {
+ "event_type": "batch_processing",
+ "timestamp": 1.0,
+ "provider": "openai",
+ "endpoint": "/v1/chat/completions",
+ "model": "model-a",
+ "queue_key": ("openai", "/v1/chat/completions", "model-a"),
+ "batch_id": "batch-1",
+ "request_count": 3,
+ "source": "poll_start",
+ }
+ terminal_event_batch_1: rich_display.BatcherEvent = {
+ "event_type": "batch_terminal",
+ "timestamp": 2.0,
+ "provider": "openai",
+ "batch_id": "batch-1",
+ "status": "completed",
+ "source": "active_poll",
+ }
+ processing_event_batch_2: rich_display.BatcherEvent = {
+ "event_type": "batch_processing",
+ "timestamp": 3.0,
+ "provider": "openai",
+ "endpoint": "/v1/chat/completions",
+ "model": "model-a",
+ "queue_key": ("openai", "/v1/chat/completions", "model-a"),
+ "batch_id": "batch-2",
+ "request_count": 2,
+ "source": "poll_start",
+ }
+ cache_event_batch_3: rich_display.BatcherEvent = {
+ "event_type": "cache_hit_routed",
+ "timestamp": 4.0,
+ "provider": "openai",
+ "endpoint": "/v1/chat/completions",
+ "model": "model-a",
+ "batch_id": "batch-3",
+ "source": "resumed_poll",
+ "custom_id": "custom-1",
+ }
+ terminal_event_batch_3: rich_display.BatcherEvent = {
+ "event_type": "batch_terminal",
+ "timestamp": 5.0,
+ "provider": "openai",
+ "batch_id": "batch-3",
+ "status": "completed",
+ "source": "resumed_poll",
+ }
+
+ display.on_event(processing_event_batch_1)
+ display.on_event(terminal_event_batch_1)
+ display.on_event(processing_event_batch_2)
+ display.on_event(cache_event_batch_3)
+ display.on_event(cache_event_batch_3)
+ display.on_event(terminal_event_batch_3)
+
+ total_samples, cached_samples, completed_samples, in_progress_samples = (
+ display._compute_request_metrics()
+ )
+ assert total_samples == 7
+ assert cached_samples == 2
+ assert completed_samples == 5
+ assert in_progress_samples == 2
+
+
+def test_batcher_rich_display_queue_table_progress_column() -> None:
+ """Test queue table progress column derives from completed/total counts."""
+ display = rich_display.BatcherRichDisplay(
+ console=Console(file=io.StringIO(), force_terminal=False),
+ )
+
+ queue_event_batch_1: rich_display.BatcherEvent = {
+ "event_type": "batch_processing",
+ "timestamp": 1.0,
+ "provider": "openai",
+ "endpoint": "/v1/chat/completions",
+ "model": "model-a",
+ "queue_key": ("openai", "/v1/chat/completions", "model-a"),
+ "batch_id": "batch-1",
+ "request_count": 1,
+ "source": "poll_start",
+ }
+ queue_event_batch_2: rich_display.BatcherEvent = {
+ "event_type": "batch_processing",
+ "timestamp": 2.0,
+ "provider": "openai",
+ "endpoint": "/v1/chat/completions",
+ "model": "model-a",
+ "queue_key": ("openai", "/v1/chat/completions", "model-a"),
+ "batch_id": "batch-2",
+ "request_count": 1,
+ "source": "poll_start",
+ }
+ terminal_event_batch_2: rich_display.BatcherEvent = {
+ "event_type": "batch_terminal",
+ "timestamp": 3.0,
+ "provider": "openai",
+ "batch_id": "batch-2",
+ "status": "completed",
+ "source": "active_poll",
+ }
+ other_queue_event: rich_display.BatcherEvent = {
+ "event_type": "batch_processing",
+ "timestamp": 4.0,
+ "provider": "groq",
+ "endpoint": "/openai/v1/chat/completions",
+ "model": "llama-3.1-8b-instant",
+ "queue_key": ("groq", "/openai/v1/chat/completions", "llama-3.1-8b-instant"),
+ "batch_id": "batch-3",
+ "request_count": 1,
+ "source": "poll_start",
+ }
+
+ display.on_event(queue_event_batch_1)
+ display.on_event(queue_event_batch_2)
+ display.on_event(terminal_event_batch_2)
+ display.on_event(other_queue_event)
+
+ queue_rows = display._compute_queue_batch_counts()
+ assert queue_rows == [
+ ("groq", "/openai/v1/chat/completions", "llama-3.1-8b-instant", 1, 0),
+ ("openai", "/v1/chat/completions", "model-a", 1, 1),
+ ]
+
+ table = display._build_queue_summary_table()
+ assert table.columns[0].width == 12
+ assert table.columns[1].width == 34
+ assert table.columns[2].width == 28
+ assert table.columns[3].width == 16
+ assert table.columns[0]._cells == ["groq", "openai"]
+ progress_cells = table.columns[3]._cells
+ assert isinstance(progress_cells[0], Text)
+ assert isinstance(progress_cells[1], Text)
+ assert progress_cells[0].plain == "0/1 (0.0%)"
+ assert progress_cells[1].plain == "1/2 (50.0%)"
+
+
+def test_batcher_rich_display_queue_table_empty_state() -> None:
+ """Test queue table renders default row when no batches are tracked."""
+ display = rich_display.BatcherRichDisplay(
+ console=Console(file=io.StringIO(), force_terminal=False),
+ )
+
+ table = display._build_queue_summary_table()
+ assert table.columns[0]._cells == ["-"]
+ assert table.columns[1]._cells == ["-"]
+ assert table.columns[2]._cells == ["-"]
+ progress_cells = table.columns[3]._cells
+ assert isinstance(progress_cells[0], Text)
+ assert progress_cells[0].plain == "0/0 (0.0%)"
+
+
+def test_batcher_rich_display_queue_progress_pads_to_total_width() -> None:
+ """Test queue progress keeps parenthesis anchor stable for large totals."""
+ progress_text = rich_display.BatcherRichDisplay._format_queue_progress(
+ running=99,
+ completed=1,
+ )
+ assert progress_text.plain == " 1/100 (1.0%)"
diff --git a/uv.lock b/uv.lock
index 9e3a08a..cf0005b 100644
--- a/uv.lock
+++ b/uv.lock
@@ -267,6 +267,7 @@ source = { editable = "." }
dependencies = [
{ name = "aiohttp" },
{ name = "httpx" },
+ { name = "rich" },
{ name = "typer" },
]
@@ -304,6 +305,7 @@ dev = [
requires-dist = [
{ name = "aiohttp", specifier = ">=3.13.3" },
{ name = "httpx", specifier = ">=0.28.1" },
+ { name = "rich", specifier = ">=14.2.0" },
{ name = "typer", specifier = ">=0.20.0" },
]