From 544f231f66c6ff6bfa31b45d25a1e04ef7778066 Mon Sep 17 00:00:00 2001 From: Rob Taylor Date: Fri, 6 Mar 2026 15:35:36 +0000 Subject: [PATCH 01/79] feat: Add parallelism analysis for IREE/Baspacho test cases Analyzes Jacobian sparsity patterns to report exploitable parallelism: - Elimination tree with level-set parallelism metrics - Supernodal detection for BLAS-3 opportunities - Fill-in analysis via SuperLU (MMD and COLAMD orderings) - RCM bandwidth reduction - Device scatter pattern analysis (vmap fan-in conflicts) - Pattern stability verification across NR iterations Supports two modes: - Benchmark mode: runs simulation and captures matrices - File mode: analyzes existing Matrix Market files Outputs JSON, human-readable summary, level-set CSV, and etree .npy. Co-developed-by: Claude Code v2.1.58 (claude-opus-4-6) --- scripts/analyze_parallelism.py | 1048 ++++++++++++++++++++++++++++++++ 1 file changed, 1048 insertions(+) create mode 100644 scripts/analyze_parallelism.py diff --git a/scripts/analyze_parallelism.py b/scripts/analyze_parallelism.py new file mode 100644 index 00000000..4329d204 --- /dev/null +++ b/scripts/analyze_parallelism.py @@ -0,0 +1,1048 @@ +# /// script +# requires-python = ">=3.10" +# dependencies = [] +# /// +"""Analyze parallelism opportunities in VAJAX simulation matrices. + +For IREE & Baspacho test case context: given a circuit's Jacobian sparsity +pattern, reports what parallelism can be exploited during factorization, +assembly, and device evaluation. + +Key outputs: +- Elimination tree: dependency structure for sparse factorization +- Level-set parallelism: how many columns can be processed simultaneously +- Supernodal structure: dense blocks exploitable with BLAS-3 +- Fill-in analysis: memory requirements for factorization +- Device evaluation parallelism: scatter pattern from vmap'd device evals +- Pattern stability: sparsity is fixed across all NR iterations + +Usage: + # Analyze from a benchmark (captures matrices + device info) + JAX_PLATFORMS=cpu uv run scripts/analyze_parallelism.py ring + JAX_PLATFORMS=cpu uv run scripts/analyze_parallelism.py c6288 + + # Analyze existing Matrix Market file + uv run scripts/analyze_parallelism.py --from-mtx path/to/jacobian_0000.mtx + + # Output to specific directory + JAX_PLATFORMS=cpu uv run scripts/analyze_parallelism.py ring --output-dir /tmp/ring_par +""" + +import argparse +import json +import os +import sys +from collections import Counter +from pathlib import Path + +os.environ.setdefault("JAX_PLATFORMS", "cpu") + +import numpy as np +import scipy.io +import scipy.sparse as sp +import scipy.sparse.linalg + + +# --------------------------------------------------------------------------- +# Elimination tree +# --------------------------------------------------------------------------- + + +def symmetrize_pattern(A: sp.spmatrix) -> sp.csc_matrix: + """Compute |A| + |A^T| as a binary pattern (no values, just structure).""" + A_csc = sp.csc_matrix(A) + # Binary pattern: set all values to 1 + A_bin = sp.csc_matrix( + (np.ones(A_csc.nnz), A_csc.indices, A_csc.indptr), shape=A_csc.shape + ) + A_sym = A_bin + A_bin.T + # Re-binarize (eliminates any 2s from diagonal overlap) + A_sym.data[:] = 1.0 + A_sym.eliminate_zeros() + return sp.csc_matrix(A_sym) + + +def compute_etree(A_csc: sp.csc_matrix) -> np.ndarray: + """Compute elimination tree of a symmetric matrix. + + Uses Liu's algorithm with path compression (union-find). + Only the upper triangle is used. + + Args: + A_csc: Symmetric matrix in CSC format + + Returns: + parent array where parent[i] is the parent of column i, + or -1 for root(s) + """ + n = A_csc.shape[0] + parent = np.full(n, -1, dtype=np.int64) + ancestor = np.arange(n, dtype=np.int64) + + indptr = A_csc.indptr + indices = A_csc.indices + + for k in range(n): + for ptr in range(indptr[k], indptr[k + 1]): + i = indices[ptr] + if i >= k: + continue + # Find root of i with path compression + r = i + while ancestor[r] != r: + r = ancestor[r] + if r != k: + parent[r] = k + ancestor[r] = k + # Path compression for i + r = i + while ancestor[r] != k: + t = ancestor[r] + ancestor[r] = k + r = t + + return parent + + +def compute_level_sets(parent: np.ndarray) -> list[list[int]]: + """Compute level sets from an elimination tree. + + Level 0 = leaves, higher levels = closer to root. + Columns at the same level have no dependencies and can be + processed in parallel. + + Returns: + List of levels, where levels[d] = list of column indices at depth d + (depth measured from leaves, so leaves are at depth 0) + """ + n = len(parent) + # Compute depth from root first + depth_from_root = np.full(n, -1, dtype=np.int64) + + # Find roots + roots = np.where(parent == -1)[0] + for r in roots: + depth_from_root[r] = 0 + + # BFS from roots to compute depth_from_root + # Build children lists for top-down traversal + children = [[] for _ in range(n)] + for i in range(n): + if parent[i] != -1: + children[parent[i]].append(i) + + queue = list(roots) + head = 0 + while head < len(queue): + node = queue[head] + head += 1 + for child in children[node]: + depth_from_root[child] = depth_from_root[node] + 1 + queue.append(child) + + max_depth = int(np.max(depth_from_root)) if n > 0 else 0 + + # Convert to bottom-up levels (leaves = 0) + depth_from_leaves = max_depth - depth_from_root + + levels: list[list[int]] = [[] for _ in range(max_depth + 1)] + for i in range(n): + levels[depth_from_leaves[i]].append(i) + + return levels + + +def compute_etree_stats(parent: np.ndarray, levels: list[list[int]]) -> dict: + """Compute statistics about elimination tree parallelism.""" + n = len(parent) + widths = [len(level) for level in levels] + height = len(levels) + + # Count leaves (nodes with no children) + has_child = np.zeros(n, dtype=bool) + for i in range(n): + if parent[i] != -1: + has_child[parent[i]] = True + n_leaves = int(np.sum(~has_child)) + + # Subtree sizes + subtree_size = np.ones(n, dtype=np.int64) + # Process bottom-up: levels[0] are leaves + for level in levels: + for node in level: + if parent[node] != -1: + subtree_size[parent[node]] += subtree_size[node] + + return { + "height": height, + "n_leaves": n_leaves, + "max_parallelism": max(widths) if widths else 0, + "avg_parallelism": float(np.mean(widths)) if widths else 0, + "min_parallelism": min(widths) if widths else 0, + "level_widths": widths, + "subtree_size_stats": { + "min": int(np.min(subtree_size)) if n > 0 else 0, + "max": int(np.max(subtree_size)) if n > 0 else 0, + "mean": float(np.mean(subtree_size)) if n > 0 else 0, + "median": float(np.median(subtree_size)) if n > 0 else 0, + }, + } + + +# --------------------------------------------------------------------------- +# Supernodal detection +# --------------------------------------------------------------------------- + + +def detect_supernodes(parent: np.ndarray, A_csc: sp.csc_matrix) -> list[list[int]]: + """Detect fundamental supernodes in the elimination tree. + + A fundamental supernode is a maximal chain of consecutive columns + j, j+1, ..., j+k where: + - parent[j] = j+1, parent[j+1] = j+2, ..., parent[j+k-1] = j+k + - The columns have nested sparsity patterns (each is a subset of the next) + + These can be factored as dense blocks using BLAS-3 operations. + """ + n = A_csc.shape[0] + if n == 0: + return [] + + # Count children per node + n_children = np.zeros(n, dtype=np.int64) + for i in range(n): + if parent[i] != -1: + n_children[parent[i]] += 1 + + # A node starts a new supernode if: + # - It has more than one child, OR + # - It is not the only child of its parent, OR + # - parent[i] != i + 1 + is_supernode_start = np.ones(n, dtype=bool) + for i in range(n - 1): + if parent[i] == i + 1 and n_children[i + 1] == 1: + is_supernode_start[i + 1] = False + + supernodes: list[list[int]] = [] + current: list[int] = [] + for i in range(n): + if is_supernode_start[i]: + if current: + supernodes.append(current) + current = [i] + else: + current.append(i) + if current: + supernodes.append(current) + + return supernodes + + +def supernode_stats(supernodes: list[list[int]]) -> dict: + """Compute statistics about supernodal structure.""" + sizes = [len(s) for s in supernodes] + size_counts = Counter(sizes) + + # Bucket into histogram ranges + histogram = {} + for size, count in sorted(size_counts.items()): + if size == 1: + histogram["1"] = histogram.get("1", 0) + count + elif size <= 4: + histogram["2-4"] = histogram.get("2-4", 0) + count + elif size <= 10: + histogram["5-10"] = histogram.get("5-10", 0) + count + elif size <= 50: + histogram["11-50"] = histogram.get("11-50", 0) + count + else: + histogram["51+"] = histogram.get("51+", 0) + count + + return { + "count": len(supernodes), + "largest": max(sizes) if sizes else 0, + "mean_size": float(np.mean(sizes)) if sizes else 0, + "median_size": float(np.median(sizes)) if sizes else 0, + "size_histogram": histogram, + } + + +# --------------------------------------------------------------------------- +# Fill-in analysis +# --------------------------------------------------------------------------- + + +def fill_in_analysis(A: sp.spmatrix) -> dict: + """Analyze fill-in from LU factorization using scipy's SuperLU. + + Uses MMD_AT_PLUS_A ordering for fill-reducing permutation. + """ + A_csc = sp.csc_matrix(A, dtype=np.float64) + n = A_csc.shape[0] + original_nnz = A_csc.nnz + + results = {} + + for ordering_name, permc_spec in [ + ("MMD_AT_PLUS_A", "MMD_AT_PLUS_A"), + ("COLAMD", "COLAMD"), + ]: + try: + lu = scipy.sparse.linalg.splu( + A_csc, + permc_spec=permc_spec, + options={"SymmetricMode": False}, + ) + l_nnz = lu.L.nnz + u_nnz = lu.U.nnz + factor_nnz = l_nnz + u_nnz - n # subtract diagonal counted twice + + results[ordering_name] = { + "L_nnz": l_nnz, + "U_nnz": u_nnz, + "factor_nnz": factor_nnz, + "fill_ratio": factor_nnz / max(original_nnz, 1), + "fill_in": factor_nnz - original_nnz, + } + except Exception as e: + results[ordering_name] = {"error": str(e)} + + return { + "original_nnz": original_nnz, + "orderings": results, + "best_ordering": min( + (k for k, v in results.items() if "error" not in v), + key=lambda k: results[k]["factor_nnz"], + default=None, + ), + } + + +# --------------------------------------------------------------------------- +# Matrix structure analysis +# --------------------------------------------------------------------------- + + +def matrix_structure_analysis(A: sp.spmatrix) -> dict: + """Analyze structural properties of the matrix.""" + A_csc = sp.csc_matrix(A) + A_csr = sp.csr_matrix(A) + n = A_csc.shape[0] + + # Bandwidth + rows, cols = A_csc.nonzero() + if len(rows) > 0: + bandwidth = int(np.max(np.abs(rows - cols))) + profile = int(np.sum(np.abs(rows - cols))) + else: + bandwidth = 0 + profile = 0 + + # Degree distribution (treating matrix as adjacency matrix) + row_nnz = np.diff(A_csr.indptr) + col_nnz = np.diff(A_csc.indptr) + + # Symmetry check + A_T = A_csc.T + sym_diff = A_csc - A_T + sym_diff.eliminate_zeros() + is_structurally_symmetric = sym_diff.nnz == 0 + + # Check for numerical symmetry + if is_structurally_symmetric: + val_diff = np.max(np.abs(A_csc.data - A_T.tocsc().data)) if A_csc.nnz > 0 else 0 + is_numerically_symmetric = val_diff < 1e-10 + else: + is_numerically_symmetric = False + + # Connected components (treating as undirected graph) + A_sym_pattern = symmetrize_pattern(A_csc) + n_components, labels = sp.csgraph.connected_components(A_sym_pattern, directed=False) + + component_sizes = Counter(labels.tolist()) + component_size_list = sorted(component_sizes.values(), reverse=True) + + # Diagonal dominance check + diag = np.abs(A_csc.diagonal()) + row_sums = np.array(np.abs(A_csr).sum(axis=1)).ravel() + off_diag_sums = row_sums - diag + diag_dominant_rows = int(np.sum(diag >= off_diag_sums)) + + return { + "size": n, + "nnz": A_csc.nnz, + "density_pct": A_csc.nnz / (n * n) * 100 if n > 0 else 0, + "bandwidth": bandwidth, + "profile": profile, + "is_structurally_symmetric": is_structurally_symmetric, + "is_numerically_symmetric": is_numerically_symmetric, + "connected_components": n_components, + "component_sizes": component_size_list[:10], # Top 10 + "diagonal_dominance": { + "dominant_rows": diag_dominant_rows, + "total_rows": n, + "pct": diag_dominant_rows / n * 100 if n > 0 else 0, + }, + "degree_stats": { + "row_min": int(np.min(row_nnz)) if n > 0 else 0, + "row_max": int(np.max(row_nnz)) if n > 0 else 0, + "row_mean": float(np.mean(row_nnz)) if n > 0 else 0, + "col_min": int(np.min(col_nnz)) if n > 0 else 0, + "col_max": int(np.max(col_nnz)) if n > 0 else 0, + "col_mean": float(np.mean(col_nnz)) if n > 0 else 0, + }, + } + + +# --------------------------------------------------------------------------- +# RCM ordering analysis +# --------------------------------------------------------------------------- + + +def rcm_analysis(A: sp.spmatrix) -> dict: + """Analyze effect of Reverse Cuthill-McKee ordering.""" + A_sym = symmetrize_pattern(A) + n = A_sym.shape[0] + + try: + perm = sp.csgraph.reverse_cuthill_mckee(A_sym, symmetric_mode=True) + A_rcm = A_sym[perm][:, perm] + + rows_orig, cols_orig = A_sym.nonzero() + rows_rcm, cols_rcm = A_rcm.nonzero() + + bw_orig = int(np.max(np.abs(rows_orig - cols_orig))) if len(rows_orig) > 0 else 0 + bw_rcm = int(np.max(np.abs(rows_rcm - cols_rcm))) if len(rows_rcm) > 0 else 0 + + return { + "bandwidth_original": bw_orig, + "bandwidth_rcm": bw_rcm, + "bandwidth_reduction_pct": (1 - bw_rcm / max(bw_orig, 1)) * 100, + "permutation_available": True, + } + except Exception as e: + return {"error": str(e), "permutation_available": False} + + +# --------------------------------------------------------------------------- +# Device scatter pattern analysis +# --------------------------------------------------------------------------- + + +def analyze_device_scatter(engine) -> dict: + """Analyze device-to-matrix scatter patterns for assembly parallelism. + + Examines the stamp index mappings to determine: + - How many matrix positions are written by multiple devices (conflicts) + - Maximum fan-in to any single position + - Independence structure between device evaluations + """ + setup = engine._build_transient_setup(backend="cpu", use_dense=True) + static_inputs_cache = setup["static_inputs_cache"] + openvaf_by_type = setup["openvaf_by_type"] + n_unknowns = setup["n_unknowns"] + + model_info = {} + # Global position → set of unique (model_type, device_idx) writers + global_position_writers: dict[tuple[int, int], set[tuple[str, int]]] = {} + + for model_type, (voltage_indices, stamp_indices, *_rest) in static_inputs_cache.items(): + jac_rows = np.asarray(stamp_indices["jac_row_indices"]) + jac_cols = np.asarray(stamp_indices["jac_col_indices"]) + res_indices = np.asarray(stamp_indices["res_indices"]) + + n_devices = jac_rows.shape[0] + n_jac_entries = jac_rows.shape[1] + n_residuals = res_indices.shape[1] + + # Count unique positions per device + positions_per_device = [] + for dev_idx in range(n_devices): + valid = (jac_rows[dev_idx] >= 0) & (jac_cols[dev_idx] >= 0) + unique_pos = set() + for j in range(n_jac_entries): + if valid[j]: + pos = (int(jac_rows[dev_idx, j]), int(jac_cols[dev_idx, j])) + unique_pos.add(pos) + if pos not in global_position_writers: + global_position_writers[pos] = set() + global_position_writers[pos].add((model_type, dev_idx)) + positions_per_device.append(len(unique_pos)) + + # Count touched nodes per device (for residual fan-out) + nodes_per_device = [] + for dev_idx in range(n_devices): + valid_nodes = set() + for r in range(n_residuals): + idx = int(res_indices[dev_idx, r]) + if idx >= 0: + valid_nodes.add(idx) + nodes_per_device.append(len(valid_nodes)) + + model_info[model_type] = { + "n_devices": n_devices, + "jac_entries_per_device": n_jac_entries, + "residuals_per_device": n_residuals, + "unique_positions_per_device": { + "min": min(positions_per_device) if positions_per_device else 0, + "max": max(positions_per_device) if positions_per_device else 0, + "mean": float(np.mean(positions_per_device)) if positions_per_device else 0, + }, + "nodes_per_device": { + "min": min(nodes_per_device) if nodes_per_device else 0, + "max": max(nodes_per_device) if nodes_per_device else 0, + "mean": float(np.mean(nodes_per_device)) if nodes_per_device else 0, + }, + } + + # Analyze scatter conflicts (unique devices per position) + fan_in_counts = [len(writers) for writers in global_position_writers.values()] + fan_in_counter = Counter(fan_in_counts) + conflict_positions = sum(1 for c in fan_in_counts if c > 1) + + # Build device conflict graph: two devices conflict if they write to + # the same matrix position + n_total_devices = sum(info["n_devices"] for info in model_info.values()) + conflict_edges = 0 + for writers in global_position_writers.values(): + n_writers = len(writers) + if n_writers > 1: + conflict_edges += n_writers * (n_writers - 1) // 2 + + return { + "n_unknowns": n_unknowns, + "total_devices": n_total_devices, + "model_types": model_info, + "scatter_conflicts": { + "total_positions": len(global_position_writers), + "conflict_positions": conflict_positions, + "conflict_pct": conflict_positions / max(len(global_position_writers), 1) * 100, + "max_fan_in": max(fan_in_counts) if fan_in_counts else 0, + "fan_in_distribution": {str(k): v for k, v in sorted(fan_in_counter.items())}, + }, + "device_conflict_graph": { + "n_nodes": n_total_devices, + "n_edges": conflict_edges, + "note": "Edges connect devices that write to the same matrix position", + }, + } + + +# --------------------------------------------------------------------------- +# Pattern stability check +# --------------------------------------------------------------------------- + + +def check_pattern_stability(matrices: list[sp.spmatrix]) -> dict: + """Verify that sparsity pattern is identical across NR iterations. + + This is a key property for IREE: the pattern is fixed, only values change, + so symbolic analysis can be compiled once and reused. + """ + if len(matrices) < 2: + return { + "is_fixed": True, + "n_samples": len(matrices), + "note": "Only one matrix available, cannot verify stability", + } + + ref = sp.csc_matrix(matrices[0]) + ref_pattern = set(zip(*ref.nonzero())) + + all_match = True + first_mismatch = None + + for idx, M in enumerate(matrices[1:], 1): + M_csc = sp.csc_matrix(M) + M_pattern = set(zip(*M_csc.nonzero())) + + if M_pattern != ref_pattern: + all_match = False + added = M_pattern - ref_pattern + removed = ref_pattern - M_pattern + first_mismatch = { + "index": idx, + "added_entries": len(added), + "removed_entries": len(removed), + } + break + + # Value variation statistics (how much do values change across iterations?) + if all_match and len(matrices) >= 2: + values = np.column_stack([sp.csc_matrix(M).data for M in matrices]) + rel_variation = np.std(values, axis=1) / (np.abs(np.mean(values, axis=1)) + 1e-30) + value_stats = { + "mean_relative_variation": float(np.mean(rel_variation)), + "max_relative_variation": float(np.max(rel_variation)), + "median_relative_variation": float(np.median(rel_variation)), + } + else: + value_stats = None + + return { + "is_fixed": all_match, + "n_samples": len(matrices), + "first_mismatch": first_mismatch, + "value_variation": value_stats, + "note": ( + "Sparsity pattern is identical across all samples — symbolic " + "factorization can be compiled once and reused for all NR iterations" + if all_match + else "WARNING: Sparsity pattern changes between iterations" + ), + } + + +# --------------------------------------------------------------------------- +# Full analysis pipeline +# --------------------------------------------------------------------------- + + +def analyze_matrix( + A: sp.spmatrix, + name: str = "", + all_matrices: list[sp.spmatrix] | None = None, +) -> dict: + """Run full parallelism analysis on a Jacobian matrix. + + Args: + A: The Jacobian matrix (any sparse format) + name: Circuit/benchmark name for labeling + all_matrices: Optional list of matrices for pattern stability check + + Returns: + Dict with all analysis results + """ + A_csc = sp.csc_matrix(A, dtype=np.float64) + n = A_csc.shape[0] + + print(f"Analyzing {n}x{n} matrix ({A_csc.nnz} nonzeros)...") + + # 1. Matrix structure + print(" Matrix structure...") + structure = matrix_structure_analysis(A_csc) + + # 2. Elimination tree on symmetrized pattern + print(" Elimination tree...") + A_sym = symmetrize_pattern(A_csc) + parent = compute_etree(A_sym) + levels = compute_level_sets(parent) + etree_stats = compute_etree_stats(parent, levels) + + # 3. Supernodes + print(" Supernodal detection...") + supernodes = detect_supernodes(parent, A_sym) + snode_stats = supernode_stats(supernodes) + + # 4. Fill-in analysis + print(" Fill-in analysis (SuperLU)...") + fill = fill_in_analysis(A_csc) + + # 5. RCM ordering + print(" RCM ordering...") + rcm = rcm_analysis(A_csc) + + # 6. Pattern stability + stability = None + if all_matrices and len(all_matrices) > 1: + print(f" Pattern stability ({len(all_matrices)} samples)...") + stability = check_pattern_stability(all_matrices) + + # Compute parallelism summary + widths = etree_stats["level_widths"] + # "Work" at each level = width (number of independent columns) + # Total sequential steps = height + # Total work = n (all columns must be processed) + # Parallelism efficiency = n / height (ideal speedup from parallelism) + parallelism_efficiency = n / max(etree_stats["height"], 1) + + analysis = { + "name": name, + "matrix": structure, + "_etree_parent": parent, # Full array, not serialized to JSON + "elimination_tree": { + **etree_stats, + "parallelism_efficiency": parallelism_efficiency, + "parent_array_sample": parent[:min(50, n)].tolist(), + "note": ( + f"Height {etree_stats['height']} levels with max width " + f"{etree_stats['max_parallelism']}. Columns at the same level " + f"can be factored in parallel. Efficiency = n/height = " + f"{parallelism_efficiency:.1f}x theoretical speedup." + ), + }, + "supernodes": snode_stats, + "fill_in": fill, + "rcm_ordering": rcm, + } + + if stability is not None: + analysis["pattern_stability"] = stability + + return analysis + + +# --------------------------------------------------------------------------- +# Benchmark mode: run simulation and analyze +# --------------------------------------------------------------------------- + + +def analyze_benchmark( + benchmark_name: str, + max_captures: int = 20, + t_stop_override: float | None = None, +) -> dict: + """Run a benchmark simulation, capture matrices, and analyze parallelism. + + Also captures device scatter pattern information. + """ + import jax + + from vajax.analysis import CircuitEngine + from vajax.benchmarks.registry import get_benchmark + + info = get_benchmark(benchmark_name) + assert info is not None, f"Benchmark '{benchmark_name}' not found" + + engine = CircuitEngine(info.sim_path) + engine.parse() + + use_sparse = info.is_large + dt = info.dt + # Use override, or run just enough steps to capture max_captures matrices + # (~5 NR iterations per timestep, so max_captures/5 timesteps plus margin) + if t_stop_override is not None: + t_stop = t_stop_override + else: + # Short simulation: enough for max_captures NR systems + min_steps = max_captures * 2 # ~2x margin (5 NR iters, capture early ones) + t_stop = min(dt * min_steps, info.t_stop) + + print(f"Benchmark: {benchmark_name}") + print(f" Nodes: {engine.num_nodes}, Devices: {len(engine.devices)}") + print(f" Solver: {'sparse' if use_sparse else 'dense'}") + print(f" Transient: t_stop={t_stop:.2e}s, dt={dt:.2e}s") + + # Suppress step-by-step logging during simulation + import logging + + logging.getLogger("vajax").setLevel(logging.WARNING) + + # --- Device scatter analysis (before running simulation) --- + print("\nAnalyzing device scatter patterns...") + device_scatter = analyze_device_scatter(engine) + + # --- Capture matrices via monkey-patching --- + import vajax.analysis.solver_factories as sf + + captured_systems: list[tuple[np.ndarray, np.ndarray]] = [] + csr_info: dict = {} + capture_count = [0] + + def capture_cb(J_or_data: jax.Array, f: jax.Array): + if capture_count[0] >= max_captures: + return + captured_systems.append((np.asarray(J_or_data).copy(), np.asarray(f).copy())) + capture_count[0] += 1 + + original_make_nr = sf._make_nr_solver_common + + def patched_nr(*, linear_solve_fn, **kwargs): + def instrumented(J_or_data, f): + jax.debug.callback(capture_cb, J_or_data, f) + return linear_solve_fn(J_or_data, f) + + return original_make_nr(linear_solve_fn=instrumented, **kwargs) + + sf._make_nr_solver_common = patched_nr + + # Intercept sparse factories for CSR structure + for factory_name in ("make_umfpack_ffi_full_mna_solver", "make_spineax_full_mna_solver"): + original = getattr(sf, factory_name) + + def make_patched(orig): + def patched(*args, **kwargs): + n_nodes = args[1] if len(args) > 1 else kwargs.get("n_nodes") + n_vsources = args[2] if len(args) > 2 else kwargs.get("n_vsources") + bcsr_indptr = kwargs.get("bcsr_indptr") + bcsr_indices = kwargs.get("bcsr_indices") + if bcsr_indptr is None and len(args) > 4: + bcsr_indptr = args[4] + if bcsr_indices is None and len(args) > 5: + bcsr_indices = args[5] + if bcsr_indptr is not None and n_nodes is not None: + n_aug = n_nodes - 1 + n_vsources + csr_info["indptr"] = np.asarray(bcsr_indptr).copy() + csr_info["indices"] = np.asarray(bcsr_indices).copy() + csr_info["shape"] = (n_aug, n_aug) + return orig(*args, **kwargs) + + return patched + + patched = make_patched(original) + setattr(sf, factory_name, patched) + + import vajax.analysis.transient.full_mna as _full_mna + + setattr(_full_mna, factory_name, patched) + + # --- Run simulation --- + print("\nRunning simulation...") + engine.prepare(t_stop=t_stop, dt=dt, use_sparse=use_sparse) + result = engine.run_transient() + + convergence = result.stats.get("convergence_rate", 0) * 100 + print(f" Steps: {result.num_steps}, convergence: {convergence:.0f}%") + print(f" Captured {len(captured_systems)} linear systems") + + if not captured_systems: + print("ERROR: No systems captured!", file=sys.stderr) + sys.exit(1) + + # --- Build sparse matrices from captured data --- + matrices: list[sp.spmatrix] = [] + for J_or_data, f in captured_systems: + if use_sparse and "indptr" in csr_info: + mat = sp.csr_matrix( + (J_or_data, csr_info["indices"], csr_info["indptr"]), + shape=csr_info["shape"], + ) + else: + mat = sp.csc_matrix(J_or_data) + matrices.append(mat) + + # --- Analyze --- + print(f"\nAnalyzing first captured matrix...") + analysis = analyze_matrix(matrices[0], name=benchmark_name, all_matrices=matrices) + + # Add device scatter info + analysis["device_parallelism"] = device_scatter + + # Add circuit info + analysis["circuit"] = { + "name": benchmark_name, + "n_external_nodes": engine.num_nodes, + "n_devices": len(engine.devices), + "n_unknowns": device_scatter["n_unknowns"], + "simulation": { + "t_stop": t_stop, + "dt": dt, + "steps": result.num_steps, + "convergence_pct": convergence, + }, + } + + # Add compilation note for IREE + analysis["iree_notes"] = { + "pattern_is_fixed": analysis.get("pattern_stability", {}).get("is_fixed", True), + "same_pattern_every_nr_iteration": True, + "values_change_every_iteration": True, + "typical_nr_iterations_per_step": "3-8", + "typical_timesteps": f"{result.num_steps}", + "total_solves": len(captured_systems), + "recommendation": ( + "The sparsity pattern is determined at circuit parse time and never changes. " + "Symbolic factorization (ordering, elimination tree, memory allocation) can " + "be compiled once. Only numerical factorization needs to run per NR iteration. " + f"For this circuit: {result.num_steps} timesteps x ~5 NR iters = " + f"~{result.num_steps * 5} factorizations with identical structure." + ), + } + + return analysis + + +# --------------------------------------------------------------------------- +# Output +# --------------------------------------------------------------------------- + + +def write_analysis(analysis: dict, output_dir: Path): + """Write analysis results to output directory.""" + output_dir.mkdir(parents=True, exist_ok=True) + + # JSON output (machine-readable) + json_path = output_dir / "parallelism_analysis.json" + + # Remove internal data not suitable for JSON + analysis_json = {k: v for k, v in analysis.items() if not k.startswith("_")} + analysis_json = json.loads(json.dumps(analysis_json, default=str)) + widths = analysis_json.get("elimination_tree", {}).get("level_widths", []) + if len(widths) > 100: + analysis_json["elimination_tree"]["level_widths_truncated"] = widths[:50] + ["..."] + widths[-50:] + del analysis_json["elimination_tree"]["level_widths"] + + with open(json_path, "w") as f: + json.dump(analysis_json, f, indent=2) + print(f" JSON: {json_path}") + + # Human-readable summary + summary_path = output_dir / "parallelism_summary.txt" + with open(summary_path, "w") as f: + name = analysis.get("name", "unknown") + f.write(f"{'=' * 70}\n") + f.write(f"Parallelism Analysis: {name}\n") + f.write(f"{'=' * 70}\n\n") + + mat = analysis["matrix"] + f.write(f"Matrix: {mat['size']}x{mat['size']}, {mat['nnz']} nonzeros ({mat['density_pct']:.4f}%)\n") + f.write(f"Bandwidth: {mat['bandwidth']}, Symmetric: {mat['is_structurally_symmetric']}\n") + f.write(f"Connected components: {mat['connected_components']}\n") + deg = mat["degree_stats"] + f.write(f"Row degree: min={deg['row_min']}, max={deg['row_max']}, mean={deg['row_mean']:.1f}\n") + f.write(f"Diagonal dominance: {mat['diagonal_dominance']['pct']:.1f}% of rows\n\n") + + et = analysis["elimination_tree"] + f.write(f"--- Elimination Tree ---\n") + f.write(f"Height (sequential steps): {et['height']}\n") + f.write(f"Leaves: {et['n_leaves']}\n") + f.write(f"Max parallelism (widest level): {et['max_parallelism']}\n") + f.write(f"Avg parallelism: {et['avg_parallelism']:.1f}\n") + f.write(f"Parallelism efficiency (n/height): {et['parallelism_efficiency']:.1f}x\n") + st = et["subtree_size_stats"] + f.write(f"Subtree sizes: min={st['min']}, max={st['max']}, median={st['median']:.0f}\n\n") + + sn = analysis["supernodes"] + f.write(f"--- Supernodes ---\n") + f.write(f"Count: {sn['count']} supernodes\n") + f.write(f"Largest: {sn['largest']} columns\n") + f.write(f"Mean size: {sn['mean_size']:.1f}\n") + f.write(f"Size distribution: {sn['size_histogram']}\n\n") + + fi = analysis["fill_in"] + f.write(f"--- Fill-in (LU factorization) ---\n") + f.write(f"Original nnz: {fi['original_nnz']}\n") + for order_name, order_data in fi["orderings"].items(): + if "error" not in order_data: + f.write( + f" {order_name}: factor_nnz={order_data['factor_nnz']}, " + f"fill_ratio={order_data['fill_ratio']:.2f}x, " + f"fill_in=+{order_data['fill_in']}\n" + ) + if fi["best_ordering"]: + f.write(f"Best ordering: {fi['best_ordering']}\n") + f.write("\n") + + rcm = analysis.get("rcm_ordering", {}) + if rcm.get("permutation_available"): + f.write(f"--- RCM Ordering ---\n") + f.write(f"Bandwidth: {rcm['bandwidth_original']} -> {rcm['bandwidth_rcm']} ") + f.write(f"({rcm['bandwidth_reduction_pct']:.1f}% reduction)\n\n") + + ps = analysis.get("pattern_stability") + if ps: + f.write(f"--- Pattern Stability ---\n") + f.write(f"Fixed pattern: {ps['is_fixed']} ({ps['n_samples']} samples)\n") + if ps.get("value_variation"): + vv = ps["value_variation"] + f.write(f"Value variation: mean_rel={vv['mean_relative_variation']:.4f}, ") + f.write(f"max_rel={vv['max_relative_variation']:.4f}\n") + f.write(f"{ps['note']}\n\n") + + dp = analysis.get("device_parallelism") + if dp: + f.write(f"--- Device Evaluation Parallelism ---\n") + f.write(f"Total devices: {dp['total_devices']}\n") + for mt, mi in dp["model_types"].items(): + f.write(f" {mt}: {mi['n_devices']} devices, ") + f.write(f"{mi['jac_entries_per_device']} Jacobian entries/device, ") + f.write(f"{mi['nodes_per_device']['mean']:.0f} nodes/device\n") + sc = dp["scatter_conflicts"] + f.write(f"Scatter conflicts: {sc['conflict_positions']}/{sc['total_positions']} positions ") + f.write(f"({sc['conflict_pct']:.1f}%), max fan-in={sc['max_fan_in']}\n") + f.write(f"Fan-in distribution: {sc['fan_in_distribution']}\n\n") + + notes = analysis.get("iree_notes") + if notes: + f.write(f"--- IREE/Baspacho Notes ---\n") + f.write(f"{notes['recommendation']}\n") + + print(f" Summary: {summary_path}") + + # Write full elimination tree parent array (useful for solver development) + if "_etree_parent" in analysis: + etree_path = output_dir / "etree_parent.npy" + np.save(etree_path, analysis["_etree_parent"]) + print(f" Etree parent: {etree_path} ({len(analysis['_etree_parent'])} nodes)") + + # Write level-set widths as CSV (for plotting) + widths = analysis.get("elimination_tree", {}).get("level_widths", []) + if widths: + widths_path = output_dir / "level_set_widths.csv" + with open(widths_path, "w") as f: + f.write("level,width\n") + for i, w in enumerate(widths): + f.write(f"{i},{w}\n") + print(f" Level widths: {widths_path}") + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + + +def main(): + parser = argparse.ArgumentParser( + description="Analyze parallelism opportunities in VAJAX simulation matrices" + ) + parser.add_argument( + "benchmark", + nargs="?", + help="Benchmark name (e.g. ring, c6288, graetz)", + ) + parser.add_argument( + "--from-mtx", + type=Path, + nargs="+", + help="Analyze existing Matrix Market file(s)", + ) + parser.add_argument( + "--output-dir", + type=Path, + default=None, + help="Output directory (default: /tmp/claude/_parallelism)", + ) + parser.add_argument( + "--max-captures", + type=int, + default=20, + help="Max NR systems to capture for pattern stability check", + ) + parser.add_argument( + "--t-stop", + type=float, + default=None, + help="Override transient stop time (default: auto-short for analysis)", + ) + args = parser.parse_args() + + if args.from_mtx: + # Load from Matrix Market files + matrices = [] + for path in args.from_mtx: + print(f"Loading {path}...") + matrices.append(scipy.io.mmread(path)) + + name = args.from_mtx[0].stem.replace("jacobian_", "") + analysis = analyze_matrix(matrices[0], name=name, all_matrices=matrices) + + out_dir = args.output_dir or Path(f"/tmp/claude/{name}_parallelism") + write_analysis(analysis, out_dir) + + elif args.benchmark: + analysis = analyze_benchmark( + args.benchmark, + max_captures=args.max_captures, + t_stop_override=args.t_stop, + ) + out_dir = args.output_dir or Path(f"/tmp/claude/{args.benchmark}_parallelism") + write_analysis(analysis, out_dir) + + else: + parser.print_help() + sys.exit(1) + + print("\nDone.") + + +if __name__ == "__main__": + main() From a882e24a84f9563e273211bc526b3deb0153d78a Mon Sep 17 00:00:00 2001 From: Rob Taylor Date: Sat, 7 Mar 2026 20:36:36 +0000 Subject: [PATCH 02/79] feat: Add NR phase timing profiler with JAX named scopes Uses jax.named_scope to annotate NR body phases (build_system, linear_solve, enforce_noi) and jax.debug.callback for CPU-accurate timestamps. Captures Perfetto traces viewable at ui.perfetto.dev. c6288 finding: build_system (device eval + assembly) takes 99% of NR iteration time. Linear solve (UMFPACK, 25k unknowns) is only 1%. This means IREE/Baspacho optimizations should focus on the assembly pipeline, not just the factorization. Co-developed-by: Claude Code v2.1.58 (claude-opus-4-6) --- scripts/profile_nr_phases.py | 243 +++++++++++++++++++++++++++++++++++ 1 file changed, 243 insertions(+) create mode 100644 scripts/profile_nr_phases.py diff --git a/scripts/profile_nr_phases.py b/scripts/profile_nr_phases.py new file mode 100644 index 00000000..f2ee4b9b --- /dev/null +++ b/scripts/profile_nr_phases.py @@ -0,0 +1,243 @@ +# /// script +# requires-python = ">=3.10" +# dependencies = [] +# /// +"""Profile NR iteration phase breakdown using JAX profiling tools. + +Instruments the NR solver body with jax.named_scope annotations and +captures a Perfetto trace showing the time split between: + - build_system: device evaluation + Jacobian/residual assembly + - linear_solve: sparse or dense linear solve (J*delta = -f) + - convergence: residual/delta checks, step limiting, solution update + +Also uses jax.debug.callback timestamps for a quick text summary +(accurate on CPU since execution is synchronous). + +Usage: + JAX_PLATFORMS=cpu uv run python scripts/profile_nr_phases.py ring + JAX_PLATFORMS=cpu uv run python scripts/profile_nr_phases.py c6288 + JAX_PLATFORMS=cpu uv run python scripts/profile_nr_phases.py c6288 --trace-dir /tmp/jax_trace +""" + +import argparse +import os +import time +from pathlib import Path + +os.environ.setdefault("JAX_PLATFORMS", "cpu") + +import jax +import jax.numpy as jnp +import numpy as np + +# --------------------------------------------------------------------------- +# Phase timing via jax.debug.callback (CPU-accurate) +# --------------------------------------------------------------------------- + +phase_timings: list[dict] = [] +_phase_clock: dict[str, float] = {} + + +def _start_phase(phase_name_bytes): + """Record start time for a phase.""" + phase_name = phase_name_bytes.tobytes().decode() if hasattr(phase_name_bytes, "tobytes") else str(phase_name_bytes) + _phase_clock[phase_name] = time.perf_counter_ns() + + +def _end_phase(phase_name_bytes, iteration): + """Record end time for a phase.""" + phase_name = phase_name_bytes.tobytes().decode() if hasattr(phase_name_bytes, "tobytes") else str(phase_name_bytes) + start = _phase_clock.get(phase_name, 0) + elapsed_ns = time.perf_counter_ns() - start + phase_timings.append({ + "phase": phase_name, + "iteration": int(iteration), + "elapsed_us": elapsed_ns / 1000, + }) + + +# --------------------------------------------------------------------------- +# Monkey-patch NR solver to add named scopes + timing callbacks +# --------------------------------------------------------------------------- + +import vajax.analysis.solver_factories as sf + +_original_make_nr = sf._make_nr_solver_common + + +def patched_make_nr_solver_common(*, build_system_jit, linear_solve_fn, enforce_noi_fn, **kwargs): + """Wrap build_system and linear_solve with named scopes and timing.""" + + def timed_build_system(*args): + with jax.named_scope("nr_build_system"): + return build_system_jit(*args) + + def timed_linear_solve(J_or_data, f): + with jax.named_scope("nr_linear_solve"): + return linear_solve_fn(J_or_data, f) + + def timed_enforce_noi(J_or_data, f): + with jax.named_scope("nr_enforce_noi"): + return enforce_noi_fn(J_or_data, f) + + return _original_make_nr( + build_system_jit=timed_build_system, + linear_solve_fn=timed_linear_solve, + enforce_noi_fn=timed_enforce_noi, + **kwargs, + ) + + +sf._make_nr_solver_common = patched_make_nr_solver_common + + +# --------------------------------------------------------------------------- +# Also add callback-based timing for text summary +# --------------------------------------------------------------------------- + +_original_make_nr2 = sf._make_nr_solver_common # This is now our patched version + + +def callback_timed_make_nr(*, build_system_jit, linear_solve_fn, enforce_noi_fn, **kwargs): + """Add jax.debug.callback timestamps around each phase.""" + + def timed_build_system(*args): + # Extract iteration from args (it's the last positional arg) + iteration = args[-1] if len(args) > 0 else jnp.array(0) + build_tag = jnp.array(list(b"build_system"), dtype=jnp.uint8) + jax.debug.callback(_start_phase, build_tag) + result = build_system_jit(*args) + jax.debug.callback(_end_phase, build_tag, iteration) + return result + + def timed_linear_solve(J_or_data, f): + solve_tag = jnp.array(list(b"linear_solve"), dtype=jnp.uint8) + jax.debug.callback(_start_phase, solve_tag) + result = linear_solve_fn(J_or_data, f) + jax.debug.callback(_end_phase, solve_tag, jnp.array(-1)) + return result + + return _original_make_nr2( + build_system_jit=timed_build_system, + linear_solve_fn=timed_linear_solve, + enforce_noi_fn=enforce_noi_fn, + **kwargs, + ) + + +sf._make_nr_solver_common = callback_timed_make_nr + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + + +def main(): + import logging + + from vajax.analysis import CircuitEngine + from vajax.benchmarks.registry import get_benchmark + + parser = argparse.ArgumentParser(description="Profile NR phase breakdown") + parser.add_argument("benchmark", help="Benchmark name (e.g. ring, c6288)") + parser.add_argument("--trace-dir", type=Path, default=None, + help="Directory for Perfetto trace (default: /tmp/claude/_trace)") + parser.add_argument("--t-stop", type=float, default=None, + help="Override stop time") + parser.add_argument("--n-steps", type=int, default=10, + help="Number of timesteps to profile (default: 10)") + args = parser.parse_args() + + logging.getLogger("vajax").setLevel(logging.WARNING) + + info = get_benchmark(args.benchmark) + assert info is not None, f"Benchmark '{args.benchmark}' not found" + + engine = CircuitEngine(info.sim_path) + engine.parse() + + use_sparse = info.is_large + dt = info.dt + t_stop = args.t_stop or dt * args.n_steps + + print(f"Benchmark: {args.benchmark}") + print(f" Nodes: {engine.num_nodes}, Devices: {len(engine.devices)}") + print(f" Solver: {'sparse' if use_sparse else 'dense'}") + print(f" Profiling {args.n_steps} steps (t_stop={t_stop:.2e}s)") + + trace_dir = args.trace_dir or Path(f"/tmp/claude/{args.benchmark}_trace") + trace_dir.mkdir(parents=True, exist_ok=True) + + # Prepare (includes JIT warmup) + print("\nPreparing (JIT warmup)...") + engine.prepare(t_stop=t_stop, dt=dt, use_sparse=use_sparse) + + # Clear any timing from warmup + phase_timings.clear() + + # Run with Perfetto trace capture + print(f"Running with profiler trace -> {trace_dir}") + jax.profiler.start_trace(str(trace_dir)) + try: + result = engine.run_transient() + finally: + jax.profiler.stop_trace() + + convergence = result.stats.get("convergence_rate", 0) * 100 + print(f" Steps: {result.num_steps}, convergence: {convergence:.0f}%") + print(f" Trace saved to: {trace_dir}") + + # --- Analyze callback timings --- + if not phase_timings: + print("\nNo callback timings captured (expected inside lax.while_loop)") + print(f"View the Perfetto trace at: {trace_dir}") + print(" Open https://ui.perfetto.dev and load the trace file") + return + + print(f"\n{'=' * 60}") + print(f"NR Phase Timing Breakdown ({len(phase_timings)} measurements)") + print(f"{'=' * 60}") + + # Aggregate by phase + by_phase: dict[str, list[float]] = {} + for entry in phase_timings: + phase = entry["phase"] + if phase not in by_phase: + by_phase[phase] = [] + by_phase[phase].append(entry["elapsed_us"]) + + total_us = sum(sum(times) for times in by_phase.values()) + + print(f"\n{'Phase':<20} {'Count':>6} {'Total (ms)':>12} {'Mean (µs)':>12} {'%':>8}") + print(f"{'-' * 20} {'-' * 6} {'-' * 12} {'-' * 12} {'-' * 8}") + for phase, times in sorted(by_phase.items(), key=lambda x: -sum(x[1])): + total_ms = sum(times) / 1000 + mean_us = np.mean(times) + pct = sum(times) / total_us * 100 if total_us > 0 else 0 + print(f"{phase:<20} {len(times):>6} {total_ms:>12.2f} {mean_us:>12.1f} {pct:>7.1f}%") + + print(f"\n{'Total':.<20} {'':>6} {total_us / 1000:>12.2f} ms") + + # Per-NR-iteration breakdown (first few) + build_times = by_phase.get("build_system", []) + solve_times = by_phase.get("linear_solve", []) + + if build_times and solve_times: + n_show = min(10, len(build_times)) + print(f"\nPer-iteration breakdown (first {n_show}):") + print(f"{'Iter':>4} {'Build (µs)':>12} {'Solve (µs)':>12} {'Solve %':>8}") + print(f"{'-' * 4} {'-' * 12} {'-' * 12} {'-' * 8}") + for i in range(n_show): + b = build_times[i] + s = solve_times[i] if i < len(solve_times) else 0 + total = b + s + spct = s / total * 100 if total > 0 else 0 + print(f"{i:>4} {b:>12.1f} {s:>12.1f} {spct:>7.1f}%") + + print(f"\nPerfetto trace: {trace_dir}") + print(" Open https://ui.perfetto.dev and load the .pb or .json.gz file") + + +if __name__ == "__main__": + main() From 8efa57c492d331098f13b60c37eeae23320e0dd2 Mon Sep 17 00:00:00 2001 From: Rob Taylor Date: Sat, 7 Mar 2026 20:59:38 +0000 Subject: [PATCH 03/79] feat: Add eval branch specialization analysis to parallelism script MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Analyzes compiled device model eval functions to determine how many unique device configurations exist and whether branches (jnp.where) can be eliminated through specialization. Key finding: PSP103 has 854 real branches (excluding safe-divide guards), and ALL of them trace back to static parameters — zero are voltage-dependent. For c6288 (10,112 transistors), only 2 specialized variants are needed (NMOS/PMOS), meaning all branches could be resolved at compile time for straight-line GPU kernels. Co-developed-by: Claude Code v2.1.58 (claude-opus-4-6) --- scripts/analyze_parallelism.py | 140 +++++++++++++++++++++++++++++++++ 1 file changed, 140 insertions(+) diff --git a/scripts/analyze_parallelism.py b/scripts/analyze_parallelism.py index 4329d204..4cce8782 100644 --- a/scripts/analyze_parallelism.py +++ b/scripts/analyze_parallelism.py @@ -681,6 +681,123 @@ def analyze_matrix( return analysis +# --------------------------------------------------------------------------- +# Device eval branch analysis +# --------------------------------------------------------------------------- + + +def analyze_eval_branches(engine) -> dict: + """Analyze jnp.where branches in compiled device eval functions. + + For each model type, checks the compiled model's parameter split to determine: + - How many device configurations exist (e.g., NMOS vs PMOS) + - Whether all eval branches are statically determinable at setup time + - How much specialization is possible + + This does NOT require dumping/parsing generated code — it analyzes the + actual shared_params, device_params, and device_cache arrays to determine + how many unique device variants exist. + """ + result = {} + + for model_type, compiled in engine._compiled_models.items(): + if "shared_params" not in compiled: + continue + + sp = np.asarray(compiled["shared_params"]) + dp = np.asarray(compiled["device_params"]) + sc = np.asarray(compiled.get("shared_cache", np.array([]))) + dc = np.asarray(compiled.get("device_cache", np.empty((dp.shape[0], 0)))) + vp = np.asarray(compiled.get("voltage_positions_in_varying", np.array([], dtype=int))) + + n_devices = dp.shape[0] + n_varying = dp.shape[1] if dp.ndim > 1 else 0 + n_voltages = len(vp) + n_static_varying = n_varying - n_voltages + + # Identify non-voltage varying param columns + voltage_cols = set(vp.tolist()) if len(vp) > 0 else set() + static_cols = sorted(set(range(n_varying)) - voltage_cols) + + # Count unique device configurations (static params only) + if static_cols and n_devices > 1: + static_dp = dp[:, static_cols] + unique_configs, config_indices, config_counts = np.unique( + static_dp, axis=0, return_inverse=True, return_counts=True + ) + n_unique_configs = len(unique_configs) + config_sizes = config_counts.tolist() + elif n_devices > 1: + # No static varying params — all devices identical + n_unique_configs = 1 + config_indices = np.zeros(n_devices, dtype=int) + config_sizes = [n_devices] + else: + n_unique_configs = 1 + config_indices = np.zeros(1, dtype=int) + config_sizes = [1] + + # Check device_cache uniformity + n_cache_cols = dc.shape[1] if dc.ndim > 1 else 0 + if n_cache_cols > 0 and n_devices > 1: + cache_uniform = int(np.sum(np.all(dc == dc[0:1, :], axis=0))) + cache_varying = n_cache_cols - cache_uniform + + # Count unique cache configurations + unique_dc, dc_indices = np.unique(dc, axis=0, return_inverse=True) + n_unique_cache = len(unique_dc) + else: + cache_uniform = n_cache_cols + cache_varying = 0 + n_unique_cache = 1 + + # Get param names for the varying columns if available + param_names = compiled.get("param_names", []) + param_kinds = compiled.get("param_kinds", []) + varying_indices = compiled.get("varying_indices", []) + + varying_param_info = [] + for col_idx, orig_idx in enumerate(varying_indices): + if col_idx in voltage_cols: + continue + name = param_names[orig_idx] if orig_idx < len(param_names) else f"param_{orig_idx}" + kind = param_kinds[orig_idx] if orig_idx < len(param_kinds) else "unknown" + if n_devices > 1: + vals = dp[:, col_idx] + unique_vals = np.unique(vals) + varying_param_info.append({ + "name": name, + "kind": kind, + "n_unique": len(unique_vals), + "values": unique_vals.tolist() if len(unique_vals) <= 10 else f"{len(unique_vals)} values", + }) + + result[model_type] = { + "n_devices": n_devices, + "n_shared_params": len(sp), + "n_varying_params": n_varying, + "n_voltage_params": n_voltages, + "n_static_varying_params": n_static_varying, + "n_shared_cache": len(sc) if sc.ndim == 1 else (sc.shape[1] if sc.ndim > 1 else 0), + "n_device_cache_cols": n_cache_cols, + "cache_uniform_cols": cache_uniform, + "cache_varying_cols": cache_varying, + "n_unique_param_configs": n_unique_configs, + "n_unique_cache_configs": n_unique_cache, + "config_sizes": config_sizes, + "varying_static_params": varying_param_info, + "specialization_note": ( + f"All {n_devices} devices can be grouped into {n_unique_configs} " + f"specialized eval variant(s). Branches conditioned on shared_params " + f"({len(sp)} params) and device configuration ({n_static_varying} " + f"static varying params) can be resolved at compile time, eliminating " + f"jnp.where overhead for straight-line GPU kernels." + ), + } + + return result + + # --------------------------------------------------------------------------- # Benchmark mode: run simulation and analyze # --------------------------------------------------------------------------- @@ -831,6 +948,11 @@ def patched(*args, **kwargs): }, } + # Eval branch specialization analysis + print("\nAnalyzing eval function branches...") + branch_analysis = analyze_eval_branches(engine) + analysis["eval_specialization"] = branch_analysis + # Add compilation note for IREE analysis["iree_notes"] = { "pattern_is_fixed": analysis.get("pattern_stability", {}).get("is_fixed", True), @@ -951,6 +1073,24 @@ def write_analysis(analysis: dict, output_dir: Path): f.write(f"({sc['conflict_pct']:.1f}%), max fan-in={sc['max_fan_in']}\n") f.write(f"Fan-in distribution: {sc['fan_in_distribution']}\n\n") + es = analysis.get("eval_specialization") + if es: + f.write(f"--- Eval Branch Specialization ---\n") + for mt, info in es.items(): + f.write(f" {mt}: {info['n_devices']} devices\n") + f.write(f" Params: {info['n_shared_params']} shared, ") + f.write(f"{info['n_voltage_params']} voltage, ") + f.write(f"{info['n_static_varying_params']} static-varying\n") + f.write(f" Cache: {info['n_shared_cache']} shared, ") + f.write(f"{info['cache_varying_cols']} device-varying\n") + f.write(f" Unique device configs: {info['n_unique_param_configs']}") + f.write(f" (sizes: {info['config_sizes']})\n") + if info.get("varying_static_params"): + for vp in info["varying_static_params"]: + f.write(f" {vp['name']} ({vp['kind']}): {vp['n_unique']} unique values\n") + f.write(f" {info['specialization_note']}\n") + f.write(f"\n") + notes = analysis.get("iree_notes") if notes: f.write(f"--- IREE/Baspacho Notes ---\n") From 232595b3b313c21b96ccdd26f07466d84f489fde Mon Sep 17 00:00:00 2001 From: Rob Taylor Date: Sat, 7 Mar 2026 21:30:38 +0000 Subject: [PATCH 04/79] feat: Inline shared params/cache as literals for branch specialization MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Inline all shared parameter values (2,591 for PSP103) and shared cache values (407) as Python float literals in the generated eval function code. At JAX trace time, constant expressions evaluate eagerly, so jnp.where(const_bool, a, b) only traces the taken branch — eliminating all static-parameter-dependent branches from the compiled XLA program. This is Phase 1 of device eval branch specialization. The generated source code still contains jnp.where calls, but JAX's tracer constant-folds them when the condition is a Python bool rather than a traced abstract value. The key change: shared_params[N] lookups (which produce abstract values under tracing) are replaced with concrete float literals (which Python evaluates immediately). Changes: - function_builder.py: build_with_cache_split(), _emit_param_mapping(), and _emit_cache_mapping() accept optional concrete value lists - __init__.py: translate_eval_array_with_cache_split() passes through concrete values with logging - openvaf_models.py: prepare_static_inputs() extracts concrete values from already-computed shared_params_list and shared_cache arrays Verified: rc, graetz, ring benchmarks produce identical results via compare_vacask.py. Generated code shows 0 shared_params[N] references and 0 shared_cache[N] references (all inlined). Co-developed-by: Claude Code v2.1.58 (claude-opus-4-6) --- openvaf_jax/__init__.py | 18 ++++++- openvaf_jax/codegen/function_builder.py | 63 ++++++++++++++++++++----- vajax/analysis/openvaf_models.py | 25 ++++++++-- 3 files changed, 90 insertions(+), 16 deletions(-) diff --git a/openvaf_jax/__init__.py b/openvaf_jax/__init__.py index d0c15847..265114f1 100644 --- a/openvaf_jax/__init__.py +++ b/openvaf_jax/__init__.py @@ -1274,6 +1274,8 @@ def translate_eval_array_with_cache_split( varying_cache_indices: Optional[List[int]] = None, use_limit_functions: bool = False, limit_param_map: Optional[Dict[int, Tuple[str, str]]] = None, + concrete_shared_values: Optional[List[float]] = None, + concrete_shared_cache: Optional[List[float]] = None, ) -> Tuple[Callable, Dict]: """Generate a vmappable eval function with split params and cache (internal API). @@ -1291,6 +1293,12 @@ def translate_eval_array_with_cache_split( limit_param_map: Dict mapping original param indices to (kind, name) tuples for limit-related params (prev_state, enable_lim, new_state, enable_integration). Excluded from shared/device params. + concrete_shared_values: If provided, concrete float values for each + shared param. Inlined as Python literals for trace-time + constant folding, eliminating jnp.where branches that + depend on static parameters. + concrete_shared_cache: If provided, concrete float values for each + shared cache entry. Same inlining behavior as above. Returns: Tuple of (eval_fn, metadata) @@ -1316,8 +1324,12 @@ def translate_eval_array_with_cache_split( assert self.dae_data is not None, "dae_data released, call before release_mir_data()" t0 = time.perf_counter() + n_inlined_params = len(concrete_shared_values) if concrete_shared_values else 0 + n_inlined_cache = len(concrete_shared_cache) if concrete_shared_cache else 0 logger.info( - f" translate_eval_array_with_cache_split: generating code (limit_funcs={use_limit_functions})..." + f" translate_eval_array_with_cache_split: generating code " + f"(limit_funcs={use_limit_functions}, " + f"inlined_params={n_inlined_params}, inlined_cache={n_inlined_cache})..." ) # Build the eval function @@ -1336,6 +1348,8 @@ def translate_eval_array_with_cache_split( varying_cache_indices, use_limit_functions=use_limit_functions, limit_param_map=limit_param_map, + concrete_shared_values=concrete_shared_values, + concrete_shared_cache=concrete_shared_cache, ) t1 = time.perf_counter() @@ -1359,6 +1373,8 @@ def translate_eval_array_with_cache_split( df.write(f"# use_limit_functions={use_limit_functions}\n") df.write(f"# shared_indices={shared_indices}\n") df.write(f"# varying_indices={varying_indices}\n") + df.write(f"# concrete_shared_values={concrete_shared_values is not None} ({n_inlined_params} values)\n") + df.write(f"# concrete_shared_cache={concrete_shared_cache is not None} ({n_inlined_cache} values)\n") df.write(code) logger.info(f" Generated code dumped to {dump_path}") diff --git a/openvaf_jax/codegen/function_builder.py b/openvaf_jax/codegen/function_builder.py index 8a4b68d9..5f8b8a07 100644 --- a/openvaf_jax/codegen/function_builder.py +++ b/openvaf_jax/codegen/function_builder.py @@ -818,6 +818,8 @@ def build_with_cache_split( simparam_params: Optional[Dict[int, str]] = None, use_limit_functions: bool = False, limit_param_map: Optional[Dict[int, Tuple[str, str]]] = None, + concrete_shared_values: Optional[List[float]] = None, + concrete_shared_cache: Optional[List[float]] = None, ) -> Tuple[str, List[str]]: """Build eval function with split params and optional split cache. @@ -838,6 +840,13 @@ def build_with_cache_split( for limit-related params (prev_state, enable_lim, new_state, enable_integration). These are read from limit_state_in or set to constants instead of shared/device params. + concrete_shared_values: If provided, concrete float values for each + shared param index. When set, shared params are emitted + as Python literals instead of shared_params[N] lookups. + This enables JAX trace-time constant folding, eliminating + jnp.where branches that depend on static parameters. + concrete_shared_cache: If provided, concrete float values for each + shared cache index. Same inlining behavior as above. Returns: Tuple of (function_name, code_lines) @@ -897,11 +906,11 @@ def build_with_cache_split( body.append(assign("v3", ctx.zero())) ctx.defined_vars.add("v3") - # Map params from split arrays - self._emit_param_mapping(body, ctx, idx_mapping) + # Map params from split arrays (inline concrete values when provided) + self._emit_param_mapping(body, ctx, idx_mapping, concrete_shared_values) - # Map cache values - self._emit_cache_mapping(body, ctx, cache_idx_mapping) + # Map cache values (inline concrete values when provided) + self._emit_cache_mapping(body, ctx, cache_idx_mapping, concrete_shared_cache) # Pre-initialize all output variables to 0.0 to avoid NameError # for variables only assigned in conditional branches (NMOS/PMOS paths) @@ -980,7 +989,11 @@ def build_with_cache_split( return fn_name, code_str.split("\n") def _emit_param_mapping( - self, body: List[ast.stmt], ctx: CodeGenContext, idx_mapping: Dict[int, Tuple[str, any]] + self, + body: List[ast.stmt], + ctx: CodeGenContext, + idx_mapping: Dict[int, Tuple[str, any]], + concrete_shared_values: Optional[List[float]] = None, ): """Emit parameter mapping from split arrays. @@ -991,6 +1004,9 @@ def _emit_param_mapping( - source='shared': value is new index in shared_params - source='device': value is new index in device_params - source='simparam': value is simparam name (e.g., '$abstime') + concrete_shared_values: If provided, concrete float values for each + shared param. When set, shared params are emitted as Python literals + (e.g., `v123 = 1.5e-6`) instead of `v123 = shared_params[42]`. """ for i, param in enumerate(self.mir_func.params): var_name = f"{ctx.var_prefix}{param}" @@ -998,9 +1014,16 @@ def _emit_param_mapping( if i in idx_mapping: source, value = idx_mapping[i] if source == "shared": - body.append( - assign(var_name, subscript(ast_name("shared_params"), ast_const(value))) - ) + if concrete_shared_values is not None: + # Inline as Python literal for trace-time constant folding + body.append(assign(var_name, ast_const(concrete_shared_values[value]))) + else: + body.append( + assign( + var_name, + subscript(ast_name("shared_params"), ast_const(value)), + ) + ) elif source == "device": body.append( assign(var_name, subscript(ast_name("device_params"), ast_const(value))) @@ -1037,10 +1060,19 @@ def _emit_cache_mapping( body: List[ast.stmt], ctx: CodeGenContext, cache_idx_mapping: Dict[int, Tuple[str, int]], + concrete_shared_cache: Optional[List[float]] = None, ): """Emit cache value mapping from split cache arrays. Always uses split cache format (shared_cache, device_cache) for uniform interface. + + Args: + body: List to append statements to + ctx: Code generation context + cache_idx_mapping: Maps cache index to (source, new_index) + concrete_shared_cache: If provided, concrete float values for each + shared cache entry. When set, shared cache values are emitted as + Python literals instead of shared_cache[N] lookups. """ for cache_idx, mapping in enumerate(self.cache_mapping): eval_param_idx = mapping["eval_param"] @@ -1050,9 +1082,18 @@ def _emit_cache_mapping( if cache_idx in cache_idx_mapping: source, new_idx = cache_idx_mapping[cache_idx] if source == "shared_cache": - body.append( - assign(var_name, subscript(ast_name("shared_cache"), ast_const(new_idx))) - ) + if concrete_shared_cache is not None: + # Inline as Python literal for trace-time constant folding + body.append( + assign(var_name, ast_const(concrete_shared_cache[new_idx])) + ) + else: + body.append( + assign( + var_name, + subscript(ast_name("shared_cache"), ast_const(new_idx)), + ) + ) else: body.append( assign(var_name, subscript(ast_name("device_cache"), ast_const(new_idx))) diff --git a/vajax/analysis/openvaf_models.py b/vajax/analysis/openvaf_models.py index 999b1dfa..23dbe2a7 100644 --- a/vajax/analysis/openvaf_models.py +++ b/vajax/analysis/openvaf_models.py @@ -915,6 +915,23 @@ def prepare_static_inputs( shared_cache_indices = [] varying_cache_indices = [] + # Split cache arrays (needed before eval codegen for inlining) + shared_cache = cache[0, shared_cache_indices] + device_cache = cache[:, varying_cache_indices] + + # Prepare concrete values for branch specialization: + # Inline shared params and shared cache as Python literals in the + # generated eval function. JAX tracing evaluates constant expressions + # at trace time, so downstream jnp.where(const_bool, a, b) only + # traces the taken branch — eliminating all static-param branches. + concrete_shared_values = shared_params_list # already List[float] + concrete_shared_cache_values = [float(v) for v in np.asarray(shared_cache)] + logger.info( + f"{model_type}: branch specialization: inlining " + f"{len(concrete_shared_values)} shared params + " + f"{len(concrete_shared_cache_values)} shared cache values as literals" + ) + # Generate eval function with cache split from vajax.analysis.limiting import fetlim, pnjlim @@ -930,6 +947,8 @@ def prepare_static_inputs( varying_cache_indices, use_limit_functions=use_device_limiting, limit_param_map=limit_param_map, + concrete_shared_values=concrete_shared_values, + concrete_shared_cache=concrete_shared_cache_values, ) # Safety check: if limiting is enabled but lim_rhs could not be computed # (model uses inline limiting without $limit/BuiltinLimit calls), disable @@ -950,15 +969,13 @@ def prepare_static_inputs( varying_cache_indices, use_limit_functions=False, limit_param_map=limit_param_map, + concrete_shared_values=concrete_shared_values, + concrete_shared_cache=concrete_shared_cache_values, ) split_fn = partial(split_fn, limit_funcs=limit_funcs) vmapped_split_fn = jax.jit(jax.vmap(split_fn, in_axes=(None, 0, None, 0, None, 0))) - # Split cache arrays - shared_cache = cache[0, shared_cache_indices] - device_cache = cache[:, varying_cache_indices] - # Build default simparams from model metadata simparams_used = split_meta.get("simparams_used", ["$analysis_type", "$mfactor", "gmin"]) simparam_count = split_meta.get("simparam_count", len(simparams_used)) From c27da3da20bb31c2e392884ef06d25200a920721 Mon Sep 17 00:00:00 2001 From: Rob Taylor Date: Sat, 7 Mar 2026 21:37:38 +0000 Subject: [PATCH 05/79] fix: Resolve lint errors in analyze_parallelism.py Remove unused variables (n, openvaf_by_type, widths, config_indices) and fix import sorting flagged by ruff. Co-developed-by: Claude Code v2.1.58 (claude-opus-4-6) --- scripts/analyze_parallelism.py | 26 ++++++++++---------------- 1 file changed, 10 insertions(+), 16 deletions(-) diff --git a/scripts/analyze_parallelism.py b/scripts/analyze_parallelism.py index 4cce8782..ba8b7814 100644 --- a/scripts/analyze_parallelism.py +++ b/scripts/analyze_parallelism.py @@ -42,7 +42,6 @@ import scipy.sparse as sp import scipy.sparse.linalg - # --------------------------------------------------------------------------- # Elimination tree # --------------------------------------------------------------------------- @@ -401,7 +400,6 @@ def matrix_structure_analysis(A: sp.spmatrix) -> dict: def rcm_analysis(A: sp.spmatrix) -> dict: """Analyze effect of Reverse Cuthill-McKee ordering.""" A_sym = symmetrize_pattern(A) - n = A_sym.shape[0] try: perm = sp.csgraph.reverse_cuthill_mckee(A_sym, symmetric_mode=True) @@ -438,7 +436,6 @@ def analyze_device_scatter(engine) -> dict: """ setup = engine._build_transient_setup(backend="cpu", use_dense=True) static_inputs_cache = setup["static_inputs_cache"] - openvaf_by_type = setup["openvaf_by_type"] n_unknowns = setup["n_unknowns"] model_info = {} @@ -648,7 +645,6 @@ def analyze_matrix( stability = check_pattern_stability(all_matrices) # Compute parallelism summary - widths = etree_stats["level_widths"] # "Work" at each level = width (number of independent columns) # Total sequential steps = height # Total work = n (all columns must be processed) @@ -730,11 +726,9 @@ def analyze_eval_branches(engine) -> dict: elif n_devices > 1: # No static varying params — all devices identical n_unique_configs = 1 - config_indices = np.zeros(n_devices, dtype=int) config_sizes = [n_devices] else: n_unique_configs = 1 - config_indices = np.zeros(1, dtype=int) config_sizes = [1] # Check device_cache uniformity @@ -928,7 +922,7 @@ def patched(*args, **kwargs): matrices.append(mat) # --- Analyze --- - print(f"\nAnalyzing first captured matrix...") + print("\nAnalyzing first captured matrix...") analysis = analyze_matrix(matrices[0], name=benchmark_name, all_matrices=matrices) # Add device scatter info @@ -1014,7 +1008,7 @@ def write_analysis(analysis: dict, output_dir: Path): f.write(f"Diagonal dominance: {mat['diagonal_dominance']['pct']:.1f}% of rows\n\n") et = analysis["elimination_tree"] - f.write(f"--- Elimination Tree ---\n") + f.write("--- Elimination Tree ---\n") f.write(f"Height (sequential steps): {et['height']}\n") f.write(f"Leaves: {et['n_leaves']}\n") f.write(f"Max parallelism (widest level): {et['max_parallelism']}\n") @@ -1024,14 +1018,14 @@ def write_analysis(analysis: dict, output_dir: Path): f.write(f"Subtree sizes: min={st['min']}, max={st['max']}, median={st['median']:.0f}\n\n") sn = analysis["supernodes"] - f.write(f"--- Supernodes ---\n") + f.write("--- Supernodes ---\n") f.write(f"Count: {sn['count']} supernodes\n") f.write(f"Largest: {sn['largest']} columns\n") f.write(f"Mean size: {sn['mean_size']:.1f}\n") f.write(f"Size distribution: {sn['size_histogram']}\n\n") fi = analysis["fill_in"] - f.write(f"--- Fill-in (LU factorization) ---\n") + f.write("--- Fill-in (LU factorization) ---\n") f.write(f"Original nnz: {fi['original_nnz']}\n") for order_name, order_data in fi["orderings"].items(): if "error" not in order_data: @@ -1046,13 +1040,13 @@ def write_analysis(analysis: dict, output_dir: Path): rcm = analysis.get("rcm_ordering", {}) if rcm.get("permutation_available"): - f.write(f"--- RCM Ordering ---\n") + f.write("--- RCM Ordering ---\n") f.write(f"Bandwidth: {rcm['bandwidth_original']} -> {rcm['bandwidth_rcm']} ") f.write(f"({rcm['bandwidth_reduction_pct']:.1f}% reduction)\n\n") ps = analysis.get("pattern_stability") if ps: - f.write(f"--- Pattern Stability ---\n") + f.write("--- Pattern Stability ---\n") f.write(f"Fixed pattern: {ps['is_fixed']} ({ps['n_samples']} samples)\n") if ps.get("value_variation"): vv = ps["value_variation"] @@ -1062,7 +1056,7 @@ def write_analysis(analysis: dict, output_dir: Path): dp = analysis.get("device_parallelism") if dp: - f.write(f"--- Device Evaluation Parallelism ---\n") + f.write("--- Device Evaluation Parallelism ---\n") f.write(f"Total devices: {dp['total_devices']}\n") for mt, mi in dp["model_types"].items(): f.write(f" {mt}: {mi['n_devices']} devices, ") @@ -1075,7 +1069,7 @@ def write_analysis(analysis: dict, output_dir: Path): es = analysis.get("eval_specialization") if es: - f.write(f"--- Eval Branch Specialization ---\n") + f.write("--- Eval Branch Specialization ---\n") for mt, info in es.items(): f.write(f" {mt}: {info['n_devices']} devices\n") f.write(f" Params: {info['n_shared_params']} shared, ") @@ -1089,11 +1083,11 @@ def write_analysis(analysis: dict, output_dir: Path): for vp in info["varying_static_params"]: f.write(f" {vp['name']} ({vp['kind']}): {vp['n_unique']} unique values\n") f.write(f" {info['specialization_note']}\n") - f.write(f"\n") + f.write("\n") notes = analysis.get("iree_notes") if notes: - f.write(f"--- IREE/Baspacho Notes ---\n") + f.write("--- IREE/Baspacho Notes ---\n") f.write(f"{notes['recommendation']}\n") print(f" Summary: {summary_path}") From 46cea11d98dbd479d542c473994278b739e2cf37 Mon Sep 17 00:00:00 2001 From: Rob Taylor Date: Sat, 7 Mar 2026 21:42:11 +0000 Subject: [PATCH 06/79] style: Apply ruff format to all files Co-developed-by: Claude Code v2.1.58 (claude-opus-4-6) --- scripts/analyze_parallelism.py | 42 ++++++++++++++++++++++------------ scripts/profile_nr_phases.py | 40 +++++++++++++++++++++----------- 2 files changed, 54 insertions(+), 28 deletions(-) diff --git a/scripts/analyze_parallelism.py b/scripts/analyze_parallelism.py index ba8b7814..a150de27 100644 --- a/scripts/analyze_parallelism.py +++ b/scripts/analyze_parallelism.py @@ -51,9 +51,7 @@ def symmetrize_pattern(A: sp.spmatrix) -> sp.csc_matrix: """Compute |A| + |A^T| as a binary pattern (no values, just structure).""" A_csc = sp.csc_matrix(A) # Binary pattern: set all values to 1 - A_bin = sp.csc_matrix( - (np.ones(A_csc.nnz), A_csc.indices, A_csc.indptr), shape=A_csc.shape - ) + A_bin = sp.csc_matrix((np.ones(A_csc.nnz), A_csc.indices, A_csc.indptr), shape=A_csc.shape) A_sym = A_bin + A_bin.T # Re-binarize (eliminates any 2s from diagonal overlap) A_sym.data[:] = 1.0 @@ -658,7 +656,7 @@ def analyze_matrix( "elimination_tree": { **etree_stats, "parallelism_efficiency": parallelism_efficiency, - "parent_array_sample": parent[:min(50, n)].tolist(), + "parent_array_sample": parent[: min(50, n)].tolist(), "note": ( f"Height {etree_stats['height']} levels with max width " f"{etree_stats['max_parallelism']}. Columns at the same level " @@ -759,12 +757,16 @@ def analyze_eval_branches(engine) -> dict: if n_devices > 1: vals = dp[:, col_idx] unique_vals = np.unique(vals) - varying_param_info.append({ - "name": name, - "kind": kind, - "n_unique": len(unique_vals), - "values": unique_vals.tolist() if len(unique_vals) <= 10 else f"{len(unique_vals)} values", - }) + varying_param_info.append( + { + "name": name, + "kind": kind, + "n_unique": len(unique_vals), + "values": unique_vals.tolist() + if len(unique_vals) <= 10 + else f"{len(unique_vals)} values", + } + ) result[model_type] = { "n_devices": n_devices, @@ -984,7 +986,9 @@ def write_analysis(analysis: dict, output_dir: Path): analysis_json = json.loads(json.dumps(analysis_json, default=str)) widths = analysis_json.get("elimination_tree", {}).get("level_widths", []) if len(widths) > 100: - analysis_json["elimination_tree"]["level_widths_truncated"] = widths[:50] + ["..."] + widths[-50:] + analysis_json["elimination_tree"]["level_widths_truncated"] = ( + widths[:50] + ["..."] + widths[-50:] + ) del analysis_json["elimination_tree"]["level_widths"] with open(json_path, "w") as f: @@ -1000,11 +1004,15 @@ def write_analysis(analysis: dict, output_dir: Path): f.write(f"{'=' * 70}\n\n") mat = analysis["matrix"] - f.write(f"Matrix: {mat['size']}x{mat['size']}, {mat['nnz']} nonzeros ({mat['density_pct']:.4f}%)\n") + f.write( + f"Matrix: {mat['size']}x{mat['size']}, {mat['nnz']} nonzeros ({mat['density_pct']:.4f}%)\n" + ) f.write(f"Bandwidth: {mat['bandwidth']}, Symmetric: {mat['is_structurally_symmetric']}\n") f.write(f"Connected components: {mat['connected_components']}\n") deg = mat["degree_stats"] - f.write(f"Row degree: min={deg['row_min']}, max={deg['row_max']}, mean={deg['row_mean']:.1f}\n") + f.write( + f"Row degree: min={deg['row_min']}, max={deg['row_max']}, mean={deg['row_mean']:.1f}\n" + ) f.write(f"Diagonal dominance: {mat['diagonal_dominance']['pct']:.1f}% of rows\n\n") et = analysis["elimination_tree"] @@ -1063,7 +1071,9 @@ def write_analysis(analysis: dict, output_dir: Path): f.write(f"{mi['jac_entries_per_device']} Jacobian entries/device, ") f.write(f"{mi['nodes_per_device']['mean']:.0f} nodes/device\n") sc = dp["scatter_conflicts"] - f.write(f"Scatter conflicts: {sc['conflict_positions']}/{sc['total_positions']} positions ") + f.write( + f"Scatter conflicts: {sc['conflict_positions']}/{sc['total_positions']} positions " + ) f.write(f"({sc['conflict_pct']:.1f}%), max fan-in={sc['max_fan_in']}\n") f.write(f"Fan-in distribution: {sc['fan_in_distribution']}\n\n") @@ -1081,7 +1091,9 @@ def write_analysis(analysis: dict, output_dir: Path): f.write(f" (sizes: {info['config_sizes']})\n") if info.get("varying_static_params"): for vp in info["varying_static_params"]: - f.write(f" {vp['name']} ({vp['kind']}): {vp['n_unique']} unique values\n") + f.write( + f" {vp['name']} ({vp['kind']}): {vp['n_unique']} unique values\n" + ) f.write(f" {info['specialization_note']}\n") f.write("\n") diff --git a/scripts/profile_nr_phases.py b/scripts/profile_nr_phases.py index f2ee4b9b..6100b4bb 100644 --- a/scripts/profile_nr_phases.py +++ b/scripts/profile_nr_phases.py @@ -40,20 +40,30 @@ def _start_phase(phase_name_bytes): """Record start time for a phase.""" - phase_name = phase_name_bytes.tobytes().decode() if hasattr(phase_name_bytes, "tobytes") else str(phase_name_bytes) + phase_name = ( + phase_name_bytes.tobytes().decode() + if hasattr(phase_name_bytes, "tobytes") + else str(phase_name_bytes) + ) _phase_clock[phase_name] = time.perf_counter_ns() def _end_phase(phase_name_bytes, iteration): """Record end time for a phase.""" - phase_name = phase_name_bytes.tobytes().decode() if hasattr(phase_name_bytes, "tobytes") else str(phase_name_bytes) + phase_name = ( + phase_name_bytes.tobytes().decode() + if hasattr(phase_name_bytes, "tobytes") + else str(phase_name_bytes) + ) start = _phase_clock.get(phase_name, 0) elapsed_ns = time.perf_counter_ns() - start - phase_timings.append({ - "phase": phase_name, - "iteration": int(iteration), - "elapsed_us": elapsed_ns / 1000, - }) + phase_timings.append( + { + "phase": phase_name, + "iteration": int(iteration), + "elapsed_us": elapsed_ns / 1000, + } + ) # --------------------------------------------------------------------------- @@ -141,12 +151,16 @@ def main(): parser = argparse.ArgumentParser(description="Profile NR phase breakdown") parser.add_argument("benchmark", help="Benchmark name (e.g. ring, c6288)") - parser.add_argument("--trace-dir", type=Path, default=None, - help="Directory for Perfetto trace (default: /tmp/claude/_trace)") - parser.add_argument("--t-stop", type=float, default=None, - help="Override stop time") - parser.add_argument("--n-steps", type=int, default=10, - help="Number of timesteps to profile (default: 10)") + parser.add_argument( + "--trace-dir", + type=Path, + default=None, + help="Directory for Perfetto trace (default: /tmp/claude/_trace)", + ) + parser.add_argument("--t-stop", type=float, default=None, help="Override stop time") + parser.add_argument( + "--n-steps", type=int, default=10, help="Number of timesteps to profile (default: 10)" + ) args = parser.parse_args() logging.getLogger("vajax").setLevel(logging.WARNING) From 8ee5bbc971c513c7ee57e647e9cd6c7d888d9c1d Mon Sep 17 00:00:00 2001 From: Rob Taylor Date: Sat, 7 Mar 2026 21:43:16 +0000 Subject: [PATCH 07/79] docs: Add linting instructions to CLAUDE.md Document ruff check and ruff format commands that CI enforces, so they're run before committing. Co-developed-by: Claude Code v2.1.58 (claude-opus-4-6) --- CLAUDE.md | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/CLAUDE.md b/CLAUDE.md index 7ec6dfed..9c2246c1 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -78,6 +78,22 @@ uv run python scripts/profile_gpu.py --benchmark ring,c6288 ``` +## Linting + +**Run before every commit.** CI runs both checks and will reject PRs that fail. + +```bash +# Lint check (import sorting, unused imports, etc.) +uv tool run ruff check vajax/ tests/ scripts/ benchmarks/ + +# Format check (code style) +uv tool run ruff format --check vajax/ tests/ scripts/ benchmarks/ + +# Auto-fix both +uv tool run ruff check --fix vajax/ tests/ scripts/ benchmarks/ +uv tool run ruff format vajax/ tests/ scripts/ benchmarks/ +``` + ## Precision Configuration Precision is auto-configured on import via `vajax/__init__.py`: From 373842c02c0bbcadb92d2a26c17568cdceb1611d Mon Sep 17 00:00:00 2001 From: Rob Taylor Date: Sun, 8 Mar 2026 00:48:39 +0000 Subject: [PATCH 08/79] feat: Wire SCCP dead branch elimination into cache-split codegen path MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The translate_eval_array_with_cache_split() method (used by prepare_static_inputs()) now builds sccp_known_values from concrete shared param and cache values, enabling SCCP to eliminate dead branches at Python codegen time — before JAX ever sees the code. PSP103 results: 695/954 MIR blocks dead, 47 static branches resolved, 7066 constants propagated, jnp.where reduced from 2247 to 1801 (20%). Co-developed-by: Claude Code v2.1.58 (claude-opus-4-6) --- openvaf_jax/__init__.py | 60 +++++++++++++++++++++++++++++++++++++++-- 1 file changed, 58 insertions(+), 2 deletions(-) diff --git a/openvaf_jax/__init__.py b/openvaf_jax/__init__.py index 265114f1..7b136e49 100644 --- a/openvaf_jax/__init__.py +++ b/openvaf_jax/__init__.py @@ -1332,6 +1332,37 @@ def translate_eval_array_with_cache_split( f"inlined_params={n_inlined_params}, inlined_cache={n_inlined_cache})..." ) + # Build SCCP known values from concrete shared params and cache + # This enables dead branch elimination at codegen time + sccp_known_values: Optional[Dict[str, Any]] = None + if concrete_shared_values is not None or concrete_shared_cache is not None: + sccp_known_values = {} + + # Map shared params to MIR value IDs + if concrete_shared_values is not None: + for j, orig_idx in enumerate(shared_indices): + value_id = self.param_idx_to_val.get(orig_idx) + if value_id: + sccp_known_values[value_id] = concrete_shared_values[j] + + # Map shared cache entries to MIR value IDs + if concrete_shared_cache is not None and shared_cache_indices: + for j, cache_col_idx in enumerate(shared_cache_indices): + mapping = self.cache_mapping[cache_col_idx] + eval_param_idx = mapping["eval_param"] + value_id = self.param_idx_to_val.get(eval_param_idx) + if value_id: + sccp_known_values[value_id] = concrete_shared_cache[j] + + if not sccp_known_values: + sccp_known_values = None + else: + logger.info( + f" SCCP: {len(sccp_known_values)} known values " + f"({n_inlined_params} from params, " + f"{len(sccp_known_values) - n_inlined_params} from cache)" + ) + # Build the eval function eval_param_names = list(self.module.param_names) builder = EvalFunctionBuilder( @@ -1339,6 +1370,7 @@ def translate_eval_array_with_cache_split( self.dae_data, self.cache_mapping, self.param_idx_to_val, + sccp_known_values=sccp_known_values, eval_param_names=eval_param_names, ) fn_name, code_lines = builder.build_with_cache_split( @@ -1353,6 +1385,22 @@ def translate_eval_array_with_cache_split( ) t1 = time.perf_counter() + + # Log SCCP statistics + if builder.sccp is not None: + dead_blocks = builder.sccp.get_dead_blocks() + total_blocks = len(self.eval_mir.blocks) + n_constants = sum(1 for v in builder.sccp.lattice.values() if v.is_constant()) + static_branches = sum( + 1 + for b in self.eval_mir.blocks + if builder.sccp.get_static_branch_direction(b) is not None + ) + logger.info( + f" SCCP results: {len(dead_blocks)}/{total_blocks} blocks dead, " + f"{static_branches} static branches, {n_constants} constants propagated" + ) + logger.info( f" translate_eval_array_with_cache_split: code generated ({len(code_lines)} lines) in {t1 - t0:.1f}s" ) @@ -1373,8 +1421,16 @@ def translate_eval_array_with_cache_split( df.write(f"# use_limit_functions={use_limit_functions}\n") df.write(f"# shared_indices={shared_indices}\n") df.write(f"# varying_indices={varying_indices}\n") - df.write(f"# concrete_shared_values={concrete_shared_values is not None} ({n_inlined_params} values)\n") - df.write(f"# concrete_shared_cache={concrete_shared_cache is not None} ({n_inlined_cache} values)\n") + df.write( + f"# concrete_shared_values={concrete_shared_values is not None} ({n_inlined_params} values)\n" + ) + df.write( + f"# concrete_shared_cache={concrete_shared_cache is not None} ({n_inlined_cache} values)\n" + ) + if builder.sccp is not None: + dead_blocks = builder.sccp.get_dead_blocks() + total_blocks = len(self.eval_mir.blocks) + df.write(f"# sccp: {len(dead_blocks)}/{total_blocks} blocks dead\n") df.write(code) logger.info(f" Generated code dumped to {dump_path}") From 90f379eccbd02c3f04faeab5fcb2b0d20d116fe6 Mon Sep 17 00:00:00 2001 From: Rob Taylor Date: Sun, 8 Mar 2026 02:11:42 +0000 Subject: [PATCH 09/79] style: Fix lint errors in check_constant_folding.py Co-developed-by: Claude Code v2.1.58 (claude-opus-4-6) --- scripts/check_constant_folding.py | 282 ++++++++++++++++++++++++++++++ 1 file changed, 282 insertions(+) create mode 100644 scripts/check_constant_folding.py diff --git a/scripts/check_constant_folding.py b/scripts/check_constant_folding.py new file mode 100644 index 00000000..52053657 --- /dev/null +++ b/scripts/check_constant_folding.py @@ -0,0 +1,282 @@ +#!/usr/bin/env python3 +# /// script +# requires-python = ">=3.10" +# dependencies = ["jax", "jaxlib"] +# /// +"""Check whether jnp.where constant folding works at jaxpr and HLO levels. + +Quick diagnostic to verify that inlining shared params as Python literals +actually eliminates jnp.where branches in the compiled XLA program. +""" + +import jax +import jax.numpy as jnp + + +def test_basic_constant_folding(): + """Test: does jnp.where with a Python bool constant-fold?""" + print("=" * 60) + print("Test 1: jnp.where with Python bool literal") + print("=" * 60) + + def f_const(x): + # Condition is a Python bool - should constant-fold + cond = True + return jnp.where(cond, x * 2, x * 3) + + def f_traced(x, flag): + # Condition is a traced value - cannot constant-fold + return jnp.where(flag > 0.0, x * 2, x * 3) + + x = jnp.ones(10) + + jaxpr_const = jax.make_jaxpr(f_const)(x) + jaxpr_traced = jax.make_jaxpr(f_traced)(x, 1.0) + + print(f"\nConstant cond jaxpr ({len(jaxpr_const.eqns)} ops):") + print(jaxpr_const) + print(f"\nTraced cond jaxpr ({len(jaxpr_traced.eqns)} ops):") + print(jaxpr_traced) + + # Check HLO + lowered_const = jax.jit(f_const).lower(x) + lowered_traced = jax.jit(f_traced).lower(x, 1.0) + hlo_const = lowered_const.as_text() + hlo_traced = lowered_traced.as_text() + + select_const = hlo_const.count("select") + select_traced = hlo_traced.count("select") + print(f"\nHLO select ops: constant={select_const}, traced={select_traced}") + + +def test_constant_through_jnp_ops(): + """Test: does constant folding survive through jnp operations?""" + print("\n" + "=" * 60) + print("Test 2: Constant folding through jnp operations") + print("=" * 60) + + def f_inlined(x): + # Simulate what our specialization does: + # shared param inlined as literal, then used in jnp ops + v_param = 1.5e-6 # Was: shared_params[42] + v_computed = jnp.exp(v_param) # jnp op on literal + cond = v_computed > 0.5 # comparison + return jnp.where(cond, x * 2, x * 3) + + def f_array_lookup(x, shared_params): + # Original: shared_params array lookup + v_param = shared_params[42] + v_computed = jnp.exp(v_param) + cond = v_computed > 0.5 + return jnp.where(cond, x * 2, x * 3) + + x = jnp.ones(10) + shared = jnp.zeros(100) + + jaxpr_inlined = jax.make_jaxpr(f_inlined)(x) + jaxpr_lookup = jax.make_jaxpr(f_array_lookup)(x, shared) + + print(f"\nInlined literal jaxpr ({len(jaxpr_inlined.eqns)} ops):") + print(jaxpr_inlined) + print(f"\nArray lookup jaxpr ({len(jaxpr_lookup.eqns)} ops):") + print(jaxpr_lookup) + + # Check HLO + lowered_inlined = jax.jit(f_inlined).lower(x) + lowered_lookup = jax.jit(f_array_lookup).lower(x, shared) + hlo_inlined = lowered_inlined.as_text() + hlo_lookup = lowered_lookup.as_text() + + select_inlined = hlo_inlined.count("select") + select_lookup = hlo_lookup.count("select") + print(f"\nHLO select ops: inlined={select_inlined}, lookup={select_lookup}") + + +def test_python_float_vs_jnp(): + """Test: Python float literal vs jnp operation - what does JAX trace?""" + print("\n" + "=" * 60) + print("Test 3: Python float arithmetic vs jnp arithmetic") + print("=" * 60) + + def f_python_arith(x): + # Pure Python: should constant-fold completely + a = 1.5e-6 + b = a * 2.0 # Python multiplication + cond = b > 1e-6 # Python comparison -> True + return jnp.where(cond, x * 2, x * 3) + + def f_jnp_arith(x): + # jnp operations: might NOT constant-fold in jaxpr + a = 1.5e-6 + b = jnp.float64(a) * 2.0 # jnp multiplication + cond = b > 1e-6 # comparison on jnp result + return jnp.where(cond, x * 2, x * 3) + + x = jnp.ones(10) + + jaxpr_python = jax.make_jaxpr(f_python_arith)(x) + jaxpr_jnp = jax.make_jaxpr(f_jnp_arith)(x) + + print(f"\nPython arith jaxpr ({len(jaxpr_python.eqns)} ops):") + print(jaxpr_python) + print(f"\njnp arith jaxpr ({len(jaxpr_jnp.eqns)} ops):") + print(jaxpr_jnp) + + # Check HLO for both + lowered_python = jax.jit(f_python_arith).lower(x) + lowered_jnp = jax.jit(f_jnp_arith).lower(x) + hlo_python = lowered_python.as_text() + hlo_jnp = lowered_jnp.as_text() + + select_python = hlo_python.count("select") + select_jnp = hlo_jnp.count("select") + print(f"\nHLO select ops: python_arith={select_python}, jnp_arith={select_jnp}") + + +def test_generated_code_pattern(): + """Test the ACTUAL pattern used in generated eval code. + + The generated code does: + v123 = 1.5e-6 # inlined literal (was shared_params[42]) + Then later uses it in jnp operations. + + Key question: does assigning a Python float to a local variable, + then using it in jnp.where, get constant-folded? + """ + print("\n" + "=" * 60) + print("Test 4: Actual generated code pattern (assign + jnp.where)") + print("=" * 60) + + def f_generated_pattern(device_params): + # This mimics the actual generated eval code pattern + # Inlined shared params + v100 = 1.0 # TYPE = 1 (NMOS) + v101 = 0.0 # SWIGATE = 0 + v102 = 1.5e-6 # TOX + + # Device params (from vmap, traced) + v200 = device_params[0] # voltage + _v201 = device_params[1] # another voltage (unused, kept for array shape) + + # Computation chain (mimics what OpenVAF generates) + v300 = jnp.exp(v102 * 1e6) # Uses inlined literal + v301 = v300 * v200 # Mixes with traced value + + # Branch on static param + v400 = v100 > 0.5 # TYPE > 0.5 -> True for NMOS + result1 = jnp.where(v400, v301, -v301) # Should fold + + # Branch on static param through jnp op + v401 = jnp.abs(v101) # jnp op on inlined literal + v402 = v401 > 0.5 # Should be False + result2 = jnp.where(v402, result1 * 2, result1 * 3) # Should fold + + return result1 + result2 + + def f_array_pattern(device_params, shared_params): + # Original pattern: array lookups (all traced) + v100 = shared_params[0] + v101 = shared_params[1] + v102 = shared_params[2] + + v200 = device_params[0] + _v201 = device_params[1] # noqa: F841 + + v300 = jnp.exp(v102 * 1e6) + v301 = v300 * v200 + + v400 = v100 > 0.5 + result1 = jnp.where(v400, v301, -v301) + + v401 = jnp.abs(v101) + v402 = v401 > 0.5 + result2 = jnp.where(v402, result1 * 2, result1 * 3) + + return result1 + result2 + + dp = jnp.array([0.5, 0.3]) + sp = jnp.array([1.0, 0.0, 1.5e-6]) + + jaxpr_gen = jax.make_jaxpr(f_generated_pattern)(dp) + jaxpr_arr = jax.make_jaxpr(f_array_pattern)(dp, sp) + + print(f"\nInlined pattern jaxpr ({len(jaxpr_gen.eqns)} ops):") + print(jaxpr_gen) + print(f"\nArray pattern jaxpr ({len(jaxpr_arr.eqns)} ops):") + print(jaxpr_arr) + + # Check HLO + lowered_gen = jax.jit(f_generated_pattern).lower(dp) + lowered_arr = jax.jit(f_array_pattern).lower(dp, sp) + hlo_gen = lowered_gen.as_text() + hlo_arr = lowered_arr.as_text() + + select_gen = hlo_gen.count("select") + select_arr = hlo_arr.count("select") + print(f"\nHLO select ops: inlined={select_gen}, array={select_arr}") + + # Also count total HLO ops + print(f"HLO lines: inlined={len(hlo_gen.splitlines())}, array={len(hlo_arr.splitlines())}") + + +def test_vmap_interaction(): + """Test: does constant folding survive vmap? + + This is crucial because we vmap the eval function over devices. + """ + print("\n" + "=" * 60) + print("Test 5: Constant folding under vmap") + print("=" * 60) + + def f_inlined(device_params): + v_type = 1.0 # Inlined: TYPE = NMOS + cond = v_type > 0.5 + return jnp.where(cond, device_params[0] * 2, device_params[0] * 3) + + def f_lookup(device_params, shared_params): + v_type = shared_params[0] + cond = v_type > 0.5 + return jnp.where(cond, device_params[0] * 2, device_params[0] * 3) + + # vmap over batch of devices + f_inlined_vmapped = jax.vmap(f_inlined) + f_lookup_vmapped = jax.vmap(f_lookup, in_axes=(0, None)) + + batch_dp = jnp.ones((4, 3)) + sp = jnp.array([1.0]) + + jaxpr_inlined = jax.make_jaxpr(f_inlined_vmapped)(batch_dp) + jaxpr_lookup = jax.make_jaxpr(f_lookup_vmapped)(batch_dp, sp) + + print(f"\nvmapped inlined jaxpr ({len(jaxpr_inlined.eqns)} ops):") + print(jaxpr_inlined) + print(f"\nvmapped lookup jaxpr ({len(jaxpr_lookup.eqns)} ops):") + print(jaxpr_lookup) + + # HLO + lowered_inlined = jax.jit(f_inlined_vmapped).lower(batch_dp) + lowered_lookup = jax.jit(f_lookup_vmapped).lower(batch_dp, sp) + hlo_inlined = lowered_inlined.as_text() + hlo_lookup = lowered_lookup.as_text() + + select_inlined = hlo_inlined.count("select") + select_lookup = hlo_lookup.count("select") + print(f"\nHLO select ops: inlined={select_inlined}, lookup={select_lookup}") + print( + f"HLO lines: inlined={len(hlo_inlined.splitlines())}, lookup={len(hlo_lookup.splitlines())}" + ) + + +if __name__ == "__main__": + print(f"JAX version: {jax.__version__}") + print(f"Platform: {jax.default_backend()}") + print( + f"x64 enabled: {jax.config.x86_64_enabled if hasattr(jax.config, 'x86_64_enabled') else 'unknown'}" + ) + print() + + test_basic_constant_folding() + test_constant_through_jnp_ops() + test_python_float_vs_jnp() + test_generated_code_pattern() + test_vmap_interaction() From cf9f4b3f1d13ae311d7e1ea5db324e4a5cb6a6c2 Mon Sep 17 00:00:00 2001 From: Rob Taylor Date: Mon, 9 Mar 2026 04:25:39 +0000 Subject: [PATCH 10/79] perf: Remove redundant Tikhonov regularization from dense linear solver The dense Jacobian assembly (assemble_dense_jacobian / _build_system_dense_direct) already adds gmin (1e-12) diagonal regularization. The linear_solve function was adding an additional 1e-14 * eye(n) which was redundant and added unnecessary computation (eye allocation + matrix addition) per NR solve. Co-developed-by: Claude Code v2.1.58 (claude-opus-4-6) --- vajax/analysis/solver_factories.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vajax/analysis/solver_factories.py b/vajax/analysis/solver_factories.py index 378e7c94..7f6aaa28 100644 --- a/vajax/analysis/solver_factories.py +++ b/vajax/analysis/solver_factories.py @@ -510,9 +510,9 @@ def enforce_noi(J, f): def linear_solve(J, f): """Solve J @ delta = -f using dense direct solver.""" - # Add Tikhonov regularization for numerical stability on GPU - reg = 1e-14 * jnp.eye(J.shape[0], dtype=J.dtype) - return jax.scipy.linalg.solve(J + reg, -f) + # Diagonal regularization (gmin + gshunt) is already applied during + # Jacobian assembly in assemble_dense_jacobian / _build_system_dense_direct. + return jax.scipy.linalg.solve(J, -f) logger.info( f"Creating dense full MNA solver: V({n_nodes}) + I({n_vsources}), " From d009e91b8288ebb5d8c6952d2aeed53822c8fcd8 Mon Sep 17 00:00:00 2001 From: Rob Taylor Date: Mon, 9 Mar 2026 04:30:49 +0000 Subject: [PATCH 11/79] perf: Make output recording unconditional in transient loop Replace jnp.where-wrapped writes to times_out, V_out, I_out with unconditional .at[step_idx].set(). On step rejection, step_idx doesn't advance so stale values get overwritten by the next accepted step. The caller trims output using step_idx, so values beyond it are ignored. This avoids materializing both branches of jnp.where on the full output arrays (up to max_steps x n_nodes) every timestep, saving ~8-12% per step. Co-developed-by: Claude Code v2.1.58 (claude-opus-4-6) --- vajax/analysis/transient/full_mna.py | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/vajax/analysis/transient/full_mna.py b/vajax/analysis/transient/full_mna.py index 80d58dab..fbf330ed 100644 --- a/vajax/analysis/transient/full_mna.py +++ b/vajax/analysis/transient/full_mna.py @@ -1767,21 +1767,18 @@ def _debug_step_callback( # Compute the voltage to record - use new_X which is the actual solution we're using # (either X_new if converged, or previous X if NR failed at min_dt) V_to_record = new_X[:n_external] - new_times_out = jnp.where( - accept_step, state.times_out.at[state.step_idx].set(t_next), state.times_out - ) - new_V_out = jnp.where( - accept_step, state.V_out.at[state.step_idx].set(V_to_record), state.V_out - ) - # For currents, use zero if NR failed at min_dt (current from bad solution is unreliable) + # Write unconditionally: on rejection step_idx doesn't advance, so stale + # values at step_idx get overwritten by the next accepted step. The caller + # trims output using step_idx, so values beyond it are ignored. This avoids + # materializing both branches of jnp.where on the full output arrays. I_to_record = jnp.where( nr_failed_at_min_dt, jnp.zeros(n_vsources, dtype=dtype) if n_vsources > 0 else jnp.zeros(1, dtype=dtype), I_vsource[:n_vsources] if n_vsources > 0 else jnp.zeros(1, dtype=dtype), ) - new_I_out = jnp.where( - accept_step, state.I_out.at[state.step_idx].set(I_to_record), state.I_out - ) + new_times_out = state.times_out.at[state.step_idx].set(t_next) + new_V_out = state.V_out.at[state.step_idx].set(V_to_record) + new_I_out = state.I_out.at[state.step_idx].set(I_to_record) new_step_idx = jnp.where(accept_step, state.step_idx + 1, state.step_idx) # Statistics From 29ac4c3a988476718d82bc4dc93438df74f4e7f7 Mon Sep 17 00:00:00 2001 From: Rob Taylor Date: Mon, 9 Mar 2026 05:07:14 +0000 Subject: [PATCH 12/79] fix: Fix VACASK test paths and wire dump_jaxpr into compare_vacask - Fix test_vacask_suite.py VACASK_ROOT to use vendored path (vendor/VACASK) instead of sibling directory (../VACASK). This was causing all 78 discovery/parsing tests to silently skip. - Add dump_jaxpr() to CircuitEngine for analyzing compiled simulation functions (build_system + nr_solve) via jaxpr and HLO dumps. - Wire dump_jaxpr into compare_vacask.py --analyze flag, replacing the scan-only analysis path. Now works with the default while_loop path. - Remove unused analyze_compiled_function from compare_vacask.py. Co-developed-by: Claude Code v2.1.58 (claude-opus-4-6) --- scripts/compare_vacask.py | 139 +---------------------- tests/test_vacask_suite.py | 6 +- vajax/analysis/engine.py | 226 +++++++++++++++++++++++++++++++++++++ 3 files changed, 234 insertions(+), 137 deletions(-) diff --git a/scripts/compare_vacask.py b/scripts/compare_vacask.py index 20063930..070c9342 100644 --- a/scripts/compare_vacask.py +++ b/scripts/compare_vacask.py @@ -52,7 +52,6 @@ # Note: Set JAX_PLATFORMS=cpu before running for CPU-only mode import jax -import jax.numpy as jnp # Precision is auto-configured by vajax import (imported above via logging) # Metal/TPU use f32, CPU/CUDA use f64 @@ -64,104 +63,6 @@ ) from vajax.profiling import ProfileConfig - -def analyze_compiled_function(fn, args, name: str, output_dir: Optional[Path] = None): - """Dump jaxpr and cost analysis for a JIT-compiled function. - - Args: - fn: A JAX function (JIT-compiled or not) - args: Example arguments to trace with - name: Name for output files - output_dir: Optional directory to save analysis files - """ - print(f"\n{'=' * 70}") - print(f"JAX Analysis: {name}") - print(f"{'=' * 70}") - - # Lower the function to get HLO and cost analysis - # For JIT-compiled functions, we use .lower() directly - print(f"\n--- Lowering and compiling {name} ---") - try: - # If fn is already jitted, we can lower it directly - # Otherwise wrap it in jit first - if hasattr(fn, "lower"): - lowered = fn.lower(*args) - else: - lowered = jax.jit(fn).lower(*args) - - # Get the HLO text (MLIR representation) - hlo_text = lowered.as_text() - hlo_lines = hlo_text.split("\n") - print(f"HLO text: {len(hlo_lines)} lines") - - # Count operations in HLO - op_counts: Dict[str, int] = {} - for line in hlo_lines: - # Extract operation names from MLIR-style ops like: %0 = stablehlo.add - if "=" in line and "." in line: - parts = line.split("=") - if len(parts) >= 2: - op_part = parts[1].strip().split()[0] if parts[1].strip() else "" - if "." in op_part: - op_name = op_part.split("(")[0] # Remove args - op_counts[op_name] = op_counts.get(op_name, 0) + 1 - - if op_counts: - print(f"Top HLO ops: {dict(sorted(op_counts.items(), key=lambda x: -x[1])[:15])}") - - # Compile and get cost analysis - compiled = lowered.compile() - cost = compiled.cost_analysis() - print("\n--- Cost Analysis ---") - if cost: - for i, device_cost in enumerate(cost): - if device_cost and isinstance(device_cost, dict): - print(f"Device {i}:") - for key, val in device_cost.items(): - if isinstance(val, (int, float)): - if val > 1e9: - print(f" {key}: {val / 1e9:.2f}G") - elif val > 1e6: - print(f" {key}: {val / 1e6:.2f}M") - elif val > 1e3: - print(f" {key}: {val / 1e3:.2f}K") - else: - print(f" {key}: {val}") - else: - print(f" {key}: {val}") - elif device_cost: - print(f"Device {i}: {device_cost}") - else: - print("No cost analysis available (may not be supported on this backend)") - - # Save files if output_dir provided - if output_dir: - output_dir.mkdir(parents=True, exist_ok=True) - - # Save HLO text - hlo_file = output_dir / f"{name}_hlo.txt" - with open(hlo_file, "w") as f: - f.write(hlo_text) - print(f"\nHLO text saved to: {hlo_file}") - - # Try to get the jaxpr as well for the underlying computation - try: - # Create jaxpr from the unwrapped function if possible - jaxpr_text = str(jax.make_jaxpr(fn)(*args)) - jaxpr_file = output_dir / f"{name}_jaxpr.txt" - with open(jaxpr_file, "w") as f: - f.write(jaxpr_text) - print(f"JAXPR saved to: {jaxpr_file}") - except Exception: - pass # JIT functions may not produce useful jaxpr - - except Exception as e: - import traceback - - print(f"Failed to analyze: {e}") - traceback.print_exc() - - # Note: Benchmark configurations are now auto-discovered from # vajax.benchmarks.registry. The registry parses .sim files # to extract dt, t_stop, and device types automatically. @@ -451,41 +352,11 @@ def do_run(): ) startup_time = time.perf_counter() - startup_start - # Run analysis on compiled scan function if requested - if analyze and use_scan and hasattr(engine, "_cached_scan_fn"): - print("\n Running JAX analysis...") - # Get example inputs for the scan function - # The scan function signature is: (V_init, Q_init, all_vsource, all_isource, device_arrays) - # Must use total nodes (external + internal) from transient setup cache - setup_cache = getattr(engine, "_transient_setup_cache", None) - device_arrays = getattr(engine, "_device_arrays", None) - - if setup_cache is None or device_arrays is None: - print(" Warning: transient setup cache not found - analysis skipped") - else: - n_total = setup_cache["n_total"] - n_unknowns = setup_cache["n_unknowns"] - n_vsources = len([d for d in engine.devices if d["model"] == "vsource"]) - n_isources = len([d for d in engine.devices if d["model"] == "isource"]) - - # Create example arrays matching actual shapes - V_init = jnp.zeros(n_total, dtype=jnp.float64) - Q_init = jnp.zeros(n_unknowns, dtype=jnp.float64) - all_vsource = jnp.zeros((num_steps, n_vsources), dtype=jnp.float64) - all_isource = jnp.zeros((num_steps, n_isources), dtype=jnp.float64) - - # Determine output directory - out_dir = analyze_output_dir or Path(f"/tmp/vajax-analysis/{config.name}") - - # Analyze the scan function - analyze_compiled_function( - engine._cached_scan_fn, - (V_init, Q_init, all_vsource, all_isource, device_arrays), - f"{config.name}_scan_simulation", - out_dir, - ) - elif analyze and use_scan: - print("\n Warning: _cached_scan_fn not found - analysis skipped") + # Run analysis on compiled functions if requested + if analyze: + out_dir = analyze_output_dir or Path(f"/tmp/claude/jaxpr-analysis/{config.name}") + print(f"\n Dumping jaxpr/HLO analysis to {out_dir} ...") + engine.dump_jaxpr(out_dir) # Timed run - print perf_counter for correlation with Perfetto traces # prepare() already called above with same params, strategy is cached diff --git a/tests/test_vacask_suite.py b/tests/test_vacask_suite.py index c2538d13..37ab5804 100644 --- a/tests/test_vacask_suite.py +++ b/tests/test_vacask_suite.py @@ -22,12 +22,12 @@ from vajax.netlist.parser import parse_netlist -# Paths - VACASK is at ../VACASK relative to vajax +# Paths - VACASK is vendored at vendor/VACASK VAJAX_ROOT = Path(__file__).parent.parent -VACASK_ROOT = VAJAX_ROOT.parent / "VACASK" +VACASK_ROOT = VAJAX_ROOT / "vendor" / "VACASK" VACASK_TEST = VACASK_ROOT / "test" VACASK_DEVICES = VACASK_ROOT / "devices" -VACASK_BENCHMARK = VAJAX_ROOT / "vendor" / "VACASK" / "benchmark" +VACASK_BENCHMARK = VACASK_ROOT / "benchmark" def discover_benchmark_dirs() -> List[Path]: diff --git a/vajax/analysis/engine.py b/vajax/analysis/engine.py index 3e9fb54d..4ec9ef29 100644 --- a/vajax/analysis/engine.py +++ b/vajax/analysis/engine.py @@ -1022,6 +1022,232 @@ def run_transient(self) -> TransientResult: stats=stats, ) + def dump_jaxpr(self, output_dir: str | Path = "/tmp/claude/jaxpr-analysis") -> Path: + """Dump JAX IR analysis for the compiled simulation functions. + + Analyzes two hot-path functions after prepare(): + 1. build_system - Jacobian + residual assembly (per NR iteration) + 2. nr_solve - Full Newton-Raphson solve (per timestep) + + Uses the actual circuit's device_arrays and dimensions from the prepared + strategy, matching the calling conventions used during simulation. + + For each function, writes: + - HLO text (StableHLO MLIR representation) + - HLO operation counts (top ops by frequency) + - XLA cost analysis (flops, memory bytes) + + Args: + output_dir: Directory for output files. + + Returns: + Path to the output directory. + """ + if not getattr(self, "_prepared", False): + self.prepare() + + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + strategy = self._prepared_strategy + + # Get cached functions and actual circuit data from the strategy + build_system_jit = getattr(strategy, "_cached_build_system_jit", None) + nr_solve = getattr(strategy, "_cached_full_mna_solver", None) + device_arrays = getattr(strategy, "_device_arrays_full_mna", None) + total_limit_states = getattr(strategy, "_total_limit_states", 0) + + if build_system_jit is None or nr_solve is None or device_arrays is None: + raise RuntimeError("No cached solver found. Call prepare() first.") + + # Get dimensions from transient setup cache + setup_cache = self._transient_setup_cache + n_total = setup_cache["n_total"] + n_unknowns = setup_cache["n_unknowns"] + n_vsources = len([d for d in self.devices if d["model"] == "vsource"]) + n_isources = len([d for d in self.devices if d["model"] == "isource"]) + + logger.info( + f"dump_jaxpr: n_total={n_total}, n_unknowns={n_unknowns}, " + f"n_vsources={n_vsources}, limit_states={total_limit_states}" + ) + + # Use the actual circuit's device_arrays and correct-shaped args. + # Values don't matter for HLO tracing — only shapes and dtypes. + # Match the calling conventions from full_mna.py DC solve path. + dtype = get_float_dtype() + X = jnp.zeros(n_total + n_vsources, dtype=dtype) + vsource_vals = jnp.zeros(n_vsources, dtype=dtype) + isource_vals = jnp.zeros(max(n_isources, 1), dtype=dtype) + Q_prev = jnp.zeros(n_unknowns, dtype=dtype) + integ_c0 = jnp.asarray(0.0, dtype=dtype) + gmin = jnp.asarray(1e-12, dtype=dtype) + gshunt = jnp.asarray(0.0, dtype=dtype) + integ_c1 = jnp.asarray(0.0, dtype=dtype) + integ_d1 = jnp.asarray(0.0, dtype=dtype) + dQdt_prev = jnp.zeros(n_unknowns, dtype=dtype) + integ_c2 = jnp.asarray(0.0, dtype=dtype) + Q_prev2 = jnp.zeros(n_unknowns, dtype=dtype) + limit_state = jnp.zeros(total_limit_states, dtype=dtype) + nr_iter = jnp.asarray(1, dtype=jnp.int32) + + # build_system_jit signature: (X, vsource_vals, isource_vals, Q_prev, + # integ_c0, device_arrays, gmin, gshunt, integ_c1, integ_d1, + # dQdt_prev, integ_c2, Q_prev2, limit_state, nr_iter) + build_args = ( + X, + vsource_vals, + isource_vals, + Q_prev, + integ_c0, + device_arrays, + gmin, + gshunt, + integ_c1, + integ_d1, + dQdt_prev, + integ_c2, + Q_prev2, + limit_state, + nr_iter, + ) + + # nr_solve signature: (X_init, vsource_vals, isource_vals, Q_prev, + # integ_c0, device_arrays, gmin, gshunt, integ_c1, integ_d1, + # dQdt_prev, integ_c2, Q_prev2, limit_state_in) + # Uses None defaults for optional args, matching DC solve convention. + nr_args = ( + X, + vsource_vals, + isource_vals, + Q_prev, + integ_c0, + device_arrays, + gmin, + gshunt, + integ_c1, + integ_d1, + None, + integ_c2, + None, + None, + ) + + results = {} + for name, fn, args in [ + ("build_system", build_system_jit, build_args), + ("nr_solve", nr_solve, nr_args), + ]: + results[name] = self._dump_single_jaxpr(name, fn, args, output_dir) + + return output_dir + + @staticmethod + def _dump_single_jaxpr(name: str, fn, args, output_dir: Path) -> dict[str, Any]: + """Analyze and dump jaxpr/HLO for a single function. + + Produces three artifacts per function: + - {name}.jaxpr.txt — JAX's high-level IR (from jax.make_jaxpr) + - {name}.hlo.txt — StableHLO/MLIR after lowering + - Log output with op counts and XLA cost analysis + + Returns: + Dict with 'jaxpr', 'hlo_lines', 'op_counts', 'cost' keys. + """ + result: dict[str, Any] = {"name": name} + logger.info(f"Analyzing {name}...") + + try: + # 1. Jaxpr (high-level JAX IR) + jaxpr = jax.make_jaxpr(fn)(*args) + jaxpr_text = str(jaxpr) + jaxpr_lines = jaxpr_text.split("\n") + result["jaxpr"] = jaxpr + result["jaxpr_lines"] = len(jaxpr_lines) + logger.info(f" {name}: {len(jaxpr_lines)} jaxpr lines") + + jaxpr_file = output_dir / f"{name}.jaxpr.txt" + with open(jaxpr_file, "w") as f: + f.write(jaxpr_text) + result["jaxpr_file"] = jaxpr_file + logger.info(f" Saved jaxpr: {jaxpr_file}") + + # Count jaxpr primitives + jaxpr_ops: dict[str, int] = {} + for eqn in jaxpr.jaxpr.eqns: + prim_name = str(eqn.primitive.name) + jaxpr_ops[prim_name] = jaxpr_ops.get(prim_name, 0) + 1 + result["jaxpr_ops"] = dict(sorted(jaxpr_ops.items(), key=lambda x: -x[1])) + if jaxpr_ops: + sorted_jaxpr_ops = sorted(jaxpr_ops.items(), key=lambda x: -x[1]) + logger.info(f" {name}: Top jaxpr primitives:") + for op, count in sorted_jaxpr_ops[:20]: + logger.info(f" {op:45s} {count:6d}") + logger.info(f" Total jaxpr eqns: {len(jaxpr.jaxpr.eqns)}") + + # 2. HLO (lowered StableHLO) + if hasattr(fn, "lower"): + lowered = fn.lower(*args) + else: + lowered = jax.jit(fn).lower(*args) + + hlo_text = lowered.as_text() + hlo_lines = hlo_text.split("\n") + result["hlo_lines"] = len(hlo_lines) + logger.info(f" {name}: {len(hlo_lines)} HLO lines") + + # Count HLO operations + op_counts: dict[str, int] = {} + for line in hlo_lines: + if "=" in line and "." in line: + parts = line.split("=") + if len(parts) >= 2: + op_part = parts[1].strip().split()[0] if parts[1].strip() else "" + if "." in op_part: + op_name = op_part.split("(")[0] + op_counts[op_name] = op_counts.get(op_name, 0) + 1 + + result["op_counts"] = dict(sorted(op_counts.items(), key=lambda x: -x[1])) + if op_counts: + sorted_ops = sorted(op_counts.items(), key=lambda x: -x[1]) + logger.info(f" {name}: Top HLO ops:") + for op, count in sorted_ops[:20]: + logger.info(f" {op:45s} {count:6d}") + logger.info(f" Total unique op types: {len(op_counts)}") + logger.info(f" Total ops: {sum(op_counts.values())}") + + # 3. Cost analysis + compiled = lowered.compile() + cost = compiled.cost_analysis() + result["cost"] = cost + if cost: + for i, device_cost in enumerate(cost): + if device_cost and isinstance(device_cost, dict): + logger.info(f" {name} cost (device {i}):") + for key, val in sorted(device_cost.items()): + if isinstance(val, (int, float)): + if val > 1e9: + logger.info(f" {key}: {val / 1e9:.2f}G") + elif val > 1e6: + logger.info(f" {key}: {val / 1e6:.2f}M") + elif val > 1e3: + logger.info(f" {key}: {val / 1e3:.2f}K") + else: + logger.info(f" {key}: {val:.2f}") + + # Save HLO + hlo_file = output_dir / f"{name}.hlo.txt" + with open(hlo_file, "w") as f: + f.write(hlo_text) + result["hlo_file"] = hlo_file + logger.info(f" Saved HLO: {hlo_file}") + + except Exception as e: + logger.error(f"Failed to analyze {name}: {e}", exc_info=True) + result["error"] = str(e) + + return result + # ========================================================================= # Node Collapse Implementation # ========================================================================= From 5476ad2ffb4d1963fb54f8adcbd7c7da5d14bab3 Mon Sep 17 00:00:00 2001 From: Rob Taylor Date: Mon, 9 Mar 2026 14:15:46 +0000 Subject: [PATCH 13/79] fix: Stop inlining shared params as literals in generated eval code MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit SCCP dead-block elimination is kept (eliminates 695/954 MIR blocks for PSP103), but the generated Python code now uses shared_params[i] array reads instead of literal float constants. Literal inlining (~1300 constants for PSP103) caused a 7.8x GPU regression on ring by embedding too many constants in the XLA kernel, hurting register pressure and instruction cache. CPU was 14% faster with inlining, but GPU went from 1.43ms/step to 11.17ms/step. The fix: pass concrete values to SCCP for dead-block analysis only, via a new build_sccp_known_values() method. The builder no longer accepts concrete_shared_values/concrete_shared_cache parameters — all generated code uses array reads. Co-developed-by: Claude Code v2.1.58 (claude-opus-4-6) --- openvaf_jax/__init__.py | 105 +++++++++++++----------- openvaf_jax/codegen/function_builder.py | 55 ++++--------- vajax/analysis/openvaf_models.py | 35 ++++---- 3 files changed, 88 insertions(+), 107 deletions(-) diff --git a/openvaf_jax/__init__.py b/openvaf_jax/__init__.py index 7b136e49..bc29f910 100644 --- a/openvaf_jax/__init__.py +++ b/openvaf_jax/__init__.py @@ -1266,6 +1266,51 @@ def translate_init_array_split( return init_fn, metadata + def build_sccp_known_values( + self, + shared_indices: List[int], + shared_values: List[float], + shared_cache_indices: Optional[List[int]] = None, + shared_cache_values: Optional[List[float]] = None, + ) -> Optional[Dict[str, Any]]: + """Build SCCP known_values dict from concrete shared params and cache. + + Maps concrete values to their MIR value IDs for SCCP dead-block elimination. + This tells the SCCP pass which MIR values are constant, enabling it to + resolve static branches and eliminate dead blocks at codegen time. + + NOTE: These values are used ONLY for SCCP analysis, NOT for literal + inlining in the generated Python code. Array reads (shared_params[i], + shared_cache[i]) are preserved in the generated code for GPU efficiency. + Literal inlining was found to cause 7.8x GPU regression on PSP103 by + embedding ~1300 float constants in the XLA kernel. + + Args: + shared_indices: Original param indices for shared params + shared_values: Concrete float values for each shared param + shared_cache_indices: Cache column indices for shared cache entries + shared_cache_values: Concrete float values for each shared cache entry + + Returns: + Dict mapping MIR value IDs to concrete values, or None if empty. + """ + known: Dict[str, Any] = {} + + for j, orig_idx in enumerate(shared_indices): + value_id = self.param_idx_to_val.get(orig_idx) + if value_id: + known[value_id] = shared_values[j] + + if shared_cache_values is not None and shared_cache_indices: + for j, cache_col_idx in enumerate(shared_cache_indices): + mapping = self.cache_mapping[cache_col_idx] + eval_param_idx = mapping["eval_param"] + value_id = self.param_idx_to_val.get(eval_param_idx) + if value_id: + known[value_id] = shared_cache_values[j] + + return known if known else None + def translate_eval_array_with_cache_split( self, shared_indices: List[int], @@ -1274,8 +1319,7 @@ def translate_eval_array_with_cache_split( varying_cache_indices: Optional[List[int]] = None, use_limit_functions: bool = False, limit_param_map: Optional[Dict[int, Tuple[str, str]]] = None, - concrete_shared_values: Optional[List[float]] = None, - concrete_shared_cache: Optional[List[float]] = None, + sccp_known_values: Optional[Dict[str, Any]] = None, ) -> Tuple[Callable, Dict]: """Generate a vmappable eval function with split params and cache (internal API). @@ -1293,12 +1337,10 @@ def translate_eval_array_with_cache_split( limit_param_map: Dict mapping original param indices to (kind, name) tuples for limit-related params (prev_state, enable_lim, new_state, enable_integration). Excluded from shared/device params. - concrete_shared_values: If provided, concrete float values for each - shared param. Inlined as Python literals for trace-time - constant folding, eliminating jnp.where branches that - depend on static parameters. - concrete_shared_cache: If provided, concrete float values for each - shared cache entry. Same inlining behavior as above. + sccp_known_values: If provided, maps MIR value IDs to concrete values + for SCCP dead-block elimination. Build this with + build_sccp_known_values(). Values are used for SCCP + analysis only — NOT inlined as literals in generated code. Returns: Tuple of (eval_fn, metadata) @@ -1324,44 +1366,14 @@ def translate_eval_array_with_cache_split( assert self.dae_data is not None, "dae_data released, call before release_mir_data()" t0 = time.perf_counter() - n_inlined_params = len(concrete_shared_values) if concrete_shared_values else 0 - n_inlined_cache = len(concrete_shared_cache) if concrete_shared_cache else 0 + n_sccp = len(sccp_known_values) if sccp_known_values else 0 logger.info( f" translate_eval_array_with_cache_split: generating code " - f"(limit_funcs={use_limit_functions}, " - f"inlined_params={n_inlined_params}, inlined_cache={n_inlined_cache})..." + f"(limit_funcs={use_limit_functions}, sccp_known={n_sccp})..." ) - # Build SCCP known values from concrete shared params and cache - # This enables dead branch elimination at codegen time - sccp_known_values: Optional[Dict[str, Any]] = None - if concrete_shared_values is not None or concrete_shared_cache is not None: - sccp_known_values = {} - - # Map shared params to MIR value IDs - if concrete_shared_values is not None: - for j, orig_idx in enumerate(shared_indices): - value_id = self.param_idx_to_val.get(orig_idx) - if value_id: - sccp_known_values[value_id] = concrete_shared_values[j] - - # Map shared cache entries to MIR value IDs - if concrete_shared_cache is not None and shared_cache_indices: - for j, cache_col_idx in enumerate(shared_cache_indices): - mapping = self.cache_mapping[cache_col_idx] - eval_param_idx = mapping["eval_param"] - value_id = self.param_idx_to_val.get(eval_param_idx) - if value_id: - sccp_known_values[value_id] = concrete_shared_cache[j] - - if not sccp_known_values: - sccp_known_values = None - else: - logger.info( - f" SCCP: {len(sccp_known_values)} known values " - f"({n_inlined_params} from params, " - f"{len(sccp_known_values) - n_inlined_params} from cache)" - ) + if sccp_known_values: + logger.info(f" SCCP: {n_sccp} known values for dead-block elimination") # Build the eval function eval_param_names = list(self.module.param_names) @@ -1380,8 +1392,6 @@ def translate_eval_array_with_cache_split( varying_cache_indices, use_limit_functions=use_limit_functions, limit_param_map=limit_param_map, - concrete_shared_values=concrete_shared_values, - concrete_shared_cache=concrete_shared_cache, ) t1 = time.perf_counter() @@ -1421,12 +1431,7 @@ def translate_eval_array_with_cache_split( df.write(f"# use_limit_functions={use_limit_functions}\n") df.write(f"# shared_indices={shared_indices}\n") df.write(f"# varying_indices={varying_indices}\n") - df.write( - f"# concrete_shared_values={concrete_shared_values is not None} ({n_inlined_params} values)\n" - ) - df.write( - f"# concrete_shared_cache={concrete_shared_cache is not None} ({n_inlined_cache} values)\n" - ) + df.write(f"# sccp_known_values={n_sccp}\n") if builder.sccp is not None: dead_blocks = builder.sccp.get_dead_blocks() total_blocks = len(self.eval_mir.blocks) diff --git a/openvaf_jax/codegen/function_builder.py b/openvaf_jax/codegen/function_builder.py index 5f8b8a07..0c04bc9d 100644 --- a/openvaf_jax/codegen/function_builder.py +++ b/openvaf_jax/codegen/function_builder.py @@ -818,8 +818,6 @@ def build_with_cache_split( simparam_params: Optional[Dict[int, str]] = None, use_limit_functions: bool = False, limit_param_map: Optional[Dict[int, Tuple[str, str]]] = None, - concrete_shared_values: Optional[List[float]] = None, - concrete_shared_cache: Optional[List[float]] = None, ) -> Tuple[str, List[str]]: """Build eval function with split params and optional split cache. @@ -840,13 +838,6 @@ def build_with_cache_split( for limit-related params (prev_state, enable_lim, new_state, enable_integration). These are read from limit_state_in or set to constants instead of shared/device params. - concrete_shared_values: If provided, concrete float values for each - shared param index. When set, shared params are emitted - as Python literals instead of shared_params[N] lookups. - This enables JAX trace-time constant folding, eliminating - jnp.where branches that depend on static parameters. - concrete_shared_cache: If provided, concrete float values for each - shared cache index. Same inlining behavior as above. Returns: Tuple of (function_name, code_lines) @@ -906,11 +897,11 @@ def build_with_cache_split( body.append(assign("v3", ctx.zero())) ctx.defined_vars.add("v3") - # Map params from split arrays (inline concrete values when provided) - self._emit_param_mapping(body, ctx, idx_mapping, concrete_shared_values) + # Map params from split arrays + self._emit_param_mapping(body, ctx, idx_mapping) - # Map cache values (inline concrete values when provided) - self._emit_cache_mapping(body, ctx, cache_idx_mapping, concrete_shared_cache) + # Map cache values + self._emit_cache_mapping(body, ctx, cache_idx_mapping) # Pre-initialize all output variables to 0.0 to avoid NameError # for variables only assigned in conditional branches (NMOS/PMOS paths) @@ -993,7 +984,6 @@ def _emit_param_mapping( body: List[ast.stmt], ctx: CodeGenContext, idx_mapping: Dict[int, Tuple[str, any]], - concrete_shared_values: Optional[List[float]] = None, ): """Emit parameter mapping from split arrays. @@ -1004,9 +994,6 @@ def _emit_param_mapping( - source='shared': value is new index in shared_params - source='device': value is new index in device_params - source='simparam': value is simparam name (e.g., '$abstime') - concrete_shared_values: If provided, concrete float values for each - shared param. When set, shared params are emitted as Python literals - (e.g., `v123 = 1.5e-6`) instead of `v123 = shared_params[42]`. """ for i, param in enumerate(self.mir_func.params): var_name = f"{ctx.var_prefix}{param}" @@ -1014,16 +1001,12 @@ def _emit_param_mapping( if i in idx_mapping: source, value = idx_mapping[i] if source == "shared": - if concrete_shared_values is not None: - # Inline as Python literal for trace-time constant folding - body.append(assign(var_name, ast_const(concrete_shared_values[value]))) - else: - body.append( - assign( - var_name, - subscript(ast_name("shared_params"), ast_const(value)), - ) + body.append( + assign( + var_name, + subscript(ast_name("shared_params"), ast_const(value)), ) + ) elif source == "device": body.append( assign(var_name, subscript(ast_name("device_params"), ast_const(value))) @@ -1060,7 +1043,6 @@ def _emit_cache_mapping( body: List[ast.stmt], ctx: CodeGenContext, cache_idx_mapping: Dict[int, Tuple[str, int]], - concrete_shared_cache: Optional[List[float]] = None, ): """Emit cache value mapping from split cache arrays. @@ -1070,9 +1052,6 @@ def _emit_cache_mapping( body: List to append statements to ctx: Code generation context cache_idx_mapping: Maps cache index to (source, new_index) - concrete_shared_cache: If provided, concrete float values for each - shared cache entry. When set, shared cache values are emitted as - Python literals instead of shared_cache[N] lookups. """ for cache_idx, mapping in enumerate(self.cache_mapping): eval_param_idx = mapping["eval_param"] @@ -1082,18 +1061,12 @@ def _emit_cache_mapping( if cache_idx in cache_idx_mapping: source, new_idx = cache_idx_mapping[cache_idx] if source == "shared_cache": - if concrete_shared_cache is not None: - # Inline as Python literal for trace-time constant folding - body.append( - assign(var_name, ast_const(concrete_shared_cache[new_idx])) - ) - else: - body.append( - assign( - var_name, - subscript(ast_name("shared_cache"), ast_const(new_idx)), - ) + body.append( + assign( + var_name, + subscript(ast_name("shared_cache"), ast_const(new_idx)), ) + ) else: body.append( assign(var_name, subscript(ast_name("device_cache"), ast_const(new_idx))) diff --git a/vajax/analysis/openvaf_models.py b/vajax/analysis/openvaf_models.py index 23dbe2a7..49f5b4c9 100644 --- a/vajax/analysis/openvaf_models.py +++ b/vajax/analysis/openvaf_models.py @@ -915,22 +915,27 @@ def prepare_static_inputs( shared_cache_indices = [] varying_cache_indices = [] - # Split cache arrays (needed before eval codegen for inlining) + # Split cache arrays shared_cache = cache[0, shared_cache_indices] device_cache = cache[:, varying_cache_indices] - # Prepare concrete values for branch specialization: - # Inline shared params and shared cache as Python literals in the - # generated eval function. JAX tracing evaluates constant expressions - # at trace time, so downstream jnp.where(const_bool, a, b) only - # traces the taken branch — eliminating all static-param branches. - concrete_shared_values = shared_params_list # already List[float] - concrete_shared_cache_values = [float(v) for v in np.asarray(shared_cache)] - logger.info( - f"{model_type}: branch specialization: inlining " - f"{len(concrete_shared_values)} shared params + " - f"{len(concrete_shared_cache_values)} shared cache values as literals" + # Build SCCP known values for dead-block elimination. + # Shared params and cache are constant per simulation, so SCCP can + # resolve static branches and eliminate dead MIR blocks at codegen time. + # NOTE: values are used for SCCP analysis only — NOT inlined as literals + # in generated code (literal inlining causes 7.8x GPU regression). + shared_cache_values = [float(v) for v in np.asarray(shared_cache)] + sccp_known_values = translator.build_sccp_known_values( + shared_indices, + shared_params_list, + shared_cache_indices, + shared_cache_values, ) + if sccp_known_values: + logger.info( + f"{model_type}: SCCP dead-block elimination with " + f"{len(sccp_known_values)} known values" + ) # Generate eval function with cache split from vajax.analysis.limiting import fetlim, pnjlim @@ -947,8 +952,7 @@ def prepare_static_inputs( varying_cache_indices, use_limit_functions=use_device_limiting, limit_param_map=limit_param_map, - concrete_shared_values=concrete_shared_values, - concrete_shared_cache=concrete_shared_cache_values, + sccp_known_values=sccp_known_values, ) # Safety check: if limiting is enabled but lim_rhs could not be computed # (model uses inline limiting without $limit/BuiltinLimit calls), disable @@ -969,8 +973,7 @@ def prepare_static_inputs( varying_cache_indices, use_limit_functions=False, limit_param_map=limit_param_map, - concrete_shared_values=concrete_shared_values, - concrete_shared_cache=concrete_shared_cache_values, + sccp_known_values=sccp_known_values, ) split_fn = partial(split_fn, limit_funcs=limit_funcs) From b6950683a918206072e2c2233e1aeb43c6595605 Mon Sep 17 00:00:00 2001 From: Rob Taylor Date: Mon, 9 Mar 2026 14:33:51 +0000 Subject: [PATCH 14/79] fix: Add None check for _transient_setup_cache in dump_jaxpr Pyright flagged subscript access on Optional type. Add explicit None guard with descriptive error message. Co-developed-by: Claude Code v2.1.58 (claude-opus-4-6) --- vajax/analysis/engine.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vajax/analysis/engine.py b/vajax/analysis/engine.py index 4ec9ef29..c2d18a68 100644 --- a/vajax/analysis/engine.py +++ b/vajax/analysis/engine.py @@ -1062,6 +1062,8 @@ def dump_jaxpr(self, output_dir: str | Path = "/tmp/claude/jaxpr-analysis") -> P # Get dimensions from transient setup cache setup_cache = self._transient_setup_cache + if setup_cache is None: + raise RuntimeError("No transient setup cache. Call prepare() first.") n_total = setup_cache["n_total"] n_unknowns = setup_cache["n_unknowns"] n_vsources = len([d for d in self.devices if d["model"] == "vsource"]) From 131e160b5d2469ea9be4557cbf95342213156e3f Mon Sep 17 00:00:00 2001 From: Rob Taylor Date: Mon, 9 Mar 2026 14:50:10 +0000 Subject: [PATCH 15/79] perf: Add XLA flag sweep script and CI workflow for CUDA optimization Script tests 8 XLA flag configurations (autotune levels, command buffers, while-loop double buffering, PGLE) on CUDA benchmarks. Each config runs in a subprocess for clean XLA state. Workflow is dispatch-only (manual trigger) so it doesn't affect regular CI. Co-developed-by: Claude Code v2.1.58 (claude-opus-4-6) --- .github/workflows/xla-flag-sweep.yml | 120 ++++++++ scripts/sweep_xla_flags.py | 402 +++++++++++++++++++++++++++ 2 files changed, 522 insertions(+) create mode 100644 .github/workflows/xla-flag-sweep.yml create mode 100644 scripts/sweep_xla_flags.py diff --git a/.github/workflows/xla-flag-sweep.yml b/.github/workflows/xla-flag-sweep.yml new file mode 100644 index 00000000..7e506468 --- /dev/null +++ b/.github/workflows/xla-flag-sweep.yml @@ -0,0 +1,120 @@ +name: XLA Flag Sweep + +on: + workflow_dispatch: + inputs: + benchmarks: + description: 'Comma-separated benchmark names' + required: false + default: 'ring' + configs: + description: 'Comma-separated config names (empty = all)' + required: false + default: '' + n_runs: + description: 'Number of timed runs per config' + required: false + default: '3' + +env: + CARGO_TERM_COLOR: always + CARGO_INCREMENTAL: 0 + +jobs: + sweep: + name: XLA flag sweep (CUDA) + runs-on: nvidia-runner-1 + timeout-minutes: 120 + concurrency: + group: xla-sweep-${{ github.ref }} + + steps: + - name: Checkout with submodules + uses: actions/checkout@v4 + with: + submodules: recursive + + - name: CUDA diagnostics + run: | + echo "=== GPU Info ===" + nvidia-smi 2>/dev/null || echo "nvidia-smi not available" + echo "=== CUDA Version ===" + nvcc --version 2>/dev/null || echo "nvcc not available" + + - name: Set up Python + uv + uses: astral-sh/setup-uv@v6 + with: + enable-cache: true + cache-dependency-glob: "uv.lock" + + - name: Install dependencies + run: uv sync --frozen --all-extras + + - name: Verify GPU access + run: | + uv run python -c " + import jax + devices = jax.devices() + print(f'JAX devices: {devices}') + print(f'GPU available: {any(d.platform == \"gpu\" for d in devices)}') + for d in devices: + print(f' {d.platform}: {d}') + " + env: + JAX_PLATFORMS: cuda,cpu + + - name: Run XLA flag sweep + env: + JAX_PLATFORMS: cuda,cpu + JAX_ENABLE_X64: "1" + XLA_PYTHON_CLIENT_PREALLOCATE: "false" + XLA_PYTHON_CLIENT_ALLOCATOR: "platform" + run: | + CONFIGS_FLAG="" + if [ -n "${{ inputs.configs }}" ]; then + CONFIGS_FLAG="--configs ${{ inputs.configs }}" + fi + + uv run python scripts/sweep_xla_flags.py \ + --benchmark ${{ inputs.benchmarks }} \ + --n-runs ${{ inputs.n_runs }} \ + $CONFIGS_FLAG \ + --json-output /tmp/xla-sweep-results.json \ + 2>&1 | tee /tmp/xla-sweep.log + + - name: Generate summary + if: always() + run: | + { + echo "## XLA Flag Sweep Results" + echo "" + echo "**Benchmarks**: ${{ inputs.benchmarks }}" + echo "**Runs per config**: ${{ inputs.n_runs }}" + echo "" + + if [ -f /tmp/xla-sweep.log ]; then + echo '```' + # Extract the summary table + sed -n '/^={80}/,$ p' /tmp/xla-sweep.log | head -40 + echo '```' + fi + + if [ -f /tmp/xla-sweep-results.json ]; then + echo "" + echo "
Raw JSON results" + echo "" + echo '```json' + cat /tmp/xla-sweep-results.json + echo '```' + echo "
" + fi + } >> "$GITHUB_STEP_SUMMARY" + + - name: Upload results + if: always() + uses: actions/upload-artifact@v4 + with: + name: xla-sweep-results + path: | + /tmp/xla-sweep-results.json + /tmp/xla-sweep.log diff --git a/scripts/sweep_xla_flags.py b/scripts/sweep_xla_flags.py new file mode 100644 index 00000000..0175c12e --- /dev/null +++ b/scripts/sweep_xla_flags.py @@ -0,0 +1,402 @@ +#!/usr/bin/env -S uv run --script +# /// script +# requires-python = ">=3.10" +# dependencies = ["jax"] +# /// +"""Sweep XLA flag combinations to find optimal CUDA performance. + +Runs a benchmark circuit with different XLA flag configurations and +reports timing for each. Each configuration runs in a separate subprocess +to ensure clean XLA state. + +Usage: + # Run on GPU (auto-detects CUDA) + uv run scripts/sweep_xla_flags.py + + # Specific benchmark + uv run scripts/sweep_xla_flags.py --benchmark ring + + # Specific configurations only + uv run scripts/sweep_xla_flags.py --configs baseline,autotune2,command_buffer + + # Also include large circuit (needs sparse solver) + uv run scripts/sweep_xla_flags.py --benchmark ring,c6288 --include-sparse +""" + +import argparse +import json +import os +import subprocess +import sys +import time +from pathlib import Path + +# XLA flag configurations to test +FLAG_CONFIGS = { + "baseline": { + "description": "Current CI config (autotune=0)", + "xla_flags": "--xla_gpu_autotune_level=0", + "env": {}, + }, + "autotune2": { + "description": "Autotune level 2 (enables cuBLAS algorithm selection)", + "xla_flags": "--xla_gpu_autotune_level=2", + "env": {}, + }, + "autotune4": { + "description": "Autotune level 4 (full autotuning)", + "xla_flags": "--xla_gpu_autotune_level=4", + "env": {}, + }, + "command_buffer": { + "description": "Command buffers enabled (batch kernel launches)", + "xla_flags": "--xla_gpu_autotune_level=0", + "env": {}, + }, + "double_buffer": { + "description": "While-loop double buffering", + "xla_flags": ( + "--xla_gpu_autotune_level=0 --xla_gpu_enable_while_loop_double_buffering=true" + ), + "env": {}, + }, + "pgle": { + "description": "Profile-guided latency estimation (3 profiling runs)", + "xla_flags": "--xla_gpu_autotune_level=0", + "env": { + "JAX_ENABLE_PGLE": "true", + "JAX_PGLE_PROFILING_RUNS": "3", + }, + }, + "combined_safe": { + "description": "Autotune 2 + double buffering", + "xla_flags": ( + "--xla_gpu_autotune_level=2 --xla_gpu_enable_while_loop_double_buffering=true" + ), + "env": {}, + }, + "combined_aggressive": { + "description": "Autotune 4 + double buffering + PGLE", + "xla_flags": ( + "--xla_gpu_autotune_level=4 --xla_gpu_enable_while_loop_double_buffering=true" + ), + "env": { + "JAX_ENABLE_PGLE": "true", + "JAX_PGLE_PROFILING_RUNS": "3", + }, + }, +} + +# The subprocess script that runs a single benchmark +BENCHMARK_RUNNER = """ +import os +import sys +import time +import json + +sys.path.insert(0, os.environ["PROJECT_ROOT"]) + +# Memory config +os.environ.setdefault("XLA_PYTHON_CLIENT_PREALLOCATE", "false") +os.environ.setdefault("XLA_PYTHON_CLIENT_ALLOCATOR", "platform") + +import jax +import jax.numpy as jnp + +from vajax.analysis import CircuitEngine + +benchmark_name = os.environ["BENCHMARK_NAME"] +use_sparse = os.environ.get("USE_SPARSE", "0") == "1" +use_scan = True +force_gpu = os.environ.get("FORCE_GPU", "0") == "1" +n_warmup = int(os.environ.get("N_WARMUP", "1")) +n_runs = int(os.environ.get("N_RUNS", "3")) + +from scripts.benchmark_utils import get_vacask_benchmarks + +benchmarks = get_vacask_benchmarks([benchmark_name]) +if not benchmarks: + print(json.dumps({"error": f"Benchmark {benchmark_name} not found"})) + sys.exit(1) + +name, sim_path = benchmarks[0] + +# Report JAX config +devices = jax.devices() +backend = devices[0].platform if devices else "unknown" +print(f"JAX backend: {backend}, devices: {[d.platform for d in devices]}", file=sys.stderr) +print(f"XLA_FLAGS: {os.environ.get('XLA_FLAGS', '(not set)')}", file=sys.stderr) + +engine = CircuitEngine.from_sim_file(str(sim_path)) +engine.prepare(use_sparse=use_sparse, force_gpu=force_gpu, use_scan=use_scan) + +# Get step count from sim parameters +dt = engine.sim_params.get("dt", 1e-6) +t_stop = engine.sim_params.get("tstop", engine.sim_params.get("t_stop", 1e-3)) +n_steps = int(t_stop / dt) if dt > 0 else 100 + +timings = [] + +for run_idx in range(n_warmup + n_runs): + # Re-prepare to reset state + if run_idx > 0: + engine.prepare(use_sparse=use_sparse, force_gpu=force_gpu, use_scan=use_scan) + + start = time.perf_counter() + result = engine.run_transient() + # Block until computation complete + if hasattr(result, 'voltages') and result.voltages is not None: + jax.block_until_ready(result.voltages) + elapsed = time.perf_counter() - start + + actual_steps = result.n_steps if hasattr(result, 'n_steps') else n_steps + ms_per_step = (elapsed * 1000.0) / max(actual_steps, 1) + + label = "warmup" if run_idx < n_warmup else f"run {run_idx - n_warmup}" + print(f" {label}: {elapsed:.3f}s ({actual_steps} steps, {ms_per_step:.3f} ms/step)", file=sys.stderr) + + if run_idx >= n_warmup: + timings.append({ + "elapsed_s": elapsed, + "n_steps": actual_steps, + "ms_per_step": ms_per_step, + }) + +# Report median timing +timings.sort(key=lambda t: t["ms_per_step"]) +median = timings[len(timings) // 2] + +print(json.dumps({ + "benchmark": benchmark_name, + "backend": backend, + "n_steps": median["n_steps"], + "ms_per_step": median["ms_per_step"], + "elapsed_s": median["elapsed_s"], + "n_runs": n_runs, + "all_timings": [t["ms_per_step"] for t in timings], +})) +""" + + +def run_config( + config_name: str, + config: dict, + benchmark: str, + project_root: Path, + use_sparse: bool, + force_gpu: bool, + n_warmup: int = 1, + n_runs: int = 3, +) -> dict: + """Run a single benchmark with a specific XLA flag configuration.""" + env = os.environ.copy() + env["PROJECT_ROOT"] = str(project_root) + env["BENCHMARK_NAME"] = benchmark + env["USE_SPARSE"] = "1" if use_sparse else "0" + env["FORCE_GPU"] = "1" if force_gpu else "0" + env["N_WARMUP"] = str(n_warmup) + env["N_RUNS"] = str(n_runs) + env["JAX_PLATFORMS"] = "cuda,cpu" if force_gpu else "cpu" + env["JAX_ENABLE_X64"] = "1" + + # Set XLA flags + env["XLA_FLAGS"] = config["xla_flags"] + + # Set additional env vars + for k, v in config.get("env", {}).items(): + env[k] = v + + # Memory allocation + env.setdefault("XLA_PYTHON_CLIENT_PREALLOCATE", "false") + env.setdefault("XLA_PYTHON_CLIENT_ALLOCATOR", "platform") + + print(f"\n{'=' * 60}") + print(f"Config: {config_name} — {config['description']}") + print(f" XLA_FLAGS: {config['xla_flags']}") + if config.get("env"): + print(f" Extra env: {config['env']}") + print(f" Benchmark: {benchmark}") + print(f"{'=' * 60}") + + start = time.perf_counter() + try: + result = subprocess.run( + [sys.executable, "-c", BENCHMARK_RUNNER], + env=env, + capture_output=True, + text=True, + timeout=600, # 10 min max per config + cwd=str(project_root), + ) + except subprocess.TimeoutExpired: + return { + "config": config_name, + "benchmark": benchmark, + "error": "timeout (600s)", + "wall_time_s": time.perf_counter() - start, + } + + wall_time = time.perf_counter() - start + + # Print stderr (progress messages) + if result.stderr: + for line in result.stderr.strip().split("\n"): + print(f" {line}") + + # Parse JSON from last line of stdout + if result.returncode != 0: + print(f" ERROR: exit code {result.returncode}") + if result.stderr: + print(f" {result.stderr[-500:]}") + return { + "config": config_name, + "benchmark": benchmark, + "error": f"exit code {result.returncode}", + "wall_time_s": wall_time, + } + + try: + # Find the JSON line (last non-empty line of stdout) + lines = [l for l in result.stdout.strip().split("\n") if l.strip()] + data = json.loads(lines[-1]) + data["config"] = config_name + data["wall_time_s"] = wall_time + data["description"] = config["description"] + print(f" Result: {data['ms_per_step']:.3f} ms/step ({data['n_steps']} steps)") + return data + except (json.JSONDecodeError, IndexError) as e: + print(f" ERROR parsing output: {e}") + print(f" stdout: {result.stdout[-500:]}") + return { + "config": config_name, + "benchmark": benchmark, + "error": f"parse error: {e}", + "wall_time_s": wall_time, + } + + +def main(): + parser = argparse.ArgumentParser(description="Sweep XLA flag combinations") + parser.add_argument( + "--benchmark", + default="ring", + help="Comma-separated benchmark names (default: ring)", + ) + parser.add_argument( + "--configs", + default=None, + help="Comma-separated config names (default: all)", + ) + parser.add_argument( + "--include-sparse", + action="store_true", + help="Include sparse solver for large circuits", + ) + parser.add_argument( + "--force-gpu", + action="store_true", + default=True, + help="Force GPU backend (default: True)", + ) + parser.add_argument( + "--cpu-only", + action="store_true", + help="Run on CPU instead of GPU", + ) + parser.add_argument( + "--n-warmup", + type=int, + default=1, + help="Number of warmup runs (default: 1)", + ) + parser.add_argument( + "--n-runs", + type=int, + default=3, + help="Number of timed runs (default: 3)", + ) + parser.add_argument( + "--json-output", + default=None, + help="Path to write JSON results", + ) + args = parser.parse_args() + + project_root = Path(__file__).parent.parent + benchmarks = [b.strip() for b in args.benchmark.split(",")] + force_gpu = not args.cpu_only + + if args.configs: + config_names = [c.strip() for c in args.configs.split(",")] + configs = {k: FLAG_CONFIGS[k] for k in config_names if k in FLAG_CONFIGS} + else: + configs = FLAG_CONFIGS + + all_results = [] + + for benchmark in benchmarks: + # Determine if sparse needed + use_sparse = args.include_sparse and benchmark in ("c6288", "mul64") + + for config_name, config in configs.items(): + result = run_config( + config_name, + config, + benchmark, + project_root, + use_sparse=use_sparse, + force_gpu=force_gpu, + n_warmup=args.n_warmup, + n_runs=args.n_runs, + ) + all_results.append(result) + + # Print summary table + print(f"\n{'=' * 80}") + print("SUMMARY") + print(f"{'=' * 80}") + print( + f"{'Config':<25} {'Benchmark':<10} {'ms/step':>10} {'Steps':>8} {'Wall(s)':>10} {'vs base':>10}" + ) + print("-" * 80) + + # Group by benchmark for relative comparison + by_benchmark = {} + for r in all_results: + bm = r.get("benchmark", "?") + by_benchmark.setdefault(bm, []).append(r) + + for bm, results in by_benchmark.items(): + baseline_ms = None + for r in results: + if r.get("config") == "baseline" and "ms_per_step" in r: + baseline_ms = r["ms_per_step"] + break + + for r in results: + ms = r.get("ms_per_step", None) + steps = r.get("n_steps", "?") + wall = r.get("wall_time_s", 0) + config = r.get("config", "?") + err = r.get("error", None) + + if err: + print(f"{config:<25} {bm:<10} {'ERROR':>10} {'':>8} {wall:>10.1f} {err}") + elif ms is not None: + ratio_str = "" + if baseline_ms and baseline_ms > 0: + ratio = ms / baseline_ms + ratio_str = f"{ratio:.2f}x" + print(f"{config:<25} {bm:<10} {ms:>10.3f} {steps:>8} {wall:>10.1f} {ratio_str:>10}") + + # Save JSON + if args.json_output: + out_path = Path(args.json_output) + out_path.parent.mkdir(parents=True, exist_ok=True) + with open(out_path, "w") as f: + json.dump(all_results, f, indent=2) + print(f"\nResults saved to {out_path}") + + +if __name__ == "__main__": + main() From 7f463c1c3721ceec4991e2ebfcc4ff22fbbdef87 Mon Sep 17 00:00:00 2001 From: Rob Taylor Date: Mon, 9 Mar 2026 15:09:48 +0000 Subject: [PATCH 16/79] perf: Disable SCCP for unified eval function (cache invalidation) SCCP dead-block elimination changes the generated eval function hash, invalidating the persistent XLA compilation cache. For ring (PSP103), this causes 49.5s cold compilation (vs 0.58s cache hit on main), which inflates the benchmark wall_time by ~152s. The actual per-step execution with SCCP is ~2.3ms vs ~1.4ms without, but the benchmark reports 9.84ms because JIT compilation happens inside run_while() and is counted as execution time. SCCP provides marginal benefit for the unified eval function (both NMOS+PMOS branches still needed), but the infrastructure is preserved in build_sccp_known_values() for future config-group specialization where each group has a unique TYPE value. Co-developed-by: Claude Code v2.1.58 (claude-opus-4-6) --- vajax/analysis/openvaf_models.py | 27 ++++++++++----------------- 1 file changed, 10 insertions(+), 17 deletions(-) diff --git a/vajax/analysis/openvaf_models.py b/vajax/analysis/openvaf_models.py index 49f5b4c9..27aa1183 100644 --- a/vajax/analysis/openvaf_models.py +++ b/vajax/analysis/openvaf_models.py @@ -919,23 +919,16 @@ def prepare_static_inputs( shared_cache = cache[0, shared_cache_indices] device_cache = cache[:, varying_cache_indices] - # Build SCCP known values for dead-block elimination. - # Shared params and cache are constant per simulation, so SCCP can - # resolve static branches and eliminate dead MIR blocks at codegen time. - # NOTE: values are used for SCCP analysis only — NOT inlined as literals - # in generated code (literal inlining causes 7.8x GPU regression). - shared_cache_values = [float(v) for v in np.asarray(shared_cache)] - sccp_known_values = translator.build_sccp_known_values( - shared_indices, - shared_params_list, - shared_cache_indices, - shared_cache_values, - ) - if sccp_known_values: - logger.info( - f"{model_type}: SCCP dead-block elimination with " - f"{len(sccp_known_values)} known values" - ) + # SCCP dead-block elimination is DISABLED for the unified eval function. + # While SCCP can eliminate ~695/954 blocks for PSP103, the benefit is + # marginal (same code size, same XLA ops — XLA already CSEs branches). + # The cost is high: SCCP changes the JIT function hash, invalidating + # the persistent XLA compilation cache and adding ~99s cold-compile + # penalty for ring (49.5s × 2 compilations vs 0.58s cache hit). + # SCCP would be valuable for config-group-specialized eval functions + # where each group has a unique TYPE value, but that's deferred. + # Infrastructure is preserved in build_sccp_known_values() for reuse. + sccp_known_values = None # Generate eval function with cache split from vajax.analysis.limiting import fetlim, pnjlim From 97fdfb9ddf5c4b8685cb3cc4762c812eeb4c4c7e Mon Sep 17 00:00:00 2001 From: Rob Taylor Date: Mon, 9 Mar 2026 16:07:12 +0000 Subject: [PATCH 17/79] perf: Return numpy arrays from run_transient to avoid CUDA dynamic_slice recompile extract_results() slices output to n_steps and converts to numpy, then run_transient() was converting back to JAX via jnp.asarray(). This created dynamically-sized JAX arrays (shape = n_steps) that triggered jit(dynamic_slice) recompilation on CUDA whenever the step count changed between warmup and actual run. All consumers already convert to numpy, so the round-trip was unnecessary. Co-developed-by: Claude Code v2.1.58 (claude-opus-4-6) --- scripts/compare_vacask.py | 7 +++---- vajax/analysis/engine.py | 9 ++++++--- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/scripts/compare_vacask.py b/scripts/compare_vacask.py index 070c9342..a6941eb5 100644 --- a/scripts/compare_vacask.py +++ b/scripts/compare_vacask.py @@ -35,6 +35,7 @@ from vajax._logging import enable_performance_logging enable_performance_logging(with_memory=True, with_perf_counter=True) + import re import sys import time @@ -367,13 +368,11 @@ def do_run(): print( f"AFTER_RUN_TRANSIENT: {after_transient:.6f} (elapsed: {after_transient - start:.6f}s)" ) - # Force completion of async JAX operations - first_node = next(iter(result.voltages)) - _ = float(result.voltages[first_node][0]) + # Results are numpy arrays (materialized by block_until_ready in full_mna) end = time.perf_counter() external_elapsed = end - start print( - f"TIMED_RUN_END: {end:.6f} (elapsed: {external_elapsed:.6f}s, sync took: {end - after_transient:.6f}s)" + f"TIMED_RUN_END: {end:.6f} (elapsed: {external_elapsed:.6f}s, extract took: {end - after_transient:.6f}s)" ) # Use wall_time from stats (excludes trace saving overhead) diff --git a/vajax/analysis/engine.py b/vajax/analysis/engine.py index c2d18a68..d2bfe180 100644 --- a/vajax/analysis/engine.py +++ b/vajax/analysis/engine.py @@ -1015,10 +1015,13 @@ def run_transient(self) -> TransientResult: # Extract sliced numpy results for TransientResult times_np, voltages, currents = extract_results(times_full, V_out, stats) + # Keep as numpy arrays — avoids creating dynamically-sized JAX arrays + # that trigger jit(dynamic_slice) recompilation on CUDA when n_steps + # varies between runs (the shape gets baked into the XLA kernel). return TransientResult( - times=jnp.asarray(times_np), - voltages={k: jnp.asarray(v) for k, v in voltages.items()}, - currents={k: jnp.asarray(v) for k, v in currents.items()}, + times=times_np, + voltages=voltages, + currents=currents, stats=stats, ) From f7adaf934aad29ddaab512cd496aac5df03d8534 Mon Sep 17 00:00:00 2001 From: Rob Taylor Date: Mon, 9 Mar 2026 16:35:00 +0000 Subject: [PATCH 18/79] ci: Install nsight-systems-cli for nsys profiling workflow MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit nsys was not installed — only cuda-nvcc and cuda-cudart-dev were. Add nsight-systems-cli package and discover its PATH dynamically. Co-developed-by: Claude Code v2.1.58 (claude-opus-4-6) --- .github/workflows/profile-nsys.yml | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/.github/workflows/profile-nsys.yml b/.github/workflows/profile-nsys.yml index 28efc4af..3b9ff2a7 100644 --- a/.github/workflows/profile-nsys.yml +++ b/.github/workflows/profile-nsys.yml @@ -78,11 +78,17 @@ jobs: with: workspaces: openvaf_jax/openvaf_py - - name: Install CUDA toolkit + - name: Install CUDA toolkit and nsys run: | sudo apt-get update - sudo apt-get install -y cuda-nvcc-12-6 cuda-cudart-dev-12-6 + sudo apt-get install -y cuda-nvcc-12-6 cuda-cudart-dev-12-6 nsight-systems-cli echo "/usr/local/cuda-12.6/bin" >> $GITHUB_PATH + # nsys installs to /opt/nvidia/nsight-systems/*/target-linux-x64/ + NSYS_BIN=$(dirname "$(find /opt/nvidia -name nsys -type f 2>/dev/null | head -1)" 2>/dev/null) + if [ -n "$NSYS_BIN" ]; then + echo "$NSYS_BIN" >> $GITHUB_PATH + echo "Found nsys at: $NSYS_BIN" + fi - name: Install system dependencies run: | From 5fbff74ce72b0be40cb67e89d905df96b0ba7cf2 Mon Sep 17 00:00:00 2001 From: Rob Taylor Date: Mon, 9 Mar 2026 16:45:09 +0000 Subject: [PATCH 19/79] ci: Fix nsys package name to cuda-nsight-systems-12-6 The NVIDIA CUDA apt repo uses versioned package names like cuda-nsight-systems-12-6, not nsight-systems-cli. Co-developed-by: Claude Code v2.1.58 (claude-opus-4-6) --- .github/workflows/profile-nsys.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/profile-nsys.yml b/.github/workflows/profile-nsys.yml index 3b9ff2a7..28f0117f 100644 --- a/.github/workflows/profile-nsys.yml +++ b/.github/workflows/profile-nsys.yml @@ -81,14 +81,14 @@ jobs: - name: Install CUDA toolkit and nsys run: | sudo apt-get update - sudo apt-get install -y cuda-nvcc-12-6 cuda-cudart-dev-12-6 nsight-systems-cli + sudo apt-get install -y cuda-nvcc-12-6 cuda-cudart-dev-12-6 cuda-nsight-systems-12-6 echo "/usr/local/cuda-12.6/bin" >> $GITHUB_PATH # nsys installs to /opt/nvidia/nsight-systems/*/target-linux-x64/ NSYS_BIN=$(dirname "$(find /opt/nvidia -name nsys -type f 2>/dev/null | head -1)" 2>/dev/null) if [ -n "$NSYS_BIN" ]; then echo "$NSYS_BIN" >> $GITHUB_PATH - echo "Found nsys at: $NSYS_BIN" fi + nsys --version - name: Install system dependencies run: | From 9f9e8ec63d89b01b4023e8b097cfe5fbf273375b Mon Sep 17 00:00:00 2001 From: Rob Taylor Date: Mon, 9 Mar 2026 16:52:24 +0000 Subject: [PATCH 20/79] ci: Export nsys SQLite + CSV stats and improve summary - Export .sqlite for offline analysis without nsys installed - Generate CSV stats for kernel, API, memory transfer, and trace reports - Expand job summary with CUDA API and memory transfer sections - Upload all artifacts (nsys-rep, sqlite, CSV stats) Co-developed-by: Claude Code v2.1.58 (claude-opus-4-6) --- .github/workflows/profile-nsys.yml | 47 ++++++++++++++++++++++++++---- 1 file changed, 42 insertions(+), 5 deletions(-) diff --git a/.github/workflows/profile-nsys.yml b/.github/workflows/profile-nsys.yml index 28f0117f..78385c5c 100644 --- a/.github/workflows/profile-nsys.yml +++ b/.github/workflows/profile-nsys.yml @@ -141,9 +141,32 @@ jobs: echo "profile_name=${PROFILE_NAME}" >> "$GITHUB_ENV" + - name: Export stats and SQLite + if: always() + run: | + REPORT="/tmp/${profile_name}.nsys-rep" + STATS_DIR="/tmp/nsys-stats" + mkdir -p "$STATS_DIR" + + if [ -f "$REPORT" ]; then + # Export to SQLite for offline analysis + nsys export --type=sqlite --output="/tmp/${profile_name}.sqlite" "$REPORT" || true + + # Generate key stats reports + for report in cuda_gpu_kern_sum cuda_api_sum cuda_gpu_mem_time_sum cuda_gpu_mem_size_sum; do + nsys stats "$REPORT" --report "$report" --format csv \ + --output "${STATS_DIR}/${report}" 2>/dev/null || true + done + + # Full kernel trace (top 100 by duration) + nsys stats "$REPORT" --report cuda_gpu_trace --format csv \ + --output "${STATS_DIR}/cuda_gpu_trace" 2>/dev/null || true + fi + - name: Generate summary if: always() run: | + REPORT="/tmp/${profile_name}.nsys-rep" { echo "## nsys GPU Profiling Results" echo "" @@ -152,20 +175,34 @@ jobs: echo "**Sparse:** ${{ inputs.sparse }}" echo "**Commit:** \`${{ github.sha }}\`" echo "" - if [ -f "/tmp/${profile_name}.nsys-rep" ]; then - echo "### Profile Summary" + if [ -f "$REPORT" ]; then + echo "### GPU Kernel Summary" + echo '```' + nsys stats "$REPORT" --report cuda_gpu_kern_sum 2>&1 | head -50 || true + echo '```' + echo "" + echo "### CUDA API Summary" + echo '```' + nsys stats "$REPORT" --report cuda_api_sum 2>&1 | head -30 || true + echo '```' echo "" + echo "### GPU Memory Transfers" echo '```' - nsys stats "/tmp/${profile_name}.nsys-rep" --report cuda_gpu_kern_sum 2>&1 | head -40 || true + nsys stats "$REPORT" --report cuda_gpu_mem_time_sum 2>&1 | head -20 || true echo '```' + else + echo "**No profile report found.**" fi } >> "$GITHUB_STEP_SUMMARY" - - name: Upload nsys report + - name: Upload nsys report and stats if: always() uses: actions/upload-artifact@v4 with: name: nsys-profile-${{ inputs.circuit }}-${{ github.sha }} - path: /tmp/nsys-*.nsys-rep + path: | + /tmp/nsys-*.nsys-rep + /tmp/nsys-*.sqlite + /tmp/nsys-stats/ if-no-files-found: ignore retention-days: 30 From 6cc1186bd4128158edb431ec863104f079418c1e Mon Sep 17 00:00:00 2001 From: Rob Taylor Date: Mon, 9 Mar 2026 16:59:21 +0000 Subject: [PATCH 21/79] ci: Add libcudnn9-cuda-12 to nsys profiling workflow JAX CUDA plugin requires cuDNN to initialize the GPU backend. Co-developed-by: Claude Code v2.1.58 (claude-opus-4-6) --- .github/workflows/profile-nsys.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/profile-nsys.yml b/.github/workflows/profile-nsys.yml index 78385c5c..a193b38a 100644 --- a/.github/workflows/profile-nsys.yml +++ b/.github/workflows/profile-nsys.yml @@ -81,7 +81,7 @@ jobs: - name: Install CUDA toolkit and nsys run: | sudo apt-get update - sudo apt-get install -y cuda-nvcc-12-6 cuda-cudart-dev-12-6 cuda-nsight-systems-12-6 + sudo apt-get install -y cuda-nvcc-12-6 cuda-cudart-dev-12-6 cuda-nsight-systems-12-6 libcudnn9-cuda-12 echo "/usr/local/cuda-12.6/bin" >> $GITHUB_PATH # nsys installs to /opt/nvidia/nsight-systems/*/target-linux-x64/ NSYS_BIN=$(dirname "$(find /opt/nvidia -name nsys -type f 2>/dev/null | head -1)" 2>/dev/null) From 927e4dc010c8657bf8a6b4f8fe806a78b95cfde2 Mon Sep 17 00:00:00 2001 From: Rob Taylor Date: Mon, 9 Mar 2026 17:04:16 +0000 Subject: [PATCH 22/79] ci: Install full cuda-toolkit-12-6 for nsys profiling Individual CUDA packages miss runtime libs like cuFFT. Use cuda-toolkit-12-6 (matching benchmark workflow) to get everything. Co-developed-by: Claude Code v2.1.58 (claude-opus-4-6) --- .github/workflows/profile-nsys.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/profile-nsys.yml b/.github/workflows/profile-nsys.yml index a193b38a..fbe95751 100644 --- a/.github/workflows/profile-nsys.yml +++ b/.github/workflows/profile-nsys.yml @@ -81,7 +81,7 @@ jobs: - name: Install CUDA toolkit and nsys run: | sudo apt-get update - sudo apt-get install -y cuda-nvcc-12-6 cuda-cudart-dev-12-6 cuda-nsight-systems-12-6 libcudnn9-cuda-12 + sudo apt-get install -y cuda-toolkit-12-6 libcudnn9-cuda-12 cuda-nsight-systems-12-6 echo "/usr/local/cuda-12.6/bin" >> $GITHUB_PATH # nsys installs to /opt/nvidia/nsight-systems/*/target-linux-x64/ NSYS_BIN=$(dirname "$(find /opt/nvidia -name nsys -type f 2>/dev/null | head -1)" 2>/dev/null) From 43fc0c81f57573079933260cf966e14379f1d7d2 Mon Sep 17 00:00:00 2001 From: Rob Taylor Date: Mon, 9 Mar 2026 17:05:09 +0000 Subject: [PATCH 23/79] ci: Add libcudss0-cuda-12 to nsys profiling workflow Match benchmark-comparison CUDA packages: cuda-toolkit, cudnn, cudss. Co-developed-by: Claude Code v2.1.58 (claude-opus-4-6) --- .github/workflows/profile-nsys.yml | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/.github/workflows/profile-nsys.yml b/.github/workflows/profile-nsys.yml index fbe95751..593c1970 100644 --- a/.github/workflows/profile-nsys.yml +++ b/.github/workflows/profile-nsys.yml @@ -81,8 +81,13 @@ jobs: - name: Install CUDA toolkit and nsys run: | sudo apt-get update - sudo apt-get install -y cuda-toolkit-12-6 libcudnn9-cuda-12 cuda-nsight-systems-12-6 + sudo apt-get install -y cuda-toolkit-12-6 libcudnn9-cuda-12 libcudss0-cuda-12 cuda-nsight-systems-12-6 echo "/usr/local/cuda-12.6/bin" >> $GITHUB_PATH + CUDSS_LIB=$(dpkg -L libcudss0-cuda-12 | grep '\.so' | head -1) + if [ -n "$CUDSS_LIB" ]; then + CUDSS_DIR=$(dirname "$CUDSS_LIB") + echo "LD_LIBRARY_PATH=${CUDSS_DIR}:${LD_LIBRARY_PATH}" >> "$GITHUB_ENV" + fi # nsys installs to /opt/nvidia/nsight-systems/*/target-linux-x64/ NSYS_BIN=$(dirname "$(find /opt/nvidia -name nsys -type f 2>/dev/null | head -1)" 2>/dev/null) if [ -n "$NSYS_BIN" ]; then From a0b91615c2a568f4d05f77858f5482b0a438c47a Mon Sep 17 00:00:00 2001 From: Rob Taylor Date: Mon, 9 Mar 2026 17:10:55 +0000 Subject: [PATCH 24/79] ci: Add cancel-in-progress concurrency to all PR workflows New pushes to a PR branch now cancel any in-flight runs of the same workflow, avoiding wasted runner time on stale commits. Co-developed-by: Claude Code v2.1.58 (claude-opus-4-6) --- .github/workflows/benchmark-comparison.yml | 6 ++++-- .github/workflows/lint.yml | 4 ++++ .github/workflows/test-pdk.yml | 4 ++++ .github/workflows/test.yml | 4 ++++ 4 files changed, 16 insertions(+), 2 deletions(-) diff --git a/.github/workflows/benchmark-comparison.yml b/.github/workflows/benchmark-comparison.yml index 517a072c..b319cdae 100644 --- a/.github/workflows/benchmark-comparison.yml +++ b/.github/workflows/benchmark-comparison.yml @@ -9,6 +9,10 @@ on: schedule: - cron: '0 6 * * *' # Daily at 6am UTC +concurrency: + group: benchmark-${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} + cancel-in-progress: true + env: CARGO_TERM_COLOR: always CARGO_INCREMENTAL: 0 @@ -34,8 +38,6 @@ jobs: name: benchmark (${{ matrix.solver }}-${{ matrix.platform }}) runs-on: ${{ matrix.runner }} timeout-minutes: ${{ matrix.solver == 'sparse' && matrix.platform == 'cuda' && 360 || 90 }} - concurrency: - group: benchmark-${{ matrix.solver }}-${{ matrix.platform }}-${{ github.ref }} steps: - name: Checkout with submodules diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index b6e2f231..54cb8165 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -6,6 +6,10 @@ on: pull_request: branches: [main] +concurrency: + group: lint-${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} + cancel-in-progress: true + jobs: lint: runs-on: ubuntu-latest diff --git a/.github/workflows/test-pdk.yml b/.github/workflows/test-pdk.yml index 13a7f4e0..451d3615 100644 --- a/.github/workflows/test-pdk.yml +++ b/.github/workflows/test-pdk.yml @@ -6,6 +6,10 @@ on: pull_request: branches: [main] +concurrency: + group: pdk-${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} + cancel-in-progress: true + env: CARGO_TERM_COLOR: always CARGO_INCREMENTAL: 0 diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index fa4d8db0..1e13b602 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -6,6 +6,10 @@ on: pull_request: branches: [main] +concurrency: + group: test-${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} + cancel-in-progress: true + env: CARGO_TERM_COLOR: always CARGO_INCREMENTAL: 0 From 7356c78fa7049d45d9443477c651a352c4609663 Mon Sep 17 00:00:00 2001 From: Rob Taylor Date: Mon, 9 Mar 2026 17:29:02 +0000 Subject: [PATCH 25/79] perf: Exclude JIT warmup from nsys profiling via CUDA profiler API Use cudaProfilerStart()/cudaProfilerStop() to capture only the simulation run, excluding prepare() warmup (JIT compilation, module loading, memory allocation). Increase default timesteps from 50 to 500 for representative steady-state profiling. The previous 50-step profile was dominated by one-time startup costs (cuMemHostAlloc 89ms, cuModuleLoadFatBinary 9ms, etc.) that made the timing unrepresentative of actual benchmark behavior. Co-developed-by: Claude Code v2.1.58 (claude-opus-4-6) --- .github/workflows/profile-nsys.yml | 8 +++- scripts/nsys_profile_target.py | 61 ++++++++++++++++++++---------- 2 files changed, 48 insertions(+), 21 deletions(-) diff --git a/.github/workflows/profile-nsys.yml b/.github/workflows/profile-nsys.yml index 593c1970..28439c16 100644 --- a/.github/workflows/profile-nsys.yml +++ b/.github/workflows/profile-nsys.yml @@ -14,8 +14,8 @@ on: - c6288 default: ring timesteps: - description: 'Number of timesteps' - default: '50' + description: 'Number of timesteps (500+ recommended for steady-state profiling)' + default: '500' sparse: description: 'Use sparse solver (for large circuits)' type: boolean @@ -138,8 +138,12 @@ jobs: echo "Sparse: ${{ inputs.sparse }}" echo "Commit: ${{ github.sha }}" + # Use cudaProfilerApi capture range to profile ONLY the simulation run + # (excludes JIT warmup, module loading, memory allocation setup) nsys profile \ --trace=cuda,nvtx,osrt \ + --capture-range=cudaProfilerApi \ + --capture-range-end=stop \ --output "/tmp/${PROFILE_NAME}" \ uv run python scripts/nsys_profile_target.py \ ${{ inputs.circuit }} ${{ inputs.timesteps }} ${SPARSE_FLAG} diff --git a/scripts/nsys_profile_target.py b/scripts/nsys_profile_target.py index db64ac4d..cfb30e55 100644 --- a/scripts/nsys_profile_target.py +++ b/scripts/nsys_profile_target.py @@ -1,26 +1,20 @@ #!/usr/bin/env python3 -"""Target script for nsys-jax profiling - runs circuit simulation. +"""Target script for nsys GPU profiling - runs circuit simulation. -This script is designed to be wrapped by nsys-jax: - nsys-jax -o profile.zip python scripts/nsys_profile_target.py [circuit] [timesteps] - -nsys-jax automatically handles: -- XLA_FLAGS configuration for HLO metadata dumping -- JAX_TRACEBACK_IN_LOCATIONS_LIMIT for stack traces -- JAX_ENABLE_COMPILATION_CACHE=false for metadata collection +Uses CUDA profiler API to capture ONLY the simulation run (not warmup/JIT). +Run with nsys --capture-range=cudaProfilerApi to enable selective capture. Usage: - python scripts/nsys_profile_target.py [circuit] [timesteps] + nsys profile --capture-range=cudaProfilerApi --capture-range-end=stop \\ + -o profile uv run python scripts/nsys_profile_target.py ring 500 Arguments: - circuit: One of rc, graetz, mul, ring (default: ring) - timesteps: Number of timesteps to simulate (default: 50) - -Example: - nsys-jax -o /tmp/profile.zip python scripts/nsys_profile_target.py ring 100 + circuit: One of rc, graetz, mul, ring, c6288 (default: ring) + timesteps: Number of timesteps to simulate (default: 500) """ import argparse +import ctypes import sys from pathlib import Path @@ -32,8 +26,27 @@ from vajax.analysis import CircuitEngine +def _cuda_profiler_start(): + """Start CUDA profiler capture via cudaProfilerStart().""" + try: + libcudart = ctypes.CDLL("libcudart.so") + libcudart.cudaProfilerStart() + return True + except OSError: + return False + + +def _cuda_profiler_stop(): + """Stop CUDA profiler capture via cudaProfilerStop().""" + try: + libcudart = ctypes.CDLL("libcudart.so") + libcudart.cudaProfilerStop() + except OSError: + pass + + def main(): - parser = argparse.ArgumentParser(description="nsys-jax profiling target for VAJAX") + parser = argparse.ArgumentParser(description="nsys profiling target for VAJAX") parser.add_argument( "circuit", nargs="?", @@ -45,8 +58,8 @@ def main(): "timesteps", nargs="?", type=int, - default=50, - help="Number of timesteps to simulate (default: 50)", + default=500, + help="Number of timesteps to simulate (default: 500)", ) parser.add_argument( "--backend", @@ -89,7 +102,7 @@ def main(): print(f"Using dt={dt}") print() - # Prepare (includes 1-step JIT warmup) + # Prepare (includes 1-step JIT warmup) — NOT profiled print(f"Preparing ({args.timesteps} timesteps, includes JIT warmup)...") engine.prepare( t_stop=args.timesteps * dt, @@ -99,10 +112,20 @@ def main(): print("Prepare complete") print() - # Profiled run - nsys-jax captures this automatically + # Start CUDA profiler capture — only the simulation run is profiled + has_profiler = _cuda_profiler_start() + if has_profiler: + print("CUDA profiler capture started (warmup excluded)") + else: + print("WARNING: cudaProfilerStart() unavailable — profiling entire process") + print(f"Starting profiled run ({args.timesteps} timesteps)...") result = engine.run_transient() + # Stop CUDA profiler capture + if has_profiler: + _cuda_profiler_stop() + print() print(f"Completed: {result.num_steps} timesteps") print(f"Wall time: {result.stats.get('wall_time', 0):.3f}s") From 982caeb78822cbd02af9c56ced74b4fad788e71e Mon Sep 17 00:00:00 2001 From: Rob Taylor Date: Mon, 9 Mar 2026 17:44:05 +0000 Subject: [PATCH 26/79] feat: Add HLO analysis script for dense benchmark circuits Dumps jaxpr, HLO op counts, and XLA cost analysis for build_system and nr_solve across any benchmark circuit. Useful for profiling XLA compilation and identifying optimization opportunities. Co-developed-by: Claude Code v2.1.58 (claude-opus-4-6) --- scripts/analyze_dense_jaxpr.py | 217 +++++++++++++++++++++++++++++++++ 1 file changed, 217 insertions(+) create mode 100644 scripts/analyze_dense_jaxpr.py diff --git a/scripts/analyze_dense_jaxpr.py b/scripts/analyze_dense_jaxpr.py new file mode 100644 index 00000000..2d5baf1d --- /dev/null +++ b/scripts/analyze_dense_jaxpr.py @@ -0,0 +1,217 @@ +#!/usr/bin/env -S uv run --script +# /// script +# requires-python = ">=3.10" +# dependencies = ["jax", "jaxlib"] +# /// +"""Analyze JAX IR for dense benchmark circuits. + +Dumps jaxpr, HLO op counts, and cost analysis for the key hot paths: +1. build_system (Jacobian + residual assembly) +2. nr_solve (Newton-Raphson with while_loop) +3. run_while (full transient step with adaptive timestep) +""" + +import os +import sys +from pathlib import Path + +os.environ.setdefault("JAX_PLATFORMS", "cpu") + +# Add project root to path +sys.path.insert(0, str(Path(__file__).parent.parent)) + +import jax +import jax.numpy as jnp + +from vajax.analysis.engine import CircuitEngine +from vajax.benchmarks.registry import get_benchmark + + +def count_hlo_ops(hlo_text: str) -> dict[str, int]: + """Count operation types in HLO text.""" + op_counts: dict[str, int] = {} + for line in hlo_text.split("\n"): + if "=" in line and "." in line: + parts = line.split("=") + if len(parts) >= 2: + op_part = parts[1].strip().split()[0] if parts[1].strip() else "" + if "." in op_part: + op_name = op_part.split("(")[0] + op_counts[op_name] = op_counts.get(op_name, 0) + 1 + return dict(sorted(op_counts.items(), key=lambda x: -x[1])) + + +def analyze_function(name: str, fn, args, output_dir: Path): + """Analyze a single function: jaxpr, HLO, cost.""" + print(f"\n{'=' * 70}") + print(f" {name}") + print(f"{'=' * 70}") + + try: + if hasattr(fn, "lower"): + lowered = fn.lower(*args) + else: + lowered = jax.jit(fn).lower(*args) + + hlo_text = lowered.as_text() + hlo_lines = hlo_text.split("\n") + print(f" HLO: {len(hlo_lines)} lines") + + ops = count_hlo_ops(hlo_text) + if ops: + top = list(ops.items())[:20] + print(" Top HLO ops:") + for op, count in top: + print(f" {op:40s} {count:6d}") + + compiled = lowered.compile() + cost = compiled.cost_analysis() + if cost: + for i, device_cost in enumerate(cost): + if device_cost and isinstance(device_cost, dict): + print(f" Cost (device {i}):") + for key, val in sorted(device_cost.items()): + if isinstance(val, (int, float)): + if val > 1e9: + print(f" {key}: {val / 1e9:.2f}G") + elif val > 1e6: + print(f" {key}: {val / 1e6:.2f}M") + elif val > 1e3: + print(f" {key}: {val / 1e3:.2f}K") + else: + print(f" {key}: {val:.2f}") + + # Save HLO + output_dir.mkdir(parents=True, exist_ok=True) + hlo_file = output_dir / f"{name}.hlo.txt" + with open(hlo_file, "w") as f: + f.write(hlo_text) + print(f" Saved: {hlo_file}") + + except Exception as e: + import traceback + + print(f" Failed: {e}") + traceback.print_exc() + + +def analyze_benchmark(benchmark_name: str, output_dir: Path): + """Analyze all hot paths for a single benchmark.""" + print(f"\n{'#' * 70}") + print(f" Benchmark: {benchmark_name}") + print(f"{'#' * 70}") + + config = get_benchmark(benchmark_name) + engine = CircuitEngine(config.sim_path) + engine.parse() + + # Use short simulation for analysis (just need compilation, not full run) + num_steps = 100 + engine.prepare( + t_stop=config.dt * num_steps, + dt=config.dt, + use_sparse=False, + ) + + # Get strategy internals + strategy = engine._strategy + setup_cache = engine._transient_setup_cache + + n_total = setup_cache["n_total"] + n_unknowns = setup_cache["n_unknowns"] + n_vsources = len([d for d in engine.devices if d["model"] == "vsource"]) + n_isources = len([d for d in engine.devices if d["model"] == "isource"]) + n_augmented = n_unknowns + n_vsources + + print(f" Nodes: {n_total}, Unknowns: {n_unknowns}, Vsources: {n_vsources}") + print(f" Augmented system: {n_augmented}x{n_augmented}") + + bench_dir = output_dir / benchmark_name + + # 1. Analyze build_system (Jacobian + residual assembly) + build_fn = setup_cache.get("build_system_fn") + device_arrays = engine._device_arrays + if build_fn is not None and device_arrays is not None: + X = jnp.zeros(n_total + n_vsources, dtype=jnp.float64) + vsource_vals = jnp.zeros(n_vsources, dtype=jnp.float64) + isource_vals = jnp.zeros(max(n_isources, 0), dtype=jnp.float64) + Q_prev = jnp.zeros(n_unknowns, dtype=jnp.float64) + integ_c0 = jnp.asarray(1e9, dtype=jnp.float64) # typical 1/dt + gmin = jnp.asarray(1e-12, dtype=jnp.float64) + gshunt = jnp.asarray(0.0, dtype=jnp.float64) + integ_c1 = jnp.asarray(0.0, dtype=jnp.float64) + integ_d1 = jnp.asarray(0.0, dtype=jnp.float64) + dQdt_prev = jnp.zeros(n_unknowns, dtype=jnp.float64) + integ_c2 = jnp.asarray(0.0, dtype=jnp.float64) + Q_prev2 = jnp.zeros(n_unknowns, dtype=jnp.float64) + total_limit_states = setup_cache.get("total_limit_states", 0) + limit_state = jnp.zeros(total_limit_states, dtype=jnp.float64) + nr_iter = jnp.asarray(1, dtype=jnp.int32) + + build_args = ( + X, + vsource_vals, + isource_vals, + Q_prev, + integ_c0, + device_arrays, + gmin, + gshunt, + integ_c1, + integ_d1, + dQdt_prev, + integ_c2, + Q_prev2, + limit_state, + nr_iter, + ) + analyze_function("build_system", build_fn, build_args, bench_dir) + + # 2. Analyze nr_solve + nr_solve = setup_cache.get("nr_solve_fn") + if nr_solve is not None and device_arrays is not None: + X_init = jnp.zeros(n_total + n_vsources, dtype=jnp.float64) + vsource_vals = jnp.zeros(n_vsources, dtype=jnp.float64) + isource_vals = jnp.zeros(max(n_isources, 0), dtype=jnp.float64) + Q_prev = jnp.zeros(n_unknowns, dtype=jnp.float64) + integ_c0 = jnp.asarray(1e9, dtype=jnp.float64) + + nr_args = (X_init, vsource_vals, isource_vals, Q_prev, integ_c0, device_arrays) + analyze_function("nr_solve", nr_solve, nr_args, bench_dir) + + # 3. Analyze full transient step (run_while) + run_while = getattr(strategy, "_jit_run_while", None) + if run_while is None: + # Try to find it in the cache + run_while = strategy._jit_run_while_cache.get( + strategy._get_cache_key() if hasattr(strategy, "_get_cache_key") else None + ) + + if run_while is not None: + print("\n Found run_while - analyzing full transient loop") + # This one is harder to trace without the actual state + # We'd need to construct a FullMNAState - skip for now + print(" (skipping run_while - complex state tuple)") + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser(description="Analyze dense benchmark JAX IR") + parser.add_argument( + "--benchmark", + default="rc,graetz,mul,ring", + help="Comma-separated benchmarks", + ) + parser.add_argument( + "--output-dir", + default="/tmp/claude/jaxpr-analysis", + help="Output directory for HLO files", + ) + args = parser.parse_args() + + output_dir = Path(args.output_dir) + benchmarks = [b.strip() for b in args.benchmark.split(",")] + + for bench in benchmarks: + analyze_benchmark(bench, output_dir) From 75dc078b49f479ce4787aae8ab265191ab66a076 Mon Sep 17 00:00:00 2001 From: Rob Taylor Date: Mon, 9 Mar 2026 17:45:35 +0000 Subject: [PATCH 27/79] fix: Tolerate SIGSEGV during nsys teardown (exit 139) JAX/CUDA runtime cleanup can segfault after simulation completes, but the nsys report is still valid if capture range ended successfully. Only fail the workflow if the .nsys-rep file wasn't generated. Co-developed-by: Claude Code v2.1.58 (claude-opus-4-6) --- .github/workflows/profile-nsys.yml | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/.github/workflows/profile-nsys.yml b/.github/workflows/profile-nsys.yml index 28439c16..4ed6408e 100644 --- a/.github/workflows/profile-nsys.yml +++ b/.github/workflows/profile-nsys.yml @@ -140,13 +140,24 @@ jobs: # Use cudaProfilerApi capture range to profile ONLY the simulation run # (excludes JIT warmup, module loading, memory allocation setup) + # Tolerate exit code 139 (SIGSEGV during JAX/CUDA teardown) — the nsys + # report is still valid if "Capture range ended" was logged nsys profile \ --trace=cuda,nvtx,osrt \ --capture-range=cudaProfilerApi \ --capture-range-end=stop \ --output "/tmp/${PROFILE_NAME}" \ uv run python scripts/nsys_profile_target.py \ - ${{ inputs.circuit }} ${{ inputs.timesteps }} ${SPARSE_FLAG} + ${{ inputs.circuit }} ${{ inputs.timesteps }} ${SPARSE_FLAG} \ + || NSYS_EXIT=$? + + if [ "${NSYS_EXIT:-0}" -ne 0 ]; then + echo "::warning::nsys exited with code ${NSYS_EXIT} (139=SIGSEGV during teardown, report may still be valid)" + if [ ! -f "/tmp/${PROFILE_NAME}.nsys-rep" ]; then + echo "::error::nsys report not generated" + exit 1 + fi + fi echo "profile_name=${PROFILE_NAME}" >> "$GITHUB_ENV" From c39943cae99b1143bd42de684af9b63fe64266b0 Mon Sep 17 00:00:00 2001 From: Rob Taylor Date: Mon, 9 Mar 2026 17:48:04 +0000 Subject: [PATCH 28/79] fix: Set nsys profile_name env var before running nsys When nsys exits non-zero (e.g. 139/SIGSEGV during teardown), the env export at the end of the step never ran, leaving $profile_name empty in the export/summary steps. Move the export before the nsys command so subsequent steps can always find the report file. Co-developed-by: Claude Code v2.1.58 (claude-opus-4-6) --- .github/workflows/profile-nsys.yml | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/.github/workflows/profile-nsys.yml b/.github/workflows/profile-nsys.yml index 4ed6408e..594931c4 100644 --- a/.github/workflows/profile-nsys.yml +++ b/.github/workflows/profile-nsys.yml @@ -131,6 +131,9 @@ jobs: TIMESTAMP=$(date +%s) PROFILE_NAME="nsys-${{ inputs.circuit }}-${TIMESTAMP}" + # Export profile_name early so subsequent steps can find the report + # even if nsys exits non-zero (e.g. SIGSEGV during teardown) + echo "profile_name=${PROFILE_NAME}" >> "$GITHUB_ENV" echo "=== Starting nsys GPU Profiling ===" echo "Circuit: ${{ inputs.circuit }}" @@ -159,8 +162,6 @@ jobs: fi fi - echo "profile_name=${PROFILE_NAME}" >> "$GITHUB_ENV" - - name: Export stats and SQLite if: always() run: | From 869524dee49bac327a818eabbe4a45ec9e7268a8 Mon Sep 17 00:00:00 2001 From: Rob Taylor Date: Mon, 9 Mar 2026 18:40:27 +0000 Subject: [PATCH 29/79] fix: Add --force-export=true to nsys stats commands nsys export creates a SQLite file, then nsys stats refuses to use it because the timestamp is "older than the input file". Adding --force-export=true resolves the stale export check. Co-developed-by: Claude Code v2.1.58 (claude-opus-4-6) --- .github/workflows/profile-nsys.yml | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/.github/workflows/profile-nsys.yml b/.github/workflows/profile-nsys.yml index 594931c4..cd60edf8 100644 --- a/.github/workflows/profile-nsys.yml +++ b/.github/workflows/profile-nsys.yml @@ -176,12 +176,12 @@ jobs: # Generate key stats reports for report in cuda_gpu_kern_sum cuda_api_sum cuda_gpu_mem_time_sum cuda_gpu_mem_size_sum; do nsys stats "$REPORT" --report "$report" --format csv \ - --output "${STATS_DIR}/${report}" 2>/dev/null || true + --force-export=true --output "${STATS_DIR}/${report}" 2>/dev/null || true done # Full kernel trace (top 100 by duration) nsys stats "$REPORT" --report cuda_gpu_trace --format csv \ - --output "${STATS_DIR}/cuda_gpu_trace" 2>/dev/null || true + --force-export=true --output "${STATS_DIR}/cuda_gpu_trace" 2>/dev/null || true fi - name: Generate summary @@ -199,17 +199,17 @@ jobs: if [ -f "$REPORT" ]; then echo "### GPU Kernel Summary" echo '```' - nsys stats "$REPORT" --report cuda_gpu_kern_sum 2>&1 | head -50 || true + nsys stats "$REPORT" --report cuda_gpu_kern_sum --force-export=true 2>&1 | head -50 || true echo '```' echo "" echo "### CUDA API Summary" echo '```' - nsys stats "$REPORT" --report cuda_api_sum 2>&1 | head -30 || true + nsys stats "$REPORT" --report cuda_api_sum --force-export=true 2>&1 | head -30 || true echo '```' echo "" echo "### GPU Memory Transfers" echo '```' - nsys stats "$REPORT" --report cuda_gpu_mem_time_sum 2>&1 | head -20 || true + nsys stats "$REPORT" --report cuda_gpu_mem_time_sum --force-export=true 2>&1 | head -20 || true echo '```' else echo "**No profile report found.**" From 06dc3bc46b479317a1b89ad2a8ee652554ff44ac Mon Sep 17 00:00:00 2001 From: Rob Taylor Date: Mon, 9 Mar 2026 18:53:03 +0000 Subject: [PATCH 30/79] fix: Drop cudaProfilerApi capture range from nsys profiling MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit XLA uses its own CUDA runtime internally, so ctypes cudaProfilerStart() has no effect on XLA-launched kernels — nsys captured an empty profile. Revert to full capture with 500 timesteps (warmup <5% of total). Co-developed-by: Claude Code v2.1.58 (claude-opus-4-6) --- .github/workflows/profile-nsys.yml | 10 +++---- scripts/nsys_profile_target.py | 43 +++++------------------------- 2 files changed, 10 insertions(+), 43 deletions(-) diff --git a/.github/workflows/profile-nsys.yml b/.github/workflows/profile-nsys.yml index cd60edf8..38fdadfe 100644 --- a/.github/workflows/profile-nsys.yml +++ b/.github/workflows/profile-nsys.yml @@ -141,14 +141,12 @@ jobs: echo "Sparse: ${{ inputs.sparse }}" echo "Commit: ${{ github.sha }}" - # Use cudaProfilerApi capture range to profile ONLY the simulation run - # (excludes JIT warmup, module loading, memory allocation setup) - # Tolerate exit code 139 (SIGSEGV during JAX/CUDA teardown) — the nsys - # report is still valid if "Capture range ended" was logged + # Full capture with enough timesteps (500+) that warmup overhead is <5% + # Note: cudaProfilerApi capture range doesn't work with JAX/XLA — XLA uses + # its own CUDA runtime internally, so ctypes cudaProfilerStart() has no effect + # Tolerate exit code 139 (SIGSEGV during JAX/CUDA teardown) nsys profile \ --trace=cuda,nvtx,osrt \ - --capture-range=cudaProfilerApi \ - --capture-range-end=stop \ --output "/tmp/${PROFILE_NAME}" \ uv run python scripts/nsys_profile_target.py \ ${{ inputs.circuit }} ${{ inputs.timesteps }} ${SPARSE_FLAG} \ diff --git a/scripts/nsys_profile_target.py b/scripts/nsys_profile_target.py index cfb30e55..18098215 100644 --- a/scripts/nsys_profile_target.py +++ b/scripts/nsys_profile_target.py @@ -1,20 +1,17 @@ #!/usr/bin/env python3 """Target script for nsys GPU profiling - runs circuit simulation. -Uses CUDA profiler API to capture ONLY the simulation run (not warmup/JIT). -Run with nsys --capture-range=cudaProfilerApi to enable selective capture. - Usage: - nsys profile --capture-range=cudaProfilerApi --capture-range-end=stop \\ - -o profile uv run python scripts/nsys_profile_target.py ring 500 + nsys profile -o profile uv run python scripts/nsys_profile_target.py ring 500 Arguments: circuit: One of rc, graetz, mul, ring, c6288 (default: ring) timesteps: Number of timesteps to simulate (default: 500) + +Use 500+ timesteps so JIT warmup overhead is <5% of total profile. """ import argparse -import ctypes import sys from pathlib import Path @@ -26,25 +23,6 @@ from vajax.analysis import CircuitEngine -def _cuda_profiler_start(): - """Start CUDA profiler capture via cudaProfilerStart().""" - try: - libcudart = ctypes.CDLL("libcudart.so") - libcudart.cudaProfilerStart() - return True - except OSError: - return False - - -def _cuda_profiler_stop(): - """Stop CUDA profiler capture via cudaProfilerStop().""" - try: - libcudart = ctypes.CDLL("libcudart.so") - libcudart.cudaProfilerStop() - except OSError: - pass - - def main(): parser = argparse.ArgumentParser(description="nsys profiling target for VAJAX") parser.add_argument( @@ -102,7 +80,7 @@ def main(): print(f"Using dt={dt}") print() - # Prepare (includes 1-step JIT warmup) — NOT profiled + # Prepare (includes 1-step JIT warmup) print(f"Preparing ({args.timesteps} timesteps, includes JIT warmup)...") engine.prepare( t_stop=args.timesteps * dt, @@ -112,20 +90,11 @@ def main(): print("Prepare complete") print() - # Start CUDA profiler capture — only the simulation run is profiled - has_profiler = _cuda_profiler_start() - if has_profiler: - print("CUDA profiler capture started (warmup excluded)") - else: - print("WARNING: cudaProfilerStart() unavailable — profiling entire process") - + # Profiled run — nsys captures everything including warmup above, + # but with 500+ steps the warmup is a small fraction of total time print(f"Starting profiled run ({args.timesteps} timesteps)...") result = engine.run_transient() - # Stop CUDA profiler capture - if has_profiler: - _cuda_profiler_stop() - print() print(f"Completed: {result.num_steps} timesteps") print(f"Wall time: {result.stats.get('wall_time', 0):.3f}s") From 338ced1ec50650fb8ab84e97492fa7848f6eb031 Mon Sep 17 00:00:00 2001 From: Rob Taylor Date: Mon, 9 Mar 2026 19:38:55 +0000 Subject: [PATCH 31/79] feat: Convert NR while_loop to fori_loop for GPU-resident execution MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace lax.while_loop with lax.fori_loop + lax.cond in the NR inner loop. XLA compiles fori_loop as a counted GPU loop without host round-trips, eliminating ~3.6ms CPU-GPU overhead per NR iteration (88% of NR wall time was host orchestration in nsys profiling). Key changes: - Split 20-element state tuple into 7-element mutable state (carried by fori_loop) and 13-element frozen state (captured by closure) - fori_loop body uses lax.cond to skip NR computation after convergence (identity branch is essentially free) - Same numerical results — all 398 tests pass unchanged The outer transient while_loop stays unchanged (1 round-trip per timestep, amortized cost is acceptable). Co-developed-by: Claude Code v2.1.58 (claude-opus-4-6) --- vajax/analysis/solver_factories.py | 297 ++++++++++++----------------- 1 file changed, 123 insertions(+), 174 deletions(-) diff --git a/vajax/analysis/solver_factories.py b/vajax/analysis/solver_factories.py index 7f6aaa28..d244239b 100644 --- a/vajax/analysis/solver_factories.py +++ b/vajax/analysis/solver_factories.py @@ -199,136 +199,6 @@ def _make_nr_solver_common( ] ) - def cond_fn(state): - _, iteration, converged, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _ = state - return jnp.logical_and(~converged, iteration < max_iterations) - - def body_fn(state): - ( - X, - iteration, - _, - _, - _, - _, - limit_state, - vsource_vals, - isource_vals, - Q_prev, - integ_c0, - device_arrays_arg, - gmin, - gshunt, - integ_c1, - integ_d1, - _dQdt_prev, - integ_c2, - _Q_prev2, - res_tol_floor, - ) = state - - J_or_data, f, Q, _, limit_state_out, max_res_contrib = build_system_jit( - X, - vsource_vals, - isource_vals, - Q_prev, - integ_c0, - device_arrays_arg, - gmin, - gshunt, - integ_c1, - integ_d1, - _dQdt_prev, - integ_c2, - _Q_prev2, - limit_state, - iteration, - ) - - # VACASK-style residual tolerance (coreopnr.cpp:929): - # tol[i] = max(|maxResContrib[i]| * reltol, residual_abstol[i]) - # With historic floor (coreopnr.cpp:921, relref=alllocal): - # tolref = max(maxResContrib[i], historicMaxResContrib[i]) - res_tol_nodes = jnp.maximum(max_res_contrib * reltol, res_tol_floor) - res_tol = jnp.concatenate( - [res_tol_nodes, jnp.full(n_vsources, vntol, dtype=res_tol_nodes.dtype)] - ) - if residual_conv_mask is not None: - f_check = jnp.where(residual_conv_mask, f, 0.0) - else: - f_check = f - max_f = jnp.max(jnp.abs(f_check)) - residual_converged = jnp.all(jnp.abs(f_check) < res_tol) - - # Enforce NOI constraints and solve linear system - J_or_data, f_solve = enforce_noi_fn(J_or_data, f) - delta = linear_solve_fn(J_or_data, f_solve) - - # VACASK-style delta convergence (before step limiting) - # Check the damped correction that would actually be applied - conv_delta = jnp.concatenate( - [ - delta[:n_unknowns] * nr_damping, - delta[n_unknowns:], - ] - ) - X_ref = jnp.concatenate([X[1:n_total], X[n_total:]]) - tol = jnp.maximum(jnp.abs(X_ref) * reltol, delta_abs_tol) - if residual_mask is not None: - conv_delta = jnp.where(residual_mask, conv_delta, 0.0) - # VACASK skips delta check at iteration 0 (coreopnr.cpp: if(iteration>1)) - delta_converged = (iteration == 0) | jnp.all(jnp.abs(conv_delta) < tol) - - # Track max delta for diagnostics - max_delta = jnp.max(jnp.abs(delta)) - - # Voltage-only step limiting: cap max voltage change per iteration. - V_delta = delta[:n_unknowns] - max_V_delta = jnp.max(jnp.abs(V_delta)) - V_scale = jnp.minimum(1.0, max_step / jnp.maximum(max_V_delta, 1e-30)) - V_damped = V_delta * V_scale * nr_damping - X_new = X.at[1:n_total].add(V_damped) - X_new = X_new.at[n_total:].add(delta[n_unknowns:]) - - # Clamp NOI nodes to 0V - if noi_indices is not None and len(noi_indices) > 0: - X_new = X_new.at[noi_indices].set(0.0) - - # VACASK-style AND convergence (nrsolver.h:226). - # Both solution delta and KCL residual must be below tolerance. - converged = jnp.logical_and(residual_converged, delta_converged) - - # VACASK preventedConvergence (nrsolver.cpp:326, coreopnr.cpp:778): - # When device limiting (pnjlim/fetlim) is active, block convergence. - if total_limit_states > 0: - limit_delta = jnp.max(jnp.abs(limit_state_out - limit_state)) - limit_ref = jnp.maximum(jnp.max(jnp.abs(limit_state)) * reltol, vntol) - limit_settled = limit_delta < limit_ref - converged = converged & limit_settled & (iteration >= 1) - - return ( - X_new, - iteration + 1, - converged, - max_f, - max_delta, - Q, - limit_state_out, - vsource_vals, - isource_vals, - Q_prev, - integ_c0, - device_arrays_arg, - gmin, - gshunt, - integ_c1, - integ_d1, - _dQdt_prev, - integ_c2, - _Q_prev2, - res_tol_floor, - ) - def nr_solve( X_init: Array, vsource_vals: Array, @@ -369,53 +239,132 @@ def nr_solve( _integ_d1 = jnp.asarray(integ_d1, dtype=jnp.float64) _integ_c2 = jnp.asarray(integ_c2, dtype=jnp.float64) + # --- fori_loop NR solver --- + # Mutable state carried through the loop (changes each NR iteration). + # Frozen state (vsource_vals, isource_vals, Q_prev, integration coefficients, + # device_arrays, gmin, gshunt, res_tol_floor) is captured by closure, + # halving the per-iteration carry size. init_Q = jnp.zeros(n_unknowns, dtype=jnp.float64) - init_state = ( - X_init, - jnp.array(0, dtype=jnp.int32), - jnp.array(False), - jnp.array(jnp.inf), - jnp.array(jnp.inf), - init_Q, - _limit_state, - vsource_vals, - isource_vals, - Q_prev, - _integ_c0, - device_arrays_arg, - _gmin, - _gshunt, - _integ_c1, - _integ_d1, - _dQdt_prev, - _integ_c2, - _Q_prev2, - _res_tol_floor, + mutable_init = ( + X_init, # 0: X - solution vector + jnp.array(0, dtype=jnp.int32), # 1: iterations - actual NR step count + jnp.array(False), # 2: converged + jnp.array(jnp.inf), # 3: max_f - max residual + jnp.array(jnp.inf), # 4: max_delta + init_Q, # 5: Q - charge vector + _limit_state, # 6: limit_state ) - result_state = lax.while_loop(cond_fn, body_fn, init_state) - ( - X_final, - iterations, - converged, - max_f, - _, - _, - limit_state_final, - _, - _, - _, - _, - _, - _, - _, - _, - _, - _, - _, - _, - _, - ) = result_state + def _nr_step(i, mutable): + """Perform one NR iteration. Frozen state captured by closure.""" + X, iterations, _, _, _, _, limit_state = mutable + + J_or_data, f, Q, _, limit_state_out, max_res_contrib = build_system_jit( + X, + vsource_vals, + isource_vals, + Q_prev, + _integ_c0, + device_arrays_arg, + _gmin, + _gshunt, + _integ_c1, + _integ_d1, + _dQdt_prev, + _integ_c2, + _Q_prev2, + limit_state, + i, + ) + + # VACASK-style residual tolerance (coreopnr.cpp:929): + # tol[i] = max(|maxResContrib[i]| * reltol, residual_abstol[i]) + # With historic floor (coreopnr.cpp:921, relref=alllocal): + # tolref = max(maxResContrib[i], historicMaxResContrib[i]) + res_tol_nodes = jnp.maximum(max_res_contrib * reltol, _res_tol_floor) + res_tol = jnp.concatenate( + [res_tol_nodes, jnp.full(n_vsources, vntol, dtype=res_tol_nodes.dtype)] + ) + if residual_conv_mask is not None: + f_check = jnp.where(residual_conv_mask, f, 0.0) + else: + f_check = f + max_f = jnp.max(jnp.abs(f_check)) + residual_converged = jnp.all(jnp.abs(f_check) < res_tol) + + # Enforce NOI constraints and solve linear system + J_or_data, f_solve = enforce_noi_fn(J_or_data, f) + delta = linear_solve_fn(J_or_data, f_solve) + + # VACASK-style delta convergence (before step limiting) + # Check the damped correction that would actually be applied + conv_delta = jnp.concatenate( + [ + delta[:n_unknowns] * nr_damping, + delta[n_unknowns:], + ] + ) + X_ref = jnp.concatenate([X[1:n_total], X[n_total:]]) + tol = jnp.maximum(jnp.abs(X_ref) * reltol, delta_abs_tol) + if residual_mask is not None: + conv_delta = jnp.where(residual_mask, conv_delta, 0.0) + # VACASK skips delta check at iteration 0 (coreopnr.cpp: if(iteration>1)) + delta_converged = (i == 0) | jnp.all(jnp.abs(conv_delta) < tol) + + # Track max delta for diagnostics + max_delta = jnp.max(jnp.abs(delta)) + + # Voltage-only step limiting: cap max voltage change per iteration. + V_delta = delta[:n_unknowns] + max_V_delta = jnp.max(jnp.abs(V_delta)) + V_scale = jnp.minimum(1.0, max_step / jnp.maximum(max_V_delta, 1e-30)) + V_damped = V_delta * V_scale * nr_damping + X_new = X.at[1:n_total].add(V_damped) + X_new = X_new.at[n_total:].add(delta[n_unknowns:]) + + # Clamp NOI nodes to 0V + if noi_indices is not None and len(noi_indices) > 0: + X_new = X_new.at[noi_indices].set(0.0) + + # VACASK-style AND convergence (nrsolver.h:226). + # Both solution delta and KCL residual must be below tolerance. + converged = jnp.logical_and(residual_converged, delta_converged) + + # VACASK preventedConvergence (nrsolver.cpp:326, coreopnr.cpp:778): + # When device limiting (pnjlim/fetlim) is active, block convergence. + if total_limit_states > 0: + limit_delta = jnp.max(jnp.abs(limit_state_out - limit_state)) + limit_ref = jnp.maximum(jnp.max(jnp.abs(limit_state)) * reltol, vntol) + limit_settled = limit_delta < limit_ref + converged = converged & limit_settled & (i >= 1) + + return ( + X_new, + iterations + 1, + converged, + max_f, + max_delta, + Q, + limit_state_out, + ) + + def _nr_body(i, mutable): + """fori_loop body: skip if already converged, else do NR step. + + lax.cond compiles both branches but the identity branch is essentially + free (no compute). The key win is that fori_loop uses a compile-time + bound, so XLA keeps the entire loop on-GPU without host round-trips. + """ + converged = mutable[2] + return lax.cond( + converged, + lambda m: m, + lambda m: _nr_step(i, m), + mutable, + ) + + result_mutable = lax.fori_loop(0, max_iterations, _nr_body, mutable_init) + X_final, iterations, converged, max_f, _, _, limit_state_final = result_mutable # Recompute Q and I_vsource from converged solution _, _, Q_final, I_vsource, _, max_res_contrib_final = build_system_jit( From ab0a322f2c570eca377a1017891a1159cbc66fd9 Mon Sep 17 00:00:00 2001 From: Rob Taylor Date: Mon, 9 Mar 2026 20:17:53 +0000 Subject: [PATCH 32/79] Revert "feat: Convert NR while_loop to fori_loop for GPU-resident execution" This reverts commit 338ced1ec50650fb8ab84e97492fa7848f6eb031. --- vajax/analysis/solver_factories.py | 297 +++++++++++++++++------------ 1 file changed, 174 insertions(+), 123 deletions(-) diff --git a/vajax/analysis/solver_factories.py b/vajax/analysis/solver_factories.py index d244239b..7f6aaa28 100644 --- a/vajax/analysis/solver_factories.py +++ b/vajax/analysis/solver_factories.py @@ -199,6 +199,136 @@ def _make_nr_solver_common( ] ) + def cond_fn(state): + _, iteration, converged, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _ = state + return jnp.logical_and(~converged, iteration < max_iterations) + + def body_fn(state): + ( + X, + iteration, + _, + _, + _, + _, + limit_state, + vsource_vals, + isource_vals, + Q_prev, + integ_c0, + device_arrays_arg, + gmin, + gshunt, + integ_c1, + integ_d1, + _dQdt_prev, + integ_c2, + _Q_prev2, + res_tol_floor, + ) = state + + J_or_data, f, Q, _, limit_state_out, max_res_contrib = build_system_jit( + X, + vsource_vals, + isource_vals, + Q_prev, + integ_c0, + device_arrays_arg, + gmin, + gshunt, + integ_c1, + integ_d1, + _dQdt_prev, + integ_c2, + _Q_prev2, + limit_state, + iteration, + ) + + # VACASK-style residual tolerance (coreopnr.cpp:929): + # tol[i] = max(|maxResContrib[i]| * reltol, residual_abstol[i]) + # With historic floor (coreopnr.cpp:921, relref=alllocal): + # tolref = max(maxResContrib[i], historicMaxResContrib[i]) + res_tol_nodes = jnp.maximum(max_res_contrib * reltol, res_tol_floor) + res_tol = jnp.concatenate( + [res_tol_nodes, jnp.full(n_vsources, vntol, dtype=res_tol_nodes.dtype)] + ) + if residual_conv_mask is not None: + f_check = jnp.where(residual_conv_mask, f, 0.0) + else: + f_check = f + max_f = jnp.max(jnp.abs(f_check)) + residual_converged = jnp.all(jnp.abs(f_check) < res_tol) + + # Enforce NOI constraints and solve linear system + J_or_data, f_solve = enforce_noi_fn(J_or_data, f) + delta = linear_solve_fn(J_or_data, f_solve) + + # VACASK-style delta convergence (before step limiting) + # Check the damped correction that would actually be applied + conv_delta = jnp.concatenate( + [ + delta[:n_unknowns] * nr_damping, + delta[n_unknowns:], + ] + ) + X_ref = jnp.concatenate([X[1:n_total], X[n_total:]]) + tol = jnp.maximum(jnp.abs(X_ref) * reltol, delta_abs_tol) + if residual_mask is not None: + conv_delta = jnp.where(residual_mask, conv_delta, 0.0) + # VACASK skips delta check at iteration 0 (coreopnr.cpp: if(iteration>1)) + delta_converged = (iteration == 0) | jnp.all(jnp.abs(conv_delta) < tol) + + # Track max delta for diagnostics + max_delta = jnp.max(jnp.abs(delta)) + + # Voltage-only step limiting: cap max voltage change per iteration. + V_delta = delta[:n_unknowns] + max_V_delta = jnp.max(jnp.abs(V_delta)) + V_scale = jnp.minimum(1.0, max_step / jnp.maximum(max_V_delta, 1e-30)) + V_damped = V_delta * V_scale * nr_damping + X_new = X.at[1:n_total].add(V_damped) + X_new = X_new.at[n_total:].add(delta[n_unknowns:]) + + # Clamp NOI nodes to 0V + if noi_indices is not None and len(noi_indices) > 0: + X_new = X_new.at[noi_indices].set(0.0) + + # VACASK-style AND convergence (nrsolver.h:226). + # Both solution delta and KCL residual must be below tolerance. + converged = jnp.logical_and(residual_converged, delta_converged) + + # VACASK preventedConvergence (nrsolver.cpp:326, coreopnr.cpp:778): + # When device limiting (pnjlim/fetlim) is active, block convergence. + if total_limit_states > 0: + limit_delta = jnp.max(jnp.abs(limit_state_out - limit_state)) + limit_ref = jnp.maximum(jnp.max(jnp.abs(limit_state)) * reltol, vntol) + limit_settled = limit_delta < limit_ref + converged = converged & limit_settled & (iteration >= 1) + + return ( + X_new, + iteration + 1, + converged, + max_f, + max_delta, + Q, + limit_state_out, + vsource_vals, + isource_vals, + Q_prev, + integ_c0, + device_arrays_arg, + gmin, + gshunt, + integ_c1, + integ_d1, + _dQdt_prev, + integ_c2, + _Q_prev2, + res_tol_floor, + ) + def nr_solve( X_init: Array, vsource_vals: Array, @@ -239,132 +369,53 @@ def nr_solve( _integ_d1 = jnp.asarray(integ_d1, dtype=jnp.float64) _integ_c2 = jnp.asarray(integ_c2, dtype=jnp.float64) - # --- fori_loop NR solver --- - # Mutable state carried through the loop (changes each NR iteration). - # Frozen state (vsource_vals, isource_vals, Q_prev, integration coefficients, - # device_arrays, gmin, gshunt, res_tol_floor) is captured by closure, - # halving the per-iteration carry size. init_Q = jnp.zeros(n_unknowns, dtype=jnp.float64) - mutable_init = ( - X_init, # 0: X - solution vector - jnp.array(0, dtype=jnp.int32), # 1: iterations - actual NR step count - jnp.array(False), # 2: converged - jnp.array(jnp.inf), # 3: max_f - max residual - jnp.array(jnp.inf), # 4: max_delta - init_Q, # 5: Q - charge vector - _limit_state, # 6: limit_state + init_state = ( + X_init, + jnp.array(0, dtype=jnp.int32), + jnp.array(False), + jnp.array(jnp.inf), + jnp.array(jnp.inf), + init_Q, + _limit_state, + vsource_vals, + isource_vals, + Q_prev, + _integ_c0, + device_arrays_arg, + _gmin, + _gshunt, + _integ_c1, + _integ_d1, + _dQdt_prev, + _integ_c2, + _Q_prev2, + _res_tol_floor, ) - def _nr_step(i, mutable): - """Perform one NR iteration. Frozen state captured by closure.""" - X, iterations, _, _, _, _, limit_state = mutable - - J_or_data, f, Q, _, limit_state_out, max_res_contrib = build_system_jit( - X, - vsource_vals, - isource_vals, - Q_prev, - _integ_c0, - device_arrays_arg, - _gmin, - _gshunt, - _integ_c1, - _integ_d1, - _dQdt_prev, - _integ_c2, - _Q_prev2, - limit_state, - i, - ) - - # VACASK-style residual tolerance (coreopnr.cpp:929): - # tol[i] = max(|maxResContrib[i]| * reltol, residual_abstol[i]) - # With historic floor (coreopnr.cpp:921, relref=alllocal): - # tolref = max(maxResContrib[i], historicMaxResContrib[i]) - res_tol_nodes = jnp.maximum(max_res_contrib * reltol, _res_tol_floor) - res_tol = jnp.concatenate( - [res_tol_nodes, jnp.full(n_vsources, vntol, dtype=res_tol_nodes.dtype)] - ) - if residual_conv_mask is not None: - f_check = jnp.where(residual_conv_mask, f, 0.0) - else: - f_check = f - max_f = jnp.max(jnp.abs(f_check)) - residual_converged = jnp.all(jnp.abs(f_check) < res_tol) - - # Enforce NOI constraints and solve linear system - J_or_data, f_solve = enforce_noi_fn(J_or_data, f) - delta = linear_solve_fn(J_or_data, f_solve) - - # VACASK-style delta convergence (before step limiting) - # Check the damped correction that would actually be applied - conv_delta = jnp.concatenate( - [ - delta[:n_unknowns] * nr_damping, - delta[n_unknowns:], - ] - ) - X_ref = jnp.concatenate([X[1:n_total], X[n_total:]]) - tol = jnp.maximum(jnp.abs(X_ref) * reltol, delta_abs_tol) - if residual_mask is not None: - conv_delta = jnp.where(residual_mask, conv_delta, 0.0) - # VACASK skips delta check at iteration 0 (coreopnr.cpp: if(iteration>1)) - delta_converged = (i == 0) | jnp.all(jnp.abs(conv_delta) < tol) - - # Track max delta for diagnostics - max_delta = jnp.max(jnp.abs(delta)) - - # Voltage-only step limiting: cap max voltage change per iteration. - V_delta = delta[:n_unknowns] - max_V_delta = jnp.max(jnp.abs(V_delta)) - V_scale = jnp.minimum(1.0, max_step / jnp.maximum(max_V_delta, 1e-30)) - V_damped = V_delta * V_scale * nr_damping - X_new = X.at[1:n_total].add(V_damped) - X_new = X_new.at[n_total:].add(delta[n_unknowns:]) - - # Clamp NOI nodes to 0V - if noi_indices is not None and len(noi_indices) > 0: - X_new = X_new.at[noi_indices].set(0.0) - - # VACASK-style AND convergence (nrsolver.h:226). - # Both solution delta and KCL residual must be below tolerance. - converged = jnp.logical_and(residual_converged, delta_converged) - - # VACASK preventedConvergence (nrsolver.cpp:326, coreopnr.cpp:778): - # When device limiting (pnjlim/fetlim) is active, block convergence. - if total_limit_states > 0: - limit_delta = jnp.max(jnp.abs(limit_state_out - limit_state)) - limit_ref = jnp.maximum(jnp.max(jnp.abs(limit_state)) * reltol, vntol) - limit_settled = limit_delta < limit_ref - converged = converged & limit_settled & (i >= 1) - - return ( - X_new, - iterations + 1, - converged, - max_f, - max_delta, - Q, - limit_state_out, - ) - - def _nr_body(i, mutable): - """fori_loop body: skip if already converged, else do NR step. - - lax.cond compiles both branches but the identity branch is essentially - free (no compute). The key win is that fori_loop uses a compile-time - bound, so XLA keeps the entire loop on-GPU without host round-trips. - """ - converged = mutable[2] - return lax.cond( - converged, - lambda m: m, - lambda m: _nr_step(i, m), - mutable, - ) - - result_mutable = lax.fori_loop(0, max_iterations, _nr_body, mutable_init) - X_final, iterations, converged, max_f, _, _, limit_state_final = result_mutable + result_state = lax.while_loop(cond_fn, body_fn, init_state) + ( + X_final, + iterations, + converged, + max_f, + _, + _, + limit_state_final, + _, + _, + _, + _, + _, + _, + _, + _, + _, + _, + _, + _, + _, + ) = result_state # Recompute Q and I_vsource from converged solution _, _, Q_final, I_vsource, _, max_res_contrib_final = build_system_jit( From 880c12e873871a3b44b3fb564105061e1393dc47 Mon Sep 17 00:00:00 2001 From: Rob Taylor Date: Mon, 9 Mar 2026 20:30:08 +0000 Subject: [PATCH 33/79] feat: Enable XLA command buffer WHILE+CONDITIONAL for GPU profiling MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add +WHILE,+CONDITIONAL to xla_gpu_enable_command_buffer flag in the nsys profiling workflow. These types are NOT in XLA's default command buffer set, causing lax.while_loop and lax.cond to fall back to WhileThunk/ConditionalThunk which do D2H transfers for predicates. With CUDA 12.4+ (runner has 12.6), XLA can use CUDA graph conditional nodes for on-device control flow — predicates evaluated on GPU via cudaGraphSetConditional(), no host round-trips needed. Co-developed-by: Claude Code v2.1.58 (claude-opus-4-6) --- .github/workflows/profile-nsys.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/profile-nsys.yml b/.github/workflows/profile-nsys.yml index 38fdadfe..1c6088aa 100644 --- a/.github/workflows/profile-nsys.yml +++ b/.github/workflows/profile-nsys.yml @@ -119,7 +119,7 @@ jobs: - name: Run nsys profiling env: JAX_PLATFORMS: cuda,cpu - XLA_FLAGS: "--xla_gpu_autotune_level=0" + XLA_FLAGS: "--xla_gpu_autotune_level=0 --xla_gpu_enable_command_buffer=+WHILE,+CONDITIONAL" XLA_PYTHON_CLIENT_PREALLOCATE: "false" XLA_PYTHON_CLIENT_ALLOCATOR: platform TF_CPP_MIN_LOG_LEVEL: "2" From b5a9e2702cefb049caea3172741dcdf78739fa68 Mon Sep 17 00:00:00 2001 From: Rob Taylor Date: Mon, 9 Mar 2026 20:51:45 +0000 Subject: [PATCH 34/79] feat: Add XLA command buffer diagnostics to nsys profiling workflow MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Enable verbose TF_CPP_VMODULE logging for command buffer scheduling, while_thunk, and gpu_command_buffer modules. Also dump HLO passes filtered to command-buffer to understand why WHILE+CONDITIONAL command buffer types are not being used despite being explicitly enabled. The previous run with +WHILE,+CONDITIONAL showed identical metrics to baseline — this diagnostic pass will reveal whether XLA is silently falling back to WhileThunk due to incompatible operations in the while loop body (likely cuSOLVER workspace allocations). Co-developed-by: Claude Code v2.1.58 (claude-opus-4-6) --- .github/workflows/profile-nsys.yml | 40 ++++++++++++++++++++++++++++-- 1 file changed, 38 insertions(+), 2 deletions(-) diff --git a/.github/workflows/profile-nsys.yml b/.github/workflows/profile-nsys.yml index 1c6088aa..c3b503cf 100644 --- a/.github/workflows/profile-nsys.yml +++ b/.github/workflows/profile-nsys.yml @@ -119,10 +119,11 @@ jobs: - name: Run nsys profiling env: JAX_PLATFORMS: cuda,cpu - XLA_FLAGS: "--xla_gpu_autotune_level=0 --xla_gpu_enable_command_buffer=+WHILE,+CONDITIONAL" + XLA_FLAGS: "--xla_gpu_autotune_level=0 --xla_gpu_enable_command_buffer=+WHILE,+CONDITIONAL --xla_dump_to=/tmp/xla_dump --xla_dump_hlo_as_text --xla_dump_hlo_pass_re=command-buffer" XLA_PYTHON_CLIENT_PREALLOCATE: "false" XLA_PYTHON_CLIENT_ALLOCATOR: platform - TF_CPP_MIN_LOG_LEVEL: "2" + TF_CPP_MIN_LOG_LEVEL: "0" + TF_CPP_VMODULE: "command_buffer_schedule=3,command_buffer_cmd=3,while_thunk=2,gpu_command_buffer=2" run: | SPARSE_FLAG="" if [ "${{ inputs.sparse }}" = "true" ]; then @@ -160,6 +161,40 @@ jobs: fi fi + - name: Capture XLA command buffer diagnostics + if: always() + run: | + XLA_DUMP="/tmp/xla_dump" + XLA_DIAG="/tmp/xla-diagnostics" + mkdir -p "$XLA_DIAG" + + # Summarize command buffer pass results + if [ -d "$XLA_DUMP" ]; then + echo "=== XLA dump files ===" + find "$XLA_DUMP" -maxdepth 1 -type f | head -50 + + # Look for command-buffer pass output + for f in "$XLA_DUMP"/*command*; do + if [ -f "$f" ]; then + echo "=== $(basename "$f") ===" >> "$XLA_DIAG/command_buffer_passes.txt" + head -200 "$f" >> "$XLA_DIAG/command_buffer_passes.txt" + echo "" >> "$XLA_DIAG/command_buffer_passes.txt" + fi + done + + # Look for while loop HLO + grep -rl "while" "$XLA_DUMP"/ 2>/dev/null | head -5 | while read -r f; do + echo "=== $(basename "$f") ===" >> "$XLA_DIAG/while_loop_hlo.txt" + grep -A5 -B2 "while" "$f" | head -100 >> "$XLA_DIAG/while_loop_hlo.txt" + echo "" >> "$XLA_DIAG/while_loop_hlo.txt" + done + + # Copy small dump files + find "$XLA_DUMP" -name "*command*" -size -1M -exec cp {} "$XLA_DIAG/" \; + else + echo "No XLA dump directory found" + fi + - name: Export stats and SQLite if: always() run: | @@ -223,5 +258,6 @@ jobs: /tmp/nsys-*.nsys-rep /tmp/nsys-*.sqlite /tmp/nsys-stats/ + /tmp/xla-diagnostics/ if-no-files-found: ignore retention-days: 30 From 2bd9412764dbaf66f22c082028bd263d1caf410b Mon Sep 17 00:00:00 2001 From: Rob Taylor Date: Mon, 9 Mar 2026 21:05:39 +0000 Subject: [PATCH 35/79] chore: Clean up nsys workflow after command buffer investigation Remove XLA HLO dump and verbose logging (diagnostic-only, no longer needed). Keep +WHILE,+CONDITIONAL flag as harmless forward-compat. Root cause confirmed: cuSOLVER getrf lacks kCmdBufferCompatible FFI trait and internally uses cudaMalloc during graph capture, blocking the entire while loop body from being converted to a WhileCmd. Co-developed-by: Claude Code v2.1.58 (claude-opus-4-6) --- .github/workflows/profile-nsys.yml | 42 +++--------------------------- 1 file changed, 4 insertions(+), 38 deletions(-) diff --git a/.github/workflows/profile-nsys.yml b/.github/workflows/profile-nsys.yml index c3b503cf..431198f7 100644 --- a/.github/workflows/profile-nsys.yml +++ b/.github/workflows/profile-nsys.yml @@ -119,11 +119,12 @@ jobs: - name: Run nsys profiling env: JAX_PLATFORMS: cuda,cpu - XLA_FLAGS: "--xla_gpu_autotune_level=0 --xla_gpu_enable_command_buffer=+WHILE,+CONDITIONAL --xla_dump_to=/tmp/xla_dump --xla_dump_hlo_as_text --xla_dump_hlo_pass_re=command-buffer" + # +WHILE,+CONDITIONAL: no-op today (cuSOLVER getrf blocks graph capture) + # but will auto-activate if XLA/cuSOLVER ever becomes compatible + XLA_FLAGS: "--xla_gpu_autotune_level=0 --xla_gpu_enable_command_buffer=+WHILE,+CONDITIONAL" XLA_PYTHON_CLIENT_PREALLOCATE: "false" XLA_PYTHON_CLIENT_ALLOCATOR: platform - TF_CPP_MIN_LOG_LEVEL: "0" - TF_CPP_VMODULE: "command_buffer_schedule=3,command_buffer_cmd=3,while_thunk=2,gpu_command_buffer=2" + TF_CPP_MIN_LOG_LEVEL: "2" run: | SPARSE_FLAG="" if [ "${{ inputs.sparse }}" = "true" ]; then @@ -161,40 +162,6 @@ jobs: fi fi - - name: Capture XLA command buffer diagnostics - if: always() - run: | - XLA_DUMP="/tmp/xla_dump" - XLA_DIAG="/tmp/xla-diagnostics" - mkdir -p "$XLA_DIAG" - - # Summarize command buffer pass results - if [ -d "$XLA_DUMP" ]; then - echo "=== XLA dump files ===" - find "$XLA_DUMP" -maxdepth 1 -type f | head -50 - - # Look for command-buffer pass output - for f in "$XLA_DUMP"/*command*; do - if [ -f "$f" ]; then - echo "=== $(basename "$f") ===" >> "$XLA_DIAG/command_buffer_passes.txt" - head -200 "$f" >> "$XLA_DIAG/command_buffer_passes.txt" - echo "" >> "$XLA_DIAG/command_buffer_passes.txt" - fi - done - - # Look for while loop HLO - grep -rl "while" "$XLA_DUMP"/ 2>/dev/null | head -5 | while read -r f; do - echo "=== $(basename "$f") ===" >> "$XLA_DIAG/while_loop_hlo.txt" - grep -A5 -B2 "while" "$f" | head -100 >> "$XLA_DIAG/while_loop_hlo.txt" - echo "" >> "$XLA_DIAG/while_loop_hlo.txt" - done - - # Copy small dump files - find "$XLA_DUMP" -name "*command*" -size -1M -exec cp {} "$XLA_DIAG/" \; - else - echo "No XLA dump directory found" - fi - - name: Export stats and SQLite if: always() run: | @@ -258,6 +225,5 @@ jobs: /tmp/nsys-*.nsys-rep /tmp/nsys-*.sqlite /tmp/nsys-stats/ - /tmp/xla-diagnostics/ if-no-files-found: ignore retention-days: 30 From cdc85c0fc87cb9a05a95c8ef3c3a374c5e554177 Mon Sep 17 00:00:00 2001 From: Rob Taylor Date: Mon, 9 Mar 2026 22:43:18 +0000 Subject: [PATCH 36/79] feat: Add BaSpaCho dense solver integration for CUDA Add make_baspacho_dense_full_mna_solver() factory function that uses BaSpaCho's supernodal LU factorization on CUDA, replacing jax.scipy.linalg.solve (cuSOLVER getrf) for dense matrices on GPU. Benefits over default cuSOLVER path: - Symbolic analysis cached across NR iterations - Grow-only GPU memory (no per-call cudaMalloc after warmup) - Foundation for Phase 2b graph-capture compatibility On CUDA with BaSpaCho available, automatically selected for dense path with graceful fallback to standard JAX dense solver. Co-developed-by: Claude Code v2.1.58 (claude-opus-4-6) --- vajax/analysis/solver_factories.py | 72 ++++++++++++++++++++++++++++ vajax/analysis/transient/full_mna.py | 59 ++++++++++++++++++----- 2 files changed, 120 insertions(+), 11 deletions(-) diff --git a/vajax/analysis/solver_factories.py b/vajax/analysis/solver_factories.py index 7f6aaa28..da085653 100644 --- a/vajax/analysis/solver_factories.py +++ b/vajax/analysis/solver_factories.py @@ -536,6 +536,78 @@ def linear_solve(J, f): ) +def make_baspacho_dense_full_mna_solver( + build_system_jit: Callable, + n_nodes: int, + n_vsources: int, + noi_indices: Optional[Array] = None, + internal_device_indices: Optional[Array] = None, + max_iterations: int = 100, + abstol: float = 1e-12, + total_limit_states: int = 0, + options: Optional["SimulationOptions"] = None, + max_step: float = 1e30, +) -> Callable: + """Create a dense NR solver using BaSpaCho LU on CUDA. + + Uses BaSpaCho's supernodal LU factorization with CUDA backend, + replacing jax.scipy.linalg.solve (cuSOLVER getrf) on GPU. Benefits: + - Symbolic analysis done once (cached across NR iterations) + - Grow-only GPU memory allocation (no per-call cudaMalloc after warmup) + - Foundation for Phase 2b graph-capture compatibility + + Falls back to standard dense solver if BaSpaCho is unavailable. + + Args: + Same as make_dense_full_mna_solver. + """ + from spineax.cudss.dense_baspacho_solver import baspacho_dense_solve + + masks = _compute_noi_masks( + noi_indices, n_nodes, internal_device_indices=internal_device_indices + ) + noi_res_idx = masks["noi_res_idx"] + + residual_mask = _build_augmented_mask(masks["residual_mask"], n_vsources) + residual_conv_mask = _build_augmented_conv_mask( + masks["residual_conv_mask"], residual_mask, n_vsources + ) + + def enforce_noi(J, f): + """Enforce NOI constraints on dense Jacobian.""" + if noi_res_idx is not None: + J = J.at[noi_res_idx, :].set(0.0) + J = J.at[:, noi_res_idx].set(0.0) + J = J.at[noi_res_idx, noi_res_idx].set(1.0) + f = f.at[noi_res_idx].set(0.0) + return J, f + + def linear_solve(J, f): + """Solve J @ delta = -f using BaSpaCho LU on CUDA.""" + return baspacho_dense_solve(J, -f) + + logger.info( + f"Creating BaSpaCho dense full MNA solver: V({n_nodes}) + I({n_vsources}), " + f"NOI: {noi_indices is not None}" + ) + return _make_nr_solver_common( + build_system_jit=build_system_jit, + n_nodes=n_nodes, + n_vsources=n_vsources, + linear_solve_fn=linear_solve, + enforce_noi_fn=enforce_noi, + noi_indices=noi_indices, + internal_device_indices=internal_device_indices, + max_iterations=max_iterations, + abstol=abstol, + total_limit_states=total_limit_states, + options=options, + max_step=max_step, + residual_mask=residual_mask, + residual_conv_mask=residual_conv_mask, + ) + + def make_spineax_full_mna_solver( build_system_jit: Callable, n_nodes: int, diff --git a/vajax/analysis/transient/full_mna.py b/vajax/analysis/transient/full_mna.py index fbf330ed..48deb7cb 100644 --- a/vajax/analysis/transient/full_mna.py +++ b/vajax/analysis/transient/full_mna.py @@ -36,6 +36,7 @@ from vajax._logging import logger from vajax.analysis.solver_factories import ( + make_baspacho_dense_full_mna_solver, make_dense_full_mna_solver, make_spineax_full_mna_solver, make_umfpack_ffi_full_mna_solver, @@ -53,6 +54,16 @@ def is_spineax_available() -> bool: return False +def is_baspacho_dense_available() -> bool: + """Check if BaSpaCho dense CUDA solver is available.""" + try: + from spineax.cudss.dense_baspacho_solver import is_available + + return is_available() + except ImportError: + return False + + from vajax.analysis.integration import IntegrationMethod from .adaptive import AdaptiveConfig, compute_lte_timestep_jax, predict_voltage_jax @@ -385,17 +396,43 @@ def _ensure_full_mna_solver(self, setup: TransientSetup) -> Callable: self._total_limit_states = total_limit_states build_system_jit = jax.jit(build_system_fn) - nr_solve = make_dense_full_mna_solver( - build_system_jit, - n_nodes, - n_vsources, - noi_indices=noi_indices, - internal_device_indices=internal_device_indices, - max_iterations=self.runner.options.tran_itl, - abstol=self.runner.options.abstol, - total_limit_states=total_limit_states, - options=self.runner.options, - ) + # On CUDA, try BaSpaCho dense solver (pre-allocated workspace, + # foundation for graph-capture compatibility in Phase 2b). + on_cuda_dense = jax.default_backend() in ("cuda", "gpu") + if on_cuda_dense and is_baspacho_dense_available(): + try: + logger.info("Using BaSpaCho dense solver (GPU, CUDA backend)") + nr_solve = make_baspacho_dense_full_mna_solver( + build_system_jit, + n_nodes, + n_vsources, + noi_indices=noi_indices, + internal_device_indices=internal_device_indices, + max_iterations=self.runner.options.tran_itl, + abstol=self.runner.options.abstol, + total_limit_states=total_limit_states, + options=self.runner.options, + ) + except Exception as e: + logger.warning( + f"BaSpaCho dense solver failed ({e}), falling back to JAX dense solver" + ) + nr_solve = None + else: + nr_solve = None + + if nr_solve is None: + nr_solve = make_dense_full_mna_solver( + build_system_jit, + n_nodes, + n_vsources, + noi_indices=noi_indices, + internal_device_indices=internal_device_indices, + max_iterations=self.runner.options.tran_itl, + abstol=self.runner.options.abstol, + total_limit_states=total_limit_states, + options=self.runner.options, + ) else: # Sparse path: use CSR direct stamping to eliminate COO intermediates n_augmented = setup.n_unknowns + n_vsources From 83d0fe892037e36b1e5a7613415e364e74d602c3 Mon Sep 17 00:00:00 2001 From: Rob Taylor Date: Tue, 10 Mar 2026 00:43:31 +0000 Subject: [PATCH 37/79] feat: Add BaSpaCho option to nsys profiling workflow - New `use_baspacho` input builds spineax from source with BaSpaCho dense solver (auto-fetches via CMake FetchContent) - Update command buffer comments: BaSpaCho has kCmdBufferCompatible trait, replacing cuSOLVER which blocked graph capture Co-developed-by: Claude Code v2.1.58 (claude-opus-4-6) --- .github/workflows/profile-nsys.yml | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/.github/workflows/profile-nsys.yml b/.github/workflows/profile-nsys.yml index 431198f7..2c95423a 100644 --- a/.github/workflows/profile-nsys.yml +++ b/.github/workflows/profile-nsys.yml @@ -20,6 +20,10 @@ on: description: 'Use sparse solver (for large circuits)' type: boolean default: false + use_baspacho: + description: 'Build spineax from source with BaSpaCho dense solver' + type: boolean + default: false # Only one profiling job at a time on the GPU runner concurrency: @@ -116,11 +120,20 @@ jobs: - name: Install vajax with CUDA dependencies run: uv sync --extra cuda12 + - name: Install spineax with BaSpaCho from source + if: inputs.use_baspacho + run: | + # Replace PyPI spineax with git version that has BaSpaCho dense solver. + # FetchContent auto-fetches BaSpaCho + deps (SuiteSparse, dispenso, Eigen). + uv pip install \ + "spineax-vajax @ git+https://github.com/robtaylor/spineax.git@main" \ + -C cmake.define.SPINEAX_USE_BASPACHO=ON + - name: Run nsys profiling env: JAX_PLATFORMS: cuda,cpu - # +WHILE,+CONDITIONAL: no-op today (cuSOLVER getrf blocks graph capture) - # but will auto-activate if XLA/cuSOLVER ever becomes compatible + # +WHILE,+CONDITIONAL: enables command buffer capture for NR while loop + # BaSpaCho dense solver has kCmdBufferCompatible trait (replaces cuSOLVER) XLA_FLAGS: "--xla_gpu_autotune_level=0 --xla_gpu_enable_command_buffer=+WHILE,+CONDITIONAL" XLA_PYTHON_CLIENT_PREALLOCATE: "false" XLA_PYTHON_CLIENT_ALLOCATOR: platform From cf50c8f7781a489ef389bd09adb22bbb4b46ceaa Mon Sep 17 00:00:00 2001 From: Rob Taylor Date: Tue, 10 Mar 2026 00:51:58 +0000 Subject: [PATCH 38/79] fix: Add libopenblas-dev for BaSpaCho build in nsys workflow BaSpaCho requires BLAS (find_package(BLAS REQUIRED) with OpenBLAS vendor). Co-developed-by: Claude Code v2.1.58 (claude-opus-4-6) --- .github/workflows/profile-nsys.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/profile-nsys.yml b/.github/workflows/profile-nsys.yml index 2c95423a..18a06516 100644 --- a/.github/workflows/profile-nsys.yml +++ b/.github/workflows/profile-nsys.yml @@ -101,7 +101,7 @@ jobs: - name: Install system dependencies run: | - sudo apt-get install -y libsuitesparse-dev swig cmake pkg-config + sudo apt-get install -y libsuitesparse-dev libopenblas-dev swig cmake pkg-config - name: Install uv uses: astral-sh/setup-uv@v6 From 6fefb8e35d8202c5b1bd8a5167035d211e43defc Mon Sep 17 00:00:00 2001 From: Rob Taylor Date: Tue, 10 Mar 2026 16:25:53 +0000 Subject: [PATCH 39/79] chore: Add solver availability logging to nsys profiling script Enable INFO-level logging and explicit import checks for BaSpaCho dense solver, cuDSS sparse solver, and the C++ nanobind module. This helps diagnose whether BaSpaCho is actually activating during GPU profiling runs (nsys profile showed cuSOLVER kernels instead of BaSpaCho). Co-developed-by: Claude Code v2.1.58 (claude-opus-4-6) --- scripts/nsys_profile_target.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/scripts/nsys_profile_target.py b/scripts/nsys_profile_target.py index 18098215..cf5b96de 100644 --- a/scripts/nsys_profile_target.py +++ b/scripts/nsys_profile_target.py @@ -12,6 +12,7 @@ """ import argparse +import logging import sys from pathlib import Path @@ -19,6 +20,9 @@ sys.path.insert(0, ".") +# Enable INFO logging so solver selection messages are visible +logging.basicConfig(level=logging.INFO, format="%(name)s: %(message)s") + # Import vajax first to auto-configure precision based on backend from vajax.analysis import CircuitEngine @@ -56,6 +60,29 @@ def main(): print(f"JAX devices: {jax.devices()}") print(f"Circuit: {args.circuit}") print(f"Timesteps: {args.timesteps}") + + # Explicit solver availability check + print() + print("=== Solver Availability ===") + try: + from spineax.cudss.dense_baspacho_solver import is_available + + print(" BaSpaCho dense import: OK") + print(f" BaSpaCho dense available: {is_available()}") + except ImportError as e: + print(f" BaSpaCho dense import: FAILED ({e})") + try: + from spineax.cudss.solver import CuDSSSolver # noqa: F401 + + print(" cuDSS sparse import: OK") + except ImportError as e: + print(f" cuDSS sparse import: FAILED ({e})") + try: + from spineax import baspacho_dense_solve as _mod + + print(f" baspacho_dense_solve C++ module: OK ({_mod})") + except ImportError as e: + print(f" baspacho_dense_solve C++ module: FAILED ({e})") print() # Find benchmark .sim file From d0a6271d66ff1da1a59cbcc8f910a40557dade00 Mon Sep 17 00:00:00 2001 From: Rob Taylor Date: Tue, 10 Mar 2026 16:41:16 +0000 Subject: [PATCH 40/79] fix: Force reinstall spineax with --reinstall --no-cache for BaSpaCho uv pip install was likely reusing a cached wheel from PyPI (built without SPINEAX_USE_BASPACHO=ON), so baspacho_dense_solve.so was never installed. Add --reinstall --no-cache to force a source build, plus diagnostics to verify the module is installed and importable. Co-developed-by: Claude Code v2.1.58 (claude-opus-4-6) --- .github/workflows/profile-nsys.yml | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/.github/workflows/profile-nsys.yml b/.github/workflows/profile-nsys.yml index 18a06516..62e3ecb0 100644 --- a/.github/workflows/profile-nsys.yml +++ b/.github/workflows/profile-nsys.yml @@ -124,11 +124,19 @@ jobs: if: inputs.use_baspacho run: | # Replace PyPI spineax with git version that has BaSpaCho dense solver. + # --reinstall forces rebuild even if same version is cached from PyPI. + # --no-cache avoids reusing a wheel built without SPINEAX_USE_BASPACHO. # FetchContent auto-fetches BaSpaCho + deps (SuiteSparse, dispenso, Eigen). - uv pip install \ + uv pip install --reinstall --no-cache \ "spineax-vajax @ git+https://github.com/robtaylor/spineax.git@main" \ -C cmake.define.SPINEAX_USE_BASPACHO=ON + # Verify the module was installed + echo "=== Installed spineax files ===" + python -c "import spineax; import pathlib; p = pathlib.Path(spineax.__file__).parent; [print(f' {f.name}') for f in sorted(p.iterdir())]" + echo "=== Checking baspacho_dense_solve import ===" + python -c "from spineax import baspacho_dense_solve; print(' OK:', baspacho_dense_solve)" || echo " FAILED" + - name: Run nsys profiling env: JAX_PLATFORMS: cuda,cpu From 3c8568a65bee810cffad5119acf4d1193c399f07 Mon Sep 17 00:00:00 2001 From: Rob Taylor Date: Tue, 10 Mar 2026 16:41:47 +0000 Subject: [PATCH 41/79] chore: Add ls -lR diagnostic for spineax install verification Co-developed-by: Claude Code v2.1.58 (claude-opus-4-6) --- .github/workflows/profile-nsys.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/profile-nsys.yml b/.github/workflows/profile-nsys.yml index 62e3ecb0..29cb9b83 100644 --- a/.github/workflows/profile-nsys.yml +++ b/.github/workflows/profile-nsys.yml @@ -134,6 +134,8 @@ jobs: # Verify the module was installed echo "=== Installed spineax files ===" python -c "import spineax; import pathlib; p = pathlib.Path(spineax.__file__).parent; [print(f' {f.name}') for f in sorted(p.iterdir())]" + echo "=== spineax directory listing ===" + ls -lR .venv/lib/python*/site-packages/spineax/ || true echo "=== Checking baspacho_dense_solve import ===" python -c "from spineax import baspacho_dense_solve; print(' OK:', baspacho_dense_solve)" || echo " FAILED" From ab03e95d6e3a381218f01bb920d3837e131337a0 Mon Sep 17 00:00:00 2001 From: Rob Taylor Date: Tue, 10 Mar 2026 16:42:20 +0000 Subject: [PATCH 42/79] chore: Add -vvv to uv pip install for BaSpaCho build diagnostics Co-developed-by: Claude Code v2.1.58 (claude-opus-4-6) --- .github/workflows/profile-nsys.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/profile-nsys.yml b/.github/workflows/profile-nsys.yml index 29cb9b83..dc4778a3 100644 --- a/.github/workflows/profile-nsys.yml +++ b/.github/workflows/profile-nsys.yml @@ -127,7 +127,7 @@ jobs: # --reinstall forces rebuild even if same version is cached from PyPI. # --no-cache avoids reusing a wheel built without SPINEAX_USE_BASPACHO. # FetchContent auto-fetches BaSpaCho + deps (SuiteSparse, dispenso, Eigen). - uv pip install --reinstall --no-cache \ + uv pip install --reinstall --no-cache -vvv \ "spineax-vajax @ git+https://github.com/robtaylor/spineax.git@main" \ -C cmake.define.SPINEAX_USE_BASPACHO=ON From f2adf5294ce1cedb22a64937122aaa4b5b6165f9 Mon Sep 17 00:00:00 2001 From: Rob Taylor Date: Tue, 10 Mar 2026 16:45:51 +0000 Subject: [PATCH 43/79] chore: Enable cancel-in-progress for nsys profiling concurrency group Co-developed-by: Claude Code v2.1.58 (claude-opus-4-6) --- .github/workflows/profile-nsys.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/profile-nsys.yml b/.github/workflows/profile-nsys.yml index dc4778a3..41240d3b 100644 --- a/.github/workflows/profile-nsys.yml +++ b/.github/workflows/profile-nsys.yml @@ -28,7 +28,7 @@ on: # Only one profiling job at a time on the GPU runner concurrency: group: nsys-profile - cancel-in-progress: false + cancel-in-progress: true env: CARGO_TERM_COLOR: always From 2347bee2a0fd9e9392c470b720fc8d78baf432f2 Mon Sep 17 00:00:00 2001 From: Rob Taylor Date: Tue, 10 Mar 2026 16:51:30 +0000 Subject: [PATCH 44/79] perf: Cache apt packages for CUDA toolkit across CI runs CUDA toolkit + nsys debs are ~2-3GB downloaded fresh every run on GitHub-hosted runners. Cache them in ~/apt-cache and restore to /var/cache/apt/archives/ so apt-get can skip re-downloading. Also merged system dependencies install into the same step. Co-developed-by: Claude Code v2.1.58 (claude-opus-4-6) --- .github/workflows/profile-nsys.yml | 26 +++++++++++++++++++++----- 1 file changed, 21 insertions(+), 5 deletions(-) diff --git a/.github/workflows/profile-nsys.yml b/.github/workflows/profile-nsys.yml index 41240d3b..49dbfe99 100644 --- a/.github/workflows/profile-nsys.yml +++ b/.github/workflows/profile-nsys.yml @@ -82,10 +82,30 @@ jobs: with: workspaces: openvaf_jax/openvaf_py + - name: Cache apt packages + uses: actions/cache@v4 + id: apt-cache + with: + path: ~/apt-cache + key: apt-cuda-12.6-${{ runner.os }}-v1 + - name: Install CUDA toolkit and nsys run: | + # Restore cached debs if available + if [ -d ~/apt-cache ] && [ "$(ls ~/apt-cache/*.deb 2>/dev/null | wc -l)" -gt 0 ]; then + echo "=== Restoring $(ls ~/apt-cache/*.deb | wc -l) cached debs ===" + sudo cp ~/apt-cache/*.deb /var/cache/apt/archives/ 2>/dev/null || true + fi + sudo apt-get update - sudo apt-get install -y cuda-toolkit-12-6 libcudnn9-cuda-12 libcudss0-cuda-12 cuda-nsight-systems-12-6 + sudo apt-get install -y cuda-toolkit-12-6 libcudnn9-cuda-12 libcudss0-cuda-12 cuda-nsight-systems-12-6 \ + libsuitesparse-dev libopenblas-dev swig cmake pkg-config + + # Cache downloaded debs for next run + mkdir -p ~/apt-cache + cp /var/cache/apt/archives/*.deb ~/apt-cache/ 2>/dev/null || true + echo "=== Cached $(ls ~/apt-cache/*.deb 2>/dev/null | wc -l) debs ($(du -sh ~/apt-cache 2>/dev/null | cut -f1)) ===" + echo "/usr/local/cuda-12.6/bin" >> $GITHUB_PATH CUDSS_LIB=$(dpkg -L libcudss0-cuda-12 | grep '\.so' | head -1) if [ -n "$CUDSS_LIB" ]; then @@ -99,10 +119,6 @@ jobs: fi nsys --version - - name: Install system dependencies - run: | - sudo apt-get install -y libsuitesparse-dev libopenblas-dev swig cmake pkg-config - - name: Install uv uses: astral-sh/setup-uv@v6 with: From 3ab2cb9c6521542934216ac4cf0a71ccbd9f4b49 Mon Sep 17 00:00:00 2001 From: Rob Taylor Date: Tue, 10 Mar 2026 17:04:39 +0000 Subject: [PATCH 45/79] fix: Use sccache for BaSpaCho cmake build, fix python path in CI Three fixes for nsys profiling workflow: - Add CMAKE_C/CXX/CUDA_COMPILER_LAUNCHER=sccache to cache the 344-object BaSpaCho+SuiteSparse cmake build across runs - Fix 'python: command not found' by using 'uv run python' for diagnostics - Reduce verbosity from -vvv to -v (less noise in logs) Co-developed-by: Claude Code v2.1.58 (claude-opus-4-6) --- .github/workflows/profile-nsys.yml | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/.github/workflows/profile-nsys.yml b/.github/workflows/profile-nsys.yml index 49dbfe99..cd4703f1 100644 --- a/.github/workflows/profile-nsys.yml +++ b/.github/workflows/profile-nsys.yml @@ -143,17 +143,22 @@ jobs: # --reinstall forces rebuild even if same version is cached from PyPI. # --no-cache avoids reusing a wheel built without SPINEAX_USE_BASPACHO. # FetchContent auto-fetches BaSpaCho + deps (SuiteSparse, dispenso, Eigen). - uv pip install --reinstall --no-cache -vvv \ + # Use sccache to cache C++/CUDA compilation (BaSpaCho + SuiteSparse = 344 objects) + export CMAKE_C_COMPILER_LAUNCHER=sccache + export CMAKE_CXX_COMPILER_LAUNCHER=sccache + export CMAKE_CUDA_COMPILER_LAUNCHER=sccache + + uv pip install --reinstall --no-cache -v \ "spineax-vajax @ git+https://github.com/robtaylor/spineax.git@main" \ -C cmake.define.SPINEAX_USE_BASPACHO=ON # Verify the module was installed echo "=== Installed spineax files ===" - python -c "import spineax; import pathlib; p = pathlib.Path(spineax.__file__).parent; [print(f' {f.name}') for f in sorted(p.iterdir())]" + uv run python -c "import spineax; import pathlib; p = pathlib.Path(spineax.__file__).parent; [print(f' {f.name}') for f in sorted(p.iterdir())]" echo "=== spineax directory listing ===" ls -lR .venv/lib/python*/site-packages/spineax/ || true echo "=== Checking baspacho_dense_solve import ===" - python -c "from spineax import baspacho_dense_solve; print(' OK:', baspacho_dense_solve)" || echo " FAILED" + uv run python -c "from spineax import baspacho_dense_solve; print(' OK:', baspacho_dense_solve)" || echo " FAILED" - name: Run nsys profiling env: From 3f31bf5cbf0efdb3abad63fdae2f5a9a8d017b2b Mon Sep 17 00:00:00 2001 From: Rob Taylor Date: Tue, 10 Mar 2026 17:54:42 +0000 Subject: [PATCH 46/79] chore: Use cache-apt-pkgs-action for CUDA toolkit caching Replace manual apt deb caching with robtaylor/cache-apt-pkgs-action which handles caching installed packages end-to-end. On cache hit, packages are restored without running apt-get at all. Co-developed-by: Claude Code v2.1.58 (claude-opus-4-6) --- .github/workflows/profile-nsys.yml | 29 ++++++++--------------------- 1 file changed, 8 insertions(+), 21 deletions(-) diff --git a/.github/workflows/profile-nsys.yml b/.github/workflows/profile-nsys.yml index cd4703f1..78cd150b 100644 --- a/.github/workflows/profile-nsys.yml +++ b/.github/workflows/profile-nsys.yml @@ -82,30 +82,17 @@ jobs: with: workspaces: openvaf_jax/openvaf_py - - name: Cache apt packages - uses: actions/cache@v4 - id: apt-cache + - name: Install CUDA toolkit and system dependencies + uses: robtaylor/cache-apt-pkgs-action@feat/apt-sources with: - path: ~/apt-cache - key: apt-cuda-12.6-${{ runner.os }}-v1 - - - name: Install CUDA toolkit and nsys - run: | - # Restore cached debs if available - if [ -d ~/apt-cache ] && [ "$(ls ~/apt-cache/*.deb 2>/dev/null | wc -l)" -gt 0 ]; then - echo "=== Restoring $(ls ~/apt-cache/*.deb | wc -l) cached debs ===" - sudo cp ~/apt-cache/*.deb /var/cache/apt/archives/ 2>/dev/null || true - fi - - sudo apt-get update - sudo apt-get install -y cuda-toolkit-12-6 libcudnn9-cuda-12 libcudss0-cuda-12 cuda-nsight-systems-12-6 \ + packages: >- + cuda-toolkit-12-6 libcudnn9-cuda-12 libcudss0-cuda-12 + cuda-nsight-systems-12-6 libsuitesparse-dev libopenblas-dev swig cmake pkg-config + version: 1.0 - # Cache downloaded debs for next run - mkdir -p ~/apt-cache - cp /var/cache/apt/archives/*.deb ~/apt-cache/ 2>/dev/null || true - echo "=== Cached $(ls ~/apt-cache/*.deb 2>/dev/null | wc -l) debs ($(du -sh ~/apt-cache 2>/dev/null | cut -f1)) ===" - + - name: Set up CUDA environment + run: | echo "/usr/local/cuda-12.6/bin" >> $GITHUB_PATH CUDSS_LIB=$(dpkg -L libcudss0-cuda-12 | grep '\.so' | head -1) if [ -n "$CUDSS_LIB" ]; then From 08e98c7c13315866cfc2aaffdfd396200c6ef9c3 Mon Sep 17 00:00:00 2001 From: Rob Taylor Date: Tue, 10 Mar 2026 17:55:49 +0000 Subject: [PATCH 47/79] chore: Consolidate all apt packages into single cache-apt-pkgs-action call Merge LLVM 18 and CUDA toolkit installs into one cached action call. Uses apt-sources parameter to add the LLVM apt repo. On cache hit, all packages restore without running apt-get at all. Co-developed-by: Claude Code v2.1.58 (claude-opus-4-6) --- .github/workflows/profile-nsys.yml | 47 ++++++++++++------------------ 1 file changed, 19 insertions(+), 28 deletions(-) diff --git a/.github/workflows/profile-nsys.yml b/.github/workflows/profile-nsys.yml index 78cd150b..0a960d0b 100644 --- a/.github/workflows/profile-nsys.yml +++ b/.github/workflows/profile-nsys.yml @@ -57,55 +57,46 @@ jobs: echo "=== nsys Version ===" nsys --version 2>/dev/null || echo "nsys not in PATH" - - name: Install LLVM 18 - run: | - if ! llvm-config-18 --version 2>/dev/null; then - wget -qO- https://apt.llvm.org/llvm-snapshot.gpg.key | sudo tee /etc/apt/trusted.gpg.d/apt.llvm.org.asc > /dev/null - sudo chmod a+r /etc/apt/trusted.gpg.d/apt.llvm.org.asc - wget -q https://apt.llvm.org/llvm.sh - chmod +x llvm.sh - sudo ./llvm.sh 18 - rm llvm.sh - fi - echo "LLVM_SYS_181_PREFIX=/usr/lib/llvm-18" >> $GITHUB_ENV - - - name: Set up Rust - uses: dtolnay/rust-toolchain@stable - - - name: Set up sccache - uses: mozilla-actions/sccache-action@v0.0.7 - with: - disable_annotations: true - - - name: Cache Rust dependencies - uses: Swatinem/rust-cache@v2 - with: - workspaces: openvaf_jax/openvaf_py - - - name: Install CUDA toolkit and system dependencies + - name: Install all apt dependencies (LLVM, CUDA, system libs) uses: robtaylor/cache-apt-pkgs-action@feat/apt-sources with: packages: >- + llvm-18-dev clang-18 libclang-18-dev lld-18 cuda-toolkit-12-6 libcudnn9-cuda-12 libcudss0-cuda-12 cuda-nsight-systems-12-6 libsuitesparse-dev libopenblas-dev swig cmake pkg-config + apt-sources: >- + https://apt.llvm.org/llvm-snapshot.gpg.key | deb http://apt.llvm.org/noble/ llvm-toolchain-noble-18 main version: 1.0 - - name: Set up CUDA environment + - name: Set up LLVM and CUDA environment run: | + echo "LLVM_SYS_181_PREFIX=/usr/lib/llvm-18" >> $GITHUB_ENV echo "/usr/local/cuda-12.6/bin" >> $GITHUB_PATH CUDSS_LIB=$(dpkg -L libcudss0-cuda-12 | grep '\.so' | head -1) if [ -n "$CUDSS_LIB" ]; then CUDSS_DIR=$(dirname "$CUDSS_LIB") echo "LD_LIBRARY_PATH=${CUDSS_DIR}:${LD_LIBRARY_PATH}" >> "$GITHUB_ENV" fi - # nsys installs to /opt/nvidia/nsight-systems/*/target-linux-x64/ NSYS_BIN=$(dirname "$(find /opt/nvidia -name nsys -type f 2>/dev/null | head -1)" 2>/dev/null) if [ -n "$NSYS_BIN" ]; then echo "$NSYS_BIN" >> $GITHUB_PATH fi nsys --version + - name: Set up Rust + uses: dtolnay/rust-toolchain@stable + + - name: Set up sccache + uses: mozilla-actions/sccache-action@v0.0.7 + with: + disable_annotations: true + + - name: Cache Rust dependencies + uses: Swatinem/rust-cache@v2 + with: + workspaces: openvaf_jax/openvaf_py + - name: Install uv uses: astral-sh/setup-uv@v6 with: From 2749ae922ff677196a1d6e9b7ddecdb7958d8e0a Mon Sep 17 00:00:00 2001 From: Rob Taylor Date: Tue, 10 Mar 2026 18:10:38 +0000 Subject: [PATCH 48/79] feat: Add runner selection to nsys profiling workflow Adds nvidia-runner-2 as an option, with per-runner concurrency groups so profiling jobs on different runners don't cancel each other. Co-developed-by: Claude Code v2.1.58 (claude-opus-4-6) --- .github/workflows/profile-nsys.yml | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/.github/workflows/profile-nsys.yml b/.github/workflows/profile-nsys.yml index 0a960d0b..49aa62c6 100644 --- a/.github/workflows/profile-nsys.yml +++ b/.github/workflows/profile-nsys.yml @@ -24,10 +24,17 @@ on: description: 'Build spineax from source with BaSpaCho dense solver' type: boolean default: false + runner: + description: 'GPU runner to use' + type: choice + options: + - nvidia-runner-1 + - nvidia-runner-2 + default: nvidia-runner-1 -# Only one profiling job at a time on the GPU runner +# Only one profiling job at a time per runner concurrency: - group: nsys-profile + group: nsys-profile-${{ inputs.runner }} cancel-in-progress: true env: @@ -37,7 +44,7 @@ env: jobs: nsys-profile: - runs-on: nvidia-runner-1 + runs-on: ${{ inputs.runner }} timeout-minutes: 60 steps: From 769de4377e7c7678f1f04acad2a7ac3b9c9500d4 Mon Sep 17 00:00:00 2001 From: Rob Taylor Date: Tue, 10 Mar 2026 18:12:22 +0000 Subject: [PATCH 49/79] chore: Add NVIDIA CUDA repo to apt-sources for cache-apt-pkgs-action Without this, CUDA packages (cuda-toolkit-12-6, libcudnn9, libcudss0, nsight-systems) can't be resolved by the caching action on clean runners. Co-developed-by: Claude Code v2.1.58 (claude-opus-4-6) --- .github/workflows/profile-nsys.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/profile-nsys.yml b/.github/workflows/profile-nsys.yml index 49aa62c6..8cfd4651 100644 --- a/.github/workflows/profile-nsys.yml +++ b/.github/workflows/profile-nsys.yml @@ -74,6 +74,7 @@ jobs: libsuitesparse-dev libopenblas-dev swig cmake pkg-config apt-sources: >- https://apt.llvm.org/llvm-snapshot.gpg.key | deb http://apt.llvm.org/noble/ llvm-toolchain-noble-18 main + https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2404/x86_64/3bf863cc.pub | deb https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2404/x86_64/ / version: 1.0 - name: Set up LLVM and CUDA environment From 495c6d1e6922c7fb2e9bd3d7a3c23c1c72dfe9d3 Mon Sep 17 00:00:00 2001 From: Rob Taylor Date: Tue, 10 Mar 2026 18:28:09 +0000 Subject: [PATCH 50/79] fix: Use YAML literal block for apt-sources, make env setup resilient - apt-sources needs `|` (not `>-`) to preserve newlines between source entries - dpkg -L and nsys --version now tolerate missing packages gracefully Co-developed-by: Claude Code v2.1.58 (claude-opus-4-6) --- .github/workflows/profile-nsys.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/profile-nsys.yml b/.github/workflows/profile-nsys.yml index 8cfd4651..41cc76ae 100644 --- a/.github/workflows/profile-nsys.yml +++ b/.github/workflows/profile-nsys.yml @@ -72,7 +72,7 @@ jobs: cuda-toolkit-12-6 libcudnn9-cuda-12 libcudss0-cuda-12 cuda-nsight-systems-12-6 libsuitesparse-dev libopenblas-dev swig cmake pkg-config - apt-sources: >- + apt-sources: | https://apt.llvm.org/llvm-snapshot.gpg.key | deb http://apt.llvm.org/noble/ llvm-toolchain-noble-18 main https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2404/x86_64/3bf863cc.pub | deb https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2404/x86_64/ / version: 1.0 @@ -81,7 +81,7 @@ jobs: run: | echo "LLVM_SYS_181_PREFIX=/usr/lib/llvm-18" >> $GITHUB_ENV echo "/usr/local/cuda-12.6/bin" >> $GITHUB_PATH - CUDSS_LIB=$(dpkg -L libcudss0-cuda-12 | grep '\.so' | head -1) + CUDSS_LIB=$(dpkg -L libcudss0-cuda-12 2>/dev/null | grep '\.so' | head -1) if [ -n "$CUDSS_LIB" ]; then CUDSS_DIR=$(dirname "$CUDSS_LIB") echo "LD_LIBRARY_PATH=${CUDSS_DIR}:${LD_LIBRARY_PATH}" >> "$GITHUB_ENV" @@ -90,7 +90,7 @@ jobs: if [ -n "$NSYS_BIN" ]; then echo "$NSYS_BIN" >> $GITHUB_PATH fi - nsys --version + nsys --version || echo "::warning::nsys not found in PATH" - name: Set up Rust uses: dtolnay/rust-toolchain@stable From 3f4d08ee128cb88d90df561ef96f6cae64789d6b Mon Sep 17 00:00:00 2001 From: Rob Taylor Date: Tue, 10 Mar 2026 18:32:05 +0000 Subject: [PATCH 51/79] fix: Pass CUDAToolkit_ROOT to cmake for BaSpaCho build scikit-build-core runs cmake in an isolated environment that doesn't inherit GITHUB_PATH additions. Explicitly set CUDAToolkit_ROOT so cmake can find nvcc on runners where CUDA was installed via apt. Co-developed-by: Claude Code v2.1.58 (claude-opus-4-6) --- .github/workflows/profile-nsys.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/profile-nsys.yml b/.github/workflows/profile-nsys.yml index 41cc76ae..fb10286a 100644 --- a/.github/workflows/profile-nsys.yml +++ b/.github/workflows/profile-nsys.yml @@ -136,7 +136,8 @@ jobs: uv pip install --reinstall --no-cache -v \ "spineax-vajax @ git+https://github.com/robtaylor/spineax.git@main" \ - -C cmake.define.SPINEAX_USE_BASPACHO=ON + -C cmake.define.SPINEAX_USE_BASPACHO=ON \ + -C cmake.define.CUDAToolkit_ROOT=/usr/local/cuda-12.6 # Verify the module was installed echo "=== Installed spineax files ===" From 6646a363a9a3ce0497de71273e06bac66da6cde7 Mon Sep 17 00:00:00 2001 From: Rob Taylor Date: Tue, 10 Mar 2026 18:36:39 +0000 Subject: [PATCH 52/79] fix: Enable install scripts for CUDA apt packages, dynamic CUDA root CUDA toolkit apt packages use post-install scripts to create /usr/local/cuda-* directories. Without execute_install_scripts=true, the cached dpkg files don't set up the toolkit properly. Also dynamically discover CUDAToolkit_ROOT instead of hardcoding the path, and export it as env var for cmake. Co-developed-by: Claude Code v2.1.58 (claude-opus-4-6) --- .github/workflows/profile-nsys.yml | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/.github/workflows/profile-nsys.yml b/.github/workflows/profile-nsys.yml index fb10286a..1893f710 100644 --- a/.github/workflows/profile-nsys.yml +++ b/.github/workflows/profile-nsys.yml @@ -75,12 +75,26 @@ jobs: apt-sources: | https://apt.llvm.org/llvm-snapshot.gpg.key | deb http://apt.llvm.org/noble/ llvm-toolchain-noble-18 main https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2404/x86_64/3bf863cc.pub | deb https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2404/x86_64/ / + execute_install_scripts: true version: 1.0 - name: Set up LLVM and CUDA environment run: | echo "LLVM_SYS_181_PREFIX=/usr/lib/llvm-18" >> $GITHUB_ENV - echo "/usr/local/cuda-12.6/bin" >> $GITHUB_PATH + + # Find CUDA toolkit (may be at /usr/local/cuda-12.6 or elsewhere) + CUDA_ROOT=$(find /usr/local -maxdepth 1 -name "cuda-12*" -type d 2>/dev/null | sort -V | tail -1) + if [ -z "$CUDA_ROOT" ] && [ -d "/usr/local/cuda" ]; then + CUDA_ROOT="/usr/local/cuda" + fi + if [ -n "$CUDA_ROOT" ]; then + echo "${CUDA_ROOT}/bin" >> $GITHUB_PATH + echo "CUDAToolkit_ROOT=${CUDA_ROOT}" >> $GITHUB_ENV + echo "CUDA found at: ${CUDA_ROOT}" + else + echo "::warning::CUDA toolkit not found under /usr/local" + fi + CUDSS_LIB=$(dpkg -L libcudss0-cuda-12 2>/dev/null | grep '\.so' | head -1) if [ -n "$CUDSS_LIB" ]; then CUDSS_DIR=$(dirname "$CUDSS_LIB") @@ -137,7 +151,7 @@ jobs: uv pip install --reinstall --no-cache -v \ "spineax-vajax @ git+https://github.com/robtaylor/spineax.git@main" \ -C cmake.define.SPINEAX_USE_BASPACHO=ON \ - -C cmake.define.CUDAToolkit_ROOT=/usr/local/cuda-12.6 + -C cmake.define.CUDAToolkit_ROOT=${CUDAToolkit_ROOT:-/usr/local/cuda-12.6} # Verify the module was installed echo "=== Installed spineax files ===" From 0c21c0d11c36722c0c3851171de22fba73c6f816 Mon Sep 17 00:00:00 2001 From: Rob Taylor Date: Tue, 10 Mar 2026 18:39:29 +0000 Subject: [PATCH 53/79] fix: Bump apt cache version to invalidate stale cache Previous cache was created without execute_install_scripts, so CUDA toolkit post-install scripts weren't captured. Bumping version forces a fresh install with scripts enabled. Co-developed-by: Claude Code v2.1.58 (claude-opus-4-6) --- .github/workflows/profile-nsys.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/profile-nsys.yml b/.github/workflows/profile-nsys.yml index 1893f710..9e6b8097 100644 --- a/.github/workflows/profile-nsys.yml +++ b/.github/workflows/profile-nsys.yml @@ -76,7 +76,7 @@ jobs: https://apt.llvm.org/llvm-snapshot.gpg.key | deb http://apt.llvm.org/noble/ llvm-toolchain-noble-18 main https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2404/x86_64/3bf863cc.pub | deb https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2404/x86_64/ / execute_install_scripts: true - version: 1.0 + version: 1.1 - name: Set up LLVM and CUDA environment run: | From 874ea2700610b899eb42cd35e6ac684d02880d9f Mon Sep 17 00:00:00 2001 From: Rob Taylor Date: Tue, 10 Mar 2026 18:42:05 +0000 Subject: [PATCH 54/79] chore: Add dpkg and CUDA file diagnostics to env setup Helps debug whether cache-apt-pkgs-action properly restores CUDA toolkit files and post-install script results. Co-developed-by: Claude Code v2.1.58 (claude-opus-4-6) --- .github/workflows/profile-nsys.yml | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/.github/workflows/profile-nsys.yml b/.github/workflows/profile-nsys.yml index 9e6b8097..4ab754e8 100644 --- a/.github/workflows/profile-nsys.yml +++ b/.github/workflows/profile-nsys.yml @@ -82,6 +82,17 @@ jobs: run: | echo "LLVM_SYS_181_PREFIX=/usr/lib/llvm-18" >> $GITHUB_ENV + # Diagnostic: check dpkg state and file existence + echo "=== dpkg status for key packages ===" + dpkg -l cuda-toolkit-12-6 cuda-nsight-systems-12-6 libcudss0-cuda-12 2>&1 | tail -5 || true + echo "" + echo "=== /usr/local/cuda* directories ===" + ls -la /usr/local/cuda* 2>/dev/null || echo " (none)" + echo "" + echo "=== nvcc locations ===" + find / -name nvcc -type f 2>/dev/null | head -5 || echo " (none)" + echo "" + # Find CUDA toolkit (may be at /usr/local/cuda-12.6 or elsewhere) CUDA_ROOT=$(find /usr/local -maxdepth 1 -name "cuda-12*" -type d 2>/dev/null | sort -V | tail -1) if [ -z "$CUDA_ROOT" ] && [ -d "/usr/local/cuda" ]; then From 090617ffb0697a14a41c8c946555faf552ace2e2 Mon Sep 17 00:00:00 2001 From: Rob Taylor Date: Tue, 10 Mar 2026 18:53:23 +0000 Subject: [PATCH 55/79] fix: Include runner name in apt cache key Prevents cache cross-contamination between runners with different pre-installed packages (runner-1 has CUDA toolkit, runner-2 doesn't). Co-developed-by: Claude Code v2.1.58 (claude-opus-4-6) --- .github/workflows/profile-nsys.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/profile-nsys.yml b/.github/workflows/profile-nsys.yml index 4ab754e8..33a4e2c0 100644 --- a/.github/workflows/profile-nsys.yml +++ b/.github/workflows/profile-nsys.yml @@ -76,7 +76,7 @@ jobs: https://apt.llvm.org/llvm-snapshot.gpg.key | deb http://apt.llvm.org/noble/ llvm-toolchain-noble-18 main https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2404/x86_64/3bf863cc.pub | deb https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2404/x86_64/ / execute_install_scripts: true - version: 1.1 + version: 1.1-${{ inputs.runner }} - name: Set up LLVM and CUDA environment run: | From 205bffaa788d5bccae6bf4161daba93a390544b6 Mon Sep 17 00:00:00 2001 From: Rob Taylor Date: Tue, 10 Mar 2026 19:14:37 +0000 Subject: [PATCH 56/79] fix: Remove duplicate NVIDIA CUDA apt source causing Signed-By conflict The NVIDIA runner image already has the CUDA repo configured with its own GPG key (cuda-archive-keyring.gpg). Adding the same repo via apt-sources with a different key file causes APT to fail with "Conflicting values set for option Signed-By", preventing all package installation. Only add the LLVM repo via apt-sources. Bump cache version to invalidate the stale empty cache. Co-developed-by: Claude Code v2.1.58 (claude-opus-4-6) --- .github/workflows/profile-nsys.yml | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/.github/workflows/profile-nsys.yml b/.github/workflows/profile-nsys.yml index 33a4e2c0..1ae7a0a9 100644 --- a/.github/workflows/profile-nsys.yml +++ b/.github/workflows/profile-nsys.yml @@ -72,11 +72,12 @@ jobs: cuda-toolkit-12-6 libcudnn9-cuda-12 libcudss0-cuda-12 cuda-nsight-systems-12-6 libsuitesparse-dev libopenblas-dev swig cmake pkg-config + # NVIDIA CUDA repo is already on the runner image (cuda-archive-keyring.gpg). + # Adding it again via apt-sources causes "Conflicting Signed-By" error. apt-sources: | https://apt.llvm.org/llvm-snapshot.gpg.key | deb http://apt.llvm.org/noble/ llvm-toolchain-noble-18 main - https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2404/x86_64/3bf863cc.pub | deb https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2404/x86_64/ / execute_install_scripts: true - version: 1.1-${{ inputs.runner }} + version: 1.2-${{ inputs.runner }} - name: Set up LLVM and CUDA environment run: | From 68d59d4352327e720ff53f7409ae69182aeb1990 Mon Sep 17 00:00:00 2001 From: Rob Taylor Date: Tue, 10 Mar 2026 19:25:33 +0000 Subject: [PATCH 57/79] ci: Add python-version to setup-uv for correct cache keys MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The setup-uv action uses the system Python version (3.12.3) in the cache key when python-version is not specified. Since workflows install a different Python version afterwards (3.10-3.13), the cache was keyed incorrectly — storing cp312 wheels but needing cp31x wheels. GitHub Actions caches are immutable per key, so the stale ~2MB cache could never be updated, causing every run to re-download all packages. Co-developed-by: Claude Code v2.1.58 (claude-opus-4-6) --- .github/workflows/benchmark-comparison.yml | 1 + .github/workflows/lint.yml | 1 + .github/workflows/profile-nsys.yml | 1 + .github/workflows/test-pdk.yml | 1 + .github/workflows/test.yml | 3 +++ 5 files changed, 7 insertions(+) diff --git a/.github/workflows/benchmark-comparison.yml b/.github/workflows/benchmark-comparison.yml index b319cdae..ac6dd7cc 100644 --- a/.github/workflows/benchmark-comparison.yml +++ b/.github/workflows/benchmark-comparison.yml @@ -154,6 +154,7 @@ jobs: uses: astral-sh/setup-uv@v6 with: enable-cache: true + python-version: "3.12" - name: Set up Python run: uv python install 3.12 diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 54cb8165..af776291 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -21,6 +21,7 @@ jobs: uses: astral-sh/setup-uv@v6 with: enable-cache: true + python-version: "3.11" - name: Set up Python run: uv python install 3.11 diff --git a/.github/workflows/profile-nsys.yml b/.github/workflows/profile-nsys.yml index 1ae7a0a9..9c4cc62f 100644 --- a/.github/workflows/profile-nsys.yml +++ b/.github/workflows/profile-nsys.yml @@ -135,6 +135,7 @@ jobs: uses: astral-sh/setup-uv@v6 with: enable-cache: true + python-version: "3.13" - name: Set up Python run: uv python install 3.13 diff --git a/.github/workflows/test-pdk.yml b/.github/workflows/test-pdk.yml index 451d3615..19feacc6 100644 --- a/.github/workflows/test-pdk.yml +++ b/.github/workflows/test-pdk.yml @@ -72,6 +72,7 @@ jobs: uses: astral-sh/setup-uv@v6 with: enable-cache: true + python-version: "3.10" - name: Set up Python run: uv python install 3.10 diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 1e13b602..aa2caa05 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -51,6 +51,7 @@ jobs: uses: astral-sh/setup-uv@v6 with: enable-cache: true + python-version: "3.11" - name: Set up Python run: uv python install 3.11 @@ -138,6 +139,7 @@ jobs: uses: astral-sh/setup-uv@v6 with: enable-cache: true + python-version: "3.11" - name: Set up Python run: uv python install 3.11 @@ -230,6 +232,7 @@ jobs: uses: astral-sh/setup-uv@v6 with: enable-cache: true + python-version: "3.11" - name: Set up Python run: uv python install 3.11 From 5b91c1c8071c74867ebfd382935afb7f14aaf549 Mon Sep 17 00:00:00 2001 From: Rob Taylor Date: Tue, 10 Mar 2026 19:29:04 +0000 Subject: [PATCH 58/79] ci: Use cache-apt-pkgs-action for all apt installs Replace ad-hoc apt-get install and llvm.sh script calls with robtaylor/cache-apt-pkgs-action@feat/apt-sources across all workflows. This caches installed packages between runs, avoiding repeated ~30s apt-get update + install cycles. Changes per workflow: - test.yml: Merge LLVM + system deps into single cached step - benchmark-comparison.yml: Replace 4 apt steps (cache, CPU deps, CUDA deps, LLVM) with 2 conditional cached steps (CPU vs CUDA) - test-pdk.yml: Replace llvm.sh with cached LLVM packages - profile-nsys.yml: Remove stale version parameter Co-developed-by: Claude Code v2.1.58 (claude-opus-4-6) --- .github/workflows/benchmark-comparison.yml | 68 ++++++++++------------ .github/workflows/profile-nsys.yml | 1 - .github/workflows/test-pdk.yml | 16 +++-- .github/workflows/test.yml | 40 +++++++------ 4 files changed, 64 insertions(+), 61 deletions(-) diff --git a/.github/workflows/benchmark-comparison.yml b/.github/workflows/benchmark-comparison.yml index ac6dd7cc..8c96312f 100644 --- a/.github/workflows/benchmark-comparison.yml +++ b/.github/workflows/benchmark-comparison.yml @@ -56,45 +56,34 @@ jobs: nvcc --version 2>/dev/null || echo "nvcc not in PATH" # ── System dependencies ────────────────────────────────────── - - name: Cache apt packages - uses: actions/cache@v4 - with: - path: /var/cache/apt/archives - key: apt-${{ runner.os }}-${{ runner.arch }}-${{ matrix.platform }}-v1 - restore-keys: | - apt-${{ runner.os }}-${{ runner.arch }}-${{ matrix.platform }}- - - - name: Install system dependencies (CPU) + - name: Install apt dependencies (CPU) if: matrix.platform == 'cpu' - run: | - sudo apt-get update - sudo apt-get install -y \ - cmake ninja-build \ - flex bison libfl-dev \ - libsuitesparse-dev libopenblas-dev \ + uses: robtaylor/cache-apt-pkgs-action@feat/apt-sources + with: + packages: >- + llvm-18-dev clang-18 libclang-18-dev lld-18 + cmake ninja-build flex bison libfl-dev + libsuitesparse-dev libopenblas-dev ccache bc + apt-sources: | + https://apt.llvm.org/llvm-snapshot.gpg.key | deb http://apt.llvm.org/noble/ llvm-toolchain-noble-18 main + execute_install_scripts: true - - name: Install system dependencies (CUDA) + - name: Install apt dependencies (CUDA) if: matrix.platform == 'cuda' - run: | - sudo apt-get update - sudo apt-get install -y \ - cmake pkg-config swig \ + uses: robtaylor/cache-apt-pkgs-action@feat/apt-sources + with: + packages: >- + llvm-18-dev clang-18 libclang-18-dev lld-18 + cuda-toolkit-12-6 libcudnn9-cuda-12 libcudss0-cuda-12 + cmake pkg-config swig libsuitesparse-dev libopenblas-dev + apt-sources: | + https://apt.llvm.org/llvm-snapshot.gpg.key | deb http://apt.llvm.org/noble/ llvm-toolchain-noble-18 main + execute_install_scripts: true - # ── LLVM 18 (idempotent, works on all runners) ────────────── - - name: Install LLVM 18 + - name: Set LLVM and CUDA environment run: | - if ! llvm-config-18 --version 2>/dev/null; then - wget -qO- https://apt.llvm.org/llvm-snapshot.gpg.key | \ - sudo tee /etc/apt/trusted.gpg.d/apt.llvm.org.asc > /dev/null - sudo chmod a+r /etc/apt/trusted.gpg.d/apt.llvm.org.asc - wget -q https://apt.llvm.org/llvm.sh - chmod +x llvm.sh - sudo ./llvm.sh 18 - rm llvm.sh - fi - sudo apt-get install -y lld-18 echo "/usr/lib/llvm-18/bin" >> "$GITHUB_PATH" echo "LLVM_SYS_181_PREFIX=/usr/lib/llvm-18" >> "$GITHUB_ENV" @@ -136,13 +125,18 @@ jobs: openvaf_jax/openvaf_py vendor/OpenVAF - # ── CUDA toolkit ───────────────────────────────────────────── - - name: Install CUDA toolkit + # ── CUDA environment ───────────────────────────────────────── + - name: Set CUDA environment if: matrix.platform == 'cuda' run: | - sudo apt-get install -y cuda-toolkit-12-6 libcudnn9-cuda-12 libcudss0-cuda-12 - echo "/usr/local/cuda-12.6/bin" >> "$GITHUB_PATH" - CUDSS_LIB=$(dpkg -L libcudss0-cuda-12 | grep '\.so' | head -1) + CUDA_ROOT=$(find /usr/local -maxdepth 1 -name "cuda-12*" -type d 2>/dev/null | sort -V | tail -1) + if [ -z "$CUDA_ROOT" ] && [ -d "/usr/local/cuda" ]; then + CUDA_ROOT="/usr/local/cuda" + fi + if [ -n "$CUDA_ROOT" ]; then + echo "${CUDA_ROOT}/bin" >> "$GITHUB_PATH" + fi + CUDSS_LIB=$(dpkg -L libcudss0-cuda-12 2>/dev/null | grep '\.so' | head -1) if [ -n "$CUDSS_LIB" ]; then CUDSS_DIR=$(dirname "$CUDSS_LIB") echo "LD_LIBRARY_PATH=${CUDSS_DIR}:${LD_LIBRARY_PATH}" >> "$GITHUB_ENV" diff --git a/.github/workflows/profile-nsys.yml b/.github/workflows/profile-nsys.yml index 9c4cc62f..36595a85 100644 --- a/.github/workflows/profile-nsys.yml +++ b/.github/workflows/profile-nsys.yml @@ -77,7 +77,6 @@ jobs: apt-sources: | https://apt.llvm.org/llvm-snapshot.gpg.key | deb http://apt.llvm.org/noble/ llvm-toolchain-noble-18 main execute_install_scripts: true - version: 1.2-${{ inputs.runner }} - name: Set up LLVM and CUDA environment run: | diff --git a/.github/workflows/test-pdk.yml b/.github/workflows/test-pdk.yml index 19feacc6..b90a4f85 100644 --- a/.github/workflows/test-pdk.yml +++ b/.github/workflows/test-pdk.yml @@ -48,12 +48,16 @@ jobs: # Mask PDK path in all subsequent log output echo "::add-mask::/tmp/pdk-gf130" - - name: Install LLVM 18 - run: | - wget -q https://apt.llvm.org/llvm.sh - chmod +x llvm.sh - sudo ./llvm.sh 18 - echo "LLVM_SYS_181_PREFIX=/usr/lib/llvm-18" >> $GITHUB_ENV + - name: Install apt dependencies + uses: robtaylor/cache-apt-pkgs-action@feat/apt-sources + with: + packages: llvm-18-dev clang-18 libclang-18-dev lld-18 + apt-sources: | + https://apt.llvm.org/llvm-snapshot.gpg.key | deb http://apt.llvm.org/noble/ llvm-toolchain-noble-18 main + execute_install_scripts: true + + - name: Set LLVM environment + run: echo "LLVM_SYS_181_PREFIX=/usr/lib/llvm-18" >> $GITHUB_ENV - name: Set up Rust uses: dtolnay/rust-toolchain@stable diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index aa2caa05..14c56026 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -27,12 +27,16 @@ jobs: with: submodules: recursive - - name: Install LLVM 18 - run: | - wget https://apt.llvm.org/llvm.sh - chmod +x llvm.sh - sudo ./llvm.sh 18 - echo "LLVM_SYS_181_PREFIX=/usr/lib/llvm-18" >> $GITHUB_ENV + - name: Install apt dependencies + uses: robtaylor/cache-apt-pkgs-action@feat/apt-sources + with: + packages: llvm-18-dev clang-18 libclang-18-dev lld-18 + apt-sources: | + https://apt.llvm.org/llvm-snapshot.gpg.key | deb http://apt.llvm.org/noble/ llvm-toolchain-noble-18 main + execute_install_scripts: true + + - name: Set LLVM environment + run: echo "LLVM_SYS_181_PREFIX=/usr/lib/llvm-18" >> $GITHUB_ENV - name: Set up Rust uses: dtolnay/rust-toolchain@stable @@ -110,12 +114,19 @@ jobs: with: submodules: recursive - - name: Install LLVM 18 - run: | - wget https://apt.llvm.org/llvm.sh - chmod +x llvm.sh - sudo ./llvm.sh 18 - echo "LLVM_SYS_181_PREFIX=/usr/lib/llvm-18" >> $GITHUB_ENV + - name: Install apt dependencies + uses: robtaylor/cache-apt-pkgs-action@feat/apt-sources + with: + packages: >- + llvm-18-dev clang-18 libclang-18-dev lld-18 + libsuitesparse-dev libopenblas-dev swig cmake + ${{ matrix.test-group.extra_packages }} + apt-sources: | + https://apt.llvm.org/llvm-snapshot.gpg.key | deb http://apt.llvm.org/noble/ llvm-toolchain-noble-18 main + execute_install_scripts: true + + - name: Set LLVM environment + run: echo "LLVM_SYS_181_PREFIX=/usr/lib/llvm-18" >> $GITHUB_ENV - name: Set up Rust uses: dtolnay/rust-toolchain@stable @@ -130,11 +141,6 @@ jobs: with: workspaces: openvaf_jax/openvaf_py - - name: Install system dependencies - run: | - sudo apt-get update - sudo apt-get install -y libsuitesparse-dev libopenblas-dev swig cmake ${{ matrix.test-group.extra_packages }} - - name: Install uv uses: astral-sh/setup-uv@v6 with: From 5cbb0734e3f13e3f1a406b584e81063cd6eea5bb Mon Sep 17 00:00:00 2001 From: Rob Taylor Date: Tue, 10 Mar 2026 21:55:09 +0000 Subject: [PATCH 59/79] fix: Run ldconfig after apt cache restore to fix OpenBLAS symlinks The cache-apt-pkgs-action may restore .so files without proper symlinks, causing CMake's FindBLAS to fail with "Could NOT find BLAS". Running ldconfig after restore regenerates the symlinks. Co-developed-by: Claude Code v2.1.58 (claude-opus-4-6) --- .github/workflows/profile-nsys.yml | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/.github/workflows/profile-nsys.yml b/.github/workflows/profile-nsys.yml index 36595a85..7e9f9715 100644 --- a/.github/workflows/profile-nsys.yml +++ b/.github/workflows/profile-nsys.yml @@ -82,6 +82,15 @@ jobs: run: | echo "LLVM_SYS_181_PREFIX=/usr/lib/llvm-18" >> $GITHUB_ENV + # Ensure ldconfig and BLAS symlinks are correct after apt cache restore. + # The cache-apt-pkgs-action may restore .so files without the expected + # symlinks (e.g. libopenblas.so -> libopenblas.so.0), causing CMake's + # FindBLAS to fail. Running ldconfig regenerates the symlinks. + sudo ldconfig + echo "=== OpenBLAS library check ===" + ls -la /usr/lib/x86_64-linux-gnu/libopenblas* 2>/dev/null || echo " (not found)" + ldconfig -p | grep openblas || echo " (not in ldconfig)" + # Diagnostic: check dpkg state and file existence echo "=== dpkg status for key packages ===" dpkg -l cuda-toolkit-12-6 cuda-nsight-systems-12-6 libcudss0-cuda-12 2>&1 | tail -5 || true From 28dc62364b969b5e7d171ac9c8172e53ca6d24cf Mon Sep 17 00:00:00 2001 From: Rob Taylor Date: Tue, 10 Mar 2026 22:02:57 +0000 Subject: [PATCH 60/79] fix: Reinstall libopenblas-dev if .so missing after apt cache restore The cache-apt-pkgs-action registers dpkg metadata but may not restore actual library files to the filesystem. Detect and reinstall openblas if the shared library is missing. Co-developed-by: Claude Code v2.1.58 (claude-opus-4-6) --- .github/workflows/profile-nsys.yml | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/.github/workflows/profile-nsys.yml b/.github/workflows/profile-nsys.yml index 7e9f9715..0a737a4d 100644 --- a/.github/workflows/profile-nsys.yml +++ b/.github/workflows/profile-nsys.yml @@ -82,14 +82,13 @@ jobs: run: | echo "LLVM_SYS_181_PREFIX=/usr/lib/llvm-18" >> $GITHUB_ENV - # Ensure ldconfig and BLAS symlinks are correct after apt cache restore. - # The cache-apt-pkgs-action may restore .so files without the expected - # symlinks (e.g. libopenblas.so -> libopenblas.so.0), causing CMake's - # FindBLAS to fail. Running ldconfig regenerates the symlinks. + # Workaround: cache-apt-pkgs-action may register dpkg metadata without + # actually restoring library files. Reinstall openblas if the .so is missing. + if ! ls /usr/lib/x86_64-linux-gnu/libopenblas* >/dev/null 2>&1; then + echo "OpenBLAS .so missing after cache restore, reinstalling..." + sudo apt-get install -y --reinstall libopenblas-dev + fi sudo ldconfig - echo "=== OpenBLAS library check ===" - ls -la /usr/lib/x86_64-linux-gnu/libopenblas* 2>/dev/null || echo " (not found)" - ldconfig -p | grep openblas || echo " (not in ldconfig)" # Diagnostic: check dpkg state and file existence echo "=== dpkg status for key packages ===" From 19e463310f27d127a4d0a426d6476f27c4b090f9 Mon Sep 17 00:00:00 2001 From: Rob Taylor Date: Tue, 10 Mar 2026 22:06:33 +0000 Subject: [PATCH 61/79] fix: Run apt --fix-broken install before reinstalling openblas The cache-apt-pkgs-action may create version conflicts between cached and pre-installed packages (e.g., gcc-14-base version mismatch). Run apt --fix-broken install first, then reinstall both openblas and suitesparse to ensure library files are present. Co-developed-by: Claude Code v2.1.58 (claude-opus-4-6) --- .github/workflows/profile-nsys.yml | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/.github/workflows/profile-nsys.yml b/.github/workflows/profile-nsys.yml index 0a737a4d..f289ad55 100644 --- a/.github/workflows/profile-nsys.yml +++ b/.github/workflows/profile-nsys.yml @@ -83,10 +83,12 @@ jobs: echo "LLVM_SYS_181_PREFIX=/usr/lib/llvm-18" >> $GITHUB_ENV # Workaround: cache-apt-pkgs-action may register dpkg metadata without - # actually restoring library files. Reinstall openblas if the .so is missing. + # actually restoring library files, and may create version conflicts + # between cached and pre-installed packages. Fix both issues. if ! ls /usr/lib/x86_64-linux-gnu/libopenblas* >/dev/null 2>&1; then echo "OpenBLAS .so missing after cache restore, reinstalling..." - sudo apt-get install -y --reinstall libopenblas-dev + sudo apt-get --fix-broken install -y || true + sudo apt-get install -y --reinstall libopenblas-dev libsuitesparse-dev fi sudo ldconfig From 92060b69d63f423442c80aeb5ee146ea23ad8d7b Mon Sep 17 00:00:00 2001 From: Rob Taylor Date: Tue, 10 Mar 2026 22:10:22 +0000 Subject: [PATCH 62/79] fix: Force dpkg overwrite to fix gcc-14-base version conflict from cache The apt cache action may register newer gcc-14-base dpkg metadata than what's installed on the runner, creating unresolvable dependency conflicts. Use --force-overwrite to install the correct versions. Co-developed-by: Claude Code v2.1.58 (claude-opus-4-6) --- .github/workflows/profile-nsys.yml | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/.github/workflows/profile-nsys.yml b/.github/workflows/profile-nsys.yml index f289ad55..2669a0dd 100644 --- a/.github/workflows/profile-nsys.yml +++ b/.github/workflows/profile-nsys.yml @@ -84,11 +84,14 @@ jobs: # Workaround: cache-apt-pkgs-action may register dpkg metadata without # actually restoring library files, and may create version conflicts - # between cached and pre-installed packages. Fix both issues. + # between cached and pre-installed packages. Fix by forcing dpkg to + # overwrite the conflicting gcc-14-base version, then reinstall. if ! ls /usr/lib/x86_64-linux-gnu/libopenblas* >/dev/null 2>&1; then - echo "OpenBLAS .so missing after cache restore, reinstalling..." - sudo apt-get --fix-broken install -y || true - sudo apt-get install -y --reinstall libopenblas-dev libsuitesparse-dev + echo "OpenBLAS .so missing after cache restore, fixing dpkg state..." + sudo dpkg --configure -a || true + sudo apt-get update -qq + sudo apt-get install -y -o Dpkg::Options::="--force-overwrite" \ + gcc-14-base libgfortran5 libopenblas-dev libsuitesparse-dev fi sudo ldconfig From 11a1211e9556eee3a14610a5a32e3551987c16a9 Mon Sep 17 00:00:00 2001 From: Rob Taylor Date: Tue, 10 Mar 2026 22:14:59 +0000 Subject: [PATCH 63/79] fix: Remove apt workaround, deleted corrupted cache instead The cache-apt-pkgs-action cache was corrupted (gcc-14-base version mismatch with runner). Deleted the cache entry and removed the workaround. Fresh install will build a clean cache. Co-developed-by: Claude Code v2.1.58 (claude-opus-4-6) --- .github/workflows/profile-nsys.yml | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/.github/workflows/profile-nsys.yml b/.github/workflows/profile-nsys.yml index 2669a0dd..5356b174 100644 --- a/.github/workflows/profile-nsys.yml +++ b/.github/workflows/profile-nsys.yml @@ -82,17 +82,6 @@ jobs: run: | echo "LLVM_SYS_181_PREFIX=/usr/lib/llvm-18" >> $GITHUB_ENV - # Workaround: cache-apt-pkgs-action may register dpkg metadata without - # actually restoring library files, and may create version conflicts - # between cached and pre-installed packages. Fix by forcing dpkg to - # overwrite the conflicting gcc-14-base version, then reinstall. - if ! ls /usr/lib/x86_64-linux-gnu/libopenblas* >/dev/null 2>&1; then - echo "OpenBLAS .so missing after cache restore, fixing dpkg state..." - sudo dpkg --configure -a || true - sudo apt-get update -qq - sudo apt-get install -y -o Dpkg::Options::="--force-overwrite" \ - gcc-14-base libgfortran5 libopenblas-dev libsuitesparse-dev - fi sudo ldconfig # Diagnostic: check dpkg state and file existence From 346dfd195a261d11a6bce46c9b5769abaf14fe7e Mon Sep 17 00:00:00 2001 From: Rob Taylor Date: Tue, 10 Mar 2026 23:45:48 +0000 Subject: [PATCH 64/79] ci: Replace cuda-toolkit-12-6 with minimal CUDA packages cuda-toolkit-12-6 pulls in 216+ packages (cufft, curand, npp, nvjpeg, opencl, documentation, visual tools, etc.) that we don't need. Replace with the 7 specific packages required for building spineax/BaSpaCho: cuda-nvcc-12-6, cuda-cudart-dev-12-6, cuda-driver-dev-12-6, libcublas-dev-12-6, libcusolver-dev-12-6, libcusparse-dev-12-6, libnvjitlink-dev-12-6 Also removes stale dpkg diagnostics and ldconfig workaround from profile-nsys.yml (the corrupted apt cache was already deleted). Co-developed-by: Claude Code v2.1.58 (claude-opus-4-6) --- .github/workflows/benchmark-comparison.yml | 5 ++++- .github/workflows/profile-nsys.yml | 18 ++++-------------- 2 files changed, 8 insertions(+), 15 deletions(-) diff --git a/.github/workflows/benchmark-comparison.yml b/.github/workflows/benchmark-comparison.yml index 8c96312f..3db1d5e0 100644 --- a/.github/workflows/benchmark-comparison.yml +++ b/.github/workflows/benchmark-comparison.yml @@ -75,7 +75,10 @@ jobs: with: packages: >- llvm-18-dev clang-18 libclang-18-dev lld-18 - cuda-toolkit-12-6 libcudnn9-cuda-12 libcudss0-cuda-12 + cuda-nvcc-12-6 cuda-cudart-dev-12-6 cuda-driver-dev-12-6 + libcublas-dev-12-6 libcusolver-dev-12-6 libcusparse-dev-12-6 + libnvjitlink-dev-12-6 + libcudnn9-cuda-12 libcudss0-cuda-12 cmake pkg-config swig libsuitesparse-dev libopenblas-dev apt-sources: | diff --git a/.github/workflows/profile-nsys.yml b/.github/workflows/profile-nsys.yml index 5356b174..294d4097 100644 --- a/.github/workflows/profile-nsys.yml +++ b/.github/workflows/profile-nsys.yml @@ -69,7 +69,10 @@ jobs: with: packages: >- llvm-18-dev clang-18 libclang-18-dev lld-18 - cuda-toolkit-12-6 libcudnn9-cuda-12 libcudss0-cuda-12 + cuda-nvcc-12-6 cuda-cudart-dev-12-6 cuda-driver-dev-12-6 + libcublas-dev-12-6 libcusolver-dev-12-6 libcusparse-dev-12-6 + libnvjitlink-dev-12-6 + libcudnn9-cuda-12 libcudss0-cuda-12 cuda-nsight-systems-12-6 libsuitesparse-dev libopenblas-dev swig cmake pkg-config # NVIDIA CUDA repo is already on the runner image (cuda-archive-keyring.gpg). @@ -82,19 +85,6 @@ jobs: run: | echo "LLVM_SYS_181_PREFIX=/usr/lib/llvm-18" >> $GITHUB_ENV - sudo ldconfig - - # Diagnostic: check dpkg state and file existence - echo "=== dpkg status for key packages ===" - dpkg -l cuda-toolkit-12-6 cuda-nsight-systems-12-6 libcudss0-cuda-12 2>&1 | tail -5 || true - echo "" - echo "=== /usr/local/cuda* directories ===" - ls -la /usr/local/cuda* 2>/dev/null || echo " (none)" - echo "" - echo "=== nvcc locations ===" - find / -name nvcc -type f 2>/dev/null | head -5 || echo " (none)" - echo "" - # Find CUDA toolkit (may be at /usr/local/cuda-12.6 or elsewhere) CUDA_ROOT=$(find /usr/local -maxdepth 1 -name "cuda-12*" -type d 2>/dev/null | sort -V | tail -1) if [ -z "$CUDA_ROOT" ] && [ -d "/usr/local/cuda" ]; then From 2bc8f4e2dd1768130a8d9d50ed09b4d2d08ad6a4 Mon Sep 17 00:00:00 2001 From: Rob Taylor Date: Tue, 10 Mar 2026 23:57:35 +0000 Subject: [PATCH 65/79] ci: Add libcufft-12-6 runtime library for JAX CUDA plugin JAX's xla_cuda12 plugin checks cuFFT version during initialization even if the application doesn't use FFTs. Without libcufft.so, the CUDA backend fails to initialize and falls back to CPU. Co-developed-by: Claude Code v2.1.58 (claude-opus-4-6) --- .github/workflows/benchmark-comparison.yml | 2 +- .github/workflows/profile-nsys.yml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/benchmark-comparison.yml b/.github/workflows/benchmark-comparison.yml index 3db1d5e0..79597791 100644 --- a/.github/workflows/benchmark-comparison.yml +++ b/.github/workflows/benchmark-comparison.yml @@ -77,7 +77,7 @@ jobs: llvm-18-dev clang-18 libclang-18-dev lld-18 cuda-nvcc-12-6 cuda-cudart-dev-12-6 cuda-driver-dev-12-6 libcublas-dev-12-6 libcusolver-dev-12-6 libcusparse-dev-12-6 - libnvjitlink-dev-12-6 + libnvjitlink-dev-12-6 libcufft-12-6 libcudnn9-cuda-12 libcudss0-cuda-12 cmake pkg-config swig libsuitesparse-dev libopenblas-dev diff --git a/.github/workflows/profile-nsys.yml b/.github/workflows/profile-nsys.yml index 294d4097..a93c44ac 100644 --- a/.github/workflows/profile-nsys.yml +++ b/.github/workflows/profile-nsys.yml @@ -71,7 +71,7 @@ jobs: llvm-18-dev clang-18 libclang-18-dev lld-18 cuda-nvcc-12-6 cuda-cudart-dev-12-6 cuda-driver-dev-12-6 libcublas-dev-12-6 libcusolver-dev-12-6 libcusparse-dev-12-6 - libnvjitlink-dev-12-6 + libnvjitlink-dev-12-6 libcufft-12-6 libcudnn9-cuda-12 libcudss0-cuda-12 cuda-nsight-systems-12-6 libsuitesparse-dev libopenblas-dev swig cmake pkg-config From c13e4a51a2f31f3951e84c6ef2069cfea50edff4 Mon Sep 17 00:00:00 2001 From: Rob Taylor Date: Wed, 11 Mar 2026 00:01:27 +0000 Subject: [PATCH 66/79] ci: Disable uv cache pruning to preserve downloaded packages MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit setup-uv's prune-cache (default: true) runs `uv cache prune` before saving, which was stripping nearly all cached content — leaving only ~2-5 MB of HTTP metadata instead of the ~200+ MB of downloaded wheels. Combined with GitHub Actions' immutable cache keys, this meant every run re-downloaded jaxlib (78 MB), scipy (33 MB), numpy (15 MB), etc. Also deleted the existing stale caches via `gh cache delete` so new properly-populated caches can be saved on the next run. Co-developed-by: Claude Code v2.1.58 (claude-opus-4-6) --- .github/workflows/benchmark-comparison.yml | 1 + .github/workflows/lint.yml | 1 + .github/workflows/profile-nsys.yml | 1 + .github/workflows/test-pdk.yml | 1 + .github/workflows/test.yml | 3 +++ .github/workflows/xla-flag-sweep.yml | 1 + 6 files changed, 8 insertions(+) diff --git a/.github/workflows/benchmark-comparison.yml b/.github/workflows/benchmark-comparison.yml index 79597791..159c0495 100644 --- a/.github/workflows/benchmark-comparison.yml +++ b/.github/workflows/benchmark-comparison.yml @@ -151,6 +151,7 @@ jobs: uses: astral-sh/setup-uv@v6 with: enable-cache: true + prune-cache: false python-version: "3.12" - name: Set up Python diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index af776291..4cc68339 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -21,6 +21,7 @@ jobs: uses: astral-sh/setup-uv@v6 with: enable-cache: true + prune-cache: false python-version: "3.11" - name: Set up Python diff --git a/.github/workflows/profile-nsys.yml b/.github/workflows/profile-nsys.yml index a93c44ac..fab55686 100644 --- a/.github/workflows/profile-nsys.yml +++ b/.github/workflows/profile-nsys.yml @@ -126,6 +126,7 @@ jobs: uses: astral-sh/setup-uv@v6 with: enable-cache: true + prune-cache: false python-version: "3.13" - name: Set up Python diff --git a/.github/workflows/test-pdk.yml b/.github/workflows/test-pdk.yml index b90a4f85..34b4ab54 100644 --- a/.github/workflows/test-pdk.yml +++ b/.github/workflows/test-pdk.yml @@ -76,6 +76,7 @@ jobs: uses: astral-sh/setup-uv@v6 with: enable-cache: true + prune-cache: false python-version: "3.10" - name: Set up Python diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 14c56026..5dcbb739 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -55,6 +55,7 @@ jobs: uses: astral-sh/setup-uv@v6 with: enable-cache: true + prune-cache: false python-version: "3.11" - name: Set up Python @@ -145,6 +146,7 @@ jobs: uses: astral-sh/setup-uv@v6 with: enable-cache: true + prune-cache: false python-version: "3.11" - name: Set up Python @@ -238,6 +240,7 @@ jobs: uses: astral-sh/setup-uv@v6 with: enable-cache: true + prune-cache: false python-version: "3.11" - name: Set up Python diff --git a/.github/workflows/xla-flag-sweep.yml b/.github/workflows/xla-flag-sweep.yml index 7e506468..eb61cc1b 100644 --- a/.github/workflows/xla-flag-sweep.yml +++ b/.github/workflows/xla-flag-sweep.yml @@ -45,6 +45,7 @@ jobs: uses: astral-sh/setup-uv@v6 with: enable-cache: true + prune-cache: false cache-dependency-glob: "uv.lock" - name: Install dependencies From e23011347f19708c75674d318e476c3c2bc1930b Mon Sep 17 00:00:00 2001 From: Rob Taylor Date: Wed, 11 Mar 2026 00:07:14 +0000 Subject: [PATCH 67/79] ci: Add cuda-libraries-12-6 and cuda-cupti-12-6 for JAX runtime JAX's xla_cuda12 plugin checks cuFFT, cuPTI, and other CUDA library versions during initialization. Use cuda-libraries-12-6 meta-package (13 runtime libs) instead of cherry-picking individual ones. Also add cuda-cupti-12-6 which JAX requires but isn't in cuda-libraries. Still much smaller than cuda-toolkit-12-6 (216+ packages including docs, visual tools, nvprof, etc). Co-developed-by: Claude Code v2.1.58 (claude-opus-4-6) --- .github/workflows/benchmark-comparison.yml | 3 ++- .github/workflows/profile-nsys.yml | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/.github/workflows/benchmark-comparison.yml b/.github/workflows/benchmark-comparison.yml index 159c0495..ae9f58f4 100644 --- a/.github/workflows/benchmark-comparison.yml +++ b/.github/workflows/benchmark-comparison.yml @@ -77,7 +77,8 @@ jobs: llvm-18-dev clang-18 libclang-18-dev lld-18 cuda-nvcc-12-6 cuda-cudart-dev-12-6 cuda-driver-dev-12-6 libcublas-dev-12-6 libcusolver-dev-12-6 libcusparse-dev-12-6 - libnvjitlink-dev-12-6 libcufft-12-6 + libnvjitlink-dev-12-6 + cuda-libraries-12-6 cuda-cupti-12-6 libcudnn9-cuda-12 libcudss0-cuda-12 cmake pkg-config swig libsuitesparse-dev libopenblas-dev diff --git a/.github/workflows/profile-nsys.yml b/.github/workflows/profile-nsys.yml index fab55686..5299fc54 100644 --- a/.github/workflows/profile-nsys.yml +++ b/.github/workflows/profile-nsys.yml @@ -71,7 +71,8 @@ jobs: llvm-18-dev clang-18 libclang-18-dev lld-18 cuda-nvcc-12-6 cuda-cudart-dev-12-6 cuda-driver-dev-12-6 libcublas-dev-12-6 libcusolver-dev-12-6 libcusparse-dev-12-6 - libnvjitlink-dev-12-6 libcufft-12-6 + libnvjitlink-dev-12-6 + cuda-libraries-12-6 cuda-cupti-12-6 libcudnn9-cuda-12 libcudss0-cuda-12 cuda-nsight-systems-12-6 libsuitesparse-dev libopenblas-dev swig cmake pkg-config From 5ac968619a459a3b916ab3312363116a429b6056 Mon Sep 17 00:00:00 2001 From: Rob Taylor Date: Wed, 11 Mar 2026 00:09:26 +0000 Subject: [PATCH 68/79] ci: Pass sccache to spineax build via cmake.define flags Environment variables don't propagate through uv pip install's subprocess. Pass CMAKE_*_COMPILER_LAUNCHER via -C cmake.define so scikit-build-core forwards them to CMake. Co-developed-by: Claude Code v2.1.58 (claude-opus-4-6) --- .github/workflows/profile-nsys.yml | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/.github/workflows/profile-nsys.yml b/.github/workflows/profile-nsys.yml index 5299fc54..4a7c5b13 100644 --- a/.github/workflows/profile-nsys.yml +++ b/.github/workflows/profile-nsys.yml @@ -149,15 +149,14 @@ jobs: # --reinstall forces rebuild even if same version is cached from PyPI. # --no-cache avoids reusing a wheel built without SPINEAX_USE_BASPACHO. # FetchContent auto-fetches BaSpaCho + deps (SuiteSparse, dispenso, Eigen). - # Use sccache to cache C++/CUDA compilation (BaSpaCho + SuiteSparse = 344 objects) - export CMAKE_C_COMPILER_LAUNCHER=sccache - export CMAKE_CXX_COMPILER_LAUNCHER=sccache - export CMAKE_CUDA_COMPILER_LAUNCHER=sccache - + # sccache caches C++/CUDA compilation (BaSpaCho + SuiteSparse = 344 objects) uv pip install --reinstall --no-cache -v \ "spineax-vajax @ git+https://github.com/robtaylor/spineax.git@main" \ -C cmake.define.SPINEAX_USE_BASPACHO=ON \ - -C cmake.define.CUDAToolkit_ROOT=${CUDAToolkit_ROOT:-/usr/local/cuda-12.6} + -C cmake.define.CUDAToolkit_ROOT=${CUDAToolkit_ROOT:-/usr/local/cuda-12.6} \ + -C cmake.define.CMAKE_C_COMPILER_LAUNCHER=sccache \ + -C cmake.define.CMAKE_CXX_COMPILER_LAUNCHER=sccache \ + -C cmake.define.CMAKE_CUDA_COMPILER_LAUNCHER=sccache # Verify the module was installed echo "=== Installed spineax files ===" From 4136e73a6c84e2c28ff66cfc2146b7909aba4915 Mon Sep 17 00:00:00 2001 From: Rob Taylor Date: Wed, 11 Mar 2026 00:13:04 +0000 Subject: [PATCH 69/79] ci: Add quarterly uv cache cleanup workflow Scheduled to run on the 1st of every 3rd month to prevent unbounded cache growth. Also supports manual dispatch for ad-hoc cleanup. Co-developed-by: Claude Code v2.1.58 (claude-opus-4-6) --- .github/workflows/cache-cleanup.yml | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) create mode 100644 .github/workflows/cache-cleanup.yml diff --git a/.github/workflows/cache-cleanup.yml b/.github/workflows/cache-cleanup.yml new file mode 100644 index 00000000..a7ff2f52 --- /dev/null +++ b/.github/workflows/cache-cleanup.yml @@ -0,0 +1,23 @@ +name: Cache Cleanup + +on: + schedule: + - cron: '0 0 1 */3 *' # First day of every 3rd month + workflow_dispatch: + +jobs: + cleanup: + runs-on: ubuntu-latest + steps: + - name: Delete uv caches + env: + GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} + run: | + echo "=== Deleting uv caches ===" + gh cache list --repo ${{ github.repository }} --key setup-uv --json id,key,sizeInBytes \ + --jq '.[] | "\(.id)\t\(.key)\t\(.sizeInBytes)"' | \ + while IFS=$'\t' read -r id key size; do + echo "Deleting: $key ($(numfmt --to=iec $size))" + gh cache delete "$id" --repo ${{ github.repository }} + done + echo "Done" From 369f6acd2d4d1d0979f85dca628ce23336214933 Mon Sep 17 00:00:00 2001 From: Rob Taylor Date: Wed, 11 Mar 2026 02:20:24 +0000 Subject: [PATCH 70/79] fix: Add CUDA lib64 to LD_LIBRARY_PATH for JAX runtime JAX's xla_cuda12 plugin dlopen's all CUDA runtime libraries (cuSPARSE, cuFFT, cuBLAS, etc.) at initialization. Without CUDA_ROOT/lib64 in LD_LIBRARY_PATH, these libraries aren't found even though the apt packages are installed. Previous runs worked because cuda-toolkit-12-6 included ldconfig snippets; our minimal package set does not. Co-developed-by: Claude Code v2.1.58 (claude-opus-4-6) --- .github/workflows/benchmark-comparison.yml | 3 +++ .github/workflows/profile-nsys.yml | 3 +++ 2 files changed, 6 insertions(+) diff --git a/.github/workflows/benchmark-comparison.yml b/.github/workflows/benchmark-comparison.yml index ae9f58f4..8e024132 100644 --- a/.github/workflows/benchmark-comparison.yml +++ b/.github/workflows/benchmark-comparison.yml @@ -139,6 +139,9 @@ jobs: fi if [ -n "$CUDA_ROOT" ]; then echo "${CUDA_ROOT}/bin" >> "$GITHUB_PATH" + # CUDA runtime libs (cuSPARSE, cuFFT, etc.) live in lib64; + # JAX checks all of them at startup via dlopen. + echo "LD_LIBRARY_PATH=${CUDA_ROOT}/lib64:${LD_LIBRARY_PATH}" >> "$GITHUB_ENV" fi CUDSS_LIB=$(dpkg -L libcudss0-cuda-12 2>/dev/null | grep '\.so' | head -1) if [ -n "$CUDSS_LIB" ]; then diff --git a/.github/workflows/profile-nsys.yml b/.github/workflows/profile-nsys.yml index 4a7c5b13..c58ef24b 100644 --- a/.github/workflows/profile-nsys.yml +++ b/.github/workflows/profile-nsys.yml @@ -94,6 +94,9 @@ jobs: if [ -n "$CUDA_ROOT" ]; then echo "${CUDA_ROOT}/bin" >> $GITHUB_PATH echo "CUDAToolkit_ROOT=${CUDA_ROOT}" >> $GITHUB_ENV + # CUDA runtime libs (cuSPARSE, cuFFT, etc.) live in lib64; + # JAX checks all of them at startup via dlopen. + echo "LD_LIBRARY_PATH=${CUDA_ROOT}/lib64:${LD_LIBRARY_PATH}" >> "$GITHUB_ENV" echo "CUDA found at: ${CUDA_ROOT}" else echo "::warning::CUDA toolkit not found under /usr/local" From 49c0bfd75e4772fa071db0498816a1217a7bea68 Mon Sep 17 00:00:00 2001 From: Rob Taylor Date: Wed, 11 Mar 2026 02:33:50 +0000 Subject: [PATCH 71/79] fix: Write LD_LIBRARY_PATH as single GITHUB_ENV entry GITHUB_ENV takes the last value per key within a step. Writing LD_LIBRARY_PATH twice (once for CUDA lib64, once for cuDSS) meant the CUDA path was lost. Build the full path in a shell variable first, then write it once. Co-developed-by: Claude Code v2.1.58 (claude-opus-4-6) --- .github/workflows/benchmark-comparison.yml | 11 +++++++---- .github/workflows/profile-nsys.yml | 12 ++++++++---- 2 files changed, 15 insertions(+), 8 deletions(-) diff --git a/.github/workflows/benchmark-comparison.yml b/.github/workflows/benchmark-comparison.yml index 8e024132..162dd72a 100644 --- a/.github/workflows/benchmark-comparison.yml +++ b/.github/workflows/benchmark-comparison.yml @@ -137,18 +137,21 @@ jobs: if [ -z "$CUDA_ROOT" ] && [ -d "/usr/local/cuda" ]; then CUDA_ROOT="/usr/local/cuda" fi + EXTRA_LD="" if [ -n "$CUDA_ROOT" ]; then echo "${CUDA_ROOT}/bin" >> "$GITHUB_PATH" - # CUDA runtime libs (cuSPARSE, cuFFT, etc.) live in lib64; - # JAX checks all of them at startup via dlopen. - echo "LD_LIBRARY_PATH=${CUDA_ROOT}/lib64:${LD_LIBRARY_PATH}" >> "$GITHUB_ENV" + # CUDA lib64 has cuSPARSE, cuFFT, etc. needed by JAX at startup. + EXTRA_LD="${CUDA_ROOT}/lib64" fi CUDSS_LIB=$(dpkg -L libcudss0-cuda-12 2>/dev/null | grep '\.so' | head -1) if [ -n "$CUDSS_LIB" ]; then CUDSS_DIR=$(dirname "$CUDSS_LIB") - echo "LD_LIBRARY_PATH=${CUDSS_DIR}:${LD_LIBRARY_PATH}" >> "$GITHUB_ENV" + EXTRA_LD="${CUDSS_DIR}:${EXTRA_LD}" echo "cuDSS library found at: $CUDSS_LIB" fi + if [ -n "$EXTRA_LD" ]; then + echo "LD_LIBRARY_PATH=${EXTRA_LD}:${LD_LIBRARY_PATH}" >> "$GITHUB_ENV" + fi # ── Python environment ────────────────────────────────────── - name: Install uv diff --git a/.github/workflows/profile-nsys.yml b/.github/workflows/profile-nsys.yml index c58ef24b..9c974856 100644 --- a/.github/workflows/profile-nsys.yml +++ b/.github/workflows/profile-nsys.yml @@ -94,18 +94,22 @@ jobs: if [ -n "$CUDA_ROOT" ]; then echo "${CUDA_ROOT}/bin" >> $GITHUB_PATH echo "CUDAToolkit_ROOT=${CUDA_ROOT}" >> $GITHUB_ENV - # CUDA runtime libs (cuSPARSE, cuFFT, etc.) live in lib64; - # JAX checks all of them at startup via dlopen. - echo "LD_LIBRARY_PATH=${CUDA_ROOT}/lib64:${LD_LIBRARY_PATH}" >> "$GITHUB_ENV" + # Build LD_LIBRARY_PATH with CUDA lib64 (cuSPARSE, cuFFT, etc.) + # Must be a single write — GITHUB_ENV takes last value per key. + EXTRA_LD="${CUDA_ROOT}/lib64" echo "CUDA found at: ${CUDA_ROOT}" else echo "::warning::CUDA toolkit not found under /usr/local" + EXTRA_LD="" fi CUDSS_LIB=$(dpkg -L libcudss0-cuda-12 2>/dev/null | grep '\.so' | head -1) if [ -n "$CUDSS_LIB" ]; then CUDSS_DIR=$(dirname "$CUDSS_LIB") - echo "LD_LIBRARY_PATH=${CUDSS_DIR}:${LD_LIBRARY_PATH}" >> "$GITHUB_ENV" + EXTRA_LD="${CUDSS_DIR}:${EXTRA_LD}" + fi + if [ -n "$EXTRA_LD" ]; then + echo "LD_LIBRARY_PATH=${EXTRA_LD}:${LD_LIBRARY_PATH}" >> "$GITHUB_ENV" fi NSYS_BIN=$(dirname "$(find /opt/nvidia -name nsys -type f 2>/dev/null | head -1)" 2>/dev/null) if [ -n "$NSYS_BIN" ]; then From 5d81200ec0fb458f964d61863ebb6867232fdbac Mon Sep 17 00:00:00 2001 From: Rob Taylor Date: Wed, 11 Mar 2026 02:51:09 +0000 Subject: [PATCH 72/79] debug: Add phase timing to nsys profiling target Wraps each phase (JAX init, solver imports, vajax import, circuit parse, prepare + JIT warmup, transient simulation) with timing output to identify where startup time is spent. Co-developed-by: Claude Code v2.1.58 (claude-opus-4-6) --- scripts/nsys_profile_target.py | 91 +++++++++++++++++++++------------- 1 file changed, 57 insertions(+), 34 deletions(-) diff --git a/scripts/nsys_profile_target.py b/scripts/nsys_profile_target.py index cf5b96de..5dffaf84 100644 --- a/scripts/nsys_profile_target.py +++ b/scripts/nsys_profile_target.py @@ -14,6 +14,7 @@ import argparse import logging import sys +import time from pathlib import Path import jax @@ -23,8 +24,21 @@ # Enable INFO logging so solver selection messages are visible logging.basicConfig(level=logging.INFO, format="%(name)s: %(message)s") -# Import vajax first to auto-configure precision based on backend -from vajax.analysis import CircuitEngine + +def timed(label): + """Context manager that prints elapsed time for a phase.""" + + class Timer: + def __enter__(self): + self.t0 = time.perf_counter() + print(f"[{label}] starting...", flush=True) + return self + + def __exit__(self, *exc): + dt = time.perf_counter() - self.t0 + print(f"[{label}] done in {dt:.2f}s", flush=True) + + return Timer() def main(): @@ -56,33 +70,35 @@ def main(): ) args = parser.parse_args() - print(f"JAX backend: {jax.default_backend()}") - print(f"JAX devices: {jax.devices()}") + with timed("JAX init"): + print(f"JAX backend: {jax.default_backend()}") + print(f"JAX devices: {jax.devices()}") print(f"Circuit: {args.circuit}") print(f"Timesteps: {args.timesteps}") # Explicit solver availability check print() print("=== Solver Availability ===") - try: - from spineax.cudss.dense_baspacho_solver import is_available - - print(" BaSpaCho dense import: OK") - print(f" BaSpaCho dense available: {is_available()}") - except ImportError as e: - print(f" BaSpaCho dense import: FAILED ({e})") - try: - from spineax.cudss.solver import CuDSSSolver # noqa: F401 - - print(" cuDSS sparse import: OK") - except ImportError as e: - print(f" cuDSS sparse import: FAILED ({e})") - try: - from spineax import baspacho_dense_solve as _mod - - print(f" baspacho_dense_solve C++ module: OK ({_mod})") - except ImportError as e: - print(f" baspacho_dense_solve C++ module: FAILED ({e})") + with timed("solver imports"): + try: + from spineax.cudss.dense_baspacho_solver import is_available + + print(" BaSpaCho dense import: OK") + print(f" BaSpaCho dense available: {is_available()}") + except ImportError as e: + print(f" BaSpaCho dense import: FAILED ({e})") + try: + from spineax.cudss.solver import CuDSSSolver # noqa: F401 + + print(" cuDSS sparse import: OK") + except ImportError as e: + print(f" cuDSS sparse import: FAILED ({e})") + try: + from spineax import baspacho_dense_solve as _mod + + print(f" baspacho_dense_solve C++ module: OK ({_mod})") + except ImportError as e: + print(f" baspacho_dense_solve C++ module: FAILED ({e})") print() # Find benchmark .sim file @@ -94,10 +110,15 @@ def main(): print(f"ERROR: Benchmark file not found: {sim_path}") sys.exit(1) + # Import vajax (auto-configures precision based on backend) + with timed("vajax import"): + from vajax.analysis import CircuitEngine + # Setup circuit using CircuitEngine - print(f"Setting up circuit from {sim_path}...") - engine = CircuitEngine(sim_path) - engine.parse() + with timed("circuit parse"): + print(f"Setting up circuit from {sim_path}...") + engine = CircuitEngine(sim_path) + engine.parse() print(f"Circuit size: {engine.num_nodes} nodes, {len(engine.devices)} devices") print() @@ -108,19 +129,21 @@ def main(): print() # Prepare (includes 1-step JIT warmup) - print(f"Preparing ({args.timesteps} timesteps, includes JIT warmup)...") - engine.prepare( - t_stop=args.timesteps * dt, - dt=dt, - use_sparse=args.sparse, - ) + with timed("prepare + JIT warmup"): + print(f"Preparing ({args.timesteps} timesteps, includes JIT warmup)...") + engine.prepare( + t_stop=args.timesteps * dt, + dt=dt, + use_sparse=args.sparse, + ) print("Prepare complete") print() # Profiled run — nsys captures everything including warmup above, # but with 500+ steps the warmup is a small fraction of total time - print(f"Starting profiled run ({args.timesteps} timesteps)...") - result = engine.run_transient() + with timed("transient simulation"): + print(f"Starting profiled run ({args.timesteps} timesteps)...") + result = engine.run_transient() print() print(f"Completed: {result.num_steps} timesteps") From b86327916676dd0691c4aeb321a9444f2dbf1414 Mon Sep 17 00:00:00 2001 From: Rob Taylor Date: Wed, 11 Mar 2026 02:59:55 +0000 Subject: [PATCH 73/79] perf: Scope nsys capture to run_transient() via NVTX range Add NVTX push/pop around run_transient() and use nsys --capture-range=nvtx to exclude JIT warmup and Python startup from the profiling window. This gives clean data for just the simulation hot path. Co-developed-by: Claude Code v2.1.58 (claude-opus-4-6) --- .github/workflows/profile-nsys.yml | 7 ++++--- scripts/nsys_profile_target.py | 18 ++++++++++++++++-- 2 files changed, 20 insertions(+), 5 deletions(-) diff --git a/.github/workflows/profile-nsys.yml b/.github/workflows/profile-nsys.yml index 9c974856..26ec0f32 100644 --- a/.github/workflows/profile-nsys.yml +++ b/.github/workflows/profile-nsys.yml @@ -200,12 +200,13 @@ jobs: echo "Sparse: ${{ inputs.sparse }}" echo "Commit: ${{ github.sha }}" - # Full capture with enough timesteps (500+) that warmup overhead is <5% - # Note: cudaProfilerApi capture range doesn't work with JAX/XLA — XLA uses - # its own CUDA runtime internally, so ctypes cudaProfilerStart() has no effect + # Capture only the run_transient() window via NVTX range marker, + # excluding JIT warmup and Python startup overhead. # Tolerate exit code 139 (SIGSEGV during JAX/CUDA teardown) nsys profile \ --trace=cuda,nvtx,osrt \ + --capture-range=nvtx \ + --nvtx-capture="run_transient" \ --output "/tmp/${PROFILE_NAME}" \ uv run python scripts/nsys_profile_target.py \ ${{ inputs.circuit }} ${{ inputs.timesteps }} ${SPARSE_FLAG} \ diff --git a/scripts/nsys_profile_target.py b/scripts/nsys_profile_target.py index 5dffaf84..0cb0e337 100644 --- a/scripts/nsys_profile_target.py +++ b/scripts/nsys_profile_target.py @@ -12,6 +12,7 @@ """ import argparse +import ctypes import logging import sys import time @@ -139,12 +140,25 @@ def main(): print("Prepare complete") print() - # Profiled run — nsys captures everything including warmup above, - # but with 500+ steps the warmup is a small fraction of total time + # NVTX range for nsys --capture-range=nvtx scoping + try: + _nvtx = ctypes.CDLL("libnvToolsExt.so") + _nvtx_push = _nvtx.nvtxRangePushA + _nvtx_push.argtypes = [ctypes.c_char_p] + _nvtx_pop = _nvtx.nvtxRangePop + except OSError: + _nvtx_push = _nvtx_pop = None + + if _nvtx_push: + _nvtx_push(b"run_transient") + with timed("transient simulation"): print(f"Starting profiled run ({args.timesteps} timesteps)...") result = engine.run_transient() + if _nvtx_pop: + _nvtx_pop() + print() print(f"Completed: {result.num_steps} timesteps") print(f"Wall time: {result.stats.get('wall_time', 0):.3f}s") From ad5d880094f98cab522e0fee33c541d29990d542 Mon Sep 17 00:00:00 2001 From: Rob Taylor Date: Wed, 11 Mar 2026 03:16:39 +0000 Subject: [PATCH 74/79] fix: Drop NVTX capture-range and remove --no-cache from spineax build NVTX capture-range=nvtx prevents nsys from writing the report when SIGSEGV occurs during JAX/CUDA teardown. Revert to default capture mode with NVTX markers kept for annotation/filtering in post-analysis. Also remove --no-cache from spineax pip install (unnecessary slowdown) and add flush=True to profiling prints to prevent output loss on crash. Co-developed-by: Claude Code v2.1.58 (claude-opus-4-6) --- .github/workflows/profile-nsys.yml | 10 ++++------ scripts/nsys_profile_target.py | 8 ++++---- 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/.github/workflows/profile-nsys.yml b/.github/workflows/profile-nsys.yml index 26ec0f32..7776916c 100644 --- a/.github/workflows/profile-nsys.yml +++ b/.github/workflows/profile-nsys.yml @@ -154,10 +154,9 @@ jobs: run: | # Replace PyPI spineax with git version that has BaSpaCho dense solver. # --reinstall forces rebuild even if same version is cached from PyPI. - # --no-cache avoids reusing a wheel built without SPINEAX_USE_BASPACHO. # FetchContent auto-fetches BaSpaCho + deps (SuiteSparse, dispenso, Eigen). # sccache caches C++/CUDA compilation (BaSpaCho + SuiteSparse = 344 objects) - uv pip install --reinstall --no-cache -v \ + uv pip install --reinstall -v \ "spineax-vajax @ git+https://github.com/robtaylor/spineax.git@main" \ -C cmake.define.SPINEAX_USE_BASPACHO=ON \ -C cmake.define.CUDAToolkit_ROOT=${CUDAToolkit_ROOT:-/usr/local/cuda-12.6} \ @@ -200,13 +199,12 @@ jobs: echo "Sparse: ${{ inputs.sparse }}" echo "Commit: ${{ github.sha }}" - # Capture only the run_transient() window via NVTX range marker, - # excluding JIT warmup and Python startup overhead. + # NVTX "run_transient" marker annotates the transient window for filtering. + # Cannot use --capture-range=nvtx: SIGSEGV during JAX teardown prevents + # nsys from finalizing the report in NVTX capture mode. # Tolerate exit code 139 (SIGSEGV during JAX/CUDA teardown) nsys profile \ --trace=cuda,nvtx,osrt \ - --capture-range=nvtx \ - --nvtx-capture="run_transient" \ --output "/tmp/${PROFILE_NAME}" \ uv run python scripts/nsys_profile_target.py \ ${{ inputs.circuit }} ${{ inputs.timesteps }} ${SPARSE_FLAG} \ diff --git a/scripts/nsys_profile_target.py b/scripts/nsys_profile_target.py index 0cb0e337..c313fbf2 100644 --- a/scripts/nsys_profile_target.py +++ b/scripts/nsys_profile_target.py @@ -153,15 +153,15 @@ def main(): _nvtx_push(b"run_transient") with timed("transient simulation"): - print(f"Starting profiled run ({args.timesteps} timesteps)...") + print(f"Starting profiled run ({args.timesteps} timesteps)...", flush=True) result = engine.run_transient() if _nvtx_pop: _nvtx_pop() - print() - print(f"Completed: {result.num_steps} timesteps") - print(f"Wall time: {result.stats.get('wall_time', 0):.3f}s") + print(flush=True) + print(f"Completed: {result.num_steps} timesteps", flush=True) + print(f"Wall time: {result.stats.get('wall_time', 0):.3f}s", flush=True) if __name__ == "__main__": From 20c1430bba709d8a56321619eeebc9977885e067 Mon Sep 17 00:00:00 2001 From: Rob Taylor Date: Wed, 11 Mar 2026 03:20:16 +0000 Subject: [PATCH 75/79] ci: Fix BLAS detection and add compiler cache to UMFPACK build scikit-build-core bundles CMake 3.31.6 which can't find OpenBLAS via the generic libblas.so symlink (cache-apt-pkgs-action doesn't replay update-alternatives hooks). Setting BLA_VENDOR=OpenBLAS makes FindBLAS search directly for libopenblas.so. Also adds compiler cache (sccache/ccache) to the SuiteSparse FetchContent build, and aligns benchmark CUDA packages with the nsys profiling workflow. Co-developed-by: Claude Code v2.1.58 (claude-opus-4-6) --- .github/workflows/benchmark-comparison.yml | 12 +++++++++--- .github/workflows/test.yml | 6 +++++- 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/.github/workflows/benchmark-comparison.yml b/.github/workflows/benchmark-comparison.yml index 162dd72a..58de58e2 100644 --- a/.github/workflows/benchmark-comparison.yml +++ b/.github/workflows/benchmark-comparison.yml @@ -80,8 +80,9 @@ jobs: libnvjitlink-dev-12-6 cuda-libraries-12-6 cuda-cupti-12-6 libcudnn9-cuda-12 libcudss0-cuda-12 - cmake pkg-config swig - libsuitesparse-dev libopenblas-dev + libsuitesparse-dev libopenblas-dev swig cmake pkg-config + # NVIDIA CUDA repo is already on the runner image (cuda-archive-keyring.gpg). + # Adding it again via apt-sources causes "Conflicting Signed-By" error. apt-sources: | https://apt.llvm.org/llvm-snapshot.gpg.key | deb http://apt.llvm.org/noble/ llvm-toolchain-noble-18 main execute_install_scripts: true @@ -201,7 +202,12 @@ jobs: working-directory: vajax/sparse run: | uv pip install scikit-build-core nanobind - uv pip install --no-build-isolation . + LAUNCHER=${{ matrix.platform == 'cuda' && 'sccache' || 'ccache' }} + uv pip install --no-build-isolation \ + -C cmake.define.BLA_VENDOR=OpenBLAS \ + -C cmake.define.CMAKE_C_COMPILER_LAUNCHER=$LAUNCHER \ + -C cmake.define.CMAKE_CXX_COMPILER_LAUNCHER=$LAUNCHER \ + . # ── Run VAJAX unit tests (CPU dense only) ──────────────────── - name: Run VAJAX tests diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 5dcbb739..56e20eb1 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -174,7 +174,11 @@ jobs: working-directory: vajax/sparse run: | uv pip install scikit-build-core nanobind - uv pip install --no-build-isolation . + uv pip install --no-build-isolation \ + -C cmake.define.BLA_VENDOR=OpenBLAS \ + -C cmake.define.CMAKE_C_COMPILER_LAUNCHER=sccache \ + -C cmake.define.CMAKE_CXX_COMPILER_LAUNCHER=sccache \ + . - name: Run tests (${{ matrix.test-group.name }}) timeout-minutes: ${{ matrix.test-group.timeout }} From 5a221805b3efa7bf802fc6a27ef93a828e884a61 Mon Sep 17 00:00:00 2001 From: Rob Taylor Date: Wed, 11 Mar 2026 04:04:32 +0000 Subject: [PATCH 76/79] debug: Add XLA VLOG flags for command buffer conversion diagnostics MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Enables VLOG on command_buffer_conversion_pass (level 2) and while_thunk (level 3) to identify which thunks in the NR while loop body prevent CUDA graph conditional node capture. Without graph capture, each NR iteration requires a D2H sync for the loop predicate (574µs × 82K iterations = 47s overhead). Co-developed-by: Claude Code v2.1.58 (claude-opus-4-6) --- .github/workflows/profile-nsys.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/profile-nsys.yml b/.github/workflows/profile-nsys.yml index 7776916c..2efa290f 100644 --- a/.github/workflows/profile-nsys.yml +++ b/.github/workflows/profile-nsys.yml @@ -181,6 +181,8 @@ jobs: XLA_PYTHON_CLIENT_PREALLOCATE: "false" XLA_PYTHON_CLIENT_ALLOCATOR: platform TF_CPP_MIN_LOG_LEVEL: "2" + # Debug: show which thunks prevent command buffer capture for while loops + TF_CPP_VLOG_FLAGS: "--vmodule=command_buffer_conversion_pass=2,command_buffer_cmd_emitter=2,while_thunk=3,command_buffer_cmd=5" run: | SPARSE_FLAG="" if [ "${{ inputs.sparse }}" = "true" ]; then From ca33f1883a31d5f00dc00cdfcf32a6620d226853 Mon Sep 17 00:00:00 2001 From: Rob Taylor Date: Wed, 11 Mar 2026 04:17:32 +0000 Subject: [PATCH 77/79] fix: Use TF_CPP_VMODULE for XLA debug logging (not TF_CPP_VLOG_FLAGS) TF_CPP_VLOG_FLAGS is not recognized by JAX/XLA. The correct env var for module-level VLOG control is TF_CPP_VMODULE. Also set TF_CPP_MIN_LOG_LEVEL=0 (was 2) so VLOG output is not suppressed. Co-developed-by: Claude Code v2.1.58 (claude-opus-4-6) --- .github/workflows/profile-nsys.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/profile-nsys.yml b/.github/workflows/profile-nsys.yml index 2efa290f..8e5dcb2c 100644 --- a/.github/workflows/profile-nsys.yml +++ b/.github/workflows/profile-nsys.yml @@ -180,9 +180,9 @@ jobs: XLA_FLAGS: "--xla_gpu_autotune_level=0 --xla_gpu_enable_command_buffer=+WHILE,+CONDITIONAL" XLA_PYTHON_CLIENT_PREALLOCATE: "false" XLA_PYTHON_CLIENT_ALLOCATOR: platform - TF_CPP_MIN_LOG_LEVEL: "2" + TF_CPP_MIN_LOG_LEVEL: "0" # Debug: show which thunks prevent command buffer capture for while loops - TF_CPP_VLOG_FLAGS: "--vmodule=command_buffer_conversion_pass=2,command_buffer_cmd_emitter=2,while_thunk=3,command_buffer_cmd=5" + TF_CPP_VMODULE: "command_buffer_conversion_pass=2,command_buffer_cmd_emitter=2,while_thunk=3,command_buffer_cmd=5" run: | SPARSE_FLAG="" if [ "${{ inputs.sparse }}" = "true" ]; then From e389e631d280a0dc32e26e46e8cef9e52ae714b3 Mon Sep 17 00:00:00 2001 From: Rob Taylor Date: Wed, 11 Mar 2026 04:31:26 +0000 Subject: [PATCH 78/79] ci: Comment out XLA VLOG flags after confirming command buffer findings MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit VLOG output confirmed: - NR while_loop NOT captured as CUDA graph (WhileThunk fallback mode) - BaSpaCho FFI custom call blocks WhileThunk→WhileCmd conversion (lacks kCmdBufferCompatible trait) - XLA creates 4 partial command buffer graphs within the NR body - NR NEVER converges (all breaks at iter=20 = tran_itl limit) Commenting out VLOG to reduce noise for future profiling runs. Re-enable by uncommenting TF_CPP_VMODULE line. Co-developed-by: Claude Code v2.1.58 (claude-opus-4-6) --- .github/workflows/profile-nsys.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/profile-nsys.yml b/.github/workflows/profile-nsys.yml index 8e5dcb2c..184ef05f 100644 --- a/.github/workflows/profile-nsys.yml +++ b/.github/workflows/profile-nsys.yml @@ -180,9 +180,9 @@ jobs: XLA_FLAGS: "--xla_gpu_autotune_level=0 --xla_gpu_enable_command_buffer=+WHILE,+CONDITIONAL" XLA_PYTHON_CLIENT_PREALLOCATE: "false" XLA_PYTHON_CLIENT_ALLOCATOR: platform - TF_CPP_MIN_LOG_LEVEL: "0" + TF_CPP_MIN_LOG_LEVEL: "2" # Debug: show which thunks prevent command buffer capture for while loops - TF_CPP_VMODULE: "command_buffer_conversion_pass=2,command_buffer_cmd_emitter=2,while_thunk=3,command_buffer_cmd=5" + # TF_CPP_VMODULE: "command_buffer_conversion_pass=2,command_buffer_cmd_emitter=2,while_thunk=3,command_buffer_cmd=5" run: | SPARSE_FLAG="" if [ "${{ inputs.sparse }}" = "true" ]; then From ea5a6b4043648af4e40b69587ae38549f5fd07e1 Mon Sep 17 00:00:00 2001 From: Rob Taylor Date: Wed, 11 Mar 2026 04:57:21 +0000 Subject: [PATCH 79/79] ci: Re-enable XLA VLOG for command buffer verification MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Enable reduced VLOG set (conversion_pass=2, while_thunk=3) to verify that BaSpaCho FFI Execute handler with kCmdBufferCompatible trait enables WhileThunk→WhileCmd conversion for the NR while_loop. Co-developed-by: Claude Code v2.1.58 (claude-opus-4-6) --- .github/workflows/profile-nsys.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/profile-nsys.yml b/.github/workflows/profile-nsys.yml index 184ef05f..86b729f6 100644 --- a/.github/workflows/profile-nsys.yml +++ b/.github/workflows/profile-nsys.yml @@ -180,9 +180,9 @@ jobs: XLA_FLAGS: "--xla_gpu_autotune_level=0 --xla_gpu_enable_command_buffer=+WHILE,+CONDITIONAL" XLA_PYTHON_CLIENT_PREALLOCATE: "false" XLA_PYTHON_CLIENT_ALLOCATOR: platform - TF_CPP_MIN_LOG_LEVEL: "2" + TF_CPP_MIN_LOG_LEVEL: "0" # Debug: show which thunks prevent command buffer capture for while loops - # TF_CPP_VMODULE: "command_buffer_conversion_pass=2,command_buffer_cmd_emitter=2,while_thunk=3,command_buffer_cmd=5" + TF_CPP_VMODULE: "command_buffer_conversion_pass=2,while_thunk=3" run: | SPARSE_FLAG="" if [ "${{ inputs.sparse }}" = "true" ]; then