Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
e4e4898
Fix idx2crd and crd2flat to accept Layout objects as shape argument
jduprat Apr 1, 2026
4c19c9f
Fix crd2crd to thread src_shape through per-mode recursion
jduprat Apr 1, 2026
e612868
Add cell_labels parameter to draw_layout and show_layout
jduprat Apr 1, 2026
08ad6e6
Use plain integers for axis labels in hierarchical mode
jduprat Apr 1, 2026
94d33e1
Add examples and check targets to Makefile
jduprat Apr 1, 2026
bcdf983
Fix rank>=3 panel splitting to match CuTe convention
jduprat Apr 1, 2026
ac7ea1a
Add interleave_colors option for hue-grouped palette
jduprat Apr 1, 2026
e0e61b5
Accept None as free-dimension marker in Tensor.__getitem__
jduprat Apr 2, 2026
425d3e9
Add Tensor.__str__ with offset notation and update draw_slice title
jduprat Apr 2, 2026
9f61168
Fix Tensor slicing for hierarchical specs with nested Nones
jduprat Apr 2, 2026
b0ae0fb
Fix slice_modes to preserve hierarchical mode boundaries
jduprat Apr 2, 2026
db0172d
Add storage to Tensor with auto-labeling in draw_layout
jduprat Apr 2, 2026
7280c90
Add is_contiguous() as a readable alias for is_bijective()
jduprat Apr 2, 2026
1a2622e
Add transpose option to draw_layout for rank-1 column vectors
jduprat Apr 2, 2026
f336457
Add tests for per-mode coalesce with None profile
jduprat Apr 3, 2026
9782acf
Add to_F2_matrix() with tests cross-referenced against Triton LinearL…
jduprat Apr 3, 2026
4c97cf8
Make Tensor[int] flat 1D evaluation on any-rank tensor
jduprat Apr 3, 2026
f29f0db
Make size, rank, cosize, depth, mode, flatten, image accept Tensors
jduprat Apr 4, 2026
6cbd37f
Remove show_* viz functions; draw_* handles inline display via filena…
jduprat Apr 4, 2026
13c1117
Add Tensor.view(layout) for same-storage reinterpretation
jduprat Apr 5, 2026
7ee675d
Fix Tensor detection in viz with duck typing instead of isinstance
jduprat Apr 5, 2026
9f5e3ce
Auto-compute draw_composite panel_size from layout dimensions
jduprat Apr 5, 2026
eca1cd6
Pass draw_composite rendering options through **kwargs
jduprat Apr 5, 2026
3e2ba54
Add draw_gemm for matmul spatial arrangement of A, B, C
jduprat Apr 5, 2026
99d5a03
Add algorithms.ipynb: COPY and GEMM visualized with layout algebra
jduprat Apr 5, 2026
91a487a
Add community feedback notice to all atom definition files
jduprat Apr 5, 2026
6716293
Add RDNA3/RDNA4 WMMA atom definitions to atoms_amd.py
jduprat Apr 5, 2026
718638d
Add Intel Xe GPU DPAS atom definitions
jduprat Apr 5, 2026
59dfdfd
Add applications.ipynb: six layout algebra patterns from arXiv:2603.0…
jduprat Apr 5, 2026
5f10d36
Fix draw_composite auto-sizing to respect grid_rows/grid_cols overrides
jduprat Apr 6, 2026
8a5fa49
Add Grouped GEMM and REDUCE sections to algorithms.ipynb
jduprat Apr 6, 2026
679f6ac
Add Epilogue Fusion and Online Softmax sections to algorithms.ipynb
jduprat Apr 6, 2026
2109550
Fix trailing comma in Layout.__str__ for 1-tuple shapes
jduprat Apr 6, 2026
453f5b0
Add precision parameter for float cell labels in viz
jduprat Apr 6, 2026
1c4ad36
Reorder algorithms.ipynb by rank and polish notation
jduprat Apr 6, 2026
c38f2df
Update applications.ipynb and add intile/oftile figure
jduprat Apr 6, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
.PHONY: build clean test docs lint
.PHONY: build clean test check docs lint examples

build:
pip install -e .
Expand All @@ -13,8 +13,15 @@ clean:
test:
python -m pytest tests/ -v

check: test

docs:
python3 docs/generate_figures.py

lint:
ruff check src/ tests/ examples/

