diff --git a/pytato/distributed/execute.py b/pytato/distributed/execute.py index 3be65c76b..dd3ee9757 100644 --- a/pytato/distributed/execute.py +++ b/pytato/distributed/execute.py @@ -110,7 +110,8 @@ def execute_distributed_partition( queue: Any, mpi_communicator: Any, *, allocator: Any | None = None, - input_args: dict[str, Any] | None = None) -> dict[str, Any]: + input_args: dict[str, Any] | None = None, + actx: Any | None = None) -> dict[str, Any]: if input_args is None: input_args = {} @@ -168,10 +169,18 @@ def _get_partition_input_name_refcount(partition: DistributedGraphPartition) \ def exec_ready_part(part: DistributedGraphPart) -> None: inputs = {k: context[k] for k in part.all_input_names()} - _evt, result_dict = prg_per_partition[part.pid](queue, + if actx and actx.profile_kernels: + import pyopencl as cl + start_evt = cl.enqueue_marker(queue) + + evt, result_dict = prg_per_partition[part.pid](queue, allocator=allocator, **inputs) + if actx and actx.profile_kernels: + name = next(iter(prg_per_partition[part.pid].program.entrypoints)) + actx._add_profiling_events(start_evt, evt, name) + context.update(result_dict) for name, send_nodes in part.name_to_send_nodes.items():