5454import abc
5555import sys
5656from collections .abc import Callable
57+ from dataclasses import dataclass
5758from typing import TYPE_CHECKING , Any
5859
5960import numpy as np
7475
7576if TYPE_CHECKING :
7677 import loopy as lp
77- import pyopencl as cl
7878 import pytato
7979
8080if getattr (sys , "_BUILDING_SPHINX_DOCS" , False ):
@@ -235,6 +235,16 @@ def get_target(self):
235235
236236# {{{ PytatoPyOpenCLArrayContext
237237
238+
239+ @dataclass
240+ class ProfileEvent :
241+ """Holds a profile event that has not been collected by the profiler yet."""
242+
243+ start_cl_event : cl ._cl .Event
244+ stop_cl_event : cl ._cl .Event
245+ t_unit_name : str
246+
247+
238248class PytatoPyOpenCLArrayContext (_BasePytatoArrayContext ):
239249 """
240250 An :class:`ArrayContext` that uses :mod:`pytato` data types to represent
@@ -259,7 +269,7 @@ def __init__(
259269 self , queue : cl .CommandQueue , allocator = None , * ,
260270 use_memory_pool : bool | None = None ,
261271 compile_trace_callback : Callable [[Any , str , Any ], None ] | None = None ,
262-
272+ profile_kernels : bool = False ,
263273 # do not use: only for testing
264274 _force_svm_arg_limit : int | None = None ,
265275 ) -> None :
@@ -322,6 +332,59 @@ def __init__(
322332
323333 self ._force_svm_arg_limit = _force_svm_arg_limit
324334
335+ self ._enable_profiling (profile_kernels )
336+
337+ # {{{ Profiling functionality
338+
339+ def _enable_profiling (self , enable : bool ) -> None :
340+ # List of ProfileEvents that haven't been transferred to profiled
341+ # results yet
342+ self ._profile_events : list [ProfileEvent ] = []
343+
344+ # Dict of kernel name -> list of kernel execution times
345+ self ._profile_results : dict [str , list [int ]] = {}
346+
347+ if enable :
348+ import pyopencl as cl
349+ if not self .queue .properties & cl .command_queue_properties .PROFILING_ENABLE :
350+ raise RuntimeError ("Profiling was not enabled in the command queue. "
351+ "Please create the queue with "
352+ "cl.command_queue_properties.PROFILING_ENABLE." )
353+ self .profile_kernels = True
354+
355+ else :
356+ self .profile_kernels = False
357+
358+ def _wait_and_transfer_profile_events (self ) -> None :
359+ """Wait for all profiling events to finish and transfer the results
360+ to *self._profile_results*."""
361+ import pyopencl as cl
362+ # First, wait for completion of all events
363+ if self ._profile_events :
364+ cl .wait_for_events ([p_event .stop_cl_event
365+ for p_event in self ._profile_events ])
366+
367+ # Then, collect all events and store them
368+ for t in self ._profile_events :
369+ name = t .t_unit_name
370+
371+ time = t .stop_cl_event .profile .end - t .start_cl_event .profile .end
372+
373+ self ._profile_results .setdefault (name , []).append (time )
374+
375+ self ._profile_events = []
376+
377+ def _add_profiling_events (self , start : cl ._cl .Event , stop : cl ._cl .Event ,
378+ t_unit_name : str ) -> None :
379+ """Add profiling events to the list of profiling events."""
380+ self ._profile_events .append (ProfileEvent (start , stop , t_unit_name ))
381+
382+ def _reset_profiling_data (self ) -> None :
383+ """Reset profiling data."""
384+ self ._profile_results = {}
385+
386+ # }}}
387+
325388 @property
326389 def _frozen_array_types (self ) -> tuple [type , ...]:
327390 import pyopencl .array as cla
@@ -546,9 +609,18 @@ def _to_frozen(key: tuple[Any, ...], ary) -> TaggableCLArray:
546609 self ._dag_transform_cache [normalized_expr ])
547610
548611 assert len (pt_prg .bound_arguments ) == 0
549- _evt , out_dict = pt_prg (self .queue ,
612+
613+ if self .profile_kernels :
614+ import pyopencl as cl
615+ start_evt = cl .enqueue_marker (self .queue )
616+
617+ evt , out_dict = pt_prg (self .queue ,
550618 allocator = self .allocator ,
551619 ** bound_arguments )
620+
621+ if self .profile_kernels :
622+ self ._add_profiling_events (start_evt , evt , pt_prg .program .entrypoint )
623+
552624 assert len (set (out_dict ) & set (key_to_frozen_subary )) == 0
553625
554626 key_to_frozen_subary = {
0 commit comments