From 172a8c5f2e964acf23031e9e29845aaf704170d4 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Thu, 1 May 2025 16:02:51 -0500 Subject: [PATCH] execute_distribute_partition: allow recording profiling events --- pytato/distributed/execute.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/pytato/distributed/execute.py b/pytato/distributed/execute.py index 3be65c76b..dd3ee9757 100644 --- a/pytato/distributed/execute.py +++ b/pytato/distributed/execute.py @@ -110,7 +110,8 @@ def execute_distributed_partition( queue: Any, mpi_communicator: Any, *, allocator: Any | None = None, - input_args: dict[str, Any] | None = None) -> dict[str, Any]: + input_args: dict[str, Any] | None = None, + actx: Any | None = None) -> dict[str, Any]: if input_args is None: input_args = {} @@ -168,10 +169,18 @@ def _get_partition_input_name_refcount(partition: DistributedGraphPartition) \ def exec_ready_part(part: DistributedGraphPart) -> None: inputs = {k: context[k] for k in part.all_input_names()} - _evt, result_dict = prg_per_partition[part.pid](queue, + if actx and actx.profile_kernels: + import pyopencl as cl + start_evt = cl.enqueue_marker(queue) + + evt, result_dict = prg_per_partition[part.pid](queue, allocator=allocator, **inputs) + if actx and actx.profile_kernels: + name = next(iter(prg_per_partition[part.pid].program.entrypoints)) + actx._add_profiling_events(start_evt, evt, name) + context.update(result_dict) for name, send_nodes in part.name_to_send_nodes.items():