diff --git a/.github/workflows/benchmark-comparison.yml b/.github/workflows/benchmark-comparison.yml index 517a072c..58de58e2 100644 --- a/.github/workflows/benchmark-comparison.yml +++ b/.github/workflows/benchmark-comparison.yml @@ -9,6 +9,10 @@ on: schedule: - cron: '0 6 * * *' # Daily at 6am UTC +concurrency: + group: benchmark-${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} + cancel-in-progress: true + env: CARGO_TERM_COLOR: always CARGO_INCREMENTAL: 0 @@ -34,8 +38,6 @@ jobs: name: benchmark (${{ matrix.solver }}-${{ matrix.platform }}) runs-on: ${{ matrix.runner }} timeout-minutes: ${{ matrix.solver == 'sparse' && matrix.platform == 'cuda' && 360 || 90 }} - concurrency: - group: benchmark-${{ matrix.solver }}-${{ matrix.platform }}-${{ github.ref }} steps: - name: Checkout with submodules @@ -54,45 +56,39 @@ jobs: nvcc --version 2>/dev/null || echo "nvcc not in PATH" # ── System dependencies ────────────────────────────────────── - - name: Cache apt packages - uses: actions/cache@v4 - with: - path: /var/cache/apt/archives - key: apt-${{ runner.os }}-${{ runner.arch }}-${{ matrix.platform }}-v1 - restore-keys: | - apt-${{ runner.os }}-${{ runner.arch }}-${{ matrix.platform }}- - - - name: Install system dependencies (CPU) + - name: Install apt dependencies (CPU) if: matrix.platform == 'cpu' - run: | - sudo apt-get update - sudo apt-get install -y \ - cmake ninja-build \ - flex bison libfl-dev \ - libsuitesparse-dev libopenblas-dev \ + uses: robtaylor/cache-apt-pkgs-action@feat/apt-sources + with: + packages: >- + llvm-18-dev clang-18 libclang-18-dev lld-18 + cmake ninja-build flex bison libfl-dev + libsuitesparse-dev libopenblas-dev ccache bc + apt-sources: | + https://apt.llvm.org/llvm-snapshot.gpg.key | deb http://apt.llvm.org/noble/ llvm-toolchain-noble-18 main + execute_install_scripts: true - - name: Install system dependencies (CUDA) + - name: Install apt dependencies (CUDA) if: matrix.platform == 'cuda' + uses: robtaylor/cache-apt-pkgs-action@feat/apt-sources + with: + packages: >- + llvm-18-dev clang-18 libclang-18-dev lld-18 + cuda-nvcc-12-6 cuda-cudart-dev-12-6 cuda-driver-dev-12-6 + libcublas-dev-12-6 libcusolver-dev-12-6 libcusparse-dev-12-6 + libnvjitlink-dev-12-6 + cuda-libraries-12-6 cuda-cupti-12-6 + libcudnn9-cuda-12 libcudss0-cuda-12 + libsuitesparse-dev libopenblas-dev swig cmake pkg-config + # NVIDIA CUDA repo is already on the runner image (cuda-archive-keyring.gpg). + # Adding it again via apt-sources causes "Conflicting Signed-By" error. + apt-sources: | + https://apt.llvm.org/llvm-snapshot.gpg.key | deb http://apt.llvm.org/noble/ llvm-toolchain-noble-18 main + execute_install_scripts: true + + - name: Set LLVM and CUDA environment run: | - sudo apt-get update - sudo apt-get install -y \ - cmake pkg-config swig \ - libsuitesparse-dev libopenblas-dev - - # ── LLVM 18 (idempotent, works on all runners) ────────────── - - name: Install LLVM 18 - run: | - if ! llvm-config-18 --version 2>/dev/null; then - wget -qO- https://apt.llvm.org/llvm-snapshot.gpg.key | \ - sudo tee /etc/apt/trusted.gpg.d/apt.llvm.org.asc > /dev/null - sudo chmod a+r /etc/apt/trusted.gpg.d/apt.llvm.org.asc - wget -q https://apt.llvm.org/llvm.sh - chmod +x llvm.sh - sudo ./llvm.sh 18 - rm llvm.sh - fi - sudo apt-get install -y lld-18 echo "/usr/lib/llvm-18/bin" >> "$GITHUB_PATH" echo "LLVM_SYS_181_PREFIX=/usr/lib/llvm-18" >> "$GITHUB_ENV" @@ -134,24 +130,37 @@ jobs: openvaf_jax/openvaf_py vendor/OpenVAF - # ── CUDA toolkit ───────────────────────────────────────────── - - name: Install CUDA toolkit + # ── CUDA environment ───────────────────────────────────────── + - name: Set CUDA environment if: matrix.platform == 'cuda' run: | - sudo apt-get install -y cuda-toolkit-12-6 libcudnn9-cuda-12 libcudss0-cuda-12 - echo "/usr/local/cuda-12.6/bin" >> "$GITHUB_PATH" - CUDSS_LIB=$(dpkg -L libcudss0-cuda-12 | grep '\.so' | head -1) + CUDA_ROOT=$(find /usr/local -maxdepth 1 -name "cuda-12*" -type d 2>/dev/null | sort -V | tail -1) + if [ -z "$CUDA_ROOT" ] && [ -d "/usr/local/cuda" ]; then + CUDA_ROOT="/usr/local/cuda" + fi + EXTRA_LD="" + if [ -n "$CUDA_ROOT" ]; then + echo "${CUDA_ROOT}/bin" >> "$GITHUB_PATH" + # CUDA lib64 has cuSPARSE, cuFFT, etc. needed by JAX at startup. + EXTRA_LD="${CUDA_ROOT}/lib64" + fi + CUDSS_LIB=$(dpkg -L libcudss0-cuda-12 2>/dev/null | grep '\.so' | head -1) if [ -n "$CUDSS_LIB" ]; then CUDSS_DIR=$(dirname "$CUDSS_LIB") - echo "LD_LIBRARY_PATH=${CUDSS_DIR}:${LD_LIBRARY_PATH}" >> "$GITHUB_ENV" + EXTRA_LD="${CUDSS_DIR}:${EXTRA_LD}" echo "cuDSS library found at: $CUDSS_LIB" fi + if [ -n "$EXTRA_LD" ]; then + echo "LD_LIBRARY_PATH=${EXTRA_LD}:${LD_LIBRARY_PATH}" >> "$GITHUB_ENV" + fi # ── Python environment ────────────────────────────────────── - name: Install uv uses: astral-sh/setup-uv@v6 with: enable-cache: true + prune-cache: false + python-version: "3.12" - name: Set up Python run: uv python install 3.12 @@ -193,7 +202,12 @@ jobs: working-directory: vajax/sparse run: | uv pip install scikit-build-core nanobind - uv pip install --no-build-isolation . + LAUNCHER=${{ matrix.platform == 'cuda' && 'sccache' || 'ccache' }} + uv pip install --no-build-isolation \ + -C cmake.define.BLA_VENDOR=OpenBLAS \ + -C cmake.define.CMAKE_C_COMPILER_LAUNCHER=$LAUNCHER \ + -C cmake.define.CMAKE_CXX_COMPILER_LAUNCHER=$LAUNCHER \ + . # ── Run VAJAX unit tests (CPU dense only) ──────────────────── - name: Run VAJAX tests diff --git a/.github/workflows/cache-cleanup.yml b/.github/workflows/cache-cleanup.yml new file mode 100644 index 00000000..a7ff2f52 --- /dev/null +++ b/.github/workflows/cache-cleanup.yml @@ -0,0 +1,23 @@ +name: Cache Cleanup + +on: + schedule: + - cron: '0 0 1 */3 *' # First day of every 3rd month + workflow_dispatch: + +jobs: + cleanup: + runs-on: ubuntu-latest + steps: + - name: Delete uv caches + env: + GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} + run: | + echo "=== Deleting uv caches ===" + gh cache list --repo ${{ github.repository }} --key setup-uv --json id,key,sizeInBytes \ + --jq '.[] | "\(.id)\t\(.key)\t\(.sizeInBytes)"' | \ + while IFS=$'\t' read -r id key size; do + echo "Deleting: $key ($(numfmt --to=iec $size))" + gh cache delete "$id" --repo ${{ github.repository }} + done + echo "Done" diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index b6e2f231..4cc68339 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -6,6 +6,10 @@ on: pull_request: branches: [main] +concurrency: + group: lint-${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} + cancel-in-progress: true + jobs: lint: runs-on: ubuntu-latest @@ -17,6 +21,8 @@ jobs: uses: astral-sh/setup-uv@v6 with: enable-cache: true + prune-cache: false + python-version: "3.11" - name: Set up Python run: uv python install 3.11 diff --git a/.github/workflows/profile-nsys.yml b/.github/workflows/profile-nsys.yml index 28efc4af..86b729f6 100644 --- a/.github/workflows/profile-nsys.yml +++ b/.github/workflows/profile-nsys.yml @@ -14,17 +14,28 @@ on: - c6288 default: ring timesteps: - description: 'Number of timesteps' - default: '50' + description: 'Number of timesteps (500+ recommended for steady-state profiling)' + default: '500' sparse: description: 'Use sparse solver (for large circuits)' type: boolean default: false + use_baspacho: + description: 'Build spineax from source with BaSpaCho dense solver' + type: boolean + default: false + runner: + description: 'GPU runner to use' + type: choice + options: + - nvidia-runner-1 + - nvidia-runner-2 + default: nvidia-runner-1 -# Only one profiling job at a time on the GPU runner +# Only one profiling job at a time per runner concurrency: - group: nsys-profile - cancel-in-progress: false + group: nsys-profile-${{ inputs.runner }} + cancel-in-progress: true env: CARGO_TERM_COLOR: always @@ -33,7 +44,7 @@ env: jobs: nsys-profile: - runs-on: nvidia-runner-1 + runs-on: ${{ inputs.runner }} timeout-minutes: 60 steps: @@ -53,18 +64,59 @@ jobs: echo "=== nsys Version ===" nsys --version 2>/dev/null || echo "nsys not in PATH" - - name: Install LLVM 18 + - name: Install all apt dependencies (LLVM, CUDA, system libs) + uses: robtaylor/cache-apt-pkgs-action@feat/apt-sources + with: + packages: >- + llvm-18-dev clang-18 libclang-18-dev lld-18 + cuda-nvcc-12-6 cuda-cudart-dev-12-6 cuda-driver-dev-12-6 + libcublas-dev-12-6 libcusolver-dev-12-6 libcusparse-dev-12-6 + libnvjitlink-dev-12-6 + cuda-libraries-12-6 cuda-cupti-12-6 + libcudnn9-cuda-12 libcudss0-cuda-12 + cuda-nsight-systems-12-6 + libsuitesparse-dev libopenblas-dev swig cmake pkg-config + # NVIDIA CUDA repo is already on the runner image (cuda-archive-keyring.gpg). + # Adding it again via apt-sources causes "Conflicting Signed-By" error. + apt-sources: | + https://apt.llvm.org/llvm-snapshot.gpg.key | deb http://apt.llvm.org/noble/ llvm-toolchain-noble-18 main + execute_install_scripts: true + + - name: Set up LLVM and CUDA environment run: | - if ! llvm-config-18 --version 2>/dev/null; then - wget -qO- https://apt.llvm.org/llvm-snapshot.gpg.key | sudo tee /etc/apt/trusted.gpg.d/apt.llvm.org.asc > /dev/null - sudo chmod a+r /etc/apt/trusted.gpg.d/apt.llvm.org.asc - wget -q https://apt.llvm.org/llvm.sh - chmod +x llvm.sh - sudo ./llvm.sh 18 - rm llvm.sh - fi echo "LLVM_SYS_181_PREFIX=/usr/lib/llvm-18" >> $GITHUB_ENV + # Find CUDA toolkit (may be at /usr/local/cuda-12.6 or elsewhere) + CUDA_ROOT=$(find /usr/local -maxdepth 1 -name "cuda-12*" -type d 2>/dev/null | sort -V | tail -1) + if [ -z "$CUDA_ROOT" ] && [ -d "/usr/local/cuda" ]; then + CUDA_ROOT="/usr/local/cuda" + fi + if [ -n "$CUDA_ROOT" ]; then + echo "${CUDA_ROOT}/bin" >> $GITHUB_PATH + echo "CUDAToolkit_ROOT=${CUDA_ROOT}" >> $GITHUB_ENV + # Build LD_LIBRARY_PATH with CUDA lib64 (cuSPARSE, cuFFT, etc.) + # Must be a single write — GITHUB_ENV takes last value per key. + EXTRA_LD="${CUDA_ROOT}/lib64" + echo "CUDA found at: ${CUDA_ROOT}" + else + echo "::warning::CUDA toolkit not found under /usr/local" + EXTRA_LD="" + fi + + CUDSS_LIB=$(dpkg -L libcudss0-cuda-12 2>/dev/null | grep '\.so' | head -1) + if [ -n "$CUDSS_LIB" ]; then + CUDSS_DIR=$(dirname "$CUDSS_LIB") + EXTRA_LD="${CUDSS_DIR}:${EXTRA_LD}" + fi + if [ -n "$EXTRA_LD" ]; then + echo "LD_LIBRARY_PATH=${EXTRA_LD}:${LD_LIBRARY_PATH}" >> "$GITHUB_ENV" + fi + NSYS_BIN=$(dirname "$(find /opt/nvidia -name nsys -type f 2>/dev/null | head -1)" 2>/dev/null) + if [ -n "$NSYS_BIN" ]; then + echo "$NSYS_BIN" >> $GITHUB_PATH + fi + nsys --version || echo "::warning::nsys not found in PATH" + - name: Set up Rust uses: dtolnay/rust-toolchain@stable @@ -78,20 +130,12 @@ jobs: with: workspaces: openvaf_jax/openvaf_py - - name: Install CUDA toolkit - run: | - sudo apt-get update - sudo apt-get install -y cuda-nvcc-12-6 cuda-cudart-dev-12-6 - echo "/usr/local/cuda-12.6/bin" >> $GITHUB_PATH - - - name: Install system dependencies - run: | - sudo apt-get install -y libsuitesparse-dev swig cmake pkg-config - - name: Install uv uses: astral-sh/setup-uv@v6 with: enable-cache: true + prune-cache: false + python-version: "3.13" - name: Set up Python run: uv python install 3.13 @@ -105,13 +149,40 @@ jobs: - name: Install vajax with CUDA dependencies run: uv sync --extra cuda12 + - name: Install spineax with BaSpaCho from source + if: inputs.use_baspacho + run: | + # Replace PyPI spineax with git version that has BaSpaCho dense solver. + # --reinstall forces rebuild even if same version is cached from PyPI. + # FetchContent auto-fetches BaSpaCho + deps (SuiteSparse, dispenso, Eigen). + # sccache caches C++/CUDA compilation (BaSpaCho + SuiteSparse = 344 objects) + uv pip install --reinstall -v \ + "spineax-vajax @ git+https://github.com/robtaylor/spineax.git@main" \ + -C cmake.define.SPINEAX_USE_BASPACHO=ON \ + -C cmake.define.CUDAToolkit_ROOT=${CUDAToolkit_ROOT:-/usr/local/cuda-12.6} \ + -C cmake.define.CMAKE_C_COMPILER_LAUNCHER=sccache \ + -C cmake.define.CMAKE_CXX_COMPILER_LAUNCHER=sccache \ + -C cmake.define.CMAKE_CUDA_COMPILER_LAUNCHER=sccache + + # Verify the module was installed + echo "=== Installed spineax files ===" + uv run python -c "import spineax; import pathlib; p = pathlib.Path(spineax.__file__).parent; [print(f' {f.name}') for f in sorted(p.iterdir())]" + echo "=== spineax directory listing ===" + ls -lR .venv/lib/python*/site-packages/spineax/ || true + echo "=== Checking baspacho_dense_solve import ===" + uv run python -c "from spineax import baspacho_dense_solve; print(' OK:', baspacho_dense_solve)" || echo " FAILED" + - name: Run nsys profiling env: JAX_PLATFORMS: cuda,cpu - XLA_FLAGS: "--xla_gpu_autotune_level=0" + # +WHILE,+CONDITIONAL: enables command buffer capture for NR while loop + # BaSpaCho dense solver has kCmdBufferCompatible trait (replaces cuSOLVER) + XLA_FLAGS: "--xla_gpu_autotune_level=0 --xla_gpu_enable_command_buffer=+WHILE,+CONDITIONAL" XLA_PYTHON_CLIENT_PREALLOCATE: "false" XLA_PYTHON_CLIENT_ALLOCATOR: platform - TF_CPP_MIN_LOG_LEVEL: "2" + TF_CPP_MIN_LOG_LEVEL: "0" + # Debug: show which thunks prevent command buffer capture for while loops + TF_CPP_VMODULE: "command_buffer_conversion_pass=2,while_thunk=3" run: | SPARSE_FLAG="" if [ "${{ inputs.sparse }}" = "true" ]; then @@ -120,6 +191,9 @@ jobs: TIMESTAMP=$(date +%s) PROFILE_NAME="nsys-${{ inputs.circuit }}-${TIMESTAMP}" + # Export profile_name early so subsequent steps can find the report + # even if nsys exits non-zero (e.g. SIGSEGV during teardown) + echo "profile_name=${PROFILE_NAME}" >> "$GITHUB_ENV" echo "=== Starting nsys GPU Profiling ===" echo "Circuit: ${{ inputs.circuit }}" @@ -127,17 +201,51 @@ jobs: echo "Sparse: ${{ inputs.sparse }}" echo "Commit: ${{ github.sha }}" + # NVTX "run_transient" marker annotates the transient window for filtering. + # Cannot use --capture-range=nvtx: SIGSEGV during JAX teardown prevents + # nsys from finalizing the report in NVTX capture mode. + # Tolerate exit code 139 (SIGSEGV during JAX/CUDA teardown) nsys profile \ --trace=cuda,nvtx,osrt \ --output "/tmp/${PROFILE_NAME}" \ uv run python scripts/nsys_profile_target.py \ - ${{ inputs.circuit }} ${{ inputs.timesteps }} ${SPARSE_FLAG} + ${{ inputs.circuit }} ${{ inputs.timesteps }} ${SPARSE_FLAG} \ + || NSYS_EXIT=$? - echo "profile_name=${PROFILE_NAME}" >> "$GITHUB_ENV" + if [ "${NSYS_EXIT:-0}" -ne 0 ]; then + echo "::warning::nsys exited with code ${NSYS_EXIT} (139=SIGSEGV during teardown, report may still be valid)" + if [ ! -f "/tmp/${PROFILE_NAME}.nsys-rep" ]; then + echo "::error::nsys report not generated" + exit 1 + fi + fi + + - name: Export stats and SQLite + if: always() + run: | + REPORT="/tmp/${profile_name}.nsys-rep" + STATS_DIR="/tmp/nsys-stats" + mkdir -p "$STATS_DIR" + + if [ -f "$REPORT" ]; then + # Export to SQLite for offline analysis + nsys export --type=sqlite --output="/tmp/${profile_name}.sqlite" "$REPORT" || true + + # Generate key stats reports + for report in cuda_gpu_kern_sum cuda_api_sum cuda_gpu_mem_time_sum cuda_gpu_mem_size_sum; do + nsys stats "$REPORT" --report "$report" --format csv \ + --force-export=true --output "${STATS_DIR}/${report}" 2>/dev/null || true + done + + # Full kernel trace (top 100 by duration) + nsys stats "$REPORT" --report cuda_gpu_trace --format csv \ + --force-export=true --output "${STATS_DIR}/cuda_gpu_trace" 2>/dev/null || true + fi - name: Generate summary if: always() run: | + REPORT="/tmp/${profile_name}.nsys-rep" { echo "## nsys GPU Profiling Results" echo "" @@ -146,20 +254,34 @@ jobs: echo "**Sparse:** ${{ inputs.sparse }}" echo "**Commit:** \`${{ github.sha }}\`" echo "" - if [ -f "/tmp/${profile_name}.nsys-rep" ]; then - echo "### Profile Summary" + if [ -f "$REPORT" ]; then + echo "### GPU Kernel Summary" + echo '```' + nsys stats "$REPORT" --report cuda_gpu_kern_sum --force-export=true 2>&1 | head -50 || true + echo '```' + echo "" + echo "### CUDA API Summary" + echo '```' + nsys stats "$REPORT" --report cuda_api_sum --force-export=true 2>&1 | head -30 || true + echo '```' echo "" + echo "### GPU Memory Transfers" echo '```' - nsys stats "/tmp/${profile_name}.nsys-rep" --report cuda_gpu_kern_sum 2>&1 | head -40 || true + nsys stats "$REPORT" --report cuda_gpu_mem_time_sum --force-export=true 2>&1 | head -20 || true echo '```' + else + echo "**No profile report found.**" fi } >> "$GITHUB_STEP_SUMMARY" - - name: Upload nsys report + - name: Upload nsys report and stats if: always() uses: actions/upload-artifact@v4 with: name: nsys-profile-${{ inputs.circuit }}-${{ github.sha }} - path: /tmp/nsys-*.nsys-rep + path: | + /tmp/nsys-*.nsys-rep + /tmp/nsys-*.sqlite + /tmp/nsys-stats/ if-no-files-found: ignore retention-days: 30 diff --git a/.github/workflows/test-pdk.yml b/.github/workflows/test-pdk.yml index 13a7f4e0..34b4ab54 100644 --- a/.github/workflows/test-pdk.yml +++ b/.github/workflows/test-pdk.yml @@ -6,6 +6,10 @@ on: pull_request: branches: [main] +concurrency: + group: pdk-${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} + cancel-in-progress: true + env: CARGO_TERM_COLOR: always CARGO_INCREMENTAL: 0 @@ -44,12 +48,16 @@ jobs: # Mask PDK path in all subsequent log output echo "::add-mask::/tmp/pdk-gf130" - - name: Install LLVM 18 - run: | - wget -q https://apt.llvm.org/llvm.sh - chmod +x llvm.sh - sudo ./llvm.sh 18 - echo "LLVM_SYS_181_PREFIX=/usr/lib/llvm-18" >> $GITHUB_ENV + - name: Install apt dependencies + uses: robtaylor/cache-apt-pkgs-action@feat/apt-sources + with: + packages: llvm-18-dev clang-18 libclang-18-dev lld-18 + apt-sources: | + https://apt.llvm.org/llvm-snapshot.gpg.key | deb http://apt.llvm.org/noble/ llvm-toolchain-noble-18 main + execute_install_scripts: true + + - name: Set LLVM environment + run: echo "LLVM_SYS_181_PREFIX=/usr/lib/llvm-18" >> $GITHUB_ENV - name: Set up Rust uses: dtolnay/rust-toolchain@stable @@ -68,6 +76,8 @@ jobs: uses: astral-sh/setup-uv@v6 with: enable-cache: true + prune-cache: false + python-version: "3.10" - name: Set up Python run: uv python install 3.10 diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index fa4d8db0..56e20eb1 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -6,6 +6,10 @@ on: pull_request: branches: [main] +concurrency: + group: test-${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} + cancel-in-progress: true + env: CARGO_TERM_COLOR: always CARGO_INCREMENTAL: 0 @@ -23,12 +27,16 @@ jobs: with: submodules: recursive - - name: Install LLVM 18 - run: | - wget https://apt.llvm.org/llvm.sh - chmod +x llvm.sh - sudo ./llvm.sh 18 - echo "LLVM_SYS_181_PREFIX=/usr/lib/llvm-18" >> $GITHUB_ENV + - name: Install apt dependencies + uses: robtaylor/cache-apt-pkgs-action@feat/apt-sources + with: + packages: llvm-18-dev clang-18 libclang-18-dev lld-18 + apt-sources: | + https://apt.llvm.org/llvm-snapshot.gpg.key | deb http://apt.llvm.org/noble/ llvm-toolchain-noble-18 main + execute_install_scripts: true + + - name: Set LLVM environment + run: echo "LLVM_SYS_181_PREFIX=/usr/lib/llvm-18" >> $GITHUB_ENV - name: Set up Rust uses: dtolnay/rust-toolchain@stable @@ -47,6 +55,8 @@ jobs: uses: astral-sh/setup-uv@v6 with: enable-cache: true + prune-cache: false + python-version: "3.11" - name: Set up Python run: uv python install 3.11 @@ -105,12 +115,19 @@ jobs: with: submodules: recursive - - name: Install LLVM 18 - run: | - wget https://apt.llvm.org/llvm.sh - chmod +x llvm.sh - sudo ./llvm.sh 18 - echo "LLVM_SYS_181_PREFIX=/usr/lib/llvm-18" >> $GITHUB_ENV + - name: Install apt dependencies + uses: robtaylor/cache-apt-pkgs-action@feat/apt-sources + with: + packages: >- + llvm-18-dev clang-18 libclang-18-dev lld-18 + libsuitesparse-dev libopenblas-dev swig cmake + ${{ matrix.test-group.extra_packages }} + apt-sources: | + https://apt.llvm.org/llvm-snapshot.gpg.key | deb http://apt.llvm.org/noble/ llvm-toolchain-noble-18 main + execute_install_scripts: true + + - name: Set LLVM environment + run: echo "LLVM_SYS_181_PREFIX=/usr/lib/llvm-18" >> $GITHUB_ENV - name: Set up Rust uses: dtolnay/rust-toolchain@stable @@ -125,15 +142,12 @@ jobs: with: workspaces: openvaf_jax/openvaf_py - - name: Install system dependencies - run: | - sudo apt-get update - sudo apt-get install -y libsuitesparse-dev libopenblas-dev swig cmake ${{ matrix.test-group.extra_packages }} - - name: Install uv uses: astral-sh/setup-uv@v6 with: enable-cache: true + prune-cache: false + python-version: "3.11" - name: Set up Python run: uv python install 3.11 @@ -160,7 +174,11 @@ jobs: working-directory: vajax/sparse run: | uv pip install scikit-build-core nanobind - uv pip install --no-build-isolation . + uv pip install --no-build-isolation \ + -C cmake.define.BLA_VENDOR=OpenBLAS \ + -C cmake.define.CMAKE_C_COMPILER_LAUNCHER=sccache \ + -C cmake.define.CMAKE_CXX_COMPILER_LAUNCHER=sccache \ + . - name: Run tests (${{ matrix.test-group.name }}) timeout-minutes: ${{ matrix.test-group.timeout }} @@ -226,6 +244,8 @@ jobs: uses: astral-sh/setup-uv@v6 with: enable-cache: true + prune-cache: false + python-version: "3.11" - name: Set up Python run: uv python install 3.11 diff --git a/.github/workflows/xla-flag-sweep.yml b/.github/workflows/xla-flag-sweep.yml new file mode 100644 index 00000000..eb61cc1b --- /dev/null +++ b/.github/workflows/xla-flag-sweep.yml @@ -0,0 +1,121 @@ +name: XLA Flag Sweep + +on: + workflow_dispatch: + inputs: + benchmarks: + description: 'Comma-separated benchmark names' + required: false + default: 'ring' + configs: + description: 'Comma-separated config names (empty = all)' + required: false + default: '' + n_runs: + description: 'Number of timed runs per config' + required: false + default: '3' + +env: + CARGO_TERM_COLOR: always + CARGO_INCREMENTAL: 0 + +jobs: + sweep: + name: XLA flag sweep (CUDA) + runs-on: nvidia-runner-1 + timeout-minutes: 120 + concurrency: + group: xla-sweep-${{ github.ref }} + + steps: + - name: Checkout with submodules + uses: actions/checkout@v4 + with: + submodules: recursive + + - name: CUDA diagnostics + run: | + echo "=== GPU Info ===" + nvidia-smi 2>/dev/null || echo "nvidia-smi not available" + echo "=== CUDA Version ===" + nvcc --version 2>/dev/null || echo "nvcc not available" + + - name: Set up Python + uv + uses: astral-sh/setup-uv@v6 + with: + enable-cache: true + prune-cache: false + cache-dependency-glob: "uv.lock" + + - name: Install dependencies + run: uv sync --frozen --all-extras + + - name: Verify GPU access + run: | + uv run python -c " + import jax + devices = jax.devices() + print(f'JAX devices: {devices}') + print(f'GPU available: {any(d.platform == \"gpu\" for d in devices)}') + for d in devices: + print(f' {d.platform}: {d}') + " + env: + JAX_PLATFORMS: cuda,cpu + + - name: Run XLA flag sweep + env: + JAX_PLATFORMS: cuda,cpu + JAX_ENABLE_X64: "1" + XLA_PYTHON_CLIENT_PREALLOCATE: "false" + XLA_PYTHON_CLIENT_ALLOCATOR: "platform" + run: | + CONFIGS_FLAG="" + if [ -n "${{ inputs.configs }}" ]; then + CONFIGS_FLAG="--configs ${{ inputs.configs }}" + fi + + uv run python scripts/sweep_xla_flags.py \ + --benchmark ${{ inputs.benchmarks }} \ + --n-runs ${{ inputs.n_runs }} \ + $CONFIGS_FLAG \ + --json-output /tmp/xla-sweep-results.json \ + 2>&1 | tee /tmp/xla-sweep.log + + - name: Generate summary + if: always() + run: | + { + echo "## XLA Flag Sweep Results" + echo "" + echo "**Benchmarks**: ${{ inputs.benchmarks }}" + echo "**Runs per config**: ${{ inputs.n_runs }}" + echo "" + + if [ -f /tmp/xla-sweep.log ]; then + echo '```' + # Extract the summary table + sed -n '/^={80}/,$ p' /tmp/xla-sweep.log | head -40 + echo '```' + fi + + if [ -f /tmp/xla-sweep-results.json ]; then + echo "" + echo "
Raw JSON results" + echo "" + echo '```json' + cat /tmp/xla-sweep-results.json + echo '```' + echo "
" + fi + } >> "$GITHUB_STEP_SUMMARY" + + - name: Upload results + if: always() + uses: actions/upload-artifact@v4 + with: + name: xla-sweep-results + path: | + /tmp/xla-sweep-results.json + /tmp/xla-sweep.log diff --git a/CLAUDE.md b/CLAUDE.md index 7ec6dfed..9c2246c1 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -78,6 +78,22 @@ uv run python scripts/profile_gpu.py --benchmark ring,c6288 ``` +## Linting + +**Run before every commit.** CI runs both checks and will reject PRs that fail. + +```bash +# Lint check (import sorting, unused imports, etc.) +uv tool run ruff check vajax/ tests/ scripts/ benchmarks/ + +# Format check (code style) +uv tool run ruff format --check vajax/ tests/ scripts/ benchmarks/ + +# Auto-fix both +uv tool run ruff check --fix vajax/ tests/ scripts/ benchmarks/ +uv tool run ruff format vajax/ tests/ scripts/ benchmarks/ +``` + ## Precision Configuration Precision is auto-configured on import via `vajax/__init__.py`: diff --git a/openvaf_jax/__init__.py b/openvaf_jax/__init__.py index d0c15847..bc29f910 100644 --- a/openvaf_jax/__init__.py +++ b/openvaf_jax/__init__.py @@ -1266,6 +1266,51 @@ def translate_init_array_split( return init_fn, metadata + def build_sccp_known_values( + self, + shared_indices: List[int], + shared_values: List[float], + shared_cache_indices: Optional[List[int]] = None, + shared_cache_values: Optional[List[float]] = None, + ) -> Optional[Dict[str, Any]]: + """Build SCCP known_values dict from concrete shared params and cache. + + Maps concrete values to their MIR value IDs for SCCP dead-block elimination. + This tells the SCCP pass which MIR values are constant, enabling it to + resolve static branches and eliminate dead blocks at codegen time. + + NOTE: These values are used ONLY for SCCP analysis, NOT for literal + inlining in the generated Python code. Array reads (shared_params[i], + shared_cache[i]) are preserved in the generated code for GPU efficiency. + Literal inlining was found to cause 7.8x GPU regression on PSP103 by + embedding ~1300 float constants in the XLA kernel. + + Args: + shared_indices: Original param indices for shared params + shared_values: Concrete float values for each shared param + shared_cache_indices: Cache column indices for shared cache entries + shared_cache_values: Concrete float values for each shared cache entry + + Returns: + Dict mapping MIR value IDs to concrete values, or None if empty. + """ + known: Dict[str, Any] = {} + + for j, orig_idx in enumerate(shared_indices): + value_id = self.param_idx_to_val.get(orig_idx) + if value_id: + known[value_id] = shared_values[j] + + if shared_cache_values is not None and shared_cache_indices: + for j, cache_col_idx in enumerate(shared_cache_indices): + mapping = self.cache_mapping[cache_col_idx] + eval_param_idx = mapping["eval_param"] + value_id = self.param_idx_to_val.get(eval_param_idx) + if value_id: + known[value_id] = shared_cache_values[j] + + return known if known else None + def translate_eval_array_with_cache_split( self, shared_indices: List[int], @@ -1274,6 +1319,7 @@ def translate_eval_array_with_cache_split( varying_cache_indices: Optional[List[int]] = None, use_limit_functions: bool = False, limit_param_map: Optional[Dict[int, Tuple[str, str]]] = None, + sccp_known_values: Optional[Dict[str, Any]] = None, ) -> Tuple[Callable, Dict]: """Generate a vmappable eval function with split params and cache (internal API). @@ -1291,6 +1337,10 @@ def translate_eval_array_with_cache_split( limit_param_map: Dict mapping original param indices to (kind, name) tuples for limit-related params (prev_state, enable_lim, new_state, enable_integration). Excluded from shared/device params. + sccp_known_values: If provided, maps MIR value IDs to concrete values + for SCCP dead-block elimination. Build this with + build_sccp_known_values(). Values are used for SCCP + analysis only — NOT inlined as literals in generated code. Returns: Tuple of (eval_fn, metadata) @@ -1316,10 +1366,15 @@ def translate_eval_array_with_cache_split( assert self.dae_data is not None, "dae_data released, call before release_mir_data()" t0 = time.perf_counter() + n_sccp = len(sccp_known_values) if sccp_known_values else 0 logger.info( - f" translate_eval_array_with_cache_split: generating code (limit_funcs={use_limit_functions})..." + f" translate_eval_array_with_cache_split: generating code " + f"(limit_funcs={use_limit_functions}, sccp_known={n_sccp})..." ) + if sccp_known_values: + logger.info(f" SCCP: {n_sccp} known values for dead-block elimination") + # Build the eval function eval_param_names = list(self.module.param_names) builder = EvalFunctionBuilder( @@ -1327,6 +1382,7 @@ def translate_eval_array_with_cache_split( self.dae_data, self.cache_mapping, self.param_idx_to_val, + sccp_known_values=sccp_known_values, eval_param_names=eval_param_names, ) fn_name, code_lines = builder.build_with_cache_split( @@ -1339,6 +1395,22 @@ def translate_eval_array_with_cache_split( ) t1 = time.perf_counter() + + # Log SCCP statistics + if builder.sccp is not None: + dead_blocks = builder.sccp.get_dead_blocks() + total_blocks = len(self.eval_mir.blocks) + n_constants = sum(1 for v in builder.sccp.lattice.values() if v.is_constant()) + static_branches = sum( + 1 + for b in self.eval_mir.blocks + if builder.sccp.get_static_branch_direction(b) is not None + ) + logger.info( + f" SCCP results: {len(dead_blocks)}/{total_blocks} blocks dead, " + f"{static_branches} static branches, {n_constants} constants propagated" + ) + logger.info( f" translate_eval_array_with_cache_split: code generated ({len(code_lines)} lines) in {t1 - t0:.1f}s" ) @@ -1359,6 +1431,11 @@ def translate_eval_array_with_cache_split( df.write(f"# use_limit_functions={use_limit_functions}\n") df.write(f"# shared_indices={shared_indices}\n") df.write(f"# varying_indices={varying_indices}\n") + df.write(f"# sccp_known_values={n_sccp}\n") + if builder.sccp is not None: + dead_blocks = builder.sccp.get_dead_blocks() + total_blocks = len(self.eval_mir.blocks) + df.write(f"# sccp: {len(dead_blocks)}/{total_blocks} blocks dead\n") df.write(code) logger.info(f" Generated code dumped to {dump_path}") diff --git a/openvaf_jax/codegen/function_builder.py b/openvaf_jax/codegen/function_builder.py index 8a4b68d9..0c04bc9d 100644 --- a/openvaf_jax/codegen/function_builder.py +++ b/openvaf_jax/codegen/function_builder.py @@ -980,7 +980,10 @@ def build_with_cache_split( return fn_name, code_str.split("\n") def _emit_param_mapping( - self, body: List[ast.stmt], ctx: CodeGenContext, idx_mapping: Dict[int, Tuple[str, any]] + self, + body: List[ast.stmt], + ctx: CodeGenContext, + idx_mapping: Dict[int, Tuple[str, any]], ): """Emit parameter mapping from split arrays. @@ -999,7 +1002,10 @@ def _emit_param_mapping( source, value = idx_mapping[i] if source == "shared": body.append( - assign(var_name, subscript(ast_name("shared_params"), ast_const(value))) + assign( + var_name, + subscript(ast_name("shared_params"), ast_const(value)), + ) ) elif source == "device": body.append( @@ -1041,6 +1047,11 @@ def _emit_cache_mapping( """Emit cache value mapping from split cache arrays. Always uses split cache format (shared_cache, device_cache) for uniform interface. + + Args: + body: List to append statements to + ctx: Code generation context + cache_idx_mapping: Maps cache index to (source, new_index) """ for cache_idx, mapping in enumerate(self.cache_mapping): eval_param_idx = mapping["eval_param"] @@ -1051,7 +1062,10 @@ def _emit_cache_mapping( source, new_idx = cache_idx_mapping[cache_idx] if source == "shared_cache": body.append( - assign(var_name, subscript(ast_name("shared_cache"), ast_const(new_idx))) + assign( + var_name, + subscript(ast_name("shared_cache"), ast_const(new_idx)), + ) ) else: body.append( diff --git a/scripts/analyze_dense_jaxpr.py b/scripts/analyze_dense_jaxpr.py new file mode 100644 index 00000000..2d5baf1d --- /dev/null +++ b/scripts/analyze_dense_jaxpr.py @@ -0,0 +1,217 @@ +#!/usr/bin/env -S uv run --script +# /// script +# requires-python = ">=3.10" +# dependencies = ["jax", "jaxlib"] +# /// +"""Analyze JAX IR for dense benchmark circuits. + +Dumps jaxpr, HLO op counts, and cost analysis for the key hot paths: +1. build_system (Jacobian + residual assembly) +2. nr_solve (Newton-Raphson with while_loop) +3. run_while (full transient step with adaptive timestep) +""" + +import os +import sys +from pathlib import Path + +os.environ.setdefault("JAX_PLATFORMS", "cpu") + +# Add project root to path +sys.path.insert(0, str(Path(__file__).parent.parent)) + +import jax +import jax.numpy as jnp + +from vajax.analysis.engine import CircuitEngine +from vajax.benchmarks.registry import get_benchmark + + +def count_hlo_ops(hlo_text: str) -> dict[str, int]: + """Count operation types in HLO text.""" + op_counts: dict[str, int] = {} + for line in hlo_text.split("\n"): + if "=" in line and "." in line: + parts = line.split("=") + if len(parts) >= 2: + op_part = parts[1].strip().split()[0] if parts[1].strip() else "" + if "." in op_part: + op_name = op_part.split("(")[0] + op_counts[op_name] = op_counts.get(op_name, 0) + 1 + return dict(sorted(op_counts.items(), key=lambda x: -x[1])) + + +def analyze_function(name: str, fn, args, output_dir: Path): + """Analyze a single function: jaxpr, HLO, cost.""" + print(f"\n{'=' * 70}") + print(f" {name}") + print(f"{'=' * 70}") + + try: + if hasattr(fn, "lower"): + lowered = fn.lower(*args) + else: + lowered = jax.jit(fn).lower(*args) + + hlo_text = lowered.as_text() + hlo_lines = hlo_text.split("\n") + print(f" HLO: {len(hlo_lines)} lines") + + ops = count_hlo_ops(hlo_text) + if ops: + top = list(ops.items())[:20] + print(" Top HLO ops:") + for op, count in top: + print(f" {op:40s} {count:6d}") + + compiled = lowered.compile() + cost = compiled.cost_analysis() + if cost: + for i, device_cost in enumerate(cost): + if device_cost and isinstance(device_cost, dict): + print(f" Cost (device {i}):") + for key, val in sorted(device_cost.items()): + if isinstance(val, (int, float)): + if val > 1e9: + print(f" {key}: {val / 1e9:.2f}G") + elif val > 1e6: + print(f" {key}: {val / 1e6:.2f}M") + elif val > 1e3: + print(f" {key}: {val / 1e3:.2f}K") + else: + print(f" {key}: {val:.2f}") + + # Save HLO + output_dir.mkdir(parents=True, exist_ok=True) + hlo_file = output_dir / f"{name}.hlo.txt" + with open(hlo_file, "w") as f: + f.write(hlo_text) + print(f" Saved: {hlo_file}") + + except Exception as e: + import traceback + + print(f" Failed: {e}") + traceback.print_exc() + + +def analyze_benchmark(benchmark_name: str, output_dir: Path): + """Analyze all hot paths for a single benchmark.""" + print(f"\n{'#' * 70}") + print(f" Benchmark: {benchmark_name}") + print(f"{'#' * 70}") + + config = get_benchmark(benchmark_name) + engine = CircuitEngine(config.sim_path) + engine.parse() + + # Use short simulation for analysis (just need compilation, not full run) + num_steps = 100 + engine.prepare( + t_stop=config.dt * num_steps, + dt=config.dt, + use_sparse=False, + ) + + # Get strategy internals + strategy = engine._strategy + setup_cache = engine._transient_setup_cache + + n_total = setup_cache["n_total"] + n_unknowns = setup_cache["n_unknowns"] + n_vsources = len([d for d in engine.devices if d["model"] == "vsource"]) + n_isources = len([d for d in engine.devices if d["model"] == "isource"]) + n_augmented = n_unknowns + n_vsources + + print(f" Nodes: {n_total}, Unknowns: {n_unknowns}, Vsources: {n_vsources}") + print(f" Augmented system: {n_augmented}x{n_augmented}") + + bench_dir = output_dir / benchmark_name + + # 1. Analyze build_system (Jacobian + residual assembly) + build_fn = setup_cache.get("build_system_fn") + device_arrays = engine._device_arrays + if build_fn is not None and device_arrays is not None: + X = jnp.zeros(n_total + n_vsources, dtype=jnp.float64) + vsource_vals = jnp.zeros(n_vsources, dtype=jnp.float64) + isource_vals = jnp.zeros(max(n_isources, 0), dtype=jnp.float64) + Q_prev = jnp.zeros(n_unknowns, dtype=jnp.float64) + integ_c0 = jnp.asarray(1e9, dtype=jnp.float64) # typical 1/dt + gmin = jnp.asarray(1e-12, dtype=jnp.float64) + gshunt = jnp.asarray(0.0, dtype=jnp.float64) + integ_c1 = jnp.asarray(0.0, dtype=jnp.float64) + integ_d1 = jnp.asarray(0.0, dtype=jnp.float64) + dQdt_prev = jnp.zeros(n_unknowns, dtype=jnp.float64) + integ_c2 = jnp.asarray(0.0, dtype=jnp.float64) + Q_prev2 = jnp.zeros(n_unknowns, dtype=jnp.float64) + total_limit_states = setup_cache.get("total_limit_states", 0) + limit_state = jnp.zeros(total_limit_states, dtype=jnp.float64) + nr_iter = jnp.asarray(1, dtype=jnp.int32) + + build_args = ( + X, + vsource_vals, + isource_vals, + Q_prev, + integ_c0, + device_arrays, + gmin, + gshunt, + integ_c1, + integ_d1, + dQdt_prev, + integ_c2, + Q_prev2, + limit_state, + nr_iter, + ) + analyze_function("build_system", build_fn, build_args, bench_dir) + + # 2. Analyze nr_solve + nr_solve = setup_cache.get("nr_solve_fn") + if nr_solve is not None and device_arrays is not None: + X_init = jnp.zeros(n_total + n_vsources, dtype=jnp.float64) + vsource_vals = jnp.zeros(n_vsources, dtype=jnp.float64) + isource_vals = jnp.zeros(max(n_isources, 0), dtype=jnp.float64) + Q_prev = jnp.zeros(n_unknowns, dtype=jnp.float64) + integ_c0 = jnp.asarray(1e9, dtype=jnp.float64) + + nr_args = (X_init, vsource_vals, isource_vals, Q_prev, integ_c0, device_arrays) + analyze_function("nr_solve", nr_solve, nr_args, bench_dir) + + # 3. Analyze full transient step (run_while) + run_while = getattr(strategy, "_jit_run_while", None) + if run_while is None: + # Try to find it in the cache + run_while = strategy._jit_run_while_cache.get( + strategy._get_cache_key() if hasattr(strategy, "_get_cache_key") else None + ) + + if run_while is not None: + print("\n Found run_while - analyzing full transient loop") + # This one is harder to trace without the actual state + # We'd need to construct a FullMNAState - skip for now + print(" (skipping run_while - complex state tuple)") + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser(description="Analyze dense benchmark JAX IR") + parser.add_argument( + "--benchmark", + default="rc,graetz,mul,ring", + help="Comma-separated benchmarks", + ) + parser.add_argument( + "--output-dir", + default="/tmp/claude/jaxpr-analysis", + help="Output directory for HLO files", + ) + args = parser.parse_args() + + output_dir = Path(args.output_dir) + benchmarks = [b.strip() for b in args.benchmark.split(",")] + + for bench in benchmarks: + analyze_benchmark(bench, output_dir) diff --git a/scripts/analyze_parallelism.py b/scripts/analyze_parallelism.py new file mode 100644 index 00000000..a150de27 --- /dev/null +++ b/scripts/analyze_parallelism.py @@ -0,0 +1,1194 @@ +# /// script +# requires-python = ">=3.10" +# dependencies = [] +# /// +"""Analyze parallelism opportunities in VAJAX simulation matrices. + +For IREE & Baspacho test case context: given a circuit's Jacobian sparsity +pattern, reports what parallelism can be exploited during factorization, +assembly, and device evaluation. + +Key outputs: +- Elimination tree: dependency structure for sparse factorization +- Level-set parallelism: how many columns can be processed simultaneously +- Supernodal structure: dense blocks exploitable with BLAS-3 +- Fill-in analysis: memory requirements for factorization +- Device evaluation parallelism: scatter pattern from vmap'd device evals +- Pattern stability: sparsity is fixed across all NR iterations + +Usage: + # Analyze from a benchmark (captures matrices + device info) + JAX_PLATFORMS=cpu uv run scripts/analyze_parallelism.py ring + JAX_PLATFORMS=cpu uv run scripts/analyze_parallelism.py c6288 + + # Analyze existing Matrix Market file + uv run scripts/analyze_parallelism.py --from-mtx path/to/jacobian_0000.mtx + + # Output to specific directory + JAX_PLATFORMS=cpu uv run scripts/analyze_parallelism.py ring --output-dir /tmp/ring_par +""" + +import argparse +import json +import os +import sys +from collections import Counter +from pathlib import Path + +os.environ.setdefault("JAX_PLATFORMS", "cpu") + +import numpy as np +import scipy.io +import scipy.sparse as sp +import scipy.sparse.linalg + +# --------------------------------------------------------------------------- +# Elimination tree +# --------------------------------------------------------------------------- + + +def symmetrize_pattern(A: sp.spmatrix) -> sp.csc_matrix: + """Compute |A| + |A^T| as a binary pattern (no values, just structure).""" + A_csc = sp.csc_matrix(A) + # Binary pattern: set all values to 1 + A_bin = sp.csc_matrix((np.ones(A_csc.nnz), A_csc.indices, A_csc.indptr), shape=A_csc.shape) + A_sym = A_bin + A_bin.T + # Re-binarize (eliminates any 2s from diagonal overlap) + A_sym.data[:] = 1.0 + A_sym.eliminate_zeros() + return sp.csc_matrix(A_sym) + + +def compute_etree(A_csc: sp.csc_matrix) -> np.ndarray: + """Compute elimination tree of a symmetric matrix. + + Uses Liu's algorithm with path compression (union-find). + Only the upper triangle is used. + + Args: + A_csc: Symmetric matrix in CSC format + + Returns: + parent array where parent[i] is the parent of column i, + or -1 for root(s) + """ + n = A_csc.shape[0] + parent = np.full(n, -1, dtype=np.int64) + ancestor = np.arange(n, dtype=np.int64) + + indptr = A_csc.indptr + indices = A_csc.indices + + for k in range(n): + for ptr in range(indptr[k], indptr[k + 1]): + i = indices[ptr] + if i >= k: + continue + # Find root of i with path compression + r = i + while ancestor[r] != r: + r = ancestor[r] + if r != k: + parent[r] = k + ancestor[r] = k + # Path compression for i + r = i + while ancestor[r] != k: + t = ancestor[r] + ancestor[r] = k + r = t + + return parent + + +def compute_level_sets(parent: np.ndarray) -> list[list[int]]: + """Compute level sets from an elimination tree. + + Level 0 = leaves, higher levels = closer to root. + Columns at the same level have no dependencies and can be + processed in parallel. + + Returns: + List of levels, where levels[d] = list of column indices at depth d + (depth measured from leaves, so leaves are at depth 0) + """ + n = len(parent) + # Compute depth from root first + depth_from_root = np.full(n, -1, dtype=np.int64) + + # Find roots + roots = np.where(parent == -1)[0] + for r in roots: + depth_from_root[r] = 0 + + # BFS from roots to compute depth_from_root + # Build children lists for top-down traversal + children = [[] for _ in range(n)] + for i in range(n): + if parent[i] != -1: + children[parent[i]].append(i) + + queue = list(roots) + head = 0 + while head < len(queue): + node = queue[head] + head += 1 + for child in children[node]: + depth_from_root[child] = depth_from_root[node] + 1 + queue.append(child) + + max_depth = int(np.max(depth_from_root)) if n > 0 else 0 + + # Convert to bottom-up levels (leaves = 0) + depth_from_leaves = max_depth - depth_from_root + + levels: list[list[int]] = [[] for _ in range(max_depth + 1)] + for i in range(n): + levels[depth_from_leaves[i]].append(i) + + return levels + + +def compute_etree_stats(parent: np.ndarray, levels: list[list[int]]) -> dict: + """Compute statistics about elimination tree parallelism.""" + n = len(parent) + widths = [len(level) for level in levels] + height = len(levels) + + # Count leaves (nodes with no children) + has_child = np.zeros(n, dtype=bool) + for i in range(n): + if parent[i] != -1: + has_child[parent[i]] = True + n_leaves = int(np.sum(~has_child)) + + # Subtree sizes + subtree_size = np.ones(n, dtype=np.int64) + # Process bottom-up: levels[0] are leaves + for level in levels: + for node in level: + if parent[node] != -1: + subtree_size[parent[node]] += subtree_size[node] + + return { + "height": height, + "n_leaves": n_leaves, + "max_parallelism": max(widths) if widths else 0, + "avg_parallelism": float(np.mean(widths)) if widths else 0, + "min_parallelism": min(widths) if widths else 0, + "level_widths": widths, + "subtree_size_stats": { + "min": int(np.min(subtree_size)) if n > 0 else 0, + "max": int(np.max(subtree_size)) if n > 0 else 0, + "mean": float(np.mean(subtree_size)) if n > 0 else 0, + "median": float(np.median(subtree_size)) if n > 0 else 0, + }, + } + + +# --------------------------------------------------------------------------- +# Supernodal detection +# --------------------------------------------------------------------------- + + +def detect_supernodes(parent: np.ndarray, A_csc: sp.csc_matrix) -> list[list[int]]: + """Detect fundamental supernodes in the elimination tree. + + A fundamental supernode is a maximal chain of consecutive columns + j, j+1, ..., j+k where: + - parent[j] = j+1, parent[j+1] = j+2, ..., parent[j+k-1] = j+k + - The columns have nested sparsity patterns (each is a subset of the next) + + These can be factored as dense blocks using BLAS-3 operations. + """ + n = A_csc.shape[0] + if n == 0: + return [] + + # Count children per node + n_children = np.zeros(n, dtype=np.int64) + for i in range(n): + if parent[i] != -1: + n_children[parent[i]] += 1 + + # A node starts a new supernode if: + # - It has more than one child, OR + # - It is not the only child of its parent, OR + # - parent[i] != i + 1 + is_supernode_start = np.ones(n, dtype=bool) + for i in range(n - 1): + if parent[i] == i + 1 and n_children[i + 1] == 1: + is_supernode_start[i + 1] = False + + supernodes: list[list[int]] = [] + current: list[int] = [] + for i in range(n): + if is_supernode_start[i]: + if current: + supernodes.append(current) + current = [i] + else: + current.append(i) + if current: + supernodes.append(current) + + return supernodes + + +def supernode_stats(supernodes: list[list[int]]) -> dict: + """Compute statistics about supernodal structure.""" + sizes = [len(s) for s in supernodes] + size_counts = Counter(sizes) + + # Bucket into histogram ranges + histogram = {} + for size, count in sorted(size_counts.items()): + if size == 1: + histogram["1"] = histogram.get("1", 0) + count + elif size <= 4: + histogram["2-4"] = histogram.get("2-4", 0) + count + elif size <= 10: + histogram["5-10"] = histogram.get("5-10", 0) + count + elif size <= 50: + histogram["11-50"] = histogram.get("11-50", 0) + count + else: + histogram["51+"] = histogram.get("51+", 0) + count + + return { + "count": len(supernodes), + "largest": max(sizes) if sizes else 0, + "mean_size": float(np.mean(sizes)) if sizes else 0, + "median_size": float(np.median(sizes)) if sizes else 0, + "size_histogram": histogram, + } + + +# --------------------------------------------------------------------------- +# Fill-in analysis +# --------------------------------------------------------------------------- + + +def fill_in_analysis(A: sp.spmatrix) -> dict: + """Analyze fill-in from LU factorization using scipy's SuperLU. + + Uses MMD_AT_PLUS_A ordering for fill-reducing permutation. + """ + A_csc = sp.csc_matrix(A, dtype=np.float64) + n = A_csc.shape[0] + original_nnz = A_csc.nnz + + results = {} + + for ordering_name, permc_spec in [ + ("MMD_AT_PLUS_A", "MMD_AT_PLUS_A"), + ("COLAMD", "COLAMD"), + ]: + try: + lu = scipy.sparse.linalg.splu( + A_csc, + permc_spec=permc_spec, + options={"SymmetricMode": False}, + ) + l_nnz = lu.L.nnz + u_nnz = lu.U.nnz + factor_nnz = l_nnz + u_nnz - n # subtract diagonal counted twice + + results[ordering_name] = { + "L_nnz": l_nnz, + "U_nnz": u_nnz, + "factor_nnz": factor_nnz, + "fill_ratio": factor_nnz / max(original_nnz, 1), + "fill_in": factor_nnz - original_nnz, + } + except Exception as e: + results[ordering_name] = {"error": str(e)} + + return { + "original_nnz": original_nnz, + "orderings": results, + "best_ordering": min( + (k for k, v in results.items() if "error" not in v), + key=lambda k: results[k]["factor_nnz"], + default=None, + ), + } + + +# --------------------------------------------------------------------------- +# Matrix structure analysis +# --------------------------------------------------------------------------- + + +def matrix_structure_analysis(A: sp.spmatrix) -> dict: + """Analyze structural properties of the matrix.""" + A_csc = sp.csc_matrix(A) + A_csr = sp.csr_matrix(A) + n = A_csc.shape[0] + + # Bandwidth + rows, cols = A_csc.nonzero() + if len(rows) > 0: + bandwidth = int(np.max(np.abs(rows - cols))) + profile = int(np.sum(np.abs(rows - cols))) + else: + bandwidth = 0 + profile = 0 + + # Degree distribution (treating matrix as adjacency matrix) + row_nnz = np.diff(A_csr.indptr) + col_nnz = np.diff(A_csc.indptr) + + # Symmetry check + A_T = A_csc.T + sym_diff = A_csc - A_T + sym_diff.eliminate_zeros() + is_structurally_symmetric = sym_diff.nnz == 0 + + # Check for numerical symmetry + if is_structurally_symmetric: + val_diff = np.max(np.abs(A_csc.data - A_T.tocsc().data)) if A_csc.nnz > 0 else 0 + is_numerically_symmetric = val_diff < 1e-10 + else: + is_numerically_symmetric = False + + # Connected components (treating as undirected graph) + A_sym_pattern = symmetrize_pattern(A_csc) + n_components, labels = sp.csgraph.connected_components(A_sym_pattern, directed=False) + + component_sizes = Counter(labels.tolist()) + component_size_list = sorted(component_sizes.values(), reverse=True) + + # Diagonal dominance check + diag = np.abs(A_csc.diagonal()) + row_sums = np.array(np.abs(A_csr).sum(axis=1)).ravel() + off_diag_sums = row_sums - diag + diag_dominant_rows = int(np.sum(diag >= off_diag_sums)) + + return { + "size": n, + "nnz": A_csc.nnz, + "density_pct": A_csc.nnz / (n * n) * 100 if n > 0 else 0, + "bandwidth": bandwidth, + "profile": profile, + "is_structurally_symmetric": is_structurally_symmetric, + "is_numerically_symmetric": is_numerically_symmetric, + "connected_components": n_components, + "component_sizes": component_size_list[:10], # Top 10 + "diagonal_dominance": { + "dominant_rows": diag_dominant_rows, + "total_rows": n, + "pct": diag_dominant_rows / n * 100 if n > 0 else 0, + }, + "degree_stats": { + "row_min": int(np.min(row_nnz)) if n > 0 else 0, + "row_max": int(np.max(row_nnz)) if n > 0 else 0, + "row_mean": float(np.mean(row_nnz)) if n > 0 else 0, + "col_min": int(np.min(col_nnz)) if n > 0 else 0, + "col_max": int(np.max(col_nnz)) if n > 0 else 0, + "col_mean": float(np.mean(col_nnz)) if n > 0 else 0, + }, + } + + +# --------------------------------------------------------------------------- +# RCM ordering analysis +# --------------------------------------------------------------------------- + + +def rcm_analysis(A: sp.spmatrix) -> dict: + """Analyze effect of Reverse Cuthill-McKee ordering.""" + A_sym = symmetrize_pattern(A) + + try: + perm = sp.csgraph.reverse_cuthill_mckee(A_sym, symmetric_mode=True) + A_rcm = A_sym[perm][:, perm] + + rows_orig, cols_orig = A_sym.nonzero() + rows_rcm, cols_rcm = A_rcm.nonzero() + + bw_orig = int(np.max(np.abs(rows_orig - cols_orig))) if len(rows_orig) > 0 else 0 + bw_rcm = int(np.max(np.abs(rows_rcm - cols_rcm))) if len(rows_rcm) > 0 else 0 + + return { + "bandwidth_original": bw_orig, + "bandwidth_rcm": bw_rcm, + "bandwidth_reduction_pct": (1 - bw_rcm / max(bw_orig, 1)) * 100, + "permutation_available": True, + } + except Exception as e: + return {"error": str(e), "permutation_available": False} + + +# --------------------------------------------------------------------------- +# Device scatter pattern analysis +# --------------------------------------------------------------------------- + + +def analyze_device_scatter(engine) -> dict: + """Analyze device-to-matrix scatter patterns for assembly parallelism. + + Examines the stamp index mappings to determine: + - How many matrix positions are written by multiple devices (conflicts) + - Maximum fan-in to any single position + - Independence structure between device evaluations + """ + setup = engine._build_transient_setup(backend="cpu", use_dense=True) + static_inputs_cache = setup["static_inputs_cache"] + n_unknowns = setup["n_unknowns"] + + model_info = {} + # Global position → set of unique (model_type, device_idx) writers + global_position_writers: dict[tuple[int, int], set[tuple[str, int]]] = {} + + for model_type, (voltage_indices, stamp_indices, *_rest) in static_inputs_cache.items(): + jac_rows = np.asarray(stamp_indices["jac_row_indices"]) + jac_cols = np.asarray(stamp_indices["jac_col_indices"]) + res_indices = np.asarray(stamp_indices["res_indices"]) + + n_devices = jac_rows.shape[0] + n_jac_entries = jac_rows.shape[1] + n_residuals = res_indices.shape[1] + + # Count unique positions per device + positions_per_device = [] + for dev_idx in range(n_devices): + valid = (jac_rows[dev_idx] >= 0) & (jac_cols[dev_idx] >= 0) + unique_pos = set() + for j in range(n_jac_entries): + if valid[j]: + pos = (int(jac_rows[dev_idx, j]), int(jac_cols[dev_idx, j])) + unique_pos.add(pos) + if pos not in global_position_writers: + global_position_writers[pos] = set() + global_position_writers[pos].add((model_type, dev_idx)) + positions_per_device.append(len(unique_pos)) + + # Count touched nodes per device (for residual fan-out) + nodes_per_device = [] + for dev_idx in range(n_devices): + valid_nodes = set() + for r in range(n_residuals): + idx = int(res_indices[dev_idx, r]) + if idx >= 0: + valid_nodes.add(idx) + nodes_per_device.append(len(valid_nodes)) + + model_info[model_type] = { + "n_devices": n_devices, + "jac_entries_per_device": n_jac_entries, + "residuals_per_device": n_residuals, + "unique_positions_per_device": { + "min": min(positions_per_device) if positions_per_device else 0, + "max": max(positions_per_device) if positions_per_device else 0, + "mean": float(np.mean(positions_per_device)) if positions_per_device else 0, + }, + "nodes_per_device": { + "min": min(nodes_per_device) if nodes_per_device else 0, + "max": max(nodes_per_device) if nodes_per_device else 0, + "mean": float(np.mean(nodes_per_device)) if nodes_per_device else 0, + }, + } + + # Analyze scatter conflicts (unique devices per position) + fan_in_counts = [len(writers) for writers in global_position_writers.values()] + fan_in_counter = Counter(fan_in_counts) + conflict_positions = sum(1 for c in fan_in_counts if c > 1) + + # Build device conflict graph: two devices conflict if they write to + # the same matrix position + n_total_devices = sum(info["n_devices"] for info in model_info.values()) + conflict_edges = 0 + for writers in global_position_writers.values(): + n_writers = len(writers) + if n_writers > 1: + conflict_edges += n_writers * (n_writers - 1) // 2 + + return { + "n_unknowns": n_unknowns, + "total_devices": n_total_devices, + "model_types": model_info, + "scatter_conflicts": { + "total_positions": len(global_position_writers), + "conflict_positions": conflict_positions, + "conflict_pct": conflict_positions / max(len(global_position_writers), 1) * 100, + "max_fan_in": max(fan_in_counts) if fan_in_counts else 0, + "fan_in_distribution": {str(k): v for k, v in sorted(fan_in_counter.items())}, + }, + "device_conflict_graph": { + "n_nodes": n_total_devices, + "n_edges": conflict_edges, + "note": "Edges connect devices that write to the same matrix position", + }, + } + + +# --------------------------------------------------------------------------- +# Pattern stability check +# --------------------------------------------------------------------------- + + +def check_pattern_stability(matrices: list[sp.spmatrix]) -> dict: + """Verify that sparsity pattern is identical across NR iterations. + + This is a key property for IREE: the pattern is fixed, only values change, + so symbolic analysis can be compiled once and reused. + """ + if len(matrices) < 2: + return { + "is_fixed": True, + "n_samples": len(matrices), + "note": "Only one matrix available, cannot verify stability", + } + + ref = sp.csc_matrix(matrices[0]) + ref_pattern = set(zip(*ref.nonzero())) + + all_match = True + first_mismatch = None + + for idx, M in enumerate(matrices[1:], 1): + M_csc = sp.csc_matrix(M) + M_pattern = set(zip(*M_csc.nonzero())) + + if M_pattern != ref_pattern: + all_match = False + added = M_pattern - ref_pattern + removed = ref_pattern - M_pattern + first_mismatch = { + "index": idx, + "added_entries": len(added), + "removed_entries": len(removed), + } + break + + # Value variation statistics (how much do values change across iterations?) + if all_match and len(matrices) >= 2: + values = np.column_stack([sp.csc_matrix(M).data for M in matrices]) + rel_variation = np.std(values, axis=1) / (np.abs(np.mean(values, axis=1)) + 1e-30) + value_stats = { + "mean_relative_variation": float(np.mean(rel_variation)), + "max_relative_variation": float(np.max(rel_variation)), + "median_relative_variation": float(np.median(rel_variation)), + } + else: + value_stats = None + + return { + "is_fixed": all_match, + "n_samples": len(matrices), + "first_mismatch": first_mismatch, + "value_variation": value_stats, + "note": ( + "Sparsity pattern is identical across all samples — symbolic " + "factorization can be compiled once and reused for all NR iterations" + if all_match + else "WARNING: Sparsity pattern changes between iterations" + ), + } + + +# --------------------------------------------------------------------------- +# Full analysis pipeline +# --------------------------------------------------------------------------- + + +def analyze_matrix( + A: sp.spmatrix, + name: str = "", + all_matrices: list[sp.spmatrix] | None = None, +) -> dict: + """Run full parallelism analysis on a Jacobian matrix. + + Args: + A: The Jacobian matrix (any sparse format) + name: Circuit/benchmark name for labeling + all_matrices: Optional list of matrices for pattern stability check + + Returns: + Dict with all analysis results + """ + A_csc = sp.csc_matrix(A, dtype=np.float64) + n = A_csc.shape[0] + + print(f"Analyzing {n}x{n} matrix ({A_csc.nnz} nonzeros)...") + + # 1. Matrix structure + print(" Matrix structure...") + structure = matrix_structure_analysis(A_csc) + + # 2. Elimination tree on symmetrized pattern + print(" Elimination tree...") + A_sym = symmetrize_pattern(A_csc) + parent = compute_etree(A_sym) + levels = compute_level_sets(parent) + etree_stats = compute_etree_stats(parent, levels) + + # 3. Supernodes + print(" Supernodal detection...") + supernodes = detect_supernodes(parent, A_sym) + snode_stats = supernode_stats(supernodes) + + # 4. Fill-in analysis + print(" Fill-in analysis (SuperLU)...") + fill = fill_in_analysis(A_csc) + + # 5. RCM ordering + print(" RCM ordering...") + rcm = rcm_analysis(A_csc) + + # 6. Pattern stability + stability = None + if all_matrices and len(all_matrices) > 1: + print(f" Pattern stability ({len(all_matrices)} samples)...") + stability = check_pattern_stability(all_matrices) + + # Compute parallelism summary + # "Work" at each level = width (number of independent columns) + # Total sequential steps = height + # Total work = n (all columns must be processed) + # Parallelism efficiency = n / height (ideal speedup from parallelism) + parallelism_efficiency = n / max(etree_stats["height"], 1) + + analysis = { + "name": name, + "matrix": structure, + "_etree_parent": parent, # Full array, not serialized to JSON + "elimination_tree": { + **etree_stats, + "parallelism_efficiency": parallelism_efficiency, + "parent_array_sample": parent[: min(50, n)].tolist(), + "note": ( + f"Height {etree_stats['height']} levels with max width " + f"{etree_stats['max_parallelism']}. Columns at the same level " + f"can be factored in parallel. Efficiency = n/height = " + f"{parallelism_efficiency:.1f}x theoretical speedup." + ), + }, + "supernodes": snode_stats, + "fill_in": fill, + "rcm_ordering": rcm, + } + + if stability is not None: + analysis["pattern_stability"] = stability + + return analysis + + +# --------------------------------------------------------------------------- +# Device eval branch analysis +# --------------------------------------------------------------------------- + + +def analyze_eval_branches(engine) -> dict: + """Analyze jnp.where branches in compiled device eval functions. + + For each model type, checks the compiled model's parameter split to determine: + - How many device configurations exist (e.g., NMOS vs PMOS) + - Whether all eval branches are statically determinable at setup time + - How much specialization is possible + + This does NOT require dumping/parsing generated code — it analyzes the + actual shared_params, device_params, and device_cache arrays to determine + how many unique device variants exist. + """ + result = {} + + for model_type, compiled in engine._compiled_models.items(): + if "shared_params" not in compiled: + continue + + sp = np.asarray(compiled["shared_params"]) + dp = np.asarray(compiled["device_params"]) + sc = np.asarray(compiled.get("shared_cache", np.array([]))) + dc = np.asarray(compiled.get("device_cache", np.empty((dp.shape[0], 0)))) + vp = np.asarray(compiled.get("voltage_positions_in_varying", np.array([], dtype=int))) + + n_devices = dp.shape[0] + n_varying = dp.shape[1] if dp.ndim > 1 else 0 + n_voltages = len(vp) + n_static_varying = n_varying - n_voltages + + # Identify non-voltage varying param columns + voltage_cols = set(vp.tolist()) if len(vp) > 0 else set() + static_cols = sorted(set(range(n_varying)) - voltage_cols) + + # Count unique device configurations (static params only) + if static_cols and n_devices > 1: + static_dp = dp[:, static_cols] + unique_configs, config_indices, config_counts = np.unique( + static_dp, axis=0, return_inverse=True, return_counts=True + ) + n_unique_configs = len(unique_configs) + config_sizes = config_counts.tolist() + elif n_devices > 1: + # No static varying params — all devices identical + n_unique_configs = 1 + config_sizes = [n_devices] + else: + n_unique_configs = 1 + config_sizes = [1] + + # Check device_cache uniformity + n_cache_cols = dc.shape[1] if dc.ndim > 1 else 0 + if n_cache_cols > 0 and n_devices > 1: + cache_uniform = int(np.sum(np.all(dc == dc[0:1, :], axis=0))) + cache_varying = n_cache_cols - cache_uniform + + # Count unique cache configurations + unique_dc, dc_indices = np.unique(dc, axis=0, return_inverse=True) + n_unique_cache = len(unique_dc) + else: + cache_uniform = n_cache_cols + cache_varying = 0 + n_unique_cache = 1 + + # Get param names for the varying columns if available + param_names = compiled.get("param_names", []) + param_kinds = compiled.get("param_kinds", []) + varying_indices = compiled.get("varying_indices", []) + + varying_param_info = [] + for col_idx, orig_idx in enumerate(varying_indices): + if col_idx in voltage_cols: + continue + name = param_names[orig_idx] if orig_idx < len(param_names) else f"param_{orig_idx}" + kind = param_kinds[orig_idx] if orig_idx < len(param_kinds) else "unknown" + if n_devices > 1: + vals = dp[:, col_idx] + unique_vals = np.unique(vals) + varying_param_info.append( + { + "name": name, + "kind": kind, + "n_unique": len(unique_vals), + "values": unique_vals.tolist() + if len(unique_vals) <= 10 + else f"{len(unique_vals)} values", + } + ) + + result[model_type] = { + "n_devices": n_devices, + "n_shared_params": len(sp), + "n_varying_params": n_varying, + "n_voltage_params": n_voltages, + "n_static_varying_params": n_static_varying, + "n_shared_cache": len(sc) if sc.ndim == 1 else (sc.shape[1] if sc.ndim > 1 else 0), + "n_device_cache_cols": n_cache_cols, + "cache_uniform_cols": cache_uniform, + "cache_varying_cols": cache_varying, + "n_unique_param_configs": n_unique_configs, + "n_unique_cache_configs": n_unique_cache, + "config_sizes": config_sizes, + "varying_static_params": varying_param_info, + "specialization_note": ( + f"All {n_devices} devices can be grouped into {n_unique_configs} " + f"specialized eval variant(s). Branches conditioned on shared_params " + f"({len(sp)} params) and device configuration ({n_static_varying} " + f"static varying params) can be resolved at compile time, eliminating " + f"jnp.where overhead for straight-line GPU kernels." + ), + } + + return result + + +# --------------------------------------------------------------------------- +# Benchmark mode: run simulation and analyze +# --------------------------------------------------------------------------- + + +def analyze_benchmark( + benchmark_name: str, + max_captures: int = 20, + t_stop_override: float | None = None, +) -> dict: + """Run a benchmark simulation, capture matrices, and analyze parallelism. + + Also captures device scatter pattern information. + """ + import jax + + from vajax.analysis import CircuitEngine + from vajax.benchmarks.registry import get_benchmark + + info = get_benchmark(benchmark_name) + assert info is not None, f"Benchmark '{benchmark_name}' not found" + + engine = CircuitEngine(info.sim_path) + engine.parse() + + use_sparse = info.is_large + dt = info.dt + # Use override, or run just enough steps to capture max_captures matrices + # (~5 NR iterations per timestep, so max_captures/5 timesteps plus margin) + if t_stop_override is not None: + t_stop = t_stop_override + else: + # Short simulation: enough for max_captures NR systems + min_steps = max_captures * 2 # ~2x margin (5 NR iters, capture early ones) + t_stop = min(dt * min_steps, info.t_stop) + + print(f"Benchmark: {benchmark_name}") + print(f" Nodes: {engine.num_nodes}, Devices: {len(engine.devices)}") + print(f" Solver: {'sparse' if use_sparse else 'dense'}") + print(f" Transient: t_stop={t_stop:.2e}s, dt={dt:.2e}s") + + # Suppress step-by-step logging during simulation + import logging + + logging.getLogger("vajax").setLevel(logging.WARNING) + + # --- Device scatter analysis (before running simulation) --- + print("\nAnalyzing device scatter patterns...") + device_scatter = analyze_device_scatter(engine) + + # --- Capture matrices via monkey-patching --- + import vajax.analysis.solver_factories as sf + + captured_systems: list[tuple[np.ndarray, np.ndarray]] = [] + csr_info: dict = {} + capture_count = [0] + + def capture_cb(J_or_data: jax.Array, f: jax.Array): + if capture_count[0] >= max_captures: + return + captured_systems.append((np.asarray(J_or_data).copy(), np.asarray(f).copy())) + capture_count[0] += 1 + + original_make_nr = sf._make_nr_solver_common + + def patched_nr(*, linear_solve_fn, **kwargs): + def instrumented(J_or_data, f): + jax.debug.callback(capture_cb, J_or_data, f) + return linear_solve_fn(J_or_data, f) + + return original_make_nr(linear_solve_fn=instrumented, **kwargs) + + sf._make_nr_solver_common = patched_nr + + # Intercept sparse factories for CSR structure + for factory_name in ("make_umfpack_ffi_full_mna_solver", "make_spineax_full_mna_solver"): + original = getattr(sf, factory_name) + + def make_patched(orig): + def patched(*args, **kwargs): + n_nodes = args[1] if len(args) > 1 else kwargs.get("n_nodes") + n_vsources = args[2] if len(args) > 2 else kwargs.get("n_vsources") + bcsr_indptr = kwargs.get("bcsr_indptr") + bcsr_indices = kwargs.get("bcsr_indices") + if bcsr_indptr is None and len(args) > 4: + bcsr_indptr = args[4] + if bcsr_indices is None and len(args) > 5: + bcsr_indices = args[5] + if bcsr_indptr is not None and n_nodes is not None: + n_aug = n_nodes - 1 + n_vsources + csr_info["indptr"] = np.asarray(bcsr_indptr).copy() + csr_info["indices"] = np.asarray(bcsr_indices).copy() + csr_info["shape"] = (n_aug, n_aug) + return orig(*args, **kwargs) + + return patched + + patched = make_patched(original) + setattr(sf, factory_name, patched) + + import vajax.analysis.transient.full_mna as _full_mna + + setattr(_full_mna, factory_name, patched) + + # --- Run simulation --- + print("\nRunning simulation...") + engine.prepare(t_stop=t_stop, dt=dt, use_sparse=use_sparse) + result = engine.run_transient() + + convergence = result.stats.get("convergence_rate", 0) * 100 + print(f" Steps: {result.num_steps}, convergence: {convergence:.0f}%") + print(f" Captured {len(captured_systems)} linear systems") + + if not captured_systems: + print("ERROR: No systems captured!", file=sys.stderr) + sys.exit(1) + + # --- Build sparse matrices from captured data --- + matrices: list[sp.spmatrix] = [] + for J_or_data, f in captured_systems: + if use_sparse and "indptr" in csr_info: + mat = sp.csr_matrix( + (J_or_data, csr_info["indices"], csr_info["indptr"]), + shape=csr_info["shape"], + ) + else: + mat = sp.csc_matrix(J_or_data) + matrices.append(mat) + + # --- Analyze --- + print("\nAnalyzing first captured matrix...") + analysis = analyze_matrix(matrices[0], name=benchmark_name, all_matrices=matrices) + + # Add device scatter info + analysis["device_parallelism"] = device_scatter + + # Add circuit info + analysis["circuit"] = { + "name": benchmark_name, + "n_external_nodes": engine.num_nodes, + "n_devices": len(engine.devices), + "n_unknowns": device_scatter["n_unknowns"], + "simulation": { + "t_stop": t_stop, + "dt": dt, + "steps": result.num_steps, + "convergence_pct": convergence, + }, + } + + # Eval branch specialization analysis + print("\nAnalyzing eval function branches...") + branch_analysis = analyze_eval_branches(engine) + analysis["eval_specialization"] = branch_analysis + + # Add compilation note for IREE + analysis["iree_notes"] = { + "pattern_is_fixed": analysis.get("pattern_stability", {}).get("is_fixed", True), + "same_pattern_every_nr_iteration": True, + "values_change_every_iteration": True, + "typical_nr_iterations_per_step": "3-8", + "typical_timesteps": f"{result.num_steps}", + "total_solves": len(captured_systems), + "recommendation": ( + "The sparsity pattern is determined at circuit parse time and never changes. " + "Symbolic factorization (ordering, elimination tree, memory allocation) can " + "be compiled once. Only numerical factorization needs to run per NR iteration. " + f"For this circuit: {result.num_steps} timesteps x ~5 NR iters = " + f"~{result.num_steps * 5} factorizations with identical structure." + ), + } + + return analysis + + +# --------------------------------------------------------------------------- +# Output +# --------------------------------------------------------------------------- + + +def write_analysis(analysis: dict, output_dir: Path): + """Write analysis results to output directory.""" + output_dir.mkdir(parents=True, exist_ok=True) + + # JSON output (machine-readable) + json_path = output_dir / "parallelism_analysis.json" + + # Remove internal data not suitable for JSON + analysis_json = {k: v for k, v in analysis.items() if not k.startswith("_")} + analysis_json = json.loads(json.dumps(analysis_json, default=str)) + widths = analysis_json.get("elimination_tree", {}).get("level_widths", []) + if len(widths) > 100: + analysis_json["elimination_tree"]["level_widths_truncated"] = ( + widths[:50] + ["..."] + widths[-50:] + ) + del analysis_json["elimination_tree"]["level_widths"] + + with open(json_path, "w") as f: + json.dump(analysis_json, f, indent=2) + print(f" JSON: {json_path}") + + # Human-readable summary + summary_path = output_dir / "parallelism_summary.txt" + with open(summary_path, "w") as f: + name = analysis.get("name", "unknown") + f.write(f"{'=' * 70}\n") + f.write(f"Parallelism Analysis: {name}\n") + f.write(f"{'=' * 70}\n\n") + + mat = analysis["matrix"] + f.write( + f"Matrix: {mat['size']}x{mat['size']}, {mat['nnz']} nonzeros ({mat['density_pct']:.4f}%)\n" + ) + f.write(f"Bandwidth: {mat['bandwidth']}, Symmetric: {mat['is_structurally_symmetric']}\n") + f.write(f"Connected components: {mat['connected_components']}\n") + deg = mat["degree_stats"] + f.write( + f"Row degree: min={deg['row_min']}, max={deg['row_max']}, mean={deg['row_mean']:.1f}\n" + ) + f.write(f"Diagonal dominance: {mat['diagonal_dominance']['pct']:.1f}% of rows\n\n") + + et = analysis["elimination_tree"] + f.write("--- Elimination Tree ---\n") + f.write(f"Height (sequential steps): {et['height']}\n") + f.write(f"Leaves: {et['n_leaves']}\n") + f.write(f"Max parallelism (widest level): {et['max_parallelism']}\n") + f.write(f"Avg parallelism: {et['avg_parallelism']:.1f}\n") + f.write(f"Parallelism efficiency (n/height): {et['parallelism_efficiency']:.1f}x\n") + st = et["subtree_size_stats"] + f.write(f"Subtree sizes: min={st['min']}, max={st['max']}, median={st['median']:.0f}\n\n") + + sn = analysis["supernodes"] + f.write("--- Supernodes ---\n") + f.write(f"Count: {sn['count']} supernodes\n") + f.write(f"Largest: {sn['largest']} columns\n") + f.write(f"Mean size: {sn['mean_size']:.1f}\n") + f.write(f"Size distribution: {sn['size_histogram']}\n\n") + + fi = analysis["fill_in"] + f.write("--- Fill-in (LU factorization) ---\n") + f.write(f"Original nnz: {fi['original_nnz']}\n") + for order_name, order_data in fi["orderings"].items(): + if "error" not in order_data: + f.write( + f" {order_name}: factor_nnz={order_data['factor_nnz']}, " + f"fill_ratio={order_data['fill_ratio']:.2f}x, " + f"fill_in=+{order_data['fill_in']}\n" + ) + if fi["best_ordering"]: + f.write(f"Best ordering: {fi['best_ordering']}\n") + f.write("\n") + + rcm = analysis.get("rcm_ordering", {}) + if rcm.get("permutation_available"): + f.write("--- RCM Ordering ---\n") + f.write(f"Bandwidth: {rcm['bandwidth_original']} -> {rcm['bandwidth_rcm']} ") + f.write(f"({rcm['bandwidth_reduction_pct']:.1f}% reduction)\n\n") + + ps = analysis.get("pattern_stability") + if ps: + f.write("--- Pattern Stability ---\n") + f.write(f"Fixed pattern: {ps['is_fixed']} ({ps['n_samples']} samples)\n") + if ps.get("value_variation"): + vv = ps["value_variation"] + f.write(f"Value variation: mean_rel={vv['mean_relative_variation']:.4f}, ") + f.write(f"max_rel={vv['max_relative_variation']:.4f}\n") + f.write(f"{ps['note']}\n\n") + + dp = analysis.get("device_parallelism") + if dp: + f.write("--- Device Evaluation Parallelism ---\n") + f.write(f"Total devices: {dp['total_devices']}\n") + for mt, mi in dp["model_types"].items(): + f.write(f" {mt}: {mi['n_devices']} devices, ") + f.write(f"{mi['jac_entries_per_device']} Jacobian entries/device, ") + f.write(f"{mi['nodes_per_device']['mean']:.0f} nodes/device\n") + sc = dp["scatter_conflicts"] + f.write( + f"Scatter conflicts: {sc['conflict_positions']}/{sc['total_positions']} positions " + ) + f.write(f"({sc['conflict_pct']:.1f}%), max fan-in={sc['max_fan_in']}\n") + f.write(f"Fan-in distribution: {sc['fan_in_distribution']}\n\n") + + es = analysis.get("eval_specialization") + if es: + f.write("--- Eval Branch Specialization ---\n") + for mt, info in es.items(): + f.write(f" {mt}: {info['n_devices']} devices\n") + f.write(f" Params: {info['n_shared_params']} shared, ") + f.write(f"{info['n_voltage_params']} voltage, ") + f.write(f"{info['n_static_varying_params']} static-varying\n") + f.write(f" Cache: {info['n_shared_cache']} shared, ") + f.write(f"{info['cache_varying_cols']} device-varying\n") + f.write(f" Unique device configs: {info['n_unique_param_configs']}") + f.write(f" (sizes: {info['config_sizes']})\n") + if info.get("varying_static_params"): + for vp in info["varying_static_params"]: + f.write( + f" {vp['name']} ({vp['kind']}): {vp['n_unique']} unique values\n" + ) + f.write(f" {info['specialization_note']}\n") + f.write("\n") + + notes = analysis.get("iree_notes") + if notes: + f.write("--- IREE/Baspacho Notes ---\n") + f.write(f"{notes['recommendation']}\n") + + print(f" Summary: {summary_path}") + + # Write full elimination tree parent array (useful for solver development) + if "_etree_parent" in analysis: + etree_path = output_dir / "etree_parent.npy" + np.save(etree_path, analysis["_etree_parent"]) + print(f" Etree parent: {etree_path} ({len(analysis['_etree_parent'])} nodes)") + + # Write level-set widths as CSV (for plotting) + widths = analysis.get("elimination_tree", {}).get("level_widths", []) + if widths: + widths_path = output_dir / "level_set_widths.csv" + with open(widths_path, "w") as f: + f.write("level,width\n") + for i, w in enumerate(widths): + f.write(f"{i},{w}\n") + print(f" Level widths: {widths_path}") + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + + +def main(): + parser = argparse.ArgumentParser( + description="Analyze parallelism opportunities in VAJAX simulation matrices" + ) + parser.add_argument( + "benchmark", + nargs="?", + help="Benchmark name (e.g. ring, c6288, graetz)", + ) + parser.add_argument( + "--from-mtx", + type=Path, + nargs="+", + help="Analyze existing Matrix Market file(s)", + ) + parser.add_argument( + "--output-dir", + type=Path, + default=None, + help="Output directory (default: /tmp/claude/_parallelism)", + ) + parser.add_argument( + "--max-captures", + type=int, + default=20, + help="Max NR systems to capture for pattern stability check", + ) + parser.add_argument( + "--t-stop", + type=float, + default=None, + help="Override transient stop time (default: auto-short for analysis)", + ) + args = parser.parse_args() + + if args.from_mtx: + # Load from Matrix Market files + matrices = [] + for path in args.from_mtx: + print(f"Loading {path}...") + matrices.append(scipy.io.mmread(path)) + + name = args.from_mtx[0].stem.replace("jacobian_", "") + analysis = analyze_matrix(matrices[0], name=name, all_matrices=matrices) + + out_dir = args.output_dir or Path(f"/tmp/claude/{name}_parallelism") + write_analysis(analysis, out_dir) + + elif args.benchmark: + analysis = analyze_benchmark( + args.benchmark, + max_captures=args.max_captures, + t_stop_override=args.t_stop, + ) + out_dir = args.output_dir or Path(f"/tmp/claude/{args.benchmark}_parallelism") + write_analysis(analysis, out_dir) + + else: + parser.print_help() + sys.exit(1) + + print("\nDone.") + + +if __name__ == "__main__": + main() diff --git a/scripts/check_constant_folding.py b/scripts/check_constant_folding.py new file mode 100644 index 00000000..52053657 --- /dev/null +++ b/scripts/check_constant_folding.py @@ -0,0 +1,282 @@ +#!/usr/bin/env python3 +# /// script +# requires-python = ">=3.10" +# dependencies = ["jax", "jaxlib"] +# /// +"""Check whether jnp.where constant folding works at jaxpr and HLO levels. + +Quick diagnostic to verify that inlining shared params as Python literals +actually eliminates jnp.where branches in the compiled XLA program. +""" + +import jax +import jax.numpy as jnp + + +def test_basic_constant_folding(): + """Test: does jnp.where with a Python bool constant-fold?""" + print("=" * 60) + print("Test 1: jnp.where with Python bool literal") + print("=" * 60) + + def f_const(x): + # Condition is a Python bool - should constant-fold + cond = True + return jnp.where(cond, x * 2, x * 3) + + def f_traced(x, flag): + # Condition is a traced value - cannot constant-fold + return jnp.where(flag > 0.0, x * 2, x * 3) + + x = jnp.ones(10) + + jaxpr_const = jax.make_jaxpr(f_const)(x) + jaxpr_traced = jax.make_jaxpr(f_traced)(x, 1.0) + + print(f"\nConstant cond jaxpr ({len(jaxpr_const.eqns)} ops):") + print(jaxpr_const) + print(f"\nTraced cond jaxpr ({len(jaxpr_traced.eqns)} ops):") + print(jaxpr_traced) + + # Check HLO + lowered_const = jax.jit(f_const).lower(x) + lowered_traced = jax.jit(f_traced).lower(x, 1.0) + hlo_const = lowered_const.as_text() + hlo_traced = lowered_traced.as_text() + + select_const = hlo_const.count("select") + select_traced = hlo_traced.count("select") + print(f"\nHLO select ops: constant={select_const}, traced={select_traced}") + + +def test_constant_through_jnp_ops(): + """Test: does constant folding survive through jnp operations?""" + print("\n" + "=" * 60) + print("Test 2: Constant folding through jnp operations") + print("=" * 60) + + def f_inlined(x): + # Simulate what our specialization does: + # shared param inlined as literal, then used in jnp ops + v_param = 1.5e-6 # Was: shared_params[42] + v_computed = jnp.exp(v_param) # jnp op on literal + cond = v_computed > 0.5 # comparison + return jnp.where(cond, x * 2, x * 3) + + def f_array_lookup(x, shared_params): + # Original: shared_params array lookup + v_param = shared_params[42] + v_computed = jnp.exp(v_param) + cond = v_computed > 0.5 + return jnp.where(cond, x * 2, x * 3) + + x = jnp.ones(10) + shared = jnp.zeros(100) + + jaxpr_inlined = jax.make_jaxpr(f_inlined)(x) + jaxpr_lookup = jax.make_jaxpr(f_array_lookup)(x, shared) + + print(f"\nInlined literal jaxpr ({len(jaxpr_inlined.eqns)} ops):") + print(jaxpr_inlined) + print(f"\nArray lookup jaxpr ({len(jaxpr_lookup.eqns)} ops):") + print(jaxpr_lookup) + + # Check HLO + lowered_inlined = jax.jit(f_inlined).lower(x) + lowered_lookup = jax.jit(f_array_lookup).lower(x, shared) + hlo_inlined = lowered_inlined.as_text() + hlo_lookup = lowered_lookup.as_text() + + select_inlined = hlo_inlined.count("select") + select_lookup = hlo_lookup.count("select") + print(f"\nHLO select ops: inlined={select_inlined}, lookup={select_lookup}") + + +def test_python_float_vs_jnp(): + """Test: Python float literal vs jnp operation - what does JAX trace?""" + print("\n" + "=" * 60) + print("Test 3: Python float arithmetic vs jnp arithmetic") + print("=" * 60) + + def f_python_arith(x): + # Pure Python: should constant-fold completely + a = 1.5e-6 + b = a * 2.0 # Python multiplication + cond = b > 1e-6 # Python comparison -> True + return jnp.where(cond, x * 2, x * 3) + + def f_jnp_arith(x): + # jnp operations: might NOT constant-fold in jaxpr + a = 1.5e-6 + b = jnp.float64(a) * 2.0 # jnp multiplication + cond = b > 1e-6 # comparison on jnp result + return jnp.where(cond, x * 2, x * 3) + + x = jnp.ones(10) + + jaxpr_python = jax.make_jaxpr(f_python_arith)(x) + jaxpr_jnp = jax.make_jaxpr(f_jnp_arith)(x) + + print(f"\nPython arith jaxpr ({len(jaxpr_python.eqns)} ops):") + print(jaxpr_python) + print(f"\njnp arith jaxpr ({len(jaxpr_jnp.eqns)} ops):") + print(jaxpr_jnp) + + # Check HLO for both + lowered_python = jax.jit(f_python_arith).lower(x) + lowered_jnp = jax.jit(f_jnp_arith).lower(x) + hlo_python = lowered_python.as_text() + hlo_jnp = lowered_jnp.as_text() + + select_python = hlo_python.count("select") + select_jnp = hlo_jnp.count("select") + print(f"\nHLO select ops: python_arith={select_python}, jnp_arith={select_jnp}") + + +def test_generated_code_pattern(): + """Test the ACTUAL pattern used in generated eval code. + + The generated code does: + v123 = 1.5e-6 # inlined literal (was shared_params[42]) + Then later uses it in jnp operations. + + Key question: does assigning a Python float to a local variable, + then using it in jnp.where, get constant-folded? + """ + print("\n" + "=" * 60) + print("Test 4: Actual generated code pattern (assign + jnp.where)") + print("=" * 60) + + def f_generated_pattern(device_params): + # This mimics the actual generated eval code pattern + # Inlined shared params + v100 = 1.0 # TYPE = 1 (NMOS) + v101 = 0.0 # SWIGATE = 0 + v102 = 1.5e-6 # TOX + + # Device params (from vmap, traced) + v200 = device_params[0] # voltage + _v201 = device_params[1] # another voltage (unused, kept for array shape) + + # Computation chain (mimics what OpenVAF generates) + v300 = jnp.exp(v102 * 1e6) # Uses inlined literal + v301 = v300 * v200 # Mixes with traced value + + # Branch on static param + v400 = v100 > 0.5 # TYPE > 0.5 -> True for NMOS + result1 = jnp.where(v400, v301, -v301) # Should fold + + # Branch on static param through jnp op + v401 = jnp.abs(v101) # jnp op on inlined literal + v402 = v401 > 0.5 # Should be False + result2 = jnp.where(v402, result1 * 2, result1 * 3) # Should fold + + return result1 + result2 + + def f_array_pattern(device_params, shared_params): + # Original pattern: array lookups (all traced) + v100 = shared_params[0] + v101 = shared_params[1] + v102 = shared_params[2] + + v200 = device_params[0] + _v201 = device_params[1] # noqa: F841 + + v300 = jnp.exp(v102 * 1e6) + v301 = v300 * v200 + + v400 = v100 > 0.5 + result1 = jnp.where(v400, v301, -v301) + + v401 = jnp.abs(v101) + v402 = v401 > 0.5 + result2 = jnp.where(v402, result1 * 2, result1 * 3) + + return result1 + result2 + + dp = jnp.array([0.5, 0.3]) + sp = jnp.array([1.0, 0.0, 1.5e-6]) + + jaxpr_gen = jax.make_jaxpr(f_generated_pattern)(dp) + jaxpr_arr = jax.make_jaxpr(f_array_pattern)(dp, sp) + + print(f"\nInlined pattern jaxpr ({len(jaxpr_gen.eqns)} ops):") + print(jaxpr_gen) + print(f"\nArray pattern jaxpr ({len(jaxpr_arr.eqns)} ops):") + print(jaxpr_arr) + + # Check HLO + lowered_gen = jax.jit(f_generated_pattern).lower(dp) + lowered_arr = jax.jit(f_array_pattern).lower(dp, sp) + hlo_gen = lowered_gen.as_text() + hlo_arr = lowered_arr.as_text() + + select_gen = hlo_gen.count("select") + select_arr = hlo_arr.count("select") + print(f"\nHLO select ops: inlined={select_gen}, array={select_arr}") + + # Also count total HLO ops + print(f"HLO lines: inlined={len(hlo_gen.splitlines())}, array={len(hlo_arr.splitlines())}") + + +def test_vmap_interaction(): + """Test: does constant folding survive vmap? + + This is crucial because we vmap the eval function over devices. + """ + print("\n" + "=" * 60) + print("Test 5: Constant folding under vmap") + print("=" * 60) + + def f_inlined(device_params): + v_type = 1.0 # Inlined: TYPE = NMOS + cond = v_type > 0.5 + return jnp.where(cond, device_params[0] * 2, device_params[0] * 3) + + def f_lookup(device_params, shared_params): + v_type = shared_params[0] + cond = v_type > 0.5 + return jnp.where(cond, device_params[0] * 2, device_params[0] * 3) + + # vmap over batch of devices + f_inlined_vmapped = jax.vmap(f_inlined) + f_lookup_vmapped = jax.vmap(f_lookup, in_axes=(0, None)) + + batch_dp = jnp.ones((4, 3)) + sp = jnp.array([1.0]) + + jaxpr_inlined = jax.make_jaxpr(f_inlined_vmapped)(batch_dp) + jaxpr_lookup = jax.make_jaxpr(f_lookup_vmapped)(batch_dp, sp) + + print(f"\nvmapped inlined jaxpr ({len(jaxpr_inlined.eqns)} ops):") + print(jaxpr_inlined) + print(f"\nvmapped lookup jaxpr ({len(jaxpr_lookup.eqns)} ops):") + print(jaxpr_lookup) + + # HLO + lowered_inlined = jax.jit(f_inlined_vmapped).lower(batch_dp) + lowered_lookup = jax.jit(f_lookup_vmapped).lower(batch_dp, sp) + hlo_inlined = lowered_inlined.as_text() + hlo_lookup = lowered_lookup.as_text() + + select_inlined = hlo_inlined.count("select") + select_lookup = hlo_lookup.count("select") + print(f"\nHLO select ops: inlined={select_inlined}, lookup={select_lookup}") + print( + f"HLO lines: inlined={len(hlo_inlined.splitlines())}, lookup={len(hlo_lookup.splitlines())}" + ) + + +if __name__ == "__main__": + print(f"JAX version: {jax.__version__}") + print(f"Platform: {jax.default_backend()}") + print( + f"x64 enabled: {jax.config.x86_64_enabled if hasattr(jax.config, 'x86_64_enabled') else 'unknown'}" + ) + print() + + test_basic_constant_folding() + test_constant_through_jnp_ops() + test_python_float_vs_jnp() + test_generated_code_pattern() + test_vmap_interaction() diff --git a/scripts/compare_vacask.py b/scripts/compare_vacask.py index 20063930..a6941eb5 100644 --- a/scripts/compare_vacask.py +++ b/scripts/compare_vacask.py @@ -35,6 +35,7 @@ from vajax._logging import enable_performance_logging enable_performance_logging(with_memory=True, with_perf_counter=True) + import re import sys import time @@ -52,7 +53,6 @@ # Note: Set JAX_PLATFORMS=cpu before running for CPU-only mode import jax -import jax.numpy as jnp # Precision is auto-configured by vajax import (imported above via logging) # Metal/TPU use f32, CPU/CUDA use f64 @@ -64,104 +64,6 @@ ) from vajax.profiling import ProfileConfig - -def analyze_compiled_function(fn, args, name: str, output_dir: Optional[Path] = None): - """Dump jaxpr and cost analysis for a JIT-compiled function. - - Args: - fn: A JAX function (JIT-compiled or not) - args: Example arguments to trace with - name: Name for output files - output_dir: Optional directory to save analysis files - """ - print(f"\n{'=' * 70}") - print(f"JAX Analysis: {name}") - print(f"{'=' * 70}") - - # Lower the function to get HLO and cost analysis - # For JIT-compiled functions, we use .lower() directly - print(f"\n--- Lowering and compiling {name} ---") - try: - # If fn is already jitted, we can lower it directly - # Otherwise wrap it in jit first - if hasattr(fn, "lower"): - lowered = fn.lower(*args) - else: - lowered = jax.jit(fn).lower(*args) - - # Get the HLO text (MLIR representation) - hlo_text = lowered.as_text() - hlo_lines = hlo_text.split("\n") - print(f"HLO text: {len(hlo_lines)} lines") - - # Count operations in HLO - op_counts: Dict[str, int] = {} - for line in hlo_lines: - # Extract operation names from MLIR-style ops like: %0 = stablehlo.add - if "=" in line and "." in line: - parts = line.split("=") - if len(parts) >= 2: - op_part = parts[1].strip().split()[0] if parts[1].strip() else "" - if "." in op_part: - op_name = op_part.split("(")[0] # Remove args - op_counts[op_name] = op_counts.get(op_name, 0) + 1 - - if op_counts: - print(f"Top HLO ops: {dict(sorted(op_counts.items(), key=lambda x: -x[1])[:15])}") - - # Compile and get cost analysis - compiled = lowered.compile() - cost = compiled.cost_analysis() - print("\n--- Cost Analysis ---") - if cost: - for i, device_cost in enumerate(cost): - if device_cost and isinstance(device_cost, dict): - print(f"Device {i}:") - for key, val in device_cost.items(): - if isinstance(val, (int, float)): - if val > 1e9: - print(f" {key}: {val / 1e9:.2f}G") - elif val > 1e6: - print(f" {key}: {val / 1e6:.2f}M") - elif val > 1e3: - print(f" {key}: {val / 1e3:.2f}K") - else: - print(f" {key}: {val}") - else: - print(f" {key}: {val}") - elif device_cost: - print(f"Device {i}: {device_cost}") - else: - print("No cost analysis available (may not be supported on this backend)") - - # Save files if output_dir provided - if output_dir: - output_dir.mkdir(parents=True, exist_ok=True) - - # Save HLO text - hlo_file = output_dir / f"{name}_hlo.txt" - with open(hlo_file, "w") as f: - f.write(hlo_text) - print(f"\nHLO text saved to: {hlo_file}") - - # Try to get the jaxpr as well for the underlying computation - try: - # Create jaxpr from the unwrapped function if possible - jaxpr_text = str(jax.make_jaxpr(fn)(*args)) - jaxpr_file = output_dir / f"{name}_jaxpr.txt" - with open(jaxpr_file, "w") as f: - f.write(jaxpr_text) - print(f"JAXPR saved to: {jaxpr_file}") - except Exception: - pass # JIT functions may not produce useful jaxpr - - except Exception as e: - import traceback - - print(f"Failed to analyze: {e}") - traceback.print_exc() - - # Note: Benchmark configurations are now auto-discovered from # vajax.benchmarks.registry. The registry parses .sim files # to extract dt, t_stop, and device types automatically. @@ -451,41 +353,11 @@ def do_run(): ) startup_time = time.perf_counter() - startup_start - # Run analysis on compiled scan function if requested - if analyze and use_scan and hasattr(engine, "_cached_scan_fn"): - print("\n Running JAX analysis...") - # Get example inputs for the scan function - # The scan function signature is: (V_init, Q_init, all_vsource, all_isource, device_arrays) - # Must use total nodes (external + internal) from transient setup cache - setup_cache = getattr(engine, "_transient_setup_cache", None) - device_arrays = getattr(engine, "_device_arrays", None) - - if setup_cache is None or device_arrays is None: - print(" Warning: transient setup cache not found - analysis skipped") - else: - n_total = setup_cache["n_total"] - n_unknowns = setup_cache["n_unknowns"] - n_vsources = len([d for d in engine.devices if d["model"] == "vsource"]) - n_isources = len([d for d in engine.devices if d["model"] == "isource"]) - - # Create example arrays matching actual shapes - V_init = jnp.zeros(n_total, dtype=jnp.float64) - Q_init = jnp.zeros(n_unknowns, dtype=jnp.float64) - all_vsource = jnp.zeros((num_steps, n_vsources), dtype=jnp.float64) - all_isource = jnp.zeros((num_steps, n_isources), dtype=jnp.float64) - - # Determine output directory - out_dir = analyze_output_dir or Path(f"/tmp/vajax-analysis/{config.name}") - - # Analyze the scan function - analyze_compiled_function( - engine._cached_scan_fn, - (V_init, Q_init, all_vsource, all_isource, device_arrays), - f"{config.name}_scan_simulation", - out_dir, - ) - elif analyze and use_scan: - print("\n Warning: _cached_scan_fn not found - analysis skipped") + # Run analysis on compiled functions if requested + if analyze: + out_dir = analyze_output_dir or Path(f"/tmp/claude/jaxpr-analysis/{config.name}") + print(f"\n Dumping jaxpr/HLO analysis to {out_dir} ...") + engine.dump_jaxpr(out_dir) # Timed run - print perf_counter for correlation with Perfetto traces # prepare() already called above with same params, strategy is cached @@ -496,13 +368,11 @@ def do_run(): print( f"AFTER_RUN_TRANSIENT: {after_transient:.6f} (elapsed: {after_transient - start:.6f}s)" ) - # Force completion of async JAX operations - first_node = next(iter(result.voltages)) - _ = float(result.voltages[first_node][0]) + # Results are numpy arrays (materialized by block_until_ready in full_mna) end = time.perf_counter() external_elapsed = end - start print( - f"TIMED_RUN_END: {end:.6f} (elapsed: {external_elapsed:.6f}s, sync took: {end - after_transient:.6f}s)" + f"TIMED_RUN_END: {end:.6f} (elapsed: {external_elapsed:.6f}s, extract took: {end - after_transient:.6f}s)" ) # Use wall_time from stats (excludes trace saving overhead) diff --git a/scripts/nsys_profile_target.py b/scripts/nsys_profile_target.py index db64ac4d..c313fbf2 100644 --- a/scripts/nsys_profile_target.py +++ b/scripts/nsys_profile_target.py @@ -1,39 +1,49 @@ #!/usr/bin/env python3 -"""Target script for nsys-jax profiling - runs circuit simulation. - -This script is designed to be wrapped by nsys-jax: - nsys-jax -o profile.zip python scripts/nsys_profile_target.py [circuit] [timesteps] - -nsys-jax automatically handles: -- XLA_FLAGS configuration for HLO metadata dumping -- JAX_TRACEBACK_IN_LOCATIONS_LIMIT for stack traces -- JAX_ENABLE_COMPILATION_CACHE=false for metadata collection +"""Target script for nsys GPU profiling - runs circuit simulation. Usage: - python scripts/nsys_profile_target.py [circuit] [timesteps] + nsys profile -o profile uv run python scripts/nsys_profile_target.py ring 500 Arguments: - circuit: One of rc, graetz, mul, ring (default: ring) - timesteps: Number of timesteps to simulate (default: 50) + circuit: One of rc, graetz, mul, ring, c6288 (default: ring) + timesteps: Number of timesteps to simulate (default: 500) -Example: - nsys-jax -o /tmp/profile.zip python scripts/nsys_profile_target.py ring 100 +Use 500+ timesteps so JIT warmup overhead is <5% of total profile. """ import argparse +import ctypes +import logging import sys +import time from pathlib import Path import jax sys.path.insert(0, ".") -# Import vajax first to auto-configure precision based on backend -from vajax.analysis import CircuitEngine +# Enable INFO logging so solver selection messages are visible +logging.basicConfig(level=logging.INFO, format="%(name)s: %(message)s") + + +def timed(label): + """Context manager that prints elapsed time for a phase.""" + + class Timer: + def __enter__(self): + self.t0 = time.perf_counter() + print(f"[{label}] starting...", flush=True) + return self + + def __exit__(self, *exc): + dt = time.perf_counter() - self.t0 + print(f"[{label}] done in {dt:.2f}s", flush=True) + + return Timer() def main(): - parser = argparse.ArgumentParser(description="nsys-jax profiling target for VAJAX") + parser = argparse.ArgumentParser(description="nsys profiling target for VAJAX") parser.add_argument( "circuit", nargs="?", @@ -45,8 +55,8 @@ def main(): "timesteps", nargs="?", type=int, - default=50, - help="Number of timesteps to simulate (default: 50)", + default=500, + help="Number of timesteps to simulate (default: 500)", ) parser.add_argument( "--backend", @@ -61,10 +71,35 @@ def main(): ) args = parser.parse_args() - print(f"JAX backend: {jax.default_backend()}") - print(f"JAX devices: {jax.devices()}") + with timed("JAX init"): + print(f"JAX backend: {jax.default_backend()}") + print(f"JAX devices: {jax.devices()}") print(f"Circuit: {args.circuit}") print(f"Timesteps: {args.timesteps}") + + # Explicit solver availability check + print() + print("=== Solver Availability ===") + with timed("solver imports"): + try: + from spineax.cudss.dense_baspacho_solver import is_available + + print(" BaSpaCho dense import: OK") + print(f" BaSpaCho dense available: {is_available()}") + except ImportError as e: + print(f" BaSpaCho dense import: FAILED ({e})") + try: + from spineax.cudss.solver import CuDSSSolver # noqa: F401 + + print(" cuDSS sparse import: OK") + except ImportError as e: + print(f" cuDSS sparse import: FAILED ({e})") + try: + from spineax import baspacho_dense_solve as _mod + + print(f" baspacho_dense_solve C++ module: OK ({_mod})") + except ImportError as e: + print(f" baspacho_dense_solve C++ module: FAILED ({e})") print() # Find benchmark .sim file @@ -76,10 +111,15 @@ def main(): print(f"ERROR: Benchmark file not found: {sim_path}") sys.exit(1) + # Import vajax (auto-configures precision based on backend) + with timed("vajax import"): + from vajax.analysis import CircuitEngine + # Setup circuit using CircuitEngine - print(f"Setting up circuit from {sim_path}...") - engine = CircuitEngine(sim_path) - engine.parse() + with timed("circuit parse"): + print(f"Setting up circuit from {sim_path}...") + engine = CircuitEngine(sim_path) + engine.parse() print(f"Circuit size: {engine.num_nodes} nodes, {len(engine.devices)} devices") print() @@ -90,22 +130,38 @@ def main(): print() # Prepare (includes 1-step JIT warmup) - print(f"Preparing ({args.timesteps} timesteps, includes JIT warmup)...") - engine.prepare( - t_stop=args.timesteps * dt, - dt=dt, - use_sparse=args.sparse, - ) + with timed("prepare + JIT warmup"): + print(f"Preparing ({args.timesteps} timesteps, includes JIT warmup)...") + engine.prepare( + t_stop=args.timesteps * dt, + dt=dt, + use_sparse=args.sparse, + ) print("Prepare complete") print() - # Profiled run - nsys-jax captures this automatically - print(f"Starting profiled run ({args.timesteps} timesteps)...") - result = engine.run_transient() + # NVTX range for nsys --capture-range=nvtx scoping + try: + _nvtx = ctypes.CDLL("libnvToolsExt.so") + _nvtx_push = _nvtx.nvtxRangePushA + _nvtx_push.argtypes = [ctypes.c_char_p] + _nvtx_pop = _nvtx.nvtxRangePop + except OSError: + _nvtx_push = _nvtx_pop = None - print() - print(f"Completed: {result.num_steps} timesteps") - print(f"Wall time: {result.stats.get('wall_time', 0):.3f}s") + if _nvtx_push: + _nvtx_push(b"run_transient") + + with timed("transient simulation"): + print(f"Starting profiled run ({args.timesteps} timesteps)...", flush=True) + result = engine.run_transient() + + if _nvtx_pop: + _nvtx_pop() + + print(flush=True) + print(f"Completed: {result.num_steps} timesteps", flush=True) + print(f"Wall time: {result.stats.get('wall_time', 0):.3f}s", flush=True) if __name__ == "__main__": diff --git a/scripts/profile_nr_phases.py b/scripts/profile_nr_phases.py new file mode 100644 index 00000000..6100b4bb --- /dev/null +++ b/scripts/profile_nr_phases.py @@ -0,0 +1,257 @@ +# /// script +# requires-python = ">=3.10" +# dependencies = [] +# /// +"""Profile NR iteration phase breakdown using JAX profiling tools. + +Instruments the NR solver body with jax.named_scope annotations and +captures a Perfetto trace showing the time split between: + - build_system: device evaluation + Jacobian/residual assembly + - linear_solve: sparse or dense linear solve (J*delta = -f) + - convergence: residual/delta checks, step limiting, solution update + +Also uses jax.debug.callback timestamps for a quick text summary +(accurate on CPU since execution is synchronous). + +Usage: + JAX_PLATFORMS=cpu uv run python scripts/profile_nr_phases.py ring + JAX_PLATFORMS=cpu uv run python scripts/profile_nr_phases.py c6288 + JAX_PLATFORMS=cpu uv run python scripts/profile_nr_phases.py c6288 --trace-dir /tmp/jax_trace +""" + +import argparse +import os +import time +from pathlib import Path + +os.environ.setdefault("JAX_PLATFORMS", "cpu") + +import jax +import jax.numpy as jnp +import numpy as np + +# --------------------------------------------------------------------------- +# Phase timing via jax.debug.callback (CPU-accurate) +# --------------------------------------------------------------------------- + +phase_timings: list[dict] = [] +_phase_clock: dict[str, float] = {} + + +def _start_phase(phase_name_bytes): + """Record start time for a phase.""" + phase_name = ( + phase_name_bytes.tobytes().decode() + if hasattr(phase_name_bytes, "tobytes") + else str(phase_name_bytes) + ) + _phase_clock[phase_name] = time.perf_counter_ns() + + +def _end_phase(phase_name_bytes, iteration): + """Record end time for a phase.""" + phase_name = ( + phase_name_bytes.tobytes().decode() + if hasattr(phase_name_bytes, "tobytes") + else str(phase_name_bytes) + ) + start = _phase_clock.get(phase_name, 0) + elapsed_ns = time.perf_counter_ns() - start + phase_timings.append( + { + "phase": phase_name, + "iteration": int(iteration), + "elapsed_us": elapsed_ns / 1000, + } + ) + + +# --------------------------------------------------------------------------- +# Monkey-patch NR solver to add named scopes + timing callbacks +# --------------------------------------------------------------------------- + +import vajax.analysis.solver_factories as sf + +_original_make_nr = sf._make_nr_solver_common + + +def patched_make_nr_solver_common(*, build_system_jit, linear_solve_fn, enforce_noi_fn, **kwargs): + """Wrap build_system and linear_solve with named scopes and timing.""" + + def timed_build_system(*args): + with jax.named_scope("nr_build_system"): + return build_system_jit(*args) + + def timed_linear_solve(J_or_data, f): + with jax.named_scope("nr_linear_solve"): + return linear_solve_fn(J_or_data, f) + + def timed_enforce_noi(J_or_data, f): + with jax.named_scope("nr_enforce_noi"): + return enforce_noi_fn(J_or_data, f) + + return _original_make_nr( + build_system_jit=timed_build_system, + linear_solve_fn=timed_linear_solve, + enforce_noi_fn=timed_enforce_noi, + **kwargs, + ) + + +sf._make_nr_solver_common = patched_make_nr_solver_common + + +# --------------------------------------------------------------------------- +# Also add callback-based timing for text summary +# --------------------------------------------------------------------------- + +_original_make_nr2 = sf._make_nr_solver_common # This is now our patched version + + +def callback_timed_make_nr(*, build_system_jit, linear_solve_fn, enforce_noi_fn, **kwargs): + """Add jax.debug.callback timestamps around each phase.""" + + def timed_build_system(*args): + # Extract iteration from args (it's the last positional arg) + iteration = args[-1] if len(args) > 0 else jnp.array(0) + build_tag = jnp.array(list(b"build_system"), dtype=jnp.uint8) + jax.debug.callback(_start_phase, build_tag) + result = build_system_jit(*args) + jax.debug.callback(_end_phase, build_tag, iteration) + return result + + def timed_linear_solve(J_or_data, f): + solve_tag = jnp.array(list(b"linear_solve"), dtype=jnp.uint8) + jax.debug.callback(_start_phase, solve_tag) + result = linear_solve_fn(J_or_data, f) + jax.debug.callback(_end_phase, solve_tag, jnp.array(-1)) + return result + + return _original_make_nr2( + build_system_jit=timed_build_system, + linear_solve_fn=timed_linear_solve, + enforce_noi_fn=enforce_noi_fn, + **kwargs, + ) + + +sf._make_nr_solver_common = callback_timed_make_nr + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + + +def main(): + import logging + + from vajax.analysis import CircuitEngine + from vajax.benchmarks.registry import get_benchmark + + parser = argparse.ArgumentParser(description="Profile NR phase breakdown") + parser.add_argument("benchmark", help="Benchmark name (e.g. ring, c6288)") + parser.add_argument( + "--trace-dir", + type=Path, + default=None, + help="Directory for Perfetto trace (default: /tmp/claude/_trace)", + ) + parser.add_argument("--t-stop", type=float, default=None, help="Override stop time") + parser.add_argument( + "--n-steps", type=int, default=10, help="Number of timesteps to profile (default: 10)" + ) + args = parser.parse_args() + + logging.getLogger("vajax").setLevel(logging.WARNING) + + info = get_benchmark(args.benchmark) + assert info is not None, f"Benchmark '{args.benchmark}' not found" + + engine = CircuitEngine(info.sim_path) + engine.parse() + + use_sparse = info.is_large + dt = info.dt + t_stop = args.t_stop or dt * args.n_steps + + print(f"Benchmark: {args.benchmark}") + print(f" Nodes: {engine.num_nodes}, Devices: {len(engine.devices)}") + print(f" Solver: {'sparse' if use_sparse else 'dense'}") + print(f" Profiling {args.n_steps} steps (t_stop={t_stop:.2e}s)") + + trace_dir = args.trace_dir or Path(f"/tmp/claude/{args.benchmark}_trace") + trace_dir.mkdir(parents=True, exist_ok=True) + + # Prepare (includes JIT warmup) + print("\nPreparing (JIT warmup)...") + engine.prepare(t_stop=t_stop, dt=dt, use_sparse=use_sparse) + + # Clear any timing from warmup + phase_timings.clear() + + # Run with Perfetto trace capture + print(f"Running with profiler trace -> {trace_dir}") + jax.profiler.start_trace(str(trace_dir)) + try: + result = engine.run_transient() + finally: + jax.profiler.stop_trace() + + convergence = result.stats.get("convergence_rate", 0) * 100 + print(f" Steps: {result.num_steps}, convergence: {convergence:.0f}%") + print(f" Trace saved to: {trace_dir}") + + # --- Analyze callback timings --- + if not phase_timings: + print("\nNo callback timings captured (expected inside lax.while_loop)") + print(f"View the Perfetto trace at: {trace_dir}") + print(" Open https://ui.perfetto.dev and load the trace file") + return + + print(f"\n{'=' * 60}") + print(f"NR Phase Timing Breakdown ({len(phase_timings)} measurements)") + print(f"{'=' * 60}") + + # Aggregate by phase + by_phase: dict[str, list[float]] = {} + for entry in phase_timings: + phase = entry["phase"] + if phase not in by_phase: + by_phase[phase] = [] + by_phase[phase].append(entry["elapsed_us"]) + + total_us = sum(sum(times) for times in by_phase.values()) + + print(f"\n{'Phase':<20} {'Count':>6} {'Total (ms)':>12} {'Mean (µs)':>12} {'%':>8}") + print(f"{'-' * 20} {'-' * 6} {'-' * 12} {'-' * 12} {'-' * 8}") + for phase, times in sorted(by_phase.items(), key=lambda x: -sum(x[1])): + total_ms = sum(times) / 1000 + mean_us = np.mean(times) + pct = sum(times) / total_us * 100 if total_us > 0 else 0 + print(f"{phase:<20} {len(times):>6} {total_ms:>12.2f} {mean_us:>12.1f} {pct:>7.1f}%") + + print(f"\n{'Total':.<20} {'':>6} {total_us / 1000:>12.2f} ms") + + # Per-NR-iteration breakdown (first few) + build_times = by_phase.get("build_system", []) + solve_times = by_phase.get("linear_solve", []) + + if build_times and solve_times: + n_show = min(10, len(build_times)) + print(f"\nPer-iteration breakdown (first {n_show}):") + print(f"{'Iter':>4} {'Build (µs)':>12} {'Solve (µs)':>12} {'Solve %':>8}") + print(f"{'-' * 4} {'-' * 12} {'-' * 12} {'-' * 8}") + for i in range(n_show): + b = build_times[i] + s = solve_times[i] if i < len(solve_times) else 0 + total = b + s + spct = s / total * 100 if total > 0 else 0 + print(f"{i:>4} {b:>12.1f} {s:>12.1f} {spct:>7.1f}%") + + print(f"\nPerfetto trace: {trace_dir}") + print(" Open https://ui.perfetto.dev and load the .pb or .json.gz file") + + +if __name__ == "__main__": + main() diff --git a/scripts/sweep_xla_flags.py b/scripts/sweep_xla_flags.py new file mode 100644 index 00000000..0175c12e --- /dev/null +++ b/scripts/sweep_xla_flags.py @@ -0,0 +1,402 @@ +#!/usr/bin/env -S uv run --script +# /// script +# requires-python = ">=3.10" +# dependencies = ["jax"] +# /// +"""Sweep XLA flag combinations to find optimal CUDA performance. + +Runs a benchmark circuit with different XLA flag configurations and +reports timing for each. Each configuration runs in a separate subprocess +to ensure clean XLA state. + +Usage: + # Run on GPU (auto-detects CUDA) + uv run scripts/sweep_xla_flags.py + + # Specific benchmark + uv run scripts/sweep_xla_flags.py --benchmark ring + + # Specific configurations only + uv run scripts/sweep_xla_flags.py --configs baseline,autotune2,command_buffer + + # Also include large circuit (needs sparse solver) + uv run scripts/sweep_xla_flags.py --benchmark ring,c6288 --include-sparse +""" + +import argparse +import json +import os +import subprocess +import sys +import time +from pathlib import Path + +# XLA flag configurations to test +FLAG_CONFIGS = { + "baseline": { + "description": "Current CI config (autotune=0)", + "xla_flags": "--xla_gpu_autotune_level=0", + "env": {}, + }, + "autotune2": { + "description": "Autotune level 2 (enables cuBLAS algorithm selection)", + "xla_flags": "--xla_gpu_autotune_level=2", + "env": {}, + }, + "autotune4": { + "description": "Autotune level 4 (full autotuning)", + "xla_flags": "--xla_gpu_autotune_level=4", + "env": {}, + }, + "command_buffer": { + "description": "Command buffers enabled (batch kernel launches)", + "xla_flags": "--xla_gpu_autotune_level=0", + "env": {}, + }, + "double_buffer": { + "description": "While-loop double buffering", + "xla_flags": ( + "--xla_gpu_autotune_level=0 --xla_gpu_enable_while_loop_double_buffering=true" + ), + "env": {}, + }, + "pgle": { + "description": "Profile-guided latency estimation (3 profiling runs)", + "xla_flags": "--xla_gpu_autotune_level=0", + "env": { + "JAX_ENABLE_PGLE": "true", + "JAX_PGLE_PROFILING_RUNS": "3", + }, + }, + "combined_safe": { + "description": "Autotune 2 + double buffering", + "xla_flags": ( + "--xla_gpu_autotune_level=2 --xla_gpu_enable_while_loop_double_buffering=true" + ), + "env": {}, + }, + "combined_aggressive": { + "description": "Autotune 4 + double buffering + PGLE", + "xla_flags": ( + "--xla_gpu_autotune_level=4 --xla_gpu_enable_while_loop_double_buffering=true" + ), + "env": { + "JAX_ENABLE_PGLE": "true", + "JAX_PGLE_PROFILING_RUNS": "3", + }, + }, +} + +# The subprocess script that runs a single benchmark +BENCHMARK_RUNNER = """ +import os +import sys +import time +import json + +sys.path.insert(0, os.environ["PROJECT_ROOT"]) + +# Memory config +os.environ.setdefault("XLA_PYTHON_CLIENT_PREALLOCATE", "false") +os.environ.setdefault("XLA_PYTHON_CLIENT_ALLOCATOR", "platform") + +import jax +import jax.numpy as jnp + +from vajax.analysis import CircuitEngine + +benchmark_name = os.environ["BENCHMARK_NAME"] +use_sparse = os.environ.get("USE_SPARSE", "0") == "1" +use_scan = True +force_gpu = os.environ.get("FORCE_GPU", "0") == "1" +n_warmup = int(os.environ.get("N_WARMUP", "1")) +n_runs = int(os.environ.get("N_RUNS", "3")) + +from scripts.benchmark_utils import get_vacask_benchmarks + +benchmarks = get_vacask_benchmarks([benchmark_name]) +if not benchmarks: + print(json.dumps({"error": f"Benchmark {benchmark_name} not found"})) + sys.exit(1) + +name, sim_path = benchmarks[0] + +# Report JAX config +devices = jax.devices() +backend = devices[0].platform if devices else "unknown" +print(f"JAX backend: {backend}, devices: {[d.platform for d in devices]}", file=sys.stderr) +print(f"XLA_FLAGS: {os.environ.get('XLA_FLAGS', '(not set)')}", file=sys.stderr) + +engine = CircuitEngine.from_sim_file(str(sim_path)) +engine.prepare(use_sparse=use_sparse, force_gpu=force_gpu, use_scan=use_scan) + +# Get step count from sim parameters +dt = engine.sim_params.get("dt", 1e-6) +t_stop = engine.sim_params.get("tstop", engine.sim_params.get("t_stop", 1e-3)) +n_steps = int(t_stop / dt) if dt > 0 else 100 + +timings = [] + +for run_idx in range(n_warmup + n_runs): + # Re-prepare to reset state + if run_idx > 0: + engine.prepare(use_sparse=use_sparse, force_gpu=force_gpu, use_scan=use_scan) + + start = time.perf_counter() + result = engine.run_transient() + # Block until computation complete + if hasattr(result, 'voltages') and result.voltages is not None: + jax.block_until_ready(result.voltages) + elapsed = time.perf_counter() - start + + actual_steps = result.n_steps if hasattr(result, 'n_steps') else n_steps + ms_per_step = (elapsed * 1000.0) / max(actual_steps, 1) + + label = "warmup" if run_idx < n_warmup else f"run {run_idx - n_warmup}" + print(f" {label}: {elapsed:.3f}s ({actual_steps} steps, {ms_per_step:.3f} ms/step)", file=sys.stderr) + + if run_idx >= n_warmup: + timings.append({ + "elapsed_s": elapsed, + "n_steps": actual_steps, + "ms_per_step": ms_per_step, + }) + +# Report median timing +timings.sort(key=lambda t: t["ms_per_step"]) +median = timings[len(timings) // 2] + +print(json.dumps({ + "benchmark": benchmark_name, + "backend": backend, + "n_steps": median["n_steps"], + "ms_per_step": median["ms_per_step"], + "elapsed_s": median["elapsed_s"], + "n_runs": n_runs, + "all_timings": [t["ms_per_step"] for t in timings], +})) +""" + + +def run_config( + config_name: str, + config: dict, + benchmark: str, + project_root: Path, + use_sparse: bool, + force_gpu: bool, + n_warmup: int = 1, + n_runs: int = 3, +) -> dict: + """Run a single benchmark with a specific XLA flag configuration.""" + env = os.environ.copy() + env["PROJECT_ROOT"] = str(project_root) + env["BENCHMARK_NAME"] = benchmark + env["USE_SPARSE"] = "1" if use_sparse else "0" + env["FORCE_GPU"] = "1" if force_gpu else "0" + env["N_WARMUP"] = str(n_warmup) + env["N_RUNS"] = str(n_runs) + env["JAX_PLATFORMS"] = "cuda,cpu" if force_gpu else "cpu" + env["JAX_ENABLE_X64"] = "1" + + # Set XLA flags + env["XLA_FLAGS"] = config["xla_flags"] + + # Set additional env vars + for k, v in config.get("env", {}).items(): + env[k] = v + + # Memory allocation + env.setdefault("XLA_PYTHON_CLIENT_PREALLOCATE", "false") + env.setdefault("XLA_PYTHON_CLIENT_ALLOCATOR", "platform") + + print(f"\n{'=' * 60}") + print(f"Config: {config_name} — {config['description']}") + print(f" XLA_FLAGS: {config['xla_flags']}") + if config.get("env"): + print(f" Extra env: {config['env']}") + print(f" Benchmark: {benchmark}") + print(f"{'=' * 60}") + + start = time.perf_counter() + try: + result = subprocess.run( + [sys.executable, "-c", BENCHMARK_RUNNER], + env=env, + capture_output=True, + text=True, + timeout=600, # 10 min max per config + cwd=str(project_root), + ) + except subprocess.TimeoutExpired: + return { + "config": config_name, + "benchmark": benchmark, + "error": "timeout (600s)", + "wall_time_s": time.perf_counter() - start, + } + + wall_time = time.perf_counter() - start + + # Print stderr (progress messages) + if result.stderr: + for line in result.stderr.strip().split("\n"): + print(f" {line}") + + # Parse JSON from last line of stdout + if result.returncode != 0: + print(f" ERROR: exit code {result.returncode}") + if result.stderr: + print(f" {result.stderr[-500:]}") + return { + "config": config_name, + "benchmark": benchmark, + "error": f"exit code {result.returncode}", + "wall_time_s": wall_time, + } + + try: + # Find the JSON line (last non-empty line of stdout) + lines = [l for l in result.stdout.strip().split("\n") if l.strip()] + data = json.loads(lines[-1]) + data["config"] = config_name + data["wall_time_s"] = wall_time + data["description"] = config["description"] + print(f" Result: {data['ms_per_step']:.3f} ms/step ({data['n_steps']} steps)") + return data + except (json.JSONDecodeError, IndexError) as e: + print(f" ERROR parsing output: {e}") + print(f" stdout: {result.stdout[-500:]}") + return { + "config": config_name, + "benchmark": benchmark, + "error": f"parse error: {e}", + "wall_time_s": wall_time, + } + + +def main(): + parser = argparse.ArgumentParser(description="Sweep XLA flag combinations") + parser.add_argument( + "--benchmark", + default="ring", + help="Comma-separated benchmark names (default: ring)", + ) + parser.add_argument( + "--configs", + default=None, + help="Comma-separated config names (default: all)", + ) + parser.add_argument( + "--include-sparse", + action="store_true", + help="Include sparse solver for large circuits", + ) + parser.add_argument( + "--force-gpu", + action="store_true", + default=True, + help="Force GPU backend (default: True)", + ) + parser.add_argument( + "--cpu-only", + action="store_true", + help="Run on CPU instead of GPU", + ) + parser.add_argument( + "--n-warmup", + type=int, + default=1, + help="Number of warmup runs (default: 1)", + ) + parser.add_argument( + "--n-runs", + type=int, + default=3, + help="Number of timed runs (default: 3)", + ) + parser.add_argument( + "--json-output", + default=None, + help="Path to write JSON results", + ) + args = parser.parse_args() + + project_root = Path(__file__).parent.parent + benchmarks = [b.strip() for b in args.benchmark.split(",")] + force_gpu = not args.cpu_only + + if args.configs: + config_names = [c.strip() for c in args.configs.split(",")] + configs = {k: FLAG_CONFIGS[k] for k in config_names if k in FLAG_CONFIGS} + else: + configs = FLAG_CONFIGS + + all_results = [] + + for benchmark in benchmarks: + # Determine if sparse needed + use_sparse = args.include_sparse and benchmark in ("c6288", "mul64") + + for config_name, config in configs.items(): + result = run_config( + config_name, + config, + benchmark, + project_root, + use_sparse=use_sparse, + force_gpu=force_gpu, + n_warmup=args.n_warmup, + n_runs=args.n_runs, + ) + all_results.append(result) + + # Print summary table + print(f"\n{'=' * 80}") + print("SUMMARY") + print(f"{'=' * 80}") + print( + f"{'Config':<25} {'Benchmark':<10} {'ms/step':>10} {'Steps':>8} {'Wall(s)':>10} {'vs base':>10}" + ) + print("-" * 80) + + # Group by benchmark for relative comparison + by_benchmark = {} + for r in all_results: + bm = r.get("benchmark", "?") + by_benchmark.setdefault(bm, []).append(r) + + for bm, results in by_benchmark.items(): + baseline_ms = None + for r in results: + if r.get("config") == "baseline" and "ms_per_step" in r: + baseline_ms = r["ms_per_step"] + break + + for r in results: + ms = r.get("ms_per_step", None) + steps = r.get("n_steps", "?") + wall = r.get("wall_time_s", 0) + config = r.get("config", "?") + err = r.get("error", None) + + if err: + print(f"{config:<25} {bm:<10} {'ERROR':>10} {'':>8} {wall:>10.1f} {err}") + elif ms is not None: + ratio_str = "" + if baseline_ms and baseline_ms > 0: + ratio = ms / baseline_ms + ratio_str = f"{ratio:.2f}x" + print(f"{config:<25} {bm:<10} {ms:>10.3f} {steps:>8} {wall:>10.1f} {ratio_str:>10}") + + # Save JSON + if args.json_output: + out_path = Path(args.json_output) + out_path.parent.mkdir(parents=True, exist_ok=True) + with open(out_path, "w") as f: + json.dump(all_results, f, indent=2) + print(f"\nResults saved to {out_path}") + + +if __name__ == "__main__": + main() diff --git a/tests/test_vacask_suite.py b/tests/test_vacask_suite.py index c2538d13..37ab5804 100644 --- a/tests/test_vacask_suite.py +++ b/tests/test_vacask_suite.py @@ -22,12 +22,12 @@ from vajax.netlist.parser import parse_netlist -# Paths - VACASK is at ../VACASK relative to vajax +# Paths - VACASK is vendored at vendor/VACASK VAJAX_ROOT = Path(__file__).parent.parent -VACASK_ROOT = VAJAX_ROOT.parent / "VACASK" +VACASK_ROOT = VAJAX_ROOT / "vendor" / "VACASK" VACASK_TEST = VACASK_ROOT / "test" VACASK_DEVICES = VACASK_ROOT / "devices" -VACASK_BENCHMARK = VAJAX_ROOT / "vendor" / "VACASK" / "benchmark" +VACASK_BENCHMARK = VACASK_ROOT / "benchmark" def discover_benchmark_dirs() -> List[Path]: diff --git a/vajax/analysis/engine.py b/vajax/analysis/engine.py index 3e9fb54d..d2bfe180 100644 --- a/vajax/analysis/engine.py +++ b/vajax/analysis/engine.py @@ -1015,13 +1015,244 @@ def run_transient(self) -> TransientResult: # Extract sliced numpy results for TransientResult times_np, voltages, currents = extract_results(times_full, V_out, stats) + # Keep as numpy arrays — avoids creating dynamically-sized JAX arrays + # that trigger jit(dynamic_slice) recompilation on CUDA when n_steps + # varies between runs (the shape gets baked into the XLA kernel). return TransientResult( - times=jnp.asarray(times_np), - voltages={k: jnp.asarray(v) for k, v in voltages.items()}, - currents={k: jnp.asarray(v) for k, v in currents.items()}, + times=times_np, + voltages=voltages, + currents=currents, stats=stats, ) + def dump_jaxpr(self, output_dir: str | Path = "/tmp/claude/jaxpr-analysis") -> Path: + """Dump JAX IR analysis for the compiled simulation functions. + + Analyzes two hot-path functions after prepare(): + 1. build_system - Jacobian + residual assembly (per NR iteration) + 2. nr_solve - Full Newton-Raphson solve (per timestep) + + Uses the actual circuit's device_arrays and dimensions from the prepared + strategy, matching the calling conventions used during simulation. + + For each function, writes: + - HLO text (StableHLO MLIR representation) + - HLO operation counts (top ops by frequency) + - XLA cost analysis (flops, memory bytes) + + Args: + output_dir: Directory for output files. + + Returns: + Path to the output directory. + """ + if not getattr(self, "_prepared", False): + self.prepare() + + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + strategy = self._prepared_strategy + + # Get cached functions and actual circuit data from the strategy + build_system_jit = getattr(strategy, "_cached_build_system_jit", None) + nr_solve = getattr(strategy, "_cached_full_mna_solver", None) + device_arrays = getattr(strategy, "_device_arrays_full_mna", None) + total_limit_states = getattr(strategy, "_total_limit_states", 0) + + if build_system_jit is None or nr_solve is None or device_arrays is None: + raise RuntimeError("No cached solver found. Call prepare() first.") + + # Get dimensions from transient setup cache + setup_cache = self._transient_setup_cache + if setup_cache is None: + raise RuntimeError("No transient setup cache. Call prepare() first.") + n_total = setup_cache["n_total"] + n_unknowns = setup_cache["n_unknowns"] + n_vsources = len([d for d in self.devices if d["model"] == "vsource"]) + n_isources = len([d for d in self.devices if d["model"] == "isource"]) + + logger.info( + f"dump_jaxpr: n_total={n_total}, n_unknowns={n_unknowns}, " + f"n_vsources={n_vsources}, limit_states={total_limit_states}" + ) + + # Use the actual circuit's device_arrays and correct-shaped args. + # Values don't matter for HLO tracing — only shapes and dtypes. + # Match the calling conventions from full_mna.py DC solve path. + dtype = get_float_dtype() + X = jnp.zeros(n_total + n_vsources, dtype=dtype) + vsource_vals = jnp.zeros(n_vsources, dtype=dtype) + isource_vals = jnp.zeros(max(n_isources, 1), dtype=dtype) + Q_prev = jnp.zeros(n_unknowns, dtype=dtype) + integ_c0 = jnp.asarray(0.0, dtype=dtype) + gmin = jnp.asarray(1e-12, dtype=dtype) + gshunt = jnp.asarray(0.0, dtype=dtype) + integ_c1 = jnp.asarray(0.0, dtype=dtype) + integ_d1 = jnp.asarray(0.0, dtype=dtype) + dQdt_prev = jnp.zeros(n_unknowns, dtype=dtype) + integ_c2 = jnp.asarray(0.0, dtype=dtype) + Q_prev2 = jnp.zeros(n_unknowns, dtype=dtype) + limit_state = jnp.zeros(total_limit_states, dtype=dtype) + nr_iter = jnp.asarray(1, dtype=jnp.int32) + + # build_system_jit signature: (X, vsource_vals, isource_vals, Q_prev, + # integ_c0, device_arrays, gmin, gshunt, integ_c1, integ_d1, + # dQdt_prev, integ_c2, Q_prev2, limit_state, nr_iter) + build_args = ( + X, + vsource_vals, + isource_vals, + Q_prev, + integ_c0, + device_arrays, + gmin, + gshunt, + integ_c1, + integ_d1, + dQdt_prev, + integ_c2, + Q_prev2, + limit_state, + nr_iter, + ) + + # nr_solve signature: (X_init, vsource_vals, isource_vals, Q_prev, + # integ_c0, device_arrays, gmin, gshunt, integ_c1, integ_d1, + # dQdt_prev, integ_c2, Q_prev2, limit_state_in) + # Uses None defaults for optional args, matching DC solve convention. + nr_args = ( + X, + vsource_vals, + isource_vals, + Q_prev, + integ_c0, + device_arrays, + gmin, + gshunt, + integ_c1, + integ_d1, + None, + integ_c2, + None, + None, + ) + + results = {} + for name, fn, args in [ + ("build_system", build_system_jit, build_args), + ("nr_solve", nr_solve, nr_args), + ]: + results[name] = self._dump_single_jaxpr(name, fn, args, output_dir) + + return output_dir + + @staticmethod + def _dump_single_jaxpr(name: str, fn, args, output_dir: Path) -> dict[str, Any]: + """Analyze and dump jaxpr/HLO for a single function. + + Produces three artifacts per function: + - {name}.jaxpr.txt — JAX's high-level IR (from jax.make_jaxpr) + - {name}.hlo.txt — StableHLO/MLIR after lowering + - Log output with op counts and XLA cost analysis + + Returns: + Dict with 'jaxpr', 'hlo_lines', 'op_counts', 'cost' keys. + """ + result: dict[str, Any] = {"name": name} + logger.info(f"Analyzing {name}...") + + try: + # 1. Jaxpr (high-level JAX IR) + jaxpr = jax.make_jaxpr(fn)(*args) + jaxpr_text = str(jaxpr) + jaxpr_lines = jaxpr_text.split("\n") + result["jaxpr"] = jaxpr + result["jaxpr_lines"] = len(jaxpr_lines) + logger.info(f" {name}: {len(jaxpr_lines)} jaxpr lines") + + jaxpr_file = output_dir / f"{name}.jaxpr.txt" + with open(jaxpr_file, "w") as f: + f.write(jaxpr_text) + result["jaxpr_file"] = jaxpr_file + logger.info(f" Saved jaxpr: {jaxpr_file}") + + # Count jaxpr primitives + jaxpr_ops: dict[str, int] = {} + for eqn in jaxpr.jaxpr.eqns: + prim_name = str(eqn.primitive.name) + jaxpr_ops[prim_name] = jaxpr_ops.get(prim_name, 0) + 1 + result["jaxpr_ops"] = dict(sorted(jaxpr_ops.items(), key=lambda x: -x[1])) + if jaxpr_ops: + sorted_jaxpr_ops = sorted(jaxpr_ops.items(), key=lambda x: -x[1]) + logger.info(f" {name}: Top jaxpr primitives:") + for op, count in sorted_jaxpr_ops[:20]: + logger.info(f" {op:45s} {count:6d}") + logger.info(f" Total jaxpr eqns: {len(jaxpr.jaxpr.eqns)}") + + # 2. HLO (lowered StableHLO) + if hasattr(fn, "lower"): + lowered = fn.lower(*args) + else: + lowered = jax.jit(fn).lower(*args) + + hlo_text = lowered.as_text() + hlo_lines = hlo_text.split("\n") + result["hlo_lines"] = len(hlo_lines) + logger.info(f" {name}: {len(hlo_lines)} HLO lines") + + # Count HLO operations + op_counts: dict[str, int] = {} + for line in hlo_lines: + if "=" in line and "." in line: + parts = line.split("=") + if len(parts) >= 2: + op_part = parts[1].strip().split()[0] if parts[1].strip() else "" + if "." in op_part: + op_name = op_part.split("(")[0] + op_counts[op_name] = op_counts.get(op_name, 0) + 1 + + result["op_counts"] = dict(sorted(op_counts.items(), key=lambda x: -x[1])) + if op_counts: + sorted_ops = sorted(op_counts.items(), key=lambda x: -x[1]) + logger.info(f" {name}: Top HLO ops:") + for op, count in sorted_ops[:20]: + logger.info(f" {op:45s} {count:6d}") + logger.info(f" Total unique op types: {len(op_counts)}") + logger.info(f" Total ops: {sum(op_counts.values())}") + + # 3. Cost analysis + compiled = lowered.compile() + cost = compiled.cost_analysis() + result["cost"] = cost + if cost: + for i, device_cost in enumerate(cost): + if device_cost and isinstance(device_cost, dict): + logger.info(f" {name} cost (device {i}):") + for key, val in sorted(device_cost.items()): + if isinstance(val, (int, float)): + if val > 1e9: + logger.info(f" {key}: {val / 1e9:.2f}G") + elif val > 1e6: + logger.info(f" {key}: {val / 1e6:.2f}M") + elif val > 1e3: + logger.info(f" {key}: {val / 1e3:.2f}K") + else: + logger.info(f" {key}: {val:.2f}") + + # Save HLO + hlo_file = output_dir / f"{name}.hlo.txt" + with open(hlo_file, "w") as f: + f.write(hlo_text) + result["hlo_file"] = hlo_file + logger.info(f" Saved HLO: {hlo_file}") + + except Exception as e: + logger.error(f"Failed to analyze {name}: {e}", exc_info=True) + result["error"] = str(e) + + return result + # ========================================================================= # Node Collapse Implementation # ========================================================================= diff --git a/vajax/analysis/openvaf_models.py b/vajax/analysis/openvaf_models.py index 999b1dfa..27aa1183 100644 --- a/vajax/analysis/openvaf_models.py +++ b/vajax/analysis/openvaf_models.py @@ -915,6 +915,21 @@ def prepare_static_inputs( shared_cache_indices = [] varying_cache_indices = [] + # Split cache arrays + shared_cache = cache[0, shared_cache_indices] + device_cache = cache[:, varying_cache_indices] + + # SCCP dead-block elimination is DISABLED for the unified eval function. + # While SCCP can eliminate ~695/954 blocks for PSP103, the benefit is + # marginal (same code size, same XLA ops — XLA already CSEs branches). + # The cost is high: SCCP changes the JIT function hash, invalidating + # the persistent XLA compilation cache and adding ~99s cold-compile + # penalty for ring (49.5s × 2 compilations vs 0.58s cache hit). + # SCCP would be valuable for config-group-specialized eval functions + # where each group has a unique TYPE value, but that's deferred. + # Infrastructure is preserved in build_sccp_known_values() for reuse. + sccp_known_values = None + # Generate eval function with cache split from vajax.analysis.limiting import fetlim, pnjlim @@ -930,6 +945,7 @@ def prepare_static_inputs( varying_cache_indices, use_limit_functions=use_device_limiting, limit_param_map=limit_param_map, + sccp_known_values=sccp_known_values, ) # Safety check: if limiting is enabled but lim_rhs could not be computed # (model uses inline limiting without $limit/BuiltinLimit calls), disable @@ -950,15 +966,12 @@ def prepare_static_inputs( varying_cache_indices, use_limit_functions=False, limit_param_map=limit_param_map, + sccp_known_values=sccp_known_values, ) split_fn = partial(split_fn, limit_funcs=limit_funcs) vmapped_split_fn = jax.jit(jax.vmap(split_fn, in_axes=(None, 0, None, 0, None, 0))) - # Split cache arrays - shared_cache = cache[0, shared_cache_indices] - device_cache = cache[:, varying_cache_indices] - # Build default simparams from model metadata simparams_used = split_meta.get("simparams_used", ["$analysis_type", "$mfactor", "gmin"]) simparam_count = split_meta.get("simparam_count", len(simparams_used)) diff --git a/vajax/analysis/solver_factories.py b/vajax/analysis/solver_factories.py index 378e7c94..da085653 100644 --- a/vajax/analysis/solver_factories.py +++ b/vajax/analysis/solver_factories.py @@ -510,9 +510,9 @@ def enforce_noi(J, f): def linear_solve(J, f): """Solve J @ delta = -f using dense direct solver.""" - # Add Tikhonov regularization for numerical stability on GPU - reg = 1e-14 * jnp.eye(J.shape[0], dtype=J.dtype) - return jax.scipy.linalg.solve(J + reg, -f) + # Diagonal regularization (gmin + gshunt) is already applied during + # Jacobian assembly in assemble_dense_jacobian / _build_system_dense_direct. + return jax.scipy.linalg.solve(J, -f) logger.info( f"Creating dense full MNA solver: V({n_nodes}) + I({n_vsources}), " @@ -536,6 +536,78 @@ def linear_solve(J, f): ) +def make_baspacho_dense_full_mna_solver( + build_system_jit: Callable, + n_nodes: int, + n_vsources: int, + noi_indices: Optional[Array] = None, + internal_device_indices: Optional[Array] = None, + max_iterations: int = 100, + abstol: float = 1e-12, + total_limit_states: int = 0, + options: Optional["SimulationOptions"] = None, + max_step: float = 1e30, +) -> Callable: + """Create a dense NR solver using BaSpaCho LU on CUDA. + + Uses BaSpaCho's supernodal LU factorization with CUDA backend, + replacing jax.scipy.linalg.solve (cuSOLVER getrf) on GPU. Benefits: + - Symbolic analysis done once (cached across NR iterations) + - Grow-only GPU memory allocation (no per-call cudaMalloc after warmup) + - Foundation for Phase 2b graph-capture compatibility + + Falls back to standard dense solver if BaSpaCho is unavailable. + + Args: + Same as make_dense_full_mna_solver. + """ + from spineax.cudss.dense_baspacho_solver import baspacho_dense_solve + + masks = _compute_noi_masks( + noi_indices, n_nodes, internal_device_indices=internal_device_indices + ) + noi_res_idx = masks["noi_res_idx"] + + residual_mask = _build_augmented_mask(masks["residual_mask"], n_vsources) + residual_conv_mask = _build_augmented_conv_mask( + masks["residual_conv_mask"], residual_mask, n_vsources + ) + + def enforce_noi(J, f): + """Enforce NOI constraints on dense Jacobian.""" + if noi_res_idx is not None: + J = J.at[noi_res_idx, :].set(0.0) + J = J.at[:, noi_res_idx].set(0.0) + J = J.at[noi_res_idx, noi_res_idx].set(1.0) + f = f.at[noi_res_idx].set(0.0) + return J, f + + def linear_solve(J, f): + """Solve J @ delta = -f using BaSpaCho LU on CUDA.""" + return baspacho_dense_solve(J, -f) + + logger.info( + f"Creating BaSpaCho dense full MNA solver: V({n_nodes}) + I({n_vsources}), " + f"NOI: {noi_indices is not None}" + ) + return _make_nr_solver_common( + build_system_jit=build_system_jit, + n_nodes=n_nodes, + n_vsources=n_vsources, + linear_solve_fn=linear_solve, + enforce_noi_fn=enforce_noi, + noi_indices=noi_indices, + internal_device_indices=internal_device_indices, + max_iterations=max_iterations, + abstol=abstol, + total_limit_states=total_limit_states, + options=options, + max_step=max_step, + residual_mask=residual_mask, + residual_conv_mask=residual_conv_mask, + ) + + def make_spineax_full_mna_solver( build_system_jit: Callable, n_nodes: int, diff --git a/vajax/analysis/transient/full_mna.py b/vajax/analysis/transient/full_mna.py index 80d58dab..48deb7cb 100644 --- a/vajax/analysis/transient/full_mna.py +++ b/vajax/analysis/transient/full_mna.py @@ -36,6 +36,7 @@ from vajax._logging import logger from vajax.analysis.solver_factories import ( + make_baspacho_dense_full_mna_solver, make_dense_full_mna_solver, make_spineax_full_mna_solver, make_umfpack_ffi_full_mna_solver, @@ -53,6 +54,16 @@ def is_spineax_available() -> bool: return False +def is_baspacho_dense_available() -> bool: + """Check if BaSpaCho dense CUDA solver is available.""" + try: + from spineax.cudss.dense_baspacho_solver import is_available + + return is_available() + except ImportError: + return False + + from vajax.analysis.integration import IntegrationMethod from .adaptive import AdaptiveConfig, compute_lte_timestep_jax, predict_voltage_jax @@ -385,17 +396,43 @@ def _ensure_full_mna_solver(self, setup: TransientSetup) -> Callable: self._total_limit_states = total_limit_states build_system_jit = jax.jit(build_system_fn) - nr_solve = make_dense_full_mna_solver( - build_system_jit, - n_nodes, - n_vsources, - noi_indices=noi_indices, - internal_device_indices=internal_device_indices, - max_iterations=self.runner.options.tran_itl, - abstol=self.runner.options.abstol, - total_limit_states=total_limit_states, - options=self.runner.options, - ) + # On CUDA, try BaSpaCho dense solver (pre-allocated workspace, + # foundation for graph-capture compatibility in Phase 2b). + on_cuda_dense = jax.default_backend() in ("cuda", "gpu") + if on_cuda_dense and is_baspacho_dense_available(): + try: + logger.info("Using BaSpaCho dense solver (GPU, CUDA backend)") + nr_solve = make_baspacho_dense_full_mna_solver( + build_system_jit, + n_nodes, + n_vsources, + noi_indices=noi_indices, + internal_device_indices=internal_device_indices, + max_iterations=self.runner.options.tran_itl, + abstol=self.runner.options.abstol, + total_limit_states=total_limit_states, + options=self.runner.options, + ) + except Exception as e: + logger.warning( + f"BaSpaCho dense solver failed ({e}), falling back to JAX dense solver" + ) + nr_solve = None + else: + nr_solve = None + + if nr_solve is None: + nr_solve = make_dense_full_mna_solver( + build_system_jit, + n_nodes, + n_vsources, + noi_indices=noi_indices, + internal_device_indices=internal_device_indices, + max_iterations=self.runner.options.tran_itl, + abstol=self.runner.options.abstol, + total_limit_states=total_limit_states, + options=self.runner.options, + ) else: # Sparse path: use CSR direct stamping to eliminate COO intermediates n_augmented = setup.n_unknowns + n_vsources @@ -1767,21 +1804,18 @@ def _debug_step_callback( # Compute the voltage to record - use new_X which is the actual solution we're using # (either X_new if converged, or previous X if NR failed at min_dt) V_to_record = new_X[:n_external] - new_times_out = jnp.where( - accept_step, state.times_out.at[state.step_idx].set(t_next), state.times_out - ) - new_V_out = jnp.where( - accept_step, state.V_out.at[state.step_idx].set(V_to_record), state.V_out - ) - # For currents, use zero if NR failed at min_dt (current from bad solution is unreliable) + # Write unconditionally: on rejection step_idx doesn't advance, so stale + # values at step_idx get overwritten by the next accepted step. The caller + # trims output using step_idx, so values beyond it are ignored. This avoids + # materializing both branches of jnp.where on the full output arrays. I_to_record = jnp.where( nr_failed_at_min_dt, jnp.zeros(n_vsources, dtype=dtype) if n_vsources > 0 else jnp.zeros(1, dtype=dtype), I_vsource[:n_vsources] if n_vsources > 0 else jnp.zeros(1, dtype=dtype), ) - new_I_out = jnp.where( - accept_step, state.I_out.at[state.step_idx].set(I_to_record), state.I_out - ) + new_times_out = state.times_out.at[state.step_idx].set(t_next) + new_V_out = state.V_out.at[state.step_idx].set(V_to_record) + new_I_out = state.I_out.at[state.step_idx].set(I_to_record) new_step_idx = jnp.where(accept_step, state.step_idx + 1, state.step_idx) # Statistics