Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion xla/stream_executor/rocm/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,6 @@ cc_library(
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/hash",
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/numeric:int128",
Expand Down
40 changes: 7 additions & 33 deletions xla/stream_executor/rocm/rocm_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ limitations under the License.

#include "absl/base/casts.h"
#include "absl/container/inlined_vector.h"
#include "absl/hash/hash.h"
#include "absl/log/check.h"
#include "absl/log/log.h"
#include "absl/numeric/int128.h"
Expand Down Expand Up @@ -178,17 +177,6 @@ absl::StatusOr<hipFunction_t> GetModuleFunction(Context* context,
const char* kernel_name) {
ScopedActivateContext activated(context);
CHECK(module != nullptr && kernel_name != nullptr);
// Check for pre-existing HIP errors before the call. On ROCm 7+
// the per-thread error state is sticky: successful HIP calls do
// not clear it, so a stale error from a prior operation would
// produce confusing diagnostics if we proceeded.
hipError_t pre_err = ::hipPeekAtLastError();
if (pre_err != hipSuccess) {
return absl::InternalError(
absl::StrCat("There was a HIP error before calling "
"hipModuleGetFunction for kernel '",
kernel_name, "': ", ToString(pre_err)));
}
hipFunction_t function;
TF_RETURN_IF_ERROR(
ToStatus(wrap::hipModuleGetFunction(&function, module, kernel_name),
Expand All @@ -210,16 +198,6 @@ absl::Status GetModuleSymbol(Context* context, hipModule_t module,
absl::StrCat("Failed to get symbol '", symbol_name, "'"));
}

// Compute a content-based ModuleHandle from HSACO bytes.
// Using a content hash instead of the raw data pointer avoids stale cache
// entries when an HSACO buffer is freed and a new one is allocated at the
// same address (pointer-reuse cache collision).
ModuleHandle HsacoModuleHandle(const char* hsaco, size_t size) {
auto hash = absl::HashOf(absl::string_view(hsaco, size));
// Ensure hash is never 0 (ModuleHandle treats nullptr as invalid)
return ModuleHandle{reinterpret_cast<const void*>(hash | 1)};
}

// Unloads module from the current context via cuModuleUnload.
void UnloadRocmModule(Context* context, hipModule_t module) {
ScopedActivateContext activated(context);
Expand Down Expand Up @@ -682,12 +660,9 @@ absl::StatusOr<std::unique_ptr<Kernel>> RocmExecutor::LoadKernel(
const auto& cubin = spec.cuda_cubin_in_memory()->cubin_bytes;
const char* hsaco = reinterpret_cast<const char*>(cubin.data());
absl::MutexLock lock{in_memory_modules_mu_};
ModuleHandle module_handle = HsacoModuleHandle(hsaco, cubin.size());
hipModule_t& module = in_memory_modules_[module_handle];

if (module == nullptr) {
TF_ASSIGN_OR_RETURN(module, LoadHsaco(&rocm_context_, hsaco));
}
TF_ASSIGN_OR_RETURN(ModuleHandle module_handle,
LoadModuleFromHsaco(hsaco));
hipModule_t module = gpu_binary_to_module_.at(module_handle).first;
kernel_to_gpu_binary_[rocm_kernel.get()] = module_handle;

VLOG(2) << "getting function " << kernel_name << " from module " << module;
Expand Down Expand Up @@ -757,17 +732,16 @@ absl::StatusOr<ModuleHandle> RocmExecutor::LoadModule(
// TODO(ROCm): Need generic term instead of cubin/cuda/ptx
if (spec.has_cuda_cubin_in_memory()) {
absl::MutexLock lock{in_memory_modules_mu_};
const auto& cubin = spec.cuda_cubin_in_memory();
return LoadModuleFromHsaco(reinterpret_cast<const char*>(cubin.data()),
cubin.size());
return LoadModuleFromHsaco(
reinterpret_cast<const char*>(spec.cuda_cubin_in_memory().data()));
} else {
return absl::InternalError("No HASCO binary found");
}
}

absl::StatusOr<ModuleHandle> RocmExecutor::LoadModuleFromHsaco(
const char* hsaco, size_t size) {
ModuleHandle module_handle = HsacoModuleHandle(hsaco, size);
const char* hsaco) {
ModuleHandle module_handle{hsaco};
uint64_t module_refcount;
hipModule_t module;
std::tie(module, module_refcount) = gpu_binary_to_module_[module_handle];
Expand Down
3 changes: 1 addition & 2 deletions xla/stream_executor/rocm/rocm_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,7 @@ class RocmExecutor : public GpuExecutor {
absl::Status InitBlas();

// Loads a module in HSACO format.
absl::StatusOr<ModuleHandle> LoadModuleFromHsaco(const char* hsaco,
size_t size)
absl::StatusOr<ModuleHandle> LoadModuleFromHsaco(const char* hsaco)
ABSL_EXCLUSIVE_LOCKS_REQUIRED(in_memory_modules_mu_);

bool UnloadGpuBinary(ModuleHandle module_handle)
Expand Down
Loading