Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 6 additions & 18 deletions benchmarks/speed.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
"""Measure the speed of the MCMC and its interfaces."""

from collections.abc import Mapping
from contextlib import nullcontext, redirect_stdout
from contextlib import redirect_stdout
from dataclasses import replace
from functools import partial
from inspect import signature
Expand All @@ -50,7 +50,6 @@
from jax.tree_util import tree_map
from jaxtyping import Array, Float32, Integer, Key, UInt8

import bartz
from bartz import mcmcloop, mcmcstep
from bartz.mcmcloop import run_mcmc
from benchmarks.latest_bartz.jaxext import get_device_count, split
Expand Down Expand Up @@ -313,7 +312,6 @@ def setup(
niters: int = NITERS,
nchains: int = 1,
cache: Cache = 'warm',
profile: bool = False,
predict: bool = False,
kwargs: Mapping[str, Any] = MappingProxyType({}),
) -> None:
Expand Down Expand Up @@ -357,22 +355,12 @@ def setup(
self.kw.update(kwargs)
block_until_ready(self.kw)

# set profile mode
if not profile:
self.context = nullcontext
elif hasattr(bartz, 'profile_mode'):
self.context = lambda: bartz.profile_mode(True)
else:
msg = 'Profile mode not supported.'
raise NotImplementedError(msg)

# save information used to run predictions
self.predict = predict
if predict:
self.test = test
with self.context():
self.bart = gbart(**self.kw)
block_bart(self.bart)
self.bart = gbart(**self.kw)
block_bart(self.bart)

# decide how much to cold-start
match cache:
Expand All @@ -385,7 +373,7 @@ def setup(

def time_gbart(self, *_: Any) -> None:
"""Time instantiating the class."""
with redirect_stdout(StringIO()), self.context():
with redirect_stdout(StringIO()):
if self.predict:
ypred = self.bart.predict(self.test.x)
block_until_ready(ypred)
Expand Down Expand Up @@ -439,13 +427,13 @@ def setup(self, nchains: int, shard: bool) -> None: # ty:ignore[invalid-method-
# on gpu shard explicitly
kwargs = dict(num_chain_devices=min(nchains, get_device_count()))

super().setup(NITERS, nchains, 'warm', False, False, dict(bart_kwargs=kwargs))
super().setup(NITERS, nchains, 'warm', False, dict(bart_kwargs=kwargs))


class GbartGeneric(BaseGbart):
"""General timing of `mc_gbart` with many settings."""

params = ((0, NITERS), (1, 6), ('warm', 'cold'), (False, True), (False, True))
params = ((0, NITERS), (1, 6), ('warm', 'cold'), (False, True))


class BaseRunMcmc(AutoParamNames):
Expand Down
11 changes: 5 additions & 6 deletions docs/development.rst
Original file line number Diff line number Diff line change
Expand Up @@ -148,25 +148,26 @@ This runs only benchmarks whose name matches <pattern>, only once, within the wo
Profiling
---------

Use the `JAX profiling utilities <https://docs.jax.dev/en/latest/profiling.html>`_ to profile `bartz`. By default the MCMC loop is compiled all at once, which makes it quite opaque to profiling. There are two ways to understand what's going on inside in more detail: 1) inspect the individual operations and use intuition to understand to what piece of code they correspond to, 2) turn on bartz's profile mode. Basic workflow:
Use the `JAX profiling utilities <https://docs.jax.dev/en/latest/profiling.html>`_ to profile `bartz`. It works well on GPU, not on CPU.

.. code-block:: python

from jax.profiler import trace, ProfileOptions
from jax import block_until_ready
from bartz.BART import gbart
from bartz import profile_mode

traceopt = ProfileOptions()

# this setting makes Python function calls show up in the trace
traceopt.python_tracer_level = 1

# on cpu, this makes the trace detailed enough to understand what's going on
# even within compiled functions
# even within compiled functions by manual inspection of each operation
traceopt.host_tracer_level = 2

with trace('./trace_results', profiler_options=traceopt), profile_mode(True):
with trace('./trace_results', profiler_options=traceopt):
bart = gbart(...)
block_until_ready(bart)

On the first run, the trace will show compilation operations, while subsequent runs (within the same Python shell) will be warmed-up. Start a xprof server to visualize the results:

Expand All @@ -177,5 +178,3 @@ On the first run, the trace will show compilation operations, while subsequent r
XProf at http://localhost:8791/ (Press CTRL+C to quit)

Open the provided URL in a browser. In the sidebar, select the tool "Trace Viewer".

In "profile mode", the MCMC loop is split into a few chunks that are compiled separately, allowing to see at a glance how much time each phase of the MCMC cycle takes. This causes some overhead, so the timings are not equivalent to the normal mode ones. On some specific example on CPU, Bartz was 20% slower in profile mode with one chain, and 2x slower with multiple chains.
1 change: 0 additions & 1 deletion docs/reference/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,3 @@ Reference
jaxext.rst
debug.rst
test.rst
profile.rst
28 changes: 0 additions & 28 deletions docs/reference/profile.rst

This file was deleted.

1 change: 0 additions & 1 deletion src/bartz/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,5 +30,4 @@

from bartz import BART, grove, jaxext, mcmcloop, mcmcstep, prepcovars # noqa: F401
from bartz._interface import Bart # noqa: F401
from bartz._profiler import profile_mode # noqa: F401
from bartz._version import __version__, __version_info__ # noqa: F401
Loading