Skip to content

Commit 6ae30d6

Browse files
factor out profile enable/disable
1 parent fedb836 commit 6ae30d6

2 files changed

Lines changed: 48 additions & 17 deletions

File tree

arraycontext/impl/pytato/__init__.py

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -284,22 +284,6 @@ def __init__(
284284
if allocator is not None and use_memory_pool is not None:
285285
raise TypeError("may not specify both allocator and use_memory_pool")
286286

287-
self.profile_kernels = profile_kernels
288-
289-
if profile_kernels:
290-
import pyopencl as cl
291-
if not queue.properties & cl.command_queue_properties.PROFILING_ENABLE:
292-
raise RuntimeError("Profiling was not enabled in the command queue. "
293-
"Please create the queue with "
294-
"cl.command_queue_properties.PROFILING_ENABLE.")
295-
296-
# List of ProfileEvents that haven't been transferred to profiled
297-
# results yet
298-
self._profile_events: list[ProfileEvent] = []
299-
300-
# Dict of kernel name -> list of kernel execution times
301-
self._profile_results: dict[str, list[int]] = {}
302-
303287
self.using_svm = None
304288

305289
if allocator is None:
@@ -348,8 +332,29 @@ def __init__(
348332

349333
self._force_svm_arg_limit = _force_svm_arg_limit
350334

335+
self._enable_profiling(profile_kernels)
336+
351337
# {{{ Profiling functionality
352338

339+
def _enable_profiling(self, enable: bool) -> None:
340+
# List of ProfileEvents that haven't been transferred to profiled
341+
# results yet
342+
self._profile_events: list[ProfileEvent] = []
343+
344+
# Dict of kernel name -> list of kernel execution times
345+
self._profile_results: dict[str, list[int]] = {}
346+
347+
if enable:
348+
import pyopencl as cl
349+
if not self.queue.properties & cl.command_queue_properties.PROFILING_ENABLE:
350+
raise RuntimeError("Profiling was not enabled in the command queue. "
351+
"Please create the queue with "
352+
"cl.command_queue_properties.PROFILING_ENABLE.")
353+
self.profile_kernels = True
354+
355+
else:
356+
self.profile_kernels = False
357+
353358
def _wait_and_transfer_profile_events(self) -> None:
354359
"""Wait for all profiling events to finish and transfer the results
355360
to *self._profile_results*."""

test/test_pytato_arraycontext.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -329,7 +329,7 @@ def twice(x):
329329

330330
# {{{ test disabling profiling
331331

332-
actx.profile_kernels = False
332+
actx._enable_profiling(False)
333333

334334
assert len(actx._profile_events) == 0
335335

@@ -341,6 +341,32 @@ def twice(x):
341341

342342
# }}}
343343

344+
# {{{ test enabling profiling
345+
346+
actx._enable_profiling(True)
347+
348+
assert len(actx._profile_events) == 0
349+
350+
for _ in range(10):
351+
assert actx.to_numpy(f(99)) == 198
352+
353+
assert len(actx._profile_events) == 10
354+
actx._wait_and_transfer_profile_events()
355+
assert len(actx._profile_events) == 0
356+
assert len(actx._profile_results) == 1
357+
358+
# }}}
359+
360+
queue2 = cl.CommandQueue(cl_ctx)
361+
362+
with pytest.raises(RuntimeError):
363+
PytatoPyOpenCLArrayContext(queue2, profile_kernels=True)
364+
365+
actx2 = PytatoPyOpenCLArrayContext(queue2)
366+
367+
with pytest.raises(RuntimeError):
368+
actx2._enable_profiling(True)
369+
344370

345371
if __name__ == "__main__":
346372
import sys

0 commit comments

Comments
 (0)