Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions examples/fix_alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@
def fix_sten_file(src_file: str, dst_file: str):
pad_key = "p"
pad_value = "P"
src_fd = os.open(src_file, os.O_RDONLY, 0o644)
src_flags = os.O_RDONLY
if sys.platform == "win32" and hasattr(os, "O_BINARY"):
src_flags |= os.O_BINARY
src_fd = os.open(src_file, src_flags, 0o644)
if src_fd < 0:
raise Exception(f"FAIL: open, src_file={src_file}")
meta = SafeTensorsMetadata.from_fd(
Expand Down Expand Up @@ -45,7 +48,10 @@ def fix_sten_file(src_file: str, dst_file: str):
f"dst: filename={dst_file}, header_len={dst_header_len} (pad={head_pad}), size={dst_header_len + meta.size_bytes - meta.header_length}"
)

dst_fd = os.open(dst_file, os.O_WRONLY | os.O_CREAT | os.O_EXCL, 0o644)
dst_flags = os.O_WRONLY | os.O_CREAT | os.O_EXCL
if sys.platform == "win32" and hasattr(os, "O_BINARY"):
dst_flags |= os.O_BINARY
dst_fd = os.open(dst_file, dst_flags, 0o644)
if dst_fd < 0:
raise Exception(f"FAIL: open, dst_fd={dst_fd}")
os_write_full(
Expand Down Expand Up @@ -121,4 +127,4 @@ def os_sendfile_full(src_fd: int, dst_fd: int, offset: int, length: int):
need_copy = fix_sten_file(f"{dir}/{filename}", f"{outdir}/{filename}")
if need_copy:
print(f"copy: {dir}/{filename} --> {outdir}/{filename}")
shutil.copyfile(f"{dir}/{filename}", f"{outdir}/{filename}")
shutil.copyfile(f"{dir}/{filename}", f"{outdir}/{filename}")
100 changes: 97 additions & 3 deletions fastsafetensors/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,84 @@ def get_device_numa_node(device: Optional[int]) -> Optional[int]:
with open(syspath) as f:
return int(f.read().strip())

def resolve_cudart_lib_name() -> str:
"""Resolve the CUDA runtime library name for the current platform.

Returns:
Library name string, or "" to use the compiled-in default.
"""
if sys.platform != "win32":
return "" # Non Windows platforms uses version-agnostic, return empty (default to cuda_compat.h GPU_RUNTIME_LIB)

# Allow explicit override via environment variable
override = os.environ.get("FASTSAFETENSORS_CUDART_LIB", "")
if override:
return override

import glob

def _find_cudart_in_dir(d: str) -> str:
"""Scan a directory for cudart64_*.dll files, return the best match."""
if not os.path.isdir(d):
return ""
matches = glob.glob(os.path.join(d, "cudart64_*.dll"))
if matches:
matches.sort(reverse=True)
return os.path.basename(matches[0])
return ""

def _detect_from_nvcc(cuda_home: str) -> str:
"""Try to detect the CUDA major version from nvcc -V output."""
nvcc = os.path.join(cuda_home, "bin", "nvcc.exe")
if not os.path.isfile(nvcc):
return ""
try:
import subprocess
output = subprocess.check_output(
[nvcc, "-V"], universal_newlines=True, stderr=subprocess.STDOUT
)
tokens = output.split()
release_idx = tokens.index("release") + 1
version_str = tokens[release_idx].rstrip(",")
cuda_major = version_str.split(".")[0]
return f"cudart64_{cuda_major}.dll"
except Exception:
return ""

# Try to detect from CUDA_HOME / CUDA_PATH
cuda_home = os.environ.get("CUDA_HOME") or os.environ.get("CUDA_PATH")
if cuda_home:
result = _detect_from_nvcc(cuda_home)
if result:
return result
result = _find_cudart_in_dir(os.path.join(cuda_home, "bin"))
if result:
return result

# Scan directories on PATH for cudart64_*.dll
path_dirs = os.environ.get("PATH", "").split(os.pathsep)
for d in path_dirs:
result = _find_cudart_in_dir(d)
if result:
return result

# Scan common NVIDIA install locations
program_files = os.environ.get("ProgramFiles", r"C:\Program Files")
nvidia_base = os.path.join(program_files, "NVIDIA GPU Computing Toolkit", "CUDA")
if os.path.isdir(nvidia_base):
# List version directories (e.g. v12.6, v11.8), newest first
try:
versions = sorted(os.listdir(nvidia_base), reverse=True)
except OSError:
versions = []
for ver_dir in versions:
result = _find_cudart_in_dir(
os.path.join(nvidia_base, ver_dir, "bin")
)
if result:
return result

return "" # fall back to compiled-in default

# keep this for compatibility
class SingleGroup:
Expand Down Expand Up @@ -106,10 +184,21 @@ def __init__(
f"validate(tensor {k}): TensorInvalidInfo, e-s={e-s}, nbytes={nbytes}, src={src}"
)
self.size_bytes = size_bytes
if start + header_length != size_bytes:
if start + header_length > size_bytes:
raise Exception(
f"MetadataIncompleteBuffer, src={src}, start={start}, header_length={header_length}, size_bytes={size_bytes}"
)
if start + header_length < size_bytes:
# Trailing padding bytes after tensor data are allowed.
# This occurs with sub-byte dtypes (FP4, NF4) where alignment
# padding is added, or when the header is padded for page alignment.
trailing = size_bytes - (start + header_length)
logger = init_logger(__name__)
logger.debug(
"trailing %d bytes after tensor data in %s (alignment padding)",
trailing,
src,
)

@classmethod
def from_buffer(
Expand Down Expand Up @@ -174,7 +263,12 @@ def from_fd(

@classmethod
def from_file(self, filename: str, framework: FrameworkOpBase):
fd = os.open(filename, os.O_RDONLY, 0o644)
flags = os.O_RDONLY
# On Windows, O_RDONLY defaults to text mode which translates \r\n -> \n,
# corrupting binary data and causing size mismatches on large files.
if sys.platform == "win32" and hasattr(os, "O_BINARY"):
flags |= os.O_BINARY
fd = os.open(filename, flags, 0o644)
ret = self.from_fd(fd, filename, framework=framework, keep_orig_dict=False)
os.close(fd)
return ret
Expand Down Expand Up @@ -334,4 +428,4 @@ def __getitem__(self, _val) -> "TensorFrame":
offsets.append(self.offsets[rdim])
strides.append(self.strides[rdim])
shape.append(self.shape[rdim])
return TensorFrame(self.dtype, shape, self.data_offsets, strides, offsets, True)
return TensorFrame(self.dtype, shape, self.data_offsets, strides, offsets, True)
1 change: 1 addition & 0 deletions fastsafetensors/copier/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from .base import CopierInterface
from .gds import GdsFileCopier
from .nogds import NoGdsFileCopier
from .dstorage import DStorageFileCopier
from .registry import (
CopierConstructFunc,
CopierType,
Expand Down
151 changes: 151 additions & 0 deletions fastsafetensors/copier/dstorage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
# SPDX-License-Identifier: Apache-2.0
import os
import sys
from typing import Dict

from .. import cpp as fstcpp
from ..common import SafeTensorsMetadata, init_logger, is_gpu_found, resolve_cudart_lib_name
from ..frameworks import FrameworkOpBase, TensorBase
from ..st_types import Device, DeviceType, DType
from .base import CopierInterface
from .registry import CopierConstructFunc, register_copier_constructor

logger = init_logger(__name__)

_inited_ds = False

def load_dstorage_dlls() -> None:
"""Download and install DirectStorage DLLs if not already present."""
import ctypes
import io
from pathlib import Path
import shutil
import zipfile
from urllib.request import urlopen, Request
from urllib.error import URLError

cache_dir = Path.home() / ".cache" / "fastsafetensors"
cache_dir.mkdir(parents=True, exist_ok=True)
dstorage_dll = cache_dir / "dstorage.dll"
dlls = ["dstoragecore.dll", "dstorage.dll"]
arch = "x64" if sys.maxsize > 2 ** 32 else "x86"

if not dstorage_dll.exists():
logger.info("Downloading fastsafetensors DirectStorage DLL's")

nupkg_url = os.environ.get("FASTSAFETENSORS_DSTORAGE_NUPKG_URL") or (
"https://globalcdn.nuget.org/packages/"
"microsoft.direct3d.directstorage.1.3.0.nupkg"
"?packageVersion=1.3.0"
)
extract_dir = cache_dir / "directstorage"
dll_src_dir = extract_dir / "native" / "bin" / arch

try:
req = Request(nupkg_url, headers={"User-Agent": "fastsafetensors"})
with urlopen(req, timeout=60) as resp:
nupkg_data = resp.read()

with zipfile.ZipFile(io.BytesIO(nupkg_data)) as zf:
zf.extractall(extract_dir)

for dll_name in dlls:
src = dll_src_dir / dll_name
dst = cache_dir / dll_name
if src.is_file():
shutil.copy2(src, dst)
else:
raise FileNotFoundError(f"Expected {dll_name} at {src} but not found in NuGet package")
except (URLError, OSError, zipfile.BadZipFile, FileNotFoundError) as e:
logger.warning(f"Failed to download/install DirectStorage DLLs: {e}")
finally:
if extract_dir.is_dir():
shutil.rmtree(extract_dir, ignore_errors=True)

for dll_name in dlls:
dll_path = cache_dir / dll_name
if dll_path.is_file():
try:
ctypes.WinDLL(str(dll_path.absolute()))
except OSError as e:
logger.warning(f"Failed to preload {dll_path}: {e}")


def init_dstorage(device_id: int = 0) -> None:
global _inited_ds
if not _inited_ds:
from .nogds import load_library_func
load_dstorage_dlls()
load_library_func()
if not is_gpu_found():
raise RuntimeError("CUDA runtime not found")
cudart_dll = resolve_cudart_lib_name()
if not cudart_dll:
raise RuntimeError("Could not find CUDA runtime DLL")
status = fstcpp.init_dstorage(device_id, 0, cudart_dll)
if status != "ok":
raise RuntimeError(f"init_dstorage failed: {status}")
_inited_ds = True


class DStorageFileCopier(CopierInterface):
"""Copier that reads files via DirectStorage with double-buffered staging
into a standard CUDA (gds_device_buffer) destination."""

def __init__(self,
metadata: SafeTensorsMetadata,
device: Device,
stream_reader: fstcpp.dstorage_stream_reader,
framework: FrameworkOpBase):
self.framework = framework
self.metadata = metadata
self.device = device
self.stream_reader = stream_reader
self.fh: fstcpp.dstorage_file_handle = None

def submit_io(self, use_buf_register: bool, max_copy_block_size: int):
total_bytes = self.metadata.size_bytes - self.metadata.header_length

gbuf = self.framework.alloc_tensor_memory(total_bytes, self.device)

# Open the file via DirectStorage
self.fh = fstcpp.dstorage_file_handle()
if not self.fh.open(self.metadata.src):
raise IOError(f"Failed to open {self.metadata.src} via DirectStorage")

# DS reads NVMe to staging buf, then cudaMemcpy staging to final CUDA buffer.
result = self.stream_reader.read_to_cuda(
self.fh,
gbuf.get_base_address(),
self.metadata.header_length,
total_bytes,
)
if result < 0:
hr = self.stream_reader.last_hresult()
raise RuntimeError(
f"dstorage_stream_reader.read_to_cuda failed: result={result}, "
f"HRESULT=0x{hr & 0xFFFFFFFF:08X}"
)

return gbuf

def wait_io(self, gbuf, dtype=DType.AUTO, noalign=False):
if self.fh:
self.fh.close()
return self.metadata.get_tensors(
gbuf, self.device, self.metadata.header_length, dtype=dtype
)


@register_copier_constructor("dstorage")
def new_dstorage_copier(device: Device, **kwargs) -> CopierConstructFunc:
"""Factory for DirectStorage file copier."""
init_dstorage(device.index if device.index is not None else 0)
stream_reader = fstcpp.dstorage_stream_reader()
if not stream_reader.is_ready():
raise RuntimeError("dstorage_stream_reader failed to initialize")

def construct(metadata: SafeTensorsMetadata, device: Device, framework: FrameworkOpBase) -> CopierInterface:
return DStorageFileCopier(metadata, device, stream_reader, framework)

return construct
15 changes: 9 additions & 6 deletions fastsafetensors/copier/gds.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0

import warnings
import platform
from typing import Dict, Optional

from .. import cpp as fstcpp
Expand Down Expand Up @@ -182,7 +183,7 @@ def new_gds_file_copier(
device_is_not_cpu = device.type != DeviceType.CPU
if device_is_not_cpu and not is_gpu_found():
raise Exception(
"[FAIL] GPU runtime library (libcudart.so or libamdhip64.so) does not exist"
"[FAIL] GPU runtime library not found (expected libcudart.so, libamdhip64.so, or cudart64_XX.dll)"
)
nogds = False
if device_is_not_cpu and not nogds:
Expand All @@ -192,10 +193,12 @@ def new_gds_file_copier(
if gds_supported < 0:
raise Exception(f"is_gds_supported({device.index}) failed")
if not fstcpp.is_cufile_found():
warnings.warn(
"libcufile.so does not exist but nogds is False. use nogds=True",
UserWarning,
)
# Windows does not have cuFile, do not warning about it
if platform.system() != "Windows":
warnings.warn(
"libcufile.so does not exist but nogds is False. use nogds=True",
UserWarning,
)
nogds = True
elif gds_supported == 0:
warnings.warn(
Expand All @@ -222,4 +225,4 @@ def construct_copier(
) -> CopierInterface:
return GdsFileCopier(metadata, device, reader, framework)

return construct_copier
return construct_copier
Loading
Loading