diff --git a/src/guidellm/backends/__init__.py b/src/guidellm/backends/__init__.py index 6ee4e82bc..0a966f8f4 100644 --- a/src/guidellm/backends/__init__.py +++ b/src/guidellm/backends/__init__.py @@ -9,8 +9,6 @@ handlers for processing streaming and non-streaming API responses. """ -from guidellm.extras.vllm import HAS_VLLM - from .backend import Backend, BackendArgs from .openai import ( AudioRequestHandler, @@ -20,13 +18,7 @@ OpenAIRequestHandlerFactory, TextCompletionsRequestHandler, ) - -# Conditionally import VLLM backend if available -if HAS_VLLM: - from .vllm_python import VLLMPythonBackend, VLLMResponseHandler -else: - VLLMPythonBackend = None # type: ignore[assignment, misc] - VLLMResponseHandler = None # type: ignore[assignment, misc] +from .vllm_python import VLLMPythonBackend, VLLMResponseHandler __all__ = [ "AudioRequestHandler", @@ -37,8 +29,6 @@ "OpenAIRequestHandler", "OpenAIRequestHandlerFactory", "TextCompletionsRequestHandler", + "VLLMPythonBackend", + "VLLMResponseHandler", ] - -# Conditionally add VLLM backend and handler to exports -if HAS_VLLM: - __all__.extend(["VLLMPythonBackend", "VLLMResponseHandler"]) diff --git a/src/guidellm/backends/vllm_python/vllm.py b/src/guidellm/backends/vllm_python/vllm.py index b6ce15f5a..d4134b886 100644 --- a/src/guidellm/backends/vllm_python/vllm.py +++ b/src/guidellm/backends/vllm_python/vllm.py @@ -22,13 +22,7 @@ from guidellm.backends.backend import Backend, BackendArgs from guidellm.backends.vllm_python.vllm_response import VLLMResponseHandler -from guidellm.extras.vllm import ( - HAS_VLLM, - AsyncEngineArgs, - AsyncLLMEngine, - RequestOutput, - SamplingParams, -) +from guidellm.extras import vllm from guidellm.logger import logger from guidellm.schemas import ( GenerationRequest, @@ -36,22 +30,7 @@ RequestInfo, StandardBaseModel, ) - -try: - from guidellm.extras.audio import _decode_audio - - HAS_AUDIO = True -except ImportError: - _decode_audio = None # type: ignore[assignment] - HAS_AUDIO = False - -try: - from guidellm.extras.vision import image_dict_to_pil - - HAS_VISION = True -except ImportError: - image_dict_to_pil = None # type: ignore[assignment] - HAS_VISION = False +from guidellm.utils import audio, vision # Sentinel for "chat template not yet resolved" cache. _CHAT_TEMPLATE_UNSET: object = object() @@ -137,14 +116,6 @@ class _ResolvedRequest(StandardBaseModel): ) -def _check_vllm_available() -> None: - """Check if vllm is available and raise helpful error if not.""" - if not HAS_VLLM: - raise ImportError( - "vllm is not installed. Install vllm to use the vllm python backend." - ) - - def _has_jinja2_markers(s: str) -> bool: """Return True if the string contains Jinja2 template syntax ({{, {%, or {#).""" return "{{" in s or "{%" in s or "{#" in s @@ -179,13 +150,12 @@ def __init__( """ Initialize VLLM Python backend with model and configuration. """ - _check_vllm_available() super().__init__(arguments) self._args = arguments # Runtime state self._in_process = False - self._engine: AsyncLLMEngine | None = None + self._engine: vllm.AsyncLLMEngine | None = None self._resolved_chat_template: str | None | object = _CHAT_TEMPLATE_UNSET @property @@ -214,8 +184,8 @@ async def process_startup(self): if self._in_process: raise RuntimeError("Backend already started up for process.") - engine_args = AsyncEngineArgs(**self._args.vllm_config) # type: ignore[misc] - self._engine = AsyncLLMEngine.from_engine_args(engine_args) # type: ignore[misc] + engine_args = vllm.AsyncEngineArgs(**self._args.vllm_config) + self._engine = vllm.AsyncLLMEngine.from_engine_args(engine_args) self._in_process = True async def process_shutdown(self): @@ -264,7 +234,7 @@ async def default_model(self) -> str: """ return self._args.model - def _validate_backend_initialized(self) -> AsyncLLMEngine: + def _validate_backend_initialized(self) -> vllm.AsyncLLMEngine: """ Validate that the backend is initialized and return the engine. @@ -304,14 +274,9 @@ def _build_multi_modal_data_from_columns( # noqa: C901, PLR0912 for item in image_items: if not item or not isinstance(item, dict): continue - if not HAS_VISION or image_dict_to_pil is None: - raise ImportError( - "Image column support requires guidellm[vision]. " - "Install with: pip install 'guidellm[vision]'" - ) # Convert raw image dicts into PIL Images as required by vLLM's vision # processor - pil_image = image_dict_to_pil(item) + pil_image = vision.image_dict_to_pil(item) if "image" not in multi_modal_data: multi_modal_data["image"] = pil_image else: @@ -334,15 +299,10 @@ def _build_multi_modal_data_from_columns( # noqa: C901, PLR0912 else: audio_bytes = first.get("audio") if isinstance(audio_bytes, bytes) and len(audio_bytes) > 0: - if not HAS_AUDIO or _decode_audio is None: - raise ImportError( - "Audio column support requires guidellm[audio]. " - "Install with: pip install 'guidellm[audio]'" - ) try: # Decode raw audio bytes into an array since vLLM audio models # expect either raw numpy arrays or specific tensor formats - audio_samples = _decode_audio(audio_bytes) + audio_samples = audio._decode_audio(audio_bytes) # noqa: SLF001 # torchcodec decodes audio on CPU, so .data is always # a CPU torch.Tensor. .cpu() is a no-op on CPU tensors. audio_array = audio_samples.data.cpu().numpy() @@ -675,7 +635,7 @@ def _update_token_timing( request_info.timings.last_token_iteration = iter_time request_info.timings.token_iterations += iterations - def _text_from_output(self, output: RequestOutput | None) -> str: + def _text_from_output(self, output: vllm.RequestOutput | None) -> str: """ Extract generated text from VLLM RequestOutput. @@ -688,7 +648,7 @@ def _text_from_output(self, output: RequestOutput | None) -> str: def _stream_usage_tokens( self, - output: RequestOutput, + output: vllm.RequestOutput, request_info: RequestInfo, ) -> tuple[int, int]: """ @@ -714,7 +674,7 @@ def _stream_usage_tokens( def _usage_from_output( self, - output: RequestOutput | None, + output: vllm.RequestOutput | None, *, request_info: RequestInfo | None = None, ) -> dict[str, int] | None: @@ -749,7 +709,7 @@ def _build_final_response( self, request: GenerationRequest, request_info: RequestInfo, - final_output: RequestOutput | None, + final_output: vllm.RequestOutput | None, stream: bool, text: str = "", ) -> tuple[GenerationResponse, RequestInfo] | None: @@ -776,7 +736,7 @@ def _build_final_response( def _create_sampling_params( self, max_tokens_override: int | None = None, - ) -> SamplingParams: + ) -> vllm.SamplingParams: """ Create VLLM SamplingParams. @@ -794,7 +754,7 @@ def _create_sampling_params( params["max_tokens"] = max_tokens_override params["ignore_eos"] = True - return SamplingParams(**params) # type: ignore[misc] + return vllm.SamplingParams(**params) def _raise_generation_error(self, exc: BaseException) -> None: """Re-raise generation failure with context. @@ -839,7 +799,7 @@ async def _run_generation( request_info: RequestInfo, stream: bool, generate_input: str | dict[str, Any], - sampling_params: SamplingParams, + sampling_params: vllm.SamplingParams, request_id: str, state: dict[str, Any], ) -> AsyncIterator[tuple[GenerationResponse, RequestInfo]]: diff --git a/src/guidellm/data/preprocessors/encoders.py b/src/guidellm/data/preprocessors/encoders.py index 57326367f..614ed090a 100644 --- a/src/guidellm/data/preprocessors/encoders.py +++ b/src/guidellm/data/preprocessors/encoders.py @@ -9,6 +9,8 @@ PreprocessorRegistry, ) from guidellm.data.schemas import DataPreprocessorArgs +from guidellm.utils import audio as guidellm_audio +from guidellm.utils import vision as guidellm_vision __all__ = ["MediaEncoder"] @@ -41,24 +43,6 @@ def __init__( ) -> None: self.config = config - @staticmethod - def encode_audio(*args, **kwargs): - from guidellm.extras.audio import encode_audio - - return encode_audio(*args, **kwargs) - - @staticmethod - def encode_image(*args, **kwargs): - from guidellm.extras.vision import encode_image - - return encode_image(*args, **kwargs) - - @staticmethod - def encode_video(*args, **kwargs): - from guidellm.extras.vision import encode_video - - return encode_video(*args, **kwargs) - def __call__(self, items: list[dict[str, list[Any]]]) -> list[dict[str, list[Any]]]: return [self.encode_turn(item) for item in items] @@ -70,7 +54,7 @@ def encode_turn(self, columns: dict[str, list[Any]]) -> dict[str, list[Any]]: continue encoded_audio.append( - self.encode_audio(audio, **self.config.audio_kwargs) + guidellm_audio.encode_audio(audio, **self.config.audio_kwargs) ) columns["audio_column"] = encoded_audio @@ -81,7 +65,7 @@ def encode_turn(self, columns: dict[str, list[Any]]) -> dict[str, list[Any]]: continue encoded_images.append( - self.encode_image(image, **self.config.image_kwargs) + guidellm_vision.encode_image(image, **self.config.image_kwargs) ) columns["image_column"] = encoded_images @@ -92,7 +76,7 @@ def encode_turn(self, columns: dict[str, list[Any]]) -> dict[str, list[Any]]: continue encoded_videos.append( - self.encode_video(video, **self.config.video_kwargs) + guidellm_vision.encode_video(video, **self.config.video_kwargs) ) columns["video_column"] = encoded_videos diff --git a/src/guidellm/extras/__init__.py b/src/guidellm/extras/__init__.py index 80a9a3ea2..5a3855989 100644 --- a/src/guidellm/extras/__init__.py +++ b/src/guidellm/extras/__init__.py @@ -1,4 +1,41 @@ """ Code that depends on optional dependencies. -Each submodule should be deferred imported. + +All dependent code should import in one of two ways: + +1. import guidellm.extras +2. from guidellm.extras import submodule + +As most of the codebase eager imports, importing specific functions or classes may cause +ImportErrors if the optional dependencies are missing. Importing from the module or +submodule level ensures errors are deferred to calling point. + +CRITICAL: Import Pattern for Lazy-Loaded Dependencies +====================================================== + +When importing from extras modules, use module imports to preserve lazy loading: + +CORRECT: + import guidellm.extras.audio as libs + decoder = libs.AudioDecoder(...) + +WRONG: + from guidellm.extras.audio import AudioDecoder + decoder = AudioDecoder(...) + +The from-import bypasses lazy loading and fails immediately if dependencies are missing. +Module imports defer errors until attribute access, allowing graceful error messages. + +Architecture: utils.audio/vision contain implementations; extras.audio/vision export +only external library classes (torchcodec, PIL). Implementations use module imports. """ + +import guidellm.utils.lazy_loader as lazy + +submodules = ["vllm", "vision", "audio"] + +__getattr__, __dir__, __all__ = lazy.attach( + __name__, + submodules=submodules, + lazy_submodules=True, # Only import submodules when accessed +) diff --git a/src/guidellm/extras/audio.py b/src/guidellm/extras/audio.py index fe05f2275..0d56d82f2 100644 --- a/src/guidellm/extras/audio.py +++ b/src/guidellm/extras/audio.py @@ -1,214 +1,13 @@ from __future__ import annotations -from pathlib import Path -from typing import Any, Literal - -import httpx -import numpy as np -import torch - -try: - from torchcodec import AudioSamples - from torchcodec.decoders import AudioDecoder - from torchcodec.encoders import AudioEncoder -except ImportError as e: - raise ImportError("Please install guidellm[audio] to use audio features") from e - -__all__ = [ - "encode_audio", - "is_url", -] - - -def is_url(text: Any) -> bool: - return isinstance(text, str) and text.startswith(("http://", "https://")) - - -def encode_audio( - audio: AudioDecoder - | bytes - | str - | Path - | np.ndarray - | torch.Tensor - | dict[str, Any], - sample_rate: int | None = None, - file_name: str = "audio.wav", - encode_sample_rate: int = 16000, - max_duration: float | None = None, - mono: bool = True, - audio_format: str = "mp3", - bitrate: str = "64k", -) -> dict[ - Literal[ - "type", - "audio", - "format", - "mimetype", - "audio_samples", - "audio_seconds", - "audio_bytes", - "file_name", - ], - str | int | float | bytes | None, -]: - """Decode audio (if necessary) and re-encode to specified format.""" - samples = _decode_audio(audio, sample_rate=sample_rate, max_duration=max_duration) - - bitrate_val = ( - int(bitrate.rstrip("k")) * 1000 if bitrate.endswith("k") else int(bitrate) - ) - format_val = audio_format.lower() - - encoded_audio = _encode_audio( - samples=samples, - resample_rate=encode_sample_rate, - bitrate=bitrate_val, - audio_format=format_val, - mono=mono, - ) - - return { - "type": "audio_file", - "audio": encoded_audio, - "file_name": get_file_name(audio) - if isinstance(audio, str | Path) - else file_name, - "format": audio_format, - "mimetype": f"audio/{format_val}", - "audio_samples": samples.sample_rate, - "audio_seconds": samples.duration_seconds, - "audio_bytes": len(encoded_audio), - } - - -def _decode_audio( # noqa: C901, PLR0912 - audio: AudioDecoder - | bytes - | str - | Path - | np.ndarray - | torch.Tensor - | dict[str, Any], - sample_rate: int | None = None, - max_duration: float | None = None, -) -> AudioSamples: - """Decode audio from various input types into AudioSamples.""" - # If input is a dict, unwrap it into a function call - if isinstance(audio, dict): - sample_rate = audio.get("sample_rate", audio.get("sampling_rate", sample_rate)) - if "data" not in audio and "url" not in audio: - raise ValueError( - f"Audio dict must contain either 'data' or 'url' keys, got {audio}" - ) - audio_data = audio["data"] if "data" in audio else audio.get("url") - if audio_data is None: - raise ValueError( - f"Audio dict must contain either 'data' or 'url' keys, got {audio}" - ) - return _decode_audio( - audio=audio_data, - sample_rate=sample_rate, - max_duration=max_duration, - ) - - # Convert numpy array to torch tensor and re-call - if isinstance(audio, np.ndarray): - return _decode_audio( - audio=torch.from_numpy(audio), - sample_rate=sample_rate, - max_duration=max_duration, - ) - - samples: AudioSamples - - data: torch.Tensor | bytes - # HF datasets return AudioDecoder for audio column - if isinstance(audio, AudioDecoder): - samples = audio.get_samples_played_in_range(stop_seconds=max_duration) - elif isinstance(audio, torch.Tensor): - # If float stream assume decoded audio - if torch.is_floating_point(audio): - if sample_rate is None: - raise ValueError("Sample rate must be set for decoded audio") - - full_duration = audio.shape[1] / sample_rate - # If max_duration is set, trim the audio to that duration - if max_duration is not None: - num_samples = int(max_duration * sample_rate) - duration = min(max_duration, full_duration) - data = audio[:, :num_samples] - else: - duration = full_duration - data = audio - - samples = AudioSamples( - data=data, - pts_seconds=0.0, - duration_seconds=duration, - sample_rate=sample_rate, - ) - # If bytes tensor assume encoded audio - elif audio.dtype == torch.uint8: - decoder = AudioDecoder( - source=audio, - sample_rate=sample_rate, - ) - samples = decoder.get_samples_played_in_range(stop_seconds=max_duration) - - else: - raise ValueError(f"Unsupported audio type: {type(audio)}") - - # If bytes, assume encoded audio - elif isinstance(audio, bytes): - decoder = AudioDecoder( - source=audio, - sample_rate=sample_rate, - ) - samples = decoder.get_samples_played_in_range(stop_seconds=max_duration) - - # If str or Path, assume file path or URL to encoded audio - elif isinstance(audio, str | Path): - if isinstance(audio, str) and is_url(audio): - response = httpx.get(audio) - response.raise_for_status() - data = response.content - else: - if not Path(audio).exists(): - raise ValueError(f"Audio file does not exist: {audio}") - data = Path(audio).read_bytes() - decoder = AudioDecoder( - source=data, - ) - samples = decoder.get_samples_played_in_range(stop_seconds=max_duration) - else: - raise ValueError(f"Unsupported audio type: {type(audio)}") - - return samples - - -def _encode_audio( - samples: AudioSamples, - resample_rate: int | None = None, - bitrate: int = 64000, - audio_format: str = "mp3", - mono: bool = True, -) -> bytes: - encoder = AudioEncoder( - samples=samples.data, - sample_rate=samples.sample_rate, - ) - - audio_tensor = encoder.to_tensor( - format=audio_format, - bit_rate=bitrate if audio_format == "mp3" else None, - num_channels=1 if mono else None, - sample_rate=resample_rate, - ) - - return audio_tensor.numpy().tobytes() - - -def get_file_name(path: Path | str) -> str: - """Get file name from path.""" - return Path(path).name +import guidellm.utils.lazy_loader as lazy + +__getattr__, __dir__, __all__ = lazy.attach_extras( + __name__, + attrs={ + "AudioSamples": lazy.ExtraAttr("torchcodec"), + "AudioDecoder": lazy.ExtraAttr("torchcodec.decoders"), + "AudioEncoder": lazy.ExtraAttr("torchcodec.encoders"), + }, + error_message="Please install guidellm[audio] to use audio features", +) diff --git a/src/guidellm/extras/audio.pyi b/src/guidellm/extras/audio.pyi new file mode 100644 index 000000000..d1e7e1ee7 --- /dev/null +++ b/src/guidellm/extras/audio.pyi @@ -0,0 +1,3 @@ +from torchcodec import AudioSamples as AudioSamples +from torchcodec.decoders import AudioDecoder as AudioDecoder +from torchcodec.encoders import AudioEncoder as AudioEncoder diff --git a/src/guidellm/extras/vision.py b/src/guidellm/extras/vision.py index d28cfa97e..3a28edb3d 100644 --- a/src/guidellm/extras/vision.py +++ b/src/guidellm/extras/vision.py @@ -1,279 +1,12 @@ from __future__ import annotations -import base64 -import io -from pathlib import Path -from typing import Any, Literal - -import httpx -import numpy as np - -try: - from PIL import Image as PILImage -except ImportError as e: - raise ImportError( - "Please install guidellm[vision] to use image/video features" - ) from e - -__all__ = [ - "encode_image", - "encode_video", - "get_file_format", - "image_dict_to_pil", - "is_url", - "resize_image", -] - - -def is_url(text: Any) -> bool: - return isinstance(text, str) and text.startswith(("http://", "https://")) - - -def encode_image( - image: bytes | str | Path | np.ndarray | PILImage.Image, - width: int | None = None, - height: int | None = None, - max_size: int | None = None, - max_width: int | None = None, - max_height: int | None = None, - encode_type: Literal["base64", "url"] | None = "base64", -) -> dict[Literal["type", "image", "image_pixels", "image_bytes"], str | int | None]: - """ - Input image types: - - bytes: raw image bytes, decoded with Pillow - - str: file path on disk, url, or already base64 encoded image string - - pathlib.Path: file path on disk - - np.ndarray: image array, decoded with Pillow - - PIL.Image.Image: Pillow image - - datasets.Image: HuggingFace datasets Image object - - max_size: maximum size of the longest edge of the image - max_width: maximum width of the image - max_height: maximum height of the image - - encode_type: None to return the supported format - (url for url, base64 string for others) - "base64" to return base64 encoded string (or download URL and encode) - "url" to return url (only if input is url, otherwise fails) - - Returns a str of either: - - image url - - "data:image/{type};base64, {data}" string - """ - if isinstance(image, str) and is_url(image): - if encode_type == "base64": - response = httpx.get(image) - response.raise_for_status() - return encode_image( - image=response.content, - max_size=max_size, - max_width=max_width, - max_height=max_height, - encode_type="base64", - ) - - if any([width, height, max_size, max_width, max_height]): - raise ValueError(f"Cannot resize image {image} when encode_type is 'url'") - - return { - "type": "image_url", - "image": image, - "image_pixels": None, - "image_bytes": None, - } - - decoded_image: PILImage.Image - - if isinstance(image, bytes): - decoded_image = PILImage.open(io.BytesIO(image)) - elif isinstance(image, str) and image.startswith("data:image/"): - _, encoded = image.split(",", 1) - image_data = base64.b64decode(encoded) - decoded_image = PILImage.open(io.BytesIO(image_data)) - elif isinstance(image, str | Path): - decoded_image = PILImage.open(image) - elif isinstance(image, np.ndarray): - decoded_image = PILImage.fromarray(image) - elif isinstance(image, PILImage.Image): - decoded_image = image - else: - raise ValueError(f"Unsupported image type: {type(image)} for {image}") - - output_image = resize_image( - decoded_image, - width=width, - height=height, - max_width=max_width, - max_height=max_height, - max_size=max_size, - ) - if output_image.mode != "RGB": - output_image = output_image.convert("RGB") - - buffer = io.BytesIO() - output_image.save(buffer, format="JPEG") - image_bytes = buffer.getvalue() - image_base64 = base64.b64encode(image_bytes).decode("utf-8") - - return { - "type": "image_base64", - "image": f"data:image/jpeg;base64,{image_base64}", - "image_pixels": output_image.width * output_image.height, - "image_bytes": len(image_bytes), - } - - -def resize_image( - image: PILImage.Image, - width: int | None = None, - height: int | None = None, - max_width: int | None = None, - max_height: int | None = None, - max_size: int | None = None, -) -> PILImage.Image: - if not isinstance(image, PILImage.Image): - raise ValueError(f"Unsupported image type: {type(image)}") - - if width is not None and height is not None: - return image.resize((width, height), PILImage.Resampling.BILINEAR) - - orig_w, orig_h = image.size - aspect = orig_w / orig_h - - if width is not None: - target_w = width - target_h = round(width / aspect) - elif height is not None: - target_h = height - target_w = round(height * aspect) - else: - target_w, target_h = orig_w, orig_h - - # Normalize max_size → max_width/max_height - if max_size is not None: - max_width = max_width or max_size - max_height = max_height or max_size - - # Apply max constraints (preserve aspect ratio) - if max_width or max_height: - scale_w = max_width / target_w if max_width else 1.0 - scale_h = max_height / target_h if max_height else 1.0 - scale = min(scale_w, scale_h, 1.0) # never upscale - target_w = round(target_w * scale) - target_h = round(target_h * scale) - - if (target_w, target_h) != (orig_w, orig_h): - image = image.resize((target_w, target_h), PILImage.Resampling.BILINEAR) - - return image - - -def image_dict_to_pil(item: dict[str, Any]) -> PILImage.Image: - """ - Decode an encoded image column item to a PIL Image for vLLM multi_modal_data. - - The item must have an "image" key with either a data URL (data:image/...;base64,...) - or an HTTP(S) URL. For data URLs the image is base64-decoded; for URLs the - image is fetched with httpx. - - :param item: Dict with "image" key (data URL or URL string) - :return: PIL Image in RGB if needed - :raises ValueError: If item has no "image" or unsupported format - """ - image_spec = item.get("image") - if not image_spec or not isinstance(image_spec, str): - raise ValueError( - "Encoded image item must have an 'image' key with a data URL or URL string." - ) - if image_spec.startswith("data:image/"): - _, encoded = image_spec.split(",", 1) - data = base64.b64decode(encoded) - decoded_image = PILImage.open(io.BytesIO(data)) - elif image_spec.startswith(("http://", "https://")): - response = httpx.get(image_spec) - response.raise_for_status() - decoded_image = PILImage.open(io.BytesIO(response.content)) - else: - raise ValueError( - "Encoded image 'image' value must be a data:image/... URL or " - f"http(s) URL, got: {image_spec[:80]!r}..." - ) - if decoded_image.mode != "RGB": - decoded_image = decoded_image.convert("RGB") # type: ignore[assignment] - # convert() returns Image; PILImage.open() may be ImageFile - return decoded_image - - -def encode_video( - video: bytes | str | Path, - encode_type: Literal["base64", "url"] | None = "base64", -) -> dict[ - Literal["type", "video", "video_frames", "video_seconds", "video_bytes"], - str | int | float | None, -]: - """ - Input video types: - - bytes: raw video bytes - - str: file path on disk, url, or already base64 encoded video string - - pathlib.Path: file path on disk - - datasets.Video: HuggingFace datasets Video object - - encode_type: None to return the supported format - (url for url, base64 string for others) - "base64" to return base64 encoded string (or download URL and encode) - "url" to return url (only if input is url, otherwise fails) - - Returns a str of either: - - video url - - "data:video/{type};base64, {data}" string - """ - if isinstance(video, str) and is_url(video): - if encode_type == "base64": - response = httpx.get(video) - response.raise_for_status() - return encode_video(video=response.content, encode_type="base64") - - return { - "type": "video_url", - "video": video, - "video_frames": None, - "video_seconds": None, - "video_bytes": None, - } - - if isinstance(video, str) and video.startswith("data:video/"): - data_str = video.split(",", 1)[1] - - return { - "type": "video_base64", - "video": video, - "video_frames": None, - "video_seconds": None, - "video_bytes": len(data_str) * 3 // 4, # base64 to bytes - } - - if isinstance(video, str | Path): - path = Path(video) - video_bytes = path.read_bytes() - video_format = get_file_format(path) - elif isinstance(video, bytes): - video_bytes = video - video_format = "unknown" - else: - raise ValueError(f"Unsupported video type: {type(video)} for {video}") - - video_base64 = base64.b64encode(video_bytes).decode("utf-8") - - return { - "type": "video_base64", - "video": f"data:video/{video_format};base64,{video_base64}", - "video_frames": None, - "video_seconds": None, - "video_bytes": len(video_bytes), - } - - -def get_file_format(path: Path | str) -> str: - """Get file format from path extension.""" - suffix = Path(path).suffix.lower() - return suffix[1:] if suffix.startswith(".") else "unknown" +import guidellm.utils.lazy_loader as lazy + +__getattr__, __dir__, __all__ = lazy.attach_extras( + __name__, + attrs={ + "PILImage": lazy.ExtraAttr("PIL", alias="Image"), + "Image": lazy.ExtraAttr("PIL.Image", alias="Image"), + }, + error_message="Please install guidellm[vision] to use image/video features", +) diff --git a/src/guidellm/extras/vision.pyi b/src/guidellm/extras/vision.pyi new file mode 100644 index 000000000..ec7c49e9f --- /dev/null +++ b/src/guidellm/extras/vision.pyi @@ -0,0 +1,4 @@ +from PIL import Image as _PILImage +from PIL.Image import Image as Image + +PILImage = _PILImage diff --git a/src/guidellm/extras/vllm.py b/src/guidellm/extras/vllm.py index a415e966f..85877721a 100644 --- a/src/guidellm/extras/vllm.py +++ b/src/guidellm/extras/vllm.py @@ -1,13 +1,11 @@ -try: - from vllm import SamplingParams - from vllm.engine.arg_utils import AsyncEngineArgs - from vllm.engine.async_llm_engine import AsyncLLMEngine - from vllm.outputs import RequestOutput +""" +vLLM wrapper with same interface as vLLM. +""" - HAS_VLLM = True -except ImportError: - AsyncLLMEngine = None # type: ignore[assignment, misc] - AsyncEngineArgs = None # type: ignore[assignment, misc] - SamplingParams = None # type: ignore[assignment, misc] - RequestOutput = None # type: ignore[assignment, misc] - HAS_VLLM = False +import guidellm.utils.lazy_loader as lazy + +__getattr__, __dir__, __all__ = lazy.attach_extras( + __name__, + package="vllm", + error_message="Please install vllm to use vLLM features", +) diff --git a/src/guidellm/extras/vllm.pyi b/src/guidellm/extras/vllm.pyi new file mode 100644 index 000000000..175bbee52 --- /dev/null +++ b/src/guidellm/extras/vllm.pyi @@ -0,0 +1,4 @@ +from vllm import AsyncEngineArgs as AsyncEngineArgs +from vllm import AsyncLLMEngine as AsyncLLMEngine +from vllm import RequestOutput as RequestOutput +from vllm import SamplingParams as SamplingParams diff --git a/src/guidellm/utils/audio.py b/src/guidellm/utils/audio.py new file mode 100644 index 000000000..e8fdbd2ce --- /dev/null +++ b/src/guidellm/utils/audio.py @@ -0,0 +1,209 @@ +from __future__ import annotations + +from pathlib import Path +from typing import Any, Literal + +import httpx +import numpy as np +import torch + +# CRITICAL: Use 'import ... as libs' pattern to preserve lazy loading +# This defers errors until attributes are actually accessed +import guidellm.extras.audio as libs + +__all__ = [ + "encode_audio", + "is_url", +] + + +def is_url(text: Any) -> bool: + return isinstance(text, str) and text.startswith(("http://", "https://")) + + +def encode_audio( + audio: libs.AudioDecoder + | bytes + | str + | Path + | np.ndarray + | torch.Tensor + | dict[str, Any], + sample_rate: int | None = None, + file_name: str = "audio.wav", + encode_sample_rate: int = 16000, + max_duration: float | None = None, + mono: bool = True, + audio_format: str = "mp3", + bitrate: str = "64k", +) -> dict[ + Literal[ + "type", + "audio", + "format", + "mimetype", + "audio_samples", + "audio_seconds", + "audio_bytes", + "file_name", + ], + str | int | float | bytes | None, +]: + """Decode audio (if necessary) and re-encode to specified format.""" + samples = _decode_audio(audio, sample_rate=sample_rate, max_duration=max_duration) + + bitrate_val = ( + int(bitrate.rstrip("k")) * 1000 if bitrate.endswith("k") else int(bitrate) + ) + format_val = audio_format.lower() + + encoded_audio = _encode_audio( + samples=samples, + resample_rate=encode_sample_rate, + bitrate=bitrate_val, + audio_format=format_val, + mono=mono, + ) + + return { + "type": "audio_file", + "audio": encoded_audio, + "file_name": get_file_name(audio) + if isinstance(audio, str | Path) + else file_name, + "format": audio_format, + "mimetype": f"audio/{format_val}", + "audio_samples": samples.sample_rate, + "audio_seconds": samples.duration_seconds, + "audio_bytes": len(encoded_audio), + } + + +def _decode_audio( # noqa: C901, PLR0912 + audio: libs.AudioDecoder + | bytes + | str + | Path + | np.ndarray + | torch.Tensor + | dict[str, Any], + sample_rate: int | None = None, + max_duration: float | None = None, +) -> libs.AudioSamples: + """Decode audio from various input types into AudioSamples.""" + # If input is a dict, unwrap it into a function call + if isinstance(audio, dict): + sample_rate = audio.get("sample_rate", audio.get("sampling_rate", sample_rate)) + if "data" not in audio and "url" not in audio: + raise ValueError( + f"Audio dict must contain either 'data' or 'url' keys, got {audio}" + ) + audio_data = audio["data"] if "data" in audio else audio.get("url") + if audio_data is None: + raise ValueError( + f"Audio dict must contain either 'data' or 'url' keys, got {audio}" + ) + return _decode_audio( + audio=audio_data, + sample_rate=sample_rate, + max_duration=max_duration, + ) + + # Convert numpy array to torch tensor and re-call + if isinstance(audio, np.ndarray): + return _decode_audio( + audio=torch.from_numpy(audio), + sample_rate=sample_rate, + max_duration=max_duration, + ) + + data: torch.Tensor | bytes + # HF datasets return AudioDecoder for audio column + if isinstance(audio, libs.AudioDecoder): + samples = audio.get_samples_played_in_range(stop_seconds=max_duration) + elif isinstance(audio, torch.Tensor): + # If float stream assume decoded audio + if torch.is_floating_point(audio): + if sample_rate is None: + raise ValueError("Sample rate must be set for decoded audio") + + full_duration = audio.shape[1] / sample_rate + # If max_duration is set, trim the audio to that duration + if max_duration is not None: + num_samples = int(max_duration * sample_rate) + duration = min(max_duration, full_duration) + data = audio[:, :num_samples] + else: + duration = full_duration + data = audio + + samples = libs.AudioSamples( + data=data, + pts_seconds=0.0, + duration_seconds=duration, + sample_rate=sample_rate, + ) + # If bytes tensor assume encoded audio + elif audio.dtype == torch.uint8: + decoder = libs.AudioDecoder( + source=audio, + sample_rate=sample_rate, + ) + samples = decoder.get_samples_played_in_range(stop_seconds=max_duration) + + else: + raise ValueError(f"Unsupported audio type: {type(audio)}") + + # If bytes, assume encoded audio + elif isinstance(audio, bytes): + decoder = libs.AudioDecoder( + source=audio, + sample_rate=sample_rate, + ) + samples = decoder.get_samples_played_in_range(stop_seconds=max_duration) + + # If str or Path, assume file path or URL to encoded audio + elif isinstance(audio, str | Path): + if isinstance(audio, str) and is_url(audio): + response = httpx.get(audio) + response.raise_for_status() + data = response.content + else: + if not Path(audio).exists(): + raise ValueError(f"Audio file does not exist: {audio}") + data = Path(audio).read_bytes() + decoder = libs.AudioDecoder( + source=data, + ) + samples = decoder.get_samples_played_in_range(stop_seconds=max_duration) + else: + raise ValueError(f"Unsupported audio type: {type(audio)}") + + return samples + + +def _encode_audio( + samples: libs.AudioSamples, + resample_rate: int | None = None, + bitrate: int = 64000, + audio_format: str = "mp3", + mono: bool = True, +) -> bytes: + encoder = libs.AudioEncoder( + samples=samples.data, + sample_rate=samples.sample_rate, + ) + + audio_tensor = encoder.to_tensor( + format=audio_format, + bit_rate=bitrate if audio_format == "mp3" else None, + num_channels=1 if mono else None, + sample_rate=resample_rate, + ) + + return audio_tensor.numpy().tobytes() + + +def get_file_name(path: Path | str) -> str: + """Get file name from path.""" + return Path(path).name diff --git a/src/guidellm/utils/lazy_loader.py b/src/guidellm/utils/lazy_loader.py new file mode 100644 index 000000000..e5503e889 --- /dev/null +++ b/src/guidellm/utils/lazy_loader.py @@ -0,0 +1,503 @@ +# ruff: noqa: PGH004 +# ruff: noqa +""" +lazy_loader +=========== + +Makes it easy to load subpackages and functions on demand. + +File uses code adapted from code with the following license: + +BSD 3-Clause License + +Copyright (c) 2022--2023, Scientific Python project +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +""" + +import ast +import importlib +import importlib.util +import os +import sys +import threading +import types +import typing +import warnings + +__version__ = "0.6rc0.dev0" +__all__ = ["ExtraAttr", "attach", "attach_extras", "attach_stub", "load"] + + +class ExtraAttr(typing.NamedTuple): + """Descriptor for a lazily imported attribute in :func:`attach_extras`. + + :param source: Dotted module path to import from. + :param alias: Attribute name inside *source*. When ``None`` (the + default), the dictionary key passed to ``attach_extras`` is used. + """ + + source: str + alias: str | None = None + + +threadlock = threading.Lock() + + +def attach( + package_name, + submodules=None, + submod_attrs=None, + lazy_submodules=False, +): + """Attach lazily loaded submodules, functions, or other attributes. + + Typically, modules import submodules and attributes as follows:: + + import mysubmodule + import anothersubmodule + + from .foo import someattr + + The idea is to replace a package's `__getattr__`, `__dir__`, and + `__all__`, such that all imports work exactly the way they would + with normal imports, except that the import occurs upon first use. + + The typical way to call this function, replacing the above imports, is:: + + __getattr__, __dir__, __all__ = lazy.attach( + __name__, ["mysubmodule", "anothersubmodule"], {"foo": ["someattr"]} + ) + + Parameters + ---------- + package_name : str + Typically use ``__name__``. + submodules : set + List of submodules to attach. + submod_attrs : dict + Dictionary of submodule -> list of attributes / functions. + These attributes are imported as they are used. + lazy_submodules : bool + Whether to lazily load submodules. If set to `True`, submodules are + returned as lazy proxies. Note that attribute access from + submod_attrs will trigger the import of the submodule. + + Returns + ------- + __getattr__, __dir__, __all__ + + """ + if submod_attrs is None: + submod_attrs = {} + + if submodules is None: + submodules = set() + else: + submodules = set(submodules) + + attr_to_modules = { + attr: mod for mod, attrs in submod_attrs.items() for attr in attrs + } + + __all__ = sorted(submodules | attr_to_modules.keys()) + + def __getattr__(name): + if name in submodules: + submod_path = f"{package_name}.{name}" + if lazy_submodules: + return load(submod_path, suppress_warning=True) + else: + return importlib.import_module(submod_path) + elif name in attr_to_modules: + submod_path = f"{package_name}.{attr_to_modules[name]}" + submod = importlib.import_module(submod_path) + attr = getattr(submod, name) + + # If the attribute lives in a file (module) with the same + # name as the attribute, ensure that the attribute and *not* + # the module is accessible on the package. + if name == attr_to_modules[name]: + pkg = sys.modules[package_name] + pkg.__dict__[name] = attr + + return attr + else: + raise AttributeError(f"No {package_name} attribute {name}") + + def __dir__(): + return __all__.copy() + + eager_import = os.environ.get("EAGER_IMPORT", "") not in ("0", "") + if eager_import: + for attr in set(attr_to_modules.keys()) | submodules: + __getattr__(attr) + + return __getattr__, __dir__, __all__.copy() + + +def _attach_extras_attrs(module_name, attrs, error_message): + _attr_map = {} + for export_name, spec in attrs.items(): + source_attr = spec.alias if spec.alias is not None else export_name + _attr_map[export_name] = (spec.source, source_attr) + + _all = sorted(_attr_map.keys()) + + def __getattr__(name): + if name not in _attr_map: + raise AttributeError(f"module {module_name!r} has no attribute {name!r}") + source_mod_name, source_attr = _attr_map[name] + try: + source_mod = importlib.import_module(source_mod_name) + except ImportError as exc: + raise AttributeError(error_message) from exc + try: + value = getattr(source_mod, source_attr) + except AttributeError: + try: + value = importlib.import_module(f"{source_mod_name}.{source_attr}") + except ImportError: + raise AttributeError(error_message) from None + if module_name in sys.modules: + sys.modules[module_name].__dict__[name] = value + return value + + def __dir__(): + return _all.copy() + + return __getattr__, __dir__, _all.copy() + + +def _attach_extras_package(module_name, package, error_message): + _cached_pkg = {} + + def _get_package(): + if "mod" not in _cached_pkg: + try: + _cached_pkg["mod"] = importlib.import_module(package) + except ImportError as exc: + raise AttributeError(error_message) from exc + return _cached_pkg["mod"] + + def __getattr__(name): + pkg = _get_package() + try: + value = getattr(pkg, name) + except AttributeError: + raise AttributeError( + f"module {module_name!r} has no attribute {name!r}" + ) from None + if module_name in sys.modules: + sys.modules[module_name].__dict__[name] = value + return value + + def __dir__(): + try: + return list(dir(_get_package())) + except AttributeError: + return [] + + return __getattr__, __dir__, [] + + +def attach_extras( + module_name, + *, + attrs=None, + package=None, + error_message="Required optional dependency is not installed", +): + """Attach lazily loaded attributes from optional external packages. + + Designed for 'extras' modules that re-export symbols from optional + dependencies. The resulting module is always safe to import; errors + are deferred until an attribute is actually accessed. + + Exactly one of ``attrs`` or ``package`` must be provided. + + :param module_name: Typically use ``__name__``. + :param attrs: Map of exported names to :class:`ExtraAttr` descriptors. + Each value specifies the *source* module and an optional *alias* + (the attribute name inside *source*, when it differs from the + dictionary key). + :param package: Name of a package whose public attributes should be + proxied wholesale. + :param error_message: Human-readable message included in the + ``AttributeError`` raised when the optional dependency is not + installed. + :returns: ``(__getattr__, __dir__, __all__)`` + """ + if (attrs is None) == (package is None): + raise ValueError("attach_extras() requires exactly one of 'attrs' or 'package'") + + if attrs is not None: + return _attach_extras_attrs(module_name, attrs, error_message) + + return _attach_extras_package(module_name, package, error_message) + + +class DelayedImportErrorModule(types.ModuleType): + def __init__(self, frame_data, *args, message, **kwargs): + self.__frame_data = frame_data + self.__message = message + super().__init__(*args, **kwargs) + + def __getattr__(self, x): + fd = self.__frame_data + raise ModuleNotFoundError( + f"{self.__message}\n\n" + "This error is lazily reported, having originally occurred in\n" + f" File {fd['filename']}, line {fd['lineno']}, in {fd['function']}\n\n" + f"----> {''.join(fd['code_context'] or '').strip()}" + ) + + +def load(fullname, *, require=None, error_on_import=False, suppress_warning=False): + """Return a lazily imported proxy for a module. + + We often see the following pattern:: + + def myfunc(): + import numpy as np + np.norm(...) + .... + + Putting the import inside the function prevents, in this case, + `numpy`, from being imported at function definition time. + That saves time if `myfunc` ends up not being called. + + This `load` function returns a proxy module that, upon access, imports + the actual module. So the idiom equivalent to the above example is:: + + np = lazy.load("numpy") + + def myfunc(): + np.norm(...) + .... + + The initial import time is fast because the actual import is delayed + until the first attribute is requested. The overall import time may + decrease as well for users that don't make use of large portions + of your library. + + Warning + ------- + While lazily loading *sub*packages technically works, it causes the + package (that contains the subpackage) to be eagerly loaded even + if the package is already lazily loaded. + So, you probably shouldn't use subpackages with this `load` feature. + Instead you should encourage the package maintainers to use the + `lazy_loader.attach` to make their subpackages load lazily. + + Parameters + ---------- + fullname : str + The full name of the module or submodule to import. For example:: + + sp = lazy.load("scipy") # import scipy as sp + + require : str + A dependency requirement as defined in PEP-508. For example:: + + "numpy >=1.24" + + If defined, the proxy module will raise an error if the installed + version does not satisfy the requirement. + + error_on_import : bool + Whether to postpone raising import errors until the module is accessed. + If set to `True`, import errors are raised as soon as `load` is called. + + suppress_warning : bool + Whether to prevent emitting a warning when loading subpackages. + If set to `True`, no warning will occur. + + Returns + ------- + pm : importlib.util._LazyModule + Proxy module. Can be used like any regularly imported module. + Actual loading of the module occurs upon first attribute request. + + """ + with threadlock: + module = sys.modules.get(fullname) + have_module = module is not None + + # Most common, short-circuit + if have_module and require is None: + return module + + if not suppress_warning and "." in fullname: + msg = ( + "subpackages can technically be lazily loaded, but it causes the " + "package to be eagerly loaded even if it is already lazily loaded. " + "So, you probably shouldn't use subpackages with this lazy feature." + ) + warnings.warn(msg, RuntimeWarning) + + spec = None + + if not have_module: + spec = importlib.util.find_spec(fullname) + have_module = spec is not None + + if not have_module: + not_found_message = f"No module named '{fullname}'" + elif require is not None: + try: + have_module = _check_requirement(require) + except ModuleNotFoundError as e: + raise ValueError( + f"Found module '{fullname}' but cannot test " + "requirement '{require}'. " + "Requirements must match distribution name, not module name." + ) from e + + not_found_message = f"No distribution can be found matching '{require}'" + + if not have_module: + if error_on_import: + raise ModuleNotFoundError(not_found_message) + import inspect + + parent = inspect.stack()[1] + frame_data = { + "filename": parent.filename, + "lineno": parent.lineno, + "function": parent.function, + "code_context": parent.code_context, + } + del parent + return DelayedImportErrorModule( + frame_data, + "DelayedImportErrorModule", + message=not_found_message, + ) + + if spec is not None and spec.loader is not None: + module = importlib.util.module_from_spec(spec) + sys.modules[fullname] = module + + loader = importlib.util.LazyLoader(spec.loader) + loader.exec_module(module) + + return module + + +def _check_requirement(require: str) -> bool: + """Verify that a package requirement is satisfied + + If the package is required, a ``ModuleNotFoundError`` is raised + by ``importlib.metadata``. + + Parameters + ---------- + require : str + A dependency requirement as defined in PEP-508 + + Returns + ------- + satisfied : bool + True if the installed version of the dependency matches + the specified version, False otherwise. + """ + import importlib.metadata + + import packaging.requirements + + req = packaging.requirements.Requirement(require) + return req.specifier.contains( + importlib.metadata.version(req.name), + prereleases=True, + ) + + +class _StubVisitor(ast.NodeVisitor): + """AST visitor to parse a stub file for submodules and submod_attrs.""" + + def __init__(self): + self._submodules = set() + self._submod_attrs = {} + + def visit_ImportFrom(self, node: ast.ImportFrom): + if node.level != 1: + raise ValueError( + "Only within-module imports are supported (`from .* import`)" + ) + if node.module: + attrs: list = self._submod_attrs.setdefault(node.module, []) + aliases = [alias.name for alias in node.names] + if "*" in aliases: + raise ValueError( + "lazy stub loader does not support star import " + f"`from {node.module} import *`" + ) + attrs.extend(aliases) + else: + self._submodules.update(alias.name for alias in node.names) + + +def attach_stub(package_name: str, filename: str): + """Attach lazily loaded submodules, functions from a type stub. + + This is a variant on ``attach`` that will parse a `.pyi` stub file to + infer ``submodules`` and ``submod_attrs``. This allows static type checkers + to find imports, while still providing lazy loading at runtime. + + Parameters + ---------- + package_name : str + Typically use ``__name__``. + filename : str + Path to `.py` file which has an adjacent `.pyi` file. + Typically use ``__file__``. + + Returns + ------- + __getattr__, __dir__, __all__ + The same output as ``attach``. + + Raises + ------ + ValueError + If a stub file is not found for `filename`, or if the stubfile is formmated + incorrectly (e.g. if it contains an relative import from outside of the module) + """ + stubfile = ( + filename if filename.endswith("i") else f"{os.path.splitext(filename)[0]}.pyi" + ) + + if not os.path.exists(stubfile): + raise ValueError(f"Cannot load imports from non-existent stub {stubfile!r}") + + with open(stubfile) as f: + stub_node = ast.parse(f.read()) + + visitor = _StubVisitor() + visitor.visit(stub_node) + return attach(package_name, visitor._submodules, visitor._submod_attrs) diff --git a/src/guidellm/utils/vision.py b/src/guidellm/utils/vision.py new file mode 100644 index 000000000..6a9dd6ddc --- /dev/null +++ b/src/guidellm/utils/vision.py @@ -0,0 +1,276 @@ +from __future__ import annotations + +import base64 +import io +from pathlib import Path +from typing import Any, Literal + +import httpx +import numpy as np + +# CRITICAL: Use 'import ... as libs' pattern to preserve lazy loading +# This defers errors until attributes are actually accessed +import guidellm.extras.vision as libs + +__all__ = [ + "encode_image", + "encode_video", + "get_file_format", + "image_dict_to_pil", + "is_url", + "resize_image", +] + + +def is_url(text: Any) -> bool: + return isinstance(text, str) and text.startswith(("http://", "https://")) + + +def encode_image( + image: bytes | str | Path | np.ndarray | libs.Image, + width: int | None = None, + height: int | None = None, + max_size: int | None = None, + max_width: int | None = None, + max_height: int | None = None, + encode_type: Literal["base64", "url"] | None = "base64", +) -> dict[Literal["type", "image", "image_pixels", "image_bytes"], str | int | None]: + """ + Input image types: + - bytes: raw image bytes, decoded with Pillow + - str: file path on disk, url, or already base64 encoded image string + - pathlib.Path: file path on disk + - np.ndarray: image array, decoded with Pillow + - PIL.Image.Image: Pillow image + - datasets.Image: HuggingFace datasets Image object + + max_size: maximum size of the longest edge of the image + max_width: maximum width of the image + max_height: maximum height of the image + + encode_type: None to return the supported format + (url for url, base64 string for others) + "base64" to return base64 encoded string (or download URL and encode) + "url" to return url (only if input is url, otherwise fails) + + Returns a str of either: + - image url + - "data:image/{type};base64, {data}" string + """ + if isinstance(image, str) and is_url(image): + if encode_type == "base64": + response = httpx.get(image) + response.raise_for_status() + return encode_image( + image=response.content, + max_size=max_size, + max_width=max_width, + max_height=max_height, + encode_type="base64", + ) + + if any([width, height, max_size, max_width, max_height]): + raise ValueError(f"Cannot resize image {image} when encode_type is 'url'") + + return { + "type": "image_url", + "image": image, + "image_pixels": None, + "image_bytes": None, + } + + decoded_image: libs.Image + + if isinstance(image, bytes): + decoded_image = libs.PILImage.open(io.BytesIO(image)) + elif isinstance(image, str) and image.startswith("data:image/"): + _, encoded = image.split(",", 1) + image_data = base64.b64decode(encoded) + decoded_image = libs.PILImage.open(io.BytesIO(image_data)) + elif isinstance(image, str | Path): + decoded_image = libs.PILImage.open(image) + elif isinstance(image, np.ndarray): + decoded_image = libs.PILImage.fromarray(image) + elif isinstance(image, libs.Image): + decoded_image = image + else: + raise ValueError(f"Unsupported image type: {type(image)} for {image}") + + output_image = resize_image( + decoded_image, + width=width, + height=height, + max_width=max_width, + max_height=max_height, + max_size=max_size, + ) + if output_image.mode != "RGB": + output_image = output_image.convert("RGB") + + buffer = io.BytesIO() + output_image.save(buffer, format="JPEG") + image_bytes = buffer.getvalue() + image_base64 = base64.b64encode(image_bytes).decode("utf-8") + + return { + "type": "image_base64", + "image": f"data:image/jpeg;base64,{image_base64}", + "image_pixels": output_image.width * output_image.height, + "image_bytes": len(image_bytes), + } + + +def resize_image( + image: libs.Image, + width: int | None = None, + height: int | None = None, + max_width: int | None = None, + max_height: int | None = None, + max_size: int | None = None, +) -> libs.Image: + if not isinstance(image, libs.Image): + raise ValueError(f"Unsupported image type: {type(image)}") + + if width is not None and height is not None: + return image.resize((width, height), libs.PILImage.Resampling.BILINEAR) + + orig_w, orig_h = image.size + aspect = orig_w / orig_h + + if width is not None: + target_w = width + target_h = round(width / aspect) + elif height is not None: + target_h = height + target_w = round(height * aspect) + else: + target_w, target_h = orig_w, orig_h + + # Normalize max_size → max_width/max_height + if max_size is not None: + max_width = max_width or max_size + max_height = max_height or max_size + + # Apply max constraints (preserve aspect ratio) + if max_width or max_height: + scale_w = max_width / target_w if max_width else 1.0 + scale_h = max_height / target_h if max_height else 1.0 + scale = min(scale_w, scale_h, 1.0) # never upscale + target_w = round(target_w * scale) + target_h = round(target_h * scale) + + if (target_w, target_h) != (orig_w, orig_h): + image = image.resize((target_w, target_h), libs.PILImage.Resampling.BILINEAR) + + return image + + +def image_dict_to_pil(item: dict[str, Any]) -> libs.Image: + """ + Decode an encoded image column item to a PIL Image for vLLM multi_modal_data. + + The item must have an "image" key with either a data URL (data:image/...;base64,...) + or an HTTP(S) URL. For data URLs the image is base64-decoded; for URLs the + image is fetched with httpx. + + :param item: Dict with "image" key (data URL or URL string) + :return: PIL Image in RGB if needed + :raises ValueError: If item has no "image" or unsupported format + """ + image_spec = item.get("image") + if not image_spec or not isinstance(image_spec, str): + raise ValueError( + "Encoded image item must have an 'image' key with a data URL or URL string." + ) + if image_spec.startswith("data:image/"): + _, encoded = image_spec.split(",", 1) + data = base64.b64decode(encoded) + decoded_image = libs.PILImage.open(io.BytesIO(data)) + elif image_spec.startswith(("http://", "https://")): + response = httpx.get(image_spec) + response.raise_for_status() + decoded_image = libs.PILImage.open(io.BytesIO(response.content)) + else: + raise ValueError( + "Encoded image 'image' value must be a data:image/... URL or " + f"http(s) URL, got: {image_spec[:80]!r}..." + ) + if decoded_image.mode != "RGB": + decoded_image = decoded_image.convert("RGB") # type: ignore[assignment] + # convert() returns Image; PILImage.open() may be ImageFile + return decoded_image + + +def encode_video( + video: bytes | str | Path, + encode_type: Literal["base64", "url"] | None = "base64", +) -> dict[ + Literal["type", "video", "video_frames", "video_seconds", "video_bytes"], + str | int | float | None, +]: + """ + Input video types: + - bytes: raw video bytes + - str: file path on disk, url, or already base64 encoded video string + - pathlib.Path: file path on disk + - datasets.Video: HuggingFace datasets Video object + + encode_type: None to return the supported format + (url for url, base64 string for others) + "base64" to return base64 encoded string (or download URL and encode) + "url" to return url (only if input is url, otherwise fails) + + Returns a str of either: + - video url + - "data:video/{type};base64, {data}" string + """ + if isinstance(video, str) and is_url(video): + if encode_type == "base64": + response = httpx.get(video) + response.raise_for_status() + return encode_video(video=response.content, encode_type="base64") + + return { + "type": "video_url", + "video": video, + "video_frames": None, + "video_seconds": None, + "video_bytes": None, + } + + if isinstance(video, str) and video.startswith("data:video/"): + data_str = video.split(",", 1)[1] + + return { + "type": "video_base64", + "video": video, + "video_frames": None, + "video_seconds": None, + "video_bytes": len(data_str) * 3 // 4, # base64 to bytes + } + + if isinstance(video, str | Path): + path = Path(video) + video_bytes = path.read_bytes() + video_format = get_file_format(path) + elif isinstance(video, bytes): + video_bytes = video + video_format = "unknown" + else: + raise ValueError(f"Unsupported video type: {type(video)} for {video}") + + video_base64 = base64.b64encode(video_bytes).decode("utf-8") + + return { + "type": "video_base64", + "video": f"data:video/{video_format};base64,{video_base64}", + "video_frames": None, + "video_seconds": None, + "video_bytes": len(video_bytes), + } + + +def get_file_format(path: Path | str) -> str: + """Get file format from path extension.""" + suffix = Path(path).suffix.lower() + return suffix[1:] if suffix.startswith(".") else "unknown" diff --git a/tests/mocks/fake_pkg/__init__.py b/tests/mocks/fake_pkg/__init__.py new file mode 100644 index 000000000..189b91df4 --- /dev/null +++ b/tests/mocks/fake_pkg/__init__.py @@ -0,0 +1,8 @@ +# ruff: noqa: PGH004 +# ruff: noqa + +import guidellm.utils.lazy_loader as lazy + +__getattr__, __lazy_dir__, __all__ = lazy.attach( + __name__, submod_attrs={"some_func": ["some_func", "aux_func"]} +) diff --git a/tests/mocks/fake_pkg/__init__.pyi b/tests/mocks/fake_pkg/__init__.pyi new file mode 100644 index 000000000..81c1b54ec --- /dev/null +++ b/tests/mocks/fake_pkg/__init__.pyi @@ -0,0 +1,4 @@ +# ruff: noqa: PGH004 +# ruff: noqa + +from .some_func import aux_func, some_func diff --git a/tests/mocks/fake_pkg/some_func.py b/tests/mocks/fake_pkg/some_func.py new file mode 100644 index 000000000..5b7c10e76 --- /dev/null +++ b/tests/mocks/fake_pkg/some_func.py @@ -0,0 +1,6 @@ +def some_func(): + """Function with same name as submodule.""" + + +def aux_func(): + """Auxiliary function.""" diff --git a/tests/unit/backends/test_backend.py b/tests/unit/backends/test_backend.py index 14c4155e4..898016764 100644 --- a/tests/unit/backends/test_backend.py +++ b/tests/unit/backends/test_backend.py @@ -481,9 +481,8 @@ def test_vllm_python_backend_registered(self): ) assert Backend.is_registered("vllm_python") - with patch("guidellm.backends.vllm_python.vllm._check_vllm_available"): - args = VLLMPythonBackendArgs(model="test-model") - backend = Backend.create(args) + args = VLLMPythonBackendArgs(model="test-model") + backend = Backend.create(args) assert isinstance(backend, VLLMPythonBackend) assert backend._args.model == "test-model" assert backend.kind == "vllm_python" diff --git a/tests/unit/backends/vllm_python/test_vllm.py b/tests/unit/backends/vllm_python/test_vllm.py index dfce9062a..5ab1d5227 100644 --- a/tests/unit/backends/vllm_python/test_vllm.py +++ b/tests/unit/backends/vllm_python/test_vllm.py @@ -11,7 +11,7 @@ import asyncio from types import SimpleNamespace -from unittest.mock import Mock, patch +from unittest.mock import MagicMock, Mock, patch import numpy as np import pytest @@ -61,13 +61,9 @@ def _mock_audio_decode_result(audio_array: np.ndarray) -> Mock: @pytest.fixture def backend(): """VLLMPythonBackend instance without requiring vllm to be installed.""" - with ( - patch("guidellm.backends.vllm_python.vllm._check_vllm_available"), - patch( - "guidellm.backends.vllm_python.vllm.SamplingParams", - _fake_sampling_params, - ), - ): + mock_vllm_extras = MagicMock() + mock_vllm_extras.SamplingParams = _fake_sampling_params + with patch("guidellm.backends.vllm_python.vllm.vllm", mock_vllm_extras): yield _make_vllm_backend(model="test-model") @@ -83,10 +79,7 @@ def test_text_column_resolves_to_prompt(self, backend): Request with text_column resolves to a prompt string via plain format. ## WRITTEN BY AI ## """ - with patch("guidellm.backends.vllm_python.vllm._check_vllm_available"): - backend_plain = _make_vllm_backend( - model="test-model", request_format="plain" - ) + backend_plain = _make_vllm_backend(model="test-model", request_format="plain") request = GenerationRequest(columns={"text_column": ["hello"]}) resolved = backend_plain._resolve_request(request) assert isinstance(resolved, _ResolvedRequest) @@ -100,10 +93,9 @@ def test_stream_false_propagated(self): When backend.stream=False, resolved.stream is False. ## WRITTEN BY AI ## """ - with patch("guidellm.backends.vllm_python.vllm._check_vllm_available"): - backend = _make_vllm_backend( - model="test-model", stream=False, request_format="plain" - ) + backend = _make_vllm_backend( + model="test-model", stream=False, request_format="plain" + ) request = GenerationRequest(columns={"text_column": ["hello"]}) resolved = backend._resolve_request(request) assert resolved.stream is False @@ -114,8 +106,7 @@ def test_prefix_and_text_columns_build_messages(self): Columns with prefix_column and text_column are formatted into prompt. ## WRITTEN BY AI ## """ - with patch("guidellm.backends.vllm_python.vllm._check_vllm_available"): - backend = _make_vllm_backend(model="test-model", request_format="plain") + backend = _make_vllm_backend(model="test-model", request_format="plain") request = GenerationRequest( columns={ "prefix_column": ["System prompt"], @@ -131,10 +122,7 @@ def test_text_only_no_media_multi_modal_data_none(self, backend): Request with only text columns leaves multi_modal_data None. ## WRITTEN BY AI ## """ - with patch("guidellm.backends.vllm_python.vllm._check_vllm_available"): - backend_plain = _make_vllm_backend( - model="test-model", request_format="plain" - ) + backend_plain = _make_vllm_backend(model="test-model", request_format="plain") request = GenerationRequest(columns={"text_column": ["hello"]}) resolved = backend_plain._resolve_request(request) assert resolved.multi_modal_data is None @@ -154,7 +142,7 @@ def test_audio_column_only_resolves_with_placeholder_prompt(self, backend): } ) with patch( - "guidellm.backends.vllm_python.vllm._decode_audio", + "guidellm.utils.audio._decode_audio", return_value=mock_decode_result, ): resolved = backend._resolve_request(request) @@ -172,8 +160,7 @@ def test_image_column_resolves_with_multi_modal_data(self): ## WRITTEN BY AI ## """ mock_pil = Mock() - with patch("guidellm.backends.vllm_python.vllm._check_vllm_available"): - backend = _make_vllm_backend(model="test-model", request_format="plain") + backend = _make_vllm_backend(model="test-model", request_format="plain") request = GenerationRequest( columns={ "text_column": ["Describe this"], @@ -183,7 +170,7 @@ def test_image_column_resolves_with_multi_modal_data(self): } ) with patch( - "guidellm.backends.vllm_python.vllm.image_dict_to_pil", + "guidellm.utils.vision.image_dict_to_pil", return_value=mock_pil, ): resolved = backend._resolve_request(request) @@ -224,10 +211,9 @@ def fake_apply_chat_template( mock_tokenizer = Mock() mock_tokenizer.apply_chat_template = fake_apply_chat_template - with patch("guidellm.backends.vllm_python.vllm._check_vllm_available"): - backend = _make_vllm_backend( - model="test-model", request_format="default-template" - ) + backend = _make_vllm_backend( + model="test-model", request_format="default-template" + ) backend._engine = Mock() backend._engine.tokenizer = mock_tokenizer @@ -238,7 +224,7 @@ def fake_apply_chat_template( } ) with patch( - "guidellm.backends.vllm_python.vllm._decode_audio", + "guidellm.utils.audio._decode_audio", return_value=mock_decode_result, ): resolved = backend._resolve_request(request) @@ -263,8 +249,7 @@ def test_audio_and_text_plain_format_uses_placeholder_string(self): mock_audio_array = np.array([0.0, 0.1], dtype=np.float32) mock_decode_result = _mock_audio_decode_result(mock_audio_array) - with patch("guidellm.backends.vllm_python.vllm._check_vllm_available"): - backend = _make_vllm_backend(model="test-model", request_format="plain") + backend = _make_vllm_backend(model="test-model", request_format="plain") request = GenerationRequest( columns={ @@ -273,7 +258,7 @@ def test_audio_and_text_plain_format_uses_placeholder_string(self): } ) with patch( - "guidellm.backends.vllm_python.vllm._decode_audio", + "guidellm.utils.audio._decode_audio", return_value=mock_decode_result, ): resolved = backend._resolve_request(request) @@ -303,11 +288,10 @@ def test_build_placeholder_prefix_image_override(self): _build_placeholder_prefix uses image_placeholder override. ## WRITTEN BY AI ## """ - with patch("guidellm.backends.vllm_python.vllm._check_vllm_available"): - backend_custom = _make_vllm_backend( - model="Qwen/Qwen3-VL-2B-Instruct", - image_placeholder=("<|vision_start|><|image_pad|><|vision_end|>"), - ) + backend_custom = _make_vllm_backend( + model="Qwen/Qwen3-VL-2B-Instruct", + image_placeholder=("<|vision_start|><|image_pad|><|vision_end|>"), + ) result = backend_custom._build_placeholder_prefix({"image": Mock()}) assert result == ("<|vision_start|><|image_pad|><|vision_end|>\n") @@ -387,11 +371,10 @@ def test_build_placeholder_prefix_audio_override(self): _build_placeholder_prefix uses audio_placeholder override. ## WRITTEN BY AI ## """ - with patch("guidellm.backends.vllm_python.vllm._check_vllm_available"): - backend_custom = _make_vllm_backend( - model="zai-org/GLM-ASR-Nano-2512", - audio_placeholder=("<|begin_of_audio|><|pad|><|end_of_audio|>"), - ) + backend_custom = _make_vllm_backend( + model="zai-org/GLM-ASR-Nano-2512", + audio_placeholder=("<|begin_of_audio|><|pad|><|end_of_audio|>"), + ) result = backend_custom._build_placeholder_prefix( {"audio": np.array([0.0], dtype=np.float32)} ) @@ -551,10 +534,7 @@ def test_request_format_plain_produces_concatenated_prompt(self): With request_format=plain, _resolve_request produces plain concatenation. ## WRITTEN BY AI ## """ - with patch("guidellm.backends.vllm_python.vllm._check_vllm_available"): - backend_plain = _make_vllm_backend( - model="test-model", request_format="plain" - ) + backend_plain = _make_vllm_backend(model="test-model", request_format="plain") request = GenerationRequest( columns={ "text_column": ["Hello"], @@ -572,12 +552,11 @@ def test_request_format_chat_completions_raises_not_a_template(self): ValueError with message that includes received value and allowed options. ## WRITTEN BY AI ## """ - with patch("guidellm.backends.vllm_python.vllm._check_vllm_available"): - backend_api = _make_vllm_backend( - model="test-model", request_format="chat_completions" - ) - backend_api._engine = Mock() - backend_api._engine.tokenizer = Mock() + backend_api = _make_vllm_backend( + model="test-model", request_format="chat_completions" + ) + backend_api._engine = Mock() + backend_api._engine.tokenizer = Mock() request = GenerationRequest(columns={"text_column": ["Hi"]}) with pytest.raises(ValueError) as exc_info: backend_api._resolve_request(request) @@ -594,12 +573,11 @@ def test_request_format_default_template_uses_apply_chat_template(self): """ mock_tokenizer = Mock() mock_tokenizer.apply_chat_template.return_value = "formatted_prompt" - with patch("guidellm.backends.vllm_python.vllm._check_vllm_available"): - backend_default = _make_vllm_backend( - model="test-model", request_format="default-template" - ) - backend_default._engine = Mock() - backend_default._engine.tokenizer = mock_tokenizer + backend_default = _make_vllm_backend( + model="test-model", request_format="default-template" + ) + backend_default._engine = Mock() + backend_default._engine.tokenizer = mock_tokenizer request = GenerationRequest(columns={"text_column": ["Hi"]}) resolved = backend_default._resolve_request(request) assert resolved.prompt == "formatted_prompt" @@ -616,10 +594,9 @@ def test_request_format_none_uses_apply_chat_template(self): """ mock_tokenizer = Mock() mock_tokenizer.apply_chat_template.return_value = "default_prompt" - with patch("guidellm.backends.vllm_python.vllm._check_vllm_available"): - backend_none = _make_vllm_backend(model="test-model") - backend_none._engine = Mock() - backend_none._engine.tokenizer = mock_tokenizer + backend_none = _make_vllm_backend(model="test-model") + backend_none._engine = Mock() + backend_none._engine.tokenizer = mock_tokenizer request = GenerationRequest(columns={"text_column": ["Hi"]}) resolved = backend_none._resolve_request(request) assert resolved.prompt == "default_prompt" @@ -633,13 +610,12 @@ def test_request_format_custom_template_string_sets_tokenizer_and_applies(self): """ mock_tokenizer = Mock() mock_tokenizer.apply_chat_template.return_value = "custom_prompt" - with patch("guidellm.backends.vllm_python.vllm._check_vllm_available"): - backend_custom = _make_vllm_backend( - model="test-model", - request_format="{{ messages[0]['content'] }}", - ) - backend_custom._engine = Mock() - backend_custom._engine.tokenizer = mock_tokenizer + backend_custom = _make_vllm_backend( + model="test-model", + request_format="{{ messages[0]['content'] }}", + ) + backend_custom._engine = Mock() + backend_custom._engine.tokenizer = mock_tokenizer request = GenerationRequest(columns={"text_column": ["Hi"]}) resolved = backend_custom._resolve_request(request) assert resolved.prompt == "custom_prompt" @@ -656,12 +632,11 @@ def test_request_format_custom_template_from_file(self, tmp_path): template_file.write_text("Custom: {{ messages[0]['content'] }}") mock_tokenizer = Mock() mock_tokenizer.apply_chat_template.return_value = "Custom: Hi" - with patch("guidellm.backends.vllm_python.vllm._check_vllm_available"): - backend_file = _make_vllm_backend( - model="test-model", request_format=str(template_file) - ) - backend_file._engine = Mock() - backend_file._engine.tokenizer = mock_tokenizer + backend_file = _make_vllm_backend( + model="test-model", request_format=str(template_file) + ) + backend_file._engine = Mock() + backend_file._engine.tokenizer = mock_tokenizer request = GenerationRequest(columns={"text_column": ["Hi"]}) resolved = backend_file._resolve_request(request) assert resolved.prompt == "Custom: Hi" @@ -679,12 +654,11 @@ def test_request_format_file_template_cached_on_second_request(self, tmp_path): ) mock_tokenizer = Mock() mock_tokenizer.apply_chat_template.return_value = "Hi" - with patch("guidellm.backends.vllm_python.vllm._check_vllm_available"): - backend_file = _make_vllm_backend( - model="test-model", request_format=str(template_file) - ) - backend_file._engine = Mock() - backend_file._engine.tokenizer = mock_tokenizer + backend_file = _make_vllm_backend( + model="test-model", request_format=str(template_file) + ) + backend_file._engine = Mock() + backend_file._engine.tokenizer = mock_tokenizer request = GenerationRequest(columns={"text_column": ["Hi"]}) backend_file._resolve_request(request) first_template = mock_tokenizer.chat_template @@ -700,12 +674,11 @@ def test_request_format_file_with_no_markers_raises(self, tmp_path): """ no_markers_file = tmp_path / "plain.txt" no_markers_file.write_text("just plain text") - with patch("guidellm.backends.vllm_python.vllm._check_vllm_available"): - backend_file = _make_vllm_backend( - model="test-model", request_format=str(no_markers_file) - ) - backend_file._engine = Mock() - backend_file._engine.tokenizer = Mock() + backend_file = _make_vllm_backend( + model="test-model", request_format=str(no_markers_file) + ) + backend_file._engine = Mock() + backend_file._engine.tokenizer = Mock() request = GenerationRequest(columns={"text_column": ["Hi"]}) with pytest.raises(ValueError) as exc_info: backend_file._resolve_request(request) @@ -718,12 +691,11 @@ def test_request_format_invalid_jinja2_string_raises(self): request_format with invalid Jinja2 syntax raises ValueError. ## WRITTEN BY AI ## """ - with patch("guidellm.backends.vllm_python.vllm._check_vllm_available"): - backend_bad = _make_vllm_backend( - model="test-model", request_format="{{ unclosed" - ) - backend_bad._engine = Mock() - backend_bad._engine.tokenizer = Mock() + backend_bad = _make_vllm_backend( + model="test-model", request_format="{{ unclosed" + ) + backend_bad._engine = Mock() + backend_bad._engine.tokenizer = Mock() request = GenerationRequest(columns={"text_column": ["Hi"]}) with pytest.raises(ValueError) as exc_info: backend_bad._resolve_request(request) @@ -736,11 +708,10 @@ def test_request_format_stored_on_backend(self): Custom request_format is stored on the backend, not in vllm_config. ## WRITTEN BY AI ## """ - with patch("guidellm.backends.vllm_python.vllm._check_vllm_available"): - backend_custom = _make_vllm_backend( - model="test-model", - request_format="/path/to/template.jinja", - ) + backend_custom = _make_vllm_backend( + model="test-model", + request_format="/path/to/template.jinja", + ) assert backend_custom._args.request_format == "/path/to/template.jinja" assert "chat_template" not in backend_custom._args.vllm_config @@ -750,10 +721,7 @@ def test_request_format_plain_not_in_vllm_config(self): request_format=plain does not add chat_template to vllm_config. ## WRITTEN BY AI ## """ - with patch("guidellm.backends.vllm_python.vllm._check_vllm_available"): - backend_plain = _make_vllm_backend( - model="test-model", request_format="plain" - ) + backend_plain = _make_vllm_backend(model="test-model", request_format="plain") assert backend_plain._args.request_format == "plain" assert "chat_template" not in backend_plain._args.vllm_config @@ -763,10 +731,9 @@ def test_request_format_default_template_not_in_vllm_config(self): request_format=default-template does not add chat_template to vllm_config. ## WRITTEN BY AI ## """ - with patch("guidellm.backends.vllm_python.vllm._check_vllm_available"): - backend_def = _make_vllm_backend( - model="test-model", request_format="default-template" - ) + backend_def = _make_vllm_backend( + model="test-model", request_format="default-template" + ) assert backend_def._args.request_format == "default-template" assert "chat_template" not in backend_def._args.vllm_config @@ -776,8 +743,7 @@ def test_vllm_config_empty_uses_vllm_defaults(self): With vllm_config empty or None, backend only sets model; no extra keys. ## WRITTEN BY AI ## """ - with patch("guidellm.backends.vllm_python.vllm._check_vllm_available"): - backend_empty = _make_vllm_backend(model="test-model", vllm_config={}) + backend_empty = _make_vllm_backend(model="test-model", vllm_config={}) assert backend_empty._args.vllm_config.get("model") == "test-model" assert "tensor_parallel_size" not in backend_empty._args.vllm_config assert "gpu_memory_utilization" not in backend_empty._args.vllm_config @@ -948,17 +914,9 @@ async def test_process_startup_success(self): ## WRITTEN BY AI ## """ mock_engine = Mock() - with ( - patch("guidellm.backends.vllm_python.vllm._check_vllm_available"), - patch( - "guidellm.backends.vllm_python.vllm.AsyncEngineArgs", - return_value=Mock(), - ), - patch( - "guidellm.backends.vllm_python.vllm.AsyncLLMEngine" - ) as mock_engine_cls, - ): - mock_engine_cls.from_engine_args = Mock(return_value=mock_engine) + with patch("guidellm.backends.vllm_python.vllm.vllm") as mock_vllm: + mock_vllm.AsyncEngineArgs.return_value = Mock() + mock_vllm.AsyncLLMEngine.from_engine_args = Mock(return_value=mock_engine) backend = _make_vllm_backend(model="test-model") await backend.process_startup() assert backend._engine is mock_engine @@ -972,17 +930,9 @@ async def test_process_startup_idempotency_raises(self): ## WRITTEN BY AI ## """ mock_engine = Mock() - with ( - patch("guidellm.backends.vllm_python.vllm._check_vllm_available"), - patch( - "guidellm.backends.vllm_python.vllm.AsyncEngineArgs", - return_value=Mock(), - ), - patch( - "guidellm.backends.vllm_python.vllm.AsyncLLMEngine" - ) as mock_engine_cls, - ): - mock_engine_cls.from_engine_args = Mock(return_value=mock_engine) + with patch("guidellm.backends.vllm_python.vllm.vllm") as mock_vllm: + mock_vllm.AsyncEngineArgs.return_value = Mock() + mock_vllm.AsyncLLMEngine.from_engine_args = Mock(return_value=mock_engine) backend = _make_vllm_backend(model="test-model") await backend.process_startup() with pytest.raises(RuntimeError, match="Backend already started up"): @@ -996,17 +946,9 @@ async def test_process_shutdown_success(self): ## WRITTEN BY AI ## """ mock_engine = Mock() - with ( - patch("guidellm.backends.vllm_python.vllm._check_vllm_available"), - patch( - "guidellm.backends.vllm_python.vllm.AsyncEngineArgs", - return_value=Mock(), - ), - patch( - "guidellm.backends.vllm_python.vllm.AsyncLLMEngine" - ) as mock_engine_cls, - ): - mock_engine_cls.from_engine_args = Mock(return_value=mock_engine) + with patch("guidellm.backends.vllm_python.vllm.vllm") as mock_vllm: + mock_vllm.AsyncEngineArgs.return_value = Mock() + mock_vllm.AsyncLLMEngine.from_engine_args = Mock(return_value=mock_engine) backend = _make_vllm_backend(model="test-model") await backend.process_startup() await backend.process_shutdown() @@ -1021,10 +963,9 @@ async def test_process_shutdown_not_started_raises(self): Raise RuntimeError when not started. ## WRITTEN BY AI ## """ - with patch("guidellm.backends.vllm_python.vllm._check_vllm_available"): - backend = _make_vllm_backend(model="test-model") - backend._in_process = False - backend._engine = None + backend = _make_vllm_backend(model="test-model") + backend._in_process = False + backend._engine = None with pytest.raises(RuntimeError, match="Backend not started up"): await backend.process_shutdown() @@ -1448,7 +1389,7 @@ async def mock_generate(prompt, sampling_params, request_id): request.output_metrics = UsageMetrics() with patch( - "guidellm.backends.vllm_python.vllm._decode_audio", + "guidellm.utils.audio._decode_audio", return_value=mock_decode_result, ): backend._engine = Mock() diff --git a/tests/unit/extras/test_audio.py b/tests/unit/extras/test_audio.py index b7f783693..f4334807b 100644 --- a/tests/unit/extras/test_audio.py +++ b/tests/unit/extras/test_audio.py @@ -7,7 +7,7 @@ import pytest import torch -from guidellm.extras.audio import encode_audio +from guidellm.utils import audio as _audio_mod @pytest.fixture @@ -51,7 +51,7 @@ def real_wav_file(): def test_encode_audio_with_tensor_input(sample_audio_tensor): - result = encode_audio( + result = _audio_mod.encode_audio( audio=sample_audio_tensor, sample_rate=16000, audio_format="mp3", @@ -71,7 +71,7 @@ def test_encode_audio_with_tensor_input(sample_audio_tensor): def test_encode_audio_with_numpy_array(sample_audio_tensor): numpy_audio = sample_audio_tensor.numpy() - result = encode_audio(audio=numpy_audio, sample_rate=16000) + result = _audio_mod.encode_audio(audio=numpy_audio, sample_rate=16000) assert result["type"] == "audio_file" assert isinstance(result["audio"], bytes) @@ -79,7 +79,9 @@ def test_encode_audio_with_numpy_array(sample_audio_tensor): def test_encode_audio_with_real_file_path(real_wav_file): - result = encode_audio(audio=real_wav_file, sample_rate=16000, max_duration=1.0) + result = _audio_mod.encode_audio( + audio=real_wav_file, sample_rate=16000, max_duration=1.0 + ) assert result["type"] == "audio_file" assert isinstance(result["audio"], bytes) @@ -93,7 +95,7 @@ def test_encode_audio_with_real_file_path(real_wav_file): def test_encode_audio_with_dict_input_complete(): audio_dict = {"data": torch.randn(1, 16000), "sample_rate": 16000} - result = encode_audio(audio=audio_dict) + result = _audio_mod.encode_audio(audio=audio_dict) assert result["type"] == "audio_file" assert result["audio_bytes"] > 0 @@ -102,7 +104,7 @@ def test_encode_audio_with_dict_input_complete(): @patch("httpx.get") -@patch("guidellm.extras.audio._encode_audio") +@patch("guidellm.utils.audio._encode_audio") def test_encode_audio_with_url(mock_http_get, sample_audio_tensor): # mock http get response mock_response = MagicMock() @@ -111,20 +113,24 @@ def test_encode_audio_with_url(mock_http_get, sample_audio_tensor): mock_http_get.return_value = mock_response # mock decode - return sample audio tensor - with patch("guidellm.extras.audio._decode_audio") as mock_decoder: + with patch("guidellm.utils.audio._decode_audio") as mock_decoder: mock_audio_result = MagicMock() mock_audio_result.data = sample_audio_tensor mock_audio_result.sample_rate = 16000 mock_decoder.return_value = mock_audio_result - result = encode_audio(audio="https://example.com/audio.wav", sample_rate=16000) + result = _audio_mod.encode_audio( + audio="https://example.com/audio.wav", sample_rate=16000 + ) assert result["type"] == "audio_file" def test_encode_audio_with_max_duration(sample_audio_tensor): long_audio = torch.randn(1, 32000) - result = encode_audio(audio=long_audio, sample_rate=16000, max_duration=1.0) + result = _audio_mod.encode_audio( + audio=long_audio, sample_rate=16000, max_duration=1.0 + ) assert result["audio_seconds"] == 1.0 @@ -133,7 +139,7 @@ def test_encode_audio_different_formats(sample_audio_tensor): formats = ["mp3", "wav", "flac"] for fmt in formats: - result = encode_audio( + result = _audio_mod.encode_audio( audio=sample_audio_tensor, sample_rate=16000, audio_format=fmt ) @@ -146,7 +152,7 @@ def test_encode_audio_resampling(sample_audio_tensor): original_rate = 16000 target_rate = 8000 - result = encode_audio( + result = _audio_mod.encode_audio( audio=sample_audio_tensor, sample_rate=original_rate, encode_sample_rate=target_rate, @@ -157,17 +163,17 @@ def test_encode_audio_resampling(sample_audio_tensor): def test_encode_audio_error_handling(): with pytest.raises(ValueError): - encode_audio(audio=123) + _audio_mod.encode_audio(audio=123) with pytest.raises(ValueError): - encode_audio(audio=torch.randn(1, 16000), sample_rate=None) + _audio_mod.encode_audio(audio=torch.randn(1, 16000), sample_rate=None) with pytest.raises(ValueError): - encode_audio(audio="/nonexistent/path/audio.wav") + _audio_mod.encode_audio(audio="/nonexistent/path/audio.wav") def test_audio_quality_preservation(sample_audio_tensor): - result = encode_audio( + result = _audio_mod.encode_audio( audio=sample_audio_tensor, sample_rate=16000, audio_format="mp3", @@ -181,7 +187,7 @@ def test_end_to_end_audio_processing(sample_audio_tensor): original_samples = sample_audio_tensor.shape[1] original_duration = original_samples / 16000 - result = encode_audio( + result = _audio_mod.encode_audio( audio=sample_audio_tensor, sample_rate=16000, audio_format="mp3", diff --git a/tests/unit/extras/test_vision.py b/tests/unit/extras/test_vision.py index 65fe8e69d..95467f3bf 100644 --- a/tests/unit/extras/test_vision.py +++ b/tests/unit/extras/test_vision.py @@ -7,12 +7,7 @@ import pytest from PIL import Image -from guidellm.extras.vision import ( - encode_image, - encode_video, - get_file_format, - resize_image, -) +from guidellm.utils import vision as _vision_mod @pytest.fixture @@ -58,7 +53,7 @@ def sample_video_file(): def test_encode_image_base64(sample_image_bytes: bytes): - result = encode_image(sample_image_bytes, encode_type="base64") + result = _vision_mod.encode_image(sample_image_bytes, encode_type="base64") assert result["type"] == "image_base64" assert "image" in result assert result["image_bytes"] > 0 @@ -66,7 +61,9 @@ def test_encode_image_base64(sample_image_bytes: bytes): def test_encode_image_url(): - result = encode_image(image="https://example.com/vision.jpg", encode_type="url") + result = _vision_mod.encode_image( + image="https://example.com/vision.jpg", encode_type="url" + ) assert result["type"] == "image_url" assert result["image"] == "https://example.com/vision.jpg" @@ -78,7 +75,7 @@ def test_resize_image(sample_image_array: np.ndarray): original_height, original_width = sample_image_array.shape[:2] new_width, new_height = 100, 100 - resized_image = resize_image( + resized_image = _vision_mod.resize_image( pil_image, # Pass PIL Image instead of numpy array width=new_width, height=new_height, @@ -88,12 +85,12 @@ def test_resize_image(sample_image_array: np.ndarray): def test_get_file_format(sample_jpeg_file): - file_format = get_file_format(sample_jpeg_file) + file_format = _vision_mod.get_file_format(sample_jpeg_file) assert file_format == "jpg" def test_encode_video_with_fixture(sample_video_file): - result = encode_video(video=sample_video_file, encode_type="base64") + result = _vision_mod.encode_video(video=sample_video_file, encode_type="base64") assert result["type"] == "video_base64" assert result["video"].startswith("data:video/mp4;base64,") @@ -111,7 +108,7 @@ def test_encode_video_with_url_base64(): mock_response.raise_for_status = MagicMock() mock_get.return_value = mock_response - result = encode_video(video=test_url, encode_type="base64") + result = _vision_mod.encode_video(video=test_url, encode_type="base64") mock_get.assert_called_once_with(test_url) assert result["type"] == "video_base64" @@ -124,7 +121,7 @@ def test_encode_video_with_url_base64(): def test_encode_video_with_url_url_encoding(): """Test encoding a video URL with url encoding""" test_url = "https://example.com/video.mp4" - result = encode_video(video=test_url, encode_type="url") + result = _vision_mod.encode_video(video=test_url, encode_type="url") assert result["type"] == "video_url" assert result["video"] == test_url @@ -139,7 +136,7 @@ def test_encode_video_with_base64_string(): base64_video = base64.b64encode(test_video_bytes).decode("utf-8") data_url = f"data:video/mp4;base64,{base64_video}" - result = encode_video(video=data_url, encode_type="base64") + result = _vision_mod.encode_video(video=data_url, encode_type="base64") assert result["type"] == "video_base64" assert result["video"] == data_url @@ -149,7 +146,7 @@ def test_encode_video_with_base64_string(): def test_encode_video_with_file_path(sample_video_file): - result = encode_video(video=sample_video_file, encode_type="base64") + result = _vision_mod.encode_video(video=sample_video_file, encode_type="base64") assert result["type"] == "video_base64" assert result["video"].startswith("data:video/mp4;base64,") @@ -172,7 +169,7 @@ def test_encode_video_with_path_object(): f.flush() temp_path = Path(f.name) - result = encode_video(video=temp_path, encode_type="base64") + result = _vision_mod.encode_video(video=temp_path, encode_type="base64") assert result["type"] == "video_base64" assert result["video"].startswith("data:video/avi;base64,") @@ -188,7 +185,7 @@ def test_encode_video_with_raw_bytes(): """Test encoding video from raw bytes""" video_bytes = b"raw video bytes content" - result = encode_video(video=video_bytes, encode_type="base64") + result = _vision_mod.encode_video(video=video_bytes, encode_type="base64") assert result["type"] == "video_base64" assert result["video"].startswith("data:video/unknown;base64,") @@ -212,14 +209,14 @@ def test_encode_video_url_with_http_error(): mock_get.return_value = mock_response with pytest.raises(Exception, match="HTTP Error"): - encode_video(video=test_url, encode_type="base64") + _vision_mod.encode_video(video=test_url, encode_type="base64") def test_encode_video_with_none_encode_type(): """Test encoding with None encode_type""" video_bytes = b"test video" - result = encode_video(video=video_bytes, encode_type=None) + result = _vision_mod.encode_video(video=video_bytes, encode_type=None) # Should default to base64 encoding assert result["type"] == "video_base64" @@ -229,7 +226,9 @@ def test_encode_video_with_none_encode_type(): def test_encode_video_with_unsupported_type(): """Test encoding with unsupported video type""" with pytest.raises(ValueError, match="Unsupported video type"): - encode_video(video=123, encode_type="base64") # int is not supported + _vision_mod.encode_video( + video=123, encode_type="base64" + ) # int is not supported def test_encode_video_file_not_found(): @@ -237,7 +236,7 @@ def test_encode_video_file_not_found(): non_existent_path = "/path/that/does/not/exist/video.mp4" with pytest.raises(FileNotFoundError): - encode_video(video=non_existent_path, encode_type="base64") + _vision_mod.encode_video(video=non_existent_path, encode_type="base64") def test_encode_video_base64_correctness(): @@ -246,7 +245,7 @@ def test_encode_video_base64_correctness(): test_bytes = b"Hello World" expected_base64 = base64.b64encode(test_bytes).decode("utf-8") - result = encode_video(video=test_bytes, encode_type="base64") + result = _vision_mod.encode_video(video=test_bytes, encode_type="base64") base64_part = result["video"].split(",", 1)[1] assert base64_part == expected_base64 @@ -257,7 +256,7 @@ def test_encode_video_data_url_format(): """Test that data URL format is correct""" video_bytes = b"test video data" - result = encode_video(video=video_bytes, encode_type="base64") + result = _vision_mod.encode_video(video=video_bytes, encode_type="base64") assert result["video"].startswith("data:video/unknown;base64,") # Verify the format is exactly as expected @@ -270,7 +269,7 @@ def test_encode_video_data_url_format(): # Additional test for edge cases def test_encode_video_empty_bytes(): """Test encoding empty video bytes""" - result = encode_video(video=b"", encode_type="base64") + result = _vision_mod.encode_video(video=b"", encode_type="base64") assert result["type"] == "video_base64" assert result["video"] == "data:video/unknown;base64," @@ -281,7 +280,7 @@ def test_encode_video_large_content(): """Test encoding with larger video content""" large_content = b"x" * 1024 * 1024 # 1MB of data - result = encode_video(video=large_content, encode_type="base64") + result = _vision_mod.encode_video(video=large_content, encode_type="base64") assert result["type"] == "video_base64" assert result["video_bytes"] == len(large_content) diff --git a/tests/unit/utils/test_lazy_loader.py b/tests/unit/utils/test_lazy_loader.py new file mode 100644 index 000000000..981c5af97 --- /dev/null +++ b/tests/unit/utils/test_lazy_loader.py @@ -0,0 +1,432 @@ +# ruff: noqa: PGH004 +# ruff: noqa +""" +Tests.Mocks for the lazy loading utilities from lazy_loader package. + +File uses code adapted from code with the following license: + +BSD 3-Clause License + +Copyright (c) 2022--2023, Scientific Python project +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +""" + +import importlib +import os +import sys +import types +from unittest import mock + +import pytest + +import guidellm.utils.lazy_loader as lazy + + +@pytest.fixture +def clean_fake_pkg(): + yield + sys.modules.pop("tests.mocks.fake_pkg.some_func", None) + sys.modules.pop("tests.mocks.fake_pkg", None) + sys.modules.pop("tests.mocks", None) + + +@pytest.mark.parametrize("attempt", [1, 2]) +def test_cleanup_fixture(clean_fake_pkg, attempt): + assert "tests.mocks.fake_pkg" not in sys.modules + assert "tests.mocks.fake_pkg.some_func" not in sys.modules + from tests.mocks import fake_pkg + + assert "tests.mocks.fake_pkg" in sys.modules + assert "tests.mocks.fake_pkg.some_func" not in sys.modules + assert isinstance(fake_pkg.some_func, types.FunctionType) + assert "tests.mocks.fake_pkg.some_func" in sys.modules + + +def test_lazy_import_basics(): + math = lazy.load("math") + anything_not_real = lazy.load("anything_not_real") + + # Now test that accessing attributes does what it should + assert math.sin(math.pi) == pytest.approx(0, 1e-6) + # poor-mans pytest.raises for testing errors on attribute access + with pytest.raises(ModuleNotFoundError): + anything_not_real.pi + assert isinstance(anything_not_real, lazy.DelayedImportErrorModule) + # see if it changes for second access + with pytest.raises(ModuleNotFoundError): + anything_not_real.pi + + +def test_lazy_import_subpackages(): + with pytest.warns(RuntimeWarning): + hp = lazy.load("html.parser") + assert "html" in sys.modules + assert type(sys.modules["html"]) is type(pytest) + assert isinstance(hp, importlib.util._LazyModule) + assert "html.parser" in sys.modules + assert sys.modules["html.parser"] == hp + + +def test_lazy_import_impact_on_sys_modules(): + math = lazy.load("math") + anything_not_real = lazy.load("anything_not_real") + + assert isinstance(math, types.ModuleType) + assert "math" in sys.modules + assert isinstance(anything_not_real, lazy.DelayedImportErrorModule) + assert "anything_not_real" not in sys.modules + + # only do this if numpy is installed + pytest.importorskip("numpy") + np = lazy.load("numpy") + assert isinstance(np, types.ModuleType) + assert "numpy" in sys.modules + + np.pi # trigger load of numpy + + assert isinstance(np, types.ModuleType) + assert "numpy" in sys.modules + + +def test_lazy_import_nonbuiltins(): + np = lazy.load("numpy") + sp = lazy.load("scipy") + if not isinstance(np, lazy.DelayedImportErrorModule): + assert np.sin(np.pi) == pytest.approx(0, 1e-6) + if isinstance(sp, lazy.DelayedImportErrorModule): + with pytest.raises(ModuleNotFoundError): + sp.pi + + +@pytest.mark.parametrize("lazy_submodules", [False, True]) +def test_lazy_attach(lazy_submodules): + name = "mymod" + submods = ["mysubmodule", "anothersubmodule"] + myall = {"not_real_submod": ["some_var_or_func"]} + + locls = { + "attach": lazy.attach, + "name": name, + "submods": submods, + "myall": myall, + "lazy_submods": lazy_submodules, + } + s = ( + "__getattr__, __lazy_dir__, __all__ = " + "attach(name, submods, myall, lazy_submodules=lazy_submods)" + ) + + exec(s, {}, locls) + expected = { + "attach": lazy.attach, + "name": name, + "submods": submods, + "myall": myall, + "lazy_submods": lazy_submodules, + "__getattr__": None, + "__lazy_dir__": None, + "__all__": None, + } + assert locls.keys() == expected.keys() + for k, v in expected.items(): + if v is not None: + assert locls[k] == v + + # Exercise __getattr__, though it will just error + with pytest.raises(ImportError): + locls["__getattr__"]("mysubmodule") + + # Attribute is supposed to be imported, error on submodule load + with pytest.raises(ImportError): + locls["__getattr__"]("some_var_or_func") + + # Attribute is unknown, raise AttributeError + with pytest.raises(AttributeError): + locls["__getattr__"]("unknown_attr") + + +def test_lazy_attach_noattrs(): + name = "mymod" + submods = ["mysubmodule", "anothersubmodule"] + _, _, all_ = lazy.attach(name, submods) + + assert all_ == sorted(submods) + + +@pytest.mark.parametrize("lazy_submodules", [False, True]) +def test_lazy_attach_returns_copies(lazy_submodules): + _get, _dir, _all = lazy.attach( + __name__, + ["my_submodule", "another_submodule"], + {"foo": ["some_attr"]}, + lazy_submodules=lazy_submodules, + ) + assert _dir() is not _dir() + assert _dir() == _all + assert _dir() is not _all + + expected = ["another_submodule", "my_submodule", "some_attr"] + assert _dir() == expected + assert _all == expected + assert _dir() is not _all + + _dir().append("modify_returned_list") + assert _dir() == expected + assert _all == expected + assert _dir() is not _all + + _all.append("modify_returned_all") + assert _dir() == expected + assert _all == [*expected, "modify_returned_all"] + + +@pytest.mark.parametrize("eager_import", [False, True]) +def test_attach_same_module_and_attr_name(clean_fake_pkg, eager_import): + env = {} + if eager_import: + env["EAGER_IMPORT"] = "1" + + with mock.patch.dict(os.environ, env): + from tests.mocks import fake_pkg + + # Grab attribute twice, to ensure that importing it does not + # override function by module + assert isinstance(fake_pkg.some_func, types.FunctionType) + assert isinstance(fake_pkg.some_func, types.FunctionType) + + # Ensure imports from submodule still work + from tests.mocks.fake_pkg.some_func import some_func + + assert isinstance(some_func, types.FunctionType) + + +FAKE_STUB = """ +from . import rank +from ._gaussian import gaussian +from .edges import sobel, scharr, prewitt, roberts +""" + + +def test_stub_loading(tmp_path): + stub = tmp_path / "stub.pyi" + stub.write_text(FAKE_STUB) + _get, _dir, _all = lazy.attach_stub("my_module", str(stub)) + expect = {"gaussian", "sobel", "scharr", "prewitt", "roberts", "rank"} + assert set(_dir()) == set(_all) == expect + + +def test_stub_loading_parity(): + from tests.mocks import fake_pkg + + from_stub = lazy.attach_stub(fake_pkg.__name__, fake_pkg.__file__) + stub_getter, stub_dir, stub_all = from_stub + assert stub_all == fake_pkg.__all__ + assert stub_dir() == fake_pkg.__lazy_dir__() + assert stub_getter("some_func") == fake_pkg.some_func + + +def test_stub_loading_errors(tmp_path): + stub = tmp_path / "stub.pyi" + stub.write_text("from ..mod import func\n") + + with pytest.raises(ValueError, match="Only within-module imports are supported"): + lazy.attach_stub("name", str(stub)) + + with pytest.raises(ValueError, match="Cannot load imports from non-existent stub"): + lazy.attach_stub("name", "not a file") + + stub2 = tmp_path / "stub2.pyi" + stub2.write_text("from .mod import *\n") + with pytest.raises(ValueError, match=r".*does not support star import"): + lazy.attach_stub("name", str(stub2)) + + +def test_require_kwarg(): + # Test with a module that definitely exists, behavior hinges on requirement + with mock.patch("importlib.metadata.version") as version: + version.return_value = "1.0.0" + math = lazy.load("math", require="somepkg >= 2.0") + assert isinstance(math, lazy.DelayedImportErrorModule) + + math = lazy.load("math", require="somepkg >= 1.0") + assert math.sin(math.pi) == pytest.approx(0, 1e-6) + + # We can fail even after a successful import + math = lazy.load("math", require="somepkg >= 2.0") + assert isinstance(math, lazy.DelayedImportErrorModule) + + # Eager failure + with pytest.raises(ModuleNotFoundError): + lazy.load("math", require="somepkg >= 2.0", error_on_import=True) + + # When a module can be loaded but the version can't be checked, + # raise a ValueError + with pytest.raises(ValueError): + lazy.load("math", require="somepkg >= 1.0") + + +# ── attach_extras tests ────────────────────────────────────────────── + + +def test_attach_extras_attrs_installed_package(): + """## WRITTEN BY AI ##""" + ga, gd, gall = lazy.attach_extras( + "test_extras_attrs", + attrs={ + "sin": lazy.ExtraAttr("math"), + "mypi": lazy.ExtraAttr("math", alias="pi"), + }, + error_message="install math", + ) + assert ga("sin") is __import__("math").sin + assert ga("mypi") == __import__("math").pi + assert gall == ["mypi", "sin"] + assert gd() == ["mypi", "sin"] + + +def test_attach_extras_attrs_alias(): + """## WRITTEN BY AI ##""" + ga, _, _ = lazy.attach_extras( + "test_extras_alias", + attrs={"my_sep": lazy.ExtraAttr("os.path", alias="sep")}, + error_message="install os", + ) + import os.path + + assert ga("my_sep") == os.path.sep + + +def test_attach_extras_attrs_submodule_fallback(): + """## WRITTEN BY AI ##""" + ga, _, _ = lazy.attach_extras( + "test_extras_submod", + attrs={"path": lazy.ExtraAttr("os", alias="path")}, + error_message="install os", + ) + import os.path + + assert ga("path") is os.path + + +def test_attach_extras_attrs_missing_package(): + """## WRITTEN BY AI ##""" + ga, _, _ = lazy.attach_extras( + "test_extras_missing", + attrs={"Foo": lazy.ExtraAttr("nonexistent_pkg_12345")}, + error_message="install nonexistent_pkg_12345", + ) + with pytest.raises(AttributeError, match="install nonexistent_pkg_12345"): + ga("Foo") + + +def test_attach_extras_attrs_unknown_attr(): + """## WRITTEN BY AI ##""" + ga, _, _ = lazy.attach_extras( + "test_extras_unknown", + attrs={"sin": lazy.ExtraAttr("math")}, + error_message="install math", + ) + with pytest.raises(AttributeError, match="has no attribute"): + ga("nonexistent") + + +def test_attach_extras_attrs_caching(): + """## WRITTEN BY AI ##""" + mod_name = "test_extras_caching_mod" + mod = types.ModuleType(mod_name) + sys.modules[mod_name] = mod + try: + ga, _, _ = lazy.attach_extras( + mod_name, + attrs={"pi": lazy.ExtraAttr("math")}, + error_message="install math", + ) + mod.__getattr__ = ga + result = ga("pi") + assert result == __import__("math").pi + assert "pi" in mod.__dict__ + assert mod.__dict__["pi"] == __import__("math").pi + finally: + del sys.modules[mod_name] + + +def test_attach_extras_attrs_dir_returns_copies(): + """## WRITTEN BY AI ##""" + _, gd, gall = lazy.attach_extras( + "test_extras_copies", + attrs={"sin": lazy.ExtraAttr("math"), "pi": lazy.ExtraAttr("math")}, + error_message="install math", + ) + assert gd() == gall + assert gd() is not gd() + gd().append("extra") + assert gd() == ["pi", "sin"] + + +def test_attach_extras_package_installed(): + """## WRITTEN BY AI ##""" + ga, gd, _ = lazy.attach_extras( + "test_extras_pkg", + package="math", + error_message="install math", + ) + assert ga("sin") is __import__("math").sin + assert ga("pi") == __import__("math").pi + assert "sin" in gd() + + +def test_attach_extras_package_missing(): + """## WRITTEN BY AI ##""" + ga, gd, gall = lazy.attach_extras( + "test_extras_pkg_missing", + package="nonexistent_pkg_12345", + error_message="install nonexistent", + ) + with pytest.raises(AttributeError, match="install nonexistent"): + ga("anything") + assert gd() == [] + assert gall == [] + + +def test_attach_extras_package_bad_attr(): + """## WRITTEN BY AI ##""" + ga, _, _ = lazy.attach_extras( + "test_extras_pkg_bad", + package="math", + error_message="install math", + ) + with pytest.raises(AttributeError, match="has no attribute"): + ga("totally_not_in_math_12345") + + +def test_attach_extras_mutual_exclusion(): + """## WRITTEN BY AI ##""" + with pytest.raises(ValueError, match="exactly one"): + lazy.attach_extras("x", attrs={"a": lazy.ExtraAttr("b")}, package="c") + with pytest.raises(ValueError, match="exactly one"): + lazy.attach_extras("x")