Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 22 additions & 17 deletions examples/scripts/code_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -836,10 +836,6 @@ def _compile_one_kernel(kernel):
logger.debug(f"Tensor order: {list(tensors.keys())}")
logger.debug(f"orch_args count: {len(orch_args)}")

# Create and initialize runtime (including kernel registration)
logger.info("=== Initializing Runtime ===")
runtime = Runtime()

# Build environment for runtime initialization
run_env = _kernel_config_runtime_env(self._kernel_config, self.kernels_dir)
if run_env:
Expand All @@ -856,27 +852,32 @@ def _compile_one_kernel(kernel):

initial_outputs = {k: v.clone() for k, v in outputs.items()}

runtime = Runtime()

with _temporary_env(run_env):
runtime.initialize(
orch_so_binary,
self.orchestration["function_name"],
orch_args,
kernel_binaries=kernel_binaries,
)

for round_idx in range(self.repeat_rounds):
if self.repeat_rounds > 1:
logger.info(f"--- Round {round_idx + 1}/{self.repeat_rounds} ---")

for k, v in initial_outputs.items():
outputs[k].copy_(v)

runtime = Runtime()
t_round_start = time.perf_counter()

# Enable profiling if requested (only first round)
if self.enable_profiling and round_idx == 0:
runtime.enable_profiling(True)
logger.info("Profiling enabled")

with _temporary_env(run_env):
runtime.initialize(
orch_so_binary,
self.orchestration["function_name"],
orch_args,
kernel_binaries=kernel_binaries,
)
for k, v in initial_outputs.items():
outputs[k].copy_(v)

runtime.initialize_round(
orch_args,
)

launch_runtime(
runtime,
Expand All @@ -888,10 +889,14 @@ def _compile_one_kernel(kernel):
orch_thread_num=self.orch_thread_num,
)

runtime.finalize()
runtime.finalize_round()
if not self.skip_golden:
self._compare_with_golden(outputs, golden)

t_round_end = time.perf_counter()
logger.info(f"HOST_TIMING round={round_idx} total_us={(t_round_end - t_round_start) * 1e6:.1f}")

runtime.finalize()
logger.info(f"=== Case {case_idx + 1}/{total_cases} Passed ===")

logger.info("=" * 60)
Expand Down
88 changes: 88 additions & 0 deletions python/bindings.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,19 @@ def _setup_functions(self):
self.lib.finalize_runtime.argtypes = [c_void_p]
self.lib.finalize_runtime.restype = c_int

# init_runtime_round - per-round data copy (INPUT+INOUT) to device
self.lib.init_runtime_round.argtypes = [
c_void_p, # runtime
POINTER(TaskArgC), # orch_args
c_int, # orch_args_count
POINTER(c_int), # arg_types
]
self.lib.init_runtime_round.restype = c_int

# finalize_runtime_round - copy results back without freeing resources
self.lib.finalize_runtime_round.argtypes = [c_void_p]
self.lib.finalize_runtime_round.restype = c_int

# Note: register_kernel has been internalized into init_runtime
# Kernel binaries are now passed directly to init_runtime()

Expand Down Expand Up @@ -232,6 +245,32 @@ def __init__(self, lib: CDLL):
size = lib.get_runtime_size()
self._buffer = ctypes.create_string_buffer(size)
self._handle = ctypes.cast(self._buffer, c_void_p)
self._initialized = False

def _convert_orch_params(self, orch_args, arg_types):
"""Convert orch_args and arg_types to ctypes arrays."""
orch_args = orch_args or []
orch_args_count = len(orch_args)

# Accept either a nanobind TaskArgArray (from task_interface) or a
# plain list of TaskArgC structs.
from _task_interface import TaskArgArray as _NbTaskArgArray

if isinstance(orch_args, _NbTaskArgArray):
orch_args_array = cast(orch_args.ctypes_ptr(), POINTER(TaskArgC)) if orch_args_count > 0 else None
# Prevent GC of the nanobind array while the ctypes pointer is live
self._nb_args_ref = orch_args
elif orch_args_count > 0:
orch_args_array = (TaskArgC * orch_args_count)(*orch_args)
else:
orch_args_array = None

if arg_types is not None and len(arg_types) > 0:
arg_types_array = (c_int * len(arg_types))(*arg_types)
else:
arg_types_array = None

return orch_args_array, orch_args_count, arg_types_array