examples:
python3 examples/layouts.py
python3 examples/tensor.py
python3 examples/viz.py
81 changes: 81 additions & 0 deletions docs/analysis_api.md
Original file line number Diff line number Diff line change
Expand Up @@ -288,3 +288,84 @@ explain(complement, Layout(4, 2), 16)
# complement = (2, 2) : (1, 8)
# image(complement) = [0, 1, 8, 9]
```

## F2 Linear Layout Matrix

`to_F2_matrix(layout)` converts a layout with power-of-2 shapes to its
binary matrix representation over GF(2). The layout mapping becomes
`offset_bits = M @ coord_bits (mod 2)`.

This is the "linear layout" representation from arXiv 2603.02298 Section 2.4.4.
Swizzles (XOR operations) are linear over F2 and fold into the matrix.

```python
from tensor_layouts.analysis import to_F2_matrix
```

### Identity (column-major)

A contiguous column-major layout is the identity map over F2:

```python
to_F2_matrix(Layout((4, 8), (1, 4)))
# [[1, 0, 0, 0, 0],
# [0, 1, 0, 0, 0],
# [0, 0, 1, 0, 0],
# [0, 0, 0, 1, 0],
# [0, 0, 0, 0, 1]]
```

### Row-major (bit permutation)

Row-major swaps the row and column bit groups -- a permutation matrix:

```python
to_F2_matrix(Layout((4, 8), (8, 1)))
# [[0, 0, 1, 0, 0], coord bits: [row0, row1, col0, col1, col2]
# [0, 0, 0, 1, 0], offset bits: row bits moved to high positions
# [0, 0, 0, 0, 1],
# [1, 0, 0, 0, 0],
# [0, 1, 0, 0, 0]]
```

### Swizzle (XOR connections)

Swizzle(3,0,3) XORs offset bits 0-2 with bits 3-5, adding off-diagonal
1s to the identity:

```python
to_F2_matrix(compose(Swizzle(3, 0, 3), Layout((8, 8), (8, 1))))
# [[1, 0, 0, 1, 0, 0], col0 = col0 XOR row0
# [0, 1, 0, 0, 1, 0], col1 = col1 XOR row1
# [0, 0, 1, 0, 0, 1], col2 = col2 XOR row2
# [1, 0, 0, 0, 0, 0], row0 = row0
# [0, 1, 0, 0, 0, 0], row1 = row1
# [0, 0, 1, 0, 0, 0]] row2 = row2
```

### MMA register mapping

The SM80 16x8x16 C accumulator layout maps (thread, value) bits to
(m, n) coordinates of the output tile. The F2 matrix reveals which
thread and value bits control which output dimensions:

```python
from tensor_layouts.atoms_nv import SM80_16x8x16_F16F16F16F16_TN
c = SM80_16x8x16_F16F16F16F16_TN.c_layout
# ((4, 8), (2, 2)) : ((32, 1), (16, 8))
# Thread bits T0-T4, Value bits V0-V1 -> m0-m3, n0-n2

