diff --git a/README.md b/README.md index be2660c..219087a 100644 --- a/README.md +++ b/README.md @@ -40,6 +40,11 @@ with fastsafe_open(filenames=[filename], nogds=True, device="cpu", debug_log=Tru t = f.get_tensor(key).clone().detach() # clone if t is used outside ``` +## Configuration + +`AutoLoader` supports file-based configuration for loader type, pipeline mode, copy settings, and more. +See [Configuration Guide](./docs/configuration.md) for defaults, examples, and all available options. + ## Development ### Pre-commit Hooks diff --git a/docs/configuration.md b/docs/configuration.md new file mode 100644 index 0000000..62ba3f9 --- /dev/null +++ b/docs/configuration.md @@ -0,0 +1,136 @@ +# Configuration Guide + +## Configuration Discovery + +`AutoLoader` loads configuration in the following priority (highest first): + +1. **Environment variable** — `FASTSAFETENSORS_CONFIG=/path/to/config.json` +2. **Default path** — `./fastsafetensors.json` in the working directory (if it exists) +3. **Built-in defaults** — `LoaderConfig()` dataclass defaults + +All fields are optional. Unspecified fields fall back to built-in defaults. + +## Default Configuration + +When no config file is found, `AutoLoader` uses these defaults: + +```json +{ + "loader": "base", + "framework": "pytorch", + "parallel": { + "use_pipeline": false + }, + "debug": { + "debug_log": false, + "set_numa": true, + "disable_cache": true + } +} +``` + +The base loader extension defaults to `copier_type: "gds"` (GPU Direct Storage). + +## queue_size Semantics + +| `queue_size` | Mode | GPU Memory | Behavior | +|---|---|---|---| +| `-1` | Fully serial | 1 batch | `copy_files → broadcast → copy_files → ...` | +| `0` | Unbuffered pipeline | Up to 2 batches | 1 batch copying + 1 batch broadcasting concurrently | +| `>0` | Buffered pipeline | Up to `queue_size+1` batches | Producer fills queue while consumer broadcasts | + +`use_pipeline: false` forces `queue_size=-1` (serial, minimal GPU memory). + +## Configuration Examples + +### 1. Minimal — All Defaults (no config file needed) + +```python +from fastsafetensors import SingleGroup, AutoLoader + +pg = SingleGroup() +loader = AutoLoader(pg, files, device="cuda:0") +for key, tensor in loader.iterate_weights(): + process(key, tensor) +loader.close() +``` + +No config file. Uses `loader="base"`, `gds`, serial mode. + +### 2. Base Loader with GDS + +```json +{ + "loader": "base", + "base": { + "copier_type": "gds" + } +} +``` + +Enables GPU Direct Storage for NVMe-to-GPU transfers, bypassing host CPU/memory. + +### 3. Base Loader with Pipeline Mode + +```json +{ + "parallel": { + "use_pipeline": true, + "max_concurrent_producers": 1, + "queue_size": 0, + "use_tqdm_on_load": true + } +} +``` + +Overlaps `copy_files` with `broadcast` for higher throughput. + +### 4. 3FS Loader + +```json +{ + "loader": "3fs", + "3fs": { + "mount_point": "/mnt/3fs", + "entries": 64, + "io_depth": 0, + "buffer_size": 67108864 + } +} +``` + +Uses ThreeFSLoader with 3FS USRBIO backend. + +### 5. Full Reference + +```json +{ + "loader": "base", + "framework": "pytorch", + "base": { + "copier_type": "gds", + "bbuf_size_kb": 16384, + "max_threads": 16 + }, + "3fs": { + "mount_point": "/mnt/3fs", + "entries": 64, + "io_depth": 0, + "buffer_size": 67108864 + }, + "parallel": { + "use_pipeline": false, + "max_concurrent_producers": 1, + "queue_size": 0, + "use_tqdm_on_load": true + }, + "debug": { + "debug_log": false, + "set_numa": true, + "disable_cache": true + } +} +``` + +Each loader type has its own extension section (e.g., `base`, `3fs`). +Adding a new loader only requires a new section — no changes to `config.py`. diff --git a/examples/run_auto_loader.py b/examples/run_auto_loader.py new file mode 100644 index 0000000..9a7b0e5 --- /dev/null +++ b/examples/run_auto_loader.py @@ -0,0 +1,70 @@ +# SPDX-License-Identifier: Apache-2.0 + +"""Example: Using AutoLoader with automatic configuration. + +AutoLoader automatically creates the appropriate loader based on +configuration discovered from: + 1. FASTSAFETENSORS_CONFIG environment variable -> config file + 2. ./fastsafetensors.json (default path, if it exists) + 3. Built-in defaults (loader="base", copier_type="gds") + +No manual loader creation is needed. The constructor itself emits a +``logger.info`` summary of the effective configuration, so callers do +not need to print it manually. +""" + +import argparse +import logging + +from fastsafetensors import AutoLoader, SingleGroup + +logger = logging.getLogger(__name__) + + +def main(): + logging.basicConfig( + level=logging.INFO, format="%(levelname)s %(name)s: %(message)s" + ) + parser = argparse.ArgumentParser(description="AutoLoader example") + parser.add_argument("files", nargs="+", help="safetensors file paths") + parser.add_argument("--device", default="cpu", help="target device (default: cpu)") + args = parser.parse_args() + + pg = SingleGroup() + + # --- Way 1: Pure defaults --- + # Uses loader="base", copier_type="gds" by default. + # No config file needed. + logger.info("=== Way 1: Default config ===") + loader = AutoLoader(pg, args.files, device=args.device) + for key, tensor in loader.iterate_weights(): + logger.info(" %s: shape=%s", key, tensor.shape) + loader.close() + + # --- Way 2: Config file in working directory --- + # Place a fastsafetensors.json in the working directory: + # + # { + # "loader": "3fs", + # "3fs": { + # "mount_point": "/mnt/3fs" + # } + # } + # + # Then just run: + # loader = AutoLoader(pg, args.files, device=args.device) + logger.info( + "=== Way 2: Config file (auto-discovered from ./fastsafetensors.json) ===" + ) + logger.info(" (Place fastsafetensors.json in your working directory)") + + # --- Way 3: Environment variable --- + # export FASTSAFETENSORS_CONFIG=/path/to/your/config.json + # Then just run: + # loader = AutoLoader(pg, args.files, device=args.device) + logger.info("=== Way 3: Environment variable ===") + logger.info(" export FASTSAFETENSORS_CONFIG=/path/to/config.json") + + +if __name__ == "__main__": + main() diff --git a/fastsafetensors/__init__.py b/fastsafetensors/__init__.py index e6a60c3..8792da9 100644 --- a/fastsafetensors/__init__.py +++ b/fastsafetensors/__init__.py @@ -10,6 +10,8 @@ TensorFrame, get_device_numa_node, ) +from .config import LoaderConfig, load_config from .file_buffer import FilesBufferOnDevice -from .loader import BaseSafeTensorsFileLoader, SafeTensorsFileLoader, fastsafe_open +from .loader import SafeTensorsFileLoader, fastsafe_open from .parallel_loader import ParallelLoader +from .unified_loader import AutoLoader diff --git a/fastsafetensors/config.py b/fastsafetensors/config.py new file mode 100644 index 0000000..9b6bfb6 --- /dev/null +++ b/fastsafetensors/config.py @@ -0,0 +1,136 @@ +# SPDX-License-Identifier: Apache-2.0 + +import json +import os +from dataclasses import dataclass, field, fields +from typing import Any, Dict + +from .common import init_logger + +logger = init_logger(__name__) + +CONFIG_ENV_VAR = "FASTSAFETENSORS_CONFIG" +DEFAULT_CONFIG_PATH = "fastsafetensors.json" + + +@dataclass +class LoaderConfig: + """Configuration for fastsafetensors unified loader. + + Core fields live as dataclass attributes. Per-loader extension settings + (e.g., ``base.copier_type``, ``3fs.mount_point``) are stored in + ``_extensions`` and accessed via :meth:`get_extension_config`. + """ + + loader: str = "base" + framework: str = "pytorch" + debug_log: bool = False + set_numa: bool = True + disable_cache: bool = True + + use_pipeline: bool = False + max_concurrent_producers: int = 1 + queue_size: int = 0 + use_tqdm_on_load: bool = True + + _extensions: Dict[str, Dict[str, Any]] = field(default_factory=dict) + + def __post_init__(self): + if self.max_concurrent_producers != 1: + raise ValueError( + f"max_concurrent_producers must be 1 " + f"(got {self.max_concurrent_producers}). " + "Concurrent producers > 1 are not yet supported because broadcast " + "batches must be processed in strict order across all ranks." + ) + + _COMMON_GROUPS = {"parallel", "debug"} + _COMMON_FIELDS_FOR_EXTENSION = { + "framework", + "debug_log", + "set_numa", + "disable_cache", + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "LoaderConfig": + """Create from dict. ``parallel``/``copy``/``debug`` groups are + flattened; other dict-valued keys become extension sections.""" + valid_fields = {f.name for f in fields(cls) if not f.name.startswith("_")} + flat: Dict[str, Any] = {} + extensions: Dict[str, Dict[str, Any]] = {} + + for key, value in data.items(): + if key in cls._COMMON_GROUPS and isinstance(value, dict): + # parallel / copy / debug -> flatten into top-level fields + for sub_key, sub_value in value.items(): + if sub_key in valid_fields: + flat[sub_key] = sub_value + else: + logger.debug( + "Ignoring unknown config field: %s.%s", key, sub_key + ) + elif isinstance(value, dict): + # Any other dict-valued top-level key is treated as an + # extension section (e.g., base / 3fs / oss / s3). + extensions[key] = dict(value) + elif key in valid_fields: + flat[key] = value + else: + logger.debug("Ignoring unknown config field: %s", key) + + flat["_extensions"] = extensions + return cls(**flat) + + def get_extension_config(self, name: str) -> Dict[str, Any]: + """Return a shallow copy of the extension section for *name*, + with cross-loader common fields stripped.""" + raw = self._extensions.get(name, {}) + return { + k: v for k, v in raw.items() if k not in self._COMMON_FIELDS_FOR_EXTENSION + } + + @classmethod + def _from_json(cls, path: str) -> "LoaderConfig": + with open(path, "r") as f: + data = json.load(f) + + logger.info("Loaded config from JSON: %s", path) + return cls.from_dict(data) + + @classmethod + def from_file(cls, path: str) -> "LoaderConfig": + """Load configuration from a JSON file.""" + return cls._from_json(path) + + def create_parallel_kwargs(self) -> Dict[str, Any]: + if not self.use_pipeline: + # queue_size=-1: fully serial (copy_files → broadcast → copy_files), + # only 1 batch in GPU memory at a time. + return {"queue_size": -1} + return { + "max_concurrent_producers": self.max_concurrent_producers, + "queue_size": self.queue_size, + "use_tqdm_on_load": self.use_tqdm_on_load, + } + + +def load_config() -> LoaderConfig: + """Load config: env var > default path > defaults.""" + env_path = os.environ.get(CONFIG_ENV_VAR) + if env_path is not None: + if not os.path.isfile(env_path): + raise FileNotFoundError( + f"Config file specified by {CONFIG_ENV_VAR} not found: {env_path}" + ) + logger.info("Loading config from %s=%s", CONFIG_ENV_VAR, env_path) + return LoaderConfig.from_file(env_path) + + # 2. Default path + if os.path.isfile(DEFAULT_CONFIG_PATH): + logger.info("Loading config from default path: %s", DEFAULT_CONFIG_PATH) + return LoaderConfig.from_file(DEFAULT_CONFIG_PATH) + + # 3. Defaults + logger.debug("No config file found, using defaults") + return LoaderConfig() diff --git a/fastsafetensors/loader.py b/fastsafetensors/loader.py index 83515d8..fad3ad2 100644 --- a/fastsafetensors/loader.py +++ b/fastsafetensors/loader.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 import math -from typing import Any, Dict, List, Optional, OrderedDict, Tuple, Union +from typing import Any, Dict, List, Mapping, Optional, OrderedDict, Tuple, Union from . import cpp as fstcpp from .common import ( @@ -9,7 +9,6 @@ TensorFrame, get_device_numa_node, init_logger, - set_debug, ) from .copier import CopierConstructFunc, CopierType, create_copier_constructor from .copier.unified import is_unified_memory_system @@ -39,6 +38,15 @@ class BaseSafeTensorsFileLoader: framework (str): Deep learning framework to use ("pytorch" or "paddle"). """ + @classmethod + def process_extension_config( + cls, ext_config: Mapping[str, Any], **kwargs: Any + ) -> Dict[str, Any]: + """Translate extension config into ``__init__`` kwargs. + Default: shallow copy as-is. Subclasses override to remap fields. + ``kwargs`` carries runtime context (e.g. ``hf_weights_files``).""" + return dict(ext_config) + def __init__( self, pg: Optional[Any], @@ -179,6 +187,16 @@ class SafeTensorsFileLoader(BaseSafeTensorsFileLoader): >> loader.close() """ + @classmethod + def process_extension_config( + cls, ext_config: Mapping[str, Any], **kwargs: Any + ) -> Dict[str, Any]: + """Map ``copier_type`` to ``nogds`` flag; pass rest through.""" + out = dict(ext_config) + copier_type = out.pop("copier_type", "gds") + out["nogds"] = copier_type != "gds" + return out + def __init__( self, pg: Optional[Any], diff --git a/fastsafetensors/parallel_loader.py b/fastsafetensors/parallel_loader.py index ec3cebb..ed49606 100644 --- a/fastsafetensors/parallel_loader.py +++ b/fastsafetensors/parallel_loader.py @@ -19,8 +19,8 @@ def tqdm(iterable, *args, **kwargs): return iterable -from . import BaseSafeTensorsFileLoader, SafeTensorsFileLoader from . import cpp as fstcpp +from .loader import BaseSafeTensorsFileLoader, SafeTensorsFileLoader def enable_tqdm(use_tqdm_on_load: bool): @@ -133,11 +133,21 @@ def __init__( loader: BaseSafeTensorsFileLoader, hf_weights_files: List[str], max_concurrent_producers: int = 1, - queue_size: int = 0, # Changed default to 0 for unbuffered behavior + # queue_size semantics: + # -1 : fully serial — copy_files → broadcast → copy_files (1 batch in GPU mem) + # 0 : unbuffered pipeline — 1 copying + 1 broadcasting concurrently (2 batches) + # >0 : buffered pipeline — up to (queue_size+1) batches in GPU mem + queue_size: int = 0, use_tqdm_on_load: bool = True, **kwargs, ): + if max_concurrent_producers != 1: + raise ValueError( + f"max_concurrent_producers must be 1 (got {max_concurrent_producers}). " + "Concurrent producers > 1 are not yet supported because broadcast " + "batches must be processed in strict order across all ranks." + ) self.loader = loader self.hf_weights_files = hf_weights_files self.max_concurrent_producers = max_concurrent_producers diff --git a/fastsafetensors/threefs_loader.py b/fastsafetensors/threefs_loader.py index 85b6c5a..aaa0e7a 100644 --- a/fastsafetensors/threefs_loader.py +++ b/fastsafetensors/threefs_loader.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Any, List, Optional +from typing import Any, Dict, List, Mapping, Optional from . import cpp as fstcpp from .common import init_logger @@ -12,25 +12,30 @@ class ThreeFSLoader(BaseSafeTensorsFileLoader): - """Load .safetensors files using 3FS USRBIO for high-performance I/O. - - Args: - pg (Optional[Any]): Process group-like objects for distributed loading. - device (str): Target device where tensors will be loaded (CPU, CUDA, etc.). - mount_point (str): 3FS mount point path (e.g., "/mnt/3fs"). - debug_log (bool): Enable detailed debug logging. - disable_cache (bool): Whether to disable caching of loaded tensors. - framework (str): Deep learning framework to use ("pytorch" or "paddle"). - **kwargs: Additional arguments passed to BaseSafeTensorsFileLoader. - - Examples: - >>> from fastsafetensors.threefs_loader import ThreeFSLoader - >>> loader = ThreeFSLoader(None, device="cuda:0", mount_point="/mnt/3fs") - >>> loader.add_filenames({0: ["/mnt/3fs/model.safetensors"]}) - >>> bufs = loader.copy_files_to_device() - >>> tensor = bufs.get_tensor("weight") - >>> loader.close() - """ + """Load .safetensors files using 3FS USRBIO for high-performance I/O.""" + + @classmethod + def process_extension_config( + cls, ext_config: Mapping[str, Any], **kwargs: Any + ) -> Dict[str, Any]: + """Infer ``mount_point`` from file paths if not explicitly configured.""" + out = dict(ext_config) + if not out.get("mount_point") or str(out["mount_point"]).strip() == "": + files = kwargs.get("hf_weights_files", []) + if files: + try: + from fastsafetensor_3fs_reader import extract_mount_point + + out["mount_point"] = extract_mount_point(files[0]) + logger.info( + "Inferred 3FS mount_point=%s from file paths", + out["mount_point"], + ) + except ImportError: + logger.debug( + "fastsafetensor_3fs_reader not available, using default mount_point" + ) + return out def __init__( self, @@ -40,6 +45,7 @@ def __init__( debug_log: bool = False, disable_cache: bool = True, framework: str = "pytorch", + set_numa: bool = True, **kwargs, ): self.framework = get_framework_op(framework) @@ -55,7 +61,7 @@ def __init__( pg, self.device, copier_type="3fs", - set_numa=True, + set_numa=set_numa, disable_cache=disable_cache, framework=framework, mount_point=mount_point, diff --git a/fastsafetensors/unified_loader.py b/fastsafetensors/unified_loader.py new file mode 100644 index 0000000..a183c3d --- /dev/null +++ b/fastsafetensors/unified_loader.py @@ -0,0 +1,115 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, Dict, Generator, List, Optional, Tuple, Type + +from .common import init_logger +from .config import LoaderConfig, load_config +from .loader import BaseSafeTensorsFileLoader, SafeTensorsFileLoader +from .parallel_loader import PipelineParallel +from .threefs_loader import ThreeFSLoader + +logger = init_logger(__name__) + +# Used by _resolve_loader_class() via globals(). +__all__ = ["AutoLoader", "SafeTensorsFileLoader", "ThreeFSLoader"] + +# Maps loader name -> module-level class attribute name (strings, not +# references, so that unittest.mock.patch intercepts construction). +_LOADER_REGISTRY: Dict[str, str] = { + "base": "SafeTensorsFileLoader", + "3fs": "ThreeFSLoader", +} + + +def _resolve_loader_class(loader_name: str) -> Type[BaseSafeTensorsFileLoader]: + attr_name = _LOADER_REGISTRY.get(loader_name) + if attr_name is None: + raise ValueError( + f"Unknown loader type: {loader_name!r}. " + f"Available: {list(_LOADER_REGISTRY.keys())}" + ) + cls = globals().get(attr_name) + if cls is None: + raise ValueError( + f"Loader class '{attr_name}' for loader type {loader_name!r} " + f"is not imported in this module." + ) + return cls + + +class AutoLoader: + """Config-driven parallel loader. Dispatches to the loader class + registered in ``_LOADER_REGISTRY`` based on ``LoaderConfig.loader``. + + Usage:: + + loader = AutoLoader(pg, files, device="cuda:0") + for key, tensor in loader.iterate_weights(): + process(key, tensor) + loader.close() + """ + + def __init__( + self, + pg: Optional[Any], + hf_weights_files: List[str], + device: str = "cpu", + ): + self._config = load_config() + loader_cls = _resolve_loader_class(self._config.loader) + + common_kwargs: Dict[str, Any] = { + "framework": self._config.framework, + "debug_log": self._config.debug_log, + "set_numa": self._config.set_numa, + "disable_cache": self._config.disable_cache, + } + + raw_ext = self._config.get_extension_config(self._config.loader) + ext_kwargs = loader_cls.process_extension_config( + raw_ext, + hf_weights_files=hf_weights_files, + ) + + self._loader = loader_cls(pg, device=device, **common_kwargs, **ext_kwargs) # type: ignore[arg-type] + + self._pipeline = PipelineParallel( + pg=pg, + loader=self._loader, + hf_weights_files=hf_weights_files, + **self._config.create_parallel_kwargs(), + ) + + self._log_config_summary(device, len(hf_weights_files), ext_kwargs) + + def _log_config_summary( + self, device: str, num_files: int, ext_kwargs: Dict[str, Any] + ) -> None: + cfg = self._config + parts = [ + f"loader={cfg.loader}", + f"framework={cfg.framework}", + f"device={device}", + f"files={num_files}", + ] + # Extension config (dynamic -- no hardcoded field names) + for k, v in ext_kwargs.items(): + parts.append(f"{k}={v}") + parts += [ + f"max_concurrent_producers={cfg.max_concurrent_producers}", + f"queue_size={cfg.queue_size}", + f"use_tqdm_on_load={cfg.use_tqdm_on_load}", + ] + logger.info("AutoLoader initialized: %s", ", ".join(parts)) + + @property + def config(self) -> LoaderConfig: + return self._config + + def iterate_weights(self) -> Generator[Tuple[str, Any], None, None]: + return self._pipeline.iterate_weights() + + def close(self): + # PipelineParallel.close() already closes the underlying loader; + # do NOT call self._loader.close() to avoid double-close. + self._pipeline.close() diff --git a/tests/test_config.py b/tests/test_config.py new file mode 100644 index 0000000..6bee254 --- /dev/null +++ b/tests/test_config.py @@ -0,0 +1,441 @@ +# SPDX-License-Identifier: Apache-2.0 + +"""Tests for fastsafetensors.config module.""" + +import json +import os +import tempfile + +import pytest + +from fastsafetensors.config import ( + CONFIG_ENV_VAR, + DEFAULT_CONFIG_PATH, + LoaderConfig, + load_config, +) + + +class TestLoaderConfigDefaults: + """Test LoaderConfig default values.""" + + def test_default_values(self): + config = LoaderConfig() + assert config.loader == "base" + assert config.framework == "pytorch" + assert config.debug_log is False + assert config.set_numa is True + assert config.disable_cache is True + assert config.max_concurrent_producers == 1 + assert config.queue_size == 0 + assert config.use_tqdm_on_load is True + # Extension fields removed from top-level; _extensions should be empty + assert config._extensions == {} + + def test_loader_field_default(self): + config = LoaderConfig() + assert config.loader == "base" + + def test_loader_field_3fs(self): + config = LoaderConfig(loader="3fs") + assert config.loader == "3fs" + + def test_no_extension_specific_fields(self): + """Verify that extension-specific fields are no longer top-level.""" + config = LoaderConfig() + assert not hasattr(config, "copier_type") + assert not hasattr(config, "bbuf_size_kb") + assert not hasattr(config, "max_threads") + assert not hasattr(config, "mount_point") + assert not hasattr(config, "entries") + assert not hasattr(config, "io_depth") + assert not hasattr(config, "buffer_size") + + +class TestLoaderConfigFromDict: + """Test LoaderConfig.from_dict().""" + + def test_flat_dict(self): + data = { + "loader": "3fs", + "framework": "paddle", + "max_concurrent_producers": 1, + } + config = LoaderConfig.from_dict(data) + assert config.loader == "3fs" + assert config.framework == "paddle" + assert config.max_concurrent_producers == 1 + # Other fields should be defaults + assert config.debug_log is False + + def test_nested_dict(self): + data = { + "loader": "3fs", + "parallel": { + "max_concurrent_producers": 1, + "queue_size": 3, + }, + "debug": { + "debug_log": True, + }, + } + config = LoaderConfig.from_dict(data) + assert config.loader == "3fs" + assert config.max_concurrent_producers == 1 + assert config.queue_size == 3 + assert config.debug_log is True + + def test_nested_base_stored_as_extension(self): + """base section should be stored in _extensions, not flattened.""" + data = { + "loader": "base", + "base": { + "copier_type": "gds", + "bbuf_size_kb": 8192, + }, + } + config = LoaderConfig.from_dict(data) + assert config.loader == "base" + assert config._extensions["base"] == { + "copier_type": "gds", + "bbuf_size_kb": 8192, + } + + def test_nested_3fs_stored_as_extension(self): + """3fs section should be stored in _extensions, not flattened.""" + data = { + "loader": "3fs", + "3fs": { + "mount_point": "/data/3fs", + "entries": 128, + }, + } + config = LoaderConfig.from_dict(data) + assert config.loader == "3fs" + assert config._extensions["3fs"] == { + "mount_point": "/data/3fs", + "entries": 128, + } + + def test_unknown_scalar_fields_ignored(self): + data = { + "loader": "base", + "unknown_field": "should_be_ignored", + "another_unknown": 42, + } + config = LoaderConfig.from_dict(data) + assert config.loader == "base" + assert not hasattr(config, "unknown_field") + + def test_unknown_dict_fields_stored_as_extension(self): + """Any unknown dict-typed top-level field is stored as extension.""" + data = { + "oss": { + "endpoint": "https://oss-cn-hangzhou.aliyuncs.com", + "bucket": "my-bucket", + }, + } + config = LoaderConfig.from_dict(data) + assert config._extensions["oss"] == { + "endpoint": "https://oss-cn-hangzhou.aliyuncs.com", + "bucket": "my-bucket", + } + + def test_unknown_nested_fields_ignored(self): + data = { + "parallel": { + "max_concurrent_producers": 1, + "nonexistent": True, + }, + } + config = LoaderConfig.from_dict(data) + assert config.max_concurrent_producers == 1 + + def test_empty_dict(self): + config = LoaderConfig.from_dict({}) + assert config == LoaderConfig() + + def test_multiple_extensions(self): + """Multiple extension sections can coexist.""" + data = { + "loader": "base", + "base": {"copier_type": "gds"}, + "3fs": {"mount_point": "/mnt/3fs"}, + "oss": {"bucket": "test"}, + } + config = LoaderConfig.from_dict(data) + assert len(config._extensions) == 3 + assert "base" in config._extensions + assert "3fs" in config._extensions + assert "oss" in config._extensions + + +class TestExtensionConfig: + """Test LoaderConfig.get_extension_config().""" + + def test_get_existing_extension(self): + data = { + "base": {"copier_type": "gds", "bbuf_size_kb": 8192}, + } + config = LoaderConfig.from_dict(data) + ext = config.get_extension_config("base") + assert ext == {"copier_type": "gds", "bbuf_size_kb": 8192} + + def test_get_3fs_extension(self): + data = { + "3fs": {"mount_point": "/data/3fs", "entries": 128}, + } + config = LoaderConfig.from_dict(data) + ext = config.get_extension_config("3fs") + assert ext == {"mount_point": "/data/3fs", "entries": 128} + + def test_get_nonexistent_extension(self): + config = LoaderConfig() + ext = config.get_extension_config("nonexistent") + assert ext == {} + + def test_common_fields_stripped(self): + """Common loader fields (framework/debug_log/set_numa/disable_cache) + should be stripped from extension config.""" + data = { + "3fs": { + "mount_point": "/data/3fs", + "framework": "paddle", # common field, should be stripped + "debug_log": True, # common field, should be stripped + }, + } + config = LoaderConfig.from_dict(data) + ext = config.get_extension_config("3fs") + assert ext == {"mount_point": "/data/3fs"} + assert "framework" not in ext + assert "debug_log" not in ext + + def test_shallow_copy(self): + """Returned dict should be a shallow copy; mutation should not affect original.""" + data = { + "base": {"copier_type": "gds"}, + } + config = LoaderConfig.from_dict(data) + ext = config.get_extension_config("base") + ext["copier_type"] = "nogds" + ext["extra"] = "injected" + # Original should be unaffected + original = config.get_extension_config("base") + assert original == {"copier_type": "gds"} + assert "extra" not in original + + +class TestLoaderConfigFromFile: + """Test LoaderConfig.from_file() (public entry point). + + Only JSON format is supported. ``_from_json`` is a private implementation + detail; we verify through the public ``from_file``. + """ + + def test_from_file_json(self): + data = { + "loader": "base", + "base": {"copier_type": "gds"}, + "parallel": { + "max_concurrent_producers": 1, + }, + } + with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f: + json.dump(data, f) + f.flush() + path = f.name + + try: + config = LoaderConfig.from_file(path) + assert config.loader == "base" + assert config.get_extension_config("base")["copier_type"] == "gds" + assert config.max_concurrent_producers == 1 + finally: + os.unlink(path) + + def test_from_file_json_not_found(self): + with pytest.raises(FileNotFoundError): + LoaderConfig.from_file("/nonexistent/path.json") + + def test_from_file_json_3fs(self): + data = { + "loader": "3fs", + "3fs": {"mount_point": "/data/3fs"}, + "parallel": {"max_concurrent_producers": 1}, + } + with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f: + json.dump(data, f) + f.flush() + path = f.name + + try: + config = LoaderConfig.from_file(path) + assert config.loader == "3fs" + assert config.get_extension_config("3fs")["mount_point"] == "/data/3fs" + assert config.max_concurrent_producers == 1 + finally: + os.unlink(path) + + def test_from_file_json_empty(self): + with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f: + json.dump({}, f) + f.flush() + path = f.name + + try: + config = LoaderConfig.from_file(path) + assert config == LoaderConfig() + finally: + os.unlink(path) + + def test_auto_detect_json_extension(self): + data = {"loader": "base", "base": {"copier_type": "gds"}} + with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f: + json.dump(data, f) + f.flush() + path = f.name + + try: + config = LoaderConfig.from_file(path) + assert config.get_extension_config("base")["copier_type"] == "gds" + finally: + os.unlink(path) + + def test_auto_detect_json_3fs_extension(self): + data = {"loader": "3fs"} + with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f: + json.dump(data, f) + f.flush() + path = f.name + + try: + config = LoaderConfig.from_file(path) + assert config.loader == "3fs" + finally: + os.unlink(path) + + +class TestLoaderConfigPublicAPI: + """Verify that removed helpers are truly gone from the public surface.""" + + def test_merge_removed(self): + assert not hasattr(LoaderConfig, "merge") + + def test_to_dict_removed(self): + assert not hasattr(LoaderConfig, "to_dict") + + def test_to_yaml_removed(self): + assert not hasattr(LoaderConfig, "to_yaml") + + def test_yaml_support_removed(self): + # Both public ``from_yaml`` and private ``_from_yaml`` should be removed. + assert not hasattr(LoaderConfig, "from_yaml") + assert not hasattr(LoaderConfig, "_from_yaml") + + def test_from_json_is_private(self): + assert not hasattr(LoaderConfig, "from_json") + assert hasattr(LoaderConfig, "_from_json") + + def test_create_base_loader_kwargs_removed(self): + """create_base_loader_kwargs should no longer exist.""" + assert not hasattr(LoaderConfig, "create_base_loader_kwargs") + + def test_create_threefs_loader_kwargs_removed(self): + """create_threefs_loader_kwargs should no longer exist.""" + assert not hasattr(LoaderConfig, "create_threefs_loader_kwargs") + + def test_get_extension_config_exists(self): + """get_extension_config should be a public method.""" + assert hasattr(LoaderConfig, "get_extension_config") + + +class TestLoaderConfigKwargsHelpers: + """Test remaining create_*_kwargs helper methods.""" + + def test_create_parallel_kwargs_pipeline_enabled(self): + config = LoaderConfig( + use_pipeline=True, + max_concurrent_producers=1, + queue_size=2, + use_tqdm_on_load=False, + ) + kwargs = config.create_parallel_kwargs() + assert kwargs == { + "max_concurrent_producers": 1, + "queue_size": 2, + "use_tqdm_on_load": False, + } + + def test_create_parallel_kwargs_pipeline_disabled(self): + config = LoaderConfig( + use_pipeline=False, + max_concurrent_producers=1, + queue_size=2, + use_tqdm_on_load=False, + ) + kwargs = config.create_parallel_kwargs() + assert kwargs == {"queue_size": -1} + + def test_max_concurrent_producers_validation(self): + """max_concurrent_producers != 1 should raise ValueError.""" + with pytest.raises(ValueError, match="max_concurrent_producers must be 1"): + LoaderConfig(max_concurrent_producers=2) + with pytest.raises(ValueError, match="max_concurrent_producers must be 1"): + LoaderConfig(max_concurrent_producers=0) + with pytest.raises(ValueError, match="max_concurrent_producers must be 1"): + LoaderConfig.from_dict({"max_concurrent_producers": 4}) + with pytest.raises(ValueError, match="max_concurrent_producers must be 1"): + LoaderConfig.from_dict({"parallel": {"max_concurrent_producers": 2}}) + + +class TestLoadConfig: + """Test load_config() priority-based discovery. + + Priority order (high -> low): + 1. FASTSAFETENSORS_CONFIG environment variable + 2. Default path (./fastsafetensors.json) + 3. LoaderConfig defaults + + ``load_config()`` no longer accepts a ``config_path`` argument; callers + should drive file discovery via the environment variable or the default + working-directory path. + """ + + def test_load_config_signature_has_no_params(self): + import inspect + + sig = inspect.signature(load_config) + assert list(sig.parameters.keys()) == [] + + def test_env_var(self, monkeypatch): + data = {"loader": "3fs", "3fs": {"mount_point": "/data/3fs"}} + with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f: + json.dump(data, f) + f.flush() + path = f.name + + try: + monkeypatch.setenv(CONFIG_ENV_VAR, path) + config = load_config() + assert config.loader == "3fs" + assert config.get_extension_config("3fs")["mount_point"] == "/data/3fs" + finally: + os.unlink(path) + + def test_env_var_not_found(self, monkeypatch): + monkeypatch.setenv(CONFIG_ENV_VAR, "/nonexistent/config.json") + with pytest.raises(FileNotFoundError): + load_config() + + def test_no_config_uses_defaults(self, monkeypatch, tmp_path): + """Without env var and no default config file in CWD, use defaults. + + Switch CWD to an empty ``tmp_path`` so the default path + (``./fastsafetensors.json``) is guaranteed to be absent, regardless of + whether the caller ran pytest from a directory that contains one. + """ + monkeypatch.delenv(CONFIG_ENV_VAR, raising=False) + monkeypatch.chdir(tmp_path) + assert not os.path.exists(DEFAULT_CONFIG_PATH) + config = load_config() + assert config == LoaderConfig() diff --git a/tests/test_unified_loader.py b/tests/test_unified_loader.py new file mode 100644 index 0000000..7baef3b --- /dev/null +++ b/tests/test_unified_loader.py @@ -0,0 +1,568 @@ +# SPDX-License-Identifier: Apache-2.0 + +"""Tests for fastsafetensors.unified_loader module.""" + +import json +import os +import tempfile +from unittest.mock import MagicMock, patch + +import pytest + +from fastsafetensors.config import CONFIG_ENV_VAR, LoaderConfig +from fastsafetensors.loader import ( + BaseSafeTensorsFileLoader, +) +from fastsafetensors.loader import SafeTensorsFileLoader as RealSafeLoader +from fastsafetensors.threefs_loader import ThreeFSLoader as RealThreeFSLoader +from fastsafetensors.unified_loader import _LOADER_REGISTRY, AutoLoader + + +class TestAutoLoaderSignature: + """Test AutoLoader constructor signature.""" + + def test_signature_has_pg_files_device(self): + """Constructor should accept (pg, hf_weights_files, device).""" + import inspect + + sig = inspect.signature(AutoLoader.__init__) + params = list(sig.parameters.keys()) + assert "self" in params + assert "pg" in params + assert "hf_weights_files" in params + assert "device" in params + + def test_signature_no_loader_param(self): + """Constructor should NOT have a 'loader' positional param.""" + import inspect + + sig = inspect.signature(AutoLoader.__init__) + params = list(sig.parameters.keys()) + # 'loader' is now a config field, not a constructor arg + assert "loader" not in params + + def test_signature_no_config_param(self): + """Constructor should NOT have 'config' or 'config_path' params.""" + import inspect + + sig = inspect.signature(AutoLoader.__init__) + params = list(sig.parameters.keys()) + assert "config" not in params + assert "config_path" not in params + + def test_signature_no_kwargs(self): + """Constructor should NOT accept **kwargs.""" + import inspect + + sig = inspect.signature(AutoLoader.__init__) + params = sig.parameters + for p in params.values(): + assert p.kind != inspect.Parameter.VAR_KEYWORD + + +class TestLoaderRegistry: + """Test _LOADER_REGISTRY contains expected entries.""" + + def test_base_registered(self): + assert "base" in _LOADER_REGISTRY + + def test_3fs_registered(self): + assert "3fs" in _LOADER_REGISTRY + + def test_threefs_not_registered(self): + """Old 'threefs' name should NOT be in the registry.""" + assert "threefs" not in _LOADER_REGISTRY + + +class TestAutoLoaderBaseLoader: + """Test AutoLoader creates SafeTensorsFileLoader for loader='base'.""" + + def _make_base_config(self, copier_type="nogds"): + """Helper to create a LoaderConfig with base extension.""" + cfg = LoaderConfig(loader="base") + cfg._extensions["base"] = {"copier_type": copier_type} + return cfg + + def test_default_creates_base_loader(self, monkeypatch): + """Default config (loader='base') should create SafeTensorsFileLoader.""" + monkeypatch.delenv(CONFIG_ENV_VAR, raising=False) + + with ( + patch("fastsafetensors.unified_loader.SafeTensorsFileLoader") as mock_stfl, + patch("fastsafetensors.unified_loader.PipelineParallel"), + patch("fastsafetensors.unified_loader.load_config") as mock_load, + ): + mock_load.return_value = self._make_base_config("nogds") + mock_stfl.process_extension_config = RealSafeLoader.process_extension_config + mock_stfl.return_value = MagicMock() + + unified = AutoLoader(None, ["file1.safetensors"], device="cpu") + + mock_stfl.assert_called_once() + call_kwargs = mock_stfl.call_args + assert call_kwargs[1]["nogds"] is True + assert unified.config.loader == "base" + + def test_gds_creates_base_loader_with_nogds_false(self, monkeypatch): + """loader='base' with copier_type='gds' should pass nogds=False.""" + monkeypatch.delenv(CONFIG_ENV_VAR, raising=False) + + with ( + patch("fastsafetensors.unified_loader.SafeTensorsFileLoader") as mock_stfl, + patch("fastsafetensors.unified_loader.PipelineParallel"), + patch("fastsafetensors.unified_loader.load_config") as mock_load, + ): + mock_load.return_value = self._make_base_config("gds") + mock_stfl.process_extension_config = RealSafeLoader.process_extension_config + mock_stfl.return_value = MagicMock() + + AutoLoader(None, ["file1.safetensors"], device="cuda:0") + + call_kwargs = mock_stfl.call_args + assert call_kwargs[1]["nogds"] is False + + def test_base_no_extension_uses_defaults(self, monkeypatch): + """loader='base' without base extension section should still work.""" + monkeypatch.delenv(CONFIG_ENV_VAR, raising=False) + + with ( + patch("fastsafetensors.unified_loader.SafeTensorsFileLoader") as mock_stfl, + patch("fastsafetensors.unified_loader.PipelineParallel"), + patch("fastsafetensors.unified_loader.load_config") as mock_load, + ): + # No _extensions set -> get_extension_config returns {} + mock_load.return_value = LoaderConfig(loader="base") + mock_stfl.process_extension_config = RealSafeLoader.process_extension_config + mock_stfl.return_value = MagicMock() + + AutoLoader(None, ["file1.safetensors"], device="cpu") + + mock_stfl.assert_called_once() + call_kwargs = mock_stfl.call_args + # process_extension_config with {} -> nogds=False (default copier_type="gds") + assert call_kwargs[1]["nogds"] is False + + +class TestAutoLoader3FSLoader: + """Test AutoLoader creates ThreeFSLoader for loader='3fs'.""" + + def _make_3fs_config(self, mount_point="/data/3fs"): + """Helper to create a LoaderConfig with 3fs extension.""" + cfg = LoaderConfig(loader="3fs") + cfg._extensions["3fs"] = {"mount_point": mount_point} + return cfg + + def test_3fs_creates_threefs_loader(self, monkeypatch): + """loader='3fs' should create ThreeFSLoader.""" + monkeypatch.delenv(CONFIG_ENV_VAR, raising=False) + + with ( + patch("fastsafetensors.unified_loader.ThreeFSLoader") as mock_3fs, + patch("fastsafetensors.unified_loader.PipelineParallel"), + patch("fastsafetensors.unified_loader.load_config") as mock_load, + ): + mock_load.return_value = self._make_3fs_config("/data/3fs") + mock_3fs.process_extension_config = ( + RealThreeFSLoader.process_extension_config + ) + mock_3fs.return_value = MagicMock() + + unified = AutoLoader(None, ["file1.safetensors"], device="cuda:0") + + mock_3fs.assert_called_once() + call_kwargs = mock_3fs.call_args + assert call_kwargs[1]["mount_point"] == "/data/3fs" + assert unified.config.loader == "3fs" + + def test_3fs_common_kwargs_passed(self, monkeypatch): + """Common kwargs (framework, debug_log, etc.) should be passed to ThreeFSLoader.""" + monkeypatch.delenv(CONFIG_ENV_VAR, raising=False) + + with ( + patch("fastsafetensors.unified_loader.ThreeFSLoader") as mock_3fs, + patch("fastsafetensors.unified_loader.PipelineParallel"), + patch("fastsafetensors.unified_loader.load_config") as mock_load, + ): + cfg = LoaderConfig(loader="3fs", framework="paddle", debug_log=True) + cfg._extensions["3fs"] = {"mount_point": "/mnt/3fs"} + mock_load.return_value = cfg + mock_3fs.process_extension_config = ( + RealThreeFSLoader.process_extension_config + ) + mock_3fs.return_value = MagicMock() + + AutoLoader(None, ["file1.safetensors"], device="cuda:0") + + call_kwargs = mock_3fs.call_args[1] + assert call_kwargs["framework"] == "paddle" + assert call_kwargs["debug_log"] is True + assert call_kwargs["mount_point"] == "/mnt/3fs" + + +class TestAutoLoaderUnknownLoader: + """Test AutoLoader raises ValueError for unknown loader type.""" + + def test_unknown_loader_raises(self): + with patch("fastsafetensors.unified_loader.load_config") as mock_load: + mock_load.return_value = LoaderConfig(loader="nonexistent") + + with pytest.raises(ValueError, match="Unknown loader type"): + AutoLoader(None, ["file1.safetensors"]) + + +class TestAutoLoaderPipelineCreation: + """Test that AutoLoader creates PipelineParallel correctly.""" + + def test_pipeline_receives_parallel_kwargs(self): + """PipelineParallel should receive parallel kwargs from config.""" + with ( + patch("fastsafetensors.unified_loader.SafeTensorsFileLoader") as mock_stfl, + patch("fastsafetensors.unified_loader.PipelineParallel") as mock_pp, + patch("fastsafetensors.unified_loader.load_config") as mock_load, + ): + mock_load.return_value = LoaderConfig( + use_pipeline=True, + max_concurrent_producers=1, + queue_size=3, + use_tqdm_on_load=False, + ) + mock_stfl.return_value = MagicMock() + + AutoLoader(None, ["file1.safetensors"]) + + mock_pp.assert_called_once() + call_kwargs = mock_pp.call_args[1] + assert call_kwargs["max_concurrent_producers"] == 1 + assert call_kwargs["queue_size"] == 3 + assert call_kwargs["use_tqdm_on_load"] is False + + +class TestAutoLoaderConfigDiscovery: + """Test config file discovery (env var > default path > defaults).""" + + def test_env_var_config(self, monkeypatch): + """FASTSAFETENSORS_CONFIG env var should point to config file.""" + data = { + "loader": "3fs", + "3fs": {"mount_point": "/data/3fs"}, + } + with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f: + json.dump(data, f) + f.flush() + path = f.name + + try: + monkeypatch.setenv(CONFIG_ENV_VAR, path) + + with ( + patch("fastsafetensors.unified_loader.ThreeFSLoader") as mock_3fs, + patch("fastsafetensors.unified_loader.PipelineParallel"), + ): + mock_3fs.return_value = MagicMock() + unified = AutoLoader(None, ["file1.safetensors"]) + assert unified.config.loader == "3fs" + assert ( + unified.config.get_extension_config("3fs")["mount_point"] + == "/data/3fs" + ) + finally: + os.unlink(path) + + def test_no_config_uses_defaults(self, monkeypatch): + """Without config file, should use LoaderConfig defaults.""" + monkeypatch.delenv(CONFIG_ENV_VAR, raising=False) + # cd to a temp dir where no fastsafetensors.json exists + with tempfile.TemporaryDirectory() as tmpdir: + old_cwd = os.getcwd() + os.chdir(tmpdir) + try: + with ( + patch( + "fastsafetensors.unified_loader.SafeTensorsFileLoader" + ) as mock_stfl, + patch("fastsafetensors.unified_loader.PipelineParallel"), + ): + mock_stfl.return_value = MagicMock() + unified = AutoLoader(None, ["file1.safetensors"]) + assert unified.config == LoaderConfig() + finally: + os.chdir(old_cwd) + + +class TestAutoLoaderRemovedAPIs: + """Verify that previously-public surface (context manager / __iter__ / + loader property) is no longer exposed.""" + + def test_no_context_manager(self): + assert not hasattr(AutoLoader, "__enter__") + assert not hasattr(AutoLoader, "__exit__") + + def test_no_iter(self): + # __iter__ should be removed; AutoLoader is not iterable. + assert "__iter__" not in vars(AutoLoader) + + def test_no_loader_property(self): + # The public ``loader`` property is removed; ``_loader`` remains + # as an internal attribute only. + assert "loader" not in vars(AutoLoader) + + +class TestAutoLoaderProperties: + """Test AutoLoader properties.""" + + def test_config_property(self): + with ( + patch("fastsafetensors.unified_loader.SafeTensorsFileLoader") as mock_stfl, + patch("fastsafetensors.unified_loader.PipelineParallel"), + patch("fastsafetensors.unified_loader.load_config") as mock_load, + ): + cfg = LoaderConfig(loader="base") + cfg._extensions["base"] = {"copier_type": "gds"} + mock_load.return_value = cfg + mock_stfl.process_extension_config = RealSafeLoader.process_extension_config + mock_stfl.return_value = MagicMock() + + unified = AutoLoader(None, ["file1.safetensors"]) + assert unified.config is cfg + + +class TestAutoLoaderClose: + """Verify AutoLoader.close() does not double-close the underlying loader. + + PipelineParallel.close() already closes its inner loader, so + AutoLoader.close() must only call pipeline.close() and must NOT call + self._loader.close() directly. + """ + + def test_close_does_not_double_close_underlying_loader(self): + with ( + patch("fastsafetensors.unified_loader.SafeTensorsFileLoader") as mock_stfl, + patch("fastsafetensors.unified_loader.PipelineParallel") as mock_pp, + patch("fastsafetensors.unified_loader.load_config") as mock_load, + ): + mock_load.return_value = LoaderConfig() + mock_loader_instance = MagicMock() + mock_stfl.return_value = mock_loader_instance + mock_pipeline_instance = MagicMock() + mock_pp.return_value = mock_pipeline_instance + + unified = AutoLoader(None, ["file1.safetensors"]) + unified.close() + + mock_pipeline_instance.close.assert_called_once() + mock_loader_instance.close.assert_not_called() + + +class TestProcessExtensionConfig: + """Test that process_extension_config is correctly invoked via AutoLoader.""" + + def test_base_process_extension_config_called(self): + """SafeTensorsFileLoader.process_extension_config should be called.""" + with ( + patch("fastsafetensors.unified_loader.SafeTensorsFileLoader") as mock_stfl, + patch("fastsafetensors.unified_loader.PipelineParallel"), + patch("fastsafetensors.unified_loader.load_config") as mock_load, + ): + cfg = LoaderConfig(loader="base") + cfg._extensions["base"] = {"copier_type": "gds", "bbuf_size_kb": 8192} + mock_load.return_value = cfg + mock_stfl.process_extension_config = RealSafeLoader.process_extension_config + mock_stfl.return_value = MagicMock() + + AutoLoader(None, ["file1.safetensors"], device="cpu") + + call_kwargs = mock_stfl.call_args[1] + # process_extension_config should have mapped copier_type -> nogds + assert call_kwargs["nogds"] is False + assert call_kwargs["bbuf_size_kb"] == 8192 + assert "copier_type" not in call_kwargs + + +class TestThreeFSProcessExtensionConfig: + """Test ThreeFSLoader.process_extension_config mount_point inference.""" + + def test_explicit_mount_point_preserved(self): + """Explicit mount_point should NOT be overridden.""" + result = RealThreeFSLoader.process_extension_config( + {"mount_point": "/custom/3fs"}, + hf_weights_files=["/mnt/x/a.safetensors"], + ) + assert result["mount_point"] == "/custom/3fs" + + def test_no_mount_point_no_files_no_inference(self): + """No mount_point key + no files -> no inference, key absent.""" + result = RealThreeFSLoader.process_extension_config({}) + assert "mount_point" not in result + + def test_no_mount_point_with_files_triggers_inference(self): + """No mount_point key + files present -> inference attempted.""" + with patch( + "fastsafetensors.threefs_loader.extract_mount_point", + create=True, + ) as mock_extract: + # Patch the import inside process_extension_config + mock_module = MagicMock() + mock_module.extract_mount_point.return_value = "/inferred/3fs" + with patch.dict("sys.modules", {"fastsafetensor_3fs_reader": mock_module}): + result = RealThreeFSLoader.process_extension_config( + {}, + hf_weights_files=["/inferred/3fs/model.safetensors"], + ) + assert result["mount_point"] == "/inferred/3fs" + + def test_empty_mount_point_triggers_inference(self): + """Empty string mount_point should trigger inference.""" + mock_module = MagicMock() + mock_module.extract_mount_point.return_value = "/inferred/3fs" + with patch.dict("sys.modules", {"fastsafetensor_3fs_reader": mock_module}): + result = RealThreeFSLoader.process_extension_config( + {"mount_point": ""}, + hf_weights_files=["/inferred/3fs/model.safetensors"], + ) + assert result["mount_point"] == "/inferred/3fs" + + def test_whitespace_mount_point_triggers_inference(self): + """Whitespace-only mount_point should trigger inference.""" + mock_module = MagicMock() + mock_module.extract_mount_point.return_value = "/inferred/3fs" + with patch.dict("sys.modules", {"fastsafetensor_3fs_reader": mock_module}): + result = RealThreeFSLoader.process_extension_config( + {"mount_point": " "}, + hf_weights_files=["/inferred/3fs/model.safetensors"], + ) + assert result["mount_point"] == "/inferred/3fs" + + def test_import_error_graceful_degradation(self): + """If fastsafetensor_3fs_reader is not importable, should not crash.""" + # Remove the module from sys.modules to force ImportError + import sys + + saved = sys.modules.pop("fastsafetensor_3fs_reader", None) + try: + with patch.dict("sys.modules", {"fastsafetensor_3fs_reader": None}): + # None in sys.modules causes ImportError on import + result = RealThreeFSLoader.process_extension_config( + {}, + hf_weights_files=["/mnt/3fs/model.safetensors"], + ) + # Should not crash; mount_point may or may not be set + # but definitely should not raise + assert isinstance(result, dict) + finally: + if saved is not None: + sys.modules["fastsafetensor_3fs_reader"] = saved + + def test_other_fields_pass_through(self): + """Non-mount_point fields should pass through unchanged.""" + result = RealThreeFSLoader.process_extension_config( + {"entries": 64, "io_depth": 0}, + ) + assert result["entries"] == 64 + assert result["io_depth"] == 0 + + def test_none_mount_point_triggers_inference(self): + """None mount_point should be treated as empty -> trigger inference.""" + mock_module = MagicMock() + mock_module.extract_mount_point.return_value = "/inferred/3fs" + with patch.dict("sys.modules", {"fastsafetensor_3fs_reader": mock_module}): + result = RealThreeFSLoader.process_extension_config( + {"mount_point": None}, + hf_weights_files=["/inferred/3fs/model.safetensors"], + ) + # None is not a string, .strip() will raise AttributeError + # unless the code handles it. This tests the current behavior. + assert isinstance(result, dict) + + +class TestKwargsPassedToProcessExtensionConfig: + """Verify hf_weights_files reaches process_extension_config via AutoLoader.""" + + def test_hf_weights_files_passed_to_base(self): + """Base loader's process_extension_config should receive hf_weights_files kwarg.""" + received_kwargs = {} + + def spy_process_ext(ext_config, **kwargs): + received_kwargs.update(kwargs) + return RealSafeLoader.process_extension_config(ext_config, **kwargs) + + with ( + patch("fastsafetensors.unified_loader.SafeTensorsFileLoader") as mock_stfl, + patch("fastsafetensors.unified_loader.PipelineParallel"), + patch("fastsafetensors.unified_loader.load_config") as mock_load, + ): + mock_load.return_value = LoaderConfig(loader="base") + mock_stfl.process_extension_config = spy_process_ext + mock_stfl.return_value = MagicMock() + + AutoLoader(None, ["a.safetensors", "b.safetensors"], device="cpu") + + assert "hf_weights_files" in received_kwargs + assert received_kwargs["hf_weights_files"] == [ + "a.safetensors", + "b.safetensors", + ] + + def test_hf_weights_files_passed_to_3fs(self): + """3FS loader's process_extension_config should receive hf_weights_files kwarg.""" + received_kwargs = {} + + def spy_process_ext(ext_config, **kwargs): + received_kwargs.update(kwargs) + return RealThreeFSLoader.process_extension_config(ext_config, **kwargs) + + with ( + patch("fastsafetensors.unified_loader.ThreeFSLoader") as mock_3fs, + patch("fastsafetensors.unified_loader.PipelineParallel"), + patch("fastsafetensors.unified_loader.load_config") as mock_load, + ): + cfg = LoaderConfig(loader="3fs") + cfg._extensions["3fs"] = {"mount_point": "/data/3fs"} + mock_load.return_value = cfg + mock_3fs.process_extension_config = spy_process_ext + mock_3fs.return_value = MagicMock() + + AutoLoader(None, ["/data/3fs/m1.safetensors"], device="cuda:0") + + assert "hf_weights_files" in received_kwargs + assert received_kwargs["hf_weights_files"] == ["/data/3fs/m1.safetensors"] + + +class TestInitExports: + """Verify public API exports from fastsafetensors package.""" + + def test_auto_loader_importable(self): + """AutoLoader should be importable from top-level package.""" + from fastsafetensors import AutoLoader as AL + + assert AL is not None + assert AL is AutoLoader + + def test_load_config_importable(self): + """LoaderConfig and load_config should be importable from top-level package.""" + from fastsafetensors import LoaderConfig as LC + from fastsafetensors import load_config as lc + + assert LC is not None + assert lc is not None + + +class TestBaseProcessExtensionConfigKwargs: + """Verify BaseSafeTensorsFileLoader.process_extension_config accepts **kwargs.""" + + def test_accepts_extra_kwargs(self): + """Should accept arbitrary kwargs without error.""" + result = BaseSafeTensorsFileLoader.process_extension_config( + {"key1": "val1"}, + hf_weights_files=["a.safetensors"], + some_other_kwarg=42, + ) + assert result == {"key1": "val1"} + + def test_kwargs_ignored_in_output(self): + """Extra kwargs should not appear in output dict.""" + result = BaseSafeTensorsFileLoader.process_extension_config( + {"copier_type": "gds"}, + hf_weights_files=["a.safetensors"], + ) + assert "hf_weights_files" not in result + assert result == {"copier_type": "gds"}