From 19b2fb02b3898f2bb79671540856d6a46ab34251 Mon Sep 17 00:00:00 2001 From: Jean-Luc Duprat Date: Tue, 24 Mar 2026 12:37:22 -0700 Subject: [PATCH 01/16] Fix coalescing_efficiency docstring to show explicit element_bytes MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The strided access example used the default element_bytes=2 (fp16), under which Layout(32, 2) fits in a single cache line (1 transaction). The example implies fp32 — make element_bytes=4 explicit so the documented 2 transactions and 0.5 efficiency are correct. --- src/tensor_layouts/analysis.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/tensor_layouts/analysis.py b/src/tensor_layouts/analysis.py index 98e9a31..93f4a0d 100644 --- a/src/tensor_layouts/analysis.py +++ b/src/tensor_layouts/analysis.py @@ -258,8 +258,8 @@ def coalescing_efficiency(layout: Layout, *, warp_size: int = 32, coalescing_efficiency(Layout(32, 1)) # {'transactions': 1, 'efficiency': 0.5, ...} -- 64B used of 128B line - # Strided access: each thread 2 elements apart - coalescing_efficiency(Layout(32, 2)) + # Strided access: each thread 2 elements apart, fp32 + coalescing_efficiency(Layout(32, 2), element_bytes=4) # {'transactions': 2, 'efficiency': 0.5, ...} """ layout = as_layout(layout) From 97273fd0b98fd485bffcf0814cbc4b523cbc5da0 Mon Sep 17 00:00:00 2001 From: Jean-Luc Duprat Date: Tue, 24 Mar 2026 12:38:06 -0700 Subject: [PATCH 02/16] Fix explain() crash for logical_product with tuple tilers explain(logical_product, ...) assumed B was always a Layout and called compose(complement(A, bound), B) directly. When B is a tuple tiler like (2, 2), compose fails because the complement has fewer modes than the tuple length. Mirror the actual logical_product implementation: for tuple tilers, show mode-by-mode decomposition instead of the single-layout formula. --- src/tensor_layouts/analysis.py | 38 +++++++++++++++++++++++----------- tests/analysis.py | 10 +++++++++ 2 files changed, 36 insertions(+), 12 deletions(-) diff --git a/src/tensor_layouts/analysis.py b/src/tensor_layouts/analysis.py index 93f4a0d..d09d2d2 100644 --- a/src/tensor_layouts/analysis.py +++ b/src/tensor_layouts/analysis.py @@ -870,18 +870,32 @@ def explain(fn, *args): if isinstance(B, int): B = Layout(B) lines.append(f'logical_product({A}, {B})') - lines.append(f' = Layout(A, compose(complement(A, size(A)*size(B)), B))') - lines.append(f'') - lines.append(f' A = {A}') - lines.append(f' B = {B}') - bound = size(A) * size(B) - lines.append(f' size(A) * size(B) = {bound}') - comp = complement(A, bound) - lines.append(f' complement(A, {bound}) = {comp}') - comp_b = compose(comp, B) - lines.append(f' compose(complement, B) = {comp_b}') - result = Layout(A, comp_b) - lines.append(f' Layout(A, {comp_b}) = {result}') + + if is_layout(B): + lines.append(f' = Layout(A, compose(complement(A, size(A)*size(B)), B))') + lines.append(f'') + lines.append(f' A = {A}') + lines.append(f' B = {B}') + bound = size(A) * size(B) + lines.append(f' size(A) * size(B) = {bound}') + comp = complement(A, bound) + lines.append(f' complement(A, {bound}) = {comp}') + comp_b = compose(comp, B) + lines.append(f' compose(complement, B) = {comp_b}') + result = Layout(A, comp_b) + lines.append(f' Layout(A, {comp_b}) = {result}') + else: + # Tuple tiler: mode-by-mode decomposition + lines.append(f' For tuple tilers, applies logical_product mode-by-mode.') + lines.append(f'') + lines.append(f' A = {A}') + lines.append(f' B = {B}') + for i in range(len(B)): + mi = mode(A, i) + bi = B[i] + ri = logical_product(mi, bi) + lines.append(f' mode {i}: logical_product({mi}, {bi}) = {ri}') + lines.append(f'') actual = logical_product(A, B) lines.append(f' result = {actual}') diff --git a/tests/analysis.py b/tests/analysis.py index eb0dc6f..c46970f 100644 --- a/tests/analysis.py +++ b/tests/analysis.py @@ -562,6 +562,16 @@ def test_explain_logical_product(): assert '(4, 3) : (1, 4)' in text +def test_explain_logical_product_tuple_tiler(): + """explain handles logical_product with tuple tiler without crashing.""" + text = explain(logical_product, Layout((4, 4), (1, 4)), (2, 2)) + assert 'logical_product' in text + assert 'mode 0' in text + assert 'mode 1' in text + expected = logical_product(Layout((4, 4), (1, 4)), (2, 2)) + assert str(expected) in text + + def test_explain_complement(): """explain shows complement with image and codomain.""" text = explain(complement, Layout(4, 2), 16) From 7f67459988013cda07ccf9ce08e64f4e6be5ddb1 Mon Sep 17 00:00:00 2001 From: Jean-Luc Duprat Date: Tue, 24 Mar 2026 12:40:22 -0700 Subject: [PATCH 03/16] Fix draw_slice for 1D layouts with tuple slice_spec _get_slice_highlight_mask_2d only handled tuple slice_spec for rank-2 layouts, silently returning an all-False mask for rank-0 and rank-1 layouts. Add an elif branch for r < 2 that unpacks the single-element tuple and matches against the layout shape. --- docs/viz_api.md | 6 ++++++ src/tensor_layouts/viz.py | 16 +++++++++++++++ tests/viz.py | 43 +++++++++++++++++++++++++++++++++++++++ 3 files changed, 65 insertions(+) diff --git a/docs/viz_api.md b/docs/viz_api.md index ca13f98..15a33ec 100644 --- a/docs/viz_api.md +++ b/docs/viz_api.md @@ -257,6 +257,12 @@ layout = Layout(((3, 2), ((2, 3), 2)), ((4, 1), ((2, 15), 100))) draw_slice(layout, ((1, None), ((None, 0), None)), title="((1,:),((:,0),:))") ``` +For 1D layouts, wrap the slice in a single-element tuple: + +```python +draw_slice(Layout(8, 1), (slice(2, 5),), title="1D slice [2:5]") +``` + ![draw_slice](images/draw_slice.png) **Parameters:** diff --git a/src/tensor_layouts/viz.py b/src/tensor_layouts/viz.py index 0d93d2b..5beb83b 100644 --- a/src/tensor_layouts/viz.py +++ b/src/tensor_layouts/viz.py @@ -2775,6 +2775,22 @@ def _get_slice_highlight_mask_2d(layout, slice_spec) -> np.ndarray: if col_match: mask[i, j] = True + elif isinstance(slice_spec, tuple) and r < 2: + if len(slice_spec) != 1: + raise ValueError( + f"Rank-{r} layout requires a 1-element tuple slice_spec, " + f"got {len(slice_spec)}" + ) + (col_spec,) = slice_spec + col_flat = _is_flat_slice_component(col_spec) + for j in range(cols): + col_coord = idx2crd(j, layout.shape) + mask[0, j] = ( + _match_flat_slice_component(j, col_spec, cols) + if col_flat + else _match_nested_slice_component(col_coord, col_spec, layout.shape) + ) + return mask diff --git a/tests/viz.py b/tests/viz.py index 76a5dd3..4a73cc6 100644 --- a/tests/viz.py +++ b/tests/viz.py @@ -1015,6 +1015,49 @@ def test_slice_highlight_mask_tracks_logical_cells_not_offsets(): ] +@requires_viz +def test_slice_highlight_mask_1d_tuple_spec(): + """1D layout with tuple slice_spec should highlight the correct elements.""" + layout = Layout(8, 1) + mask = _get_slice_highlight_mask_2d(layout, (slice(2, 5),)) + assert mask.shape == (1, 8) + assert mask.tolist() == [[False, False, True, True, True, False, False, False]] + + +@requires_viz +def test_slice_highlight_mask_1d_tuple_spec_rank1(): + """Rank-1 layout with tuple slice_spec should highlight the correct elements.""" + layout = Layout((8,), (1,)) + mask = _get_slice_highlight_mask_2d(layout, (slice(2, 5),)) + assert mask.shape == (1, 8) + assert mask.tolist() == [[False, False, True, True, True, False, False, False]] + + +@requires_viz +def test_slice_highlight_mask_1d_tuple_int_spec(): + """1D layout with tuple (int,) slice_spec highlights a single element.""" + layout = Layout(8, 1) + mask = _get_slice_highlight_mask_2d(layout, (3,)) + assert mask.shape == (1, 8) + assert mask.tolist() == [[False, False, False, True, False, False, False, False]] + + +@requires_viz +def test_slice_highlight_mask_1d_tuple_none_spec(): + """1D layout with tuple (None,) selects all elements.""" + layout = Layout(4, 1) + mask = _get_slice_highlight_mask_2d(layout, (None,)) + assert mask.tolist() == [[True, True, True, True]] + + +@requires_viz +def test_slice_highlight_mask_1d_wrong_tuple_length_raises(): + """1D layout with 2-element tuple slice_spec raises ValueError.""" + layout = Layout(4, 1) + with pytest.raises(ValueError, match="1-element tuple"): + _get_slice_highlight_mask_2d(layout, (1, 2)) + + @requires_viz def test_compute_tv_mapping_uses_first_wins_for_duplicate_cells(): layout = Layout((2, 2), (0, 0)) From 958340ef3ca848eb7aedf8cabb8da59868dadb95 Mon Sep 17 00:00:00 2001 From: Jean-Luc Duprat Date: Tue, 24 Mar 2026 12:40:45 -0700 Subject: [PATCH 04/16] Add type validation for Layout shape and stride arguments Layout.__init__ now validates that shape and stride arguments are int or nested tuples of ints before calling normalize(). Invalid types like strings, floats, or None produce a clear TypeError naming the offending parameter (e.g. "Layout stride must be int or tuple of ints, got str"). --- src/tensor_layouts/layouts.py | 20 ++++++++++++++++++++ tests/layouts.py | 25 +++++++++++++++++++++++++ 2 files changed, 45 insertions(+) diff --git a/src/tensor_layouts/layouts.py b/src/tensor_layouts/layouts.py index ff821c4..2d21921 100644 --- a/src/tensor_layouts/layouts.py +++ b/src/tensor_layouts/layouts.py @@ -301,6 +301,23 @@ def normalize(x: Any) -> IntOrIntTuple: # +def _validate_shape_type(x, name: str) -> None: + """Validate that *x* is a valid shape or stride: int or nested tuple of ints. + + Raises TypeError with a clear message naming the offending parameter + (``name`` should be ``"shape"`` or ``"stride"``). + """ + if is_int(x): + return + if isinstance(x, (list, tuple)): + for elem in x: + _validate_shape_type(elem, name) + return + raise TypeError( + f"Layout {name} must be int or tuple of ints, got {type(x).__name__}" + ) + + class Layout: """A function from logical coordinates to memory offsets: offset = sum(coord_i * stride_i). @@ -355,11 +372,14 @@ def __init__(self, *args, swizzle: "Swizzle | None" = None): elif len(args) == 1: shape = args[0] + _validate_shape_type(shape, "shape") self._shape = normalize(shape) self._stride = compute_col_major_strides(self._shape) elif len(args) == 2: shape, stride = args + _validate_shape_type(shape, "shape") + _validate_shape_type(stride, "stride") self._shape = normalize(shape) self._stride = normalize(stride) diff --git a/tests/layouts.py b/tests/layouts.py index cf5a911..4244cdd 100644 --- a/tests/layouts.py +++ b/tests/layouts.py @@ -201,6 +201,31 @@ def test_layout_basic(): Layout((6, 1, 12, 2, 2), (2, 0, 12, 144, 1)) # Complex layout +def test_layout_type_validation(): + """Layout rejects invalid shape/stride types with clear messages.""" + # Strings rejected + with pytest.raises(TypeError, match="stride.*str"): + Layout((4, 2), "row") + with pytest.raises(TypeError, match="shape.*str"): + Layout("abc") + + # Floats rejected + with pytest.raises(TypeError, match="stride.*float"): + Layout((4, 2), 1.5) + with pytest.raises(TypeError, match="shape.*float"): + Layout(3.14) + + # None rejected + with pytest.raises(TypeError, match="shape.*NoneType"): + Layout(None) + + # Valid constructions still work + Layout((4, 2), (1, 4)) + Layout(((2, 2), 4), ((1, 2), 4)) + Layout(8) + Layout([4, 2]) # lists are fine + + def test_layout_rank_size_cosize(): L5 = Layout((64, 32), (1, 128)) assert rank(L5) == 2 From aae5ee84ec9bd3d9f6bd60a7d2d1ed066e199eb7 Mon Sep 17 00:00:00 2001 From: Jean-Luc Duprat Date: Tue, 24 Mar 2026 12:43:19 -0700 Subject: [PATCH 05/16] Fix per_group analysis iteration for TV layouts MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit per_group_bank_conflicts and per_group_coalescing iterated over flat indices, splitting a (32,4) TV layout into 4 groups of 32 — one per value — instead of 1 group of 32 threads with all 4 values each. Add _tv_dimensions() helper to extract (thread_count, value_count). Group by the thread dimension (mode 0) and iterate all value modes per thread using colexicographic indexing. Rank-1 layouts are unchanged. --- src/tensor_layouts/analysis.py | 57 ++++++++++++++++++++++++---------- tests/analysis.py | 18 +++++++++++ 2 files changed, 59 insertions(+), 16 deletions(-) diff --git a/src/tensor_layouts/analysis.py b/src/tensor_layouts/analysis.py index d09d2d2..e04f7f5 100644 --- a/src/tensor_layouts/analysis.py +++ b/src/tensor_layouts/analysis.py @@ -356,6 +356,19 @@ def segment_analysis(layout: Layout, *, warp_size: int = 32, # Per-group analysis # ============================================================================= + +def _tv_dimensions(layout: Layout): + """Extract (thread_count, value_count) from a layout. + + For rank-1 (scalar shape) layouts: thread_count = size, value_count = 1. + For rank>1 (TV) layouts: thread_count = size(mode 0), value_count = + product of remaining modes. + """ + if is_int(layout.shape): + return size(layout), 1 + return size(mode(layout, 0)), size(layout) // size(mode(layout, 0)) + + def per_group_bank_conflicts(layout: Layout, *, group_size: int = 32, num_banks: int = 32, element_bytes: int = 2, bank_width_bytes: int = 4) -> dict: @@ -364,6 +377,10 @@ def per_group_bank_conflicts(layout: Layout, *, group_size: int = 32, Splits the layout into groups of ``group_size`` threads and analyzes bank conflicts for each group independently. + For multi-mode (TV) layouts, groups are formed along mode 0 (the thread + dimension). Each thread's accesses across all value modes (mode 1+) are + included in its group's analysis. + Args: layout: Maps thread_id -> memory offset (in elements). group_size: Threads per group (32 = NVIDIA warp, 64 = AMD wave). @@ -381,8 +398,8 @@ def per_group_bank_conflicts(layout: Layout, *, group_size: int = 32, layout = as_layout(layout) if group_size <= 0: raise ValueError(f"group_size must be positive, got {group_size}") - n = size(layout) - num_groups = (n + group_size - 1) // group_size + thread_count, value_count = _tv_dimensions(layout) + num_groups = (thread_count + group_size - 1) // group_size groups = [] worst_idx = 0 @@ -390,15 +407,17 @@ def per_group_bank_conflicts(layout: Layout, *, group_size: int = 32, for g in range(num_groups): start = g * group_size - end = min(start + group_size, n) + end = min(start + group_size, thread_count) thread_banks = {} for t in range(start, end): - offset = layout(t) - byte_addr = offset * element_bytes - word_addr = byte_addr // bank_width_bytes - bank = word_addr % num_banks - thread_banks.setdefault(bank, []).append((t, word_addr)) + for v in range(value_count): + flat_idx = v * thread_count + t + offset = layout(flat_idx) + byte_addr = offset * element_bytes + word_addr = byte_addr // bank_width_bytes + bank = word_addr % num_banks + thread_banks.setdefault(bank, []).append((t, word_addr)) max_ways = 1 bank_to_threads = {} @@ -436,6 +455,10 @@ def per_group_coalescing(layout: Layout, *, group_size: int = 32, Splits the layout into groups of ``group_size`` threads and analyzes coalescing for each group independently. + For multi-mode (TV) layouts, groups are formed along mode 0 (the thread + dimension). Each thread's accesses across all value modes (mode 1+) are + included in its group's analysis. + Args: layout: Maps thread_id -> memory offset (in elements). group_size: Threads per group (32 = NVIDIA warp, 64 = AMD wave). @@ -451,8 +474,8 @@ def per_group_coalescing(layout: Layout, *, group_size: int = 32, layout = as_layout(layout) if group_size <= 0: raise ValueError(f"group_size must be positive, got {group_size}") - n = size(layout) - num_groups = (n + group_size - 1) // group_size + thread_count, value_count = _tv_dimensions(layout) + num_groups = (thread_count + group_size - 1) // group_size groups = [] worst_idx = 0 @@ -460,16 +483,18 @@ def per_group_coalescing(layout: Layout, *, group_size: int = 32, for g in range(num_groups): start = g * group_size - end = min(start + group_size, n) + end = min(start + group_size, thread_count) cache_lines = set() unique_offsets = set() for t in range(start, end): - offset = layout(t) - unique_offsets.add(offset) - byte_addr = offset * element_bytes - cache_line = byte_addr // cache_line_bytes - cache_lines.add(cache_line) + for v in range(value_count): + flat_idx = v * thread_count + t + offset = layout(flat_idx) + unique_offsets.add(offset) + byte_addr = offset * element_bytes + cache_line = byte_addr // cache_line_bytes + cache_lines.add(cache_line) transactions = len(cache_lines) useful_bytes = len(unique_offsets) * element_bytes diff --git a/tests/analysis.py b/tests/analysis.py index c46970f..f31a8f6 100644 --- a/tests/analysis.py +++ b/tests/analysis.py @@ -260,6 +260,14 @@ def test_per_group_bank_conflicts(): assert r_per['worst_max_ways'] == r_single['max_ways'] +def test_per_group_bank_conflicts_tv_layout(): + """TV layout groups by thread dimension, not flat index.""" + # 32 threads, 4 values each: should be 1 group (not 4) + tv = Layout((32, 4), (1, 32)) + result = per_group_bank_conflicts(tv, group_size=32) + assert len(result['groups']) == 1 + + def test_per_group_coalescing(): """Per-group coalescing for a uniform layout gives identical per-warp results.""" r_per = per_group_coalescing(Layout(64, 1), element_bytes=2) @@ -269,6 +277,16 @@ def test_per_group_coalescing(): assert g['transactions'] == 1 +def test_per_group_coalescing_tv_layout(): + """TV layout groups by thread dimension, not flat index.""" + # 32 threads, 4 values each (contiguous within each thread's block) + tv = Layout((32, 4), (4, 1)) + result = per_group_coalescing(tv, group_size=32) + assert len(result['groups']) == 1 + # 32 threads * 4 values = 128 elements * 2B = 256B -> 2 cache lines + assert result['groups'][0]['transactions'] == 2 + + ## cycles From e8f3b146d592680b757b1137decb47fb60977467 Mon Sep 17 00:00:00 2001 From: Jean-Luc Duprat Date: Tue, 24 Mar 2026 12:45:09 -0700 Subject: [PATCH 06/16] Add hierarchical layout support to draw_composite draw_composite and show_composite now accept flatten_hierarchical and label_hierarchy_levels parameters (both as top-level defaults and as per-panel overrides via the options dict). When flatten_hierarchical is False, hierarchical panels render with nested coordinate labels and tile boundary lines, matching draw_layout's existing behavior. --- docs/viz_api.md | 7 +++++ src/tensor_layouts/viz.py | 65 +++++++++++++++++++++++++++++++++------ tests/viz.py | 34 ++++++++++++++++++++ 3 files changed, 97 insertions(+), 9 deletions(-) diff --git a/docs/viz_api.md b/docs/viz_api.md index 15a33ec..d0d9e6f 100644 --- a/docs/viz_api.md +++ b/docs/viz_api.md @@ -308,6 +308,13 @@ draw_composite(panels, "comparison.png", | `panel_size` | `(w, h)` | `(4, 4)` | Size per panel | | `colorize` | `bool` | `False` | Rainbow colors | | `tv_mode` | `bool` | `False` | Use TV-layout rendering | +| `flatten_hierarchical` | `bool` | `True` | Flatten nested shapes to 2D grid | +| `label_hierarchy_levels` | `bool` | `False` | In nested hierarchical mode, annotate hierarchy levels | + +Per-panel options (`(Layout, opts_dict)` tuples) override the top-level +defaults: `colorize`, `color_layout`, `num_colors`, `tv_mode`, +`flatten_hierarchical`, `label_hierarchy_levels`, and the TV-specific +`grid_rows`, `grid_cols`, `thr_id_layout`, `col_major`. ## draw_tiled_grid diff --git a/src/tensor_layouts/viz.py b/src/tensor_layouts/viz.py index 5beb83b..dfc4797 100644 --- a/src/tensor_layouts/viz.py +++ b/src/tensor_layouts/viz.py @@ -660,6 +660,8 @@ def _build_composite_figure( panel_size: Tuple[float, float] = (4, 4), colorize: bool = False, tv_mode: bool = False, + flatten_hierarchical: bool = True, + label_hierarchy_levels: bool = False, ): """Build the composite figure used by draw_composite/show_composite.""" n = len(panels) @@ -715,6 +717,10 @@ def _build_composite_figure( panel_tv_mode = opts.get("tv_mode", tv_mode) color_layout = opts.get("color_layout", None) num_colors = opts.get("num_colors", 8) + panel_flatten = opts.get("flatten_hierarchical", flatten_hierarchical) + panel_label_levels = opts.get( + "label_hierarchy_levels", label_hierarchy_levels + ) # Get title title = titles[idx] if titles and idx < len(titles) else None @@ -733,17 +739,41 @@ def _build_composite_figure( col_major=opts.get("col_major", True), ) else: - grid = _prepare_offset_grid( - layout, color_layout=color_layout, eval_fn=eval_fn + # Check if this panel should use hierarchical rendering + r = rank(layout) + is_hier = r == 2 and not panel_flatten and ( + isinstance(mode(layout.shape, 0), tuple) + or isinstance(mode(layout.shape, 1), tuple) ) - _draw_grid( - ax, - grid.indices, - title=title, - colorize=panel_colorize, - color_indices=grid.color_indices, - num_colors=num_colors, + grid = _prepare_offset_grid( + layout, color_layout=color_layout, eval_fn=eval_fn, + hierarchical=is_hier, ) + if grid.is_hierarchical: + _draw_hierarchical_grid( + ax, + grid.indices, + grid.rows, + grid.cols, + cell_coords=grid.cell_coords, + row_shape=grid.row_shape, + col_shape=grid.col_shape, + title=title, + colorize=panel_colorize, + color_indices=grid.color_indices, + flatten_hierarchical=False, + label_hierarchy_levels=panel_label_levels, + num_colors=num_colors, + ) + else: + _draw_grid( + ax, + grid.indices, + title=title, + colorize=panel_colorize, + color_indices=grid.color_indices, + num_colors=num_colors, + ) # Hide unused axes for idx in range(len(panels), len(axes)): @@ -766,6 +796,8 @@ def draw_composite( panel_size: Tuple[float, float] = (4, 4), colorize: bool = False, tv_mode: bool = False, + flatten_hierarchical: bool = True, + label_hierarchy_levels: bool = False, ): """Draw multiple layouts in a single composite figure. @@ -782,6 +814,7 @@ def draw_composite( colorize, color_layout, num_colors -- offset-grid options tv_mode -- if True, render this panel as a TV grid grid_rows, grid_cols, thr_id_layout, col_major -- TV options + flatten_hierarchical, label_hierarchy_levels -- hierarchy options filename: Output path (.svg, .png, or .pdf) arrangement: How to arrange panels: - "horizontal": side by side (1 row) @@ -793,6 +826,10 @@ def draw_composite( panel_size: Size of each panel in inches (width, height) colorize: Default colorize setting for all panels tv_mode: If True, render panels as TV layouts with T/V labels + flatten_hierarchical: Default for all panels. If False, show explicit + nested coordinate labels for hierarchical layouts + label_hierarchy_levels: Default for all panels. If True, annotate axes + with hierarchy level labels Example: # Side-by-side comparison @@ -811,6 +848,8 @@ def draw_composite( panel_size=panel_size, colorize=colorize, tv_mode=tv_mode, + flatten_hierarchical=flatten_hierarchical, + label_hierarchy_levels=label_hierarchy_levels, ) _save_figure(fig, filename, dpi) @@ -3200,6 +3239,8 @@ def show_composite( panel_size: Tuple[float, float] = (4, 4), colorize: bool = False, tv_mode: bool = False, + flatten_hierarchical: bool = True, + label_hierarchy_levels: bool = False, ): """Display a composite figure inline (for Jupyter notebooks). @@ -3211,6 +3252,10 @@ def show_composite( panel_size: Size of each panel in inches (width, height) colorize: Default colorize setting for all panels tv_mode: If True, render panels as TV layouts with T/V labels + flatten_hierarchical: Default for all panels. If False, show explicit + nested coordinate labels for hierarchical layouts + label_hierarchy_levels: Default for all panels. If True, annotate axes + with hierarchy level labels Returns: matplotlib Figure @@ -3223,6 +3268,8 @@ def show_composite( panel_size=panel_size, colorize=colorize, tv_mode=tv_mode, + flatten_hierarchical=flatten_hierarchical, + label_hierarchy_levels=label_hierarchy_levels, ) diff --git a/tests/viz.py b/tests/viz.py index 4a73cc6..116101b 100644 --- a/tests/viz.py +++ b/tests/viz.py @@ -400,6 +400,40 @@ def test_draw_composite_mixed_tv_and_offset(): plt.close(fig) +@requires_viz +def test_draw_composite_hierarchical_panel(): + """Composite figure with flatten_hierarchical=False renders hierarchy lines.""" + hier = Layout(((2, 2), (2, 2)), ((1, 4), (2, 8))) + flat = Layout((4, 4), (4, 1)) + panels = [ + (hier, {'flatten_hierarchical': False}), + flat, + ] + fig = show_composite(panels, titles=["Hierarchical", "Flat"]) + try: + assert isinstance(fig, matplotlib.figure.Figure) + # The hierarchical panel should have hierarchy boundary lines + hier_ax = fig.axes[0] + assert len(hier_ax.lines) > 0 + # The flat panel should have no hierarchy lines + flat_ax = fig.axes[1] + assert len(flat_ax.lines) == 0 + finally: + plt.close(fig) + + +@requires_viz +def test_draw_composite_hierarchical_top_level_default(): + """flatten_hierarchical=False as top-level default applies to all panels.""" + hier = Layout(((2, 2), (2, 2)), ((1, 4), (2, 8))) + fig = show_composite([hier], flatten_hierarchical=False) + try: + ax = fig.axes[0] + assert len(ax.lines) > 0 + finally: + plt.close(fig) + + @requires_viz def test_draw_copy_layout_smoke(): src = Layout((4, 2), (2, 1)) From 1c38385f4668241ba0c81e421091700bf7592134 Mon Sep 17 00:00:00 2001 From: Jean-Luc Duprat Date: Tue, 24 Mar 2026 12:49:40 -0700 Subject: [PATCH 07/16] Warn when composite figure panels exceed grid capacity _build_composite_figure previously silently dropped panels that did not fit into the grid. Now emits a UserWarning so users know data is being omitted. --- src/tensor_layouts/viz.py | 9 +++++++++ tests/viz.py | 9 +++++++++ 2 files changed, 18 insertions(+) diff --git a/src/tensor_layouts/viz.py b/src/tensor_layouts/viz.py index dfc4797..81e1db9 100644 --- a/src/tensor_layouts/viz.py +++ b/src/tensor_layouts/viz.py @@ -693,6 +693,15 @@ def _build_composite_figure( axes = [axes_array[i, j] for i in range(nrows) for j in range(ncols)] # Process each panel + if len(panels) > nrows * ncols: + import warnings + + warnings.warn( + f"{len(panels)} panels provided but grid has only " + f"{nrows * ncols} cells ({nrows}x{ncols}); " + f"extra panels will be dropped", + stacklevel=3, + ) for idx, panel in enumerate(panels): if idx >= len(axes): break diff --git a/tests/viz.py b/tests/viz.py index 116101b..e34ab7c 100644 --- a/tests/viz.py +++ b/tests/viz.py @@ -434,6 +434,15 @@ def test_draw_composite_hierarchical_top_level_default(): plt.close(fig) +@requires_viz +def test_draw_composite_warns_on_panel_truncation(): + """Warning emitted when panels exceed grid capacity.""" + panels = [Layout((2, 2), (2, 1)) for _ in range(5)] + with pytest.warns(UserWarning, match="5 panels.*4 cells"): + fig = show_composite(panels, arrangement="grid:2x2") + plt.close(fig) + + @requires_viz def test_draw_copy_layout_smoke(): src = Layout((4, 2), (2, 1)) From 01ee60308609cea9fb15c85de0b09a0d45f42a8b Mon Sep 17 00:00:00 2001 From: Jean-Luc Duprat Date: Tue, 24 Mar 2026 12:51:36 -0700 Subject: [PATCH 08/16] Add TV-aware vectorized access modeling to analysis functions bank_conflicts, coalescing_efficiency, and segment_analysis previously treated each thread as issuing a single scalar access. For TV layouts where mode 0 is the thread dimension and mode 1+ are value dimensions, the functions now iterate all values per thread, correctly modeling vectorized loads (e.g., LDG.128, LDS.128). Rank-1 layouts are unchanged (value_count=1). --- docs/analysis_api.md | 23 ++++++++++++ src/tensor_layouts/analysis.py | 66 ++++++++++++++++++++++------------ tests/analysis.py | 29 +++++++++++++++ 3 files changed, 95 insertions(+), 23 deletions(-) diff --git a/docs/analysis_api.md b/docs/analysis_api.md index dafcc18..337529e 100644 --- a/docs/analysis_api.md +++ b/docs/analysis_api.md @@ -75,6 +75,18 @@ The `max_ways` value is the worst-case serialization factor: 1 means no conflicts, N means N-way serialization. Two threads accessing the *same* word get a broadcast (no conflict on NVIDIA hardware). +For multi-mode (TV) layouts where mode 0 is the thread dimension and +mode 1+ are value dimensions, all values per thread are included in the +analysis. This models vectorized loads where each thread accesses +multiple elements: + +```python +# TV layout: 32 threads, each loading 2 fp16 elements +tv = Layout((32, 2), (1, 32)) +result = bank_conflicts(tv, element_bytes=2) +result['conflict_free'] # True: values land in distinct banks +``` + Returns a dict: | Key | Type | Description | @@ -111,6 +123,17 @@ Returns a dict: | `efficiency` | float | Unique useful bytes / transferred bytes (1.0 = perfect) | | `cache_lines` | list | Sorted cache line indices touched | +For multi-mode (TV) layouts, all values per thread are included, +modeling vectorized loads: + +```python +# TV layout: 32 threads, 4 values each, contiguous within each thread +tv = Layout((32, 4), (4, 1)) +result = coalescing_efficiency(tv, element_bytes=2) +result['transactions'] # 2 (256 bytes spans 2 cache lines) +result['efficiency'] # 1.0 (256 unique bytes / 256 transferred) +``` + ## Permutation Analysis When a layout is bijective (every offset is hit exactly once), it defines diff --git a/src/tensor_layouts/analysis.py b/src/tensor_layouts/analysis.py index e04f7f5..eb897a7 100644 --- a/src/tensor_layouts/analysis.py +++ b/src/tensor_layouts/analysis.py @@ -153,8 +153,11 @@ def bank_conflicts(layout: Layout, *, num_banks: int = 32, Only the first ``group_size`` threads are analyzed, matching the hardware issue granularity (warp on NVIDIA, wavefront on AMD). This avoids overstating conflicts when the layout spans multiple - warps. The model assigns each access to its starting bank word; - accesses wider than one bank word are not tracked across banks. + warps. + + For multi-mode (TV) layouts, mode 0 is the thread dimension and all + remaining modes are value dimensions. Each thread's accesses across + all values are included in the analysis, modeling vectorized loads. Args: layout: Maps thread_id -> memory offset (in elements). @@ -182,18 +185,21 @@ def bank_conflicts(layout: Layout, *, num_banks: int = 32, layout = as_layout(layout) if group_size <= 0: raise ValueError(f"group_size must be positive, got {group_size}") - n = min(size(layout), group_size) + thread_count, value_count = _tv_dimensions(layout) + n = min(thread_count, group_size) # Map each thread to (bank, word_address) # A bank conflict occurs when threads access different 4-byte words in the # same bank. Two threads accessing the same word get a broadcast (no conflict). thread_banks = {} # bank -> [(thread_id, word_address), ...] for t in range(n): - offset = layout(t) - byte_addr = offset * element_bytes - word_addr = byte_addr // bank_width_bytes - bank = word_addr % num_banks - thread_banks.setdefault(bank, []).append((t, word_addr)) + for v in range(value_count): + flat_idx = v * thread_count + t + offset = layout(flat_idx) + byte_addr = offset * element_bytes + word_addr = byte_addr // bank_width_bytes + bank = word_addr % num_banks + thread_banks.setdefault(bank, []).append((t, word_addr)) # Compute conflicts per bank # Two threads conflict if they hit the same bank but different addresses. @@ -239,6 +245,10 @@ def coalescing_efficiency(layout: Layout, *, warp_size: int = 32, transaction. In the worst case, each thread triggers a separate transaction. + For multi-mode (TV) layouts, mode 0 is the thread dimension and all + remaining modes are value dimensions. Each thread's accesses across + all values are included in the analysis, modeling vectorized loads. + Args: layout: Maps thread_id -> memory offset (in elements). warp_size: Threads per warp (32 on NVIDIA/AMD GPUs). @@ -263,17 +273,20 @@ def coalescing_efficiency(layout: Layout, *, warp_size: int = 32, # {'transactions': 2, 'efficiency': 0.5, ...} """ layout = as_layout(layout) - n = min(size(layout), warp_size) + thread_count, value_count = _tv_dimensions(layout) + n = min(thread_count, warp_size) # Find which cache lines are touched and count unique offsets cache_lines = set() unique_offsets = set() for t in range(n): - offset = layout(t) - unique_offsets.add(offset) - byte_addr = offset * element_bytes - cache_line = byte_addr // cache_line_bytes - cache_lines.add(cache_line) + for v in range(value_count): + flat_idx = v * thread_count + t + offset = layout(flat_idx) + unique_offsets.add(offset) + byte_addr = offset * element_bytes + cache_line = byte_addr // cache_line_bytes + cache_lines.add(cache_line) transactions = len(cache_lines) useful_bytes = len(unique_offsets) * element_bytes @@ -298,6 +311,10 @@ def segment_analysis(layout: Layout, *, warp_size: int = 32, warp access may touch fewer cache lines than segments when accesses cluster within a line but span multiple segments. + For multi-mode (TV) layouts, mode 0 is the thread dimension and all + remaining modes are value dimensions. Each thread's accesses across + all values are included in the analysis, modeling vectorized loads. + Args: layout: Maps thread_id -> memory offset (in elements). warp_size: Threads per warp. @@ -317,7 +334,8 @@ def segment_analysis(layout: Layout, *, warp_size: int = 32, first_alignment: alignment of first_byte_addr to segment_bytes """ layout = as_layout(layout) - n = min(size(layout), warp_size) + thread_count, value_count = _tv_dimensions(layout) + n = min(thread_count, warp_size) segments = set() lines = set() @@ -325,19 +343,21 @@ def segment_analysis(layout: Layout, *, warp_size: int = 32, first_byte = None for t in range(n): - offset = layout(t) - unique_offsets.add(offset) - byte_addr = offset * element_bytes - if first_byte is None: - first_byte = byte_addr - segments.add(byte_addr // segment_bytes) - lines.add(byte_addr // cache_line_bytes) + for v in range(value_count): + flat_idx = v * thread_count + t + offset = layout(flat_idx) + unique_offsets.add(offset) + byte_addr = offset * element_bytes + if first_byte is None: + first_byte = byte_addr + segments.add(byte_addr // segment_bytes) + lines.add(byte_addr // cache_line_bytes) first_byte = first_byte if first_byte is not None else 0 n_segments = len(segments) n_lines = len(lines) unique_bytes = len(unique_offsets) * element_bytes - requested_bytes = n * element_bytes + requested_bytes = n * value_count * element_bytes transferred_bytes = n_segments * segment_bytes return { diff --git a/tests/analysis.py b/tests/analysis.py index f31a8f6..07b81de 100644 --- a/tests/analysis.py +++ b/tests/analysis.py @@ -172,6 +172,15 @@ def test_bank_conflicts_group_size_validation(): bank_conflicts(Layout(32, 1), group_size=-1) +def test_bank_conflicts_tv_layout(): + """TV layout analyzes all values per thread, not just value 0.""" + # 32 threads, 2 values: stride-1 threads, stride-32 values + tv = Layout((32, 2), (1, 32)) + r = bank_conflicts(tv, element_bytes=2) + assert r['conflict_free'] + assert len(r['bank_to_threads']) == 32 # all banks accessed + + ## coalescing_efficiency @@ -214,6 +223,16 @@ def test_coalescing_broadcast(): assert result['efficiency'] == pytest.approx(2.0 / 128) +def test_coalescing_tv_layout(): + """TV layout counts all values in cache line computation.""" + # 32 threads, 4 values each, stride-4 between threads + tv = Layout((32, 4), (4, 1)) + result = coalescing_efficiency(tv, element_bytes=2) + # 128 unique offsets * 2B = 256B -> cache lines 0, 1 + assert result['transactions'] == 2 + assert result['efficiency'] == pytest.approx(1.0) + + ## segment_analysis @@ -246,6 +265,16 @@ def test_segment_analysis_broadcast(): assert result['requested_bytes'] == 64 +def test_segment_analysis_tv_layout(): + """TV layout includes all values in segment computation.""" + tv = Layout((32, 4), (4, 1)) + result = segment_analysis(tv, element_bytes=2) + # 128 elements * 2B = 256B -> 8 segments, 2 cache lines + assert result['segments'] == 8 + assert result['cache_lines'] == 2 + assert result['requested_bytes'] == 256 # 32 * 4 * 2 + + ## per-group analysis From bea9484cd452eecc947a5846b4f7470e7d64e425 Mon Sep 17 00:00:00 2001 From: Jean-Luc Duprat Date: Tue, 24 Mar 2026 12:58:38 -0700 Subject: [PATCH 09/16] Make element_bytes a required parameter in analysis functions element_bytes varies per use case (fp16=2, fp32=4, fp8=1) and should be set explicitly on every call. Hardware constants like warp_size, num_banks, and cache_line_bytes rarely change and keep their defaults. Reorder parameters to put element_bytes first (no default) across bank_conflicts, coalescing_efficiency, segment_analysis, and their per_group variants. --- docs/analysis_api.md | 6 +++--- examples/layouts.py | 2 +- src/tensor_layouts/analysis.py | 29 +++++++++++++++-------------- tests/analysis.py | 22 +++++++++++----------- 4 files changed, 30 insertions(+), 29 deletions(-) diff --git a/docs/analysis_api.md b/docs/analysis_api.md index 337529e..9d684d2 100644 --- a/docs/analysis_api.md +++ b/docs/analysis_api.md @@ -44,7 +44,7 @@ offset_table(Layout((4, 2), (0, 1))) # 1: [(0,1), (1,1), (2,1), (3,1)]} ``` -## bank_conflicts(layout, *, num_banks=32, element_bytes=2, bank_width_bytes=4, group_size=32) +## bank_conflicts(layout, *, element_bytes, num_banks=32, bank_width_bytes=4, group_size=32) Analyze shared memory bank conflicts for a thread-to-offset layout. @@ -95,7 +95,7 @@ Returns a dict: | `max_ways` | int | Worst-case serialization factor across all banks | | `bank_to_threads` | dict | `{bank_id: [thread_ids...]}` for all accessed banks | -## coalescing_efficiency(layout, *, warp_size=32, element_bytes=2, cache_line_bytes=128) +## coalescing_efficiency(layout, *, element_bytes, warp_size=32, cache_line_bytes=128) Analyze global memory coalescing for a thread-to-offset layout. @@ -110,7 +110,7 @@ result['transactions'] # 1 result['efficiency'] # 1.0 (128 unique useful bytes / 128 transferred) # Worst case: each thread hits a separate cache line -result = coalescing_efficiency(Layout(32, 64)) +result = coalescing_efficiency(Layout(32, 64), element_bytes=2) result['transactions'] # 32 result['efficiency'] # 0.016 (64 unique useful bytes / 4096 transferred) ``` diff --git a/examples/layouts.py b/examples/layouts.py index 47602a7..4348e94 100644 --- a/examples/layouts.py +++ b/examples/layouts.py @@ -763,7 +763,7 @@ def example_analysis(): f"efficiency {r1['efficiency']:.0%}") scattered = Layout(32, 64) - r2 = coalescing_efficiency(scattered) + r2 = coalescing_efficiency(scattered, element_bytes=2) print(f" Stride-64 (fp16): {r2['transactions']} transactions, " f"efficiency {r2['efficiency']:.1%}") diff --git a/src/tensor_layouts/analysis.py b/src/tensor_layouts/analysis.py index eb897a7..9d01a2c 100644 --- a/src/tensor_layouts/analysis.py +++ b/src/tensor_layouts/analysis.py @@ -136,8 +136,8 @@ def footprint(layout: Layout) -> dict: # Bank conflict analysis # ============================================================================= -def bank_conflicts(layout: Layout, *, num_banks: int = 32, - element_bytes: int = 2, bank_width_bytes: int = 4, +def bank_conflicts(layout: Layout, *, element_bytes: int, + num_banks: int = 32, bank_width_bytes: int = 4, group_size: int = 32): """Analyze shared memory bank conflicts for a thread-to-offset layout. @@ -174,12 +174,12 @@ def bank_conflicts(layout: Layout, *, num_banks: int = 32, bank_to_threads: {bank_id: [thread_ids...]} for all accessed banks Examples: - # Linear layout: threads access consecutive elements - bank_conflicts(Layout(32, 1)) + # Linear layout: threads access consecutive fp16 elements + bank_conflicts(Layout(32, 1), element_bytes=2) # {'conflict_free': True, 'max_ways': 1, ...} # All threads hit the same address - bank_conflicts(Layout(32, 0)) + bank_conflicts(Layout(32, 0), element_bytes=2) # {'conflict_free': True, 'max_ways': 1, ...} (broadcast, not a conflict) """ layout = as_layout(layout) @@ -231,8 +231,8 @@ def bank_conflicts(layout: Layout, *, num_banks: int = 32, # Coalescing analysis # ============================================================================= -def coalescing_efficiency(layout: Layout, *, warp_size: int = 32, - element_bytes: int = 2, +def coalescing_efficiency(layout: Layout, *, element_bytes: int, + warp_size: int = 32, cache_line_bytes: int = 128): """Analyze global memory coalescing for a thread-to-offset layout. @@ -265,7 +265,7 @@ def coalescing_efficiency(layout: Layout, *, warp_size: int = 32, Examples: # Perfectly coalesced: 32 threads, stride 1, fp16 - coalescing_efficiency(Layout(32, 1)) + coalescing_efficiency(Layout(32, 1), element_bytes=2) # {'transactions': 1, 'efficiency': 0.5, ...} -- 64B used of 128B line # Strided access: each thread 2 elements apart, fp32 @@ -300,8 +300,8 @@ def coalescing_efficiency(layout: Layout, *, warp_size: int = 32, } -def segment_analysis(layout: Layout, *, warp_size: int = 32, - element_bytes: int = 2, +def segment_analysis(layout: Layout, *, element_bytes: int, + warp_size: int = 32, segment_bytes: int = 32, cache_line_bytes: int = 128): """Segment- and alignment-aware global memory transaction analysis. @@ -389,8 +389,9 @@ def _tv_dimensions(layout: Layout): return size(mode(layout, 0)), size(layout) // size(mode(layout, 0)) -def per_group_bank_conflicts(layout: Layout, *, group_size: int = 32, - num_banks: int = 32, element_bytes: int = 2, +def per_group_bank_conflicts(layout: Layout, *, element_bytes: int, + group_size: int = 32, + num_banks: int = 32, bank_width_bytes: int = 4) -> dict: """Analyze bank conflicts per warp/wavefront group across a full layout. @@ -467,8 +468,8 @@ def per_group_bank_conflicts(layout: Layout, *, group_size: int = 32, } -def per_group_coalescing(layout: Layout, *, group_size: int = 32, - element_bytes: int = 2, +def per_group_coalescing(layout: Layout, *, element_bytes: int, + group_size: int = 32, cache_line_bytes: int = 128) -> dict: """Analyze coalescing efficiency per warp/wavefront group across a full layout. diff --git a/tests/analysis.py b/tests/analysis.py index 07b81de..90f7142 100644 --- a/tests/analysis.py +++ b/tests/analysis.py @@ -98,14 +98,14 @@ def test_footprint_broadcast(): def test_bank_conflicts_linear(): """Linear stride-1 access: no conflicts.""" - result = bank_conflicts(Layout(32, 1)) + result = bank_conflicts(Layout(32, 1), element_bytes=2) assert result['conflict_free'] assert result['max_ways'] == 1 def test_bank_conflicts_broadcast(): """All threads access same address: broadcast, not a conflict.""" - result = bank_conflicts(Layout(32, 0)) + result = bank_conflicts(Layout(32, 0), element_bytes=2) assert result['conflict_free'] @@ -167,9 +167,9 @@ def test_bank_conflicts_group_size(): def test_bank_conflicts_group_size_validation(): """group_size <= 0 must raise ValueError.""" with pytest.raises(ValueError, match="group_size must be positive"): - bank_conflicts(Layout(32, 1), group_size=0) + bank_conflicts(Layout(32, 1), element_bytes=2, group_size=0) with pytest.raises(ValueError, match="group_size must be positive"): - bank_conflicts(Layout(32, 1), group_size=-1) + bank_conflicts(Layout(32, 1), element_bytes=2, group_size=-1) def test_bank_conflicts_tv_layout(): @@ -186,7 +186,7 @@ def test_bank_conflicts_tv_layout(): def test_coalescing_contiguous_fp16(): """32 threads, stride 1, fp16: one cache line (64B of 128B).""" - result = coalescing_efficiency(Layout(32, 1)) + result = coalescing_efficiency(Layout(32, 1), element_bytes=2) assert result['transactions'] == 1 assert result['efficiency'] == pytest.approx(0.5) @@ -200,7 +200,7 @@ def test_coalescing_contiguous_fp32(): def test_coalescing_strided(): """Stride-2 access doubles the cache lines touched.""" - result = coalescing_efficiency(Layout(32, 2)) + result = coalescing_efficiency(Layout(32, 2), element_bytes=2) assert result['transactions'] == 1 # 32*2*2=128 bytes, still fits in 1 line # Actually: offsets 0,2,4,...,62. byte addrs 0,4,8,...,124. All in line 0. assert result['efficiency'] == pytest.approx(0.5) @@ -209,7 +209,7 @@ def test_coalescing_strided(): def test_coalescing_large_stride(): """Large stride: each thread touches a different cache line.""" # stride 64 elements * 2 bytes = 128 bytes = 1 cache line apart - result = coalescing_efficiency(Layout(32, 64)) + result = coalescing_efficiency(Layout(32, 64), element_bytes=2) assert result['transactions'] == 32 # 32 threads * 2 bytes = 64 useful bytes, 32 * 128 = 4096 transferred assert result['efficiency'] == pytest.approx(64.0 / (32 * 128)) @@ -217,7 +217,7 @@ def test_coalescing_large_stride(): def test_coalescing_broadcast(): """All threads access same element: single transaction, minimal useful bytes.""" - result = coalescing_efficiency(Layout(32, 0)) + result = coalescing_efficiency(Layout(32, 0), element_bytes=2) assert result['transactions'] == 1 # Only 1 unique offset: 1 * 2 bytes useful out of 128 transferred assert result['efficiency'] == pytest.approx(2.0 / 128) @@ -238,7 +238,7 @@ def test_coalescing_tv_layout(): def test_segment_analysis_contiguous_fp16(): """32 threads, stride 1, fp16: 2 segments, 1 cache line.""" - result = segment_analysis(Layout(32, 1)) + result = segment_analysis(Layout(32, 1), element_bytes=2) # 32 * 2B = 64B -> 2 segments of 32B, 1 cache line of 128B assert result['segments'] == 2 assert result['cache_lines'] == 1 @@ -293,7 +293,7 @@ def test_per_group_bank_conflicts_tv_layout(): """TV layout groups by thread dimension, not flat index.""" # 32 threads, 4 values each: should be 1 group (not 4) tv = Layout((32, 4), (1, 32)) - result = per_group_bank_conflicts(tv, group_size=32) + result = per_group_bank_conflicts(tv, element_bytes=2, group_size=32) assert len(result['groups']) == 1 @@ -310,7 +310,7 @@ def test_per_group_coalescing_tv_layout(): """TV layout groups by thread dimension, not flat index.""" # 32 threads, 4 values each (contiguous within each thread's block) tv = Layout((32, 4), (4, 1)) - result = per_group_coalescing(tv, group_size=32) + result = per_group_coalescing(tv, element_bytes=2, group_size=32) assert len(result['groups']) == 1 # 32 threads * 4 values = 128 elements * 2B = 256B -> 2 cache lines assert result['groups'][0]['transactions'] == 2 From b5475cacbd371a324f7af51f5202e2dc960169b3 Mon Sep 17 00:00:00 2001 From: Jean-Luc Duprat Date: Tue, 24 Mar 2026 13:35:11 -0700 Subject: [PATCH 10/16] Add identity short-circuit to Layout.__eq__ and Swizzle.__eq__ Add `if self is other: return True` as the first check in both Layout.__eq__ and Swizzle.__eq__. This is a standard Python best practice that avoids redundant field-by-field comparison when testing an object against itself. Addresses REVIEW_ANALYSIS.md Section 5 (Equality Short-Circuiting). --- src/tensor_layouts/layouts.py | 4 +++ tests/layouts.py | 46 +++++++++++++++++++++++++++++++++++ 2 files changed, 50 insertions(+) diff --git a/src/tensor_layouts/layouts.py b/src/tensor_layouts/layouts.py index 2d21921..41b623b 100644 --- a/src/tensor_layouts/layouts.py +++ b/src/tensor_layouts/layouts.py @@ -394,6 +394,8 @@ def __init__(self, *args, swizzle: "Swizzle | None" = None): ) def __eq__(self, other): + if self is other: + return True if not isinstance(other, Layout): return False return ( @@ -3259,6 +3261,8 @@ def __repr__(self) -> str: return f"Swizzle({self.bits}, {self.base}, {self.shift})" def __eq__(self, other: object) -> bool: + if self is other: + return True if not isinstance(other, Swizzle): return False return self.bits == other.bits and self.base == other.base and self.shift == other.shift diff --git a/tests/layouts.py b/tests/layouts.py index 4244cdd..45e432f 100644 --- a/tests/layouts.py +++ b/tests/layouts.py @@ -1110,6 +1110,52 @@ def test_layout_hash(): assert len(s) == 2 +## Layout.__eq__ identity short-circuit + + +def test_layout_eq_identity_shortcircuit(): + """Same object identity returns True immediately.""" + L = Layout((4, 8), (1, 4)) + assert L == L + assert L is L + + sw = Swizzle(3, 0, 3) + L_sw = compose(sw, L) + assert L_sw == L_sw + + +def test_layout_eq_structural(): + """Distinct objects with equal shape/stride/swizzle are equal.""" + L1 = Layout((4, 8), (1, 4)) + L2 = Layout((4, 8), (1, 4)) + assert L1 is not L2 + assert L1 == L2 + + +def test_layout_eq_non_layout(): + """Comparing Layout with non-Layout returns False, not an error.""" + L = Layout((4, 8), (1, 4)) + assert L != 42 + assert L != "not a layout" + assert L != (4, 8) + assert L != None # noqa: E711 + + +def test_swizzle_eq_identity_shortcircuit(): + """Same Swizzle identity returns True immediately.""" + sw = Swizzle(3, 0, 3) + assert sw == sw + assert sw is sw + + +def test_swizzle_eq_structural(): + """Distinct Swizzle objects with equal fields are equal.""" + sw1 = Swizzle(3, 0, 3) + sw2 = Swizzle(3, 0, 3) + assert sw1 is not sw2 + assert sw1 == sw2 + + ## compose() functional property From bff02458e232c82a45c2e73bede0585ecd2447a0 Mon Sep 17 00:00:00 2001 From: Jean-Luc Duprat Date: Tue, 24 Mar 2026 13:37:08 -0700 Subject: [PATCH 11/16] Make Layout.__repr__ return eval-safe constructor string Split the string representation into two methods following Python conventions: - __repr__ now returns an eval-safe constructor string such as Layout((4, 2), (1, 4)) or Layout((8, 8), (8, 1), swizzle=Swizzle(3, 0, 3)). This satisfies the Python data model guideline that repr should, where feasible, return a string that can recreate the object via eval(). - __str__ retains the human-readable CuTe notation (4, 2) : (1, 4) used in print() and casual display. Addresses REVIEW_ANALYSIS.md Section 5 (String Representations). --- src/tensor_layouts/layouts.py | 17 +++++++-- tests/layouts.py | 69 ++++++++++++++++++++++++++++++++++- 2 files changed, 81 insertions(+), 5 deletions(-) diff --git a/src/tensor_layouts/layouts.py b/src/tensor_layouts/layouts.py index 41b623b..dbbb106 100644 --- a/src/tensor_layouts/layouts.py +++ b/src/tensor_layouts/layouts.py @@ -411,15 +411,24 @@ def __hash__(self): return hash((self.shape, self.stride, swizzle_hash)) def __repr__(self): + """Return an eval-safe constructor string: Layout((4, 2), (1, 4)).""" + if self._swizzle is not None: + return ( + f"Layout({self._shape!r}, {self._stride!r}, " + f"swizzle={self._swizzle!r})" + ) + return f"Layout({self._shape!r}, {self._stride!r})" + + def __str__(self): + """Return human-readable CuTe notation: (4, 2) : (1, 4).""" def fmt(x): - """Format shape/stride: int as-is, tuple with parens.""" if isinstance(x, int): return str(x) return repr(x) - base_repr = f"{fmt(self._shape)} : {fmt(self._stride)}" + base = f"{fmt(self._shape)} : {fmt(self._stride)}" if self._swizzle is not None: - return f"({self._swizzle}) o ({base_repr})" - return base_repr + return f"({self._swizzle}) o ({base})" + return base @property def shape(self) -> IntOrIntTuple: diff --git a/tests/layouts.py b/tests/layouts.py index 45e432f..0ee5e8d 100644 --- a/tests/layouts.py +++ b/tests/layouts.py @@ -1087,7 +1087,74 @@ def test_safe_div(): def test_tile_repr(): tiler = Tile(Layout(3, 4), Layout(8, 2)) r = repr(tiler) - assert r == "Tile(3 : 4, 8 : 2)" + assert r == "Tile(Layout(3, 4), Layout(8, 2))" + + +## Layout.__repr__ and __str__ + + +def test_layout_repr_scalar(): + """repr() of a 1D layout returns an eval-safe constructor string.""" + L = Layout(8, 2) + assert repr(L) == "Layout(8, 2)" + + +def test_layout_repr_tuple(): + """repr() of a multi-dimensional layout returns an eval-safe constructor string.""" + L = Layout((4, 8), (1, 4)) + assert repr(L) == "Layout((4, 8), (1, 4))" + + +def test_layout_repr_hierarchical(): + """repr() of a hierarchical layout returns an eval-safe constructor string.""" + L = Layout(((2, 3), (2, 4)), ((1, 6), (2, 12))) + assert repr(L) == "Layout(((2, 3), (2, 4)), ((1, 6), (2, 12)))" + + +def test_layout_repr_swizzled(): + """repr() of a swizzled layout includes the swizzle keyword argument.""" + sw = Swizzle(3, 0, 3) + L = compose(sw, Layout((8, 8), (8, 1))) + r = repr(L) + assert r == "Layout((8, 8), (8, 1), swizzle=Swizzle(3, 0, 3))" + + +def test_layout_repr_eval_roundtrip(): + """eval(repr(L)) reconstructs an equal Layout (the gold standard for repr).""" + cases = [ + Layout(8, 2), + Layout((4, 8), (1, 4)), + Layout((4, 8), (0, 1)), + Layout(((2, 3), (2, 4)), ((1, 6), (2, 12))), + ] + for L in cases: + reconstructed = eval(repr(L)) # noqa: S307 + assert reconstructed == L, f"Roundtrip failed for {repr(L)}" + + +def test_layout_repr_eval_roundtrip_swizzled(): + """eval(repr(L)) works for swizzled layouts too.""" + L = compose(Swizzle(3, 0, 3), Layout((8, 8), (8, 1))) + reconstructed = eval(repr(L)) # noqa: S307 + assert reconstructed == L + + +def test_layout_str_scalar(): + """str() returns the human-readable CuTe notation.""" + L = Layout(8, 2) + assert str(L) == "8 : 2" + + +def test_layout_str_tuple(): + """str() returns the human-readable CuTe notation for multi-dim layouts.""" + L = Layout((4, 8), (1, 4)) + assert str(L) == "(4, 8) : (1, 4)" + + +def test_layout_str_swizzled(): + """str() returns the CuTe composition notation for swizzled layouts.""" + L = compose(Swizzle(3, 0, 3), Layout((8, 8), (8, 1))) + assert str(L) == "(Swizzle(3, 0, 3)) o ((8, 8) : (8, 1))" ## Layout.__hash__ From ced7f9aade59d5cfc3451083a474110a6e76103a Mon Sep 17 00:00:00 2001 From: Jean-Luc Duprat Date: Tue, 24 Mar 2026 13:42:18 -0700 Subject: [PATCH 12/16] Add Intel AMX tile matrix multiply atom definitions Introduces atoms_amx.py with MMAAtom definitions for Intel AMX instructions (tdpbf16ps, tdpfp16ps, tdpbssd, tdpbsud, tdpbusd, tdpbuud). AMX is a true tile matrix multiply (16x16 output) executed by a single CPU core (T=1), making it the cleanest CPU-to-MMAAtom mapping in the layout algebra framework. --- src/tensor_layouts/atoms_amx.py | 152 ++++++++++++++++++++++++++++++++ 1 file changed, 152 insertions(+) create mode 100644 src/tensor_layouts/atoms_amx.py diff --git a/src/tensor_layouts/atoms_amx.py b/src/tensor_layouts/atoms_amx.py new file mode 100644 index 0000000..c7b05b4 --- /dev/null +++ b/src/tensor_layouts/atoms_amx.py @@ -0,0 +1,152 @@ +# MIT License +# +# Copyright (c) 2026 Meta Platforms, Inc. and affiliates. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +"""Intel AMX (Advanced Matrix Extensions) tile atom definitions. + +Maps AMX tile matrix-multiply instructions to the (Thread, Value) -> +element-offset framework used by CuTe-style layout algebra. + +Conceptual mapping +================== + +AMX is a true tile matrix multiply executed by a single CPU core. The +instruction operates on 8 KiB tile registers (tmm0-tmm7), each holding +up to 16 rows x 64 bytes. A single ``tdp*`` instruction computes the +full C[M,N] += A[M,K] * B[K,N] tile with no thread cooperation, so: + + T = 1 (one CPU core) + V = M*K, N*K, or M*N (all elements in the Value dimension) + +This is the cleanest CPU -> MMAAtom mapping because AMX is a genuine +matrix multiply, unlike AVX512 VNNI which is a batched dot-product. + +Tile dimensions +=============== + +All AMX tile-multiply instructions use M=16, N=16 output tiles. +The K dimension depends on the element type: + + BF16 (tdpbf16ps): K=32 (2 BF16 per 32-bit pair, 64-byte rows) + FP16 (tdpfp16ps): K=32 (2 FP16 per 32-bit pair, 64-byte rows) + INT8 (tdpbssd): K=64 (4 INT8 per 32-bit group, 64-byte rows) + +B-tile storage +-------------- + +The B tile register uses "VNNI format" (pairs/quads of K-elements packed +into 32-bit groups along the row), but the LOGICAL layout is still K x N. +Our atom layouts describe the logical element ordering, not the physical +register packing. + +References +========== + +- Intel Architecture Instruction Set Extensions Programming Reference + (ISE), Chapter 3 -- AMX (TILECFG, TILELOADD, TDPBF16PS, TDPBSSD, etc.) +- Intel 64 and IA-32 Architectures Software Developer's Manual (SDM), + Volume 2 -- Instruction Set Reference + +Usage:: + + from tensor_layouts.atoms_amx import AMX_16x16x32_F32BF16BF16F32 + atom = AMX_16x16x32_F32BF16BF16F32 + print(atom.shape_mnk) # (16, 16, 32) + print(atom.c_layout) # (1, (16, 16)):(0, (1, 16)) +""" + +from .layouts import Layout +from .atoms import MMAAtom + + +# ============================================================================= +# Intel AMX tile matrix multiply atoms — T=1 (single CPU core) +# Source: Intel ISE Chapter 3, AMX instructions +# +# tdpbf16ps tmm, tmm, tmm -- C[16,16] FP32 += A[16,32] BF16 * B[32,16] BF16 +# tdpfp16ps tmm, tmm, tmm -- C[16,16] FP32 += A[16,32] FP16 * B[32,16] FP16 +# tdpbssd tmm, tmm, tmm -- C[16,16] INT32 += A[16,64] INT8 * B[64,16] INT8 +# tdpbsud tmm, tmm, tmm -- C[16,16] INT32 += A[16,64] INT8 * B[64,16] UINT8 +# tdpbusd tmm, tmm, tmm -- C[16,16] INT32 += A[16,64] UINT8 * B[64,16] INT8 +# tdpbuud tmm, tmm, tmm -- C[16,16] INT32 += A[16,64] UINT8 * B[64,16] UINT8 +# +# All produce a 16x16 output tile. K varies by datatype. +# ============================================================================= + +# -- BF16 -> FP32 ------------------------------------------------------------- +AMX_16x16x32_F32BF16BF16F32 = MMAAtom( + name="AMX_16x16x32_F32BF16BF16F32", + ptx="tdpbf16ps", + shape_mnk=(16, 16, 32), thr_id=Layout(1), + # (T=1, V=512) -> col-major offset in (M=16, K=32) + a_layout=Layout((1, (16, 32)), (0, (1, 16))), + # (T=1, V=512) -> col-major offset in (N=16, K=32) + b_layout=Layout((1, (16, 32)), (0, (1, 16))), + # (T=1, V=256) -> col-major offset in (M=16, N=16) + c_layout=Layout((1, (16, 16)), (0, (1, 16)))) + +# -- FP16 -> FP32 ------------------------------------------------------------- +AMX_16x16x32_F32F16F16F32 = MMAAtom( + name="AMX_16x16x32_F32F16F16F32", + ptx="tdpfp16ps", + shape_mnk=(16, 16, 32), thr_id=Layout(1), + a_layout=Layout((1, (16, 32)), (0, (1, 16))), + b_layout=Layout((1, (16, 32)), (0, (1, 16))), + c_layout=Layout((1, (16, 16)), (0, (1, 16)))) + +# -- INT8 x INT8 -> INT32 (signed x signed) ----------------------------------- +AMX_16x16x64_S32S8S8S32 = MMAAtom( + name="AMX_16x16x64_S32S8S8S32", + ptx="tdpbssd", + shape_mnk=(16, 16, 64), thr_id=Layout(1), + # (T=1, V=1024) -> col-major offset in (M=16, K=64) + a_layout=Layout((1, (16, 64)), (0, (1, 16))), + # (T=1, V=1024) -> col-major offset in (N=16, K=64) + b_layout=Layout((1, (16, 64)), (0, (1, 16))), + # (T=1, V=256) -> col-major offset in (M=16, N=16) + c_layout=Layout((1, (16, 16)), (0, (1, 16)))) + +# -- INT8 x UINT8 -> INT32 (signed x unsigned) -------------------------------- +AMX_16x16x64_S32S8U8S32 = MMAAtom( + name="AMX_16x16x64_S32S8U8S32", + ptx="tdpbsud", + shape_mnk=(16, 16, 64), thr_id=Layout(1), + a_layout=Layout((1, (16, 64)), (0, (1, 16))), + b_layout=Layout((1, (16, 64)), (0, (1, 16))), + c_layout=Layout((1, (16, 16)), (0, (1, 16)))) + +# -- UINT8 x INT8 -> INT32 (unsigned x signed) -------------------------------- +AMX_16x16x64_S32U8S8S32 = MMAAtom( + name="AMX_16x16x64_S32U8S8S32", + ptx="tdpbusd", + shape_mnk=(16, 16, 64), thr_id=Layout(1), + a_layout=Layout((1, (16, 64)), (0, (1, 16))), + b_layout=Layout((1, (16, 64)), (0, (1, 16))), + c_layout=Layout((1, (16, 16)), (0, (1, 16)))) + +# -- UINT8 x UINT8 -> INT32 (unsigned x unsigned) ----------------------------- +AMX_16x16x64_S32U8U8S32 = MMAAtom( + name="AMX_16x16x64_S32U8U8S32", + ptx="tdpbuud", + shape_mnk=(16, 16, 64), thr_id=Layout(1), + a_layout=Layout((1, (16, 64)), (0, (1, 16))), + b_layout=Layout((1, (16, 64)), (0, (1, 16))), + c_layout=Layout((1, (16, 16)), (0, (1, 16)))) From 676d683935ce273f9bbad94e870f29530dd680c6 Mon Sep 17 00:00:00 2001 From: Jean-Luc Duprat Date: Tue, 24 Mar 2026 13:52:39 -0700 Subject: [PATCH 13/16] Add __str__ to MMAAtom and CopyAtom for concise display The dataclass-generated __repr__ includes all fields (layouts, PTX strings) producing 300+ character lines that are hard to scan in REPL sessions and logs. Add a short __str__ that shows just the atom name and shape: str(atom) -> MMAAtom('SM80_16x8x16_F32F16F16F32_TN', 16x8x16) str(copy) -> CopyAtom('SM75_U32x4_LDSM_N') The verbose eval-safe __repr__ from @dataclass is unchanged. Addresses REVIEW_ANALYSIS.md Section 5 (String Representations). --- src/tensor_layouts/atoms.py | 7 +++++ tests/analysis.py | 52 +++++++++++++++++++++++++++++++++++++ 2 files changed, 59 insertions(+) diff --git a/src/tensor_layouts/atoms.py b/src/tensor_layouts/atoms.py index 2280ef3..729b634 100644 --- a/src/tensor_layouts/atoms.py +++ b/src/tensor_layouts/atoms.py @@ -53,6 +53,10 @@ class MMAAtom: b_layout: Layout c_layout: Layout + def __str__(self) -> str: + m, n, k = self.shape_mnk + return f"MMAAtom('{self.name}', {m}x{n}x{k})" + @dataclass(frozen=True) class CopyAtom: @@ -72,3 +76,6 @@ class CopyAtom: thr_id: Layout src_layout_bits: Layout dst_layout_bits: Layout + + def __str__(self) -> str: + return f"CopyAtom('{self.name}')" diff --git a/tests/analysis.py b/tests/analysis.py index 90f7142..bdd5618 100644 --- a/tests/analysis.py +++ b/tests/analysis.py @@ -695,6 +695,58 @@ def test_explain_flat_divide(): assert '(tile0, tile1, ..., rest0, rest1, ...)' in text +## MMAAtom and CopyAtom __str__ + + +def test_mma_atom_str(): + """MMAAtom.__str__ returns a concise summary with name and shape.""" + from tensor_layouts.atoms import MMAAtom + + atom = MMAAtom( + name="test_16x8x4", + ptx="test.op", + shape_mnk=(16, 8, 4), + thr_id=Layout(32), + a_layout=Layout((32, 4), (4, 1)), + b_layout=Layout((32, 2), (2, 1)), + c_layout=Layout((32, 4), (4, 1)), + ) + assert str(atom) == "MMAAtom('test_16x8x4', 16x8x4)" + + +def test_copy_atom_str(): + """CopyAtom.__str__ returns a concise summary with name.""" + from tensor_layouts.atoms import CopyAtom + + atom = CopyAtom( + name="test_copy_128b", + ptx="test.copy", + thr_id=Layout(32), + src_layout_bits=Layout((32, 128), (128, 1)), + dst_layout_bits=Layout((32, 128), (128, 1)), + ) + assert str(atom) == "CopyAtom('test_copy_128b')" + + +def test_mma_atom_repr_is_verbose(): + """MMAAtom.__repr__ (dataclass-generated) includes all fields.""" + from tensor_layouts.atoms import MMAAtom + + atom = MMAAtom( + name="test_2x2x1", + ptx="test", + shape_mnk=(2, 2, 1), + thr_id=None, + a_layout=Layout(2, 1), + b_layout=Layout(2, 1), + c_layout=Layout((2, 2), (1, 2)), + ) + r = repr(atom) + assert r.startswith("MMAAtom(name=") + assert "shape_mnk=(2, 2, 1)" in r + assert "a_layout=Layout(2, 1)" in r + + if __name__ == "__main__": import subprocess import sys From e37faabba91dec0d0d6ded9e3d896abee8615914 Mon Sep 17 00:00:00 2001 From: Jean-Luc Duprat Date: Tue, 24 Mar 2026 14:03:58 -0700 Subject: [PATCH 14/16] Fix idx2crd to correctly wrap coordinates for scalar shapes Ensure idx2crd matches NVIDIA CuTe behavior where strictly scalar shapes always modulo-wrap the coordinate. Added an oracle differential test validating our idx2crd implementation against NVIDIA's authoritative pycute implementation across a range of shapes and indices. --- src/tensor_layouts/layouts.py | 2 ++ tests/oracle_nv.py | 15 +++++++++++++++ 2 files changed, 17 insertions(+) diff --git a/src/tensor_layouts/layouts.py b/src/tensor_layouts/layouts.py index dbbb106..9579430 100644 --- a/src/tensor_layouts/layouts.py +++ b/src/tensor_layouts/layouts.py @@ -1793,6 +1793,8 @@ def idx2crd(coord: Any, shape: Any) -> Any: """Convert index into a hierarchical coordinate.""" if isinstance(shape, int): + if isinstance(coord, int): + return coord % shape return coord # Case: Input is a single integer index for this entire sub-hierarchy diff --git a/tests/oracle_nv.py b/tests/oracle_nv.py index 0a8ce32..9f1e49e 100644 --- a/tests/oracle_nv.py +++ b/tests/oracle_nv.py @@ -1492,6 +1492,21 @@ def test_product_each_matches_pycute_size(): assert result == expected, f"product_each({shape}) = {result} != {expected}" + +@pytest.mark.skipif(pycute is None, reason="pycute not installed") +def test_oracle_idx2crd(): + shapes = [ + 4, + (4, 2), + (2, (2, 2)), + ((2, 2), (2, 2)), + ] + indices = [0, 1, 3, 5, 10, 16] + for s in shapes: + for idx in indices: + assert idx2crd(idx, s) == pycute.idx2crd(idx, s) + + if __name__ == "__main__": import traceback test_funcs = [v for k, v in sorted(globals().items()) if k.startswith("test_") and callable(v)] From 092f92a86660dc6ac4559ae698eb4d5821450ea3 Mon Sep 17 00:00:00 2001 From: Jean-Luc Duprat Date: Tue, 24 Mar 2026 17:38:29 -0700 Subject: [PATCH 15/16] Add missing copyright and license blocks to documentation files --- CONTRIBUTING.md | 2 +- docs/analysis_api.md | 24 ++++++++++++++++++++++++ docs/generate_figures.py | 22 ++++++++++++++++++++++ docs/layout_api.md | 24 ++++++++++++++++++++++++ docs/viz_api.md | 24 ++++++++++++++++++++++++ 5 files changed, 95 insertions(+), 1 deletion(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 95b2394..d25e08f 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -28,4 +28,4 @@ outlined on that page and do not file a public issue. ## License By contributing to tensor-layouts, you agree that your contributions will be licensed -under the LICENSE file in the root directory of this source tree. \ No newline at end of file +under the LICENSE file in the root directory of this source tree. diff --git a/docs/analysis_api.md b/docs/analysis_api.md index 9d684d2..496118e 100644 --- a/docs/analysis_api.md +++ b/docs/analysis_api.md @@ -1,3 +1,27 @@ + + # Analysis API GPU kernel performance lives or dies by memory access patterns. Two diff --git a/docs/generate_figures.py b/docs/generate_figures.py index 753aa5e..9e0402c 100644 --- a/docs/generate_figures.py +++ b/docs/generate_figures.py @@ -1,3 +1,25 @@ +# MIT License +# +# Copyright (c) 2026 Meta Platforms, Inc. and affiliates. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + #!/usr/bin/env python3 """Regenerate all PNG figures used in the documentation. diff --git a/docs/layout_api.md b/docs/layout_api.md index ff67bf6..0857981 100644 --- a/docs/layout_api.md +++ b/docs/layout_api.md @@ -1,3 +1,27 @@ + + # Layout Algebra API This document covers the core `tensor_layouts` API: constructing layouts, diff --git a/docs/viz_api.md b/docs/viz_api.md index d0d9e6f..78dc71e 100644 --- a/docs/viz_api.md +++ b/docs/viz_api.md @@ -1,3 +1,27 @@ + + # Visualization API This document covers the `tensor_layouts.viz` module for drawing layouts, From a872c2fca601f70430ea05162a883c10456ca3e2 Mon Sep 17 00:00:00 2001 From: Jean-Luc Duprat Date: Tue, 24 Mar 2026 18:22:48 -0700 Subject: [PATCH 16/16] [NFC] Project is now lint clean --- .gitignore | 1 + examples/layouts.py | 74 ++-- examples/viz.py | 676 ++++++++++++++++++---------- pyproject.toml | 2 +- src/tensor_layouts/__init__.py | 2 +- src/tensor_layouts/analysis.py | 415 +++++++++-------- src/tensor_layouts/atoms.py | 2 + src/tensor_layouts/atoms_amd.py | 577 +++++++++++++++++------- src/tensor_layouts/atoms_amx.py | 36 +- src/tensor_layouts/atoms_nv.py | 689 ++++++++++++++++++++--------- src/tensor_layouts/layout_utils.py | 23 +- src/tensor_layouts/layouts.py | 249 ++++++++--- src/tensor_layouts/tensor.py | 19 +- src/tensor_layouts/viz.py | 109 ++--- tests/analysis.py | 267 +++++------ tests/external.py | 196 ++++---- tests/layouts.py | 92 ++-- tests/oracle_amd.py | 324 +++++++++----- tests/oracle_nv.py | 283 +++++++----- tests/tensor.py | 24 +- tests/viz.py | 94 ++-- 21 files changed, 2604 insertions(+), 1550 deletions(-) diff --git a/.gitignore b/.gitignore index 755d4e3..c8787b7 100644 --- a/.gitignore +++ b/.gitignore @@ -21,6 +21,7 @@ htmlcov/ .idea/ *.swp *.swo +.arclint # OS .DS_Store diff --git a/examples/layouts.py b/examples/layouts.py index 4348e94..c04f864 100644 --- a/examples/layouts.py +++ b/examples/layouts.py @@ -35,14 +35,15 @@ viz.ipynb — Jupyter notebook gallery """ -from tensor_layouts import * -from tensor_layouts.analysis import * +from tensor_layouts import * # noqa: F401,F403,F405 +from tensor_layouts.analysis import * # noqa: F401,F403,F405 # ============================================================================= # Section 1: Layout Construction # ============================================================================= + def example_construction(): """Building layouts from shape and stride. @@ -91,6 +92,7 @@ def example_construction(): # Section 2: Querying Layouts # ============================================================================= + def example_querying(): """Query functions for shape, size, rank, and depth. @@ -106,10 +108,10 @@ def example_querying(): print(f" Layout: {layout}") print(f" shape: {layout.shape}") print(f" stride: {layout.stride}") - print(f" size: {size(layout)}") # Total number of elements - print(f" cosize: {cosize(layout)}") # Span: max offset + 1 - print(f" rank: {rank(layout)}") # Number of top-level modes - print(f" depth: {depth(layout)}") # Maximum nesting depth + print(f" size: {size(layout)}") # Total number of elements + print(f" cosize: {cosize(layout)}") # Span: max offset + 1 + print(f" rank: {rank(layout)}") # Number of top-level modes + print(f" depth: {depth(layout)}") # Maximum nesting depth # mode() extracts a single mode as a Layout print(f" mode 0: {mode(layout, 0)}") @@ -125,6 +127,7 @@ def example_querying(): # Section 3: Coordinate Mapping # ============================================================================= + def example_coordinate_mapping(): """Calling a layout to map coordinates to memory offsets. @@ -139,17 +142,17 @@ def example_coordinate_mapping(): print(f" Layout: {layout}") # Multi-dimensional coordinates - print(f" (0, 0) -> {layout(0, 0)}") # 0 - print(f" (2, 3) -> {layout(2, 3)}") # 2 + 12 = 14 - print(f" (3, 7) -> {layout(3, 7)}") # 3 + 28 = 31 + print(f" (0, 0) -> {layout(0, 0)}") # 0 + print(f" (2, 3) -> {layout(2, 3)}") # 2 + 12 = 14 + print(f" (3, 7) -> {layout(3, 7)}") # 3 + 28 = 31 # Flat index: column-major traversal of the domain - print(f" flat 0 -> {layout(0)}") # Same as (0,0) -> 0 - print(f" flat 5 -> {layout(5)}") # Same as (1,1) -> 5 - print(f" flat 31 -> {layout(31)}") # Same as (3,7) -> 31 + print(f" flat 0 -> {layout(0)}") # Same as (0,0) -> 0 + print(f" flat 5 -> {layout(5)}") # Same as (1,1) -> 5 + print(f" flat 31 -> {layout(31)}") # Same as (3,7) -> 31 # idx2crd: convert flat index to multi-dimensional coordinate - print(f"\n idx2crd(5, (4, 8)): {idx2crd(5, (4, 8))}") # (1, 1) + print(f"\n idx2crd(5, (4, 8)): {idx2crd(5, (4, 8))}") # (1, 1) print(f" idx2crd(14, (4, 8)): {idx2crd(14, (4, 8))}") # (2, 3) # crd2flat: convert coordinate to flat index @@ -160,6 +163,7 @@ def example_coordinate_mapping(): # Section 4: Tuple Arithmetic # ============================================================================= + def example_tuple_arithmetic(): """Arithmetic on nested tuples — the foundation of layout algebra. @@ -174,12 +178,12 @@ def example_tuple_arithmetic(): # column-major strides from a shape shape = (2, 3, 4) pp = prefix_product(shape) - print(f" prefix_product({shape}): {pp}") # (1, 2, 6) + print(f" prefix_product({shape}): {pp}") # (1, 2, 6) # This is exactly the column-major stride for shape (2,3,4) # suffix_product: running product from the right sp = suffix_product(shape) - print(f" suffix_product({shape}): {sp}") # (12, 4, 1) + print(f" suffix_product({shape}): {sp}") # (12, 4, 1) # This is exactly the row-major stride for shape (2,3,4) # inner_product: sum of element-wise products @@ -200,6 +204,7 @@ def example_tuple_arithmetic(): # Section 5: Layout Manipulation # ============================================================================= + def example_manipulation(): """Reshape and reorganize layouts without changing the mapping. @@ -244,6 +249,7 @@ def example_manipulation(): # Section 6: Composition # ============================================================================= + def example_composition(): """compose(A, B) — function composition: C(i) = A(B(i)). @@ -289,6 +295,7 @@ def example_composition(): # Section 7: Complement # ============================================================================= + def example_complement(): """complement(L) — the layout that fills in L's gaps. @@ -333,6 +340,7 @@ def example_complement(): # Section 8: Division # ============================================================================= + def example_division(): """logical_divide — split a layout into (tile, rest). @@ -382,6 +390,7 @@ def example_division(): # Section 9: Product # ============================================================================= + def example_product(): """logical_product — replicate A's pattern across B's domain. @@ -422,6 +431,7 @@ def example_product(): # Section 10: Inverse # ============================================================================= + def example_inverse(): """right_inverse, left_inverse — undo a layout's mapping. @@ -461,6 +471,7 @@ def example_inverse(): # Section 11: Swizzle # ============================================================================= + def example_swizzle(): """Swizzle(bits, base, shift) — XOR-based bank conflict avoidance. @@ -504,6 +515,7 @@ def example_swizzle(): # Section 12: Tensor # ============================================================================= + def example_tensor(): """Tensor — a Layout combined with a base offset. @@ -548,6 +560,7 @@ def example_tensor(): # Section 13: Tile # ============================================================================= + def example_tile(): """Tile — a tuple of Layouts for mode-by-mode composition. @@ -586,6 +599,7 @@ def example_tile(): # Section 14: Iteration # ============================================================================= + def example_iteration(): """Iterating over layouts. @@ -627,6 +641,7 @@ def example_iteration(): # Section 15: Image and Injectivity # ============================================================================= + def example_image_injectivity(): """Analyzing a layout as a function. @@ -668,6 +683,7 @@ def example_image_injectivity(): # Section 16: Functional Equivalence # ============================================================================= + def example_functional_equivalence(): """Checking if two layouts compute the same mapping. @@ -704,6 +720,7 @@ def example_functional_equivalence(): # Section 17: GPU Analysis # ============================================================================= + def example_analysis(): """Analyzing layouts for GPU performance. @@ -735,7 +752,9 @@ def example_analysis(): result = bank_conflicts(col_access, element_bytes=4) print(f" Column access (stride 8): {col_access}") print(f" conflict_free: {result['conflict_free']}") - print(f" max_ways: {result['max_ways']} (threads serialize {result['max_ways']}x)") + print( + f" max_ways: {result['max_ways']} (threads serialize {result['max_ways']}x)" + ) # --- Swizzle fix --- sw_tile = compose(Swizzle(3, 0, 3), tile) @@ -759,13 +778,17 @@ def example_analysis(): print(" " + "-" * 40) coalesced = Layout(32, 1) r1 = coalescing_efficiency(coalesced, element_bytes=4) - print(f" Stride-1 (fp32): {r1['transactions']} transaction, " - f"efficiency {r1['efficiency']:.0%}") + print( + f" Stride-1 (fp32): {r1['transactions']} transaction, " + f"efficiency {r1['efficiency']:.0%}" + ) scattered = Layout(32, 64) r2 = coalescing_efficiency(scattered, element_bytes=2) - print(f" Stride-64 (fp16): {r2['transactions']} transactions, " - f"efficiency {r2['efficiency']:.1%}") + print( + f" Stride-64 (fp16): {r2['transactions']} transactions, " + f"efficiency {r2['efficiency']:.1%}" + ) # --- Permutation structure --- print(f"\n Permutation Structure") @@ -792,11 +815,11 @@ def example_analysis(): print(" " + "-" * 40) layouts = [ - ("Contiguous 1D", Layout(8, 1)), - ("Strided 1D", Layout(8, 2)), - ("Col-major 4x8", Layout((4, 8), (1, 4))), - ("Gapped 4x8", Layout((4, 8), (1, 8))), - ("Row-major 4x8", Layout((4, 8), (8, 1))), + ("Contiguous 1D", Layout(8, 1)), + ("Strided 1D", Layout(8, 2)), + ("Col-major 4x8", Layout((4, 8), (1, 4))), + ("Gapped 4x8", Layout((4, 8), (1, 8))), + ("Row-major 4x8", Layout((4, 8), (8, 1))), ] for label, l in layouts: print(f" {label:20s} {str(l):25s} contiguity={contiguity(l)}") @@ -818,6 +841,7 @@ def example_analysis(): # Main # ============================================================================= + def main(): """Run all layout algebra examples.""" print("=" * 70) diff --git a/examples/viz.py b/examples/viz.py index 24620b4..d143b35 100644 --- a/examples/viz.py +++ b/examples/viz.py @@ -21,6 +21,8 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. +# ruff: noqa: F405 + """Visualization Examples Based on CuTe C++ Documentation. This file demonstrates the visualization capabilities of the layouts library, @@ -62,6 +64,7 @@ def setup_output_dir(name: str = "examples_output") -> Path: # Section 1: Output Formats (SVG, PNG, PDF) # ============================================================================= + def example_output_formats(output: Path): """Demonstrate SVG, PNG, and PDF output formats. @@ -89,71 +92,95 @@ def example_output_formats(output: Path): layout = Layout((4, 8), (8, 1)) # SVG output - vector format (default, best for most uses) - draw_layout(layout, output / "format_example.svg", - title="(4,8):(8,1)") + draw_layout(layout, output / "format_example.svg", title="(4,8):(8,1)") print(f"✓ SVG: format_example.svg (vector, scalable)") # PNG output - raster format with configurable DPI - draw_layout(layout, output / "format_example.png", - title="(4,8):(8,1)", dpi=150) + draw_layout(layout, output / "format_example.png", title="(4,8):(8,1)", dpi=150) print(f"✓ PNG: format_example.png (raster, 150 dpi)") # PDF output - print-ready format - draw_layout(layout, output / "format_example.pdf", - title="(4,8):(8,1)") + draw_layout(layout, output / "format_example.pdf", title="(4,8):(8,1)") print(f"✓ PDF: format_example.pdf (print-ready)") # Demonstrate color_layout options layout_8x8 = Layout((8, 8), (8, 1)) # Color by value (default) - same value = same color - draw_layout(layout_8x8, output / "color_by_value.svg", - title="color_layout=None (by value)") + draw_layout( + layout_8x8, output / "color_by_value.svg", title="color_layout=None (by value)" + ) print(f"✓ Color by value: color_by_value.svg") # Color by column - darker across columns (cute-viz style) - draw_layout(layout_8x8, output / "color_by_col.svg", - title="color_layout=(8,8):(0,1) (by column)", - color_layout=Layout((8, 8), (0, 1))) + draw_layout( + layout_8x8, + output / "color_by_col.svg", + title="color_layout=(8,8):(0,1) (by column)", + color_layout=Layout((8, 8), (0, 1)), + ) print(f"✓ Color by column: color_by_col.svg") # Color by row - darker down rows - draw_layout(layout_8x8, output / "color_by_row.svg", - title="color_layout=(8,8):(1,0) (by row)", - color_layout=Layout((8, 8), (1, 0))) + draw_layout( + layout_8x8, + output / "color_by_row.svg", + title="color_layout=(8,8):(1,0) (by row)", + color_layout=Layout((8, 8), (1, 0)), + ) print(f"✓ Color by row: color_by_row.svg") # Uniform color - no variation - draw_layout(layout_8x8, output / "color_uniform.svg", - title="color_layout=1:0 (uniform)", - color_layout=Layout(1, 0)) + draw_layout( + layout_8x8, + output / "color_uniform.svg", + title="color_layout=1:0 (uniform)", + color_layout=Layout(1, 0), + ) print(f"✓ Uniform color: color_uniform.svg") # Rainbow colors with different color_layout - draw_layout(layout_8x8, output / "color_by_col_rainbow.svg", - title="colorize=True, by column", - colorize=True, color_layout=Layout((8, 8), (0, 1))) + draw_layout( + layout_8x8, + output / "color_by_col_rainbow.svg", + title="colorize=True, by column", + colorize=True, + color_layout=Layout((8, 8), (0, 1)), + ) print(f"✓ Rainbow by column: color_by_col_rainbow.svg") # color_by shorthand — equivalent to the manual color_layout above - draw_layout(layout_8x8, output / "color_by_row_shorthand.svg", - title='color_by="row"', color_by="row") - draw_layout(layout_8x8, output / "color_by_col_shorthand.svg", - title='color_by="column"', color_by="column") - print(f"✓ color_by shorthand: color_by_row_shorthand.svg, color_by_col_shorthand.svg") + draw_layout( + layout_8x8, + output / "color_by_row_shorthand.svg", + title='color_by="row"', + color_by="row", + ) + draw_layout( + layout_8x8, + output / "color_by_col_shorthand.svg", + title='color_by="column"', + color_by="column", + ) + print( + f"✓ color_by shorthand: color_by_row_shorthand.svg, color_by_col_shorthand.svg" + ) # Swizzle comparison showing row-group coloring (reveals permutation effect) base = Layout((8, 8), (8, 1)) sw = Swizzle(3, 0, 3) draw_swizzle(base, sw, output / "swizzle_example.svg") draw_swizzle(base, sw, output / "swizzle_example_color.svg", colorize=True) - print(f"✓ Swizzle with row-group coloring: swizzle_example.svg, swizzle_example_color.svg") + print( + f"✓ Swizzle with row-group coloring: swizzle_example.svg, swizzle_example_color.svg" + ) # ============================================================================= # Section 2: 1D Layouts # ============================================================================= + def example_1d_layouts(output: Path): """1D contiguous and strided layouts. @@ -168,22 +195,21 @@ def example_1d_layouts(output: Path): # Contiguous 1D layout: 8 elements, stride 1 layout_1d_contiguous = Layout(8, 1) - draw_layout(layout_1d_contiguous, output / "1d_contiguous.svg", - title="1D Contiguous: 8:1") + draw_layout( + layout_1d_contiguous, output / "1d_contiguous.svg", title="1D Contiguous: 8:1" + ) print(f"✓ 1D Contiguous: 8:1") print(f" Maps index i → offset i (e.g., 3 → 3)") # Strided 1D layout: 8 elements, stride 2 layout_1d_strided = Layout(8, 2) - draw_layout(layout_1d_strided, output / "1d_strided.svg", - title="1D Strided: 8:2") + draw_layout(layout_1d_strided, output / "1d_strided.svg", title="1D Strided: 8:2") print(f"✓ 1D Strided: 8:2") print(f" Maps index i → offset 2*i (e.g., 3 → 6)") # Strided 1D layout: 4 elements, stride 4 layout_1d_stride4 = Layout(4, 4) - draw_layout(layout_1d_stride4, output / "1d_stride4.svg", - title="1D Stride-4: 4:4") + draw_layout(layout_1d_stride4, output / "1d_stride4.svg", title="1D Stride-4: 4:4") print(f"✓ 1D Stride-4: 4:4") print(f" Maps index i → offset 4*i (e.g., 2 → 8)") @@ -192,6 +218,7 @@ def example_1d_layouts(output: Path): # Section 3: 2D Layouts # ============================================================================= + def example_2d_layouts(output: Path): """2D row-major and column-major layouts. @@ -207,35 +234,55 @@ def example_2d_layouts(output: Path): # Row-major 4x3: shape (4 rows, 3 cols), stride (3, 1) # Row i, Col j → offset = i*3 + j row_major_4x3 = Layout((4, 3), (3, 1)) - draw_layout(row_major_4x3, output / "2d_row_major_4x3.svg", - title="Row-Major 4×3: (4,3):(3,1)") + draw_layout( + row_major_4x3, + output / "2d_row_major_4x3.svg", + title="Row-Major 4×3: (4,3):(3,1)", + ) print(f"✓ Row-Major 4×3: (4,3):(3,1)") print(f" offset(i,j) = i*3 + j*1") # Column-major 4x3: shape (4 rows, 3 cols), stride (1, 4) # Row i, Col j → offset = i*1 + j*4 col_major_4x3 = Layout((4, 3), (1, 4)) - draw_layout(col_major_4x3, output / "2d_col_major_4x3.svg", - title="Col-Major 4×3: (4,3):(1,4)") + draw_layout( + col_major_4x3, + output / "2d_col_major_4x3.svg", + title="Col-Major 4×3: (4,3):(1,4)", + ) print(f"✓ Col-Major 4×3: (4,3):(1,4)") print(f" offset(i,j) = i*1 + j*4") # 8x8 Row-major: shape (8 rows, 8 cols), stride (8, 1) # This is the common layout for matrix operations row_major_8x8 = Layout((8, 8), (8, 1)) - draw_layout(row_major_8x8, output / "2d_row_major_8x8.svg", - title="Row-Major 8×8: (8,8):(8,1)") - draw_layout(row_major_8x8, output / "2d_row_major_8x8_color.svg", - title="Row-Major 8×8: (8,8):(8,1)", colorize=True) + draw_layout( + row_major_8x8, + output / "2d_row_major_8x8.svg", + title="Row-Major 8×8: (8,8):(8,1)", + ) + draw_layout( + row_major_8x8, + output / "2d_row_major_8x8_color.svg", + title="Row-Major 8×8: (8,8):(8,1)", + colorize=True, + ) print(f"✓ Row-Major 8×8: (8,8):(8,1) [grayscale and colorized]") print(f" offset(i,j) = i*8 + j*1") # 8x8 Column-major: shape (8 rows, 8 cols), stride (1, 8) col_major_8x8 = Layout((8, 8), (1, 8)) - draw_layout(col_major_8x8, output / "2d_col_major_8x8.svg", - title="Col-Major 8×8: (8,8):(1,8)") - draw_layout(col_major_8x8, output / "2d_col_major_8x8_color.svg", - title="Col-Major 8×8: (8,8):(1,8)", colorize=True) + draw_layout( + col_major_8x8, + output / "2d_col_major_8x8.svg", + title="Col-Major 8×8: (8,8):(1,8)", + ) + draw_layout( + col_major_8x8, + output / "2d_col_major_8x8_color.svg", + title="Col-Major 8×8: (8,8):(1,8)", + colorize=True, + ) print(f"✓ Col-Major 8×8: (8,8):(1,8) [grayscale and colorized]") print(f" offset(i,j) = i*1 + j*8") @@ -244,6 +291,7 @@ def example_2d_layouts(output: Path): # Section 4: Hierarchical Layouts # ============================================================================= + def example_hierarchical_layouts(output: Path): """Hierarchical (nested) layouts - flattened and nested views. @@ -275,8 +323,12 @@ def example_hierarchical_layouts(output: Path): print(f" (({i},0),({j},0)) → {idx}") # Flat view (default) - draw_layout(hier_2x2_3x4, output / "hier_2x2_3x4_flat.svg", - title=f"Flat: {hier_2x2_3x4}", flatten_hierarchical=True) + draw_layout( + hier_2x2_3x4, + output / "hier_2x2_3x4_flat.svg", + title=f"Flat: {hier_2x2_3x4}", + flatten_hierarchical=True, + ) print(f"✓ Flat view: hier_2x2_3x4_flat.svg") # Nested pedagogical view: @@ -284,9 +336,13 @@ def example_hierarchical_layouts(output: Path): # - each cell shows col=... (nested column coordinate) # - each cell shows offset=... (resulting offset) # - axes stay simple (R0, R1, ... / C0, C1, ...) - draw_layout(hier_2x2_3x4, output / "hier_2x2_3x4_nested.svg", - title=f"Nested: {hier_2x2_3x4}", flatten_hierarchical=False, - label_hierarchy_levels=True) + draw_layout( + hier_2x2_3x4, + output / "hier_2x2_3x4_nested.svg", + title=f"Nested: {hier_2x2_3x4}", + flatten_hierarchical=False, + label_hierarchy_levels=True, + ) print(f"✓ Nested view: hier_2x2_3x4_nested.svg") # ========================================================================= @@ -306,13 +362,23 @@ def example_hierarchical_layouts(output: Path): # Tile (1,0): 4 6 Tile (1,1): 12 14 # 5 7 13 15 - draw_layout(logo_layout, output / "hier_2x2_tiles_flat.svg", - title=f"Flat: {logo_layout}", flatten_hierarchical=True) - draw_layout(logo_layout, output / "hier_2x2_tiles_nested.svg", - title=f"Nested: {logo_layout}", flatten_hierarchical=False, - label_hierarchy_levels=True) + draw_layout( + logo_layout, + output / "hier_2x2_tiles_flat.svg", + title=f"Flat: {logo_layout}", + flatten_hierarchical=True, + ) + draw_layout( + logo_layout, + output / "hier_2x2_tiles_nested.svg", + title=f"Nested: {logo_layout}", + flatten_hierarchical=False, + label_hierarchy_levels=True, + ) print(f"✓ Hierarchical 2×2 in 2×2 (logo layout): {logo_layout}") - print(f" Nested view is pedagogical: row=... / col=... show nested coordinates, offset=... shows mapping") + print( + f" Nested view is pedagogical: row=... / col=... show nested coordinates, offset=... shows mapping" + ) # ========================================================================= # Example 3: 3-level asymmetric hierarchy with per-level axis labels @@ -373,14 +439,16 @@ def example_hierarchical_layouts(output: Path): # Flatten the hierarchical layout (algebra operation) flat_layout = flatten(logo_layout) - draw_layout(flat_layout, output / "hier_flattened.svg", - title=f"flatten(): {flat_layout}") + draw_layout( + flat_layout, output / "hier_flattened.svg", title=f"flatten(): {flat_layout}" + ) print(f"✓ Flattened (algebra): {flat_layout}") # Coalesce to merge contiguous dimensions coal_layout = coalesce(logo_layout) - draw_layout(coal_layout, output / "hier_coalesced.svg", - title=f"coalesce(): {coal_layout}") + draw_layout( + coal_layout, output / "hier_coalesced.svg", title=f"coalesce(): {coal_layout}" + ) print(f"✓ Coalesced: {coal_layout}") # ========================================================================= @@ -390,53 +458,74 @@ def example_hierarchical_layouts(output: Path): # (4,8):(1,4) — column-major 4x8 cecka_1 = Layout((4, 8), (1, 4)) - draw_layout(cecka_1, output / "cecka_4x8_col.svg", - title="(4,8):(1,4)") + draw_layout(cecka_1, output / "cecka_4x8_col.svg", title="(4,8):(1,4)") print(f"✓ (4,8):(1,4) — column-major 4×8") # (4,8):(8,1) — row-major 4x8 cecka_2 = Layout((4, 8), (8, 1)) - draw_layout(cecka_2, output / "cecka_4x8_row.svg", - title="(4,8):(8,1)") + draw_layout(cecka_2, output / "cecka_4x8_row.svg", title="(4,8):(8,1)") print(f"✓ (4,8):(8,1) — row-major 4×8") # (4,8):(1,5) — non-injective layout (stride 5 with shape 8 wraps) cecka_3 = Layout((4, 8), (1, 5)) - draw_layout(cecka_3, output / "cecka_4x8_s1_s5.svg", - title="(4,8):(1,5)") + draw_layout(cecka_3, output / "cecka_4x8_s1_s5.svg", title="(4,8):(1,5)") print(f"✓ (4,8):(1,5) — non-injective (surjective) layout") # (4,(4,2)):(4,(1,16)) — hierarchical column dimension # Nested rendering explicitly shows how the hierarchical column coordinate # maps to the final offset for each displayed cell. cecka_4 = Layout((4, (4, 2)), (4, (1, 16))) - draw_layout(cecka_4, output / "cecka_hier_col.svg", - title="(4,(4,2)):(4,(1,16))", flatten_hierarchical=False, - label_hierarchy_levels=True) - draw_layout(cecka_4, output / "cecka_hier_col_flat.svg", - title="(4,(4,2)):(4,(1,16))", flatten_hierarchical=True) + draw_layout( + cecka_4, + output / "cecka_hier_col.svg", + title="(4,(4,2)):(4,(1,16))", + flatten_hierarchical=False, + label_hierarchy_levels=True, + ) + draw_layout( + cecka_4, + output / "cecka_hier_col_flat.svg", + title="(4,(4,2)):(4,(1,16))", + flatten_hierarchical=True, + ) print(f"✓ (4,(4,2)):(4,(1,16)) — hierarchical column") # ((2,2),(4,2)):((1,8),(2,16)) — hierarchical in both modes # This is a good example where explicit row=... / col=... labels help explain # the two-level row and column structure. cecka_5 = Layout(((2, 2), (4, 2)), ((1, 8), (2, 16))) - draw_layout(cecka_5, output / "cecka_hier_both.svg", - title="((2,2),(4,2)):((1,8),(2,16))", flatten_hierarchical=False, - label_hierarchy_levels=True) - draw_layout(cecka_5, output / "cecka_hier_both_flat.svg", - title="((2,2),(4,2)):((1,8),(2,16))", flatten_hierarchical=True) + draw_layout( + cecka_5, + output / "cecka_hier_both.svg", + title="((2,2),(4,2)):((1,8),(2,16))", + flatten_hierarchical=False, + label_hierarchy_levels=True, + ) + draw_layout( + cecka_5, + output / "cecka_hier_both_flat.svg", + title="((2,2),(4,2)):((1,8),(2,16))", + flatten_hierarchical=True, + ) print(f"✓ ((2,2),(4,2)):((1,8),(2,16)) — hierarchical both modes") # ((2,2),(2,4)):((0,2),(0,4)) — zero-stride (broadcast) layout # The pedagogical nested view is especially useful here because repeated # offsets are easier to interpret when the source coordinates are explicit. cecka_6 = Layout(((2, 2), (2, 4)), ((0, 2), (0, 4))) - draw_layout(cecka_6, output / "cecka_broadcast.svg", - title="((2,2),(2,4)):((0,2),(0,4))", flatten_hierarchical=False, - label_hierarchy_levels=True) - draw_layout(cecka_6, output / "cecka_broadcast_flat.svg", - title="((2,2),(2,4)):((0,2),(0,4))", flatten_hierarchical=True) + draw_layout( + cecka_6, + output / "cecka_broadcast.svg", + title="((2,2),(2,4)):((0,2),(0,4))", + flatten_hierarchical=False, + label_hierarchy_levels=True, + ) + draw_layout( + cecka_6, + output / "cecka_broadcast_flat.svg", + title="((2,2),(2,4)):((0,2),(0,4))", + flatten_hierarchical=True, + ) print(f"✓ ((2,2),(2,4)):((0,2),(0,4)) — broadcast (zero-stride) layout") # Morton/Z-order layout using blocked_product (CuTe pattern) @@ -446,14 +535,15 @@ def example_hierarchical_layouts(output: Path): morton1 = Layout((2, 2), (1, 2)) morton2 = blocked_product(morton1, morton1) morton3 = blocked_product(morton1, morton2) - draw_layout(morton1, output / "hier_morton_2x2.svg", - title=f"Morton 2×2: {morton1}") - draw_layout(morton2, output / "hier_morton_4x4.svg", - title=f"Morton 4×4: {morton2}") - draw_layout(morton3, output / "hier_morton_8x8.svg", - title=f"Morton 8×8: {morton3}") - draw_layout(morton3, output / "hier_morton_8x8_color.svg", - title=f"Morton 8×8: {morton3}", colorize=True) + draw_layout(morton1, output / "hier_morton_2x2.svg", title=f"Morton 2×2: {morton1}") + draw_layout(morton2, output / "hier_morton_4x4.svg", title=f"Morton 4×4: {morton2}") + draw_layout(morton3, output / "hier_morton_8x8.svg", title=f"Morton 8×8: {morton3}") + draw_layout( + morton3, + output / "hier_morton_8x8_color.svg", + title=f"Morton 8×8: {morton3}", + colorize=True, + ) print(f"✓ Morton 2×2: {morton1}") print(f"✓ Morton 4×4: {morton2}") print(f"✓ Morton 8×8: {morton3}") @@ -461,14 +551,12 @@ def example_hierarchical_layouts(output: Path): # Show nested mode access # Mode 0 is the row dimension with shape (2, 2) mode0 = mode(logo_layout, 0) - draw_layout(mode0, output / "hier_mode0.svg", - title=f"Mode 0 (rows): {mode0}") + draw_layout(mode0, output / "hier_mode0.svg", title=f"Mode 0 (rows): {mode0}") print(f"✓ Mode 0 (rows): {mode0}") # Mode 1 is the column dimension with shape (2, 2) mode1 = mode(logo_layout, 1) - draw_layout(mode1, output / "hier_mode1.svg", - title=f"Mode 1 (cols): {mode1}") + draw_layout(mode1, output / "hier_mode1.svg", title=f"Mode 1 (cols): {mode1}") print(f"✓ Mode 1 (cols): {mode1}") @@ -476,6 +564,7 @@ def example_hierarchical_layouts(output: Path): # Section 5: Swizzled Layouts # ============================================================================= + def example_swizzled_layouts(output: Path): """Swizzled layouts for GPU shared memory bank conflict avoidance. @@ -526,7 +615,9 @@ def example_swizzled_layouts(output: Path): # Column-major variant base_8x8_col = Layout((8, 8), (1, 8)) - draw_swizzle(base_8x8_col, sw_303, output / "swizzle_8x8_col_303.svg", colorize=True) + draw_swizzle( + base_8x8_col, sw_303, output / "swizzle_8x8_col_303.svg", colorize=True + ) print(f"✓ Swizzle(3,0,3) on 8×8 col-major") # 16x8 variant (common for tensor core) @@ -568,7 +659,9 @@ def example_swizzled_layouts(output: Path): # Canonical byte layout: 8 rows × 128 columns (128 bytes per row) sw_343 = Swizzle(3, 4, 3) base_8x128 = Layout((8, 128), (128, 1)) - draw_swizzle(base_8x128, sw_343, output / "swizzle_8x128_343_SW128.svg", colorize=True) + draw_swizzle( + base_8x128, sw_343, output / "swizzle_8x128_343_SW128.svg", colorize=True + ) print(f"✓ Swizzle(3,4,3) SW128 on 8×128: XOR bits [4,7) with [7,10)") # ========================================================================= @@ -578,7 +671,9 @@ def example_swizzled_layouts(output: Path): # Swizzle(0, M, S) is identity - no XOR applied sw_043 = Swizzle(0, 4, 3) - draw_swizzle(base_8x128, sw_043, output / "swizzle_8x128_043_none.svg", colorize=True) + draw_swizzle( + base_8x128, sw_043, output / "swizzle_8x128_043_none.svg", colorize=True + ) print(f"✓ Swizzle(0,4,3) on 8×128: Identity (no XOR)") @@ -586,6 +681,7 @@ def example_swizzled_layouts(output: Path): # Section 6: Thread-Value (TV) Layouts # ============================================================================= + def example_thread_value_layouts(output: Path): """Thread-Value (TV) layouts for GPU parallelism. @@ -604,38 +700,55 @@ def example_thread_value_layouts(output: Path): # Simple TV layout: 4 threads, 2 values each = 8 elements # Thread 0: V0, V1; Thread 1: V0, V1; etc. tv_4x2 = Layout((4, 2), (2, 1)) - draw_tv_layout(tv_4x2, output / "tv_4threads_2values.svg", - title="TV: (4,2):(2,1) - 4 threads, 2 values each") - draw_tv_layout(tv_4x2, output / "tv_4threads_2values_color.svg", - title="TV: (4,2):(2,1)", colorize=True) + draw_tv_layout( + tv_4x2, + output / "tv_4threads_2values.svg", + title="TV: (4,2):(2,1) - 4 threads, 2 values each", + ) + draw_tv_layout( + tv_4x2, + output / "tv_4threads_2values_color.svg", + title="TV: (4,2):(2,1)", + colorize=True, + ) print(f"✓ TV Layout 4×2: 4 threads, 2 values each") print(f" Thread t owns values V0, V1 at offsets 2*t and 2*t+1") # TV layout with interleaved threads tv_4x2_col = Layout((4, 2), (1, 4)) - draw_tv_layout(tv_4x2_col, output / "tv_4threads_2values_interleaved.svg", - title="TV interleaved: (4,2):(1,4)") + draw_tv_layout( + tv_4x2_col, + output / "tv_4threads_2values_interleaved.svg", + title="TV interleaved: (4,2):(1,4)", + ) print(f"✓ TV Layout 4×2 interleaved: offsets t and t+4") # 8x4 TV layout (smaller than full warp for clarity) tv_8x4 = Layout((8, 4), (4, 1)) - draw_tv_layout(tv_8x4, output / "tv_8x4.svg", - title="TV: (8,4):(4,1) - 8 threads, 4 values") - draw_tv_layout(tv_8x4, output / "tv_8x4_color.svg", - title="TV: (8,4):(4,1)", colorize=True) + draw_tv_layout( + tv_8x4, output / "tv_8x4.svg", title="TV: (8,4):(4,1) - 8 threads, 4 values" + ) + draw_tv_layout( + tv_8x4, output / "tv_8x4_color.svg", title="TV: (8,4):(4,1)", colorize=True + ) print(f"✓ TV Layout 8×4: 8 threads, 4 values each") # 8x8 TV layout (common for LDMATRIX) tv_8x8 = Layout((8, 8), (8, 1)) - draw_tv_layout(tv_8x8, output / "tv_8x8.svg", - title="TV: (8,8):(8,1) - 8 threads, 8 values") - draw_tv_layout(tv_8x8, output / "tv_8x8_color.svg", - title="TV: (8,8):(8,1)", colorize=True) + draw_tv_layout( + tv_8x8, output / "tv_8x8.svg", title="TV: (8,8):(8,1) - 8 threads, 8 values" + ) + draw_tv_layout( + tv_8x8, output / "tv_8x8_color.svg", title="TV: (8,8):(8,1)", colorize=True + ) print(f"✓ TV Layout 8×8: 8 threads, 8 values each (LDMATRIX style)") # Also show the regular layout view for comparison - draw_layout(tv_8x8, output / "tv_8x8_offsets.svg", - title="TV: (8,8):(8,1) - Memory offsets view") + draw_layout( + tv_8x8, + output / "tv_8x8_offsets.svg", + title="TV: (8,8):(8,1) - Memory offsets view", + ) print(f" (Also showing memory offset view for comparison)") @@ -643,6 +756,7 @@ def example_thread_value_layouts(output: Path): # Section 7: Copy Atom Traits (LDMATRIX, STMATRIX, TMA) # ============================================================================= + def example_copy_atoms(output: Path): """Copy atom TV layouts across GPU architectures. @@ -678,8 +792,9 @@ def example_copy_atoms(output: Path): ] for atom in ldsm_atoms: # draw_copy_atom handles upcast from bit to element coords automatically - draw_copy_atom(atom, element_bits=element_bits, - filename=output / f"{atom.name}_copy.svg") + draw_copy_atom( + atom, element_bits=element_bits, filename=output / f"{atom.name}_copy.svg" + ) dst = upcast(atom.dst_layout_bits, element_bits) n_thr = size(atom.thr_id) @@ -694,8 +809,9 @@ def example_copy_atoms(output: Path): stsm_atoms = [SM90_U32x4_STSM_N, SM90_U16x8_STSM_T] for atom in stsm_atoms: - draw_copy_atom(atom, element_bits=element_bits, - filename=output / f"{atom.name}_copy.svg") + draw_copy_atom( + atom, element_bits=element_bits, filename=output / f"{atom.name}_copy.svg" + ) print(f"✓ {atom.name} ({atom.ptx})") # ===================================================================== @@ -718,14 +834,19 @@ def example_copy_atoms(output: Path): # For fp16: Swizzle<3,4,3> ∘ (8, 64):(64, 1) = 8 rows × 64 cols print("\n TMA target: GMMA K-major SW128 smem layout (fp16):") base_tma = Layout((8, 64), (64, 1)) - draw_swizzle(base_tma, Swizzle(3, 4, 3), - output / "SM90_TMA_GMMA_K_SW128.svg", colorize=True) + draw_swizzle( + base_tma, Swizzle(3, 4, 3), output / "SM90_TMA_GMMA_K_SW128.svg", colorize=True + ) print(f"✓ SM90 TMA → GMMA K-major SW128: Swizzle(3,4,3) ∘ (8,64):(64,1)") print("\n TMA target: GMMA M|N-major SW128 smem layout (fp16):") base_tma_mn = Layout((64, 8), (1, 64)) - draw_swizzle(base_tma_mn, Swizzle(3, 4, 3), - output / "SM90_TMA_GMMA_MN_SW128.svg", colorize=True) + draw_swizzle( + base_tma_mn, + Swizzle(3, 4, 3), + output / "SM90_TMA_GMMA_MN_SW128.svg", + colorize=True, + ) print(f"✓ SM90 TMA → GMMA M|N-major SW128: Swizzle(3,4,3) ∘ (64,8):(1,64)") # ===================================================================== @@ -733,8 +854,9 @@ def example_copy_atoms(output: Path): # ===================================================================== print("\n --- LDMATRIX Shared Memory with Swizzle ---") smem_8x8 = Layout((8, 8), (8, 1)) - draw_swizzle(smem_8x8, Swizzle(3, 0, 3), - output / "ldmatrix_smem_swizzle.svg", colorize=True) + draw_swizzle( + smem_8x8, Swizzle(3, 0, 3), output / "ldmatrix_smem_swizzle.svg", colorize=True + ) print(f"✓ LDMATRIX shared memory with Swizzle(3,0,3)") @@ -742,6 +864,7 @@ def example_copy_atoms(output: Path): # Section 8: MMA Atom Traits # ============================================================================= + def _draw_mma_atom(atom, output: Path): """Draw A, B, C, and combined figures for one MMA atom.""" name = atom.name @@ -757,23 +880,44 @@ def _draw_mma_atom(atom, output: Path): b_rows, b_cols = cs_b // N if cs_b % N == 0 else K, N c_rows, c_cols = M, cs_c // M if cs_c % M == 0 else N - draw_tv_layout(atom.a_layout, output / f"{name}_A.svg", - title=f"{name} A ({a_rows}×{a_cols})", - colorize=True, grid_shape=(a_rows, a_cols), thr_id_layout=thr) + draw_tv_layout( + atom.a_layout, + output / f"{name}_A.svg", + title=f"{name} A ({a_rows}×{a_cols})", + colorize=True, + grid_shape=(a_rows, a_cols), + thr_id_layout=thr, + ) - draw_tv_layout(atom.b_layout, output / f"{name}_B.svg", - title=f"{name} B ({b_rows}×{b_cols})", - colorize=True, grid_shape=(b_rows, b_cols), thr_id_layout=thr, - col_major=False) + draw_tv_layout( + atom.b_layout, + output / f"{name}_B.svg", + title=f"{name} B ({b_rows}×{b_cols})", + colorize=True, + grid_shape=(b_rows, b_cols), + thr_id_layout=thr, + col_major=False, + ) - draw_tv_layout(atom.c_layout, output / f"{name}_C.svg", - title=f"{name} C ({c_rows}×{c_cols})", - colorize=True, grid_shape=(c_rows, c_cols), thr_id_layout=thr) + draw_tv_layout( + atom.c_layout, + output / f"{name}_C.svg", + title=f"{name} C ({c_rows}×{c_cols})", + colorize=True, + grid_shape=(c_rows, c_cols), + thr_id_layout=thr, + ) - draw_mma_layout(atom.a_layout, atom.b_layout, atom.c_layout, - output / f"{name}_combined.svg", - tile_mnk=(a_rows, c_cols, a_cols), main_title=name, - colorize=True, thr_id_layout=thr) + draw_mma_layout( + atom.a_layout, + atom.b_layout, + atom.c_layout, + output / f"{name}_combined.svg", + tile_mnk=(a_rows, c_cols, a_cols), + main_title=name, + colorize=True, + thr_id_layout=thr, + ) n_thr = size(mode(atom.c_layout, 0)) n_val_a = size(mode(atom.a_layout, 1)) @@ -815,14 +959,16 @@ def _draw_tiled_mma(atom, atom_layout, output: Path, tile_mnk=None): label = f"{atom.name}_{n_am}x{n_an}_{M}x{N}x{K}" # Compute tiled grids - c_grid, _ = tile_mma_grid(atom, atom_layout, 'C', tile_mnk=tile_mnk) - a_grid, _ = tile_mma_grid(atom, atom_layout, 'A', tile_mnk=tile_mnk) - b_grid, _ = tile_mma_grid(atom, atom_layout, 'B', tile_mnk=tile_mnk) - - draw_tiled_grid(c_grid, M, N, output / f"{label}_C.svg", - title=f"{label} C ({M}×{N})") - draw_tiled_grid(a_grid, M, K, output / f"{label}_A.svg", - title=f"{label} A ({M}×{K})") + c_grid, _ = tile_mma_grid(atom, atom_layout, "C", tile_mnk=tile_mnk) + a_grid, _ = tile_mma_grid(atom, atom_layout, "A", tile_mnk=tile_mnk) + b_grid, _ = tile_mma_grid(atom, atom_layout, "B", tile_mnk=tile_mnk) + + draw_tiled_grid( + c_grid, M, N, output / f"{label}_C.svg", title=f"{label} C ({M}×{N})" + ) + draw_tiled_grid( + a_grid, M, K, output / f"{label}_A.svg", title=f"{label} A ({M}×{K})" + ) # B displayed as K×N (transposed) b_display = {} for (r, c), val in b_grid.items(): @@ -830,17 +976,28 @@ def _draw_tiled_mma(atom, atom_layout, output: Path, tile_mnk=None): n_coord = r k_coord = c b_display[(k_coord, n_coord)] = val - draw_tiled_grid(b_display, K, N, output / f"{label}_B.svg", - title=f"{label} B ({K}×{N})") + draw_tiled_grid( + b_display, K, N, output / f"{label}_B.svg", title=f"{label} B ({K}×{N})" + ) # Combined figure: A (left), B (top-right), C (bottom-right) - draw_combined_mma_grid(a_grid, b_display, c_grid, M, N, K, - output / f"{label}_combined.svg", title=label) + draw_combined_mma_grid( + a_grid, + b_display, + c_grid, + M, + N, + K, + output / f"{label}_combined.svg", + title=label, + ) print(f"✓ Tiled MMA: {label}") - print(f" {size(atom_layout)} atoms ({n_am}×{n_an}), " - f"tile {M}×{N}×{K}, " - f"{size(atom_layout) * size(mode(atom.c_layout, 0))} threads") + print( + f" {size(atom_layout)} atoms ({n_am}×{n_an}), " + f"tile {M}×{N}×{K}, " + f"{size(atom_layout) * size(mode(atom.c_layout, 0))} threads" + ) def example_mma_atom(output: Path): @@ -866,8 +1023,7 @@ def example_mma_atom(output: Path): # make_tiled_mma(SM70_8x8x4_F32F16F16F32_NT{}, # Layout, Stride<_2,_1>>{}); # Reference: HMMA.8x8x4.NT_2x2.png - _draw_tiled_mma(SM70_8x8x4_F32F16F16F32_NT, - Layout((2, 2), (2, 1)), output) + _draw_tiled_mma(SM70_8x8x4_F32F16F16F32_NT, Layout((2, 2), (2, 1)), output) # Tiled MMA expanded to 32×32×4 via value tiling # Equivalent to C++: @@ -875,9 +1031,9 @@ def example_mma_atom(output: Path): # Layout, Stride<_2,_1>>{}, # Tile<_32,_32,_4>{}); # Reference: HMMA.8x8x4.NT_2x2_32x32x4.png - _draw_tiled_mma(SM70_8x8x4_F32F16F16F32_NT, - Layout((2, 2), (2, 1)), output, - tile_mnk=(32, 32, 4)) + _draw_tiled_mma( + SM70_8x8x4_F32F16F16F32_NT, Layout((2, 2), (2, 1)), output, tile_mnk=(32, 32, 4) + ) # SM80 Ampere print("\n --- SM80 Ampere (32 threads, warp) ---") @@ -893,9 +1049,13 @@ def example_mma_atom(output: Path): print("\n --- SM90 Hopper GMMA (128 threads, warpgroup) ---") for atom in [SM90_64x8x16_F16F16F16_SS, SM90_64x64x16_F16F16F16_SS]: M, N, K = atom.shape_mnk - draw_tv_layout(atom.c_layout, output / f"{atom.name}_C.svg", - title=f"{atom.name} C ({M}×{N})", - colorize=True, grid_shape=(M, N)) + draw_tv_layout( + atom.c_layout, + output / f"{atom.name}_C.svg", + title=f"{atom.name} C ({M}×{N})", + colorize=True, + grid_shape=(M, N), + ) n_vals = size(mode(atom.c_layout, 1)) print(f"✓ {atom.name} C: 128 thr × {n_vals} vals = {128*n_vals} elements") @@ -924,17 +1084,24 @@ def example_mma_atom(output: Path): # Compare: SM90 GMMA 64×8 C uses 128 threads with hierarchical layout sm90_atom = SM90_64x8x16_F16F16F16_SS sm90_c = sm90_atom.c_layout - draw_tv_layout(sm90_c, output / "SM100_compare_SM90_64x8_C.svg", - title="SM90 GMMA C (64×8) — 128 threads", - colorize=True, grid_shape=(64, 8)) + draw_tv_layout( + sm90_c, + output / "SM100_compare_SM90_64x8_C.svg", + title="SM90 GMMA C (64×8) — 128 threads", + colorize=True, + grid_shape=(64, 8), + ) print(f"✓ SM90 GMMA 64×8 C: {sm90_c} (128 thr × {size(mode(sm90_c, 1))} vals)") # SM100 UMMA 64×8 C — same logical tile, 1 "thread", all values umma_atom = make_umma_atom_ss(64, 8) umma_c = umma_atom.c_layout - draw_layout(umma_c, output / "SM100_compare_UMMA_64x8_C.svg", - title="SM100 UMMA C (64×8) — 1 thread, TMEM", - flatten_hierarchical=True) + draw_layout( + umma_c, + output / "SM100_compare_UMMA_64x8_C.svg", + title="SM100 UMMA C (64×8) — 1 thread, TMEM", + flatten_hierarchical=True, + ) print(f"✓ SM100 UMMA 64×8 C: {umma_c} (1 thr × {size(umma_c)} vals)") # SM120 Blackwell B200 — warp-level FP8 @@ -957,6 +1124,7 @@ def example_mma_atom(output: Path): # Section 9: Slicing Examples # ============================================================================= + def example_slicing(output: Path): """Slicing layouts - row, column, and complex discontinuous slices. @@ -973,25 +1141,32 @@ def example_slicing(output: Path): base = Layout((8, 8), (8, 1)) # Row slice: select row 3 (all columns) - draw_slice(base, (3, None), output / "slice_row3.svg", - title="Row Slice: layout(3, :)") + draw_slice( + base, (3, None), output / "slice_row3.svg", title="Row Slice: layout(3, :)" + ) print(f"✓ Row slice: layout(3, :)") print(f" Selects all 8 elements in row 3") # Column slice: select column 5 (all rows) - draw_slice(base, (None, 5), output / "slice_col5.svg", - title="Column Slice: layout(:, 5)") + draw_slice( + base, (None, 5), output / "slice_col5.svg", title="Column Slice: layout(:, 5)" + ) print(f"✓ Column slice: layout(:, 5)") print(f" Selects all 8 elements in column 5") # Single element - draw_slice(base, (4, 6), output / "slice_element.svg", - title="Single Element: layout(4, 6)") + draw_slice( + base, (4, 6), output / "slice_element.svg", title="Single Element: layout(4, 6)" + ) print(f"✓ Single element: layout(4, 6)") # Rectangular region: rows 2-5, columns 1-4 - draw_slice(base, (slice(2, 6), slice(1, 5)), output / "slice_rect.svg", - title="Rectangular: layout[2:6, 1:5]") + draw_slice( + base, + (slice(2, 6), slice(1, 5)), + output / "slice_rect.svg", + title="Rectangular: layout[2:6, 1:5]", + ) print(f"✓ Rectangular region: layout[2:6, 1:5]") print(f" Selects 4×4 = 16 elements") @@ -1001,33 +1176,40 @@ def example_slicing(output: Path): # Divide 8 rows into 4 groups of 2 divided = logical_divide(base, Layout((2, 4), (1, 2))) - draw_layout(divided, output / "slice_divided_base.svg", - title="Divided: 2-row groups") + draw_layout( + divided, output / "slice_divided_base.svg", title="Divided: 2-row groups" + ) print(f"✓ Divided layout: groups of 2 rows") # Tile-based slicing: select every other 2x2 tile tiled = Layout(((2, 4), (2, 4)), ((1, 16), (2, 8))) - draw_layout(tiled, output / "slice_tiled.svg", - title="Tiled: ((2,4),(2,4)):((1,16),(2,8))") + draw_layout( + tiled, output / "slice_tiled.svg", title="Tiled: ((2,4),(2,4)):((1,16),(2,8))" + ) print(f"✓ Tiled layout: 2×2 tiles in 8×8") # Strided row access (every other row) strided_rows = Layout((4, 8), (16, 1)) - draw_layout(strided_rows, output / "slice_strided_rows.svg", - title="Strided Rows: (4,8):(16,1)") + draw_layout( + strided_rows, + output / "slice_strided_rows.svg", + title="Strided Rows: (4,8):(16,1)", + ) print(f"✓ Strided rows: every other row") # Strided column access (every other column) strided_cols = Layout((8, 4), (8, 2)) - draw_layout(strided_cols, output / "slice_strided_cols.svg", - title="Strided Cols: (8,4):(8,2)") + draw_layout( + strided_cols, + output / "slice_strided_cols.svg", + title="Strided Cols: (8,4):(8,2)", + ) print(f"✓ Strided columns: every other column") # Diagonal-like pattern using hierarchical layout # Access elements (0,0), (1,1), (2,2), (3,3), ... diag = Layout(8, 9) # stride 9 = 8+1 gives diagonal - draw_layout(diag, output / "slice_diagonal.svg", - title="Diagonal: 8:9") + draw_layout(diag, output / "slice_diagonal.svg", title="Diagonal: 8:9") print(f"✓ Diagonal access: stride 9 (row_stride + 1)") # ========================================================================= @@ -1036,45 +1218,63 @@ def example_slicing(output: Path): # ========================================================================= print(f"\n --- Cecka Hierarchical Slicing ---") cecka_t = Layout(((3, 2), ((2, 3), 2)), ((4, 1), ((2, 15), 100))) - draw_layout(cecka_t, output / "cecka_slice_base.svg", - title=f"Tensor: {cecka_t}", flatten_hierarchical=True) - draw_layout(cecka_t, output / "cecka_slice_base_nested.svg", - title=f"Tensor: {cecka_t}", flatten_hierarchical=False, - label_hierarchy_levels=True) + draw_layout( + cecka_t, + output / "cecka_slice_base.svg", + title=f"Tensor: {cecka_t}", + flatten_hierarchical=True, + ) + draw_layout( + cecka_t, + output / "cecka_slice_base_nested.svg", + title=f"Tensor: {cecka_t}", + flatten_hierarchical=False, + label_hierarchy_levels=True, + ) print(f"✓ Base tensor: {cecka_t}") # Slice (2, None) — fix mode-0 to flat index 2, keep all of mode-1 - draw_slice(cecka_t, (2, None), output / "cecka_slice_2_None.svg", - title="(2,:)") + draw_slice(cecka_t, (2, None), output / "cecka_slice_2_None.svg", title="(2,:)") print(f"✓ Slice (2, :) — fix row to 2") # Slice (None, 5) — keep all of mode-0, fix mode-1 to flat index 5 - draw_slice(cecka_t, (None, 5), output / "cecka_slice_None_5.svg", - title="(:,5)") + draw_slice(cecka_t, (None, 5), output / "cecka_slice_None_5.svg", title="(:,5)") print(f"✓ Slice (:, 5) — fix col to 5") # Slice (2, ((0,None),None)) — fix mode-0 to 2, partially slice mode-1 - draw_slice(cecka_t, (2, ((0, None), None)), - output / "cecka_slice_2_0NN.svg", - title="(2,((0,:),:))") + draw_slice( + cecka_t, + (2, ((0, None), None)), + output / "cecka_slice_2_0NN.svg", + title="(2,((0,:),:))", + ) print(f"✓ Slice (2, ((0,:),:)) — fix row=2, inner-col-0=0, rest free") # Slice ((None,1),(None,0)) — fix outer-row=1, inner-col-outer=0 - draw_slice(cecka_t, ((None, 1), (None, 0)), - output / "cecka_slice_N1_N0.svg", - title="((:,1),(:,0))") + draw_slice( + cecka_t, + ((None, 1), (None, 0)), + output / "cecka_slice_N1_N0.svg", + title="((:,1),(:,0))", + ) print(f"✓ Slice ((:,1), (:,0)) — outer-row=1, mode-1 partially fixed") # Slice ((None,0),((0,None),1)) — outer-row=0, inner-col-0=0, outer-col=1 - draw_slice(cecka_t, ((None, 0), ((0, None), 1)), - output / "cecka_slice_N0_0N1.svg", - title="((:,0),((0,:),1))") + draw_slice( + cecka_t, + ((None, 0), ((0, None), 1)), + output / "cecka_slice_N0_0N1.svg", + title="((:,0),((0,:),1))", + ) print(f"✓ Slice ((:,0), ((0,:),1)) — outer-row=0, inner-0=0, outer-col=1") # Slice ((1,None),((None,0),None)) — inner-row=1, middle-col=0 - draw_slice(cecka_t, ((1, None), ((None, 0), None)), - output / "cecka_slice_1N_N0N.svg", - title="((1,:),((:,0),:))") + draw_slice( + cecka_t, + ((1, None), ((None, 0), None)), + output / "cecka_slice_1N_N0N.svg", + title="((1,:),((:,0),:))", + ) print(f"✓ Slice ((1,:), ((:,0),:)) — inner-row=1, middle-col=0") @@ -1082,6 +1282,7 @@ def example_slicing(output: Path): # Section 10: Layout Algebra Operations # ============================================================================= + def example_algebra_operations(output: Path): """Layout algebra operations: compose, complement, divide, product. @@ -1098,51 +1299,65 @@ def example_algebra_operations(output: Path): inner = Layout((4, 2), (1, 4)) outer = Layout(8, 2) composed = compose(outer, inner) - draw_layout(inner, output / "algebra_inner.svg", - title=f"Inner: {inner}") - draw_layout(composed, output / "algebra_composed.svg", - title=f"Composed: compose({outer}, {inner})") + draw_layout(inner, output / "algebra_inner.svg", title=f"Inner: {inner}") + draw_layout( + composed, + output / "algebra_composed.svg", + title=f"Composed: compose({outer}, {inner})", + ) print(f"✓ Composition: compose({outer}, {inner}) = {composed}") # Complement: find layout that covers remaining indices base = Layout((4, 2), (2, 1)) comp = complement(base, 16) - draw_layout(base, output / "algebra_base.svg", - title=f"Base: {base}") - draw_layout(comp, output / "algebra_complement.svg", - title=f"Complement: complement({base}, 16)") + draw_layout(base, output / "algebra_base.svg", title=f"Base: {base}") + draw_layout( + comp, + output / "algebra_complement.svg", + title=f"Complement: complement({base}, 16)", + ) print(f"✓ Complement: complement({base}, 16) = {comp}") # Logical divide: tile a layout matrix = Layout((8, 8), (8, 1)) tiler = Layout((2, 2), (1, 2)) divided = logical_divide(matrix, tiler) - draw_layout(matrix, output / "algebra_matrix.svg", - title=f"Matrix: {matrix}") - draw_layout(divided, output / "algebra_divided.svg", - title=f"Divided: logical_divide by {tiler}") + draw_layout(matrix, output / "algebra_matrix.svg", title=f"Matrix: {matrix}") + draw_layout( + divided, + output / "algebra_divided.svg", + title=f"Divided: logical_divide by {tiler}", + ) print(f"✓ Logical divide: 8×8 by 2×2 tiler") # Logical product: replicate a layout tile = Layout((2, 2), (2, 1)) grid = Layout((4, 4), (1, 4)) product = logical_product(tile, grid) - draw_layout(tile, output / "algebra_tile.svg", - title=f"Tile: {tile}") - draw_layout(product, output / "algebra_product.svg", - title=f"Product: logical_product({tile}, {grid})") + draw_layout(tile, output / "algebra_tile.svg", title=f"Tile: {tile}") + draw_layout( + product, + output / "algebra_product.svg", + title=f"Product: logical_product({tile}, {grid})", + ) print(f"✓ Logical product: {tile} × {grid}") # Rank >= 3 results: flat_divide and flat_product produce rank-3 layouts # that are now automatically rendered as multi-panel 2D grids fd = flat_divide(matrix, Layout(2, 1)) - draw_layout(fd, output / "algebra_flat_divide.svg", - title=f"flat_divide result (rank {rank(fd)})") + draw_layout( + fd, + output / "algebra_flat_divide.svg", + title=f"flat_divide result (rank {rank(fd)})", + ) print(f"✓ flat_divide: shape={fd.shape}, rank={rank(fd)} → multi-panel") fp = flat_product(Layout((2, 2), (1, 2)), Layout(4, 1)) - draw_layout(fp, output / "algebra_flat_product.svg", - title=f"flat_product result (rank {rank(fp)})") + draw_layout( + fp, + output / "algebra_flat_product.svg", + title=f"flat_product result (rank {rank(fp)})", + ) print(f"✓ flat_product: shape={fp.shape}, rank={rank(fp)} → multi-panel") @@ -1150,6 +1365,7 @@ def example_algebra_operations(output: Path): # Section 11: Tensor Slicing with Visualization # ============================================================================= + def example_tensor_slicing(output: Path): """Tensor slicing - visualizing tensors directly. @@ -1174,10 +1390,12 @@ def example_tensor_slicing(output: Path): print(f"✓ Offset tensor (offset=16): cell (0,0) = {tensor_16(0, 0)}") # Side-by-side comparison using draw_composite - draw_composite([tensor, tensor_16], - output / "tensor_offset_compare.svg", - titles=["offset=0", "offset=16"], - main_title="Tensor Offset Comparison") + draw_composite( + [tensor, tensor_16], + output / "tensor_offset_compare.svg", + titles=["offset=0", "offset=16"], + main_title="Tensor Offset Comparison", + ) print(f"✓ Tensor comparison: tensor_offset_compare.svg") # Swizzled tensor — swizzle applied to total linear offset @@ -1191,14 +1409,14 @@ def example_tensor_slicing(output: Path): row2 = tensor[2, :] print(f"\n tensor[2, :] = {row2}") print(f" tensor[2, :](0) = {row2(0)}, tensor(2, 0) = {tensor(2, 0)}") - draw_layout(row2, output / "tensor_slice_row2.svg", - title=f"tensor[2, :] = {row2}") + draw_layout(row2, output / "tensor_slice_row2.svg", title=f"tensor[2, :] = {row2}") # ============================================================================= # Main Entry Point # ============================================================================= + def main(output_dir: str = "examples_output"): """Run all visualization examples.""" output = setup_output_dir(output_dir) diff --git a/pyproject.toml b/pyproject.toml index a382636..284ec6c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -62,7 +62,7 @@ ignore = ["E501"] [tool.ruff.lint.per-file-ignores] "__init__.py" = ["F401", "F403"] "tests/*.py" = ["F403", "F405", "E741"] -"examples/*" = ["E402", "F403", "F405", "F541"] +"examples/*" = ["E402", "E741", "F403", "F405", "F541"] "src/layout_algebra/layout_utils.py" = ["F403", "F405"] "src/layout_algebra/tensor.py" = ["F403", "F405"] "src/layout_algebra/viz.py" = ["F403", "F405"] diff --git a/src/tensor_layouts/__init__.py b/src/tensor_layouts/__init__.py index c53e3a1..c1130b0 100644 --- a/src/tensor_layouts/__init__.py +++ b/src/tensor_layouts/__init__.py @@ -22,7 +22,7 @@ """Pure-Python implementation of GPU layout algebra.""" -from .layouts import * # noqa: F401,F403 +from .layouts import * # noqa: F401,F403,F405 from .tensor import Tensor # noqa: F401 from .atoms import MMAAtom, CopyAtom # noqa: F401 diff --git a/src/tensor_layouts/analysis.py b/src/tensor_layouts/analysis.py index 9d01a2c..88bd5cb 100644 --- a/src/tensor_layouts/analysis.py +++ b/src/tensor_layouts/analysis.py @@ -20,6 +20,8 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. +# ruff: noqa: F405 + """GPU layout analysis: bank conflicts, coalescing, permutation structure. These functions analyze layouts in GPU-specific contexts --- shared memory @@ -59,6 +61,7 @@ # Inverse mapping # ============================================================================= + def offset_table(layout: Layout) -> dict: """Return {offset: [coord, ...]} mapping each offset to its coordinates. @@ -122,13 +125,13 @@ def footprint(layout: Layout) -> dict: span = max_off - min_off + 1 if offsets else 0 return { - 'min_offset': min_off, - 'max_offset': max_off, - 'span': span, - 'unique_offsets': n_unique, - 'total_elements': n_total, - 'reuse_factor': n_total / n_unique if n_unique > 0 else 0.0, - 'holes': span - n_unique, + "min_offset": min_off, + "max_offset": max_off, + "span": span, + "unique_offsets": n_unique, + "total_elements": n_total, + "reuse_factor": n_total / n_unique if n_unique > 0 else 0.0, + "holes": span - n_unique, } @@ -136,9 +139,15 @@ def footprint(layout: Layout) -> dict: # Bank conflict analysis # ============================================================================= -def bank_conflicts(layout: Layout, *, element_bytes: int, - num_banks: int = 32, bank_width_bytes: int = 4, - group_size: int = 32): + +def bank_conflicts( + layout: Layout, + *, + element_bytes: int, + num_banks: int = 32, + bank_width_bytes: int = 4, + group_size: int = 32, +): """Analyze shared memory bank conflicts for a thread-to-offset layout. Given a layout that maps thread indices to shared memory offsets, @@ -221,9 +230,9 @@ def bank_conflicts(layout: Layout, *, element_bytes: int, max_ways = ways return { - 'conflict_free': max_ways <= 1, - 'max_ways': max_ways, - 'bank_to_threads': bank_to_threads, + "conflict_free": max_ways <= 1, + "max_ways": max_ways, + "bank_to_threads": bank_to_threads, } @@ -231,9 +240,14 @@ def bank_conflicts(layout: Layout, *, element_bytes: int, # Coalescing analysis # ============================================================================= -def coalescing_efficiency(layout: Layout, *, element_bytes: int, - warp_size: int = 32, - cache_line_bytes: int = 128): + +def coalescing_efficiency( + layout: Layout, + *, + element_bytes: int, + warp_size: int = 32, + cache_line_bytes: int = 128, +): """Analyze global memory coalescing for a thread-to-offset layout. Given a layout that maps thread indices to global memory offsets, @@ -294,16 +308,20 @@ def coalescing_efficiency(layout: Layout, *, element_bytes: int, efficiency = useful_bytes / transferred_bytes if transferred_bytes > 0 else 0.0 return { - 'transactions': transactions, - 'efficiency': efficiency, - 'cache_lines': sorted(cache_lines), + "transactions": transactions, + "efficiency": efficiency, + "cache_lines": sorted(cache_lines), } -def segment_analysis(layout: Layout, *, element_bytes: int, - warp_size: int = 32, - segment_bytes: int = 32, - cache_line_bytes: int = 128): +def segment_analysis( + layout: Layout, + *, + element_bytes: int, + warp_size: int = 32, + segment_bytes: int = 32, + cache_line_bytes: int = 128, +): """Segment- and alignment-aware global memory transaction analysis. A more detailed model than ``coalescing_efficiency()``. NVIDIA GPUs @@ -361,14 +379,14 @@ def segment_analysis(layout: Layout, *, element_bytes: int, transferred_bytes = n_segments * segment_bytes return { - 'segments': n_segments, - 'cache_lines': n_lines, - 'unique_bytes': unique_bytes, - 'requested_bytes': requested_bytes, - 'transferred_bytes': transferred_bytes, - 'segment_efficiency': unique_bytes / transferred_bytes if transferred_bytes > 0 else 0.0, - 'first_byte_addr': first_byte, - 'first_alignment': first_byte % segment_bytes, + "segments": n_segments, + "cache_lines": n_lines, + "unique_bytes": unique_bytes, + "requested_bytes": requested_bytes, + "transferred_bytes": transferred_bytes, + "segment_efficiency": (unique_bytes / transferred_bytes if transferred_bytes > 0 else 0.0), + "first_byte_addr": first_byte, + "first_alignment": first_byte % segment_bytes, } @@ -389,10 +407,14 @@ def _tv_dimensions(layout: Layout): return size(mode(layout, 0)), size(layout) // size(mode(layout, 0)) -def per_group_bank_conflicts(layout: Layout, *, element_bytes: int, - group_size: int = 32, - num_banks: int = 32, - bank_width_bytes: int = 4) -> dict: +def per_group_bank_conflicts( + layout: Layout, + *, + element_bytes: int, + group_size: int = 32, + num_banks: int = 32, + bank_width_bytes: int = 4, +) -> dict: """Analyze bank conflicts per warp/wavefront group across a full layout. Splits the layout into groups of ``group_size`` threads and analyzes @@ -452,9 +474,9 @@ def per_group_bank_conflicts(layout: Layout, *, element_bytes: int, max_ways = ways result = { - 'conflict_free': max_ways <= 1, - 'max_ways': max_ways, - 'bank_to_threads': bank_to_threads, + "conflict_free": max_ways <= 1, + "max_ways": max_ways, + "bank_to_threads": bank_to_threads, } groups.append(result) if max_ways > worst_ways: @@ -462,15 +484,19 @@ def per_group_bank_conflicts(layout: Layout, *, element_bytes: int, worst_idx = g return { - 'groups': groups, - 'worst_group': worst_idx, - 'worst_max_ways': worst_ways, + "groups": groups, + "worst_group": worst_idx, + "worst_max_ways": worst_ways, } -def per_group_coalescing(layout: Layout, *, element_bytes: int, - group_size: int = 32, - cache_line_bytes: int = 128) -> dict: +def per_group_coalescing( + layout: Layout, + *, + element_bytes: int, + group_size: int = 32, + cache_line_bytes: int = 128, +) -> dict: """Analyze coalescing efficiency per warp/wavefront group across a full layout. Splits the layout into groups of ``group_size`` threads and analyzes @@ -500,7 +526,7 @@ def per_group_coalescing(layout: Layout, *, element_bytes: int, groups = [] worst_idx = 0 - worst_eff = float('inf') + worst_eff = float("inf") for g in range(num_groups): start = g * group_size @@ -523,9 +549,9 @@ def per_group_coalescing(layout: Layout, *, element_bytes: int, efficiency = useful_bytes / transferred_bytes if transferred_bytes > 0 else 0.0 result = { - 'transactions': transactions, - 'efficiency': efficiency, - 'cache_lines': sorted(cache_lines), + "transactions": transactions, + "efficiency": efficiency, + "cache_lines": sorted(cache_lines), } groups.append(result) if efficiency < worst_eff: @@ -533,9 +559,9 @@ def per_group_coalescing(layout: Layout, *, element_bytes: int, worst_idx = g return { - 'groups': groups, - 'worst_group': worst_idx, - 'worst_efficiency': worst_eff, + "groups": groups, + "worst_group": worst_idx, + "worst_efficiency": worst_eff, } @@ -543,6 +569,7 @@ def per_group_coalescing(layout: Layout, *, element_bytes: int, # Permutation analysis # ============================================================================= + def cycles(layout: Layout) -> list: """Return the cycle decomposition of a bijective layout. @@ -633,6 +660,7 @@ def order(layout: Layout) -> int: # Contiguity # ============================================================================= + def contiguity(layout: Layout) -> int: """Return the longest contiguous vector width from the start of the layout. @@ -722,6 +750,7 @@ def slice_contiguity(layout: Layout, coord) -> int: # Atom analysis # ============================================================================= + def atom_summary(atom: MMAAtom) -> dict: """Summarize an MMA atom's key properties. @@ -764,40 +793,39 @@ def atom_summary(atom: MMAAtom) -> dict: for v in range(num_v): c_offset_list.append(atom.c_layout(t, v)) c_offsets = set(c_offset_list) - c_coverage_ok = (c_offsets == set(range(M * N)) - and len(c_offset_list) == M * N) + c_coverage_ok = c_offsets == set(range(M * N)) and len(c_offset_list) == M * N # Check for broadcast (stride-0) in A and B a_broadcast = atom.a_layout.filter() != atom.a_layout b_broadcast = atom.b_layout.filter() != atom.b_layout result = { - 'name': atom.name, - 'shape_mnk': atom.shape_mnk, - 'threads': threads, - 'values_a': values_a, - 'values_b': values_b, - 'values_c': values_c, - 'c_coverage_ok': c_coverage_ok, - 'a_broadcast': a_broadcast, - 'b_broadcast': b_broadcast, + "name": atom.name, + "shape_mnk": atom.shape_mnk, + "threads": threads, + "values_a": values_a, + "values_b": values_b, + "values_c": values_c, + "c_coverage_ok": c_coverage_ok, + "a_broadcast": a_broadcast, + "b_broadcast": b_broadcast, } lines = [ atom.name, - f' Shape (M, N, K): {M} x {N} x {K}', - f' Threads: {threads}', - f' Values per thread: A={values_a}, B={values_b}, C={values_c}', - f' C covers M*N: {c_coverage_ok}', + f" Shape (M, N, K): {M} x {N} x {K}", + f" Threads: {threads}", + f" Values per thread: A={values_a}, B={values_b}, C={values_c}", + f" C covers M*N: {c_coverage_ok}", ] if a_broadcast: - lines.append(f' A has broadcast (stride-0) modes') + lines.append(" A has broadcast (stride-0) modes") if b_broadcast: - lines.append(f' B has broadcast (stride-0) modes') + lines.append(" B has broadcast (stride-0) modes") - text = '\n'.join(lines) + text = "\n".join(lines) print(text) - result['text'] = text + result["text"] = text return result @@ -819,14 +847,14 @@ def _operand_coverage(layout: Layout, domain_size: int) -> dict: duplicates = total_accesses - len(unique) return { - 'domain_size': domain_size, - 'unique_offsets': len(unique), - 'total_accesses': total_accesses, - 'duplicates': duplicates, - 'coverage_ok': unique == expected, - 'missing': sorted(missing) if missing else [], - 'extra': sorted(extra) if extra else [], - 'thread_utilization': len(unique) / total_accesses if total_accesses > 0 else 0.0, + "domain_size": domain_size, + "unique_offsets": len(unique), + "total_accesses": total_accesses, + "duplicates": duplicates, + "coverage_ok": unique == expected, + "missing": sorted(missing) if missing else [], + "extra": sorted(extra) if extra else [], + "thread_utilization": (len(unique) / total_accesses if total_accesses > 0 else 0.0), } @@ -853,9 +881,9 @@ def operand_analysis(atom: MMAAtom) -> dict: M, N, K = atom.shape_mnk return { - 'a': _operand_coverage(atom.a_layout, M * K), - 'b': _operand_coverage(atom.b_layout, N * K), - 'c': _operand_coverage(atom.c_layout, M * N), + "a": _operand_coverage(atom.a_layout, M * K), + "b": _operand_coverage(atom.b_layout, N * K), + "c": _operand_coverage(atom.c_layout, M * N), } @@ -863,6 +891,7 @@ def operand_analysis(atom: MMAAtom) -> dict: # Algebra explanation # ============================================================================= + def explain(fn, *args): """Show step-by-step how an algebra operation computes its result. @@ -883,189 +912,189 @@ def explain(fn, *args): name = fn.__name__ lines = [] - if name == 'logical_divide': + if name == "logical_divide": L, T = args if isinstance(T, int): T = Layout(T) - lines.append(f'logical_divide({L}, {T})') + lines.append(f"logical_divide({L}, {T})") actual = logical_divide(L, T) if is_layout(T): - lines.append(f' = compose(L, Layout(T, complement(T, size(L))))') - lines.append(f'') - lines.append(f' L = {L}') - lines.append(f' T = {T}') - lines.append(f' size(L) = {size(L)}') + lines.append(" = compose(L, Layout(T, complement(T, size(L))))") + lines.append("") + lines.append(f" L = {L}") + lines.append(f" T = {T}") + lines.append(f" size(L) = {size(L)}") comp = complement(T, size(L)) - lines.append(f' complement(T, {size(L)}) = {comp}') + lines.append(f" complement(T, {size(L)}) = {comp}") intermediate = Layout(T, comp) - lines.append(f' Layout(T, complement) = {intermediate}') + lines.append(f" Layout(T, complement) = {intermediate}") result = compose(L, intermediate) - lines.append(f' compose(L, {intermediate}) = {result}') + lines.append(f" compose(L, {intermediate}) = {result}") else: - lines.append(f' Divides each mode of L by the corresponding tiler element.') - lines.append(f'') - lines.append(f' L = {L}') - lines.append(f' T = {T}') + lines.append(" Divides each mode of L by the corresponding tiler element.") + lines.append("") + lines.append(f" L = {L}") + lines.append(f" T = {T}") - lines.append(f'') - lines.append(f' result = {actual}') + lines.append("") + lines.append(f" result = {actual}") - elif name == 'logical_product': + elif name == "logical_product": A, B = args if isinstance(B, int): B = Layout(B) - lines.append(f'logical_product({A}, {B})') + lines.append(f"logical_product({A}, {B})") if is_layout(B): - lines.append(f' = Layout(A, compose(complement(A, size(A)*size(B)), B))') - lines.append(f'') - lines.append(f' A = {A}') - lines.append(f' B = {B}') + lines.append(" = Layout(A, compose(complement(A, size(A)*size(B)), B))") + lines.append("") + lines.append(f" A = {A}") + lines.append(f" B = {B}") bound = size(A) * size(B) - lines.append(f' size(A) * size(B) = {bound}') + lines.append(f" size(A) * size(B) = {bound}") comp = complement(A, bound) - lines.append(f' complement(A, {bound}) = {comp}') + lines.append(f" complement(A, {bound}) = {comp}") comp_b = compose(comp, B) - lines.append(f' compose(complement, B) = {comp_b}') + lines.append(f" compose(complement, B) = {comp_b}") result = Layout(A, comp_b) - lines.append(f' Layout(A, {comp_b}) = {result}') + lines.append(f" Layout(A, {comp_b}) = {result}") else: # Tuple tiler: mode-by-mode decomposition - lines.append(f' For tuple tilers, applies logical_product mode-by-mode.') - lines.append(f'') - lines.append(f' A = {A}') - lines.append(f' B = {B}') + lines.append(" For tuple tilers, applies logical_product mode-by-mode.") + lines.append("") + lines.append(f" A = {A}") + lines.append(f" B = {B}") for i in range(len(B)): mi = mode(A, i) bi = B[i] ri = logical_product(mi, bi) - lines.append(f' mode {i}: logical_product({mi}, {bi}) = {ri}') + lines.append(f" mode {i}: logical_product({mi}, {bi}) = {ri}") - lines.append(f'') + lines.append("") actual = logical_product(A, B) - lines.append(f' result = {actual}') + lines.append(f" result = {actual}") - elif name == 'complement': + elif name == "complement": L = args[0] bound = args[1] if len(args) > 1 else None if bound is not None: - lines.append(f'complement({L}, {bound})') + lines.append(f"complement({L}, {bound})") else: - lines.append(f'complement({L})') + lines.append(f"complement({L})") bound = cosize(L) - lines.append(f' Fills the gaps in L\'s codomain up to bound={bound}.') - lines.append(f'') - lines.append(f' L = {L}') - lines.append(f' image(L) = {image(L)}') - lines.append(f' codomain = [0, {bound})') + lines.append(f" Fills the gaps in L's codomain up to bound={bound}.") + lines.append("") + lines.append(f" L = {L}") + lines.append(f" image(L) = {image(L)}") + lines.append(f" codomain = [0, {bound})") comp = complement(*args) - lines.append(f' complement = {comp}') - lines.append(f' image(complement) = {image(comp)}') + lines.append(f" complement = {comp}") + lines.append(f" image(complement) = {image(comp)}") - elif name == 'compose': + elif name == "compose": A, B = args - lines.append(f'compose({A}, {B})') - lines.append(f' C(i) = A(B(i))') - lines.append(f'') - lines.append(f' A = {A}') - lines.append(f' B = {B}') + lines.append(f"compose({A}, {B})") + lines.append(" C(i) = A(B(i))") + lines.append("") + lines.append(f" A = {A}") + lines.append(f" B = {B}") result = compose(A, B) - lines.append(f' result = {result}') - lines.append(f'') + lines.append(f" result = {result}") + lines.append("") n = min(size(result), 8) - lines.append(f' First {n} values:') + lines.append(f" First {n} values:") for i in range(n): - lines.append(f' i={i}: B({i})={B(i)}, A({B(i)})={result(i)}') + lines.append(f" i={i}: B({i})={B(i)}, A({B(i)})={result(i)}") - elif name == 'right_inverse': + elif name == "right_inverse": L = args[0] - lines.append(f'right_inverse({L})') - lines.append(f' R such that L(R(i)) == i') - lines.append(f'') + lines.append(f"right_inverse({L})") + lines.append(" R such that L(R(i)) == i") + lines.append("") R = right_inverse(L) - lines.append(f' L = {L}') - lines.append(f' R = {R}') + lines.append(f" L = {L}") + lines.append(f" R = {R}") n = min(size(R), 8) - lines.append(f'') - lines.append(f' Verification (first {n}):') + lines.append("") + lines.append(f" Verification (first {n}):") for i in range(n): - lines.append(f' R({i})={R(i)}, L(R({i}))={L(R(i))}') + lines.append(f" R({i})={R(i)}, L(R({i}))={L(R(i))}") - elif name == 'left_inverse': + elif name == "left_inverse": L = args[0] - lines.append(f'left_inverse({L})') - lines.append(f' R such that R(L(i)) == i') - lines.append(f'') + lines.append(f"left_inverse({L})") + lines.append(" R such that R(L(i)) == i") + lines.append("") R = left_inverse(L) - lines.append(f' L = {L}') - lines.append(f' R = {R}') + lines.append(f" L = {L}") + lines.append(f" R = {R}") n = min(size(L), 8) - lines.append(f'') - lines.append(f' Verification (first {n}):') + lines.append("") + lines.append(f" Verification (first {n}):") for i in range(n): - lines.append(f' L({i})={L(i)}, R(L({i}))={R(L(i))}') + lines.append(f" L({i})={L(i)}, R(L({i}))={R(L(i))}") - elif name == 'blocked_product': + elif name == "blocked_product": A, B = args - lines.append(f'blocked_product({A}, {B})') - lines.append(f' Like logical_product, but interleaves corresponding modes:') - lines.append(f' ((A0, B0), (A1, B1), ...) — A varies fastest (block-first).') - lines.append(f'') + lines.append(f"blocked_product({A}, {B})") + lines.append(" Like logical_product, but interleaves corresponding modes:") + lines.append(" ((A0, B0), (A1, B1), ...) — A varies fastest (block-first).") + lines.append("") lp = logical_product(A, B) - lines.append(f' logical_product(A, B) = {lp}') + lines.append(f" logical_product(A, B) = {lp}") actual = blocked_product(A, B) - lines.append(f' blocked_product(A, B) = {actual}') - lines.append(f'') - lines.append(f' Mode structure:') + lines.append(f" blocked_product(A, B) = {actual}") + lines.append("") + lines.append(" Mode structure:") for i in range(max(1, len(actual.shape) if isinstance(actual.shape, tuple) else 1)): m = mode(actual, i) if isinstance(actual.shape, tuple) else actual - lines.append(f' mode {i}: {m.shape} : {m.stride}') + lines.append(f" mode {i}: {m.shape} : {m.stride}") - elif name == 'raked_product': + elif name == "raked_product": A, B = args - lines.append(f'raked_product({A}, {B})') - lines.append(f' Like blocked_product, but B varies fastest (rake-first):') - lines.append(f' ((B0, A0), (B1, A1), ...) — elements are interleaved.') - lines.append(f'') + lines.append(f"raked_product({A}, {B})") + lines.append(" Like blocked_product, but B varies fastest (rake-first):") + lines.append(" ((B0, A0), (B1, A1), ...) — elements are interleaved.") + lines.append("") bp = blocked_product(A, B) - lines.append(f' blocked_product(A, B) = {bp}') + lines.append(f" blocked_product(A, B) = {bp}") actual = raked_product(A, B) - lines.append(f' raked_product(A, B) = {actual}') - lines.append(f'') - lines.append(f' Compare first 8 offsets:') + lines.append(f" raked_product(A, B) = {actual}") + lines.append("") + lines.append(" Compare first 8 offsets:") n = min(size(actual), 8) bp_vals = [bp(i) for i in range(n)] rp_vals = [actual(i) for i in range(n)] - lines.append(f' blocked: {bp_vals}') - lines.append(f' raked: {rp_vals}') + lines.append(f" blocked: {bp_vals}") + lines.append(f" raked: {rp_vals}") - elif name in ('zipped_divide', 'tiled_divide', 'flat_divide'): + elif name in ("zipped_divide", "tiled_divide", "flat_divide"): L, T = args - lines.append(f'{name}({L}, {T})') - lines.append(f' Rearrangement of logical_divide result.') - lines.append(f'') + lines.append(f"{name}({L}, {T})") + lines.append(" Rearrangement of logical_divide result.") + lines.append("") ld = logical_divide(L, T) - lines.append(f' logical_divide({L}, {T})') - lines.append(f' = {ld}') + lines.append(f" logical_divide({L}, {T})") + lines.append(f" = {ld}") actual = fn(L, T) - lines.append(f' {name}:') - lines.append(f' = {actual}') - lines.append(f'') - if name == 'zipped_divide': - lines.append(f' Structure: ((tiles), (rests))') - elif name == 'tiled_divide': - lines.append(f' Structure: ((tiles), rest0, rest1, ...)') + lines.append(f" {name}:") + lines.append(f" = {actual}") + lines.append("") + if name == "zipped_divide": + lines.append(" Structure: ((tiles), (rests))") + elif name == "tiled_divide": + lines.append(" Structure: ((tiles), rest0, rest1, ...)") else: - lines.append(f' Structure: (tile0, tile1, ..., rest0, rest1, ...)') + lines.append(" Structure: (tile0, tile1, ..., rest0, rest1, ...)") else: - lines.append(f'explain() does not support {name}.') - lines.append(f'Supported: logical_divide, logical_product, complement,') - lines.append(f' compose, right_inverse, left_inverse,') - lines.append(f' blocked_product, raked_product,') - lines.append(f' zipped_divide, tiled_divide, flat_divide.') + lines.append(f"explain() does not support {name}.") + lines.append("Supported: logical_divide, logical_product, complement,") + lines.append(" compose, right_inverse, left_inverse,") + lines.append(" blocked_product, raked_product,") + lines.append(" zipped_divide, tiled_divide, flat_divide.") - text = '\n'.join(lines) + text = "\n".join(lines) print(text) return text diff --git a/src/tensor_layouts/atoms.py b/src/tensor_layouts/atoms.py index 729b634..eedd278 100644 --- a/src/tensor_layouts/atoms.py +++ b/src/tensor_layouts/atoms.py @@ -45,6 +45,7 @@ class MMAAtom: b_layout: (T, V) -> col-major offset in (N, K) c_layout: (T, V) -> col-major offset in (M, N) """ + name: str ptx: str shape_mnk: Tuple[int, int, int] @@ -71,6 +72,7 @@ class CopyAtom: src_layout_bits: (thr, val) -> bit offset for source dst_layout_bits: (thr, val) -> bit offset for destination """ + name: str ptx: str thr_id: Layout diff --git a/src/tensor_layouts/atoms_amd.py b/src/tensor_layouts/atoms_amd.py index 4356eb4..7f1214e 100644 --- a/src/tensor_layouts/atoms_amd.py +++ b/src/tensor_layouts/atoms_amd.py @@ -120,6 +120,7 @@ # Helper: construct CuTe layouts from MFMA structural parameters # ============================================================================= + def _mfma_c_layout( m: int, n: int, @@ -219,21 +220,23 @@ def make_mfma_atom( # Sanity checks matching the CK static_asserts if num_threads_per_blk != n: - raise ValueError( - f"num_threads_per_blk ({num_threads_per_blk}) != n ({n})") + raise ValueError(f"num_threads_per_blk ({num_threads_per_blk}) != n ({n})") if num_regs_per_blk * num_input_blks != m: raise ValueError( f"num_regs_per_blk * num_input_blks " - f"({num_regs_per_blk * num_input_blks}) != m ({m})") + f"({num_regs_per_blk * num_input_blks}) != m ({m})" + ) if num_regs_per_blk * wave_size != m * n: raise ValueError( f"num_regs_per_blk * wave_size " - f"({num_regs_per_blk * wave_size}) != m*n ({m * n})") + f"({num_regs_per_blk * wave_size}) != m*n ({m * n})" + ) if wave_size != num_input_blks * num_threads_per_blk: raise ValueError( f"wave_size ({wave_size}) != " f"num_input_blks * num_threads_per_blk " - f"({num_input_blks * num_threads_per_blk})") + f"({num_input_blks * num_threads_per_blk})" + ) # For k-reduction variants: K = k_per_blk * num_input_blks # For non-k-reduction: K = k_per_blk @@ -242,16 +245,28 @@ def make_mfma_atom( raise ValueError(f"total_k ({total_k}) != k ({k})") c_layout = _mfma_c_layout( - m, n, group_size, num_groups_per_blk, - num_threads_per_blk, num_input_blks, + m, + n, + group_size, + num_groups_per_blk, + num_threads_per_blk, + num_input_blks, ) a_layout = _mfma_input_layout( - m, k, num_threads_per_blk, num_input_blks, k_per_blk, + m, + k, + num_threads_per_blk, + num_input_blks, + k_per_blk, ) b_layout = _mfma_input_layout( - n, k, num_threads_per_blk, num_input_blks, k_per_blk, + n, + k, + num_threads_per_blk, + num_input_blks, + k_per_blk, ) return MMAAtom( @@ -278,11 +293,18 @@ def make_mfma_atom( CDNA_32x32x8_F32F16F16_MFMA = make_mfma_atom( name="CDNA_32x32x8_F32F16F16_MFMA", inst="v_mfma_f32_32x32x8f16", - m=32, n=32, k=8, - group_size=4, num_groups_per_blk=4, - num_threads_per_blk=32, num_input_blks=2, - num_output_blks=1, k_per_blk=4, - is_k_reduction=True, num_v_a=2, num_v_b=2, + m=32, + n=32, + k=8, + group_size=4, + num_groups_per_blk=4, + num_threads_per_blk=32, + num_input_blks=2, + num_output_blks=1, + k_per_blk=4, + is_k_reduction=True, + num_v_a=2, + num_v_b=2, ) # v_mfma_f32_16x16x16f16: D[16x16] = C[16x16] + A[16x16]*B[16x16] @@ -292,11 +314,18 @@ def make_mfma_atom( CDNA_16x16x16_F32F16F16_MFMA = make_mfma_atom( name="CDNA_16x16x16_F32F16F16_MFMA", inst="v_mfma_f32_16x16x16f16", - m=16, n=16, k=16, - group_size=4, num_groups_per_blk=1, - num_threads_per_blk=16, num_input_blks=4, - num_output_blks=1, k_per_blk=4, - is_k_reduction=True, num_v_a=2, num_v_b=2, + m=16, + n=16, + k=16, + group_size=4, + num_groups_per_blk=1, + num_threads_per_blk=16, + num_input_blks=4, + num_output_blks=1, + k_per_blk=4, + is_k_reduction=True, + num_v_a=2, + num_v_b=2, ) # v_mfma_f32_4x4x4f16: D[4x4] = C[4x4] + A[4x4]*B[4x4] @@ -309,11 +338,18 @@ def make_mfma_atom( CDNA_4x4x4_F32F16F16_MFMA = make_mfma_atom( name="CDNA_4x4x4_F32F16F16_MFMA", inst="v_mfma_f32_4x4x4f16", - m=4, n=64, k=4, - group_size=4, num_groups_per_blk=1, - num_threads_per_blk=64, num_input_blks=1, - num_output_blks=1, k_per_blk=4, - is_k_reduction=False, num_v_a=2, num_v_b=2, + m=4, + n=64, + k=4, + group_size=4, + num_groups_per_blk=1, + num_threads_per_blk=64, + num_input_blks=1, + num_output_blks=1, + k_per_blk=4, + is_k_reduction=False, + num_v_a=2, + num_v_b=2, ) # --- Non-k-reduction variants (larger K, multiple output blocks) --- @@ -325,11 +361,18 @@ def make_mfma_atom( CDNA_32x32x4_F32F16F16_MFMA = make_mfma_atom( name="CDNA_32x32x4_F32F16F16_MFMA", inst="v_mfma_f32_32x32x4f16", - m=32, n=32, k=4, - group_size=4, num_groups_per_blk=4, - num_threads_per_blk=32, num_input_blks=2, - num_output_blks=2, k_per_blk=4, - is_k_reduction=False, num_v_a=2, num_v_b=2, + m=32, + n=32, + k=4, + group_size=4, + num_groups_per_blk=4, + num_threads_per_blk=32, + num_input_blks=2, + num_output_blks=2, + k_per_blk=4, + is_k_reduction=False, + num_v_a=2, + num_v_b=2, ) # v_mfma_f32_16x16x4f16: 4 output blocks (non-k-reduction) @@ -338,11 +381,18 @@ def make_mfma_atom( CDNA_16x16x4_F32F16F16_MFMA = make_mfma_atom( name="CDNA_16x16x4_F32F16F16_MFMA", inst="v_mfma_f32_16x16x4f16", - m=16, n=16, k=4, - group_size=4, num_groups_per_blk=1, - num_threads_per_blk=16, num_input_blks=4, - num_output_blks=4, k_per_blk=4, - is_k_reduction=False, num_v_a=2, num_v_b=2, + m=16, + n=16, + k=4, + group_size=4, + num_groups_per_blk=1, + num_threads_per_blk=16, + num_input_blks=4, + num_output_blks=4, + k_per_blk=4, + is_k_reduction=False, + num_v_a=2, + num_v_b=2, ) @@ -355,22 +405,36 @@ def make_mfma_atom( CDNA_32x32x8_F32BF16BF16_1K_MFMA = make_mfma_atom( name="CDNA_32x32x8_F32BF16BF16_1K_MFMA", inst="v_mfma_f32_32x32x8bf16_1k", - m=32, n=32, k=8, - group_size=4, num_groups_per_blk=4, - num_threads_per_blk=32, num_input_blks=2, - num_output_blks=1, k_per_blk=4, - is_k_reduction=True, num_v_a=2, num_v_b=2, + m=32, + n=32, + k=8, + group_size=4, + num_groups_per_blk=4, + num_threads_per_blk=32, + num_input_blks=2, + num_output_blks=1, + k_per_blk=4, + is_k_reduction=True, + num_v_a=2, + num_v_b=2, ) # v_mfma_f32_16x16x16bf16_1k: identical layout to 16x16x16f16 CDNA_16x16x16_F32BF16BF16_1K_MFMA = make_mfma_atom( name="CDNA_16x16x16_F32BF16BF16_1K_MFMA", inst="v_mfma_f32_16x16x16bf16_1k", - m=16, n=16, k=16, - group_size=4, num_groups_per_blk=1, - num_threads_per_blk=16, num_input_blks=4, - num_output_blks=1, k_per_blk=4, - is_k_reduction=True, num_v_a=2, num_v_b=2, + m=16, + n=16, + k=16, + group_size=4, + num_groups_per_blk=1, + num_threads_per_blk=16, + num_input_blks=4, + num_output_blks=1, + k_per_blk=4, + is_k_reduction=True, + num_v_a=2, + num_v_b=2, ) @@ -382,22 +446,36 @@ def make_mfma_atom( CDNA_32x32x4_F32BF16BF16_MFMA = make_mfma_atom( name="CDNA_32x32x4_F32BF16BF16_MFMA", inst="v_mfma_f32_32x32x4bf16", - m=32, n=32, k=4, - group_size=4, num_groups_per_blk=4, - num_threads_per_blk=32, num_input_blks=2, - num_output_blks=1, k_per_blk=2, - is_k_reduction=True, num_v_a=2, num_v_b=2, + m=32, + n=32, + k=4, + group_size=4, + num_groups_per_blk=4, + num_threads_per_blk=32, + num_input_blks=2, + num_output_blks=1, + k_per_blk=2, + is_k_reduction=True, + num_v_a=2, + num_v_b=2, ) # v_mfma_f32_16x16x8bf16 CDNA_16x16x8_F32BF16BF16_MFMA = make_mfma_atom( name="CDNA_16x16x8_F32BF16BF16_MFMA", inst="v_mfma_f32_16x16x8bf16", - m=16, n=16, k=8, - group_size=4, num_groups_per_blk=1, - num_threads_per_blk=16, num_input_blks=4, - num_output_blks=1, k_per_blk=2, - is_k_reduction=True, num_v_a=2, num_v_b=2, + m=16, + n=16, + k=8, + group_size=4, + num_groups_per_blk=1, + num_threads_per_blk=16, + num_input_blks=4, + num_output_blks=1, + k_per_blk=2, + is_k_reduction=True, + num_v_a=2, + num_v_b=2, ) @@ -409,22 +487,36 @@ def make_mfma_atom( CDNA_32x32x8_I32I8I8_MFMA = make_mfma_atom( name="CDNA_32x32x8_I32I8I8_MFMA", inst="v_mfma_i32_32x32x8i8", - m=32, n=32, k=8, - group_size=4, num_groups_per_blk=4, - num_threads_per_blk=32, num_input_blks=2, - num_output_blks=1, k_per_blk=4, - is_k_reduction=True, num_v_a=1, num_v_b=1, + m=32, + n=32, + k=8, + group_size=4, + num_groups_per_blk=4, + num_threads_per_blk=32, + num_input_blks=2, + num_output_blks=1, + k_per_blk=4, + is_k_reduction=True, + num_v_a=1, + num_v_b=1, ) # v_mfma_i32_16x16x16i8 CDNA_16x16x16_I32I8I8_MFMA = make_mfma_atom( name="CDNA_16x16x16_I32I8I8_MFMA", inst="v_mfma_i32_16x16x16i8", - m=16, n=16, k=16, - group_size=4, num_groups_per_blk=1, - num_threads_per_blk=16, num_input_blks=4, - num_output_blks=1, k_per_blk=4, - is_k_reduction=True, num_v_a=1, num_v_b=1, + m=16, + n=16, + k=16, + group_size=4, + num_groups_per_blk=1, + num_threads_per_blk=16, + num_input_blks=4, + num_output_blks=1, + k_per_blk=4, + is_k_reduction=True, + num_v_a=1, + num_v_b=1, ) @@ -436,22 +528,36 @@ def make_mfma_atom( CDNA_32x32x2_F32F32F32_MFMA = make_mfma_atom( name="CDNA_32x32x2_F32F32F32_MFMA", inst="v_mfma_f32_32x32x2f32", - m=32, n=32, k=2, - group_size=4, num_groups_per_blk=4, - num_threads_per_blk=32, num_input_blks=2, - num_output_blks=1, k_per_blk=1, - is_k_reduction=True, num_v_a=1, num_v_b=1, + m=32, + n=32, + k=2, + group_size=4, + num_groups_per_blk=4, + num_threads_per_blk=32, + num_input_blks=2, + num_output_blks=1, + k_per_blk=1, + is_k_reduction=True, + num_v_a=1, + num_v_b=1, ) # v_mfma_f32_16x16x4f32 CDNA_16x16x4_F32F32F32_MFMA = make_mfma_atom( name="CDNA_16x16x4_F32F32F32_MFMA", inst="v_mfma_f32_16x16x4f32", - m=16, n=16, k=4, - group_size=4, num_groups_per_blk=1, - num_threads_per_blk=16, num_input_blks=4, - num_output_blks=1, k_per_blk=1, - is_k_reduction=True, num_v_a=1, num_v_b=1, + m=16, + n=16, + k=4, + group_size=4, + num_groups_per_blk=1, + num_threads_per_blk=16, + num_input_blks=4, + num_output_blks=1, + k_per_blk=1, + is_k_reduction=True, + num_v_a=1, + num_v_b=1, ) @@ -463,11 +569,18 @@ def make_mfma_atom( CDNA_16x16x4_F64F64F64_MFMA = make_mfma_atom( name="CDNA_16x16x4_F64F64F64_MFMA", inst="v_mfma_f64_16x16x4f64", - m=16, n=16, k=4, - group_size=1, num_groups_per_blk=4, - num_threads_per_blk=16, num_input_blks=4, - num_output_blks=1, k_per_blk=1, - is_k_reduction=True, num_v_a=2, num_v_b=2, + m=16, + n=16, + k=4, + group_size=1, + num_groups_per_blk=4, + num_threads_per_blk=16, + num_input_blks=4, + num_output_blks=1, + k_per_blk=1, + is_k_reduction=True, + num_v_a=2, + num_v_b=2, ) @@ -501,11 +614,18 @@ def make_mfma_atom( CDNA3_32x32x16_I32I8I8_MFMA = make_mfma_atom( name="CDNA3_32x32x16_I32I8I8_MFMA", inst="v_mfma_i32_32x32x16i8", - m=32, n=32, k=16, - group_size=4, num_groups_per_blk=4, - num_threads_per_blk=32, num_input_blks=2, - num_output_blks=1, k_per_blk=8, - is_k_reduction=True, num_v_a=2, num_v_b=2, + m=32, + n=32, + k=16, + group_size=4, + num_groups_per_blk=4, + num_threads_per_blk=32, + num_input_blks=2, + num_output_blks=1, + k_per_blk=8, + is_k_reduction=True, + num_v_a=2, + num_v_b=2, ) # v_mfma_i32_16x16x32i8: 16x16 output, K=32 @@ -514,11 +634,18 @@ def make_mfma_atom( CDNA3_16x16x32_I32I8I8_MFMA = make_mfma_atom( name="CDNA3_16x16x32_I32I8I8_MFMA", inst="v_mfma_i32_16x16x32i8", - m=16, n=16, k=32, - group_size=4, num_groups_per_blk=1, - num_threads_per_blk=16, num_input_blks=4, - num_output_blks=1, k_per_blk=8, - is_k_reduction=True, num_v_a=2, num_v_b=2, + m=16, + n=16, + k=32, + group_size=4, + num_groups_per_blk=1, + num_threads_per_blk=16, + num_input_blks=4, + num_output_blks=1, + k_per_blk=8, + is_k_reduction=True, + num_v_a=2, + num_v_b=2, ) # --- XF32 (TF32-like, CDNA3) --- @@ -527,22 +654,36 @@ def make_mfma_atom( CDNA3_32x32x4_F32XF32XF32_MFMA = make_mfma_atom( name="CDNA3_32x32x4_F32XF32XF32_MFMA", inst="v_mfma_f32_32x32x4_xf32", - m=32, n=32, k=4, - group_size=4, num_groups_per_blk=4, - num_threads_per_blk=32, num_input_blks=2, - num_output_blks=1, k_per_blk=2, - is_k_reduction=True, num_v_a=2, num_v_b=2, + m=32, + n=32, + k=4, + group_size=4, + num_groups_per_blk=4, + num_threads_per_blk=32, + num_input_blks=2, + num_output_blks=1, + k_per_blk=2, + is_k_reduction=True, + num_v_a=2, + num_v_b=2, ) # v_mfma_f32_16x16x8_xf32 CDNA3_16x16x8_F32XF32XF32_MFMA = make_mfma_atom( name="CDNA3_16x16x8_F32XF32XF32_MFMA", inst="v_mfma_f32_16x16x8_xf32", - m=16, n=16, k=8, - group_size=4, num_groups_per_blk=1, - num_threads_per_blk=16, num_input_blks=4, - num_output_blks=1, k_per_blk=2, - is_k_reduction=True, num_v_a=2, num_v_b=2, + m=16, + n=16, + k=8, + group_size=4, + num_groups_per_blk=1, + num_threads_per_blk=16, + num_input_blks=4, + num_output_blks=1, + k_per_blk=2, + is_k_reduction=True, + num_v_a=2, + num_v_b=2, ) @@ -554,85 +695,141 @@ def make_mfma_atom( CDNA3_32x32x16_F32F8F8_MFMA = make_mfma_atom( name="CDNA3_32x32x16_F32F8F8_MFMA", inst="v_mfma_f32_32x32x16_fp8_fp8", - m=32, n=32, k=16, - group_size=4, num_groups_per_blk=4, - num_threads_per_blk=32, num_input_blks=2, - num_output_blks=1, k_per_blk=8, - is_k_reduction=True, num_v_a=2, num_v_b=2, + m=32, + n=32, + k=16, + group_size=4, + num_groups_per_blk=4, + num_threads_per_blk=32, + num_input_blks=2, + num_output_blks=1, + k_per_blk=8, + is_k_reduction=True, + num_v_a=2, + num_v_b=2, ) # v_mfma_f32_16x16x32_fp8_fp8 CDNA3_16x16x32_F32F8F8_MFMA = make_mfma_atom( name="CDNA3_16x16x32_F32F8F8_MFMA", inst="v_mfma_f32_16x16x32_fp8_fp8", - m=16, n=16, k=32, - group_size=4, num_groups_per_blk=1, - num_threads_per_blk=16, num_input_blks=4, - num_output_blks=1, k_per_blk=8, - is_k_reduction=True, num_v_a=2, num_v_b=2, + m=16, + n=16, + k=32, + group_size=4, + num_groups_per_blk=1, + num_threads_per_blk=16, + num_input_blks=4, + num_output_blks=1, + k_per_blk=8, + is_k_reduction=True, + num_v_a=2, + num_v_b=2, ) # v_mfma_f32_32x32x16_bf8_bf8: same layout as fp8_fp8 32x32 CDNA3_32x32x16_F32BF8BF8_MFMA = make_mfma_atom( name="CDNA3_32x32x16_F32BF8BF8_MFMA", inst="v_mfma_f32_32x32x16_bf8_bf8", - m=32, n=32, k=16, - group_size=4, num_groups_per_blk=4, - num_threads_per_blk=32, num_input_blks=2, - num_output_blks=1, k_per_blk=8, - is_k_reduction=True, num_v_a=2, num_v_b=2, + m=32, + n=32, + k=16, + group_size=4, + num_groups_per_blk=4, + num_threads_per_blk=32, + num_input_blks=2, + num_output_blks=1, + k_per_blk=8, + is_k_reduction=True, + num_v_a=2, + num_v_b=2, ) # v_mfma_f32_16x16x32_bf8_bf8 CDNA3_16x16x32_F32BF8BF8_MFMA = make_mfma_atom( name="CDNA3_16x16x32_F32BF8BF8_MFMA", inst="v_mfma_f32_16x16x32_bf8_bf8", - m=16, n=16, k=32, - group_size=4, num_groups_per_blk=1, - num_threads_per_blk=16, num_input_blks=4, - num_output_blks=1, k_per_blk=8, - is_k_reduction=True, num_v_a=2, num_v_b=2, + m=16, + n=16, + k=32, + group_size=4, + num_groups_per_blk=1, + num_threads_per_blk=16, + num_input_blks=4, + num_output_blks=1, + k_per_blk=8, + is_k_reduction=True, + num_v_a=2, + num_v_b=2, ) # Mixed FP8 variants (fp8 x bf8, bf8 x fp8) — same layouts CDNA3_32x32x16_F32F8BF8_MFMA = make_mfma_atom( name="CDNA3_32x32x16_F32F8BF8_MFMA", inst="v_mfma_f32_32x32x16_fp8_bf8", - m=32, n=32, k=16, - group_size=4, num_groups_per_blk=4, - num_threads_per_blk=32, num_input_blks=2, - num_output_blks=1, k_per_blk=8, - is_k_reduction=True, num_v_a=2, num_v_b=2, + m=32, + n=32, + k=16, + group_size=4, + num_groups_per_blk=4, + num_threads_per_blk=32, + num_input_blks=2, + num_output_blks=1, + k_per_blk=8, + is_k_reduction=True, + num_v_a=2, + num_v_b=2, ) CDNA3_16x16x32_F32F8BF8_MFMA = make_mfma_atom( name="CDNA3_16x16x32_F32F8BF8_MFMA", inst="v_mfma_f32_16x16x32_fp8_bf8", - m=16, n=16, k=32, - group_size=4, num_groups_per_blk=1, - num_threads_per_blk=16, num_input_blks=4, - num_output_blks=1, k_per_blk=8, - is_k_reduction=True, num_v_a=2, num_v_b=2, + m=16, + n=16, + k=32, + group_size=4, + num_groups_per_blk=1, + num_threads_per_blk=16, + num_input_blks=4, + num_output_blks=1, + k_per_blk=8, + is_k_reduction=True, + num_v_a=2, + num_v_b=2, ) CDNA3_32x32x16_F32BF8F8_MFMA = make_mfma_atom( name="CDNA3_32x32x16_F32BF8F8_MFMA", inst="v_mfma_f32_32x32x16_bf8_fp8", - m=32, n=32, k=16, - group_size=4, num_groups_per_blk=4, - num_threads_per_blk=32, num_input_blks=2, - num_output_blks=1, k_per_blk=8, - is_k_reduction=True, num_v_a=2, num_v_b=2, + m=32, + n=32, + k=16, + group_size=4, + num_groups_per_blk=4, + num_threads_per_blk=32, + num_input_blks=2, + num_output_blks=1, + k_per_blk=8, + is_k_reduction=True, + num_v_a=2, + num_v_b=2, ) CDNA3_16x16x32_F32BF8F8_MFMA = make_mfma_atom( name="CDNA3_16x16x32_F32BF8F8_MFMA", inst="v_mfma_f32_16x16x32_bf8_fp8", - m=16, n=16, k=32, - group_size=4, num_groups_per_blk=1, - num_threads_per_blk=16, num_input_blks=4, - num_output_blks=1, k_per_blk=8, - is_k_reduction=True, num_v_a=2, num_v_b=2, + m=16, + n=16, + k=32, + group_size=4, + num_groups_per_blk=1, + num_threads_per_blk=16, + num_input_blks=4, + num_output_blks=1, + k_per_blk=8, + is_k_reduction=True, + num_v_a=2, + num_v_b=2, ) @@ -646,11 +843,18 @@ def make_mfma_atom( CDNA3P_32x32x16_F32F16F16_MFMA = make_mfma_atom( name="CDNA3P_32x32x16_F32F16F16_MFMA", inst="v_mfma_f32_32x32x16_f16", - m=32, n=32, k=16, - group_size=4, num_groups_per_blk=4, - num_threads_per_blk=32, num_input_blks=2, - num_output_blks=1, k_per_blk=8, - is_k_reduction=True, num_v_a=2, num_v_b=2, + m=32, + n=32, + k=16, + group_size=4, + num_groups_per_blk=4, + num_threads_per_blk=32, + num_input_blks=2, + num_output_blks=1, + k_per_blk=8, + is_k_reduction=True, + num_v_a=2, + num_v_b=2, ) # v_mfma_f32_16x16x32_f16 (gfx950 only): 2x K vs 16x16x16f16 @@ -659,11 +863,18 @@ def make_mfma_atom( CDNA3P_16x16x32_F32F16F16_MFMA = make_mfma_atom( name="CDNA3P_16x16x32_F32F16F16_MFMA", inst="v_mfma_f32_16x16x32_f16", - m=16, n=16, k=32, - group_size=4, num_groups_per_blk=1, - num_threads_per_blk=16, num_input_blks=4, - num_output_blks=1, k_per_blk=8, - is_k_reduction=True, num_v_a=2, num_v_b=2, + m=16, + n=16, + k=32, + group_size=4, + num_groups_per_blk=1, + num_threads_per_blk=16, + num_input_blks=4, + num_output_blks=1, + k_per_blk=8, + is_k_reduction=True, + num_v_a=2, + num_v_b=2, ) # v_mfma_f32_32x32x16_bf16 (gfx950 only) @@ -672,44 +883,72 @@ def make_mfma_atom( CDNA3P_32x32x16_F32BF16BF16_MFMA = make_mfma_atom( name="CDNA3P_32x32x16_F32BF16BF16_MFMA", inst="v_mfma_f32_32x32x16_bf16", - m=32, n=32, k=16, - group_size=4, num_groups_per_blk=4, - num_threads_per_blk=32, num_input_blks=2, - num_output_blks=1, k_per_blk=8, - is_k_reduction=True, num_v_a=2, num_v_b=2, + m=32, + n=32, + k=16, + group_size=4, + num_groups_per_blk=4, + num_threads_per_blk=32, + num_input_blks=2, + num_output_blks=1, + k_per_blk=8, + is_k_reduction=True, + num_v_a=2, + num_v_b=2, ) # v_mfma_f32_16x16x32_bf16 (gfx950 only) CDNA3P_16x16x32_F32BF16BF16_MFMA = make_mfma_atom( name="CDNA3P_16x16x32_F32BF16BF16_MFMA", inst="v_mfma_f32_16x16x32_bf16", - m=16, n=16, k=32, - group_size=4, num_groups_per_blk=1, - num_threads_per_blk=16, num_input_blks=4, - num_output_blks=1, k_per_blk=8, - is_k_reduction=True, num_v_a=2, num_v_b=2, + m=16, + n=16, + k=32, + group_size=4, + num_groups_per_blk=1, + num_threads_per_blk=16, + num_input_blks=4, + num_output_blks=1, + k_per_blk=8, + is_k_reduction=True, + num_v_a=2, + num_v_b=2, ) # v_mfma_i32_32x32x32_i8 (gfx950 only) CDNA3P_32x32x32_I32I8I8_MFMA = make_mfma_atom( name="CDNA3P_32x32x32_I32I8I8_MFMA", inst="v_mfma_i32_32x32x32i8", - m=32, n=32, k=32, - group_size=4, num_groups_per_blk=4, - num_threads_per_blk=32, num_input_blks=2, - num_output_blks=1, k_per_blk=16, - is_k_reduction=True, num_v_a=2, num_v_b=2, + m=32, + n=32, + k=32, + group_size=4, + num_groups_per_blk=4, + num_threads_per_blk=32, + num_input_blks=2, + num_output_blks=1, + k_per_blk=16, + is_k_reduction=True, + num_v_a=2, + num_v_b=2, ) # v_mfma_i32_16x16x64_i8 (gfx950 only) CDNA3P_16x16x64_I32I8I8_MFMA = make_mfma_atom( name="CDNA3P_16x16x64_I32I8I8_MFMA", inst="v_mfma_i32_16x16x64i8", - m=16, n=16, k=64, - group_size=4, num_groups_per_blk=1, - num_threads_per_blk=16, num_input_blks=4, - num_output_blks=1, k_per_blk=16, - is_k_reduction=True, num_v_a=2, num_v_b=2, + m=16, + n=16, + k=64, + group_size=4, + num_groups_per_blk=1, + num_threads_per_blk=16, + num_input_blks=4, + num_output_blks=1, + k_per_blk=16, + is_k_reduction=True, + num_v_a=2, + num_v_b=2, ) diff --git a/src/tensor_layouts/atoms_amx.py b/src/tensor_layouts/atoms_amx.py index c7b05b4..f487113 100644 --- a/src/tensor_layouts/atoms_amx.py +++ b/src/tensor_layouts/atoms_amx.py @@ -95,58 +95,70 @@ AMX_16x16x32_F32BF16BF16F32 = MMAAtom( name="AMX_16x16x32_F32BF16BF16F32", ptx="tdpbf16ps", - shape_mnk=(16, 16, 32), thr_id=Layout(1), + shape_mnk=(16, 16, 32), + thr_id=Layout(1), # (T=1, V=512) -> col-major offset in (M=16, K=32) a_layout=Layout((1, (16, 32)), (0, (1, 16))), # (T=1, V=512) -> col-major offset in (N=16, K=32) b_layout=Layout((1, (16, 32)), (0, (1, 16))), # (T=1, V=256) -> col-major offset in (M=16, N=16) - c_layout=Layout((1, (16, 16)), (0, (1, 16)))) + c_layout=Layout((1, (16, 16)), (0, (1, 16))), +) # -- FP16 -> FP32 ------------------------------------------------------------- AMX_16x16x32_F32F16F16F32 = MMAAtom( name="AMX_16x16x32_F32F16F16F32", ptx="tdpfp16ps", - shape_mnk=(16, 16, 32), thr_id=Layout(1), + shape_mnk=(16, 16, 32), + thr_id=Layout(1), a_layout=Layout((1, (16, 32)), (0, (1, 16))), b_layout=Layout((1, (16, 32)), (0, (1, 16))), - c_layout=Layout((1, (16, 16)), (0, (1, 16)))) + c_layout=Layout((1, (16, 16)), (0, (1, 16))), +) # -- INT8 x INT8 -> INT32 (signed x signed) ----------------------------------- AMX_16x16x64_S32S8S8S32 = MMAAtom( name="AMX_16x16x64_S32S8S8S32", ptx="tdpbssd", - shape_mnk=(16, 16, 64), thr_id=Layout(1), + shape_mnk=(16, 16, 64), + thr_id=Layout(1), # (T=1, V=1024) -> col-major offset in (M=16, K=64) a_layout=Layout((1, (16, 64)), (0, (1, 16))), # (T=1, V=1024) -> col-major offset in (N=16, K=64) b_layout=Layout((1, (16, 64)), (0, (1, 16))), # (T=1, V=256) -> col-major offset in (M=16, N=16) - c_layout=Layout((1, (16, 16)), (0, (1, 16)))) + c_layout=Layout((1, (16, 16)), (0, (1, 16))), +) # -- INT8 x UINT8 -> INT32 (signed x unsigned) -------------------------------- AMX_16x16x64_S32S8U8S32 = MMAAtom( name="AMX_16x16x64_S32S8U8S32", ptx="tdpbsud", - shape_mnk=(16, 16, 64), thr_id=Layout(1), + shape_mnk=(16, 16, 64), + thr_id=Layout(1), a_layout=Layout((1, (16, 64)), (0, (1, 16))), b_layout=Layout((1, (16, 64)), (0, (1, 16))), - c_layout=Layout((1, (16, 16)), (0, (1, 16)))) + c_layout=Layout((1, (16, 16)), (0, (1, 16))), +) # -- UINT8 x INT8 -> INT32 (unsigned x signed) -------------------------------- AMX_16x16x64_S32U8S8S32 = MMAAtom( name="AMX_16x16x64_S32U8S8S32", ptx="tdpbusd", - shape_mnk=(16, 16, 64), thr_id=Layout(1), + shape_mnk=(16, 16, 64), + thr_id=Layout(1), a_layout=Layout((1, (16, 64)), (0, (1, 16))), b_layout=Layout((1, (16, 64)), (0, (1, 16))), - c_layout=Layout((1, (16, 16)), (0, (1, 16)))) + c_layout=Layout((1, (16, 16)), (0, (1, 16))), +) # -- UINT8 x UINT8 -> INT32 (unsigned x unsigned) ----------------------------- AMX_16x16x64_S32U8U8S32 = MMAAtom( name="AMX_16x16x64_S32U8U8S32", ptx="tdpbuud", - shape_mnk=(16, 16, 64), thr_id=Layout(1), + shape_mnk=(16, 16, 64), + thr_id=Layout(1), a_layout=Layout((1, (16, 64)), (0, (1, 16))), b_layout=Layout((1, (16, 64)), (0, (1, 16))), - c_layout=Layout((1, (16, 16)), (0, (1, 16)))) + c_layout=Layout((1, (16, 16)), (0, (1, 16))), +) diff --git a/src/tensor_layouts/atoms_nv.py b/src/tensor_layouts/atoms_nv.py index 91a8bc7..da28850 100644 --- a/src/tensor_layouts/atoms_nv.py +++ b/src/tensor_layouts/atoms_nv.py @@ -53,6 +53,7 @@ from .layouts import Layout from .atoms import MMAAtom, CopyAtom + # ============================================================================= # SM61 Pascal DP MMA atoms — 1 "thread" (scalar) # Source: include/cute/atom/mma_traits_sm61.hpp @@ -62,18 +63,22 @@ SM61_1x1x4_S32S8S8S32 = MMAAtom( name="SM61_DP4A", ptx="dp4a.s32.s32", - shape_mnk=(1, 1, 4), thr_id=Layout(1), + shape_mnk=(1, 1, 4), + thr_id=Layout(1), a_layout=Layout((1, 4)), b_layout=Layout((1, 4)), - c_layout=Layout((1, 1))) + c_layout=Layout((1, 1)), +) SM61_1x1x2_S32S16S16S32 = MMAAtom( name="SM61_DP2A", ptx="dp2a.s32.s32", - shape_mnk=(1, 1, 2), thr_id=Layout(1), + shape_mnk=(1, 1, 2), + thr_id=Layout(1), a_layout=Layout((1, 2)), b_layout=Layout((1, 2)), - c_layout=Layout((1, 1))) + c_layout=Layout((1, 1)), +) # ============================================================================= @@ -82,28 +87,25 @@ # ============================================================================= # Logical thread id → warp lane index (quadpair: lanes 0-3 and 16-19) -SM70_QuadPair = Layout((4, 2), (1, 16)) # line 44 -SM70_8x4_Row = Layout((8, 4), (1, 8)) # line 47: (T8,V4) → (M8,K4) -SM70_8x4_Col = Layout(((4, 2), 4), # line 50: (T8,V4) → (M8,K4) - ((8, 4), 1)) -SM70_8x8_16b = Layout((8, 8), (1, 8)) # line 53: (T8,V8) → (M8,N8) fp16 accum -SM70_8x8_32b = Layout(((2, 2, 2), # line 56: (T8,V8) → (M8,N8) fp32 accum - (2, 2, 2)), - ((1, 16, 4), - (8, 2, 32))) +SM70_QuadPair = Layout((4, 2), (1, 16)) # line 44 +SM70_8x4_Row = Layout((8, 4), (1, 8)) # line 47: (T8,V4) → (M8,K4) +SM70_8x4_Col = Layout(((4, 2), 4), ((8, 4), 1)) # line 50: (T8,V4) → (M8,K4) +SM70_8x8_16b = Layout((8, 8), (1, 8)) # line 53: (T8,V8) → (M8,N8) fp16 accum +SM70_8x8_32b = Layout( + ((2, 2, 2), (2, 2, 2)), # line 56: (T8,V8) → (M8,N8) fp32 accum + ((1, 16, 4), (8, 2, 32)), +) # ============================================================================= # From mma_traits_sm80.hpp (lines 41-55) # ============================================================================= -SM80_8x4 = Layout(((4, 8), 1), # line 42: (T32,V1) → (M8,N8) - ((8, 1), 0)) -SM80_8x8_Row = Layout(((4, 8), 2), # line 46: (T32,V2) → (M8,N8) - ((16, 1), 8)) -SM80_8x16_Row = Layout(((4, 8), 4), # line 50: (T32,V4) → (M8,N16) - ((32, 1), 8)) -SM80_16x8_Row = Layout(((4, 8), (2, 2)), # line 53: (T32,V4) → (M16,N8) - ((32, 1), (16, 8))) +SM80_8x4 = Layout(((4, 8), 1), ((8, 1), 0)) # line 42: (T32,V1) → (M8,N8) +SM80_8x8_Row = Layout(((4, 8), 2), ((16, 1), 8)) # line 46: (T32,V2) → (M8,N8) +SM80_8x16_Row = Layout(((4, 8), 4), ((32, 1), 8)) # line 50: (T32,V4) → (M8,N16) +SM80_16x8_Row = Layout( + ((4, 8), (2, 2)), ((32, 1), (16, 8)) # line 53: (T32,V4) → (M16,N8) +) # ============================================================================= @@ -117,58 +119,90 @@ SM70_8x8x4_F16F16F16F16_TN = MMAAtom( name="SM70_8x8x4_F16F16F16F16_TN", ptx="mma.sync.aligned.m8n8k4.row.col.f16.f16.f16.f16", - shape_mnk=(8, 8, 4), thr_id=SM70_QuadPair, - a_layout=SM70_8x4_Row, b_layout=SM70_8x4_Row, c_layout=SM70_8x8_16b) + shape_mnk=(8, 8, 4), + thr_id=SM70_QuadPair, + a_layout=SM70_8x4_Row, + b_layout=SM70_8x4_Row, + c_layout=SM70_8x8_16b, +) # line 81 — fp16 accumulator, A=col-major, B=row-major SM70_8x8x4_F16F16F16F16_NT = MMAAtom( name="SM70_8x8x4_F16F16F16F16_NT", ptx="mma.sync.aligned.m8n8k4.col.row.f16.f16.f16.f16", - shape_mnk=(8, 8, 4), thr_id=SM70_QuadPair, - a_layout=SM70_8x4_Col, b_layout=SM70_8x4_Col, c_layout=SM70_8x8_16b) + shape_mnk=(8, 8, 4), + thr_id=SM70_QuadPair, + a_layout=SM70_8x4_Col, + b_layout=SM70_8x4_Col, + c_layout=SM70_8x8_16b, +) # line 98 SM70_8x8x4_F16F16F16F16_NN = MMAAtom( name="SM70_8x8x4_F16F16F16F16_NN", ptx="mma.sync.aligned.m8n8k4.col.col.f16.f16.f16.f16", - shape_mnk=(8, 8, 4), thr_id=SM70_QuadPair, - a_layout=SM70_8x4_Col, b_layout=SM70_8x4_Row, c_layout=SM70_8x8_16b) + shape_mnk=(8, 8, 4), + thr_id=SM70_QuadPair, + a_layout=SM70_8x4_Col, + b_layout=SM70_8x4_Row, + c_layout=SM70_8x8_16b, +) # line 115 SM70_8x8x4_F16F16F16F16_TT = MMAAtom( name="SM70_8x8x4_F16F16F16F16_TT", ptx="mma.sync.aligned.m8n8k4.row.row.f16.f16.f16.f16", - shape_mnk=(8, 8, 4), thr_id=SM70_QuadPair, - a_layout=SM70_8x4_Row, b_layout=SM70_8x4_Col, c_layout=SM70_8x8_16b) + shape_mnk=(8, 8, 4), + thr_id=SM70_QuadPair, + a_layout=SM70_8x4_Row, + b_layout=SM70_8x4_Col, + c_layout=SM70_8x8_16b, +) # line 132 — fp32 accumulator, A=row-major, B=col-major SM70_8x8x4_F32F16F16F32_TN = MMAAtom( name="SM70_8x8x4_F32F16F16F32_TN", ptx="mma.sync.aligned.m8n8k4.row.col.f32.f16.f16.f32", - shape_mnk=(8, 8, 4), thr_id=SM70_QuadPair, - a_layout=SM70_8x4_Row, b_layout=SM70_8x4_Row, c_layout=SM70_8x8_32b) + shape_mnk=(8, 8, 4), + thr_id=SM70_QuadPair, + a_layout=SM70_8x4_Row, + b_layout=SM70_8x4_Row, + c_layout=SM70_8x8_32b, +) # line 149 — fp32 accumulator, A=col-major, B=row-major # Reference image: media/images/cute/HMMA.8x8x4.NT_Atom.png SM70_8x8x4_F32F16F16F32_NT = MMAAtom( name="SM70_8x8x4_F32F16F16F32_NT", ptx="mma.sync.aligned.m8n8k4.col.row.f32.f16.f16.f32", - shape_mnk=(8, 8, 4), thr_id=SM70_QuadPair, - a_layout=SM70_8x4_Col, b_layout=SM70_8x4_Col, c_layout=SM70_8x8_32b) + shape_mnk=(8, 8, 4), + thr_id=SM70_QuadPair, + a_layout=SM70_8x4_Col, + b_layout=SM70_8x4_Col, + c_layout=SM70_8x8_32b, +) # line 166 SM70_8x8x4_F32F16F16F32_NN = MMAAtom( name="SM70_8x8x4_F32F16F16F32_NN", ptx="mma.sync.aligned.m8n8k4.col.col.f32.f16.f16.f32", - shape_mnk=(8, 8, 4), thr_id=SM70_QuadPair, - a_layout=SM70_8x4_Col, b_layout=SM70_8x4_Row, c_layout=SM70_8x8_32b) + shape_mnk=(8, 8, 4), + thr_id=SM70_QuadPair, + a_layout=SM70_8x4_Col, + b_layout=SM70_8x4_Row, + c_layout=SM70_8x8_32b, +) # line 183 SM70_8x8x4_F32F16F16F32_TT = MMAAtom( name="SM70_8x8x4_F32F16F16F32_TT", ptx="mma.sync.aligned.m8n8k4.row.row.f32.f16.f16.f32", - shape_mnk=(8, 8, 4), thr_id=SM70_QuadPair, - a_layout=SM70_8x4_Row, b_layout=SM70_8x4_Col, c_layout=SM70_8x8_32b) + shape_mnk=(8, 8, 4), + thr_id=SM70_QuadPair, + a_layout=SM70_8x4_Row, + b_layout=SM70_8x4_Col, + c_layout=SM70_8x8_32b, +) # ============================================================================= @@ -181,18 +215,22 @@ SM75_16x8x8_F32F16F16F32_TN = MMAAtom( name="SM75_16x8x8_F32F16F16F32_TN", ptx="mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32", - shape_mnk=(16, 8, 8), thr_id=None, + shape_mnk=(16, 8, 8), + thr_id=None, a_layout=Layout(((4, 8), (2, 2)), ((32, 1), (16, 8))), b_layout=Layout(((4, 8), 2), ((16, 1), 8)), - c_layout=Layout(((4, 8), (2, 2)), ((32, 1), (16, 8)))) + c_layout=Layout(((4, 8), (2, 2)), ((32, 1), (16, 8))), +) SM75_8x8x16_S32S8S8S32_TN = MMAAtom( name="SM75_8x8x16_S32S8S8S32_TN", ptx="mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32", - shape_mnk=(8, 8, 16), thr_id=None, + shape_mnk=(8, 8, 16), + thr_id=None, a_layout=Layout(((4, 8), 4), ((32, 1), 8)), b_layout=Layout(((4, 8), 4), ((32, 1), 8)), - c_layout=Layout(((4, 8), 2), ((16, 1), 8))) + c_layout=Layout(((4, 8), 2), ((16, 1), 8)), +) # ============================================================================= @@ -212,70 +250,96 @@ SM80_16x8x8_F16F16F16F16_TN = MMAAtom( name="SM80_16x8x8_F16F16F16F16_TN", ptx="mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16", - shape_mnk=(16, 8, 8), thr_id=None, - a_layout=SM80_16x8_Row, b_layout=SM80_8x8_Row, c_layout=SM80_16x8_Row) + shape_mnk=(16, 8, 8), + thr_id=None, + a_layout=SM80_16x8_Row, + b_layout=SM80_8x8_Row, + c_layout=SM80_16x8_Row, +) SM80_16x8x16_F16F16F16F16_TN = MMAAtom( name="SM80_16x8x16_F16F16F16F16_TN", ptx="mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16", - shape_mnk=(16, 8, 16), thr_id=None, + shape_mnk=(16, 8, 16), + thr_id=None, a_layout=Layout(((4, 8), (2, 2, 2)), ((32, 1), (16, 8, 128))), b_layout=Layout(((4, 8), (2, 2)), ((16, 1), (8, 64))), - c_layout=SM80_16x8_Row) + c_layout=SM80_16x8_Row, +) # --- FP32 accumulator with FP16 inputs --- SM80_16x8x8_F32F16F16F32_TN = MMAAtom( name="SM80_16x8x8_F32F16F16F32_TN", ptx="mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32", - shape_mnk=(16, 8, 8), thr_id=None, - a_layout=SM80_16x8_Row, b_layout=SM80_8x8_Row, c_layout=SM80_16x8_Row) + shape_mnk=(16, 8, 8), + thr_id=None, + a_layout=SM80_16x8_Row, + b_layout=SM80_8x8_Row, + c_layout=SM80_16x8_Row, +) SM80_16x8x16_F32F16F16F32_TN = MMAAtom( name="SM80_16x8x16_F32F16F16F32_TN", ptx="mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32", - shape_mnk=(16, 8, 16), thr_id=None, + shape_mnk=(16, 8, 16), + thr_id=None, a_layout=Layout(((4, 8), (2, 2, 2)), ((32, 1), (16, 8, 128))), b_layout=Layout(((4, 8), (2, 2)), ((16, 1), (8, 64))), - c_layout=SM80_16x8_Row) + c_layout=SM80_16x8_Row, +) # --- BF16 (same layouts as FP16) --- SM80_16x8x8_F32BF16BF16F32_TN = MMAAtom( name="SM80_16x8x8_F32BF16BF16F32_TN", ptx="mma.sync.aligned.m16n8k8.row.col.f32.bf16.bf16.f32", - shape_mnk=(16, 8, 8), thr_id=None, - a_layout=SM80_16x8_Row, b_layout=SM80_8x8_Row, c_layout=SM80_16x8_Row) + shape_mnk=(16, 8, 8), + thr_id=None, + a_layout=SM80_16x8_Row, + b_layout=SM80_8x8_Row, + c_layout=SM80_16x8_Row, +) SM80_16x8x16_F32BF16BF16F32_TN = MMAAtom( name="SM80_16x8x16_F32BF16BF16F32_TN", ptx="mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32", - shape_mnk=(16, 8, 16), thr_id=None, + shape_mnk=(16, 8, 16), + thr_id=None, a_layout=Layout(((4, 8), (2, 2, 2)), ((32, 1), (16, 8, 128))), b_layout=Layout(((4, 8), (2, 2)), ((16, 1), (8, 64))), - c_layout=SM80_16x8_Row) + c_layout=SM80_16x8_Row, +) # --- TF32 (TensorFloat-32) --- SM80_16x8x4_F32TF32TF32F32_TN = MMAAtom( name="SM80_16x8x4_F32TF32TF32F32_TN", ptx="mma.sync.aligned.m16n8k4.row.col.f32.tf32.tf32.f32", - shape_mnk=(16, 8, 4), thr_id=None, + shape_mnk=(16, 8, 4), + thr_id=None, a_layout=Layout(((4, 8), 2), ((16, 1), 8)), b_layout=SM80_8x4, - c_layout=SM80_16x8_Row) + c_layout=SM80_16x8_Row, +) SM80_16x8x8_F32TF32TF32F32_TN = MMAAtom( name="SM80_16x8x8_F32TF32TF32F32_TN", ptx="mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32", - shape_mnk=(16, 8, 8), thr_id=None, + shape_mnk=(16, 8, 8), + thr_id=None, a_layout=Layout(((4, 8), (2, 2)), ((16, 1), (8, 64))), b_layout=Layout(((4, 8), 2), ((8, 1), 32)), - c_layout=SM80_16x8_Row) + c_layout=SM80_16x8_Row, +) # --- FP64 --- SM80_8x8x4_F64F64F64F64_TN = MMAAtom( name="SM80_8x8x4_F64F64F64F64_TN", ptx="mma.sync.aligned.m8n8k4.row.col.f64.f64.f64.f64", - shape_mnk=(8, 8, 4), thr_id=None, - a_layout=SM80_8x4, b_layout=SM80_8x4, c_layout=SM80_8x8_Row) + shape_mnk=(8, 8, 4), + thr_id=None, + a_layout=SM80_8x4, + b_layout=SM80_8x4, + c_layout=SM80_8x8_Row, +) # --- INT8 (s8×s8, s8×u8, u8×s8, u8×u8 all share layouts at same tile size) --- @@ -283,26 +347,34 @@ SM80_8x8x16_S32S8S8S32_TN = MMAAtom( name="SM80_8x8x16_S32S8S8S32_TN", ptx="mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32", - shape_mnk=(8, 8, 16), thr_id=None, - a_layout=SM80_8x16_Row, b_layout=SM80_8x16_Row, c_layout=SM80_8x8_Row) + shape_mnk=(8, 8, 16), + thr_id=None, + a_layout=SM80_8x16_Row, + b_layout=SM80_8x16_Row, + c_layout=SM80_8x8_Row, +) # 16x8x16 SM80_16x8x16_S32S8S8S32_TN = MMAAtom( name="SM80_16x8x16_S32S8S8S32_TN", ptx="mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32", - shape_mnk=(16, 8, 16), thr_id=None, + shape_mnk=(16, 8, 16), + thr_id=None, a_layout=Layout(((4, 8), (4, 2)), ((64, 1), (16, 8))), b_layout=SM80_8x16_Row, - c_layout=SM80_16x8_Row) + c_layout=SM80_16x8_Row, +) # 16x8x32 SM80_16x8x32_S32S8S8S32_TN = MMAAtom( name="SM80_16x8x32_S32S8S8S32_TN", ptx="mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32", - shape_mnk=(16, 8, 32), thr_id=None, + shape_mnk=(16, 8, 32), + thr_id=None, a_layout=Layout(((4, 8), (4, 2, 2)), ((64, 1), (16, 8, 256))), b_layout=Layout(((4, 8), (4, 2)), ((32, 1), (8, 128))), - c_layout=SM80_16x8_Row) + c_layout=SM80_16x8_Row, +) # --- INT4 --- @@ -310,54 +382,66 @@ SM80_8x8x32_S32S4S4S32_TN = MMAAtom( name="SM80_8x8x32_S32S4S4S32_TN", ptx="mma.sync.aligned.m8n8k32.row.col.s32.s4.s4.s32", - shape_mnk=(8, 8, 32), thr_id=None, + shape_mnk=(8, 8, 32), + thr_id=None, a_layout=Layout(((4, 8), 8), ((64, 1), 8)), b_layout=Layout(((4, 8), 8), ((64, 1), 8)), - c_layout=SM80_8x8_Row) + c_layout=SM80_8x8_Row, +) # 16x8x32 SM80_16x8x32_S32S4S4S32_TN = MMAAtom( name="SM80_16x8x32_S32S4S4S32_TN", ptx="mma.sync.aligned.m16n8k32.row.col.s32.s4.s4.s32", - shape_mnk=(16, 8, 32), thr_id=None, + shape_mnk=(16, 8, 32), + thr_id=None, a_layout=Layout(((4, 8), (8, 2)), ((128, 1), (16, 8))), b_layout=Layout(((4, 8), 8), ((32, 1), 8)), - c_layout=SM80_16x8_Row) + c_layout=SM80_16x8_Row, +) # 16x8x64 SM80_16x8x64_S32S4S4S32_TN = MMAAtom( name="SM80_16x8x64_S32S4S4S32_TN", ptx="mma.sync.aligned.m16n8k64.row.col.s32.s4.s4.s32", - shape_mnk=(16, 8, 64), thr_id=None, + shape_mnk=(16, 8, 64), + thr_id=None, a_layout=Layout(((4, 8), (8, 2, 2)), ((128, 1), (16, 8, 512))), b_layout=Layout(((4, 8), (8, 2)), ((64, 1), (8, 256))), - c_layout=SM80_16x8_Row) + c_layout=SM80_16x8_Row, +) # --- Binary (U1) --- SM80_8x8x128_S32U1U1S32_TN_XORPOPC = MMAAtom( name="SM80_8x8x128_S32U1U1S32_TN_XORPOPC", ptx="mma.sync.aligned.m8n8k128.row.col.s32.b1.b1.s32.xor.popc", - shape_mnk=(8, 8, 128), thr_id=None, + shape_mnk=(8, 8, 128), + thr_id=None, a_layout=Layout(((4, 8), 32), ((256, 1), 8)), b_layout=Layout(((4, 8), 32), ((256, 1), 8)), - c_layout=SM80_8x8_Row) + c_layout=SM80_8x8_Row, +) SM80_16x8x128_S32U1U1S32_TN_XORPOPC = MMAAtom( name="SM80_16x8x128_S32U1U1S32_TN_XORPOPC", ptx="mma.sync.aligned.m16n8k128.row.col.s32.b1.b1.s32.xor.popc", - shape_mnk=(16, 8, 128), thr_id=None, + shape_mnk=(16, 8, 128), + thr_id=None, a_layout=Layout(((4, 8), (32, 2)), ((512, 1), (16, 8))), b_layout=Layout(((4, 8), 32), ((256, 1), 8)), - c_layout=SM80_16x8_Row) + c_layout=SM80_16x8_Row, +) SM80_16x8x256_S32U1U1S32_TN_XORPOPC = MMAAtom( name="SM80_16x8x256_S32U1U1S32_TN_XORPOPC", ptx="mma.sync.aligned.m16n8k256.row.col.s32.b1.b1.s32.xor.popc", - shape_mnk=(16, 8, 256), thr_id=None, + shape_mnk=(16, 8, 256), + thr_id=None, a_layout=Layout(((4, 8), (32, 2, 2)), ((512, 1), (16, 8, 2048))), b_layout=Layout(((4, 8), (32, 2)), ((256, 1), (8, 1024))), - c_layout=SM80_16x8_Row) + c_layout=SM80_16x8_Row, +) # ============================================================================= @@ -373,67 +457,83 @@ SM89_16x8x32_F32E4M3E4M3F32_TN = MMAAtom( name="SM89_16x8x32_F32E4M3E4M3F32_TN", ptx="mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32", - shape_mnk=(16, 8, 32), thr_id=None, + shape_mnk=(16, 8, 32), + thr_id=None, a_layout=Layout(((4, 8), (4, 2, 2)), ((64, 1), (16, 8, 256))), b_layout=Layout(((4, 8), (4, 2)), ((32, 1), (8, 128))), - c_layout=SM80_16x8_Row) + c_layout=SM80_16x8_Row, +) SM89_16x8x32_F32E4M3E5M2F32_TN = MMAAtom( name="SM89_16x8x32_F32E4M3E5M2F32_TN", ptx="mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e5m2.f32", - shape_mnk=(16, 8, 32), thr_id=None, + shape_mnk=(16, 8, 32), + thr_id=None, a_layout=SM89_16x8x32_F32E4M3E4M3F32_TN.a_layout, b_layout=SM89_16x8x32_F32E4M3E4M3F32_TN.b_layout, - c_layout=SM89_16x8x32_F32E4M3E4M3F32_TN.c_layout) + c_layout=SM89_16x8x32_F32E4M3E4M3F32_TN.c_layout, +) SM89_16x8x32_F32E5M2E5M2F32_TN = MMAAtom( name="SM89_16x8x32_F32E5M2E5M2F32_TN", ptx="mma.sync.aligned.m16n8k32.row.col.f32.e5m2.e5m2.f32", - shape_mnk=(16, 8, 32), thr_id=None, + shape_mnk=(16, 8, 32), + thr_id=None, a_layout=SM89_16x8x32_F32E4M3E4M3F32_TN.a_layout, b_layout=SM89_16x8x32_F32E4M3E4M3F32_TN.b_layout, - c_layout=SM89_16x8x32_F32E4M3E4M3F32_TN.c_layout) + c_layout=SM89_16x8x32_F32E4M3E4M3F32_TN.c_layout, +) SM89_16x8x32_F32E5M2E4M3F32_TN = MMAAtom( name="SM89_16x8x32_F32E5M2E4M3F32_TN", ptx="mma.sync.aligned.m16n8k32.row.col.f32.e5m2.e4m3.f32", - shape_mnk=(16, 8, 32), thr_id=None, + shape_mnk=(16, 8, 32), + thr_id=None, a_layout=SM89_16x8x32_F32E4M3E4M3F32_TN.a_layout, b_layout=SM89_16x8x32_F32E4M3E4M3F32_TN.b_layout, - c_layout=SM89_16x8x32_F32E4M3E4M3F32_TN.c_layout) + c_layout=SM89_16x8x32_F32E4M3E4M3F32_TN.c_layout, +) # FP16 accumulator variants (same layouts) SM89_16x8x32_F16E4M3E4M3F16_TN = MMAAtom( name="SM89_16x8x32_F16E4M3E4M3F16_TN", ptx="mma.sync.aligned.m16n8k32.row.col.f16.e4m3.e4m3.f16", - shape_mnk=(16, 8, 32), thr_id=None, + shape_mnk=(16, 8, 32), + thr_id=None, a_layout=SM89_16x8x32_F32E4M3E4M3F32_TN.a_layout, b_layout=SM89_16x8x32_F32E4M3E4M3F32_TN.b_layout, - c_layout=SM89_16x8x32_F32E4M3E4M3F32_TN.c_layout) + c_layout=SM89_16x8x32_F32E4M3E4M3F32_TN.c_layout, +) SM89_16x8x32_F16E4M3E5M2F16_TN = MMAAtom( name="SM89_16x8x32_F16E4M3E5M2F16_TN", ptx="mma.sync.aligned.m16n8k32.row.col.f16.e4m3.e5m2.f16", - shape_mnk=(16, 8, 32), thr_id=None, + shape_mnk=(16, 8, 32), + thr_id=None, a_layout=SM89_16x8x32_F32E4M3E4M3F32_TN.a_layout, b_layout=SM89_16x8x32_F32E4M3E4M3F32_TN.b_layout, - c_layout=SM89_16x8x32_F32E4M3E4M3F32_TN.c_layout) + c_layout=SM89_16x8x32_F32E4M3E4M3F32_TN.c_layout, +) SM89_16x8x32_F16E5M2E5M2F16_TN = MMAAtom( name="SM89_16x8x32_F16E5M2E5M2F16_TN", ptx="mma.sync.aligned.m16n8k32.row.col.f16.e5m2.e5m2.f16", - shape_mnk=(16, 8, 32), thr_id=None, + shape_mnk=(16, 8, 32), + thr_id=None, a_layout=SM89_16x8x32_F32E4M3E4M3F32_TN.a_layout, b_layout=SM89_16x8x32_F32E4M3E4M3F32_TN.b_layout, - c_layout=SM89_16x8x32_F32E4M3E4M3F32_TN.c_layout) + c_layout=SM89_16x8x32_F32E4M3E4M3F32_TN.c_layout, +) SM89_16x8x32_F16E5M2E4M3F16_TN = MMAAtom( name="SM89_16x8x32_F16E5M2E4M3F16_TN", ptx="mma.sync.aligned.m16n8k32.row.col.f16.e5m2.e4m3.f16", - shape_mnk=(16, 8, 32), thr_id=None, + shape_mnk=(16, 8, 32), + thr_id=None, a_layout=SM89_16x8x32_F32E4M3E4M3F32_TN.a_layout, b_layout=SM89_16x8x32_F32E4M3E4M3F32_TN.b_layout, - c_layout=SM89_16x8x32_F32E4M3E4M3F32_TN.c_layout) + c_layout=SM89_16x8x32_F32E4M3E4M3F32_TN.c_layout, +) # ============================================================================= @@ -447,54 +547,66 @@ SM90_16x8x4_F64F64F64F64_TN = MMAAtom( name="SM90_16x8x4_F64F64F64F64_TN", ptx="mma.sync.aligned.m16n8k4.row.col.f64.f64.f64.f64", - shape_mnk=(16, 8, 4), thr_id=None, + shape_mnk=(16, 8, 4), + thr_id=None, a_layout=Layout(((4, 8), 2), ((16, 1), 8)), b_layout=SM80_8x4, - c_layout=SM80_16x8_Row) + c_layout=SM80_16x8_Row, +) # line 67 SM90_16x8x8_F64F64F64F64_TN = MMAAtom( name="SM90_16x8x8_F64F64F64F64_TN", ptx="mma.sync.aligned.m16n8k8.row.col.f64.f64.f64.f64", - shape_mnk=(16, 8, 8), thr_id=None, + shape_mnk=(16, 8, 8), + thr_id=None, a_layout=Layout(((4, 8), (2, 2)), ((16, 1), (8, 64))), b_layout=Layout(((4, 8), 2), ((8, 1), 32)), - c_layout=SM80_16x8_Row) + c_layout=SM80_16x8_Row, +) # line 87 SM90_16x8x16_F64F64F64F64_TN = MMAAtom( name="SM90_16x8x16_F64F64F64F64_TN", ptx="mma.sync.aligned.m16n8k16.row.col.f64.f64.f64.f64", - shape_mnk=(16, 8, 16), thr_id=None, + shape_mnk=(16, 8, 16), + thr_id=None, a_layout=Layout(((4, 8), (2, 4)), ((16, 1), (8, 64))), b_layout=Layout(((4, 8), 4), ((8, 1), 32)), - c_layout=SM80_16x8_Row) + c_layout=SM80_16x8_Row, +) # --- Complex FP64 (same layouts as FP64, different value types) --- SM90_16x8x4_C64C64C64C64_TN = MMAAtom( name="SM90_16x8x4_C64C64C64C64_TN", ptx="mma.sync.aligned.m16n8k4.row.col.f64.f64.f64.f64 (complex)", - shape_mnk=(16, 8, 4), thr_id=None, + shape_mnk=(16, 8, 4), + thr_id=None, a_layout=SM90_16x8x4_F64F64F64F64_TN.a_layout, b_layout=SM90_16x8x4_F64F64F64F64_TN.b_layout, - c_layout=SM90_16x8x4_F64F64F64F64_TN.c_layout) + c_layout=SM90_16x8x4_F64F64F64F64_TN.c_layout, +) SM90_16x8x8_C64C64C64C64_TN = MMAAtom( name="SM90_16x8x8_C64C64C64C64_TN", ptx="mma.sync.aligned.m16n8k8.row.col.f64.f64.f64.f64 (complex)", - shape_mnk=(16, 8, 8), thr_id=None, + shape_mnk=(16, 8, 8), + thr_id=None, a_layout=SM90_16x8x8_F64F64F64F64_TN.a_layout, b_layout=SM90_16x8x8_F64F64F64F64_TN.b_layout, - c_layout=SM90_16x8x8_F64F64F64F64_TN.c_layout) + c_layout=SM90_16x8x8_F64F64F64F64_TN.c_layout, +) SM90_16x8x16_C64C64C64C64_TN = MMAAtom( name="SM90_16x8x16_C64C64C64C64_TN", ptx="mma.sync.aligned.m16n8k16.row.col.f64.f64.f64.f64 (complex)", - shape_mnk=(16, 8, 16), thr_id=None, + shape_mnk=(16, 8, 16), + thr_id=None, a_layout=SM90_16x8x16_F64F64F64F64_TN.a_layout, b_layout=SM90_16x8x16_F64F64F64F64_TN.b_layout, - c_layout=SM90_16x8x16_F64F64F64F64_TN.c_layout) + c_layout=SM90_16x8x16_F64F64F64F64_TN.c_layout, +) # ============================================================================= @@ -511,65 +623,79 @@ # tile with stride-0 in the thread dimension (line 436-443 in 0t_mma_atom.md). # ============================================================================= + def gmma_c_layout(n: int) -> Layout: """CLayout_64xN: accumulator layout for SM90 GMMA with N columns. Source: mma_traits_sm90_gmma.hpp line 432.""" - return Layout(((4, 8, 4), (2, 2, n // 8)), - ((128, 1, 16), (64, 8, 512))) + return Layout(((4, 8, 4), (2, 2, n // 8)), ((128, 1, 16), (64, 8, 512))) + def gmma_ab_layout(m: int, k: int) -> Layout: """ABLayout: shared memory descriptor layout — all threads see entire tile. Source: mma_traits_sm90_gmma.hpp; 0t_mma_atom.md lines 436-443.""" return Layout((128, (m, k)), (0, (1, m))) + # line 657 — SM90_64x64x16_F16F16F16_SS SM90_64x8x16_F16F16F16_SS = MMAAtom( name="SM90_64x8x16_F16F16F16_SS", ptx="wgmma.mma_async.sync.aligned.m64n8k16.f16.f16.f16", - shape_mnk=(64, 8, 16), thr_id=None, + shape_mnk=(64, 8, 16), + thr_id=None, a_layout=gmma_ab_layout(64, 16), b_layout=gmma_ab_layout(8, 16), - c_layout=gmma_c_layout(8)) + c_layout=gmma_c_layout(8), +) SM90_64x16x16_F16F16F16_SS = MMAAtom( name="SM90_64x16x16_F16F16F16_SS", ptx="wgmma.mma_async.sync.aligned.m64n16k16.f16.f16.f16", - shape_mnk=(64, 16, 16), thr_id=None, + shape_mnk=(64, 16, 16), + thr_id=None, a_layout=gmma_ab_layout(64, 16), b_layout=gmma_ab_layout(16, 16), - c_layout=gmma_c_layout(16)) + c_layout=gmma_c_layout(16), +) SM90_64x32x16_F16F16F16_SS = MMAAtom( name="SM90_64x32x16_F16F16F16_SS", ptx="wgmma.mma_async.sync.aligned.m64n32k16.f16.f16.f16", - shape_mnk=(64, 32, 16), thr_id=None, + shape_mnk=(64, 32, 16), + thr_id=None, a_layout=gmma_ab_layout(64, 16), b_layout=gmma_ab_layout(32, 16), - c_layout=gmma_c_layout(32)) + c_layout=gmma_c_layout(32), +) SM90_64x64x16_F16F16F16_SS = MMAAtom( name="SM90_64x64x16_F16F16F16_SS", ptx="wgmma.mma_async.sync.aligned.m64n64k16.f16.f16.f16", - shape_mnk=(64, 64, 16), thr_id=None, + shape_mnk=(64, 64, 16), + thr_id=None, a_layout=gmma_ab_layout(64, 16), b_layout=gmma_ab_layout(64, 16), - c_layout=gmma_c_layout(64)) + c_layout=gmma_c_layout(64), +) SM90_64x128x16_F16F16F16_SS = MMAAtom( name="SM90_64x128x16_F16F16F16_SS", ptx="wgmma.mma_async.sync.aligned.m64n128k16.f16.f16.f16", - shape_mnk=(64, 128, 16), thr_id=None, + shape_mnk=(64, 128, 16), + thr_id=None, a_layout=gmma_ab_layout(64, 16), b_layout=gmma_ab_layout(128, 16), - c_layout=gmma_c_layout(128)) + c_layout=gmma_c_layout(128), +) SM90_64x256x16_F16F16F16_SS = MMAAtom( name="SM90_64x256x16_F16F16F16_SS", ptx="wgmma.mma_async.sync.aligned.m64n256k16.f16.f16.f16", - shape_mnk=(64, 256, 16), thr_id=None, + shape_mnk=(64, 256, 16), + thr_id=None, a_layout=gmma_ab_layout(64, 16), b_layout=gmma_ab_layout(256, 16), - c_layout=gmma_c_layout(256)) + c_layout=gmma_c_layout(256), +) # ============================================================================= @@ -585,8 +711,10 @@ def gmma_ab_layout(m: int, k: int) -> Layout: # provide a factory instead of enumerating hundreds of concrete atoms. # ============================================================================= -def make_gmma_atom_ss(n: int, k: int = 16, d_type: str = "F16", - ab_type: str | None = None) -> MMAAtom: + +def make_gmma_atom_ss( + n: int, k: int = 16, d_type: str = "F16", ab_type: str | None = None +) -> MMAAtom: """Create an SM90 GMMA SS atom for 64×N×K with the given data types. Args: @@ -603,10 +731,12 @@ def make_gmma_atom_ss(n: int, k: int = 16, d_type: str = "F16", return MMAAtom( name=name, ptx=f"wgmma.mma_async.sync.aligned.m64n{n}k{k}", - shape_mnk=(64, n, k), thr_id=None, + shape_mnk=(64, n, k), + thr_id=None, a_layout=gmma_ab_layout(64, k), b_layout=gmma_ab_layout(n, k), - c_layout=gmma_c_layout(n)) + c_layout=gmma_c_layout(n), + ) # Representative ext atoms (N values not in the base set) @@ -642,19 +772,35 @@ def make_gmma_atom_ss(n: int, k: int = 16, d_type: str = "F16", # FP8 E4M3 GMMA atoms (K=32 for 8-bit types) SM90_64x64x32_F32E4M3E4M3_SS = make_gmma_atom_ss(64, k=32, d_type="F32", ab_type="E4M3") -SM90_64x128x32_F32E4M3E4M3_SS = make_gmma_atom_ss(128, k=32, d_type="F32", ab_type="E4M3") -SM90_64x256x32_F32E4M3E4M3_SS = make_gmma_atom_ss(256, k=32, d_type="F32", ab_type="E4M3") +SM90_64x128x32_F32E4M3E4M3_SS = make_gmma_atom_ss( + 128, k=32, d_type="F32", ab_type="E4M3" +) +SM90_64x256x32_F32E4M3E4M3_SS = make_gmma_atom_ss( + 256, k=32, d_type="F32", ab_type="E4M3" +) SM90_64x64x32_F16E4M3E4M3_SS = make_gmma_atom_ss(64, k=32, d_type="F16", ab_type="E4M3") -SM90_64x128x32_F16E4M3E4M3_SS = make_gmma_atom_ss(128, k=32, d_type="F16", ab_type="E4M3") -SM90_64x256x32_F16E4M3E4M3_SS = make_gmma_atom_ss(256, k=32, d_type="F16", ab_type="E4M3") +SM90_64x128x32_F16E4M3E4M3_SS = make_gmma_atom_ss( + 128, k=32, d_type="F16", ab_type="E4M3" +) +SM90_64x256x32_F16E4M3E4M3_SS = make_gmma_atom_ss( + 256, k=32, d_type="F16", ab_type="E4M3" +) # FP8 E5M2 GMMA atoms SM90_64x64x32_F32E5M2E5M2_SS = make_gmma_atom_ss(64, k=32, d_type="F32", ab_type="E5M2") -SM90_64x128x32_F32E5M2E5M2_SS = make_gmma_atom_ss(128, k=32, d_type="F32", ab_type="E5M2") -SM90_64x256x32_F32E5M2E5M2_SS = make_gmma_atom_ss(256, k=32, d_type="F32", ab_type="E5M2") +SM90_64x128x32_F32E5M2E5M2_SS = make_gmma_atom_ss( + 128, k=32, d_type="F32", ab_type="E5M2" +) +SM90_64x256x32_F32E5M2E5M2_SS = make_gmma_atom_ss( + 256, k=32, d_type="F32", ab_type="E5M2" +) SM90_64x64x32_F16E5M2E5M2_SS = make_gmma_atom_ss(64, k=32, d_type="F16", ab_type="E5M2") -SM90_64x128x32_F16E5M2E5M2_SS = make_gmma_atom_ss(128, k=32, d_type="F16", ab_type="E5M2") -SM90_64x256x32_F16E5M2E5M2_SS = make_gmma_atom_ss(256, k=32, d_type="F16", ab_type="E5M2") +SM90_64x128x32_F16E5M2E5M2_SS = make_gmma_atom_ss( + 128, k=32, d_type="F16", ab_type="E5M2" +) +SM90_64x256x32_F16E5M2E5M2_SS = make_gmma_atom_ss( + 256, k=32, d_type="F16", ab_type="E5M2" +) # ============================================================================= @@ -667,8 +813,10 @@ def make_gmma_atom_ss(n: int, k: int = 16, d_type: str = "F16", # K_sparse = 2 * K_dense (e.g. K=32 for F16 sparse vs K=16 for F16 dense). # ============================================================================= -def make_gmma_sparse_atom_ss(n: int, k: int = 32, d_type: str = "F16", - ab_type: str | None = None) -> MMAAtom: + +def make_gmma_sparse_atom_ss( + n: int, k: int = 32, d_type: str = "F16", ab_type: str | None = None +) -> MMAAtom: """Create an SM90 GMMA sparse SS atom for 64×N×K.""" if ab_type is None: ab_type = d_type @@ -678,10 +826,12 @@ def make_gmma_sparse_atom_ss(n: int, k: int = 32, d_type: str = "F16", return MMAAtom( name=name, ptx=f"wgmma.mma_async.sp.sync.aligned.m64n{n}k{k}", - shape_mnk=(64, n, k), thr_id=None, + shape_mnk=(64, n, k), + thr_id=None, a_layout=gmma_ab_layout(64, k), b_layout=gmma_ab_layout(n, k), - c_layout=gmma_c_layout(n)) + c_layout=gmma_c_layout(n), + ) # F16 sparse (K=32, double the dense K=16) @@ -690,14 +840,26 @@ def make_gmma_sparse_atom_ss(n: int, k: int = 32, d_type: str = "F16", SM90_64x256x32_F16F16F16_SS_SPARSE = make_gmma_sparse_atom_ss(256) # TF32 sparse (K=16, double the dense K=8) -SM90_64x64x16_F32TF32TF32_SS_SPARSE = make_gmma_sparse_atom_ss(64, k=16, d_type="F32", ab_type="TF32") -SM90_64x128x16_F32TF32TF32_SS_SPARSE = make_gmma_sparse_atom_ss(128, k=16, d_type="F32", ab_type="TF32") -SM90_64x256x16_F32TF32TF32_SS_SPARSE = make_gmma_sparse_atom_ss(256, k=16, d_type="F32", ab_type="TF32") +SM90_64x64x16_F32TF32TF32_SS_SPARSE = make_gmma_sparse_atom_ss( + 64, k=16, d_type="F32", ab_type="TF32" +) +SM90_64x128x16_F32TF32TF32_SS_SPARSE = make_gmma_sparse_atom_ss( + 128, k=16, d_type="F32", ab_type="TF32" +) +SM90_64x256x16_F32TF32TF32_SS_SPARSE = make_gmma_sparse_atom_ss( + 256, k=16, d_type="F32", ab_type="TF32" +) # INT8 sparse (K=64, double the dense K=32) -SM90_64x64x64_S32S8S8_SS_SPARSE = make_gmma_sparse_atom_ss(64, k=64, d_type="S32", ab_type="S8") -SM90_64x128x64_S32S8S8_SS_SPARSE = make_gmma_sparse_atom_ss(128, k=64, d_type="S32", ab_type="S8") -SM90_64x256x64_S32S8S8_SS_SPARSE = make_gmma_sparse_atom_ss(256, k=64, d_type="S32", ab_type="S8") +SM90_64x64x64_S32S8S8_SS_SPARSE = make_gmma_sparse_atom_ss( + 64, k=64, d_type="S32", ab_type="S8" +) +SM90_64x128x64_S32S8S8_SS_SPARSE = make_gmma_sparse_atom_ss( + 128, k=64, d_type="S32", ab_type="S8" +) +SM90_64x256x64_S32S8S8_SS_SPARSE = make_gmma_sparse_atom_ss( + 256, k=64, d_type="S32", ab_type="S8" +) # ============================================================================= @@ -719,83 +881,103 @@ def make_gmma_sparse_atom_ss(n: int, k: int = 32, d_type: str = "F16", # M ∈ {64, 128}, N ∈ {8, 16, 24, ..., 256} (multiples of 8) # ============================================================================= + def umma_layout(rows: int, cols: int) -> Layout: """SM100 UMMA layout: (1, (rows, cols)) : (0, (1, rows)) — col-major.""" return Layout((1, (rows, cols)), (0, (1, rows))) + # --- F16/BF16 SS (both operands from shared memory) --- SM100_64x64x16_F16F16F16_SS = MMAAtom( name="SM100_64x64x16_F16F16F16_SS", ptx="tcgen05.mma ... m64n64k16.f16.f16.f16", - shape_mnk=(64, 64, 16), thr_id=Layout(1), + shape_mnk=(64, 64, 16), + thr_id=Layout(1), a_layout=umma_layout(64, 16), b_layout=umma_layout(64, 16), - c_layout=umma_layout(64, 64)) + c_layout=umma_layout(64, 64), +) SM100_64x128x16_F16F16F16_SS = MMAAtom( name="SM100_64x128x16_F16F16F16_SS", ptx="tcgen05.mma ... m64n128k16.f16.f16.f16", - shape_mnk=(64, 128, 16), thr_id=Layout(1), + shape_mnk=(64, 128, 16), + thr_id=Layout(1), a_layout=umma_layout(64, 16), b_layout=umma_layout(128, 16), - c_layout=umma_layout(64, 128)) + c_layout=umma_layout(64, 128), +) SM100_64x256x16_F16F16F16_SS = MMAAtom( name="SM100_64x256x16_F16F16F16_SS", ptx="tcgen05.mma ... m64n256k16.f16.f16.f16", - shape_mnk=(64, 256, 16), thr_id=Layout(1), + shape_mnk=(64, 256, 16), + thr_id=Layout(1), a_layout=umma_layout(64, 16), b_layout=umma_layout(256, 16), - c_layout=umma_layout(64, 256)) + c_layout=umma_layout(64, 256), +) SM100_128x64x16_F16F16F16_SS = MMAAtom( name="SM100_128x64x16_F16F16F16_SS", ptx="tcgen05.mma ... m128n64k16.f16.f16.f16", - shape_mnk=(128, 64, 16), thr_id=Layout(1), + shape_mnk=(128, 64, 16), + thr_id=Layout(1), a_layout=umma_layout(128, 16), b_layout=umma_layout(64, 16), - c_layout=umma_layout(128, 64)) + c_layout=umma_layout(128, 64), +) SM100_128x128x16_F16F16F16_SS = MMAAtom( name="SM100_128x128x16_F16F16F16_SS", ptx="tcgen05.mma ... m128n128k16.f16.f16.f16", - shape_mnk=(128, 128, 16), thr_id=Layout(1), + shape_mnk=(128, 128, 16), + thr_id=Layout(1), a_layout=umma_layout(128, 16), b_layout=umma_layout(128, 16), - c_layout=umma_layout(128, 128)) + c_layout=umma_layout(128, 128), +) SM100_128x256x16_F16F16F16_SS = MMAAtom( name="SM100_128x256x16_F16F16F16_SS", ptx="tcgen05.mma ... m128n256k16.f16.f16.f16", - shape_mnk=(128, 256, 16), thr_id=Layout(1), + shape_mnk=(128, 256, 16), + thr_id=Layout(1), a_layout=umma_layout(128, 16), b_layout=umma_layout(256, 16), - c_layout=umma_layout(128, 256)) + c_layout=umma_layout(128, 256), +) # --- TF32 SS (K=8 because 256/32=8) --- SM100_64x64x8_F32TF32TF32F32_SS = MMAAtom( name="SM100_64x64x8_F32TF32TF32F32_SS", ptx="tcgen05.mma ... m64n64k8.f32.tf32.tf32.f32", - shape_mnk=(64, 64, 8), thr_id=Layout(1), + shape_mnk=(64, 64, 8), + thr_id=Layout(1), a_layout=umma_layout(64, 8), b_layout=umma_layout(64, 8), - c_layout=umma_layout(64, 64)) + c_layout=umma_layout(64, 64), +) SM100_128x128x8_F32TF32TF32F32_SS = MMAAtom( name="SM100_128x128x8_F32TF32TF32F32_SS", ptx="tcgen05.mma ... m128n128k8.f32.tf32.tf32.f32", - shape_mnk=(128, 128, 8), thr_id=Layout(1), + shape_mnk=(128, 128, 8), + thr_id=Layout(1), a_layout=umma_layout(128, 8), b_layout=umma_layout(128, 8), - c_layout=umma_layout(128, 128)) + c_layout=umma_layout(128, 128), +) # --- SM100 UMMA factory --- -def make_umma_atom_ss(m: int, n: int, k: int = 16, - d_type: str = "F16", ab_type: str | None = None) -> MMAAtom: + +def make_umma_atom_ss( + m: int, n: int, k: int = 16, d_type: str = "F16", ab_type: str | None = None +) -> MMAAtom: """Create an SM100 UMMA SS atom for M×N×K with the given data types.""" if ab_type is None: ab_type = d_type @@ -803,10 +985,13 @@ def make_umma_atom_ss(m: int, n: int, k: int = 16, return MMAAtom( name=name, ptx=f"tcgen05.mma ... m{m}n{n}k{k}", - shape_mnk=(m, n, k), thr_id=Layout(1), + shape_mnk=(m, n, k), + thr_id=Layout(1), a_layout=umma_layout(m, k), b_layout=umma_layout(n, k), - c_layout=umma_layout(m, n)) + c_layout=umma_layout(m, n), + ) + # F32-accumulator with F16 inputs SM100_64x64x16_F32F16F16_SS = make_umma_atom_ss(64, 64, d_type="F32", ab_type="F16") @@ -818,19 +1003,39 @@ def make_umma_atom_ss(m: int, n: int, k: int = 16, # F32-accumulator with BF16 inputs SM100_64x64x16_F32BF16BF16_SS = make_umma_atom_ss(64, 64, d_type="F32", ab_type="BF16") -SM100_64x128x16_F32BF16BF16_SS = make_umma_atom_ss(64, 128, d_type="F32", ab_type="BF16") -SM100_64x256x16_F32BF16BF16_SS = make_umma_atom_ss(64, 256, d_type="F32", ab_type="BF16") -SM100_128x64x16_F32BF16BF16_SS = make_umma_atom_ss(128, 64, d_type="F32", ab_type="BF16") -SM100_128x128x16_F32BF16BF16_SS = make_umma_atom_ss(128, 128, d_type="F32", ab_type="BF16") -SM100_128x256x16_F32BF16BF16_SS = make_umma_atom_ss(128, 256, d_type="F32", ab_type="BF16") +SM100_64x128x16_F32BF16BF16_SS = make_umma_atom_ss( + 64, 128, d_type="F32", ab_type="BF16" +) +SM100_64x256x16_F32BF16BF16_SS = make_umma_atom_ss( + 64, 256, d_type="F32", ab_type="BF16" +) +SM100_128x64x16_F32BF16BF16_SS = make_umma_atom_ss( + 128, 64, d_type="F32", ab_type="BF16" +) +SM100_128x128x16_F32BF16BF16_SS = make_umma_atom_ss( + 128, 128, d_type="F32", ab_type="BF16" +) +SM100_128x256x16_F32BF16BF16_SS = make_umma_atom_ss( + 128, 256, d_type="F32", ab_type="BF16" +) # F16-accumulator with BF16 inputs SM100_64x64x16_F16BF16BF16_SS = make_umma_atom_ss(64, 64, d_type="F16", ab_type="BF16") -SM100_64x128x16_F16BF16BF16_SS = make_umma_atom_ss(64, 128, d_type="F16", ab_type="BF16") -SM100_64x256x16_F16BF16BF16_SS = make_umma_atom_ss(64, 256, d_type="F16", ab_type="BF16") -SM100_128x64x16_F16BF16BF16_SS = make_umma_atom_ss(128, 64, d_type="F16", ab_type="BF16") -SM100_128x128x16_F16BF16BF16_SS = make_umma_atom_ss(128, 128, d_type="F16", ab_type="BF16") -SM100_128x256x16_F16BF16BF16_SS = make_umma_atom_ss(128, 256, d_type="F16", ab_type="BF16") +SM100_64x128x16_F16BF16BF16_SS = make_umma_atom_ss( + 64, 128, d_type="F16", ab_type="BF16" +) +SM100_64x256x16_F16BF16BF16_SS = make_umma_atom_ss( + 64, 256, d_type="F16", ab_type="BF16" +) +SM100_128x64x16_F16BF16BF16_SS = make_umma_atom_ss( + 128, 64, d_type="F16", ab_type="BF16" +) +SM100_128x128x16_F16BF16BF16_SS = make_umma_atom_ss( + 128, 128, d_type="F16", ab_type="BF16" +) +SM100_128x256x16_F16BF16BF16_SS = make_umma_atom_ss( + 128, 256, d_type="F16", ab_type="BF16" +) # ============================================================================= @@ -845,19 +1050,23 @@ def make_umma_atom_ss(m: int, n: int, k: int = 16, SM120_16x8x32_F32E4M3E4M3F32_TN = MMAAtom( name="SM120_16x8x32_F32E4M3E4M3F32_TN", ptx="mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32", - shape_mnk=(16, 8, 32), thr_id=None, + shape_mnk=(16, 8, 32), + thr_id=None, a_layout=Layout(((4, 8), (4, 2, 2)), ((64, 1), (16, 8, 256))), b_layout=Layout(((4, 8), (4, 2)), ((32, 1), (8, 128))), - c_layout=SM80_16x8_Row) + c_layout=SM80_16x8_Row, +) # SM120 block-scaled MXF8F6F4 16x8x64 SM120_16x8x64_F32E4M3E4M3F32_TN = MMAAtom( name="SM120_16x8x64_F32E4M3E4M3F32_TN", ptx="mma.sync.aligned.m16n8k64.row.col.f32.e4m3.e4m3.f32", - shape_mnk=(16, 8, 64), thr_id=None, + shape_mnk=(16, 8, 64), + thr_id=None, a_layout=Layout(((4, 8), (8, 2, 2)), ((128, 1), (16, 8, 512))), b_layout=Layout(((4, 8), (8, 2)), ((64, 1), (8, 256))), - c_layout=SM80_16x8_Row) + c_layout=SM80_16x8_Row, +) # --- SM120 Sparse (structured 2:4 sparsity) --- # Source: include/cute/atom/mma_traits_sm120_sparse.hpp @@ -866,53 +1075,65 @@ def make_umma_atom_ss(m: int, n: int, k: int = 16, SM120_16x8x64_F32E4M3E4M3F32_TN_SPARSE = MMAAtom( name="SM120_16x8x64_F32E4M3E4M3F32_TN_SPARSE", ptx="mma.sync.aligned.m16n8k64.row.col.f32.e4m3.e4m3.f32 (sparse)", - shape_mnk=(16, 8, 64), thr_id=None, + shape_mnk=(16, 8, 64), + thr_id=None, a_layout=Layout(((4, 8), (8, 2, 2)), ((128, 1), (16, 8, 512))), b_layout=Layout(((4, 8), (4, 4)), ((32, 1), (8, 128))), - c_layout=SM80_16x8_Row) + c_layout=SM80_16x8_Row, +) # SM120 sparse block-scaled 16x8x128 (FP4, 2:4 sparsity) SM120_16x8x128_F32E4M3E4M3F32_TN_SPARSE = MMAAtom( name="SM120_16x8x128_F32E4M3E4M3F32_TN_SPARSE", ptx="mma.sync.aligned.m16n8k128.row.col.f32.e4m3.e4m3.f32 (sparse)", - shape_mnk=(16, 8, 128), thr_id=None, + shape_mnk=(16, 8, 128), + thr_id=None, a_layout=Layout(((4, 8), (16, 2, 2)), ((256, 1), (16, 8, 1024))), b_layout=Layout(((4, 8), (8, 4)), ((64, 1), (8, 256))), - c_layout=SM80_16x8_Row) + c_layout=SM80_16x8_Row, +) # --- SM120 F16-accumulator variants (same layouts as F32, different register width) --- SM120_16x8x32_F16E4M3E4M3F16_TN = MMAAtom( name="SM120_16x8x32_F16E4M3E4M3F16_TN", ptx="mma.sync.aligned.m16n8k32.row.col.f16.e4m3.e4m3.f16", - shape_mnk=(16, 8, 32), thr_id=None, + shape_mnk=(16, 8, 32), + thr_id=None, a_layout=Layout(((4, 8), (4, 2, 2)), ((64, 1), (16, 8, 256))), b_layout=Layout(((4, 8), (4, 2)), ((32, 1), (8, 128))), - c_layout=SM80_16x8_Row) + c_layout=SM80_16x8_Row, +) SM120_16x8x64_F16E4M3E4M3F16_TN = MMAAtom( name="SM120_16x8x64_F16E4M3E4M3F16_TN", ptx="mma.sync.aligned.m16n8k64.row.col.f16.e4m3.e4m3.f16", - shape_mnk=(16, 8, 64), thr_id=None, + shape_mnk=(16, 8, 64), + thr_id=None, a_layout=Layout(((4, 8), (8, 2, 2)), ((128, 1), (16, 8, 512))), b_layout=Layout(((4, 8), (8, 2)), ((64, 1), (8, 256))), - c_layout=SM80_16x8_Row) + c_layout=SM80_16x8_Row, +) SM120_16x8x64_F16E4M3E4M3F16_TN_SPARSE = MMAAtom( name="SM120_16x8x64_F16E4M3E4M3F16_TN_SPARSE", ptx="mma.sync.aligned.m16n8k64.row.col.f16.e4m3.e4m3.f16 (sparse)", - shape_mnk=(16, 8, 64), thr_id=None, + shape_mnk=(16, 8, 64), + thr_id=None, a_layout=Layout(((4, 8), (8, 2, 2)), ((128, 1), (16, 8, 512))), b_layout=Layout(((4, 8), (4, 4)), ((32, 1), (8, 128))), - c_layout=SM80_16x8_Row) + c_layout=SM80_16x8_Row, +) SM120_16x8x128_F16E4M3E4M3F16_TN_SPARSE = MMAAtom( name="SM120_16x8x128_F16E4M3E4M3F16_TN_SPARSE", ptx="mma.sync.aligned.m16n8k128.row.col.f16.e4m3.e4m3.f16 (sparse)", - shape_mnk=(16, 8, 128), thr_id=None, + shape_mnk=(16, 8, 128), + thr_id=None, a_layout=Layout(((4, 8), (16, 2, 2)), ((256, 1), (16, 8, 1024))), b_layout=Layout(((4, 8), (8, 4)), ((64, 1), (8, 256))), - c_layout=SM80_16x8_Row) + c_layout=SM80_16x8_Row, +) # ============================================================================= @@ -926,14 +1147,16 @@ def make_umma_atom_ss(m: int, n: int, k: int = 16, ptx="shfl.sync.bfly (XOR1 2x2 transpose)", thr_id=Layout(32), src_layout_bits=Layout((32, 64), (64, 1)), - dst_layout_bits=Layout(((2, 16), (32, 2)), ((32, 128), (1, 64)))) + dst_layout_bits=Layout(((2, 16), (32, 2)), ((32, 128), (1, 64))), +) SM50_Shuffle_U32_2x2Trans_XOR4 = CopyAtom( name="SM50_Shuffle_U32_2x2Trans_XOR4", ptx="shfl.sync.bfly (XOR4 2x2 transpose)", thr_id=Layout(32), src_layout_bits=Layout((32, 64), (64, 1)), - dst_layout_bits=Layout(((4, 2, 4), (32, 2)), ((64, 32, 512), (1, 256)))) + dst_layout_bits=Layout(((4, 2, 4), (32, 2)), ((64, 32, 512), (1, 256))), +) # ============================================================================= @@ -948,42 +1171,48 @@ def make_umma_atom_ss(m: int, n: int, k: int = 16, ptx="ldmatrix.sync.aligned.x1.m8n8.shared.b16", thr_id=Layout(32), src_layout_bits=Layout(((8, 4), 128), ((128, 0), 1)), - dst_layout_bits=Layout((32, 32), (32, 1))) + dst_layout_bits=Layout((32, 32), (32, 1)), +) SM75_U32x2_LDSM_N = CopyAtom( name="SM75_U32x2_LDSM_N", ptx="ldmatrix.sync.aligned.x2.m8n8.shared.b16", thr_id=Layout(32), src_layout_bits=Layout(((16, 2), 128), ((128, 0), 1)), - dst_layout_bits=Layout((32, (32, 2)), (32, (1, 1024)))) + dst_layout_bits=Layout((32, (32, 2)), (32, (1, 1024))), +) SM75_U32x4_LDSM_N = CopyAtom( name="SM75_U32x4_LDSM_N", ptx="ldmatrix.sync.aligned.x4.m8n8.shared.b16", thr_id=Layout(32), src_layout_bits=Layout((32, 128), (128, 1)), - dst_layout_bits=Layout((32, (32, 4)), (32, (1, 1024)))) + dst_layout_bits=Layout((32, (32, 4)), (32, (1, 1024))), +) SM75_U16x2_LDSM_T = CopyAtom( name="SM75_U16x2_LDSM_T", ptx="ldmatrix.sync.aligned.x1.trans.m8n8.shared.b16", thr_id=Layout(32), src_layout_bits=Layout(((8, 4), 128), ((128, 0), 1)), - dst_layout_bits=Layout(((4, 8), (16, 2)), ((256, 16), (1, 128)))) + dst_layout_bits=Layout(((4, 8), (16, 2)), ((256, 16), (1, 128))), +) SM75_U16x4_LDSM_T = CopyAtom( name="SM75_U16x4_LDSM_T", ptx="ldmatrix.sync.aligned.x2.trans.m8n8.shared.b16", thr_id=Layout(32), src_layout_bits=Layout(((16, 2), 128), ((128, 0), 1)), - dst_layout_bits=Layout(((4, 8), (16, 2, 2)), ((256, 16), (1, 128, 1024)))) + dst_layout_bits=Layout(((4, 8), (16, 2, 2)), ((256, 16), (1, 128, 1024))), +) SM75_U16x8_LDSM_T = CopyAtom( name="SM75_U16x8_LDSM_T", ptx="ldmatrix.sync.aligned.x4.trans.m8n8.shared.b16", thr_id=Layout(32), src_layout_bits=Layout((32, 128), (128, 1)), - dst_layout_bits=Layout(((4, 8), (16, 2, 4)), ((256, 16), (1, 128, 1024)))) + dst_layout_bits=Layout(((4, 8), (16, 2, 4)), ((256, 16), (1, 128, 1024))), +) # ============================================================================= @@ -1002,14 +1231,16 @@ def make_umma_atom_ss(m: int, n: int, k: int = 16, ptx="cp.async.ca.shared.global [16B]", thr_id=Layout(1), src_layout_bits=Layout((1, 128)), - dst_layout_bits=Layout((1, 128))) + dst_layout_bits=Layout((1, 128)), +) SM80_CP_ASYNC_CACHEGLOBAL_16B = CopyAtom( name="SM80_CP_ASYNC_CACHEGLOBAL_16B", ptx="cp.async.cg.shared.global [16B]", thr_id=Layout(1), src_layout_bits=Layout((1, 128)), - dst_layout_bits=Layout((1, 128))) + dst_layout_bits=Layout((1, 128)), +) # ============================================================================= @@ -1024,42 +1255,48 @@ def make_umma_atom_ss(m: int, n: int, k: int = 16, ptx="stmatrix.sync.aligned.x1.m8n8.shared.b16", thr_id=Layout(32), src_layout_bits=SM75_U32x1_LDSM_N.dst_layout_bits, - dst_layout_bits=SM75_U32x1_LDSM_N.src_layout_bits) + dst_layout_bits=SM75_U32x1_LDSM_N.src_layout_bits, +) SM90_U32x2_STSM_N = CopyAtom( name="SM90_U32x2_STSM_N", ptx="stmatrix.sync.aligned.x2.m8n8.shared.b16", thr_id=Layout(32), src_layout_bits=SM75_U32x2_LDSM_N.dst_layout_bits, - dst_layout_bits=SM75_U32x2_LDSM_N.src_layout_bits) + dst_layout_bits=SM75_U32x2_LDSM_N.src_layout_bits, +) SM90_U32x4_STSM_N = CopyAtom( name="SM90_U32x4_STSM_N", ptx="stmatrix.sync.aligned.x4.m8n8.shared.b16", thr_id=Layout(32), src_layout_bits=SM75_U32x4_LDSM_N.dst_layout_bits, - dst_layout_bits=SM75_U32x4_LDSM_N.src_layout_bits) + dst_layout_bits=SM75_U32x4_LDSM_N.src_layout_bits, +) SM90_U16x2_STSM_T = CopyAtom( name="SM90_U16x2_STSM_T", ptx="stmatrix.sync.aligned.x1.trans.m8n8.shared.b16", thr_id=Layout(32), src_layout_bits=SM75_U16x2_LDSM_T.dst_layout_bits, - dst_layout_bits=SM75_U16x2_LDSM_T.src_layout_bits) + dst_layout_bits=SM75_U16x2_LDSM_T.src_layout_bits, +) SM90_U16x4_STSM_T = CopyAtom( name="SM90_U16x4_STSM_T", ptx="stmatrix.sync.aligned.x2.trans.m8n8.shared.b16", thr_id=Layout(32), src_layout_bits=SM75_U16x4_LDSM_T.dst_layout_bits, - dst_layout_bits=SM75_U16x4_LDSM_T.src_layout_bits) + dst_layout_bits=SM75_U16x4_LDSM_T.src_layout_bits, +) SM90_U16x8_STSM_T = CopyAtom( name="SM90_U16x8_STSM_T", ptx="stmatrix.sync.aligned.x4.trans.m8n8.shared.b16", thr_id=Layout(32), src_layout_bits=SM75_U16x8_LDSM_T.dst_layout_bits, - dst_layout_bits=SM75_U16x8_LDSM_T.src_layout_bits) + dst_layout_bits=SM75_U16x8_LDSM_T.src_layout_bits, +) # ============================================================================= @@ -1072,10 +1309,14 @@ def make_umma_atom_ss(m: int, n: int, k: int = 16, ] MMA_ATOMS_SM70 = [ - SM70_8x8x4_F16F16F16F16_TN, SM70_8x8x4_F16F16F16F16_NT, - SM70_8x8x4_F16F16F16F16_NN, SM70_8x8x4_F16F16F16F16_TT, - SM70_8x8x4_F32F16F16F32_TN, SM70_8x8x4_F32F16F16F32_NT, - SM70_8x8x4_F32F16F16F32_NN, SM70_8x8x4_F32F16F16F32_TT, + SM70_8x8x4_F16F16F16F16_TN, + SM70_8x8x4_F16F16F16F16_NT, + SM70_8x8x4_F16F16F16F16_NN, + SM70_8x8x4_F16F16F16F16_TT, + SM70_8x8x4_F32F16F16F32_TN, + SM70_8x8x4_F32F16F16F32_NT, + SM70_8x8x4_F32F16F16F32_NN, + SM70_8x8x4_F32F16F16F32_TT, ] MMA_ATOMS_SM75 = [ @@ -1084,14 +1325,20 @@ def make_umma_atom_ss(m: int, n: int, k: int = 16, ] MMA_ATOMS_SM80 = [ - SM80_16x8x8_F16F16F16F16_TN, SM80_16x8x16_F16F16F16F16_TN, - SM80_16x8x8_F32F16F16F32_TN, SM80_16x8x16_F32F16F16F32_TN, - SM80_16x8x8_F32BF16BF16F32_TN, SM80_16x8x16_F32BF16BF16F32_TN, - SM80_16x8x4_F32TF32TF32F32_TN, SM80_16x8x8_F32TF32TF32F32_TN, + SM80_16x8x8_F16F16F16F16_TN, + SM80_16x8x16_F16F16F16F16_TN, + SM80_16x8x8_F32F16F16F32_TN, + SM80_16x8x16_F32F16F16F32_TN, + SM80_16x8x8_F32BF16BF16F32_TN, + SM80_16x8x16_F32BF16BF16F32_TN, + SM80_16x8x4_F32TF32TF32F32_TN, + SM80_16x8x8_F32TF32TF32F32_TN, SM80_8x8x4_F64F64F64F64_TN, - SM80_8x8x16_S32S8S8S32_TN, SM80_16x8x16_S32S8S8S32_TN, + SM80_8x8x16_S32S8S8S32_TN, + SM80_16x8x16_S32S8S8S32_TN, SM80_16x8x32_S32S8S8S32_TN, - SM80_8x8x32_S32S4S4S32_TN, SM80_16x8x32_S32S4S4S32_TN, + SM80_8x8x32_S32S4S4S32_TN, + SM80_16x8x32_S32S4S4S32_TN, SM80_16x8x64_S32S4S4S32_TN, SM80_8x8x128_S32U1U1S32_TN_XORPOPC, SM80_16x8x128_S32U1U1S32_TN_XORPOPC, @@ -1219,8 +1466,12 @@ def make_umma_atom_ss(m: int, n: int, k: int = 16, ] COPY_ATOMS_SM75 = [ - SM75_U32x1_LDSM_N, SM75_U32x2_LDSM_N, SM75_U32x4_LDSM_N, - SM75_U16x2_LDSM_T, SM75_U16x4_LDSM_T, SM75_U16x8_LDSM_T, + SM75_U32x1_LDSM_N, + SM75_U32x2_LDSM_N, + SM75_U32x4_LDSM_N, + SM75_U16x2_LDSM_T, + SM75_U16x4_LDSM_T, + SM75_U16x8_LDSM_T, ] COPY_ATOMS_SM80 = [ @@ -1229,6 +1480,10 @@ def make_umma_atom_ss(m: int, n: int, k: int = 16, ] COPY_ATOMS_SM90 = [ - SM90_U32x1_STSM_N, SM90_U32x2_STSM_N, SM90_U32x4_STSM_N, - SM90_U16x2_STSM_T, SM90_U16x4_STSM_T, SM90_U16x8_STSM_T, + SM90_U32x1_STSM_N, + SM90_U32x2_STSM_N, + SM90_U32x4_STSM_N, + SM90_U16x2_STSM_T, + SM90_U16x4_STSM_T, + SM90_U16x8_STSM_T, ] diff --git a/src/tensor_layouts/layout_utils.py b/src/tensor_layouts/layout_utils.py index eba4fcc..61cdf08 100644 --- a/src/tensor_layouts/layout_utils.py +++ b/src/tensor_layouts/layout_utils.py @@ -20,6 +20,8 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. +# ruff: noqa: F403,F405 + """Convenience utilities for working with CuTe layouts. These functions provide higher-level operations built on top of the core @@ -106,9 +108,7 @@ def tile_to_shape(layout: Layout, target_shape, order: tuple = None) -> Layout: block_shape = product_each(layout.shape) - product_shape = tuple( - (t + b - 1) // b for t, b in zip(target_shape, block_shape) - ) + product_shape = tuple((t + b - 1) // b for t, b in zip(target_shape, block_shape)) replication = make_ordered_layout(product_shape, order) @@ -165,7 +165,8 @@ def get_strides_for_shape(shape, offset=0): result_strides = get_strides_for_shape(tiler_shape) return Layout(tiler_shape, result_strides) -def tile_mma_grid(atom, atom_layout, matrix='C', tile_mnk=None): + +def tile_mma_grid(atom, atom_layout, matrix="C", tile_mnk=None): """Compute the tiled MMA grid by replicating an atom across quadpairs. Mirrors the C++ make_tiled_mma(atom, atom_layout, Tile) function. @@ -206,17 +207,17 @@ def tile_mma_grid(atom, atom_layout, matrix='C', tile_mnk=None): qp_offset = n_thr_per_atom // 2 if thr_id is not None else n_thr_per_atom # Select atom layout and tile dimensions based on matrix - if matrix == 'C': + if matrix == "C": atom_lyt = atom.c_layout row_atoms = n_atoms_m col_atoms = n_atoms_n atom_rows, atom_cols = M_atom, N_atom - elif matrix == 'A': + elif matrix == "A": atom_lyt = atom.a_layout row_atoms = n_atoms_m col_atoms = 1 atom_rows, atom_cols = M_atom, K_atom - elif matrix == 'B': + elif matrix == "B": atom_lyt = atom.b_layout row_atoms = n_atoms_n col_atoms = 1 @@ -233,9 +234,9 @@ def tile_mma_grid(atom, atom_layout, matrix='C', tile_mnk=None): for am in range(row_atoms): for an in range(col_atoms): # Determine atom index from the atom_layout - if matrix == 'C': + if matrix == "C": atom_idx = atom_layout((am, an)) if not is_int(atom_shape) else am - elif matrix == 'A': + elif matrix == "A": # A tiles along M only; use first N-column atom atom_idx = atom_layout((am, 0)) if not is_int(atom_shape) else am else: # B @@ -268,10 +269,10 @@ def tile_mma_grid(atom, atom_layout, matrix='C', tile_mnk=None): if tile_mnk is not None: tile_M, tile_N, tile_K = tile_mnk - if matrix == 'C': + if matrix == "C": rep_m = tile_M // nat_M rep_n = tile_N // nat_N - elif matrix == 'A': + elif matrix == "A": rep_m = tile_M // nat_M rep_n = 1 else: # B diff --git a/src/tensor_layouts/layouts.py b/src/tensor_layouts/layouts.py index 9579430..6a5375b 100644 --- a/src/tensor_layouts/layouts.py +++ b/src/tensor_layouts/layouts.py @@ -72,45 +72,101 @@ # Type alias "IntOrIntTuple", # Type predicates - "is_tuple", "is_int", "is_scalar", "is_iterable", "is_layout", - "is_pure_shape", "has_none", + "is_tuple", + "is_int", + "is_scalar", + "is_iterable", + "is_layout", + "is_pure_shape", + "has_none", # Shape conversions - "as_tuple", "as_shape", "as_layout", "unwrap", "normalize", + "as_tuple", + "as_shape", + "as_layout", + "unwrap", + "normalize", # Core types - "Layout", "Tile", "Swizzle", "make_swizzle", + "Layout", + "Tile", + "Swizzle", + "make_swizzle", # Stride computation - "compute_col_major_strides", "compute_row_major_strides", + "compute_col_major_strides", + "compute_row_major_strides", # Query functions - "size", "cosize", "rank", "depth", "mode", + "size", + "cosize", + "rank", + "depth", + "mode", # Tuple operations - "concat", "congruent", "compatible", - "tuple_max", "transform_tuple", "zip_transform", - "fold", "fold_accumulate", "elem_scale", "inner_product", - "prefix_product", "suffix_product", "product_each", + "concat", + "congruent", + "compatible", + "tuple_max", + "transform_tuple", + "zip_transform", + "fold", + "fold_accumulate", + "elem_scale", + "inner_product", + "prefix_product", + "suffix_product", + "product_each", # Layout manipulation - "append", "prepend", "replace", "group", - "flatten", "unflatten", "sort", "coalesce", + "append", + "prepend", + "replace", + "group", + "flatten", + "unflatten", + "sort", + "coalesce", # Coordinate conversion - "idx2crd", "crd2flat", "crd2offset", "crd2idx", "crd2crd", - "slice_modes", "dice_modes", "slice_and_offset", + "idx2crd", + "crd2flat", + "crd2offset", + "crd2idx", + "crd2crd", + "slice_modes", + "dice_modes", + "slice_and_offset", # Core algebra - "compose", "complement", "logical_divide", "logical_product", + "compose", + "complement", + "logical_divide", + "logical_product", # Division variants - "zipped_divide", "tiled_divide", "flat_divide", + "zipped_divide", + "tiled_divide", + "flat_divide", # Product variants - "zipped_product", "tiled_product", "hier_unzip", - "blocked_product", "raked_product", "flat_product", + "zipped_product", + "tiled_product", + "hier_unzip", + "blocked_product", + "raked_product", + "flat_product", # Inverse and related - "right_inverse", "left_inverse", "nullspace", - "max_common_layout", "max_common_vector", + "right_inverse", + "left_inverse", + "nullspace", + "max_common_layout", + "max_common_vector", # Shape arithmetic - "safe_div", "shape_div", "shape_mod", + "safe_div", + "shape_div", + "shape_mod", # Upcast / downcast - "upcast", "downcast", + "upcast", + "downcast", # Iteration "iter_layout", # Image and injectivity - "image", "is_injective", "is_surjective", "is_bijective", + "image", + "is_injective", + "is_surjective", + "is_bijective", # Functional equivalence "functionally_equal", ] @@ -143,7 +199,7 @@ def as_layout(obj): """ if isinstance(obj, Layout): return obj - if hasattr(obj, 'shape') and hasattr(obj, 'stride'): + if hasattr(obj, "shape") and hasattr(obj, "stride"): return Layout(obj.shape, obj.stride) raise TypeError(f"Expected Layout, got {type(obj).__name__}") @@ -199,6 +255,7 @@ def has_none(a) -> bool: """ return fold(a, False, lambda acc, v: acc or v is None) + # ============================================================================= # Shape conversions # ============================================================================= @@ -421,10 +478,12 @@ def __repr__(self): def __str__(self): """Return human-readable CuTe notation: (4, 2) : (1, 4).""" + def fmt(x): if isinstance(x, int): return str(x) return repr(x) + base = f"{fmt(self._shape)} : {fmt(self._stride)}" if self._swizzle is not None: return f"({self._swizzle}) o ({base})" @@ -587,6 +646,7 @@ def _zero_leading_unit_strides(shape, strides): still_leading = False return tuple(result) + # ============================================================================= # Query functions: size, rank, depth, mode # ============================================================================= @@ -599,6 +659,7 @@ def _zero_leading_unit_strides(shape, strides): # mode -- extract a single mode (dimension) from a shape or layout # + def size(obj: Any) -> int: """Returns the logical number of elements (product of shape).""" if isinstance(obj, Layout): @@ -665,8 +726,10 @@ def concat(t1: Any, t2: Any): if is_tuple(t1) and is_tuple(t2): return t1 + t2 if isinstance(t1, Layout) and isinstance(t2, Layout): - return Layout(as_tuple(t1.shape) + as_tuple(t2.shape), - as_tuple(t1.stride) + as_tuple(t2.stride)) + return Layout( + as_tuple(t1.shape) + as_tuple(t2.shape), + as_tuple(t1.stride) + as_tuple(t2.stride), + ) raise TypeError( f"Cannot concatenate objects of {type(t1).__name__} and {type(t2).__name__}" ) @@ -729,7 +792,10 @@ def _can_group_a_into_b(a_modes: list, b) -> bool: return acc_size == target_size if is_tuple(b): - return all(_can_group_a_into_b(a_modes, sub_b) for sub_b in b) and len(a_modes) == 0 + return ( + all(_can_group_a_into_b(a_modes, sub_b) for sub_b in b) + and len(a_modes) == 0 + ) return False @@ -744,6 +810,7 @@ def _can_group_a_into_b(a_modes: list, b) -> bool: # and coordinates are computed via idx2crd. # + def iter_layout(layout: Layout): """Yield (coordinate, offset) pairs for every element in the layout. @@ -772,6 +839,7 @@ def iter_layout(layout: Layout): # is_bijective -- both (the layout is a permutation) # + def image(layout: Layout) -> list: """Return the sorted list of distinct offsets produced by the layout. @@ -842,6 +910,7 @@ def is_bijective(layout: Layout) -> bool: # Functional equivalence # ============================================================================= + def functionally_equal(a: Layout, b: Layout) -> bool: """True if two layouts compute the same mapping for every flat index. @@ -913,9 +982,7 @@ def group(layout: Layout, start: int, end: int) -> Layout: """ r = rank(layout) if start < 0 or end > r or start >= end: - raise ValueError( - f"Invalid group range [{start}, {end}) for layout of rank {r}" - ) + raise ValueError(f"Invalid group range [{start}, {end}) for layout of rank {r}") shapes = list(as_tuple(layout.shape)) strides = list(as_tuple(layout.stride)) @@ -981,6 +1048,7 @@ def unflatten(obj, target_profile): flatten(obj) == obj (obj must already be flat) rank(flatten(target_profile)) == rank(obj) """ + def _unflatten_helper(flat_tuple, profile): """Consume elements from flat_tuple to match profile's structure.""" if is_tuple(profile): @@ -1063,6 +1131,7 @@ def sort(obj: Layout) -> Layout: # from a shape, which is how Layout(shape) auto-computes its strides. # + def tuple_max(a: Any) -> int: """Return the maximum value across all terminals of a (possibly nested) int-tuple. @@ -1071,7 +1140,7 @@ def tuple_max(a: Any) -> int: tuple_max((3, 7, 2)) -> 7 tuple_max(((1, 9), (4, 2))) -> 9 """ - return fold(a, -float('inf'), lambda acc, x: max(acc, x)) + return fold(a, -float("inf"), lambda acc, x: max(acc, x)) def transform_tuple(t: Any, f) -> Any: @@ -1218,7 +1287,9 @@ def inner_product(a: Any, b: Any) -> int: return sum(inner_product(x, y) for x, y in zip(a, b)) else: if not isinstance(a, int) or not isinstance(b, int): - raise TypeError(f"Expected int, got {type(a).__name__} and {type(b).__name__}") + raise TypeError( + f"Expected int, got {type(a).__name__} and {type(b).__name__}" + ) return a * b @@ -1300,6 +1371,7 @@ def suffix_product(a: Any, init: Any = 1) -> Any: # always safe and always preserves semantics. # + def coalesce(obj: Layout, profile: Any = None) -> Layout: """Returns a new Layout where contiguous dimensions are merged. @@ -1366,7 +1438,9 @@ def _coalesce_by_mode(layout: Layout, profile: tuple) -> Layout: result_s.append(1) result_d.append(0) else: - coalesced = _coalesce_flat(Layout(mode(layout.shape, i), mode(layout.stride, i))) + coalesced = _coalesce_flat( + Layout(mode(layout.shape, i), mode(layout.stride, i)) + ) result_s.append(coalesced.shape) result_d.append(coalesced.stride) return Layout(as_shape(result_s), as_shape(result_d)) @@ -1434,6 +1508,7 @@ def _coalesce_by_mode(layout: Layout, profile: tuple) -> Layout: # free dimensions, much like NumPy's array[3, :, :] syntax. # + def complement(layout: Layout, cosize_bound: Any = None) -> Layout: """Compute the complement of a layout: a layout that fills in the gaps. @@ -1505,13 +1580,9 @@ def _step_mode(current_stride, stride, shape): # CuTe/pycute asserts current_stride <= stride * shape (injectivity). # Negative strides or zero-sized shapes violate this invariant. if stride < 0: - raise ValueError( - f"complement: negative stride {stride} is not supported" - ) + raise ValueError(f"complement: negative stride {stride} is not supported") if shape == 0: - raise ValueError( - f"complement: zero-sized shape is not supported" - ) + raise ValueError("complement: zero-sized shape is not supported") gap_size, next_stride = _step_mode(current_stride, stride, shape) if gap_size > 1: result_shapes.append(gap_size) @@ -1595,10 +1666,7 @@ def _step_mode(current_idx, stride, shape): if not result_shape: return Layout(1, 0) - return coalesce(Layout( - tuple(result_shape), - tuple(result_stride) - )) + return coalesce(Layout(tuple(result_shape), tuple(result_stride))) def left_inverse(layout: Any) -> Layout: @@ -1789,6 +1857,7 @@ def slice_and_offset(crd, layout: Layout): # crd2crd: convert between two shapes' coordinate spaces # + def idx2crd(coord: Any, shape: Any) -> Any: """Convert index into a hierarchical coordinate.""" @@ -1895,13 +1964,10 @@ def crd2offset(coord, shape, stride) -> int: # Case 3: nD coordinate mapping (coord tuple -> offset) if not is_tuple(coord): - raise TypeError( - f"Coordinate must be int or tuple, got {type(coord).__name__}" - ) + raise TypeError(f"Coordinate must be int or tuple, got {type(coord).__name__}") if len(coord) != len(shape): raise ValueError( - f"Coordinate rank {len(coord)} does not match " - f"layout rank {len(shape)}" + f"Coordinate rank {len(coord)} does not match " f"layout rank {len(shape)}" ) offset = 0 for c, s, d in zip(coord, shape, stride): @@ -1949,12 +2015,16 @@ def crd2crd(crd: Any, dst_shape: Any, src_shape: Any = None) -> Any: if is_tuple(crd): if is_tuple(dst_shape): if len(crd) != len(dst_shape): - raise ValueError(f"Rank mismatch: crd has {len(crd)} elements, dst_shape has {len(dst_shape)}") + raise ValueError( + f"Rank mismatch: crd has {len(crd)} elements, dst_shape has {len(dst_shape)}" + ) return zip_transform(crd, dst_shape, crd2crd) else: # crd is tuple, dst_shape is scalar: flatten using src_shape if src_shape is None: - raise ValueError("src_shape required to flatten tuple coordinate to scalar") + raise ValueError( + "src_shape required to flatten tuple coordinate to scalar" + ) return crd2flat(crd, src_shape) else: if is_tuple(dst_shape): @@ -1987,7 +2057,9 @@ def slice_modes(crd, trg): if is_tuple(crd): if is_tuple(trg): if len(crd) != len(trg): - raise ValueError(f"Rank mismatch: crd has {len(crd)} elements, trg has {len(trg)}") + raise ValueError( + f"Rank mismatch: crd has {len(crd)} elements, trg has {len(trg)}" + ) # Flatten and concatenate non-empty results result = [] for c, s in zip(crd, trg): @@ -2028,12 +2100,15 @@ def dice_modes(crd, layout): dice_modes((0, None), Layout((3,4),(1,4))) -> 3:1 # keep mode 0 dice_modes((None, 0), Layout((3,4),(1,4))) -> 4:4 # keep mode 1 """ + def dice_tuple(crd, trg): """Keep elements of trg paired with integers in crd.""" if is_tuple(crd): if is_tuple(trg): if len(crd) != len(trg): - raise ValueError(f"Rank mismatch: crd has {len(crd)} elements, trg has {len(trg)}") + raise ValueError( + f"Rank mismatch: crd has {len(crd)} elements, trg has {len(trg)}" + ) result = [] for c, s in zip(crd, trg): result.extend(dice_tuple(c, s)) @@ -2082,6 +2157,7 @@ def dice_tuple(crd, trg): # nested shape elements, consuming from the innermost (leftmost) modes first. # + class Tile(tuple): """A Tiler is a tuple-of-Layouts used for mode-by-mode composition. @@ -2197,6 +2273,7 @@ def shape_mod(shape: Any, modulus: int) -> Any: shape_mod((4, 3), 2) -> (2, 1) # 2 consumed from first mode, nothing from second shape_mod((4, 3), 12) -> (4, 3) # All kept (modulus >= size) """ + def _scalar(s, m): return s if m >= s else math.gcd(s, m) @@ -2249,7 +2326,9 @@ def _upcast_leaf(s, d): def _apply(shape, stride): if is_tuple(shape): if not is_tuple(stride) or len(shape) != len(stride): - raise ValueError(f"Shape/stride structure mismatch: {shape} vs {stride}") + raise ValueError( + f"Shape/stride structure mismatch: {shape} vs {stride}" + ) pairs = [_apply(s, d) for s, d in zip(shape, stride)] new_s = tuple(p[0] for p in pairs) new_d = tuple(p[1] for p in pairs) @@ -2283,7 +2362,9 @@ def _downcast_leaf(s, d): def _apply(shape, stride): if is_tuple(shape): if not is_tuple(stride) or len(shape) != len(stride): - raise ValueError(f"Shape/stride structure mismatch: {shape} vs {stride}") + raise ValueError( + f"Shape/stride structure mismatch: {shape} vs {stride}" + ) pairs = [_apply(s, d) for s, d in zip(shape, stride)] new_s = tuple(p[0] for p in pairs) new_d = tuple(p[1] for p in pairs) @@ -2360,17 +2441,20 @@ def _compose_layouts(layout_a: Layout, layout_b: Layout) -> Layout: def compose_element(b_shape, b_stride): """Recursively compose A with one element of B's shape/stride.""" if is_tuple(b_shape): - results = [compose_element(b_shape[i], b_stride[i]) - for i in range(len(b_shape))] - return Layout(tuple(r.shape for r in results), - tuple(r.stride for r in results)) + results = [ + compose_element(b_shape[i], b_stride[i]) for i in range(len(b_shape)) + ] + return Layout( + tuple(r.shape for r in results), tuple(r.stride for r in results) + ) return _composition_1d(layout_a, b_shape, b_stride) if is_tuple(layout_b.shape): - results = [compose_element(layout_b.shape[i], layout_b.stride[i]) - for i in range(len(layout_b.shape))] - return Layout(tuple(r.shape for r in results), - tuple(r.stride for r in results)) + results = [ + compose_element(layout_b.shape[i], layout_b.stride[i]) + for i in range(len(layout_b.shape)) + ] + return Layout(tuple(r.shape for r in results), tuple(r.stride for r in results)) return _composition_1d(layout_a, layout_b.shape, layout_b.stride) @@ -2484,6 +2568,7 @@ def compose(layout_a: Any, layout_b: Any) -> Any: # Tuple tiler - convert elements and recurse if is_tuple(layout_b): + def to_layout(elem): if isinstance(elem, Layout): return elem @@ -2649,8 +2734,16 @@ def _logical_divide_by_shape(layout: Layout, tiler_shape: Any) -> Layout: result_strides.append(divided.stride) else: tile_part = compose(Layout(s, d), Layout(tile_size, 1)) - tile_s = unwrap(tile_part.shape) if is_tuple(tile_part.shape) else tile_part.shape - tile_d = unwrap(tile_part.stride) if is_tuple(tile_part.stride) else tile_part.stride + tile_s = ( + unwrap(tile_part.shape) + if is_tuple(tile_part.shape) + else tile_part.shape + ) + tile_d = ( + unwrap(tile_part.stride) + if is_tuple(tile_part.stride) + else tile_part.stride + ) result_shapes.append((tile_s, 1)) result_strides.append((tile_d, elem_scale(d, mode_size))) @@ -2725,7 +2818,9 @@ def zipped_divide(layout: Layout, tiler: Any) -> Layout: Examples: zipped_divide(Layout((4,8)), (2,4)) -> Layout(((2,4),(2,2)), ((1,4),(2,16))) """ - tile_shapes, tile_strides, rest_shapes, rest_strides = _split_divided_modes(layout, tiler) + tile_shapes, tile_strides, rest_shapes, rest_strides = _split_divided_modes( + layout, tiler + ) tiles_shape = as_shape(tile_shapes) tiles_stride = as_shape(tile_strides) @@ -2757,7 +2852,9 @@ def tiled_divide(layout: Layout, tiler: Any) -> Layout: Examples: tiled_divide(Layout((8,8)), (2,2)) -> Layout(((2,2), 4, 4), ...) """ - tile_shapes, tile_strides, rest_shapes, rest_strides = _split_divided_modes(layout, tiler) + tile_shapes, tile_strides, rest_shapes, rest_strides = _split_divided_modes( + layout, tiler + ) tiles_shape = as_shape(tile_shapes) tiles_stride = as_shape(tile_strides) @@ -2786,7 +2883,9 @@ def flat_divide(layout: Layout, tiler: Any) -> Layout: Examples: flat_divide(Layout((8,8)), (2,2)) -> Layout((2, 2, 4, 4), ...) """ - tile_shapes, tile_strides, rest_shapes, rest_strides = _split_divided_modes(layout, tiler) + tile_shapes, tile_strides, rest_shapes, rest_strides = _split_divided_modes( + layout, tiler + ) all_shapes = tile_shapes + rest_shapes all_strides = tile_strides + rest_strides @@ -2886,8 +2985,10 @@ def hier_unzip(splitter, layout_a: Layout, layout_b) -> Layout: f"layout_a rank ({rank(layout_a)}) < tiler length ({len(layout_b)})" ) - splits = [hier_unzip(splitter, mode(layout_a, i), layout_b[i]) - for i in range(len(layout_b))] + splits = [ + hier_unzip(splitter, mode(layout_a, i), layout_b[i]) + for i in range(len(layout_b)) + ] first_shapes = [mode(s, 0).shape for s in splits] first_strides = [mode(s, 0).stride for s in splits] @@ -3107,7 +3208,9 @@ def _zip_layouts(layout_a: Layout, layout_b: Layout) -> Layout: # Handle scalar layouts by treating them as rank-1 if a_rank == 0 and b_rank == 0: # Both scalar: create a single mode with paired shapes/strides - return Layout((layout_a.shape, layout_b.shape), (layout_a.stride, layout_b.stride)) + return Layout( + (layout_a.shape, layout_b.shape), (layout_a.stride, layout_b.stride) + ) if a_rank != b_rank: raise ValueError(f"Rank mismatch in zip: {a_rank} vs {b_rank}") @@ -3276,7 +3379,11 @@ def __eq__(self, other: object) -> bool: return True if not isinstance(other, Swizzle): return False - return self.bits == other.bits and self.base == other.base and self.shift == other.shift + return ( + self.bits == other.bits + and self.base == other.base + and self.shift == other.shift + ) @property def yyy_msk(self) -> int: diff --git a/src/tensor_layouts/tensor.py b/src/tensor_layouts/tensor.py index 9a85618..46becdf 100644 --- a/src/tensor_layouts/tensor.py +++ b/src/tensor_layouts/tensor.py @@ -20,6 +20,8 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. +# ruff: noqa: F403,F405 + """Tensor class: combines a Layout with a base offset (pointer equivalent). In CuTe, a Tensor is (Engine/Pointer, Layout). Here we represent the pointer @@ -144,8 +146,10 @@ def _slice_single(self, key, mode_idx: int) -> "Tensor | int": if isinstance(key, slice) and key == slice(None): # Slice with : (all elements) - return tensor for this mode mode_layout = mode(self._layout, mode_idx) - return Tensor(Layout(mode_layout.shape, mode_layout.stride, - swizzle=self._layout.swizzle), self._offset) + return Tensor( + Layout(mode_layout.shape, mode_layout.stride, swizzle=self._layout.swizzle), + self._offset, + ) elif isinstance(key, (int, tuple)): # Fixed coordinate - compute the linear offset contribution return self._fix_mode(mode_idx, key) @@ -155,9 +159,7 @@ def _slice_single(self, key, mode_idx: int) -> "Tensor | int": def _slice_multi(self, keys: tuple) -> "Tensor | int": """Handle multi-dimensional slicing like tensor[i, :].""" if len(keys) != rank(self._layout): - raise IndexError( - f"Expected {rank(self._layout)} indices, got {len(keys)}" - ) + raise IndexError(f"Expected {rank(self._layout)} indices, got {len(keys)}") fixed_modes = [] sliced_modes = [] @@ -187,8 +189,11 @@ def _build_remaining_layout(self, mode_indices) -> Layout: m = mode(self._layout, idx) remaining_shapes.append(unwrap(m.shape)) remaining_strides.append(unwrap(m.stride)) - return Layout(as_shape(remaining_shapes), as_shape(remaining_strides), - swizzle=self._layout.swizzle) + return Layout( + as_shape(remaining_shapes), + as_shape(remaining_strides), + swizzle=self._layout.swizzle, + ) def _fix_mode(self, mode_idx: int, coord) -> "Tensor | int": """Fix one mode to a specific coordinate value.""" diff --git a/src/tensor_layouts/viz.py b/src/tensor_layouts/viz.py index 81e1db9..d9c0ac0 100644 --- a/src/tensor_layouts/viz.py +++ b/src/tensor_layouts/viz.py @@ -20,6 +20,8 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. +# ruff: noqa: F403, F405 + """Layout visualization with PNG/SVG/PDF output. Visualize layouts, swizzled layouts, and tensor slices in cute-viz style. @@ -179,7 +181,7 @@ def _make_rainbow_palette(n: int) -> list: for i in range(n): hue = i / n r, g, b = colorsys.hsv_to_rgb(hue, sat, val) - monotonic.append(f"#{int(r*255):02X}{int(g*255):02X}{int(b*255):02X}") + monotonic.append(f"#{int(r * 255):02X}{int(g * 255):02X}{int(b * 255):02X}") order = _max_contrast_order(n) return [monotonic[k] for k in order] @@ -413,14 +415,10 @@ def _get_color_indices_2d(layout, color_layout) -> Optional[np.ndarray]: row_coord = idx2crd(i, row_shape) for j in range(cols): col_coord = idx2crd(j, col_shape) - color_indices[i, j] = _color_result_to_index( - color_layout(row_coord, col_coord) - ) + color_indices[i, j] = _color_result_to_index(color_layout(row_coord, col_coord)) return color_indices - raise ValueError( - f"Unsupported color_layout rank {color_rank} for layout rank {layout_rank}" - ) + raise ValueError(f"Unsupported color_layout rank {color_rank} for layout rank {layout_rank}") # ============================================================================= @@ -727,9 +725,7 @@ def _build_composite_figure( color_layout = opts.get("color_layout", None) num_colors = opts.get("num_colors", 8) panel_flatten = opts.get("flatten_hierarchical", flatten_hierarchical) - panel_label_levels = opts.get( - "label_hierarchy_levels", label_hierarchy_levels - ) + panel_label_levels = opts.get("label_hierarchy_levels", label_hierarchy_levels) # Get title title = titles[idx] if titles and idx < len(titles) else None @@ -750,12 +746,18 @@ def _build_composite_figure( else: # Check if this panel should use hierarchical rendering r = rank(layout) - is_hier = r == 2 and not panel_flatten and ( - isinstance(mode(layout.shape, 0), tuple) - or isinstance(mode(layout.shape, 1), tuple) + is_hier = ( + r == 2 + and not panel_flatten + and ( + isinstance(mode(layout.shape, 0), tuple) + or isinstance(mode(layout.shape, 1), tuple) + ) ) grid = _prepare_offset_grid( - layout, color_layout=color_layout, eval_fn=eval_fn, + layout, + color_layout=color_layout, + eval_fn=eval_fn, hierarchical=is_hier, ) if grid.is_hierarchical: @@ -974,9 +976,7 @@ def _draw_hierarchy_boundary_lines( n_row_levels = len(row_block_sizes) n_col_levels = len(col_block_sizes) - def _is_shadowed_by_coarser( - level: int, pos: int, block_sizes: tuple[int, ...] - ) -> bool: + def _is_shadowed_by_coarser(level: int, pos: int, block_sizes: tuple[int, ...]) -> bool: """Return True if a same-orientation coarser hierarchy line also sits at pos.""" for coarser_level in range(level + 1, len(block_sizes)): coarser_block = block_sizes[coarser_level] @@ -1032,9 +1032,7 @@ def _draw_boundary_line( _draw_boundary_line(j, 0, j, rows, color, linewidth, zorder) -def _format_hierarchical_cell_lines( - row_coord, col_coord, offset: int -) -> tuple[str, str, str]: +def _format_hierarchical_cell_lines(row_coord, col_coord, offset: int) -> tuple[str, str, str]: """Format pedagogical hierarchical cell labels. Returns three explicit lines: @@ -1153,12 +1151,10 @@ def _auto_hierarchical_figsize( scale_pts = max(required_cell_width_pts, required_cell_height_pts) subplot_width_frac = ( - matplotlib.rcParams["figure.subplot.right"] - - matplotlib.rcParams["figure.subplot.left"] + matplotlib.rcParams["figure.subplot.right"] - matplotlib.rcParams["figure.subplot.left"] ) subplot_height_frac = ( - matplotlib.rcParams["figure.subplot.top"] - - matplotlib.rcParams["figure.subplot.bottom"] + matplotlib.rcParams["figure.subplot.top"] - matplotlib.rcParams["figure.subplot.bottom"] ) total_x_range = cols + 0.5 + left_margin total_y_range = rows + 0.5 + top_margin @@ -1196,11 +1192,7 @@ def _draw_colored_coord_line( pieces.append( ( str(value), - ( - _hierarchy_level_color(level, for_dark_bg) - if use_level_colors - else base_color - ), + (_hierarchy_level_color(level, for_dark_bg) if use_level_colors else base_color), ) ) if level != len(levels) - 1: @@ -1521,8 +1513,7 @@ def _build_layout_figure( color_layout = None # default behavior else: raise ValueError( - f"Unknown color_by value: {color_by!r} " - f"(expected 'row', 'column', or 'offset')" + f"Unknown color_by value: {color_by!r} (expected 'row', 'column', or 'offset')" ) colorize = True @@ -1625,8 +1616,7 @@ def fn(*args): # Check if this is a hierarchical layout (has nested tuple shapes) is_hierarchical = r == 2 and ( - isinstance(mode(layout.shape, 0), tuple) - or isinstance(mode(layout.shape, 1), tuple) + isinstance(mode(layout.shape, 0), tuple) or isinstance(mode(layout.shape, 1), tuple) ) want_hierarchical = is_hierarchical and not flatten_hierarchical @@ -2169,12 +2159,8 @@ def draw_tv_matrix( title_above=True, col_major=False, ) - draw_tv_matrix( - layout_a, a_offset_x, a_offset_y, M, K, f"A ({M}×{K})", title_above=False - ) - draw_tv_matrix( - layout_c, c_offset_x, c_offset_y, M, N, f"C ({M}×{N})", title_above=False - ) + draw_tv_matrix(layout_a, a_offset_x, a_offset_y, M, K, f"A ({M}×{K})", title_above=False) + draw_tv_matrix(layout_c, c_offset_x, c_offset_y, M, N, f"C ({M}×{N})", title_above=False) for k in range(K): ax.text( @@ -2283,16 +2269,12 @@ def draw_mma_layout( _save_figure(fig, filename, dpi) -def _build_tiled_grid_figure( - grid: dict, rows: int, cols: int, title: Optional[str] = None -): +def _build_tiled_grid_figure(grid: dict, rows: int, cols: int, title: Optional[str] = None): """Build the tiled-grid figure used by draw_tiled_grid/show_tiled_grid.""" colors = _make_rainbow_palette(8) font = max(4, min(7, int(60 / max(rows, cols)))) fig, ax = plt.subplots(figsize=(cols * 0.45 + 1.5, rows * 0.4 + 1.0)) - _setup_axes( - ax, (-0.5, cols + 0.5), (-0.5, rows + 0.5), title=title, title_fontsize=9 - ) + _setup_axes(ax, (-0.5, cols + 0.5), (-0.5, rows + 0.5), title=title, title_fontsize=9) _draw_tv_cells(ax, grid, rows, cols, colors, fontsize=font, linewidth=0.5) plt.tight_layout() return fig @@ -2389,9 +2371,7 @@ def _build_combined_grid_figure(a_grid, b_grid, c_grid, M, N, K, title=None): return fig -def draw_combined_mma_grid( - a_grid, b_grid, c_grid, M, N, K, filename=None, dpi=150, title=None -): +def draw_combined_mma_grid(a_grid, b_grid, c_grid, M, N, K, filename=None, dpi=150, title=None): """Draw combined A/B/C grid-dict panels in the standard MMA arrangement. This is the grid-dict counterpart of draw_mma_layout. Use it when @@ -2473,9 +2453,7 @@ def _build_copy_figure( fontsize=6, linewidth=0.5, ) - ax.text( - cols / 2, -0.6, "Src", ha="center", va="bottom", fontsize=10, fontweight="bold" - ) + ax.text(cols / 2, -0.6, "Src", ha="center", va="bottom", fontsize=10, fontweight="bold") # Destination grid (right) dst_ox = cols + gap @@ -2619,9 +2597,7 @@ def _build_swizzle_figure( figsize = _swizzle_figsize(linear_idx, swizzle_idx, rows, cols) def _swizzle_color_indices(idx_array): - return np.vectorize(lambda v: (int(v) >> bit_shift) % effective_colors)( - idx_array - ) + return np.vectorize(lambda v: (int(v) >> bit_shift) % effective_colors)(idx_array) linear_ci = _swizzle_color_indices(linear_idx) swizzle_ci = _swizzle_color_indices(swizzle_idx) @@ -2712,7 +2688,9 @@ def _expand_hier_slice(spec, shape): elif is_tuple(spec): if is_tuple(shape): if len(spec) != len(shape): - raise ValueError(f"Rank mismatch: spec has {len(spec)} elements, shape has {len(shape)}") + raise ValueError( + f"Rank mismatch: spec has {len(spec)} elements, shape has {len(shape)}" + ) sub_iters = [_expand_hier_slice(s, sh) for s, sh in zip(spec, shape)] for combo in itertools.product(*sub_iters): yield combo @@ -2745,24 +2723,17 @@ def _match_nested_slice_component(coord, spec, shape) -> bool: return coord == spec if isinstance(spec, slice): if is_tuple(coord): - raise TypeError( - f"Slice spec {spec} requires a scalar coordinate, got {coord!r}" - ) + raise TypeError(f"Slice spec {spec} requires a scalar coordinate, got {coord!r}") return coord in range(*spec.indices(shape)) if is_tuple(spec): if not is_tuple(coord) or not is_tuple(shape): - raise TypeError( - f"Tuple spec {spec!r} is incompatible with coordinate {coord!r}" - ) + raise TypeError(f"Tuple spec {spec!r} is incompatible with coordinate {coord!r}") if len(spec) != len(coord) or len(spec) != len(shape): raise ValueError( f"Tuple spec length {len(spec)} does not match coordinate/shape lengths " f"{len(coord)} / {len(shape)}" ) - return all( - _match_nested_slice_component(c, s, sh) - for c, s, sh in zip(coord, spec, shape) - ) + return all(_match_nested_slice_component(c, s, sh) for c, s, sh in zip(coord, spec, shape)) raise TypeError(f"Unsupported slice component {spec!r}") @@ -2826,8 +2797,7 @@ def _get_slice_highlight_mask_2d(layout, slice_spec) -> np.ndarray: elif isinstance(slice_spec, tuple) and r < 2: if len(slice_spec) != 1: raise ValueError( - f"Rank-{r} layout requires a 1-element tuple slice_spec, " - f"got {len(slice_spec)}" + f"Rank-{r} layout requires a 1-element tuple slice_spec, got {len(slice_spec)}" ) (col_spec,) = slice_spec col_flat = _is_flat_slice_component(col_spec) @@ -2852,9 +2822,7 @@ def _build_slice_figure( num_colors=8, ): """Build the slice figure used by draw_slice/show_slice.""" - grid = _prepare_offset_grid( - layout, color_layout=color_layout, slice_spec=slice_spec - ) + grid = _prepare_offset_grid(layout, color_layout=color_layout, slice_spec=slice_spec) if figsize is None: figsize = (grid.cols * 0.5 + 1, grid.rows * 0.5 + 1) @@ -3291,7 +3259,6 @@ def demo(output_dir: str = "."): """Generate example visualizations in all formats.""" from pathlib import Path - output = Path(output_dir) output.mkdir(parents=True, exist_ok=True) diff --git a/tests/analysis.py b/tests/analysis.py index bdd5618..3345b39 100644 --- a/tests/analysis.py +++ b/tests/analysis.py @@ -65,32 +65,32 @@ def test_offset_table_strided(): def test_footprint_contiguous(): """Contiguous layout: no holes, no reuse.""" result = footprint(Layout(8, 1)) - assert result['min_offset'] == 0 - assert result['max_offset'] == 7 - assert result['span'] == 8 - assert result['unique_offsets'] == 8 - assert result['total_elements'] == 8 - assert result['reuse_factor'] == 1.0 - assert result['holes'] == 0 + assert result["min_offset"] == 0 + assert result["max_offset"] == 7 + assert result["span"] == 8 + assert result["unique_offsets"] == 8 + assert result["total_elements"] == 8 + assert result["reuse_factor"] == 1.0 + assert result["holes"] == 0 def test_footprint_strided(): """Strided layout: holes between offsets.""" result = footprint(Layout(4, 2)) - assert result['min_offset'] == 0 - assert result['max_offset'] == 6 - assert result['span'] == 7 - assert result['unique_offsets'] == 4 - assert result['holes'] == 3 + assert result["min_offset"] == 0 + assert result["max_offset"] == 6 + assert result["span"] == 7 + assert result["unique_offsets"] == 4 + assert result["holes"] == 3 def test_footprint_broadcast(): """Broadcast: high reuse factor.""" result = footprint(Layout((4, 2), (0, 1))) - assert result['unique_offsets'] == 2 - assert result['total_elements'] == 8 - assert result['reuse_factor'] == 4.0 - assert result['holes'] == 0 + assert result["unique_offsets"] == 2 + assert result["total_elements"] == 8 + assert result["reuse_factor"] == 4.0 + assert result["holes"] == 0 ## bank_conflicts @@ -99,14 +99,14 @@ def test_footprint_broadcast(): def test_bank_conflicts_linear(): """Linear stride-1 access: no conflicts.""" result = bank_conflicts(Layout(32, 1), element_bytes=2) - assert result['conflict_free'] - assert result['max_ways'] == 1 + assert result["conflict_free"] + assert result["max_ways"] == 1 def test_bank_conflicts_broadcast(): """All threads access same address: broadcast, not a conflict.""" result = bank_conflicts(Layout(32, 0), element_bytes=2) - assert result['conflict_free'] + assert result["conflict_free"] def test_bank_conflicts_stride_32(): @@ -115,7 +115,7 @@ def test_bank_conflicts_stride_32(): # Actually: thread t -> offset 32*t, byte_addr = 64*t, bank = (64t/4) % 32 = 16t % 32 # This causes 2-way conflicts (threads 0,2,4,... hit bank 0; threads 1,3,5,... hit bank 16) result = bank_conflicts(Layout(32, 32), element_bytes=2) - assert not result['conflict_free'] + assert not result["conflict_free"] def test_bank_conflicts_swizzled(): @@ -134,7 +134,6 @@ def test_bank_conflicts_swizzled(): # Also verify via bank_conflicts: each row as a 1D layout for thread in range(8): - row = Layout(8, 1) # value indices 0..7 # Build a layout mapping value -> swizzled offset for this thread offsets = [sw_layout(thread, v) for v in range(8)] result = bank_conflicts( @@ -142,13 +141,13 @@ def test_bank_conflicts_swizzled(): element_bytes=4, # treat each offset as a 4-byte word ) # stride-1, 8 consecutive elements with 4-byte words: 8 different banks - assert result['conflict_free'] + assert result["conflict_free"] def test_bank_conflicts_fp32(): """4-byte elements: bank width matches element width.""" result = bank_conflicts(Layout(32, 1), element_bytes=4) - assert result['conflict_free'] + assert result["conflict_free"] def test_bank_conflicts_group_size(): @@ -157,11 +156,11 @@ def test_bank_conflicts_group_size(): r32 = bank_conflicts(Layout(32, 32), element_bytes=2) r64_default = bank_conflicts(Layout(64, 32), element_bytes=2) # Default group_size=32 limits analysis to first warp - assert r64_default['max_ways'] == r32['max_ways'] + assert r64_default["max_ways"] == r32["max_ways"] # Explicitly analyzing all 64 threads gives a larger conflict factor r64_full = bank_conflicts(Layout(64, 32), element_bytes=2, group_size=64) - assert r64_full['max_ways'] > r32['max_ways'] + assert r64_full["max_ways"] > r32["max_ways"] def test_bank_conflicts_group_size_validation(): @@ -177,8 +176,8 @@ def test_bank_conflicts_tv_layout(): # 32 threads, 2 values: stride-1 threads, stride-32 values tv = Layout((32, 2), (1, 32)) r = bank_conflicts(tv, element_bytes=2) - assert r['conflict_free'] - assert len(r['bank_to_threads']) == 32 # all banks accessed + assert r["conflict_free"] + assert len(r["bank_to_threads"]) == 32 # all banks accessed ## coalescing_efficiency @@ -187,40 +186,40 @@ def test_bank_conflicts_tv_layout(): def test_coalescing_contiguous_fp16(): """32 threads, stride 1, fp16: one cache line (64B of 128B).""" result = coalescing_efficiency(Layout(32, 1), element_bytes=2) - assert result['transactions'] == 1 - assert result['efficiency'] == pytest.approx(0.5) + assert result["transactions"] == 1 + assert result["efficiency"] == pytest.approx(0.5) def test_coalescing_contiguous_fp32(): """32 threads, stride 1, fp32: one cache line (128B of 128B).""" result = coalescing_efficiency(Layout(32, 1), element_bytes=4) - assert result['transactions'] == 1 - assert result['efficiency'] == pytest.approx(1.0) + assert result["transactions"] == 1 + assert result["efficiency"] == pytest.approx(1.0) def test_coalescing_strided(): """Stride-2 access doubles the cache lines touched.""" result = coalescing_efficiency(Layout(32, 2), element_bytes=2) - assert result['transactions'] == 1 # 32*2*2=128 bytes, still fits in 1 line + assert result["transactions"] == 1 # 32*2*2=128 bytes, still fits in 1 line # Actually: offsets 0,2,4,...,62. byte addrs 0,4,8,...,124. All in line 0. - assert result['efficiency'] == pytest.approx(0.5) + assert result["efficiency"] == pytest.approx(0.5) def test_coalescing_large_stride(): """Large stride: each thread touches a different cache line.""" # stride 64 elements * 2 bytes = 128 bytes = 1 cache line apart result = coalescing_efficiency(Layout(32, 64), element_bytes=2) - assert result['transactions'] == 32 + assert result["transactions"] == 32 # 32 threads * 2 bytes = 64 useful bytes, 32 * 128 = 4096 transferred - assert result['efficiency'] == pytest.approx(64.0 / (32 * 128)) + assert result["efficiency"] == pytest.approx(64.0 / (32 * 128)) def test_coalescing_broadcast(): """All threads access same element: single transaction, minimal useful bytes.""" result = coalescing_efficiency(Layout(32, 0), element_bytes=2) - assert result['transactions'] == 1 + assert result["transactions"] == 1 # Only 1 unique offset: 1 * 2 bytes useful out of 128 transferred - assert result['efficiency'] == pytest.approx(2.0 / 128) + assert result["efficiency"] == pytest.approx(2.0 / 128) def test_coalescing_tv_layout(): @@ -229,8 +228,8 @@ def test_coalescing_tv_layout(): tv = Layout((32, 4), (4, 1)) result = coalescing_efficiency(tv, element_bytes=2) # 128 unique offsets * 2B = 256B -> cache lines 0, 1 - assert result['transactions'] == 2 - assert result['efficiency'] == pytest.approx(1.0) + assert result["transactions"] == 2 + assert result["efficiency"] == pytest.approx(1.0) ## segment_analysis @@ -240,29 +239,29 @@ def test_segment_analysis_contiguous_fp16(): """32 threads, stride 1, fp16: 2 segments, 1 cache line.""" result = segment_analysis(Layout(32, 1), element_bytes=2) # 32 * 2B = 64B -> 2 segments of 32B, 1 cache line of 128B - assert result['segments'] == 2 - assert result['cache_lines'] == 1 - assert result['unique_bytes'] == 64 - assert result['requested_bytes'] == 64 - assert result['transferred_bytes'] == 64 # 2 * 32 - assert result['segment_efficiency'] == pytest.approx(1.0) - assert result['first_alignment'] == 0 + assert result["segments"] == 2 + assert result["cache_lines"] == 1 + assert result["unique_bytes"] == 64 + assert result["requested_bytes"] == 64 + assert result["transferred_bytes"] == 64 # 2 * 32 + assert result["segment_efficiency"] == pytest.approx(1.0) + assert result["first_alignment"] == 0 def test_segment_analysis_strided(): """Stride-2 touches more segments than contiguous.""" result = segment_analysis(Layout(32, 2), element_bytes=2) # offsets 0,2,4,...,62 -> byte addrs 0,4,8,...,124 -> 4 segments - assert result['segments'] == 4 - assert result['cache_lines'] == 1 + assert result["segments"] == 4 + assert result["cache_lines"] == 1 def test_segment_analysis_broadcast(): """Broadcast: 1 segment, minimal unique bytes.""" result = segment_analysis(Layout(32, 0), element_bytes=2) - assert result['segments'] == 1 - assert result['unique_bytes'] == 2 - assert result['requested_bytes'] == 64 + assert result["segments"] == 1 + assert result["unique_bytes"] == 2 + assert result["requested_bytes"] == 64 def test_segment_analysis_tv_layout(): @@ -270,9 +269,9 @@ def test_segment_analysis_tv_layout(): tv = Layout((32, 4), (4, 1)) result = segment_analysis(tv, element_bytes=2) # 128 elements * 2B = 256B -> 8 segments, 2 cache lines - assert result['segments'] == 8 - assert result['cache_lines'] == 2 - assert result['requested_bytes'] == 256 # 32 * 4 * 2 + assert result["segments"] == 8 + assert result["cache_lines"] == 2 + assert result["requested_bytes"] == 256 # 32 * 4 * 2 ## per-group analysis @@ -282,11 +281,11 @@ def test_per_group_bank_conflicts(): """Per-group analysis matches single-group result for each warp.""" r_single = bank_conflicts(Layout(32, 32), element_bytes=2) r_per = per_group_bank_conflicts(Layout(64, 32), element_bytes=2) - assert len(r_per['groups']) == 2 + assert len(r_per["groups"]) == 2 # Each group should match the single-warp result - for g in r_per['groups']: - assert g['max_ways'] == r_single['max_ways'] - assert r_per['worst_max_ways'] == r_single['max_ways'] + for g in r_per["groups"]: + assert g["max_ways"] == r_single["max_ways"] + assert r_per["worst_max_ways"] == r_single["max_ways"] def test_per_group_bank_conflicts_tv_layout(): @@ -294,16 +293,16 @@ def test_per_group_bank_conflicts_tv_layout(): # 32 threads, 4 values each: should be 1 group (not 4) tv = Layout((32, 4), (1, 32)) result = per_group_bank_conflicts(tv, element_bytes=2, group_size=32) - assert len(result['groups']) == 1 + assert len(result["groups"]) == 1 def test_per_group_coalescing(): """Per-group coalescing for a uniform layout gives identical per-warp results.""" r_per = per_group_coalescing(Layout(64, 1), element_bytes=2) - assert len(r_per['groups']) == 2 - for g in r_per['groups']: - assert g['efficiency'] == pytest.approx(0.5) - assert g['transactions'] == 1 + assert len(r_per["groups"]) == 2 + for g in r_per["groups"]: + assert g["efficiency"] == pytest.approx(0.5) + assert g["transactions"] == 1 def test_per_group_coalescing_tv_layout(): @@ -311,9 +310,9 @@ def test_per_group_coalescing_tv_layout(): # 32 threads, 4 values each (contiguous within each thread's block) tv = Layout((32, 4), (4, 1)) result = per_group_coalescing(tv, element_bytes=2, group_size=32) - assert len(result['groups']) == 1 + assert len(result["groups"]) == 1 # 32 threads * 4 values = 128 elements * 2B = 256B -> 2 cache lines - assert result['groups'][0]['transactions'] == 2 + assert result["groups"][0]["transactions"] == 2 ## cycles @@ -473,43 +472,48 @@ def test_slice_contiguity_col_major(): def test_atom_summary_nv_sm80(): """SM80 16x8x16 F16 atom summary.""" from tensor_layouts.atoms_nv import SM80_16x8x16_F16F16F16F16_TN + result = atom_summary(SM80_16x8x16_F16F16F16F16_TN) - assert result['shape_mnk'] == (16, 8, 16) - assert result['threads'] == 32 - assert result['values_c'] > 0 - assert result['c_coverage_ok'] + assert result["shape_mnk"] == (16, 8, 16) + assert result["threads"] == 32 + assert result["values_c"] > 0 + assert result["c_coverage_ok"] def test_atom_summary_nv_sm80_f32(): """SM80 16x8x8 F32 accumulator atom.""" from tensor_layouts.atoms_nv import SM80_16x8x8_F32F16F16F32_TN + result = atom_summary(SM80_16x8x8_F32F16F16F32_TN) - assert result['shape_mnk'] == (16, 8, 8) - assert result['threads'] == 32 - assert result['c_coverage_ok'] + assert result["shape_mnk"] == (16, 8, 8) + assert result["threads"] == 32 + assert result["c_coverage_ok"] def test_atom_summary_amd_cdna(): """AMD CDNA 32x32x8 MFMA atom summary.""" from tensor_layouts.atoms_amd import CDNA_32x32x8_F32F16F16_MFMA + result = atom_summary(CDNA_32x32x8_F32F16F16_MFMA) - assert result['shape_mnk'] == (32, 32, 8) - assert result['threads'] == 64 # AMD wavefront - assert result['c_coverage_ok'] + assert result["shape_mnk"] == (32, 32, 8) + assert result["threads"] == 64 # AMD wavefront + assert result["c_coverage_ok"] def test_atom_summary_text_output(): """atom_summary returns a readable text summary.""" from tensor_layouts.atoms_nv import SM80_16x8x16_F16F16F16F16_TN + result = atom_summary(SM80_16x8x16_F16F16F16F16_TN) - assert 'SM80' in result['text'] - assert '16 x 8 x 16' in result['text'] - assert 'Threads' in result['text'] + assert "SM80" in result["text"] + assert "16 x 8 x 16" in result["text"] + assert "Threads" in result["text"] def test_atom_summary_rejects_wrong_c_offsets(): """c_coverage_ok must check exact offset set, not just cardinality.""" from tensor_layouts.atoms import MMAAtom + # Build a 2x2 atom where C layout produces offsets {0, 1, 2, 5} # instead of the expected {0, 1, 2, 3}. Cardinality is 4 = M*N, # but the set is wrong. @@ -518,10 +522,10 @@ def test_atom_summary_rejects_wrong_c_offsets(): ptx="test", shape_mnk=(2, 2, 1), thr_id=Layout(4), - a_layout=Layout((4, 1), (1, 0)), # doesn't matter for this test - b_layout=Layout((4, 1), (1, 0)), # doesn't matter for this test + a_layout=Layout((4, 1), (1, 0)), # doesn't matter for this test + b_layout=Layout((4, 1), (1, 0)), # doesn't matter for this test # C layout: 4 threads, 1 value each -> offsets 0, 1, 2, 5 - c_layout=Layout((4, 1), (1, 0)), # placeholder, override below + c_layout=Layout((4, 1), (1, 0)), # placeholder, override below ) # Manually construct a C layout that maps t -> {0, 1, 2, 5} # Layout((4, 1), (1, 0)) maps t -> t, giving {0, 1, 2, 3} — that's correct. @@ -529,16 +533,18 @@ def test_atom_summary_rejects_wrong_c_offsets(): # Layout with shape (2, 2) stride (1, 2) gives 0,1,2,3 — still correct. # Use a non-standard construction: ((2, 2), 1) : ((1, 4), 0) -> 0,1,4,5 import dataclasses + bad_c = Layout(((2, 2), 1), ((1, 4), 0)) bad_atom = dataclasses.replace(bad_atom, c_layout=bad_c) result = atom_summary(bad_atom) - assert not result['c_coverage_ok'] + assert not result["c_coverage_ok"] def test_atom_summary_rejects_duplicate_c_coverage(): """c_coverage_ok must be False when C layout produces duplicate offsets.""" from tensor_layouts.atoms import MMAAtom import dataclasses + # Build a 2x2x1 atom where C layout has shape (4, 2) stride (1, 0). # This maps (t, v) pairs to offsets [0,0,1,1,2,2,3,3] — correct set # but each offset appears twice. @@ -554,26 +560,27 @@ def test_atom_summary_rejects_duplicate_c_coverage(): dup_c = Layout((4, 2), (1, 0)) # 8 accesses, offsets 0..3 each twice bad_atom = dataclasses.replace(base, c_layout=dup_c) result = atom_summary(bad_atom) - assert not result['c_coverage_ok'] + assert not result["c_coverage_ok"] def test_operand_analysis_sm80(): """operand_analysis on a well-formed atom reports full coverage.""" from tensor_layouts.atoms_nv import SM80_16x8x16_F16F16F16F16_TN + result = operand_analysis(SM80_16x8x16_F16F16F16F16_TN) - for op in ['a', 'b', 'c']: - assert result[op]['coverage_ok'] - assert result[op]['duplicates'] == 0 - assert result[op]['thread_utilization'] == pytest.approx(1.0) - assert result['a']['domain_size'] == 16 * 16 # M * K - assert result['b']['domain_size'] == 8 * 16 # N * K - assert result['c']['domain_size'] == 16 * 8 # M * N + for op in ["a", "b", "c"]: + assert result[op]["coverage_ok"] + assert result[op]["duplicates"] == 0 + assert result[op]["thread_utilization"] == pytest.approx(1.0) + assert result["a"]["domain_size"] == 16 * 16 # M * K + assert result["b"]["domain_size"] == 8 * 16 # N * K + assert result["c"]["domain_size"] == 16 * 8 # M * N def test_operand_analysis_bad_coverage(): """operand_analysis detects malformed operand coverage.""" from tensor_layouts.atoms import MMAAtom - import dataclasses + base = MMAAtom( name="test_bad_operand", ptx="test", @@ -584,9 +591,9 @@ def test_operand_analysis_bad_coverage(): c_layout=Layout(((2, 2), 1), ((1, 4), 0)), # offsets {0,1,4,5}, not {0,1,2,3} ) result = operand_analysis(base) - assert not result['c']['coverage_ok'] - assert len(result['c']['missing']) > 0 - assert len(result['c']['extra']) > 0 + assert not result["c"]["coverage_ok"] + assert len(result["c"]["missing"]) > 0 + assert len(result["c"]["extra"]) > 0 ## explain @@ -595,26 +602,26 @@ def test_operand_analysis_bad_coverage(): def test_explain_logical_divide(): """explain shows step-by-step logical_divide computation.""" text = explain(logical_divide, Layout(16, 1), 4) - assert 'logical_divide' in text - assert 'complement' in text - assert 'compose' in text - assert '(4, 4) : (1, 4)' in text + assert "logical_divide" in text + assert "complement" in text + assert "compose" in text + assert "(4, 4) : (1, 4)" in text def test_explain_logical_product(): """explain shows step-by-step logical_product computation.""" text = explain(logical_product, Layout(4, 1), Layout(3, 1)) - assert 'logical_product' in text - assert 'complement' in text - assert '(4, 3) : (1, 4)' in text + assert "logical_product" in text + assert "complement" in text + assert "(4, 3) : (1, 4)" in text def test_explain_logical_product_tuple_tiler(): """explain handles logical_product with tuple tiler without crashing.""" text = explain(logical_product, Layout((4, 4), (1, 4)), (2, 2)) - assert 'logical_product' in text - assert 'mode 0' in text - assert 'mode 1' in text + assert "logical_product" in text + assert "mode 0" in text + assert "mode 1" in text expected = logical_product(Layout((4, 4), (1, 4)), (2, 2)) assert str(expected) in text @@ -622,77 +629,77 @@ def test_explain_logical_product_tuple_tiler(): def test_explain_complement(): """explain shows complement with image and codomain.""" text = explain(complement, Layout(4, 2), 16) - assert 'image' in text - assert 'codomain' in text - assert '[0, 16)' in text + assert "image" in text + assert "codomain" in text + assert "[0, 16)" in text def test_explain_compose(): """explain shows compose with per-element trace.""" text = explain(compose, Layout(8, 2), Layout(4, 1)) - assert 'C(i) = A(B(i))' in text - assert 'i=0' in text + assert "C(i) = A(B(i))" in text + assert "i=0" in text def test_explain_right_inverse(): """explain shows right_inverse with verification.""" text = explain(right_inverse, Layout(4, 2)) - assert 'R such that L(R(i)) == i' in text - assert 'Verification' in text + assert "R such that L(R(i)) == i" in text + assert "Verification" in text def test_explain_left_inverse(): """explain shows left_inverse with verification.""" text = explain(left_inverse, Layout(4, 2)) - assert 'R such that R(L(i)) == i' in text - assert 'Verification' in text + assert "R such that R(L(i)) == i" in text + assert "Verification" in text def test_explain_unsupported(): """explain gracefully handles unsupported functions.""" text = explain(size, Layout(4, 1)) - assert 'does not support' in text + assert "does not support" in text def test_explain_blocked_product(): """explain shows blocked_product as interleaved logical_product.""" text = explain(blocked_product, Layout((2, 3), (1, 2)), Layout((4, 2), (1, 4))) - assert 'blocked_product' in text - assert 'logical_product' in text - assert 'A varies fastest' in text + assert "blocked_product" in text + assert "logical_product" in text + assert "A varies fastest" in text def test_explain_raked_product(): """explain shows raked_product with comparison to blocked.""" text = explain(raked_product, Layout(4, 1), Layout(3, 1)) - assert 'raked_product' in text - assert 'B varies fastest' in text - assert 'blocked' in text - assert 'raked' in text + assert "raked_product" in text + assert "B varies fastest" in text + assert "blocked" in text + assert "raked" in text def test_explain_zipped_divide(): """explain shows zipped_divide as rearranged logical_divide.""" text = explain(zipped_divide, Layout((4, 6), (1, 4)), (2, 3)) - assert 'zipped_divide' in text - assert 'logical_divide' in text - assert '((tiles), (rests))' in text + assert "zipped_divide" in text + assert "logical_divide" in text + assert "((tiles), (rests))" in text def test_explain_tiled_divide(): """explain shows tiled_divide structure.""" text = explain(tiled_divide, Layout((4, 6), (1, 4)), (2, 3)) - assert 'tiled_divide' in text - assert 'logical_divide' in text - assert '((tiles), rest0, rest1, ...)' in text + assert "tiled_divide" in text + assert "logical_divide" in text + assert "((tiles), rest0, rest1, ...)" in text def test_explain_flat_divide(): """explain shows flat_divide structure.""" text = explain(flat_divide, Layout((4, 6), (1, 4)), (2, 3)) - assert 'flat_divide' in text - assert 'logical_divide' in text - assert '(tile0, tile1, ..., rest0, rest1, ...)' in text + assert "flat_divide" in text + assert "logical_divide" in text + assert "(tile0, tile1, ..., rest0, rest1, ...)" in text ## MMAAtom and CopyAtom __str__ diff --git a/tests/external.py b/tests/external.py index dd45717..da35d2a 100644 --- a/tests/external.py +++ b/tests/external.py @@ -22,7 +22,7 @@ import pytest -from tensor_layouts import * +from tensor_layouts import * # noqa: F401,F403,F405 from tensor_layouts.layout_utils import round_up @@ -225,17 +225,20 @@ def _test_coalesce_properties(layout): coalesce_layout = coalesce(layout) # Property 1: Result depth is at most 1 (flattened) - assert depth(coalesce_layout) <= 1, \ - f"depth(coalesce_layout)={depth(coalesce_layout)} > 1" + assert ( + depth(coalesce_layout) <= 1 + ), f"depth(coalesce_layout)={depth(coalesce_layout)} > 1" # Property 2: Size is preserved - assert size(coalesce_layout) == size(layout), \ - f"size(coalesce_layout)={size(coalesce_layout)} != size(layout)={size(layout)}" + assert size(coalesce_layout) == size( + layout + ), f"size(coalesce_layout)={size(coalesce_layout)} != size(layout)={size(layout)}" # Property 3: All indices map to the same offsets for i in range(size(layout)): - assert coalesce_layout(i) == layout(i), \ - f"coalesce_layout({i})={coalesce_layout(i)} != layout({i})={layout(i)}" + assert coalesce_layout(i) == layout( + i + ), f"coalesce_layout({i})={coalesce_layout(i)} != layout({i})={layout(i)}" def test_coalesce_simple(): @@ -328,16 +331,18 @@ def _test_composition_properties(layout_a, layout_b): layout_r = compose(layout_a, layout_b) # Property 1: Layout B is compatible with layout R - assert compatible(layout_b.shape, layout_r.shape), \ - f"layoutB.shape={layout_b.shape} not compatible with layoutR.shape={layout_r.shape}" + assert compatible( + layout_b.shape, layout_r.shape + ), f"layoutB.shape={layout_b.shape} not compatible with layoutR.shape={layout_r.shape}" # Property 2: R(c) = A(B(c)) for coordinates within A's domain a_size = size(layout_a) for c in range(size(layout_b)): bc = layout_b(c) if bc < a_size: - assert layout_r(c) == layout_a(bc), \ - f"layoutR({c})={layout_r(c)} != layoutA(layoutB({c}))={layout_a(bc)}" + assert layout_r(c) == layout_a( + bc + ), f"layoutR({c})={layout_r(c)} != layoutA(layoutB({c}))={layout_a(bc)}" def test_composition_simple(): @@ -435,21 +440,16 @@ def test_composition_multidimensional(): def test_composition_nested(): # Layout((8, 8)) o Layout(((2, 2, 2), (2, 2, 2)), ((1, 16, 4), (8, 2, 32))) _test_composition_properties( - Layout((8, 8)), - Layout(((2, 2, 2), (2, 2, 2)), ((1, 16, 4), (8, 2, 32))) + Layout((8, 8)), Layout(((2, 2, 2), (2, 2, 2)), ((1, 16, 4), (8, 2, 32))) ) # Layout((8, 8), (8, 1)) o Layout(((2, 2, 2), (2, 2, 2)), ((1, 16, 4), (8, 2, 32))) _test_composition_properties( - Layout((8, 8), (8, 1)), - Layout(((2, 2, 2), (2, 2, 2)), ((1, 16, 4), (8, 2, 32))) + Layout((8, 8), (8, 1)), Layout(((2, 2, 2), (2, 2, 2)), ((1, 16, 4), (8, 2, 32))) ) # Layout(((4, 2),), ((1, 16),)) o Layout((4, 2), (2, 1)) - _test_composition_properties( - Layout(((4, 2),), ((1, 16),)), - Layout((4, 2), (2, 1)) - ) + _test_composition_properties(Layout(((4, 2),), ((1, 16),)), Layout((4, 2), (2, 1))) # Layout((2, 2), (2, 1)) o Layout((2, 2), (2, 1)) _test_composition_properties(Layout((2, 2), (2, 1)), Layout((2, 2), (2, 1))) @@ -459,14 +459,12 @@ def test_composition_nested(): # Layout((4, 8, 2), (2, 8, 1)) o Layout((2, 2, 2), (1, 8, 2)) _test_composition_properties( - Layout((4, 8, 2), (2, 8, 1)), - Layout((2, 2, 2), (1, 8, 2)) + Layout((4, 8, 2), (2, 8, 1)), Layout((2, 2, 2), (1, 8, 2)) ) # Layout((4, 8, 2), (2, 8, 1)) o Layout((4, 2, 2), (2, 8, 1)) _test_composition_properties( - Layout((4, 8, 2), (2, 8, 1)), - Layout((4, 2, 2), (2, 8, 1)) + Layout((4, 8, 2), (2, 8, 1)), Layout((4, 2, 2), (2, 8, 1)) ) @@ -483,16 +481,10 @@ def test_composition_dynamic(): _test_composition_properties(Layout(16, 2), Layout(4, 2)) # Layout((128, 24, 5), (1, 128, 3072)) o Layout(64, 2) - _test_composition_properties( - Layout((128, 24, 5), (1, 128, 3072)), - Layout(64, 2) - ) + _test_composition_properties(Layout((128, 24, 5), (1, 128, 3072)), Layout(64, 2)) # Layout((128, 24, 5), (1, 128, 3072)) o Layout(480, 32) - _test_composition_properties( - Layout((128, 24, 5), (1, 128, 3072)), - Layout(480, 32) - ) + _test_composition_properties(Layout((128, 24, 5), (1, 128, 3072)), Layout(480, 32)) def test_composition_cosize_larger(): @@ -582,12 +574,15 @@ def _test_logical_divide_properties(layout, tile): # For Layout tilers, verify the result rank is 2 (tile, rest) if isinstance(tile, Layout): # CuTe formula produces rank-2 result: (Tile, Rest) - assert rank(result) == 2, f"Expected rank 2 for Layout tiler, got {rank(result)}" + assert ( + rank(result) == 2 + ), f"Expected rank 2 for Layout tiler, got {rank(result)}" # The tile part (mode 0) should have size equal to size(tiler) tile_part = mode(result, 0) - assert size(tile_part) == size(tile_layout), \ - f"Tile part size {size(tile_part)} != tiler size {size(tile_layout)}" + assert size(tile_part) == size( + tile_layout + ), f"Tile part size {size(tile_part)} != tiler size {size(tile_layout)}" def test_logical_divide_simple(): @@ -707,6 +702,7 @@ def _test_swizzle_2d(sw_layout): This tests that slicing a tensor with swizzled layout preserves correct indexing. """ from tensor_layouts import Tensor + tensor = Tensor(sw_layout) # Get dimensions @@ -750,10 +746,7 @@ def test_swizzle_3_0_3(): - Col index contributes bits [0,3) - XOR creates the swizzle pattern """ - sw_layout = compose( - Swizzle(3, 0, 3), - Layout((8, 8), (8, 1)) # 8x8 row-major - ) + sw_layout = compose(Swizzle(3, 0, 3), Layout((8, 8), (8, 1))) # 8x8 row-major _test_swizzle_2d(sw_layout) @@ -765,10 +758,7 @@ def test_swizzle_3_0_neg3(): - The XOR goes in the opposite direction - Bits at [0,3) are shifted left and XORed into bits [3,6) """ - sw_layout = compose( - Swizzle(3, 0, -3), - Layout((8, 8), (8, 1)) # 8x8 row-major - ) + sw_layout = compose(Swizzle(3, 0, -3), Layout((8, 8), (8, 1))) # 8x8 row-major _test_swizzle_2d(sw_layout) @@ -788,11 +778,7 @@ def test_swizzle_2_1_3(): - shift=3: masks are 3 positions apart (bits [1,3) and [4,6)) """ sw_layout = compose( - Swizzle(2, 1, 3), - Layout( - ((2, 2, 2), (2, 2, 2)), - ((32, 2, 8), (4, 1, 16)) - ) + Swizzle(2, 1, 3), Layout(((2, 2, 2), (2, 2, 2)), ((32, 2, 8), (4, 1, 16))) ) _test_swizzle_2d(sw_layout) @@ -870,10 +856,7 @@ def test_swizzle_with_base(): def test_composed_layout_repr(): """Test swizzled Layout string representation.""" - sw_layout = compose( - Swizzle(3, 0, 3), - Layout((8, 8), (8, 1)) - ) + sw_layout = compose(Swizzle(3, 0, 3), Layout((8, 8), (8, 1))) repr_str = repr(sw_layout) assert "Swizzle(3, 0, 3)" in repr_str @@ -903,10 +886,7 @@ def test_inner_product(): """Test inner_product (pycute test_int_tuple.py::test_inner_product).""" assert inner_product(2, 3) == 6 assert inner_product((1, 2), (3, 2)) == 7 - assert inner_product( - ((2, 3), 4), - ((2, 1), 2) - ) == 15 + assert inner_product(((2, 3), 4), ((2, 1), 2)) == 15 def test_prefix_product(): @@ -915,9 +895,11 @@ def test_prefix_product(): assert prefix_product((3, 2)) == (1, 3) assert prefix_product((3, 2, 4)) == (1, 3, 6) assert prefix_product(((2, 3), 4)) == ((1, 2), 6) - assert prefix_product( - ((2, 3), (2, 1, 2), (5, 2, 1)) - ) == ((1, 2), (6, 12, 12), (24, 120, 240)) + assert prefix_product(((2, 3), (2, 1, 2), (5, 2, 1))) == ( + (1, 2), + (6, 12, 12), + (24, 120, 240), + ) def test_shape_div_pycute(): @@ -976,13 +958,14 @@ def test_coalesce_pycute(): Uses the pycute helper: verify size and functional equivalence. """ + def _check(layout): layoutR = coalesce(layout) assert size(layoutR) == size(layout) for i in range(size(layout)): - assert layoutR(i) == layout(i), ( - f"coalesce({layout})({i}) = {layoutR(i)} != {layout(i)}" - ) + assert layoutR(i) == layout( + i + ), f"coalesce({layout})({i}) = {layoutR(i)} != {layout(i)}" _check(Layout(1, 0)) _check(Layout(1, 1)) @@ -1007,12 +990,13 @@ def test_composition_pycute(): Uses the pycute helper: R(i) == A(B(i)) for all i. """ + def _check(A, B): R = compose(A, B) for i in range(size(R)): - assert R(i) == A(B(i)), ( - f"compose({A}, {B})({i}) = {R(i)} != A(B({i})) = {A(B(i))}" - ) + assert R(i) == A( + B(i) + ), f"compose({A}, {B})({i}) = {R(i)} != A(B({i})) = {A(B(i))}" # All test cases from pycute test_composition.py _check(Layout(1, 0), Layout(1, 0)) @@ -1044,7 +1028,9 @@ def _check(A, B): _check(Layout((4, 3), (3, 1)), Layout(6, 2)) _check(Layout((4, 3), (3, 1)), Layout((6, 2), (2, 1))) _check(Layout((8, 8)), Layout(((2, 2, 2), (2, 2, 2)), ((1, 16, 4), (8, 2, 32)))) - _check(Layout((8, 8), (8, 1)), Layout(((2, 2, 2), (2, 2, 2)), ((1, 16, 4), (8, 2, 32)))) + _check( + Layout((8, 8), (8, 1)), Layout(((2, 2, 2), (2, 2, 2)), ((1, 16, 4), (8, 2, 32))) + ) # Layout applied from right with stride (from pycute, not in C++ tests) _check(Layout(((2, 2, 2), (2, 2, 2)), ((1, 16, 4), (8, 2, 32))), Layout(8, 4)) _check(Layout((4, 2), (1, 16)), Layout((4, 2), (2, 1))) @@ -1067,8 +1053,7 @@ def _test_right_inverse(layout): inv_layout = right_inverse(layout) for i in range(size(inv_layout)): assert layout(inv_layout(i)) == i, ( - f"right_inverse({layout}): L(R({i})) = " - f"{layout(inv_layout(i))} != {i}" + f"right_inverse({layout}): L(R({i})) = " f"{layout(inv_layout(i))} != {i}" ) @@ -1110,8 +1095,7 @@ def _test_left_inverse(layout): inv_layout = left_inverse(layout) for i in range(size(layout)): assert inv_layout(layout(i)) == i, ( - f"left_inverse({layout}): R(L({i})) = " - f"{inv_layout(layout(i))} != {i}" + f"left_inverse({layout}): R(L({i})) = " f"{inv_layout(layout(i))} != {i}" ) @@ -1164,15 +1148,14 @@ def _test_logical_product_properties(layout_a, layout_b): R = logical_product(layout_a, layout_b) # Property 1: Result has rank 2 - assert rank(R) == 2, ( - f"logical_product({layout_a}, {layout_b}): rank={rank(R)}, expected 2" - ) + assert ( + rank(R) == 2 + ), f"logical_product({layout_a}, {layout_b}): rank={rank(R)}, expected 2" # Property 2: First mode of R equals A R0 = mode(R, 0) assert R0.shape == layout_a.shape and R0.stride == layout_a.stride, ( - f"logical_product({layout_a}, {layout_b}): " - f"mode(R,0)={R0} != A={layout_a}" + f"logical_product({layout_a}, {layout_b}): " f"mode(R,0)={R0} != A={layout_a}" ) # Property 3: B is compatible with second mode of R @@ -1213,13 +1196,9 @@ def test_logical_product_multidim_tile(): # Layout((2,4)) x Layout(3) _test_logical_product_properties(Layout((2, 4)), Layout(3)) # Layout((8,(2,2))) x Layout(4,2) - _test_logical_product_properties( - Layout((8, (2, 2)), (1, (8, 16))), Layout(4, 2) - ) + _test_logical_product_properties(Layout((8, (2, 2)), (1, (8, 16))), Layout(4, 2)) # Layout((2,2)) x Layout((3,3),(3,1)) - _test_logical_product_properties( - Layout((2, 2), (1, 2)), Layout((3, 3), (3, 1)) - ) + _test_logical_product_properties(Layout((2, 2), (1, 2)), Layout((3, 3), (3, 1))) def test_logical_product_large_stride(): @@ -1231,13 +1210,9 @@ def test_logical_product_large_stride(): def test_logical_product_nested(): """Logical product with nested/hierarchical layouts (C++ lines 175-213).""" # Layout(((4,2)),((1,16))) x Layout((4,4)) - _test_logical_product_properties( - Layout((4, 2), (1, 16)), Layout((4, 4)) - ) + _test_logical_product_properties(Layout((4, 2), (1, 16)), Layout((4, 4))) # Layout(((4,2)),((1,16))) x Layout((4,2),(2,1)) - _test_logical_product_properties( - Layout((4, 2), (1, 16)), Layout((4, 2), (2, 1)) - ) + _test_logical_product_properties(Layout((4, 2), (1, 16)), Layout((4, 2), (2, 1))) # Layout(((2,2),(2,2)),((1,4),(8,32))) x Layout((2,2),(1,2)) _test_logical_product_properties( Layout(((2, 2), (2, 2)), ((1, 4), (8, 32))), @@ -1249,9 +1224,7 @@ def test_logical_product_nested(): Layout((2, 2), (2, 1)), ) # Layout(((4,6)),((1,6))) x Layout(3,1) - _test_logical_product_properties( - Layout((4, 6), (1, 6)), Layout(3, 1) - ) + _test_logical_product_properties(Layout((4, 6), (1, 6)), Layout(3, 1)) ## Left Inverse edge cases (C++ inverse_left.cpp) @@ -1281,9 +1254,7 @@ def _test_left_inverse_cpp(layout): # Fast path: stride-0 modes make injectivity impossible if _has_broadcast(layout): - assert size(inv_layout) >= 1, ( - f"left_inverse({layout}): empty result" - ) + assert size(inv_layout) >= 1, f"left_inverse({layout}): empty result" return # No broadcast modes — check injectivity and contiguity via enumeration @@ -1296,13 +1267,10 @@ def _test_left_inverse_cpp(layout): ili = inv_layout(li) lili = layout(ili) assert lili == li, ( - f"left_inverse({layout}): " - f"L(inv(L({i})))={lili} != L({i})={li}" + f"left_inverse({layout}): " f"L(inv(L({i})))={lili} != L({i})={li}" ) else: - assert size(inv_layout) >= 1, ( - f"left_inverse({layout}): empty result" - ) + assert size(inv_layout) >= 1, f"left_inverse({layout}): empty result" def test_left_inverse_cpp_broadcast(): @@ -1342,10 +1310,12 @@ def test_left_inverse_cpp_deep_nested(): Shape: (((( 32, 4), 1), ( 32, 2)), 4), 1, (2, 2), 2) Stride: ((((262144, 4), 0), ( 0, 1)), 8388608), 0, (2, 16), 32) """ - _test_left_inverse_cpp(Layout( - ((((32, 4), 1), (32, 2)), 4, 1, (2, 2), 2), - ((((262144, 4), 0), (0, 1)), 8388608, 0, (2, 16), 32), - )) + _test_left_inverse_cpp( + Layout( + ((((32, 4), 1), (32, 2)), 4, 1, (2, 2), 2), + ((((262144, 4), 0), (0, 1)), 8388608, 0, (2, 16), 32), + ) + ) ## Right Inverse edge cases (C++ inverse_right.cpp) @@ -1360,10 +1330,7 @@ def _test_right_inverse_cpp(layout): inv_layout = right_inverse(layout) for i in range(size(inv_layout)): li = layout(inv_layout(i)) - assert li == i, ( - f"right_inverse({layout}): " - f"L(R({i}))={li} != {i}" - ) + assert li == i, f"right_inverse({layout}): " f"L(R({i}))={li} != {i}" def test_right_inverse_cpp_4d(): @@ -1392,10 +1359,12 @@ def test_right_inverse_cpp_broadcast_middle(): def test_right_inverse_cpp_deep_nested(): """Right-inverse of deeply nested layout (C++ inverse_right.cpp line 210).""" - _test_right_inverse_cpp(Layout( - ((((32, 4), 1), (32, 2)), 4, 1, (2, 2), 2), - ((((262144, 4), 0), (0, 1)), 8388608, 0, (2, 16), 32), - )) + _test_right_inverse_cpp( + Layout( + ((((32, 4), 1), (32, 2)), 4, 1, (2, 2), 2), + ((((262144, 4), 0), (0, 1)), 8388608, 0, (2, 16), 32), + ) + ) ## Composition edge case (C++ composition.cpp line 241-246) @@ -1406,9 +1375,7 @@ def test_composition_transposed_strides(): Layout((4,3)) o Layout((4,3),(3,1)) -- col-major transposed. """ - _test_composition_properties( - Layout((4, 3)), Layout((4, 3), (3, 1)) - ) + _test_composition_properties(Layout((4, 3)), Layout((4, 3), (3, 1))) ## Complement edge case (pycute Python test_complement.py) @@ -1455,6 +1422,7 @@ def test_elem_scale(): assert elem_scale((1, 1), (7, 8)) == (7, 8) # Tuple x scalar -> error import pytest + with pytest.raises(TypeError): elem_scale((2, 3), 4) @@ -1560,7 +1528,9 @@ def test_slice_and_offset(): # Verify: sublayout(i) + offset == original(fixed_coord, i) for i in range(size(sub)): - assert sub(i) + offset == layout(2, i), f"i={i}: {sub(i) + offset} != {layout(2, i)}" + assert sub(i) + offset == layout( + 2, i + ), f"i={i}: {sub(i) + offset} != {layout(2, i)}" ## zipped_product diff --git a/tests/layouts.py b/tests/layouts.py index 0ee5e8d..307ce46 100644 --- a/tests/layouts.py +++ b/tests/layouts.py @@ -22,8 +22,12 @@ import pytest -from tensor_layouts import * -from tensor_layouts.layout_utils import make_layout_like, make_ordered_layout, tile_to_shape +from tensor_layouts import * # noqa: F401,F403,F405 +from tensor_layouts.layout_utils import ( + make_layout_like, + make_ordered_layout, + tile_to_shape, +) # These tests roughly follow: @@ -64,7 +68,7 @@ def test_tuple_single_element(): def test_tuple_nested_single_element(): - t = (((2,))) + t = (2,) assert len(t) == 1 assert size(t) == 2 assert rank(t) == 1 @@ -435,9 +439,9 @@ def test_coordinate_validation(): L([1, 2]) # Valid cases still work - assert L(1) == 1 # flat index - assert L(1, 2) == 9 # tuple coord via *args - assert L((1, 2)) == 9 # tuple coord via single arg + assert L(1) == 1 # flat index + assert L(1, 2) == 9 # tuple coord via *args + assert L((1, 2)) == 9 # tuple coord via single arg def test_idx2crd_crd2flat_crd2offset(): @@ -491,16 +495,16 @@ def test_shape_div_non_divisible(): (e.g., shape_div/mod won't be complementary), so we assert. """ # Valid cases where divisibility holds - assert shape_div(12, 4) == 3 # 12%4==0 - assert shape_div(4, 12) == 1 # 12%4==0 - assert shape_div(8, 2) == 4 # 8%2==0 - assert shape_div(2, 8) == 1 # 8%2==0 + assert shape_div(12, 4) == 3 # 12%4==0 + assert shape_div(4, 12) == 1 # 12%4==0 + assert shape_div(8, 2) == 4 # 8%2==0 + assert shape_div(2, 8) == 1 # 8%2==0 # Invalid cases should raise ValueError with pytest.raises(ValueError): - shape_div(6, 4) # 6%4≠0, 4%6≠0 + shape_div(6, 4) # 6%4≠0, 4%6≠0 with pytest.raises(ValueError): - shape_div(4, 6) # 4%6≠0, 6%4≠0 + shape_div(4, 6) # 4%6≠0, 6%4≠0 def test_shape_mod_non_divisible(): @@ -514,9 +518,9 @@ def test_shape_mod_non_divisible(): # shape_mod(2, 2) = gcd(2,2) = 2 assert shape_mod((6, 2), 4) == (2, 2) # Scalar shape_mod: when modulus < shape, returns gcd - assert shape_mod(6, 4) == 2 # gcd(6,4) = 2 + assert shape_mod(6, 4) == 2 # gcd(6,4) = 2 # Scalar shape_mod: when modulus >= shape, returns shape - assert shape_mod(4, 6) == 4 # 6 >= 4, returns 4 + assert shape_mod(4, 6) == 4 # 6 >= 4, returns 4 def test_shape_div_mod_complementary(): @@ -525,11 +529,15 @@ def test_shape_div_mod_complementary(): This holds when the divisor evenly divides each mode it consumes. """ test_cases = [ - ((6, 2), 2), ((6, 2), 3), - ((6, 2), 6), ((6, 2), 12), - ((4, 3), 2), ((4, 3), 4), + ((6, 2), 2), + ((6, 2), 3), + ((6, 2), 6), + ((6, 2), 12), + ((4, 3), 2), + ((4, 3), 4), ((4, 3), 12), - ((3, 6, 2, 8), 3), ((3, 6, 2, 8), 9), + ((3, 6, 2, 8), 3), + ((3, 6, 2, 8), 9), ((3, 6, 2, 8), 72), ] for shape, div in test_cases: @@ -626,7 +634,9 @@ def test_compose_two_2d(): # B(0,0)=0, B(1,0)=1, B(0,1)=2, B(1,1)=3 # A(0)=0, A(1)=1, A(2)=2, A(3)=3 # So compose gives same result as B indexing into first 4 elements of A - assert compose(Layout((4, 4), (1, 4)), Layout((2, 2), (1, 2))) == Layout((2, 2), (1, 2)) + assert compose(Layout((4, 4), (1, 4)), Layout((2, 2), (1, 2))) == Layout( + (2, 2), (1, 2) + ) def test_compose_functional_equivalence(): @@ -1055,7 +1065,9 @@ def test_core_matrix_operations(): # For a 2-byte dtype such as f16, core matrix is 8x8 tile1 = Layout((8, 1), (1, 0)) # (8,1):(1,0) mul1 = Layout((1, 8), (0, 1)) - tile2 = coalesce(blocked_product(tile1, mul1), profile=(None, None)) # (8,8):(1,8) -> One core Matrix + tile2 = coalesce( + blocked_product(tile1, mul1), profile=(None, None) + ) # (8,8):(1,8) -> One core Matrix assert tile2 == Layout((8, 8), (1, 8)) # Now organize core matrices into 8x8 pattern, so that we have a 64x64 Tile, say in SMem mul2 = Layout((8, 8), (1, 8)) @@ -1280,11 +1292,12 @@ def test_swizzled_layout_eq_hash(): def test_offset_swizzled_layout_basic(): from tensor_layouts import Tensor + sw_layout = compose(Swizzle(3, 0, 3), Layout((8, 8), (8, 1))) tensor = Tensor(sw_layout) # Slicing a Tensor produces a Tensor with offset row_slice = tensor[3, :] - assert hasattr(row_slice, 'offset') + assert hasattr(row_slice, "offset") assert row_slice.offset == Layout((8, 8), (8, 1))(3, 0) # = 24 # Check functional correctness: tensor[3, :](j) == tensor(3, j) @@ -1294,6 +1307,7 @@ def test_offset_swizzled_layout_basic(): def test_offset_swizzled_layout_repr(): from tensor_layouts import Tensor + sw_layout = compose(Swizzle(3, 0, 3), Layout((8, 8), (8, 1))) tensor = Tensor(sw_layout) row_slice = tensor[2, :] @@ -1304,6 +1318,7 @@ def test_offset_swizzled_layout_repr(): def test_offset_swizzled_layout_eq(): from tensor_layouts import Tensor + sw_layout = compose(Swizzle(3, 0, 3), Layout((8, 8), (8, 1))) tensor = Tensor(sw_layout) slice1 = tensor[3, :] @@ -1449,6 +1464,7 @@ def test_tile_to_shape_nested_block(): ## is_layout + def test_is_layout(): assert is_layout(Layout(4, 1)) is True assert is_layout(Layout((2, 3), (1, 2))) is True @@ -1460,6 +1476,7 @@ def test_is_layout(): ## unflatten + def test_unflatten_tuple(): # Flat tuple -> nested tuple assert unflatten((1, 2, 3, 4, 5), ((0, 0), (0, 0, 0))) == ((1, 2), (3, 4, 5)) @@ -1557,6 +1574,7 @@ def test_make_ordered_layout_scalar(): ## dice_modes + def test_dice_modes_scalar_coord(): # Scalar coord: identity (keep everything) layout = Layout((3, 4), (1, 4)) @@ -1614,6 +1632,7 @@ def test_dice_modes_complement_of_slice_modes(): ## nullspace + def test_nullspace_all_zero_strides(): # All stride-0: everything is in the kernel # Inspired by C++ test: Layout,Stride<_0,_0,_0>> @@ -1681,6 +1700,7 @@ def test_nullspace_scalar_zero_stride(): ## max_common_vector and max_common_layout + def test_max_common_vector_identical(): # Same layout: all elements are common a = Layout(8, 1) @@ -1724,6 +1744,7 @@ def test_max_common_layout_partial(): ## flat_product + def test_flat_product_basic(): # flat_product = zipped_product then unpack both modes block = Layout(4, 1) @@ -1750,6 +1771,7 @@ def test_flat_product_2d(): ## raked_product + def test_raked_product_basic(): # raked_product vs blocked_product: reversed zip order block = Layout((2, 2), (1, 2)) @@ -1864,8 +1886,11 @@ def test_upcast_known_copy_atoms(): derived from the CUTLASS C++ copy_traits_sm75.hpp source. """ from tensor_layouts.atoms_nv import ( - SM75_U32x1_LDSM_N, SM75_U32x4_LDSM_N, - SM75_U16x2_LDSM_T, SM75_U16x4_LDSM_T, SM75_U16x8_LDSM_T, + SM75_U32x1_LDSM_N, + SM75_U32x4_LDSM_N, + SM75_U16x2_LDSM_T, + SM75_U16x4_LDSM_T, + SM75_U16x8_LDSM_T, ) cases = [ @@ -1879,12 +1904,12 @@ def test_upcast_known_copy_atoms(): for atom, exp_shape, exp_stride in cases: result = upcast(atom.dst_layout_bits, 16) - assert result.shape == exp_shape, ( - f"{atom.name}: shape {result.shape} != expected {exp_shape}" - ) - assert result.stride == exp_stride, ( - f"{atom.name}: stride {result.stride} != expected {exp_stride}" - ) + assert ( + result.shape == exp_shape + ), f"{atom.name}: shape {result.shape} != expected {exp_shape}" + assert ( + result.stride == exp_stride + ), f"{atom.name}: stride {result.stride} != expected {exp_stride}" def test_downcast_simple(): @@ -1945,9 +1970,12 @@ def test_iter_layout_2d_col_major(): layout = Layout((2, 3), (1, 2)) result = list(iter_layout(layout)) expected = [ - ((0, 0), 0), ((1, 0), 1), # col 0 - ((0, 1), 2), ((1, 1), 3), # col 1 - ((0, 2), 4), ((1, 2), 5), # col 2 + ((0, 0), 0), + ((1, 0), 1), # col 0 + ((0, 1), 2), + ((1, 1), 3), # col 1 + ((0, 2), 4), + ((1, 2), 5), # col 2 ] assert result == expected diff --git a/tests/oracle_amd.py b/tests/oracle_amd.py index dfbc7d0..c81d1cd 100644 --- a/tests/oracle_amd.py +++ b/tests/oracle_amd.py @@ -44,42 +44,51 @@ from tensor_layouts import Layout, size, rank, depth, mode, cosize from tensor_layouts.layouts import ( - compose, complement, flatten, coalesce, - logical_divide, logical_product, - left_inverse, right_inverse, - idx2crd, crd2idx, + compose, + complement, + flatten, + coalesce, + logical_divide, + logical_product, + left_inverse, + right_inverse, + idx2crd, + crd2idx, ) from tensor_layouts.layout_utils import ( - make_ordered_layout, tile_to_shape, product_each, + make_ordered_layout, + tile_to_shape, + product_each, ) -from tensor_layouts.atoms_amd import * +from tensor_layouts.atoms_amd import * # noqa: F401,F403,F405 # Try to import the AMD matrix instruction calculator. try: from amd_matrix_instruction_calculator import matrix_calculator + HAS_CALCULATOR = True except ImportError: try: import matrix_calculator + HAS_CALCULATOR = True except ImportError: HAS_CALCULATOR = False requires_calculator = pytest.mark.skipif( - not HAS_CALCULATOR, - reason="amd_matrix_instruction_calculator not available" + not HAS_CALCULATOR, reason="amd_matrix_instruction_calculator not available" ) # Try to import visualization module (requires matplotlib). try: from tensor_layouts.viz import draw_tv_layout, draw_mma_layout, _compute_tv_mapping + HAS_VIZ = True except ImportError: HAS_VIZ = False requires_viz = pytest.mark.skipif( - not HAS_VIZ, - reason="tensor_layouts.viz not available (needs matplotlib)" + not HAS_VIZ, reason="tensor_layouts.viz not available (needs matplotlib)" ) @@ -133,9 +142,12 @@ # Helpers # ============================================================================= + def _num_threads(layout): """Number of threads from a TV layout's thread dimension.""" - return size(layout.shape[0]) if isinstance(layout.shape, tuple) else size(layout.shape) + return ( + size(layout.shape[0]) if isinstance(layout.shape, tuple) else size(layout.shape) + ) def _num_values(layout): @@ -166,10 +178,10 @@ def get_calculator_d_mapping(arch: str, instruction: str, m: int, n: int): try: info = matrix_calculator.get_instruction_info(arch, instruction) mapping = {} - num_vgprs = info['num_output_regs'] + num_vgprs = info["num_output_regs"] for lane in range(64): for vgpr in range(num_vgprs): - row, col = info['get_output'](lane, vgpr) + row, col = info["get_output"](lane, vgpr) mapping[(lane, vgpr)] = (row, col) return mapping except (AttributeError, KeyError): @@ -206,9 +218,8 @@ def validate_c_layout(atom, arch: str): f"ref=({ref_row},{ref_col})" ) - assert not errors, ( - f"{atom.name}: {len(errors)} mismatches:\n" + - "\n".join(errors[:20]) + assert not errors, f"{atom.name}: {len(errors)} mismatches:\n" + "\n".join( + errors[:20] ) @@ -216,14 +227,17 @@ def validate_c_layout(atom, arch: str): # CDNA1 (gfx908) FP16 atoms # ============================================================================= + @requires_calculator def test_oracle_cdna_32x32x8_f32f16f16(): validate_c_layout(CDNA_32x32x8_F32F16F16_MFMA, "cdna1") + @requires_calculator def test_oracle_cdna_16x16x16_f32f16f16(): validate_c_layout(CDNA_16x16x16_F32F16F16_MFMA, "cdna1") + @requires_calculator def test_oracle_cdna_4x4x4_f32f16f16(): validate_c_layout(CDNA_4x4x4_F32F16F16_MFMA, "cdna1") @@ -233,10 +247,12 @@ def test_oracle_cdna_4x4x4_f32f16f16(): # CDNA1 non-k-reduction variants # ============================================================================= + @requires_calculator def test_oracle_cdna_32x32x4_f32f16f16(): validate_c_layout(CDNA_32x32x4_F32F16F16_MFMA, "cdna1") + @requires_calculator def test_oracle_cdna_16x16x4_f32f16f16(): validate_c_layout(CDNA_16x16x4_F32F16F16_MFMA, "cdna1") @@ -246,10 +262,12 @@ def test_oracle_cdna_16x16x4_f32f16f16(): # CDNA2 (gfx90a) BF16_1K atoms # ============================================================================= + @requires_calculator def test_oracle_cdna_32x32x8_f32bf16bf16_1k(): validate_c_layout(CDNA_32x32x8_F32BF16BF16_1K_MFMA, "cdna2") + @requires_calculator def test_oracle_cdna_16x16x16_f32bf16bf16_1k(): validate_c_layout(CDNA_16x16x16_F32BF16BF16_1K_MFMA, "cdna2") @@ -259,10 +277,12 @@ def test_oracle_cdna_16x16x16_f32bf16bf16_1k(): # CDNA1/2 BF16 (original, non-1K) atoms # ============================================================================= + @requires_calculator def test_oracle_cdna_32x32x4_f32bf16bf16(): validate_c_layout(CDNA_32x32x4_F32BF16BF16_MFMA, "cdna1") + @requires_calculator def test_oracle_cdna_16x16x8_f32bf16bf16(): validate_c_layout(CDNA_16x16x8_F32BF16BF16_MFMA, "cdna1") @@ -272,10 +292,12 @@ def test_oracle_cdna_16x16x8_f32bf16bf16(): # CDNA1/2 INT8 atoms # ============================================================================= + @requires_calculator def test_oracle_cdna_32x32x8_i32i8i8(): validate_c_layout(CDNA_32x32x8_I32I8I8_MFMA, "cdna1") + @requires_calculator def test_oracle_cdna_16x16x16_i32i8i8(): validate_c_layout(CDNA_16x16x16_I32I8I8_MFMA, "cdna1") @@ -285,10 +307,12 @@ def test_oracle_cdna_16x16x16_i32i8i8(): # CDNA1/2 FP32 atoms # ============================================================================= + @requires_calculator def test_oracle_cdna_32x32x2_f32f32f32(): validate_c_layout(CDNA_32x32x2_F32F32F32_MFMA, "cdna1") + @requires_calculator def test_oracle_cdna_16x16x4_f32f32f32(): validate_c_layout(CDNA_16x16x4_F32F32F32_MFMA, "cdna1") @@ -298,6 +322,7 @@ def test_oracle_cdna_16x16x4_f32f32f32(): # CDNA2/3 FP64 atom # ============================================================================= + @requires_calculator def test_oracle_cdna_16x16x4_f64f64f64(): validate_c_layout(CDNA_16x16x4_F64F64F64_MFMA, "cdna2") @@ -307,18 +332,22 @@ def test_oracle_cdna_16x16x4_f64f64f64(): # CDNA3 (gfx942) enhanced atoms # ============================================================================= + @requires_calculator def test_oracle_cdna3_32x32x16_i32i8i8(): validate_c_layout(CDNA3_32x32x16_I32I8I8_MFMA, "cdna3") + @requires_calculator def test_oracle_cdna3_16x16x32_i32i8i8(): validate_c_layout(CDNA3_16x16x32_I32I8I8_MFMA, "cdna3") + @requires_calculator def test_oracle_cdna3_32x32x4_f32xf32xf32(): validate_c_layout(CDNA3_32x32x4_F32XF32XF32_MFMA, "cdna3") + @requires_calculator def test_oracle_cdna3_16x16x8_f32xf32xf32(): validate_c_layout(CDNA3_16x16x8_F32XF32XF32_MFMA, "cdna3") @@ -328,34 +357,42 @@ def test_oracle_cdna3_16x16x8_f32xf32xf32(): # CDNA3 FP8 atoms # ============================================================================= + @requires_calculator def test_oracle_cdna3_32x32x16_f32f8f8(): validate_c_layout(CDNA3_32x32x16_F32F8F8_MFMA, "cdna3") + @requires_calculator def test_oracle_cdna3_16x16x32_f32f8f8(): validate_c_layout(CDNA3_16x16x32_F32F8F8_MFMA, "cdna3") + @requires_calculator def test_oracle_cdna3_32x32x16_f32bf8bf8(): validate_c_layout(CDNA3_32x32x16_F32BF8BF8_MFMA, "cdna3") + @requires_calculator def test_oracle_cdna3_16x16x32_f32bf8bf8(): validate_c_layout(CDNA3_16x16x32_F32BF8BF8_MFMA, "cdna3") + @requires_calculator def test_oracle_cdna3_32x32x16_f32f8bf8(): validate_c_layout(CDNA3_32x32x16_F32F8BF8_MFMA, "cdna3") + @requires_calculator def test_oracle_cdna3_16x16x32_f32f8bf8(): validate_c_layout(CDNA3_16x16x32_F32F8BF8_MFMA, "cdna3") + @requires_calculator def test_oracle_cdna3_32x32x16_f32bf8f8(): validate_c_layout(CDNA3_32x32x16_F32BF8F8_MFMA, "cdna3") + @requires_calculator def test_oracle_cdna3_16x16x32_f32bf8f8(): validate_c_layout(CDNA3_16x16x32_F32BF8F8_MFMA, "cdna3") @@ -365,26 +402,32 @@ def test_oracle_cdna3_16x16x32_f32bf8f8(): # CDNA3+ (gfx950) double-rate atoms # ============================================================================= + @requires_calculator def test_oracle_cdna3p_32x32x16_f32f16f16(): validate_c_layout(CDNA3P_32x32x16_F32F16F16_MFMA, "cdna3") + @requires_calculator def test_oracle_cdna3p_16x16x32_f32f16f16(): validate_c_layout(CDNA3P_16x16x32_F32F16F16_MFMA, "cdna3") + @requires_calculator def test_oracle_cdna3p_32x32x16_f32bf16bf16(): validate_c_layout(CDNA3P_32x32x16_F32BF16BF16_MFMA, "cdna3") + @requires_calculator def test_oracle_cdna3p_16x16x32_f32bf16bf16(): validate_c_layout(CDNA3P_16x16x32_F32BF16BF16_MFMA, "cdna3") + @requires_calculator def test_oracle_cdna3p_32x32x32_i32i8i8(): validate_c_layout(CDNA3P_32x32x32_I32I8I8_MFMA, "cdna3") + @requires_calculator def test_oracle_cdna3p_16x16x64_i32i8i8(): validate_c_layout(CDNA3P_16x16x64_I32I8I8_MFMA, "cdna3") @@ -397,6 +440,7 @@ def test_oracle_cdna3p_16x16x64_i32i8i8(): # These tests verify algebraic properties of the layouts themselves, # independent of the AMD calculator. They always run. + @pytest.mark.parametrize("atom", ALL_ATOMS, ids=lambda a: a.name) class TestMFMAStructural: """Structural invariants that must hold for any valid MFMA atom.""" @@ -412,118 +456,127 @@ def test_c_layout_covers_all_elements(self, atom): for t in range(num_t): for v in range(num_v): offset = c(t, v) - assert 0 <= offset < m * n, \ - f"{atom.name}: offset {offset} out of range [0, {m*n})" - assert offset not in seen, \ - f"{atom.name}: duplicate offset {offset} at t={t}, v={v}" + assert ( + 0 <= offset < m * n + ), f"{atom.name}: offset {offset} out of range [0, {m*n})" + assert ( + offset not in seen + ), f"{atom.name}: duplicate offset {offset} at t={t}, v={v}" seen.add(offset) - assert len(seen) == m * n, \ - f"{atom.name}: covers {len(seen)} elements, expected {m*n}" + assert ( + len(seen) == m * n + ), f"{atom.name}: covers {len(seen)} elements, expected {m*n}" def test_c_layout_thread_count(self, atom): """Thread dimension has exactly 64 elements (one wavefront).""" c = atom.c_layout - assert _num_threads(c) == 64, \ - f"{atom.name}: {_num_threads(c)} threads, expected 64" + assert ( + _num_threads(c) == 64 + ), f"{atom.name}: {_num_threads(c)} threads, expected 64" def test_a_layout_thread_count(self, atom): """A layout thread dimension has exactly 64 elements.""" a = atom.a_layout - assert _num_threads(a) == 64, \ - f"{atom.name}: A has {_num_threads(a)} threads, expected 64" + assert ( + _num_threads(a) == 64 + ), f"{atom.name}: A has {_num_threads(a)} threads, expected 64" def test_b_layout_thread_count(self, atom): """B layout thread dimension has exactly 64 elements.""" b = atom.b_layout - assert _num_threads(b) == 64, \ - f"{atom.name}: B has {_num_threads(b)} threads, expected 64" + assert ( + _num_threads(b) == 64 + ), f"{atom.name}: B has {_num_threads(b)} threads, expected 64" def test_a_layout_broadcast(self, atom): """A layout broadcasts across blocks (stride-0 in block dimension).""" a = atom.a_layout if isinstance(a.stride, tuple) and isinstance(a.stride[0], tuple): blk_stride = a.stride[0][0] - assert blk_stride == 0, \ - f"{atom.name}: A layout block stride is {blk_stride}, expected 0" + assert ( + blk_stride == 0 + ), f"{atom.name}: A layout block stride is {blk_stride}, expected 0" def test_b_layout_broadcast(self, atom): """B layout broadcasts across blocks (stride-0 in block dimension).""" b = atom.b_layout if isinstance(b.stride, tuple) and isinstance(b.stride[0], tuple): blk_stride = b.stride[0][0] - assert blk_stride == 0, \ - f"{atom.name}: B layout block stride is {blk_stride}, expected 0" + assert ( + blk_stride == 0 + ), f"{atom.name}: B layout block stride is {blk_stride}, expected 0" def test_a_layout_cosize_bounded(self, atom): """A layout codomain is bounded by thread_count * values_per_thread.""" a = atom.a_layout # cosize is max_offset + 1; for broadcast layouts this can exceed M*K # but must be bounded by the underlying coordinate space - assert cosize(a) >= 1, \ - f"{atom.name}: A cosize must be positive" + assert cosize(a) >= 1, f"{atom.name}: A cosize must be positive" def test_b_layout_cosize_bounded(self, atom): """B layout codomain is bounded by thread_count * values_per_thread.""" b = atom.b_layout - assert cosize(b) >= 1, \ - f"{atom.name}: B cosize must be positive" + assert cosize(b) >= 1, f"{atom.name}: B cosize must be positive" def test_c_layout_cosize_equals_mn(self, atom): """C layout codomain spans exactly M x N (since it's a bijection).""" m, n, k = atom.shape_mnk c = atom.c_layout - assert cosize(c) == m * n, \ - f"{atom.name}: C cosize {cosize(c)} != M*N={m*n}" + assert cosize(c) == m * n, f"{atom.name}: C cosize {cosize(c)} != M*N={m*n}" def test_thr_id_is_none(self, atom): """AMD MFMA atoms use identity thread mapping (thr_id is None).""" - assert atom.thr_id is None, \ - f"{atom.name}: thr_id should be None, got {atom.thr_id}" + assert ( + atom.thr_id is None + ), f"{atom.name}: thr_id should be None, got {atom.thr_id}" def test_c_layout_rank_is_2(self, atom): """C layout is rank-2: (thread, value).""" c = atom.c_layout - assert rank(c) == 2, \ - f"{atom.name}: C rank {rank(c)}, expected 2" + assert rank(c) == 2, f"{atom.name}: C rank {rank(c)}, expected 2" def test_a_layout_rank_is_2(self, atom): """A layout is rank-2: (thread, value).""" a = atom.a_layout - assert rank(a) == 2, \ - f"{atom.name}: A rank {rank(a)}, expected 2" + assert rank(a) == 2, f"{atom.name}: A rank {rank(a)}, expected 2" def test_b_layout_rank_is_2(self, atom): """B layout is rank-2: (thread, value).""" b = atom.b_layout - assert rank(b) == 2, \ - f"{atom.name}: B rank {rank(b)}, expected 2" + assert rank(b) == 2, f"{atom.name}: B rank {rank(b)}, expected 2" def test_layout_sizes_match_shape_mnk(self, atom): """Layout domain sizes are consistent with M, N, K.""" m, n, k = atom.shape_mnk a, b, c = atom.a_layout, atom.b_layout, atom.c_layout - assert size(c) == m * n, \ - f"{atom.name}: C size {size(c)} != M*N={m*n}" + assert size(c) == m * n, f"{atom.name}: C size {size(c)} != M*N={m*n}" # A and B sizes include the broadcast dimension, so size >= M*K / N*K # but since broadcast replicates the same data, size == 64 * values_per_thread - assert size(a) == 64 * _num_values(a), \ - f"{atom.name}: A size {size(a)} != 64 * {_num_values(a)}" - assert size(b) == 64 * _num_values(b), \ - f"{atom.name}: B size {size(b)} != 64 * {_num_values(b)}" + assert size(a) == 64 * _num_values( + a + ), f"{atom.name}: A size {size(a)} != 64 * {_num_values(a)}" + assert size(b) == 64 * _num_values( + b + ), f"{atom.name}: B size {size(b)} != 64 * {_num_values(b)}" # ============================================================================= # Layout algebra tests (run without the calculator) # ============================================================================= + @pytest.mark.parametrize("atom", ALL_ATOMS, ids=lambda a: a.name) class TestLayoutAlgebra: """Test layout algebra operations on real AMD atom layouts.""" def test_size_rank_depth_mode(self, atom): """Exercise size(), rank(), depth(), mode() on all three layouts.""" - for layout_name, layout in [("C", atom.c_layout), ("A", atom.a_layout), ("B", atom.b_layout)]: + for layout_name, layout in [ + ("C", atom.c_layout), + ("A", atom.a_layout), + ("B", atom.b_layout), + ]: s = size(layout) r = rank(layout) d = depth(layout) @@ -534,8 +587,9 @@ def test_size_rank_depth_mode(self, atom): # mode(layout, 0) is the thread dimension thr_mode = mode(layout, 0) val_mode = mode(layout, 1) - assert size(thr_mode) * size(val_mode) == s, \ - f"{atom.name} {layout_name}: mode sizes don't multiply to total" + assert ( + size(thr_mode) * size(val_mode) == s + ), f"{atom.name} {layout_name}: mode sizes don't multiply to total" def test_flatten_preserves_mapping(self, atom): """flatten(c_layout) produces the same offsets for all flat indices.""" @@ -543,16 +597,18 @@ def test_flatten_preserves_mapping(self, atom): c_flat = flatten(c) # Flattened layout should produce same offsets when indexed linearly for i in range(size(c)): - assert c_flat(i) == c(i), \ - f"{atom.name}: flatten mismatch at {i}: {c_flat(i)} != {c(i)}" + assert c_flat(i) == c( + i + ), f"{atom.name}: flatten mismatch at {i}: {c_flat(i)} != {c(i)}" def test_coalesce_preserves_mapping(self, atom): """coalesce(c_layout) produces the same offsets.""" c = atom.c_layout c_coal = coalesce(c) for i in range(size(c)): - assert c_coal(i) == c(i), \ - f"{atom.name}: coalesce mismatch at {i}: {c_coal(i)} != {c(i)}" + assert c_coal(i) == c( + i + ), f"{atom.name}: coalesce mismatch at {i}: {c_coal(i)} != {c(i)}" def test_compose_with_identity(self, atom): """compose(L, identity) == L for all indices.""" @@ -560,8 +616,7 @@ def test_compose_with_identity(self, atom): identity = Layout(size(c)) # col-major identity composed = compose(c, identity) for i in range(size(c)): - assert composed(i) == c(i), \ - f"{atom.name}: compose(C, id) mismatch at {i}" + assert composed(i) == c(i), f"{atom.name}: compose(C, id) mismatch at {i}" def test_complement_c_layout(self, atom): """complement of flattened C layout produces valid ordered disjoint layout.""" @@ -571,13 +626,15 @@ def test_complement_c_layout(self, atom): comp = complement(c_flat) # complement must be ordered: comp(i-1) < comp(i) for i >= 1 for i in range(1, size(comp)): - assert comp(i - 1) < comp(i), \ - f"{atom.name}: complement not ordered at {i}: {comp(i-1)} >= {comp(i)}" + assert comp(i - 1) < comp( + i + ), f"{atom.name}: complement not ordered at {i}: {comp(i-1)} >= {comp(i)}" # complement must be disjoint from layout for i >= 1 c_offsets = {c_flat(j) for j in range(size(c_flat))} for i in range(1, size(comp)): - assert comp(i) not in c_offsets, \ - f"{atom.name}: complement({i})={comp(i)} overlaps with layout" + assert ( + comp(i) not in c_offsets + ), f"{atom.name}: complement({i})={comp(i)} overlaps with layout" def test_left_inverse_c_layout(self, atom): """left_inverse(C) composed with C gives identity for flat indices.""" @@ -588,8 +645,9 @@ def test_left_inverse_c_layout(self, atom): for i in range(size(c_flat)): offset = c_flat(i) roundtrip = linv(offset) - assert roundtrip == i, \ - f"{atom.name}: left_inverse roundtrip at {i}: {roundtrip} != {i}" + assert ( + roundtrip == i + ), f"{atom.name}: left_inverse roundtrip at {i}: {roundtrip} != {i}" def test_right_inverse_c_layout(self, atom): """C composed with right_inverse(C) gives identity for offsets in range.""" @@ -600,8 +658,9 @@ def test_right_inverse_c_layout(self, atom): for i in range(size(c_flat)): offset = c_flat(i) roundtrip = c_flat(rinv(offset)) - assert roundtrip == offset, \ - f"{atom.name}: right_inverse roundtrip at offset {offset}: {roundtrip} != {offset}" + assert ( + roundtrip == offset + ), f"{atom.name}: right_inverse roundtrip at offset {offset}: {roundtrip} != {offset}" def test_logical_divide_c_layout(self, atom): """logical_divide factors C layout into (tile, rest).""" @@ -614,8 +673,9 @@ def test_logical_divide_c_layout(self, atom): c_flat_thr = flatten(mode(c, 0)) divided = logical_divide(c_flat_thr, tiler) # The divided layout must cover the same total size - assert size(divided) == size(c_flat_thr), \ - f"{atom.name}: logical_divide changed size: {size(divided)} != {size(c_flat_thr)}" + assert size(divided) == size( + c_flat_thr + ), f"{atom.name}: logical_divide changed size: {size(divided)} != {size(c_flat_thr)}" def test_logical_product(self, atom): """logical_product replicates a pattern across positions.""" @@ -625,12 +685,14 @@ def test_logical_product(self, atom): replicator = Layout(2, size(c_flat)) product = logical_product(c_flat, replicator) # Size should be original * 2 - assert size(product) == size(c_flat) * 2, \ - f"{atom.name}: logical_product size {size(product)} != {size(c_flat) * 2}" + assert ( + size(product) == size(c_flat) * 2 + ), f"{atom.name}: logical_product size {size(product)} != {size(c_flat) * 2}" # First half should match original for i in range(size(c_flat)): - assert product(i) == c_flat(i), \ - f"{atom.name}: logical_product first-half mismatch at {i}" + assert product(i) == c_flat( + i + ), f"{atom.name}: logical_product first-half mismatch at {i}" def test_idx2crd_crd2idx_roundtrip(self, atom): """idx2crd and crd2idx are inverses on the thread dimension shape.""" @@ -639,8 +701,9 @@ def test_idx2crd_crd2idx_roundtrip(self, atom): for i in range(size(thr_shape)): crd = idx2crd(i, thr_shape) idx = crd2idx(crd, thr_shape) - assert idx == i, \ - f"{atom.name}: idx2crd/crd2idx roundtrip at {i}: {idx} != {i}" + assert ( + idx == i + ), f"{atom.name}: idx2crd/crd2idx roundtrip at {i}: {idx} != {i}" def test_idx2crd_crd2idx_roundtrip_val(self, atom): """idx2crd/crd2idx roundtrip on value dimension.""" @@ -649,8 +712,9 @@ def test_idx2crd_crd2idx_roundtrip_val(self, atom): for i in range(size(val_shape)): crd = idx2crd(i, val_shape) idx = crd2idx(crd, val_shape) - assert idx == i, \ - f"{atom.name}: val idx2crd/crd2idx roundtrip at {i}: {idx} != {i}" + assert ( + idx == i + ), f"{atom.name}: val idx2crd/crd2idx roundtrip at {i}: {idx} != {i}" def test_flatten_is_idempotent(self, atom): """flatten(flatten(L)) == flatten(L).""" @@ -658,8 +722,7 @@ def test_flatten_is_idempotent(self, atom): once = flatten(c) twice = flatten(once) for i in range(size(c)): - assert once(i) == twice(i), \ - f"{atom.name}: flatten not idempotent at {i}" + assert once(i) == twice(i), f"{atom.name}: flatten not idempotent at {i}" def test_coalesce_is_idempotent(self, atom): """coalesce(coalesce(L)) == coalesce(L).""" @@ -667,16 +730,14 @@ def test_coalesce_is_idempotent(self, atom): once = coalesce(c) twice = coalesce(once) for i in range(size(c)): - assert once(i) == twice(i), \ - f"{atom.name}: coalesce not idempotent at {i}" + assert once(i) == twice(i), f"{atom.name}: coalesce not idempotent at {i}" def test_flatten_then_coalesce(self, atom): """flatten then coalesce produces same mapping.""" c = atom.c_layout fc = coalesce(flatten(c)) for i in range(size(c)): - assert fc(i) == c(i), \ - f"{atom.name}: flatten+coalesce mismatch at {i}" + assert fc(i) == c(i), f"{atom.name}: flatten+coalesce mismatch at {i}" def test_compose_chain(self, atom): """compose(compose(L, A), B) == compose(L, compose(A, B)) (associativity).""" @@ -689,8 +750,9 @@ def test_compose_chain(self, atom): lhs = compose(compose(c_flat, a), b) rhs = compose(c_flat, compose(a, b)) for i in range(size(b)): - assert lhs(i) == rhs(i), \ - f"{atom.name}: compose associativity failed at {i}: {lhs(i)} != {rhs(i)}" + assert lhs(i) == rhs( + i + ), f"{atom.name}: compose associativity failed at {i}: {lhs(i)} != {rhs(i)}" def test_make_ordered_layout_flat_c_shape(self, atom): """make_ordered_layout on flattened C shape produces ordered strides.""" @@ -698,12 +760,14 @@ def test_make_ordered_layout_flat_c_shape(self, atom): c_flat = flatten(c) ordered = make_ordered_layout(c_flat.shape) # Same size - assert size(ordered) == size(c), \ - f"{atom.name}: make_ordered_layout changed size" + assert size(ordered) == size( + c + ), f"{atom.name}: make_ordered_layout changed size" # Ordered: strides should be increasing (column-major order) for i in range(1, size(ordered)): - assert ordered(i) > ordered(i - 1), \ - f"{atom.name}: make_ordered_layout not ordered at {i}" + assert ordered(i) > ordered( + i - 1 + ), f"{atom.name}: make_ordered_layout not ordered at {i}" # ============================================================================= @@ -730,16 +794,18 @@ def test_compute_tv_mapping_c(self, atom): """_compute_tv_mapping on c_layout covers every cell of the M x N grid.""" m, n, k = atom.shape_mnk c = atom.c_layout - tv_map = _compute_tv_mapping(c, grid_cols=n, grid_rows=m, - col_major=True) + tv_map = _compute_tv_mapping(c, grid_cols=n, grid_rows=m, col_major=True) # Every (row, col) in [0,M) x [0,N) should have an entry for row in range(m): for col in range(n): - assert (row, col) in tv_map, \ - f"{atom.name}: C tv_map missing ({row},{col})" + assert ( + row, + col, + ) in tv_map, f"{atom.name}: C tv_map missing ({row},{col})" phys_t, v_idx, logical_t = tv_map[(row, col)] - assert 0 <= phys_t < 64, \ - f"{atom.name}: C invalid thread {phys_t} at ({row},{col})" + assert ( + 0 <= phys_t < 64 + ), f"{atom.name}: C invalid thread {phys_t} at ({row},{col})" def test_compute_tv_mapping_a(self, atom): """_compute_tv_mapping on a_layout produces valid entries.""" @@ -756,8 +822,9 @@ def test_compute_tv_mapping_a(self, atom): for t in range(64): for v in range(num_v): offset = a(t, v) - assert 0 <= offset < a_cosize, \ - f"{atom.name}: A offset {offset} out of range [0, {a_cosize})" + assert ( + 0 <= offset < a_cosize + ), f"{atom.name}: A offset {offset} out of range [0, {a_cosize})" def test_compute_tv_mapping_b(self, atom): """_compute_tv_mapping on b_layout produces valid entries.""" @@ -768,15 +835,15 @@ def test_compute_tv_mapping_b(self, atom): for t in range(64): for v in range(num_v): offset = b(t, v) - assert 0 <= offset < b_cosize, \ - f"{atom.name}: B offset {offset} out of range [0, {b_cosize})" + assert ( + 0 <= offset < b_cosize + ), f"{atom.name}: B offset {offset} out of range [0, {b_cosize})" def test_compute_tv_mapping_c_threads_match(self, atom): """Thread IDs from tv_mapping match direct layout evaluation.""" m, n, k = atom.shape_mnk c = atom.c_layout - tv_map = _compute_tv_mapping(c, grid_cols=n, grid_rows=m, - col_major=True) + tv_map = _compute_tv_mapping(c, grid_cols=n, grid_rows=m, col_major=True) # Rebuild the forward map and compare num_v = _num_values(c) for t in range(64): @@ -784,41 +851,57 @@ def test_compute_tv_mapping_c_threads_match(self, atom): offset = c(t, v) row = offset % m col = offset // m - assert (row, col) in tv_map, \ - f"{atom.name}: ({row},{col}) missing from tv_map" + assert ( + row, + col, + ) in tv_map, f"{atom.name}: ({row},{col}) missing from tv_map" phys_t, v_idx, logical_t = tv_map[(row, col)] - assert phys_t == t, \ - f"{atom.name}: thread mismatch at ({row},{col}): {phys_t} != {t}" - assert v_idx == v, \ - f"{atom.name}: value mismatch at ({row},{col}): {v_idx} != {v}" + assert ( + phys_t == t + ), f"{atom.name}: thread mismatch at ({row},{col}): {phys_t} != {t}" + assert ( + v_idx == v + ), f"{atom.name}: value mismatch at ({row},{col}): {v_idx} != {v}" def test_draw_tv_layout_smoke(self, atom): """draw_tv_layout runs without error (output to tempfile).""" m, n, k = atom.shape_mnk c = atom.c_layout with tempfile.NamedTemporaryFile(suffix=".png") as f: - draw_tv_layout(c, filename=f.name, - grid_shape=(m, n), colorize=True) + draw_tv_layout(c, filename=f.name, grid_shape=(m, n), colorize=True) def test_draw_mma_layout_smoke(self, atom): """draw_mma_layout runs without error.""" m, n, k = atom.shape_mnk with tempfile.NamedTemporaryFile(suffix=".png") as f: if atom.name == "CDNA_4x4x4_F32F16F16_MFMA": - with pytest.raises(ValueError, match=r"A .*panel shape .*out of bounds"): - draw_mma_layout(atom.a_layout, atom.b_layout, atom.c_layout, - filename=f.name, tile_mnk=(m, n, k), - main_title=atom.name) + with pytest.raises( + ValueError, match=r"A .*panel shape .*out of bounds" + ): + draw_mma_layout( + atom.a_layout, + atom.b_layout, + atom.c_layout, + filename=f.name, + tile_mnk=(m, n, k), + main_title=atom.name, + ) else: - draw_mma_layout(atom.a_layout, atom.b_layout, atom.c_layout, - filename=f.name, tile_mnk=(m, n, k), - main_title=atom.name) + draw_mma_layout( + atom.a_layout, + atom.b_layout, + atom.c_layout, + filename=f.name, + tile_mnk=(m, n, k), + main_title=atom.name, + ) # ============================================================================= # Layout utils tests # ============================================================================= + @pytest.mark.parametrize("atom", ALL_ATOMS, ids=lambda a: a.name) class TestLayoutUtils: """Test layout_utils functions on AMD atom layouts.""" @@ -842,5 +925,6 @@ def test_tile_to_shape_c(self, atom): # Tile to 2x the original shape target = (size(c.shape[0]) * 2, size(c.shape[1])) tiled = tile_to_shape(c, target) - assert size(tiled) == size(c) * 2, \ - f"{atom.name}: tile_to_shape wrong size: {size(tiled)} != {size(c) * 2}" + assert ( + size(tiled) == size(c) * 2 + ), f"{atom.name}: tile_to_shape wrong size: {size(tiled)} != {size(c) * 2}" diff --git a/tests/oracle_nv.py b/tests/oracle_nv.py index 9f1e49e..02a2564 100644 --- a/tests/oracle_nv.py +++ b/tests/oracle_nv.py @@ -30,7 +30,7 @@ invariants over all valid layouts up to a size bound. """ -from tensor_layouts import * +from tensor_layouts import * # noqa: F401,F403,F405 from tensor_layouts.layout_utils import make_ordered_layout, tile_to_shape import pytest @@ -40,12 +40,15 @@ # We need NVIDIA's pycute from the CUTLASS source tree. try: import pycute - if not hasattr(pycute, 'Layout'): + + if not hasattr(pycute, "Layout"): pycute = None except ImportError: pycute = None -pytestmark = pytest.mark.skipif(pycute is None, reason="pycute (NVIDIA CUTLASS) not available") +pytestmark = pytest.mark.skipif( + pycute is None, reason="pycute (NVIDIA CUTLASS) not available" +) ############################################################################### @@ -115,33 +118,53 @@ def layouts_functionally_equal(our, ref, domain_size): # Standard test layouts: (shape, stride) pairs covering many patterns LAYOUT_CORPUS = [ # Trivial - (1, 0), (1, 1), + (1, 0), + (1, 1), # 1D - (4, 1), (4, 2), (8, 1), (8, 2), (12, 1), (12, 3), + (4, 1), + (4, 2), + (8, 1), + (8, 2), + (12, 1), + (12, 3), # Zero stride (broadcast) - (4, 0), (8, 0), + (4, 0), + (8, 0), # 2D col-major - ((2, 4), (1, 2)), ((4, 3), (1, 4)), ((8, 4), (1, 8)), + ((2, 4), (1, 2)), + ((4, 3), (1, 4)), + ((8, 4), (1, 8)), # 2D row-major - ((2, 4), (4, 1)), ((4, 3), (3, 1)), ((8, 4), (4, 1)), + ((2, 4), (4, 1)), + ((4, 3), (3, 1)), + ((8, 4), (4, 1)), # 2D with gaps - ((2, 4), (1, 4)), ((2, 4), (1, 6)), ((4, 2), (1, 10)), ((4, 2), (1, 16)), + ((2, 4), (1, 4)), + ((2, 4), (1, 6)), + ((4, 2), (1, 10)), + ((4, 2), (1, 16)), # 2D with broadcast - ((2, 4), (0, 2)), ((4, 2), (2, 0)), + ((2, 4), (0, 2)), + ((4, 2), (2, 0)), # 3D - ((2, 4, 6), (1, 2, 8)), ((2, 4, 6), (4, 1, 8)), - ((2, 3, 4), (1, 2, 6)), ((2, 4, 8), (8, 1, 64)), + ((2, 4, 6), (1, 2, 8)), + ((2, 4, 6), (4, 1, 8)), + ((2, 3, 4), (1, 2, 6)), + ((2, 4, 8), (8, 1, 64)), ((2, 4, 6), (24, 6, 1)), # 3D with broadcast - ((2, 4, 8), (8, 1, 0)), ((2, 4, 3), (1, 2, 0)), + ((2, 4, 8), (8, 1, 0)), + ((2, 4, 3), (1, 2, 0)), # Nested (hierarchical) (((2, 2), (2, 2)), ((1, 4), (8, 32))), ((2, (3, 4)), (3, (1, 6))), (((4, 2),), ((1, 16),)), # Auto-stride (col-major) - ((2, 4), None), ((4, 3), None), ((2, 4, 6), None), ((2, 3, 4), None), + ((2, 4), None), + ((4, 3), None), + ((2, 4, 6), None), + ((2, 3, 4), None), ((8, 8), None), - # ===== FROM C++ TESTS ===== # C++ inverse / complement tests: broadcast shapes (((3, 7),), ((0, 0),)), @@ -164,7 +187,6 @@ def layouts_functionally_equal(our, ref, domain_size): ((4, 10), (1, 10)), # C++ composition tests: transposed strides ((4, 3), (3, 1)), - # ===== EDGE CASES ===== # All-zero strides (pure broadcast) ((2, 3, 4), (0, 0, 0)), @@ -307,11 +329,22 @@ def test_oracle_composition(): # Composition pairs: (A_shape, A_stride, B_shape, B_stride) composition_pairs = [ # Simple - (1, 0, 1, 0), (1, 0, 1, 1), (1, 1, 1, 0), (1, 1, 1, 1), - (4, 1, 4, 1), (4, 2, 4, 1), (4, 1, 4, 2), (4, 0, 4, 1), - (4, 1, 4, 0), (1, 0, 4, 1), (4, 1, 1, 0), + (1, 0, 1, 0), + (1, 0, 1, 1), + (1, 1, 1, 0), + (1, 1, 1, 1), + (4, 1, 4, 1), + (4, 2, 4, 1), + (4, 1, 4, 2), + (4, 0, 4, 1), + (4, 1, 4, 0), + (1, 0, 4, 1), + (4, 1, 1, 0), # Partial - (4, 1, 2, 1), (4, 2, 2, 1), (4, 1, 2, 2), (4, 2, 2, 2), + (4, 1, 2, 1), + (4, 2, 2, 1), + (4, 1, 2, 2), + (4, 2, 2, 2), # Multi-dim A, 1D B ((4, 3), (1, 4), 12, 1), ((4, 3), (1, 4), 6, 1), @@ -366,16 +399,32 @@ def test_oracle_shape_div(): """Cross-validate shape_div() against pycute.""" test_cases = [ # (shape, divisor) -- only cases valid for pycute (a%b==0 or b%a==0 at each level) - ((3, 4), 1), ((3, 4), 3), ((3, 4), 6), - ((3, 4), 12), ((3, 4), 36), - ((4, 3), 2), ((4, 3), 4), ((4, 3), 12), - ((6, 2), 2), ((6, 2), 3), ((6, 2), 6), ((6, 2), 12), + ((3, 4), 1), + ((3, 4), 3), + ((3, 4), 6), + ((3, 4), 12), + ((3, 4), 36), + ((4, 3), 2), + ((4, 3), 4), + ((4, 3), 12), + ((6, 2), 2), + ((6, 2), 3), + ((6, 2), 6), + ((6, 2), 12), # Nested - (((3, 4), 6), 1), (((3, 4), 6), 3), (((3, 4), 6), 12), - (((3, 4), 6), 36), (((3, 4), 6), 72), - ((6, (3, 4)), 6), ((6, (3, 4)), 36), + (((3, 4), 6), 1), + (((3, 4), 6), 3), + (((3, 4), 6), 12), + (((3, 4), 6), 36), + (((3, 4), 6), 72), + ((6, (3, 4)), 6), + ((6, (3, 4)), 36), # Scalars - (12, 1), (12, 3), (12, 4), (12, 6), (12, 12), + (12, 1), + (12, 3), + (12, 4), + (12, 6), + (12, 12), ] for shape, divisor in test_cases: @@ -391,10 +440,9 @@ def test_oracle_shape_div(): assert shapes_equal( ours_r if isinstance(ours_r, int) else ours_r, - ref_r if isinstance(ref_r, int) else ref_r + ref_r if isinstance(ref_r, int) else ref_r, ), ( - f"shape_div({shape}, {divisor}): " - f"ours={ours_r} vs pycute={ref_r}" + f"shape_div({shape}, {divisor}): " f"ours={ours_r} vs pycute={ref_r}" ) @@ -417,10 +465,8 @@ def test_oracle_prefix_product(): assert shapes_equal( ours_r if isinstance(ours_r, int) else ours_r, - ref_r if isinstance(ref_r, int) else ref_r - ), ( - f"prefix_product({shape}): ours={ours_r} vs pycute={ref_r}" - ) + ref_r if isinstance(ref_r, int) else ref_r, + ), f"prefix_product({shape}): ours={ours_r} vs pycute={ref_r}" def test_oracle_inner_product(): @@ -440,9 +486,9 @@ def test_oracle_inner_product(): ours_r = inner_product(ours_a, ours_b) ref_r = pycute.inner_product(ref_a, ref_b) - assert ours_r == ref_r, ( - f"inner_product({a}, {b}): ours={ours_r} vs pycute={ref_r}" - ) + assert ( + ours_r == ref_r + ), f"inner_product({a}, {b}): ours={ours_r} vs pycute={ref_r}" def test_oracle_right_inverse(): @@ -490,12 +536,22 @@ def test_oracle_logical_divide(): """Cross-validate logical_divide() with Layout tilers against pycute.""" # (layout_shape, layout_stride, tiler_shape, tiler_stride) divide_cases = [ - (1, 0, 1, 0), (1, 0, 1, 1), (1, 1, 1, 0), (1, 1, 1, 1), - (6, 1, 2, 1), (6, 1, 2, 3), (6, 2, 2, 1), (6, 2, 2, 3), - (6, 1, (2, 3), (3, 1)), (6, 2, (2, 3), (3, 1)), + (1, 0, 1, 0), + (1, 0, 1, 1), + (1, 1, 1, 0), + (1, 1, 1, 1), + (6, 1, 2, 1), + (6, 1, 2, 3), + (6, 2, 2, 1), + (6, 2, 2, 3), + (6, 1, (2, 3), (3, 1)), + (6, 2, (2, 3), (3, 1)), (32, 1, 2, 8), - (12, 1, 4, 1), (12, 1, 6, 1), (12, 2, 4, 1), - (48, 1, 32, 1), (96, 1, 32, 2), + (12, 1, 4, 1), + (12, 1, 6, 1), + (12, 2, 4, 1), + (48, 1, 32, 1), + (96, 1, 32, 2), ] for item in divide_cases: @@ -521,16 +577,26 @@ def test_oracle_logical_product(): """Cross-validate logical_product() against pycute.""" product_cases = [ # (A_shape, A_stride, B_shape, B_stride) - (4, 1, 3, 1), (4, 2, 3, 1), (4, 1, 3, 2), + (4, 1, 3, 1), + (4, 2, 3, 1), + (4, 1, 3, 2), ((2, 4), (1, 2), 3, 1), - (8, 1, 4, 1), (8, 2, 4, 1), + (8, 1, 4, 1), + (8, 2, 4, 1), # === C++ test cases === # Trivial - (1, 0, 1, 0), (1, 1, 1, 0), (1, 0, 1, 1), (1, 1, 1, 1), + (1, 0, 1, 0), + (1, 1, 1, 0), + (1, 0, 1, 1), + (1, 1, 1, 1), # Broadcast - (3, 1, 4, 0), (3, 0, 4, 1), (3, 0, 4, 0), (3, 2, 4, 1), + (3, 1, 4, 0), + (3, 0, 4, 1), + (3, 0, 4, 0), + (3, 2, 4, 1), # 1D - (3, 1, (2, 4), None), ((2, 4), None, 3, 1), + (3, 1, (2, 4), None), + ((2, 4), None, 3, 1), # Hierarchical ((8, (2, 2)), None, 4, 2), ((2, 2), None, (3, 3), (3, 1)), @@ -647,22 +713,20 @@ def test_exhaustive_coalesce_preserves_mapping(): """Verify coalesce preserves the index mapping for all small layouts.""" for layout in _generate_small_layouts(): coal = coalesce(layout) - assert size(coal) == size(layout), ( - f"coalesce({layout}): size changed from {size(layout)} to {size(coal)}" - ) + assert size(coal) == size( + layout + ), f"coalesce({layout}): size changed from {size(layout)} to {size(coal)}" for i in range(size(layout)): - assert coal(i) == layout(i), ( - f"coalesce({layout})({i}) = {coal(i)} != {layout(i)}" - ) + assert coal(i) == layout( + i + ), f"coalesce({layout})({i}) = {coal(i)} != {layout(i)}" def test_exhaustive_coalesce_reduces_depth(): """Verify coalesce produces depth <= 1.""" for layout in _generate_small_layouts(): coal = coalesce(layout) - assert depth(coal) <= 1, ( - f"coalesce({layout}) has depth {depth(coal)} > 1" - ) + assert depth(coal) <= 1, f"coalesce({layout}) has depth {depth(coal)} > 1" def test_exhaustive_right_inverse_identity(): @@ -670,9 +734,9 @@ def test_exhaustive_right_inverse_identity(): for layout in _generate_small_layouts(): rinv = right_inverse(layout) for i in range(size(rinv)): - assert layout(rinv(i)) == i, ( - f"right_inverse({layout}): L(R({i}))={layout(rinv(i))} != {i}" - ) + assert ( + layout(rinv(i)) == i + ), f"right_inverse({layout}): L(R({i}))={layout(rinv(i))} != {i}" def _is_injective(layout): @@ -711,9 +775,9 @@ def test_exhaustive_left_inverse_identity(): continue linv = left_inverse(layout) for i in range(size(layout)): - assert linv(layout(i)) == i, ( - f"left_inverse({layout}): R(L({i}))={linv(layout(i))} != {i}" - ) + assert ( + linv(layout(i)) == i + ), f"left_inverse({layout}): R(L({i}))={linv(layout(i))} != {i}" def test_exhaustive_compose_identity(): @@ -745,9 +809,9 @@ def test_exhaustive_compose_identity(): n = min(size(r), pycute.size(ref_r)) for i in range(n): - assert r(i) == ref_r(i), ( - f"compose({a}, {b})({i}) = {r(i)} != pycute={ref_r(i)}" - ) + assert r(i) == ref_r( + i + ), f"compose({a}, {b})({i}) = {r(i)} != pycute={ref_r(i)}" tested += 1 assert tested > 100, f"Only tested {tested} composition pairs, expected more" @@ -761,8 +825,13 @@ def test_exhaustive_shape_div_mod_complementary(): so not all divisors of size(s) are valid. We try each and skip failures. """ shapes = [ - (2, 3), (3, 4), (2, 2, 3), (4, 3), - (6, 2), (2, 6), (3, 2, 4), + (2, 3), + (3, 4), + (2, 2, 3), + (4, 3), + (6, 2), + (2, 6), + (3, 2, 4), ] tested = 0 @@ -824,17 +893,13 @@ def test_exhaustive_inverse_roundtrip(): # right_inverse property: L(R(i)) == i (works for all layouts) for i in range(size(rinv)): - assert layout(rinv(i)) == i, ( - f"right_inverse({layout}): L(R({i})) != {i}" - ) + assert layout(rinv(i)) == i, f"right_inverse({layout}): L(R({i})) != {i}" # left_inverse property: R(L(i)) == i (only for injective + contiguous layouts) if _is_injective(layout) and _is_contiguous(layout): linv = left_inverse(layout) for i in range(size(layout)): - assert linv(layout(i)) == i, ( - f"left_inverse({layout}): R(L({i})) != {i}" - ) + assert linv(layout(i)) == i, f"left_inverse({layout}): R(L({i})) != {i}" ############################################################################### @@ -861,8 +926,8 @@ def test_oracle_tuple_max(): def test_oracle_elem_scale(): """Cross-validate elem_scale against pycute.""" cases = [ - (3, 4), # int x int - (2, (3, 4)), # int x tuple + (3, 4), # int x int + (2, (3, 4)), # int x tuple ((2, 3), (4, 5)), # tuple x tuple (1, (2, 3, 4)), # int x tuple ] @@ -968,9 +1033,9 @@ def test_oracle_zipped_product(): ours_result = zipped_product(ours_a, ours_b) ref_result = pycute.zipped_product(ref_a, ref_b) - assert size(ours_result) == pycute.size(ref_result), ( - f"zipped_product size mismatch: {size(ours_result)} != {pycute.size(ref_result)}" - ) + assert size(ours_result) == pycute.size( + ref_result + ), f"zipped_product size mismatch: {size(ours_result)} != {pycute.size(ref_result)}" for i in range(size(ours_result)): assert ours_result(i) == ref_result(i), ( f"zipped_product({a_s}:{a_d}, {b_s}:{b_d})({i}): " @@ -993,9 +1058,9 @@ def test_oracle_tiled_product(): ours_result = tiled_product(ours_a, ours_b) ref_result = pycute.tiled_product(ref_a, ref_b) - assert size(ours_result) == pycute.size(ref_result), ( - f"tiled_product size mismatch: {size(ours_result)} != {pycute.size(ref_result)}" - ) + assert size(ours_result) == pycute.size( + ref_result + ), f"tiled_product size mismatch: {size(ours_result)} != {pycute.size(ref_result)}" for i in range(size(ours_result)): assert ours_result(i) == ref_result(i), ( f"tiled_product({a_s}:{a_d}, {b_s}:{b_d})({i}): " @@ -1021,14 +1086,14 @@ def test_oracle_layout_slice(): # Both should return sublayouts with same shape/stride ours_shape = to_pycute_shape(ours_sub.shape) ref_shape_val = ref_sub.shape - assert ours_shape == ref_shape_val, ( - f"Layout({shape},{stride})({crd}) shape: {ours_shape} != {ref_shape_val}" - ) + assert ( + ours_shape == ref_shape_val + ), f"Layout({shape},{stride})({crd}) shape: {ours_shape} != {ref_shape_val}" ours_stride = to_pycute_shape(ours_sub.stride) ref_stride_val = ref_sub.stride - assert ours_stride == ref_stride_val, ( - f"Layout({shape},{stride})({crd}) stride: {ours_stride} != {ref_stride_val}" - ) + assert ( + ours_stride == ref_stride_val + ), f"Layout({shape},{stride})({crd}) stride: {ours_stride} != {ref_stride_val}" ############################################################################### @@ -1303,10 +1368,16 @@ def test_exhaustive_filter_idempotent(): def test_exhaustive_blocked_product_size(): """Verify blocked_product(A, B) has size(A) * size(B).""" layouts_1d = [ - Layout(2, 1), Layout(3, 1), Layout(4, 1), Layout(2, 2), Layout(4, 2), + Layout(2, 1), + Layout(3, 1), + Layout(4, 1), + Layout(2, 2), + Layout(4, 2), ] layouts_2d = [ - Layout((2, 2), (1, 2)), Layout((2, 3), (3, 1)), Layout((3, 2)), + Layout((2, 2), (1, 2)), + Layout((2, 3), (3, 1)), + Layout((3, 2)), ] all_layouts = layouts_1d + layouts_2d @@ -1315,8 +1386,7 @@ def test_exhaustive_blocked_product_size(): result = blocked_product(a, b) expected_size = size(a) * size(b) assert size(result) == expected_size, ( - f"blocked_product({a}, {b}): " - f"size={size(result)} != {expected_size}" + f"blocked_product({a}, {b}): " f"size={size(result)} != {expected_size}" ) @@ -1327,10 +1397,14 @@ def test_exhaustive_blocked_product_covers_offsets(): blocks: B's offsets are shifted by cosize(A) * i for each copy. """ compact_layouts = [ - Layout(2, 1), Layout(4, 1), Layout((2, 2)), + Layout(2, 1), + Layout(4, 1), + Layout((2, 2)), ] tiling_layouts = [ - Layout(2, 1), Layout(3, 1), Layout((2, 2)), + Layout(2, 1), + Layout(3, 1), + Layout((2, 2)), ] for a in compact_layouts: @@ -1373,7 +1447,11 @@ def test_exhaustive_flat_divide_preserves_mapping(): # For multi-mode layouts, only test tilers that divide evenly # within the first mode to avoid cross-mode reordering issues if r > 1: - first_mode_size = layout.shape[0] if isinstance(layout.shape[0], int) else size(Layout(layout.shape[0])) + first_mode_size = ( + layout.shape[0] + if isinstance(layout.shape[0], int) + else size(Layout(layout.shape[0])) + ) if t > first_mode_size or first_mode_size % t != 0: continue try: @@ -1383,8 +1461,7 @@ def test_exhaustive_flat_divide_preserves_mapping(): for i in range(s): assert result(i) == layout(i), ( - f"flat_divide({layout}, {t})({i}) = " - f"{result(i)} != {layout(i)}" + f"flat_divide({layout}, {t})({i}) = " f"{result(i)} != {layout(i)}" ) tested += 1 @@ -1435,9 +1512,9 @@ def test_tile_to_shape_size_preserved(): for block, target in test_cases: result = tile_to_shape(block, target) target_size = target if isinstance(target, int) else size(target) - assert size(result) == target_size, ( - f"tile_to_shape({block}, {target}): size={size(result)} != {target_size}" - ) + assert ( + size(result) == target_size + ), f"tile_to_shape({block}, {target}): size={size(result)} != {target_size}" def test_tile_to_shape_blocked_structure(): @@ -1492,7 +1569,6 @@ def test_product_each_matches_pycute_size(): assert result == expected, f"product_each({shape}) = {result} != {expected}" - @pytest.mark.skipif(pycute is None, reason="pycute not installed") def test_oracle_idx2crd(): shapes = [ @@ -1509,7 +1585,10 @@ def test_oracle_idx2crd(): if __name__ == "__main__": import traceback - test_funcs = [v for k, v in sorted(globals().items()) if k.startswith("test_") and callable(v)] + + test_funcs = [ + v for k, v in sorted(globals().items()) if k.startswith("test_") and callable(v) + ] passed = 0 failed = 0 errors = [] diff --git a/tests/tensor.py b/tests/tensor.py index 79eae7f..cb4e5e9 100644 --- a/tests/tensor.py +++ b/tests/tensor.py @@ -37,8 +37,18 @@ import pytest from tensor_layouts import ( - Layout, Swizzle, compose, complement, logical_divide, logical_product, - rank, size, cosize, mode, flatten, coalesce + Layout, + Swizzle, + compose, + complement, + logical_divide, + logical_product, + rank, + size, + cosize, + mode, + flatten, + coalesce, ) from tensor_layouts import Tensor @@ -119,7 +129,7 @@ def test_with_offset(self): assert tensor(0, 0) == 100 assert tensor(1, 0) == 101 assert tensor(0, 1) == 104 - assert tensor(3, 7) == 100 + 3 + 7*4 + assert tensor(3, 7) == 100 + 3 + 7 * 4 def test_hierarchical_shape(self): """Nested/hierarchical shape - key CuTe feature.""" @@ -381,11 +391,11 @@ def test_sequential_slices(self): # Second slice s2 = s1[3, :] - assert s2.offset == 2 + 3*4 # = 14 + assert s2.offset == 2 + 3 * 4 # = 14 # Third slice s3 = s2[1] - assert s3 == 2 + 3*4 + 1*32 # = 46 + assert s3 == 2 + 3 * 4 + 1 * 32 # = 46 def test_slice_with_initial_offset(self): """Slicing tensor with non-zero offset accumulates correctly.""" @@ -970,7 +980,9 @@ def test_cute_mma_fragment_pattern(self): assert isinstance(thread1_slice, Tensor) # Verify they access different starting positions - assert thread0_slice.offset != thread1_slice.offset or thread0_slice(0) != thread1_slice(0) + assert thread0_slice.offset != thread1_slice.offset or thread0_slice( + 0 + ) != thread1_slice(0) if __name__ == "__main__": diff --git a/tests/viz.py b/tests/viz.py index e34ab7c..4d85167 100644 --- a/tests/viz.py +++ b/tests/viz.py @@ -20,6 +20,8 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. +# ruff: noqa: F405 + from collections import defaultdict import tempfile @@ -28,7 +30,9 @@ from tensor_layouts import Layout, Swizzle from tensor_layouts.tensor import Tensor from tensor_layouts.layouts import ( - mode, rank, flat_divide, tiled_divide, flat_product, + mode, + rank, + flat_divide, ) from tensor_layouts.atoms_amd import ( CDNA3P_16x16x32_F32F16F16_MFMA, @@ -48,7 +52,7 @@ from matplotlib.transforms import Bbox import tensor_layouts.viz as viz_mod from tensor_layouts.viz import ( - _build_swizzle_figure, + _build_swizzle_figure, # noqa: F401 (used via monkeypatch.setattr) _compute_tv_mapping, _draw_hierarchical_grid, _format_hierarchical_cell_lines, @@ -82,14 +86,14 @@ show_tiled_grid, show_tv_layout, ) + HAS_VIZ = True except ImportError: HAS_VIZ = False requires_viz = pytest.mark.skipif( - not HAS_VIZ, - reason="tensor_layouts.viz not available (needs matplotlib)" + not HAS_VIZ, reason="tensor_layouts.viz not available (needs matplotlib)" ) @@ -100,12 +104,17 @@ def _call_draw_hierarchical_grid(ax, layout, **kwargs): row_shape = mode(layout.shape, 0) col_shape = mode(layout.shape, 1) return _draw_hierarchical_grid( - ax, indices, rows, cols, + ax, + indices, + rows, + cols, cell_coords=cell_coords, - row_shape=row_shape, col_shape=col_shape, + row_shape=row_shape, + col_shape=col_shape, **kwargs, ) + MIXED_VIZ_ATOMS = [ # Representative cross-section for visualization smoke tests: # - NVIDIA: Ampere (SM80), Hopper-era scalar/legacy-style atom (SM90), @@ -171,11 +180,13 @@ def test_show_layout_tensor_zero_offset(): fig_layout = show_layout(layout) fig_tensor = show_layout(tensor) try: + def _cell_values(fig): ax = fig.axes[0] return sorted( [(t.get_position(), t.get_text()) for t in ax.texts if t.get_text().isdigit()], ) + assert _cell_values(fig_layout) == _cell_values(fig_tensor) finally: plt.close(fig_layout) @@ -206,8 +217,7 @@ def test_color_by_row_matches_color_layout(): """color_by='row' produces the same color indices as the manual color_layout.""" layout = Layout((4, 8), (8, 1)) fig_by = show_layout(layout, color_by="row") - fig_manual = show_layout(layout, color_layout=Layout((4, 8), (1, 0)), - colorize=True) + fig_manual = show_layout(layout, color_layout=Layout((4, 8), (1, 0)), colorize=True) try: # Both should have the same cell background colors patches_by = [p for p in fig_by.axes[0].patches] @@ -234,9 +244,7 @@ def test_color_by_column(): def test_color_by_and_color_layout_exclusive(): """Providing both color_by and color_layout raises ValueError.""" with pytest.raises(ValueError, match="mutually exclusive"): - show_layout(Layout((4, 4), (4, 1)), - color_by="row", - color_layout=Layout((4, 4), (1, 0))) + show_layout(Layout((4, 4), (4, 1)), color_by="row", color_layout=Layout((4, 4), (1, 0))) @requires_viz @@ -261,6 +269,7 @@ def test_rank3_panel_values_match_layout(): divided = flat_divide(matrix, Layout(2, 1)) fig = show_layout(divided) try: + def _cell_val(ax, x, y): for t in ax.texts: tx = round(t.get_position()[0], 1) @@ -390,7 +399,7 @@ def test_draw_composite_mixed_tv_and_offset(): atom = SM80_16x8x16_F16F16F16F16_TN panels = [ Layout((4, 4), (4, 1)), # offset grid (default) - (atom.c_layout, {'tv_mode': True}), # TV grid + (atom.c_layout, {"tv_mode": True}), # TV grid ] fig = show_composite(panels, titles=["Offset", "TV"]) try: @@ -406,7 +415,7 @@ def test_draw_composite_hierarchical_panel(): hier = Layout(((2, 2), (2, 2)), ((1, 4), (2, 8))) flat = Layout((4, 4), (4, 1)) panels = [ - (hier, {'flatten_hierarchical': False}), + (hier, {"flatten_hierarchical": False}), flat, ] fig = show_composite(panels, titles=["Hierarchical", "Flat"]) @@ -448,15 +457,13 @@ def test_draw_copy_layout_smoke(): src = Layout((4, 2), (2, 1)) dst = Layout((4, 2), (1, 4)) with tempfile.NamedTemporaryFile(suffix=".png") as f: - draw_copy_layout(src, dst, filename=f.name, - title="copy smoke", colorize=True) + draw_copy_layout(src, dst, filename=f.name, title="copy smoke", colorize=True) @requires_viz def test_draw_copy_layout_rejects_rank1(): with pytest.raises(ValueError, match="rank 2"): - draw_copy_layout(Layout(8, 1), Layout((4, 2), (2, 1)), - filename="ignored.png") + draw_copy_layout(Layout(8, 1), Layout((4, 2), (2, 1)), filename="ignored.png") @requires_viz @@ -474,6 +481,7 @@ def test_show_copy_layout_returns_figure(): def test_draw_copy_atom_smoke(): """draw_copy_atom handles the upcast from bit coordinates automatically.""" from tensor_layouts.atoms_nv import SM75_U32x1_LDSM_N + with tempfile.NamedTemporaryFile(suffix=".png") as f: draw_copy_atom(SM75_U32x1_LDSM_N, element_bits=16, filename=f.name) @@ -482,6 +490,7 @@ def test_draw_copy_atom_smoke(): def test_show_copy_atom_returns_figure(): """show_copy_atom returns a Figure for Jupyter display.""" from tensor_layouts.atoms_nv import SM90_U32x4_STSM_N + fig = show_copy_atom(SM90_U32x4_STSM_N, element_bits=16) try: assert isinstance(fig, matplotlib.figure.Figure) @@ -501,10 +510,16 @@ def test_show_tv_layout_returns_figure(): @requires_viz def test_show_mma_layout_returns_figure(): from tensor_layouts.atoms_nv import SM80_16x8x16_F16F16F16F16_TN + atom = SM80_16x8x16_F16F16F16F16_TN - fig = show_mma_layout(atom.a_layout, atom.b_layout, atom.c_layout, - tile_mnk=atom.shape_mnk, colorize=True, - thr_id_layout=atom.thr_id) + fig = show_mma_layout( + atom.a_layout, + atom.b_layout, + atom.c_layout, + tile_mnk=atom.shape_mnk, + colorize=True, + thr_id_layout=atom.thr_id, + ) try: assert isinstance(fig, matplotlib.figure.Figure) finally: @@ -514,6 +529,7 @@ def test_show_mma_layout_returns_figure(): @requires_viz def test_show_tiled_grid_returns_figure(): from tensor_layouts.atoms_nv import SM80_16x8x16_F16F16F16F16_TN + atom = SM80_16x8x16_F16F16F16F16_TN atom_layout = Layout((2, 2), (1, 2)) grid, tile_shape = tile_mma_grid(atom, atom_layout, matrix="C") @@ -773,7 +789,7 @@ def _label_bboxes(ax): def _has_bbox_overlap(boxes): """Return True if any pair of bounding boxes overlaps.""" for i, (_, bbox_i) in enumerate(boxes): - for _, bbox_j in boxes[i + 1:]: + for _, bbox_j in boxes[i + 1 :]: if Bbox.overlaps(bbox_i, bbox_j): return True return False @@ -811,7 +827,9 @@ def _cell_patch_bboxes(ax): continue if patch.get_width() != 1.0 or patch.get_height() != 1.0: continue - boxes[(int(round(patch.get_y())), int(round(patch.get_x())))] = patch.get_window_extent(renderer=renderer) + boxes[(int(round(patch.get_y())), int(round(patch.get_x())))] = patch.get_window_extent( + renderer=renderer + ) return boxes @@ -852,8 +870,9 @@ def test_draw_hierarchical_grid_cecka_hier_col_margin_labels_do_not_overlap(): fig, ax = plt.subplots(figsize=(8 * 0.8 + 1, 4 * 0.8 + 1)) try: - _call_draw_hierarchical_grid(ax, layout, flatten_hierarchical=False, - label_hierarchy_levels=True) + _call_draw_hierarchical_grid( + ax, layout, flatten_hierarchical=False, label_hierarchy_levels=True + ) row_boxes, col_boxes = _label_bboxes(ax) assert row_boxes assert col_boxes @@ -869,8 +888,9 @@ def test_draw_hierarchical_grid_offset_values_clear_offset_equals_label(): fig, ax = plt.subplots(figsize=(8 * 0.8 + 1, 4 * 0.8 + 1)) try: - _call_draw_hierarchical_grid(ax, layout, flatten_hierarchical=False, - label_hierarchy_levels=True) + _call_draw_hierarchical_grid( + ax, layout, flatten_hierarchical=False, label_hierarchy_levels=True + ) pairs = _offset_label_value_bboxes(ax) assert pairs min_gap = min(value_bbox.x0 - label_bbox.x1 for label_bbox, value_bbox in pairs) @@ -923,8 +943,9 @@ def test_draw_hierarchical_grid_leaves_corner_gap_between_axis_label_bands(): fig, ax = plt.subplots(figsize=(12 * 0.8 + 1, 6 * 0.8 + 1)) try: - _call_draw_hierarchical_grid(ax, layout, flatten_hierarchical=False, - label_hierarchy_levels=True) + _call_draw_hierarchical_grid( + ax, layout, flatten_hierarchical=False, label_hierarchy_levels=True + ) row_boxes, col_boxes = _label_bboxes(ax) assert row_boxes assert col_boxes @@ -944,8 +965,7 @@ def test_draw_hierarchical_grid_leaves_corner_gap_between_axis_label_bands(): @requires_viz def test_draw_hierarchical_grid_draws_outer_perimeter_for_multiple_levels(): - layout = Layout(((3, 2, 2, 2), (4, 2, 2, 2)), - ((1, 3, 6, 12), (24, 96, 192, 384))) + layout = Layout(((3, 2, 2, 2), (4, 2, 2, 2)), ((1, 3, 6, 12), (24, 96, 192, 384))) fig, ax = plt.subplots() try: @@ -1029,14 +1049,10 @@ def fake_save(fig, filename, dpi=150): seen["line_count"] = len(ax.lines) red_edge = to_rgba(viz_mod.HIGHLIGHT_EDGE) seen["highlight_zorders"] = [ - patch.get_zorder() - for patch in ax.patches - if patch.get_edgecolor() == red_edge + patch.get_zorder() for patch in ax.patches if patch.get_edgecolor() == red_edge ] seen["base_zorders"] = [ - patch.get_zorder() - for patch in ax.patches - if patch.get_edgecolor() != red_edge + patch.get_zorder() for patch in ax.patches if patch.get_edgecolor() != red_edge ] plt.close(fig) @@ -1172,8 +1188,7 @@ def test_draw_combined_mma_grid_smoke(): M, N, K = M_a * 2, N_a * 2, K_a with tempfile.NamedTemporaryFile(suffix=".png") as f: - draw_combined_mma_grid(a_grid, b_display, c_grid, M, N, K, - filename=f.name, title="test") + draw_combined_mma_grid(a_grid, b_display, c_grid, M, N, K, filename=f.name, title="test") @requires_viz @@ -1189,8 +1204,7 @@ def test_show_combined_mma_grid_returns_figure(): M_a, N_a, K_a = atom.shape_mnk M, N, K = M_a * 2, N_a * 2, K_a - fig = show_combined_mma_grid(a_grid, b_display, c_grid, M, N, K, - title="test") + fig = show_combined_mma_grid(a_grid, b_display, c_grid, M, N, K, title="test") try: assert isinstance(fig, matplotlib.figure.Figure) finally: