From b2e345e7092a04ddc42eeabd2966136033a42c97 Mon Sep 17 00:00:00 2001 From: Matt Kornfield Date: Fri, 12 Jun 2026 21:08:45 +0000 Subject: [PATCH] chore: add hf retries to files --- .../core/files/app/backends/huggingface.py | 232 ++++++++++++------ .../files/tests/test_huggingface_backend.py | 54 +++- 2 files changed, 204 insertions(+), 82 deletions(-) diff --git a/services/core/files/src/nmp/core/files/app/backends/huggingface.py b/services/core/files/src/nmp/core/files/app/backends/huggingface.py index 2bf7356a5e..6006f9364d 100644 --- a/services/core/files/src/nmp/core/files/app/backends/huggingface.py +++ b/services/core/files/src/nmp/core/files/app/backends/huggingface.py @@ -6,13 +6,14 @@ from __future__ import annotations import logging +from collections.abc import AsyncIterator, Callable from dataclasses import dataclass, field -from typing import ( - AsyncIterator, -) +from datetime import datetime, timezone +from email.utils import parsedate_to_datetime +from typing import TypeVar import aiohttp -from anyio import to_thread +from anyio import sleep, to_thread from huggingface_hub import HfApi, get_hf_file_metadata, hf_hub_url from huggingface_hub.utils import ( EntryNotFoundError, @@ -39,6 +40,12 @@ logger = logging.getLogger(__name__) +_T = TypeVar("_T") + +_HF_TRANSIENT_RETRY_ATTEMPTS = 4 +_HF_TRANSIENT_RETRY_INITIAL_DELAY_SECONDS = 0.5 +_HF_TRANSIENT_RETRY_MAX_DELAY_SECONDS = 5.0 + class HuggingfaceBackendError(StorageBackendError): """Raised when there's issues talking to Huggingface.""" @@ -101,6 +108,90 @@ def raise_for_hf_status( raise HuggingfaceBackendError(f"HTTP {status_code}{context}") +def _map_hf_http_error(exc: HfHubHTTPError) -> Exception: + if exc.response is not None: + try: + raise_for_hf_status( + exc.response.status_code, + dict(exc.response.headers), + str(exc.response.url), + ) + except ( + HuggingfaceAccessError, + HuggingfaceConfigError, + HuggingfaceUnavailableError, + HuggingfaceBackendError, + ) as mapped: + return mapped + return HuggingfaceBackendError(f"HuggingFace API error: {exc}") + + +def _retry_after_seconds(headers: dict[str, str] | None) -> float | None: + if not headers: + return None + + raw_value = headers.get("Retry-After") or headers.get("retry-after") + if not raw_value: + return None + + try: + return max(0.0, float(raw_value)) + except ValueError: + pass + + try: + retry_at = parsedate_to_datetime(raw_value) + except (TypeError, ValueError): + return None + if retry_at.tzinfo is None: + retry_at = retry_at.replace(tzinfo=timezone.utc) + return max(0.0, (retry_at - datetime.now(timezone.utc)).total_seconds()) + + +async def _sleep_before_retry( + *, + operation: str, + attempt: int, + headers: dict[str, str] | None, + error: Exception, +) -> None: + retry_after = _retry_after_seconds(headers) + delay = ( + min(retry_after, _HF_TRANSIENT_RETRY_MAX_DELAY_SECONDS) + if retry_after is not None + else min( + _HF_TRANSIENT_RETRY_INITIAL_DELAY_SECONDS * (2 ** (attempt - 1)), + _HF_TRANSIENT_RETRY_MAX_DELAY_SECONDS, + ) + ) + logger.warning( + "Transient HuggingFace error during %s; retrying attempt %s/%s after %.2fs: %s", + operation, + attempt + 1, + _HF_TRANSIENT_RETRY_ATTEMPTS, + delay, + error, + ) + await sleep(delay) + + +async def _run_hf_request(operation: str, request: Callable[[], _T]) -> _T: + for attempt in range(1, _HF_TRANSIENT_RETRY_ATTEMPTS + 1): + try: + return await to_thread.run_sync(request) + except EntryNotFoundError: + raise + except HfHubHTTPError as exc: + mapped = _map_hf_http_error(exc) + if isinstance(mapped, HuggingfaceUnavailableError) and attempt < _HF_TRANSIENT_RETRY_ATTEMPTS: + headers = dict(exc.response.headers) if exc.response is not None else None + await _sleep_before_retry(operation=operation, attempt=attempt, headers=headers, error=mapped) + continue + raise mapped from exc + + raise HuggingfaceBackendError(f"HuggingFace API error during {operation}") + + @dataclass class HuggingfaceStorageImpl(StorageImpl): config: HuggingfaceStorageConfig @@ -126,28 +217,20 @@ async def resolve_config(self) -> HuggingfaceStorageConfig: Raises: HuggingfaceConfigError: If the repository or revision is not found. """ - try: - info = await to_thread.run_sync( - lambda: self._api.repo_info( - repo_id=self.config.repo_id, - repo_type=self.config.repo_type, - revision=self.config.revision, - ) - ) - return self.config.model_copy( - update={ - "original_revision": self.config.revision, - "revision": info.sha, - } - ) - except HfHubHTTPError as exc: - if exc.response is not None: - raise_for_hf_status( - exc.response.status_code, - dict(exc.response.headers), - str(exc.response.url), - ) - raise HuggingfaceBackendError(f"HuggingFace API error: {exc}") from exc + info = await _run_hf_request( + "resolve repository revision", + lambda: self._api.repo_info( + repo_id=self.config.repo_id, + repo_type=self.config.repo_type, + revision=self.config.revision, + ), + ) + return self.config.model_copy( + update={ + "original_revision": self.config.revision, + "revision": info.sha, + } + ) def _get_download_url(self, filepath: str) -> str: """Generate a download URL for a file in the Huggingface repo.""" @@ -162,14 +245,18 @@ def _get_download_url(self, filepath: str) -> str: async def _get_hf_file_metadata(self, filepath: str): """Get file metadata from Huggingface for a specific file.""" url = self._get_download_url(filepath) - return await to_thread.run_sync(lambda: get_hf_file_metadata(url=url, token=self.secrets.get("token"))) + return await _run_hf_request( + "get file metadata", + lambda: get_hf_file_metadata(url=url, token=self.secrets.get("token")), + ) async def list_files(self, path: str | None = None) -> list[FileInfo]: """List files in the Huggingface repository.""" try: # list_repo_tree returns RepoFile and RepoFolder objects # We filter for files only (items with size attribute) - items = await to_thread.run_sync( + items = await _run_hf_request( + "list repository tree", lambda: list( self._api.list_repo_tree( repo_id=self.config.repo_id, @@ -178,7 +265,7 @@ async def list_files(self, path: str | None = None) -> list[FileInfo]: path_in_repo=path, recursive=True, ) - ) + ), ) except EntryNotFoundError: # list_repo_tree expects a directory path. If path points to a file @@ -192,14 +279,6 @@ async def list_files(self, path: str | None = None) -> list[FileInfo]: # Neither a directory nor a file - return empty list return [] return [] - except HfHubHTTPError as exc: - if exc.response is not None: - raise_for_hf_status( - exc.response.status_code, - dict(exc.response.headers), - str(exc.response.url), - ) - raise HuggingfaceBackendError(f"HuggingFace API error: {exc}") from exc file_infos = [] for item in items: @@ -225,17 +304,12 @@ async def get_file(self, path: str) -> FileInfo: url = self._get_download_url(path) try: - metadata = await to_thread.run_sync(lambda: get_hf_file_metadata(url=url, token=self.secrets.get("token"))) + metadata = await _run_hf_request( + "get file metadata", + lambda: get_hf_file_metadata(url=url, token=self.secrets.get("token")), + ) except EntryNotFoundError as exc: raise NotFoundError(f"File '{path}' not found in {self.config.repo_id}@{self.config.revision}") from exc - except HfHubHTTPError as exc: - if exc.response is not None: - raise_for_hf_status( - exc.response.status_code, - dict(exc.response.headers), - str(exc.response.url), - ) - raise HuggingfaceBackendError(f"HuggingFace API error: {exc}") from exc return FileInfo(path=path, size=metadata.size) @@ -254,21 +328,37 @@ async def download(self, path: str, byte_range: ByteRange | None) -> AsyncIterat headers["Authorization"] = f"Bearer {self.secrets.get('token')}" async def _download() -> AsyncIterator[bytes]: - session = get_http_session() - try: - async for chunk in download_url_streaming( - url=download_url, - session=session, - headers=headers if headers else None, - byte_range=byte_range, - chunk_size=self.config.read_chunk_size, - ): - yield chunk - except aiohttp.ClientResponseError as exc: - response_headers = dict(exc.headers) if exc.headers else None - raise_for_hf_status(exc.status, response_headers, download_url) - except aiohttp.ClientError as exc: - raise HuggingfaceBackendError(f"Network error downloading file {path}") from exc + for attempt in range(1, _HF_TRANSIENT_RETRY_ATTEMPTS + 1): + session = get_http_session() + yielded = False + try: + async for chunk in download_url_streaming( + url=download_url, + session=session, + headers=headers if headers else None, + byte_range=byte_range, + chunk_size=self.config.read_chunk_size, + ): + yielded = True + yield chunk + return + except aiohttp.ClientResponseError as exc: + response_headers = dict(exc.headers) if exc.headers else None + try: + raise_for_hf_status(exc.status, response_headers, download_url) + except HuggingfaceUnavailableError as mapped: + if yielded or attempt >= _HF_TRANSIENT_RETRY_ATTEMPTS: + raise mapped from exc + await _sleep_before_retry( + operation="download file", + attempt=attempt, + headers=response_headers, + error=mapped, + ) + continue + raise + except aiohttp.ClientError as exc: + raise HuggingfaceBackendError(f"Network error downloading file {path}") from exc return _download() @@ -285,12 +375,13 @@ async def validate_storage(self): """ validate_external_host(self.config.endpoint) try: - repo_info = await to_thread.run_sync( + repo_info = await _run_hf_request( + "validate repository", lambda: self._api.repo_info( repo_id=self.config.repo_id, repo_type=self.config.repo_type, revision=self.config.revision, - ) + ), ) # Verify we can actually download files by checking a file's metadata. @@ -299,14 +390,13 @@ async def validate_storage(self): sibling = repo_info.siblings[0] await self._get_hf_file_metadata(sibling.rfilename) - except HfHubHTTPError as exc: - if exc.response is not None: - raise_for_hf_status( - exc.response.status_code, - dict(exc.response.headers), - str(exc.response.url), - ) - raise HuggingfaceBackendError(f"HuggingFace API error: {exc}") from exc + except ( + HuggingfaceAccessError, + HuggingfaceConfigError, + HuggingfaceUnavailableError, + HuggingfaceBackendError, + ): + raise except Exception as exc: raise HuggingfaceBackendError( f"Failed to access Huggingface repository {self.config.repo_id}@{self.config.revision}" diff --git a/services/core/files/tests/test_huggingface_backend.py b/services/core/files/tests/test_huggingface_backend.py index 629e1135fb..7acf123ad2 100644 --- a/services/core/files/tests/test_huggingface_backend.py +++ b/services/core/files/tests/test_huggingface_backend.py @@ -3,7 +3,7 @@ """Tests for Huggingface storage backend.""" -from unittest.mock import Mock, patch +from unittest.mock import AsyncMock, Mock, patch import aiohttp import httpx @@ -12,6 +12,7 @@ from nmp.core.files.app.backends.base import ByteRange from nmp.core.files.app.backends.factory import storage_impl_factory from nmp.core.files.app.backends.huggingface import ( + _HF_TRANSIENT_RETRY_ATTEMPTS, HuggingfaceAccessError, HuggingfaceBackendError, HuggingfaceConfigError, @@ -32,6 +33,17 @@ def _hf_http_error_without_response(message: str): return error +def _hf_http_error(status_code: int, message: str = "HuggingFace error", headers: dict[str, str] | None = None): + """Create a Hugging Face HTTP error with a mock response.""" + from huggingface_hub.utils import HfHubHTTPError + + response = Mock() + response.status_code = status_code + response.headers = headers or {} + response.url = "https://huggingface.co/test-org/test-repo" + return HfHubHTTPError(message, response=response) + + @pytest.fixture def mock_httpx_response(): """Create a mock httpx.Response for Huggingface exceptions.""" @@ -585,25 +597,22 @@ async def test_get_file_gated_repo_error(hf_config, mock_httpx_response, hf_secr async def test_get_file_rate_limit_error(hf_config, hf_secrets_empty): """Test get_file when rate limited by HuggingFace.""" - from huggingface_hub.utils import HfHubHTTPError - - # Create a mock response with 429 status code - mock_response = Mock() - mock_response.status_code = 429 - mock_response.headers = {} - mock_response.url = "https://huggingface.co/test-org/test-repo/test.txt" - - mock_error = HfHubHTTPError("Rate limited", response=mock_response) + mock_error = _hf_http_error(429, "Rate limited", headers={"Retry-After": "0"}) with patch("nmp.core.files.app.backends.huggingface.get_hf_file_metadata") as mock_metadata: mock_metadata.side_effect = mock_error impl = HuggingfaceStorageImpl(hf_config, hf_secrets_empty) - with pytest.raises(HuggingfaceUnavailableError) as exc_info: + with ( + patch("nmp.core.files.app.backends.huggingface.sleep", new_callable=AsyncMock) as mock_sleep, + pytest.raises(HuggingfaceUnavailableError) as exc_info, + ): await impl.get_file("test.txt") assert "Rate limited" in str(exc_info.value) + assert mock_metadata.call_count == _HF_TRANSIENT_RETRY_ATTEMPTS + assert mock_sleep.await_count == _HF_TRANSIENT_RETRY_ATTEMPTS - 1 async def test_validate_storage_gated_repo_error(hf_config, mock_hf_api, mock_httpx_response, hf_secrets_empty): @@ -750,6 +759,29 @@ async def test_resolve_config_tag_to_sha(mock_hf_api, hf_secrets_empty): assert resolved_config.revision == "def789abc123456" +async def test_resolve_config_retries_rate_limit_then_succeeds(mock_hf_api, hf_secrets_empty): + """Transient HuggingFace rate limits are retried during revision resolution.""" + mock_repo_info = Mock() + mock_repo_info.sha = "abc123def456789" + mock_hf_api.repo_info.side_effect = [ + _hf_http_error(429, "Rate limited", headers={"Retry-After": "0"}), + mock_repo_info, + ] + + impl = HuggingfaceStorageImpl( + HuggingfaceStorageConfig(repo_id="test-org/test-repo", repo_type="model", revision="main"), + hf_secrets_empty, + ) + + with patch("nmp.core.files.app.backends.huggingface.sleep", new_callable=AsyncMock) as mock_sleep: + resolved_config = await impl.resolve_config() + + assert resolved_config.original_revision == "main" + assert resolved_config.revision == "abc123def456789" + assert mock_hf_api.repo_info.call_count == 2 + mock_sleep.assert_awaited_once_with(0.0) + + async def test_resolve_config_repo_not_found(hf_config, mock_hf_api, mock_httpx_response, hf_secrets_empty): """Test resolve_config raises HuggingfaceConfigError when repo not found.""" from huggingface_hub.utils import HfHubHTTPError