diff --git a/docs/architecture/api.md b/docs/architecture/api.md index ca251fe..637636a 100644 --- a/docs/architecture/api.md +++ b/docs/architecture/api.md @@ -9,7 +9,7 @@ yields `None`. Import it from `batchling`. - Install HTTP hooks once (idempotent). - Construct a `Batcher` with configuration such as `batch_size`, `batch_window_seconds`, `batch_poll_interval_seconds`, `dry_run`, - and `cache`. + `cache`, and `live_display`. - Configure `batchling` logging defaults with Python's stdlib `logging` (`WARNING` by default). - Return a `BatchingContext` to scope batching to a context manager. @@ -26,6 +26,12 @@ yields `None`. Import it from `batchling`. - **`cache` behavior**: when `cache=True` (default), intercepted requests are fingerprinted and looked up in a persistent request cache. Cache hits bypass queueing and resume polling from an existing provider batch when not in dry-run mode. +- **`live_display` behavior**: `live_display` is a boolean. + When `True` (default), Rich panel rendering runs in auto mode and is enabled + only when `stderr` is a TTY, terminal is not `dumb`, and `CI` is not set. + If auto mode disables Rich, context-level progress is logged at `INFO` on + polling events. + When `False`, live display and fallback progress logs are both disabled. - **Outputs**: `BatchingContext[None]` instance that yields `None`. - **Logging**: lifecycle milestones are emitted at `INFO`, problems at `WARNING`/`ERROR`, and high-volume diagnostics at `DEBUG`. Request payloads @@ -43,7 +49,7 @@ Behavior: - CLI options map directly to `batchify` arguments: `batch_size`, `batch_window_seconds`, `batch_poll_interval_seconds`, `dry_run`, - and `cache`. + `cache`, and `live_display`. - Script target must use `module_path:function_name` syntax. - Forwarded callable arguments are mapped as: positional tokens are passed as positional arguments; diff --git a/docs/architecture/context.md b/docs/architecture/context.md index 78f67a9..8e67530 100644 --- a/docs/architecture/context.md +++ b/docs/architecture/context.md @@ -9,6 +9,7 @@ a context variable. - Activate the `active_batcher` context for the duration of a context block. - Yield `None` for scope-only lifecycle control. - Support sync and async context manager patterns for cleanup and context scoping. +- Start and stop optional Rich live activity display while the context is active. ## Flow summary @@ -16,7 +17,12 @@ a context variable. 2. `__enter__`/`__aenter__` set the active batcher for the entire context block. 3. `__exit__` resets the context and schedules `batcher.close()` if an event loop is running (otherwise it warns). -4. `__aexit__` resets the context and awaits `batcher.close()` to flush pending work. +4. If `live_display=True`, the context attempts to start Rich panel rendering at + enter-time when terminal auto-detection passes (`TTY`, non-`dumb`, non-`CI`). + Otherwise it registers an `INFO` logging fallback that emits progress at poll-time. +5. `__aexit__` resets the context and awaits `batcher.close()` to flush pending work. +6. The live display listener is removed and the panel is stopped when context cleanup + finishes. ## Code reference diff --git a/docs/cli.md b/docs/cli.md index 692a1ec..9c05cda 100644 --- a/docs/cli.md +++ b/docs/cli.md @@ -64,6 +64,29 @@ batchling generate_product_images.py:main That's it! Just run that command and you save 50% off your workflow. +## Live visibility panel + +The CLI also exposes the live Rich panel control: + +```bash +batchling generate_product_images.py:main --live-display +``` + +`--live-display` is a boolean flag pair: + +- `--live-display` (default): auto mode, Rich panel only in interactive terminals + (`TTY`, non-`dumb`, non-`CI`). If Rich auto-disables, progress is emitted as + `INFO` logs on polling events. +- `--no-live-display`: disable both Rich panel and fallback progress logs. + +When enabled, the panel shows overall context progress: +`completed_samples / total_samples`, completion percentage, and `Time Elapsed` +since the first batch seen in the context. +It also shows request counters and a queue summary table with one row per +`(provider, endpoint, model)`, including `progress` as +`completed/total (percentage)` where `completed` is terminal batches and +`total` is `running + completed`. + ## Next Steps If you haven't yet, look at how you can: diff --git a/docs/python-sdk.md b/docs/python-sdk.md index 4a00228..4208075 100644 --- a/docs/python-sdk.md +++ b/docs/python-sdk.md @@ -79,6 +79,30 @@ async def main(): That's it! Update three lines of code and you save 50% off your workflow. +## Live visibility panel + +You can toggle live visibility behavior while the context is active: + +```py +async with batchify(live_display=True): + generated_images = await asyncio.gather(*tasks) +``` + +`live_display` accepts a boolean: + +- `True` (default): auto mode, Rich panel only in interactive terminals + (`TTY`, non-`dumb`, non-`CI`). If Rich auto-disables, progress is emitted as + `INFO` logs on polling events. +- `False`: disable both Rich panel and fallback progress logs. + +When enabled, the panel shows context-level progress only: +`completed_samples / total_samples`, completion percentage, and `Time Elapsed` +since the first batch seen in the context. +It also shows request counters and a queue summary table with one row per +`(provider, endpoint, model)`, including `progress` as +`completed/total (percentage)` where `completed` is terminal batches and +`total` is `running + completed`. + You can now run this script normally using python and start saving money: ```bash diff --git a/examples/racing.py b/examples/racing.py new file mode 100644 index 0000000..22ede9a --- /dev/null +++ b/examples/racing.py @@ -0,0 +1,278 @@ +import asyncio +import os +import time +import typing as t +from dataclasses import dataclass + +from dotenv import load_dotenv +from groq import AsyncGroq +from mistralai import Mistral +from openai import AsyncOpenAI +from together import AsyncTogether + +from batchling import batchify + +load_dotenv() + + +@dataclass +class ProviderRaceResult: + """One provider completion entry in completion order.""" + + model: str + elapsed_seconds: float + answer: str + + +ProviderRequestBuilder = t.Callable[[], t.Coroutine[t.Any, t.Any, tuple[str, str]]] + + +async def run_openai_request(*, prompt: str) -> tuple[str, str]: + """ + Send one OpenAI request. + + Parameters + ---------- + prompt : str + User prompt sent to the provider. + + Returns + ------- + tuple[str, str] + ``(model_name, answer_text)``. + """ + client = AsyncOpenAI(api_key=os.getenv(key="OPENAI_API_KEY")) + response = await client.responses.create( + input=prompt, + model="gpt-4o-mini", + ) + content = response.output[-1].content + return response.model, content[0].text + + +async def run_groq_request(*, prompt: str) -> tuple[str, str]: + """ + Send one Groq request. + + Parameters + ---------- + prompt : str + User prompt sent to the provider. + + Returns + ------- + tuple[str, str] + ``(model_name, answer_text)``. + """ + client = AsyncGroq(api_key=os.getenv(key="GROQ_API_KEY")) + response = await client.chat.completions.create( + model="llama-3.1-8b-instant", + messages=[ + { + "role": "user", + "content": prompt, + } + ], + ) + return response.model, response.choices[0].message.content + + +async def run_mistral_request(*, prompt: str) -> tuple[str, str]: + """ + Send one Mistral request. + + Parameters + ---------- + prompt : str + User prompt sent to the provider. + + Returns + ------- + tuple[str, str] + ``(model_name, answer_text)``. + """ + client = Mistral(api_key=os.getenv(key="MISTRAL_API_KEY")) + response = await client.chat.complete_async( + model="mistral-medium-2505", + messages=[ + { + "role": "user", + "content": prompt, + } + ], + stream=False, + response_format={"type": "text"}, + ) + return response.model, str(object=response.choices[0].message.content) + + +async def run_together_request(*, prompt: str) -> tuple[str, str]: + """ + Send one Together request. + + Parameters + ---------- + prompt : str + User prompt sent to the provider. + + Returns + ------- + tuple[str, str] + ``(model_name, answer_text)``. + """ + client = AsyncTogether(api_key=os.getenv(key="TOGETHER_API_KEY")) + response = await client.chat.completions.create( + model="google/gemma-3n-E4B-it", + messages=[ + { + "role": "user", + "content": prompt, + } + ], + ) + return response.model, response.choices[0].message.content + + +async def run_doubleword_request(*, prompt: str) -> tuple[str, str]: + """ + Send one Doubleword request. + + Parameters + ---------- + prompt : str + User prompt sent to the provider. + + Returns + ------- + tuple[str, str] + ``(model_name, answer_text)``. + """ + client = AsyncOpenAI( + api_key=os.getenv(key="DOUBLEWORD_API_KEY"), + base_url="https://api.doubleword.ai/v1", + ) + response = await client.responses.create( + input=prompt, + model="openai/gpt-oss-20b", + ) + content = response.output[-1].content + return response.model, content[0].text + + +async def run_provider_request( + *, + request_builder: ProviderRequestBuilder, + started_at: float, +) -> ProviderRaceResult: + """ + Execute one provider request and annotate elapsed time. + + Parameters + ---------- + request_builder : ProviderRequestBuilder + Provider request coroutine factory. + started_at : float + Shared wall-clock start time in ``perf_counter`` seconds. + + Returns + ------- + ProviderRaceResult + Result payload with answer and elapsed time. + """ + model, answer = await request_builder() + elapsed_seconds = time.perf_counter() - started_at + return ProviderRaceResult( + model=model, + elapsed_seconds=elapsed_seconds, + answer=answer, + ) + + +def build_enabled_request_builders(*, prompt: str) -> list[ProviderRequestBuilder]: + """ + Build one request factory per configured provider. + + Parameters + ---------- + prompt : str + Shared text prompt sent to all providers. + + Returns + ------- + list[ProviderRequestBuilder] + Enabled provider request factories. + """ + providers: list[tuple[str, ProviderRequestBuilder]] = [ + ( + "OPENAI_API_KEY", + lambda: run_openai_request(prompt=prompt), + ), + ( + "GROQ_API_KEY", + lambda: run_groq_request(prompt=prompt), + ), + ( + "MISTRAL_API_KEY", + lambda: run_mistral_request(prompt=prompt), + ), + ( + "TOGETHER_API_KEY", + lambda: run_together_request(prompt=prompt), + ), + ( + "DOUBLEWORD_API_KEY", + lambda: run_doubleword_request(prompt=prompt), + ), + ] + enabled_builders: list[ProviderRequestBuilder] = [] + for env_var_name, request_builder in providers: + api_key = os.getenv(key=env_var_name) + if not api_key: + continue + enabled_builders.append(request_builder) + return enabled_builders + + +async def main() -> None: + """ + Run one request per provider and collect completion-order results. + + The race excludes Anthropic, Gemini, and XAI on purpose because their model field + extraction differs from the other provider examples. + """ + prompt = "Give one short sentence explaining what asynchronous batching is." + request_builders = build_enabled_request_builders(prompt=prompt) + if not request_builders: + print("No providers configured. Set at least one provider API key in your environment.") + return + + started_at = time.perf_counter() + tasks = [ + asyncio.create_task( + run_provider_request( + request_builder=request_builder, + started_at=started_at, + ) + ) + for request_builder in request_builders + ] + + completion_order_register: list[ProviderRaceResult] = [] + for task in asyncio.as_completed(tasks): + result = await task + completion_order_register.append(result) + + for index, result in enumerate(completion_order_register, start=1): + print(f"{index}. model={result.model}") + print(f" elapsed={result.elapsed_seconds:.2f}s") + print(f" answer={result.answer}\n") + + +async def run_with_batchify() -> None: + """Run the provider race inside ``batchify`` for direct script execution.""" + async with batchify(): + await main() + + +if __name__ == "__main__": + asyncio.run(run_with_batchify()) diff --git a/pyproject.toml b/pyproject.toml index ba8394a..b7c4e15 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,6 +31,7 @@ classifiers = [ dependencies = [ "aiohttp>=3.13.3", "httpx>=0.28.1", + "rich>=14.2.0", "typer>=0.20.0", ] diff --git a/requirements.txt b/requirements.txt index d0e7873..99d8d5c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -645,6 +645,7 @@ requests-toolbelt==1.0.0 respx==0.22.0 rich==14.2.0 # via + # batchling # cyclopts # fastmcp # instructor diff --git a/src/batchling/api.py b/src/batchling/api.py index a00afb9..46a2f88 100644 --- a/src/batchling/api.py +++ b/src/batchling/api.py @@ -16,6 +16,7 @@ def batchify( batch_poll_interval_seconds: float = 10.0, dry_run: bool = False, cache: bool = True, + live_display: bool = True, ) -> BatchingContext: """ Context manager used to activate batching for a scoped context.
@@ -37,6 +38,11 @@ def batchify( cache : bool, optional If ``True``, enable persistent request cache lookups.
This parameter allows to skip the batch submission and go straight to the polling phase for requests that have already been sent. + live_display : bool, optional + Enable live display behavior while the context is active.
+ When ``True``, Rich panel rendering is attempted with terminal auto-detection. + If terminal auto-detection disables Rich (non-TTY, ``TERM=dumb``, or ``CI``), + progress is logged at ``INFO`` on polling events. Returns ------- @@ -65,4 +71,5 @@ def batchify( # 3. Return BatchingContext with no yielded target. return BatchingContext( batcher=batcher, + live_display=live_display, ) diff --git a/src/batchling/cli/main.py b/src/batchling/cli/main.py index d266302..6cba6dd 100644 --- a/src/batchling/cli/main.py +++ b/src/batchling/cli/main.py @@ -75,7 +75,8 @@ async def run_script_with_batchify( batch_poll_interval_seconds: float, dry_run: bool, cache: bool, -): + live_display: bool, +) -> None: """ Execute a Python script under a batchify context. @@ -97,6 +98,8 @@ async def run_script_with_batchify( Dry run mode passed to ``batchify``. cache : bool Cache mode passed to ``batchify``. + live_display : bool + Live display toggle passed to ``batchify``. """ if not module_path.exists(): typer.echo(f"Script not found: {module_path}") @@ -112,6 +115,7 @@ async def run_script_with_batchify( batch_poll_interval_seconds=batch_poll_interval_seconds, dry_run=dry_run, cache=cache, + live_display=live_display, ): ns = runpy.run_path(path_name=script_path_as_posix, run_name="batchling.runtime") func = ns.get(func_name) @@ -155,6 +159,16 @@ def main( bool, typer.Option("--cache/--no-cache", help="Enable persistent request caching"), ] = True, + live_display: t.Annotated[ + bool, + typer.Option( + "--live-display/--no-live-display", + help=( + "Enable auto live display. When disabled by terminal auto-detection, " + "fallback polling progress is logged at INFO." + ), + ), + ] = True, ): """Run a script under ``batchify``.""" try: @@ -173,5 +187,6 @@ def main( batch_poll_interval_seconds=batch_poll_interval_seconds, dry_run=dry_run, cache=cache, + live_display=live_display, ) ) diff --git a/src/batchling/context.py b/src/batchling/context.py index 0916481..82ccf12 100644 --- a/src/batchling/context.py +++ b/src/batchling/context.py @@ -3,11 +3,55 @@ """ import asyncio +import logging import typing as t import warnings -from batchling.core import Batcher +from batchling.core import Batcher, BatcherEvent from batchling.hooks import active_batcher +from batchling.logging import log_info +from batchling.progress_state import BatchProgressState +from batchling.rich_display import ( + BatcherRichDisplay, + should_enable_live_display, +) + +log = logging.getLogger(name=__name__) + + +class _PollingProgressLogger: + """INFO logger fallback used when Rich live display auto-disables.""" + + def __init__(self) -> None: + self._progress_state = BatchProgressState() + + def on_event(self, event: BatcherEvent) -> None: + """ + Consume one lifecycle event and log progress on poll events. + + Parameters + ---------- + event : BatcherEvent + Lifecycle event emitted by ``Batcher``. + """ + self._progress_state.on_event(event=event) + + event_type = str(object=event.get("event_type", "unknown")) + if event_type != "batch_polled": + return + + completed_samples, total_samples, percent = self._progress_state.compute_progress() + _, _, _, in_progress_samples = self._progress_state.compute_request_metrics() + log_info( + logger=log, + event="Live display fallback progress", + batch_id=event.get("batch_id"), + status=event.get("status"), + completed_samples=completed_samples, + total_samples=total_samples, + percent=f"{percent:.1f}", + in_progress_samples=in_progress_samples, + ) class BatchingContext: @@ -18,9 +62,16 @@ class BatchingContext: ---------- batcher : Batcher Batcher instance used for the scope of the context manager. + live_display : bool, optional + Whether to enable auto live display behavior for the context. """ - def __init__(self, *, batcher: "Batcher") -> None: + def __init__( + self, + *, + batcher: "Batcher", + live_display: bool = True, + ) -> None: """ Initialize the context manager. @@ -28,10 +79,130 @@ def __init__(self, *, batcher: "Batcher") -> None: ---------- batcher : Batcher Batcher instance used for the scope of the context manager. + live_display : bool, optional + Whether to enable auto live display behavior for the context. """ self._self_batcher = batcher + self._self_live_display_enabled = live_display + self._self_live_display: BatcherRichDisplay | None = None + self._self_live_display_heartbeat_task: asyncio.Task[None] | None = None + self._self_polling_progress_logger: _PollingProgressLogger | None = None self._self_context_token: t.Any | None = None + def _start_polling_progress_logger(self) -> None: + """ + Start the INFO polling progress fallback listener. + + Notes + ----- + Fallback listener errors are downgraded to warnings. + """ + if self._self_polling_progress_logger is not None: + return + try: + listener = _PollingProgressLogger() + self._self_batcher._add_event_listener(listener=listener.on_event) + self._self_polling_progress_logger = listener + log_info( + logger=log, + event=( + "Live display disabled by terminal auto-detection; " + "using polling progress INFO logs" + ), + ) + except Exception as error: + warnings.warn( + message=f"Failed to start batchling polling progress logs: {error}", + category=UserWarning, + stacklevel=2, + ) + + def _start_live_display(self) -> None: + """ + Start the Rich live display when enabled. + + Notes + ----- + Display errors are downgraded to warnings to avoid breaking batching. + """ + if self._self_live_display is not None or self._self_polling_progress_logger is not None: + return + if not self._self_live_display_enabled: + return + if not should_enable_live_display(enabled=self._self_live_display_enabled): + self._start_polling_progress_logger() + return + try: + display = BatcherRichDisplay() + self._self_batcher._add_event_listener(listener=display.on_event) + display.start() + self._self_live_display = display + self._start_live_display_heartbeat() + except Exception as error: + warnings.warn( + message=f"Failed to start batchling live display: {error}", + category=UserWarning, + stacklevel=2, + ) + + async def _run_live_display_heartbeat(self) -> None: + """ + Periodically refresh the live display while the context is active. + """ + try: + while self._self_live_display is not None: + self._self_live_display.refresh() + await asyncio.sleep(1.0) + except asyncio.CancelledError: + raise + + def _start_live_display_heartbeat(self) -> None: + """ + Start the 1-second live display heartbeat when an event loop exists. + """ + if self._self_live_display is None: + return + if self._self_live_display_heartbeat_task is not None: + return + try: + loop = asyncio.get_running_loop() + except RuntimeError: + return + self._self_live_display_heartbeat_task = loop.create_task( + coro=self._run_live_display_heartbeat() + ) + + def _stop_live_display(self) -> None: + """ + Stop and unregister the Rich live display. + + Notes + ----- + Display shutdown errors are downgraded to warnings. + """ + display = self._self_live_display + fallback_listener = self._self_polling_progress_logger + if display is None and fallback_listener is None: + return + self._self_live_display = None + self._self_polling_progress_logger = None + heartbeat_task = self._self_live_display_heartbeat_task + self._self_live_display_heartbeat_task = None + if heartbeat_task is not None and not heartbeat_task.done(): + heartbeat_task.cancel() + try: + if display is not None: + self._self_batcher._remove_event_listener(listener=display.on_event) + display.stop() + if fallback_listener is not None: + self._self_batcher._remove_event_listener(listener=fallback_listener.on_event) + except Exception as error: + warnings.warn( + message=f"Failed to stop batchling live display: {error}", + category=UserWarning, + stacklevel=2, + ) + def __enter__(self) -> None: """ Enter the synchronous context manager and activate the batcher. @@ -42,6 +213,7 @@ def __enter__(self) -> None: ``None`` for scoped activation. """ self._self_context_token = active_batcher.set(self._self_batcher) + self._start_live_display() return None def __exit__( @@ -67,7 +239,8 @@ def __exit__( self._self_context_token = None try: loop = asyncio.get_running_loop() - loop.create_task(coro=self._self_batcher.close()) + close_task = loop.create_task(coro=self._self_batcher.close()) + close_task.add_done_callback(self._on_sync_close_done) except RuntimeError: warnings.warn( message=( @@ -78,6 +251,18 @@ def __exit__( category=UserWarning, stacklevel=2, ) + self._stop_live_display() + + def _on_sync_close_done(self, _: asyncio.Task[None]) -> None: + """ + Callback run when sync-context close task completes. + + Parameters + ---------- + _ : asyncio.Task[None] + Completed close task. + """ + self._stop_live_display() async def __aenter__(self) -> None: """ @@ -89,6 +274,7 @@ async def __aenter__(self) -> None: ``None`` for scoped activation. """ self._self_context_token = active_batcher.set(self._self_batcher) + self._start_live_display() return None async def __aexit__( @@ -112,4 +298,7 @@ async def __aexit__( if self._self_context_token is not None: active_batcher.reset(self._self_context_token) self._self_context_token = None - await self._self_batcher.close() + try: + await self._self_batcher.close() + finally: + self._stop_live_display() diff --git a/src/batchling/core.py b/src/batchling/core.py index 65dc7aa..f27c451 100644 --- a/src/batchling/core.py +++ b/src/batchling/core.py @@ -31,6 +31,59 @@ CACHE_RETENTION_SECONDS = 30 * 24 * 60 * 60 +class BatcherEvent(t.TypedDict, total=False): + """ + Lifecycle event emitted by ``Batcher`` for optional observers. + + event_type : str + Event identifier. + timestamp : float + Event timestamp in UNIX seconds. + provider : str + Provider name. + endpoint : str + Request endpoint. + model : str + Request model. + queue_key : QueueKey + Queue identifier. + batch_id : str + Provider batch identifier. + status : str + Provider batch status. + request_count : int + Number of requests in the event scope. + pending_count : int + Current queue pending size. + custom_id : str + Custom request identifier. + source : str + Event source subsystem. + error : str + Error text. + missing_count : int + Number of missing results. + """ + + event_type: str + timestamp: float + provider: str + endpoint: str + model: str + queue_key: QueueKey + batch_id: str + status: str + request_count: int + pending_count: int + custom_id: str + source: str + error: str + missing_count: int + + +BatcherEventListener = t.Callable[[BatcherEvent], None] + + @dataclass class _PendingRequest: # FIXME: _PendingRequest can use a generic type to match any request from: @@ -135,6 +188,7 @@ def __init__( self._resumed_poll_tasks: set[asyncio.Task[None]] = set() self._resumed_batches: dict[ResumedBatchKey, _ResumedBatch] = {} self._resumed_lock = asyncio.Lock() + self._event_listeners: set[BatcherEventListener] = set() self._cache_store: RequestCacheStore | None = None if self._cache_enabled: @@ -163,6 +217,73 @@ def __init__( ), ) + def _add_event_listener( + self, + *, + listener: BatcherEventListener, + ) -> None: + """ + Register a listener receiving batch lifecycle events. + + Parameters + ---------- + listener : BatcherEventListener + Observer callback. + """ + self._event_listeners.add(listener) + + def _remove_event_listener( + self, + *, + listener: BatcherEventListener, + ) -> None: + """ + Unregister a previously registered lifecycle listener. + + Parameters + ---------- + listener : BatcherEventListener + Observer callback. + """ + self._event_listeners.discard(listener) + + def _emit_event( + self, + *, + event_type: str, + **payload: t.Any, + ) -> None: + """ + Emit a lifecycle event to all registered listeners. + + Parameters + ---------- + event_type : str + Event identifier. + **payload : typing.Any + Event payload. + """ + if not self._event_listeners: + return + event = t.cast( + typ=BatcherEvent, + val={ + "event_type": event_type, + "timestamp": time.time(), + **payload, + }, + ) + for listener in list(self._event_listeners): + try: + listener(event) + except Exception as error: + log_debug( + logger=log, + event="Batcher event listener failed", + listener=repr(listener), + error=str(object=error), + ) + @staticmethod def _format_queue_key(*, queue_key: QueueKey) -> str: """ @@ -353,6 +474,16 @@ async def _try_submit_from_cache( batch_id=cache_entry.batch_id, custom_id=cache_entry.custom_id, ) + cache_source = "cache_dry_run" if self._dry_run else "resumed_poll" + self._emit_event( + event_type="cache_hit_routed", + provider=provider_name, + endpoint=endpoint, + model=model_name, + batch_id=cache_entry.batch_id, + custom_id=cache_entry.custom_id, + source=cache_source, + ) if self._dry_run: dry_run_request = _PendingRequest( custom_id=cache_entry.custom_id, @@ -401,6 +532,15 @@ async def _enqueue_pending_request( queue = self._pending_by_provider.setdefault(queue_key, []) queue.append(request) pending_count = len(queue) + self._emit_event( + event_type="request_queued", + provider=queue_key[0], + endpoint=queue_key[1], + model=queue_key[2], + queue_key=queue_key, + pending_count=pending_count, + custom_id=request.custom_id, + ) if pending_count == 1: self._window_tasks[queue_key] = asyncio.create_task( @@ -610,7 +750,6 @@ async def _attach_cached_request( future=future, ) ) - if should_start_poller: task = asyncio.create_task( coro=self._poll_cached_batch(resume_key=resume_key), @@ -686,6 +825,15 @@ async def _window_timer(self, *, queue_key: QueueKey) -> None: queue_key=queue_name, error=str(object=e), ) + self._emit_event( + event_type="window_timer_error", + provider=provider_name, + endpoint=queue_endpoint, + model=model_name, + queue_key=queue_key, + error=str(object=e), + source="window_timer", + ) await self._fail_pending_provider_requests( queue_key=queue_key, error=e, @@ -722,6 +870,14 @@ async def _submit_requests( queue_key=queue_name, request_count=len(requests), ) + self._emit_event( + event_type="batch_submitting", + provider=provider_name, + endpoint=queue_endpoint, + model=model_name, + queue_key=queue_key, + request_count=len(requests), + ) task = asyncio.create_task( coro=self._process_batch(queue_key=queue_key, requests=requests), name=f"batch_submit_{queue_name}_{uuid.uuid4()}", @@ -876,6 +1032,16 @@ async def _process_batch( try: if self._dry_run: dry_run_batch_id = f"dryrun-{uuid.uuid4()}" + self._emit_event( + event_type="batch_processing", + provider=provider.name, + endpoint=queue_endpoint, + model=model_name, + queue_key=queue_key, + request_count=len(requests), + batch_id=dry_run_batch_id, + source="dry_run", + ) active_batch = _ActiveBatch( batch_id=dry_run_batch_id, output_file_id="", @@ -900,6 +1066,17 @@ async def _process_batch( batch_id=dry_run_batch_id, request_count=len(requests), ) + self._emit_event( + event_type="batch_terminal", + provider=provider.name, + endpoint=queue_endpoint, + model=model_name, + queue_key=queue_key, + request_count=len(requests), + batch_id=dry_run_batch_id, + status="simulated", + source="dry_run", + ) return log_info( @@ -911,11 +1088,30 @@ async def _process_batch( queue_key=queue_name, request_count=len(requests), ) + self._emit_event( + event_type="batch_processing", + provider=provider.name, + endpoint=queue_endpoint, + model=model_name, + queue_key=queue_key, + request_count=len(requests), + source="submit", + ) batch_submission = await provider.process_batch( requests=requests, client_factory=self._client_factory, queue_key=queue_key, ) + self._emit_event( + event_type="batch_processing", + provider=provider.name, + endpoint=queue_endpoint, + model=model_name, + queue_key=queue_key, + request_count=len(requests), + batch_id=batch_submission.batch_id, + source="poll_start", + ) self._write_cache_entries( queue_key=queue_key, requests=requests, @@ -948,6 +1144,15 @@ async def _process_batch( queue_key=queue_name, error=str(object=e), ) + self._emit_event( + event_type="batch_failed", + provider=provider_name, + endpoint=queue_endpoint, + model=model_name, + queue_key=queue_key, + request_count=len(requests), + error=str(object=e), + ) for req in requests: if not req.future.done(): req.future.set_exception(e) @@ -1056,6 +1261,13 @@ async def _poll_batch( api_headers=api_headers, batch_id=active_batch.batch_id, ) + self._emit_event( + event_type="batch_polled", + provider=provider.name, + batch_id=active_batch.batch_id, + status=poll_snapshot.status, + source="active_poll", + ) active_batch.output_file_id = poll_snapshot.output_file_id active_batch.error_file_id = poll_snapshot.error_file_id @@ -1067,6 +1279,13 @@ async def _poll_batch( batch_id=active_batch.batch_id, status=poll_snapshot.status, ) + self._emit_event( + event_type="batch_terminal", + provider=provider.name, + batch_id=active_batch.batch_id, + status=poll_snapshot.status, + source="active_poll", + ) await self._resolve_batch_results( base_url=base_url, api_headers=api_headers, @@ -1105,7 +1324,21 @@ async def _poll_cached_batch( api_headers=resumed_batch.api_headers, batch_id=batch_id, ) + self._emit_event( + event_type="batch_polled", + provider=provider.name, + batch_id=batch_id, + status=poll_snapshot.status, + source="resumed_poll", + ) if poll_snapshot.status in provider.batch_terminal_states: + self._emit_event( + event_type="batch_terminal", + provider=provider.name, + batch_id=batch_id, + status=poll_snapshot.status, + source="resumed_poll", + ) await self._resolve_cached_batch_results( resume_key=resume_key, output_file_id=poll_snapshot.output_file_id, @@ -1117,6 +1350,13 @@ async def _poll_cached_batch( except asyncio.CancelledError: raise except Exception as error: + self._emit_event( + event_type="batch_failed", + provider=provider.name, + batch_id=batch_id, + error=str(object=error), + source="resumed_poll", + ) await self._fail_resumed_batch_requests( resume_key=resume_key, error=error, @@ -1199,6 +1439,12 @@ async def _resolve_cached_batch_results( pending.future.set_result(resolved_response) if missing_hashes: + self._emit_event( + event_type="missing_results", + batch_id=batch_id, + missing_count=len(missing_hashes), + source="resumed_results", + ) _ = self._invalidate_cache_hashes(request_hashes=missing_hashes) async def _fail_resumed_batch_requests( @@ -1405,6 +1651,12 @@ def _fail_missing_results( batch_id=active_batch.batch_id, missing_count=len(missing), ) + self._emit_event( + event_type="missing_results", + batch_id=active_batch.batch_id, + missing_count=len(missing), + source="results", + ) error = RuntimeError(f"Missing results for {len(missing)} request(s)") for custom_id in missing: pending = active_batch.requests.get(custom_id) @@ -1447,6 +1699,15 @@ async def close(self) -> None: queue_key=queue_name, request_count=len(requests), ) + self._emit_event( + event_type="final_flush_submitting", + provider=provider_name, + endpoint=queue_endpoint, + model=model_name, + queue_key=queue_key, + request_count=len(requests), + source="close", + ) await self._submit_requests( queue_key=queue_key, requests=requests, diff --git a/src/batchling/progress_state.py b/src/batchling/progress_state.py new file mode 100644 index 0000000..4a9b0c6 --- /dev/null +++ b/src/batchling/progress_state.py @@ -0,0 +1,223 @@ +"""Shared progress-state tracking for live display and fallback logging.""" + +from __future__ import annotations + +import time +import typing as t +from dataclasses import dataclass + +from batchling.core import BatcherEvent + + +@dataclass +class _TrackedBatch: + """In-memory batch state used for aggregate progress computations.""" + + batch_id: str + provider: str = "-" + endpoint: str = "-" + model: str = "-" + size: int = 0 + completed: bool = False + terminal: bool = False + + +class BatchProgressState: + """ + Track batch lifecycle state and compute shared aggregate metrics. + + Parameters + ---------- + now_fn : typing.Callable[[], float] | None, optional + Clock function used for elapsed-time calculations. + """ + + def __init__( + self, + *, + now_fn: t.Callable[[], float] | None = None, + ) -> None: + self._now_fn = now_fn or time.time + self._batches: dict[str, _TrackedBatch] = {} + self._cached_samples = 0 + self._first_batch_created_at: float | None = None + + def on_event(self, *, event: BatcherEvent) -> None: + """ + Update tracked state from one lifecycle event. + + Parameters + ---------- + event : BatcherEvent + Lifecycle event emitted by ``Batcher``. + """ + event_type = str(object=event.get("event_type", "unknown")) + source = str(object=event.get("source", "")) + batch_id = event.get("batch_id") + + if batch_id is None: + return + + batch = self._get_or_create_batch(batch_id=str(object=batch_id)) + self._update_batch_identity(batch=batch, event=event) + + if event_type == "batch_processing": + request_count = event.get("request_count") + if isinstance(request_count, int): + batch.size = max(batch.size, request_count) + batch.terminal = False + return + + if event_type == "batch_polled": + batch.terminal = False + return + + if event_type == "batch_terminal": + status = str(object=event.get("status", "completed")) + batch.completed = self._status_counts_as_completed(status=status) + batch.terminal = True + return + + if event_type == "batch_failed": + batch.completed = False + batch.terminal = True + return + + if event_type == "cache_hit_routed" and source == "resumed_poll": + batch.size += 1 + self._cached_samples += 1 + batch.terminal = False + + def compute_progress(self) -> tuple[int, int, float]: + """ + Compute aggregate sample progress from tracked batches. + + Returns + ------- + tuple[int, int, float] + ``(completed_samples, total_samples, percent)``. + """ + total_samples = sum(batch.size for batch in self._batches.values()) + completed_samples = sum(batch.size for batch in self._batches.values() if batch.completed) + if total_samples <= 0: + return 0, 0, 0.0 + percent = (completed_samples / total_samples) * 100.0 + return completed_samples, total_samples, percent + + def compute_request_metrics(self) -> tuple[int, int, int, int]: + """ + Compute aggregate request counters from tracked batches. + + Returns + ------- + tuple[int, int, int, int] + ``(total_samples, cached_samples, completed_samples, in_progress_samples)``. + """ + total_samples = sum(batch.size for batch in self._batches.values()) + completed_samples = sum(batch.size for batch in self._batches.values() if batch.completed) + in_progress_samples = sum( + batch.size for batch in self._batches.values() if not batch.terminal + ) + return total_samples, self._cached_samples, completed_samples, in_progress_samples + + def compute_queue_batch_counts(self) -> list[tuple[str, str, str, int, int]]: + """ + Aggregate queue-level running and terminal batch counts. + + Returns + ------- + list[tuple[str, str, str, int, int]] + Sorted rows as ``(provider, endpoint, model, running, completed)``. + """ + counts_by_queue: dict[tuple[str, str, str], list[int]] = {} + for batch in self._batches.values(): + queue_key = (batch.provider, batch.endpoint, batch.model) + counters = counts_by_queue.setdefault(queue_key, [0, 0]) + if batch.terminal: + counters[1] += 1 + else: + counters[0] += 1 + + rows = [ + (provider, endpoint, model, counters[0], counters[1]) + for (provider, endpoint, model), counters in counts_by_queue.items() + ] + return sorted(rows, key=lambda row: (row[0], row[1], row[2])) + + def compute_elapsed_seconds(self) -> int: + """ + Compute elapsed seconds since first tracked batch in this context. + + Returns + ------- + int + Elapsed seconds. + """ + if self._first_batch_created_at is None: + return 0 + return max(0, int(self._now_fn() - self._first_batch_created_at)) + + def _get_or_create_batch(self, *, batch_id: str) -> _TrackedBatch: + """ + Get or create one tracked batch record. + + Parameters + ---------- + batch_id : str + Provider batch identifier. + + Returns + ------- + _TrackedBatch + Mutable tracked batch. + """ + batch = self._batches.get(batch_id) + if batch is None: + batch = _TrackedBatch(batch_id=batch_id) + self._batches[batch_id] = batch + if self._first_batch_created_at is None: + self._first_batch_created_at = self._now_fn() + return batch + + @staticmethod + def _update_batch_identity(*, batch: _TrackedBatch, event: BatcherEvent) -> None: + """ + Update batch metadata from lifecycle event payload. + + Parameters + ---------- + batch : _TrackedBatch + Mutable tracked batch. + event : BatcherEvent + Lifecycle event payload. + """ + provider = event.get("provider") + endpoint = event.get("endpoint") + model = event.get("model") + if provider is not None: + batch.provider = str(object=provider) + if endpoint is not None: + batch.endpoint = str(object=endpoint) + if model is not None: + batch.model = str(object=model) + + @staticmethod + def _status_counts_as_completed(*, status: str) -> bool: + """ + Determine whether a terminal status counts as completed samples. + + Parameters + ---------- + status : str + Terminal provider status. + + Returns + ------- + bool + ``True`` when terminal state should contribute to completed samples. + """ + lowered_status = status.lower() + negative_markers = ("fail", "error", "cancel", "expired", "timeout") + if any(marker in lowered_status for marker in negative_markers): + return False + return True diff --git a/src/batchling/rich_display.py b/src/batchling/rich_display.py new file mode 100644 index 0000000..0060755 --- /dev/null +++ b/src/batchling/rich_display.py @@ -0,0 +1,330 @@ +"""Rich live display for batch lifecycle visibility.""" + +from __future__ import annotations + +import os +import sys +import time + +from rich.console import Console, Group +from rich.live import Live +from rich.panel import Panel +from rich.progress import BarColumn, Progress, TextColumn +from rich.table import Table +from rich.text import Text + +from batchling.core import BatcherEvent +from batchling.progress_state import BatchProgressState + + +class BatcherRichDisplay: + """ + Render context-level sample progress through a Rich ``Live`` panel. + + Progress is computed from tracked sent batches as: + ``sum(size of completed batches) / sum(size of all tracked batches)``. + + Parameters + ---------- + refresh_per_second : float, optional + Refresh rate for Rich live updates. + console : Console | None, optional + Rich console to render to. Defaults to ``Console(stderr=True)``. + """ + + def __init__( + self, + *, + refresh_per_second: float = 1.0, + console: Console | None = None, + ) -> None: + self._console = console or Console(stderr=True) + self._refresh_per_second = refresh_per_second + self._progress_state = BatchProgressState(now_fn=time.time) + self._live: Live | None = None + + def start(self) -> None: + """Start the live panel if not already running.""" + if self._live is not None: + return + self._live = Live( + renderable=self._render(), + console=self._console, + refresh_per_second=self._refresh_per_second, + transient=False, + ) + self._live.start(refresh=True) + + def stop(self) -> None: + """Stop the live panel if running.""" + if self._live is None: + return + self._live.stop() + self._live = None + + def on_event(self, event: BatcherEvent) -> None: + """ + Consume one batch lifecycle event and refresh the panel. + + Parameters + ---------- + event : BatcherEvent + Lifecycle event emitted by ``Batcher``. + """ + self._progress_state.on_event(event=event) + self.refresh() + + def refresh(self) -> None: + """Force one live-panel refresh when running.""" + if self._live is None: + return + self._live.update(renderable=self._render(), refresh=True) + + def _compute_progress(self) -> tuple[int, int, float]: + """ + Compute aggregate context progress from tracked batches. + + Returns + ------- + tuple[int, int, float] + ``(completed_samples, total_samples, percent)``. + """ + return self._progress_state.compute_progress() + + def _compute_request_metrics(self) -> tuple[int, int, int, int]: + """ + Compute aggregate request counters shown under the progress bar. + + Returns + ------- + tuple[int, int, int, int] + ``(total_samples, cached_samples, completed_samples, in_progress_samples)``. + """ + return self._progress_state.compute_request_metrics() + + def _compute_elapsed_seconds(self) -> int: + """ + Compute elapsed seconds since first batch creation in this context. + + Returns + ------- + int + Elapsed seconds. + """ + return self._progress_state.compute_elapsed_seconds() + + @staticmethod + def _format_elapsed(*, elapsed_seconds: int) -> str: + """ + Format elapsed seconds as ``HH:MM:SS``. + + Parameters + ---------- + elapsed_seconds : int + Elapsed seconds. + + Returns + ------- + str + Formatted duration. + """ + hours = elapsed_seconds // 3600 + minutes = (elapsed_seconds % 3600) // 60 + seconds = elapsed_seconds % 60 + return f"{hours:02d}:{minutes:02d}:{seconds:02d}" + + def _render(self) -> Panel: + """Build the current Rich panel renderable.""" + progress_bar = self._build_progress_bar() + requests_line = self._build_requests_line() + queue_summary_table = self._build_queue_summary_table() + return Panel( + renderable=Group(progress_bar, requests_line, queue_summary_table), + title="batchling context progress", + border_style="cyan", + ) + + def _build_progress_bar(self) -> Progress: + """Build aggregate context progress as a Rich progress bar.""" + completed_samples, total_samples, _ = self._compute_progress() + elapsed_seconds = self._compute_elapsed_seconds() + elapsed_label = self._format_elapsed(elapsed_seconds=elapsed_seconds) + sample_width = max(1, len(str(object=total_samples))) + + progress = Progress( + BarColumn(bar_width=None), + TextColumn( + text_format=( + f"[bold green]{{task.fields[completed_samples]:>{sample_width}}}[/bold green]/" + f"[bold cyan]{{task.fields[total_samples]:>{sample_width}}}[/bold cyan] " + "([bold green]{task.percentage:.1f}%[/bold green])" + ) + ), + TextColumn(text_format=f"Time Elapsed: [bold magenta]{elapsed_label}[/bold magenta]"), + expand=True, + ) + display_total = max(total_samples, 1) + _ = progress.add_task( + description="samples", + total=display_total, + completed=min(completed_samples, display_total), + completed_samples=completed_samples, + total_samples=total_samples, + ) + return progress + + def _build_requests_line(self) -> Text: + """ + Build one-line request metrics shown under the progress bar. + + Returns + ------- + Text + Styled metrics line. + """ + total_samples, cached_samples, completed_samples, in_progress_samples = ( + self._compute_request_metrics() + ) + line = Text() + line.append(text="Requests", style="bold white") + line.append(text=": ", style="white") + line.append(text="Total", style="grey70") + line.append(text=": ", style="grey70") + line.append(text=str(object=total_samples), style="bold cyan") + line.append(text=" - ", style="grey70") + line.append(text="Cached", style="grey70") + line.append(text=": ", style="grey70") + line.append(text=str(object=cached_samples), style="bold magenta") + line.append(text=" - ", style="grey70") + line.append(text="Completed", style="grey70") + line.append(text=": ", style="grey70") + line.append(text=str(object=completed_samples), style="bold green") + line.append(text=" - ", style="grey70") + line.append(text="In Progress", style="grey70") + line.append(text=": ", style="grey70") + line.append(text=str(object=in_progress_samples), style="bold yellow") + return line + + def _compute_queue_batch_counts(self) -> list[tuple[str, str, str, int, int]]: + """ + Aggregate queue-level running and terminal batch counts. + + Returns + ------- + list[tuple[str, str, str, int, int]] + Sorted rows as ``(provider, endpoint, model, running, completed)``. + """ + return self._progress_state.compute_queue_batch_counts() + + def _build_queue_summary_table(self) -> Table: + """ + Build queue-level table with per-queue progress summary. + + Returns + ------- + Table + Queue summary table. + """ + queue_rows = self._compute_queue_batch_counts() + table = Table(expand=False) + table.add_column( + header="provider", + style="bold blue", + width=12, + no_wrap=True, + overflow="ellipsis", + ) + table.add_column( + header="endpoint", + width=34, + no_wrap=True, + overflow="ellipsis", + ) + table.add_column( + header="model", + style="bold magenta", + width=28, + no_wrap=True, + overflow="ellipsis", + ) + table.add_column( + header="progress", + justify="right", + width=16, + no_wrap=True, + overflow="ellipsis", + ) + + if not queue_rows: + table.add_row("-", "-", "-", self._format_queue_progress(running=0, completed=0)) + return table + + for provider, endpoint, model, running, completed in queue_rows: + table.add_row( + provider, + endpoint, + model, + self._format_queue_progress( + running=running, + completed=completed, + ), + ) + return table + + @staticmethod + def _format_queue_progress(*, running: int, completed: int) -> Text: + """ + Format one queue progress cell as ``completed/total (percent)``. + + Parameters + ---------- + running : int + Number of non-terminal batches. + completed : int + Number of terminal batches. + + Returns + ------- + Text + Formatted queue progress. + """ + total = running + completed + if total <= 0: + percent = 0.0 + else: + percent = (completed / total) * 100.0 + count_width = max(1, len(str(object=total))) + progress = Text() + progress.append(text=f"{completed:>{count_width}}", style="bold green") + progress.append(text="/", style="white") + progress.append(text=f"{total:>{count_width}}", style="bold cyan") + progress.append(text=" (", style="white") + progress.append(text=f"{percent:.1f}%", style="bold green") + progress.append(text=")", style="white") + return progress + + +def should_enable_live_display(*, enabled: bool) -> bool: + """ + Resolve if the Rich live panel should be enabled. + + Parameters + ---------- + enabled : bool + Requested live display toggle. + + Returns + ------- + bool + ``True`` when the live panel should run. + """ + if not enabled: + return False + + stderr_stream = sys.stderr + is_tty = bool(getattr(stderr_stream, "isatty", lambda: False)()) + terminal_name = str(object=os.environ.get("TERM", "")).lower() + is_dumb_terminal = terminal_name in {"", "dumb"} + is_ci = bool(os.environ.get("CI")) + + return is_tty and not is_dumb_terminal and not is_ci diff --git a/tests/test_api.py b/tests/test_api.py index 03da464..446d2b6 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -60,6 +60,24 @@ async def test_batchify_configures_cache_flag(reset_hooks, reset_context): assert wrapped._self_batcher._cache_enabled is False +@pytest.mark.asyncio +async def test_batchify_forwards_live_display_flag(reset_hooks, reset_context): + """Test that batchify forwards live display flag to BatchingContext.""" + wrapped = batchify( + live_display=False, + ) + + assert wrapped._self_live_display_enabled is False + + +@pytest.mark.asyncio +async def test_batchify_live_display_defaults_to_true(reset_hooks, reset_context): + """Test that live display defaults to enabled.""" + wrapped = batchify() + + assert wrapped._self_live_display_enabled is True + + @pytest.mark.asyncio async def test_batchify_returns_context_manager(reset_hooks, reset_context): """Test that batchify returns a BatchingContext.""" diff --git a/tests/test_cli_script_runner.py b/tests/test_cli_script_runner.py index 0472062..e7426d7 100644 --- a/tests/test_cli_script_runner.py +++ b/tests/test_cli_script_runner.py @@ -61,6 +61,7 @@ def fake_batchify(**kwargs): } assert captured_batchify_kwargs["dry_run"] is True assert captured_batchify_kwargs["cache"] is True + assert captured_batchify_kwargs["live_display"] is True def test_run_script_with_cache_option(tmp_path: Path, monkeypatch): @@ -92,6 +93,38 @@ def fake_batchify(**kwargs): assert result.exit_code == 0 assert captured_batchify_kwargs["cache"] is False + assert captured_batchify_kwargs["live_display"] is True + + +def test_run_script_with_no_live_display_option(tmp_path: Path, monkeypatch): + script_path = tmp_path / "script.py" + script_path.write_text( + "\n".join( + [ + "async def foo(*args, **kwargs):", + " return None", + ] + ) + + "\n" + ) + captured_batchify_kwargs: dict = {} + + def fake_batchify(**kwargs): + captured_batchify_kwargs.update(kwargs) + return DummyAsyncBatchifyContext() + + monkeypatch.setattr(cli_main, "batchify", fake_batchify) + + result = runner.invoke( + app, + [ + f"{script_path.as_posix()}:foo", + "--no-live-display", + ], + ) + + assert result.exit_code == 0 + assert captured_batchify_kwargs["live_display"] is False def test_batch_size_flag_scope_for_cli_and_target_function(tmp_path: Path, monkeypatch): diff --git a/tests/test_context.py b/tests/test_context.py index e967a57..72c6eb9 100644 --- a/tests/test_context.py +++ b/tests/test_context.py @@ -3,6 +3,7 @@ """ import asyncio +import logging import typing as t import warnings from unittest.mock import AsyncMock, patch @@ -96,3 +97,122 @@ async def test_batching_context_without_target(batcher: Batcher, reset_context: async with context as active_target: assert active_batcher.get() is batcher assert active_target is None + + +@pytest.mark.asyncio +async def test_batching_context_starts_and_stops_live_display( + batcher: Batcher, + reset_context: None, + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Test that async context starts and stops live display listeners.""" + + class DummyDisplay: + """Simple display stub.""" + + def __init__(self) -> None: + self.started = False + self.stopped = False + + def start(self) -> None: + self.started = True + + def stop(self) -> None: + self.stopped = True + + def on_event(self, event: dict[str, t.Any]) -> None: + del event + + dummy_display = DummyDisplay() + monkeypatch.setattr("batchling.context.BatcherRichDisplay", lambda: dummy_display) + monkeypatch.setattr("batchling.context.should_enable_live_display", lambda **_kwargs: True) + context = BatchingContext( + batcher=batcher, + live_display=True, + ) + + with patch.object(target=batcher, attribute="close", new_callable=AsyncMock): + async with context: + assert dummy_display.started is True + assert context._self_live_display_heartbeat_task is not None + assert not context._self_live_display_heartbeat_task.done() + assert dummy_display.stopped is True + assert context._self_live_display_heartbeat_task is None + + +def test_batching_context_sync_stops_live_display_without_loop( + batcher: Batcher, + reset_context: None, + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Test sync context stops display when no event loop is running.""" + + class DummyDisplay: + """Simple display stub.""" + + def __init__(self) -> None: + self.stopped = False + + def start(self) -> None: + return None + + def stop(self) -> None: + self.stopped = True + + def on_event(self, event: dict[str, t.Any]) -> None: + del event + + dummy_display = DummyDisplay() + monkeypatch.setattr("batchling.context.BatcherRichDisplay", lambda: dummy_display) + monkeypatch.setattr("batchling.context.should_enable_live_display", lambda **_kwargs: True) + + context = BatchingContext( + batcher=batcher, + live_display=True, + ) + + with warnings.catch_warnings(record=True): + warnings.simplefilter(action="always") + with context: + pass + + assert dummy_display.stopped is True + assert context._self_live_display_heartbeat_task is None + + +def test_batching_context_uses_polling_progress_fallback_when_auto_disabled( + batcher: Batcher, + caplog: pytest.LogCaptureFixture, + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Test live-display fallback logs progress at poll time when Rich is disabled.""" + monkeypatch.setattr("batchling.context.should_enable_live_display", lambda **_kwargs: False) + context = BatchingContext( + batcher=batcher, + live_display=True, + ) + + caplog.set_level(level=logging.INFO, logger="batchling.context") + + context._start_live_display() + assert context._self_live_display is None + assert context._self_polling_progress_logger is not None + assert any("using polling progress INFO logs" in record.message for record in caplog.records) + + batcher._emit_event( + event_type="batch_processing", + batch_id="batch-1", + request_count=4, + source="poll_start", + ) + batcher._emit_event( + event_type="batch_polled", + batch_id="batch-1", + status="in_progress", + source="active_poll", + ) + + assert any("Live display fallback progress" in record.message for record in caplog.records) + + context._stop_live_display() + assert context._self_polling_progress_logger is None diff --git a/tests/test_core.py b/tests/test_core.py index a15d5be..960ba6e 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -10,7 +10,7 @@ import pytest from batchling.cache import CacheEntry -from batchling.core import Batcher, _ActiveBatch, _PendingRequest +from batchling.core import Batcher, BatcherEvent, _ActiveBatch, _PendingRequest from batchling.providers.anthropic import AnthropicProvider from batchling.providers.base import PollSnapshot, ProviderRequestSpec, ResumeContext from batchling.providers.gemini import GeminiProvider @@ -1948,3 +1948,142 @@ async def test_dry_run_cache_hit_is_read_only(mock_openai_api_transport: httpx.M assert dry_run_entry is not None assert dry_run_entry.created_at == original_created_at await dry_run_batcher.close() + + +@pytest.mark.asyncio +async def test_submit_emits_lifecycle_events(batcher: Batcher, provider: OpenAIProvider) -> None: + """Test submit emits queue, submit, poll, and terminal lifecycle events.""" + events: list[BatcherEvent] = [] + batcher._add_event_listener(listener=events.append) + + _ = await batcher.submit( + client_type="httpx", + method="POST", + url="api.openai.com", + endpoint="/v1/chat/completions", + provider=provider, + headers={"Authorization": "Bearer token"}, + body=b'{"model":"model-a","messages":[]}', + ) + + event_types = {str(event["event_type"]) for event in events} + assert "request_queued" in event_types + assert "batch_submitting" in event_types + assert "batch_processing" in event_types + assert "batch_polled" in event_types + assert "batch_terminal" in event_types + + +@pytest.mark.asyncio +async def test_dry_run_emits_terminal_without_poll(provider: OpenAIProvider) -> None: + """Test dry-run emits terminal lifecycle without provider polling events.""" + dry_run_batcher = Batcher( + batch_size=1, + batch_window_seconds=1.0, + dry_run=True, + cache=False, + ) + events: list[BatcherEvent] = [] + dry_run_batcher._add_event_listener(listener=events.append) + + _ = await dry_run_batcher.submit( + client_type="httpx", + method="POST", + url="api.openai.com", + endpoint="/v1/chat/completions", + provider=provider, + headers={"Authorization": "Bearer token"}, + body=b'{"model":"model-a","messages":[]}', + ) + + event_types = {str(event["event_type"]) for event in events} + assert "request_queued" in event_types + assert "batch_submitting" in event_types + assert "batch_terminal" in event_types + assert "batch_polled" not in event_types + + await dry_run_batcher.close() + + +@pytest.mark.asyncio +async def test_cache_hit_emits_cache_hit_routed( + mock_openai_api_transport: httpx.MockTransport, +) -> None: + """Test cache-hit routing emits dedicated lifecycle event.""" + writer_batcher = Batcher( + batch_size=2, + batch_window_seconds=0.1, + cache=True, + ) + writer_batcher._client_factory = lambda: httpx.AsyncClient(transport=mock_openai_api_transport) + writer_batcher._poll_interval_seconds = 0.01 + provider = OpenAIProvider() + + _ = await writer_batcher.submit( + client_type="httpx", + method="POST", + url="api.openai.com", + endpoint="/v1/chat/completions", + provider=provider, + headers={"Authorization": "Bearer token"}, + body=b'{"model":"model-a","messages":[]}', + ) + await writer_batcher.close() + + dry_run_batcher = Batcher( + batch_size=2, + batch_window_seconds=0.1, + cache=True, + dry_run=True, + ) + events: list[BatcherEvent] = [] + dry_run_batcher._add_event_listener(listener=events.append) + + _ = await dry_run_batcher.submit( + client_type="httpx", + method="POST", + url="api.openai.com", + endpoint="/v1/chat/completions", + provider=OpenAIProvider(), + headers={"Authorization": "Bearer token"}, + body=b'{"model":"model-a","messages":[]}', + ) + + assert any( + event["event_type"] == "cache_hit_routed" and event.get("source") == "cache_dry_run" + for event in events + ) + await dry_run_batcher.close() + + +def test_fail_missing_results_emits_lifecycle_event() -> None: + """Test missing output mapping emits missing-results lifecycle event.""" + batcher = Batcher( + batch_size=2, + batch_window_seconds=1.0, + cache=False, + ) + loop = asyncio.new_event_loop() + future: asyncio.Future[t.Any] = loop.create_future() + request = _PendingRequest( + custom_id="custom-1", + queue_key=("openai", "/v1/chat/completions", "model-a"), + params={}, + provider=OpenAIProvider(), + future=future, + request_hash="hash-1", + ) + active_batch = _ActiveBatch( + batch_id="batch-1", + output_file_id="file-1", + error_file_id="", + requests={"custom-1": request}, + ) + + events: list[BatcherEvent] = [] + batcher._add_event_listener(listener=events.append) + + batcher._fail_missing_results(active_batch=active_batch, seen=set()) + + assert any(event["event_type"] == "missing_results" for event in events) + loop.close() diff --git a/tests/test_rich_display.py b/tests/test_rich_display.py new file mode 100644 index 0000000..2d37959 --- /dev/null +++ b/tests/test_rich_display.py @@ -0,0 +1,370 @@ +"""Tests for Rich live display helpers.""" + +import io + +from rich.console import Console +from rich.text import Text + +import batchling.rich_display as rich_display + + +def test_should_enable_live_display_auto_in_interactive_terminal( + monkeypatch, +) -> None: + """Test auto mode enables display in interactive terminals.""" + + class DummyStderr: + def isatty(self) -> bool: + return True + + monkeypatch.setattr(rich_display.sys, "stderr", DummyStderr()) + monkeypatch.setenv("TERM", "xterm-256color") + monkeypatch.delenv("CI", raising=False) + + assert rich_display.should_enable_live_display(enabled=True) is True + + +def test_should_enable_live_display_auto_disabled_in_ci(monkeypatch) -> None: + """Test auto mode disables display in CI environments.""" + + class DummyStderr: + def isatty(self) -> bool: + return True + + monkeypatch.setattr(rich_display.sys, "stderr", DummyStderr()) + monkeypatch.setenv("TERM", "xterm-256color") + monkeypatch.setenv("CI", "true") + + assert rich_display.should_enable_live_display(enabled=True) is False + + +def test_should_enable_live_display_disabled_by_flag(monkeypatch) -> None: + """Test explicit disable always returns False.""" + + class DummyStderr: + def isatty(self) -> bool: + return True + + monkeypatch.setattr(rich_display.sys, "stderr", DummyStderr()) + monkeypatch.setenv("TERM", "xterm-256color") + monkeypatch.delenv("CI", raising=False) + + assert rich_display.should_enable_live_display(enabled=False) is False + + +def test_batcher_rich_display_computes_context_progress() -> None: + """Test context progress is derived from completed batch sizes.""" + display = rich_display.BatcherRichDisplay( + console=Console(file=io.StringIO(), force_terminal=False), + ) + + processing_event: rich_display.BatcherEvent = { + "event_type": "batch_processing", + "timestamp": 1.0, + "provider": "openai", + "endpoint": "/v1/chat/completions", + "model": "model-a", + "queue_key": ("openai", "/v1/chat/completions", "model-a"), + "batch_id": "batch-1", + "request_count": 3, + "source": "poll_start", + } + terminal_event: rich_display.BatcherEvent = { + "event_type": "batch_terminal", + "timestamp": 2.0, + "provider": "openai", + "batch_id": "batch-1", + "status": "completed", + "source": "active_poll", + } + failed_batch_event: rich_display.BatcherEvent = { + "event_type": "batch_processing", + "timestamp": 3.0, + "provider": "openai", + "endpoint": "/v1/chat/completions", + "model": "model-a", + "queue_key": ("openai", "/v1/chat/completions", "model-a"), + "batch_id": "batch-2", + "request_count": 2, + "source": "poll_start", + } + failed_terminal_event: rich_display.BatcherEvent = { + "event_type": "batch_terminal", + "timestamp": 4.0, + "provider": "openai", + "batch_id": "batch-2", + "status": "failed", + "source": "active_poll", + } + + display.on_event(processing_event) + display.on_event(terminal_event) + display.on_event(failed_batch_event) + display.on_event(failed_terminal_event) + + completed_samples, total_samples, percent = display._compute_progress() + assert completed_samples == 3 + assert total_samples == 5 + assert percent == 60.0 + + +def test_batcher_rich_display_tracks_resumed_batch_progress() -> None: + """Test resumed cache-hit routing contributes to total and completion.""" + display = rich_display.BatcherRichDisplay( + console=Console(file=io.StringIO(), force_terminal=False), + ) + + cache_event: rich_display.BatcherEvent = { + "event_type": "cache_hit_routed", + "timestamp": 1.0, + "provider": "openai", + "endpoint": "/v1/chat/completions", + "model": "model-a", + "batch_id": "batch-cached-1", + "source": "resumed_poll", + "custom_id": "custom-1", + } + terminal_event: rich_display.BatcherEvent = { + "event_type": "batch_terminal", + "timestamp": 2.0, + "provider": "openai", + "batch_id": "batch-cached-1", + "status": "completed", + "source": "resumed_poll", + } + + display.on_event(cache_event) + display.on_event(cache_event) + display.on_event(terminal_event) + + completed_samples, total_samples, percent = display._compute_progress() + assert completed_samples == 2 + assert total_samples == 2 + assert percent == 100.0 + + +def test_batcher_rich_display_elapsed_uses_first_batch_time(monkeypatch) -> None: + """Test elapsed timer starts from first batch seen in the context.""" + current_time = {"value": 100.0} + + def fake_time() -> float: + return current_time["value"] + + monkeypatch.setattr(rich_display.time, "time", fake_time) + + display = rich_display.BatcherRichDisplay( + console=Console(file=io.StringIO(), force_terminal=False), + ) + + first_batch_event: rich_display.BatcherEvent = { + "event_type": "batch_processing", + "timestamp": 100.0, + "provider": "openai", + "endpoint": "/v1/chat/completions", + "model": "model-a", + "queue_key": ("openai", "/v1/chat/completions", "model-a"), + "batch_id": "batch-1", + "request_count": 1, + "source": "poll_start", + } + display.on_event(first_batch_event) + + current_time["value"] = 127.0 + assert display._compute_elapsed_seconds() == 27 + assert display._format_elapsed(elapsed_seconds=27) == "00:00:27" + + +def test_batcher_rich_display_elapsed_starts_with_cache_batch(monkeypatch) -> None: + """Test elapsed timer also starts when the first batch comes from cache routing.""" + current_time = {"value": 200.0} + + def fake_time() -> float: + return current_time["value"] + + monkeypatch.setattr(rich_display.time, "time", fake_time) + + display = rich_display.BatcherRichDisplay( + console=Console(file=io.StringIO(), force_terminal=False), + ) + + cache_event: rich_display.BatcherEvent = { + "event_type": "cache_hit_routed", + "timestamp": 200.0, + "provider": "openai", + "endpoint": "/v1/chat/completions", + "model": "model-a", + "batch_id": "batch-cached-1", + "source": "resumed_poll", + "custom_id": "custom-1", + } + display.on_event(cache_event) + + current_time["value"] = 206.0 + assert display._compute_elapsed_seconds() == 6 + + +def test_batcher_rich_display_request_metrics_line() -> None: + """Test requests metrics aggregate total/cached/completed/in-progress samples.""" + display = rich_display.BatcherRichDisplay( + console=Console(file=io.StringIO(), force_terminal=False), + ) + + processing_event_batch_1: rich_display.BatcherEvent = { + "event_type": "batch_processing", + "timestamp": 1.0, + "provider": "openai", + "endpoint": "/v1/chat/completions", + "model": "model-a", + "queue_key": ("openai", "/v1/chat/completions", "model-a"), + "batch_id": "batch-1", + "request_count": 3, + "source": "poll_start", + } + terminal_event_batch_1: rich_display.BatcherEvent = { + "event_type": "batch_terminal", + "timestamp": 2.0, + "provider": "openai", + "batch_id": "batch-1", + "status": "completed", + "source": "active_poll", + } + processing_event_batch_2: rich_display.BatcherEvent = { + "event_type": "batch_processing", + "timestamp": 3.0, + "provider": "openai", + "endpoint": "/v1/chat/completions", + "model": "model-a", + "queue_key": ("openai", "/v1/chat/completions", "model-a"), + "batch_id": "batch-2", + "request_count": 2, + "source": "poll_start", + } + cache_event_batch_3: rich_display.BatcherEvent = { + "event_type": "cache_hit_routed", + "timestamp": 4.0, + "provider": "openai", + "endpoint": "/v1/chat/completions", + "model": "model-a", + "batch_id": "batch-3", + "source": "resumed_poll", + "custom_id": "custom-1", + } + terminal_event_batch_3: rich_display.BatcherEvent = { + "event_type": "batch_terminal", + "timestamp": 5.0, + "provider": "openai", + "batch_id": "batch-3", + "status": "completed", + "source": "resumed_poll", + } + + display.on_event(processing_event_batch_1) + display.on_event(terminal_event_batch_1) + display.on_event(processing_event_batch_2) + display.on_event(cache_event_batch_3) + display.on_event(cache_event_batch_3) + display.on_event(terminal_event_batch_3) + + total_samples, cached_samples, completed_samples, in_progress_samples = ( + display._compute_request_metrics() + ) + assert total_samples == 7 + assert cached_samples == 2 + assert completed_samples == 5 + assert in_progress_samples == 2 + + +def test_batcher_rich_display_queue_table_progress_column() -> None: + """Test queue table progress column derives from completed/total counts.""" + display = rich_display.BatcherRichDisplay( + console=Console(file=io.StringIO(), force_terminal=False), + ) + + queue_event_batch_1: rich_display.BatcherEvent = { + "event_type": "batch_processing", + "timestamp": 1.0, + "provider": "openai", + "endpoint": "/v1/chat/completions", + "model": "model-a", + "queue_key": ("openai", "/v1/chat/completions", "model-a"), + "batch_id": "batch-1", + "request_count": 1, + "source": "poll_start", + } + queue_event_batch_2: rich_display.BatcherEvent = { + "event_type": "batch_processing", + "timestamp": 2.0, + "provider": "openai", + "endpoint": "/v1/chat/completions", + "model": "model-a", + "queue_key": ("openai", "/v1/chat/completions", "model-a"), + "batch_id": "batch-2", + "request_count": 1, + "source": "poll_start", + } + terminal_event_batch_2: rich_display.BatcherEvent = { + "event_type": "batch_terminal", + "timestamp": 3.0, + "provider": "openai", + "batch_id": "batch-2", + "status": "completed", + "source": "active_poll", + } + other_queue_event: rich_display.BatcherEvent = { + "event_type": "batch_processing", + "timestamp": 4.0, + "provider": "groq", + "endpoint": "/openai/v1/chat/completions", + "model": "llama-3.1-8b-instant", + "queue_key": ("groq", "/openai/v1/chat/completions", "llama-3.1-8b-instant"), + "batch_id": "batch-3", + "request_count": 1, + "source": "poll_start", + } + + display.on_event(queue_event_batch_1) + display.on_event(queue_event_batch_2) + display.on_event(terminal_event_batch_2) + display.on_event(other_queue_event) + + queue_rows = display._compute_queue_batch_counts() + assert queue_rows == [ + ("groq", "/openai/v1/chat/completions", "llama-3.1-8b-instant", 1, 0), + ("openai", "/v1/chat/completions", "model-a", 1, 1), + ] + + table = display._build_queue_summary_table() + assert table.columns[0].width == 12 + assert table.columns[1].width == 34 + assert table.columns[2].width == 28 + assert table.columns[3].width == 16 + assert table.columns[0]._cells == ["groq", "openai"] + progress_cells = table.columns[3]._cells + assert isinstance(progress_cells[0], Text) + assert isinstance(progress_cells[1], Text) + assert progress_cells[0].plain == "0/1 (0.0%)" + assert progress_cells[1].plain == "1/2 (50.0%)" + + +def test_batcher_rich_display_queue_table_empty_state() -> None: + """Test queue table renders default row when no batches are tracked.""" + display = rich_display.BatcherRichDisplay( + console=Console(file=io.StringIO(), force_terminal=False), + ) + + table = display._build_queue_summary_table() + assert table.columns[0]._cells == ["-"] + assert table.columns[1]._cells == ["-"] + assert table.columns[2]._cells == ["-"] + progress_cells = table.columns[3]._cells + assert isinstance(progress_cells[0], Text) + assert progress_cells[0].plain == "0/0 (0.0%)" + + +def test_batcher_rich_display_queue_progress_pads_to_total_width() -> None: + """Test queue progress keeps parenthesis anchor stable for large totals.""" + progress_text = rich_display.BatcherRichDisplay._format_queue_progress( + running=99, + completed=1, + ) + assert progress_text.plain == " 1/100 (1.0%)" diff --git a/uv.lock b/uv.lock index 9e3a08a..cf0005b 100644 --- a/uv.lock +++ b/uv.lock @@ -267,6 +267,7 @@ source = { editable = "." } dependencies = [ { name = "aiohttp" }, { name = "httpx" }, + { name = "rich" }, { name = "typer" }, ] @@ -304,6 +305,7 @@ dev = [ requires-dist = [ { name = "aiohttp", specifier = ">=3.13.3" }, { name = "httpx", specifier = ">=0.28.1" }, + { name = "rich", specifier = ">=14.2.0" }, { name = "typer", specifier = ">=0.20.0" }, ]