diff --git a/.github/workflows/benchmark-comparison.yml b/.github/workflows/benchmark-comparison.yml
index 517a072c..58de58e2 100644
--- a/.github/workflows/benchmark-comparison.yml
+++ b/.github/workflows/benchmark-comparison.yml
@@ -9,6 +9,10 @@ on:
schedule:
- cron: '0 6 * * *' # Daily at 6am UTC
+concurrency:
+ group: benchmark-${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
+ cancel-in-progress: true
+
env:
CARGO_TERM_COLOR: always
CARGO_INCREMENTAL: 0
@@ -34,8 +38,6 @@ jobs:
name: benchmark (${{ matrix.solver }}-${{ matrix.platform }})
runs-on: ${{ matrix.runner }}
timeout-minutes: ${{ matrix.solver == 'sparse' && matrix.platform == 'cuda' && 360 || 90 }}
- concurrency:
- group: benchmark-${{ matrix.solver }}-${{ matrix.platform }}-${{ github.ref }}
steps:
- name: Checkout with submodules
@@ -54,45 +56,39 @@ jobs:
nvcc --version 2>/dev/null || echo "nvcc not in PATH"
# ── System dependencies ──────────────────────────────────────
- - name: Cache apt packages
- uses: actions/cache@v4
- with:
- path: /var/cache/apt/archives
- key: apt-${{ runner.os }}-${{ runner.arch }}-${{ matrix.platform }}-v1
- restore-keys: |
- apt-${{ runner.os }}-${{ runner.arch }}-${{ matrix.platform }}-
-
- - name: Install system dependencies (CPU)
+ - name: Install apt dependencies (CPU)
if: matrix.platform == 'cpu'
- run: |
- sudo apt-get update
- sudo apt-get install -y \
- cmake ninja-build \
- flex bison libfl-dev \
- libsuitesparse-dev libopenblas-dev \
+ uses: robtaylor/cache-apt-pkgs-action@feat/apt-sources
+ with:
+ packages: >-
+ llvm-18-dev clang-18 libclang-18-dev lld-18
+ cmake ninja-build flex bison libfl-dev
+ libsuitesparse-dev libopenblas-dev
ccache bc
+ apt-sources: |
+ https://apt.llvm.org/llvm-snapshot.gpg.key | deb http://apt.llvm.org/noble/ llvm-toolchain-noble-18 main
+ execute_install_scripts: true
- - name: Install system dependencies (CUDA)
+ - name: Install apt dependencies (CUDA)
if: matrix.platform == 'cuda'
+ uses: robtaylor/cache-apt-pkgs-action@feat/apt-sources
+ with:
+ packages: >-
+ llvm-18-dev clang-18 libclang-18-dev lld-18
+ cuda-nvcc-12-6 cuda-cudart-dev-12-6 cuda-driver-dev-12-6
+ libcublas-dev-12-6 libcusolver-dev-12-6 libcusparse-dev-12-6
+ libnvjitlink-dev-12-6
+ cuda-libraries-12-6 cuda-cupti-12-6
+ libcudnn9-cuda-12 libcudss0-cuda-12
+ libsuitesparse-dev libopenblas-dev swig cmake pkg-config
+ # NVIDIA CUDA repo is already on the runner image (cuda-archive-keyring.gpg).
+ # Adding it again via apt-sources causes "Conflicting Signed-By" error.
+ apt-sources: |
+ https://apt.llvm.org/llvm-snapshot.gpg.key | deb http://apt.llvm.org/noble/ llvm-toolchain-noble-18 main
+ execute_install_scripts: true
+
+ - name: Set LLVM and CUDA environment
run: |
- sudo apt-get update
- sudo apt-get install -y \
- cmake pkg-config swig \
- libsuitesparse-dev libopenblas-dev
-
- # ── LLVM 18 (idempotent, works on all runners) ──────────────
- - name: Install LLVM 18
- run: |
- if ! llvm-config-18 --version 2>/dev/null; then
- wget -qO- https://apt.llvm.org/llvm-snapshot.gpg.key | \
- sudo tee /etc/apt/trusted.gpg.d/apt.llvm.org.asc > /dev/null
- sudo chmod a+r /etc/apt/trusted.gpg.d/apt.llvm.org.asc
- wget -q https://apt.llvm.org/llvm.sh
- chmod +x llvm.sh
- sudo ./llvm.sh 18
- rm llvm.sh
- fi
- sudo apt-get install -y lld-18
echo "/usr/lib/llvm-18/bin" >> "$GITHUB_PATH"
echo "LLVM_SYS_181_PREFIX=/usr/lib/llvm-18" >> "$GITHUB_ENV"
@@ -134,24 +130,37 @@ jobs:
openvaf_jax/openvaf_py
vendor/OpenVAF
- # ── CUDA toolkit ─────────────────────────────────────────────
- - name: Install CUDA toolkit
+ # ── CUDA environment ─────────────────────────────────────────
+ - name: Set CUDA environment
if: matrix.platform == 'cuda'
run: |
- sudo apt-get install -y cuda-toolkit-12-6 libcudnn9-cuda-12 libcudss0-cuda-12
- echo "/usr/local/cuda-12.6/bin" >> "$GITHUB_PATH"
- CUDSS_LIB=$(dpkg -L libcudss0-cuda-12 | grep '\.so' | head -1)
+ CUDA_ROOT=$(find /usr/local -maxdepth 1 -name "cuda-12*" -type d 2>/dev/null | sort -V | tail -1)
+ if [ -z "$CUDA_ROOT" ] && [ -d "/usr/local/cuda" ]; then
+ CUDA_ROOT="/usr/local/cuda"
+ fi
+ EXTRA_LD=""
+ if [ -n "$CUDA_ROOT" ]; then
+ echo "${CUDA_ROOT}/bin" >> "$GITHUB_PATH"
+ # CUDA lib64 has cuSPARSE, cuFFT, etc. needed by JAX at startup.
+ EXTRA_LD="${CUDA_ROOT}/lib64"
+ fi
+ CUDSS_LIB=$(dpkg -L libcudss0-cuda-12 2>/dev/null | grep '\.so' | head -1)
if [ -n "$CUDSS_LIB" ]; then
CUDSS_DIR=$(dirname "$CUDSS_LIB")
- echo "LD_LIBRARY_PATH=${CUDSS_DIR}:${LD_LIBRARY_PATH}" >> "$GITHUB_ENV"
+ EXTRA_LD="${CUDSS_DIR}:${EXTRA_LD}"
echo "cuDSS library found at: $CUDSS_LIB"
fi
+ if [ -n "$EXTRA_LD" ]; then
+ echo "LD_LIBRARY_PATH=${EXTRA_LD}:${LD_LIBRARY_PATH}" >> "$GITHUB_ENV"
+ fi
# ── Python environment ──────────────────────────────────────
- name: Install uv
uses: astral-sh/setup-uv@v6
with:
enable-cache: true
+ prune-cache: false
+ python-version: "3.12"
- name: Set up Python
run: uv python install 3.12
@@ -193,7 +202,12 @@ jobs:
working-directory: vajax/sparse
run: |
uv pip install scikit-build-core nanobind
- uv pip install --no-build-isolation .
+ LAUNCHER=${{ matrix.platform == 'cuda' && 'sccache' || 'ccache' }}
+ uv pip install --no-build-isolation \
+ -C cmake.define.BLA_VENDOR=OpenBLAS \
+ -C cmake.define.CMAKE_C_COMPILER_LAUNCHER=$LAUNCHER \
+ -C cmake.define.CMAKE_CXX_COMPILER_LAUNCHER=$LAUNCHER \
+ .
# ── Run VAJAX unit tests (CPU dense only) ────────────────────
- name: Run VAJAX tests
diff --git a/.github/workflows/cache-cleanup.yml b/.github/workflows/cache-cleanup.yml
new file mode 100644
index 00000000..a7ff2f52
--- /dev/null
+++ b/.github/workflows/cache-cleanup.yml
@@ -0,0 +1,23 @@
+name: Cache Cleanup
+
+on:
+ schedule:
+ - cron: '0 0 1 */3 *' # First day of every 3rd month
+ workflow_dispatch:
+
+jobs:
+ cleanup:
+ runs-on: ubuntu-latest
+ steps:
+ - name: Delete uv caches
+ env:
+ GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
+ run: |
+ echo "=== Deleting uv caches ==="
+ gh cache list --repo ${{ github.repository }} --key setup-uv --json id,key,sizeInBytes \
+ --jq '.[] | "\(.id)\t\(.key)\t\(.sizeInBytes)"' | \
+ while IFS=$'\t' read -r id key size; do
+ echo "Deleting: $key ($(numfmt --to=iec $size))"
+ gh cache delete "$id" --repo ${{ github.repository }}
+ done
+ echo "Done"
diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml
index b6e2f231..4cc68339 100644
--- a/.github/workflows/lint.yml
+++ b/.github/workflows/lint.yml
@@ -6,6 +6,10 @@ on:
pull_request:
branches: [main]
+concurrency:
+ group: lint-${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
+ cancel-in-progress: true
+
jobs:
lint:
runs-on: ubuntu-latest
@@ -17,6 +21,8 @@ jobs:
uses: astral-sh/setup-uv@v6
with:
enable-cache: true
+ prune-cache: false
+ python-version: "3.11"
- name: Set up Python
run: uv python install 3.11
diff --git a/.github/workflows/profile-nsys.yml b/.github/workflows/profile-nsys.yml
index 28efc4af..86b729f6 100644
--- a/.github/workflows/profile-nsys.yml
+++ b/.github/workflows/profile-nsys.yml
@@ -14,17 +14,28 @@ on:
- c6288
default: ring
timesteps:
- description: 'Number of timesteps'
- default: '50'
+ description: 'Number of timesteps (500+ recommended for steady-state profiling)'
+ default: '500'
sparse:
description: 'Use sparse solver (for large circuits)'
type: boolean
default: false
+ use_baspacho:
+ description: 'Build spineax from source with BaSpaCho dense solver'
+ type: boolean
+ default: false
+ runner:
+ description: 'GPU runner to use'
+ type: choice
+ options:
+ - nvidia-runner-1
+ - nvidia-runner-2
+ default: nvidia-runner-1
-# Only one profiling job at a time on the GPU runner
+# Only one profiling job at a time per runner
concurrency:
- group: nsys-profile
- cancel-in-progress: false
+ group: nsys-profile-${{ inputs.runner }}
+ cancel-in-progress: true
env:
CARGO_TERM_COLOR: always
@@ -33,7 +44,7 @@ env:
jobs:
nsys-profile:
- runs-on: nvidia-runner-1
+ runs-on: ${{ inputs.runner }}
timeout-minutes: 60
steps:
@@ -53,18 +64,59 @@ jobs:
echo "=== nsys Version ==="
nsys --version 2>/dev/null || echo "nsys not in PATH"
- - name: Install LLVM 18
+ - name: Install all apt dependencies (LLVM, CUDA, system libs)
+ uses: robtaylor/cache-apt-pkgs-action@feat/apt-sources
+ with:
+ packages: >-
+ llvm-18-dev clang-18 libclang-18-dev lld-18
+ cuda-nvcc-12-6 cuda-cudart-dev-12-6 cuda-driver-dev-12-6
+ libcublas-dev-12-6 libcusolver-dev-12-6 libcusparse-dev-12-6
+ libnvjitlink-dev-12-6
+ cuda-libraries-12-6 cuda-cupti-12-6
+ libcudnn9-cuda-12 libcudss0-cuda-12
+ cuda-nsight-systems-12-6
+ libsuitesparse-dev libopenblas-dev swig cmake pkg-config
+ # NVIDIA CUDA repo is already on the runner image (cuda-archive-keyring.gpg).
+ # Adding it again via apt-sources causes "Conflicting Signed-By" error.
+ apt-sources: |
+ https://apt.llvm.org/llvm-snapshot.gpg.key | deb http://apt.llvm.org/noble/ llvm-toolchain-noble-18 main
+ execute_install_scripts: true
+
+ - name: Set up LLVM and CUDA environment
run: |
- if ! llvm-config-18 --version 2>/dev/null; then
- wget -qO- https://apt.llvm.org/llvm-snapshot.gpg.key | sudo tee /etc/apt/trusted.gpg.d/apt.llvm.org.asc > /dev/null
- sudo chmod a+r /etc/apt/trusted.gpg.d/apt.llvm.org.asc
- wget -q https://apt.llvm.org/llvm.sh
- chmod +x llvm.sh
- sudo ./llvm.sh 18
- rm llvm.sh
- fi
echo "LLVM_SYS_181_PREFIX=/usr/lib/llvm-18" >> $GITHUB_ENV
+ # Find CUDA toolkit (may be at /usr/local/cuda-12.6 or elsewhere)
+ CUDA_ROOT=$(find /usr/local -maxdepth 1 -name "cuda-12*" -type d 2>/dev/null | sort -V | tail -1)
+ if [ -z "$CUDA_ROOT" ] && [ -d "/usr/local/cuda" ]; then
+ CUDA_ROOT="/usr/local/cuda"
+ fi
+ if [ -n "$CUDA_ROOT" ]; then
+ echo "${CUDA_ROOT}/bin" >> $GITHUB_PATH
+ echo "CUDAToolkit_ROOT=${CUDA_ROOT}" >> $GITHUB_ENV
+ # Build LD_LIBRARY_PATH with CUDA lib64 (cuSPARSE, cuFFT, etc.)
+ # Must be a single write — GITHUB_ENV takes last value per key.
+ EXTRA_LD="${CUDA_ROOT}/lib64"
+ echo "CUDA found at: ${CUDA_ROOT}"
+ else
+ echo "::warning::CUDA toolkit not found under /usr/local"
+ EXTRA_LD=""
+ fi
+
+ CUDSS_LIB=$(dpkg -L libcudss0-cuda-12 2>/dev/null | grep '\.so' | head -1)
+ if [ -n "$CUDSS_LIB" ]; then
+ CUDSS_DIR=$(dirname "$CUDSS_LIB")
+ EXTRA_LD="${CUDSS_DIR}:${EXTRA_LD}"
+ fi
+ if [ -n "$EXTRA_LD" ]; then
+ echo "LD_LIBRARY_PATH=${EXTRA_LD}:${LD_LIBRARY_PATH}" >> "$GITHUB_ENV"
+ fi
+ NSYS_BIN=$(dirname "$(find /opt/nvidia -name nsys -type f 2>/dev/null | head -1)" 2>/dev/null)
+ if [ -n "$NSYS_BIN" ]; then
+ echo "$NSYS_BIN" >> $GITHUB_PATH
+ fi
+ nsys --version || echo "::warning::nsys not found in PATH"
+
- name: Set up Rust
uses: dtolnay/rust-toolchain@stable
@@ -78,20 +130,12 @@ jobs:
with:
workspaces: openvaf_jax/openvaf_py
- - name: Install CUDA toolkit
- run: |
- sudo apt-get update
- sudo apt-get install -y cuda-nvcc-12-6 cuda-cudart-dev-12-6
- echo "/usr/local/cuda-12.6/bin" >> $GITHUB_PATH
-
- - name: Install system dependencies
- run: |
- sudo apt-get install -y libsuitesparse-dev swig cmake pkg-config
-
- name: Install uv
uses: astral-sh/setup-uv@v6
with:
enable-cache: true
+ prune-cache: false
+ python-version: "3.13"
- name: Set up Python
run: uv python install 3.13
@@ -105,13 +149,40 @@ jobs:
- name: Install vajax with CUDA dependencies
run: uv sync --extra cuda12
+ - name: Install spineax with BaSpaCho from source
+ if: inputs.use_baspacho
+ run: |
+ # Replace PyPI spineax with git version that has BaSpaCho dense solver.
+ # --reinstall forces rebuild even if same version is cached from PyPI.
+ # FetchContent auto-fetches BaSpaCho + deps (SuiteSparse, dispenso, Eigen).
+ # sccache caches C++/CUDA compilation (BaSpaCho + SuiteSparse = 344 objects)
+ uv pip install --reinstall -v \
+ "spineax-vajax @ git+https://github.com/robtaylor/spineax.git@main" \
+ -C cmake.define.SPINEAX_USE_BASPACHO=ON \
+ -C cmake.define.CUDAToolkit_ROOT=${CUDAToolkit_ROOT:-/usr/local/cuda-12.6} \
+ -C cmake.define.CMAKE_C_COMPILER_LAUNCHER=sccache \
+ -C cmake.define.CMAKE_CXX_COMPILER_LAUNCHER=sccache \
+ -C cmake.define.CMAKE_CUDA_COMPILER_LAUNCHER=sccache
+
+ # Verify the module was installed
+ echo "=== Installed spineax files ==="
+ uv run python -c "import spineax; import pathlib; p = pathlib.Path(spineax.__file__).parent; [print(f' {f.name}') for f in sorted(p.iterdir())]"
+ echo "=== spineax directory listing ==="
+ ls -lR .venv/lib/python*/site-packages/spineax/ || true
+ echo "=== Checking baspacho_dense_solve import ==="
+ uv run python -c "from spineax import baspacho_dense_solve; print(' OK:', baspacho_dense_solve)" || echo " FAILED"
+
- name: Run nsys profiling
env:
JAX_PLATFORMS: cuda,cpu
- XLA_FLAGS: "--xla_gpu_autotune_level=0"
+ # +WHILE,+CONDITIONAL: enables command buffer capture for NR while loop
+ # BaSpaCho dense solver has kCmdBufferCompatible trait (replaces cuSOLVER)
+ XLA_FLAGS: "--xla_gpu_autotune_level=0 --xla_gpu_enable_command_buffer=+WHILE,+CONDITIONAL"
XLA_PYTHON_CLIENT_PREALLOCATE: "false"
XLA_PYTHON_CLIENT_ALLOCATOR: platform
- TF_CPP_MIN_LOG_LEVEL: "2"
+ TF_CPP_MIN_LOG_LEVEL: "0"
+ # Debug: show which thunks prevent command buffer capture for while loops
+ TF_CPP_VMODULE: "command_buffer_conversion_pass=2,while_thunk=3"
run: |
SPARSE_FLAG=""
if [ "${{ inputs.sparse }}" = "true" ]; then
@@ -120,6 +191,9 @@ jobs:
TIMESTAMP=$(date +%s)
PROFILE_NAME="nsys-${{ inputs.circuit }}-${TIMESTAMP}"
+ # Export profile_name early so subsequent steps can find the report
+ # even if nsys exits non-zero (e.g. SIGSEGV during teardown)
+ echo "profile_name=${PROFILE_NAME}" >> "$GITHUB_ENV"
echo "=== Starting nsys GPU Profiling ==="
echo "Circuit: ${{ inputs.circuit }}"
@@ -127,17 +201,51 @@ jobs:
echo "Sparse: ${{ inputs.sparse }}"
echo "Commit: ${{ github.sha }}"
+ # NVTX "run_transient" marker annotates the transient window for filtering.
+ # Cannot use --capture-range=nvtx: SIGSEGV during JAX teardown prevents
+ # nsys from finalizing the report in NVTX capture mode.
+ # Tolerate exit code 139 (SIGSEGV during JAX/CUDA teardown)
nsys profile \
--trace=cuda,nvtx,osrt \
--output "/tmp/${PROFILE_NAME}" \
uv run python scripts/nsys_profile_target.py \
- ${{ inputs.circuit }} ${{ inputs.timesteps }} ${SPARSE_FLAG}
+ ${{ inputs.circuit }} ${{ inputs.timesteps }} ${SPARSE_FLAG} \
+ || NSYS_EXIT=$?
- echo "profile_name=${PROFILE_NAME}" >> "$GITHUB_ENV"
+ if [ "${NSYS_EXIT:-0}" -ne 0 ]; then
+ echo "::warning::nsys exited with code ${NSYS_EXIT} (139=SIGSEGV during teardown, report may still be valid)"
+ if [ ! -f "/tmp/${PROFILE_NAME}.nsys-rep" ]; then
+ echo "::error::nsys report not generated"
+ exit 1
+ fi
+ fi
+
+ - name: Export stats and SQLite
+ if: always()
+ run: |
+ REPORT="/tmp/${profile_name}.nsys-rep"
+ STATS_DIR="/tmp/nsys-stats"
+ mkdir -p "$STATS_DIR"
+
+ if [ -f "$REPORT" ]; then
+ # Export to SQLite for offline analysis
+ nsys export --type=sqlite --output="/tmp/${profile_name}.sqlite" "$REPORT" || true
+
+ # Generate key stats reports
+ for report in cuda_gpu_kern_sum cuda_api_sum cuda_gpu_mem_time_sum cuda_gpu_mem_size_sum; do
+ nsys stats "$REPORT" --report "$report" --format csv \
+ --force-export=true --output "${STATS_DIR}/${report}" 2>/dev/null || true
+ done
+
+ # Full kernel trace (top 100 by duration)
+ nsys stats "$REPORT" --report cuda_gpu_trace --format csv \
+ --force-export=true --output "${STATS_DIR}/cuda_gpu_trace" 2>/dev/null || true
+ fi
- name: Generate summary
if: always()
run: |
+ REPORT="/tmp/${profile_name}.nsys-rep"
{
echo "## nsys GPU Profiling Results"
echo ""
@@ -146,20 +254,34 @@ jobs:
echo "**Sparse:** ${{ inputs.sparse }}"
echo "**Commit:** \`${{ github.sha }}\`"
echo ""
- if [ -f "/tmp/${profile_name}.nsys-rep" ]; then
- echo "### Profile Summary"
+ if [ -f "$REPORT" ]; then
+ echo "### GPU Kernel Summary"
+ echo '```'
+ nsys stats "$REPORT" --report cuda_gpu_kern_sum --force-export=true 2>&1 | head -50 || true
+ echo '```'
+ echo ""
+ echo "### CUDA API Summary"
+ echo '```'
+ nsys stats "$REPORT" --report cuda_api_sum --force-export=true 2>&1 | head -30 || true
+ echo '```'
echo ""
+ echo "### GPU Memory Transfers"
echo '```'
- nsys stats "/tmp/${profile_name}.nsys-rep" --report cuda_gpu_kern_sum 2>&1 | head -40 || true
+ nsys stats "$REPORT" --report cuda_gpu_mem_time_sum --force-export=true 2>&1 | head -20 || true
echo '```'
+ else
+ echo "**No profile report found.**"
fi
} >> "$GITHUB_STEP_SUMMARY"
- - name: Upload nsys report
+ - name: Upload nsys report and stats
if: always()
uses: actions/upload-artifact@v4
with:
name: nsys-profile-${{ inputs.circuit }}-${{ github.sha }}
- path: /tmp/nsys-*.nsys-rep
+ path: |
+ /tmp/nsys-*.nsys-rep
+ /tmp/nsys-*.sqlite
+ /tmp/nsys-stats/
if-no-files-found: ignore
retention-days: 30
diff --git a/.github/workflows/test-pdk.yml b/.github/workflows/test-pdk.yml
index 13a7f4e0..34b4ab54 100644
--- a/.github/workflows/test-pdk.yml
+++ b/.github/workflows/test-pdk.yml
@@ -6,6 +6,10 @@ on:
pull_request:
branches: [main]
+concurrency:
+ group: pdk-${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
+ cancel-in-progress: true
+
env:
CARGO_TERM_COLOR: always
CARGO_INCREMENTAL: 0
@@ -44,12 +48,16 @@ jobs:
# Mask PDK path in all subsequent log output
echo "::add-mask::/tmp/pdk-gf130"
- - name: Install LLVM 18
- run: |
- wget -q https://apt.llvm.org/llvm.sh
- chmod +x llvm.sh
- sudo ./llvm.sh 18
- echo "LLVM_SYS_181_PREFIX=/usr/lib/llvm-18" >> $GITHUB_ENV
+ - name: Install apt dependencies
+ uses: robtaylor/cache-apt-pkgs-action@feat/apt-sources
+ with:
+ packages: llvm-18-dev clang-18 libclang-18-dev lld-18
+ apt-sources: |
+ https://apt.llvm.org/llvm-snapshot.gpg.key | deb http://apt.llvm.org/noble/ llvm-toolchain-noble-18 main
+ execute_install_scripts: true
+
+ - name: Set LLVM environment
+ run: echo "LLVM_SYS_181_PREFIX=/usr/lib/llvm-18" >> $GITHUB_ENV
- name: Set up Rust
uses: dtolnay/rust-toolchain@stable
@@ -68,6 +76,8 @@ jobs:
uses: astral-sh/setup-uv@v6
with:
enable-cache: true
+ prune-cache: false
+ python-version: "3.10"
- name: Set up Python
run: uv python install 3.10
diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml
index fa4d8db0..56e20eb1 100644
--- a/.github/workflows/test.yml
+++ b/.github/workflows/test.yml
@@ -6,6 +6,10 @@ on:
pull_request:
branches: [main]
+concurrency:
+ group: test-${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
+ cancel-in-progress: true
+
env:
CARGO_TERM_COLOR: always
CARGO_INCREMENTAL: 0
@@ -23,12 +27,16 @@ jobs:
with:
submodules: recursive
- - name: Install LLVM 18
- run: |
- wget https://apt.llvm.org/llvm.sh
- chmod +x llvm.sh
- sudo ./llvm.sh 18
- echo "LLVM_SYS_181_PREFIX=/usr/lib/llvm-18" >> $GITHUB_ENV
+ - name: Install apt dependencies
+ uses: robtaylor/cache-apt-pkgs-action@feat/apt-sources
+ with:
+ packages: llvm-18-dev clang-18 libclang-18-dev lld-18
+ apt-sources: |
+ https://apt.llvm.org/llvm-snapshot.gpg.key | deb http://apt.llvm.org/noble/ llvm-toolchain-noble-18 main
+ execute_install_scripts: true
+
+ - name: Set LLVM environment
+ run: echo "LLVM_SYS_181_PREFIX=/usr/lib/llvm-18" >> $GITHUB_ENV
- name: Set up Rust
uses: dtolnay/rust-toolchain@stable
@@ -47,6 +55,8 @@ jobs:
uses: astral-sh/setup-uv@v6
with:
enable-cache: true
+ prune-cache: false
+ python-version: "3.11"
- name: Set up Python
run: uv python install 3.11
@@ -105,12 +115,19 @@ jobs:
with:
submodules: recursive
- - name: Install LLVM 18
- run: |
- wget https://apt.llvm.org/llvm.sh
- chmod +x llvm.sh
- sudo ./llvm.sh 18
- echo "LLVM_SYS_181_PREFIX=/usr/lib/llvm-18" >> $GITHUB_ENV
+ - name: Install apt dependencies
+ uses: robtaylor/cache-apt-pkgs-action@feat/apt-sources
+ with:
+ packages: >-
+ llvm-18-dev clang-18 libclang-18-dev lld-18
+ libsuitesparse-dev libopenblas-dev swig cmake
+ ${{ matrix.test-group.extra_packages }}
+ apt-sources: |
+ https://apt.llvm.org/llvm-snapshot.gpg.key | deb http://apt.llvm.org/noble/ llvm-toolchain-noble-18 main
+ execute_install_scripts: true
+
+ - name: Set LLVM environment
+ run: echo "LLVM_SYS_181_PREFIX=/usr/lib/llvm-18" >> $GITHUB_ENV
- name: Set up Rust
uses: dtolnay/rust-toolchain@stable
@@ -125,15 +142,12 @@ jobs:
with:
workspaces: openvaf_jax/openvaf_py
- - name: Install system dependencies
- run: |
- sudo apt-get update
- sudo apt-get install -y libsuitesparse-dev libopenblas-dev swig cmake ${{ matrix.test-group.extra_packages }}
-
- name: Install uv
uses: astral-sh/setup-uv@v6
with:
enable-cache: true
+ prune-cache: false
+ python-version: "3.11"
- name: Set up Python
run: uv python install 3.11
@@ -160,7 +174,11 @@ jobs:
working-directory: vajax/sparse
run: |
uv pip install scikit-build-core nanobind
- uv pip install --no-build-isolation .
+ uv pip install --no-build-isolation \
+ -C cmake.define.BLA_VENDOR=OpenBLAS \
+ -C cmake.define.CMAKE_C_COMPILER_LAUNCHER=sccache \
+ -C cmake.define.CMAKE_CXX_COMPILER_LAUNCHER=sccache \
+ .
- name: Run tests (${{ matrix.test-group.name }})
timeout-minutes: ${{ matrix.test-group.timeout }}
@@ -226,6 +244,8 @@ jobs:
uses: astral-sh/setup-uv@v6
with:
enable-cache: true
+ prune-cache: false
+ python-version: "3.11"
- name: Set up Python
run: uv python install 3.11
diff --git a/.github/workflows/xla-flag-sweep.yml b/.github/workflows/xla-flag-sweep.yml
new file mode 100644
index 00000000..eb61cc1b
--- /dev/null
+++ b/.github/workflows/xla-flag-sweep.yml
@@ -0,0 +1,121 @@
+name: XLA Flag Sweep
+
+on:
+ workflow_dispatch:
+ inputs:
+ benchmarks:
+ description: 'Comma-separated benchmark names'
+ required: false
+ default: 'ring'
+ configs:
+ description: 'Comma-separated config names (empty = all)'
+ required: false
+ default: ''
+ n_runs:
+ description: 'Number of timed runs per config'
+ required: false
+ default: '3'
+
+env:
+ CARGO_TERM_COLOR: always
+ CARGO_INCREMENTAL: 0
+
+jobs:
+ sweep:
+ name: XLA flag sweep (CUDA)
+ runs-on: nvidia-runner-1
+ timeout-minutes: 120
+ concurrency:
+ group: xla-sweep-${{ github.ref }}
+
+ steps:
+ - name: Checkout with submodules
+ uses: actions/checkout@v4
+ with:
+ submodules: recursive
+
+ - name: CUDA diagnostics
+ run: |
+ echo "=== GPU Info ==="
+ nvidia-smi 2>/dev/null || echo "nvidia-smi not available"
+ echo "=== CUDA Version ==="
+ nvcc --version 2>/dev/null || echo "nvcc not available"
+
+ - name: Set up Python + uv
+ uses: astral-sh/setup-uv@v6
+ with:
+ enable-cache: true
+ prune-cache: false
+ cache-dependency-glob: "uv.lock"
+
+ - name: Install dependencies
+ run: uv sync --frozen --all-extras
+
+ - name: Verify GPU access
+ run: |
+ uv run python -c "
+ import jax
+ devices = jax.devices()
+ print(f'JAX devices: {devices}')
+ print(f'GPU available: {any(d.platform == \"gpu\" for d in devices)}')
+ for d in devices:
+ print(f' {d.platform}: {d}')
+ "
+ env:
+ JAX_PLATFORMS: cuda,cpu
+
+ - name: Run XLA flag sweep
+ env:
+ JAX_PLATFORMS: cuda,cpu
+ JAX_ENABLE_X64: "1"
+ XLA_PYTHON_CLIENT_PREALLOCATE: "false"
+ XLA_PYTHON_CLIENT_ALLOCATOR: "platform"
+ run: |
+ CONFIGS_FLAG=""
+ if [ -n "${{ inputs.configs }}" ]; then
+ CONFIGS_FLAG="--configs ${{ inputs.configs }}"
+ fi
+
+ uv run python scripts/sweep_xla_flags.py \
+ --benchmark ${{ inputs.benchmarks }} \
+ --n-runs ${{ inputs.n_runs }} \
+ $CONFIGS_FLAG \
+ --json-output /tmp/xla-sweep-results.json \
+ 2>&1 | tee /tmp/xla-sweep.log
+
+ - name: Generate summary
+ if: always()
+ run: |
+ {
+ echo "## XLA Flag Sweep Results"
+ echo ""
+ echo "**Benchmarks**: ${{ inputs.benchmarks }}"
+ echo "**Runs per config**: ${{ inputs.n_runs }}"
+ echo ""
+
+ if [ -f /tmp/xla-sweep.log ]; then
+ echo '```'
+ # Extract the summary table
+ sed -n '/^={80}/,$ p' /tmp/xla-sweep.log | head -40
+ echo '```'
+ fi
+
+ if [ -f /tmp/xla-sweep-results.json ]; then
+ echo ""
+ echo "Raw JSON results
"
+ echo ""
+ echo '```json'
+ cat /tmp/xla-sweep-results.json
+ echo '```'
+ echo " "
+ fi
+ } >> "$GITHUB_STEP_SUMMARY"
+
+ - name: Upload results
+ if: always()
+ uses: actions/upload-artifact@v4
+ with:
+ name: xla-sweep-results
+ path: |
+ /tmp/xla-sweep-results.json
+ /tmp/xla-sweep.log
diff --git a/CLAUDE.md b/CLAUDE.md
index 7ec6dfed..9c2246c1 100644
--- a/CLAUDE.md
+++ b/CLAUDE.md
@@ -78,6 +78,22 @@ uv run python scripts/profile_gpu.py --benchmark ring,c6288
```
+## Linting
+
+**Run before every commit.** CI runs both checks and will reject PRs that fail.
+
+```bash
+# Lint check (import sorting, unused imports, etc.)
+uv tool run ruff check vajax/ tests/ scripts/ benchmarks/
+
+# Format check (code style)
+uv tool run ruff format --check vajax/ tests/ scripts/ benchmarks/
+
+# Auto-fix both
+uv tool run ruff check --fix vajax/ tests/ scripts/ benchmarks/
+uv tool run ruff format vajax/ tests/ scripts/ benchmarks/
+```
+
## Precision Configuration
Precision is auto-configured on import via `vajax/__init__.py`:
diff --git a/openvaf_jax/__init__.py b/openvaf_jax/__init__.py
index d0c15847..bc29f910 100644
--- a/openvaf_jax/__init__.py
+++ b/openvaf_jax/__init__.py
@@ -1266,6 +1266,51 @@ def translate_init_array_split(
return init_fn, metadata
+ def build_sccp_known_values(
+ self,
+ shared_indices: List[int],
+ shared_values: List[float],
+ shared_cache_indices: Optional[List[int]] = None,
+ shared_cache_values: Optional[List[float]] = None,
+ ) -> Optional[Dict[str, Any]]:
+ """Build SCCP known_values dict from concrete shared params and cache.
+
+ Maps concrete values to their MIR value IDs for SCCP dead-block elimination.
+ This tells the SCCP pass which MIR values are constant, enabling it to
+ resolve static branches and eliminate dead blocks at codegen time.
+
+ NOTE: These values are used ONLY for SCCP analysis, NOT for literal
+ inlining in the generated Python code. Array reads (shared_params[i],
+ shared_cache[i]) are preserved in the generated code for GPU efficiency.
+ Literal inlining was found to cause 7.8x GPU regression on PSP103 by
+ embedding ~1300 float constants in the XLA kernel.
+
+ Args:
+ shared_indices: Original param indices for shared params
+ shared_values: Concrete float values for each shared param
+ shared_cache_indices: Cache column indices for shared cache entries
+ shared_cache_values: Concrete float values for each shared cache entry
+
+ Returns:
+ Dict mapping MIR value IDs to concrete values, or None if empty.
+ """
+ known: Dict[str, Any] = {}
+
+ for j, orig_idx in enumerate(shared_indices):
+ value_id = self.param_idx_to_val.get(orig_idx)
+ if value_id:
+ known[value_id] = shared_values[j]
+
+ if shared_cache_values is not None and shared_cache_indices:
+ for j, cache_col_idx in enumerate(shared_cache_indices):
+ mapping = self.cache_mapping[cache_col_idx]
+ eval_param_idx = mapping["eval_param"]
+ value_id = self.param_idx_to_val.get(eval_param_idx)
+ if value_id:
+ known[value_id] = shared_cache_values[j]
+
+ return known if known else None
+
def translate_eval_array_with_cache_split(
self,
shared_indices: List[int],
@@ -1274,6 +1319,7 @@ def translate_eval_array_with_cache_split(
varying_cache_indices: Optional[List[int]] = None,
use_limit_functions: bool = False,
limit_param_map: Optional[Dict[int, Tuple[str, str]]] = None,
+ sccp_known_values: Optional[Dict[str, Any]] = None,
) -> Tuple[Callable, Dict]:
"""Generate a vmappable eval function with split params and cache (internal API).
@@ -1291,6 +1337,10 @@ def translate_eval_array_with_cache_split(
limit_param_map: Dict mapping original param indices to (kind, name) tuples
for limit-related params (prev_state, enable_lim, new_state,
enable_integration). Excluded from shared/device params.
+ sccp_known_values: If provided, maps MIR value IDs to concrete values
+ for SCCP dead-block elimination. Build this with
+ build_sccp_known_values(). Values are used for SCCP
+ analysis only — NOT inlined as literals in generated code.
Returns:
Tuple of (eval_fn, metadata)
@@ -1316,10 +1366,15 @@ def translate_eval_array_with_cache_split(
assert self.dae_data is not None, "dae_data released, call before release_mir_data()"
t0 = time.perf_counter()
+ n_sccp = len(sccp_known_values) if sccp_known_values else 0
logger.info(
- f" translate_eval_array_with_cache_split: generating code (limit_funcs={use_limit_functions})..."
+ f" translate_eval_array_with_cache_split: generating code "
+ f"(limit_funcs={use_limit_functions}, sccp_known={n_sccp})..."
)
+ if sccp_known_values:
+ logger.info(f" SCCP: {n_sccp} known values for dead-block elimination")
+
# Build the eval function
eval_param_names = list(self.module.param_names)
builder = EvalFunctionBuilder(
@@ -1327,6 +1382,7 @@ def translate_eval_array_with_cache_split(
self.dae_data,
self.cache_mapping,
self.param_idx_to_val,
+ sccp_known_values=sccp_known_values,
eval_param_names=eval_param_names,
)
fn_name, code_lines = builder.build_with_cache_split(
@@ -1339,6 +1395,22 @@ def translate_eval_array_with_cache_split(
)
t1 = time.perf_counter()
+
+ # Log SCCP statistics
+ if builder.sccp is not None:
+ dead_blocks = builder.sccp.get_dead_blocks()
+ total_blocks = len(self.eval_mir.blocks)
+ n_constants = sum(1 for v in builder.sccp.lattice.values() if v.is_constant())
+ static_branches = sum(
+ 1
+ for b in self.eval_mir.blocks
+ if builder.sccp.get_static_branch_direction(b) is not None
+ )
+ logger.info(
+ f" SCCP results: {len(dead_blocks)}/{total_blocks} blocks dead, "
+ f"{static_branches} static branches, {n_constants} constants propagated"
+ )
+
logger.info(
f" translate_eval_array_with_cache_split: code generated ({len(code_lines)} lines) in {t1 - t0:.1f}s"
)
@@ -1359,6 +1431,11 @@ def translate_eval_array_with_cache_split(
df.write(f"# use_limit_functions={use_limit_functions}\n")
df.write(f"# shared_indices={shared_indices}\n")
df.write(f"# varying_indices={varying_indices}\n")
+ df.write(f"# sccp_known_values={n_sccp}\n")
+ if builder.sccp is not None:
+ dead_blocks = builder.sccp.get_dead_blocks()
+ total_blocks = len(self.eval_mir.blocks)
+ df.write(f"# sccp: {len(dead_blocks)}/{total_blocks} blocks dead\n")
df.write(code)
logger.info(f" Generated code dumped to {dump_path}")
diff --git a/openvaf_jax/codegen/function_builder.py b/openvaf_jax/codegen/function_builder.py
index 8a4b68d9..0c04bc9d 100644
--- a/openvaf_jax/codegen/function_builder.py
+++ b/openvaf_jax/codegen/function_builder.py
@@ -980,7 +980,10 @@ def build_with_cache_split(
return fn_name, code_str.split("\n")
def _emit_param_mapping(
- self, body: List[ast.stmt], ctx: CodeGenContext, idx_mapping: Dict[int, Tuple[str, any]]
+ self,
+ body: List[ast.stmt],
+ ctx: CodeGenContext,
+ idx_mapping: Dict[int, Tuple[str, any]],
):
"""Emit parameter mapping from split arrays.
@@ -999,7 +1002,10 @@ def _emit_param_mapping(
source, value = idx_mapping[i]
if source == "shared":
body.append(
- assign(var_name, subscript(ast_name("shared_params"), ast_const(value)))
+ assign(
+ var_name,
+ subscript(ast_name("shared_params"), ast_const(value)),
+ )
)
elif source == "device":
body.append(
@@ -1041,6 +1047,11 @@ def _emit_cache_mapping(
"""Emit cache value mapping from split cache arrays.
Always uses split cache format (shared_cache, device_cache) for uniform interface.
+
+ Args:
+ body: List to append statements to
+ ctx: Code generation context
+ cache_idx_mapping: Maps cache index to (source, new_index)
"""
for cache_idx, mapping in enumerate(self.cache_mapping):
eval_param_idx = mapping["eval_param"]
@@ -1051,7 +1062,10 @@ def _emit_cache_mapping(
source, new_idx = cache_idx_mapping[cache_idx]
if source == "shared_cache":
body.append(
- assign(var_name, subscript(ast_name("shared_cache"), ast_const(new_idx)))
+ assign(
+ var_name,
+ subscript(ast_name("shared_cache"), ast_const(new_idx)),
+ )
)
else:
body.append(
diff --git a/scripts/analyze_dense_jaxpr.py b/scripts/analyze_dense_jaxpr.py
new file mode 100644
index 00000000..2d5baf1d
--- /dev/null
+++ b/scripts/analyze_dense_jaxpr.py
@@ -0,0 +1,217 @@
+#!/usr/bin/env -S uv run --script
+# /// script
+# requires-python = ">=3.10"
+# dependencies = ["jax", "jaxlib"]
+# ///
+"""Analyze JAX IR for dense benchmark circuits.
+
+Dumps jaxpr, HLO op counts, and cost analysis for the key hot paths:
+1. build_system (Jacobian + residual assembly)
+2. nr_solve (Newton-Raphson with while_loop)
+3. run_while (full transient step with adaptive timestep)
+"""
+
+import os
+import sys
+from pathlib import Path
+
+os.environ.setdefault("JAX_PLATFORMS", "cpu")
+
+# Add project root to path
+sys.path.insert(0, str(Path(__file__).parent.parent))
+
+import jax
+import jax.numpy as jnp
+
+from vajax.analysis.engine import CircuitEngine
+from vajax.benchmarks.registry import get_benchmark
+
+
+def count_hlo_ops(hlo_text: str) -> dict[str, int]:
+ """Count operation types in HLO text."""
+ op_counts: dict[str, int] = {}
+ for line in hlo_text.split("\n"):
+ if "=" in line and "." in line:
+ parts = line.split("=")
+ if len(parts) >= 2:
+ op_part = parts[1].strip().split()[0] if parts[1].strip() else ""
+ if "." in op_part:
+ op_name = op_part.split("(")[0]
+ op_counts[op_name] = op_counts.get(op_name, 0) + 1
+ return dict(sorted(op_counts.items(), key=lambda x: -x[1]))
+
+
+def analyze_function(name: str, fn, args, output_dir: Path):
+ """Analyze a single function: jaxpr, HLO, cost."""
+ print(f"\n{'=' * 70}")
+ print(f" {name}")
+ print(f"{'=' * 70}")
+
+ try:
+ if hasattr(fn, "lower"):
+ lowered = fn.lower(*args)
+ else:
+ lowered = jax.jit(fn).lower(*args)
+
+ hlo_text = lowered.as_text()
+ hlo_lines = hlo_text.split("\n")
+ print(f" HLO: {len(hlo_lines)} lines")
+
+ ops = count_hlo_ops(hlo_text)
+ if ops:
+ top = list(ops.items())[:20]
+ print(" Top HLO ops:")
+ for op, count in top:
+ print(f" {op:40s} {count:6d}")
+
+ compiled = lowered.compile()
+ cost = compiled.cost_analysis()
+ if cost:
+ for i, device_cost in enumerate(cost):
+ if device_cost and isinstance(device_cost, dict):
+ print(f" Cost (device {i}):")
+ for key, val in sorted(device_cost.items()):
+ if isinstance(val, (int, float)):
+ if val > 1e9:
+ print(f" {key}: {val / 1e9:.2f}G")
+ elif val > 1e6:
+ print(f" {key}: {val / 1e6:.2f}M")
+ elif val > 1e3:
+ print(f" {key}: {val / 1e3:.2f}K")
+ else:
+ print(f" {key}: {val:.2f}")
+
+ # Save HLO
+ output_dir.mkdir(parents=True, exist_ok=True)
+ hlo_file = output_dir / f"{name}.hlo.txt"
+ with open(hlo_file, "w") as f:
+ f.write(hlo_text)
+ print(f" Saved: {hlo_file}")
+
+ except Exception as e:
+ import traceback
+
+ print(f" Failed: {e}")
+ traceback.print_exc()
+
+
+def analyze_benchmark(benchmark_name: str, output_dir: Path):
+ """Analyze all hot paths for a single benchmark."""
+ print(f"\n{'#' * 70}")
+ print(f" Benchmark: {benchmark_name}")
+ print(f"{'#' * 70}")
+
+ config = get_benchmark(benchmark_name)
+ engine = CircuitEngine(config.sim_path)
+ engine.parse()
+
+ # Use short simulation for analysis (just need compilation, not full run)
+ num_steps = 100
+ engine.prepare(
+ t_stop=config.dt * num_steps,
+ dt=config.dt,
+ use_sparse=False,
+ )
+
+ # Get strategy internals
+ strategy = engine._strategy
+ setup_cache = engine._transient_setup_cache
+
+ n_total = setup_cache["n_total"]
+ n_unknowns = setup_cache["n_unknowns"]
+ n_vsources = len([d for d in engine.devices if d["model"] == "vsource"])
+ n_isources = len([d for d in engine.devices if d["model"] == "isource"])
+ n_augmented = n_unknowns + n_vsources
+
+ print(f" Nodes: {n_total}, Unknowns: {n_unknowns}, Vsources: {n_vsources}")
+ print(f" Augmented system: {n_augmented}x{n_augmented}")
+
+ bench_dir = output_dir / benchmark_name
+
+ # 1. Analyze build_system (Jacobian + residual assembly)
+ build_fn = setup_cache.get("build_system_fn")
+ device_arrays = engine._device_arrays
+ if build_fn is not None and device_arrays is not None:
+ X = jnp.zeros(n_total + n_vsources, dtype=jnp.float64)
+ vsource_vals = jnp.zeros(n_vsources, dtype=jnp.float64)
+ isource_vals = jnp.zeros(max(n_isources, 0), dtype=jnp.float64)
+ Q_prev = jnp.zeros(n_unknowns, dtype=jnp.float64)
+ integ_c0 = jnp.asarray(1e9, dtype=jnp.float64) # typical 1/dt
+ gmin = jnp.asarray(1e-12, dtype=jnp.float64)
+ gshunt = jnp.asarray(0.0, dtype=jnp.float64)
+ integ_c1 = jnp.asarray(0.0, dtype=jnp.float64)
+ integ_d1 = jnp.asarray(0.0, dtype=jnp.float64)
+ dQdt_prev = jnp.zeros(n_unknowns, dtype=jnp.float64)
+ integ_c2 = jnp.asarray(0.0, dtype=jnp.float64)
+ Q_prev2 = jnp.zeros(n_unknowns, dtype=jnp.float64)
+ total_limit_states = setup_cache.get("total_limit_states", 0)
+ limit_state = jnp.zeros(total_limit_states, dtype=jnp.float64)
+ nr_iter = jnp.asarray(1, dtype=jnp.int32)
+
+ build_args = (
+ X,
+ vsource_vals,
+ isource_vals,
+ Q_prev,
+ integ_c0,
+ device_arrays,
+ gmin,
+ gshunt,
+ integ_c1,
+ integ_d1,
+ dQdt_prev,
+ integ_c2,
+ Q_prev2,
+ limit_state,
+ nr_iter,
+ )
+ analyze_function("build_system", build_fn, build_args, bench_dir)
+
+ # 2. Analyze nr_solve
+ nr_solve = setup_cache.get("nr_solve_fn")
+ if nr_solve is not None and device_arrays is not None:
+ X_init = jnp.zeros(n_total + n_vsources, dtype=jnp.float64)
+ vsource_vals = jnp.zeros(n_vsources, dtype=jnp.float64)
+ isource_vals = jnp.zeros(max(n_isources, 0), dtype=jnp.float64)
+ Q_prev = jnp.zeros(n_unknowns, dtype=jnp.float64)
+ integ_c0 = jnp.asarray(1e9, dtype=jnp.float64)
+
+ nr_args = (X_init, vsource_vals, isource_vals, Q_prev, integ_c0, device_arrays)
+ analyze_function("nr_solve", nr_solve, nr_args, bench_dir)
+
+ # 3. Analyze full transient step (run_while)
+ run_while = getattr(strategy, "_jit_run_while", None)
+ if run_while is None:
+ # Try to find it in the cache
+ run_while = strategy._jit_run_while_cache.get(
+ strategy._get_cache_key() if hasattr(strategy, "_get_cache_key") else None
+ )
+
+ if run_while is not None:
+ print("\n Found run_while - analyzing full transient loop")
+ # This one is harder to trace without the actual state
+ # We'd need to construct a FullMNAState - skip for now
+ print(" (skipping run_while - complex state tuple)")
+
+
+if __name__ == "__main__":
+ import argparse
+
+ parser = argparse.ArgumentParser(description="Analyze dense benchmark JAX IR")
+ parser.add_argument(
+ "--benchmark",
+ default="rc,graetz,mul,ring",
+ help="Comma-separated benchmarks",
+ )
+ parser.add_argument(
+ "--output-dir",
+ default="/tmp/claude/jaxpr-analysis",
+ help="Output directory for HLO files",
+ )
+ args = parser.parse_args()
+
+ output_dir = Path(args.output_dir)
+ benchmarks = [b.strip() for b in args.benchmark.split(",")]
+
+ for bench in benchmarks:
+ analyze_benchmark(bench, output_dir)
diff --git a/scripts/analyze_parallelism.py b/scripts/analyze_parallelism.py
new file mode 100644
index 00000000..a150de27
--- /dev/null
+++ b/scripts/analyze_parallelism.py
@@ -0,0 +1,1194 @@
+# /// script
+# requires-python = ">=3.10"
+# dependencies = []
+# ///
+"""Analyze parallelism opportunities in VAJAX simulation matrices.
+
+For IREE & Baspacho test case context: given a circuit's Jacobian sparsity
+pattern, reports what parallelism can be exploited during factorization,
+assembly, and device evaluation.
+
+Key outputs:
+- Elimination tree: dependency structure for sparse factorization
+- Level-set parallelism: how many columns can be processed simultaneously
+- Supernodal structure: dense blocks exploitable with BLAS-3
+- Fill-in analysis: memory requirements for factorization
+- Device evaluation parallelism: scatter pattern from vmap'd device evals
+- Pattern stability: sparsity is fixed across all NR iterations
+
+Usage:
+ # Analyze from a benchmark (captures matrices + device info)
+ JAX_PLATFORMS=cpu uv run scripts/analyze_parallelism.py ring
+ JAX_PLATFORMS=cpu uv run scripts/analyze_parallelism.py c6288
+
+ # Analyze existing Matrix Market file
+ uv run scripts/analyze_parallelism.py --from-mtx path/to/jacobian_0000.mtx
+
+ # Output to specific directory
+ JAX_PLATFORMS=cpu uv run scripts/analyze_parallelism.py ring --output-dir /tmp/ring_par
+"""
+
+import argparse
+import json
+import os
+import sys
+from collections import Counter
+from pathlib import Path
+
+os.environ.setdefault("JAX_PLATFORMS", "cpu")
+
+import numpy as np
+import scipy.io
+import scipy.sparse as sp
+import scipy.sparse.linalg
+
+# ---------------------------------------------------------------------------
+# Elimination tree
+# ---------------------------------------------------------------------------
+
+
+def symmetrize_pattern(A: sp.spmatrix) -> sp.csc_matrix:
+ """Compute |A| + |A^T| as a binary pattern (no values, just structure)."""
+ A_csc = sp.csc_matrix(A)
+ # Binary pattern: set all values to 1
+ A_bin = sp.csc_matrix((np.ones(A_csc.nnz), A_csc.indices, A_csc.indptr), shape=A_csc.shape)
+ A_sym = A_bin + A_bin.T
+ # Re-binarize (eliminates any 2s from diagonal overlap)
+ A_sym.data[:] = 1.0
+ A_sym.eliminate_zeros()
+ return sp.csc_matrix(A_sym)
+
+
+def compute_etree(A_csc: sp.csc_matrix) -> np.ndarray:
+ """Compute elimination tree of a symmetric matrix.
+
+ Uses Liu's algorithm with path compression (union-find).
+ Only the upper triangle is used.
+
+ Args:
+ A_csc: Symmetric matrix in CSC format
+
+ Returns:
+ parent array where parent[i] is the parent of column i,
+ or -1 for root(s)
+ """
+ n = A_csc.shape[0]
+ parent = np.full(n, -1, dtype=np.int64)
+ ancestor = np.arange(n, dtype=np.int64)
+
+ indptr = A_csc.indptr
+ indices = A_csc.indices
+
+ for k in range(n):
+ for ptr in range(indptr[k], indptr[k + 1]):
+ i = indices[ptr]
+ if i >= k:
+ continue
+ # Find root of i with path compression
+ r = i
+ while ancestor[r] != r:
+ r = ancestor[r]
+ if r != k:
+ parent[r] = k
+ ancestor[r] = k
+ # Path compression for i
+ r = i
+ while ancestor[r] != k:
+ t = ancestor[r]
+ ancestor[r] = k
+ r = t
+
+ return parent
+
+
+def compute_level_sets(parent: np.ndarray) -> list[list[int]]:
+ """Compute level sets from an elimination tree.
+
+ Level 0 = leaves, higher levels = closer to root.
+ Columns at the same level have no dependencies and can be
+ processed in parallel.
+
+ Returns:
+ List of levels, where levels[d] = list of column indices at depth d
+ (depth measured from leaves, so leaves are at depth 0)
+ """
+ n = len(parent)
+ # Compute depth from root first
+ depth_from_root = np.full(n, -1, dtype=np.int64)
+
+ # Find roots
+ roots = np.where(parent == -1)[0]
+ for r in roots:
+ depth_from_root[r] = 0
+
+ # BFS from roots to compute depth_from_root
+ # Build children lists for top-down traversal
+ children = [[] for _ in range(n)]
+ for i in range(n):
+ if parent[i] != -1:
+ children[parent[i]].append(i)
+
+ queue = list(roots)
+ head = 0
+ while head < len(queue):
+ node = queue[head]
+ head += 1
+ for child in children[node]:
+ depth_from_root[child] = depth_from_root[node] + 1
+ queue.append(child)
+
+ max_depth = int(np.max(depth_from_root)) if n > 0 else 0
+
+ # Convert to bottom-up levels (leaves = 0)
+ depth_from_leaves = max_depth - depth_from_root
+
+ levels: list[list[int]] = [[] for _ in range(max_depth + 1)]
+ for i in range(n):
+ levels[depth_from_leaves[i]].append(i)
+
+ return levels
+
+
+def compute_etree_stats(parent: np.ndarray, levels: list[list[int]]) -> dict:
+ """Compute statistics about elimination tree parallelism."""
+ n = len(parent)
+ widths = [len(level) for level in levels]
+ height = len(levels)
+
+ # Count leaves (nodes with no children)
+ has_child = np.zeros(n, dtype=bool)
+ for i in range(n):
+ if parent[i] != -1:
+ has_child[parent[i]] = True
+ n_leaves = int(np.sum(~has_child))
+
+ # Subtree sizes
+ subtree_size = np.ones(n, dtype=np.int64)
+ # Process bottom-up: levels[0] are leaves
+ for level in levels:
+ for node in level:
+ if parent[node] != -1:
+ subtree_size[parent[node]] += subtree_size[node]
+
+ return {
+ "height": height,
+ "n_leaves": n_leaves,
+ "max_parallelism": max(widths) if widths else 0,
+ "avg_parallelism": float(np.mean(widths)) if widths else 0,
+ "min_parallelism": min(widths) if widths else 0,
+ "level_widths": widths,
+ "subtree_size_stats": {
+ "min": int(np.min(subtree_size)) if n > 0 else 0,
+ "max": int(np.max(subtree_size)) if n > 0 else 0,
+ "mean": float(np.mean(subtree_size)) if n > 0 else 0,
+ "median": float(np.median(subtree_size)) if n > 0 else 0,
+ },
+ }
+
+
+# ---------------------------------------------------------------------------
+# Supernodal detection
+# ---------------------------------------------------------------------------
+
+
+def detect_supernodes(parent: np.ndarray, A_csc: sp.csc_matrix) -> list[list[int]]:
+ """Detect fundamental supernodes in the elimination tree.
+
+ A fundamental supernode is a maximal chain of consecutive columns
+ j, j+1, ..., j+k where:
+ - parent[j] = j+1, parent[j+1] = j+2, ..., parent[j+k-1] = j+k
+ - The columns have nested sparsity patterns (each is a subset of the next)
+
+ These can be factored as dense blocks using BLAS-3 operations.
+ """
+ n = A_csc.shape[0]
+ if n == 0:
+ return []
+
+ # Count children per node
+ n_children = np.zeros(n, dtype=np.int64)
+ for i in range(n):
+ if parent[i] != -1:
+ n_children[parent[i]] += 1
+
+ # A node starts a new supernode if:
+ # - It has more than one child, OR
+ # - It is not the only child of its parent, OR
+ # - parent[i] != i + 1
+ is_supernode_start = np.ones(n, dtype=bool)
+ for i in range(n - 1):
+ if parent[i] == i + 1 and n_children[i + 1] == 1:
+ is_supernode_start[i + 1] = False
+
+ supernodes: list[list[int]] = []
+ current: list[int] = []
+ for i in range(n):
+ if is_supernode_start[i]:
+ if current:
+ supernodes.append(current)
+ current = [i]
+ else:
+ current.append(i)
+ if current:
+ supernodes.append(current)
+
+ return supernodes
+
+
+def supernode_stats(supernodes: list[list[int]]) -> dict:
+ """Compute statistics about supernodal structure."""
+ sizes = [len(s) for s in supernodes]
+ size_counts = Counter(sizes)
+
+ # Bucket into histogram ranges
+ histogram = {}
+ for size, count in sorted(size_counts.items()):
+ if size == 1:
+ histogram["1"] = histogram.get("1", 0) + count
+ elif size <= 4:
+ histogram["2-4"] = histogram.get("2-4", 0) + count
+ elif size <= 10:
+ histogram["5-10"] = histogram.get("5-10", 0) + count
+ elif size <= 50:
+ histogram["11-50"] = histogram.get("11-50", 0) + count
+ else:
+ histogram["51+"] = histogram.get("51+", 0) + count
+
+ return {
+ "count": len(supernodes),
+ "largest": max(sizes) if sizes else 0,
+ "mean_size": float(np.mean(sizes)) if sizes else 0,
+ "median_size": float(np.median(sizes)) if sizes else 0,
+ "size_histogram": histogram,
+ }
+
+
+# ---------------------------------------------------------------------------
+# Fill-in analysis
+# ---------------------------------------------------------------------------
+
+
+def fill_in_analysis(A: sp.spmatrix) -> dict:
+ """Analyze fill-in from LU factorization using scipy's SuperLU.
+
+ Uses MMD_AT_PLUS_A ordering for fill-reducing permutation.
+ """
+ A_csc = sp.csc_matrix(A, dtype=np.float64)
+ n = A_csc.shape[0]
+ original_nnz = A_csc.nnz
+
+ results = {}
+
+ for ordering_name, permc_spec in [
+ ("MMD_AT_PLUS_A", "MMD_AT_PLUS_A"),
+ ("COLAMD", "COLAMD"),
+ ]:
+ try:
+ lu = scipy.sparse.linalg.splu(
+ A_csc,
+ permc_spec=permc_spec,
+ options={"SymmetricMode": False},
+ )
+ l_nnz = lu.L.nnz
+ u_nnz = lu.U.nnz
+ factor_nnz = l_nnz + u_nnz - n # subtract diagonal counted twice
+
+ results[ordering_name] = {
+ "L_nnz": l_nnz,
+ "U_nnz": u_nnz,
+ "factor_nnz": factor_nnz,
+ "fill_ratio": factor_nnz / max(original_nnz, 1),
+ "fill_in": factor_nnz - original_nnz,
+ }
+ except Exception as e:
+ results[ordering_name] = {"error": str(e)}
+
+ return {
+ "original_nnz": original_nnz,
+ "orderings": results,
+ "best_ordering": min(
+ (k for k, v in results.items() if "error" not in v),
+ key=lambda k: results[k]["factor_nnz"],
+ default=None,
+ ),
+ }
+
+
+# ---------------------------------------------------------------------------
+# Matrix structure analysis
+# ---------------------------------------------------------------------------
+
+
+def matrix_structure_analysis(A: sp.spmatrix) -> dict:
+ """Analyze structural properties of the matrix."""
+ A_csc = sp.csc_matrix(A)
+ A_csr = sp.csr_matrix(A)
+ n = A_csc.shape[0]
+
+ # Bandwidth
+ rows, cols = A_csc.nonzero()
+ if len(rows) > 0:
+ bandwidth = int(np.max(np.abs(rows - cols)))
+ profile = int(np.sum(np.abs(rows - cols)))
+ else:
+ bandwidth = 0
+ profile = 0
+
+ # Degree distribution (treating matrix as adjacency matrix)
+ row_nnz = np.diff(A_csr.indptr)
+ col_nnz = np.diff(A_csc.indptr)
+
+ # Symmetry check
+ A_T = A_csc.T
+ sym_diff = A_csc - A_T
+ sym_diff.eliminate_zeros()
+ is_structurally_symmetric = sym_diff.nnz == 0
+
+ # Check for numerical symmetry
+ if is_structurally_symmetric:
+ val_diff = np.max(np.abs(A_csc.data - A_T.tocsc().data)) if A_csc.nnz > 0 else 0
+ is_numerically_symmetric = val_diff < 1e-10
+ else:
+ is_numerically_symmetric = False
+
+ # Connected components (treating as undirected graph)
+ A_sym_pattern = symmetrize_pattern(A_csc)
+ n_components, labels = sp.csgraph.connected_components(A_sym_pattern, directed=False)
+
+ component_sizes = Counter(labels.tolist())
+ component_size_list = sorted(component_sizes.values(), reverse=True)
+
+ # Diagonal dominance check
+ diag = np.abs(A_csc.diagonal())
+ row_sums = np.array(np.abs(A_csr).sum(axis=1)).ravel()
+ off_diag_sums = row_sums - diag
+ diag_dominant_rows = int(np.sum(diag >= off_diag_sums))
+
+ return {
+ "size": n,
+ "nnz": A_csc.nnz,
+ "density_pct": A_csc.nnz / (n * n) * 100 if n > 0 else 0,
+ "bandwidth": bandwidth,
+ "profile": profile,
+ "is_structurally_symmetric": is_structurally_symmetric,
+ "is_numerically_symmetric": is_numerically_symmetric,
+ "connected_components": n_components,
+ "component_sizes": component_size_list[:10], # Top 10
+ "diagonal_dominance": {
+ "dominant_rows": diag_dominant_rows,
+ "total_rows": n,
+ "pct": diag_dominant_rows / n * 100 if n > 0 else 0,
+ },
+ "degree_stats": {
+ "row_min": int(np.min(row_nnz)) if n > 0 else 0,
+ "row_max": int(np.max(row_nnz)) if n > 0 else 0,
+ "row_mean": float(np.mean(row_nnz)) if n > 0 else 0,
+ "col_min": int(np.min(col_nnz)) if n > 0 else 0,
+ "col_max": int(np.max(col_nnz)) if n > 0 else 0,
+ "col_mean": float(np.mean(col_nnz)) if n > 0 else 0,
+ },
+ }
+
+
+# ---------------------------------------------------------------------------
+# RCM ordering analysis
+# ---------------------------------------------------------------------------
+
+
+def rcm_analysis(A: sp.spmatrix) -> dict:
+ """Analyze effect of Reverse Cuthill-McKee ordering."""
+ A_sym = symmetrize_pattern(A)
+
+ try:
+ perm = sp.csgraph.reverse_cuthill_mckee(A_sym, symmetric_mode=True)
+ A_rcm = A_sym[perm][:, perm]
+
+ rows_orig, cols_orig = A_sym.nonzero()
+ rows_rcm, cols_rcm = A_rcm.nonzero()
+
+ bw_orig = int(np.max(np.abs(rows_orig - cols_orig))) if len(rows_orig) > 0 else 0
+ bw_rcm = int(np.max(np.abs(rows_rcm - cols_rcm))) if len(rows_rcm) > 0 else 0
+
+ return {
+ "bandwidth_original": bw_orig,
+ "bandwidth_rcm": bw_rcm,
+ "bandwidth_reduction_pct": (1 - bw_rcm / max(bw_orig, 1)) * 100,
+ "permutation_available": True,
+ }
+ except Exception as e:
+ return {"error": str(e), "permutation_available": False}
+
+
+# ---------------------------------------------------------------------------
+# Device scatter pattern analysis
+# ---------------------------------------------------------------------------
+
+
+def analyze_device_scatter(engine) -> dict:
+ """Analyze device-to-matrix scatter patterns for assembly parallelism.
+
+ Examines the stamp index mappings to determine:
+ - How many matrix positions are written by multiple devices (conflicts)
+ - Maximum fan-in to any single position
+ - Independence structure between device evaluations
+ """
+ setup = engine._build_transient_setup(backend="cpu", use_dense=True)
+ static_inputs_cache = setup["static_inputs_cache"]
+ n_unknowns = setup["n_unknowns"]
+
+ model_info = {}
+ # Global position → set of unique (model_type, device_idx) writers
+ global_position_writers: dict[tuple[int, int], set[tuple[str, int]]] = {}
+
+ for model_type, (voltage_indices, stamp_indices, *_rest) in static_inputs_cache.items():
+ jac_rows = np.asarray(stamp_indices["jac_row_indices"])
+ jac_cols = np.asarray(stamp_indices["jac_col_indices"])
+ res_indices = np.asarray(stamp_indices["res_indices"])
+
+ n_devices = jac_rows.shape[0]
+ n_jac_entries = jac_rows.shape[1]
+ n_residuals = res_indices.shape[1]
+
+ # Count unique positions per device
+ positions_per_device = []
+ for dev_idx in range(n_devices):
+ valid = (jac_rows[dev_idx] >= 0) & (jac_cols[dev_idx] >= 0)
+ unique_pos = set()
+ for j in range(n_jac_entries):
+ if valid[j]:
+ pos = (int(jac_rows[dev_idx, j]), int(jac_cols[dev_idx, j]))
+ unique_pos.add(pos)
+ if pos not in global_position_writers:
+ global_position_writers[pos] = set()
+ global_position_writers[pos].add((model_type, dev_idx))
+ positions_per_device.append(len(unique_pos))
+
+ # Count touched nodes per device (for residual fan-out)
+ nodes_per_device = []
+ for dev_idx in range(n_devices):
+ valid_nodes = set()
+ for r in range(n_residuals):
+ idx = int(res_indices[dev_idx, r])
+ if idx >= 0:
+ valid_nodes.add(idx)
+ nodes_per_device.append(len(valid_nodes))
+
+ model_info[model_type] = {
+ "n_devices": n_devices,
+ "jac_entries_per_device": n_jac_entries,
+ "residuals_per_device": n_residuals,
+ "unique_positions_per_device": {
+ "min": min(positions_per_device) if positions_per_device else 0,
+ "max": max(positions_per_device) if positions_per_device else 0,
+ "mean": float(np.mean(positions_per_device)) if positions_per_device else 0,
+ },
+ "nodes_per_device": {
+ "min": min(nodes_per_device) if nodes_per_device else 0,
+ "max": max(nodes_per_device) if nodes_per_device else 0,
+ "mean": float(np.mean(nodes_per_device)) if nodes_per_device else 0,
+ },
+ }
+
+ # Analyze scatter conflicts (unique devices per position)
+ fan_in_counts = [len(writers) for writers in global_position_writers.values()]
+ fan_in_counter = Counter(fan_in_counts)
+ conflict_positions = sum(1 for c in fan_in_counts if c > 1)
+
+ # Build device conflict graph: two devices conflict if they write to
+ # the same matrix position
+ n_total_devices = sum(info["n_devices"] for info in model_info.values())
+ conflict_edges = 0
+ for writers in global_position_writers.values():
+ n_writers = len(writers)
+ if n_writers > 1:
+ conflict_edges += n_writers * (n_writers - 1) // 2
+
+ return {
+ "n_unknowns": n_unknowns,
+ "total_devices": n_total_devices,
+ "model_types": model_info,
+ "scatter_conflicts": {
+ "total_positions": len(global_position_writers),
+ "conflict_positions": conflict_positions,
+ "conflict_pct": conflict_positions / max(len(global_position_writers), 1) * 100,
+ "max_fan_in": max(fan_in_counts) if fan_in_counts else 0,
+ "fan_in_distribution": {str(k): v for k, v in sorted(fan_in_counter.items())},
+ },
+ "device_conflict_graph": {
+ "n_nodes": n_total_devices,
+ "n_edges": conflict_edges,
+ "note": "Edges connect devices that write to the same matrix position",
+ },
+ }
+
+
+# ---------------------------------------------------------------------------
+# Pattern stability check
+# ---------------------------------------------------------------------------
+
+
+def check_pattern_stability(matrices: list[sp.spmatrix]) -> dict:
+ """Verify that sparsity pattern is identical across NR iterations.
+
+ This is a key property for IREE: the pattern is fixed, only values change,
+ so symbolic analysis can be compiled once and reused.
+ """
+ if len(matrices) < 2:
+ return {
+ "is_fixed": True,
+ "n_samples": len(matrices),
+ "note": "Only one matrix available, cannot verify stability",
+ }
+
+ ref = sp.csc_matrix(matrices[0])
+ ref_pattern = set(zip(*ref.nonzero()))
+
+ all_match = True
+ first_mismatch = None
+
+ for idx, M in enumerate(matrices[1:], 1):
+ M_csc = sp.csc_matrix(M)
+ M_pattern = set(zip(*M_csc.nonzero()))
+
+ if M_pattern != ref_pattern:
+ all_match = False
+ added = M_pattern - ref_pattern
+ removed = ref_pattern - M_pattern
+ first_mismatch = {
+ "index": idx,
+ "added_entries": len(added),
+ "removed_entries": len(removed),
+ }
+ break
+
+ # Value variation statistics (how much do values change across iterations?)
+ if all_match and len(matrices) >= 2:
+ values = np.column_stack([sp.csc_matrix(M).data for M in matrices])
+ rel_variation = np.std(values, axis=1) / (np.abs(np.mean(values, axis=1)) + 1e-30)
+ value_stats = {
+ "mean_relative_variation": float(np.mean(rel_variation)),
+ "max_relative_variation": float(np.max(rel_variation)),
+ "median_relative_variation": float(np.median(rel_variation)),
+ }
+ else:
+ value_stats = None
+
+ return {
+ "is_fixed": all_match,
+ "n_samples": len(matrices),
+ "first_mismatch": first_mismatch,
+ "value_variation": value_stats,
+ "note": (
+ "Sparsity pattern is identical across all samples — symbolic "
+ "factorization can be compiled once and reused for all NR iterations"
+ if all_match
+ else "WARNING: Sparsity pattern changes between iterations"
+ ),
+ }
+
+
+# ---------------------------------------------------------------------------
+# Full analysis pipeline
+# ---------------------------------------------------------------------------
+
+
+def analyze_matrix(
+ A: sp.spmatrix,
+ name: str = "",
+ all_matrices: list[sp.spmatrix] | None = None,
+) -> dict:
+ """Run full parallelism analysis on a Jacobian matrix.
+
+ Args:
+ A: The Jacobian matrix (any sparse format)
+ name: Circuit/benchmark name for labeling
+ all_matrices: Optional list of matrices for pattern stability check
+
+ Returns:
+ Dict with all analysis results
+ """
+ A_csc = sp.csc_matrix(A, dtype=np.float64)
+ n = A_csc.shape[0]
+
+ print(f"Analyzing {n}x{n} matrix ({A_csc.nnz} nonzeros)...")
+
+ # 1. Matrix structure
+ print(" Matrix structure...")
+ structure = matrix_structure_analysis(A_csc)
+
+ # 2. Elimination tree on symmetrized pattern
+ print(" Elimination tree...")
+ A_sym = symmetrize_pattern(A_csc)
+ parent = compute_etree(A_sym)
+ levels = compute_level_sets(parent)
+ etree_stats = compute_etree_stats(parent, levels)
+
+ # 3. Supernodes
+ print(" Supernodal detection...")
+ supernodes = detect_supernodes(parent, A_sym)
+ snode_stats = supernode_stats(supernodes)
+
+ # 4. Fill-in analysis
+ print(" Fill-in analysis (SuperLU)...")
+ fill = fill_in_analysis(A_csc)
+
+ # 5. RCM ordering
+ print(" RCM ordering...")
+ rcm = rcm_analysis(A_csc)
+
+ # 6. Pattern stability
+ stability = None
+ if all_matrices and len(all_matrices) > 1:
+ print(f" Pattern stability ({len(all_matrices)} samples)...")
+ stability = check_pattern_stability(all_matrices)
+
+ # Compute parallelism summary
+ # "Work" at each level = width (number of independent columns)
+ # Total sequential steps = height
+ # Total work = n (all columns must be processed)
+ # Parallelism efficiency = n / height (ideal speedup from parallelism)
+ parallelism_efficiency = n / max(etree_stats["height"], 1)
+
+ analysis = {
+ "name": name,
+ "matrix": structure,
+ "_etree_parent": parent, # Full array, not serialized to JSON
+ "elimination_tree": {
+ **etree_stats,
+ "parallelism_efficiency": parallelism_efficiency,
+ "parent_array_sample": parent[: min(50, n)].tolist(),
+ "note": (
+ f"Height {etree_stats['height']} levels with max width "
+ f"{etree_stats['max_parallelism']}. Columns at the same level "
+ f"can be factored in parallel. Efficiency = n/height = "
+ f"{parallelism_efficiency:.1f}x theoretical speedup."
+ ),
+ },
+ "supernodes": snode_stats,
+ "fill_in": fill,
+ "rcm_ordering": rcm,
+ }
+
+ if stability is not None:
+ analysis["pattern_stability"] = stability
+
+ return analysis
+
+
+# ---------------------------------------------------------------------------
+# Device eval branch analysis
+# ---------------------------------------------------------------------------
+
+
+def analyze_eval_branches(engine) -> dict:
+ """Analyze jnp.where branches in compiled device eval functions.
+
+ For each model type, checks the compiled model's parameter split to determine:
+ - How many device configurations exist (e.g., NMOS vs PMOS)
+ - Whether all eval branches are statically determinable at setup time
+ - How much specialization is possible
+
+ This does NOT require dumping/parsing generated code — it analyzes the
+ actual shared_params, device_params, and device_cache arrays to determine
+ how many unique device variants exist.
+ """
+ result = {}
+
+ for model_type, compiled in engine._compiled_models.items():
+ if "shared_params" not in compiled:
+ continue
+
+ sp = np.asarray(compiled["shared_params"])
+ dp = np.asarray(compiled["device_params"])
+ sc = np.asarray(compiled.get("shared_cache", np.array([])))
+ dc = np.asarray(compiled.get("device_cache", np.empty((dp.shape[0], 0))))
+ vp = np.asarray(compiled.get("voltage_positions_in_varying", np.array([], dtype=int)))
+
+ n_devices = dp.shape[0]
+ n_varying = dp.shape[1] if dp.ndim > 1 else 0
+ n_voltages = len(vp)
+ n_static_varying = n_varying - n_voltages
+
+ # Identify non-voltage varying param columns
+ voltage_cols = set(vp.tolist()) if len(vp) > 0 else set()
+ static_cols = sorted(set(range(n_varying)) - voltage_cols)
+
+ # Count unique device configurations (static params only)
+ if static_cols and n_devices > 1:
+ static_dp = dp[:, static_cols]
+ unique_configs, config_indices, config_counts = np.unique(
+ static_dp, axis=0, return_inverse=True, return_counts=True
+ )
+ n_unique_configs = len(unique_configs)
+ config_sizes = config_counts.tolist()
+ elif n_devices > 1:
+ # No static varying params — all devices identical
+ n_unique_configs = 1
+ config_sizes = [n_devices]
+ else:
+ n_unique_configs = 1
+ config_sizes = [1]
+
+ # Check device_cache uniformity
+ n_cache_cols = dc.shape[1] if dc.ndim > 1 else 0
+ if n_cache_cols > 0 and n_devices > 1:
+ cache_uniform = int(np.sum(np.all(dc == dc[0:1, :], axis=0)))
+ cache_varying = n_cache_cols - cache_uniform
+
+ # Count unique cache configurations
+ unique_dc, dc_indices = np.unique(dc, axis=0, return_inverse=True)
+ n_unique_cache = len(unique_dc)
+ else:
+ cache_uniform = n_cache_cols
+ cache_varying = 0
+ n_unique_cache = 1
+
+ # Get param names for the varying columns if available
+ param_names = compiled.get("param_names", [])
+ param_kinds = compiled.get("param_kinds", [])
+ varying_indices = compiled.get("varying_indices", [])
+
+ varying_param_info = []
+ for col_idx, orig_idx in enumerate(varying_indices):
+ if col_idx in voltage_cols:
+ continue
+ name = param_names[orig_idx] if orig_idx < len(param_names) else f"param_{orig_idx}"
+ kind = param_kinds[orig_idx] if orig_idx < len(param_kinds) else "unknown"
+ if n_devices > 1:
+ vals = dp[:, col_idx]
+ unique_vals = np.unique(vals)
+ varying_param_info.append(
+ {
+ "name": name,
+ "kind": kind,
+ "n_unique": len(unique_vals),
+ "values": unique_vals.tolist()
+ if len(unique_vals) <= 10
+ else f"{len(unique_vals)} values",
+ }
+ )
+
+ result[model_type] = {
+ "n_devices": n_devices,
+ "n_shared_params": len(sp),
+ "n_varying_params": n_varying,
+ "n_voltage_params": n_voltages,
+ "n_static_varying_params": n_static_varying,
+ "n_shared_cache": len(sc) if sc.ndim == 1 else (sc.shape[1] if sc.ndim > 1 else 0),
+ "n_device_cache_cols": n_cache_cols,
+ "cache_uniform_cols": cache_uniform,
+ "cache_varying_cols": cache_varying,
+ "n_unique_param_configs": n_unique_configs,
+ "n_unique_cache_configs": n_unique_cache,
+ "config_sizes": config_sizes,
+ "varying_static_params": varying_param_info,
+ "specialization_note": (
+ f"All {n_devices} devices can be grouped into {n_unique_configs} "
+ f"specialized eval variant(s). Branches conditioned on shared_params "
+ f"({len(sp)} params) and device configuration ({n_static_varying} "
+ f"static varying params) can be resolved at compile time, eliminating "
+ f"jnp.where overhead for straight-line GPU kernels."
+ ),
+ }
+
+ return result
+
+
+# ---------------------------------------------------------------------------
+# Benchmark mode: run simulation and analyze
+# ---------------------------------------------------------------------------
+
+
+def analyze_benchmark(
+ benchmark_name: str,
+ max_captures: int = 20,
+ t_stop_override: float | None = None,
+) -> dict:
+ """Run a benchmark simulation, capture matrices, and analyze parallelism.
+
+ Also captures device scatter pattern information.
+ """
+ import jax
+
+ from vajax.analysis import CircuitEngine
+ from vajax.benchmarks.registry import get_benchmark
+
+ info = get_benchmark(benchmark_name)
+ assert info is not None, f"Benchmark '{benchmark_name}' not found"
+
+ engine = CircuitEngine(info.sim_path)
+ engine.parse()
+
+ use_sparse = info.is_large
+ dt = info.dt
+ # Use override, or run just enough steps to capture max_captures matrices
+ # (~5 NR iterations per timestep, so max_captures/5 timesteps plus margin)
+ if t_stop_override is not None:
+ t_stop = t_stop_override
+ else:
+ # Short simulation: enough for max_captures NR systems
+ min_steps = max_captures * 2 # ~2x margin (5 NR iters, capture early ones)
+ t_stop = min(dt * min_steps, info.t_stop)
+
+ print(f"Benchmark: {benchmark_name}")
+ print(f" Nodes: {engine.num_nodes}, Devices: {len(engine.devices)}")
+ print(f" Solver: {'sparse' if use_sparse else 'dense'}")
+ print(f" Transient: t_stop={t_stop:.2e}s, dt={dt:.2e}s")
+
+ # Suppress step-by-step logging during simulation
+ import logging
+
+ logging.getLogger("vajax").setLevel(logging.WARNING)
+
+ # --- Device scatter analysis (before running simulation) ---
+ print("\nAnalyzing device scatter patterns...")
+ device_scatter = analyze_device_scatter(engine)
+
+ # --- Capture matrices via monkey-patching ---
+ import vajax.analysis.solver_factories as sf
+
+ captured_systems: list[tuple[np.ndarray, np.ndarray]] = []
+ csr_info: dict = {}
+ capture_count = [0]
+
+ def capture_cb(J_or_data: jax.Array, f: jax.Array):
+ if capture_count[0] >= max_captures:
+ return
+ captured_systems.append((np.asarray(J_or_data).copy(), np.asarray(f).copy()))
+ capture_count[0] += 1
+
+ original_make_nr = sf._make_nr_solver_common
+
+ def patched_nr(*, linear_solve_fn, **kwargs):
+ def instrumented(J_or_data, f):
+ jax.debug.callback(capture_cb, J_or_data, f)
+ return linear_solve_fn(J_or_data, f)
+
+ return original_make_nr(linear_solve_fn=instrumented, **kwargs)
+
+ sf._make_nr_solver_common = patched_nr
+
+ # Intercept sparse factories for CSR structure
+ for factory_name in ("make_umfpack_ffi_full_mna_solver", "make_spineax_full_mna_solver"):
+ original = getattr(sf, factory_name)
+
+ def make_patched(orig):
+ def patched(*args, **kwargs):
+ n_nodes = args[1] if len(args) > 1 else kwargs.get("n_nodes")
+ n_vsources = args[2] if len(args) > 2 else kwargs.get("n_vsources")
+ bcsr_indptr = kwargs.get("bcsr_indptr")
+ bcsr_indices = kwargs.get("bcsr_indices")
+ if bcsr_indptr is None and len(args) > 4:
+ bcsr_indptr = args[4]
+ if bcsr_indices is None and len(args) > 5:
+ bcsr_indices = args[5]
+ if bcsr_indptr is not None and n_nodes is not None:
+ n_aug = n_nodes - 1 + n_vsources
+ csr_info["indptr"] = np.asarray(bcsr_indptr).copy()
+ csr_info["indices"] = np.asarray(bcsr_indices).copy()
+ csr_info["shape"] = (n_aug, n_aug)
+ return orig(*args, **kwargs)
+
+ return patched
+
+ patched = make_patched(original)
+ setattr(sf, factory_name, patched)
+
+ import vajax.analysis.transient.full_mna as _full_mna
+
+ setattr(_full_mna, factory_name, patched)
+
+ # --- Run simulation ---
+ print("\nRunning simulation...")
+ engine.prepare(t_stop=t_stop, dt=dt, use_sparse=use_sparse)
+ result = engine.run_transient()
+
+ convergence = result.stats.get("convergence_rate", 0) * 100
+ print(f" Steps: {result.num_steps}, convergence: {convergence:.0f}%")
+ print(f" Captured {len(captured_systems)} linear systems")
+
+ if not captured_systems:
+ print("ERROR: No systems captured!", file=sys.stderr)
+ sys.exit(1)
+
+ # --- Build sparse matrices from captured data ---
+ matrices: list[sp.spmatrix] = []
+ for J_or_data, f in captured_systems:
+ if use_sparse and "indptr" in csr_info:
+ mat = sp.csr_matrix(
+ (J_or_data, csr_info["indices"], csr_info["indptr"]),
+ shape=csr_info["shape"],
+ )
+ else:
+ mat = sp.csc_matrix(J_or_data)
+ matrices.append(mat)
+
+ # --- Analyze ---
+ print("\nAnalyzing first captured matrix...")
+ analysis = analyze_matrix(matrices[0], name=benchmark_name, all_matrices=matrices)
+
+ # Add device scatter info
+ analysis["device_parallelism"] = device_scatter
+
+ # Add circuit info
+ analysis["circuit"] = {
+ "name": benchmark_name,
+ "n_external_nodes": engine.num_nodes,
+ "n_devices": len(engine.devices),
+ "n_unknowns": device_scatter["n_unknowns"],
+ "simulation": {
+ "t_stop": t_stop,
+ "dt": dt,
+ "steps": result.num_steps,
+ "convergence_pct": convergence,
+ },
+ }
+
+ # Eval branch specialization analysis
+ print("\nAnalyzing eval function branches...")
+ branch_analysis = analyze_eval_branches(engine)
+ analysis["eval_specialization"] = branch_analysis
+
+ # Add compilation note for IREE
+ analysis["iree_notes"] = {
+ "pattern_is_fixed": analysis.get("pattern_stability", {}).get("is_fixed", True),
+ "same_pattern_every_nr_iteration": True,
+ "values_change_every_iteration": True,
+ "typical_nr_iterations_per_step": "3-8",
+ "typical_timesteps": f"{result.num_steps}",
+ "total_solves": len(captured_systems),
+ "recommendation": (
+ "The sparsity pattern is determined at circuit parse time and never changes. "
+ "Symbolic factorization (ordering, elimination tree, memory allocation) can "
+ "be compiled once. Only numerical factorization needs to run per NR iteration. "
+ f"For this circuit: {result.num_steps} timesteps x ~5 NR iters = "
+ f"~{result.num_steps * 5} factorizations with identical structure."
+ ),
+ }
+
+ return analysis
+
+
+# ---------------------------------------------------------------------------
+# Output
+# ---------------------------------------------------------------------------
+
+
+def write_analysis(analysis: dict, output_dir: Path):
+ """Write analysis results to output directory."""
+ output_dir.mkdir(parents=True, exist_ok=True)
+
+ # JSON output (machine-readable)
+ json_path = output_dir / "parallelism_analysis.json"
+
+ # Remove internal data not suitable for JSON
+ analysis_json = {k: v for k, v in analysis.items() if not k.startswith("_")}
+ analysis_json = json.loads(json.dumps(analysis_json, default=str))
+ widths = analysis_json.get("elimination_tree", {}).get("level_widths", [])
+ if len(widths) > 100:
+ analysis_json["elimination_tree"]["level_widths_truncated"] = (
+ widths[:50] + ["..."] + widths[-50:]
+ )
+ del analysis_json["elimination_tree"]["level_widths"]
+
+ with open(json_path, "w") as f:
+ json.dump(analysis_json, f, indent=2)
+ print(f" JSON: {json_path}")
+
+ # Human-readable summary
+ summary_path = output_dir / "parallelism_summary.txt"
+ with open(summary_path, "w") as f:
+ name = analysis.get("name", "unknown")
+ f.write(f"{'=' * 70}\n")
+ f.write(f"Parallelism Analysis: {name}\n")
+ f.write(f"{'=' * 70}\n\n")
+
+ mat = analysis["matrix"]
+ f.write(
+ f"Matrix: {mat['size']}x{mat['size']}, {mat['nnz']} nonzeros ({mat['density_pct']:.4f}%)\n"
+ )
+ f.write(f"Bandwidth: {mat['bandwidth']}, Symmetric: {mat['is_structurally_symmetric']}\n")
+ f.write(f"Connected components: {mat['connected_components']}\n")
+ deg = mat["degree_stats"]
+ f.write(
+ f"Row degree: min={deg['row_min']}, max={deg['row_max']}, mean={deg['row_mean']:.1f}\n"
+ )
+ f.write(f"Diagonal dominance: {mat['diagonal_dominance']['pct']:.1f}% of rows\n\n")
+
+ et = analysis["elimination_tree"]
+ f.write("--- Elimination Tree ---\n")
+ f.write(f"Height (sequential steps): {et['height']}\n")
+ f.write(f"Leaves: {et['n_leaves']}\n")
+ f.write(f"Max parallelism (widest level): {et['max_parallelism']}\n")
+ f.write(f"Avg parallelism: {et['avg_parallelism']:.1f}\n")
+ f.write(f"Parallelism efficiency (n/height): {et['parallelism_efficiency']:.1f}x\n")
+ st = et["subtree_size_stats"]
+ f.write(f"Subtree sizes: min={st['min']}, max={st['max']}, median={st['median']:.0f}\n\n")
+
+ sn = analysis["supernodes"]
+ f.write("--- Supernodes ---\n")
+ f.write(f"Count: {sn['count']} supernodes\n")
+ f.write(f"Largest: {sn['largest']} columns\n")
+ f.write(f"Mean size: {sn['mean_size']:.1f}\n")
+ f.write(f"Size distribution: {sn['size_histogram']}\n\n")
+
+ fi = analysis["fill_in"]
+ f.write("--- Fill-in (LU factorization) ---\n")
+ f.write(f"Original nnz: {fi['original_nnz']}\n")
+ for order_name, order_data in fi["orderings"].items():
+ if "error" not in order_data:
+ f.write(
+ f" {order_name}: factor_nnz={order_data['factor_nnz']}, "
+ f"fill_ratio={order_data['fill_ratio']:.2f}x, "
+ f"fill_in=+{order_data['fill_in']}\n"
+ )
+ if fi["best_ordering"]:
+ f.write(f"Best ordering: {fi['best_ordering']}\n")
+ f.write("\n")
+
+ rcm = analysis.get("rcm_ordering", {})
+ if rcm.get("permutation_available"):
+ f.write("--- RCM Ordering ---\n")
+ f.write(f"Bandwidth: {rcm['bandwidth_original']} -> {rcm['bandwidth_rcm']} ")
+ f.write(f"({rcm['bandwidth_reduction_pct']:.1f}% reduction)\n\n")
+
+ ps = analysis.get("pattern_stability")
+ if ps:
+ f.write("--- Pattern Stability ---\n")
+ f.write(f"Fixed pattern: {ps['is_fixed']} ({ps['n_samples']} samples)\n")
+ if ps.get("value_variation"):
+ vv = ps["value_variation"]
+ f.write(f"Value variation: mean_rel={vv['mean_relative_variation']:.4f}, ")
+ f.write(f"max_rel={vv['max_relative_variation']:.4f}\n")
+ f.write(f"{ps['note']}\n\n")
+
+ dp = analysis.get("device_parallelism")
+ if dp:
+ f.write("--- Device Evaluation Parallelism ---\n")
+ f.write(f"Total devices: {dp['total_devices']}\n")
+ for mt, mi in dp["model_types"].items():
+ f.write(f" {mt}: {mi['n_devices']} devices, ")
+ f.write(f"{mi['jac_entries_per_device']} Jacobian entries/device, ")
+ f.write(f"{mi['nodes_per_device']['mean']:.0f} nodes/device\n")
+ sc = dp["scatter_conflicts"]
+ f.write(
+ f"Scatter conflicts: {sc['conflict_positions']}/{sc['total_positions']} positions "
+ )
+ f.write(f"({sc['conflict_pct']:.1f}%), max fan-in={sc['max_fan_in']}\n")
+ f.write(f"Fan-in distribution: {sc['fan_in_distribution']}\n\n")
+
+ es = analysis.get("eval_specialization")
+ if es:
+ f.write("--- Eval Branch Specialization ---\n")
+ for mt, info in es.items():
+ f.write(f" {mt}: {info['n_devices']} devices\n")
+ f.write(f" Params: {info['n_shared_params']} shared, ")
+ f.write(f"{info['n_voltage_params']} voltage, ")
+ f.write(f"{info['n_static_varying_params']} static-varying\n")
+ f.write(f" Cache: {info['n_shared_cache']} shared, ")
+ f.write(f"{info['cache_varying_cols']} device-varying\n")
+ f.write(f" Unique device configs: {info['n_unique_param_configs']}")
+ f.write(f" (sizes: {info['config_sizes']})\n")
+ if info.get("varying_static_params"):
+ for vp in info["varying_static_params"]:
+ f.write(
+ f" {vp['name']} ({vp['kind']}): {vp['n_unique']} unique values\n"
+ )
+ f.write(f" {info['specialization_note']}\n")
+ f.write("\n")
+
+ notes = analysis.get("iree_notes")
+ if notes:
+ f.write("--- IREE/Baspacho Notes ---\n")
+ f.write(f"{notes['recommendation']}\n")
+
+ print(f" Summary: {summary_path}")
+
+ # Write full elimination tree parent array (useful for solver development)
+ if "_etree_parent" in analysis:
+ etree_path = output_dir / "etree_parent.npy"
+ np.save(etree_path, analysis["_etree_parent"])
+ print(f" Etree parent: {etree_path} ({len(analysis['_etree_parent'])} nodes)")
+
+ # Write level-set widths as CSV (for plotting)
+ widths = analysis.get("elimination_tree", {}).get("level_widths", [])
+ if widths:
+ widths_path = output_dir / "level_set_widths.csv"
+ with open(widths_path, "w") as f:
+ f.write("level,width\n")
+ for i, w in enumerate(widths):
+ f.write(f"{i},{w}\n")
+ print(f" Level widths: {widths_path}")
+
+
+# ---------------------------------------------------------------------------
+# Main
+# ---------------------------------------------------------------------------
+
+
+def main():
+ parser = argparse.ArgumentParser(
+ description="Analyze parallelism opportunities in VAJAX simulation matrices"
+ )
+ parser.add_argument(
+ "benchmark",
+ nargs="?",
+ help="Benchmark name (e.g. ring, c6288, graetz)",
+ )
+ parser.add_argument(
+ "--from-mtx",
+ type=Path,
+ nargs="+",
+ help="Analyze existing Matrix Market file(s)",
+ )
+ parser.add_argument(
+ "--output-dir",
+ type=Path,
+ default=None,
+ help="Output directory (default: /tmp/claude/_parallelism)",
+ )
+ parser.add_argument(
+ "--max-captures",
+ type=int,
+ default=20,
+ help="Max NR systems to capture for pattern stability check",
+ )
+ parser.add_argument(
+ "--t-stop",
+ type=float,
+ default=None,
+ help="Override transient stop time (default: auto-short for analysis)",
+ )
+ args = parser.parse_args()
+
+ if args.from_mtx:
+ # Load from Matrix Market files
+ matrices = []
+ for path in args.from_mtx:
+ print(f"Loading {path}...")
+ matrices.append(scipy.io.mmread(path))
+
+ name = args.from_mtx[0].stem.replace("jacobian_", "")
+ analysis = analyze_matrix(matrices[0], name=name, all_matrices=matrices)
+
+ out_dir = args.output_dir or Path(f"/tmp/claude/{name}_parallelism")
+ write_analysis(analysis, out_dir)
+
+ elif args.benchmark:
+ analysis = analyze_benchmark(
+ args.benchmark,
+ max_captures=args.max_captures,
+ t_stop_override=args.t_stop,
+ )
+ out_dir = args.output_dir or Path(f"/tmp/claude/{args.benchmark}_parallelism")
+ write_analysis(analysis, out_dir)
+
+ else:
+ parser.print_help()
+ sys.exit(1)
+
+ print("\nDone.")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/scripts/check_constant_folding.py b/scripts/check_constant_folding.py
new file mode 100644
index 00000000..52053657
--- /dev/null
+++ b/scripts/check_constant_folding.py
@@ -0,0 +1,282 @@
+#!/usr/bin/env python3
+# /// script
+# requires-python = ">=3.10"
+# dependencies = ["jax", "jaxlib"]
+# ///
+"""Check whether jnp.where constant folding works at jaxpr and HLO levels.
+
+Quick diagnostic to verify that inlining shared params as Python literals
+actually eliminates jnp.where branches in the compiled XLA program.
+"""
+
+import jax
+import jax.numpy as jnp
+
+
+def test_basic_constant_folding():
+ """Test: does jnp.where with a Python bool constant-fold?"""
+ print("=" * 60)
+ print("Test 1: jnp.where with Python bool literal")
+ print("=" * 60)
+
+ def f_const(x):
+ # Condition is a Python bool - should constant-fold
+ cond = True
+ return jnp.where(cond, x * 2, x * 3)
+
+ def f_traced(x, flag):
+ # Condition is a traced value - cannot constant-fold
+ return jnp.where(flag > 0.0, x * 2, x * 3)
+
+ x = jnp.ones(10)
+
+ jaxpr_const = jax.make_jaxpr(f_const)(x)
+ jaxpr_traced = jax.make_jaxpr(f_traced)(x, 1.0)
+
+ print(f"\nConstant cond jaxpr ({len(jaxpr_const.eqns)} ops):")
+ print(jaxpr_const)
+ print(f"\nTraced cond jaxpr ({len(jaxpr_traced.eqns)} ops):")
+ print(jaxpr_traced)
+
+ # Check HLO
+ lowered_const = jax.jit(f_const).lower(x)
+ lowered_traced = jax.jit(f_traced).lower(x, 1.0)
+ hlo_const = lowered_const.as_text()
+ hlo_traced = lowered_traced.as_text()
+
+ select_const = hlo_const.count("select")
+ select_traced = hlo_traced.count("select")
+ print(f"\nHLO select ops: constant={select_const}, traced={select_traced}")
+
+
+def test_constant_through_jnp_ops():
+ """Test: does constant folding survive through jnp operations?"""
+ print("\n" + "=" * 60)
+ print("Test 2: Constant folding through jnp operations")
+ print("=" * 60)
+
+ def f_inlined(x):
+ # Simulate what our specialization does:
+ # shared param inlined as literal, then used in jnp ops
+ v_param = 1.5e-6 # Was: shared_params[42]
+ v_computed = jnp.exp(v_param) # jnp op on literal
+ cond = v_computed > 0.5 # comparison
+ return jnp.where(cond, x * 2, x * 3)
+
+ def f_array_lookup(x, shared_params):
+ # Original: shared_params array lookup
+ v_param = shared_params[42]
+ v_computed = jnp.exp(v_param)
+ cond = v_computed > 0.5
+ return jnp.where(cond, x * 2, x * 3)
+
+ x = jnp.ones(10)
+ shared = jnp.zeros(100)
+
+ jaxpr_inlined = jax.make_jaxpr(f_inlined)(x)
+ jaxpr_lookup = jax.make_jaxpr(f_array_lookup)(x, shared)
+
+ print(f"\nInlined literal jaxpr ({len(jaxpr_inlined.eqns)} ops):")
+ print(jaxpr_inlined)
+ print(f"\nArray lookup jaxpr ({len(jaxpr_lookup.eqns)} ops):")
+ print(jaxpr_lookup)
+
+ # Check HLO
+ lowered_inlined = jax.jit(f_inlined).lower(x)
+ lowered_lookup = jax.jit(f_array_lookup).lower(x, shared)
+ hlo_inlined = lowered_inlined.as_text()
+ hlo_lookup = lowered_lookup.as_text()
+
+ select_inlined = hlo_inlined.count("select")
+ select_lookup = hlo_lookup.count("select")
+ print(f"\nHLO select ops: inlined={select_inlined}, lookup={select_lookup}")
+
+
+def test_python_float_vs_jnp():
+ """Test: Python float literal vs jnp operation - what does JAX trace?"""
+ print("\n" + "=" * 60)
+ print("Test 3: Python float arithmetic vs jnp arithmetic")
+ print("=" * 60)
+
+ def f_python_arith(x):
+ # Pure Python: should constant-fold completely
+ a = 1.5e-6
+ b = a * 2.0 # Python multiplication
+ cond = b > 1e-6 # Python comparison -> True
+ return jnp.where(cond, x * 2, x * 3)
+
+ def f_jnp_arith(x):
+ # jnp operations: might NOT constant-fold in jaxpr
+ a = 1.5e-6
+ b = jnp.float64(a) * 2.0 # jnp multiplication
+ cond = b > 1e-6 # comparison on jnp result
+ return jnp.where(cond, x * 2, x * 3)
+
+ x = jnp.ones(10)
+
+ jaxpr_python = jax.make_jaxpr(f_python_arith)(x)
+ jaxpr_jnp = jax.make_jaxpr(f_jnp_arith)(x)
+
+ print(f"\nPython arith jaxpr ({len(jaxpr_python.eqns)} ops):")
+ print(jaxpr_python)
+ print(f"\njnp arith jaxpr ({len(jaxpr_jnp.eqns)} ops):")
+ print(jaxpr_jnp)
+
+ # Check HLO for both
+ lowered_python = jax.jit(f_python_arith).lower(x)
+ lowered_jnp = jax.jit(f_jnp_arith).lower(x)
+ hlo_python = lowered_python.as_text()
+ hlo_jnp = lowered_jnp.as_text()
+
+ select_python = hlo_python.count("select")
+ select_jnp = hlo_jnp.count("select")
+ print(f"\nHLO select ops: python_arith={select_python}, jnp_arith={select_jnp}")
+
+
+def test_generated_code_pattern():
+ """Test the ACTUAL pattern used in generated eval code.
+
+ The generated code does:
+ v123 = 1.5e-6 # inlined literal (was shared_params[42])
+ Then later uses it in jnp operations.
+
+ Key question: does assigning a Python float to a local variable,
+ then using it in jnp.where, get constant-folded?
+ """
+ print("\n" + "=" * 60)
+ print("Test 4: Actual generated code pattern (assign + jnp.where)")
+ print("=" * 60)
+
+ def f_generated_pattern(device_params):
+ # This mimics the actual generated eval code pattern
+ # Inlined shared params
+ v100 = 1.0 # TYPE = 1 (NMOS)
+ v101 = 0.0 # SWIGATE = 0
+ v102 = 1.5e-6 # TOX
+
+ # Device params (from vmap, traced)
+ v200 = device_params[0] # voltage
+ _v201 = device_params[1] # another voltage (unused, kept for array shape)
+
+ # Computation chain (mimics what OpenVAF generates)
+ v300 = jnp.exp(v102 * 1e6) # Uses inlined literal
+ v301 = v300 * v200 # Mixes with traced value
+
+ # Branch on static param
+ v400 = v100 > 0.5 # TYPE > 0.5 -> True for NMOS
+ result1 = jnp.where(v400, v301, -v301) # Should fold
+
+ # Branch on static param through jnp op
+ v401 = jnp.abs(v101) # jnp op on inlined literal
+ v402 = v401 > 0.5 # Should be False
+ result2 = jnp.where(v402, result1 * 2, result1 * 3) # Should fold
+
+ return result1 + result2
+
+ def f_array_pattern(device_params, shared_params):
+ # Original pattern: array lookups (all traced)
+ v100 = shared_params[0]
+ v101 = shared_params[1]
+ v102 = shared_params[2]
+
+ v200 = device_params[0]
+ _v201 = device_params[1] # noqa: F841
+
+ v300 = jnp.exp(v102 * 1e6)
+ v301 = v300 * v200
+
+ v400 = v100 > 0.5
+ result1 = jnp.where(v400, v301, -v301)
+
+ v401 = jnp.abs(v101)
+ v402 = v401 > 0.5
+ result2 = jnp.where(v402, result1 * 2, result1 * 3)
+
+ return result1 + result2
+
+ dp = jnp.array([0.5, 0.3])
+ sp = jnp.array([1.0, 0.0, 1.5e-6])
+
+ jaxpr_gen = jax.make_jaxpr(f_generated_pattern)(dp)
+ jaxpr_arr = jax.make_jaxpr(f_array_pattern)(dp, sp)
+
+ print(f"\nInlined pattern jaxpr ({len(jaxpr_gen.eqns)} ops):")
+ print(jaxpr_gen)
+ print(f"\nArray pattern jaxpr ({len(jaxpr_arr.eqns)} ops):")
+ print(jaxpr_arr)
+
+ # Check HLO
+ lowered_gen = jax.jit(f_generated_pattern).lower(dp)
+ lowered_arr = jax.jit(f_array_pattern).lower(dp, sp)
+ hlo_gen = lowered_gen.as_text()
+ hlo_arr = lowered_arr.as_text()
+
+ select_gen = hlo_gen.count("select")
+ select_arr = hlo_arr.count("select")
+ print(f"\nHLO select ops: inlined={select_gen}, array={select_arr}")
+
+ # Also count total HLO ops
+ print(f"HLO lines: inlined={len(hlo_gen.splitlines())}, array={len(hlo_arr.splitlines())}")
+
+
+def test_vmap_interaction():
+ """Test: does constant folding survive vmap?
+
+ This is crucial because we vmap the eval function over devices.
+ """
+ print("\n" + "=" * 60)
+ print("Test 5: Constant folding under vmap")
+ print("=" * 60)
+
+ def f_inlined(device_params):
+ v_type = 1.0 # Inlined: TYPE = NMOS
+ cond = v_type > 0.5
+ return jnp.where(cond, device_params[0] * 2, device_params[0] * 3)
+
+ def f_lookup(device_params, shared_params):
+ v_type = shared_params[0]
+ cond = v_type > 0.5
+ return jnp.where(cond, device_params[0] * 2, device_params[0] * 3)
+
+ # vmap over batch of devices
+ f_inlined_vmapped = jax.vmap(f_inlined)
+ f_lookup_vmapped = jax.vmap(f_lookup, in_axes=(0, None))
+
+ batch_dp = jnp.ones((4, 3))
+ sp = jnp.array([1.0])
+
+ jaxpr_inlined = jax.make_jaxpr(f_inlined_vmapped)(batch_dp)
+ jaxpr_lookup = jax.make_jaxpr(f_lookup_vmapped)(batch_dp, sp)
+
+ print(f"\nvmapped inlined jaxpr ({len(jaxpr_inlined.eqns)} ops):")
+ print(jaxpr_inlined)
+ print(f"\nvmapped lookup jaxpr ({len(jaxpr_lookup.eqns)} ops):")
+ print(jaxpr_lookup)
+
+ # HLO
+ lowered_inlined = jax.jit(f_inlined_vmapped).lower(batch_dp)
+ lowered_lookup = jax.jit(f_lookup_vmapped).lower(batch_dp, sp)
+ hlo_inlined = lowered_inlined.as_text()
+ hlo_lookup = lowered_lookup.as_text()
+
+ select_inlined = hlo_inlined.count("select")
+ select_lookup = hlo_lookup.count("select")
+ print(f"\nHLO select ops: inlined={select_inlined}, lookup={select_lookup}")
+ print(
+ f"HLO lines: inlined={len(hlo_inlined.splitlines())}, lookup={len(hlo_lookup.splitlines())}"
+ )
+
+
+if __name__ == "__main__":
+ print(f"JAX version: {jax.__version__}")
+ print(f"Platform: {jax.default_backend()}")
+ print(
+ f"x64 enabled: {jax.config.x86_64_enabled if hasattr(jax.config, 'x86_64_enabled') else 'unknown'}"
+ )
+ print()
+
+ test_basic_constant_folding()
+ test_constant_through_jnp_ops()
+ test_python_float_vs_jnp()
+ test_generated_code_pattern()
+ test_vmap_interaction()
diff --git a/scripts/compare_vacask.py b/scripts/compare_vacask.py
index 20063930..a6941eb5 100644
--- a/scripts/compare_vacask.py
+++ b/scripts/compare_vacask.py
@@ -35,6 +35,7 @@
from vajax._logging import enable_performance_logging
enable_performance_logging(with_memory=True, with_perf_counter=True)
+
import re
import sys
import time
@@ -52,7 +53,6 @@
# Note: Set JAX_PLATFORMS=cpu before running for CPU-only mode
import jax
-import jax.numpy as jnp
# Precision is auto-configured by vajax import (imported above via logging)
# Metal/TPU use f32, CPU/CUDA use f64
@@ -64,104 +64,6 @@
)
from vajax.profiling import ProfileConfig
-
-def analyze_compiled_function(fn, args, name: str, output_dir: Optional[Path] = None):
- """Dump jaxpr and cost analysis for a JIT-compiled function.
-
- Args:
- fn: A JAX function (JIT-compiled or not)
- args: Example arguments to trace with
- name: Name for output files
- output_dir: Optional directory to save analysis files
- """
- print(f"\n{'=' * 70}")
- print(f"JAX Analysis: {name}")
- print(f"{'=' * 70}")
-
- # Lower the function to get HLO and cost analysis
- # For JIT-compiled functions, we use .lower() directly
- print(f"\n--- Lowering and compiling {name} ---")
- try:
- # If fn is already jitted, we can lower it directly
- # Otherwise wrap it in jit first
- if hasattr(fn, "lower"):
- lowered = fn.lower(*args)
- else:
- lowered = jax.jit(fn).lower(*args)
-
- # Get the HLO text (MLIR representation)
- hlo_text = lowered.as_text()
- hlo_lines = hlo_text.split("\n")
- print(f"HLO text: {len(hlo_lines)} lines")
-
- # Count operations in HLO
- op_counts: Dict[str, int] = {}
- for line in hlo_lines:
- # Extract operation names from MLIR-style ops like: %0 = stablehlo.add
- if "=" in line and "." in line:
- parts = line.split("=")
- if len(parts) >= 2:
- op_part = parts[1].strip().split()[0] if parts[1].strip() else ""
- if "." in op_part:
- op_name = op_part.split("(")[0] # Remove args
- op_counts[op_name] = op_counts.get(op_name, 0) + 1
-
- if op_counts:
- print(f"Top HLO ops: {dict(sorted(op_counts.items(), key=lambda x: -x[1])[:15])}")
-
- # Compile and get cost analysis
- compiled = lowered.compile()
- cost = compiled.cost_analysis()
- print("\n--- Cost Analysis ---")
- if cost:
- for i, device_cost in enumerate(cost):
- if device_cost and isinstance(device_cost, dict):
- print(f"Device {i}:")
- for key, val in device_cost.items():
- if isinstance(val, (int, float)):
- if val > 1e9:
- print(f" {key}: {val / 1e9:.2f}G")
- elif val > 1e6:
- print(f" {key}: {val / 1e6:.2f}M")
- elif val > 1e3:
- print(f" {key}: {val / 1e3:.2f}K")
- else:
- print(f" {key}: {val}")
- else:
- print(f" {key}: {val}")
- elif device_cost:
- print(f"Device {i}: {device_cost}")
- else:
- print("No cost analysis available (may not be supported on this backend)")
-
- # Save files if output_dir provided
- if output_dir:
- output_dir.mkdir(parents=True, exist_ok=True)
-
- # Save HLO text
- hlo_file = output_dir / f"{name}_hlo.txt"
- with open(hlo_file, "w") as f:
- f.write(hlo_text)
- print(f"\nHLO text saved to: {hlo_file}")
-
- # Try to get the jaxpr as well for the underlying computation
- try:
- # Create jaxpr from the unwrapped function if possible
- jaxpr_text = str(jax.make_jaxpr(fn)(*args))
- jaxpr_file = output_dir / f"{name}_jaxpr.txt"
- with open(jaxpr_file, "w") as f:
- f.write(jaxpr_text)
- print(f"JAXPR saved to: {jaxpr_file}")
- except Exception:
- pass # JIT functions may not produce useful jaxpr
-
- except Exception as e:
- import traceback
-
- print(f"Failed to analyze: {e}")
- traceback.print_exc()
-
-
# Note: Benchmark configurations are now auto-discovered from
# vajax.benchmarks.registry. The registry parses .sim files
# to extract dt, t_stop, and device types automatically.
@@ -451,41 +353,11 @@ def do_run():
)
startup_time = time.perf_counter() - startup_start
- # Run analysis on compiled scan function if requested
- if analyze and use_scan and hasattr(engine, "_cached_scan_fn"):
- print("\n Running JAX analysis...")
- # Get example inputs for the scan function
- # The scan function signature is: (V_init, Q_init, all_vsource, all_isource, device_arrays)
- # Must use total nodes (external + internal) from transient setup cache
- setup_cache = getattr(engine, "_transient_setup_cache", None)
- device_arrays = getattr(engine, "_device_arrays", None)
-
- if setup_cache is None or device_arrays is None:
- print(" Warning: transient setup cache not found - analysis skipped")
- else:
- n_total = setup_cache["n_total"]
- n_unknowns = setup_cache["n_unknowns"]
- n_vsources = len([d for d in engine.devices if d["model"] == "vsource"])
- n_isources = len([d for d in engine.devices if d["model"] == "isource"])
-
- # Create example arrays matching actual shapes
- V_init = jnp.zeros(n_total, dtype=jnp.float64)
- Q_init = jnp.zeros(n_unknowns, dtype=jnp.float64)
- all_vsource = jnp.zeros((num_steps, n_vsources), dtype=jnp.float64)
- all_isource = jnp.zeros((num_steps, n_isources), dtype=jnp.float64)
-
- # Determine output directory
- out_dir = analyze_output_dir or Path(f"/tmp/vajax-analysis/{config.name}")
-
- # Analyze the scan function
- analyze_compiled_function(
- engine._cached_scan_fn,
- (V_init, Q_init, all_vsource, all_isource, device_arrays),
- f"{config.name}_scan_simulation",
- out_dir,
- )
- elif analyze and use_scan:
- print("\n Warning: _cached_scan_fn not found - analysis skipped")
+ # Run analysis on compiled functions if requested
+ if analyze:
+ out_dir = analyze_output_dir or Path(f"/tmp/claude/jaxpr-analysis/{config.name}")
+ print(f"\n Dumping jaxpr/HLO analysis to {out_dir} ...")
+ engine.dump_jaxpr(out_dir)
# Timed run - print perf_counter for correlation with Perfetto traces
# prepare() already called above with same params, strategy is cached
@@ -496,13 +368,11 @@ def do_run():
print(
f"AFTER_RUN_TRANSIENT: {after_transient:.6f} (elapsed: {after_transient - start:.6f}s)"
)
- # Force completion of async JAX operations
- first_node = next(iter(result.voltages))
- _ = float(result.voltages[first_node][0])
+ # Results are numpy arrays (materialized by block_until_ready in full_mna)
end = time.perf_counter()
external_elapsed = end - start
print(
- f"TIMED_RUN_END: {end:.6f} (elapsed: {external_elapsed:.6f}s, sync took: {end - after_transient:.6f}s)"
+ f"TIMED_RUN_END: {end:.6f} (elapsed: {external_elapsed:.6f}s, extract took: {end - after_transient:.6f}s)"
)
# Use wall_time from stats (excludes trace saving overhead)
diff --git a/scripts/nsys_profile_target.py b/scripts/nsys_profile_target.py
index db64ac4d..c313fbf2 100644
--- a/scripts/nsys_profile_target.py
+++ b/scripts/nsys_profile_target.py
@@ -1,39 +1,49 @@
#!/usr/bin/env python3
-"""Target script for nsys-jax profiling - runs circuit simulation.
-
-This script is designed to be wrapped by nsys-jax:
- nsys-jax -o profile.zip python scripts/nsys_profile_target.py [circuit] [timesteps]
-
-nsys-jax automatically handles:
-- XLA_FLAGS configuration for HLO metadata dumping
-- JAX_TRACEBACK_IN_LOCATIONS_LIMIT for stack traces
-- JAX_ENABLE_COMPILATION_CACHE=false for metadata collection
+"""Target script for nsys GPU profiling - runs circuit simulation.
Usage:
- python scripts/nsys_profile_target.py [circuit] [timesteps]
+ nsys profile -o profile uv run python scripts/nsys_profile_target.py ring 500
Arguments:
- circuit: One of rc, graetz, mul, ring (default: ring)
- timesteps: Number of timesteps to simulate (default: 50)
+ circuit: One of rc, graetz, mul, ring, c6288 (default: ring)
+ timesteps: Number of timesteps to simulate (default: 500)
-Example:
- nsys-jax -o /tmp/profile.zip python scripts/nsys_profile_target.py ring 100
+Use 500+ timesteps so JIT warmup overhead is <5% of total profile.
"""
import argparse
+import ctypes
+import logging
import sys
+import time
from pathlib import Path
import jax
sys.path.insert(0, ".")
-# Import vajax first to auto-configure precision based on backend
-from vajax.analysis import CircuitEngine
+# Enable INFO logging so solver selection messages are visible
+logging.basicConfig(level=logging.INFO, format="%(name)s: %(message)s")
+
+
+def timed(label):
+ """Context manager that prints elapsed time for a phase."""
+
+ class Timer:
+ def __enter__(self):
+ self.t0 = time.perf_counter()
+ print(f"[{label}] starting...", flush=True)
+ return self
+
+ def __exit__(self, *exc):
+ dt = time.perf_counter() - self.t0
+ print(f"[{label}] done in {dt:.2f}s", flush=True)
+
+ return Timer()
def main():
- parser = argparse.ArgumentParser(description="nsys-jax profiling target for VAJAX")
+ parser = argparse.ArgumentParser(description="nsys profiling target for VAJAX")
parser.add_argument(
"circuit",
nargs="?",
@@ -45,8 +55,8 @@ def main():
"timesteps",
nargs="?",
type=int,
- default=50,
- help="Number of timesteps to simulate (default: 50)",
+ default=500,
+ help="Number of timesteps to simulate (default: 500)",
)
parser.add_argument(
"--backend",
@@ -61,10 +71,35 @@ def main():
)
args = parser.parse_args()
- print(f"JAX backend: {jax.default_backend()}")
- print(f"JAX devices: {jax.devices()}")
+ with timed("JAX init"):
+ print(f"JAX backend: {jax.default_backend()}")
+ print(f"JAX devices: {jax.devices()}")
print(f"Circuit: {args.circuit}")
print(f"Timesteps: {args.timesteps}")
+
+ # Explicit solver availability check
+ print()
+ print("=== Solver Availability ===")
+ with timed("solver imports"):
+ try:
+ from spineax.cudss.dense_baspacho_solver import is_available
+
+ print(" BaSpaCho dense import: OK")
+ print(f" BaSpaCho dense available: {is_available()}")
+ except ImportError as e:
+ print(f" BaSpaCho dense import: FAILED ({e})")
+ try:
+ from spineax.cudss.solver import CuDSSSolver # noqa: F401
+
+ print(" cuDSS sparse import: OK")
+ except ImportError as e:
+ print(f" cuDSS sparse import: FAILED ({e})")
+ try:
+ from spineax import baspacho_dense_solve as _mod
+
+ print(f" baspacho_dense_solve C++ module: OK ({_mod})")
+ except ImportError as e:
+ print(f" baspacho_dense_solve C++ module: FAILED ({e})")
print()
# Find benchmark .sim file
@@ -76,10 +111,15 @@ def main():
print(f"ERROR: Benchmark file not found: {sim_path}")
sys.exit(1)
+ # Import vajax (auto-configures precision based on backend)
+ with timed("vajax import"):
+ from vajax.analysis import CircuitEngine
+
# Setup circuit using CircuitEngine
- print(f"Setting up circuit from {sim_path}...")
- engine = CircuitEngine(sim_path)
- engine.parse()
+ with timed("circuit parse"):
+ print(f"Setting up circuit from {sim_path}...")
+ engine = CircuitEngine(sim_path)
+ engine.parse()
print(f"Circuit size: {engine.num_nodes} nodes, {len(engine.devices)} devices")
print()
@@ -90,22 +130,38 @@ def main():
print()
# Prepare (includes 1-step JIT warmup)
- print(f"Preparing ({args.timesteps} timesteps, includes JIT warmup)...")
- engine.prepare(
- t_stop=args.timesteps * dt,
- dt=dt,
- use_sparse=args.sparse,
- )
+ with timed("prepare + JIT warmup"):
+ print(f"Preparing ({args.timesteps} timesteps, includes JIT warmup)...")
+ engine.prepare(
+ t_stop=args.timesteps * dt,
+ dt=dt,
+ use_sparse=args.sparse,
+ )
print("Prepare complete")
print()
- # Profiled run - nsys-jax captures this automatically
- print(f"Starting profiled run ({args.timesteps} timesteps)...")
- result = engine.run_transient()
+ # NVTX range for nsys --capture-range=nvtx scoping
+ try:
+ _nvtx = ctypes.CDLL("libnvToolsExt.so")
+ _nvtx_push = _nvtx.nvtxRangePushA
+ _nvtx_push.argtypes = [ctypes.c_char_p]
+ _nvtx_pop = _nvtx.nvtxRangePop
+ except OSError:
+ _nvtx_push = _nvtx_pop = None
- print()
- print(f"Completed: {result.num_steps} timesteps")
- print(f"Wall time: {result.stats.get('wall_time', 0):.3f}s")
+ if _nvtx_push:
+ _nvtx_push(b"run_transient")
+
+ with timed("transient simulation"):
+ print(f"Starting profiled run ({args.timesteps} timesteps)...", flush=True)
+ result = engine.run_transient()
+
+ if _nvtx_pop:
+ _nvtx_pop()
+
+ print(flush=True)
+ print(f"Completed: {result.num_steps} timesteps", flush=True)
+ print(f"Wall time: {result.stats.get('wall_time', 0):.3f}s", flush=True)
if __name__ == "__main__":
diff --git a/scripts/profile_nr_phases.py b/scripts/profile_nr_phases.py
new file mode 100644
index 00000000..6100b4bb
--- /dev/null
+++ b/scripts/profile_nr_phases.py
@@ -0,0 +1,257 @@
+# /// script
+# requires-python = ">=3.10"
+# dependencies = []
+# ///
+"""Profile NR iteration phase breakdown using JAX profiling tools.
+
+Instruments the NR solver body with jax.named_scope annotations and
+captures a Perfetto trace showing the time split between:
+ - build_system: device evaluation + Jacobian/residual assembly
+ - linear_solve: sparse or dense linear solve (J*delta = -f)
+ - convergence: residual/delta checks, step limiting, solution update
+
+Also uses jax.debug.callback timestamps for a quick text summary
+(accurate on CPU since execution is synchronous).
+
+Usage:
+ JAX_PLATFORMS=cpu uv run python scripts/profile_nr_phases.py ring
+ JAX_PLATFORMS=cpu uv run python scripts/profile_nr_phases.py c6288
+ JAX_PLATFORMS=cpu uv run python scripts/profile_nr_phases.py c6288 --trace-dir /tmp/jax_trace
+"""
+
+import argparse
+import os
+import time
+from pathlib import Path
+
+os.environ.setdefault("JAX_PLATFORMS", "cpu")
+
+import jax
+import jax.numpy as jnp
+import numpy as np
+
+# ---------------------------------------------------------------------------
+# Phase timing via jax.debug.callback (CPU-accurate)
+# ---------------------------------------------------------------------------
+
+phase_timings: list[dict] = []
+_phase_clock: dict[str, float] = {}
+
+
+def _start_phase(phase_name_bytes):
+ """Record start time for a phase."""
+ phase_name = (
+ phase_name_bytes.tobytes().decode()
+ if hasattr(phase_name_bytes, "tobytes")
+ else str(phase_name_bytes)
+ )
+ _phase_clock[phase_name] = time.perf_counter_ns()
+
+
+def _end_phase(phase_name_bytes, iteration):
+ """Record end time for a phase."""
+ phase_name = (
+ phase_name_bytes.tobytes().decode()
+ if hasattr(phase_name_bytes, "tobytes")
+ else str(phase_name_bytes)
+ )
+ start = _phase_clock.get(phase_name, 0)
+ elapsed_ns = time.perf_counter_ns() - start
+ phase_timings.append(
+ {
+ "phase": phase_name,
+ "iteration": int(iteration),
+ "elapsed_us": elapsed_ns / 1000,
+ }
+ )
+
+
+# ---------------------------------------------------------------------------
+# Monkey-patch NR solver to add named scopes + timing callbacks
+# ---------------------------------------------------------------------------
+
+import vajax.analysis.solver_factories as sf
+
+_original_make_nr = sf._make_nr_solver_common
+
+
+def patched_make_nr_solver_common(*, build_system_jit, linear_solve_fn, enforce_noi_fn, **kwargs):
+ """Wrap build_system and linear_solve with named scopes and timing."""
+
+ def timed_build_system(*args):
+ with jax.named_scope("nr_build_system"):
+ return build_system_jit(*args)
+
+ def timed_linear_solve(J_or_data, f):
+ with jax.named_scope("nr_linear_solve"):
+ return linear_solve_fn(J_or_data, f)
+
+ def timed_enforce_noi(J_or_data, f):
+ with jax.named_scope("nr_enforce_noi"):
+ return enforce_noi_fn(J_or_data, f)
+
+ return _original_make_nr(
+ build_system_jit=timed_build_system,
+ linear_solve_fn=timed_linear_solve,
+ enforce_noi_fn=timed_enforce_noi,
+ **kwargs,
+ )
+
+
+sf._make_nr_solver_common = patched_make_nr_solver_common
+
+
+# ---------------------------------------------------------------------------
+# Also add callback-based timing for text summary
+# ---------------------------------------------------------------------------
+
+_original_make_nr2 = sf._make_nr_solver_common # This is now our patched version
+
+
+def callback_timed_make_nr(*, build_system_jit, linear_solve_fn, enforce_noi_fn, **kwargs):
+ """Add jax.debug.callback timestamps around each phase."""
+
+ def timed_build_system(*args):
+ # Extract iteration from args (it's the last positional arg)
+ iteration = args[-1] if len(args) > 0 else jnp.array(0)
+ build_tag = jnp.array(list(b"build_system"), dtype=jnp.uint8)
+ jax.debug.callback(_start_phase, build_tag)
+ result = build_system_jit(*args)
+ jax.debug.callback(_end_phase, build_tag, iteration)
+ return result
+
+ def timed_linear_solve(J_or_data, f):
+ solve_tag = jnp.array(list(b"linear_solve"), dtype=jnp.uint8)
+ jax.debug.callback(_start_phase, solve_tag)
+ result = linear_solve_fn(J_or_data, f)
+ jax.debug.callback(_end_phase, solve_tag, jnp.array(-1))
+ return result
+
+ return _original_make_nr2(
+ build_system_jit=timed_build_system,
+ linear_solve_fn=timed_linear_solve,
+ enforce_noi_fn=enforce_noi_fn,
+ **kwargs,
+ )
+
+
+sf._make_nr_solver_common = callback_timed_make_nr
+
+
+# ---------------------------------------------------------------------------
+# Main
+# ---------------------------------------------------------------------------
+
+
+def main():
+ import logging
+
+ from vajax.analysis import CircuitEngine
+ from vajax.benchmarks.registry import get_benchmark
+
+ parser = argparse.ArgumentParser(description="Profile NR phase breakdown")
+ parser.add_argument("benchmark", help="Benchmark name (e.g. ring, c6288)")
+ parser.add_argument(
+ "--trace-dir",
+ type=Path,
+ default=None,
+ help="Directory for Perfetto trace (default: /tmp/claude/_trace)",
+ )
+ parser.add_argument("--t-stop", type=float, default=None, help="Override stop time")
+ parser.add_argument(
+ "--n-steps", type=int, default=10, help="Number of timesteps to profile (default: 10)"
+ )
+ args = parser.parse_args()
+
+ logging.getLogger("vajax").setLevel(logging.WARNING)
+
+ info = get_benchmark(args.benchmark)
+ assert info is not None, f"Benchmark '{args.benchmark}' not found"
+
+ engine = CircuitEngine(info.sim_path)
+ engine.parse()
+
+ use_sparse = info.is_large
+ dt = info.dt
+ t_stop = args.t_stop or dt * args.n_steps
+
+ print(f"Benchmark: {args.benchmark}")
+ print(f" Nodes: {engine.num_nodes}, Devices: {len(engine.devices)}")
+ print(f" Solver: {'sparse' if use_sparse else 'dense'}")
+ print(f" Profiling {args.n_steps} steps (t_stop={t_stop:.2e}s)")
+
+ trace_dir = args.trace_dir or Path(f"/tmp/claude/{args.benchmark}_trace")
+ trace_dir.mkdir(parents=True, exist_ok=True)
+
+ # Prepare (includes JIT warmup)
+ print("\nPreparing (JIT warmup)...")
+ engine.prepare(t_stop=t_stop, dt=dt, use_sparse=use_sparse)
+
+ # Clear any timing from warmup
+ phase_timings.clear()
+
+ # Run with Perfetto trace capture
+ print(f"Running with profiler trace -> {trace_dir}")
+ jax.profiler.start_trace(str(trace_dir))
+ try:
+ result = engine.run_transient()
+ finally:
+ jax.profiler.stop_trace()
+
+ convergence = result.stats.get("convergence_rate", 0) * 100
+ print(f" Steps: {result.num_steps}, convergence: {convergence:.0f}%")
+ print(f" Trace saved to: {trace_dir}")
+
+ # --- Analyze callback timings ---
+ if not phase_timings:
+ print("\nNo callback timings captured (expected inside lax.while_loop)")
+ print(f"View the Perfetto trace at: {trace_dir}")
+ print(" Open https://ui.perfetto.dev and load the trace file")
+ return
+
+ print(f"\n{'=' * 60}")
+ print(f"NR Phase Timing Breakdown ({len(phase_timings)} measurements)")
+ print(f"{'=' * 60}")
+
+ # Aggregate by phase
+ by_phase: dict[str, list[float]] = {}
+ for entry in phase_timings:
+ phase = entry["phase"]
+ if phase not in by_phase:
+ by_phase[phase] = []
+ by_phase[phase].append(entry["elapsed_us"])
+
+ total_us = sum(sum(times) for times in by_phase.values())
+
+ print(f"\n{'Phase':<20} {'Count':>6} {'Total (ms)':>12} {'Mean (µs)':>12} {'%':>8}")
+ print(f"{'-' * 20} {'-' * 6} {'-' * 12} {'-' * 12} {'-' * 8}")
+ for phase, times in sorted(by_phase.items(), key=lambda x: -sum(x[1])):
+ total_ms = sum(times) / 1000
+ mean_us = np.mean(times)
+ pct = sum(times) / total_us * 100 if total_us > 0 else 0
+ print(f"{phase:<20} {len(times):>6} {total_ms:>12.2f} {mean_us:>12.1f} {pct:>7.1f}%")
+
+ print(f"\n{'Total':.<20} {'':>6} {total_us / 1000:>12.2f} ms")
+
+ # Per-NR-iteration breakdown (first few)
+ build_times = by_phase.get("build_system", [])
+ solve_times = by_phase.get("linear_solve", [])
+
+ if build_times and solve_times:
+ n_show = min(10, len(build_times))
+ print(f"\nPer-iteration breakdown (first {n_show}):")
+ print(f"{'Iter':>4} {'Build (µs)':>12} {'Solve (µs)':>12} {'Solve %':>8}")
+ print(f"{'-' * 4} {'-' * 12} {'-' * 12} {'-' * 8}")
+ for i in range(n_show):
+ b = build_times[i]
+ s = solve_times[i] if i < len(solve_times) else 0
+ total = b + s
+ spct = s / total * 100 if total > 0 else 0
+ print(f"{i:>4} {b:>12.1f} {s:>12.1f} {spct:>7.1f}%")
+
+ print(f"\nPerfetto trace: {trace_dir}")
+ print(" Open https://ui.perfetto.dev and load the .pb or .json.gz file")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/scripts/sweep_xla_flags.py b/scripts/sweep_xla_flags.py
new file mode 100644
index 00000000..0175c12e
--- /dev/null
+++ b/scripts/sweep_xla_flags.py
@@ -0,0 +1,402 @@
+#!/usr/bin/env -S uv run --script
+# /// script
+# requires-python = ">=3.10"
+# dependencies = ["jax"]
+# ///
+"""Sweep XLA flag combinations to find optimal CUDA performance.
+
+Runs a benchmark circuit with different XLA flag configurations and
+reports timing for each. Each configuration runs in a separate subprocess
+to ensure clean XLA state.
+
+Usage:
+ # Run on GPU (auto-detects CUDA)
+ uv run scripts/sweep_xla_flags.py
+
+ # Specific benchmark
+ uv run scripts/sweep_xla_flags.py --benchmark ring
+
+ # Specific configurations only
+ uv run scripts/sweep_xla_flags.py --configs baseline,autotune2,command_buffer
+
+ # Also include large circuit (needs sparse solver)
+ uv run scripts/sweep_xla_flags.py --benchmark ring,c6288 --include-sparse
+"""
+
+import argparse
+import json
+import os
+import subprocess
+import sys
+import time
+from pathlib import Path
+
+# XLA flag configurations to test
+FLAG_CONFIGS = {
+ "baseline": {
+ "description": "Current CI config (autotune=0)",
+ "xla_flags": "--xla_gpu_autotune_level=0",
+ "env": {},
+ },
+ "autotune2": {
+ "description": "Autotune level 2 (enables cuBLAS algorithm selection)",
+ "xla_flags": "--xla_gpu_autotune_level=2",
+ "env": {},
+ },
+ "autotune4": {
+ "description": "Autotune level 4 (full autotuning)",
+ "xla_flags": "--xla_gpu_autotune_level=4",
+ "env": {},
+ },
+ "command_buffer": {
+ "description": "Command buffers enabled (batch kernel launches)",
+ "xla_flags": "--xla_gpu_autotune_level=0",
+ "env": {},
+ },
+ "double_buffer": {
+ "description": "While-loop double buffering",
+ "xla_flags": (
+ "--xla_gpu_autotune_level=0 --xla_gpu_enable_while_loop_double_buffering=true"
+ ),
+ "env": {},
+ },
+ "pgle": {
+ "description": "Profile-guided latency estimation (3 profiling runs)",
+ "xla_flags": "--xla_gpu_autotune_level=0",
+ "env": {
+ "JAX_ENABLE_PGLE": "true",
+ "JAX_PGLE_PROFILING_RUNS": "3",
+ },
+ },
+ "combined_safe": {
+ "description": "Autotune 2 + double buffering",
+ "xla_flags": (
+ "--xla_gpu_autotune_level=2 --xla_gpu_enable_while_loop_double_buffering=true"
+ ),
+ "env": {},
+ },
+ "combined_aggressive": {
+ "description": "Autotune 4 + double buffering + PGLE",
+ "xla_flags": (
+ "--xla_gpu_autotune_level=4 --xla_gpu_enable_while_loop_double_buffering=true"
+ ),
+ "env": {
+ "JAX_ENABLE_PGLE": "true",
+ "JAX_PGLE_PROFILING_RUNS": "3",
+ },
+ },
+}
+
+# The subprocess script that runs a single benchmark
+BENCHMARK_RUNNER = """
+import os
+import sys
+import time
+import json
+
+sys.path.insert(0, os.environ["PROJECT_ROOT"])
+
+# Memory config
+os.environ.setdefault("XLA_PYTHON_CLIENT_PREALLOCATE", "false")
+os.environ.setdefault("XLA_PYTHON_CLIENT_ALLOCATOR", "platform")
+
+import jax
+import jax.numpy as jnp
+
+from vajax.analysis import CircuitEngine
+
+benchmark_name = os.environ["BENCHMARK_NAME"]
+use_sparse = os.environ.get("USE_SPARSE", "0") == "1"
+use_scan = True
+force_gpu = os.environ.get("FORCE_GPU", "0") == "1"
+n_warmup = int(os.environ.get("N_WARMUP", "1"))
+n_runs = int(os.environ.get("N_RUNS", "3"))
+
+from scripts.benchmark_utils import get_vacask_benchmarks
+
+benchmarks = get_vacask_benchmarks([benchmark_name])
+if not benchmarks:
+ print(json.dumps({"error": f"Benchmark {benchmark_name} not found"}))
+ sys.exit(1)
+
+name, sim_path = benchmarks[0]
+
+# Report JAX config
+devices = jax.devices()
+backend = devices[0].platform if devices else "unknown"
+print(f"JAX backend: {backend}, devices: {[d.platform for d in devices]}", file=sys.stderr)
+print(f"XLA_FLAGS: {os.environ.get('XLA_FLAGS', '(not set)')}", file=sys.stderr)
+
+engine = CircuitEngine.from_sim_file(str(sim_path))
+engine.prepare(use_sparse=use_sparse, force_gpu=force_gpu, use_scan=use_scan)
+
+# Get step count from sim parameters
+dt = engine.sim_params.get("dt", 1e-6)
+t_stop = engine.sim_params.get("tstop", engine.sim_params.get("t_stop", 1e-3))
+n_steps = int(t_stop / dt) if dt > 0 else 100
+
+timings = []
+
+for run_idx in range(n_warmup + n_runs):
+ # Re-prepare to reset state
+ if run_idx > 0:
+ engine.prepare(use_sparse=use_sparse, force_gpu=force_gpu, use_scan=use_scan)
+
+ start = time.perf_counter()
+ result = engine.run_transient()
+ # Block until computation complete
+ if hasattr(result, 'voltages') and result.voltages is not None:
+ jax.block_until_ready(result.voltages)
+ elapsed = time.perf_counter() - start
+
+ actual_steps = result.n_steps if hasattr(result, 'n_steps') else n_steps
+ ms_per_step = (elapsed * 1000.0) / max(actual_steps, 1)
+
+ label = "warmup" if run_idx < n_warmup else f"run {run_idx - n_warmup}"
+ print(f" {label}: {elapsed:.3f}s ({actual_steps} steps, {ms_per_step:.3f} ms/step)", file=sys.stderr)
+
+ if run_idx >= n_warmup:
+ timings.append({
+ "elapsed_s": elapsed,
+ "n_steps": actual_steps,
+ "ms_per_step": ms_per_step,
+ })
+
+# Report median timing
+timings.sort(key=lambda t: t["ms_per_step"])
+median = timings[len(timings) // 2]
+
+print(json.dumps({
+ "benchmark": benchmark_name,
+ "backend": backend,
+ "n_steps": median["n_steps"],
+ "ms_per_step": median["ms_per_step"],
+ "elapsed_s": median["elapsed_s"],
+ "n_runs": n_runs,
+ "all_timings": [t["ms_per_step"] for t in timings],
+}))
+"""
+
+
+def run_config(
+ config_name: str,
+ config: dict,
+ benchmark: str,
+ project_root: Path,
+ use_sparse: bool,
+ force_gpu: bool,
+ n_warmup: int = 1,
+ n_runs: int = 3,
+) -> dict:
+ """Run a single benchmark with a specific XLA flag configuration."""
+ env = os.environ.copy()
+ env["PROJECT_ROOT"] = str(project_root)
+ env["BENCHMARK_NAME"] = benchmark
+ env["USE_SPARSE"] = "1" if use_sparse else "0"
+ env["FORCE_GPU"] = "1" if force_gpu else "0"
+ env["N_WARMUP"] = str(n_warmup)
+ env["N_RUNS"] = str(n_runs)
+ env["JAX_PLATFORMS"] = "cuda,cpu" if force_gpu else "cpu"
+ env["JAX_ENABLE_X64"] = "1"
+
+ # Set XLA flags
+ env["XLA_FLAGS"] = config["xla_flags"]
+
+ # Set additional env vars
+ for k, v in config.get("env", {}).items():
+ env[k] = v
+
+ # Memory allocation
+ env.setdefault("XLA_PYTHON_CLIENT_PREALLOCATE", "false")
+ env.setdefault("XLA_PYTHON_CLIENT_ALLOCATOR", "platform")
+
+ print(f"\n{'=' * 60}")
+ print(f"Config: {config_name} — {config['description']}")
+ print(f" XLA_FLAGS: {config['xla_flags']}")
+ if config.get("env"):
+ print(f" Extra env: {config['env']}")
+ print(f" Benchmark: {benchmark}")
+ print(f"{'=' * 60}")
+
+ start = time.perf_counter()
+ try:
+ result = subprocess.run(
+ [sys.executable, "-c", BENCHMARK_RUNNER],
+ env=env,
+ capture_output=True,
+ text=True,
+ timeout=600, # 10 min max per config
+ cwd=str(project_root),
+ )
+ except subprocess.TimeoutExpired:
+ return {
+ "config": config_name,
+ "benchmark": benchmark,
+ "error": "timeout (600s)",
+ "wall_time_s": time.perf_counter() - start,
+ }
+
+ wall_time = time.perf_counter() - start
+
+ # Print stderr (progress messages)
+ if result.stderr:
+ for line in result.stderr.strip().split("\n"):
+ print(f" {line}")
+
+ # Parse JSON from last line of stdout
+ if result.returncode != 0:
+ print(f" ERROR: exit code {result.returncode}")
+ if result.stderr:
+ print(f" {result.stderr[-500:]}")
+ return {
+ "config": config_name,
+ "benchmark": benchmark,
+ "error": f"exit code {result.returncode}",
+ "wall_time_s": wall_time,
+ }
+
+ try:
+ # Find the JSON line (last non-empty line of stdout)
+ lines = [l for l in result.stdout.strip().split("\n") if l.strip()]
+ data = json.loads(lines[-1])
+ data["config"] = config_name
+ data["wall_time_s"] = wall_time
+ data["description"] = config["description"]
+ print(f" Result: {data['ms_per_step']:.3f} ms/step ({data['n_steps']} steps)")
+ return data
+ except (json.JSONDecodeError, IndexError) as e:
+ print(f" ERROR parsing output: {e}")
+ print(f" stdout: {result.stdout[-500:]}")
+ return {
+ "config": config_name,
+ "benchmark": benchmark,
+ "error": f"parse error: {e}",
+ "wall_time_s": wall_time,
+ }
+
+
+def main():
+ parser = argparse.ArgumentParser(description="Sweep XLA flag combinations")
+ parser.add_argument(
+ "--benchmark",
+ default="ring",
+ help="Comma-separated benchmark names (default: ring)",
+ )
+ parser.add_argument(
+ "--configs",
+ default=None,
+ help="Comma-separated config names (default: all)",
+ )
+ parser.add_argument(
+ "--include-sparse",
+ action="store_true",
+ help="Include sparse solver for large circuits",
+ )
+ parser.add_argument(
+ "--force-gpu",
+ action="store_true",
+ default=True,
+ help="Force GPU backend (default: True)",
+ )
+ parser.add_argument(
+ "--cpu-only",
+ action="store_true",
+ help="Run on CPU instead of GPU",
+ )
+ parser.add_argument(
+ "--n-warmup",
+ type=int,
+ default=1,
+ help="Number of warmup runs (default: 1)",
+ )
+ parser.add_argument(
+ "--n-runs",
+ type=int,
+ default=3,
+ help="Number of timed runs (default: 3)",
+ )
+ parser.add_argument(
+ "--json-output",
+ default=None,
+ help="Path to write JSON results",
+ )
+ args = parser.parse_args()
+
+ project_root = Path(__file__).parent.parent
+ benchmarks = [b.strip() for b in args.benchmark.split(",")]
+ force_gpu = not args.cpu_only
+
+ if args.configs:
+ config_names = [c.strip() for c in args.configs.split(",")]
+ configs = {k: FLAG_CONFIGS[k] for k in config_names if k in FLAG_CONFIGS}
+ else:
+ configs = FLAG_CONFIGS
+
+ all_results = []
+
+ for benchmark in benchmarks:
+ # Determine if sparse needed
+ use_sparse = args.include_sparse and benchmark in ("c6288", "mul64")
+
+ for config_name, config in configs.items():
+ result = run_config(
+ config_name,
+ config,
+ benchmark,
+ project_root,
+ use_sparse=use_sparse,
+ force_gpu=force_gpu,
+ n_warmup=args.n_warmup,
+ n_runs=args.n_runs,
+ )
+ all_results.append(result)
+
+ # Print summary table
+ print(f"\n{'=' * 80}")
+ print("SUMMARY")
+ print(f"{'=' * 80}")
+ print(
+ f"{'Config':<25} {'Benchmark':<10} {'ms/step':>10} {'Steps':>8} {'Wall(s)':>10} {'vs base':>10}"
+ )
+ print("-" * 80)
+
+ # Group by benchmark for relative comparison
+ by_benchmark = {}
+ for r in all_results:
+ bm = r.get("benchmark", "?")
+ by_benchmark.setdefault(bm, []).append(r)
+
+ for bm, results in by_benchmark.items():
+ baseline_ms = None
+ for r in results:
+ if r.get("config") == "baseline" and "ms_per_step" in r:
+ baseline_ms = r["ms_per_step"]
+ break
+
+ for r in results:
+ ms = r.get("ms_per_step", None)
+ steps = r.get("n_steps", "?")
+ wall = r.get("wall_time_s", 0)
+ config = r.get("config", "?")
+ err = r.get("error", None)
+
+ if err:
+ print(f"{config:<25} {bm:<10} {'ERROR':>10} {'':>8} {wall:>10.1f} {err}")
+ elif ms is not None:
+ ratio_str = ""
+ if baseline_ms and baseline_ms > 0:
+ ratio = ms / baseline_ms
+ ratio_str = f"{ratio:.2f}x"
+ print(f"{config:<25} {bm:<10} {ms:>10.3f} {steps:>8} {wall:>10.1f} {ratio_str:>10}")
+
+ # Save JSON
+ if args.json_output:
+ out_path = Path(args.json_output)
+ out_path.parent.mkdir(parents=True, exist_ok=True)
+ with open(out_path, "w") as f:
+ json.dump(all_results, f, indent=2)
+ print(f"\nResults saved to {out_path}")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/tests/test_vacask_suite.py b/tests/test_vacask_suite.py
index c2538d13..37ab5804 100644
--- a/tests/test_vacask_suite.py
+++ b/tests/test_vacask_suite.py
@@ -22,12 +22,12 @@
from vajax.netlist.parser import parse_netlist
-# Paths - VACASK is at ../VACASK relative to vajax
+# Paths - VACASK is vendored at vendor/VACASK
VAJAX_ROOT = Path(__file__).parent.parent
-VACASK_ROOT = VAJAX_ROOT.parent / "VACASK"
+VACASK_ROOT = VAJAX_ROOT / "vendor" / "VACASK"
VACASK_TEST = VACASK_ROOT / "test"
VACASK_DEVICES = VACASK_ROOT / "devices"
-VACASK_BENCHMARK = VAJAX_ROOT / "vendor" / "VACASK" / "benchmark"
+VACASK_BENCHMARK = VACASK_ROOT / "benchmark"
def discover_benchmark_dirs() -> List[Path]:
diff --git a/vajax/analysis/engine.py b/vajax/analysis/engine.py
index 3e9fb54d..d2bfe180 100644
--- a/vajax/analysis/engine.py
+++ b/vajax/analysis/engine.py
@@ -1015,13 +1015,244 @@ def run_transient(self) -> TransientResult:
# Extract sliced numpy results for TransientResult
times_np, voltages, currents = extract_results(times_full, V_out, stats)
+ # Keep as numpy arrays — avoids creating dynamically-sized JAX arrays
+ # that trigger jit(dynamic_slice) recompilation on CUDA when n_steps
+ # varies between runs (the shape gets baked into the XLA kernel).
return TransientResult(
- times=jnp.asarray(times_np),
- voltages={k: jnp.asarray(v) for k, v in voltages.items()},
- currents={k: jnp.asarray(v) for k, v in currents.items()},
+ times=times_np,
+ voltages=voltages,
+ currents=currents,
stats=stats,
)
+ def dump_jaxpr(self, output_dir: str | Path = "/tmp/claude/jaxpr-analysis") -> Path:
+ """Dump JAX IR analysis for the compiled simulation functions.
+
+ Analyzes two hot-path functions after prepare():
+ 1. build_system - Jacobian + residual assembly (per NR iteration)
+ 2. nr_solve - Full Newton-Raphson solve (per timestep)
+
+ Uses the actual circuit's device_arrays and dimensions from the prepared
+ strategy, matching the calling conventions used during simulation.
+
+ For each function, writes:
+ - HLO text (StableHLO MLIR representation)
+ - HLO operation counts (top ops by frequency)
+ - XLA cost analysis (flops, memory bytes)
+
+ Args:
+ output_dir: Directory for output files.
+
+ Returns:
+ Path to the output directory.
+ """
+ if not getattr(self, "_prepared", False):
+ self.prepare()
+
+ output_dir = Path(output_dir)
+ output_dir.mkdir(parents=True, exist_ok=True)
+
+ strategy = self._prepared_strategy
+
+ # Get cached functions and actual circuit data from the strategy
+ build_system_jit = getattr(strategy, "_cached_build_system_jit", None)
+ nr_solve = getattr(strategy, "_cached_full_mna_solver", None)
+ device_arrays = getattr(strategy, "_device_arrays_full_mna", None)
+ total_limit_states = getattr(strategy, "_total_limit_states", 0)
+
+ if build_system_jit is None or nr_solve is None or device_arrays is None:
+ raise RuntimeError("No cached solver found. Call prepare() first.")
+
+ # Get dimensions from transient setup cache
+ setup_cache = self._transient_setup_cache
+ if setup_cache is None:
+ raise RuntimeError("No transient setup cache. Call prepare() first.")
+ n_total = setup_cache["n_total"]
+ n_unknowns = setup_cache["n_unknowns"]
+ n_vsources = len([d for d in self.devices if d["model"] == "vsource"])
+ n_isources = len([d for d in self.devices if d["model"] == "isource"])
+
+ logger.info(
+ f"dump_jaxpr: n_total={n_total}, n_unknowns={n_unknowns}, "
+ f"n_vsources={n_vsources}, limit_states={total_limit_states}"
+ )
+
+ # Use the actual circuit's device_arrays and correct-shaped args.
+ # Values don't matter for HLO tracing — only shapes and dtypes.
+ # Match the calling conventions from full_mna.py DC solve path.
+ dtype = get_float_dtype()
+ X = jnp.zeros(n_total + n_vsources, dtype=dtype)
+ vsource_vals = jnp.zeros(n_vsources, dtype=dtype)
+ isource_vals = jnp.zeros(max(n_isources, 1), dtype=dtype)
+ Q_prev = jnp.zeros(n_unknowns, dtype=dtype)
+ integ_c0 = jnp.asarray(0.0, dtype=dtype)
+ gmin = jnp.asarray(1e-12, dtype=dtype)
+ gshunt = jnp.asarray(0.0, dtype=dtype)
+ integ_c1 = jnp.asarray(0.0, dtype=dtype)
+ integ_d1 = jnp.asarray(0.0, dtype=dtype)
+ dQdt_prev = jnp.zeros(n_unknowns, dtype=dtype)
+ integ_c2 = jnp.asarray(0.0, dtype=dtype)
+ Q_prev2 = jnp.zeros(n_unknowns, dtype=dtype)
+ limit_state = jnp.zeros(total_limit_states, dtype=dtype)
+ nr_iter = jnp.asarray(1, dtype=jnp.int32)
+
+ # build_system_jit signature: (X, vsource_vals, isource_vals, Q_prev,
+ # integ_c0, device_arrays, gmin, gshunt, integ_c1, integ_d1,
+ # dQdt_prev, integ_c2, Q_prev2, limit_state, nr_iter)
+ build_args = (
+ X,
+ vsource_vals,
+ isource_vals,
+ Q_prev,
+ integ_c0,
+ device_arrays,
+ gmin,
+ gshunt,
+ integ_c1,
+ integ_d1,
+ dQdt_prev,
+ integ_c2,
+ Q_prev2,
+ limit_state,
+ nr_iter,
+ )
+
+ # nr_solve signature: (X_init, vsource_vals, isource_vals, Q_prev,
+ # integ_c0, device_arrays, gmin, gshunt, integ_c1, integ_d1,
+ # dQdt_prev, integ_c2, Q_prev2, limit_state_in)
+ # Uses None defaults for optional args, matching DC solve convention.
+ nr_args = (
+ X,
+ vsource_vals,
+ isource_vals,
+ Q_prev,
+ integ_c0,
+ device_arrays,
+ gmin,
+ gshunt,
+ integ_c1,
+ integ_d1,
+ None,
+ integ_c2,
+ None,
+ None,
+ )
+
+ results = {}
+ for name, fn, args in [
+ ("build_system", build_system_jit, build_args),
+ ("nr_solve", nr_solve, nr_args),
+ ]:
+ results[name] = self._dump_single_jaxpr(name, fn, args, output_dir)
+
+ return output_dir
+
+ @staticmethod
+ def _dump_single_jaxpr(name: str, fn, args, output_dir: Path) -> dict[str, Any]:
+ """Analyze and dump jaxpr/HLO for a single function.
+
+ Produces three artifacts per function:
+ - {name}.jaxpr.txt — JAX's high-level IR (from jax.make_jaxpr)
+ - {name}.hlo.txt — StableHLO/MLIR after lowering
+ - Log output with op counts and XLA cost analysis
+
+ Returns:
+ Dict with 'jaxpr', 'hlo_lines', 'op_counts', 'cost' keys.
+ """
+ result: dict[str, Any] = {"name": name}
+ logger.info(f"Analyzing {name}...")
+
+ try:
+ # 1. Jaxpr (high-level JAX IR)
+ jaxpr = jax.make_jaxpr(fn)(*args)
+ jaxpr_text = str(jaxpr)
+ jaxpr_lines = jaxpr_text.split("\n")
+ result["jaxpr"] = jaxpr
+ result["jaxpr_lines"] = len(jaxpr_lines)
+ logger.info(f" {name}: {len(jaxpr_lines)} jaxpr lines")
+
+ jaxpr_file = output_dir / f"{name}.jaxpr.txt"
+ with open(jaxpr_file, "w") as f:
+ f.write(jaxpr_text)
+ result["jaxpr_file"] = jaxpr_file
+ logger.info(f" Saved jaxpr: {jaxpr_file}")
+
+ # Count jaxpr primitives
+ jaxpr_ops: dict[str, int] = {}
+ for eqn in jaxpr.jaxpr.eqns:
+ prim_name = str(eqn.primitive.name)
+ jaxpr_ops[prim_name] = jaxpr_ops.get(prim_name, 0) + 1
+ result["jaxpr_ops"] = dict(sorted(jaxpr_ops.items(), key=lambda x: -x[1]))
+ if jaxpr_ops:
+ sorted_jaxpr_ops = sorted(jaxpr_ops.items(), key=lambda x: -x[1])
+ logger.info(f" {name}: Top jaxpr primitives:")
+ for op, count in sorted_jaxpr_ops[:20]:
+ logger.info(f" {op:45s} {count:6d}")
+ logger.info(f" Total jaxpr eqns: {len(jaxpr.jaxpr.eqns)}")
+
+ # 2. HLO (lowered StableHLO)
+ if hasattr(fn, "lower"):
+ lowered = fn.lower(*args)
+ else:
+ lowered = jax.jit(fn).lower(*args)
+
+ hlo_text = lowered.as_text()
+ hlo_lines = hlo_text.split("\n")
+ result["hlo_lines"] = len(hlo_lines)
+ logger.info(f" {name}: {len(hlo_lines)} HLO lines")
+
+ # Count HLO operations
+ op_counts: dict[str, int] = {}
+ for line in hlo_lines:
+ if "=" in line and "." in line:
+ parts = line.split("=")
+ if len(parts) >= 2:
+ op_part = parts[1].strip().split()[0] if parts[1].strip() else ""
+ if "." in op_part:
+ op_name = op_part.split("(")[0]
+ op_counts[op_name] = op_counts.get(op_name, 0) + 1
+
+ result["op_counts"] = dict(sorted(op_counts.items(), key=lambda x: -x[1]))
+ if op_counts:
+ sorted_ops = sorted(op_counts.items(), key=lambda x: -x[1])
+ logger.info(f" {name}: Top HLO ops:")
+ for op, count in sorted_ops[:20]:
+ logger.info(f" {op:45s} {count:6d}")
+ logger.info(f" Total unique op types: {len(op_counts)}")
+ logger.info(f" Total ops: {sum(op_counts.values())}")
+
+ # 3. Cost analysis
+ compiled = lowered.compile()
+ cost = compiled.cost_analysis()
+ result["cost"] = cost
+ if cost:
+ for i, device_cost in enumerate(cost):
+ if device_cost and isinstance(device_cost, dict):
+ logger.info(f" {name} cost (device {i}):")
+ for key, val in sorted(device_cost.items()):
+ if isinstance(val, (int, float)):
+ if val > 1e9:
+ logger.info(f" {key}: {val / 1e9:.2f}G")
+ elif val > 1e6:
+ logger.info(f" {key}: {val / 1e6:.2f}M")
+ elif val > 1e3:
+ logger.info(f" {key}: {val / 1e3:.2f}K")
+ else:
+ logger.info(f" {key}: {val:.2f}")
+
+ # Save HLO
+ hlo_file = output_dir / f"{name}.hlo.txt"
+ with open(hlo_file, "w") as f:
+ f.write(hlo_text)
+ result["hlo_file"] = hlo_file
+ logger.info(f" Saved HLO: {hlo_file}")
+
+ except Exception as e:
+ logger.error(f"Failed to analyze {name}: {e}", exc_info=True)
+ result["error"] = str(e)
+
+ return result
+
# =========================================================================
# Node Collapse Implementation
# =========================================================================
diff --git a/vajax/analysis/openvaf_models.py b/vajax/analysis/openvaf_models.py
index 999b1dfa..27aa1183 100644
--- a/vajax/analysis/openvaf_models.py
+++ b/vajax/analysis/openvaf_models.py
@@ -915,6 +915,21 @@ def prepare_static_inputs(
shared_cache_indices = []
varying_cache_indices = []
+ # Split cache arrays
+ shared_cache = cache[0, shared_cache_indices]
+ device_cache = cache[:, varying_cache_indices]
+
+ # SCCP dead-block elimination is DISABLED for the unified eval function.
+ # While SCCP can eliminate ~695/954 blocks for PSP103, the benefit is
+ # marginal (same code size, same XLA ops — XLA already CSEs branches).
+ # The cost is high: SCCP changes the JIT function hash, invalidating
+ # the persistent XLA compilation cache and adding ~99s cold-compile
+ # penalty for ring (49.5s × 2 compilations vs 0.58s cache hit).
+ # SCCP would be valuable for config-group-specialized eval functions
+ # where each group has a unique TYPE value, but that's deferred.
+ # Infrastructure is preserved in build_sccp_known_values() for reuse.
+ sccp_known_values = None
+
# Generate eval function with cache split
from vajax.analysis.limiting import fetlim, pnjlim
@@ -930,6 +945,7 @@ def prepare_static_inputs(
varying_cache_indices,
use_limit_functions=use_device_limiting,
limit_param_map=limit_param_map,
+ sccp_known_values=sccp_known_values,
)
# Safety check: if limiting is enabled but lim_rhs could not be computed
# (model uses inline limiting without $limit/BuiltinLimit calls), disable
@@ -950,15 +966,12 @@ def prepare_static_inputs(
varying_cache_indices,
use_limit_functions=False,
limit_param_map=limit_param_map,
+ sccp_known_values=sccp_known_values,
)
split_fn = partial(split_fn, limit_funcs=limit_funcs)
vmapped_split_fn = jax.jit(jax.vmap(split_fn, in_axes=(None, 0, None, 0, None, 0)))
- # Split cache arrays
- shared_cache = cache[0, shared_cache_indices]
- device_cache = cache[:, varying_cache_indices]
-
# Build default simparams from model metadata
simparams_used = split_meta.get("simparams_used", ["$analysis_type", "$mfactor", "gmin"])
simparam_count = split_meta.get("simparam_count", len(simparams_used))
diff --git a/vajax/analysis/solver_factories.py b/vajax/analysis/solver_factories.py
index 378e7c94..da085653 100644
--- a/vajax/analysis/solver_factories.py
+++ b/vajax/analysis/solver_factories.py
@@ -510,9 +510,9 @@ def enforce_noi(J, f):
def linear_solve(J, f):
"""Solve J @ delta = -f using dense direct solver."""
- # Add Tikhonov regularization for numerical stability on GPU
- reg = 1e-14 * jnp.eye(J.shape[0], dtype=J.dtype)
- return jax.scipy.linalg.solve(J + reg, -f)
+ # Diagonal regularization (gmin + gshunt) is already applied during
+ # Jacobian assembly in assemble_dense_jacobian / _build_system_dense_direct.
+ return jax.scipy.linalg.solve(J, -f)
logger.info(
f"Creating dense full MNA solver: V({n_nodes}) + I({n_vsources}), "
@@ -536,6 +536,78 @@ def linear_solve(J, f):
)
+def make_baspacho_dense_full_mna_solver(
+ build_system_jit: Callable,
+ n_nodes: int,
+ n_vsources: int,
+ noi_indices: Optional[Array] = None,
+ internal_device_indices: Optional[Array] = None,
+ max_iterations: int = 100,
+ abstol: float = 1e-12,
+ total_limit_states: int = 0,
+ options: Optional["SimulationOptions"] = None,
+ max_step: float = 1e30,
+) -> Callable:
+ """Create a dense NR solver using BaSpaCho LU on CUDA.
+
+ Uses BaSpaCho's supernodal LU factorization with CUDA backend,
+ replacing jax.scipy.linalg.solve (cuSOLVER getrf) on GPU. Benefits:
+ - Symbolic analysis done once (cached across NR iterations)
+ - Grow-only GPU memory allocation (no per-call cudaMalloc after warmup)
+ - Foundation for Phase 2b graph-capture compatibility
+
+ Falls back to standard dense solver if BaSpaCho is unavailable.
+
+ Args:
+ Same as make_dense_full_mna_solver.
+ """
+ from spineax.cudss.dense_baspacho_solver import baspacho_dense_solve
+
+ masks = _compute_noi_masks(
+ noi_indices, n_nodes, internal_device_indices=internal_device_indices
+ )
+ noi_res_idx = masks["noi_res_idx"]
+
+ residual_mask = _build_augmented_mask(masks["residual_mask"], n_vsources)
+ residual_conv_mask = _build_augmented_conv_mask(
+ masks["residual_conv_mask"], residual_mask, n_vsources
+ )
+
+ def enforce_noi(J, f):
+ """Enforce NOI constraints on dense Jacobian."""
+ if noi_res_idx is not None:
+ J = J.at[noi_res_idx, :].set(0.0)
+ J = J.at[:, noi_res_idx].set(0.0)
+ J = J.at[noi_res_idx, noi_res_idx].set(1.0)
+ f = f.at[noi_res_idx].set(0.0)
+ return J, f
+
+ def linear_solve(J, f):
+ """Solve J @ delta = -f using BaSpaCho LU on CUDA."""
+ return baspacho_dense_solve(J, -f)
+
+ logger.info(
+ f"Creating BaSpaCho dense full MNA solver: V({n_nodes}) + I({n_vsources}), "
+ f"NOI: {noi_indices is not None}"
+ )
+ return _make_nr_solver_common(
+ build_system_jit=build_system_jit,
+ n_nodes=n_nodes,
+ n_vsources=n_vsources,
+ linear_solve_fn=linear_solve,
+ enforce_noi_fn=enforce_noi,
+ noi_indices=noi_indices,
+ internal_device_indices=internal_device_indices,
+ max_iterations=max_iterations,
+ abstol=abstol,
+ total_limit_states=total_limit_states,
+ options=options,
+ max_step=max_step,
+ residual_mask=residual_mask,
+ residual_conv_mask=residual_conv_mask,
+ )
+
+
def make_spineax_full_mna_solver(
build_system_jit: Callable,
n_nodes: int,
diff --git a/vajax/analysis/transient/full_mna.py b/vajax/analysis/transient/full_mna.py
index 80d58dab..48deb7cb 100644
--- a/vajax/analysis/transient/full_mna.py
+++ b/vajax/analysis/transient/full_mna.py
@@ -36,6 +36,7 @@
from vajax._logging import logger
from vajax.analysis.solver_factories import (
+ make_baspacho_dense_full_mna_solver,
make_dense_full_mna_solver,
make_spineax_full_mna_solver,
make_umfpack_ffi_full_mna_solver,
@@ -53,6 +54,16 @@ def is_spineax_available() -> bool:
return False
+def is_baspacho_dense_available() -> bool:
+ """Check if BaSpaCho dense CUDA solver is available."""
+ try:
+ from spineax.cudss.dense_baspacho_solver import is_available
+
+ return is_available()
+ except ImportError:
+ return False
+
+
from vajax.analysis.integration import IntegrationMethod
from .adaptive import AdaptiveConfig, compute_lte_timestep_jax, predict_voltage_jax
@@ -385,17 +396,43 @@ def _ensure_full_mna_solver(self, setup: TransientSetup) -> Callable:
self._total_limit_states = total_limit_states
build_system_jit = jax.jit(build_system_fn)
- nr_solve = make_dense_full_mna_solver(
- build_system_jit,
- n_nodes,
- n_vsources,
- noi_indices=noi_indices,
- internal_device_indices=internal_device_indices,
- max_iterations=self.runner.options.tran_itl,
- abstol=self.runner.options.abstol,
- total_limit_states=total_limit_states,
- options=self.runner.options,
- )
+ # On CUDA, try BaSpaCho dense solver (pre-allocated workspace,
+ # foundation for graph-capture compatibility in Phase 2b).
+ on_cuda_dense = jax.default_backend() in ("cuda", "gpu")
+ if on_cuda_dense and is_baspacho_dense_available():
+ try:
+ logger.info("Using BaSpaCho dense solver (GPU, CUDA backend)")
+ nr_solve = make_baspacho_dense_full_mna_solver(
+ build_system_jit,
+ n_nodes,
+ n_vsources,
+ noi_indices=noi_indices,
+ internal_device_indices=internal_device_indices,
+ max_iterations=self.runner.options.tran_itl,
+ abstol=self.runner.options.abstol,
+ total_limit_states=total_limit_states,
+ options=self.runner.options,
+ )
+ except Exception as e:
+ logger.warning(
+ f"BaSpaCho dense solver failed ({e}), falling back to JAX dense solver"
+ )
+ nr_solve = None
+ else:
+ nr_solve = None
+
+ if nr_solve is None:
+ nr_solve = make_dense_full_mna_solver(
+ build_system_jit,
+ n_nodes,
+ n_vsources,
+ noi_indices=noi_indices,
+ internal_device_indices=internal_device_indices,
+ max_iterations=self.runner.options.tran_itl,
+ abstol=self.runner.options.abstol,
+ total_limit_states=total_limit_states,
+ options=self.runner.options,
+ )
else:
# Sparse path: use CSR direct stamping to eliminate COO intermediates
n_augmented = setup.n_unknowns + n_vsources
@@ -1767,21 +1804,18 @@ def _debug_step_callback(
# Compute the voltage to record - use new_X which is the actual solution we're using
# (either X_new if converged, or previous X if NR failed at min_dt)
V_to_record = new_X[:n_external]
- new_times_out = jnp.where(
- accept_step, state.times_out.at[state.step_idx].set(t_next), state.times_out
- )
- new_V_out = jnp.where(
- accept_step, state.V_out.at[state.step_idx].set(V_to_record), state.V_out
- )
- # For currents, use zero if NR failed at min_dt (current from bad solution is unreliable)
+ # Write unconditionally: on rejection step_idx doesn't advance, so stale
+ # values at step_idx get overwritten by the next accepted step. The caller
+ # trims output using step_idx, so values beyond it are ignored. This avoids
+ # materializing both branches of jnp.where on the full output arrays.
I_to_record = jnp.where(
nr_failed_at_min_dt,
jnp.zeros(n_vsources, dtype=dtype) if n_vsources > 0 else jnp.zeros(1, dtype=dtype),
I_vsource[:n_vsources] if n_vsources > 0 else jnp.zeros(1, dtype=dtype),
)
- new_I_out = jnp.where(
- accept_step, state.I_out.at[state.step_idx].set(I_to_record), state.I_out
- )
+ new_times_out = state.times_out.at[state.step_idx].set(t_next)
+ new_V_out = state.V_out.at[state.step_idx].set(V_to_record)
+ new_I_out = state.I_out.at[state.step_idx].set(I_to_record)
new_step_idx = jnp.where(accept_step, state.step_idx + 1, state.step_idx)
# Statistics