Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
d7a439c
Add vertical arrangement to draw_swizzle for wide layouts
jduprat Apr 7, 2026
bbb0407
Add im2col figure and clarify CONV→GEMM mapping
jduprat Apr 7, 2026
ccf637d
Fix logical_divide to support Layout tilers in by-mode tuples
jduprat Apr 8, 2026
d45da8e
Fix compose to truncate unreachable modes before divisibility check
jduprat Apr 8, 2026
7760f7f
Fix left_inverse for non-contiguous (padded) layouts
jduprat Apr 8, 2026
0358750
Add paper examples test suite (arXiv:2603.02298v1)
jduprat Apr 8, 2026
20e2d7a
Fix Tensor storage validation for offsets and negative strides
jduprat Apr 8, 2026
37329cd
Support negative strides in Layout and Tensor.view()
jduprat Apr 9, 2026
f9b8f6b
Handle negative strides in analysis functions
jduprat Apr 9, 2026
c478b38
Rebase negative offsets in visualization functions
jduprat Apr 9, 2026
d20cd53
Reject free coordinates in Tensor.__setitem__
jduprat Apr 9, 2026
2917ecb
Fix compose and logical_divide for nested tuple tilers
jduprat Apr 9, 2026
60e2216
Fix divide variants to preserve Layout tiler strides
jduprat Apr 9, 2026
3653cc4
Fix explain(logical_product) to use cosize(B) for complement bound
jduprat Apr 9, 2026
6a0ce0e
Fix explain(compose) crash on tuple tilers
jduprat Apr 9, 2026
b224e54
Configure Ruff and fix lint warnings across the codebase
jduprat Apr 9, 2026
881b918
Fix duplicate test name shadowing draw_swizzle coverage
jduprat Apr 9, 2026
e12c156
Canonicalize stride to 0 for unit-extent modes in logical_divide
jduprat Apr 9, 2026
c287e8c
Add CuTe C++ oracle tests for compose and logical_divide regressions
jduprat Apr 9, 2026
dc14e9c
Preserve swizzle attribute in slice_and_offset sublayout results
jduprat Apr 9, 2026
699349f
Handle bare None as a full-slice identity in Layout.__call__
jduprat Apr 9, 2026
891bddd
Support tensor[:] as a whole-view full slice on Tensor
jduprat Apr 9, 2026
ef43526
Document shape_div's strict scalar divisibility policy
jduprat Apr 9, 2026
5d5a6c8
Move exhaustive introspection helpers from layouts.py to analysis.py
jduprat Apr 9, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -144,14 +144,28 @@ pip install -e ".[test]"
pytest tests/
```

For local linting, install the dev extras and run Ruff on the Python sources:

```bash
pip install -e ".[dev]"
ruff check src/ tests/ examples/
```

The default Ruff configuration excludes `*.ipynb`; notebooks are treated as
worked material rather than part of the Python lint surface.

Oracle tests cross-validate against vendor reference implementations and are
skipped automatically if the corresponding tool is unavailable:

```bash
# NVIDIA (cross-validation against pycute)
# NVIDIA pycute oracle
pip install -e ".[test,oracle-nv]"
pytest tests/oracle_nv.py

# Direct CuTe C++ oracle
# Requires a C++ compiler plus CUTLASS/CUDA headers in the active environment.
pytest tests/oracle_cute_cpp.py

# AMD (cross-validation against amd_matrix_instruction_calculator)
pip install -e ".[test,oracle-amd]"
pytest tests/oracle_amd.py
Expand Down
66 changes: 66 additions & 0 deletions docs/analysis_api.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,53 @@ bijective layouts, and trace the algebra step by step.

```python
from tensor_layouts.analysis import (
image, is_injective, is_surjective, is_bijective,
offset_table, bank_conflicts, coalescing_efficiency,
cycles, fixed_points, order, explain,
)
```

---

## Image and Injectivity

These functions analyze a layout viewed as a function from coordinates to
memory offsets. They enumerate all coordinates, so their cost is O(size).

