[ROCm] Cherry-pick: Fix HSACO module cache using pointer-based key causing stale lookups#780
Conversation
…ookups - The in_memory_modules_ cache in RocmExecutor used the raw HSACO data pointer as the cache key. When a CustomKernelThunk's owned HSACO buffer was freed and a new buffer was allocated at the same address, the cache returned a stale module loaded from different binary content, causing hipModuleGetFunction to fail with hipErrorNotFound. - Replace pointer-based ModuleHandle key with a content hash of the HSACO bytes (absl::HashOf). Same content still hits the cache; different content at a reused address correctly misses. - Add hipPeekAtLastError guard before hipModuleGetFunction to surface pre-existing sticky errors early with clear diagnostics instead of producing confusing failures.
d60dad9 to
85555f4
Compare
There was a problem hiding this comment.
I'm not sure I fully got this. @pemeliya @nurmukhametov did you aware of this one?
There was a problem hiding this comment.
IIUC, this is because of openxla@e129675
do we have this offensive commit in 0.9.2?
|
yes that commit exposes this problem. but commit itself doesn't introduce the bug. |
my question is, does xla-0.9.2 has this commit? |
|
No Chao, openxla@e129675 is not in 9.2 xla |
| // the per-thread error state is sticky: successful HIP calls do | ||
| // not clear it, so a stale error from a prior operation would | ||
| // produce confusing diagnostics if we proceeded. | ||
| hipError_t pre_err = ::hipPeekAtLastError(); |
There was a problem hiding this comment.
Oh I am not sure we shall just clean up sticky error bit here. It may hide problems in other libraries. Like the recent one about rocprofiler collector using wrong device ID.
There was a problem hiding this comment.
Hi Pavel. intention here is to not to clear sticky error but get a peek at it and if there is an error dont proceed and error out here. if we simply proceed with this error we run into complex run to run variation issue under parallel loads. so I added more this check more of a guard.
There was a problem hiding this comment.
ah ok, I see, I have confused it wit hipGetLastError()
| // Compute a content-based ModuleHandle from HSACO bytes. | ||
| // Using a content hash instead of the raw data pointer avoids stale cache | ||
| // entries when an HSACO buffer is freed and a new one is allocated at the | ||
| // same address (pointer-reuse cache collision). |
There was a problem hiding this comment.
I am wondering how it happens that we have stale cache entries? I saw we have several flat_hash_maps in rocm_executor: in_memory_modules_, kernel_to_gpu_binary_ and gpu_binary_to_module_.
Which one was returning a stale cache entry? I see at least that UnloadGpuBinary() shall cleanup gpu_binary_to_module_ and in_memory_modules_. Though, the logic about erasing in_memory_modules_ is not easy to understand at first glance..
I mean, from my understanding, this shall not happen since GpuExecutable maintains a list of module handles which shall be automatically unloaded at the end.
There was a problem hiding this comment.
yes there is an issue with unload right now and it doesnt take all cases into account. I'm working on a fix , its not ready yet, might take couple more days.
But I agree with that fix we may not need hashing at all. I just discovered this yesterday we don't have to merge this PR now. I'll work on improved one.
but can we please not revert the upstream PR , without this fix we have very annoying jax bug in pytest run.
summary : this fix works but there is more to it, and I have found actual problem with unload an erase routine will work on it and come up with better fix.
cc: @i-chaochen
There was a problem hiding this comment.
Yes, I think we need to have a proper fix for that.. I mean, here we use the result of absl::HashOf() as a unique key to identify hipModule_t object. Though, absl::Hash is very fast but prone to collisions (not like SHA256), so this could be a subtle bug which only shows up under a heavy load - when two different hipModules suddenly get the same hash key..
Summary
Cherry-pick of openxla#40419 into
rocm-jaxlib-v0.9.2release branch.Fixes flaky sort-related test failures caused by HSACO module cache returning stale modules when a freed buffer's address is reused.
ModuleHandlekey with content hash of HSACO bytes (absl::HashOf)hipPeekAtLastErrorguard beforehipModuleGetFunctionfor clearer diagnostics