Skip to content

Commit eac8afa

Browse files
committed
UI improvements
1 parent 45eb054 commit eac8afa

3 files changed

Lines changed: 16 additions & 19 deletions

File tree

demo_linear_layout.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,8 @@
55
COLOR_AXES = {"warp": "H", "thread": "S", "reg": "L"}
66
COLOR_RANGES = {
77
"H": (0.0, 0.8),
8-
#"S": (0, 1),
98
"S": (0, 0),
10-
#"L": (1.0, 0.25),
9+
#"S": (0.25, 1.0),
1110
"L": (0, 1.0),
1211
}
1312

@@ -23,10 +22,6 @@
2322
["x", "y"],
2423
),
2524
),
26-
'''
27-
T0T1T2T3T4 R0R1R2 -perm->
28-
R1T0T1T2 R0T3T4R2
29-
'''
3025
"mma": (
3126
"MMA A Layout (m16n8k16)",
3227
LinearLayout.from_bases(

linear_layout_viz.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -161,17 +161,16 @@ def _infer_output_dims(
161161
input_dims: list[tuple[str, Any]],
162162
output_names: list[str],
163163
) -> list[tuple[str, int]]:
164-
"""Infer output sizes from the basis vectors, matching the browser path."""
164+
"""Infer power-of-two output sizes from the highest bit used on each axis."""
165165

166-
input_shape = tuple(1 << len(bases) for _dim_name, bases in input_dims)
167-
maxima = [0] * len(output_names)
168-
for input_coord in np.ndindex(input_shape):
169-
output_coord = _map_linear_layout_coord(input_coord, input_dims, len(output_names))
170-
for axis, value in enumerate(output_coord):
171-
maxima[axis] = max(maxima[axis], value)
166+
sizes = [1] * len(output_names)
167+
for _dim_name, bases in input_dims:
168+
for basis in bases:
169+
for axis, value in enumerate(basis[: len(output_names)]):
170+
sizes[axis] = max(sizes[axis], 1 if int(value) <= 0 else 1 << int(value).bit_length())
172171
return [
173-
(dim_name, dim_max + 1)
174-
for dim_name, dim_max in zip(output_names, maxima, strict=True)
172+
(dim_name, dim_size)
173+
for dim_name, dim_size in zip(output_names, sizes, strict=True)
175174
]
176175

177176

@@ -405,12 +404,12 @@ def create_layout_session_data(
405404

406405
session_data = create_session_data(
407406
{
408-
"Hardware tensor": hardware_tensor,
409-
"Logical tensor": logical_tensor,
407+
"Hardware Layout": hardware_tensor,
408+
"Logical Layout": logical_tensor,
410409
},
411410
name=name or "Layout",
412411
labels={
413-
"Hardware tensor": _viewer_axis_labels(hardware_names),
412+
"Hardware Layout": _viewer_axis_labels(hardware_names),
414413
},
415414
color_instructions={
416415
"tensor-1": [
@@ -422,6 +421,9 @@ def create_layout_session_data(
422421
},
423422
)
424423
manifest = json.loads(session_data.manifest_bytes)
424+
logical_marker_coords = np.argwhere(logical_tensor < 0).tolist()
425+
if logical_marker_coords:
426+
manifest["tabs"][0]["tensors"][1]["markerCoords"] = logical_marker_coords
425427
manifest["tabs"][0]["viewer"]["dimensionMappingScheme"] = "contiguous"
426428
manifest["tabs"][0]["viewer"]["linearLayoutState"] = linear_layout_state
427429
manifest["tabs"][0]["viewer"]["linearLayoutSpec"] = linear_layout_spec

0 commit comments

Comments
 (0)