to_F2_matrix(c)
# T0 T1 T2 T3 T4 V0 V1
# m0 [ 0, 0, 1, 0, 0, 0, 0] m0 = T2
# m1 [ 0, 0, 0, 1, 0, 0, 0] m1 = T3
# m2 [ 0, 0, 0, 0, 1, 0, 0] m2 = T4
# m3 [ 0, 0, 0, 0, 0, 0, 1] m3 = V1
# n0 [ 0, 0, 0, 0, 0, 1, 0] n0 = V0
# n1 [ 1, 0, 0, 0, 0, 0, 0] n1 = T0
# n2 [ 0, 1, 0, 0, 0, 0, 0] n2 = T1
```

Reading: threads 0-3 (T0, T1) select N-dimension column pairs, threads
within each group of 4 (T2-T4) select M-dimension rows, and the two
value bits split across M bit 3 and N bit 0.
117 changes: 104 additions & 13 deletions docs/generate_figures.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@
import shutil
from pathlib import Path

import matplotlib.patches as patches
import matplotlib.pyplot as plt

from tensor_layouts import Layout, Swizzle
from tensor_layouts.atoms_nv import SM80_16x8x16_F16F16F16F16_TN
from tensor_layouts.layout_utils import tile_mma_grid
Expand All @@ -49,6 +52,105 @@
IMAGES = Path(__file__).resolve().parent / "images"


def _generate_intile_oftile(path: Path) -> None:
"""intile/oftile coordinate diagram: manual index math vs layout algebra.

Three panels on a 4×8 matrix tiled by (2,4):
Left: cells show linear index i
Center: cells show 2D index (r,c)
Right: cells show intile coords — the output of logical_divide
Tile coloring (shared across all panels) shows the oftile grouping.
Formulas below each panel explain the conversion.
"""
M, K = 4, 8
tm, tk = 2, 4
tile_colors = {
(0, 0): "#DBEAFE", (0, 1): "#FEE2E2",
(1, 0): "#D1FAE5", (1, 1): "#EDE9FE",
}

fig, axes = plt.subplots(1, 3, figsize=(18, 5.2))

def _draw_grid(ax, cell_text_fn, title, subtitle):
"""Draw M×K grid with colored tiles and per-cell text."""
for r in range(M):
for c in range(K):
om, im = r // tm, r % tm
ok, ik = c // tk, c % tk
y = M - 1 - r
rect = patches.Rectangle(
(c, y), 1, 1,
facecolor=tile_colors[(om, ok)],
edgecolor="#D1D5DB", linewidth=0.5,
)
ax.add_patch(rect)
ax.text(
c + 0.5, y + 0.5, cell_text_fn(r, c),
ha="center", va="center", fontsize=9,
color="#374151", family="monospace",
)
# thick tile borders
for i in range(0, M + 1, tm):
ax.plot([0, K], [i, i], color="#1F2937", lw=2.5,
solid_capstyle="butt")
for j in range(0, K + 1, tk):
ax.plot([j, j], [0, M], color="#1F2937", lw=2.5,
solid_capstyle="butt")
# oftile margin labels
for om in range(M // tm):
y_c = M - om * tm - tm / 2
ax.text(-0.3, y_c, f"oftile\u2080={om}", ha="right", va="center",
fontsize=8, fontweight="bold", color="#7C3AED",
family="monospace")
for ok in range(K // tk):
x_c = ok * tk + tk / 2
ax.text(x_c, M + 0.15, f"oftile\u2081={ok}", ha="center",
va="bottom", fontsize=8, fontweight="bold", color="#7C3AED",
family="monospace")
ax.set_xlim(-2.5, K + 0.5)
ax.set_ylim(-1.8, M + 0.8)
ax.axis("off")
ax.set_title(title, fontsize=10.5, fontweight="bold", pad=10)
ax.text(K / 2, -0.3, subtitle, ha="center", va="top", fontsize=8,
color="#6B7280", family="monospace", linespacing=1.6)

# ── Panel 1: linear index ────────────────────────────────────
_draw_grid(
axes[0],
lambda r, c: str(r * K + c),
"Linear index i",
"row = i // 8, col = i % 8\n"
"intile = (row % 2, col % 4)\n"
"oftile = (row // 2, col // 4)",
)

# ── Panel 2: 2D index ────────────────────────────────────────
_draw_grid(
axes[1],
lambda r, c: f"{r},{c}",
"2D index (row, col)",
"intile = (row % 2, col % 4)\n"
"oftile = (row // 2, col // 4)",
)

# ── Panel 3: layout algebra ──────────────────────────────────
_draw_grid(
axes[2],
lambda r, c: f"{r % tm},{c % tk}",
"logical_divide((4,8):(8,1), (2,4))",
"result: ((2,2),(4,2)) : ((8,16),(1,4))\n"
"mode 0: (intile\u2080, oftile\u2080)\n"
"mode 1: (intile\u2081, oftile\u2081)",
)
axes[2].text(K / 2, -1.45, "cells show (intile\u2080, intile\u2081)",
ha="center", va="top", fontsize=9, fontweight="bold",
color="#2563EB", family="monospace")

plt.tight_layout()
fig.savefig(path, dpi=150, bbox_inches="tight")
plt.close(fig)


def main():
IMAGES.mkdir(exist_ok=True)

Expand Down Expand Up @@ -150,19 +252,8 @@ def main():
title="SM80 16x8x16 C \u2014 2x2 atoms",
)

# -- show_layout (no title) --
draw_layout(
layout_8x8,
IMAGES / "show_layout.png",
colorize=True,
)

# -- show_swizzle (no colorize) --
draw_swizzle(
layout_8x8,
Swizzle(3, 0, 3),
IMAGES / "show_swizzle.png",
)
# -- intile / oftile (applications.ipynb §3.3.5) --
_generate_intile_oftile(IMAGES / "intile_oftile.png")

print(f"Generated {len(list(IMAGES.glob('*.png')))} figures in {IMAGES}")

Expand Down
Binary file modified docs/images/draw_composite.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/images/hierarchical.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/images/intile_oftile.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
11 changes: 1 addition & 10 deletions docs/layout_api.md
Original file line number Diff line number Diff line change
Expand Up @@ -309,16 +309,7 @@ swizzled = compose(Swizzle(3, 0, 3), Layout((8, 8), (8, 1)))

## Tensor

`Tensor(layout, offset=0)` combines a Layout with a base offset (the
pointer equivalent from CuTe C++). Supports slicing:

```python
t = Tensor(Layout((4, 8), (8, 1)))
t(2, 5) # 21 — same as layout(2, 5)
t[2, :] # Tensor(8:1, offset=16) — row 2
t[:, 5] # Tensor(4:8, offset=5) — column 5
t[2, 5] # 21 — fix all modes, returns int
```
See [`docs/tensor_api.md`](tensor_api.md).

## Tile

Expand Down
Loading
Loading