diff --git a/docs/source/user_guide/parallelization.md b/docs/source/user_guide/parallelization.md index 9d7ebb3845..8f910927b6 100644 --- a/docs/source/user_guide/parallelization.md +++ b/docs/source/user_guide/parallelization.md @@ -48,6 +48,32 @@ def fill(a: qd.Template) -> None: `I` is a `qd.Vector` with one element per dimension. +### Controlling iteration order with `layout=` + +By default, `qd.ndrange(d0, d1, ..., dN-1)` makes the **last argument the innermost (fastest-varying) axis** in the flat parallel loop: adjacent flat threads differ in the last index. The `layout=` keyword lets you choose a different iteration-nesting order. It's a tuple of `int` listing the **canonical axis index at each successive iteration-nesting level, outermost first**, and must be a permutation of `range(N)` where `N` is the number of arguments to `qd.ndrange`: + +```python +@qd.kernel +def k(): + # axis 1 is outermost (slowest-varying), axis 0 is innermost (fastest-varying) + for i, j in qd.ndrange(M, N, layout=(1, 0)): + ... +``` + +The yielded loop variables (`i`, `j`, ...) are still bound to canonical axes 0, 1, ... — only the visit order changes. `layout=None` (the default) and the identity permutation `(0, 1, ..., N-1)` are equivalent and reproduce the default last-argument-innermost order. Mismatched length and non-permutation values are rejected up front with `qd.QuadrantsSyntaxError`. + +`layout=` is independent of what's in the loop body: it controls the iteration order regardless of whether the body touches a `qd.field`, a `qd.ndarray`, a `qd.tensor`, a `qd.Vector` / `qd.Matrix` variant, or no tensor at all. + +`layout=` is supported by both the plain and `qd.grouped` forms: + +```python +for i, j in qd.ndrange(M, N, layout=(1, 0)): + ... +for I in qd.grouped(qd.ndrange(M, N, layout=(1, 0))): + # I[0] is still the canonical axis-0 index, regardless of layout + ... +``` + ## Does GPU kernel launch latency matter? Kernel launch can be done in parallel whilst the previously launched kernel is still running. This means that if the previously launched kernel takes longer to run than the launch time for the new kernel, then the kernel launch latency will be perfectly hidden. diff --git a/docs/source/user_guide/tensor.md b/docs/source/user_guide/tensor.md index f4b0ab25c3..8063ff524f 100644 --- a/docs/source/user_guide/tensor.md +++ b/docs/source/user_guide/tensor.md @@ -113,6 +113,8 @@ b[i, j] = ... # canonical indexing in kernels still works Any permutation is supported, up to Quadrants' `quadrants_max_num_indices` (currently 12). `layout=None` and the identity permutation (`(0, 1, ..., N-1)`) are equivalent and forward no permutation to the underlying allocator. +For best performance, pair `qd.tensor(..., layout=...)` with a matching iteration order via `qd.ndrange(..., layout=...)` (see [`parallelization`](parallelization.md#controlling-iteration-order-with-layout)): the permutation has the same meaning on both APIs, and using the same value on both lines adjacent flat threads up with adjacent physical memory slots. + Quadrants rejects mismatched / invalid layouts up front: ```python diff --git a/python/quadrants/lang/_ndrange.py b/python/quadrants/lang/_ndrange.py index 3766c4a0b0..049d657cf3 100644 --- a/python/quadrants/lang/_ndrange.py +++ b/python/quadrants/lang/_ndrange.py @@ -31,7 +31,7 @@ def _coerce_to_int(v): class _Ndrange: - def __init__(self, *args): + def __init__(self, *args, layout=None): args = list(args) for i, arg in enumerate(args): if not isinstance(arg, collections.abc.Sequence): @@ -49,33 +49,87 @@ def __init__(self, *args): raise QuadrantsTypeError( "Every argument of ndrange should be an integer scalar or a tuple/list of (int, int)" ) - self.bounds = args - self.dimensions = [None] * len(args) - for i, bound in enumerate(self.bounds): - self.dimensions[i] = bound[1] - bound[0] + n = len(args) + + # Validate and normalize ``layout``. Stored as ``self.layout`` (``None`` for the identity + # permutation, else the user-supplied tuple) for introspection / tests, and as + # ``self._physical_to_canonical`` (a Python list of int of length ``n``) for the AST + # builder to use when remapping per-physical-level decomposed indices to canonical loop + # targets. The identity case is kept as ``None`` so the AST-builder fast-path matches + # the pre-layout codegen byte-for-byte. + if layout is None: + self.layout = None + physical_to_canonical = list(range(n)) + else: + layout_t = tuple(layout) + if len(layout_t) != n: + raise QuadrantsSyntaxError( + f"qd.ndrange(layout={layout_t!r}) has {len(layout_t)} entries " + f"but ndrange was called with {n} dimension argument(s); they must match" + ) + # Type-check each entry before sorting / permutation checks, so mixed-type or + # non-integer entries surface a Quadrants error instead of Python's raw + # ``TypeError`` from ``sorted``. ``bool`` is rejected explicitly even though it is + # an ``int`` subclass — accepting ``True`` / ``False`` as axis indices would be a + # foot-gun. + for e in layout_t: + if isinstance(e, bool) or not isinstance(e, (int, np.integer)): + raise QuadrantsTypeError( + f"qd.ndrange(layout={layout_t!r}) entries must be Python ints; " + f"got {type(e).__name__} ({e!r})" + ) + if sorted(layout_t) != list(range(n)): + raise QuadrantsSyntaxError(f"qd.ndrange(layout={layout_t!r}) is not a permutation of range({n})") + if layout_t == tuple(range(n)): + self.layout = None + physical_to_canonical = list(range(n)) + else: + self.layout = layout_t + physical_to_canonical = list(layout_t) - self.acc_dimensions = self.dimensions.copy() - for i in reversed(range(len(self.bounds) - 1)): - self.acc_dimensions[i] = self.acc_dimensions[i] * self.acc_dimensions[i + 1] - if len(self.acc_dimensions) == 0: # for the empty case, e.g. qd.ndrange() - self.acc_dimensions = [1] + self._physical_to_canonical = physical_to_canonical - def __iter__(self): - def gen(d, prefix): - if d == len(self.bounds): - yield prefix - else: - for t in range(self.bounds[d][0], self.bounds[d][1]): - yield from gen(d + 1, prefix + (t,)) + canonical_bounds = args + canonical_dimensions = [bound[1] - bound[0] for bound in canonical_bounds] - yield from gen(0, ()) + physical_bounds = [canonical_bounds[c] for c in physical_to_canonical] + physical_dimensions = [canonical_dimensions[c] for c in physical_to_canonical] + + acc_dimensions = physical_dimensions.copy() + for i in reversed(range(n - 1)): + acc_dimensions[i] = acc_dimensions[i] * acc_dimensions[i + 1] + if not acc_dimensions: # for the empty case, e.g. qd.ndrange() + acc_dimensions = [1] + + self._canonical_bounds = canonical_bounds + self._canonical_dimensions = canonical_dimensions + self.bounds = physical_bounds + self.dimensions = physical_dimensions + self.acc_dimensions = acc_dimensions + + def __iter__(self): + p2c = self._physical_to_canonical + cbounds = self._canonical_bounds + n = len(p2c) + + def gen(level, current): + if level == n: + yield tuple(current) + return + ax = p2c[level] + b, e = cbounds[ax] + for t in range(b, e): + current[ax] = t + yield from gen(level + 1, current) + + yield from gen(0, [0] * n) def grouped(self): return GroupedNDRange(self) -def ndrange(*args) -> Iterable: +def ndrange(*args, layout=None) -> Iterable: """Return an immutable iterator object for looping over multi-dimensional indices. This returned set of multi-dimensional indices is the direct product (in the set-theory sense) @@ -91,6 +145,18 @@ def ndrange(*args) -> Iterable: Args: entries: (int, tuple): Must be either an integer, or a tuple/list of two integers. + layout (tuple of int, optional): Permutation of canonical axes describing the iteration + nesting order, outermost (slowest-varying) first. For an N-argument ndrange, must be + a permutation of ``range(N)``. ``None`` (default) and the identity permutation are + equivalent and reproduce the default order in which the **last argument is the + innermost / fastest-varying axis**. The yielded loop variables stay bound to + canonical axes 0, 1, ..., N-1 regardless of layout — only the visit order changes. + ``layout=`` is independent of the loop body; it controls iteration order whether + the body touches a field, ndarray, tensor, vector/matrix variant, or no tensor at + all. The motivating use case is aligning iteration with a non-default physical + memory layout (e.g. ``qd.tensor(..., layout=...)`` or ``qd.field(..., order=...)``): + using the matching permutation makes adjacent flat threads step through physically + adjacent memory. Returns: An immutable iterator object. @@ -154,8 +220,18 @@ def ndrange(*args) -> Iterable: >>> def loop_tensor(): >>> for row, col, channel in qd.ndrange(image_height, image_width, channels): >>> image[row, col, channel] = ... + + Aligning iteration order with a non-default tensor layout via ``layout=``: + + >>> A = qd.tensor(qd.f32, shape=(M, N), layout=(1, 0)) # axis 1 outer, axis 0 inner + >>> @qd.kernel + >>> def fill(): + >>> # adjacent flat threads now step along axis 0 (the inner physical axis of A), + >>> # i.e. touch physically adjacent memory in A + >>> for i, j in qd.ndrange(M, N, layout=(1, 0)): + >>> A[i, j] = i + j """ - return _Ndrange(*args) + return _Ndrange(*args, layout=layout) class GroupedNDRange: diff --git a/python/quadrants/lang/ast/ast_transformer.py b/python/quadrants/lang/ast/ast_transformer.py index 263a4a11a3..e59e189c0a 100644 --- a/python/quadrants/lang/ast/ast_transformer.py +++ b/python/quadrants/lang/ast/ast_transformer.py @@ -1052,24 +1052,32 @@ def build_ndrange_for(ctx: ASTTransformerFuncContext, node: ast.For) -> None: "Please check if the number of arguments of qd.ndrange() is equal to " "the number of the loop variables." ) - for i, target in enumerate(targets): - if i + 1 < len(targets): - target_tmp = impl.expr_init(I // ndrange_var.acc_dimensions[i + 1]) + # ``physical_to_canonical[p]`` is the canonical (user-visible) axis index that receives + # the decomposed index for physical nesting level ``p``. For the identity / ``layout=None`` + # case this is ``[0, 1, ..., n-1]`` and the emitted IR matches the pre-layout codegen + # byte-for-byte. + physical_to_canonical = ndrange_var._physical_to_canonical + n_levels = len(ndrange_var.dimensions) + for p in range(n_levels): + if p + 1 < n_levels: + target_tmp = impl.expr_init(I // ndrange_var.acc_dimensions[p + 1]) else: target_tmp = impl.expr_init(I) + canonical_idx = physical_to_canonical[p] + target = targets[canonical_idx] ctx.create_variable( target, impl.expr_init( target_tmp + impl.subscript( ctx.ast_builder, - impl.subscript(ctx.ast_builder, ndrange_var.bounds, i), + impl.subscript(ctx.ast_builder, ndrange_var.bounds, p), 0, ) ), ) - if i + 1 < len(targets): - I._assign(I - target_tmp * ndrange_var.acc_dimensions[i + 1]) + if p + 1 < n_levels: + I._assign(I - target_tmp * ndrange_var.acc_dimensions[p + 1]) ctx.loop_depth += 1 build_stmts(ctx, node.body) ctx.loop_depth -= 1 @@ -1098,14 +1106,22 @@ def build_grouped_ndrange_for(ctx: ASTTransformerFuncContext, node: ast.For) -> ctx.create_variable(target, target_var) I = impl.expr_init(ndrange_loop_var) - for i in range(len(ndrange_var.dimensions)): - if i + 1 < len(ndrange_var.dimensions): - target_tmp = I // ndrange_var.acc_dimensions[i + 1] + # See ``build_ndrange_for`` above for the layout semantics. The grouped target_var is a + # vector indexed by canonical axis, so element ``physical_to_canonical[p]`` (not ``p``) + # receives the decomposition of physical level ``p``. + physical_to_canonical = ndrange_var._physical_to_canonical + n_levels = len(ndrange_var.dimensions) + for p in range(n_levels): + if p + 1 < n_levels: + target_tmp = I // ndrange_var.acc_dimensions[p + 1] else: target_tmp = I - impl.subscript(ctx.ast_builder, target_var, i)._assign(target_tmp + ndrange_var.bounds[i][0]) - if i + 1 < len(ndrange_var.dimensions): - I._assign(I - target_tmp * ndrange_var.acc_dimensions[i + 1]) + canonical_idx = physical_to_canonical[p] + impl.subscript(ctx.ast_builder, target_var, canonical_idx)._assign( + target_tmp + ndrange_var.bounds[p][0] + ) + if p + 1 < n_levels: + I._assign(I - target_tmp * ndrange_var.acc_dimensions[p + 1]) ctx.loop_depth += 1 build_stmts(ctx, node.body) ctx.loop_depth -= 1 diff --git a/tests/python/test_ndrange_layout.py b/tests/python/test_ndrange_layout.py new file mode 100644 index 0000000000..a1456fce25 --- /dev/null +++ b/tests/python/test_ndrange_layout.py @@ -0,0 +1,400 @@ +"""Tests for the ``layout=`` keyword on :func:`quadrants.ndrange`. + +``layout=`` is canonical-preserving: the loop variables stay bound to canonical axes regardless of layout; +only the visit order (which canonical axis is the outermost / innermost iteration nesting level) changes. +``layout=None`` and the identity permutation are equivalent and produce the default last-arg-innermost +behaviour. +""" + +import itertools + +import numpy as np +import pytest + +import quadrants as qd + +from tests import test_utils + + +def _expected_flat_to_canonical(dims, layout): + """Build the expected sequence of canonical multi-indices yielded by an ``ndrange`` of the given + dimensions and layout. + + Iteration nests with physical level 0 outermost; physical level ``p`` indexes canonical axis ``layout[p]``. + """ + layout = tuple(range(len(dims))) if layout is None else tuple(layout) + ranges = [range(dims[axis]) for axis in layout] + out = [] + for physical_tuple in itertools.product(*ranges): + canonical = [0] * len(dims) + for p, ax in enumerate(layout): + canonical[ax] = physical_tuple[p] + out.append(tuple(canonical)) + return out + + +def _expected_flat_index(canonical, dims, layout): + """Return the flat thread index that visits ``canonical`` under (``dims``, ``layout``). + + Mirrors the AST-builder's decomposition: flat = sum_{p} canonical[layout[p]] * prod(dims[layout[p+1:]]). + """ + layout = tuple(range(len(dims))) if layout is None else tuple(layout) + n = len(dims) + flat = 0 + for p in range(n): + ax = layout[p] + inner = 1 + for q in range(p + 1, n): + inner *= dims[layout[q]] + flat += canonical[ax] * inner + return flat + + +# ---------------------------------------------------------------------------- +# Identity / default equivalence +# ---------------------------------------------------------------------------- + + +@test_utils.test() +def test_layout_none_matches_default(): + M, N = 5, 7 + x = qd.field(qd.i32, shape=(M, N)) + y = qd.field(qd.i32, shape=(M, N)) + + @qd.kernel + def fill_default(): + for i, j in qd.ndrange(M, N): + x[i, j] = i * 100 + j + + @qd.kernel + def fill_layout_none(): + for i, j in qd.ndrange(M, N, layout=None): + y[i, j] = i * 100 + j + + fill_default() + fill_layout_none() + np.testing.assert_array_equal(x.to_numpy(), y.to_numpy()) + + +@test_utils.test() +def test_layout_identity_matches_default(): + M, N = 5, 7 + x = qd.field(qd.i32, shape=(M, N)) + y = qd.field(qd.i32, shape=(M, N)) + + @qd.kernel + def fill_default(): + for i, j in qd.ndrange(M, N): + x[i, j] = i * 100 + j + + @qd.kernel + def fill_layout_identity(): + for i, j in qd.ndrange(M, N, layout=(0, 1)): + y[i, j] = i * 100 + j + + fill_default() + fill_layout_identity() + np.testing.assert_array_equal(x.to_numpy(), y.to_numpy()) + + +# ---------------------------------------------------------------------------- +# Non-identity layouts: canonical loop targets, full coverage +# ---------------------------------------------------------------------------- + + +@test_utils.test() +def test_layout_2d_transposed_canonical_targets(): + """With ``layout=(1, 0)``, the loop variables (i, j) are still canonical axes 0, 1.""" + M, N = 4, 6 + x = qd.field(qd.i32, shape=(M, N)) + + @qd.kernel + def fill(): + for i, j in qd.ndrange(M, N, layout=(1, 0)): + x[i, j] = i * 100 + j + + fill() + expected = np.array([[i * 100 + j for j in range(N)] for i in range(M)], dtype=np.int32) + np.testing.assert_array_equal(x.to_numpy(), expected) + + +@test_utils.test() +def test_layout_3d_arbitrary_permutation_canonical_targets(): + """Rank-3 with a non-cyclic permutation.""" + D0, D1, D2 = 3, 4, 5 + x = qd.field(qd.i32, shape=(D0, D1, D2)) + + @qd.kernel + def fill(): + for i, j, k in qd.ndrange(D0, D1, D2, layout=(2, 0, 1)): + x[i, j, k] = i * 10000 + j * 100 + k + + fill() + expected = np.array( + [[[i * 10000 + j * 100 + k for k in range(D2)] for j in range(D1)] for i in range(D0)], + dtype=np.int32, + ) + np.testing.assert_array_equal(x.to_numpy(), expected) + + +@test_utils.test() +def test_layout_with_tuple_bounds_preserves_offsets(): + """Layout doesn't disturb (begin, end) tuples — each canonical axis keeps its own bounds.""" + M, N = 16, 16 + x = qd.field(qd.i32, shape=(M, N)) + + @qd.kernel + def fill(): + for i, j in qd.ndrange((2, 10), (3, 7), layout=(1, 0)): + x[i, j] = i * 100 + j + + fill() + expected = np.zeros((M, N), dtype=np.int32) + for i in range(2, 10): + for j in range(3, 7): + expected[i, j] = i * 100 + j + np.testing.assert_array_equal(x.to_numpy(), expected) + + +@test_utils.test() +def test_layout_full_coverage_via_atomic_count(): + """Every canonical slot is visited exactly once.""" + M, N = 5, 7 + counts = qd.field(qd.i32, shape=(M, N)) + + @qd.kernel + def fill(): + for i, j in qd.ndrange(M, N, layout=(1, 0)): + counts[i, j] += 1 + + fill() + np.testing.assert_array_equal(counts.to_numpy(), np.ones((M, N), dtype=np.int32)) + + +@test_utils.test() +def test_layout_flat_index_matches_decomposition(): + """The flat thread index reconstructed from the canonical loop variables under the requested + layout permutation matches what a sequential range-loop would assign — i.e. the AST decomposition + is the inverse of the canonical-from-physical mapping. + """ + M, N = 4, 6 + flat = qd.field(qd.i32, shape=(M, N)) + + @qd.kernel + def fill(): + for i, j in qd.ndrange(M, N, layout=(1, 0)): + # If physical level 0 = axis 1 (outer) and level 1 = axis 0 (inner), then the flat index + # is j * M + i. Writing it into a per-canonical-slot grid lets us check coverage and the + # bijection in one pass. + flat[i, j] = j * M + i + + fill() + expected = np.array([[j * M + i for j in range(N)] for i in range(M)], dtype=np.int32) + np.testing.assert_array_equal(flat.to_numpy(), expected) + + +# ---------------------------------------------------------------------------- +# qd.grouped + layout +# ---------------------------------------------------------------------------- + + +@test_utils.test() +def test_layout_grouped_indices_are_canonical(): + """``I[0]`` is the canonical axis-0 index regardless of layout.""" + M, N = 4, 5 + x = qd.field(qd.i32, shape=(M, N)) + + @qd.kernel + def fill(): + for I in qd.grouped(qd.ndrange(M, N, layout=(1, 0))): + x[I] = I[0] * 100 + I[1] + + fill() + expected = np.array([[i * 100 + j for j in range(N)] for i in range(M)], dtype=np.int32) + np.testing.assert_array_equal(x.to_numpy(), expected) + + +@test_utils.test() +def test_layout_static_grouped(): + """Unrolled (qd.static) grouped path also sees canonical indices in physical iteration order.""" + M, N = 3, 4 + x = qd.field(qd.i32, shape=(M, N)) + + @qd.kernel + def fill(): + for I in qd.static(qd.grouped(qd.ndrange(M, N, layout=(1, 0)))): + x[I] = I[0] * 100 + I[1] + + fill() + expected = np.array([[i * 100 + j for j in range(N)] for i in range(M)], dtype=np.int32) + np.testing.assert_array_equal(x.to_numpy(), expected) + + +# ---------------------------------------------------------------------------- +# Pairing with qd.tensor(..., layout=...) +# ---------------------------------------------------------------------------- + + +@test_utils.test(arch=qd.cpu) +def test_layout_pairs_with_tensor_layout_field(): + """The documented pairing use case: matching ``layout=`` on both tensor and ndrange. The kernel + body uses canonical indexing throughout; correctness must hold (this exercises the + canonical->physical AST rewrite on the tensor side and the layout-aware decomposition on the + ndrange side together). + """ + M, N = 4, 6 + A = qd.tensor(qd.i32, shape=(M, N), backend=qd.Backend.FIELD, layout=(1, 0)) + + @qd.kernel + def fill(a: qd.template()): + for i, j in qd.ndrange(M, N, layout=(1, 0)): + a[i, j] = i * 100 + j + + fill(A) + expected = np.array([[i * 100 + j for j in range(N)] for i in range(M)], dtype=np.int32) + np.testing.assert_array_equal(A.to_numpy(), expected) + + +# ---------------------------------------------------------------------------- +# Python-side iteration (outside @qd.kernel) +# ---------------------------------------------------------------------------- + + +def test_layout_python_iteration_2d(): + qd.init(arch=qd.cpu) + M, N = 3, 4 + got = list(qd.ndrange(M, N, layout=(1, 0))) + assert got == _expected_flat_to_canonical((M, N), (1, 0)) + + +def test_layout_python_iteration_3d(): + qd.init(arch=qd.cpu) + dims = (2, 3, 4) + got = list(qd.ndrange(*dims, layout=(2, 0, 1))) + assert got == _expected_flat_to_canonical(dims, (2, 0, 1)) + + +def test_layout_python_iteration_identity_matches_default(): + qd.init(arch=qd.cpu) + M, N = 3, 4 + assert list(qd.ndrange(M, N, layout=(0, 1))) == list(qd.ndrange(M, N)) + assert list(qd.ndrange(M, N, layout=None)) == list(qd.ndrange(M, N)) + + +def test_layout_grouped_python_iteration_via_method(): + """``_Ndrange.grouped()`` (Python-scope method, not ``qd.grouped``) preserves the layout-induced + iteration order. ``qd.grouped`` itself is decorated ``@quadrants_scope`` and cannot be invoked + outside a kernel, so test the underlying method directly here. + """ + qd.init(arch=qd.cpu) + from quadrants.lang._ndrange import _Ndrange + + M, N = 3, 4 + got = [] + for vec in _Ndrange(M, N, layout=(1, 0)).grouped(): + got.append(tuple(vec.to_list())) + assert got == _expected_flat_to_canonical((M, N), (1, 0)) + + +# ---------------------------------------------------------------------------- +# Introspection +# ---------------------------------------------------------------------------- + + +def test_layout_attribute_identity_normalizes_to_none(): + qd.init(arch=qd.cpu) + # ``layout=None`` and identity layout both expose ``layout = None`` for introspection + # (so user code can treat "no layout" symmetrically). + from quadrants.lang._ndrange import _Ndrange + + a = _Ndrange(3, 4) + b = _Ndrange(3, 4, layout=None) + c = _Ndrange(3, 4, layout=(0, 1)) + assert a.layout is None + assert b.layout is None + assert c.layout is None + + +def test_layout_attribute_non_identity_preserved(): + qd.init(arch=qd.cpu) + from quadrants.lang._ndrange import _Ndrange + + a = _Ndrange(3, 4, layout=(1, 0)) + assert a.layout == (1, 0) + + +# ---------------------------------------------------------------------------- +# Degenerate ranks +# ---------------------------------------------------------------------------- + + +@test_utils.test() +def test_layout_1d_degenerate(): + """Layout (0,) on a 1-D ndrange is the only permutation and must match the default.""" + M = 7 + x = qd.field(qd.i32, shape=(M,)) + y = qd.field(qd.i32, shape=(M,)) + + @qd.kernel + def fill_default(): + for i in qd.ndrange(M): + x[i] = i + + @qd.kernel + def fill_layout(): + for i in qd.ndrange(M, layout=(0,)): + y[i] = i + + fill_default() + fill_layout() + np.testing.assert_array_equal(x.to_numpy(), y.to_numpy()) + + +def test_layout_zero_dim_degenerate(): + qd.init(arch=qd.cpu) + # Empty ndrange yields exactly one (empty) tuple. + assert list(qd.ndrange()) == [()] + assert list(qd.ndrange(layout=())) == [()] + + +# ---------------------------------------------------------------------------- +# Error cases +# ---------------------------------------------------------------------------- + + +def test_layout_wrong_length_raises(): + qd.init(arch=qd.cpu) + with pytest.raises(qd.QuadrantsSyntaxError, match=r"qd\.ndrange\(layout=.*\) has 3 entries but ndrange"): + qd.ndrange(4, 5, layout=(0, 1, 2)) + + +def test_layout_not_a_permutation_raises(): + qd.init(arch=qd.cpu) + with pytest.raises(qd.QuadrantsSyntaxError, match=r"qd\.ndrange\(layout=.*\) is not a permutation"): + qd.ndrange(4, 5, layout=(0, 0)) + + +def test_layout_out_of_range_raises(): + qd.init(arch=qd.cpu) + with pytest.raises(qd.QuadrantsSyntaxError, match=r"qd\.ndrange\(layout=.*\) is not a permutation"): + qd.ndrange(4, 5, layout=(0, 2)) + + +def test_layout_non_integer_entry_raises(): + """Non-integer entries (string, float, mixed) surface a QuadrantsTypeError instead of the raw + Python ``TypeError`` ``sorted`` would emit on mixed-type sequences. + """ + qd.init(arch=qd.cpu) + with pytest.raises(qd.QuadrantsTypeError, match=r"entries must be Python ints"): + qd.ndrange(4, 5, layout=(0, "1")) + with pytest.raises(qd.QuadrantsTypeError, match=r"entries must be Python ints"): + qd.ndrange(4, 5, layout=(0.0, 1.0)) + + +def test_layout_bool_entry_rejected(): + """``bool`` is an ``int`` subclass but rejecting ``True`` / ``False`` as axis indices avoids a + foot-gun. + """ + qd.init(arch=qd.cpu) + with pytest.raises(qd.QuadrantsTypeError, match=r"entries must be Python ints"): + qd.ndrange(4, 5, layout=(True, False))