Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
178 changes: 129 additions & 49 deletions tools/ais-check/ais-check
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,14 @@ If components are missing the program exits with a non-zero exit code.
import argparse
import ctypes
import ctypes.util
import glob
import gzip
import os
import sys

# Global mapping of HIP runtime library paths to AIS support flags
hip_libraries = {}


def kernel_supports_p2pdma():
"""
Expand Down Expand Up @@ -46,67 +50,134 @@ def kernel_supports_p2pdma():
return False


def find_hip_runtimes():
"""
Populate the global mapping of HIP runtime library paths to AIS
support flags by looking in the usual places.
"""

# NOTE: CodeQL will be unhappy if you are not careful about paths
# in this function

candidates = []

# 1. Respect runtime linker paths
for p in os.environ.get("LD_LIBRARY_PATH", "").split(":"):
if p:
# Clean up the path by removing `..`, etc. and getting
# an absolute path.
safe_p = os.path.abspath(os.path.normpath(p))

candidates.append(os.path.join(safe_p, "libamdhip64.so"))

# 2. Environment variables commonly set by ROCm or modules
for var in ("ROCM_HOME", "ROCM_PATH", "HIP_PATH"):
base = os.environ.get(var)
Comment on lines +64 to +75
if base:
# Also clean up this path
safe_base = os.path.abspath(os.path.normpath(base))
candidates += [
os.path.join(safe_base, "lib", "libamdhip64.so"),
os.path.join(safe_base, "lib64", "libamdhip64.so"),
]

# 3. Standard ROCm install paths
candidates += [
"/opt/rocm/lib/libamdhip64.so",
"/opt/rocm/lib64/libamdhip64.so",
]

# 4. Versioned installs (/opt/rocm-5.x, etc.)
for d in glob.glob("/opt/rocm*/lib*/libamdhip64.so"):
candidates.append(d)
Comment on lines +84 to +92
Copy link

Copilot AI Apr 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

find_hip_runtimes() only looks for an unversioned libamdhip64.so in a small set of locations (/opt/rocm*, LD_LIBRARY_PATH, ROCM_* / HIP_PATH). This drops the previous ctypes.util.find_library('amdhip64') behavior and can miss systems where the HIP runtime is discoverable via the default loader paths (e.g., /usr/lib*) or where only a versioned SONAME exists (e.g., libamdhip64.so.6 without the unversioned symlink). Consider adding a fallback candidate from ctypes.util.find_library('amdhip64') and/or globbing for libamdhip64.so.* in the searched directories to avoid false negatives.

Copilot uses AI. Check for mistakes.

# Drop any paths that don't exist
existing_paths = [
os.path.abspath(os.path.normpath(path))
for path in candidates
if os.path.exists(os.path.abspath(os.path.normpath(path)))
]

# Populate the global dictionary of paths
#
# Tell pylint to be quiet since returning the dictionary
# would result in uglier downstream code
global hip_libraries # pylint: disable=W0603
hip_libraries = dict.fromkeys(existing_paths, False)


def hip_runtime_supports_ais():
"""
Check if hipAmdFileRead and hipAmdFileWrite are available in HIP
"""
hip_path = ctypes.util.find_library("amdhip64")
if hip_path is None:
return False
find_hip_runtimes()

hip = ctypes.CDLL(hip_path)
# Check for AIS functions in the list of found HIP libraries
for hip_path in hip_libraries:

hipError_t = ctypes.c_int
hipDriverProcAddressQueryResult = ctypes.c_int
try:
hip = ctypes.CDLL(hip_path)
except OSError:
continue

hip.hipRuntimeGetVersion.argtypes = [ctypes.POINTER(ctypes.c_int)]
hip.hipRuntimeGetVersion.restype = hipError_t
hipError_t = ctypes.c_int
hipDriverProcAddressQueryResult = ctypes.c_int

hip.hipGetProcAddress.argtypes = [
ctypes.c_char_p,
ctypes.POINTER(ctypes.c_void_p),
ctypes.c_int,
ctypes.c_uint64,
ctypes.POINTER(hipDriverProcAddressQueryResult),
]
hip.hipGetProcAddress.restype = hipError_t
hip.hipRuntimeGetVersion.argtypes = [ctypes.POINTER(ctypes.c_int)]
hip.hipRuntimeGetVersion.restype = hipError_t

hip.hipGetErrorString.argtypes = [hipError_t]
hip.hipGetErrorString.restype = ctypes.c_char_p
hip.hipGetProcAddress.argtypes = [
ctypes.c_char_p,
ctypes.POINTER(ctypes.c_void_p),
ctypes.c_int,
ctypes.c_uint64,
ctypes.POINTER(hipDriverProcAddressQueryResult),
]
hip.hipGetProcAddress.restype = hipError_t

version = ctypes.c_int()
err = hip.hipRuntimeGetVersion(ctypes.byref(version))
if err != 0:
err_str = hip.hipGetErrorString(err).decode("utf-8")
print(
f"hipRuntimeGetVersion failed with err code {err} ({err_str})",
file=sys.stderr,
)
return False

for symbol in [b"hipAmdFileWrite", b"hipAmdFileRead"]:
func_ptr = ctypes.c_void_p()
symbol_status = hipDriverProcAddressQueryResult()
err = hip.hipGetProcAddress(
symbol,
ctypes.byref(func_ptr),
version.value,
0,
ctypes.byref(symbol_status),
)
hip.hipGetErrorString.argtypes = [hipError_t]
hip.hipGetErrorString.restype = ctypes.c_char_p
Comment on lines +123 to +139
Copy link

Copilot AI Apr 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When probing each libamdhip64.so, the code assumes hipRuntimeGetVersion, hipGetProcAddress, and hipGetErrorString exist. If a candidate library loads but is missing one of these symbols, ctypes will raise AttributeError and the script will crash instead of treating that path as “not a HIP runtime”. Wrap the symbol lookups/argtypes setup in a try/except AttributeError (or use getattr checks) and continue to the next candidate.

Copilot uses AI. Check for mistakes.

version = ctypes.c_int()
err = hip.hipRuntimeGetVersion(ctypes.byref(version))
if err != 0:
if symbol_status.value != 1:
symbol = symbol.decode("utf-8")
err_str = hip.hipGetErrorString(err).decode("utf-8")
print(
f"hipGetProcAddress({symbol}) failed with err code"
f" {err} ({err_str}) and symbolStatus"
f" {symbol_status.value}",
file=sys.stderr,
)
return False
err_str = hip.hipGetErrorString(err).decode("utf-8")
print(
f"hipRuntimeGetVersion failed with err code {err} ({err_str})",
file=sys.stderr,
)
continue

return True
# Track whether all required AIS symbols are available in this library
supported = True

for symbol in [b"hipAmdFileWrite", b"hipAmdFileRead"]:
func_ptr = ctypes.c_void_p()
symbol_status = hipDriverProcAddressQueryResult()
err = hip.hipGetProcAddress(
symbol,
ctypes.byref(func_ptr),
version.value,
0,
ctypes.byref(symbol_status),
)
if err != 0:
if symbol_status.value != 1:
symbol = symbol.decode("utf-8")
err_str = hip.hipGetErrorString(err).decode("utf-8")
print(
f"hipGetProcAddress({symbol}) failed with err code"
Comment on lines +164 to +169
f" {err} ({err_str}) and symbolStatus"
f" {symbol_status.value}",
file=sys.stderr,
)
Comment on lines +164 to +173
Copy link

Copilot AI Apr 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The AIS symbol check only treats err != 0 as a failure. In this repo’s C++ HIP wrapper (src/amd_detail/hip.cpp), hipGetProcAddress can succeed yet return a null pointer for an unavailable symbol (the wrapper checks for nullptr). In that case, this script would incorrectly mark the runtime as AIS-supported. After hipGetProcAddress, also verify the returned pointer (e.g., func_ptr.value) and/or symbol_status indicates the symbol is found before keeping supported = True.

Suggested change
if err != 0:
if symbol_status.value != 1:
symbol = symbol.decode("utf-8")
err_str = hip.hipGetErrorString(err).decode("utf-8")
print(
f"hipGetProcAddress({symbol}) failed with err code"
f" {err} ({err_str}) and symbolStatus"
f" {symbol_status.value}",
file=sys.stderr,
)
if err != 0 or func_ptr.value is None or symbol_status.value != 1:
symbol_name = symbol.decode("utf-8")
if err != 0:
err_str = hip.hipGetErrorString(err).decode("utf-8")
else:
err_str = "success"
print(
f"hipGetProcAddress({symbol_name}) failed with err code"
f" {err} ({err_str}), symbolStatus {symbol_status.value},"
f" func_ptr {func_ptr.value}",
file=sys.stderr,
)

Copilot uses AI. Check for mistakes.
supported = False
break

if supported:
hip_libraries[hip_path] = True

return any(hip_libraries.values())


def amdgpu_supports_ais():
Expand Down Expand Up @@ -154,8 +225,17 @@ def main():
u = os.uname()
print()
print(u.sysname, u.nodename, u.release, u.version, u.machine)

print()
print("Found these HIP libraries (some may be redundant symlinks):")
for lib, support in hip_libraries.items():
if support:
pretty_supported = "supported"
else:
pretty_supported = "NOT supported"
print(f"\t{lib} (AIS {pretty_supported})")

print()
print("AIS support in:")
for name, supported in component_support:
print(f"\t{name:<24}: {supported}")
Expand Down
Loading