diff --git a/examples/fix_alignment.py b/examples/fix_alignment.py index fe7897f..87e2895 100644 --- a/examples/fix_alignment.py +++ b/examples/fix_alignment.py @@ -14,7 +14,10 @@ def fix_sten_file(src_file: str, dst_file: str): pad_key = "p" pad_value = "P" - src_fd = os.open(src_file, os.O_RDONLY, 0o644) + src_flags = os.O_RDONLY + if sys.platform == "win32" and hasattr(os, "O_BINARY"): + src_flags |= os.O_BINARY + src_fd = os.open(src_file, src_flags, 0o644) if src_fd < 0: raise Exception(f"FAIL: open, src_file={src_file}") meta = SafeTensorsMetadata.from_fd( @@ -45,7 +48,10 @@ def fix_sten_file(src_file: str, dst_file: str): f"dst: filename={dst_file}, header_len={dst_header_len} (pad={head_pad}), size={dst_header_len + meta.size_bytes - meta.header_length}" ) - dst_fd = os.open(dst_file, os.O_WRONLY | os.O_CREAT | os.O_EXCL, 0o644) + dst_flags = os.O_WRONLY | os.O_CREAT | os.O_EXCL + if sys.platform == "win32" and hasattr(os, "O_BINARY"): + dst_flags |= os.O_BINARY + dst_fd = os.open(dst_file, dst_flags, 0o644) if dst_fd < 0: raise Exception(f"FAIL: open, dst_fd={dst_fd}") os_write_full( @@ -121,4 +127,4 @@ def os_sendfile_full(src_fd: int, dst_fd: int, offset: int, length: int): need_copy = fix_sten_file(f"{dir}/{filename}", f"{outdir}/{filename}") if need_copy: print(f"copy: {dir}/{filename} --> {outdir}/{filename}") - shutil.copyfile(f"{dir}/{filename}", f"{outdir}/{filename}") + shutil.copyfile(f"{dir}/{filename}", f"{outdir}/{filename}") \ No newline at end of file diff --git a/fastsafetensors/common.py b/fastsafetensors/common.py index fb994e8..be10a46 100644 --- a/fastsafetensors/common.py +++ b/fastsafetensors/common.py @@ -50,6 +50,84 @@ def get_device_numa_node(device: Optional[int]) -> Optional[int]: with open(syspath) as f: return int(f.read().strip()) +def resolve_cudart_lib_name() -> str: + """Resolve the CUDA runtime library name for the current platform. + + Returns: + Library name string, or "" to use the compiled-in default. + """ + if sys.platform != "win32": + return "" # Non Windows platforms uses version-agnostic, return empty (default to cuda_compat.h GPU_RUNTIME_LIB) + + # Allow explicit override via environment variable + override = os.environ.get("FASTSAFETENSORS_CUDART_LIB", "") + if override: + return override + + import glob + + def _find_cudart_in_dir(d: str) -> str: + """Scan a directory for cudart64_*.dll files, return the best match.""" + if not os.path.isdir(d): + return "" + matches = glob.glob(os.path.join(d, "cudart64_*.dll")) + if matches: + matches.sort(reverse=True) + return os.path.basename(matches[0]) + return "" + + def _detect_from_nvcc(cuda_home: str) -> str: + """Try to detect the CUDA major version from nvcc -V output.""" + nvcc = os.path.join(cuda_home, "bin", "nvcc.exe") + if not os.path.isfile(nvcc): + return "" + try: + import subprocess + output = subprocess.check_output( + [nvcc, "-V"], universal_newlines=True, stderr=subprocess.STDOUT + ) + tokens = output.split() + release_idx = tokens.index("release") + 1 + version_str = tokens[release_idx].rstrip(",") + cuda_major = version_str.split(".")[0] + return f"cudart64_{cuda_major}.dll" + except Exception: + return "" + + # Try to detect from CUDA_HOME / CUDA_PATH + cuda_home = os.environ.get("CUDA_HOME") or os.environ.get("CUDA_PATH") + if cuda_home: + result = _detect_from_nvcc(cuda_home) + if result: + return result + result = _find_cudart_in_dir(os.path.join(cuda_home, "bin")) + if result: + return result + + # Scan directories on PATH for cudart64_*.dll + path_dirs = os.environ.get("PATH", "").split(os.pathsep) + for d in path_dirs: + result = _find_cudart_in_dir(d) + if result: + return result + + # Scan common NVIDIA install locations + program_files = os.environ.get("ProgramFiles", r"C:\Program Files") + nvidia_base = os.path.join(program_files, "NVIDIA GPU Computing Toolkit", "CUDA") + if os.path.isdir(nvidia_base): + # List version directories (e.g. v12.6, v11.8), newest first + try: + versions = sorted(os.listdir(nvidia_base), reverse=True) + except OSError: + versions = [] + for ver_dir in versions: + result = _find_cudart_in_dir( + os.path.join(nvidia_base, ver_dir, "bin") + ) + if result: + return result + + return "" # fall back to compiled-in default # keep this for compatibility class SingleGroup: @@ -106,10 +184,21 @@ def __init__( f"validate(tensor {k}): TensorInvalidInfo, e-s={e-s}, nbytes={nbytes}, src={src}" ) self.size_bytes = size_bytes - if start + header_length != size_bytes: + if start + header_length > size_bytes: raise Exception( f"MetadataIncompleteBuffer, src={src}, start={start}, header_length={header_length}, size_bytes={size_bytes}" ) + if start + header_length < size_bytes: + # Trailing padding bytes after tensor data are allowed. + # This occurs with sub-byte dtypes (FP4, NF4) where alignment + # padding is added, or when the header is padded for page alignment. + trailing = size_bytes - (start + header_length) + logger = init_logger(__name__) + logger.debug( + "trailing %d bytes after tensor data in %s (alignment padding)", + trailing, + src, + ) @classmethod def from_buffer( @@ -174,7 +263,12 @@ def from_fd( @classmethod def from_file(self, filename: str, framework: FrameworkOpBase): - fd = os.open(filename, os.O_RDONLY, 0o644) + flags = os.O_RDONLY + # On Windows, O_RDONLY defaults to text mode which translates \r\n -> \n, + # corrupting binary data and causing size mismatches on large files. + if sys.platform == "win32" and hasattr(os, "O_BINARY"): + flags |= os.O_BINARY + fd = os.open(filename, flags, 0o644) ret = self.from_fd(fd, filename, framework=framework, keep_orig_dict=False) os.close(fd) return ret @@ -334,4 +428,4 @@ def __getitem__(self, _val) -> "TensorFrame": offsets.append(self.offsets[rdim]) strides.append(self.strides[rdim]) shape.append(self.shape[rdim]) - return TensorFrame(self.dtype, shape, self.data_offsets, strides, offsets, True) + return TensorFrame(self.dtype, shape, self.data_offsets, strides, offsets, True) \ No newline at end of file diff --git a/fastsafetensors/copier/__init__.py b/fastsafetensors/copier/__init__.py index ddfa359..0df3362 100644 --- a/fastsafetensors/copier/__init__.py +++ b/fastsafetensors/copier/__init__.py @@ -3,6 +3,7 @@ from .base import CopierInterface from .gds import GdsFileCopier from .nogds import NoGdsFileCopier +from .dstorage import DStorageFileCopier from .registry import ( CopierConstructFunc, CopierType, diff --git a/fastsafetensors/copier/dstorage.py b/fastsafetensors/copier/dstorage.py new file mode 100644 index 0000000..e209ab4 --- /dev/null +++ b/fastsafetensors/copier/dstorage.py @@ -0,0 +1,151 @@ +# SPDX-License-Identifier: Apache-2.0 +import os +import sys +from typing import Dict + +from .. import cpp as fstcpp +from ..common import SafeTensorsMetadata, init_logger, is_gpu_found, resolve_cudart_lib_name +from ..frameworks import FrameworkOpBase, TensorBase +from ..st_types import Device, DeviceType, DType +from .base import CopierInterface +from .registry import CopierConstructFunc, register_copier_constructor + +logger = init_logger(__name__) + +_inited_ds = False + +def load_dstorage_dlls() -> None: + """Download and install DirectStorage DLLs if not already present.""" + import ctypes + import io + from pathlib import Path + import shutil + import zipfile + from urllib.request import urlopen, Request + from urllib.error import URLError + + cache_dir = Path.home() / ".cache" / "fastsafetensors" + cache_dir.mkdir(parents=True, exist_ok=True) + dstorage_dll = cache_dir / "dstorage.dll" + dlls = ["dstoragecore.dll", "dstorage.dll"] + arch = "x64" if sys.maxsize > 2 ** 32 else "x86" + + if not dstorage_dll.exists(): + logger.info("Downloading fastsafetensors DirectStorage DLL's") + + nupkg_url = os.environ.get("FASTSAFETENSORS_DSTORAGE_NUPKG_URL") or ( + "https://globalcdn.nuget.org/packages/" + "microsoft.direct3d.directstorage.1.3.0.nupkg" + "?packageVersion=1.3.0" + ) + extract_dir = cache_dir / "directstorage" + dll_src_dir = extract_dir / "native" / "bin" / arch + + try: + req = Request(nupkg_url, headers={"User-Agent": "fastsafetensors"}) + with urlopen(req, timeout=60) as resp: + nupkg_data = resp.read() + + with zipfile.ZipFile(io.BytesIO(nupkg_data)) as zf: + zf.extractall(extract_dir) + + for dll_name in dlls: + src = dll_src_dir / dll_name + dst = cache_dir / dll_name + if src.is_file(): + shutil.copy2(src, dst) + else: + raise FileNotFoundError(f"Expected {dll_name} at {src} but not found in NuGet package") + except (URLError, OSError, zipfile.BadZipFile, FileNotFoundError) as e: + logger.warning(f"Failed to download/install DirectStorage DLLs: {e}") + finally: + if extract_dir.is_dir(): + shutil.rmtree(extract_dir, ignore_errors=True) + + for dll_name in dlls: + dll_path = cache_dir / dll_name + if dll_path.is_file(): + try: + ctypes.WinDLL(str(dll_path.absolute())) + except OSError as e: + logger.warning(f"Failed to preload {dll_path}: {e}") + + +def init_dstorage(device_id: int = 0) -> None: + global _inited_ds + if not _inited_ds: + from .nogds import load_library_func + load_dstorage_dlls() + load_library_func() + if not is_gpu_found(): + raise RuntimeError("CUDA runtime not found") + cudart_dll = resolve_cudart_lib_name() + if not cudart_dll: + raise RuntimeError("Could not find CUDA runtime DLL") + status = fstcpp.init_dstorage(device_id, 0, cudart_dll) + if status != "ok": + raise RuntimeError(f"init_dstorage failed: {status}") + _inited_ds = True + + +class DStorageFileCopier(CopierInterface): + """Copier that reads files via DirectStorage with double-buffered staging + into a standard CUDA (gds_device_buffer) destination.""" + + def __init__(self, + metadata: SafeTensorsMetadata, + device: Device, + stream_reader: fstcpp.dstorage_stream_reader, + framework: FrameworkOpBase): + self.framework = framework + self.metadata = metadata + self.device = device + self.stream_reader = stream_reader + self.fh: fstcpp.dstorage_file_handle = None + + def submit_io(self, use_buf_register: bool, max_copy_block_size: int): + total_bytes = self.metadata.size_bytes - self.metadata.header_length + + gbuf = self.framework.alloc_tensor_memory(total_bytes, self.device) + + # Open the file via DirectStorage + self.fh = fstcpp.dstorage_file_handle() + if not self.fh.open(self.metadata.src): + raise IOError(f"Failed to open {self.metadata.src} via DirectStorage") + + # DS reads NVMe to staging buf, then cudaMemcpy staging to final CUDA buffer. + result = self.stream_reader.read_to_cuda( + self.fh, + gbuf.get_base_address(), + self.metadata.header_length, + total_bytes, + ) + if result < 0: + hr = self.stream_reader.last_hresult() + raise RuntimeError( + f"dstorage_stream_reader.read_to_cuda failed: result={result}, " + f"HRESULT=0x{hr & 0xFFFFFFFF:08X}" + ) + + return gbuf + + def wait_io(self, gbuf, dtype=DType.AUTO, noalign=False): + if self.fh: + self.fh.close() + return self.metadata.get_tensors( + gbuf, self.device, self.metadata.header_length, dtype=dtype + ) + + +@register_copier_constructor("dstorage") +def new_dstorage_copier(device: Device, **kwargs) -> CopierConstructFunc: + """Factory for DirectStorage file copier.""" + init_dstorage(device.index if device.index is not None else 0) + stream_reader = fstcpp.dstorage_stream_reader() + if not stream_reader.is_ready(): + raise RuntimeError("dstorage_stream_reader failed to initialize") + + def construct(metadata: SafeTensorsMetadata, device: Device, framework: FrameworkOpBase) -> CopierInterface: + return DStorageFileCopier(metadata, device, stream_reader, framework) + + return construct \ No newline at end of file diff --git a/fastsafetensors/copier/gds.py b/fastsafetensors/copier/gds.py index fbb9e63..cfd74e0 100644 --- a/fastsafetensors/copier/gds.py +++ b/fastsafetensors/copier/gds.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 import warnings +import platform from typing import Dict, Optional from .. import cpp as fstcpp @@ -182,7 +183,7 @@ def new_gds_file_copier( device_is_not_cpu = device.type != DeviceType.CPU if device_is_not_cpu and not is_gpu_found(): raise Exception( - "[FAIL] GPU runtime library (libcudart.so or libamdhip64.so) does not exist" + "[FAIL] GPU runtime library not found (expected libcudart.so, libamdhip64.so, or cudart64_XX.dll)" ) nogds = False if device_is_not_cpu and not nogds: @@ -192,10 +193,12 @@ def new_gds_file_copier( if gds_supported < 0: raise Exception(f"is_gds_supported({device.index}) failed") if not fstcpp.is_cufile_found(): - warnings.warn( - "libcufile.so does not exist but nogds is False. use nogds=True", - UserWarning, - ) + # Windows does not have cuFile, do not warning about it + if platform.system() != "Windows": + warnings.warn( + "libcufile.so does not exist but nogds is False. use nogds=True", + UserWarning, + ) nogds = True elif gds_supported == 0: warnings.warn( @@ -222,4 +225,4 @@ def construct_copier( ) -> CopierInterface: return GdsFileCopier(metadata, device, reader, framework) - return construct_copier + return construct_copier \ No newline at end of file diff --git a/fastsafetensors/copier/nogds.py b/fastsafetensors/copier/nogds.py index 882771a..b9ed6ae 100644 --- a/fastsafetensors/copier/nogds.py +++ b/fastsafetensors/copier/nogds.py @@ -1,10 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 import os +import sys from typing import Dict, List from .. import cpp as fstcpp -from ..common import SafeTensorsMetadata, is_gpu_found +from ..common import SafeTensorsMetadata, is_gpu_found, resolve_cudart_lib_name from ..frameworks import FrameworkOpBase, TensorBase from ..st_types import Device, DeviceType, DType from .base import CopierInterface @@ -22,7 +23,12 @@ def __init__( self.framework = framework self.metadata = metadata self.reader = reader - self.fd = os.open(metadata.src, os.O_RDONLY, 0o644) + flags = os.O_RDONLY + # On Windows, O_RDONLY defaults to text mode which translates \r\n + # and stops at 0x1A (Ctrl+Z), corrupting binary tensor data. + if sys.platform == "win32" and hasattr(os, "O_BINARY"): + flags |= os.O_BINARY + self.fd = os.open(metadata.src, flags, 0o644) if self.fd < 0: raise Exception( f"NoGdsFileCopier.__init__: failed to open, file={metadata.src}" @@ -66,14 +72,13 @@ def wait_io( gbuf, self.device, self.metadata.header_length, dtype=dtype ) - _loaded_library = False - def load_library_func(): global _loaded_library if not _loaded_library: - fstcpp.load_library_functions() + cudart_lib = resolve_cudart_lib_name() + fstcpp.load_library_functions(cudart_lib) _loaded_library = True @@ -88,7 +93,7 @@ def new_nogds_file_copier( device_is_not_cpu = device.type != DeviceType.CPU if device_is_not_cpu and not is_gpu_found(): raise Exception( - "[FAIL] GPU runtime library not found (expected libcudart.so or libamdhip64.so)" + "[FAIL] GPU runtime library not found (expected libcudart.so, libamdhip64.so, or cudart64_XX.dll)" ) device_id = device.index if device.index is not None else 0 @@ -103,4 +108,4 @@ def construct_nogds_copier( ) -> CopierInterface: return NoGdsFileCopier(metadata, device, nogds_reader, framework) - return construct_nogds_copier + return construct_nogds_copier \ No newline at end of file diff --git a/fastsafetensors/cpp.pyi b/fastsafetensors/cpp.pyi index 10cab12..6dfd5d9 100644 --- a/fastsafetensors/cpp.pyi +++ b/fastsafetensors/cpp.pyi @@ -60,7 +60,7 @@ def cpu_malloc(length: int) -> int: ... def cpu_free(addr: int) -> None: ... def gpu_malloc(length: int) -> int: ... def gpu_free(addr: int) -> None: ... -def load_library_functions() -> None: ... +def load_library_functions(cudart_lib_name: str = "") -> None: ... def get_cpp_metrics() -> cpp_metrics: ... def set_gil_release(gil_release: bool) -> None: ... def get_gil_release() -> bool: ... diff --git a/fastsafetensors/cpp/dlfcn.cpp b/fastsafetensors/cpp/dlfcn.cpp new file mode 100644 index 0000000..5b30484 --- /dev/null +++ b/fastsafetensors/cpp/dlfcn.cpp @@ -0,0 +1,922 @@ +/* + * dlfcn-win32 + * Copyright (c) 2007 Ramiro Polla + * Copyright (c) 2015 Tiancheng "Timothy" Gu + * Copyright (c) 2019 Pali Rohár + * Copyright (c) 2020 Ralf Habacker + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL + * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#ifdef _DEBUG +#define _CRTDBG_MAP_ALLOC +#include /* malloc() and free() */ +#include +#else +#include +#endif +#include +#include + +/* Older versions do not have this type */ +#if _WIN32_WINNT < 0x0500 +typedef ULONG ULONG_PTR; +#endif + +/* Older SDK versions do not have these macros */ +#ifndef GET_MODULE_HANDLE_EX_FLAG_FROM_ADDRESS +#define GET_MODULE_HANDLE_EX_FLAG_FROM_ADDRESS 0x4 +#endif +#ifndef GET_MODULE_HANDLE_EX_FLAG_UNCHANGED_REFCOUNT +#define GET_MODULE_HANDLE_EX_FLAG_UNCHANGED_REFCOUNT 0x2 +#endif +#ifndef IMAGE_NT_OPTIONAL_HDR_MAGIC +#ifdef _WIN64 +#define IMAGE_NT_OPTIONAL_HDR_MAGIC 0x20b +#else +#define IMAGE_NT_OPTIONAL_HDR_MAGIC 0x10b +#endif +#endif +#ifndef IMAGE_DIRECTORY_ENTRY_IAT +#define IMAGE_DIRECTORY_ENTRY_IAT 12 +#endif +#ifndef LOAD_WITH_ALTERED_SEARCH_PATH +#define LOAD_WITH_ALTERED_SEARCH_PATH 0x8 +#endif + +#ifdef _MSC_VER +#if _MSC_VER >= 1000 +/* https://docs.microsoft.com/en-us/cpp/intrinsics/returnaddress + * When compiling in C++ mode, it is required to have C declaration for _ReturnAddress. + */ +extern "C" void *_ReturnAddress(void); +#pragma intrinsic( _ReturnAddress ) +#else +/* On older version read return address from the value on stack pointer + 4 of + * the caller. Caller stack pointer is stored in EBP register but only when + * the EBP register is not optimized out. Usage of _alloca() function prevent + * EBP register optimization. Read value of EBP + 4 via inline assembly. And + * because inline assembly does not have a return value, put it into naked + * function which does not have prologue and epilogue and preserve registers. + * When compiling in C++ mode, it is required to have C declaration for _alloca. + */ +extern "C" void *__cdecl _alloca(std::size_t); +__declspec( naked ) static void *_ReturnAddress( void ) { __asm mov eax, [ebp+4] __asm ret } +#define _ReturnAddress( ) ( _alloca(1), _ReturnAddress( ) ) +#endif +#else +/* https://gcc.gnu.org/onlinedocs/gcc/Return-Address.html */ +#ifndef _ReturnAddress +#define _ReturnAddress( ) ( __builtin_extract_return_addr( __builtin_return_address( 0 ) ) ) +#endif +#endif + +#ifdef DLFCN_WIN32_SHARED +#define DLFCN_WIN32_EXPORTS +#endif +#include "dlfcn.h" + +#if defined( _MSC_VER ) && _MSC_VER >= 1300 +/* https://docs.microsoft.com/en-us/cpp/cpp/noinline */ +#define DLFCN_NOINLINE __declspec( noinline ) +#elif defined( __GNUC__ ) && ( ( __GNUC__ > 3 ) || ( __GNUC__ == 3 && __GNUC_MINOR__ >= 1 ) ) +/* https://gcc.gnu.org/onlinedocs/gcc/Common-Function-Attributes.html */ +#define DLFCN_NOINLINE __attribute__(( noinline )) +#else +#define DLFCN_NOINLINE +#endif + +/* All internal helpers go in an anonymous namespace, the C++ idiom for + * file-local linkage (replacing C's `static` on free functions/objects). + */ +namespace { + +void *MyAlloc( std::size_t size ) +{ +#ifdef _DEBUG + return std::malloc( size ); +#else + return LocalAlloc( LPTR, size ); +#endif +} + +void MyFree( void *ptr ) +{ +#ifdef _DEBUG + std::free( ptr ); +#else + LocalFree( static_cast( ptr ) ); +#endif +} + +/* Note: + * MSDN says these functions are not thread-safe. We make no efforts to have + * any kind of thread safety. + */ + +struct local_object { + HMODULE hModule; + local_object *previous; + local_object *next; +}; + +local_object first_object; + +/* These functions implement a double linked list for the local objects. */ +local_object *local_search( HMODULE hModule ) +{ + if( hModule == nullptr ) + return nullptr; + + for( local_object *pobject = &first_object; pobject; pobject = pobject->next ) + if( pobject->hModule == hModule ) + return pobject; + + return nullptr; +} + +BOOL local_add( HMODULE hModule ) +{ + if( hModule == nullptr ) + return TRUE; + + local_object *pobject = local_search( hModule ); + + /* Do not add object again if it's already on the list */ + if( pobject != nullptr ) + return TRUE; + + for( pobject = &first_object; pobject->next; pobject = pobject->next ) + ; + + local_object *nobject = static_cast( MyAlloc( sizeof( local_object ) ) ); + + if( !nobject ) + return FALSE; + + pobject->next = nobject; + nobject->next = nullptr; + nobject->previous = pobject; + nobject->hModule = hModule; + + return TRUE; +} + +void local_rem( HMODULE hModule ) +{ + if( hModule == nullptr ) + return; + + local_object *pobject = local_search( hModule ); + + if( pobject == nullptr ) + return; + + if( pobject->next ) + pobject->next->previous = pobject->previous; + if( pobject->previous ) + pobject->previous->next = pobject->next; + + MyFree( pobject ); +} + +/* POSIX says dlerror( ) doesn't have to be thread-safe, so we use one + * static buffer. + * MSDN says the buffer cannot be larger than 64K bytes, so we set it to + * the limit. + */ +char error_buffer[65535]; +BOOL error_occurred; + +void save_err_str( const char *str, DWORD dwMessageId ) +{ + /* Format error message to: + * "": + */ + std::size_t pos = 0; + error_buffer[pos++] = '"'; + for( std::size_t i = 0; i < sizeof( error_buffer ) - 5 && str[i] != '\0'; i++ ) + error_buffer[pos++] = str[i]; + error_buffer[pos++] = '"'; + error_buffer[pos++] = ':'; + error_buffer[pos++] = ' '; + + DWORD ret = FormatMessageA( FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS, nullptr, dwMessageId, + MAKELANGID( LANG_NEUTRAL, SUBLANG_DEFAULT ), + error_buffer + pos, static_cast( sizeof( error_buffer ) - pos ), nullptr ); + pos += ret; + + /* When FormatMessageA() fails it returns zero and does not touch buffer + * so add trailing null byte */ + if( ret == 0 ) + error_buffer[pos] = '\0'; + + if( pos > 1 ) + { + /* POSIX says the string must not have trailing */ + if( error_buffer[pos-2] == '\r' && error_buffer[pos-1] == '\n' ) + error_buffer[pos-2] = '\0'; + } + + error_occurred = TRUE; +} + +void save_err_ptr_str( const void *ptr, DWORD dwMessageId ) +{ + char ptr_buf[2 + 2 * sizeof( ptr ) + 1]; + + ptr_buf[0] = '0'; + ptr_buf[1] = 'x'; + + for( std::size_t i = 0; i < 2 * sizeof( ptr ); i++ ) + { + char num = static_cast( ( ( reinterpret_cast( ptr ) ) >> ( 8 * sizeof( ptr ) - 4 * ( i + 1 ) ) ) & 0xF ); + ptr_buf[2 + i] = static_cast( num + ( ( num < 0xA ) ? '0' : ( 'A' - 0xA ) ) ); + } + + ptr_buf[2 + 2 * sizeof( ptr )] = 0; + + save_err_str( ptr_buf, dwMessageId ); +} + +UINT MySetErrorMode( UINT uMode ) +{ + using SetThreadErrorModeFn = BOOL (WINAPI *)(DWORD, DWORD *); + static SetThreadErrorModeFn SetThreadErrorModePtr = nullptr; + static BOOL failed = FALSE; + + if( !failed && SetThreadErrorModePtr == nullptr ) + { + HMODULE kernel32 = GetModuleHandleA( "Kernel32.dll" ); + if( kernel32 != nullptr ) + SetThreadErrorModePtr = reinterpret_cast( reinterpret_cast( GetProcAddress( kernel32, "SetThreadErrorMode" ) ) ); + if( SetThreadErrorModePtr == nullptr ) + failed = TRUE; + } + + if( !failed ) + { + DWORD oldMode; + if( !SetThreadErrorModePtr( uMode, &oldMode ) ) + return 0; + else + return oldMode; + } + else + { + return SetErrorMode( uMode ); + } +} + +HMODULE MyGetModuleHandleFromAddress( const void *addr ) +{ + using GetModuleHandleExAFn = BOOL (WINAPI *)(DWORD, LPCSTR, HMODULE *); + static GetModuleHandleExAFn GetModuleHandleExAPtr = nullptr; + static BOOL failed = FALSE; + HMODULE hModule; + + if( !failed && GetModuleHandleExAPtr == nullptr ) + { + HMODULE kernel32 = GetModuleHandleA( "Kernel32.dll" ); + if( kernel32 != nullptr ) + GetModuleHandleExAPtr = reinterpret_cast( reinterpret_cast( GetProcAddress( kernel32, "GetModuleHandleExA" ) ) ); + if( GetModuleHandleExAPtr == nullptr ) + failed = TRUE; + } + + if( !failed ) + { + /* If GetModuleHandleExA is available use it with GET_MODULE_HANDLE_EX_FLAG_FROM_ADDRESS */ + if( !GetModuleHandleExAPtr( GET_MODULE_HANDLE_EX_FLAG_FROM_ADDRESS | GET_MODULE_HANDLE_EX_FLAG_UNCHANGED_REFCOUNT, static_cast( addr ), &hModule ) ) + return nullptr; + } + else + { + /* To get HMODULE from address use undocumented hack from https://stackoverflow.com/a/2396380 + * The HMODULE of a DLL is the same value as the module's base address. + */ + MEMORY_BASIC_INFORMATION info; + std::size_t sLen = VirtualQuery( addr, &info, sizeof( info ) ); + if( sLen != sizeof( info ) ) + return nullptr; + hModule = static_cast( info.AllocationBase ); + } + + return hModule; +} + +/* Load Psapi.dll at runtime, this avoids linking caveat */ +BOOL MyEnumProcessModules( HANDLE hProcess, HMODULE *lphModule, DWORD cb, LPDWORD lpcbNeeded ) +{ + using EnumProcessModulesFn = BOOL (WINAPI *)(HANDLE, HMODULE *, DWORD, LPDWORD); + static EnumProcessModulesFn EnumProcessModulesPtr = nullptr; + static BOOL failed = FALSE; + + if( failed ) + return FALSE; + + if( EnumProcessModulesPtr == nullptr ) + { + /* Windows 7 and newer versions have K32EnumProcessModules in Kernel32.dll which is always pre-loaded */ + HMODULE kernel32 = GetModuleHandleA( "Kernel32.dll" ); + if( kernel32 != nullptr ) + EnumProcessModulesPtr = reinterpret_cast( reinterpret_cast( GetProcAddress( kernel32, "K32EnumProcessModules" ) ) ); + + /* Windows Vista and older version have EnumProcessModules in Psapi.dll which needs to be loaded */ + if( EnumProcessModulesPtr == nullptr ) + { + /* Do not let Windows display the critical-error-handler message box */ + UINT uMode = MySetErrorMode( SEM_FAILCRITICALERRORS ); + HMODULE psapi = LoadLibraryA( "Psapi.dll" ); + MySetErrorMode( uMode ); + if( psapi != nullptr ) + { + EnumProcessModulesPtr = reinterpret_cast( reinterpret_cast( GetProcAddress( psapi, "EnumProcessModules" ) ) ); + if( EnumProcessModulesPtr == nullptr ) + FreeLibrary( psapi ); + } + } + + if( EnumProcessModulesPtr == nullptr ) + { + failed = TRUE; + return FALSE; + } + } + + return EnumProcessModulesPtr( hProcess, lphModule, cb, lpcbNeeded ); +} + +} // anonymous namespace + +DLFCN_EXPORT +void *dlopen( const char *file, int mode ) +{ + HMODULE hModule; + + error_occurred = FALSE; + + /* Do not let Windows display the critical-error-handler message box */ + UINT uMode = MySetErrorMode( SEM_FAILCRITICALERRORS ); + + if( file == nullptr ) + { + /* POSIX says that if the value of file is NULL, a handle on a global + * symbol object must be provided. That object must be able to access + * all symbols from the original program file, and any objects loaded + * with the RTLD_GLOBAL flag. + * The return value from GetModuleHandleA( ) allows us to retrieve + * symbols only from the original program file. EnumProcessModules() is + * used to access symbols from other libraries. For objects loaded + * with the RTLD_LOCAL flag, we create our own list later on. They are + * excluded from EnumProcessModules() iteration. + */ + hModule = GetModuleHandleA( nullptr ); + + if( !hModule ) + save_err_str( "(null)", GetLastError( ) ); + } + else + { + char lpFileName[MAX_PATH]; + std::size_t len; + + for( len = 0; ; len++ ) + { + if( file[len] == '\0' ) + break; + } + + if( len >= sizeof( lpFileName ) ) + { + save_err_str( file, ERROR_FILENAME_EXCED_RANGE ); + hModule = nullptr; + } + else + { + /* MSDN says backslashes *must* be used instead of forward slashes. */ + for( std::size_t i = 0; i < len; i++ ) + { + if( file[i] == '/' ) + lpFileName[i] = '\\'; + else + lpFileName[i] = file[i]; + } + lpFileName[len] = '\0'; + + HANDLE hCurrentProc = GetCurrentProcess( ); + DWORD dwProcModsBefore, dwProcModsAfter; + + if( MyEnumProcessModules( hCurrentProc, nullptr, 0, &dwProcModsBefore ) == 0 ) + dwProcModsBefore = 0; + + /* POSIX says the search path is implementation-defined. + * LOAD_WITH_ALTERED_SEARCH_PATH is used to make it behave more closely + * to UNIX's search paths (start with system folders instead of current + * folder). + */ + hModule = LoadLibraryExA( lpFileName, nullptr, LOAD_WITH_ALTERED_SEARCH_PATH ); + + if( !hModule ) + { + save_err_str( lpFileName, GetLastError( ) ); + } + else + { + if( MyEnumProcessModules( hCurrentProc, nullptr, 0, &dwProcModsAfter ) == 0 ) + dwProcModsAfter = 0; + + /* If the object was loaded with RTLD_LOCAL, add it to list of local + * objects, so that its symbols cannot be retrieved even if the handle for + * the original program file is passed. POSIX says that if the same + * file is specified in multiple invocations, and any of them are + * RTLD_GLOBAL, even if any further invocations use RTLD_LOCAL, the + * symbols will remain global. If number of loaded modules was not + * changed after calling LoadLibraryEx(), it means that library was + * already loaded. + */ + if( (mode & RTLD_LOCAL) && dwProcModsBefore != dwProcModsAfter ) + { + if( !local_add( hModule ) ) + { + save_err_str( lpFileName, ERROR_NOT_ENOUGH_MEMORY ); + FreeLibrary( hModule ); + hModule = nullptr; + } + } + else if( !(mode & RTLD_LOCAL) && dwProcModsBefore == dwProcModsAfter ) + { + local_rem( hModule ); + } + } + } + } + + /* Return to previous state of the error-mode bit flags. */ + MySetErrorMode( uMode ); + + return static_cast( hModule ); +} + +DLFCN_EXPORT +int dlclose( void *handle ) +{ + HMODULE hModule = static_cast( handle ); + + error_occurred = FALSE; + + /* dlopen(NULL, ...) does not call LoadLibrary(), so do not call FreeLibrary(). */ + if( hModule == GetModuleHandleA( nullptr ) ) + return 0; + + BOOL ret = FreeLibrary( hModule ); + + /* If the object was loaded with RTLD_LOCAL, remove it from list of local + * objects. + */ + if( ret ) + local_rem( hModule ); + else + save_err_ptr_str( handle, GetLastError( ) ); + + /* dlclose's return value in inverted in relation to FreeLibrary's. */ + ret = !ret; + + return static_cast( ret ); +} + +DLFCN_NOINLINE /* Needed for _ReturnAddress() */ +DLFCN_EXPORT +void *dlsym( void *handle, const char *name ) +{ + error_occurred = FALSE; + + FARPROC symbol = nullptr; + HMODULE hCaller = nullptr; + HMODULE hModule = GetModuleHandleA( nullptr ); + DWORD dwMessageId = 0; + + if( handle == RTLD_DEFAULT ) + { + /* The symbol lookup happens in the normal global scope; that is, + * a search for a symbol using this handle would find the same + * definition as a direct use of this symbol in the program code. + * So use same lookup procedure as when filename is NULL. + */ + handle = hModule; + } + else if( handle == RTLD_NEXT ) + { + /* Specifies the next object after this one that defines name. + * This one refers to the object containing the invocation of dlsym(). + * The next object is the one found upon the application of a load + * order symbol resolution algorithm. To get caller function of dlsym() + * use _ReturnAddress() intrinsic. To get HMODULE of caller function + * use MyGetModuleHandleFromAddress() which calls either standard + * GetModuleHandleExA() function or hack via VirtualQuery(). + */ + hCaller = MyGetModuleHandleFromAddress( _ReturnAddress( ) ); + + if( hCaller == nullptr ) + { + dwMessageId = ERROR_INVALID_PARAMETER; + goto end; + } + } + + if( handle != RTLD_NEXT ) + { + symbol = GetProcAddress( static_cast( handle ), name ); + + if( symbol != nullptr ) + goto end; + } + + /* If the handle for the original program file is passed, also search + * in all globally loaded objects. + */ + + if( hModule == handle || handle == RTLD_NEXT ) + { + HANDLE hCurrentProc = GetCurrentProcess( ); + DWORD dwSize; + + /* GetModuleHandleA( NULL ) only returns the current program file. So + * if we want to get ALL loaded module including those in linked DLLs, + * we have to use EnumProcessModules( ). + */ + if( MyEnumProcessModules( hCurrentProc, nullptr, 0, &dwSize ) != 0 ) + { + HMODULE *modules = static_cast( MyAlloc( dwSize ) ); + if( modules ) + { + DWORD cbNeeded; + if( MyEnumProcessModules( hCurrentProc, modules, dwSize, &cbNeeded ) != 0 && dwSize == cbNeeded ) + { + for( std::size_t i = 0; i < dwSize / sizeof( HMODULE ); i++ ) + { + if( handle == RTLD_NEXT && hCaller ) + { + /* Next modules can be used for RTLD_NEXT */ + if( hCaller == modules[i] ) + hCaller = nullptr; + continue; + } + if( local_search( modules[i] ) ) + continue; + symbol = GetProcAddress( modules[i], name ); + if( symbol != nullptr ) + { + MyFree( modules ); + goto end; + } + } + + } + MyFree( modules ); + } + else + { + dwMessageId = ERROR_NOT_ENOUGH_MEMORY; + goto end; + } + } + } + +end: + if( symbol == nullptr ) + { + if( !dwMessageId ) + dwMessageId = ERROR_PROC_NOT_FOUND; + save_err_str( name, dwMessageId ); + } + + /* Preserve the original C trick: punning a FARPROC through a void* via + * pointer casting. reinterpret_cast between function and object pointers + * is not portable C++, but is well-defined on Windows and matches the + * original implementation's behavior. + */ + return *reinterpret_cast( &symbol ); +} + +DLFCN_EXPORT +char *dlerror( void ) +{ + /* If this is the second consecutive call to dlerror, return NULL */ + if( !error_occurred ) + return nullptr; + + /* POSIX says that invoking dlerror( ) a second time, immediately following + * a prior invocation, shall result in NULL being returned. + */ + error_occurred = FALSE; + + return error_buffer; +} + +/* See https://docs.microsoft.com/en-us/archive/msdn-magazine/2002/march/inside-windows-an-in-depth-look-into-the-win32-portable-executable-file-format-part-2 + * for details */ + +namespace { + +/* Get specific image section */ +BOOL get_image_section( HMODULE module, int index, void **ptr, DWORD *size ) +{ + IMAGE_DOS_HEADER *dosHeader = reinterpret_cast( module ); + + if( dosHeader->e_magic != IMAGE_DOS_SIGNATURE ) + return FALSE; + + IMAGE_NT_HEADERS *ntHeaders = reinterpret_cast( reinterpret_cast( dosHeader ) + dosHeader->e_lfanew ); + + if( ntHeaders->Signature != IMAGE_NT_SIGNATURE ) + return FALSE; + + IMAGE_OPTIONAL_HEADER *optionalHeader = &ntHeaders->OptionalHeader; + + if( optionalHeader->Magic != IMAGE_NT_OPTIONAL_HDR_MAGIC ) + return FALSE; + + if( index < 0 || index >= IMAGE_NUMBEROF_DIRECTORY_ENTRIES || static_cast( index ) >= optionalHeader->NumberOfRvaAndSizes ) + return FALSE; + + if( optionalHeader->DataDirectory[index].Size == 0 || optionalHeader->DataDirectory[index].VirtualAddress == 0 ) + return FALSE; + + if( size != nullptr ) + *size = optionalHeader->DataDirectory[index].Size; + + *ptr = reinterpret_cast( reinterpret_cast( module ) + optionalHeader->DataDirectory[index].VirtualAddress ); + + return TRUE; +} + +/* Return symbol name for a given address from export table */ +const char *get_export_symbol_name( HMODULE module, IMAGE_EXPORT_DIRECTORY *ied, const void *addr, void **func_address ) +{ + void *candidateAddr = nullptr; + int candidateIndex = -1; + BYTE *base = reinterpret_cast( module ); + DWORD *functionAddressesOffsets = reinterpret_cast( base + static_cast( ied->AddressOfFunctions ) ); + DWORD *functionNamesOffsets = reinterpret_cast( base + static_cast( ied->AddressOfNames ) ); + USHORT *functionNameOrdinalsIndexes = reinterpret_cast( base + static_cast( ied->AddressOfNameOrdinals ) ); + + for( DWORD i = 0; i < ied->NumberOfFunctions; i++ ) + { + if( static_cast( base + functionAddressesOffsets[i] ) > addr || candidateAddr >= static_cast( base + functionAddressesOffsets[i] ) ) + continue; + + candidateAddr = static_cast( base + functionAddressesOffsets[i] ); + candidateIndex = i; + } + + if( candidateIndex == -1 ) + return nullptr; + + *func_address = candidateAddr; + + for( DWORD i = 0; i < ied->NumberOfNames; i++ ) + { + if( functionNameOrdinalsIndexes[i] == candidateIndex ) + return reinterpret_cast( base + functionNamesOffsets[i] ); + } + + return nullptr; +} + +BOOL is_valid_address( const void *addr ) +{ + if( addr == nullptr ) + return FALSE; + + /* check valid pointer */ + MEMORY_BASIC_INFORMATION info; + std::size_t result = VirtualQuery( addr, &info, sizeof( info ) ); + + if( result != sizeof( info ) || info.AllocationBase == nullptr || info.State == MEM_FREE || info.State == MEM_RESERVE || info.Protect == 0 || info.Protect == PAGE_NOACCESS ) + return FALSE; + + return TRUE; +} + +#if defined(_M_ARM64) || defined(__aarch64__) +INT64 sign_extend( UINT64 value, UINT bits ) +{ + const UINT left = 64 - bits; + const INT64 m1 = -1; + const INT64 wide = static_cast( value << left ); + const INT64 sign = ( wide < 0 ) ? ( m1 << left ) : 0; + + return value | sign; +} +#endif + +/* Return state if address points to an import thunk + * + * On x86, an import thunk is setup with a 'jmp' instruction followed by an + * absolute address (32bit) or relative offset (64bit) pointing into + * the import address table (iat), which is partially maintained by + * the runtime linker. + * + * On ARM64, an import thunk is also a relative jump pointing into the + * import address table, implemented by the following three instructions: + * + * adrp x16, [page_offset] + * Calculates the page address (aligned to 4KB) the IAT is at, based + * on the value of x16, with page_offset. + * + * ldr x16, [x16, offset] + * Calculates the final IAT address, x16 <- x16 + offset. + * + * br x16 + * Jump to the address in x16. + * + * The register used here is hardcoded to be x16. + */ +BOOL is_import_thunk( const void *addr ) +{ +#if defined(_M_ARM64) || defined(__aarch64__) + ULONG opCode1 = *reinterpret_cast( static_cast( addr ) ); + ULONG opCode2 = *reinterpret_cast( static_cast( addr ) + 4 ); + ULONG opCode3 = *reinterpret_cast( static_cast( addr ) + 8 ); + + return ( opCode1 & 0x9f00001f ) == 0x90000010 /* adrp x16, [page_offset] */ + && ( opCode2 & 0xffe003ff ) == 0xf9400210 /* ldr x16, [x16, offset] */ + && opCode3 == 0xd61f0200 /* br x16 */ + ? TRUE : FALSE; +#elif defined(_M_AMD64) || defined(_M_IX86) || defined(__x86_64__) || defined(__i386__) + return *static_cast( addr ) == 0x25ff ? TRUE : FALSE; +#else + (void) addr; + return FALSE; +#endif +} + +/* Return adress from the import address table (iat), + * if the original address points to a thunk table entry. + */ +void *get_address_from_import_address_table( void *iat, DWORD iat_size, const void *addr ) +{ + const BYTE *thkp = static_cast( addr ); +#if defined(_M_ARM64) || defined(__aarch64__) + /* + * typical import thunk in ARM64: + * 0x7ff772ae78c0 <+25760>: adrp x16, 1 + * 0x7ff772ae78c4 <+25764>: ldr x16, [x16, #0xdc0] + * 0x7ff772ae78c8 <+25768>: br x16 + */ + ULONG opCode1 = *reinterpret_cast( thkp ); + ULONG opCode2 = *reinterpret_cast( thkp + 4 ); + + /* Extract the offset from adrp instruction */ + UINT64 pageLow2 = ( opCode1 >> 29 ) & 3; + UINT64 pageHigh19 = ( opCode1 >> 5 ) & ~( ~0ull << 19 ); + INT64 page = sign_extend( ( pageHigh19 << 2 ) | pageLow2, 21 ) << 12; + + /* Extract the offset from ldr instruction */ + UINT64 offset = ( ( opCode2 >> 10 ) & ~( ~0ull << 12 ) ) << 3; + + /* Calculate the final address */ + const BYTE *ptr = reinterpret_cast( ( reinterpret_cast( thkp ) & ~0xfffull ) + page + offset ); +#elif defined(_M_AMD64) || defined(_M_IX86) || defined(__x86_64__) || defined(__i386__) + /* Get offset from thunk table (after instruction 0xff 0x25) + * 4018c8 <_VirtualQuery>: ff 25 4a 8a 00 00 + */ + ULONG offset = *reinterpret_cast( thkp + 2 ); +#if defined(_M_AMD64) || defined(__x86_64__) + /* On 64 bit the offset is relative + * 4018c8: ff 25 4a 8a 00 00 jmpq *0x8a4a(%rip) # 40a318 <__imp_VirtualQuery> + * And can be also negative (MSVC in WDK) + * 100002f20: ff 25 3a e1 ff ff jmpq *-0x1ec6(%rip) # 0x100001060 + * So cast to signed LONG type + */ + const BYTE *ptr = thkp + 6 + static_cast( offset ); +#else + /* On 32 bit the offset is absolute + * 4019b4: ff 25 90 71 40 00 jmp *0x40719 + */ + const BYTE *ptr = reinterpret_cast( offset ); +#endif +#else + (void) thkp; + (void) iat; + (void) iat_size; + (void) addr; + return nullptr; +#endif + + if( !is_valid_address( ptr ) || ptr < static_cast( iat ) || ptr > static_cast( iat ) + iat_size ) + return nullptr; + + return *reinterpret_cast( ptr ); +} + +/* Holds module filename */ +char module_filename[2*MAX_PATH]; + +BOOL fill_info( const void *addr, Dl_info *info ) +{ + /* Get module of the specified address */ + HMODULE hModule = MyGetModuleHandleFromAddress( addr ); + + if( hModule == nullptr ) + return FALSE; + + DWORD dwSize = GetModuleFileNameA( hModule, module_filename, sizeof( module_filename ) ); + + if( dwSize == 0 || dwSize == sizeof( module_filename ) ) + return FALSE; + + info->dli_fname = module_filename; + info->dli_fbase = static_cast( hModule ); + + /* Find function name and function address in module's export table */ + IMAGE_EXPORT_DIRECTORY *ied; + void *funcAddress = nullptr; + if( get_image_section( hModule, IMAGE_DIRECTORY_ENTRY_EXPORT, reinterpret_cast( &ied ), nullptr ) ) + info->dli_sname = get_export_symbol_name( hModule, ied, addr, &funcAddress ); + else + info->dli_sname = nullptr; + + info->dli_saddr = info->dli_sname == nullptr ? nullptr : funcAddress != nullptr ? funcAddress : const_cast( addr ); + + return TRUE; +} + +} // anonymous namespace + +DLFCN_EXPORT +int dladdr( const void *addr, Dl_info *info ) +{ + if( info == nullptr ) + return 0; + + if( !is_valid_address( addr ) ) + return 0; + + if( is_import_thunk( addr ) ) + { + void *iat; + DWORD iatSize; + + /* Get module of the import thunk address */ + HMODULE hModule = MyGetModuleHandleFromAddress( addr ); + + if( hModule == nullptr ) + return 0; + + if( !get_image_section( hModule, IMAGE_DIRECTORY_ENTRY_IAT, &iat, &iatSize ) ) + { + /* Fallback for cases where the iat is not defined, + * for example i586-mingw32msvc-gcc */ + IMAGE_IMPORT_DESCRIPTOR *iid; + DWORD iidSize; + + if( !get_image_section( hModule, IMAGE_DIRECTORY_ENTRY_IMPORT, reinterpret_cast( &iid ), &iidSize ) ) + return 0; + + if( iid == nullptr || iid->Characteristics == 0 || iid->FirstThunk == 0 ) + return 0; + + iat = reinterpret_cast( reinterpret_cast( hModule ) + static_cast( iid->FirstThunk ) ); + /* We assume that in this case iid and iat's are in linear order */ + iatSize = iidSize - static_cast( static_cast( iat ) - reinterpret_cast( iid ) ); + } + + addr = get_address_from_import_address_table( iat, iatSize, addr ); + + if( !is_valid_address( addr ) ) + return 0; + } + + if( !fill_info( addr, info ) ) + return 0; + + return 1; +} + +#ifdef DLFCN_WIN32_SHARED +BOOL WINAPI DllMain( HINSTANCE hinstDLL, DWORD fdwReason, LPVOID lpvReserved ) +{ + (void) hinstDLL; + (void) fdwReason; + (void) lpvReserved; + return TRUE; +} +#endif \ No newline at end of file diff --git a/fastsafetensors/cpp/dlfcn.h b/fastsafetensors/cpp/dlfcn.h new file mode 100644 index 0000000..bf5c7d4 --- /dev/null +++ b/fastsafetensors/cpp/dlfcn.h @@ -0,0 +1,94 @@ +/* + * dlfcn-win32 + * Copyright (c) 2007 Ramiro Polla + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL + * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#ifndef DLFCN_H +#define DLFCN_H + +#ifdef __cplusplus +extern "C" { +#endif + +#if defined(DLFCN_WIN32_SHARED) +#if defined(DLFCN_WIN32_EXPORTS) +# define DLFCN_EXPORT __declspec(dllexport) +#else +# define DLFCN_EXPORT __declspec(dllimport) +#endif +#else +# define DLFCN_EXPORT +#endif + +/* Relocations are performed when the object is loaded. */ +#define RTLD_NOW 0 + +/* Relocations are performed at an implementation-defined time. + * Windows API does not support lazy symbol resolving (when first reference + * to a given symbol occurs). So RTLD_LAZY implementation is same as RTLD_NOW. + */ +#define RTLD_LAZY RTLD_NOW + +/* All symbols are available for relocation processing of other modules. */ +#define RTLD_GLOBAL (1 << 1) + +/* All symbols are not made available for relocation processing by other modules. */ +#define RTLD_LOCAL (1 << 2) + +/* These two were added in The Open Group Base Specifications Issue 6. + * Note: All other RTLD_* flags in any dlfcn.h are not standard compliant. + */ + +/* The symbol lookup happens in the normal global scope. */ +#define RTLD_DEFAULT ((void *)0) + +/* Specifies the next object after this one that defines name. */ +#define RTLD_NEXT ((void *)-1) + +/* Structure filled in by dladdr() */ +typedef struct dl_info +{ + const char *dli_fname; /* Filename of defining object (thread unsafe and reused on every call to dladdr) */ + void *dli_fbase; /* Load address of that object */ + const char *dli_sname; /* Name of nearest lower symbol */ + void *dli_saddr; /* Exact value of nearest symbol */ +} Dl_info; + +/* Open a symbol table handle. */ +DLFCN_EXPORT void *dlopen(const char *file, int mode); + +/* Close a symbol table handle. */ +DLFCN_EXPORT int dlclose(void *handle); + +/* Get the address of a symbol from a symbol table handle. */ +DLFCN_EXPORT void *dlsym(void *handle, const char *name); + +/* Get diagnostic information. */ +DLFCN_EXPORT char *dlerror(void); + +/* Translate address to symbolic information (no POSIX standard) */ +DLFCN_EXPORT int dladdr(const void *addr, Dl_info *info); + +#ifdef __cplusplus +} +#endif + +#endif /* DLFCN_H */ diff --git a/fastsafetensors/cpp/dstorage.h b/fastsafetensors/cpp/dstorage.h new file mode 100644 index 0000000..b599351 --- /dev/null +++ b/fastsafetensors/cpp/dstorage.h @@ -0,0 +1,1423 @@ +/*------------------------------------------------------------------------------------- + * + * Copyright (c) Microsoft Corporation + * Licensed under the MIT license + * + *-------------------------------------------------------------------------------------*/ + +#if !defined(__cplusplus) + #error C++11 required +#endif + +#pragma once + +#include +#include +#include "dstorageerr.h" + +#define DSTORAGE_SDK_VERSION 300 + +interface ID3D12Resource; +interface ID3D12Fence; +interface IDStorageStatusArray; + +/// +/// The priority of a DirectStorage queue. +/// +enum DSTORAGE_PRIORITY : INT8 { + DSTORAGE_PRIORITY_LOW = -1, + DSTORAGE_PRIORITY_NORMAL = 0, + DSTORAGE_PRIORITY_HIGH = 1, + DSTORAGE_PRIORITY_REALTIME = 2, + + /// + /// The following values can be used for iterating over all priority levels. + /// + DSTORAGE_PRIORITY_FIRST = DSTORAGE_PRIORITY_LOW, + DSTORAGE_PRIORITY_LAST = DSTORAGE_PRIORITY_REALTIME, + + DSTORAGE_PRIORITY_COUNT = 4 +}; + +/// +/// The minimum valid queue capacity. +/// +#define DSTORAGE_MIN_QUEUE_CAPACITY 0x80 + +/// +/// The maximum valid queue capacity. +/// +#define DSTORAGE_MAX_QUEUE_CAPACITY 0x2000 + +/// +/// The source type of a DirectStorage request. +/// +enum DSTORAGE_REQUEST_SOURCE_TYPE : UINT64 { + /// + /// The source of the DirectStorage request is a file. + /// + DSTORAGE_REQUEST_SOURCE_FILE = 0, + + /// + /// The source of the DirectStorage request is a block of memory. + /// + DSTORAGE_REQUEST_SOURCE_MEMORY = 1, +}; + +/// +/// The destination type of a DirectStorage request. +/// +enum DSTORAGE_REQUEST_DESTINATION_TYPE : UINT64 { + /// + /// The destination of the DirectStorage request is a block of memory. + /// + DSTORAGE_REQUEST_DESTINATION_MEMORY = 0, + + /// + /// The destination of the DirectStorage request is an ID3D12Resource + /// that is a buffer. + /// + DSTORAGE_REQUEST_DESTINATION_BUFFER = 1, + + /// + /// The destination of the DirectStorage request is an ID3D12Resource + /// that is a texture. + /// + DSTORAGE_REQUEST_DESTINATION_TEXTURE_REGION = 2, + + /// + /// The destination of the DirectStorage request is an ID3D12Resource + /// that is a texture that will receive all subresources in a + /// single request. + /// + DSTORAGE_REQUEST_DESTINATION_MULTIPLE_SUBRESOURCES = 3, + + /// + /// The destination of the DirectStorage request is an ID3D12Resource + /// that is tiled. + /// + DSTORAGE_REQUEST_DESTINATION_TILES = 4, + + /// + /// The destination of the DirectStorage request is an ID3D12Resource + /// that is a texture that will receive the number of subresources + /// specified in a single request. + /// + DSTORAGE_REQUEST_DESTINATION_MULTIPLE_SUBRESOURCES_RANGE = 5 +}; + +/// +/// The DSTORAGE_QUEUE_DESC structure contains the properties of a DirectStorage +/// queue for the queue's creation. +/// +struct DSTORAGE_QUEUE_DESC { + /// + /// The source type of requests that this DirectStorage queue can accept. + /// + DSTORAGE_REQUEST_SOURCE_TYPE SourceType; + + /// + /// The maximum number of requests that the queue can hold. + /// + UINT16 Capacity; + + /// + /// The priority of the requests in this queue. + /// + DSTORAGE_PRIORITY Priority; + + /// + /// Optional name of the queue. Used for debugging. + /// + _In_opt_z_ const CHAR *Name; + + /// + /// Optional device to use for writing to destination resources and + /// performing GPU decompression. The destination resource's device + /// must match this device. + /// + /// This member may be null. If you specify a null device, then the + /// destination type must be DSTORAGE_REQUEST_DESTINATION_MEMORY. + /// + ID3D12Device* Device; +}; + +/// +/// The DSTORAGE_QUEUE_INFO structure contains the properties and current state +/// of a DirectStorage queue. +/// +struct DSTORAGE_QUEUE_INFO { + /// + /// The DSTORAGE_QUEUE_DESC structure used for the queue's creation. + /// + DSTORAGE_QUEUE_DESC Desc; + + /// + /// The number of available empty slots. If a queue is empty, then the number + /// of empty slots equals capacity - 1. The reserved slot is used to + /// distinguish between empty and full cases. + /// + UINT16 EmptySlotCount; + + /// + /// The number of entries that would need to be enqueued in order to trigger + /// automatic submission. + /// + UINT16 RequestCountUntilAutoSubmit; +}; + +/// +/// The type of compression format used at the decompression stage. +/// Your application can implement custom decompressors, starting from +/// DSTORAGE_CUSTOM_COMPRESSION_0. +/// +enum DSTORAGE_COMPRESSION_FORMAT : UINT8 { + /// + /// The data is uncompressed. + /// + DSTORAGE_COMPRESSION_FORMAT_NONE = 0, + + /// + /// The data is compressed using the built-in GDEFLATE format. + /// + DSTORAGE_COMPRESSION_FORMAT_GDEFLATE = 1, + + /// + /// The data is stored in an application-defined custom format. The + /// application must use IDStorageCustomDecompressionQueue to implement + /// custom decompression. Additional custom compression formats can be + /// used, for example `(DSTORAGE_CUSTOM_COMPRESSION_0 + 1)`. + DSTORAGE_CUSTOM_COMPRESSION_0 = 0x80, +}; + +/// +/// Options for a DirectStorage request. +/// +struct DSTORAGE_REQUEST_OPTIONS { + /// + /// DSTORAGE_COMPRESSION_FORMAT indicating how the data is compressed. + /// + DSTORAGE_COMPRESSION_FORMAT CompressionFormat : 8; + + /// + /// Reserved fields. Must be 0. + /// + UINT8 Reserved1[7]; + + /// + /// DSTORAGE_REQUEST_SOURCE_TYPE enum value indicating whether the + /// source of the request is a file or a block of memory. + /// + DSTORAGE_REQUEST_SOURCE_TYPE SourceType : 1; + + /// + /// DSTORAGE_REQUEST_DESTINATION_TYPE enum value indicating the + /// destination of the request. Block of memory, resource. + /// + DSTORAGE_REQUEST_DESTINATION_TYPE DestinationType : 7; + + /// + /// Reserved fields. Must be 0. + /// + UINT64 Reserved : 48; +}; + +/// +/// Flags controlling DirectStorage debug layer. +/// +enum DSTORAGE_DEBUG { + /// + /// DirectStorage debug layer is disabled. + /// + DSTORAGE_DEBUG_NONE = 0x00, + + /// + /// Print error information to a debugger. + /// + DSTORAGE_DEBUG_SHOW_ERRORS = 0x01, + + /// + /// Trigger a debug break each time an error is detected. + /// + DSTORAGE_DEBUG_BREAK_ON_ERROR = 0x02, + + /// + /// Include object names in ETW events. + /// + DSTORAGE_DEBUG_RECORD_OBJECT_NAMES = 0x04 +}; +DEFINE_ENUM_FLAG_OPERATORS(DSTORAGE_DEBUG); + +/// +/// Represents a file to be accessed by DirectStorage. +/// +DECLARE_INTERFACE_IID_(IDStorageFile, IUnknown, "5de95e7b-955a-4868-a73c-243b29f4b8da") +{ + /// + /// Closes the file, regardless of the reference count on this object. + /// + /// After an IDStorageFile object is closed, it can no longer be used in + /// DirectStorage requests. This does not modify the reference count on this + /// object; Release() must be called as usual. + /// + virtual void STDMETHODCALLTYPE Close() = 0; + + /// + /// Retrieves file information for an opened file. + /// + /// Receives the file information. + /// Standard HRESULT error code. + virtual HRESULT STDMETHODCALLTYPE GetFileInformation(_Out_ BY_HANDLE_FILE_INFORMATION* info) = 0; +}; + +/// +/// Describes a source for a request with SourceType +/// DSTORAGE_REQUEST_SOURCE_FILE. +/// +struct DSTORAGE_SOURCE_FILE { + /// + /// The file to perform this read request from. + /// + IDStorageFile* Source; + + /// + /// The offset, in bytes, in the file to start the read request at. + /// + UINT64 Offset; + + /// + /// Number of bytes to read from the file. + /// + UINT32 Size; +}; + +/// +/// Describes the source for a request with SourceType +/// DSTORAGE_REQUEST_SOURCE_MEMORY. +/// +struct DSTORAGE_SOURCE_MEMORY { + /// + /// Address of the source buffer to be read from. + /// + void const* Source; + + /// + /// Number of bytes to read from the source buffer. + /// + UINT32 Size; +}; + +/// +/// Describes the destination for a request with DestinationType +/// DSTORAGE_REQUEST_DESTINATION_MEMORY. +/// +struct DSTORAGE_DESTINATION_MEMORY { + /// + /// Address of the buffer to receive the final result of this request. + /// + void* Buffer; + + /// + /// Number of bytes to write to the destination buffer. + /// + UINT32 Size; +}; + +/// +/// Describes the destination for a request with DestinationType +/// DSTORAGE_REQUEST_DESTINATION_BUFFER. +/// +struct DSTORAGE_DESTINATION_BUFFER { + /// + /// Address of the resource to receive the final result of this request. + /// + ID3D12Resource* Resource; + + /// + /// The offset, in bytes, in the buffer resource to write into. + /// + UINT64 Offset; + + /// + /// Number of bytes to write to the destination buffer. + /// + UINT32 Size; +}; + +/// +/// Describes the destination for a request with DestinationType +/// DSTORAGE_REQUEST_DESTINATION_TEXTURE_REGION. +/// +struct DSTORAGE_DESTINATION_TEXTURE_REGION { + /// + /// Address of the resource to receive the final result of this request. + /// + ID3D12Resource* Resource; + + /// + /// Describes the destination texture copy location. The subresource + /// referred to must be in the D3D12_RESOURCE_STATE_COMMON state. + /// + UINT SubresourceIndex; + + /// + /// Coordinates and size of the destination region to copy, in pixels. + /// + D3D12_BOX Region; +}; + +/// +/// Describes the destination for a request with DestinationType +/// DSTORAGE_REQUEST_DESTINATION_MULTIPLE_SUBRESOURCES. +/// +struct DSTORAGE_DESTINATION_MULTIPLE_SUBRESOURCES { + /// + /// Address of the resource to receive the final result of this request. The + /// source is expected to contain full data for all subresources, starting + /// from FirstSubresource. + /// + ID3D12Resource* Resource; + + /// + /// Describes the first subresource of the destination texture copy + /// location. The subresource referred to must be in the + /// D3D12_RESOURCE_STATE_COMMON state. + /// + UINT FirstSubresource; +}; + +/// +/// Describes the destination for a request with DestinationType +/// DSTORAGE_REQUEST_DESTINATION_MULTIPLE_SUBRESOURCES_RANGE. +/// +struct DSTORAGE_DESTINATION_MULTIPLE_SUBRESOURCES_RANGE +{ + /// + /// Address of the resource to receive the final result of this request. The + /// source is expected to contain full data for all subresources, starting + /// from FirstSubresource. + /// + ID3D12Resource* Resource; + + /// + /// Describes the first subresource of the destination texture copy + /// location. The subresource referred to must be in the + /// D3D12_RESOURCE_STATE_COMMON state. + /// + UINT FirstSubresource; + + /// + /// Describes the number of subresources to copy to the destination + /// resource starting from the FirstSubresource specified. + /// + UINT NumSubresources; +}; + +/// +/// Describes the destination for a request with DestinationType +/// DSTORAGE_REQUEST_DESTINATION_TILES. +/// +struct DSTORAGE_DESTINATION_TILES { + /// + /// Address of the resource to receive the final result of this request. The + /// source buffer is expected to contain data arranged as if it were the + /// source to a CopyTiles call with these parameters. + /// + ID3D12Resource* Resource; + + /// + /// The starting coordinates of the tiled region. + /// + D3D12_TILED_RESOURCE_COORDINATE TiledRegionStartCoordinate; + + /// + /// The size of the tiled region. + /// + D3D12_TILE_REGION_SIZE TileRegionSize; +}; + +/// +/// Describes the source specified for a DirectStorage request. For a request, +/// the value of `request.Options.SourceType` determines which of these union +/// fields is active. +/// +union DSTORAGE_SOURCE { + DSTORAGE_SOURCE_MEMORY Memory; + DSTORAGE_SOURCE_FILE File; +}; + +/// +/// Describes the destination for a DirectStorage request. For a request, the +/// value of `request.Options.DestinationType` determines which of these union +/// fields is active. +/// +union DSTORAGE_DESTINATION { + DSTORAGE_DESTINATION_MEMORY Memory; + DSTORAGE_DESTINATION_BUFFER Buffer; + DSTORAGE_DESTINATION_TEXTURE_REGION Texture; + DSTORAGE_DESTINATION_MULTIPLE_SUBRESOURCES MultipleSubresources; + DSTORAGE_DESTINATION_TILES Tiles; + DSTORAGE_DESTINATION_MULTIPLE_SUBRESOURCES_RANGE MultipleSubresourcesRange; +}; + +/// +/// Represents a DirectStorage request. +/// +struct DSTORAGE_REQUEST { + /// + /// Combination of decompression and other options for this request. + /// + DSTORAGE_REQUEST_OPTIONS Options; + + /// + /// The source for this request. + /// + DSTORAGE_SOURCE Source; + + /// + /// The destination for this request. + /// + DSTORAGE_DESTINATION Destination; + + /// + /// The uncompressed size in bytes for the destination for this request. + /// If the request is not compressed, then this can be left as 0. + /// + /// For compressed data, if the destination is memory, then the uncompressed size must + /// exactly equal the destination size. For other destination types, the uncompressed + /// size may be greater than the destination size. + /// + /// If the destination is to memory or buffer, then the destination size should + /// be specified in the corresponding struct (for example, DSTORAGE_DESTINATION_MEMORY). + /// For textures, it's the value of pTotalBytes returned by GetCopyableFootprints. + /// For tiles, it's 64k * number of tiles. + /// + UINT32 UncompressedSize; + + /// + /// An arbitrary UINT64 number used for cancellation matching. + /// + UINT64 CancellationTag; + + /// + /// Optional name of the request. Used for debugging. If specified, the + /// string should be accessible until the request completes. + /// + _In_opt_z_ const CHAR *Name; +}; + +/// +/// Flags controlling the behavior of requests enqueued using EnqueueRequests. +/// +enum DSTORAGE_ENQUEUE_REQUEST_FLAGS : UINT32 +{ + /// + /// Requests wait on the ID3D12Fence before writing to the destination. + /// All processing required for the requests before the write can be + /// completed asynchronously once submitted. This is the default behavior. + /// + DSTORAGE_ENQUEUE_REQUEST_FLAG_NONE = 0, + + /// + /// Requests wait on the ID3D12Fence before utilizing the GPU for any of + /// the requests and before writing to the destination. All processing + /// required for the requests, except GPU work or writing to the + /// destination, can be completed asynchronously once submitted. + /// + DSTORAGE_ENQUEUE_REQUEST_FLAG_FENCE_WAIT_BEFORE_GPU_WORK = 1, + + /// + /// Requests wait on the ID3D12Fence before reading from the source. No + /// processing occurs until the ID3D12Fence is set. + /// + DSTORAGE_ENQUEUE_REQUEST_FLAG_FENCE_WAIT_BEFORE_SOURCE_ACCESS = 2 +}; + +/// +/// The maximum number of characters that will be stored for a request's name. +/// +#define DSTORAGE_REQUEST_MAX_NAME 64 + +/// +/// The type of command that failed, as reported by +/// DSTORAGE_ERROR_FIRST_FAILURE. +/// +enum DSTORAGE_COMMAND_TYPE { + DSTORAGE_COMMAND_TYPE_NONE = -1, + DSTORAGE_COMMAND_TYPE_REQUEST = 0, + DSTORAGE_COMMAND_TYPE_STATUS = 1, + DSTORAGE_COMMAND_TYPE_SIGNAL = 2, + DSTORAGE_COMMAND_TYPE_EVENT = 3, +}; + +/// +/// The parameters passed to the EnqueueRequest call, and optional +/// filename if the request is for a file source. +/// +struct DSTORAGE_ERROR_PARAMETERS_REQUEST { + /// + /// For a file source request, the name of the file the request was + /// targeted to. + /// + WCHAR Filename[MAX_PATH]; + + /// + /// The name of the request if one was specified. + /// + CHAR RequestName[DSTORAGE_REQUEST_MAX_NAME]; + + /// + /// The parameters passed to the EnqueueRequest call. + /// + DSTORAGE_REQUEST Request; +}; + +/// +/// The parameters passed to the EnqueueStatus call. +/// +struct DSTORAGE_ERROR_PARAMETERS_STATUS { + IDStorageStatusArray* StatusArray; + UINT32 Index; +}; + +/// +/// The parameters passed to the EnqueueSignal call. +/// +struct DSTORAGE_ERROR_PARAMETERS_SIGNAL { + ID3D12Fence* Fence; + UINT64 Value; +}; + +/// +/// The parameters passed to the EnqueueSetEvent call. +/// +struct DSTORAGE_ERROR_PARAMETERS_EVENT +{ + HANDLE Handle; +}; + +/// +/// Structure to receive the detailed record of the first failed DirectStorage +/// request. +/// +struct DSTORAGE_ERROR_FIRST_FAILURE { + + /// + /// The HRESULT code of the failure. + /// + HRESULT HResult; + + /// + /// Type of the Enqueue command that caused the failure. + /// + DSTORAGE_COMMAND_TYPE CommandType; + + /// + /// The parameters passed to the Enqueue call. + /// + union + { + DSTORAGE_ERROR_PARAMETERS_REQUEST Request; + DSTORAGE_ERROR_PARAMETERS_STATUS Status; + DSTORAGE_ERROR_PARAMETERS_SIGNAL Signal; + DSTORAGE_ERROR_PARAMETERS_EVENT Event; + }; +}; + +/// +/// Structure to receive the detailed record of a failed DirectStorage request. +/// +struct DSTORAGE_ERROR_RECORD { + /// + /// The number of failed requests in the queue since the last + /// RetrieveErrorRecord call. + /// + UINT32 FailureCount; + + /// + /// Detailed record about the first failed command in the enqueue order. + /// + DSTORAGE_ERROR_FIRST_FAILURE FirstFailure; +}; + + +/// +/// Defines common staging buffer sizes. +/// +enum DSTORAGE_STAGING_BUFFER_SIZE : UINT32 { + /// + /// There is no staging buffer. Use this value to force DirectStorage to + /// deallocate any memory it has allocated for staging buffers. + /// + DSTORAGE_STAGING_BUFFER_SIZE_0 = 0, + + /// + /// The default staging buffer size of 32MB. + /// + DSTORAGE_STAGING_BUFFER_SIZE_32MB = 32 * 1048576, +}; + + +/// +/// Flags used with GetRequests1 when requesting +/// items from the custom decompression queue. +/// +enum DSTORAGE_GET_REQUEST_FLAGS : UINT32 +{ + /// + /// Request entries that use custom decompression formats + /// >= DSTORAGE_CUSTOM_COMPRESSION_0. + /// + DSTORAGE_GET_REQUEST_FLAG_SELECT_CUSTOM = 0x01, + + /// + /// Request entries that use built in compression formats + /// that DirectStorage understands. + /// + DSTORAGE_GET_REQUEST_FLAG_SELECT_BUILTIN = 0x02, + + /// + /// Request all entries. This includes custom decompression and + /// built-in compressed formats. + /// + DSTORAGE_GET_REQUEST_FLAG_SELECT_ALL = (DSTORAGE_GET_REQUEST_FLAG_SELECT_CUSTOM | DSTORAGE_GET_REQUEST_FLAG_SELECT_BUILTIN) +}; +DEFINE_ENUM_FLAG_OPERATORS(DSTORAGE_GET_REQUEST_FLAGS); + +/// +/// Specifies information about a custom decompression request. +/// +enum DSTORAGE_CUSTOM_DECOMPRESSION_FLAGS : UINT32 +{ + /// + /// No additional information. + /// + DSTORAGE_CUSTOM_DECOMPRESSION_FLAG_NONE = 0x00, + + /// + /// The uncompressed destination buffer is located in an + /// upload heap, and is marked as WRITE_COMBINED. + /// + DSTORAGE_CUSTOM_DECOMPRESSION_FLAG_DEST_IN_UPLOAD_HEAP = 0x01, +}; +DEFINE_ENUM_FLAG_OPERATORS(DSTORAGE_CUSTOM_DECOMPRESSION_FLAGS); + +/// +/// A custom decompression request. Use IDStorageCustomDecompressionQueue to +/// retrieve these requests. +/// +struct DSTORAGE_CUSTOM_DECOMPRESSION_REQUEST { + /// + /// An identifier provided by DirectStorage. This should be used to + /// identify the request in DSTORAGE_CUSTOM_DECOMPRESSION_RESULT. This + /// identifier is unique among uncompleted requests, but may be reused after + /// a request has completed. + /// + UINT64 Id; + + /// + /// The compression format. This will be >= DSTORAGE_CUSTOM_COMPRESSION_0 + /// if DSTORAGE_CUSTOM_DECOMPRESSION_CUSTOMONLY is used to retrieve requests. + /// + DSTORAGE_COMPRESSION_FORMAT CompressionFormat; + + /// + /// Reserved for future use. + /// + UINT8 Reserved[3]; + + /// + /// Flags containing additional details about the decompression request. + /// + DSTORAGE_CUSTOM_DECOMPRESSION_FLAGS Flags; + + /// + /// The size of SrcBuffer in bytes. + /// + UINT64 SrcSize; + + /// + /// The compressed source buffer. + /// + void const* SrcBuffer; + + /// + /// The size of DstBuffer in bytes. + /// + UINT64 DstSize; + + /// + /// The uncompressed destination buffer. SrcBuffer should be decompressed to + /// DstBuffer. + /// + void* DstBuffer; +}; + +/// +/// The result of a custom decompression operation. If the request failed, then +/// the Result code is passed back through the standard DirectStorage +/// status/error reporting mechanism. +/// +struct DSTORAGE_CUSTOM_DECOMPRESSION_RESULT { + /// + /// The identifier for the request, from DSTORAGE_CUSTOM_DECOMPRESSION_REQUEST. + /// + UINT64 Id; + + /// + /// The result of this decompression. S_OK indicates success. + /// + HRESULT Result; +}; + +/// +/// A queue of decompression requests. This can be obtained using QueryInterface +/// against the factory. Your application must take requests from this queue, +/// decompress them, and report that decompression is complete. That allows an +/// application to provide its own custom decompression. +/// +DECLARE_INTERFACE_IID_(IDStorageCustomDecompressionQueue, IUnknown, "97179b2f-2c21-49ca-8291-4e1bf4a160df") +{ + /// + /// Obtains an event to wait on. This event is set when there are pending + /// decompression requests. + /// + virtual HANDLE STDMETHODCALLTYPE GetEvent() = 0; + + /// + /// Populates the given array of request structs with new pending requests. + /// Your application must arrange to fulfill all these requests, and then + /// call SetRequestResults to indicate completion. + /// + virtual HRESULT STDMETHODCALLTYPE GetRequests( + _In_ UINT32 maxRequests, + _Out_writes_to_(maxRequests, *numRequests) DSTORAGE_CUSTOM_DECOMPRESSION_REQUEST* requests, + _Out_ UINT32* numRequests) = 0; + + /// + /// Your application calls this to indicate that requests have been + /// completed. + /// + /// The number of results in `results`. + /// An array of results, the size is specified by + /// `numResults.` + /// Standard HRESULT error code. + virtual HRESULT STDMETHODCALLTYPE SetRequestResults( + _In_ UINT32 numResults, + _In_reads_(numResults) DSTORAGE_CUSTOM_DECOMPRESSION_RESULT* results) = 0; +}; + + +/// +/// An extension of IDStorageCustomDecompressionQueue that allows an +/// application to retrieve specific types of custom decompression +/// requests from the decompression queue. +/// +DECLARE_INTERFACE_IID_( + IDStorageCustomDecompressionQueue1, + IDStorageCustomDecompressionQueue, + "0D47C6C9-E61A-4706-93B4-68BFE3F4AA4A") +{ + /// + /// Populates the given array of request structs with new pending requests + /// based on the specified custom decompression request type. + /// The application must arrange to fulfill all these requests, and then + /// call SetRequestResults to indicate completion. + /// + virtual HRESULT STDMETHODCALLTYPE GetRequests1( + _In_ DSTORAGE_GET_REQUEST_FLAGS flags, + _In_ UINT32 maxRequests, + _Out_writes_to_(maxRequests, *numRequests) DSTORAGE_CUSTOM_DECOMPRESSION_REQUEST* requests, + _Out_ UINT32 * numRequests) = 0; +}; + +/// +/// Represents the static DirectStorage object used to create DirectStorage +/// queues, open files for DirectStorage access, and other global operations. +/// +DECLARE_INTERFACE_IID_(IDStorageFactory, IUnknown, "6924ea0c-c3cd-4826-b10a-f64f4ed927c1") +{ + /// + /// Creates a DirectStorage queue object. + /// + /// Descriptor to specify the properties of the queue. + /// Specifies the DirectStorage queue interface, such as + /// __uuidof(IDStorageQueue). + /// Receives the new queue created. + /// Standard HRESULT error code. + virtual HRESULT STDMETHODCALLTYPE CreateQueue(const DSTORAGE_QUEUE_DESC *desc, REFIID riid, _COM_Outptr_ void **ppv) = 0; + + /// + /// Opens a file for DirectStorage access. + /// + /// Path of the file to be opened. + /// Specifies the DirectStorage file interface, such as + /// __uuidof(IDStorageFile). + /// Receives the new file opened. + /// Standard HRESULT error code. + virtual HRESULT STDMETHODCALLTYPE OpenFile(_In_z_ const WCHAR *path, REFIID riid, _COM_Outptr_ void **ppv) = 0; + + /// + /// Creates a DirectStorage status array object. + /// + /// Specifies the number of statuses that the array can + /// hold. + /// Specifies object's name that will appear in + // the ETW events if enabled through the debug layer. This is an optional + // parameter. + /// Specifies the DirectStorage status interface, such as + /// __uuidof(IDStorageStatusArray). + /// Receives the new status array object created. + /// Standard HRESULT error code. + virtual HRESULT STDMETHODCALLTYPE CreateStatusArray(UINT32 capacity, _In_opt_ PCSTR name, REFIID riid, _COM_Outptr_ void **ppv) = 0; + + /// + /// Sets flags used to control the debug layer. + /// + /// A set of flags controlling the debug layer. + virtual void STDMETHODCALLTYPE SetDebugFlags(UINT32 flags) = 0; + + /// + /// Sets the size of staging buffer(s) used to temporarily store content loaded + /// from the storage device before they are decompressed. If only uncompressed + /// memory sourced queues writing to cpu memory destinations are used, then the + /// staging buffer may be 0-sized. + /// + /// Size, in bytes, of each staging buffer used + /// to complete a request. + /// + /// + /// The default staging buffer is DSTORAGE_STAGING_BUFFER_SIZE_32MB. + /// If multiple staging buffers are necessary to complete a request, then each + /// separate staging buffer is allocated to this staging buffer size. + /// + /// If the destination is a GPU resource, then some but not all of the staging + /// buffers will be allocated from VRAM. + /// + /// Requests that exceed the specified size to SetStagingBufferSize will fail. + /// + virtual HRESULT STDMETHODCALLTYPE SetStagingBufferSize(UINT32 size) = 0; +}; + +/// +/// Represents an array of status entries to receive completion results for the +/// read requests before them. +/// +/// +/// A status entry receives completion status for all the requests in the +/// DStorageQueue between where it is enqueued and the previously enqueued +/// status entry. Only when all requests enqueued before the status entry +/// complete (that is, IsComplete for the entry returns true), the status entry +/// can be enqueued again. +/// +DECLARE_INTERFACE_IID_(IDStorageStatusArray, IUnknown, "82397587-7cd5-453b-a02e-31379bd64656") +{ + /// + /// Returns a Boolean value indicating that all requests enqueued prior to the + /// specified status entry have completed. + /// + /// Specifies the index of the status entry to retrieve. + /// Boolean value indicating completion. + /// This is equivalent to `GetHResult(index) != E_PENDING`. + virtual bool STDMETHODCALLTYPE IsComplete(UINT32 index) = 0; + + /// + /// Returns the HRESULT code of all requests between the specified status + /// entry and the status entry enqueued before it. + /// + /// Specifies the index of the status entry to retrieve. + /// HRESULT code of the requests. + /// + /// + /// + /// If any requests have not completed yet, the return value is E_PENDING. + /// + /// + /// If all requests have completed, and there were failure(s), then the return + /// value stores the failure code of the first failed request in the enqueue + /// order. + /// + /// + /// If all requests have completed successfully, then the return value is S_OK. + /// + /// + /// + virtual HRESULT STDMETHODCALLTYPE GetHResult(UINT32 index) = 0; +}; + +/// +/// Represents a DirectStorage queue to perform read operations. +/// +DECLARE_INTERFACE_IID_(IDStorageQueue, IUnknown, "cfdbd83f-9e06-4fda-8ea5-69042137f49b") +{ + /// + /// Enqueues a read request to the queue. The request remains in the queue + /// until Submit is called, or until the queue is half full. + /// If there are no free entries in the queue, then the enqueue operation + /// blocks until one becomes available. + /// + /// The read request to be queued. + virtual void STDMETHODCALLTYPE EnqueueRequest(const DSTORAGE_REQUEST *request) = 0; + + /// + /// Enqueues a status write. The status write happens when all requests + /// before the status write entry complete. If there were failure(s) + /// since the previous status write entry, then the HResult of the enqueued + /// status entry stores the failure code of the first failed request in the + /// enqueue order. + /// If there are no free entries in the queue, then the enqueue operation + /// blocks until one becomes available. + /// + /// IDStorageStatusArray object. + /// Index of the status entry in the + /// IDStorageStatusArray object to receive the status. + virtual void STDMETHODCALLTYPE EnqueueStatus(IDStorageStatusArray *statusArray, UINT32 index) = 0; + + /// + /// Enqueues fence write. The fence write happens when all requests before + /// the fence entry complete. + /// If there are no free entries in the queue, then the enqueue operation will + /// block until one becomes available. + /// + /// An ID3D12Fence to be written. + /// The value to write to the fence. + virtual void STDMETHODCALLTYPE EnqueueSignal(ID3D12Fence *fence, UINT64 value) = 0; + + /// + /// Submits all requests enqueued so far to DirectStorage to be executed. + /// + virtual void STDMETHODCALLTYPE Submit() = 0; + + /// + /// Attempts to cancel a group of previously enqueued read requests. All + /// previously enqueued requests whose CancellationTag matches the formula + /// (CancellationTag & mask) == value will be cancelled. + /// A cancelled request might or might not complete its original read request. + /// A cancelled request is not counted as a failure in either + /// IDStorageStatus or DSTORAGE_ERROR_RECORD. + /// + /// The mask for the cancellation formula. + /// The value for the cancellation formula. + virtual void STDMETHODCALLTYPE CancelRequestsWithTag(UINT64 mask, UINT64 value) = 0; + + /// + /// Closes the DirectStorage queue, regardless of the reference count on this + /// object. + /// + /// After the Close function returns, the queue will no longer complete any + /// more requests, even if some are submitted. This does not modify the + /// reference count on this object; Release() must be called as usual. + /// + virtual void STDMETHODCALLTYPE Close() = 0; + + /// + /// Obtains an event to wait on. When there is any error happening for read + /// requests in this queue, the event will be signaled, and you may call + /// RetrieveErrorRecord to retrieve diagnostic information. + /// + /// HANDLE to an event. + virtual HANDLE STDMETHODCALLTYPE GetErrorEvent() = 0; + + /// + /// When the error event is signaled, this function can be called to + /// retrieve a DSTORAGE_ERROR_RECORD. Once the error record is retrieved, + /// this function should not be called until the next time the error event + /// is signaled. + /// + /// Receives the error record. + virtual void STDMETHODCALLTYPE RetrieveErrorRecord(_Out_ DSTORAGE_ERROR_RECORD *record) = 0; + + /// + /// Obtains information about the queue. It includes the DSTORAGE_QUEUE_DESC + /// structure used for the queue's creation as well as the number of empty slots + /// and number of entries that need to be enqueued to trigger automatic + /// submission. + /// + /// Receives the queue information. + virtual void STDMETHODCALLTYPE Query(_Out_ DSTORAGE_QUEUE_INFO *info) = 0; +}; + +/// +/// Represents a DirectStorage queue to perform read operations. +/// +DECLARE_INTERFACE_IID_(IDStorageQueue1, IDStorageQueue, "dd2f482c-5eff-41e8-9c9e-d2374b278128") +{ + /// + /// Enqueues an operation to set the specified event object to a signaled state. + /// The event object is set when all requests before it complete. + /// If there are no free entries in the queue the enqueue operation will + /// block until one becomes available. + /// + /// A handle to an event object. + virtual void STDMETHODCALLTYPE EnqueueSetEvent(HANDLE handle) = 0; +}; + +/// +/// Flags returned with GetCompressionSupport that describe the features +/// used by the runtime to decompress content. +/// +enum DSTORAGE_COMPRESSION_SUPPORT : UINT32 +{ + /// + /// None + /// + DSTORAGE_COMPRESSION_SUPPORT_NONE = 0x0, + + /// + /// Optimized driver support for GPU decompression will be used. + /// + DSTORAGE_COMPRESSION_SUPPORT_GPU_OPTIMIZED = 0x01, + + /// + /// Built-in GPU decompression fallback shader will be used. This can occur if + /// optimized driver support is not available and the D3D12 device used for this + /// DirectStorage queue supports the required capabilities. + /// + DSTORAGE_COMPRESSION_SUPPORT_GPU_FALLBACK = 0x02, + + /// + /// CPU fallback implementation will be used. + /// This can occur if: + /// * Optimized driver support and built-in GPU decompression is not available. + /// * GPU decompression support has been explicitly disabled using + /// DSTORAGE_CONFIGURATION. + /// * DirectStorage runtime encounters a failure during initialization of its + /// GPU decompression system. + /// + DSTORAGE_COMPRESSION_SUPPORT_CPU_FALLBACK = 0x04, + + /// + /// Executes work on a compute queue. + /// + DSTORAGE_COMPRESSION_SUPPORT_USES_COMPUTE_QUEUE = 0x08, + + /// + /// Executes work on a copy queue. + /// + DSTORAGE_COMPRESSION_SUPPORT_USES_COPY_QUEUE = 0x010, +}; +DEFINE_ENUM_FLAG_OPERATORS(DSTORAGE_COMPRESSION_SUPPORT); + +/// +/// Represents a DirectStorage queue to perform read operations. +/// +DECLARE_INTERFACE_IID_(IDStorageQueue2, IDStorageQueue1, "b1c9d643-3a49-44a2-b46f-653649470d18") +{ + /// + /// Obtains support information about the queue for a specified compression format. + /// It includes the chosen path that the DirectStorage runtime will use for decompression. + /// + /// Specifies the compression format to retrieve information + /// about. + virtual DSTORAGE_COMPRESSION_SUPPORT STDMETHODCALLTYPE GetCompressionSupport(DSTORAGE_COMPRESSION_FORMAT format) = 0; +}; + +/// +/// Represents a DirectStorage queue to perform read operations. +/// +DECLARE_INTERFACE_IID_(IDStorageQueue3, IDStorageQueue2, "deb54c52-eca8-46b3-82a7-031b72262653") +{ + /// + /// Enqueues an array of requests to the queue. The requests will be synchronized + /// with the specified `ID3D12Fence` and processed after the synchronization point. + /// + /// A pointer to an array of requests that will be synchronized + /// with the `ID3D12Fence`. + /// The number of requests in the array pointed to by + /// `requests`. + /// A pointer to an `ID3D12Fence` that will be used to synchronize + /// the processing of the requests pointed to by `requests`. + /// The value the `fence` will wait for. Once the `fence` reaches + /// the specified `value`, the `requests` will start processing past the + /// synchronization point. + /// A flag that specifies the synchronization point for the + /// `requests`. + virtual void STDMETHODCALLTYPE EnqueueRequests( + _In_reads_(numRequests) const DSTORAGE_REQUEST * requests, + UINT numRequests, + _In_opt_ ID3D12Fence* fence, + UINT64 value, + DSTORAGE_ENQUEUE_REQUEST_FLAGS flag) = 0; +}; + +/// +/// Disables built-in decompression. +/// +/// Set NumBuiltInCpuDecompressionThreads in DSTORAGE_CONFIGURATION to +/// this value to disable built-in decompression. No decompression threads +/// will be created and the title is fully responsible for checking +/// the custom decompression queue and pulling off ALL entries to decompress. +/// +#define DSTORAGE_DISABLE_BUILTIN_CPU_DECOMPRESSION -1 + +/// +/// DirectStorage Configuration. Zero initializing this will result in the default values. +/// +struct DSTORAGE_CONFIGURATION { + /// + /// Sets the number of threads to use for submitting IO operations. + /// Specifying 0 means use the system's best guess at a good value. + /// Default == 0. + /// + UINT32 NumSubmitThreads; + + /// + /// Sets the number of threads to be used by the DirectStorage runtime to + /// decompress data using the CPU for built-in compressed formats + /// that cannot be decompressed using the GPU. + /// + /// Specifying 0 means to use the system's best guess at a good value. + /// + /// Specifying DSTORAGE_DISABLE_BUILTIN_CPU_DECOMPRESSION means no decompression + /// threads will be created and the title is fully responsible for checking + /// the custom decompression queue and pulling off ALL entries to decompress. + /// + /// Default == 0. + /// + INT32 NumBuiltInCpuDecompressionThreads; + + /// + /// Forces the use of the IO mapping layer, even when running on an + /// operation system that doesn't require it. This may be useful during + /// development, but should be set to the FALSE for release. Default=FALSE. + /// + BOOL ForceMappingLayer; + + /// + /// Disables the use of the bypass IO optimization, even if it is available. + /// This might be useful during development, but should be set to FALSE + /// for release unless ForceFileBuffering is set to TRUE. + /// Default == FALSE. + /// + BOOL DisableBypassIO; + + /// + /// Disables the reporting of telemetry data when set to TRUE. + /// Telemetry data is enabled by default in the DirectStorage runtime. + /// Default == FALSE. + /// + BOOL DisableTelemetry; + + /// + /// Disables the use of a decompression metacommand, even if one + /// is available. This will force the runtime to use the built-in GPU decompression + /// fallback shader. + /// This may be useful during development, but should be set to the FALSE + /// for release. Default == FALSE. + /// + BOOL DisableGpuDecompressionMetacommand; + + /// + /// Disables the use of GPU based decompression, even if it is available. + /// This will force the runtime to use the CPU. Default=FALSE. + /// + BOOL DisableGpuDecompression; +}; + +/// +/// DirectStorage Configuration. Zero initializing this will result in the default values. +/// +struct DSTORAGE_CONFIGURATION1 +{ + /// + /// Sets the number of threads to use for submitting IO operations. + /// Specifying 0 means use the system's best guess at a good value. + /// Default == 0. + /// + UINT32 NumSubmitThreads; + + /// + /// Sets the number of threads to be used by the DirectStorage runtime to + /// decompress data using the CPU for built-in compressed formats + /// that cannot be decompressed using the GPU. + /// + /// Specifying 0 means to use the system's best guess at a good value. + /// + /// Specifying DSTORAGE_DISABLE_BUILTIN_CPU_DECOMPRESSION means no decompression + /// threads will be created and the title is fully responsible for checking + /// the custom decompression queue and pulling off ALL entries to decompress. + /// + /// Default == 0. + /// + INT32 NumBuiltInCpuDecompressionThreads; + + /// + /// Forces the use of the IO mapping layer, even when running on an + /// operation system that doesn't require it. This may be useful during + /// development, but should be set to the FALSE for release. Default=FALSE. + /// + BOOL ForceMappingLayer; + + /// + /// Disables the use of the bypass IO optimization, even if it is available. + /// This might be useful during development, but should be set to FALSE + /// for release unless ForceFileBuffering is set to TRUE. + /// Default == FALSE. + /// + BOOL DisableBypassIO; + + /// + /// Disables the reporting of telemetry data when set to TRUE. + /// Telemetry data is enabled by default in the DirectStorage runtime. + /// Default == FALSE. + /// + BOOL DisableTelemetry; + + /// + /// Disables the use of a decompression metacommand, even if one + /// is available. This will force the runtime to use the built-in GPU decompression + /// fallback shader. + /// This may be useful during development, but should be set to the FALSE + /// for release. Default == FALSE. + /// + BOOL DisableGpuDecompressionMetacommand; + + /// + /// Disables the use of GPU based decompression, even if it is available. + /// This will force the runtime to use the CPU. Default=FALSE. + /// + BOOL DisableGpuDecompression; + + /// + /// Forces the use of the built-in file caching behaviors supported + /// within the Windows operating system by not setting + /// FILE_FLAG_NO_BUFFERING when opening files. + /// + /// DisableBypassIO must be set to TRUE when using this option or + /// E_DSTORAGE_FILEBUFFERING_REQUIRES_DISABLED_BYPASSIO will be returned. + /// + /// It is the title's responsibility to know when to use this setting. + /// This feature should ONLY be enabled for slower HDD drives that will + /// benefit from the OS file buffering features. + /// + /// WARNING: Enabling file buffering on high speed drives may reduce + /// overall performance when reading from that drive because BypassIO + /// is also disabled. Default=FALSE. + /// + BOOL ForceFileBuffering; +}; + +/// +/// Settings controlling DirectStorage compression codec behavior. +/// +enum DSTORAGE_COMPRESSION : INT32 { + + /// + /// Compress data at a fast rate which may not yield the best + /// compression ratio. + /// + DSTORAGE_COMPRESSION_FASTEST = -1, + + /// + /// Compress data at an average rate with a good compression ratio. + /// + DSTORAGE_COMPRESSION_DEFAULT = 0, + + /// + /// Compress data at slow rate with the best compression ratio. + /// + DSTORAGE_COMPRESSION_BEST_RATIO = 1 +}; + +/// +/// Represents the DirectStorage object for compressing and decompressing the buffers. +/// +/// Use DStorageCreateCompressionCodec to get an instance of this. +/// +/// +DECLARE_INTERFACE_IID_(IDStorageCompressionCodec, IUnknown, "84ef5121-9b43-4d03-b5c1-cc34606b262d") +{ + /// + /// Compresses a buffer of data using a known compression format at the specifed + /// compression level. + /// + /// Points to a buffer containing uncompressed data. + /// Size, in bytes, of the uncompressed data buffer. + /// Specifies the compression settings to use. + /// Points to a buffer where compressed data will be + /// written. + /// Size, in bytes, of the buffer which will receive + /// the compressed data + /// Size, in bytes, of the actual size written to compressedBuffer + /// Standard HRESULT error code. + virtual HRESULT STDMETHODCALLTYPE CompressBuffer( + const void* uncompressedData, + size_t uncompressedDataSize, + DSTORAGE_COMPRESSION compressionSetting, + void* compressedBuffer, + size_t compressedBufferSize, + size_t* compressedDataSize) = 0; + + /// + /// Decompresses data previously compressed using CompressBuffer. + /// + /// Points to a buffer containing compressed data. + /// Size, in bytes, of the compressed data buffer. + /// Points to a buffer where uncompressed data will be + /// written. + /// Size, in bytes, of the buffer which will receive + /// the uncompressed data + /// Size, in bytes, of the actual size written to uncompressedBuffer + /// Standard HRESULT error code. + virtual HRESULT STDMETHODCALLTYPE DecompressBuffer( + const void* compressedData, + size_t compressedDataSize, + void* uncompressedBuffer, + size_t uncompressedBufferSize, + size_t* uncompressedDataSize) = 0; + + /// + /// Returns an upper bound estimated size in bytes required to compress the specified data size. + /// + /// Size, in bytes, of the data to be compressed + virtual size_t STDMETHODCALLTYPE CompressBufferBound(size_t uncompressedDataSize) = 0; +}; + +extern "C" { + +/// +/// Configures DirectStorage. This must be called before the first call to +/// DStorageGetFactory. If this is not called, then default values are used. +/// +/// Specifies the configuration. +/// Standard HRESULT error code. The configuration can only be changed +/// when no queue is created and no files are open, +/// E_DSTORAGE_STAGING_BUFFER_LOCKED is returned if this is not the case. +HRESULT WINAPI DStorageSetConfiguration(DSTORAGE_CONFIGURATION const* configuration); + +/// +/// Configures DirectStorage. This must be called before the first call to +/// DStorageGetFactory. If this is not called, then default values are used. +/// +/// Specifies the configuration. +/// Standard HRESULT error code. The configuration can only be changed +/// when no queue is created and no files are open, +/// E_DSTORAGE_STAGING_BUFFER_LOCKED is returned if this is not the case. +HRESULT WINAPI DStorageSetConfiguration1(DSTORAGE_CONFIGURATION1 const* configuration); + +/// +/// Returns the static DirectStorage factory object used to create DirectStorage queues, +/// open files for DirectStorage access, and other global operations. +/// +/// Specifies the DirectStorage factory interface, such as +/// __uuidof(IDStorageFactory) +/// Receives the DirectStorage factory object. +/// Standard HRESULT error code. +HRESULT WINAPI DStorageGetFactory(REFIID riid, _COM_Outptr_ void** ppv); + +/// +/// Returns an object used to compress/decompress content. +/// Compression codecs are not thread safe so multiple +/// instances are required if the codecs need to be used +/// by multiple threads. +/// +/// Specifies how the data is compressed. +/// Specifies maximum number of threads this codec +/// will use. Specifying 0 means to use the system's best guess at a good value. +/// Specifies the DirectStorage compressor/decompressor interface, such as +/// __uuidof(IDStorageCompressionCodec) +/// Receives the DirectStorage object. +/// Standard HRESULT error code. +HRESULT WINAPI DStorageCreateCompressionCodec(DSTORAGE_COMPRESSION_FORMAT format, UINT32 numThreads, REFIID riid, _COM_Outptr_ void** ppv); + +} diff --git a/fastsafetensors/cpp/dstorage_reader.cpp b/fastsafetensors/cpp/dstorage_reader.cpp new file mode 100644 index 0000000..2811cf9 --- /dev/null +++ b/fastsafetensors/cpp/dstorage_reader.cpp @@ -0,0 +1,406 @@ +#include +#include +#include +#include +#include +#include "dstorage.h" +#include +#include +#include +#pragma once +#include "ext.hpp" + +namespace py = pybind11; + +#pragma comment(lib, "d3d12.lib") +#pragma comment(lib, "dxgi.lib") +#pragma comment(lib, "dxguid.lib") + +static constexpr UINT32 DS_STAGING_BUFFER_BYTES = 256u * 1024u * 1024u; + +static const GUID IID_IDStorageFactory = { + 0x6924ea0c, 0xc3cd, 0x4826, + {0xb1, 0x0a, 0xf6, 0x4f, 0x4e, 0xd9, 0x27, 0xc1} +}; +static const GUID IID_IDStorageFile = { + 0x5de95e7b, 0x955a, 0x4868, + {0xa7, 0x3c, 0x24, 0x3b, 0x29, 0xf4, 0xb8, 0xda} +}; +static const GUID IID_IDStorageQueue = { + 0xcfdbd83f, 0x9e06, 0x4fda, + {0x8e, 0xa5, 0x69, 0x04, 0x21, 0x37, 0xf4, 0x9b} +}; + +// DLL loading +typedef HRESULT (__stdcall *PFN_DStorageGetFactory)(REFIID riid, void** ppv); +static PFN_DStorageGetFactory g_pfnGetFactory = nullptr; + +typedef HRESULT (__stdcall *PFN_DStorageSetConfiguration1)(DSTORAGE_CONFIGURATION1 const*); +static PFN_DStorageSetConfiguration1 g_pfnSetConfig1 = nullptr; + +static bool LoadDirectStorage() { + HMODULE hMod = LoadLibraryA("dstoragecore.dll"); + if (!hMod) return false; + hMod = LoadLibraryA("dstorage.dll"); + if (!hMod) return false; + g_pfnGetFactory = (PFN_DStorageGetFactory)GetProcAddress(hMod, "DStorageGetFactory"); + g_pfnSetConfig1 = (PFN_DStorageSetConfiguration1)GetProcAddress(hMod, "DStorageSetConfiguration1"); + return g_pfnGetFactory != nullptr; +} + +// Global state, D3D12 device + DirectStorage factory +class GlobalDStorageState { + static inline int s_device_id = 0; + +public: + static bool Initialize(int device_id, uintptr_t provided_device, const std::string& cudart_dll) { + if (s_initialized) return true; + std::lock_guard lock(s_mutex); + if (s_initialized) return true; + + if (!LoadDirectStorage()) { + last_error_ = "Failed to load dstorage.dll or dstoragecore.dll"; + return false; + } + + if (!cuda_fns.cudaSetDevice || !cuda_fns.cudaImportExternalMemory || + !cuda_fns.cudaExternalMemoryGetMappedBuffer || !cuda_fns.cudaDestroyExternalMemory) { + last_error_ = "CUDA external memory functions not loaded"; + return false; + } + + cudaError_t err = cuda_fns.cudaSetDevice(device_id); + if (err != cudaSuccess) { + last_error_ = "cudaSetDevice failed: " + std::to_string(err); + return false; + } + s_device_id = device_id; + + if (provided_device) { + s_device = reinterpret_cast(provided_device); + s_device->AddRef(); + } else { + IDXGIFactory1* factory = nullptr; + HRESULT hr = CreateDXGIFactory1(IID_IDXGIFactory1, (void**)&factory); + if (FAILED(hr)) { last_error_ = "CreateDXGIFactory1 failed"; return false; } + + IDXGIAdapter1* adapter = nullptr; + for (UINT i = 0; factory->EnumAdapters1(i, &adapter) != DXGI_ERROR_NOT_FOUND; ++i) { + DXGI_ADAPTER_DESC1 desc; + adapter->GetDesc1(&desc); + if (desc.Flags & DXGI_ADAPTER_FLAG_SOFTWARE) { adapter->Release(); continue; } + break; + } + factory->Release(); + if (!adapter) { last_error_ = "No hardware D3D12 adapter found"; return false; } + HRESULT hr2 = D3D12CreateDevice(adapter, D3D_FEATURE_LEVEL_11_0, IID_ID3D12Device, (void**)&s_device); + adapter->Release(); + if (FAILED(hr2)) { last_error_ = "D3D12CreateDevice failed"; return false; } + } + + HRESULT hr = g_pfnGetFactory(IID_IDStorageFactory, (void**)&s_factory); + if (FAILED(hr)) { + last_error_ = "DStorageGetFactory failed"; + return false; + } + + hr = s_factory->SetStagingBufferSize(DS_STAGING_BUFFER_BYTES); + if (FAILED(hr)) { + last_error_ = "SetStagingBufferSize(" + + std::to_string(DS_STAGING_BUFFER_BYTES) + + ") failed: hr=0x" + std::to_string(hr); + return false; + } + + if (g_pfnSetConfig1) { + DSTORAGE_CONFIGURATION1 config = {}; + config.NumSubmitThreads = 8; + g_pfnSetConfig1(&config); + } + + s_initialized = true; + return true; + } + + static ID3D12Device* GetDevice() { return s_device; } + static IDStorageFactory* GetFactory() { return s_factory; } + static int GetCudaDeviceId() { return s_device_id; } + static const std::string& LastError() { return last_error_; } + + static void Shutdown() { + if (s_factory) { s_factory->Release(); s_factory = nullptr; } + if (s_device) { s_device->Release(); s_device = nullptr; } + s_initialized = false; + } + +private: + static inline bool s_initialized = false; + static inline ID3D12Device* s_device = nullptr; + static inline IDStorageFactory* s_factory = nullptr; + static inline std::string last_error_; + static inline std::mutex s_mutex; +}; + +// dstorage_file_handle, wraps IDStorageFile +class dstorage_file_handle { +public: + bool open(const std::string& path_utf8) { + int wlen = MultiByteToWideChar(CP_UTF8, 0, path_utf8.c_str(), -1, nullptr, 0); + std::vector wpath(wlen); + MultiByteToWideChar(CP_UTF8, 0, path_utf8.c_str(), -1, wpath.data(), wlen); + + HRESULT hr = GlobalDStorageState::GetFactory()->OpenFile( + wpath.data(), IID_IDStorageFile, (void**)&file_); + return SUCCEEDED(hr); + } + + void close() { + if (file_) { file_->Close(); file_->Release(); file_ = nullptr; } + } + + IDStorageFile* get() const { return file_; } + +private: + IDStorageFile* file_ = nullptr; +}; + +// dstorage_stream_reader, double-buffered DS staging to CUDA copy pipeline +class dstorage_stream_reader { +public: + static constexpr uint64_t STAGE_SIZE = 64ULL * 1024 * 1024; + + dstorage_stream_reader() { + auto* dev = GlobalDStorageState::GetDevice(); + auto* factory = GlobalDStorageState::GetFactory(); + if (!dev || !factory) { + fprintf(stderr, "dstorage_stream_reader: GlobalDStorageState not initialized (dev=%p, factory=%p)\n", dev, factory); + return; + } + cuda_fns.cudaSetDevice(GlobalDStorageState::GetCudaDeviceId()); + + // Create the DirectStorage queue first + DSTORAGE_QUEUE_DESC qdesc = {}; + qdesc.SourceType = DSTORAGE_REQUEST_SOURCE_FILE; + qdesc.Capacity = 8192; + qdesc.Priority = DSTORAGE_PRIORITY_HIGH; + qdesc.Name = "fs_stream"; + qdesc.Device = dev; + HRESULT hr = factory->CreateQueue( + &qdesc, IID_IDStorageQueue, (void**)&queue_); + if (FAILED(hr)) { + fprintf(stderr, "dstorage_stream_reader: CreateQueue failed hr=0x%08X\n", (unsigned)hr); + return; + } + + hr = dev->CreateFence(0, D3D12_FENCE_FLAG_NONE, + IID_ID3D12Fence, (void**)&fence_); + if (FAILED(hr)) { + fprintf(stderr, "dstorage_stream_reader: CreateFence failed hr=0x%08X\n", (unsigned)hr); + return; + } + + // Allocate two D3D12 staging buffers with CUDA interop + D3D12_HEAP_PROPERTIES hp = {}; + hp.Type = D3D12_HEAP_TYPE_DEFAULT; + D3D12_RESOURCE_DESC desc = {}; + desc.Dimension = D3D12_RESOURCE_DIMENSION_BUFFER; + desc.Width = STAGE_SIZE; + desc.Height = 1; + desc.DepthOrArraySize = 1; + desc.MipLevels = 1; + desc.SampleDesc.Count = 1; + desc.Layout = D3D12_TEXTURE_LAYOUT_ROW_MAJOR; + + for (int i = 0; i < 2; i++) { + hr = dev->CreateCommittedResource( + &hp, D3D12_HEAP_FLAG_SHARED, &desc, + D3D12_RESOURCE_STATE_COMMON, nullptr, + IID_ID3D12Resource, (void**)&stage_res_[i]); + if (FAILED(hr)) { + fprintf(stderr, "dstorage_stream_reader: CreateCommittedResource[%d] failed hr=0x%08X\n", i, (unsigned)hr); + return; + } + + hr = dev->CreateSharedHandle(stage_res_[i], nullptr, GENERIC_ALL, + nullptr, &stage_handle_[i]); + if (FAILED(hr)) { + fprintf(stderr, "dstorage_stream_reader: CreateSharedHandle[%d] failed\n", i); + return; + } + + cudaExternalMemoryHandleDesc emd = {}; + emd.type = cudaExternalMemoryHandleTypeD3D12Resource; + emd.handle.win32.handle = stage_handle_[i]; + emd.size = STAGE_SIZE; + emd.flags = cudaExternalMemoryDedicated; + cudaError_t cerr = cuda_fns.cudaImportExternalMemory(&stage_ext_mem_[i], &emd); + if (cerr != cudaSuccess) { + fprintf(stderr, "dstorage_stream_reader: cudaImportExternalMemory[%d] failed err=%d\n", i, cerr); + return; + } + + cudaExternalMemoryBufferDesc ebd = {}; + ebd.size = STAGE_SIZE; + cerr = cuda_fns.cudaExternalMemoryGetMappedBuffer(&stage_cuda_ptr_[i], stage_ext_mem_[i], &ebd); + if (cerr != cudaSuccess) { + fprintf(stderr, "dstorage_stream_reader: cudaExternalMemoryGetMappedBuffer[%d] failed err=%d\n", i, cerr); + return; + } + } + + ready_ = true; + } + + ~dstorage_stream_reader() { close(); } + + bool is_ready() const { return ready_; } + + int64_t read_to_cuda(dstorage_file_handle& fh, + uintptr_t dst_cuda_ptr, + uint64_t file_offset, + uint64_t total_bytes) { + if (!ready_) return -1; + + char* dst = reinterpret_cast(dst_cuda_ptr); + uint64_t remaining = total_bytes; + uint64_t src_off = file_offset; + uint64_t dst_off = 0; + int cur = 0; + + bool prev_pending = false; + int prev_buf = 0; + uint64_t prev_size = 0; + uint64_t prev_dst_off = 0; + uint64_t prev_fence = 0; + + while (remaining > 0 || prev_pending) { + // Kick off DS read for current chunk + uint64_t chunk = 0; + uint64_t cur_fence = 0; + if (remaining > 0) { + chunk = (remaining > STAGE_SIZE) ? STAGE_SIZE : remaining; + + DSTORAGE_REQUEST req = {}; + req.Options.CompressionFormat = DSTORAGE_COMPRESSION_FORMAT_NONE; + req.Options.SourceType = DSTORAGE_REQUEST_SOURCE_FILE; + req.Options.DestinationType = DSTORAGE_REQUEST_DESTINATION_BUFFER; + req.Source.File.Source = fh.get(); + req.Source.File.Offset = src_off; + req.Source.File.Size = static_cast(chunk); + req.Destination.Buffer.Resource = stage_res_[cur]; + req.Destination.Buffer.Offset = 0; + req.Destination.Buffer.Size = static_cast(chunk); + req.UncompressedSize = static_cast(chunk); + + queue_->EnqueueRequest(&req); + cur_fence = ++fence_val_; + queue_->EnqueueSignal(fence_, cur_fence); + queue_->Submit(); + + src_off += chunk; + remaining -= chunk; + } + + // While DS reads into cur, copy previous staging to final CUDA + if (prev_pending) { + wait_fence_internal(prev_fence); + + HANDLE errEvent = queue_->GetErrorEvent(); + if (errEvent && WaitForSingleObject(errEvent, 0) == WAIT_OBJECT_0) { + DSTORAGE_ERROR_RECORD rec = {}; + queue_->RetrieveErrorRecord(&rec); + if (rec.FailureCount > 0) { + last_hresult_ = rec.FirstFailure.HResult; + return -2; + } + } + + cudaError_t cerr = cuda_fns.cudaMemcpy( + dst + prev_dst_off, + stage_cuda_ptr_[prev_buf], + prev_size, + cudaMemcpyDeviceToDevice); + if (cerr != cudaSuccess) { + fprintf(stderr, "dstorage_stream_reader: cudaMemcpy failed err=%d\n", cerr); + return -3; + } + } + + // Current becomes previous + if (chunk > 0) { + prev_pending = true; + prev_buf = cur; + prev_size = chunk; + prev_dst_off = dst_off; + prev_fence = cur_fence; + dst_off += chunk; + cur ^= 1; + } else { + prev_pending = false; + } + } + + return static_cast(total_bytes); + } + + int64_t last_hresult() const { return static_cast(last_hresult_); } + + void close() { + for (int i = 0; i < 2; i++) { + if (stage_cuda_ptr_[i]) { + cuda_fns.cudaDestroyExternalMemory(stage_ext_mem_[i]); + stage_cuda_ptr_[i] = nullptr; + stage_ext_mem_[i] = nullptr; + } + if (stage_res_[i]) { stage_res_[i]->Release(); stage_res_[i] = nullptr; } + if (stage_handle_[i]) { CloseHandle(stage_handle_[i]); stage_handle_[i] = nullptr; } + } + if (queue_) { queue_->Close(); queue_->Release(); queue_ = nullptr; } + if (fence_) { fence_->Release(); fence_ = nullptr; } + ready_ = false; + } + +private: + void wait_fence_internal(uint64_t fval) { + if (fence_->GetCompletedValue() < fval) { + HANDLE evt = CreateEventA(nullptr, FALSE, FALSE, nullptr); + fence_->SetEventOnCompletion(fval, evt); + WaitForSingleObject(evt, INFINITE); + CloseHandle(evt); + } + } + + ID3D12Resource* stage_res_[2] = {}; + HANDLE stage_handle_[2] = {}; + cudaExternalMemory_t stage_ext_mem_[2] = {}; + void* stage_cuda_ptr_[2] = {}; + IDStorageQueue* queue_ = nullptr; + ID3D12Fence* fence_ = nullptr; + uint64_t fence_val_ = 0; + HRESULT last_hresult_ = S_OK; + bool ready_ = false; +}; + +void init_dstorage_bindings(py::module_& m) { + m.def("init_dstorage", [](int device_id, uintptr_t d3d12_ptr, const std::string& cudart_dll) -> std::string { + if (GlobalDStorageState::Initialize(device_id, d3d12_ptr, cudart_dll)) + return "ok"; + return GlobalDStorageState::LastError(); + }, py::arg("device_id") = 0, py::arg("d3d12_device_ptr") = 0, py::arg("cudart_dll") = "cudart64_12.dll"); + + py::class_(m, "dstorage_file_handle") + .def(py::init<>()) + .def("open", &dstorage_file_handle::open) + .def("close", &dstorage_file_handle::close); + + py::class_(m, "dstorage_stream_reader") + .def(py::init<>()) + .def("is_ready", &dstorage_stream_reader::is_ready) + .def("read_to_cuda", &dstorage_stream_reader::read_to_cuda, + py::arg("fh"), + py::arg("dst_cuda_ptr"), + py::arg("file_offset"), + py::arg("total_bytes")) + .def("last_hresult", &dstorage_stream_reader::last_hresult) + .def("close", &dstorage_stream_reader::close); +} \ No newline at end of file diff --git a/fastsafetensors/cpp/dstorageerr.h b/fastsafetensors/cpp/dstorageerr.h new file mode 100644 index 0000000..1491c84 --- /dev/null +++ b/fastsafetensors/cpp/dstorageerr.h @@ -0,0 +1,492 @@ +/*------------------------------------------------------------------------------------- + * + * Copyright (c) Microsoft Corporation + * Licensed under the MIT license + * + *-------------------------------------------------------------------------------------*/ +#pragma once + +/*++ + + MessageId's 0x0000 - 0x00ff (inclusive) are reserved for DirectStorage. + +--*/ +// +// Values are 32 bit values laid out as follows: +// +// 3 3 2 2 2 2 2 2 2 2 2 2 1 1 1 1 1 1 1 1 1 1 +// 1 0 9 8 7 6 5 4 3 2 1 0 9 8 7 6 5 4 3 2 1 0 9 8 7 6 5 4 3 2 1 0 +// +---+-+-+-----------------------+-------------------------------+ +// |Sev|C|R| Facility | Code | +// +---+-+-+-----------------------+-------------------------------+ +// +// where +// +// Sev - is the severity code +// +// 00 - Success +// 01 - Informational +// 10 - Warning +// 11 - Error +// +// C - is the Customer code flag +// +// R - is a reserved bit +// +// Facility - is the facility code +// +// Code - is the facility's status code +// +// +// Define the facility codes +// +#define FACILITY_GAME 2340 + + +// +// Define the severity codes +// + + +// +// MessageId: E_DSTORAGE_ALREADY_RUNNING +// +// MessageText: +// +// DStorage is already running exclusively. +// +#define E_DSTORAGE_ALREADY_RUNNING ((HRESULT)0x89240001L) + +// +// MessageId: E_DSTORAGE_NOT_RUNNING +// +// MessageText: +// +// DStorage is not running. +// +#define E_DSTORAGE_NOT_RUNNING ((HRESULT)0x89240002L) + +// +// MessageId: E_DSTORAGE_INVALID_QUEUE_CAPACITY +// +// MessageText: +// +// Invalid queue capacity parameter. +// +#define E_DSTORAGE_INVALID_QUEUE_CAPACITY ((HRESULT)0x89240003L) + +// +// MessageId: E_DSTORAGE_XVD_DEVICE_NOT_SUPPORTED +// +// MessageText: +// +// The specified XVD is not on a supported NVMe device. +// This error only applies to Xbox. +// +#define E_DSTORAGE_XVD_DEVICE_NOT_SUPPORTED ((HRESULT)0x89240004L) + +// +// MessageId: E_DSTORAGE_UNSUPPORTED_VOLUME +// +// MessageText: +// +// The specified XVD is not on a supported volume. +// This error only applies to Xbox. +// +#define E_DSTORAGE_UNSUPPORTED_VOLUME ((HRESULT)0x89240005L) + +// +// MessageId: E_DSTORAGE_END_OF_FILE +// +// MessageText: +// +// The specified offset and length exceeds the size of the file. +// +#define E_DSTORAGE_END_OF_FILE ((HRESULT)0x89240007L) + +// +// MessageId: E_DSTORAGE_REQUEST_TOO_LARGE +// +// MessageText: +// +// The IO request is too large. +// +#define E_DSTORAGE_REQUEST_TOO_LARGE ((HRESULT)0x89240008L) + +// +// MessageId: E_DSTORAGE_ACCESS_VIOLATION +// +// MessageText: +// +// The destination buffer for the DStorage request is not accessible. +// +#define E_DSTORAGE_ACCESS_VIOLATION ((HRESULT)0x89240009L) + +// +// MessageId: E_DSTORAGE_UNSUPPORTED_FILE +// +// MessageText: +// +// The file is not supported by DStorage. Possible reasons include the file is a +// sparse file, or is compressed in NTFS. +// This error only applies to Xbox. +// +#define E_DSTORAGE_UNSUPPORTED_FILE ((HRESULT)0x8924000AL) + +// +// MessageId: E_DSTORAGE_FILE_NOT_OPEN +// +// MessageText: +// +// The file is not open. +// +#define E_DSTORAGE_FILE_NOT_OPEN ((HRESULT)0x8924000BL) + +// +// MessageId: E_DSTORAGE_RESERVED_FIELDS +// +// MessageText: +// +// A reserved field is not set to 0. +// +#define E_DSTORAGE_RESERVED_FIELDS ((HRESULT)0x8924000CL) + +// +// MessageId: E_DSTORAGE_INVALID_BCPACK_MODE +// +// MessageText: +// +// The request has invalid BCPack decompression mode. +// This error only applies to Xbox. +// +#define E_DSTORAGE_INVALID_BCPACK_MODE ((HRESULT)0x8924000DL) + +// +// MessageId: E_DSTORAGE_INVALID_SWIZZLE_MODE +// +// MessageText: +// +// The request has invalid swizzle mode. +// This error only applies to Xbox. +// +#define E_DSTORAGE_INVALID_SWIZZLE_MODE ((HRESULT)0x8924000EL) + +// +// MessageId: E_DSTORAGE_INVALID_DESTINATION_SIZE +// +// MessageText: +// +// The request's destination size is invalid. If no decompression is used, it must +// be equal to the request's length; If decompression is used, it must be larger +// than the request's length. +// +#define E_DSTORAGE_INVALID_DESTINATION_SIZE ((HRESULT)0x8924000FL) + +// +// MessageId: E_DSTORAGE_QUEUE_CLOSED +// +// MessageText: +// +// The request targets a queue that is closed. +// +#define E_DSTORAGE_QUEUE_CLOSED ((HRESULT)0x89240010L) + +// +// MessageId: E_DSTORAGE_INVALID_CLUSTER_SIZE +// +// MessageText: +// +// The volume is formatted with an unsupported cluster size. +// This error only applies to Xbox. +// +#define E_DSTORAGE_INVALID_CLUSTER_SIZE ((HRESULT)0x89240011L) + +// +// MessageId: E_DSTORAGE_TOO_MANY_QUEUES +// +// MessageText: +// +// The number of queues has reached the maximum limit. +// +#define E_DSTORAGE_TOO_MANY_QUEUES ((HRESULT)0x89240012L) + +// +// MessageId: E_DSTORAGE_INVALID_QUEUE_PRIORITY +// +// MessageText: +// +// Invalid priority is specified for the queue. +// +#define E_DSTORAGE_INVALID_QUEUE_PRIORITY ((HRESULT)0x89240013L) + +// +// MessageId: E_DSTORAGE_TOO_MANY_FILES +// +// MessageText: +// +// The number of files has reached the maximum limit. +// +#define E_DSTORAGE_TOO_MANY_FILES ((HRESULT)0x89240014L) + +// +// MessageId: E_DSTORAGE_INDEX_BOUND +// +// MessageText: +// +// The index parameter is out of bound. +// +#define E_DSTORAGE_INDEX_BOUND ((HRESULT)0x89240015L) + +// +// MessageId: E_DSTORAGE_IO_TIMEOUT +// +// MessageText: +// +// The IO operation has timed out. +// +#define E_DSTORAGE_IO_TIMEOUT ((HRESULT)0x89240016L) + +// +// MessageId: E_DSTORAGE_INVALID_FILE_HANDLE +// +// MessageText: +// +// The specified file has not been opened. +// +#define E_DSTORAGE_INVALID_FILE_HANDLE ((HRESULT)0x89240017L) + +// +// MessageId: E_DSTORAGE_DEPRECATED_PREVIEW_GDK +// +// MessageText: +// +// This GDK preview is deprecated. Update to a supported GDK version. +// This error only applies to Xbox. +// +#define E_DSTORAGE_DEPRECATED_PREVIEW_GDK ((HRESULT)0x89240018L) + +// +// MessageId: E_DSTORAGE_XVD_NOT_REGISTERED +// +// MessageText: +// +// The specified XVD is not registered or unmounted. +// This error only applies to Xbox. +// +#define E_DSTORAGE_XVD_NOT_REGISTERED ((HRESULT)0x89240019L) + +// +// MessageId: E_DSTORAGE_INVALID_FILE_OFFSET +// +// MessageText: +// +// The request has invalid file offset for the specified decompression mode. +// +#define E_DSTORAGE_INVALID_FILE_OFFSET ((HRESULT)0x8924001AL) + +// +// MessageId: E_DSTORAGE_INVALID_SOURCE_TYPE +// +// MessageText: +// +// A memory source request was enqueued into a file source queue, or a file source +// request was enqueued into a memory source queue. +// +#define E_DSTORAGE_INVALID_SOURCE_TYPE ((HRESULT)0x8924001BL) + +// +// MessageId: E_DSTORAGE_INVALID_INTERMEDIATE_SIZE +// +// MessageText: +// +// The request has invalid intermediate size for the specified decompression modes. +// This error only applies to Xbox. +// +#define E_DSTORAGE_INVALID_INTERMEDIATE_SIZE ((HRESULT)0x8924001CL) + +// +// MessageId: E_DSTORAGE_SYSTEM_NOT_SUPPORTED +// +// MessageText: +// +// This console generation doesn't support DirectStorage. +// This error only applies to Xbox. +// +#define E_DSTORAGE_SYSTEM_NOT_SUPPORTED ((HRESULT)0x8924001DL) + +// +// MessageId: E_DSTORAGE_STAGING_BUFFER_LOCKED +// +// MessageText: +// +// Staging buffer size can only be changed when no queue is created and no file is +// open. +// +#define E_DSTORAGE_STAGING_BUFFER_LOCKED ((HRESULT)0x8924001FL) + +// +// MessageId: E_DSTORAGE_INVALID_STAGING_BUFFER_SIZE +// +// MessageText: +// +// The specified staging buffer size is not valid. +// +#define E_DSTORAGE_INVALID_STAGING_BUFFER_SIZE ((HRESULT)0x89240020L) + +// +// MessageId: E_DSTORAGE_STAGING_BUFFER_TOO_SMALL +// +// MessageText: +// +// The staging buffer isn't large enough to perform this operation. +// +#define E_DSTORAGE_STAGING_BUFFER_TOO_SMALL ((HRESULT)0x89240021L) + +// +// MessageId: E_DSTORAGE_INVALID_FENCE +// +// MessageText: +// +// The fence is not valid or has been released. +// +#define E_DSTORAGE_INVALID_FENCE ((HRESULT)0x89240022L) + +// +// MessageId: E_DSTORAGE_INVALID_STATUS_ARRAY +// +// MessageText: +// +// The status array is not valid or has been released. +// +#define E_DSTORAGE_INVALID_STATUS_ARRAY ((HRESULT)0x89240023L) + +// +// MessageId: E_DSTORAGE_INVALID_MEMORY_QUEUE_PRIORITY +// +// MessageText: +// +// Invalid priority is specified for the queue. Only DSTORAGE_PRIORITY_REALTIME +// is a valid priority for a memory queue. +// +#define E_DSTORAGE_INVALID_MEMORY_QUEUE_PRIORITY ((HRESULT)0x89240024L) + +// +// MessageId: E_DSTORAGE_DECOMPRESSION_ERROR +// +// MessageText: +// +// A generic error has happened during decompression. +// +#define E_DSTORAGE_DECOMPRESSION_ERROR ((HRESULT)0x89240030L) + +// +// MessageId: E_DSTORAGE_ZLIB_BAD_HEADER +// +// MessageText: +// +// ZLIB header is corrupted. +// This error only applies to Xbox. +// +#define E_DSTORAGE_ZLIB_BAD_HEADER ((HRESULT)0x89240031L) + +// +// MessageId: E_DSTORAGE_ZLIB_BAD_DATA +// +// MessageText: +// +// ZLIB compressed data is corrupted/invalid. +// This error only applies to Xbox. +// +#define E_DSTORAGE_ZLIB_BAD_DATA ((HRESULT)0x89240032L) + +// +// MessageId: E_DSTORAGE_ZLIB_PARITY_FAIL +// +// MessageText: +// +// Block-level ADLER parity check failed during ZLIB decompression. +// This error only applies to Xbox. +// +#define E_DSTORAGE_ZLIB_PARITY_FAIL ((HRESULT)0x89240033L) + +// +// MessageId: E_DSTORAGE_BCPACK_BAD_HEADER +// +// MessageText: +// +// BCPack header is corrupted. +// This error only applies to Xbox. +// +#define E_DSTORAGE_BCPACK_BAD_HEADER ((HRESULT)0x89240034L) + +// +// MessageId: E_DSTORAGE_BCPACK_BAD_DATA +// +// MessageText: +// +// BCPack decoder has generated more data than expected, most likely due to +// corrupted bitstream. +// This error only applies to Xbox. +// +#define E_DSTORAGE_BCPACK_BAD_DATA ((HRESULT)0x89240035L) + +// +// MessageId: E_DSTORAGE_DECRYPTION_ERROR +// +// MessageText: +// +// A generic error has happened during decryption. +// This error only applies to Xbox. +// +#define E_DSTORAGE_DECRYPTION_ERROR ((HRESULT)0x89240036L) + +// +// MessageId: E_DSTORAGE_PASSTHROUGH_ERROR +// +// MessageText: +// +// A generic error has happened during copy operation. +// This error only applies to Xbox. +// +#define E_DSTORAGE_PASSTHROUGH_ERROR ((HRESULT)0x89240037L) + +// +// MessageId: E_DSTORAGE_FILE_TOO_FRAGMENTED +// +// MessageText: +// +// The file is too fragmented to be accessed by DStorage. This error can only +// happen with files overly fragmented on a writable volume. +// This error only applies to Xbox. +// +#define E_DSTORAGE_FILE_TOO_FRAGMENTED ((HRESULT)0x89240038L) + +// +// MessageId: E_DSTORAGE_COMPRESSED_DATA_TOO_LARGE +// +// MessageText: +// +// The size of the resulting compressed data is too large for +// DirectStorage to decompress successfully on the GPU. +// +#define E_DSTORAGE_COMPRESSED_DATA_TOO_LARGE ((HRESULT)0x89240039L) + +// +// MessageId: E_DSTORAGE_INVALID_DESTINATION_TYPE +// +// MessageText: +// +// A gpu memory destination request was enqueued into a queue that +// was created without a D3D device or the destination type is +// unknown. +// +#define E_DSTORAGE_INVALID_DESTINATION_TYPE ((HRESULT)0x89240040L) + +// +// MessageId: E_DSTORAGE_FILEBUFFERING_REQUIRES_DISABLED_BYPASSIO +// +// MessageText: +// +// ForceFileBuffering was enabled without disabling BypassIO. +// +#define E_DSTORAGE_FILEBUFFERING_REQUIRES_DISABLED_BYPASSIO ((HRESULT)0x89240041L) diff --git a/fastsafetensors/cpp/ext.cpp b/fastsafetensors/cpp/ext.cpp index 700e97f..33f2fae 100644 --- a/fastsafetensors/cpp/ext.cpp +++ b/fastsafetensors/cpp/ext.cpp @@ -1,11 +1,50 @@ // SPDX-License-Identifier: Apache-2.0 +#ifdef _MSC_VER +#define _CRT_SECURE_NO_WARNINGS +#endif + #include #include +#ifdef _MSC_VER +#include +#include +#include +#include +#include +#include "mman.h" +#include "dlfcn.h" +// Windows-compatible posix_memalign +static inline int posix_memalign(void **memptr, size_t alignment, size_t size) { + *memptr = _aligned_malloc(size, alignment); + return (*memptr) ? 0 : errno; +} +// Windows-compatible pread +static inline int64_t pread(int fd, void *buf, size_t count, int64_t offset) { + int64_t cur = _lseeki64(fd, 0, 1 /*SEEK_CUR*/); + if (cur < 0) return -1; + if (_lseeki64(fd, offset, 0 /*SEEK_SET*/) < 0) return -1; + int rd = _read(fd, buf, (unsigned int)count); + _lseeki64(fd, cur, 0 /*SEEK_SET*/); + return rd; +} +#ifndef RTLD_NODELETE +#define RTLD_NODELETE 0 +#endif +// Map POSIX names to MSVC equivalents +#define open _open +#define close _close +#define O_RDONLY _O_RDONLY +#ifndef O_DIRECT +#define O_DIRECT 0 +#endif +#else #include #include #include #include +#endif +#include #include #include @@ -14,7 +53,9 @@ #define ALIGN 4096 -static bool debug_log = false; +void init_dstorage_bindings(pybind11::module_&); + +bool debug_log = false; // non-static: referenced by dstorage_compat.cpp on Windows static bool enable_gil_release = false; static cpp_metrics_t mc = {.bounce_buffer_bytes = 0}; @@ -53,7 +94,11 @@ static cudaError_t cpu_cudaHostAlloc(void ** p, size_t length, unsigned int) { return cudaSuccess; } static cudaError_t cpu_cudaFreeHost(void * p) { +#ifdef _MSC_VER + _aligned_free(p); +#else free(p); +#endif return cudaSuccess; } static cudaError_t cpu_cudaDeviceGetPCIBusId(char * in, int s, int) { @@ -82,10 +127,14 @@ ext_funcs_t cpu_fns = ext_funcs_t { .cudaDeviceGetPCIBusId = cpu_cudaDeviceGetPCIBusId, .numa_run_on_node = cpu_numa_run_on_node, .cudaSetDevice = cpu_cudaSetDevice, + .cudaImportExternalMemory = nullptr, + .cudaExternalMemoryGetMappedBuffer = nullptr, + .cudaDestroyExternalMemory = nullptr, }; ext_funcs_t cuda_fns; static bool cuda_found = false; +static bool is_hip_runtime = false; // Track if we loaded HIP (not auto-hipified) static bool cufile_found = false; static int cufile_ver = 0; @@ -94,27 +143,37 @@ template void mydlsym(T** h, void* lib, std::string const& name) { *h = reinterpret_cast(dlsym(lib, name.c_str())); } -static void load_library_functions() { +static void load_library_functions(const std::string& cudart_override = "") { cudaError_t (*cudaGetDeviceCount)(int*); +#ifdef _MSC_VER + const char* cufileLib = nullptr; // cuFile not available on Windows + const char* numaLib = nullptr; // NUMA not available on Windows +#else const char* cufileLib = "libcufile.so.0"; - const char* cudartLib = GPU_RUNTIME_LIB; const char* numaLib = "libnuma.so.1"; +#endif + // Use the runtime-provided library name if given, otherwise fall back + // to the compile-time default from cuda_compat.h (GPU_RUNTIME_LIB). + std::string cudart_name = cudart_override.empty() ? GPU_RUNTIME_LIB : cudart_override; + const char* cudartLib = cudart_name.c_str(); bool init_log = getenv(ENV_ENABLE_INIT_LOG); int mode = RTLD_LAZY | RTLD_GLOBAL | RTLD_NODELETE; - void* handle_numa = dlopen(numaLib, mode); - if (handle_numa) { - mydlsym(&cpu_fns.numa_run_on_node, handle_numa, "numa_run_on_node"); - if (cpu_fns.numa_run_on_node) { - cuda_fns.numa_run_on_node = cpu_fns.numa_run_on_node; - if (init_log) { - fprintf(stderr, "[DEBUG] loaded: %s\n", numaLib); + if (numaLib) { + void* handle_numa = dlopen(numaLib, mode); + if (handle_numa) { + mydlsym(&cpu_fns.numa_run_on_node, handle_numa, "numa_run_on_node"); + if (cpu_fns.numa_run_on_node) { + cuda_fns.numa_run_on_node = cpu_fns.numa_run_on_node; + if (init_log) { + fprintf(stderr, "[DEBUG] loaded: %s\n", numaLib); + } } + dlclose(handle_numa); } - dlclose(handle_numa); } if (!cpu_fns.numa_run_on_node) { - if (init_log) { + if (init_log && numaLib) { fprintf(stderr, "[DEBUG] %s is not installed. fallback\n", numaLib); } cpu_fns.numa_run_on_node = cpu_numa_run_on_node; @@ -130,6 +189,10 @@ static void load_library_functions() { count = 0; // why cudaGetDeviceCount returns non-zero for errors? } cuda_found = count > 0; + // Detect if we loaded HIP runtime (ROCm) vs CUDA runtime + if (cuda_found && std::string(cudartLib).find("hip") != std::string::npos) { + is_hip_runtime = true; + } if (init_log) { fprintf(stderr, "[DEBUG] device count=%d, cuda_found=%d\n", count, cuda_found); } @@ -151,11 +214,16 @@ static void load_library_functions() { mydlsym(&cuda_fns.cudaDriverGetVersion, handle_cudart, GPU_SYM_DRIVER_GET_VERSION); mydlsym(&cuda_fns.cudaDeviceGetAttribute, handle_cudart, GPU_SYM_DEVICE_GET_ATTRIBUTE); mydlsym(&cuda_fns.cudaSetDevice, handle_cudart, GPU_SYM_SET_DEVICE); + mydlsym(&cuda_fns.cudaImportExternalMemory, handle_cudart, "cudaImportExternalMemory"); + mydlsym(&cuda_fns.cudaExternalMemoryGetMappedBuffer, handle_cudart, "cudaExternalMemoryGetMappedBuffer"); + mydlsym(&cuda_fns.cudaDestroyExternalMemory, handle_cudart, "cudaDestroyExternalMemory"); bool success = cuda_fns.cudaMemcpy && cuda_fns.cudaDeviceSynchronize; success = success && cuda_fns.cudaHostAlloc && cuda_fns.cudaFreeHost; success = success && cuda_fns.cudaDeviceGetPCIBusId && cuda_fns.cudaDeviceMalloc; success = success && cuda_fns.cudaDeviceFree && cuda_fns.cudaDriverGetVersion; success = success && cuda_fns.cudaDeviceGetAttribute && cuda_fns.cudaSetDevice; + success = success && cuda_fns.cudaImportExternalMemory && cuda_fns.cudaExternalMemoryGetMappedBuffer; + success = success && cuda_fns.cudaDestroyExternalMemory; if (!success) { cuda_found = false; if (init_log) { @@ -176,10 +244,13 @@ static void load_library_functions() { cuda_fns.cudaFreeHost = cpu_cudaFreeHost; cuda_fns.cudaDeviceGetPCIBusId = cpu_cudaDeviceGetPCIBusId; cuda_fns.cudaSetDevice = cpu_cudaSetDevice; + cuda_fns.cudaImportExternalMemory = nullptr; + cuda_fns.cudaExternalMemoryGetMappedBuffer = nullptr; + cuda_fns.cudaDestroyExternalMemory = nullptr; } cufile_found = false; - if (cuda_found) { + if (cuda_found && cufileLib) { void* handle_cufile = dlopen(cufileLib, mode); if (handle_cufile) { CUfileError_t (*cuFileGetVersion)(int *); @@ -250,6 +321,14 @@ bool cuda_not_available() return false; // On ROCm, CUDA is never available } +// Separate function for checking HIP runtime detection (not hipified) +// On CUDA: checks if HIP runtime was detected +// On ROCm: not used (is_cuda_found gets hipified to is_hip_found) +bool check_hip_runtime() +{ + return is_hip_runtime; +} + bool is_cufile_found() { return cufile_found; @@ -403,7 +482,11 @@ uintptr_t cpu_malloc(uint64_t length) { void cpu_free(uintptr_t addr) { void *p = reinterpret_cast(addr); +#ifdef _MSC_VER + _aligned_free(p); +#else free(p); +#endif } uintptr_t gpu_malloc(uint64_t length) { @@ -822,6 +905,9 @@ static int memcpy_h2d_async(uintptr_t dst, uintptr_t src, size_t size) { PYBIND11_MODULE(__MOD_NAME__, m) { +#ifdef _MSC_VER + init_dstorage_bindings(m); +#endif // Initialize GIL release setting from environment variable on module load init_gil_release_from_env(); // Export both is_cuda_found and is_hip_found on all platforms. @@ -846,12 +932,32 @@ PYBIND11_MODULE(__MOD_NAME__, m) m.def("cpu_free", &cpu_free); m.def("gpu_malloc", &gpu_malloc); m.def("gpu_free", &gpu_free); - m.def("load_library_functions", &load_library_functions); + m.def("load_library_functions", &load_library_functions, + pybind11::arg("cudart_lib_name") = ""); m.def("memcpy_h2d_async", &memcpy_h2d_async); m.def("get_cpp_metrics", &get_cpp_metrics); m.def("set_gil_release", &set_gil_release); m.def("get_gil_release", &get_gil_release); + m.def("cuda_memcpy_device_to_host", [](uintptr_t dev_ptr, size_t size) -> pybind11::bytes { + if (!cuda_fns.cudaMemcpy || !cuda_fns.cudaSetDevice) { + throw std::runtime_error("CUDA functions not loaded"); + } + cuda_fns.cudaSetDevice(0); // or pass device_id + std::string buf(size, '\0'); + cudaError_t err = cuda_fns.cudaMemcpy(buf.data(), reinterpret_cast(dev_ptr), size, cudaMemcpyDeviceToHost); + if (err != cudaSuccess) { + throw std::runtime_error("cudaMemcpy failed: " + std::to_string(err)); + } + return pybind11::bytes(buf); + }, pybind11::arg("dev_ptr"), pybind11::arg("size")); + + m.def("cuda_memcpy_host_to_device", [](uintptr_t dev_ptr, pybind11::bytes data) -> void { + std::string s = data; + cudaError_t err = cuda_fns.cudaMemcpy(reinterpret_cast(dev_ptr), s.data(), s.size(), cudaMemcpyHostToDevice); + if (err != cudaSuccess) throw std::runtime_error("cudaMemcpy H2D failed"); + }); + pybind11::class_(m, "gds_device_buffer") .def(pybind11::init()) .def("cufile_register", &gds_device_buffer::cufile_register) @@ -914,4 +1020,4 @@ PYBIND11_MODULE(__MOD_NAME__, m) pybind11::class_(m, "cpp_metrics") .def(pybind11::init<>()) .def_readwrite("bounce_buffer_bytes", &cpp_metrics_t::bounce_buffer_bytes); -} +} \ No newline at end of file diff --git a/fastsafetensors/cpp/ext.hpp b/fastsafetensors/cpp/ext.hpp index edf0bf0..773893a 100644 --- a/fastsafetensors/cpp/ext.hpp +++ b/fastsafetensors/cpp/ext.hpp @@ -8,6 +8,12 @@ #include #include #include +#include + +#ifdef _MSC_VER +#include +typedef SSIZE_T ssize_t; +#endif #include #include @@ -37,9 +43,39 @@ typedef struct CUfileError { CUfileOpError err; } CUfileError_t; // We load all GPU functions dynamically at runtime via dlopen() typedef enum cudaError { cudaSuccess = 0, cudaErrorMemoryAllocation = 2 } cudaError_t; enum cudaDeviceAttr {cudaDevAttrGPUDirectRDMASupported = 116}; -enum cudaMemcpyKind { cudaMemcpyHostToDevice=1, cudaMemcpyDefault = 4 }; +enum cudaMemcpyKind { cudaMemcpyHostToDevice=1, cudaMemcpyDefault = 4, cudaMemcpyDeviceToHost=2, cudaMemcpyDeviceToDevice=3 }; typedef void * cudaStream_t; +enum cudaExternalMemoryHandleType { + cudaExternalMemoryHandleTypeOpaqueFd = 1, + cudaExternalMemoryHandleTypeOpaqueWin32 = 2, + cudaExternalMemoryHandleTypeOpaqueWin32Kmt = 3, + cudaExternalMemoryHandleTypeD3D12Heap = 4, + cudaExternalMemoryHandleTypeD3D12Resource = 5, + cudaExternalMemoryHandleTypeD3D11Resource = 6, + cudaExternalMemoryHandleTypeD3D11ResourceKmt = 7 +}; + +struct cudaExternalMemoryHandleDesc { + cudaExternalMemoryHandleType type; + union { + struct { + void *handle; + const void *name; + } win32; + } handle; + unsigned long long size; + unsigned int flags; +}; + +struct cudaExternalMemoryBufferDesc { + unsigned long long offset; + unsigned long long size; + unsigned int flags; +}; + +typedef void * cudaExternalMemory_t; +static const unsigned int cudaExternalMemoryDedicated = 0x1; typedef enum CUfileFeatureFlags { CU_FILE_DYN_ROUTING_SUPPORTED =0, @@ -199,7 +235,11 @@ typedef struct ext_funcs { CUfileError_t (*cuFileBufDeregister)(const void *); CUfileError_t (*cuFileHandleRegister)(CUfileHandle_t *, CUfileDescr_t *); void (*cuFileHandleDeregister)(CUfileHandle_t); +#ifdef _MSC_VER + ssize_t (*cuFileRead)(CUfileHandle_t, void *, size_t, int64_t, int64_t); +#else ssize_t (*cuFileRead)(CUfileHandle_t, void *, size_t, off_t, off_t); +#endif cudaError_t (*cudaMemcpy)(void *, const void *, size_t, enum cudaMemcpyKind); cudaError_t (*cudaMemcpyAsync)(void *, const void *, size_t, enum cudaMemcpyKind, cudaStream_t); cudaError_t (*cudaDeviceSynchronize)(void); @@ -212,10 +252,13 @@ typedef struct ext_funcs { cudaError_t (*cudaDriverGetVersion)(int *); cudaError_t (*cudaDeviceGetAttribute)(int *, enum cudaDeviceAttr, int); cudaError_t (*cudaSetDevice)(int); + cudaError_t (*cudaImportExternalMemory)(cudaExternalMemory_t*, const struct cudaExternalMemoryHandleDesc*); + cudaError_t (*cudaExternalMemoryGetMappedBuffer)(void**, cudaExternalMemory_t, const struct cudaExternalMemoryBufferDesc*); + cudaError_t (*cudaDestroyExternalMemory)(cudaExternalMemory_t); } ext_funcs_t; typedef struct cpp_metrics { size_t bounce_buffer_bytes; } cpp_metrics_t; -#endif //__EXT_HPP__ +#endif //__EXT_HPP__ \ No newline at end of file diff --git a/fastsafetensors/cpp/mman.cpp b/fastsafetensors/cpp/mman.cpp new file mode 100644 index 0000000..3601633 --- /dev/null +++ b/fastsafetensors/cpp/mman.cpp @@ -0,0 +1,178 @@ +#include +#include + +#include +#include +#include + +#include "mman.h" + +#ifndef FILE_MAP_EXECUTE +#define FILE_MAP_EXECUTE 0x0020 +#endif + +namespace { + + +int MapMmanError(DWORD err, int /*deferr*/) noexcept +{ + if (err == 0) + return 0; + // TODO: implement proper Windows -> errno mapping + return static_cast(err); +} + +DWORD MapMmapProtPage(int prot) noexcept +{ + if (prot == PROT_NONE) + return 0; + + if ((prot & PROT_EXEC) != 0) + { + return ((prot & PROT_WRITE) != 0) + ? PAGE_EXECUTE_READWRITE + : PAGE_EXECUTE_READ; + } + + return ((prot & PROT_WRITE) != 0) ? PAGE_READWRITE : PAGE_READONLY; +} + +DWORD MapMmapProtFile(int prot) noexcept +{ + if (prot == PROT_NONE) + return 0; + + DWORD desiredAccess = 0; + if ((prot & PROT_READ) != 0) + desiredAccess |= FILE_MAP_READ; + if ((prot & PROT_WRITE) != 0) + desiredAccess |= FILE_MAP_WRITE; + if ((prot & PROT_EXEC) != 0) + desiredAccess |= FILE_MAP_EXECUTE; + + return desiredAccess; +} + +// Split a 64-bit-capable offset into the high/low DWORD pair Win32 expects, +// in a way that avoids "shift count >= width of type" warnings when +// OffsetType is itself only 32 bits wide. +struct DwordPair +{ + DWORD low; + DWORD high; +}; + +DwordPair SplitOffset(OffsetType value) noexcept +{ + if constexpr (sizeof(OffsetType) <= sizeof(DWORD)) + { + return { static_cast(value), 0 }; + } + else + { + const auto u = static_cast(value); + return { + static_cast(u & 0xFFFFFFFFu), + static_cast((u >> 32) & 0xFFFFFFFFu) + }; + } +} + +} // namespace + +void* mmap(void* addr, std::size_t len, int prot, int flags, int fildes, OffsetType off) +{ + errno = 0; + + // Reject zero-length and unsupported protection combinations. + if (len == 0 || prot == PROT_EXEC) + { + errno = EINVAL; + return MAP_FAILED; + } + + const DWORD protect = MapMmapProtPage(prot); + const DWORD desiredAccess = MapMmapProtFile(prot); + + const auto fileOffset = SplitOffset(off); + const auto maxSize = SplitOffset(off + static_cast(len)); + + HANDLE h = ((flags & MAP_ANONYMOUS) == 0) + ? reinterpret_cast(_get_osfhandle(fildes)) + : INVALID_HANDLE_VALUE; + + if ((flags & MAP_ANONYMOUS) == 0 && h == INVALID_HANDLE_VALUE) + { + errno = EBADF; + return MAP_FAILED; + } + + HANDLE fm = ::CreateFileMapping(h, nullptr, protect, maxSize.high, maxSize.low, nullptr); + if (fm == nullptr) + { + errno = MapMmanError(::GetLastError(), EPERM); + return MAP_FAILED; + } + + void* map = ((flags & MAP_FIXED) == 0) + ? ::MapViewOfFile(fm, desiredAccess, fileOffset.high, fileOffset.low, len) + : ::MapViewOfFileEx(fm, desiredAccess, fileOffset.high, fileOffset.low, len, addr); + + ::CloseHandle(fm); + + if (map == nullptr) + { + errno = MapMmanError(::GetLastError(), EPERM); + return MAP_FAILED; + } + + return map; +} + +int munmap(void* addr, std::size_t /*len*/) +{ + if (::UnmapViewOfFile(addr)) + return 0; + + errno = MapMmanError(::GetLastError(), EPERM); + return -1; +} + +int _mprotect(void* addr, std::size_t len, int prot) +{ + const DWORD newProtect = MapMmapProtPage(prot); + DWORD oldProtect = 0; + + if (::VirtualProtect(addr, len, newProtect, &oldProtect)) + return 0; + + errno = MapMmanError(::GetLastError(), EPERM); + return -1; +} + +int msync(void* addr, std::size_t len, int /*flags*/) +{ + if (::FlushViewOfFile(addr, len)) + return 0; + + errno = MapMmanError(::GetLastError(), EPERM); + return -1; +} + +int mlock(const void* addr, std::size_t len) +{ + if (::VirtualLock(const_cast(addr), len)) + return 0; + + errno = MapMmanError(::GetLastError(), EPERM); + return -1; +} + +int munlock(const void* addr, std::size_t len) +{ + if (::VirtualUnlock(const_cast(addr), len)) + return 0; + + errno = MapMmanError(::GetLastError(), EPERM); + return -1; +} \ No newline at end of file diff --git a/fastsafetensors/cpp/mman.h b/fastsafetensors/cpp/mman.h new file mode 100644 index 0000000..047d3a0 --- /dev/null +++ b/fastsafetensors/cpp/mman.h @@ -0,0 +1,76 @@ +/* + * sys/mman.h + * mman-win32 + */ + +#ifndef _SYS_MMAN_H_ +#define _SYS_MMAN_H_ + +#ifndef _WIN32_WINNT // Allow use of features specific to Windows XP or later. +#define _WIN32_WINNT 0x0501 // Change this to the appropriate value to target other versions of Windows. +#endif + +/* All the headers include this file. */ +#ifndef _MSC_VER +#include <_mingw.h> +#endif + +#if defined(MMAN_LIBRARY_DLL) +/* Windows shared libraries (DLL) must be declared export when building the lib and import when building the +application which links against the library. */ +#if defined(MMAN_LIBRARY) +#define MMANSHARED_EXPORT __declspec(dllexport) +#else +#define MMANSHARED_EXPORT __declspec(dllimport) +#endif /* MMAN_LIBRARY */ +#else +/* Static libraries do not require a __declspec attribute.*/ +#define MMANSHARED_EXPORT +#endif /* MMAN_LIBRARY_DLL */ + +/* Determine offset type */ +#include +#if defined(_WIN64) +typedef int64_t OffsetType; +#else +typedef uint32_t OffsetType; +#endif + +#include + +#ifdef __cplusplus +extern "C" { +#endif + +#define PROT_NONE 0 +#define PROT_READ 1 +#define PROT_WRITE 2 +#define PROT_EXEC 4 + +#define MAP_FILE 0 +#define MAP_SHARED 1 +#define MAP_PRIVATE 2 +#define MAP_TYPE 0xf +#define MAP_FIXED 0x10 +#define MAP_ANONYMOUS 0x20 +#define MAP_ANON MAP_ANONYMOUS + +#define MAP_FAILED ((void *)-1) + +/* Flags for msync. */ +#define MS_ASYNC 1 +#define MS_SYNC 2 +#define MS_INVALIDATE 4 + +MMANSHARED_EXPORT void* mmap(void *addr, size_t len, int prot, int flags, int fildes, OffsetType off); +MMANSHARED_EXPORT int munmap(void *addr, size_t len); +MMANSHARED_EXPORT int _mprotect(void *addr, size_t len, int prot); +MMANSHARED_EXPORT int msync(void *addr, size_t len, int flags); +MMANSHARED_EXPORT int mlock(const void *addr, size_t len); +MMANSHARED_EXPORT int munlock(const void *addr, size_t len); + +#ifdef __cplusplus +} +#endif + +#endif /* _SYS_MMAN_H_ */ diff --git a/fastsafetensors/loader.py b/fastsafetensors/loader.py index 83515d8..df55a89 100644 --- a/fastsafetensors/loader.py +++ b/fastsafetensors/loader.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 import math +import platform from typing import Any, Dict, List, Optional, OrderedDict, Tuple, Union from . import cpp as fstcpp @@ -199,7 +200,10 @@ def __init__( fstcpp.set_debug_log(debug_log) if not nogds: - copier_type = "gds" + if platform.system() == "Windows": + copier_type = "dstorage" + else: + copier_type = "gds" elif self.device.type != DeviceType.CPU and is_unified_memory_system(): # When GDS is unavailable, prefer the unified copier on systems # with shared CPU/GPU memory (e.g., DGX Spark) over the @@ -291,4 +295,4 @@ def __exit__(self, exc_type, exc_value, tb): if self.fb: self.fb.close() if self.loader: - self.loader.close() + self.loader.close() \ No newline at end of file diff --git a/fastsafetensors/threefs_loader.py b/fastsafetensors/threefs_loader.py index 85b6c5a..2daf10e 100644 --- a/fastsafetensors/threefs_loader.py +++ b/fastsafetensors/threefs_loader.py @@ -3,7 +3,7 @@ from typing import Any, List, Optional from . import cpp as fstcpp -from .common import init_logger +from .common import init_logger, resolve_cudart_lib_name from .frameworks import get_framework_op from .loader import BaseSafeTensorsFileLoader, loaded_library from .parallel_loader import PipelineParallel @@ -48,7 +48,7 @@ def __init__( global loaded_library if not loaded_library: - fstcpp.load_library_functions() + fstcpp.load_library_functions(resolve_cudart_lib_name()) loaded_library = True fstcpp.set_debug_log(debug_log) super().__init__( @@ -130,4 +130,4 @@ def __init__( queue_size, use_tqdm_on_load, **kwargs, - ) + ) \ No newline at end of file diff --git a/perf/fastsafetensors_perf/perf.py b/perf/fastsafetensors_perf/perf.py index 64fba97..b3e77c5 100644 --- a/perf/fastsafetensors_perf/perf.py +++ b/perf/fastsafetensors_perf/perf.py @@ -400,14 +400,14 @@ def drop_cache( for filename in filenames: targets[os.path.realpath(filename)] = True for filename in targets.keys(): - fd = os.open(filename, os.O_RDONLY) + fd = os.open(filename, os.O_RDONLY | (os.O_BINARY if sys.platform == "win32" and hasattr(os, "O_BINARY") else 0)) s = os.fstat(fd) if hasattr(os, "posix_fadvise") and hasattr(os, "POSIX_FADV_DONTNEED"): os.posix_fadvise(fd, 0, s.st_size, os.POSIX_FADV_DONTNEED) # type: ignore[attr-defined] os.close(fd) print(f"DROP_CACHE: {filename}, {s.st_size/1024/1024/1024} GiB") total += s.st_size - fd = os.open(sten_collection_json, os.O_RDONLY) + fd = os.open(sten_collection_json, os.O_RDONLY | (os.O_BINARY if sys.platform == "win32" and hasattr(os, "O_BINARY") else 0)) s = os.fstat(fd) if hasattr(os, "posix_fadvise") and hasattr(os, "POSIX_FADV_DONTNEED"): os.posix_fadvise(fd, 0, s.st_size, os.POSIX_FADV_DONTNEED) # type: ignore[attr-defined] @@ -798,4 +798,4 @@ def run_gds( if __name__ == "__main__": - app() + app() \ No newline at end of file diff --git a/setup.py b/setup.py index c3c702a..9e19b21 100644 --- a/setup.py +++ b/setup.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 import os +import platform from setuptools import Extension, setup from setuptools.command.build_ext import build_ext @@ -33,6 +34,25 @@ def MyExtension(name, sources, mod_name, platform_type, *args, **kwargs): kwargs["language"] = "c++" kwargs["extra_compile_args"] = ["-fvisibility=hidden", "-std=c++17"] + # Windows-specific configuration for DirectStorage + D3D12/CUDA interop + if platform.system() == "Windows": + sources.append("fastsafetensors/cpp/dstorage_reader.cpp") + sources.append("fastsafetensors/cpp/mman.cpp") + sources.append("fastsafetensors/cpp/dlfcn.cpp") + kwargs["libraries"] = [] + #c++20 required for designated initializers at ext.hpp + kwargs["extra_compile_args"] = ["/std:c++20"] + # Note: dstorage.dll is loaded at runtime via LoadLibrary, not linked. + kwargs["libraries"].extend(["ole32", "d3d12", "dxgi", "dxguid", "uuid"]) + + # CUDA interop headers: if CUDA_HOME/CUDA_PATH is set, add include path + # for cudaExternalMemory types used by the interop bridge. + cuda_home = os.environ.get("CUDA_HOME") or os.environ.get("CUDA_PATH") + if cuda_home: + cuda_include = os.path.join(cuda_home, "include") + if os.path.isdir(cuda_include): + kwargs["include_dirs"].append(cuda_include) + if platform_type == "rocm": # Define platform macros so cuda_compat.h selects the ROCm symbol names. # No ROCm headers or libraries are needed at build time — the runtime @@ -66,4 +86,4 @@ def MyExtension(name, sources, mod_name, platform_type, *args, **kwargs): platform_type=platform_type, ) ], -) +) \ No newline at end of file diff --git a/tests/conftest.py b/tests/conftest.py index 368b943..dfa0487 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -7,7 +7,7 @@ from fastsafetensors import SingleGroup from fastsafetensors import cpp as fstcpp -from fastsafetensors.common import is_gpu_found +from fastsafetensors.common import is_gpu_found, resolve_cudart_lib_name from fastsafetensors.cpp import load_library_functions from fastsafetensors.frameworks import FrameworkOpBase, get_framework_op from fastsafetensors.st_types import Device @@ -23,7 +23,7 @@ os.makedirs(TF_DIR, 0o777, True) os.makedirs(TMP_DIR, 0o777, True) -load_library_functions() +load_library_functions(resolve_cudart_lib_name()) FRAMEWORK = get_framework_op(os.getenv("TEST_FASTSAFETENSORS_FRAMEWORK", "please set")) # Print platform information at test startup @@ -97,4 +97,4 @@ def fstcpp_log() -> None: @pytest.fixture(scope="function") def tmp_dir() -> str: - return TMP_DIR + return TMP_DIR \ No newline at end of file diff --git a/tests/test_fastsafetensors.py b/tests/test_fastsafetensors.py index d5ff509..234b35a 100644 --- a/tests/test_fastsafetensors.py +++ b/tests/test_fastsafetensors.py @@ -77,7 +77,10 @@ def run_nogds_file_read( input_file: str, framework: FrameworkOpBase, ) -> Tuple[SafeTensorsMetadata, fstcpp.gds_device_buffer]: - fd = os.open(input_file, os.O_RDONLY, 0o644) + flags = os.O_RDONLY + if sys.platform == "win32" and hasattr(os, "O_BINARY"): + flags |= os.O_BINARY + fd = os.open(input_file, flags, 0o644) meta = SafeTensorsMetadata.from_file(input_file, framework) size = meta.size_bytes - meta.header_length device, dev_is_gpu = get_and_check_device(framework) @@ -294,7 +297,10 @@ def test_memmove(fstcpp_log, framework) -> None: def test_nogds_file_reader(fstcpp_log, input_files, framework) -> None: print("test_nogds_file_reader") - fd = os.open(input_files[0], os.O_RDONLY, 0o644) + flags = os.O_RDONLY + if sys.platform == "win32" and hasattr(os, "O_BINARY"): + flags |= os.O_BINARY + fd = os.open(input_files[0], flags, 0o644) s = os.fstat(fd) assert fd > 0 device, dev_is_gpu = get_and_check_device(framework) @@ -692,4 +698,4 @@ def test_cpp_metrics(fstcpp_log, framework) -> None: assert framework.get_mem_used() == exp_length assert exp_length == 0 - assert framework.get_mem_used() == 0 + assert framework.get_mem_used() == 0 \ No newline at end of file diff --git a/tests/threefs/conftest.py b/tests/threefs/conftest.py index c50c09d..ec4b336 100644 --- a/tests/threefs/conftest.py +++ b/tests/threefs/conftest.py @@ -14,7 +14,7 @@ from fastsafetensors import SingleGroup from fastsafetensors import cpp as fstcpp -from fastsafetensors.common import is_gpu_found +from fastsafetensors.common import is_gpu_found, resolve_cudart_lib_name from fastsafetensors.cpp import load_library_functions from fastsafetensors.frameworks import FrameworkOpBase, get_framework_op from fastsafetensors.st_types import Device @@ -34,7 +34,7 @@ def mock_3fs_reader(): yield -load_library_functions() +load_library_functions(resolve_cudart_lib_name()) FRAMEWORK = get_framework_op(os.getenv("TEST_FASTSAFETENSORS_FRAMEWORK", "please set")) @@ -84,4 +84,4 @@ def framework() -> FrameworkOpBase: @pytest.fixture(scope="function") def fstcpp_log() -> None: - fstcpp.set_debug_log(True) + fstcpp.set_debug_log(True) \ No newline at end of file diff --git a/tests/threefs/mock_reader.py b/tests/threefs/mock_reader.py index 760de35..406a8b5 100644 --- a/tests/threefs/mock_reader.py +++ b/tests/threefs/mock_reader.py @@ -3,6 +3,7 @@ import ctypes import os +import sys class MockFileReader: @@ -20,7 +21,10 @@ def read_chunked( self, path, dev_ptr, file_offset, total_length, chunk_size=0, **kwargs ) -> int: if path not in self._fd_map: - self._fd_map[path] = os.open(path, os.O_RDONLY) + flags = os.O_RDONLY + if sys.platform == "win32" and hasattr(os, "O_BINARY"): + flags |= os.O_BINARY + self._fd_map[path] = os.open(path, flags) fd = self._fd_map[path] data = os.pread(fd, total_length, file_offset) if dev_ptr != 0: @@ -45,4 +49,4 @@ def close(self) -> None: def extract_mount_point(path: str) -> str: """Fallback: return the directory containing the file.""" - return os.path.dirname(os.path.abspath(path)) + return os.path.dirname(os.path.abspath(path)) \ No newline at end of file