From 19cd59cdcdf07b01eba02f5cb363e936a163d137 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Mon, 10 Nov 2025 16:28:06 +0000 Subject: [PATCH] Support new flattened kernel builds Support kernels that have the main module at `build/`. See: https://github.com/huggingface/kernel-builder/pull/293 --- src/kernels/utils.py | 95 +++++++++++++++++++++++--------------------- tests/test_basic.py | 58 ++++++++++++++++++--------- 2 files changed, 88 insertions(+), 65 deletions(-) diff --git a/src/kernels/utils.py b/src/kernels/utils.py index f5338901..e35d8e0a 100644 --- a/src/kernels/utils.py +++ b/src/kernels/utils.py @@ -88,7 +88,11 @@ def universal_build_variant() -> str: return "torch-universal" -def import_from_path(module_name: str, file_path: Path) -> ModuleType: +def _import_from_path(module_name: str, variant_path: Path) -> ModuleType: + file_path = variant_path / "__init__.py" + if not file_path.exists(): + file_path = variant_path / module_name / "__init__.py" + # We cannot use the module name as-is, after adding it to `sys.modules`, # it would also be used for other imports. So, we make a module name that # depends on the path for it to be unique using the hex-encoded hash of @@ -149,29 +153,34 @@ def install_kernel( ) try: - return _load_kernel_from_path(repo_path, package_name, variant_locks) + return _find_kernel_in_repo_path(repo_path, package_name, variant_locks) except FileNotFoundError: - # Redo with more specific error message. raise FileNotFoundError( - f"Kernel `{repo_id}` at revision {revision} does not have build: {variant}" + f"Cannot install kernel from repo {repo_id} (revision: {revision})" ) -def _load_kernel_from_path( +def _find_kernel_in_repo_path( repo_path: Path, package_name: str, variant_locks: Optional[Dict[str, VariantLock]] = None, ) -> Tuple[str, Path]: - variant = build_variant() + specific_variant = build_variant() universal_variant = universal_build_variant() - variant_path = repo_path / "build" / variant + specific_variant_path = repo_path / "build" / specific_variant universal_variant_path = repo_path / "build" / universal_variant - if not variant_path.exists() and universal_variant_path.exists(): - # Fall back to universal variant. + if specific_variant_path.exists(): + variant = specific_variant + variant_path = specific_variant_path + elif universal_variant_path.exists(): variant = universal_variant variant_path = universal_variant_path + else: + raise FileNotFoundError( + f"Kernel at path `{repo_path}` does not have one of build variants: {specific_variant}, {universal_variant}" + ) if variant_locks is not None: variant_lock = variant_locks.get(variant) @@ -179,12 +188,13 @@ def _load_kernel_from_path( raise ValueError(f"No lock found for build variant: {variant}") validate_kernel(repo_path=repo_path, variant=variant, hash=variant_lock.hash) - module_init_path = variant_path / package_name / "__init__.py" + module_init_path = variant_path / "__init__.py" + if not os.path.exists(module_init_path): + # Compatibility with older kernels. + module_init_path = variant_path / package_name / "__init__.py" if not os.path.exists(module_init_path): - raise FileNotFoundError( - f"Kernel at path `{repo_path}` does not have build: {variant}" - ) + raise FileNotFoundError(f"No kernel module found at: `{variant_path}`") return package_name, variant_path @@ -258,10 +268,10 @@ def get_kernel( ``` """ revision = select_revision_or_version(repo_id, revision, version) - package_name, package_path = install_kernel( + package_name, variant_path = install_kernel( repo_id, revision=revision, user_agent=user_agent ) - return import_from_path(package_name, package_path / package_name / "__init__.py") + return _import_from_path(package_name, variant_path) def get_local_kernel(repo_path: Path, package_name: str) -> ModuleType: @@ -284,15 +294,15 @@ def get_local_kernel(repo_path: Path, package_name: str) -> ModuleType: for base_path in [repo_path, repo_path / "build"]: # Prefer the universal variant if it exists. for v in [universal_variant, variant]: - package_path = base_path / v / package_name / "__init__.py" - if package_path.exists(): - return import_from_path(package_name, package_path) + variant_path = base_path / v + if variant_path.exists(): + return _import_from_path(package_name, variant_path) # If we didn't find the package in the repo we may have a explicit # package path. - package_path = repo_path / package_name / "__init__.py" - if package_path.exists(): - return import_from_path(package_name, package_path) + variant_path = repo_path + if variant_path.exists(): + return _import_from_path(package_name, variant_path) raise FileNotFoundError(f"Could not find package '{package_name}' in {repo_path}") @@ -321,18 +331,16 @@ def has_kernel( variant = build_variant() universal_variant = universal_build_variant() - if file_exists( - repo_id, - revision=revision, - filename=f"build/{universal_variant}/{package_name}/__init__.py", - ): - return True - - return file_exists( - repo_id, - revision=revision, - filename=f"build/{variant}/{package_name}/__init__.py", - ) + for variant in [universal_variant, variant]: + for init_file in ["__init__.py", f"{package_name}/__init__.py"]: + if file_exists( + repo_id, + revision=revision, + filename=f"build/{variant}/{init_file}", + ): + return True + + return False def load_kernel(repo_id: str, *, lockfile: Optional[Path] = None) -> ModuleType: @@ -376,21 +384,16 @@ def load_kernel(repo_id: str, *, lockfile: Optional[Path] = None) -> ModuleType: ) ) - variant_path = repo_path / "build" / variant - universal_variant_path = repo_path / "build" / universal_variant - if not variant_path.exists() and universal_variant_path.exists(): - # Fall back to universal variant. - variant = universal_variant - variant_path = universal_variant_path - - module_init_path = variant_path / package_name / "__init__.py" - if not os.path.exists(module_init_path): + try: + package_name, variant_path = _find_kernel_in_repo_path( + repo_path, package_name, variant_locks=None + ) + return _import_from_path(package_name, variant_path) + except FileNotFoundError: raise FileNotFoundError( f"Locked kernel `{repo_id}` does not have build `{variant}` or was not downloaded with `kernels download `" ) - return import_from_path(package_name, variant_path / package_name / "__init__.py") - def get_locked_kernel(repo_id: str, local_files_only: bool = False) -> ModuleType: """ @@ -410,11 +413,11 @@ def get_locked_kernel(repo_id: str, local_files_only: bool = False) -> ModuleTyp if locked_sha is None: raise ValueError(f"Kernel `{repo_id}` is not locked") - package_name, package_path = install_kernel( + package_name, variant_path = install_kernel( repo_id, locked_sha, local_files_only=local_files_only ) - return import_from_path(package_name, package_path / package_name / "__init__.py") + return _import_from_path(package_name, variant_path) def _get_caller_locked_kernel(repo_id: str) -> Optional[str]: diff --git a/tests/test_basic.py b/tests/test_basic.py index f77f5c19..7039ceb2 100644 --- a/tests/test_basic.py +++ b/tests/test_basic.py @@ -1,5 +1,6 @@ import pytest import torch +import torch.nn.functional as F from kernels import get_kernel, get_local_kernel, has_kernel, install_kernel @@ -72,37 +73,33 @@ def test_local_kernel(local_kernel, device): assert torch.allclose(y, expected) -@pytest.mark.cuda_only -def test_local_kernel_path_types(local_kernel_path, device): - package_name, path = local_kernel_path +@pytest.mark.parametrize( + "repo_revision", + [ + ("kernels-test/flattened-build", "pre-flattening"), + ("kernels-test/flattened-build", "main"), + ("kernels-test/flattened-build", "without-compat-module"), + ], +) +def test_local_kernel_path_types(repo_revision, device): + repo_id, revision = repo_revision + package_name, path = install_kernel(repo_id, revision) # Top-level repo path # ie: /home/ubuntu/.cache/huggingface/hub/models--kernels-community--activation/snapshots/2fafa6a3a38ccb57a1a98419047cf7816ecbc071 kernel = get_local_kernel(path.parent.parent, package_name) - x = torch.arange(1, 10, dtype=torch.float16, device=device).view(3, 3) - y = torch.empty_like(x) - - kernel.gelu_fast(y, x) - expected = torch.tensor( - [[0.8408, 1.9551, 2.9961], [4.0000, 5.0000, 6.0000], [7.0000, 8.0000, 9.0000]], - device=device, - dtype=torch.float16, - ) - assert torch.allclose(y, expected) + x = torch.arange(0, 32, dtype=torch.float16, device=device).view(2, 16) + torch.testing.assert_close(kernel.silu_and_mul(x), silu_and_mul_torch(x)) # Build directory path # ie: /home/ubuntu/.cache/huggingface/hub/models--kernels-community--activation/snapshots/2fafa6a3a38ccb57a1a98419047cf7816ecbc071/build kernel = get_local_kernel(path.parent.parent / "build", package_name) - y = torch.empty_like(x) - kernel.gelu_fast(y, x) - assert torch.allclose(y, expected) + torch.testing.assert_close(kernel.silu_and_mul(x), silu_and_mul_torch(x)) # Explicit package path # ie: /home/ubuntu/.cache/huggingface/hub/models--kernels-community--activation/snapshots/2fafa6a3a38ccb57a1a98419047cf7816ecbc071/build/torch28-cxx11-cu128-x86_64-linux kernel = get_local_kernel(path, package_name) - y = torch.empty_like(x) - kernel.gelu_fast(y, x) - assert torch.allclose(y, expected) + torch.testing.assert_close(kernel.silu_and_mul(x), silu_and_mul_torch(x)) @pytest.mark.darwin_only @@ -123,6 +120,8 @@ def test_relu_metal(metal_kernel, dtype): # support/test against this version). ("kernels-test/only-torch-2.4", "main", False), ("google-bert/bert-base-uncased", "87565a309", False), + ("kernels-test/flattened-build", "main", True), + ("kernels-test/flattened-build", "without-compat-module", True), ], ) def test_has_kernel(kernel_exists): @@ -162,3 +161,24 @@ def test_universal_kernel(universal_kernel): out_check = out_check.to(torch.float16) torch.testing.assert_close(out, out_check, rtol=1e-1, atol=1e-1) + + +@pytest.mark.parametrize( + "repo_revision", + [ + ("kernels-test/flattened-build", "pre-flattening"), + ("kernels-test/flattened-build", "main"), + ("kernels-test/flattened-build", "without-compat-module"), + ], +) +def test_flattened_build(repo_revision, device): + repo_id, revision = repo_revision + kernel = get_kernel(repo_id, revision=revision) + + x = torch.arange(0, 32, dtype=torch.float16, device=device).view(2, 16) + torch.testing.assert_close(kernel.silu_and_mul(x), silu_and_mul_torch(x)) + + +def silu_and_mul_torch(x: torch.Tensor): + d = x.shape[-1] // 2 + return F.silu(x[..., :d]) * x[..., d:]