def initialize(
self,
Expand Down Expand Up @@ -323,6 +362,7 @@ def initialize(
)
if rc != 0:
raise RuntimeError(f"init_runtime failed: {rc}")
self._initialized = True

def finalize(self) -> None:
"""
Expand All @@ -335,10 +375,58 @@ def finalize(self) -> None:
Raises:
RuntimeError: If finalization fails
"""
if not self._initialized:
return

rc = self.lib.finalize_runtime(self._handle)
if rc != 0:
raise RuntimeError(f"finalize_runtime failed: {rc}")
self._initialized = False

def initialize_round(
self,
orch_args: Optional[list] = None,
arg_types: Optional[List[int]] = None,
) -> None:
"""
Per-round initialization: copy INPUT and INOUT tensor data to device.

Uses existing device memory allocations from initialize().
Called every round (including the first) before launch_runtime().

Args:
orch_args: List of TaskArgC structs for orchestration
arg_types: Array describing each argument's type

Raises:
RuntimeError: If round initialization fails
"""
orch_args_array, orch_args_count, arg_types_array = \
self._convert_orch_params(orch_args, arg_types)

rc = self.lib.init_runtime_round(
self._handle,
orch_args_array,
orch_args_count,
arg_types_array,
)
if rc != 0:
raise RuntimeError(f"init_runtime_round failed: {rc}")

def finalize_round(self) -> None:
"""
Round-level finalize: copy results back but keep device resources alive.

Copies output/inout tensors from device to host without freeing
device memory or kernel binaries. Use between rounds in the same case.

Raises:
RuntimeError: If round finalization fails
"""
rc = self.lib.finalize_runtime_round(self._handle)
if rc != 0:
# Not supported by this runtime, fallback to full finalize
self.finalize()
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里释放掉所有资源后,其他round继续执行会造成潜在的bug


def enable_profiling(self, enabled: bool = True) -> None:
"""
Expand Down
32 changes: 32 additions & 0 deletions src/a2a3/platform/include/host/pto_runtime_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,38 @@ int launch_runtime(RuntimeHandle runtime,
size_t aicore_size,
int orch_thread_num);

/**
* Per-round initialization: copy INPUT and INOUT tensor data to device.
*
* Uses existing device memory allocations from init_runtime().
* Called every round (including the first) before launch_runtime().
*
* Must be called after a successful init_runtime(). The Runtime handle
* must not have been fully finalized.
*
* @param runtime Runtime handle (previously initialized)
* @param orch_args Array of TaskArg describing orchestration arguments
* @param orch_args_count Number of orchestration arguments
* @param arg_types Array describing each argument's type (ArgType enum)
* @return 0 on success, -1 on failure
*/
int init_runtime_round(RuntimeHandle runtime,
const struct TaskArg* orch_args,
int orch_args_count,
int* arg_types);

/**
* Round-level finalize: copy results back but keep device resources alive.
*
* Copies output/inout tensors from device to host, but does NOT free
* device memory, kernel binaries, or call the Runtime destructor.
* Use this between rounds within the same case.
*
* @param runtime Runtime handle to finalize for this round
* @return 0 on success, -1 on failure
*/
int finalize_runtime_round(RuntimeHandle runtime);

/**
* Finalize and cleanup a runtime instance.
*
Expand Down
32 changes: 32 additions & 0 deletions src/a2a3/platform/onboard/host/pto_runtime_c_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,11 @@ int init_runtime_impl(Runtime* runtime,
const size_t* kernel_sizes,
int kernel_count);
int validate_runtime_impl(Runtime* runtime);
int init_runtime_round_impl(Runtime* runtime,
const TaskArg* orch_args,
int orch_args_count,
int* arg_types);
int validate_runtime_round_impl(Runtime* runtime);

/* Forward declarations for device memory functions used in init_runtime */
void* device_malloc(size_t size);
Expand Down Expand Up @@ -200,6 +205,33 @@ int launch_runtime(RuntimeHandle runtime,
}
}

int init_runtime_round(RuntimeHandle runtime,
const TaskArg* orch_args,
int orch_args_count,
int* arg_types) {
if (runtime == NULL) {
return -1;
}
try {
Runtime* r = static_cast<Runtime*>(runtime);
return init_runtime_round_impl(r, orch_args, orch_args_count, arg_types);
} catch (...) {
return -1;
}
}

int finalize_runtime_round(RuntimeHandle runtime) {
if (runtime == NULL) {
return -1;
}
try {
Runtime* r = static_cast<Runtime*>(runtime);
return validate_runtime_round_impl(r);
} catch (...) {
return -1;
}
}

int finalize_runtime(RuntimeHandle runtime) {
if (runtime == NULL) {
return -1;
Expand Down
32 changes: 32 additions & 0 deletions src/a2a3/platform/sim/host/pto_runtime_c_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,11 @@ int init_runtime_impl(Runtime* runtime,
const size_t* kernel_sizes,
int kernel_count);
int validate_runtime_impl(Runtime* runtime);
int init_runtime_round_impl(Runtime* runtime,
const TaskArg* orch_args,
int orch_args_count,
int* arg_types);
int validate_runtime_round_impl(Runtime* runtime);

/* Forward declarations */
void* device_malloc(size_t size);
Expand Down Expand Up @@ -203,6 +208,33 @@ int launch_runtime(RuntimeHandle runtime,
}
}

int init_runtime_round(RuntimeHandle runtime,
const TaskArg* orch_args,
int orch_args_count,
int* arg_types) {
if (runtime == NULL) {
return -1;
}
try {
Runtime* r = static_cast<Runtime*>(runtime);
return init_runtime_round_impl(r, orch_args, orch_args_count, arg_types);
} catch (...) {
return -1;
}
}

int finalize_runtime_round(RuntimeHandle runtime) {
if (runtime == NULL) {
return -1;
}
try {
Runtime* r = static_cast<Runtime*>(runtime);
return validate_runtime_round_impl(r);
} catch (...) {
return -1;
}
}

int finalize_runtime(RuntimeHandle runtime) {
if (runtime == NULL) {
return -1;
Expand Down
Loading
Loading