| Function | Description |
|----------|-------------|
| `image(L)` | Sorted list of distinct offsets produced |
| `is_injective(L)` | True if no two coordinates share an offset |
| `is_surjective(L, codomain_size=None)` | True if every offset in `[0, codomain)` is hit |
| `is_bijective(L)` | True if both injective and surjective (a permutation) |
| `is_contiguous(L)` | Alias for `is_bijective` — reads as "one dense block?" |

```python
layout = Layout((4, 8), (1, 4))
image(layout) # [0, 1, 2, ..., 31]
is_bijective(layout) # True

broadcast = Layout((4, 8), (0, 1))
image(broadcast) # [0, 1, 2, 3, 4, 5, 6, 7]
is_injective(broadcast) # False (stride-0 causes aliasing)
```

## Functional Equivalence

`functionally_equal(a, b)` returns True if two layouts compute the same
offset for every flat index, even when they have different shapes or strides.
This is useful for verifying that algebraic transformations like `coalesce()`
and `flatten()` preserve behavior. Cost is O(size).

```python
L = Layout(((2, 2), (2, 4)), ((1, 4), (2, 8)))
coalesce(L) == L # False (structurally different)
functionally_equal(L, coalesce(L)) # True (same mapping)
functionally_equal(L, flatten(L)) # True
```

---

## offset_table(layout)

