Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions backends/aoti/aoti_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

class COMPILE_SPEC_KEYS(Enum):
METHOD_NAME = "method_name"
SHARE_KV_CACHE_ACROSS_METHODS = "share_kv_cache_across_methods"


@experimental(
Expand Down Expand Up @@ -286,3 +287,13 @@ def method_name_from_compile_specs(
raise RuntimeError(
f"Could not find method name in compile specs: {compile_specs}"
)

@classmethod
def generate_share_kv_cache_compile_spec(cls) -> CompileSpec:
"""
Generate a CompileSpec to enable cross-method KV cache sharing.
"""
return CompileSpec(
COMPILE_SPEC_KEYS.SHARE_KV_CACHE_ACROSS_METHODS.value,
bytes([1]),
)
28 changes: 24 additions & 4 deletions backends/cuda/runtime/cuda_backend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ namespace {
constexpr char kSkipCopyOutputToCpuForMethod[] =
"skip_copy_output_to_cpu_for_method";
constexpr char kUseSharedCudaStream[] = "use_shared_cuda_stream";
constexpr char kShareKvCacheAcrossMethods[] = "share_kv_cache_across_methods";
} // anonymous namespace

class ET_EXPERIMENTAL CudaBackend final
Expand Down Expand Up @@ -287,12 +288,17 @@ class ET_EXPERIMENTAL CudaBackend final
ArrayRef<CompileSpec> compile_specs // This will be my empty list
) const override {
std::string method_name;
bool share_kv_cache = false;
for (const CompileSpec& spec : compile_specs) {
if (std::strcmp(spec.key, "method_name") == 0) {
method_name.assign(
static_cast<const char*>(spec.value.buffer),
spec.value.nbytes); // no nullptr guarantee, so pass size
break;
} else if (std::strcmp(spec.key, kShareKvCacheAcrossMethods) == 0) {
if (spec.value.nbytes >= 1) {
share_kv_cache =
static_cast<const uint8_t*>(spec.value.buffer)[0] != 0;
}
}
}

Expand Down Expand Up @@ -416,14 +422,16 @@ class ET_EXPERIMENTAL CudaBackend final
// ---------------------------------------------------------------
// Cross-method constant sharing (e.g., KV cache between prefill/decode).
//
// Only enabled when share_kv_cache_across_methods compile spec is set.
// The first container to initialize extracts its constants (keyed by
// original FQN) and stores the AtenTensorHandle's. Subsequent containers
// with matching FQNs are updated to point to the same GPU tensors via
// UpdateUserManagedConstantBufferPairs (user_managed = true → no copy,
// the source container retains ownership).
// ---------------------------------------------------------------
if (handle->get_num_constants && handle->get_constant_name &&
handle->get_constant_original_fqn && handle->extract_constants_map &&
if (share_kv_cache && handle->get_num_constants &&
handle->get_constant_name && handle->get_constant_original_fqn &&
handle->extract_constants_map &&
handle->update_user_managed_constant_buffer_pairs) {
size_t num_constants = 0;
handle->get_num_constants(handle->container_handle, &num_constants);
Expand Down Expand Up @@ -469,6 +477,8 @@ class ET_EXPERIMENTAL CudaBackend final
Error,
"Failed to extract constants from '%s'",
method_name.c_str());
delete handle;
return Error::Internal;
}
} else {
// Subsequent container: share matching constants from the first.
Expand Down Expand Up @@ -501,14 +511,24 @@ class ET_EXPERIMENTAL CudaBackend final
Error,
"Failed to share constants into '%s'",
method_name.c_str());
delete handle;
return Error::Internal;
}
}
}
}
} else if (share_kv_cache) {
ET_LOG(
Error,
"share_kv_cache_across_methods requested but constant sharing APIs "
"not available for method '%s'",
method_name.c_str());
delete handle;
return Error::Internal;
} else {
ET_LOG(
Info,
"Constant sharing APIs not available for method '%s'",
"Constant sharing not requested for method '%s'",
method_name.c_str());
}

Expand Down
10 changes: 8 additions & 2 deletions examples/models/qwen3_5_moe/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -659,12 +659,18 @@ def _export_cuda(model, config, args):
partitioner={
"decode": [
CudaPartitioner(
[CudaBackend.generate_method_name_compile_spec("decode")]
[
CudaBackend.generate_method_name_compile_spec("decode"),
CudaBackend.generate_share_kv_cache_compile_spec(),
]
)
],
"prefill": [
CudaPartitioner(
[CudaBackend.generate_method_name_compile_spec("prefill")]
[
CudaBackend.generate_method_name_compile_spec("prefill"),
CudaBackend.generate_share_kv_cache_compile_spec(),
]
)
],
},
Expand Down
Loading