diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 95b2394..d25e08f 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -28,4 +28,4 @@ outlined on that page and do not file a public issue. ## License By contributing to tensor-layouts, you agree that your contributions will be licensed -under the LICENSE file in the root directory of this source tree. \ No newline at end of file +under the LICENSE file in the root directory of this source tree. diff --git a/docs/analysis_api.md b/docs/analysis_api.md index dafcc18..496118e 100644 --- a/docs/analysis_api.md +++ b/docs/analysis_api.md @@ -1,3 +1,27 @@ + + # Analysis API GPU kernel performance lives or dies by memory access patterns. Two @@ -44,7 +68,7 @@ offset_table(Layout((4, 2), (0, 1))) # 1: [(0,1), (1,1), (2,1), (3,1)]} ``` -## bank_conflicts(layout, *, num_banks=32, element_bytes=2, bank_width_bytes=4, group_size=32) +## bank_conflicts(layout, *, element_bytes, num_banks=32, bank_width_bytes=4, group_size=32) Analyze shared memory bank conflicts for a thread-to-offset layout. @@ -75,6 +99,18 @@ The `max_ways` value is the worst-case serialization factor: 1 means no conflicts, N means N-way serialization. Two threads accessing the *same* word get a broadcast (no conflict on NVIDIA hardware). +For multi-mode (TV) layouts where mode 0 is the thread dimension and +mode 1+ are value dimensions, all values per thread are included in the +analysis. This models vectorized loads where each thread accesses +multiple elements: + +```python +# TV layout: 32 threads, each loading 2 fp16 elements +tv = Layout((32, 2), (1, 32)) +result = bank_conflicts(tv, element_bytes=2) +result['conflict_free'] # True: values land in distinct banks +``` + Returns a dict: | Key | Type | Description | @@ -83,7 +119,7 @@ Returns a dict: | `max_ways` | int | Worst-case serialization factor across all banks | | `bank_to_threads` | dict | `{bank_id: [thread_ids...]}` for all accessed banks | -## coalescing_efficiency(layout, *, warp_size=32, element_bytes=2, cache_line_bytes=128) +## coalescing_efficiency(layout, *, element_bytes, warp_size=32, cache_line_bytes=128) Analyze global memory coalescing for a thread-to-offset layout. @@ -98,7 +134,7 @@ result['transactions'] # 1 result['efficiency'] # 1.0 (128 unique useful bytes / 128 transferred) # Worst case: each thread hits a separate cache line -result = coalescing_efficiency(Layout(32, 64)) +result = coalescing_efficiency(Layout(32, 64), element_bytes=2) result['transactions'] # 32 result['efficiency'] # 0.016 (64 unique useful bytes / 4096 transferred) ``` @@ -111,6 +147,17 @@ Returns a dict: | `efficiency` | float | Unique useful bytes / transferred bytes (1.0 = perfect) | | `cache_lines` | list | Sorted cache line indices touched | +For multi-mode (TV) layouts, all values per thread are included, +modeling vectorized loads: + +```python +# TV layout: 32 threads, 4 values each, contiguous within each thread +tv = Layout((32, 4), (4, 1)) +result = coalescing_efficiency(tv, element_bytes=2) +result['transactions'] # 2 (256 bytes spans 2 cache lines) +result['efficiency'] # 1.0 (256 unique bytes / 256 transferred) +``` + ## Permutation Analysis When a layout is bijective (every offset is hit exactly once), it defines diff --git a/docs/generate_figures.py b/docs/generate_figures.py index 753aa5e..9e0402c 100644 --- a/docs/generate_figures.py +++ b/docs/generate_figures.py @@ -1,3 +1,25 @@ +# MIT License +# +# Copyright (c) 2026 Meta Platforms, Inc. and affiliates. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + #!/usr/bin/env python3 """Regenerate all PNG figures used in the documentation. diff --git a/docs/layout_api.md b/docs/layout_api.md index ff67bf6..0857981 100644 --- a/docs/layout_api.md +++ b/docs/layout_api.md @@ -1,3 +1,27 @@ + + # Layout Algebra API This document covers the core `tensor_layouts` API: constructing layouts, diff --git a/docs/viz_api.md b/docs/viz_api.md index ca13f98..78dc71e 100644 --- a/docs/viz_api.md +++ b/docs/viz_api.md @@ -1,3 +1,27 @@ + + # Visualization API This document covers the `tensor_layouts.viz` module for drawing layouts, @@ -257,6 +281,12 @@ layout = Layout(((3, 2), ((2, 3), 2)), ((4, 1), ((2, 15), 100))) draw_slice(layout, ((1, None), ((None, 0), None)), title="((1,:),((:,0),:))") ``` +For 1D layouts, wrap the slice in a single-element tuple: + +```python +draw_slice(Layout(8, 1), (slice(2, 5),), title="1D slice [2:5]") +``` + ![draw_slice](images/draw_slice.png) **Parameters:** @@ -302,6 +332,13 @@ draw_composite(panels, "comparison.png", | `panel_size` | `(w, h)` | `(4, 4)` | Size per panel | | `colorize` | `bool` | `False` | Rainbow colors | | `tv_mode` | `bool` | `False` | Use TV-layout rendering | +| `flatten_hierarchical` | `bool` | `True` | Flatten nested shapes to 2D grid | +| `label_hierarchy_levels` | `bool` | `False` | In nested hierarchical mode, annotate hierarchy levels | + +Per-panel options (`(Layout, opts_dict)` tuples) override the top-level +defaults: `colorize`, `color_layout`, `num_colors`, `tv_mode`, +`flatten_hierarchical`, `label_hierarchy_levels`, and the TV-specific +`grid_rows`, `grid_cols`, `thr_id_layout`, `col_major`. ## draw_tiled_grid diff --git a/examples/layouts.py b/examples/layouts.py index 47602a7..4348e94 100644 --- a/examples/layouts.py +++ b/examples/layouts.py @@ -763,7 +763,7 @@ def example_analysis(): f"efficiency {r1['efficiency']:.0%}") scattered = Layout(32, 64) - r2 = coalescing_efficiency(scattered) + r2 = coalescing_efficiency(scattered, element_bytes=2) print(f" Stride-64 (fp16): {r2['transactions']} transactions, " f"efficiency {r2['efficiency']:.1%}") diff --git a/src/tensor_layouts/analysis.py b/src/tensor_layouts/analysis.py index 98e9a31..2bc5d02 100644 --- a/src/tensor_layouts/analysis.py +++ b/src/tensor_layouts/analysis.py @@ -136,8 +136,8 @@ def footprint(layout: Layout) -> dict: # Bank conflict analysis # ============================================================================= -def bank_conflicts(layout: Layout, *, num_banks: int = 32, - element_bytes: int = 2, bank_width_bytes: int = 4, +def bank_conflicts(layout: Layout, *, element_bytes: int, + num_banks: int = 32, bank_width_bytes: int = 4, group_size: int = 32): """Analyze shared memory bank conflicts for a thread-to-offset layout. @@ -153,8 +153,11 @@ def bank_conflicts(layout: Layout, *, num_banks: int = 32, Only the first ``group_size`` threads are analyzed, matching the hardware issue granularity (warp on NVIDIA, wavefront on AMD). This avoids overstating conflicts when the layout spans multiple - warps. The model assigns each access to its starting bank word; - accesses wider than one bank word are not tracked across banks. + warps. + + For multi-mode (TV) layouts, mode 0 is the thread dimension and all + remaining modes are value dimensions. Each thread's accesses across + all values are included in the analysis, modeling vectorized loads. Args: layout: Maps thread_id -> memory offset (in elements). @@ -171,29 +174,32 @@ def bank_conflicts(layout: Layout, *, num_banks: int = 32, bank_to_threads: {bank_id: [thread_ids...]} for all accessed banks Examples: - # Linear layout: threads access consecutive elements - bank_conflicts(Layout(32, 1)) + # Linear layout: threads access consecutive fp16 elements + bank_conflicts(Layout(32, 1), element_bytes=2) # {'conflict_free': True, 'max_ways': 1, ...} # All threads hit the same address - bank_conflicts(Layout(32, 0)) + bank_conflicts(Layout(32, 0), element_bytes=2) # {'conflict_free': True, 'max_ways': 1, ...} (broadcast, not a conflict) """ layout = as_layout(layout) if group_size <= 0: raise ValueError(f"group_size must be positive, got {group_size}") - n = min(size(layout), group_size) + thread_count, value_count = _tv_dimensions(layout) + n = min(thread_count, group_size) # Map each thread to (bank, word_address) # A bank conflict occurs when threads access different 4-byte words in the # same bank. Two threads accessing the same word get a broadcast (no conflict). thread_banks = {} # bank -> [(thread_id, word_address), ...] for t in range(n): - offset = layout(t) - byte_addr = offset * element_bytes - word_addr = byte_addr // bank_width_bytes - bank = word_addr % num_banks - thread_banks.setdefault(bank, []).append((t, word_addr)) + for v in range(value_count): + flat_idx = v * thread_count + t + offset = layout(flat_idx) + byte_addr = offset * element_bytes + word_addr = byte_addr // bank_width_bytes + bank = word_addr % num_banks + thread_banks.setdefault(bank, []).append((t, word_addr)) # Compute conflicts per bank # Two threads conflict if they hit the same bank but different addresses. @@ -225,8 +231,8 @@ def bank_conflicts(layout: Layout, *, num_banks: int = 32, # Coalescing analysis # ============================================================================= -def coalescing_efficiency(layout: Layout, *, warp_size: int = 32, - element_bytes: int = 2, +def coalescing_efficiency(layout: Layout, *, element_bytes: int, + warp_size: int = 32, cache_line_bytes: int = 128): """Analyze global memory coalescing for a thread-to-offset layout. @@ -239,6 +245,10 @@ def coalescing_efficiency(layout: Layout, *, warp_size: int = 32, transaction. In the worst case, each thread triggers a separate transaction. + For multi-mode (TV) layouts, mode 0 is the thread dimension and all + remaining modes are value dimensions. Each thread's accesses across + all values are included in the analysis, modeling vectorized loads. + Args: layout: Maps thread_id -> memory offset (in elements). warp_size: Threads per warp (32 on NVIDIA/AMD GPUs). @@ -255,7 +265,7 @@ def coalescing_efficiency(layout: Layout, *, warp_size: int = 32, Examples: # Perfectly coalesced: 32 threads, stride 1, fp16 - coalescing_efficiency(Layout(32, 1)) + coalescing_efficiency(Layout(32, 1), element_bytes=2) # {'transactions': 1, 'efficiency': 0.5, ...} -- 64B used of 128B line # Strided access: each thread 2 elements apart @@ -263,17 +273,20 @@ def coalescing_efficiency(layout: Layout, *, warp_size: int = 32, # {'transactions': 2, 'efficiency': 0.5, ...} """ layout = as_layout(layout) - n = min(size(layout), warp_size) + thread_count, value_count = _tv_dimensions(layout) + n = min(thread_count, warp_size) # Find which cache lines are touched and count unique offsets cache_lines = set() unique_offsets = set() for t in range(n): - offset = layout(t) - unique_offsets.add(offset) - byte_addr = offset * element_bytes - cache_line = byte_addr // cache_line_bytes - cache_lines.add(cache_line) + for v in range(value_count): + flat_idx = v * thread_count + t + offset = layout(flat_idx) + unique_offsets.add(offset) + byte_addr = offset * element_bytes + cache_line = byte_addr // cache_line_bytes + cache_lines.add(cache_line) transactions = len(cache_lines) useful_bytes = len(unique_offsets) * element_bytes @@ -287,8 +300,8 @@ def coalescing_efficiency(layout: Layout, *, warp_size: int = 32, } -def segment_analysis(layout: Layout, *, warp_size: int = 32, - element_bytes: int = 2, +def segment_analysis(layout: Layout, *, element_bytes: int, + warp_size: int = 32, segment_bytes: int = 32, cache_line_bytes: int = 128): """Segment- and alignment-aware global memory transaction analysis. @@ -298,6 +311,10 @@ def segment_analysis(layout: Layout, *, warp_size: int = 32, warp access may touch fewer cache lines than segments when accesses cluster within a line but span multiple segments. + For multi-mode (TV) layouts, mode 0 is the thread dimension and all + remaining modes are value dimensions. Each thread's accesses across + all values are included in the analysis, modeling vectorized loads. + Args: layout: Maps thread_id -> memory offset (in elements). warp_size: Threads per warp. @@ -317,7 +334,8 @@ def segment_analysis(layout: Layout, *, warp_size: int = 32, first_alignment: alignment of first_byte_addr to segment_bytes """ layout = as_layout(layout) - n = min(size(layout), warp_size) + thread_count, value_count = _tv_dimensions(layout) + n = min(thread_count, warp_size) segments = set() lines = set() @@ -325,19 +343,21 @@ def segment_analysis(layout: Layout, *, warp_size: int = 32, first_byte = None for t in range(n): - offset = layout(t) - unique_offsets.add(offset) - byte_addr = offset * element_bytes - if first_byte is None: - first_byte = byte_addr - segments.add(byte_addr // segment_bytes) - lines.add(byte_addr // cache_line_bytes) + for v in range(value_count): + flat_idx = v * thread_count + t + offset = layout(flat_idx) + unique_offsets.add(offset) + byte_addr = offset * element_bytes + if first_byte is None: + first_byte = byte_addr + segments.add(byte_addr // segment_bytes) + lines.add(byte_addr // cache_line_bytes) first_byte = first_byte if first_byte is not None else 0 n_segments = len(segments) n_lines = len(lines) unique_bytes = len(unique_offsets) * element_bytes - requested_bytes = n * element_bytes + requested_bytes = n * value_count * element_bytes transferred_bytes = n_segments * segment_bytes return { @@ -356,14 +376,32 @@ def segment_analysis(layout: Layout, *, warp_size: int = 32, # Per-group analysis # ============================================================================= -def per_group_bank_conflicts(layout: Layout, *, group_size: int = 32, - num_banks: int = 32, element_bytes: int = 2, + +def _tv_dimensions(layout: Layout): + """Extract (thread_count, value_count) from a layout. + + For rank-1 (scalar shape) layouts: thread_count = size, value_count = 1. + For rank>1 (TV) layouts: thread_count = size(mode 0), value_count = + product of remaining modes. + """ + if is_int(layout.shape): + return size(layout), 1 + return size(mode(layout, 0)), size(layout) // size(mode(layout, 0)) + + +def per_group_bank_conflicts(layout: Layout, *, element_bytes: int, + group_size: int = 32, + num_banks: int = 32, bank_width_bytes: int = 4) -> dict: """Analyze bank conflicts per warp/wavefront group across a full layout. Splits the layout into groups of ``group_size`` threads and analyzes bank conflicts for each group independently. + For multi-mode (TV) layouts, groups are formed along mode 0 (the thread + dimension). Each thread's accesses across all value modes (mode 1+) are + included in its group's analysis. + Args: layout: Maps thread_id -> memory offset (in elements). group_size: Threads per group (32 = NVIDIA warp, 64 = AMD wave). @@ -381,8 +419,8 @@ def per_group_bank_conflicts(layout: Layout, *, group_size: int = 32, layout = as_layout(layout) if group_size <= 0: raise ValueError(f"group_size must be positive, got {group_size}") - n = size(layout) - num_groups = (n + group_size - 1) // group_size + thread_count, value_count = _tv_dimensions(layout) + num_groups = (thread_count + group_size - 1) // group_size groups = [] worst_idx = 0 @@ -390,15 +428,17 @@ def per_group_bank_conflicts(layout: Layout, *, group_size: int = 32, for g in range(num_groups): start = g * group_size - end = min(start + group_size, n) + end = min(start + group_size, thread_count) thread_banks = {} for t in range(start, end): - offset = layout(t) - byte_addr = offset * element_bytes - word_addr = byte_addr // bank_width_bytes - bank = word_addr % num_banks - thread_banks.setdefault(bank, []).append((t, word_addr)) + for v in range(value_count): + flat_idx = v * thread_count + t + offset = layout(flat_idx) + byte_addr = offset * element_bytes + word_addr = byte_addr // bank_width_bytes + bank = word_addr % num_banks + thread_banks.setdefault(bank, []).append((t, word_addr)) max_ways = 1 bank_to_threads = {} @@ -428,14 +468,18 @@ def per_group_bank_conflicts(layout: Layout, *, group_size: int = 32, } -def per_group_coalescing(layout: Layout, *, group_size: int = 32, - element_bytes: int = 2, +def per_group_coalescing(layout: Layout, *, element_bytes: int, + group_size: int = 32, cache_line_bytes: int = 128) -> dict: """Analyze coalescing efficiency per warp/wavefront group across a full layout. Splits the layout into groups of ``group_size`` threads and analyzes coalescing for each group independently. + For multi-mode (TV) layouts, groups are formed along mode 0 (the thread + dimension). Each thread's accesses across all value modes (mode 1+) are + included in its group's analysis. + Args: layout: Maps thread_id -> memory offset (in elements). group_size: Threads per group (32 = NVIDIA warp, 64 = AMD wave). @@ -451,8 +495,8 @@ def per_group_coalescing(layout: Layout, *, group_size: int = 32, layout = as_layout(layout) if group_size <= 0: raise ValueError(f"group_size must be positive, got {group_size}") - n = size(layout) - num_groups = (n + group_size - 1) // group_size + thread_count, value_count = _tv_dimensions(layout) + num_groups = (thread_count + group_size - 1) // group_size groups = [] worst_idx = 0 @@ -460,16 +504,18 @@ def per_group_coalescing(layout: Layout, *, group_size: int = 32, for g in range(num_groups): start = g * group_size - end = min(start + group_size, n) + end = min(start + group_size, thread_count) cache_lines = set() unique_offsets = set() for t in range(start, end): - offset = layout(t) - unique_offsets.add(offset) - byte_addr = offset * element_bytes - cache_line = byte_addr // cache_line_bytes - cache_lines.add(cache_line) + for v in range(value_count): + flat_idx = v * thread_count + t + offset = layout(flat_idx) + unique_offsets.add(offset) + byte_addr = offset * element_bytes + cache_line = byte_addr // cache_line_bytes + cache_lines.add(cache_line) transactions = len(cache_lines) useful_bytes = len(unique_offsets) * element_bytes @@ -870,18 +916,32 @@ def explain(fn, *args): if isinstance(B, int): B = Layout(B) lines.append(f'logical_product({A}, {B})') - lines.append(f' = Layout(A, compose(complement(A, size(A)*size(B)), B))') - lines.append(f'') - lines.append(f' A = {A}') - lines.append(f' B = {B}') - bound = size(A) * size(B) - lines.append(f' size(A) * size(B) = {bound}') - comp = complement(A, bound) - lines.append(f' complement(A, {bound}) = {comp}') - comp_b = compose(comp, B) - lines.append(f' compose(complement, B) = {comp_b}') - result = Layout(A, comp_b) - lines.append(f' Layout(A, {comp_b}) = {result}') + + if is_layout(B): + lines.append(f' = Layout(A, compose(complement(A, size(A)*size(B)), B))') + lines.append(f'') + lines.append(f' A = {A}') + lines.append(f' B = {B}') + bound = size(A) * size(B) + lines.append(f' size(A) * size(B) = {bound}') + comp = complement(A, bound) + lines.append(f' complement(A, {bound}) = {comp}') + comp_b = compose(comp, B) + lines.append(f' compose(complement, B) = {comp_b}') + result = Layout(A, comp_b) + lines.append(f' Layout(A, {comp_b}) = {result}') + else: + # Tuple tiler: mode-by-mode decomposition + lines.append(f' For tuple tilers, applies logical_product mode-by-mode.') + lines.append(f'') + lines.append(f' A = {A}') + lines.append(f' B = {B}') + for i in range(len(B)): + mi = mode(A, i) + bi = B[i] + ri = logical_product(mi, bi) + lines.append(f' mode {i}: logical_product({mi}, {bi}) = {ri}') + lines.append(f'') actual = logical_product(A, B) lines.append(f' result = {actual}') diff --git a/src/tensor_layouts/atoms.py b/src/tensor_layouts/atoms.py index 2280ef3..729b634 100644 --- a/src/tensor_layouts/atoms.py +++ b/src/tensor_layouts/atoms.py @@ -53,6 +53,10 @@ class MMAAtom: b_layout: Layout c_layout: Layout + def __str__(self) -> str: + m, n, k = self.shape_mnk + return f"MMAAtom('{self.name}', {m}x{n}x{k})" + @dataclass(frozen=True) class CopyAtom: @@ -72,3 +76,6 @@ class CopyAtom: thr_id: Layout src_layout_bits: Layout dst_layout_bits: Layout + + def __str__(self) -> str: + return f"CopyAtom('{self.name}')" diff --git a/src/tensor_layouts/atoms_amx.py b/src/tensor_layouts/atoms_amx.py new file mode 100644 index 0000000..c7b05b4 --- /dev/null +++ b/src/tensor_layouts/atoms_amx.py @@ -0,0 +1,152 @@ +# MIT License +# +# Copyright (c) 2026 Meta Platforms, Inc. and affiliates. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +"""Intel AMX (Advanced Matrix Extensions) tile atom definitions. + +Maps AMX tile matrix-multiply instructions to the (Thread, Value) -> +element-offset framework used by CuTe-style layout algebra. + +Conceptual mapping +================== + +AMX is a true tile matrix multiply executed by a single CPU core. The +instruction operates on 8 KiB tile registers (tmm0-tmm7), each holding +up to 16 rows x 64 bytes. A single ``tdp*`` instruction computes the +full C[M,N] += A[M,K] * B[K,N] tile with no thread cooperation, so: + + T = 1 (one CPU core) + V = M*K, N*K, or M*N (all elements in the Value dimension) + +This is the cleanest CPU -> MMAAtom mapping because AMX is a genuine +matrix multiply, unlike AVX512 VNNI which is a batched dot-product. + +Tile dimensions +=============== + +All AMX tile-multiply instructions use M=16, N=16 output tiles. +The K dimension depends on the element type: + + BF16 (tdpbf16ps): K=32 (2 BF16 per 32-bit pair, 64-byte rows) + FP16 (tdpfp16ps): K=32 (2 FP16 per 32-bit pair, 64-byte rows) + INT8 (tdpbssd): K=64 (4 INT8 per 32-bit group, 64-byte rows) + +B-tile storage +-------------- + +The B tile register uses "VNNI format" (pairs/quads of K-elements packed +into 32-bit groups along the row), but the LOGICAL layout is still K x N. +Our atom layouts describe the logical element ordering, not the physical +register packing. + +References +========== + +- Intel Architecture Instruction Set Extensions Programming Reference + (ISE), Chapter 3 -- AMX (TILECFG, TILELOADD, TDPBF16PS, TDPBSSD, etc.) +- Intel 64 and IA-32 Architectures Software Developer's Manual (SDM), + Volume 2 -- Instruction Set Reference + +Usage:: + + from tensor_layouts.atoms_amx import AMX_16x16x32_F32BF16BF16F32 + atom = AMX_16x16x32_F32BF16BF16F32 + print(atom.shape_mnk) # (16, 16, 32) + print(atom.c_layout) # (1, (16, 16)):(0, (1, 16)) +""" + +from .layouts import Layout +from .atoms import MMAAtom + + +# ============================================================================= +# Intel AMX tile matrix multiply atoms — T=1 (single CPU core) +# Source: Intel ISE Chapter 3, AMX instructions +# +# tdpbf16ps tmm, tmm, tmm -- C[16,16] FP32 += A[16,32] BF16 * B[32,16] BF16 +# tdpfp16ps tmm, tmm, tmm -- C[16,16] FP32 += A[16,32] FP16 * B[32,16] FP16 +# tdpbssd tmm, tmm, tmm -- C[16,16] INT32 += A[16,64] INT8 * B[64,16] INT8 +# tdpbsud tmm, tmm, tmm -- C[16,16] INT32 += A[16,64] INT8 * B[64,16] UINT8 +# tdpbusd tmm, tmm, tmm -- C[16,16] INT32 += A[16,64] UINT8 * B[64,16] INT8 +# tdpbuud tmm, tmm, tmm -- C[16,16] INT32 += A[16,64] UINT8 * B[64,16] UINT8 +# +# All produce a 16x16 output tile. K varies by datatype. +# ============================================================================= + +# -- BF16 -> FP32 ------------------------------------------------------------- +AMX_16x16x32_F32BF16BF16F32 = MMAAtom( + name="AMX_16x16x32_F32BF16BF16F32", + ptx="tdpbf16ps", + shape_mnk=(16, 16, 32), thr_id=Layout(1), + # (T=1, V=512) -> col-major offset in (M=16, K=32) + a_layout=Layout((1, (16, 32)), (0, (1, 16))), + # (T=1, V=512) -> col-major offset in (N=16, K=32) + b_layout=Layout((1, (16, 32)), (0, (1, 16))), + # (T=1, V=256) -> col-major offset in (M=16, N=16) + c_layout=Layout((1, (16, 16)), (0, (1, 16)))) + +# -- FP16 -> FP32 ------------------------------------------------------------- +AMX_16x16x32_F32F16F16F32 = MMAAtom( + name="AMX_16x16x32_F32F16F16F32", + ptx="tdpfp16ps", + shape_mnk=(16, 16, 32), thr_id=Layout(1), + a_layout=Layout((1, (16, 32)), (0, (1, 16))), + b_layout=Layout((1, (16, 32)), (0, (1, 16))), + c_layout=Layout((1, (16, 16)), (0, (1, 16)))) + +# -- INT8 x INT8 -> INT32 (signed x signed) ----------------------------------- +AMX_16x16x64_S32S8S8S32 = MMAAtom( + name="AMX_16x16x64_S32S8S8S32", + ptx="tdpbssd", + shape_mnk=(16, 16, 64), thr_id=Layout(1), + # (T=1, V=1024) -> col-major offset in (M=16, K=64) + a_layout=Layout((1, (16, 64)), (0, (1, 16))), + # (T=1, V=1024) -> col-major offset in (N=16, K=64) + b_layout=Layout((1, (16, 64)), (0, (1, 16))), + # (T=1, V=256) -> col-major offset in (M=16, N=16) + c_layout=Layout((1, (16, 16)), (0, (1, 16)))) + +# -- INT8 x UINT8 -> INT32 (signed x unsigned) -------------------------------- +AMX_16x16x64_S32S8U8S32 = MMAAtom( + name="AMX_16x16x64_S32S8U8S32", + ptx="tdpbsud", + shape_mnk=(16, 16, 64), thr_id=Layout(1), + a_layout=Layout((1, (16, 64)), (0, (1, 16))), + b_layout=Layout((1, (16, 64)), (0, (1, 16))), + c_layout=Layout((1, (16, 16)), (0, (1, 16)))) + +# -- UINT8 x INT8 -> INT32 (unsigned x signed) -------------------------------- +AMX_16x16x64_S32U8S8S32 = MMAAtom( + name="AMX_16x16x64_S32U8S8S32", + ptx="tdpbusd", + shape_mnk=(16, 16, 64), thr_id=Layout(1), + a_layout=Layout((1, (16, 64)), (0, (1, 16))), + b_layout=Layout((1, (16, 64)), (0, (1, 16))), + c_layout=Layout((1, (16, 16)), (0, (1, 16)))) + +# -- UINT8 x UINT8 -> INT32 (unsigned x unsigned) ----------------------------- +AMX_16x16x64_S32U8U8S32 = MMAAtom( + name="AMX_16x16x64_S32U8U8S32", + ptx="tdpbuud", + shape_mnk=(16, 16, 64), thr_id=Layout(1), + a_layout=Layout((1, (16, 64)), (0, (1, 16))), + b_layout=Layout((1, (16, 64)), (0, (1, 16))), + c_layout=Layout((1, (16, 16)), (0, (1, 16)))) diff --git a/src/tensor_layouts/layouts.py b/src/tensor_layouts/layouts.py index ff821c4..8c637ad 100644 --- a/src/tensor_layouts/layouts.py +++ b/src/tensor_layouts/layouts.py @@ -72,45 +72,102 @@ # Type alias "IntOrIntTuple", # Type predicates - "is_tuple", "is_int", "is_scalar", "is_iterable", "is_layout", - "is_pure_shape", "has_none", + "is_tuple", + "is_int", + "is_scalar", + "is_iterable", + "is_layout", + "is_pure_shape", + "has_none", # Shape conversions - "as_tuple", "as_shape", "as_layout", "unwrap", "normalize", + "as_tuple", + "as_shape", + "as_layout", + "unwrap", + "normalize", # Core types - "Layout", "Tile", "Swizzle", "make_swizzle", + "Layout", + "Tile", + "Swizzle", + "make_swizzle", # Stride computation - "compute_col_major_strides", "compute_row_major_strides", + "compute_col_major_strides", + "compute_row_major_strides", # Query functions - "size", "cosize", "rank", "depth", "mode", + "size", + "cosize", + "rank", + "depth", + "mode", # Tuple operations - "concat", "congruent", "compatible", - "tuple_max", "transform_tuple", "zip_transform", - "fold", "fold_accumulate", "elem_scale", "inner_product", - "prefix_product", "suffix_product", "product_each", + "concat", + "congruent", + "weakly_congruent", + "compatible", + "tuple_max", + "transform_tuple", + "zip_transform", + "fold", + "fold_accumulate", + "elem_scale", + "inner_product", + "prefix_product", + "suffix_product", + "product_each", # Layout manipulation - "append", "prepend", "replace", "group", - "flatten", "unflatten", "sort", "coalesce", + "append", + "prepend", + "replace", + "group", + "flatten", + "unflatten", + "sort", + "coalesce", # Coordinate conversion - "idx2crd", "crd2flat", "crd2offset", "crd2idx", "crd2crd", - "slice_modes", "dice_modes", "slice_and_offset", + "idx2crd", + "crd2flat", + "crd2offset", + "crd2idx", + "crd2crd", + "slice_modes", + "dice_modes", + "slice_and_offset", # Core algebra - "compose", "complement", "logical_divide", "logical_product", + "compose", + "complement", + "logical_divide", + "logical_product", # Division variants - "zipped_divide", "tiled_divide", "flat_divide", + "zipped_divide", + "tiled_divide", + "flat_divide", # Product variants - "zipped_product", "tiled_product", "hier_unzip", - "blocked_product", "raked_product", "flat_product", + "zipped_product", + "tiled_product", + "hier_unzip", + "blocked_product", + "raked_product", + "flat_product", # Inverse and related - "right_inverse", "left_inverse", "nullspace", - "max_common_layout", "max_common_vector", + "right_inverse", + "left_inverse", + "nullspace", + "max_common_layout", + "max_common_vector", # Shape arithmetic - "safe_div", "shape_div", "shape_mod", + "safe_div", + "shape_div", + "shape_mod", # Upcast / downcast - "upcast", "downcast", + "upcast", + "downcast", # Iteration "iter_layout", # Image and injectivity - "image", "is_injective", "is_surjective", "is_bijective", + "image", + "is_injective", + "is_surjective", + "is_bijective", # Functional equivalence "functionally_equal", ] @@ -143,7 +200,7 @@ def as_layout(obj): """ if isinstance(obj, Layout): return obj - if hasattr(obj, 'shape') and hasattr(obj, 'stride'): + if hasattr(obj, "shape") and hasattr(obj, "stride"): return Layout(obj.shape, obj.stride) raise TypeError(f"Expected Layout, got {type(obj).__name__}") @@ -199,6 +256,7 @@ def has_none(a) -> bool: """ return fold(a, False, lambda acc, v: acc or v is None) + # ============================================================================= # Shape conversions # ============================================================================= @@ -301,6 +359,21 @@ def normalize(x: Any) -> IntOrIntTuple: # +def _validate_shape_type(x, name: str) -> None: + """Validate that *x* is a valid shape or stride: int or nested tuple of ints. + + Raises TypeError with a clear message naming the offending parameter + (``name`` should be ``"shape"`` or ``"stride"``). + """ + if is_int(x): + return + if isinstance(x, (list, tuple)): + for elem in x: + _validate_shape_type(elem, name) + return + raise TypeError(f"Layout {name} must be int or tuple of ints, got {type(x).__name__}") + + class Layout: """A function from logical coordinates to memory offsets: offset = sum(coord_i * stride_i). @@ -355,11 +428,14 @@ def __init__(self, *args, swizzle: "Swizzle | None" = None): elif len(args) == 1: shape = args[0] + _validate_shape_type(shape, "shape") self._shape = normalize(shape) self._stride = compute_col_major_strides(self._shape) elif len(args) == 2: shape, stride = args + _validate_shape_type(shape, "shape") + _validate_shape_type(stride, "stride") self._shape = normalize(shape) self._stride = normalize(stride) @@ -369,11 +445,11 @@ def __init__(self, *args, swizzle: "Swizzle | None" = None): ) if not congruent(self._shape, self._stride): - raise ValueError( - f"Shape {self._shape} and Stride {self._stride} are not congruent" - ) + raise ValueError(f"Shape {self._shape} and Stride {self._stride} are not congruent") def __eq__(self, other): + if self is other: + return True if not isinstance(other, Layout): return False return ( @@ -389,15 +465,23 @@ def __hash__(self): return hash((self.shape, self.stride, swizzle_hash)) def __repr__(self): + """Return an eval-safe constructor string: Layout((4, 2), (1, 4)).""" + if self._swizzle is not None: + return f"Layout({self._shape!r}, {self._stride!r}, swizzle={self._swizzle!r})" + return f"Layout({self._shape!r}, {self._stride!r})" + + def __str__(self): + """Return human-readable CuTe notation: (4, 2) : (1, 4).""" + def fmt(x): - """Format shape/stride: int as-is, tuple with parens.""" if isinstance(x, int): return str(x) return repr(x) - base_repr = f"{fmt(self._shape)} : {fmt(self._stride)}" + + base = f"{fmt(self._shape)} : {fmt(self._stride)}" if self._swizzle is not None: - return f"({self._swizzle}) o ({base_repr})" - return base_repr + return f"({self._swizzle}) o ({base})" + return base @property def shape(self) -> IntOrIntTuple: @@ -415,9 +499,7 @@ def swizzle(self) -> "Swizzle | None": @staticmethod def _calculate_max_offset(shape: Any, stride: Any) -> int: if is_tuple(shape): - return sum( - Layout._calculate_max_offset(s, d) for s, d in zip(shape, stride) - ) + return sum(Layout._calculate_max_offset(s, d) for s, d in zip(shape, stride)) return (shape - 1) * stride def __call__(self, *args): @@ -490,7 +572,6 @@ def filter_strides(self, shape, stride, target): d_out = [] for s, d in zip(shape, stride): if is_tuple(s): - sub_s, sub_d = self.filter_strides(s, d, target) if sub_s != (): s_out.append(sub_s) @@ -556,6 +637,7 @@ def _zero_leading_unit_strides(shape, strides): still_leading = False return tuple(result) + # ============================================================================= # Query functions: size, rank, depth, mode # ============================================================================= @@ -568,6 +650,7 @@ def _zero_leading_unit_strides(shape, strides): # mode -- extract a single mode (dimension) from a shape or layout # + def size(obj: Any) -> int: """Returns the logical number of elements (product of shape).""" if isinstance(obj, Layout): @@ -591,7 +674,7 @@ def rank(obj: Any) -> int: return len(obj) if isinstance(obj, Layout): if is_int(obj.shape): - return 0 + return 1 return len(obj.shape) if is_int(obj): return 0 @@ -627,6 +710,10 @@ def mode(obj: Any, idx): raise IndexError(f"Index {idx} out of range for scalar layout") return obj return Layout(obj.shape[idx], obj.stride[idx]) + if is_int(obj): + if idx != 0: + raise IndexError(f"Index {idx} out of range for scalar") + return obj raise TypeError(f"Cannot get mode of {type(obj).__name__}") @@ -634,11 +721,10 @@ def concat(t1: Any, t2: Any): if is_tuple(t1) and is_tuple(t2): return t1 + t2 if isinstance(t1, Layout) and isinstance(t2, Layout): - return Layout(as_tuple(t1.shape) + as_tuple(t2.shape), - as_tuple(t1.stride) + as_tuple(t2.stride)) - raise TypeError( - f"Cannot concatenate objects of {type(t1).__name__} and {type(t2).__name__}" - ) + return Layout( + as_tuple(t1.shape) + as_tuple(t2.shape), as_tuple(t1.stride) + as_tuple(t2.stride) + ) + raise TypeError(f"Cannot concatenate objects of {type(t1).__name__} and {type(t2).__name__}") def congruent(a: IntOrIntTuple, b: IntOrIntTuple) -> bool: @@ -660,6 +746,29 @@ def congruent(a: IntOrIntTuple, b: IntOrIntTuple) -> bool: return False +def weakly_congruent(a: IntOrIntTuple, b: IntOrIntTuple) -> bool: + """Returns True if A's profile is contained in B's profile. + + Matches CuTe's weakly_congruent(): a partial order A <= B where A's + hierarchical rank division is "at most as deep as" B's. A scalar on + the A side matches any sub-tree on the B side, but a tuple on the A + side requires at least as much structure on the B side. + + Examples: + weakly_congruent(6, (2, 3)) -> True (scalar matches anything) + weakly_congruent((2, 3), 6) -> False (tuple vs scalar) + weakly_congruent((2, 3), (4, 5)) -> True (same rank) + weakly_congruent((2, (3, 4)), (5, (6, 7))) -> True (same nesting) + weakly_congruent((2, (3, 4)), (5, 6)) -> False (A deeper than B) + weakly_congruent((2, 3), (5, (6, 7))) -> True (A flatter than B) + """ + if isinstance(a, int): + return True + if is_tuple(a) and is_tuple(b): + return len(a) == len(b) and all(weakly_congruent(sa, sb) for sa, sb in zip(a, b)) + return False + + def compatible(a: IntOrIntTuple, b: IntOrIntTuple) -> bool: """Checks if shape A is compatible with shape B. @@ -713,6 +822,7 @@ def _can_group_a_into_b(a_modes: list, b) -> bool: # and coordinates are computed via idx2crd. # + def iter_layout(layout: Layout): """Yield (coordinate, offset) pairs for every element in the layout. @@ -741,6 +851,7 @@ def iter_layout(layout: Layout): # is_bijective -- both (the layout is a permutation) # + def image(layout: Layout) -> list: """Return the sorted list of distinct offsets produced by the layout. @@ -811,6 +922,7 @@ def is_bijective(layout: Layout) -> bool: # Functional equivalence # ============================================================================= + def functionally_equal(a: Layout, b: Layout) -> bool: """True if two layouts compute the same mapping for every flat index. @@ -882,9 +994,7 @@ def group(layout: Layout, start: int, end: int) -> Layout: """ r = rank(layout) if start < 0 or end > r or start >= end: - raise ValueError( - f"Invalid group range [{start}, {end}) for layout of rank {r}" - ) + raise ValueError(f"Invalid group range [{start}, {end}) for layout of rank {r}") shapes = list(as_tuple(layout.shape)) strides = list(as_tuple(layout.stride)) @@ -950,6 +1060,7 @@ def unflatten(obj, target_profile): flatten(obj) == obj (obj must already be flat) rank(flatten(target_profile)) == rank(obj) """ + def _unflatten_helper(flat_tuple, profile): """Consume elements from flat_tuple to match profile's structure.""" if is_tuple(profile): @@ -1032,6 +1143,7 @@ def sort(obj: Layout) -> Layout: # from a shape, which is how Layout(shape) auto-computes its strides. # + def tuple_max(a: Any) -> int: """Return the maximum value across all terminals of a (possibly nested) int-tuple. @@ -1040,7 +1152,7 @@ def tuple_max(a: Any) -> int: tuple_max((3, 7, 2)) -> 7 tuple_max(((1, 9), (4, 2))) -> 9 """ - return fold(a, -float('inf'), lambda acc, x: max(acc, x)) + return fold(a, -float("inf"), lambda acc, x: max(acc, x)) def transform_tuple(t: Any, f) -> Any: @@ -1269,6 +1381,7 @@ def suffix_product(a: Any, init: Any = 1) -> Any: # always safe and always preserves semantics. # + def coalesce(obj: Layout, profile: Any = None) -> Layout: """Returns a new Layout where contiguous dimensions are merged. @@ -1403,6 +1516,7 @@ def _coalesce_by_mode(layout: Layout, profile: tuple) -> Layout: # free dimensions, much like NumPy's array[3, :, :] syntax. # + def complement(layout: Layout, cosize_bound: Any = None) -> Layout: """Compute the complement of a layout: a layout that fills in the gaps. @@ -1461,9 +1575,7 @@ def _step_mode(current_stride, stride, shape): flat_shapes = list(flat.shape) flat_strides = list(flat.stride) - modes = sorted( - ((d, s) for s, d in zip(flat_shapes, flat_strides) if s != 1 and d != 0) - ) + modes = sorted(((d, s) for s, d in zip(flat_shapes, flat_strides) if s != 1 and d != 0)) # Fold _step_mode over sorted modes, collecting gap-fills result_shapes = [] @@ -1474,13 +1586,9 @@ def _step_mode(current_stride, stride, shape): # CuTe/pycute asserts current_stride <= stride * shape (injectivity). # Negative strides or zero-sized shapes violate this invariant. if stride < 0: - raise ValueError( - f"complement: negative stride {stride} is not supported" - ) + raise ValueError(f"complement: negative stride {stride} is not supported") if shape == 0: - raise ValueError( - f"complement: zero-sized shape is not supported" - ) + raise ValueError(f"complement: zero-sized shape is not supported") gap_size, next_stride = _step_mode(current_stride, stride, shape) if gap_size > 1: result_shapes.append(gap_size) @@ -1564,10 +1672,7 @@ def _step_mode(current_idx, stride, shape): if not result_shape: return Layout(1, 0) - return coalesce(Layout( - tuple(result_shape), - tuple(result_stride) - )) + return coalesce(Layout(tuple(result_shape), tuple(result_stride))) def left_inverse(layout: Any) -> Layout: @@ -1758,10 +1863,13 @@ def slice_and_offset(crd, layout: Layout): # crd2crd: convert between two shapes' coordinate spaces # + def idx2crd(coord: Any, shape: Any) -> Any: """Convert index into a hierarchical coordinate.""" if isinstance(shape, int): + if isinstance(coord, int): + return coord % shape return coord # Case: Input is a single integer index for this entire sub-hierarchy @@ -1779,9 +1887,7 @@ def idx2crd(coord: Any, shape: Any) -> Any: # We map the modes of the coordinate to the modes of the shape if is_tuple(coord): if len(coord) != len(shape): - raise ValueError( - f"Coordinate rank {len(coord)} mismatch with Shape rank {len(shape)}" - ) + raise ValueError(f"Coordinate rank {len(coord)} mismatch with Shape rank {len(shape)}") return zip_transform(coord, shape, idx2crd) @@ -1862,14 +1968,9 @@ def crd2offset(coord, shape, stride) -> int: # Case 3: nD coordinate mapping (coord tuple -> offset) if not is_tuple(coord): - raise TypeError( - f"Coordinate must be int or tuple, got {type(coord).__name__}" - ) + raise TypeError(f"Coordinate must be int or tuple, got {type(coord).__name__}") if len(coord) != len(shape): - raise ValueError( - f"Coordinate rank {len(coord)} does not match " - f"layout rank {len(shape)}" - ) + raise ValueError(f"Coordinate rank {len(coord)} does not match layout rank {len(shape)}") offset = 0 for c, s, d in zip(coord, shape, stride): if c is None: @@ -1916,7 +2017,9 @@ def crd2crd(crd: Any, dst_shape: Any, src_shape: Any = None) -> Any: if is_tuple(crd): if is_tuple(dst_shape): if len(crd) != len(dst_shape): - raise ValueError(f"Rank mismatch: crd has {len(crd)} elements, dst_shape has {len(dst_shape)}") + raise ValueError( + f"Rank mismatch: crd has {len(crd)} elements, dst_shape has {len(dst_shape)}" + ) return zip_transform(crd, dst_shape, crd2crd) else: # crd is tuple, dst_shape is scalar: flatten using src_shape @@ -1995,12 +2098,15 @@ def dice_modes(crd, layout): dice_modes((0, None), Layout((3,4),(1,4))) -> 3:1 # keep mode 0 dice_modes((None, 0), Layout((3,4),(1,4))) -> 4:4 # keep mode 1 """ + def dice_tuple(crd, trg): """Keep elements of trg paired with integers in crd.""" if is_tuple(crd): if is_tuple(trg): if len(crd) != len(trg): - raise ValueError(f"Rank mismatch: crd has {len(crd)} elements, trg has {len(trg)}") + raise ValueError( + f"Rank mismatch: crd has {len(crd)} elements, trg has {len(trg)}" + ) result = [] for c, s in zip(crd, trg): result.extend(dice_tuple(c, s)) @@ -2049,6 +2155,7 @@ def dice_tuple(crd, trg): # nested shape elements, consuming from the innermost (leftmost) modes first. # + class Tile(tuple): """A Tiler is a tuple-of-Layouts used for mode-by-mode composition. @@ -2082,9 +2189,7 @@ def __new__(cls, *layouts): """ for i, layout in enumerate(layouts): if not isinstance(layout, Layout): - raise TypeError( - f"Tile element {i} must be a Layout, got {type(layout).__name__}" - ) + raise TypeError(f"Tile element {i} must be a Layout, got {type(layout).__name__}") return super().__new__(cls, layouts) def __repr__(self): @@ -2139,8 +2244,7 @@ def shape_div(shape: Any, divisor: int) -> Any: def _scalar(s, d): if s % d != 0 and d % s != 0: raise ValueError( - f"shape_div({s}, {d}): one must divide the other for clean " - f"factorization" + f"shape_div({s}, {d}): one must divide the other for clean factorization" ) return (s + d - 1) // d @@ -2164,6 +2268,7 @@ def shape_mod(shape: Any, modulus: int) -> Any: shape_mod((4, 3), 2) -> (2, 1) # 2 consumed from first mode, nothing from second shape_mod((4, 3), 12) -> (4, 3) # All kept (modulus >= size) """ + def _scalar(s, m): return s if m >= s else math.gcd(s, m) @@ -2299,8 +2404,7 @@ def _composition_1d(layout_a: "Layout", b_shape: int, b_stride: int) -> "Layout" for curr_shape, curr_stride in zip(flat_shapes[:-1], flat_strides[:-1]): if curr_shape % remaining_stride != 0 and remaining_stride % curr_shape != 0: raise ValueError( - f"complement: shape {curr_shape} and stride {remaining_stride} " - f"are not divisible" + f"complement: shape {curr_shape} and stride {remaining_stride} are not divisible" ) new_shape = min(max(1, curr_shape // remaining_stride), remaining_shape) if new_shape != 1: @@ -2327,17 +2431,16 @@ def _compose_layouts(layout_a: Layout, layout_b: Layout) -> Layout: def compose_element(b_shape, b_stride): """Recursively compose A with one element of B's shape/stride.""" if is_tuple(b_shape): - results = [compose_element(b_shape[i], b_stride[i]) - for i in range(len(b_shape))] - return Layout(tuple(r.shape for r in results), - tuple(r.stride for r in results)) + results = [compose_element(b_shape[i], b_stride[i]) for i in range(len(b_shape))] + return Layout(tuple(r.shape for r in results), tuple(r.stride for r in results)) return _composition_1d(layout_a, b_shape, b_stride) if is_tuple(layout_b.shape): - results = [compose_element(layout_b.shape[i], layout_b.stride[i]) - for i in range(len(layout_b.shape))] - return Layout(tuple(r.shape for r in results), - tuple(r.stride for r in results)) + results = [ + compose_element(layout_b.shape[i], layout_b.stride[i]) + for i in range(len(layout_b.shape)) + ] + return Layout(tuple(r.shape for r in results), tuple(r.stride for r in results)) return _composition_1d(layout_a, layout_b.shape, layout_b.stride) @@ -2451,6 +2554,7 @@ def compose(layout_a: Any, layout_b: Any) -> Any: # Tuple tiler - convert elements and recurse if is_tuple(layout_b): + def to_layout(elem): if isinstance(elem, Layout): return elem @@ -2849,12 +2953,11 @@ def hier_unzip(splitter, layout_a: Layout, layout_b) -> Layout: if is_tuple(layout_b) and not isinstance(layout_b, Layout): if rank(layout_a) < len(layout_b): - raise ValueError( - f"layout_a rank ({rank(layout_a)}) < tiler length ({len(layout_b)})" - ) + raise ValueError(f"layout_a rank ({rank(layout_a)}) < tiler length ({len(layout_b)})") - splits = [hier_unzip(splitter, mode(layout_a, i), layout_b[i]) - for i in range(len(layout_b))] + splits = [ + hier_unzip(splitter, mode(layout_a, i), layout_b[i]) for i in range(len(layout_b)) + ] first_shapes = [mode(s, 0).shape for s in splits] first_strides = [mode(s, 0).stride for s in splits] @@ -2905,9 +3008,7 @@ def logical_product(layout_a: Layout, layout_b: Layout) -> Layout: # For tuple tilers, apply mode-by-mode if is_tuple(layout_b) and not isinstance(layout_b, Layout): if rank(layout_a) < len(layout_b): - raise ValueError( - f"layout_a rank ({rank(layout_a)}) < tiler length ({len(layout_b)})" - ) + raise ValueError(f"layout_a rank ({rank(layout_a)}) < tiler length ({len(layout_b)})") result_modes = [] for i in range(len(layout_b)): result_modes.append(logical_product(mode(layout_a, i), layout_b[i])) @@ -3239,6 +3340,8 @@ def __repr__(self) -> str: return f"Swizzle({self.bits}, {self.base}, {self.shift})" def __eq__(self, other: object) -> bool: + if self is other: + return True if not isinstance(other, Swizzle): return False return self.bits == other.bits and self.base == other.base and self.shift == other.shift diff --git a/src/tensor_layouts/viz.py b/src/tensor_layouts/viz.py index 0d93d2b..81e1db9 100644 --- a/src/tensor_layouts/viz.py +++ b/src/tensor_layouts/viz.py @@ -660,6 +660,8 @@ def _build_composite_figure( panel_size: Tuple[float, float] = (4, 4), colorize: bool = False, tv_mode: bool = False, + flatten_hierarchical: bool = True, + label_hierarchy_levels: bool = False, ): """Build the composite figure used by draw_composite/show_composite.""" n = len(panels) @@ -691,6 +693,15 @@ def _build_composite_figure( axes = [axes_array[i, j] for i in range(nrows) for j in range(ncols)] # Process each panel + if len(panels) > nrows * ncols: + import warnings + + warnings.warn( + f"{len(panels)} panels provided but grid has only " + f"{nrows * ncols} cells ({nrows}x{ncols}); " + f"extra panels will be dropped", + stacklevel=3, + ) for idx, panel in enumerate(panels): if idx >= len(axes): break @@ -715,6 +726,10 @@ def _build_composite_figure( panel_tv_mode = opts.get("tv_mode", tv_mode) color_layout = opts.get("color_layout", None) num_colors = opts.get("num_colors", 8) + panel_flatten = opts.get("flatten_hierarchical", flatten_hierarchical) + panel_label_levels = opts.get( + "label_hierarchy_levels", label_hierarchy_levels + ) # Get title title = titles[idx] if titles and idx < len(titles) else None @@ -733,17 +748,41 @@ def _build_composite_figure( col_major=opts.get("col_major", True), ) else: - grid = _prepare_offset_grid( - layout, color_layout=color_layout, eval_fn=eval_fn + # Check if this panel should use hierarchical rendering + r = rank(layout) + is_hier = r == 2 and not panel_flatten and ( + isinstance(mode(layout.shape, 0), tuple) + or isinstance(mode(layout.shape, 1), tuple) ) - _draw_grid( - ax, - grid.indices, - title=title, - colorize=panel_colorize, - color_indices=grid.color_indices, - num_colors=num_colors, + grid = _prepare_offset_grid( + layout, color_layout=color_layout, eval_fn=eval_fn, + hierarchical=is_hier, ) + if grid.is_hierarchical: + _draw_hierarchical_grid( + ax, + grid.indices, + grid.rows, + grid.cols, + cell_coords=grid.cell_coords, + row_shape=grid.row_shape, + col_shape=grid.col_shape, + title=title, + colorize=panel_colorize, + color_indices=grid.color_indices, + flatten_hierarchical=False, + label_hierarchy_levels=panel_label_levels, + num_colors=num_colors, + ) + else: + _draw_grid( + ax, + grid.indices, + title=title, + colorize=panel_colorize, + color_indices=grid.color_indices, + num_colors=num_colors, + ) # Hide unused axes for idx in range(len(panels), len(axes)): @@ -766,6 +805,8 @@ def draw_composite( panel_size: Tuple[float, float] = (4, 4), colorize: bool = False, tv_mode: bool = False, + flatten_hierarchical: bool = True, + label_hierarchy_levels: bool = False, ): """Draw multiple layouts in a single composite figure. @@ -782,6 +823,7 @@ def draw_composite( colorize, color_layout, num_colors -- offset-grid options tv_mode -- if True, render this panel as a TV grid grid_rows, grid_cols, thr_id_layout, col_major -- TV options + flatten_hierarchical, label_hierarchy_levels -- hierarchy options filename: Output path (.svg, .png, or .pdf) arrangement: How to arrange panels: - "horizontal": side by side (1 row) @@ -793,6 +835,10 @@ def draw_composite( panel_size: Size of each panel in inches (width, height) colorize: Default colorize setting for all panels tv_mode: If True, render panels as TV layouts with T/V labels + flatten_hierarchical: Default for all panels. If False, show explicit + nested coordinate labels for hierarchical layouts + label_hierarchy_levels: Default for all panels. If True, annotate axes + with hierarchy level labels Example: # Side-by-side comparison @@ -811,6 +857,8 @@ def draw_composite( panel_size=panel_size, colorize=colorize, tv_mode=tv_mode, + flatten_hierarchical=flatten_hierarchical, + label_hierarchy_levels=label_hierarchy_levels, ) _save_figure(fig, filename, dpi) @@ -2775,6 +2823,22 @@ def _get_slice_highlight_mask_2d(layout, slice_spec) -> np.ndarray: if col_match: mask[i, j] = True + elif isinstance(slice_spec, tuple) and r < 2: + if len(slice_spec) != 1: + raise ValueError( + f"Rank-{r} layout requires a 1-element tuple slice_spec, " + f"got {len(slice_spec)}" + ) + (col_spec,) = slice_spec + col_flat = _is_flat_slice_component(col_spec) + for j in range(cols): + col_coord = idx2crd(j, layout.shape) + mask[0, j] = ( + _match_flat_slice_component(j, col_spec, cols) + if col_flat + else _match_nested_slice_component(col_coord, col_spec, layout.shape) + ) + return mask @@ -3184,6 +3248,8 @@ def show_composite( panel_size: Tuple[float, float] = (4, 4), colorize: bool = False, tv_mode: bool = False, + flatten_hierarchical: bool = True, + label_hierarchy_levels: bool = False, ): """Display a composite figure inline (for Jupyter notebooks). @@ -3195,6 +3261,10 @@ def show_composite( panel_size: Size of each panel in inches (width, height) colorize: Default colorize setting for all panels tv_mode: If True, render panels as TV layouts with T/V labels + flatten_hierarchical: Default for all panels. If False, show explicit + nested coordinate labels for hierarchical layouts + label_hierarchy_levels: Default for all panels. If True, annotate axes + with hierarchy level labels Returns: matplotlib Figure @@ -3207,6 +3277,8 @@ def show_composite( panel_size=panel_size, colorize=colorize, tv_mode=tv_mode, + flatten_hierarchical=flatten_hierarchical, + label_hierarchy_levels=label_hierarchy_levels, ) diff --git a/tests/analysis.py b/tests/analysis.py index eb0dc6f..bdd5618 100644 --- a/tests/analysis.py +++ b/tests/analysis.py @@ -98,14 +98,14 @@ def test_footprint_broadcast(): def test_bank_conflicts_linear(): """Linear stride-1 access: no conflicts.""" - result = bank_conflicts(Layout(32, 1)) + result = bank_conflicts(Layout(32, 1), element_bytes=2) assert result['conflict_free'] assert result['max_ways'] == 1 def test_bank_conflicts_broadcast(): """All threads access same address: broadcast, not a conflict.""" - result = bank_conflicts(Layout(32, 0)) + result = bank_conflicts(Layout(32, 0), element_bytes=2) assert result['conflict_free'] @@ -167,9 +167,18 @@ def test_bank_conflicts_group_size(): def test_bank_conflicts_group_size_validation(): """group_size <= 0 must raise ValueError.""" with pytest.raises(ValueError, match="group_size must be positive"): - bank_conflicts(Layout(32, 1), group_size=0) + bank_conflicts(Layout(32, 1), element_bytes=2, group_size=0) with pytest.raises(ValueError, match="group_size must be positive"): - bank_conflicts(Layout(32, 1), group_size=-1) + bank_conflicts(Layout(32, 1), element_bytes=2, group_size=-1) + + +def test_bank_conflicts_tv_layout(): + """TV layout analyzes all values per thread, not just value 0.""" + # 32 threads, 2 values: stride-1 threads, stride-32 values + tv = Layout((32, 2), (1, 32)) + r = bank_conflicts(tv, element_bytes=2) + assert r['conflict_free'] + assert len(r['bank_to_threads']) == 32 # all banks accessed ## coalescing_efficiency @@ -177,7 +186,7 @@ def test_bank_conflicts_group_size_validation(): def test_coalescing_contiguous_fp16(): """32 threads, stride 1, fp16: one cache line (64B of 128B).""" - result = coalescing_efficiency(Layout(32, 1)) + result = coalescing_efficiency(Layout(32, 1), element_bytes=2) assert result['transactions'] == 1 assert result['efficiency'] == pytest.approx(0.5) @@ -191,7 +200,7 @@ def test_coalescing_contiguous_fp32(): def test_coalescing_strided(): """Stride-2 access doubles the cache lines touched.""" - result = coalescing_efficiency(Layout(32, 2)) + result = coalescing_efficiency(Layout(32, 2), element_bytes=2) assert result['transactions'] == 1 # 32*2*2=128 bytes, still fits in 1 line # Actually: offsets 0,2,4,...,62. byte addrs 0,4,8,...,124. All in line 0. assert result['efficiency'] == pytest.approx(0.5) @@ -200,7 +209,7 @@ def test_coalescing_strided(): def test_coalescing_large_stride(): """Large stride: each thread touches a different cache line.""" # stride 64 elements * 2 bytes = 128 bytes = 1 cache line apart - result = coalescing_efficiency(Layout(32, 64)) + result = coalescing_efficiency(Layout(32, 64), element_bytes=2) assert result['transactions'] == 32 # 32 threads * 2 bytes = 64 useful bytes, 32 * 128 = 4096 transferred assert result['efficiency'] == pytest.approx(64.0 / (32 * 128)) @@ -208,18 +217,28 @@ def test_coalescing_large_stride(): def test_coalescing_broadcast(): """All threads access same element: single transaction, minimal useful bytes.""" - result = coalescing_efficiency(Layout(32, 0)) + result = coalescing_efficiency(Layout(32, 0), element_bytes=2) assert result['transactions'] == 1 # Only 1 unique offset: 1 * 2 bytes useful out of 128 transferred assert result['efficiency'] == pytest.approx(2.0 / 128) +def test_coalescing_tv_layout(): + """TV layout counts all values in cache line computation.""" + # 32 threads, 4 values each, stride-4 between threads + tv = Layout((32, 4), (4, 1)) + result = coalescing_efficiency(tv, element_bytes=2) + # 128 unique offsets * 2B = 256B -> cache lines 0, 1 + assert result['transactions'] == 2 + assert result['efficiency'] == pytest.approx(1.0) + + ## segment_analysis def test_segment_analysis_contiguous_fp16(): """32 threads, stride 1, fp16: 2 segments, 1 cache line.""" - result = segment_analysis(Layout(32, 1)) + result = segment_analysis(Layout(32, 1), element_bytes=2) # 32 * 2B = 64B -> 2 segments of 32B, 1 cache line of 128B assert result['segments'] == 2 assert result['cache_lines'] == 1 @@ -246,6 +265,16 @@ def test_segment_analysis_broadcast(): assert result['requested_bytes'] == 64 +def test_segment_analysis_tv_layout(): + """TV layout includes all values in segment computation.""" + tv = Layout((32, 4), (4, 1)) + result = segment_analysis(tv, element_bytes=2) + # 128 elements * 2B = 256B -> 8 segments, 2 cache lines + assert result['segments'] == 8 + assert result['cache_lines'] == 2 + assert result['requested_bytes'] == 256 # 32 * 4 * 2 + + ## per-group analysis @@ -260,6 +289,14 @@ def test_per_group_bank_conflicts(): assert r_per['worst_max_ways'] == r_single['max_ways'] +def test_per_group_bank_conflicts_tv_layout(): + """TV layout groups by thread dimension, not flat index.""" + # 32 threads, 4 values each: should be 1 group (not 4) + tv = Layout((32, 4), (1, 32)) + result = per_group_bank_conflicts(tv, element_bytes=2, group_size=32) + assert len(result['groups']) == 1 + + def test_per_group_coalescing(): """Per-group coalescing for a uniform layout gives identical per-warp results.""" r_per = per_group_coalescing(Layout(64, 1), element_bytes=2) @@ -269,6 +306,16 @@ def test_per_group_coalescing(): assert g['transactions'] == 1 +def test_per_group_coalescing_tv_layout(): + """TV layout groups by thread dimension, not flat index.""" + # 32 threads, 4 values each (contiguous within each thread's block) + tv = Layout((32, 4), (4, 1)) + result = per_group_coalescing(tv, element_bytes=2, group_size=32) + assert len(result['groups']) == 1 + # 32 threads * 4 values = 128 elements * 2B = 256B -> 2 cache lines + assert result['groups'][0]['transactions'] == 2 + + ## cycles @@ -562,6 +609,16 @@ def test_explain_logical_product(): assert '(4, 3) : (1, 4)' in text +def test_explain_logical_product_tuple_tiler(): + """explain handles logical_product with tuple tiler without crashing.""" + text = explain(logical_product, Layout((4, 4), (1, 4)), (2, 2)) + assert 'logical_product' in text + assert 'mode 0' in text + assert 'mode 1' in text + expected = logical_product(Layout((4, 4), (1, 4)), (2, 2)) + assert str(expected) in text + + def test_explain_complement(): """explain shows complement with image and codomain.""" text = explain(complement, Layout(4, 2), 16) @@ -638,6 +695,58 @@ def test_explain_flat_divide(): assert '(tile0, tile1, ..., rest0, rest1, ...)' in text +## MMAAtom and CopyAtom __str__ + + +def test_mma_atom_str(): + """MMAAtom.__str__ returns a concise summary with name and shape.""" + from tensor_layouts.atoms import MMAAtom + + atom = MMAAtom( + name="test_16x8x4", + ptx="test.op", + shape_mnk=(16, 8, 4), + thr_id=Layout(32), + a_layout=Layout((32, 4), (4, 1)), + b_layout=Layout((32, 2), (2, 1)), + c_layout=Layout((32, 4), (4, 1)), + ) + assert str(atom) == "MMAAtom('test_16x8x4', 16x8x4)" + + +def test_copy_atom_str(): + """CopyAtom.__str__ returns a concise summary with name.""" + from tensor_layouts.atoms import CopyAtom + + atom = CopyAtom( + name="test_copy_128b", + ptx="test.copy", + thr_id=Layout(32), + src_layout_bits=Layout((32, 128), (128, 1)), + dst_layout_bits=Layout((32, 128), (128, 1)), + ) + assert str(atom) == "CopyAtom('test_copy_128b')" + + +def test_mma_atom_repr_is_verbose(): + """MMAAtom.__repr__ (dataclass-generated) includes all fields.""" + from tensor_layouts.atoms import MMAAtom + + atom = MMAAtom( + name="test_2x2x1", + ptx="test", + shape_mnk=(2, 2, 1), + thr_id=None, + a_layout=Layout(2, 1), + b_layout=Layout(2, 1), + c_layout=Layout((2, 2), (1, 2)), + ) + r = repr(atom) + assert r.startswith("MMAAtom(name=") + assert "shape_mnk=(2, 2, 1)" in r + assert "a_layout=Layout(2, 1)" in r + + if __name__ == "__main__": import subprocess import sys diff --git a/tests/layouts.py b/tests/layouts.py index cf5a911..e1cf319 100644 --- a/tests/layouts.py +++ b/tests/layouts.py @@ -64,7 +64,7 @@ def test_tuple_single_element(): def test_tuple_nested_single_element(): - t = (((2,))) + t = (2,) assert len(t) == 1 assert size(t) == 2 assert rank(t) == 1 @@ -172,6 +172,42 @@ def test_congruent(): assert not congruent((2, (3, 4)), (5, 6)) +def test_weakly_congruent(): + # Both scalars → True (same as congruent) + assert weakly_congruent(3, 5) + # Scalar A matches any B (the key relaxation over congruent) + assert weakly_congruent(6, (2, 3)) + assert weakly_congruent(1, ((2, 3), (4, 5))) + # Tuple A vs scalar B → False (A is more structured) + assert not weakly_congruent((2, 3), 6) + assert not weakly_congruent(((2, 3), 4), 24) + # Scalar A vs 1-tuple B → True + assert weakly_congruent(6, (6,)) + # 1-tuple A vs scalar B → False + assert not weakly_congruent((6,), 6) + # Same flat rank → True + assert weakly_congruent((2, 3), (4, 5)) + assert weakly_congruent((3, 128, 128), (1, 256, 64)) + # Different flat rank → False + assert not weakly_congruent((3, 128), (1, 256, 64)) + assert not weakly_congruent((1, 256, 64), (3, 128)) + # Same nested structure → True + assert weakly_congruent((2, (3, 4)), (5, (6, 7))) + # A deeper than B in a sub-mode → False + assert not weakly_congruent((2, (3, 4)), (5, 6)) + # A flatter than B in a sub-mode → True (scalar sub-mode matches nested B) + assert weakly_congruent((2, 3), (5, (6, 7))) + # Deeply nested: A flat sub-mode vs B's deep nesting + assert weakly_congruent((2, 3), ((4, 5), ((6, 7), 8))) + # Asymmetry: congruent ↔ weakly_congruent in both directions, + # but weakly_congruent only in one direction when profiles differ + assert congruent((2, 3), (4, 5)) + assert weakly_congruent((2, 3), (4, 5)) + assert weakly_congruent((4, 5), (2, 3)) + assert weakly_congruent(6, (2, 3)) + assert not weakly_congruent((2, 3), 6) + + def test_compatible(): assert not compatible(24, 32) assert compatible(24, (4, 6)) @@ -201,7 +237,38 @@ def test_layout_basic(): Layout((6, 1, 12, 2, 2), (2, 0, 12, 144, 1)) # Complex layout +def test_layout_type_validation(): + """Layout rejects invalid shape/stride types with clear messages.""" + # Strings rejected + with pytest.raises(TypeError, match="stride.*str"): + Layout((4, 2), "row") + with pytest.raises(TypeError, match="shape.*str"): + Layout("abc") + + # Floats rejected + with pytest.raises(TypeError, match="stride.*float"): + Layout((4, 2), 1.5) + with pytest.raises(TypeError, match="shape.*float"): + Layout(3.14) + + # None rejected + with pytest.raises(TypeError, match="shape.*NoneType"): + Layout(None) + + # Valid constructions still work + Layout((4, 2), (1, 4)) + Layout(((2, 2), 4), ((1, 2), 4)) + Layout(8) + Layout([4, 2]) # lists are fine + + def test_layout_rank_size_cosize(): + # Single-mode layout is rank 1 (one mode), not rank 0 + L_vec = Layout(31, 1) + assert rank(L_vec) == 1 + assert size(L_vec) == 31 + assert mode(L_vec, 0) == L_vec + L5 = Layout((64, 32), (1, 128)) assert rank(L5) == 2 assert size(L5) == 2048 @@ -410,9 +477,9 @@ def test_coordinate_validation(): L([1, 2]) # Valid cases still work - assert L(1) == 1 # flat index - assert L(1, 2) == 9 # tuple coord via *args - assert L((1, 2)) == 9 # tuple coord via single arg + assert L(1) == 1 # flat index + assert L(1, 2) == 9 # tuple coord via *args + assert L((1, 2)) == 9 # tuple coord via single arg def test_idx2crd_crd2flat_crd2offset(): @@ -466,16 +533,16 @@ def test_shape_div_non_divisible(): (e.g., shape_div/mod won't be complementary), so we assert. """ # Valid cases where divisibility holds - assert shape_div(12, 4) == 3 # 12%4==0 - assert shape_div(4, 12) == 1 # 12%4==0 - assert shape_div(8, 2) == 4 # 8%2==0 - assert shape_div(2, 8) == 1 # 8%2==0 + assert shape_div(12, 4) == 3 # 12%4==0 + assert shape_div(4, 12) == 1 # 12%4==0 + assert shape_div(8, 2) == 4 # 8%2==0 + assert shape_div(2, 8) == 1 # 8%2==0 # Invalid cases should raise ValueError with pytest.raises(ValueError): - shape_div(6, 4) # 6%4≠0, 4%6≠0 + shape_div(6, 4) # 6%4≠0, 4%6≠0 with pytest.raises(ValueError): - shape_div(4, 6) # 4%6≠0, 6%4≠0 + shape_div(4, 6) # 4%6≠0, 6%4≠0 def test_shape_mod_non_divisible(): @@ -489,9 +556,9 @@ def test_shape_mod_non_divisible(): # shape_mod(2, 2) = gcd(2,2) = 2 assert shape_mod((6, 2), 4) == (2, 2) # Scalar shape_mod: when modulus < shape, returns gcd - assert shape_mod(6, 4) == 2 # gcd(6,4) = 2 + assert shape_mod(6, 4) == 2 # gcd(6,4) = 2 # Scalar shape_mod: when modulus >= shape, returns shape - assert shape_mod(4, 6) == 4 # 6 >= 4, returns 4 + assert shape_mod(4, 6) == 4 # 6 >= 4, returns 4 def test_shape_div_mod_complementary(): @@ -500,11 +567,15 @@ def test_shape_div_mod_complementary(): This holds when the divisor evenly divides each mode it consumes. """ test_cases = [ - ((6, 2), 2), ((6, 2), 3), - ((6, 2), 6), ((6, 2), 12), - ((4, 3), 2), ((4, 3), 4), + ((6, 2), 2), + ((6, 2), 3), + ((6, 2), 6), + ((6, 2), 12), + ((4, 3), 2), + ((4, 3), 4), ((4, 3), 12), - ((3, 6, 2, 8), 3), ((3, 6, 2, 8), 9), + ((3, 6, 2, 8), 3), + ((3, 6, 2, 8), 9), ((3, 6, 2, 8), 72), ] for shape, div in test_cases: @@ -1030,7 +1101,9 @@ def test_core_matrix_operations(): # For a 2-byte dtype such as f16, core matrix is 8x8 tile1 = Layout((8, 1), (1, 0)) # (8,1):(1,0) mul1 = Layout((1, 8), (0, 1)) - tile2 = coalesce(blocked_product(tile1, mul1), profile=(None, None)) # (8,8):(1,8) -> One core Matrix + tile2 = coalesce( + blocked_product(tile1, mul1), profile=(None, None) + ) # (8,8):(1,8) -> One core Matrix assert tile2 == Layout((8, 8), (1, 8)) # Now organize core matrices into 8x8 pattern, so that we have a 64x64 Tile, say in SMem mul2 = Layout((8, 8), (1, 8)) @@ -1062,7 +1135,74 @@ def test_safe_div(): def test_tile_repr(): tiler = Tile(Layout(3, 4), Layout(8, 2)) r = repr(tiler) - assert r == "Tile(3 : 4, 8 : 2)" + assert r == "Tile(Layout(3, 4), Layout(8, 2))" + + +## Layout.__repr__ and __str__ + + +def test_layout_repr_scalar(): + """repr() of a 1D layout returns an eval-safe constructor string.""" + L = Layout(8, 2) + assert repr(L) == "Layout(8, 2)" + + +def test_layout_repr_tuple(): + """repr() of a multi-dimensional layout returns an eval-safe constructor string.""" + L = Layout((4, 8), (1, 4)) + assert repr(L) == "Layout((4, 8), (1, 4))" + + +def test_layout_repr_hierarchical(): + """repr() of a hierarchical layout returns an eval-safe constructor string.""" + L = Layout(((2, 3), (2, 4)), ((1, 6), (2, 12))) + assert repr(L) == "Layout(((2, 3), (2, 4)), ((1, 6), (2, 12)))" + + +def test_layout_repr_swizzled(): + """repr() of a swizzled layout includes the swizzle keyword argument.""" + sw = Swizzle(3, 0, 3) + L = compose(sw, Layout((8, 8), (8, 1))) + r = repr(L) + assert r == "Layout((8, 8), (8, 1), swizzle=Swizzle(3, 0, 3))" + + +def test_layout_repr_eval_roundtrip(): + """eval(repr(L)) reconstructs an equal Layout (the gold standard for repr).""" + cases = [ + Layout(8, 2), + Layout((4, 8), (1, 4)), + Layout((4, 8), (0, 1)), + Layout(((2, 3), (2, 4)), ((1, 6), (2, 12))), + ] + for L in cases: + reconstructed = eval(repr(L)) # noqa: S307 + assert reconstructed == L, f"Roundtrip failed for {repr(L)}" + + +def test_layout_repr_eval_roundtrip_swizzled(): + """eval(repr(L)) works for swizzled layouts too.""" + L = compose(Swizzle(3, 0, 3), Layout((8, 8), (8, 1))) + reconstructed = eval(repr(L)) # noqa: S307 + assert reconstructed == L + + +def test_layout_str_scalar(): + """str() returns the human-readable CuTe notation.""" + L = Layout(8, 2) + assert str(L) == "8 : 2" + + +def test_layout_str_tuple(): + """str() returns the human-readable CuTe notation for multi-dim layouts.""" + L = Layout((4, 8), (1, 4)) + assert str(L) == "(4, 8) : (1, 4)" + + +def test_layout_str_swizzled(): + """str() returns the CuTe composition notation for swizzled layouts.""" + L = compose(Swizzle(3, 0, 3), Layout((8, 8), (8, 1))) + assert str(L) == "(Swizzle(3, 0, 3)) o ((8, 8) : (8, 1))" ## Layout.__hash__ @@ -1085,6 +1225,52 @@ def test_layout_hash(): assert len(s) == 2 +## Layout.__eq__ identity short-circuit + + +def test_layout_eq_identity_shortcircuit(): + """Same object identity returns True immediately.""" + L = Layout((4, 8), (1, 4)) + assert L == L + assert L is L + + sw = Swizzle(3, 0, 3) + L_sw = compose(sw, L) + assert L_sw == L_sw + + +def test_layout_eq_structural(): + """Distinct objects with equal shape/stride/swizzle are equal.""" + L1 = Layout((4, 8), (1, 4)) + L2 = Layout((4, 8), (1, 4)) + assert L1 is not L2 + assert L1 == L2 + + +def test_layout_eq_non_layout(): + """Comparing Layout with non-Layout returns False, not an error.""" + L = Layout((4, 8), (1, 4)) + assert L != 42 + assert L != "not a layout" + assert L != (4, 8) + assert L != None # noqa: E711 + + +def test_swizzle_eq_identity_shortcircuit(): + """Same Swizzle identity returns True immediately.""" + sw = Swizzle(3, 0, 3) + assert sw == sw + assert sw is sw + + +def test_swizzle_eq_structural(): + """Distinct Swizzle objects with equal fields are equal.""" + sw1 = Swizzle(3, 0, 3) + sw2 = Swizzle(3, 0, 3) + assert sw1 is not sw2 + assert sw1 == sw2 + + ## compose() functional property @@ -1142,11 +1328,12 @@ def test_swizzled_layout_eq_hash(): def test_offset_swizzled_layout_basic(): from tensor_layouts import Tensor + sw_layout = compose(Swizzle(3, 0, 3), Layout((8, 8), (8, 1))) tensor = Tensor(sw_layout) # Slicing a Tensor produces a Tensor with offset row_slice = tensor[3, :] - assert hasattr(row_slice, 'offset') + assert hasattr(row_slice, "offset") assert row_slice.offset == Layout((8, 8), (8, 1))(3, 0) # = 24 # Check functional correctness: tensor[3, :](j) == tensor(3, j) @@ -1156,6 +1343,7 @@ def test_offset_swizzled_layout_basic(): def test_offset_swizzled_layout_repr(): from tensor_layouts import Tensor + sw_layout = compose(Swizzle(3, 0, 3), Layout((8, 8), (8, 1))) tensor = Tensor(sw_layout) row_slice = tensor[2, :] @@ -1166,6 +1354,7 @@ def test_offset_swizzled_layout_repr(): def test_offset_swizzled_layout_eq(): from tensor_layouts import Tensor + sw_layout = compose(Swizzle(3, 0, 3), Layout((8, 8), (8, 1))) tensor = Tensor(sw_layout) slice1 = tensor[3, :] @@ -1311,6 +1500,7 @@ def test_tile_to_shape_nested_block(): ## is_layout + def test_is_layout(): assert is_layout(Layout(4, 1)) is True assert is_layout(Layout((2, 3), (1, 2))) is True @@ -1322,6 +1512,7 @@ def test_is_layout(): ## unflatten + def test_unflatten_tuple(): # Flat tuple -> nested tuple assert unflatten((1, 2, 3, 4, 5), ((0, 0), (0, 0, 0))) == ((1, 2), (3, 4, 5)) @@ -1419,6 +1610,7 @@ def test_make_ordered_layout_scalar(): ## dice_modes + def test_dice_modes_scalar_coord(): # Scalar coord: identity (keep everything) layout = Layout((3, 4), (1, 4)) @@ -1476,6 +1668,7 @@ def test_dice_modes_complement_of_slice_modes(): ## nullspace + def test_nullspace_all_zero_strides(): # All stride-0: everything is in the kernel # Inspired by C++ test: Layout,Stride<_0,_0,_0>> @@ -1543,6 +1736,7 @@ def test_nullspace_scalar_zero_stride(): ## max_common_vector and max_common_layout + def test_max_common_vector_identical(): # Same layout: all elements are common a = Layout(8, 1) @@ -1586,6 +1780,7 @@ def test_max_common_layout_partial(): ## flat_product + def test_flat_product_basic(): # flat_product = zipped_product then unpack both modes block = Layout(4, 1) @@ -1612,6 +1807,7 @@ def test_flat_product_2d(): ## raked_product + def test_raked_product_basic(): # raked_product vs blocked_product: reversed zip order block = Layout((2, 2), (1, 2)) @@ -1726,8 +1922,11 @@ def test_upcast_known_copy_atoms(): derived from the CUTLASS C++ copy_traits_sm75.hpp source. """ from tensor_layouts.atoms_nv import ( - SM75_U32x1_LDSM_N, SM75_U32x4_LDSM_N, - SM75_U16x2_LDSM_T, SM75_U16x4_LDSM_T, SM75_U16x8_LDSM_T, + SM75_U32x1_LDSM_N, + SM75_U32x4_LDSM_N, + SM75_U16x2_LDSM_T, + SM75_U16x4_LDSM_T, + SM75_U16x8_LDSM_T, ) cases = [ @@ -1807,9 +2006,12 @@ def test_iter_layout_2d_col_major(): layout = Layout((2, 3), (1, 2)) result = list(iter_layout(layout)) expected = [ - ((0, 0), 0), ((1, 0), 1), # col 0 - ((0, 1), 2), ((1, 1), 3), # col 1 - ((0, 2), 4), ((1, 2), 5), # col 2 + ((0, 0), 0), + ((1, 0), 1), # col 0 + ((0, 1), 2), + ((1, 1), 3), # col 1 + ((0, 2), 4), + ((1, 2), 5), # col 2 ] assert result == expected diff --git a/tests/oracle_nv.py b/tests/oracle_nv.py index 0a8ce32..9f1e49e 100644 --- a/tests/oracle_nv.py +++ b/tests/oracle_nv.py @@ -1492,6 +1492,21 @@ def test_product_each_matches_pycute_size(): assert result == expected, f"product_each({shape}) = {result} != {expected}" + +@pytest.mark.skipif(pycute is None, reason="pycute not installed") +def test_oracle_idx2crd(): + shapes = [ + 4, + (4, 2), + (2, (2, 2)), + ((2, 2), (2, 2)), + ] + indices = [0, 1, 3, 5, 10, 16] + for s in shapes: + for idx in indices: + assert idx2crd(idx, s) == pycute.idx2crd(idx, s) + + if __name__ == "__main__": import traceback test_funcs = [v for k, v in sorted(globals().items()) if k.startswith("test_") and callable(v)] diff --git a/tests/tensor.py b/tests/tensor.py index 79eae7f..f443d02 100644 --- a/tests/tensor.py +++ b/tests/tensor.py @@ -81,8 +81,8 @@ def test_rank1_contiguous(self): layout = Layout(32, 1) tensor = Tensor(layout) - # Scalar shape has rank 0 in CuTe convention - assert rank(tensor.layout) == 0 + # Single-mode layout has rank 1 (one mode) + assert rank(tensor.layout) == 1 assert size(tensor.layout) == 32 for i in range(32): assert tensor(i) == i diff --git a/tests/viz.py b/tests/viz.py index 76a5dd3..e34ab7c 100644 --- a/tests/viz.py +++ b/tests/viz.py @@ -400,6 +400,49 @@ def test_draw_composite_mixed_tv_and_offset(): plt.close(fig) +@requires_viz +def test_draw_composite_hierarchical_panel(): + """Composite figure with flatten_hierarchical=False renders hierarchy lines.""" + hier = Layout(((2, 2), (2, 2)), ((1, 4), (2, 8))) + flat = Layout((4, 4), (4, 1)) + panels = [ + (hier, {'flatten_hierarchical': False}), + flat, + ] + fig = show_composite(panels, titles=["Hierarchical", "Flat"]) + try: + assert isinstance(fig, matplotlib.figure.Figure) + # The hierarchical panel should have hierarchy boundary lines + hier_ax = fig.axes[0] + assert len(hier_ax.lines) > 0 + # The flat panel should have no hierarchy lines + flat_ax = fig.axes[1] + assert len(flat_ax.lines) == 0 + finally: + plt.close(fig) + + +@requires_viz +def test_draw_composite_hierarchical_top_level_default(): + """flatten_hierarchical=False as top-level default applies to all panels.""" + hier = Layout(((2, 2), (2, 2)), ((1, 4), (2, 8))) + fig = show_composite([hier], flatten_hierarchical=False) + try: + ax = fig.axes[0] + assert len(ax.lines) > 0 + finally: + plt.close(fig) + + +@requires_viz +def test_draw_composite_warns_on_panel_truncation(): + """Warning emitted when panels exceed grid capacity.""" + panels = [Layout((2, 2), (2, 1)) for _ in range(5)] + with pytest.warns(UserWarning, match="5 panels.*4 cells"): + fig = show_composite(panels, arrangement="grid:2x2") + plt.close(fig) + + @requires_viz def test_draw_copy_layout_smoke(): src = Layout((4, 2), (2, 1)) @@ -1015,6 +1058,49 @@ def test_slice_highlight_mask_tracks_logical_cells_not_offsets(): ] +@requires_viz +def test_slice_highlight_mask_1d_tuple_spec(): + """1D layout with tuple slice_spec should highlight the correct elements.""" + layout = Layout(8, 1) + mask = _get_slice_highlight_mask_2d(layout, (slice(2, 5),)) + assert mask.shape == (1, 8) + assert mask.tolist() == [[False, False, True, True, True, False, False, False]] + + +@requires_viz +def test_slice_highlight_mask_1d_tuple_spec_rank1(): + """Rank-1 layout with tuple slice_spec should highlight the correct elements.""" + layout = Layout((8,), (1,)) + mask = _get_slice_highlight_mask_2d(layout, (slice(2, 5),)) + assert mask.shape == (1, 8) + assert mask.tolist() == [[False, False, True, True, True, False, False, False]] + + +@requires_viz +def test_slice_highlight_mask_1d_tuple_int_spec(): + """1D layout with tuple (int,) slice_spec highlights a single element.""" + layout = Layout(8, 1) + mask = _get_slice_highlight_mask_2d(layout, (3,)) + assert mask.shape == (1, 8) + assert mask.tolist() == [[False, False, False, True, False, False, False, False]] + + +@requires_viz +def test_slice_highlight_mask_1d_tuple_none_spec(): + """1D layout with tuple (None,) selects all elements.""" + layout = Layout(4, 1) + mask = _get_slice_highlight_mask_2d(layout, (None,)) + assert mask.tolist() == [[True, True, True, True]] + + +@requires_viz +def test_slice_highlight_mask_1d_wrong_tuple_length_raises(): + """1D layout with 2-element tuple slice_spec raises ValueError.""" + layout = Layout(4, 1) + with pytest.raises(ValueError, match="1-element tuple"): + _get_slice_highlight_mask_2d(layout, (1, 2)) + + @requires_viz def test_compute_tv_mapping_uses_first_wins_for_duplicate_cells(): layout = Layout((2, 2), (0, 0))