Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,11 @@ with fastsafe_open(filenames=[filename], nogds=True, device="cpu", debug_log=Tru
t = f.get_tensor(key).clone().detach() # clone if t is used outside
```

## Configuration

`AutoLoader` supports file-based configuration for loader type, pipeline mode, copy settings, and more.
See [Configuration Guide](./docs/configuration.md) for defaults, examples, and all available options.

## Development

### Pre-commit Hooks
Expand Down
136 changes: 136 additions & 0 deletions docs/configuration.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
# Configuration Guide

## Configuration Discovery

`AutoLoader` loads configuration in the following priority (highest first):

1. **Environment variable** — `FASTSAFETENSORS_CONFIG=/path/to/config.json`
2. **Default path** — `./fastsafetensors.json` in the working directory (if it exists)
3. **Built-in defaults** — `LoaderConfig()` dataclass defaults

All fields are optional. Unspecified fields fall back to built-in defaults.

## Default Configuration

When no config file is found, `AutoLoader` uses these defaults:

```json
{
"loader": "base",
"framework": "pytorch",
"parallel": {
"use_pipeline": false
},
"debug": {
"debug_log": false,
"set_numa": true,
"disable_cache": true
}
}
```

The base loader extension defaults to `copier_type: "gds"` (GPU Direct Storage).

## queue_size Semantics

| `queue_size` | Mode | GPU Memory | Behavior |
|---|---|---|---|
| `-1` | Fully serial | 1 batch | `copy_files → broadcast → copy_files → ...` |
| `0` | Unbuffered pipeline | Up to 2 batches | 1 batch copying + 1 batch broadcasting concurrently |
| `>0` | Buffered pipeline | Up to `queue_size+1` batches | Producer fills queue while consumer broadcasts |

`use_pipeline: false` forces `queue_size=-1` (serial, minimal GPU memory).

## Configuration Examples

### 1. Minimal — All Defaults (no config file needed)

```python
from fastsafetensors import SingleGroup, AutoLoader

pg = SingleGroup()
loader = AutoLoader(pg, files, device="cuda:0")
for key, tensor in loader.iterate_weights():
process(key, tensor)
loader.close()
```

No config file. Uses `loader="base"`, `gds`, serial mode.

### 2. Base Loader with GDS

```json
{
"loader": "base",
"base": {
"copier_type": "gds"
}
}
```

Enables GPU Direct Storage for NVMe-to-GPU transfers, bypassing host CPU/memory.

### 3. Base Loader with Pipeline Mode

```json
{
"parallel": {
"use_pipeline": true,
"max_concurrent_producers": 1,
"queue_size": 0,
"use_tqdm_on_load": true
}
}
```

Overlaps `copy_files` with `broadcast` for higher throughput.

### 4. 3FS Loader

```json
{
"loader": "3fs",
"3fs": {
"mount_point": "/mnt/3fs",
"entries": 64,
"io_depth": 0,
"buffer_size": 67108864
}
}
```

Uses ThreeFSLoader with 3FS USRBIO backend.

### 5. Full Reference

```json
{
"loader": "base",
"framework": "pytorch",
"base": {
"copier_type": "gds",
"bbuf_size_kb": 16384,
"max_threads": 16
},
"3fs": {
"mount_point": "/mnt/3fs",
"entries": 64,
"io_depth": 0,
"buffer_size": 67108864
},
"parallel": {
"use_pipeline": false,
"max_concurrent_producers": 1,
"queue_size": 0,
"use_tqdm_on_load": true
},
"debug": {
"debug_log": false,
"set_numa": true,
"disable_cache": true
}
}
```

Each loader type has its own extension section (e.g., `base`, `3fs`).
Adding a new loader only requires a new section — no changes to `config.py`.
70 changes: 70 additions & 0 deletions examples/run_auto_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# SPDX-License-Identifier: Apache-2.0

"""Example: Using AutoLoader with automatic configuration.

AutoLoader automatically creates the appropriate loader based on
configuration discovered from:
1. FASTSAFETENSORS_CONFIG environment variable -> config file
2. ./fastsafetensors.json (default path, if it exists)
3. Built-in defaults (loader="base", copier_type="gds")

No manual loader creation is needed. The constructor itself emits a
``logger.info`` summary of the effective configuration, so callers do
not need to print it manually.
"""

import argparse
import logging

from fastsafetensors import AutoLoader, SingleGroup

logger = logging.getLogger(__name__)


def main():
logging.basicConfig(
level=logging.INFO, format="%(levelname)s %(name)s: %(message)s"
)
parser = argparse.ArgumentParser(description="AutoLoader example")
parser.add_argument("files", nargs="+", help="safetensors file paths")
parser.add_argument("--device", default="cpu", help="target device (default: cpu)")
args = parser.parse_args()

pg = SingleGroup()

# --- Way 1: Pure defaults ---
# Uses loader="base", copier_type="gds" by default.
# No config file needed.
logger.info("=== Way 1: Default config ===")
loader = AutoLoader(pg, args.files, device=args.device)
for key, tensor in loader.iterate_weights():
logger.info(" %s: shape=%s", key, tensor.shape)
loader.close()

# --- Way 2: Config file in working directory ---
# Place a fastsafetensors.json in the working directory:
#
# {
# "loader": "3fs",
# "3fs": {
# "mount_point": "/mnt/3fs"
# }
# }
#
# Then just run:
# loader = AutoLoader(pg, args.files, device=args.device)
logger.info(
"=== Way 2: Config file (auto-discovered from ./fastsafetensors.json) ==="
)
logger.info(" (Place fastsafetensors.json in your working directory)")

# --- Way 3: Environment variable ---
# export FASTSAFETENSORS_CONFIG=/path/to/your/config.json
# Then just run:
# loader = AutoLoader(pg, args.files, device=args.device)
logger.info("=== Way 3: Environment variable ===")
logger.info(" export FASTSAFETENSORS_CONFIG=/path/to/config.json")


if __name__ == "__main__":
main()
4 changes: 3 additions & 1 deletion fastsafetensors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
TensorFrame,
get_device_numa_node,
)
from .config import LoaderConfig, load_config
from .file_buffer import FilesBufferOnDevice
from .loader import BaseSafeTensorsFileLoader, SafeTensorsFileLoader, fastsafe_open
from .loader import SafeTensorsFileLoader, fastsafe_open
from .parallel_loader import ParallelLoader
from .unified_loader import AutoLoader
136 changes: 136 additions & 0 deletions fastsafetensors/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
# SPDX-License-Identifier: Apache-2.0

import json
import os
from dataclasses import dataclass, field, fields
from typing import Any, Dict

from .common import init_logger

logger = init_logger(__name__)

CONFIG_ENV_VAR = "FASTSAFETENSORS_CONFIG"
DEFAULT_CONFIG_PATH = "fastsafetensors.json"


@dataclass
class LoaderConfig:
"""Configuration for fastsafetensors unified loader.

Core fields live as dataclass attributes. Per-loader extension settings
(e.g., ``base.copier_type``, ``3fs.mount_point``) are stored in
``_extensions`` and accessed via :meth:`get_extension_config`.
"""

loader: str = "base"
framework: str = "pytorch"
debug_log: bool = False
set_numa: bool = True
disable_cache: bool = True

use_pipeline: bool = False
max_concurrent_producers: int = 1
queue_size: int = 0
use_tqdm_on_load: bool = True

_extensions: Dict[str, Dict[str, Any]] = field(default_factory=dict)

def __post_init__(self):
if self.max_concurrent_producers != 1:
raise ValueError(
f"max_concurrent_producers must be 1 "
f"(got {self.max_concurrent_producers}). "
"Concurrent producers > 1 are not yet supported because broadcast "
"batches must be processed in strict order across all ranks."
)

_COMMON_GROUPS = {"parallel", "debug"}
_COMMON_FIELDS_FOR_EXTENSION = {
"framework",
"debug_log",
"set_numa",
"disable_cache",
}

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "LoaderConfig":
"""Create from dict. ``parallel``/``copy``/``debug`` groups are
flattened; other dict-valued keys become extension sections."""
valid_fields = {f.name for f in fields(cls) if not f.name.startswith("_")}
flat: Dict[str, Any] = {}
extensions: Dict[str, Dict[str, Any]] = {}

for key, value in data.items():
if key in cls._COMMON_GROUPS and isinstance(value, dict):
# parallel / copy / debug -> flatten into top-level fields
for sub_key, sub_value in value.items():
if sub_key in valid_fields:
flat[sub_key] = sub_value
else:
logger.debug(
"Ignoring unknown config field: %s.%s", key, sub_key
)
elif isinstance(value, dict):
# Any other dict-valued top-level key is treated as an
# extension section (e.g., base / 3fs / oss / s3).
extensions[key] = dict(value)
elif key in valid_fields:
flat[key] = value
else:
logger.debug("Ignoring unknown config field: %s", key)

flat["_extensions"] = extensions
return cls(**flat)

def get_extension_config(self, name: str) -> Dict[str, Any]:
"""Return a shallow copy of the extension section for *name*,
with cross-loader common fields stripped."""
raw = self._extensions.get(name, {})
return {
k: v for k, v in raw.items() if k not in self._COMMON_FIELDS_FOR_EXTENSION
}

@classmethod
def _from_json(cls, path: str) -> "LoaderConfig":
with open(path, "r") as f:
data = json.load(f)

logger.info("Loaded config from JSON: %s", path)
return cls.from_dict(data)

@classmethod
def from_file(cls, path: str) -> "LoaderConfig":
"""Load configuration from a JSON file."""
return cls._from_json(path)

def create_parallel_kwargs(self) -> Dict[str, Any]:
if not self.use_pipeline:
# queue_size=-1: fully serial (copy_files → broadcast → copy_files),
# only 1 batch in GPU memory at a time.
return {"queue_size": -1}
return {
"max_concurrent_producers": self.max_concurrent_producers,
"queue_size": self.queue_size,
"use_tqdm_on_load": self.use_tqdm_on_load,
}


def load_config() -> LoaderConfig:
"""Load config: env var > default path > defaults."""
env_path = os.environ.get(CONFIG_ENV_VAR)
if env_path is not None:
if not os.path.isfile(env_path):
raise FileNotFoundError(
f"Config file specified by {CONFIG_ENV_VAR} not found: {env_path}"
)
logger.info("Loading config from %s=%s", CONFIG_ENV_VAR, env_path)
return LoaderConfig.from_file(env_path)

# 2. Default path
if os.path.isfile(DEFAULT_CONFIG_PATH):
logger.info("Loading config from default path: %s", DEFAULT_CONFIG_PATH)
return LoaderConfig.from_file(DEFAULT_CONFIG_PATH)

# 3. Defaults
logger.debug("No config file found, using defaults")
return LoaderConfig()
Loading
Loading