Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
232 changes: 161 additions & 71 deletions services/core/files/src/nmp/core/files/app/backends/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,14 @@
from __future__ import annotations

import logging
from collections.abc import AsyncIterator, Callable
from dataclasses import dataclass, field
from typing import (
AsyncIterator,
)
from datetime import datetime, timezone
from email.utils import parsedate_to_datetime
from typing import TypeVar

import aiohttp
from anyio import to_thread
from anyio import sleep, to_thread
from huggingface_hub import HfApi, get_hf_file_metadata, hf_hub_url
from huggingface_hub.utils import (
EntryNotFoundError,
Expand All @@ -39,6 +40,12 @@

logger = logging.getLogger(__name__)

_T = TypeVar("_T")

_HF_TRANSIENT_RETRY_ATTEMPTS = 4
_HF_TRANSIENT_RETRY_INITIAL_DELAY_SECONDS = 0.5
_HF_TRANSIENT_RETRY_MAX_DELAY_SECONDS = 5.0


class HuggingfaceBackendError(StorageBackendError):
"""Raised when there's issues talking to Huggingface."""
Expand Down Expand Up @@ -101,6 +108,90 @@ def raise_for_hf_status(
raise HuggingfaceBackendError(f"HTTP {status_code}{context}")


def _map_hf_http_error(exc: HfHubHTTPError) -> Exception:
if exc.response is not None:
try:
raise_for_hf_status(
exc.response.status_code,
dict(exc.response.headers),
str(exc.response.url),
)
except (
HuggingfaceAccessError,
HuggingfaceConfigError,
HuggingfaceUnavailableError,
HuggingfaceBackendError,
) as mapped:
return mapped
return HuggingfaceBackendError(f"HuggingFace API error: {exc}")


def _retry_after_seconds(headers: dict[str, str] | None) -> float | None:
if not headers:
return None

raw_value = headers.get("Retry-After") or headers.get("retry-after")
if not raw_value:
return None

try:
return max(0.0, float(raw_value))
except ValueError:
pass

try:
retry_at = parsedate_to_datetime(raw_value)
except (TypeError, ValueError):
return None
if retry_at.tzinfo is None:
retry_at = retry_at.replace(tzinfo=timezone.utc)
return max(0.0, (retry_at - datetime.now(timezone.utc)).total_seconds())


async def _sleep_before_retry(
*,
operation: str,
attempt: int,
headers: dict[str, str] | None,
error: Exception,
) -> None:
retry_after = _retry_after_seconds(headers)
delay = (
min(retry_after, _HF_TRANSIENT_RETRY_MAX_DELAY_SECONDS)
if retry_after is not None
else min(
_HF_TRANSIENT_RETRY_INITIAL_DELAY_SECONDS * (2 ** (attempt - 1)),
_HF_TRANSIENT_RETRY_MAX_DELAY_SECONDS,
)
)
logger.warning(
"Transient HuggingFace error during %s; retrying attempt %s/%s after %.2fs: %s",
operation,
attempt + 1,
_HF_TRANSIENT_RETRY_ATTEMPTS,
delay,
error,
)
await sleep(delay)


async def _run_hf_request(operation: str, request: Callable[[], _T]) -> _T:
for attempt in range(1, _HF_TRANSIENT_RETRY_ATTEMPTS + 1):
try:
return await to_thread.run_sync(request)
except EntryNotFoundError:
raise
except HfHubHTTPError as exc:
mapped = _map_hf_http_error(exc)
if isinstance(mapped, HuggingfaceUnavailableError) and attempt < _HF_TRANSIENT_RETRY_ATTEMPTS:
headers = dict(exc.response.headers) if exc.response is not None else None
await _sleep_before_retry(operation=operation, attempt=attempt, headers=headers, error=mapped)
continue
raise mapped from exc

raise HuggingfaceBackendError(f"HuggingFace API error during {operation}")


@dataclass
class HuggingfaceStorageImpl(StorageImpl):
config: HuggingfaceStorageConfig
Expand All @@ -126,28 +217,20 @@ async def resolve_config(self) -> HuggingfaceStorageConfig:
Raises:
HuggingfaceConfigError: If the repository or revision is not found.
"""
try:
info = await to_thread.run_sync(
lambda: self._api.repo_info(
repo_id=self.config.repo_id,
repo_type=self.config.repo_type,
revision=self.config.revision,
)
)
return self.config.model_copy(
update={
"original_revision": self.config.revision,
"revision": info.sha,
}
)
except HfHubHTTPError as exc:
if exc.response is not None:
raise_for_hf_status(
exc.response.status_code,
dict(exc.response.headers),
str(exc.response.url),
)
raise HuggingfaceBackendError(f"HuggingFace API error: {exc}") from exc
info = await _run_hf_request(
"resolve repository revision",
lambda: self._api.repo_info(
repo_id=self.config.repo_id,
repo_type=self.config.repo_type,
revision=self.config.revision,
),
)
return self.config.model_copy(
update={
"original_revision": self.config.revision,
"revision": info.sha,
}
)

def _get_download_url(self, filepath: str) -> str:
"""Generate a download URL for a file in the Huggingface repo."""
Expand All @@ -162,14 +245,18 @@ def _get_download_url(self, filepath: str) -> str:
async def _get_hf_file_metadata(self, filepath: str):
"""Get file metadata from Huggingface for a specific file."""
url = self._get_download_url(filepath)
return await to_thread.run_sync(lambda: get_hf_file_metadata(url=url, token=self.secrets.get("token")))
return await _run_hf_request(
"get file metadata",
lambda: get_hf_file_metadata(url=url, token=self.secrets.get("token")),
)

async def list_files(self, path: str | None = None) -> list[FileInfo]:
"""List files in the Huggingface repository."""
try:
# list_repo_tree returns RepoFile and RepoFolder objects
# We filter for files only (items with size attribute)
items = await to_thread.run_sync(
items = await _run_hf_request(
"list repository tree",
lambda: list(
self._api.list_repo_tree(
repo_id=self.config.repo_id,
Expand All @@ -178,7 +265,7 @@ async def list_files(self, path: str | None = None) -> list[FileInfo]:
path_in_repo=path,
recursive=True,
)
)
),
)
except EntryNotFoundError:
# list_repo_tree expects a directory path. If path points to a file
Expand All @@ -192,14 +279,6 @@ async def list_files(self, path: str | None = None) -> list[FileInfo]:
# Neither a directory nor a file - return empty list
return []
return []
except HfHubHTTPError as exc:
if exc.response is not None:
raise_for_hf_status(
exc.response.status_code,
dict(exc.response.headers),
str(exc.response.url),
)
raise HuggingfaceBackendError(f"HuggingFace API error: {exc}") from exc

file_infos = []
for item in items:
Expand All @@ -225,17 +304,12 @@ async def get_file(self, path: str) -> FileInfo:
url = self._get_download_url(path)

try:
metadata = await to_thread.run_sync(lambda: get_hf_file_metadata(url=url, token=self.secrets.get("token")))
metadata = await _run_hf_request(
"get file metadata",
lambda: get_hf_file_metadata(url=url, token=self.secrets.get("token")),
)
except EntryNotFoundError as exc:
raise NotFoundError(f"File '{path}' not found in {self.config.repo_id}@{self.config.revision}") from exc
except HfHubHTTPError as exc:
if exc.response is not None:
raise_for_hf_status(
exc.response.status_code,
dict(exc.response.headers),
str(exc.response.url),
)
raise HuggingfaceBackendError(f"HuggingFace API error: {exc}") from exc

return FileInfo(path=path, size=metadata.size)

Expand All @@ -254,21 +328,37 @@ async def download(self, path: str, byte_range: ByteRange | None) -> AsyncIterat
headers["Authorization"] = f"Bearer {self.secrets.get('token')}"

async def _download() -> AsyncIterator[bytes]:
session = get_http_session()
try:
async for chunk in download_url_streaming(
url=download_url,
session=session,
headers=headers if headers else None,
byte_range=byte_range,
chunk_size=self.config.read_chunk_size,
):
yield chunk
except aiohttp.ClientResponseError as exc:
response_headers = dict(exc.headers) if exc.headers else None
raise_for_hf_status(exc.status, response_headers, download_url)
except aiohttp.ClientError as exc:
raise HuggingfaceBackendError(f"Network error downloading file {path}") from exc
for attempt in range(1, _HF_TRANSIENT_RETRY_ATTEMPTS + 1):
session = get_http_session()
yielded = False
try:
async for chunk in download_url_streaming(
url=download_url,
session=session,
headers=headers if headers else None,
byte_range=byte_range,
chunk_size=self.config.read_chunk_size,
):
yielded = True
yield chunk
return
except aiohttp.ClientResponseError as exc:
response_headers = dict(exc.headers) if exc.headers else None
try:
raise_for_hf_status(exc.status, response_headers, download_url)
except HuggingfaceUnavailableError as mapped:
if yielded or attempt >= _HF_TRANSIENT_RETRY_ATTEMPTS:
raise mapped from exc
await _sleep_before_retry(
operation="download file",
attempt=attempt,
headers=response_headers,
error=mapped,
)
continue
raise
except aiohttp.ClientError as exc:
raise HuggingfaceBackendError(f"Network error downloading file {path}") from exc

return _download()

Expand All @@ -285,12 +375,13 @@ async def validate_storage(self):
"""
validate_external_host(self.config.endpoint)
try:
repo_info = await to_thread.run_sync(
repo_info = await _run_hf_request(
"validate repository",
lambda: self._api.repo_info(
repo_id=self.config.repo_id,
repo_type=self.config.repo_type,
revision=self.config.revision,
)
),
)

# Verify we can actually download files by checking a file's metadata.
Expand All @@ -299,14 +390,13 @@ async def validate_storage(self):
sibling = repo_info.siblings[0]
await self._get_hf_file_metadata(sibling.rfilename)

except HfHubHTTPError as exc:
if exc.response is not None:
raise_for_hf_status(
exc.response.status_code,
dict(exc.response.headers),
str(exc.response.url),
)
raise HuggingfaceBackendError(f"HuggingFace API error: {exc}") from exc
except (
HuggingfaceAccessError,
HuggingfaceConfigError,
HuggingfaceUnavailableError,
HuggingfaceBackendError,
):
raise
except Exception as exc:
raise HuggingfaceBackendError(
f"Failed to access Huggingface repository {self.config.repo_id}@{self.config.revision}"
Expand Down
Loading
Loading