Skip to content

Commit 55a0c2c

Browse files
add add_profiling_event
1 parent 8ade259 commit 55a0c2c

2 files changed

Lines changed: 8 additions & 8 deletions

File tree

arraycontext/impl/pytato/__init__.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -374,6 +374,11 @@ def _wait_and_transfer_profile_events(self) -> None:
374374

375375
self.profile_events = []
376376

377+
def add_profiling_event(self, evt: cl._cl.Event, translation_unit: Any) -> None:
378+
"""Add a profiling event to the list of profiling events."""
379+
if self.profile_kernels:
380+
self.profile_events.append(ProfileEvent(evt, translation_unit))
381+
377382
def get_profiling_data_for_kernel(self, kernel_name: str) \
378383
-> MultiCallKernelProfile:
379384
"""Return profiling data for kernel *kernel_name*."""
@@ -660,8 +665,7 @@ def _to_frozen(key: tuple[Any, ...], ary) -> TaggableCLArray:
660665
allocator=self.allocator,
661666
**bound_arguments)
662667

663-
if self.profile_kernels:
664-
self.profile_events.append(ProfileEvent(evt, pt_prg))
668+
self.add_profiling_event(evt, pt_prg)
665669

666670
assert len(set(out_dict) & set(key_to_frozen_subary)) == 0
667671

arraycontext/impl/pytato/compile.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -640,9 +640,7 @@ def __call__(self, arg_id_to_arg) -> ArrayContainer:
640640
allocator=self.actx.allocator,
641641
**input_kwargs_for_loopy)
642642

643-
if self.actx.profile_kernels:
644-
from arraycontext.impl.pytato import ProfileEvent
645-
self.actx.profile_events.append(ProfileEvent(evt, self.pytato_program))
643+
self.actx.add_profiling_event(evt, self.pytato_program)
646644

647645
def to_output_template(keys, _):
648646
name_in_program = self.output_id_to_name_in_program[keys]
@@ -683,9 +681,7 @@ def __call__(self, arg_id_to_arg) -> ArrayContainer:
683681
allocator=self.actx.allocator,
684682
**input_kwargs_for_loopy)
685683

686-
if self.actx.profile_kernels:
687-
from arraycontext.impl.pytato import ProfileEvent
688-
self.actx.profile_events.append(ProfileEvent(evt, self.pytato_program))
684+
self.actx.add_profiling_event(evt, self.pytato_program)
689685

690686
return self.actx.thaw(to_tagged_cl_array(out_dict[self.output_name],
691687
axes=get_cl_axes_from_pt_axes(

0 commit comments

Comments
 (0)