diff --git a/.env.example b/.env.example index 24ad4f5..07058f2 100644 --- a/.env.example +++ b/.env.example @@ -26,6 +26,8 @@ SKILLSPECTOR_MODEL= # leave empty to # SKILLSPECTOR_MODEL_REGISTRY=./model_registry.yaml # optional override; defaults to each provider's bundled YAML in src/skillspector/providers/ SKILLSPECTOR_LOG_LEVEL=WARNING # options: DEBUG|INFO|WARNING|ERROR +SKILLSPECTOR_MAX_CONCURRENCY=5 # concurrent LLM requests; increase for faster processing (higher = more rate limits) + # langchain/langsmith config (all optional) LANGCHAIN_TRACING_V2=false LANGCHAIN_API_KEY= diff --git a/docs/DEVELOPMENT.md b/docs/DEVELOPMENT.md index 0795f09..433f6a0 100644 --- a/docs/DEVELOPMENT.md +++ b/docs/DEVELOPMENT.md @@ -266,6 +266,7 @@ Copy [.env.example](../.env.example) to `.env` in the project root and set value | `OPENAI_BASE_URL` | Override the OpenAI endpoint (e.g. point at Ollama). | `http://localhost:11434/v1` | | `ANTHROPIC_API_KEY` | Credential for `SKILLSPECTOR_PROVIDER=anthropic`. | `sk-ant-...` | | `SKILLSPECTOR_MODEL` | Override the active provider's bundled default model (see [README.md](../README.md) for per-provider defaults). | `gpt-5.2` | +| `SKILLSPECTOR_MAX_CONCURRENCY` | Maximum concurrent LLM requests. Defaults to 5 if unset. Lower values reduce rate limits but increase scan time. | `7` | ### Constants, token budgets, and LLM diff --git a/src/skillspector/llm_analyzer_base.py b/src/skillspector/llm_analyzer_base.py index a678e4e..79a6db4 100644 --- a/src/skillspector/llm_analyzer_base.py +++ b/src/skillspector/llm_analyzer_base.py @@ -34,7 +34,12 @@ from pydantic import BaseModel, Field -from skillspector.llm_utils import get_chat_model +from skillspector.llm_utils import ( + _resolve_max_concurrency, + get_chat_model, + retry_llm_call, + retry_llm_call_sync, +) from skillspector.logging_config import get_logger from skillspector.model_info import get_max_input_tokens from skillspector.models import Finding @@ -353,9 +358,13 @@ def run_batches( len(batch.findings), ) if self._structured_llm: - response = self._structured_llm.invoke(prompt) + response = retry_llm_call_sync( + lambda prompt=prompt: self._structured_llm.invoke(prompt) + ) else: - response = self._llm.invoke(prompt).content + response = retry_llm_call_sync( + lambda prompt=prompt: self._llm.invoke(prompt) + ).content logger.debug("LLM response for %s", batch.file_label) parsed = self.parse_response(response, batch) results.append((batch, parsed)) @@ -365,7 +374,7 @@ async def arun_batches( self, batches: list[Batch], *, - max_concurrency: int = 10, + max_concurrency: int | None = None, **kwargs: object, ) -> list[tuple[Batch, list]]: """Execute LLM calls for all *batches* concurrently. @@ -376,6 +385,8 @@ async def arun_batches( The return type mirrors :meth:`run_batches`. """ + max_concurrency = _resolve_max_concurrency(max_concurrency) + sem = asyncio.Semaphore(max_concurrency) async def _process(batch: Batch) -> tuple[Batch, list]: @@ -388,9 +399,9 @@ async def _process(batch: Batch) -> tuple[Batch, list]: len(batch.findings), ) if self._structured_llm: - response = await self._structured_llm.ainvoke(prompt) + response = await retry_llm_call(lambda: self._structured_llm.ainvoke(prompt)) else: - response = (await self._llm.ainvoke(prompt)).content + response = (await retry_llm_call(lambda: self._llm.ainvoke(prompt))).content logger.debug("LLM response for %s", batch.file_label) return (batch, self.parse_response(response, batch)) diff --git a/src/skillspector/llm_utils.py b/src/skillspector/llm_utils.py index 1e03fc1..51625a9 100644 --- a/src/skillspector/llm_utils.py +++ b/src/skillspector/llm_utils.py @@ -29,7 +29,11 @@ from __future__ import annotations +import asyncio +import logging import os +import random +import time from langchain_openai import ChatOpenAI @@ -37,6 +41,8 @@ from skillspector.model_info import get_max_input_tokens, get_max_output_tokens from skillspector.providers import resolve_provider_credentials +logger = logging.getLogger(__name__) + def _resolve_llm_credentials() -> tuple[str, str | None]: """Return ``(api_key, base_url)`` resolved from the environment. @@ -101,3 +107,89 @@ def chat_completion(prompt: str, *, model: str | None = None) -> str: llm = get_chat_model(model=model) response = llm.invoke(prompt) return response.content or "" + + +def retry_llm_call_sync(call_func, max_attempts=4): + """Retry transient LLM errors (429, timeout, 5XX, etc) with exponential backoff (sync).""" + for attempt in range(max_attempts): + try: + return call_func() # Call it each retry + except Exception as e: + error_str = str(e).lower() + error_name = type(e).__name__.lower() + + is_retryable = ( + "429" in error_str + or "500" in error_str + or "502" in error_str + or "503" in error_str + or "529" in error_str + or "ratelimit" in error_name + or "timeout" in error_name + or "timeout" in error_str + or "overloaded" in error_str + or "service unavailable" in error_str + or "bad gateway" in error_str + or "connection" in error_name + ) + + if not is_retryable: + raise + if attempt == max_attempts - 1: + raise + + wait = 2**attempt + random.uniform(0, 1) + time.sleep(wait) + + +async def retry_llm_call(coro_func, max_attempts=4): + """Retry transient LLM errors (429, timeout, 5XX, etc) with exponential backoff (async).""" + for attempt in range(max_attempts): + try: + return await coro_func() + except Exception as e: + error_str = str(e).lower() + error_name = type(e).__name__.lower() + + is_retryable = ( + "429" in error_str + or "500" in error_str + or "502" in error_str + or "503" in error_str + or "529" in error_str + or "ratelimit" in error_name + or "timeout" in error_name + or "timeout" in error_str + or "overloaded" in error_str + or "service unavailable" in error_str + or "bad gateway" in error_str + or "connection" in error_name + ) + + if not is_retryable: + raise + if attempt == max_attempts - 1: + raise + + wait = 2**attempt + random.uniform(0, 1) + await asyncio.sleep(wait) + + +def _resolve_max_concurrency(max_concurrency: int | None = None) -> int: + """Resolve max concurrency from an explicit value or environment.""" + + if max_concurrency is not None: + return max(max_concurrency, 1) + + raw = (os.environ.get("SKILLSPECTOR_MAX_CONCURRENCY") or "").strip() + + try: + max_concurrency = int(raw) if raw else 5 + except ValueError: + logger.warning( + "Invalid SKILLSPECTOR_MAX_CONCURRENCY=%r; defaulting to 5", + raw, + ) + max_concurrency = 5 + + return max(max_concurrency, 1) diff --git a/tests/nodes/test_llm_analyzer_base.py b/tests/nodes/test_llm_analyzer_base.py index 9899a7f..2368000 100644 --- a/tests/nodes/test_llm_analyzer_base.py +++ b/tests/nodes/test_llm_analyzer_base.py @@ -17,6 +17,7 @@ from __future__ import annotations +import os from unittest.mock import AsyncMock, MagicMock, patch import pytest @@ -31,6 +32,7 @@ findings_in_range, number_lines, ) +from skillspector.llm_utils import _resolve_max_concurrency from skillspector.models import Finding from skillspector.nodes.meta_analyzer import ( LLMMetaAnalyzer, @@ -1225,3 +1227,109 @@ def test_unknown_model_uses_default(self) -> None: out = get_max_output_tokens("unknown/model") assert inp == int(mocked_ctx * 0.75) assert out == int(mocked_ctx * 0.25) + + +# --------------------------------------------------------------------------- +# Rate limit retry tests +# --------------------------------------------------------------------------- + + +class TestRateLimitRetry: + MODEL = "nvidia/openai/gpt-oss-120b" + + @patch(MOCK_PATCH_TARGET, _mock_get_chat_model) + def test_run_batches_retries_on_429(self) -> None: + """Verify sync batch processing retries when 429 is raised.""" + calls = 0 + original_response = LLMAnalysisResult( + findings=[LLMFinding(rule_id="T-1", message="hit", severity="LOW", start_line=1)] + ) + + def flaky_invoke(prompt: str) -> LLMAnalysisResult: + nonlocal calls + calls += 1 + if calls == 1: + raise Exception("429 Too Many Requests") + return original_response + + analyzer = LLMAnalyzerBase(base_prompt="test", model=self.MODEL) + analyzer._structured_llm.invoke = flaky_invoke + batch = Batch(file_path="a.py", content="code") + + results = analyzer.run_batches([batch]) + assert len(results) == 1 + assert results[0][1][0].rule_id == "T-1" + assert calls == 2 + + @patch(MOCK_PATCH_TARGET, _mock_get_chat_model) + async def test_arun_batches_retries_on_429(self) -> None: + """Verify async batch processing retries when 429 is raised.""" + calls = 0 + original_response = LLMAnalysisResult( + findings=[LLMFinding(rule_id="T-1", message="hit", severity="LOW", start_line=1)] + ) + + async def flaky_ainvoke(prompt: str) -> LLMAnalysisResult: + nonlocal calls + calls += 1 + if calls == 1: + raise Exception("429 Too Many Requests") + return original_response + + analyzer = LLMAnalyzerBase(base_prompt="test", model=self.MODEL) + analyzer._structured_llm.ainvoke = flaky_ainvoke + batch = Batch(file_path="a.py", content="code") + + results = await analyzer.arun_batches([batch]) + assert len(results) == 1 + assert results[0][1][0].rule_id == "T-1" + assert calls == 2 + + @patch(MOCK_PATCH_TARGET, _mock_get_chat_model) + def test_run_batches_fails_after_max_retries(self) -> None: + """Verify sync processing fails after exhausting retries.""" + + def always_fails(prompt: str) -> LLMAnalysisResult: + raise Exception("429 Too Many Requests") + + analyzer = LLMAnalyzerBase(base_prompt="test", model=self.MODEL) + analyzer._structured_llm.invoke = always_fails + batch = Batch(file_path="a.py", content="code") + + with pytest.raises(Exception, match="429"): + analyzer.run_batches([batch]) + + @patch(MOCK_PATCH_TARGET, _mock_get_chat_model) + async def test_arun_batches_fails_after_max_retries(self) -> None: + """Verify async processing fails after exhausting retries.""" + + async def always_fails(prompt: str) -> LLMAnalysisResult: + raise Exception("429 Too Many Requests") + + analyzer = LLMAnalyzerBase(base_prompt="test", model=self.MODEL) + analyzer._structured_llm.ainvoke = always_fails + batch = Batch(file_path="a.py", content="code") + + with pytest.raises(Exception, match="429"): + await analyzer.arun_batches([batch]) + + @pytest.mark.parametrize( + ("env", "expected"), + [ + ({}, 5), + ({"SKILLSPECTOR_MAX_CONCURRENCY": ""}, 5), + ({"SKILLSPECTOR_MAX_CONCURRENCY": " "}, 5), + ({"SKILLSPECTOR_MAX_CONCURRENCY": "auto"}, 5), + ({"SKILLSPECTOR_MAX_CONCURRENCY": "0"}, 1), + ({"SKILLSPECTOR_MAX_CONCURRENCY": "-3"}, 1), + ({"SKILLSPECTOR_MAX_CONCURRENCY": "8"}, 8), + ], + ) + def test_max_concurrency_env_parsing( + self, + env: dict[str, str], + expected: int, + ) -> None: + """Verify SKILLSPECTOR_MAX_CONCURRENCY defaults and clamps safely.""" + with patch.dict(os.environ, env, clear=True): + assert _resolve_max_concurrency() == expected