Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 49 additions & 46 deletions src/kernels/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,11 @@ def universal_build_variant() -> str:
return "torch-universal"


def import_from_path(module_name: str, file_path: Path) -> ModuleType:
def _import_from_path(module_name: str, variant_path: Path) -> ModuleType:
file_path = variant_path / "__init__.py"
if not file_path.exists():
file_path = variant_path / module_name / "__init__.py"

# We cannot use the module name as-is, after adding it to `sys.modules`,
# it would also be used for other imports. So, we make a module name that
# depends on the path for it to be unique using the hex-encoded hash of
Expand Down Expand Up @@ -149,42 +153,48 @@ def install_kernel(
)

try:
return _load_kernel_from_path(repo_path, package_name, variant_locks)
return _find_kernel_in_repo_path(repo_path, package_name, variant_locks)
except FileNotFoundError:
# Redo with more specific error message.
raise FileNotFoundError(
f"Kernel `{repo_id}` at revision {revision} does not have build: {variant}"
f"Cannot install kernel from repo {repo_id} (revision: {revision})"
)


def _load_kernel_from_path(
def _find_kernel_in_repo_path(
repo_path: Path,
package_name: str,
variant_locks: Optional[Dict[str, VariantLock]] = None,
) -> Tuple[str, Path]:
variant = build_variant()
specific_variant = build_variant()
universal_variant = universal_build_variant()

variant_path = repo_path / "build" / variant
specific_variant_path = repo_path / "build" / specific_variant
universal_variant_path = repo_path / "build" / universal_variant

if not variant_path.exists() and universal_variant_path.exists():
# Fall back to universal variant.
if specific_variant_path.exists():
variant = specific_variant
variant_path = specific_variant_path
elif universal_variant_path.exists():
variant = universal_variant
variant_path = universal_variant_path
else:
raise FileNotFoundError(
f"Kernel at path `{repo_path}` does not have one of build variants: {specific_variant}, {universal_variant}"
)

if variant_locks is not None:
variant_lock = variant_locks.get(variant)
if variant_lock is None:
raise ValueError(f"No lock found for build variant: {variant}")
validate_kernel(repo_path=repo_path, variant=variant, hash=variant_lock.hash)

module_init_path = variant_path / package_name / "__init__.py"
module_init_path = variant_path / "__init__.py"
if not os.path.exists(module_init_path):
# Compatibility with older kernels.
module_init_path = variant_path / package_name / "__init__.py"

if not os.path.exists(module_init_path):
raise FileNotFoundError(
f"Kernel at path `{repo_path}` does not have build: {variant}"
)
raise FileNotFoundError(f"No kernel module found at: `{variant_path}`")

return package_name, variant_path

Expand Down Expand Up @@ -258,10 +268,10 @@ def get_kernel(
```
"""
revision = select_revision_or_version(repo_id, revision, version)
package_name, package_path = install_kernel(
package_name, variant_path = install_kernel(
repo_id, revision=revision, user_agent=user_agent
)
return import_from_path(package_name, package_path / package_name / "__init__.py")
return _import_from_path(package_name, variant_path)


def get_local_kernel(repo_path: Path, package_name: str) -> ModuleType:
Expand All @@ -284,15 +294,15 @@ def get_local_kernel(repo_path: Path, package_name: str) -> ModuleType:
for base_path in [repo_path, repo_path / "build"]:
# Prefer the universal variant if it exists.
for v in [universal_variant, variant]:
package_path = base_path / v / package_name / "__init__.py"
if package_path.exists():
return import_from_path(package_name, package_path)
variant_path = base_path / v
if variant_path.exists():
return _import_from_path(package_name, variant_path)

# If we didn't find the package in the repo we may have a explicit
# package path.
package_path = repo_path / package_name / "__init__.py"
if package_path.exists():
return import_from_path(package_name, package_path)
variant_path = repo_path
if variant_path.exists():
return _import_from_path(package_name, variant_path)

raise FileNotFoundError(f"Could not find package '{package_name}' in {repo_path}")

Expand Down Expand Up @@ -321,18 +331,16 @@ def has_kernel(
variant = build_variant()
universal_variant = universal_build_variant()

if file_exists(
repo_id,
revision=revision,
filename=f"build/{universal_variant}/{package_name}/__init__.py",
):
return True

return file_exists(
repo_id,
revision=revision,
filename=f"build/{variant}/{package_name}/__init__.py",
)
for variant in [universal_variant, variant]:
for init_file in ["__init__.py", f"{package_name}/__init__.py"]:
if file_exists(
repo_id,
revision=revision,
filename=f"build/{variant}/{init_file}",
):
return True

return False


def load_kernel(repo_id: str, *, lockfile: Optional[Path] = None) -> ModuleType:
Expand Down Expand Up @@ -376,21 +384,16 @@ def load_kernel(repo_id: str, *, lockfile: Optional[Path] = None) -> ModuleType:
)
)

variant_path = repo_path / "build" / variant
universal_variant_path = repo_path / "build" / universal_variant
if not variant_path.exists() and universal_variant_path.exists():
# Fall back to universal variant.
variant = universal_variant
variant_path = universal_variant_path

module_init_path = variant_path / package_name / "__init__.py"
if not os.path.exists(module_init_path):
try:
package_name, variant_path = _find_kernel_in_repo_path(
repo_path, package_name, variant_locks=None
)
return _import_from_path(package_name, variant_path)
except FileNotFoundError:
raise FileNotFoundError(
f"Locked kernel `{repo_id}` does not have build `{variant}` or was not downloaded with `kernels download <project>`"
)

return import_from_path(package_name, variant_path / package_name / "__init__.py")


def get_locked_kernel(repo_id: str, local_files_only: bool = False) -> ModuleType:
"""
Expand All @@ -410,11 +413,11 @@ def get_locked_kernel(repo_id: str, local_files_only: bool = False) -> ModuleTyp
if locked_sha is None:
raise ValueError(f"Kernel `{repo_id}` is not locked")

package_name, package_path = install_kernel(
package_name, variant_path = install_kernel(
repo_id, locked_sha, local_files_only=local_files_only
)

return import_from_path(package_name, package_path / package_name / "__init__.py")
return _import_from_path(package_name, variant_path)


def _get_caller_locked_kernel(repo_id: str) -> Optional[str]:
Expand Down
58 changes: 39 additions & 19 deletions tests/test_basic.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pytest
import torch
import torch.nn.functional as F

from kernels import get_kernel, get_local_kernel, has_kernel, install_kernel

Expand Down Expand Up @@ -72,37 +73,33 @@ def test_local_kernel(local_kernel, device):
assert torch.allclose(y, expected)


@pytest.mark.cuda_only
def test_local_kernel_path_types(local_kernel_path, device):
package_name, path = local_kernel_path
@pytest.mark.parametrize(
"repo_revision",
[
("kernels-test/flattened-build", "pre-flattening"),
("kernels-test/flattened-build", "main"),
("kernels-test/flattened-build", "without-compat-module"),
],
)
def test_local_kernel_path_types(repo_revision, device):
repo_id, revision = repo_revision
package_name, path = install_kernel(repo_id, revision)

# Top-level repo path
# ie: /home/ubuntu/.cache/huggingface/hub/models--kernels-community--activation/snapshots/2fafa6a3a38ccb57a1a98419047cf7816ecbc071
kernel = get_local_kernel(path.parent.parent, package_name)
x = torch.arange(1, 10, dtype=torch.float16, device=device).view(3, 3)
y = torch.empty_like(x)

kernel.gelu_fast(y, x)
expected = torch.tensor(
[[0.8408, 1.9551, 2.9961], [4.0000, 5.0000, 6.0000], [7.0000, 8.0000, 9.0000]],
device=device,
dtype=torch.float16,
)
assert torch.allclose(y, expected)
x = torch.arange(0, 32, dtype=torch.float16, device=device).view(2, 16)
torch.testing.assert_close(kernel.silu_and_mul(x), silu_and_mul_torch(x))

# Build directory path
# ie: /home/ubuntu/.cache/huggingface/hub/models--kernels-community--activation/snapshots/2fafa6a3a38ccb57a1a98419047cf7816ecbc071/build
kernel = get_local_kernel(path.parent.parent / "build", package_name)
y = torch.empty_like(x)
kernel.gelu_fast(y, x)
assert torch.allclose(y, expected)
torch.testing.assert_close(kernel.silu_and_mul(x), silu_and_mul_torch(x))

# Explicit package path
# ie: /home/ubuntu/.cache/huggingface/hub/models--kernels-community--activation/snapshots/2fafa6a3a38ccb57a1a98419047cf7816ecbc071/build/torch28-cxx11-cu128-x86_64-linux
kernel = get_local_kernel(path, package_name)
y = torch.empty_like(x)
kernel.gelu_fast(y, x)
assert torch.allclose(y, expected)
torch.testing.assert_close(kernel.silu_and_mul(x), silu_and_mul_torch(x))


@pytest.mark.darwin_only
Expand All @@ -123,6 +120,8 @@ def test_relu_metal(metal_kernel, dtype):
# support/test against this version).
("kernels-test/only-torch-2.4", "main", False),
("google-bert/bert-base-uncased", "87565a309", False),
("kernels-test/flattened-build", "main", True),
("kernels-test/flattened-build", "without-compat-module", True),
],
)
def test_has_kernel(kernel_exists):
Expand Down Expand Up @@ -162,3 +161,24 @@ def test_universal_kernel(universal_kernel):
out_check = out_check.to(torch.float16)

torch.testing.assert_close(out, out_check, rtol=1e-1, atol=1e-1)


@pytest.mark.parametrize(
"repo_revision",
[
("kernels-test/flattened-build", "pre-flattening"),
("kernels-test/flattened-build", "main"),
("kernels-test/flattened-build", "without-compat-module"),
],
)
def test_flattened_build(repo_revision, device):
repo_id, revision = repo_revision
kernel = get_kernel(repo_id, revision=revision)

x = torch.arange(0, 32, dtype=torch.float16, device=device).view(2, 16)
torch.testing.assert_close(kernel.silu_and_mul(x), silu_and_mul_torch(x))


def silu_and_mul_torch(x: torch.Tensor):
d = x.shape[-1] // 2
return F.silu(x[..., :d]) * x[..., d:]
Loading