diff --git a/nx_parallel/tests/test_should_run.py b/nx_parallel/tests/test_should_run.py index 22def8b6..d1649d80 100644 --- a/nx_parallel/tests/test_should_run.py +++ b/nx_parallel/tests/test_should_run.py @@ -3,7 +3,6 @@ import networkx as nx import inspect import pytest -import os def get_functions_with_should_run(): @@ -16,70 +15,6 @@ def test_get_functions_with_should_run(): assert set(get_functions_with_should_run()) == set(ALGORITHMS) -def test_default_should_run(): - @nxp._configure_if_nx_active() - def dummy_default(): - pass - - with pytest.MonkeyPatch().context() as mp: - mp.delitem(os.environ, "PYTEST_CURRENT_TEST", raising=False) - with nx.config.backends.parallel(n_jobs=1): - assert ( - dummy_default.should_run() - == "Parallel backend requires `n_jobs` > 1 to run" - ) - - assert dummy_default.should_run() - - -def test_skip_parallel_backend(): - @nxp._configure_if_nx_active(should_run=nxp.should_skip_parallel) - def dummy_skip_parallel(): - pass - - assert dummy_skip_parallel.should_run() == "Fast algorithm; skip parallel execution" - - -def test_should_run_if_large(): - @nxp._configure_if_nx_active(should_run=nxp.should_run_if_large) - def dummy_if_large(G): - pass - - smallG = nx.fast_gnp_random_graph(20, 0.6, seed=42) - largeG = nx.fast_gnp_random_graph(250, 0.6, seed=42) - - assert dummy_if_large.should_run(smallG) == "Graph too small for parallel execution" - assert dummy_if_large.should_run(largeG) - - -def test_should_run_if_nodes_none(): - @nxp._configure_if_nx_active(should_run=nxp.should_run_if_nodes_none) - def dummy_nodes_none(G, nodes=None): - pass - - G = nx.fast_gnp_random_graph(20, 0.6, seed=42) - assert ( - dummy_nodes_none.should_run(G, nodes=[1, 3]) - == "Parallel execution only supported when `nodes` is None" - ) - assert dummy_nodes_none.should_run(G) - - -def test_should_run_if_sparse(): - @nxp._configure_if_nx_active(should_run=nxp.should_run_if_sparse(threshold=0.4)) - def dummy_if_sparse(G): - pass - - G_dense = nx.fast_gnp_random_graph(20, 0.6, seed=42) - assert ( - dummy_if_sparse.should_run(G_dense) - == "Graph too dense to benefit from parallel execution" - ) - - G_sparse = nx.fast_gnp_random_graph(20, 0.2, seed=42) - assert dummy_if_sparse.should_run(G_sparse) - - @pytest.mark.parametrize("func_name", get_functions_with_should_run()) def test_should_run(func_name): tournament_funcs = [ diff --git a/nx_parallel/utils/chunk.py b/nx_parallel/utils/chunk.py index d8df9564..694be3c8 100644 --- a/nx_parallel/utils/chunk.py +++ b/nx_parallel/utils/chunk.py @@ -67,12 +67,14 @@ def get_n_jobs(n_jobs=None): return 2 if n_jobs is None: - if nx.config.backends.parallel.active: - n_jobs = nx.config.backends.parallel.n_jobs - else: - from joblib.parallel import get_active_backend + from joblib.parallel import get_active_backend - n_jobs = get_active_backend()[1] + # Always check Joblib first (it reflects the live/innermost state) + _, n_jobs = get_active_backend() + + # Fallback to NX config if Joblib has no explicit value + if n_jobs is None and nx.config.backends.parallel.active: + n_jobs = nx.config.backends.parallel.n_jobs if n_jobs is None: return 1 diff --git a/nx_parallel/utils/decorators.py b/nx_parallel/utils/decorators.py index 0af6182c..5bad5f49 100644 --- a/nx_parallel/utils/decorators.py +++ b/nx_parallel/utils/decorators.py @@ -3,6 +3,7 @@ from functools import wraps import networkx as nx from joblib import parallel_config +from joblib.parallel import get_active_backend from nx_parallel.utils.should_run_policies import default_should_run @@ -21,11 +22,18 @@ def wrapper(*args, **kwargs): nx.config.backends.parallel.active or "PYTEST_CURRENT_TEST" in os.environ ): - # Activate nx config system in nx_parallel with: - # `nx.config.backends.parallel.active = True` + # Peeking at joblib's current state to see if there's an outer context + _, current_n_jobs = get_active_backend() + + # Get NetworkX config config_dict = asdict(nx.config.backends.parallel) config_dict.update(config_dict.pop("backend_params")) config_dict.pop("active", None) + + # SYNC: If user has an outer Joblib context for n_jobs, respect it! + if current_n_jobs is not None: + config_dict["n_jobs"] = current_n_jobs + with parallel_config(**config_dict): return func(*args, **kwargs) return func(*args, **kwargs) diff --git a/nx_parallel/utils/tests/test_should_run_policies.py b/nx_parallel/utils/tests/test_should_run_policies.py new file mode 100644 index 00000000..3b6f29e8 --- /dev/null +++ b/nx_parallel/utils/tests/test_should_run_policies.py @@ -0,0 +1,68 @@ +import os +import pytest +import networkx as nx +import nx_parallel as nxp + + +def test_default_should_run(): + @nxp._configure_if_nx_active() + def dummy_default(): + pass + + with pytest.MonkeyPatch().context() as mp: + mp.delitem(os.environ, "PYTEST_CURRENT_TEST", raising=False) + with nx.config.backends.parallel(n_jobs=1): + assert ( + dummy_default.should_run() + == "Parallel backend requires `n_jobs` > 1 to run" + ) + + assert dummy_default.should_run() + + +def test_skip_parallel_backend(): + @nxp._configure_if_nx_active(should_run=nxp.should_skip_parallel) + def dummy_skip_parallel(): + pass + + assert dummy_skip_parallel.should_run() == "Fast algorithm; skip parallel execution" + + +def test_should_run_if_large(): + @nxp._configure_if_nx_active(should_run=nxp.should_run_if_large) + def dummy_if_large(G): + pass + + smallG = nx.fast_gnp_random_graph(20, 0.6, seed=42) + largeG = nx.fast_gnp_random_graph(250, 0.6, seed=42) + + assert dummy_if_large.should_run(smallG) == "Graph too small for parallel execution" + assert dummy_if_large.should_run(largeG) + + +def test_should_run_if_nodes_none(): + @nxp._configure_if_nx_active(should_run=nxp.should_run_if_nodes_none) + def dummy_nodes_none(G, nodes=None): + pass + + G = nx.fast_gnp_random_graph(20, 0.6, seed=42) + assert ( + dummy_nodes_none.should_run(G, nodes=[1, 3]) + == "Parallel execution only supported when `nodes` is None" + ) + assert dummy_nodes_none.should_run(G) + + +def test_should_run_if_sparse(): + @nxp._configure_if_nx_active(should_run=nxp.should_run_if_sparse(threshold=0.4)) + def dummy_if_sparse(G): + pass + + G_dense = nx.fast_gnp_random_graph(20, 0.6, seed=42) + assert ( + dummy_if_sparse.should_run(G_dense) + == "Graph too dense to benefit from parallel execution" + ) + + G_sparse = nx.fast_gnp_random_graph(20, 0.2, seed=42) + assert dummy_if_sparse.should_run(G_sparse)