Inverse mapping: `{offset: [coord, ...]}`. Reveals aliasing --- when
Expand Down Expand Up @@ -289,6 +329,32 @@ explain(complement, Layout(4, 2), 16)
# image(complement) = [0, 1, 8, 9]
```

`explain(compose, ...)` also handles tuple tilers directly and shows the
mode-by-mode decomposition CuTe uses:

```python
explain(compose, Layout((4, 8), (8, 1)), (2, 4))
# compose((4, 8) : (8, 1), (2, 4))
# For tuple tilers, composition is applied mode-by-mode.
#
# A = (4, 8) : (8, 1)
# B = (2, 4)
# result = (2, 4) : (8, 1)
# mode 0: compose(4 : 8, 2 : 1) = 2 : 8
# mode 1: compose(8 : 1, 4 : 1) = 4 : 1
```

For true `Layout` tilers, the `logical_product` explanation follows CuTe's
actual bound `size(A) * cosize(B)` rather than `size(A) * size(B)`:

```python
explain(logical_product, Layout(4, 1), Layout(3, 2))
# ...
# size(A) = 4
# cosize(B) = 5
# size(A) * cosize(B) = 20
```

## F2 Linear Layout Matrix

`to_F2_matrix(layout)` converts a layout with power-of-2 shapes to its
Expand Down
148 changes: 148 additions & 0 deletions docs/generate_figures.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,151 @@ def _draw_grid(ax, cell_text_fn, title, subtitle):
plt.close(fig)


def _generate_im2col(
path: Path,
H: int = 4,
W: int = 4,
R: int = 2,
S: int = 2,
) -> None:
"""im2col diagram: input matrix on the left, unrolled output on the right.

Shows how a sliding R×S window over an H×W input produces a (P*Q)×(R*S)
matrix where each row is one flattened window. Color-coding links each
window position to its row in the output.
"""
import colorsys
from collections import defaultdict

P, Q = H - R + 1, W - S + 1
n_windows = P * Q
n_taps = R * S
labels = [chr(ord("A") + i) for i in range(H * W)]

# -- im2col index matrix: each row lists the H*W indices for one window --
im2col_rows = []
for p in range(P):
for q in range(Q):
im2col_rows.append(
[(p + r) * W + (q + s) for r in range(R) for s in range(S)]
)

# -- color palette: one pastel color per window --
window_colors_rgb = []
for i in range(n_windows):
hue = i / n_windows
r, g, b = colorsys.hsv_to_rgb(hue, 0.30, 0.95)
window_colors_rgb.append((r, g, b))

def rgb_to_hex(rgb):
return f"#{int(rgb[0]*255):02X}{int(rgb[1]*255):02X}{int(rgb[2]*255):02X}"

window_colors = [rgb_to_hex(c) for c in window_colors_rgb]

# -- per-input-cell color: blend all windows that cover each cell --
cell_windows = defaultdict(list)
for win_idx, indices in enumerate(im2col_rows):
for idx in indices:
cell_windows[idx].append(win_idx)

def blend(win_indices):
rgbs = [window_colors_rgb[i] for i in win_indices]
avg = tuple(sum(c[k] for c in rgbs) / len(rgbs) for k in range(3))
return rgb_to_hex(avg)

# -- figure --
fig, axes = plt.subplots(
1, 2, figsize=(4 + n_taps * 1.2, max(H, n_windows) * 0.8 + 1.5),
gridspec_kw={"width_ratios": [W, n_taps * 1.1]},
)

# ── Left panel: input grid ──────────────────────────────────────
ax = axes[0]
for row in range(H):
for col in range(W):
idx = row * W + col
y = H - 1 - row
color = blend(cell_windows[idx])
rect = patches.Rectangle(
(col, y), 1, 1,
facecolor=color, edgecolor="#D1D5DB", linewidth=0.5,
)
ax.add_patch(rect)
ax.text(
col + 0.5, y + 0.5, labels[idx],
ha="center", va="center", fontsize=13,
color="#374151", family="monospace", fontweight="bold",
)
# thick outer border
rect = patches.Rectangle(
(0, 0), W, H,
facecolor="none", edgecolor="#1F2937", linewidth=2.5,
)
ax.add_patch(rect)
ax.set_xlim(-0.5, W + 0.5)
ax.set_ylim(-0.8, H + 0.8)
ax.set_aspect("equal")
ax.axis("off")
ax.set_title(f"Input ({H}\u00d7{W})", fontsize=11, fontweight="bold", pad=10)

# ── Right panel: im2col output grid ─────────────────────────────
ax = axes[1]
for row_idx in range(n_windows):
y = n_windows - 1 - row_idx
for col_idx in range(n_taps):
cell_label = labels[im2col_rows[row_idx][col_idx]]
rect = patches.Rectangle(
(col_idx, y), 1, 1,
facecolor=window_colors[row_idx],
edgecolor="#D1D5DB", linewidth=0.5,
)
ax.add_patch(rect)
ax.text(
col_idx + 0.5, y + 0.5, cell_label,
ha="center", va="center", fontsize=12,
color="#374151", family="monospace", fontweight="bold",
)
# row label: window position (p, q)
p, q = divmod(row_idx, Q)
ax.text(
-0.2, y + 0.5, f"({p},{q})",
ha="right", va="center", fontsize=8,
color="#6B7280", family="monospace",
)
# thick outer border
rect = patches.Rectangle(
(0, 0), n_taps, n_windows,
facecolor="none", edgecolor="#1F2937", linewidth=2.5,
)
ax.add_patch(rect)
ax.set_xlim(-1.5, n_taps + 0.5)
ax.set_ylim(-0.5, n_windows + 0.8)
ax.set_aspect("equal")
ax.axis("off")
ax.set_title(
f"im2col output ({n_windows}\u00d7{n_taps})",
fontsize=11, fontweight="bold", pad=10,
)

# ── Arrow connecting the panels ─────────────────────────────────
arrow = patches.FancyArrowPatch(
(0.44, 0.5), (0.52, 0.5),
transform=fig.transFigure,
arrowstyle="->,head_width=6,head_length=5",
color="#374151", linewidth=2,
)
fig.patches.append(arrow)
fig.text(
0.48, 0.54, f"im2col({R}\u00d7{S})",
ha="center", va="bottom", fontsize=11, fontweight="bold",
color="#374151", family="monospace", transform=fig.transFigure,
)

plt.tight_layout()
fig.savefig(path, dpi=150, bbox_inches="tight")
plt.close(fig)


def main():
IMAGES.mkdir(exist_ok=True)

Expand Down Expand Up @@ -255,6 +400,9 @@ def main():
# -- intile / oftile (applications.ipynb §3.3.5) --
_generate_intile_oftile(IMAGES / "intile_oftile.png")

# -- im2col (algorithms.ipynb §CONV) --
_generate_im2col(IMAGES / "im2col.png")

print(f"Generated {len(list(IMAGES.glob('*.png')))} figures in {IMAGES}")


Expand Down
Binary file added docs/images/im2col.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Loading